Skip to content

Layer 0 — TFB 接入界面

1. 在父层中的位置

TFB 的 TransformerAdapter 包装 Nonstationary_TransformerRollingForecast._eval_batch 通过 adapter 调用训练/预测接口。

2. I/O 接口定义

四入参 shape(toy 参数,ETTh1 N=5)

参数Shape含义
x_enc(2, 12, 5)encoder 输入,B × seq_len × enc_in
x_mark_enc(2, 12, 4)encoder 时间标记(hour/day/weekday/month)
x_dec(2, 10, 5)decoder 输入 = label_len(6) + pred_len(4) 的零填充
x_mark_dec(2, 10, 4)decoder 时间标记

输出 shape(2, 4, 5) = (B, pred_len, enc_in),由 dec_out[:, -pred_len:, :] 切出

3. 顺序图

4. 语义分组图

5. 逐步骤精读

§5.0 完整原始代码

python
# adapters_for_transformers.py(Non-stationary 专属参数)
MODEL_HYPER_PARAMS = {
    ...
    "p_hidden_dims": [128, 128],   # Projector MLP 隐层维度列表
    "p_hidden_layers": 2,          # Projector MLP 层数
    "output_attention": 0,
    "factor": 1,
    "task_name": "short_term_forecast",
    ...
}

# adapters_for_transformers.py(_process 构造 dec_input)
def _process(self, input, target, input_mark, target_mark):
    dec_input = torch.zeros_like(target[:, -self.config.horizon :, :]).float()
    dec_input = (
        torch.cat([target[:, : self.config.label_len, :], dec_input], dim=1)
        .float()
        .to(input.device)
    )
    output = self.model(input, input_mark, dec_input, target_mark)
    return {"output": output}

§5.1 Non-stationary 专属参数:p_hidden_dims / p_hidden_layers

MODEL_HYPER_PARAMS 中两个专属参数:

p_hidden_dims = [128, 128] 是一个列表,表示 Projector MLP 有两个隐层,各 128 维。

p_hidden_layers = 2 与列表长度对应。

这两个参数传入 Nonstationary_Transformer.__init__ 后会直接传给 Projector.__init__

python
self.tau_learner = Projector(
    enc_in=configs.enc_in,
    seq_len=configs.seq_len,
    hidden_dims=configs.p_hidden_dims,    # [128, 128]
    hidden_layers=configs.p_hidden_layers, # 2
    output_dim=1,
)
self.delta_learner = Projector(
    enc_in=configs.enc_in,
    seq_len=configs.seq_len,
    hidden_dims=configs.p_hidden_dims,    # [128, 128]
    hidden_layers=configs.p_hidden_layers, # 2
    output_dim=configs.seq_len,           # seq_len,delta 输出与 key 序列等长
)

toy 参数下使用 p_hidden_dims=[32], p_hidden_layers=1 以简化追踪,实际训练时使用默认 [128, 128]

§5.2 TFB 自动填充参数

multi_forecasting_hyper_param_tune(train_data) 自动设置:

python
column_num = train_data.shape[1]   # ETTh1: 5(OT + 4 covariates)
self.config.enc_in = column_num    # 5
self.config.dec_in = column_num    # 5
self.config.c_out  = column_num    # 5
setattr(self.config, "label_len", self.config.seq_len // 2)  # 例:seq_len=96 → 48

enc_in 同时影响两个 Projector 的 2 × enc_in 输入维度,以及 DataEmbedding 的 TokenEmbedding 权重尺寸。

§5.3 _process 构造 dec_input

toy 数值:B=2, pred_len=4, label_len=6, enc_in=5。

target 的 shape 为 (2, 10, 5),其中前 6 步是真实历史,后 4 步是真实未来(但在 dec_input 里被清零)。

torch.zeros_like(target[:, -4:, :]) 取最后 4 步并全部置 0,shape 为 (2, 4, 5)。

target[:, :6, :] 取前 6 步(label 历史段),shape 为 (2, 6, 5)。

torch.cat([(2,6,5), (2,4,5)], dim=1)dec_input shape 为 (2, 10, 5) = 历史段真实值 + 未来段全零。

Non-stationary 的 forecast() 会实际用到 x_dec(从中取 x_dec[:, -pred_len:, :] 创建 zeros),但真正的 x_dec_new 是在 forecast() 内部重新构造的:

python
x_dec_new = torch.cat([
    x_enc[:, -self.label_len:, :],   # 归一化后的 encoder 尾段
    torch.zeros_like(x_dec[:, -self.pred_len:, :]),
], dim=1)

所以 _process 传入的 dec_input 里的真实历史值其实不被直接使用——x_dec_new 是由归一化后的 x_enc 末段拼接而成。

§5.4 task_name 路径

MODEL_HYPER_PARAMS 默认 task_name = "short_term_forecast"forward() 走:

python
if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
    dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
    return dec_out[:, -self.pred_len:, :]

两种 task_name 都调同一个 forecast(),无区别。

6. 下钻子组件

子组件职责文档
forecast() 主链归一化 + Projector + Enc/Dec + 反归一化[[02-Layer1-forecast主链]]

*记录并在线阅读我的笔记*