Skip to content

Level3 forecast() 主链

Abstract

这一篇对应 00-PatchTST总览与Level树Level 3

聚焦 forecast() 函数的 9 步骨架,每步说明做什么tensor shape 变化。具体数学计算在 Level 4 精读。


1. 顺序图:forecast() 9步执行链

2. 抽象概念树

读图方法:先看 A 组建立"归一化框"的概念(预测在归一化空间完成,最后还原),再看 B 组理解 channel-independent 的三次轴操作如何对称,最后看 C+D 是模型真正做预测的部分。


3. 入口:forward() 分发

python
# PatchTST.py:222
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]

TFB 框架默认 task_name = "short_term_forecast",所以走 forecast() 分支。

最后的 dec_out[:, -self.pred_len:, :] 是"保险切片":FlattenHead 输出的长度已经等于 pred_len,所以这里实际上截取全部,不丢弃任何东西。作用是明确告诉调用者"我只给你最后 pred_len 步"。


2. forecast() 完整代码(加注释版)

python
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
    # ── 步骤 1:Normalize ──────────────────────────────────────
    means = x_enc.mean(1, keepdim=True).detach()   # (2,1,3)
    x_enc = x_enc - means                          # (2,9,3)
    stdev = torch.sqrt(
        torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5
    )                                               # (2,1,3)
    x_enc /= stdev                                  # (2,9,3)

    # ── 步骤 2:channel-independent 重排 ────────────────────────
    x_enc = x_enc.permute(0, 2, 1)                 # (2,3,9)

    # ── 步骤 3:PatchEmbedding ───────────────────────────────────
    enc_out, n_vars = self.patch_embedding(x_enc)  # (6,4,8), n_vars=3

    # ── 步骤 4:Encoder ──────────────────────────────────────────
    enc_out, attns = self.encoder(enc_out)          # (6,4,8)

    # ── 步骤 5:reshape 恢复 batch/channel ──────────────────────
    enc_out = torch.reshape(
        enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
    )                                               # (2,3,4,8)

    # ── 步骤 6:permute,d_model 和 patch_num 轴互换 ─────────────
    enc_out = enc_out.permute(0, 1, 3, 2)          # (2,3,8,4)

    # ── 步骤 7:FlattenHead 预测 ─────────────────────────────────
    dec_out = self.head(enc_out)                    # (2,3,7)

    # ── 步骤 8:还原 channel 轴到最后 ────────────────────────────
    dec_out = dec_out.permute(0, 2, 1)             # (2,7,3)

    # ── 步骤 9:Denormalize ──────────────────────────────────────
    dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    return dec_out                                  # (2,7,3)

3. 各步骤详解

步骤 1:Normalize

做什么:沿时间维(dim=1)计算每个样本每个变量的均值和标准差,做 z-score 归一化。

输入(2, 9, 3)输出(2, 9, 3)(shape 不变,值被归一化)

中间量 shape:

  • means: (2, 1, 3)(keepdim=True 保持维度,为了后续广播)
  • stdev: (2, 1, 3)

.detach() 的作用:切断 means/stdev 的梯度,它们只作为归一化的统计量,不参与反向传播。

toy 数值示例(只看 batch=0, channel=0 的时间序列):

x_enc[0, :, 0] = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0, 3.0, 5.0, 7.0]  (9个时间步)

mean = (1+3+5+2+4+6+3+5+7)/9 = 36/9 = 4.0
var  = mean([(1-4)²,(3-4)²,(5-4)²,...]) = mean([9,1,1,4,0,4,1,1,9]) = 30/9 ≈ 3.33
std  = sqrt(3.33 + 1e-5) ≈ 1.826

归一化后: [(1-4)/1.826, (3-4)/1.826, ...] = [-1.64, -0.55, 0.55, -1.10, 0.0, 1.10, -0.55, 0.55, 1.64]

步骤 2:permute(0, 2, 1)

做什么:把 (batch, time, channel) 变成 (batch, channel, time)。这是 channel-independent 的准备步骤——把 channel 轴提前,方便后续和 batch 合并。

输入(2, 9, 3)输出(2, 3, 9)

轴变化:

原始 (batch=2, time=9, channel=3)
         ↓ permute(0, 2, 1)
重排 (batch=2, channel=3, time=9)

步骤 3:PatchEmbedding

做什么

  1. 右端补 padding=2 个复制值
  2. 在时间轴上滑窗切 patch(unfold),每个 patch 是 patch_len=4
  3. 把 batch 和 channel 合并(channel-independent 关键步骤)
  4. nn.Linear(patch_len, d_model) 把每个 patch 投影到 d_model 维

输入(2, 3, 9)输出(6, 4, 8),同时返回 n_vars=3

形状演变:

(2, 3, 9)  → padding  → (2, 3, 11)
           → unfold   → (2, 3, 4, 4)   ← patch_num=4 个 patch, 每个长4
           → reshape  → (6, 4, 4)      ← B×enc_in=6 是新 batch
           → Linear   → (6, 4, 8)      ← d_model=8

详细精读见:04A-PatchEmbedding精读


步骤 4:Encoder

