news 2026/1/5 11:08:38

SAM2跟踪的理解7——mask decoder

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
SAM2跟踪的理解7——mask decoder

目录

一、前言

四、MaskDecoder.forward

4.1 MaskDecoder.predict_masks

4.1.2 TwoWayTransformer.forward

4.1.2.1 TwoWayAttentionBlock.forward

4.1.2.2 self.self_attn——Attention.forward

线性映射前后维度是不变的,那它里面做了什么?有什么作用?

你的意思是,本来q,k,v都是相同的,线性映射之后就不同了是吗,那如何理解线性映射呢?后面什么要对q、k、v拆头,为什么是拆8个头?F.scaled_dot_product_attention做了什么?out_proj做了什么?

这里很多函数进不去,标记一下

如何理解自注意力机制呢?我自己还不了解我自己吗,还要用这种方式?

所以说其实是如果我是token的话,我的身份是受整个向量的其他token影响的,每个token都去询问一遍其他所有token以确定自己的身份,是这样理解吗

这里感觉没有理清楚,标记一下

4.1.2.3 self.cross_attn_token_to_image——Attention.forward

为什么out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 之后out: torch.Size([1, 8, 9, 16])

如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h

代入SAM2分割这一情景,如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h

4.1.2.4 自注意力和交叉注意力有什么区别?为什么要先自注意力?

4.1.2.5 为什么q和k都要加个绝对位置编码,然后v没加而直接就是keys?

4.1.2.6 MLP.forward


一、前言

下面是第一帧情况下的函数调用顺序。

2.12 <重点> add_new_prompt

2.13 <重点> _run_single_frame_inference

2.14 <重点> track_step

2.15 <重点> _prepare_memory_conditioned_features

2.16 _use_multimask

2.17 <重点> _forward_sam_heads

2.18 提示编码器:类PromptEncoder.forward

2.19 类PositionEmbeddingRandom.forward_with_coords

2.20 类PromptEncoder.get_dense_pe

2.21 掩码解码器 类MaskDecoder.forward

2.22 类MaskDecoder.predict_masks

2.23 TwoWayTransformer.forward(这篇开头在这)

2.24 TwoWayAttentionBlock.forward

2.25 Attention.forward(这篇结束在这)

2.26 <重点> MLP.forward

2.27 Attention.forward

2.28 LayerNorm2d.forward

2.29 MaskDecoder._dynamic_multimask_via_stability

2.30 MaskDecoder._get_stability_scores

2.31 fill_holes_in_mask_scores

2.32 _get_maskmem_pos_enc

2.33 _consolidate_temp_output_across_obj

2.34 _get_orig_video_res_output

四、MaskDecoder.forward

4.1 MaskDecoder.predict_masks

4.1.2 TwoWayTransformer.forward

sam2/modeling/sam/transformer.py

hs, src = self.transformer(src, pos_src, tokens)

上面这句进去就是调用TwoWayTransformer的forward函数。

class TwoWayTransformer(nn.Module): """ 双向 Transformer: 1. 先让「稀疏点 token」(queries) 与「稠密图像 token」(keys) 做若干层双向 cross-attention; 2. 最后再让 queries 单独对图像做一次 attention,得到增强后的 queries 作为最终输出。 图像 token 只做中间传递,最终原样返回。 """ def __init__( self, depth: int, # 双向 attention block 重复次数 embedding_dim: int, # 通道维度 C num_heads: int, # 多头注意力的头数 mlp_dim: int, # FFN 中间层维度 activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, # Attention 内部 Q/K 下采样比例 ) -> None: super().__init__() self.depth = depth self.embedding_dim = embedding_dim self.num_heads = num_heads self.mlp_dim = mlp_dim self.layers = nn.ModuleList() # 堆叠 depth 个双向 attention block for i in range(depth): self.layers.append( TwoWayAttentionBlock( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), # 第一层无需给 query 加 PE(已在输入时加好) ) ) # 最后一层:queries → 图像的 attention self.final_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( self, image_embedding: Tensor, # [B, C, H, W] image_pe: Tensor, # [B, C, H, W] 图像位置编码 point_embedding: Tensor, # [B, Np, C] 点提示的 embedding(已含 PE) ) -> Tuple[Tensor, Tensor]: """ Returns: processed point embedding [B, Np, C] processed image embedding [B, H*W, C] (与输入内容相同,仅 reshape) """ # image_embedding: torch.Size([1, 256, 64, 64]) # image_pe: torch.Size([1, 256, 64, 64]) # point_embedding: torch.Size([1, 9, 256]) # 1. 把图像展平成 token 序列 bs, c, h, w = image_embedding.shape # bs:1 c:256 h:64 w:64 image_embedding = image_embedding.flatten(2).permute(0, 2, 1) # [B, H*W, C] # image_embedding: torch.Size([1, 4096, 256]) image_pe = image_pe.flatten(2).permute(0, 2, 1) # [B, H*W, C] # image_pe: torch.Size([1, 4096, 256]) queries = point_embedding # [B, Np, C] # queries: torch.Size([1, 9, 256]) keys = image_embedding # [B, H*W, C] # keys: torch.Size([1, 4096, 256]) # 2. 逐层双向 attention 更新 queries 和 keys for layer in self.layers: # 进入TwoWayAttentionBlock.forward queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, # 每次把原始 PE 作为 Q 的偏置传进去 key_pe=image_pe, ) # 3. 最后一层:queries 再对图像做一次 attention q = queries + point_embedding # 残差加回原始 PE k = keys + image_pe attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) # [B, Np, C] queries = queries + attn_out # 残差连接 queries = self.norm_final_attn(queries) # LayerNorm # 4. 返回增强后的 queries 和原图 token(下游只拿 queries 用即可) return queries, keys

整体流程一句话总结
“稀疏点 token” 先和“稠密图像 token”在多层的双向 cross-attention 里互相更新;
最后再把更新后的点 token 单独对图像做一次 attention 并残差+Norm,得到最终点特征。
图像 token 只充当信息搬运工,原样返回即可。

4.1.2.1 TwoWayAttentionBlock.forward

sam2/modeling/sam/transformer.py

for layer in self.layers:

# 进入TwoWayAttentionBlock.forward
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding, # 每次把原始 PE 作为 Q 的偏置传进去
key_pe=image_pe,
)

