Appearance
03B1-Layer3-FFT_for_Period
本文件位置
上层:[[03B-Layer2B-TimesBlock]]
入口代码:period_list, period_weight = FFT_for_Period(x, self.k)
入口函数:FFT_for_Period(x, k=2)
出口:period与period_weight。
1. 本层顺序树
1.1 语义分组图
2. 输入输出接口
| 变量 | toy shape | 含义 |
|---|---|---|
x | (3,13,6) | TimesBlock 输入 hidden 序列 |
xf | (3,7,6) | rfft 输出,实数 FFT 只保留非负频率,长度为 T//2+1=7 |
frequency_list | (7,) | 每个频率在 batch 和 channel 上的平均幅值 |
top_list | (2,) | 幅值最大的频率索引 |
period | (2,) | 周期长度,计算为 T // top_list |
period_weight | (3,2) | 每个 batch 对 top 频率的幅值,用于后续 softmax 加权 |
3. 对照源码
位置:ts_benchmark/baselines/time_series_library/models/TimesNet.py
python
def FFT_for_Period(x, k=2):
# [B, T, C]
xf = torch.fft.rfft(x, dim=1)
# find period by amplitudes
frequency_list = abs(xf).mean(0).mean(-1)
frequency_list[0] = 0
_, top_list = torch.topk(frequency_list, k)
top_list = top_list.detach().cpu().numpy()
period = x.shape[1] // top_list
return period, abs(xf).mean(-1)[:, top_list]4. FFT 的数学含义
对每个 batch 和 channel,沿时间维计算:
abs(xf) 得到幅值:
代码中的全局频率强度:
然后选出 score_f 最大的 k 个频率索引。
5. toy 数值例子
令:
text
B=3, T=13, C=6, k=2
rfft频率索引: [0,1,2,3,4,5,6]假设 abs(xf).mean(0).mean(-1) 得到:
text
frequency_list before zero:
[9.5, 1.2, 0.4, 3.0, 0.7, 2.1, 0.5]
frequency_list[0] = 0 后:
[0.0, 1.2, 0.4, 3.0, 0.7, 2.1, 0.5]torch.topk(frequency_list, k=2) 返回:
text
values: [3.0, 2.1]
indices: [3, 5]周期计算:
text
period = T // top_list
period = 13 // [3,5]
period = [4,2]period = T // top_list 是整数下取整。若 TimesBlock.forward() 会通过 padding 把长度补到周期整数倍。
返回的 period_weight 来自:
python
abs(xf).mean(-1)[:, top_list]toy 值:
text
abs(xf).mean(-1): (3,7)
batch0: [9.0, 1.0, 0.2, 2.8, 0.5, 2.0, 0.1]
batch1: [8.0, 1.4, 0.5, 3.2, 0.6, 1.9, 0.4]
batch2: [7.0, 1.2, 0.5, 3.0, 1.0, 2.4, 1.0]
取 top_list=[3,5] 后:
period_weight =
[[2.8, 2.0],
[3.2, 1.9],
[3.0, 2.4]]6. torch.fft.rfft、topk、numpy 的职责
| 代码 | 职责 | 输出 |
|---|---|---|
torch.fft.rfft(x, dim=1) | 沿时间维把序列变成频谱 | (B,T//2+1,C) |
abs(xf) | 复数频谱取幅值 | (B,T//2+1,C) |
.mean(0).mean(-1) | 跨 batch 和 channel 得到全局频率强度 | (T//2+1,) |
torch.topk(frequency_list, k) | 找最强的 k 个频率索引 | values 与 indices |
.detach().cpu().numpy() | 把频率索引转成 numpy,后面用于 Python 分支和整数计算 | (k,) |
x.shape[1] // top_list | 频率索引转周期长度 | (k,) |
7. 出口接回上层
text
period_list = [4,2]
period_weight = (3,2)
回到 [[03B-Layer2B-TimesBlock]]
下一步: 对每个 period 分别 padding、reshape、2D conv