Appearance
Layer 0 — TFB 接入界面
1. 在父层中的位置
TFB 通过 TransformerAdapter 包装 TimeMixer,RollingForecast 调用 adapter 的训练/预测接口。
2. I/O 接口定义
四入参 shape(toy 参数):
| 参数 | Shape | 含义 |
|---|---|---|
x_enc | (2, 24, 3) | encoder 输入,B × seq_len × enc_in |
x_mark_enc | (2, 24, 4) | encoder 时间标记(hour/day/weekday/month) |
x_dec | (2, 18, 3) | decoder 输入 = label_len(12) + pred_len(6) 的零填充;TimeMixer 完全不使用 |
x_mark_dec | (2, 18, 4) | decoder 时间标记;TimeMixer 完全不使用 |
输出 shape: (2, 6, 3) = (B, pred_len, enc_in)
x_dec 被构造但从未使用
TransformerAdapter._process()按照 Informer/Autoformer 风格构造了dec_input,但TimeMixer.forecast()的函数签名接受x_dec, x_mark_dec却从不读取这两个参数。这是接口统一设计,TimeMixer 内部只依赖x_enc和x_mark_enc。
3. 顺序图
4. 语义分组图
5. 逐步骤精读
§5.0 完整原始代码
python
# adapters_for_transformers.py:10-54(节选 TimeMixer 相关默认值)
MODEL_HYPER_PARAMS = {
"top_k": 5,
"enc_in": 1,
"dec_in": 1,
"c_out": 1,
"e_layers": 2,
"d_model": 512,
"d_ff": 2048,
"embed": "timeF",
"freq": "h",
"lradj": "type1",
"moving_avg": 25,
"factor": 1,
"n_heads": 8,
"activation": "gelu",
"dropout": 0.1,
"batch_size": 32,
"lr": 0.0001,
"num_epochs": 10,
"loss": "MSE",
"patience": 3,
"down_sampling_windows": 2, # ← 注意:key 是 "windows"(复数)
"channel_independence": True,
"down_sampling_layers": 3, # ← key 是 "layers"(复数)
"down_sampling_method": "avg",
"decomp_method": "moving_avg",
"use_norm": True,
"task_name": "short_term_forecast",
}
# adapters_for_transformers.py:81-94
def _process(self, input, target, input_mark, target_mark):
dec_input = torch.zeros_like(target[:, -self.config.horizon :, :]).float()
dec_input = (
torch.cat([target[:, : self.config.label_len, :], dec_input], dim=1)
.float()
.to(input.device)
)
output = self.model(input, input_mark, dec_input, target_mark)
return {"output": output}§5.1 MODEL_HYPER_PARAMS 与脚本参数对照
TFB ETTh1 脚本(horizon=96 为例)传入参数:脚本传 "down_sampling_window": 2(singular "window")、"down_sampling_layer": 3(singular "layer",⚠️ 与代码键名不一致)、"d_model": 16、"d_ff": 32、"e_layers": 2。
⚠️ 参数名不一致问题:
| 脚本 key | MODEL_HYPER_PARAMS key | TimeMixer 代码访问 | 实际生效值 |
|---|---|---|---|
down_sampling_window | down_sampling_windows | configs.down_sampling_window | 脚本值 2(脚本 key 正好匹配代码) |
down_sampling_layer | down_sampling_layers | configs.down_sampling_layers | MODEL_HYPER_PARAMS 默认值 3(脚本 key 多了's',未能覆盖) |
norm | — | configs.use_norm | MODEL_HYPER_PARAMS 的 use_norm=True(脚本传 "norm" 不影响 "use_norm") |
脚本意图传 down_sampling_layer=3,但 TimeMixer 读 configs.down_sampling_layers,读到的是 MODEL_HYPER_PARAMS 的默认值 3——数字碰巧相同,所以功能正确,但存在隐性 bug。
§5.2 TFB 自动填充的参数
multi_forecasting_hyper_param_tune(train_data) 在训练前自动设置:
python
column_num = train_data.shape[1] # ETTh1.csv: 7
self.config.enc_in = column_num # 7
self.config.dec_in = column_num # 7
self.config.c_out = column_num # 7
setattr(self.config, "label_len", self.config.seq_len // 2) # 512//2 = 256这些参数无法通过脚本直接控制,由数据集列数自动推断。
§5.3 _process 构造 dec_input
形状注解: target shape (2, 18, 3) = label_len=12, pred_len=6。zeros_like(target[:, -6:, :]) 取最后 6 步并清零,shape (2, 6, 3)。target[:, :12, :] 取前 12 步(label 历史段),shape (2, 12, 3)。torch.cat([(2,12,3), (2,6,3)], dim=1) 在时间轴拼接,得 dec_input shape (2, 18, 3)。
toy 数值: B=2, pred_len=6, label_len=12, enc_in=3。target[:, -6:, :] 取 target 的最后 6 步得 (2, 6, 3),全部填 0;target[:, :12, :] 取前 12 步得 (2, 12, 3),保留真实历史值;cat 拼接后 dec_input shape (2, 18, 3) = label 段真实值 + 未来段全零。
TimeMixer 的 forecast() 接收了 x_dec=(2,18,3) 和 x_mark_dec=(2,18,4),但函数体内从不引用这两个参数。
§5.4 task_name 路径
MODEL_HYPER_PARAMS 中 task_name = "short_term_forecast",脚本不覆盖此值。TimeMixer.forward 的条件:
python
if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out两种 task_name 都走同一个 forecast() 函数,实际无区别。
6. 下钻子组件
| 子组件 | 职责 | 文档 |
|---|---|---|
TimeMixer.forecast() | 多尺度输入→PDM→FMM→预测 | [[02-Layer1-forecast主链]] |