diff --git a/active_grasp/baselines.py b/active_grasp/baselines.py index cec0141..ea16176 100644 --- a/active_grasp/baselines.py +++ b/active_grasp/baselines.py @@ -13,7 +13,7 @@ class SingleView(BasePolicy): def update(self, img, extrinsic): self.integrate_img(img, extrinsic) - self.best_grasp = self.predict_best_grasp() + self.best_grasp = self.compute_best_grasp() self.done = True @@ -32,7 +32,7 @@ class TopView(BasePolicy): self.integrate_img(img, extrinsic) error = extrinsic.translation - self.target.translation if np.linalg.norm(error) < 0.01: - self.best_grasp = self.predict_best_grasp() + self.best_grasp = self.compute_best_grasp() self.done = True return self.target @@ -58,7 +58,7 @@ class RandomView(BasePolicy): self.integrate_img(img, extrinsic) error = extrinsic.translation - self.target.translation if np.linalg.norm(error) < 0.01: - self.best_grasp = self.predict_best_grasp() + self.best_grasp = self.compute_best_grasp() self.done = True return self.target @@ -83,7 +83,7 @@ class FixedTrajectory(BasePolicy): self.integrate_img(img, extrinsic) elapsed_time = (rospy.Time.now() - self.tic).to_sec() if elapsed_time > self.duration: - self.best_grasp = self.predict_best_grasp() + self.best_grasp = self.compute_best_grasp() self.done = True else: t = self.m(elapsed_time) @@ -106,7 +106,7 @@ class AlignmentView(BasePolicy): self.integrate_img(img, extrinsic) if not self.target: - grasp = self.predict_best_grasp() + grasp = self.compute_best_grasp() if not grasp: self.done = True return @@ -118,6 +118,6 @@ class AlignmentView(BasePolicy): error = extrinsic.translation - self.target.translation if np.linalg.norm(error) < 0.01: - self.best_grasp = self.predict_best_grasp() + self.best_grasp = self.compute_best_grasp() self.done = True return self.target diff --git a/active_grasp/controller.py b/active_grasp/controller.py index abf937f..5519607 100644 --- a/active_grasp/controller.py +++ b/active_grasp/controller.py @@ -3,10 +3,9 @@ import cv_bridge from geometry_msgs.msg import PoseStamped import numpy as np import rospy -from sensor_msgs.msg import CameraInfo, Image +from sensor_msgs.msg import Image from .bbox import from_bbox_msg -from .policy import make from .timer import Timer from active_grasp.srv import Reset, ResetRequest from robot_helpers.ros import tf @@ -16,19 +15,18 @@ from robot_helpers.spatial import Rotation, Transform class GraspController: - def __init__(self, policy_id): + def __init__(self, policy): + self.policy = policy self.reset_env = rospy.ServiceProxy("reset", Reset) self.load_parameters() self.lookup_transforms() self.init_robot_connection() self.init_camera_stream() - self.make_policy(policy_id) def load_parameters(self): self.base_frame = rospy.get_param("~base_frame_id") self.ee_frame = rospy.get_param("~ee_frame_id") self.cam_frame = rospy.get_param("~camera/frame_id") - self.info_topic = rospy.get_param("~camera/info_topic") self.depth_topic = rospy.get_param("~camera/depth_topic") self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv() @@ -44,17 +42,12 @@ class GraspController: self.target_pose_pub.publish(msg) def init_camera_stream(self): - msg = rospy.wait_for_message(self.info_topic, CameraInfo, rospy.Duration(2.0)) - self.intrinsic = from_camera_info_msg(msg) self.cv_bridge = cv_bridge.CvBridge() rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1) def sensor_cb(self, msg): self.latest_depth_msg = msg - def make_policy(self, name): - self.policy = make(name, self.intrinsic) - def run(self): bbox = self.reset() with Timer("search_time"): diff --git a/active_grasp/policy.py b/active_grasp/policy.py index 49e3c3f..9fc1e7f 100644 --- a/active_grasp/policy.py +++ b/active_grasp/policy.py @@ -1,4 +1,5 @@ import numpy as np +from sensor_msgs.msg import CameraInfo from pathlib import Path import rospy @@ -19,15 +20,17 @@ class Policy: class BasePolicy(Policy): - def __init__(self, intrinsic): - self.intrinsic = intrinsic - self.rate = 5 + def __init__(self, rate=5): + self.rate = rate self.load_parameters() self.init_visualizer() def load_parameters(self): self.base_frame = rospy.get_param("active_grasp/base_frame_id") self.task_frame = "task" + info_topic = rospy.get_param("active_grasp/camera/info_topic") + msg = rospy.wait_for_message(info_topic, CameraInfo, rospy.Duration(2.0)) + self.intrinsic = from_camera_info_msg(msg) self.vgn = VGN(Path(rospy.get_param("vgn/model"))) def init_visualizer(self): @@ -56,6 +59,9 @@ class BasePolicy(Policy): self.visualizer.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud()) self.visualizer.path(self.viewpoints) + def compute_best_grasp(self): + return self.predict_best_grasp() + def predict_best_grasp(self): tsdf_grid = self.tsdf.get_grid() out = self.vgn.predict(tsdf_grid) diff --git a/scripts/run.py b/scripts/run.py index 1870b0b..9a2ccbd 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -6,16 +6,19 @@ import rospy from tqdm import tqdm from active_grasp.controller import * -from active_grasp.policy import registry +from active_grasp.policy import make, registry from active_grasp.srv import Seed def main(): rospy.init_node("active_grasp") + parser = create_parser() args = parser.parse_args() - controller = GraspController(args.policy) - logger = Logger(args.logdir, args.policy) + + policy = make(args.policy, args.rate) + controller = GraspController(policy) + logger = Logger(args) seed_simulation(args.seed) @@ -29,15 +32,16 @@ def create_parser(): parser.add_argument("policy", type=str, choices=registry.keys()) parser.add_argument("--runs", type=int, default=10) parser.add_argument("--logdir", type=Path, default="logs") + parser.add_argument("--rate", type=int, default=5) parser.add_argument("--seed", type=int, default=12) return parser class Logger: - def __init__(self, logdir, policy): + def __init__(self, args): stamp = datetime.now().strftime("%y%m%d-%H%M%S") - name = "{}_policy={}".format(stamp, policy) - self.path = logdir / (name + ".csv") + descr = "policy={},rate={}".format(args.policy, args.rate) + self.path = args.logdir / (stamp + "_" + descr + ".csv") def log_run(self, info): df = pd.DataFrame.from_records([info])