This commit is contained in:
hofee 2024-09-30 10:04:59 +08:00
commit 551282a0ec
10 changed files with 545 additions and 342 deletions

View File

@ -24,12 +24,6 @@ runner:
max_height: 0.15 max_height: 0.15
min_radius: 0.3 min_radius: 0.3
max_radius: 0.5 max_radius: 0.5
min_R: 0.05
max_R: 0.3
min_G: 0.05
max_G: 0.3
min_B: 0.05
max_B: 0.3
display_object: display_object:
min_x: 0 min_x: 0
max_x: 0.03 max_x: 0.03

View File

@ -0,0 +1,100 @@
import torch
from torch import nn
import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory.component_factory import ComponentFactory
from PytorchBoot.utils import Log
@stereotype.pipeline("nbv_reconstruction_global_pts_n_num_pipeline")
class NBVReconstructionGlobalPointsPipeline(nn.Module):
def __init__(self, config):
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
self.config = config
self.module_config = config["modules"]
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"])
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"])
self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"])
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"])
self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"])
self.eps = float(self.config["eps"])
self.enable_global_scanned_feat = self.config["global_scanned_feat"]
def forward(self, data):
mode = data["mode"]
if mode == namespace.Mode.TRAIN:
return self.forward_train(data)
elif mode == namespace.Mode.TEST:
return self.forward_test(data)
else:
Log.error("Unknown mode: {}".format(mode), True)
def pertube_data(self, gt_delta_9d):
bs = gt_delta_9d.shape[0]
random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps
random_t = random_t.unsqueeze(-1)
mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t)
std = std.view(-1, 1)
z = torch.randn_like(gt_delta_9d)
perturbed_x = mu + z * std
target_score = - z * std / (std ** 2)
return perturbed_x, random_t, target_score, std
def forward_train(self, data):
main_feat = self.get_main_feat(data)
''' get std '''
best_to_world_pose_9d_batch = data["best_to_world_pose_9d"]
perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch)
input_data = {
"sampled_pose": perturbed_x,
"t": random_t,
"main_feat": main_feat,
}
estimated_score = self.view_finder(input_data)
output = {
"estimated_score": estimated_score,
"target_score": target_score,
"std": std
}
return output
def forward_test(self,data):
main_feat = self.get_main_feat(data)
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat)
result = {
"pred_pose_9d": estimated_delta_rot_9d,
"in_process_sample": in_process_sample
}
return result
def get_main_feat(self, data):
scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d']
scanned_target_pts_num_batch = data['scanned_target_points_num']
device = next(self.parameters()).device
embedding_list_batch = []
for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch):
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device)
scanned_target_pts_num = scanned_target_pts_num.to(device)
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)
embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1))
main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch)
combined_scanned_pts_batch = data['combined_scanned_pts']
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
if torch.isnan(main_feat).any():
Log.error("nan in main_feat", True)
return main_feat

View File

