change pipeline's name

This commit is contained in:
hofee 2024-08-21 17:59:42 +08:00
parent 837e1c870a
commit 96fb34be09

View File

@ -5,9 +5,9 @@ import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory.component_factory import ComponentFactory from PytorchBoot.factory.component_factory import ComponentFactory
@stereotype.pipeline("nbv_reconstruction_pipeline") @stereotype.pipeline("nbv_reconstruction_pipeline")
class ViewFinderPipeline(nn.Module): class NBVReconstructionPipeline(nn.Module):
def __init__(self, config): def __init__(self, config):
super(ViewFinderPipeline, self).__init__() super(NBVReconstructionPipeline, self).__init__()
self.config = config self.config = config
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"]) self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"])
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"]) self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"])
@ -24,6 +24,6 @@ class ViewFinderPipeline(nn.Module):
pts_feat_list.append(self.pts_encoder.encode_points(pts)) pts_feat_list.append(self.pts_encoder.encode_points(pts))
pose_feat_list.append(self.pose_encoder.encode_pose(pose)) pose_feat_list.append(self.pose_encoder.encode_pose(pose))
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list) seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
output['next_best_view'] = self.view_finder.next_best_view(seq_feat) output['estimated_score'] = self.view_finder.next_best_view(seq_feat)
return output return output