add inference

This commit is contained in:
hofee 2024-09-19 00:14:26 +08:00
parent 9ec3a00fd4
commit 935069d68c
10 changed files with 302 additions and 139 deletions

16
app_inference.py Normal file
View File

@ -0,0 +1,16 @@
from PytorchBoot.application import PytorchBootApplication
from runners.inferencer import Inferencer
@PytorchBootApplication("inference")
class InferenceApp:
@staticmethod
def start():
'''
call default or your custom runners here, code will be executed
automatically when type "pytorch-boot run" or "ptb run" in terminal
example:
Trainer("path_to_your_train_config").run()
Evaluator("path_to_your_eval_config").run()
'''
Inferencer("./configs/local/inference_config.yaml").run()

View File

@ -0,0 +1,70 @@
runner:
general:
seed: 1
device: cuda
cuda_visible_devices: "0,1,2,3,4,5,6,7"
experiment:
name: local_eval
root_dir: "experiments"
epoch: 600 # -1 stands for last epoch
test:
dataset_list:
- OmniObject3d_train
blender_script_path: "/media/hofee/data/project/python/nbv_reconstruction/blender/data_renderer.py"
output_dir: "/media/hofee/data/project/python/nbv_reconstruction/sample_for_training/inference_result"
pipeline: nbv_reconstruction_pipeline
dataset:
OmniObject3d_train:
root_dir: "/media/hofee/data/project/python/nbv_reconstruction/sample_for_training/scenes"
model_dir: "/media/hofee/data/data/scaled_object_meshes"
source: seq_nbv_reconstruction_dataset
split_file: "/media/hofee/data/project/python/nbv_reconstruction/sample_for_training/OmniObject3d_train.txt"
type: test
filter_degree: 75
ratio: 1
batch_size: 1
num_workers: 12
pts_num: 4096
pipeline:
nbv_reconstruction_pipeline:
pts_encoder: pointnet_encoder
seq_encoder: transformer_seq_encoder
pose_encoder: pose_encoder
view_finder: gf_view_finder
module:
pointnet_encoder:
in_dim: 3
out_dim: 1024
global_feat: True
feature_transform: False
transformer_seq_encoder:
pts_embed_dim: 1024
pose_embed_dim: 256
num_heads: 4
ffn_dim: 256
num_layers: 3
output_dim: 2048
gf_view_finder:
t_feat_dim: 128
pose_feat_dim: 256
main_feat_dim: 2048
regression_head: Rx_Ry_and_T
pose_mode: rot_matrix
per_point_feature: False
sample_mode: ode
sampling_steps: 500
sde_mode: ve
pose_encoder:
pose_dim: 9
out_dim: 256

View File

@ -27,7 +27,7 @@ class NBVReconstructionDataset(BaseDataset):
self.pts_num = config["pts_num"] self.pts_num = config["pts_num"]
self.type = config["type"] self.type = config["type"]
self.cache = config["cache"] self.cache = config.get("cache")
if self.type == namespace.Mode.TEST: if self.type == namespace.Mode.TEST:
self.model_dir = config["model_dir"] self.model_dir = config["model_dir"]
self.filter_degree = config["filter_degree"] self.filter_degree = config["filter_degree"]
@ -105,6 +105,9 @@ class NBVReconstructionDataset(BaseDataset):
nR_to_world_pose = cam_info["cam_to_world_R"] nR_to_world_pose = cam_info["cam_to_world_R"]
n_to_1_pose = np.dot(np.linalg.inv(first_frame_to_world), n_to_world_pose) n_to_1_pose = np.dot(np.linalg.inv(first_frame_to_world), n_to_world_pose)
nR_to_1_pose = np.dot(np.linalg.inv(first_frame_to_world), nR_to_world_pose) nR_to_1_pose = np.dot(np.linalg.inv(first_frame_to_world), nR_to_world_pose)
cached_data = None
if self.cache:
cached_data = self.load_from_cache(scene_name, first_frame_idx, frame_idx) cached_data = self.load_from_cache(scene_name, first_frame_idx, frame_idx)
if cached_data is None: if cached_data is None:
@ -116,6 +119,7 @@ class NBVReconstructionDataset(BaseDataset):
point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536) point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536)
overlap_points = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R) overlap_points = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R)
downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(overlap_points, self.pts_num) downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(overlap_points, self.pts_num)
if self.cache:
self.save_to_cache(scene_name, first_frame_idx, frame_idx, downsampled_target_point_cloud) self.save_to_cache(scene_name, first_frame_idx, frame_idx, downsampled_target_point_cloud)
else: else:
downsampled_target_point_cloud = cached_data downsampled_target_point_cloud = cached_data

