update pipeline: train test mode
This commit is contained in:
parent
96fb34be09
commit
68b4325dbd
@ -3,6 +3,7 @@ from torch import nn
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.factory.component_factory import ComponentFactory
|
||||
from PytorchBoot.utils import Log
|
||||
|
||||
@stereotype.pipeline("nbv_reconstruction_pipeline")
|
||||
class NBVReconstructionPipeline(nn.Module):
|
||||
@ -15,6 +16,15 @@ class NBVReconstructionPipeline(nn.Module):
|
||||
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"])
|
||||
|
||||
def forward(self, data):
|
||||
mode = data["mode"]
|
||||
if mode == namespace.Mode.TRAIN:
|
||||
return self.forward_train(data)
|
||||
elif mode == namespace.Mode.TEST:
|
||||
return self.forward_test(data)
|
||||
else:
|
||||
Log.error("Unknown mode: {}".format(mode), True)
|
||||
|
||||
def forward_train(self, data):
|
||||
output = {}
|
||||
pts_list = data['pts_list']
|
||||
pose_list = data['pose_list']
|
||||
@ -26,4 +36,7 @@ class NBVReconstructionPipeline(nn.Module):
|
||||
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
|
||||
output['estimated_score'] = self.view_finder.next_best_view(seq_feat)
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
def forward_test(self,data):
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user