Skip to content

FITS Layer0 — 接入界面


1. 在父层中的位置

这是整套文档的入口层。FITS 类是 TFB 框架实例化和调用模型的唯一接口——读者在运行 run_benchmark.py 后,框架内部的全部初始化和调用流程都经过这一层。


2. I/O 接口定义

实例化侧

python
# fits.py
class FITS(DeepForecastingModelBase):
    def __init__(self, **kwargs):
        super(FITS, self).__init__(MODEL_HYPER_PARAMS, **kwargs)
        self.config.cut_freq = (
            int(self.seq_len // self.config.base_T + 1) * self.config.H_order + 10
        )

输入:来自 TFB 框架解析 --model-hyper-params 后注入的 **kwargs,最终合并进 self.config

调用侧(_process

python
def _process(self, input, target, input_mark, target_mark):
    output, low = self.model(input)
    return {"output": output}
参数类型含义FITS 实际使用
input(B, seq_len, enc_in) float历史观测序列✅ 唯一输入
target(B, label_len+pred_len, enc_in) float目标序列❌ 完全忽略
input_mark(B, seq_len, time_dims) float时间特征❌ 完全忽略
target_mark(B, label_len+pred_len, time_dims) float目标时间特征❌ 完全忽略

返回{"output": output}output shape = (B, seq_len+pred_len, enc_in) = (2, 16, 3)(toy 值)。

FITS 是信号处理,不是序列建模

FITS 完全不使用时间特征(input_marktarget_mark),也不使用 target(仅做损失计算)。这与 Transformer 系模型形成鲜明对比:Informer/Autoformer 的 Decoder 需要 x_decx_mark_dec 构造历史前缀。FITS 的输入只有原始数值序列。


3. 顺序图(具体层)


4. 语义分组图(索引层)


5. 逐步骤精读

5.1 MODEL_HYPER_PARAMS — 默认参数表

python
MODEL_HYPER_PARAMS = {
    "embed": "timeF",
    "freq": "h",
    "lradj": "type1",
    "factor": 1,
    "activation": "gelu",
    "dropout": 0.1,
    "batch_size": 32,
    "lr": 0.0001,
    "num_epochs": 100,
    "num_workers": 0,
    "loss": "MSE",
    "itr": 1,
    "distil": True,
    "patience": 3,
    "cut_freq": 0,
    "train_mode": 1,
    "base_T": 24,
    "H_order": 2,
    "individual": False,
}

FITS 专属参数(其他 Transformer 模型没有):

参数默认值含义
cut_freq0(占位,__init__ 中重算)LPF 截止频率;实际值由公式计算覆盖
base_T24基础周期(小时数);对 hourly 数据 = 24(日周期)
H_order2保留谐波阶数;考虑基频 + 2 倍频
individualFalseFalse=所有通道共享一个 Linear;True=每通道独立 Linear
train_mode1模型代码中未实际使用
embed/factor/distil 等参数的来源

embed, factor, distil 这些参数是从 TransformerAdapter 时代遗留的,FITS 的 FITSModel 代码实际上完全不读取它们。框架在初始化时会设置这些属性,但 forward 链路中永远不会访问到。

5.2 cut_freq 计算

python
self.config.cut_freq = (
    int(self.seq_len // self.config.base_T + 1) * self.config.H_order + 10
)

形状注解cut_freq 是标量整数,决定 LPF 截断位置 = 模型"看到"多少个低频成分。

计算逻辑拆解(toy: seq_len=12, base_T=6, H_order=2)

cut\_freq=(126+1)×2+10=(2+1)×2+10=16

但注意:rfft 对长度 12 的序列只产生 12//2+1=7 个频率成分。如果 cut_freq=16 > rfft_len=7,则 LPF 切片 [:, 0:16, :] 实际只返回 7 列(PyTorch 不报错,但 Linear 期望 16 列会崩溃)。

调试时需保证 cut_freq ≤ rfft_len

论文实验通常 seq_len=96~720,rfft_len=49~361,cut_freq=20~72,始终满足 cut_freq ≤ rfft_len。 调试用小数据时必须手动验证:rfft\_len=seq\_len//2+1cut\_freq。 调试推荐参数见 [[调试形参]]。

实际生产参数示例seq_len=512, base_T=96, H_order=16):

cut\_freq=(51296+1)×16+10=(5+1)×16+10=106rfft\_len=512//2+1=257106

5.3 individual 分支(__init__ 中)

python
if self.individual:
    self.freq_upsampler = nn.ModuleList()
    for i in range(self.channels):
        self.freq_upsampler.append(
            nn.Linear(
                self.dominance_freq,
                int(self.dominance_freq * self.length_ratio),
            ).to(torch.cfloat)
        )
else:
    self.freq_upsampler = nn.Linear(
        self.dominance_freq, int(self.dominance_freq * self.length_ratio)
    ).to(torch.cfloat)

individual=False(共享 W,默认,TFB 走这条)

var0 频谱 ──┐
var1 频谱 ──┼── 同一个复数 W ──→ 各通道独立扩展
var2 频谱 ──┘
W.shape: (freq_len_out, dominance_freq) = (5, 4) 复数
参数量: 2 × 5 × 4 = 40(实部+虚部)

individual=True(逐通道 W)

var0 频谱 ── W₀ ── var0 扩展
var1 频谱 ── W₁ ── var1 扩展
var2 频谱 ── W₂ ── var2 扩展
每个 Wᵢ.shape: (5, 4) 复数;总参数: enc_in × 2 × 5 × 4 = 3 × 40 = 120

对比表:

individual=Falseindividual=True
freq_upsampler 类型nn.Linearnn.ModuleList
参数量2×freq\_out×dom\_freq×enc\_in
适合场景变量间频域模式相似变量间频域结构差异大
TFB 默认

toy 数值(enc_in=3, dominance_freq=4, freq_len_out=5)

  • individual=False:1 个 Linear,参数量 = 2×5×4=40(实部 20 + 虚部 20)
  • individual=True:3 个 Linear,参数量 = 3×40=120

5.4 _process — 极简调用

python
def _process(self, input, target, input_mark, target_mark):
    output, low = self.model(input)
    return {"output": output}

形状注解

  • input: (B, seq_len, enc_in) = (2, 12, 3)
  • output: (B, seq_len+pred_len, enc_in) = (2, 16, 3)
  • low: (B, seq_len+pred_len, enc_in) = (2, 16, 3)(被丢弃)
low 分量为什么存在又被丢弃?

low = low_xy * sqrt(x_var) 是归一化后的低频重构分量(未加 mean)。论文中提到可以把 high-frequency 残差 x - low 再送入 DLinear 处理(注释行 # dom_xy=self.Dlinear(dom_x)),但 TFB 版本注释掉了这个分支,low 仅作为 debug 辅助输出被返回后立刻丢弃。

toy 数值追踪

  • 框架在 forecast_fit 中调用 _process 后,紧接着执行:
    python
    target = target[:, -config.horizon:, :series_dim]   # (2, 4, 3)
    output = output[:, -config.horizon:, :series_dim]   # (2, 4, 3),截取最后 pred_len=4 步
    loss = criterion(output, target)
  • 因此模型输出的第 0~11 步(历史重构)在训练时完全不参与损失计算,只有第 12~15 步(预测部分)被用于优化。

6. 下钻子组件列表

子组件职责下层文档
FITSModel.forward()RIN → rfft → LPF → 复数 Linear → 零填充 → irfft → 反归一化[[02-Layer1-forward主链]]

*记录并在线阅读我的笔记*