import torch import torch.nn.functional as F from typing import Tuple class VolumeRendererUtil: @staticmethod def render_rays( nerf_model, rays_o: torch.Tensor, rays_d: torch.Tensor, near: torch.Tensor, far: torch.Tensor, coarse_samples: int = 64, fine_samples: int = 128, perturb: bool = True ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ 渲染光线并计算不确定性(熵) 参数: nerf_model: NeRF模型(需实现forward方法) rays_o: 光线起点 [N_rays, 3] rays_d: 光线方向(已归一化) [N_rays, 3] near: 近平面距离 [N_rays] far: 远平面距离 [N_rays] coarse_samples: 粗采样点数 fine_samples: 精细采样点数 perturb: 是否在采样时添加噪声 返回: rgb_map: 渲染颜色 [N_rays, 3] weights: 权重分布 [N_rays, N_samples] t_vals: 采样点参数 [N_rays, N_samples] entropy: 每条光线的熵 [N_rays] """ # 粗采样 t_vals_coarse, points_coarse = VolumeRendererUtil.sample_along_ray( rays_o, rays_d, near, far, coarse_samples, perturb) # 重要性采样(精细) with torch.no_grad(): sigma_coarse, _ = nerf_model(points_coarse[..., :3], rays_d.unsqueeze(1)) weights_coarse = VolumeRendererUtil.compute_weights(sigma_coarse, t_vals_coarse, rays_d) t_vals_fine = VolumeRendererUtil.importance_sampling(t_vals_coarse, weights_coarse, fine_samples) # 合并采样点 t_vals = torch.sort(torch.cat([t_vals_coarse, t_vals_fine], -1)).values points = rays_o[..., None, :] + t_vals[..., None] * rays_d[..., None, :] # 精细渲染 sigma, color = nerf_model(points[..., :3], rays_d.unsqueeze(1)) rgb_map, weights = VolumeRendererUtil.volume_rendering(sigma, color, t_vals, rays_d) entropy = VolumeRendererUtil.calculate_entropy(weights) return rgb_map, weights, t_vals, entropy @staticmethod def importance_sampling( t_vals: torch.Tensor, weights: torch.Tensor, n_samples: int ) -> torch.Tensor: """ 重要性采样(根据权重分布生成新采样点) 参数: t_vals: 原始采样点参数 [N_rays, N_coarse] weights: 权重分布 [N_rays, N_coarse] n_samples: 需要生成的采样点数 返回: samples: 新采样点参数 [N_rays, N_fine] """ weights = weights + 1e-5 # 防止除零 pdf = weights / torch.sum(weights, -1, keepdims=True) cdf = torch.cumsum(pdf, -1) # 逆变换采样 u = torch.linspace(0, 1, n_samples, device=weights.device) u = u.expand(list(cdf.shape[:-1]) + [n_samples]) indices = torch.searchsorted(cdf, u, right=True) # 插值得到新采样点 below = torch.max(torch.zeros_like(indices), indices - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(indices), indices) indices_g = torch.stack([below, above], -1) cdf_g = torch.gather(cdf, -1, indices_g) t_vals_g = torch.gather(t_vals, -1, indices_g) denom = cdf_g[..., 1] - cdf_g[..., 0] denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = t_vals_g[..., 0] + t * (t_vals_g[..., 1] - t_vals_g[..., 0]) return samples @staticmethod def sample_along_ray( rays_o: torch.Tensor, rays_d: torch.Tensor, near: torch.Tensor, far: torch.Tensor, n_samples: int, perturb: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: """ 沿光线分层采样点 参数: rays_o: 光线起点 [N_rays, 3] rays_d: 光线方向 [N_rays, 3] near: 近平面距离 [N_rays] far: 远平面距离 [N_rays] n_samples: 采样点数 perturb: 是否添加噪声 返回: t_vals: 采样点参数 [N_rays, N_samples] points: 采样点3D坐标 [N_rays, N_samples, 3] """ # 基础分层采样 t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device) t_vals = near + (far - near) * t_vals.unsqueeze(0) if perturb: # 添加分层噪声 mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1]) upper = torch.cat([mids, t_vals[..., -1:]], -1) lower = torch.cat([t_vals[..., :1], mids], -1) t_rand = torch.rand(t_vals.shape, device=rays_o.device) t_vals = lower + (upper - lower) * t_rand # 生成3D点 points = rays_o.unsqueeze(1) + t_vals.unsqueeze(-1) * rays_d.unsqueeze(1) return t_vals, points @staticmethod def volume_rendering( sigma: torch.Tensor, color: torch.Tensor, t_vals: torch.Tensor, rays_d: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ 执行体积渲染 参数: sigma: 体积密度 [N_rays, N_samples, 1] color: RGB颜色 [N_rays, N_samples, 3] t_vals: 采样点参数 [N_rays, N_samples] rays_d: 光线方向 [N_rays, 3] 返回: rgb_map: 渲染颜色 [N_rays, 3] weights: 权重分布 [N_rays, N_samples] """ dists = t_vals[..., 1:] - t_vals[..., :-1] dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).expand(dists[..., :1].shape)], -1) dists = dists * torch.norm(rays_d[..., None, :], dim=-1) alpha = 1. - torch.exp(-sigma.squeeze(-1) * dists) trans = torch.exp(-torch.cat([ torch.zeros_like(sigma[..., :1, 0]), torch.cumsum(sigma[..., :-1, 0] * dists[..., :-1].unsqueeze(-1), dim=-2) ], dim=-2)) weights = alpha * trans.squeeze(-1) rgb_map = torch.sum(weights.unsqueeze(-1) * color, dim=-2) return rgb_map, weights @staticmethod def calculate_entropy(weights: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: """ 计算权重分布的熵 参数: weights: 权重分布 [N_rays, N_samples] eps: 防止log(0)的小量 返回: entropy: 每条光线的熵 [N_rays] """ norm_weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + eps) entropy = -torch.sum(norm_weights * torch.log(norm_weights + eps), dim=-1) return entropy @staticmethod def compute_weights(sigma: torch.Tensor, t_vals: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor: """计算权重(用于重要性采样)""" dists = t_vals[..., 1:] - t_vals[..., :-1] dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).expand(dists[..., :1].shape)], -1) dists = dists * torch.norm(rays_d[..., None, :], dim=-1) alpha = 1. - torch.exp(-sigma.squeeze(-1) * dists) trans = torch.exp(-torch.cat([ torch.zeros_like(sigma[..., :1, 0]), torch.cumsum(sigma[..., :-1, 0] * dists[..., :-1].unsqueeze(-1), dim=-2) ], dim=-2)) return alpha * trans.squeeze(-1)