2024-10-09 16:13:22 +00:00

125 lines
5.3 KiB
Python
Executable File

from ipdb import set_trace
import os
import sys
import numpy as np
import argparse
import time
import torch
from torch.utils.data import DataLoader
from graspnetAPI.graspnet_eval import GraspGroup, GraspNetEval
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(ROOT_DIR, 'pointnet2'))
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
sys.path.append(os.path.join(ROOT_DIR, 'models'))
sys.path.append(os.path.join(ROOT_DIR, 'dataset'))
from models.graspnet import GraspNet, pred_decode
from dataset.graspnet_dataset import GraspNetDataset, minkowski_collate_fn
from collision_detector import ModelFreeCollisionDetector
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', default=None, required=True)
parser.add_argument('--checkpoint_path', help='Model checkpoint path', default=None, required=True)
parser.add_argument('--dump_dir', help='Dump dir to save outputs', default=None, required=True)
parser.add_argument('--seed_feat_dim', default=512, type=int, help='Point wise feature dim')
parser.add_argument('--camera', default='kinect', help='Camera split [realsense/kinect]')
parser.add_argument('--num_point', type=int, default=15000, help='Point Number [default: 15000]')
parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during inference [default: 1]')
parser.add_argument('--voxel_size', type=float, default=0.005, help='Voxel Size for sparse convolution')
parser.add_argument('--collision_thresh', type=float, default=0.01,
help='Collision Threshold in collision detection [default: 0.01]')
parser.add_argument('--voxel_size_cd', type=float, default=0.01, help='Voxel Size for collision detection')
parser.add_argument('--infer', action='store_true', default=False)
parser.add_argument('--eval', action='store_true', default=False)
cfgs = parser.parse_args()
# ------------------------------------------------------------------------- GLOBAL CONFIG BEG
if not os.path.exists(cfgs.dump_dir):
os.mkdir(cfgs.dump_dir)
# Init datasets and dataloaders
def my_worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
pass
def inference():
test_dataset = GraspNetDataset(cfgs.dataset_root, split='test_seen', camera=cfgs.camera, num_points=cfgs.num_point,
voxel_size=cfgs.voxel_size, remove_outlier=True, augment=False, load_label=False)
print('Test dataset length: ', len(test_dataset))
scene_list = test_dataset.scene_list()
test_dataloader = DataLoader(test_dataset, batch_size=cfgs.batch_size, shuffle=False,
num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn)
print('Test dataloader length: ', len(test_dataloader))
# Init the model
net = GraspNet(seed_feat_dim=cfgs.seed_feat_dim, is_training=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
# Load checkpoint
checkpoint = torch.load(cfgs.checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
print("-> loaded checkpoint %s (epoch: %d)" % (cfgs.checkpoint_path, start_epoch))
batch_interval = 100
net.eval()
tic = time.time()
for batch_idx, batch_data in enumerate(test_dataloader):
for key in batch_data:
if 'list' in key:
for i in range(len(batch_data[key])):
for j in range(len(batch_data[key][i])):
batch_data[key][i][j] = batch_data[key][i][j].to(device)
else:
batch_data[key] = batch_data[key].to(device)
# Forward pass
with torch.no_grad():
end_points = net(batch_data)
grasp_preds = pred_decode(end_points)
# Dump results for evaluation
for i in range(cfgs.batch_size):
data_idx = batch_idx * cfgs.batch_size + i
preds = grasp_preds[i].detach().cpu().numpy()
gg = GraspGroup(preds)
# collision detection
if cfgs.collision_thresh > 0:
cloud = test_dataset.get_data(data_idx, return_raw_cloud=True)
mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size_cd)
collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh)
gg = gg[~collision_mask]
# save grasps
save_dir = os.path.join(cfgs.dump_dir, scene_list[data_idx], cfgs.camera)
save_path = os.path.join(save_dir, str(data_idx % 256).zfill(4) + '.npy')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
gg.save_npy(save_path)
if (batch_idx + 1) % batch_interval == 0:
toc = time.time()
print('Eval batch: %d, time: %fs' % (batch_idx + 1, (toc - tic) / batch_interval))
tic = time.time()
def evaluate(dump_dir):
ge = GraspNetEval(root=cfgs.dataset_root, camera=cfgs.camera, split='test_seen')
res, ap = ge.eval_seen(dump_folder=dump_dir, proc=6)
save_dir = os.path.join(cfgs.dump_dir, 'ap_{}.npy'.format(cfgs.camera))
np.save(save_dir, res)
if __name__ == '__main__':
if cfgs.infer:
#inference()
pass
if cfgs.eval:
evaluate(cfgs.dump_dir)