Skip to content

Layer 2A — __multi_scale_process_inputs() 多尺度输入构建

1. 在父层中的位置

forecast() 第一行调用 __multi_scale_process_inputs(x_enc, x_mark_enc) 生成多尺度列表,是后续所有操作的起点。

2. I/O 接口定义

参数Shape说明
x_enc(2, 24, 3)原始 encoder 输入
x_mark_enc(2, 24, 4)时间标记
返回 x_enc[(2,24,3),(2,12,3),(2,6,3)]down_sampling_layers+1 = 3 个尺度
返回 x_mark[(2,24,4),(2,12,4),(2,6,4)]对应 mark 列表

3. 顺序图

4. 语义分组图

5. 逐步骤精读

§5.0 完整原始代码

python
def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
    if self.configs.down_sampling_method == "max":
        down_pool = torch.nn.MaxPool1d(
            self.configs.down_sampling_window, return_indices=False
        )
    elif self.configs.down_sampling_method == "avg":
        down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window)
    elif self.configs.down_sampling_method == "conv":
        padding = 1 if torch.__version__ >= "1.5.0" else 2
        down_pool = nn.Conv1d(
            in_channels=self.configs.enc_in,
            out_channels=self.configs.enc_in,
            kernel_size=3,
            padding=padding,
            stride=self.configs.down_sampling_window,
            padding_mode="circular",
            bias=False,
        )
    else:
        return x_enc, x_mark_enc
    # B,T,C -> B,C,T
    x_enc = x_enc.permute(0, 2, 1)

    x_enc_ori = x_enc
    x_mark_enc_mark_ori = x_mark_enc

    x_enc_sampling_list = []
    x_mark_sampling_list = []
    x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
    x_mark_sampling_list.append(x_mark_enc)

    for i in range(self.configs.down_sampling_layers):
        x_enc_sampling = down_pool(x_enc_ori)

        x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
        x_enc_ori = x_enc_sampling

        if x_mark_enc is not None:
            x_mark_sampling_list.append(
                x_mark_enc_mark_ori[:, :: self.configs.down_sampling_window, :]
            )
            x_mark_enc_mark_ori = x_mark_enc_mark_ori[
                :, :: self.configs.down_sampling_window, :
            ]

    x_enc = x_enc_sampling_list
    x_mark_enc = x_mark_sampling_list if x_mark_enc is not None else None

    return x_enc, x_mark_enc

§5.1 宏观逻辑

目标:以固定窗口大小 w=2 对时序逐级下采样,生成 layers+1 个不同分辨率的副本,构成多尺度列表。

用小例子(B=1,N=2,T=8,w=2,layers=2)串起:

首先 permute(0,2,1)(1,8,2) 变为 (1,2,8),因为 AvgPool1d 期望输入格式为 (B,C,T)(通道在前,时间在后)。然后把原始 (1,8,2)(permute 还原)直接加入结果列表。循环 i=0:Pool 作用在 (1,2,8) 的时间维 T=8T=4,得 (1,2,4),permute 回 (1,4,2);循环 i=1:Pool 作用在 (1,2,4)T=2,得 (1,2,2),permute 回 (1,2,2)。结果列表为 [(1,8,2),(1,4,2),(1,2,2)]

为什么先 permute 再 Pool 再 permute 回来? PyTorch AvgPool1d 对最后一维操作,要求格式 (B,C,T),而模型内部统一使用 (B,T,C) 格式。permute 是格式适配层,不改变数据内容,只调整轴顺序。

AvgPool vs 可学习下采样

AvgPool1d 无参数:下采样 = 局部均值,纯粹的信息粗化,不引入特征变换。可学习的 Conv 下采样会在压缩时学习特征,把"改变分辨率"和"改变特征语义"两件事混在一起,与"多尺度信息聚合"的设计意图冲突。

§5.2 步骤 1 — 构建 down_pool

python
elif self.configs.down_sampling_method == "avg":
    down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window)

形状注解: TFB 默认 down_sampling_method="avg"down_sampling_window=2,故 down_pool = AvgPool1d(kernel_size=2, stride=2)。输入格式须为 (B,C,T),在 T 维上每 2 个值取均值,T 减半。

toy 数值: AvgPool1d(2) 作用于 (2, 3, 24) 时,对时间维每两步取均值:位置 0 = (v[0]+v[1])/2,位置 1 = (v[2]+v[3])/2,…,共 12 个输出,得 (2, 3, 12)

§5.3 步骤 2 — permute 准备 + 加入原始尺度

python
x_enc = x_enc.permute(0, 2, 1)
x_enc_ori = x_enc
x_enc_sampling_list = []
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
x_mark_sampling_list.append(x_mark_enc)

形状注解: x_enc.permute(0, 2, 1) 将轴顺序由 (B,T,N)(B,N,T),即 (2,24,3)(2,3,24),赋给 x_encx_enc_orix_enc_sampling_list.append(x_enc.permute(0,2,1)) 立刻 permute 回 (2,24,3) 加入列表,即原始尺度直接收录。

toy 数值: 第一行 permute 后 x_enc shape (2,3,24)x_enc_ori 指向同一张量。x_enc.permute(0,2,1)(2,24,3),这是尺度 0。x_enc_sampling_list = [(2,24,3)]x_mark_sampling_list = [(2,24,4)]

§5.4 步骤 3 — 循环下采样(down_sampling_layers=2 次)

python
for i in range(self.configs.down_sampling_layers):
    x_enc_sampling = down_pool(x_enc_ori)
    x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
    x_enc_ori = x_enc_sampling
    if x_mark_enc is not None:
        x_mark_sampling_list.append(
            x_mark_enc_mark_ori[:, :: self.configs.down_sampling_window, :]
        )
        x_mark_enc_mark_ori = x_mark_enc_mark_ori[
            :, :: self.configs.down_sampling_window, :
        ]

形状注解(i=0): down_pool(x_enc_ori=(2,3,24))AvgPool1d(2)T=24 维操作 → (2,3,12).permute(0,2,1)(2,12,3) 加入列表。x_enc_ori 更新为 (2,3,12)。mark:x_mark_enc_mark_ori[:, ::2, :] 取步长为 2 的切片,(2,24,4)(2,12,4) 加入列表,x_mark_enc_mark_ori 更新为 (2,12,4)

形状注解(i=1): down_pool(x_enc_ori=(2,3,12))(2,3,6),permute → (2,6,3) 加入列表。mark:(2,12,4)[:, ::2, :](2,6,4)

toy 数值: 循环结束后 x_enc_sampling_list = [(2,24,3), (2,12,3), (2,6,3)]x_mark_sampling_list = [(2,24,4), (2,12,4), (2,6,4)]。三个尺度时间维分别为 T=24T=12T=6,比例为 1:1/2:1/4

mark 为何用 stride 切片而不用 Pool? mark 是时间标记(小时/星期几等离散整数)。对两个时刻的"小时"取均值毫无语义(均值 = 3.5 不是任何真实时刻)。stride 切片 [::2] 每隔一步取一个真实时间戳,保留等间隔采样点的语义。

==切片语法补充:x_mark_enc_mark_ori[:, ::2, :] 即取时间位置:0, 2, 4, 6, ...==

*记录并在线阅读我的笔记*