Appearance
Level2 数据进入 PatchTST
Abstract
这一篇对应 00-PatchTST总览与Level树 的
Level 2。讲一件事:原始 CSV 数据怎样变成
forward(x_enc, x_mark_enc, x_dec, x_mark_dec)的四个入参,以及哪些入参 PatchTST 其实根本不用。
1. 调用时序图
2. 四个 forward 入参的含义与 shape
| 参数名(代码) | 含义 | Shape(toy) | 是否被 PatchTST 使用 |
|---|---|---|---|
x_enc / input | encoder 输入,历史时序 | (2, 9, 3) | ✅ 核心输入 |
x_mark_enc / input_mark | encoder 时间戳特征 | (2, 9, 4) | ❌ 未使用 |
x_dec / dec_input | decoder 输入(adapter 构造) | (2, 11, 3) | ❌ 未使用 |
x_mark_dec / target_mark | decoder 时间戳特征 | (2, 11, 4) | ❌ 未使用 |
重要:PatchTST 是纯 encoder-only 模型,
forecast()函数体内从未访问x_dec和x_mark_dec。这四个参数的签名是为了和整个框架的统一接口保持兼容,不代表 PatchTST 真正需要它们。
3. dec_input 是怎样构造的(即使 PatchTST 不用它)
位置:adapters_for_transformers.py:81
python
def _process(self, input, target, input_mark, target_mark):
dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
dec_input = (
torch.cat([target[:, :self.config.label_len, :], dec_input], dim=1)
.float()
.to(input.device)
)
output = self.model(input, input_mark, dec_input, target_mark)
return {"output": output}target 的 shape 是 (B, label_len + horizon, enc_in)。
用 toy 参数代入:
label_len = seq_len // 2 = 9 // 2 = 4(由deep_forecasting_model_base.py:241设置)horizon = pred_len = 7target.shape = (2, 4+7, 3) = (2, 11, 3)
dec_input 构造过程:
target[:, :4, :] 取历史尾部 4 步 → shape (2, 4, 3)
zeros_like(target[:, -7:, :]) → shape (2, 7, 3),全零占位
cat([历史尾4步, 全零7步], dim=1) → shape (2, 11, 3)语义:这是 Informer/Autoformer 等 decoder 模型需要的"partially observed future"——过去尾部 + 未来零占位。PatchTST 没有 decoder,这个输入对它毫无意义,直接被忽略。
4. label_len 的来源
label_len 不是命令行参数,由框架在训练前自动设置:
python
# deep_forecasting_model_base.py:241
setattr(self.config, "label_len", self.config.seq_len // 2)例外:MICN 模型设为 seq_len,单步预测模型设为 horizon。PatchTST 走通用分支,得 seq_len // 2。
toy 参数下:label_len = 9 // 2 = 4
5. enc_in 的来源
enc_in 也不是命令行参数,由框架在训练前从训练数据推断:
python
# deep_forecasting_model_base.py:234
self.config.enc_in = column_num # 数据集的列数这意味着:MODEL_HYPER_PARAMS 里的默认 enc_in=1 只是占位,真正的值在训练数据加载后才确定。
toy 参数里 enc_in=3 代表数据集有 3 列。
6. DataLoader 输出的完整 tensor 关系
CSV 原始数据(T 行 × enc_in 列)
↓ 滑窗切割(RollingForecast 策略)
batch_x (2, 9, 3) ← input: 历史 seq_len=9 步
batch_y (2, 11, 3) ← target: label_len=4 + pred_len=7 步
batch_x_mark (2, 9, 4) ← 对应 batch_x 的时间戳特征(month/day/hour/weekday)
batch_y_mark (2, 11, 4) ← 对应 batch_y 的时间戳特征7. 下一步
继续看 forward 骨架:03-Level3-forward主链