Appearance
FITS Layer1 — forward 主链
1. 在父层中的位置
_process() 调用 self.model(input),进入 FITSModel.forward(x)。这是 FITS 的全部计算逻辑所在,无子模块可供下钻。
2. I/O 接口定义
python
def forward(self, x):
...
return xy, low_xy * torch.sqrt(x_var)| 参数/返回 | shape | 含义 |
|---|---|---|
x(输入) | (B, seq_len, enc_in) = (2, 12, 3) | 原始历史序列 |
xy(第一返回值) | (B, seq_len+pred_len, enc_in) = (2, 16, 3) | 完整重构+预测序列(反归一化后) |
low_xy * sqrt(x_var)(第二返回值) | (2, 16, 3) | 低频重构的尺度还原版(调用方丢弃) |
3. 顺序图(具体层)
4. 语义分组图(索引层)
5. 逐步骤精读
5.0 完整原始代码
python
class FITSModel(nn.Module):
# FITS: Frequency Interpolation Time Series Forecasting
def __init__(self, configs):
super(FITSModel, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.individual = configs.individual
self.channels = configs.enc_in
self.dominance_freq = configs.cut_freq # 720/24
self.length_ratio = (self.seq_len + self.pred_len) / self.seq_len
if self.individual:
self.freq_upsampler = nn.ModuleList()
for i in range(self.channels):
self.freq_upsampler.append(
nn.Linear(
self.dominance_freq,
int(self.dominance_freq * self.length_ratio),
).to(torch.cfloat)
)
else:
self.freq_upsampler = nn.Linear(
self.dominance_freq, int(self.dominance_freq * self.length_ratio)
).to(
torch.cfloat
) # complex layer for frequency upcampling]
# configs.pred_len=configs.seq_len+configs.pred_len
# #self.Dlinear=DLinear.Model(configs)
# configs.pred_len=self.pred_len
def forward(self, x):
# RIN
x_mean = torch.mean(x, dim=1, keepdim=True)
x = x - x_mean
x_var = torch.var(x, dim=1, keepdim=True) + 1e-5
# print(x_var)
x = x / torch.sqrt(x_var)
low_specx = torch.fft.rfft(x, dim=1)
low_specx[:, self.dominance_freq :] = 0 # LPF
low_specx = low_specx[:, 0 : self.dominance_freq, :] # LPF
# print(low_specx.permute(0,2,1))
if self.individual:
low_specxy_ = torch.zeros(
[
low_specx.size(0),
int(self.dominance_freq * self.length_ratio),
low_specx.size(2),
],
dtype=low_specx.dtype,
).to(low_specx.device)
for i in range(self.channels):
low_specxy_[:, :, i] = self.freq_upsampler[i](
low_specx[:, :, i].permute(0, 1)
).permute(0, 1)
else:
low_specxy_ = self.freq_upsampler(low_specx.permute(0, 2, 1)).permute(
0, 2, 1
)
# print(low_specxy_)
low_specxy = torch.zeros(
[
low_specxy_.size(0),
int((self.seq_len + self.pred_len) / 2 + 1),
low_specxy_.size(2),
],
dtype=low_specxy_.dtype,
).to(low_specxy_.device)
low_specxy[:, 0 : low_specxy_.size(1), :] = low_specxy_ # zero padding
low_xy = torch.fft.irfft(low_specxy, dim=1)
low_xy = low_xy * self.length_ratio # energy compemsation for the length change
# dom_x=x-low_x
# dom_xy=self.Dlinear(dom_x)
# xy=(low_xy+dom_xy) * torch.sqrt(x_var) +x_mean # REVERSE RIN
xy = (low_xy) * torch.sqrt(x_var) + x_mean
return xy, low_xy * torch.sqrt(x_var)5.1 宏观逻辑
一句话目标
把历史序列 → 频域 → 只保留低频 → 线性插值到更长序列的频域 → 还原时域,得到"过去+未来"的低频重构,末尾就是预测。
Why before What(用小例子,B=1,L=8,C=1,dom=3):
原信号: x = [1, 2, 3, 2, 1, 0, -1, 0] (L=8)
① 为什么先 RIN?
FFT 对直流分量(均值)敏感;不同序列均值不同导致 Linear 需要记住"基准线",浪费容量。
RIN 把序列平移到零均值后,Linear 只需学习形状,不必学习偏置。
② 为什么 rfft 而非 fft?
rfft 利用实数信号的共轭对称性,只返回前 L//2+1=5 个复数成分(后半段是前半段的共轭)。
irfft 从这 5 个复数完整恢复 8 点实数序列。节省一半存储。
③ 为什么要 LPF(置零+截取)?
时序预测关注趋势和周期,高频成分(bin3,4)多为噪声。
置零:物理上是低通滤波;截取:让 Linear 输入维度固定为 dom=3,而不是 rfft_len=5。
rfft: [c0, c1, c2, c3, c4]
LPF: [c0, c1, c2, 0+0j, 0+0j] → 截取 → [c0, c1, c2]
④ 为什么 Linear 工作在频域(且是复数)?
频域插值等价于时域插值(Sinc 插值的离散版)。把 3 个低频成分映射到 ceil(3 × (8+4)/8)=5 个,
等价于把信号"外推"到更长的时间轴 L=12=8+4。复数 Linear 同时学习幅度和相位的变换。
⑤ 为什么零填充而不让 Linear 直接输出 full_freq_len=7?
full_freq_len=(8+4)//2+1=7。Linear 输出 5 个低频成分,高频部分(bin5,6)设为 0,
等价于滤波器在高频不引入新成分。直接让 Linear 输出 7 个等于让它学习高频成分,违背低通假设。
⑥ 为什么乘以 length_ratio?
irfft 在信号长度变化时能量会被稀释。从 L=8 插值到 L=12,irfft 输出的幅度
会被缩小为原来的 L_new/L_old = 12/8 = 1.5 倍的倒数,需要乘回去补偿。
shape 变化全链:
(1,8,1) → RIN → (1,8,1) → rfft → (1,5,1)cplx → LPF → (1,3,1)cplx
→ Linear → (1,4,1)cplx → zero-pad → (1,7,1)cplx → irfft → (1,12,1) → × ratio → RIN-1 → (1,12,1)图解 — FITSModel 整体架构流:
5.2 RIN — 减均值
python
x_mean = torch.mean(x, dim=1, keepdim=True)
x = x - x_mean形状注解:
x: (B, seq_len, enc_in) = (2, 12, 3)x_mean: (B, 1, enc_in) = (2, 1, 3)(keepdim=True保持第 1 维为 1)x - x_mean:广播,(2, 12, 3) - (2, 1, 3)→(2, 12, 3)
keepdim=True 的设计约束
mean(dim=1, keepdim=True)的结果 dim=1 一定为 1,保证后续x - x_mean的广播语义正确(对每个时间步减去同一个均值)。若 keepdim=False,x_mean为(2, 3),则x - x_mean需要手动 unsqueeze,代码会复杂。
keepdim x_mean shape x - x_mean广播结果True (2, 1, 3) (2,12,3) - (2,1,3)→ 沿时间轴广播,每步减同一均值 ✓False (2, 3) 需要 .unsqueeze(1)才能广播,否则维度不对齐
toy 数值追踪(以 batch=0, channel=0 为例):
取 x[0, :, 0] = [2.0, 4.0, 6.0, 8.0, 10.0, 8.0, 6.0, 4.0, 2.0, 0.0, -2.0, 0.0]
减均值后:x[0, :, 0] = [-2, 0, 2, 4, 6, 4, 2, 0, -2, -4, -6, -4]
5.3 RIN — 除标准差
python
x_var = torch.var(x, dim=1, keepdim=True) + 1e-5
x = x / torch.sqrt(x_var)形状注解:
x_var: (B, 1, enc_in) = (2, 1, 3)(keepdim=True,与 §5.2 同理)torch.sqrt(x_var): (2, 1, 3)x / sqrt(x_var): (2, 12, 3) / (2, 1, 3) → (2, 12, 3)(广播)
torch.var 默认 Bessel 校正
torch.var 默认 Bessel 校正
torch.var(x, dim=1)默认correction=1(Bessel 校正,分母为)。与 Autoformer 系等用 stdev = sqrt(var(correction=1))的做法一致。
toy 数值追踪(接续 §5.2,已减均值):
x[0, :, 0] = [-2, 0, 2, 4, 6, 4, 2, 0, -2, -4, -6, -4]
归一化后:x[0, :, 0] ≈ [-0.538, 0, 0.538, 1.076, 1.613, 1.076, 0.538, 0, -0.538, -1.076, -1.613, -1.076]
5.4 rfft + LPF
python
low_specx = torch.fft.rfft(x, dim=1)
low_specx[:, self.dominance_freq :] = 0 # LPF
low_specx = low_specx[:, 0 : self.dominance_freq, :] # LPFrfft:
torch.fft.rfft 对实数信号做 FFT,只返回非冗余的前半部分频率成分。
- 输入:
x (2, 12, 3)real → 输出:low_specx (2, 7, 3)complex(torch.cfloat) dim=1:沿时间轴做 FFT,每个变量各自独立变换
频率含义:bin
| bin | |||||||
|---|---|---|---|---|---|---|---|
| 周期 | ∞(直流) | 12 步 | 6 步 | 4 步 | 3 步 | 2.4 步 | 2 步 |
| LPF 保留? | ✅ | ✅ | ✅ | ✅ | ❌→0 | ❌→0 | ❌→0 |
LPF 置零:low_specx[:, 4:] = 0 — 将 bin4~6(索引 4、5、6)置为 0+0j
LPF 截取:low_specx[:, 0:4, :] — 取前 4 列,shape (2, 4, 3) complex
图解 — LPF 频谱可视化:
| 操作 | 输入 shape | 输出 shape | 关键变化 |
|---|---|---|---|
| rfft | (2, 12, 3) real | (2, 7, 3) complex | 12步→7个复数频率成分 |
| LPF置零 | (2, 7, 3) | (2, 7, 3)(in-place) | bin4~6 变为 0+0j |
| LPF截取 | (2, 7, 3) | (2, 4, 3) | 仅保留 bin0~bin3 |
toy 数值追踪(batch=0, channel=0 的近似 rfft 值):
对归一化后的 x[0, :, 0] ≈ [-0.538, 0, 0.538, 1.076, 1.613, 1.076, 0.538, 0, -0.538, -1.076, -1.613, -1.076],其 rfft 输出约为:
bin0: (0+0j) ← 直流分量 = sum/L = 0(已减均值)
bin1: (-0.003+0j) ← 12步周期成分(主成分)
bin2: (-5.386+0j) ← 6步周期(实际值由信号决定)
bin3: (0.003+0j) ← 4步周期
[bin4,5,6]: → 置零为 0+0j截取后:low_specx[0, :, 0] ≈ [(0+0j), (-0.003+0j), (-5.386+0j), (0.003+0j)](4个复数)
rfft 的 torch.cfloat 类型
rfft的输出类型是torch.cfloat(32位复数 = 2×float32)。后续的freq_upsamplerLinear 也必须是.to(torch.cfloat),才能对复数张量做线性变换(分别对实部和虚部做同样的线性映射)。
5.5 freq_upsampler — 复数域线性插值
shared 分支(individual=False,TFB 默认):
python
low_specxy_ = self.freq_upsampler(low_specx.permute(0, 2, 1)).permute(0, 2, 1)形状注解:
low_specx.permute(0, 2, 1):(2, 4, 3)→(2, 3, 4)(把 enc_in 换到第1维,让 Linear 作用于最后一维的 dominance_freq=4)self.freq_upsampler(nn.Linear(4, 5, dtype=cfloat)):(2, 3, 4)→(2, 3, 5).permute(0, 2, 1):(2, 3, 5)→(2, 5, 3)还原为(B, freq_len_out, enc_in)格式
图解 — permute 语义变化:
permute 前: (2, 4, 3)
dim0 = B=2 dim1 = 频率=4 dim2 = 通道=3
Linear 作用在 dim2=通道=3 ← 错误!应该作用在频率维
permute(0,2,1) → (2, 3, 4)
dim0 = B=2 dim1 = 通道=3 dim2 = 频率=4
Linear 作用在 dim2=频率=4 ← 正确 ✓
Linear(4→5): (2, 3, 4) → (2, 3, 5)
每个 (B, channel) 对独立做 4→5 的线性变换
permute(0,2,1) → (2, 5, 3) 还原
dim0 = B=2 dim1 = 频率=5 dim2 = 通道=3图解 — 频域插值可视化:
nn.Linear 在复数域的工作方式:
nn.Linear 的权重 W 是 torch.cfloat,公式为:
其中
每个输出复数频率成分是输入 4 个复数成分的线性组合,同时学习幅度(模)和相位(角)的映射。
toy 数值追踪:
输入 low_specx[0, :, 0](4个复数),permute 后成为 low_specx_p[0, 0, :](一行4个复数)。 经过 Linear(4→5) 后:输出 low_specxy_p[0, 0, :](一行5个复数)。 permute 回来:low_specxy_[0, :, 0](5个复数)。
结果 shape:low_specxy_: (2, 5, 3) complex
individual=True 分支(注意 ⚠️ 冗余代码):
python
for i in range(self.channels):
low_specxy_[:, :, i] = self.freq_upsampler[i](
low_specx[:, :, i].permute(0, 1)
).permute(0, 1)⚠️ 源码冗余:low_specx[:, :, i] 的 shape 是 (B, dominance_freq) = (2, 4)(2D 张量)。对 2D 张量调用 .permute(0, 1) 是恒等操作(permute(0,1) = 不变),等价于:
python
low_specxy_[:, :, i] = self.freq_upsampler[i](low_specx[:, :, i])nn.Linear(4, 5) 作用在最后一维 dominance_freq=4 上,输入 (2, 4),输出 (2, 5),赋值给 low_specxy_[:, :, i] 的 shape (2, 5)。结果是正确的,冗余的 permute 不影响正确性。
toy 数值追踪(individual=True 分支,toy: enc_in=3, dominance_freq=4, freq_len_out=5):
i=0: low_specx[:, :, 0] → (2, 4) complex
freq_upsampler[0]: Linear(4→5, cfloat)
输出: (2, 5) complex → 赋值给 low_specxy_[:, :, 0]
i=1: low_specx[:, :, 1] → (2, 4) complex → Linear_1(4→5) → (2, 5)
i=2: low_specx[:, :, 2] → (2, 4) complex → Linear_2(4→5) → (2, 5)
循环 3 次(enc_in=3),low_specxy_ 最终 shape: (2, 5, 3) complex5.6 零填充到 full_freq_len
python
low_specxy = torch.zeros(
[
low_specxy_.size(0),
int((self.seq_len + self.pred_len) / 2 + 1),
low_specxy_.size(2),
],
dtype=low_specxy_.dtype,
).to(low_specxy_.device)
low_specxy[:, 0 : low_specxy_.size(1), :] = low_specxy_ # zero padding形状注解:
- 新建全零张量
low_specxy: (2, 9, 3)complex - 赋值:前
freq_len_out=5行 =low_specxy_,后9-5=4行保持0+0j
图解 — 零填充语义:
low_specxy_ (2, 5, 3):
[c0', c1', c2', c3', c4'] ← 插值后的5个低频成分
low_specxy (2, 9, 3):
[c0', c1', c2', c3', c4', 0+0j, 0+0j, 0+0j, 0+0j]
←────── 低频段 (5个) ──────→ ←── 高频段置零 (4个) ──→
对应 irfft 输出长度:
(full_freq_len - 1) × 2 = (9-1) × 2 = 16 = seq_len + pred_len ✓irfft 输出长度的计算
torch.fft.irfft(x, dim=1)的输出长度 =(x.size(dim) - 1) × 2(默认 n=None)。 这里x.size(1) = 9,所以输出长度 =(9-1) × 2 = 16 = seq_len + pred_len。这正是 full_freq_len 的设计:,其中 。
toy 数值追踪:
low_specxy[0, :, 0] = [c0'≈..., c1'≈..., c2'≈..., c3'≈..., c4'≈..., 0+0j, 0+0j, 0+0j, 0+0j]
全部 9 个复数,后 4 个为零。
5.7 irfft — 还原时域
python
low_xy = torch.fft.irfft(low_specxy, dim=1)形状注解:low_specxy: (2, 9, 3) complex → low_xy: (2, 16, 3) real
irfft 是 rfft 的逆变换:给定半侧频谱,重建完整实数序列。输出长度 = (9-1)×2 = 16。
toy 数值追踪(batch=0, channel=0 近似):
low_xy[0, :, 0] 是长度 16 的实数序列。前 12 个值是历史序列的低频重构,后 4 个是外推预测:
low_xy[0, :, 0] ≈ [rₒ₀, r₁, r₂, ..., r₁₁, r₁₂, r₁₃, r₁₄, r₁₅]
─── 历史重构 12步 ─── ──── 预测 4步 ────实际数值由训练好的 freq_upsampler 权重决定,调试时在断点查看。
5.8 能量补偿
python
low_xy = low_xy * self.length_ratio # energy compemsation for the length change形状注解:(2, 16, 3) × 标量 length_ratio = 4/3 ≈ 1.333 → (2, 16, 3)
为什么需要能量补偿?
irfft在处理不同长度信号时的能量归一化方式是,其中 是输出长度。 当我们把频谱从 seq_len=12的分辨率插值到seq_len+pred_len=16时,irfft 按归一化,而不是原来的 ,导致输出幅度偏小 倍。 补偿因子 length_ratio = 16/12 = 4/3正好纠正这个缩放差异。源码注释拼写有误(
compemsation,⚠️ typo),实际含义正确。
toy 数值追踪:low_xy[0, 0, 0] × 4/3,每个值乘以约 1.333。
5.9 反归一化
python
xy = (low_xy) * torch.sqrt(x_var) + x_mean
return xy, low_xy * torch.sqrt(x_var)形状注解:
low_xy: (2, 16, 3)realx_var: (2, 1, 3)— 与 §5.3 保存的方差(keepdim=True)x_mean: (2, 1, 3)— 与 §5.2 保存的均值- 广播:
(2, 16, 3) × (2, 1, 3) + (2, 1, 3)→(2, 16, 3)
⚠️ 注意 x_mean 和 x_var 保存的是原始 x 的统计量
在 §5.2 中,
x_mean是从原始x计算的。之后x被原地修改(x = x - x_mean,x = x / sqrt(x_var)),但x_mean和x_var张量保持不变,可以在最后用于反归一化。这是通过 Python 变量绑定保证的:x_mean和x_var始终指向原始计算结果,不随x的重赋值而改变。
反归一化公式:
toy 数值追踪(batch=0, channel=0):
x_mean[0, 0, 0] = 4.0(§5.2 计算)sqrt(x_var[0, 0, 0]) ≈ 3.718(§5.3 计算)xy[0, :, 0] ≈ low_xy[0, :, 0] × 3.718 + 4.0(按元素)
前 12 步是历史重构,后 4 步是预测:
xy[0, 12:16, 0] ≈ [r₁₂, r₁₃, r₁₄, r₁₅] × 3.718 + 4.0这 4 个值就是 FITS 对 batch=0, channel=0 的预测结果。
返回值:
xy: (2, 16, 3)(反归一化后的全序列)low_xy * sqrt(x_var): (2, 16, 3)(未加均值的低频重构,_process中被丢弃)
训练循环截取:output[:, -4:, :] → (2, 4, 3),与 target[:, -4:, :] 计算 MSE。
6. 下钻子组件列表
FITS 无需下钻——所有计算在 forward() 一层完成,无独立子模块(freq_upsampler 是 nn.Linear 的直接调用,不开新文档)。