Skip to content

Layer 5 — FullAttention 精读

父层 04B-Layer4-AttentionLayer 的第三步调用 self.inner_attention(Q, K, V, ...)
本文档覆盖 FullAttention.forward 的完整计算(最底层,无子层委托)。


1. 在父层中的位置

AttentionLayer.forward
  └─ out, attn = self.inner_attention(queries, keys, values, attn_mask, ...)  ← 本文档
       └─ scale · Q@K.T → softmax → A@V

2. I/O 接口定义

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
shape(toy)含义
输入 queries(8, 6, 2, 8) = (B*C, patch_num, n_heads, d_keys)已拆多头的 Q
输入 keys(8, 6, 2, 8)已拆多头的 K
输入 values(8, 6, 2, 8) = (B*C, patch_num, n_heads, d_values)已拆多头的 V
输出 V(8, 6, 2, 8)注意力加权后的聚合 value,形状与输入相同
输出 attnNoneoutput_attention=False 时为 None

tau / delta 全为 None,本层完全不使用(ProbSparse 相关参数)。
本层不知道 d_model、不做 projection、不做残差,只做三件事:Q@K.T → softmax → A@V


3. 顺序图(具体层)


4. 语义分组图(索引层)

三件事:① 算出每对 patch 的相似度(Q@K.T)→ ② 归一化成概率(softmax)→ ③ 用概率加权 value(A@V)


5. 逐步解析

5.0 完整原始代码

python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
    B, L, H, E = queries.shape
    _, S, _, D = values.shape
    scale = self.scale or 1.0 / sqrt(E)

    scores = torch.einsum("blhe,bshe->bhls", queries, keys)

    if self.mask_flag:
        if attn_mask is None:
            attn_mask = TriangularCausalMask(B, L, device=queries.device)
        scores.masked_fill_(attn_mask.mask, -np.inf)

    A = self.dropout(torch.softmax(scale * scores, dim=-1))
    V = torch.einsum("bhls,bshd->blhd", A, values)

    if self.output_attention:
        return V.contiguous(), A
    else:
        return V.contiguous(), None

5.1 预处理

本节的作用

从 tensor shape 中提取维度常量,计算 ==缩放点积注意力== 所需的缩放系数 1/E

步骤一 — 解构维度 + 缩放系数

python
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)

queries.shape = (8, 6, 2, 8),拆出 B=8L=6(query patch 数)、H=2(头数)、E=8(d_keys,每个头的向量维度)。values.shape 同理得 S=6(key patch 数)、D=8(d_values)。

self.scale=None,所以 scale = 1/\sqrt{8} \approx 0.354

为什么要除以 E

两个随机单位向量的点积期望为 0,方差为 E(d_keys)。维度越高,点积绝对值越大,softmax 的输入就越极端,梯度进入饱和区趋向消失。

除以 E 将点积的方差归一化回 1,使 softmax 保持在梯度敏感区,让模型能有效学习注意力权重。


5.2 注意力打分

本节的作用

计算每对 (query patch, key patch) 的相似度,经 softmax 归一化为概率分布 A。

步骤二 — einsum Q@K.T → scores

python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)

下图展示 einsum ① 和 ② 的维度消去结构(步骤五的 einsum ② 也在其中):

下标含义:Q 的 blhe 与 K 的 bshe 共享 e 轴(d_keys=8),einsum 对 e 求和得到每对 (l, s) 的标量点积,输出 bhls(8, 2, 6, 6)

下标含义toy 值在输出里
b独立序列编号(B×C 合并后)8✓ 保留
lquery 的 patch 位置6✓ 保留
h注意力头编号2✓ 保留
ed_keys,向量维度8✗ 求和消去(点积)
skey 的 patch 位置6✓ 保留

固定一个头视角(b=0, h=0):Q[0,:,0,:] shape (6,8)K[0,:,0,:] shape (6,8) 做矩阵乘 (6,8) @ (8,6) = (6,6),即 scores[0,0,:,:]scores[0,0,l,s] 是 patch l 与 patch s 的点积相似度。

步骤三 — mask 分支

python
if self.mask_flag:
    ...

self.mask_flag=False,整块跳过。PatchTST 不加因果掩码——patch 之间==双向关注==,未来 patch 对当前 patch 同样有价值。

步骤四 — 缩放 + softmax + dropout

python
A = self.dropout(torch.softmax(scale * scores, dim=-1))

scale * scores 把得分缩小(乘 0.354)。softmax(dim=-1)s 轴(key patch 方向)归一化,使每行权重之和为 1。

toy 数值追踪(b=0, h=0, l=0 的一行):原始得分 [0.41, -0.12, 0.67, 0.08, -0.23, 0.31] → 乘 0.354 → [0.14, -0.04, 0.24, 0.03, -0.08, 0.11] → softmax → [0.17, 0.14, 0.19, 0.15, 0.13, 0.17](和 = 1.0)。

输出 A: (8, 2, 6, 6)A[b, h, l, :] 是第 l 个 query patch 对 6 个 key patch 的注意力分布。

注意力矩阵热力图(固定 b=0, h=0,6×6 矩阵,颜色越深权重越高):


5.3 信息聚合

本节的作用

以注意力分布 A 为权重,对所有 key patch 的 value 做加权求和,得到每个 query patch 的新表示。

步骤五 — einsum A@V → output

python
V = torch.einsum("bhls,bshd->blhd", A, values)

A 的 bhls 与 values 的 bshd 共享 s 轴(key patch 位置),einsum 对 s 求和消去,即以 A 的权重对所有 key patch 的 value 做==加权平均==,输出 blhd(8, 6, 2, 8)

下标来自含义在输出里
bA 和 values序列编号✓ 保留
hA 和 values头编号✓ 保留
lAquery patch 位置✓ 保留
sA 和 valueskey patch 位置✗ 求和消去(加权平均)
dvaluesd_values✓ 保留

固定一个头视角(b=0, h=0):A[0,0,:,:] shape (6,6)values[0,:,0,:] shape (6,8) 做矩阵乘 (6,6) @ (6,8) = (6,8),即 V[0,:,0,:]

toy 数值追踪(patch 0 的输出):V[0,0,0,d]=s=05A[0,0,0,s]×values[0,s,0,d],即 [0.17, 0.14, 0.19, 0.15, 0.13, 0.17] 分别加权 6 个 key patch 的 value 向量,得到 patch 0 聚合后的 8 维表示。

步骤六 — 返回

python
return V.contiguous(), None

.contiguous() 确保内存连续。

为什么 einsum 之后需要 .contiguous()

torch.einsum 返回的 tensor 在内存中可能是非连续布局(stride 不对齐)。后续 04B-Layer4-AttentionLayer 里的 .view(B, L, -1) 要求内存必须连续,否则会抛 RuntimeError.contiguous() 强制创建一块连续内存的副本,以解决这个问题。

None 代替 attn 矩阵——output_attention=False 时不需要返回注意力权重,仅返回聚合后的 value。

*记录并在线阅读我的笔记*