news 2026/4/15 1:36:52

ViT的demo实现与解读

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
ViT的demo实现与解读

首先可以看看ViT的流程视频:

15分钟认识ViT!【视觉Transformer】_哔哩哔哩_bilibili

输入大小为:

torch.Size([4, 3, 224, 224])

也就是batch_size=4,三个通道,224*224大小的图片

具体的forward过程函数如下:

patch_embed部分就是将一个图片按照16*16的大小进行分割:

输入前和输入后的x的大小变化:

前面的4代表batch_size。

一个patch的大小是3*16*16。 196=224*224/(16*16)=14*14。

也就是一张224*224的图片被分割成了196个14*14的图片patch,这个patch可以看作一个单词。

768=3*16*16。也就是将一个三通道的图片patch,延展成一个一维的向量。

然后是增加一个CLS token:

x的变化为:

也就是增加一个特殊的token

添加位置编码x的大小不变:

类似于transformer的位置编码,不过这里的位置编码是一个可以学习的矩阵:

之后就是正常的transformer结构:

完整的模型结构如下:

模型结构: VisionTransformer( (patch_embed): PatchEmbedding( (projection): Sequential( (0): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)) (1): Rearrange('b e h w -> b (h w) e') ) ) (pos_dropout): Dropout(p=0.1, inplace=False) (blocks): ModuleList( (0-11): 12 x TransformerBlock( (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (attn): MultiHeadAttention( (qkv): Linear(in_features=768, out_features=2304, bias=True) (proj): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (mlp): MLP( (fc1): Linear(in_features=768, out_features=3072, bias=True) (act): GELU(approximate='none') (fc2): Linear(in_features=3072, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (head): Linear(in_features=768, out_features=1000, bias=True) )

完整的demo代码如下:

""" Vision Transformer (ViT) 完整实现 用于图像分类任务 """ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange class PatchEmbedding(nn.Module): """ 将图像分割成patches并进行嵌入 """ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 # 使用卷积层将图像分割成patches并投影到embed_dim维度 self.projection = nn.Sequential( nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size), Rearrange('b e h w -> b (h w) e'), # 重排维度 ) def forward(self, x): """ x: (batch_size, channels, height, width) return: (batch_size, n_patches, embed_dim) """ x = self.projection(x) return x class MultiHeadAttention(nn.Module): """ 多头自注意力机制 """ def __init__(self, embed_dim=768, num_heads=12, dropout=0.0): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = self.head_dim ** -0.5 assert embed_dim % num_heads == 0, "embed_dim必须能被num_heads整除" # Q, K, V的线性变换 self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True) self.proj = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): """ x: (batch_size, seq_len, embed_dim) """ batch_size, seq_len, embed_dim = x.shape # 生成Q, K, V qkv = self.qkv(x) # (batch_size, seq_len, embed_dim * 3) qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, num_heads, seq_len, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] # 计算注意力分数 attn = (q @ k.transpose(-2, -1)) * self.scale # (batch_size, num_heads, seq_len, seq_len) attn = attn.softmax(dim=-1) attn = self.dropout(attn) # 加权求和 out = attn @ v # (batch_size, num_heads, seq_len, head_dim) out = out.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim) out = out.reshape(batch_size, seq_len, embed_dim) # 输出投影 out = self.proj(out) out = self.dropout(out) return out class MLP(nn.Module): """ 前馈神经网络 """ def __init__(self, embed_dim=768, mlp_ratio=4.0, dropout=0.0): super().__init__() hidden_dim = int(embed_dim * mlp_ratio) self.fc1 = nn.Linear(embed_dim, hidden_dim) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class TransformerBlock(nn.Module): """ Transformer编码器块 """ def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = MultiHeadAttention(embed_dim, num_heads, dropout) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = MLP(embed_dim, mlp_ratio, dropout) def forward(self, x): # 注意力块 + 残差连接 x = x + self.attn(self.norm1(x)) # MLP块 + 残差连接 x = x + self.mlp(self.norm2(x)) return x class VisionTransformer(nn.Module): """ 完整的Vision Transformer模型 """ def __init__( self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.0, emb_dropout=0.0, ): super().__init__() # Patch嵌入 self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) num_patches = self.patch_embed.n_patches # CLS token (可学习参数) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 位置编码 (可学习参数) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_dropout = nn.Dropout(emb_dropout) # Transformer编码器 self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth) ]) # 归一化层 self.norm = nn.LayerNorm(embed_dim) # 分类头 self.head = nn.Linear(embed_dim, num_classes) # 初始化权重 self._init_weights() def _init_weights(self): """初始化模型权重""" nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) nn.init.trunc_normal_(self.head.weight, std=0.02) nn.init.constant_(self.head.bias, 0) def forward(self, x): """ x: (batch_size, channels, height, width) return: (batch_size, num_classes) """ batch_size = x.shape[0] # Patch嵌入 x = self.patch_embed(x) # (batch_size, n_patches, embed_dim) # 添加CLS token cls_tokens = self.cls_token.expand(batch_size, -1, -1) # (batch_size, 1, embed_dim) x = torch.cat([cls_tokens, x], dim=1) # (batch_size, n_patches + 1, embed_dim) # 添加位置编码 x = x + self.pos_embed x = self.pos_dropout(x) # 通过Transformer编码器 for block in self.blocks: x = block(x) # 归一化 x = self.norm(x) # 使用CLS token进行分类 cls_token_final = x[:, 0] # (batch_size, embed_dim) logits = self.head(cls_token_final) # (batch_size, num_classes) return logits def create_vit_base(): """创建ViT-Base模型""" return VisionTransformer( img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1, emb_dropout=0.1, ) def create_vit_small(): """创建ViT-Small模型""" return VisionTransformer( img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.0, dropout=0.1, emb_dropout=0.1, ) # 测试代码 if __name__ == "__main__": # 创建模型 model = create_vit_base() print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M") # 创建随机输入 batch_size = 4 x = torch.randn(batch_size, 3, 224, 224) # 前向传播 with torch.no_grad(): output = model(x) print(f"输入形状: {x.shape}") print(f"输出形状: {output.shape}") print(f"输出示例: {output[0, :5]}") # 打印模型结构 print("\n模型结构:") print(model)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/3 15:43:06

3.1IT治理

1、IT治理的驱动因素:解决信息孤岛 2、IT治理主要目标包括:与业务目标一致、有效利用信息与数据资源、风险管理。 3、管理层次分为三层:最高管理层、执行管理层、业务与服务执行层。 4、IT治理体系的具体构成包括:IT定位、IT治理架…

作者头像 李华
网站建设 2026/4/6 7:16:37

中小企业的营销“暖心伙伴”——北京易美之尚,让增长不再难

“深夜改完的营销方案,投出去却石沉大海;花大价钱引的流量,转头就成了‘一次性过客’”——这大概是很多中小企业主的日常焦虑。在互联网营销的浪潮里,不是不想冲,而是怕方向错;不是没投入,而是…

作者头像 李华
网站建设 2026/4/10 21:45:46

Excalidraw链接功能全解析:超链接与跳转处理

Excalidraw链接功能全解析:超链接与跳转处理 在远程协作日益频繁的今天,一张图是否“能点”,往往决定了它是装饰还是生产力工具。许多团队还在用静态截图传递信息时,另一些人已经通过 Excalidraw 构建起可交互的知识网络——点击一…

作者头像 李华
网站建设 2026/4/14 13:44:52

LobeChat能否实现AI香道师?气味搭配与情绪调节芳香疗法推荐

LobeChat能否实现AI香道师?气味搭配与情绪调节芳香疗法推荐 在快节奏的都市生活中,越来越多的人开始寻求非药物方式来缓解压力、调节情绪。冥想、音乐疗愈、自然接触……而其中,“香气”作为一种古老却始终鲜活的感官媒介,正悄然回…

作者头像 李华
网站建设 2026/4/14 17:13:20

HunyuanVideo-Foley:高保真拟音生成扩散模型

HunyuanVideo-Foley:高保真拟音生成扩散模型 你有没有遇到过这样的情况:一段精心制作的AI生成视频,画面流畅、细节丰富,可一旦播放,却像“默片”一样缺乏声音支撑?再逼真的奔跑镜头配上静音,观…

作者头像 李华