nbv_rec_control/runners/cad_strategy.py
2024-10-08 21:28:30 +08:00

193 lines
8.3 KiB
Python

import os
import trimesh
import tempfile
import subprocess
import numpy as np
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
import PytorchBoot.stereotype as stereotype
from PytorchBoot.utils.log_util import Log
from PytorchBoot.status import status_manager
from utils.control_util import ControlUtil
from utils.communicate_util import CommunicateUtil
from utils.pts_util import PtsUtil
from utils.reconstruction_util import ReconstructionUtil
from utils.preprocess_util import save_scene_data
from utils.data_load import DataLoadUtil
@stereotype.runner("CAD_strategy_runner")
class CADStrategyRunner(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.load_experiment("cad_strategy")
self.status_info = {
"status_manager": status_manager,
"app_name": "cad",
"runner_name": "cad_strategy"
}
self.generate_config = ConfigManager.get("runner", "generate")
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
self.model_dir = self.generate_config["model_dir"]
self.voxel_size = self.generate_config["voxel_size"]
self.max_view = self.generate_config["max_view"]
self.min_view = self.generate_config["min_view"]
self.max_diag = self.generate_config["max_diag"]
self.min_diag = self.generate_config["min_diag"]
self.min_cam_table_included_degree = self.generate_config["min_cam_table_included_degree"]
self.random_view_ratio = self.generate_config["random_view_ratio"]
self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"]
self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"]
self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"]
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
def get_pts_from_view_data(self, view_data):
depth = view_data["depth_image"]
depth_intrinsics = view_data["depth_intrinsics"]
depth_extrinsics = view_data["depth_extrinsics"]
cam_pts = PtsUtil.get_pts_from_depth(depth, depth_intrinsics, depth_extrinsics)
return cam_pts
def split_scan_pts_and_obj_pts(self, world_pts, scan_pts_z, z_threshold = 0.003):
scan_pts = world_pts[scan_pts_z < z_threshold]
obj_pts = world_pts[scan_pts_z >= z_threshold]
return scan_pts, obj_pts
def run_one_model(self, model_name):
''' init robot '''
#ControlUtil.init()
''' load CAD model '''
model_path = os.path.join(self.model_dir, model_name,"mesh.obj")
cad_model = trimesh.load(model_path)
''' take first view '''
#view_data = CommunicateUtil.get_view_data(init=True)
#first_cam_pts = self.get_pts_from_view_data(view_data)
''' register '''
#cad_to_cam = PtsUtil.register(first_cam_pts, cad_model)
#cam_to_world = ControlUtil.get_pose()
cad_to_world = np.eye(4) #cam_to_world @ cad_to_cam
world_to_blender_world = np.eye(4)
world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215])
cad_to_blender_world = np.dot(world_to_blender_world, cad_to_world)
cad_model:trimesh.Trimesh = cad_model.apply_transform(cad_to_blender_world)
with tempfile.TemporaryDirectory() as temp_dir:
name = "cad_model_world"
cad_model.export(os.path.join(temp_dir, f"{name}.obj"))
temp_dir = "/home/user/nbv_rec/nbv_rec_control/test_output"
scene_dir = os.path.join(temp_dir, name)
script_path = "/home/user/nbv_rec/blender_app/data_generator.py"
''' sample view '''
# import ipdb; ipdb.set_trace()
# print("start running renderer")
# result = subprocess.run([
# 'blender', '-b', '-P', script_path, '--', temp_dir
# ], capture_output=True, text=True)
# print("finish running renderer")
#
world_model_points = np.loadtxt(os.path.join(scene_dir, "points_and_normals.txt"))[:,:3]
''' preprocess '''
# save_scene_data(temp_dir, name)
pts_dir = os.path.join(temp_dir,name,"pts")
sample_view_pts_list = []
scan_points_idx_list = []
frame_num = len(os.listdir(pts_dir))
for frame_idx in range(frame_num):
pts_path = os.path.join(temp_dir,name, "pts", f"{frame_idx}.txt")
idx_path = os.path.join(temp_dir,name, "scan_points_indices", f"{frame_idx}.txt")
point_cloud = np.loadtxt(pts_path)
sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud, self.voxel_size)
indices = np.loadtxt(idx_path, dtype=np.int32)
try:
len(indices)
except:
indices = np.array([indices])
sample_view_pts_list.append(sampled_point_cloud)
scan_points_idx_list.append(indices)
''' generate strategy '''
limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(
world_model_points, sample_view_pts_list,
scan_points_indices_list = scan_points_idx_list,
init_view=0,
threshold=self.voxel_size,
soft_overlap_threshold = self.soft_overlap_threshold,
hard_overlap_threshold = self.hard_overlap_threshold,
scan_points_threshold = self.scan_points_threshold,
status_info=self.status_info
)
''' extract cam_to_world sequence '''
cam_to_world_seq = []
coveraget_rate_seq = []
from ipdb import set_trace; set_trace()
for idx, coverage_rate in limited_useful_view:
path = DataLoadUtil.get_path(temp_dir, name, idx)
cam_info = DataLoadUtil.load_cam_info(path, binocular=True)
cam_to_world_seq.append(cam_info["cam_to_world"])
coveraget_rate_seq.append(coverage_rate)
# ''' take best seq view '''
# for cam_to_world in cam_to_world_seq:
# ControlUtil.move_to(cam_to_world)
# ''' get world pts '''
# view_data = CommunicateUtil.get_view_data()
# cam_pts = self.get_pts_from_view_data(view_data)
# scan_points_idx = None
# world_pts = PtsUtil.transform_point_cloud(cam_pts, cam_to_world)
# sample_view_pts_list.append(world_pts)
# scan_points_idx_list.append(scan_points_idx)
def run(self):
total = len(os.listdir(self.model_dir))
model_start_idx = self.generate_config["model_start_idx"]
count_object = model_start_idx
for model_name in os.listdir(self.model_dir[model_start_idx:]):
Log.info(f"[{count_object}/{total}]Processing {model_name}")
self.run_one_model(model_name)
Log.success(f"[{count_object}/{total}]Finished processing {model_name}")
# ---------------------------- test ---------------------------- #
if __name__ == "__main__":
model_path = r"C:\Users\hofee\Downloads\mesh.obj"
model = trimesh.load(model_path)
''' test register '''
# test_pts_L = np.load(r"C:\Users\hofee\Downloads\0.npy")
# import open3d as o3d
# def add_noise(points, translation, rotation):
# R = o3d.geometry.get_rotation_matrix_from_axis_angle(rotation)
# noisy_points = points @ R.T + translation
# return noisy_points
# translation_noise = np.random.uniform(-0.5, 0.5, size=3)
# rotation_noise = np.random.uniform(-np.pi/4, np.pi/4, size=3)
# noisy_pts_L = add_noise(test_pts_L, translation_noise, rotation_noise)
# cad_to_cam_L = PtsUtil.register(noisy_pts_L, model)
# cad_pts_L = PtsUtil.transform_point_cloud(noisy_pts_L, cad_to_cam_L)
# np.savetxt(r"test.txt", cad_pts_L)
# np.savetxt(r"src.txt", noisy_pts_L)