125 lines
5.3 KiB
Python
Executable File
125 lines
5.3 KiB
Python
Executable File
from ipdb import set_trace
|
|
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
import argparse
|
|
import time
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from graspnetAPI.graspnet_eval import GraspGroup, GraspNetEval
|
|
|
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(os.path.join(ROOT_DIR, 'pointnet2'))
|
|
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
|
|
sys.path.append(os.path.join(ROOT_DIR, 'models'))
|
|
sys.path.append(os.path.join(ROOT_DIR, 'dataset'))
|
|
from models.graspnet import GraspNet, pred_decode
|
|
from dataset.graspnet_dataset import GraspNetDataset, minkowski_collate_fn
|
|
from collision_detector import ModelFreeCollisionDetector
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--dataset_root', default=None, required=True)
|
|
parser.add_argument('--checkpoint_path', help='Model checkpoint path', default=None, required=True)
|
|
parser.add_argument('--dump_dir', help='Dump dir to save outputs', default=None, required=True)
|
|
parser.add_argument('--seed_feat_dim', default=512, type=int, help='Point wise feature dim')
|
|
parser.add_argument('--camera', default='kinect', help='Camera split [realsense/kinect]')
|
|
parser.add_argument('--num_point', type=int, default=15000, help='Point Number [default: 15000]')
|
|
parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during inference [default: 1]')
|
|
parser.add_argument('--voxel_size', type=float, default=0.005, help='Voxel Size for sparse convolution')
|
|
parser.add_argument('--collision_thresh', type=float, default=0.01,
|
|
help='Collision Threshold in collision detection [default: 0.01]')
|
|
parser.add_argument('--voxel_size_cd', type=float, default=0.01, help='Voxel Size for collision detection')
|
|
parser.add_argument('--infer', action='store_true', default=False)
|
|
parser.add_argument('--eval', action='store_true', default=False)
|
|
cfgs = parser.parse_args()
|
|
|
|
# ------------------------------------------------------------------------- GLOBAL CONFIG BEG
|
|
if not os.path.exists(cfgs.dump_dir):
|
|
os.mkdir(cfgs.dump_dir)
|
|
|
|
|
|
# Init datasets and dataloaders
|
|
def my_worker_init_fn(worker_id):
|
|
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
pass
|
|
|
|
|
|
def inference():
|
|
|
|
test_dataset = GraspNetDataset(cfgs.dataset_root, split='test_seen', camera=cfgs.camera, num_points=cfgs.num_point,
|
|
voxel_size=cfgs.voxel_size, remove_outlier=True, augment=False, load_label=False)
|
|
print('Test dataset length: ', len(test_dataset))
|
|
scene_list = test_dataset.scene_list()
|
|
test_dataloader = DataLoader(test_dataset, batch_size=cfgs.batch_size, shuffle=False,
|
|
num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn)
|
|
print('Test dataloader length: ', len(test_dataloader))
|
|
# Init the model
|
|
|
|
net = GraspNet(seed_feat_dim=cfgs.seed_feat_dim, is_training=False)
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
net.to(device)
|
|
# Load checkpoint
|
|
checkpoint = torch.load(cfgs.checkpoint_path)
|
|
net.load_state_dict(checkpoint['model_state_dict'])
|
|
start_epoch = checkpoint['epoch']
|
|
print("-> loaded checkpoint %s (epoch: %d)" % (cfgs.checkpoint_path, start_epoch))
|
|
|
|
batch_interval = 100
|
|
net.eval()
|
|
tic = time.time()
|
|
for batch_idx, batch_data in enumerate(test_dataloader):
|
|
for key in batch_data:
|
|
if 'list' in key:
|
|
for i in range(len(batch_data[key])):
|
|
for j in range(len(batch_data[key][i])):
|
|
batch_data[key][i][j] = batch_data[key][i][j].to(device)
|
|
else:
|
|
batch_data[key] = batch_data[key].to(device)
|
|
|
|
# Forward pass
|
|
with torch.no_grad():
|
|
end_points = net(batch_data)
|
|
grasp_preds = pred_decode(end_points)
|
|
|
|
# Dump results for evaluation
|
|
for i in range(cfgs.batch_size):
|
|
data_idx = batch_idx * cfgs.batch_size + i
|
|
preds = grasp_preds[i].detach().cpu().numpy()
|
|
|
|
gg = GraspGroup(preds)
|
|
# collision detection
|
|
if cfgs.collision_thresh > 0:
|
|
cloud = test_dataset.get_data(data_idx, return_raw_cloud=True)
|
|
mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size_cd)
|
|
collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh)
|
|
gg = gg[~collision_mask]
|
|
|
|
# save grasps
|
|
save_dir = os.path.join(cfgs.dump_dir, scene_list[data_idx], cfgs.camera)
|
|
save_path = os.path.join(save_dir, str(data_idx % 256).zfill(4) + '.npy')
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
gg.save_npy(save_path)
|
|
|
|
if (batch_idx + 1) % batch_interval == 0:
|
|
toc = time.time()
|
|
print('Eval batch: %d, time: %fs' % (batch_idx + 1, (toc - tic) / batch_interval))
|
|
tic = time.time()
|
|
|
|
|
|
def evaluate(dump_dir):
|
|
ge = GraspNetEval(root=cfgs.dataset_root, camera=cfgs.camera, split='test_seen')
|
|
res, ap = ge.eval_seen(dump_folder=dump_dir, proc=6)
|
|
save_dir = os.path.join(cfgs.dump_dir, 'ap_{}.npy'.format(cfgs.camera))
|
|
np.save(save_dir, res)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if cfgs.infer:
|
|
#inference()
|
|
pass
|
|
if cfgs.eval:
|
|
evaluate(cfgs.dump_dir)
|