class TwoWayAttentionBlock(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ 一个 Transformer 块,内部 4 步: 1) sparse queries 自注意力 2) queries cross-attend 到 dense keys(token→image) 3) 对 queries 做 MLP 4) dense keys cross-attend 到 sparse queries(image→token) 通过双向交叉,实现“稀疏点”与“稠密图”信息互通。 """ super().__init__() # 1. 自注意力 self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) # 2. token→image 交叉注意力 # 又进入TwoWayAttentionBlock.forward # attention_downsample_rate:2 self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm2 = nn.LayerNorm(embedding_dim) # 3. MLP self.mlp = MLP( embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation ) self.norm3 = nn.LayerNorm(embedding_dim) # 4. image→token 交叉注意力 self.norm4 = nn.LayerNorm(embedding_dim) # attention_downsample_rate:2 self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.skip_first_layer_pe = skip_first_layer_pe # 首块是否给 Q 加 PE def forward( self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple[Tensor, Tensor]: # 输入形状示例: # queries: torch.Size([1, 9, 256]) 稀疏点 token # keys: torch.Size([1, 4096, 256]) 稠密图像 token # query_pe:torch.Size([1, 9, 256]) 稀疏点token的绝对位置编码 # key_pe:torch.Size([1, 4096, 256]) 稠密图像token的绝对位置编码 # ---------- 1. 自注意力 ---------- # self.skip_first_layer_pe: True if self.skip_first_layer_pe: # 首层不加 PE,直接 self-attn # queries: torch.Size([1, 9, 256]) queries = self.self_attn(q=queries, k=queries, v=queries) # queries: torch.Size([1, 9, 256]) else: q = queries + query_pe # 残差加 PE attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out # 残差连接 queries = self.norm1(queries) # [B, 9, 256] # queries: torch.Size([1, 9, 256]) # ---------- 2. token→image 交叉注意力 ---------- q = queries + query_pe # 给 query 加 PE # q: torch.Size([1, 9, 256]) k = keys + key_pe # 给 key 加 PE # k: torch.Size([1, 4096, 256]) # q: torch.Size([1, 9, 256]) # k: torch.Size([1, 4096, 256]) # keys: torch.Size([1, 4096, 256]) attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) # 下采样在内部完成 # attn_out: torch.Size([1, 9, 256]) queries = queries + attn_out # 残差 # queries: torch.Size([1, 9, 256]) queries = self.norm2(queries) # [B, 9, 256] # queries: torch.Size([1, 9, 256]) # ---------- 3. MLP ---------- mlp_out = self.mlp(queries) queries = queries + mlp_out # 残差 queries = self.norm3(queries) # [B, 9, 256] # ---------- 4. image→token 交叉注意力 ---------- # 注意:这里“角色互换”——用图像 token 做 Q,去 attend 稀疏点 q = queries + query_pe # 稀疏点继续当“被 attend”的 K/V k = keys + key_pe # 图像当 Q attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) # 形状 [B, 4096, 256] keys = keys + attn_out # 残差更新图像 token keys = self.norm4(keys) # [B, 4096, 256] # 返回更新后的 (queries, keys),供下一层或下游使用 return queries, keys

总结

  1. 稀疏点先 self-attn,增强自身上下文。

  2. 再把增强后的点去 attend 图像,提取对应位置特征。

  3. 过一遍 MLP,进一步非线性变换。

  4. 最后让图像 token 反过来看这些点,把“哪些区域有点”信息写回图像特征。
    于是“点”与“图”完成一次双向融合,形状全程保持不变:
    queries 始终 [B, Np, C],keys 始终 [B, H·W, C]。

注意上面TwoWayAttentionBlock初始化里面,创建Attention的时候,自注意力是没有传入downsample_rate,所以是默认的1,而交叉注意力是传入downsample_rate=2的,这也是为什么后面线性映射的时候自注意力的Attention是没有降维的(一直是256),而交叉注意力的Attention里面线性映射的时候降维到128了(不过最后会升回256)

4.1.2.2self.self_attn——Attention.forward

sam2/modeling/sam/transformer.py

TwoWayAttentionBlock.forward中

queries =self.self_attn(q=queries, k=queries, v=queries)

上面的语句进入Attention的forward

class Attention(nn.Module): """ 标准多头注意力,但支持「把 Q/K/V 映射到更低维度」以节省计算。 内部使用 PyTorch 2.x 的 scaled_dot_product_attention,可自动选 Flash / Mem-efficient / Math kernel。 """ def __init__( self, embedding_dim: int, # 输入 token 的通道数 C num_heads: int, # 头数 h downsample_rate: int = 1, # 把 C 压缩成 C//downsample_rate,再分头 dropout: float = 0.0, kv_in_dim: int = None, # 如果 K/V 的输入维度与 Q 不同,可单独指定 ) -> None: super().__init__() self.embedding_dim = embedding_dim self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim # downsample_rate:1 self.internal_dim = embedding_dim // downsample_rate # 压缩后的通道 self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads 必须整除 internal_dim" # 线性映射:Q 来自 embedding_dim,K/V 可能来自别的维度 self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) # 输出再映射回原始通道 self.out_proj = nn.Linear(self.internal_dim, embedding_dim) self.dropout_p = dropout # 把 [B, N, C] 拆成 [B, h, N, C//h] 以便并行算多头 def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: # x:[1,9,256] num_heads:8 b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) # [1,9,8,32] return x.transpose(1, 2) # [1,8,9,32] # 逆操作:把 [B, h, N, C//h] 还原成 [B, N, C] def _recombine_heads(self, x: Tensor) -> Tensor: b, n_heads, n_tokens, c_per_head = x.shape # [1,8,9,32] x = x.transpose(1, 2) # [1,9,8,32] return x.reshape(b, n_tokens, n_heads * c_per_head) # [1,9,256] def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # 输入示例:q/k/v 均为 [1,9,256] # 1. 线性映射 + 降维 q = self.q_proj(q) # q: [1,9,256] k = self.k_proj(k) # k: [1,9,256] v = self.v_proj(v) # v: [1,9,256] # 2. 拆头 q = self._separate_heads(q, self.num_heads) # q: [1,8,9,32] k = self._separate_heads(k, self.num_heads) # k: [1,8,9,32] v = self._separate_heads(v, self.num_heads) # v: [1,8,9,32] # 3. 计算 dropout 开关(推理时关闭) # self.training: False dropout_p = self.dropout_p if self.training else 0.0 # dropout_p: 0.0 # 4. 调用最优 kernel 做 scaled dot-product attention with torch.backends.cuda.sdp_kernel( # USE_FLASH_ATTN: False enable_flash=USE_FLASH_ATTN, # 全局开关:是否尝试 FlashAttention # OLD_GPU:True dropout_p:0.0 MATH_KERNEL_ON:True enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, # 兜底 math kernel enable_mem_efficient=OLD_GPU, # 老 GPU 可开 memory-efficient ): out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) # [1,8,9,32] # 5. 合并头 + 输出投影 # out: torch.Size([1, 8, 9, 32]) out = self._recombine_heads(out) # out: torch.Size([1, 9, 256]) out = self.out_proj(out) # out: torch.Size([1, 9, 256]) return out
  1. 先把 Q/K/V 线性映射到更低维度internal_dim,省计算。

  2. 拆成多头形状,用 PyTorch 官方scaled_dot_product_attention自动选最快 kernel。

  3. 计算完再合并头,并线性映射回原始通道数。

  4. 整个模块对外保持“输入 [B, N, C] → 输出 [B, N, C]”的透明接口,内部按需降维。

线性映射前后维度是不变的,那它里面做了什么?有什么作用?

“维度不变”只是最外圈的假象——
forward真正做矩阵乘的维度已经被q_proj / k_proj / v_proj压成了internal_dim = embedding_dim // downsample_rate,自注意力这里是取downsample_rate=1
算完注意力再用out_proj把通道升回来。

所以“线性映射”里面干了三件事:

  1. 降维投影(省钱)
    把原始C映射到C/r(r=downsample_rate,默认 1 即不降;SAM 里常取 2)。
    参数量与计算量都近似降为1/r²

  2. 多头拆分准备
    投影后的通道必须能被num_heads整除,才能直接.reshape(..., heads, dim_per_head)
    这一步把“物理维度”转成“逻辑头维度”,后面sdpa并行算 h 个头。

  3. 可学习的线性组合(混通道)
    即使 r=1 不降维,投影矩阵仍是满秩可训练参数,给每个头提供不同的子空间,
    否则所有头都看到完全相同的 Q/K/V,多头就退化成“单头复制”。


形象一句话
“先压通道省钱,再拆头并行,最后升通道还给你”——
外表看 [B,N,C]→[B,N,C],里面其实偷偷把计算搬到了更小的空间。

你的意思是,本来q,k,v都是相同的,线性映射之后就不同了是吗,那如何理解线性映射呢?后面什么要对q、k、v拆头,为什么是拆8个头?F.scaled_dot_product_attention做了什么?out_proj做了什么?

把四个问题拆成四步,一次讲清:

  1. 线性映射到底干了什么
    输入q=k=v=x确实完全一样,但

    • q_projk_projv_proj三个独立的权重矩阵(形状都是[C, C/r])。

    • 各自乘一遍后,得到三套不同Q/K/V,通道数只剩C/r
      这就是 Attention 论文里最经典的“用三个线性层把同一批 token 变成查询-键-值”。

  2. 为什么要“拆 8 个头”

    • 投影后的internal_dim必须能被头数整除,于是
      internal_dim = num_heads × (dim_per_head)

    • SAM 把num_heads设成 8,相当于把256 → 8×32
      每个头只看 32 维的子空间,并行学 8 组不同的相似度函数,最后再拼回来。

    • 头数=8 是超参,经验值;越大越分散,计算量也线性涨。

  3. F.scaled_dot_product_attention内部做了什么
    给定[B, h, N, d]Q/K/V,它一步完成:

    scores = Q @ K.transpose(-2,-1) / sqrt(d) # [B,h,N,N] 相似度 attn = softmax(scores, dim=-1) # 行归一化得权重 output = attn @ V # 加权求和

    并自动选 FlashAttention / memory-efficient / math 三种 kernel 中最快且支持你硬件的那一条。

  4. out_proj的作用
    多头输出拼完后形状是[B, N, internal_dim],但下游网络期望原通道C
    out_proj就是一个Linear(internal_dim, C),把压缩后的特征升维+线性混合,再残差接回原路径。
    它同时让不同头的信息再次融合,并恢复网络下一层所需的通道数。

一句话串起来:
“同一份x三个独立线性层得到Q/K/V→ 压通道 → 拆 8 个头并行算相似度 → 拼回 → 再线性映射回原始通道。”

这里很多函数进不去,标记一下
如何理解自注意力机制呢?我自己还不了解我自己吗,还要用这种方式?

“自己还不了解自己吗?”——这句话把“自注意力”拟人化了,结果就绕进了死胡同。
把它当成纯数学的“特征再分配”来看,就通了:

  1. token ≠ 人
    每个 token 只是 C 维向量,同一位置在不同层含义一直在变
    “了解我自己”不是指它突然有了意识,而是用全局信息重新给这 C 个数加权,让“该强调的子空间放大,该抑制的噪声归零”。

  2. 没有全局视野的向量 = 局部盲人
    单层 Conv 或 FFN 只能看到邻域(甚至只能看单点);
    自注意力把整句/整图所有位置拉成一张相似度表,一次完成任意两两通信
    结果:

    • “猫” token 发现当前特征里“狗”分量很高,就把这部分降权;

    • “边界” token 发现远处还有一条同向边界,就把对应特征加强。
      这就是上下文校正,而不是“自我怀疑”。

  3. 数学上就是“可学习的加权平均”

    new_x_i = Σ_j softmax( (x_i W_q)(x_j W_k)^T /√d ) · (x_j W_v)

    权重矩阵 W 训练好后,每个位置自动知道该抄谁、抄多少
    所谓“自”仅表示 Query、Key、Value 都来自同一组 x,不引入额外外部信号

  4. 为什么一定要“投影三次”
    如果直接用x点积x,只能得到原始特征的相似度;
    投影成Q/K/V后,网络可以给“查询侧”“键侧”“值侧”各自学不同的子空间
    否则就退化成固定核的平滑,没有可学习参数

  5. 拆多头 = 多组“聚光灯”
    8 个头 ≈ 8 台不同颜色的探照灯,每台只照 32 维,
    有的专盯“颜色相似”,有的专盯“空间相邻”,最后拼起来就同时兼顾多种上下文线索

一句话收尸:
自注意力不是让 token“反省人生”,而是给它一次“看完全局再重新整理自己特征”的机会;
三次线性映射 + 多头 + softmax 加权,只是让这次“整理”变得可学习、可并行、多视角。

你的意思是token是我的局部,我的局部与局部之间要互相了解是吗

可以这么类比,但别再往“人”上靠——把“局部”想成纯数据块就不会绕:

  1. 一张图被切成 64×64 的小 patch,每个 patch 只“看见”自己那 8×8 像素;
    它对应的 token 向量里,也许有“边缘”“红色”“纹理”这些分量,但不知道 20 个 patch 之外还有一条同方向的边缘
    自注意力就是让这条“局部边缘” token 去查整张图里所有其他边缘 token,发现“哦,原来咱俩在同一条大轮廓上”,于是把自己的“轮廓置信度”拉高。

  2. 一句话里,“bank” 这个 token 如果只盯自己,无法区分“河岸”还是“银行”;
    它扫完全句,发现前面有“river”,就把“水”相关的分量加大,把“金融”分量压掉。
    这就是局部与局部之间互相校正,而不是“谈心”。

  3. 数学上就是用相似度做加权平均
    每个局部发出一个“查询向量”,和其他所有局部的“键向量”点积,得到“我跟你有多相关”;
    再用这组相关度去加权所有局部的“值向量”,更新自己
    所以“局部了解局部”=用全局相似度重新加权融合信息,让原本孤立的 C 维向量变成“全局-aware”的新向量。

一句话:
不是 token 有了情感,而是它的数值列表通过“与所有邻居算相似度→加权抄作业”被重新填写了一次,从而带上全局上下文。

所以说其实是如果我是token的话,我的身份是受整个向量的其他token影响的,每个token都去询问一遍其他所有token以确定自己的身份,是这样理解吗

可以这么“拟人”地打比方,但一定记住——这只是帮你建立直觉的比喻,底层全是矩阵乘法,没有“身份”“询问”这些意识活动:

  1. 在自注意力前,你(某个 token)手里的 C 个数字只来自你自己的小区域(8×8 像素或一个子词)。

  2. 自注意力让你把这张“全局相似度表”当成权重,去加权抄全场所有 token 的特征,然后用抄来的结果替换/加回自己原来的 C 个数字

  3. 于是,新的 C 个数字里既保留了原来的局部线索,又混入了“谁跟我很像/谁对我重要”的全局信息。

  4. 下一层网络看到的新向量,就不再只是局部,而是全局-aware 的表示——这就是你说的“身份被其他 token 影响”。

所以,“询问”=点积算相似度,“确定身份”=用相似度加权平均后更新自己的特征向量
整个过程是可学习的线性投影 + 数值加权,没有任何语义层面的“自我认知”。

全局相似度表这是指q、k、v都有各自的可学习的权重矩阵是吗?那这个权重矩阵是怎么被初始化,怎么被学习的呢?它的数值是怎么来的? 是不是有种真值知道更新后的是不是正确的,然后反向更新这个权重矩阵呢?但这个真值是什么?

把“比喻”全部剥掉,只剩可学习的张量可求导的 loss,就清楚了:

  1. 三个权重矩阵哪来的

    • 形状:q_proj.weight [C, C/r]k_proj.weight [C, C/r]v_proj.weight [C, C/r]

    • 初始化:PyTorch 默认用KaimingUniformxavier_uniform随机采样,数值服从某种 scaled 均匀分布,与任务无关。

    • 它们就是普通 Linear 层,跟 CNN 的卷积核一样,一开始完全随机,没有任何“真值”先验

  2. “全局相似度表”怎么来的
    前向时:

    Q = x W_q, K = x W_k, S = Q K^T /√d

    S就是[B, h, N, N]的“相似度表”,每一次 forward 都重新算一次,不存储、不保留,也没有可学习参数
    可学习的只有W_qW_k(以及W_vout_proj)。

  3. 数值怎么被“纠正”——没有“真值相似度”
    下游任务给的才是终极真值:

    • 分割任务:像素级 mask

    • 分类任务:类别标签

    • 生成任务:下一个 token ID

    网络输出 → 与这些标签算交叉熵、MSE、Dice … → 得到标量 loss → 反向传播一直走到W_qW_kW_vout_proj→ 自动求导调整矩阵元素。
    没有任何人告诉模型“第 i 与 j 的相似度应该是 0.7”,它只是发现“把这两个 patch 的权重调高,最终 mask 的 IoU 会变大”,于是下次继续调高。

  1. 总结一句话
    可学习的是四个线性矩阵,相似度表只是它们在前向时临时算出的副产品;
    “对不对”由下游任务 loss 说了算,loss 只认最终输出与人工标签的差距,不认中间相似度。

这里感觉没有理清楚,标记一下
4.1.2.3 self.cross_attn_token_to_image——Attention.forward

TwoWayAttentionBlock.forward中

# ---------- 2. token→image 交叉注意力 ----------
q = queries + query_pe # 给 query 加 PE
# q: torch.Size([1, 9, 256])

k = keys + key_pe # 给 key 加 PE
# k: torch.Size([1, 4096, 256])

# 又进入Attention.forward

attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) # 下采样在内部完成
queries = queries + attn_out # 残差
queries = self.norm2(queries) # [B, 9, 256]

class Attention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, dropout: float = 0.0, kv_in_dim: int = None, ) -> None: super().__init__() self.embedding_dim = embedding_dim # 原始输入维度(q 的输入维度) self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim # k/v 的输入维度,可与 q 不同 # self.internal_dim = 256 // 2 = 128 self.internal_dim = embedding_dim // downsample_rate # 经过降采样后的“内部”维度,用于多头计算 self.num_heads = num_heads # 注意力头数 assert ( self.internal_dim % num_heads == 0 ), "num_heads must divide embedding_dim." # 线性映射:把输入映射到统一的 internal_dim 空间 # embedding_dim:256 self.internal_dim:128 self.q_proj = nn.Linear(embedding_dim, self.internal_dim) # 仅 q 来自 embedding_dim # self.kv_in_dim:256 self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) # k/v 可能来自不同维度 self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) # 输出映射:把拼接后的多头结果再映射回原始 embedding_dim # embedding_dim:256 self.internal_dim:128 self.out_proj = nn.Linear(self.internal_dim, embedding_dim) self.dropout_p = dropout # attention dropout 比例 def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: """把 [B, N, C] 拆成 [B, num_heads, N, C//num_heads],方便并行算多头""" b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: Tensor) -> Tensor: """与 _separate_heads 相反,把多头结果重新拼接回 [B, N, C]""" # x: torch.Size([1, 8, 9, 16]) b, n_heads, n_tokens, c_per_head = x.shape # b:1 n_heads:8 n_tokens:9 c_per_head:16 x = x.transpose(1, 2) # 先交换维度,变成 [B, N_tokens, N_heads, C_per_head] # x: torch.Size([1, 9, 8, 16]) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: """ 参数: q: [B, Nq, embedding_dim] 查询序列 k: [B, Nk, kv_in_dim] 键序列 v: [B, Nk, kv_in_dim] 值序列 返回: out: [B, Nq, embedding_dim] """ # 输入: # q: torch.Size([1, 9, 256]) # k: torch.Size([1, 4096, 256]) # v: torch.Size([1, 4096, 256]) # Input projections # 初始化的时候 self.internal_dim = embedding_dim // downsample_rate # downsample_rate = 2, 所以交叉注意力里的线性映射发生降维了 q = self.q_proj(q) # q: torch.Size([1, 9, 128]) k = self.k_proj(k) # k: torch.Size([1, 4096, 128]) v = self.v_proj(v) # v: torch.Size([1, 4096, 128]) # Separate into heads q = self._separate_heads(q, self.num_heads) # q: torch.Size([1, 8, 9, 16]) k = self._separate_heads(k, self.num_heads) # k: torch.Size([1, 8, 4096, 16]) v = self._separate_heads(v, self.num_heads) # v:torch.Size([1, 8, 4096, 16]) # self.dropout_p:0 self.training:False dropout_p = self.dropout_p if self.training else 0.0 # 推理时关闭 dropout # dropout_p: 0.0 # Attention # 根据 GPU 能力及配置选择最优 kernel:FlashAttention / Math / MemoryEfficient with torch.backends.cuda.sdp_kernel( enable_flash=USE_FLASH_ATTN, # USE_FLASH_ATTN:False enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, # OLD_GPU:True dropout: 0.0 MATH_KERNEL_ON: True enable_mem_efficient=OLD_GPU, ): out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) # out: torch.Size([1, 8, 9, 16]) out = self._recombine_heads(out) # out: torch.Size([1, 9, 128]) out = self.out_proj(out) # out: torch.Size([1, 9, 256]) return out

代码整体解释

  1. 作用:实现一个支持“降维”的多头交叉注意力层。

  2. 降维:通过downsample_rate把 q/k/v 映射到更低的internal_dim再计算注意力,减少计算量;计算完再映射回原始embedding_dim

  3. 支持异构输入:k/v 的输入维度kv_in_dim可与 q 的embedding_dim不同,方便像 ViT 编码器-解码器结构那样使用。

  4. 多头拆分/合并:_separate_heads_recombine_heads负责把张量维度在“token”与“head”之间来回变换,以调用 PyTorch 的高效scaled_dot_product_attention

  5. 后端选择:sdp_kernel的上下文管理器根据硬件与超参自动挑选 FlashAttention、Math 或 MemoryEfficient kernel,兼顾速度与显存。

为什么out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 之后out: torch.Size([1, 8, 9, 16])

一句话:因为F.scaled_dot_product_attention“多头已经拆好”的 4-D 张量上做的是逐头计算,输出形状和输入的“头维”保持一致,所以
[B, num_heads, Nq, C_per_head]进去 → 同样形状出来。


详细推导

  1. 输入形状
    q:[1, 8, 9, 16]→ 含义[batch=1, heads=8, query_tokens=9, dim_per_head=16]
    k:[1, 8, 4096, 16]
    v:[1, 8, 4096, 16]

  2. scaled_dot_product_attention内部逻辑
    对每一个头(dim=1上切片)独立地执行:

    # 对于1个头,q乘以k的转置 [9, 16] x [16, 4096] = [9, 4096] attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k)) # [9, 4096] # 对于1个头,[9, 4096] x [4096, 16] = [9, 16] out_h = attn @ v_h # [9, 16]

    因此每个头都会得到一个[9, 16]的输出。

  3. 把 8 个头的结果按“头维”堆回去,就是[1, 8, 9, 16]

  4. 输出形状
    函数约定:

    • 不改变头数

    • 不改变 batch 大小

    • token 数始终与q 的 token 数相同

    • 特征维度始终与v 的 per-head 维度相同

    所以出来还是[1, 8, 9, 16]


小结

“多头”维度在sdpa眼里只是普通的 batch 维度之一,它不会合并或重塑它;合并/重塑是_recombine_headsout_proj后续才做的事。

如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h

把“注意力”想成查资料

  • q_h:你此刻想提的问题(9 个关键词,每个 16 维向量)。

  • k_h:资料库里所有标题(4096 条,每条也是 16 维向量)。

  • v_h:资料库里对应标题的正文(4096 条,每条 16 维向量)。


1. 打分:问题与标题有多匹配

score = q_h @ k_h.transpose(-2, -1) # [9, 4096]

每条问题向量与每条标题向量做点积(cosine 的分子),数值越大表示越相关。
除以√d_k√16 = 4)防止点积绝对值太大导致 softmax 饱和,这一步叫scale


