Appearance
Layer 2A — FourierBlock(频域 Self-Attention)
1. 在父层中的位置
forecast() 中 Encoder 和 Decoder self-attention 的每个 AutoCorrelationLayer 内嵌了 FourierBlock 作为 inner_correlation。AutoCorrelationLayer 完成 Q/K/V 的线性投影和多头拆分后,把 (q, k, v, attn_mask) 传入 FourierBlock.forward()。FourierBlock 忽略 k、v、mask,只使用 q 做频域变换。
2. I/O 接口定义
AutoCorrelationLayer 接口(外层包装):
| 参数 | Shape | 含义 |
|---|---|---|
queries | (3, 12, 16) | Encoder self-attn:来自 enc_out 经 embedding |
keys | (3, 12, 16) | 同上(FourierBlock 实际忽略) |
values | (3, 12, 16) | 同上(FourierBlock 实际忽略) |
| 输出 | (3, 12, 16) | 注意力输出,shape 不变 |
FourierBlock.forward 接口(投影后):
| 参数 | Shape | 含义 |
|---|---|---|
q | (3, 12, 8, 2) | (B, L, H, E),多头拆分后 |
k | (3, 12, 8, 2) | 传入但不使用 |
v | (3, 12, 8, 2) | 传入但不使用 |
输出 x | (3, 8, 2, 12) | (B, H, E, L)——注意与 AutoCorrelationLayer 期望形状不同! |
3. 顺序图
4. 语义分组图
5. 逐步骤精读
§5.0 完整原始代码
python
class AutoCorrelationLayer(nn.Module):
def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None):
super(AutoCorrelationLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_correlation = correlation
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_correlation(queries, keys, values, attn_mask)
out = out.view(B, L, -1)
return self.out_projection(out), attn
class FourierBlock(nn.Module):
def __init__(
self, in_channels, out_channels, seq_len, modes=0, mode_select_method="random"
):
super(FourierBlock, self).__init__()
print("fourier enhanced block used!")
self.index = get_frequency_modes(
seq_len, modes=modes, mode_select_method=mode_select_method
)
print("modes={}, index={}".format(modes, self.index))
self.scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
self.scale * torch.rand(
8, in_channels // 8, out_channels // 8, len(self.index),
dtype=torch.float,
)
)
self.weights2 = nn.Parameter(
self.scale * torch.rand(
8, in_channels // 8, out_channels // 8, len(self.index),
dtype=torch.float,
)
)
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(
weights, torch.zeros_like(weights).to(weights.device)
)
if x_flag or w_flag:
return torch.complex(
torch.einsum(order, x.real, weights.real)
- torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag)
+ torch.einsum(order, x.imag, weights.real),
)
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
B, L, H, E = q.shape
x = q.permute(0, 2, 3, 1)
x_ft = torch.fft.rfft(x, dim=-1)
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
for wi, i in enumerate(self.index):
if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
continue
out_ft[:, :, :, wi] = self.compl_mul1d(
"bhi,hio->bho",
x_ft[:, :, :, i],
torch.complex(self.weights1, self.weights2)[:, :, :, wi],
)
x = torch.fft.irfft(out_ft, n=x.size(-1))
return (x, None)§5.1 宏观逻辑
核心设计意图:时间序列的有效信息集中在少数 M 个频率上。与其在所有时间步做
标准 Self-Attention:
FourierBlock:
用小例子(B=1, L=4, H=8, E=2, modes=2):
q (1,4,8,2)
→ permute → (1,8,2,4)
→ rfft(dim=-1) → (1,8,2,3) cfloat [3 = 4//2+1]
→ index = [1, 2](随机选2个模式)
→ out_ft = zeros(1,8,2,3) cfloat
wi=0, i=1: x_ft[:,:,:,1] (1,8,2) × W[:,:,:,0] (8,2,2) → (1,8,2) 填入 out_ft[:,:,:,0]
wi=1, i=2: x_ft[:,:,:,2] (1,8,2) × W[:,:,:,1] (8,2,2) → (1,8,2) 填入 out_ft[:,:,:,1]
out_ft[:,:,:,2] = 0(未选中的频率保持0)
→ irfft(out_ft, n=4) → (1,8,2,4) real完整 shape 变化链(Encoder self-attn,toy 全局参数):
(3,12,16) → Linear×3+view → (3,12,8,2) → permute → (3,8,2,12) → rfft → (3,8,2,7) cfloat → M=4频率线性变换 → out_ft(3,8,2,7) cfloat → irfft(n=12) → (3,8,2,12) → view(3,12,16) [quirk] → out_projection → (3,12,16)
注意力复杂度分析
FourierBlock 中:rfft 复杂度
;M=4 次复数线性变换 ;irfft 同样 。 对比:标准 FullAttention
;FourierBlock 约 (toy 参数下反而更大,因为 L 很小;L 大时 FourierBlock 优势显现)。M 固定时,大 L 下 FourierBlock 为 (rfft 主导),好于 。
架构图:
§5.2 __init__ — 频率模式选择与权重初始化
python
def get_frequency_modes(seq_len, modes=64, mode_select_method="random"):
modes = min(modes, seq_len // 2)
if mode_select_method == "random":
index = list(range(0, seq_len // 2))
np.random.shuffle(index)
index = index[:modes]
else:
index = list(range(0, modes))
index.sort()
return indexEncoder FourierBlock 初始化时:seq_len=12, modes=32(用户默认), mode_select_method="random"。
actual_modes = min(32, 12//2=6) = 6(用户 modes=32 被截断为 6!实际选取的模式数为 6,而非 32。在我们的 toy 参数 modes=4 时,actual_modes = min(4, 6) = 4)。
index = list(range(0, 6)) → shuffle → 取前 4 → sort → self.index 例如 [1, 2, 4, 5](每次运行可能不同)。
权重初始化(toy:in_channels=out_channels=16,actual_modes=4):
self.weights1: shape (8, 16//8, 16//8, 4) = (8, 2, 2, 4)
self.weights2: shape (8, 2, 2, 4)(虚部参数,与 weights1 同形)
⚠️ 硬编码 n_heads=8
weights1第一维写死为 8,而不是n_heads。因此:
n_heads=8:queries.shape = (B,L,8,2),einsum 中 h=8 匹配,✅ 正常n_heads=4:queries.shape = (B,L,4,2),einsum 中 h=4 ≠ 8,❌ RuntimeError正确写法应为
torch.rand(n_heads, in_channels//n_heads, ...)但源码硬编码了 8。FourierCrossAttention修正了此问题,使用num_heads参数。
Decoder self-attn FourierBlock 的 seq_len 不同:seq_len_dec = seq_len//2 + pred_len = 6+4 = 10,故 actual_modes_dec = min(4, 10//2=5) = 4,rfft 输出 10//2+1=6 个频率,index 从 [0..4] 中选 4 个。
§5.3 forward — permute + rfft
python
B, L, H, E = q.shape
x = q.permute(0, 2, 3, 1)
x_ft = torch.fft.rfft(x, dim=-1)toy 全局参数(Encoder self-attn):B=3, L=12, H=8, E=2。
q.permute(0, 2, 3, 1) 把 L 轴移到最后,使 rfft 沿时间轴做变换。
shape 变化:(3, 12, 8, 2) → (3, 8, 2, 12) — H 和 E 成为"特征轴",L 成为"信号轴"。
rfft(x, dim=-1) 对最后一维(L=12)做实数 FFT,输出
shape 变化:(3, 8, 2, 12) real → (3, 8, 2, 7) cfloat
toy 数值:x_ft[0,0,0,:] 是 batch=0, head=0, enc_var=0 的 7 个复数频率系数。频率 0 对应直流(均值),频率 6 对应奈奎斯特频率(最高频)。
§5.4 forward — M 频率线性变换(核心)
python
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
for wi, i in enumerate(self.index):
if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
continue
out_ft[:, :, :, wi] = self.compl_mul1d(
"bhi,hio->bho",
x_ft[:, :, :, i],
torch.complex(self.weights1, self.weights2)[:, :, :, wi],
)out_ft = zeros(3, 8, 2, 7) cfloat — 全零初始化,未选中的频率槽保持 0(irfft 后等于没有这些频率成分)。
权重 W:torch.complex(weights1, weights2) → shape (8, 2, 2, 4) cfloat
逐模式变换(假设 index=[1,2,4,5]):
wi=0, i=1:
x_ft[:,:,:,1]→ shape(3, 8, 2) cfloat— 频率 ω₁=1 的所有 batch/head/channel 系数W[:,:,:,0]→ shape(8, 2, 2) cfloat— 第 0 个模式的权重矩阵compl_mul1d("bhi,hio->bho", ...)→ 复数 einsum:(3,8,2) × (8,2,2) → (3,8,2)
物理含义:对频率 ω₁ 处的特征向量
out_ft[:,:,:,0] 被填入结果,其余 out_ft[:,:,:,1:7] 中只有对应选中频率的槽被填,其余保持 0。
compl_mul1d 复数乘法原理:
python
real = einsum(order, x.real, w.real) - einsum(order, x.imag, w.imag) # ac - bd
imag = einsum(order, x.real, w.imag) + einsum(order, x.imag, w.real) # ad + bc
return torch.complex(real, imag)早期 PyTorch 版本对复数 einsum 支持不稳定,这里手动分解实虚部实现复数矩阵乘法。
toy 数值:假设 x_ft[0,0,0,1] = 0.3+0.2j(频率1,batch0,head0,E0),W[0,0,0,0] = 0.1+0.05j(第0个输出模式,head0,E_in=0→E_out=0): real = 0.3×0.1 - 0.2×0.05 = 0.03-0.01 = 0.02 imag = 0.3×0.05 + 0.2×0.1 = 0.015+0.02 = 0.035 out_ft[0,0,0,0] 的 (0,0,0) 分量贡献 = 0.02+0.035j(还需加 E_in=1 的贡献)。
§5.5 forward — irfft 还原时域
python
x = torch.fft.irfft(out_ft, n=x.size(-1))
return (x, None)x.size(-1) 此时是 x(permute 后的张量)的最后一维大小 = L = 12(permute 后 x shape 已是 (3,8,2,12),x.size(-1)=12)。
irfft(out_ft, n=12) 把 (3,8,2,7) cfloat 还原为 (3,8,2,12) real(利用实信号共轭对称性,n=12 指定输出长度)。
返回 (x, None) = ((3,8,2,12), None)。
⚠️ AutoCorrelationLayer 的 shape quirk
AutoCorrelationLayer.forward()接到返回值后执行:pythonout = out.view(B, L, -1)FourierBlock 返回
(3, 8, 2, 12)= (B, H, E, L),总元素 = 3×8×2×12 = 576。view(3, 12, -1)=view(3, 12, 16)也是 576 个元素,reshape 不报错。但内存布局:原始 (B,H,E,L) 按 H-E-L 顺序存储。
view(B,L,HE)把前 16 个元素分配给 L=0,接下来 16 个给 L=1……这 16 个元素在原始张量中是 H=0,E=0,L=0 到 H=0,E=1,L=3(非完整 L 切片)。结果:每个"时间步"位置混合了来自不同 head 和不同时间步的信息。out_projection 学习从这个混合表示中提取有用特征,模型仍然收敛。这是已发布代码中的已知布局语义偏差,文档记录但不改动。
§5.6 AutoCorrelationLayer — out_projection 还原
python
out = out.view(B, L, -1)
return self.out_projection(out), attnout.view(3, 12, -1) → (3, 12, 16) [含 shape quirk]
out_projection:Linear(d_values×n_heads=2×8=16, d_model=16) → (3, 12, 16)
形状恢复为输入 shape,可直接加入残差 + 后续 DecomP + FFN。
6. FourierBlock vs AutoCorrelation 对比
| 维度 | AutoCorrelation(Autoformer) | FourierBlock(FEDformer) |
|---|---|---|
| 频域使用方式 | Q·K^* → IFFT → 时延相关曲线 → top-k lag 时移 V | Q rfft → M 频率独立复数线性变换 → irfft |
| 输出语义 | 时延加权后的 V(全时域聚合) | M 个频率子空间投影到时域 |
| K/V 的使用 | K 用于 IFFT 相关,V 用于时移加权 | K/V 完全忽略,仅用 Q |
| 学习的结构 | 哪些时延最重要(top-k lag) | 每个选中频率的线性变换矩阵 |
| 复杂度 |