用Python从零实现NeRF:代码驱动理解神经辐射场
当你第一次听说NeRF(神经辐射场)时,可能被那些复杂的数学公式吓退了。但今天,我们要用另一种方式理解它——通过代码。本文将带你用Python一步步实现NeRF的核心组件,把抽象概念转化为可运行的代码逻辑。
1. 环境准备与基础概念
在开始编码前,我们需要明确几个关键概念。神经辐射场本质上是一个函数,它接受空间坐标和观察方向作为输入,输出该点的颜色和密度。这种表示方式让我们能够从任意角度渲染出逼真的3D场景。
首先安装必要的Python库:
pip install torch numpy matplotlib imageio接着导入我们将用到的模块:
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader2. 位置编码实现
原始NeRF论文发现,直接将坐标输入神经网络会导致高频细节丢失。解决方案是使用位置编码(Positional Encoding)将输入映射到高维空间。
class PositionalEncoder(nn.Module): def __init__(self, L=10): super().__init__() self.L = L def forward(self, x): # x: [batch_size, 3] or [batch_size, 2] for direction encoded = [] for i in range(self.L): encoded.append(torch.sin(2**i * torch.pi * x)) encoded.append(torch.cos(2**i * torch.pi * x)) return torch.cat(encoded, dim=-1)这个编码器会对每个输入维度产生2L个特征(正弦和余弦交替)。对于3D坐标,输出维度将是3×2L=60(当L=10时)。
3. NeRF核心网络架构
NeRF的主干网络是一个多层感知机(MLP),它有两个特殊设计:
- 密度预测分支只使用空间坐标
- 颜色预测分支同时使用空间坐标和观察方向
class NeRFModel(nn.Module): def __init__(self, pos_L=10, dir_L=4, hidden_dim=256): super().__init__() # 位置编码 self.pos_encoder = PositionalEncoder(pos_L) self.dir_encoder = PositionalEncoder(dir_L) # 主干网络 self.layer1 = nn.Linear(3*2*pos_L, hidden_dim) self.layer2 = nn.Linear(hidden_dim, hidden_dim) self.layer3 = nn.Linear(hidden_dim, hidden_dim) self.layer4 = nn.Linear(hidden_dim, hidden_dim) # 密度输出头 self.density_out = nn.Linear(hidden_dim, 1) # 特征+方向→颜色分支 self.layer5 = nn.Linear(hidden_dim + 3*2*dir_L, hidden_dim//2) self.color_out = nn.Linear(hidden_dim//2, 3) def forward(self, pos, dir): # 编码输入 pos_encoded = self.pos_encoder(pos) dir_encoded = self.dir_encoder(dir) # 主干网络 x = F.relu(self.layer1(pos_encoded)) x = F.relu(self.layer2(x)) x = F.relu(self.layer3(x)) x = F.relu(self.layer4(x)) # 密度预测 density = F.relu(self.density_out(x)) # 颜色预测 x = torch.cat([x, dir_encoded], dim=-1) x = F.relu(self.layer5(x)) color = torch.sigmoid(self.color_out(x)) return color, density4. 体渲染实现
体渲染是NeRF的核心,它通过沿着光线积分来计算最终像素颜色。我们需要实现两个关键函数:
def compute_transmittance(densities, deltas): # densities: [n_samples], deltas: [n_samples] (相邻采样点间距) transmittance = torch.exp(-torch.cumsum(densities * deltas, dim=0)) return transmittance def render_rays(model, ray_origin, ray_dir, near=0.0, far=1.0, n_samples=64): # 沿光线采样点 t_vals = torch.linspace(near, far, n_samples) points = ray_origin + ray_dir * t_vals.unsqueeze(-1) # 计算相邻采样点间距 deltas = t_vals[1:] - t_vals[:-1] deltas = torch.cat([deltas, torch.tensor([1e10])]) # 最后一个点给个大值 # 查询模型 dirs = ray_dir.expand(n_samples, -1) colors, densities = model(points, dirs) # 计算透射率 transmittance = compute_transmittance(densities.squeeze(), deltas) # 计算权重 weights = transmittance * (1 - torch.exp(-densities.squeeze() * deltas)) # 积分得到最终颜色 pixel_color = torch.sum(weights.unsqueeze(-1) * colors, dim=0) return pixel_color5. 训练流程与结果可视化
现在我们可以把这些组件组合起来训练NeRF模型。以下是简化的训练循环:
def train(model, optimizer, dataset, n_epochs=1000): for epoch in range(n_epochs): total_loss = 0 for ray_origin, ray_dir, target_color in dataset: optimizer.zero_grad() # 渲染光线 pred_color = render_rays(model, ray_origin, ray_dir) # 计算损失 loss = F.mse_loss(pred_color, target_color) loss.backward() optimizer.step() total_loss += loss.item() if epoch % 100 == 0: print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}") visualize_results(model)可视化函数可以帮助我们观察训练进展:
def visualize_results(model, resolution=50): # 生成测试视角 xx, yy = torch.meshgrid(torch.linspace(-1, 1, resolution), torch.linspace(-1, 1, resolution)) zz = torch.ones_like(xx) * 0.5 # 固定深度 # 渲染每个像素 image = torch.zeros(resolution, resolution, 3) for i in range(resolution): for j in range(resolution): ray_origin = torch.tensor([0, 0, -1]) # 相机位置 ray_dir = torch.tensor([xx[i,j], yy[i,j], zz[i,j]]) - ray_origin ray_dir = ray_dir / torch.norm(ray_dir) with torch.no_grad(): color = render_rays(model, ray_origin, ray_dir) image[i,j] = color # 显示图像 plt.imshow(image.numpy()) plt.axis('off') plt.show()6. 分层采样优化
原始NeRF论文提出了分层采样策略来提高渲染质量。我们可以在现有代码基础上实现这一优化:
def hierarchical_sampling(model, ray_origin, ray_dir, n_coarse=64, n_fine=128): # 粗采样 t_vals_coarse = torch.linspace(0, 1, n_coarse) points_coarse = ray_origin + ray_dir * t_vals_coarse.unsqueeze(-1) # 查询粗采样点 with torch.no_grad(): _, densities = model(points_coarse, ray_dir.expand(n_coarse, -1)) # 根据密度重新分配采样点 weights = compute_weights(densities.squeeze(), t_vals_coarse[1:] - t_vals_coarse[:-1]) t_vals_fine = sample_pdf(t_vals_coarse, weights, n_fine) # 合并粗细采样点 t_vals = torch.sort(torch.cat([t_vals_coarse, t_vals_fine]))[0] points = ray_origin + ray_dir * t_vals.unsqueeze(-1) # 最终渲染 colors, densities = model(points, ray_dir.expand(len(t_vals), -1)) pixel_color = integrate(colors, densities, t_vals[1:] - t_vals[:-1]) return pixel_color def sample_pdf(bins, weights, n_samples): # 根据权重进行重要性采样 weights = weights + 1e-5 # 防止除零 pdf = weights / torch.sum(weights) cdf = torch.cumsum(pdf, dim=0) cdf = torch.cat([torch.zeros_like(cdf[:1]), cdf]) # 逆变换采样 u = torch.linspace(0, 1, n_samples) u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds-1), inds-1) above = torch.min((cdf.shape[0]-1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) cdf_g = torch.gather(cdf, 0, inds_g) bins_g = torch.gather(bins, 0, inds_g) denom = cdf_g[...,1]-cdf_g[...,0] denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) t = (u-cdf_g[...,0])/denom samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) return samples7. 实战技巧与性能优化
在实际项目中,有几个关键点可以显著提升NeRF的性能和训练效率:
批处理光线渲染:
def batch_render_rays(model, ray_origins, ray_dirs, batch_size=1024): n_rays = ray_origins.shape[0] all_colors = [] for i in range(0, n_rays, batch_size): batch_origins = ray_origins[i:i+batch_size] batch_dirs = ray_dirs[i:i+batch_size] with torch.no_grad(): batch_colors = render_rays(model, batch_origins, batch_dirs) all_colors.append(batch_colors) return torch.cat(all_colors, dim=0)学习率调度:
def get_lr_scheduler(optimizer, warmup_steps=5000): def lr_lambda(step): if step < warmup_steps: return step / warmup_steps else: return 0.1 ** ((step - warmup_steps) / 250000) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)位置编码调参:
- 对于空间坐标,通常使用L=10(产生60维特征)
- 对于观察方向,使用L=4(产生24维特征)
损失函数增强:
def compute_loss(pred_color, target_color, pred_density, coef=0.01): color_loss = F.mse_loss(pred_color, target_color) # 添加稀疏性约束,鼓励大部分区域的密度为0 density_reg = torch.mean(pred_density) return color_loss + coef * density_reg