Skip to content

4C-2 ProbAttention

Abstract

这一篇是:

04C-1 AttentionLayerinner_attention 的当前实现,也就是 ProbAttention

只讲:

ProbAttention 怎样不对所有 query 都完整算一遍 QK,而是先抽样、选 top 查询,再只更新这些查询对应的上下文。

1. 上下文

上一层:

这一层的入口代码是:

python
out, attn = self.inner_attention(
    queries, keys, values, attn_mask, tau=tau, delta=delta
)

这一层的输出是:

python
context.shape = (B, L_Q, H, D)

2. 当前层第一性

这一层存在的第一性是:

不是让每个 query 都完整地和所有 key 做 attention,而是先找“最值得细算”的 query,再只对这些 query 做精确更新。

3. 本层输入输出含义

3.1 输入

  • queries
    • 形状 (B, L_Q, H, D)
  • keys
    • 形状 (B, L_K, H, D)
  • values
    • 形状 (B, L_K, H, D)
  • factor
    • 控制抽样规模和 top 查询数量
  • mask_flag
    • 是否使用因果 mask

3.2 输出

  • context
    • 形状 (B, L_Q, H, D)
  • attn
    • 注意力权重,若 output_attention=True 才返回

4. 顺序图

5. 抽象树

6. 固定 toy 例子

为了算得动,固定:

  • B = 1
  • H = 1
  • L_Q = L_K = 4
  • D = 2
  • factor = 1
python
Q = [
    [1, 0],
    [0, 1],
    [1, 1],
    [2, 0],
]  # (1, 1, 4, 2)

K = [
    [1, 0],
    [0, 1],
    [1, 1],
    [1, -1],
]

V = [
    [10, 11],
    [20, 21],
    [30, 31],
    [40, 41],
]

7. 代码块 1:ProbAttention.forward(...)

位置:

完整代码:

python
class ProbAttention(nn.Module):
    def __init__(
        self,
        mask_flag=True,
        factor=5,
        scale=None,
        attention_dropout=0.1,
        output_attention=False,
    ):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape

        queries = queries.transpose(2, 1)
        keys = keys.transpose(2, 1)
        values = values.transpose(2, 1)

        U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
        u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()

        U_part = U_part if U_part < L_K else L_K
        u = u if u < L_Q else L_Q

        scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

        scale = self.scale or 1.0 / sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale
        context = self._get_initial_context(values, L_Q)
        context, attn = self._update_context(
            context, values, scores_top, index, L_Q, attn_mask
        )

        return context.contiguous(), attn

8. 子块 A:_prob_QK(...)

这一块的完整代码是:

python
def _prob_QK(self, Q, K, sample_k, n_top):
    B, H, L_K, E = K.shape
    _, _, L_Q, _ = Q.shape

    K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
    index_sample = torch.randint(L_K, (L_Q, sample_k))
    K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
    Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

    M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
    M_top = M.topk(n_top, sorted=False)[1]

    Q_reduce = Q[
        torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
    ]
    Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))

    return Q_K, M_top

中文注释版:

python
def _prob_QK(self, Q, K, sample_k, n_top):
    # Q: (B, H, L_Q, D)
    # K: (B, H, L_K, D)
    B, H, L_K, E = K.shape
    _, _, L_Q, _ = Q.shape

    # 第一步:把 K 扩成“每个 query 都有一份可抽样 key 候选”的形状
    K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)

    # 第二步:对每个 query 随机抽 sample_k 个 key 下标
    index_sample = torch.randint(L_K, (L_Q, sample_k))

    # 第三步:拿到每个 query 对应的抽样 key
    K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]

    # 第四步:先只和抽样到的 key 算一个近似分数
    Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

    # 第五步:用 max - mean 估计“哪个 query 更值得精算”
    M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
    M_top = M.topk(n_top, sorted=False)[1]

    # 第六步:只保留 top 查询,再和全部 key 做完整 QK
    Q_reduce = Q[
        torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
    ]
    Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))

    return Q_K, M_top

8.1 对应到当前 forward(...) 的位置

ProbAttention.forward(...) 里,这一段对应的是:

python
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

