2024-09-13 16:58:34 +08:00

62 lines
2.5 KiB
Python

import os
import time
from abc import abstractmethod, ABC
import numpy as np
import torch
from PytorchBoot.config import ConfigManager
from PytorchBoot.utils.log_util import Log
class Runner(ABC):
@abstractmethod
def __init__(self, config_path):
ConfigManager.load_config_with(config_path)
ConfigManager.print_config()
self.config = ConfigManager.get("runner")
self.seed = self.config["general"]["seed"]
self.device = self.config["general"]["device"]
self.cuda_visible_devices = self.config["general"]["cuda_visible_devices"]
os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices
self.experiments_config = self.config["experiment"]
self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"])
np.random.seed(self.seed)
torch.manual_seed(self.seed)
lt = time.localtime()
self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s"
@abstractmethod
def run(self):
pass
@abstractmethod
def load_experiment(self, backup_name=None):
if not os.path.exists(self.experiment_path):
Log.info(f"experiments environment {self.experiments_config['name']} does not exists.")
self.create_experiment(backup_name)
else:
Log.info(f"experiments environment {self.experiments_config['name']}")
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
if not os.path.exists(backup_config_dir):
os.makedirs(backup_config_dir)
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
@abstractmethod
def create_experiment(self, backup_name=None):
Log.info("creating experiment: " + self.experiments_config["name"])
os.makedirs(self.experiment_path)
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
os.makedirs(backup_config_dir)
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
log_dir = os.path.join(str(self.experiment_path), "log")
os.makedirs(log_dir)
cache_dir = os.path.join(str(self.experiment_path), "cache")
os.makedirs(cache_dir)
def print_info(self):
table_size = 80
Log.blue("+" + "-" * table_size + "+")
Log.blue(f"| Experiment <{self.experiments_config['name']}>")
Log.blue("+" + "-" * table_size + "+")