Appearance
AttentionLayer:多头投影外壳
Abstract
AttentionLayer不负责真正的 attention 公式。它负责四件事:把输入投影成 Q/K/V、拆成多头、调用
inner_attention、再把多头输出合并回d_model。
1. 在源码里的位置
python
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
...
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads2. 图解
![[zdocs/pytorch-basics/assets/self_attention_attentionlayer.svg]]
3. I/O
toy:
text
B=2, L=5, S=6, d_model=8, H=2, d_keys=3, d_values=4| 张量 | 输入 shape | 说明 |
|---|---|---|
queries | (2,5,8) | query 序列 |
keys | (2,6,8) | key 序列 |
values | (2,6,8) | value 序列 |
| 输出 | (2,5,8) | 回到 d_model |
4. 逐行精读
源码:
python
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attn注解:
python
B, L, _ = queries.shape
# queries: (2,5,8), 所以 B=2, L=5
_, S, _ = keys.shape
# keys: (2,6,8), 所以 S=6
H = self.n_heads
# H = 2Q/K/V 投影:
python
queries = self.query_projection(queries).view(B, L, H, -1)
# Linear(8 -> d_keys*n_heads = 3*2 = 6)
# (2,5,8) -> (2,5,6)
# view(2,5,2,-1) -> (2,5,2,3)
keys = self.key_projection(keys).view(B, S, H, -1)
# (2,6,8) -> Linear -> (2,6,6) -> view -> (2,6,2,3)
values = self.value_projection(values).view(B, S, H, -1)
# Linear(8 -> d_values*n_heads = 4*2 = 8)
# (2,6,8) -> (2,6,8) -> view -> (2,6,2,4)调用内层 attention:
python
out, attn = self.inner_attention(queries, keys, values, ...)
# 如果 inner_attention 是 FullAttention:
# out: (B,L,H,d_values) = (2,5,2,4)合并多头:
python
out = out.view(B, L, -1)
# (2,5,2,4) -> (2,5,8)
return self.out_projection(out), attn
# Linear(8 -> 8)
# (2,5,8) -> (2,5,8)5. 数学逻辑
AttentionLayer 的数学主线不是 softmax,而是线性投影:
拆头只是把最后一维重新解释:
最后再合并:
6. 常见误区
view(B,L,H,-1) 不是“生成多个头”,它只是改 shape。真正给不同 head 提供不同子空间的是前面的 Linear 权重。
AttentionLayer 可以包 FullAttention、ProbAttention、DSAttention。因此看模型时要同时看外壳和 inner attention。