change dataset.py to nbv_dataset.py

This commit is contained in:
hofee 2024-09-20 06:43:19 +00:00
parent 4e4fcb2ce5
commit 18333e6831
2 changed files with 1 additions and 7 deletions

View File

@ -22,7 +22,6 @@ class TransformerSequenceEncoder(nn.Module):
self.fc = nn.Linear(embed_dim, config["output_dim"]) self.fc = nn.Linear(embed_dim, config["output_dim"])
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch): def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch):
# Combine features and pad sequences
combined_features_batch = [] combined_features_batch = []
lengths = [] lengths = []
@ -36,16 +35,11 @@ class TransformerSequenceEncoder(nn.Module):
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim] combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
# Prepare mask for padding
max_len = max(lengths) max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device) padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
# Transformer encoding
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask) transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
# Mean pooling
final_feature = transformer_output.mean(dim=1) final_feature = transformer_output.mean(dim=1)
# Fully connected layer
final_output = self.fc(final_feature) final_output = self.fc(final_feature)
return final_output return final_output