Skip to content

Layer 5 — ProbAttention 精读

由 AttentionLayer(04-Layer4-AttentionLayer)的 self.inner_attention(Q, K, V, ...) 调用。
本文档覆盖 ProbAttention.forward 及其三个子函数:_prob_QK_get_initial_context_update_context


1. 在父层中的位置

AttentionLayer.forward
  └─ self.inner_attention(Q, K, V, attn_mask)   ← ProbAttention.forward(本文档)
       ├─ _prob_QK(Q, K, sample_k, n_top)        ← 稀疏查询筛选
       ├─ _get_initial_context(V, L_Q)           ← 初始化上下文
       └─ _update_context(context, V, ...)       ← 填入活跃查询精确值

传入时 Q/K/V 的格式是 (B, L, H, D) = (3, 10, 4, 2)(AttentionLayer 拆好的多头格式)。
ProbAttention 内部首先 transpose(2,1) 切换为 (B, H, L, D) 格式处理,返回 context: (B, H, L_Q, D) = (3, 4, 10, 2)


2. I/O 接口定义

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):

以 EncoderLayer 0 自注意力(toy 基准)为例:

shape(toy)含义
输入 queries(3, 10, 4, 2) = (B, L_Q, H, D)AttentionLayer 拆好的多头 query
输入 keys(3, 10, 4, 2)多头 key
输入 values(3, 10, 4, 2)多头 value
输出 context(3, 4, 10, 2) = (B, H, L_Q, D)聚合后上下文(top-u 精算 + 其余 mean(V))
输出 attnNoneoutput_attention=False 时为 None

⚠️ 输出格式是 (B, H, L_Q, D),而非 AttentionLayer 输入时的 (B, L, H, D)
AttentionLayer 在拿回结果后用 .view(B, L, -1) 合并,需要先 .contiguous()


3. 顺序图(具体层)


4. 语义分组图(索引层)

核心思路:大多数 query 是"懒"的(注意力接近均匀分布),其输出 ≈ mean(V),可以直接填充省去计算;只对少数"活跃" query 精算真实注意力。整体复杂度从 O(L²) 降至 O(L log L)。


5. 逐步解析

5.1 宏观逻辑

论文核心观察:标准 Transformer 的注意力 softmax(QKTd)V 中,大多数 query 的注意力分布接近均匀分布——这类"懒" query 做的是加权平均,而均匀加权恰好等于 mean(V),可以直接代替,不用计算。

M = max − mean 衡量活跃度

M(qi,K)=maxj(qikjTd)1LKj(qikjTd)

M 大 → 注意力集中(有主导 key,活跃);M 小 → 注意力均匀(懒惰,近似 mean(V) 即可)。

两步参数的关键区分(最容易混淆):

samplek=factor×ceil(ln(LK)):粗估用 每个 query 随机采样 samplek 个 key,算 M 分数 目的:快速估计哪些 query 可能活跃

ntop=factor×ceil(ln(LQ)):精算用 只为 M 最大的 ntop 个 query 做完整 Q×K.T 目的:精确计算少数活跃 query 的注意力

用小例子(LQ=LK=10,factor=1,u=3)串起来:

所有 query: q0  q1  q2  q3  q4  q5  q6  q7  q8  q9
               ↓  先用随机采样的 3 个 key 估算活跃度 M
活跃度 M:   0.38 0.09 0.41 0.07 0.35 0.12 0.08 0.15 0.06 0.20
                 ↓  topk(3),选活跃度最高的 3 个
Top-3 选中:  q0  q2  q4

阶段1 — 初始化(全部 mean(V)):
  context[q0..q9] = [0.46, 0.45]  ← 所有位置相同

阶段2 — 精算更新(只更新 3 个):
  context[q0] = softmax(q0·K.T/√D) @ V  ← 精确值
  context[q2] = softmax(q2·K.T/√D) @ V
  context[q4] = softmax(q4·K.T/√D) @ V
  context[q1/q3/q5..q9] = [0.46, 0.45]  ← 保持不变

5.2 完整原始代码(forward)

本节的作用

forward 是 ProbAttention 的入口:格式转置 → U/u 参数计算 → _prob_QK 稀疏筛选 → 缩放 → _get_initial_context 初始化 → _update_context 精算更新 → 返回。

python
# SelfAttention_Family.py:95-206
class ProbAttention(nn.Module):
    def __init__(
        self,
        mask_flag=True,
        factor=5,
        scale=None,
        attention_dropout=0.1,
        output_attention=False,
    ):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
        
    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape

        queries = queries.transpose(2, 1)
        keys = keys.transpose(2, 1)
        values = values.transpose(2, 1)

        U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
        u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()

        U_part = U_part if U_part < L_K else L_K
        u = u if u < L_Q else L_Q

        scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

        scale = self.scale or 1.0 / sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale

        context = self._get_initial_context(values, L_Q)
        context, attn = self._update_context(
            context, values, scores_top, index, L_Q, attn_mask
        )

        return context.contiguous(), attn

