Pass arguments directly to the policy
This commit is contained in:
parent
6fa4007727
commit
4eeb309a8f
@ -13,7 +13,7 @@ class SingleView(BasePolicy):
|
||||
|
||||
def update(self, img, extrinsic):
|
||||
self.integrate_img(img, extrinsic)
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ class TopView(BasePolicy):
|
||||
self.integrate_img(img, extrinsic)
|
||||
error = extrinsic.translation - self.target.translation
|
||||
if np.linalg.norm(error) < 0.01:
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
return self.target
|
||||
|
||||
@ -58,7 +58,7 @@ class RandomView(BasePolicy):
|
||||
self.integrate_img(img, extrinsic)
|
||||
error = extrinsic.translation - self.target.translation
|
||||
if np.linalg.norm(error) < 0.01:
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
return self.target
|
||||
|
||||
@ -83,7 +83,7 @@ class FixedTrajectory(BasePolicy):
|
||||
self.integrate_img(img, extrinsic)
|
||||
elapsed_time = (rospy.Time.now() - self.tic).to_sec()
|
||||
if elapsed_time > self.duration:
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
else:
|
||||
t = self.m(elapsed_time)
|
||||
@ -106,7 +106,7 @@ class AlignmentView(BasePolicy):
|
||||
self.integrate_img(img, extrinsic)
|
||||
|
||||
if not self.target:
|
||||
grasp = self.predict_best_grasp()
|
||||
grasp = self.compute_best_grasp()
|
||||
if not grasp:
|
||||
self.done = True
|
||||
return
|
||||
@ -118,6 +118,6 @@ class AlignmentView(BasePolicy):
|
||||
|
||||
error = extrinsic.translation - self.target.translation
|
||||
if np.linalg.norm(error) < 0.01:
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
return self.target
|
||||
|
@ -3,10 +3,9 @@ import cv_bridge
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import CameraInfo, Image
|
||||
from sensor_msgs.msg import Image
|
||||
|
||||
from .bbox import from_bbox_msg
|
||||
from .policy import make
|
||||
from .timer import Timer
|
||||
from active_grasp.srv import Reset, ResetRequest
|
||||
from robot_helpers.ros import tf
|
||||
@ -16,19 +15,18 @@ from robot_helpers.spatial import Rotation, Transform
|
||||
|
||||
|
||||
class GraspController:
|
||||
def __init__(self, policy_id):
|
||||
def __init__(self, policy):
|
||||
self.policy = policy
|
||||
self.reset_env = rospy.ServiceProxy("reset", Reset)
|
||||
self.load_parameters()
|
||||
self.lookup_transforms()
|
||||
self.init_robot_connection()
|
||||
self.init_camera_stream()
|
||||
self.make_policy(policy_id)
|
||||
|
||||
def load_parameters(self):
|
||||
self.base_frame = rospy.get_param("~base_frame_id")
|
||||
self.ee_frame = rospy.get_param("~ee_frame_id")
|
||||
self.cam_frame = rospy.get_param("~camera/frame_id")
|
||||
self.info_topic = rospy.get_param("~camera/info_topic")
|
||||
self.depth_topic = rospy.get_param("~camera/depth_topic")
|
||||
self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv()
|
||||
|
||||
@ -44,17 +42,12 @@ class GraspController:
|
||||
self.target_pose_pub.publish(msg)
|
||||
|
||||
def init_camera_stream(self):
|
||||
msg = rospy.wait_for_message(self.info_topic, CameraInfo, rospy.Duration(2.0))
|
||||
self.intrinsic = from_camera_info_msg(msg)
|
||||
self.cv_bridge = cv_bridge.CvBridge()
|
||||
rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1)
|
||||
|
||||
def sensor_cb(self, msg):
|
||||
self.latest_depth_msg = msg
|
||||
|
||||
def make_policy(self, name):
|
||||
self.policy = make(name, self.intrinsic)
|
||||
|
||||
def run(self):
|
||||
bbox = self.reset()
|
||||
with Timer("search_time"):
|
||||
|
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from sensor_msgs.msg import CameraInfo
|
||||
from pathlib import Path
|
||||
import rospy
|
||||
|
||||
@ -19,15 +20,17 @@ class Policy:
|
||||
|
||||
|
||||
class BasePolicy(Policy):
|
||||
def __init__(self, intrinsic):
|
||||
self.intrinsic = intrinsic
|
||||
self.rate = 5
|
||||
def __init__(self, rate=5):
|
||||
self.rate = rate
|
||||
self.load_parameters()
|
||||
self.init_visualizer()
|
||||
|
||||
def load_parameters(self):
|
||||
self.base_frame = rospy.get_param("active_grasp/base_frame_id")
|
||||
self.task_frame = "task"
|
||||
info_topic = rospy.get_param("active_grasp/camera/info_topic")
|
||||
msg = rospy.wait_for_message(info_topic, CameraInfo, rospy.Duration(2.0))
|
||||
self.intrinsic = from_camera_info_msg(msg)
|
||||
self.vgn = VGN(Path(rospy.get_param("vgn/model")))
|
||||
|
||||
def init_visualizer(self):
|
||||
@ -56,6 +59,9 @@ class BasePolicy(Policy):
|
||||
self.visualizer.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud())
|
||||
self.visualizer.path(self.viewpoints)
|
||||
|
||||
def compute_best_grasp(self):
|
||||
return self.predict_best_grasp()
|
||||
|
||||
def predict_best_grasp(self):
|
||||
tsdf_grid = self.tsdf.get_grid()
|
||||
out = self.vgn.predict(tsdf_grid)
|
||||
|
@ -6,16 +6,19 @@ import rospy
|
||||
from tqdm import tqdm
|
||||
|
||||
from active_grasp.controller import *
|
||||
from active_grasp.policy import registry
|
||||
from active_grasp.policy import make, registry
|
||||
from active_grasp.srv import Seed
|
||||
|
||||
|
||||
def main():
|
||||
rospy.init_node("active_grasp")
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
controller = GraspController(args.policy)
|
||||
logger = Logger(args.logdir, args.policy)
|
||||
|
||||
policy = make(args.policy, args.rate)
|
||||
controller = GraspController(policy)
|
||||
logger = Logger(args)
|
||||
|
||||
seed_simulation(args.seed)
|
||||
|
||||
@ -29,15 +32,16 @@ def create_parser():
|
||||
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||
parser.add_argument("--runs", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=Path, default="logs")
|
||||
parser.add_argument("--rate", type=int, default=5)
|
||||
parser.add_argument("--seed", type=int, default=12)
|
||||
return parser
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, logdir, policy):
|
||||
def __init__(self, args):
|
||||
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
name = "{}_policy={}".format(stamp, policy)
|
||||
self.path = logdir / (name + ".csv")
|
||||
descr = "policy={},rate={}".format(args.policy, args.rate)
|
||||
self.path = args.logdir / (stamp + "_" + descr + ".csv")
|
||||
|
||||
def log_run(self, info):
|
||||
df = pd.DataFrame.from_records([info])
|
||||
|
Loading…
x
Reference in New Issue
Block a user