dataset update

This commit is contained in:
hofee 2024-08-30 17:57:47 +08:00
parent 71676e2f4e
commit be5a2d57fa
3 changed files with 13 additions and 10 deletions

View File

@ -15,7 +15,6 @@ runner:
datasets: datasets:
OmniObject3d: OmniObject3d:
root_dir: "C:\\Document\\Local Project\\nbv_rec\\sample_dataset" root_dir: "/media/hofee/data/data/nbv_rec/sample"
label_dir: "C:\\Document\\Local Project\\nbv_rec\\sample_output"

View File

@ -4,7 +4,7 @@ from PytorchBoot.dataset import BaseDataset
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
import sys import sys
sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction") sys.path.append(r"/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction")
from utils.data_load import DataLoadUtil from utils.data_load import DataLoadUtil
from utils.pose import PoseUtil from utils.pose import PoseUtil
@ -14,7 +14,6 @@ class NBVReconstructionDataset(BaseDataset):
def __init__(self, config): def __init__(self, config):
super(NBVReconstructionDataset, self).__init__(config) super(NBVReconstructionDataset, self).__init__(config)
self.config = config self.config = config
self.label_dir = config["label_dir"]
self.root_dir = config["root_dir"] self.root_dir = config["root_dir"]
self.datalist = self.get_datalist() self.datalist = self.get_datalist()
@ -22,7 +21,7 @@ class NBVReconstructionDataset(BaseDataset):
datalist = [] datalist = []
scene_name_list = os.listdir(self.root_dir) scene_name_list = os.listdir(self.root_dir)
for scene_name in scene_name_list: for scene_name in scene_name_list:
label_path = DataLoadUtil.get_label_path(self.label_dir, scene_name) label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name)
label_data = DataLoadUtil.load_label(label_path) label_data = DataLoadUtil.load_label(label_path)
for data_pair in label_data["data_pairs"]: for data_pair in label_data["data_pairs"]:
scanned_views = data_pair[0] scanned_views = data_pair[0]
@ -45,12 +44,17 @@ class NBVReconstructionDataset(BaseDataset):
max_coverage_rate = data_item_info["max_coverage_rate"] max_coverage_rate = data_item_info["max_coverage_rate"]
scene_name = data_item_info["scene_name"] scene_name = data_item_info["scene_name"]
scanned_views_pts, scanned_coverages_rate, scanned_cam_pose = [], [], [] scanned_views_pts, scanned_coverages_rate, scanned_cam_pose = [], [], []
first_frame_idx = scanned_views[0][0]
first_frame_pose = DataLoadUtil.load_cam_info(DataLoadUtil.get_path(self.root_dir, scene_name, first_frame_idx))["cam_to_world"]
for view in scanned_views: for view in scanned_views:
frame_idx = view[0] frame_idx = view[0]
coverage_rate = view[1] coverage_rate = view[1]
view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx)
pts = DataLoadUtil.load_depth(view_path) depth = DataLoadUtil.load_depth(view_path)
scanned_views_pts.append(pts) cam_info = DataLoadUtil.load_cam_info(view_path)
mask = DataLoadUtil.load_seg(view_path)
target_point_cloud = DataLoadUtil.get_target_point_cloud(depth, cam_info["cam_intrinsic"], cam_info["cam_to_world"], mask)
scanned_views_pts.append(target_point_cloud)
scanned_coverages_rate.append(coverage_rate) scanned_coverages_rate.append(coverage_rate)
cam_pose = DataLoadUtil.load_cam_info(view_path)["cam_to_world"] cam_pose = DataLoadUtil.load_cam_info(view_path)["cam_to_world"]
@ -86,13 +90,13 @@ class NBVReconstructionDataset(BaseDataset):
if __name__ == "__main__": if __name__ == "__main__":
import torch import torch
config = { config = {
"root_dir": "C:\\Document\\Local Project\\nbv_rec\\sample_dataset", "root_dir": "/media/hofee/data/data/nbv_rec/sample",
"label_dir": "C:\\Document\\Local Project\\nbv_rec\\sample_output",
"ratio": 0.1, "ratio": 0.1,
"batch_size": 1, "batch_size": 1,
"num_workers": 0, "num_workers": 0,
} }
ds = NBVReconstructionDataset(config) ds = NBVReconstructionDataset(config)
print(len(ds))
dl = ds.get_loader(shuffle=True) dl = ds.get_loader(shuffle=True)
for idx, data in enumerate(dl): for idx, data in enumerate(dl):
for key, value in data.items(): for key, value in data.items():

View File

@ -76,7 +76,7 @@ class StrategyGenerator(Runner):
def generate_data_pairs(self, useful_view): def generate_data_pairs(self, useful_view):
data_pairs = [] data_pairs = []
for next_view_idx in range(len(useful_view)): for next_view_idx in range(1, len(useful_view)):
scanned_views = useful_view[:next_view_idx] scanned_views = useful_view[:next_view_idx]
next_view = useful_view[next_view_idx] next_view = useful_view[next_view_idx]
data_pairs.append((scanned_views, next_view)) data_pairs.append((scanned_views, next_view))