import torch import torch.nn as nn import PytorchBoot.stereotype as stereotype import sys sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction") from utils.pose import PoseUtil import modules.module_lib as mlib import modules.func_lib as flib def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module @stereotype.module("gf_view_finder") class GradientFieldViewFinder(nn.Module): def __init__(self, config): super(GradientFieldViewFinder, self).__init__() self.regression_head = config["regression_head"] self.per_point_feature = config["per_point_feature"] self.act = nn.ReLU(True) self.sample_mode = config["sample_mode"] self.pose_mode = config["pose_mode"] pose_dim = PoseUtil.get_pose_dim(self.pose_mode) self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"]) self.sampling_steps = config["sampling_steps"] self.t_feat_dim = config["t_feat_dim"] self.pose_feat_dim = config["pose_feat_dim"] self.main_feat_dim = config["main_feat_dim"] ''' encode pose ''' self.pose_encoder = nn.Sequential( nn.Linear(pose_dim, self.pose_feat_dim ), self.act, nn.Linear(self.pose_feat_dim , self.pose_feat_dim ), self.act, ) ''' encode t ''' self.t_encoder = nn.Sequential( mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ), nn.Linear(self.t_feat_dim , self.t_feat_dim ), self.act, ) ''' fusion tail ''' if self.regression_head == 'Rx_Ry_and_T': if self.pose_mode != 'rot_matrix': raise NotImplementedError if not self.per_point_feature: ''' rotation_x_axis regress head ''' self.fusion_tail_rot_x = nn.Sequential( nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256), self.act, zero_module(nn.Linear(256, 3)), ) self.fusion_tail_rot_y = nn.Sequential( nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256), self.act, zero_module(nn.Linear(256, 3)), ) ''' tranalation regress head ''' self.fusion_tail_trans = nn.Sequential( nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256), self.act, zero_module(nn.Linear(256, 3)), ) else: raise NotImplementedError else: raise NotImplementedError def forward(self, data): """ Args: data, dict { 'seq_feat': [bs, c] 'pose_sample': [bs, pose_dim] 't': [bs, 1] } """ seq_feat = data['seq_feat'] sampled_pose = data['sampled_pose'] t = data['t'] t_feat = self.t_encoder(t.squeeze(1)) pose_feat = self.pose_encoder(sampled_pose) if self.per_point_feature: raise NotImplementedError else: total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1) _, std = self.marginal_prob_fn(total_feat, t) if self.regression_head == 'Rx_Ry_and_T': rot_x = self.fusion_tail_rot_x(total_feat) rot_y = self.fusion_tail_rot_y(total_feat) trans = self.fusion_tail_trans(total_feat) out_score = torch.cat([rot_x, rot_y, trans], dim=-1) / (std+1e-7) # normalisation else: raise NotImplementedError return out_score def marginal_prob(self, x, t): return self.marginal_prob_fn(x,t) def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None): if self.sample_mode == 'ode': T0 = self.T if T0 is None else T0 in_process_sample, res = flib.cond_ode_sampler( score_model=self, data=data, prior=self.prior_fn, sde_coeff=self.sde_fn, atol=atol, rtol=rtol, eps=self.sampling_eps, T=T0, num_steps=self.sampling_steps, pose_mode=self.pose_mode, denoise=denoise, init_x=init_x ) else: raise NotImplementedError return in_process_sample, res def next_best_view(self, seq_feat): data = { 'seq_feat': seq_feat, } in_process_sample, res = self.sample(data) return res.to(dtype=torch.float32), in_process_sample ''' ----------- DEBUG -----------''' if __name__ == "__main__": config = { "regression_head": "Rx_Ry_and_T", "per_point_feature": False, "pose_mode": "rot_matrix", "sde_mode": "ve", "sampling_steps": 500, "sample_mode": "ode" } test_seq_feat = torch.rand(32, 2048).to("cuda:0") test_pose = torch.rand(32, 9).to("cuda:0") test_t = torch.rand(32, 1).to("cuda:0") view_finder = GradientFieldViewFinder(config).to("cuda:0") test_data = { 'seq_feat': test_seq_feat, 'sampled_pose': test_pose, 't': test_t } score = view_finder(test_data) print(score.shape) res, inprocess = view_finder.next_best_view(test_seq_feat) print(res.shape, inprocess.shape)