diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index 4573bb0..a3c5e1f 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -124,63 +124,20 @@ class NBVReconstructionDataset(BaseDataset): scanned_n_to_world_pose, scanned_target_pts_num, ) = ([], [], [], []) - target_pts_num_dict = DataLoadUtil.load_target_pts_num_dict( - self.root_dir, scene_name - ) for view in scanned_views: frame_idx = view[0] coverage_rate = view[1] view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True) - target_pts_num = target_pts_num_dict[frame_idx] - n_to_world_pose = cam_info["cam_to_world"] - nR_to_world_pose = cam_info["cam_to_world_R"] - - if self.load_from_preprocess: - downsampled_target_point_cloud = ( - DataLoadUtil.load_from_preprocessed_pts(view_path) - ) - else: - cached_data = None - if self.cache: - cached_data = self.load_from_cache(scene_name, frame_idx) - - if cached_data is None: - print("load depth") - depth_L, depth_R = DataLoadUtil.load_depth( - view_path, - cam_info["near_plane"], - cam_info["far_plane"], - binocular=True, - ) - point_cloud_L = DataLoadUtil.get_point_cloud( - depth_L, cam_info["cam_intrinsic"], n_to_world_pose - )["points_world"] - point_cloud_R = DataLoadUtil.get_point_cloud( - depth_R, cam_info["cam_intrinsic"], nR_to_world_pose - )["points_world"] - - point_cloud_L = PtsUtil.random_downsample_point_cloud( - point_cloud_L, 65536 - ) - point_cloud_R = PtsUtil.random_downsample_point_cloud( - point_cloud_R, 65536 - ) - overlap_points = PtsUtil.get_overlapping_points( - point_cloud_L, point_cloud_R - ) - downsampled_target_point_cloud = ( - PtsUtil.random_downsample_point_cloud( - overlap_points, self.pts_num - ) - ) - if self.cache: - self.save_to_cache( - scene_name, frame_idx, downsampled_target_point_cloud - ) - else: - downsampled_target_point_cloud = cached_data - + + n_to_world_pose = cam_info["cam_to_world"] + target_point_cloud = ( + DataLoadUtil.load_from_preprocessed_pts(view_path) + ) + target_pts_num = target_point_cloud.shape[0] + downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud( + target_point_cloud, self.pts_num + ) scanned_views_pts.append(downsampled_target_point_cloud) scanned_coverages_rate.append(coverage_rate) n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( diff --git a/utils/data_load.py b/utils/data_load.py index ed5a144..4ec8936 100644 --- a/utils/data_load.py +++ b/utils/data_load.py @@ -237,7 +237,7 @@ class DataLoadUtil: @staticmethod def load_from_preprocessed_pts(path): npy_path = os.path.join( - os.path.dirname(path), "points", os.path.basename(path) + ".npy" + os.path.dirname(path), "pts", os.path.basename(path) + ".npy" ) pts = np.load(npy_path) return pts