news 2026/4/15 14:55:39

别再死记公式了!用Python+PyTorch亲手画图理解卷积的‘放大’与‘缩小’

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再死记公式了!用Python+PyTorch亲手画图理解卷积的‘放大’与‘缩小’

别再死记公式了!用Python+PyTorch亲手画图理解卷积的‘放大’与‘缩小’

卷积神经网络(CNN)中的"下采样"和"上采样"概念常常让初学者感到困惑。与其死记硬背公式,不如通过代码和可视化来直观理解这些操作的本质。本文将带你用Python和PyTorch亲手实现这些操作,并通过动态可视化来观察特征图的变化过程。

1. 准备工作:搭建可视化环境

在开始之前,我们需要准备一个能够实时显示卷积操作效果的环境。推荐使用Jupyter Notebook配合Matplotlib进行交互式可视化。

首先安装必要的库:

!pip install torch torchvision matplotlib numpy

然后导入所需的模块:

import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt import numpy as np from matplotlib.animation import FuncAnimation

为了更直观地观察卷积过程,我们可以创建一个简单的可视化函数:

def visualize_convolution(input_tensor, kernel, output_tensor, title=""): fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) ax1.imshow(input_tensor.squeeze(), cmap='gray') ax1.set_title('Input') ax1.axis('off') ax2.imshow(kernel.squeeze(), cmap='gray') ax2.set_title('Kernel') ax2.axis('off') ax3.imshow(output_tensor.squeeze(), cmap='gray') ax3.set_title('Output') ax3.axis('off') plt.suptitle(title) plt.show()

2. 下采样:特征图的缩小过程

下采样是CNN中常见的操作,它通过卷积和池化等方式减小特征图的尺寸。让我们通过代码来观察这一过程。

2.1 标准卷积的下采样效果

首先创建一个简单的4×4输入矩阵:

input_data = torch.tensor([ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16] ], dtype=torch.float32).unsqueeze(0).unsqueeze(0) # shape: [1, 1, 4, 4]

定义一个3×3的卷积核,步长(stride)为1:

conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=0, bias=False) # 手动设置卷积核权重 with torch.no_grad(): conv.weight.data = torch.ones_like(conv.weight.data)

执行卷积操作并可视化:

output = conv(input_data) visualize_convolution(input_data, conv.weight.data, output, "Standard Convolution (stride=1)")

观察输出结果,你会发现特征图从4×4缩小到了2×2。这是因为在没有填充(padding=0)的情况下,3×3的卷积核在4×4的输入上只能滑动2×2次。

2.2 增大步长的下采样效果

现在让我们增大步长(stride)来观察更明显的下采样效果:

conv_stride2 = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=0, bias=False) with torch.no_grad(): conv_stride2.weight.data = torch.ones_like(conv_stride2.weight.data) output_stride2 = conv_stride2(input_data) visualize_convolution(input_data, conv_stride2.weight.data, output_stride2, "Convolution with stride=2")

这次输出变成了1×1的特征图。通过调整步长,我们可以控制下采样的程度。

3. 上采样:特征图的放大过程

上采样是下采样的逆过程,常用于图像分割等需要输出与输入尺寸相同的任务中。PyTorch提供了几种上采样方法,我们重点看看转置卷积。

3.1 转置卷积的基本原理

转置卷积(Transposed Convolution)常被误称为"反卷积",它实际上是一种特殊的正向卷积操作,能够实现上采样。

创建一个2×2的输入:

small_input = torch.tensor([ [1, 2], [3, 4] ], dtype=torch.float32).unsqueeze(0).unsqueeze(0) # shape: [1, 1, 2, 2]

定义转置卷积层:

trans_conv = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=1, padding=0, bias=False) with torch.no_grad(): trans_conv.weight.data = torch.ones_like(trans_conv.weight.data)

执行转置卷积并可视化:

output_trans = trans_conv(small_input) visualize_convolution(small_input, trans_conv.weight.data, output_trans, "Transposed Convolution (stride=1)")

你会看到2×2的输入被放大到了4×4。这是因为转置卷积在输入元素之间插入了零值,然后进行常规卷积操作。

3.2 转置卷积的步长效应

增大转置卷积的步长可以进一步放大特征图:

trans_conv_stride2 = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, padding=0, bias=False) with torch.no_grad(): trans_conv_stride2.weight.data = torch.ones_like(trans_conv_stride2.weight.data) output_trans_stride2 = trans_conv_stride2(small_input) visualize_convolution(small_input, trans_conv_stride2.weight.data, output_trans_stride2, "Transposed Convolution (stride=2)")

这次2×2的输入被放大到了5×5。理解转置卷积的关键在于认识到它实际上是在输入元素之间插入(stride-1)个零值,然后进行常规卷积。

4. 空洞卷积:扩大感受野而不增加参数

空洞卷积(Dilated Convolution)通过在卷积核元素之间插入空洞来扩大感受野,同时不增加参数数量。

4.1 基本空洞卷积实现

创建一个7×7的输入:

large_input = torch.tensor([ [1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14], [15, 16, 17, 18, 19, 20, 21], [22, 23, 24, 25, 26, 27, 28], [29, 30, 31, 32, 33, 34, 35], [36, 37, 38, 39, 40, 41, 42], [43, 44, 45, 46, 47, 48, 49] ], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

定义空洞卷积层(空洞率=2):

dilated_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=2, dilation=2, bias=False) with torch.no_grad(): dilated_conv.weight.data = torch.ones_like(dilated_conv.weight.data)

