dataset update

This commit is contained in:
hofee 2024-08-30 17:57:47 +08:00
parent 71676e2f4e
commit be5a2d57fa
3 changed files with 13 additions and 10 deletions

View File

@ -15,7 +15,6 @@ runner:
datasets:
OmniObject3d:
root_dir: "C:\\Document\\Local Project\\nbv_rec\\sample_dataset"
label_dir: "C:\\Document\\Local Project\\nbv_rec\\sample_output"
root_dir: "/media/hofee/data/data/nbv_rec/sample"

View File

@ -4,7 +4,7 @@ from PytorchBoot.dataset import BaseDataset
import PytorchBoot.stereotype as stereotype
import sys
sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
sys.path.append(r"/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction")
from utils.data_load import DataLoadUtil
from utils.pose import PoseUtil
@ -14,7 +14,6 @@ class NBVReconstructionDataset(BaseDataset):
def __init__(self, config):
super(NBVReconstructionDataset, self).__init__(config)
self.config = config
self.label_dir = config["label_dir"]
self.root_dir = config["root_dir"]
self.datalist = self.get_datalist()
@ -22,7 +21,7 @@ class NBVReconstructionDataset(BaseDataset):
datalist = []
scene_name_list = os.listdir(self.root_dir)
for scene_name in scene_name_list:
label_path = DataLoadUtil.get_label_path(self.label_dir, scene_name)
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name)
label_data = DataLoadUtil.load_label(label_path)
for data_pair in label_data["data_pairs"]:
scanned_views = data_pair[0]
@ -45,12 +44,17 @@ class NBVReconstructionDataset(BaseDataset):
max_coverage_rate = data_item_info["max_coverage_rate"]
scene_name = data_item_info["scene_name"]
scanned_views_pts, scanned_coverages_rate, scanned_cam_pose = [], [], []
first_frame_idx = scanned_views[0][0]
first_frame_pose = DataLoadUtil.load_cam_info(DataLoadUtil.get_path(self.root_dir, scene_name, first_frame_idx))["cam_to_world"]
for view in scanned_views:
frame_idx = view[0]
coverage_rate = view[1]
view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx)
pts = DataLoadUtil.load_depth(view_path)
scanned_views_pts.append(pts)
depth = DataLoadUtil.load_depth(view_path)
cam_info = DataLoadUtil.load_cam_info(view_path)
mask = DataLoadUtil.load_seg(view_path)
target_point_cloud = DataLoadUtil.get_target_point_cloud(depth, cam_info["cam_intrinsic"], cam_info["cam_to_world"], mask)
scanned_views_pts.append(target_point_cloud)
scanned_coverages_rate.append(coverage_rate)
cam_pose = DataLoadUtil.load_cam_info(view_path)["cam_to_world"]
@ -86,13 +90,13 @@ class NBVReconstructionDataset(BaseDataset):
if __name__ == "__main__":
import torch
config = {
"root_dir": "C:\\Document\\Local Project\\nbv_rec\\sample_dataset",
"label_dir": "C:\\Document\\Local Project\\nbv_rec\\sample_output",
"root_dir": "/media/hofee/data/data/nbv_rec/sample",
"ratio": 0.1,
"batch_size": 1,
"num_workers": 0,
}
ds = NBVReconstructionDataset(config)
print(len(ds))
dl = ds.get_loader(shuffle=True)
for idx, data in enumerate(dl):
for key, value in data.items():

View File

@ -76,7 +76,7 @@ class StrategyGenerator(Runner):
def generate_data_pairs(self, useful_view):
data_pairs = []
for next_view_idx in range(len(useful_view)):
for next_view_idx in range(1, len(useful_view)):
scanned_views = useful_view[:next_view_idx]
next_view = useful_view[next_view_idx]
data_pairs.append((scanned_views, next_view))