diff --git a/core/pipeline.py b/core/pipeline.py index ceb5c03..a949654 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -3,6 +3,7 @@ from torch import nn import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory.component_factory import ComponentFactory +from PytorchBoot.utils import Log @stereotype.pipeline("nbv_reconstruction_pipeline") class NBVReconstructionPipeline(nn.Module): @@ -15,6 +16,15 @@ class NBVReconstructionPipeline(nn.Module): self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"]) def forward(self, data): + mode = data["mode"] + if mode == namespace.Mode.TRAIN: + return self.forward_train(data) + elif mode == namespace.Mode.TEST: + return self.forward_test(data) + else: + Log.error("Unknown mode: {}".format(mode), True) + + def forward_train(self, data): output = {} pts_list = data['pts_list'] pose_list = data['pose_list'] @@ -26,4 +36,7 @@ class NBVReconstructionPipeline(nn.Module): seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list) output['estimated_score'] = self.view_finder.next_best_view(seq_feat) - return output \ No newline at end of file + return output + + def forward_test(self,data): + pass \ No newline at end of file