Skip to content

Layer 1 — forecast 主链

父层(Layer 0)的 forward() 直接调用 forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
本文档覆盖 forecast() 的完整 9 步计算序列。

1. 在父层中的位置

PatchTST.forward(x_enc, ...)
  └─ forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)   ← 本文档

forward() 只做分支判断,无实质计算,是透明跳板。


2. I/O 接口定义

python
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec) -> Tensor:
shape(toy)含义
输入 x_enc(2, 12, 4) = (B, seq_len, enc_in)历史时序
输出(2, 3, 4) = (B, pred_len, enc_in)预测未来

x_mark_enc / x_dec / x_mark_dec 接收但完全不读取


3. 顺序图(具体层)


4. 语义分组图(索引层)

论文/原理描述代码实现关键原因
"Channel-Independent Patch Transformer"permute→PatchEmbedding(reshape B*C)→Encoder→reshape还原→FlattenHead全链的核心:把多变量预测变成 B*C 条独立单变量 patch 序列的处理
"Patching reduces token数量"unfold(size=patch_len, step=stride)patch_num=6 远小于 seq_len=12,注意力复杂度从 O(144) 降到 O(36)

5. 逐步解析

5.0 完整原始代码

python
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
    # Normalization from Non-stationary Transformer
    means = x_enc.mean(1, keepdim=True).detach()
    x_enc = x_enc - means
    stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
    x_enc /= stdev

    # do patching and embedding
    x_enc = x_enc.permute(0, 2, 1)
    # u: [bs * nvars x patch_num x d_model]
    enc_out, n_vars = self.patch_embedding(x_enc)

    # Encoder
    # z: [bs * nvars x patch_num x d_model]
    enc_out, attns = self.encoder(enc_out)
    # z: [bs x nvars x patch_num x d_model]
    enc_out = torch.reshape(
        enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
    )
    # z: [bs x nvars x d_model x patch_num]
    enc_out = enc_out.permute(0, 1, 3, 2)

    # Decoder
    dec_out = self.head(enc_out)  # z: [bs x nvars x target_window]
    dec_out = dec_out.permute(0, 2, 1)

    # De-Normalization from Non-stationary Transformer
    dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    return dec_out

整体数据流:归一化 → permute → PatchEmbedding(合并 B×C)→ Encoder → reshape(拆回 B×C)→ FlattenHead(预测)→ 反归一化。


5.1 实例归一化(步骤①②)

本节的作用

逐样本、逐变量消除均值和方差,让 Transformer 处理的是==分布稳定==的序列,而非原始量纲数据。

步骤① — 计算均值并去均值

python
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means

.mean(1, keepdim=True) 对时间轴(dim=1)求均值,keepdim=True 保持维度以便广播,输出 means: (2, 1, 4).detach() 切断梯度,means 不参与反向传播。x_enc - means 广播减法 (2,12,4) - (2,1,4) = (2,12,4)

toy 追踪(batch 0,变量 0,原始序列 = [1,2,...,12]):mean[0,0,0] = (1+2+...+12)/12 = 6.5,去均值后 x_enc[0,:,0] = [-5.5, -4.5, ..., 5.5]

步骤② — 计算标准差并标准化

python
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev

torch.var(..., unbiased=False)N 做分母(population variance),+1e-5 防止除零,输出 stdev: (2, 1, 4)x_enc /= stdev 广播除法 (2,12,4) / (2,1,4) = (2,12,4)

toy 追踪(接步骤①):var = (5.5²+4.5²+...+0.5²)×2 / 12 ≈ 11.917stdev ≈ 3.452x_enc[0,:,0] ≈ [-1.593, -1.303, ..., 1.593]

为什么 detach()

如果 meansstdev 参与梯度计算,模型可以通过改变归一化统计量来规避约束,归一化就失去了作用——模型会"学会"把异常值塞进均值项而绕过标准化。.detach() 让这两个统计量成为纯粹的常数,保证归一化的独立性。


5.2 permute:Channel-Independent 的起点(步骤③)

本节的作用

把变量轴(enc_in)从 dim=2 移到 dim=1,为后续 PatchEmbedding 将 B 和 C 合并做准备。

步骤③ — permute(0, 2, 1)

python
x_enc = x_enc.permute(0, 2, 1)

交换 dim=1(seq_len=12)和 dim=2(enc_in=4),输出 (B, enc_in, seq_len) = (2, 4, 12)

toy 追踪:x_enc[0, var, :] 原来是 x_enc[0, :, var] 的时间序列——permute 只是换了索引方式,值不变。

为什么这步必须在 PatchEmbedding 之前?