View File

@ -1,44 +1,15 @@
import torch import torch
import os
import json
import numpy as np import numpy as np
import subprocess
import tempfile
from utils.data_load import DataLoadUtil
from utils.reconstruction import ReconstructionUtil from utils.reconstruction import ReconstructionUtil
from utils.pose import PoseUtil from utils.pose import PoseUtil
from utils.pts import PtsUtil from utils.pts import PtsUtil
from utils.render import RenderUtil
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
import PytorchBoot.namespace as namespace import PytorchBoot.namespace as namespace
from PytorchBoot.utils.log_util import Log from PytorchBoot.utils.log_util import Log
def render_pts(cam_pose, scene_path,script_path, model_points_normals, voxel_threshold=0.005, filter_degree=75, nO_to_nL_pose=None):
nO_to_world_pose = cam_pose.cpu().numpy() @ nO_to_nL_pose
nO_to_world_pose = DataLoadUtil.cam_pose_transformation(nO_to_world_pose)
with tempfile.TemporaryDirectory() as temp_dir:
params = {
"cam_pose": nO_to_world_pose.tolist(),
"scene_path": scene_path
}
params_data_path = os.path.join(temp_dir, "params.json")
with open(params_data_path, 'w') as f:
json.dump(params, f)
result = subprocess.run([
'blender', '-b', '-P', script_path, '--', temp_dir
], capture_output=True, text=True)
if result.returncode != 0:
print("Blender script failed:")
print(result.stderr)
return None
path = os.path.join(temp_dir, "tmp")
point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True)
cam_params = DataLoadUtil.load_cam_info(path, binocular=True)
sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=filter_degree)
return sampled_point_cloud
@stereotype.evaluation_method("pose_diff") @stereotype.evaluation_method("pose_diff")
class PoseDiff: class PoseDiff:
def __init__(self, _): def __init__(self, _):
@ -110,7 +81,7 @@ class ConverageRateIncrease:
filter_degree = filter_degree_list[idx] filter_degree = filter_degree_list[idx]
nO_to_nL_pose = nO_to_nL_pose_batch[idx] nO_to_nL_pose = nO_to_nL_pose_batch[idx]
try: try:
new_pts = render_pts(pred_pose, scene_path, self.renderer_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=nO_to_nL_pose) new_pts, _ = RenderUtil.render_pts(pred_pose, scene_path, self.renderer_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=nO_to_nL_pose)
pred_cr = self.compute_coverage_rate(scanned_view_pts, new_pts, down_sampled_model_pts, threshold=voxel_threshold) pred_cr = self.compute_coverage_rate(scanned_view_pts, new_pts, down_sampled_model_pts, threshold=voxel_threshold)
except Exception as e: except Exception as e:
Log.warning(f"Error in scene {scene_path}, {e}") Log.warning(f"Error in scene {scene_path}, {e}")

View File

