nbv_rec_control/runners/cad_strategy.py
2024-10-07 16:20:56 +08:00

144 lines
6.2 KiB
Python

import os
import trimesh
import numpy as np
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
import PytorchBoot.stereotype as stereotype
from PytorchBoot.utils.log_util import Log
from PytorchBoot.status import status_manager
from utils.control_util import ControlUtil
from utils.communicate_util import CommunicateUtil
from utils.pts_util import PtsUtil
from utils.view_sample_util import ViewSampleUtil
from utils.reconstruction_util import ReconstructionUtil
@stereotype.runner("CAD_strategy_runner")
class CADStrategyRunner(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.load_experiment("cad_strategy")
self.status_info = {
"status_manager": status_manager,
"app_name": "cad",
"runner_name": "cad_strategy"
}
self.generate_config = ConfigManager.get("runner", "generate")
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
self.model_dir = self.generate_config["model_dir"]
self.voxel_size = self.generate_config["voxel_size"]
self.max_view = self.generate_config["max_view"]
self.min_view = self.generate_config["min_view"]
self.max_diag = self.generate_config["max_diag"]
self.min_diag = self.generate_config["min_diag"]
self.min_cam_table_included_degree = self.generate_config["min_cam_table_included_degree"]
self.random_view_ratio = self.generate_config["random_view_ratio"]
self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"]
self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"]
self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"]
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
def run_one_model(self, model_name):
''' init robot '''
ControlUtil.init()
''' load CAD model '''
model_path = os.path.join(self.model_dir, model_name)
cad_model = trimesh.load(model_path)
''' take first view '''
view_data = CommunicateUtil.get_view_data()
first_cam_pts = None
''' register '''
cad_to_cam = PtsUtil.register_icp(first_cam_pts, cad_model)
cam_to_world = ControlUtil.get_pose()
cad_to_world = cam_to_world @ cad_to_cam
cad_model:trimesh.Trimesh = cad_model.apply_transform(cad_to_world)
''' sample view '''
min_corner = cad_model.bounds[0]
max_corner = cad_model.bounds[1]
diag = np.linalg.norm(max_corner - min_corner)
view_num = int(self.min_view + (diag - self.min_diag)/(self.max_diag - self.min_diag) * (self.max_view - self.min_view))
sampled_view_data = ViewSampleUtil.sample_view_data_world_space(
cad_model, cad_to_world,
voxel_size = self.voxel_size,
max_views = view_num,
min_cam_table_included_degree= self.min_cam_table_included_degree,
random_view_ratio = self.random_view_ratio
)
cam_to_world_poses = sampled_view_data["cam_to_world_poses"]
world_model_points = sampled_view_data["voxel_down_sampled_points"]
''' take sample view '''
scan_points_idx_list = []
sample_view_pts_list = []
for cam_to_world in cam_to_world_poses:
ControlUtil.move_to(cam_to_world)
''' get world pts '''
view_data = CommunicateUtil.get_view_data()
cam_pts = None
scan_points_idx = None
world_pts = PtsUtil.transform_point_cloud(cam_pts, cam_to_world)
sample_view_pts_list.append(world_pts)
scan_points_idx_list.append(scan_points_idx)
''' generate strategy '''
limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(
world_model_points, sample_view_pts_list,
scan_points_indices_list = scan_points_idx_list,
init_view=0,
threshold=self.voxel_size,
soft_overlap_threshold= self.soft_overlap_threshold,
hard_overlap_threshold= self.hard_overlap_threshold,
scan_points_threshold = self.scan_points_threshold,
status_info=self.status_info
)
''' extract cam_to world sequence '''
cam_to_world_seq = []
coveraget_rate_seq = []
for idx, coverage_rate in limited_useful_view:
cam_to_world_seq.append(cam_to_world_poses[idx])
coveraget_rate_seq.append(coverage_rate)
''' take best seq view '''
for cam_to_world in cam_to_world_seq:
ControlUtil.move_to(cam_to_world)
''' get world pts '''
view_data = CommunicateUtil.get_view_data()
cam_pts = None
scan_points_idx = None
world_pts = PtsUtil.transform_point_cloud(cam_pts, cam_to_world)
sample_view_pts_list.append(world_pts)
scan_points_idx_list.append(scan_points_idx)
def run(self):
total = len(os.listdir(self.model_dir))
model_start_idx = self.generate_config["model_start_idx"]
count_object = model_start_idx
for model_name in os.listdir(self.model_dir[model_start_idx:]):
Log.info(f"[{count_object}/{total}]Processing {model_name}")
self.run_one_model(model_name)
Log.success(f"[{count_object}/{total}]Finished processing {model_name}")
if __name__ == "__main__":
model_path = "/home/yan20/nbv_rec/data/test_CAD/test_model/bear_scaled.ply"
model = trimesh.load(model_path)
test_pts_L = np.loadtxt("/home/yan20/nbv_rec/data/test_CAD/cam_pts_0_L.txt")
test_pts_R = np.loadtxt("/home/yan20/nbv_rec/data/test_CAD/cam_pts_0_R.txt")
cam_to_world_L = PtsUtil.register_icp(test_pts_L, model)
cam_to_world_R = PtsUtil.register_icp(test_pts_R, model)
print(cam_to_world_L)
print("================================")
print(cam_to_world_R)