nbv_grasping/evaluations/eval_function_factory.py
2024-10-09 16:13:22 +00:00

36 lines
1.2 KiB
Python
Executable File

from annotations.stereotype import evaluation_methods
import importlib
import pkgutil
import os
package_name = os.path.dirname("evaluations")
package = importlib.import_module("evaluations")
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
importlib.import_module(module_name)
class EvalFunctionFactory:
@staticmethod
def create(eval_type_list):
def eval_func(output, data):
temp_results = {"scalars": {}, "points": {}, "images": {}}
for eval_type in eval_type_list:
if eval_type in evaluation_methods:
result = evaluation_methods[eval_type](output, data)
for k, v in result.items():
temp_results[k].update(v)
results = {}
for k, v in temp_results.items():
if len(v) > 0:
results[k] = v
return results
return eval_func
''' ------------ Debug ------------ '''
if __name__ == "__main__":
from configs.config import ConfigManager
ConfigManager.load_config_with('../configs/local_train_config.yaml')
ConfigManager.print_config()