Appearance
FITS 总览
ICLR 2024 · Han et al. · Frequency Interpolation Time Series Forecasting
1. 论文动机
已有方法的两个痛点:
参数量爆炸 Transformer 系模型动辄几百万参数;DLinear 在多变量时(individual=True)每个通道一个权重矩阵,参数量 =
,seq_len 大时也不轻。FITS 目标是把整个模型压到 ~10K 参数,让部署成本接近零。 高频噪声干扰预测 时序预测的本质是捕捉低频周期趋势(日周期、周周期),而高频成分多为噪声。直接在时域建模等于让网络自己去学"忽视噪声",代价高昂。
FITS 的洞见:
"预测 = 低频成分在频域的线性外插。"
- 对历史序列做 FFT,只保留前
cut_freq个低频成分(Low-Pass Filter) - 用一个==复数域线性层==把频谱从
seq_len对应的分辨率插值到seq_len + pred_len对应的分辨率 - IFFT 还原到时域,截取末尾
pred_len步作为预测
整个模型就是一个 Linear,但它工作在复数频率域,且参数量极小。
2. 核心创新
2.1 频域低通滤波 + 复数线性插值
原始时域信号(seq_len=12 步):
[2, 4, 6, 8, 10, 8, 6, 4, 2, 0, -2, 0]
rfft → 频域(7 个复数频率成分):
bin0: 直流 bin1: 基频 bin2: 二次谐波
bin3: 三次谐波 bin4: ↑LPF截止点 bin5: 0+0j bin6: 0+0j
────────────── 保留 ─────────── ───── 置零 ─────
LPF 截取(保留前 dominance_freq=4 个):
[bin0, bin1, bin2, bin3] ← 4个复数频率
复数 Linear(4 → 5)(频域"插值",模拟 seq_len→16 的频谱):
[bin0', bin1', bin2', bin3', bin4'] ← 5个复数频率
零填充至 full_freq_len=9:
[bin0', bin1', ..., bin4', 0, 0, 0, 0] ← 9个复数
irfft → 时域(16 步,= seq_len + pred_len):
[r0, r1, ..., r11, r12, r13, r14, r15]
─────── past (12步) ──────── ─ future (4步) ─
取尾 pred_len=4:
[r12, r13, r14, r15] ← 最终预测2.2 RevIN-like 归一化(Instance Normalization)
在 FFT 之前对每个 batch 内的每条序列做实例归一化:减去时间均值,除以时间标准差。预测后再反归一化。论文称此为 "RIN"(Reversible Instance Normalization)。
2.3 跨模型核心对比表
| 维度 | DLinear | Informer | Autoformer | TimesNet | FITS |
|---|---|---|---|---|---|
| 计算域 | 时域 | 时域 | 时域+频域 | 时域→2D | 频域 |
| 核心操作 | Linear | ProbSparse Attn | AutoCorrelation | TimesBlock | rfft+Linear+irfft |
| 参数量 | ~10K(与 | ||||
| 时间复杂度 | |||||
| 时间标记使用 | ❌ | ✅ | ✅ | ✅ | ❌(纯信号处理) |
| 归一化策略 | 无 | 无 | 无 | RevIN | RIN(实例归一化) |
| 预测机制 | 直接输出 pred_len | 直接输出 pred_len | 直接输出 pred_len | 直接输出 pred_len | 输出 seq+pred,截尾取最后 pred_len |
| Encoder/Decoder | 无 | Encoder+Decoder | Encoder+Decoder | Encoder-only | 无(单次 forward) |
| 位置编码 | 无 | 加法 | 无 | 加法 | 无 |
| 多变量建模 | channel-indep | 混合 | 混合 | 混合 | 共享 W(默认)或逐通道 W |
3. 论文架构(原理层)
4. TFB 调用链
FITS 适配器不走 TransformerAdapter
FITS 直接继承
DeepForecastingModelBase,不需要--adapter transformer_adapter参数。 model-name 格式:fits.FITS(模块路径.类名)。
5. 文档 BFS 树形索引
FITS 无子模块,BFS 树只有 2 层(Layer0 + Layer1 = forward 本身)。
6. 论文组件 → 代码对应表
| 论文组件 | 代码实现 | 精读文档 |
|---|---|---|
| RIN(Reversible Instance Norm) | forward() 前 3 行:x_mean, x_var, x/sqrt(x_var) | [[02-Layer1-forward主链]] §5.2 |
| Low-Pass Filter(LPF) | low_specx[:, dominance_freq:] = 0 + slice | [[02-Layer1-forward主链]] §5.4 |
| 复数频域线性层 | freq_upsampler = nn.Linear(...).to(torch.cfloat) | [[02-Layer1-forward主链]] §5.5 |
| 频率零填充 | low_specxy = zeros(...) + 赋值 | [[02-Layer1-forward主链]] §5.6 |
| IFFT 还原 | torch.fft.irfft(low_specxy, dim=1) | [[02-Layer1-forward主链]] §5.7 |
| 能量补偿 | low_xy * self.length_ratio | [[02-Layer1-forward主链]] §5.8 |
| 反归一化 | low_xy * sqrt(x_var) + x_mean | [[02-Layer1-forward主链]] §5.9 |
| cut_freq 计算 | FITS.__init__ 中公式 | [[01-Layer0-接入界面]] §5.2 |
| individual 分支 | __init__ if/else + forward if/else | [[01-Layer0-接入界面]] §5.3, [[02-Layer1-forward主链]] §5.5 |
7. 全局 Toy 参数
| 参数 | 值 | 含义 |
|---|---|---|
B | 2 | batch size |
seq_len | 12 | 输入时序步数 |
pred_len | 4 | 预测步数(= horizon) |
enc_in | 3 | 变量(通道)数 |
dominance_freq | 4 | LPF 截止频率(= cut_freq) |
rfft_len | 7 | = seq_len // 2 + 1 = 7,rfft 输出长度 |
freq_len_out | 5 | = int(4 × 16/12) = 5,Linear 输出频率数 |
full_freq_len | 9 | = (seq_len+pred_len) // 2 + 1 = 9,零填充后长度 |
output_len | 16 | = seq_len + pred_len = 16,irfft 输出步数 |
length_ratio | 4/3 ≈ 1.333 | = 16 / 12,频谱扩展比 |
所有维度数值不同,方便追踪
2, 12, 4, 3, 4, 7, 5, 9, 16 各不相同。rfft_len=7 与 dominance_freq=4 数值不同(7≠4),在 LPF 截断时能清晰看出"丢弃了多少频率"(7-4=3 个被置零)。
8. 推荐阅读路径
快速了解直觉:本文 §1 动机 → §2.1 ASCII 图 → [[02-Layer1-forward主链]] §5.1 宏观逻辑
完整代码精读:[[01-Layer0-接入界面]] → [[02-Layer1-forward主链]] → [[03-收束]]
调试运行:[[调试形参]] → 配 PyCharm/VSCode 后在 forward() 第一行打断点