257 lines
10 KiB
Python
Executable File
257 lines
10 KiB
Python
Executable File
import os
|
|
import re
|
|
import sys
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from omni_util import OmniUtil
|
|
|
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(os.path.join(ROOT_DIR, "pointnet2"))
|
|
sys.path.append(os.path.join(ROOT_DIR, "utils"))
|
|
sys.path.append(os.path.join(ROOT_DIR, "models"))
|
|
sys.path.append(os.path.join(ROOT_DIR, "dataset"))
|
|
from models.graspnet import GraspNet
|
|
from dataset.graspnet_dataset import minkowski_collate_fn
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class GSNetInferenceDataset(Dataset):
|
|
CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json"
|
|
DISTANCE_TEMPLATE = "distance_to_camera_{}.npy"
|
|
RGB_TEMPLATE = "rgb_{}.png"
|
|
MASK_TEMPLATE = "semantic_segmentation_{}.png"
|
|
MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json"
|
|
|
|
def __init__(
|
|
self,
|
|
source="nbv1",
|
|
data_type="sample",
|
|
data_dir="/mnt/h/AI/Datasets",
|
|
scene_pts_num=15000,
|
|
):
|
|
|
|
self.data_dir = data_dir
|
|
self.scene_pts_num = scene_pts_num
|
|
self.data_path = str(os.path.join(self.data_dir, source, data_type))
|
|
self.scene_list = os.listdir(self.data_path)
|
|
self.data_list = self.get_datalist()
|
|
self.voxel_size = 0.005
|
|
|
|
def __len__(self):
|
|
return len(self.data_list)
|
|
|
|
def __getitem__(self, index):
|
|
frame_path = self.data_list[index]
|
|
frame_data = self.load_frame_data(frame_path=frame_path)
|
|
return frame_data
|
|
|
|
def get_datalist(self):
|
|
for scene in self.scene_list:
|
|
scene_path = os.path.join(self.data_path, scene)
|
|
file_list = os.listdir(scene_path)
|
|
scene_frame_list = []
|
|
for file in file_list:
|
|
if file.startswith("camera_params"):
|
|
frame_index = re.findall(r"\d+", file)[0]
|
|
frame_path = os.path.join(scene_path, frame_index)
|
|
scene_frame_list.append(frame_path)
|
|
|
|
return scene_frame_list
|
|
|
|
def load_frame_data(self, frame_path):
|
|
target_list = OmniUtil.get_object_list(path=frame_path, contains_nonobj=True)
|
|
scene_pts, obj_pcl_dict = OmniUtil.get_segmented_points(
|
|
path=frame_path, target_list=target_list
|
|
)
|
|
ret_dict = {
|
|
"frame_path": frame_path,
|
|
"point_clouds": scene_pts.astype(np.float32),
|
|
"coors": scene_pts.astype(np.float32) / self.voxel_size,
|
|
"feats": np.ones_like(scene_pts).astype(np.float32),
|
|
"obj_pcl_dict": obj_pcl_dict,
|
|
}
|
|
return ret_dict
|
|
|
|
@staticmethod
|
|
def sample_pcl(pcl, n_pts=1024):
|
|
indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts)
|
|
return pcl[indices, :]
|
|
|
|
|
|
class GSNetPreprocessor:
|
|
LABEL_TEMPLATE = "label_{}.json"
|
|
|
|
def __init__(self):
|
|
self.voxel_size = 0.005
|
|
self.camera = "kinect"
|
|
self.num_point = 15000
|
|
self.batch_size = 1
|
|
self.seed_feat_dim = 512
|
|
self.checkpoint_path = "logs/log_kn/epoch10.tar"
|
|
self.dump_dir = "logs/log_kn/dump_kinect"
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
def get_dataloader(self, dataset_config=None):
|
|
def my_worker_init_fn(worker_id):
|
|
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
pass
|
|
|
|
dataset = GSNetInferenceDataset()
|
|
print("Test dataset length: ", len(dataset))
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
num_workers=0,
|
|
worker_init_fn=my_worker_init_fn,
|
|
collate_fn=minkowski_collate_fn,
|
|
)
|
|
print("Test dataloader length: ", len(dataloader))
|
|
return dataloader
|
|
|
|
def get_model(self, model_config=None):
|
|
model = GraspNet(seed_feat_dim=self.seed_feat_dim, is_training=False)
|
|
model.to(self.device)
|
|
checkpoint = torch.load(self.checkpoint_path)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
start_epoch = checkpoint["epoch"]
|
|
print(
|
|
"-> loaded checkpoint %s (epoch: %d)" % (self.checkpoint_path, start_epoch)
|
|
)
|
|
model.eval()
|
|
return model
|
|
|
|
def prediction(self, model, dataloader):
|
|
preds = {}
|
|
total = len(dataloader)
|
|
for idx, batch_data in enumerate(dataloader):
|
|
print(f"predicting... [{idx}/{total}]")
|
|
for key in batch_data:
|
|
if "list" in key:
|
|
for i in range(len(batch_data[key])):
|
|
for j in range(len(batch_data[key][i])):
|
|
batch_data[key][i][j] = batch_data[key][i][j].to(
|
|
self.device
|
|
)
|
|
elif not isinstance(batch_data[key], (list)):
|
|
batch_data[key] = batch_data[key].to(self.device)
|
|
with torch.no_grad():
|
|
end_points = model(batch_data)
|
|
grasp_preds = self.decode_pred(end_points)
|
|
for frame_idx in range(len(batch_data["frame_path"])):
|
|
preds[batch_data["frame_path"][frame_idx]] = grasp_preds[frame_idx]
|
|
preds[batch_data["frame_path"][frame_idx]]["obj_pcl_dict"] = (
|
|
batch_data["obj_pcl_dict"][frame_idx]
|
|
)
|
|
|
|
results = {}
|
|
top_k = 50
|
|
for frame_path in preds:
|
|
predict_results = {}
|
|
grasp_center = preds[frame_path]["grasp_center"]
|
|
grasp_score = preds[frame_path]["grasp_score"]
|
|
obj_pcl_dict = preds[frame_path]["obj_pcl_dict"]
|
|
grasp_center = grasp_center.unsqueeze(1)
|
|
for obj_name in obj_pcl_dict:
|
|
if obj_name in OmniUtil.NON_OBJECT_LIST:
|
|
continue
|
|
obj_pcl = obj_pcl_dict[obj_name]
|
|
obj_pcl = torch.tensor(
|
|
obj_pcl.astype(np.float32), device=grasp_center.device
|
|
)
|
|
obj_pcl = obj_pcl.unsqueeze(0)
|
|
grasp_obj_table = (grasp_center == obj_pcl).all(axis=-1)
|
|
obj_pts_on_grasp = grasp_obj_table.any(axis=1)
|
|
obj_graspable_pts = grasp_center[obj_pts_on_grasp].squeeze(1)
|
|
obj_graspable_pts_score = grasp_score[obj_pts_on_grasp]
|
|
obj_graspable_pts_info = torch.cat(
|
|
[obj_graspable_pts, obj_graspable_pts_score], dim=1
|
|
)
|
|
if obj_graspable_pts.shape[0] == 0:
|
|
obj_graspable_pts_info = torch.zeros((top_k, 4))
|
|
ranked_obj_graspable_pts_info = self.sample_graspable_pts(
|
|
obj_graspable_pts_info, top_k=top_k
|
|
)
|
|
predict_results[obj_name] = {
|
|
"positions": ranked_obj_graspable_pts_info[:, :3]
|
|
.cpu()
|
|
.numpy()
|
|
.tolist(),
|
|
"scores": ranked_obj_graspable_pts_info[:, 3]
|
|
.cpu()
|
|
.numpy()
|
|
.tolist(),
|
|
}
|
|
results[frame_path] = {"predicted_results": predict_results}
|
|
return results
|
|
|
|
def preprocess(self, predicted_data):
|
|
obj_score_list_dict = {}
|
|
for frame_path in predicted_data:
|
|
frame_obj_info = predicted_data[frame_path]["predicted_results"]
|
|
predicted_data[frame_path]["sum_score"] = {}
|
|
for obj_name in frame_obj_info:
|
|
if obj_name not in obj_score_list_dict:
|
|
obj_score_list_dict[obj_name] = []
|
|
obj_score_sum = np.sum(frame_obj_info[obj_name]["scores"])
|
|
obj_score_list_dict[obj_name].append(obj_score_sum)
|
|
predicted_data[frame_path]["sum_score"][obj_name] = obj_score_sum
|
|
|
|
for frame_path in predicted_data:
|
|
frame_obj_info = predicted_data[frame_path]["predicted_results"]
|
|
predicted_data[frame_path]["regularized_score"] = {}
|
|
for obj_name in frame_obj_info:
|
|
obj_score_sum = predicted_data[frame_path]["sum_score"][obj_name]
|
|
max_obj_score = max(obj_score_list_dict[obj_name])
|
|
predicted_data[frame_path]["regularized_score"][obj_name] = (
|
|
obj_score_sum / (max_obj_score + 1e-6)
|
|
)
|
|
return predicted_data
|
|
|
|
@staticmethod
|
|
def sample_graspable_pts(graspable_pts, top_k=50):
|
|
if graspable_pts.shape[0] < top_k:
|
|
sampled_indices = torch.randint(0, graspable_pts.shape[0], (top_k,))
|
|
graspable_pts = graspable_pts[sampled_indices]
|
|
sorted_indices = torch.argsort(graspable_pts[:, 3], descending=True)
|
|
sampled_indices = graspable_pts[sorted_indices][:50]
|
|
return sampled_indices
|
|
|
|
def save_processed_data(self, processed_data, dataset_config):
|
|
import json
|
|
|
|
for frame_path in processed_data:
|
|
data_item = processed_data[frame_path]
|
|
save_root, idx = frame_path[:-4], frame_path[-4:]
|
|
label_save_path = os.path.join(
|
|
str(save_root), self.LABEL_TEMPLATE.format(idx)
|
|
)
|
|
with open(label_save_path, "w+") as f:
|
|
json.dump(data_item, f)
|
|
|
|
def decode_pred(self, end_points):
|
|
batch_size = len(end_points["point_clouds"])
|
|
grasp_preds = []
|
|
for i in range(batch_size):
|
|
grasp_center = end_points["xyz_graspable"][i].float()
|
|
num_pts = end_points["xyz_graspable"][i].shape[0]
|
|
grasp_score = end_points["grasp_score_pred"][i].float()
|
|
grasp_score = grasp_score.view(num_pts, -1)
|
|
grasp_score, _ = torch.max(grasp_score, -1) # [M_POINT]
|
|
grasp_score = grasp_score.view(-1, 1)
|
|
grasp_preds.append(
|
|
{"grasp_center": grasp_center, "grasp_score": grasp_score}
|
|
)
|
|
return grasp_preds
|
|
|
|
|
|
if __name__ == "__main__":
|
|
gs_preproc = GSNetPreprocessor()
|
|
dataloader = gs_preproc.get_dataloader()
|
|
model = gs_preproc.get_model()
|
|
results = gs_preproc.prediction(model=model, dataloader=dataloader)
|
|
results = gs_preproc.preprocess(results)
|
|
gs_preproc.save_processed_data(results, None)
|
|
# gs_preproc.evaluate()
|