add embedding_seq_encoder and remove specific seq_encoder

This commit is contained in:
hofee 2024-09-29 20:43:01 +08:00
parent f42e45d608
commit 2f6d156abd
3 changed files with 15 additions and 159 deletions

View File

@ -1,72 +0,0 @@
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import PytorchBoot.stereotype as stereotype
@stereotype.module("transformer_pose_n_num_seq_encoder")
class TransformerPoseAndNumSequenceEncoder(nn.Module):
def __init__(self, config):
super(TransformerPoseAndNumSequenceEncoder, self).__init__()
self.config = config
embed_dim = config["pts_num_embed_dim"] + config["pose_embed_dim"]
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=config["num_heads"],
dim_feedforward=config["ffn_dim"],
batch_first=True,
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=config["num_layers"]
)
self.fc = nn.Linear(embed_dim, config["output_dim"])
def encode_sequence(self, pts_num_embedding_list_batch, pose_embedding_list_batch):
combined_features_batch = []
lengths = []
for pts_num_embedding_list, pose_embedding_list in zip(pts_num_embedding_list_batch, pose_embedding_list_batch):
combined_features = [
torch.cat((pts_num_embed, pose_embed), dim=-1)
for pts_num_embed, pose_embed in zip(pts_num_embedding_list, pose_embedding_list)
]
combined_features_batch.append(torch.stack(combined_features))
lengths.append(len(combined_features))
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
final_feature = transformer_output.mean(dim=1)
final_output = self.fc(final_feature)
return final_output
if __name__ == "__main__":
config = {
"pts_num_embed_dim": 128,
"pose_embed_dim": 256,
"num_heads": 4,
"ffn_dim": 256,
"num_layers": 3,
"output_dim": 2048,
}
encoder = TransformerPoseAndNumSequenceEncoder(config)
seq_len = [5, 8, 9, 4]
batch_size = 4
pts_num_embedding_list_batch = [
torch.randn(seq_len[idx], config["pts_num_embed_dim"]) for idx in range(batch_size)
]
pose_embedding_list_batch = [
torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size)
]
output_feature = encoder.encode_sequence(
pts_num_embedding_list_batch, pose_embedding_list_batch
)
print("Encoded Feature:", output_feature)
print("Feature Shape:", output_feature.shape)

View File

@ -1,72 +0,0 @@
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import PytorchBoot.stereotype as stereotype
@stereotype.module("transformer_pose_n_pts_seq_encoder")
class TransformerSequenceEncoder(nn.Module):
def __init__(self, config):
super(TransformerSequenceEncoder, self).__init__()
self.config = config
embed_dim = config["pts_embed_dim"] + config["pose_embed_dim"]
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=config["num_heads"],
dim_feedforward=config["ffn_dim"],
batch_first=True,
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=config["num_layers"]
)
self.fc = nn.Linear(embed_dim, config["output_dim"])
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch):
combined_features_batch = []
lengths = []
for pts_embedding_list, pose_embedding_list in zip(pts_embedding_list_batch, pose_embedding_list_batch):
combined_features = [
torch.cat((pts_embed, pose_embed), dim=-1)
for pts_embed, pose_embed in zip(pts_embedding_list, pose_embedding_list)
]
combined_features_batch.append(torch.stack(combined_features))
lengths.append(len(combined_features))
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
final_feature = transformer_output.mean(dim=1)
final_output = self.fc(final_feature)
return final_output
if __name__ == "__main__":
config = {
"pts_embed_dim": 1024,
"pose_embed_dim": 256,
"num_heads": 4,
"ffn_dim": 256,
"num_layers": 3,
"output_dim": 2048,
}
encoder = TransformerSequenceEncoder(config)
seq_len = [5, 8, 9, 4]
batch_size = 4
pts_embedding_list_batch = [
torch.randn(seq_len[idx], config["pts_embed_dim"]) for idx in range(batch_size)
]
pose_embedding_list_batch = [
torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size)
]
output_feature = encoder.encode_sequence(
pts_embedding_list_batch, pose_embedding_list_batch
)
print("Encoded Feature:", output_feature)
print("Feature Shape:", output_feature.shape)

View File

@ -4,12 +4,12 @@ from torch.nn.utils.rnn import pad_sequence
import PytorchBoot.stereotype as stereotype
@stereotype.module("transformer_pose_seq_encoder")
class TransformerPoseSequenceEncoder(nn.Module):
@stereotype.module("transformer_seq_encoder")
class TransformerSequenceEncoder(nn.Module):
def __init__(self, config):
super(TransformerPoseSequenceEncoder, self).__init__()
super(TransformerSequenceEncoder, self).__init__()
self.config = config
embed_dim = config["pose_embed_dim"]
embed_dim = config["embed_dim"]
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=config["num_heads"],
@ -21,19 +21,19 @@ class TransformerPoseSequenceEncoder(nn.Module):
)
self.fc = nn.Linear(embed_dim, config["output_dim"])
def encode_sequence(self, pose_embedding_list_batch):
def encode_sequence(self, embedding_list_batch):
lengths = []
for pose_embedding_list in pose_embedding_list_batch:
lengths.append(len(pose_embedding_list))
for embedding_list in embedding_list_batch:
lengths.append(len(embedding_list))
combined_tensor = pad_sequence(pose_embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
embedding_tensor = pad_sequence(embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(embedding_tensor.device)
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
transformer_output = self.transformer_encoder(embedding_tensor, src_key_padding_mask=padding_mask)
final_feature = transformer_output.mean(dim=1)
final_output = self.fc(final_feature)
@ -42,22 +42,22 @@ class TransformerPoseSequenceEncoder(nn.Module):
if __name__ == "__main__":
config = {
"pose_embed_dim": 256,
"embed_dim": 256,
"num_heads": 4,
"ffn_dim": 256,
"num_layers": 3,
"output_dim": 1024,
}
encoder = TransformerPoseSequenceEncoder(config)
encoder = TransformerSequenceEncoder(config)
seq_len = [5, 8, 9, 4]
batch_size = 4
pose_embedding_list_batch = [
torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size)
embedding_list_batch = [
torch.randn(seq_len[idx], config["embed_dim"]) for idx in range(batch_size)
]
output_feature = encoder.encode_sequence(
pose_embedding_list_batch
embedding_list_batch
)
print("Encoded Feature:", output_feature)
print("Feature Shape:", output_feature.shape)