Skip to content

03B1-Layer3-FFT_for_Period

本文件位置

上层:[[03B-Layer2B-TimesBlock]]
入口代码:period_list, period_weight = FFT_for_Period(x, self.k)
入口函数:FFT_for_Period(x, k=2)
出口:periodperiod_weight

1. 本层顺序树

1.1 语义分组图

2. 输入输出接口

变量toy shape含义
x(3,13,6)TimesBlock 输入 hidden 序列
xf(3,7,6)rfft 输出,实数 FFT 只保留非负频率,长度为 T//2+1=7
frequency_list(7,)每个频率在 batch 和 channel 上的平均幅值
top_list(2,)幅值最大的频率索引
period(2,)周期长度,计算为 T // top_list
period_weight(3,2)每个 batch 对 top 频率的幅值,用于后续 softmax 加权

3. 对照源码

位置:ts_benchmark/baselines/time_series_library/models/TimesNet.py

python
def FFT_for_Period(x, k=2):
    # [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    period = x.shape[1] // top_list
    return period, abs(xf).mean(-1)[:, top_list]

4. FFT 的数学含义

对每个 batch 和 channel,沿时间维计算:

Xb,f,c=t=0T1xb,t,cexp(2πiftT)

abs(xf) 得到幅值:

Ab,f,c=|Xb,f,c|

代码中的全局频率强度:

scoref=1BCbcAb,f,c

然后选出 score_f 最大的 k 个频率索引。

5. toy 数值例子

令:

text
B=3, T=13, C=6, k=2
rfft频率索引: [0,1,2,3,4,5,6]

假设 abs(xf).mean(0).mean(-1) 得到:

text
frequency_list before zero:
[9.5, 1.2, 0.4, 3.0, 0.7, 2.1, 0.5]

frequency_list[0] = 0 后:
[0.0, 1.2, 0.4, 3.0, 0.7, 2.1, 0.5]

torch.topk(frequency_list, k=2) 返回:

text
values:   [3.0, 2.1]
indices:  [3, 5]

周期计算:

text
period = T // top_list
period = 13 // [3,5]
period = [4,2]

period = T // top_list 是整数下取整。若 T 不能被频率索引整除,后续 TimesBlock.forward() 会通过 padding 把长度补到周期整数倍。

返回的 period_weight 来自:

python
abs(xf).mean(-1)[:, top_list]

toy 值:

text
abs(xf).mean(-1): (3,7)

batch0: [9.0, 1.0, 0.2, 2.8, 0.5, 2.0, 0.1]
batch1: [8.0, 1.4, 0.5, 3.2, 0.6, 1.9, 0.4]
batch2: [7.0, 1.2, 0.5, 3.0, 1.0, 2.4, 1.0]

取 top_list=[3,5] 后:
period_weight =
[[2.8, 2.0],
 [3.2, 1.9],
 [3.0, 2.4]]

6. torch.fft.rffttopknumpy 的职责

代码职责输出
torch.fft.rfft(x, dim=1)沿时间维把序列变成频谱(B,T//2+1,C)
abs(xf)复数频谱取幅值(B,T//2+1,C)
.mean(0).mean(-1)跨 batch 和 channel 得到全局频率强度(T//2+1,)
torch.topk(frequency_list, k)找最强的 k 个频率索引values 与 indices
.detach().cpu().numpy()把频率索引转成 numpy,后面用于 Python 分支和整数计算(k,)
x.shape[1] // top_list频率索引转周期长度(k,)

7. 出口接回上层

text
period_list = [4,2]
period_weight = (3,2)
回到 [[03B-Layer2B-TimesBlock]]
下一步: 对每个 period 分别 padding、reshape、2D conv

*记录并在线阅读我的笔记*