Skip to content

Layer 1 — forecast() 主链

1. 在父层中的位置

FEDformer.forward() 判断 task_name == "short_term_forecast" 后调用 self.forecast(),这是整个预测的主干。

2. I/O 接口定义

参数Shape含义
x_enc(3, 12, 5)encoder 输入(原始时序)
x_mark_enc(3, 12, 4)encoder 时间标记
x_dec(3, 10, 5)decoder 输入(⚠️ 内容被忽略,只决定形状隐式约束)
x_mark_dec(3, 10, 4)decoder 时间标记

输出:dec_out shape (3, 10, 5),由 forward() 截取 [:, -4:, :] → (3, 4, 5)

3. 顺序图

4. 语义分组图

5. 逐步骤精读

§5.0 完整原始代码

python
class FEDformer(nn.Module):
    def __init__(self, configs, version="fourier", mode_select="random", modes=32):
        super(FEDformer, self).__init__()
        self.task_name = configs.task_name
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len
        self.pred_len = configs.pred_len
        self.version = version
        self.mode_select = mode_select
        self.modes = modes

        self.decomp = series_decomp(configs.moving_avg)
        self.enc_embedding = DataEmbedding(
            configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout,
        )
        self.dec_embedding = DataEmbedding(
            configs.dec_in, configs.d_model, configs.embed, configs.freq, configs.dropout,
        )

        if self.version == "Wavelets":
            ...  # Wavelets 分支(TFB 不走这里)
        else:
            encoder_self_att = FourierBlock(
                in_channels=configs.d_model, out_channels=configs.d_model,
                seq_len=self.seq_len, modes=self.modes,
                mode_select_method=self.mode_select,
            )
            decoder_self_att = FourierBlock(
                in_channels=configs.d_model, out_channels=configs.d_model,
                seq_len=self.seq_len // 2 + self.pred_len,
                modes=self.modes, mode_select_method=self.mode_select,
            )
            decoder_cross_att = FourierCrossAttention(
                in_channels=configs.d_model, out_channels=configs.d_model,
                seq_len_q=self.seq_len // 2 + self.pred_len,
                seq_len_kv=self.seq_len,
                modes=self.modes, mode_select_method=self.mode_select,
                num_heads=configs.n_heads,
            )
        self.encoder = Encoder(
            [EncoderLayer(
                AutoCorrelationLayer(encoder_self_att, configs.d_model, configs.n_heads),
                configs.d_model, configs.d_ff,
                moving_avg=configs.moving_avg, dropout=configs.dropout,
                activation=configs.activation,
            ) for l in range(configs.e_layers)],
            norm_layer=my_Layernorm(configs.d_model),
        )
        self.decoder = Decoder(
            [DecoderLayer(
                AutoCorrelationLayer(decoder_self_att, configs.d_model, configs.n_heads),
                AutoCorrelationLayer(decoder_cross_att, configs.d_model, configs.n_heads),
                configs.d_model, configs.c_out, configs.d_ff,
                moving_avg=configs.moving_avg, dropout=configs.dropout,
                activation=configs.activation,
            ) for l in range(configs.d_layers)],
            norm_layer=my_Layernorm(configs.d_model),
            projection=nn.Linear(configs.d_model, configs.c_out, bias=True),
        )

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        # decomp init
        mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
        seasonal_init, trend_init = self.decomp(x_enc)
        # decoder input
        trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
        seasonal_init = F.pad(
            seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len)
        )
        # enc
        enc_out = self.enc_embedding(x_enc, x_mark_enc)
        dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
        enc_out, attns = self.encoder(enc_out, attn_mask=None)
        # dec
        seasonal_part, trend_part = self.decoder(
            dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init
        )
        # final
        dec_out = trend_part + seasonal_part
        return dec_out

§5.1 宏观逻辑

核心设计意图:FEDformer 继承 Autoformer 的"进式分解"(Progressive Decomposition)思想:维护 seasonal 和 trend 两条独立路径,但把 Encoder 和 Decoder 中的 Auto-Correlation 全部换成频域注意力。

双路结构:

  • Seasonal 路径(蓝色):贯穿 Encoder 和 Decoder,每次分解都保留 seasonal,由 my_Layernorm + Linear(d_model→c_out) 最终投影
  • Trend 路径(橙色):从 trend_init 出发,DecoderLayer 的 3 次分解分别抽出 trend1/2/3,累加后由 Conv1d 投影为 residual_trend,最终 trend_part = trend_init + Σresidual_trend

用小例子(B=1, seq_len=4, label_len=2, pred_len=2, enc_in=2, d_model=4, n_heads=8)串起来:

x_enc = [[1,2],[3,4],[2,5],[4,3]]  shape (1,4,2)

mean = x_enc.mean(dim=1) = [[2.5, 3.5]] → unsqueeze → (1,1,2) → repeat(1,2,1) → (1,2,2)

series_decomp(x_enc):
  seasonal_init = x_enc - moving_avg → (1,4,2)
  trend_init = moving_avg(x_enc) → (1,4,2)

trend_init = cat[trend_init[:,-2:,:], mean] = cat[(1,2,2),(1,2,2)] = (1,4,2)
seasonal_init = F.pad(seasonal_init[:,-2:,:], (0,0,0,2))
             = cat[(1,2,2), zeros(1,2,2)] = (1,4,2)

完整 shape 变化链

(3,12,5) → [series_decomp] → (3,12,5)×2 → [构造 dec_init] → (3,10,5)×2 → [enc_embedding→Encoder] → enc_out(3,12,16) → [dec_embedding→Decoder] → seasonal_part(3,10,5) + trend_part(3,10,5) → [sum] → (3,10,5) → [forward截取[:,-4:,:]] → (3,4,5)

