Skip to content

Informer ProbAttention 完整代码精读

Abstract

本文只讲一个对象:SelfAttention_Family.py 里的 ProbAttention

目标是把 ProbAttention.forward(...) 从入口到返回讲清楚:输入是什么、为什么要采样、怎么选 top query、未选中的 query 怎么近似、被选中的 query 怎么精确更新、最后返回什么。

关联文档:


0. ProbAttention 的第一性

Full Attention 做的是:

Attention(Q,K,V)=softmax(QKd)V

如果 query 长度是 LQ,key 长度是 LK,那么完整分数矩阵是:

QKRLQ×LK

复杂度大致是:

O(LQLK)

Informer 的 ProbSparse 思路是:

核心思想

不是所有 query 都值得精确计算完整 attention。先用少量采样 key 粗略估计每个 query 的“稀疏性/尖锐程度”,只挑 top query 做精确 attention;其他 query 用一个便宜的默认 context 近似。

换成代码语言:

text
1. 对每个 query,随机采样一部分 key,粗略算 QK
2. 用 M = max(sample_score) - mean(sample_score) 衡量 query 是否尖锐
3. 选 M 最大的 u 个 query
4. 只对这 u 个 query 计算完整 QK
5. 先给所有 query 一个初始 context
6. 用精确 attention 输出覆盖 top query 的 context
7. 返回 context

1. 在 Informer 主链中的位置

Informer 里并不是直接调用 ProbAttention,而是被 AttentionLayer 包了一层。

text
EncoderLayer.forward
└─ self.attention(x, x, x, ...)
   └─ AttentionLayer.forward
      ├─ Linear 生成 Q/K/V
      ├─ view 成多头格式: (B,L,H,D)
      └─ self.inner_attention(...)
         └─ ProbAttention.forward   ← 本文

Decoder 中也会用到:

text
DecoderLayer.forward
├─ self_attention: ProbAttention(mask_flag=True)
└─ cross_attention: ProbAttention(mask_flag=False)
两种 mask_flag

Encoder self-attention 和 decoder cross-attention 一般是 mask_flag=False

Decoder self-attention 是 mask_flag=True,防止未来信息泄露。


2. 输入输出接口

ProbAttention.forward(...) 的入口:

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):

进入 ProbAttention 前,AttentionLayer 已经把 Q/K/V 变成:

text
queries: (B, L_Q, H, D)
keys:    (B, L_K, H, D)
values:  (B, L_K, H, D)

ProbAttention.forward(...) 内部第一步会转置:

text
queries.transpose(2, 1): (B, H, L_Q, D)
keys.transpose(2, 1):    (B, H, L_K, D)
values.transpose(2, 1):  (B, H, L_K, D)

输出:

text
context: (B, H, L_Q, D)
attn:    None 或 (B, H, L_Q, L_K)

然后 AttentionLayer.forward(...) 会接着:

python
out = out.view(B, L, -1)
return self.out_projection(out), attn
维度方向容易混

AttentionLayerProbAttention 的输入是 (B,L,H,D)

ProbAttention 内部为了方便按 head 做 attention,会转成 (B,H,L,D)

所以读 _prob_QK_get_initial_context_update_context 时,都要按 (B,H,L,D) 理解。


3. 总流程图


4. 完整真实代码

位置:

text
ts_benchmark/baselines/time_series_library/layers/SelfAttention_Family.py
class ProbAttention
python
class ProbAttention(nn.Module):
    def __init__(
        self,
        mask_flag=True,
        factor=5,
        scale=None,
        attention_dropout=0.1,
        output_attention=False,
    ):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def _prob_QK(self, Q, K, sample_k, n_top):  # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape

        # calculate the sampled Q_K
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        # real U = U_part(factor*ln(L_k))*L_q
        index_sample = torch.randint(L_K, (L_Q, sample_k))
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

        # find the Top_k query with sparisty measurement
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]

        # use the reduced Q to calculate Q_K
        Q_reduce = Q[
            torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
        ]  # factor*ln(L_q)
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))  # factor*ln(L_q)*L_k

        return Q_K, M_top

    def _get_initial_context(self, V, L_Q):
        B, H, L_V, D = V.shape
        if not self.mask_flag:
            # V_sum = V.sum(dim=-2)
            V_sum = V.mean(dim=-2)
            contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
        else:  # use mask
            # requires that L_Q == L_V, i.e. for self-attention only
            assert L_Q == L_V
            contex = V.cumsum(dim=-2)
        return contex

    def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
        B, H, L_V, D = V.shape

        if self.mask_flag:
            attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)

        attn = torch.softmax(scores, dim=-1)  # nn.Softmax(dim=-1)(scores)

        context_in[
            torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
        ] = torch.matmul(attn, V).type_as(context_in)
        if self.output_attention:
            attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
            attns[
                torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
            ] = attn
            return context_in, attns
        else:
            return context_in, None

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape

        queries = queries.transpose(2, 1)
        keys = keys.transpose(2, 1)
        values = values.transpose(2, 1)

        U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()  # c*ln(L_k)
        u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()  # c*ln(L_q)

        U_part = U_part if U_part < L_K else L_K
        u = u if u < L_Q else L_Q

        scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

        # add scale factor
        scale = self.scale or 1.0 / sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale
        # get the context
        context = self._get_initial_context(values, L_Q)
        # update the context with selected top_k queries
        context, attn = self._update_context(
            context, values, scores_top, index, L_Q, attn_mask
        )

        return context.contiguous(), attn

