Skip to content

Layer 2B — TimesBlock 精读

本层作用

TimesBlock.forward(x) 是 TimesNet 的核心:输入已经是完整长度 hidden 序列 (B, seq_len + pred_len, d_model),本层先用 FFT 找 top_k 个主周期,再分别按每个周期把 1D 时间轴折成 2D 周期网格,做二维卷积,最后把多个周期分支按频域强度加权融合。


1. 在父层中的位置

text
TimesNet.forecast()
  └─ for i in range(self.layer):
       enc_out = self.layer_norm(self.model[i](enc_out))
                                   └─ TimesBlock.forward(enc_out)  ← 本文档
                                        ├─ FFT_for_Period(x, self.k) → 详见 [[03B1-Layer3-FFT_for_Period]]
                                        ├─ period 分支循环
                                        ├─ padding + reshape + permute
                                        ├─ self.conv(out)             → 详见 [[03B2-Layer3-Inception_Block_V1]]
                                        └─ stack + softmax + residual

2. I/O 接口定义

入口函数:

python
def forward(self, x):

全局 toy 参数:B=3, T=13, d_model=6, top_k=2, num_kernels=3

变量toy shape含义
x(3,13,6)predict_linear 后的完整长度 hidden 序列
period_list(2,)FFT 选出的两个周期,例如 [4, 2]
period_weight(3,2)每个 batch 在两个主频上的幅值,用于分支融合
out分支内变化单个周期分支的中间张量
res before stack两个 (3,13,6)两个周期分支各自产生一个全长度 hidden 序列
res after stack(3,13,6,2)最后一维保存 top_k 个周期分支
res final(3,13,6)周期加权融合后加 residual 的输出

3. 顺序图(具体层)


4. 语义分组图(索引层)


5. 逐步精读

5.0 完整原始代码

位置:ts_benchmark/baselines/time_series_library/models/TimesNet.py

