new_nbv_rec/modules/pointnet++_encoder.py
2025-05-13 09:03:38 +08:00

150 lines
4.8 KiB
Python

import torch
import torch.nn as nn
import os
import sys
path = os.path.abspath(__file__)
for i in range(2):
path = os.path.dirname(path)
PROJECT_ROOT = path
sys.path.append(PROJECT_ROOT)
import PytorchBoot.stereotype as stereotype
from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG
ClsMSG_CFG_Dense = {
'NPOINTS': [512, 256, 128, None],
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]],
'MLPS': [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 384, 512]]],
'DP_RATIO': 0.5,
}
ClsMSG_CFG_Light = {
'NPOINTS': [512, 256, 128, None],
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
'MLPS': [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 384, 512]]],
'DP_RATIO': 0.5,
}
ClsMSG_CFG_Light_2048 = {
'NPOINTS': [512, 256, 128, None],
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
'MLPS': [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 1024], [256, 512, 1024]]],
'DP_RATIO': 0.5,
}
ClsMSG_CFG_Strong = {
'NPOINTS': [512, 256, 128, 64, None],
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16],[0.16, 0.32], [None, None]],
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32], [None, None]],
'MLPS': [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 512, 512]],
[[512, 512, 2048], [512, 1024, 2048]]
],
'DP_RATIO': 0.5,
}
ClsMSG_CFG_Lighter = {
'NPOINTS': [512, 256, 128, 64, None],
'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]],
'NSAMPLE': [[64], [32], [16], [8], [None]],
'MLPS': [[[32, 32, 64]],
[[64, 64, 128]],
[[128, 196, 256]],
[[256, 256, 512]],
[[512, 512, 1024]]],
'DP_RATIO': 0.5,
}
def select_params(name):
if name == 'light':
return ClsMSG_CFG_Light
elif name == 'lighter':
return ClsMSG_CFG_Lighter
elif name == 'dense':
return ClsMSG_CFG_Dense
elif name == 'light_2048':
return ClsMSG_CFG_Light_2048
elif name == 'strong':
return ClsMSG_CFG_Strong
else:
raise NotImplementedError
def break_up_pc(pc):
xyz = pc[..., 0:3].contiguous()
features = (
pc[..., 3:].transpose(1, 2).contiguous()
if pc.size(-1) > 3 else None
)
return xyz, features
@stereotype.module("pointnet++_encoder")
class PointNet2Encoder(nn.Module):
def encode_points(self, pts, require_per_point_feat=False):
return self.forward(pts)
def __init__(self, config:dict):
super().__init__()
channel_in = config.get("in_dim", 3) - 3
params_name = config.get("params_name", "light")
self.SA_modules = nn.ModuleList()
selected_params = select_params(params_name)
for k in range(selected_params['NPOINTS'].__len__()):
mlps = selected_params['MLPS'][k].copy()
channel_out = 0
for idx in range(mlps.__len__()):
mlps[idx] = [channel_in] + mlps[idx]
channel_out += mlps[idx][-1]
self.SA_modules.append(
PointnetSAModuleMSG(
npoint=selected_params['NPOINTS'][k],
radii=selected_params['RADIUS'][k],
nsamples=selected_params['NSAMPLE'][k],
mlps=mlps,
use_xyz=True,
bn=True
)
)
channel_in = channel_out
def forward(self, point_cloud: torch.cuda.FloatTensor):
xyz, features = break_up_pc(point_cloud)
l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)
return l_features[-1].squeeze(-1)
if __name__ == '__main__':
seed = 100
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
net = PointNet2Encoder(config={"in_dim": 3, "params_name": "strong"}).cuda()
pts = torch.randn(2, 2444, 3).cuda()
print(torch.mean(pts, dim=1))
pre = net.encode_points(pts)
print(pre.shape)