change dataset.py to nbv_dataset.py
This commit is contained in:
parent
4e4fcb2ce5
commit
18333e6831
@ -22,7 +22,6 @@ class TransformerSequenceEncoder(nn.Module):
|
|||||||
self.fc = nn.Linear(embed_dim, config["output_dim"])
|
self.fc = nn.Linear(embed_dim, config["output_dim"])
|
||||||
|
|
||||||
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch):
|
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch):
|
||||||
# Combine features and pad sequences
|
|
||||||
combined_features_batch = []
|
combined_features_batch = []
|
||||||
lengths = []
|
lengths = []
|
||||||
|
|
||||||
@ -36,16 +35,11 @@ class TransformerSequenceEncoder(nn.Module):
|
|||||||
|
|
||||||
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
|
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
|
||||||
|
|
||||||
# Prepare mask for padding
|
|
||||||
max_len = max(lengths)
|
max_len = max(lengths)
|
||||||
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
|
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
|
||||||
# Transformer encoding
|
|
||||||
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
|
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
|
||||||
|
|
||||||
# Mean pooling
|
|
||||||
final_feature = transformer_output.mean(dim=1)
|
final_feature = transformer_output.mean(dim=1)
|
||||||
|
|
||||||
# Fully connected layer
|
|
||||||
final_output = self.fc(final_feature)
|
final_output = self.fc(final_feature)
|
||||||
|
|
||||||
return final_output
|
return final_output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user