@ -7,6 +7,7 @@ from PytorchBoot.utils.log_util import Log
import torch import torch
import os import os
import sys import sys
sys.path.append(r"/home/data/hofee/project/nbv_rec/nbv_reconstruction") sys.path.append(r"/home/data/hofee/project/nbv_rec/nbv_reconstruction")
from utils.data_load import DataLoadUtil from utils.data_load import DataLoadUtil
@ -29,20 +30,17 @@ class NBVReconstructionDataset(BaseDataset):
self.cache = config.get("cache") self.cache = config.get("cache")
self.load_from_preprocess = config.get("load_from_preprocess", False) self.load_from_preprocess = config.get("load_from_preprocess", False)
if self.type == namespace.Mode.TEST: if self.type == namespace.Mode.TEST:
self.model_dir = config["model_dir"] self.model_dir = config["model_dir"]
self.filter_degree = config["filter_degree"] self.filter_degree = config["filter_degree"]
if self.type == namespace.Mode.TRAIN: if self.type == namespace.Mode.TRAIN:
scale_ratio = 100 scale_ratio = 100
self.datalist = self.datalist*scale_ratio self.datalist = self.datalist * scale_ratio
if self.cache: if self.cache:
expr_root = ConfigManager.get("runner", "experiment", "root_dir") expr_root = ConfigManager.get("runner", "experiment", "root_dir")
expr_name = ConfigManager.get("runner", "experiment", "name") expr_name = ConfigManager.get("runner", "experiment", "name")
self.cache_dir = os.path.join(expr_root, expr_name, "cache") self.cache_dir = os.path.join(expr_root, expr_name, "cache")
#self.preprocess_cache() # self.preprocess_cache()
def load_scene_name_list(self): def load_scene_name_list(self):
scene_name_list = [] scene_name_list = []
@ -60,7 +58,9 @@ class NBVReconstructionDataset(BaseDataset):
max_coverage_rate_list = [] max_coverage_rate_list = []
for seq_idx in range(seq_num): for seq_idx in range(seq_num):
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx) label_path = DataLoadUtil.get_label_path(
self.root_dir, scene_name, seq_idx
)
label_data = DataLoadUtil.load_label(label_path) label_data = DataLoadUtil.load_label(label_path)
max_coverage_rate = label_data["max_coverage_rate"] max_coverage_rate = label_data["max_coverage_rate"]
if max_coverage_rate > scene_max_coverage_rate: if max_coverage_rate > scene_max_coverage_rate:
@ -69,20 +69,24 @@ class NBVReconstructionDataset(BaseDataset):
mean_coverage_rate = np.mean(max_coverage_rate_list) mean_coverage_rate = np.mean(max_coverage_rate_list)
for seq_idx in range(seq_num): for seq_idx in range(seq_num):
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx) label_path = DataLoadUtil.get_label_path(
self.root_dir, scene_name, seq_idx
)
label_data = DataLoadUtil.load_label(label_path) label_data = DataLoadUtil.load_label(label_path)
if max_coverage_rate_list[seq_idx] > mean_coverage_rate - 0.1: if max_coverage_rate_list[seq_idx] > mean_coverage_rate - 0.1:
for data_pair in label_data["data_pairs"]: for data_pair in label_data["data_pairs"]:
scanned_views = data_pair[0] scanned_views = data_pair[0]
next_best_view = data_pair[1] next_best_view = data_pair[1]
datalist.append({ datalist.append(
"scanned_views": scanned_views, {
"next_best_view": next_best_view, "scanned_views": scanned_views,
"seq_max_coverage_rate": max_coverage_rate, "next_best_view": next_best_view,
"scene_name": scene_name, "seq_max_coverage_rate": max_coverage_rate,
"label_idx": seq_idx, "scene_name": scene_name,
"scene_max_coverage_rate": scene_max_coverage_rate "label_idx": seq_idx,
}) "scene_max_coverage_rate": scene_max_coverage_rate,
}
)
return datalist return datalist
def preprocess_cache(self): def preprocess_cache(self):
@ -107,9 +111,6 @@ class NBVReconstructionDataset(BaseDataset):
np.savetxt(cache_path, data) np.savetxt(cache_path, data)
except Exception as e: except Exception as e:
Log.error(f"Save cache failed: {e}") Log.error(f"Save cache failed: {e}")
# ----- Debug Trace ----- #
import ipdb; ipdb.set_trace()
# ------------------------ #
def __getitem__(self, index): def __getitem__(self, index):
data_item_info = self.datalist[index] data_item_info = self.datalist[index]
@ -117,18 +118,28 @@ class NBVReconstructionDataset(BaseDataset):
nbv = data_item_info["next_best_view"] nbv = data_item_info["next_best_view"]
max_coverage_rate = data_item_info["seq_max_coverage_rate"] max_coverage_rate = data_item_info["seq_max_coverage_rate"]
scene_name = data_item_info["scene_name"] scene_name = data_item_info["scene_name"]
scanned_views_pts, scanned_coverages_rate, scanned_n_to_world_pose = [], [], [] (
scanned_views_pts,
scanned_coverages_rate,
scanned_n_to_world_pose,
scanned_target_pts_num,
) = ([], [], [], [])
target_pts_num_dict = DataLoadUtil.load_target_pts_num_dict(
self.root_dir, scene_name
)
for view in scanned_views: for view in scanned_views:
frame_idx = view[0] frame_idx = view[0]
coverage_rate = view[1] coverage_rate = view[1]
view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx)
cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True) cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True)
target_pts_num = target_pts_num_dict[frame_idx]
n_to_world_pose = cam_info["cam_to_world"] n_to_world_pose = cam_info["cam_to_world"]
nR_to_world_pose = cam_info["cam_to_world_R"] nR_to_world_pose = cam_info["cam_to_world_R"]
if self.load_from_preprocess: if self.load_from_preprocess:
downsampled_target_point_cloud = DataLoadUtil.load_from_preprocessed_pts(view_path) downsampled_target_point_cloud = (
DataLoadUtil.load_from_preprocessed_pts(view_path)
)
else: else:
cached_data = None cached_data = None
if self.cache: if self.cache:
@ -136,72 +147,90 @@ class NBVReconstructionDataset(BaseDataset):
if cached_data is None: if cached_data is None:
print("load depth") print("load depth")
depth_L, depth_R = DataLoadUtil.load_depth(view_path, cam_info['near_plane'], cam_info['far_plane'], binocular=True) depth_L, depth_R = DataLoadUtil.load_depth(
point_cloud_L = DataLoadUtil.get_point_cloud(depth_L, cam_info['cam_intrinsic'], n_to_world_pose)['points_world'] view_path,
point_cloud_R = DataLoadUtil.get_point_cloud(depth_R, cam_info['cam_intrinsic'], nR_to_world_pose)['points_world'] cam_info["near_plane"],
cam_info["far_plane"],
binocular=True,
)
point_cloud_L = DataLoadUtil.get_point_cloud(
depth_L, cam_info["cam_intrinsic"], n_to_world_pose
)["points_world"]
point_cloud_R = DataLoadUtil.get_point_cloud(
depth_R, cam_info["cam_intrinsic"], nR_to_world_pose
)["points_world"]
point_cloud_L = PtsUtil.random_downsample_point_cloud(point_cloud_L, 65536) point_cloud_L = PtsUtil.random_downsample_point_cloud(
point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536) point_cloud_L, 65536
overlap_points = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R) )
downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(overlap_points, self.pts_num) point_cloud_R = PtsUtil.random_downsample_point_cloud(
point_cloud_R, 65536
)
overlap_points = DataLoadUtil.get_overlapping_points(
point_cloud_L, point_cloud_R
)
downsampled_target_point_cloud = (
PtsUtil.random_downsample_point_cloud(
overlap_points, self.pts_num
)
)
if self.cache: if self.cache:
self.save_to_cache(scene_name, frame_idx, downsampled_target_point_cloud) self.save_to_cache(
scene_name, frame_idx, downsampled_target_point_cloud
)
else: else:
downsampled_target_point_cloud = cached_data downsampled_target_point_cloud = cached_data
scanned_views_pts.append(downsampled_target_point_cloud) scanned_views_pts.append(downsampled_target_point_cloud)
scanned_coverages_rate.append(coverage_rate) scanned_coverages_rate.append(coverage_rate)
n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(np.asarray(n_to_world_pose[:3,:3])) n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(
n_to_world_trans = n_to_world_pose[:3,3] np.asarray(n_to_world_pose[:3, :3])
)
n_to_world_trans = n_to_world_pose[:3, 3]
n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0) n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0)
scanned_n_to_world_pose.append(n_to_world_9d) scanned_n_to_world_pose.append(n_to_world_9d)
scanned_target_pts_num.append(target_pts_num)
nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] nbv_idx, nbv_coverage_rate = nbv[0], nbv[1]
nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx) nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx)
cam_info = DataLoadUtil.load_cam_info(nbv_path) cam_info = DataLoadUtil.load_cam_info(nbv_path)
best_frame_to_world = cam_info["cam_to_world"] best_frame_to_world = cam_info["cam_to_world"]
best_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(np.asarray(best_frame_to_world[:3,:3])) best_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy(
best_to_world_trans = best_frame_to_world[:3,3] np.asarray(best_frame_to_world[:3, :3])
best_to_world_9d = np.concatenate([best_to_world_6d, best_to_world_trans], axis=0) )
best_to_world_trans = best_frame_to_world[:3, 3]
best_to_world_9d = np.concatenate(
[best_to_world_6d, best_to_world_trans], axis=0
)
combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0)
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002) voxel_downsampled_combined_scanned_pts_np = (
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, self.pts_num) PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002)
)
random_downsampled_combined_scanned_pts_np = (
PtsUtil.random_downsample_point_cloud(
voxel_downsampled_combined_scanned_pts_np, self.pts_num
)
)
data_item = { data_item = {
"scanned_pts": np.asarray(scanned_views_pts,dtype=np.float32), "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32),
"combined_scanned_pts": np.asarray(random_downsampled_combined_scanned_pts_np,dtype=np.float32), "combined_scanned_pts": np.asarray(
random_downsampled_combined_scanned_pts_np, dtype=np.float32
),
"scanned_coverage_rate": scanned_coverages_rate, "scanned_coverage_rate": scanned_coverages_rate,
"scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose,dtype=np.float32), "scanned_n_to_world_pose_9d": np.asarray(
scanned_n_to_world_pose, dtype=np.float32
),
"best_coverage_rate": nbv_coverage_rate, "best_coverage_rate": nbv_coverage_rate,
"best_to_world_pose_9d": np.asarray(best_to_world_9d,dtype=np.float32), "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32),
"seq_max_coverage_rate": max_coverage_rate, "seq_max_coverage_rate": max_coverage_rate,
"scene_name": scene_name "scene_name": scene_name,
"scanned_target_points_num": np.asarray(
scanned_target_pts_num, dtype=np.int32
),
} }
# if self.type == namespace.Mode.TEST:
# diag = DataLoadUtil.get_bbox_diag(self.model_dir, scene_name)
# voxel_threshold = diag*0.02
# model_points_normals = DataLoadUtil.load_points_normals(self.root_dir, scene_name)
# pts_list = []
# for view in scanned_views:
# frame_idx = view[0]
# view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx)
# point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(view_path, binocular=True)
# cam_params = DataLoadUtil.load_cam_info(view_path, binocular=True)
# sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree)
# pts_list.append(sampled_point_cloud)
# nL_to_world_pose = cam_params["cam_to_world"]
# nO_to_world_pose = cam_params["cam_to_world_O"]
# nO_to_nL_pose = np.dot(np.linalg.inv(nL_to_world_pose), nO_to_world_pose)
# data_item["scanned_target_pts_list"] = pts_list
# data_item["model_points_normals"] = model_points_normals
# data_item["voxel_threshold"] = voxel_threshold
# data_item["filter_degree"] = self.filter_degree
# data_item["scene_path"] = os.path.join(self.root_dir, scene_name)
# data_item["first_frame_to_world"] = np.asarray(first_frame_to_world, dtype=np.float32)
# data_item["nO_to_nL_pose"] = np.asarray(nO_to_nL_pose, dtype=np.float32)
return data_item return data_item
def __len__(self): def __len__(self):
@ -210,22 +239,44 @@ class NBVReconstructionDataset(BaseDataset):
def get_collate_fn(self): def get_collate_fn(self):
def collate_fn(batch): def collate_fn(batch):
collate_data = {} collate_data = {}
collate_data["scanned_pts"] = [torch.tensor(item['scanned_pts']) for item in batch] collate_data["scanned_pts"] = [
collate_data["scanned_n_to_world_pose_9d"] = [torch.tensor(item['scanned_n_to_world_pose_9d']) for item in batch] torch.tensor(item["scanned_pts"]) for item in batch
collate_data["best_to_world_pose_9d"] = torch.stack([torch.tensor(item['best_to_world_pose_9d']) for item in batch]) ]
collate_data["combined_scanned_pts"] = torch.stack([torch.tensor(item['combined_scanned_pts']) for item in batch]) collate_data["scanned_n_to_world_pose_9d"] = [
torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch
]
collate_data["scanned_target_points_num"] = [
torch.tensor(item["scanned_target_points_num"]) for item in batch
]
collate_data["best_to_world_pose_9d"] = torch.stack(
[torch.tensor(item["best_to_world_pose_9d"]) for item in batch]
)
collate_data["combined_scanned_pts"] = torch.stack(
[torch.tensor(item["combined_scanned_pts"]) for item in batch]
)
if "first_frame_to_world" in batch[0]: if "first_frame_to_world" in batch[0]:
collate_data["first_frame_to_world"] = torch.stack([torch.tensor(item["first_frame_to_world"]) for item in batch]) collate_data["first_frame_to_world"] = torch.stack(
[torch.tensor(item["first_frame_to_world"]) for item in batch]
)
for key in batch[0].keys(): for key in batch[0].keys():
if key not in ["scanned_pts", "scanned_n_to_world_pose_9d", "best_to_world_pose_9d", "first_frame_to_world", "combined_scanned_pts"]: if key not in [
"scanned_pts",
"scanned_n_to_world_pose_9d",
"best_to_world_pose_9d",
"first_frame_to_world",
"combined_scanned_pts",
"scanned_target_points_num",
]:
collate_data[key] = [item[key] for item in batch] collate_data[key] = [item[key] for item in batch]
return collate_data return collate_data
return collate_fn return collate_fn
# -------------- Debug ---------------- # # -------------- Debug ---------------- #
if __name__ == "__main__": if __name__ == "__main__":
import torch import torch
seed = 0 seed = 0
torch.manual_seed(seed) torch.manual_seed(seed)
np.random.seed(seed) np.random.seed(seed)
@ -244,41 +295,13 @@ if __name__ == "__main__":
} }
ds = NBVReconstructionDataset(config) ds = NBVReconstructionDataset(config)
print(len(ds)) print(len(ds))
#ds.__getitem__(10) # ds.__getitem__(10)
dl = ds.get_loader(shuffle=True) dl = ds.get_loader(shuffle=True)
for idx, data in enumerate(dl): for idx, data in enumerate(dl):
data = ds.process_batch(data, "cuda:0") data = ds.process_batch(data, "cuda:0")
print(data) print(data)
# ------ Debug Start ------ # ------ Debug Start ------
import ipdb;ipdb.set_trace() import ipdb
ipdb.set_trace()
# ------ Debug End ------ # ------ Debug End ------
#
# for idx, data in enumerate(dl):
# cnt=0
# print(data["scene_name"])
# print(data["scanned_coverage_rate"])
# print(data["best_coverage_rate"])
# for pts in data["scanned_pts"][0]:
# #np.savetxt(f"pts_{cnt}.txt", pts)
# cnt+=1
# #np.savetxt("best_pts.txt", best_pts)
# for key, value in data.items():
# if isinstance(value, torch.Tensor):
# print(key, ":" ,value.shape)
# else:
# print(key, ":" ,len(value))
# if key == "scanned_n_to_world_pose_9d":
# for val in value:
# print(val.shape)
# if key == "scanned_pts":
# print("scanned_pts")
# for val in value:
# print(val.shape)
# cnt = 0
# for v in val:
# import ipdb;ipdb.set_trace()
# np.savetxt(f"pts_{cnt}.txt", v)
# cnt+=1
# print()

