diff --git a/README.md b/README.md index 335bff4..ca6fc38 100644 --- a/README.md +++ b/README.md @@ -1 +1,13 @@ -# active_grasp \ No newline at end of file +# active_grasp + +First, run the simulation + +``` +roslaunch active_grasp simulation.launch +``` + +Then you can run a policy. + +``` +python3 run.py ... +``` diff --git a/launch/panda_visualization.launch b/launch/panda_visualization.launch deleted file mode 100644 index d694dc6..0000000 --- a/launch/panda_visualization.launch +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - diff --git a/launch/simulation.launch b/launch/simulation.launch new file mode 100644 index 0000000..1e1c92f --- /dev/null +++ b/launch/simulation.launch @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/sim.py b/nodes/simulation.py old mode 100644 new mode 100755 similarity index 69% rename from sim.py rename to nodes/simulation.py index 150ee89..3cae5f1 --- a/sim.py +++ b/nodes/simulation.py @@ -1,8 +1,12 @@ +#!/usr/bin/env python3 + import argparse +import actionlib import numpy as np import rospy +import franka_gripper.msg from geometry_msgs.msg import Pose from sensor_msgs.msg import JointState @@ -22,24 +26,31 @@ class BtSimNode: self.joint_state_pub = rospy.Publisher( "/joint_states", JointState, queue_size=10 ) - rospy.Subscriber("/target", Pose, self._target_pose_cb) + self.move_server = actionlib.SimpleActionServer( + "move", + franka_gripper.msg.MoveAction, + execute_cb=self.move, + auto_start=False, + ) + self.move_server.start() + rospy.Subscriber("/target", Pose, self.target_pose_cb) def run(self): rate = rospy.Rate(self.sim.rate) self.step_cnt = 0 while not rospy.is_shutdown(): - self._handle_updates() + self.handle_updates() self.sim.step() self.step_cnt = (self.step_cnt + 1) % self.sim.rate rate.sleep() - def _handle_updates(self): + def handle_updates(self): if self.step_cnt % int(self.sim.rate / CONTROLLER_UPDATE_RATE) == 0: self.controller.update() if self.step_cnt % int(self.sim.rate / JOINT_STATE_PUBLISHER_RATE) == 0: - self._publish_joint_state() + self.publish_joint_state() - def _publish_joint_state(self): + def publish_joint_state(self): q, dq = self.sim.arm.get_state() width = self.sim.gripper.read() msg = JointState() @@ -52,7 +63,11 @@ class BtSimNode: msg.velocity = dq self.joint_state_pub.publish(msg) - def _target_pose_cb(self, msg): + def move(self, goal): + self.sim.gripper.move(goal.width) + self.move_server.set_succeeded() + + def target_pose_cb(self, msg): self.controller.set_target(from_pose_msg(msg)) @@ -64,6 +79,6 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--gui", type=str, default=True) - args = parser.parse_args() + parser.add_argument("--gui", action="store_true") + args, _ = parser.parse_known_args() main(args) diff --git a/policies.py b/policies.py index d4e441b..ca674c7 100644 --- a/policies.py +++ b/policies.py @@ -3,12 +3,15 @@ import numpy as np import rospy import scipy.interpolate +from robot_tools.spatial import Rotation, Transform from robot_tools.ros import * def get_policy(name): - if name == "fixed": + if name == "fixed-trajectory": return FixedTrajectory() + else: + raise ValueError("{} policy does not exist.".format(name)) class BasePolicy: diff --git a/run.py b/run.py index ce44641..e7a12c6 100644 --- a/run.py +++ b/run.py @@ -7,25 +7,52 @@ from std_srvs.srv import Trigger from policies import get_policy +from robot_tools.ros import * + + +class GraspController: + def __init__(self, policy, rate): + self.policy = policy + self.rate = rate + + self.target_pose_pub = rospy.Publisher("/target", Pose, queue_size=10) + self.gripper = PandaGripperRosInterface() + + def explore(self): + r = rospy.Rate(self.rate) + done = False + self.policy.start() + while not done: + done = self.policy.update() + r.sleep() + + def execute_grasp(self): + self.gripper.move(0.08) + rospy.sleep(1.0) + target = self.policy.best_grasp + self.target_pose_pub.publish(to_pose_msg(target)) + rospy.sleep(2.0) + self.gripper.move(0.0) + rospy.sleep(1.0) + target.translation[2] += 0.1 + self.target_pose_pub.publish(to_pose_msg(target)) + rospy.sleep(2.0) + def main(args): rospy.init_node("panda_grasp") policy = get_policy(args.policy) - - r = rospy.Rate(args.rate) - done = False - policy.start() - while not done: - done = policy.update() - r.sleep() - - # TODO execute grasp + gc = GraspController(policy, args.rate) + gc.explore() + gc.execute_grasp() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--policy", type=str, choices=["fixed"]) + parser.add_argument( + "--policy", type=str, choices=["single-view", "fixed-trajectory"] + ) parser.add_argument("--rate", type=int, default=10) args = parser.parse_args() main(args)