import torch.optim as optim class OptimizerFactory: @staticmethod def create(config, params): optim_type = config["type"] lr = config.get("lr", 1e-3) if optim_type == "sgd": return optim.SGD( params, lr=lr, momentum=config.get("momentum", 0.9), weight_decay=config.get("weight_decay", 1e-4), ) elif optim_type == "adam": return optim.Adam( params, lr=lr, betas=config.get("betas", (0.9, 0.999)), eps=config.get("eps", 1e-8), ) else: raise NotImplementedError("Unknown optimizers: {}".format(optim_type)) """ ------------ Debug ------------ """ if __name__ == "__main__": from configs.config import ConfigManager ConfigManager.load_config_with("../configs/local_train_config.yaml") ConfigManager.print_config()