Add top baseline

This commit is contained in:
Michel Breyer 2021-07-07 17:46:11 +02:00
parent 6658b8c7f0
commit 5855f67c2a
4 changed files with 68 additions and 40 deletions

View File

@ -0,0 +1,5 @@
from .policy import register
from .baselines import *
register("single-view", SingleViewBaseline)
register("top", TopBaseline)

43
active_grasp/baselines.py Normal file
View File

@ -0,0 +1,43 @@
import numpy as np
from active_grasp.policy import BasePolicy
from robot_utils.ros import tf
from vgn.utils import look_at
class SingleViewBaseline(BasePolicy):
"""
Process a single image from the initial viewpoint.
"""
def update(self):
self.integrate_latest_image()
self.draw_scene_cloud()
self.best_grasp = self.predict_best_grasp()
self.done = True
class TopBaseline(BasePolicy):
"""
Move the camera to a top-down view of the target object.
"""
def activate(self, bbox):
super().activate(bbox)
center = (bbox.min + bbox.max) / 2.0
eye = np.r_[center[:2], center[2] + 0.3]
up = np.r_[1.0, 0.0, 0.0]
self.target = self.T_B_task * (self.T_EE_cam * look_at(eye, center, up)).inv()
def update(self):
current = tf.lookup(self.base_frame, self.ee_frame)
error = current.translation - self.target.translation
if np.linalg.norm(error) < 0.01:
self.best_grasp = self.predict_best_grasp()
self.done = True
else:
self.integrate_latest_image()
self.draw_scene_cloud()
self.draw_camera_path()
return self.target

View File

@ -16,22 +16,7 @@ from vgn.perception import UniformTSDFVolume
from vgn.utils import * from vgn.utils import *
def get_policy(name): class BasePolicy:
if name == "single-view":
return SingleView()
else:
raise ValueError("{} policy does not exist.".format(name))
class Policy:
def activate(self, bbox):
raise NotImplementedError
def update(self):
raise NotImplementedError
class BasePolicy(Policy):
def __init__(self): def __init__(self):
self.cv_bridge = cv_bridge.CvBridge() self.cv_bridge = cv_bridge.CvBridge()
self.vgn = VGN(Path(rospy.get_param("vgn/model"))) self.vgn = VGN(Path(rospy.get_param("vgn/model")))
@ -42,11 +27,12 @@ class BasePolicy(Policy):
self.connect_to_camera() self.connect_to_camera()
self.connect_to_rviz() self.connect_to_rviz()
self.rate = 2 self.rate = 5
def load_parameters(self): def load_parameters(self):
self.base_frame = rospy.get_param("~base_frame_id")
self.task_frame = rospy.get_param("~frame_id") self.task_frame = rospy.get_param("~frame_id")
self.base_frame = rospy.get_param("~base_frame_id")
self.ee_frame = rospy.get_param("~ee_frame_id")
self.cam_frame = rospy.get_param("~camera/frame_id") self.cam_frame = rospy.get_param("~camera/frame_id")
self.info_topic = rospy.get_param("~camera/info_topic") self.info_topic = rospy.get_param("~camera/info_topic")
self.depth_topic = rospy.get_param("~camera/depth_topic") self.depth_topic = rospy.get_param("~camera/depth_topic")
@ -55,6 +41,7 @@ class BasePolicy(Policy):
tf._init_listener() tf._init_listener()
rospy.sleep(1.0) # wait to receive transforms rospy.sleep(1.0) # wait to receive transforms
self.T_B_task = tf.lookup(self.base_frame, self.task_frame) self.T_B_task = tf.lookup(self.base_frame, self.task_frame)
self.T_EE_cam = tf.lookup(self.ee_frame, self.cam_frame)
def connect_to_camera(self): def connect_to_camera(self):
msg = rospy.wait_for_message( msg = rospy.wait_for_message(
@ -160,13 +147,16 @@ class BasePolicy(Policy):
self.path_pub.publish(MarkerArray([spheres, lines])) self.path_pub.publish(MarkerArray([spheres, lines]))
class SingleView(BasePolicy): registry = {}
"""
Process a single image from the initial viewpoint.
"""
def update(self):
self.integrate_latest_image() def register(id, cls):
self.draw_scene_cloud() global registry
self.best_grasp = self.predict_best_grasp() registry[id] = cls
self.done = True
def make(id):
if id in registry:
return registry[id]()
else:
raise ValueError("{} policy does not exist.".format(id))

View File

@ -2,22 +2,12 @@ import argparse
import rospy import rospy
from active_grasp.controller import GraspController from active_grasp.controller import GraspController
from active_grasp.policies import get_policy from active_grasp.policy import make, registry
def create_parser(): def create_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--policy", type=str, choices=registry.keys())
"--policy",
type=str,
choices=[
"single-view",
"top",
"alignment",
"random",
"fixed-trajectory",
],
)
return parser return parser
@ -25,7 +15,7 @@ def main():
rospy.init_node("active_grasp") rospy.init_node("active_grasp")
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
policy = get_policy(args.policy) policy = make(args.policy)
controller = GraspController(policy) controller = GraspController(policy)
while True: while True: