update transformer_seq_encoder's config
This commit is contained in:
parent
fa69f9f879
commit
bfc8ba0f4b
@ -90,7 +90,7 @@ pipeline:
|
|||||||
nbv_reconstruction_global_pts_pipeline:
|
nbv_reconstruction_global_pts_pipeline:
|
||||||
modules:
|
modules:
|
||||||
pts_encoder: pointnet_encoder
|
pts_encoder: pointnet_encoder
|
||||||
pose_seq_encoder: transformer_pose_seq_encoder
|
pose_seq_encoder: transformer_seq_encoder
|
||||||
pose_encoder: pose_encoder
|
pose_encoder: pose_encoder
|
||||||
view_finder: gf_view_finder
|
view_finder: gf_view_finder
|
||||||
eps: 1e-5
|
eps: 1e-5
|
||||||
@ -107,20 +107,12 @@ module:
|
|||||||
feature_transform: False
|
feature_transform: False
|
||||||
|
|
||||||
transformer_seq_encoder:
|
transformer_seq_encoder:
|
||||||
pts_embed_dim: 1024
|
embed_dim: 1344
|
||||||
pose_embed_dim: 256
|
|
||||||
num_heads: 4
|
num_heads: 4
|
||||||
ffn_dim: 256
|
ffn_dim: 256
|
||||||
num_layers: 3
|
num_layers: 3
|
||||||
output_dim: 2048
|
output_dim: 2048
|
||||||
|
|
||||||
transformer_pose_seq_encoder:
|
|
||||||
pose_embed_dim: 256
|
|
||||||
num_heads: 4
|
|
||||||
ffn_dim: 256
|
|
||||||
num_layers: 3
|
|
||||||
output_dim: 1024
|
|
||||||
|
|
||||||
gf_view_finder:
|
gf_view_finder:
|
||||||
t_feat_dim: 128
|
t_feat_dim: 128
|
||||||
pose_feat_dim: 256
|
pose_feat_dim: 256
|
||||||
|
@ -12,21 +12,24 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
|||||||
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
|
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.module_config = config["modules"]
|
self.module_config = config["modules"]
|
||||||
|
|
||||||
self.pts_encoder = ComponentFactory.create(
|
self.pts_encoder = ComponentFactory.create(
|
||||||
namespace.Stereotype.MODULE, self.module_config["pts_encoder"]
|
namespace.Stereotype.MODULE, self.module_config["pts_encoder"]
|
||||||
)
|
)
|
||||||
self.pose_encoder = ComponentFactory.create(
|
self.pose_encoder = ComponentFactory.create(
|
||||||
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
|
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
|
||||||
)
|
)
|
||||||
self.pose_n_num_seq_encoder = ComponentFactory.create(
|
self.pts_num_encoder = ComponentFactory.create(
|
||||||
namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]
|
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.transformer_seq_encoder = ComponentFactory.create(
|
||||||
|
namespace.Stereotype.MODULE, self.module_config["transformer_seq_encoder"]
|
||||||
)
|
)
|
||||||
self.view_finder = ComponentFactory.create(
|
self.view_finder = ComponentFactory.create(
|
||||||
namespace.Stereotype.MODULE, self.module_config["view_finder"]
|
namespace.Stereotype.MODULE, self.module_config["view_finder"]
|
||||||
)
|
)
|
||||||
self.pts_num_encoder = ComponentFactory.create(
|
|
||||||
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.eps = float(self.config["eps"])
|
self.eps = float(self.config["eps"])
|
||||||
self.enable_global_scanned_feat = self.config["global_scanned_feat"]
|
self.enable_global_scanned_feat = self.config["global_scanned_feat"]
|
||||||
@ -128,7 +131,7 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
|||||||
seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl))
|
seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl))
|
||||||
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl))
|
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl))
|
||||||
|
|
||||||
seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
|
seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
|
||||||
|
|
||||||
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user