2. 转成概率:softmax

attn = softmax(score, dim=-1) # [9, 4096],每行和为 1

对每条问题(每一行)做 softmax,把“相关分”变成选资料的概率分布
结果attn[i, j]就是“问题 i 对标题 j 的关注权重”。


3. 拿概率去加权正文

out_h = attn @ v_h # [9, 16]

按关注权重把 4096 条正文向量做加权平均

  • 权重大的正文对结果贡献大;

  • 权重小的几乎被忽略。
    于是 4096 条信息被压缩成 9 条“精炼答案”,维度仍保持 16。


一句话总结

“先算相关性,再按相关性加权求和”——这就是注意力机制的核心。

代入SAM2分割这一情景,如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h

在 SAM 2(Segment Anything Model 2)里,这套注意力被用来做“记忆-查询”式的跨帧传播

  • 不是普通 NLP 里的“单词→单词”,而是
    当前帧的像素查询← →记忆库中过去帧的像素特征

把符号代入视频分割场景,就能一眼看懂那两行公式在干什么。


  1. 张量含义(单头视角,h 代表“每头”)

    • q_h :当前帧待解码的 9 个像素 token(可以是某个 prompt 点对应的窗口特征,也可以是整张特征图展平后采样得到的 9 个向量)。

    • k_h :记忆库里的4096 个记忆位置(过去帧+当前帧已编码的高置信度像素特征,带空间-时间坐标)。

    • v_h :与 k_h 一一对应的16 维掩码嵌入(memory value),后续要拿来加权求和,生成当前像素的掩码 logits。


  1. 打分 → softmax attn = softmax(q_h @ k_h.T / √d_k)
    “当前 9 个像素” 分别对 “记忆库里 4096 个历史像素” 算相似度,再归一化。
    结果每行是一个4096 维的注意力权重,告诉模型:
    “为了知道眼前这个像素是不是‘狗’,我最该看记忆库里的哪几个历史像素?”


  1. 加权求和 out_h = attn @ v_h
    → 用刚刚的权重,把记忆库里的4096 条掩码嵌入做加权平均,得到 9 条16 维的聚合嵌入
    这 16 维向量已经融合了跨帧、跨视角的物体一致性信息,后续再喂给轻量级解码器就能直接吐出当前 9 个像素的前景/背景概率实例掩码 logits


