From cef7ab442975471dd886eff61d6890094ca5e701 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Mon, 30 Sep 2024 00:55:34 +0800 Subject: [PATCH] add scan points check --- configs/local/view_generate_config.yaml | 6 --- core/global_pts_n_num_pipeline.py | 10 ++--- utils/reconstruction.py | 55 +++++++++++++++++++++---- 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/configs/local/view_generate_config.yaml b/configs/local/view_generate_config.yaml index 624adb7..4112baa 100644 --- a/configs/local/view_generate_config.yaml +++ b/configs/local/view_generate_config.yaml @@ -24,12 +24,6 @@ runner: max_height: 0.15 min_radius: 0.3 max_radius: 0.5 - min_R: 0.05 - max_R: 0.3 - min_G: 0.05 - max_G: 0.3 - min_B: 0.05 - max_B: 0.3 display_object: min_x: 0 max_x: 0.03 diff --git a/core/global_pts_n_num_pipeline.py b/core/global_pts_n_num_pipeline.py index 948b047..3970bd3 100644 --- a/core/global_pts_n_num_pipeline.py +++ b/core/global_pts_n_num_pipeline.py @@ -76,16 +76,16 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): device = next(self.parameters()).device - pose_feat_seq_list = [] - pts_num_feat_seq_list = [] + embedding_list_batch = [] for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch): scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) scanned_target_pts_num = scanned_target_pts_num.to(device) - pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)) - pts_num_feat_seq_list.append(self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)) + pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) + pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) + embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1)) - main_feat = self.pose_n_num_seq_encoder.encode_sequence(pts_num_feat_seq_list, pose_feat_seq_list) + main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) combined_scanned_pts_batch = data['combined_scanned_pts'] diff --git a/utils/reconstruction.py b/utils/reconstruction.py index 7b9d5cf..0b78e59 100644 --- a/utils/reconstruction.py +++ b/utils/reconstruction.py @@ -45,9 +45,10 @@ class ReconstructionUtil: @staticmethod - def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list,threshold=0.01, overlap_threshold=0.3, init_view = 0, status_info=None): + def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, scan_points_indices_list, threshold=0.01, overlap_threshold=0.3, init_view = 0, status_info=None): selected_views = [point_cloud_list[init_view]] combined_point_cloud = np.vstack(selected_views) + combined_scan_points_indices = scan_points_indices_list[init_view] down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) current_coverage = new_coverage @@ -63,12 +64,14 @@ class ReconstructionUtil: for view_index in remaining_views: if selected_views: - combined_old_point_cloud = np.vstack(selected_views) - down_sampled_old_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_old_point_cloud,threshold) - down_sampled_new_view_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud_list[view_index],threshold) - overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_new_view_point_cloud,down_sampled_old_point_cloud, threshold) - if overlap_rate < overlap_threshold: - continue + new_scan_points_indices = scan_points_indices_list[view_index] + if not ReconstructionUtil.check_scan_points_overlap(combined_scan_points_indices, new_scan_points_indices): + combined_old_point_cloud = np.vstack(selected_views) + down_sampled_old_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_old_point_cloud,threshold) + down_sampled_new_view_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud_list[view_index],threshold) + overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_new_view_point_cloud,down_sampled_old_point_cloud, threshold) + if overlap_rate < overlap_threshold: + continue candidate_views = selected_views + [point_cloud_list[view_index]] combined_point_cloud = np.vstack(candidate_views) @@ -85,6 +88,7 @@ class ReconstructionUtil: break selected_views.append(point_cloud_list[best_view]) remaining_views.remove(best_view) + combined_scan_points_indices = ReconstructionUtil.combine_scan_points_indices(combined_scan_points_indices, scan_points_indices_list[best_view]) current_coverage += best_coverage_increase cnt_processed_view += 1 if status_info is not None: @@ -120,4 +124,39 @@ class ReconstructionUtil: filtered_sampled_points= sampled_points[cos_theta > np.cos(theta_rad)] return filtered_sampled_points[:, :3] - \ No newline at end of file + + @staticmethod + def generate_scan_points(display_table_top, display_table_radius, min_distance=0.03, max_points_num = 100, max_attempts = 1000): + points = [] + attempts = 0 + while len(points) < max_points_num and attempts < max_attempts: + angle = np.random.uniform(0, 2 * np.pi) + r = np.random.uniform(0, display_table_radius) + x = r * np.cos(angle) + y = r * np.sin(angle) + z = display_table_top + new_point = (x, y, z) + if all(np.linalg.norm(np.array(new_point) - np.array(existing_point)) >= min_distance for existing_point in points): + points.append(new_point) + attempts += 1 + return points + + @staticmethod + def compute_covered_scan_points(scan_points, point_cloud, threshold=0.01): + tree = cKDTree(point_cloud) + covered_points = [] + indices = [] + for i, scan_point in enumerate(scan_points): + if tree.query_ball_point(scan_point, threshold): + covered_points.append(scan_point) + indices.append(i) + return covered_points, indices + + @staticmethod + def check_scan_points_overlap(indices1, indices2, threshold=5): + return len(set(indices1).intersection(set(indices2))) > threshold + + @staticmethod + def combine_scan_points_indices(indices1, indices2): + combined_indices = set(indices1) | set(indices2) + return sorted(combined_indices) \ No newline at end of file