Skip to content

Attention 基础操作:view、matmul、einsum、softmax、topk

Abstract

这篇不重新讲完整 attention 理论。

它只解释源码中最容易卡住的基础操作:Q/K/V 怎样拆多头,分数怎样算,softmax 在哪一维做,Informer 为什么还要 topk

0. 文件索引

项目内容
源文件ts_benchmark/baselines/time_series_library/layers/SelfAttention_Family.py
相关类AttentionLayer / FullAttention / ProbAttention
覆盖模型Informer / PatchTST
toy 参数B=2, L=6, S=6, H=2, d_model=16, d_k=8

1. Level 1:AttentionLayer 里的多头拆分

源码:

python
class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries, keys, values, attn_mask, tau=tau, delta=delta
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

toy:

text
queries: (B, L, d_model) = (2, 6, 16)
H = 2
d_k = d_model // H = 8

逐步:

text
Linear:
  (2, 6, 16) -> (2, 6, 16)

view(B, L, H, -1):
  (2, 6, 16) -> (2, 6, 2, 8)

-1 的意思是:

这一维让 PyTorch 自动推断。这里自动推断成 8

下图按本文 toy 参数重画这条 shape 链:(2,6,16) 先经过 Linear,再用 view(B,L,H,-1) 拆成 (2,6,2,8)

2. Level 2:FullAttention 里的 einsum

源码:

python
class FullAttention(nn.Module):
    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1.0 / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:
            ...

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        return V.contiguous(), None

输入:

text
queries: (B, L, H, E) = (2, 6, 2, 8)
keys:    (B, S, H, E) = (2, 6, 2, 8)
values:  (B, S, H, D) = (2, 6, 2, 8)

第一句:

python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)

下标解释:

字母含义
bbatch
lquery 位置
skey/value 位置
hhead
ekey/query 向量维度,被求和消掉

输出:

text
scores.shape = (B, H, L, S) = (2, 2, 6, 6)

这就是每个 head 内,每个 query 对每个 key 的相似度分数。

上面的图右下角同时画了 scores[b,h,:,:] 的矩阵语义:横轴是 key 位置 S,纵轴是 query 位置 Lsoftmax(dim=-1) 就是在每一行上做归一化。

3. Level 3:softmax 为什么是 dim=-1

源码:

python
A = self.dropout(torch.softmax(scale * scores, dim=-1))

此时:

text
scores.shape = (B, H, L, S)

最后一维 S 是 key 位置。

softmax(dim=-1) 表示:

对每个 query,让它看所有 key 的分数归一化成概率。

所以:

text
A[b, h, l, :].sum() = 1

4. Level 4:第二个 einsum 做加权求和

源码:

python
V = torch.einsum("bhls,bshd->blhd", A, values)

输入:

text
A:      (B, H, L, S)
values: (B, S, H, D)

输出:

text
V: (B, L, H, D)

含义:

对每个 query 位置 l,用它对所有 key/value 位置 s 的 attention 权重,加权求和 value 向量。

5. Level 5:可算小例子

只看一个 query、一个 head。

attention 分数:

text
scores = [2.0, 1.0, 0.0]

softmax 后:

text
weights ≈ [0.665, 0.245, 0.090]

value:

text
v0 = [10, 0]
v1 = [0, 20]
v2 = [10, 10]

输出:

text
out = 0.665*v0 + 0.245*v1 + 0.090*v2
    = 0.665*[10,0] + 0.245*[0,20] + 0.090*[10,10]
    = [7.55, 5.80]

这张图对应上面这个可算例子,只看一个 query、一个 head:

这就是 softmax(QK^T)V

6. Level 6:ProbAttention 里的 matmul

Informer 的 ProbAttention 不直接对所有 query-key 组合完整计算,而是先采样估计重要 query。

源码片段:

python
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

这里的核心是:

python
torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1))

如果:

text
Q.unsqueeze(-2):          (..., 1, E)
K_sample.transpose(-2,-1):(..., E, sample_k)

那么:

text
matmul -> (..., 1, sample_k)

也就是一个 query 和若干 sampled keys 的点积分数。

7. Level 7:ProbAttention 里的 topk

源码:

python
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]

M 是每个 query 的稀疏度估计。

直觉:

text
如果某个 query 对少数 key 特别敏感,它的 max 会明显大于平均值。
这个 query 更值得完整计算 attention。

topk 返回两个东西:

python
values, indices = M.topk(n_top)

源码里取 [1]

python
M_top = M.topk(n_top, sorted=False)[1]

表示只要 top query 的下标。

8. Level 8:FullAttention 和 ProbAttention 的区别

项目FullAttentionProbAttention
代表模型PatchTSTInformer
核心计算所有 query 对所有 key先采样估计,再只重点算 top query
主要函数einsum / softmaxrandint / matmul / topk / softmax
理解重点标准 QK^T稀疏 query 选择

9. 常见错误

9.1 忘记 view 只是改形状,不是新建 attention head

真正生成 Q/K/V 的是:

python
Linear

拆成多头的是:

python
view(B, L, H, -1)

9.2 softmax 维度看错

在 attention 里,softmax(dim=-1) 通常是在 key 维做。

也就是:

每个 query 对所有 key 的权重和为 1。

10. 一句话总结

Attention 里的基础操作可以压成一句:

Linear 得到 Q/K/V,view 拆多头,matmul/einsum 算分数,softmax(dim=-1) 变权重,再对 value 加权求和;Informer 的 topk 是为了只挑重要 query 精算。

*记录并在线阅读我的笔记*