add split dataset
This commit is contained in:
parent
f58360c0c0
commit
2fcfcd1966
9
app_split.py
Normal file
9
app_split.py
Normal file
@ -0,0 +1,9 @@
|
||||
from PytorchBoot.application import PytorchBootApplication
|
||||
from runners.data_splitor import DataSplitor
|
||||
|
||||
@PytorchBootApplication("split")
|
||||
class DataSplitApp:
|
||||
@staticmethod
|
||||
def start():
|
||||
DataSplitor(r"configs\split_dataset_config.yaml").run()
|
||||
|
22
configs/split_dataset_config.yaml
Normal file
22
configs/split_dataset_config.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
|
||||
runner:
|
||||
general:
|
||||
seed: 0
|
||||
device: cpu
|
||||
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
||||
|
||||
experiment:
|
||||
name: debug
|
||||
root_dir: "experiments"
|
||||
|
||||
split:
|
||||
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
|
||||
type: "unseen_instance" # "unseen_category"
|
||||
datasets:
|
||||
OmniObject3d_train:
|
||||
path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt"
|
||||
ratio: 0.5
|
||||
|
||||
OmniObject3d_test:
|
||||
path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt"
|
||||
ratio: 0.5
|
@ -11,10 +11,14 @@ runner:
|
||||
|
||||
train:
|
||||
dataset_list:
|
||||
- OmniObject3d
|
||||
- OmniObject3d_train
|
||||
|
||||
datasets:
|
||||
OmniObject3d:
|
||||
root_dir: "/media/hofee/data/data/nbv_rec/sample"
|
||||
OmniObject3d_train:
|
||||
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
|
||||
split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt"
|
||||
|
||||
OmniObject3d_test:
|
||||
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
|
||||
split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt"
|
||||
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from PytorchBoot.dataset import BaseDataset
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
import sys
|
||||
sys.path.append(r"/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction")
|
||||
sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
|
||||
|
||||
from utils.data_load import DataLoadUtil
|
||||
from utils.pose import PoseUtil
|
||||
@ -16,13 +15,22 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
super(NBVReconstructionDataset, self).__init__(config)
|
||||
self.config = config
|
||||
self.root_dir = config["root_dir"]
|
||||
self.split_file_path = config["split_file"]
|
||||
self.scene_name_list = self.load_scene_name_list()
|
||||
self.datalist = self.get_datalist()
|
||||
self.pts_num = 1024
|
||||
|
||||
def load_scene_name_list(self):
|
||||
scene_name_list = []
|
||||
with open(self.split_file_path, "r") as f:
|
||||
for line in f:
|
||||
scene_name = line.strip()
|
||||
scene_name_list.append(scene_name)
|
||||
return scene_name_list
|
||||
|
||||
def get_datalist(self):
|
||||
datalist = []
|
||||
scene_name_list = os.listdir(self.root_dir)
|
||||
for scene_name in scene_name_list:
|
||||
for scene_name in self.scene_name_list:
|
||||
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name)
|
||||
label_data = DataLoadUtil.load_label(label_path)
|
||||
for data_pair in label_data["data_pairs"]:
|
||||
@ -97,8 +105,12 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
seed = 0
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
config = {
|
||||
"root_dir": "/media/hofee/data/data/nbv_rec/sample",
|
||||
"root_dir": "C:\\Document\\Local Project\\nbv_rec\\data\\sample",
|
||||
"split_file": "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt",
|
||||
"ratio": 0.05,
|
||||
"batch_size": 1,
|
||||
"num_workers": 0,
|
||||
|
55
runners/data_splitor.py
Normal file
55
runners/data_splitor.py
Normal file
@ -0,0 +1,55 @@
|
||||
import os
|
||||
import random
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.utils import Log
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
|
||||
@stereotype.runner("data_splitor", comment="unfinished")
|
||||
class DataSplitor(Runner):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.load_experiment("data_split")
|
||||
self.root_dir = ConfigManager.get("runner", "split", "root_dir")
|
||||
self.type = ConfigManager.get("runner", "split", "type")
|
||||
self.datasets = ConfigManager.get("runner", "split", "datasets")
|
||||
self.datapath_list = self.load_all_datapath()
|
||||
|
||||
def run(self):
|
||||
self.split_dataset()
|
||||
|
||||
def split_dataset(self):
|
||||
|
||||
random.shuffle(self.datapath_list)
|
||||
start_idx = 0
|
||||
for dataset in self.datasets:
|
||||
ratio = self.datasets[dataset]["ratio"]
|
||||
path = self.datasets[dataset]["path"]
|
||||
split_size = int(len(self.datapath_list) * ratio)
|
||||
split_files = self.datapath_list[start_idx:start_idx + split_size]
|
||||
start_idx += split_size
|
||||
self.save_split_files(path, split_files)
|
||||
Log.success(f"save {dataset} split files to {path}")
|
||||
|
||||
def save_split_files(self, path, split_files):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(split_files))
|
||||
|
||||
|
||||
def load_all_datapath(self):
|
||||
return os.listdir(self.root_dir)
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user