Skip to content

Layer4 AttentionLayer

覆盖 EncoderLayer.forward() 中的 self.attention(x, x, x, ...) 调用。AttentionLayer 完成三件事:Linear 投影 Q/K/V → .view() 拆分多头 → FullAttention 计算注意力 → 还原形状 → out_projection。inner_attention 的细节见 [[04A-Layer5-FullAttention]]。


§1 在父层中的位置

EncoderLayer.forward() 第一步:

python
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta)
  • Q = K = V = x (3, 9, 8) — Self-Attention,输入都是同一张量
  • 输出:new_x (3, 9, 8)attn = None

§2 I/O 接口定义

参数shape(toy)含义
queries (=x)(3, 9, 8)B=3,token_count=9,d_model=8
keys (=x)(3, 9, 8)Self-Attention:与 queries 相同张量
values (=x)(3, 9, 8)Self-Attention:与 queries 相同张量
attn_maskNoneiTransformer 不使用掩码
输出 out(3, 9, 8)注意力加权聚合后的 token 表示
输出 attnNoneoutput_attention=False

§3 顺序图


§4 语义分组图


§5 逐步精读

§5.0 完整原始代码

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
    B, L, _ = queries.shape
    _, S, _ = keys.shape
    H = self.n_heads

    queries = self.query_projection(queries).view(B, L, H, -1)
    keys = self.key_projection(keys).view(B, S, H, -1)
    values = self.value_projection(values).view(B, S, H, -1)

    out, attn = self.inner_attention(
        queries, keys, values, attn_mask, tau=tau, delta=delta
    )
    out = out.view(B, L, -1)

    return self.out_projection(out), attn

§5.1 宏观逻辑

用小例子(B=1,L=S=3,d_model=4,n_heads=2,d_keys=2)串起来:

① Q/K/V 各做一次 Linear(4,4),形状不变
  queries/keys/values: (1, 3, 4) → (1, 3, 4)

② .view(1, 3, 2, 2):把 d_model=4 拆成 n_heads=2 × d_keys=2
  相当于"把 4 维分给 2 个头,每头 2 维"
  Head 0: queries[:, :, 0, :] shape (1, 3, 2)
  Head 1: queries[:, :, 1, :] shape (1, 3, 2)

③ FullAttention 对 4 个张量 (B,L,H,d_keys) 做多头注意力
  每头独立算一个 3×3 矩阵,输出 (1, 3, 2, 2)

④ .view(1, 3, 4):把 n_heads × d_keys = 2×2=4 合并
  再做 Linear(4,4):out_projection,最终 (1, 3, 4)

为什么先 Linear 再 .view,不直接让 Linear 输出 (B,L,H,d_keys) 形状?

nn.Linear 只作用于最后一维,必须先产出 (B,L,d_model) 再用 .view() 重排。.view() 不做任何计算,只是重新解释内存布局(等效于 reshape),零代价。


§5.2 步骤一:解包形状

python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
  • B=3, L=9(query token 数),S=9(key token 数),H=4
  • Self-Attention 时 L=S(queries、keys 来自同一 tensor)
  • _ 表示 d_model=8,用不到(投影层内部定义了输入维度)

§5.3 步骤二:Q/K/V 投影

python
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)

Linear 的语义:

query_projection = nn.Linear(d_model=8, n_heads × d_keys = 4×2 = 8)

queries\_proj=queriesWQT+bQ,WQR8×8

nn.Linear 作用于最后一维,前面所有维度视为 batch:(3,9,8) 视为 27 条长度 8 的行向量,各自乘 WQT,输出仍 (3,9,8)

.view(B, L, H, -1) 的语义:

.view(3, 9, 4, -1) 把最后一维 8 拆成 (n_heads=4, d_keys=2),即 8=4×2

  • 拆分前:queries_proj (3, 9, 8),连续内存,每 token 有 8 个数
  • 拆分后:queries (3, 9, 4, 2),每 token 的 8 个数被"分配"给 4 个头,每头 2 个数

toy 数值(batch=0,token=0):

queries[0, 0, :] 原始 = [q0, q1, q2, q3, q4, q5, q6, q7]   8 维

经过 Linear(8,8) → [q'0, ..., q'7](线性混合)

.view(4, 2) 后:
  Head 0: [q'0, q'1]   ← queries[0, 0, 0, :]
  Head 1: [q'2, q'3]   ← queries[0, 0, 1, :]
  Head 2: [q'4, q'5]   ← queries[0, 0, 2, :]
  Head 3: [q'6, q'7]   ← queries[0, 0, 3, :]

keys, values 同理,各自有独立的投影矩阵 WK,WV


§5.4 步骤三:FullAttention

python
out, attn = self.inner_attention(
    queries, keys, values, attn_mask, tau=tau, delta=delta
)
  • self.inner_attention = FullAttention 实例
  • 输入:queries (3,9,4,2)keys (3,9,4,2)values (3,9,4,2)
  • 输出:out (3,9,4,2)attn = None(iTransformer 的 FullAttention 用 output_attention=False

FullAttention 对每个 head 独立算 9×9 注意力矩阵,再用 V 加权求和。详见 [[04A-Layer5-FullAttention]]。


§5.5 步骤四:合并多头 + out_projection

python
out = out.view(B, L, -1)
return self.out_projection(out), attn

.view(B, L, -1) 的语义:

.view(3, 9, -1):把 (n_heads=4, d_keys=2) 合并回 d_model=8,即 1=4×2=8

  • 拆分前:out (3, 9, 4, 2)
  • 合并后:out (3, 9, 8)

这是 .view(B, L, H, -1) 的逆操作,内存布局不变。

out_projection = nn.Linear(d_model=8, d_model=8)

output=outWOT+bO,WOR8×8

线性混合各头输出,得到最终 (3, 9, 8),作为 new_x 返回给 EncoderLayer

toy 数值(batch=0,token=0):

out[0, 0, :, :] shape (4, 2) → .view(8) → [o0, ..., o7]
out_projection: [o0,...,o7] × W_O^T + b_O → [p0,...,p7]
new_x[0, 0, :] = [p0,...,p7]   ← var_0 token 的注意力更新结果

§6 初始化参数(toy)

AttentionLayer.__init__() 关键参数:

参数含义
d_model8输入/输出维度
n_heads4注意力头数
d_keys2= d_model / n_heads = 8/4
d_values2= d_keys(通常相同)
query_projectionLinear(8, 8)8 = n_heads × d_keys
key_projectionLinear(8, 8)同上
value_projectionLinear(8, 8)同上
out_projectionLinear(8, 8)合并后输出投影
inner_attentionFullAttention(...)见下层

§7 下钻子组件

组件输入 shape输出 shape下层文档
FullAttentionQ/K/V: (3,9,4,2)(3,9,4,2)[[04A-Layer5-FullAttention]]

§8 多头拆分图

*记录并在线阅读我的笔记*