import os import trimesh import numpy as np from PytorchBoot.runners.runner import Runner from PytorchBoot.config import ConfigManager import PytorchBoot.stereotype as stereotype from PytorchBoot.utils.log_util import Log from PytorchBoot.status import status_manager from utils.control_util import ControlUtil from utils.communicate_util import CommunicateUtil from utils.pts_util import PtsUtil from utils.view_sample_util import ViewSampleUtil from utils.reconstruction_util import ReconstructionUtil @stereotype.runner("CAD_strategy_runner") class CADStrategyRunner(Runner): def __init__(self, config_path: str): super().__init__(config_path) self.load_experiment("cad_strategy") self.status_info = { "status_manager": status_manager, "app_name": "cad", "runner_name": "cad_strategy" } self.generate_config = ConfigManager.get("runner", "generate") self.reconstruct_config = ConfigManager.get("runner", "reconstruct") self.model_dir = self.generate_config["model_dir"] self.voxel_size = self.generate_config["voxel_size"] self.max_view = self.generate_config["max_view"] self.min_view = self.generate_config["min_view"] self.max_diag = self.generate_config["max_diag"] self.min_diag = self.generate_config["min_diag"] self.min_cam_table_included_degree = self.generate_config["min_cam_table_included_degree"] self.random_view_ratio = self.generate_config["random_view_ratio"] self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"] self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"] self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"] def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def load_experiment(self, backup_name=None): super().load_experiment(backup_name) def run_one_model(self, model_name): ''' init robot ''' ControlUtil.init() ''' load CAD model ''' model_path = os.path.join(self.model_dir, model_name) cad_model = trimesh.load(model_path) ''' take first view ''' view_data = CommunicateUtil.get_view_data() first_cam_pts = None ''' register ''' cad_to_cam = PtsUtil.register_icp(first_cam_pts, cad_model) cam_to_world = ControlUtil.get_pose() cad_to_world = cam_to_world @ cad_to_cam cad_model:trimesh.Trimesh = cad_model.apply_transform(cad_to_world) ''' sample view ''' min_corner = cad_model.bounds[0] max_corner = cad_model.bounds[1] diag = np.linalg.norm(max_corner - min_corner) view_num = int(self.min_view + (diag - self.min_diag)/(self.max_diag - self.min_diag) * (self.max_view - self.min_view)) sampled_view_data = ViewSampleUtil.sample_view_data_world_space( cad_model, cad_to_world, voxel_size = self.voxel_size, max_views = view_num, min_cam_table_included_degree= self.min_cam_table_included_degree, random_view_ratio = self.random_view_ratio ) cam_to_world_poses = sampled_view_data["cam_to_world_poses"] world_model_points = sampled_view_data["voxel_down_sampled_points"] ''' take sample view ''' scan_points_idx_list = [] sample_view_pts_list = [] for cam_to_world in cam_to_world_poses: ControlUtil.move_to(cam_to_world) ''' get world pts ''' view_data = CommunicateUtil.get_view_data() cam_pts = None scan_points_idx = None world_pts = PtsUtil.transform_point_cloud(cam_pts, cam_to_world) sample_view_pts_list.append(world_pts) scan_points_idx_list.append(scan_points_idx) ''' generate strategy ''' limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap( world_model_points, sample_view_pts_list, scan_points_indices_list = scan_points_idx_list, init_view=0, threshold=self.voxel_size, soft_overlap_threshold= self.soft_overlap_threshold, hard_overlap_threshold= self.hard_overlap_threshold, scan_points_threshold = self.scan_points_threshold, status_info=self.status_info ) ''' extract cam_to world sequence ''' cam_to_world_seq = [] coveraget_rate_seq = [] for idx, coverage_rate in limited_useful_view: cam_to_world_seq.append(cam_to_world_poses[idx]) coveraget_rate_seq.append(coverage_rate) ''' take best seq view ''' for cam_to_world in cam_to_world_seq: ControlUtil.move_to(cam_to_world) ''' get world pts ''' view_data = CommunicateUtil.get_view_data() cam_pts = None scan_points_idx = None world_pts = PtsUtil.transform_point_cloud(cam_pts, cam_to_world) sample_view_pts_list.append(world_pts) scan_points_idx_list.append(scan_points_idx) def run(self): total = len(os.listdir(self.model_dir)) model_start_idx = self.generate_config["model_start_idx"] count_object = model_start_idx for model_name in os.listdir(self.model_dir[model_start_idx:]): Log.info(f"[{count_object}/{total}]Processing {model_name}") self.run_one_model(model_name) Log.success(f"[{count_object}/{total}]Finished processing {model_name}") if __name__ == "__main__": model_path = "/home/yan20/nbv_rec/data/test_CAD/test_model/bear_scaled.ply" model = trimesh.load(model_path) test_pts_L = np.loadtxt("/home/yan20/nbv_rec/data/test_CAD/cam_pts_0_L.txt") test_pts_R = np.loadtxt("/home/yan20/nbv_rec/data/test_CAD/cam_pts_0_R.txt") cam_to_world_L = PtsUtil.register_icp(test_pts_L, model) cam_to_world_R = PtsUtil.register_icp(test_pts_R, model) print(cam_to_world_L) print("================================") print(cam_to_world_R)