import os import re import sys import numpy as np import torch from torch.utils.data import DataLoader from omni_util import OmniUtil ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(ROOT_DIR, "pointnet2")) sys.path.append(os.path.join(ROOT_DIR, "utils")) sys.path.append(os.path.join(ROOT_DIR, "models")) sys.path.append(os.path.join(ROOT_DIR, "dataset")) from models.graspnet import GraspNet from dataset.graspnet_dataset import minkowski_collate_fn from torch.utils.data import Dataset class GSNetInferenceDataset(Dataset): CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json" DISTANCE_TEMPLATE = "distance_to_camera_{}.npy" RGB_TEMPLATE = "rgb_{}.png" MASK_TEMPLATE = "semantic_segmentation_{}.png" MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json" def __init__( self, source="nbv1", data_type="sample", data_dir="/mnt/h/AI/Datasets", scene_pts_num=15000, ): self.data_dir = data_dir self.scene_pts_num = scene_pts_num self.data_path = str(os.path.join(self.data_dir, source, data_type)) self.scene_list = os.listdir(self.data_path) self.data_list = self.get_datalist() self.voxel_size = 0.005 def __len__(self): return len(self.data_list) def __getitem__(self, index): frame_path = self.data_list[index] frame_data = self.load_frame_data(frame_path=frame_path) return frame_data def get_datalist(self): for scene in self.scene_list: scene_path = os.path.join(self.data_path, scene) file_list = os.listdir(scene_path) scene_frame_list = [] for file in file_list: if file.startswith("camera_params"): frame_index = re.findall(r"\d+", file)[0] frame_path = os.path.join(scene_path, frame_index) scene_frame_list.append(frame_path) return scene_frame_list def load_frame_data(self, frame_path): target_list = OmniUtil.get_object_list(path=frame_path, contains_nonobj=True) scene_pts, obj_pcl_dict = OmniUtil.get_segmented_points( path=frame_path, target_list=target_list ) ret_dict = { "frame_path": frame_path, "point_clouds": scene_pts.astype(np.float32), "coors": scene_pts.astype(np.float32) / self.voxel_size, "feats": np.ones_like(scene_pts).astype(np.float32), "obj_pcl_dict": obj_pcl_dict, } return ret_dict @staticmethod def sample_pcl(pcl, n_pts=1024): indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts) return pcl[indices, :] class GSNetPreprocessor: LABEL_TEMPLATE = "label_{}.json" def __init__(self): self.voxel_size = 0.005 self.camera = "kinect" self.num_point = 15000 self.batch_size = 1 self.seed_feat_dim = 512 self.checkpoint_path = "logs/log_kn/epoch10.tar" self.dump_dir = "logs/log_kn/dump_kinect" self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def get_dataloader(self, dataset_config=None): def my_worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id) pass dataset = GSNetInferenceDataset() print("Test dataset length: ", len(dataset)) dataloader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn, ) print("Test dataloader length: ", len(dataloader)) return dataloader def get_model(self, model_config=None): model = GraspNet(seed_feat_dim=self.seed_feat_dim, is_training=False) model.to(self.device) checkpoint = torch.load(self.checkpoint_path) model.load_state_dict(checkpoint["model_state_dict"]) start_epoch = checkpoint["epoch"] print( "-> loaded checkpoint %s (epoch: %d)" % (self.checkpoint_path, start_epoch) ) model.eval() return model def prediction(self, model, dataloader): preds = {} total = len(dataloader) for idx, batch_data in enumerate(dataloader): print(f"predicting... [{idx}/{total}]") for key in batch_data: if "list" in key: for i in range(len(batch_data[key])): for j in range(len(batch_data[key][i])): batch_data[key][i][j] = batch_data[key][i][j].to( self.device ) elif not isinstance(batch_data[key], (list)): batch_data[key] = batch_data[key].to(self.device) with torch.no_grad(): end_points = model(batch_data) grasp_preds = self.decode_pred(end_points) for frame_idx in range(len(batch_data["frame_path"])): preds[batch_data["frame_path"][frame_idx]] = grasp_preds[frame_idx] preds[batch_data["frame_path"][frame_idx]]["obj_pcl_dict"] = ( batch_data["obj_pcl_dict"][frame_idx] ) results = {} top_k = 50 for frame_path in preds: predict_results = {} grasp_center = preds[frame_path]["grasp_center"] grasp_score = preds[frame_path]["grasp_score"] obj_pcl_dict = preds[frame_path]["obj_pcl_dict"] grasp_center = grasp_center.unsqueeze(1) for obj_name in obj_pcl_dict: if obj_name in OmniUtil.NON_OBJECT_LIST: continue obj_pcl = obj_pcl_dict[obj_name] obj_pcl = torch.tensor( obj_pcl.astype(np.float32), device=grasp_center.device ) obj_pcl = obj_pcl.unsqueeze(0) grasp_obj_table = (grasp_center == obj_pcl).all(axis=-1) obj_pts_on_grasp = grasp_obj_table.any(axis=1) obj_graspable_pts = grasp_center[obj_pts_on_grasp].squeeze(1) obj_graspable_pts_score = grasp_score[obj_pts_on_grasp] obj_graspable_pts_info = torch.cat( [obj_graspable_pts, obj_graspable_pts_score], dim=1 ) if obj_graspable_pts.shape[0] == 0: obj_graspable_pts_info = torch.zeros((top_k, 4)) ranked_obj_graspable_pts_info = self.sample_graspable_pts( obj_graspable_pts_info, top_k=top_k ) predict_results[obj_name] = { "positions": ranked_obj_graspable_pts_info[:, :3] .cpu() .numpy() .tolist(), "scores": ranked_obj_graspable_pts_info[:, 3] .cpu() .numpy() .tolist(), } results[frame_path] = {"predicted_results": predict_results} return results def preprocess(self, predicted_data): obj_score_list_dict = {} for frame_path in predicted_data: frame_obj_info = predicted_data[frame_path]["predicted_results"] predicted_data[frame_path]["sum_score"] = {} for obj_name in frame_obj_info: if obj_name not in obj_score_list_dict: obj_score_list_dict[obj_name] = [] obj_score_sum = np.sum(frame_obj_info[obj_name]["scores"]) obj_score_list_dict[obj_name].append(obj_score_sum) predicted_data[frame_path]["sum_score"][obj_name] = obj_score_sum for frame_path in predicted_data: frame_obj_info = predicted_data[frame_path]["predicted_results"] predicted_data[frame_path]["regularized_score"] = {} for obj_name in frame_obj_info: obj_score_sum = predicted_data[frame_path]["sum_score"][obj_name] max_obj_score = max(obj_score_list_dict[obj_name]) predicted_data[frame_path]["regularized_score"][obj_name] = ( obj_score_sum / (max_obj_score + 1e-6) ) return predicted_data @staticmethod def sample_graspable_pts(graspable_pts, top_k=50): if graspable_pts.shape[0] < top_k: sampled_indices = torch.randint(0, graspable_pts.shape[0], (top_k,)) graspable_pts = graspable_pts[sampled_indices] sorted_indices = torch.argsort(graspable_pts[:, 3], descending=True) sampled_indices = graspable_pts[sorted_indices][:50] return sampled_indices def save_processed_data(self, processed_data, dataset_config): import json for frame_path in processed_data: data_item = processed_data[frame_path] save_root, idx = frame_path[:-4], frame_path[-4:] label_save_path = os.path.join( str(save_root), self.LABEL_TEMPLATE.format(idx) ) with open(label_save_path, "w+") as f: json.dump(data_item, f) def decode_pred(self, end_points): batch_size = len(end_points["point_clouds"]) grasp_preds = [] for i in range(batch_size): grasp_center = end_points["xyz_graspable"][i].float() num_pts = end_points["xyz_graspable"][i].shape[0] grasp_score = end_points["grasp_score_pred"][i].float() grasp_score = grasp_score.view(num_pts, -1) grasp_score, _ = torch.max(grasp_score, -1) # [M_POINT] grasp_score = grasp_score.view(-1, 1) grasp_preds.append( {"grasp_center": grasp_center, "grasp_score": grasp_score} ) return grasp_preds if __name__ == "__main__": gs_preproc = GSNetPreprocessor() dataloader = gs_preproc.get_dataloader() model = gs_preproc.get_model() results = gs_preproc.prediction(model=model, dataloader=dataloader) results = gs_preproc.preprocess(results) gs_preproc.save_processed_data(results, None) # gs_preproc.evaluate()