Appearance
Informer ProbAttention 完整代码精读
Abstract
本文只讲一个对象:
SelfAttention_Family.py里的ProbAttention。目标是把
ProbAttention.forward(...)从入口到返回讲清楚:输入是什么、为什么要采样、怎么选 top query、未选中的 query 怎么近似、被选中的 query 怎么精确更新、最后返回什么。
关联文档:
0. ProbAttention 的第一性
Full Attention 做的是:
如果 query 长度是
复杂度大致是:
Informer 的 ProbSparse 思路是:
核心思想
不是所有 query 都值得精确计算完整 attention。先用少量采样 key 粗略估计每个 query 的“稀疏性/尖锐程度”,只挑 top query 做精确 attention;其他 query 用一个便宜的默认 context 近似。
换成代码语言:
text
1. 对每个 query,随机采样一部分 key,粗略算 QK
2. 用 M = max(sample_score) - mean(sample_score) 衡量 query 是否尖锐
3. 选 M 最大的 u 个 query
4. 只对这 u 个 query 计算完整 QK
5. 先给所有 query 一个初始 context
6. 用精确 attention 输出覆盖 top query 的 context
7. 返回 context1. 在 Informer 主链中的位置
Informer 里并不是直接调用 ProbAttention,而是被 AttentionLayer 包了一层。
text
EncoderLayer.forward
└─ self.attention(x, x, x, ...)
└─ AttentionLayer.forward
├─ Linear 生成 Q/K/V
├─ view 成多头格式: (B,L,H,D)
└─ self.inner_attention(...)
└─ ProbAttention.forward ← 本文Decoder 中也会用到:
text
DecoderLayer.forward
├─ self_attention: ProbAttention(mask_flag=True)
└─ cross_attention: ProbAttention(mask_flag=False)两种 mask_flag
Encoder self-attention 和 decoder cross-attention 一般是
mask_flag=False。Decoder self-attention 是
mask_flag=True,防止未来信息泄露。
2. 输入输出接口
ProbAttention.forward(...) 的入口:
python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):进入 ProbAttention 前,AttentionLayer 已经把 Q/K/V 变成:
text
queries: (B, L_Q, H, D)
keys: (B, L_K, H, D)
values: (B, L_K, H, D)ProbAttention.forward(...) 内部第一步会转置:
text
queries.transpose(2, 1): (B, H, L_Q, D)
keys.transpose(2, 1): (B, H, L_K, D)
values.transpose(2, 1): (B, H, L_K, D)输出:
text
context: (B, H, L_Q, D)
attn: None 或 (B, H, L_Q, L_K)然后 AttentionLayer.forward(...) 会接着:
python
out = out.view(B, L, -1)
return self.out_projection(out), attn维度方向容易混
AttentionLayer给ProbAttention的输入是(B,L,H,D)。
ProbAttention内部为了方便按 head 做 attention,会转成(B,H,L,D)。所以读
_prob_QK、_get_initial_context、_update_context时,都要按(B,H,L,D)理解。
3. 总流程图
4. 完整真实代码
位置:
text
ts_benchmark/baselines/time_series_library/layers/SelfAttention_Family.py
class ProbAttentionpython
class ProbAttention(nn.Module):
def __init__(
self,
mask_flag=True,
factor=5,
scale=None,
attention_dropout=0.1,
output_attention=False,
):
super(ProbAttention, self).__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# real U = U_part(factor*ln(L_k))*L_q
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# find the Top_k query with sparisty measurement
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
# use the reduced Q to calculate Q_K
Q_reduce = Q[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
] # factor*ln(L_q)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
return Q_K, M_top
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
# V_sum = V.sum(dim=-2)
V_sum = V.mean(dim=-2)
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
else: # use mask
# requires that L_Q == L_V, i.e. for self-attention only
assert L_Q == L_V
contex = V.cumsum(dim=-2)
return contex
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
context_in[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
attns[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = attn
return context_in, attns
else:
return context_in, None
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = queries.transpose(2, 1)
keys = keys.transpose(2, 1)
values = values.transpose(2, 1)
U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item() # c*ln(L_k)
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item() # c*ln(L_q)
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
# add scale factor
scale = self.scale or 1.0 / sqrt(D)
if scale is not None:
scores_top = scores_top * scale
# get the context
context = self._get_initial_context(values, L_Q)
# update the context with selected top_k queries
context, attn = self._update_context(
context, values, scores_top, index, L_Q, attn_mask
)
return context.contiguous(), attn5. forward:主控流程
先看 forward,因为它决定执行顺序。
中文注释版:
python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
# AttentionLayer 传进来的是 (B, L, H, D)
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
# 转成 ProbAttention 内部使用的 (B, H, L, D)
queries = queries.transpose(2, 1)
keys = keys.transpose(2, 1)
values = values.transpose(2, 1)
# 采样 key 的数量:factor * ceil(log L_K)
U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
# 精确计算 attention 的 query 数量:factor * ceil(log L_Q)
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()
# 防止采样数量超过真实长度
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q
# 选出 top query,并只对这些 query 计算完整 QK
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
# 缩放点积,和标准 attention 一样除以 sqrt(D)
scale = self.scale or 1.0 / sqrt(D)
if scale is not None:
scores_top = scores_top * scale
# 给所有 query 一个初始 context
context = self._get_initial_context(values, L_Q)
# 只把 top query 的 context 用精确 attention 结果覆盖
context, attn = self._update_context(
context, values, scores_top, index, L_Q, attn_mask
)
return context.contiguous(), attn5.1 当前调试参数下的大致 shape
如果使用 调试形参 里的小例子:
text
B = 4
seq_len = 24
d_model = 32
n_heads = 2
head_dim = 16Encoder 侧进入 ProbAttention 时:
text
queries: (4, 24, 2, 16)
keys: (4, 24, 2, 16)
values: (4, 24, 2, 16)转置后:
text
queries: (4, 2, 24, 16)
keys: (4, 2, 24, 16)
values: (4, 2, 24, 16)如果 factor=3:
text
U_part = min(3 * ceil(log(24)), 24)
= min(3 * 4, 24)
= 12
u = min(3 * ceil(log(24)), 24)
= 12含义:
text
每个 query 随机看 12 个 key 来估计稀疏性。
最后选 12 个 query 做完整 attention。5.2 np.ceil:采样规模从公式落到代码
对应代码:
python
U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q对应数学公式:
变量含义:
| 符号/变量 | 代码位置 | 含义 |
|---|---|---|
self.factor | ProbSparse 的采样系数,越大越接近 full attention,越小越省计算 | |
keys.shape[1] | key 序列长度 | |
queries.shape[1] | query 序列长度 | |
U_part / sample_k | 每个 query 随机抽多少个 key 来估计重要性 | |
u / n_top | 最后保留多少个 query 做完整 attention |
逐函数解释:
text
np.log(L_K)表示自然对数
text
np.ceil(...)表示向上取整。比如:
text
log(24) = 3.178...
ceil(log(24)) = 4向上取整的原因是:采样数量必须是整数,而且不能因为小数截断导致采样过少。
text
.astype("int").item()把 NumPy 标量转成 Python 整数。这里不是模型计算,只是为了得到后面 torch.randint 和 topk 需要的整数参数。
具体例子:
text
L_K = 24
factor = 3
U_part = 3 * ceil(log(24))
= 3 * 4
= 12所以每个 query 随机看 12 个 key。
toy 例子里:
text
L_K = 5
factor = 1
U_part = 1 * ceil(log(5))
= 1 * 2
= 2所以 toy 例子中每个 query 只随机看 2 个 key。
最后两行:
python
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q作用是上限截断:
text
如果算出来要采样 12 个 key,但真实只有 5 个 key,那最多只能采样 5 个。
如果算出来要选 12 个 query,但真实只有 4 个 query,那最多只能选 4 个。6. _prob_QK:选 top query
_prob_QK 是 ProbAttention 的核心。
它做两件事:
text
1. 用随机采样 key 粗略判断哪些 query 更重要
2. 只对重要 query 计算完整 QK6.1 完整代码
python
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# real U = U_part(factor*ln(L_k))*L_q
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# find the Top_k query with sparisty measurement
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
# use the reduced Q to calculate Q_K
Q_reduce = Q[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
] # factor*ln(L_q)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
return Q_K, M_top6.2 toy 例子
为了能手算,固定:
text
B = 1
H = 1
L_Q = 4
L_K = 5
D = 2
sample_k = 2
n_top = 2设 Q:
text
q0 = (2, 0)
q1 = (0, 1)
q2 = (3, 1)
q3 = (0, 1)设 K:
text
k0 = (1, 0)
k1 = (0, 1)
k2 = (2, 0)
k3 = (0, 1)
k4 = (1, 1)真实代码里 index_sample = torch.randint(...) 是随机的。为了讲清楚,toy 例子固定采样:
text
q0 采样 key: k0, k2
q1 采样 key: k1, k4
q2 采样 key: k0, k4
q3 采样 key: k2, k3也就是:
text
index_sample:
q0 -> (0, 2)
q1 -> (1, 4)
q2 -> (0, 4)
q3 -> (2, 3)6.3 K_expand
代码:
python
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)原来:
text
K: (B,H,L_K,D) = (1,1,5,2)unsqueeze(-3) 后:
text
K: (1,1,1,5,2)expand(B,H,L_Q,L_K,E) 后:
text
K_expand: (1,1,4,5,2)语义:
text
给每个 query 准备一份 key 池。
不是复制真实数据,而是广播视图。更精确地说:
text
K[b,h,k,e] 变成 K_expand[b,h,q,k,e]数学对应是:
也就是说,q 这一维只是为了让每个 query 都能在自己的位置上索引不同的 sampled key。它不改变 key 的数值。
expand 的关键点:
| 操作 | shape 变化 | 含义 |
|---|---|---|
K | (B,H,L_K,D) | 原始 key |
K.unsqueeze(-3) | (B,H,1,L_K,D) | 在 L_K 前面插入 query 维 |
.expand(B,H,L_Q,L_K,D) | (B,H,L_Q,L_K,D) | 把长度为 1 的 query 维广播成 L_Q |
Note
expand通常不真实复制数据,而是创建广播视图。这里的目标不是“复制一份 K”,而是让索引表达式可以写成K_expand[:, :, query_index, sampled_key_index, :]。
6.4 K_sample
代码:
python
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]这个索引的语义是:
text
对每个 query,只取 sample_k 个随机 key。toy 中得到:
text
q0 的 K_sample: k0=(1,0), k2=(2,0)
q1 的 K_sample: k1=(0,1), k4=(1,1)
q2 的 K_sample: k0=(1,0), k4=(1,1)
q3 的 K_sample: k2=(2,0), k3=(0,1)shape:
text
K_sample: (B,H,L_Q,sample_k,D) = (1,1,4,2,2)6.5 Q_K_sample
代码:
python
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()局部 shape:
text
Q.unsqueeze(-2): (1,1,4,1,2)
K_sample.transpose(-2,-1): (1,1,4,2,2)
Q_K_sample: (1,1,4,2)逐 query 手算:
text
q0=(2,0):
q0·k0 = (2,0)·(1,0) = 2
q0·k2 = (2,0)·(2,0) = 4
sample score = (2,4)
q1=(0,1):
q1·k1 = (0,1)·(0,1) = 1
q1·k4 = (0,1)·(1,1) = 1
sample score = (1,1)
q2=(3,1):
q2·k0 = (3,1)·(1,0) = 3
q2·k4 = (3,1)·(1,1) = 4
sample score = (3,4)
q3=(0,1):
q3·k2 = (0,1)·(2,0) = 0
q3·k3 = (0,1)·(0,1) = 1
sample score = (0,1)6.6 稀疏性指标 M = max - mean
代码:
python
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)代码细节
这里 mean 的分母用的是
L_K,不是sample_k。所以它不是采样分数的普通均值,而是论文/实现里用于估计稀疏性的写法。
toy 中 L_K = 5:
text
q0: max(2,4) - (2+4)/5 = 4 - 1.2 = 2.8
q1: max(1,1) - (1+1)/5 = 1 - 0.4 = 0.6
q2: max(3,4) - (3+4)/5 = 4 - 1.4 = 2.6
q3: max(0,1) - (0+1)/5 = 1 - 0.2 = 0.8所以:
text
M = (2.8, 0.6, 2.6, 0.8)选择 top 2:
python
M_top = M.topk(n_top, sorted=False)[1]topk 返回的是一个二元组:
python
values, indices = M.topk(n_top, sorted=False)其中:
text
values = 最大的 n_top 个分数
indices = 这些分数在原 query 维度上的下标当前代码只要下标,所以写:
python
M_top = M.topk(n_top, sorted=False)[1]sorted=False 的意思是:返回的 top query 不保证按分数从大到小排序。对后续计算没有影响,因为后面只是用这些下标去取对应的 Q,不是依赖顺序做递推。
toy 中:
text
M = (2.8, 0.6, 2.6, 0.8)
n_top = 2所以 top 2 的 query 下标是:
text
M_top = (0, 2)如果 sorted=False,也可能返回:
text
M_top = (2, 0)这两种都表示选中了 q0 和 q2,只是顺序不同。
公式对应:
其中 M_top。
得到:
text
M_top = (0, 2)含义:
text
q0 和 q2 被认为是最值得精确计算 attention 的 query。
q1 和 q3 暂时只用默认 context 近似。6.7 对 top query 计算完整 QK
代码:
python
Q_reduce = Q[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
]
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))公式对应:
其中:
text
I = M_top
r = top query 内部编号也就是说,Q_reduce 不是新的 query,而是从原始 Q 里按 M_top 下标挑出来的少量重要 query。
toy 中:
text
Q_reduce = q0, q2完整 key:
text
k0=(1,0), k1=(0,1), k2=(2,0), k3=(0,1), k4=(1,1)完整 QK:
text
q0=(2,0):
q0·k0 = 2
q0·k1 = 0
q0·k2 = 4
q0·k3 = 0
q0·k4 = 2
score row = (2,0,4,0,2)
q2=(3,1):
q2·k0 = 3
q2·k1 = 1
q2·k2 = 6
q2·k3 = 1
q2·k4 = 4
score row = (3,1,6,1,4)返回:
text
Q_K.shape = (B,H,n_top,L_K) = (1,1,2,5)
M_top.shape = (B,H,n_top) = (1,1,2)7. _get_initial_context:给所有 query 一个默认输出
_prob_QK 只算了 top query。那未选中的 query 怎么办?
答案:先给所有 query 一个便宜的初始 context。
7.1 完整代码
python
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
# V_sum = V.sum(dim=-2)
V_sum = V.mean(dim=-2)
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
else: # use mask
# requires that L_Q == L_V, i.e. for self-attention only
assert L_Q == L_V
contex = V.cumsum(dim=-2)
return contex7.2 mask_flag=False
Encoder self-attention 和 decoder cross-attention 通常走这里。
设 V:
text
v0=(0.1,0.8)
v1=(0.5,0.3)
v2=(0.9,0.2)
v3=(0.4,0.6)
v4=(0.7,0.1)求均值:
text
mean(V)
= ((0.1+0.5+0.9+0.4+0.7)/5, (0.8+0.3+0.2+0.6+0.1)/5)
= (0.52, 0.40)扩展到所有 query:
text
context for q0 = (0.52,0.40)
context for q1 = (0.52,0.40)
context for q2 = (0.52,0.40)
context for q3 = (0.52,0.40)shape:
text
context: (B,H,L_Q,D) = (1,1,4,2)含义:
text
未被选中的 query 先用 value 的全局均值作为近似输出。7.3 mask_flag=True
Decoder self-attention 走这里。
代码:
python
contex = V.cumsum(dim=-2)dim=-2 对应的是时间/token 维:
text
V: (B,H,L_V,D)
↑
dim=-2所以 cumsum(dim=-2) 的数学含义是:
它不是普通求和,而是“前缀和”。每个时间点只累计自己和自己之前的 value。
如果:
text
v0=(1,0)
v1=(0,2)
v2=(3,1)
v3=(1,1)那么:
text
cumsum:
q0 context = v0 = (1,0)
q1 context = v0 + v1 = (1,2)
q2 context = v0 + v1 + v2 = (4,3)
q3 context = v0 + v1 + v2 + v3 = (5,4)含义:
text
因果 self-attention 下,默认 context 不能看未来,只能用当前位置之前的累计信息。这里还要注意真实代码前面有:
python
assert L_Q == L_V原因是 decoder self-attention 中 query 和 value 来自同一段 decoder 输入。只有长度一致时,q_t 才能自然对应到 V 的前缀 v_0 ... v_t。
Important
cumsum(V)只是给所有 query 的初始近似 context。后面_update_context(...)仍然会对 top query 用 masked attention 的精确结果覆盖它。
8. _update_context:用精确 attention 覆盖 top query
8.1 完整代码
python
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
context_in[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
attns[
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
] = attn
return context_in, attns
else:
return context_in, None8.2 scale 后 scores
forward 中已经做了:
python
scores_top = scores_top * (1.0 / sqrt(D))toy 中 D=2:
text
scale = 1 / sqrt(2) ≈ 0.707原始 top scores:
text
q0: (2,0,4,0,2)
q2: (3,1,6,1,4)scale 后:
text
q0: (1.414, 0, 2.828, 0, 1.414)
q2: (2.121, 0.707, 4.242, 0.707, 2.828)8.3 softmax
代码:
python
attn = torch.softmax(scores, dim=-1)dim=-1 是 key 维。
近似得到:
text
attn(q0) ≈ (0.145, 0.035, 0.596, 0.035, 0.145)
attn(q2) ≈ (0.087, 0.021, 0.724, 0.021, 0.177)8.4 attn @ V
代码:
python
torch.matmul(attn, V)用 V:
text
v0=(0.1,0.8)
v1=(0.5,0.3)
v2=(0.9,0.2)
v3=(0.4,0.6)
v4=(0.7,0.1)计算 q0:
text
out(q0)
= 0.145*v0 + 0.035*v1 + 0.596*v2 + 0.035*v3 + 0.145*v4
≈ (0.684, 0.269)计算 q2:
text
out(q2)
= 0.087*v0 + 0.021*v1 + 0.724*v2 + 0.021*v3 + 0.177*v4
≈ (0.803, 0.246)8.5 覆盖 context
初始 context:
text
q0=(0.52,0.40)
q1=(0.52,0.40)
q2=(0.52,0.40)
q3=(0.52,0.40)top query index:
text
index = (0,2)覆盖后:
text
q0=(0.684,0.269) ← 精确更新
q1=(0.520,0.400) ← 保持近似
q2=(0.803,0.246) ← 精确更新
q3=(0.520,0.400) ← 保持近似这就是 ProbAttention 的输出 context。
9. output_attention=True 时返回什么
默认:
python
output_attention=False所以:
text
return context_in, None如果打开:
python
output_attention=True代码会先构造一个均匀 attention:
python
attns = torch.ones([B, H, L_V, L_V]) / L_V然后把 top query 的真实 attention 填进去:
python
attns[..., index, :] = attn含义:
text
未被选中的 query,attention 被近似为均匀分布。
被选中的 top query,attention 是真实 softmax(scores)。形状细节
这里
attns写成[B,H,L_V,L_V]。在 self-attention 中L_Q=L_V没问题。对 cross-attention,如果
L_Q和L_V不同,打开output_attention=True时要额外小心这个实现假设。
10. Encoder / Decoder 中的差异
10.1 Encoder self-attention
创建位置在 Informer.__init__:
python
ProbAttention(
False,
configs.factor,
attention_dropout=configs.dropout,
output_attention=configs.output_attention,
)特征:
text
mask_flag=False
Q=K=V=encoder embedding
初始 context = mean(V)
top query 用精确 attention 更新10.2 Decoder self-attention
创建位置:
python
ProbAttention(
True,
configs.factor,
attention_dropout=configs.dropout,
output_attention=False,
)特征:
text
mask_flag=True
Q=K=V=decoder embedding
初始 context = cumsum(V)
ProbMask 防止未来泄露10.3 Decoder cross-attention
创建位置:
python
ProbAttention(
False,
configs.factor,
attention_dropout=configs.dropout,
output_attention=False,
)特征:
text
mask_flag=False
Q=decoder hidden
K/V=encoder output
初始 context = mean(encoder values)
top decoder query 从 encoder memory 中精确取信息11. 和 FullAttention 的对比
| 项目 | FullAttention | ProbAttention |
|---|---|---|
| 是否计算所有 query-key 分数 | 是 | 否 |
| score shape | (B,H,L_Q,L_K) | 只对 top query 精确算 (B,H,u,L_K) |
| 未选中 query 怎么办 | 不存在,全部精确算 | 用初始 context 近似 |
| 核心额外步骤 | 无 | _prob_QK 选 top query |
| 复杂度直觉 | 约 | |
| Informer 论文创新点 | 否 | 是 |
12. 公式和代码变量对照表
| 数学表达 | 代码变量/代码行 | 作用 |
|---|---|---|
queries, keys, values | attention 的三组输入 | |
B,H,L_Q,D / B,H,L_K,E | batch、head、序列长度、head 内特征维 | |
self.factor | ProbSparse 采样系数 | |
U_part / sample_k | 每个 query 抽样多少个 key | |
u / n_top | 选多少个重要 query | |
K.unsqueeze(-3).expand(...) | 为“每个 query 抽不同 key”准备广播视图 | |
Q_K_sample | sampled key 上的粗略打分 | |
M | 判断 query 是否稀疏/重要 | |
M_top / index | 重要 query 的下标 | |
Q_reduce | 从全部 query 中取出 top query | |
Q_K / scores_top | top query 的完整 attention score | |
attn = torch.softmax(scores, dim=-1) | 完整 attention 权重 | |
torch.matmul(attn, V) | top query 的精确输出 | |
V.mean(dim=-2) | 非 causal 情况下的默认 context | |
V.cumsum(dim=-2) | causal 情况下的默认 context |
13. 调试断点建议
按下面顺序打断点:
AttentionLayer.forward(...)- 看输入
queries.shape是否是(B,L,H,D)。
- 看输入
ProbAttention.forward(...)- 看转置前后 shape。
_prob_QK(...)- 看
sample_k、n_top、index_sample、M_top。
- 看
_get_initial_context(...)- 看
mask_flag=False时是不是mean(V)扩展。
- 看
_update_context(...)- 看
index对应哪些 query 被覆盖。
- 看
调试时重点观察:
text
queries.shape
keys.shape
values.shape
U_part
u
index_sample.shape
Q_K_sample.shape
M.shape
M_top / index
scores_top.shape
context.shape14. 一句话总结
ProbAttention 的关键不是“另一种 softmax”,而是:
Summary
先用随机采样 key 估计每个 query 的重要性,只对 top query 做完整 attention;其他 query 用便宜的默认 context 近似。
最核心的数据流:
text
Q,K,V
-> sample keys
-> sampled QK
-> M = max - mean
-> top query index
-> top query full QK
-> initial context for all query
-> top query context = softmax(full QK) @ V
-> return context