diff --git a/modules/func_lib/samplers.py b/modules/func_lib/samplers.py index cdefdcc..b0274fa 100644 --- a/modules/func_lib/samplers.py +++ b/modules/func_lib/samplers.py @@ -7,170 +7,37 @@ from utils.pose import PoseUtil def global_prior_likelihood(z, sigma_max): """The likelihood of a Gaussian distribution with mean zero and - standard deviation sigma.""" + standard deviation sigma.""" # z: [bs, pose_dim] shape = z.shape N = np.prod(shape[1:]) # pose_dim - return -N / 2. * torch.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=-1) / (2 * sigma_max ** 2) - - -def cond_ode_likelihood( - score_model, - data, - prior, - sde_coeff, - marginal_prob_fn, - atol=1e-5, - rtol=1e-5, - device='cuda', - eps=1e-5, - num_steps=None, - pose_mode='quat_wxyz', - init_x=None, -): - pose_dim = PoseUtil.get_pose_dim(pose_mode) - batch_size = data['pts'].shape[0] - epsilon = prior((batch_size, pose_dim)).to(device) - init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x - shape = init_x.shape - init_logp = np.zeros((shape[0],)) # [bs] - init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0) - - def score_eval_wrapper(data): - """A wrapper of the score-based model for use by the ODE solver.""" - with torch.no_grad(): - score = score_model(data) - return score.cpu().numpy().reshape((-1,)) - - def divergence_eval(data, epsilon): - """Compute the divergence of the score-based model with Skilling-Hutchinson.""" - # save ckpt of sampled_pose - origin_sampled_pose = data['sampled_pose'].clone() - with torch.enable_grad(): - # make sampled_pose differentiable - data['sampled_pose'].requires_grad_(True) - score = score_model(data) - score_energy = torch.sum(score * epsilon) # [, ] - grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim] - # reset sampled_pose - data['sampled_pose'] = origin_sampled_pose - return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1] - - def divergence_eval_wrapper(data): - """A wrapper for evaluating the divergence of score for the black-box ODE solver.""" - with torch.no_grad(): - # Compute likelihood. - div = divergence_eval(data, epsilon) # [bs, 1] - return div.cpu().numpy().reshape((-1,)).astype(np.float64) - - def ode_func(t, inp): - """The ODE function for use by the ODE solver.""" - # split x, logp from inp - x = inp[:-shape[0]] - # calc x-grad - x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device) - time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t - drift, diffusion = sde_coeff(torch.tensor(t)) - drift = drift.cpu().numpy() - diffusion = diffusion.cpu().numpy() - data['sampled_pose'] = x - data['t'] = time_steps - x_grad = drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data) - # calc logp-grad - logp_grad = drift - 0.5 * (diffusion ** 2) * divergence_eval_wrapper(data) - # concat curr grad - return np.concatenate([x_grad, logp_grad], axis=0) - - # Run the black-box ODE solver, note the - res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45') - zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)] - z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim] - delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp - _, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1 - prior_logp = global_prior_likelihood(z, sigma_max) - log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls) - return z, log_likelihoods - - -def cond_pc_sampler( - score_model, - data, - prior, - sde_coeff, - num_steps=500, - snr=0.16, - device='cuda', - eps=1e-5, - pose_mode='quat_wxyz', - init_x=None, -): - pose_dim = PoseUtil.get_pose_dim(pose_mode) - batch_size = data['target_pts_feat'].shape[0] - init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x - time_steps = torch.linspace(1., eps, num_steps, device=device) - step_size = time_steps[0] - time_steps[1] - noise_norm = np.sqrt(pose_dim) - x = init_x - poses = [] - with torch.no_grad(): - for time_step in time_steps: - batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step - # Corrector step (Langevin MCMC) - data['sampled_pose'] = x - data['t'] = batch_time_step - grad = score_model(data) - grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean() - langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2 - x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x) - - # normalisation - if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw': - # quat, should be normalised - x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True) - elif pose_mode == 'euler_xyz': - pass - else: - # rotation(x axis, y axis), should be normalised - x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True) - x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True) - - # Predictor step (Euler-Maruyama) - drift, diffusion = sde_coeff(batch_time_step) - drift = drift - diffusion ** 2 * grad # R-SDE - mean_x = x + drift * step_size - x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x) - - # normalisation - x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode) - poses.append(x.unsqueeze(0)) - - xs = torch.cat(poses, dim=0) - xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1) - mean_x[:, -3:] += data['pts_center'] - mean_x[:, :-3] = PoseUtil.normalize_rotation(mean_x[:, :-3], pose_mode) - # The last step does not include any noise - return xs.permute(1, 0, 2), mean_x + return -N / 2.0 * torch.log(2 * np.pi * sigma_max**2) - torch.sum( + z**2, dim=-1 + ) / (2 * sigma_max**2) def cond_ode_sampler( - score_model, - data, - prior, - sde_coeff, - atol=1e-5, - rtol=1e-5, - device='cuda', - eps=1e-5, - T=1.0, - num_steps=None, - pose_mode='quat_wxyz', - denoise=True, - init_x=None, + score_model, + data, + prior, + sde_coeff, + atol=1e-5, + rtol=1e-5, + device="cuda", + eps=1e-5, + T=1.0, + num_steps=None, + pose_mode="quat_wxyz", + denoise=True, + init_x=None, ): pose_dim = PoseUtil.get_pose_dim(pose_mode) - batch_size = data['target_feat'].shape[0] - init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim), - T=T).to(device) + batch_size = data["seq_feat"].shape[0] + init_x = ( + prior((batch_size, pose_dim), T=T).to(device) + if init_x is None + else init_x + prior((batch_size, pose_dim), T=T).to(device) + ) shape = init_x.shape def score_eval_wrapper(data): @@ -186,28 +53,37 @@ def cond_ode_sampler( drift, diffusion = sde_coeff(torch.tensor(t)) drift = drift.cpu().numpy() diffusion = diffusion.cpu().numpy() - data['sampled_pose'] = x - data['t'] = time_steps - return drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data) + data["sampled_pose"] = x + data["t"] = time_steps + return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data) # Run the black-box ODE solver, note the t_eval = None if num_steps is not None: # num_steps, from T -> eps t_eval = np.linspace(T, eps, num_steps) - res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45', - t_eval=t_eval) - xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim] + res = integrate.solve_ivp( + ode_func, + (T, eps), + init_x.reshape(-1).cpu().numpy(), + rtol=rtol, + atol=atol, + method="RK45", + t_eval=t_eval, + ) + xs = torch.tensor(res.y, device=device).T.view( + -1, batch_size, pose_dim + ) # [num_steps, bs, pose_dim] x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim] # denoise, using the predictor step in P-C sampler if denoise: # Reverse diffusion predictor for denoising vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps drift, diffusion = sde_coeff(vec_eps) - data['sampled_pose'] = x.float() - data['t'] = vec_eps + data["sampled_pose"] = x.float() + data["t"] = vec_eps grad = score_model(data) - drift = drift - diffusion ** 2 * grad # R-SDE + drift = drift - diffusion**2 * grad # R-SDE mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps)) x = mean_x @@ -217,64 +93,3 @@ def cond_ode_sampler( xs = xs.reshape(num_steps, batch_size, -1) x = PoseUtil.normalize_rotation(x, pose_mode) return xs.permute(1, 0, 2), x - - -def cond_edm_sampler( - decoder_model, data, prior_fn, randn_like=torch.randn_like, - num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, - S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, - pose_mode='quat_wxyz', device='cuda' -): - pose_dim = PoseUtil.get_pose_dim(pose_mode) - batch_size = data['pts'].shape[0] - latents = prior_fn((batch_size, pose_dim)).to(device) - - # Time step discretion. note that sigma and t is interchangeable - step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) - t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( - sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho - t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 - - def decoder_wrapper(decoder, data, x, t): - # save temp - x_, t_ = data['sampled_pose'], data['t'] - # init data - data['sampled_pose'], data['t'] = x, t - # denoise - data, denoised = decoder(data) - # recover data - data['sampled_pose'], data['t'] = x_, t_ - return denoised.to(torch.float64) - - # Main sampling loop. - x_next = latents.to(torch.float64) * t_steps[0] - xs = [] - for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 - x_cur = x_next - - # Increase noise temporarily. - gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 - t_hat = torch.as_tensor(t_cur + gamma * t_cur) - x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) - - # Euler step. - denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat) - d_cur = (x_hat - denoised) / t_hat - x_next = x_hat + (t_next - t_hat) * d_cur - - # Apply 2nd order correction. - if i < num_steps - 1: - denoised = decoder_wrapper(decoder_model, data, x_next, t_next) - d_prime = (x_next - denoised) / t_next - x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) - xs.append(x_next.unsqueeze(0)) - - xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim] - x = xs[-1] # [bs, pose_dim] - - # post-processing - xs = xs.reshape(batch_size * num_steps, -1) - xs = PoseUtil.normalize_rotation(xs, pose_mode) - xs = xs.reshape(num_steps, batch_size, -1) - x = PoseUtil.normalize_rotation(x, pose_mode) - return xs.permute(1, 0, 2), x diff --git a/modules/seq_encoder/transformer_seq_encoder.py b/modules/seq_encoder/transformer_seq_encoder.py index 42af3b4..c318a2a 100644 --- a/modules/seq_encoder/transformer_seq_encoder.py +++ b/modules/seq_encoder/transformer_seq_encoder.py @@ -1,10 +1,47 @@ +import torch from torch import nn + import PytorchBoot.stereotype as stereotype +import sys; sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction") +from modules.seq_encoder.abstract_seq_encoder import SequenceEncoder @stereotype.module("transformer_seq_encoder") -class TransformerSequenceEncoder(nn.Module): +class TransformerSequenceEncoder(SequenceEncoder): def __init__(self, config): super(TransformerSequenceEncoder, self).__init__() self.config = config + embed_dim = config['pts_embed_dim'] + config['pose_embed_dim'] + self.positional_encoding = nn.Parameter(torch.zeros(1, config['max_seq_len'], embed_dim)) + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=config['num_heads'], dim_feedforward=config['ffn_dim']) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config['num_layers']) + self.fc = nn.Linear(embed_dim, config['output_dim']) + def encode_sequence(self, pts_embedding_list, pose_embedding_list): - pass + combined_features = [torch.cat((pts_embed, pose_embed), dim=-1) for pts_embed, pose_embed in zip(pts_embedding_list[:-1], pose_embedding_list[:-1])] + combined_tensor = torch.stack(combined_features) + pos_encoding = self.positional_encoding[:, :combined_tensor.size(0), :] + combined_tensor = combined_tensor.unsqueeze(0) + pos_encoding + transformer_output = self.transformer_encoder(combined_tensor).squeeze(0) + final_feature = transformer_output.mean(dim=0) + final_output = self.fc(final_feature) + + return final_output + +if __name__ == "__main__": + config = { + 'pts_embed_dim': 1024, # 每个点云embedding的维度 + 'pose_embed_dim': 256, # 每个姿态embedding的维度 + 'num_heads': 4, # 多头注意力机制的头数 + 'ffn_dim': 256, # 前馈神经网络的维度 + 'num_layers': 3, # Transformer 编码层数 + 'max_seq_len': 10, # 最大序列长度 + 'output_dim': 2048, # 输出特征维度 + } + + encoder = TransformerSequenceEncoder(config) + seq_len = 5 + pts_embedding_list = [torch.randn(config['pts_embed_dim']) for _ in range(seq_len)] + pose_embedding_list = [torch.randn(config['pose_embed_dim']) for _ in range(seq_len)] + output_feature = encoder.encode_sequence(pts_embedding_list, pose_embedding_list) + print("Encoded Feature:", output_feature) + print("Feature Shape:", output_feature.shape) \ No newline at end of file diff --git a/modules/view_finder/gf_view_finder.py b/modules/view_finder/gf_view_finder.py index 47b2b6c..ea7b037 100644 --- a/modules/view_finder/gf_view_finder.py +++ b/modules/view_finder/gf_view_finder.py @@ -54,12 +54,12 @@ class GradientFieldViewFinder(ViewFinder): if not self.per_point_feature: ''' rotation_x_axis regress head ''' self.fusion_tail_rot_x = nn.Sequential( - nn.Linear(128 + 256 + 1024 + 1024, 256), + nn.Linear(128 + 256 + 2048, 256), self.act, zero_module(nn.Linear(256, 3)), ) self.fusion_tail_rot_y = nn.Sequential( - nn.Linear(128 + 256 + 1024 + 1024, 256), + nn.Linear(128 + 256 + 2048, 256), self.act, zero_module(nn.Linear(256, 3)), ) @@ -72,15 +72,13 @@ class GradientFieldViewFinder(ViewFinder): """ Args: data, dict { - 'target_pts_feat': [bs, c] - 'scene_pts_feat': [bs, c] + 'seq_feat': [bs, c] 'pose_sample': [bs, pose_dim] 't': [bs, 1] } """ - scene_pts_feat = data['scene_feat'] - target_pts_feat = data['target_feat'] + seq_feat = data['seq_feat'] sampled_pose = data['sampled_pose'] t = data['t'] t_feat = self.t_encoder(t.squeeze(1)) @@ -89,7 +87,7 @@ class GradientFieldViewFinder(ViewFinder): if self.per_point_feature: raise NotImplementedError else: - total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1) + total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1) _, std = self.marginal_prob_fn(total_feat, t) if self.regression_head == 'Rx_Ry': @@ -106,20 +104,7 @@ class GradientFieldViewFinder(ViewFinder): def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None): - if self.sample_mode == 'pc': - in_process_sample, res = flib.cond_pc_sampler( - score_model=self, - data=data, - prior=self.prior_fn, - sde_coeff=self.sde_fn, - num_steps=self.sampling_steps, - snr=snr, - eps=self.sampling_eps, - pose_mode=self.pose_mode, - init_x=init_x - ) - - elif self.sample_mode == 'ode': + if self.sample_mode == 'ode': T0 = self.T if T0 is None else T0 in_process_sample, res = flib.cond_ode_sampler( score_model=self, @@ -140,10 +125,9 @@ class GradientFieldViewFinder(ViewFinder): return in_process_sample, res - def next_best_view(self, scene_pts_feat, target_pts_feat): + def next_best_view(self, seq_feat): data = { - 'scene_feat': scene_pts_feat, - 'target_feat': target_pts_feat, + 'seq_feat': seq_feat, } in_process_sample, res = self.sample(data) return res.to(dtype=torch.float32), in_process_sample