import os import torch import numpy as np from PytorchBoot.factory.component_factory import ComponentFactory import PytorchBoot.stereotype as stereotype import PytorchBoot.namespace as namespace from PytorchBoot.utils.log_util import Log from utils.volume_render_util import VolumeRendererUtil @stereotype.pipeline("reconstruction_pipeline") class ReconstructionPipeline: def __init__(self, config:dict): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.module_config = config["modules"] self.nerf = ComponentFactory.create( namespace.Stereotype.MODULE, self.module_config["nerf"] ) self.nerf_model_output_dir = self.config.get("nerf_model_output_dir", "./output/nerf_model") def create_experiment(self, backup_name=None): return super().create_experiment(backup_name) def load_experiment(self, backup_name=None): super().load_experiment(backup_name) def save(self, object_name: str, best_model: bool = True, name: str|None = None): output_dir = os.path.join(self.nerf_model_output_dir, object_name) os.makedirs(output_dir, exist_ok=True) if best_model: torch.save(self.nerf.state_dict(), os.path.join(output_dir, "best_model.pth")) elif name is not None: torch.save(self.nerf.state_dict(), os.path.join(output_dir, f"{name}.pth")) else: Log.error("save failed, best_model and name cannot be None at the same time", terminate=True) Log.info(f"save {object_name} to {output_dir}") return output_dir def load(self, object_name: str, best_model: bool = True, name: str|None = None): output_dir = os.path.join(self.nerf_model_output_dir, object_name) if best_model: self.nerf.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth"))) elif name is not None: self.nerf.load_state_dict(torch.load(os.path.join(output_dir, f"{name}.pth"))) else: Log.error("save failed, best_model and name cannot be None at the same time", terminate=True) Log.info(f"load {object_name} from {output_dir}") return output_dir def train_nerf(self, images: torch.Tensor, poses: torch.Tensor, epochs: int = 5000, batch_size: int = 4096, lr: float = 5e-4, start_from_model=None, object_name: str = "unknown") -> float: output_dir = os.path.join(self.nerf_model_output_dir, object_name) os.makedirs(output_dir, exist_ok=True) Log.info("train NeRF model with {} images".format(len(images))) H, W = images.shape[1], images.shape[2] sampling_config = self.config.get("sampling", {}) camera_config = self.config.get("camera", {}) focal = camera_config.get("focal", 1000.0) near = camera_config.get("near", 2.0) far = camera_config.get("far", 6.0) coarse_samples = sampling_config.get("coarse_samples", 64) fine_samples = sampling_config.get("fine_samples", 128) perturb = sampling_config.get("perturb", True) if start_from_model is not None: self.nerf.load_state_dict(start_from_model.state_dict()) optimizer = torch.optim.Adam(self.nerf.parameters(), lr=lr) mse_loss = torch.nn.MSELoss() self.nerf.train() rays_o, rays_d = ReconstructionPipeline.generate_rays(poses, H, W, focal) rays_o = rays_o.to(self.device) rays_d = rays_d.to(self.device) images = images.to(self.device) best_loss = float('inf') for epoch in range(epochs): batch_rays_o, batch_rays_d, target_pixels = ReconstructionPipeline.sample_pixel_batch( images, rays_o, rays_d, batch_size) batch_rays_d = torch.nn.functional.normalize(batch_rays_d, dim=-1) near_tensor = torch.ones_like(batch_rays_o[..., 0]) * near far_tensor = torch.ones_like(batch_rays_o[..., 0]) * far optimizer.zero_grad() rgb_map, _, _, _ = VolumeRendererUtil.render_rays( self.nerf, batch_rays_o, batch_rays_d, near_tensor, far_tensor, coarse_samples, fine_samples, perturb ) loss = mse_loss(rgb_map, target_pixels) loss.backward() optimizer.step() if (epoch + 1) % 100 == 0: psnr = -10.0 * torch.log10(loss) Log.info(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.6f}, PSNR: {psnr.item():.2f}") if loss.item() < best_loss: best_loss = loss.item() torch.save(self.nerf.state_dict(), os.path.join(output_dir, "best_model.pth")) self.nerf.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth"))) Log.info(f"finish training, best loss: {best_loss:.6f}") return best_loss @staticmethod def generate_rays( poses: torch.Tensor, H: int, W: int, focal: float) -> tuple: i, j = torch.meshgrid( torch.linspace(0, W-1, W), torch.linspace(0, H-1, H), indexing='ij' ) i = i.t() # [H, W] j = j.t() # [H, W] dirs = torch.stack([ (i - W * 0.5) / focal, -(j - H * 0.5) / focal, -torch.ones_like(i) ], dim=-1) # [H, W, 3] rays_o_list = [] rays_d_list = [] for pose in poses: rays_d = torch.sum(dirs[..., None, :] * pose[:3, :3], dim=-1) # [H, W, 3] rays_o = pose[:3, -1].expand(rays_d.shape) # [H, W, 3] rays_o = rays_o.reshape(-1, 3) # [H*W, 3] rays_d = rays_d.reshape(-1, 3) # [H*W, 3] rays_o_list.append(rays_o) rays_d_list.append(rays_d) rays_o_all = torch.stack(rays_o_list, dim=0) # [N, H*W, 3] rays_d_all = torch.stack(rays_d_list, dim=0) # [N, H*W, 3] return rays_o_all, rays_d_all @staticmethod def sample_pixel_batch( images: torch.Tensor, rays_o: torch.Tensor, rays_d: torch.Tensor, batch_size: int) -> tuple: N = images.shape[0] H = images.shape[1] W = images.shape[2] total_rays = N * H * W pixels = images.reshape(N, -1, 3) # [N, H*W, 3] indices = torch.randint(0, total_rays, size=(batch_size,)) img_indices = indices // (H * W) pixel_indices = indices % (H * W) sampled_rays_o = torch.stack([rays_o[i, j] for i, j in zip(img_indices, pixel_indices)]) sampled_rays_d = torch.stack([rays_d[i, j] for i, j in zip(img_indices, pixel_indices)]) sampled_pixels = torch.stack([pixels[i, j] for i, j in zip(img_indices, pixel_indices)]) return sampled_rays_o, sampled_rays_d, sampled_pixels