5.3 输入变换 + 参数计算

本节的作用

transpose(2,1) 把 AttentionLayer 传入的 (B,L,H,D) 格式转为 ProbAttention 内部的 (B,H,L,D)U_partu 是 ProbSparse 的两个核心参数,分别控制采样 key 数和活跃 query 数。

python
B, L_Q, H, D = queries.shape  # B=3, L_Q=10, H=4, D=2
_, L_K, _, _ = keys.shape     # L_K=10

queries = queries.transpose(2, 1)  # (3,10,4,2) → (3,4,10,2) = (B,H,L_Q,D)
keys    = keys.transpose(2, 1)     # (3,10,4,2) → (3,4,10,2) = (B,H,L_K,D)
values  = values.transpose(2, 1)   # (3,10,4,2) → (3,4,10,2) = (B,H,L_K,D)

transpose(2,1) 交换 dim1(序列位置 L)和 dim2(注意力头 H):AttentionLayer 传入的是 (B,L,H,D),ProbAttention 内部需要 (B,H,L,D) 才能方便地按头处理。

python
U_part = 1 * np.ceil(np.log(10)).astype("int").item()
# np.log(10)≈2.302 → ceil→3 → U_part=3(每个 query 随机采样的 key 数)

u = 1 * np.ceil(np.log(10)).astype("int").item()
# u=3(选出的 top-u 活跃查询数)

U_part = 3 if 3 < 10 else 10  # = 3 ✓
u      = 3 if 3 < 10 else 10  # = 3 ✓

不同场景的参数推导:

Encoder Layer 0 自注意力 (L_K=L_Q=10, factor=1):
  U_part = ceil(ln(10)) = 3,  u = 3

Encoder Layer 1 自注意力 (distilling后 L=6):
  U_part = ceil(ln(6)) = 2,   u = 2

Decoder cross-attention (L_K=6, L_Q=12):
  U_part = ceil(ln(6))  = 2,  u = ceil(ln(12)) = 3

Decoder masked self-attention (L_Q=L_K=12):
  U_part = ceil(ln(12)) = 3,  u = 3

5.2 _prob_QK — 稀疏查询筛选

完整原始代码

python
def _prob_QK(self, Q, K, sample_k, n_top):
    B, H, L_K, E = K.shape
    _, _, L_Q, _ = Q.shape

    K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
    index_sample = torch.randint(L_K, (L_Q, sample_k))
    K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
    Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

    # find the Top_k query with sparisty measurement   ⚠️ typo: sparisty → sparsity
    M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
    M_top = M.topk(n_top, sorted=False)[1]

    Q_reduce = Q[
        torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
    ]
    Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))

    return Q_K, M_top

注解版:

python
# Q: (3,4,10,2) = (B,H,L_Q,E)
# K: (3,4,10,2) = (B,H,L_K,E)
# sample_k=3, n_top=3

B, H, L_K, E = K.shape  # 3, 4, 10, 2
_, _, L_Q, _ = Q.shape  # L_Q=10

# ─── Step 1: 扩展 K 供每个 query 独立采样 ───
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# K: (3,4,10,2)
# unsqueeze(-3): 在 dim=2 插入 → (3,4,1,10,2)
# expand → (3,4,10,10,2):每个 query 位置各有一份 K 的副本

# ─── Step 2: 随机采样 key 索引 ───
index_sample = torch.randint(L_K, (L_Q, sample_k))
# shape: (10,3),每行是 query i 随机选的 3 个 key 索引

# ─── Step 3: 取采样 key 向量 ───
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
# K_sample: (3,4,10,3,2),K_sample[b,h,q] = query q 的 3 个采样 key

# ─── Step 4: 点积得分 ───
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# Q.unsqueeze(-2): (3,4,10,1,2)
# K_sample.transpose(-2,-1): (3,4,10,2,3)
# matmul: → (3,4,10,1,3) → .squeeze() → (3,4,10,3)
# ⚠️ .squeeze() 无参:当 B=1 或 H=1 时会多压缩维度,应写 .squeeze(-2)

# ─── Step 5: M 稀疏度指标 ───
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
# max(-1)[0]: (3,4,10) ← 3 个采样得分的最大值
# sum(-1)/L_K: (3,4,10) ← 除以全量 key 数(不是采样数),近似均匀均值
# M: (3,4,10),值越大表示越活跃

