225 lines
6.9 KiB
Python
Executable File
225 lines
6.9 KiB
Python
Executable File
import MinkowskiEngine as ME
|
|
from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
|
|
from models.resnet import ResNetBase
|
|
|
|
|
|
class MinkUNetBase(ResNetBase):
|
|
BLOCK = None
|
|
PLANES = None
|
|
DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1)
|
|
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
|
|
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
|
|
INIT_DIM = 32
|
|
OUT_TENSOR_STRIDE = 1
|
|
|
|
# To use the model, must call initialize_coords before forward pass.
|
|
# Once data is processed, call clear to reset the model before calling
|
|
# initialize_coords
|
|
def __init__(self, in_channels, out_channels, D=3):
|
|
ResNetBase.__init__(self, in_channels, out_channels, D)
|
|
|
|
def network_initialization(self, in_channels, out_channels, D):
|
|
# Output of the first conv concated to conv6
|
|
self.inplanes = self.INIT_DIM
|
|
self.conv0p1s1 = ME.MinkowskiConvolution(
|
|
in_channels, self.inplanes, kernel_size=5, dimension=D)
|
|
|
|
self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)
|
|
|
|
self.conv1p1s2 = ME.MinkowskiConvolution(
|
|
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
|
|
self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
|
|
|
|
self.block1 = self._make_layer(self.BLOCK, self.PLANES[0],
|
|
self.LAYERS[0])
|
|
|
|
self.conv2p2s2 = ME.MinkowskiConvolution(
|
|
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
|
|
self.bn2 = ME.MinkowskiBatchNorm(self.inplanes)
|
|
|
|
self.block2 = self._make_layer(self.BLOCK, self.PLANES[1],
|
|
self.LAYERS[1])
|
|
|
|
self.conv3p4s2 = ME.MinkowskiConvolution(
|
|
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
|
|
|
|
self.bn3 = ME.MinkowskiBatchNorm(self.inplanes)
|
|
self.block3 = self._make_layer(self.BLOCK, self.PLANES[2],
|
|
self.LAYERS[2])
|
|
|
|
self.conv4p8s2 = ME.MinkowskiConvolution(
|
|
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
|
|
self.bn4 = ME.MinkowskiBatchNorm(self.inplanes)
|
|
self.block4 = self._make_layer(self.BLOCK, self.PLANES[3],
|
|
self.LAYERS[3])
|
|
|
|
self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(
|
|
self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D)
|
|
self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4])
|
|
|
|
self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
|
|
self.block5 = self._make_layer(self.BLOCK, self.PLANES[4],
|
|
self.LAYERS[4])
|
|
self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(
|
|
self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D)
|
|
self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5])
|
|
|
|
self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
|
|
self.block6 = self._make_layer(self.BLOCK, self.PLANES[5],
|
|
self.LAYERS[5])
|
|
self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(
|
|
self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D)
|
|
self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6])
|
|
|
|
self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
|
|
self.block7 = self._make_layer(self.BLOCK, self.PLANES[6],
|
|
self.LAYERS[6])
|
|
self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(
|
|
self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D)
|
|
self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7])
|
|
|
|
self.inplanes = self.PLANES[7] + self.INIT_DIM
|
|
self.block8 = self._make_layer(self.BLOCK, self.PLANES[7],
|
|
self.LAYERS[7])
|
|
|
|
self.final = ME.MinkowskiConvolution(
|
|
self.PLANES[7] * self.BLOCK.expansion,
|
|
out_channels,
|
|
kernel_size=1,
|
|
bias=True,
|
|
dimension=D)
|
|
self.relu = ME.MinkowskiReLU(inplace=True)
|
|
|
|
def forward(self, x):
|
|
out = self.conv0p1s1(x)
|
|
out = self.bn0(out)
|
|
out_p1 = self.relu(out)
|
|
|
|
out = self.conv1p1s2(out_p1)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
out_b1p2 = self.block1(out)
|
|
|
|
out = self.conv2p2s2(out_b1p2)
|
|
out = self.bn2(out)
|
|
out = self.relu(out)
|
|
out_b2p4 = self.block2(out)
|
|
|
|
out = self.conv3p4s2(out_b2p4)
|
|
out = self.bn3(out)
|
|
out = self.relu(out)
|
|
out_b3p8 = self.block3(out)
|
|
|
|
# tensor_stride=16
|
|
out = self.conv4p8s2(out_b3p8)
|
|
out = self.bn4(out)
|
|
out = self.relu(out)
|
|
out = self.block4(out)
|
|
|
|
# tensor_stride=8
|
|
out = self.convtr4p16s2(out)
|
|
out = self.bntr4(out)
|
|
out = self.relu(out)
|
|
|
|
out = ME.cat(out, out_b3p8)
|
|
out = self.block5(out)
|
|
|
|
# tensor_stride=4
|
|
out = self.convtr5p8s2(out)
|
|
out = self.bntr5(out)
|
|
out = self.relu(out)
|
|
|
|
out = ME.cat(out, out_b2p4)
|
|
out = self.block6(out)
|
|
|
|
# tensor_stride=2
|
|
out = self.convtr6p4s2(out)
|
|
out = self.bntr6(out)
|
|
out = self.relu(out)
|
|
|
|
out = ME.cat(out, out_b1p2)
|
|
out = self.block7(out)
|
|
|
|
# tensor_stride=1
|
|
out = self.convtr7p2s2(out)
|
|
out = self.bntr7(out)
|
|
out = self.relu(out)
|
|
|
|
out = ME.cat(out, out_p1)
|
|
out = self.block8(out)
|
|
|
|
return self.final(out)
|
|
|
|
|
|
class MinkUNet14(MinkUNetBase):
|
|
BLOCK = BasicBlock
|
|
LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)
|
|
|
|
|
|
class MinkUNet18(MinkUNetBase):
|
|
BLOCK = BasicBlock
|
|
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
|
|
|
|
|
|
class MinkUNet34(MinkUNetBase):
|
|
BLOCK = BasicBlock
|
|
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
|
|
|
|
|
|
class MinkUNet50(MinkUNetBase):
|
|
BLOCK = Bottleneck
|
|
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
|
|
|
|
|
|
class MinkUNet101(MinkUNetBase):
|
|
BLOCK = Bottleneck
|
|
LAYERS = (2, 3, 4, 23, 2, 2, 2, 2)
|
|
|
|
|
|
class MinkUNet14A(MinkUNet14):
|
|
PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
|
|
|
|
|
|
class MinkUNet14B(MinkUNet14):
|
|
PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
|
|
|
|
|
|
class MinkUNet14C(MinkUNet14):
|
|
PLANES = (32, 64, 128, 256, 192, 192, 128, 128)
|
|
|
|
|
|
class MinkUNet14Dori(MinkUNet14):
|
|
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
|
|
|
|
|
|
class MinkUNet14E(MinkUNet14):
|
|
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
|
|
|
|
|
|
class MinkUNet14D(MinkUNet14):
|
|
PLANES = (32, 64, 128, 256, 192, 192, 192, 192)
|
|
|
|
|
|
class MinkUNet18A(MinkUNet18):
|
|
PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
|
|
|
|
|
|
class MinkUNet18B(MinkUNet18):
|
|
PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
|
|
|
|
|
|
class MinkUNet18D(MinkUNet18):
|
|
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
|
|
|
|
|
|
class MinkUNet34A(MinkUNet34):
|
|
PLANES = (32, 64, 128, 256, 256, 128, 64, 64)
|
|
|
|
|
|
class MinkUNet34B(MinkUNet34):
|
|
PLANES = (32, 64, 128, 256, 256, 128, 64, 32)
|
|
|
|
|
|
class MinkUNet34C(MinkUNet34):
|
|
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
|