5. forward:主控流程

先看 forward,因为它决定执行顺序。

中文注释版:

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
    # AttentionLayer 传进来的是 (B, L, H, D)
    B, L_Q, H, D = queries.shape
    _, L_K, _, _ = keys.shape

    # 转成 ProbAttention 内部使用的 (B, H, L, D)
    queries = queries.transpose(2, 1)
    keys = keys.transpose(2, 1)
    values = values.transpose(2, 1)

    # 采样 key 的数量:factor * ceil(log L_K)
    U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()

    # 精确计算 attention 的 query 数量:factor * ceil(log L_Q)
    u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()

    # 防止采样数量超过真实长度
    U_part = U_part if U_part < L_K else L_K
    u = u if u < L_Q else L_Q

    # 选出 top query,并只对这些 query 计算完整 QK
    scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

    # 缩放点积,和标准 attention 一样除以 sqrt(D)
    scale = self.scale or 1.0 / sqrt(D)
    if scale is not None:
        scores_top = scores_top * scale

    # 给所有 query 一个初始 context
    context = self._get_initial_context(values, L_Q)

    # 只把 top query 的 context 用精确 attention 结果覆盖
    context, attn = self._update_context(
        context, values, scores_top, index, L_Q, attn_mask
    )

    return context.contiguous(), attn

5.1 当前调试参数下的大致 shape

如果使用 调试形参 里的小例子:

text
B = 4
seq_len = 24
d_model = 32
n_heads = 2
head_dim = 16

Encoder 侧进入 ProbAttention 时:

text
queries: (4, 24, 2, 16)
keys:    (4, 24, 2, 16)
values:  (4, 24, 2, 16)

转置后:

text
queries: (4, 2, 24, 16)
keys:    (4, 2, 24, 16)
values:  (4, 2, 24, 16)

如果 factor=3

text
U_part = min(3 * ceil(log(24)), 24)
       = min(3 * 4, 24)
       = 12

u = min(3 * ceil(log(24)), 24)
  = 12

含义:

text
每个 query 随机看 12 个 key 来估计稀疏性。
最后选 12 个 query 做完整 attention。

5.2 np.ceil:采样规模从公式落到代码

对应代码:

python
U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()

U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q

对应数学公式:

Upart=min(clnLK,LK)u=min(clnLQ,LQ)

变量含义:

符号/变量代码位置含义
cself.factorProbSparse 的采样系数,越大越接近 full attention,越小越省计算
LKkeys.shape[1]key 序列长度
LQqueries.shape[1]query 序列长度
UpartU_part / sample_k每个 query 随机抽多少个 key 来估计重要性
uu / n_top最后保留多少个 query 做完整 attention

逐函数解释:

text
np.log(L_K)

表示自然对数 lnLK。Informer 论文的 ProbSparse 不是采样固定比例的 key,而是采样对数级别的 key,所以长序列时不会线性变贵。

text
np.ceil(...)

表示向上取整。比如:

text
log(24) = 3.178...
ceil(log(24)) = 4

向上取整的原因是:采样数量必须是整数,而且不能因为小数截断导致采样过少。

text
.astype("int").item()

把 NumPy 标量转成 Python 整数。这里不是模型计算,只是为了得到后面 torch.randinttopk 需要的整数参数。

具体例子:

text
L_K = 24
factor = 3

U_part = 3 * ceil(log(24))
       = 3 * 4
       = 12