M_top = M.topk(n_top=3, sorted=False)[1]
# M_top: (3,4,3) ← 每个 (batch, head) 中 M 最大的 3 个 query 位置索引

# ─── Step 6: 只为 top-u query 计算完整 Q×K.T ───
Q_reduce = Q[
    torch.arange(B)[:, None, None],   # (3,1,1)
    torch.arange(H)[None, :, None],   # (1,4,1)
    M_top,                             # (3,4,3)
    :
]
# Q_reduce: (3,4,3,2) ← 只取 top-3 query 的向量

Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))
# K.transpose(-2,-1): (3,4,10,2) → (3,4,2,10)
# matmul: (3,4,3,2) × (3,4,2,10) = (3,4,3,10)
# Q_K: (3,4,3,10) ← top-3 query 对全 10 个 key 的完整点积

return Q_K, M_top
# Q_K: (3,4,3,10) = scores_top
# M_top: (3,4,3) = index

toy 数值追踪(head h=0, batch b=0):

Q[0,0,0,:] = [0.5, -0.3]  ← query 0 向量
随机采样到 key 2, 7, 9:
  K_sample[0,0,0,:,:] = [ [0.8,0.1], [0.2,0.9], [-0.4,0.6] ]

Q_K_sample[0,0,0,:] = [0.5×0.8+(-0.3)×0.1, 0.5×0.2+(-0.3)×0.9, 0.5×(-0.4)+(-0.3)×0.6]
                     = [0.37, -0.17, -0.38]

M[0,0,0] = max(0.37,-0.17,-0.38) - (0.37-0.17-0.38)/10
          = 0.37 - (-0.018) = 0.388  ← 活跃

对比"懒"query: Q[0,0,5,:] = [0.1, 0.1](注意力均匀)
Q_K_sample[0,0,5,:] = [0.09, 0.11, 0.02]
M[0,0,5] = 0.11 - 0.022 = 0.088  ← 懒惰,远小于 q0

假设 M_top[0,0,:] = [0, 2, 4](活跃度最高的 3 个)
→ Q_K[0,0,:,:] = Q_reduce[0,0,:,:] @ K[0,0,:,:].T  形状 (3,10)

5.3 _get_initial_context — 初始化上下文

完整原始代码

python
def _get_initial_context(self, V, L_Q):
    B, H, L_V, D = V.shape
    if not self.mask_flag:
        # V_sum = V.sum(dim=-2)
        V_sum = V.mean(dim=-2)
        contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()  # ⚠️ contex → context
    else:  # use mask
        assert L_Q == L_V
        contex = V.cumsum(dim=-2)   # ⚠️ contex → context
    return contex

两条路径的含义:

mask_flag=False(encoder 自注意力 + decoder cross-attention)

"懒" query 的真实输出 ≈ 均匀注意力 × V = mean(V)。所有位置填同一个均值向量,后续 top-u 活跃 query 的位置被精算值覆盖。

mean(V) → 广播到所有 L_Q 个位置

context = [ [0.46, 0.45],   ← q0(懒查询近似)
           [0.46, 0.45],   ← q1
           ...
           [0.46, 0.45] ]   ← q9  (全部相同)

mask_flag=True(decoder masked 自注意力)

因果 mask 下位置 t 只能看 V[0..t],因此"懒"位置的近似是 cumsum(V)(前缀和),而不是全局 mean。

context[t] = V[0] + V[1] + ... + V[t]  (每行不同,越靠后越大)

注解版:

python
# V: (3,4,10,2) = (B,H,L_V,D),  L_Q=10

# ─── mask_flag=False ───
V_sum = V.mean(dim=-2)
# dim=-2 即 L_V 维,求均值 → (3,4,2)

contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
# unsqueeze(-2): (3,4,2) → (3,4,1,2)
# expand: (3,4,10,2) ← 广播到 L_Q=10 个位置
# .clone(): 必须 clone!expand 是 view,后续 index 赋值需要真实内存

# ─── mask_flag=True ───
assert L_Q == L_V       # 自注意力时 Q/V 序列等长
contex = V.cumsum(dim=-2)
# (3,4,10,2),每个位置是到该位置的 value 前缀和

toy 数值追踪(mask_flag=False,head h=0, batch b=0):

V[0,0,:,:] = [ [0.1,0.8],[0.5,0.3],[0.9,0.2],[0.4,0.6],[0.7,0.1],
              [0.2,0.5],[0.6,0.4],[0.3,0.7],[0.8,0.0],[0.1,0.9] ]

mean = [(0.1+0.5+0.9+0.4+0.7+0.2+0.6+0.3+0.8+0.1)/10,
        (0.8+0.3+0.2+0.6+0.1+0.5+0.4+0.7+0.0+0.9)/10]
     = [0.46, 0.45]

