diff --git a/active_grasp/baselines.py b/active_grasp/baselines.py index 18ffb0c..4e74a51 100644 --- a/active_grasp/baselines.py +++ b/active_grasp/baselines.py @@ -17,6 +17,7 @@ class TopView(SingleViewPolicy): eye = np.r_[self.center[:2], self.bbox.max[2] + self.min_z_dist] up = np.r_[1.0, 0.0, 0.0] self.x_d = look_at(eye, self.center, up) + self.done = False if self.is_view_feasible(self.x_d) else True class TopTrajectory(MultiViewPolicy): @@ -25,6 +26,7 @@ class TopTrajectory(MultiViewPolicy): eye = np.r_[self.center[:2], self.bbox.max[2] + self.min_z_dist] up = np.r_[1.0, 0.0, 0.0] self.x_d = look_at(eye, self.center, up) + self.done = False if self.is_view_feasible(self.x_d) else True def update(self, img, x): self.integrate(img, x) diff --git a/active_grasp/controller.py b/active_grasp/controller.py index 73e48b8..c7f9f01 100644 --- a/active_grasp/controller.py +++ b/active_grasp/controller.py @@ -20,7 +20,6 @@ class GraspController: def __init__(self, policy): self.policy = policy self.load_parameters() - self.lookup_transforms() self.init_service_proxies() self.init_robot_connection() self.init_moveit() @@ -33,10 +32,6 @@ class GraspController: self.depth_topic = rospy.get_param("~camera/depth_topic") self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv() - def lookup_transforms(self): - tf.init() - self.T_ee_cam = tf.lookup(self.ee_frame, self.cam_frame) - def init_service_proxies(self): self.reset_env = rospy.ServiceProxy("reset", Reset) self.switch_controller = rospy.ServiceProxy( @@ -52,6 +47,7 @@ class GraspController: rospy.sleep(1.0) # wait for connections to be established # msg = to_pose_stamped_msg(Transform.t([0.4, 0, 0.4]), self.base_frame) # self.moveit.scene.add_box("table", msg, size=(0.5, 0.5, 0.02)) + self.policy.moveit = self.moveit def switch_to_cartesian_velocity_control(self): req = SwitchControllerRequest() diff --git a/active_grasp/policy.py b/active_grasp/policy.py index 5a00a70..ee1f7c8 100644 --- a/active_grasp/policy.py +++ b/active_grasp/policy.py @@ -16,9 +16,11 @@ class Policy: self.rate = rate self.load_parameters() self.init_visualizer() + self.lookup_transforms() def load_parameters(self): self.base_frame = rospy.get_param("~base_frame_id") + self.cam_frame = rospy.get_param("~camera/frame_id") info_topic = rospy.get_param("~camera/info_topic") self.linear_vel = rospy.get_param("~linear_vel") self.min_z_dist = rospy.get_param("~camera/min_z_dist") @@ -31,6 +33,9 @@ class Policy: def init_visualizer(self): self.vis = Visualizer() + def lookup_transforms(self): + self.T_cam_ee = tf.lookup(self.cam_frame, "panda_link8") + def activate(self, bbox): self.bbox = bbox self.vis.clear() @@ -50,17 +55,8 @@ class Policy: tf.broadcast(self.T_base_task, self.base_frame, self.task_frame) rospy.sleep(0.5) - def compute_error(self, x_d, x): - linear = x_d.translation - x.translation - angular = (x_d.rotation * x.rotation.inv()).as_rotvec() - return linear, angular - - def compute_velocity_cmd(self, linear, angular): - kp = 4.0 - linear = kp * linear - scale = np.linalg.norm(linear) - linear *= np.clip(scale, 0.0, self.linear_vel) / scale - return np.r_[linear, angular] + def score_fn(self, grasp): + return grasp.quality def sort_grasps(self, in_grasps): # Transforms grasps into base frame, checks whether they lie on the target, and sorts by their score @@ -85,12 +81,26 @@ class Policy: indices = np.argsort(-scores) return grasps[indices], scores[indices] - def score_fn(self, grasp): - return grasp.quality - def update(self, img, pose): raise NotImplementedError + def is_view_feasible(self, view): + # Check whether MoveIt can find a trajectory to the given view + success, _ = self.moveit.plan(view * self.T_cam_ee) + return success + + def compute_error(self, x_d, x): + linear = x_d.translation - x.translation + angular = (x_d.rotation * x.rotation.inv()).as_rotvec() + return linear, angular + + def compute_velocity_cmd(self, linear, angular): + kp = 4.0 + linear = kp * linear + scale = np.linalg.norm(linear) + linear *= np.clip(scale, 0.0, self.linear_vel) / scale + return np.r_[linear, angular] + class SingleViewPolicy(Policy): def update(self, img, x): diff --git a/scripts/run.py b/scripts/run.py index 4951f63..5f6b3b1 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -8,10 +8,12 @@ from tqdm import tqdm from active_grasp.controller import * from active_grasp.policy import make, registry from active_grasp.srv import Seed +from robot_helpers.ros import tf def main(): rospy.init_node("grasp_controller") + tf.init() parser = create_parser() args = parser.parse_args()