new_nbv_rec/modules/mlp_view_finder.py
2025-05-13 09:03:38 +08:00

92 lines
2.6 KiB
Python

import torch
import torch.nn as nn
import PytorchBoot.stereotype as stereotype
from utils.pose import PoseUtil
import modules.module_lib as mlib
import modules.func_lib as flib
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
@stereotype.module("mlp_view_finder")
class MLPViewFinder(nn.Module):
def __init__(self, config):
super(MLPViewFinder, self).__init__()
self.regression_head = 'Rx_Ry_and_T'
self.per_point_feature = False
self.act = nn.ReLU(True)
self.main_feat_dim = config["main_feat_dim"]
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
self.fusion_tail_rot_y = nn.Sequential(
nn.Linear(self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
''' tranalation regress head '''
self.fusion_tail_trans = nn.Sequential(
nn.Linear(self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
def forward(self, data):
"""
Args:
data, dict {
'main_feat': [bs, c]
}
"""
total_feat = data['main_feat']
rot_x = self.fusion_tail_rot_x(total_feat)
rot_y = self.fusion_tail_rot_y(total_feat)
trans = self.fusion_tail_trans(total_feat)
output = torch.cat([rot_x,rot_y,trans], dim=-1)
return output
def next_best_view(self, main_feat):
data = {
'main_feat': main_feat,
}
res = self(data)
return res.to(dtype=torch.float32), None
''' ----------- DEBUG -----------'''
if __name__ == "__main__":
config = {
"regression_head": "Rx_Ry_and_T",
"per_point_feature": False,
"pose_mode": "rot_matrix",
"sde_mode": "ve",
"sampling_steps": 500,
"sample_mode": "ode"
}
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
test_pose = torch.rand(32, 9).to("cuda:0")
test_t = torch.rand(32, 1).to("cuda:0")
view_finder = GradientFieldViewFinder(config).to("cuda:0")
test_data = {
'seq_feat': test_seq_feat,
'sampled_pose': test_pose,
't': test_t
}
score = view_finder(test_data)
print(score.shape)
res, inprocess = view_finder.next_best_view(test_seq_feat)
print(res.shape, inprocess.shape)