update nbv_dataset: scene_points to target_points

This commit is contained in:
hofee 2024-10-05 15:17:54 -05:00
parent 60c9357491
commit 1a3ae15130
2 changed files with 10 additions and 53 deletions

View File

@ -124,63 +124,20 @@ class NBVReconstructionDataset(BaseDataset):
scanned_n_to_world_pose, scanned_n_to_world_pose,
scanned_target_pts_num, scanned_target_pts_num,
) = ([], [], [], []) ) = ([], [], [], [])
target_pts_num_dict = DataLoadUtil.load_target_pts_num_dict(
self.root_dir, scene_name
)
for view in scanned_views: for view in scanned_views:
frame_idx = view[0] frame_idx = view[0]
coverage_rate = view[1] coverage_rate = view[1]
view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx)
cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True) cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True)
target_pts_num = target_pts_num_dict[frame_idx]
n_to_world_pose = cam_info["cam_to_world"] n_to_world_pose = cam_info["cam_to_world"]
nR_to_world_pose = cam_info["cam_to_world_R"] target_point_cloud = (
DataLoadUtil.load_from_preprocessed_pts(view_path)
if self.load_from_preprocess: )
downsampled_target_point_cloud = ( target_pts_num = target_point_cloud.shape[0]
DataLoadUtil.load_from_preprocessed_pts(view_path) downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(
) target_point_cloud, self.pts_num
else: )
cached_data = None
if self.cache:
cached_data = self.load_from_cache(scene_name, frame_idx)
if cached_data is None:
print("load depth")
depth_L, depth_R = DataLoadUtil.load_depth(
view_path,
cam_info["near_plane"],
cam_info["far_plane"],
binocular=True,
)
point_cloud_L = DataLoadUtil.get_point_cloud(
depth_L, cam_info["cam_intrinsic"], n_to_world_pose
)["points_world"]
point_cloud_R = DataLoadUtil.get_point_cloud(
depth_R, cam_info["cam_intrinsic"], nR_to_world_pose
)["points_world"]
point_cloud_L = PtsUtil.random_downsample_point_cloud(
point_cloud_L, 65536
)
point_cloud_R = PtsUtil.random_downsample_point_cloud(
point_cloud_R, 65536
)
overlap_points = PtsUtil.get_overlapping_points(
point_cloud_L, point_cloud_R
)
downsampled_target_point_cloud = (
PtsUtil.random_downsample_point_cloud(
overlap_points, self.pts_num
)
)
if self.cache:
self.save_to_cache(
scene_name, frame_idx, downsampled_target_point_cloud
)
else:
downsampled_target_point_cloud = cached_data
scanned_views_pts.append(downsampled_target_point_cloud) scanned_views_pts.append(downsampled_target_point_cloud)
scanned_coverages_rate.append(coverage_rate) scanned_coverages_rate.append(coverage_rate)
n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(

View File

@ -237,7 +237,7 @@ class DataLoadUtil:
@staticmethod @staticmethod
def load_from_preprocessed_pts(path): def load_from_preprocessed_pts(path):
npy_path = os.path.join( npy_path = os.path.join(
os.path.dirname(path), "points", os.path.basename(path) + ".npy" os.path.dirname(path), "pts", os.path.basename(path) + ".npy"
) )
pts = np.load(npy_path) pts = np.load(npy_path)
return pts return pts