news 2026/4/17 18:05:25

Day49 - CBAM注意力机制

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day49 - CBAM注意力机制

1. 简介

CBAM (Convolutional Block Attention Module) 是一种轻量级的注意力模块,它可以无缝集成到任何CNN架构中,通过引入额外的开销来显著提升模型的性能。

与SE (Squeeze-and-Excitation) 模块主要关注通道注意力不同,CBAM 同时结合了通道注意力 (Channel Attention)空间注意力 (Spatial Attention)

这种串联的注意力机制使得网络能够依次学习"关注什么" (What to focus on) 和 "关注哪里" (Where to focus on)。

2. 核心原理

CBAM 包含两个子模块,通常采用串联方式连接(先通道后空间):

2.1 通道注意力模块 (Channel Attention Module, CAM)

通道注意力旨在探索通道之间的依赖关系。CBAM 的通道注意力改进了 SE 模块:

  • 不仅使用全局平均池化 (Average Pooling),还引入了全局最大池化 (Max Pooling)。
  • 认为最大池化能收集到更独特的对象特征,与平均池化互补。
  • 两个池化后的特征向量共享同一个多层感知机 (MLP) 网络。
  • 最终将两个输出相加并通过 Sigmoid 激活函数生成通道权重。

2.2 空间注意力模块 (Spatial Attention Module, SAM)

空间注意力旨在探索特征图在空间维度上的重要性(即哪些区域更重要)。

  • 在通道维度上进行平均池化和最大池化,得到两个 2D 特征图。
  • 将这两个特征图在通道维度拼接 (Concat)。
  • 通过一个 7x7 的卷积层进行特征融合。
  • 通过 Sigmoid 激活函数生成空间权重图。

3. 代码实现

以下是基于 PyTorch 的 CBAM 完整实现,包括通道注意力、空间注意力及其在 CNN 中的集成。

3.1 通道注意力 (ChannelAttention)

import torch import torch.nn as nn class ChannelAttention(nn.Module): def __init__(self, in_channels, ratio=16): super().__init__() # 平均池化和最大池化 self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) # 共享的全连接层 (MLP) # 使用1x1卷积代替全连接层,减少参数量并保持输入形状 self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // ratio, bias=False), nn.ReLU(), nn.Linear(in_channels // ratio, in_channels, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): b, c, h, w = x.shape # 平均池化分支 avg_out = self.fc(self.avg_pool(x).view(b, c)) # 最大池化分支 max_out = self.fc(self.max_pool(x).view(b, c)) # 结果相加后经过Sigmoid attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1) # 权重作用于原特征图 return x * attention

3.2 空间注意力 (SpatialAttention)

class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() # padding计算保证输出大小不变 padding = kernel_size // 2 # 输入通道为2 (AvgPool 1 + MaxPool 1) self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): # 在通道维度上求平均 (b, 1, h, w) avg_out = torch.mean(x, dim=1, keepdim=True) # 在通道维度上求最大 (b, 1, h, w) max_out, _ = torch.max(x, dim=1, keepdim=True) # 拼接 (b, 2, h, w) pool_out = torch.cat([avg_out, max_out], dim=1) # 卷积 + Sigmoid attention = self.conv(pool_out) return x * self.sigmoid(attention)

3.3 CBAM 模块组合

class CBAM(nn.Module): def __init__(self, in_channels, ratio=16, kernel_size=7): super().__init__() self.channel_attention = ChannelAttention(in_channels, ratio) self.spatial_attention = SpatialAttention(kernel_size) def forward(self, x): # 串联结构:先通道后空间 x = self.channel_attention(x) x = self.spatial_attention(x) return x

3.4 集成到 CNN 模型

在经典的卷积神经网络中,CBAM 模块通常被放置在卷积层和激活函数之后,或者池化层之前。以下是一个简单的 CBAM-CNN 示例:

