change pipeline's name
This commit is contained in:
parent
837e1c870a
commit
96fb34be09
@ -5,9 +5,9 @@ import PytorchBoot.stereotype as stereotype
|
|||||||
from PytorchBoot.factory.component_factory import ComponentFactory
|
from PytorchBoot.factory.component_factory import ComponentFactory
|
||||||
|
|
||||||
@stereotype.pipeline("nbv_reconstruction_pipeline")
|
@stereotype.pipeline("nbv_reconstruction_pipeline")
|
||||||
class ViewFinderPipeline(nn.Module):
|
class NBVReconstructionPipeline(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(ViewFinderPipeline, self).__init__()
|
super(NBVReconstructionPipeline, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"])
|
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"])
|
||||||
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"])
|
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"])
|
||||||
@ -24,6 +24,6 @@ class ViewFinderPipeline(nn.Module):
|
|||||||
pts_feat_list.append(self.pts_encoder.encode_points(pts))
|
pts_feat_list.append(self.pts_encoder.encode_points(pts))
|
||||||
pose_feat_list.append(self.pose_encoder.encode_pose(pose))
|
pose_feat_list.append(self.pose_encoder.encode_pose(pose))
|
||||||
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
|
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
|
||||||
output['next_best_view'] = self.view_finder.next_best_view(seq_feat)
|
output['estimated_score'] = self.view_finder.next_best_view(seq_feat)
|
||||||
|
|
||||||
return output
|
return output
|
Loading…
x
Reference in New Issue
Block a user