nbv_grasping/baselines/grasping/GSNet/models/backbone_resunet14.py
2024-10-09 16:13:22 +00:00

225 lines
6.9 KiB
Python
Executable File

import MinkowskiEngine as ME
from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
from models.resnet import ResNetBase
class MinkUNetBase(ResNetBase):
BLOCK = None
PLANES = None
DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1)
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
INIT_DIM = 32
OUT_TENSOR_STRIDE = 1
# To use the model, must call initialize_coords before forward pass.
# Once data is processed, call clear to reset the model before calling
# initialize_coords
def __init__(self, in_channels, out_channels, D=3):
ResNetBase.__init__(self, in_channels, out_channels, D)
def network_initialization(self, in_channels, out_channels, D):
# Output of the first conv concated to conv6
self.inplanes = self.INIT_DIM
self.conv0p1s1 = ME.MinkowskiConvolution(
in_channels, self.inplanes, kernel_size=5, dimension=D)
self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)
self.conv1p1s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
self.block1 = self._make_layer(self.BLOCK, self.PLANES[0],
self.LAYERS[0])
self.conv2p2s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
self.bn2 = ME.MinkowskiBatchNorm(self.inplanes)
self.block2 = self._make_layer(self.BLOCK, self.PLANES[1],
self.LAYERS[1])
self.conv3p4s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
self.bn3 = ME.MinkowskiBatchNorm(self.inplanes)
self.block3 = self._make_layer(self.BLOCK, self.PLANES[2],
self.LAYERS[2])
self.conv4p8s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
self.bn4 = ME.MinkowskiBatchNorm(self.inplanes)
self.block4 = self._make_layer(self.BLOCK, self.PLANES[3],
self.LAYERS[3])
self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D)
self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4])
self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
self.block5 = self._make_layer(self.BLOCK, self.PLANES[4],
self.LAYERS[4])
self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D)
self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5])
self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
self.block6 = self._make_layer(self.BLOCK, self.PLANES[5],
self.LAYERS[5])
self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D)
self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6])
self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
self.block7 = self._make_layer(self.BLOCK, self.PLANES[6],
self.LAYERS[6])
self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D)
self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7])
self.inplanes = self.PLANES[7] + self.INIT_DIM
self.block8 = self._make_layer(self.BLOCK, self.PLANES[7],
self.LAYERS[7])
self.final = ME.MinkowskiConvolution(
self.PLANES[7] * self.BLOCK.expansion,
out_channels,
kernel_size=1,
bias=True,
dimension=D)
self.relu = ME.MinkowskiReLU(inplace=True)
def forward(self, x):
out = self.conv0p1s1(x)
out = self.bn0(out)
out_p1 = self.relu(out)
out = self.conv1p1s2(out_p1)
out = self.bn1(out)
out = self.relu(out)
out_b1p2 = self.block1(out)
out = self.conv2p2s2(out_b1p2)
out = self.bn2(out)
out = self.relu(out)
out_b2p4 = self.block2(out)
out = self.conv3p4s2(out_b2p4)
out = self.bn3(out)
out = self.relu(out)
out_b3p8 = self.block3(out)
# tensor_stride=16
out = self.conv4p8s2(out_b3p8)
out = self.bn4(out)
out = self.relu(out)
out = self.block4(out)
# tensor_stride=8
out = self.convtr4p16s2(out)
out = self.bntr4(out)
out = self.relu(out)
out = ME.cat(out, out_b3p8)
out = self.block5(out)
# tensor_stride=4
out = self.convtr5p8s2(out)
out = self.bntr5(out)
out = self.relu(out)
out = ME.cat(out, out_b2p4)
out = self.block6(out)
# tensor_stride=2
out = self.convtr6p4s2(out)
out = self.bntr6(out)
out = self.relu(out)
out = ME.cat(out, out_b1p2)
out = self.block7(out)
# tensor_stride=1
out = self.convtr7p2s2(out)
out = self.bntr7(out)
out = self.relu(out)
out = ME.cat(out, out_p1)
out = self.block8(out)
return self.final(out)
class MinkUNet14(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)
class MinkUNet18(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
class MinkUNet34(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
class MinkUNet50(MinkUNetBase):
BLOCK = Bottleneck
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
class MinkUNet101(MinkUNetBase):
BLOCK = Bottleneck
LAYERS = (2, 3, 4, 23, 2, 2, 2, 2)
class MinkUNet14A(MinkUNet14):
PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
class MinkUNet14B(MinkUNet14):
PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
class MinkUNet14C(MinkUNet14):
PLANES = (32, 64, 128, 256, 192, 192, 128, 128)
class MinkUNet14Dori(MinkUNet14):
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
class MinkUNet14E(MinkUNet14):
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
class MinkUNet14D(MinkUNet14):
PLANES = (32, 64, 128, 256, 192, 192, 192, 192)
class MinkUNet18A(MinkUNet18):
PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
class MinkUNet18B(MinkUNet18):
PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
class MinkUNet18D(MinkUNet18):
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
class MinkUNet34A(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 64, 64)
class MinkUNet34B(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 64, 32)
class MinkUNet34C(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)