change dataset.py to nbv_dataset.py
This commit is contained in:
parent
4e4fcb2ce5
commit
18333e6831
@ -22,7 +22,6 @@ class TransformerSequenceEncoder(nn.Module):
|
||||
self.fc = nn.Linear(embed_dim, config["output_dim"])
|
||||
|
||||
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch):
|
||||
# Combine features and pad sequences
|
||||
combined_features_batch = []
|
||||
lengths = []
|
||||
|
||||
@ -36,16 +35,11 @@ class TransformerSequenceEncoder(nn.Module):
|
||||
|
||||
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
|
||||
|
||||
# Prepare mask for padding
|
||||
max_len = max(lengths)
|
||||
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
|
||||
# Transformer encoding
|
||||
|
||||
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
|
||||
|
||||
# Mean pooling
|
||||
final_feature = transformer_output.mean(dim=1)
|
||||
|
||||
# Fully connected layer
|
||||
final_output = self.fc(final_feature)
|
||||
|
||||
return final_output
|
||||
|
Loading…
x
Reference in New Issue
Block a user