68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
import torch.optim as optim
|
|
|
|
class OptimizerFactory:
|
|
@staticmethod
|
|
def create(config: dict, params) -> optim.Optimizer:
|
|
optim_type = config["type"]
|
|
lr = config.get("lr", 1e-3)
|
|
|
|
if optim_type == "SGD":
|
|
return optim.SGD(
|
|
params,
|
|
lr=lr,
|
|
momentum=config.get("momentum", 0.9),
|
|
weight_decay=config.get("weight_decay", 1e-4),
|
|
)
|
|
elif optim_type == "Adam":
|
|
return optim.Adam(
|
|
params,
|
|
lr=lr,
|
|
betas=config.get("betas", (0.9, 0.999)),
|
|
eps=config.get("eps", 1e-8),
|
|
)
|
|
elif optim_type == "AdamW":
|
|
return optim.AdamW(
|
|
params,
|
|
lr=lr,
|
|
betas=config.get("betas", (0.9, 0.999)),
|
|
eps=config.get("eps", 1e-8),
|
|
weight_decay=config.get("weight_decay", 1e-2),
|
|
)
|
|
elif optim_type == "RMSprop":
|
|
return optim.RMSprop(
|
|
params,
|
|
lr=lr,
|
|
alpha=config.get("alpha", 0.99),
|
|
eps=config.get("eps", 1e-8),
|
|
weight_decay=config.get("weight_decay", 1e-4),
|
|
momentum=config.get("momentum", 0.9),
|
|
)
|
|
elif optim_type == "Adagrad":
|
|
return optim.Adagrad(
|
|
params,
|
|
lr=lr,
|
|
lr_decay=config.get("lr_decay", 0),
|
|
weight_decay=config.get("weight_decay", 0),
|
|
)
|
|
elif optim_type == "Adamax":
|
|
return optim.Adamax(
|
|
params,
|
|
lr=lr,
|
|
betas=config.get("betas", (0.9, 0.999)),
|
|
eps=config.get("eps", 1e-8),
|
|
weight_decay=config.get("weight_decay", 0),
|
|
)
|
|
elif optim_type == "LBFGS":
|
|
return optim.LBFGS(
|
|
params,
|
|
lr=lr,
|
|
max_iter=config.get("max_iter", 20),
|
|
max_eval=config.get("max_eval", None),
|
|
tolerance_grad=config.get("tolerance_grad", 1e-7),
|
|
tolerance_change=config.get("tolerance_change", 1e-9),
|
|
history_size=config.get("history_size", 100),
|
|
line_search_fn=config.get("line_search_fn", None),
|
|
)
|
|
else:
|
|
raise NotImplementedError("Unknown optimizer: {}".format(optim_type))
|