Appearance
Layer 2B — FourierCrossAttention(频域 Cross-Attention)
1. 在父层中的位置
forecast() 中 Decoder 的 cross-attention AutoCorrelationLayer 内嵌了 FourierCrossAttention 作为 inner_correlation。AutoCorrelationLayer 完成线性投影和多头拆分后,把 (q, k, v, attn_mask) 传入 FourierCrossAttention.forward()。
与 [[03A-Layer2A-FourierBlock]] 的关键区别:Q 来自 decoder(dec_len=10),K/V 来自 encoder(seq_len=12),两者序列长度不同,需要各自独立地选择频率模式。
2. I/O 接口定义
AutoCorrelationLayer 接口(外层包装,Decoder cross-attn):
| 参数 | Shape | 含义 |
|---|---|---|
queries | (3, 10, 16) | Decoder 隐层输出 |
keys | (3, 12, 16) | Encoder enc_out |
values | (3, 12, 16) | Encoder enc_out(与 keys 相同源) |
| 输出 | (3, 10, 16) | Cross-attention 输出 |
FourierCrossAttention.forward 接口(多头投影后):
| 参数 | Shape | 含义 |
|---|---|---|
q | (3, 10, 8, 2) | (B, L_q, H, E),来自 Decoder |
k | (3, 12, 8, 2) | (B, L_kv, H, E),来自 Encoder |
v | (3, 12, 8, 2) | 传入但完全不使用(K 同时充当 V) |
输出 out | (3, 8, 2, 10) | (B, H, E, L_q)——同样存在 view quirk |
3. 顺序图
4. 语义分组图
5. 逐步骤精读
§5.0 完整原始代码
python
class FourierCrossAttention(nn.Module):
def __init__(
self,
in_channels,
out_channels,
seq_len_q,
seq_len_kv,
modes=64,
mode_select_method="random",
activation="tanh",
policy=0,
num_heads=8,
):
super(FourierCrossAttention, self).__init__()
print(" fourier enhanced cross attention used!")
self.activation = activation
self.in_channels = in_channels
self.out_channels = out_channels
self.index_q = get_frequency_modes(
seq_len_q, modes=modes, mode_select_method=mode_select_method
)
self.index_kv = get_frequency_modes(
seq_len_kv, modes=modes, mode_select_method=mode_select_method
)
print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q))
print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv))
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
self.scale
* torch.rand(
num_heads,
in_channels // num_heads,
out_channels // num_heads,
len(self.index_q),
dtype=torch.float,
)
)
self.weights2 = nn.Parameter(
self.scale
* torch.rand(
num_heads,
in_channels // num_heads,
out_channels // num_heads,
len(self.index_q),
dtype=torch.float,
)
)
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(
weights, torch.zeros_like(weights).to(weights.device)
)
if x_flag or w_flag:
return torch.complex(
torch.einsum(order, x.real, weights.real)
- torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag)
+ torch.einsum(order, x.imag, weights.real),
)
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
B, L, H, E = q.shape
xq = q.permute(0, 2, 3, 1)
xk = k.permute(0, 2, 3, 1)
xv = v.permute(0, 2, 3, 1)
xq_ft_ = torch.zeros(
B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
if j >= xq_ft.shape[3]:
continue
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(
B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat
)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_kv):
if j >= xk_ft.shape[3]:
continue
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)
if self.activation == "tanh":
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == "softmax":
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception(
"{} actiation function is not implemented".format(self.activation)
)
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = self.compl_mul1d(
"bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)
)
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
continue
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
out = torch.fft.irfft(
out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
)
return (out, None)§5.1 宏观逻辑
核心设计意图:标准 Cross-Attention 在时域做
与 FourierBlock 的关键区别:
| FourierBlock(self-attn) | FourierCrossAttention(cross-attn) | |
|---|---|---|
| Q/K/V | 只用 Q(K/V 忽略) | Q 来自 dec,K 来自 enc;v 忽略 |
| 频率选择 | 一套 index | 两套 index(index_q, index_kv) |
| 变换 | Q→M 频率独立线性变换 | Q×K 频域点积 → tanh → ×K 加权 |
| 注意力矩阵 | 无 | |
| n_heads 硬编码 | ⚠️ 是(写死 8) | ✅ 否(使用 num_heads 参数) |
用小例子(B=1, L_q=4, L_kv=6, H=2, E=2, modes=2)串起来看:
Q rfft: (1,2,2,4) → (1,2,2,3) cfloat; 选 index_q=[1,2] → xq_ft_(1,2,2,2)
K rfft: (1,2,2,6) → (1,2,2,4) cfloat; 选 index_kv=[0,2] → xk_ft_(1,2,2,2)
注意力矩阵: "bhex,bhey→bhxy"
xq_ft_(1,2,2,2) × xk_ft_(1,2,2,2) → xqk_ft(1,2,2,2) [2×2 频域注意力]
tanh: (1,2,2,2) → (1,2,2,2) 不变
V加权: "bhxy,bhey→bhex"
xqk_ft(1,2,2,2) × xk_ft_(1,2,2,2) → xqkv_ft(1,2,2,2)
投影: "bhex,heox→bhox"
xqkv_ft(1,2,2,2) × W(2,2,2,2) → xqkvw(1,2,2,2)
scatter: out_ft(1,2,2,3); 填入 j=1 → slot1, j=2 → slot2; slot0=0
irfft(n=4) → (1,2,2,4) real完整 shape 变化链(toy 全局参数):
(3,10,16)+(3,12,16) → Linear+view → q(3,10,8,2)+k(3,12,8,2) → permute → xq(3,8,2,10)+xk(3,8,2,12) → rfft → xq_ft(3,8,2,6)+xk_ft(3,8,2,7) → 选M=4模式 → xq_ft_(3,8,2,4)+xk_ft_(3,8,2,4) → 频域注意力 → xqk_ft(3,8,4,4) → tanh → (3,8,4,4) → V加权 → xqkv_ft(3,8,2,4) → W投影 → xqkvw(3,8,2,4) → scatter+÷256 → out_ft(3,8,2,6) → irfft(n=10) → out(3,8,2,10) → view(3,10,16) → out_projection → (3,10,16)
注意力复杂度分析
标准 Cross-Attention:
; FourierCrossAttention:rfft
;频域注意力矩阵 ;irfft 。 小
时无优势; 固定 时 FourierCrossAttention 注意力矩阵部分为 (常数),仅 rfft/irfft 随 增长( )——大序列时优于标准 。
整体数据流 SVG:
§5.2 __init__ — 双套频率模式初始化
Decoder cross-attn FourierCrossAttention 初始化参数(来自 FEDformer.__init__):
python
decoder_cross_att = FourierCrossAttention(
in_channels=configs.d_model, # 16
out_channels=configs.d_model, # 16
seq_len_q=self.seq_len // 2 + self.pred_len, # 10
seq_len_kv=self.seq_len, # 12
modes=self.modes, # 4(toy)
mode_select_method=self.mode_select, # "random"
num_heads=configs.n_heads, # 8 ← 正确使用 num_heads!
)index_q(Q 来自 Decoder,seq_len_q=10):
actual_modes_q = min(4, 10//2=5) = 4;list(range(0, 5)) = [0,1,2,3,4] → shuffle → 取前 4 → sort,例如 index_q = [1, 2, 3, 4]。
index_kv(K 来自 Encoder,seq_len_kv=12):
actual_modes_kv = min(4, 12//2=6) = 4;list(range(0, 6)) = [0,1,2,3,4,5] → shuffle → 取前 4 → sort,例如 index_kv = [0, 2, 4, 5]。
双套 index 与 FourierBlock 的区别
FourierBlock 只有一套 index(因为 Q 和 K 的 seq_len 相同)。FourierCrossAttention 为 Q 和 K 维护独立的 index,各自从不同长度的频谱(
vs )中选取 个模式,然后通过 的频域注意力矩阵进行对齐。
权重初始化(toy:in_channels=out_channels=16, num_heads=8, actual_modes_q=4):
self.weights1 = nn.Parameter(scale * torch.rand(8, 2, 2, 4)) = (8, 2, 2, 4)。
self.weights2:同形 (8, 2, 2, 4)。
注意:权重第四维是 len(index_q)=4,不是 len(index_kv)——权重对应 Q 的模式数(输出空间大小)。
§5.3 forward — permute + 双路 rfft + 模式采样
python
B, L, H, E = q.shape
xq = q.permute(0, 2, 3, 1)
xk = k.permute(0, 2, 3, 1)
xv = v.permute(0, 2, 3, 1)B, L, H, E = q.shape 从 Q 提取维度:B=3, L=10(dec_len), H=8, E=2。
permute 后:
xq = (3, 10, 8, 2) → (3, 8, 2, 10),时间轴移到最后。
xk = (3, 12, 8, 2) → (3, 8, 2, 12)。
xv = (3, 12, 8, 2) → (3, 8, 2, 12)。
⚠️ xv 被计算但从未使用
xv完整地执行了 permute,但后续代码中xv从未出现——所有计算只用xq和xk。K 在 V 加权步骤(§5.6)中同时充当 V 角色(使用xk_ft_),传入的v张量被完全忽略。这意味着 FourierCrossAttention 在计算图中等价于
KQK注意力:,而非标准的 。
Q 路径——rfft + 模式采样:
python
xq_ft_ = torch.zeros(
B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
if j >= xq_ft.shape[3]:
continue
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]rfft(xq, dim=-1):(3, 8, 2, 10) → (3, 8, 2, 6) cfloat(
xq_ft_:预分配 (3, 8, 2, 4) cfloat,全零初始化。
采样循环(假设 index_q=[1,2,3,4]):
i=0, j=1:xq_ft_[:,:,:,0] = xq_ft[:,:,:,1](频率 ω=1 的系数)i=1, j=2:xq_ft_[:,:,:,1] = xq_ft[:,:,:,2]i=2, j=3:xq_ft_[:,:,:,2] = xq_ft[:,:,:,3]i=3, j=4:xq_ft_[:,:,:,3] = xq_ft[:,:,:,4]
xq_ft_.shape = (3, 8, 2, 4),第四维对应 4 个选出的 Q 频率模式。
K 路径——rfft + 模式采样:
python
xk_ft_ = torch.zeros(
B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat
)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_kv):
if j >= xk_ft.shape[3]:
continue
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]rfft(xk, dim=-1):(3, 8, 2, 12) → (3, 8, 2, 7) cfloat(
xk_ft_:预分配 (3, 8, 2, 4) cfloat,全零初始化。
采样循环(假设 index_kv=[0,2,4,5]):
i=0, j=0:xk_ft_[:,:,:,0] = xk_ft[:,:,:,0](直流分量)i=1, j=2:xk_ft_[:,:,:,1] = xk_ft[:,:,:,2]i=2, j=4:xk_ft_[:,:,:,2] = xk_ft[:,:,:,4]i=3, j=5:xk_ft_[:,:,:,3] = xk_ft[:,:,:,5]
xk_ft_.shape = (3, 8, 2, 4),第四维对应 4 个选出的 K 频率模式。
toy 数值(xq_ft):设 Q 对应的 seasonal_init 中一条时序(head=0, E=0, batch=0)12步均值接近零,则 xq_ft[0,0,0,:] ≈ [small_real, complex1, complex2, complex3, complex4, complex5],xq_ft_[0,0,0,:] = [complex1, complex2, complex3, complex4]。
§5.4 forward — 频域注意力矩阵 xqk_ft
python
xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)xq_ft_ shape (3, 8, 2, 4) → b=B, h=H, e=E=2, x=modes_q=4
xk_ft_ shape (3, 8, 2, 4) → b=B, h=H, e=E=2, y=modes_kv=4
einsum "bhex,bhey->bhxy":对 e 维(E=2)收缩,输出 (3, 8, 4, 4)。
数学含义:
(复数乘法展开为 compl_mul1d 的 ac-bd / ad+bc 分解,等价于频域内积)
物理含义:对每对 (Q 模式
xqk_ft.shape = (3, 8, 4, 4) cfloat。
toy 数值:设 xq_ft_[0,0,:,0] = [0.3+0.1j, 0.2+0.05j](E=0和E=1的分量),xk_ft_[0,0,:,0] = [0.4+0.2j, 0.1+0.3j],则:
即
§5.5 forward — tanh 激活
python
if self.activation == "tanh":
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == "softmax":
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))TFB 默认 activation="tanh",走第一分支。
xqk_ft.real.tanh():对
xqk_ft.imag.tanh():对虚部同样逐元素应用
torch.complex(...) 重新组装为复数张量,shape 不变:(3, 8, 4, 4) cfloat。
tanh vs softmax 的设计意图
标准 Attention 用 softmax 使注意力权重沿 K 维归一化为概率分布。频域中,softmax 作用在
abs(xqk_ft)上(取模),丢弃相位信息,只保留幅度权重(见 elif 分支)。tanh 方案保留了复数结构(实虚部分开压缩),保留相位信息,也就保留了频率成分的"方向"(正弦/余弦分量比例)。代码变量名中的
typo也暗示这是非正式实验:"{} actiation function is not implemented"— "actiation" 是 "activation" 的拼写错误(原始代码保留,不影响逻辑)。
toy 数值(续接 §5.4):xqk_ft[0,0,0,0] ≈ 0.105 + 0.165j,tanh 后:
即 tanh 激活后 xqk_ft[0,0,0,0] ≈ 0.1044 + 0.1641j(绝对值略有缩小,接近原值,因为输入已较小)。
§5.6 forward — V 加权 xqkv_ft(K 充当 V)
python
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)xqk_ft shape (3, 8, 4, 4) → b=B, h=H, x=modes_q=4, y=modes_kv=4
xk_ft_ shape (3, 8, 2, 4) → b=B, h=H, e=E=2, y=modes_kv=4
einsum "bhxy,bhey->bhex":对 y 维(modes_kv=4)收缩,输出 (3, 8, 2, 4)。
数学含义:
物理含义:对每个 Q 模式
xqkv_ft.shape = (3, 8, 2, 4) cfloat
toy 数值(简化,只看 [0,0,0,0]):
设 xqk_ft[0,0,0,:] ≈ [0.1044+0.1641j, 0.05+0.08j, 0.12+0.09j, 0.07+0.11j](4 个 K 模式的注意力权重)。
xk_ft_[0,0,0,:] ≈ [0.4+0.2j, 0.3+0.15j, 0.25+0.1j, 0.35+0.2j](4 个 K 模式的 E=0 特征)。
§5.7 forward — 可学习权重投影 xqkvw
python
xqkvw = self.compl_mul1d(
"bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)
)xqkv_ft shape (3, 8, 2, 4) → b=B, h=H, e=E_in=2, x=modes_q=4
W = torch.complex(weights1, weights2) shape (8, 2, 2, 4) → h=H, e=E_in=2, o=E_out=2, x=modes_q=4
einsum "bhex,heox->bhox":对 e 维(E_in=2)收缩,输出 (3, 8, 2, 4)。
数学含义:
每个 Q 频率模式
xqkvw.shape = (3, 8, 2, 4) cfloat
与 FourierBlock 的区别:FourierBlock 的权重形状也是 (8, 2, 2, 4),einsum 也是对 E 维收缩。但 FourierBlock 的输入是直接从 Q 频谱采样的系数("bhi,hio->bho"),而这里输入是已经过 Q×K 注意力加权的结果。
§5.8 forward — scatter + irfft + 归一化
python
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
continue
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
out = torch.fft.irfft(
out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
)
return (out, None)scatter 回频率轴:
out_ft = zeros(3, 8, 2, L//2+1) = zeros(3, 8, 2, 6) cfloat(q.shape[1],
逆向映射(scatter,与采样时方向相反):
采样时:xq_ft_[:,:,:,i] = xq_ft[:,:,:,j] (全谱 j → 密集 i)
scatter:out_ft[:,:,:,j] = xqkvw[:,:,:,i] (密集 i → 全谱 j)假设 index_q=[1,2,3,4]:
i=0, j=1:out_ft[:,:,:,1] = xqkvw[:,:,:,0]i=1, j=2:out_ft[:,:,:,2] = xqkvw[:,:,:,1]i=2, j=3:out_ft[:,:,:,3] = xqkvw[:,:,:,2]i=3, j=4:out_ft[:,:,:,4] = xqkvw[:,:,:,3]out_ft[:,:,:,0]= 0,out_ft[:,:,:,5]= 0(未选中频率保持零)
out_ft.shape = (3, 8, 2, 6) cfloat,仅 4/6 个频率槽被填充。
归一化:
out_ft / self.in_channels / self.out_channels = out_ft / 16 / 16 = out_ft / 256
FourierBlock 无此归一化;FourierCrossAttention 在 irfft 前额外除以 scale = 1/(in_channels * out_channels) 相呼应——两者组合使输出能量归一化。
irfft 还原时域:
irfft(out_ft_scaled, n=xq.size(-1)) = irfft(..., n=10):(3, 8, 2, 6) cfloat → (3, 8, 2, 10) real
返回 (out, None) = ((3, 8, 2, 10), None)。
AutoCorrelationLayer view quirk(与 FourierBlock 完全相同机制):
out.view(B, L, -1) = view(3, 10, -1) = view(3, 10, 16)。
总元素:
内存布局同样是 (B,H,E,L) 按 H-E-L 顺序存储,view 后每个"时间步"实际混合了来自不同 head 和时间步的信息。out_projection(Linear
最终输出:out_projection(out.view(3,10,-1)) → (3, 10, 16),即 Decoder cross-attention 的输出,继续流入第 2 次 series_decomp。
6. FourierCrossAttention vs FourierBlock 完整对比
| 维度 | FourierBlock(self-attn) | FourierCrossAttention(cross-attn) |
|---|---|---|
| Q 来源 | Encoder 或 Decoder 自身 | Decoder(dec_len=10) |
| K/V 来源 | 同 Q(self-attn),但忽略 | K 来自 Encoder(enc_len=12),v 忽略 |
| 频率选择 | 一套 index(seq_len 单一) | 双套 index(index_q / index_kv,不同 seq_len) |
| 核心操作 | ||
| 注意力矩阵 | 无 | |
| 权重维度 | 硬编码 8 ⚠️ | 正确使用 num_heads ✅ |
| 额外归一化 | 无 | irfft 前除以 |
| einsum 链 | 1 次 "bhi,hio→bho" | 3 次:"bhex,bhey→bhxy" / "bhxy,bhey→bhex" / "bhex,heox→bhox" |