import os import json import numpy as np from PytorchBoot.runners.runner import Runner from PytorchBoot.config import ConfigManager from PytorchBoot.utils import Log import PytorchBoot.stereotype as stereotype from PytorchBoot.status import status_manager from utils.data_load import DataLoadUtil from utils.reconstruction import ReconstructionUtil from utils.pts import PtsUtil @stereotype.runner("strategy_generator") class StrategyGenerator(Runner): def __init__(self, config): super().__init__(config) self.load_experiment("generate_strategy") self.status_info = { "status_manager": status_manager, "app_name": "generate_strategy", "runner_name": "strategy_generator" } self.overwrite = ConfigManager.get("runner", "generate", "overwrite") self.seq_num = ConfigManager.get("runner","generate","seq_num") self.overlap_area_threshold = ConfigManager.get("runner","generate","overlap_area_threshold") self.compute_with_normal = ConfigManager.get("runner","generate","compute_with_normal") self.scan_points_threshold = ConfigManager.get("runner","generate","scan_points_threshold") def run(self): dataset_name_list = ConfigManager.get("runner", "generate", "dataset_list") voxel_threshold = ConfigManager.get("runner","generate","voxel_threshold") for dataset_idx in range(len(dataset_name_list)): dataset_name = dataset_name_list[dataset_idx] status_manager.set_progress("generate_strategy", "strategy_generator", "dataset", dataset_idx, len(dataset_name_list)) root_dir = ConfigManager.get("datasets", dataset_name, "root_dir") from_idx = ConfigManager.get("datasets",dataset_name,"from") to_idx = ConfigManager.get("datasets",dataset_name,"to") scene_name_list = os.listdir(root_dir) if to_idx == -1: to_idx = len(scene_name_list) cnt = 0 total = len(scene_name_list[from_idx:to_idx]) Log.info(f"Processing Dataset: {dataset_name}, From: {from_idx}, To: {to_idx}") for scene_name in scene_name_list[from_idx:to_idx]: Log.info(f"({dataset_name})Processing [{cnt}/{total}]: {scene_name}") status_manager.set_progress("generate_strategy", "strategy_generator", "scene", cnt, total) output_label_path = DataLoadUtil.get_label_path(root_dir, scene_name,0) if os.path.exists(output_label_path) and not self.overwrite: Log.info(f"Scene <{scene_name}> Already Exists, Skip") cnt += 1 continue self.generate_sequence(root_dir, scene_name,voxel_threshold) cnt += 1 status_manager.set_progress("generate_strategy", "strategy_generator", "scene", total, total) status_manager.set_progress("generate_strategy", "strategy_generator", "dataset", len(dataset_name_list), len(dataset_name_list)) def create_experiment(self, backup_name=None): super().create_experiment(backup_name) output_dir = os.path.join(str(self.experiment_path), "output") os.makedirs(output_dir) def load_experiment(self, backup_name=None): super().load_experiment(backup_name) def generate_sequence(self, root, scene_name, voxel_threshold): status_manager.set_status("generate_strategy", "strategy_generator", "scene", scene_name) frame_num = DataLoadUtil.get_scene_seq_length(root, scene_name) model_points_normals = DataLoadUtil.load_points_normals(root, scene_name) model_pts = model_points_normals[:,:3] down_sampled_model_pts, idx = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold, require_idx=True) down_sampled_model_nrm = model_points_normals[idx, 3:] pts_list = [] nrm_list = [] scan_points_indices_list = [] non_zero_cnt = 0 for frame_idx in range(frame_num): status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num) pts_path = os.path.join(root,scene_name, "pts", f"{frame_idx}.npy") nrm_path = os.path.join(root,scene_name, "nrm", f"{frame_idx}.npy") idx_path = os.path.join(root,scene_name, "scan_points_indices", f"{frame_idx}.npy") pts = np.load(pts_path) if self.compute_with_normal: if pts.shape[0] == 0: nrm = np.zeros((0,3)) else: nrm = np.load(nrm_path) nrm_list.append(nrm) pts_list.append(pts) indices = np.load(idx_path) scan_points_indices_list.append(indices) if pts.shape[0] > 0: non_zero_cnt += 1 status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_num, frame_num) seq_num = min(self.seq_num, non_zero_cnt) init_view_list = [] idx = 0 while len(init_view_list) < seq_num and idx < len(pts_list): if pts_list[idx].shape[0] > 50: init_view_list.append(idx) idx += 1 seq_idx = 0 import time for init_view in init_view_list: status_manager.set_progress("generate_strategy", "strategy_generator", "computing sequence", seq_idx, len(init_view_list)) start = time.time() if not self.compute_with_normal: limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence(down_sampled_model_pts, pts_list, scan_points_indices_list = scan_points_indices_list,init_view=init_view, threshold=voxel_threshold, scan_points_threshold=self.scan_points_threshold, overlap_area_threshold=self.overlap_area_threshold, status_info=self.status_info) else: limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_normal(down_sampled_model_pts, down_sampled_model_nrm, pts_list, nrm_list, scan_points_indices_list = scan_points_indices_list,init_view=init_view, threshold=voxel_threshold, scan_points_threshold=self.scan_points_threshold, overlap_area_threshold=self.overlap_area_threshold, status_info=self.status_info) end = time.time() print(f"Time: {end-start}") data_pairs = self.generate_data_pairs(limited_useful_view) seq_save_data = { "data_pairs": data_pairs, "best_sequence": limited_useful_view, "max_coverage_rate": limited_useful_view[-1][1] } status_manager.set_status("generate_strategy", "strategy_generator", "max_coverage_rate", limited_useful_view[-1][1]) Log.success(f"Scene <{scene_name}> Finished, Max Coverage Rate: {limited_useful_view[-1][1]}, Best Sequence length: {len(limited_useful_view)}") output_label_path = DataLoadUtil.get_label_path(root, scene_name, seq_idx) with open(output_label_path, 'w') as f: json.dump(seq_save_data, f) seq_idx += 1 status_manager.set_progress("generate_strategy", "strategy_generator", "computing sequence", len(init_view_list), len(init_view_list)) def generate_data_pairs(self, useful_view): data_pairs = [] for next_view_idx in range(1, len(useful_view)): scanned_views = useful_view[:next_view_idx] next_view = useful_view[next_view_idx] data_pairs.append((scanned_views, next_view)) return data_pairs