所以每个 query 随机看 12 个 key。

toy 例子里:

text
L_K = 5
factor = 1

U_part = 1 * ceil(log(5))
       = 1 * 2
       = 2

所以 toy 例子中每个 query 只随机看 2 个 key。

最后两行:

python
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q

作用是上限截断:

text
如果算出来要采样 12 个 key,但真实只有 5 个 key,那最多只能采样 5 个。
如果算出来要选 12 个 query,但真实只有 4 个 query,那最多只能选 4 个。

6. _prob_QK:选 top query

_prob_QK 是 ProbAttention 的核心。

它做两件事:

text
1. 用随机采样 key 粗略判断哪些 query 更重要
2. 只对重要 query 计算完整 QK

6.1 完整代码

python
def _prob_QK(self, Q, K, sample_k, n_top):  # n_top: c*ln(L_q)
    # Q [B, H, L, D]
    B, H, L_K, E = K.shape
    _, _, L_Q, _ = Q.shape

    # calculate the sampled Q_K
    K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
    # real U = U_part(factor*ln(L_k))*L_q
    index_sample = torch.randint(L_K, (L_Q, sample_k))
    K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
    Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

    # find the Top_k query with sparisty measurement
    M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
    M_top = M.topk(n_top, sorted=False)[1]

    # use the reduced Q to calculate Q_K
    Q_reduce = Q[
        torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
    ]  # factor*ln(L_q)
    Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))  # factor*ln(L_q)*L_k

    return Q_K, M_top

6.2 toy 例子

为了能手算,固定:

text
B = 1
H = 1
L_Q = 4
L_K = 5
D = 2
sample_k = 2
n_top = 2

设 Q:

text
q0 = (2, 0)
q1 = (0, 1)
q2 = (3, 1)
q3 = (0, 1)

设 K:

text
k0 = (1, 0)
k1 = (0, 1)
k2 = (2, 0)
k3 = (0, 1)
k4 = (1, 1)

真实代码里 index_sample = torch.randint(...) 是随机的。为了讲清楚,toy 例子固定采样:

text
q0 采样 key: k0, k2
q1 采样 key: k1, k4
q2 采样 key: k0, k4
q3 采样 key: k2, k3

也就是:

text
index_sample:
q0 -> (0, 2)
q1 -> (1, 4)
q2 -> (0, 4)
q3 -> (2, 3)

6.3 K_expand

代码:

python
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)

原来:

text
K: (B,H,L_K,D) = (1,1,5,2)

unsqueeze(-3) 后:

text
K: (1,1,1,5,2)

expand(B,H,L_Q,L_K,E) 后:

text
K_expand: (1,1,4,5,2)

语义:

text
给每个 query 准备一份 key 池。
不是复制真实数据,而是广播视图。

更精确地说:

text
K[b,h,k,e] 变成 K_expand[b,h,q,k,e]

数学对应是:

Kexpand[b,h,q,k,e]=K[b,h,k,e]

也就是说,q 这一维只是为了让每个 query 都能在自己的位置上索引不同的 sampled key。它不改变 key 的数值。

expand 的关键点:

操作shape 变化含义
K(B,H,L_K,D)原始 key
K.unsqueeze(-3)(B,H,1,L_K,D)L_K 前面插入 query 维
.expand(B,H,L_Q,L_K,D)(B,H,L_Q,L_K,D)把长度为 1 的 query 维广播成 L_Q
Note

expand 通常不真实复制数据,而是创建广播视图。这里的目标不是“复制一份 K”,而是让索引表达式可以写成 K_expand[:, :, query_index, sampled_key_index, :]

6.4 K_sample

代码:

python
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]

这个索引的语义是:

text
对每个 query,只取 sample_k 个随机 key。

toy 中得到:

text
q0 的 K_sample: k0=(1,0), k2=(2,0)
q1 的 K_sample: k1=(0,1), k4=(1,1)
q2 的 K_sample: k0=(1,0), k4=(1,1)
q3 的 K_sample: k2=(2,0), k3=(0,1)

shape:

text
K_sample: (B,H,L_Q,sample_k,D) = (1,1,4,2,2)

6.5 Q_K_sample

代码:

python
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()

局部 shape:

text
Q.unsqueeze(-2):              (1,1,4,1,2)
K_sample.transpose(-2,-1):    (1,1,4,2,2)
Q_K_sample:                   (1,1,4,2)

逐 query 手算:

text
q0=(2,0):
  q0·k0 = (2,0)·(1,0) = 2
  q0·k2 = (2,0)·(2,0) = 4
  sample score = (2,4)

q1=(0,1):
  q1·k1 = (0,1)·(0,1) = 1
  q1·k4 = (0,1)·(1,1) = 1
  sample score = (1,1)

q2=(3,1):
  q2·k0 = (3,1)·(1,0) = 3
  q2·k4 = (3,1)·(1,1) = 4
  sample score = (3,4)

q3=(0,1):
  q3·k2 = (0,1)·(2,0) = 0
  q3·k3 = (0,1)·(0,1) = 1
  sample score = (0,1)

6.6 稀疏性指标 M = max - mean

代码:

python
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
代码细节

这里 mean 的分母用的是 L_K,不是 sample_k

所以它不是采样分数的普通均值,而是论文/实现里用于估计稀疏性的写法。

toy 中 L_K = 5

text
q0: max(2,4) - (2+4)/5 = 4 - 1.2 = 2.8
q1: max(1,1) - (1+1)/5 = 1 - 0.4 = 0.6
q2: max(3,4) - (3+4)/5 = 4 - 1.4 = 2.6
q3: max(0,1) - (0+1)/5 = 1 - 0.2 = 0.8

所以:

text
M = (2.8, 0.6, 2.6, 0.8)

选择 top 2:

python
M_top = M.topk(n_top, sorted=False)[1]

topk 返回的是一个二元组:

python
values, indices = M.topk(n_top, sorted=False)

其中:

text
values  = 最大的 n_top 个分数
indices = 这些分数在原 query 维度上的下标

当前代码只要下标,所以写:

python
M_top = M.topk(n_top, sorted=False)[1]

sorted=False 的意思是:返回的 top query 不保证按分数从大到小排序。对后续计算没有影响,因为后面只是用这些下标去取对应的 Q,不是依赖顺序做递推。

toy 中:

text
M = (2.8, 0.6, 2.6, 0.8)
n_top = 2

所以 top 2 的 query 下标是:

text
M_top = (0, 2)

如果 sorted=False,也可能返回:

text
M_top = (2, 0)

这两种都表示选中了 q0q2,只是顺序不同。

公式对应:

I=TopKIndex(M,u)

其中 I 就是代码里的 M_top

得到:

text
M_top = (0, 2)

含义:

text
q0 和 q2 被认为是最值得精确计算 attention 的 query。
q1 和 q3 暂时只用默认 context 近似。

6.7 对 top query 计算完整 QK

代码:

python
Q_reduce = Q[
    torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
]
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))

公式对应:

Qreduce[b,h,r,:]=Q[b,h,I[b,h,r],:]

其中:

text
I = M_top
r = top query 内部编号

也就是说,Q_reduce 不是新的 query,而是从原始 Q 里按 M_top 下标挑出来的少量重要 query。

toy 中:

text
Q_reduce = q0, q2

完整 key:

text
k0=(1,0), k1=(0,1), k2=(2,0), k3=(0,1), k4=(1,1)

完整 QK:

text
q0=(2,0):
  q0·k0 = 2
  q0·k1 = 0
  q0·k2 = 4
  q0·k3 = 0
  q0·k4 = 2
  score row = (2,0,4,0,2)

q2=(3,1):
  q2·k0 = 3
  q2·k1 = 1
  q2·k2 = 6
  q2·k3 = 1
  q2·k4 = 4
  score row = (3,1,6,1,4)

返回:

text
Q_K.shape = (B,H,n_top,L_K) = (1,1,2,5)
M_top.shape = (B,H,n_top)   = (1,1,2)

7. _get_initial_context:给所有 query 一个默认输出

_prob_QK 只算了 top query。那未选中的 query 怎么办?

答案:先给所有 query 一个便宜的初始 context。

7.1 完整代码

python
def _get_initial_context(self, V, L_Q):
    B, H, L_V, D = V.shape
    if not self.mask_flag:
        # V_sum = V.sum(dim=-2)
        V_sum = V.mean(dim=-2)
        contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
    else:  # use mask
        # requires that L_Q == L_V, i.e. for self-attention only
        assert L_Q == L_V
        contex = V.cumsum(dim=-2)
    return contex

7.2 mask_flag=False

Encoder self-attention 和 decoder cross-attention 通常走这里。

设 V:

text
v0=(0.1,0.8)
v1=(0.5,0.3)
v2=(0.9,0.2)
v3=(0.4,0.6)
v4=(0.7,0.1)

求均值:

