Skip to content

Layer 3B — MultiScaleTrendMixing(顶向下趋势混合)

1. 在父层中的位置

PDM 块 forward()out_trend_list = self.mixing_multi_scale_trend(trend_list) 调用。

2. I/O 接口定义

参数Shape说明
trend_list[(6,8,24),(6,8,12),(6,8,6)]各尺度趋势成分,(BN, d, Ti)
返回[(6,24,8),(6,12,8),(6,6,8)]混合后,permute 回 (BN, Ti, d),已从逆序还原

3. 顺序图

4. 逐步骤精读

§5.0 完整原始代码

python
class MultiScaleTrendMixing(nn.Module):
    def __init__(self, configs):
        super(MultiScaleTrendMixing, self).__init__()

        self.up_sampling_layers = torch.nn.ModuleList(
            [
                nn.Sequential(
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window ** (i + 1)),
                        configs.seq_len // (configs.down_sampling_window**i),
                    ),
                    nn.GELU(),
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window**i),
                        configs.seq_len // (configs.down_sampling_window**i),
                    ),
                )
                for i in reversed(range(configs.down_sampling_layers))
            ]
        )

    def forward(self, trend_list):
        # mixing low->high
        trend_list_reverse = trend_list.copy()
        trend_list_reverse.reverse()
        out_low = trend_list_reverse[0]
        out_high = trend_list_reverse[1]
        out_trend_list = [out_low.permute(0, 2, 1)]

        for i in range(len(trend_list_reverse) - 1):
            out_high_res = self.up_sampling_layers[i](out_low)
            out_high = out_high + out_high_res
            out_low = out_high
            if i + 2 <= len(trend_list_reverse) - 1:
                out_high = trend_list_reverse[i + 2]
            out_trend_list.append(out_low.permute(0, 2, 1))

        out_trend_list.reverse()
        return out_trend_list

§5.1 宏观逻辑

设计直觉:粗粒度尺度(T=6)已被 AvgPool 多次均值平滑,趋势信号最干净。通过线性上采样(扩展时间维)把粗粒度趋势"广播"给细粒度,让细粒度的趋势成分对齐到全局背景。传播方向:scale2(T=6)→ scale1(T=12)→ scale0(T=24),故称"顶向下(top-down)"。

top-down 传播链示意:

scale2 (T=6)  ─── up_layers[0] ─── ⊕ ──→ scale1' (T=12)

                               scale1 (T=12, 原始)

scale1' (T=12) ─── up_layers[1] ─── ⊕ ──→ scale0' (T=24)

                                scale0 (T=24, 原始)

这与 SeasonMixing 完全对称:Season 用 down_layers(压缩 T),Trend 用 up_layers(扩展 T);Season 从 scale0 出发,Trend 从 scale2 出发。

§5.2 up_sampling_layers 初始化的关键细节

构建时 for i in reversed(range(configs.down_sampling_layers)),当 down_sampling_layers=2,即 reversed(range(2)) = [1, 0]

up_sampling_layers[0]i=1 构建:Linear(24//21+1, 24//21) = Linear(6,12),GELU,Linear(12,12)。用于 scale2 → scale1 的上采样。

up_sampling_layers[1]i=0 构建:Linear(24//20+1, 24//20) = Linear(12,24),GELU,Linear(24,24)。用于 scale1 → scale0 的上采样。

为什么用 reversed 保证 up_sampling_layers 的列表顺序与 forward 中的循环顺序一致:循环 i=0up_sampling_layers[0](做 612),循环 i=1up_sampling_layers[1](做 1224)。如果不 reversedi=0 会对应 Linear(12,24),作用于 T=6 的输入,维度不匹配报错。

§5.3 toy 数值完整追踪

输入: trend_list = [(6,8,24), (6,8,12), (6,8,6)]

逆序:

trend_list_reverse = [(6,8,6), (6,8,12), (6,8,24)].copy().reverse(),不修改原 list)。out_low = trend_list_reverse[0](6,8,6)(最粗尺度 scale2)。out_high = trend_list_reverse[1](6,8,12)(scale1)。out_trend_list = [out_low.permute(0,2,1)] = [(6,6,8)],scale2 原封不动加入结果(仅 permute)。

循环 i=0(共 2 次):

up_sampling_layers[0](Linear612 + GELU + Linear1212)作用于 out_low=(6,8,6) 的最后维(T=6),输出 out_high_res shape (6, 8, 12)out_high = (6,8,12) + (6,8,12) = (6,8,12)(scale1 原始 + 来自 scale2 的上采样残差)。out_low = out_high = (6,8,12)out_low 更新为混合后的 scale1)。i+2=2 ≤ len-1=2,故 out_high = trend_list_reverse[2] = (6,8,24)(下一轮的接收方 scale0 原始)。out_trend_list.append(out_low.permute(0,2,1)) → append (6,12,8)out_trend_list = [(6,6,8),(6,12,8)]

循环 i=1

up_sampling_layers[1](Linear1224 + GELU + Linear2424)作用于 out_low=(6,8,12) 的最后维(T=12),输出 out_high_res shape (6, 8, 24)out_high = (6,8,24) + (6,8,24) = (6,8,24)(scale0 原始 + 来自混合后 scale1 的上采样残差)。out_low = out_high = (6,8,24)i+2=3 > len-1=2,不更新 out_highout_trend_list.append(out_low.permute(0,2,1)) → append (6,24,8)out_trend_list = [(6,6,8),(6,12,8),(6,24,8)]

还原顺序:

out_trend_list.reverse()[(6,24,8),(6,12,8),(6,6,8)],与原始 trend_list 顺序(细→粗)一致,才能与 PDM 的 zip 循环正确配对。

返回: [(6,24,8),(6,12,8),(6,6,8)]

信息传播分析:

  • out_trend_list[0] = (6,24,8):scale0 原始 ⊕ (scale1 ⊕ scale2 信息) 经 up_layers[1] 上采样的残差
  • out_trend_list[1] = (6,12,8):scale1 原始 ⊕ scale2 经 up_layers[0] 上采样的残差
  • out_trend_list[2] = (6,6,8):scale2 原始,未被混合

scale2 的粗粒度趋势经两跳抵达 scale0。

5. Season vs Trend 对比表

维度SeasonMixingTrendMixing
传播方向细→粗(bottom-up)粗→细(top-down)
层参数down_sampling_layersup_sampling_layers
Linear 方向压缩 T(长→短)扩展 T(短→长)
__init__ 循环for i in range(layers)for i in reversed(range(layers))
起点(首入 list 元素)scale0(最细)scale2(最粗,逆序后 index=0)
终点(未被混合的尺度)scale0(结果[0]不经残差)scale2(结果[2]不经残差)
最后是否 reverse❌ 直接返回✅ 需 reverse 还原细→粗顺序

*记录并在线阅读我的笔记*