Appearance
Level3 forecast() 主链
Abstract
这一篇对应 00-PatchTST总览与Level树 的
Level 3。聚焦
forecast()函数的 9 步骨架,每步说明做什么和tensor shape 变化。具体数学计算在 Level 4 精读。
1. 顺序图:forecast() 9步执行链
2. 抽象概念树
读图方法:先看 A 组建立"归一化框"的概念(预测在归一化空间完成,最后还原),再看 B 组理解 channel-independent 的三次轴操作如何对称,最后看 C+D 是模型真正做预测的部分。
3. 入口:forward() 分发
python
# PatchTST.py:222
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]TFB 框架默认 task_name = "short_term_forecast",所以走 forecast() 分支。
最后的 dec_out[:, -self.pred_len:, :] 是"保险切片":FlattenHead 输出的长度已经等于 pred_len,所以这里实际上截取全部,不丢弃任何东西。作用是明确告诉调用者"我只给你最后 pred_len 步"。
2. forecast() 完整代码(加注释版)
python
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# ── 步骤 1:Normalize ──────────────────────────────────────
means = x_enc.mean(1, keepdim=True).detach() # (2,1,3)
x_enc = x_enc - means # (2,9,3)
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5
) # (2,1,3)
x_enc /= stdev # (2,9,3)
# ── 步骤 2:channel-independent 重排 ────────────────────────
x_enc = x_enc.permute(0, 2, 1) # (2,3,9)
# ── 步骤 3:PatchEmbedding ───────────────────────────────────
enc_out, n_vars = self.patch_embedding(x_enc) # (6,4,8), n_vars=3
# ── 步骤 4:Encoder ──────────────────────────────────────────
enc_out, attns = self.encoder(enc_out) # (6,4,8)
# ── 步骤 5:reshape 恢复 batch/channel ──────────────────────
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
) # (2,3,4,8)
# ── 步骤 6:permute,d_model 和 patch_num 轴互换 ─────────────
enc_out = enc_out.permute(0, 1, 3, 2) # (2,3,8,4)
# ── 步骤 7:FlattenHead 预测 ─────────────────────────────────
dec_out = self.head(enc_out) # (2,3,7)
# ── 步骤 8:还原 channel 轴到最后 ────────────────────────────
dec_out = dec_out.permute(0, 2, 1) # (2,7,3)
# ── 步骤 9:Denormalize ──────────────────────────────────────
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out # (2,7,3)3. 各步骤详解
步骤 1:Normalize
做什么:沿时间维(dim=1)计算每个样本每个变量的均值和标准差,做 z-score 归一化。
输入:(2, 9, 3)输出:(2, 9, 3)(shape 不变,值被归一化)
中间量 shape:
means:(2, 1, 3)(keepdim=True 保持维度,为了后续广播)stdev:(2, 1, 3)
.detach() 的作用:切断 means/stdev 的梯度,它们只作为归一化的统计量,不参与反向传播。
toy 数值示例(只看 batch=0, channel=0 的时间序列):
x_enc[0, :, 0] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0, 3.0, 5.0, 7.0] (9个时间步)
mean = (1+3+5+2+4+6+3+5+7)/9 = 36/9 = 4.0
var = mean([(1-4)²,(3-4)²,(5-4)²,...]) = mean([9,1,1,4,0,4,1,1,9]) = 30/9 ≈ 3.33
std = sqrt(3.33 + 1e-5) ≈ 1.826
归一化后: [(1-4)/1.826, (3-4)/1.826, ...] = [-1.64, -0.55, 0.55, -1.10, 0.0, 1.10, -0.55, 0.55, 1.64]步骤 2:permute(0, 2, 1)
做什么:把 (batch, time, channel) 变成 (batch, channel, time)。这是 channel-independent 的准备步骤——把 channel 轴提前,方便后续和 batch 合并。
输入:(2, 9, 3)输出:(2, 3, 9)
轴变化:
原始 (batch=2, time=9, channel=3)
↓ permute(0, 2, 1)
重排 (batch=2, channel=3, time=9)步骤 3:PatchEmbedding
做什么:
- 右端补
padding=2个复制值 - 在时间轴上滑窗切 patch(unfold),每个 patch 是
patch_len=4步 - 把 batch 和 channel 合并(channel-independent 关键步骤)
- 用
nn.Linear(patch_len, d_model)把每个 patch 投影到 d_model 维
输入:(2, 3, 9)输出:(6, 4, 8),同时返回 n_vars=3
形状演变:
(2, 3, 9) → padding → (2, 3, 11)
→ unfold → (2, 3, 4, 4) ← patch_num=4 个 patch, 每个长4
→ reshape → (6, 4, 4) ← B×enc_in=6 是新 batch
→ Linear → (6, 4, 8) ← d_model=8详细精读见:04A-PatchEmbedding精读
步骤 4:Encoder
做什么:e_layers 层 Transformer EncoderLayer,每层包含 Self-Attention + FFN + LayerNorm。PatchTST 不用 distillation(无 ConvLayer),是"干净"的标准 Encoder。
输入:(6, 4, 8)(B×enc_in, patch_num, d_model) 输出:(6, 4, 8)(shape 不变,值经过 Transformer 变换)
attention 是在
patch_num=4这个维度上做的,即 4 个 patch 之间互相 attend。 每个 channel 的 4 个 patch 独立计算,不同 channel 之间完全不交互。
详细精读见:04B-Encoder精读
步骤 5:reshape 恢复 batch/channel
做什么:把之前合并的 B×enc_in=6 拆回 (B=2, enc_in=3)。
输入:(6, 4, 8)输出:(2, 3, 4, 8)
python
enc_out = torch.reshape(enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
# -1 自动推断为 B=2
# n_vars=3(之前 PatchEmbedding 返回的)
# enc_out.shape[-2] = 4(patch_num)
# enc_out.shape[-1] = 8(d_model)n_vars 在步骤 3 里由 PatchEmbedding.forward 返回:n_vars = x.shape[1],即 enc_in=3。
步骤 6:permute(0, 1, 3, 2)
做什么:把 patch_num 和 d_model 两个轴互换,准备送进 FlattenHead。
输入:(2, 3, 4, 8)(batch, channel, patch_num, d_model) 输出:(2, 3, 8, 4)(batch, channel, d_model, patch_num)
轴含义变化:
位置 2: patch_num(4) → d_model(8)
位置 3: d_model(8) → patch_num(4)为什么要换轴:FlattenHead 用
nn.Flatten(start_dim=-2)把最后两维合并,之后用nn.Linear(head_nf, pred_len)作用在合并后的最后一维。d_model × patch_num的顺序没有数学意义,但这步 permute 让 FlattenHead 的 flatten 方向和head_nf = d_model × patch_num的计算一致。
步骤 7:FlattenHead
做什么:把 (d_model, patch_num) 展平后用线性层映射到 pred_len。
输入:(2, 3, 8, 4)输出:(2, 3, 7)
内部步骤:
(2, 3, 8, 4)
↓ Flatten(start_dim=-2) → (2, 3, 32) ← head_nf = 8×4 = 32
↓ Linear(32, 7) → (2, 3, 7) ← pred_len = 7
↓ Dropoutnn.Linear(32, 7) 作用在最后一维(32),前面的 (2, 3) 作为 batch 维度处理:即对每个 batch、每个 channel,独立做一次线性预测。
步骤 8:permute(0, 2, 1)
做什么:把 (batch, channel, pred_len) 变成 (batch, pred_len, channel),还原成标准时序格式。
输入:(2, 3, 7)输出:(2, 7, 3)
这步是步骤 2 的逆操作,让输出格式和 TFB 框架统一的 (B, pred_len, enc_in) 一致。
步骤 9:Denormalize
做什么:用步骤 1 保存的 mean/stdev,把预测结果还原到原始数据的尺度。
输入:(2, 7, 3)输出:(2, 7, 3)(shape 不变,值还原到原始量纲)
代码:
python
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))stdev[:, 0, :]:取 (2, 1, 3) 中 dim=1 的第 0 个,得 (2, 3)。
为什么用
[:, 0, :]而不是squeeze(1):效果相同,都去掉了 keepdim 的那个维度。用[:,0,:]更明确地表示"取第 0 号时间步的统计量"——由于keepdim=True,整个 seq_len 维只有一个值,取第 0 号就是取全部。
.unsqueeze(1).repeat(1, self.pred_len, 1) 把 (2, 3) 扩展成 (2, 7, 3),对 pred_len 个时间步全部应用同一个 stdev/mean。
toy 数值(接步骤 1 的例子,batch=0, channel=0):
步骤 1 求得:mean=4.0, stdev≈1.826
假设预测输出 dec_out[0, :, 0] = [0.5, 1.2, -0.3, 0.8, 1.5, -1.0, 0.0] (归一化空间)
denormalize:
× stdev: [0.5×1.826, 1.2×1.826, ...] = [0.913, 2.191, -0.548, 1.461, 2.739, -1.826, 0.0]
+ mean: [0.913+4.0, 2.191+4.0, ...] = [4.913, 6.191, 3.452, 5.461, 6.739, 2.174, 4.0]4. 完整 shape 链汇总
x_enc 输入: (2, 9, 3)
↓ normalize
↓ permute (2, 3, 9)
↓ patch_emb (6, 4, 8) ← B×enc_in=6 是 channel-independent 的关键
↓ encoder (6, 4, 8)
↓ reshape (2, 3, 4, 8)
↓ permute (2, 3, 8, 4)
↓ head (2, 3, 7)
↓ permute (2, 7, 3)
↓ denormalize
dec_out 输出: (2, 7, 3)5. 下一步
深入精读 PatchEmbedding 每一行代码:04A-PatchEmbedding精读
深入精读 Encoder/Attention/FFN:04B-Encoder精读