nbv_rec_uncertainty_guide/utils/volume_render_util.py
2025-04-20 10:26:09 +08:00

201 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
from typing import Tuple
class VolumeRendererUtil:
@staticmethod
def render_rays(
nerf_model,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
near: torch.Tensor,
far: torch.Tensor,
coarse_samples: int = 64,
fine_samples: int = 128,
perturb: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
渲染光线并计算不确定性(熵)
参数:
nerf_model: NeRF模型需实现forward方法
rays_o: 光线起点 [N_rays, 3]
rays_d: 光线方向(已归一化) [N_rays, 3]
near: 近平面距离 [N_rays]
far: 远平面距离 [N_rays]
coarse_samples: 粗采样点数
fine_samples: 精细采样点数
perturb: 是否在采样时添加噪声
返回:
rgb_map: 渲染颜色 [N_rays, 3]
weights: 权重分布 [N_rays, N_samples]
t_vals: 采样点参数 [N_rays, N_samples]
entropy: 每条光线的熵 [N_rays]
"""
# 粗采样
t_vals_coarse, points_coarse = VolumeRendererUtil.sample_along_ray(
rays_o, rays_d, near, far, coarse_samples, perturb)
# 重要性采样(精细)
with torch.no_grad():
sigma_coarse, _ = nerf_model(points_coarse[..., :3], rays_d.unsqueeze(1))
weights_coarse = VolumeRendererUtil.compute_weights(sigma_coarse, t_vals_coarse, rays_d)
t_vals_fine = VolumeRendererUtil.importance_sampling(t_vals_coarse, weights_coarse, fine_samples)
# 合并采样点
t_vals = torch.sort(torch.cat([t_vals_coarse, t_vals_fine], -1)).values
points = rays_o[..., None, :] + t_vals[..., None] * rays_d[..., None, :]
# 精细渲染
sigma, color = nerf_model(points[..., :3], rays_d.unsqueeze(1))
rgb_map, weights = VolumeRendererUtil.volume_rendering(sigma, color, t_vals, rays_d)
entropy = VolumeRendererUtil.calculate_entropy(weights)
return rgb_map, weights, t_vals, entropy
@staticmethod
def importance_sampling(
t_vals: torch.Tensor,
weights: torch.Tensor,
n_samples: int
) -> torch.Tensor:
"""
重要性采样(根据权重分布生成新采样点)
参数:
t_vals: 原始采样点参数 [N_rays, N_coarse]
weights: 权重分布 [N_rays, N_coarse]
n_samples: 需要生成的采样点数
返回:
samples: 新采样点参数 [N_rays, N_fine]
"""
weights = weights + 1e-5 # 防止除零
pdf = weights / torch.sum(weights, -1, keepdims=True)
cdf = torch.cumsum(pdf, -1)
# 逆变换采样
u = torch.linspace(0, 1, n_samples, device=weights.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
indices = torch.searchsorted(cdf, u, right=True)
# 插值得到新采样点
below = torch.max(torch.zeros_like(indices), indices - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(indices), indices)
indices_g = torch.stack([below, above], -1)
cdf_g = torch.gather(cdf, -1, indices_g)
t_vals_g = torch.gather(t_vals, -1, indices_g)
denom = cdf_g[..., 1] - cdf_g[..., 0]
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = t_vals_g[..., 0] + t * (t_vals_g[..., 1] - t_vals_g[..., 0])
return samples
@staticmethod
def sample_along_ray(
rays_o: torch.Tensor,
rays_d: torch.Tensor,
near: torch.Tensor,
far: torch.Tensor,
n_samples: int,
perturb: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
沿光线分层采样点
参数:
rays_o: 光线起点 [N_rays, 3]
rays_d: 光线方向 [N_rays, 3]
near: 近平面距离 [N_rays]
far: 远平面距离 [N_rays]
n_samples: 采样点数
perturb: 是否添加噪声
返回:
t_vals: 采样点参数 [N_rays, N_samples]
points: 采样点3D坐标 [N_rays, N_samples, 3]
"""
# 基础分层采样
t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
t_vals = near + (far - near) * t_vals.unsqueeze(0)
if perturb:
# 添加分层噪声
mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
upper = torch.cat([mids, t_vals[..., -1:]], -1)
lower = torch.cat([t_vals[..., :1], mids], -1)
t_rand = torch.rand(t_vals.shape, device=rays_o.device)
t_vals = lower + (upper - lower) * t_rand
# 生成3D点
points = rays_o.unsqueeze(1) + t_vals.unsqueeze(-1) * rays_d.unsqueeze(1)
return t_vals, points
@staticmethod
def volume_rendering(
sigma: torch.Tensor,
color: torch.Tensor,
t_vals: torch.Tensor,
rays_d: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
执行体积渲染
参数:
sigma: 体积密度 [N_rays, N_samples, 1]
color: RGB颜色 [N_rays, N_samples, 3]
t_vals: 采样点参数 [N_rays, N_samples]
rays_d: 光线方向 [N_rays, 3]
返回:
rgb_map: 渲染颜色 [N_rays, 3]
weights: 权重分布 [N_rays, N_samples]
"""
dists = t_vals[..., 1:] - t_vals[..., :-1]
dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).expand(dists[..., :1].shape)], -1)
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
alpha = 1. - torch.exp(-sigma.squeeze(-1) * dists)
trans = torch.exp(-torch.cat([
torch.zeros_like(sigma[..., :1, 0]),
torch.cumsum(sigma[..., :-1, 0] * dists[..., :-1].unsqueeze(-1), dim=-2)
], dim=-2))
weights = alpha * trans.squeeze(-1)
rgb_map = torch.sum(weights.unsqueeze(-1) * color, dim=-2)
return rgb_map, weights
@staticmethod
def calculate_entropy(weights: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
"""
计算权重分布的熵
参数:
weights: 权重分布 [N_rays, N_samples]
eps: 防止log(0)的小量
返回:
entropy: 每条光线的熵 [N_rays]
"""
norm_weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + eps)
entropy = -torch.sum(norm_weights * torch.log(norm_weights + eps), dim=-1)
return entropy
@staticmethod
def compute_weights(sigma: torch.Tensor, t_vals: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
"""计算权重(用于重要性采样)"""
dists = t_vals[..., 1:] - t_vals[..., :-1]
dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).expand(dists[..., :1].shape)], -1)
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
alpha = 1. - torch.exp(-sigma.squeeze(-1) * dists)
trans = torch.exp(-torch.cat([
torch.zeros_like(sigma[..., :1, 0]),
torch.cumsum(sigma[..., :-1, 0] * dists[..., :-1].unsqueeze(-1), dim=-2)
], dim=-2))
return alpha * trans.squeeze(-1)