Appearance
4C-2 ProbAttention
Abstract
这一篇是:
04C-1 AttentionLayer里inner_attention的当前实现,也就是ProbAttention。只讲:
ProbAttention怎样不对所有 query 都完整算一遍QK,而是先抽样、选 top 查询,再只更新这些查询对应的上下文。
1. 上下文
上一层:
这一层的入口代码是:
python
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)这一层的输出是:
python
context.shape = (B, L_Q, H, D)2. 当前层第一性
这一层存在的第一性是:
不是让每个 query 都完整地和所有 key 做 attention,而是先找“最值得细算”的 query,再只对这些 query 做精确更新。
3. 本层输入输出含义
3.1 输入
queries- 形状
(B, L_Q, H, D)
- 形状
keys- 形状
(B, L_K, H, D)
- 形状
values- 形状
(B, L_K, H, D)
- 形状
factor- 控制抽样规模和 top 查询数量
mask_flag- 是否使用因果 mask
3.2 输出
context- 形状
(B, L_Q, H, D)
- 形状
attn- 注意力权重,若
output_attention=True才返回
- 注意力权重,若
4. 顺序图
5. 抽象树
6. 固定 toy 例子
为了算得动,固定:
B = 1H = 1L_Q = L_K = 4D = 2factor = 1
python
Q = [
[1, 0],
[0, 1],
[1, 1],
[2, 0],
] # (1, 1, 4, 2)
K = [
[1, 0],
[0, 1],
[1, 1],
[1, -1],
]
V = [
[10, 11],
[20, 21],
[30, 31],
[40, 41],
]7. 代码块 1:ProbAttention.forward(...)
位置:
完整代码:
python
class ProbAttention(nn.Module):
def __init__(
self,
mask_flag=True,
factor=5,
scale=None,
attention_dropout=0.1,
output_attention=False,
):
super(ProbAttention, self).__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = queries.transpose(2, 1)
keys = keys.transpose(2, 1)
values = values.transpose(2, 1)
U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
scale = self.scale or 1.0 / sqrt(D)
if scale is not None:
scores_top = scores_top * scale
context = self._get_initial_context(values, L_Q)
context, attn = self._update_context(
context, values, scores_top, index, L_Q, attn_mask
)
return context.contiguous(), attn8. 子块 A:_prob_QK(...)
这一块的完整代码是:
python
def _prob_QK(self, Q, K, sample_k, n_top):
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
Q_reduce = Q[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
]
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))
return Q_K, M_top中文注释版:
python
def _prob_QK(self, Q, K, sample_k, n_top):
# Q: (B, H, L_Q, D)
# K: (B, H, L_K, D)
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# 第一步:把 K 扩成“每个 query 都有一份可抽样 key 候选”的形状
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# 第二步:对每个 query 随机抽 sample_k 个 key 下标
index_sample = torch.randint(L_K, (L_Q, sample_k))
# 第三步:拿到每个 query 对应的抽样 key
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
# 第四步:先只和抽样到的 key 算一个近似分数
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# 第五步:用 max - mean 估计“哪个 query 更值得精算”
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
# 第六步:只保留 top 查询,再和全部 key 做完整 QK
Q_reduce = Q[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
]
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))
return Q_K, M_top8.1 对应到当前 forward(...) 的位置
在 ProbAttention.forward(...) 里,这一段对应的是:
python
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)也就是说,这一块的输入和输出语义是:
- 输入:所有 query / key
- 输出:
scores_top:只针对 top 查询保留的完整分数index:被选中的 top 查询下标
8.2 toy 张量逐步演变图
text
输入:
Q = 4 个 query
K = 4 个 key
步骤 1: 先抽样
固定 sample_k = 2
假设对 4 个 query 抽到的 key 下标分别是:
q0 -> [0, 2]
q1 -> [1, 3]
q2 -> [0, 2]
q3 -> [2, 3]
步骤 2: 对抽样到的 key 算近似分数
q0 只和 k0/k2 算
q1 只和 k1/k3 算
q2 只和 k0/k2 算
q3 只和 k2/k3 算
步骤 3: 用 max - mean 得到每个 query 的稀疏度分数 M
步骤 4: 只保留 top 查询下标 index
toy 里固定最后选中:
index = [0, 2]
步骤 5: 只让 q0/q2 再去和全部 key 做完整 QK
输出:
scores_top =
q0 对 [k0,k1,k2,k3] 的完整分数
q2 对 [k0,k1,k2,k3] 的完整分数
index = [0, 2]8.3 一个可算的 toy 小算例
为了直观,假设抽样后发现:
- query 0 最有代表性
- query 2 次之
那就只对这两个 query 和所有 key 完整算分数。
例如 query 0 = [1, 0],与四个 key 点积:
text
q0·k0 = [1,0]·[1,0] = 1
q0·k1 = [1,0]·[0,1] = 0
q0·k2 = [1,0]·[1,1] = 1
q0·k3 = [1,0]·[1,-1] = 1所以:
text
scores_top[q0] = [1, 0, 1, 1]再看 query 2 = [1, 1]:
text
q2·k0 = 1
q2·k1 = 1
q2·k2 = 2
q2·k3 = 0所以:
text
scores_top[q2] = [1, 1, 2, 0]再往前补一步,把 M = max - mean 也算出来:
对 q0:
text
max = 1
mean = (1 + 0 + 1 + 1) / 4 = 0.75
M(q0) = 1 - 0.75 = 0.25对 q2:
text
max = 2
mean = (1 + 1 + 2 + 0) / 4 = 1
M(q2) = 2 - 1 = 1所以 q2 比 q0 更稀疏、更值得保留。
8.4 这一段的 input / output 语义
- 输入
Q/K- 多头 query/key
- 输出
scores_top- 只对 top 查询保留的完整分数
- 输出
index- 哪些查询被选中做精算
9. 子块 B:_get_initial_context(...)
这一块的完整代码是:
python
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
V_sum = V.mean(dim=-2)
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
else:
assert L_Q == L_V
contex = V.cumsum(dim=-2)
return contex中文注释版:
python
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
# 非因果情形:先用所有 value 的平均值做一个默认上下文
V_sum = V.mean(dim=-2)
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
else:
# 因果情形:用前缀和,保证每个位置只能看到过去
assert L_Q == L_V
contex = V.cumsum(dim=-2)
return contex9.1 对应到当前 forward(...) 的位置
python
context = self._get_initial_context(values, L_Q)这里的语义是:
_prob_QK还没真正更新 context 之前- 先给每个 query 一个“默认版本的上下文”
9.2 toy 张量逐步演变图
在 mask_flag=False 时,源码是:
python
V_sum = V.mean(dim=-2)
context = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()也就是先拿所有 value 的平均值作为每个 query 的默认上下文。
text
V =
[
[10, 11],
[20, 21],
[30, 31],
[40, 41],
]
平均值 = [25, 26]
初始化 context =
[
[25, 26],
[25, 26],
[25, 26],
[25, 26],
]9.3 这一段的 input / output 语义
- 输入
V- value 向量集合
- 输出
context- 所有 query 的默认上下文底稿
10. 子块 C:_update_context(...)
这一块的完整代码是:
python
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = torch.softmax(scores, dim=-1)
context_in[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
attns[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = attn
return context_in, attns
else:
return context_in, None中文注释版:
python
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
# 因果情形下对非法位置做 mask
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
# 对 top 查询的完整分数做 softmax
attn = torch.softmax(scores, dim=-1)
# 只把这些 top 查询位置写回 context
context_in[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
attns[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = attn
return context_in, attns
else:
return context_in, None10.1 对应到当前 forward(...) 的位置
python
context, attn = self._update_context(
context, values, scores_top, index, L_Q, attn_mask
)语义是:
- 前面拿到的是默认
context - 这里才把 top 查询位置替换成真正 attention 结果
10.2 toy 张量逐步演变图
text
初始 context =
[
[25, 26],
[25, 26],
[25, 26],
[25, 26],
]
当前只更新 query 0 和 query 210.3 一个可算的 toy 小算例
对 query 0,分数是:
text
scores_top[q0] = [1, 0, 1, 1]softmax 后假设得到:
text
w0 = [0.30, 0.10, 0.30, 0.30]那么新的上下文就是:
text
context[0] = 0.30*[10,11] + 0.10*[20,21] + 0.30*[30,31] + 0.30*[40,41]
= [3,3.3] + [2,2.1] + [9,9.3] + [12,12.3]
= [26, 27]对 query 2,分数是:
text
scores_top[q2] = [1, 1, 2, 0]softmax 后假设得到:
text
w2 = [0.20, 0.20, 0.50, 0.10]则:
text
context[2] = 0.20*[10,11] + 0.20*[20,21] + 0.50*[30,31] + 0.10*[40,41]
= [2,2.2] + [4,4.2] + [15,15.5] + [4,4.1]
= [25, 26]于是更新后:
text
context =
[
[26, 27],
[25, 26],
[25, 26],
[25, 26],
]如果 query 2 算出不同值,它也会被写回自己的位置。
还可以把“只更新 top 查询”写得更直白一点:
text
更新前:
q0 -> [25, 26]
q1 -> [25, 26]
q2 -> [25, 26]
q3 -> [25, 26]
更新后:
q0 -> [26, 27] # 被精确更新
q1 -> [25, 26] # 保持默认值
q2 -> [25, 26] # toy 里更新后碰巧与默认值一样
q3 -> [25, 26] # 保持默认值10.4 这一段的 input / output 语义
- 输入
context- 默认上下文底稿
- 输入
scores_top/index- 需要精确更新的查询及其分数
- 输出
context- 只有 top 查询位置被精确替换过的新上下文
11. 这一层和普通 FullAttention 的本质区别
普通 FullAttention 更像:
text
每个 query
-> 和所有 key 算完整分数
-> softmax
-> 加权求和 VProbAttention 更像:
text
所有 query
-> 先抽样估计稀疏度
-> 只挑 top 查询
-> 所有 query 先给默认 context
-> 只对 top 查询做完整 attention 更新所以最本质的区别是:
不是所有 query 都做完整 attention,而是只有最值得细算的 query 才做。
12. 当前层真正要固定什么
ProbAttention的关键不是“换公式”,而是“少算很多 query”- 它先:
- 抽样
- 选 top 查询
- 给所有 query 一个默认 context
- 再只更新 top 查询
- 最终输出形状仍然和普通 attention 一样:
(B, L_Q, H, D)
13. 下一步
返回上层继续看: