Skip to content

Layer 2A — FourierBlock(频域 Self-Attention)

1. 在父层中的位置

forecast() 中 Encoder 和 Decoder self-attention 的每个 AutoCorrelationLayer 内嵌了 FourierBlock 作为 inner_correlationAutoCorrelationLayer 完成 Q/K/V 的线性投影和多头拆分后,把 (q, k, v, attn_mask) 传入 FourierBlock.forward()。FourierBlock 忽略 k、v、mask,只使用 q 做频域变换。

2. I/O 接口定义

AutoCorrelationLayer 接口(外层包装):

参数Shape含义
queries(3, 12, 16)Encoder self-attn:来自 enc_out 经 embedding
keys(3, 12, 16)同上(FourierBlock 实际忽略)
values(3, 12, 16)同上(FourierBlock 实际忽略)
输出(3, 12, 16)注意力输出,shape 不变

FourierBlock.forward 接口(投影后):

参数Shape含义
q(3, 12, 8, 2)(B, L, H, E),多头拆分后
k(3, 12, 8, 2)传入但不使用
v(3, 12, 8, 2)传入但不使用
输出 x(3, 8, 2, 12)(B, H, E, L)——注意与 AutoCorrelationLayer 期望形状不同!

3. 顺序图

4. 语义分组图

5. 逐步骤精读

§5.0 完整原始代码

python
class AutoCorrelationLayer(nn.Module):
    def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None):
        super(AutoCorrelationLayer, self).__init__()
        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)
        self.inner_correlation = correlation
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads
        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)
        out, attn = self.inner_correlation(queries, keys, values, attn_mask)
        out = out.view(B, L, -1)
        return self.out_projection(out), attn


class FourierBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, seq_len, modes=0, mode_select_method="random"
    ):
        super(FourierBlock, self).__init__()
        print("fourier enhanced block used!")
        self.index = get_frequency_modes(
            seq_len, modes=modes, mode_select_method=mode_select_method
        )
        print("modes={}, index={}".format(modes, self.index))
        self.scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(
                8, in_channels // 8, out_channels // 8, len(self.index),
                dtype=torch.float,
            )
        )
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(
                8, in_channels // 8, out_channels // 8, len(self.index),
                dtype=torch.float,
            )
        )

    def compl_mul1d(self, order, x, weights):
        x_flag = True
        w_flag = True
        if not torch.is_complex(x):
            x_flag = False
            x = torch.complex(x, torch.zeros_like(x).to(x.device))
        if not torch.is_complex(weights):
            w_flag = False
            weights = torch.complex(
                weights, torch.zeros_like(weights).to(weights.device)
            )
        if x_flag or w_flag:
            return torch.complex(
                torch.einsum(order, x.real, weights.real)
                - torch.einsum(order, x.imag, weights.imag),
                torch.einsum(order, x.real, weights.imag)
                + torch.einsum(order, x.imag, weights.real),
            )
        else:
            return torch.einsum(order, x.real, weights.real)

    def forward(self, q, k, v, mask):
        B, L, H, E = q.shape
        x = q.permute(0, 2, 3, 1)
        x_ft = torch.fft.rfft(x, dim=-1)
        out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
        for wi, i in enumerate(self.index):
            if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
                continue
            out_ft[:, :, :, wi] = self.compl_mul1d(
                "bhi,hio->bho",
                x_ft[:, :, :, i],
                torch.complex(self.weights1, self.weights2)[:, :, :, wi],
            )
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return (x, None)

§5.1 宏观逻辑

核心设计意图:时间序列的有效信息集中在少数 M 个频率上。与其在所有时间步做 O(L2) 的 QK 点积,不如先变换到频域,只在 M 个频率上各独立做一次复数线性变换,复杂度 O(ML)

标准 Self-Attention:Attn(Q,K,V)=softmax(QKTdk)V,复杂度 O(L2dk)

