import torch.optim as optim class OptimizerFactory: @staticmethod def create(config: dict, params) -> optim.Optimizer: optim_type = config["type"] lr = config.get("lr", 1e-3) if optim_type == "SGD": return optim.SGD( params, lr=lr, momentum=config.get("momentum", 0.9), weight_decay=config.get("weight_decay", 1e-4), ) elif optim_type == "Adam": return optim.Adam( params, lr=lr, betas=config.get("betas", (0.9, 0.999)), eps=config.get("eps", 1e-8), ) elif optim_type == "AdamW": return optim.AdamW( params, lr=lr, betas=config.get("betas", (0.9, 0.999)), eps=config.get("eps", 1e-8), weight_decay=config.get("weight_decay", 1e-2), ) elif optim_type == "RMSprop": return optim.RMSprop( params, lr=lr, alpha=config.get("alpha", 0.99), eps=config.get("eps", 1e-8), weight_decay=config.get("weight_decay", 1e-4), momentum=config.get("momentum", 0.9), ) elif optim_type == "Adagrad": return optim.Adagrad( params, lr=lr, lr_decay=config.get("lr_decay", 0), weight_decay=config.get("weight_decay", 0), ) elif optim_type == "Adamax": return optim.Adamax( params, lr=lr, betas=config.get("betas", (0.9, 0.999)), eps=config.get("eps", 1e-8), weight_decay=config.get("weight_decay", 0), ) elif optim_type == "LBFGS": return optim.LBFGS( params, lr=lr, max_iter=config.get("max_iter", 20), max_eval=config.get("max_eval", None), tolerance_grad=config.get("tolerance_grad", 1e-7), tolerance_change=config.get("tolerance_change", 1e-9), history_size=config.get("history_size", 100), line_search_fn=config.get("line_search_fn", None), ) else: raise NotImplementedError("Unknown optimizer: {}".format(optim_type))