@ -5,7 +5,7 @@ import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory.component_factory import ComponentFactory from PytorchBoot.factory.component_factory import ComponentFactory
from PytorchBoot.utils import Log from PytorchBoot.utils import Log
@stereotype.pipeline("nbv_reconstruction_pipeline", comment="should be tested") @stereotype.pipeline("nbv_reconstruction_pipeline")
class NBVReconstructionPipeline(nn.Module): class NBVReconstructionPipeline(nn.Module):
def __init__(self, config): def __init__(self, config):
super(NBVReconstructionPipeline, self).__init__() super(NBVReconstructionPipeline, self).__init__()
@ -67,14 +67,13 @@ class NBVReconstructionPipeline(nn.Module):
def get_seq_feat(self, data): def get_seq_feat(self, data):
scanned_pts_batch = data['scanned_pts'] scanned_pts_batch = data['scanned_pts']
scanned_n_to_1_pose_9d_batch = data['scanned_n_to_1_pose_9d'] scanned_n_to_1_pose_9d_batch = data['scanned_n_to_1_pose_9d']
best_to_1_pose_9d_batch = data["best_to_1_pose_9d"]
pts_feat_seq_list = [] pts_feat_seq_list = []
pose_feat_seq_list = [] pose_feat_seq_list = []
device = next(self.parameters()).device
for scanned_pts,scanned_n_to_1_pose_9d in zip(scanned_pts_batch,scanned_n_to_1_pose_9d_batch): for scanned_pts,scanned_n_to_1_pose_9d in zip(scanned_pts_batch,scanned_n_to_1_pose_9d_batch):
scanned_pts = scanned_pts.to(best_to_1_pose_9d_batch.device) scanned_pts = scanned_pts.to(device)
scanned_n_to_1_pose_9d = scanned_n_to_1_pose_9d.to(best_to_1_pose_9d_batch.device) scanned_n_to_1_pose_9d = scanned_n_to_1_pose_9d.to(device)
pts_feat_seq_list.append(self.pts_encoder.encode_points(scanned_pts)) pts_feat_seq_list.append(self.pts_encoder.encode_points(scanned_pts))
pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_1_pose_9d)) pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_1_pose_9d))

View File

@ -51,7 +51,7 @@ class SeqNBVReconstructionDataset(BaseDataset):
"first_frame": first_frame, "first_frame": first_frame,
"max_coverage_rate": max_coverage_rate "max_coverage_rate": max_coverage_rate
}) })
return datalist return datalist[5:]
def __getitem__(self, index): def __getitem__(self, index):
data_item_info = self.datalist[index] data_item_info = self.datalist[index]
@ -85,6 +85,7 @@ class SeqNBVReconstructionDataset(BaseDataset):
voxel_threshold = diag*0.02 voxel_threshold = diag*0.02
first_O_to_first_L_pose = np.dot(np.linalg.inv(first_left_cam_pose), first_center_cam_pose) first_O_to_first_L_pose = np.dot(np.linalg.inv(first_left_cam_pose), first_center_cam_pose)
scene_path = os.path.join(self.root_dir, scene_name) scene_path = os.path.join(self.root_dir, scene_name)
model_points_normals = DataLoadUtil.load_points_normals(self.root_dir, scene_name)
data_item = { data_item = {
"first_pts": np.asarray([first_downsampled_target_point_cloud],dtype=np.float32), "first_pts": np.asarray([first_downsampled_target_point_cloud],dtype=np.float32),
"first_to_first_9d": np.asarray([first_to_first_9d],dtype=np.float32), "first_to_first_9d": np.asarray([first_to_first_9d],dtype=np.float32),
@ -92,10 +93,11 @@ class SeqNBVReconstructionDataset(BaseDataset):
"max_coverage_rate": max_coverage_rate, "max_coverage_rate": max_coverage_rate,
"voxel_threshold": voxel_threshold, "voxel_threshold": voxel_threshold,
"filter_degree": self.filter_degree, "filter_degree": self.filter_degree,
"first_frame_to_world": first_frame_to_world, "first_frame_to_world": np.asarray(first_frame_to_world, dtype=np.float32),
"first_O_to_first_L_pose": first_O_to_first_L_pose, "O_to_L_pose": first_O_to_first_L_pose,
"first_frame_coverage": first_frame_coverage, "first_frame_coverage": first_frame_coverage,
"scene_path": scene_path "scene_path": scene_path,
"model_points_normals": model_points_normals,
} }
return data_item return data_item

View File

