news 2026/4/20 0:11:56

SRGAN实战:用Python+PyTorch实现照片级超分辨率重建(附代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
SRGAN实战:用Python+PyTorch实现照片级超分辨率重建(附代码)

SRGAN实战:用Python+PyTorch实现照片级超分辨率重建

当你翻出十年前的老照片,是否曾被模糊的像素和失真的细节所困扰?超分辨率重建技术正悄然改变这一现状。在众多解决方案中,SRGAN凭借其生成对抗网络的独特架构,能够从低分辨率图像中还原出令人惊艳的高频细节。本文将带你从零实现一个完整的SRGAN模型,不仅涵盖核心代码实现,更会分享实际训练中的调参技巧和避坑指南。

1. 环境配置与数据准备

工欲善其事,必先利其器。我们需要搭建一个支持GPU加速的PyTorch开发环境。推荐使用Python 3.8+和PyTorch 1.10+版本,它们对GAN训练提供了更好的支持。

conda create -n srgan python=3.8 conda activate srgan pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install opencv-python pillow matplotlib tqdm

数据集的选择直接影响模型效果。DIV2K是超分辨率任务的标准数据集,包含800张训练图像和100张验证图像,涵盖丰富场景。实际应用中,你可能还需要加入自己的业务数据:

from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomCrop(96), # 随机裁剪96x96 patches transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # 低分辨率图像通过双三次下采样获得 def get_lr_image(hr_img, scale=4): lr_img = hr_img.resize((hr_img.width//scale, hr_img.height//scale), Image.BICUBIC) return lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)

数据加载器的实现需要考虑内存效率。对于大型数据集,建议使用Dataset类按需加载:

class SRDataset(torch.utils.data.Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __getitem__(self, idx): hr_img = Image.open(self.image_paths[idx]).convert('RGB') lr_img = get_lr_image(hr_img) if self.transform: hr_img = self.transform(hr_img) lr_img = self.transform(lr_img) return lr_img, hr_img

2. 模型架构设计

SRGAN的核心在于生成器与判别器的对抗设计。生成器采用深度残差结构,而判别器则借鉴VGG网络的判别能力。

2.1 生成器网络(SRResNet)

生成器基于ResNet构建,包含多个残差块和亚像素卷积层:

import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.prelu = nn.PReLU() self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.prelu(out) out = self.conv2(out) out = self.bn2(out) return out + residual class Generator(nn.Module): def __init__(self, scale_factor=4, num_residual=16): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4) self.prelu = nn.PReLU() # 残差块堆叠 self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residual)]) # 上采样部分 upsampling = [] for _ in range(scale_factor//2): upsampling += [ nn.Conv2d(64, 256, kernel_size=3, padding=1), nn.PixelShuffle(2), nn.PReLU() ] self.upsampling = nn.Sequential(*upsampling) self.conv2 = nn.Conv2d(64, 3, kernel_size=9, padding=4) def forward(self, x): x = self.prelu(self.conv1(x)) residual = x x = self.res_blocks(x) x = x + residual x = self.upsampling(x) x = self.conv2(x) return torch.tanh(x)

2.2 判别器网络

判别器采用PatchGAN结构,对图像的局部区域进行真伪判断:

class Discriminator(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2), # 重复堆叠卷积层 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1024, kernel_size=1), nn.LeakyReLU(0.2), nn.Conv2d(1024, 1, kernel_size=1) ) def forward(self, x): return self.net(x)

3. 损失函数与训练策略

SRGAN的成功很大程度上归功于其精心设计的感知损失函数。它结合了内容损失和对抗损失,在像素级准确性和感知质量之间取得平衡。

3.1 感知损失实现

VGG特征提取器用于计算内容损失:

class VGGFeatureExtractor(nn.Module): def __init__(self): super().__init__() vgg = torchvision.models.vgg19(pretrained=True) self.features = nn.Sequential(*list(vgg.features.children())[:35]) # 截取到conv5_4 def forward(self, x): # 输入图像需要归一化到VGG的训练范围 x = (x + 1) / 2 # [-1,1] -> [0,1] x = x.sub(torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(x.device)) x = x.div(torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(x.device)) return self.features(x) def perceptual_loss(hr, sr, feature_extractor): mse_loss = nn.MSELoss() hr_features = feature_extractor(hr) sr_features = feature_extractor(sr) return mse_loss(hr_features, sr_features)

3.2 对抗损失与优化器配置

GAN训练需要平衡生成器和判别器的学习进度:

# 初始化模型 generator = Generator().to(device) discriminator = Discriminator().to(device) feature_extractor = VGGFeatureExtractor().to(device).eval() # 优化器设置 g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.9, 0.999)) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.9, 0.999)) # 损失函数 adversarial_criterion = nn.BCEWithLogitsLoss() pixel_criterion = nn.L1Loss() def train_step(lr, hr): # 生成器训练 sr = generator(lr) real_label = torch.ones(hr.size(0), 1, 1, 1).to(device) # 内容损失 content_loss = pixel_criterion(sr, hr) + 0.006 * perceptual_loss(hr, sr, feature_extractor) # 对抗损失 g_loss = adversarial_criterion(discriminator(sr), real_label) total_loss = content_loss + 1e-3 * g_loss g_optimizer.zero_grad() total_loss.backward() g_optimizer.step() # 判别器训练 d_loss_real = adversarial_criterion(discriminator(hr), real_label) d_loss_fake = adversarial_criterion(discriminator(sr.detach()), torch.zeros_like(real_label)) d_loss = (d_loss_real + d_loss_fake) / 2 d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() return total_loss.item(), d_loss.item()

4. 训练技巧与效果优化

GAN训练 notoriously unstable,以下技巧可显著提升SRGAN的训练稳定性:

4.1 两阶段训练策略

  1. 预训练生成器:仅使用MSE损失训练生成器20-30个epoch
  2. 联合训练:加入判别器进行对抗训练
# 生成器预训练 def pretrain_generator(generator, dataloader, epochs=20): optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4) criterion = nn.MSELoss() for epoch in range(epochs): for lr, hr in dataloader: lr, hr = lr.to(device), hr.to(device) sr = generator(lr) loss = criterion(sr, hr) optimizer.zero_grad() loss.backward() optimizer.step()

4.2 学习率调度与梯度裁剪

# 学习率调度器 g_scheduler = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size=1000, gamma=0.1) d_scheduler = torch.optim.lr_scheduler.StepLR(d_optimizer, step_size=1000, gamma=0.1) # 梯度裁剪 torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)

4.3 训练监控与可视化

实时监控训练过程有助于及时调整策略:

def save_sample(lr, sr, hr, epoch, path="samples"): os.makedirs(path, exist_ok=True) lr = lr[0].cpu().detach().numpy().transpose(1,2,0) sr = sr[0].cpu().detach().numpy().transpose(1,2,0) hr = hr[0].cpu().detach().numpy().transpose(1,2,0) fig, axes = plt.subplots(1, 3, figsize=(15,5)) axes[0].imshow((lr+1)/2) axes[0].set_title("Low Resolution") axes[1].imshow((sr+1)/2) axes[1].set_title("Super Resolution") axes[2].imshow((hr+1)/2) axes[2].set_title("High Resolution") plt.savefig(f"{path}/epoch_{epoch}.png") plt.close()

5. 模型评估与应用

训练完成后,我们需要全面评估模型性能:

5.1 定量指标评估

def calculate_psnr(sr, hr, max_val=1.0): mse = torch.mean((sr - hr) ** 2) return 10 * torch.log10(max_val**2 / mse) def calculate_ssim(sr, hr, window_size=11): # 实现SSIM计算 pass

5.2 实际应用示例

将训练好的模型应用于真实场景:

def enhance_image(image_path, generator, device): lr_img = Image.open(image_path).convert('RGB') lr_tensor = transforms.ToTensor()(lr_img).unsqueeze(0).to(device) with torch.no_grad(): sr_tensor = generator(lr_tensor) sr_img = transforms.ToPILImage()(sr_tensor.squeeze().cpu()) return sr_img

5.3 模型导出与部署

# 导出为TorchScript traced_generator = torch.jit.trace(generator, torch.rand(1,3,96,96).to(device)) traced_generator.save("srgan_generator.pt") # ONNX导出 torch.onnx.export(generator, torch.rand(1,3,96,96).to(device), "srgan.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

在实际项目中,我发现生成器的残差块数量并非越多越好。当超过20个残差块时,模型容易出现训练不稳定的情况。此外,使用Adam优化器时,beta2参数设置为0.999比默认的0.99能带来更稳定的训练过程。对于4K图像的超分辨率处理,建议先对图像分块处理再合并,可以有效降低显存消耗。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/14 8:46:05

.NET对象转JSON,到底有几种方式?荡

背景 在软件开发的漫长旅途中,"构建"这个词往往让人又爱又恨。爱的是,一键点击,代码变成产品,那是程序员最迷人的时刻;恨的是,维护那一堆乱糟糟的构建脚本,简直是噩梦。 在很多项目中…

作者头像 李华
网站建设 2026/4/20 0:11:56

Python FastAPI 请求超时机制

Python FastAPI 请求超时机制解析 在构建高性能Web应用时,请求超时是开发者必须面对的关键问题之一。FastAPI作为现代Python异步框架,其超时机制不仅影响用户体验,还直接关系到系统稳定性。本文将深入探讨FastAPI的请求超时设计,…

作者头像 李华
网站建设 2026/4/14 8:44:47

磁珠与电感的本质区别

磁珠与电感的基本概念磁珠(Ferrite Bead)是一种由铁氧体材料制成的被动元件,主要用于高频噪声抑制,通过将噪声能量转化为热能消耗掉。 电感(Inductor)是储能元件,利用电磁感应原理存储和释放能量…

作者头像 李华
网站建设 2026/4/19 6:10:45

如何用罗技鼠标宏实现绝地求生压枪:5分钟快速配置指南

如何用罗技鼠标宏实现绝地求生压枪:5分钟快速配置指南 【免费下载链接】logitech-pubg PUBG no recoil script for Logitech gaming mouse / 绝地求生 罗技 鼠标宏 项目地址: https://gitcode.com/gh_mirrors/lo/logitech-pubg 想要在《绝地求生》中实现专业…

作者头像 李华
网站建设 2026/4/19 6:06:37

如何解决微信网页版无法访问:wechat-need-web浏览器插件解决方案

如何解决微信网页版无法访问:wechat-need-web浏览器插件解决方案 【免费下载链接】wechat-need-web 让微信网页版可用 / Allow the use of WeChat via webpage access 项目地址: https://gitcode.com/gh_mirrors/we/wechat-need-web 你是否曾经在办公电脑上急…

作者头像 李华