一句话场景化总结
当前像素问:‘我以前长什么样?’
记忆库回答:‘你跟我这些历史像素最像,按相似程度给你加权平均一下以前的掩码特征,你就知道自己该属于哪一类了。’”
那两行公式就是“找最像的历史像素 → 把它们的掩码特征借过来”的数学实现,保证视频里物体 mask 在帧间时序一致、ID 不跳变

但现在只是第一帧图像预测

第一帧就用“记忆-查询”式注意力,看起来像是“自己查自己”,但这一步仍然有意义,而且并不浪费。把公式放到 SAM 2 第一帧的上下文里,可以这样理解:

  1. 记忆库 ≠ 下一帧,而是同一帧内部
    在 SAM 2 的 encoder 端,输入图像先被拆成 16×16 的窗口 token(外加绝对位置编码),这些 token 同时扮演kv;而q则是更稀疏的 prompt token(用户给的 1 个点或 9 个框采样点对应的特征)。
    因此
    q_h:9 个 prompt token
    k_h / v_h:4096 个图像 patch token
    注意力就是“prompt 点应该去看图像的哪些区域”——本质上是一次空间非局部聚合,而不是时序聚合。

  2. softmax 打分 → 拿到“空间关联图”
    attn = softmax(q_h @ k_h.T / √d_k)
    这一步得到 9×4096 的权重:每个 prompt 点对全图所有 patch 的关注程度。
    第一帧没有“历史”,但全图 patch 本身就是最丰富的上下文;让 prompt 点直接“看到”整幅图,就能迅速把局部提示传播到全局目标区域,解决“一个点太小,边界在哪”的问题。

  3. 加权求和 → 生成“全局感知的 prompt 嵌入”
    out_h = attn @ v_h
    把权重乘回图像特征,得到 9 个已经融合了整幅图语义的 16 维向量。
    这些向量再送进解码器,就能回归出精细的实例掩码,而不会出现“只给一个点就只能猜出一个小圆” 的局部歧义。

