Skip to content

Level2 数据进入 PatchTST

Abstract

这一篇对应 00-PatchTST总览与Level树Level 2

讲一件事:原始 CSV 数据怎样变成 forward(x_enc, x_mark_enc, x_dec, x_mark_dec) 的四个入参,以及哪些入参 PatchTST 其实根本不用。


1. 调用时序图


2. 四个 forward 入参的含义与 shape

参数名(代码)含义Shape(toy)是否被 PatchTST 使用
x_enc / inputencoder 输入,历史时序(2, 9, 3)✅ 核心输入
x_mark_enc / input_markencoder 时间戳特征(2, 9, 4)❌ 未使用
x_dec / dec_inputdecoder 输入(adapter 构造)(2, 11, 3)❌ 未使用
x_mark_dec / target_markdecoder 时间戳特征(2, 11, 4)❌ 未使用

重要:PatchTST 是纯 encoder-only 模型,forecast() 函数体内从未访问 x_decx_mark_dec。这四个参数的签名是为了和整个框架的统一接口保持兼容,不代表 PatchTST 真正需要它们


3. dec_input 是怎样构造的(即使 PatchTST 不用它)

位置:adapters_for_transformers.py:81

python
def _process(self, input, target, input_mark, target_mark):
    dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
    dec_input = (
        torch.cat([target[:, :self.config.label_len, :], dec_input], dim=1)
        .float()
        .to(input.device)
    )
    output = self.model(input, input_mark, dec_input, target_mark)
    return {"output": output}

target 的 shape 是 (B, label_len + horizon, enc_in)

用 toy 参数代入:

  • label_len = seq_len // 2 = 9 // 2 = 4(由 deep_forecasting_model_base.py:241 设置)
  • horizon = pred_len = 7
  • target.shape = (2, 4+7, 3) = (2, 11, 3)

dec_input 构造过程:

target[:, :4, :] 取历史尾部 4 步     → shape (2, 4, 3)
zeros_like(target[:, -7:, :])         → shape (2, 7, 3),全零占位
cat([历史尾4步, 全零7步], dim=1)       → shape (2, 11, 3)

语义:这是 Informer/Autoformer 等 decoder 模型需要的"partially observed future"——过去尾部 + 未来零占位。PatchTST 没有 decoder,这个输入对它毫无意义,直接被忽略。


4. label_len 的来源

label_len 不是命令行参数,由框架在训练前自动设置:

python
# deep_forecasting_model_base.py:241
setattr(self.config, "label_len", self.config.seq_len // 2)

例外:MICN 模型设为 seq_len,单步预测模型设为 horizon。PatchTST 走通用分支,得 seq_len // 2

toy 参数下:label_len = 9 // 2 = 4


5. enc_in 的来源

enc_in 也不是命令行参数,由框架在训练前从训练数据推断:

python
# deep_forecasting_model_base.py:234
self.config.enc_in = column_num   # 数据集的列数

这意味着:MODEL_HYPER_PARAMS 里的默认 enc_in=1 只是占位,真正的值在训练数据加载后才确定

toy 参数里 enc_in=3 代表数据集有 3 列。


6. DataLoader 输出的完整 tensor 关系

CSV 原始数据(T 行 × enc_in 列)
    ↓ 滑窗切割(RollingForecast 策略)
batch_x      (2, 9,  3)   ← input:  历史 seq_len=9 步
batch_y      (2, 11, 3)   ← target: label_len=4 + pred_len=7 步
batch_x_mark (2, 9,  4)   ← 对应 batch_x 的时间戳特征(month/day/hour/weekday)
batch_y_mark (2, 11, 4)   ← 对应 batch_y 的时间戳特征

7. 下一步

继续看 forward 骨架:03-Level3-forward主链

*记录并在线阅读我的笔记*