2025-04-20 10:26:09 +08:00

101 lines
4.3 KiB
Python

import os
import torch
import numpy as np
from PytorchBoot.runners.runner import Runner
import PytorchBoot.stereotype as stereotype
import PytorchBoot.namespace as namespace
from PytorchBoot.utils.log_util import Log
from PytorchBoot.factory.component_factory import ComponentFactory
@stereotype.runner("reconstruction_runner")
class ReconstructionRunner(Runner):
def __init__(self, config_path):
super().__init__(config_path)
self.config_path = config_path
self.module_config = self.config.get("module", {})
self.pipeline_config = self.config.get("pipeline", {})
self.pipeline = ComponentFactory.create(
namespace.Stereotype.PIPELINE, self.pipeline_config
)
def run(self):
pass
def run_active_reconstruction(self,
initial_poses: np.ndarray,
initial_images: torch.Tensor = None,
max_iterations: int = 3):
Log.info("start active reconstruction...")
self.pipeline.train_nerf(
initial_images,
torch.from_numpy(initial_poses).float().to(self.device),
epochs=self.config.get("reconstruction", {}).get("epochs_per_iteration", 2000)
)
self.pipeline.save()
all_poses = initial_poses.copy()
current_poses = initial_poses.copy()
all_images = initial_images.clone()
# 提取初始网格
initial_mesh_path = os.path.join(self.output_dir, "initial_mesh.obj")
self.extract_mesh(
initial_mesh_path,
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
)
# 迭代执行主动重建
for iteration in range(max_iterations):
print(f"\n开始迭代 {iteration+1}/{max_iterations}")
# 选择下一批视角
next_views = self.policy.select_next_views(self.nerf_model, current_poses)
print(f"选择了 {len(next_views)} 个新视角")
# 采集新视角的图像
new_images = self._simulate_image_capture(next_views)
# 将新选择的视角添加到当前位姿和图像中
current_poses = np.concatenate([current_poses, next_views], axis=0)
all_poses = np.concatenate([all_poses, next_views], axis=0)
all_images = torch.cat([all_images, new_images], dim=0)
# 按照作者的描述,我们从初始模型重新初始化,而不是继续训练
# "After selecting additional images, we initialize the network with the model from the initialization step and refine the model further with the updated training set."
# 因此,我们先加载初始模型,然后用扩展的数据集重新训练
self.nerf_model.load_state_dict(torch.load(initial_model_path))
# 用扩展的数据集重新训练模型
self.train_nerf(
all_images,
torch.from_numpy(current_poses).float().to(self.device),
epochs=self.config.get("reconstruction", {}).get("epochs_per_iteration", 2000)
)
# 每次迭代后提取网格,以便观察重建质量的改进
iter_mesh_path = os.path.join(self.output_dir, f"mesh_iter_{iteration+1}.obj")
self.extract_mesh(
iter_mesh_path,
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
)
# 提取最终的3D网格
output_mesh_path = os.path.join(self.output_dir, "final_mesh.obj")
self.extract_mesh(
output_mesh_path,
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
)
# 评估重建质量
self.evaluate_reconstruction()
print("主动重建过程完成")
return all_poses
def create_experiment(self, backup_name=None):
return super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)