Skip to content

Layer 0 — 接入界面

覆盖 TFB 如何实例化 Autoformer 并调用它:config 映射链 + _process() 四元组构造 + forward() I/O。


1. 在父层中的位置

TFB 框架顶层调用链:

run_benchmark.py
  └─ pipeline(config)
       └─ RollingForecast.strategy()
            └─ TransformerAdapter._process(input, target, input_mark, target_mark)
                 └─ Autoformer.forward(x_enc, x_mark_enc, x_dec, x_mark_dec)

2. I/O 接口定义

以 toy 参数为基准(B=2, seq_len=12, pred_len=4, label_len=6, enc_in=5, time_dims=4):

参数shape来源
x_enc(2, 12, 5)RollingForecast 切出的 encoder 窗口
x_mark_enc(2, 12, 4)对应的时间特征(月/日/时/分)
x_dec(2, 10, 5)_process() 构造的 decoder 输入
x_mark_dec(2, 10, 4)decoder 时间特征(对应 label+pred 窗口)
输出(2, 4, 5)dec_out[:, -pred_len:, :]

3. 顺序图(具体层)


4. 语义分组图(索引层)


5. 逐步精读

5.0 关键源码

python
MODEL_HYPER_PARAMS = {
    ...
    "moving_avg": 25,
    "factor": 1,
    "n_heads": 8,
    "d_model": 512,
    "d_ff": 2048,
    "e_layers": 2,
    "d_layers": 1,
    ...
}
python
def multi_forecasting_hyper_param_tune(self, train_data):
    ...
    column_num = train_data.shape[1]
    self.config.enc_in = column_num
    self.config.dec_in = column_num
    self.config.c_out = column_num
    setattr(self.config, "label_len", self.config.seq_len // 2)
python
def _process(self, input, target, input_mark, target_mark):
    dec_input = torch.zeros_like(target[:, -self.config.horizon :, :]).float()
    dec_input = (
        torch.cat([target[:, : self.config.label_len, :], dec_input], dim=1)
        .float()
        .to(input.device)
    )
    output = self.model(input, input_mark, dec_input, target_mark)
    return {"output": output}
python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    if (self.task_name == "long_term_forecast" or
            self.task_name == "short_term_forecast"):
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len :, :]

5.1 参数自动推断

multi_forecasting_hyper_param_tune() 在训练前被 TFB 自动调用,填充三个无法从命令行指定的参数:

  • enc_in = dec_in = c_out = column_num(数据集的变量数,从 DataFrame 形状读取)
  • label_len = seq_len // 2(decoder 历史前缀 = encoder 长度的一半)

toy:column_num=5seq_len=12label_len=6

moving_avg=25 是默认值(针对真实数据集的 kernel size),toy 文档中使用 moving_avg=3 以便手算验证。


5.2 dec_input 构造

python
dec_input = torch.zeros_like(target[:, -self.config.horizon :, :]).float()
dec_input = torch.cat([target[:, : self.config.label_len, :], dec_input], dim=1)

target 的形状在 TFB 的 RollingForecast 中为 (B, label_len + pred_len, enc_in) = (2, 10, 5)(包含历史前缀和预测窗口)。

  • target[:, -horizon:, :] = target[:, -4:, :] → shape (2, 4, 5),用 zeros_like 得到全零张量
  • target[:, :label_len, :] = target[:, :6, :] → shape (2, 6, 5),历史前缀(已知)
  • cat 结果 → dec_input (2, 10, 5)

图解

target:     [历史6步 | 未来4步]   shape (2, 10, 5)
dec_input:  [历史6步 | 0 0 0 0]   shape (2, 10, 5)
              ↑已知历史    ↑Generative Decoder 待填充

toy 数值:dec_input[0, :, 0] = [v0, v1, v2, v3, v4, v5, 0, 0, 0, 0],前6步为真实历史值,后4步为零占位。


5.3 forward() 分支

python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    if (self.task_name == "long_term_forecast" or
            self.task_name == "short_term_forecast"):
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len :, :]

task_name 由 TFB 框架统一设置为 "long_term_forecast"forecast() 返回 (2, 10, 5)(完整 dec_len),[:, -4:, :] 切出最后 4 步预测结果 → (2, 4, 5)

forecast() 内部详见 [[02-Layer1-forecast主链]]


6. 下钻子组件

子组件职责下层文档
Autoformer.forecast()双路预测主链:decomp 初始化 + enc + dec + 合并[[02-Layer1-forecast主链]]

*记录并在线阅读我的笔记*