diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index 6e44234..bf12e1b 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -90,7 +90,7 @@ pipeline: nbv_reconstruction_global_pts_pipeline: modules: pts_encoder: pointnet_encoder - pose_seq_encoder: transformer_pose_seq_encoder + pose_seq_encoder: transformer_seq_encoder pose_encoder: pose_encoder view_finder: gf_view_finder eps: 1e-5 @@ -107,20 +107,12 @@ module: feature_transform: False transformer_seq_encoder: - pts_embed_dim: 1024 - pose_embed_dim: 256 + embed_dim: 1344 num_heads: 4 ffn_dim: 256 num_layers: 3 output_dim: 2048 - transformer_pose_seq_encoder: - pose_embed_dim: 256 - num_heads: 4 - ffn_dim: 256 - num_layers: 3 - output_dim: 1024 - gf_view_finder: t_feat_dim: 128 pose_feat_dim: 256 diff --git a/core/global_pts_n_num_pipeline.py b/core/global_pts_n_num_pipeline.py index efc2bb8..04a360b 100644 --- a/core/global_pts_n_num_pipeline.py +++ b/core/global_pts_n_num_pipeline.py @@ -12,21 +12,24 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): super(NBVReconstructionGlobalPointsPipeline, self).__init__() self.config = config self.module_config = config["modules"] + self.pts_encoder = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["pts_encoder"] ) self.pose_encoder = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["pose_encoder"] ) - self.pose_n_num_seq_encoder = ComponentFactory.create( - namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"] + self.pts_num_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"] + ) + + self.transformer_seq_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["transformer_seq_encoder"] ) self.view_finder = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["view_finder"] ) - self.pts_num_encoder = ComponentFactory.create( - namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"] - ) + self.eps = float(self.config["eps"]) self.enable_global_scanned_feat = self.config["global_scanned_feat"] @@ -128,7 +131,7 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl)) embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl)) - seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) + seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))