diff --git a/active_grasp/policy.py b/active_grasp/policy.py index e257321..f621eb6 100644 --- a/active_grasp/policy.py +++ b/active_grasp/policy.py @@ -138,7 +138,7 @@ class MultiViewPolicy(Policy): with Timer("grasp_prediction"): tsdf_grid, voxel_size = self.tsdf.get_grid(), self.tsdf.voxel_size out = self.vgn.predict(tsdf_grid) - self.vis.quality(self.task_frame, self.tsdf.voxel_size, out.qual, 0.5) + self.vis.quality(self.task_frame, self.tsdf.voxel_size, out.qual, 0.8) t = (len(self.views) - 1) % self.T self.qual_hist[t, ...] = out.qual @@ -150,10 +150,9 @@ class MultiViewPolicy(Policy): self.vis.clear_grasps() if len(grasps) > 0: - smin, smax = np.min(scores), np.max(scores) self.best_grasp = grasps[0] - self.vis.grasps(self.base_frame, grasps, scores, smin, smax) - self.vis.best_grasp(self.base_frame, grasps[0], scores[0], smin, smax) + self.vis.grasps(self.base_frame, grasps) + self.vis.best_grasp(self.base_frame, self.best_grasp) else: self.best_grasp = None diff --git a/active_grasp/visualization.py b/active_grasp/visualization.py index f9599d7..3e7a7b0 100644 --- a/active_grasp/visualization.py +++ b/active_grasp/visualization.py @@ -69,15 +69,14 @@ class Visualizer: marker = create_line_list_marker(frame, pose, scale, color, lines, ns="bbox") self.draw([marker]) - def best_grasp(self, frame, grasp, score, smin=0.9, smax=1.0, alpha=1.0): - color = cmap((score - smin) / (smax - smin)) - color = [color[0], color[1], color[2], alpha] + def best_grasp(self, frame, grasp, qmin=0.5, qmax=1.0): + color = cmap((grasp.quality - qmin) / (qmax - qmin)) self.draw(create_grasp_markers(frame, grasp, color, "best_grasp", radius=0.006)) - def grasps(self, frame, grasps, scores, smin=0.9, smax=1.0, alpha=0.8): + def grasps(self, frame, grasps, qmin=0.5, qmax=1.0, alpha=0.8): markers = [] - for i, (grasp, score) in enumerate(zip(grasps, scores)): - color = cmap((score - smin) / (smax - smin)) + for i, grasp in enumerate(grasps): + color = cmap((grasp.quality - qmin) / (qmax - qmin)) color = [color[0], color[1], color[2], alpha] markers += create_grasp_markers(frame, grasp, color, "grasps", 4 * i) self.grasp_marker_count = len(markers)