update fps algo and fps mask
This commit is contained in:
parent
276f45dcc3
commit
fa69f9f879
@ -12,41 +12,55 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
|||||||
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
|
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.module_config = config["modules"]
|
self.module_config = config["modules"]
|
||||||
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"])
|
self.pts_encoder = ComponentFactory.create(
|
||||||
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"])
|
namespace.Stereotype.MODULE, self.module_config["pts_encoder"]
|
||||||
self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"])
|
)
|
||||||
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"])
|
self.pose_encoder = ComponentFactory.create(
|
||||||
self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"])
|
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
|
||||||
|
)
|
||||||
|
self.pose_n_num_seq_encoder = ComponentFactory.create(
|
||||||
|
namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]
|
||||||
|
)
|
||||||
|
self.view_finder = ComponentFactory.create(
|
||||||
|
namespace.Stereotype.MODULE, self.module_config["view_finder"]
|
||||||
|
)
|
||||||
|
self.pts_num_encoder = ComponentFactory.create(
|
||||||
|
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
|
||||||
|
)
|
||||||
|
|
||||||
self.eps = float(self.config["eps"])
|
self.eps = float(self.config["eps"])
|
||||||
self.enable_global_scanned_feat = self.config["global_scanned_feat"]
|
self.enable_global_scanned_feat = self.config["global_scanned_feat"]
|
||||||
|
|
||||||
def forward(self, data):
|
def forward(self, data):
|
||||||
mode = data["mode"]
|
mode = data["mode"]
|
||||||
|
|
||||||
if mode == namespace.Mode.TRAIN:
|
if mode == namespace.Mode.TRAIN:
|
||||||
return self.forward_train(data)
|
return self.forward_train(data)
|
||||||
elif mode == namespace.Mode.TEST:
|
elif mode == namespace.Mode.TEST:
|
||||||
return self.forward_test(data)
|
return self.forward_test(data)
|
||||||
else:
|
else:
|
||||||
Log.error("Unknown mode: {}".format(mode), True)
|
Log.error("Unknown mode: {}".format(mode), True)
|
||||||
|
|
||||||
def pertube_data(self, gt_delta_9d):
|
def pertube_data(self, gt_delta_9d):
|
||||||
bs = gt_delta_9d.shape[0]
|
bs = gt_delta_9d.shape[0]
|
||||||
random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps
|
random_t = (
|
||||||
|
torch.rand(bs, device=gt_delta_9d.device) * (1.0 - self.eps) + self.eps
|
||||||
|
)
|
||||||
random_t = random_t.unsqueeze(-1)
|
random_t = random_t.unsqueeze(-1)
|
||||||
mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t)
|
mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t)
|
||||||
std = std.view(-1, 1)
|
std = std.view(-1, 1)
|
||||||
z = torch.randn_like(gt_delta_9d)
|
z = torch.randn_like(gt_delta_9d)
|
||||||
perturbed_x = mu + z * std
|
perturbed_x = mu + z * std
|
||||||
target_score = - z * std / (std ** 2)
|
target_score = -z * std / (std**2)
|
||||||
return perturbed_x, random_t, target_score, std
|
return perturbed_x, random_t, target_score, std
|
||||||
|
|
||||||
def forward_train(self, data):
|
def forward_train(self, data):
|
||||||
main_feat = self.get_main_feat(data)
|
main_feat = self.get_main_feat(data)
|
||||||
''' get std '''
|
""" get std """
|
||||||
best_to_world_pose_9d_batch = data["best_to_world_pose_9d"]
|
best_to_world_pose_9d_batch = data["best_to_world_pose_9d"]
|
||||||
perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch)
|
perturbed_x, random_t, target_score, std = self.pertube_data(
|
||||||
|
best_to_world_pose_9d_batch
|
||||||
|
)
|
||||||
input_data = {
|
input_data = {
|
||||||
"sampled_pose": perturbed_x,
|
"sampled_pose": perturbed_x,
|
||||||
"t": random_t,
|
"t": random_t,
|
||||||
@ -56,45 +70,69 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
|||||||
output = {
|
output = {
|
||||||
"estimated_score": estimated_score,
|
"estimated_score": estimated_score,
|
||||||
"target_score": target_score,
|
"target_score": target_score,
|
||||||
"std": std
|
"std": std,
|
||||||
}
|
}
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_test(self,data):
|
def forward_test(self, data):
|
||||||
main_feat = self.get_main_feat(data)
|
main_feat = self.get_main_feat(data)
|
||||||
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat)
|
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(
|
||||||
|
main_feat
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
"pred_pose_9d": estimated_delta_rot_9d,
|
"pred_pose_9d": estimated_delta_rot_9d,
|
||||||
"in_process_sample": in_process_sample
|
"in_process_sample": in_process_sample,
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_main_feat(self, data):
|
def get_main_feat(self, data):
|
||||||
scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d']
|
scanned_n_to_world_pose_9d_batch = data[
|
||||||
scanned_target_pts_num_batch = data['scanned_target_points_num']
|
"scanned_n_to_world_pose_9d"
|
||||||
|
] # List(B): Tensor(S x 9)
|
||||||
|
scanned_pts_mask_batch = data[
|
||||||
|
"scanned_pts_mask"
|
||||||
|
] # Tensor(B x N)
|
||||||
|
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
embedding_list_batch = []
|
embedding_list_batch = []
|
||||||
|
|
||||||
for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch):
|
|
||||||
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device)
|
|
||||||
scanned_target_pts_num = scanned_target_pts_num.to(device)
|
|
||||||
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)
|
|
||||||
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)
|
|
||||||
embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1))
|
|
||||||
|
|
||||||
main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch)
|
combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3)
|
||||||
|
global_scanned_feat, perpoint_scanned_feat_batch = self.pts_encoder.encode_points(
|
||||||
|
combined_scanned_pts_batch, require_per_point_feat=True
|
||||||
combined_scanned_pts_batch = data['combined_scanned_pts']
|
) # global_scanned_feat: Tensor(B x Dg), perpoint_scanned_feat: Tensor(B x N x Dl)
|
||||||
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
|
|
||||||
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
|
for scanned_n_to_world_pose_9d, scanned_mask, perpoint_scanned_feat in zip(
|
||||||
|
scanned_n_to_world_pose_9d_batch,
|
||||||
|
scanned_pts_mask_batch,
|
||||||
|
perpoint_scanned_feat_batch,
|
||||||
|
):
|
||||||
|
scanned_target_pts_num = [] # List(S): Int
|
||||||
|
partial_feat_seq = []
|
||||||
|
|
||||||
|
seq_len = len(scanned_n_to_world_pose_9d)
|
||||||
|
for seq_idx in range(seq_len):
|
||||||
|
partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask
|
||||||
|
partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl)
|
||||||
|
partial_feat = torch.mean(partial_perpoint_feat, dim=0)[0] # Tensor(Dl)
|
||||||
|
partial_feat_seq.append(partial_feat)
|
||||||
|
scanned_target_pts_num.append(partial_perpoint_feat.shape[0])
|
||||||
|
|
||||||
|
scanned_target_pts_num = torch.tensor(scanned_target_pts_num, dtype=torch.int32).to(device) # Tensor(S)
|
||||||
|
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9)
|
||||||
|
|
||||||
|
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp)
|
||||||
|
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num) # Tensor(S x Dn)
|
||||||
|
partial_feat_seq = torch.stack(partial_feat_seq) # Tensor(S x Dl)
|
||||||
|
|
||||||
|
seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl))
|
||||||
|
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl))
|
||||||
|
|
||||||
|
seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
|
||||||
|
|
||||||
|
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
||||||
|
|
||||||
if torch.isnan(main_feat).any():
|
if torch.isnan(main_feat).any():
|
||||||
Log.error("nan in main_feat", True)
|
Log.error("nan in main_feat", True)
|
||||||
|
|
||||||
return main_feat
|
return main_feat
|
||||||
|
|
||||||
|
@ -122,7 +122,6 @@ class NBVReconstructionDataset(BaseDataset):
|
|||||||
scanned_views_pts,
|
scanned_views_pts,
|
||||||
scanned_coverages_rate,
|
scanned_coverages_rate,
|
||||||
scanned_n_to_world_pose,
|
scanned_n_to_world_pose,
|
||||||
scanned_target_pts_num,
|
|
||||||
) = ([], [], [], [])
|
) = ([], [], [], [])
|
||||||
for view in scanned_views:
|
for view in scanned_views:
|
||||||
frame_idx = view[0]
|
frame_idx = view[0]
|
||||||
@ -134,7 +133,6 @@ class NBVReconstructionDataset(BaseDataset):
|
|||||||
target_point_cloud = (
|
target_point_cloud = (
|
||||||
DataLoadUtil.load_from_preprocessed_pts(view_path)
|
DataLoadUtil.load_from_preprocessed_pts(view_path)
|
||||||
)
|
)
|
||||||
target_pts_num = target_point_cloud.shape[0]
|
|
||||||
downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(
|
downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(
|
||||||
target_point_cloud, self.pts_num
|
target_point_cloud, self.pts_num
|
||||||
)
|
)
|
||||||
@ -146,7 +144,7 @@ class NBVReconstructionDataset(BaseDataset):
|
|||||||
n_to_world_trans = n_to_world_pose[:3, 3]
|
n_to_world_trans = n_to_world_pose[:3, 3]
|
||||||
n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0)
|
n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0)
|
||||||
scanned_n_to_world_pose.append(n_to_world_9d)
|
scanned_n_to_world_pose.append(n_to_world_9d)
|
||||||
scanned_target_pts_num.append(target_pts_num)
|
|
||||||
|
|
||||||
nbv_idx, nbv_coverage_rate = nbv[0], nbv[1]
|
nbv_idx, nbv_coverage_rate = nbv[0], nbv[1]
|
||||||
nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx)
|
nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx)
|
||||||
@ -162,35 +160,33 @@ class NBVReconstructionDataset(BaseDataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0)
|
combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0)
|
||||||
fps_downsampled_combined_scanned_pts, fps_mask = PtsUtil.fps_downsample_point_cloud(
|
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
|
||||||
combined_scanned_views_pts, self.pts_num, require_mask=True
|
combined_scanned_views_pts, self.pts_num, require_idx=True
|
||||||
)
|
)
|
||||||
|
|
||||||
view_start_indices = np.cumsum([0] + [pts.shape[0] for pts in scanned_views_pts[:-1]])
|
combined_scanned_views_pts_mask = np.zeros(len(scanned_views_pts), dtype=np.uint8)
|
||||||
scanned_pts_mask = []
|
|
||||||
|
start_idx = 0
|
||||||
for i, start_idx in enumerate(view_start_indices[:-1]):
|
for i in range(len(scanned_views_pts)):
|
||||||
end_idx = view_start_indices[i + 1]
|
end_idx = start_idx + len(scanned_views_pts[i])
|
||||||
view_mask = fps_mask[start_idx:end_idx]
|
combined_scanned_views_pts_mask[start_idx:end_idx] = i
|
||||||
scanned_pts_mask.append(view_mask)
|
start_idx = end_idx
|
||||||
|
|
||||||
|
fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
data_item = {
|
data_item = {
|
||||||
"scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32),
|
"scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3)
|
||||||
"scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.uint8),
|
"scanned_pts_mask": np.asarray(fps_downsampled_combined_scanned_pts_mask,dtype=np.uint8), # Ndarray(N), range(0, S)
|
||||||
"combined_scanned_pts": np.asarray(
|
"combined_scanned_pts": np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32), # Ndarray(N x 3)
|
||||||
fps_downsampled_combined_scanned_pts, dtype=np.float32
|
"scanned_coverage_rate": scanned_coverages_rate, # List(S): Float, range(0, 1)
|
||||||
),
|
"scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9)
|
||||||
"scanned_coverage_rate": scanned_coverages_rate,
|
"best_coverage_rate": nbv_coverage_rate, # Float, range(0, 1)
|
||||||
"scanned_n_to_world_pose_9d": np.asarray(
|
"best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), # Ndarray(9)
|
||||||
scanned_n_to_world_pose, dtype=np.float32
|
"seq_max_coverage_rate": max_coverage_rate, # Float, range(0, 1)
|
||||||
),
|
"scene_name": scene_name, # String
|
||||||
"best_coverage_rate": nbv_coverage_rate,
|
|
||||||
"best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32),
|
|
||||||
"seq_max_coverage_rate": max_coverage_rate,
|
|
||||||
"scene_name": scene_name,
|
|
||||||
"scanned_target_points_num": np.asarray(
|
|
||||||
scanned_target_pts_num, dtype=np.int32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return data_item
|
return data_item
|
||||||
@ -201,37 +197,35 @@ class NBVReconstructionDataset(BaseDataset):
|
|||||||
def get_collate_fn(self):
|
def get_collate_fn(self):
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
collate_data = {}
|
collate_data = {}
|
||||||
|
|
||||||
|
''' ------ Varialbe Length ------ '''
|
||||||
|
|
||||||
collate_data["scanned_pts"] = [
|
collate_data["scanned_pts"] = [
|
||||||
torch.tensor(item["scanned_pts"]) for item in batch
|
torch.tensor(item["scanned_pts"]) for item in batch
|
||||||
]
|
]
|
||||||
collate_data["scanned_pts_mask"] = [
|
|
||||||
torch.tensor(item["scanned_pts_mask"]) for item in batch
|
|
||||||
]
|
|
||||||
collate_data["scanned_n_to_world_pose_9d"] = [
|
collate_data["scanned_n_to_world_pose_9d"] = [
|
||||||
torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch
|
torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch
|
||||||
]
|
]
|
||||||
collate_data["scanned_target_points_num"] = [
|
|
||||||
torch.tensor(item["scanned_target_points_num"]) for item in batch
|
''' ------ Fixed Length ------ '''
|
||||||
]
|
|
||||||
collate_data["best_to_world_pose_9d"] = torch.stack(
|
collate_data["best_to_world_pose_9d"] = torch.stack(
|
||||||
[torch.tensor(item["best_to_world_pose_9d"]) for item in batch]
|
[torch.tensor(item["best_to_world_pose_9d"]) for item in batch]
|
||||||
)
|
)
|
||||||
collate_data["combined_scanned_pts"] = torch.stack(
|
collate_data["combined_scanned_pts"] = torch.stack(
|
||||||
[torch.tensor(item["combined_scanned_pts"]) for item in batch]
|
[torch.tensor(item["combined_scanned_pts"]) for item in batch]
|
||||||
)
|
)
|
||||||
if "first_frame_to_world" in batch[0]:
|
collate_data["scanned_pts_mask"] = torch.stack(
|
||||||
collate_data["first_frame_to_world"] = torch.stack(
|
[torch.tensor(item["scanned_pts_mask"]) for item in batch]
|
||||||
[torch.tensor(item["first_frame_to_world"]) for item in batch]
|
)
|
||||||
)
|
|
||||||
for key in batch[0].keys():
|
for key in batch[0].keys():
|
||||||
if key not in [
|
if key not in [
|
||||||
"scanned_pts",
|
"scanned_pts",
|
||||||
"scanned_pts_mask",
|
"scanned_pts_mask",
|
||||||
"scanned_n_to_world_pose_9d",
|
"scanned_n_to_world_pose_9d",
|
||||||
"best_to_world_pose_9d",
|
"best_to_world_pose_9d",
|
||||||
"first_frame_to_world",
|
|
||||||
"combined_scanned_pts",
|
"combined_scanned_pts",
|
||||||
"scanned_target_points_num",
|
|
||||||
]:
|
]:
|
||||||
collate_data[key] = [item[key] for item in batch]
|
collate_data[key] = [item[key] for item in batch]
|
||||||
return collate_data
|
return collate_data
|
||||||
|
@ -24,13 +24,12 @@ class PtsUtil:
|
|||||||
return point_cloud[idx]
|
return point_cloud[idx]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fps_downsample_point_cloud(point_cloud, num_points, require_mask=False):
|
def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False):
|
||||||
N = point_cloud.shape[0]
|
N = point_cloud.shape[0]
|
||||||
mask = np.zeros(N, dtype=bool)
|
mask = np.zeros(N, dtype=bool)
|
||||||
|
|
||||||
sampled_indices = np.zeros(num_points, dtype=int)
|
sampled_indices = np.zeros(num_points, dtype=int)
|
||||||
sampled_indices[0] = np.random.randint(0, N)
|
sampled_indices[0] = np.random.randint(0, N)
|
||||||
mask[sampled_indices[0]] = True
|
|
||||||
distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1)
|
distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1)
|
||||||
for i in range(1, num_points):
|
for i in range(1, num_points):
|
||||||
farthest_index = np.argmax(distances)
|
farthest_index = np.argmax(distances)
|
||||||
@ -41,8 +40,8 @@ class PtsUtil:
|
|||||||
distances = np.minimum(distances, new_distances)
|
distances = np.minimum(distances, new_distances)
|
||||||
|
|
||||||
sampled_points = point_cloud[sampled_indices]
|
sampled_points = point_cloud[sampled_indices]
|
||||||
if require_mask:
|
if require_idx:
|
||||||
return sampled_points, mask
|
return sampled_points, sampled_indices
|
||||||
return sampled_points
|
return sampled_points
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
Loading…
x
Reference in New Issue
Block a user