Appearance
Layer 0 — 接入界面
覆盖 TFB 如何实例化 Autoformer 并调用它:config 映射链 +
_process()四元组构造 +forward()I/O。
1. 在父层中的位置
TFB 框架顶层调用链:
run_benchmark.py
└─ pipeline(config)
└─ RollingForecast.strategy()
└─ TransformerAdapter._process(input, target, input_mark, target_mark)
└─ Autoformer.forward(x_enc, x_mark_enc, x_dec, x_mark_dec)2. I/O 接口定义
以 toy 参数为基准(B=2, seq_len=12, pred_len=4, label_len=6, enc_in=5, time_dims=4):
| 参数 | shape | 来源 |
|---|---|---|
x_enc | (2, 12, 5) | RollingForecast 切出的 encoder 窗口 |
x_mark_enc | (2, 12, 4) | 对应的时间特征(月/日/时/分) |
x_dec | (2, 10, 5) | _process() 构造的 decoder 输入 |
x_mark_dec | (2, 10, 4) | decoder 时间特征(对应 label+pred 窗口) |
| 输出 | (2, 4, 5) | dec_out[:, -pred_len:, :] |
3. 顺序图(具体层)
4. 语义分组图(索引层)
5. 逐步精读
5.0 关键源码
python
MODEL_HYPER_PARAMS = {
...
"moving_avg": 25,
"factor": 1,
"n_heads": 8,
"d_model": 512,
"d_ff": 2048,
"e_layers": 2,
"d_layers": 1,
...
}python
def multi_forecasting_hyper_param_tune(self, train_data):
...
column_num = train_data.shape[1]
self.config.enc_in = column_num
self.config.dec_in = column_num
self.config.c_out = column_num
setattr(self.config, "label_len", self.config.seq_len // 2)python
def _process(self, input, target, input_mark, target_mark):
dec_input = torch.zeros_like(target[:, -self.config.horizon :, :]).float()
dec_input = (
torch.cat([target[:, : self.config.label_len, :], dec_input], dim=1)
.float()
.to(input.device)
)
output = self.model(input, input_mark, dec_input, target_mark)
return {"output": output}python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if (self.task_name == "long_term_forecast" or
self.task_name == "short_term_forecast"):
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :]5.1 参数自动推断
multi_forecasting_hyper_param_tune() 在训练前被 TFB 自动调用,填充三个无法从命令行指定的参数:
enc_in = dec_in = c_out = column_num(数据集的变量数,从 DataFrame 形状读取)label_len = seq_len // 2(decoder 历史前缀 = encoder 长度的一半)
toy:column_num=5,seq_len=12 → label_len=6。
moving_avg=25 是默认值(针对真实数据集的 kernel size),toy 文档中使用 moving_avg=3 以便手算验证。
5.2 dec_input 构造
python
dec_input = torch.zeros_like(target[:, -self.config.horizon :, :]).float()
dec_input = torch.cat([target[:, : self.config.label_len, :], dec_input], dim=1)target 的形状在 TFB 的 RollingForecast 中为 (B, label_len + pred_len, enc_in) = (2, 10, 5)(包含历史前缀和预测窗口)。
target[:, -horizon:, :]=target[:, -4:, :]→ shape(2, 4, 5),用zeros_like得到全零张量target[:, :label_len, :]=target[:, :6, :]→ shape(2, 6, 5),历史前缀(已知)cat结果 →dec_input (2, 10, 5)
图解:
target: [历史6步 | 未来4步] shape (2, 10, 5)
dec_input: [历史6步 | 0 0 0 0] shape (2, 10, 5)
↑已知历史 ↑Generative Decoder 待填充toy 数值:dec_input[0, :, 0] = [v0, v1, v2, v3, v4, v5, 0, 0, 0, 0],前6步为真实历史值,后4步为零占位。
5.3 forward() 分支
python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if (self.task_name == "long_term_forecast" or
self.task_name == "short_term_forecast"):
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :]task_name 由 TFB 框架统一设置为 "long_term_forecast"。forecast() 返回 (2, 10, 5)(完整 dec_len),[:, -4:, :] 切出最后 4 步预测结果 → (2, 4, 5)。
→ forecast() 内部详见 [[02-Layer1-forecast主链]]
6. 下钻子组件
| 子组件 | 职责 | 下层文档 |
|---|---|---|
Autoformer.forecast() | 双路预测主链:decomp 初始化 + enc + dec + 合并 | [[02-Layer1-forecast主链]] |