update fps algo and fps mask

This commit is contained in:
hofee 2024-10-06 13:48:54 +08:00
parent 276f45dcc3
commit fa69f9f879
3 changed files with 115 additions and 84 deletions

View File

@ -12,11 +12,21 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
super(NBVReconstructionGlobalPointsPipeline, self).__init__() super(NBVReconstructionGlobalPointsPipeline, self).__init__()
self.config = config self.config = config
self.module_config = config["modules"] self.module_config = config["modules"]
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) self.pts_encoder = ComponentFactory.create(
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) namespace.Stereotype.MODULE, self.module_config["pts_encoder"]
self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]) )
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) self.pose_encoder = ComponentFactory.create(
self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]) namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
)
self.pose_n_num_seq_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]
)
self.view_finder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["view_finder"]
)
self.pts_num_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
)
self.eps = float(self.config["eps"]) self.eps = float(self.config["eps"])
self.enable_global_scanned_feat = self.config["global_scanned_feat"] self.enable_global_scanned_feat = self.config["global_scanned_feat"]
@ -33,20 +43,24 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
def pertube_data(self, gt_delta_9d): def pertube_data(self, gt_delta_9d):
bs = gt_delta_9d.shape[0] bs = gt_delta_9d.shape[0]
random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps random_t = (
torch.rand(bs, device=gt_delta_9d.device) * (1.0 - self.eps) + self.eps
)
random_t = random_t.unsqueeze(-1) random_t = random_t.unsqueeze(-1)
mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t)
std = std.view(-1, 1) std = std.view(-1, 1)
z = torch.randn_like(gt_delta_9d) z = torch.randn_like(gt_delta_9d)
perturbed_x = mu + z * std perturbed_x = mu + z * std
target_score = - z * std / (std ** 2) target_score = -z * std / (std**2)
return perturbed_x, random_t, target_score, std return perturbed_x, random_t, target_score, std
def forward_train(self, data): def forward_train(self, data):
main_feat = self.get_main_feat(data) main_feat = self.get_main_feat(data)
''' get std ''' """ get std """
best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] best_to_world_pose_9d_batch = data["best_to_world_pose_9d"]
perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch) perturbed_x, random_t, target_score, std = self.pertube_data(
best_to_world_pose_9d_batch
)
input_data = { input_data = {
"sampled_pose": perturbed_x, "sampled_pose": perturbed_x,
"t": random_t, "t": random_t,
@ -56,45 +70,69 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
output = { output = {
"estimated_score": estimated_score, "estimated_score": estimated_score,
"target_score": target_score, "target_score": target_score,
"std": std "std": std,
} }
return output return output
def forward_test(self,data): def forward_test(self, data):
main_feat = self.get_main_feat(data) main_feat = self.get_main_feat(data)
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat) estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(
main_feat
)
result = { result = {
"pred_pose_9d": estimated_delta_rot_9d, "pred_pose_9d": estimated_delta_rot_9d,
"in_process_sample": in_process_sample "in_process_sample": in_process_sample,
} }
return result return result
def get_main_feat(self, data): def get_main_feat(self, data):
scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d'] scanned_n_to_world_pose_9d_batch = data[
scanned_target_pts_num_batch = data['scanned_target_points_num'] "scanned_n_to_world_pose_9d"
] # List(B): Tensor(S x 9)
scanned_pts_mask_batch = data[
"scanned_pts_mask"
] # Tensor(B x N)
device = next(self.parameters()).device device = next(self.parameters()).device
embedding_list_batch = [] embedding_list_batch = []
for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch): combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3)
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) global_scanned_feat, perpoint_scanned_feat_batch = self.pts_encoder.encode_points(
scanned_target_pts_num = scanned_target_pts_num.to(device) combined_scanned_pts_batch, require_per_point_feat=True
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) ) # global_scanned_feat: Tensor(B x Dg), perpoint_scanned_feat: Tensor(B x N x Dl)
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)
embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1))
main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) for scanned_n_to_world_pose_9d, scanned_mask, perpoint_scanned_feat in zip(
scanned_n_to_world_pose_9d_batch,
scanned_pts_mask_batch,
perpoint_scanned_feat_batch,
):
scanned_target_pts_num = [] # List(S): Int
partial_feat_seq = []
seq_len = len(scanned_n_to_world_pose_9d)
for seq_idx in range(seq_len):
partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask
partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl)
partial_feat = torch.mean(partial_perpoint_feat, dim=0)[0] # Tensor(Dl)
partial_feat_seq.append(partial_feat)
scanned_target_pts_num.append(partial_perpoint_feat.shape[0])
combined_scanned_pts_batch = data['combined_scanned_pts'] scanned_target_pts_num = torch.tensor(scanned_target_pts_num, dtype=torch.int32).to(device) # Tensor(S)
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch) scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9)
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp)
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) # Tensor(S x Dn)
partial_feat_seq = torch.stack(partial_feat_seq) # Tensor(S x Dl)
seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl))
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl))
seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
if torch.isnan(main_feat).any(): if torch.isnan(main_feat).any():
Log.error("nan in main_feat", True) Log.error("nan in main_feat", True)
return main_feat return main_feat

