Skip to content

AttentionLayer:多头投影外壳

Abstract

AttentionLayer 不负责真正的 attention 公式。

它负责四件事:把输入投影成 Q/K/V、拆成多头、调用 inner_attention、再把多头输出合并回 d_model

1. 在源码里的位置

python
class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        ...
        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

2. 图解

![[zdocs/pytorch-basics/assets/self_attention_attentionlayer.svg]]

3. I/O

toy:

text
B=2, L=5, S=6, d_model=8, H=2, d_keys=3, d_values=4
张量输入 shape说明
queries(2,5,8)query 序列
keys(2,6,8)key 序列
values(2,6,8)value 序列
输出(2,5,8)回到 d_model

4. 逐行精读

源码:

python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads

queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)

out, attn = self.inner_attention(
    queries, keys, values, attn_mask, tau=tau, delta=delta
)
out = out.view(B, L, -1)

return self.out_projection(out), attn

注解:

python
B, L, _ = queries.shape
# queries: (2,5,8), 所以 B=2, L=5

_, S, _ = keys.shape
# keys: (2,6,8), 所以 S=6

H = self.n_heads
# H = 2

Q/K/V 投影:

python
queries = self.query_projection(queries).view(B, L, H, -1)
# Linear(8 -> d_keys*n_heads = 3*2 = 6)
# (2,5,8) -> (2,5,6)
# view(2,5,2,-1) -> (2,5,2,3)

keys = self.key_projection(keys).view(B, S, H, -1)
# (2,6,8) -> Linear -> (2,6,6) -> view -> (2,6,2,3)

values = self.value_projection(values).view(B, S, H, -1)
# Linear(8 -> d_values*n_heads = 4*2 = 8)
# (2,6,8) -> (2,6,8) -> view -> (2,6,2,4)

调用内层 attention:

python
out, attn = self.inner_attention(queries, keys, values, ...)
# 如果 inner_attention 是 FullAttention:
# out: (B,L,H,d_values) = (2,5,2,4)

合并多头:

python
out = out.view(B, L, -1)
# (2,5,2,4) -> (2,5,8)

return self.out_projection(out), attn
# Linear(8 -> 8)
# (2,5,8) -> (2,5,8)

5. 数学逻辑

AttentionLayer 的数学主线不是 softmax,而是线性投影:

Q=XqWQ+bQK=XkWK+bKV=XvWV+bV

拆头只是把最后一维重新解释:

(B,L,Hdk)(B,L,H,dk)

最后再合并:

(B,L,H,dv)(B,L,Hdv)(B,L,dmodel)

6. 常见误区

view(B,L,H,-1) 不是“生成多个头”,它只是改 shape。真正给不同 head 提供不同子空间的是前面的 Linear 权重。

AttentionLayer 可以包 FullAttentionProbAttentionDSAttention。因此看模型时要同时看外壳和 inner attention。

*记录并在线阅读我的笔记*