237 lines
6.0 KiB
Python
237 lines
6.0 KiB
Python
import torch.nn as nn
|
|
from typing import List, Tuple
|
|
|
|
|
|
class SharedMLP(nn.Sequential):
|
|
|
|
def __init__(
|
|
self,
|
|
args: List[int],
|
|
*,
|
|
bn: bool = False,
|
|
activation=nn.ReLU(inplace=True),
|
|
preact: bool = False,
|
|
first: bool = False,
|
|
name: str = "",
|
|
instance_norm: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
for i in range(len(args) - 1):
|
|
self.add_module(
|
|
name + 'layer{}'.format(i),
|
|
Conv2d(
|
|
args[i],
|
|
args[i + 1],
|
|
bn=(not first or not preact or (i != 0)) and bn,
|
|
activation=activation
|
|
if (not first or not preact or (i != 0)) else None,
|
|
preact=preact,
|
|
instance_norm=instance_norm
|
|
)
|
|
)
|
|
|
|
|
|
class _ConvBase(nn.Sequential):
|
|
|
|
def __init__(
|
|
self,
|
|
in_size,
|
|
out_size,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
activation,
|
|
bn,
|
|
init,
|
|
conv=None,
|
|
batch_norm=None,
|
|
bias=True,
|
|
preact=False,
|
|
name="",
|
|
instance_norm=False,
|
|
instance_norm_func=None
|
|
):
|
|
super().__init__()
|
|
|
|
bias = bias and (not bn)
|
|
conv_unit = conv(
|
|
in_size,
|
|
out_size,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
init(conv_unit.weight)
|
|
if bias:
|
|
nn.init.constant_(conv_unit.bias, 0)
|
|
|
|
if bn:
|
|
if not preact:
|
|
bn_unit = batch_norm(out_size)
|
|
else:
|
|
bn_unit = batch_norm(in_size)
|
|
if instance_norm:
|
|
if not preact:
|
|
in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
|
|
else:
|
|
in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
|
|
|
|
if preact:
|
|
if bn:
|
|
self.add_module(name + 'bn', bn_unit)
|
|
|
|
if activation is not None:
|
|
self.add_module(name + 'activation', activation)
|
|
|
|
if not bn and instance_norm:
|
|
self.add_module(name + 'in', in_unit)
|
|
|
|
self.add_module(name + 'conv', conv_unit)
|
|
|
|
if not preact:
|
|
if bn:
|
|
self.add_module(name + 'bn', bn_unit)
|
|
|
|
if activation is not None:
|
|
self.add_module(name + 'activation', activation)
|
|
|
|
if not bn and instance_norm:
|
|
self.add_module(name + 'in', in_unit)
|
|
|
|
|
|
class _BNBase(nn.Sequential):
|
|
|
|
def __init__(self, in_size, batch_norm=None, name=""):
|
|
super().__init__()
|
|
self.add_module(name + "bn", batch_norm(in_size))
|
|
|
|
nn.init.constant_(self[0].weight, 1.0)
|
|
nn.init.constant_(self[0].bias, 0)
|
|
|
|
|
|
class BatchNorm1d(_BNBase):
|
|
|
|
def __init__(self, in_size: int, *, name: str = ""):
|
|
super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
|
|
|
|
|
|
class BatchNorm2d(_BNBase):
|
|
|
|
def __init__(self, in_size: int, name: str = ""):
|
|
super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
|
|
|
|
|
|
class Conv1d(_ConvBase):
|
|
|
|
def __init__(
|
|
self,
|
|
in_size: int,
|
|
out_size: int,
|
|
*,
|
|
kernel_size: int = 1,
|
|
stride: int = 1,
|
|
padding: int = 0,
|
|
activation=nn.ReLU(inplace=True),
|
|
bn: bool = False,
|
|
init=nn.init.kaiming_normal_,
|
|
bias: bool = True,
|
|
preact: bool = False,
|
|
name: str = "",
|
|
instance_norm=False
|
|
):
|
|
super().__init__(
|
|
in_size,
|
|
out_size,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
activation,
|
|
bn,
|
|
init,
|
|
conv=nn.Conv1d,
|
|
batch_norm=BatchNorm1d,
|
|
bias=bias,
|
|
preact=preact,
|
|
name=name,
|
|
instance_norm=instance_norm,
|
|
instance_norm_func=nn.InstanceNorm1d
|
|
)
|
|
|
|
|
|
class Conv2d(_ConvBase):
|
|
|
|
def __init__(
|
|
self,
|
|
in_size: int,
|
|
out_size: int,
|
|
*,
|
|
kernel_size: Tuple[int, int] = (1, 1),
|
|
stride: Tuple[int, int] = (1, 1),
|
|
padding: Tuple[int, int] = (0, 0),
|
|
activation=nn.ReLU(inplace=True),
|
|
bn: bool = False,
|
|
init=nn.init.kaiming_normal_,
|
|
bias: bool = True,
|
|
preact: bool = False,
|
|
name: str = "",
|
|
instance_norm=False
|
|
):
|
|
super().__init__(
|
|
in_size,
|
|
out_size,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
activation,
|
|
bn,
|
|
init,
|
|
conv=nn.Conv2d,
|
|
batch_norm=BatchNorm2d,
|
|
bias=bias,
|
|
preact=preact,
|
|
name=name,
|
|
instance_norm=instance_norm,
|
|
instance_norm_func=nn.InstanceNorm2d
|
|
)
|
|
|
|
|
|
class FC(nn.Sequential):
|
|
|
|
def __init__(
|
|
self,
|
|
in_size: int,
|
|
out_size: int,
|
|
*,
|
|
activation=nn.ReLU(inplace=True),
|
|
bn: bool = False,
|
|
init=None,
|
|
preact: bool = False,
|
|
name: str = ""
|
|
):
|
|
super().__init__()
|
|
|
|
fc = nn.Linear(in_size, out_size, bias=not bn)
|
|
if init is not None:
|
|
init(fc.weight)
|
|
if not bn:
|
|
nn.init.constant(fc.bias, 0)
|
|
|
|
if preact:
|
|
if bn:
|
|
self.add_module(name + 'bn', BatchNorm1d(in_size))
|
|
|
|
if activation is not None:
|
|
self.add_module(name + 'activation', activation)
|
|
|
|
self.add_module(name + 'fc', fc)
|
|
|
|
if not preact:
|
|
if bn:
|
|
self.add_module(name + 'bn', BatchNorm1d(out_size))
|
|
|
|
if activation is not None:
|
|
self.add_module(name + 'activation', activation)
|
|
|