Appearance
FEDformer 调试形参
§1 PyCharm Run Configuration
Script path:D:\1sudyta\1ai-self\aistyle\TFB\scripts\run_benchmark.py
Parameters:
--config-path rolling_forecast_config.json
--data-name-list ETTh1.csv
--strategy-args {"horizon": 4}
--model-name time_series_library.FEDformer
--model-hyper-params {"batch_size":3,"seq_len":12,"horizon":4,"d_model":16,"d_ff":32,"n_heads":8,"e_layers":2,"d_layers":1,"factor":1,"moving_avg":7,"output_attention":0,"num_epochs":1,"patience":3,"lr":0.001,"loss":"MSE","dropout":0.0,"embed":"timeF"}
--adapter transformer_adapter
--deterministic full
--gpus 0
--num-workers 1
--timeout 60000
--save-path debug/FEDformerWorking directory:D:\1sudyta\1ai-self\aistyle\TFB
n_heads 强制为 8
FourierBlock.weights1/2第一维硬编码为字面量8(不是n_heads参数)。若n_heads ≠ 8,compl_mul1d中einsum("bhi,hio->bho")的h维不匹配,直接报错。此参数不可调。
§2 VSCode launch.json
json
{
"name": "FEDformer Debug",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/scripts/run_benchmark.py",
"args": [
"--config-path", "rolling_forecast_config.json",
"--data-name-list", "ETTh1.csv",
"--strategy-args", "{\"horizon\": 4}",
"--model-name", "time_series_library.FEDformer",
"--model-hyper-params", "{\"batch_size\":3,\"seq_len\":12,\"horizon\":4,\"d_model\":16,\"d_ff\":32,\"n_heads\":8,\"e_layers\":2,\"d_layers\":1,\"factor\":1,\"moving_avg\":7,\"output_attention\":0,\"num_epochs\":1,\"patience\":3,\"lr\":0.001,\"loss\":\"MSE\",\"dropout\":0.0,\"embed\":\"timeF\"}",
"--adapter", "transformer_adapter",
"--deterministic", "full",
"--gpus", "0",
"--num-workers", "1",
"--timeout", "60000",
"--save-path", "debug/FEDformer"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal"
}§3 参数选择第一原理
| 参数 | 值 | 原因 |
|---|---|---|
data-name-list | ETTh1.csv | N=5 变量,toy 参数 enc_in=5 自动推断 |
batch_size | 3 | toy 参数 B=3;与 d_keys=2、modes=4 等维度数值区分 |
seq_len | 12 | 与 d_model/d_keys/pred_len 均不同 |
horizon | 4 | = pred_len;label_len 自动设为 seq_len//2=6 |
d_model | 16 | 足够小;n_heads=8 时 d_keys=E=2 |
d_ff | 32 | = d_model × 2 |
n_heads | 8(必须) | FourierBlock.weights1 第一维硬编码为 8;更改会引发 einsum 维度报错 |
moving_avg | 7 | toy 参数;奇数(AvgPool1d kernel 大小要求);需 < seq_len=12 |
e_layers | 2 | 覆盖 Encoder 双层 FourierBlock 循环 |
d_layers | 1 | 单层 DecoderLayer,保持简单 |
output_attention | 0 | FourierBlock/FourierCrossAttention 均返回 (out, None),无效但与接口兼容 |
dropout | 0.0 | 禁用随机性 |
num_epochs | 1 | 只走一次 forward 路径 |
三个不可配置参数
以下三个参数是 FEDformer.__init__ 的 Python 默认参数,_init_model() 从不传入,命令行无法修改:
| 参数 | 硬编码默认值 | 实际生效值(seq_len=12) |
|---|---|---|
version | "fourier" | FourierBlock 路径(而非 Wavelets 路径) |
mode_select | "random" | 随机选 M 个频率(非低频优先) |
modes | 32 | min(32, seq_len//2) = min(32, 6) = 6(注:toy 文档取 modes=4 是进一步简化) |
§4 形参含义速查
| 参数名 | 代码中访问路径 | 决定什么 |
|---|---|---|
seq_len | configs.seq_len | Encoder 输入长度;FourierBlock index 的采样上限 = seq_len//2=6 |
pred_len | configs.pred_len(由 horizon 自动设置) | forecast() 中 mean 重复次数;seasonal_init 零填充长度 |
label_len | 自动推断 = seq_len//2=6 | dec_input 历史段;trend_init/seasonal_init 截取起点 |
enc_in | 自动从数据列数推断 | Embedding 输入通道数;c_out 输出维度 |
d_model | configs.d_model | DataEmbedding_wo_pos 输出维度;AutoCorrelationLayer 线性投影 |
n_heads | configs.n_heads | AutoCorrelationLayer 多头数;FourierCrossAttention num_heads |
moving_avg | configs.moving_avg | series_decomp AvgPool1d kernel_size;padding = (moving_avg-1)//2=3 |
modes | Python 默认 32(不可配置) | get_frequency_modes 采样上限;实际 = min(32, seq_len//2) |
version | Python 默认 "fourier"(不可配置) | 选择 FourierBlock 还是 Wavelets 路径 |
mode_select | Python 默认 "random"(不可配置) | 频率选择策略;"random" 引入正则化效果 |
§5 循环覆盖验证
| 循环/分支 | 覆盖方法 | 当前参数是否覆盖 |
|---|---|---|
e_layers=2 Encoder 循环 | e_layers=2 | ✅ 执行 2 次 EncoderLayer(每次含 FourierBlock + 2×decomp) |
d_layers=1 Decoder 循环 | d_layers=1 | ✅ 执行 1 次 DecoderLayer(FourierBlock + FourierCrossAttention + 3×decomp) |
FourierBlock Encoder self-attn | EncoderLayer 自带 | ✅ |
FourierBlock Decoder self-attn | DecoderLayer 自带 | ✅ |
FourierCrossAttention Decoder cross-attn | DecoderLayer 自带 | ✅ |
| series_decomp padding(front + end) | moving_avg=7 → pad=3 | ✅ |
| trend_init mean 填充(pred_len 段) | pred_len=4 > 0 | ✅ |
| seasonal_init 零填充(0,0,0,4) | pred_len=4 | ✅ |
FourierBlock for wi, i in enumerate(self.index) | modes=6(实际生效) | ✅ |
| FourierCrossAttention mode 采样(index_q / index_kv 分别采样) | seq_len_q=10, seq_len_kv=12 | ✅ 两个不同 index |
| forecast() → forward() 路径 | task_name="short_term_forecast" | ✅ |
§6 Shape 追踪快速参考
ETTh1(N=5),batch_size=3,seq_len=12,pred_len=4,label_len=6:
| 步骤 | shape |
|---|---|
x_enc 输入 | (3, 12, 5) |
series_decomp(x_enc) seasonal | (3, 12, 5) |
series_decomp(x_enc) trend | (3, 12, 5) |
x_enc.mean(dim=1) | (3, 5) |
mean.unsqueeze(1).repeat(1,4,1) | (3, 4, 5) |
trend_init = cat[trend[-6:,:], mean] | (3, 10, 5) |
seasonal_init = F.pad(seasonal[-6:,:], (0,0,0,4)) | (3, 10, 5) |
enc_embedding 输出 | (3, 12, 16) |
| Encoder 输出 enc_out | (3, 12, 16) |
dec_embedding 输出 | (3, 10, 16) |
| Decoder seasonal_part | (3, 10, 5) |
| Decoder trend_part | (3, 10, 5) |
dec_out = trend_part + seasonal_part | (3, 10, 5) |
[:, -4:, :] 截取 | (3, 4, 5) |
FourierBlock 内部(Encoder self-attn,seq_len=12):
| 步骤 | shape |
|---|---|
q 输入 FourierBlock | (3, 12, 8, 2) |
q.permute(0,2,3,1) | (3, 8, 2, 12) |
rfft(x, dim=-1) | (3, 8, 2, 7)(复数) |
out_ft(初始化全零) | (3, 8, 2, 7)(复数) |
compl_mul1d 填写 M=6 个频率后 | (3, 8, 2, 7)(复数,其余为零) |
irfft(out_ft, n=12) → 返回 | (3, 8, 2, 12) |
FourierCrossAttention 内部(Decoder cross-attn):
| 步骤 | shape |
|---|---|
q 输入(来自 decoder) | (3, 10, 8, 2) |
k 输入(来自 encoder) | (3, 12, 8, 2) |
xq.permute(0,2,3,1) | (3, 8, 2, 10) |
xk.permute(0,2,3,1) | (3, 8, 2, 12) |
rfft(xq) | (3, 8, 2, 6)(复数) |
rfft(xk) | (3, 8, 2, 7)(复数) |
xq_ft_(采样 M_q 个) | (3, 8, 2, 6)(初始化大小;实际填写 M_q 格) |
xk_ft_(采样 M_kv 个) | (3, 8, 2, 6)(初始化大小;实际填写 M_kv 格) |
xqk_ft = compl_mul1d("bhex,bhey→bhxy") | (3, 8, 6, 6) |
tanh(real) + j·tanh(imag) | (3, 8, 6, 6) |
xqkv_ft = compl_mul1d("bhxy,bhey→bhex") | (3, 8, 2, 6) |
xqkvw = compl_mul1d("bhex,heox→bhox") | (3, 8, 2, 6) |
out_ft(scatter 后) | (3, 8, 2, 6)(复数) |
irfft(out_ft/256, n=10) | (3, 8, 2, 10) |
modes 实际值说明
toy 文档使用 modes=4(用于简洁追踪)。实际调试中 modes=min(32,6)=6,故上表 M=6,xqk_ft 变为 (3,8,6,6)。如需精确复现 toy 文档中的 (3,8,4,4),可在源码临时改写 get_frequency_modes 返回 index[:4]。
§7 关键断点设置
| # | 文件 | 位置 | 目的 |
|---|---|---|---|
| 1 | adapters_for_transformers.py | _process 首行 | 确认 x_enc/x_dec shape (3,12,5)/(3,10,5) |
| 2 | FEDformer.py | forecast 首行 | 进入主链 |
| 3 | FEDformer.py | seasonal_init, trend_init = self.decomp(x_enc) 后 | 验证 decomp 输出 (3,12,5)×2 |
| 4 | FEDformer.py | trend_init = torch.cat([...] 后 | 验证 trend_init (3,10,5) |
| 5 | FEDformer.py | seasonal_init = F.pad(...) 后 | 验证 seasonal_init (3,10,5) |
| 6 | FEDformer.py | enc_out, attns = self.encoder(...) 后 | 验证 enc_out (3,12,16) |
| 7 | FourierCorrelation.py | FourierBlock.forward 首行 | 验证 q shape (3,12,8,2) |
| 8 | FourierCorrelation.py | x_ft = torch.fft.rfft(x, dim=-1) 后 | 验证 x_ft shape (3,8,2,7),为复数 tensor |
| 9 | FourierCorrelation.py | for wi, i in enumerate(self.index) 内 | 查看 self.index 长度(实际 modes 数) |
| 10 | FourierCorrelation.py | FourierCrossAttention.forward 首行 | 验证 q(3,10,8,2)、k(3,12,8,2) |
| 11 | FourierCorrelation.py | xqk_ft = self.compl_mul1d(...) 后 | 验证 xqk_ft shape (3,8,M_q,M_kv) |
| 12 | FourierCorrelation.py | xqkv_ft = self.compl_mul1d("bhxy,bhey→bhex",...) 后 | 验证 xv 完全未被引用(可在此行查看 xv 变量) |
| 13 | FourierCorrelation.py | out = torch.fft.irfft(out_ft / ...) 后 | 验证 out shape (3,8,2,10) |
| 14 | Autoformer_EncDec.py | DecoderLayer.forward,trend = trend + residual_trend | 验证 trend 累加(shape 始终 (3,10,5)) |
| 15 | FEDformer.py | dec_out = trend_part + seasonal_part 后 | 验证输出 (3,10,5) |
§8 与 Autoformer 调试参数对比
| 维度 | Autoformer | FEDformer |
|---|---|---|
| adapter | transformer_adapter | transformer_adapter |
| dataset | ETTh1.csv (N=5) | ETTh1.csv (N=5) |
| seq_len | 12 | 12 |
| pred_len | 4 | 4 |
| batch_size | 2 | 3(toy B=3) |
| d_model | 8 | 16 |
| n_heads | 4(可调) | 8(强制) |
| moving_avg | 3 | 7 |
| 关键不可配参数 | 无 | version / mode_select / modes |
| 注意力机制 | AutoCorrelation(时域互相关) | FourierBlock(频域线性变换) |
| cross-attn | AutoCorrelation | FourierCrossAttention(Q×K 频域注意力) |
| v 是否有效 | ✅ K/V 均使用 | ❌ cross-attn 中 v 完全被忽略 |
| 形状拐点 | AutoCorrelation irfft (2,4,2,12) | FourierBlock irfft (3,8,2,12);FCA scatter→irfft (3,8,2,10) |
| n_heads 硬编码 | 无 | FourierBlock weights 维度写死 8 |