一句话总结
第一帧的“记忆库”就是当前帧自身的全部 patch;注意力做的事是
用稀疏提示去查询整张图,把提示特征扩散成完整的物体表示”——
所以它依旧是那两行公式,只是查询范围从‘跨帧历史’变成了‘同帧全局’,目的都是把局部线索扩展成完整目标

4.1.2.4 自注意力和交叉注意力有什么区别?为什么要先自注意力?

把“自注意力”和“交叉注意力”放到 SAM 2 的 TwoWayAttentionBlock 里,可以一句话区分:

  • 自注意力:同一组 token 内部“互相聊天”——稀疏点之间交换信息。

  • 交叉注意力:两组不同 token“跨组对话”——稀疏点 ⇄ 整张图交换信息。


1. 为什么先跑自注意力?

  1. 稀疏提示点先自我对齐
    用户可能一次给了 1 个点、3 个点或 9 个框,这些提示彼此独立喂进网络。自注意力让它们先“通气”:

    • 点 A 知道点 B 也在物体上,于是更确信“这是一整个物体”而非孤立边缘。

    • 点 C 落在背景,通过自注意力被其他前景点“投票”抑制,减少误报。
      结果:每个 query token 的 256 维向量已经携带了多提示共识,再拿去和图像特征交叉时,查询信号更稳定、噪声更小

  2. 保持排列等变性(permutation-equivariant)
    自注意力对输入顺序不敏感,保证“交换两个提示点顺序”输出不变,符合提示点的集合特性。

  3. 工程经验:先内聚、再外联
    几乎所有 Transformer 块都遵循
    “Self-Attn → Cross-Attn → FFN”三部曲:

    • Self:把自己人说明白;

    • Cross:再出去谈合作;

    • FFN:最后各自消化。
      这样梯度、收敛、精度都更友好。


