Skip to content

Layer 2B — FourierCrossAttention(频域 Cross-Attention)

1. 在父层中的位置

forecast() 中 Decoder 的 cross-attention AutoCorrelationLayer 内嵌了 FourierCrossAttention 作为 inner_correlationAutoCorrelationLayer 完成线性投影和多头拆分后,把 (q, k, v, attn_mask) 传入 FourierCrossAttention.forward()

与 [[03A-Layer2A-FourierBlock]] 的关键区别:Q 来自 decoder(dec_len=10),K/V 来自 encoder(seq_len=12),两者序列长度不同,需要各自独立地选择频率模式。

2. I/O 接口定义

AutoCorrelationLayer 接口(外层包装,Decoder cross-attn):

参数Shape含义
queries(3, 10, 16)Decoder 隐层输出
keys(3, 12, 16)Encoder enc_out
values(3, 12, 16)Encoder enc_out(与 keys 相同源)
输出(3, 10, 16)Cross-attention 输出

FourierCrossAttention.forward 接口(多头投影后):

参数Shape含义
q(3, 10, 8, 2)(B, L_q, H, E),来自 Decoder
k(3, 12, 8, 2)(B, L_kv, H, E),来自 Encoder
v(3, 12, 8, 2)传入但完全不使用(K 同时充当 V)
输出 out(3, 8, 2, 10)(B, H, E, L_q)——同样存在 view quirk

3. 顺序图

4. 语义分组图

5. 逐步骤精读

§5.0 完整原始代码

python
class FourierCrossAttention(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        seq_len_q,
        seq_len_kv,
        modes=64,
        mode_select_method="random",
        activation="tanh",
        policy=0,
        num_heads=8,
    ):
        super(FourierCrossAttention, self).__init__()
        print(" fourier enhanced cross attention used!")
        self.activation = activation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.index_q = get_frequency_modes(
            seq_len_q, modes=modes, mode_select_method=mode_select_method
        )
        self.index_kv = get_frequency_modes(
            seq_len_kv, modes=modes, mode_select_method=mode_select_method
        )
        print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q))
        print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv))

        self.scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            self.scale
            * torch.rand(
                num_heads,
                in_channels // num_heads,
                out_channels // num_heads,
                len(self.index_q),
                dtype=torch.float,
            )
        )
        self.weights2 = nn.Parameter(
            self.scale
            * torch.rand(
                num_heads,
                in_channels // num_heads,
                out_channels // num_heads,
                len(self.index_q),
                dtype=torch.float,
            )
        )

    def compl_mul1d(self, order, x, weights):
        x_flag = True
        w_flag = True
        if not torch.is_complex(x):
            x_flag = False
            x = torch.complex(x, torch.zeros_like(x).to(x.device))
        if not torch.is_complex(weights):
            w_flag = False
            weights = torch.complex(
                weights, torch.zeros_like(weights).to(weights.device)
            )
        if x_flag or w_flag:
            return torch.complex(
                torch.einsum(order, x.real, weights.real)
                - torch.einsum(order, x.imag, weights.imag),
                torch.einsum(order, x.real, weights.imag)
                + torch.einsum(order, x.imag, weights.real),
            )
        else:
            return torch.einsum(order, x.real, weights.real)

    def forward(self, q, k, v, mask):
        B, L, H, E = q.shape
        xq = q.permute(0, 2, 3, 1)
        xk = k.permute(0, 2, 3, 1)
        xv = v.permute(0, 2, 3, 1)

        xq_ft_ = torch.zeros(
            B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
        )
        xq_ft = torch.fft.rfft(xq, dim=-1)
        for i, j in enumerate(self.index_q):
            if j >= xq_ft.shape[3]:
                continue
            xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
        xk_ft_ = torch.zeros(
            B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat
        )
        xk_ft = torch.fft.rfft(xk, dim=-1)
        for i, j in enumerate(self.index_kv):
            if j >= xk_ft.shape[3]:
                continue
            xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]

        xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)
        if self.activation == "tanh":
            xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
        elif self.activation == "softmax":
            xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
            xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
        else:
            raise Exception(
                "{} actiation function is not implemented".format(self.activation)
            )
        xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
        xqkvw = self.compl_mul1d(
            "bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)
        )
        out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
        for i, j in enumerate(self.index_q):
            if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
                continue
            out_ft[:, :, :, j] = xqkvw[:, :, :, i]
        out = torch.fft.irfft(
            out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
        )
        return (out, None)

§5.1 宏观逻辑