text
mean(V)
= ((0.1+0.5+0.9+0.4+0.7)/5, (0.8+0.3+0.2+0.6+0.1)/5)
= (0.52, 0.40)

扩展到所有 query:

text
context for q0 = (0.52,0.40)
context for q1 = (0.52,0.40)
context for q2 = (0.52,0.40)
context for q3 = (0.52,0.40)

shape:

text
context: (B,H,L_Q,D) = (1,1,4,2)

含义:

text
未被选中的 query 先用 value 的全局均值作为近似输出。

7.3 mask_flag=True

Decoder self-attention 走这里。

代码:

python
contex = V.cumsum(dim=-2)

dim=-2 对应的是时间/token 维:

text
V: (B,H,L_V,D)

      dim=-2

所以 cumsum(dim=-2) 的数学含义是:

context[b,h,t,:]=i=0tV[b,h,i,:]

它不是普通求和,而是“前缀和”。每个时间点只累计自己和自己之前的 value。

如果:

text
v0=(1,0)
v1=(0,2)
v2=(3,1)
v3=(1,1)

那么:

text
cumsum:
q0 context = v0             = (1,0)
q1 context = v0 + v1        = (1,2)
q2 context = v0 + v1 + v2   = (4,3)
q3 context = v0 + v1 + v2 + v3 = (5,4)

含义:

text
因果 self-attention 下,默认 context 不能看未来,只能用当前位置之前的累计信息。

这里还要注意真实代码前面有:

python
assert L_Q == L_V

原因是 decoder self-attention 中 query 和 value 来自同一段 decoder 输入。只有长度一致时,q_t 才能自然对应到 V 的前缀 v_0 ... v_t

Important

cumsum(V) 只是给所有 query 的初始近似 context。后面 _update_context(...) 仍然会对 top query 用 masked attention 的精确结果覆盖它。


8. _update_context:用精确 attention 覆盖 top query

8.1 完整代码

python
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
    B, H, L_V, D = V.shape

    if self.mask_flag:
        attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
        scores.masked_fill_(attn_mask.mask, -np.inf)

    attn = torch.softmax(scores, dim=-1)  # nn.Softmax(dim=-1)(scores)

    context_in[
        torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
    ] = torch.matmul(attn, V).type_as(context_in)
    if self.output_attention:
        attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
        attns[
            torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
        ] = attn
        return context_in, attns
    else:
        return context_in, None

8.2 scale 后 scores

forward 中已经做了:

python
scores_top = scores_top * (1.0 / sqrt(D))

toy 中 D=2

text
scale = 1 / sqrt(2) ≈ 0.707

原始 top scores:

text
q0: (2,0,4,0,2)
q2: (3,1,6,1,4)

scale 后:

text
q0: (1.414, 0, 2.828, 0, 1.414)
q2: (2.121, 0.707, 4.242, 0.707, 2.828)

8.3 softmax

代码:

python
attn = torch.softmax(scores, dim=-1)

dim=-1 是 key 维。

近似得到:

text
attn(q0) ≈ (0.145, 0.035, 0.596, 0.035, 0.145)
attn(q2) ≈ (0.087, 0.021, 0.724, 0.021, 0.177)

8.4 attn @ V

代码:

python
torch.matmul(attn, V)

用 V:

text
v0=(0.1,0.8)
v1=(0.5,0.3)
v2=(0.9,0.2)
v3=(0.4,0.6)
v4=(0.7,0.1)

计算 q0:

text
out(q0)
= 0.145*v0 + 0.035*v1 + 0.596*v2 + 0.035*v3 + 0.145*v4
≈ (0.684, 0.269)

计算 q2:

text
out(q2)
= 0.087*v0 + 0.021*v1 + 0.724*v2 + 0.021*v3 + 0.177*v4
≈ (0.803, 0.246)

8.5 覆盖 context

初始 context:

text
q0=(0.52,0.40)
q1=(0.52,0.40)
q2=(0.52,0.40)
q3=(0.52,0.40)

top query index:

text
index = (0,2)

覆盖后:

text
q0=(0.684,0.269)  ← 精确更新
q1=(0.520,0.400)  ← 保持近似
q2=(0.803,0.246)  ← 精确更新
q3=(0.520,0.400)  ← 保持近似

这就是 ProbAttention 的输出 context。


9. output_attention=True 时返回什么

默认:

python
output_attention=False

所以:

text
return context_in, None

如果打开:

python
output_attention=True

代码会先构造一个均匀 attention:

