import os import time import trimesh import tempfile import subprocess import numpy as np from PytorchBoot.runners.runner import Runner from PytorchBoot.config import ConfigManager import PytorchBoot.stereotype as stereotype from PytorchBoot.utils.log_util import Log from PytorchBoot.status import status_manager from utils.control_util import ControlUtil from utils.communicate_util import CommunicateUtil from utils.pts_util import PtsUtil from utils.reconstruction_util import ReconstructionUtil from utils.preprocess_util import save_scene_data, save_scene_data_multithread from utils.data_load import DataLoadUtil from utils.view_util import ViewUtil @stereotype.runner("CAD_close_loop_strategy_runner") class CADCloseLoopStrategyRunner(Runner): def __init__(self, config_path: str): super().__init__(config_path) self.load_experiment("cad_strategy") self.status_info = { "status_manager": status_manager, "app_name": "cad", "runner_name": "CAD_close_loop_strategy_runner", } self.generate_config = ConfigManager.get("runner", "generate") self.reconstruct_config = ConfigManager.get("runner", "reconstruct") self.blender_bin_path = self.generate_config["blender_bin_path"] self.generator_script_path = self.generate_config["generator_script_path"] self.model_dir = self.generate_config["model_dir"] self.voxel_size = self.generate_config["voxel_size"] self.max_view = self.generate_config["max_view"] self.min_view = self.generate_config["min_view"] self.max_diag = self.generate_config["max_diag"] self.min_diag = self.generate_config["min_diag"] self.min_cam_table_included_degree = self.generate_config[ "min_cam_table_included_degree" ] self.max_shot_view_num = self.generate_config["max_shot_view_num"] self.min_shot_new_pts_num = self.generate_config["min_shot_new_pts_num"] self.min_coverage_increase = self.generate_config["min_coverage_increase"] self.random_view_ratio = self.generate_config["random_view_ratio"] self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"] self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"] self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"] def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def load_experiment(self, backup_name=None): super().load_experiment(backup_name) def split_scan_pts_and_obj_pts(self, world_pts, z_threshold=0): scan_pts = world_pts[world_pts[:, 2] < z_threshold] obj_pts = world_pts[world_pts[:, 2] >= z_threshold] return scan_pts, obj_pts def run_one_model(self, model_name): temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output" ControlUtil.connect_robot() """ init robot """ Log.info("[Part 1/5] start init and register") ControlUtil.init() """ load CAD model """ model_path = os.path.join(self.model_dir, model_name, "mesh.ply") temp_name = "cad_model_world" cad_model = trimesh.load(model_path) """ take first view """ Log.info("[Part 1/5] take first view data") view_data = CommunicateUtil.get_view_data(init=True) first_cam_pts = ViewUtil.get_pts(view_data) first_cam_to_real_world = ControlUtil.get_pose() first_real_world_pts = PtsUtil.transform_point_cloud( first_cam_pts, first_cam_to_real_world ) _, first_splitted_real_world_pts = self.split_scan_pts_and_obj_pts( first_real_world_pts ) np.savetxt(f"first_real_pts_{model_name}.txt", first_splitted_real_world_pts) """ register """ Log.info("[Part 1/4] do registeration") real_world_to_cad = PtsUtil.register(first_splitted_real_world_pts, cad_model) cad_to_real_world = np.linalg.inv(real_world_to_cad) Log.success("[Part 1/4] finish init and register") real_world_to_blender_world = np.eye(4) real_world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215]) cad_model_real_world: trimesh.Trimesh = cad_model.apply_transform( cad_to_real_world ) cad_model_real_world.export( os.path.join(temp_dir, f"real_world_{temp_name}.obj") ) cad_model_blender_world: trimesh.Trimesh = cad_model.apply_transform( real_world_to_blender_world ) with tempfile.TemporaryDirectory() as temp_dir: temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output" cad_model_blender_world.export(os.path.join(temp_dir, f"{temp_name}.obj")) """ sample view """ Log.info("[Part 2/4] start running renderer") subprocess.run( [ self.blender_bin_path, "-b", "-P", self.generator_script_path, "--", temp_dir, ], capture_output=True, text=True, ) Log.success("[Part 2/4] finish running renderer") """ preprocess """ Log.info("[Part 3/4] start preprocessing data") save_scene_data(temp_dir, temp_name) Log.success("[Part 3/4] finish preprocessing data") pts_dir = os.path.join(temp_dir, temp_name, "pts") sample_view_pts_list = [] scan_points_idx_list = [] frame_num = len(os.listdir(pts_dir)) for frame_idx in range(frame_num): pts_path = os.path.join(temp_dir, temp_name, "pts", f"{frame_idx}.txt") idx_path = os.path.join( temp_dir, temp_name, "scan_points_indices", f"{frame_idx}.npy" ) point_cloud = np.loadtxt(pts_path) if point_cloud.shape[0] != 0: sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud( point_cloud, self.voxel_size ) indices = np.load(idx_path) try: len(indices) except: indices = np.array([indices]) sample_view_pts_list.append(sampled_point_cloud) scan_points_idx_list.append(indices) """ close-loop strategy """ scanned_pts = PtsUtil.voxel_downsample_point_cloud( first_splitted_real_world_pts, self.voxel_size ) shot_pts_list = [first_splitted_real_world_pts] history_indices = [] last_coverage = 0 Log.info("[Part 4/4] start close-loop control") cnt = 0 while True: #import ipdb; ipdb.set_trace() next_best_view, next_best_coverage, next_best_covered_num = ( ReconstructionUtil.compute_next_best_view_with_overlap( scanned_pts, sample_view_pts_list, history_indices, scan_points_idx_list, threshold=self.voxel_size, overlap_area_threshold=25, scan_points_threshold=self.scan_points_threshold, ) ) nbv_path = DataLoadUtil.get_path(temp_dir, temp_name, next_best_view) nbv_cam_info = DataLoadUtil.load_cam_info(nbv_path, binocular=True) nbv_cam_to_world = nbv_cam_info["cam_to_world_O"] ControlUtil.move_to(nbv_cam_to_world) ''' get world pts ''' time.sleep(0.5) view_data = CommunicateUtil.get_view_data() if view_data is None: Log.error("No view data received") continue cam_shot_pts = ViewUtil.get_pts(view_data) world_shot_pts = PtsUtil.transform_point_cloud( cam_shot_pts, first_cam_to_real_world ) _, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts( world_shot_pts ) shot_pts_list.append(world_splitted_shot_pts) debug_dir = os.path.join(temp_dir, "debug") if not os.path.exists(debug_dir): os.makedirs(debug_dir) np.savetxt(os.path.join(debug_dir, f"shot_pts_{cnt}.txt"), world_splitted_shot_pts) np.savetxt(os.path.join(debug_dir, f"render_pts_{cnt}.txt"), sample_view_pts_list[next_best_view]) #real_world_to_cad = PtsUtil.register(first_splitted_real_world_pts, cad_model) #import ipdb; ipdb.set_trace() last_scanned_pts_num = scanned_pts.shape[0] new_scanned_pts = PtsUtil.voxel_downsample_point_cloud( np.vstack([scanned_pts, world_splitted_shot_pts]), self.voxel_size ) new_scanned_pts_num = new_scanned_pts.shape[0] history_indices.append(scan_points_idx_list[next_best_view]) scanned_pts = new_scanned_pts Log.info( f"Next Best cover pts: {next_best_covered_num}, Best coverage: {next_best_coverage}" ) coverage_rate_increase = next_best_coverage - last_coverage if coverage_rate_increase < self.min_coverage_increase: Log.info(f"Coverage rate = {coverage_rate_increase} < {self.min_coverage_increase}, stop scanning") # break last_coverage = next_best_coverage new_added_pts_num = new_scanned_pts_num - last_scanned_pts_num if new_added_pts_num < self.min_shot_new_pts_num: Log.info(f"New added pts num = {new_added_pts_num} < {self.min_shot_new_pts_num}") #ipdb.set_trace() if len(shot_pts_list) >= self.max_shot_view_num: Log.info(f"Scanned view num = {len(shot_pts_list)} >= {self.max_shot_view_num}, stop scanning") #break cnt += 1 Log.success("[Part 4/4] finish close-loop control") def run(self): total = len(os.listdir(self.model_dir)) model_start_idx = self.generate_config["model_start_idx"] count_object = model_start_idx for model_name in os.listdir(self.model_dir[model_start_idx:]): Log.info(f"[{count_object}/{total}]Processing {model_name}") self.run_one_model(model_name) Log.success(f"[{count_object}/{total}]Finished processing {model_name}") # ---------------------------- test ---------------------------- # if __name__ == "__main__": model_path = r"C:\Users\hofee\Downloads\mesh.obj" model = trimesh.load(model_path)