核心设计意图:标准 Cross-Attention 在时域做 softmax(QKTdk)V,复杂度 O(LqLkv)。当 Lq=10,Lkv=12 时为 O(120)。FourierCrossAttention 在频域做注意力:Q 和 K 各自只选 Mq=Mkv=4 个模式,注意力矩阵从 (Lq×Lkv)=(10×12) 缩减到 (Mq×Mkv)=(4×4),复杂度为 O(M2)(固定 M 时与序列长度无关)。

与 FourierBlock 的关键区别

FourierBlock(self-attn)FourierCrossAttention(cross-attn)
Q/K/V只用 Q(K/V 忽略)Q 来自 dec,K 来自 enc;v 忽略
频率选择一套 index两套 index(index_q, index_kv)
变换Q→M 频率独立线性变换Q×K 频域点积 → tanh → ×K 加权
注意力矩阵(Mq×Mkv)=(4×4) cfloat
n_heads 硬编码⚠️ (写死 8)(使用 num_heads 参数)

用小例子(B=1, L_q=4, L_kv=6, H=2, E=2, modes=2)串起来看:

Q rfft: (1,2,2,4) → (1,2,2,3) cfloat; 选 index_q=[1,2] → xq_ft_(1,2,2,2)
K rfft: (1,2,2,6) → (1,2,2,4) cfloat; 选 index_kv=[0,2] → xk_ft_(1,2,2,2)

注意力矩阵: "bhex,bhey→bhxy"
  xq_ft_(1,2,2,2) × xk_ft_(1,2,2,2) → xqk_ft(1,2,2,2)  [2×2 频域注意力]

tanh: (1,2,2,2) → (1,2,2,2) 不变

V加权: "bhxy,bhey→bhex"
  xqk_ft(1,2,2,2) × xk_ft_(1,2,2,2) → xqkv_ft(1,2,2,2)

投影: "bhex,heox→bhox"
  xqkv_ft(1,2,2,2) × W(2,2,2,2) → xqkvw(1,2,2,2)

scatter: out_ft(1,2,2,3); 填入 j=1 → slot1, j=2 → slot2; slot0=0
irfft(n=4) → (1,2,2,4) real

完整 shape 变化链(toy 全局参数)

(3,10,16)+(3,12,16)Linear+viewq(3,10,8,2)+k(3,12,8,2)permutexq(3,8,2,10)+xk(3,8,2,12)rfftxq_ft(3,8,2,6)+xk_ft(3,8,2,7)选M=4模式xq_ft_(3,8,2,4)+xk_ft_(3,8,2,4)频域注意力xqk_ft(3,8,4,4)tanh(3,8,4,4)V加权xqkv_ft(3,8,2,4)W投影xqkvw(3,8,2,4)scatter+÷256out_ft(3,8,2,6)irfft(n=10)out(3,8,2,10)view(3,10,16)out_projection(3,10,16)

注意力复杂度分析

标准 Cross-Attention:O(LqLkvdk)=O(10×12×2)=O(240)

FourierCrossAttention:rfft O(LqlogLq)O(33);频域注意力矩阵 O(M2HE)=O(16×8×2)=O(256);irfft O(LqlogLq)

L 时无优势;L 固定 M 时 FourierCrossAttention 注意力矩阵部分为 O(M2)(常数),仅 rfft/irfft 随 L 增长(O(LlogL))——大序列时优于标准 O(LqLkv)

整体数据流 SVG

§5.2 __init__ — 双套频率模式初始化

Decoder cross-attn FourierCrossAttention 初始化参数(来自 FEDformer.__init__):

python
decoder_cross_att = FourierCrossAttention(
    in_channels=configs.d_model,      # 16
    out_channels=configs.d_model,     # 16
    seq_len_q=self.seq_len // 2 + self.pred_len,  # 10
    seq_len_kv=self.seq_len,          # 12
    modes=self.modes,                 # 4(toy)
    mode_select_method=self.mode_select,  # "random"
    num_heads=configs.n_heads,        # 8  ← 正确使用 num_heads!
)

index_q(Q 来自 Decoder,seq_len_q=10):

