Skip to content

FITS 总览

ICLR 2024 · Han et al. · Frequency Interpolation Time Series Forecasting


1. 论文动机

已有方法的两个痛点:

  1. 参数量爆炸 Transformer 系模型动辄几百万参数;DLinear 在多变量时(individual=True)每个通道一个权重矩阵,参数量 = 2×seq\_len×C,seq_len 大时也不轻。FITS 目标是把整个模型压到 ~10K 参数,让部署成本接近零。

  2. 高频噪声干扰预测 时序预测的本质是捕捉低频周期趋势(日周期、周周期),而高频成分多为噪声。直接在时域建模等于让网络自己去学"忽视噪声",代价高昂。

FITS 的洞见:

"预测 = 低频成分在频域的线性外插。"

  1. 对历史序列做 FFT,只保留前 cut_freq 个低频成分(Low-Pass Filter)
  2. 用一个==复数域线性层==把频谱从 seq_len 对应的分辨率插值到 seq_len + pred_len 对应的分辨率
  3. IFFT 还原到时域,截取末尾 pred_len 步作为预测

整个模型就是一个 Linear,但它工作在复数频率域,且参数量极小。


2. 核心创新

2.1 频域低通滤波 + 复数线性插值

原始时域信号(seq_len=12 步):
  [2, 4, 6, 8, 10, 8, 6, 4, 2, 0, -2, 0]

rfft → 频域(7 个复数频率成分):
  bin0: 直流     bin1: 基频    bin2: 二次谐波
  bin3: 三次谐波  bin4: ↑LPF截止点 bin5: 0+0j  bin6: 0+0j
         ────────────── 保留 ───────────  ───── 置零 ─────

LPF 截取(保留前 dominance_freq=4 个):
  [bin0, bin1, bin2, bin3]  ← 4个复数频率

复数 Linear(4 → 5)(频域"插值",模拟 seq_len→16 的频谱):
  [bin0', bin1', bin2', bin3', bin4']  ← 5个复数频率

零填充至 full_freq_len=9:
  [bin0', bin1', ..., bin4', 0, 0, 0, 0]  ← 9个复数

irfft → 时域(16 步,= seq_len + pred_len):
  [r0, r1, ..., r11, r12, r13, r14, r15]
   ─────── past (12步) ────────  ─ future (4步) ─

取尾 pred_len=4:
  [r12, r13, r14, r15]  ← 最终预测

2.2 RevIN-like 归一化(Instance Normalization)

在 FFT 之前对每个 batch 内的每条序列做实例归一化:减去时间均值,除以时间标准差。预测后再反归一化。论文称此为 "RIN"(Reversible Instance Normalization)。

2.3 跨模型核心对比表

维度DLinearInformerAutoformerTimesNetFITS
计算域时域时域时域+频域时域→2D频域
核心操作LinearProbSparse AttnAutoCorrelationTimesBlockrfft+Linear+irfft
参数量O(L×C)O(L2) 级别O(L2) 级别O(L2) 级别~10K(与 LC 无关)
时间复杂度O(L)O(LlogL)O(LlogL)O(LlogL)O(LlogL)(rfft)
时间标记使用❌(纯信号处理)
归一化策略RevINRIN(实例归一化)
预测机制直接输出 pred_len直接输出 pred_len直接输出 pred_len直接输出 pred_len输出 seq+pred,截尾取最后 pred_len
Encoder/DecoderEncoder+DecoderEncoder+DecoderEncoder-only无(单次 forward)
位置编码加法加法
多变量建模channel-indep混合混合混合共享 W(默认)或逐通道 W

3. 论文架构(原理层)


4. TFB 调用链

FITS 适配器不走 TransformerAdapter

FITS 直接继承 DeepForecastingModelBase,不需要 --adapter transformer_adapter 参数。 model-name 格式:fits.FITS(模块路径.类名)。


5. 文档 BFS 树形索引

FITS 无子模块,BFS 树只有 2 层(Layer0 + Layer1 = forward 本身)。


6. 论文组件 → 代码对应表

论文组件代码实现精读文档
RIN(Reversible Instance Norm)forward() 前 3 行:x_mean, x_var, x/sqrt(x_var)[[02-Layer1-forward主链]] §5.2
Low-Pass Filter(LPF)low_specx[:, dominance_freq:] = 0 + slice[[02-Layer1-forward主链]] §5.4
复数频域线性层freq_upsampler = nn.Linear(...).to(torch.cfloat)[[02-Layer1-forward主链]] §5.5
频率零填充low_specxy = zeros(...) + 赋值[[02-Layer1-forward主链]] §5.6
IFFT 还原torch.fft.irfft(low_specxy, dim=1)[[02-Layer1-forward主链]] §5.7
能量补偿low_xy * self.length_ratio[[02-Layer1-forward主链]] §5.8
反归一化low_xy * sqrt(x_var) + x_mean[[02-Layer1-forward主链]] §5.9
cut_freq 计算FITS.__init__ 中公式[[01-Layer0-接入界面]] §5.2
individual 分支__init__ if/else + forward if/else[[01-Layer0-接入界面]] §5.3, [[02-Layer1-forward主链]] §5.5

7. 全局 Toy 参数

参数含义
B2batch size
seq_len12输入时序步数
pred_len4预测步数(= horizon)
enc_in3变量(通道)数
dominance_freq4LPF 截止频率(= cut_freq
rfft_len7= seq_len // 2 + 1 = 7,rfft 输出长度
freq_len_out5= int(4 × 16/12) = 5,Linear 输出频率数
full_freq_len9= (seq_len+pred_len) // 2 + 1 = 9,零填充后长度
output_len16= seq_len + pred_len = 16,irfft 输出步数
length_ratio4/3 ≈ 1.333= 16 / 12,频谱扩展比
所有维度数值不同,方便追踪

2, 12, 4, 3, 4, 7, 5, 9, 16 各不相同。rfft_len=7 与 dominance_freq=4 数值不同(7≠4),在 LPF 截断时能清晰看出"丢弃了多少频率"(7-4=3 个被置零)。


8. 推荐阅读路径

快速了解直觉:本文 §1 动机 → §2.1 ASCII 图 → [[02-Layer1-forward主链]] §5.1 宏观逻辑

完整代码精读:[[01-Layer0-接入界面]] → [[02-Layer1-forward主链]] → [[03-收束]]

调试运行:[[调试形参]] → 配 PyCharm/VSCode 后在 forward() 第一行打断点

*记录并在线阅读我的笔记*