import os import json from datetime import datetime import torch from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter from PytorchBoot.config import ConfigManager import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory import ComponentFactory from PytorchBoot.factory import OptimizerFactory from PytorchBoot.dataset import BaseDataset from PytorchBoot.runners.runner import Runner from PytorchBoot.utils.tensorboard_util import TensorboardWriter from PytorchBoot.stereotype import EXTERNAL_FRONZEN_MODULES from PytorchBoot.utils import Log from PytorchBoot.status import status_manager @stereotype.runner("default_trainer") class DefaultTrainer(Runner): def __init__(self, config_path): super().__init__(config_path) tensorboard_path = os.path.join(self.experiment_path, namespace.Direcotry.TENSORBOARD_DIR_NAME) ''' Pipeline ''' self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] self.parallel = self.config["general"]["parallel"] self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) if self.parallel and self.device == "cuda": self.pipeline = torch.nn.DataParallel(self.pipeline) self.pipeline = self.pipeline.to(self.device) ''' Experiment ''' self.current_epoch = 0 self.current_iter = 0 self.max_epochs = self.experiments_config["max_epochs"] self.test_first = self.experiments_config["test_first"] self.load_experiment("default_trainer") ''' Train ''' self.train_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TRAIN) self.train_dataset_name= self.train_config["dataset"] self.train_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, self.train_dataset_name) self.optimizer = OptimizerFactory.create(self.train_config["optimizer"], self.pipeline.parameters()) self.train_writer = SummaryWriter( log_dir=os.path.join(tensorboard_path, f"[{namespace.Mode.TRAIN}]{self.train_dataset_name}")) ''' Test ''' self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST) self.test_dataset_name_list = self.test_config["dataset_list"] self.test_set_list = [] self.test_writer_list = [] seen_name = set() for test_dataset_name in self.test_dataset_name_list: if test_dataset_name not in seen_name: seen_name.add(test_dataset_name) else: raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name)) test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name) test_writer = SummaryWriter( log_dir=os.path.join(tensorboard_path, f"[test]{test_dataset_name}")) self.test_set_list.append(test_set) self.test_writer_list.append(test_writer) self.print_info() def run(self): save_interval = self.experiments_config["save_checkpoint_interval"] if self.current_epoch != 0: Log.info("Continue training from epoch {}.".format(self.current_epoch)) else: Log.info("Start training from initial model.") if self.test_first: Log.info("Do test first.") self.test() while self.current_epoch < self.max_epochs: self.current_epoch += 1 status_manager.set_progress("train", "default_trainer", "Epoch", self.current_epoch, self.max_epochs) self.train() self.test() if self.current_epoch % save_interval == 0: self.save_checkpoint() self.save_checkpoint(is_last=True) def train(self): self.pipeline.train() train_set_name = self.train_dataset_name config = self.train_set.get_config() train_loader = self.train_set.get_loader(shuffle=True) total=len(train_loader) loop = tqdm(enumerate(train_loader), total=total) for i, data in loop: status_manager.set_progress("train", "default_trainer", f"(train) Batch[{train_set_name}]", i+1, total) self.train_set.process_batch(data, self.device) loss_dict = self.train_step(data) loop.set_description( f'Epoch [{self.current_epoch}/{self.max_epochs}] (Train: {train_set_name}, ratio={config["ratio"]})') loop.set_postfix(loss=loss_dict) for loss_name, loss in loss_dict.items(): status_manager.set_status("train", "default_trainer", f"[loss]{loss_name}", loss) TensorboardWriter.write_tensorboard(self.train_writer, "iter", loss_dict, self.current_iter, simple_scalar=True) self.current_iter += 1 def train_step(self, data): self.optimizer.zero_grad() data["mode"] = namespace.Mode.TRAIN output = self.pipeline(data) total_loss, loss_dict = self.loss_fn(output, data) total_loss.backward() self.optimizer.step() for k, v in loss_dict.items(): loss_dict[k] = round(v, 5) return loss_dict def loss_fn(self, output, data): loss_name_list = self.train_config["losses"] loss_dict = {} total_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device) for loss_name in loss_name_list: target_loss_fn = ComponentFactory.create(namespace.Stereotype.LOSS_FUNCTION, loss_name) loss = target_loss_fn.compute(output, data) loss_dict[loss_name] = loss.item() total_loss += loss loss_dict['total_loss'] = total_loss.item() return total_loss, loss_dict def test(self): self.pipeline.eval() with torch.no_grad(): test_set: BaseDataset for dataset_idx, test_set in enumerate(self.test_set_list): test_set_config = test_set.get_config() eval_list = test_set_config["eval_list"] ratio = test_set_config["ratio"] test_set_name = test_set.get_name() writer = self.test_writer_list[dataset_idx] output_list = [] data_list = [] test_loader = test_set.get_loader() total=int(len(test_loader)) loop = tqdm(enumerate(test_loader), total=total) for i, data in loop: status_manager.set_progress("train", "default_trainer", f"(test) Batch[{test_set_name}]", i+1, total) test_set.process_batch(data, self.device) data["mode"] = namespace.Mode.TEST output = self.pipeline(data) output_list.append(output) data_list.append(data) loop.set_description( f'Epoch [{self.current_epoch}/{self.max_epochs}] (Test: {test_set_name}, ratio={ratio})') result_dict = self.eval_fn(output_list, data_list, eval_list) TensorboardWriter.write_tensorboard(writer, "epoch", result_dict, self.current_epoch - 1) @staticmethod def eval_fn(output_list, data_list, eval_list): collected_result = {} for eval_method_name in eval_list: eval_method = ComponentFactory.create(namespace.Stereotype.EVALUATION_METHOD, eval_method_name) eval_results:dict = eval_method.evaluate(output_list, data_list) for data_type, eval_result in eval_results.items(): if data_type not in collected_result: collected_result[data_type] = {} for name, value in eval_result.items(): collected_result[data_type][name] = value status_manager.set_status("train", "default_trainer", f"[eval]{name}", value) return collected_result def get_checkpoint_path(self, is_last=False): return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, "Epoch_{}.pth".format( self.current_epoch if self.current_epoch != -1 and not is_last else "last")) def load_checkpoint(self, is_last=False): self.load(self.get_checkpoint_path(is_last)) Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") if is_last: checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) meta_path = os.path.join(checkpoint_root, "meta.json") if not os.path.exists(meta_path): raise FileNotFoundError( "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) file_path = os.path.join(checkpoint_root, "meta.json") with open(file_path, "r") as f: meta = json.load(f) self.current_epoch = meta["last_epoch"] self.current_iter = meta["last_iter"] def save_checkpoint(self, is_last=False): self.save(self.get_checkpoint_path(is_last)) if not is_last: Log.success(f"Checkpoint at epoch {self.current_epoch} saved to {self.get_checkpoint_path(is_last)}") else: meta = { "last_epoch": self.current_epoch, "last_iter": self.current_iter, "time": str(datetime.now()) } checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) file_path = os.path.join(checkpoint_root, "meta.json") with open(file_path, "w") as f: json.dump(meta, f) def load_experiment(self, backup_name=None): super().load_experiment(backup_name) if self.experiments_config["use_checkpoint"]: self.current_epoch = self.experiments_config["epoch"] self.load_checkpoint(is_last=(self.current_epoch == -1)) def create_experiment(self, backup_name=None): super().create_experiment(backup_name) ckpt_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.CHECKPOINT_DIR_NAME) os.makedirs(ckpt_dir) tensorboard_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.TENSORBOARD_DIR_NAME) os.makedirs(tensorboard_dir) def load(self, path): state_dict = torch.load(path) if self.parallel: self.pipeline.module.load_state_dict(state_dict) else: self.pipeline.load_state_dict(state_dict) def save(self, path): if self.parallel: state_dict = self.pipeline.module.state_dict() else: state_dict = self.pipeline.state_dict() for name, module in self.pipeline.named_modules(): if module.__class__ in EXTERNAL_FRONZEN_MODULES: if name in state_dict: del state_dict[name] torch.save(state_dict, path) def print_info(self): def print_dataset(dataset: BaseDataset): config = dataset.get_config() name = dataset.get_name() Log.blue(f"Dataset: {name}") for k,v in config.items(): Log.blue(f"\t{k}: {v}") super().print_info() table_size = 70 Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') Log.blue(self.pipeline) Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') Log.blue("train dataset: ") print_dataset(self.train_set) for i, test_set in enumerate(self.test_set_list): Log.blue(f"test dataset {i}: ") print_dataset(test_set) Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')