import torch import numpy as np def weight_init(shape, mode, fan_in, fan_out): if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape) raise ValueError(f'Invalid init mode "{mode}"') class Linear(torch.nn.Module): def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0): super().__init__() self.in_features = in_features self.out_features = out_features init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None def forward(self, x): x = x @ self.weight.to(x.dtype).t() if self.bias is not None: x = x.add_(self.bias.to(x.dtype)) return x