diff --git a/active_grasp/__init__.py b/active_grasp/__init__.py index 4e14257..e7a1e5b 100644 --- a/active_grasp/__init__.py +++ b/active_grasp/__init__.py @@ -3,3 +3,4 @@ from .baselines import * register("single-view", SingleViewBaseline) register("top", TopBaseline) +register("fixed-trajectory", FixedTrajectoryBaseline) diff --git a/active_grasp/baselines.py b/active_grasp/baselines.py index b297b72..50937e2 100644 --- a/active_grasp/baselines.py +++ b/active_grasp/baselines.py @@ -1,4 +1,6 @@ import numpy as np +import scipy.interpolate +import rospy from active_grasp.policy import BasePolicy from robot_utils.ros import tf @@ -41,3 +43,38 @@ class TopBaseline(BasePolicy): self.draw_scene_cloud() self.draw_camera_path() return self.target + + +class FixedTrajectoryBaseline(BasePolicy): + """ + Follow a pre-defined circular trajectory centered above the target object. + """ + + def __init__(self): + super().__init__() + self.r = 0.06 + self.h = 0.3 + self.duration = 6.0 + self.m = scipy.interpolate.interp1d([0, self.duration], [np.pi, 3.0 * np.pi]) + + def activate(self, bbox): + super().activate(bbox) + self.tic = rospy.Time.now() + self.circle_center = (bbox.min + bbox.max) / 2.0 + self.circle_center[2] += self.h + + def update(self): + elapsed_time = (rospy.Time.now() - self.tic).to_sec() + if elapsed_time > self.duration: + self.best_grasp = self.predict_best_grasp() + self.done = True + else: + self.integrate_latest_image() + t = self.m(elapsed_time) + eye = self.circle_center + np.r_[self.r * np.cos(t), self.r * np.sin(t), 0] + center = (self.bbox.min + self.bbox.max) / 2.0 + up = np.r_[1.0, 0.0, 0.0] + target = self.T_B_task * (self.T_EE_cam * look_at(eye, center, up)).inv() + self.draw_scene_cloud() + self.draw_camera_path() + return target diff --git a/active_grasp/simulation.py b/active_grasp/simulation.py index ac53370..9c9d85b 100644 --- a/active_grasp/simulation.py +++ b/active_grasp/simulation.py @@ -33,7 +33,7 @@ class Simulation(BtSim): ori = Rotation.from_rotvec(np.array([0, 0, np.pi / 2])).as_quat() p.loadURDF("table/table.urdf", baseOrientation=ori, useFixedBase=True) self.length = 0.3 - self.origin = [-0.3, -0.5 * self.length, 0.5] + self.origin = [-0.35, -0.5 * self.length, 0.5] def load_robot(self): self.T_W_B = Transform(Rotation.identity(), np.r_[-0.6, 0.0, 0.4])