PatchEmbedding 会执行 reshape(B*C, ...) 把样本和变量合并。这要求变量轴(C)在 dim=1——只有这样,PyTorch 才能把 (B, C, T) 的前两维连续地合并成 B*C,使每个变量独立成为一条时序样本。如果不 permute,reshape 会把时间轴并入 batch,语义完全错误。

论文/原理描述代码实现关键原因
"Channel-Independent: 每个变量独立走 Transformer"permute(0,2,1) → PatchEmbedding reshape (B*C,...)permute 把 C 放到 dim=1,reshape 时 B*C 合并,让注意力只在同变量 patch 间发生

5.3 PatchEmbedding(步骤④,下钻至 Layer 2A)

本节的作用

(B, C, T) 切成 patch 并投影为 token,同时将 B 和 C 合并,实现 ==Channel-Independent== 处理。

步骤④ — 调用 PatchEmbedding

python
enc_out, n_vars = self.patch_embedding(x_enc)

输入 x_enc: (B, enc_in, seq_len) = (2, 4, 12),输出 enc_out: (B*enc_in, patch_num, d_model) = (8, 6, 16)n_vars = enc_in = 4(用于后续 reshape 还原)。

PatchEmbedding 内部做:右端 padding → unfold 切 patch → reshape 合并 B×C → Linear 投影 + 位置编码。

→ 详见 03-Layer2A-PatchEmbedding


5.4 Encoder(步骤⑤,下钻至 Layer 2B)

本节的作用

在 patch 维度做 Transformer 自注意力,让每个 patch token 聚合全局上下文信息。

步骤⑤ — 调用 Encoder

python
enc_out, attns = self.encoder(enc_out)

输入 enc_out: (B*enc_in, patch_num, d_model) = (8, 6, 16),输出 enc_out: (8, 6, 16),形状不变,内容已被 Transformer 处理。attns = [None]output_attention=False)。

Encoder 在 patch_num 维度做 FullAttention,e_layers=1 层 EncoderLayer + 收尾 LayerNorm。

→ 详见 04-Layer2B-Encoder


5.5 reshape 还原 B×C(步骤⑥)

本节的作用

把步骤④合并的 B×C 拆回独立的 B 和 C 两个维度,恢复多变量结构。

步骤⑥ — reshape(-1, n_vars, ...)

python
enc_out = torch.reshape(
    enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
)

-1 由 PyTorch 自动推断为 B = total_samples / n_vars = 8 / 4 = 2enc_out.shape[-2] = patch_num = 6enc_out.shape[-1] = d_model = 16

toy 追踪:(8, 6, 16) → (2, 4, 6, 16) = (B, enc_in, patch_num, d_model)。步骤④中 enc_out[0](变量 0 的 batch 0)和 enc_out[4](变量 0 的 batch 1)分别变回 enc_out[0, 0]enc_out[1, 0]


5.6 permute d_model 和 patch_num(步骤⑦)

本节的作用

交换 d_model 和 patch_num 两轴,让 FlattenHead 的 Flatten 操作按"特征优先"顺序合并,与 head_nf = d_model × patch_num 的设计一致。

步骤⑦ — permute(0, 1, 3, 2)

python
enc_out = enc_out.permute(0, 1, 3, 2)

交换 dim=2(patch_num=6)和 dim=3(d_model=16),输出 (B, enc_in, d_model, patch_num) = (2, 4, 16, 6)

toy 追踪:enc_out[0, 0](6, 16) 变为 (16, 6)——行索引从 patch 编号变成特征维度。

为什么要交换 d_model 和 patch_num?

FlattenHead 用 nn.Flatten(start_dim=-2) 把最后两维合并成 head_nf。permute 前是 (patch_num=6, d_model=16),permute 后是 (d_model=16, patch_num=6)——合并顺序变了,但 head_nf=96 相同。源码注释 [bs x nvars x d_model x patch_num] 明确了这个布局意图:先特征维、后时间维,权重矩阵 W ∈ R^{pred_len × head_nf} 的每一行对应一个预测步,对 96 个"特征×时间"组合做线性组合。


5.7 FlattenHead(步骤⑧)

本节的作用

把每个变量的 (d_model, patch_num) 二维 patch 表示压平后线性映射到 pred_len,完成从 patch token 到预测值的转换。

步骤⑧a — 调用 FlattenHead

python
dec_out = self.head(enc_out)  # z: [bs x nvars x target_window]

self.headFlattenHead 实例,定义如下:

python
class FlattenHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x

self.flatten(x) — 合并末两维

