2024-10-09 16:13:22 +00:00

257 lines
10 KiB
Python
Executable File

import os
import re
import sys
import numpy as np
import torch
from torch.utils.data import DataLoader
from omni_util import OmniUtil
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(ROOT_DIR, "pointnet2"))
sys.path.append(os.path.join(ROOT_DIR, "utils"))
sys.path.append(os.path.join(ROOT_DIR, "models"))
sys.path.append(os.path.join(ROOT_DIR, "dataset"))
from models.graspnet import GraspNet
from dataset.graspnet_dataset import minkowski_collate_fn
from torch.utils.data import Dataset
class GSNetInferenceDataset(Dataset):
CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json"
DISTANCE_TEMPLATE = "distance_to_camera_{}.npy"
RGB_TEMPLATE = "rgb_{}.png"
MASK_TEMPLATE = "semantic_segmentation_{}.png"
MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json"
def __init__(
self,
source="nbv1",
data_type="sample",
data_dir="/mnt/h/AI/Datasets",
scene_pts_num=15000,
):
self.data_dir = data_dir
self.scene_pts_num = scene_pts_num
self.data_path = str(os.path.join(self.data_dir, source, data_type))
self.scene_list = os.listdir(self.data_path)
self.data_list = self.get_datalist()
self.voxel_size = 0.005
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
frame_path = self.data_list[index]
frame_data = self.load_frame_data(frame_path=frame_path)
return frame_data
def get_datalist(self):
for scene in self.scene_list:
scene_path = os.path.join(self.data_path, scene)
file_list = os.listdir(scene_path)
scene_frame_list = []
for file in file_list:
if file.startswith("camera_params"):
frame_index = re.findall(r"\d+", file)[0]
frame_path = os.path.join(scene_path, frame_index)
scene_frame_list.append(frame_path)
return scene_frame_list
def load_frame_data(self, frame_path):
target_list = OmniUtil.get_object_list(path=frame_path, contains_nonobj=True)
scene_pts, obj_pcl_dict = OmniUtil.get_segmented_points(
path=frame_path, target_list=target_list
)
ret_dict = {
"frame_path": frame_path,
"point_clouds": scene_pts.astype(np.float32),
"coors": scene_pts.astype(np.float32) / self.voxel_size,
"feats": np.ones_like(scene_pts).astype(np.float32),
"obj_pcl_dict": obj_pcl_dict,
}
return ret_dict
@staticmethod
def sample_pcl(pcl, n_pts=1024):
indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts)
return pcl[indices, :]
class GSNetPreprocessor:
LABEL_TEMPLATE = "label_{}.json"
def __init__(self):
self.voxel_size = 0.005
self.camera = "kinect"
self.num_point = 15000
self.batch_size = 1
self.seed_feat_dim = 512
self.checkpoint_path = "logs/log_kn/epoch10.tar"
self.dump_dir = "logs/log_kn/dump_kinect"
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_dataloader(self, dataset_config=None):
def my_worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
pass
dataset = GSNetInferenceDataset()
print("Test dataset length: ", len(dataset))
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=0,
worker_init_fn=my_worker_init_fn,
collate_fn=minkowski_collate_fn,
)
print("Test dataloader length: ", len(dataloader))
return dataloader
def get_model(self, model_config=None):
model = GraspNet(seed_feat_dim=self.seed_feat_dim, is_training=False)
model.to(self.device)
checkpoint = torch.load(self.checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
start_epoch = checkpoint["epoch"]
print(
"-> loaded checkpoint %s (epoch: %d)" % (self.checkpoint_path, start_epoch)
)
model.eval()
return model
def prediction(self, model, dataloader):
preds = {}
total = len(dataloader)
for idx, batch_data in enumerate(dataloader):
print(f"predicting... [{idx}/{total}]")
for key in batch_data:
if "list" in key:
for i in range(len(batch_data[key])):
for j in range(len(batch_data[key][i])):
batch_data[key][i][j] = batch_data[key][i][j].to(
self.device
)
elif not isinstance(batch_data[key], (list)):
batch_data[key] = batch_data[key].to(self.device)
with torch.no_grad():
end_points = model(batch_data)
grasp_preds = self.decode_pred(end_points)
for frame_idx in range(len(batch_data["frame_path"])):
preds[batch_data["frame_path"][frame_idx]] = grasp_preds[frame_idx]
preds[batch_data["frame_path"][frame_idx]]["obj_pcl_dict"] = (
batch_data["obj_pcl_dict"][frame_idx]
)
results = {}
top_k = 50
for frame_path in preds:
predict_results = {}
grasp_center = preds[frame_path]["grasp_center"]
grasp_score = preds[frame_path]["grasp_score"]
obj_pcl_dict = preds[frame_path]["obj_pcl_dict"]
grasp_center = grasp_center.unsqueeze(1)
for obj_name in obj_pcl_dict:
if obj_name in OmniUtil.NON_OBJECT_LIST:
continue
obj_pcl = obj_pcl_dict[obj_name]
obj_pcl = torch.tensor(
obj_pcl.astype(np.float32), device=grasp_center.device
)
obj_pcl = obj_pcl.unsqueeze(0)
grasp_obj_table = (grasp_center == obj_pcl).all(axis=-1)
obj_pts_on_grasp = grasp_obj_table.any(axis=1)
obj_graspable_pts = grasp_center[obj_pts_on_grasp].squeeze(1)
obj_graspable_pts_score = grasp_score[obj_pts_on_grasp]
obj_graspable_pts_info = torch.cat(
[obj_graspable_pts, obj_graspable_pts_score], dim=1
)
if obj_graspable_pts.shape[0] == 0:
obj_graspable_pts_info = torch.zeros((top_k, 4))
ranked_obj_graspable_pts_info = self.sample_graspable_pts(
obj_graspable_pts_info, top_k=top_k
)
predict_results[obj_name] = {
"positions": ranked_obj_graspable_pts_info[:, :3]
.cpu()
.numpy()
.tolist(),
"scores": ranked_obj_graspable_pts_info[:, 3]
.cpu()
.numpy()
.tolist(),
}
results[frame_path] = {"predicted_results": predict_results}
return results
def preprocess(self, predicted_data):
obj_score_list_dict = {}
for frame_path in predicted_data:
frame_obj_info = predicted_data[frame_path]["predicted_results"]
predicted_data[frame_path]["sum_score"] = {}
for obj_name in frame_obj_info:
if obj_name not in obj_score_list_dict:
obj_score_list_dict[obj_name] = []
obj_score_sum = np.sum(frame_obj_info[obj_name]["scores"])
obj_score_list_dict[obj_name].append(obj_score_sum)
predicted_data[frame_path]["sum_score"][obj_name] = obj_score_sum
for frame_path in predicted_data:
frame_obj_info = predicted_data[frame_path]["predicted_results"]
predicted_data[frame_path]["regularized_score"] = {}
for obj_name in frame_obj_info:
obj_score_sum = predicted_data[frame_path]["sum_score"][obj_name]
max_obj_score = max(obj_score_list_dict[obj_name])
predicted_data[frame_path]["regularized_score"][obj_name] = (
obj_score_sum / (max_obj_score + 1e-6)
)
return predicted_data
@staticmethod
def sample_graspable_pts(graspable_pts, top_k=50):
if graspable_pts.shape[0] < top_k:
sampled_indices = torch.randint(0, graspable_pts.shape[0], (top_k,))
graspable_pts = graspable_pts[sampled_indices]
sorted_indices = torch.argsort(graspable_pts[:, 3], descending=True)
sampled_indices = graspable_pts[sorted_indices][:50]
return sampled_indices
def save_processed_data(self, processed_data, dataset_config):
import json
for frame_path in processed_data:
data_item = processed_data[frame_path]
save_root, idx = frame_path[:-4], frame_path[-4:]
label_save_path = os.path.join(
str(save_root), self.LABEL_TEMPLATE.format(idx)
)
with open(label_save_path, "w+") as f:
json.dump(data_item, f)
def decode_pred(self, end_points):
batch_size = len(end_points["point_clouds"])
grasp_preds = []
for i in range(batch_size):
grasp_center = end_points["xyz_graspable"][i].float()
num_pts = end_points["xyz_graspable"][i].shape[0]
grasp_score = end_points["grasp_score_pred"][i].float()
grasp_score = grasp_score.view(num_pts, -1)
grasp_score, _ = torch.max(grasp_score, -1) # [M_POINT]
grasp_score = grasp_score.view(-1, 1)
grasp_preds.append(
{"grasp_center": grasp_center, "grasp_score": grasp_score}
)
return grasp_preds
if __name__ == "__main__":
gs_preproc = GSNetPreprocessor()
dataloader = gs_preproc.get_dataloader()
model = gs_preproc.get_model()
results = gs_preproc.prediction(model=model, dataloader=dataloader)
results = gs_preproc.preprocess(results)
gs_preproc.save_processed_data(results, None)
# gs_preproc.evaluate()