Skip to content

Layer 3A — MultiScaleSeasonMixing(底向上季节混合)

1. 在父层中的位置

PDM 块 forward()out_season_list = self.mixing_multi_scale_season(season_list) 调用。

2. I/O 接口定义

参数Shape说明
season_list[(6,8,24),(6,8,12),(6,8,6)]各尺度季节成分,(BN, d, Ti)
返回[(6,24,8),(6,12,8),(6,6,8)]混合后,permute 回 (BN, Ti, d)

3. 顺序图

4. 逐步骤精读

§5.0 完整原始代码

python
class MultiScaleSeasonMixing(nn.Module):
    def __init__(self, configs):
        super(MultiScaleSeasonMixing, self).__init__()

        self.down_sampling_layers = torch.nn.ModuleList(
            [
                nn.Sequential(
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window**i),
                        configs.seq_len // (configs.down_sampling_window ** (i + 1)),
                    ),
                    nn.GELU(),
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window ** (i + 1)),
                        configs.seq_len // (configs.down_sampling_window ** (i + 1)),
                    ),
                )
                for i in range(configs.down_sampling_layers)
            ]
        )

    def forward(self, season_list):
        # mixing high->low
        out_high = season_list[0]
        out_low = season_list[1]
        out_season_list = [out_high.permute(0, 2, 1)]

        for i in range(len(season_list) - 1):
            out_low_res = self.down_sampling_layers[i](out_high)
            out_low = out_low + out_low_res
            out_high = out_low
            if i + 2 <= len(season_list) - 1:
                out_low = season_list[i + 2]
            out_season_list.append(out_high.permute(0, 2, 1))

        return out_season_list

§5.1 宏观逻辑

设计直觉:细粒度(高分辨率)的季节成分包含最丰富的振荡细节,应当向粗粒度"渗透"——让粗粒度尺度借鉴细粒度的振荡模式,而不是只看自己平滑后的近似。传播方向:scale0(T=24)→ scale1(T=12)→ scale2(T=6),即从细到粗,故称"底向上(bottom-up)"。

用小例子(BN=2d=4,三个尺度 T=8,4,2)串起:out_high=(2,4,8)out_low=(2,4,4)down_layers[0] 是 Linear(84) + GELU + Linear(44),作用于 out_high 的最后维(T轴),输出 (2,4,4),加到 out_low:粗粒度 scale1 吸收了 scale0 的细节。然后 out_high = out_low(此时 scale1 已含 scale0 信息),out_low = scale2(2,4,2)down_layers[1] 是 Linear(42) + GELU + Linear(22),将含 scale0 信息的 scale1 进一步压缩后加到 scale2——scale0 信息经两跳传播到了最粗尺度。

bottom-up 传播链示意(d 维在 lane 内,T 维是 Linear 操作轴):

scale0 (T=24) ─── down_layers[0] ─── ⊕ ──→ scale1' (T=12)

                                 scale1 (T=12, 原始)

scale1' (T=12) ─── down_layers[1] ─── ⊕ ──→ scale2' (T=6)

                                   scale2 (T=6, 原始)

§5.2 down_sampling_layers 的 Linear 维度

python
self.down_sampling_layers = torch.nn.ModuleList(
	[
		nn.Sequential(
			torch.nn.Linear(
				configs.seq_len // (configs.down_sampling_window**i),
				configs.seq_len // (configs.down_sampling_window ** (i + 1)),
			),
			nn.GELU(),
			torch.nn.Linear(
				configs.seq_len // (configs.down_sampling_window ** (i + 1)),
				configs.seq_len // (configs.down_sampling_window ** (i + 1)),
			),
		)
		for i in range(configs.down_sampling_layers)
	]
)

toy 参数(seq_len=24, window=2, down_sampling_layers=2):

down_sampling_layers[0]i=0 构建:Linear(24//20, 24//21) = Linear(24,12),GELU,Linear(12,12)。这个 Sequential 接受最后维为 24 的张量,输出最后维为 12 的张量。

down_sampling_layers[1]i=1 构建:Linear(24//21, 24//22) = Linear(12,6),GELU,Linear(6,6)

两层 Linear 的含义:第一层压缩(TiTi+1),第二层做非线性精炼(Ti+1Ti+1,维度不变)。

§5.3 toy 数值完整追踪

输入: season_list = [(6,8,24), (6,8,12), (6,8,6)]

初始化:

out_high = season_list[0] → shape (6, 8, 24)(scale0)。out_low = season_list[1] → shape (6, 8, 12)(scale1)。out_season_list = [out_high.permute(0,2,1)] = [(6,24,8)],scale0 原封不动(仅 permute)加入结果。

循环 i=0(共 2 次,len(season_list)-1=2):

down_sampling_layers[0](Linear2412 + GELU + Linear1212)作用于 out_high=(6,8,24) 的最后维,输出 out_low_res shape (6, 8, 12)out_low = (6,8,12) + (6,8,12) = (6,8,12)(scale1 原始 + 来自 scale0 的压缩残差)。out_high = out_low = (6,8,12)out_high 指针更新为混合后的 scale1)。i+2=2 ≤ len-1=2,故 out_low = season_list[2] = (6,8,6)(下一轮的接收方 scale2)。out_season_list.append(out_high.permute(0,2,1)) → append (6,12,8)out_season_list = [(6,24,8),(6,12,8)]

循环 i=1

down_sampling_layers[1](Linear126 + GELU + Linear66)作用于 out_high=(6,8,12) 的最后维,输出 out_low_res shape (6, 8, 6)out_low = (6,8,6) + (6,8,6) = (6,8,6)(scale2 原始 + 来自混合后 scale1 的压缩残差)。out_high = out_low = (6,8,6)i+2=3 > len-1=2,不更新 out_lowout_season_list.append(out_high.permute(0,2,1)) → append (6,6,8)out_season_list = [(6,24,8),(6,12,8),(6,6,8)]

返回: [(6,24,8),(6,12,8),(6,6,8)]

信息传播分析:

  • out_season_list[0] = (6,24,8):scale0 原始,未被混合
  • out_season_list[1] = (6,12,8):scale1 原始 ⊕ scale0 经 down_layers[0] 压缩的残差
  • out_season_list[2] = (6,6,8):scale2 原始 ⊕ (scale1 ⊕ scale0 信息) 经 down_layers[1] 再次压缩的残差

scale0 的细粒度振荡信息经两跳链式传播抵达 scale2。

out_high = out_low 是指针更新,不是拷贝

循环内 out_high = out_lowout_high 指向上一步更新后的 out_low(已融合细粒度信息的粗粒度尺度),而不是原始 season_list 中的对应尺度。这是链式传播能实现的关键——不是"scale0 直接给 scale2",而是"scale0 给 scale1,混合后的 scale1 再给 scale2"。

*记录并在线阅读我的笔记*