Appearance
Layer 5 — FullAttention 精读
父层 04B-Layer4-AttentionLayer 的第三步调用
self.inner_attention(Q, K, V, ...)。
本文档覆盖FullAttention.forward的完整计算(最底层,无子层委托)。
1. 在父层中的位置
AttentionLayer.forward
└─ out, attn = self.inner_attention(queries, keys, values, attn_mask, ...) ← 本文档
└─ scale · Q@K.T → softmax → A@V2. I/O 接口定义
python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):| shape(toy) | 含义 | |
|---|---|---|
输入 queries | (8, 6, 2, 8) = (B*C, patch_num, n_heads, d_keys) | 已拆多头的 Q |
输入 keys | (8, 6, 2, 8) | 已拆多头的 K |
输入 values | (8, 6, 2, 8) = (B*C, patch_num, n_heads, d_values) | 已拆多头的 V |
输出 V | (8, 6, 2, 8) | 注意力加权后的聚合 value,形状与输入相同 |
输出 attn | None | output_attention=False 时为 None |
tau / delta全为None,本层完全不使用(ProbSparse 相关参数)。
本层不知道d_model、不做 projection、不做残差,只做三件事:Q@K.T → softmax → A@V。
3. 顺序图(具体层)
4. 语义分组图(索引层)
三件事:① 算出每对 patch 的相似度(Q@K.T)→ ② 归一化成概率(softmax)→ ③ 用概率加权 value(A@V)。
5. 逐步解析
5.0 完整原始代码
python
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None5.1 预处理
本节的作用
从 tensor shape 中提取维度常量,计算 ==缩放点积注意力== 所需的缩放系数
。
步骤一 — 解构维度 + 缩放系数
python
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)queries.shape = (8, 6, 2, 8),拆出 B=8、L=6(query patch 数)、H=2(头数)、E=8(d_keys,每个头的向量维度)。values.shape 同理得 S=6(key patch 数)、D=8(d_values)。
self.scale=None,所以 scale = 1/\sqrt{8} \approx 0.354。
为什么要除以 ?
两个随机单位向量的点积期望为 0,方差为
(d_keys)。维度越高,点积绝对值越大,softmax 的输入就越极端,梯度进入饱和区趋向消失。 除以
将点积的方差归一化回 1,使 softmax 保持在梯度敏感区,让模型能有效学习注意力权重。
5.2 注意力打分
本节的作用
计算每对
(query patch, key patch)的相似度,经 softmax 归一化为概率分布 A。
步骤二 — einsum Q@K.T → scores
python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)下图展示 einsum ① 和 ② 的维度消去结构(步骤五的 einsum ② 也在其中):
下标含义:Q 的 blhe 与 K 的 bshe 共享 e 轴(d_keys=8),einsum 对 e 求和得到每对 (l, s) 的标量点积,输出 bhls 即 (8, 2, 6, 6)。
| 下标 | 含义 | toy 值 | 在输出里 |
|---|---|---|---|
b | 独立序列编号(B×C 合并后) | 8 | ✓ 保留 |
l | query 的 patch 位置 | 6 | ✓ 保留 |
h | 注意力头编号 | 2 | ✓ 保留 |
e | d_keys,向量维度 | 8 | ✗ 求和消去(点积) |
s | key 的 patch 位置 | 6 | ✓ 保留 |
固定一个头视角(b=0, h=0):Q[0,:,0,:] shape (6,8) 与 K[0,:,0,:] shape (6,8) 做矩阵乘 (6,8) @ (8,6) = (6,6),即 scores[0,0,:,:]。scores[0,0,l,s] 是 patch l 与 patch s 的点积相似度。
步骤三 — mask 分支
python
if self.mask_flag:
...self.mask_flag=False,整块跳过。PatchTST 不加因果掩码——patch 之间==双向关注==,未来 patch 对当前 patch 同样有价值。
步骤四 — 缩放 + softmax + dropout
python
A = self.dropout(torch.softmax(scale * scores, dim=-1))scale * scores 把得分缩小(乘 0.354)。softmax(dim=-1) 对 s 轴(key patch 方向)归一化,使每行权重之和为 1。
toy 数值追踪(b=0, h=0, l=0 的一行):原始得分 [0.41, -0.12, 0.67, 0.08, -0.23, 0.31] → 乘 0.354 → [0.14, -0.04, 0.24, 0.03, -0.08, 0.11] → softmax → [0.17, 0.14, 0.19, 0.15, 0.13, 0.17](和 = 1.0)。
输出 A: (8, 2, 6, 6),A[b, h, l, :] 是第 l 个 query patch 对 6 个 key patch 的注意力分布。
注意力矩阵热力图(固定 b=0, h=0,6×6 矩阵,颜色越深权重越高):
5.3 信息聚合
本节的作用
以注意力分布 A 为权重,对所有 key patch 的 value 做加权求和,得到每个 query patch 的新表示。
步骤五 — einsum A@V → output
python
V = torch.einsum("bhls,bshd->blhd", A, values)A 的 bhls 与 values 的 bshd 共享 s 轴(key patch 位置),einsum 对 s 求和消去,即以 A 的权重对所有 key patch 的 value 做==加权平均==,输出 blhd 即 (8, 6, 2, 8)。
| 下标 | 来自 | 含义 | 在输出里 |
|---|---|---|---|
b | A 和 values | 序列编号 | ✓ 保留 |
h | A 和 values | 头编号 | ✓ 保留 |
l | A | query patch 位置 | ✓ 保留 |
s | A 和 values | key patch 位置 | ✗ 求和消去(加权平均) |
d | values | d_values | ✓ 保留 |
固定一个头视角(b=0, h=0):A[0,0,:,:] shape (6,6) 与 values[0,:,0,:] shape (6,8) 做矩阵乘 (6,6) @ (6,8) = (6,8),即 V[0,:,0,:]。
toy 数值追踪(patch 0 的输出):[0.17, 0.14, 0.19, 0.15, 0.13, 0.17] 分别加权 6 个 key patch 的 value 向量,得到 patch 0 聚合后的 8 维表示。
步骤六 — 返回
python
return V.contiguous(), None.contiguous() 确保内存连续。
为什么 einsum 之后需要 .contiguous()?
.contiguous()?
torch.einsum返回的 tensor 在内存中可能是非连续布局(stride 不对齐)。后续 04B-Layer4-AttentionLayer 里的.view(B, L, -1)要求内存必须连续,否则会抛RuntimeError。.contiguous()强制创建一块连续内存的副本,以解决这个问题。
None代替attn矩阵——output_attention=False时不需要返回注意力权重,仅返回聚合后的 value。