From 5855f67c2a97756754b10a849e21be249e7bde5f Mon Sep 17 00:00:00 2001 From: Michel Breyer Date: Wed, 7 Jul 2021 17:46:11 +0200 Subject: [PATCH] Add top baseline --- active_grasp/__init__.py | 5 +++ active_grasp/baselines.py | 43 ++++++++++++++++++++++++ active_grasp/{policies.py => policy.py} | 44 ++++++++++--------------- scripts/run.py | 16 ++------- 4 files changed, 68 insertions(+), 40 deletions(-) create mode 100644 active_grasp/baselines.py rename active_grasp/{policies.py => policy.py} (89%) diff --git a/active_grasp/__init__.py b/active_grasp/__init__.py index e69de29..4e14257 100644 --- a/active_grasp/__init__.py +++ b/active_grasp/__init__.py @@ -0,0 +1,5 @@ +from .policy import register +from .baselines import * + +register("single-view", SingleViewBaseline) +register("top", TopBaseline) diff --git a/active_grasp/baselines.py b/active_grasp/baselines.py new file mode 100644 index 0000000..b297b72 --- /dev/null +++ b/active_grasp/baselines.py @@ -0,0 +1,43 @@ +import numpy as np + +from active_grasp.policy import BasePolicy +from robot_utils.ros import tf +from vgn.utils import look_at + + +class SingleViewBaseline(BasePolicy): + """ + Process a single image from the initial viewpoint. + """ + + def update(self): + self.integrate_latest_image() + self.draw_scene_cloud() + self.best_grasp = self.predict_best_grasp() + self.done = True + + +class TopBaseline(BasePolicy): + """ + Move the camera to a top-down view of the target object. + """ + + def activate(self, bbox): + super().activate(bbox) + center = (bbox.min + bbox.max) / 2.0 + eye = np.r_[center[:2], center[2] + 0.3] + up = np.r_[1.0, 0.0, 0.0] + self.target = self.T_B_task * (self.T_EE_cam * look_at(eye, center, up)).inv() + + def update(self): + current = tf.lookup(self.base_frame, self.ee_frame) + error = current.translation - self.target.translation + + if np.linalg.norm(error) < 0.01: + self.best_grasp = self.predict_best_grasp() + self.done = True + else: + self.integrate_latest_image() + self.draw_scene_cloud() + self.draw_camera_path() + return self.target diff --git a/active_grasp/policies.py b/active_grasp/policy.py similarity index 89% rename from active_grasp/policies.py rename to active_grasp/policy.py index 92c41b3..d9c2045 100644 --- a/active_grasp/policies.py +++ b/active_grasp/policy.py @@ -16,22 +16,7 @@ from vgn.perception import UniformTSDFVolume from vgn.utils import * -def get_policy(name): - if name == "single-view": - return SingleView() - else: - raise ValueError("{} policy does not exist.".format(name)) - - -class Policy: - def activate(self, bbox): - raise NotImplementedError - - def update(self): - raise NotImplementedError - - -class BasePolicy(Policy): +class BasePolicy: def __init__(self): self.cv_bridge = cv_bridge.CvBridge() self.vgn = VGN(Path(rospy.get_param("vgn/model"))) @@ -42,11 +27,12 @@ class BasePolicy(Policy): self.connect_to_camera() self.connect_to_rviz() - self.rate = 2 + self.rate = 5 def load_parameters(self): - self.base_frame = rospy.get_param("~base_frame_id") self.task_frame = rospy.get_param("~frame_id") + self.base_frame = rospy.get_param("~base_frame_id") + self.ee_frame = rospy.get_param("~ee_frame_id") self.cam_frame = rospy.get_param("~camera/frame_id") self.info_topic = rospy.get_param("~camera/info_topic") self.depth_topic = rospy.get_param("~camera/depth_topic") @@ -55,6 +41,7 @@ class BasePolicy(Policy): tf._init_listener() rospy.sleep(1.0) # wait to receive transforms self.T_B_task = tf.lookup(self.base_frame, self.task_frame) + self.T_EE_cam = tf.lookup(self.ee_frame, self.cam_frame) def connect_to_camera(self): msg = rospy.wait_for_message( @@ -160,13 +147,16 @@ class BasePolicy(Policy): self.path_pub.publish(MarkerArray([spheres, lines])) -class SingleView(BasePolicy): - """ - Process a single image from the initial viewpoint. - """ +registry = {} - def update(self): - self.integrate_latest_image() - self.draw_scene_cloud() - self.best_grasp = self.predict_best_grasp() - self.done = True + +def register(id, cls): + global registry + registry[id] = cls + + +def make(id): + if id in registry: + return registry[id]() + else: + raise ValueError("{} policy does not exist.".format(id)) diff --git a/scripts/run.py b/scripts/run.py index 31f5689..f05792d 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -2,22 +2,12 @@ import argparse import rospy from active_grasp.controller import GraspController -from active_grasp.policies import get_policy +from active_grasp.policy import make, registry def create_parser(): parser = argparse.ArgumentParser() - parser.add_argument( - "--policy", - type=str, - choices=[ - "single-view", - "top", - "alignment", - "random", - "fixed-trajectory", - ], - ) + parser.add_argument("--policy", type=str, choices=registry.keys()) return parser @@ -25,7 +15,7 @@ def main(): rospy.init_node("active_grasp") parser = create_parser() args = parser.parse_args() - policy = get_policy(args.policy) + policy = make(args.policy) controller = GraspController(policy) while True: