Skip to content

Layer5 FullAttention

覆盖 AttentionLayer.forward() 中的 self.inner_attention(queries, keys, values, ...) 调用。FullAttention 是 iTransformer 的最底层注意力计算核:einsum 计算 QKT(得到 9×9 变量注意力矩阵)→ scaled softmax → einsum ×V(加权聚合)。iTransformer 中 mask_flag=Falseoutput_attention=False,两个 if 分支均不执行。


§1 在父层中的位置

AttentionLayer.forward() 第三步:

python
out, attn = self.inner_attention(
    queries, keys, values, attn_mask, tau=tau, delta=delta
)
  • 输入:queries/keys/values(3, 9, 4, 2)
  • 输出:out (3, 9, 4, 2)attn = None

§2 I/O 接口定义

参数shape(toy)含义
queries(3, 9, 4, 2)B=3,L=9(token_count),H=4(heads),E=2(d_keys)
keys(3, 9, 4, 2)S=9(key token 数),同 L
values(3, 9, 4, 2)D=2(d_values),同 E
attn_maskNoneiTransformer 不用掩码
输出 out(3, 9, 4, 2)注意力加权聚合结果
输出 attnNoneoutput_attention=False

§3 顺序图


§4 语义分组图


§5 逐步精读

§5.0 完整原始代码

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
    B, L, H, E = queries.shape
    _, S, _, D = values.shape
    scale = self.scale or 1.0 / sqrt(E)

    scores = torch.einsum("blhe,bshe->bhls", queries, keys)

    if self.mask_flag:
        if attn_mask is None:
            attn_mask = TriangularCausalMask(B, L, device=queries.device)
        scores.masked_fill_(attn_mask.mask, -np.inf)

    A = self.dropout(torch.softmax(scale * scores, dim=-1))
    V = torch.einsum("bhls,bshd->blhd", A, values)

    if self.output_attention:
        return V.contiguous(), A
    else:
        return V.contiguous(), None

§5.1 宏观逻辑

用小例子(B=1,L=S=3,H=2,E=D=2)串起来:

① 解包形状
  B=1, L=3, H=2, E=2
  scale = 1/√2 ≈ 0.707

② einsum QKᵀ
  对每个 batch 和 head,Q(3×2) × K^T(2×3) → scores(3×3)
  scores 整体 (1, 2, 3, 3)  — 这就是注意力得分矩阵

③ scaled softmax
  scale × scores → softmax(dim=-1)
  每行 3 个数变成 3 个和为 1 的权重
  A (1, 2, 3, 3)

④ einsum A×V
  对每个 batch 和 head,A(3×3) × V(3×2) → Out(3×2)
  Out 整体 (1, 3, 2, 2)  — 每个 query token 的新表示

⑤ 返回 Out.contiguous(), None

"9×9 变量注意力矩阵"的语义

在 iTransformer 中,L=S=9(9 个变量 token),所以 scores 是 (3, 4, 9, 9) 的矩阵。

  • scores[b, h, i, j] = 变量 token i 对变量 token j 的相似度(经 Q/K 线性变换后的点积)
  • A[b, h, i, :] 是第 i 个变量 token 对 9 个变量 token 的注意力分布(softmax 后,行和=1)
  • V_out[b, i, h, :] = j=08A[b,h,i,j]×V[b,j,h,:],即"var_i 的新表示 = 对 9 个变量 V 的加权组合"

这捕捉的是跨变量依赖:var_0 的输出受到 var_1~var_4 及 4 个时间 token 的影响,权重由它们的 Q-K 相似度决定。

为什么 O(token_count²) 比标准 Transformer 更高效?

标准 Transformer 的注意力矩阵大小 = L×L=12×12=144(seq_len × seq_len,对每个 head)。 iTransformer 的注意力矩阵大小 = token_count×token_count=9×9=81(变量数 + 时间维度)。

关键在于:inverted 之后,序列维度从 seq_len=12 压缩为 token_count=N+time_dims=9,注意力矩阵反而缩小了。 对于变量数 N 远小于序列长度 L 的场景(多变量长时序,如 N=7, L=336),复杂度从 O(L2) 降至 O(N2), 效率提升显著:O(3362)=112896 vs O(112)=121(含 time_dims=4 共 11 个 token)。


§5.2 步骤一:解包形状与 scale

python
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
  • B=3, L=9, H=4, E=2(queries 的四个维度)
  • S=9, D=2(values 的 key token 数和 d_values)
  • scale = 1/\sqrt{E} = 1/\sqrt{2} \approx 0.707
scale=1E=1dkeys

self.scale 初始化时默认 None,所以走右侧 1/sqrt(E)。这是 Vaswani 2017 原始 Transformer 的缩放因子,防止点积随 d_keys 增大而导致 softmax 梯度消失。

为什么要 scale?

当 d_keys 较大时,QKT 的量级约为 dkeys(随机初始化下),不除以 dkeys 会让最大值远大于其他值,softmax 后趋近于 one-hot,梯度极小。除以 dkeys 后量级归一,softmax 分布更均匀,梯度更健康。

toy 数值:

scale = 1/√2 ≈ 0.707

§5.3 步骤二:einsum QKᵀ

python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)

下标含义:

下标维度(toy)含义
b3batch
l9query token 位置
h4head
e2d_keys(求和消去)
s9key token 位置

等价矩阵乘法(对固定 b, h):

scores[b,h]=Q[b,:,h,:]K[b,:,h,:]TR9×9

