nbv_sim/scripts/bt_sim_node.py
2024-12-20 15:28:39 +08:00

405 lines
13 KiB
Python
Executable File

#!/usr/bin/env python3
from actionlib import SimpleActionServer
import control_msgs.msg as control_msgs
from controller_manager_msgs.srv import *
import cv_bridge
from franka_msgs.msg import FrankaState, ErrorRecoveryAction
from franka_gripper.msg import *
from geometry_msgs.msg import Twist
import numpy as np
import rospy
from sensor_msgs.msg import JointState, Image, CameraInfo
from scipy import interpolate
from std_msgs.msg import Header
from threading import Thread
from active_grasp.bbox import to_bbox_msg
from active_grasp.srv import *
from active_grasp.simulation import Simulation
from robot_helpers.ros.conversions import *
from vgn.simulation import apply_noise
class BtSimNode:
def __init__(self):
gui = rospy.get_param("~gui")
scene_id = rospy.get_param("~scene")
vgn_path = rospy.get_param("vgn/model")
self.sim = Simulation(gui, scene_id, vgn_path)
self.init_plugins()
self.advertise_services()
def init_plugins(self):
self.plugins = [
PhysicsPlugin(self.sim),
RobotStatePlugin(self.sim.arm, self.sim.gripper),
MoveActionPlugin(self.sim.gripper),
GraspActionPlugin(self.sim.gripper),
GripperActionPlugin(),
CameraPlugin(self.sim.camera),
MockActionsPlugin(),
]
self.controllers = {
"cartesian_velocity_controller": CartesianVelocityControllerPlugin(
self.sim.arm, self.sim.model
),
"position_joint_trajectory_controller": JointTrajectoryControllerPlugin(
self.sim.arm
),
}
def start_plugins(self):
for plugin in self.plugins + list(self.controllers.values()):
plugin.thread.start()
def activate_plugins(self):
for plugin in self.plugins:
plugin.activate()
def deactivate_plugins(self):
for plugin in self.plugins:
plugin.deactivate()
def deactivate_controllers(self):
for controller in self.controllers.values():
controller.deactivate()
def advertise_services(self):
rospy.Service("seed", Seed, self.seed)
rospy.Service("reset", Reset, self.reset)
rospy.Service(
"/controller_manager/switch_controller",
SwitchController,
self.switch_controller,
)
rospy.Service("get_target_seg_id", TargetID, self.get_target_seg_id)
rospy.Service("get_support_seg_id", TargetID, self.get_support_seg_id)
def seed(self, req):
self.sim.seed(req.seed)
rospy.loginfo(f"Seeded the rng with {req.seed}.")
return SeedResponse()
def reset(self, req):
self.deactivate_plugins()
self.deactivate_controllers()
rospy.sleep(1.0) # TODO replace with a read-write lock
bbox = self.sim.reset()
#import ipdb; ipdb.set_trace()
self.activate_plugins()
return ResetResponse(to_bbox_msg(bbox))
def get_target_seg_id(self, req):
response = TargetIDResponse()
response.id = self.sim.select_target()
return response
def get_support_seg_id(self, req):
response = TargetIDResponse()
response.id = self.sim.select_support()
return response
def switch_controller(self, req):
for controller in req.stop_controllers:
self.controllers[controller].deactivate()
for controller in req.start_controllers:
self.controllers[controller].activate()
return SwitchControllerResponse(ok=True)
def run(self):
self.start_plugins()
self.activate_plugins()
rospy.spin()
class Plugin:
"""A plugin that spins at a constant rate in its own thread."""
def __init__(self, rate):
self.rate = rate
self.thread = Thread(target=self.loop, daemon=True)
self.is_running = False
def activate(self):
self.is_running = True
def deactivate(self):
self.is_running = False
def loop(self):
rate = rospy.Rate(self.rate)
while not rospy.is_shutdown():
if self.is_running:
self.update()
rate.sleep()
def update(self):
raise NotImplementedError
class PhysicsPlugin(Plugin):
def __init__(self, sim):
super().__init__(sim.rate)
self.sim = sim
def update(self):
self.sim.step()
class RobotStatePlugin(Plugin):
def __init__(self, arm, gripper, rate=30):
super().__init__(rate)
self.arm = arm
self.gripper = gripper
self.arm_state_pub = rospy.Publisher(
"/franka_state_controller/franka_states", FrankaState, queue_size=10
)
self.gripper_state_pub = rospy.Publisher(
"/franka_gripper/joint_states", JointState, queue_size=10
)
self.joint_states_pub = rospy.Publisher(
"joint_states", JointState, queue_size=10
)
def update(self):
q, dq = self.arm.get_state()
width = self.gripper.read()
header = Header(stamp=rospy.Time.now())
msg = FrankaState(header=header, q=q, dq=dq)
self.arm_state_pub.publish(msg)
msg = JointState(header=header)
msg.name = ["panda_finger_joint1", "panda_finger_joint2"]
msg.position = [0.5 * width, 0.5 * width]
self.gripper_state_pub.publish(msg)
msg = JointState(header=header)
msg.name = ["panda_joint{}".format(i) for i in range(1, 8)] + [
"panda_finger_joint1",
"panda_finger_joint2",
]
msg.position = np.r_[q, 0.5 * width, 0.5 * width]
self.joint_states_pub.publish(msg)
class CartesianVelocityControllerPlugin(Plugin):
def __init__(self, arm, model, rate=30):
super().__init__(rate)
self.arm = arm
self.model = model
topic = rospy.get_param("cartesian_velocity_controller/topic")
rospy.Subscriber(topic, Twist, self.target_cb)
def target_cb(self, msg):
self.dx_d = from_twist_msg(msg)
def activate(self):
self.dx_d = np.zeros(6)
self.is_running = True
def deactivate(self):
self.dx_d = np.zeros(6)
self.is_running = False
self.arm.set_desired_joint_velocities(np.zeros(7))
def update(self):
q, _ = self.arm.get_state()
J_pinv = np.linalg.pinv(self.model.jacobian(q))
cmd = np.dot(J_pinv, self.dx_d)
self.arm.set_desired_joint_velocities(cmd)
class JointTrajectoryControllerPlugin(Plugin):
def __init__(self, arm, rate=30):
super().__init__(rate)
self.arm = arm
self.dt = 1.0 / self.rate # TODO this might not be reliable
self.init_action_server()
def init_action_server(self):
name = "position_joint_trajectory_controller/follow_joint_trajectory"
self.action_server = SimpleActionServer(
name, control_msgs.FollowJointTrajectoryAction, auto_start=False
)
self.action_server.register_goal_callback(self.action_goal_cb)
self.action_server.start()
def action_goal_cb(self):
goal = self.action_server.accept_new_goal()
self.interpolate_trajectory(goal.trajectory.points)
self.elapsed_time = 0.0
def interpolate_trajectory(self, points):
t, y = np.zeros(len(points)), np.zeros((7, len(points)))
for i, point in enumerate(points):
t[i] = point.time_from_start.to_sec()
y[:, i] = point.positions
self.m = interpolate.interp1d(t, y)
self.duration = t[-1]
def update(self):
if self.action_server.is_active():
self.elapsed_time += self.dt
if self.elapsed_time > self.duration:
self.action_server.set_succeeded()
return
self.arm.set_desired_joint_positions(self.m(self.elapsed_time))
class MoveActionPlugin(Plugin):
def __init__(self, gripper, rate=10):
super().__init__(rate)
self.gripper = gripper
self.dt = 1.0 / self.rate
self.init_action_server()
def init_action_server(self):
name = "/franka_gripper/move"
self.action_server = SimpleActionServer(name, MoveAction, auto_start=False)
self.action_server.register_goal_callback(self.action_goal_cb)
self.action_server.start()
def action_goal_cb(self):
self.elapsed_time = 0.0
goal = self.action_server.accept_new_goal()
self.gripper.set_desired_width(goal.width)
def update(self):
if self.action_server.is_active():
self.elapsed_time += self.dt
if self.elapsed_time > 1.0:
self.action_server.set_succeeded()
class GraspActionPlugin(Plugin):
def __init__(self, gripper, rate=10):
super().__init__(rate)
self.gripper = gripper
self.dt = 1.0 / self.rate
self.force = rospy.get_param("~gripper_force")
self.init_action_server()
def init_action_server(self):
name = "/franka_gripper/grasp"
self.action_server = SimpleActionServer(name, GraspAction, auto_start=False)
self.action_server.register_goal_callback(self.action_goal_cb)
self.action_server.start()
def action_goal_cb(self):
self.elapsed_time = 0.0
goal = self.action_server.accept_new_goal()
self.gripper.set_desired_speed(-0.1, force=self.force)
def update(self):
if self.action_server.is_active():
self.elapsed_time += self.dt
if self.elapsed_time > 1.0:
self.action_server.set_succeeded()
class GripperActionPlugin(Plugin):
"""Empty action server to make MoveIt happy"""
def __init__(self, rate=1):
super().__init__(rate)
self.init_action_server()
def init_action_server(self):
name = "/franka_gripper/gripper_action"
self.action_server = SimpleActionServer(
name, control_msgs.GripperCommandAction, auto_start=False
)
self.action_server.register_goal_callback(self.action_goal_cb)
self.action_server.start()
def action_goal_cb(self):
self.action_server.accept_new_goal()
def update(self):
if self.action_server.is_active():
self.action_server.set_succeeded()
class CameraPlugin(Plugin):
def __init__(self, camera, name="camera", rate=5):
super().__init__(rate)
self.camera = camera
self.name = name
self.cam_noise = rospy.get_param("~cam_noise", False)
self.cv_bridge = cv_bridge.CvBridge()
self.init_publishers()
def init_publishers(self):
topic = self.name + "/depth/camera_info"
self.info_pub = rospy.Publisher(topic, CameraInfo, queue_size=10)
topic = self.name + "/depth/image_rect_raw"
self.depth_pub = rospy.Publisher(topic, Image, queue_size=10)
topic = self.name + "/color/image_rect_raw"
self.rgb_pub = rospy.Publisher(topic, Image, queue_size=10)
topic = self.name + "/segmentation/image_rect_raw"
self.seg_pub = rospy.Publisher(topic, Image, queue_size=10)
def update(self):
stamp = rospy.Time.now()
msg = to_camera_info_msg(self.camera.intrinsic)
msg.header.frame_id = self.name + "_optical_frame"
msg.header.stamp = stamp
self.info_pub.publish(msg)
color, depth, seg = self.camera.get_image()
if self.cam_noise:
depth = apply_noise(depth)
msg = self.cv_bridge.cv2_to_imgmsg((1000 * depth).astype(np.uint16))
msg.header.stamp = stamp
self.depth_pub.publish(msg)
msg = self.cv_bridge.cv2_to_imgmsg(color.astype(np.uint8), encoding='rgb8')
msg.header.stamp = stamp
self.rgb_pub.publish(msg)
msg = self.cv_bridge.cv2_to_imgmsg(seg.astype(np.uint8), encoding='mono8')
msg.header.stamp = stamp
self.seg_pub.publish(msg)
class MockActionsPlugin(Plugin):
def __init__(self):
super().__init__(1)
self.init_recovery_action_server()
self.init_homing_action_server()
def init_homing_action_server(self):
self.homing_as = SimpleActionServer(
"/franka_gripper/homing", HomingAction, auto_start=False
)
self.homing_as.register_goal_callback(self.action_goal_cb)
self.homing_as.start()
def init_recovery_action_server(self):
self.recovery_as = SimpleActionServer(
"/franka_control/error_recovery", ErrorRecoveryAction, auto_start=False
)
self.recovery_as.register_goal_callback(self.action_goal_cb)
self.recovery_as.start()
def action_goal_cb(self):
pass
def update(self):
pass
def main():
rospy.init_node("bt_sim")
server = BtSimNode()
server.run()
if __name__ == "__main__":
main()