From 2df4f09d0b13ccb8c4beb7420c3b5d915f5987b7 Mon Sep 17 00:00:00 2001 From: Michel Breyer Date: Mon, 8 Nov 2021 14:39:44 +0100 Subject: [PATCH] Fix --- src/active_grasp/policy.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/active_grasp/policy.py b/src/active_grasp/policy.py index 110a2e5..c23b352 100644 --- a/src/active_grasp/policy.py +++ b/src/active_grasp/policy.py @@ -78,8 +78,13 @@ class Policy: def update(self, img, x, q): raise NotImplementedError - def select_best_grasp(self, grasps, qualities, q): - filtered_grasps, scores = [], [] + def filter_grasps(self, out, q): + grasps, qualities = select_local_maxima( + self.tsdf.voxel_size, + out, + self.qual_thresh, + ) + filtered_grasps, filtered_qualities = [], [] for grasp, quality in zip(grasps, qualities): pose = self.T_base_task * grasp.pose R, t = pose.rotation, pose.translation @@ -89,12 +94,13 @@ class Policy: q_grasp = self.solve_ee_ik(q, pose * self.T_grasp_ee) if q_grasp is not None: filtered_grasps.append(grasp) - scores.append(self.score_fn(grasp, quality, q, q_grasp)) - i = np.argmax(scores) - return filtered_grasps[i], qualities[i], scores[i] + filtered_qualities.append(quality) + return filtered_grasps, filtered_qualities - def score_fn(self, grasp, quality, q, q_grasp): - return quality + +def select_best_grasp(grasps, qualities): + i = np.argmax(qualities) + return grasps[i], qualities[i] class SingleViewPolicy(Policy): @@ -109,11 +115,12 @@ class SingleViewPolicy(Policy): out = self.vgn.predict(tsdf_grid) self.vis.quality(self.task_frame, voxel_size, out.qual, 0.5) - grasps, qualities = select_local_maxima(voxel_size, out, self.qual_thresh) + + grasps, qualities = self.filter_grasps(out, q) if len(grasps) > 0: - self.best_grasp, qual, _ = self.select_best_grasp(grasps, qualities, q) - self.vis.grasp(self.base_frame, self.best_grasp, qual) + self.best_grasp, quality = select_best_grasp(grasps, qualities) + self.vis.grasp(self.base_frame, self.best_grasp, quality) self.done = True @@ -137,7 +144,7 @@ class MultiViewPolicy(Policy): self.vis.map_cloud(self.task_frame, self.tsdf.get_map_cloud()) with Timer("grasp_prediction"): - tsdf_grid, voxel_size = self.tsdf.get_grid(), self.tsdf.voxel_size + tsdf_grid = self.tsdf.get_grid() out = self.vgn.predict(tsdf_grid) self.vis.quality(self.task_frame, self.tsdf.voxel_size, out.qual, 0.9) @@ -145,10 +152,10 @@ class MultiViewPolicy(Policy): self.qual_hist[t, ...] = out.qual with Timer("grasp_selection"): - grasps, qualities = select_local_maxima(voxel_size, out, self.qual_thresh) + grasps, qualities = self.filter_grasps(out, q) if len(grasps) > 0: - self.best_grasp, quality, _ = self.select_best_grasp(grasps, qualities, q) + self.best_grasp, quality = select_best_grasp(grasps, qualities) self.vis.grasp(self.base_frame, self.best_grasp, quality) else: self.best_grasp = None