nn.Flatten(start_dim=-2) 从倒数第 2 维(d_model)开始向右合并所有维度:

head\_nf=dmodel×patch\_num=16×6=96

输入 (2, 4, 16, 6) → 输出 (2, 4, 96),B 和 enc_in 两维不受影响。

toy 追踪:enc_out[0, 0] 原形状 (16, 6) → 展平为 (96,) = [feat_0_patch_0, ..., feat_0_patch_5, feat_1_patch_0, ..., feat_15_patch_5],按 d_model 行顺序展开(每行 6 个 patch 值连续排列)。

self.linear(x) — 线性投影到 pred_len

y=xWT+b,WR3×96,bR3

Linear 只作用于最后一维,前面所有维度视为 batch——把 (2, 4, 96) 理解为 2×4=8 条长度为 96 的行向量,每条独立乘 WT 得长度为 3 的输出,重组回 (2, 4, 3)

toy 追踪(batch 0,变量 0):第 t 步预测值 =W[t,:]x[0,0,:]+b[t],即对 96 个"特征×patch"组合的加权求和。

self.dropout(x) — 推理时无操作

本模型 head_dropout=0,等同于恒等映射;训练时若非零则随机置零部分输出,形状始终保持 (2, 4, 3) 不变。

步骤⑧b — permute 还原输出格式

python
dec_out = dec_out.permute(0, 2, 1)

交换 dim=1(enc_in=4)和 dim=2(target_window=3):(2, 4, 3)(2, 3, 4) = (B, pred_len, enc_in),与输入 x_enc(B, seq_len, enc_in) 布局一致。

toy 追踪:permute 前 dec_out[0, var, t] 按变量→时间步索引,permute 后变为 dec_out[0, t, var] 按时间步→变量,与 TFB 框架期望的输出格式对齐。


5.8 反归一化(步骤⑨)

本节的作用

把步骤①②保存的 meansstdev 广播到 pred_len 维度,将模型输出从标准化空间逐元素还原为原始量纲的预测值。

步骤⑨ — 构造广播矩阵 + 逐元素还原

python
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

两行结构完全对称。以 stdev 为例,完整 shape 变换链:

(2,1,4B,1,C)[:,0,:](2,4B,C)unsqueeze(1)(2,1,4B,1,C)repeat(1,3,1)(2,3,4B,Tpred,C)

[:,0,:] 取 dim=1 的唯一元素(keepdim 后只有 1 步,等价于 squeeze),形状从 (2,1,4) 变为 (2,4).unsqueeze(1) 重新插入时间轴得 (2,1,4).repeat(1, pred_len, 1) 沿时间轴复制 3 份得 (2,3,4),与 dec_out 形状完全对齐。

为什么 dim=1 一定是 1,以及 [:,0,:] 什么时候会不安全

步骤①②的计算是 x_enc.mean(1, keepdim=True)torch.var(x_enc, dim=1, keepdim=True, ...)——对整个时间维(seq_len=12)做完全聚合,keepdim=True 让被压缩的维度保留为长度 1,因此输出必然是 (B, 1, enc_in),dim=1 永远是 1。

[:,0,:] 取 dim=1 的唯一元素,这在当前代码里绝对安全。但若统计策略改变,dim=1 就不再是 1:

统计策略means shapedim=1
对整段时间维聚合(当前)(B, 1, C)固定为 1
按 patch 分段统计(B, patch_num, C)= patch_num
滑动窗口逐位统计(B, seq_len, C)= seq_len
keepdim=False(B, C)维度消失

一旦按 patch 统计,每个预测步 τ 就需要对应一个具体的 μb,p,c,反归一化就不能再写成简单的广播乘加,而要先决定"哪个 patch 的统计量还原哪个预测步"。当前代码用全局统计量的隐含假设是:整段输入的均值/方差代表该变量的长期分布,对所有未来步同等适用。

元素级还原公式:

y^b,t,c=ynorm,b,t,c×σb,c+μb,c

其中 σb,c=stdev[b,0,c]μb,c=means[b,0,c],每个变量独立还原,互不影响。

toy 追踪(batch 0,变量 0):σ0,03.452μ0,0=6.5,则对 t{0,1,2} 均有:

y^0,t,0=ynorm,0,t,0×3.452+6.5

6. 下钻子组件

子组件职责文档
self.patch_embedding(x_enc)padding+unfold+channel-independent reshape+Linear+pos_embed03-Layer2A-PatchEmbedding
self.encoder(enc_out)EncoderLayer × e_layers:FullAttention + FFN + LayerNorm04-Layer2B-Encoder

*记录并在线阅读我的笔记*