diff --git a/policies.py b/policies.py index 1e7445b..a1982b0 100644 --- a/policies.py +++ b/policies.py @@ -57,15 +57,64 @@ class Policy: self.H_B_T = self.tf.lookup(self.base_frame_id, self.frame_id, rospy.Time.now()) rospy.Subscriber(depth_topic, Image, self.sensor_cb, queue_size=1) + vis.draw_workspace(0.3) + def sensor_cb(self, msg): self.last_depth_img = self.cv_bridge.imgmsg_to_cv2(msg).astype(np.float32) self.last_extrinsic = self.tf.lookup( self.cam_frame_id, self.frame_id, msg.header.stamp, rospy.Duration(0.1) ) + def get_tsdf_grid(self): + map_cloud = self.tsdf.get_map_cloud() + points = np.asarray(map_cloud.points) + distances = np.asarray(map_cloud.colors)[:, 0] + return create_grid_from_map_cloud(points, distances, self.tsdf.voxel_size) + + def plan_best_grasp(self): + tsdf_grid = self.get_tsdf_grid() + out = self.vgn.predict(tsdf_grid) + grasps = compute_grasps(out, voxel_size=self.tsdf.voxel_size) + + vis.draw_tsdf(tsdf_grid, self.tsdf.voxel_size) + vis.draw_grasps(grasps, 0.05) + + # Ensure that the camera is pointing forward. + grasp = grasps[0] + rot = grasp.pose.rotation + axis = rot.as_matrix()[:, 0] + if axis[0] < 0: + grasp.pose.rotation = rot * Rotation.from_euler("z", np.pi) + + # Compute target pose of the EE + H_T_G = grasp.pose + H_B_EE = self.H_B_T * H_T_G * self.H_EE_G.inv() + return H_B_EE + class SingleViewBaseline(Policy): - pass + def __init__(sel): + super().__init__() + + def start(self): + self.done = False + + def update(self): + # Integrate image + self.tsdf.integrate( + self.last_depth_img, + self.intrinsic, + self.last_extrinsic, + ) + + # Visualize reconstruction + cloud = self.tsdf.get_scene_cloud() + vis.draw_points(np.asarray(cloud.points)) + + # Plan grasp + self.best_grasp = self.plan_best_grasp() + self.done = True + return class FixedTrajectoryBaseline(Policy): @@ -74,7 +123,6 @@ class FixedTrajectoryBaseline(Policy): self.duration = 4.0 self.radius = 0.1 self.m = scipy.interpolate.interp1d([0, self.duration], [np.pi, 3.0 * np.pi]) - vis.draw_workspace(0.3) def start(self): self.tic = rospy.Time.now() @@ -99,31 +147,7 @@ class FixedTrajectoryBaseline(Policy): vis.draw_points(np.asarray(cloud.points)) if elapsed_time > self.duration: - # Plan grasps - map_cloud = self.tsdf.get_map_cloud() - points = np.asarray(map_cloud.points) - distances = np.asarray(map_cloud.colors)[:, 0] - tsdf_grid = create_grid_from_map_cloud( - points, distances, self.tsdf.voxel_size - ) - out = self.vgn.predict(tsdf_grid) - grasps = compute_grasps(out, voxel_size=self.tsdf.voxel_size) - - # Visualize - vis.draw_tsdf(tsdf_grid, self.tsdf.voxel_size) - vis.draw_grasps(grasps, 0.05) - - # Ensure that the camera is pointing forward. - grasp = grasps[0] - rot = grasp.pose.rotation - axis = rot.as_matrix()[:, 0] - if axis[0] < 0: - grasp.pose.rotation = rot * Rotation.from_euler("z", np.pi) - - # Compute target pose of the EE - H_T_G = grasp.pose - H_B_EE = self.H_B_T * H_T_G * self.H_EE_G.inv() - self.best_grasp = H_B_EE + self.best_grasp = self.plan_best_grasp() self.done = True return