python
class TimesBlock(nn.Module):
    def __init__(self, configs):
        super(TimesBlock, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.k = configs.top_k
        # parameter-efficient design
        self.conv = nn.Sequential(
            Inception_Block_V1(
                configs.d_model, configs.d_ff, num_kernels=configs.num_kernels
            ),
            nn.GELU(),
            Inception_Block_V1(
                configs.d_ff, configs.d_model, num_kernels=configs.num_kernels
            ),
        )

    def forward(self, x):
        B, T, N = x.size()
        period_list, period_weight = FFT_for_Period(x, self.k)

        res = []
        for i in range(self.k):
            period = period_list[i]
            # padding
            if (self.seq_len + self.pred_len) % period != 0:
                length = (((self.seq_len + self.pred_len) // period) + 1) * period
                padding = torch.zeros(
                    [x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]
                ).to(x.device)
                out = torch.cat([x, padding], dim=1)
            else:
                length = self.seq_len + self.pred_len
                out = x
            # reshape
            out = (
                out.reshape(B, length // period, period, N)
                .permute(0, 3, 1, 2)
                .contiguous()
            )
            # 2D conv: from 1d Variation to 2d Variation
            out = self.conv(out)
            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
            res.append(out[:, : (self.seq_len + self.pred_len), :])
        res = torch.stack(res, dim=-1)
        # adaptive aggregation
        period_weight = F.softmax(period_weight, dim=1)
        period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
        res = torch.sum(res * period_weight, -1)
        # residual connection
        res = res + x
        return res

5.1 宏观逻辑:TimesBlock 到底在做什么

论文直觉

如果一个序列有周期 p,那么把长度维按 p 切块后,二维网格的列表示“周期内位置”,行表示“第几个周期块”。Conv2d 在这个网格上滑动,就能同时观察周期内局部结构与跨周期演化。

用一条最小序列看 period=4 的折叠:

text
1D hidden 序列:
[1,2,3,4,5,6,7,8,9,10,11,12,13]

T=13 不能被 4 整除,先补到 16:
[1,2,3,4,5,6,7,8,9,10,11,12,13,0,0,0]

reshape(length//period=4, period=4):
row0: [1,  2,  3,  4]
row1: [5,  6,  7,  8]
row2: [9, 10, 11, 12]
row3: [13, 0,  0,  0]

完整 shape 链:

text
x: (3,13,6)
FFT_for_Period: period_list=[4,2], period_weight=(3,2)

period=4 分支:
  padding 到 length=16
  (3,13,6) 到 (3,16,6)
  reshape 到 (3,4,4,6)
  permute 到 (3,6,4,4)
  Conv2d 后保持 (3,6,4,4)
  reshape back 到 (3,16,6)
  crop 到 (3,13,6)

period=2 分支:
  padding 到 length=14
  (3,13,6) 到 (3,14,6)
  reshape 到 (3,7,2,6)
  permute 到 (3,6,7,2)
  Conv2d 后保持 (3,6,7,2)
  reshape back 到 (3,14,6)
  crop 到 (3,13,6)

stack 两个分支: (3,13,6,2)
softmax 权重融合: (3,13,6)
residual: (3,13,6)

顺序不能换:必须先 padding,再 reshape。如果不先补齐,T=13 不能整除 period=4,就无法 reshape 成完整二维网格。


5.2 S1 — FFT_for_Period 选择周期

python
B, T, N = x.size()
period_list, period_weight = FFT_for_Period(x, self.k)

x.size() 得到:

text
B = 3
T = 13
N = 6

FFT_for_Period(x, self.k) 返回:

text
period_list:   (2,)    例如 [4, 2]
period_weight: (3,2)   每个 batch 在两个主频上的幅值

本层只需要把 period_list 当成后续分支的循环条件。period_list 怎么由 rfftabstopk 得到,详见 [[03B1-Layer3-FFT_for_Period]]。

toy 假设:

text
period_list = [4, 2]

period_weight =
batch0: [2.0, 1.0]
batch1: [1.5, 1.2]
batch2: [0.7, 1.4]

这表示 batch0 的 period=4 分支更强,batch2 的 period=2 分支更强。


5.3 S2 — 遍历每个周期分支

python
res = []
for i in range(self.k):
    period = period_list[i]

self.k = top_k = 2,所以循环执行两次:

text
i=0: period = 4
i=1: period = 2

每个周期分支都会独立完成:

text
padding
到 reshape + permute
到 self.conv(out)
到 reshape back + crop
到 append 到 res

res 最终保存两个分支的输出,每个分支都是 (3,13,6)


5.4 S2a/S2b — period=4 分支:padding 与 2D 网格

python
if (self.seq_len + self.pred_len) % period != 0:
    length = (((self.seq_len + self.pred_len) // period) + 1) * period
    padding = torch.zeros(
        [x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]
    ).to(x.device)
    out = torch.cat([x, padding], dim=1)
else:
    length = self.seq_len + self.pred_len
    out = x

period=4

text
self.seq_len + self.pred_len = 8 + 5 = 13
13 % 4 = 1,不整除
length = ((13 // 4) + 1) * 4 = (3 + 1) * 4 = 16
padding length = 16 - 13 = 3

padding: (3,3,6)
out:     (3,16,6)

继续 reshape:

python
out = (
    out.reshape(B, length // period, period, N)
    .permute(0, 3, 1, 2)
    .contiguous()
)

shape:

text
out:                              (3,16,6)
reshape(B, length//period, period, N):
                                  (3,4,4,6)
permute(0,3,1,2):
                                  (3,6,4,4)

toy 数值只看 batch=0, hidden=0

text
x[0,:,0] = [1,2,3,4,5,6,7,8,9,10,11,12,13]

padding 后:
[1,2,3,4,5,6,7,8,9,10,11,12,13,0,0,0]

reshape 到 (周期段数=4, period=4):
row0: [1,  2,  3,  4]
row1: [5,  6,  7,  8]
row2: [9, 10, 11, 12]
row3: [13, 0,  0,  0]

permute(0,3,1,2) 之后,hidden=0 这张二维网格变成 Conv2d 的一个输入 channel。完整 tensor 的语义是:

text
(B, N, H, W) = (3, 6, 4, 4)
B = batch
N = hidden channel
H = 周期段数
W = 周期内位置

5.5 S2a/S2b — period=2 分支:另一种周期视角

period=2

text
13 % 2 = 1,不整除
length = ((13 // 2) + 1) * 2 = (6 + 1) * 2 = 14
padding length = 14 - 13 = 1

padding: (3,1,6)
out:     (3,14,6)

shape:

text
out:                              (3,14,6)
reshape(B, length//period, period, N):
                                  (3,7,2,6)
permute(0,3,1,2):
                                  (3,6,7,2)

toy 数值仍看 batch=0, hidden=0

text
x[0,:,0] = [1,2,3,4,5,6,7,8,9,10,11,12,13]

padding 后:
[1,2,3,4,5,6,7,8,9,10,11,12,13,0]

reshape 到 (周期段数=7, period=2):
row0: [1,  2]
row1: [3,  4]
row2: [5,  6]
row3: [7,  8]
row4: [9, 10]
row5: [11,12]
row6: [13, 0]

同一条 1D 序列在 period=4period=2 下被折成两种不同二维视角。TimesNet 不提前押注哪个周期一定正确,而是让两个分支都卷积,再用 period_weight 融合。


5.6 S2c/S2d — self.conv 与 reshape back

python
out = self.conv(out)
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
res.append(out[:, : (self.seq_len + self.pred_len), :])

self.conv 是:

python
self.conv = nn.Sequential(
    Inception_Block_V1(
        configs.d_model, configs.d_ff, num_kernels=configs.num_kernels
    ),
    nn.GELU(),
    Inception_Block_V1(
        configs.d_ff, configs.d_model, num_kernels=configs.num_kernels
    ),
)

period=4 分支:

text
进入 self.conv:             (3,6,4,4)
Inception_Block_V1 6到7:    (3,7,4,4)
GELU:                       (3,7,4,4)
Inception_Block_V1 7到6:    (3,6,4,4)
permute(0,2,3,1):           (3,4,4,6)
reshape(B,-1,N):            (3,16,6)
crop 前 13 步:               (3,13,6)

period=2 分支:

text
进入 self.conv:             (3,6,7,2)
conv 后保持:                 (3,6,7,2)
permute(0,2,3,1):           (3,7,2,6)
reshape(B,-1,N):            (3,14,6)
crop 前 13 步:               (3,13,6)

Inception_Block_V1 内部如何用 kernel_size=1,3,5 并行卷积并 torch.stack(res_list, dim=-1).mean(-1),详见 [[03B2-Layer3-Inception_Block_V1]]。

为什么 crop 前 13 步

padding 只是为了让长度能被周期整除。模型真实需要的长度仍然是 T=seq_len+pred_len=13,所以卷积后必须 out[:, :13, :] 裁剪回原长度。


5.7 S3 — stack 周期分支

python
res = torch.stack(res, dim=-1)

循环结束前:

text
res[0]: period=4 分支输出 (3,13,6)
res[1]: period=2 分支输出 (3,13,6)

torch.stack(res, dim=-1) 新增最后一维保存分支编号:

text
stack 后: (3,13,6,2)

toy 只看一个位置 (batch=0, time=5, hidden=2)

text
period=4 分支输出: branch4[0,5,2] = 10.0
period=2 分支输出: branch2[0,5,2] = 20.0

stack 后:
res[0,5,2,:] = [10.0, 20.0]

5.8 S4 — softmax 周期权重并广播

python
period_weight = F.softmax(period_weight, dim=1)
period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)

输入 period_weight 来自 FFT 幅值:

text
period_weight: (3,2)

对 batch0,假设 FFT 幅值为 [2.0, 1.0]

softmax([2,1])=[e2e2+e1,e1e2+e1][0.731,0.269]

shape 广播链:

text
period_weight:          (3,2)
unsqueeze(1):           (3,1,2)
unsqueeze(1):           (3,1,1,2)
repeat(1,T,N,1):        (3,13,6,2)

这样每个 batch 的两个周期权重会复制到所有时间步和所有 hidden channel 上。


5.9 S5 — 加权求和 + residual

python
res = torch.sum(res * period_weight, -1)
res = res + x
return res

res * period_weight 逐元素相乘:

text
res:           (3,13,6,2)
period_weight: (3,13,6,2)
相乘后:        (3,13,6,2)
sum(-1):       (3,13,6)

toy 接续上一个位置:

text
res[0,5,2,:] = [10.0, 20.0]
period_weight[0,5,2,:] = [0.731, 0.269]

weighted =
10.0 * 0.731 + 20.0 * 0.269
= 7.31 + 5.38
= 12.69

最后加残差:

text
原始输入 x[0,5,2] = 6.0
输出 res[0,5,2] = 12.69 + 6.0 = 18.69
residual 的作用

TimesBlock 的卷积分支学习的是周期增强后的变化量。res + x 保留输入 hidden 表示的原始信息,避免周期分支学不好时破坏主干信号。


6. 下钻子组件

子组件父层调用位置职责下层文档
FFT_for_Periodperiod_list, period_weight = FFT_for_Period(x, self.k)rfft 的频域幅值选主周期[[03B1-Layer3-FFT_for_Period]]
Inception_Block_V1out = self.conv(out)多 kernel 二维卷积,保持二维网格大小[[03B2-Layer3-Inception_Block_V1]]

7. 出口接回上层

TimesBlock.forward() 返回:

text
res: (3,13,6)

回到 [[02-Layer1-forecast主链]] 后立刻进入:

python
enc_out = self.layer_norm(self.model[i](enc_out))

LayerNorm 保持 shape (3,13,6),随后 projection 把 hidden 维 6 映射回变量维 4

*记录并在线阅读我的笔记*