Appearance
Layer 0 — TFB 接入界面
1. 在父层中的位置
TFB 的 TransformerAdapter 包装 Nonstationary_Transformer,RollingForecast._eval_batch 通过 adapter 调用训练/预测接口。
2. I/O 接口定义
四入参 shape(toy 参数,ETTh1 N=5):
| 参数 | Shape | 含义 |
|---|---|---|
x_enc | (2, 12, 5) | encoder 输入,B × seq_len × enc_in |
x_mark_enc | (2, 12, 4) | encoder 时间标记(hour/day/weekday/month) |
x_dec | (2, 10, 5) | decoder 输入 = label_len(6) + pred_len(4) 的零填充 |
x_mark_dec | (2, 10, 4) | decoder 时间标记 |
输出 shape:(2, 4, 5) = (B, pred_len, enc_in),由 dec_out[:, -pred_len:, :] 切出
3. 顺序图
4. 语义分组图
5. 逐步骤精读
§5.0 完整原始代码
python
# adapters_for_transformers.py(Non-stationary 专属参数)
MODEL_HYPER_PARAMS = {
...
"p_hidden_dims": [128, 128], # Projector MLP 隐层维度列表
"p_hidden_layers": 2, # Projector MLP 层数
"output_attention": 0,
"factor": 1,
"task_name": "short_term_forecast",
...
}
# adapters_for_transformers.py(_process 构造 dec_input)
def _process(self, input, target, input_mark, target_mark):
dec_input = torch.zeros_like(target[:, -self.config.horizon :, :]).float()
dec_input = (
torch.cat([target[:, : self.config.label_len, :], dec_input], dim=1)
.float()
.to(input.device)
)
output = self.model(input, input_mark, dec_input, target_mark)
return {"output": output}§5.1 Non-stationary 专属参数:p_hidden_dims / p_hidden_layers
MODEL_HYPER_PARAMS 中两个专属参数:
p_hidden_dims = [128, 128] 是一个列表,表示 Projector MLP 有两个隐层,各 128 维。
p_hidden_layers = 2 与列表长度对应。
这两个参数传入 Nonstationary_Transformer.__init__ 后会直接传给 Projector.__init__:
python
self.tau_learner = Projector(
enc_in=configs.enc_in,
seq_len=configs.seq_len,
hidden_dims=configs.p_hidden_dims, # [128, 128]
hidden_layers=configs.p_hidden_layers, # 2
output_dim=1,
)
self.delta_learner = Projector(
enc_in=configs.enc_in,
seq_len=configs.seq_len,
hidden_dims=configs.p_hidden_dims, # [128, 128]
hidden_layers=configs.p_hidden_layers, # 2
output_dim=configs.seq_len, # seq_len,delta 输出与 key 序列等长
)toy 参数下使用 p_hidden_dims=[32], p_hidden_layers=1 以简化追踪,实际训练时使用默认 [128, 128]。
§5.2 TFB 自动填充参数
multi_forecasting_hyper_param_tune(train_data) 自动设置:
python
column_num = train_data.shape[1] # ETTh1: 5(OT + 4 covariates)
self.config.enc_in = column_num # 5
self.config.dec_in = column_num # 5
self.config.c_out = column_num # 5
setattr(self.config, "label_len", self.config.seq_len // 2) # 例:seq_len=96 → 48enc_in 同时影响两个 Projector 的 2 × enc_in 输入维度,以及 DataEmbedding 的 TokenEmbedding 权重尺寸。
§5.3 _process 构造 dec_input
toy 数值:B=2, pred_len=4, label_len=6, enc_in=5。
target 的 shape 为 (2, 10, 5),其中前 6 步是真实历史,后 4 步是真实未来(但在 dec_input 里被清零)。
torch.zeros_like(target[:, -4:, :]) 取最后 4 步并全部置 0,shape 为 (2, 4, 5)。
target[:, :6, :] 取前 6 步(label 历史段),shape 为 (2, 6, 5)。
torch.cat([(2,6,5), (2,4,5)], dim=1) → dec_input shape 为 (2, 10, 5) = 历史段真实值 + 未来段全零。
Non-stationary 的 forecast() 会实际用到 x_dec(从中取 x_dec[:, -pred_len:, :] 创建 zeros),但真正的 x_dec_new 是在 forecast() 内部重新构造的:
python
x_dec_new = torch.cat([
x_enc[:, -self.label_len:, :], # 归一化后的 encoder 尾段
torch.zeros_like(x_dec[:, -self.pred_len:, :]),
], dim=1)所以 _process 传入的 dec_input 里的真实历史值其实不被直接使用——x_dec_new 是由归一化后的 x_enc 末段拼接而成。
§5.4 task_name 路径
MODEL_HYPER_PARAMS 默认 task_name = "short_term_forecast",forward() 走:
python
if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :]两种 task_name 都调同一个 forecast(),无区别。
6. 下钻子组件
| 子组件 | 职责 | 文档 |
|---|---|---|
forecast() 主链 | 归一化 + Projector + Enc/Dec + 反归一化 | [[02-Layer1-forecast主链]] |