42 lines
1.7 KiB
Python
42 lines
1.7 KiB
Python
|
|
from torch import nn
|
|
import PytorchBoot.namespace as namespace
|
|
import PytorchBoot.stereotype as stereotype
|
|
from PytorchBoot.factory.component_factory import ComponentFactory
|
|
from PytorchBoot.utils import Log
|
|
|
|
@stereotype.pipeline("nbv_reconstruction_pipeline")
|
|
class NBVReconstructionPipeline(nn.Module):
|
|
def __init__(self, config):
|
|
super(NBVReconstructionPipeline, self).__init__()
|
|
self.config = config
|
|
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"])
|
|
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"])
|
|
self.seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["seq_encoder"])
|
|
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"])
|
|
|
|
def forward(self, data):
|
|
mode = data["mode"]
|
|
if mode == namespace.Mode.TRAIN:
|
|
return self.forward_train(data)
|
|
elif mode == namespace.Mode.TEST:
|
|
return self.forward_test(data)
|
|
else:
|
|
Log.error("Unknown mode: {}".format(mode), True)
|
|
|
|
def forward_train(self, data):
|
|
output = {}
|
|
pts_list = data['pts_list']
|
|
pose_list = data['pose_list']
|
|
pts_feat_list = []
|
|
pose_feat_list = []
|
|
for pts,pose in zip(pts_list,pose_list):
|
|
pts_feat_list.append(self.pts_encoder.encode_points(pts))
|
|
pose_feat_list.append(self.pose_encoder.encode_pose(pose))
|
|
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
|
|
output['estimated_score'] = self.view_finder.next_best_view(seq_feat)
|
|
|
|
return output
|
|
|
|
def forward_test(self,data):
|
|
pass |