Skip to content

PatchTST 总览

1. 论文问题与动机

2022 年时,Transformer 在时间序列预测领域的主流用法是逐时间步做注意力:每个时间点是一个 token,注意力矩阵大小 = seq_len × seq_len。这带来两个问题:

  1. 语义太细:单个时间点没有局部上下文,模型无法捕捉"这几步一起构成了一个模式"
  2. 变量间混干扰:多变量时序里,变量间相关性往往比时间依赖弱,强行混合反而让模型学偏

Nie et al. (ICLR 2023) 提出 PatchTST,用两个设计解决这两个问题:

① Patching(局部语义)
   原始时序:   [t0][t1][t2][t3][t4][t5][t6][t7][t8][t9][t10][t11]
   Patch 0:  [t0, t1, t2, t3]       ← 4步一组,捕获局部模式
   Patch 1:       [t2, t3, t4, t5]  ← 步长2,允许重叠
   Patch 2:             [t4, t5, t6, t7]
   ...
   每个 patch 是一个 token,token 数 = patch_num ≪ seq_len

② Channel-Independent(变量独立)
   x[B, seq_len, C]
   → permute → x[B, C, seq_len]
   → reshape → x[B*C, patch_num, d_model]   ← 每个变量当作独立样本
   这样注意力只在同一变量内的 patch 之间发生,变量间不互相干扰

2. 核心创新

输入 x_enc (B, seq_len, C)

① 实例归一化(RevIN 风格)
   减去均值,除以标准差(detach,不回传梯度)

② permute: (B,C,T) → 把变量轴提前

③ PatchEmbedding
   padding → unfold → reshape(B*C,...) → Linear → + pos_embed
   输出 (B*C, patch_num, d_model)

④ Encoder(FullAttention,在 patch 维度做注意力)
   输出 (B*C, patch_num, d_model)

⑤ reshape 回: (B, C, patch_num, d_model)
   permute: (B, C, d_model, patch_num)

⑥ FlattenHead
   Flatten(d_model×patch_num=head_nf) → Linear(head_nf→pred_len)
   输出 (B, C, pred_len) → permute → (B, pred_len, C)

⑦ 反归一化(还原均值和标准差)

3. 论文架构图(原理层)

4. TFB 完整调用链

5. 文档 BFS 树形索引

文件层级覆盖内容关键 tensor
01-Layer0-接入界面Layer 0config映射 + batch I/O + forward分支(B,seq_len,C)(B,pred_len,C)
02-Layer1-forecast主链Layer 19步forecast全链 + 归一化 + channel-independent(B,12,4)→(8,6,16)→(B,3,4)
03-Layer2A-PatchEmbeddingLayer 2Apadding+unfold toy数值 + reshape channel-independent原理(B,C,T)→(B*C,patch_num,d_model)
04-Layer2B-EncoderLayer 2BEncoderLayer顺序+AttentionLayer+FullAttention(8,6,16)→(8,6,16)
04D-Layer4-例子-AttentionLayer到returnAttention 叶子例子x→Q/K/V→scores→softmax→AV→out_projection 可手算矩阵流(1,2,4)→(1,2,2,2)→(1,2,4)
05-收束收束端到端流程图 + tensor汇总全链

6. 全局 toy 参数

参数toy 值真实值(ETTh1)
B24
seq_len12336
pred_len396
enc_in47
d_model16128
patch_len416
stride28
e_layers13
n_heads216
padding2 (=stride)8

派生参数(从以上计算,与其他维度互不相等):

派生参数计算公式toy 值
patch_num(seq_len - patch_len)/stride + 2 = (12-4)/2+26
B*CB × enc_in8
d_keys = d_valuesd_model // n_heads = 16//28
d_ff默认 4 × d_model64
head_nfd_model × patch_num = 16×696

所有维度互不相等(2,3,4,6,8,12,16,64,96),防止 permute/reshape 后 shape 看起来没变而掩盖轴语义错误。

toy 输入示例(x_enc,shape (2, 12, 4),第0个 batch,第0个变量列):

python
x_enc[0, :, 0] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]

最终输出 shape:(2, 3, 4),即 (B, pred_len, enc_in)

7. 推荐阅读路径

快速了解直觉(不深入代码)
总览 §1-§3 → 01-Layer0 §3-§4(顺序图)→ 02-Layer1 §3-§4(顺序图)

完整代码精读
总览 → 01-Layer0 → 02-Layer1 → 03-Layer2A → 04-Layer2B → 04A-Layer3 → 04B-Layer4 → 04D-Layer4例子 → 04C-Layer5 → 05-收束


旧版文档(Level1/2/3/4 结构)已归档至 PatchTST_v1_archive/

*记录并在线阅读我的笔记*