执行空洞卷积并可视化:

output_dilated = dilated_conv(large_input) visualize_convolution(large_input, dilated_conv.weight.data, output_dilated, "Dilated Convolution (rate=2)")

虽然使用了3×3的卷积核,但由于空洞率为2,实际感受野相当于5×5。观察输出结果,你会发现中心像素受到了更广泛区域的影响。

4.2 空洞卷积与标准卷积的对比

为了更清楚地看到空洞卷积的效果,我们可以创建一个动画来对比标准卷积和空洞卷积:

def create_comparison_animation(): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) # 标准卷积 standard_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) with torch.no_grad(): standard_conv.weight.data = torch.ones_like(standard_conv.weight.data) # 空洞卷积 dilated_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=2, dilation=2, bias=False) with torch.no_grad(): dilated_conv.weight.data = torch.ones_like(dilated_conv.weight.data) def update(i): ax1.clear() ax2.clear() # 在输入上标记当前卷积核位置 marked_input = large_input.clone() h, w = marked_input.shape[-2:] # 标准卷积的覆盖区域 center_h, center_w = i//w, i%w for dh in [-1, 0, 1]: for dw in [-1, 0, 1]: nh, nw = center_h + dh, center_w + dw if 0 <= nh < h and 0 <= nw < w: marked_input[0, 0, nh, nw] = 10 # 高亮显示 ax1.imshow(marked_input.squeeze(), cmap='gray') ax1.set_title('Standard Conv Coverage') # 空洞卷积的覆盖区域 marked_input_dilated = large_input.clone() for dh in [-2, 0, 2]: for dw in [-2, 0, 2]: nh, nw = center_h + dh, center_w + dw if 0 <= nh < h and 0 <= nw < w: marked_input_dilated[0, 0, nh, nw] = 10 # 高亮显示 ax2.imshow(marked_input_dilated.squeeze(), cmap='gray') ax2.set_title('Dilated Conv Coverage (rate=2)') plt.suptitle(f'Position {i}: ({center_h}, {center_w})') anim = FuncAnimation(fig, update, frames=49, interval=200) plt.close() return anim # 显示动画 create_comparison_animation()

这个动画清晰地展示了标准卷积和空洞卷积在感受野上的差异。虽然两者都使用3×3的卷积核,但空洞卷积能够覆盖更大的区域。

5. 综合应用:构建简单的上采样-下采样网络

现在让我们把这些概念结合起来,构建一个简单的网络,先下采样再上采样一张图片:

class SimpleUpDownNet(nn.Module): def __init__(self): super().__init__() self.down1 = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1) self.down2 = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1) self.up1 = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, padding=1, output_padding=1) self.up2 = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, padding=1, output_padding=1) def forward(self, x): x = F.relu(self.down1(x)) x = F.relu(self.down2(x)) x = F.relu(self.up1(x)) x = self.up2(x) return x # 加载测试图像 from skimage.data import camera test_image = torch.from_numpy(camera()).float().unsqueeze(0).unsqueeze(0) / 255.0 # 创建并运行网络 net = SimpleUpDownNet() with torch.no_grad(): output_image = net(test_image) # 可视化结果 plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.imshow(test_image.squeeze(), cmap='gray') plt.title('Original Image') plt.axis('off') plt.subplot(1, 2, 2) plt.imshow(output_image.squeeze(), cmap='gray') plt.title('After Down-Up Sampling') plt.axis('off') plt.show()

通过这个简单的例子,你可以看到下采样和上采样操作对图像的影响。虽然最终图像尺寸恢复了,但一些细节信息在过程中丢失了。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/15 14:50:04

为什么92%的多模态推理服务在峰值期崩溃?——基于QPS/显存/时延三维指标的负载均衡重构指南

第一章&#xff1a;为什么92%的多模态推理服务在峰值期崩溃&#xff1f;——基于QPS/显存/时延三维指标的负载均衡重构指南 2026奇点智能技术大会(https://ml-summit.org) 多模态推理服务在真实业务场景中并非线性扩展&#xff1a;图像编码、文本解码、跨模态对齐三阶段存在显…

作者头像 李华
网站建设 2026/4/15 14:47:06

Xtreme Download Manager:5倍下载加速与视频捕获完全指南

Xtreme Download Manager&#xff1a;5倍下载加速与视频捕获完全指南 【免费下载链接】xdm Powerfull download accelerator and video downloader 项目地址: https://gitcode.com/gh_mirrors/xd/xdm 你是否厌倦了缓慢的下载速度&#xff1f;是否曾因为网络中断而不得不…

作者头像 李华
网站建设 2026/4/15 14:44:19

Cursor Free VIP:一键解锁AI编程助手Pro功能的终极解决方案

Cursor Free VIP&#xff1a;一键解锁AI编程助手Pro功能的终极解决方案 【免费下载链接】cursor-free-vip [Support 0.45]&#xff08;Multi Language 多语言&#xff09;自动注册 Cursor Ai &#xff0c;自动重置机器ID &#xff0c; 免费升级使用Pro 功能: Youve reached you…

作者头像 李华