FourierBlock:y^ωi=h,eq^ωi(h,e)Wωi(h,e,o),复杂度 O(ML)(rfft 为 O(LlogL),但 M 次线性变换为 O(MHE2),整体约 O(ML)

用小例子(B=1, L=4, H=8, E=2, modes=2):

q (1,4,8,2)
  → permute → (1,8,2,4)
  → rfft(dim=-1) → (1,8,2,3) cfloat  [3 = 4//2+1]
  → index = [1, 2](随机选2个模式)
  → out_ft = zeros(1,8,2,3) cfloat
     wi=0, i=1: x_ft[:,:,:,1] (1,8,2) × W[:,:,:,0] (8,2,2) → (1,8,2) 填入 out_ft[:,:,:,0]
     wi=1, i=2: x_ft[:,:,:,2] (1,8,2) × W[:,:,:,1] (8,2,2) → (1,8,2) 填入 out_ft[:,:,:,1]
     out_ft[:,:,:,2] = 0(未选中的频率保持0)
  → irfft(out_ft, n=4) → (1,8,2,4) real

完整 shape 变化链(Encoder self-attn,toy 全局参数)

(3,12,16) → Linear×3+view → (3,12,8,2) → permute → (3,8,2,12) → rfft → (3,8,2,7) cfloat → M=4频率线性变换 → out_ft(3,8,2,7) cfloat → irfft(n=12) → (3,8,2,12) → view(3,12,16) [quirk] → out_projection → (3,12,16)

注意力复杂度分析

FourierBlock 中:rfft 复杂度 O(LlogL)=O(12×3.643);M=4 次复数线性变换 O(MHE2)=O(4×8×4)=O(128);irfft 同样 O(LlogL)

对比:标准 FullAttention O(L2)=O(144);FourierBlock 约 O(2×43+128)O(214)(toy 参数下反而更大,因为 L 很小;L 大时 FourierBlock 优势显现)。M 固定时,大 L 下 FourierBlock 为 O(LlogL)(rfft 主导),好于 O(L2)

架构图

§5.2 __init__ — 频率模式选择与权重初始化

python
def get_frequency_modes(seq_len, modes=64, mode_select_method="random"):
    modes = min(modes, seq_len // 2)
    if mode_select_method == "random":
        index = list(range(0, seq_len // 2))
        np.random.shuffle(index)
        index = index[:modes]
    else:
        index = list(range(0, modes))
    index.sort()
    return index

Encoder FourierBlock 初始化时:seq_len=12, modes=32(用户默认), mode_select_method="random"

actual_modes = min(32, 12//2=6) = 6(用户 modes=32 被截断为 6!实际选取的模式数为 6,而非 32。在我们的 toy 参数 modes=4 时,actual_modes = min(4, 6) = 4)。

index = list(range(0, 6)) → shuffle → 取前 4 → sort → self.index 例如 [1, 2, 4, 5](每次运行可能不同)。

权重初始化(toy:in_channels=out_channels=16,actual_modes=4):

self.weights1: shape (8, 16//8, 16//8, 4) = (8, 2, 2, 4)

self.weights2: shape (8, 2, 2, 4)(虚部参数,与 weights1 同形)

⚠️ 硬编码 n_heads=8

weights1 第一维写死为 8,而不是 n_heads。因此:

  • n_heads=8queries.shape = (B,L,8,2),einsum 中 h=8 匹配,✅ 正常
  • n_heads=4queries.shape = (B,L,4,2),einsum 中 h=4 ≠ 8,❌ RuntimeError

正确写法应为 torch.rand(n_heads, in_channels//n_heads, ...) 但源码硬编码了 8。 FourierCrossAttention 修正了此问题,使用 num_heads 参数。

Decoder self-attn FourierBlock 的 seq_len 不同:seq_len_dec = seq_len//2 + pred_len = 6+4 = 10,故 actual_modes_dec = min(4, 10//2=5) = 4,rfft 输出 10//2+1=6 个频率,index 从 [0..4] 中选 4 个。

§5.3 forward — permute + rfft

python
B, L, H, E = q.shape
x = q.permute(0, 2, 3, 1)
x_ft = torch.fft.rfft(x, dim=-1)

toy 全局参数(Encoder self-attn):B=3, L=12, H=8, E=2。

q.permute(0, 2, 3, 1) 把 L 轴移到最后,使 rfft 沿时间轴做变换。

shape 变化:(3, 12, 8, 2)(3, 8, 2, 12) — H 和 E 成为"特征轴",L 成为"信号轴"。

rfft(x, dim=-1) 对最后一维(L=12)做实数 FFT,输出 L//2+1=7 个复数频率(rfft 利用实信号共轭对称性,仅保留 L//2+1 个独立频率):

shape 变化:(3, 8, 2, 12) real(3, 8, 2, 7) cfloat

toy 数值:x_ft[0,0,0,:] 是 batch=0, head=0, enc_var=0 的 7 个复数频率系数。频率 0 对应直流(均值),频率 6 对应奈奎斯特频率(最高频)。

§5.4 forward — M 频率线性变换(核心)

python
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
for wi, i in enumerate(self.index):
    if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
        continue
    out_ft[:, :, :, wi] = self.compl_mul1d(
        "bhi,hio->bho",
        x_ft[:, :, :, i],
        torch.complex(self.weights1, self.weights2)[:, :, :, wi],
    )

out_ft = zeros(3, 8, 2, 7) cfloat — 全零初始化,未选中的频率槽保持 0(irfft 后等于没有这些频率成分)。

权重 W:torch.complex(weights1, weights2) → shape (8, 2, 2, 4) cfloat

逐模式变换(假设 index=[1,2,4,5]):

wi=0, i=1

  • x_ft[:,:,:,1] → shape (3, 8, 2) cfloat — 频率 ω₁=1 的所有 batch/head/channel 系数
  • W[:,:,:,0] → shape (8, 2, 2) cfloat — 第 0 个模式的权重矩阵
  • compl_mul1d("bhi,hio->bho", ...) → 复数 einsum:(3,8,2) × (8,2,2) → (3,8,2)

物理含义:对频率 ω₁ 处的特征向量 (b,h,i) 做线性混合,输出 (b,h,o),权重 Wω1(h,i,o)=W1,ω1+jW2,ω1 是复数(实部和虚部分别学习对信号实部和虚部的贡献)。

out_ft[:,:,:,0] 被填入结果,其余 out_ft[:,:,:,1:7] 中只有对应选中频率的槽被填,其余保持 0。

compl_mul1d 复数乘法原理

(a+jb)(c+jd)=(acbd)+j(ad+bc)

python
real = einsum(order, x.real, w.real) - einsum(order, x.imag, w.imag)  # ac - bd
imag = einsum(order, x.real, w.imag) + einsum(order, x.imag, w.real)  # ad + bc
return torch.complex(real, imag)

早期 PyTorch 版本对复数 einsum 支持不稳定,这里手动分解实虚部实现复数矩阵乘法。

toy 数值:假设 x_ft[0,0,0,1] = 0.3+0.2j(频率1,batch0,head0,E0),W[0,0,0,0] = 0.1+0.05j(第0个输出模式,head0,E_in=0→E_out=0): real = 0.3×0.1 - 0.2×0.05 = 0.03-0.01 = 0.02 imag = 0.3×0.05 + 0.2×0.1 = 0.015+0.02 = 0.035 out_ft[0,0,0,0] 的 (0,0,0) 分量贡献 = 0.02+0.035j(还需加 E_in=1 的贡献)。

§5.5 forward — irfft 还原时域

python
x = torch.fft.irfft(out_ft, n=x.size(-1))
return (x, None)

x.size(-1) 此时是 x(permute 后的张量)的最后一维大小 = L = 12(permute 后 x shape 已是 (3,8,2,12),x.size(-1)=12)。

irfft(out_ft, n=12)(3,8,2,7) cfloat 还原为 (3,8,2,12) real(利用实信号共轭对称性,n=12 指定输出长度)。

返回 (x, None) = ((3,8,2,12), None)

⚠️ AutoCorrelationLayer 的 shape quirk

AutoCorrelationLayer.forward() 接到返回值后执行:

python
out = out.view(B, L, -1)

FourierBlock 返回 (3, 8, 2, 12) = (B, H, E, L),总元素 = 3×8×2×12 = 576。 view(3, 12, -1) = view(3, 12, 16) 也是 576 个元素,reshape 不报错。

但内存布局:原始 (B,H,E,L) 按 H-E-L 顺序存储。view(B,L,HE) 把前 16 个元素分配给 L=0,接下来 16 个给 L=1……这 16 个元素在原始张量中是 H=0,E=0,L=0 到 H=0,E=1,L=3(非完整 L 切片)。

结果:每个"时间步"位置混合了来自不同 head 和不同时间步的信息。out_projection 学习从这个混合表示中提取有用特征,模型仍然收敛。这是已发布代码中的已知布局语义偏差,文档记录但不改动。

§5.6 AutoCorrelationLayer — out_projection 还原

python
out = out.view(B, L, -1)
return self.out_projection(out), attn

out.view(3, 12, -1)(3, 12, 16) [含 shape quirk]

out_projection:Linear(d_values×n_heads=2×8=16, d_model=16) → (3, 12, 16)

形状恢复为输入 shape,可直接加入残差 + 后续 DecomP + FFN。


6. FourierBlock vs AutoCorrelation 对比

维度AutoCorrelation(Autoformer)FourierBlock(FEDformer)
频域使用方式Q·K^* → IFFT → 时延相关曲线 → top-k lag 时移 VQ rfft → M 频率独立复数线性变换 → irfft
输出语义时延加权后的 V(全时域聚合)M 个频率子空间投影到时域
K/V 的使用K 用于 IFFT 相关,V 用于时移加权K/V 完全忽略,仅用 Q
学习的结构哪些时延最重要(top-k lag)每个选中频率的线性变换矩阵 Wω
复杂度O(LlogL)(IFFT)+ O(topkL)(roll)O(LlogL)(rfft/irfft)+ O(MHE2)(线性变换)

*记录并在线阅读我的笔记*