2. 自注意力 vs. 交叉注意力 对照表

维度自注意力交叉注意力
Q/K/V 来源同一组 token不同组 token
输入形状[B, 9, 256]Q: [B, 9, 256] K/V: [B, 4096, 256]
作用范围稀疏点之间稀疏点 ↔ 图像块
目的提示点内部共识把提示扩散到全图 / 把图特征回传提示
参数量一套投影矩阵两套投影矩阵(q_proj 与 kv_proj 分开)
计算复杂度O(9²) 可忽略O(9×4096) 主导

一句话总结

先自注意力=让“点”先达成内部共识,再代表整个物体去和整张图谈判;
否则每个点各说各话,交叉注意力会被噪声带偏,分割边缘就飘了。

4.1.2.5 为什么q和k都要加个绝对位置编码,然后v没加而直接就是keys?

q = queries + query_pe # 给 query 加 PE
# q: torch.Size([1, 9, 256])

k = keys + key_pe # 给 key 加 PE
# k: torch.Size([1, 4096, 256])

# q: torch.Size([1, 9, 256])
# k: torch.Size([1, 4096, 256])
# keys: torch.Size([1, 4096, 256])
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) # 下采样在内部完成
# attn_out: torch.Size([1, 9, 256])

为什么q和k都要加个绝对位置编码,然后v没加而直接就是keys?