class CBAM_CNN(nn.Module): def __init__(self): super().__init__() # 第一层卷积块 self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2) self.cbam1 = CBAM(in_channels=32) # 集成CBAM # ... 后续层省略 ... # 假设这里还有更多层 # 全连接层 self.fc1 = nn.Linear(128 * 4 * 4, 512) self.dropout = nn.Dropout(p=0.5) self.fc2 = nn.Linear(512, 10) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.pool1(x) x = self.cbam1(x) # 应用注意力机制 # ... 后续前向传播 ... x = x.view(-1, 128 * 4 * 4) x = self.fc1(x) x = self.relu1(x) # 注意这里应该是对应的激活函数 x = self.dropout(x) x = self.fc2(x) return x

4. 训练与实验

在 CIFAR-10 数据集上的训练过程显示,引入 CBAM 后,模型能够更有效地聚焦于图像的关键特征。

  • 优化器: 使用 Adam 优化器,自适应调整学习率。
  • 学习率调度: 使用ReduceLROnPlateau,当验证集损失不再下降时自动降低学习率,有助于模型收敛到更优解。
  • 性能: 在约 50 个 Epoch 的训练中,模型能够达到较高的准确率 (如 86% 左右),证明了注意力机制对特征提取能力的增强作用。

5. 总结

CBAM 通过结合通道注意力和空间注意力,提供了一种即插即用的性能提升方案。

  • 轻量级: 参数量和计算量增加很少。
  • 通用性: 适用于各种 CNN 架构 (ResNet, MobileNet 等)。
  • 互补性: MaxPool 和 AvgPool 的结合保留了更丰富的特征信息。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/17 8:44:51

Cogito v2预览版:109B混合推理AI模型来了

Cogito v2预览版:109B混合推理AI模型来了 【免费下载链接】cogito-v2-preview-llama-109B-MoE 项目地址: https://ai.gitcode.com/hf_mirrors/unsloth/cogito-v2-preview-llama-109B-MoE 大语言模型领域再添重量级选手——Cogito v2预览版(cogit…

作者头像 李华
网站建设 2026/4/16 23:33:13

AMD SMU调试工具:深度掌控Ryzen处理器性能的终极方案

AMD SMU调试工具:深度掌控Ryzen处理器性能的终极方案 【免费下载链接】SMUDebugTool A dedicated tool to help write/read various parameters of Ryzen-based systems, such as manual overclock, SMU, PCI, CPUID, MSR and Power Table. 项目地址: https://git…

作者头像 李华
网站建设 2026/4/17 3:38:39

CoreCycler终极指南:轻松搞定CPU稳定性测试的完整教程

CoreCycler终极指南:轻松搞定CPU稳定性测试的完整教程 【免费下载链接】corecycler Stability test script for PBO & Curve Optimizer stability testing on AMD Ryzen processors 项目地址: https://gitcode.com/gh_mirrors/co/corecycler 还在为CPU超…

作者头像 李华
网站建设 2026/4/16 17:06:48

GPT-SoVITS推理速度优化:实时合成可行吗?

GPT-SoVITS推理速度优化:实时合成可行吗? 在虚拟主播直播间里,观众刚打出一句提问,几秒后才听到“数字人”慢半拍地回应——这种延迟虽然不至于中断体验,却足以打破沉浸感。类似场景也出现在智能客服、游戏NPC对话甚至…

作者头像 李华
网站建设 2026/3/31 1:56:16

如何快速掌握NBT编辑器:从入门到精通的完整指南

如何快速掌握NBT编辑器:从入门到精通的完整指南 【免费下载链接】NBTExplorer A graphical NBT editor for all Minecraft NBT data sources 项目地址: https://gitcode.com/gh_mirrors/nb/NBTExplorer NBT编辑器是一款功能强大的图形化NBT数据编辑工具&…

作者头像 李华
网站建设 2026/4/17 4:33:13

小红书链接解析实战:从失败到成功的完整心路历程

小红书链接解析实战:从失败到成功的完整心路历程 【免费下载链接】XHS-Downloader 免费;轻量;开源,基于 AIOHTTP 模块实现的小红书图文/视频作品采集工具 项目地址: https://gitcode.com/gh_mirrors/xh/XHS-Downloader 作为…

作者头像 李华