Appearance
FITS Layer0 — 接入界面
1. 在父层中的位置
这是整套文档的入口层。FITS 类是 TFB 框架实例化和调用模型的唯一接口——读者在运行 run_benchmark.py 后,框架内部的全部初始化和调用流程都经过这一层。
2. I/O 接口定义
实例化侧
python
# fits.py
class FITS(DeepForecastingModelBase):
def __init__(self, **kwargs):
super(FITS, self).__init__(MODEL_HYPER_PARAMS, **kwargs)
self.config.cut_freq = (
int(self.seq_len // self.config.base_T + 1) * self.config.H_order + 10
)输入:来自 TFB 框架解析 --model-hyper-params 后注入的 **kwargs,最终合并进 self.config。
调用侧(_process)
python
def _process(self, input, target, input_mark, target_mark):
output, low = self.model(input)
return {"output": output}| 参数 | 类型 | 含义 | FITS 实际使用 |
|---|---|---|---|
input | (B, seq_len, enc_in) float | 历史观测序列 | ✅ 唯一输入 |
target | (B, label_len+pred_len, enc_in) float | 目标序列 | ❌ 完全忽略 |
input_mark | (B, seq_len, time_dims) float | 时间特征 | ❌ 完全忽略 |
target_mark | (B, label_len+pred_len, time_dims) float | 目标时间特征 | ❌ 完全忽略 |
返回:{"output": output},output shape = (B, seq_len+pred_len, enc_in) = (2, 16, 3)(toy 值)。
FITS 是信号处理,不是序列建模
FITS 完全不使用时间特征(
input_mark、target_mark),也不使用target(仅做损失计算)。这与 Transformer 系模型形成鲜明对比:Informer/Autoformer 的 Decoder 需要x_dec和x_mark_dec构造历史前缀。FITS 的输入只有原始数值序列。
3. 顺序图(具体层)
4. 语义分组图(索引层)
5. 逐步骤精读
5.1 MODEL_HYPER_PARAMS — 默认参数表
python
MODEL_HYPER_PARAMS = {
"embed": "timeF",
"freq": "h",
"lradj": "type1",
"factor": 1,
"activation": "gelu",
"dropout": 0.1,
"batch_size": 32,
"lr": 0.0001,
"num_epochs": 100,
"num_workers": 0,
"loss": "MSE",
"itr": 1,
"distil": True,
"patience": 3,
"cut_freq": 0,
"train_mode": 1,
"base_T": 24,
"H_order": 2,
"individual": False,
}FITS 专属参数(其他 Transformer 模型没有):
| 参数 | 默认值 | 含义 |
|---|---|---|
cut_freq | 0(占位,__init__ 中重算) | LPF 截止频率;实际值由公式计算覆盖 |
base_T | 24 | 基础周期(小时数);对 hourly 数据 = 24(日周期) |
H_order | 2 | 保留谐波阶数;考虑基频 + 2 倍频 |
individual | False | False=所有通道共享一个 Linear;True=每通道独立 Linear |
train_mode | 1 | 模型代码中未实际使用 |
embed/factor/distil 等参数的来源
embed,factor,distil这些参数是从 TransformerAdapter 时代遗留的,FITS 的FITSModel代码实际上完全不读取它们。框架在初始化时会设置这些属性,但 forward 链路中永远不会访问到。
5.2 cut_freq 计算
python
self.config.cut_freq = (
int(self.seq_len // self.config.base_T + 1) * self.config.H_order + 10
)形状注解:cut_freq 是标量整数,决定 LPF 截断位置 = 模型"看到"多少个低频成分。
计算逻辑拆解(toy: seq_len=12, base_T=6, H_order=2):
但注意:rfft 对长度 12 的序列只产生 [:, 0:16, :] 实际只返回 7 列(PyTorch 不报错,但 Linear 期望 16 列会崩溃)。
调试时需保证 cut_freq ≤ rfft_len
论文实验通常
seq_len=96~720,rfft_len=49~361,cut_freq=20~72,始终满足 cut_freq ≤ rfft_len。 调试用小数据时必须手动验证:。 调试推荐参数见 [[调试形参]]。
实际生产参数示例(seq_len=512, base_T=96, H_order=16):
5.3 individual 分支(__init__ 中)
python
if self.individual:
self.freq_upsampler = nn.ModuleList()
for i in range(self.channels):
self.freq_upsampler.append(
nn.Linear(
self.dominance_freq,
int(self.dominance_freq * self.length_ratio),
).to(torch.cfloat)
)
else:
self.freq_upsampler = nn.Linear(
self.dominance_freq, int(self.dominance_freq * self.length_ratio)
).to(torch.cfloat)individual=False(共享 W,默认,TFB 走这条):
var0 频谱 ──┐
var1 频谱 ──┼── 同一个复数 W ──→ 各通道独立扩展
var2 频谱 ──┘
W.shape: (freq_len_out, dominance_freq) = (5, 4) 复数
参数量: 2 × 5 × 4 = 40(实部+虚部)individual=True(逐通道 W):
var0 频谱 ── W₀ ── var0 扩展
var1 频谱 ── W₁ ── var1 扩展
var2 频谱 ── W₂ ── var2 扩展
每个 Wᵢ.shape: (5, 4) 复数;总参数: enc_in × 2 × 5 × 4 = 3 × 40 = 120对比表:
| individual=False | individual=True | |
|---|---|---|
freq_upsampler 类型 | nn.Linear | nn.ModuleList |
| 参数量 | ||
| 适合场景 | 变量间频域模式相似 | 变量间频域结构差异大 |
| TFB 默认 | ✅ | ❌ |
toy 数值(enc_in=3, dominance_freq=4, freq_len_out=5):
- individual=False:1 个 Linear,参数量 =
(实部 20 + 虚部 20) - individual=True:3 个 Linear,参数量 =
5.4 _process — 极简调用
python
def _process(self, input, target, input_mark, target_mark):
output, low = self.model(input)
return {"output": output}形状注解:
input: (B, seq_len, enc_in) = (2, 12, 3)output: (B, seq_len+pred_len, enc_in) = (2, 16, 3)low: (B, seq_len+pred_len, enc_in) = (2, 16, 3)(被丢弃)
low 分量为什么存在又被丢弃?
low = low_xy * sqrt(x_var)是归一化后的低频重构分量(未加 mean)。论文中提到可以把 high-frequency 残差x - low再送入 DLinear 处理(注释行# dom_xy=self.Dlinear(dom_x)),但 TFB 版本注释掉了这个分支,low仅作为 debug 辅助输出被返回后立刻丢弃。
toy 数值追踪:
- 框架在
forecast_fit中调用_process后,紧接着执行:pythontarget = target[:, -config.horizon:, :series_dim] # (2, 4, 3) output = output[:, -config.horizon:, :series_dim] # (2, 4, 3),截取最后 pred_len=4 步 loss = criterion(output, target) - 因此模型输出的第 0~11 步(历史重构)在训练时完全不参与损失计算,只有第 12~15 步(预测部分)被用于优化。
6. 下钻子组件列表
| 子组件 | 职责 | 下层文档 |
|---|---|---|
FITSModel.forward() | RIN → rfft → LPF → 复数 Linear → 零填充 → irfft → 反归一化 | [[02-Layer1-forward主链]] |