@ -1,33 +1,35 @@
import os import os
import json import json
from datetime import datetime from utils.render import RenderUtil
from utils.pose import PoseUtil
from utils.pts import PtsUtil
from utils.reconstruction import ReconstructionUtil
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import numpy as np
import pickle
from PytorchBoot.config import ConfigManager from PytorchBoot.config import ConfigManager
import PytorchBoot.namespace as namespace import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory import ComponentFactory from PytorchBoot.factory import ComponentFactory
from PytorchBoot.factory import OptimizerFactory
from PytorchBoot.dataset import BaseDataset from PytorchBoot.dataset import BaseDataset
from PytorchBoot.runners.runner import Runner from PytorchBoot.runners.runner import Runner
from PytorchBoot.stereotype import EXTERNAL_FRONZEN_MODULES
from PytorchBoot.utils import Log from PytorchBoot.utils import Log
from PytorchBoot.status import status_manager from PytorchBoot.status import status_manager
@stereotype.runner("nbv_evaluator") @stereotype.runner("inferencer", comment="not tested")
class NextBestViewEvaluator(Runner): class Inferencer(Runner):
def __init__(self, config_path): def __init__(self, config_path):
super().__init__(config_path) super().__init__(config_path)
self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path")
self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir")
''' Pipeline ''' ''' Pipeline '''
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
self.parallel = self.config["general"]["parallel"]
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
if self.parallel and self.device == "cuda":
self.pipeline = torch.nn.DataParallel(self.pipeline)
self.pipeline = self.pipeline.to(self.device) self.pipeline = self.pipeline.to(self.device)
''' Experiment ''' ''' Experiment '''
@ -46,55 +48,135 @@ class NextBestViewEvaluator(Runner):
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name)) raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name) test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
self.test_set_list.append(test_set) self.test_set_list.append(test_set)
self.print_info() self.print_info()
def run(self): def run(self):
Log.info("Loading from epoch {}.".format(self.current_epoch)) Log.info("Loading from epoch {}.".format(self.current_epoch))
self.test() self.inference()
Log.success("Inference finished.")
def test(self): def inference(self):
self.pipeline.eval() self.pipeline.eval()
with torch.no_grad(): with torch.no_grad():
test_set: BaseDataset test_set: BaseDataset
for dataset_idx, test_set in enumerate(self.test_set_list): for dataset_idx, test_set in enumerate(self.test_set_list):
test_set_config = test_set.get_config() status_manager.set_progress("inference", "inferencer", f"dataset", dataset_idx, len(self.test_set_list))
eval_list = test_set_config["eval_list"]
ratio = test_set_config["ratio"]
test_set_name = test_set.get_name() test_set_name = test_set.get_name()
output_list = []
data_list = []
test_loader = test_set.get_loader() test_loader = test_set.get_loader()
if test_loader.batch_size > 1:
Log.error("Batch size should be 1 for inference, found {} in {}".format(test_loader.batch_size, test_set_name), terminate=True)
total=int(len(test_loader)) total=int(len(test_loader))
loop = tqdm(enumerate(test_loader), total=total) loop = tqdm(enumerate(test_loader), total=total)
for i, data in loop: for i, data in loop:
status_manager.set_progress("train", "default_trainer", f"(test) Batch[{test_set_name}]", i+1, total) status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
test_set.process_batch(data, self.device) test_set.process_batch(data, self.device)
data["mode"] = namespace.Mode.TEST output = self.predict_sequence(data)
output = self.pipeline(data) self.save_inference_result(output, data)
output_list.append(output)
data_list.append(data)
loop.set_description(
f'Epoch [{self.current_epoch}/{self.max_epochs}] (Test: {test_set_name}, ratio={ratio})')
result_dict = self.eval_fn(output_list, data_list, eval_list)
@staticmethod status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
def eval_fn(output_list, data_list, eval_list):
collected_result = {} def predict_sequence(self, data, cr_increase_threshold=0, max_iter=100):
for eval_method_name in eval_list: pred_cr_seq = []
eval_method = ComponentFactory.create(namespace.Stereotype.EVALUATION_METHOD, eval_method_name) scene_name = data["scene_name"][0]
eval_results:dict = eval_method.evaluate(output_list, data_list) Log.info(f"Processing scene: {scene_name}")
for data_type, eval_result in eval_results.items(): status_manager.set_status("inference", "inferencer", "scene", scene_name)
if data_type not in collected_result:
collected_result[data_type] = {} ''' data for rendering '''
for name, value in eval_result.items(): scene_path = data["scene_path"][0]
collected_result[data_type][name] = value O_to_L_pose = data["O_to_L_pose"][0]
status_manager.set_status("train", "default_trainer", f"[eval]{name}", value) voxel_threshold = data["voxel_threshold"][0]
filter_degree = data["filter_degree"][0]
model_points_normals = data["model_points_normals"][0]
model_pts = model_points_normals[:,:3]
down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold)
first_frame_to_world = data["first_frame_to_world"][0]
''' data for inference '''
input_data = {}
input_data["scanned_pts"] = [data["first_pts"][0].to(self.device)]
input_data["scanned_n_to_1_pose_9d"] = [data["first_to_first_9d"][0].to(self.device)]
input_data["mode"] = namespace.Mode.TEST
input_pts_N = input_data["scanned_pts"][0].shape[1]
first_frame_target_pts, _ = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
scanned_view_pts = [first_frame_target_pts]
last_pred_cr = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)
while len(pred_cr_seq) < max_iter:
output = self.pipeline(input_data)
next_pose_9d = output["pred_pose_9d"]
pred_pose = torch.eye(4, device=next_pose_9d.device)
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(next_pose_9d[:,:6])[0]
pred_pose[:3,3] = next_pose_9d[0,6:]
pred_n_to_world_pose_mat = torch.matmul(first_frame_to_world, pred_pose)
try:
new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_n_to_world_pose_mat, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True)
except Exception as e:
Log.warning(f"Error in scene {scene_path}, {e}")
print("current pose: ", pred_pose)
print("curr_pred_cr: ", last_pred_cr)
continue
pred_cr = self.compute_coverage_rate(scanned_view_pts, new_target_pts_world, down_sampled_model_pts, threshold=voxel_threshold)
pred_cr_seq.append(pred_cr)
if pred_cr >= data["max_coverage_rate"]:
break
if pred_cr < last_pred_cr + cr_increase_threshold:
break
scanned_view_pts.append(new_target_pts_world)
down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_pts_world, input_pts_N)
new_pts_world_aug = np.hstack([down_sampled_new_pts_world, np.ones((down_sampled_new_pts_world.shape[0], 1))])
new_pts = np.dot(np.linalg.inv(first_frame_to_world.cpu()), new_pts_world_aug.T).T[:,:3]
new_pts_tensor = torch.tensor(new_pts, dtype=torch.float32).unsqueeze(0).to(self.device)
input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0] , new_pts_tensor], dim=0)]
input_data["scanned_n_to_1_pose_9d"] = [torch.cat([input_data["scanned_n_to_1_pose_9d"][0], next_pose_9d], dim=0)]
last_pred_cr = pred_cr
# ------ Debug Start ------
import ipdb;ipdb.set_trace()
# ------ Debug End ------
input_data["scanned_pts"] = input_data["scanned_pts"][0].cpu().numpy().tolist()
input_data["scanned_n_to_1_pose_9d"] = input_data["scanned_n_to_1_pose_9d"][0].cpu().numpy().tolist()
result = {
"pred_pose_9d_seq": input_data["scanned_n_to_1_pose_9d"],
"pts_seq": input_data["scanned_pts"],
"target_pts_seq": scanned_view_pts,
"coverage_rate_seq": pred_cr_seq,
"max_coverage_rate": data["max_coverage_rate"],
"pred_max_coverage_rate": max(pred_cr_seq)
}
return result
def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005):
if new_pts is not None:
new_scanned_view_pts = scanned_view_pts + [new_pts]
else:
new_scanned_view_pts = scanned_view_pts
combined_point_cloud = np.vstack(new_scanned_view_pts)
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold)
return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold)
def save_inference_result(self, dataset_name, scene_name, output):
dataset_dir = os.path.join(self.output_dir, dataset_name)
if not os.path.exists(dataset_dir):
os.makedirs(dataset_dir)
pickle.dump(output, open(f"result_{scene_name}.pkl", "wb"))
return collected_result
def get_checkpoint_path(self, is_last=False): def get_checkpoint_path(self, is_last=False):
return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME,
@ -116,55 +198,19 @@ class NextBestViewEvaluator(Runner):
self.current_epoch = meta["last_epoch"] self.current_epoch = meta["last_epoch"]
self.current_iter = meta["last_iter"] self.current_iter = meta["last_iter"]
def save_checkpoint(self, is_last=False):
self.save(self.get_checkpoint_path(is_last))
if not is_last:
Log.success(f"Checkpoint at epoch {self.current_epoch} saved to {self.get_checkpoint_path(is_last)}")
else:
meta = {
"last_epoch": self.current_epoch,
"last_iter": self.current_iter,
"time": str(datetime.now())
}
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
file_path = os.path.join(checkpoint_root, "meta.json")
with open(file_path, "w") as f:
json.dump(meta, f)
def load_experiment(self, backup_name=None): def load_experiment(self, backup_name=None):
super().load_experiment(backup_name) super().load_experiment(backup_name)
if self.experiments_config["use_checkpoint"]:
self.current_epoch = self.experiments_config["epoch"] self.current_epoch = self.experiments_config["epoch"]
self.load_checkpoint(is_last=(self.current_epoch == -1)) self.load_checkpoint(is_last=(self.current_epoch == -1))
def create_experiment(self, backup_name=None): def create_experiment(self, backup_name=None):
super().create_experiment(backup_name) super().create_experiment(backup_name)
ckpt_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.CHECKPOINT_DIR_NAME)
os.makedirs(ckpt_dir)
tensorboard_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.TENSORBOARD_DIR_NAME)
os.makedirs(tensorboard_dir)
def load(self, path): def load(self, path):
state_dict = torch.load(path) state_dict = torch.load(path)
if self.parallel:
self.pipeline.module.load_state_dict(state_dict)
else:
self.pipeline.load_state_dict(state_dict) self.pipeline.load_state_dict(state_dict)
def save(self, path):
if self.parallel:
state_dict = self.pipeline.module.state_dict()
else:
state_dict = self.pipeline.state_dict()
for name, module in self.pipeline.named_modules():
if module.__class__ in EXTERNAL_FRONZEN_MODULES:
if name in state_dict:
del state_dict[name]
torch.save(state_dict, path)
def print_info(self): def print_info(self):
def print_dataset(dataset: BaseDataset): def print_dataset(dataset: BaseDataset):
config = dataset.get_config() config = dataset.get_config()
@ -178,8 +224,6 @@ class NextBestViewEvaluator(Runner):
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
Log.blue(self.pipeline) Log.blue(self.pipeline)
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
Log.blue("train dataset: ")
print_dataset(self.train_set)
for i, test_set in enumerate(self.test_set_list): for i, test_set in enumerate(self.test_set_list):
Log.blue(f"test dataset {i}: ") Log.blue(f"test dataset {i}: ")
print_dataset(test_set) print_dataset(test_set)

