new_nbv_rec/runners/data_spliter.py
2025-05-13 09:03:38 +08:00

57 lines
2.1 KiB
Python

import os
import random
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
from PytorchBoot.utils import Log
import PytorchBoot.stereotype as stereotype
from PytorchBoot.status import status_manager
@stereotype.runner("data_spliter")
class DataSpliter(Runner):
def __init__(self, config):
super().__init__(config)
self.load_experiment("data_split")
self.root_dir = ConfigManager.get("runner", "split", "root_dir")
self.type = ConfigManager.get("runner", "split", "type")
self.datasets = ConfigManager.get("runner", "split", "datasets")
self.datapath_list = self.load_all_datapath()
def run(self):
self.split_dataset()
def split_dataset(self):
random.shuffle(self.datapath_list)
start_idx = 0
for dataset_idx in range(len(self.datasets)):
dataset = list(self.datasets.keys())[dataset_idx]
ratio = self.datasets[dataset]["ratio"]
path = self.datasets[dataset]["path"]
split_size = int(len(self.datapath_list) * ratio)
split_files = self.datapath_list[start_idx:start_idx + split_size]
start_idx += split_size
self.save_split_files(path, split_files)
status_manager.set_progress("split", "data_splitor", "split dataset", dataset_idx, len(self.datasets))
Log.success(f"save {dataset} split files to {path}")
status_manager.set_progress("split", "data_splitor", "split dataset", len(self.datasets), len(self.datasets))
def save_split_files(self, path, split_files):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
f.write("\n".join(split_files))
def load_all_datapath(self):
return os.listdir(self.root_dir)
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)