96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
from scipy import integrate
|
|
from utils.pose import PoseUtil
|
|
|
|
|
|
def global_prior_likelihood(z, sigma_max):
|
|
"""The likelihood of a Gaussian distribution with mean zero and
|
|
standard deviation sigma."""
|
|
# z: [bs, pose_dim]
|
|
shape = z.shape
|
|
N = np.prod(shape[1:]) # pose_dim
|
|
return -N / 2.0 * torch.log(2 * np.pi * sigma_max**2) - torch.sum(
|
|
z**2, dim=-1
|
|
) / (2 * sigma_max**2)
|
|
|
|
|
|
def cond_ode_sampler(
|
|
score_model,
|
|
data,
|
|
prior,
|
|
sde_coeff,
|
|
atol=1e-5,
|
|
rtol=1e-5,
|
|
device="cuda",
|
|
eps=1e-5,
|
|
T=1.0,
|
|
num_steps=None,
|
|
pose_mode="quat_wxyz",
|
|
denoise=True,
|
|
init_x=None,
|
|
):
|
|
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
|
batch_size = data["main_feat"].shape[0]
|
|
init_x = (
|
|
prior((batch_size, pose_dim), T=T).to(device)
|
|
if init_x is None
|
|
else init_x + prior((batch_size, pose_dim), T=T).to(device)
|
|
)
|
|
shape = init_x.shape
|
|
|
|
def score_eval_wrapper(data):
|
|
"""A wrapper of the score-based model for use by the ODE solver."""
|
|
with torch.no_grad():
|
|
score = score_model(data)
|
|
return score.cpu().numpy().reshape((-1,))
|
|
|
|
def ode_func(t, x):
|
|
"""The ODE function for use by the ODE solver."""
|
|
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
|
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
|
drift, diffusion = sde_coeff(torch.tensor(t))
|
|
drift = drift.cpu().numpy()
|
|
diffusion = diffusion.cpu().numpy()
|
|
data["sampled_pose"] = x
|
|
data["t"] = time_steps
|
|
return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data)
|
|
|
|
# Run the black-box ODE solver, note the
|
|
t_eval = None
|
|
if num_steps is not None:
|
|
# num_steps, from T -> eps
|
|
t_eval = np.linspace(T, eps, num_steps)
|
|
res = integrate.solve_ivp(
|
|
ode_func,
|
|
(T, eps),
|
|
init_x.reshape(-1).cpu().numpy(),
|
|
rtol=rtol,
|
|
atol=atol,
|
|
method="RK45",
|
|
t_eval=t_eval,
|
|
)
|
|
xs = torch.tensor(res.y, device=device).T.view(
|
|
-1, batch_size, pose_dim
|
|
) # [num_steps, bs, pose_dim]
|
|
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
|
|
# denoise, using the predictor step in P-C sampler
|
|
if denoise:
|
|
# Reverse diffusion predictor for denoising
|
|
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
|
|
drift, diffusion = sde_coeff(vec_eps)
|
|
data["sampled_pose"] = x.float()
|
|
data["t"] = vec_eps
|
|
grad = score_model(data)
|
|
drift = drift - diffusion**2 * grad # R-SDE
|
|
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
|
|
x = mean_x
|
|
|
|
num_steps = xs.shape[0]
|
|
xs = xs.reshape(batch_size*num_steps, -1)
|
|
xs[:, :-3] = PoseUtil.normalize_rotation(xs[:, :-3], pose_mode)
|
|
xs = xs.reshape(num_steps, batch_size, -1)
|
|
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
|
|
return xs.permute(1, 0, 2), x
|