import numpy as np from sklearn.cluster import DBSCAN class PredictResult: def __init__(self, raw_predict_result, input_pts=None, cluster_params=dict(eps=0.5, min_samples=2)): self.input_pts = input_pts self.cluster_params = cluster_params self.sampled_9d_pose = raw_predict_result self.sampled_matrix_pose = self.get_sampled_matrix_pose() self.distance_matrix = self.calculate_distance_matrix() self.clusters = self.get_cluster_result() self.candidate_matrix_poses = self.get_candidate_poses() self.candidate_9d_poses = [np.concatenate((self.matrix_to_rotation_6d_numpy(matrix[:3,:3]), matrix[:3,3].reshape(-1,)), axis=-1) for matrix in self.candidate_matrix_poses] self.cluster_num = len(self.clusters) @staticmethod def rotation_6d_to_matrix_numpy(d6): a1, a2 = d6[:3], d6[3:] b1 = a1 / np.linalg.norm(a1) b2 = a2 - np.dot(b1, a2) * b1 b2 = b2 / np.linalg.norm(b2) b3 = np.cross(b1, b2) return np.stack((b1, b2, b3), axis=-2) @staticmethod def matrix_to_rotation_6d_numpy(matrix): return np.copy(matrix[:2, :]).reshape((6,)) def __str__(self): info = "Predict Result:\n" info += f" Predicted pose number: {len(self.sampled_9d_pose)}\n" info += f" Cluster number: {self.cluster_num}\n" for i, cluster in enumerate(self.clusters): info += f" - Cluster {i} size: {len(cluster)}\n" max_distance = np.max(self.distance_matrix[self.distance_matrix != 0]) min_distance = np.min(self.distance_matrix[self.distance_matrix != 0]) info += f" Max distance: {max_distance}\n" info += f" Min distance: {min_distance}\n" return info def get_sampled_matrix_pose(self): sampled_matrix_pose = [] for pose in self.sampled_9d_pose: rotation = pose[:6] translation = pose[6:] pose = self.rotation_6d_to_matrix_numpy(rotation) pose = np.concatenate((pose, translation.reshape(-1, 1)), axis=-1) pose = np.concatenate((pose, np.array([[0, 0, 0, 1]])), axis=-2) sampled_matrix_pose.append(pose) return np.array(sampled_matrix_pose) def rotation_distance(self, R1, R2): R = np.dot(R1.T, R2) trace = np.trace(R) angle = np.arccos(np.clip((trace - 1) / 2, -1, 1)) return angle def calculate_distance_matrix(self): n = len(self.sampled_matrix_pose) dist_matrix = np.zeros((n, n)) for i in range(n): for j in range(n): dist_matrix[i, j] = self.rotation_distance(self.sampled_matrix_pose[i][:3, :3], self.sampled_matrix_pose[j][:3, :3]) return dist_matrix def cluster_rotations(self): clustering = DBSCAN(eps=self.cluster_params['eps'], min_samples=self.cluster_params['min_samples'], metric='precomputed') labels = clustering.fit_predict(self.distance_matrix) return labels def get_cluster_result(self): labels = self.cluster_rotations() cluster_num = len(set(labels)) - (1 if -1 in labels else 0) clusters = [] for _ in range(cluster_num): clusters.append([]) for matrix_pose, label in zip(self.sampled_matrix_pose, labels): if label != -1: clusters[label].append(matrix_pose) clusters.sort(key=len, reverse=True) return clusters def get_center_matrix_pose_from_cluster(self, cluster): min_total_distance = float('inf') center_matrix_pose = None if len(cluster) == 1: return cluster[0] for matrix_pose in cluster: total_distance = 0 for other_matrix_pose in cluster: rot_distance = self.rotation_distance(matrix_pose[:3, :3], other_matrix_pose[:3, :3]) total_distance += rot_distance if total_distance < min_total_distance: min_total_distance = total_distance center_matrix_pose = matrix_pose return center_matrix_pose def get_candidate_poses(self): candidate_poses = [] for cluster in self.clusters: candidate_poses.append(self.get_center_matrix_pose_from_cluster(cluster)) return candidate_poses def visualize(self): import plotly.graph_objects as go fig = go.Figure() if self.input_pts is not None: fig.add_trace(go.Scatter3d( x=self.input_pts[:, 0], y=self.input_pts[:, 1], z=self.input_pts[:, 2], mode='markers', marker=dict(size=1, color='gray', opacity=0.5), name='Input Points' )) colors = ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance', 'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg'] for i, cluster in enumerate(self.clusters): color = colors[i] candidate_pose = self.candidate_matrix_poses[i] origin_candidate = candidate_pose[:3, 3] z_axis_candidate = candidate_pose[:3, 2] for pose in cluster: origin = pose[:3, 3] z_axis = pose[:3, 2] fig.add_trace(go.Cone( x=[origin[0]], y=[origin[1]], z=[origin[2]], u=[z_axis[0]], v=[z_axis[1]], w=[z_axis[2]], colorscale=color, sizemode="absolute", sizeref=0.05, anchor="tail", showscale=False )) fig.add_trace(go.Cone( x=[origin_candidate[0]], y=[origin_candidate[1]], z=[origin_candidate[2]], u=[z_axis_candidate[0]], v=[z_axis_candidate[1]], w=[z_axis_candidate[2]], colorscale=color, sizemode="absolute", sizeref=0.1, anchor="tail", showscale=False )) fig.update_layout( title="Clustered Poses and Input Points", scene=dict( xaxis_title='X', yaxis_title='Y', zaxis_title='Z' ), margin=dict(l=0, r=0, b=0, t=40), scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)) ) fig.show() if __name__ == "__main__": step = 0 raw_predict_result = np.load(f"inference_result_pack/inference_result_pack/{step}/all_pred_pose_9d.npy") input_pts = np.loadtxt(f"inference_result_pack/inference_result_pack/{step}/input_pts.txt") print(raw_predict_result.shape) predict_result = PredictResult(raw_predict_result, input_pts, cluster_params=dict(eps=0.25, min_samples=3)) print(predict_result) print(len(predict_result.candidate_matrix_poses)) print(predict_result.distance_matrix) #import ipdb; ipdb.set_trace() predict_result.visualize()