From 2fcfcd19667f19dc4f345737af25d8e5d4f7de78 Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Mon, 2 Sep 2024 18:21:38 +0800 Subject: [PATCH] add split dataset --- app_split.py | 9 +++++ configs/split_dataset_config.yaml | 22 +++++++++++++ configs/train_config.yaml | 12 ++++--- core/dataset.py | 22 ++++++++++--- runners/data_splitor.py | 55 +++++++++++++++++++++++++++++++ 5 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 app_split.py create mode 100644 configs/split_dataset_config.yaml create mode 100644 runners/data_splitor.py diff --git a/app_split.py b/app_split.py new file mode 100644 index 0000000..4ad462f --- /dev/null +++ b/app_split.py @@ -0,0 +1,9 @@ +from PytorchBoot.application import PytorchBootApplication +from runners.data_splitor import DataSplitor + +@PytorchBootApplication("split") +class DataSplitApp: + @staticmethod + def start(): + DataSplitor(r"configs\split_dataset_config.yaml").run() + \ No newline at end of file diff --git a/configs/split_dataset_config.yaml b/configs/split_dataset_config.yaml new file mode 100644 index 0000000..f2f2805 --- /dev/null +++ b/configs/split_dataset_config.yaml @@ -0,0 +1,22 @@ + +runner: + general: + seed: 0 + device: cpu + cuda_visible_devices: "0,1,2,3,4,5,6,7" + + experiment: + name: debug + root_dir: "experiments" + + split: + root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample" + type: "unseen_instance" # "unseen_category" + datasets: + OmniObject3d_train: + path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt" + ratio: 0.5 + + OmniObject3d_test: + path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt" + ratio: 0.5 \ No newline at end of file diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 9766f91..4def879 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -11,10 +11,14 @@ runner: train: dataset_list: - - OmniObject3d + - OmniObject3d_train datasets: - OmniObject3d: - root_dir: "/media/hofee/data/data/nbv_rec/sample" + OmniObject3d_train: + root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample" + split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt" + + OmniObject3d_test: + root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample" + split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt" - diff --git a/core/dataset.py b/core/dataset.py index 4dcc5de..1337452 100644 --- a/core/dataset.py +++ b/core/dataset.py @@ -1,10 +1,9 @@ -import os import numpy as np from PytorchBoot.dataset import BaseDataset import PytorchBoot.stereotype as stereotype import sys -sys.path.append(r"/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction") +sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction") from utils.data_load import DataLoadUtil from utils.pose import PoseUtil @@ -16,13 +15,22 @@ class NBVReconstructionDataset(BaseDataset): super(NBVReconstructionDataset, self).__init__(config) self.config = config self.root_dir = config["root_dir"] + self.split_file_path = config["split_file"] + self.scene_name_list = self.load_scene_name_list() self.datalist = self.get_datalist() self.pts_num = 1024 + def load_scene_name_list(self): + scene_name_list = [] + with open(self.split_file_path, "r") as f: + for line in f: + scene_name = line.strip() + scene_name_list.append(scene_name) + return scene_name_list + def get_datalist(self): datalist = [] - scene_name_list = os.listdir(self.root_dir) - for scene_name in scene_name_list: + for scene_name in self.scene_name_list: label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name) label_data = DataLoadUtil.load_label(label_path) for data_pair in label_data["data_pairs"]: @@ -97,8 +105,12 @@ class NBVReconstructionDataset(BaseDataset): if __name__ == "__main__": import torch + seed = 0 + torch.manual_seed(seed) + np.random.seed(seed) config = { - "root_dir": "/media/hofee/data/data/nbv_rec/sample", + "root_dir": "C:\\Document\\Local Project\\nbv_rec\\data\\sample", + "split_file": "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt", "ratio": 0.05, "batch_size": 1, "num_workers": 0, diff --git a/runners/data_splitor.py b/runners/data_splitor.py new file mode 100644 index 0000000..2a8bb87 --- /dev/null +++ b/runners/data_splitor.py @@ -0,0 +1,55 @@ +import os +import random +from PytorchBoot.runners.runner import Runner +from PytorchBoot.config import ConfigManager +from PytorchBoot.utils import Log +import PytorchBoot.stereotype as stereotype + + +@stereotype.runner("data_splitor", comment="unfinished") +class DataSplitor(Runner): + def __init__(self, config): + super().__init__(config) + self.load_experiment("data_split") + self.root_dir = ConfigManager.get("runner", "split", "root_dir") + self.type = ConfigManager.get("runner", "split", "type") + self.datasets = ConfigManager.get("runner", "split", "datasets") + self.datapath_list = self.load_all_datapath() + + def run(self): + self.split_dataset() + + def split_dataset(self): + + random.shuffle(self.datapath_list) + start_idx = 0 + for dataset in self.datasets: + ratio = self.datasets[dataset]["ratio"] + path = self.datasets[dataset]["path"] + split_size = int(len(self.datapath_list) * ratio) + split_files = self.datapath_list[start_idx:start_idx + split_size] + start_idx += split_size + self.save_split_files(path, split_files) + Log.success(f"save {dataset} split files to {path}") + + def save_split_files(self, path, split_files): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w") as f: + f.write("\n".join(split_files)) + + + def load_all_datapath(self): + return os.listdir(self.root_dir) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + + + + + + + \ No newline at end of file