即:对每个 batch、每个 head,计算 Q(9×2)×KT(2×9)scores(9×9)

输出 scores (3, 4, 9, 9):batch=3,head=4,(query_token, key_token)=(9,9)。

toy 数值(batch=0,head=0):

Q[0, :, 0, :]  shape (9, 2)  — 9 个变量 token,Head 0 的 2 维 query
K[0, :, 0, :]  shape (9, 2)  — 同,key

scores[0, 0, i, j] = Q[0,i,0,0]×K[0,j,0,0] + Q[0,i,0,1]×K[0,j,0,1]
                    = 2 维内积,标量

scores[0, 0, :, :]  shape (9, 9)
  [i,j] = var_i 的 query 与 var_j 的 key 的相似度

§5.4 步骤三:mask_flag=False,跳过

python
if self.mask_flag:
    ...

iTransformer 的 FullAttention 初始化时 mask_flag=False,此 if 块永远不执行

为什么变量间注意力不需要因果掩码?

标准时序 Transformer(token = 时间步)中,因果掩码防止 t 时刻的 query 看到 t+1 时刻的 key——这是"不能用未来预测过去"的约束。

iTransformer(token = 变量)中,token 之间没有时间先后关系:

  • var_i 关注 var_j 与 var_j 关注 var_i 是对等的,两者都只持有历史窗口 [t-seq_len+1, t] 内的信息
  • 不存在"未来变量"的概念,var_3 的 token 不包含 var_0 未来的任何信息

因此全局互相关(完整 9×9 矩阵)是正确的,引入因果掩码反而会错误地屏蔽有效的跨变量信号。


§5.5 步骤四:scaled softmax

python
A = self.dropout(torch.softmax(scale * scores, dim=-1))
A[b,h,i,:]=softmax(scores[b,h,i,:]E)
  • scale * scores (3,4,9,9)softmax(dim=-1) → 每行 9 个数归一化,行和=1
  • self.dropout 训练时随机置零部分权重,推理时恒等

输出 A (3, 4, 9, 9):每行是一个概率分布,表示"该 query token 对 9 个 key token 的关注程度"。

toy 数值(batch=0,head=0,query token=var_0,即 row 0):

scale × scores[0, 0, 0, :] = 0.707 × [s_0, s_1, ..., s_8]   (9个原始相似度分数)

softmax 后:
  A[0, 0, 0, :] = [a_0, a_1, ..., a_8],各元素 ∈ (0,1),∑=1.0

语义:var_0 对 [var_0, var_1, var_2, var_3, var_4, time_0, time_1, time_2, time_3]
     的关注权重。若 a_1 最大,说明 var_1 对 var_0 的预测最相关。

§5.6 步骤五:einsum A×V

python
V = torch.einsum("bhls,bshd->blhd", A, values)

下标含义:

下标消去情况含义
b保留batch
h保留head
l保留query token 位置
s消去(求和)key token 位置(对 9 个 value token 加权求和)
d保留d_values=2

等价矩阵乘法(对固定 b, h):

Out[b,:,h,:]=A[b,h,:,:]V[b,:,h,:]R9×2

即:A(9×9)×V(9×2)Out(9×2)

输出 V (3, 9, 4, 2) — 形状与输入 queries 完全相同。

toy 数值(batch=0,head=0,query token=var_0,即 row 0):

A[0, 0, 0, :] = [a_0, ..., a_8]   (var_0 对 9 个 token 的权重)
V[0, :, 0, :]  shape (9, 2)       (9 个 token 在 head 0 的 value 向量)

Out[0, 0, 0, :] = Σ_{j=0}^{8} a_j × V[0, j, 0, :]
               = a_0 × [v0_0, v0_1]    ← var_0 自身的 value
               + a_1 × [v1_0, v1_1]    ← var_1 的 value,权重 a_1
               + ...
               + a_8 × [v8_0, v8_1]    ← time_3 的 value,权重 a_8

Out[0, 0, 0, :] = 2 维向量,编码了 var_0 关注所有变量后的综合表示

§5.7 步骤六:返回

python
if self.output_attention:
    return V.contiguous(), A
else:
    return V.contiguous(), None

iTransformer 中 output_attention=False,走 else 分支,attn=None

.contiguous() 确保返回的张量在内存中连续存储(部分 PyTorch 操作要求连续)。


§6 注意力矩阵全景(9×9)

                  Key:  var_0  var_1  var_2  var_3  var_4  time_0  time_1  time_2  time_3
Query: var_0           [高    低    低    低    低    中     中      中      中   ]  ← 行和=1
Query: var_1           [低    高    低    低    低    中     中      中      中   ]
Query: var_2           [低    低    高    低    低    中     中      中      中   ]
Query: var_3           [低    低    低    高    低    中     中      中      中   ]
Query: var_4           [低    低    低    低    高    中     中      中      中   ]
Query: time_0          [均匀  均匀  均匀  均匀  均匀  均匀   均匀    均匀    均匀 ]
Query: time_1          [...]
Query: time_2          [...]
Query: time_3          [...]

对角线(自注意力)通常权重最高(变量自己最"了解"自己);时间 token 行权重均匀(时间特征对所有变量的关系是均匀的)。这种结构在 [[05-收束]] 的模型对比中与 PatchTST(时间 patch 间注意力)形成鲜明对比。


§7 注意力矩阵热力图

*记录并在线阅读我的笔记*