actual_modes_q = min(4, 10//2=5) = 4list(range(0, 5)) = [0,1,2,3,4] → shuffle → 取前 4 → sort,例如 index_q = [1, 2, 3, 4]

index_kv(K 来自 Encoder,seq_len_kv=12):

actual_modes_kv = min(4, 12//2=6) = 4list(range(0, 6)) = [0,1,2,3,4,5] → shuffle → 取前 4 → sort,例如 index_kv = [0, 2, 4, 5]

双套 index 与 FourierBlock 的区别

FourierBlock 只有一套 index(因为 Q 和 K 的 seq_len 相同)。FourierCrossAttention 为 Q 和 K 维护独立的 index,各自从不同长度的频谱(Lq//2+1=6 vs Lkv//2+1=7)中选取 M 个模式,然后通过 (Mq×Mkv) 的频域注意力矩阵进行对齐。

权重初始化(toy:in_channels=out_channels=16, num_heads=8, actual_modes_q=4):

self.weights1 = nn.Parameter(scale * torch.rand(8, 2, 2, 4)) = (8, 2, 2, 4)

self.weights2:同形 (8, 2, 2, 4)

注意:权重第四维是 len(index_q)=4,不是 len(index_kv)——权重对应 Q 的模式数(输出空间大小)。

§5.3 forward — permute + 双路 rfft + 模式采样

python
B, L, H, E = q.shape
xq = q.permute(0, 2, 3, 1)
xk = k.permute(0, 2, 3, 1)
xv = v.permute(0, 2, 3, 1)

B, L, H, E = q.shapeQ 提取维度:B=3, L=10(dec_len), H=8, E=2。

permute 后:

xq = (3, 10, 8, 2)(3, 8, 2, 10),时间轴移到最后。

xk = (3, 12, 8, 2)(3, 8, 2, 12)

xv = (3, 12, 8, 2)(3, 8, 2, 12)

⚠️ xv 被计算但从未使用

xv 完整地执行了 permute,但后续代码中 xv 从未出现——所有计算只用 xqxk。K 在 V 加权步骤(§5.6)中同时充当 V 角色(使用 xk_ft_),传入的 v 张量被完全忽略。

这意味着 FourierCrossAttention 在计算图中等价于 KQK 注意力:Attn(Q,K,K),而非标准的 Attn(Q,K,V)

Q 路径——rfft + 模式采样

python
xq_ft_ = torch.zeros(
    B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
    if j >= xq_ft.shape[3]:
        continue
    xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]

rfft(xq, dim=-1)(3, 8, 2, 10)(3, 8, 2, 6) cfloat10//2+1=6 个独立频率)。

xq_ft_:预分配 (3, 8, 2, 4) cfloat,全零初始化。

采样循环(假设 index_q=[1,2,3,4]):

  • i=0, j=1xq_ft_[:,:,:,0] = xq_ft[:,:,:,1](频率 ω=1 的系数)
  • i=1, j=2xq_ft_[:,:,:,1] = xq_ft[:,:,:,2]
  • i=2, j=3xq_ft_[:,:,:,2] = xq_ft[:,:,:,3]
  • i=3, j=4xq_ft_[:,:,:,3] = xq_ft[:,:,:,4]

xq_ft_.shape = (3, 8, 2, 4),第四维对应 4 个选出的 Q 频率模式。

K 路径——rfft + 模式采样

python
xk_ft_ = torch.zeros(
    B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat
)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_kv):
    if j >= xk_ft.shape[3]:
        continue
    xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]

rfft(xk, dim=-1)(3, 8, 2, 12)(3, 8, 2, 7) cfloat12//2+1=7 个独立频率)。

xk_ft_:预分配 (3, 8, 2, 4) cfloat,全零初始化。

采样循环(假设 index_kv=[0,2,4,5]):

  • i=0, j=0xk_ft_[:,:,:,0] = xk_ft[:,:,:,0](直流分量)
  • i=1, j=2xk_ft_[:,:,:,1] = xk_ft[:,:,:,2]
  • i=2, j=4xk_ft_[:,:,:,2] = xk_ft[:,:,:,4]
  • i=3, j=5xk_ft_[:,:,:,3] = xk_ft[:,:,:,5]

xk_ft_.shape = (3, 8, 2, 4),第四维对应 4 个选出的 K 频率模式。

toy 数值(xq_ft):设 Q 对应的 seasonal_init 中一条时序(head=0, E=0, batch=0)12步均值接近零,则 xq_ft[0,0,0,:] ≈ [small_real, complex1, complex2, complex3, complex4, complex5],xq_ft_[0,0,0,:] = [complex1, complex2, complex3, complex4]。

§5.4 forward — 频域注意力矩阵 xqk_ft

python
xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)

xq_ft_ shape (3, 8, 2, 4) → b=B, h=H, e=E=2, x=modes_q=4

xk_ft_ shape (3, 8, 2, 4) → b=B, h=H, e=E=2, y=modes_kv=4

