update pipeline: train test mode
This commit is contained in:
parent
96fb34be09
commit
68b4325dbd
@ -3,6 +3,7 @@ from torch import nn
|
|||||||
import PytorchBoot.namespace as namespace
|
import PytorchBoot.namespace as namespace
|
||||||
import PytorchBoot.stereotype as stereotype
|
import PytorchBoot.stereotype as stereotype
|
||||||
from PytorchBoot.factory.component_factory import ComponentFactory
|
from PytorchBoot.factory.component_factory import ComponentFactory
|
||||||
|
from PytorchBoot.utils import Log
|
||||||
|
|
||||||
@stereotype.pipeline("nbv_reconstruction_pipeline")
|
@stereotype.pipeline("nbv_reconstruction_pipeline")
|
||||||
class NBVReconstructionPipeline(nn.Module):
|
class NBVReconstructionPipeline(nn.Module):
|
||||||
@ -15,6 +16,15 @@ class NBVReconstructionPipeline(nn.Module):
|
|||||||
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"])
|
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"])
|
||||||
|
|
||||||
def forward(self, data):
|
def forward(self, data):
|
||||||
|
mode = data["mode"]
|
||||||
|
if mode == namespace.Mode.TRAIN:
|
||||||
|
return self.forward_train(data)
|
||||||
|
elif mode == namespace.Mode.TEST:
|
||||||
|
return self.forward_test(data)
|
||||||
|
else:
|
||||||
|
Log.error("Unknown mode: {}".format(mode), True)
|
||||||
|
|
||||||
|
def forward_train(self, data):
|
||||||
output = {}
|
output = {}
|
||||||
pts_list = data['pts_list']
|
pts_list = data['pts_list']
|
||||||
pose_list = data['pose_list']
|
pose_list = data['pose_list']
|
||||||
@ -27,3 +37,6 @@ class NBVReconstructionPipeline(nn.Module):
|
|||||||
output['estimated_score'] = self.view_finder.next_best_view(seq_feat)
|
output['estimated_score'] = self.view_finder.next_best_view(seq_feat)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def forward_test(self,data):
|
||||||
|
pass
|
Loading…
x
Reference in New Issue
Block a user