View File

@ -22,12 +22,10 @@ class PointNetEncoder(nn.Module):
self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 512, 1) self.conv3 = torch.nn.Conv1d(128, 512, 1)
self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1) self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1)
self.global_feat = config["global_feat"]
if self.feature_transform: if self.feature_transform:
self.f_stn = STNkd(k=64) self.f_stn = STNkd(k=64)
def forward(self, x): def forward(self, x):
n_pts = x.shape[2]
trans = self.stn(x) trans = self.stn(x)
x = x.transpose(2, 1) x = x.transpose(2, 1)
x = torch.bmm(x, trans) x = torch.bmm(x, trans)
@ -46,20 +44,15 @@ class PointNetEncoder(nn.Module):
x = self.conv4(x) x = self.conv4(x)
x = torch.max(x, 2, keepdim=True)[0] x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, self.out_dim) x = x.view(-1, self.out_dim)
if self.global_feat: return x, point_feat
return x
else:
x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts)
return torch.cat([x, point_feat], 1)
def encode_points(self, pts): def encode_points(self, pts, require_per_point_feat=False):
pts = pts.transpose(2, 1) pts = pts.transpose(2, 1)
global_pts_feature, per_point_feature = self(pts)
if not self.global_feat: if require_per_point_feat:
pts_feature = self(pts).transpose(2, 1) return global_pts_feature, per_point_feature.transpose(2, 1)
else: else:
pts_feature = self(pts) return global_pts_feature
return pts_feature
class STNkd(nn.Module): class STNkd(nn.Module):
def __init__(self, k=64): def __init__(self, k=64):
@ -102,21 +95,13 @@ if __name__ == "__main__":
config = { config = {
"in_dim": 3, "in_dim": 3,
"out_dim": 1024, "out_dim": 1024,
"global_feat": True,
"feature_transform": False "feature_transform": False
} }
pointnet_global = PointNetEncoder(config) pointnet = PointNetEncoder(config)
out = pointnet_global.encode_points(sim_data) out = pointnet.encode_points(sim_data)
print("global feat", out.size()) print("global feat", out.size())
config = { out, per_point_out = pointnet.encode_points(sim_data, require_per_point_feat=True)
"in_dim": 3,
"out_dim": 1024,
"global_feat": False,
"feature_transform": False
}
pointnet = PointNetEncoder(config)
out = pointnet.encode_points(sim_data)
print("point feat", out.size()) print("point feat", out.size())
print("per point feat", per_point_out.size())

