From 99e57c3f4cdf7dd8d112185ad695ded61a1a7e17 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Sun, 29 Sep 2024 18:11:55 +0800 Subject: [PATCH] add target_pts_num into dataset --- core/nbv_dataset.py | 251 ++++++++++++++++++++++++-------------------- utils/data_load.py | 7 ++ 2 files changed, 144 insertions(+), 114 deletions(-) diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index f2748b8..829bb01 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -7,6 +7,7 @@ from PytorchBoot.utils.log_util import Log import torch import os import sys + sys.path.append(r"/home/data/hofee/project/nbv_rec/nbv_reconstruction") from utils.data_load import DataLoadUtil @@ -29,20 +30,17 @@ class NBVReconstructionDataset(BaseDataset): self.cache = config.get("cache") self.load_from_preprocess = config.get("load_from_preprocess", False) - if self.type == namespace.Mode.TEST: self.model_dir = config["model_dir"] self.filter_degree = config["filter_degree"] if self.type == namespace.Mode.TRAIN: scale_ratio = 100 - self.datalist = self.datalist*scale_ratio + self.datalist = self.datalist * scale_ratio if self.cache: expr_root = ConfigManager.get("runner", "experiment", "root_dir") expr_name = ConfigManager.get("runner", "experiment", "name") self.cache_dir = os.path.join(expr_root, expr_name, "cache") - #self.preprocess_cache() - - + # self.preprocess_cache() def load_scene_name_list(self): scene_name_list = [] @@ -51,7 +49,7 @@ class NBVReconstructionDataset(BaseDataset): scene_name = line.strip() scene_name_list.append(scene_name) return scene_name_list - + def get_datalist(self): datalist = [] for scene_name in self.scene_name_list: @@ -60,7 +58,9 @@ class NBVReconstructionDataset(BaseDataset): max_coverage_rate_list = [] for seq_idx in range(seq_num): - label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx) + label_path = DataLoadUtil.get_label_path( + self.root_dir, scene_name, seq_idx + ) label_data = DataLoadUtil.load_label(label_path) max_coverage_rate = label_data["max_coverage_rate"] if max_coverage_rate > scene_max_coverage_rate: @@ -69,20 +69,24 @@ class NBVReconstructionDataset(BaseDataset): mean_coverage_rate = np.mean(max_coverage_rate_list) for seq_idx in range(seq_num): - label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx) + label_path = DataLoadUtil.get_label_path( + self.root_dir, scene_name, seq_idx + ) label_data = DataLoadUtil.load_label(label_path) if max_coverage_rate_list[seq_idx] > mean_coverage_rate - 0.1: for data_pair in label_data["data_pairs"]: scanned_views = data_pair[0] next_best_view = data_pair[1] - datalist.append({ - "scanned_views": scanned_views, - "next_best_view": next_best_view, - "seq_max_coverage_rate": max_coverage_rate, - "scene_name": scene_name, - "label_idx": seq_idx, - "scene_max_coverage_rate": scene_max_coverage_rate - }) + datalist.append( + { + "scanned_views": scanned_views, + "next_best_view": next_best_view, + "seq_max_coverage_rate": max_coverage_rate, + "scene_name": scene_name, + "label_idx": seq_idx, + "scene_max_coverage_rate": scene_max_coverage_rate, + } + ) return datalist def preprocess_cache(self): @@ -90,7 +94,7 @@ class NBVReconstructionDataset(BaseDataset): for item_idx in range(len(self.datalist)): self.__getitem__(item_idx) Log.success("finish preprocessing cache.") - + def load_from_cache(self, scene_name, curr_frame_idx): cache_name = f"{scene_name}_{curr_frame_idx}.txt" cache_path = os.path.join(self.cache_dir, cache_name) @@ -99,7 +103,7 @@ class NBVReconstructionDataset(BaseDataset): return data else: return None - + def save_to_cache(self, scene_name, curr_frame_idx, data): cache_name = f"{scene_name}_{curr_frame_idx}.txt" cache_path = os.path.join(self.cache_dir, cache_name) @@ -107,125 +111,172 @@ class NBVReconstructionDataset(BaseDataset): np.savetxt(cache_path, data) except Exception as e: Log.error(f"Save cache failed: {e}") - # ----- Debug Trace ----- # - import ipdb; ipdb.set_trace() - # ------------------------ # - + def __getitem__(self, index): data_item_info = self.datalist[index] scanned_views = data_item_info["scanned_views"] nbv = data_item_info["next_best_view"] max_coverage_rate = data_item_info["seq_max_coverage_rate"] scene_name = data_item_info["scene_name"] - scanned_views_pts, scanned_coverages_rate, scanned_n_to_world_pose = [], [], [] - + ( + scanned_views_pts, + scanned_coverages_rate, + scanned_n_to_world_pose, + scanned_target_pts_num, + ) = ([], [], [], []) + target_pts_num_dict = DataLoadUtil.load_target_pts_num_dict( + self.root_dir, scene_name + ) for view in scanned_views: frame_idx = view[0] coverage_rate = view[1] view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True) + target_pts_num = target_pts_num_dict[frame_idx] n_to_world_pose = cam_info["cam_to_world"] - nR_to_world_pose = cam_info["cam_to_world_R"] + nR_to_world_pose = cam_info["cam_to_world_R"] if self.load_from_preprocess: - downsampled_target_point_cloud = DataLoadUtil.load_from_preprocessed_pts(view_path) + downsampled_target_point_cloud = ( + DataLoadUtil.load_from_preprocessed_pts(view_path) + ) else: cached_data = None if self.cache: cached_data = self.load_from_cache(scene_name, frame_idx) - + if cached_data is None: print("load depth") - depth_L, depth_R = DataLoadUtil.load_depth(view_path, cam_info['near_plane'], cam_info['far_plane'], binocular=True) - point_cloud_L = DataLoadUtil.get_point_cloud(depth_L, cam_info['cam_intrinsic'], n_to_world_pose)['points_world'] - point_cloud_R = DataLoadUtil.get_point_cloud(depth_R, cam_info['cam_intrinsic'], nR_to_world_pose)['points_world'] - - point_cloud_L = PtsUtil.random_downsample_point_cloud(point_cloud_L, 65536) - point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536) - overlap_points = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R) - downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(overlap_points, self.pts_num) + depth_L, depth_R = DataLoadUtil.load_depth( + view_path, + cam_info["near_plane"], + cam_info["far_plane"], + binocular=True, + ) + point_cloud_L = DataLoadUtil.get_point_cloud( + depth_L, cam_info["cam_intrinsic"], n_to_world_pose + )["points_world"] + point_cloud_R = DataLoadUtil.get_point_cloud( + depth_R, cam_info["cam_intrinsic"], nR_to_world_pose + )["points_world"] + + point_cloud_L = PtsUtil.random_downsample_point_cloud( + point_cloud_L, 65536 + ) + point_cloud_R = PtsUtil.random_downsample_point_cloud( + point_cloud_R, 65536 + ) + overlap_points = DataLoadUtil.get_overlapping_points( + point_cloud_L, point_cloud_R + ) + downsampled_target_point_cloud = ( + PtsUtil.random_downsample_point_cloud( + overlap_points, self.pts_num + ) + ) if self.cache: - self.save_to_cache(scene_name, frame_idx, downsampled_target_point_cloud) + self.save_to_cache( + scene_name, frame_idx, downsampled_target_point_cloud + ) else: downsampled_target_point_cloud = cached_data - + scanned_views_pts.append(downsampled_target_point_cloud) - scanned_coverages_rate.append(coverage_rate) - n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(np.asarray(n_to_world_pose[:3,:3])) - n_to_world_trans = n_to_world_pose[:3,3] + scanned_coverages_rate.append(coverage_rate) + n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( + np.asarray(n_to_world_pose[:3, :3]) + ) + n_to_world_trans = n_to_world_pose[:3, 3] n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0) scanned_n_to_world_pose.append(n_to_world_9d) + scanned_target_pts_num.append(target_pts_num) nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx) cam_info = DataLoadUtil.load_cam_info(nbv_path) best_frame_to_world = cam_info["cam_to_world"] - - best_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(np.asarray(best_frame_to_world[:3,:3])) - best_to_world_trans = best_frame_to_world[:3,3] - best_to_world_9d = np.concatenate([best_to_world_6d, best_to_world_trans], axis=0) + + best_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( + np.asarray(best_frame_to_world[:3, :3]) + ) + best_to_world_trans = best_frame_to_world[:3, 3] + best_to_world_9d = np.concatenate( + [best_to_world_6d, best_to_world_trans], axis=0 + ) combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) - voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002) - random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, self.pts_num) + voxel_downsampled_combined_scanned_pts_np = ( + PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002) + ) + random_downsampled_combined_scanned_pts_np = ( + PtsUtil.random_downsample_point_cloud( + voxel_downsampled_combined_scanned_pts_np, self.pts_num + ) + ) data_item = { - "scanned_pts": np.asarray(scanned_views_pts,dtype=np.float32), - "combined_scanned_pts": np.asarray(random_downsampled_combined_scanned_pts_np,dtype=np.float32), + "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), + "combined_scanned_pts": np.asarray( + random_downsampled_combined_scanned_pts_np, dtype=np.float32 + ), "scanned_coverage_rate": scanned_coverages_rate, - "scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose,dtype=np.float32), + "scanned_n_to_world_pose_9d": np.asarray( + scanned_n_to_world_pose, dtype=np.float32 + ), "best_coverage_rate": nbv_coverage_rate, - "best_to_world_pose_9d": np.asarray(best_to_world_9d,dtype=np.float32), + "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), "seq_max_coverage_rate": max_coverage_rate, - "scene_name": scene_name + "scene_name": scene_name, + "scanned_target_points_num": np.asarray( + scanned_target_pts_num, dtype=np.int32 + ), } - - - # if self.type == namespace.Mode.TEST: - # diag = DataLoadUtil.get_bbox_diag(self.model_dir, scene_name) - # voxel_threshold = diag*0.02 - # model_points_normals = DataLoadUtil.load_points_normals(self.root_dir, scene_name) - # pts_list = [] - # for view in scanned_views: - # frame_idx = view[0] - # view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) - # point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(view_path, binocular=True) - # cam_params = DataLoadUtil.load_cam_info(view_path, binocular=True) - # sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree) - # pts_list.append(sampled_point_cloud) - # nL_to_world_pose = cam_params["cam_to_world"] - # nO_to_world_pose = cam_params["cam_to_world_O"] - # nO_to_nL_pose = np.dot(np.linalg.inv(nL_to_world_pose), nO_to_world_pose) - # data_item["scanned_target_pts_list"] = pts_list - # data_item["model_points_normals"] = model_points_normals - # data_item["voxel_threshold"] = voxel_threshold - # data_item["filter_degree"] = self.filter_degree - # data_item["scene_path"] = os.path.join(self.root_dir, scene_name) - # data_item["first_frame_to_world"] = np.asarray(first_frame_to_world, dtype=np.float32) - # data_item["nO_to_nL_pose"] = np.asarray(nO_to_nL_pose, dtype=np.float32) + return data_item def __len__(self): return len(self.datalist) - + def get_collate_fn(self): def collate_fn(batch): collate_data = {} - collate_data["scanned_pts"] = [torch.tensor(item['scanned_pts']) for item in batch] - collate_data["scanned_n_to_world_pose_9d"] = [torch.tensor(item['scanned_n_to_world_pose_9d']) for item in batch] - collate_data["best_to_world_pose_9d"] = torch.stack([torch.tensor(item['best_to_world_pose_9d']) for item in batch]) - collate_data["combined_scanned_pts"] = torch.stack([torch.tensor(item['combined_scanned_pts']) for item in batch]) + collate_data["scanned_pts"] = [ + torch.tensor(item["scanned_pts"]) for item in batch + ] + collate_data["scanned_n_to_world_pose_9d"] = [ + torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch + ] + collate_data["scanned_target_points_num"] = [ + torch.tensor(item["scanned_target_points_num"]) for item in batch + ] + collate_data["best_to_world_pose_9d"] = torch.stack( + [torch.tensor(item["best_to_world_pose_9d"]) for item in batch] + ) + collate_data["combined_scanned_pts"] = torch.stack( + [torch.tensor(item["combined_scanned_pts"]) for item in batch] + ) if "first_frame_to_world" in batch[0]: - collate_data["first_frame_to_world"] = torch.stack([torch.tensor(item["first_frame_to_world"]) for item in batch]) + collate_data["first_frame_to_world"] = torch.stack( + [torch.tensor(item["first_frame_to_world"]) for item in batch] + ) for key in batch[0].keys(): - if key not in ["scanned_pts", "scanned_n_to_world_pose_9d", "best_to_world_pose_9d", "first_frame_to_world", "combined_scanned_pts"]: + if key not in [ + "scanned_pts", + "scanned_n_to_world_pose_9d", + "best_to_world_pose_9d", + "first_frame_to_world", + "combined_scanned_pts", + "scanned_target_points_num", + ]: collate_data[key] = [item[key] for item in batch] return collate_data + return collate_fn # -------------- Debug ---------------- # if __name__ == "__main__": import torch + seed = 0 torch.manual_seed(seed) np.random.seed(seed) @@ -244,41 +295,13 @@ if __name__ == "__main__": } ds = NBVReconstructionDataset(config) print(len(ds)) - #ds.__getitem__(10) + # ds.__getitem__(10) dl = ds.get_loader(shuffle=True) for idx, data in enumerate(dl): data = ds.process_batch(data, "cuda:0") print(data) # ------ Debug Start ------ - import ipdb;ipdb.set_trace() + import ipdb + + ipdb.set_trace() # ------ Debug End ------ - # - # for idx, data in enumerate(dl): - # cnt=0 - # print(data["scene_name"]) - # print(data["scanned_coverage_rate"]) - # print(data["best_coverage_rate"]) - # for pts in data["scanned_pts"][0]: - # #np.savetxt(f"pts_{cnt}.txt", pts) - # cnt+=1 - # #np.savetxt("best_pts.txt", best_pts) - # for key, value in data.items(): - # if isinstance(value, torch.Tensor): - # print(key, ":" ,value.shape) - # else: - # print(key, ":" ,len(value)) - # if key == "scanned_n_to_world_pose_9d": - # for val in value: - # print(val.shape) - # if key == "scanned_pts": - # print("scanned_pts") - # for val in value: - # print(val.shape) - # cnt = 0 - # for v in val: - # import ipdb;ipdb.set_trace() - # np.savetxt(f"pts_{cnt}.txt", v) - # cnt+=1 - - - # print() \ No newline at end of file diff --git a/utils/data_load.py b/utils/data_load.py index 16f568b..5cebfd0 100644 --- a/utils/data_load.py +++ b/utils/data_load.py @@ -98,6 +98,13 @@ class DataLoadUtil: scene_info = json.load(f) return scene_info + @staticmethod + def load_target_pts_num_dict(root, scene_name): + target_pts_num_path = os.path.join(root, scene_name, "target_pts_num.json") + with open(target_pts_num_path, "r") as f: + target_pts_num_dict = json.load(f) + return target_pts_num_dict + @staticmethod def load_target_object_pose(root, scene_name): scene_info = DataLoadUtil.load_scene_info(root, scene_name)