From e8dff9bf5c3e0ad2473449f038d5830ef2135b34 Mon Sep 17 00:00:00 2001 From: Michel Breyer Date: Sat, 11 Sep 2021 22:31:48 +0200 Subject: [PATCH] Add stable grasp prediction stopping criteria --- active_grasp/nbv.py | 15 ++++++++++++++- active_grasp/policy.py | 11 +++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/active_grasp/nbv.py b/active_grasp/nbv.py index 33b2f7a..3b9b38a 100644 --- a/active_grasp/nbv.py +++ b/active_grasp/nbv.py @@ -24,7 +24,7 @@ class NextBestView(MultiViewPolicy): self.view_candidates.append(view) def update(self, img, x): - if len(self.views) > self.max_views: + if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable(): self.done = True else: self.integrate(img, x) @@ -37,6 +37,19 @@ class NextBestView(MultiViewPolicy): nbv, _ = views[i], gains[i] self.x_d = nbv + def best_grasp_prediction_is_stable(self): + if self.best_grasp: + t = (self.T_task_base * self.best_grasp.pose).translation + i, j, k = (t / self.tsdf.voxel_size).astype(int) + qs = self.qual_hist[:, i, j, k] + if ( + np.count_nonzero(qs) == self.T + and np.mean(qs) > 0.9 + and np.std(qs) < 0.05 + ): + return True + return False + def ig_fn(self, view, downsample=20): fx = self.intrinsic.fx / downsample fy = self.intrinsic.fy / downsample diff --git a/active_grasp/policy.py b/active_grasp/policy.py index 7d52b68..017ab2d 100644 --- a/active_grasp/policy.py +++ b/active_grasp/policy.py @@ -47,7 +47,7 @@ class Policy: self.T_base_task = Transform.translation(self.bbox.center - np.full(3, 0.15)) self.T_task_base = self.T_base_task.inv() tf.broadcast(self.T_base_task, self.base_frame, self.task_frame) - rospy.sleep(0.5) # Wait for tf tree to be updated. + rospy.sleep(0.5) # Wait for tf tree to be updated self.vis.workspace(self.task_frame, 0.3) def update(self, img, pose): @@ -59,7 +59,6 @@ class Policy: for grasp in in_grasps: pose = self.T_base_task * grasp.pose - R, t = pose.rotation, pose.translation # Filter out artifacts close to the support @@ -102,6 +101,11 @@ class SingleViewPolicy(Policy): class MultiViewPolicy(Policy): + def activate(self, bbox, view_sphere): + super().activate(bbox, view_sphere) + self.T = 5 # Window size of grasp prediction history + self.qual_hist = np.zeros((self.T,) + (40,) * 3, np.float32) + def integrate(self, img, x): self.views.append(x) self.tsdf.integrate(img, self.intrinsic, x.inv() * self.T_base_task) @@ -114,6 +118,9 @@ class MultiViewPolicy(Policy): out = self.vgn.predict(tsdf_grid) self.vis.quality(self.task_frame, self.tsdf.voxel_size, out.qual, 0.5) + t = (len(self.views) - 1) % self.T + self.qual_hist[t, ...] = out.qual + grasps = select_grid(voxel_size, out, threshold=self.qual_threshold) grasps, scores = self.sort_grasps(grasps)