Appearance
Attention 基础操作:view、matmul、einsum、softmax、topk
Abstract
这篇不重新讲完整 attention 理论。
它只解释源码中最容易卡住的基础操作:Q/K/V 怎样拆多头,分数怎样算,softmax 在哪一维做,Informer 为什么还要
topk。
0. 文件索引
| 项目 | 内容 |
|---|---|
| 源文件 | ts_benchmark/baselines/time_series_library/layers/SelfAttention_Family.py |
| 相关类 | AttentionLayer / FullAttention / ProbAttention |
| 覆盖模型 | Informer / PatchTST |
| toy 参数 | B=2, L=6, S=6, H=2, d_model=16, d_k=8 |
1. Level 1:AttentionLayer 里的多头拆分
源码:
python
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attntoy:
text
queries: (B, L, d_model) = (2, 6, 16)
H = 2
d_k = d_model // H = 8逐步:
text
Linear:
(2, 6, 16) -> (2, 6, 16)
view(B, L, H, -1):
(2, 6, 16) -> (2, 6, 2, 8)-1 的意思是:
这一维让 PyTorch 自动推断。这里自动推断成
8。
下图按本文 toy 参数重画这条 shape 链:(2,6,16) 先经过 Linear,再用 view(B,L,H,-1) 拆成 (2,6,2,8)。
2. Level 2:FullAttention 里的 einsum
源码:
python
class FullAttention(nn.Module):
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
...
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
return V.contiguous(), None输入:
text
queries: (B, L, H, E) = (2, 6, 2, 8)
keys: (B, S, H, E) = (2, 6, 2, 8)
values: (B, S, H, D) = (2, 6, 2, 8)第一句:
python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)下标解释:
| 字母 | 含义 |
|---|---|
b | batch |
l | query 位置 |
s | key/value 位置 |
h | head |
e | key/query 向量维度,被求和消掉 |
输出:
text
scores.shape = (B, H, L, S) = (2, 2, 6, 6)这就是每个 head 内,每个 query 对每个 key 的相似度分数。
上面的图右下角同时画了 scores[b,h,:,:] 的矩阵语义:横轴是 key 位置 S,纵轴是 query 位置 L;softmax(dim=-1) 就是在每一行上做归一化。
3. Level 3:softmax 为什么是 dim=-1
源码:
python
A = self.dropout(torch.softmax(scale * scores, dim=-1))此时:
text
scores.shape = (B, H, L, S)最后一维 S 是 key 位置。
softmax(dim=-1) 表示:
对每个 query,让它看所有 key 的分数归一化成概率。
所以:
text
A[b, h, l, :].sum() = 14. Level 4:第二个 einsum 做加权求和
源码:
python
V = torch.einsum("bhls,bshd->blhd", A, values)输入:
text
A: (B, H, L, S)
values: (B, S, H, D)输出:
text
V: (B, L, H, D)含义:
对每个 query 位置
l,用它对所有 key/value 位置s的 attention 权重,加权求和 value 向量。
5. Level 5:可算小例子
只看一个 query、一个 head。
attention 分数:
text
scores = [2.0, 1.0, 0.0]softmax 后:
text
weights ≈ [0.665, 0.245, 0.090]value:
text
v0 = [10, 0]
v1 = [0, 20]
v2 = [10, 10]输出:
text
out = 0.665*v0 + 0.245*v1 + 0.090*v2
= 0.665*[10,0] + 0.245*[0,20] + 0.090*[10,10]
= [7.55, 5.80]这张图对应上面这个可算例子,只看一个 query、一个 head:
这就是 softmax(QK^T)V。
6. Level 6:ProbAttention 里的 matmul
Informer 的 ProbAttention 不直接对所有 query-key 组合完整计算,而是先采样估计重要 query。
源码片段:
python
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()这里的核心是:
python
torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1))如果:
text
Q.unsqueeze(-2): (..., 1, E)
K_sample.transpose(-2,-1):(..., E, sample_k)那么:
text
matmul -> (..., 1, sample_k)也就是一个 query 和若干 sampled keys 的点积分数。
7. Level 7:ProbAttention 里的 topk
源码:
python
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]M 是每个 query 的稀疏度估计。
直觉:
text
如果某个 query 对少数 key 特别敏感,它的 max 会明显大于平均值。
这个 query 更值得完整计算 attention。topk 返回两个东西:
python
values, indices = M.topk(n_top)源码里取 [1]:
python
M_top = M.topk(n_top, sorted=False)[1]表示只要 top query 的下标。
8. Level 8:FullAttention 和 ProbAttention 的区别
| 项目 | FullAttention | ProbAttention |
|---|---|---|
| 代表模型 | PatchTST | Informer |
| 核心计算 | 所有 query 对所有 key | 先采样估计,再只重点算 top query |
| 主要函数 | einsum / softmax | randint / matmul / topk / softmax |
| 理解重点 | 标准 QK^T | 稀疏 query 选择 |
9. 常见错误
9.1 忘记 view 只是改形状,不是新建 attention head
真正生成 Q/K/V 的是:
python
Linear拆成多头的是:
python
view(B, L, H, -1)9.2 softmax 维度看错
在 attention 里,softmax(dim=-1) 通常是在 key 维做。
也就是:
每个 query 对所有 key 的权重和为 1。
10. 一句话总结
Attention 里的基础操作可以压成一句:
Linear得到 Q/K/V,view拆多头,matmul/einsum算分数,softmax(dim=-1)变权重,再对 value 加权求和;Informer 的topk是为了只挑重要 query 精算。