Skip to content

Layer 4 — AttentionLayer 精读

父层 04A-Layer3-EncoderLayer 的第一步调用 self.attention(x, x, x, ...)
本文档只覆盖 AttentionLayer.forward 这一层(d_model ↔ 多头格式桥梁)。
子层 FullAttention 见 04C-Layer5-FullAttention

如果你对代码本身不敏感,优先看 04D-Layer4-例子-AttentionLayer到return。那篇只用一个 B=1, L=2, d_model=4, n_heads=2 的可手算矩阵例子,把 x→Q/K/V→scores→softmax→AV→return 画到底。


1. 在父层中的位置

EncoderLayer.forward
  └─ new_x, attn = self.attention(x, x, x, ...)   ← 本文档
       └─ Q/K/V projection + .view()    (8,6,16) → (8,6,2,8)
       └─ self.inner_attention(Q,K,V)   → 详见 Layer5 FullAttention
       └─ out.view() + out_projection   (8,6,2,8) → (8,6,16)

2. I/O 接口定义

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
shape(toy)含义
输入 queries(8, 6, 16) = (B*C, patch_num, d_model)自注意力,三者均为同一 x
输入 keys(8, 6, 16)同上
输入 values(8, 6, 16)同上
输出 new_x(8, 6, 16)注意力加权后的 patch 表示,形状与输入相同
输出 attnNoneoutput_attention=False 时为 None

3. 顺序图(具体层)


4. 语义分组图(索引层)

本层的核心是 A→B→C 这个**"进来拆开,算完合并"**的对称结构。格式变换只在这里发生,FullAttention 永远看不到 d_model


5. 逐步解析

5.0 完整原始代码

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
    B, L, _ = queries.shape
    _, S, _ = keys.shape
    H = self.n_heads

    queries = self.query_projection(queries).view(B, L, H, -1)
    keys = self.key_projection(keys).view(B, S, H, -1)
    values = self.value_projection(values).view(B, S, H, -1)

    out, attn = self.inner_attention(
        queries,
        keys,
        values,
        attn_mask,
        tau=tau,
        delta=delta
    )
    out = out.view(B, L, -1)

    return self.out_projection(out), attn

整体数据流:三路 projection+view 把输入拆成多头格式,FullAttention 完成注意力计算,再 view+projection 合并回 d_model。


5.1 格式转换 IN

本节的作用

(B, L, d_model) 的输入分别投影成 Q/K/V,再拆成 ==多头格式== (B, L, n_heads, d_keys),让每个头独立地"看"同一序列。

步骤一 — 提取形状变量

python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads

queries.shape = (8, 6, 16),拆出 B=8(B×C 合并后)、L=6(query 的 patch 数)、_ 丢弃 d_model=16(后续由 projection 重建)。S=6(key 的 patch 数,自注意力 S==L)。H=2(n_heads)。

步骤二 — Q/K/V projection + .view() 拆多头

python
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)

以 queries 为例:query_projectionLinear(16, 16),输出仍是 (8, 6, 16),随后 .view(8, 6, 2, -1) 把末维 16 重新解释为 (n_heads=2, d_keys=8)——前 8 维归 head 0,后 8 维归 head 1。每个头得到 (8, 6, 8) 的 query 子矩阵。

keysvalues 同理,都从 (8, 6, 16) 变成 (8, 6, 2, 8)

为什么要先 projection 再拆多头,而不是直接切片?

如果直接把 d_model=16 切成两段各 8 维,每个头的 query 向量只是原始向量的子集,两个头"看的是同一组特征的不同片段"。

Linear(16, 16) 投影后再拆,每个头实际上拥有独立的权重矩阵:head 0 的投影权重与 head 1 完全不同,可以各自学习关注不同的语义特征。这是多头注意力表达能力的来源。

下图展示一个 token 的 16 维向量经投影后如何分配到两个头:

论文描述代码实现原因
每个头有独立的 Q/K/V 投影query_projectionLinear(16,16)让每头学不同的关注方向
多头拆分.view(B, L, n_heads, d_keys)把 d_model 重解释为 (头数, 每头维度)

5.2 注意力计算(委托 FullAttention)

本节的作用

把格式化后的 Q/K/V 交给 FullAttention 做纯数学计算。本层不参与计算,只负责传递。

步骤三 — 调用 FullAttention

python
out, attn = self.inner_attention(
    queries, keys, values, attn_mask, tau=tau, delta=delta
)

传入 queries=(8,6,2,8)keys=(8,6,2,8)values=(8,6,2,8)attn_mask=None。FullAttention 完成 QKTsoftmaxAV 的全部计算,返回 out=(8,6,2,8)attn=None

→ 详见 04C-Layer5-FullAttention


5.3 格式转换 OUT

本节的作用

把多头输出 (B, L, n_heads, d_keys) 合并回 (B, L, d_model),再经线性变换做==跨头信息融合==。

步骤四 — .view() 合并 + out_projection

python
out = out.view(B, L, -1)

return self.out_projection(out), attn

.view(B, L, -1) 是步骤二 .view(B, L, H, D) 的逆操作:-1 = H × D = 2 × 8 = 16,把两个头各 8 维的输出拼接回 16 维,得到 (8, 6, 16)

拼接 ≠ 融合,out_projection 才是融合

.view() 只是把两个头的向量并排放置:[head0_0, head0_1, ..., head0_7, head1_0, ..., head1_7],两段之间没有任何交互。

out_projectionLinear(16, 16),其权重矩阵 W ∈ R^{16×16} 让输出维度的每一个位置都是 16 个输入维度的线性组合——head 0 的特征和 head 1 的特征在这里真正混合,产生比单头更丰富的表示。

整个函数返回 new_x=(8,6,16)attn=None


6. 下钻子组件

子组件职责下层文档
FullAttentionself.inner_attention纯数学:Q@K.T → softmax → A@V,不知道 d_model04C-Layer5-FullAttention

*记录并在线阅读我的笔记*