nbv_rec_control/runners/cad_strategy.py
2024-10-09 21:46:13 +08:00

222 lines
9.7 KiB
Python

import os
import time
import trimesh
import tempfile
import subprocess
import numpy as np
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
import PytorchBoot.stereotype as stereotype
from PytorchBoot.utils.log_util import Log
from PytorchBoot.status import status_manager
from utils.control_util import ControlUtil
from utils.communicate_util import CommunicateUtil
from utils.pts_util import PtsUtil
from utils.reconstruction_util import ReconstructionUtil
from utils.preprocess_util import save_scene_data
from utils.data_load import DataLoadUtil
from utils.view_util import ViewUtil
@stereotype.runner("CAD_strategy_runner")
class CADStrategyRunner(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.load_experiment("cad_strategy")
self.status_info = {
"status_manager": status_manager,
"app_name": "cad",
"runner_name": "cad_strategy"
}
self.generate_config = ConfigManager.get("runner", "generate")
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
self.blender_bin_path = self.generate_config["blender_bin_path"]
self.generator_script_path = self.generate_config["generator_script_path"]
self.model_dir = self.generate_config["model_dir"]
self.voxel_size = self.generate_config["voxel_size"]
self.max_view = self.generate_config["max_view"]
self.min_view = self.generate_config["min_view"]
self.max_diag = self.generate_config["max_diag"]
self.min_diag = self.generate_config["min_diag"]
self.min_cam_table_included_degree = self.generate_config["min_cam_table_included_degree"]
self.random_view_ratio = self.generate_config["random_view_ratio"]
self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"]
self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"]
self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"]
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
def split_scan_pts_and_obj_pts(self, world_pts, scan_pts_z, z_threshold = 0.003):
scan_pts = world_pts[scan_pts_z < z_threshold]
obj_pts = world_pts[scan_pts_z >= z_threshold]
return scan_pts, obj_pts
def run_one_model(self, model_name):
shot_pts_list = []
ControlUtil.connect_robot()
''' init robot '''
Log.info("[Part 1/5] start init and register")
ControlUtil.init()
''' load CAD model '''
model_path = os.path.join(self.model_dir, model_name,"mesh.ply")
temp_name = "cad_model_world"
cad_model = trimesh.load(model_path)
''' take first view '''
Log.info("[Part 1/5] take first view data")
view_data = CommunicateUtil.get_view_data(init=True)
first_cam_pts = ViewUtil.get_pts(view_data)
#shot_pts_list.append(first_cam_pts)
''' register '''
Log.info("[Part 1/5] do registeration")
cam_to_cad = PtsUtil.register(first_cam_pts, cad_model)
first_cam_to_world = ControlUtil.get_pose()
cad_to_world = first_cam_to_world @ np.linalg.inv(cam_to_cad)
Log.success("[Part 1/5] finish init and register")
world_to_blender_world = np.eye(4)
world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215])
cad_to_blender_world = world_to_blender_world @ cad_to_world
temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output"
cad_model:trimesh.Trimesh = cad_model.apply_transform(cad_to_blender_world)
cad_model.export(os.path.join(temp_dir, f"{temp_name}.obj"))
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output"
cad_model.export(os.path.join(temp_dir, f"{temp_name}.obj"))
scene_dir = os.path.join(temp_dir, temp_name)
''' sample view '''
Log.info("[Part 2/5] start running renderer")
result = subprocess.run([
self.blender_bin_path, '-b', '-P', self.generator_script_path, '--', temp_dir
], capture_output=True, text=True)
Log.success("[Part 2/5] finish running renderer")
world_model_points = np.loadtxt(os.path.join(scene_dir, "points_and_normals.txt"))[:,:3]
''' preprocess '''
Log.info("[Part 3/5] start preprocessing data")
save_scene_data(temp_dir, temp_name)
Log.success("[Part 3/5] finish preprocessing data")
pts_dir = os.path.join(temp_dir,temp_name,"pts")
sample_view_pts_list = []
scan_points_idx_list = []
frame_num = len(os.listdir(pts_dir))
for frame_idx in range(frame_num):
pts_path = os.path.join(temp_dir,temp_name, "pts", f"{frame_idx}.txt")
idx_path = os.path.join(temp_dir,temp_name, "scan_points_indices", f"{frame_idx}.txt")
point_cloud = np.loadtxt(pts_path)
if point_cloud.shape[0] != 0:
sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud, self.voxel_size)
indices = np.loadtxt(idx_path, dtype=np.int32)
try:
len(indices)
except:
indices = np.array([indices])
sample_view_pts_list.append(sampled_point_cloud)
scan_points_idx_list.append(indices)
''' generate strategy '''
Log.info("[Part 4/5] start generating strategy")
limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(
world_model_points, sample_view_pts_list,
scan_points_indices_list = scan_points_idx_list,
init_view=0,
threshold=self.voxel_size,
soft_overlap_threshold = self.soft_overlap_threshold,
hard_overlap_threshold = self.hard_overlap_threshold,
scan_points_threshold = self.scan_points_threshold,
status_info=self.status_info
)
Log.success("[Part 4/5] finish generating strategy")
''' extract cam_to_world sequence '''
cam_to_world_seq = []
coveraget_rate_seq = []
render_pts = []
idx_seq = []
for idx, coverage_rate in limited_useful_view:
path = DataLoadUtil.get_path(temp_dir, temp_name, idx)
cam_info = DataLoadUtil.load_cam_info(path, binocular=True)
cam_to_world_seq.append(cam_info["cam_to_world"])
coveraget_rate_seq.append(coverage_rate)
idx_seq.append(idx)
render_pts.append(sample_view_pts_list[idx])
Log.info("[Part 5/5] start running robot")
''' take best seq view '''
for cam_to_world in cam_to_world_seq:
ControlUtil.move_to(cam_to_world)
''' get world pts '''
time.sleep(1)
view_data = CommunicateUtil.get_view_data()
if view_data is None:
Log.error("Failed to get view data")
continue
cam_pts = ViewUtil.get_pts(view_data)
shot_pts_list.append(cam_pts)
#import ipdb;ipdb.set_trace()
print(idx_seq)
for idx in range(len(shot_pts_list)):
if not os.path.exists(os.path.join(temp_dir, temp_name, "shot_pts")):
os.makedirs(os.path.join(temp_dir, temp_name, "shot_pts"))
if not os.path.exists(os.path.join(temp_dir, temp_name, "render_pts")):
os.makedirs(os.path.join(temp_dir, temp_name, "render_pts"))
shot_pts = PtsUtil.transform_point_cloud(shot_pts_list[idx], first_cam_to_world)
np.savetxt(os.path.join(temp_dir, temp_name, "shot_pts", f"{idx}.txt"), shot_pts)
np.savetxt(os.path.join(temp_dir, temp_name, "render_pts", f"{idx}.txt"), render_pts[idx])
Log.success("[Part 5/5] finish running robot")
def run(self):
total = len(os.listdir(self.model_dir))
model_start_idx = self.generate_config["model_start_idx"]
count_object = model_start_idx
for model_name in os.listdir(self.model_dir[model_start_idx:]):
Log.info(f"[{count_object}/{total}]Processing {model_name}")
self.run_one_model(model_name)
Log.success(f"[{count_object}/{total}]Finished processing {model_name}")
# ---------------------------- test ---------------------------- #
if __name__ == "__main__":
model_path = r"C:\Users\hofee\Downloads\mesh.obj"
model = trimesh.load(model_path)
''' test register '''
# test_pts_L = np.load(r"C:\Users\hofee\Downloads\0.npy")
# import open3d as o3d
# def add_noise(points, translation, rotation):
# R = o3d.geometry.get_rotation_matrix_from_axis_angle(rotation)
# noisy_points = points @ R.T + translation
# return noisy_points
# translation_noise = np.random.uniform(-0.5, 0.5, size=3)
# rotation_noise = np.random.uniform(-np.pi/4, np.pi/4, size=3)
# noisy_pts_L = add_noise(test_pts_L, translation_noise, rotation_noise)
# cad_to_cam_L = PtsUtil.register(noisy_pts_L, model)
# cad_pts_L = PtsUtil.transform_point_cloud(noisy_pts_L, cad_to_cam_L)
# np.savetxt(r"test.txt", cad_pts_L)
# np.savetxt(r"src.txt", noisy_pts_L)