Appearance
Level4A PatchEmbedding 精读
Abstract
覆盖:
Embed.py:PatchEmbedding(第 181-206 行)输入:
(B, enc_in, seq_len)=(2, 3, 9)输出:(B×enc_in, patch_num, d_model)=(6, 4, 8),同时返回n_vars=3这一步做了三件事:①把时间轴切成重叠的 patch;②把 batch 和 channel 合并为 channel-independent 的大 batch;③把每个 patch 投影到 d_model 维。
Toy 参数(本文统一)
B=2, seq_len=9, enc_in=3
d_model=8, patch_len=4, stride=2, padding=2
→ patch_num = int((9-4)/2 + 2) = 4
→ B×enc_in = 6原理→代码映射
| 论文步骤 | 对应代码 | 文件行 | 说明 |
|---|---|---|---|
| ① 右端补 padding | self.padding_patch_layer(x) | Embed.py:202 | ReplicationPad1d 复制最后一个时间步,不是补零 |
| ② 滑窗切 patch | x.unfold(-1, patch_len, stride) | Embed.py:203 | 论文图里的"切片"操作,每次移动 stride 步 |
| ③ 合并 B×enc_in | torch.reshape(x, (B*enc_in, ...)) | Embed.py:204 | channel-independent 的实现方式 |
| ④ patch→d_model | self.value_embedding(x) | Embed.py:205 | Linear(patch_len→d_model),作用在最后一维 |
| ⑤ 加位置编码 | + self.position_embedding(x) | Embed.py:205 | 广播加法,shape 不变 |
最容易卡住的两处
卡点 1:unfold 究竟做了什么?
unfold(dimension=-1, size=4, step=2) = 在最后一维上用大小为 4 的滑窗、每次移动 2 步提取子序列。
关键:相邻 patch 有重叠(overlap = patch_len − stride = 2),边界时间步的信息不会因为 patch 切割而丢失。
卡点 2:reshape 为什么能实现 channel-independent?
把 (B=2, enc_in=3) 合并成 B×enc_in=6,相当于把"2个样本×3个变量"视为"6个独立样本"。后续 Transformer 对这 6 条序列完全独立处理,变量之间没有任何 attention 交互。
代码上就是一行
torch.reshape,但在论文设计层面,这就是 channel-independent 思想的完整实现。
1. 原始代码
python
# Embed.py:181
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, padding, dropout):
super(PatchEmbedding, self).__init__()
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
self.position_embedding = PositionalEmbedding(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# do patching
n_vars = x.shape[1]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# Input encoding
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x), n_vars2. 步骤拆解
步骤 A:记录 n_vars
python
n_vars = x.shape[1] # x: (2, 3, 9), x.shape[1] = 3
# n_vars = 3为什么要提前记录:下一步 reshape 会把 batch 和 channel 合并,合并后就无法从 shape 里区分出原来是几个 channel 了。n_vars 在 forecast() 的步骤 5(reshape 恢复 batch/channel)时需要用到。
步骤 B:ReplicationPad1d 右端填充
原始代码:
python
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
# padding = stride = 2
x = self.padding_patch_layer(x)注解版:
python
# x: (2, 3, 9) ← (batch, channel, time)
# ReplicationPad1d((0, 2)): 左端补 0 个,右端补 2 个
# 补的方式:复制边缘值(replication),不是补零
x = self.padding_patch_layer(x)
# x: (2, 3, 11)ReplicationPad1d 说明:
nn.ReplicationPad1d 的输入必须是 (*, L) 形状的张量,L 是被填充的维度。参数 (left, right) 表示左端补几个、右端补几个,填充内容是边缘元素的复制。
与 ZeroPad1d 的区别:
ZeroPad1d:补 0ReplicationPad1d:复制边缘值
为什么用复制而非补零:时间序列里,补零会引入不真实的"断崖",复制边缘值更符合信号的平稳性假设,减少边界效应对最后一个 patch 的影响。
toy 数值追踪(只看 batch=0, channel=0 的一维序列):
原始时序: [a0, a1, a2, a3, a4, a5, a6, a7, a8] (length=9)
↑ 最后一个元素 a8
填充后: [a0, a1, a2, a3, a4, a5, a6, a7, a8, a8, a8] (length=11)
↑ 复制了 a8 两次用具体数字:
原始: [1, 3, 5, 2, 4, 6, 3, 5, 7]
填充后: [1, 3, 5, 2, 4, 6, 3, 5, 7, 7, 7]
↑ ↑ 复制边缘值 7shape 变化:(2, 3, 9) → (2, 3, 11)
步骤 C:unfold 切 patch
原始代码:
python
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)注解版:
python
# x: (2, 3, 11)
# unfold(dimension=-1, size=4, step=2)
# 在最后一维(time)上滑窗:窗口大小=4,步长=2
# patch_num = floor((11 - 4) / 2) + 1 = floor(3.5) + 1 = 3 + 1 = 4
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
# x: (2, 3, 4, 4) ← (batch, channel, patch_num, patch_len)unfold 详解:
tensor.unfold(dimension, size, step) 在指定维度上滑动一个窗口,返回每个窗口的内容。
dimension=-1:在最后一维(时间轴)上操作size=4:每个窗口(patch)的大小step=2:每次滑动的步长
输出维度计算:patch_num = floor((L - size) / step) + 1
代入:floor((11 - 4) / 2) + 1 = floor(3.5) + 1 = 4
toy 数值追踪(batch=0, channel=0,填充后序列 [1,3,5,2,4,6,3,5,7,7,7]):
滑窗过程(size=4, step=2):
位置 0(步长起点 0): [1, 3, 5, 2] ← 索引 0~3
位置 1(步长起点 2): [5, 2, 4, 6] ← 索引 2~5
位置 2(步长起点 4): [4, 6, 3, 5] ← 索引 4~7
位置 3(步长起点 6): [3, 5, 7, 7] ← 索引 6~9
patch_num = 4 ✓
(步长起点 8 时:索引 8~11,但 11 > 10,越界,停止)结果:这条时序的 4 个 patch = [ [1,3,5,2], [5,2,4,6], [4,6,3,5], [3,5,7,7] ]
shape 变化:(2, 3, 11) → (2, 3, 4, 4) ← (batch, channel, patch_num, patch_len)
步骤 D:reshape 合并 batch 和 channel
原始代码:
python
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))注解版:
python
# x: (2, 3, 4, 4) ← (batch, channel, patch_num, patch_len)
# reshape: batch × channel 合并为新的 batch 维
# x.shape[0] * x.shape[1] = 2 × 3 = 6
# x.shape[2] = 4 (patch_num)
# x.shape[3] = 4 (patch_len)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# x: (6, 4, 4) ← (B×enc_in, patch_num, patch_len)channel-independent 的本质:
reshape 后,原本属于"batch 0 的 channel 0"、"batch 0 的 channel 1"、"batch 1 的 channel 0" 等 6 条序列,变成了 6 个独立的"样本"。它们在后续 Encoder 里完全不知道彼此的存在,互相独立地经过 attention 和 FFN。
toy 数值追踪(reshape 前后的数组排列):
reshape 前 (2, 3, 4, 4):
[0]: batch=0, channel=0 → 4 个 patch
[1]: batch=0, channel=1 → 4 个 patch
[2]: batch=0, channel=2 → 4 个 patch
[3]: batch=1, channel=0 → 4 个 patch
[4]: batch=1, channel=1 → 4 个 patch
[5]: batch=1, channel=2 → 4 个 patch
reshape 后 (6, 4, 4):
[0]: 原 batch=0, channel=0 的 4 个 patch:[ [1,3,5,2],[5,2,4,6],[4,6,3,5],[3,5,7,7] ]
[1]: 原 batch=0, channel=1 的 4 个 patch
[2]: 原 batch=0, channel=2 的 4 个 patch
[3]: 原 batch=1, channel=0 的 4 个 patch
[4]: 原 batch=1, channel=1 的 4 个 patch
[5]: 原 batch=1, channel=2 的 4 个 patchshape 变化:(2, 3, 4, 4) → (6, 4, 4)
步骤 E:value_embedding(patch 投影)
原始代码:
python
x = self.value_embedding(x) + self.position_embedding(x)注解版(value_embedding 部分):
python
# x: (6, 4, 4) ← (B×enc_in, patch_num, patch_len)
# nn.Linear(patch_len=4, d_model=8, bias=False)
# Linear 作用在最后一维(patch_len=4),前面 (6, 4) 是 batch 维
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
# x = self.value_embedding(x)
# x: (6, 4, 8) ← (B×enc_in, patch_num, d_model)nn.Linear 的维度语义:
nn.Linear(in_features, out_features) 只对输入的最后一维做线性变换,其他维度全部视为 batch 维度。
这里:
- 输入最后一维:patch_len=4(一个 patch 内的 4 个时间步)
- 输出最后一维:d_model=8(嵌入维度)
- 前面的
(6, 4)=(B×enc_in, patch_num),全部 batch 处理
等价于:对 6×4=24 个 patch,每个 patch 各自独立做一次 (4,) → (8,) 的线性投影。
toy 数值追踪(只追踪 x[0, 0, :],即第 0 个虚拟 batch 的第 0 个 patch):
输入 patch: [1.0, 3.0, 5.0, 2.0] (已归一化后的值)
W 是形状 (8, 4) 的权重矩阵(bias=False)。
设 W 的第 0 行为 [0.1, -0.2, 0.3, -0.1]:
output[0] = 0.1×1.0 + (-0.2)×3.0 + 0.3×5.0 + (-0.1)×2.0
= 0.1 - 0.6 + 1.5 - 0.2 = 0.8
(依此类推计算 8 个输出分量)
输出: [0.8, ..., ...] ← (d_model=8,)shape 变化:(6, 4, 4) → (6, 4, 8)
步骤 F:position_embedding 加法
原始代码:
python
x = self.value_embedding(x) + self.position_embedding(x)注解版(position_embedding 部分):
python
# self.position_embedding = PositionalEmbedding(d_model=8)
# PositionalEmbedding 输出 shape: (1, seq_len, d_model)
# 这里 seq_len 对应 patch_num=4,d_model=8
# → position_embedding(x): (1, 4, 8),广播加到 (6, 4, 8)PositionalEmbedding 用标准的正弦/余弦位置编码,对 patch 的"位置"(第几个 patch)进行编码,让模型知道时间顺序。
shape 分析:
value_embedding(x): (6, 4, 8)
position_embedding(x): (1, 4, 8) ← 广播
相加结果: (6, 4, 8)步骤 G:Dropout
python
return self.dropout(x), n_vars
# x: (6, 4, 8)
# n_vars: 3最终输出 (6, 4, 8) 和 n_vars=3,传回 forecast() 步骤 3。
3. 完整 shape 链汇总
输入 x: (2, 3, 9) ← forecast() permute 后
↓ n_vars = 3
↓ ReplicationPad1d → (2, 3, 11)
↓ unfold(-1, 4, 2) → (2, 3, 4, 4)
↓ reshape → (6, 4, 4)
↓ value_embedding → (6, 4, 8)
↓ + position_embedding
↓ dropout
输出 enc_out: (6, 4, 8), n_vars=34. 下一步
04B-Encoder精读:(6, 4, 8) 输入 Encoder,经过 attention + FFN,再还原成 (2, 7, 3)。