nbv_reconstruction/runners/inference_server.py
2025-01-07 19:32:02 +08:00

117 lines
4.9 KiB
Python

import os
import json
import torch
import numpy as np
from flask import Flask, request, jsonify
import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory import ComponentFactory
from PytorchBoot.runners.runner import Runner
from PytorchBoot.utils import Log
from utils.pts import PtsUtil
from beans.predict_result import PredictResult
@stereotype.runner("inferencer_server")
class InferencerServer(Runner):
def __init__(self, config_path):
super().__init__(config_path)
''' Web Server '''
self.app = Flask(__name__)
''' Pipeline '''
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
self.pipeline = self.pipeline.to(self.device)
self.pts_num = 8192
self.voxel_size = 0.002
''' Experiment '''
self.load_experiment("inferencer_server")
def get_input_data(self, data):
input_data = {}
scanned_pts = data["scanned_pts"]
scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"]
combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0)
voxel_downsampled_combined_scanned_pts = PtsUtil.voxel_downsample_point_cloud(
combined_scanned_views_pts, self.voxel_size
)
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
voxel_downsampled_combined_scanned_pts, self.pts_num, require_idx=True
)
input_data["scanned_pts"] = scanned_pts
input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32)
input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32)
return input_data
def get_result(self, output_data):
pred_pose_9d = output_data["pred_pose_9d"]
pred_pose_9d = np.asarray(PredictResult(pred_pose_9d.cpu().numpy(), None, cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses, dtype=np.float32)
result = {
"pred_pose_9d": pred_pose_9d.tolist()
}
return result
def collate_input(self, input_data):
collated_input_data = {}
collated_input_data["scanned_pts"] = [torch.tensor(input_data["scanned_pts"], dtype=torch.float32, device=self.device)]
collated_input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(input_data["scanned_n_to_world_pose_9d"], dtype=torch.float32, device=self.device)]
collated_input_data["combined_scanned_pts"] = torch.tensor(input_data["combined_scanned_pts"], dtype=torch.float32, device=self.device).unsqueeze(0)
return collated_input_data
def run(self):
Log.info("Loading from epoch {}.".format(self.current_epoch))
@self.app.route("/inference", methods=["POST"])
def inference():
data = request.json
input_data = self.get_input_data(data)
collated_input_data = self.collate_input(input_data)
output_data = self.pipeline.forward_test(collated_input_data)
result = self.get_result(output_data)
return jsonify(result)
self.app.run(host="0.0.0.0", port=5000)
def get_checkpoint_path(self, is_last=False):
return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME,
"Epoch_{}.pth".format(
self.current_epoch if self.current_epoch != -1 and not is_last else "last"))
def load_checkpoint(self, is_last=False):
self.load(self.get_checkpoint_path(is_last))
Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}")
if is_last:
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
meta_path = os.path.join(checkpoint_root, "meta.json")
if not os.path.exists(meta_path):
raise FileNotFoundError(
"No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"]))
file_path = os.path.join(checkpoint_root, "meta.json")
with open(file_path, "r") as f:
meta = json.load(f)
self.current_epoch = meta["last_epoch"]
self.current_iter = meta["last_iter"]
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
self.current_epoch = self.experiments_config["epoch"]
self.load_checkpoint(is_last=(self.current_epoch == -1))
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load(self, path):
state_dict = torch.load(path)
self.pipeline.load_state_dict(state_dict)