add split dataset

This commit is contained in:
hofee 2024-09-02 18:21:38 +08:00
parent f58360c0c0
commit 2fcfcd1966
5 changed files with 111 additions and 9 deletions

9
app_split.py Normal file
View File

@ -0,0 +1,9 @@
from PytorchBoot.application import PytorchBootApplication
from runners.data_splitor import DataSplitor
@PytorchBootApplication("split")
class DataSplitApp:
@staticmethod
def start():
DataSplitor(r"configs\split_dataset_config.yaml").run()

View File

@ -0,0 +1,22 @@
runner:
general:
seed: 0
device: cpu
cuda_visible_devices: "0,1,2,3,4,5,6,7"
experiment:
name: debug
root_dir: "experiments"
split:
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
type: "unseen_instance" # "unseen_category"
datasets:
OmniObject3d_train:
path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt"
ratio: 0.5
OmniObject3d_test:
path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt"
ratio: 0.5

View File

@ -11,10 +11,14 @@ runner:
train: train:
dataset_list: dataset_list:
- OmniObject3d - OmniObject3d_train
datasets: datasets:
OmniObject3d: OmniObject3d_train:
root_dir: "/media/hofee/data/data/nbv_rec/sample" root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt"
OmniObject3d_test:
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt"

View File

@ -1,10 +1,9 @@
import os
import numpy as np import numpy as np
from PytorchBoot.dataset import BaseDataset from PytorchBoot.dataset import BaseDataset
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
import sys import sys
sys.path.append(r"/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction") sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
from utils.data_load import DataLoadUtil from utils.data_load import DataLoadUtil
from utils.pose import PoseUtil from utils.pose import PoseUtil
@ -16,13 +15,22 @@ class NBVReconstructionDataset(BaseDataset):
super(NBVReconstructionDataset, self).__init__(config) super(NBVReconstructionDataset, self).__init__(config)
self.config = config self.config = config
self.root_dir = config["root_dir"] self.root_dir = config["root_dir"]
self.split_file_path = config["split_file"]
self.scene_name_list = self.load_scene_name_list()
self.datalist = self.get_datalist() self.datalist = self.get_datalist()
self.pts_num = 1024 self.pts_num = 1024
def load_scene_name_list(self):
scene_name_list = []
with open(self.split_file_path, "r") as f:
for line in f:
scene_name = line.strip()
scene_name_list.append(scene_name)
return scene_name_list
def get_datalist(self): def get_datalist(self):
datalist = [] datalist = []
scene_name_list = os.listdir(self.root_dir) for scene_name in self.scene_name_list:
for scene_name in scene_name_list:
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name) label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name)
label_data = DataLoadUtil.load_label(label_path) label_data = DataLoadUtil.load_label(label_path)
for data_pair in label_data["data_pairs"]: for data_pair in label_data["data_pairs"]:
@ -97,8 +105,12 @@ class NBVReconstructionDataset(BaseDataset):
if __name__ == "__main__": if __name__ == "__main__":
import torch import torch
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
config = { config = {
"root_dir": "/media/hofee/data/data/nbv_rec/sample", "root_dir": "C:\\Document\\Local Project\\nbv_rec\\data\\sample",
"split_file": "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt",
"ratio": 0.05, "ratio": 0.05,
"batch_size": 1, "batch_size": 1,
"num_workers": 0, "num_workers": 0,

55
runners/data_splitor.py Normal file
View File

@ -0,0 +1,55 @@
import os
import random
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
from PytorchBoot.utils import Log
import PytorchBoot.stereotype as stereotype
@stereotype.runner("data_splitor", comment="unfinished")
class DataSplitor(Runner):
def __init__(self, config):
super().__init__(config)
self.load_experiment("data_split")
self.root_dir = ConfigManager.get("runner", "split", "root_dir")
self.type = ConfigManager.get("runner", "split", "type")
self.datasets = ConfigManager.get("runner", "split", "datasets")
self.datapath_list = self.load_all_datapath()
def run(self):
self.split_dataset()
def split_dataset(self):
random.shuffle(self.datapath_list)
start_idx = 0
for dataset in self.datasets:
ratio = self.datasets[dataset]["ratio"]
path = self.datasets[dataset]["path"]
split_size = int(len(self.datapath_list) * ratio)
split_files = self.datapath_list[start_idx:start_idx + split_size]
start_idx += split_size
self.save_split_files(path, split_files)
Log.success(f"save {dataset} split files to {path}")
def save_split_files(self, path, split_files):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
f.write("\n".join(split_files))
def load_all_datapath(self):
return os.listdir(self.root_dir)
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)