import numpy as np from utils.pose_util import PoseUtil import trimesh from collections import defaultdict from scipy.spatial.transform import Rotation as R import random class ViewSampleUtil: @staticmethod def farthest_point_sampling(points, num_samples): num_points = points.shape[0] if num_samples >= num_points: return points, np.arange(num_points) sampled_indices = np.zeros(num_samples, dtype=int) sampled_indices[0] = np.random.randint(num_points) min_distances = np.full(num_points, np.inf) for i in range(1, num_samples): current_point = points[sampled_indices[i - 1]] dist_to_current_point = np.linalg.norm(points - current_point, axis=1) min_distances = np.minimum(min_distances, dist_to_current_point) sampled_indices[i] = np.argmax(min_distances) downsampled_points = points[sampled_indices] return downsampled_points, sampled_indices @staticmethod def voxel_downsample(points, voxel_size): voxel_grid = defaultdict(list) for i, point in enumerate(points): voxel_index = tuple((point // voxel_size).astype(int)) voxel_grid[voxel_index].append(i) downsampled_points = [] downsampled_indices = [] for indices in voxel_grid.values(): selected_index = indices[0] downsampled_points.append(points[selected_index]) downsampled_indices.append(selected_index) return np.array(downsampled_points), downsampled_indices @staticmethod def sample_view_data(mesh: trimesh.Trimesh, distance_range: tuple = (0.25, 0.5), voxel_size: float = 0.005, max_views: int = 1, pertube_repeat: int = 1) -> dict: view_data = { "look_at_points": [], "cam_positions": [], } vertices = mesh.vertices look_at_points = [] cam_positions = [] normals = [] vertex_normals = mesh.vertex_normals for i, vertex in enumerate(vertices): look_at_point = vertex view_data["look_at_points"].append(look_at_point) normal = vertex_normals[i] if np.isnan(normal).any(): continue if np.dot(normal, look_at_point) < 0: normal = -normal normals.append(normal) for _ in range(pertube_repeat): perturb_angle = np.radians(np.random.uniform(0, 30)) perturb_axis = np.random.normal(size=3) perturb_axis /= np.linalg.norm(perturb_axis) rotation_matrix = R.from_rotvec(perturb_angle * perturb_axis).as_matrix() perturbed_normal = np.dot(rotation_matrix, normal) distance = np.random.uniform(*distance_range) cam_position = look_at_point + distance * perturbed_normal look_at_points.append(look_at_point) cam_positions.append(cam_position) look_at_points = np.array(look_at_points) cam_positions = np.array(cam_positions) voxel_downsampled_look_at_points, selected_indices = ViewSampleUtil.voxel_downsample(look_at_points, voxel_size) voxel_downsampled_cam_positions = cam_positions[selected_indices] voxel_downsampled_normals = np.array(normals)[selected_indices] fps_downsampled_look_at_points, selected_indices = ViewSampleUtil.farthest_point_sampling(voxel_downsampled_look_at_points, max_views * 2) fps_downsampled_cam_positions = voxel_downsampled_cam_positions[selected_indices] view_data["look_at_points"] = fps_downsampled_look_at_points.tolist() view_data["cam_positions"] = fps_downsampled_cam_positions.tolist() view_data["normals"] = voxel_downsampled_normals.tolist() view_data["voxel_down_sampled_points"] = voxel_downsampled_look_at_points return view_data @staticmethod def get_world_points_and_normals(view_data: dict, obj_world_pose: np.ndarray) -> tuple: world_points = [] world_normals = [] for voxel_down_sampled_points, normal in zip(view_data["voxel_down_sampled_points"], view_data["normals"]): voxel_down_sampled_points_world = obj_world_pose @ np.append(voxel_down_sampled_points, 1.0) normal_world = obj_world_pose[:3, :3] @ normal world_points.append(voxel_down_sampled_points_world[:3]) world_normals.append(normal_world) return np.array(world_points), np.array(world_normals) @staticmethod def get_cam_pose(view_data: dict, obj_world_pose: np.ndarray, max_views: int, min_cam_table_included_degree: int, random_view_ratio: float) -> np.ndarray: cam_poses = [] min_height_z = 1000 for look_at_point, cam_position in zip(view_data["look_at_points"], view_data["cam_positions"]): look_at_point_world = obj_world_pose @ np.append(look_at_point, 1.0) cam_position_world = obj_world_pose @ np.append(cam_position, 1.0) if look_at_point_world[2] < min_height_z: min_height_z = look_at_point_world[2] look_at_point_world = look_at_point_world[:3] cam_position_world = cam_position_world[:3] forward_vector = cam_position_world - look_at_point_world forward_vector /= np.linalg.norm(forward_vector) up_vector = np.array([0, 0, 1]) right_vector = np.cross(up_vector, forward_vector) right_vector /= np.linalg.norm(right_vector) corrected_up_vector = np.cross(forward_vector, right_vector) rotation_matrix = np.array([right_vector, corrected_up_vector, forward_vector]).T cam_pose = np.eye(4) cam_pose[:3, :3] = rotation_matrix cam_pose[:3, 3] = cam_position_world cam_poses.append(cam_pose) filtered_cam_poses = [] for cam_pose in cam_poses: if cam_pose[2, 3] > min_height_z: direction_vector = cam_pose[:3, 2] horizontal_normal = np.array([0, 0, 1]) cos_angle = np.dot(direction_vector, horizontal_normal) / (np.linalg.norm(direction_vector) * np.linalg.norm(horizontal_normal)) angle = np.arccos(np.clip(cos_angle, -1.0, 1.0)) angle_degree = np.degrees(angle) if angle_degree < 90 - min_cam_table_included_degree: filtered_cam_poses.append(cam_pose) if random.random() < random_view_ratio: pertube_pose = PoseUtil.get_uniform_pose([0.1, 0.1, 0.1], [3, 3, 3], 0, 180, "cm") filtered_cam_poses.append(pertube_pose @ cam_pose) if len(filtered_cam_poses) > max_views: indices = np.random.choice(len(filtered_cam_poses), max_views, replace=False) filtered_cam_poses = [filtered_cam_poses[i] for i in indices] return np.array(filtered_cam_poses) @staticmethod def sample_view_data_world_space(mesh: trimesh.Trimesh, cad_to_world: np.ndarray, distance_range:tuple = (0.25,0.5), voxel_size:float = 0.005, max_views: int=1, min_cam_table_included_degree:int=20, random_view_ratio:float = 0.2) -> dict: view_data = ViewSampleUtil.sample_view_data(mesh, distance_range, voxel_size, max_views) view_data["cam_to_world_poses"] = ViewSampleUtil.get_cam_pose(view_data, cad_to_world, max_views, min_cam_table_included_degree, random_view_ratio) view_data["voxel_down_sampled_points"], view_data["normals"] = ViewSampleUtil.get_world_points_and_normals(view_data, cad_to_world) return view_data