101 lines
4.3 KiB
Python
101 lines
4.3 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
from PytorchBoot.runners.runner import Runner
|
|
import PytorchBoot.stereotype as stereotype
|
|
import PytorchBoot.namespace as namespace
|
|
from PytorchBoot.utils.log_util import Log
|
|
from PytorchBoot.factory.component_factory import ComponentFactory
|
|
|
|
@stereotype.runner("reconstruction_runner")
|
|
class ReconstructionRunner(Runner):
|
|
def __init__(self, config_path):
|
|
super().__init__(config_path)
|
|
self.config_path = config_path
|
|
self.module_config = self.config.get("module", {})
|
|
self.pipeline_config = self.config.get("pipeline", {})
|
|
self.pipeline = ComponentFactory.create(
|
|
namespace.Stereotype.PIPELINE, self.pipeline_config
|
|
)
|
|
|
|
def run(self):
|
|
pass
|
|
|
|
def run_active_reconstruction(self,
|
|
initial_poses: np.ndarray,
|
|
initial_images: torch.Tensor = None,
|
|
max_iterations: int = 3):
|
|
Log.info("start active reconstruction...")
|
|
|
|
self.pipeline.train_nerf(
|
|
initial_images,
|
|
torch.from_numpy(initial_poses).float().to(self.device),
|
|
epochs=self.config.get("reconstruction", {}).get("epochs_per_iteration", 2000)
|
|
)
|
|
|
|
self.pipeline.save()
|
|
|
|
all_poses = initial_poses.copy()
|
|
current_poses = initial_poses.copy()
|
|
all_images = initial_images.clone()
|
|
|
|
# 提取初始网格
|
|
initial_mesh_path = os.path.join(self.output_dir, "initial_mesh.obj")
|
|
self.extract_mesh(
|
|
initial_mesh_path,
|
|
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
|
|
)
|
|
|
|
# 迭代执行主动重建
|
|
for iteration in range(max_iterations):
|
|
print(f"\n开始迭代 {iteration+1}/{max_iterations}")
|
|
|
|
# 选择下一批视角
|
|
next_views = self.policy.select_next_views(self.nerf_model, current_poses)
|
|
print(f"选择了 {len(next_views)} 个新视角")
|
|
|
|
# 采集新视角的图像
|
|
new_images = self._simulate_image_capture(next_views)
|
|
|
|
# 将新选择的视角添加到当前位姿和图像中
|
|
current_poses = np.concatenate([current_poses, next_views], axis=0)
|
|
all_poses = np.concatenate([all_poses, next_views], axis=0)
|
|
all_images = torch.cat([all_images, new_images], dim=0)
|
|
|
|
# 按照作者的描述,我们从初始模型重新初始化,而不是继续训练
|
|
# "After selecting additional images, we initialize the network with the model from the initialization step and refine the model further with the updated training set."
|
|
# 因此,我们先加载初始模型,然后用扩展的数据集重新训练
|
|
self.nerf_model.load_state_dict(torch.load(initial_model_path))
|
|
|
|
# 用扩展的数据集重新训练模型
|
|
self.train_nerf(
|
|
all_images,
|
|
torch.from_numpy(current_poses).float().to(self.device),
|
|
epochs=self.config.get("reconstruction", {}).get("epochs_per_iteration", 2000)
|
|
)
|
|
|
|
# 每次迭代后提取网格,以便观察重建质量的改进
|
|
iter_mesh_path = os.path.join(self.output_dir, f"mesh_iter_{iteration+1}.obj")
|
|
self.extract_mesh(
|
|
iter_mesh_path,
|
|
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
|
|
)
|
|
|
|
# 提取最终的3D网格
|
|
output_mesh_path = os.path.join(self.output_dir, "final_mesh.obj")
|
|
self.extract_mesh(
|
|
output_mesh_path,
|
|
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
|
|
)
|
|
|
|
# 评估重建质量
|
|
self.evaluate_reconstruction()
|
|
|
|
print("主动重建过程完成")
|
|
return all_poses
|
|
|
|
def create_experiment(self, backup_name=None):
|
|
return super().create_experiment(backup_name)
|
|
|
|
def load_experiment(self, backup_name=None):
|
|
super().load_experiment(backup_name) |