new_nbv_rec/beans/predict_result.py
2025-05-19 16:32:04 +08:00

165 lines
6.7 KiB
Python

import numpy as np
from sklearn.cluster import DBSCAN
class PredictResult:
def __init__(self, raw_predict_result, input_pts=None, cluster_params=dict(eps=0.5, min_samples=2)):
self.input_pts = input_pts
self.cluster_params = cluster_params
self.sampled_9d_pose = raw_predict_result
self.sampled_matrix_pose = self.get_sampled_matrix_pose()
self.distance_matrix = self.calculate_distance_matrix()
self.clusters = self.get_cluster_result()
self.candidate_matrix_poses = self.get_candidate_poses()
self.candidate_9d_poses = [np.concatenate((self.matrix_to_rotation_6d_numpy(matrix[:3,:3]), matrix[:3,3].reshape(-1,)), axis=-1) for matrix in self.candidate_matrix_poses]
self.cluster_num = len(self.clusters)
@staticmethod
def rotation_6d_to_matrix_numpy(d6):
a1, a2 = d6[:3], d6[3:]
b1 = a1 / np.linalg.norm(a1)
b2 = a2 - np.dot(b1, a2) * b1
b2 = b2 / np.linalg.norm(b2)
b3 = np.cross(b1, b2)
return np.stack((b1, b2, b3), axis=-2)
@staticmethod
def matrix_to_rotation_6d_numpy(matrix):
return np.copy(matrix[:2, :]).reshape((6,))
def __str__(self):
info = "Predict Result:\n"
info += f" Predicted pose number: {len(self.sampled_9d_pose)}\n"
info += f" Cluster number: {self.cluster_num}\n"
for i, cluster in enumerate(self.clusters):
info += f" - Cluster {i} size: {len(cluster)}\n"
max_distance = np.max(self.distance_matrix[self.distance_matrix != 0])
min_distance = np.min(self.distance_matrix[self.distance_matrix != 0])
info += f" Max distance: {max_distance}\n"
info += f" Min distance: {min_distance}\n"
return info
def get_sampled_matrix_pose(self):
sampled_matrix_pose = []
for pose in self.sampled_9d_pose:
rotation = pose[:6]
translation = pose[6:]
pose = self.rotation_6d_to_matrix_numpy(rotation)
pose = np.concatenate((pose, translation.reshape(-1, 1)), axis=-1)
pose = np.concatenate((pose, np.array([[0, 0, 0, 1]])), axis=-2)
sampled_matrix_pose.append(pose)
return np.array(sampled_matrix_pose)
def rotation_distance(self, R1, R2):
R = np.dot(R1.T, R2)
trace = np.trace(R)
angle = np.arccos(np.clip((trace - 1) / 2, -1, 1))
return angle
def calculate_distance_matrix(self):
n = len(self.sampled_matrix_pose)
dist_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
dist_matrix[i, j] = self.rotation_distance(self.sampled_matrix_pose[i][:3, :3], self.sampled_matrix_pose[j][:3, :3])
return dist_matrix
def cluster_rotations(self):
clustering = DBSCAN(eps=self.cluster_params['eps'], min_samples=self.cluster_params['min_samples'], metric='precomputed')
labels = clustering.fit_predict(self.distance_matrix)
return labels
def get_cluster_result(self):
labels = self.cluster_rotations()
cluster_num = len(set(labels)) - (1 if -1 in labels else 0)
clusters = []
for _ in range(cluster_num):
clusters.append([])
for matrix_pose, label in zip(self.sampled_matrix_pose, labels):
if label != -1:
clusters[label].append(matrix_pose)
clusters.sort(key=len, reverse=True)
return clusters
def get_center_matrix_pose_from_cluster(self, cluster):
min_total_distance = float('inf')
center_matrix_pose = None
if len(cluster) == 1:
return cluster[0]
for matrix_pose in cluster:
total_distance = 0
for other_matrix_pose in cluster:
rot_distance = self.rotation_distance(matrix_pose[:3, :3], other_matrix_pose[:3, :3])
total_distance += rot_distance
if total_distance < min_total_distance:
min_total_distance = total_distance
center_matrix_pose = matrix_pose
return center_matrix_pose
def get_candidate_poses(self):
candidate_poses = []
for cluster in self.clusters:
candidate_poses.append(self.get_center_matrix_pose_from_cluster(cluster))
return candidate_poses
def visualize(self):
import plotly.graph_objects as go
fig = go.Figure()
if self.input_pts is not None:
fig.add_trace(go.Scatter3d(
x=self.input_pts[:, 0], y=self.input_pts[:, 1], z=self.input_pts[:, 2],
mode='markers', marker=dict(size=1, color='gray', opacity=0.5), name='Input Points'
))
colors = ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance',
'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg']
for i, cluster in enumerate(self.clusters):
color = colors[i]
candidate_pose = self.candidate_matrix_poses[i]
origin_candidate = candidate_pose[:3, 3]
z_axis_candidate = candidate_pose[:3, 2]
for pose in cluster:
origin = pose[:3, 3]
z_axis = pose[:3, 2]
fig.add_trace(go.Cone(
x=[origin[0]], y=[origin[1]], z=[origin[2]],
u=[z_axis[0]], v=[z_axis[1]], w=[z_axis[2]],
colorscale=color,
sizemode="absolute", sizeref=0.05, anchor="tail", showscale=False
))
fig.add_trace(go.Cone(
x=[origin_candidate[0]], y=[origin_candidate[1]], z=[origin_candidate[2]],
u=[z_axis_candidate[0]], v=[z_axis_candidate[1]], w=[z_axis_candidate[2]],
colorscale=color,
sizemode="absolute", sizeref=0.1, anchor="tail", showscale=False
))
fig.update_layout(
title="Clustered Poses and Input Points",
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Z'
),
margin=dict(l=0, r=0, b=0, t=40),
scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25))
)
fig.show()
if __name__ == "__main__":
step = 0
raw_predict_result = np.load(f"inference_result_pack/inference_result_pack/{step}/all_pred_pose_9d.npy")
input_pts = np.loadtxt(f"inference_result_pack/inference_result_pack/{step}/input_pts.txt")
print(raw_predict_result.shape)
predict_result = PredictResult(raw_predict_result, input_pts, cluster_params=dict(eps=0.25, min_samples=3))
print(predict_result)
print(len(predict_result.candidate_matrix_poses))
print(predict_result.distance_matrix)
#import ipdb; ipdb.set_trace()
predict_result.visualize()