Appearance
Layer 5 — ProbAttention 精读
由 AttentionLayer(04-Layer4-AttentionLayer)的
self.inner_attention(Q, K, V, ...)调用。
本文档覆盖ProbAttention.forward及其三个子函数:_prob_QK、_get_initial_context、_update_context。
1. 在父层中的位置
AttentionLayer.forward
└─ self.inner_attention(Q, K, V, attn_mask) ← ProbAttention.forward(本文档)
├─ _prob_QK(Q, K, sample_k, n_top) ← 稀疏查询筛选
├─ _get_initial_context(V, L_Q) ← 初始化上下文
└─ _update_context(context, V, ...) ← 填入活跃查询精确值传入时 Q/K/V 的格式是 (B, L, H, D) = (3, 10, 4, 2)(AttentionLayer 拆好的多头格式)。
ProbAttention 内部首先 transpose(2,1) 切换为 (B, H, L, D) 格式处理,返回 context: (B, H, L_Q, D) = (3, 4, 10, 2)。
2. I/O 接口定义
python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):以 EncoderLayer 0 自注意力(toy 基准)为例:
| shape(toy) | 含义 | |
|---|---|---|
输入 queries | (3, 10, 4, 2) = (B, L_Q, H, D) | AttentionLayer 拆好的多头 query |
输入 keys | (3, 10, 4, 2) | 多头 key |
输入 values | (3, 10, 4, 2) | 多头 value |
输出 context | (3, 4, 10, 2) = (B, H, L_Q, D) | 聚合后上下文(top-u 精算 + 其余 mean(V)) |
输出 attn | None | output_attention=False 时为 None |
⚠️ 输出格式是
(B, H, L_Q, D),而非 AttentionLayer 输入时的(B, L, H, D)。
AttentionLayer 在拿回结果后用.view(B, L, -1)合并,需要先.contiguous()。
3. 顺序图(具体层)
4. 语义分组图(索引层)
核心思路:大多数 query 是"懒"的(注意力接近均匀分布),其输出 ≈ mean(V),可以直接填充省去计算;只对少数"活跃" query 精算真实注意力。整体复杂度从 O(L²) 降至 O(L log L)。
5. 逐步解析
5.1 宏观逻辑
论文核心观察:标准 Transformer 的注意力
M = max − mean 衡量活跃度:
两步参数的关键区分(最容易混淆):
用小例子(
所有 query: q0 q1 q2 q3 q4 q5 q6 q7 q8 q9
↓ 先用随机采样的 3 个 key 估算活跃度 M
活跃度 M: 0.38 0.09 0.41 0.07 0.35 0.12 0.08 0.15 0.06 0.20
↓ topk(3),选活跃度最高的 3 个
Top-3 选中: q0 q2 q4
阶段1 — 初始化(全部 mean(V)):
context[q0..q9] = [0.46, 0.45] ← 所有位置相同
阶段2 — 精算更新(只更新 3 个):
context[q0] = softmax(q0·K.T/√D) @ V ← 精确值
context[q2] = softmax(q2·K.T/√D) @ V
context[q4] = softmax(q4·K.T/√D) @ V
context[q1/q3/q5..q9] = [0.46, 0.45] ← 保持不变5.2 完整原始代码(forward)
本节的作用
forward是 ProbAttention 的入口:格式转置 → U/u 参数计算 →_prob_QK稀疏筛选 → 缩放 →_get_initial_context初始化 →_update_context精算更新 → 返回。
python
# SelfAttention_Family.py:95-206
class ProbAttention(nn.Module):
def __init__(
self,
mask_flag=True,
factor=5,
scale=None,
attention_dropout=0.1,
output_attention=False,
):
super(ProbAttention, self).__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = queries.transpose(2, 1)
keys = keys.transpose(2, 1)
values = values.transpose(2, 1)
U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
scale = self.scale or 1.0 / sqrt(D)
if scale is not None:
scores_top = scores_top * scale
context = self._get_initial_context(values, L_Q)
context, attn = self._update_context(
context, values, scores_top, index, L_Q, attn_mask
)
return context.contiguous(), attn5.3 输入变换 + 参数计算
本节的作用
transpose(2,1)把 AttentionLayer 传入的(B,L,H,D)格式转为 ProbAttention 内部的(B,H,L,D);U_part和u是 ProbSparse 的两个核心参数,分别控制采样 key 数和活跃 query 数。
python
B, L_Q, H, D = queries.shape # B=3, L_Q=10, H=4, D=2
_, L_K, _, _ = keys.shape # L_K=10
queries = queries.transpose(2, 1) # (3,10,4,2) → (3,4,10,2) = (B,H,L_Q,D)
keys = keys.transpose(2, 1) # (3,10,4,2) → (3,4,10,2) = (B,H,L_K,D)
values = values.transpose(2, 1) # (3,10,4,2) → (3,4,10,2) = (B,H,L_K,D)transpose(2,1) 交换 dim1(序列位置 L)和 dim2(注意力头 H):AttentionLayer 传入的是 (B,L,H,D),ProbAttention 内部需要 (B,H,L,D) 才能方便地按头处理。
python
U_part = 1 * np.ceil(np.log(10)).astype("int").item()
# np.log(10)≈2.302 → ceil→3 → U_part=3(每个 query 随机采样的 key 数)
u = 1 * np.ceil(np.log(10)).astype("int").item()
# u=3(选出的 top-u 活跃查询数)
U_part = 3 if 3 < 10 else 10 # = 3 ✓
u = 3 if 3 < 10 else 10 # = 3 ✓不同场景的参数推导:
Encoder Layer 0 自注意力 (L_K=L_Q=10, factor=1):
U_part = ceil(ln(10)) = 3, u = 3
Encoder Layer 1 自注意力 (distilling后 L=6):
U_part = ceil(ln(6)) = 2, u = 2
Decoder cross-attention (L_K=6, L_Q=12):
U_part = ceil(ln(6)) = 2, u = ceil(ln(12)) = 3
Decoder masked self-attention (L_Q=L_K=12):
U_part = ceil(ln(12)) = 3, u = 35.2 _prob_QK — 稀疏查询筛选
完整原始代码
python
def _prob_QK(self, Q, K, sample_k, n_top):
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# find the Top_k query with sparisty measurement ⚠️ typo: sparisty → sparsity
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
Q_reduce = Q[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
]
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))
return Q_K, M_top注解版:
python
# Q: (3,4,10,2) = (B,H,L_Q,E)
# K: (3,4,10,2) = (B,H,L_K,E)
# sample_k=3, n_top=3
B, H, L_K, E = K.shape # 3, 4, 10, 2
_, _, L_Q, _ = Q.shape # L_Q=10
# ─── Step 1: 扩展 K 供每个 query 独立采样 ───
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# K: (3,4,10,2)
# unsqueeze(-3): 在 dim=2 插入 → (3,4,1,10,2)
# expand → (3,4,10,10,2):每个 query 位置各有一份 K 的副本
# ─── Step 2: 随机采样 key 索引 ───
index_sample = torch.randint(L_K, (L_Q, sample_k))
# shape: (10,3),每行是 query i 随机选的 3 个 key 索引
# ─── Step 3: 取采样 key 向量 ───
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
# K_sample: (3,4,10,3,2),K_sample[b,h,q] = query q 的 3 个采样 key
# ─── Step 4: 点积得分 ───
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# Q.unsqueeze(-2): (3,4,10,1,2)
# K_sample.transpose(-2,-1): (3,4,10,2,3)
# matmul: → (3,4,10,1,3) → .squeeze() → (3,4,10,3)
# ⚠️ .squeeze() 无参:当 B=1 或 H=1 时会多压缩维度,应写 .squeeze(-2)
# ─── Step 5: M 稀疏度指标 ───
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
# max(-1)[0]: (3,4,10) ← 3 个采样得分的最大值
# sum(-1)/L_K: (3,4,10) ← 除以全量 key 数(不是采样数),近似均匀均值
# M: (3,4,10),值越大表示越活跃
M_top = M.topk(n_top=3, sorted=False)[1]
# M_top: (3,4,3) ← 每个 (batch, head) 中 M 最大的 3 个 query 位置索引
# ─── Step 6: 只为 top-u query 计算完整 Q×K.T ───
Q_reduce = Q[
torch.arange(B)[:, None, None], # (3,1,1)
torch.arange(H)[None, :, None], # (1,4,1)
M_top, # (3,4,3)
:
]
# Q_reduce: (3,4,3,2) ← 只取 top-3 query 的向量
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))
# K.transpose(-2,-1): (3,4,10,2) → (3,4,2,10)
# matmul: (3,4,3,2) × (3,4,2,10) = (3,4,3,10)
# Q_K: (3,4,3,10) ← top-3 query 对全 10 个 key 的完整点积
return Q_K, M_top
# Q_K: (3,4,3,10) = scores_top
# M_top: (3,4,3) = indextoy 数值追踪(head h=0, batch b=0):
Q[0,0,0,:] = [0.5, -0.3] ← query 0 向量
随机采样到 key 2, 7, 9:
K_sample[0,0,0,:,:] = [ [0.8,0.1], [0.2,0.9], [-0.4,0.6] ]
Q_K_sample[0,0,0,:] = [0.5×0.8+(-0.3)×0.1, 0.5×0.2+(-0.3)×0.9, 0.5×(-0.4)+(-0.3)×0.6]
= [0.37, -0.17, -0.38]
M[0,0,0] = max(0.37,-0.17,-0.38) - (0.37-0.17-0.38)/10
= 0.37 - (-0.018) = 0.388 ← 活跃
对比"懒"query: Q[0,0,5,:] = [0.1, 0.1](注意力均匀)
Q_K_sample[0,0,5,:] = [0.09, 0.11, 0.02]
M[0,0,5] = 0.11 - 0.022 = 0.088 ← 懒惰,远小于 q0
假设 M_top[0,0,:] = [0, 2, 4](活跃度最高的 3 个)
→ Q_K[0,0,:,:] = Q_reduce[0,0,:,:] @ K[0,0,:,:].T 形状 (3,10)5.3 _get_initial_context — 初始化上下文
完整原始代码
python
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
# V_sum = V.sum(dim=-2)
V_sum = V.mean(dim=-2)
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() # ⚠️ contex → context
else: # use mask
assert L_Q == L_V
contex = V.cumsum(dim=-2) # ⚠️ contex → context
return contex两条路径的含义:
mask_flag=False(encoder 自注意力 + decoder cross-attention):
"懒" query 的真实输出 ≈ 均匀注意力 × V = mean(V)。所有位置填同一个均值向量,后续 top-u 活跃 query 的位置被精算值覆盖。
mean(V) → 广播到所有 L_Q 个位置
context = [ [0.46, 0.45], ← q0(懒查询近似)
[0.46, 0.45], ← q1
...
[0.46, 0.45] ] ← q9 (全部相同)mask_flag=True(decoder masked 自注意力):
因果 mask 下位置 t 只能看 V[0..t],因此"懒"位置的近似是 cumsum(V)(前缀和),而不是全局 mean。
context[t] = V[0] + V[1] + ... + V[t] (每行不同,越靠后越大)注解版:
python
# V: (3,4,10,2) = (B,H,L_V,D), L_Q=10
# ─── mask_flag=False ───
V_sum = V.mean(dim=-2)
# dim=-2 即 L_V 维,求均值 → (3,4,2)
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
# unsqueeze(-2): (3,4,2) → (3,4,1,2)
# expand: (3,4,10,2) ← 广播到 L_Q=10 个位置
# .clone(): 必须 clone!expand 是 view,后续 index 赋值需要真实内存
# ─── mask_flag=True ───
assert L_Q == L_V # 自注意力时 Q/V 序列等长
contex = V.cumsum(dim=-2)
# (3,4,10,2),每个位置是到该位置的 value 前缀和toy 数值追踪(mask_flag=False,head h=0, batch b=0):
V[0,0,:,:] = [ [0.1,0.8],[0.5,0.3],[0.9,0.2],[0.4,0.6],[0.7,0.1],
[0.2,0.5],[0.6,0.4],[0.3,0.7],[0.8,0.0],[0.1,0.9] ]
mean = [(0.1+0.5+0.9+0.4+0.7+0.2+0.6+0.3+0.8+0.1)/10,
(0.8+0.3+0.2+0.6+0.1+0.5+0.4+0.7+0.0+0.9)/10]
= [0.46, 0.45]
context[0,0,:,:] = 全部 10 行填 [0.46, 0.45]左侧(Encoder,mask_flag=False):所有位置初始化为同一 mean(V),top-u 活跃查询(珊瑚色)后续被精算值覆盖。
右侧(Decoder Masked,mask_flag=True):cumsum 使每个位置初始值不同,体现因果约束;top-u 活跃查询同样被覆盖。
5.4 _update_context — 填入活跃查询的真实注意力
完整原始代码
python
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = torch.softmax(scores, dim=-1)
context_in[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
attns[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = attn
return context_in, attns
else:
return context_in, None注解版:
python
# context_in: (3,4,10,2) ← 来自 _get_initial_context,全部填了 mean(V)
# V: (3,4,10,2)
# scores: (3,4,3,10) ← top-3 query 对全 10 个 key 的点积(已乘 scale)
# index: (3,4,3) ← 哪 3 个 query 位置是活跃的
# ─── mask_flag=True: 施加因果掩码 ───
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
# 构造上三角因果掩码,只取 top-u 行对应位置
# attn_mask.mask: (3,4,3,10),True 表示"未来位置"需屏蔽
scores.masked_fill_(attn_mask.mask, -np.inf)
# 被屏蔽位置 softmax 后概率为 0
# ─── 所有路径:softmax ───
attn = torch.softmax(scores, dim=-1)
# dim=-1 即 key 维(L_K=10),每行和为 1
# attn: (3,4,3,10)
# ─── 高级索引写回 ───
context_in[arange(B)[:,None,None], arange(H)[None,:,None], index, :]
= torch.matmul(attn, V).type_as(context_in)
# matmul(attn, V): (3,4,3,10) × (3,4,10,2) = (3,4,3,2)
# 将 top-u 个 query 的真实注意力输出写入 context_in 对应位置
# output_attention=False(TFB 默认)
return context_in, None
# context_in: (3,4,10,2),top-u 位置精算,其余保持 mean(V)图解 — context 写回过程:
_get_initial_context 后:
index = [0, 2, 4] ← top-3 活跃 query 位置
context = [ [0.46,0.45], ← q0 (mean(V))
[0.46,0.45], ← q1
[0.46,0.45], ← q2
...
[0.46,0.45] ] ← q9
attn @ V 结果(精算):
q0 → [0.71, 0.23]
q2 → [0.35, 0.58]
q4 → [0.89, 0.12]
_update_context 后:
context = [ [0.71,0.23], ← q0 ✓ 精确注意力
[0.46,0.45], ← q1 保持 mean(V)
[0.35,0.58], ← q2 ✓ 精确注意力
[0.46,0.45], ← q3
[0.89,0.12], ← q4 ✓ 精确注意力
[0.46,0.45], ← q5
...
[0.46,0.45] ] ← q9toy 数值追踪(head h=0, batch b=0,延续 §5.2/5.3 数字):
scores_top[0,0,0,:] 经 scale×0.707 后(示意):
[0.26, 0.04, 0.38, 0.11, 0.05, 0.09, 0.02, 0.08, 0.12, 0.03]
attn[0,0,0,:] = softmax(...) → [0.12, 0.09, 0.14, 0.10, 0.09, ...](和=1.0)
matmul(attn, V)[0,0,0,:] = attn @ V[0,0,:,:]
= Σ_k attn[k] × V[k,:]
= 0.12×[0.1,0.8] + 0.09×[0.5,0.3] + 0.14×[0.9,0.2] + ...
≈ [0.48, 0.35](示意)
context_in[0,0,0,:] = [0.48, 0.35] ← q0 从 [0.46,0.45] 更新为精确值
context_in[0,0,2,:] = [...] ← q2 更新
context_in[0,0,4,:] = [...] ← q4 更新
context_in[0,0,{1,3,5,6,7,8,9},:] = [0.46,0.45] ← 保持不变6. 下钻子组件
ProbAttention 的所有核心逻辑已在本文档完整覆盖,无需进一步下钻。
ProbMask(在 _update_context 的 mask_flag=True 路径中使用):构造因果上三角掩码,只取 top-u 行对应位置,使活跃 query 的注意力计算仍遵守"只能看左边"的 causal 约束。源码 masking.py:17-28,逻辑简单,无需单独文档。