View File

@ -0,0 +1,20 @@
from torch import nn
import PytorchBoot.stereotype as stereotype
@stereotype.module("pts_num_encoder")
class PointsNumEncoder(nn.Module):
def __init__(self, config):
super(PointsNumEncoder, self).__init__()
self.config = config
out_dim = config["out_dim"]
self.act = nn.ReLU(True)
self.pts_num_encoder = nn.Sequential(
nn.Linear(1, out_dim),
self.act,
nn.Linear(out_dim, out_dim),
self.act,
)
def encode_pts_num(self, num_seq):
return self.pts_num_encoder(num_seq)

View File

@ -1,63 +0,0 @@
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import PytorchBoot.stereotype as stereotype
@stereotype.module("transformer_pose_seq_encoder")
class TransformerPoseSequenceEncoder(nn.Module):
def __init__(self, config):
super(TransformerPoseSequenceEncoder, self).__init__()
self.config = config
embed_dim = config["pose_embed_dim"]
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=config["num_heads"],
dim_feedforward=config["ffn_dim"],
batch_first=True,
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=config["num_layers"]
)
self.fc = nn.Linear(embed_dim, config["output_dim"])
def encode_sequence(self, pose_embedding_list_batch):
lengths = []
for pose_embedding_list in pose_embedding_list_batch:
lengths.append(len(pose_embedding_list))
combined_tensor = pad_sequence(pose_embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
final_feature = transformer_output.mean(dim=1)
final_output = self.fc(final_feature)
return final_output
if __name__ == "__main__":
config = {
"pose_embed_dim": 256,
"num_heads": 4,
"ffn_dim": 256,
"num_layers": 3,
"output_dim": 1024,
}
encoder = TransformerPoseSequenceEncoder(config)
seq_len = [5, 8, 9, 4]
batch_size = 4
pose_embedding_list_batch = [
torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size)
]
output_feature = encoder.encode_sequence(
pose_embedding_list_batch
)
print("Encoded Feature:", output_feature)
print("Feature Shape:", output_feature.shape)

View File

@ -9,7 +9,7 @@ class TransformerSequenceEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(TransformerSequenceEncoder, self).__init__() super(TransformerSequenceEncoder, self).__init__()
self.config = config self.config = config
embed_dim = config["pts_embed_dim"] + config["pose_embed_dim"] embed_dim = config["embed_dim"]
encoder_layer = nn.TransformerEncoderLayer( encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, d_model=embed_dim,
nhead=config["num_heads"], nhead=config["num_heads"],
@ -21,24 +21,19 @@ class TransformerSequenceEncoder(nn.Module):
) )
self.fc = nn.Linear(embed_dim, config["output_dim"]) self.fc = nn.Linear(embed_dim, config["output_dim"])
def encode_sequence(self, pts_embedding_list_batch, pose_embedding_list_batch): def encode_sequence(self, embedding_list_batch):
combined_features_batch = []
lengths = [] lengths = []
for pts_embedding_list, pose_embedding_list in zip(pts_embedding_list_batch, pose_embedding_list_batch): for embedding_list in embedding_list_batch:
combined_features = [ lengths.append(len(embedding_list))
torch.cat((pts_embed, pose_embed), dim=-1)
for pts_embed, pose_embed in zip(pts_embedding_list, pose_embedding_list)
]
combined_features_batch.append(torch.stack(combined_features))
lengths.append(len(combined_features))
combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim] embedding_tensor = pad_sequence(embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
max_len = max(lengths) max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device) padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(embedding_tensor.device)
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask) transformer_output = self.transformer_encoder(embedding_tensor, src_key_padding_mask=padding_mask)
final_feature = transformer_output.mean(dim=1) final_feature = transformer_output.mean(dim=1)
final_output = self.fc(final_feature) final_output = self.fc(final_feature)
@ -47,26 +42,22 @@ class TransformerSequenceEncoder(nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
config = { config = {
"pts_embed_dim": 1024, "embed_dim": 256,
"pose_embed_dim": 256,
"num_heads": 4, "num_heads": 4,
"ffn_dim": 256, "ffn_dim": 256,
"num_layers": 3, "num_layers": 3,
"output_dim": 2048, "output_dim": 1024,
} }
encoder = TransformerSequenceEncoder(config) encoder = TransformerSequenceEncoder(config)
seq_len = [5, 8, 9, 4] seq_len = [5, 8, 9, 4]
batch_size = 4 batch_size = 4
pts_embedding_list_batch = [ embedding_list_batch = [
torch.randn(seq_len[idx], config["pts_embed_dim"]) for idx in range(batch_size) torch.randn(seq_len[idx], config["embed_dim"]) for idx in range(batch_size)
]
pose_embedding_list_batch = [
torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size)
] ]
output_feature = encoder.encode_sequence( output_feature = encoder.encode_sequence(
pts_embedding_list_batch, pose_embedding_list_batch embedding_list_batch
) )
print("Encoded Feature:", output_feature) print("Encoded Feature:", output_feature)
print("Feature Shape:", output_feature.shape) print("Feature Shape:", output_feature.shape)

View File

