87 lines
3.2 KiB
Python
87 lines
3.2 KiB
Python
import os
|
|
import trimesh
|
|
import numpy as np
|
|
from PytorchBoot.runners.runner import Runner
|
|
from PytorchBoot.config import ConfigManager
|
|
import PytorchBoot.stereotype as stereotype
|
|
from PytorchBoot.utils.log_util import Log
|
|
import PytorchBoot.namespace as namespace
|
|
from PytorchBoot.status import status_manager
|
|
|
|
from utils.control_util import ControlUtil
|
|
from utils.communicate_util import CommunicateUtil
|
|
from utils.pts_util import PtsUtil
|
|
from utils.view_sample_util import ViewSampleUtil
|
|
from utils.reconstruction_util import ReconstructionUtil
|
|
|
|
|
|
@stereotype.runner("inferencer")
|
|
class Inferencer(Runner):
|
|
|
|
def __init__(self, config_path: str):
|
|
super().__init__(config_path)
|
|
self.load_experiment("inferencer")
|
|
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
|
|
self.voxel_size = self.reconstruct_config["voxel_size"]
|
|
self.max_iter = self.reconstruct_config["max_iter"]
|
|
|
|
def create_experiment(self, backup_name=None):
|
|
super().create_experiment(backup_name)
|
|
|
|
def load_experiment(self, backup_name=None):
|
|
super().load_experiment(backup_name)
|
|
|
|
def run_inference(self, model_name):
|
|
|
|
''' init robot '''
|
|
ControlUtil.init()
|
|
''' take first view '''
|
|
view_data = CommunicateUtil.get_view_data()
|
|
first_cam_pts = None
|
|
first_cam_pose = None
|
|
combined_pts = first_cam_pts
|
|
input_data = {
|
|
"scanned_target_points_num": [first_cam_pts.shape[0]],
|
|
"scanned_n_to_world_pose_9d": [first_cam_pose],
|
|
"combined_scanned_pts": combined_pts
|
|
}
|
|
''' enter loop '''
|
|
iter = 0
|
|
while True:
|
|
''' inference '''
|
|
inference_result = CommunicateUtil.get_inference_data(input_data)
|
|
cam_to_world = inference_result["cam_to_world"]
|
|
''' set pose '''
|
|
ControlUtil.set_pose(cam_to_world)
|
|
''' take view '''
|
|
view_data = CommunicateUtil.get_view_data()
|
|
curr_cam_pts = None
|
|
curr_cam_pose = None
|
|
''' update combined pts '''
|
|
combined_pts = np.concatenate([combined_pts, curr_cam_pts], axis=0)
|
|
combined_pts = PtsUtil.voxel_downsample_point_cloud(combined_pts, voxel_size=self.voxel_size)
|
|
''' update input data '''
|
|
input_data["combined_scanned_pts"] = combined_pts
|
|
input_data["scanned_target_points_num"].append(curr_cam_pts.shape[0])
|
|
input_data["scanned_n_to_world_pose_9d"].append(curr_cam_pose)
|
|
|
|
''' check stop condition '''
|
|
if iter >= self.max_iter:
|
|
break
|
|
|
|
|
|
def run(self):
|
|
self.run_inference()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model_path = "/home/yan20/nbv_rec/data/test_CAD/test_model/bear_scaled.ply"
|
|
model = trimesh.load(model_path)
|
|
test_pts_L = np.loadtxt("/home/yan20/nbv_rec/data/test_CAD/cam_pts_0_L.txt")
|
|
test_pts_R = np.loadtxt("/home/yan20/nbv_rec/data/test_CAD/cam_pts_0_R.txt")
|
|
cam_to_world_L = PtsUtil.register_icp(test_pts_L, model)
|
|
cam_to_world_R = PtsUtil.register_icp(test_pts_R, model)
|
|
print(cam_to_world_L)
|
|
print("================================")
|
|
print(cam_to_world_R)
|
|
|