Skip to content

Layer 0 — 接入界面

1. 在父层中的位置

TFB 通过 transformer_adapter 工厂函数将 FEDformer 类包装为 TransformerAdapter,再由 RollingForecast 调用 _process()model.forward()

2. I/O 接口定义

_process() 输入(来自 RollingForecast):

参数Shape含义
input (x_enc)(3, 12, 5)encoder 历史序列
input_mark (x_mark_enc)(3, 12, 4)encoder 时间标记(month/day/weekday/hour)
target(3, 10, 5)decoder 目标(label_len+pred_len 段)
target_mark (x_mark_dec)(3, 10, 4)decoder 时间标记

FEDformer.forward() 返回dec_out[:, -pred_len:, :] shape (3, 4, 5)

3. 顺序图

4. 语义分组图

5. 逐步骤精读

§5.0 完整原始代码(接入侧)

python
# adapters_for_transformers.py

MODEL_HYPER_PARAMS = {
    "top_k": 5, "enc_in": 1, "dec_in": 1, "c_out": 1,
    "e_layers": 2, "d_layers": 1, "d_model": 512, "d_ff": 2048,
    "embed": "timeF", "freq": "h", "lradj": "type1",
    "moving_avg": 25, "num_kernels": 6, "factor": 1, "n_heads": 8,
    "seg_len": 6, "win_size": 2, "activation": "gelu",
    "output_attention": 0, "patch_len": 16, "stride": 8,
    "dropout": 0.1, "batch_size": 32, "lr": 0.0001,
    "num_epochs": 10, "num_workers": 0, "loss": "MSE", "itr": 1,
    "distil": True, "patience": 3, ...
    "task_name": "short_term_forecast",
}

class TransformerAdapter(DeepForecastingModelBase):
    def __init__(self, model_name, model_class, **kwargs):
        super(TransformerAdapter, self).__init__(MODEL_HYPER_PARAMS, **kwargs)
        self._model_name = model_name
        self.model_class = model_class

    def _init_model(self):
        return self.model_class(self.config)

    def _process(self, input, target, input_mark, target_mark):
        dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
        dec_input = (
            torch.cat([target[:, :self.config.label_len, :], dec_input], dim=1)
            .float().to(input.device)
        )
        output = self.model(input, input_mark, dec_input, target_mark)
        return {"output": output}
python
# FEDformer.py

class FEDformer(nn.Module):
    def __init__(self, configs, version="fourier", mode_select="random", modes=32):
        ...
    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
            return dec_out[:, -self.pred_len:, :]

§5.1 宏观逻辑

两侧差异:TFB 接入层的核心工作是把调用规范(_process 的 4 参数)转换为模型所需格式,同时通过 MODEL_HYPER_PARAMS 提供通用默认值。

FEDformer 的特殊性在于 三个初始化参数不可通过命令行配置

参数FEDformer.__init__ 默认值说明
version"fourier"TFB 只走 fourier 分支;Wavelets 分支未经适配器暴露
mode_select"random"频率选择策略:random = 随机采样,low = 低频优先
modes32用户设置的频率模式数;实际 = min(32, seq_len//2)

_init_model() 只调用 self.model_class(self.config),不传 version/mode_select/modes,因此这三者永远使用硬编码的 Python 默认值,不受 --model-hyper-params 控制。

toy 值:modes_user=32, actual_modes = min(32, 12//2=6) = 6(encoder FourierBlock); actual_modes_dec = min(32, 10//2=5) = 5(decoder FourierBlock,seq_len_dec = seq_len//2+pred_len = 10)。

§5.2 MODEL_HYPER_PARAMS 与 FEDformer 相关参数

对 FEDformer 真正有影响的共享超参:

参数默认值FEDformer 中的用途
n_heads8必须=8(FourierBlock 硬编码检查 h=8)
d_model512必须被 8 整除(d_model/n_heads=d_keys)
moving_avg25series_decomp 核大小(奇数)
e_layers2Encoder 层数
d_layers1Decoder 层数

无关参数(在 FEDformer 中未使用):top_k(ProbSparse 参数)、factor(DSAttention 参数)、p_hidden_dims/p_hidden_layers(Projector 参数)、distil(Informer 蒸馏参数)。

§5.3 _process() — dec_input 构造

python
dec_input = torch.zeros_like(target[:, -self.config.horizon:, :]).float()
dec_input = (
    torch.cat([target[:, :self.config.label_len, :], dec_input], dim=1)
    .float().to(input.device)
)

toy 数值:target shape (3, 10, 5),horizon=pred_len=4label_len=6

  • zeros_like(target[:, -4:, :]) → zeros shape (3, 4, 5)
  • target[:, :6, :] → (3, 6, 5)(历史 label 段,包含真实值)
  • cat([(3,6,5), (3,4,5)], dim=1)dec_input shape (3, 10, 5)

dec_input 的语义:前 6 步是真实历史(label 段),后 4 步是全零(模型待预测的占位符)。与 Non-stationary / Informer / Autoformer 的 _process 构造方式相同,FEDformer 没有特殊构造。

注意:FEDformer.forecast() 中 x_dec 实际上只用其形状

forecast() 调用 F.pad(seasonal_init[:, -label_len:, :], (0, 0, 0, pred_len)) 构造 seasonal_init,并不使用传入的 x_dec 内容(不像 Non-stationary 使用 x_dec 取形状)。 实际上 FEDformer 完全忽略 x_dec 的数值,dec_input 的内容对 forecast() 没有影响。

§5.4 forward() — 分支判断与返回

python
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
    if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast":
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len:, :]
  • task_name="short_term_forecast"(TFB 默认)→ 走 forecast 分支
  • forecast() 返回 dec_out shape (3, 10, 5)(包含 label_len 段 + pred_len 段)
  • [:, -4:, :] 截取最后 pred_len=4 步 → (3, 4, 5) 返回给 _process()

6. 下钻子组件

子组件职责文档
FEDformer.forecast()主预测链:series_decomp + enc + dec + dual-path[[02-Layer1-forecast主链]]

*记录并在线阅读我的笔记*