Skip to content

Layer 0 — TFB 接入界面

1. 在父层中的位置

TFB 通过 TransformerAdapter 包装 TimeMixerRollingForecast 调用 adapter 的训练/预测接口。

2. I/O 接口定义

四入参 shape(toy 参数):

参数Shape含义
x_enc(2, 24, 3)encoder 输入,B × seq_len × enc_in
x_mark_enc(2, 24, 4)encoder 时间标记(hour/day/weekday/month)
x_dec(2, 18, 3)decoder 输入 = label_len(12) + pred_len(6) 的零填充;TimeMixer 完全不使用
x_mark_dec(2, 18, 4)decoder 时间标记;TimeMixer 完全不使用

输出 shape: (2, 6, 3) = (B, pred_len, enc_in)

x_dec 被构造但从未使用

TransformerAdapter._process() 按照 Informer/Autoformer 风格构造了 dec_input,但 TimeMixer.forecast() 的函数签名接受 x_dec, x_mark_dec 却从不读取这两个参数。这是接口统一设计,TimeMixer 内部只依赖 x_encx_mark_enc

3. 顺序图

4. 语义分组图

5. 逐步骤精读

§5.0 完整原始代码

python
# adapters_for_transformers.py:10-54(节选 TimeMixer 相关默认值)
MODEL_HYPER_PARAMS = {
    "top_k": 5,
    "enc_in": 1,
    "dec_in": 1,
    "c_out": 1,
    "e_layers": 2,
    "d_model": 512,
    "d_ff": 2048,
    "embed": "timeF",
    "freq": "h",
    "lradj": "type1",
    "moving_avg": 25,
    "factor": 1,
    "n_heads": 8,
    "activation": "gelu",
    "dropout": 0.1,
    "batch_size": 32,
    "lr": 0.0001,
    "num_epochs": 10,
    "loss": "MSE",
    "patience": 3,
    "down_sampling_windows": 2,      # ← 注意:key 是 "windows"(复数)
    "channel_independence": True,
    "down_sampling_layers": 3,       # ← key 是 "layers"(复数)
    "down_sampling_method": "avg",
    "decomp_method": "moving_avg",
    "use_norm": True,
    "task_name": "short_term_forecast",
}

# adapters_for_transformers.py:81-94
def _process(self, input, target, input_mark, target_mark):
    dec_input = torch.zeros_like(target[:, -self.config.horizon :, :]).float()
    dec_input = (
        torch.cat([target[:, : self.config.label_len, :], dec_input], dim=1)
        .float()
        .to(input.device)
    )
    output = self.model(input, input_mark, dec_input, target_mark)
    return {"output": output}

§5.1 MODEL_HYPER_PARAMS 与脚本参数对照

TFB ETTh1 脚本(horizon=96 为例)传入参数:脚本传 "down_sampling_window": 2(singular "window")、"down_sampling_layer": 3(singular "layer",⚠️ 与代码键名不一致)、"d_model": 16"d_ff": 32"e_layers": 2

⚠️ 参数名不一致问题:

脚本 keyMODEL_HYPER_PARAMS keyTimeMixer 代码访问实际生效值
down_sampling_windowdown_sampling_windowsconfigs.down_sampling_window脚本值 2(脚本 key 正好匹配代码)
down_sampling_layerdown_sampling_layersconfigs.down_sampling_layersMODEL_HYPER_PARAMS 默认值 3(脚本 key 多了's',未能覆盖)
normconfigs.use_normMODEL_HYPER_PARAMS 的 use_norm=True(脚本传 "norm" 不影响 "use_norm")

脚本意图传 down_sampling_layer=3,但 TimeMixer 读 configs.down_sampling_layers,读到的是 MODEL_HYPER_PARAMS 的默认值 3——数字碰巧相同,所以功能正确,但存在隐性 bug

§5.2 TFB 自动填充的参数

multi_forecasting_hyper_param_tune(train_data) 在训练前自动设置:

python
column_num = train_data.shape[1]   # ETTh1.csv: 7
self.config.enc_in = column_num    # 7
self.config.dec_in = column_num    # 7
self.config.c_out  = column_num    # 7
setattr(self.config, "label_len", self.config.seq_len // 2)  # 512//2 = 256

这些参数无法通过脚本直接控制,由数据集列数自动推断。

§5.3 _process 构造 dec_input

形状注解: target shape (2, 18, 3) = (B, label\_len+pred\_len, N),其中 label_len=12, pred_len=6zeros_like(target[:, -6:, :]) 取最后 6 步并清零,shape (2, 6, 3)target[:, :12, :] 取前 12 步(label 历史段),shape (2, 12, 3)torch.cat([(2,12,3), (2,6,3)], dim=1) 在时间轴拼接,得 dec_input shape (2, 18, 3)

toy 数值: B=2, pred_len=6, label_len=12, enc_in=3target[:, -6:, :] 取 target 的最后 6 步得 (2, 6, 3),全部填 0;target[:, :12, :] 取前 12 步得 (2, 12, 3),保留真实历史值;cat 拼接后 dec_input shape (2, 18, 3) = label 段真实值 + 未来段全零。

TimeMixer 的 forecast() 接收了 x_dec=(2,18,3)x_mark_dec=(2,18,4),但函数体内从不引用这两个参数。

§5.4 task_name 路径

MODEL_HYPER_PARAMS 中 task_name = "short_term_forecast",脚本不覆盖此值。TimeMixer.forward 的条件:

python
if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
    dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
    return dec_out

两种 task_name 都走同一个 forecast() 函数,实际无区别。

6. 下钻子组件

子组件职责文档
TimeMixer.forecast()多尺度输入→PDM→FMM→预测[[02-Layer1-forecast主链]]

*记录并在线阅读我的笔记*