From 493639287e0ba297d56f99cba1c4b9af2d99d5d2 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Thu, 7 Nov 2024 19:42:44 +0800 Subject: [PATCH] update calculating pts_num in inference.py --- configs/local/view_generate_config.yaml | 8 ++++---- runners/inferencer.py | 18 +++++++----------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/configs/local/view_generate_config.yaml b/configs/local/view_generate_config.yaml index 803bae8..4396de7 100644 --- a/configs/local/view_generate_config.yaml +++ b/configs/local/view_generate_config.yaml @@ -8,11 +8,11 @@ runner: root_dir: experiments generate: port: 5002 - from: 0 - to: -1 # -1 means all - object_dir: C:\\Document\\Datasets\\ball_meshes + from: 1 + to: 50 # -1 means all + object_dir: C:\\Document\\Datasets\\scaled_object_meshes table_model_path: C:\\Document\\Datasets\\table.obj - output_dir: C:\\Document\\Datasets\\debug_ball_generate_view + output_dir: C:\\Document\\Datasets\\debug_generate_view binocular_vision: true plane_size: 10 max_views: 512 diff --git a/runners/inferencer.py b/runners/inferencer.py index 0ba8a4f..238da0b 100644 --- a/runners/inferencer.py +++ b/runners/inferencer.py @@ -128,7 +128,7 @@ class Inferencer(Runner): retry = 0 pred_cr_seq = [last_pred_cr] success = 0 - last_pts_num = PtsUtil.voxel_downsample_point_cloud(data["first_scanned_pts"][0], 0.002) + last_pts_num = PtsUtil.voxel_downsample_point_cloud(data["first_scanned_pts"][0], 0.002).shape[0] import time while len(pred_cr_seq) < max_iter and retry < max_retry: start_time = time.time() @@ -151,7 +151,7 @@ class Inferencer(Runner): curr_overlap_area_threshold = overlap_area_threshold * 0.5 downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold) - overlap, new_added_pts_num = ReconstructionUtil.check_overlap(downsampled_new_target_pts, down_sampled_model_pts, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True) + overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, down_sampled_model_pts, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True) if not overlap: retry += 1 retry_overlap_pose.append(pred_pose.cpu().numpy().tolist()) @@ -175,27 +175,22 @@ class Inferencer(Runner): continue start_time = time.time() - pred_cr, covered_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) + pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) end_time = time.time() print(f"Time taken for coverage rate computation: {end_time - start_time} seconds") print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"]) if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: print("max coverage rate reached!: ", pred_cr) success += 1 - if pred_cr <= last_pred_cr + cr_increase_threshold: - retry += 1 - retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) - continue + retry = 0 pred_cr_seq.append(pred_cr) scanned_view_pts.append(new_target_pts) - down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_target_pts, input_pts_N) - new_pts = down_sampled_new_pts_world input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] - combined_scanned_pts = np.concatenate([input_data["combined_scanned_pts"][0].cpu().numpy(), new_pts], axis=0) + combined_scanned_pts = np.vstack(scanned_view_pts) voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, 0.002) random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) @@ -204,8 +199,9 @@ class Inferencer(Runner): break last_pred_cr = pred_cr pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0] - if pts_num - last_pts_num < 10: + if pts_num - last_pts_num < 10 and pred_cr < data["seq_max_coverage_rate"] - 1e-3: retry += 1 + retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) print("delta pts num < 10:", pts_num, last_pts_num) last_pts_num = pts_num