View File

@ -16,10 +16,10 @@ from utils.pts import PtsUtil
class StrategyGenerator(Runner): class StrategyGenerator(Runner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.load_experiment("generate") self.load_experiment("generate_strategy")
self.status_info = { self.status_info = {
"status_manager": status_manager, "status_manager": status_manager,
"app_name": "generate", "app_name": "generate_strategy",
"runner_name": "strategy_generator" "runner_name": "strategy_generator"
} }
self.to_specified_dir = ConfigManager.get("runner", "generate", "to_specified_dir") self.to_specified_dir = ConfigManager.get("runner", "generate", "to_specified_dir")
@ -36,7 +36,7 @@ class StrategyGenerator(Runner):
self.save_pts = ConfigManager.get("runner","generate","save_points") self.save_pts = ConfigManager.get("runner","generate","save_points")
for dataset_idx in range(len(dataset_name_list)): for dataset_idx in range(len(dataset_name_list)):
dataset_name = dataset_name_list[dataset_idx] dataset_name = dataset_name_list[dataset_idx]
status_manager.set_progress("generate", "strategy_generator", "dataset", dataset_idx, len(dataset_name_list)) status_manager.set_progress("generate_strategy", "strategy_generator", "dataset", dataset_idx, len(dataset_name_list))
root_dir = ConfigManager.get("datasets", dataset_name, "root_dir") root_dir = ConfigManager.get("datasets", dataset_name, "root_dir")
model_dir = ConfigManager.get("datasets", dataset_name, "model_dir") model_dir = ConfigManager.get("datasets", dataset_name, "model_dir")
scene_name_list = os.listdir(root_dir) scene_name_list = os.listdir(root_dir)
@ -44,10 +44,10 @@ class StrategyGenerator(Runner):
total = len(scene_name_list) total = len(scene_name_list)
for scene_name in scene_name_list: for scene_name in scene_name_list:
Log.info(f"({dataset_name})Processing [{cnt}/{total}]: {scene_name}") Log.info(f"({dataset_name})Processing [{cnt}/{total}]: {scene_name}")
status_manager.set_progress("generate", "strategy_generator", "scene", cnt, total) status_manager.set_progress("generate_strategy", "strategy_generator", "scene", cnt, total)
diag = DataLoadUtil.get_bbox_diag(model_dir, scene_name) diag = DataLoadUtil.get_bbox_diag(model_dir, scene_name)
voxel_threshold = diag*0.02 voxel_threshold = diag*0.02
status_manager.set_status("generate", "strategy_generator", "voxel_threshold", voxel_threshold) status_manager.set_status("generate_strategy", "strategy_generator", "voxel_threshold", voxel_threshold)
output_label_path = DataLoadUtil.get_label_path(root_dir, scene_name) output_label_path = DataLoadUtil.get_label_path(root_dir, scene_name)
if os.path.exists(output_label_path) and not self.overwrite: if os.path.exists(output_label_path) and not self.overwrite:
Log.info(f"Scene <{scene_name}> Already Exists, Skip") Log.info(f"Scene <{scene_name}> Already Exists, Skip")
@ -58,8 +58,8 @@ class StrategyGenerator(Runner):
except Exception as e: except Exception as e:
Log.error(f"Scene <{scene_name}> Failed, Error: {e}") Log.error(f"Scene <{scene_name}> Failed, Error: {e}")
cnt += 1 cnt += 1
status_manager.set_progress("generate", "strategy_generator", "scene", total, total) status_manager.set_progress("generate_strategy", "strategy_generator", "scene", total, total)
status_manager.set_progress("generate", "strategy_generator", "dataset", len(dataset_name_list), len(dataset_name_list)) status_manager.set_progress("generate_strategy", "strategy_generator", "dataset", len(dataset_name_list), len(dataset_name_list))
def create_experiment(self, backup_name=None): def create_experiment(self, backup_name=None):
super().create_experiment(backup_name) super().create_experiment(backup_name)
@ -70,7 +70,7 @@ class StrategyGenerator(Runner):
super().load_experiment(backup_name) super().load_experiment(backup_name)
def generate_sequence(self, root, model_dir, scene_name, voxel_threshold, overlap_threshold): def generate_sequence(self, root, model_dir, scene_name, voxel_threshold, overlap_threshold):
status_manager.set_status("generate", "strategy_generator", "scene", scene_name) status_manager.set_status("generate_strategy", "strategy_generator", "scene", scene_name)
frame_num = DataLoadUtil.get_scene_seq_length(root, scene_name) frame_num = DataLoadUtil.get_scene_seq_length(root, scene_name)
model_points_normals = DataLoadUtil.load_points_normals(root, scene_name) model_points_normals = DataLoadUtil.load_points_normals(root, scene_name)
model_pts = model_points_normals[:,:3] model_pts = model_points_normals[:,:3]
@ -81,7 +81,7 @@ class StrategyGenerator(Runner):
for frame_idx in range(frame_num): for frame_idx in range(frame_num):
path = DataLoadUtil.get_path(root, scene_name, frame_idx) path = DataLoadUtil.get_path(root, scene_name, frame_idx)
cam_params = DataLoadUtil.load_cam_info(path, binocular=True) cam_params = DataLoadUtil.load_cam_info(path, binocular=True)
status_manager.set_progress("generate", "strategy_generator", "loading frame", frame_idx, frame_num) status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num)
point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True) point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True)
#display_table = None #DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True, target_mask_label=()) #TODO #display_table = None #DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True, target_mask_label=()) #TODO
sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree) sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree)
@ -92,7 +92,7 @@ class StrategyGenerator(Runner):
os.makedirs(pts_dir) os.makedirs(pts_dir)
np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud) np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud)
pts_list.append(sampled_point_cloud) pts_list.append(sampled_point_cloud)
status_manager.set_progress("generate", "strategy_generator", "loading frame", frame_num, frame_num) status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_num, frame_num)
limited_useful_view, _, best_combined_pts = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(down_sampled_model_pts, pts_list, threshold=voxel_threshold, overlap_threshold=overlap_threshold, status_info=self.status_info) limited_useful_view, _, best_combined_pts = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(down_sampled_model_pts, pts_list, threshold=voxel_threshold, overlap_threshold=overlap_threshold, status_info=self.status_info)
data_pairs = self.generate_data_pairs(limited_useful_view) data_pairs = self.generate_data_pairs(limited_useful_view)
@ -102,7 +102,7 @@ class StrategyGenerator(Runner):
"max_coverage_rate": limited_useful_view[-1][1] "max_coverage_rate": limited_useful_view[-1][1]
} }
status_manager.set_status("generate", "strategy_generator", "max_coverage_rate", limited_useful_view[-1][1]) status_manager.set_status("generate_strategy", "strategy_generator", "max_coverage_rate", limited_useful_view[-1][1])
Log.success(f"Scene <{scene_name}> Finished, Max Coverage Rate: {limited_useful_view[-1][1]}, Best Sequence length: {len(limited_useful_view)}") Log.success(f"Scene <{scene_name}> Finished, Max Coverage Rate: {limited_useful_view[-1][1]}, Best Sequence length: {len(limited_useful_view)}")
output_label_path = DataLoadUtil.get_label_path(root, scene_name) output_label_path = DataLoadUtil.get_label_path(root, scene_name)

