192 lines
9.5 KiB
Python
192 lines
9.5 KiB
Python
|
|
import numpy as np
|
|
import bmesh
|
|
from collections import defaultdict
|
|
from scipy.spatial.transform import Rotation as R
|
|
from utils.pose import PoseUtil
|
|
from utils.pts import PtsUtil
|
|
import random
|
|
|
|
class ViewSampleUtil:
|
|
@staticmethod
|
|
def farthest_point_sampling(points, num_samples):
|
|
num_points = points.shape[0]
|
|
if num_samples >= num_points:
|
|
return points, np.arange(num_points)
|
|
sampled_indices = np.zeros(num_samples, dtype=int)
|
|
sampled_indices[0] = np.random.randint(num_points)
|
|
min_distances = np.full(num_points, np.inf)
|
|
for i in range(1, num_samples):
|
|
current_point = points[sampled_indices[i - 1]]
|
|
dist_to_current_point = np.linalg.norm(points - current_point, axis=1)
|
|
min_distances = np.minimum(min_distances, dist_to_current_point)
|
|
sampled_indices[i] = np.argmax(min_distances)
|
|
downsampled_points = points[sampled_indices]
|
|
return downsampled_points, sampled_indices
|
|
|
|
@staticmethod
|
|
def voxel_downsample(points, voxel_size):
|
|
voxel_grid = defaultdict(list)
|
|
for i, point in enumerate(points):
|
|
voxel_index = tuple((point // voxel_size).astype(int))
|
|
voxel_grid[voxel_index].append(i)
|
|
|
|
downsampled_points = []
|
|
downsampled_indices = []
|
|
for indices in voxel_grid.values():
|
|
selected_index = indices[0]
|
|
downsampled_points.append(points[selected_index])
|
|
downsampled_indices.append(selected_index)
|
|
|
|
return np.array(downsampled_points), downsampled_indices
|
|
|
|
@staticmethod
|
|
def sample_view_data(obj, distance_range:tuple = (0.25,0.5), voxel_size:float = 0.005, max_views: int = 1, pertube_repeat:int = 1) -> dict:
|
|
view_data = {
|
|
"look_at_points": [],
|
|
"cam_positions": [],
|
|
}
|
|
mesh = obj.data
|
|
bm = bmesh.new()
|
|
bm.from_mesh(mesh)
|
|
bm.verts.ensure_lookup_table()
|
|
bm.faces.ensure_lookup_table()
|
|
bm.normal_update()
|
|
|
|
look_at_points = []
|
|
cam_positions = []
|
|
normals = []
|
|
for v in bm.verts:
|
|
look_at_point = np.array(v.co)
|
|
|
|
view_data["look_at_points"].append(look_at_point)
|
|
normal = np.zeros(3)
|
|
for loop in v.link_loops:
|
|
normal += np.array(loop.calc_normal())
|
|
normal /= len(v.link_loops)
|
|
normal = normal / np.linalg.norm(normal)
|
|
if np.isnan(normal).any():
|
|
continue
|
|
if np.dot(normal, look_at_point) < 0:
|
|
normal = -normal
|
|
normals.append(normal)
|
|
|
|
for _ in range(pertube_repeat):
|
|
perturb_angle = np.radians(np.random.uniform(0, 10))
|
|
perturb_axis = np.random.normal(size=3)
|
|
perturb_axis /= np.linalg.norm(perturb_axis)
|
|
rotation_matrix = R.from_rotvec(perturb_angle * perturb_axis).as_matrix()
|
|
perturbed_normal = np.dot(rotation_matrix, normal)
|
|
middle_distance = (distance_range[0] + distance_range[1]) / 2
|
|
perturbed_distance = random.uniform(middle_distance-0.05, middle_distance+0.05)
|
|
cam_position = look_at_point + perturbed_distance * perturbed_normal
|
|
look_at_points.append(look_at_point)
|
|
cam_positions.append(cam_position)
|
|
|
|
|
|
bm.free()
|
|
look_at_points = np.array(look_at_points)
|
|
cam_positions = np.array(cam_positions)
|
|
voxel_downsampled_look_at_points, selected_indices = ViewSampleUtil.voxel_downsample(look_at_points, voxel_size)
|
|
voxel_downsampled_cam_positions = cam_positions[selected_indices]
|
|
voxel_downsampled_normals = np.array(normals)[selected_indices]
|
|
|
|
fps_downsampled_look_at_points, selected_indices = ViewSampleUtil.farthest_point_sampling(voxel_downsampled_look_at_points, max_views*2)
|
|
fps_downsampled_cam_positions = voxel_downsampled_cam_positions[selected_indices]
|
|
|
|
view_data["look_at_points"] = fps_downsampled_look_at_points.tolist()
|
|
view_data["cam_positions"] = fps_downsampled_cam_positions.tolist()
|
|
view_data["normals"] = voxel_downsampled_normals
|
|
view_data["voxel_down_sampled_points"] = voxel_downsampled_look_at_points
|
|
return view_data
|
|
|
|
@staticmethod
|
|
def get_world_points_and_normals(view_data: dict, obj_world_pose: np.ndarray) -> tuple:
|
|
world_points = []
|
|
world_normals = []
|
|
for voxel_down_sampled_points, normal in zip(view_data["voxel_down_sampled_points"], view_data["normals"]):
|
|
voxel_down_sampled_points_world = obj_world_pose @ np.append(voxel_down_sampled_points, 1.0)
|
|
normal_world = obj_world_pose[:3, :3] @ normal
|
|
world_points.append(voxel_down_sampled_points_world[:3])
|
|
world_normals.append(normal_world)
|
|
return np.array(world_points), np.array(world_normals)
|
|
|
|
@staticmethod
|
|
def get_cam_pose(view_data: dict, obj_world_pose: np.ndarray, max_views: int, min_cam_table_included_degree: int, random_view_ratio: float) -> np.ndarray:
|
|
cam_poses = []
|
|
min_height_z = 1000
|
|
for look_at_point, cam_position in zip(view_data["look_at_points"], view_data["cam_positions"]):
|
|
look_at_point_world = obj_world_pose @ np.append(look_at_point, 1.0)
|
|
cam_position_world = obj_world_pose @ np.append(cam_position, 1.0)
|
|
if look_at_point_world[2] < min_height_z:
|
|
min_height_z = look_at_point_world[2]
|
|
look_at_point_world = look_at_point_world[:3]
|
|
cam_position_world = cam_position_world[:3]
|
|
|
|
forward_vector = cam_position_world - look_at_point_world
|
|
forward_vector /= np.linalg.norm(forward_vector)
|
|
|
|
up_vector = np.array([0, 0, 1])
|
|
|
|
dot_product = np.dot(forward_vector, up_vector)
|
|
angle = np.degrees(np.arccos(dot_product))
|
|
right_vector = np.cross(up_vector, forward_vector)
|
|
|
|
if angle > 90 - min_cam_table_included_degree:
|
|
max_angle = 90 - min_cam_table_included_degree
|
|
min_angle = max(90 - min_cam_table_included_degree*2, 30)
|
|
target_angle = np.random.uniform(min_angle, max_angle)
|
|
angle_difference = np.radians(target_angle - angle)
|
|
|
|
rotation_axis = np.cross(forward_vector, up_vector)
|
|
rotation_axis /= np.linalg.norm(rotation_axis)
|
|
rotation_matrix = PoseUtil.rotation_matrix_from_axis_angle(rotation_axis, -angle_difference)
|
|
new_cam_position_world = np.dot(rotation_matrix, cam_position_world - look_at_point_world) + look_at_point_world
|
|
cam_position_world = new_cam_position_world
|
|
forward_vector = cam_position_world - look_at_point_world
|
|
forward_vector /= np.linalg.norm(forward_vector)
|
|
right_vector = np.cross(up_vector, forward_vector)
|
|
right_vector /= np.linalg.norm(right_vector)
|
|
|
|
corrected_up_vector = np.cross(forward_vector, right_vector)
|
|
rotation_matrix = np.array([right_vector, corrected_up_vector, forward_vector]).T
|
|
else:
|
|
right_vector = np.cross(up_vector, forward_vector)
|
|
right_vector /= np.linalg.norm(right_vector)
|
|
corrected_up_vector = np.cross(forward_vector, right_vector)
|
|
rotation_matrix = np.array([right_vector, corrected_up_vector, forward_vector]).T
|
|
cam_pose = np.eye(4)
|
|
cam_pose[:3, :3] = rotation_matrix
|
|
cam_pose[:3, 3] = cam_position_world
|
|
cam_poses.append(cam_pose)
|
|
|
|
filtered_cam_poses = []
|
|
for cam_pose in cam_poses:
|
|
if cam_pose[2, 3] > min_height_z:
|
|
direction_vector = cam_pose[:3, 2]
|
|
horizontal_normal = np.array([0, 0, 1])
|
|
cos_angle = np.dot(direction_vector, horizontal_normal) / (np.linalg.norm(direction_vector) * np.linalg.norm(horizontal_normal))
|
|
angle = np.arccos(np.clip(cos_angle, -1.0, 1.0))
|
|
angle_degree = np.degrees(angle)
|
|
if angle_degree < 90 - min_cam_table_included_degree:
|
|
filtered_cam_poses.append(cam_pose)
|
|
if random.random() < random_view_ratio:
|
|
pertube_pose = PoseUtil.get_uniform_pose([0.1, 0.1, 0.1], [3, 3, 3], 0, 180, "cm")
|
|
filtered_cam_poses.append(pertube_pose @ cam_pose)
|
|
|
|
if len(filtered_cam_poses) > max_views:
|
|
cam_points = np.array([cam_pose[:3, 3] for cam_pose in filtered_cam_poses])
|
|
_, indices = PtsUtil.fps_downsample_point_cloud(cam_points, max_views, require_idx=True)
|
|
filtered_cam_poses = [filtered_cam_poses[i] for i in indices]
|
|
|
|
return np.array(filtered_cam_poses)
|
|
|
|
@staticmethod
|
|
def sample_view_data_world_space(obj, distance_range:tuple = (0.3,0.5), voxel_size:float = 0.005, max_views: int=1, min_cam_table_included_degree:int=20, random_view_ratio:float = 0.2) -> dict:
|
|
obj_world_pose = np.asarray(obj.matrix_world)
|
|
view_data = ViewSampleUtil.sample_view_data(obj, distance_range, voxel_size, max_views)
|
|
view_data["cam_poses"] = ViewSampleUtil.get_cam_pose(view_data, obj_world_pose, max_views, min_cam_table_included_degree, random_view_ratio)
|
|
view_data["voxel_down_sampled_points"], view_data["normals"] = ViewSampleUtil.get_world_points_and_normals(view_data, obj_world_pose)
|
|
return view_data
|
|
|