92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import PytorchBoot.stereotype as stereotype
|
|
|
|
from utils.pose import PoseUtil
|
|
import modules.module_lib as mlib
|
|
import modules.func_lib as flib
|
|
|
|
def zero_module(module):
|
|
"""
|
|
Zero out the parameters of a module and return it.
|
|
"""
|
|
for p in module.parameters():
|
|
p.detach().zero_()
|
|
return module
|
|
|
|
@stereotype.module("mlp_view_finder")
|
|
class MLPViewFinder(nn.Module):
|
|
def __init__(self, config):
|
|
|
|
super(MLPViewFinder, self).__init__()
|
|
|
|
self.regression_head = 'Rx_Ry_and_T'
|
|
self.per_point_feature = False
|
|
self.act = nn.ReLU(True)
|
|
self.main_feat_dim = config["main_feat_dim"]
|
|
|
|
''' rotation_x_axis regress head '''
|
|
self.fusion_tail_rot_x = nn.Sequential(
|
|
nn.Linear(self.main_feat_dim, 256),
|
|
self.act,
|
|
zero_module(nn.Linear(256, 3)),
|
|
)
|
|
self.fusion_tail_rot_y = nn.Sequential(
|
|
nn.Linear(self.main_feat_dim, 256),
|
|
self.act,
|
|
zero_module(nn.Linear(256, 3)),
|
|
)
|
|
''' tranalation regress head '''
|
|
self.fusion_tail_trans = nn.Sequential(
|
|
nn.Linear(self.main_feat_dim, 256),
|
|
self.act,
|
|
zero_module(nn.Linear(256, 3)),
|
|
)
|
|
|
|
|
|
def forward(self, data):
|
|
"""
|
|
Args:
|
|
data, dict {
|
|
'main_feat': [bs, c]
|
|
}
|
|
"""
|
|
|
|
total_feat = data['main_feat']
|
|
rot_x = self.fusion_tail_rot_x(total_feat)
|
|
rot_y = self.fusion_tail_rot_y(total_feat)
|
|
trans = self.fusion_tail_trans(total_feat)
|
|
output = torch.cat([rot_x,rot_y,trans], dim=-1)
|
|
return output
|
|
|
|
def next_best_view(self, main_feat):
|
|
data = {
|
|
'main_feat': main_feat,
|
|
}
|
|
res = self(data)
|
|
return res.to(dtype=torch.float32), None
|
|
|
|
''' ----------- DEBUG -----------'''
|
|
if __name__ == "__main__":
|
|
config = {
|
|
"regression_head": "Rx_Ry_and_T",
|
|
"per_point_feature": False,
|
|
"pose_mode": "rot_matrix",
|
|
"sde_mode": "ve",
|
|
"sampling_steps": 500,
|
|
"sample_mode": "ode"
|
|
}
|
|
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
|
|
test_pose = torch.rand(32, 9).to("cuda:0")
|
|
test_t = torch.rand(32, 1).to("cuda:0")
|
|
view_finder = GradientFieldViewFinder(config).to("cuda:0")
|
|
test_data = {
|
|
'seq_feat': test_seq_feat,
|
|
'sampled_pose': test_pose,
|
|
't': test_t
|
|
}
|
|
score = view_finder(test_data)
|
|
print(score.shape)
|
|
res, inprocess = view_finder.next_best_view(test_seq_feat)
|
|
print(res.shape, inprocess.shape)
|