import os import json from utils.render import RenderUtil from utils.pose import PoseUtil from utils.pts import PtsUtil from utils.reconstruction import ReconstructionUtil import torch from tqdm import tqdm import numpy as np import pickle from PytorchBoot.config import ConfigManager import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory import ComponentFactory from PytorchBoot.dataset import BaseDataset from PytorchBoot.runners.runner import Runner from PytorchBoot.utils import Log from PytorchBoot.status import status_manager @stereotype.runner("inferencer") class Inferencer(Runner): def __init__(self, config_path): super().__init__(config_path) self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path") self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir") self.voxel_size = ConfigManager.get(namespace.Stereotype.RUNNER, "voxel_size") ''' Pipeline ''' self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) self.pipeline = self.pipeline.to(self.device) ''' Experiment ''' self.load_experiment("nbv_evaluator") self.stat_result = {} ''' Test ''' self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST) self.test_dataset_name_list = self.test_config["dataset_list"] self.test_set_list = [] self.test_writer_list = [] seen_name = set() for test_dataset_name in self.test_dataset_name_list: if test_dataset_name not in seen_name: seen_name.add(test_dataset_name) else: raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name)) test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name) self.test_set_list.append(test_set) self.print_info() def run(self): Log.info("Loading from epoch {}.".format(self.current_epoch)) self.inference() Log.success("Inference finished.") def inference(self): self.pipeline.eval() with torch.no_grad(): test_set: BaseDataset for dataset_idx, test_set in enumerate(self.test_set_list): status_manager.set_progress("inference", "inferencer", f"dataset", dataset_idx, len(self.test_set_list)) test_set_name = test_set.get_name() total=int(len(test_set)) scene_name_list = test_set.get_scene_name_list() for i in range(total): scene_name = scene_name_list[i] inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl") if os.path.exists(inference_result_path): Log.info(f"Inference result already exists for scene: {scene_name}") continue data = test_set.__getitem__(i) status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total) scene_name = data["scene_name"] output = self.predict_sequence(data) self.save_inference_result(test_set_name, data["scene_name"], output) status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list)) def predict_sequence(self, data, cr_increase_threshold=0, max_iter=50, max_retry=5): scene_name = data["scene_name"] Log.info(f"Processing scene: {scene_name}") status_manager.set_status("inference", "inferencer", "scene", scene_name) ''' data for rendering ''' scene_path = data["scene_path"] O_to_L_pose = data["O_to_L_pose"] voxel_threshold = self.voxel_size filter_degree = 75 down_sampled_model_pts = data["gt_pts"] first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0] first_frame_to_world = np.eye(4) first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6]) first_frame_to_world[:3,3] = first_frame_to_world_9d[6:] ''' data for inference ''' input_data = {} input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0) input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)] input_data["mode"] = namespace.Mode.TEST input_pts_N = input_data["combined_scanned_pts"].shape[1] first_frame_target_pts, first_frame_target_normals = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) scanned_view_pts = [first_frame_target_pts] last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold) retry_duplication_pose = [] retry_no_pts_pose = [] retry = 0 pred_cr_seq = [last_pred_cr] success = 0 while len(pred_cr_seq) < max_iter and retry < max_retry: output = self.pipeline(input_data) pred_pose_9d = output["pred_pose_9d"] pred_pose = torch.eye(4, device=pred_pose_9d.device) pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0] pred_pose[:3,3] = pred_pose_9d[0,6:] try: new_target_pts, new_target_normals = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) except Exception as e: Log.warning(f"Error in scene {scene_path}, {e}") print("current pose: ", pred_pose) print("curr_pred_cr: ", last_pred_cr) retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) retry += 1 continue if new_target_pts.shape[0] == 0: print("no pts in new target") retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) retry += 1 continue pred_cr, new_added_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"]) if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: print("max coverage rate reached!: ", pred_cr) success += 1 elif new_added_pts_num < 10: print("min added pts num reached!: ", new_added_pts_num) if pred_cr <= last_pred_cr + cr_increase_threshold: retry += 1 retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) continue retry = 0 pred_cr_seq.append(pred_cr) scanned_view_pts.append(new_target_pts) down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_target_pts, input_pts_N) new_pts = down_sampled_new_pts_world input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] combined_scanned_pts = np.concatenate([input_data["combined_scanned_pts"][0].cpu().numpy(), new_pts], axis=0) voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, 0.002) random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) if success > 3: break last_pred_cr = pred_cr input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist() result = { "pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"], "combined_scanned_pts": input_data["combined_scanned_pts"], "target_pts_seq": scanned_view_pts, "coverage_rate_seq": pred_cr_seq, "max_coverage_rate": data["seq_max_coverage_rate"], "pred_max_coverage_rate": max(pred_cr_seq), "scene_name": scene_name, "retry_no_pts_pose": retry_no_pts_pose, "retry_duplication_pose": retry_duplication_pose, "best_seq_len": data["best_seq_len"], } self.stat_result[scene_name] = { "coverage_rate_seq": pred_cr_seq, "pred_max_coverage_rate": max(pred_cr_seq), "pred_seq_len": len(pred_cr_seq), } print('success rate: ', max(pred_cr_seq)) return result def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005): if new_pts is not None: new_scanned_view_pts = scanned_view_pts + [new_pts] else: new_scanned_view_pts = scanned_view_pts combined_point_cloud = np.vstack(new_scanned_view_pts) down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold) def save_inference_result(self, dataset_name, scene_name, output): dataset_dir = os.path.join(self.output_dir, dataset_name) if not os.path.exists(dataset_dir): os.makedirs(dataset_dir) output_path = os.path.join(dataset_dir, f"{scene_name}.pkl") pickle.dump(output, open(output_path, "wb")) with open(os.path.join(dataset_dir, "stat.json"), "w") as f: json.dump(self.stat_result, f) def get_checkpoint_path(self, is_last=False): return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, "Epoch_{}.pth".format( self.current_epoch if self.current_epoch != -1 and not is_last else "last")) def load_checkpoint(self, is_last=False): self.load(self.get_checkpoint_path(is_last)) Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") if is_last: checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) meta_path = os.path.join(checkpoint_root, "meta.json") if not os.path.exists(meta_path): raise FileNotFoundError( "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) file_path = os.path.join(checkpoint_root, "meta.json") with open(file_path, "r") as f: meta = json.load(f) self.current_epoch = meta["last_epoch"] self.current_iter = meta["last_iter"] def load_experiment(self, backup_name=None): super().load_experiment(backup_name) self.current_epoch = self.experiments_config["epoch"] self.load_checkpoint(is_last=(self.current_epoch == -1)) def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def load(self, path): state_dict = torch.load(path) self.pipeline.load_state_dict(state_dict) def print_info(self): def print_dataset(dataset: BaseDataset): config = dataset.get_config() name = dataset.get_name() Log.blue(f"Dataset: {name}") for k,v in config.items(): Log.blue(f"\t{k}: {v}") super().print_info() table_size = 70 Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') Log.blue(self.pipeline) Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') for i, test_set in enumerate(self.test_set_list): Log.blue(f"test dataset {i}: ") print_dataset(test_set) Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')