diff --git a/nodes/simulation.py b/bt_sim_node.py
similarity index 72%
rename from nodes/simulation.py
rename to bt_sim_node.py
index 3cae5f1..e56a5b6 100755
--- a/nodes/simulation.py
+++ b/bt_sim_node.py
@@ -9,11 +9,13 @@ import rospy
import franka_gripper.msg
from geometry_msgs.msg import Pose
from sensor_msgs.msg import JointState
+import std_srvs.srv
-from robot_tools.btsim import BtPandaEnv
from robot_tools.controllers import CartesianPoseController
from robot_tools.ros import *
+from simulation import BtPandaEnv
+
CONTROLLER_UPDATE_RATE = 60
JOINT_STATE_PUBLISHER_RATE = 60
@@ -22,26 +24,37 @@ class BtSimNode:
def __init__(self, gui):
self.sim = BtPandaEnv(gui=gui, sleep=False)
self.controller = CartesianPoseController(self.sim.arm)
-
- self.joint_state_pub = rospy.Publisher(
- "/joint_states", JointState, queue_size=10
- )
+ self.reset_server = rospy.Service("reset", std_srvs.srv.Trigger, self.reset)
self.move_server = actionlib.SimpleActionServer(
"move",
franka_gripper.msg.MoveAction,
execute_cb=self.move,
auto_start=False,
)
+ self.joint_state_pub = rospy.Publisher(
+ "joint_states", JointState, queue_size=10
+ )
+ rospy.Subscriber("target", Pose, self.target_pose_cb)
+ self.step_cnt = 0
+ self.reset_requested = False
self.move_server.start()
- rospy.Subscriber("/target", Pose, self.target_pose_cb)
+
+ def reset(self, req):
+ self.reset_requested = True
+ rospy.sleep(1.0) # wait for the latest sim step to finish
+ self.sim.reset()
+ self.controller.set_target(self.sim.arm.pose())
+ self.step_cnt = 0
+ self.reset_requested = False
+ return std_srvs.srv.TriggerResponse(success=True)
def run(self):
rate = rospy.Rate(self.sim.rate)
- self.step_cnt = 0
while not rospy.is_shutdown():
- self.handle_updates()
- self.sim.step()
- self.step_cnt = (self.step_cnt + 1) % self.sim.rate
+ if not self.reset_requested:
+ self.handle_updates()
+ self.sim.step()
+ self.step_cnt = (self.step_cnt + 1) % self.sim.rate
rate.sleep()
def handle_updates(self):
diff --git a/launch/simulation.launch b/launch/simulation.launch
index 1e1c92f..8acd595 100644
--- a/launch/simulation.launch
+++ b/launch/simulation.launch
@@ -7,7 +7,7 @@
-
+
diff --git a/run.py b/run.py
index e7a12c6..af1fa46 100644
--- a/run.py
+++ b/run.py
@@ -1,12 +1,12 @@
import argparse
-from geometry_msgs.msg import Pose
import numpy as np
import rospy
-from std_srvs.srv import Trigger
+
+from geometry_msgs.msg import Pose
+import std_srvs.srv
from policies import get_policy
-
from robot_tools.ros import *
@@ -14,16 +14,25 @@ class GraspController:
def __init__(self, policy, rate):
self.policy = policy
self.rate = rate
-
- self.target_pose_pub = rospy.Publisher("/target", Pose, queue_size=10)
+ self.reset_client = rospy.ServiceProxy("reset", std_srvs.srv.Trigger)
+ self.target_pose_pub = rospy.Publisher("target", Pose, queue_size=10)
self.gripper = PandaGripperRosInterface()
+ rospy.sleep(1.0)
+
+ def run(self):
+ self.reset()
+ self.explore()
+ self.execute_grasp()
+
+ def reset(self):
+ req = std_srvs.srv.TriggerRequest()
+ self.reset_client(req)
def explore(self):
r = rospy.Rate(self.rate)
- done = False
self.policy.start()
- while not done:
- done = self.policy.update()
+ while not self.policy.done:
+ self.policy.update()
r.sleep()
def execute_grasp(self):
@@ -41,11 +50,9 @@ class GraspController:
def main(args):
rospy.init_node("panda_grasp")
-
policy = get_policy(args.policy)
gc = GraspController(policy, args.rate)
- gc.explore()
- gc.execute_grasp()
+ gc.run()
if __name__ == "__main__":
diff --git a/simulation.py b/simulation.py
new file mode 100644
index 0000000..dea6ba8
--- /dev/null
+++ b/simulation.py
@@ -0,0 +1,35 @@
+import pybullet as p
+
+from robot_tools.btsim import *
+from robot_tools.spatial import Rotation, Transform
+
+
+class BtPandaEnv(BtBaseEnv):
+ def __init__(self, gui=True, sleep=True):
+ super().__init__(gui, sleep)
+ self.arm = BtPandaArm()
+ self.gripper = BtPandaGripper()
+ self.T_W_B = Transform(Rotation.identity(), np.r_[-0.6, 0.0, 0.4])
+ self.load_table()
+ self.load_robot()
+ self.load_objects()
+
+ def reset(self):
+ q = self.arm.configurations["ready"]
+ for i, q_i in enumerate(q):
+ p.resetJointState(self.arm.uid, i, q_i, 0)
+
+ def load_table(self):
+ p.loadURDF("plane.urdf")
+ p.loadURDF(
+ "table/table.urdf",
+ baseOrientation=Rotation.from_rotvec(np.array([0, 0, np.pi / 2])).as_quat(),
+ useFixedBase=True,
+ )
+
+ def load_robot(self):
+ self.arm.load(self.T_W_B)
+ self.gripper.uid = self.arm.uid
+
+ def load_objects(self):
+ p.loadURDF("cube_small.urdf", [-0.2, 0.0, 0.8])