一句话:
Q 和 K 需要“位置”来对齐“谁问谁”,V 只负责“给答案”不需要位置。


  1. 注意力打分靠 Q·Kᵀ
    相似度 scoreᵢⱼ = (qᵢ + peᵢ) · (kⱼ + peⱼ)
    这里 pe 让网络知道
    “第 i 个提示点” 与 “第 j 个图像块”在图像坐标上是否对应
    去掉 pe,所有向量就只剩“外观”没有“坐标”,点乘结果会乱掉——提示点可能去 attend 很远但外观相似的区域


  1. V 只是“被加权取平均”的原材料
    attn_out = Σⱼ αᵢⱼ · vⱼ
    权重 αᵢⱼ 已经由带位置的 Q、K算好,vⱼ 用纯外观特征即可;
    再把 pe 加进 V 只会把位置信息重复注入,反而让网络难以区分“外观”与“位置”,实验上通常掉点。


  1. 工程惯例
    Transformer 原始论文、DETR、SAM/SAM2 都沿用
    “Q、K 加位置,V 不加”的套路,已成为视觉任务默认配置。


记忆口诀
“打分需要地址,送货只看内容。”
Q、K 带地址(pe)才能寄对快递;V 只管把货(特征)搬过来,地址早由权重 α 指定好了。

