Appearance
Layer 0 — 接入界面
1. 在父层中的位置
TFB 通过 transformer_adapter 工厂函数将 FEDformer 类包装为 TransformerAdapter,再由 RollingForecast 调用 _process() → model.forward()。
2. I/O 接口定义
_process() 输入(来自 RollingForecast):
| 参数 | Shape | 含义 |
|---|---|---|
input (x_enc) | (3, 12, 5) | encoder 历史序列 |
input_mark (x_mark_enc) | (3, 12, 4) | encoder 时间标记(month/day/weekday/hour) |
target | (3, 10, 5) | decoder 目标(label_len+pred_len 段) |
target_mark (x_mark_dec) | (3, 10, 4) | decoder 时间标记 |
FEDformer.forward() 返回:dec_out[:, -pred_len:, :] shape (3, 4, 5)
3. 顺序图
4. 语义分组图
5. 逐步骤精读
§5.0 完整原始代码(接入侧)
python
# adapters_for_transformers.py
MODEL_HYPER_PARAMS = {
"top_k": 5, "enc_in": 1, "dec_in": 1, "c_out": 1,
"e_layers": 2, "d_layers": 1, "d_model": 512, "d_ff": 2048,
"embed": "timeF", "freq": "h", "lradj": "type1",
"moving_avg": 25, "num_kernels": 6, "factor": 1, "n_heads": 8,
"seg_len": 6, "win_size": 2, "activation": "gelu",
"output_attention": 0, "patch_len": 16, "stride": 8,
"dropout": 0.1, "batch_size": 32, "lr": 0.0001,
"num_epochs": 10, "num_workers": 0, "loss": "MSE", "itr": 1,
"distil": True, "patience": 3, ...
"task_name": "short_term_forecast",
}
class TransformerAdapter(DeepForecastingModelBase):
def __init__(self, model_name, model_class, **kwargs):
super(TransformerAdapter, self).__init__(MODEL_HYPER_PARAMS, **kwargs)
self._model_name = model_name
self.model_class = model_class
def _init_model(self):
return self.model_class(self.config)
def _process(self, input, target, input_mark, target_mark):
dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
dec_input = (
torch.cat([target[:, :self.config.label_len, :], dec_input], dim=1)
.float().to(input.device)
)
output = self.model(input, input_mark, dec_input, target_mark)
return {"output": output}python
# FEDformer.py
class FEDformer(nn.Module):
def __init__(self, configs, version="fourier", mode_select="random", modes=32):
...
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :]§5.1 宏观逻辑
两侧差异:TFB 接入层的核心工作是把调用规范(_process 的 4 参数)转换为模型所需格式,同时通过 MODEL_HYPER_PARAMS 提供通用默认值。
FEDformer 的特殊性在于 三个初始化参数不可通过命令行配置:
| 参数 | FEDformer.__init__ 默认值 | 说明 |
|---|---|---|
version | "fourier" | TFB 只走 fourier 分支;Wavelets 分支未经适配器暴露 |
mode_select | "random" | 频率选择策略:random = 随机采样,low = 低频优先 |
modes | 32 | 用户设置的频率模式数;实际 = min(32, seq_len//2) |
_init_model() 只调用 self.model_class(self.config),不传 version/mode_select/modes,因此这三者永远使用硬编码的 Python 默认值,不受 --model-hyper-params 控制。
toy 值:modes_user=32, actual_modes = min(32, 12//2=6) = 6(encoder FourierBlock); actual_modes_dec = min(32, 10//2=5) = 5(decoder FourierBlock,seq_len_dec = seq_len//2+pred_len = 10)。
§5.2 MODEL_HYPER_PARAMS 与 FEDformer 相关参数
对 FEDformer 真正有影响的共享超参:
| 参数 | 默认值 | FEDformer 中的用途 |
|---|---|---|
n_heads | 8 | 必须=8(FourierBlock 硬编码检查 h=8) |
d_model | 512 | 必须被 8 整除(d_model/n_heads=d_keys) |
moving_avg | 25 | series_decomp 核大小(奇数) |
e_layers | 2 | Encoder 层数 |
d_layers | 1 | Decoder 层数 |
无关参数(在 FEDformer 中未使用):top_k(ProbSparse 参数)、factor(DSAttention 参数)、p_hidden_dims/p_hidden_layers(Projector 参数)、distil(Informer 蒸馏参数)。
§5.3 _process() — dec_input 构造
python
dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
dec_input = (
torch.cat([target[:, :self.config.label_len, :], dec_input], dim=1)
.float().to(input.device)
)toy 数值:target shape (3, 10, 5),horizon=pred_len=4,label_len=6。
zeros_like(target[:, -4:, :])→ zeros shape (3, 4, 5)target[:, :6, :]→ (3, 6, 5)(历史 label 段,包含真实值)cat([(3,6,5), (3,4,5)], dim=1)→ dec_input shape (3, 10, 5)
dec_input 的语义:前 6 步是真实历史(label 段),后 4 步是全零(模型待预测的占位符)。与 Non-stationary / Informer / Autoformer 的 _process 构造方式相同,FEDformer 没有特殊构造。
注意:FEDformer.forecast() 中 x_dec 实际上只用其形状
forecast()调用F.pad(seasonal_init[:, -label_len:, :], (0, 0, 0, pred_len))构造 seasonal_init,并不使用传入的x_dec内容(不像 Non-stationary 使用 x_dec 取形状)。 实际上 FEDformer 完全忽略 x_dec 的数值,dec_input 的内容对 forecast() 没有影响。
§5.4 forward() — 分支判断与返回
python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :]task_name="short_term_forecast"(TFB 默认)→ 走 forecast 分支forecast()返回dec_outshape (3, 10, 5)(包含 label_len 段 + pred_len 段)[:, -4:, :]截取最后 pred_len=4 步 → (3, 4, 5) 返回给_process()
6. 下钻子组件
| 子组件 | 职责 | 文档 |
|---|---|---|
FEDformer.forecast() | 主预测链:series_decomp + enc + dec + dual-path | [[02-Layer1-forecast主链]] |