import os import random from PytorchBoot.runners.runner import Runner from PytorchBoot.config import ConfigManager from PytorchBoot.utils import Log import PytorchBoot.stereotype as stereotype from PytorchBoot.status import status_manager @stereotype.runner("data_spliter") class DataSpliter(Runner): def __init__(self, config): super().__init__(config) self.load_experiment("data_split") self.root_dir = ConfigManager.get("runner", "split", "root_dir") self.type = ConfigManager.get("runner", "split", "type") self.datasets = ConfigManager.get("runner", "split", "datasets") self.datapath_list = self.load_all_datapath() def run(self): self.split_dataset() def split_dataset(self): random.shuffle(self.datapath_list) start_idx = 0 for dataset_idx in range(len(self.datasets)): dataset = list(self.datasets.keys())[dataset_idx] ratio = self.datasets[dataset]["ratio"] path = self.datasets[dataset]["path"] split_size = int(len(self.datapath_list) * ratio) split_files = self.datapath_list[start_idx:start_idx + split_size] start_idx += split_size self.save_split_files(path, split_files) status_manager.set_progress("split", "data_splitor", "split dataset", dataset_idx, len(self.datasets)) Log.success(f"save {dataset} split files to {path}") status_manager.set_progress("split", "data_splitor", "split dataset", len(self.datasets), len(self.datasets)) def save_split_files(self, path, split_files): os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as f: f.write("\n".join(split_files)) def load_all_datapath(self): return os.listdir(self.root_dir) def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def load_experiment(self, backup_name=None): super().load_experiment(backup_name)