PyTorchBoot/PytorchBoot/factory/component_factory.py
2024-09-13 16:58:34 +08:00

27 lines
1.0 KiB
Python

from PytorchBoot.component import Component
from PytorchBoot.stereotype import *
from PytorchBoot.utils.log_util import Log
from PytorchBoot.config import ConfigManager
class ComponentFactory:
@staticmethod
def create(component_type: str, name: str) -> Component:
component_classes = get_component_classes(component_type=component_type)
if component_classes is None:
Log.error(f"Unsupported component type: {component_type}", True)
if component_type == namespace.Stereotype.DATASET:
config = ConfigManager.get(component_type, name)
cls = dataset_classes[config["source"]]
dataset_obj = cls(config)
dataset_obj.NAME = name
dataset_obj.TYPE = component_type
return dataset_obj
if name not in component_classes:
Log.error(f"Unsupported component name: {name}", True)
cls = component_classes[name]
config = ConfigManager.get(component_type, name)
return cls(config)