import numpy as np import bmesh from collections import defaultdict from scipy.spatial.transform import Rotation as R from utils.pose import PoseUtil from utils.pts import PtsUtil import random class CADViewSampleUtil: @staticmethod def farthest_point_sampling(points, num_samples): num_points = points.shape[0] if num_samples >= num_points: return points, np.arange(num_points) sampled_indices = np.zeros(num_samples, dtype=int) sampled_indices[0] = np.random.randint(num_points) min_distances = np.full(num_points, np.inf) for i in range(1, num_samples): current_point = points[sampled_indices[i - 1]] dist_to_current_point = np.linalg.norm(points - current_point, axis=1) min_distances = np.minimum(min_distances, dist_to_current_point) sampled_indices[i] = np.argmax(min_distances) downsampled_points = points[sampled_indices] return downsampled_points, sampled_indices @staticmethod def voxel_downsample(points, voxel_size): voxel_grid = defaultdict(list) for i, point in enumerate(points): voxel_index = tuple((point // voxel_size).astype(int)) voxel_grid[voxel_index].append(i) downsampled_points = [] downsampled_indices = [] for indices in voxel_grid.values(): selected_index = indices[0] downsampled_points.append(points[selected_index]) downsampled_indices.append(selected_index) return np.array(downsampled_points), downsampled_indices @staticmethod def sample_view_data(obj, distance_range:tuple = (0.25,0.5), voxel_size:float = 0.005, max_views: int = 1, pertube_repeat:int = 1) -> dict: view_data = { "look_at_points": [], "cam_positions": [], } mesh = obj.data bm = bmesh.new() bm.from_mesh(mesh) bm.verts.ensure_lookup_table() bm.faces.ensure_lookup_table() bm.normal_update() look_at_points = [] cam_positions = [] normals = [] for v in bm.verts: look_at_point = np.array(v.co) view_data["look_at_points"].append(look_at_point) normal = np.zeros(3) for loop in v.link_loops: normal += np.array(loop.calc_normal()) normal /= len(v.link_loops) normal = normal / np.linalg.norm(normal) if np.isnan(normal).any(): continue if np.dot(normal, look_at_point) < 0: normal = -normal normals.append(normal) for _ in range(pertube_repeat): perturb_angle = np.radians(np.random.uniform(0, 10)) perturb_axis = np.random.normal(size=3) perturb_axis /= np.linalg.norm(perturb_axis) rotation_matrix = R.from_rotvec(perturb_angle * perturb_axis).as_matrix() perturbed_normal = np.dot(rotation_matrix, normal) middle_distance = (distance_range[0] + distance_range[1]) / 2 perturbed_distance = random.uniform(middle_distance-0.05, middle_distance+0.05) cam_position = look_at_point + perturbed_distance * perturbed_normal look_at_points.append(look_at_point) cam_positions.append(cam_position) bm.free() look_at_points = np.array(look_at_points) cam_positions = np.array(cam_positions) voxel_downsampled_look_at_points, selected_indices = CADViewSampleUtil.voxel_downsample(look_at_points, voxel_size) voxel_downsampled_cam_positions = cam_positions[selected_indices] voxel_downsampled_normals = np.array(normals)[selected_indices] fps_downsampled_look_at_points, selected_indices = CADViewSampleUtil.farthest_point_sampling(voxel_downsampled_look_at_points, max_views*2) fps_downsampled_cam_positions = voxel_downsampled_cam_positions[selected_indices] view_data["look_at_points"] = fps_downsampled_look_at_points.tolist() view_data["cam_positions"] = fps_downsampled_cam_positions.tolist() view_data["normals"] = voxel_downsampled_normals view_data["voxel_down_sampled_points"] = voxel_downsampled_look_at_points return view_data @staticmethod def get_world_points_and_normals(view_data: dict, obj_world_pose: np.ndarray) -> tuple: world_points = [] world_normals = [] for voxel_down_sampled_points, normal in zip(view_data["voxel_down_sampled_points"], view_data["normals"]): voxel_down_sampled_points_world = obj_world_pose @ np.append(voxel_down_sampled_points, 1.0) normal_world = obj_world_pose[:3, :3] @ normal world_points.append(voxel_down_sampled_points_world[:3]) world_normals.append(normal_world) return np.array(world_points), np.array(world_normals) @staticmethod def get_cam_pose(view_data: dict, obj_world_pose: np.ndarray, max_views: int) -> np.ndarray: cam_poses = [] for look_at_point, cam_position in zip(view_data["look_at_points"], view_data["cam_positions"]): look_at_point_world = obj_world_pose @ np.append(look_at_point, 1.0) cam_position_world = obj_world_pose @ np.append(cam_position, 1.0) look_at_point_world = look_at_point_world[:3] cam_position_world = cam_position_world[:3] forward_vector = cam_position_world - look_at_point_world forward_vector /= np.linalg.norm(forward_vector) up_vector = np.array([0, 0, 1]) right_vector = np.cross(up_vector, forward_vector) rotation_matrix = np.array([right_vector, up_vector, forward_vector]).T cam_pose = np.eye(4) cam_pose[:3, :3] = rotation_matrix cam_pose[:3, 3] = cam_position_world cam_poses.append(cam_pose) if len(cam_poses) > max_views: cam_points = np.array([cam_pose[:3, 3] for cam_pose in cam_poses]) _, indices = PtsUtil.fps_downsample_point_cloud(cam_points, max_views, require_idx=True) cam_poses = [cam_poses[i] for i in indices] return np.array(cam_poses) @staticmethod def sample_view_data_world_space(obj, distance_range:tuple = (0.3,0.5), voxel_size:float = 0.005, max_views: int=1, min_cam_table_included_degree:int=20, random_view_ratio:float = 0.2) -> dict: obj_world_pose = np.asarray(obj.matrix_world) view_data = CADViewSampleUtil.sample_view_data(obj, distance_range, voxel_size, max_views) view_data["cam_poses"] = CADViewSampleUtil.get_cam_pose(view_data, obj_world_pose, max_views, min_cam_table_included_degree, random_view_ratio) view_data["voxel_down_sampled_points"], view_data["normals"] = CADViewSampleUtil.get_world_points_and_normals(view_data, obj_world_pose) return view_data