From 8f82bec66cd1d2cad23a88271d503fa1ad52400b Mon Sep 17 00:00:00 2001 From: hofee <64160135+GitHofee@users.noreply.github.com> Date: Thu, 22 Aug 2024 00:41:30 +0800 Subject: [PATCH] batchlize transfomer and add forward_train/test in pipeline --- core/pipeline.py | 50 ++++++++++++++--- .../seq_encoder/transformer_seq_encoder.py | 53 ++++++++++++------- 2 files changed, 78 insertions(+), 25 deletions(-) diff --git a/core/pipeline.py b/core/pipeline.py index a949654..1771e52 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -1,4 +1,4 @@ - +import torch from torch import nn import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype @@ -24,8 +24,43 @@ class NBVReconstructionPipeline(nn.Module): else: Log.error("Unknown mode: {}".format(mode), True) + def pertube_data(self, gt_delta_rot_6d): + bs = gt_delta_rot_6d.shape[0] + random_t = torch.rand(bs, device=self.device) * (1. - self.eps) + self.eps + random_t = random_t.unsqueeze(-1) + mu, std = self.view_finder.marginal_prob(gt_delta_rot_6d, random_t) + std = std.view(-1, 1) + z = torch.randn_like(gt_delta_rot_6d) + perturbed_x = mu + z * std + target_score = - z * std / (std ** 2) + return perturbed_x, random_t, target_score, std + def forward_train(self, data): - output = {} + pts_list = data['pts_list'] + pose_list = data['pose_list'] + gt_delta_rot_6d = data["delta_rot_6d"] + pts_feat_list = [] + pose_feat_list = [] + for pts,pose in zip(pts_list,pose_list): + pts_feat_list.append(self.pts_encoder.encode_points(pts)) + pose_feat_list.append(self.pose_encoder.encode_pose(pose)) + seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list) + ''' get std ''' + perturbed_x, random_t, target_score, std = self.pertube_data(gt_delta_rot_6d) + input_data = { + "sampled_pose": perturbed_x, + "t": random_t, + "seq_feat": seq_feat, + } + estimated_score = self.view_finder(input_data) + output = { + "estimated_score": estimated_score, + "target_score": target_score, + "std": std + } + return output + + def forward_test(self,data): pts_list = data['pts_list'] pose_list = data['pose_list'] pts_feat_list = [] @@ -34,9 +69,10 @@ class NBVReconstructionPipeline(nn.Module): pts_feat_list.append(self.pts_encoder.encode_points(pts)) pose_feat_list.append(self.pose_encoder.encode_pose(pose)) seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list) - output['estimated_score'] = self.view_finder.next_best_view(seq_feat) + estimated_delta_rot_6d, in_process_sample = self.view_finder.next_best_view(seq_feat) + result = { + "estimated_delta_rot_6d": estimated_delta_rot_6d, + "in_process_sample": in_process_sample + } + return result - return output - - def forward_test(self,data): - pass \ No newline at end of file diff --git a/modules/seq_encoder/transformer_seq_encoder.py b/modules/seq_encoder/transformer_seq_encoder.py index c318a2a..47f0736 100644 --- a/modules/seq_encoder/transformer_seq_encoder.py +++ b/modules/seq_encoder/transformer_seq_encoder.py @@ -16,32 +16,49 @@ class TransformerSequenceEncoder(SequenceEncoder): self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config['num_layers']) self.fc = nn.Linear(embed_dim, config['output_dim']) - def encode_sequence(self, pts_embedding_list, pose_embedding_list): - combined_features = [torch.cat((pts_embed, pose_embed), dim=-1) for pts_embed, pose_embed in zip(pts_embedding_list[:-1], pose_embedding_list[:-1])] - combined_tensor = torch.stack(combined_features) - pos_encoding = self.positional_encoding[:, :combined_tensor.size(0), :] - combined_tensor = combined_tensor.unsqueeze(0) + pos_encoding - transformer_output = self.transformer_encoder(combined_tensor).squeeze(0) - final_feature = transformer_output.mean(dim=0) + def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch): + batch_size = len(pts_embedding_list_batch) + combined_features_batch = [] + + for i in range(batch_size): + combined_features = [torch.cat((pts_embed, pose_embed), dim=-1) + for pts_embed, pose_embed in zip(pts_embedding_list_batch[i][:-1], pose_embedding_list_batch[i][:-1])] + combined_features_batch.append(torch.stack(combined_features)) + + combined_tensor = torch.stack(combined_features_batch) # Shape: [batch_size, seq_len-1, embed_dim] + + # Adjust positional encoding to match batch size + pos_encoding = self.positional_encoding[:, :combined_tensor.size(1), :].repeat(batch_size, 1, 1) + combined_tensor = combined_tensor + pos_encoding + + # Transformer encoding + transformer_output = self.transformer_encoder(combined_tensor) + + # Mean pooling + final_feature = transformer_output.mean(dim=1) + + # Fully connected layer final_output = self.fc(final_feature) return final_output if __name__ == "__main__": config = { - 'pts_embed_dim': 1024, # 每个点云embedding的维度 - 'pose_embed_dim': 256, # 每个姿态embedding的维度 - 'num_heads': 4, # 多头注意力机制的头数 - 'ffn_dim': 256, # 前馈神经网络的维度 - 'num_layers': 3, # Transformer 编码层数 - 'max_seq_len': 10, # 最大序列长度 - 'output_dim': 2048, # 输出特征维度 + 'pts_embed_dim': 1024, # 每个点云embedding的维度 + 'pose_embed_dim': 256, # 每个姿态embedding的维度 + 'num_heads': 4, # 多头注意力机制的头数 + 'ffn_dim': 256, # 前馈神经网络的维度 + 'num_layers': 3, # Transformer 编码层数 + 'max_seq_len': 10, # 最大序列长度 + 'output_dim': 2048, # 输出特征维度 } encoder = TransformerSequenceEncoder(config) seq_len = 5 - pts_embedding_list = [torch.randn(config['pts_embed_dim']) for _ in range(seq_len)] - pose_embedding_list = [torch.randn(config['pose_embed_dim']) for _ in range(seq_len)] - output_feature = encoder.encode_sequence(pts_embedding_list, pose_embedding_list) + batch_size = 4 + + pts_embedding_list_batch = [torch.randn(seq_len, config['pts_embed_dim']) for _ in range(batch_size)] + pose_embedding_list_batch = [torch.randn(seq_len, config['pose_embed_dim']) for _ in range(batch_size)] + output_feature = encoder.encode_sequence(pts_embedding_list_batch, pose_embedding_list_batch) print("Encoded Feature:", output_feature) - print("Feature Shape:", output_feature.shape) \ No newline at end of file + print("Feature Shape:", output_feature.shape)