@ -82,28 +82,40 @@ class StrategyGenerator(Runner):
model_points_normals = DataLoadUtil.load_points_normals(root, scene_name) model_points_normals = DataLoadUtil.load_points_normals(root, scene_name)
model_pts = model_points_normals[:,:3] model_pts = model_points_normals[:,:3]
down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold) down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold)
display_table_info = DataLoadUtil.get_display_table_info(root, scene_name)
radius = display_table_info["radius"]
top = DataLoadUtil.get_display_table_top(root, scene_name)
scan_points = ReconstructionUtil.generate_scan_points(display_table_top=top,display_table_radius=radius)
pts_list = [] pts_list = []
scan_points_indices_list = []
for frame_idx in range(frame_num): for frame_idx in range(frame_num):
if self.load_pts and os.path.exists(os.path.join(root,scene_name, "pts", f"{frame_idx}.txt")): if self.load_pts and os.path.exists(os.path.join(root,scene_name, "pts", f"{frame_idx}.txt")):
sampled_point_cloud = np.loadtxt(os.path.join(root,scene_name, "pts", f"{frame_idx}.txt")) sampled_point_cloud = np.loadtxt(os.path.join(root,scene_name, "pts", f"{frame_idx}.txt"))
indices = np.loadtxt(os.path.join(root,scene_name, "pts", f"{frame_idx}_indices.txt")).astype(np.int32).tolist()
status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num) status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num)
pts_list.append(sampled_point_cloud) pts_list.append(sampled_point_cloud)
continue scan_points_indices_list.append(indices)
else: else:
path = DataLoadUtil.get_path(root, scene_name, frame_idx) path = DataLoadUtil.get_path(root, scene_name, frame_idx)
cam_params = DataLoadUtil.load_cam_info(path, binocular=True) cam_params = DataLoadUtil.load_cam_info(path, binocular=True)
status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num) status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num)
point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True) point_cloud, display_table_pts = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True, get_display_table_pts=True)
sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree) sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree)
covered_pts, indices = ReconstructionUtil.compute_covered_scan_points(scan_points, display_table_pts)
if self.save_pts: if self.save_pts:
pts_dir = os.path.join(root,scene_name, "pts") pts_dir = os.path.join(root,scene_name, "pts")
covered_pts_dir = os.path.join(pts_dir, "covered_scan_pts")
if not os.path.exists(pts_dir): if not os.path.exists(pts_dir):
os.makedirs(pts_dir) os.makedirs(pts_dir)
if not os.path.exists(covered_pts_dir):
os.makedirs(covered_pts_dir)
np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud) np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud)
np.savetxt(os.path.join(covered_pts_dir, f"{frame_idx}.txt"), covered_pts)
np.savetxt(os.path.join(pts_dir, f"{frame_idx}_indices.txt"), indices)
pts_list.append(sampled_point_cloud) pts_list.append(sampled_point_cloud)
scan_points_indices_list.append(indices)
status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_num, frame_num) status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_num, frame_num)
seq_num = min(self.seq_num, len(pts_list)) seq_num = min(self.seq_num, len(pts_list))

View File

