Skip to content

ProbAttention:ProbSparse 稀疏注意力

Abstract

ProbAttention 是 Informer 的核心:不对所有 query 都完整计算 attention。

它先随机采样 key 估计每个 query 的稀疏度,只挑 top query 精算,再把结果写回 context。

1. 图解

![[zdocs/pytorch-basics/assets/self_attention_probattention.svg]]

2. 源码结构

ProbAttention 有 4 个关键方法:

方法作用
_prob_QK随机采样 key,估计 query 稀疏度,选 top query 并精算 QK
_get_initial_context为所有 query 初始化默认 context
_update_context只更新 top query 的 context
forward串起上面三步

3. forward 主链

源码:

python
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape

queries = queries.transpose(2, 1)
keys = keys.transpose(2, 1)
values = values.transpose(2, 1)

U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()

U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q

scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

toy:

text
queries: (B,L_Q,H,D) = (2,5,2,4)
keys:    (B,L_K,H,D) = (2,6,2,4)
factor = 2

transpose 后:

text
queries: (2,2,5,4)
keys:    (2,2,6,4)
values:  (2,2,6,4)

采样数量:

Upart=factorln(LK)=2ln(6)=4

top query 数:

u=factorln(LQ)=2ln(5)=4

所以这组 toy 中:每个 query 采样 4 个 key,最后选 4 个 query 精算。

4. _prob_QK 精读

源码:

python
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]

shape:

text
Q: (B,H,L_Q,D) = (2,2,5,4)
K: (B,H,L_K,D) = (2,2,6,4)

K_expand:
(2,2,6,4) -> unsqueeze(-3) -> (2,2,1,6,4)
expand -> (2,2,5,6,4)

K_sample:
(2,2,5,4,4)

Q.unsqueeze(-2):
(2,2,5,4) -> (2,2,5,1,4)

K_sample.transpose(-2,-1):
(2,2,5,4,4) -> (2,2,5,4,4)

Q_K_sample:
(2,2,5,1,4) @ (2,2,5,4,4) -> (2,2,5,1,4) -> squeeze -> (2,2,5,4)

稀疏度公式:

Mi=maxj(QiKj)1LKjQiKj

直觉:

text
如果 max 很大、平均值不大,说明该 query 只强烈关注少数 key。
这种 query 是 ProbSparse 认为“值得精算”的 query。

5. _get_initial_context

源码:

python
if not self.mask_flag:
    V_sum = V.mean(dim=-2)
    contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
else:
    assert L_Q == L_V
    contex = V.cumsum(dim=-2)

非 mask 场景:

contextb,h,l,:=1LVsVb,h,s,:

也就是所有 query 先拿同一个 value 平均值当默认输出。

mask 场景:

contextb,h,l,:=slVb,h,s,:

这用于 causal self-attention。

6. _update_context

源码:

python
attn = torch.softmax(scores, dim=-1)

context_in[
    torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = torch.matmul(attn, V).type_as(context_in)

含义:

text
只更新 index 指向的 top query。
没被选中的 query 继续保留 initial_context。

加权求和公式:

contextb,h,indexi,d=sattnb,h,i,sVb,h,s,d

7. 输出 shape 注意

FullAttention 返回:

text
(B,L,H,D)

ProbAttention.forward 返回:

text
(B,H,L,D)

这看起来不一致,但 Informer 后续代码会配合它使用。读代码时一定要看调用者是否期待这个格式。

8. 常见误区

ProbAttention 不是“不算 attention”,而是只对 top query 完整算 attention。

topk 选的是 query,不是 key。

factor 越大,采样 key 和 top query 越多,越接近 FullAttention,但计算更重。

*记录并在线阅读我的笔记*