View File

@ -122,7 +122,6 @@ class NBVReconstructionDataset(BaseDataset):
scanned_views_pts, scanned_views_pts,
scanned_coverages_rate, scanned_coverages_rate,
scanned_n_to_world_pose, scanned_n_to_world_pose,
scanned_target_pts_num,
) = ([], [], [], []) ) = ([], [], [], [])
for view in scanned_views: for view in scanned_views:
frame_idx = view[0] frame_idx = view[0]
@ -134,7 +133,6 @@ class NBVReconstructionDataset(BaseDataset):
target_point_cloud = ( target_point_cloud = (
DataLoadUtil.load_from_preprocessed_pts(view_path) DataLoadUtil.load_from_preprocessed_pts(view_path)
) )
target_pts_num = target_point_cloud.shape[0]
downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud( downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(
target_point_cloud, self.pts_num target_point_cloud, self.pts_num
) )
@ -146,7 +144,7 @@ class NBVReconstructionDataset(BaseDataset):
n_to_world_trans = n_to_world_pose[:3, 3] n_to_world_trans = n_to_world_pose[:3, 3]
n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0) n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0)
scanned_n_to_world_pose.append(n_to_world_9d) scanned_n_to_world_pose.append(n_to_world_9d)
scanned_target_pts_num.append(target_pts_num)
nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] nbv_idx, nbv_coverage_rate = nbv[0], nbv[1]
nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx) nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx)
@ -162,35 +160,33 @@ class NBVReconstructionDataset(BaseDataset):
) )
combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0)
fps_downsampled_combined_scanned_pts, fps_mask = PtsUtil.fps_downsample_point_cloud( fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
combined_scanned_views_pts, self.pts_num, require_mask=True combined_scanned_views_pts, self.pts_num, require_idx=True
) )
view_start_indices = np.cumsum([0] + [pts.shape[0] for pts in scanned_views_pts[:-1]]) combined_scanned_views_pts_mask = np.zeros(len(scanned_views_pts), dtype=np.uint8)
scanned_pts_mask = []
start_idx = 0
for i in range(len(scanned_views_pts)):
end_idx = start_idx + len(scanned_views_pts[i])
combined_scanned_views_pts_mask[start_idx:end_idx] = i
start_idx = end_idx
fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx]
for i, start_idx in enumerate(view_start_indices[:-1]):
end_idx = view_start_indices[i + 1]
view_mask = fps_mask[start_idx:end_idx]
scanned_pts_mask.append(view_mask)
data_item = { data_item = {
"scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3)
"scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.uint8), "scanned_pts_mask": np.asarray(fps_downsampled_combined_scanned_pts_mask,dtype=np.uint8), # Ndarray(N), range(0, S)
"combined_scanned_pts": np.asarray( "combined_scanned_pts": np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32), # Ndarray(N x 3)
fps_downsampled_combined_scanned_pts, dtype=np.float32 "scanned_coverage_rate": scanned_coverages_rate, # List(S): Float, range(0, 1)
), "scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9)
"scanned_coverage_rate": scanned_coverages_rate, "best_coverage_rate": nbv_coverage_rate, # Float, range(0, 1)
"scanned_n_to_world_pose_9d": np.asarray( "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), # Ndarray(9)
scanned_n_to_world_pose, dtype=np.float32 "seq_max_coverage_rate": max_coverage_rate, # Float, range(0, 1)
), "scene_name": scene_name, # String
"best_coverage_rate": nbv_coverage_rate,
"best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32),
"seq_max_coverage_rate": max_coverage_rate,
"scene_name": scene_name,
"scanned_target_points_num": np.asarray(
scanned_target_pts_num, dtype=np.int32
),
} }
return data_item return data_item
@ -201,37 +197,35 @@ class NBVReconstructionDataset(BaseDataset):
def get_collate_fn(self): def get_collate_fn(self):
def collate_fn(batch): def collate_fn(batch):
collate_data = {} collate_data = {}
''' ------ Varialbe Length ------ '''
collate_data["scanned_pts"] = [ collate_data["scanned_pts"] = [
torch.tensor(item["scanned_pts"]) for item in batch torch.tensor(item["scanned_pts"]) for item in batch
] ]
collate_data["scanned_pts_mask"] = [
torch.tensor(item["scanned_pts_mask"]) for item in batch
]
collate_data["scanned_n_to_world_pose_9d"] = [ collate_data["scanned_n_to_world_pose_9d"] = [
torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch
] ]
collate_data["scanned_target_points_num"] = [
torch.tensor(item["scanned_target_points_num"]) for item in batch ''' ------ Fixed Length ------ '''
]
collate_data["best_to_world_pose_9d"] = torch.stack( collate_data["best_to_world_pose_9d"] = torch.stack(
[torch.tensor(item["best_to_world_pose_9d"]) for item in batch] [torch.tensor(item["best_to_world_pose_9d"]) for item in batch]
) )
collate_data["combined_scanned_pts"] = torch.stack( collate_data["combined_scanned_pts"] = torch.stack(
[torch.tensor(item["combined_scanned_pts"]) for item in batch] [torch.tensor(item["combined_scanned_pts"]) for item in batch]
) )
if "first_frame_to_world" in batch[0]: collate_data["scanned_pts_mask"] = torch.stack(
collate_data["first_frame_to_world"] = torch.stack( [torch.tensor(item["scanned_pts_mask"]) for item in batch]
[torch.tensor(item["first_frame_to_world"]) for item in batch] )
)
for key in batch[0].keys(): for key in batch[0].keys():
if key not in [ if key not in [
"scanned_pts", "scanned_pts",
"scanned_pts_mask", "scanned_pts_mask",
"scanned_n_to_world_pose_9d", "scanned_n_to_world_pose_9d",
"best_to_world_pose_9d", "best_to_world_pose_9d",
"first_frame_to_world",
"combined_scanned_pts", "combined_scanned_pts",
"scanned_target_points_num",
]: ]:
collate_data[key] = [item[key] for item in batch] collate_data[key] = [item[key] for item in batch]
return collate_data return collate_data

View File

@ -24,13 +24,12 @@ class PtsUtil:
return point_cloud[idx] return point_cloud[idx]
@staticmethod @staticmethod
def fps_downsample_point_cloud(point_cloud, num_points, require_mask=False): def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False):
N = point_cloud.shape[0] N = point_cloud.shape[0]
mask = np.zeros(N, dtype=bool) mask = np.zeros(N, dtype=bool)
sampled_indices = np.zeros(num_points, dtype=int) sampled_indices = np.zeros(num_points, dtype=int)
sampled_indices[0] = np.random.randint(0, N) sampled_indices[0] = np.random.randint(0, N)
mask[sampled_indices[0]] = True
distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1) distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1)
for i in range(1, num_points): for i in range(1, num_points):
farthest_index = np.argmax(distances) farthest_index = np.argmax(distances)
@ -41,8 +40,8 @@ class PtsUtil:
distances = np.minimum(distances, new_distances) distances = np.minimum(distances, new_distances)
sampled_points = point_cloud[sampled_indices] sampled_points = point_cloud[sampled_indices]
if require_mask: if require_idx:
return sampled_points, mask return sampled_points, sampled_indices
return sampled_points return sampled_points
@staticmethod @staticmethod