einsum "bhex,bhey->bhxy":对 e 维(E=2)收缩,输出 (3, 8, 4, 4)

数学含义:

xqk\_ft[b,h,x,y]=e=0E1q^¯ωxQ(b,h,e)k^ωyK(b,h,e)

(复数乘法展开为 compl_mul1d 的 ac-bd / ad+bc 分解,等价于频域内积)

物理含义:对每对 (Q 模式 x, K 模式 y),计算跨 E 维的复数内积,得到该频率对之间的"相关强度"——构成一个 (Mq×Mkv)=(4×4) 的频域注意力矩阵。

xqk_ft.shape = (3, 8, 4, 4) cfloat

toy 数值:设 xq_ft_[0,0,:,0] = [0.3+0.1j, 0.2+0.05j](E=0和E=1的分量),xk_ft_[0,0,:,0] = [0.4+0.2j, 0.1+0.3j],则:

xqk\_ft[0,0,0,0] 的实部 =0.3×0.40.1×0.2+0.2×0.10.05×0.3=0.120.02+0.020.015=0.105

xqk\_ft[0,0,0,0] 的虚部 =0.3×0.2+0.1×0.4+0.2×0.3+0.05×0.1=0.06+0.04+0.06+0.005=0.165

xqk\_ft[0,0,0,0]0.105+0.165j

§5.5 forward — tanh 激活

python
if self.activation == "tanh":
    xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == "softmax":
    xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
    xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))

TFB 默认 activation="tanh",走第一分支。

xqk_ft.real.tanh():对 (3,8,4,4) 实部逐元素应用 tanh,将值压缩到 (1,1)

xqk_ft.imag.tanh():对虚部同样逐元素应用 tanh

torch.complex(...) 重新组装为复数张量,shape 不变:(3, 8, 4, 4) cfloat

tanh vs softmax 的设计意图

标准 Attention 用 softmax 使注意力权重沿 K 维归一化为概率分布。频域中,softmax 作用在 abs(xqk_ft) 上(取模),丢弃相位信息,只保留幅度权重(见 elif 分支)。

tanh 方案保留了复数结构(实虚部分开压缩),保留相位信息,也就保留了频率成分的"方向"(正弦/余弦分量比例)。代码变量名中的 typo 也暗示这是非正式实验:"{} actiation function is not implemented" — "actiation" 是 "activation" 的拼写错误(原始代码保留,不影响逻辑)。

toy 数值(续接 §5.4):xqk_ft[0,0,0,0] ≈ 0.105 + 0.165j,tanh 后:

tanh(0.105)0.1044tanh(0.165)0.1641

即 tanh 激活后 xqk_ft[0,0,0,0] ≈ 0.1044 + 0.1641j(绝对值略有缩小,接近原值,因为输入已较小)。

§5.6 forward — V 加权 xqkv_ft(K 充当 V)

python
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)

xqk_ft shape (3, 8, 4, 4) → b=B, h=H, x=modes_q=4, y=modes_kv=4

xk_ft_ shape (3, 8, 2, 4) → b=B, h=H, e=E=2, y=modes_kv=4

einsum "bhxy,bhey->bhex":对 y 维(modes_kv=4)收缩,输出 (3, 8, 2, 4)

数学含义:

xqkv\_ft[b,h,e,x]=y=0Mkv1xqk\_ft[b,h,x,y]k^ωyK(b,h,e)

物理含义:对每个 Q 模式 x,用注意力权重 xqk\_ft[...,x,y] 加权对所有 K 频率模式 y 的特征 k^ωy(h,e) 求和——相当于以 K 充当 V 的频域加权聚合,得到每个 Q 模式的 attended 表示。

xqkv_ft.shape = (3, 8, 2, 4) cfloat

toy 数值(简化,只看 [0,0,0,0]):

xqk_ft[0,0,0,:] ≈ [0.1044+0.1641j, 0.05+0.08j, 0.12+0.09j, 0.07+0.11j](4 个 K 模式的注意力权重)。

xk_ft_[0,0,0,:] ≈ [0.4+0.2j, 0.3+0.15j, 0.25+0.1j, 0.35+0.2j](4 个 K 模式的 E=0 特征)。

xqkv\_ft[0,0,0,0]=y [复数乘法,4项相加] ≈ 一个包含聚合信息的复数值。

§5.7 forward — 可学习权重投影 xqkvw

python
xqkvw = self.compl_mul1d(
    "bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)
)

xqkv_ft shape (3, 8, 2, 4) → b=B, h=H, e=E_in=2, x=modes_q=4