4.1.2.6 MLP.forward

sam2/modeling/sam2_utils.py

TwoWayAttentionBlock.forward里面调用了

mlp_out = self.mlp(queries)

# Lightly adapted from # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa class MLP(nn.Module): """ 经典多层感知机(MLP): - 支持任意层数 - 最后一层不加激活 - 可选 sigmoid 输出 常用于 Transformer 中的 FFN 子模块。 """ def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, activation: nn.Module = nn.ReLU, sigmoid_output: bool = False, ) -> None: super().__init__() self.num_layers = num_layers # 构造隐藏层维度列表:中间层全部用 hidden_dim h = [hidden_dim] * (num_layers - 1) # 顺序拼接 Linear:输入 → 隐藏 → ... → 输出 self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) self.sigmoid_output = sigmoid_output # 是否对最后一层加 sigmoid self.act = activation() # 实例化激活函数 def forward(self, x): # x: torch.Size([1, 9, 256]) # 逐层前向:除最后一层外均接激活 for i, layer in enumerate(self.layers): x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) # i=0 x: torch.Size([1, 9, 2048]) # 第一层升维 # i=1 x: torch.Size([1, 9, 256]) # 第二层降回原维(残差分支用) # self.sigmoid_output: False if self.sigmoid_output: x = F.sigmoid(x) # 若需要 0~1 范围则再套 sigmoid # x: torch.Size([1, 9, 256]) return x
  1. 一个可复用的 MLP 积木,通常作为 Transformer 块里的 FFN(Feed-Forward Network)。

  2. 默认2 层:先升维到 2048,再降回 256,配合残差连接,给模型增加非线性且保持通道维度一致。

  3. sigmoid_output开关方便在需要概率输出(如 mask logits 后处理)时直接得到 0~1 值。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/17 1:10:40

代码之恋(第十二篇:公开的合并与意外的提交)

周一的 办公室&#xff0c;晨光刚漫过工位隔板&#xff0c;咖啡机还在 “咕嘟咕嘟” 预热&#xff0c;空气里飘着速溶咖啡和隔夜外卖的混合气味 —— 典型的 “节后重启失败” 现场。李磊站在电梯口等艾丽&#xff0c;指尖无意识地摩挲着手机壳边缘&#xff0c;心里像跑着十个异…

作者头像 李华
网站建设 2025/12/17 1:09:14

基于89C51单片机的交通灯控制系统设计

基于89C51单片机的交通灯控制系统设计 第一章 系统概述 传统十字路口交通灯多采用固定时序电路&#xff0c;存在时序不可调、无法响应实时交通变化的问题&#xff0c;易在早晚高峰引发拥堵。基于89C51单片机的交通灯控制系统&#xff0c;以低成本、高可靠性的89C51为核心&#…

作者头像 李华
网站建设 2025/12/17 1:08:28

0基础转行网络安全,到底行不行?附全网最全人才发展路线图

最近有同学在后台留言&#xff0c;0基础怎么学网络安全&#xff1f;0基础可以转行做网络安全吗&#xff1f;以前也碰到过类似的问题&#xff0c;想了想&#xff0c;今天简单写一下。 我的回答是先了解&#xff0c;再入行。 具体怎么做呢&#xff1f; 首先&#xff0c;你要确…

作者头像 李华
网站建设 2026/1/3 16:10:34

收藏级干货:从零开始学Agent开发,万字详解核心链路与实战技巧

本文系统介绍了AI Agent的开发核心链路&#xff0c;涵盖Agent的概念、四大核心能力&#xff08;环境感知、智能决策、任务执行、持续学习&#xff09;、技术架构&#xff08;规划模块、记忆模块、工具调用&#xff09;及上下文工程策略。通过腾讯Dola案例分析&#xff0c;展示了…

作者头像 李华