nbv_reconstruction/modules/gf_view_finder.py
2024-09-13 09:40:08 +00:00

168 lines
5.7 KiB
Python

import torch
import torch.nn as nn
import PytorchBoot.stereotype as stereotype
from utils.pose import PoseUtil
import modules.module_lib as mlib
import modules.func_lib as flib
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
@stereotype.module("gf_view_finder")
class GradientFieldViewFinder(nn.Module):
def __init__(self, config):
super(GradientFieldViewFinder, self).__init__()
self.regression_head = config["regression_head"]
self.per_point_feature = config["per_point_feature"]
self.act = nn.ReLU(True)
self.sample_mode = config["sample_mode"]
self.pose_mode = config["pose_mode"]
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
self.sampling_steps = config["sampling_steps"]
self.t_feat_dim = config["t_feat_dim"]
self.pose_feat_dim = config["pose_feat_dim"]
self.main_feat_dim = config["main_feat_dim"]
''' encode pose '''
self.pose_encoder = nn.Sequential(
nn.Linear(pose_dim, self.pose_feat_dim ),
self.act,
nn.Linear(self.pose_feat_dim , self.pose_feat_dim ),
self.act,
)
''' encode t '''
self.t_encoder = nn.Sequential(
mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ),
nn.Linear(self.t_feat_dim , self.t_feat_dim ),
self.act,
)
''' fusion tail '''
if self.regression_head == 'Rx_Ry_and_T':
if self.pose_mode != 'rot_matrix':
raise NotImplementedError
if not self.per_point_feature:
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
self.fusion_tail_rot_y = nn.Sequential(
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
''' tranalation regress head '''
self.fusion_tail_trans = nn.Sequential(
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
else:
raise NotImplementedError
else:
raise NotImplementedError
def forward(self, data):
"""
Args:
data, dict {
'seq_feat': [bs, c]
'pose_sample': [bs, pose_dim]
't': [bs, 1]
}
"""
seq_feat = data['seq_feat']
sampled_pose = data['sampled_pose']
t = data['t']
t_feat = self.t_encoder(t.squeeze(1))
pose_feat = self.pose_encoder(sampled_pose)
if self.per_point_feature:
raise NotImplementedError
else:
total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1)
_, std = self.marginal_prob_fn(total_feat, t)
if self.regression_head == 'Rx_Ry_and_T':
rot_x = self.fusion_tail_rot_x(total_feat)
rot_y = self.fusion_tail_rot_y(total_feat)
trans = self.fusion_tail_trans(total_feat)
out_score = torch.cat([rot_x, rot_y, trans], dim=-1) / (std+1e-7) # normalisation
else:
raise NotImplementedError
return out_score
def marginal_prob(self, x, t):
return self.marginal_prob_fn(x,t)
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
if self.sample_mode == 'ode':
T0 = self.T if T0 is None else T0
in_process_sample, res = flib.cond_ode_sampler(
score_model=self,
data=data,
prior=self.prior_fn,
sde_coeff=self.sde_fn,
atol=atol,
rtol=rtol,
eps=self.sampling_eps,
T=T0,
num_steps=self.sampling_steps,
pose_mode=self.pose_mode,
denoise=denoise,
init_x=init_x
)
else:
raise NotImplementedError
return in_process_sample, res
def next_best_view(self, seq_feat):
data = {
'seq_feat': seq_feat,
}
in_process_sample, res = self.sample(data)
return res.to(dtype=torch.float32), in_process_sample
''' ----------- DEBUG -----------'''
if __name__ == "__main__":
config = {
"regression_head": "Rx_Ry_and_T",
"per_point_feature": False,
"pose_mode": "rot_matrix",
"sde_mode": "ve",
"sampling_steps": 500,
"sample_mode": "ode"
}
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
test_pose = torch.rand(32, 9).to("cuda:0")
test_t = torch.rand(32, 1).to("cuda:0")
view_finder = GradientFieldViewFinder(config).to("cuda:0")
test_data = {
'seq_feat': test_seq_feat,
'sampled_pose': test_pose,
't': test_t
}
score = view_finder(test_data)
print(score.shape)
res, inprocess = view_finder.next_best_view(test_seq_feat)
print(res.shape, inprocess.shape)