update pipeline: train test mode

This commit is contained in:
hofee 2024-08-21 18:03:35 +08:00
parent 96fb34be09
commit 68b4325dbd

View File

@ -3,6 +3,7 @@ from torch import nn
import PytorchBoot.namespace as namespace import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory.component_factory import ComponentFactory from PytorchBoot.factory.component_factory import ComponentFactory
from PytorchBoot.utils import Log
@stereotype.pipeline("nbv_reconstruction_pipeline") @stereotype.pipeline("nbv_reconstruction_pipeline")
class NBVReconstructionPipeline(nn.Module): class NBVReconstructionPipeline(nn.Module):
@ -15,6 +16,15 @@ class NBVReconstructionPipeline(nn.Module):
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"]) self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"])
def forward(self, data): def forward(self, data):
mode = data["mode"]
if mode == namespace.Mode.TRAIN:
return self.forward_train(data)
elif mode == namespace.Mode.TEST:
return self.forward_test(data)
else:
Log.error("Unknown mode: {}".format(mode), True)
def forward_train(self, data):
output = {} output = {}
pts_list = data['pts_list'] pts_list = data['pts_list']
pose_list = data['pose_list'] pose_list = data['pose_list']
@ -27,3 +37,6 @@ class NBVReconstructionPipeline(nn.Module):
output['estimated_score'] = self.view_finder.next_best_view(seq_feat) output['estimated_score'] = self.view_finder.next_best_view(seq_feat)
return output return output
def forward_test(self,data):
pass