nbv_rec_control/runners/cad_open_loop_strategy.py
2024-10-13 15:24:41 +08:00

224 lines
11 KiB
Python

import os
import time
import trimesh
import tempfile
import subprocess
import numpy as np
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
import PytorchBoot.stereotype as stereotype
from PytorchBoot.utils.log_util import Log
from PytorchBoot.status import status_manager
from utils.control_util import ControlUtil
from utils.communicate_util import CommunicateUtil
from utils.pts_util import PtsUtil
from utils.reconstruction_util import ReconstructionUtil
from utils.preprocess_util import save_scene_data, save_scene_data_multithread
from utils.data_load import DataLoadUtil
from utils.view_util import ViewUtil
@stereotype.runner("CAD_open_loop_strategy_runner")
class CADOpenLoopStrategyRunner(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.load_experiment("cad_open_loop_strategy")
self.status_info = {
"status_manager": status_manager,
"app_name": "cad",
"runner_name": "CAD_open_loop_strategy_runner"
}
self.generate_config = ConfigManager.get("runner", "generate")
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
self.blender_bin_path = self.generate_config["blender_bin_path"]
self.generator_script_path = self.generate_config["generator_script_path"]
self.model_dir = self.generate_config["model_dir"]
self.voxel_size = self.generate_config["voxel_size"]
self.max_view = self.generate_config["max_view"]
self.min_view = self.generate_config["min_view"]
self.max_diag = self.generate_config["max_diag"]
self.min_diag = self.generate_config["min_diag"]
self.min_cam_table_included_degree = self.generate_config["min_cam_table_included_degree"]
self.random_view_ratio = self.generate_config["random_view_ratio"]
self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"]
self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"]
self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"]
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
def split_scan_pts_and_obj_pts(self, world_pts, z_threshold = 0):
scan_pts = world_pts[world_pts[:,2] < z_threshold]
obj_pts = world_pts[world_pts[:,2] >= z_threshold]
return scan_pts, obj_pts
def run_one_model(self, model_name):
temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output"
result = dict()
shot_pts_list = []
ControlUtil.connect_robot()
''' init robot '''
Log.info("[Part 1/5] start init and register")
ControlUtil.init()
''' load CAD model '''
model_path = os.path.join(self.model_dir, model_name,"mesh.ply")
temp_name = "cad_model_world"
cad_model = trimesh.load(model_path)
''' take first view '''
Log.info("[Part 1/5] take first view data")
view_data = CommunicateUtil.get_view_data(init=True)
first_cam_pts = ViewUtil.get_pts(view_data)
first_cam_to_real_world = ControlUtil.get_pose()
first_real_world_pts = PtsUtil.transform_point_cloud(first_cam_pts, first_cam_to_real_world)
_, first_splitted_real_world_pts = self.split_scan_pts_and_obj_pts(first_real_world_pts)
np.savetxt(f"first_real_pts_{model_name}.txt", first_splitted_real_world_pts)
''' register '''
Log.info("[Part 1/5] do registeration")
real_world_to_cad = PtsUtil.register(first_splitted_real_world_pts, cad_model)
cad_to_real_world = np.linalg.inv(real_world_to_cad)
Log.success("[Part 1/5] finish init and register")
real_world_to_blender_world = np.eye(4)
real_world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215])
cad_model_real_world:trimesh.Trimesh = cad_model.apply_transform(cad_to_real_world)
cad_model_real_world.export(os.path.join(temp_dir, f"real_world_{temp_name}.obj"))
cad_model_blender_world:trimesh.Trimesh = cad_model.apply_transform(real_world_to_blender_world)
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output"
cad_model_blender_world.export(os.path.join(temp_dir, f"{temp_name}.obj"))
scene_dir = os.path.join(temp_dir, temp_name)
''' sample view '''
Log.info("[Part 2/5] start running renderer")
subprocess.run([
self.blender_bin_path, '-b', '-P', self.generator_script_path, '--', temp_dir
], capture_output=True, text=True)
Log.success("[Part 2/5] finish running renderer")
world_model_points = np.loadtxt(os.path.join(scene_dir, "points_and_normals.txt"))[:,:3]
''' preprocess '''
Log.info("[Part 3/5] start preprocessing data")
save_scene_data(temp_dir, temp_name)
Log.success("[Part 3/5] finish preprocessing data")
pts_dir = os.path.join(temp_dir,temp_name,"pts")
sample_view_pts_list = []
scan_points_idx_list = []
frame_num = len(os.listdir(pts_dir))
for frame_idx in range(frame_num):
pts_path = os.path.join(temp_dir,temp_name, "pts", f"{frame_idx}.txt")
idx_path = os.path.join(temp_dir,temp_name, "scan_points_indices", f"{frame_idx}.npy")
point_cloud = np.loadtxt(pts_path)
if point_cloud.shape[0] != 0:
sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud, self.voxel_size)
indices = np.load(idx_path)
try:
len(indices)
except:
indices = np.array([indices])
sample_view_pts_list.append(sampled_point_cloud)
scan_points_idx_list.append(indices)
''' generate strategy '''
Log.info("[Part 4/5] start generating strategy")
limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(
world_model_points, sample_view_pts_list,
scan_points_indices_list = scan_points_idx_list,
init_view=0,
threshold=self.voxel_size,
soft_overlap_threshold = self.soft_overlap_threshold,
hard_overlap_threshold = self.hard_overlap_threshold,
scan_points_threshold = self.scan_points_threshold,
status_info=self.status_info
)
Log.success("[Part 4/5] finish generating strategy")
''' extract cam_to_world sequence '''
cam_to_world_seq = []
coveraget_rate_seq = []
render_pts = []
idx_seq = []
for idx, coverage_rate in limited_useful_view:
path = DataLoadUtil.get_path(temp_dir, temp_name, idx)
cam_info = DataLoadUtil.load_cam_info(path, binocular=True)
cam_to_world_seq.append(cam_info["cam_to_world_O"])
coveraget_rate_seq.append(coverage_rate)
idx_seq.append(idx)
render_pts.append(sample_view_pts_list[idx])
Log.info("[Part 5/5] start running robot")
''' take best seq view '''
#import ipdb; ipdb.set_trace()
target_scanned_pts = np.concatenate(sample_view_pts_list)
voxel_downsampled_target_scanned_pts = PtsUtil.voxel_downsample_point_cloud(target_scanned_pts, self.voxel_size)
result = dict()
gt_scanned_pts = np.concatenate(render_pts, axis=0)
voxel_down_sampled_gt_scanned_pts = PtsUtil.voxel_downsample_point_cloud(gt_scanned_pts, self.voxel_size)
result["gt_final_coverage_rate_cad"] = ReconstructionUtil.compute_coverage_rate(voxel_downsampled_target_scanned_pts, voxel_down_sampled_gt_scanned_pts, self.voxel_size)
step = 1
result["real_coverage_rate_seq"] = []
for cam_to_world in cam_to_world_seq:
try:
ControlUtil.move_to(cam_to_world)
''' get world pts '''
time.sleep(0.5)
view_data = CommunicateUtil.get_view_data()
if view_data is None:
Log.error("Failed to get view data")
continue
cam_pts = ViewUtil.get_pts(view_data)
shot_pts_list.append(cam_pts)
scanned_pts = np.concatenate(shot_pts_list, axis=0)
voxel_down_sampled_scanned_pts = PtsUtil.voxel_downsample_point_cloud(scanned_pts, self.voxel_size)
voxel_down_sampled_scanned_pts_world = PtsUtil.transform_point_cloud(voxel_down_sampled_scanned_pts, first_cam_to_real_world)
curr_CR = ReconstructionUtil.compute_coverage_rate(voxel_downsampled_target_scanned_pts, voxel_down_sampled_scanned_pts_world, self.voxel_size)
Log.success(f"(step {step}/{len(cam_to_world_seq)}) current coverage: {curr_CR} | gt coverage: {result['gt_final_coverage_rate_cad']}")
result["real_final_coverage_rate"] = curr_CR
result["real_coverage_rate_seq"].append(curr_CR)
step += 1
except Exception as e:
Log.error(f"Failed to move to {cam_to_world}")
Log.error(e)
#import ipdb;ipdb.set_trace()
for idx in range(len(shot_pts_list)):
if not os.path.exists(os.path.join(temp_dir, temp_name, "shot_pts")):
os.makedirs(os.path.join(temp_dir, temp_name, "shot_pts"))
if not os.path.exists(os.path.join(temp_dir, temp_name, "render_pts")):
os.makedirs(os.path.join(temp_dir, temp_name, "render_pts"))
shot_pts = PtsUtil.transform_point_cloud(shot_pts_list[idx], first_cam_to_real_world)
np.savetxt(os.path.join(temp_dir, temp_name, "shot_pts", f"{idx}.txt"), shot_pts)
np.savetxt(os.path.join(temp_dir, temp_name, "render_pts", f"{idx}.txt"), render_pts[idx])
Log.success("[Part 5/5] finish running robot")
Log.debug(result)
def run(self):
total = len(os.listdir(self.model_dir))
model_start_idx = self.generate_config["model_start_idx"]
count_object = model_start_idx
for model_name in os.listdir(self.model_dir[model_start_idx:]):
Log.info(f"[{count_object}/{total}]Processing {model_name}")
self.run_one_model(model_name)
Log.success(f"[{count_object}/{total}]Finished processing {model_name}")
# ---------------------------- test ---------------------------- #
if __name__ == "__main__":
model_path = r"C:\Users\hofee\Downloads\mesh.obj"
model = trimesh.load(model_path)