117 lines
4.9 KiB
Python
117 lines
4.9 KiB
Python
import os
|
|
import json
|
|
import torch
|
|
import numpy as np
|
|
from flask import Flask, request, jsonify
|
|
|
|
import PytorchBoot.namespace as namespace
|
|
import PytorchBoot.stereotype as stereotype
|
|
from PytorchBoot.factory import ComponentFactory
|
|
|
|
from PytorchBoot.runners.runner import Runner
|
|
from PytorchBoot.utils import Log
|
|
|
|
from utils.pts import PtsUtil
|
|
from beans.predict_result import PredictResult
|
|
|
|
@stereotype.runner("inferencer_server")
|
|
class InferencerServer(Runner):
|
|
def __init__(self, config_path):
|
|
super().__init__(config_path)
|
|
|
|
''' Web Server '''
|
|
self.app = Flask(__name__)
|
|
''' Pipeline '''
|
|
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
|
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
|
self.pipeline = self.pipeline.to(self.device)
|
|
self.pts_num = 8192
|
|
self.voxel_size = 0.002
|
|
|
|
''' Experiment '''
|
|
self.load_experiment("inferencer_server")
|
|
|
|
def get_input_data(self, data):
|
|
input_data = {}
|
|
scanned_pts = data["scanned_pts"]
|
|
scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"]
|
|
combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0)
|
|
voxel_downsampled_combined_scanned_pts = PtsUtil.voxel_downsample_point_cloud(
|
|
combined_scanned_views_pts, self.voxel_size
|
|
)
|
|
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
|
|
voxel_downsampled_combined_scanned_pts, self.pts_num, require_idx=True
|
|
)
|
|
|
|
input_data["scanned_pts"] = scanned_pts
|
|
input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32)
|
|
input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32)
|
|
return input_data
|
|
|
|
def get_result(self, output_data):
|
|
|
|
pred_pose_9d = output_data["pred_pose_9d"]
|
|
pred_pose_9d = np.asarray(PredictResult(pred_pose_9d.cpu().numpy(), None, cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses, dtype=np.float32)
|
|
result = {
|
|
"pred_pose_9d": pred_pose_9d.tolist()
|
|
}
|
|
return result
|
|
|
|
def collate_input(self, input_data):
|
|
collated_input_data = {}
|
|
collated_input_data["scanned_pts"] = [torch.tensor(input_data["scanned_pts"], dtype=torch.float32, device=self.device)]
|
|
collated_input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(input_data["scanned_n_to_world_pose_9d"], dtype=torch.float32, device=self.device)]
|
|
collated_input_data["combined_scanned_pts"] = torch.tensor(input_data["combined_scanned_pts"], dtype=torch.float32, device=self.device).unsqueeze(0)
|
|
return collated_input_data
|
|
|
|
def run(self):
|
|
Log.info("Loading from epoch {}.".format(self.current_epoch))
|
|
|
|
@self.app.route("/inference", methods=["POST"])
|
|
def inference():
|
|
data = request.json
|
|
input_data = self.get_input_data(data)
|
|
collated_input_data = self.collate_input(input_data)
|
|
output_data = self.pipeline.forward_test(collated_input_data)
|
|
result = self.get_result(output_data)
|
|
return jsonify(result)
|
|
|
|
|
|
self.app.run(host="0.0.0.0", port=5000)
|
|
|
|
def get_checkpoint_path(self, is_last=False):
|
|
return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME,
|
|
"Epoch_{}.pth".format(
|
|
self.current_epoch if self.current_epoch != -1 and not is_last else "last"))
|
|
|
|
def load_checkpoint(self, is_last=False):
|
|
self.load(self.get_checkpoint_path(is_last))
|
|
Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}")
|
|
if is_last:
|
|
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
|
|
meta_path = os.path.join(checkpoint_root, "meta.json")
|
|
if not os.path.exists(meta_path):
|
|
raise FileNotFoundError(
|
|
"No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"]))
|
|
file_path = os.path.join(checkpoint_root, "meta.json")
|
|
with open(file_path, "r") as f:
|
|
meta = json.load(f)
|
|
self.current_epoch = meta["last_epoch"]
|
|
self.current_iter = meta["last_iter"]
|
|
|
|
def load_experiment(self, backup_name=None):
|
|
super().load_experiment(backup_name)
|
|
self.current_epoch = self.experiments_config["epoch"]
|
|
self.load_checkpoint(is_last=(self.current_epoch == -1))
|
|
|
|
def create_experiment(self, backup_name=None):
|
|
super().create_experiment(backup_name)
|
|
|
|
|
|
def load(self, path):
|
|
state_dict = torch.load(path)
|
|
self.pipeline.load_state_dict(state_dict)
|
|
|
|
|
|
|