62 lines
2.5 KiB
Python
62 lines
2.5 KiB
Python
import os
|
|
import time
|
|
from abc import abstractmethod, ABC
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from PytorchBoot.config import ConfigManager
|
|
from PytorchBoot.utils.log_util import Log
|
|
|
|
class Runner(ABC):
|
|
|
|
@abstractmethod
|
|
def __init__(self, config_path):
|
|
ConfigManager.load_config_with(config_path)
|
|
ConfigManager.print_config()
|
|
self.config = ConfigManager.get("runner")
|
|
self.seed = self.config["general"]["seed"]
|
|
self.device = self.config["general"]["device"]
|
|
self.cuda_visible_devices = self.config["general"]["cuda_visible_devices"]
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices
|
|
self.experiments_config = self.config["experiment"]
|
|
self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"])
|
|
np.random.seed(self.seed)
|
|
torch.manual_seed(self.seed)
|
|
lt = time.localtime()
|
|
self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s"
|
|
|
|
@abstractmethod
|
|
def run(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_experiment(self, backup_name=None):
|
|
if not os.path.exists(self.experiment_path):
|
|
Log.info(f"experiments environment {self.experiments_config['name']} does not exists.")
|
|
self.create_experiment(backup_name)
|
|
else:
|
|
Log.info(f"experiments environment {self.experiments_config['name']}")
|
|
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
|
if not os.path.exists(backup_config_dir):
|
|
os.makedirs(backup_config_dir)
|
|
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
|
|
|
@abstractmethod
|
|
def create_experiment(self, backup_name=None):
|
|
Log.info("creating experiment: " + self.experiments_config["name"])
|
|
os.makedirs(self.experiment_path)
|
|
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
|
os.makedirs(backup_config_dir)
|
|
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
|
log_dir = os.path.join(str(self.experiment_path), "log")
|
|
os.makedirs(log_dir)
|
|
cache_dir = os.path.join(str(self.experiment_path), "cache")
|
|
os.makedirs(cache_dir)
|
|
|
|
def print_info(self):
|
|
table_size = 80
|
|
Log.blue("+" + "-" * table_size + "+")
|
|
Log.blue(f"| Experiment <{self.experiments_config['name']}>")
|
|
Log.blue("+" + "-" * table_size + "+")
|