From 954fed1122a7e2ea3d6f55e9d4bc2f5c85281006 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:28:20 +0800 Subject: [PATCH] finish nbv_reconstruction_dataset --- configs/generate_config.yaml | 14 +++--- configs/train_config.yaml | 21 ++++++++ core/dataset.py | 90 +++++++++++++++++++++++++++++++---- runners/strategy_generator.py | 4 +- utils/data_load.py | 6 +++ 5 files changed, 117 insertions(+), 18 deletions(-) create mode 100644 configs/train_config.yaml diff --git a/configs/generate_config.yaml b/configs/generate_config.yaml index 0ff838b..148efe5 100644 --- a/configs/generate_config.yaml +++ b/configs/generate_config.yaml @@ -4,17 +4,17 @@ runner: seed: 0 device: cpu cuda_visible_devices: "0,1,2,3,4,5,6,7" - - generate: - voxel_threshold: 0.005 - overlap_threshold: 0.3 + experiment: name: debug root_dir: "experiments" - - dataset_list: - - OmniObject3d + + generate: + voxel_threshold: 0.005 + overlap_threshold: 0.3 + dataset_list: + - OmniObject3d datasets: OmniObject3d: diff --git a/configs/train_config.yaml b/configs/train_config.yaml new file mode 100644 index 0000000..450f477 --- /dev/null +++ b/configs/train_config.yaml @@ -0,0 +1,21 @@ + +runner: + general: + seed: 0 + device: cpu + cuda_visible_devices: "0,1,2,3,4,5,6,7" + + experiment: + name: debug + root_dir: "experiments" + + train: + dataset_list: + - OmniObject3d + +datasets: + OmniObject3d: + root_dir: "C:\\Document\\Local Project\\nbv_rec\\sample_dataset" + label_dir: "C:\\Document\\Local Project\\nbv_rec\\sample_output" + + diff --git a/core/dataset.py b/core/dataset.py index 7b377d8..c4dd9a3 100644 --- a/core/dataset.py +++ b/core/dataset.py @@ -1,17 +1,89 @@ +import numpy as np from PytorchBoot.dataset import BaseDataset import PytorchBoot.stereotype as stereotype +from utils.data_load import DataLoadUtil -@stereotype.dataset("nbv_reconstruction_dataset", comment="unfinished") + +@stereotype.dataset("nbv_reconstruction_dataset") class NBVReconstructionDataset(BaseDataset): def __init__(self, config): super(NBVReconstructionDataset, self).__init__(config) self.config = config - + self.label_dir = config["label_dir"] + self.root_dir = config["root_dir"] + self.datalist = self.get_datalist() + def get_datalist(self): - pass - - def load_view(path): - pass - - def load_data_item(self, idx): - pass \ No newline at end of file + datalist = [] + scene_idx_list = DataLoadUtil.get_scene_idx_list(self.root_dir) + for scene_idx in scene_idx_list: + label_path = DataLoadUtil.get_label_path(self.label_dir, scene_idx) + label_data = DataLoadUtil.load_label(label_path) + for data_pair in label_data["data_pairs"]: + scanned_views = data_pair[0] + next_best_view = data_pair[1] + max_coverage_rate = label_data["max_coverage_rate"] + datalist.append( + { + "scanned_views": scanned_views, + "next_best_view": next_best_view, + "max_coverage_rate": max_coverage_rate, + "scene_idx": scene_idx, + } + ) + return datalist + + def __getitem__(self, index): + data_item_info = self.datalist[index] + scanned_views = data_item_info["scanned_views"] + nbv = data_item_info["next_best_view"] + max_coverage_rate = data_item_info["max_coverage_rate"] + scene_idx = data_item_info["scene_idx"] + scanned_views_pts, scanned_coverages_rate, scanned_cam_pose = [], [], [] + for view in scanned_views: + frame_idx = view[0] + coverage_rate = view[1] + view_path = DataLoadUtil.get_path(self.root_dir, scene_idx, frame_idx) + pts = DataLoadUtil.load_depth(view_path) + scanned_views_pts.append(pts) + scanned_coverages_rate.append(coverage_rate) + cam_pose = DataLoadUtil.load_cam_info(view_path)["cam_to_world"] + scanned_cam_pose.append(cam_pose) + + nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] + nbv_path = DataLoadUtil.get_path(self.root_dir, scene_idx, nbv_idx) + nbv_pts = DataLoadUtil.load_depth(nbv_path) + cam_info = DataLoadUtil.load_cam_info(nbv_path) + nbv_cam_pose = cam_info["cam_to_world"] + + data_item = { + "scanned_views_pts": np.asarray(scanned_views_pts,dtype=np.float32), + "scanned_coverages_rate": np.asarray(scanned_coverages_rate,dtype=np.float32), + "scanned_cam_pose": np.asarray(scanned_cam_pose,dtype=np.float32), + "nbv_pts": np.asarray(nbv_pts,dtype=np.float32), + "nbv_coverage_rate": nbv_coverage_rate, + "nbv_cam_pose": nbv_cam_pose, + "max_coverage_rate": max_coverage_rate, + } + + return data_item + + def __len__(self): + return len(self.datalist) + +if __name__ == "__main__": + import torch + config = { + "root_dir": "C:\\Document\\Local Project\\nbv_rec\\sample_dataset", + "label_dir": "C:\\Document\\Local Project\\nbv_rec\\sample_output", + "ratio": 0.1, + "batch_size": 1, + "num_workers": 0, + } + ds = NBVReconstructionDataset(config) + dl = ds.get_loader(shuffle=True) + for idx, data in enumerate(dl): + for key, value in data.items(): + if isinstance(value, torch.Tensor): + print(key, ":" ,value.shape) + print() \ No newline at end of file diff --git a/runners/strategy_generator.py b/runners/strategy_generator.py index f0b05b2..06fca5c 100644 --- a/runners/strategy_generator.py +++ b/runners/strategy_generator.py @@ -7,14 +7,14 @@ import PytorchBoot.stereotype as stereotype from utils.data_load import DataLoadUtil from utils.reconstruction import ReconstructionUtil -@stereotype.runner("strategy_generator", comment="unfinished") +@stereotype.runner("strategy_generator") class StrategyGenerator(Runner): def __init__(self, config): super().__init__(config) self.load_experiment("generate") def run(self): - dataset_name_list = ConfigManager.get("runner", "dataset_list") + dataset_name_list = ConfigManager.get("runner", "generate" "dataset_list") voxel_threshold, overlap_threshold = ConfigManager.get("runner","generate","voxel_threshold"), ConfigManager.get("runner","generate","overlap_threshold") for dataset_name in dataset_name_list: root_dir = ConfigManager.get("datasets", dataset_name, "root_dir") diff --git a/utils/data_load.py b/utils/data_load.py index f487e9c..ee5cd64 100644 --- a/utils/data_load.py +++ b/utils/data_load.py @@ -66,6 +66,12 @@ class DataLoadUtil: depth_map = DataLoadUtil.read_exr_depth(depth_path) return depth_map + @staticmethod + def load_label(path): + with open(path, 'r') as f: + label_data = json.load(f) + return label_data + @staticmethod def load_rgb(path): rgb_path = path + ".camera.png"