update fps algo and fps mask
This commit is contained in:
parent
276f45dcc3
commit
fa69f9f879
@ -12,11 +12,21 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
|
||||
self.config = config
|
||||
self.module_config = config["modules"]
|
||||
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"])
|
||||
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"])
|
||||
self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"])
|
||||
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"])
|
||||
self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"])
|
||||
self.pts_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pts_encoder"]
|
||||
)
|
||||
self.pose_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
|
||||
)
|
||||
self.pose_n_num_seq_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]
|
||||
)
|
||||
self.view_finder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["view_finder"]
|
||||
)
|
||||
self.pts_num_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
|
||||
)
|
||||
|
||||
self.eps = float(self.config["eps"])
|
||||
self.enable_global_scanned_feat = self.config["global_scanned_feat"]
|
||||
@ -33,20 +43,24 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
|
||||
def pertube_data(self, gt_delta_9d):
|
||||
bs = gt_delta_9d.shape[0]
|
||||
random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps
|
||||
random_t = (
|
||||
torch.rand(bs, device=gt_delta_9d.device) * (1.0 - self.eps) + self.eps
|
||||
)
|
||||
random_t = random_t.unsqueeze(-1)
|
||||
mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t)
|
||||
std = std.view(-1, 1)
|
||||
z = torch.randn_like(gt_delta_9d)
|
||||
perturbed_x = mu + z * std
|
||||
target_score = - z * std / (std ** 2)
|
||||
target_score = -z * std / (std**2)
|
||||
return perturbed_x, random_t, target_score, std
|
||||
|
||||
def forward_train(self, data):
|
||||
main_feat = self.get_main_feat(data)
|
||||
''' get std '''
|
||||
""" get std """
|
||||
best_to_world_pose_9d_batch = data["best_to_world_pose_9d"]
|
||||
perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch)
|
||||
perturbed_x, random_t, target_score, std = self.pertube_data(
|
||||
best_to_world_pose_9d_batch
|
||||
)
|
||||
input_data = {
|
||||
"sampled_pose": perturbed_x,
|
||||
"t": random_t,
|
||||
@ -56,45 +70,69 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
output = {
|
||||
"estimated_score": estimated_score,
|
||||
"target_score": target_score,
|
||||
"std": std
|
||||
"std": std,
|
||||
}
|
||||
return output
|
||||
|
||||
def forward_test(self,data):
|
||||
def forward_test(self, data):
|
||||
main_feat = self.get_main_feat(data)
|
||||
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat)
|
||||
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(
|
||||
main_feat
|
||||
)
|
||||
result = {
|
||||
"pred_pose_9d": estimated_delta_rot_9d,
|
||||
"in_process_sample": in_process_sample
|
||||
"in_process_sample": in_process_sample,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def get_main_feat(self, data):
|
||||
scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d']
|
||||
scanned_target_pts_num_batch = data['scanned_target_points_num']
|
||||
scanned_n_to_world_pose_9d_batch = data[
|
||||
"scanned_n_to_world_pose_9d"
|
||||
] # List(B): Tensor(S x 9)
|
||||
scanned_pts_mask_batch = data[
|
||||
"scanned_pts_mask"
|
||||
] # Tensor(B x N)
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
embedding_list_batch = []
|
||||
|
||||
for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch):
|
||||
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device)
|
||||
scanned_target_pts_num = scanned_target_pts_num.to(device)
|
||||
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)
|
||||
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)
|
||||
embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1))
|
||||
combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3)
|
||||
global_scanned_feat, perpoint_scanned_feat_batch = self.pts_encoder.encode_points(
|
||||
combined_scanned_pts_batch, require_per_point_feat=True
|
||||
) # global_scanned_feat: Tensor(B x Dg), perpoint_scanned_feat: Tensor(B x N x Dl)
|
||||
|
||||
main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch)
|
||||
for scanned_n_to_world_pose_9d, scanned_mask, perpoint_scanned_feat in zip(
|
||||
scanned_n_to_world_pose_9d_batch,
|
||||
scanned_pts_mask_batch,
|
||||
perpoint_scanned_feat_batch,
|
||||
):
|
||||
scanned_target_pts_num = [] # List(S): Int
|
||||
partial_feat_seq = []
|
||||
|
||||
seq_len = len(scanned_n_to_world_pose_9d)
|
||||
for seq_idx in range(seq_len):
|
||||
partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask
|
||||
partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl)
|
||||
partial_feat = torch.mean(partial_perpoint_feat, dim=0)[0] # Tensor(Dl)
|
||||
partial_feat_seq.append(partial_feat)
|
||||
scanned_target_pts_num.append(partial_perpoint_feat.shape[0])
|
||||
|
||||
combined_scanned_pts_batch = data['combined_scanned_pts']
|
||||
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
|
||||
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
|
||||
scanned_target_pts_num = torch.tensor(scanned_target_pts_num, dtype=torch.int32).to(device) # Tensor(S)
|
||||
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9)
|
||||
|
||||
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp)
|
||||
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) # Tensor(S x Dn)
|
||||
partial_feat_seq = torch.stack(partial_feat_seq) # Tensor(S x Dl)
|
||||
|
||||
seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl))
|
||||
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl))
|
||||
|
||||
seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
|
||||
|
||||
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
||||
|
||||
if torch.isnan(main_feat).any():
|
||||
Log.error("nan in main_feat", True)
|
||||
|
||||
return main_feat
|
||||
|
||||
|
@ -122,7 +122,6 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
scanned_views_pts,
|
||||
scanned_coverages_rate,
|
||||
scanned_n_to_world_pose,
|
||||
scanned_target_pts_num,
|
||||
) = ([], [], [], [])
|
||||
for view in scanned_views:
|
||||
frame_idx = view[0]
|
||||
@ -134,7 +133,6 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
target_point_cloud = (
|
||||
DataLoadUtil.load_from_preprocessed_pts(view_path)
|
||||
)
|
||||
target_pts_num = target_point_cloud.shape[0]
|
||||
downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(
|
||||
target_point_cloud, self.pts_num
|
||||
)
|
||||
@ -146,7 +144,7 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
n_to_world_trans = n_to_world_pose[:3, 3]
|
||||
n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0)
|
||||
scanned_n_to_world_pose.append(n_to_world_9d)
|
||||
scanned_target_pts_num.append(target_pts_num)
|
||||
|
||||
|
||||
nbv_idx, nbv_coverage_rate = nbv[0], nbv[1]
|
||||
nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx)
|
||||
@ -162,35 +160,33 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
)
|
||||
|
||||
combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0)
|
||||
fps_downsampled_combined_scanned_pts, fps_mask = PtsUtil.fps_downsample_point_cloud(
|
||||
combined_scanned_views_pts, self.pts_num, require_mask=True
|
||||
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
|
||||
combined_scanned_views_pts, self.pts_num, require_idx=True
|
||||
)
|
||||
|
||||
view_start_indices = np.cumsum([0] + [pts.shape[0] for pts in scanned_views_pts[:-1]])
|
||||
scanned_pts_mask = []
|
||||
combined_scanned_views_pts_mask = np.zeros(len(scanned_views_pts), dtype=np.uint8)
|
||||
|
||||
start_idx = 0
|
||||
for i in range(len(scanned_views_pts)):
|
||||
end_idx = start_idx + len(scanned_views_pts[i])
|
||||
combined_scanned_views_pts_mask[start_idx:end_idx] = i
|
||||
start_idx = end_idx
|
||||
|
||||
fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx]
|
||||
|
||||
|
||||
|
||||
for i, start_idx in enumerate(view_start_indices[:-1]):
|
||||
end_idx = view_start_indices[i + 1]
|
||||
view_mask = fps_mask[start_idx:end_idx]
|
||||
scanned_pts_mask.append(view_mask)
|
||||
|
||||
data_item = {
|
||||
"scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32),
|
||||
"scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.uint8),
|
||||
"combined_scanned_pts": np.asarray(
|
||||
fps_downsampled_combined_scanned_pts, dtype=np.float32
|
||||
),
|
||||
"scanned_coverage_rate": scanned_coverages_rate,
|
||||
"scanned_n_to_world_pose_9d": np.asarray(
|
||||
scanned_n_to_world_pose, dtype=np.float32
|
||||
),
|
||||
"best_coverage_rate": nbv_coverage_rate,
|
||||
"best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32),
|
||||
"seq_max_coverage_rate": max_coverage_rate,
|
||||
"scene_name": scene_name,
|
||||
"scanned_target_points_num": np.asarray(
|
||||
scanned_target_pts_num, dtype=np.int32
|
||||
),
|
||||
"scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3)
|
||||
"scanned_pts_mask": np.asarray(fps_downsampled_combined_scanned_pts_mask,dtype=np.uint8), # Ndarray(N), range(0, S)
|
||||
"combined_scanned_pts": np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32), # Ndarray(N x 3)
|
||||
"scanned_coverage_rate": scanned_coverages_rate, # List(S): Float, range(0, 1)
|
||||
"scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9)
|
||||
"best_coverage_rate": nbv_coverage_rate, # Float, range(0, 1)
|
||||
"best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), # Ndarray(9)
|
||||
"seq_max_coverage_rate": max_coverage_rate, # Float, range(0, 1)
|
||||
"scene_name": scene_name, # String
|
||||
}
|
||||
|
||||
return data_item
|
||||
@ -201,37 +197,35 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
def get_collate_fn(self):
|
||||
def collate_fn(batch):
|
||||
collate_data = {}
|
||||
|
||||
''' ------ Varialbe Length ------ '''
|
||||
|
||||
collate_data["scanned_pts"] = [
|
||||
torch.tensor(item["scanned_pts"]) for item in batch
|
||||
]
|
||||
collate_data["scanned_pts_mask"] = [
|
||||
torch.tensor(item["scanned_pts_mask"]) for item in batch
|
||||
]
|
||||
collate_data["scanned_n_to_world_pose_9d"] = [
|
||||
torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch
|
||||
]
|
||||
collate_data["scanned_target_points_num"] = [
|
||||
torch.tensor(item["scanned_target_points_num"]) for item in batch
|
||||
]
|
||||
|
||||
''' ------ Fixed Length ------ '''
|
||||
|
||||
collate_data["best_to_world_pose_9d"] = torch.stack(
|
||||
[torch.tensor(item["best_to_world_pose_9d"]) for item in batch]
|
||||
)
|
||||
collate_data["combined_scanned_pts"] = torch.stack(
|
||||
[torch.tensor(item["combined_scanned_pts"]) for item in batch]
|
||||
)
|
||||
if "first_frame_to_world" in batch[0]:
|
||||
collate_data["first_frame_to_world"] = torch.stack(
|
||||
[torch.tensor(item["first_frame_to_world"]) for item in batch]
|
||||
collate_data["scanned_pts_mask"] = torch.stack(
|
||||
[torch.tensor(item["scanned_pts_mask"]) for item in batch]
|
||||
)
|
||||
|
||||
for key in batch[0].keys():
|
||||
if key not in [
|
||||
"scanned_pts",
|
||||
"scanned_pts_mask",
|
||||
"scanned_n_to_world_pose_9d",
|
||||
"best_to_world_pose_9d",
|
||||
"first_frame_to_world",
|
||||
"combined_scanned_pts",
|
||||
"scanned_target_points_num",
|
||||
]:
|
||||
collate_data[key] = [item[key] for item in batch]
|
||||
return collate_data
|
||||
|
@ -24,13 +24,12 @@ class PtsUtil:
|
||||
return point_cloud[idx]
|
||||
|
||||
@staticmethod
|
||||
def fps_downsample_point_cloud(point_cloud, num_points, require_mask=False):
|
||||
def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False):
|
||||
N = point_cloud.shape[0]
|
||||
mask = np.zeros(N, dtype=bool)
|
||||
|
||||
sampled_indices = np.zeros(num_points, dtype=int)
|
||||
sampled_indices[0] = np.random.randint(0, N)
|
||||
mask[sampled_indices[0]] = True
|
||||
distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1)
|
||||
for i in range(1, num_points):
|
||||
farthest_index = np.argmax(distances)
|
||||
@ -41,8 +40,8 @@ class PtsUtil:
|
||||
distances = np.minimum(distances, new_distances)
|
||||
|
||||
sampled_points = point_cloud[sampled_indices]
|
||||
if require_mask:
|
||||
return sampled_points, mask
|
||||
if require_idx:
|
||||
return sampled_points, sampled_indices
|
||||
return sampled_points
|
||||
|
||||
@staticmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user