add transformer seq encoder and add seq_feat in gf_view_finder
This commit is contained in:
parent
68b4325dbd
commit
b06dede4b8
@ -11,145 +11,9 @@ def global_prior_likelihood(z, sigma_max):
|
|||||||
# z: [bs, pose_dim]
|
# z: [bs, pose_dim]
|
||||||
shape = z.shape
|
shape = z.shape
|
||||||
N = np.prod(shape[1:]) # pose_dim
|
N = np.prod(shape[1:]) # pose_dim
|
||||||
return -N / 2. * torch.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=-1) / (2 * sigma_max ** 2)
|
return -N / 2.0 * torch.log(2 * np.pi * sigma_max**2) - torch.sum(
|
||||||
|
z**2, dim=-1
|
||||||
|
) / (2 * sigma_max**2)
|
||||||
def cond_ode_likelihood(
|
|
||||||
score_model,
|
|
||||||
data,
|
|
||||||
prior,
|
|
||||||
sde_coeff,
|
|
||||||
marginal_prob_fn,
|
|
||||||
atol=1e-5,
|
|
||||||
rtol=1e-5,
|
|
||||||
device='cuda',
|
|
||||||
eps=1e-5,
|
|
||||||
num_steps=None,
|
|
||||||
pose_mode='quat_wxyz',
|
|
||||||
init_x=None,
|
|
||||||
):
|
|
||||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
|
||||||
batch_size = data['pts'].shape[0]
|
|
||||||
epsilon = prior((batch_size, pose_dim)).to(device)
|
|
||||||
init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x
|
|
||||||
shape = init_x.shape
|
|
||||||
init_logp = np.zeros((shape[0],)) # [bs]
|
|
||||||
init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0)
|
|
||||||
|
|
||||||
def score_eval_wrapper(data):
|
|
||||||
"""A wrapper of the score-based model for use by the ODE solver."""
|
|
||||||
with torch.no_grad():
|
|
||||||
score = score_model(data)
|
|
||||||
return score.cpu().numpy().reshape((-1,))
|
|
||||||
|
|
||||||
def divergence_eval(data, epsilon):
|
|
||||||
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
|
|
||||||
# save ckpt of sampled_pose
|
|
||||||
origin_sampled_pose = data['sampled_pose'].clone()
|
|
||||||
with torch.enable_grad():
|
|
||||||
# make sampled_pose differentiable
|
|
||||||
data['sampled_pose'].requires_grad_(True)
|
|
||||||
score = score_model(data)
|
|
||||||
score_energy = torch.sum(score * epsilon) # [, ]
|
|
||||||
grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim]
|
|
||||||
# reset sampled_pose
|
|
||||||
data['sampled_pose'] = origin_sampled_pose
|
|
||||||
return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1]
|
|
||||||
|
|
||||||
def divergence_eval_wrapper(data):
|
|
||||||
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
|
|
||||||
with torch.no_grad():
|
|
||||||
# Compute likelihood.
|
|
||||||
div = divergence_eval(data, epsilon) # [bs, 1]
|
|
||||||
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
|
|
||||||
|
|
||||||
def ode_func(t, inp):
|
|
||||||
"""The ODE function for use by the ODE solver."""
|
|
||||||
# split x, logp from inp
|
|
||||||
x = inp[:-shape[0]]
|
|
||||||
# calc x-grad
|
|
||||||
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
|
||||||
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
|
||||||
drift, diffusion = sde_coeff(torch.tensor(t))
|
|
||||||
drift = drift.cpu().numpy()
|
|
||||||
diffusion = diffusion.cpu().numpy()
|
|
||||||
data['sampled_pose'] = x
|
|
||||||
data['t'] = time_steps
|
|
||||||
x_grad = drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
|
|
||||||
# calc logp-grad
|
|
||||||
logp_grad = drift - 0.5 * (diffusion ** 2) * divergence_eval_wrapper(data)
|
|
||||||
# concat curr grad
|
|
||||||
return np.concatenate([x_grad, logp_grad], axis=0)
|
|
||||||
|
|
||||||
# Run the black-box ODE solver, note the
|
|
||||||
res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45')
|
|
||||||
zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)]
|
|
||||||
z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim]
|
|
||||||
delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp
|
|
||||||
_, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1
|
|
||||||
prior_logp = global_prior_likelihood(z, sigma_max)
|
|
||||||
log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls)
|
|
||||||
return z, log_likelihoods
|
|
||||||
|
|
||||||
|
|
||||||
def cond_pc_sampler(
|
|
||||||
score_model,
|
|
||||||
data,
|
|
||||||
prior,
|
|
||||||
sde_coeff,
|
|
||||||
num_steps=500,
|
|
||||||
snr=0.16,
|
|
||||||
device='cuda',
|
|
||||||
eps=1e-5,
|
|
||||||
pose_mode='quat_wxyz',
|
|
||||||
init_x=None,
|
|
||||||
):
|
|
||||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
|
||||||
batch_size = data['target_pts_feat'].shape[0]
|
|
||||||
init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x
|
|
||||||
time_steps = torch.linspace(1., eps, num_steps, device=device)
|
|
||||||
step_size = time_steps[0] - time_steps[1]
|
|
||||||
noise_norm = np.sqrt(pose_dim)
|
|
||||||
x = init_x
|
|
||||||
poses = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for time_step in time_steps:
|
|
||||||
batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step
|
|
||||||
# Corrector step (Langevin MCMC)
|
|
||||||
data['sampled_pose'] = x
|
|
||||||
data['t'] = batch_time_step
|
|
||||||
grad = score_model(data)
|
|
||||||
grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean()
|
|
||||||
langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2
|
|
||||||
x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
|
|
||||||
|
|
||||||
# normalisation
|
|
||||||
if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw':
|
|
||||||
# quat, should be normalised
|
|
||||||
x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True)
|
|
||||||
elif pose_mode == 'euler_xyz':
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# rotation(x axis, y axis), should be normalised
|
|
||||||
x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True)
|
|
||||||
x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Predictor step (Euler-Maruyama)
|
|
||||||
drift, diffusion = sde_coeff(batch_time_step)
|
|
||||||
drift = drift - diffusion ** 2 * grad # R-SDE
|
|
||||||
mean_x = x + drift * step_size
|
|
||||||
x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x)
|
|
||||||
|
|
||||||
# normalisation
|
|
||||||
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
|
|
||||||
poses.append(x.unsqueeze(0))
|
|
||||||
|
|
||||||
xs = torch.cat(poses, dim=0)
|
|
||||||
xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1)
|
|
||||||
mean_x[:, -3:] += data['pts_center']
|
|
||||||
mean_x[:, :-3] = PoseUtil.normalize_rotation(mean_x[:, :-3], pose_mode)
|
|
||||||
# The last step does not include any noise
|
|
||||||
return xs.permute(1, 0, 2), mean_x
|
|
||||||
|
|
||||||
|
|
||||||
def cond_ode_sampler(
|
def cond_ode_sampler(
|
||||||
@ -159,18 +23,21 @@ def cond_ode_sampler(
|
|||||||
sde_coeff,
|
sde_coeff,
|
||||||
atol=1e-5,
|
atol=1e-5,
|
||||||
rtol=1e-5,
|
rtol=1e-5,
|
||||||
device='cuda',
|
device="cuda",
|
||||||
eps=1e-5,
|
eps=1e-5,
|
||||||
T=1.0,
|
T=1.0,
|
||||||
num_steps=None,
|
num_steps=None,
|
||||||
pose_mode='quat_wxyz',
|
pose_mode="quat_wxyz",
|
||||||
denoise=True,
|
denoise=True,
|
||||||
init_x=None,
|
init_x=None,
|
||||||
):
|
):
|
||||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||||
batch_size = data['target_feat'].shape[0]
|
batch_size = data["seq_feat"].shape[0]
|
||||||
init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim),
|
init_x = (
|
||||||
T=T).to(device)
|
prior((batch_size, pose_dim), T=T).to(device)
|
||||||
|
if init_x is None
|
||||||
|
else init_x + prior((batch_size, pose_dim), T=T).to(device)
|
||||||
|
)
|
||||||
shape = init_x.shape
|
shape = init_x.shape
|
||||||
|
|
||||||
def score_eval_wrapper(data):
|
def score_eval_wrapper(data):
|
||||||
@ -186,8 +53,8 @@ def cond_ode_sampler(
|
|||||||
drift, diffusion = sde_coeff(torch.tensor(t))
|
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||||
drift = drift.cpu().numpy()
|
drift = drift.cpu().numpy()
|
||||||
diffusion = diffusion.cpu().numpy()
|
diffusion = diffusion.cpu().numpy()
|
||||||
data['sampled_pose'] = x
|
data["sampled_pose"] = x
|
||||||
data['t'] = time_steps
|
data["t"] = time_steps
|
||||||
return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data)
|
return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data)
|
||||||
|
|
||||||
# Run the black-box ODE solver, note the
|
# Run the black-box ODE solver, note the
|
||||||
@ -195,17 +62,26 @@ def cond_ode_sampler(
|
|||||||
if num_steps is not None:
|
if num_steps is not None:
|
||||||
# num_steps, from T -> eps
|
# num_steps, from T -> eps
|
||||||
t_eval = np.linspace(T, eps, num_steps)
|
t_eval = np.linspace(T, eps, num_steps)
|
||||||
res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45',
|
res = integrate.solve_ivp(
|
||||||
t_eval=t_eval)
|
ode_func,
|
||||||
xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
|
(T, eps),
|
||||||
|
init_x.reshape(-1).cpu().numpy(),
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
method="RK45",
|
||||||
|
t_eval=t_eval,
|
||||||
|
)
|
||||||
|
xs = torch.tensor(res.y, device=device).T.view(
|
||||||
|
-1, batch_size, pose_dim
|
||||||
|
) # [num_steps, bs, pose_dim]
|
||||||
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
|
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
|
||||||
# denoise, using the predictor step in P-C sampler
|
# denoise, using the predictor step in P-C sampler
|
||||||
if denoise:
|
if denoise:
|
||||||
# Reverse diffusion predictor for denoising
|
# Reverse diffusion predictor for denoising
|
||||||
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
|
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
|
||||||
drift, diffusion = sde_coeff(vec_eps)
|
drift, diffusion = sde_coeff(vec_eps)
|
||||||
data['sampled_pose'] = x.float()
|
data["sampled_pose"] = x.float()
|
||||||
data['t'] = vec_eps
|
data["t"] = vec_eps
|
||||||
grad = score_model(data)
|
grad = score_model(data)
|
||||||
drift = drift - diffusion**2 * grad # R-SDE
|
drift = drift - diffusion**2 * grad # R-SDE
|
||||||
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
|
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
|
||||||
@ -217,64 +93,3 @@ def cond_ode_sampler(
|
|||||||
xs = xs.reshape(num_steps, batch_size, -1)
|
xs = xs.reshape(num_steps, batch_size, -1)
|
||||||
x = PoseUtil.normalize_rotation(x, pose_mode)
|
x = PoseUtil.normalize_rotation(x, pose_mode)
|
||||||
return xs.permute(1, 0, 2), x
|
return xs.permute(1, 0, 2), x
|
||||||
|
|
||||||
|
|
||||||
def cond_edm_sampler(
|
|
||||||
decoder_model, data, prior_fn, randn_like=torch.randn_like,
|
|
||||||
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
|
||||||
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
|
||||||
pose_mode='quat_wxyz', device='cuda'
|
|
||||||
):
|
|
||||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
|
||||||
batch_size = data['pts'].shape[0]
|
|
||||||
latents = prior_fn((batch_size, pose_dim)).to(device)
|
|
||||||
|
|
||||||
# Time step discretion. note that sigma and t is interchangeable
|
|
||||||
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
|
||||||
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
|
||||||
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
|
||||||
t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
|
||||||
|
|
||||||
def decoder_wrapper(decoder, data, x, t):
|
|
||||||
# save temp
|
|
||||||
x_, t_ = data['sampled_pose'], data['t']
|
|
||||||
# init data
|
|
||||||
data['sampled_pose'], data['t'] = x, t
|
|
||||||
# denoise
|
|
||||||
data, denoised = decoder(data)
|
|
||||||
# recover data
|
|
||||||
data['sampled_pose'], data['t'] = x_, t_
|
|
||||||
return denoised.to(torch.float64)
|
|
||||||
|
|
||||||
# Main sampling loop.
|
|
||||||
x_next = latents.to(torch.float64) * t_steps[0]
|
|
||||||
xs = []
|
|
||||||
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
|
||||||
x_cur = x_next
|
|
||||||
|
|
||||||
# Increase noise temporarily.
|
|
||||||
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
|
||||||
t_hat = torch.as_tensor(t_cur + gamma * t_cur)
|
|
||||||
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
|
||||||
|
|
||||||
# Euler step.
|
|
||||||
denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat)
|
|
||||||
d_cur = (x_hat - denoised) / t_hat
|
|
||||||
x_next = x_hat + (t_next - t_hat) * d_cur
|
|
||||||
|
|
||||||
# Apply 2nd order correction.
|
|
||||||
if i < num_steps - 1:
|
|
||||||
denoised = decoder_wrapper(decoder_model, data, x_next, t_next)
|
|
||||||
d_prime = (x_next - denoised) / t_next
|
|
||||||
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
|
||||||
xs.append(x_next.unsqueeze(0))
|
|
||||||
|
|
||||||
xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim]
|
|
||||||
x = xs[-1] # [bs, pose_dim]
|
|
||||||
|
|
||||||
# post-processing
|
|
||||||
xs = xs.reshape(batch_size * num_steps, -1)
|
|
||||||
xs = PoseUtil.normalize_rotation(xs, pose_mode)
|
|
||||||
xs = xs.reshape(num_steps, batch_size, -1)
|
|
||||||
x = PoseUtil.normalize_rotation(x, pose_mode)
|
|
||||||
return xs.permute(1, 0, 2), x
|
|
||||||
|
@ -1,10 +1,47 @@
|
|||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import PytorchBoot.stereotype as stereotype
|
import PytorchBoot.stereotype as stereotype
|
||||||
|
import sys; sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
|
||||||
|
from modules.seq_encoder.abstract_seq_encoder import SequenceEncoder
|
||||||
|
|
||||||
@stereotype.module("transformer_seq_encoder")
|
@stereotype.module("transformer_seq_encoder")
|
||||||
class TransformerSequenceEncoder(nn.Module):
|
class TransformerSequenceEncoder(SequenceEncoder):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(TransformerSequenceEncoder, self).__init__()
|
super(TransformerSequenceEncoder, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
embed_dim = config['pts_embed_dim'] + config['pose_embed_dim']
|
||||||
|
self.positional_encoding = nn.Parameter(torch.zeros(1, config['max_seq_len'], embed_dim))
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=config['num_heads'], dim_feedforward=config['ffn_dim'])
|
||||||
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config['num_layers'])
|
||||||
|
self.fc = nn.Linear(embed_dim, config['output_dim'])
|
||||||
|
|
||||||
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
|
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
|
||||||
pass
|
combined_features = [torch.cat((pts_embed, pose_embed), dim=-1) for pts_embed, pose_embed in zip(pts_embedding_list[:-1], pose_embedding_list[:-1])]
|
||||||
|
combined_tensor = torch.stack(combined_features)
|
||||||
|
pos_encoding = self.positional_encoding[:, :combined_tensor.size(0), :]
|
||||||
|
combined_tensor = combined_tensor.unsqueeze(0) + pos_encoding
|
||||||
|
transformer_output = self.transformer_encoder(combined_tensor).squeeze(0)
|
||||||
|
final_feature = transformer_output.mean(dim=0)
|
||||||
|
final_output = self.fc(final_feature)
|
||||||
|
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config = {
|
||||||
|
'pts_embed_dim': 1024, # 每个点云embedding的维度
|
||||||
|
'pose_embed_dim': 256, # 每个姿态embedding的维度
|
||||||
|
'num_heads': 4, # 多头注意力机制的头数
|
||||||
|
'ffn_dim': 256, # 前馈神经网络的维度
|
||||||
|
'num_layers': 3, # Transformer 编码层数
|
||||||
|
'max_seq_len': 10, # 最大序列长度
|
||||||
|
'output_dim': 2048, # 输出特征维度
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder = TransformerSequenceEncoder(config)
|
||||||
|
seq_len = 5
|
||||||
|
pts_embedding_list = [torch.randn(config['pts_embed_dim']) for _ in range(seq_len)]
|
||||||
|
pose_embedding_list = [torch.randn(config['pose_embed_dim']) for _ in range(seq_len)]
|
||||||
|
output_feature = encoder.encode_sequence(pts_embedding_list, pose_embedding_list)
|
||||||
|
print("Encoded Feature:", output_feature)
|
||||||
|
print("Feature Shape:", output_feature.shape)
|
@ -54,12 +54,12 @@ class GradientFieldViewFinder(ViewFinder):
|
|||||||
if not self.per_point_feature:
|
if not self.per_point_feature:
|
||||||
''' rotation_x_axis regress head '''
|
''' rotation_x_axis regress head '''
|
||||||
self.fusion_tail_rot_x = nn.Sequential(
|
self.fusion_tail_rot_x = nn.Sequential(
|
||||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
nn.Linear(128 + 256 + 2048, 256),
|
||||||
self.act,
|
self.act,
|
||||||
zero_module(nn.Linear(256, 3)),
|
zero_module(nn.Linear(256, 3)),
|
||||||
)
|
)
|
||||||
self.fusion_tail_rot_y = nn.Sequential(
|
self.fusion_tail_rot_y = nn.Sequential(
|
||||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
nn.Linear(128 + 256 + 2048, 256),
|
||||||
self.act,
|
self.act,
|
||||||
zero_module(nn.Linear(256, 3)),
|
zero_module(nn.Linear(256, 3)),
|
||||||
)
|
)
|
||||||
@ -72,15 +72,13 @@ class GradientFieldViewFinder(ViewFinder):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
data, dict {
|
data, dict {
|
||||||
'target_pts_feat': [bs, c]
|
'seq_feat': [bs, c]
|
||||||
'scene_pts_feat': [bs, c]
|
|
||||||
'pose_sample': [bs, pose_dim]
|
'pose_sample': [bs, pose_dim]
|
||||||
't': [bs, 1]
|
't': [bs, 1]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scene_pts_feat = data['scene_feat']
|
seq_feat = data['seq_feat']
|
||||||
target_pts_feat = data['target_feat']
|
|
||||||
sampled_pose = data['sampled_pose']
|
sampled_pose = data['sampled_pose']
|
||||||
t = data['t']
|
t = data['t']
|
||||||
t_feat = self.t_encoder(t.squeeze(1))
|
t_feat = self.t_encoder(t.squeeze(1))
|
||||||
@ -89,7 +87,7 @@ class GradientFieldViewFinder(ViewFinder):
|
|||||||
if self.per_point_feature:
|
if self.per_point_feature:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1)
|
total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1)
|
||||||
_, std = self.marginal_prob_fn(total_feat, t)
|
_, std = self.marginal_prob_fn(total_feat, t)
|
||||||
|
|
||||||
if self.regression_head == 'Rx_Ry':
|
if self.regression_head == 'Rx_Ry':
|
||||||
@ -106,20 +104,7 @@ class GradientFieldViewFinder(ViewFinder):
|
|||||||
|
|
||||||
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
|
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
|
||||||
|
|
||||||
if self.sample_mode == 'pc':
|
if self.sample_mode == 'ode':
|
||||||
in_process_sample, res = flib.cond_pc_sampler(
|
|
||||||
score_model=self,
|
|
||||||
data=data,
|
|
||||||
prior=self.prior_fn,
|
|
||||||
sde_coeff=self.sde_fn,
|
|
||||||
num_steps=self.sampling_steps,
|
|
||||||
snr=snr,
|
|
||||||
eps=self.sampling_eps,
|
|
||||||
pose_mode=self.pose_mode,
|
|
||||||
init_x=init_x
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.sample_mode == 'ode':
|
|
||||||
T0 = self.T if T0 is None else T0
|
T0 = self.T if T0 is None else T0
|
||||||
in_process_sample, res = flib.cond_ode_sampler(
|
in_process_sample, res = flib.cond_ode_sampler(
|
||||||
score_model=self,
|
score_model=self,
|
||||||
@ -140,10 +125,9 @@ class GradientFieldViewFinder(ViewFinder):
|
|||||||
|
|
||||||
return in_process_sample, res
|
return in_process_sample, res
|
||||||
|
|
||||||
def next_best_view(self, scene_pts_feat, target_pts_feat):
|
def next_best_view(self, seq_feat):
|
||||||
data = {
|
data = {
|
||||||
'scene_feat': scene_pts_feat,
|
'seq_feat': seq_feat,
|
||||||
'target_feat': target_pts_feat,
|
|
||||||
}
|
}
|
||||||
in_process_sample, res = self.sample(data)
|
in_process_sample, res = self.sample(data)
|
||||||
return res.to(dtype=torch.float32), in_process_sample
|
return res.to(dtype=torch.float32), in_process_sample
|
||||||
|
Loading…
x
Reference in New Issue
Block a user