From b209ce050c31cfbe5376d6e090600e0a8e76b3fb Mon Sep 17 00:00:00 2001 From: hofee Date: Mon, 23 Sep 2024 17:45:01 +0800 Subject: [PATCH] change world space origin --- core/nbv_dataset.py | 11 +++++++---- utils/data_load.py | 30 ++++++++++++++++++++++++------ utils/render.py | 2 +- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index 891a62c..c3bc9b8 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -29,11 +29,14 @@ class NBVReconstructionDataset(BaseDataset): self.type = config["type"] self.cache = config.get("cache") self.load_from_preprocess = config.get("load_from_preprocess", False) + + if self.type == namespace.Mode.TEST: self.model_dir = config["model_dir"] self.filter_degree = config["filter_degree"] if self.type == namespace.Mode.TRAIN: - self.datalist = self.datalist*100 + scale_ratio = 1 + self.datalist = self.datalist*scale_ratio if self.cache: expr_root = ConfigManager.get("runner", "experiment", "root_dir") expr_name = ConfigManager.get("runner", "experiment", "name") @@ -53,7 +56,7 @@ class NBVReconstructionDataset(BaseDataset): def get_datalist(self): datalist = [] for scene_name in self.scene_name_list: - label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name) + label_path = DataLoadUtil.get_label_path_old(self.root_dir, scene_name) label_data = DataLoadUtil.load_label(label_path) for data_pair in label_data["data_pairs"]: scanned_views = data_pair[0] @@ -208,11 +211,11 @@ if __name__ == "__main__": torch.manual_seed(seed) np.random.seed(seed) config = { - "root_dir": "/media/hofee/data/project/python/nbv_reconstruction/sample_for_training/preprocessed_scenes/", + "root_dir": "/media/hofee/repository/nbv_reconstruction_data_512", "model_dir": "/media/hofee/data/data/scaled_object_meshes", "source": "nbv_reconstruction_dataset", "split_file": "/media/hofee/data/project/python/nbv_reconstruction/sample_for_training/OmniObject3d_train.txt", - "load_from_preprocess": True, + "load_from_preprocess": False, "ratio": 0.5, "batch_size": 2, "filter_degree": 75, diff --git a/utils/data_load.py b/utils/data_load.py index 78ddc74..16f568b 100644 --- a/utils/data_load.py +++ b/utils/data_load.py @@ -7,7 +7,20 @@ import torch from utils.pts import PtsUtil class DataLoadUtil: - DISPLAY_TABLE_POSITION = np.asarray([0,0,0.895]) + TABLE_POSITION = np.asarray([0,0,0.8215]) + + @staticmethod + def get_display_table_info(root, scene_name): + scene_info = DataLoadUtil.load_scene_info(root, scene_name) + display_table_info = scene_info["display_table"] + return display_table_info + + @staticmethod + def get_display_table_top(root, scene_name): + display_table_height = DataLoadUtil.get_display_table_info(root, scene_name)["height"] + display_table_top = DataLoadUtil.TABLE_POSITION + np.asarray([0,0,display_table_height]) + return display_table_top + @staticmethod def get_path(root, scene_name, frame_idx): path = os.path.join(root, scene_name, f"{frame_idx}") @@ -64,7 +77,7 @@ class DataLoadUtil: target_name = scene_info["target_name"] transformation = scene_info[target_name] if display_table_as_world_space_origin: - location = transformation["location"] - DataLoadUtil.DISPLAY_TABLE_POSITION + location = transformation["location"] - DataLoadUtil.get_display_table_top(root, scene_name) else: location = transformation["location"] rotation_euler = transformation["rotation_euler"] @@ -168,13 +181,16 @@ class DataLoadUtil: @staticmethod def load_cam_info(path, binocular=False, display_table_as_world_space_origin=True): + scene_dir = os.path.dirname(path) + root_dir = os.path.dirname(scene_dir) + scene_name = os.path.basename(scene_dir) camera_params_path = os.path.join(os.path.dirname(path), "camera_params", os.path.basename(path) + ".json") with open(camera_params_path, 'r') as f: label_data = json.load(f) cam_to_world = np.asarray(label_data["extrinsic"]) cam_to_world = DataLoadUtil.cam_pose_transformation(cam_to_world) world_to_display_table = np.eye(4) - world_to_display_table[:3, 3] = - DataLoadUtil.DISPLAY_TABLE_POSITION + world_to_display_table[:3, 3] = - DataLoadUtil.get_display_table_top(root_dir, scene_name) if display_table_as_world_space_origin: cam_to_world = np.dot(world_to_display_table, cam_to_world) cam_intrinsic = np.asarray(label_data["intrinsic"]) @@ -197,13 +213,15 @@ class DataLoadUtil: return cam_info @staticmethod - def get_real_cam_O_from_cam_L(cam_L, cam_O_to_cam_L, display_table_as_world_space_origin=True): + def get_real_cam_O_from_cam_L(cam_L, cam_O_to_cam_L, scene_path, display_table_as_world_space_origin=True): + root_dir = os.path.dirname(scene_path) + scene_name = os.path.basename(scene_path) if isinstance(cam_L, torch.Tensor): cam_L = cam_L.cpu().numpy() nO_to_display_table_pose = cam_L @ cam_O_to_cam_L if display_table_as_world_space_origin: display_table_to_world = np.eye(4) - display_table_to_world[:3, 3] = DataLoadUtil.DISPLAY_TABLE_POSITION + display_table_to_world[:3, 3] = DataLoadUtil.get_display_table_top(root_dir, scene_name) nO_to_world_pose = np.dot(display_table_to_world, nO_to_display_table_pose) nO_to_world_pose = DataLoadUtil.cam_pose_transformation(nO_to_world_pose) return nO_to_world_pose @@ -292,5 +310,5 @@ class DataLoadUtil: points_path = os.path.join(root, scene_name, "points_and_normals.txt") points_normals = np.loadtxt(points_path) if display_table_as_world_space_origin: - points_normals[:,:3] = points_normals[:,:3] - DataLoadUtil.DISPLAY_TABLE_POSITION + points_normals[:,:3] = points_normals[:,:3] - DataLoadUtil.get_display_table_top(root, scene_name) return points_normals \ No newline at end of file diff --git a/utils/render.py b/utils/render.py index 2282542..1b3c5a3 100644 --- a/utils/render.py +++ b/utils/render.py @@ -11,7 +11,7 @@ class RenderUtil: @staticmethod def render_pts(cam_pose, scene_path,script_path, model_points_normals, voxel_threshold=0.005, filter_degree=75, nO_to_nL_pose=None, require_full_scene=False): - nO_to_world_pose = DataLoadUtil.get_real_cam_O_from_cam_L(cam_pose, nO_to_nL_pose) + nO_to_world_pose = DataLoadUtil.get_real_cam_O_from_cam_L(cam_pose, nO_to_nL_pose, scene_path=scene_path) with tempfile.TemporaryDirectory() as temp_dir: