nbv_reconstruction/utils/reconstruction.py
2024-10-05 12:24:53 -05:00

169 lines
8.2 KiB
Python

import numpy as np
from scipy.spatial import cKDTree
from utils.pts import PtsUtil
class ReconstructionUtil:
@staticmethod
def compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold=0.01):
kdtree = cKDTree(combined_point_cloud)
distances, _ = kdtree.query(target_point_cloud)
covered_points = np.sum(distances < threshold*2)
coverage_rate = covered_points / target_point_cloud.shape[0]
return coverage_rate
@staticmethod
def compute_overlap_rate(new_point_cloud, combined_point_cloud, threshold=0.01):
kdtree = cKDTree(combined_point_cloud)
distances, _ = kdtree.query(new_point_cloud)
overlapping_points = np.sum(distances < threshold)
if new_point_cloud.shape[0] == 0:
overlap_rate = 0
else:
overlap_rate = overlapping_points / new_point_cloud.shape[0]
return overlap_rate
@staticmethod
def combine_point_with_view_sequence(point_list, view_sequence):
selected_views = []
for view_index, _ in view_sequence:
selected_views.append(point_list[view_index])
return np.vstack(selected_views)
@staticmethod
def compute_next_view_coverage_list(views, combined_point_cloud, target_point_cloud, threshold=0.01):
best_view = None
best_coverage_increase = -1
current_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold)
for view_index, view in enumerate(views):
candidate_views = combined_point_cloud + [view]
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(candidate_views, threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
coverage_increase = new_coverage - current_coverage
if coverage_increase > best_coverage_increase:
best_coverage_increase = coverage_increase
best_view = view_index
return best_view, best_coverage_increase
@staticmethod
def get_new_added_points(old_combined_pts, new_pts, threshold=0.005):
if old_combined_pts.size == 0:
return new_pts
if new_pts.size == 0:
return np.array([])
tree = cKDTree(old_combined_pts)
distances, _ = tree.query(new_pts, k=1)
new_added_points = new_pts[distances > threshold]
return new_added_points
@staticmethod
def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, scan_points_indices_list, threshold=0.01, soft_overlap_threshold=0.5, hard_overlap_threshold=0.7, init_view = 0, scan_points_threshold=5, status_info=None):
selected_views = [point_cloud_list[init_view]]
combined_point_cloud = np.vstack(selected_views)
history_indices = [scan_points_indices_list[init_view]]
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
current_coverage = new_coverage
remaining_views = list(range(len(point_cloud_list)))
view_sequence = [(init_view, current_coverage)]
cnt_processed_view = 0
remaining_views.remove(init_view)
while remaining_views:
best_view = None
best_coverage_increase = -1
for view_index in remaining_views:
if point_cloud_list[view_index].shape[0] == 0:
continue
if selected_views:
new_scan_points_indices = scan_points_indices_list[view_index]
if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold):
overlap_threshold = hard_overlap_threshold
else:
overlap_threshold = soft_overlap_threshold
combined_old_point_cloud = np.vstack(selected_views)
down_sampled_old_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_old_point_cloud,threshold)
down_sampled_new_view_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud_list[view_index],threshold)
overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_new_view_point_cloud,down_sampled_old_point_cloud, threshold)
if overlap_rate < overlap_threshold:
continue
candidate_views = selected_views + [point_cloud_list[view_index]]
combined_point_cloud = np.vstack(candidate_views)
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
coverage_increase = new_coverage - current_coverage
if coverage_increase > best_coverage_increase:
best_coverage_increase = coverage_increase
best_view = view_index
if best_view is not None:
if best_coverage_increase <=3e-3:
break
selected_views.append(point_cloud_list[best_view])
remaining_views.remove(best_view)
history_indices.append(scan_points_indices_list[best_view])
current_coverage += best_coverage_increase
cnt_processed_view += 1
if status_info is not None:
sm = status_info["status_manager"]
app_name = status_info["app_name"]
runner_name = status_info["runner_name"]
sm.set_status(app_name, runner_name, "current coverage", current_coverage)
sm.set_progress(app_name, runner_name, "processed view", cnt_processed_view, len(point_cloud_list))
view_sequence.append((best_view, current_coverage))
else:
break
if status_info is not None:
sm = status_info["status_manager"]
app_name = status_info["app_name"]
runner_name = status_info["runner_name"]
sm.set_progress(app_name, runner_name, "processed view", len(point_cloud_list), len(point_cloud_list))
return view_sequence, remaining_views, down_sampled_combined_point_cloud
@staticmethod
def generate_scan_points(display_table_top, display_table_radius, min_distance=0.03, max_points_num = 500, max_attempts = 1000):
points = []
attempts = 0
while len(points) < max_points_num and attempts < max_attempts:
angle = np.random.uniform(0, 2 * np.pi)
r = np.random.uniform(0, display_table_radius)
x = r * np.cos(angle)
y = r * np.sin(angle)
z = display_table_top
new_point = (x, y, z)
if all(np.linalg.norm(np.array(new_point) - np.array(existing_point)) >= min_distance for existing_point in points):
points.append(new_point)
attempts += 1
return points
@staticmethod
def compute_covered_scan_points(scan_points, point_cloud, threshold=0.01):
tree = cKDTree(point_cloud)
covered_points = []
indices = []
for i, scan_point in enumerate(scan_points):
if tree.query_ball_point(scan_point, threshold):
covered_points.append(scan_point)
indices.append(i)
return covered_points, indices
@staticmethod
def check_scan_points_overlap(history_indices, indices2, threshold=5):
for indices1 in history_indices:
if len(set(indices1).intersection(set(indices2))) >= threshold:
return True
return False