Appearance
Level4B Encoder 精读
Abstract
覆盖:
Transformer_EncDec.py:Encoder / EncoderLayer+SelfAttention_Family.py:AttentionLayer / FullAttention+PatchTST.py:FlattenHead输入(Encoder):
(B×enc_in, patch_num, d_model)=(6, 4, 8)输出(FlattenHead 之后):(B, pred_len, enc_in)=(2, 7, 3)
组件顺序图
组件层次树
读图方法:Encoder 是调度层,EncoderLayer 是每一层的运算单元,AttentionLayer 负责多头拆分/合并,FullAttention 才是真正的 QKV 矩阵运算。四层嵌套,从外到内精读。FlattenHead 是独立的预测头,在 Encoder 之外。
原理→代码映射
| 论文步骤 | 对应代码 | 文件行 | 说明 |
|---|---|---|---|
| Multi-Head Self-Attention | AttentionLayer + FullAttention | SelfAttention_Family.py:198/58 | AttentionLayer 负责拆/合头,FullAttention 算分数 |
| Q/K/V 线性投影 | self.query_projection(queries) 等 | SelfAttention_Family.py:211 | 三个独立 Linear(d_model, d_k×H) |
| 多头拆分 | .view(B, L, H, d_k) | SelfAttention_Family.py:218 | view 把 d_model=8 切成 H=2 × d_k=4 |
| attention 分数 | einsum("blhe,bshe->bhls", Q, K) | SelfAttention_Family.py:85 | 等价于每头做 Q×Kᵀ,但用 einsum 一次算全部 |
| softmax + 加权 | einsum("bhls,bshd->blhd", A, V) | SelfAttention_Family.py:94 | 用概率权重对 V 做加权求和 |
| FFN | 两个 Conv1d | Transformer_EncDec.py:35 | Conv1d(kernel=1) = 在每个位置独立做 MLP,等价于 Linear |
| 预测头 | FlattenHead | PatchTST.py:9 | Flatten(patch 和 d_model 两维) + Linear→pred_len |
最容易卡住的两处
卡点 1:view + transpose 为什么能"拆成多头"?
原始 Q: (6, 4, 8) = (B×enc_in, patch_num, d_model)
.view(6, 4, 2, 4):
→ (6, 4, H=2, d_k=4)
把每个时间步的 8 维向量拆成 2 个 4 维"头"
.transpose(1, 2):
→ (6, 2, 4, 4) = (B, H, L, d_k)
把头的维度挪到前面,让每个头能独立计算 attention直觉:每个头"看"的是不同的子空间(4 维),多个头并行,然后把结果拼回来。
卡点 2:einsum("blhe,bshe->bhls", Q, K) 是什么意思?
Q: (B, L_Q, H, d_k) = (6, 4, 2, 4) ← b=batch, l=query位置, h=头, e=d_k
K: (B, L_K, H, d_k) = (6, 4, 2, 4) ← b=batch, s=key位置, h=头, e=d_k
einsum 指定:对 e 维做求和(点积),保留 b,l,h,s
输出:(B, H, L_Q, L_K) = (6, 2, 4, 4) ← 每头、每个 Q 对每个 K 的 score
等价逻辑:for each b,h,l,s: scores[b,h,l,s] = sum(Q[b,l,h,:] * K[b,s,h,:])Toy 参数(本文统一)
B=2, enc_in=3, d_model=8, patch_num=4
n_heads=2, d_k=d_model//n_heads=4, d_ff=16(用于toy,真实脚本里更大)
pred_len=7, head_nf=d_model×patch_num=32
B×enc_in=6一、Encoder 入口
1. 原始代码
python
# Transformer_EncDec.py:52
class Encoder(nn.Module):
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
...
def forward(self, x, attn_mask=None, tau=None, delta=None):
attns = []
if self.conv_layers is not None:
... # distillation path,PatchTST 不走
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns2. 注解版
python
def forward(self, x, attn_mask=None, tau=None, delta=None):
# x: (6, 4, 8) ← (B×enc_in, patch_num, d_model)
attns = []
# PatchTST 构造时 conv_layers=None → 走 else 分支(无 distillation)
for attn_layer in self.attn_layers:
# 每层 EncoderLayer 处理后 shape 不变
x, attn = attn_layer(x, ...)
# x: (6, 4, 8)
attns.append(attn) # attn: None(output_attention=False 时)
x = self.norm(x) # LayerNorm(d_model=8),shape 不变: (6, 4, 8)
return x, attns # (6, 4, 8)PatchTST 默认 e_layers=1(单层 Encoder)。多层时重复以上循环,每次 shape 不变。
二、EncoderLayer
1. 原始代码
python
# Transformer_EncDec.py:29
class EncoderLayer(nn.Module):
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None, tau=None, delta=None):
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta)
x = x + self.dropout(new_x)
y = x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm2(x + y), attn2. 注解版
python
def forward(self, x, attn_mask=None, tau=None, delta=None):
# x: (6, 4, 8) ← (B×enc_in, patch_num, d_model)
# ── Self-Attention + 残差 ────────────────────────────────
new_x, attn = self.attention(x, x, x, ...)
# new_x: (6, 4, 8),Q/K/V 全部来自 x(self-attention)
x = x + self.dropout(new_x)
# x: (6, 4, 8),加完残差
# ── FFN(用 Conv1d 实现)+ 残差 ──────────────────────────
y = x = self.norm1(x)
# 注意:同时给 x 和 y 赋值,x 用于后面的残差连接,y 用于 FFN 变换
# x: (6, 4, 8),y: (6, 4, 8)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
# y.transpose(-1, 1): (6, 4, 8) → (6, 8, 4) ← Conv1d 需要 (B, C, L) 格式
# conv1: Conv1d(d_model=8, d_ff=16, kernel=1) → (6, 16, 4)
# activation(relu/gelu) + dropout → (6, 16, 4)
y = self.dropout(self.conv2(y).transpose(-1, 1))
# conv2: Conv1d(d_ff=16, d_model=8, kernel=1) → (6, 8, 4)
# .transpose(-1, 1): (6, 8, 4) → (6, 4, 8) ← 换回 (B, L, C) 格式
# dropout → (6, 4, 8)
return self.norm2(x + y), attn
# x + y: (6, 4, 8),残差连接
# norm2: LayerNorm(8),shape 不变
# 输出: (6, 4, 8)3. FFN 用 Conv1d 而非 Linear 的原因
标准 Transformer FFN 用两个 nn.Linear:Linear(d_model, d_ff) + Linear(d_ff, d_model)。
这里改用 Conv1d(kernel=1),在数学上完全等价:kernel_size=1 的 Conv1d 对每个位置独立做线性变换,等价于在序列长度维度上共享权重的 Linear。
区别在于 API 的维度约定:
nn.Linear:作用在最后一维,输入(B, L, d_model)nn.Conv1d:作用在中间维,输入必须是(B, C, L),即 C(channels)在第二维
所以 Conv1d 之前必须先 transpose(-1, 1) 把 (B, L, C) 变成 (B, C, L),之后再 transpose(-1, 1) 变回来。
toy 数值追踪(FFN 部分):
输入 y: (6, 4, 8) 例如 y[0, 0, :] = [0.5, 1.2, -0.3, 0.8, 1.5, -1.0, 0.0, 0.7]
y.transpose(-1, 1): (6, 8, 4)
y[0, :, 0] = [0.5, 1.2, -0.3, 0.8, 1.5, -1.0, 0.0, 0.7] ← 原来的 d_model 轴
y[0, :, 1] = y[0, 1, :] ← 原来 patch_num=1 的值
...
conv1: kernel_size=1,等价于对每个时间位置做 Linear(8→16)
conv1(y.transpose)[0, :, 0]: 8个输入 → 16个输出(一个位置的变换)
结果 (6, 16, 4) 再 transpose 回 (6, 4, 8) [如果 d_ff=d_model,此处 d_ff=16≠8 可见]三、AttentionLayer
1. 原始代码
python
# SelfAttention_Family.py:198
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
d_keys = d_keys or (d_model // n_heads) # = 8 // 2 = 4
d_values = d_values or (d_model // n_heads) # = 4
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads) # Linear(8, 8)
self.key_projection = nn.Linear(d_model, d_keys * n_heads) # Linear(8, 8)
self.value_projection = nn.Linear(d_model, d_values * n_heads) # Linear(8, 8)
self.out_projection = nn.Linear(d_values * n_heads, d_model) # Linear(8, 8)
self.n_heads = n_heads # = 2
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(queries, keys, values, attn_mask, ...)
out = out.view(B, L, -1)
return self.out_projection(out), attn2. 注解版
python
def forward(self, queries, keys, values, attn_mask, ...):
# PatchTST self-attention: queries=keys=values=x (来自 EncoderLayer)
# x: (6, 4, 8) ← (B×enc_in, patch_num, d_model)
B, L, _ = queries.shape # B=6, L=4(patch_num)
_, S, _ = keys.shape # S=4(self-attn 时 S=L)
H = self.n_heads # H=2
# Q/K/V 投影 + 拆分多头
queries = self.query_projection(queries).view(B, L, H, -1)
# query_projection: Linear(8,8), 输出 (6,4,8)
# .view(6, 4, 2, -1): -1 = 8//2 = 4
# queries: (6, 4, 2, 4) ← (B×enc_in, patch_num, n_heads, d_k)
keys = self.key_projection(keys).view(B, S, H, -1)
# keys: (6, 4, 2, 4)
values = self.value_projection(values).view(B, S, H, -1)
# values: (6, 4, 2, 4)
out, attn = self.inner_attention(queries, keys, values, attn_mask, ...)
# out: (6, 4, 2, 4) ← FullAttention 的输出(见下节)
out = out.view(B, L, -1)
# out.view(6, 4, -1): -1 = n_heads × d_k = 2×4 = 8
# out: (6, 4, 8) ← 合并多头
return self.out_projection(out), attn
# out_projection: Linear(8, 8)
# 输出: (6, 4, 8)四、FullAttention
1. 原始代码
python
# SelfAttention_Family.py:58
class FullAttention(nn.Module):
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, H, E = queries.shape # B=6, L=4, H=2, E=4 (d_k)
_, S, _, D = values.shape # S=4, D=4 (d_v)
scale = self.scale or 1.0 / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag: # PatchTST 里 mask_flag=False,跳过
...
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
return V.contiguous(), None # output_attention=False2. 注解版
python
def forward(self, queries, keys, values, attn_mask, ...):
# queries: (6, 4, 2, 4) ← (B×enc_in, patch_num, n_heads, d_k)
# keys: (6, 4, 2, 4)
# values: (6, 4, 2, 4)
B, L, H, E = queries.shape # 6, 4, 2, 4
scale = 1.0 / sqrt(4) = 0.5
# ── 计算注意力分数 ────────────────────────────────────────
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
# einsum 下标解读:
# b=batch(6), l=query位置(4), h=head(2), e=d_k(4)
# b=batch(6), s=key位置(4), h=head(2), e=d_k(4)
# → bhls: 对 e(d_k)维度求和,结果形状 (B, H, L, S)
# scores: (6, 2, 4, 4) ← (B×enc_in, n_heads, patch_num, patch_num)
# mask_flag=False(PatchTST 不掩码,所有 patch 之间都可以 attend)
# ── softmax 归一化 ────────────────────────────────────────
A = torch.softmax(scale * scores, dim=-1)
# scale * scores: (6, 2, 4, 4)
# softmax(dim=-1): 对最后一维(key 位置 s)做 softmax
# A: (6, 2, 4, 4) ← 每个 query 位置对应 4 个 key 的注意力权重,和=1
# ── 加权聚合 value ────────────────────────────────────────
V = torch.einsum("bhls,bshd->blhd", A, values)
# einsum 下标解读:
# b=batch(6), h=head(2), l=query位置(4), s=key位置(4)
# b=batch(6), s=key位置(4), h=head(2), d=d_v(4)
# → blhd: 对 s(key位置)求和,结果形状 (B, L, H, D)
# V: (6, 4, 2, 4) ← (B×enc_in, patch_num, n_heads, d_v)
return V.contiguous(), None
# V: (6, 4, 2, 4)3. einsum 下标深度解析
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
逐字母解读:
| 字母 | 含义 | 维度值 | 在输出中 |
|---|---|---|---|
b | batch(含 channel) | 6 | ✅ 保留 |
l | query 的时序位置(patch 号) | 4 | ✅ 保留 |
h | attention head | 2 | ✅ 保留 |
e | d_k(每个 head 的 key 维度) | 4 | ❌ 求和消去 |
s | key 的时序位置(patch 号) | 4 | ✅ 保留 |
queries[b,l,h,:] 和 keys[b,s,h,:] 做点积(对 e 求和),得到第 h 个 head 里第 l 个 query patch 对第 s 个 key patch 的相似度分数。
等价的矩阵写法(固定 b 和 h):
scores[b, h] = queries[b, :, h, :] @ keys[b, :, h, :].T
= (4, 4) @ (4, 4).T = (4, 4)V = torch.einsum("bhls,bshd->blhd", A, values)
A[b,h,l,s] 是第 l 个 query patch 对第 s 个 key patch 的注意力权重(已经 softmax 过,和=1)。
这里对 s(key 位置)求和:把 4 个 value patch 按注意力权重加权求和,得到第 l 个 query patch 的输出。
等价的矩阵写法(固定 b 和 h):
V[b, :, h, :] = A[b, h, :, :] @ values[b, :, h, :]
= (4, 4) @ (4, 4) = (4, 4)4. toy 数值追踪(单头简化版)
固定 b=0, h=0(第 0 个虚拟 batch 的第 0 个 head):
queries[0, :, 0, :]: 4 个 patch × d_k=4 → 矩阵 Q (4×4)
keys[0, :, 0, :]: 4 个 patch × d_k=4 → 矩阵 K (4×4)
设(简化):
Q = [ [1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1] ]
K = [ [1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1] ]
scores = Q @ K.T = I (4×4) 单位矩阵
scale = 0.5
scaled_scores = 0.5 × I
softmax(dim=-1) 对每行做 softmax:
每行 = [0.5, 0, 0, 0](除了对角线)
softmax([0.5, 0, 0, 0]) = [e^0.5/(e^0.5+3), 1/(e^0.5+3), ...]
≈ [0.432, 0.189, 0.189, 0.189]
→ A[0, 0, :, :] 每行都是 [0.432, 0.189, 0.189, 0.189](因为 K=I)
实际意义:每个 patch 主要 attend 自己(权重 0.432),也少量 attend 其他 patch(各 0.189)五、FlattenHead
这部分紧跟在 Encoder 输出后,位于 forecast() 的步骤 5-8。
1. 原始代码
python
# PatchTST.py:9
class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x2. 注解版
python
def forward(self, x):
# 进入此函数之前,forecast() 步骤 5-6 已经完成了 reshape 和 permute:
# x: (2, 3, 8, 4) ← (B, enc_in, d_model, patch_num)
x = self.flatten(x)
# nn.Flatten(start_dim=-2): 把最后两个维度合并
# (2, 3, 8, 4) → (2, 3, 32) ← head_nf = d_model × patch_num = 8×4 = 32
x = self.linear(x)
# nn.Linear(nf=32, target_window=7)
# 作用在最后一维:32 → 7,前面 (2, 3) 是 batch 维
# (2, 3, 32) → (2, 3, 7) ← pred_len = 7
x = self.dropout(x)
return x # (2, 3, 7)3. nn.Flatten(start_dim=-2) 语义
nn.Flatten(start_dim, end_dim) 把从 start_dim 到 end_dim(默认到最后一维)的所有维度合并成一个维度。
start_dim=-2 表示从倒数第二维开始合并:
输入: (2, 3, 8, 4)
dim: 0=B, 1=enc_in, 2=d_model(-2), 3=patch_num(-1)
↑ start_dim=-2
合并 dim[-2] 和 dim[-1]: 8×4 = 32
输出: (2, 3, 32)4. toy 数值追踪
x[0, 0, :, :]: (8, 4) 矩阵 → 展平为长度 32 的向量
行0: [v00, v01, v02, v03]
行1: [v10, v11, v12, v13]
...
行7: [v70, v71, v72, v73]
展平后: [v00, v01, v02, v03, v10, v11, ..., v70, v71, v72, v73] (长度32)
nn.Linear(32, 7): 设 W 是 (7, 32) 权重矩阵
output[k] = sum_j(W[k,j] × input[j]) for k=0..6
得到 x[0, 0, :]: 长度 7 的向量 = 对 channel 0 的 pred_len=7 步预测(归一化空间)六、forecast() 步骤 8 后续
FlattenHead 输出 (2, 3, 7) 后,forecast() 还有两步:
步骤 8:permute(0, 2, 1)
python
dec_out = dec_out.permute(0, 2, 1)
# (2, 3, 7) → (2, 7, 3)
# 把 (B, enc_in, pred_len) 变回 (B, pred_len, enc_in)
# 还原成 TFB 框架的标准输出格式步骤 9:Denormalize(细节见 03-Level3-forward主链 §3.9)
python
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)
dec_out = dec_out + means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)
# (2, 7, 3),值从归一化空间还原到原始量纲七、完整 shape 链汇总
Encoder 输入: (6, 4, 8) ← 04A 的输出
↓ EncoderLayer × e_layers
├── Self-Attention (6,4,8) → QKV (6,4,2,4) → scores (6,2,4,4) → out (6,4,8)
└── FFN (Conv1d) (6,4,8) → transpose (6,8,4) → conv1 (6,16,4) → conv2 (6,8,4) → transpose (6,4,8)
↓ LayerNorm
Encoder 输出: (6, 4, 8)
↓ reshape (6,4,8) → (2,3,4,8)
↓ permute (2,3,4,8) → (2,3,8,4)
↓ FlattenHead
├── Flatten(start_dim=-2) (2,3,8,4) → (2,3,32)
└── Linear(32,7) (2,3,32) → (2,3,7)
↓ permute (2,3,7) → (2,7,3)
↓ Denormalize
最终输出: (2, 7, 3) = (B, pred_len, enc_in)