也就是说,这一块的输入和输出语义是:

  • 输入:所有 query / key
  • 输出:
    • scores_top:只针对 top 查询保留的完整分数
    • index:被选中的 top 查询下标

8.2 toy 张量逐步演变图

text
输入:
  Q = 4 个 query
  K = 4 个 key

步骤 1: 先抽样
  固定 sample_k = 2
  假设对 4 个 query 抽到的 key 下标分别是:
    q0 -> [0, 2]
    q1 -> [1, 3]
    q2 -> [0, 2]
    q3 -> [2, 3]

步骤 2: 对抽样到的 key 算近似分数
  q0 只和 k0/k2 算
  q1 只和 k1/k3 算
  q2 只和 k0/k2 算
  q3 只和 k2/k3 算

步骤 3: 用 max - mean 得到每个 query 的稀疏度分数 M

步骤 4: 只保留 top 查询下标 index
  toy 里固定最后选中:
    index = [0, 2]

步骤 5: 只让 q0/q2 再去和全部 key 做完整 QK

输出:
  scores_top =
    q0 对 [k0,k1,k2,k3] 的完整分数
    q2 对 [k0,k1,k2,k3] 的完整分数
  index = [0, 2]

8.3 一个可算的 toy 小算例

为了直观,假设抽样后发现:

  • query 0 最有代表性
  • query 2 次之

那就只对这两个 query 和所有 key 完整算分数。

例如 query 0 = [1, 0],与四个 key 点积:

text
q0·k0 = [1,0]·[1,0]  = 1
q0·k1 = [1,0]·[0,1]  = 0
q0·k2 = [1,0]·[1,1]  = 1
q0·k3 = [1,0]·[1,-1] = 1

所以:

text
scores_top[q0] = [1, 0, 1, 1]

再看 query 2 = [1, 1]

text
q2·k0 = 1
q2·k1 = 1
q2·k2 = 2
q2·k3 = 0

所以:

text
scores_top[q2] = [1, 1, 2, 0]

再往前补一步,把 M = max - mean 也算出来:

q0

text
max = 1
mean = (1 + 0 + 1 + 1) / 4 = 0.75
M(q0) = 1 - 0.75 = 0.25

q2

text
max = 2
mean = (1 + 1 + 2 + 0) / 4 = 1
M(q2) = 2 - 1 = 1

所以 q2q0 更稀疏、更值得保留。

8.4 这一段的 input / output 语义

  • 输入 Q/K
    • 多头 query/key
  • 输出 scores_top
    • 只对 top 查询保留的完整分数
  • 输出 index
    • 哪些查询被选中做精算

9. 子块 B:_get_initial_context(...)

这一块的完整代码是:

python
def _get_initial_context(self, V, L_Q):
    B, H, L_V, D = V.shape
    if not self.mask_flag:
        V_sum = V.mean(dim=-2)
        contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
    else:
        assert L_Q == L_V
        contex = V.cumsum(dim=-2)
    return contex

中文注释版:

python
def _get_initial_context(self, V, L_Q):
    B, H, L_V, D = V.shape
    if not self.mask_flag:
        # 非因果情形:先用所有 value 的平均值做一个默认上下文
        V_sum = V.mean(dim=-2)
        contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
    else:
        # 因果情形:用前缀和,保证每个位置只能看到过去
        assert L_Q == L_V
        contex = V.cumsum(dim=-2)
    return contex

9.1 对应到当前 forward(...) 的位置

python
context = self._get_initial_context(values, L_Q)

这里的语义是:

  • _prob_QK 还没真正更新 context 之前
  • 先给每个 query 一个“默认版本的上下文”

9.2 toy 张量逐步演变图

mask_flag=False 时,源码是:

python
V_sum = V.mean(dim=-2)
context = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()

也就是先拿所有 value 的平均值作为每个 query 的默认上下文。

text
V =
[
  [10, 11],
  [20, 21],
  [30, 31],
  [40, 41],
]

平均值 = [25, 26]

初始化 context =
[
  [25, 26],
  [25, 26],
  [25, 26],
  [25, 26],
]

9.3 这一段的 input / output 语义

  • 输入 V
    • value 向量集合
  • 输出 context
    • 所有 query 的默认上下文底稿

10. 子块 C:_update_context(...)

这一块的完整代码是:

