193 lines
7.5 KiB
Python
193 lines
7.5 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
|
|
from PytorchBoot.factory.component_factory import ComponentFactory
|
|
import PytorchBoot.stereotype as stereotype
|
|
import PytorchBoot.namespace as namespace
|
|
from PytorchBoot.utils.log_util import Log
|
|
|
|
from utils.volume_render_util import VolumeRendererUtil
|
|
|
|
|
|
@stereotype.pipeline("reconstruction_pipeline")
|
|
class ReconstructionPipeline:
|
|
def __init__(self, config:dict):
|
|
self.config = config
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.module_config = config["modules"]
|
|
self.nerf = ComponentFactory.create(
|
|
namespace.Stereotype.MODULE, self.module_config["nerf"]
|
|
)
|
|
self.nerf_model_output_dir = self.config.get("nerf_model_output_dir", "./output/nerf_model")
|
|
|
|
def create_experiment(self, backup_name=None):
|
|
return super().create_experiment(backup_name)
|
|
|
|
def load_experiment(self, backup_name=None):
|
|
super().load_experiment(backup_name)
|
|
|
|
def save(self, object_name: str, best_model: bool = True, name: str|None = None):
|
|
output_dir = os.path.join(self.nerf_model_output_dir, object_name)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
if best_model:
|
|
torch.save(self.nerf.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
|
elif name is not None:
|
|
torch.save(self.nerf.state_dict(), os.path.join(output_dir, f"{name}.pth"))
|
|
else:
|
|
Log.error("save failed, best_model and name cannot be None at the same time", terminate=True)
|
|
Log.info(f"save {object_name} to {output_dir}")
|
|
return output_dir
|
|
|
|
def load(self, object_name: str, best_model: bool = True, name: str|None = None):
|
|
output_dir = os.path.join(self.nerf_model_output_dir, object_name)
|
|
if best_model:
|
|
self.nerf.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
|
|
elif name is not None:
|
|
self.nerf.load_state_dict(torch.load(os.path.join(output_dir, f"{name}.pth")))
|
|
else:
|
|
Log.error("save failed, best_model and name cannot be None at the same time", terminate=True)
|
|
Log.info(f"load {object_name} from {output_dir}")
|
|
return output_dir
|
|
|
|
|
|
def train_nerf(self,
|
|
images: torch.Tensor,
|
|
poses: torch.Tensor,
|
|
epochs: int = 5000,
|
|
batch_size: int = 4096,
|
|
lr: float = 5e-4,
|
|
start_from_model=None,
|
|
object_name: str = "unknown") -> float:
|
|
|
|
output_dir = os.path.join(self.nerf_model_output_dir, object_name)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
Log.info("train NeRF model with {} images".format(len(images)))
|
|
H, W = images.shape[1], images.shape[2]
|
|
sampling_config = self.config.get("sampling", {})
|
|
camera_config = self.config.get("camera", {})
|
|
focal = camera_config.get("focal", 1000.0)
|
|
near = camera_config.get("near", 2.0)
|
|
far = camera_config.get("far", 6.0)
|
|
coarse_samples = sampling_config.get("coarse_samples", 64)
|
|
fine_samples = sampling_config.get("fine_samples", 128)
|
|
perturb = sampling_config.get("perturb", True)
|
|
|
|
|
|
if start_from_model is not None:
|
|
self.nerf.load_state_dict(start_from_model.state_dict())
|
|
|
|
optimizer = torch.optim.Adam(self.nerf.parameters(), lr=lr)
|
|
mse_loss = torch.nn.MSELoss()
|
|
|
|
self.nerf.train()
|
|
|
|
rays_o, rays_d = ReconstructionPipeline.generate_rays(poses, H, W, focal)
|
|
rays_o = rays_o.to(self.device)
|
|
rays_d = rays_d.to(self.device)
|
|
images = images.to(self.device)
|
|
|
|
best_loss = float('inf')
|
|
for epoch in range(epochs):
|
|
batch_rays_o, batch_rays_d, target_pixels = ReconstructionPipeline.sample_pixel_batch(
|
|
images, rays_o, rays_d, batch_size)
|
|
|
|
batch_rays_d = torch.nn.functional.normalize(batch_rays_d, dim=-1)
|
|
|
|
near_tensor = torch.ones_like(batch_rays_o[..., 0]) * near
|
|
far_tensor = torch.ones_like(batch_rays_o[..., 0]) * far
|
|
|
|
optimizer.zero_grad()
|
|
|
|
rgb_map, _, _, _ = VolumeRendererUtil.render_rays(
|
|
self.nerf,
|
|
batch_rays_o,
|
|
batch_rays_d,
|
|
near_tensor,
|
|
far_tensor,
|
|
coarse_samples,
|
|
fine_samples,
|
|
perturb
|
|
)
|
|
|
|
loss = mse_loss(rgb_map, target_pixels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if (epoch + 1) % 100 == 0:
|
|
psnr = -10.0 * torch.log10(loss)
|
|
Log.info(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.6f}, PSNR: {psnr.item():.2f}")
|
|
|
|
if loss.item() < best_loss:
|
|
best_loss = loss.item()
|
|
torch.save(self.nerf.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
|
|
|
self.nerf.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
|
|
|
|
Log.info(f"finish training, best loss: {best_loss:.6f}")
|
|
return best_loss
|
|
|
|
@staticmethod
|
|
def generate_rays(
|
|
poses: torch.Tensor,
|
|
H: int,
|
|
W: int,
|
|
focal: float) -> tuple:
|
|
|
|
i, j = torch.meshgrid(
|
|
torch.linspace(0, W-1, W),
|
|
torch.linspace(0, H-1, H),
|
|
indexing='ij'
|
|
)
|
|
i = i.t() # [H, W]
|
|
j = j.t() # [H, W]
|
|
|
|
dirs = torch.stack([
|
|
(i - W * 0.5) / focal,
|
|
-(j - H * 0.5) / focal,
|
|
-torch.ones_like(i)
|
|
], dim=-1) # [H, W, 3]
|
|
|
|
rays_o_list = []
|
|
rays_d_list = []
|
|
|
|
for pose in poses:
|
|
rays_d = torch.sum(dirs[..., None, :] * pose[:3, :3], dim=-1) # [H, W, 3]
|
|
|
|
rays_o = pose[:3, -1].expand(rays_d.shape) # [H, W, 3]
|
|
|
|
rays_o = rays_o.reshape(-1, 3) # [H*W, 3]
|
|
rays_d = rays_d.reshape(-1, 3) # [H*W, 3]
|
|
|
|
rays_o_list.append(rays_o)
|
|
rays_d_list.append(rays_d)
|
|
|
|
rays_o_all = torch.stack(rays_o_list, dim=0) # [N, H*W, 3]
|
|
rays_d_all = torch.stack(rays_d_list, dim=0) # [N, H*W, 3]
|
|
|
|
return rays_o_all, rays_d_all
|
|
|
|
@staticmethod
|
|
def sample_pixel_batch(
|
|
images: torch.Tensor,
|
|
rays_o: torch.Tensor,
|
|
rays_d: torch.Tensor,
|
|
batch_size: int) -> tuple:
|
|
|
|
N = images.shape[0]
|
|
H = images.shape[1]
|
|
W = images.shape[2]
|
|
total_rays = N * H * W
|
|
|
|
pixels = images.reshape(N, -1, 3) # [N, H*W, 3]
|
|
|
|
indices = torch.randint(0, total_rays, size=(batch_size,))
|
|
img_indices = indices // (H * W)
|
|
pixel_indices = indices % (H * W)
|
|
|
|
sampled_rays_o = torch.stack([rays_o[i, j] for i, j in zip(img_indices, pixel_indices)])
|
|
sampled_rays_d = torch.stack([rays_d[i, j] for i, j in zip(img_indices, pixel_indices)])
|
|
sampled_pixels = torch.stack([pixels[i, j] for i, j in zip(img_indices, pixel_indices)])
|
|
|
|
return sampled_rays_o, sampled_rays_d, sampled_pixels |