2024-09-13 16:58:34 +08:00

133 lines
5.3 KiB
Python

import os
import json
import torch
from tqdm import tqdm
from PytorchBoot.config import ConfigManager
import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory import ComponentFactory
from PytorchBoot.dataset import BaseDataset
from PytorchBoot.runners.runner import Runner
from PytorchBoot.utils import Log
@stereotype.runner("default_evaluator")
class DefaultEvaluator(Runner):
def __init__(self, config_path):
super().__init__(config_path)
''' Pipeline '''
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
self.pipeline = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
self.pipeline:torch.nn.Module = self.pipeline.to(self.device)
''' Experiment '''
self.model_path = self.config["experiment"]["model_path"]
self.load_experiment("default_evaluator")
''' Test '''
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
self.test_dataset_name_list = self.test_config["dataset_list"]
self.test_set_list = []
self.test_writer_list = []
seen_name = set()
for test_dataset_name in self.test_dataset_name_list:
if test_dataset_name not in seen_name:
seen_name.add(test_dataset_name)
else:
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
self.test_set_list.append(test_set)
self.print_info()
def run(self):
eval_result = self.test()
self.save_eval_result(eval_result)
def test(self):
self.pipeline.eval()
eval_result = {}
with torch.no_grad():
test_set: BaseDataset
for dataset_idx, test_set in enumerate(self.test_set_list):
test_set_config = test_set.get_config()
eval_list = test_set_config["eval_list"]
ratio = test_set_config["ratio"]
test_set_name = test_set.get_name()
output_list = []
data_list = []
test_loader = test_set.get_loader()
loop = tqdm(enumerate(test_loader), total=int(len(test_loader)))
for _, data in loop:
test_set.process_batch(data, self.device)
data["mode"] = namespace.Mode.TEST
output = self.pipeline(data)
output_list.append(output)
data_list.append(data)
loop.set_description(
f'Evaluating [{dataset_idx+1}/{len(self.test_set_list)}] (Test: {test_set_name}, ratio={ratio})')
result_dict = self.eval_fn(output_list, data_list, eval_list)
eval_result[test_set_name] = result_dict
return eval_result
def save_eval_result(self, eval_result):
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME)
eval_result_path = os.path.join(result_dir, self.file_name + "_eval_result.json")
with open(eval_result_path, "w") as f:
json.dump(eval_result, f, indent=4)
Log.success(f"Saved evaluation result to {eval_result_path}")
@staticmethod
def eval_fn(output_list, data_list, eval_list):
collected_result = {}
for eval_method_name in eval_list:
eval_method = ComponentFactory.create(namespace.Stereotype.EVALUATION_METHOD, eval_method_name)
eval_results:dict = eval_method.evaluate(output_list, data_list)
for data_type, eval_result in eval_results.items():
if data_type not in collected_result:
collected_result[data_type] = {}
for name, value in eval_result.items():
collected_result[data_type][name] = value
return collected_result
def load_checkpoint(self):
self.load(self.model_path)
Log.success(f"Loaded checkpoint from {self.model_path}")
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
self.load_checkpoint()
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME)
os.makedirs(result_dir)
def load(self, path):
state_dict = torch.load(path)
self.pipeline.load_state_dict(state_dict)
def print_info(self):
def print_dataset(dataset: BaseDataset):
config = dataset.get_config()
name = dataset.get_name()
Log.blue(f"Dataset: {name}")
for k,v in config.items():
Log.blue(f"\t{k}: {v}")
super().print_info()
table_size = 70
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
Log.blue(self.pipeline)
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
for i, test_set in enumerate(self.test_set_list):
Log.blue(f"test dataset {i}: ")
print_dataset(test_set)
Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')