import os import json import torch import numpy as np from flask import Flask, request, jsonify import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory import ComponentFactory from PytorchBoot.runners.runner import Runner from PytorchBoot.utils import Log from utils.pts import PtsUtil @stereotype.runner("inferencer_server") class InferencerServer(Runner): def __init__(self, config_path): super().__init__(config_path) ''' Web Server ''' self.app = Flask(__name__) ''' Pipeline ''' self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) self.pipeline = self.pipeline.to(self.device) self.pts_num = 8192 ''' Experiment ''' self.load_experiment("inferencer_server") def get_input_data(self, data): input_data = {} scanned_pts = data["scanned_pts"] scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"] combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0) fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud( combined_scanned_views_pts, self.pts_num, require_idx=True ) # combined_scanned_views_pts_mask = np.zeros(len(scanned_pts), dtype=np.uint8) # start_idx = 0 # for i in range(len(scanned_pts)): # end_idx = start_idx + len(scanned_pts[i]) # combined_scanned_views_pts_mask[start_idx:end_idx] = i # start_idx = end_idx # fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx] input_data["scanned_pts"] = scanned_pts # input_data["scanned_pts_mask"] = np.asarray(fps_downsampled_combined_scanned_pts_mask, dtype=np.uint8) input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32) input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32) return input_data def get_result(self, output_data): pred_pose_9d = output_data["pred_pose_9d"] result = { "pred_pose_9d": pred_pose_9d.tolist() } return result def collate_input(self, input_data): collated_input_data = {} collated_input_data["scanned_pts"] = [torch.tensor(input_data["scanned_pts"], dtype=torch.float32, device=self.device)] collated_input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(input_data["scanned_n_to_world_pose_9d"], dtype=torch.float32, device=self.device)] collated_input_data["combined_scanned_pts"] = torch.tensor(input_data["combined_scanned_pts"], dtype=torch.float32, device=self.device).unsqueeze(0) return collated_input_data def run(self): Log.info("Loading from epoch {}.".format(self.current_epoch)) @self.app.route("/inference", methods=["POST"]) def inference(): data = request.json input_data = self.get_input_data(data) collated_input_data = self.collate_input(input_data) output_data = self.pipeline.forward_test(collated_input_data) result = self.get_result(output_data) return jsonify(result) self.app.run(host="0.0.0.0", port=5000) def get_checkpoint_path(self, is_last=False): return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, "Epoch_{}.pth".format( self.current_epoch if self.current_epoch != -1 and not is_last else "last")) def load_checkpoint(self, is_last=False): self.load(self.get_checkpoint_path(is_last)) Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") if is_last: checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) meta_path = os.path.join(checkpoint_root, "meta.json") if not os.path.exists(meta_path): raise FileNotFoundError( "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) file_path = os.path.join(checkpoint_root, "meta.json") with open(file_path, "r") as f: meta = json.load(f) self.current_epoch = meta["last_epoch"] self.current_iter = meta["last_iter"] def load_experiment(self, backup_name=None): super().load_experiment(backup_name) self.current_epoch = self.experiments_config["epoch"] self.load_checkpoint(is_last=(self.current_epoch == -1)) def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def load(self, path): state_dict = torch.load(path) self.pipeline.load_state_dict(state_dict)