View File

@ -1,5 +1,6 @@
import numpy as np import numpy as np
import open3d as o3d import open3d as o3d
import torch
class PtsUtil: class PtsUtil:
@ -20,3 +21,8 @@ class PtsUtil:
def random_downsample_point_cloud(point_cloud, num_points): def random_downsample_point_cloud(point_cloud, num_points):
idx = np.random.choice(len(point_cloud), num_points, replace=True) idx = np.random.choice(len(point_cloud), num_points, replace=True)
return point_cloud[idx] return point_cloud[idx]
@staticmethod
def random_downsample_point_cloud_tensor(point_cloud, num_points):
idx = torch.randint(0, len(point_cloud), (num_points,))
return point_cloud[idx]

51
utils/render.py Normal file
View File

@ -0,0 +1,51 @@
import os
import json
import subprocess
import tempfile
from utils.data_load import DataLoadUtil
from utils.reconstruction import ReconstructionUtil
from utils.pts import PtsUtil
class RenderUtil:
@staticmethod
def render_pts(cam_pose, scene_path,script_path, model_points_normals, voxel_threshold=0.005, filter_degree=75, nO_to_nL_pose=None, require_full_scene=False):
nO_to_world_pose = cam_pose.cpu().numpy() @ nO_to_nL_pose
nO_to_world_pose = DataLoadUtil.cam_pose_transformation(nO_to_world_pose)
with tempfile.TemporaryDirectory() as temp_dir:
params = {
"cam_pose": nO_to_world_pose.tolist(),
"scene_path": scene_path
}
params_data_path = os.path.join(temp_dir, "params.json")
with open(params_data_path, 'w') as f:
json.dump(params, f)
result = subprocess.run([
'blender', '-b', '-P', script_path, '--', temp_dir
], capture_output=True, text=True)
if result.returncode != 0:
print("Blender script failed:")
print(result.stderr)
return None
path = os.path.join(temp_dir, "tmp")
# ------ Debug Start ------
import ipdb;ipdb.set_trace()
# ------ Debug End ------
point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True)
cam_params = DataLoadUtil.load_cam_info(path, binocular=True)
filtered_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=filter_degree)
full_scene_point_cloud = None
if require_full_scene:
depth_L, depth_R = DataLoadUtil.load_depth(path, cam_params['near_plane'], cam_params['far_plane'], binocular=True)
point_cloud_L = DataLoadUtil.get_point_cloud(depth_L, cam_params['cam_intrinsic'], cam_params['cam_to_world'])['points_world']
point_cloud_R = DataLoadUtil.get_point_cloud(depth_R, cam_params['cam_intrinsic'], cam_params['cam_to_world_R'])['points_world']
point_cloud_L = PtsUtil.random_downsample_point_cloud(point_cloud_L, 65536)
point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536)
full_scene_point_cloud = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R)
return filtered_point_cloud, full_scene_point_cloud