from torch import nn import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory.component_factory import ComponentFactory from PytorchBoot.utils import Log @stereotype.pipeline("nbv_reconstruction_pipeline") class NBVReconstructionPipeline(nn.Module): def __init__(self, config): super(NBVReconstructionPipeline, self).__init__() self.config = config self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"]) self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"]) self.seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["seq_encoder"]) self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"]) def forward(self, data): mode = data["mode"] if mode == namespace.Mode.TRAIN: return self.forward_train(data) elif mode == namespace.Mode.TEST: return self.forward_test(data) else: Log.error("Unknown mode: {}".format(mode), True) def forward_train(self, data): output = {} pts_list = data['pts_list'] pose_list = data['pose_list'] pts_feat_list = [] pose_feat_list = [] for pts,pose in zip(pts_list,pose_list): pts_feat_list.append(self.pts_encoder.encode_points(pts)) pose_feat_list.append(self.pose_encoder.encode_pose(pose)) seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list) output['estimated_score'] = self.view_finder.next_best_view(seq_feat) return output def forward_test(self,data): pass