Pass arguments directly to the policy
This commit is contained in:
parent
6fa4007727
commit
4eeb309a8f
@ -13,7 +13,7 @@ class SingleView(BasePolicy):
|
|||||||
|
|
||||||
def update(self, img, extrinsic):
|
def update(self, img, extrinsic):
|
||||||
self.integrate_img(img, extrinsic)
|
self.integrate_img(img, extrinsic)
|
||||||
self.best_grasp = self.predict_best_grasp()
|
self.best_grasp = self.compute_best_grasp()
|
||||||
self.done = True
|
self.done = True
|
||||||
|
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ class TopView(BasePolicy):
|
|||||||
self.integrate_img(img, extrinsic)
|
self.integrate_img(img, extrinsic)
|
||||||
error = extrinsic.translation - self.target.translation
|
error = extrinsic.translation - self.target.translation
|
||||||
if np.linalg.norm(error) < 0.01:
|
if np.linalg.norm(error) < 0.01:
|
||||||
self.best_grasp = self.predict_best_grasp()
|
self.best_grasp = self.compute_best_grasp()
|
||||||
self.done = True
|
self.done = True
|
||||||
return self.target
|
return self.target
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class RandomView(BasePolicy):
|
|||||||
self.integrate_img(img, extrinsic)
|
self.integrate_img(img, extrinsic)
|
||||||
error = extrinsic.translation - self.target.translation
|
error = extrinsic.translation - self.target.translation
|
||||||
if np.linalg.norm(error) < 0.01:
|
if np.linalg.norm(error) < 0.01:
|
||||||
self.best_grasp = self.predict_best_grasp()
|
self.best_grasp = self.compute_best_grasp()
|
||||||
self.done = True
|
self.done = True
|
||||||
return self.target
|
return self.target
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ class FixedTrajectory(BasePolicy):
|
|||||||
self.integrate_img(img, extrinsic)
|
self.integrate_img(img, extrinsic)
|
||||||
elapsed_time = (rospy.Time.now() - self.tic).to_sec()
|
elapsed_time = (rospy.Time.now() - self.tic).to_sec()
|
||||||
if elapsed_time > self.duration:
|
if elapsed_time > self.duration:
|
||||||
self.best_grasp = self.predict_best_grasp()
|
self.best_grasp = self.compute_best_grasp()
|
||||||
self.done = True
|
self.done = True
|
||||||
else:
|
else:
|
||||||
t = self.m(elapsed_time)
|
t = self.m(elapsed_time)
|
||||||
@ -106,7 +106,7 @@ class AlignmentView(BasePolicy):
|
|||||||
self.integrate_img(img, extrinsic)
|
self.integrate_img(img, extrinsic)
|
||||||
|
|
||||||
if not self.target:
|
if not self.target:
|
||||||
grasp = self.predict_best_grasp()
|
grasp = self.compute_best_grasp()
|
||||||
if not grasp:
|
if not grasp:
|
||||||
self.done = True
|
self.done = True
|
||||||
return
|
return
|
||||||
@ -118,6 +118,6 @@ class AlignmentView(BasePolicy):
|
|||||||
|
|
||||||
error = extrinsic.translation - self.target.translation
|
error = extrinsic.translation - self.target.translation
|
||||||
if np.linalg.norm(error) < 0.01:
|
if np.linalg.norm(error) < 0.01:
|
||||||
self.best_grasp = self.predict_best_grasp()
|
self.best_grasp = self.compute_best_grasp()
|
||||||
self.done = True
|
self.done = True
|
||||||
return self.target
|
return self.target
|
||||||
|
@ -3,10 +3,9 @@ import cv_bridge
|
|||||||
from geometry_msgs.msg import PoseStamped
|
from geometry_msgs.msg import PoseStamped
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rospy
|
import rospy
|
||||||
from sensor_msgs.msg import CameraInfo, Image
|
from sensor_msgs.msg import Image
|
||||||
|
|
||||||
from .bbox import from_bbox_msg
|
from .bbox import from_bbox_msg
|
||||||
from .policy import make
|
|
||||||
from .timer import Timer
|
from .timer import Timer
|
||||||
from active_grasp.srv import Reset, ResetRequest
|
from active_grasp.srv import Reset, ResetRequest
|
||||||
from robot_helpers.ros import tf
|
from robot_helpers.ros import tf
|
||||||
@ -16,19 +15,18 @@ from robot_helpers.spatial import Rotation, Transform
|
|||||||
|
|
||||||
|
|
||||||
class GraspController:
|
class GraspController:
|
||||||
def __init__(self, policy_id):
|
def __init__(self, policy):
|
||||||
|
self.policy = policy
|
||||||
self.reset_env = rospy.ServiceProxy("reset", Reset)
|
self.reset_env = rospy.ServiceProxy("reset", Reset)
|
||||||
self.load_parameters()
|
self.load_parameters()
|
||||||
self.lookup_transforms()
|
self.lookup_transforms()
|
||||||
self.init_robot_connection()
|
self.init_robot_connection()
|
||||||
self.init_camera_stream()
|
self.init_camera_stream()
|
||||||
self.make_policy(policy_id)
|
|
||||||
|
|
||||||
def load_parameters(self):
|
def load_parameters(self):
|
||||||
self.base_frame = rospy.get_param("~base_frame_id")
|
self.base_frame = rospy.get_param("~base_frame_id")
|
||||||
self.ee_frame = rospy.get_param("~ee_frame_id")
|
self.ee_frame = rospy.get_param("~ee_frame_id")
|
||||||
self.cam_frame = rospy.get_param("~camera/frame_id")
|
self.cam_frame = rospy.get_param("~camera/frame_id")
|
||||||
self.info_topic = rospy.get_param("~camera/info_topic")
|
|
||||||
self.depth_topic = rospy.get_param("~camera/depth_topic")
|
self.depth_topic = rospy.get_param("~camera/depth_topic")
|
||||||
self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv()
|
self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv()
|
||||||
|
|
||||||
@ -44,17 +42,12 @@ class GraspController:
|
|||||||
self.target_pose_pub.publish(msg)
|
self.target_pose_pub.publish(msg)
|
||||||
|
|
||||||
def init_camera_stream(self):
|
def init_camera_stream(self):
|
||||||
msg = rospy.wait_for_message(self.info_topic, CameraInfo, rospy.Duration(2.0))
|
|
||||||
self.intrinsic = from_camera_info_msg(msg)
|
|
||||||
self.cv_bridge = cv_bridge.CvBridge()
|
self.cv_bridge = cv_bridge.CvBridge()
|
||||||
rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1)
|
rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1)
|
||||||
|
|
||||||
def sensor_cb(self, msg):
|
def sensor_cb(self, msg):
|
||||||
self.latest_depth_msg = msg
|
self.latest_depth_msg = msg
|
||||||
|
|
||||||
def make_policy(self, name):
|
|
||||||
self.policy = make(name, self.intrinsic)
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
bbox = self.reset()
|
bbox = self.reset()
|
||||||
with Timer("search_time"):
|
with Timer("search_time"):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
from sensor_msgs.msg import CameraInfo
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import rospy
|
import rospy
|
||||||
|
|
||||||
@ -19,15 +20,17 @@ class Policy:
|
|||||||
|
|
||||||
|
|
||||||
class BasePolicy(Policy):
|
class BasePolicy(Policy):
|
||||||
def __init__(self, intrinsic):
|
def __init__(self, rate=5):
|
||||||
self.intrinsic = intrinsic
|
self.rate = rate
|
||||||
self.rate = 5
|
|
||||||
self.load_parameters()
|
self.load_parameters()
|
||||||
self.init_visualizer()
|
self.init_visualizer()
|
||||||
|
|
||||||
def load_parameters(self):
|
def load_parameters(self):
|
||||||
self.base_frame = rospy.get_param("active_grasp/base_frame_id")
|
self.base_frame = rospy.get_param("active_grasp/base_frame_id")
|
||||||
self.task_frame = "task"
|
self.task_frame = "task"
|
||||||
|
info_topic = rospy.get_param("active_grasp/camera/info_topic")
|
||||||
|
msg = rospy.wait_for_message(info_topic, CameraInfo, rospy.Duration(2.0))
|
||||||
|
self.intrinsic = from_camera_info_msg(msg)
|
||||||
self.vgn = VGN(Path(rospy.get_param("vgn/model")))
|
self.vgn = VGN(Path(rospy.get_param("vgn/model")))
|
||||||
|
|
||||||
def init_visualizer(self):
|
def init_visualizer(self):
|
||||||
@ -56,6 +59,9 @@ class BasePolicy(Policy):
|
|||||||
self.visualizer.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud())
|
self.visualizer.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud())
|
||||||
self.visualizer.path(self.viewpoints)
|
self.visualizer.path(self.viewpoints)
|
||||||
|
|
||||||
|
def compute_best_grasp(self):
|
||||||
|
return self.predict_best_grasp()
|
||||||
|
|
||||||
def predict_best_grasp(self):
|
def predict_best_grasp(self):
|
||||||
tsdf_grid = self.tsdf.get_grid()
|
tsdf_grid = self.tsdf.get_grid()
|
||||||
out = self.vgn.predict(tsdf_grid)
|
out = self.vgn.predict(tsdf_grid)
|
||||||
|
@ -6,16 +6,19 @@ import rospy
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from active_grasp.controller import *
|
from active_grasp.controller import *
|
||||||
from active_grasp.policy import registry
|
from active_grasp.policy import make, registry
|
||||||
from active_grasp.srv import Seed
|
from active_grasp.srv import Seed
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
rospy.init_node("active_grasp")
|
rospy.init_node("active_grasp")
|
||||||
|
|
||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
controller = GraspController(args.policy)
|
|
||||||
logger = Logger(args.logdir, args.policy)
|
policy = make(args.policy, args.rate)
|
||||||
|
controller = GraspController(policy)
|
||||||
|
logger = Logger(args)
|
||||||
|
|
||||||
seed_simulation(args.seed)
|
seed_simulation(args.seed)
|
||||||
|
|
||||||
@ -29,15 +32,16 @@ def create_parser():
|
|||||||
parser.add_argument("policy", type=str, choices=registry.keys())
|
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||||
parser.add_argument("--runs", type=int, default=10)
|
parser.add_argument("--runs", type=int, default=10)
|
||||||
parser.add_argument("--logdir", type=Path, default="logs")
|
parser.add_argument("--logdir", type=Path, default="logs")
|
||||||
|
parser.add_argument("--rate", type=int, default=5)
|
||||||
parser.add_argument("--seed", type=int, default=12)
|
parser.add_argument("--seed", type=int, default=12)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class Logger:
|
class Logger:
|
||||||
def __init__(self, logdir, policy):
|
def __init__(self, args):
|
||||||
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
|
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
name = "{}_policy={}".format(stamp, policy)
|
descr = "policy={},rate={}".format(args.policy, args.rate)
|
||||||
self.path = logdir / (name + ".csv")
|
self.path = args.logdir / (stamp + "_" + descr + ".csv")
|
||||||
|
|
||||||
def log_run(self, info):
|
def log_run(self, info):
|
||||||
df = pd.DataFrame.from_records([info])
|
df = pd.DataFrame.from_records([info])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user