Pass arguments directly to the policy

This commit is contained in:
Michel Breyer 2021-08-06 15:23:50 +02:00
parent 6fa4007727
commit 4eeb309a8f
4 changed files with 28 additions and 25 deletions

View File

@ -13,7 +13,7 @@ class SingleView(BasePolicy):
def update(self, img, extrinsic):
self.integrate_img(img, extrinsic)
self.best_grasp = self.predict_best_grasp()
self.best_grasp = self.compute_best_grasp()
self.done = True
@ -32,7 +32,7 @@ class TopView(BasePolicy):
self.integrate_img(img, extrinsic)
error = extrinsic.translation - self.target.translation
if np.linalg.norm(error) < 0.01:
self.best_grasp = self.predict_best_grasp()
self.best_grasp = self.compute_best_grasp()
self.done = True
return self.target
@ -58,7 +58,7 @@ class RandomView(BasePolicy):
self.integrate_img(img, extrinsic)
error = extrinsic.translation - self.target.translation
if np.linalg.norm(error) < 0.01:
self.best_grasp = self.predict_best_grasp()
self.best_grasp = self.compute_best_grasp()
self.done = True
return self.target
@ -83,7 +83,7 @@ class FixedTrajectory(BasePolicy):
self.integrate_img(img, extrinsic)
elapsed_time = (rospy.Time.now() - self.tic).to_sec()
if elapsed_time > self.duration:
self.best_grasp = self.predict_best_grasp()
self.best_grasp = self.compute_best_grasp()
self.done = True
else:
t = self.m(elapsed_time)
@ -106,7 +106,7 @@ class AlignmentView(BasePolicy):
self.integrate_img(img, extrinsic)
if not self.target:
grasp = self.predict_best_grasp()
grasp = self.compute_best_grasp()
if not grasp:
self.done = True
return
@ -118,6 +118,6 @@ class AlignmentView(BasePolicy):
error = extrinsic.translation - self.target.translation
if np.linalg.norm(error) < 0.01:
self.best_grasp = self.predict_best_grasp()
self.best_grasp = self.compute_best_grasp()
self.done = True
return self.target

View File

@ -3,10 +3,9 @@ import cv_bridge
from geometry_msgs.msg import PoseStamped
import numpy as np
import rospy
from sensor_msgs.msg import CameraInfo, Image
from sensor_msgs.msg import Image
from .bbox import from_bbox_msg
from .policy import make
from .timer import Timer
from active_grasp.srv import Reset, ResetRequest
from robot_helpers.ros import tf
@ -16,19 +15,18 @@ from robot_helpers.spatial import Rotation, Transform
class GraspController:
def __init__(self, policy_id):
def __init__(self, policy):
self.policy = policy
self.reset_env = rospy.ServiceProxy("reset", Reset)
self.load_parameters()
self.lookup_transforms()
self.init_robot_connection()
self.init_camera_stream()
self.make_policy(policy_id)
def load_parameters(self):
self.base_frame = rospy.get_param("~base_frame_id")
self.ee_frame = rospy.get_param("~ee_frame_id")
self.cam_frame = rospy.get_param("~camera/frame_id")
self.info_topic = rospy.get_param("~camera/info_topic")
self.depth_topic = rospy.get_param("~camera/depth_topic")
self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv()
@ -44,17 +42,12 @@ class GraspController:
self.target_pose_pub.publish(msg)
def init_camera_stream(self):
msg = rospy.wait_for_message(self.info_topic, CameraInfo, rospy.Duration(2.0))
self.intrinsic = from_camera_info_msg(msg)
self.cv_bridge = cv_bridge.CvBridge()
rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1)
def sensor_cb(self, msg):
self.latest_depth_msg = msg
def make_policy(self, name):
self.policy = make(name, self.intrinsic)
def run(self):
bbox = self.reset()
with Timer("search_time"):

View File

@ -1,4 +1,5 @@
import numpy as np
from sensor_msgs.msg import CameraInfo
from pathlib import Path
import rospy
@ -19,15 +20,17 @@ class Policy:
class BasePolicy(Policy):
def __init__(self, intrinsic):
self.intrinsic = intrinsic
self.rate = 5
def __init__(self, rate=5):
self.rate = rate
self.load_parameters()
self.init_visualizer()
def load_parameters(self):
self.base_frame = rospy.get_param("active_grasp/base_frame_id")
self.task_frame = "task"
info_topic = rospy.get_param("active_grasp/camera/info_topic")
msg = rospy.wait_for_message(info_topic, CameraInfo, rospy.Duration(2.0))
self.intrinsic = from_camera_info_msg(msg)
self.vgn = VGN(Path(rospy.get_param("vgn/model")))
def init_visualizer(self):
@ -56,6 +59,9 @@ class BasePolicy(Policy):
self.visualizer.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud())
self.visualizer.path(self.viewpoints)
def compute_best_grasp(self):
return self.predict_best_grasp()
def predict_best_grasp(self):
tsdf_grid = self.tsdf.get_grid()
out = self.vgn.predict(tsdf_grid)

View File

@ -6,16 +6,19 @@ import rospy
from tqdm import tqdm
from active_grasp.controller import *
from active_grasp.policy import registry
from active_grasp.policy import make, registry
from active_grasp.srv import Seed
def main():
rospy.init_node("active_grasp")
parser = create_parser()
args = parser.parse_args()
controller = GraspController(args.policy)
logger = Logger(args.logdir, args.policy)
policy = make(args.policy, args.rate)
controller = GraspController(policy)
logger = Logger(args)
seed_simulation(args.seed)
@ -29,15 +32,16 @@ def create_parser():
parser.add_argument("policy", type=str, choices=registry.keys())
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--logdir", type=Path, default="logs")
parser.add_argument("--rate", type=int, default=5)
parser.add_argument("--seed", type=int, default=12)
return parser
class Logger:
def __init__(self, logdir, policy):
def __init__(self, args):
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
name = "{}_policy={}".format(stamp, policy)
self.path = logdir / (name + ".csv")
descr = "policy={},rate={}".format(args.policy, args.rate)
self.path = args.logdir / (stamp + "_" + descr + ".csv")
def log_run(self, info):
df = pd.DataFrame.from_records([info])