nbv_rec_control/runners/cad_close_loop_strategy.py
2024-10-18 17:13:45 +08:00

245 lines
11 KiB
Python

import os
import time
import trimesh
import tempfile
import subprocess
import numpy as np
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
import PytorchBoot.stereotype as stereotype
from PytorchBoot.utils.log_util import Log
from PytorchBoot.status import status_manager
from utils.control_util import ControlUtil
from utils.communicate_util import CommunicateUtil
from utils.pts_util import PtsUtil
from utils.reconstruction_util import ReconstructionUtil
from utils.preprocess_util import save_scene_data, save_scene_data_multithread
from utils.data_load import DataLoadUtil
from utils.view_util import ViewUtil
@stereotype.runner("CAD_close_loop_strategy_runner")
class CADCloseLoopStrategyRunner(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.load_experiment("cad_strategy")
self.status_info = {
"status_manager": status_manager,
"app_name": "cad",
"runner_name": "CAD_close_loop_strategy_runner",
}
self.generate_config = ConfigManager.get("runner", "generate")
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
self.blender_bin_path = self.generate_config["blender_bin_path"]
self.generator_script_path = self.generate_config["generator_script_path"]
self.model_dir = self.generate_config["model_dir"]
self.voxel_size = self.generate_config["voxel_size"]
self.max_view = self.generate_config["max_view"]
self.min_view = self.generate_config["min_view"]
self.max_diag = self.generate_config["max_diag"]
self.min_diag = self.generate_config["min_diag"]
self.min_cam_table_included_degree = self.generate_config[
"min_cam_table_included_degree"
]
self.max_shot_view_num = self.generate_config["max_shot_view_num"]
self.min_shot_new_pts_num = self.generate_config["min_shot_new_pts_num"]
self.min_coverage_increase = self.generate_config["min_coverage_increase"]
self.random_view_ratio = self.generate_config["random_view_ratio"]
self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"]
self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"]
self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"]
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
def split_scan_pts_and_obj_pts(self, world_pts, z_threshold=0):
scan_pts = world_pts[world_pts[:, 2] < z_threshold]
obj_pts = world_pts[world_pts[:, 2] >= z_threshold]
return scan_pts, obj_pts
def run_one_model(self, model_name):
temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output"
ControlUtil.connect_robot()
""" init robot """
Log.info("[Part 1/5] start init and register")
ControlUtil.init()
""" load CAD model """
model_path = os.path.join(self.model_dir, model_name, "mesh.ply")
temp_name = "cad_model_world"
cad_model = trimesh.load(model_path)
""" take first view """
Log.info("[Part 1/5] take first view data")
view_data = CommunicateUtil.get_view_data(init=True)
first_cam_pts = ViewUtil.get_pts(view_data)
first_cam_to_real_world = ControlUtil.get_pose()
first_real_world_pts = PtsUtil.transform_point_cloud(
first_cam_pts, first_cam_to_real_world
)
_, first_splitted_real_world_pts = self.split_scan_pts_and_obj_pts(
first_real_world_pts
)
np.savetxt(f"first_real_pts_{model_name}.txt", first_splitted_real_world_pts)
""" register """
Log.info("[Part 1/4] do registeration")
real_world_to_cad = PtsUtil.register(first_splitted_real_world_pts, cad_model)
cad_to_real_world = np.linalg.inv(real_world_to_cad)
Log.success("[Part 1/4] finish init and register")
real_world_to_blender_world = np.eye(4)
real_world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215])
cad_model_real_world: trimesh.Trimesh = cad_model.apply_transform(
cad_to_real_world
)
cad_model_real_world.export(
os.path.join(temp_dir, f"real_world_{temp_name}.obj")
)
cad_model_blender_world: trimesh.Trimesh = cad_model.apply_transform(
real_world_to_blender_world
)
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output"
cad_model_blender_world.export(os.path.join(temp_dir, f"{temp_name}.obj"))
""" sample view """
Log.info("[Part 2/4] start running renderer")
subprocess.run(
[
self.blender_bin_path,
"-b",
"-P",
self.generator_script_path,
"--",
temp_dir,
],
capture_output=True,
text=True,
)
Log.success("[Part 2/4] finish running renderer")
""" preprocess """
Log.info("[Part 3/4] start preprocessing data")
save_scene_data(temp_dir, temp_name)
Log.success("[Part 3/4] finish preprocessing data")
pts_dir = os.path.join(temp_dir, temp_name, "pts")
sample_view_pts_list = []
scan_points_idx_list = []
frame_num = len(os.listdir(pts_dir))
for frame_idx in range(frame_num):
pts_path = os.path.join(temp_dir, temp_name, "pts", f"{frame_idx}.txt")
idx_path = os.path.join(
temp_dir, temp_name, "scan_points_indices", f"{frame_idx}.npy"
)
point_cloud = np.loadtxt(pts_path)
if point_cloud.shape[0] != 0:
sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(
point_cloud, self.voxel_size
)
indices = np.load(idx_path)
try:
len(indices)
except:
indices = np.array([indices])
sample_view_pts_list.append(sampled_point_cloud)
scan_points_idx_list.append(indices)
""" close-loop strategy """
scanned_pts = PtsUtil.voxel_downsample_point_cloud(
first_splitted_real_world_pts, self.voxel_size
)
shot_pts_list = [first_splitted_real_world_pts]
history_indices = []
last_coverage = 0
Log.info("[Part 4/4] start close-loop control")
cnt = 0
while True:
#import ipdb; ipdb.set_trace()
next_best_view, next_best_coverage, next_best_covered_num = (
ReconstructionUtil.compute_next_best_view_with_overlap(
scanned_pts,
sample_view_pts_list,
history_indices,
scan_points_idx_list,
threshold=self.voxel_size,
overlap_area_threshold=25,
scan_points_threshold=self.scan_points_threshold,
)
)
nbv_path = DataLoadUtil.get_path(temp_dir, temp_name, next_best_view)
nbv_cam_info = DataLoadUtil.load_cam_info(nbv_path, binocular=True)
nbv_cam_to_world = nbv_cam_info["cam_to_world_O"]
ControlUtil.move_to(nbv_cam_to_world)
''' get world pts '''
time.sleep(0.5)
view_data = CommunicateUtil.get_view_data()
if view_data is None:
Log.error("No view data received")
continue
cam_shot_pts = ViewUtil.get_pts(view_data)
world_shot_pts = PtsUtil.transform_point_cloud(
cam_shot_pts, first_cam_to_real_world
)
_, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts(
world_shot_pts
)
shot_pts_list.append(world_splitted_shot_pts)
debug_dir = os.path.join(temp_dir, "debug")
if not os.path.exists(debug_dir):
os.makedirs(debug_dir)
np.savetxt(os.path.join(debug_dir, f"shot_pts_{cnt}.txt"), world_splitted_shot_pts)
np.savetxt(os.path.join(debug_dir, f"render_pts_{cnt}.txt"), sample_view_pts_list[next_best_view])
#real_world_to_cad = PtsUtil.register(first_splitted_real_world_pts, cad_model)
#import ipdb; ipdb.set_trace()
last_scanned_pts_num = scanned_pts.shape[0]
new_scanned_pts = PtsUtil.voxel_downsample_point_cloud(
np.vstack([scanned_pts, world_splitted_shot_pts]), self.voxel_size
)
new_scanned_pts_num = new_scanned_pts.shape[0]
history_indices.append(scan_points_idx_list[next_best_view])
scanned_pts = new_scanned_pts
Log.info(
f"Next Best cover pts: {next_best_covered_num}, Best coverage: {next_best_coverage}"
)
coverage_rate_increase = next_best_coverage - last_coverage
if coverage_rate_increase < self.min_coverage_increase:
Log.info(f"Coverage rate = {coverage_rate_increase} < {self.min_coverage_increase}, stop scanning")
# break
last_coverage = next_best_coverage
new_added_pts_num = new_scanned_pts_num - last_scanned_pts_num
if new_added_pts_num < self.min_shot_new_pts_num:
Log.info(f"New added pts num = {new_added_pts_num} < {self.min_shot_new_pts_num}")
#ipdb.set_trace()
if len(shot_pts_list) >= self.max_shot_view_num:
Log.info(f"Scanned view num = {len(shot_pts_list)} >= {self.max_shot_view_num}, stop scanning")
#break
cnt += 1
Log.success("[Part 4/4] finish close-loop control")
def run(self):
total = len(os.listdir(self.model_dir))
model_start_idx = self.generate_config["model_start_idx"]
count_object = model_start_idx
for model_name in os.listdir(self.model_dir[model_start_idx:]):
Log.info(f"[{count_object}/{total}]Processing {model_name}")
self.run_one_model(model_name)
Log.success(f"[{count_object}/{total}]Finished processing {model_name}")
# ---------------------------- test ---------------------------- #
if __name__ == "__main__":
model_path = r"C:\Users\hofee\Downloads\mesh.obj"
model = trimesh.load(model_path)