@ -6,8 +6,9 @@ import trimesh
import torch import torch
from utils.pts import PtsUtil from utils.pts import PtsUtil
class DataLoadUtil: class DataLoadUtil:
TABLE_POSITION = np.asarray([0,0,0.8215]) TABLE_POSITION = np.asarray([0, 0, 0.8215])
@staticmethod @staticmethod
def get_display_table_info(root, scene_name): def get_display_table_info(root, scene_name):
@ -17,8 +18,12 @@ class DataLoadUtil:
@staticmethod @staticmethod
def get_display_table_top(root, scene_name): def get_display_table_top(root, scene_name):
display_table_height = DataLoadUtil.get_display_table_info(root, scene_name)["height"] display_table_height = DataLoadUtil.get_display_table_info(root, scene_name)[
display_table_top = DataLoadUtil.TABLE_POSITION + np.asarray([0,0,display_table_height]) "height"
]
display_table_top = DataLoadUtil.TABLE_POSITION + np.asarray(
[0, 0, display_table_height]
)
return display_table_top return display_table_top
@staticmethod @staticmethod
@ -28,20 +33,20 @@ class DataLoadUtil:
@staticmethod @staticmethod
def get_label_num(root, scene_name): def get_label_num(root, scene_name):
label_dir = os.path.join(root,scene_name,"label") label_dir = os.path.join(root, scene_name, "label")
return len(os.listdir(label_dir)) return len(os.listdir(label_dir))
@staticmethod @staticmethod
def get_label_path(root, scene_name, seq_idx): def get_label_path(root, scene_name, seq_idx):
label_dir = os.path.join(root,scene_name,"label") label_dir = os.path.join(root, scene_name, "label")
if not os.path.exists(label_dir): if not os.path.exists(label_dir):
os.makedirs(label_dir) os.makedirs(label_dir)
path = os.path.join(label_dir,f"{seq_idx}.json") path = os.path.join(label_dir, f"{seq_idx}.json")
return path return path
@staticmethod @staticmethod
def get_label_path_old(root, scene_name): def get_label_path_old(root, scene_name):
path = os.path.join(root,scene_name,"label.json") path = os.path.join(root, scene_name, "label.json")
return path return path
@staticmethod @staticmethod
@ -64,7 +69,6 @@ class DataLoadUtil:
diagonal_length = np.linalg.norm(bbox) diagonal_length = np.linalg.norm(bbox)
return diagonal_length return diagonal_length
@staticmethod @staticmethod
def save_mesh_at(model_dir, output_dir, object_name, scene_name, world_object_pose): def save_mesh_at(model_dir, output_dir, object_name, scene_name, world_object_pose):
mesh = DataLoadUtil.load_mesh_at(model_dir, object_name, world_object_pose) mesh = DataLoadUtil.load_mesh_at(model_dir, object_name, world_object_pose)
@ -72,12 +76,16 @@ class DataLoadUtil:
mesh.export(model_path) mesh.export(model_path)
@staticmethod @staticmethod
def save_target_mesh_at_world_space(root, model_dir, scene_name, display_table_as_world_space_origin=True): def save_target_mesh_at_world_space(
root, model_dir, scene_name, display_table_as_world_space_origin=True
):
scene_info = DataLoadUtil.load_scene_info(root, scene_name) scene_info = DataLoadUtil.load_scene_info(root, scene_name)
target_name = scene_info["target_name"] target_name = scene_info["target_name"]
transformation = scene_info[target_name] transformation = scene_info[target_name]
if display_table_as_world_space_origin: if display_table_as_world_space_origin:
location = transformation["location"] - DataLoadUtil.get_display_table_top(root, scene_name) location = transformation["location"] - DataLoadUtil.get_display_table_top(
root, scene_name
)
else: else:
location = transformation["location"] location = transformation["location"]
rotation_euler = transformation["rotation_euler"] rotation_euler = transformation["rotation_euler"]
@ -98,6 +106,13 @@ class DataLoadUtil:
scene_info = json.load(f) scene_info = json.load(f)
return scene_info return scene_info
@staticmethod
def load_target_pts_num_dict(root, scene_name):
target_pts_num_path = os.path.join(root, scene_name, "target_pts_num.json")
with open(target_pts_num_path, "r") as f:
target_pts_num_dict = json.load(f)
return target_pts_num_dict
@staticmethod @staticmethod
def load_target_object_pose(root, scene_name): def load_target_object_pose(root, scene_name):
scene_info = DataLoadUtil.load_scene_info(root, scene_name) scene_info = DataLoadUtil.load_scene_info(root, scene_name)
@ -110,7 +125,7 @@ class DataLoadUtil:
return pose_mat return pose_mat
@staticmethod @staticmethod
def load_depth(path, min_depth=0.01,max_depth=5.0,binocular=False): def load_depth(path, min_depth=0.01, max_depth=5.0, binocular=False):
def load_depth_from_real_path(real_path, min_depth, max_depth): def load_depth_from_real_path(real_path, min_depth, max_depth):
depth = cv2.imread(real_path, cv2.IMREAD_UNCHANGED) depth = cv2.imread(real_path, cv2.IMREAD_UNCHANGED)
@ -121,62 +136,84 @@ class DataLoadUtil:
return depth_meters return depth_meters
if binocular: if binocular:
depth_path_L = os.path.join(os.path.dirname(path), "depth", os.path.basename(path) + "_L.png") depth_path_L = os.path.join(
depth_path_R = os.path.join(os.path.dirname(path), "depth", os.path.basename(path) + "_R.png") os.path.dirname(path), "depth", os.path.basename(path) + "_L.png"
depth_meters_L = load_depth_from_real_path(depth_path_L, min_depth, max_depth) )
depth_meters_R = load_depth_from_real_path(depth_path_R, min_depth, max_depth) depth_path_R = os.path.join(
os.path.dirname(path), "depth", os.path.basename(path) + "_R.png"
)
depth_meters_L = load_depth_from_real_path(
depth_path_L, min_depth, max_depth
)
depth_meters_R = load_depth_from_real_path(
depth_path_R, min_depth, max_depth
)
return depth_meters_L, depth_meters_R return depth_meters_L, depth_meters_R
else: else:
depth_path = os.path.join(os.path.dirname(path), "depth", os.path.basename(path) + ".png") depth_path = os.path.join(
os.path.dirname(path), "depth", os.path.basename(path) + ".png"
)
depth_meters = load_depth_from_real_path(depth_path, min_depth, max_depth) depth_meters = load_depth_from_real_path(depth_path, min_depth, max_depth)
return depth_meters return depth_meters
@staticmethod @staticmethod
def load_seg(path, binocular=False): def load_seg(path, binocular=False):
if binocular: if binocular:
def clean_mask(mask_image): def clean_mask(mask_image):
green = [0, 255, 0, 255] green = [0, 255, 0, 255]
red = [255, 0, 0, 255] red = [255, 0, 0, 255]
threshold = 2 threshold = 2
mask_image = np.where(np.abs(mask_image - green) <= threshold, green, mask_image) mask_image = np.where(
mask_image = np.where(np.abs(mask_image - red) <= threshold, red, mask_image) np.abs(mask_image - green) <= threshold, green, mask_image
)
mask_image = np.where(
np.abs(mask_image - red) <= threshold, red, mask_image
)
return mask_image return mask_image
mask_path_L = os.path.join(os.path.dirname(path), "mask", os.path.basename(path) + "_L.png")
mask_path_L = os.path.join(
os.path.dirname(path), "mask", os.path.basename(path) + "_L.png"
)
mask_image_L = clean_mask(cv2.imread(mask_path_L, cv2.IMREAD_UNCHANGED)) mask_image_L = clean_mask(cv2.imread(mask_path_L, cv2.IMREAD_UNCHANGED))
mask_path_R = os.path.join(os.path.dirname(path), "mask", os.path.basename(path) + "_R.png") mask_path_R = os.path.join(
os.path.dirname(path), "mask", os.path.basename(path) + "_R.png"
)
mask_image_R = clean_mask(cv2.imread(mask_path_R, cv2.IMREAD_UNCHANGED)) mask_image_R = clean_mask(cv2.imread(mask_path_R, cv2.IMREAD_UNCHANGED))
return mask_image_L, mask_image_R return mask_image_L, mask_image_R
else: else:
mask_path = os.path.join(os.path.dirname(path), "mask", os.path.basename(path) + ".png") mask_path = os.path.join(
os.path.dirname(path), "mask", os.path.basename(path) + ".png"
)
mask_image = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) mask_image = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
return mask_image return mask_image
@staticmethod @staticmethod
def load_label(path): def load_label(path):
with open(path, 'r') as f: with open(path, "r") as f:
label_data = json.load(f) label_data = json.load(f)
return label_data return label_data
@staticmethod @staticmethod
def load_rgb(path): def load_rgb(path):
rgb_path = os.path.join(os.path.dirname(path), "rgb", os.path.basename(path) + ".png") rgb_path = os.path.join(
os.path.dirname(path), "rgb", os.path.basename(path) + ".png"
)
rgb_image = cv2.imread(rgb_path, cv2.IMREAD_COLOR) rgb_image = cv2.imread(rgb_path, cv2.IMREAD_COLOR)
return rgb_image return rgb_image
@staticmethod @staticmethod
def load_from_preprocessed_pts(path): def load_from_preprocessed_pts(path):
npy_path = os.path.join(os.path.dirname(path), "points", os.path.basename(path) + ".npy") npy_path = os.path.join(
os.path.dirname(path), "points", os.path.basename(path) + ".npy"
)
pts = np.load(npy_path) pts = np.load(npy_path)
return pts return pts
@staticmethod @staticmethod
def cam_pose_transformation(cam_pose_before): def cam_pose_transformation(cam_pose_before):
offset = np.asarray([ offset = np.asarray([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
[1, 0, 0, 0], cam_pose_after = cam_pose_before @ offset
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]])
cam_pose_after = cam_pose_before @ offset
return cam_pose_after return cam_pose_after
@staticmethod @staticmethod
@ -184,13 +221,17 @@ class DataLoadUtil:
scene_dir = os.path.dirname(path) scene_dir = os.path.dirname(path)
root_dir = os.path.dirname(scene_dir) root_dir = os.path.dirname(scene_dir)
scene_name = os.path.basename(scene_dir) scene_name = os.path.basename(scene_dir)
camera_params_path = os.path.join(os.path.dirname(path), "camera_params", os.path.basename(path) + ".json") camera_params_path = os.path.join(
with open(camera_params_path, 'r') as f: os.path.dirname(path), "camera_params", os.path.basename(path) + ".json"
)
with open(camera_params_path, "r") as f:
label_data = json.load(f) label_data = json.load(f)
cam_to_world = np.asarray(label_data["extrinsic"]) cam_to_world = np.asarray(label_data["extrinsic"])
cam_to_world = DataLoadUtil.cam_pose_transformation(cam_to_world) cam_to_world = DataLoadUtil.cam_pose_transformation(cam_to_world)
world_to_display_table = np.eye(4) world_to_display_table = np.eye(4)
world_to_display_table[:3, 3] = - DataLoadUtil.get_display_table_top(root_dir, scene_name) world_to_display_table[:3, 3] = -DataLoadUtil.get_display_table_top(
root_dir, scene_name
)
if display_table_as_world_space_origin: if display_table_as_world_space_origin:
cam_to_world = np.dot(world_to_display_table, cam_to_world) cam_to_world = np.dot(world_to_display_table, cam_to_world)
cam_intrinsic = np.asarray(label_data["intrinsic"]) cam_intrinsic = np.asarray(label_data["intrinsic"])
@ -198,7 +239,7 @@ class DataLoadUtil:
"cam_to_world": cam_to_world, "cam_to_world": cam_to_world,
"cam_intrinsic": cam_intrinsic, "cam_intrinsic": cam_intrinsic,
"far_plane": label_data["far_plane"], "far_plane": label_data["far_plane"],
"near_plane": label_data["near_plane"] "near_plane": label_data["near_plane"],
} }
if binocular: if binocular:
cam_to_world_R = np.asarray(label_data["extrinsic_R"]) cam_to_world_R = np.asarray(label_data["extrinsic_R"])
@ -213,79 +254,136 @@ class DataLoadUtil:
return cam_info return cam_info
@staticmethod @staticmethod
def get_real_cam_O_from_cam_L(cam_L, cam_O_to_cam_L, scene_path, display_table_as_world_space_origin=True): def get_real_cam_O_from_cam_L(
cam_L, cam_O_to_cam_L, scene_path, display_table_as_world_space_origin=True
):
root_dir = os.path.dirname(scene_path) root_dir = os.path.dirname(scene_path)
scene_name = os.path.basename(scene_path) scene_name = os.path.basename(scene_path)
if isinstance(cam_L, torch.Tensor): if isinstance(cam_L, torch.Tensor):
cam_L = cam_L.cpu().numpy() cam_L = cam_L.cpu().numpy()
nO_to_display_table_pose = cam_L @ cam_O_to_cam_L nO_to_display_table_pose = cam_L @ cam_O_to_cam_L
if display_table_as_world_space_origin: if display_table_as_world_space_origin:
display_table_to_world = np.eye(4) display_table_to_world = np.eye(4)
display_table_to_world[:3, 3] = DataLoadUtil.get_display_table_top(root_dir, scene_name) display_table_to_world[:3, 3] = DataLoadUtil.get_display_table_top(
root_dir, scene_name
)
nO_to_world_pose = np.dot(display_table_to_world, nO_to_display_table_pose) nO_to_world_pose = np.dot(display_table_to_world, nO_to_display_table_pose)
nO_to_world_pose = DataLoadUtil.cam_pose_transformation(nO_to_world_pose) nO_to_world_pose = DataLoadUtil.cam_pose_transformation(nO_to_world_pose)
return nO_to_world_pose return nO_to_world_pose
@staticmethod @staticmethod
def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(0,255,0,255)): def get_target_point_cloud(
depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(0, 255, 0, 255)
):
h, w = depth.shape h, w = depth.shape
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy') i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy")
z = depth z = depth
x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0]
y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1]
points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3)
mask = mask.reshape(-1,4) mask = mask.reshape(-1, 4)
target_mask = (mask == target_mask_label).all(axis=-1) target_mask = (mask == target_mask_label).all(axis=-1)
target_points_camera = points_camera[target_mask] target_points_camera = points_camera[target_mask]
target_points_camera_aug = np.concatenate([target_points_camera, np.ones((target_points_camera.shape[0], 1))], axis=-1) target_points_camera_aug = np.concatenate(
[target_points_camera, np.ones((target_points_camera.shape[0], 1))], axis=-1
)
target_points_world = np.dot(cam_extrinsic, target_points_camera_aug.T).T[:, :3] target_points_world = np.dot(cam_extrinsic, target_points_camera_aug.T).T[:, :3]
return { return {
"points_world": target_points_world, "points_world": target_points_world,
"points_camera": target_points_camera "points_camera": target_points_camera,
} }
@staticmethod @staticmethod
def get_point_cloud(depth, cam_intrinsic, cam_extrinsic): def get_point_cloud(depth, cam_intrinsic, cam_extrinsic):
h, w = depth.shape h, w = depth.shape
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy') i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy")
z = depth z = depth
x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0]
y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1]
points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3)
points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1) points_camera_aug = np.concatenate(
[points_camera, np.ones((points_camera.shape[0], 1))], axis=-1
)
points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3]
return { return {"points_world": points_world, "points_camera": points_camera}
"points_world": points_world,
"points_camera": points_camera
}
@staticmethod @staticmethod
def get_target_point_cloud_world_from_path(path, binocular=False, random_downsample_N=65536, voxel_size = 0.005, target_mask_label=(0,255,0,255)): def get_target_point_cloud_world_from_path(
path,
binocular=False,
random_downsample_N=65536,
voxel_size=0.005,
target_mask_label=(0, 255, 0, 255),
display_table_mask_label=(255, 0, 0, 255),
get_display_table_pts=False
):
cam_info = DataLoadUtil.load_cam_info(path, binocular=binocular) cam_info = DataLoadUtil.load_cam_info(path, binocular=binocular)
if binocular: if binocular:
depth_L, depth_R = DataLoadUtil.load_depth(path, cam_info['near_plane'], cam_info['far_plane'], binocular=True) depth_L, depth_R = DataLoadUtil.load_depth(
path, cam_info["near_plane"], cam_info["far_plane"], binocular=True
)
mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True)
point_cloud_L = DataLoadUtil.get_target_point_cloud(depth_L, cam_info['cam_intrinsic'], cam_info['cam_to_world'], mask_L, target_mask_label)['points_world'] point_cloud_L = DataLoadUtil.get_target_point_cloud(
point_cloud_R = DataLoadUtil.get_target_point_cloud(depth_R, cam_info['cam_intrinsic'], cam_info['cam_to_world_R'], mask_R, target_mask_label)['points_world'] depth_L,
point_cloud_L = PtsUtil.random_downsample_point_cloud(point_cloud_L, random_downsample_N) cam_info["cam_intrinsic"],
point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, random_downsample_N) cam_info["cam_to_world"],
overlap_points = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size) mask_L,
target_mask_label,
)["points_world"]
point_cloud_R = DataLoadUtil.get_target_point_cloud(
depth_R,
cam_info["cam_intrinsic"],
cam_info["cam_to_world_R"],
mask_R,
target_mask_label,
)["points_world"]
point_cloud_L = PtsUtil.random_downsample_point_cloud(
point_cloud_L, random_downsample_N
)
point_cloud_R = PtsUtil.random_downsample_point_cloud(
point_cloud_R, random_downsample_N
)
overlap_points = DataLoadUtil.get_overlapping_points(
point_cloud_L, point_cloud_R, voxel_size
)
if get_display_table_pts:
display_pts_L = DataLoadUtil.get_target_point_cloud(
depth_L,
cam_info["cam_intrinsic"],
cam_info["cam_to_world"],
mask_L,
display_table_mask_label,
)["points_world"]
display_pts_R = DataLoadUtil.get_target_point_cloud(
depth_R,
cam_info["cam_intrinsic"],
cam_info["cam_to_world_R"],
mask_R,
display_table_mask_label,
)["points_world"]
display_pts_overlap = DataLoadUtil.get_overlapping_points(
display_pts_L, display_pts_R, voxel_size
)
return overlap_points, display_pts_overlap
return overlap_points return overlap_points
else: else:
depth = DataLoadUtil.load_depth(path, cam_info['near_plane'], cam_info['far_plane']) depth = DataLoadUtil.load_depth(
path, cam_info["near_plane"], cam_info["far_plane"]
)
mask = DataLoadUtil.load_seg(path) mask = DataLoadUtil.load_seg(path)
point_cloud = DataLoadUtil.get_target_point_cloud(depth, cam_info['cam_intrinsic'], cam_info['cam_to_world'], mask)['points_world'] point_cloud = DataLoadUtil.get_target_point_cloud(
depth, cam_info["cam_intrinsic"], cam_info["cam_to_world"], mask
)["points_world"]
return point_cloud return point_cloud
@staticmethod @staticmethod
def voxelize_points(points, voxel_size): def voxelize_points(points, voxel_size):
@ -298,10 +396,12 @@ class DataLoadUtil:
voxels_L, indices_L = DataLoadUtil.voxelize_points(point_cloud_L, voxel_size) voxels_L, indices_L = DataLoadUtil.voxelize_points(point_cloud_L, voxel_size)
voxels_R, _ = DataLoadUtil.voxelize_points(point_cloud_R, voxel_size) voxels_R, _ = DataLoadUtil.voxelize_points(point_cloud_R, voxel_size)
voxel_indices_L = voxels_L.view([('', voxels_L.dtype)]*3) voxel_indices_L = voxels_L.view([("", voxels_L.dtype)] * 3)
voxel_indices_R = voxels_R.view([('', voxels_R.dtype)]*3) voxel_indices_R = voxels_R.view([("", voxels_R.dtype)] * 3)
overlapping_voxels = np.intersect1d(voxel_indices_L, voxel_indices_R) overlapping_voxels = np.intersect1d(voxel_indices_L, voxel_indices_R)
mask_L = np.isin(indices_L, np.where(np.isin(voxel_indices_L, overlapping_voxels))[0]) mask_L = np.isin(
indices_L, np.where(np.isin(voxel_indices_L, overlapping_voxels))[0]
)
overlapping_points = point_cloud_L[mask_L] overlapping_points = point_cloud_L[mask_L]
return overlapping_points return overlapping_points
@ -310,5 +410,7 @@ class DataLoadUtil:
points_path = os.path.join(root, scene_name, "points_and_normals.txt") points_path = os.path.join(root, scene_name, "points_and_normals.txt")
points_normals = np.loadtxt(points_path) points_normals = np.loadtxt(points_path)
if display_table_as_world_space_origin: if display_table_as_world_space_origin:
points_normals[:,:3] = points_normals[:,:3] - DataLoadUtil.get_display_table_top(root, scene_name) points_normals[:, :3] = points_normals[
:, :3
] - DataLoadUtil.get_display_table_top(root, scene_name)
return points_normals return points_normals

