import os import json import torch from tqdm import tqdm from PytorchBoot.config import ConfigManager import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory import ComponentFactory from PytorchBoot.dataset import BaseDataset from PytorchBoot.runners.runner import Runner from PytorchBoot.utils import Log @stereotype.runner("default_predictor") class DefaultPredictor(Runner): def __init__(self, config_path): super().__init__(config_path) ''' Pipeline ''' self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] self.pipeline = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) self.pipeline:torch.nn.Module = self.pipeline.to(self.device) ''' Experiment ''' self.model_path = self.config["experiment"]["model_path"] self.load_experiment("default_predictor") self.save_original_data = self.config["experiment"]["save_original_data"] ''' Testset ''' self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST) self.test_dataset_name_list = self.test_config["dataset_list"] self.test_set_list = [] self.test_writer_list = [] seen_name = set() for test_dataset_name in self.test_dataset_name_list: if test_dataset_name not in seen_name: seen_name.add(test_dataset_name) else: raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name)) test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name) self.test_set_list.append(test_set) self.print_info() def run(self): predict_result = self.predict() self.save_predict_result(predict_result) def predict(self): self.pipeline.eval() predict_result = {} with torch.no_grad(): test_set: BaseDataset for dataset_idx, test_set in enumerate(self.test_set_list): test_set_config = test_set.get_config() ratio = test_set_config["ratio"] test_set_name = test_set.get_name() output_list = [] data_list = [] test_loader = test_set.get_loader() loop = tqdm(enumerate(test_loader), total=int(len(test_loader))) for _, data in loop: test_set.process_batch(data, self.device) data["mode"] = namespace.Mode.TEST output = self.pipeline(data) output_list.append(output) data_list.append(data) loop.set_description( f'Predicting [{dataset_idx+1}/{len(self.test_set_list)}] (Test: {test_set_name}, ratio={ratio})') predict_result[test_set_name] = { "output": output_list, "data": data_list } return predict_result def save_predict_result(self, predict_result): result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME, self.file_name+"_predict_result") os.makedirs(result_dir) for test_set_name in predict_result.keys(): os.mkdir(os.path.join(result_dir, test_set_name)) idx = 0 for output, data in zip(predict_result[test_set_name]["output"], predict_result[test_set_name]["data"]): output_path = os.path.join(result_dir, test_set_name, f"output_{idx}.pth") torch.save(output, output_path) if self.save_original_data: data_path = os.path.join(result_dir, test_set_name, f"data_{idx}.pth") torch.save(data, data_path) idx += 1 Log.success(f"Saved predict result of {test_set_name} to {result_dir}") Log.success(f"Saved all predict result to {result_dir}") def load_checkpoint(self): self.load(self.model_path) Log.success(f"Loaded checkpoint from {self.model_path}") def load_experiment(self, backup_name=None): super().load_experiment(backup_name) self.load_checkpoint() def create_experiment(self, backup_name=None): super().create_experiment(backup_name) result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME) os.makedirs(result_dir) def load(self, path): state_dict = torch.load(path) self.pipeline.load_state_dict(state_dict) def print_info(self): def print_dataset(dataset: BaseDataset): config = dataset.get_config() name = dataset.get_name() Log.blue(f"Dataset: {name}") for k,v in config.items(): Log.blue(f"\t{k}: {v}") super().print_info() table_size = 70 Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') Log.blue(self.pipeline) Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') for i, test_set in enumerate(self.test_set_list): Log.blue(f"test dataset {i}: ") print_dataset(test_set) Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')