27 lines
1.0 KiB
Python
27 lines
1.0 KiB
Python
from PytorchBoot.component import Component
|
|
from PytorchBoot.stereotype import *
|
|
from PytorchBoot.utils.log_util import Log
|
|
from PytorchBoot.config import ConfigManager
|
|
|
|
class ComponentFactory:
|
|
@staticmethod
|
|
def create(component_type: str, name: str) -> Component:
|
|
component_classes = get_component_classes(component_type=component_type)
|
|
if component_classes is None:
|
|
Log.error(f"Unsupported component type: {component_type}", True)
|
|
|
|
if component_type == namespace.Stereotype.DATASET:
|
|
config = ConfigManager.get(component_type, name)
|
|
cls = dataset_classes[config["source"]]
|
|
dataset_obj = cls(config)
|
|
dataset_obj.NAME = name
|
|
dataset_obj.TYPE = component_type
|
|
return dataset_obj
|
|
|
|
if name not in component_classes:
|
|
Log.error(f"Unsupported component name: {name}", True)
|
|
|
|
cls = component_classes[name]
|
|
config = ConfigManager.get(component_type, name)
|
|
return cls(config)
|
|
|