SRGAN实战:用Python+PyTorch实现照片级超分辨率重建
当你翻出十年前的老照片,是否曾被模糊的像素和失真的细节所困扰?超分辨率重建技术正悄然改变这一现状。在众多解决方案中,SRGAN凭借其生成对抗网络的独特架构,能够从低分辨率图像中还原出令人惊艳的高频细节。本文将带你从零实现一个完整的SRGAN模型,不仅涵盖核心代码实现,更会分享实际训练中的调参技巧和避坑指南。
1. 环境配置与数据准备
工欲善其事,必先利其器。我们需要搭建一个支持GPU加速的PyTorch开发环境。推荐使用Python 3.8+和PyTorch 1.10+版本,它们对GAN训练提供了更好的支持。
conda create -n srgan python=3.8 conda activate srgan pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install opencv-python pillow matplotlib tqdm数据集的选择直接影响模型效果。DIV2K是超分辨率任务的标准数据集,包含800张训练图像和100张验证图像,涵盖丰富场景。实际应用中,你可能还需要加入自己的业务数据:
from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomCrop(96), # 随机裁剪96x96 patches transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # 低分辨率图像通过双三次下采样获得 def get_lr_image(hr_img, scale=4): lr_img = hr_img.resize((hr_img.width//scale, hr_img.height//scale), Image.BICUBIC) return lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)数据加载器的实现需要考虑内存效率。对于大型数据集,建议使用Dataset类按需加载:
class SRDataset(torch.utils.data.Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __getitem__(self, idx): hr_img = Image.open(self.image_paths[idx]).convert('RGB') lr_img = get_lr_image(hr_img) if self.transform: hr_img = self.transform(hr_img) lr_img = self.transform(lr_img) return lr_img, hr_img2. 模型架构设计
SRGAN的核心在于生成器与判别器的对抗设计。生成器采用深度残差结构,而判别器则借鉴VGG网络的判别能力。
2.1 生成器网络(SRResNet)
生成器基于ResNet构建,包含多个残差块和亚像素卷积层:
import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.prelu = nn.PReLU() self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.prelu(out) out = self.conv2(out) out = self.bn2(out) return out + residual class Generator(nn.Module): def __init__(self, scale_factor=4, num_residual=16): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4) self.prelu = nn.PReLU() # 残差块堆叠 self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residual)]) # 上采样部分 upsampling = [] for _ in range(scale_factor//2): upsampling += [ nn.Conv2d(64, 256, kernel_size=3, padding=1), nn.PixelShuffle(2), nn.PReLU() ] self.upsampling = nn.Sequential(*upsampling) self.conv2 = nn.Conv2d(64, 3, kernel_size=9, padding=4) def forward(self, x): x = self.prelu(self.conv1(x)) residual = x x = self.res_blocks(x) x = x + residual x = self.upsampling(x) x = self.conv2(x) return torch.tanh(x)2.2 判别器网络
判别器采用PatchGAN结构,对图像的局部区域进行真伪判断:
class Discriminator(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2), # 重复堆叠卷积层 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1024, kernel_size=1), nn.LeakyReLU(0.2), nn.Conv2d(1024, 1, kernel_size=1) ) def forward(self, x): return self.net(x)3. 损失函数与训练策略
SRGAN的成功很大程度上归功于其精心设计的感知损失函数。它结合了内容损失和对抗损失,在像素级准确性和感知质量之间取得平衡。
3.1 感知损失实现
VGG特征提取器用于计算内容损失:
class VGGFeatureExtractor(nn.Module): def __init__(self): super().__init__() vgg = torchvision.models.vgg19(pretrained=True) self.features = nn.Sequential(*list(vgg.features.children())[:35]) # 截取到conv5_4 def forward(self, x): # 输入图像需要归一化到VGG的训练范围 x = (x + 1) / 2 # [-1,1] -> [0,1] x = x.sub(torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(x.device)) x = x.div(torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(x.device)) return self.features(x) def perceptual_loss(hr, sr, feature_extractor): mse_loss = nn.MSELoss() hr_features = feature_extractor(hr) sr_features = feature_extractor(sr) return mse_loss(hr_features, sr_features)3.2 对抗损失与优化器配置
GAN训练需要平衡生成器和判别器的学习进度:
# 初始化模型 generator = Generator().to(device) discriminator = Discriminator().to(device) feature_extractor = VGGFeatureExtractor().to(device).eval() # 优化器设置 g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.9, 0.999)) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.9, 0.999)) # 损失函数 adversarial_criterion = nn.BCEWithLogitsLoss() pixel_criterion = nn.L1Loss() def train_step(lr, hr): # 生成器训练 sr = generator(lr) real_label = torch.ones(hr.size(0), 1, 1, 1).to(device) # 内容损失 content_loss = pixel_criterion(sr, hr) + 0.006 * perceptual_loss(hr, sr, feature_extractor) # 对抗损失 g_loss = adversarial_criterion(discriminator(sr), real_label) total_loss = content_loss + 1e-3 * g_loss g_optimizer.zero_grad() total_loss.backward() g_optimizer.step() # 判别器训练 d_loss_real = adversarial_criterion(discriminator(hr), real_label) d_loss_fake = adversarial_criterion(discriminator(sr.detach()), torch.zeros_like(real_label)) d_loss = (d_loss_real + d_loss_fake) / 2 d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() return total_loss.item(), d_loss.item()4. 训练技巧与效果优化
GAN训练 notoriously unstable,以下技巧可显著提升SRGAN的训练稳定性:
4.1 两阶段训练策略
- 预训练生成器:仅使用MSE损失训练生成器20-30个epoch
- 联合训练:加入判别器进行对抗训练
# 生成器预训练 def pretrain_generator(generator, dataloader, epochs=20): optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4) criterion = nn.MSELoss() for epoch in range(epochs): for lr, hr in dataloader: lr, hr = lr.to(device), hr.to(device) sr = generator(lr) loss = criterion(sr, hr) optimizer.zero_grad() loss.backward() optimizer.step()4.2 学习率调度与梯度裁剪
# 学习率调度器 g_scheduler = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size=1000, gamma=0.1) d_scheduler = torch.optim.lr_scheduler.StepLR(d_optimizer, step_size=1000, gamma=0.1) # 梯度裁剪 torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)4.3 训练监控与可视化
实时监控训练过程有助于及时调整策略:
def save_sample(lr, sr, hr, epoch, path="samples"): os.makedirs(path, exist_ok=True) lr = lr[0].cpu().detach().numpy().transpose(1,2,0) sr = sr[0].cpu().detach().numpy().transpose(1,2,0) hr = hr[0].cpu().detach().numpy().transpose(1,2,0) fig, axes = plt.subplots(1, 3, figsize=(15,5)) axes[0].imshow((lr+1)/2) axes[0].set_title("Low Resolution") axes[1].imshow((sr+1)/2) axes[1].set_title("Super Resolution") axes[2].imshow((hr+1)/2) axes[2].set_title("High Resolution") plt.savefig(f"{path}/epoch_{epoch}.png") plt.close()5. 模型评估与应用
训练完成后,我们需要全面评估模型性能:
5.1 定量指标评估
def calculate_psnr(sr, hr, max_val=1.0): mse = torch.mean((sr - hr) ** 2) return 10 * torch.log10(max_val**2 / mse) def calculate_ssim(sr, hr, window_size=11): # 实现SSIM计算 pass5.2 实际应用示例
将训练好的模型应用于真实场景:
def enhance_image(image_path, generator, device): lr_img = Image.open(image_path).convert('RGB') lr_tensor = transforms.ToTensor()(lr_img).unsqueeze(0).to(device) with torch.no_grad(): sr_tensor = generator(lr_tensor) sr_img = transforms.ToPILImage()(sr_tensor.squeeze().cpu()) return sr_img5.3 模型导出与部署
# 导出为TorchScript traced_generator = torch.jit.trace(generator, torch.rand(1,3,96,96).to(device)) traced_generator.save("srgan_generator.pt") # ONNX导出 torch.onnx.export(generator, torch.rand(1,3,96,96).to(device), "srgan.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})在实际项目中,我发现生成器的残差块数量并非越多越好。当超过20个残差块时,模型容易出现训练不稳定的情况。此外,使用Adam优化器时,beta2参数设置为0.999比默认的0.99能带来更稳定的训练过程。对于4K图像的超分辨率处理,建议先对图像分块处理再合并,可以有效降低显存消耗。