Fix
This commit is contained in:
parent
e78c70f173
commit
2df4f09d0b
@ -78,8 +78,13 @@ class Policy:
|
||||
def update(self, img, x, q):
|
||||
raise NotImplementedError
|
||||
|
||||
def select_best_grasp(self, grasps, qualities, q):
|
||||
filtered_grasps, scores = [], []
|
||||
def filter_grasps(self, out, q):
|
||||
grasps, qualities = select_local_maxima(
|
||||
self.tsdf.voxel_size,
|
||||
out,
|
||||
self.qual_thresh,
|
||||
)
|
||||
filtered_grasps, filtered_qualities = [], []
|
||||
for grasp, quality in zip(grasps, qualities):
|
||||
pose = self.T_base_task * grasp.pose
|
||||
R, t = pose.rotation, pose.translation
|
||||
@ -89,12 +94,13 @@ class Policy:
|
||||
q_grasp = self.solve_ee_ik(q, pose * self.T_grasp_ee)
|
||||
if q_grasp is not None:
|
||||
filtered_grasps.append(grasp)
|
||||
scores.append(self.score_fn(grasp, quality, q, q_grasp))
|
||||
i = np.argmax(scores)
|
||||
return filtered_grasps[i], qualities[i], scores[i]
|
||||
filtered_qualities.append(quality)
|
||||
return filtered_grasps, filtered_qualities
|
||||
|
||||
def score_fn(self, grasp, quality, q, q_grasp):
|
||||
return quality
|
||||
|
||||
def select_best_grasp(grasps, qualities):
|
||||
i = np.argmax(qualities)
|
||||
return grasps[i], qualities[i]
|
||||
|
||||
|
||||
class SingleViewPolicy(Policy):
|
||||
@ -109,11 +115,12 @@ class SingleViewPolicy(Policy):
|
||||
|
||||
out = self.vgn.predict(tsdf_grid)
|
||||
self.vis.quality(self.task_frame, voxel_size, out.qual, 0.5)
|
||||
grasps, qualities = select_local_maxima(voxel_size, out, self.qual_thresh)
|
||||
|
||||
grasps, qualities = self.filter_grasps(out, q)
|
||||
|
||||
if len(grasps) > 0:
|
||||
self.best_grasp, qual, _ = self.select_best_grasp(grasps, qualities, q)
|
||||
self.vis.grasp(self.base_frame, self.best_grasp, qual)
|
||||
self.best_grasp, quality = select_best_grasp(grasps, qualities)
|
||||
self.vis.grasp(self.base_frame, self.best_grasp, quality)
|
||||
|
||||
self.done = True
|
||||
|
||||
@ -137,7 +144,7 @@ class MultiViewPolicy(Policy):
|
||||
self.vis.map_cloud(self.task_frame, self.tsdf.get_map_cloud())
|
||||
|
||||
with Timer("grasp_prediction"):
|
||||
tsdf_grid, voxel_size = self.tsdf.get_grid(), self.tsdf.voxel_size
|
||||
tsdf_grid = self.tsdf.get_grid()
|
||||
out = self.vgn.predict(tsdf_grid)
|
||||
self.vis.quality(self.task_frame, self.tsdf.voxel_size, out.qual, 0.9)
|
||||
|
||||
@ -145,10 +152,10 @@ class MultiViewPolicy(Policy):
|
||||
self.qual_hist[t, ...] = out.qual
|
||||
|
||||
with Timer("grasp_selection"):
|
||||
grasps, qualities = select_local_maxima(voxel_size, out, self.qual_thresh)
|
||||
grasps, qualities = self.filter_grasps(out, q)
|
||||
|
||||
if len(grasps) > 0:
|
||||
self.best_grasp, quality, _ = self.select_best_grasp(grasps, qualities, q)
|
||||
self.best_grasp, quality = select_best_grasp(grasps, qualities)
|
||||
self.vis.grasp(self.base_frame, self.best_grasp, quality)
|
||||
else:
|
||||
self.best_grasp = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user