165 lines
6.7 KiB
Python
165 lines
6.7 KiB
Python
import numpy as np
|
|
from sklearn.cluster import DBSCAN
|
|
|
|
class PredictResult:
|
|
def __init__(self, raw_predict_result, input_pts=None, cluster_params=dict(eps=0.5, min_samples=2)):
|
|
self.input_pts = input_pts
|
|
self.cluster_params = cluster_params
|
|
self.sampled_9d_pose = raw_predict_result
|
|
self.sampled_matrix_pose = self.get_sampled_matrix_pose()
|
|
self.distance_matrix = self.calculate_distance_matrix()
|
|
self.clusters = self.get_cluster_result()
|
|
self.candidate_matrix_poses = self.get_candidate_poses()
|
|
self.candidate_9d_poses = [np.concatenate((self.matrix_to_rotation_6d_numpy(matrix[:3,:3]), matrix[:3,3].reshape(-1,)), axis=-1) for matrix in self.candidate_matrix_poses]
|
|
self.cluster_num = len(self.clusters)
|
|
|
|
@staticmethod
|
|
def rotation_6d_to_matrix_numpy(d6):
|
|
a1, a2 = d6[:3], d6[3:]
|
|
b1 = a1 / np.linalg.norm(a1)
|
|
b2 = a2 - np.dot(b1, a2) * b1
|
|
b2 = b2 / np.linalg.norm(b2)
|
|
b3 = np.cross(b1, b2)
|
|
return np.stack((b1, b2, b3), axis=-2)
|
|
|
|
@staticmethod
|
|
def matrix_to_rotation_6d_numpy(matrix):
|
|
return np.copy(matrix[:2, :]).reshape((6,))
|
|
|
|
def __str__(self):
|
|
info = "Predict Result:\n"
|
|
info += f" Predicted pose number: {len(self.sampled_9d_pose)}\n"
|
|
info += f" Cluster number: {self.cluster_num}\n"
|
|
for i, cluster in enumerate(self.clusters):
|
|
info += f" - Cluster {i} size: {len(cluster)}\n"
|
|
max_distance = np.max(self.distance_matrix[self.distance_matrix != 0])
|
|
min_distance = np.min(self.distance_matrix[self.distance_matrix != 0])
|
|
info += f" Max distance: {max_distance}\n"
|
|
info += f" Min distance: {min_distance}\n"
|
|
return info
|
|
|
|
def get_sampled_matrix_pose(self):
|
|
sampled_matrix_pose = []
|
|
for pose in self.sampled_9d_pose:
|
|
rotation = pose[:6]
|
|
translation = pose[6:]
|
|
pose = self.rotation_6d_to_matrix_numpy(rotation)
|
|
pose = np.concatenate((pose, translation.reshape(-1, 1)), axis=-1)
|
|
pose = np.concatenate((pose, np.array([[0, 0, 0, 1]])), axis=-2)
|
|
sampled_matrix_pose.append(pose)
|
|
return np.array(sampled_matrix_pose)
|
|
|
|
def rotation_distance(self, R1, R2):
|
|
R = np.dot(R1.T, R2)
|
|
trace = np.trace(R)
|
|
angle = np.arccos(np.clip((trace - 1) / 2, -1, 1))
|
|
return angle
|
|
|
|
def calculate_distance_matrix(self):
|
|
n = len(self.sampled_matrix_pose)
|
|
dist_matrix = np.zeros((n, n))
|
|
for i in range(n):
|
|
for j in range(n):
|
|
dist_matrix[i, j] = self.rotation_distance(self.sampled_matrix_pose[i][:3, :3], self.sampled_matrix_pose[j][:3, :3])
|
|
return dist_matrix
|
|
|
|
def cluster_rotations(self):
|
|
clustering = DBSCAN(eps=self.cluster_params['eps'], min_samples=self.cluster_params['min_samples'], metric='precomputed')
|
|
labels = clustering.fit_predict(self.distance_matrix)
|
|
return labels
|
|
|
|
def get_cluster_result(self):
|
|
labels = self.cluster_rotations()
|
|
cluster_num = len(set(labels)) - (1 if -1 in labels else 0)
|
|
clusters = []
|
|
for _ in range(cluster_num):
|
|
clusters.append([])
|
|
for matrix_pose, label in zip(self.sampled_matrix_pose, labels):
|
|
if label != -1:
|
|
clusters[label].append(matrix_pose)
|
|
clusters.sort(key=len, reverse=True)
|
|
return clusters
|
|
|
|
def get_center_matrix_pose_from_cluster(self, cluster):
|
|
min_total_distance = float('inf')
|
|
center_matrix_pose = None
|
|
if len(cluster) == 1:
|
|
return cluster[0]
|
|
for matrix_pose in cluster:
|
|
total_distance = 0
|
|
for other_matrix_pose in cluster:
|
|
rot_distance = self.rotation_distance(matrix_pose[:3, :3], other_matrix_pose[:3, :3])
|
|
total_distance += rot_distance
|
|
|
|
if total_distance < min_total_distance:
|
|
min_total_distance = total_distance
|
|
center_matrix_pose = matrix_pose
|
|
|
|
|
|
return center_matrix_pose
|
|
|
|
def get_candidate_poses(self):
|
|
candidate_poses = []
|
|
for cluster in self.clusters:
|
|
candidate_poses.append(self.get_center_matrix_pose_from_cluster(cluster))
|
|
return candidate_poses
|
|
|
|
def visualize(self):
|
|
import plotly.graph_objects as go
|
|
fig = go.Figure()
|
|
if self.input_pts is not None:
|
|
fig.add_trace(go.Scatter3d(
|
|
x=self.input_pts[:, 0], y=self.input_pts[:, 1], z=self.input_pts[:, 2],
|
|
mode='markers', marker=dict(size=1, color='gray', opacity=0.5), name='Input Points'
|
|
))
|
|
colors = ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance',
|
|
'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg']
|
|
for i, cluster in enumerate(self.clusters):
|
|
color = colors[i]
|
|
candidate_pose = self.candidate_matrix_poses[i]
|
|
origin_candidate = candidate_pose[:3, 3]
|
|
z_axis_candidate = candidate_pose[:3, 2]
|
|
for pose in cluster:
|
|
origin = pose[:3, 3]
|
|
z_axis = pose[:3, 2]
|
|
fig.add_trace(go.Cone(
|
|
x=[origin[0]], y=[origin[1]], z=[origin[2]],
|
|
u=[z_axis[0]], v=[z_axis[1]], w=[z_axis[2]],
|
|
colorscale=color,
|
|
sizemode="absolute", sizeref=0.05, anchor="tail", showscale=False
|
|
))
|
|
fig.add_trace(go.Cone(
|
|
x=[origin_candidate[0]], y=[origin_candidate[1]], z=[origin_candidate[2]],
|
|
u=[z_axis_candidate[0]], v=[z_axis_candidate[1]], w=[z_axis_candidate[2]],
|
|
colorscale=color,
|
|
sizemode="absolute", sizeref=0.1, anchor="tail", showscale=False
|
|
))
|
|
|
|
fig.update_layout(
|
|
title="Clustered Poses and Input Points",
|
|
scene=dict(
|
|
xaxis_title='X',
|
|
yaxis_title='Y',
|
|
zaxis_title='Z'
|
|
),
|
|
margin=dict(l=0, r=0, b=0, t=40),
|
|
scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25))
|
|
)
|
|
|
|
fig.show()
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
step = 0
|
|
raw_predict_result = np.load(f"inference_result_pack/inference_result_pack/{step}/all_pred_pose_9d.npy")
|
|
input_pts = np.loadtxt(f"inference_result_pack/inference_result_pack/{step}/input_pts.txt")
|
|
print(raw_predict_result.shape)
|
|
predict_result = PredictResult(raw_predict_result, input_pts, cluster_params=dict(eps=0.25, min_samples=3))
|
|
print(predict_result)
|
|
print(len(predict_result.candidate_matrix_poses))
|
|
print(predict_result.distance_matrix)
|
|
#import ipdb; ipdb.set_trace()
|
|
predict_result.visualize()
|
|
|