Appearance
Layer 2B — TimesBlock 精读
本层作用
TimesBlock.forward(x)是 TimesNet 的核心:输入已经是完整长度 hidden 序列(B, seq_len + pred_len, d_model),本层先用 FFT 找top_k个主周期,再分别按每个周期把 1D 时间轴折成 2D 周期网格,做二维卷积,最后把多个周期分支按频域强度加权融合。
1. 在父层中的位置
text
TimesNet.forecast()
└─ for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
└─ TimesBlock.forward(enc_out) ← 本文档
├─ FFT_for_Period(x, self.k) → 详见 [[03B1-Layer3-FFT_for_Period]]
├─ period 分支循环
├─ padding + reshape + permute
├─ self.conv(out) → 详见 [[03B2-Layer3-Inception_Block_V1]]
└─ stack + softmax + residual2. I/O 接口定义
入口函数:
python
def forward(self, x):全局 toy 参数:B=3, T=13, d_model=6, top_k=2, num_kernels=3。
| 变量 | toy shape | 含义 |
|---|---|---|
x | (3,13,6) | predict_linear 后的完整长度 hidden 序列 |
period_list | (2,) | FFT 选出的两个周期,例如 [4, 2] |
period_weight | (3,2) | 每个 batch 在两个主频上的幅值,用于分支融合 |
out | 分支内变化 | 单个周期分支的中间张量 |
res before stack | 两个 (3,13,6) | 两个周期分支各自产生一个全长度 hidden 序列 |
res after stack | (3,13,6,2) | 最后一维保存 top_k 个周期分支 |
res final | (3,13,6) | 周期加权融合后加 residual 的输出 |
3. 顺序图(具体层)
4. 语义分组图(索引层)
5. 逐步精读
5.0 完整原始代码
位置:ts_benchmark/baselines/time_series_library/models/TimesNet.py
python
class TimesBlock(nn.Module):
def __init__(self, configs):
super(TimesBlock, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.k = configs.top_k
# parameter-efficient design
self.conv = nn.Sequential(
Inception_Block_V1(
configs.d_model, configs.d_ff, num_kernels=configs.num_kernels
),
nn.GELU(),
Inception_Block_V1(
configs.d_ff, configs.d_model, num_kernels=configs.num_kernels
),
)
def forward(self, x):
B, T, N = x.size()
period_list, period_weight = FFT_for_Period(x, self.k)
res = []
for i in range(self.k):
period = period_list[i]
# padding
if (self.seq_len + self.pred_len) % period != 0:
length = (((self.seq_len + self.pred_len) // period) + 1) * period
padding = torch.zeros(
[x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]
).to(x.device)
out = torch.cat([x, padding], dim=1)
else:
length = self.seq_len + self.pred_len
out = x
# reshape
out = (
out.reshape(B, length // period, period, N)
.permute(0, 3, 1, 2)
.contiguous()
)
# 2D conv: from 1d Variation to 2d Variation
out = self.conv(out)
# reshape back
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
res.append(out[:, : (self.seq_len + self.pred_len), :])
res = torch.stack(res, dim=-1)
# adaptive aggregation
period_weight = F.softmax(period_weight, dim=1)
period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
res = torch.sum(res * period_weight, -1)
# residual connection
res = res + x
return res5.1 宏观逻辑:TimesBlock 到底在做什么
论文直觉
如果一个序列有周期
,那么把长度维按 切块后,二维网格的列表示“周期内位置”,行表示“第几个周期块”。 Conv2d在这个网格上滑动,就能同时观察周期内局部结构与跨周期演化。
用一条最小序列看 period=4 的折叠:
text
1D hidden 序列:
[1,2,3,4,5,6,7,8,9,10,11,12,13]
T=13 不能被 4 整除,先补到 16:
[1,2,3,4,5,6,7,8,9,10,11,12,13,0,0,0]
reshape(length//period=4, period=4):
row0: [1, 2, 3, 4]
row1: [5, 6, 7, 8]
row2: [9, 10, 11, 12]
row3: [13, 0, 0, 0]完整 shape 链:
text
x: (3,13,6)
FFT_for_Period: period_list=[4,2], period_weight=(3,2)
period=4 分支:
padding 到 length=16
(3,13,6) 到 (3,16,6)
reshape 到 (3,4,4,6)
permute 到 (3,6,4,4)
Conv2d 后保持 (3,6,4,4)
reshape back 到 (3,16,6)
crop 到 (3,13,6)
period=2 分支:
padding 到 length=14
(3,13,6) 到 (3,14,6)
reshape 到 (3,7,2,6)
permute 到 (3,6,7,2)
Conv2d 后保持 (3,6,7,2)
reshape back 到 (3,14,6)
crop 到 (3,13,6)
stack 两个分支: (3,13,6,2)
softmax 权重融合: (3,13,6)
residual: (3,13,6)顺序不能换:必须先 padding,再 reshape。如果不先补齐,T=13 不能整除 period=4,就无法 reshape 成完整二维网格。
5.2 S1 — FFT_for_Period 选择周期
python
B, T, N = x.size()
period_list, period_weight = FFT_for_Period(x, self.k)x.size() 得到:
text
B = 3
T = 13
N = 6FFT_for_Period(x, self.k) 返回:
text
period_list: (2,) 例如 [4, 2]
period_weight: (3,2) 每个 batch 在两个主频上的幅值本层只需要把 period_list 当成后续分支的循环条件。period_list 怎么由 rfft、abs、topk 得到,详见 [[03B1-Layer3-FFT_for_Period]]。
toy 假设:
text
period_list = [4, 2]
period_weight =
batch0: [2.0, 1.0]
batch1: [1.5, 1.2]
batch2: [0.7, 1.4]这表示 batch0 的 period=4 分支更强,batch2 的 period=2 分支更强。
5.3 S2 — 遍历每个周期分支
python
res = []
for i in range(self.k):
period = period_list[i]self.k = top_k = 2,所以循环执行两次:
text
i=0: period = 4
i=1: period = 2每个周期分支都会独立完成:
text
padding
到 reshape + permute
到 self.conv(out)
到 reshape back + crop
到 append 到 resres 最终保存两个分支的输出,每个分支都是 (3,13,6)。
5.4 S2a/S2b — period=4 分支:padding 与 2D 网格
python
if (self.seq_len + self.pred_len) % period != 0:
length = (((self.seq_len + self.pred_len) // period) + 1) * period
padding = torch.zeros(
[x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]
).to(x.device)
out = torch.cat([x, padding], dim=1)
else:
length = self.seq_len + self.pred_len
out = x对 period=4:
text
self.seq_len + self.pred_len = 8 + 5 = 13
13 % 4 = 1,不整除
length = ((13 // 4) + 1) * 4 = (3 + 1) * 4 = 16
padding length = 16 - 13 = 3
padding: (3,3,6)
out: (3,16,6)继续 reshape:
python
out = (
out.reshape(B, length // period, period, N)
.permute(0, 3, 1, 2)
.contiguous()
)shape:
text
out: (3,16,6)
reshape(B, length//period, period, N):
(3,4,4,6)
permute(0,3,1,2):
(3,6,4,4)toy 数值只看 batch=0, hidden=0:
text
x[0,:,0] = [1,2,3,4,5,6,7,8,9,10,11,12,13]
padding 后:
[1,2,3,4,5,6,7,8,9,10,11,12,13,0,0,0]
reshape 到 (周期段数=4, period=4):
row0: [1, 2, 3, 4]
row1: [5, 6, 7, 8]
row2: [9, 10, 11, 12]
row3: [13, 0, 0, 0]permute(0,3,1,2) 之后,hidden=0 这张二维网格变成 Conv2d 的一个输入 channel。完整 tensor 的语义是:
text
(B, N, H, W) = (3, 6, 4, 4)
B = batch
N = hidden channel
H = 周期段数
W = 周期内位置5.5 S2a/S2b — period=2 分支:另一种周期视角
对 period=2:
text
13 % 2 = 1,不整除
length = ((13 // 2) + 1) * 2 = (6 + 1) * 2 = 14
padding length = 14 - 13 = 1
padding: (3,1,6)
out: (3,14,6)shape:
text
out: (3,14,6)
reshape(B, length//period, period, N):
(3,7,2,6)
permute(0,3,1,2):
(3,6,7,2)toy 数值仍看 batch=0, hidden=0:
text
x[0,:,0] = [1,2,3,4,5,6,7,8,9,10,11,12,13]
padding 后:
[1,2,3,4,5,6,7,8,9,10,11,12,13,0]
reshape 到 (周期段数=7, period=2):
row0: [1, 2]
row1: [3, 4]
row2: [5, 6]
row3: [7, 8]
row4: [9, 10]
row5: [11,12]
row6: [13, 0]同一条 1D 序列在 period=4 和 period=2 下被折成两种不同二维视角。TimesNet 不提前押注哪个周期一定正确,而是让两个分支都卷积,再用 period_weight 融合。
5.6 S2c/S2d — self.conv 与 reshape back
python
out = self.conv(out)
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
res.append(out[:, : (self.seq_len + self.pred_len), :])self.conv 是:
python
self.conv = nn.Sequential(
Inception_Block_V1(
configs.d_model, configs.d_ff, num_kernels=configs.num_kernels
),
nn.GELU(),
Inception_Block_V1(
configs.d_ff, configs.d_model, num_kernels=configs.num_kernels
),
)对 period=4 分支:
text
进入 self.conv: (3,6,4,4)
Inception_Block_V1 6到7: (3,7,4,4)
GELU: (3,7,4,4)
Inception_Block_V1 7到6: (3,6,4,4)
permute(0,2,3,1): (3,4,4,6)
reshape(B,-1,N): (3,16,6)
crop 前 13 步: (3,13,6)对 period=2 分支:
text
进入 self.conv: (3,6,7,2)
conv 后保持: (3,6,7,2)
permute(0,2,3,1): (3,7,2,6)
reshape(B,-1,N): (3,14,6)
crop 前 13 步: (3,13,6)Inception_Block_V1 内部如何用 kernel_size=1,3,5 并行卷积并 torch.stack(res_list, dim=-1).mean(-1),详见 [[03B2-Layer3-Inception_Block_V1]]。
为什么 crop 前 13 步
padding 只是为了让长度能被周期整除。模型真实需要的长度仍然是
,所以卷积后必须 out[:, :13, :]裁剪回原长度。
5.7 S3 — stack 周期分支
python
res = torch.stack(res, dim=-1)循环结束前:
text
res[0]: period=4 分支输出 (3,13,6)
res[1]: period=2 分支输出 (3,13,6)torch.stack(res, dim=-1) 新增最后一维保存分支编号:
text
stack 后: (3,13,6,2)toy 只看一个位置 (batch=0, time=5, hidden=2):
text
period=4 分支输出: branch4[0,5,2] = 10.0
period=2 分支输出: branch2[0,5,2] = 20.0
stack 后:
res[0,5,2,:] = [10.0, 20.0]5.8 S4 — softmax 周期权重并广播
python
period_weight = F.softmax(period_weight, dim=1)
period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)输入 period_weight 来自 FFT 幅值:
text
period_weight: (3,2)对 batch0,假设 FFT 幅值为 [2.0, 1.0]:
shape 广播链:
text
period_weight: (3,2)
unsqueeze(1): (3,1,2)
unsqueeze(1): (3,1,1,2)
repeat(1,T,N,1): (3,13,6,2)这样每个 batch 的两个周期权重会复制到所有时间步和所有 hidden channel 上。
5.9 S5 — 加权求和 + residual
python
res = torch.sum(res * period_weight, -1)
res = res + x
return resres * period_weight 逐元素相乘:
text
res: (3,13,6,2)
period_weight: (3,13,6,2)
相乘后: (3,13,6,2)
sum(-1): (3,13,6)toy 接续上一个位置:
text
res[0,5,2,:] = [10.0, 20.0]
period_weight[0,5,2,:] = [0.731, 0.269]
weighted =
10.0 * 0.731 + 20.0 * 0.269
= 7.31 + 5.38
= 12.69最后加残差:
text
原始输入 x[0,5,2] = 6.0
输出 res[0,5,2] = 12.69 + 6.0 = 18.69residual 的作用
TimesBlock 的卷积分支学习的是周期增强后的变化量。
res + x保留输入 hidden 表示的原始信息,避免周期分支学不好时破坏主干信号。
6. 下钻子组件
| 子组件 | 父层调用位置 | 职责 | 下层文档 |
|---|---|---|---|
FFT_for_Period | period_list, period_weight = FFT_for_Period(x, self.k) | 用 rfft 的频域幅值选主周期 | [[03B1-Layer3-FFT_for_Period]] |
Inception_Block_V1 | out = self.conv(out) | 多 kernel 二维卷积,保持二维网格大小 | [[03B2-Layer3-Inception_Block_V1]] |
7. 出口接回上层
TimesBlock.forward() 返回:
text
res: (3,13,6)回到 [[02-Layer1-forecast主链]] 后立刻进入:
python
enc_out = self.layer_norm(self.model[i](enc_out))LayerNorm 保持 shape (3,13,6),随后 projection 把 hidden 维 6 映射回变量维 4。