Appearance
Layer 1 — forecast 主链
父层(Layer 0)的
forward()直接调用forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)。
本文档覆盖forecast()的完整 9 步计算序列。
1. 在父层中的位置
PatchTST.forward(x_enc, ...)
└─ forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) ← 本文档forward() 只做分支判断,无实质计算,是透明跳板。
2. I/O 接口定义
python
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec) -> Tensor:| shape(toy) | 含义 | |
|---|---|---|
输入 x_enc | (2, 12, 4) = (B, seq_len, enc_in) | 历史时序 |
| 输出 | (2, 3, 4) = (B, pred_len, enc_in) | 预测未来 |
x_mark_enc / x_dec / x_mark_dec接收但完全不读取。
3. 顺序图(具体层)
4. 语义分组图(索引层)
| 论文/原理描述 | 代码实现 | 关键原因 |
|---|---|---|
| "Channel-Independent Patch Transformer" | permute→PatchEmbedding(reshape B*C)→Encoder→reshape还原→FlattenHead | 全链的核心:把多变量预测变成 B*C 条独立单变量 patch 序列的处理 |
| "Patching reduces token数量" | unfold(size=patch_len, step=stride) | patch_num=6 远小于 seq_len=12,注意力复杂度从 O(144) 降到 O(36) |
5. 逐步解析
5.0 完整原始代码
python
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
enc_out, n_vars = self.patch_embedding(x_enc)
# Encoder
# z: [bs * nvars x patch_num x d_model]
enc_out, attns = self.encoder(enc_out)
# z: [bs x nvars x patch_num x d_model]
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
)
# z: [bs x nvars x d_model x patch_num]
enc_out = enc_out.permute(0, 1, 3, 2)
# Decoder
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
dec_out = dec_out.permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out整体数据流:归一化 → permute → PatchEmbedding(合并 B×C)→ Encoder → reshape(拆回 B×C)→ FlattenHead(预测)→ 反归一化。
5.1 实例归一化(步骤①②)
本节的作用
逐样本、逐变量消除均值和方差,让 Transformer 处理的是==分布稳定==的序列,而非原始量纲数据。
步骤① — 计算均值并去均值
python
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means.mean(1, keepdim=True) 对时间轴(dim=1)求均值,keepdim=True 保持维度以便广播,输出 means: (2, 1, 4)。.detach() 切断梯度,means 不参与反向传播。x_enc - means 广播减法 (2,12,4) - (2,1,4) = (2,12,4)。
toy 追踪(batch 0,变量 0,原始序列 = [1,2,...,12]):mean[0,0,0] = (1+2+...+12)/12 = 6.5,去均值后 x_enc[0,:,0] = [-5.5, -4.5, ..., 5.5]。
步骤② — 计算标准差并标准化
python
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdevtorch.var(..., unbiased=False) 用 +1e-5 防止除零,输出 stdev: (2, 1, 4)。x_enc /= stdev 广播除法 (2,12,4) / (2,1,4) = (2,12,4)。
toy 追踪(接步骤①):var = (5.5²+4.5²+...+0.5²)×2 / 12 ≈ 11.917,stdev ≈ 3.452,x_enc[0,:,0] ≈ [-1.593, -1.303, ..., 1.593]。
为什么 detach()?
detach()?如果
means和stdev参与梯度计算,模型可以通过改变归一化统计量来规避约束,归一化就失去了作用——模型会"学会"把异常值塞进均值项而绕过标准化。.detach()让这两个统计量成为纯粹的常数,保证归一化的独立性。
5.2 permute:Channel-Independent 的起点(步骤③)
本节的作用
把变量轴(enc_in)从 dim=2 移到 dim=1,为后续 PatchEmbedding 将 B 和 C 合并做准备。
步骤③ — permute(0, 2, 1)
python
x_enc = x_enc.permute(0, 2, 1)交换 dim=1(seq_len=12)和 dim=2(enc_in=4),输出 (B, enc_in, seq_len) = (2, 4, 12)。
toy 追踪:x_enc[0, var, :] 原来是 x_enc[0, :, var] 的时间序列——permute 只是换了索引方式,值不变。
为什么这步必须在 PatchEmbedding 之前?
PatchEmbedding 会执行
reshape(B*C, ...)把样本和变量合并。这要求变量轴(C)在 dim=1——只有这样,PyTorch 才能把(B, C, T)的前两维连续地合并成B*C,使每个变量独立成为一条时序样本。如果不 permute,reshape 会把时间轴并入 batch,语义完全错误。
| 论文/原理描述 | 代码实现 | 关键原因 |
|---|---|---|
| "Channel-Independent: 每个变量独立走 Transformer" | permute(0,2,1) → PatchEmbedding reshape (B*C,...) | permute 把 C 放到 dim=1,reshape 时 B*C 合并,让注意力只在同变量 patch 间发生 |
5.3 PatchEmbedding(步骤④,下钻至 Layer 2A)
本节的作用
把
(B, C, T)切成 patch 并投影为 token,同时将 B 和 C 合并,实现 ==Channel-Independent== 处理。
步骤④ — 调用 PatchEmbedding
python
enc_out, n_vars = self.patch_embedding(x_enc)输入 x_enc: (B, enc_in, seq_len) = (2, 4, 12),输出 enc_out: (B*enc_in, patch_num, d_model) = (8, 6, 16),n_vars = enc_in = 4(用于后续 reshape 还原)。
PatchEmbedding 内部做:右端 padding → unfold 切 patch → reshape 合并 B×C → Linear 投影 + 位置编码。
→ 详见 03-Layer2A-PatchEmbedding。
5.4 Encoder(步骤⑤,下钻至 Layer 2B)
本节的作用
在 patch 维度做 Transformer 自注意力,让每个 patch token 聚合全局上下文信息。
步骤⑤ — 调用 Encoder
python
enc_out, attns = self.encoder(enc_out)输入 enc_out: (B*enc_in, patch_num, d_model) = (8, 6, 16),输出 enc_out: (8, 6, 16),形状不变,内容已被 Transformer 处理。attns = [None](output_attention=False)。
Encoder 在 patch_num 维度做 FullAttention,e_layers=1 层 EncoderLayer + 收尾 LayerNorm。
→ 详见 04-Layer2B-Encoder。
5.5 reshape 还原 B×C(步骤⑥)
本节的作用
把步骤④合并的 B×C 拆回独立的 B 和 C 两个维度,恢复多变量结构。
步骤⑥ — reshape(-1, n_vars, ...)
python
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
)-1 由 PyTorch 自动推断为 B = total_samples / n_vars = 8 / 4 = 2;enc_out.shape[-2] = patch_num = 6;enc_out.shape[-1] = d_model = 16。
toy 追踪:(8, 6, 16) → (2, 4, 6, 16) = (B, enc_in, patch_num, d_model)。步骤④中 enc_out[0](变量 0 的 batch 0)和 enc_out[4](变量 0 的 batch 1)分别变回 enc_out[0, 0] 和 enc_out[1, 0]。
5.6 permute d_model 和 patch_num(步骤⑦)
本节的作用
交换 d_model 和 patch_num 两轴,让 FlattenHead 的 Flatten 操作按"特征优先"顺序合并,与
head_nf = d_model × patch_num的设计一致。
步骤⑦ — permute(0, 1, 3, 2)
python
enc_out = enc_out.permute(0, 1, 3, 2)交换 dim=2(patch_num=6)和 dim=3(d_model=16),输出 (B, enc_in, d_model, patch_num) = (2, 4, 16, 6)。
toy 追踪:enc_out[0, 0] 从 (6, 16) 变为 (16, 6)——行索引从 patch 编号变成特征维度。
为什么要交换 d_model 和 patch_num?
FlattenHead 用
nn.Flatten(start_dim=-2)把最后两维合并成head_nf。permute 前是(patch_num=6, d_model=16),permute 后是(d_model=16, patch_num=6)——合并顺序变了,但head_nf=96相同。源码注释[bs x nvars x d_model x patch_num]明确了这个布局意图:先特征维、后时间维,权重矩阵W ∈ R^{pred_len × head_nf}的每一行对应一个预测步,对 96 个"特征×时间"组合做线性组合。
5.7 FlattenHead(步骤⑧)
本节的作用
把每个变量的
(d_model, patch_num)二维 patch 表示压平后线性映射到 pred_len,完成从 patch token 到预测值的转换。
步骤⑧a — 调用 FlattenHead
python
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]self.head 是 FlattenHead 实例,定义如下:
python
class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return xself.flatten(x) — 合并末两维
nn.Flatten(start_dim=-2) 从倒数第 2 维(d_model)开始向右合并所有维度:
输入 (2, 4, 16, 6) → 输出 (2, 4, 96),B 和 enc_in 两维不受影响。
toy 追踪:enc_out[0, 0] 原形状 (16, 6) → 展平为 (96,) = [feat_0_patch_0, ..., feat_0_patch_5, feat_1_patch_0, ..., feat_15_patch_5],按 d_model 行顺序展开(每行 6 个 patch 值连续排列)。
self.linear(x) — 线性投影到 pred_len
Linear 只作用于最后一维,前面所有维度视为 batch——把 (2, 4, 96) 理解为 (2, 4, 3)。
toy 追踪(batch 0,变量 0):第
self.dropout(x) — 推理时无操作
本模型 head_dropout=0,等同于恒等映射;训练时若非零则随机置零部分输出,形状始终保持 (2, 4, 3) 不变。
步骤⑧b — permute 还原输出格式
python
dec_out = dec_out.permute(0, 2, 1)交换 dim=1(enc_in=4)和 dim=2(target_window=3):(2, 4, 3) → (2, 3, 4) = (B, pred_len, enc_in),与输入 x_enc 的 (B, seq_len, enc_in) 布局一致。
toy 追踪:permute 前 dec_out[0, var, t] 按变量→时间步索引,permute 后变为 dec_out[0, t, var] 按时间步→变量,与 TFB 框架期望的输出格式对齐。
5.8 反归一化(步骤⑨)
本节的作用
把步骤①②保存的
means和stdev广播到 pred_len 维度,将模型输出从标准化空间逐元素还原为原始量纲的预测值。
步骤⑨ — 构造广播矩阵 + 逐元素还原
python
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))两行结构完全对称。以 stdev 为例,完整 shape 变换链:
[:,0,:] 取 dim=1 的唯一元素(keepdim 后只有 1 步,等价于 squeeze),形状从 (2,1,4) 变为 (2,4)。.unsqueeze(1) 重新插入时间轴得 (2,1,4)。.repeat(1, pred_len, 1) 沿时间轴复制 3 份得 (2,3,4),与 dec_out 形状完全对齐。
为什么 dim=1 一定是 1,以及 [:,0,:] 什么时候会不安全
[:,0,:] 什么时候会不安全步骤①②的计算是
x_enc.mean(1, keepdim=True)和torch.var(x_enc, dim=1, keepdim=True, ...)——对整个时间维(seq_len=12)做完全聚合,keepdim=True让被压缩的维度保留为长度 1,因此输出必然是(B, 1, enc_in),dim=1 永远是 1。
[:,0,:]取 dim=1 的唯一元素,这在当前代码里绝对安全。但若统计策略改变,dim=1 就不再是 1:
统计策略 meansshapedim=1 对整段时间维聚合(当前) (B, 1, C)固定为 1 按 patch 分段统计 (B, patch_num, C)= patch_num 滑动窗口逐位统计 (B, seq_len, C)= seq_len keepdim=False(B, C)维度消失 一旦按 patch 统计,每个预测步
就需要对应一个具体的 ,反归一化就不能再写成简单的广播乘加,而要先决定"哪个 patch 的统计量还原哪个预测步"。当前代码用全局统计量的隐含假设是:整段输入的均值/方差代表该变量的长期分布,对所有未来步同等适用。
元素级还原公式:
其中
toy 追踪(batch 0,变量 0):
6. 下钻子组件
| 子组件 | 职责 | 文档 |
|---|---|---|
self.patch_embedding(x_enc) | padding+unfold+channel-independent reshape+Linear+pos_embed | 03-Layer2A-PatchEmbedding |
self.encoder(enc_out) | EncoderLayer × e_layers:FullAttention + FFN + LayerNorm | 04-Layer2B-Encoder |