python
attns = torch.ones([B, H, L_V, L_V]) / L_V

然后把 top query 的真实 attention 填进去:

python
attns[..., index, :] = attn

含义:

text
未被选中的 query,attention 被近似为均匀分布。
被选中的 top query,attention 是真实 softmax(scores)。
形状细节

这里 attns 写成 [B,H,L_V,L_V]。在 self-attention 中 L_Q=L_V 没问题。

对 cross-attention,如果 L_QL_V 不同,打开 output_attention=True 时要额外小心这个实现假设。


10. Encoder / Decoder 中的差异

10.1 Encoder self-attention

创建位置在 Informer.__init__

python
ProbAttention(
    False,
    configs.factor,
    attention_dropout=configs.dropout,
    output_attention=configs.output_attention,
)

特征:

text
mask_flag=False
Q=K=V=encoder embedding
初始 context = mean(V)
top query 用精确 attention 更新

10.2 Decoder self-attention

创建位置:

python
ProbAttention(
    True,
    configs.factor,
    attention_dropout=configs.dropout,
    output_attention=False,
)

特征:

text
mask_flag=True
Q=K=V=decoder embedding
初始 context = cumsum(V)
ProbMask 防止未来泄露

10.3 Decoder cross-attention

创建位置:

python
ProbAttention(
    False,
    configs.factor,
    attention_dropout=configs.dropout,
    output_attention=False,
)

特征:

text
mask_flag=False
Q=decoder hidden
K/V=encoder output
初始 context = mean(encoder values)
top decoder query 从 encoder memory 中精确取信息

11. 和 FullAttention 的对比

项目FullAttentionProbAttention
是否计算所有 query-key 分数
score shape(B,H,L_Q,L_K)只对 top query 精确算 (B,H,u,L_K)
未选中 query 怎么办不存在,全部精确算用初始 context 近似
核心额外步骤_prob_QK 选 top query
复杂度直觉O(LQLK)O(LlogL) 级别
Informer 论文创新点

12. 公式和代码变量对照表

数学表达代码变量/代码行作用
Q,K,Vqueries, keys, valuesattention 的三组输入
B,H,L,DB,H,L_Q,D / B,H,L_K,Ebatch、head、序列长度、head 内特征维
cself.factorProbSparse 采样系数
Upart=min(clnLK,LK)U_part / sample_k每个 query 抽样多少个 key
u=min(clnLQ,LQ)u / n_top选多少个重要 query
Kexpand[b,h,q,k,e]=K[b,h,k,e]K.unsqueeze(-3).expand(...)为“每个 query 抽不同 key”准备广播视图
S~q,j=qkjQ_K_samplesampled key 上的粗略打分
Mq=maxjS~q,j1LKjS~q,jM判断 query 是否稀疏/重要
I=TopKIndex(M,u)M_top / index重要 query 的下标
QIQ_reduce从全部 query 中取出 top query
SI=QIKQ_K / scores_toptop query 的完整 attention score
softmax(SI/D)attn = torch.softmax(scores, dim=-1)完整 attention 权重
softmax(SI/D)Vtorch.matmul(attn, V)top query 的精确输出
mean(V)V.mean(dim=-2)非 causal 情况下的默认 context
i=0tViV.cumsum(dim=-2)causal 情况下的默认 context

13. 调试断点建议

按下面顺序打断点:

  1. AttentionLayer.forward(...)
    • 看输入 queries.shape 是否是 (B,L,H,D)
  2. ProbAttention.forward(...)
    • 看转置前后 shape。
  3. _prob_QK(...)
    • sample_kn_topindex_sampleM_top
  4. _get_initial_context(...)
    • mask_flag=False 时是不是 mean(V) 扩展。
  5. _update_context(...)
    • index 对应哪些 query 被覆盖。

调试时重点观察:

text
queries.shape
keys.shape
values.shape
U_part
u
index_sample.shape
Q_K_sample.shape
M.shape
M_top / index
scores_top.shape
context.shape

14. 一句话总结

ProbAttention 的关键不是“另一种 softmax”,而是:

Summary

先用随机采样 key 估计每个 query 的重要性,只对 top query 做完整 attention;其他 query 用便宜的默认 context 近似。

最核心的数据流:

text
Q,K,V
-> sample keys
-> sampled QK
-> M = max - mean
-> top query index
-> top query full QK
-> initial context for all query
-> top query context = softmax(full QK) @ V
-> return context

*记录并在线阅读我的笔记*