context[0,0,:,:] = 全部 10 行填 [0.46, 0.45]

左侧(Encoder,mask_flag=False):所有位置初始化为同一 mean(V),top-u 活跃查询(珊瑚色)后续被精算值覆盖。
右侧(Decoder Masked,mask_flag=True):cumsum 使每个位置初始值不同,体现因果约束;top-u 活跃查询同样被覆盖。


5.4 _update_context — 填入活跃查询的真实注意力

完整原始代码

python
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
    B, H, L_V, D = V.shape

    if self.mask_flag:
        attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
        scores.masked_fill_(attn_mask.mask, -np.inf)

    attn = torch.softmax(scores, dim=-1)

    context_in[
        torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
    ] = torch.matmul(attn, V).type_as(context_in)
    if self.output_attention:
        attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
        attns[
            torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
        ] = attn
        return context_in, attns
    else:
        return context_in, None

注解版:

python
# context_in: (3,4,10,2) ← 来自 _get_initial_context,全部填了 mean(V)
# V:          (3,4,10,2)
# scores:     (3,4,3,10) ← top-3 query 对全 10 个 key 的点积(已乘 scale)
# index:      (3,4,3)    ← 哪 3 个 query 位置是活跃的

# ─── mask_flag=True: 施加因果掩码 ───
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
# 构造上三角因果掩码,只取 top-u 行对应位置
# attn_mask.mask: (3,4,3,10),True 表示"未来位置"需屏蔽
scores.masked_fill_(attn_mask.mask, -np.inf)
# 被屏蔽位置 softmax 后概率为 0

# ─── 所有路径:softmax ───
attn = torch.softmax(scores, dim=-1)
# dim=-1 即 key 维(L_K=10),每行和为 1
# attn: (3,4,3,10)

# ─── 高级索引写回 ───
context_in[arange(B)[:,None,None], arange(H)[None,:,None], index, :]
    = torch.matmul(attn, V).type_as(context_in)
# matmul(attn, V): (3,4,3,10) × (3,4,10,2) = (3,4,3,2)
# 将 top-u 个 query 的真实注意力输出写入 context_in 对应位置

# output_attention=False(TFB 默认)
return context_in, None
# context_in: (3,4,10,2),top-u 位置精算,其余保持 mean(V)

图解 — context 写回过程:

_get_initial_context 后:
index = [0, 2, 4]  ← top-3 活跃 query 位置
context = [ [0.46,0.45],  ← q0 (mean(V))
           [0.46,0.45],  ← q1
           [0.46,0.45],  ← q2
           ...
           [0.46,0.45] ]  ← q9

attn @ V 结果(精算):
  q0 → [0.71, 0.23]
  q2 → [0.35, 0.58]
  q4 → [0.89, 0.12]

_update_context 后:
context = [ [0.71,0.23],  ← q0 ✓ 精确注意力
           [0.46,0.45],  ← q1   保持 mean(V)
           [0.35,0.58],  ← q2 ✓ 精确注意力
           [0.46,0.45],  ← q3
           [0.89,0.12],  ← q4 ✓ 精确注意力
           [0.46,0.45],  ← q5
           ...
           [0.46,0.45] ]  ← q9

toy 数值追踪(head h=0, batch b=0,延续 §5.2/5.3 数字):

scores_top[0,0,0,:] 经 scale×0.707 后(示意):
  [0.26, 0.04, 0.38, 0.11, 0.05, 0.09, 0.02, 0.08, 0.12, 0.03]

attn[0,0,0,:] = softmax(...) → [0.12, 0.09, 0.14, 0.10, 0.09, ...](和=1.0)

matmul(attn, V)[0,0,0,:] = attn @ V[0,0,:,:]
  = Σ_k attn[k] × V[k,:]
  = 0.12×[0.1,0.8] + 0.09×[0.5,0.3] + 0.14×[0.9,0.2] + ...
  ≈ [0.48, 0.35](示意)

context_in[0,0,0,:] = [0.48, 0.35]  ← q0 从 [0.46,0.45] 更新为精确值
context_in[0,0,2,:] = [...]         ← q2 更新
context_in[0,0,4,:] = [...]         ← q4 更新
context_in[0,0,{1,3,5,6,7,8,9},:] = [0.46,0.45]  ← 保持不变

6. 下钻子组件

ProbAttention 的所有核心逻辑已在本文档完整覆盖,无需进一步下钻。

ProbMask(在 _update_context 的 mask_flag=True 路径中使用):构造因果上三角掩码,只取 top-u 行对应位置,使活跃 query 的注意力计算仍遵守"只能看左边"的 causal 约束。源码 masking.py:17-28,逻辑简单,无需单独文档。

*记录并在线阅读我的笔记*