nbv_reconstruction/core/pipeline.py
2024-08-21 17:59:42 +08:00

29 lines
1.3 KiB
Python

from torch import nn
import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory.component_factory import ComponentFactory
@stereotype.pipeline("nbv_reconstruction_pipeline")
class NBVReconstructionPipeline(nn.Module):
def __init__(self, config):
super(NBVReconstructionPipeline, self).__init__()
self.config = config
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"])
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"])
self.seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["seq_encoder"])
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"])
def forward(self, data):
output = {}
pts_list = data['pts_list']
pose_list = data['pose_list']
pts_feat_list = []
pose_feat_list = []
for pts,pose in zip(pts_list,pose_list):
pts_feat_list.append(self.pts_encoder.encode_points(pts))
pose_feat_list.append(self.pose_encoder.encode_pose(pose))
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
output['estimated_score'] = self.view_finder.next_best_view(seq_feat)
return output