57 lines
2.1 KiB
Python
57 lines
2.1 KiB
Python
import os
|
|
import random
|
|
from PytorchBoot.runners.runner import Runner
|
|
from PytorchBoot.config import ConfigManager
|
|
from PytorchBoot.utils import Log
|
|
import PytorchBoot.stereotype as stereotype
|
|
from PytorchBoot.status import status_manager
|
|
|
|
@stereotype.runner("data_spliter")
|
|
class DataSpliter(Runner):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.load_experiment("data_split")
|
|
self.root_dir = ConfigManager.get("runner", "split", "root_dir")
|
|
self.type = ConfigManager.get("runner", "split", "type")
|
|
self.datasets = ConfigManager.get("runner", "split", "datasets")
|
|
self.datapath_list = self.load_all_datapath()
|
|
|
|
def run(self):
|
|
self.split_dataset()
|
|
|
|
def split_dataset(self):
|
|
|
|
random.shuffle(self.datapath_list)
|
|
start_idx = 0
|
|
for dataset_idx in range(len(self.datasets)):
|
|
dataset = list(self.datasets.keys())[dataset_idx]
|
|
ratio = self.datasets[dataset]["ratio"]
|
|
path = self.datasets[dataset]["path"]
|
|
split_size = int(len(self.datapath_list) * ratio)
|
|
split_files = self.datapath_list[start_idx:start_idx + split_size]
|
|
start_idx += split_size
|
|
self.save_split_files(path, split_files)
|
|
status_manager.set_progress("split", "data_splitor", "split dataset", dataset_idx, len(self.datasets))
|
|
Log.success(f"save {dataset} split files to {path}")
|
|
status_manager.set_progress("split", "data_splitor", "split dataset", len(self.datasets), len(self.datasets))
|
|
def save_split_files(self, path, split_files):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
with open(path, "w") as f:
|
|
f.write("\n".join(split_files))
|
|
|
|
|
|
def load_all_datapath(self):
|
|
return os.listdir(self.root_dir)
|
|
|
|
def create_experiment(self, backup_name=None):
|
|
super().create_experiment(backup_name)
|
|
|
|
def load_experiment(self, backup_name=None):
|
|
super().load_experiment(backup_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|