Appearance
Layer 2A — __multi_scale_process_inputs() 多尺度输入构建
1. 在父层中的位置
forecast() 第一行调用 __multi_scale_process_inputs(x_enc, x_mark_enc) 生成多尺度列表,是后续所有操作的起点。
2. I/O 接口定义
| 参数 | Shape | 说明 |
|---|---|---|
x_enc | (2, 24, 3) | 原始 encoder 输入 |
x_mark_enc | (2, 24, 4) | 时间标记 |
| 返回 x_enc | [(2,24,3),(2,12,3),(2,6,3)] | down_sampling_layers+1 = 3 个尺度 |
| 返回 x_mark | [(2,24,4),(2,12,4),(2,6,4)] | 对应 mark 列表 |
3. 顺序图
4. 语义分组图
5. 逐步骤精读
§5.0 完整原始代码
python
def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
if self.configs.down_sampling_method == "max":
down_pool = torch.nn.MaxPool1d(
self.configs.down_sampling_window, return_indices=False
)
elif self.configs.down_sampling_method == "avg":
down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window)
elif self.configs.down_sampling_method == "conv":
padding = 1 if torch.__version__ >= "1.5.0" else 2
down_pool = nn.Conv1d(
in_channels=self.configs.enc_in,
out_channels=self.configs.enc_in,
kernel_size=3,
padding=padding,
stride=self.configs.down_sampling_window,
padding_mode="circular",
bias=False,
)
else:
return x_enc, x_mark_enc
# B,T,C -> B,C,T
x_enc = x_enc.permute(0, 2, 1)
x_enc_ori = x_enc
x_mark_enc_mark_ori = x_mark_enc
x_enc_sampling_list = []
x_mark_sampling_list = []
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
x_mark_sampling_list.append(x_mark_enc)
for i in range(self.configs.down_sampling_layers):
x_enc_sampling = down_pool(x_enc_ori)
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
x_enc_ori = x_enc_sampling
if x_mark_enc is not None:
x_mark_sampling_list.append(
x_mark_enc_mark_ori[:, :: self.configs.down_sampling_window, :]
)
x_mark_enc_mark_ori = x_mark_enc_mark_ori[
:, :: self.configs.down_sampling_window, :
]
x_enc = x_enc_sampling_list
x_mark_enc = x_mark_sampling_list if x_mark_enc is not None else None
return x_enc, x_mark_enc§5.1 宏观逻辑
目标:以固定窗口大小
用小例子(
首先 permute(0,2,1) 将 (1,8,2) 变为 (1,2,8),因为 AvgPool1d 期望输入格式为 (1,8,2)(permute 还原)直接加入结果列表。循环 (1,2,8) 的时间维 (1,2,4),permute 回 (1,4,2);循环 (1,2,4) → (1,2,2),permute 回 (1,2,2)。结果列表为 [(1,8,2),(1,4,2),(1,2,2)]。
为什么先 permute 再 Pool 再 permute 回来? PyTorch AvgPool1d 对最后一维操作,要求格式
AvgPool vs 可学习下采样
AvgPool1d 无参数:下采样 = 局部均值,纯粹的信息粗化,不引入特征变换。可学习的 Conv 下采样会在压缩时学习特征,把"改变分辨率"和"改变特征语义"两件事混在一起,与"多尺度信息聚合"的设计意图冲突。
§5.2 步骤 1 — 构建 down_pool
python
elif self.configs.down_sampling_method == "avg":
down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window)形状注解: TFB 默认 down_sampling_method="avg",down_sampling_window=2,故 down_pool = AvgPool1d(kernel_size=2, stride=2)。输入格式须为
toy 数值: AvgPool1d(2) 作用于 (2, 3, 24) 时,对时间维每两步取均值:位置 0 = (v[0]+v[1])/2,位置 1 = (v[2]+v[3])/2,…,共 12 个输出,得 (2, 3, 12)。
§5.3 步骤 2 — permute 准备 + 加入原始尺度
python
x_enc = x_enc.permute(0, 2, 1)
x_enc_ori = x_enc
x_enc_sampling_list = []
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
x_mark_sampling_list.append(x_mark_enc)形状注解: x_enc.permute(0, 2, 1) 将轴顺序由 (2,24,3) → (2,3,24),赋给 x_enc 和 x_enc_ori。x_enc_sampling_list.append(x_enc.permute(0,2,1)) 立刻 permute 回 (2,24,3) 加入列表,即原始尺度直接收录。
toy 数值: 第一行 permute 后 x_enc shape (2,3,24),x_enc_ori 指向同一张量。x_enc.permute(0,2,1) 得 (2,24,3),这是尺度 0。x_enc_sampling_list = [(2,24,3)],x_mark_sampling_list = [(2,24,4)]。
§5.4 步骤 3 — 循环下采样(down_sampling_layers=2 次)
python
for i in range(self.configs.down_sampling_layers):
x_enc_sampling = down_pool(x_enc_ori)
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
x_enc_ori = x_enc_sampling
if x_mark_enc is not None:
x_mark_sampling_list.append(
x_mark_enc_mark_ori[:, :: self.configs.down_sampling_window, :]
)
x_mark_enc_mark_ori = x_mark_enc_mark_ori[
:, :: self.configs.down_sampling_window, :
]形状注解(down_pool(x_enc_ori=(2,3,24)) → AvgPool1d(2) 在 (2,3,12)。.permute(0,2,1) → (2,12,3) 加入列表。x_enc_ori 更新为 (2,3,12)。mark:x_mark_enc_mark_ori[:, ::2, :] 取步长为 2 的切片,(2,24,4) → (2,12,4) 加入列表,x_mark_enc_mark_ori 更新为 (2,12,4)。
形状注解(down_pool(x_enc_ori=(2,3,12)) → (2,3,6),permute → (2,6,3) 加入列表。mark:(2,12,4)[:, ::2, :] → (2,6,4)。
toy 数值: 循环结束后 x_enc_sampling_list = [(2,24,3), (2,12,3), (2,6,3)],x_mark_sampling_list = [(2,24,4), (2,12,4), (2,6,4)]。三个尺度时间维分别为
mark 为何用 stride 切片而不用 Pool? mark 是时间标记(小时/星期几等离散整数)。对两个时刻的"小时"取均值毫无语义(均值 = 3.5 不是任何真实时刻)。stride 切片 [::2] 每隔一步取一个真实时间戳,保留等间隔采样点的语义。
==切片语法补充:x_mark_enc_mark_ori[:, ::2, :] 即取时间位置:0, 2, 4, 6, ...==