View File

@ -45,9 +45,10 @@ class ReconstructionUtil:
@staticmethod @staticmethod
def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list,threshold=0.01, overlap_threshold=0.3, init_view = 0, status_info=None): def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, scan_points_indices_list, threshold=0.01, overlap_threshold=0.3, init_view = 0, status_info=None):
selected_views = [point_cloud_list[init_view]] selected_views = [point_cloud_list[init_view]]
combined_point_cloud = np.vstack(selected_views) combined_point_cloud = np.vstack(selected_views)
combined_scan_points_indices = scan_points_indices_list[init_view]
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold) new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
current_coverage = new_coverage current_coverage = new_coverage
@ -63,12 +64,14 @@ class ReconstructionUtil:
for view_index in remaining_views: for view_index in remaining_views:
if selected_views: if selected_views:
combined_old_point_cloud = np.vstack(selected_views) new_scan_points_indices = scan_points_indices_list[view_index]
down_sampled_old_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_old_point_cloud,threshold) if not ReconstructionUtil.check_scan_points_overlap(combined_scan_points_indices, new_scan_points_indices):
down_sampled_new_view_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud_list[view_index],threshold) combined_old_point_cloud = np.vstack(selected_views)
overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_new_view_point_cloud,down_sampled_old_point_cloud, threshold) down_sampled_old_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_old_point_cloud,threshold)
if overlap_rate < overlap_threshold: down_sampled_new_view_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud_list[view_index],threshold)
continue overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_new_view_point_cloud,down_sampled_old_point_cloud, threshold)
if overlap_rate < overlap_threshold:
continue
candidate_views = selected_views + [point_cloud_list[view_index]] candidate_views = selected_views + [point_cloud_list[view_index]]
combined_point_cloud = np.vstack(candidate_views) combined_point_cloud = np.vstack(candidate_views)
@ -85,6 +88,7 @@ class ReconstructionUtil:
break break
selected_views.append(point_cloud_list[best_view]) selected_views.append(point_cloud_list[best_view])
remaining_views.remove(best_view) remaining_views.remove(best_view)
combined_scan_points_indices = ReconstructionUtil.combine_scan_points_indices(combined_scan_points_indices, scan_points_indices_list[best_view])
current_coverage += best_coverage_increase current_coverage += best_coverage_increase
cnt_processed_view += 1 cnt_processed_view += 1
if status_info is not None: if status_info is not None:
@ -121,3 +125,38 @@ class ReconstructionUtil:
return filtered_sampled_points[:, :3] return filtered_sampled_points[:, :3]
@staticmethod
def generate_scan_points(display_table_top, display_table_radius, min_distance=0.03, max_points_num = 100, max_attempts = 1000):
points = []
attempts = 0
while len(points) < max_points_num and attempts < max_attempts:
angle = np.random.uniform(0, 2 * np.pi)
r = np.random.uniform(0, display_table_radius)
x = r * np.cos(angle)
y = r * np.sin(angle)
z = display_table_top
new_point = (x, y, z)
if all(np.linalg.norm(np.array(new_point) - np.array(existing_point)) >= min_distance for existing_point in points):
points.append(new_point)
attempts += 1
return points
@staticmethod
def compute_covered_scan_points(scan_points, point_cloud, threshold=0.01):
tree = cKDTree(point_cloud)
covered_points = []
indices = []
for i, scan_point in enumerate(scan_points):
if tree.query_ball_point(scan_point, threshold):
covered_points.append(scan_point)
indices.append(i)
return covered_points, indices
@staticmethod
def check_scan_points_overlap(indices1, indices2, threshold=5):
return len(set(indices1).intersection(set(indices2))) > threshold
@staticmethod
def combine_scan_points_indices(indices1, indices2):
combined_indices = set(indices1) | set(indices2)
return sorted(combined_indices)