python
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
    B, H, L_V, D = V.shape

    if self.mask_flag:
        attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
        scores.masked_fill_(attn_mask.mask, -np.inf)

    attn = torch.softmax(scores, dim=-1)

    context_in[
        torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
    ] = torch.matmul(attn, V).type_as(context_in)
    if self.output_attention:
        attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
        attns[
            torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
        ] = attn
        return context_in, attns
    else:
        return context_in, None

中文注释版:

python
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
    B, H, L_V, D = V.shape

    if self.mask_flag:
        # 因果情形下对非法位置做 mask
        attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
        scores.masked_fill_(attn_mask.mask, -np.inf)

    # 对 top 查询的完整分数做 softmax
    attn = torch.softmax(scores, dim=-1)

    # 只把这些 top 查询位置写回 context
    context_in[
        torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
    ] = torch.matmul(attn, V).type_as(context_in)

    if self.output_attention:
        attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
        attns[
            torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
        ] = attn
        return context_in, attns
    else:
        return context_in, None

10.1 对应到当前 forward(...) 的位置

python
context, attn = self._update_context(
    context, values, scores_top, index, L_Q, attn_mask
)

语义是:

  • 前面拿到的是默认 context
  • 这里才把 top 查询位置替换成真正 attention 结果

10.2 toy 张量逐步演变图

text
初始 context =
[
  [25, 26],
  [25, 26],
  [25, 26],
  [25, 26],
]

当前只更新 query 0 和 query 2

10.3 一个可算的 toy 小算例

对 query 0,分数是:

text
scores_top[q0] = [1, 0, 1, 1]

softmax 后假设得到:

text
w0 = [0.30, 0.10, 0.30, 0.30]

那么新的上下文就是:

text
context[0] = 0.30*[10,11] + 0.10*[20,21] + 0.30*[30,31] + 0.30*[40,41]
           = [3,3.3] + [2,2.1] + [9,9.3] + [12,12.3]
           = [26, 27]

对 query 2,分数是:

text
scores_top[q2] = [1, 1, 2, 0]

softmax 后假设得到:

text
w2 = [0.20, 0.20, 0.50, 0.10]

则:

text
context[2] = 0.20*[10,11] + 0.20*[20,21] + 0.50*[30,31] + 0.10*[40,41]
           = [2,2.2] + [4,4.2] + [15,15.5] + [4,4.1]
           = [25, 26]

于是更新后:

text
context =
[
  [26, 27],
  [25, 26],
  [25, 26],
  [25, 26],
]

如果 query 2 算出不同值,它也会被写回自己的位置。

还可以把“只更新 top 查询”写得更直白一点:

text
更新前:
  q0 -> [25, 26]
  q1 -> [25, 26]
  q2 -> [25, 26]
  q3 -> [25, 26]

更新后:
  q0 -> [26, 27]   # 被精确更新
  q1 -> [25, 26]   # 保持默认值
  q2 -> [25, 26]   # toy 里更新后碰巧与默认值一样
  q3 -> [25, 26]   # 保持默认值

10.4 这一段的 input / output 语义

  • 输入 context
    • 默认上下文底稿
  • 输入 scores_top/index
    • 需要精确更新的查询及其分数
  • 输出 context
    • 只有 top 查询位置被精确替换过的新上下文

11. 这一层和普通 FullAttention 的本质区别

普通 FullAttention 更像:

text
每个 query
-> 和所有 key 算完整分数
-> softmax
-> 加权求和 V

ProbAttention 更像:

text
所有 query
-> 先抽样估计稀疏度
-> 只挑 top 查询
-> 所有 query 先给默认 context
-> 只对 top 查询做完整 attention 更新

所以最本质的区别是:

不是所有 query 都做完整 attention,而是只有最值得细算的 query 才做。

12. 当前层真正要固定什么

  1. ProbAttention 的关键不是“换公式”,而是“少算很多 query”
  2. 它先:
    • 抽样
    • 选 top 查询
    • 给所有 query 一个默认 context
    • 再只更新 top 查询
  3. 最终输出形状仍然和普通 attention 一样:
    • (B, L_Q, H, D)

13. 下一步

返回上层继续看:

*记录并在线阅读我的笔记*