This commit is contained in:
Michel Breyer 2021-11-08 14:39:44 +01:00
parent e78c70f173
commit 2df4f09d0b

View File

@ -78,8 +78,13 @@ class Policy:
def update(self, img, x, q):
raise NotImplementedError
def select_best_grasp(self, grasps, qualities, q):
filtered_grasps, scores = [], []
def filter_grasps(self, out, q):
grasps, qualities = select_local_maxima(
self.tsdf.voxel_size,
out,
self.qual_thresh,
)
filtered_grasps, filtered_qualities = [], []
for grasp, quality in zip(grasps, qualities):
pose = self.T_base_task * grasp.pose
R, t = pose.rotation, pose.translation
@ -89,12 +94,13 @@ class Policy:
q_grasp = self.solve_ee_ik(q, pose * self.T_grasp_ee)
if q_grasp is not None:
filtered_grasps.append(grasp)
scores.append(self.score_fn(grasp, quality, q, q_grasp))
i = np.argmax(scores)
return filtered_grasps[i], qualities[i], scores[i]
filtered_qualities.append(quality)
return filtered_grasps, filtered_qualities
def score_fn(self, grasp, quality, q, q_grasp):
return quality
def select_best_grasp(grasps, qualities):
i = np.argmax(qualities)
return grasps[i], qualities[i]
class SingleViewPolicy(Policy):
@ -109,11 +115,12 @@ class SingleViewPolicy(Policy):
out = self.vgn.predict(tsdf_grid)
self.vis.quality(self.task_frame, voxel_size, out.qual, 0.5)
grasps, qualities = select_local_maxima(voxel_size, out, self.qual_thresh)
grasps, qualities = self.filter_grasps(out, q)
if len(grasps) > 0:
self.best_grasp, qual, _ = self.select_best_grasp(grasps, qualities, q)
self.vis.grasp(self.base_frame, self.best_grasp, qual)
self.best_grasp, quality = select_best_grasp(grasps, qualities)
self.vis.grasp(self.base_frame, self.best_grasp, quality)
self.done = True
@ -137,7 +144,7 @@ class MultiViewPolicy(Policy):
self.vis.map_cloud(self.task_frame, self.tsdf.get_map_cloud())
with Timer("grasp_prediction"):
tsdf_grid, voxel_size = self.tsdf.get_grid(), self.tsdf.voxel_size
tsdf_grid = self.tsdf.get_grid()
out = self.vgn.predict(tsdf_grid)
self.vis.quality(self.task_frame, self.tsdf.voxel_size, out.qual, 0.9)
@ -145,10 +152,10 @@ class MultiViewPolicy(Policy):
self.qual_hist[t, ...] = out.qual
with Timer("grasp_selection"):
grasps, qualities = select_local_maxima(voxel_size, out, self.qual_thresh)
grasps, qualities = self.filter_grasps(out, q)
if len(grasps) > 0:
self.best_grasp, quality, _ = self.select_best_grasp(grasps, qualities, q)
self.best_grasp, quality = select_best_grasp(grasps, qualities)
self.vis.grasp(self.base_frame, self.best_grasp, quality)
else:
self.best_grasp = None