From bb9b3f81c3dc8f545a47cbad903bb394467db689 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sat, 5 Oct 2024 15:10:31 -0500 Subject: [PATCH 1/7] update reconstruction --- configs/local/strategy_generate_config.yaml | 4 +- preprocess/preprocessor.py | 29 ++----- runners/strategy_generator.py | 27 +++++-- utils/reconstruction.py | 83 ++++++++++----------- 4 files changed, 68 insertions(+), 75 deletions(-) diff --git a/configs/local/strategy_generate_config.yaml b/configs/local/strategy_generate_config.yaml index 3e4576d..e9601f4 100644 --- a/configs/local/strategy_generate_config.yaml +++ b/configs/local/strategy_generate_config.yaml @@ -28,8 +28,8 @@ runner: datasets: OmniObject3d: #"/media/hofee/data/data/temp_output" - root_dir: "/media/hofee/repository/new_full_box_data" - model_dir: "/media/hofee/data/data/scaled_object_meshes" + root_dir: "C:\\Document\\Local Project\\nbv_rec\\nbv_reconstruction\\test\\test_sample" + model_dir: "H:\\AI\\Datasets\\scaled_object_meshes" from: 0 to: -1 # -1 means end #output_dir: "/media/hofee/data/data/label_output" diff --git a/preprocess/preprocessor.py b/preprocess/preprocessor.py index e3064be..e7fd616 100644 --- a/preprocess/preprocessor.py +++ b/preprocess/preprocessor.py @@ -30,22 +30,6 @@ def save_scan_points_indices(root, scene, frame_idx, scan_points_indices: np.nda def save_scan_points(root, scene, scan_points: np.ndarray): scan_points_path = os.path.join(root,scene, "scan_points.txt") save_np_pts(scan_points_path, scan_points) - - -def old_get_world_points(depth, cam_intrinsic, cam_extrinsic): - h, w = depth.shape - i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") - # ----- Debug Trace ----- # - import ipdb; ipdb.set_trace() - # ------------------------ # - z = depth - x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] - y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] - points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) - points_camera_aug = np.concatenate((points_camera, np.ones((points_camera.shape[0], 1))), axis=-1) - points_camera_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] - - return points_camera_world def get_world_points(depth, mask, cam_intrinsic, cam_extrinsic): z = depth[mask] @@ -74,7 +58,7 @@ def get_scan_points_indices(scan_points, mask, display_table_mask_label, cam_int return selected_points_indices -def save_scene_data(root, scene, scene_idx=0, scene_total=1): +def save_scene_data(root, scene, scene_idx=0, scene_total=1,file_type="txt"): ''' configuration ''' target_mask_label = (0, 255, 0, 255) @@ -128,8 +112,9 @@ def save_scene_data(root, scene, scene_idx=0, scene_total=1): sampled_target_points_L, sampled_target_points_R, voxel_size ) - - has_points = target_points.shape[0] > 0 + if has_points: + has_points = target_points.shape[0] > 0 + if has_points: points_normals = DataLoadUtil.load_points_normals(root, scene, display_table_as_world_space_origin=True) target_points = PtsUtil.filter_points( @@ -145,8 +130,8 @@ def save_scene_data(root, scene, scene_idx=0, scene_total=1): if not has_points: target_points = np.zeros((0, 3)) - save_target_points(root, scene, frame_id, target_points) - save_scan_points_indices(root, scene, frame_id, scan_points_indices) + save_target_points(root, scene, frame_id, target_points, file_type=file_type) + save_scan_points_indices(root, scene, frame_id, scan_points_indices, file_type=file_type) save_scan_points(root, scene, scan_points) # The "done" flag of scene preprocess @@ -168,7 +153,7 @@ if __name__ == "__main__": total = to_idx - from_idx for scene in scene_list[from_idx:to_idx]: start = time.time() - save_scene_data(root, scene, cnt, total) + save_scene_data(root, scene, cnt, total, file_type="npy") cnt+=1 end = time.time() print(f"Time cost: {end-start}") \ No newline at end of file diff --git a/runners/strategy_generator.py b/runners/strategy_generator.py index 62ddffd..a9872a2 100644 --- a/runners/strategy_generator.py +++ b/runners/strategy_generator.py @@ -84,27 +84,38 @@ class StrategyGenerator(Runner): pts_list = [] scan_points_indices_list = [] non_zero_cnt = 0 + for frame_idx in range(frame_num): status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num) - pts_path = os.path.join(root,scene_name, "target_pts", f"{frame_idx}.txt") - sampled_point_cloud = np.loadtxt(pts_path) - indices = None # ReconstructionUtil.compute_covered_scan_points(scan_points, display_table_pts) + pts_path = os.path.join(root,scene_name, "pts", f"{frame_idx}.npy") + idx_path = os.path.join(root,scene_name, "scan_points_indices", f"{frame_idx}.npy") + point_cloud = np.load(pts_path) + sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud, voxel_threshold) + indices = np.load(idx_path) pts_list.append(sampled_point_cloud) + scan_points_indices_list.append(indices) + if sampled_point_cloud.shape[0] > 0: + non_zero_cnt += 1 status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_num, frame_num) - + seq_num = min(self.seq_num, non_zero_cnt) init_view_list = [] - for i in range(seq_num): - if pts_list[i].shape[0] < 100: - continue - init_view_list.append(i) + idx = 0 + while len(init_view_list) < seq_num: + if pts_list[idx].shape[0] > 100: + init_view_list.append(idx) + idx += 1 seq_idx = 0 + import time for init_view in init_view_list: status_manager.set_progress("generate_strategy", "strategy_generator", "computing sequence", seq_idx, len(init_view_list)) + start = time.time() limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(down_sampled_model_pts, pts_list, scan_points_indices_list = scan_points_indices_list,init_view=init_view, threshold=voxel_threshold, soft_overlap_threshold=soft_overlap_threshold, hard_overlap_threshold= hard_overlap_threshold, scan_points_threshold=10, status_info=self.status_info) + end = time.time() + print(f"Time: {end-start}") data_pairs = self.generate_data_pairs(limited_useful_view) seq_save_data = { "data_pairs": data_pairs, diff --git a/utils/reconstruction.py b/utils/reconstruction.py index 530358d..376c056 100644 --- a/utils/reconstruction.py +++ b/utils/reconstruction.py @@ -22,29 +22,7 @@ class ReconstructionUtil: else: overlap_rate = overlapping_points / new_point_cloud.shape[0] return overlap_rate - - @staticmethod - def combine_point_with_view_sequence(point_list, view_sequence): - selected_views = [] - for view_index, _ in view_sequence: - selected_views.append(point_list[view_index]) - return np.vstack(selected_views) - - @staticmethod - def compute_next_view_coverage_list(views, combined_point_cloud, target_point_cloud, threshold=0.01): - best_view = None - best_coverage_increase = -1 - current_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold) - - for view_index, view in enumerate(views): - candidate_views = combined_point_cloud + [view] - down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(candidate_views, threshold) - new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) - coverage_increase = new_coverage - current_coverage - if coverage_increase > best_coverage_increase: - best_coverage_increase = coverage_increase - best_view = view_index - return best_view, best_coverage_increase + @staticmethod def get_new_added_points(old_combined_pts, new_pts, threshold=0.005): @@ -60,54 +38,70 @@ class ReconstructionUtil: @staticmethod def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, scan_points_indices_list, threshold=0.01, soft_overlap_threshold=0.5, hard_overlap_threshold=0.7, init_view = 0, scan_points_threshold=5, status_info=None): - selected_views = [point_cloud_list[init_view]] - combined_point_cloud = np.vstack(selected_views) + selected_views = [init_view] + combined_point_cloud = point_cloud_list[init_view] history_indices = [scan_points_indices_list[init_view]] - down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) - new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) + + max_rec_pts = np.vstack(point_cloud_list) + downsampled_max_rec_pts = PtsUtil.voxel_downsample_point_cloud(max_rec_pts, threshold) + + max_rec_pts_num = downsampled_max_rec_pts.shape[0] + max_rec_pts_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, downsampled_max_rec_pts, threshold) + + new_coverage = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, combined_point_cloud, threshold) current_coverage = new_coverage remaining_views = list(range(len(point_cloud_list))) view_sequence = [(init_view, current_coverage)] cnt_processed_view = 0 remaining_views.remove(init_view) - + curr_rec_pts_num = combined_point_cloud.shape[0] + + import time while remaining_views: best_view = None best_coverage_increase = -1 + best_combined_point_cloud = None for view_index in remaining_views: if point_cloud_list[view_index].shape[0] == 0: continue - if selected_views: new_scan_points_indices = scan_points_indices_list[view_index] - if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): overlap_threshold = hard_overlap_threshold else: overlap_threshold = soft_overlap_threshold - - combined_old_point_cloud = np.vstack(selected_views) - down_sampled_old_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_old_point_cloud,threshold) - down_sampled_new_view_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud_list[view_index],threshold) - overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_new_view_point_cloud,down_sampled_old_point_cloud, threshold) + start = time.time() + overlap_rate = ReconstructionUtil.compute_overlap_rate(point_cloud_list[view_index],combined_point_cloud, threshold) + end = time.time() + # print(f"overlap_rate Time: {end-start}") if overlap_rate < overlap_threshold: continue - - candidate_views = selected_views + [point_cloud_list[view_index]] - combined_point_cloud = np.vstack(candidate_views) - down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) - new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) + + start = time.time() + new_combined_point_cloud = np.vstack([combined_point_cloud, point_cloud_list[view_index]]) + new_downsampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(new_combined_point_cloud,threshold) + new_coverage = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, new_downsampled_combined_point_cloud, threshold) + end = time.time() + #print(f"compute_coverage_rate Time: {end-start}") coverage_increase = new_coverage - current_coverage if coverage_increase > best_coverage_increase: best_coverage_increase = coverage_increase best_view = view_index + best_combined_point_cloud = new_downsampled_combined_point_cloud if best_view is not None: - if best_coverage_increase <=3e-3: + if best_coverage_increase <=1e-3: break - selected_views.append(point_cloud_list[best_view]) + + selected_views.append(best_view) + best_rec_pts_num = best_combined_point_cloud.shape[0] + print(f"Current rec pts num: {curr_rec_pts_num}, Best rec pts num: {best_rec_pts_num}, Max rec pts num: {max_rec_pts_num}") + print(f"Current coverage: {current_coverage}, Best coverage increase: {best_coverage_increase}, Max coverage: {max_rec_pts_coverage}") + + curr_rec_pts_num = best_rec_pts_num + combined_point_cloud = best_combined_point_cloud remaining_views.remove(best_view) history_indices.append(scan_points_indices_list[best_view]) current_coverage += best_coverage_increase @@ -123,12 +117,15 @@ class ReconstructionUtil: else: break + # ----- Debug Trace ----- # + import ipdb; ipdb.set_trace() + # ------------------------ # if status_info is not None: sm = status_info["status_manager"] app_name = status_info["app_name"] runner_name = status_info["runner_name"] sm.set_progress(app_name, runner_name, "processed view", len(point_cloud_list), len(point_cloud_list)) - return view_sequence, remaining_views, down_sampled_combined_point_cloud + return view_sequence, remaining_views, combined_point_cloud @staticmethod From 1a3ae15130f52d6f61ecf192dfa37e8b7f7be6b3 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sat, 5 Oct 2024 15:17:54 -0500 Subject: [PATCH 2/7] update nbv_dataset: scene_points to target_points --- core/nbv_dataset.py | 61 +++++++-------------------------------------- utils/data_load.py | 2 +- 2 files changed, 10 insertions(+), 53 deletions(-) diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index 4573bb0..a3c5e1f 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -124,63 +124,20 @@ class NBVReconstructionDataset(BaseDataset): scanned_n_to_world_pose, scanned_target_pts_num, ) = ([], [], [], []) - target_pts_num_dict = DataLoadUtil.load_target_pts_num_dict( - self.root_dir, scene_name - ) for view in scanned_views: frame_idx = view[0] coverage_rate = view[1] view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True) - target_pts_num = target_pts_num_dict[frame_idx] - n_to_world_pose = cam_info["cam_to_world"] - nR_to_world_pose = cam_info["cam_to_world_R"] - - if self.load_from_preprocess: - downsampled_target_point_cloud = ( - DataLoadUtil.load_from_preprocessed_pts(view_path) - ) - else: - cached_data = None - if self.cache: - cached_data = self.load_from_cache(scene_name, frame_idx) - - if cached_data is None: - print("load depth") - depth_L, depth_R = DataLoadUtil.load_depth( - view_path, - cam_info["near_plane"], - cam_info["far_plane"], - binocular=True, - ) - point_cloud_L = DataLoadUtil.get_point_cloud( - depth_L, cam_info["cam_intrinsic"], n_to_world_pose - )["points_world"] - point_cloud_R = DataLoadUtil.get_point_cloud( - depth_R, cam_info["cam_intrinsic"], nR_to_world_pose - )["points_world"] - - point_cloud_L = PtsUtil.random_downsample_point_cloud( - point_cloud_L, 65536 - ) - point_cloud_R = PtsUtil.random_downsample_point_cloud( - point_cloud_R, 65536 - ) - overlap_points = PtsUtil.get_overlapping_points( - point_cloud_L, point_cloud_R - ) - downsampled_target_point_cloud = ( - PtsUtil.random_downsample_point_cloud( - overlap_points, self.pts_num - ) - ) - if self.cache: - self.save_to_cache( - scene_name, frame_idx, downsampled_target_point_cloud - ) - else: - downsampled_target_point_cloud = cached_data - + + n_to_world_pose = cam_info["cam_to_world"] + target_point_cloud = ( + DataLoadUtil.load_from_preprocessed_pts(view_path) + ) + target_pts_num = target_point_cloud.shape[0] + downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud( + target_point_cloud, self.pts_num + ) scanned_views_pts.append(downsampled_target_point_cloud) scanned_coverages_rate.append(coverage_rate) n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( diff --git a/utils/data_load.py b/utils/data_load.py index ed5a144..4ec8936 100644 --- a/utils/data_load.py +++ b/utils/data_load.py @@ -237,7 +237,7 @@ class DataLoadUtil: @staticmethod def load_from_preprocessed_pts(path): npy_path = os.path.join( - os.path.dirname(path), "points", os.path.basename(path) + ".npy" + os.path.dirname(path), "pts", os.path.basename(path) + ".npy" ) pts = np.load(npy_path) return pts From e315fd99ee7780e4771683f6ea970ab9c6688093 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sat, 5 Oct 2024 15:36:38 -0500 Subject: [PATCH 3/7] update new_num limit --- core/seq_dataset.py | 21 ++++++--------------- utils/reconstruction.py | 24 ++++++++++++++---------- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/core/seq_dataset.py b/core/seq_dataset.py index 0e1b836..3196949 100644 --- a/core/seq_dataset.py +++ b/core/seq_dataset.py @@ -74,24 +74,12 @@ class SeqNBVReconstructionDataset(BaseDataset): max_coverage_rate = data_item_info["max_coverage_rate"] scene_name = data_item_info["scene_name"] first_cam_info = DataLoadUtil.load_cam_info(DataLoadUtil.get_path(self.root_dir, scene_name, first_frame_idx), binocular=True) - first_view_path = DataLoadUtil.get_path(self.root_dir, scene_name, first_frame_idx) first_left_cam_pose = first_cam_info["cam_to_world"] - first_right_cam_pose = first_cam_info["cam_to_world_R"] first_center_cam_pose = first_cam_info["cam_to_world_O"] - if self.load_from_preprocess: - first_downsampled_target_point_cloud = DataLoadUtil.load_from_preprocessed_pts(first_view_path) - else: - first_depth_L, first_depth_R = DataLoadUtil.load_depth(first_view_path, first_cam_info['near_plane'], first_cam_info['far_plane'], binocular=True) - - first_point_cloud_L = DataLoadUtil.get_point_cloud(first_depth_L, first_cam_info['cam_intrinsic'], first_left_cam_pose)['points_world'] - first_point_cloud_R = DataLoadUtil.get_point_cloud(first_depth_R, first_cam_info['cam_intrinsic'], first_right_cam_pose)['points_world'] - - first_point_cloud_L = PtsUtil.random_downsample_point_cloud(first_point_cloud_L, 65536) - first_point_cloud_R = PtsUtil.random_downsample_point_cloud(first_point_cloud_R, 65536) - first_overlap_points = PtsUtil.get_overlapping_points(first_point_cloud_L, first_point_cloud_R) - first_downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(first_overlap_points, self.pts_num) - + first_target_point_cloud = DataLoadUtil.load_from_preprocessed_pts(first_view_path) + first_pts_num = first_target_point_cloud.shape[0] + first_downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(first_target_point_cloud, self.pts_num) first_to_world_rot_6d = PoseUtil.matrix_to_rotation_6d_numpy(np.asarray(first_left_cam_pose[:3,:3])) first_to_world_trans = first_left_cam_pose[:3,3] first_to_world_9d = np.concatenate([first_to_world_rot_6d, first_to_world_trans], axis=0) @@ -102,6 +90,9 @@ class SeqNBVReconstructionDataset(BaseDataset): model_points_normals = DataLoadUtil.load_points_normals(self.root_dir, scene_name) data_item = { + "first_pts_num": np.asarray( + first_pts_num, dtype=np.int32 + ), "first_pts": np.asarray([first_downsampled_target_point_cloud],dtype=np.float32), "combined_scanned_pts": np.asarray(first_downsampled_target_point_cloud,dtype=np.float32), "first_to_world_9d": np.asarray([first_to_world_9d],dtype=np.float32), diff --git a/utils/reconstruction.py b/utils/reconstruction.py index 376c056..5aebdf3 100644 --- a/utils/reconstruction.py +++ b/utils/reconstruction.py @@ -8,9 +8,9 @@ class ReconstructionUtil: def compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold=0.01): kdtree = cKDTree(combined_point_cloud) distances, _ = kdtree.query(target_point_cloud) - covered_points = np.sum(distances < threshold*2) - coverage_rate = covered_points / target_point_cloud.shape[0] - return coverage_rate + covered_points_num = np.sum(distances < threshold) + coverage_rate = covered_points_num / target_point_cloud.shape[0] + return coverage_rate, covered_points_num @staticmethod def compute_overlap_rate(new_point_cloud, combined_point_cloud, threshold=0.01): @@ -46,10 +46,12 @@ class ReconstructionUtil: downsampled_max_rec_pts = PtsUtil.voxel_downsample_point_cloud(max_rec_pts, threshold) max_rec_pts_num = downsampled_max_rec_pts.shape[0] - max_rec_pts_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, downsampled_max_rec_pts, threshold) + max_real_rec_pts_coverage, _ = ReconstructionUtil.compute_coverage_rate(target_point_cloud, downsampled_max_rec_pts, threshold) - new_coverage = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, combined_point_cloud, threshold) + new_coverage, new_covered_num = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, combined_point_cloud, threshold) current_coverage = new_coverage + current_covered_num = new_covered_num + remaining_views = list(range(len(point_cloud_list))) view_sequence = [(init_view, current_coverage)] cnt_processed_view = 0 @@ -61,6 +63,7 @@ class ReconstructionUtil: best_view = None best_coverage_increase = -1 best_combined_point_cloud = None + best_covered_num = 0 for view_index in remaining_views: if point_cloud_list[view_index].shape[0] == 0: @@ -81,25 +84,26 @@ class ReconstructionUtil: start = time.time() new_combined_point_cloud = np.vstack([combined_point_cloud, point_cloud_list[view_index]]) new_downsampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(new_combined_point_cloud,threshold) - new_coverage = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, new_downsampled_combined_point_cloud, threshold) + new_coverage, new_covered_num = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, new_downsampled_combined_point_cloud, threshold) end = time.time() #print(f"compute_coverage_rate Time: {end-start}") coverage_increase = new_coverage - current_coverage if coverage_increase > best_coverage_increase: best_coverage_increase = coverage_increase best_view = view_index + best_covered_num = new_covered_num best_combined_point_cloud = new_downsampled_combined_point_cloud if best_view is not None: - if best_coverage_increase <=1e-3: + if best_coverage_increase <=1e-3 or best_covered_num - current_covered_num <= 5: break selected_views.append(best_view) best_rec_pts_num = best_combined_point_cloud.shape[0] - print(f"Current rec pts num: {curr_rec_pts_num}, Best rec pts num: {best_rec_pts_num}, Max rec pts num: {max_rec_pts_num}") - print(f"Current coverage: {current_coverage}, Best coverage increase: {best_coverage_increase}, Max coverage: {max_rec_pts_coverage}") - + print(f"Current rec pts num: {curr_rec_pts_num}, Best rec pts num: {best_rec_pts_num}, Best cover pts: {best_covered_num}, Max rec pts num: {max_rec_pts_num}") + print(f"Current coverage: {current_coverage}, Best coverage increase: {best_coverage_increase}, Max Real coverage: {max_real_rec_pts_coverage}") + current_covered_num = best_covered_num curr_rec_pts_num = best_rec_pts_num combined_point_cloud = best_combined_point_cloud remaining_views.remove(best_view) From a84417ef62f5e0bcd9bbdaf133ed724c5a3e14d1 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sun, 6 Oct 2024 11:49:03 +0800 Subject: [PATCH 4/7] add fps --- core/nbv_dataset.py | 4 ++-- utils/pts.py | 46 +++++++++++++++++++++++++++------------------ 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index a3c5e1f..702134e 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -162,8 +162,8 @@ class NBVReconstructionDataset(BaseDataset): ) combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) - voxel_downsampled_combined_scanned_pts_np = ( - PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002) + voxel_downsampled_combined_scanned_pts_np, _ = ( + PtsUtil.voxelize_points(combined_scanned_views_pts, 0.002) ) random_downsampled_combined_scanned_pts_np = ( PtsUtil.random_downsample_point_cloud( diff --git a/utils/pts.py b/utils/pts.py index 0551149..4716ce1 100644 --- a/utils/pts.py +++ b/utils/pts.py @@ -12,12 +12,6 @@ class PtsUtil: downsampled_pc = o3d_pc.voxel_down_sample(voxel_size) return np.asarray(downsampled_pc.points) - @staticmethod - def transform_point_cloud(points, pose_mat): - points_h = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1) - points_h = np.dot(pose_mat, points_h.T).T - return points_h[:, :3] - @staticmethod def random_downsample_point_cloud(point_cloud, num_points, require_idx=False): if point_cloud.shape[0] == 0: @@ -29,6 +23,28 @@ class PtsUtil: return point_cloud[idx], idx return point_cloud[idx] + @staticmethod + def fps_downsample_point_cloud(point_cloud, num_points, require_mask=False): + N = point_cloud.shape[0] + mask = np.zeros(N, dtype=bool) + + sampled_indices = np.zeros(num_points, dtype=int) + sampled_indices[0] = np.random.randint(0, N) + mask[sampled_indices[0]] = True + distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1) + for i in range(1, num_points): + farthest_index = np.argmax(distances) + sampled_indices[i] = farthest_index + mask[farthest_index] = True + + new_distances = np.linalg.norm(point_cloud - point_cloud[farthest_index], axis=1) + distances = np.minimum(distances, new_distances) + + sampled_points = point_cloud[sampled_indices] + if require_mask: + return sampled_points, mask + return sampled_points + @staticmethod def random_downsample_point_cloud_tensor(point_cloud, num_points): idx = torch.randint(0, len(point_cloud), (num_points,)) @@ -40,6 +56,12 @@ class PtsUtil: unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True) return unique_voxels + @staticmethod + def transform_point_cloud(points, pose_mat): + points_h = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1) + points_h = np.dot(pose_mat, points_h.T).T + return points_h[:, :3] + @staticmethod def get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size=0.005, require_idx=False): voxels_L, indices_L = PtsUtil.voxelize_points(point_cloud_L, voxel_size) @@ -56,18 +78,6 @@ class PtsUtil: return overlapping_points, mask_L return overlapping_points - @staticmethod - def new_filter_points(points, normals, cam_pose, theta=75, require_idx=False): - camera_axis = -cam_pose[:3, 2] - normals_normalized = normals / np.linalg.norm(normals, axis=1, keepdims=True) - cos_theta = np.dot(normals_normalized, camera_axis) - theta_rad = np.deg2rad(theta) - idx = cos_theta > np.cos(theta_rad) - filtered_points= points[idx] - if require_idx: - return filtered_points, idx - return filtered_points - @staticmethod def filter_points(points, points_normals, cam_pose, voxel_size=0.002, theta=45, z_range=(0.2, 0.45)): From 276f45dcc369e06e10bc957bec0eb075fdfe0172 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sun, 6 Oct 2024 12:01:10 +0800 Subject: [PATCH 5/7] add scanned_pts_mask --- core/nbv_dataset.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index 702134e..3880ed5 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -162,18 +162,23 @@ class NBVReconstructionDataset(BaseDataset): ) combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) - voxel_downsampled_combined_scanned_pts_np, _ = ( - PtsUtil.voxelize_points(combined_scanned_views_pts, 0.002) - ) - random_downsampled_combined_scanned_pts_np = ( - PtsUtil.random_downsample_point_cloud( - voxel_downsampled_combined_scanned_pts_np, self.pts_num - ) + fps_downsampled_combined_scanned_pts, fps_mask = PtsUtil.fps_downsample_point_cloud( + combined_scanned_views_pts, self.pts_num, require_mask=True ) + + view_start_indices = np.cumsum([0] + [pts.shape[0] for pts in scanned_views_pts[:-1]]) + scanned_pts_mask = [] + + for i, start_idx in enumerate(view_start_indices[:-1]): + end_idx = view_start_indices[i + 1] + view_mask = fps_mask[start_idx:end_idx] + scanned_pts_mask.append(view_mask) + data_item = { "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), + "scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.uint8), "combined_scanned_pts": np.asarray( - random_downsampled_combined_scanned_pts_np, dtype=np.float32 + fps_downsampled_combined_scanned_pts, dtype=np.float32 ), "scanned_coverage_rate": scanned_coverages_rate, "scanned_n_to_world_pose_9d": np.asarray( @@ -199,6 +204,9 @@ class NBVReconstructionDataset(BaseDataset): collate_data["scanned_pts"] = [ torch.tensor(item["scanned_pts"]) for item in batch ] + collate_data["scanned_pts_mask"] = [ + torch.tensor(item["scanned_pts_mask"]) for item in batch + ] collate_data["scanned_n_to_world_pose_9d"] = [ torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch ] @@ -218,6 +226,7 @@ class NBVReconstructionDataset(BaseDataset): for key in batch[0].keys(): if key not in [ "scanned_pts", + "scanned_pts_mask", "scanned_n_to_world_pose_9d", "best_to_world_pose_9d", "first_frame_to_world", From fa69f9f8794ddaac6c89ca86048ab73f37545e40 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sun, 6 Oct 2024 13:48:54 +0800 Subject: [PATCH 6/7] update fps algo and fps mask --- core/global_pts_n_num_pipeline.py | 118 ++++++++++++++++++++---------- core/nbv_dataset.py | 74 +++++++++---------- utils/pts.py | 7 +- 3 files changed, 115 insertions(+), 84 deletions(-) diff --git a/core/global_pts_n_num_pipeline.py b/core/global_pts_n_num_pipeline.py index 3970bd3..efc2bb8 100644 --- a/core/global_pts_n_num_pipeline.py +++ b/core/global_pts_n_num_pipeline.py @@ -12,41 +12,55 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): super(NBVReconstructionGlobalPointsPipeline, self).__init__() self.config = config self.module_config = config["modules"] - self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) - self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) - self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]) - self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) - self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]) - + self.pts_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pts_encoder"] + ) + self.pose_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pose_encoder"] + ) + self.pose_n_num_seq_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"] + ) + self.view_finder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["view_finder"] + ) + self.pts_num_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"] + ) + self.eps = float(self.config["eps"]) self.enable_global_scanned_feat = self.config["global_scanned_feat"] - + def forward(self, data): mode = data["mode"] - + if mode == namespace.Mode.TRAIN: return self.forward_train(data) elif mode == namespace.Mode.TEST: return self.forward_test(data) else: Log.error("Unknown mode: {}".format(mode), True) - + def pertube_data(self, gt_delta_9d): bs = gt_delta_9d.shape[0] - random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps + random_t = ( + torch.rand(bs, device=gt_delta_9d.device) * (1.0 - self.eps) + self.eps + ) random_t = random_t.unsqueeze(-1) mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) std = std.view(-1, 1) z = torch.randn_like(gt_delta_9d) perturbed_x = mu + z * std - target_score = - z * std / (std ** 2) + target_score = -z * std / (std**2) return perturbed_x, random_t, target_score, std - + def forward_train(self, data): main_feat = self.get_main_feat(data) - ''' get std ''' + """ get std """ best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] - perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch) + perturbed_x, random_t, target_score, std = self.pertube_data( + best_to_world_pose_9d_batch + ) input_data = { "sampled_pose": perturbed_x, "t": random_t, @@ -56,45 +70,69 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): output = { "estimated_score": estimated_score, "target_score": target_score, - "std": std + "std": std, } return output - - def forward_test(self,data): + + def forward_test(self, data): main_feat = self.get_main_feat(data) - estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat) + estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view( + main_feat + ) result = { "pred_pose_9d": estimated_delta_rot_9d, - "in_process_sample": in_process_sample + "in_process_sample": in_process_sample, } return result - - + def get_main_feat(self, data): - scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d'] - scanned_target_pts_num_batch = data['scanned_target_points_num'] + scanned_n_to_world_pose_9d_batch = data[ + "scanned_n_to_world_pose_9d" + ] # List(B): Tensor(S x 9) + scanned_pts_mask_batch = data[ + "scanned_pts_mask" + ] # Tensor(B x N) device = next(self.parameters()).device embedding_list_batch = [] - - for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch): - scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) - scanned_target_pts_num = scanned_target_pts_num.to(device) - pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) - pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) - embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1)) - main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) - - - combined_scanned_pts_batch = data['combined_scanned_pts'] - global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch) - main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1) - - + combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3) + global_scanned_feat, perpoint_scanned_feat_batch = self.pts_encoder.encode_points( + combined_scanned_pts_batch, require_per_point_feat=True + ) # global_scanned_feat: Tensor(B x Dg), perpoint_scanned_feat: Tensor(B x N x Dl) + + for scanned_n_to_world_pose_9d, scanned_mask, perpoint_scanned_feat in zip( + scanned_n_to_world_pose_9d_batch, + scanned_pts_mask_batch, + perpoint_scanned_feat_batch, + ): + scanned_target_pts_num = [] # List(S): Int + partial_feat_seq = [] + + seq_len = len(scanned_n_to_world_pose_9d) + for seq_idx in range(seq_len): + partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask + partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl) + partial_feat = torch.mean(partial_perpoint_feat, dim=0)[0] # Tensor(Dl) + partial_feat_seq.append(partial_feat) + scanned_target_pts_num.append(partial_perpoint_feat.shape[0]) + + scanned_target_pts_num = torch.tensor(scanned_target_pts_num, dtype=torch.int32).to(device) # Tensor(S) + scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9) + + pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp) + pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) # Tensor(S x Dn) + partial_feat_seq = torch.stack(partial_feat_seq) # Tensor(S x Dl) + + seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl)) + embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl)) + + seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) + + main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg)) + if torch.isnan(main_feat).any(): Log.error("nan in main_feat", True) - + return main_feat - diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index 3880ed5..4013d98 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -122,7 +122,6 @@ class NBVReconstructionDataset(BaseDataset): scanned_views_pts, scanned_coverages_rate, scanned_n_to_world_pose, - scanned_target_pts_num, ) = ([], [], [], []) for view in scanned_views: frame_idx = view[0] @@ -134,7 +133,6 @@ class NBVReconstructionDataset(BaseDataset): target_point_cloud = ( DataLoadUtil.load_from_preprocessed_pts(view_path) ) - target_pts_num = target_point_cloud.shape[0] downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud( target_point_cloud, self.pts_num ) @@ -146,7 +144,7 @@ class NBVReconstructionDataset(BaseDataset): n_to_world_trans = n_to_world_pose[:3, 3] n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0) scanned_n_to_world_pose.append(n_to_world_9d) - scanned_target_pts_num.append(target_pts_num) + nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx) @@ -162,35 +160,33 @@ class NBVReconstructionDataset(BaseDataset): ) combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) - fps_downsampled_combined_scanned_pts, fps_mask = PtsUtil.fps_downsample_point_cloud( - combined_scanned_views_pts, self.pts_num, require_mask=True + fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud( + combined_scanned_views_pts, self.pts_num, require_idx=True ) - view_start_indices = np.cumsum([0] + [pts.shape[0] for pts in scanned_views_pts[:-1]]) - scanned_pts_mask = [] - - for i, start_idx in enumerate(view_start_indices[:-1]): - end_idx = view_start_indices[i + 1] - view_mask = fps_mask[start_idx:end_idx] - scanned_pts_mask.append(view_mask) + combined_scanned_views_pts_mask = np.zeros(len(scanned_views_pts), dtype=np.uint8) + + start_idx = 0 + for i in range(len(scanned_views_pts)): + end_idx = start_idx + len(scanned_views_pts[i]) + combined_scanned_views_pts_mask[start_idx:end_idx] = i + start_idx = end_idx + + fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx] + + + data_item = { - "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), - "scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.uint8), - "combined_scanned_pts": np.asarray( - fps_downsampled_combined_scanned_pts, dtype=np.float32 - ), - "scanned_coverage_rate": scanned_coverages_rate, - "scanned_n_to_world_pose_9d": np.asarray( - scanned_n_to_world_pose, dtype=np.float32 - ), - "best_coverage_rate": nbv_coverage_rate, - "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), - "seq_max_coverage_rate": max_coverage_rate, - "scene_name": scene_name, - "scanned_target_points_num": np.asarray( - scanned_target_pts_num, dtype=np.int32 - ), + "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3) + "scanned_pts_mask": np.asarray(fps_downsampled_combined_scanned_pts_mask,dtype=np.uint8), # Ndarray(N), range(0, S) + "combined_scanned_pts": np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32), # Ndarray(N x 3) + "scanned_coverage_rate": scanned_coverages_rate, # List(S): Float, range(0, 1) + "scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9) + "best_coverage_rate": nbv_coverage_rate, # Float, range(0, 1) + "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), # Ndarray(9) + "seq_max_coverage_rate": max_coverage_rate, # Float, range(0, 1) + "scene_name": scene_name, # String } return data_item @@ -201,37 +197,35 @@ class NBVReconstructionDataset(BaseDataset): def get_collate_fn(self): def collate_fn(batch): collate_data = {} + + ''' ------ Varialbe Length ------ ''' + collate_data["scanned_pts"] = [ torch.tensor(item["scanned_pts"]) for item in batch ] - collate_data["scanned_pts_mask"] = [ - torch.tensor(item["scanned_pts_mask"]) for item in batch - ] collate_data["scanned_n_to_world_pose_9d"] = [ torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch ] - collate_data["scanned_target_points_num"] = [ - torch.tensor(item["scanned_target_points_num"]) for item in batch - ] + + ''' ------ Fixed Length ------ ''' + collate_data["best_to_world_pose_9d"] = torch.stack( [torch.tensor(item["best_to_world_pose_9d"]) for item in batch] ) collate_data["combined_scanned_pts"] = torch.stack( [torch.tensor(item["combined_scanned_pts"]) for item in batch] ) - if "first_frame_to_world" in batch[0]: - collate_data["first_frame_to_world"] = torch.stack( - [torch.tensor(item["first_frame_to_world"]) for item in batch] - ) + collate_data["scanned_pts_mask"] = torch.stack( + [torch.tensor(item["scanned_pts_mask"]) for item in batch] + ) + for key in batch[0].keys(): if key not in [ "scanned_pts", "scanned_pts_mask", "scanned_n_to_world_pose_9d", "best_to_world_pose_9d", - "first_frame_to_world", "combined_scanned_pts", - "scanned_target_points_num", ]: collate_data[key] = [item[key] for item in batch] return collate_data diff --git a/utils/pts.py b/utils/pts.py index 4716ce1..87818cb 100644 --- a/utils/pts.py +++ b/utils/pts.py @@ -24,13 +24,12 @@ class PtsUtil: return point_cloud[idx] @staticmethod - def fps_downsample_point_cloud(point_cloud, num_points, require_mask=False): + def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False): N = point_cloud.shape[0] mask = np.zeros(N, dtype=bool) sampled_indices = np.zeros(num_points, dtype=int) sampled_indices[0] = np.random.randint(0, N) - mask[sampled_indices[0]] = True distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1) for i in range(1, num_points): farthest_index = np.argmax(distances) @@ -41,8 +40,8 @@ class PtsUtil: distances = np.minimum(distances, new_distances) sampled_points = point_cloud[sampled_indices] - if require_mask: - return sampled_points, mask + if require_idx: + return sampled_points, sampled_indices return sampled_points @staticmethod From bfc8ba0f4bace77ea878fdbd9267a9331dfe4bb9 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sun, 6 Oct 2024 13:53:32 +0800 Subject: [PATCH 7/7] update transformer_seq_encoder's config --- configs/server/server_train_config.yaml | 12 ++---------- core/global_pts_n_num_pipeline.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index 6e44234..bf12e1b 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -90,7 +90,7 @@ pipeline: nbv_reconstruction_global_pts_pipeline: modules: pts_encoder: pointnet_encoder - pose_seq_encoder: transformer_pose_seq_encoder + pose_seq_encoder: transformer_seq_encoder pose_encoder: pose_encoder view_finder: gf_view_finder eps: 1e-5 @@ -107,20 +107,12 @@ module: feature_transform: False transformer_seq_encoder: - pts_embed_dim: 1024 - pose_embed_dim: 256 + embed_dim: 1344 num_heads: 4 ffn_dim: 256 num_layers: 3 output_dim: 2048 - transformer_pose_seq_encoder: - pose_embed_dim: 256 - num_heads: 4 - ffn_dim: 256 - num_layers: 3 - output_dim: 1024 - gf_view_finder: t_feat_dim: 128 pose_feat_dim: 256 diff --git a/core/global_pts_n_num_pipeline.py b/core/global_pts_n_num_pipeline.py index efc2bb8..04a360b 100644 --- a/core/global_pts_n_num_pipeline.py +++ b/core/global_pts_n_num_pipeline.py @@ -12,21 +12,24 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): super(NBVReconstructionGlobalPointsPipeline, self).__init__() self.config = config self.module_config = config["modules"] + self.pts_encoder = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["pts_encoder"] ) self.pose_encoder = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["pose_encoder"] ) - self.pose_n_num_seq_encoder = ComponentFactory.create( - namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"] + self.pts_num_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"] + ) + + self.transformer_seq_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["transformer_seq_encoder"] ) self.view_finder = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["view_finder"] ) - self.pts_num_encoder = ComponentFactory.create( - namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"] - ) + self.eps = float(self.config["eps"]) self.enable_global_scanned_feat = self.config["global_scanned_feat"] @@ -128,7 +131,7 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl)) embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl)) - seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) + seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))