做什么e_layers 层 Transformer EncoderLayer,每层包含 Self-Attention + FFN + LayerNorm。PatchTST 不用 distillation(无 ConvLayer),是"干净"的标准 Encoder。

输入(6, 4, 8)(B×enc_in, patch_num, d_model) 输出(6, 4, 8)(shape 不变,值经过 Transformer 变换)

attention 是在 patch_num=4 这个维度上做的,即 4 个 patch 之间互相 attend。 每个 channel 的 4 个 patch 独立计算,不同 channel 之间完全不交互。

详细精读见:04B-Encoder精读


步骤 5:reshape 恢复 batch/channel

做什么:把之前合并的 B×enc_in=6 拆回 (B=2, enc_in=3)

输入(6, 4, 8)输出(2, 3, 4, 8)

python
enc_out = torch.reshape(enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
# -1 自动推断为 B=2
# n_vars=3(之前 PatchEmbedding 返回的)
# enc_out.shape[-2] = 4(patch_num)
# enc_out.shape[-1] = 8(d_model)

n_vars 在步骤 3 里由 PatchEmbedding.forward 返回:n_vars = x.shape[1],即 enc_in=3


步骤 6:permute(0, 1, 3, 2)

做什么:把 patch_numd_model 两个轴互换,准备送进 FlattenHead。

输入(2, 3, 4, 8)(batch, channel, patch_num, d_model) 输出(2, 3, 8, 4)(batch, channel, d_model, patch_num)

轴含义变化:
  位置 2: patch_num(4) → d_model(8)
  位置 3: d_model(8)   → patch_num(4)

为什么要换轴:FlattenHead 用 nn.Flatten(start_dim=-2) 把最后两维合并,之后用 nn.Linear(head_nf, pred_len) 作用在合并后的最后一维。d_model × patch_num 的顺序没有数学意义,但这步 permute 让 FlattenHead 的 flatten 方向和 head_nf = d_model × patch_num 的计算一致。


步骤 7:FlattenHead

做什么:把 (d_model, patch_num) 展平后用线性层映射到 pred_len

输入(2, 3, 8, 4)输出(2, 3, 7)

内部步骤:

(2, 3, 8, 4)
    ↓ Flatten(start_dim=-2)    → (2, 3, 32)   ← head_nf = 8×4 = 32
    ↓ Linear(32, 7)            → (2, 3, 7)    ← pred_len = 7
    ↓ Dropout

nn.Linear(32, 7) 作用在最后一维(32),前面的 (2, 3) 作为 batch 维度处理:即对每个 batch、每个 channel,独立做一次线性预测。


步骤 8:permute(0, 2, 1)

做什么:把 (batch, channel, pred_len) 变成 (batch, pred_len, channel),还原成标准时序格式。

输入(2, 3, 7)输出(2, 7, 3)

这步是步骤 2 的逆操作,让输出格式和 TFB 框架统一的 (B, pred_len, enc_in) 一致。


步骤 9:Denormalize

做什么:用步骤 1 保存的 mean/stdev,把预测结果还原到原始数据的尺度。

输入(2, 7, 3)输出(2, 7, 3)(shape 不变,值还原到原始量纲)

代码:

python
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

stdev[:, 0, :]:取 (2, 1, 3) 中 dim=1 的第 0 个,得 (2, 3)

为什么用 [:, 0, :] 而不是 squeeze(1):效果相同,都去掉了 keepdim 的那个维度。用 [:,0,:] 更明确地表示"取第 0 号时间步的统计量"——由于 keepdim=True,整个 seq_len 维只有一个值,取第 0 号就是取全部。

.unsqueeze(1).repeat(1, self.pred_len, 1)(2, 3) 扩展成 (2, 7, 3),对 pred_len 个时间步全部应用同一个 stdev/mean。

toy 数值(接步骤 1 的例子,batch=0, channel=0):

步骤 1 求得:mean=4.0, stdev≈1.826

假设预测输出 dec_out[0, :, 0] = [0.5, 1.2, -0.3, 0.8, 1.5, -1.0, 0.0]  (归一化空间)

denormalize:
  × stdev: [0.5×1.826, 1.2×1.826, ...] = [0.913, 2.191, -0.548, 1.461, 2.739, -1.826, 0.0]
  + mean:  [0.913+4.0, 2.191+4.0, ...] = [4.913, 6.191,  3.452, 5.461, 6.739,  2.174, 4.0]

4. 完整 shape 链汇总

x_enc 输入:       (2,  9,  3)
  ↓ normalize
  ↓ permute       (2,  3,  9)
  ↓ patch_emb     (6,  4,  8)   ← B×enc_in=6 是 channel-independent 的关键
  ↓ encoder       (6,  4,  8)
  ↓ reshape       (2,  3,  4,  8)
  ↓ permute       (2,  3,  8,  4)
  ↓ head          (2,  3,  7)
  ↓ permute       (2,  7,  3)
  ↓ denormalize
dec_out 输出:     (2,  7,  3)

5. 下一步

深入精读 PatchEmbedding 每一行代码:04A-PatchEmbedding精读

深入精读 Encoder/Attention/FFN:04B-Encoder精读

*记录并在线阅读我的笔记*