add split dataset

This commit is contained in:
hofee 2024-09-02 18:21:38 +08:00
parent f58360c0c0
commit 2fcfcd1966
5 changed files with 111 additions and 9 deletions

9
app_split.py Normal file
View File

@ -0,0 +1,9 @@
from PytorchBoot.application import PytorchBootApplication
from runners.data_splitor import DataSplitor
@PytorchBootApplication("split")
class DataSplitApp:
@staticmethod
def start():
DataSplitor(r"configs\split_dataset_config.yaml").run()

View File

@ -0,0 +1,22 @@
runner:
general:
seed: 0
device: cpu
cuda_visible_devices: "0,1,2,3,4,5,6,7"
experiment:
name: debug
root_dir: "experiments"
split:
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
type: "unseen_instance" # "unseen_category"
datasets:
OmniObject3d_train:
path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt"
ratio: 0.5
OmniObject3d_test:
path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt"
ratio: 0.5

View File

@ -11,10 +11,14 @@ runner:
train:
dataset_list:
- OmniObject3d
- OmniObject3d_train
datasets:
OmniObject3d:
root_dir: "/media/hofee/data/data/nbv_rec/sample"
OmniObject3d_train:
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt"
OmniObject3d_test:
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt"

View File

@ -1,10 +1,9 @@
import os
import numpy as np
from PytorchBoot.dataset import BaseDataset
import PytorchBoot.stereotype as stereotype
import sys
sys.path.append(r"/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction")
sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
from utils.data_load import DataLoadUtil
from utils.pose import PoseUtil
@ -16,13 +15,22 @@ class NBVReconstructionDataset(BaseDataset):
super(NBVReconstructionDataset, self).__init__(config)
self.config = config
self.root_dir = config["root_dir"]
self.split_file_path = config["split_file"]
self.scene_name_list = self.load_scene_name_list()
self.datalist = self.get_datalist()
self.pts_num = 1024
def load_scene_name_list(self):
scene_name_list = []
with open(self.split_file_path, "r") as f:
for line in f:
scene_name = line.strip()
scene_name_list.append(scene_name)
return scene_name_list
def get_datalist(self):
datalist = []
scene_name_list = os.listdir(self.root_dir)
for scene_name in scene_name_list:
for scene_name in self.scene_name_list:
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name)
label_data = DataLoadUtil.load_label(label_path)
for data_pair in label_data["data_pairs"]:
@ -97,8 +105,12 @@ class NBVReconstructionDataset(BaseDataset):
if __name__ == "__main__":
import torch
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
config = {
"root_dir": "/media/hofee/data/data/nbv_rec/sample",
"root_dir": "C:\\Document\\Local Project\\nbv_rec\\data\\sample",
"split_file": "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt",
"ratio": 0.05,
"batch_size": 1,
"num_workers": 0,

55
runners/data_splitor.py Normal file
View File

@ -0,0 +1,55 @@
import os
import random
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
from PytorchBoot.utils import Log
import PytorchBoot.stereotype as stereotype
@stereotype.runner("data_splitor", comment="unfinished")
class DataSplitor(Runner):
def __init__(self, config):
super().__init__(config)
self.load_experiment("data_split")
self.root_dir = ConfigManager.get("runner", "split", "root_dir")
self.type = ConfigManager.get("runner", "split", "type")
self.datasets = ConfigManager.get("runner", "split", "datasets")
self.datapath_list = self.load_all_datapath()
def run(self):
self.split_dataset()
def split_dataset(self):
random.shuffle(self.datapath_list)
start_idx = 0
for dataset in self.datasets:
ratio = self.datasets[dataset]["ratio"]
path = self.datasets[dataset]["path"]
split_size = int(len(self.datapath_list) * ratio)
split_files = self.datapath_list[start_idx:start_idx + split_size]
start_idx += split_size
self.save_split_files(path, split_files)
Log.success(f"save {dataset} split files to {path}")
def save_split_files(self, path, split_files):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
f.write("\n".join(split_files))
def load_all_datapath(self):
return os.listdir(self.root_dir)
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)