diff --git a/configs/local/view_generate_config.yaml b/configs/local/view_generate_config.yaml index 09a5da8..aae49c0 100644 --- a/configs/local/view_generate_config.yaml +++ b/configs/local/view_generate_config.yaml @@ -24,12 +24,6 @@ runner: max_height: 0.15 min_radius: 0.3 max_radius: 0.5 - min_R: 0.05 - max_R: 0.3 - min_G: 0.05 - max_G: 0.3 - min_B: 0.05 - max_B: 0.3 display_object: min_x: 0 max_x: 0.03 diff --git a/core/global_pts_n_num_pipeline.py b/core/global_pts_n_num_pipeline.py new file mode 100644 index 0000000..3970bd3 --- /dev/null +++ b/core/global_pts_n_num_pipeline.py @@ -0,0 +1,100 @@ +import torch +from torch import nn +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory.component_factory import ComponentFactory +from PytorchBoot.utils import Log + + +@stereotype.pipeline("nbv_reconstruction_global_pts_n_num_pipeline") +class NBVReconstructionGlobalPointsPipeline(nn.Module): + def __init__(self, config): + super(NBVReconstructionGlobalPointsPipeline, self).__init__() + self.config = config + self.module_config = config["modules"] + self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) + self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) + self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]) + self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) + self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]) + + self.eps = float(self.config["eps"]) + self.enable_global_scanned_feat = self.config["global_scanned_feat"] + + def forward(self, data): + mode = data["mode"] + + if mode == namespace.Mode.TRAIN: + return self.forward_train(data) + elif mode == namespace.Mode.TEST: + return self.forward_test(data) + else: + Log.error("Unknown mode: {}".format(mode), True) + + def pertube_data(self, gt_delta_9d): + bs = gt_delta_9d.shape[0] + random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps + random_t = random_t.unsqueeze(-1) + mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) + std = std.view(-1, 1) + z = torch.randn_like(gt_delta_9d) + perturbed_x = mu + z * std + target_score = - z * std / (std ** 2) + return perturbed_x, random_t, target_score, std + + def forward_train(self, data): + main_feat = self.get_main_feat(data) + ''' get std ''' + best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] + perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch) + input_data = { + "sampled_pose": perturbed_x, + "t": random_t, + "main_feat": main_feat, + } + estimated_score = self.view_finder(input_data) + output = { + "estimated_score": estimated_score, + "target_score": target_score, + "std": std + } + return output + + def forward_test(self,data): + main_feat = self.get_main_feat(data) + estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat) + result = { + "pred_pose_9d": estimated_delta_rot_9d, + "in_process_sample": in_process_sample + } + return result + + + def get_main_feat(self, data): + scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d'] + scanned_target_pts_num_batch = data['scanned_target_points_num'] + + device = next(self.parameters()).device + + embedding_list_batch = [] + + for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch): + scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) + scanned_target_pts_num = scanned_target_pts_num.to(device) + pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) + pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) + embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1)) + + main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) + + + combined_scanned_pts_batch = data['combined_scanned_pts'] + global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch) + main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1) + + + if torch.isnan(main_feat).any(): + Log.error("nan in main_feat", True) + + return main_feat + diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index f2748b8..829bb01 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -7,6 +7,7 @@ from PytorchBoot.utils.log_util import Log import torch import os import sys + sys.path.append(r"/home/data/hofee/project/nbv_rec/nbv_reconstruction") from utils.data_load import DataLoadUtil @@ -29,20 +30,17 @@ class NBVReconstructionDataset(BaseDataset): self.cache = config.get("cache") self.load_from_preprocess = config.get("load_from_preprocess", False) - if self.type == namespace.Mode.TEST: self.model_dir = config["model_dir"] self.filter_degree = config["filter_degree"] if self.type == namespace.Mode.TRAIN: scale_ratio = 100 - self.datalist = self.datalist*scale_ratio + self.datalist = self.datalist * scale_ratio if self.cache: expr_root = ConfigManager.get("runner", "experiment", "root_dir") expr_name = ConfigManager.get("runner", "experiment", "name") self.cache_dir = os.path.join(expr_root, expr_name, "cache") - #self.preprocess_cache() - - + # self.preprocess_cache() def load_scene_name_list(self): scene_name_list = [] @@ -51,7 +49,7 @@ class NBVReconstructionDataset(BaseDataset): scene_name = line.strip() scene_name_list.append(scene_name) return scene_name_list - + def get_datalist(self): datalist = [] for scene_name in self.scene_name_list: @@ -60,7 +58,9 @@ class NBVReconstructionDataset(BaseDataset): max_coverage_rate_list = [] for seq_idx in range(seq_num): - label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx) + label_path = DataLoadUtil.get_label_path( + self.root_dir, scene_name, seq_idx + ) label_data = DataLoadUtil.load_label(label_path) max_coverage_rate = label_data["max_coverage_rate"] if max_coverage_rate > scene_max_coverage_rate: @@ -69,20 +69,24 @@ class NBVReconstructionDataset(BaseDataset): mean_coverage_rate = np.mean(max_coverage_rate_list) for seq_idx in range(seq_num): - label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx) + label_path = DataLoadUtil.get_label_path( + self.root_dir, scene_name, seq_idx + ) label_data = DataLoadUtil.load_label(label_path) if max_coverage_rate_list[seq_idx] > mean_coverage_rate - 0.1: for data_pair in label_data["data_pairs"]: scanned_views = data_pair[0] next_best_view = data_pair[1] - datalist.append({ - "scanned_views": scanned_views, - "next_best_view": next_best_view, - "seq_max_coverage_rate": max_coverage_rate, - "scene_name": scene_name, - "label_idx": seq_idx, - "scene_max_coverage_rate": scene_max_coverage_rate - }) + datalist.append( + { + "scanned_views": scanned_views, + "next_best_view": next_best_view, + "seq_max_coverage_rate": max_coverage_rate, + "scene_name": scene_name, + "label_idx": seq_idx, + "scene_max_coverage_rate": scene_max_coverage_rate, + } + ) return datalist def preprocess_cache(self): @@ -90,7 +94,7 @@ class NBVReconstructionDataset(BaseDataset): for item_idx in range(len(self.datalist)): self.__getitem__(item_idx) Log.success("finish preprocessing cache.") - + def load_from_cache(self, scene_name, curr_frame_idx): cache_name = f"{scene_name}_{curr_frame_idx}.txt" cache_path = os.path.join(self.cache_dir, cache_name) @@ -99,7 +103,7 @@ class NBVReconstructionDataset(BaseDataset): return data else: return None - + def save_to_cache(self, scene_name, curr_frame_idx, data): cache_name = f"{scene_name}_{curr_frame_idx}.txt" cache_path = os.path.join(self.cache_dir, cache_name) @@ -107,125 +111,172 @@ class NBVReconstructionDataset(BaseDataset): np.savetxt(cache_path, data) except Exception as e: Log.error(f"Save cache failed: {e}") - # ----- Debug Trace ----- # - import ipdb; ipdb.set_trace() - # ------------------------ # - + def __getitem__(self, index): data_item_info = self.datalist[index] scanned_views = data_item_info["scanned_views"] nbv = data_item_info["next_best_view"] max_coverage_rate = data_item_info["seq_max_coverage_rate"] scene_name = data_item_info["scene_name"] - scanned_views_pts, scanned_coverages_rate, scanned_n_to_world_pose = [], [], [] - + ( + scanned_views_pts, + scanned_coverages_rate, + scanned_n_to_world_pose, + scanned_target_pts_num, + ) = ([], [], [], []) + target_pts_num_dict = DataLoadUtil.load_target_pts_num_dict( + self.root_dir, scene_name + ) for view in scanned_views: frame_idx = view[0] coverage_rate = view[1] view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True) + target_pts_num = target_pts_num_dict[frame_idx] n_to_world_pose = cam_info["cam_to_world"] - nR_to_world_pose = cam_info["cam_to_world_R"] + nR_to_world_pose = cam_info["cam_to_world_R"] if self.load_from_preprocess: - downsampled_target_point_cloud = DataLoadUtil.load_from_preprocessed_pts(view_path) + downsampled_target_point_cloud = ( + DataLoadUtil.load_from_preprocessed_pts(view_path) + ) else: cached_data = None if self.cache: cached_data = self.load_from_cache(scene_name, frame_idx) - + if cached_data is None: print("load depth") - depth_L, depth_R = DataLoadUtil.load_depth(view_path, cam_info['near_plane'], cam_info['far_plane'], binocular=True) - point_cloud_L = DataLoadUtil.get_point_cloud(depth_L, cam_info['cam_intrinsic'], n_to_world_pose)['points_world'] - point_cloud_R = DataLoadUtil.get_point_cloud(depth_R, cam_info['cam_intrinsic'], nR_to_world_pose)['points_world'] - - point_cloud_L = PtsUtil.random_downsample_point_cloud(point_cloud_L, 65536) - point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536) - overlap_points = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R) - downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(overlap_points, self.pts_num) + depth_L, depth_R = DataLoadUtil.load_depth( + view_path, + cam_info["near_plane"], + cam_info["far_plane"], + binocular=True, + ) + point_cloud_L = DataLoadUtil.get_point_cloud( + depth_L, cam_info["cam_intrinsic"], n_to_world_pose + )["points_world"] + point_cloud_R = DataLoadUtil.get_point_cloud( + depth_R, cam_info["cam_intrinsic"], nR_to_world_pose + )["points_world"] + + point_cloud_L = PtsUtil.random_downsample_point_cloud( + point_cloud_L, 65536 + ) + point_cloud_R = PtsUtil.random_downsample_point_cloud( + point_cloud_R, 65536 + ) + overlap_points = DataLoadUtil.get_overlapping_points( + point_cloud_L, point_cloud_R + ) + downsampled_target_point_cloud = ( + PtsUtil.random_downsample_point_cloud( + overlap_points, self.pts_num + ) + ) if self.cache: - self.save_to_cache(scene_name, frame_idx, downsampled_target_point_cloud) + self.save_to_cache( + scene_name, frame_idx, downsampled_target_point_cloud + ) else: downsampled_target_point_cloud = cached_data - + scanned_views_pts.append(downsampled_target_point_cloud) - scanned_coverages_rate.append(coverage_rate) - n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(np.asarray(n_to_world_pose[:3,:3])) - n_to_world_trans = n_to_world_pose[:3,3] + scanned_coverages_rate.append(coverage_rate) + n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( + np.asarray(n_to_world_pose[:3, :3]) + ) + n_to_world_trans = n_to_world_pose[:3, 3] n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0) scanned_n_to_world_pose.append(n_to_world_9d) + scanned_target_pts_num.append(target_pts_num) nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx) cam_info = DataLoadUtil.load_cam_info(nbv_path) best_frame_to_world = cam_info["cam_to_world"] - - best_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(np.asarray(best_frame_to_world[:3,:3])) - best_to_world_trans = best_frame_to_world[:3,3] - best_to_world_9d = np.concatenate([best_to_world_6d, best_to_world_trans], axis=0) + + best_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( + np.asarray(best_frame_to_world[:3, :3]) + ) + best_to_world_trans = best_frame_to_world[:3, 3] + best_to_world_9d = np.concatenate( + [best_to_world_6d, best_to_world_trans], axis=0 + ) combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) - voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002) - random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, self.pts_num) + voxel_downsampled_combined_scanned_pts_np = ( + PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002) + ) + random_downsampled_combined_scanned_pts_np = ( + PtsUtil.random_downsample_point_cloud( + voxel_downsampled_combined_scanned_pts_np, self.pts_num + ) + ) data_item = { - "scanned_pts": np.asarray(scanned_views_pts,dtype=np.float32), - "combined_scanned_pts": np.asarray(random_downsampled_combined_scanned_pts_np,dtype=np.float32), + "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), + "combined_scanned_pts": np.asarray( + random_downsampled_combined_scanned_pts_np, dtype=np.float32 + ), "scanned_coverage_rate": scanned_coverages_rate, - "scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose,dtype=np.float32), + "scanned_n_to_world_pose_9d": np.asarray( + scanned_n_to_world_pose, dtype=np.float32 + ), "best_coverage_rate": nbv_coverage_rate, - "best_to_world_pose_9d": np.asarray(best_to_world_9d,dtype=np.float32), + "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), "seq_max_coverage_rate": max_coverage_rate, - "scene_name": scene_name + "scene_name": scene_name, + "scanned_target_points_num": np.asarray( + scanned_target_pts_num, dtype=np.int32 + ), } - - - # if self.type == namespace.Mode.TEST: - # diag = DataLoadUtil.get_bbox_diag(self.model_dir, scene_name) - # voxel_threshold = diag*0.02 - # model_points_normals = DataLoadUtil.load_points_normals(self.root_dir, scene_name) - # pts_list = [] - # for view in scanned_views: - # frame_idx = view[0] - # view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) - # point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(view_path, binocular=True) - # cam_params = DataLoadUtil.load_cam_info(view_path, binocular=True) - # sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree) - # pts_list.append(sampled_point_cloud) - # nL_to_world_pose = cam_params["cam_to_world"] - # nO_to_world_pose = cam_params["cam_to_world_O"] - # nO_to_nL_pose = np.dot(np.linalg.inv(nL_to_world_pose), nO_to_world_pose) - # data_item["scanned_target_pts_list"] = pts_list - # data_item["model_points_normals"] = model_points_normals - # data_item["voxel_threshold"] = voxel_threshold - # data_item["filter_degree"] = self.filter_degree - # data_item["scene_path"] = os.path.join(self.root_dir, scene_name) - # data_item["first_frame_to_world"] = np.asarray(first_frame_to_world, dtype=np.float32) - # data_item["nO_to_nL_pose"] = np.asarray(nO_to_nL_pose, dtype=np.float32) + return data_item def __len__(self): return len(self.datalist) - + def get_collate_fn(self): def collate_fn(batch): collate_data = {} - collate_data["scanned_pts"] = [torch.tensor(item['scanned_pts']) for item in batch] - collate_data["scanned_n_to_world_pose_9d"] = [torch.tensor(item['scanned_n_to_world_pose_9d']) for item in batch] - collate_data["best_to_world_pose_9d"] = torch.stack([torch.tensor(item['best_to_world_pose_9d']) for item in batch]) - collate_data["combined_scanned_pts"] = torch.stack([torch.tensor(item['combined_scanned_pts']) for item in batch]) + collate_data["scanned_pts"] = [ + torch.tensor(item["scanned_pts"]) for item in batch + ] + collate_data["scanned_n_to_world_pose_9d"] = [ + torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch + ] + collate_data["scanned_target_points_num"] = [ + torch.tensor(item["scanned_target_points_num"]) for item in batch + ] + collate_data["best_to_world_pose_9d"] = torch.stack( + [torch.tensor(item["best_to_world_pose_9d"]) for item in batch] + ) + collate_data["combined_scanned_pts"] = torch.stack( + [torch.tensor(item["combined_scanned_pts"]) for item in batch] + ) if "first_frame_to_world" in batch[0]: - collate_data["first_frame_to_world"] = torch.stack([torch.tensor(item["first_frame_to_world"]) for item in batch]) + collate_data["first_frame_to_world"] = torch.stack( + [torch.tensor(item["first_frame_to_world"]) for item in batch] + ) for key in batch[0].keys(): - if key not in ["scanned_pts", "scanned_n_to_world_pose_9d", "best_to_world_pose_9d", "first_frame_to_world", "combined_scanned_pts"]: + if key not in [ + "scanned_pts", + "scanned_n_to_world_pose_9d", + "best_to_world_pose_9d", + "first_frame_to_world", + "combined_scanned_pts", + "scanned_target_points_num", + ]: collate_data[key] = [item[key] for item in batch] return collate_data + return collate_fn # -------------- Debug ---------------- # if __name__ == "__main__": import torch + seed = 0 torch.manual_seed(seed) np.random.seed(seed) @@ -244,41 +295,13 @@ if __name__ == "__main__": } ds = NBVReconstructionDataset(config) print(len(ds)) - #ds.__getitem__(10) + # ds.__getitem__(10) dl = ds.get_loader(shuffle=True) for idx, data in enumerate(dl): data = ds.process_batch(data, "cuda:0") print(data) # ------ Debug Start ------ - import ipdb;ipdb.set_trace() + import ipdb + + ipdb.set_trace() # ------ Debug End ------ - # - # for idx, data in enumerate(dl): - # cnt=0 - # print(data["scene_name"]) - # print(data["scanned_coverage_rate"]) - # print(data["best_coverage_rate"]) - # for pts in data["scanned_pts"][0]: - # #np.savetxt(f"pts_{cnt}.txt", pts) - # cnt+=1 - # #np.savetxt("best_pts.txt", best_pts) - # for key, value in data.items(): - # if isinstance(value, torch.Tensor): - # print(key, ":" ,value.shape) - # else: - # print(key, ":" ,len(value)) - # if key == "scanned_n_to_world_pose_9d": - # for val in value: - # print(val.shape) - # if key == "scanned_pts": - # print("scanned_pts") - # for val in value: - # print(val.shape) - # cnt = 0 - # for v in val: - # import ipdb;ipdb.set_trace() - # np.savetxt(f"pts_{cnt}.txt", v) - # cnt+=1 - - - # print() \ No newline at end of file diff --git a/modules/pointnet_encoder.py b/modules/pointnet_encoder.py index 6483709..6e414f2 100644 --- a/modules/pointnet_encoder.py +++ b/modules/pointnet_encoder.py @@ -22,12 +22,10 @@ class PointNetEncoder(nn.Module): self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 512, 1) self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1) - self.global_feat = config["global_feat"] if self.feature_transform: self.f_stn = STNkd(k=64) def forward(self, x): - n_pts = x.shape[2] trans = self.stn(x) x = x.transpose(2, 1) x = torch.bmm(x, trans) @@ -46,20 +44,15 @@ class PointNetEncoder(nn.Module): x = self.conv4(x) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, self.out_dim) - if self.global_feat: - return x - else: - x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts) - return torch.cat([x, point_feat], 1) + return x, point_feat - def encode_points(self, pts): + def encode_points(self, pts, require_per_point_feat=False): pts = pts.transpose(2, 1) - - if not self.global_feat: - pts_feature = self(pts).transpose(2, 1) + global_pts_feature, per_point_feature = self(pts) + if require_per_point_feat: + return global_pts_feature, per_point_feature.transpose(2, 1) else: - pts_feature = self(pts) - return pts_feature + return global_pts_feature class STNkd(nn.Module): def __init__(self, k=64): @@ -102,21 +95,13 @@ if __name__ == "__main__": config = { "in_dim": 3, "out_dim": 1024, - "global_feat": True, "feature_transform": False } - pointnet_global = PointNetEncoder(config) - out = pointnet_global.encode_points(sim_data) + pointnet = PointNetEncoder(config) + out = pointnet.encode_points(sim_data) print("global feat", out.size()) - config = { - "in_dim": 3, - "out_dim": 1024, - "global_feat": False, - "feature_transform": False - } - - pointnet = PointNetEncoder(config) - out = pointnet.encode_points(sim_data) + out, per_point_out = pointnet.encode_points(sim_data, require_per_point_feat=True) print("point feat", out.size()) + print("per point feat", per_point_out.size()) diff --git a/modules/pts_num_encoder.py b/modules/pts_num_encoder.py new file mode 100644 index 0000000..2210c21 --- /dev/null +++ b/modules/pts_num_encoder.py @@ -0,0 +1,20 @@ +from torch import nn +import PytorchBoot.stereotype as stereotype + +@stereotype.module("pts_num_encoder") +class PointsNumEncoder(nn.Module): + def __init__(self, config): + super(PointsNumEncoder, self).__init__() + self.config = config + out_dim = config["out_dim"] + self.act = nn.ReLU(True) + + self.pts_num_encoder = nn.Sequential( + nn.Linear(1, out_dim), + self.act, + nn.Linear(out_dim, out_dim), + self.act, + ) + + def encode_pts_num(self, num_seq): + return self.pts_num_encoder(num_seq) diff --git a/modules/transformer_pose_seq_encoder.py b/modules/transformer_pose_seq_encoder.py deleted file mode 100644 index 926a0e4..0000000 --- a/modules/transformer_pose_seq_encoder.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from torch import nn -from torch.nn.utils.rnn import pad_sequence -import PytorchBoot.stereotype as stereotype - - -@stereotype.module("transformer_pose_seq_encoder") -class TransformerPoseSequenceEncoder(nn.Module): - def __init__(self, config): - super(TransformerPoseSequenceEncoder, self).__init__() - self.config = config - embed_dim = config["pose_embed_dim"] - encoder_layer = nn.TransformerEncoderLayer( - d_model=embed_dim, - nhead=config["num_heads"], - dim_feedforward=config["ffn_dim"], - batch_first=True, - ) - self.transformer_encoder = nn.TransformerEncoder( - encoder_layer, num_layers=config["num_layers"] - ) - self.fc = nn.Linear(embed_dim, config["output_dim"]) - - def encode_sequence(self, pose_embedding_list_batch): - - lengths = [] - - for pose_embedding_list in pose_embedding_list_batch: - lengths.append(len(pose_embedding_list)) - - combined_tensor = pad_sequence(pose_embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim] - - max_len = max(lengths) - padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device) - - transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask) - final_feature = transformer_output.mean(dim=1) - final_output = self.fc(final_feature) - - return final_output - - -if __name__ == "__main__": - config = { - "pose_embed_dim": 256, - "num_heads": 4, - "ffn_dim": 256, - "num_layers": 3, - "output_dim": 1024, - } - - encoder = TransformerPoseSequenceEncoder(config) - seq_len = [5, 8, 9, 4] - batch_size = 4 - - pose_embedding_list_batch = [ - torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size) - ] - output_feature = encoder.encode_sequence( - pose_embedding_list_batch - ) - print("Encoded Feature:", output_feature) - print("Feature Shape:", output_feature.shape) diff --git a/modules/transformer_seq_encoder.py b/modules/transformer_seq_encoder.py index 1eae505..13aabb2 100644 --- a/modules/transformer_seq_encoder.py +++ b/modules/transformer_seq_encoder.py @@ -9,7 +9,7 @@ class TransformerSequenceEncoder(nn.Module): def __init__(self, config): super(TransformerSequenceEncoder, self).__init__() self.config = config - embed_dim = config["pts_embed_dim"] + config["pose_embed_dim"] + embed_dim = config["embed_dim"] encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=config["num_heads"], @@ -21,24 +21,19 @@ class TransformerSequenceEncoder(nn.Module): ) self.fc = nn.Linear(embed_dim, config["output_dim"]) - def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch): - combined_features_batch = [] + def encode_sequence(self, embedding_list_batch): + lengths = [] + + for embedding_list in embedding_list_batch: + lengths.append(len(embedding_list)) - for pts_embedding_list, pose_embedding_list in zip(pts_embedding_list_batch, pose_embedding_list_batch): - combined_features = [ - torch.cat((pts_embed, pose_embed), dim=-1) - for pts_embed, pose_embed in zip(pts_embedding_list, pose_embedding_list) - ] - combined_features_batch.append(torch.stack(combined_features)) - lengths.append(len(combined_features)) - - combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim] + embedding_tensor = pad_sequence(embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim] max_len = max(lengths) - padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device) + padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(embedding_tensor.device) - transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask) + transformer_output = self.transformer_encoder(embedding_tensor, src_key_padding_mask=padding_mask) final_feature = transformer_output.mean(dim=1) final_output = self.fc(final_feature) @@ -47,26 +42,22 @@ class TransformerSequenceEncoder(nn.Module): if __name__ == "__main__": config = { - "pts_embed_dim": 1024, - "pose_embed_dim": 256, + "embed_dim": 256, "num_heads": 4, "ffn_dim": 256, "num_layers": 3, - "output_dim": 2048, + "output_dim": 1024, } encoder = TransformerSequenceEncoder(config) seq_len = [5, 8, 9, 4] batch_size = 4 - pts_embedding_list_batch = [ - torch.randn(seq_len[idx], config["pts_embed_dim"]) for idx in range(batch_size) - ] - pose_embedding_list_batch = [ - torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size) + embedding_list_batch = [ + torch.randn(seq_len[idx], config["embed_dim"]) for idx in range(batch_size) ] output_feature = encoder.encode_sequence( - pts_embedding_list_batch, pose_embedding_list_batch + embedding_list_batch ) print("Encoded Feature:", output_feature) print("Feature Shape:", output_feature.shape) diff --git a/runners/strategy_generator.py b/runners/strategy_generator.py index c8f91be..db7cc2f 100644 --- a/runners/strategy_generator.py +++ b/runners/strategy_generator.py @@ -82,28 +82,40 @@ class StrategyGenerator(Runner): model_points_normals = DataLoadUtil.load_points_normals(root, scene_name) model_pts = model_points_normals[:,:3] down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold) - + display_table_info = DataLoadUtil.get_display_table_info(root, scene_name) + radius = display_table_info["radius"] + top = DataLoadUtil.get_display_table_top(root, scene_name) + scan_points = ReconstructionUtil.generate_scan_points(display_table_top=top,display_table_radius=radius) pts_list = [] + scan_points_indices_list = [] for frame_idx in range(frame_num): if self.load_pts and os.path.exists(os.path.join(root,scene_name, "pts", f"{frame_idx}.txt")): sampled_point_cloud = np.loadtxt(os.path.join(root,scene_name, "pts", f"{frame_idx}.txt")) + indices = np.loadtxt(os.path.join(root,scene_name, "pts", f"{frame_idx}_indices.txt")).astype(np.int32).tolist() status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num) pts_list.append(sampled_point_cloud) - continue + scan_points_indices_list.append(indices) + else: path = DataLoadUtil.get_path(root, scene_name, frame_idx) cam_params = DataLoadUtil.load_cam_info(path, binocular=True) status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num) - point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True) + point_cloud, display_table_pts = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True, get_display_table_pts=True) sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree) - + covered_pts, indices = ReconstructionUtil.compute_covered_scan_points(scan_points, display_table_pts) if self.save_pts: pts_dir = os.path.join(root,scene_name, "pts") + covered_pts_dir = os.path.join(pts_dir, "covered_scan_pts") if not os.path.exists(pts_dir): os.makedirs(pts_dir) + if not os.path.exists(covered_pts_dir): + os.makedirs(covered_pts_dir) np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud) + np.savetxt(os.path.join(covered_pts_dir, f"{frame_idx}.txt"), covered_pts) + np.savetxt(os.path.join(pts_dir, f"{frame_idx}_indices.txt"), indices) pts_list.append(sampled_point_cloud) + scan_points_indices_list.append(indices) status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_num, frame_num) seq_num = min(self.seq_num, len(pts_list)) diff --git a/utils/data_load.py b/utils/data_load.py index 16f568b..318a6d5 100644 --- a/utils/data_load.py +++ b/utils/data_load.py @@ -6,56 +6,61 @@ import trimesh import torch from utils.pts import PtsUtil + class DataLoadUtil: - TABLE_POSITION = np.asarray([0,0,0.8215]) + TABLE_POSITION = np.asarray([0, 0, 0.8215]) @staticmethod def get_display_table_info(root, scene_name): scene_info = DataLoadUtil.load_scene_info(root, scene_name) display_table_info = scene_info["display_table"] return display_table_info - + @staticmethod def get_display_table_top(root, scene_name): - display_table_height = DataLoadUtil.get_display_table_info(root, scene_name)["height"] - display_table_top = DataLoadUtil.TABLE_POSITION + np.asarray([0,0,display_table_height]) + display_table_height = DataLoadUtil.get_display_table_info(root, scene_name)[ + "height" + ] + display_table_top = DataLoadUtil.TABLE_POSITION + np.asarray( + [0, 0, display_table_height] + ) return display_table_top - + @staticmethod def get_path(root, scene_name, frame_idx): path = os.path.join(root, scene_name, f"{frame_idx}") return path - + @staticmethod def get_label_num(root, scene_name): - label_dir = os.path.join(root,scene_name,"label") + label_dir = os.path.join(root, scene_name, "label") return len(os.listdir(label_dir)) - + @staticmethod def get_label_path(root, scene_name, seq_idx): - label_dir = os.path.join(root,scene_name,"label") + label_dir = os.path.join(root, scene_name, "label") if not os.path.exists(label_dir): os.makedirs(label_dir) - path = os.path.join(label_dir,f"{seq_idx}.json") + path = os.path.join(label_dir, f"{seq_idx}.json") return path - + @staticmethod def get_label_path_old(root, scene_name): - path = os.path.join(root,scene_name,"label.json") + path = os.path.join(root, scene_name, "label.json") return path - + @staticmethod def get_scene_seq_length(root, scene_name): camera_params_path = os.path.join(root, scene_name, "camera_params") return len(os.listdir(camera_params_path)) - + @staticmethod def load_mesh_at(model_dir, object_name, world_object_pose): model_path = os.path.join(model_dir, object_name, "mesh.obj") mesh = trimesh.load(model_path) mesh.apply_transform(world_object_pose) return mesh - + @staticmethod def get_bbox_diag(model_dir, object_name): model_path = os.path.join(model_dir, object_name, "mesh.obj") @@ -63,8 +68,7 @@ class DataLoadUtil: bbox = mesh.bounding_box.extents diagonal_length = np.linalg.norm(bbox) return diagonal_length - - + @staticmethod def save_mesh_at(model_dir, output_dir, object_name, scene_name, world_object_pose): mesh = DataLoadUtil.load_mesh_at(model_dir, object_name, world_object_pose) @@ -72,12 +76,16 @@ class DataLoadUtil: mesh.export(model_path) @staticmethod - def save_target_mesh_at_world_space(root, model_dir, scene_name, display_table_as_world_space_origin=True): + def save_target_mesh_at_world_space( + root, model_dir, scene_name, display_table_as_world_space_origin=True + ): scene_info = DataLoadUtil.load_scene_info(root, scene_name) target_name = scene_info["target_name"] transformation = scene_info[target_name] if display_table_as_world_space_origin: - location = transformation["location"] - DataLoadUtil.get_display_table_top(root, scene_name) + location = transformation["location"] - DataLoadUtil.get_display_table_top( + root, scene_name + ) else: location = transformation["location"] rotation_euler = transformation["rotation_euler"] @@ -90,14 +98,21 @@ class DataLoadUtil: os.makedirs(mesh_dir) model_path = os.path.join(mesh_dir, "world_target_mesh.obj") mesh.export(model_path) - + @staticmethod def load_scene_info(root, scene_name): scene_info_path = os.path.join(root, scene_name, "scene_info.json") with open(scene_info_path, "r") as f: scene_info = json.load(f) return scene_info - + + @staticmethod + def load_target_pts_num_dict(root, scene_name): + target_pts_num_path = os.path.join(root, scene_name, "target_pts_num.json") + with open(target_pts_num_path, "r") as f: + target_pts_num_dict = json.load(f) + return target_pts_num_dict + @staticmethod def load_target_object_pose(root, scene_name): scene_info = DataLoadUtil.load_scene_info(root, scene_name) @@ -108,10 +123,10 @@ class DataLoadUtil: pose_mat = trimesh.transformations.euler_matrix(*rotation_euler) pose_mat[:3, 3] = location return pose_mat - + @staticmethod - def load_depth(path, min_depth=0.01,max_depth=5.0,binocular=False): - + def load_depth(path, min_depth=0.01, max_depth=5.0, binocular=False): + def load_depth_from_real_path(real_path, min_depth, max_depth): depth = cv2.imread(real_path, cv2.IMREAD_UNCHANGED) depth = depth.astype(np.float32) / 65535.0 @@ -119,78 +134,104 @@ class DataLoadUtil: max_depth = max_depth depth_meters = min_depth + (max_depth - min_depth) * depth return depth_meters - + if binocular: - depth_path_L = os.path.join(os.path.dirname(path), "depth", os.path.basename(path) + "_L.png") - depth_path_R = os.path.join(os.path.dirname(path), "depth", os.path.basename(path) + "_R.png") - depth_meters_L = load_depth_from_real_path(depth_path_L, min_depth, max_depth) - depth_meters_R = load_depth_from_real_path(depth_path_R, min_depth, max_depth) + depth_path_L = os.path.join( + os.path.dirname(path), "depth", os.path.basename(path) + "_L.png" + ) + depth_path_R = os.path.join( + os.path.dirname(path), "depth", os.path.basename(path) + "_R.png" + ) + depth_meters_L = load_depth_from_real_path( + depth_path_L, min_depth, max_depth + ) + depth_meters_R = load_depth_from_real_path( + depth_path_R, min_depth, max_depth + ) return depth_meters_L, depth_meters_R else: - depth_path = os.path.join(os.path.dirname(path), "depth", os.path.basename(path) + ".png") + depth_path = os.path.join( + os.path.dirname(path), "depth", os.path.basename(path) + ".png" + ) depth_meters = load_depth_from_real_path(depth_path, min_depth, max_depth) return depth_meters - + @staticmethod def load_seg(path, binocular=False): if binocular: + def clean_mask(mask_image): green = [0, 255, 0, 255] red = [255, 0, 0, 255] threshold = 2 - mask_image = np.where(np.abs(mask_image - green) <= threshold, green, mask_image) - mask_image = np.where(np.abs(mask_image - red) <= threshold, red, mask_image) + mask_image = np.where( + np.abs(mask_image - green) <= threshold, green, mask_image + ) + mask_image = np.where( + np.abs(mask_image - red) <= threshold, red, mask_image + ) return mask_image - mask_path_L = os.path.join(os.path.dirname(path), "mask", os.path.basename(path) + "_L.png") + + mask_path_L = os.path.join( + os.path.dirname(path), "mask", os.path.basename(path) + "_L.png" + ) mask_image_L = clean_mask(cv2.imread(mask_path_L, cv2.IMREAD_UNCHANGED)) - mask_path_R = os.path.join(os.path.dirname(path), "mask", os.path.basename(path) + "_R.png") + mask_path_R = os.path.join( + os.path.dirname(path), "mask", os.path.basename(path) + "_R.png" + ) mask_image_R = clean_mask(cv2.imread(mask_path_R, cv2.IMREAD_UNCHANGED)) return mask_image_L, mask_image_R else: - mask_path = os.path.join(os.path.dirname(path), "mask", os.path.basename(path) + ".png") + mask_path = os.path.join( + os.path.dirname(path), "mask", os.path.basename(path) + ".png" + ) mask_image = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) return mask_image - + @staticmethod def load_label(path): - with open(path, 'r') as f: + with open(path, "r") as f: label_data = json.load(f) return label_data - + @staticmethod def load_rgb(path): - rgb_path = os.path.join(os.path.dirname(path), "rgb", os.path.basename(path) + ".png") + rgb_path = os.path.join( + os.path.dirname(path), "rgb", os.path.basename(path) + ".png" + ) rgb_image = cv2.imread(rgb_path, cv2.IMREAD_COLOR) return rgb_image - + @staticmethod def load_from_preprocessed_pts(path): - npy_path = os.path.join(os.path.dirname(path), "points", os.path.basename(path) + ".npy") + npy_path = os.path.join( + os.path.dirname(path), "points", os.path.basename(path) + ".npy" + ) pts = np.load(npy_path) return pts @staticmethod def cam_pose_transformation(cam_pose_before): - offset = np.asarray([ - [1, 0, 0, 0], - [0, -1, 0, 0], - [0, 0, -1, 0], - [0, 0, 0, 1]]) - cam_pose_after = cam_pose_before @ offset + offset = np.asarray([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + cam_pose_after = cam_pose_before @ offset return cam_pose_after - + @staticmethod def load_cam_info(path, binocular=False, display_table_as_world_space_origin=True): scene_dir = os.path.dirname(path) root_dir = os.path.dirname(scene_dir) scene_name = os.path.basename(scene_dir) - camera_params_path = os.path.join(os.path.dirname(path), "camera_params", os.path.basename(path) + ".json") - with open(camera_params_path, 'r') as f: + camera_params_path = os.path.join( + os.path.dirname(path), "camera_params", os.path.basename(path) + ".json" + ) + with open(camera_params_path, "r") as f: label_data = json.load(f) cam_to_world = np.asarray(label_data["extrinsic"]) cam_to_world = DataLoadUtil.cam_pose_transformation(cam_to_world) world_to_display_table = np.eye(4) - world_to_display_table[:3, 3] = - DataLoadUtil.get_display_table_top(root_dir, scene_name) + world_to_display_table[:3, 3] = -DataLoadUtil.get_display_table_top( + root_dir, scene_name + ) if display_table_as_world_space_origin: cam_to_world = np.dot(world_to_display_table, cam_to_world) cam_intrinsic = np.asarray(label_data["intrinsic"]) @@ -198,7 +239,7 @@ class DataLoadUtil: "cam_to_world": cam_to_world, "cam_intrinsic": cam_intrinsic, "far_plane": label_data["far_plane"], - "near_plane": label_data["near_plane"] + "near_plane": label_data["near_plane"], } if binocular: cam_to_world_R = np.asarray(label_data["extrinsic_R"]) @@ -211,104 +252,165 @@ class DataLoadUtil: cam_info["cam_to_world_O"] = cam_to_world_O cam_info["cam_to_world_R"] = cam_to_world_R return cam_info - + @staticmethod - def get_real_cam_O_from_cam_L(cam_L, cam_O_to_cam_L, scene_path, display_table_as_world_space_origin=True): + def get_real_cam_O_from_cam_L( + cam_L, cam_O_to_cam_L, scene_path, display_table_as_world_space_origin=True + ): root_dir = os.path.dirname(scene_path) scene_name = os.path.basename(scene_path) if isinstance(cam_L, torch.Tensor): cam_L = cam_L.cpu().numpy() - nO_to_display_table_pose = cam_L @ cam_O_to_cam_L + nO_to_display_table_pose = cam_L @ cam_O_to_cam_L if display_table_as_world_space_origin: display_table_to_world = np.eye(4) - display_table_to_world[:3, 3] = DataLoadUtil.get_display_table_top(root_dir, scene_name) + display_table_to_world[:3, 3] = DataLoadUtil.get_display_table_top( + root_dir, scene_name + ) nO_to_world_pose = np.dot(display_table_to_world, nO_to_display_table_pose) nO_to_world_pose = DataLoadUtil.cam_pose_transformation(nO_to_world_pose) return nO_to_world_pose - + @staticmethod - def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(0,255,0,255)): + def get_target_point_cloud( + depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(0, 255, 0, 255) + ): h, w = depth.shape - i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy') - + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") + z = depth x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] - - points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) - mask = mask.reshape(-1,4) - target_mask = (mask == target_mask_label).all(axis=-1) - + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + mask = mask.reshape(-1, 4) + + target_mask = (mask == target_mask_label).all(axis=-1) + target_points_camera = points_camera[target_mask] - target_points_camera_aug = np.concatenate([target_points_camera, np.ones((target_points_camera.shape[0], 1))], axis=-1) - + target_points_camera_aug = np.concatenate( + [target_points_camera, np.ones((target_points_camera.shape[0], 1))], axis=-1 + ) + target_points_world = np.dot(cam_extrinsic, target_points_camera_aug.T).T[:, :3] return { "points_world": target_points_world, - "points_camera": target_points_camera + "points_camera": target_points_camera, } - + @staticmethod def get_point_cloud(depth, cam_intrinsic, cam_extrinsic): h, w = depth.shape - i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy') - + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") + z = depth x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] - + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) - points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1) - + points_camera_aug = np.concatenate( + [points_camera, np.ones((points_camera.shape[0], 1))], axis=-1 + ) + points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] - return { - "points_world": points_world, - "points_camera": points_camera - } - + return {"points_world": points_world, "points_camera": points_camera} + @staticmethod - def get_target_point_cloud_world_from_path(path, binocular=False, random_downsample_N=65536, voxel_size = 0.005, target_mask_label=(0,255,0,255)): + def get_target_point_cloud_world_from_path( + path, + binocular=False, + random_downsample_N=65536, + voxel_size=0.005, + target_mask_label=(0, 255, 0, 255), + display_table_mask_label=(255, 0, 0, 255), + get_display_table_pts=False + ): cam_info = DataLoadUtil.load_cam_info(path, binocular=binocular) if binocular: - depth_L, depth_R = DataLoadUtil.load_depth(path, cam_info['near_plane'], cam_info['far_plane'], binocular=True) + depth_L, depth_R = DataLoadUtil.load_depth( + path, cam_info["near_plane"], cam_info["far_plane"], binocular=True + ) mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) - point_cloud_L = DataLoadUtil.get_target_point_cloud(depth_L, cam_info['cam_intrinsic'], cam_info['cam_to_world'], mask_L, target_mask_label)['points_world'] - point_cloud_R = DataLoadUtil.get_target_point_cloud(depth_R, cam_info['cam_intrinsic'], cam_info['cam_to_world_R'], mask_R, target_mask_label)['points_world'] - point_cloud_L = PtsUtil.random_downsample_point_cloud(point_cloud_L, random_downsample_N) - point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, random_downsample_N) - overlap_points = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size) + point_cloud_L = DataLoadUtil.get_target_point_cloud( + depth_L, + cam_info["cam_intrinsic"], + cam_info["cam_to_world"], + mask_L, + target_mask_label, + )["points_world"] + point_cloud_R = DataLoadUtil.get_target_point_cloud( + depth_R, + cam_info["cam_intrinsic"], + cam_info["cam_to_world_R"], + mask_R, + target_mask_label, + )["points_world"] + point_cloud_L = PtsUtil.random_downsample_point_cloud( + point_cloud_L, random_downsample_N + ) + point_cloud_R = PtsUtil.random_downsample_point_cloud( + point_cloud_R, random_downsample_N + ) + overlap_points = DataLoadUtil.get_overlapping_points( + point_cloud_L, point_cloud_R, voxel_size + ) + if get_display_table_pts: + display_pts_L = DataLoadUtil.get_target_point_cloud( + depth_L, + cam_info["cam_intrinsic"], + cam_info["cam_to_world"], + mask_L, + display_table_mask_label, + )["points_world"] + display_pts_R = DataLoadUtil.get_target_point_cloud( + depth_R, + cam_info["cam_intrinsic"], + cam_info["cam_to_world_R"], + mask_R, + display_table_mask_label, + )["points_world"] + display_pts_overlap = DataLoadUtil.get_overlapping_points( + display_pts_L, display_pts_R, voxel_size + ) + return overlap_points, display_pts_overlap return overlap_points else: - depth = DataLoadUtil.load_depth(path, cam_info['near_plane'], cam_info['far_plane']) + depth = DataLoadUtil.load_depth( + path, cam_info["near_plane"], cam_info["far_plane"] + ) mask = DataLoadUtil.load_seg(path) - point_cloud = DataLoadUtil.get_target_point_cloud(depth, cam_info['cam_intrinsic'], cam_info['cam_to_world'], mask)['points_world'] + point_cloud = DataLoadUtil.get_target_point_cloud( + depth, cam_info["cam_intrinsic"], cam_info["cam_to_world"], mask + )["points_world"] return point_cloud - - + @staticmethod def voxelize_points(points, voxel_size): - + voxel_indices = np.floor(points / voxel_size).astype(np.int32) unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True) return unique_voxels - + @staticmethod def get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size=0.005): voxels_L, indices_L = DataLoadUtil.voxelize_points(point_cloud_L, voxel_size) voxels_R, _ = DataLoadUtil.voxelize_points(point_cloud_R, voxel_size) - voxel_indices_L = voxels_L.view([('', voxels_L.dtype)]*3) - voxel_indices_R = voxels_R.view([('', voxels_R.dtype)]*3) + voxel_indices_L = voxels_L.view([("", voxels_L.dtype)] * 3) + voxel_indices_R = voxels_R.view([("", voxels_R.dtype)] * 3) overlapping_voxels = np.intersect1d(voxel_indices_L, voxel_indices_R) - mask_L = np.isin(indices_L, np.where(np.isin(voxel_indices_L, overlapping_voxels))[0]) + mask_L = np.isin( + indices_L, np.where(np.isin(voxel_indices_L, overlapping_voxels))[0] + ) overlapping_points = point_cloud_L[mask_L] return overlapping_points - + @staticmethod def load_points_normals(root, scene_name, display_table_as_world_space_origin=True): points_path = os.path.join(root, scene_name, "points_and_normals.txt") points_normals = np.loadtxt(points_path) if display_table_as_world_space_origin: - points_normals[:,:3] = points_normals[:,:3] - DataLoadUtil.get_display_table_top(root, scene_name) - return points_normals \ No newline at end of file + points_normals[:, :3] = points_normals[ + :, :3 + ] - DataLoadUtil.get_display_table_top(root, scene_name) + return points_normals diff --git a/utils/reconstruction.py b/utils/reconstruction.py index 7b9d5cf..0b78e59 100644 --- a/utils/reconstruction.py +++ b/utils/reconstruction.py @@ -45,9 +45,10 @@ class ReconstructionUtil: @staticmethod - def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list,threshold=0.01, overlap_threshold=0.3, init_view = 0, status_info=None): + def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, scan_points_indices_list, threshold=0.01, overlap_threshold=0.3, init_view = 0, status_info=None): selected_views = [point_cloud_list[init_view]] combined_point_cloud = np.vstack(selected_views) + combined_scan_points_indices = scan_points_indices_list[init_view] down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) current_coverage = new_coverage @@ -63,12 +64,14 @@ class ReconstructionUtil: for view_index in remaining_views: if selected_views: - combined_old_point_cloud = np.vstack(selected_views) - down_sampled_old_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_old_point_cloud,threshold) - down_sampled_new_view_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud_list[view_index],threshold) - overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_new_view_point_cloud,down_sampled_old_point_cloud, threshold) - if overlap_rate < overlap_threshold: - continue + new_scan_points_indices = scan_points_indices_list[view_index] + if not ReconstructionUtil.check_scan_points_overlap(combined_scan_points_indices, new_scan_points_indices): + combined_old_point_cloud = np.vstack(selected_views) + down_sampled_old_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_old_point_cloud,threshold) + down_sampled_new_view_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud_list[view_index],threshold) + overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_new_view_point_cloud,down_sampled_old_point_cloud, threshold) + if overlap_rate < overlap_threshold: + continue candidate_views = selected_views + [point_cloud_list[view_index]] combined_point_cloud = np.vstack(candidate_views) @@ -85,6 +88,7 @@ class ReconstructionUtil: break selected_views.append(point_cloud_list[best_view]) remaining_views.remove(best_view) + combined_scan_points_indices = ReconstructionUtil.combine_scan_points_indices(combined_scan_points_indices, scan_points_indices_list[best_view]) current_coverage += best_coverage_increase cnt_processed_view += 1 if status_info is not None: @@ -120,4 +124,39 @@ class ReconstructionUtil: filtered_sampled_points= sampled_points[cos_theta > np.cos(theta_rad)] return filtered_sampled_points[:, :3] - \ No newline at end of file + + @staticmethod + def generate_scan_points(display_table_top, display_table_radius, min_distance=0.03, max_points_num = 100, max_attempts = 1000): + points = [] + attempts = 0 + while len(points) < max_points_num and attempts < max_attempts: + angle = np.random.uniform(0, 2 * np.pi) + r = np.random.uniform(0, display_table_radius) + x = r * np.cos(angle) + y = r * np.sin(angle) + z = display_table_top + new_point = (x, y, z) + if all(np.linalg.norm(np.array(new_point) - np.array(existing_point)) >= min_distance for existing_point in points): + points.append(new_point) + attempts += 1 + return points + + @staticmethod + def compute_covered_scan_points(scan_points, point_cloud, threshold=0.01): + tree = cKDTree(point_cloud) + covered_points = [] + indices = [] + for i, scan_point in enumerate(scan_points): + if tree.query_ball_point(scan_point, threshold): + covered_points.append(scan_point) + indices.append(i) + return covered_points, indices + + @staticmethod + def check_scan_points_overlap(indices1, indices2, threshold=5): + return len(set(indices1).intersection(set(indices2))) > threshold + + @staticmethod + def combine_scan_points_indices(indices1, indices2): + combined_indices = set(indices1) | set(indices2) + return sorted(combined_indices) \ No newline at end of file