整体数据流 SVG

§5.2 步骤 ① — mean 计算(trend 占位填充值)

python
mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)

x_enc.mean(dim=1) 在时间轴(seq_len=12)取均值,shape (3, 5)。

.unsqueeze(1) → (3, 1, 5)。

.repeat(1, 4, 1)(3, 4, 5) — 4 = pred_len。

作用:用历史均值填充 trend_init 的"未来"占位段,比用零初始化更接近历史水平,帮助 Decoder 从合理的趋势基准出发。

toy 值:若 x_enc[0,:,0] = [1,3,5,7,9,11,8,6,4,2,4,6],mean[0,0,0] = sum/12 ≈ 5.5,则 mean[0,:,0] = [5.5, 5.5, 5.5, 5.5](pred_len 个重复值)。

§5.3 步骤 ② — series_decomp 分解

python
seasonal_init, trend_init = self.decomp(x_enc)

series_decomp(moving_avg=7) 内部:

  1. moving_avg(x_enc, kernel=7) — replicate padding 各 3 步((7-1)//2=3),AvgPool1d,再 permute 还原 → trend_init shape (3, 12, 5)
  2. seasonal_init = x_enc - trend_init → shape (3, 12, 5)

toy 值:x_enc[0,0,0]=1.0,经 7 步平均(边缘 replicate padding 后),trend_init[0,0,0] ≈ 2.1(前 3 步 padding 的影响),seasonal_init[0,0,0] ≈ 1.0 - 2.1 = -1.1。

§5.4 步骤 ③④ — 构造 Decoder 初始化输入

Trend 初始化

python
trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)

trend_init[:, -6:, :] → (3, 6, 5)(真实历史 trend 末段)。

cat([(3,6,5), mean(3,4,5)], dim=1)trend_init shape (3, 10, 5)

语义:前 6 步 = 历史趋势(来自 series_decomp),后 4 步 = 均值占位(模型学习如何从此基准预测未来趋势)。

Seasonal 初始化

python
seasonal_init = F.pad(
    seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len)
)

seasonal_init[:, -6:, :] → (3, 6, 5)。

F.pad(..., (0, 0, 0, 4)) 对 3D 张量:最后 2 个 pad 值 (0, 4) 作用在 dim=1(时间轴)末端追加 4 步零 → (3, 10, 5)

语义:前 6 步 = 历史 seasonal 成分(label 段),后 4 步 = 全零(Decoder 学习预测未来 seasonal 波动)。

toy 值:seasonal_init[0,-1,0]=0.8(最后一个历史 seasonal 值),填充后 seasonal_init[0,6:,0] = [0,0,0,0]。

§5.5 步骤 ⑤⑥ — Encoder

python
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None)

enc_embedding(DataEmbedding):x_enc (3,12,5) + x_mark_enc (3,12,4) → enc_out (3, 12, 16)。包含 TokenEmbedding(Conv1d)+ TemporalEmbedding + PositionalEmbedding。

encodere_layers=2 个 EncoderLayer,每层:

  1. AutoCorrelationLayer(FourierBlock) → series_decomp → seasonal residual(丢弃 trend)
  2. FFN(Conv1d 1×1)→ series_decomp → seasonal residual(丢弃 trend)

输入 (3,12,16) → 2 次 EncoderLayer → my_Layernormenc_out (3, 12, 16)(形状不变,无 distilling)。

§5.6 步骤 ⑦⑧ — Decoder

python
dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
seasonal_part, trend_part = self.decoder(
    dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init
)

dec_embedding:seasonal_init (3,10,5) + x_mark_dec (3,10,4) → dec_out (3, 10, 16)

decoder(d_layers=1 个 DecoderLayer):

  • self-attention:AutoCorrelationLayer(FourierBlock)(decoder 端 seq_len = dec_len=10)
  • cross-attention:AutoCorrelationLayer(FourierCrossAttention)(Q from dec dec_len=10, K/V from enc seq_len=12)
  • 3 次 series_decomp 各抽出 trend1/2/3
  • residual_trend = trend1+trend2+trend3(经 Conv1d 投影回 enc_in=5)
  • 循环积累:trend = trend + residual_trend,最终 trend_part (3, 10, 5)
  • seasonal 穿透 → my_Layernorm → projection Linear(16→5) → seasonal_part (3, 10, 5)
FEDformer Decoder 与 Autoformer Decoder 的唯一区别

Autoformer Decoder 使用 AutoCorrelation 作为 inner_correlation;FEDformer 使用 FourierBlock(self-attn)和 FourierCrossAttention(cross-attn)。EncoderLayer、DecoderLayer、series_decomp、my_Layernorm 代码完全共享。

§5.7 步骤 ⑨ — 合并

python
dec_out = trend_part + seasonal_part
return dec_out

trend_part (3, 10, 5) + seasonal_part (3, 10, 5)dec_out (3, 10, 5)

forward() 截取 [:, -4:, :](3, 4, 5) 返回。

toy 值(最后时间步 t₉):seasonal_part[0,9,0] = 0.3(季节性波动),trend_part[0,9,0] = 5.8(趋势基准),dec_out[0,9,0] = 6.1(预测值)。

6. 下钻子组件

子组件职责文档
FourierBlockEncoder self-attn + Decoder self-attn:FFT→M频率线性变换→irfft[[03A-Layer2A-FourierBlock]]
FourierCrossAttentionDecoder cross-attn:频域 Q×K 注意力[[03B-Layer2B-FourierCrossAttention]]
series_decomp / EncoderLayer / DecoderLayerAutoformer 共享骨架[[Autoformer]] 文档

*记录并在线阅读我的笔记*