diff --git a/core/global_pts_n_num_pipeline.py b/core/global_pts_n_num_pipeline.py index 3970bd3..efc2bb8 100644 --- a/core/global_pts_n_num_pipeline.py +++ b/core/global_pts_n_num_pipeline.py @@ -12,41 +12,55 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): super(NBVReconstructionGlobalPointsPipeline, self).__init__() self.config = config self.module_config = config["modules"] - self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) - self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) - self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]) - self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) - self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]) - + self.pts_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pts_encoder"] + ) + self.pose_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pose_encoder"] + ) + self.pose_n_num_seq_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"] + ) + self.view_finder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["view_finder"] + ) + self.pts_num_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"] + ) + self.eps = float(self.config["eps"]) self.enable_global_scanned_feat = self.config["global_scanned_feat"] - + def forward(self, data): mode = data["mode"] - + if mode == namespace.Mode.TRAIN: return self.forward_train(data) elif mode == namespace.Mode.TEST: return self.forward_test(data) else: Log.error("Unknown mode: {}".format(mode), True) - + def pertube_data(self, gt_delta_9d): bs = gt_delta_9d.shape[0] - random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps + random_t = ( + torch.rand(bs, device=gt_delta_9d.device) * (1.0 - self.eps) + self.eps + ) random_t = random_t.unsqueeze(-1) mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) std = std.view(-1, 1) z = torch.randn_like(gt_delta_9d) perturbed_x = mu + z * std - target_score = - z * std / (std ** 2) + target_score = -z * std / (std**2) return perturbed_x, random_t, target_score, std - + def forward_train(self, data): main_feat = self.get_main_feat(data) - ''' get std ''' + """ get std """ best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] - perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch) + perturbed_x, random_t, target_score, std = self.pertube_data( + best_to_world_pose_9d_batch + ) input_data = { "sampled_pose": perturbed_x, "t": random_t, @@ -56,45 +70,69 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): output = { "estimated_score": estimated_score, "target_score": target_score, - "std": std + "std": std, } return output - - def forward_test(self,data): + + def forward_test(self, data): main_feat = self.get_main_feat(data) - estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat) + estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view( + main_feat + ) result = { "pred_pose_9d": estimated_delta_rot_9d, - "in_process_sample": in_process_sample + "in_process_sample": in_process_sample, } return result - - + def get_main_feat(self, data): - scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d'] - scanned_target_pts_num_batch = data['scanned_target_points_num'] + scanned_n_to_world_pose_9d_batch = data[ + "scanned_n_to_world_pose_9d" + ] # List(B): Tensor(S x 9) + scanned_pts_mask_batch = data[ + "scanned_pts_mask" + ] # Tensor(B x N) device = next(self.parameters()).device embedding_list_batch = [] - - for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch): - scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) - scanned_target_pts_num = scanned_target_pts_num.to(device) - pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) - pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) - embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1)) - main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) - - - combined_scanned_pts_batch = data['combined_scanned_pts'] - global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch) - main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1) - - + combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3) + global_scanned_feat, perpoint_scanned_feat_batch = self.pts_encoder.encode_points( + combined_scanned_pts_batch, require_per_point_feat=True + ) # global_scanned_feat: Tensor(B x Dg), perpoint_scanned_feat: Tensor(B x N x Dl) + + for scanned_n_to_world_pose_9d, scanned_mask, perpoint_scanned_feat in zip( + scanned_n_to_world_pose_9d_batch, + scanned_pts_mask_batch, + perpoint_scanned_feat_batch, + ): + scanned_target_pts_num = [] # List(S): Int + partial_feat_seq = [] + + seq_len = len(scanned_n_to_world_pose_9d) + for seq_idx in range(seq_len): + partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask + partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl) + partial_feat = torch.mean(partial_perpoint_feat, dim=0)[0] # Tensor(Dl) + partial_feat_seq.append(partial_feat) + scanned_target_pts_num.append(partial_perpoint_feat.shape[0]) + + scanned_target_pts_num = torch.tensor(scanned_target_pts_num, dtype=torch.int32).to(device) # Tensor(S) + scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9) + + pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp) + pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) # Tensor(S x Dn) + partial_feat_seq = torch.stack(partial_feat_seq) # Tensor(S x Dl) + + seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl)) + embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl)) + + seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) + + main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg)) + if torch.isnan(main_feat).any(): Log.error("nan in main_feat", True) - + return main_feat - diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index 3880ed5..4013d98 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -122,7 +122,6 @@ class NBVReconstructionDataset(BaseDataset): scanned_views_pts, scanned_coverages_rate, scanned_n_to_world_pose, - scanned_target_pts_num, ) = ([], [], [], []) for view in scanned_views: frame_idx = view[0] @@ -134,7 +133,6 @@ class NBVReconstructionDataset(BaseDataset): target_point_cloud = ( DataLoadUtil.load_from_preprocessed_pts(view_path) ) - target_pts_num = target_point_cloud.shape[0] downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud( target_point_cloud, self.pts_num ) @@ -146,7 +144,7 @@ class NBVReconstructionDataset(BaseDataset): n_to_world_trans = n_to_world_pose[:3, 3] n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0) scanned_n_to_world_pose.append(n_to_world_9d) - scanned_target_pts_num.append(target_pts_num) + nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx) @@ -162,35 +160,33 @@ class NBVReconstructionDataset(BaseDataset): ) combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) - fps_downsampled_combined_scanned_pts, fps_mask = PtsUtil.fps_downsample_point_cloud( - combined_scanned_views_pts, self.pts_num, require_mask=True + fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud( + combined_scanned_views_pts, self.pts_num, require_idx=True ) - view_start_indices = np.cumsum([0] + [pts.shape[0] for pts in scanned_views_pts[:-1]]) - scanned_pts_mask = [] - - for i, start_idx in enumerate(view_start_indices[:-1]): - end_idx = view_start_indices[i + 1] - view_mask = fps_mask[start_idx:end_idx] - scanned_pts_mask.append(view_mask) + combined_scanned_views_pts_mask = np.zeros(len(scanned_views_pts), dtype=np.uint8) + + start_idx = 0 + for i in range(len(scanned_views_pts)): + end_idx = start_idx + len(scanned_views_pts[i]) + combined_scanned_views_pts_mask[start_idx:end_idx] = i + start_idx = end_idx + + fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx] + + + data_item = { - "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), - "scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.uint8), - "combined_scanned_pts": np.asarray( - fps_downsampled_combined_scanned_pts, dtype=np.float32 - ), - "scanned_coverage_rate": scanned_coverages_rate, - "scanned_n_to_world_pose_9d": np.asarray( - scanned_n_to_world_pose, dtype=np.float32 - ), - "best_coverage_rate": nbv_coverage_rate, - "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), - "seq_max_coverage_rate": max_coverage_rate, - "scene_name": scene_name, - "scanned_target_points_num": np.asarray( - scanned_target_pts_num, dtype=np.int32 - ), + "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3) + "scanned_pts_mask": np.asarray(fps_downsampled_combined_scanned_pts_mask,dtype=np.uint8), # Ndarray(N), range(0, S) + "combined_scanned_pts": np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32), # Ndarray(N x 3) + "scanned_coverage_rate": scanned_coverages_rate, # List(S): Float, range(0, 1) + "scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9) + "best_coverage_rate": nbv_coverage_rate, # Float, range(0, 1) + "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), # Ndarray(9) + "seq_max_coverage_rate": max_coverage_rate, # Float, range(0, 1) + "scene_name": scene_name, # String } return data_item @@ -201,37 +197,35 @@ class NBVReconstructionDataset(BaseDataset): def get_collate_fn(self): def collate_fn(batch): collate_data = {} + + ''' ------ Varialbe Length ------ ''' + collate_data["scanned_pts"] = [ torch.tensor(item["scanned_pts"]) for item in batch ] - collate_data["scanned_pts_mask"] = [ - torch.tensor(item["scanned_pts_mask"]) for item in batch - ] collate_data["scanned_n_to_world_pose_9d"] = [ torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch ] - collate_data["scanned_target_points_num"] = [ - torch.tensor(item["scanned_target_points_num"]) for item in batch - ] + + ''' ------ Fixed Length ------ ''' + collate_data["best_to_world_pose_9d"] = torch.stack( [torch.tensor(item["best_to_world_pose_9d"]) for item in batch] ) collate_data["combined_scanned_pts"] = torch.stack( [torch.tensor(item["combined_scanned_pts"]) for item in batch] ) - if "first_frame_to_world" in batch[0]: - collate_data["first_frame_to_world"] = torch.stack( - [torch.tensor(item["first_frame_to_world"]) for item in batch] - ) + collate_data["scanned_pts_mask"] = torch.stack( + [torch.tensor(item["scanned_pts_mask"]) for item in batch] + ) + for key in batch[0].keys(): if key not in [ "scanned_pts", "scanned_pts_mask", "scanned_n_to_world_pose_9d", "best_to_world_pose_9d", - "first_frame_to_world", "combined_scanned_pts", - "scanned_target_points_num", ]: collate_data[key] = [item[key] for item in batch] return collate_data diff --git a/utils/pts.py b/utils/pts.py index 4716ce1..87818cb 100644 --- a/utils/pts.py +++ b/utils/pts.py @@ -24,13 +24,12 @@ class PtsUtil: return point_cloud[idx] @staticmethod - def fps_downsample_point_cloud(point_cloud, num_points, require_mask=False): + def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False): N = point_cloud.shape[0] mask = np.zeros(N, dtype=bool) sampled_indices = np.zeros(num_points, dtype=int) sampled_indices[0] = np.random.randint(0, N) - mask[sampled_indices[0]] = True distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1) for i in range(1, num_points): farthest_index = np.argmax(distances) @@ -41,8 +40,8 @@ class PtsUtil: distances = np.minimum(distances, new_distances) sampled_points = point_cloud[sampled_indices] - if require_mask: - return sampled_points, mask + if require_idx: + return sampled_points, sampled_indices return sampled_points @staticmethod