import torch import torch.nn as nn import PytorchBoot.stereotype as stereotype from utils.pose import PoseUtil import modules.module_lib as mlib import modules.func_lib as flib def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module @stereotype.module("mlp_view_finder") class MLPViewFinder(nn.Module): def __init__(self, config): super(MLPViewFinder, self).__init__() self.regression_head = 'Rx_Ry_and_T' self.per_point_feature = False self.act = nn.ReLU(True) self.main_feat_dim = config["main_feat_dim"] ''' rotation_x_axis regress head ''' self.fusion_tail_rot_x = nn.Sequential( nn.Linear(self.main_feat_dim, 256), self.act, zero_module(nn.Linear(256, 3)), ) self.fusion_tail_rot_y = nn.Sequential( nn.Linear(self.main_feat_dim, 256), self.act, zero_module(nn.Linear(256, 3)), ) ''' tranalation regress head ''' self.fusion_tail_trans = nn.Sequential( nn.Linear(self.main_feat_dim, 256), self.act, zero_module(nn.Linear(256, 3)), ) def forward(self, data): """ Args: data, dict { 'main_feat': [bs, c] } """ total_feat = data['main_feat'] rot_x = self.fusion_tail_rot_x(total_feat) rot_y = self.fusion_tail_rot_y(total_feat) trans = self.fusion_tail_trans(total_feat) output = torch.cat([rot_x,rot_y,trans], dim=-1) return output def next_best_view(self, main_feat): data = { 'main_feat': main_feat, } res = self(data) return res.to(dtype=torch.float32), None ''' ----------- DEBUG -----------''' if __name__ == "__main__": config = { "regression_head": "Rx_Ry_and_T", "per_point_feature": False, "pose_mode": "rot_matrix", "sde_mode": "ve", "sampling_steps": 500, "sample_mode": "ode" } test_seq_feat = torch.rand(32, 2048).to("cuda:0") test_pose = torch.rand(32, 9).to("cuda:0") test_t = torch.rand(32, 1).to("cuda:0") view_finder = GradientFieldViewFinder(config).to("cuda:0") test_data = { 'seq_feat': test_seq_feat, 'sampled_pose': test_pose, 't': test_t } score = view_finder(test_data) print(score.shape) res, inprocess = view_finder.next_best_view(test_seq_feat) print(res.shape, inprocess.shape)