W = torch.complex(weights1, weights2) shape (8, 2, 2, 4) → h=H, e=E_in=2, o=E_out=2, x=modes_q=4

einsum "bhex,heox->bhox":对 e 维(E_in=2)收缩,输出 (3, 8, 2, 4)

数学含义:

xqkvw[b,h,o,x]=e=0E1xqkv\_ft[b,h,e,x]W(h,e,o,x)

每个 Q 频率模式 x,对每个 head h,做 EinEout 的复数线性变换。W 是复数可学习参数(W1+jW2),每个频率模式有独立的 E×E 变换矩阵。

xqkvw.shape = (3, 8, 2, 4) cfloat

与 FourierBlock 的区别:FourierBlock 的权重形状也是 (8, 2, 2, 4),einsum 也是对 E 维收缩。但 FourierBlock 的输入是直接从 Q 频谱采样的系数("bhi,hio->bho"),而这里输入是已经过 Q×K 注意力加权的结果。

§5.8 forward — scatter + irfft + 归一化

python
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
    if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
        continue
    out_ft[:, :, :, j] = xqkvw[:, :, :, i]
out = torch.fft.irfft(
    out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
)
return (out, None)

scatter 回频率轴

out_ft = zeros(3, 8, 2, L//2+1) = zeros(3, 8, 2, 6) cfloatL=10 来自 q.shape[1]10//2+1=6)。

逆向映射(scatter,与采样时方向相反):

采样时:xq_ft_[:,:,:,i] = xq_ft[:,:,:,j]  (全谱 j → 密集 i)
scatter:out_ft[:,:,:,j] = xqkvw[:,:,:,i]  (密集 i → 全谱 j)

假设 index_q=[1,2,3,4]:

  • i=0, j=1out_ft[:,:,:,1] = xqkvw[:,:,:,0]
  • i=1, j=2out_ft[:,:,:,2] = xqkvw[:,:,:,1]
  • i=2, j=3out_ft[:,:,:,3] = xqkvw[:,:,:,2]
  • i=3, j=4out_ft[:,:,:,4] = xqkvw[:,:,:,3]
  • out_ft[:,:,:,0] = 0,out_ft[:,:,:,5] = 0(未选中频率保持零)

out_ft.shape = (3, 8, 2, 6) cfloat,仅 4/6 个频率槽被填充。

归一化

out_ft / self.in_channels / self.out_channels = out_ft / 16 / 16 = out_ft / 256

FourierBlock 无此归一化;FourierCrossAttention 在 irfft 前额外除以 (in\_channels×out\_channels),与权重初始化的 scale = 1/(in_channels * out_channels) 相呼应——两者组合使输出能量归一化。

irfft 还原时域

irfft(out_ft_scaled, n=xq.size(-1)) = irfft(..., n=10)(3, 8, 2, 6) cfloat(3, 8, 2, 10) real

返回 (out, None) = ((3, 8, 2, 10), None)

AutoCorrelationLayer view quirk(与 FourierBlock 完全相同机制):

out.view(B, L, -1) = view(3, 10, -1) = view(3, 10, 16)

总元素:3×8×2×10=4803×10×16=480 ✓。

内存布局同样是 (B,H,E,L) 按 H-E-L 顺序存储,view 后每个"时间步"实际混合了来自不同 head 和时间步的信息。out_projection(Linear 1616)学习从混合表示提取有用特征,模型收敛。

最终输出:out_projection(out.view(3,10,-1)) → (3, 10, 16),即 Decoder cross-attention 的输出,继续流入第 2 次 series_decomp


6. FourierCrossAttention vs FourierBlock 完整对比

维度FourierBlock(self-attn)FourierCrossAttention(cross-attn)
Q 来源Encoder 或 Decoder 自身Decoder(dec_len=10)
K/V 来源同 Q(self-attn),但忽略K 来自 Encoder(enc_len=12),v 忽略
频率选择一套 index(seq_len 单一)双套 index(index_q / index_kv,不同 seq_len)
核心操作M 个频率各自独立线性变换Q×K 频域注意力矩阵 +tanh+K 加权
注意力矩阵(Mq×Mkv)=(4×4)
权重维度硬编码 8 ⚠️正确使用 num_heads
额外归一化irfft 前除以 in\_ch×out\_ch
einsum 链1 次 "bhi,hio→bho"3 次:"bhex,bhey→bhxy" / "bhxy,bhey→bhex" / "bhex,heox→bhox"

*记录并在线阅读我的笔记*