Appearance
ProbAttention:ProbSparse 稀疏注意力
Abstract
ProbAttention是 Informer 的核心:不对所有 query 都完整计算 attention。它先随机采样 key 估计每个 query 的稀疏度,只挑 top query 精算,再把结果写回 context。
1. 图解
![[zdocs/pytorch-basics/assets/self_attention_probattention.svg]]
2. 源码结构
ProbAttention 有 4 个关键方法:
| 方法 | 作用 |
|---|---|
_prob_QK | 随机采样 key,估计 query 稀疏度,选 top query 并精算 QK |
_get_initial_context | 为所有 query 初始化默认 context |
_update_context | 只更新 top query 的 context |
forward | 串起上面三步 |
3. forward 主链
源码:
python
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = queries.transpose(2, 1)
keys = keys.transpose(2, 1)
values = values.transpose(2, 1)
U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)toy:
text
queries: (B,L_Q,H,D) = (2,5,2,4)
keys: (B,L_K,H,D) = (2,6,2,4)
factor = 2transpose 后:
text
queries: (2,2,5,4)
keys: (2,2,6,4)
values: (2,2,6,4)采样数量:
top query 数:
所以这组 toy 中:每个 query 采样 4 个 key,最后选 4 个 query 精算。
4. _prob_QK 精读
源码:
python
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]shape:
text
Q: (B,H,L_Q,D) = (2,2,5,4)
K: (B,H,L_K,D) = (2,2,6,4)
K_expand:
(2,2,6,4) -> unsqueeze(-3) -> (2,2,1,6,4)
expand -> (2,2,5,6,4)
K_sample:
(2,2,5,4,4)
Q.unsqueeze(-2):
(2,2,5,4) -> (2,2,5,1,4)
K_sample.transpose(-2,-1):
(2,2,5,4,4) -> (2,2,5,4,4)
Q_K_sample:
(2,2,5,1,4) @ (2,2,5,4,4) -> (2,2,5,1,4) -> squeeze -> (2,2,5,4)稀疏度公式:
直觉:
text
如果 max 很大、平均值不大,说明该 query 只强烈关注少数 key。
这种 query 是 ProbSparse 认为“值得精算”的 query。5. _get_initial_context
源码:
python
if not self.mask_flag:
V_sum = V.mean(dim=-2)
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
else:
assert L_Q == L_V
contex = V.cumsum(dim=-2)非 mask 场景:
也就是所有 query 先拿同一个 value 平均值当默认输出。
mask 场景:
这用于 causal self-attention。
6. _update_context
源码:
python
attn = torch.softmax(scores, dim=-1)
context_in[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = torch.matmul(attn, V).type_as(context_in)含义:
text
只更新 index 指向的 top query。
没被选中的 query 继续保留 initial_context。加权求和公式:
7. 输出 shape 注意
FullAttention 返回:
text
(B,L,H,D)ProbAttention.forward 返回:
text
(B,H,L,D)这看起来不一致,但 Informer 后续代码会配合它使用。读代码时一定要看调用者是否期待这个格式。
8. 常见误区
ProbAttention 不是“不算 attention”,而是只对 top query 完整算 attention。
topk 选的是 query,不是 key。
factor 越大,采样 key 和 top query 越多,越接近 FullAttention,但计算更重。