Skip to content

Level4A PatchEmbedding 精读

Abstract

覆盖:Embed.py:PatchEmbedding(第 181-206 行)

输入:(B, enc_in, seq_len) = (2, 3, 9) 输出:(B×enc_in, patch_num, d_model) = (6, 4, 8),同时返回 n_vars=3

这一步做了三件事:①把时间轴切成重叠的 patch;②把 batch 和 channel 合并为 channel-independent 的大 batch;③把每个 patch 投影到 d_model 维。


Toy 参数(本文统一)

B=2, seq_len=9, enc_in=3
d_model=8, patch_len=4, stride=2, padding=2
→ patch_num = int((9-4)/2 + 2) = 4
→ B×enc_in = 6

原理→代码映射

论文步骤对应代码文件行说明
① 右端补 paddingself.padding_patch_layer(x)Embed.py:202ReplicationPad1d 复制最后一个时间步,不是补零
② 滑窗切 patchx.unfold(-1, patch_len, stride)Embed.py:203论文图里的"切片"操作,每次移动 stride 步
③ 合并 B×enc_intorch.reshape(x, (B*enc_in, ...))Embed.py:204channel-independent 的实现方式
④ patch→d_modelself.value_embedding(x)Embed.py:205Linear(patch_len→d_model),作用在最后一维
⑤ 加位置编码+ self.position_embedding(x)Embed.py:205广播加法,shape 不变

最容易卡住的两处

卡点 1:unfold 究竟做了什么?

unfold(dimension=-1, size=4, step=2) = 在最后一维上用大小为 4 的滑窗、每次移动 2 步提取子序列。

关键:相邻 patch 有重叠(overlap = patch_len − stride = 2),边界时间步的信息不会因为 patch 切割而丢失。

卡点 2:reshape 为什么能实现 channel-independent?

(B=2, enc_in=3) 合并成 B×enc_in=6,相当于把"2个样本×3个变量"视为"6个独立样本"。后续 Transformer 对这 6 条序列完全独立处理,变量之间没有任何 attention 交互。

代码上就是一行 torch.reshape,但在论文设计层面,这就是 channel-independent 思想的完整实现。


1. 原始代码

python
# Embed.py:181
class PatchEmbedding(nn.Module):
    def __init__(self, d_model, patch_len, stride, padding, dropout):
        super(PatchEmbedding, self).__init__()
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
        self.position_embedding = PositionalEmbedding(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # do patching
        n_vars = x.shape[1]
        x = self.padding_patch_layer(x)
        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
        # Input encoding
        x = self.value_embedding(x) + self.position_embedding(x)
        return self.dropout(x), n_vars

2. 步骤拆解

步骤 A:记录 n_vars

python
n_vars = x.shape[1]   # x: (2, 3, 9), x.shape[1] = 3
# n_vars = 3

为什么要提前记录:下一步 reshape 会把 batch 和 channel 合并,合并后就无法从 shape 里区分出原来是几个 channel 了。n_varsforecast() 的步骤 5(reshape 恢复 batch/channel)时需要用到。


步骤 B:ReplicationPad1d 右端填充

原始代码

python
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
# padding = stride = 2
x = self.padding_patch_layer(x)

注解版

python
# x: (2, 3, 9)  ← (batch, channel, time)
# ReplicationPad1d((0, 2)): 左端补 0 个,右端补 2 个
# 补的方式:复制边缘值(replication),不是补零
x = self.padding_patch_layer(x)
# x: (2, 3, 11)

ReplicationPad1d 说明

nn.ReplicationPad1d 的输入必须是 (*, L) 形状的张量,L 是被填充的维度。参数 (left, right) 表示左端补几个、右端补几个,填充内容是边缘元素的复制。

ZeroPad1d 的区别:

  • ZeroPad1d:补 0
  • ReplicationPad1d:复制边缘值

为什么用复制而非补零:时间序列里,补零会引入不真实的"断崖",复制边缘值更符合信号的平稳性假设,减少边界效应对最后一个 patch 的影响。

toy 数值追踪(只看 batch=0, channel=0 的一维序列):

原始时序: [a0, a1, a2, a3, a4, a5, a6, a7, a8]  (length=9)
                                                  ↑ 最后一个元素 a8

填充后:   [a0, a1, a2, a3, a4, a5, a6, a7, a8, a8, a8]  (length=11)
                                                       ↑ 复制了 a8 两次

用具体数字:

原始:    [1, 3, 5, 2, 4, 6, 3, 5, 7]
填充后:  [1, 3, 5, 2, 4, 6, 3, 5, 7, 7, 7]
                                    ↑ ↑ 复制边缘值 7

shape 变化:(2, 3, 9) → (2, 3, 11)


步骤 C:unfold 切 patch

原始代码

python
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)

注解版

python
# x: (2, 3, 11)
# unfold(dimension=-1, size=4, step=2)
# 在最后一维(time)上滑窗:窗口大小=4,步长=2
# patch_num = floor((11 - 4) / 2) + 1 = floor(3.5) + 1 = 3 + 1 = 4
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
# x: (2, 3, 4, 4)  ← (batch, channel, patch_num, patch_len)

unfold 详解

tensor.unfold(dimension, size, step) 在指定维度上滑动一个窗口,返回每个窗口的内容。

  • dimension=-1:在最后一维(时间轴)上操作
  • size=4:每个窗口(patch)的大小
  • step=2:每次滑动的步长

输出维度计算:patch_num = floor((L - size) / step) + 1

代入:floor((11 - 4) / 2) + 1 = floor(3.5) + 1 = 4

toy 数值追踪(batch=0, channel=0,填充后序列 [1,3,5,2,4,6,3,5,7,7,7]):

滑窗过程(size=4, step=2):

位置 0(步长起点 0): [1, 3, 5, 2]   ← 索引 0~3
位置 1(步长起点 2): [5, 2, 4, 6]   ← 索引 2~5
位置 2(步长起点 4): [4, 6, 3, 5]   ← 索引 4~7
位置 3(步长起点 6): [3, 5, 7, 7]   ← 索引 6~9

patch_num = 4 ✓
(步长起点 8 时:索引 8~11,但 11 > 10,越界,停止)

结果:这条时序的 4 个 patch = [ [1,3,5,2], [5,2,4,6], [4,6,3,5], [3,5,7,7] ]

shape 变化:(2, 3, 11) → (2, 3, 4, 4) ← (batch, channel, patch_num, patch_len)


步骤 D:reshape 合并 batch 和 channel

原始代码

python
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))

注解版

python
# x: (2, 3, 4, 4)  ← (batch, channel, patch_num, patch_len)
# reshape: batch × channel 合并为新的 batch 维
# x.shape[0] * x.shape[1] = 2 × 3 = 6
# x.shape[2] = 4  (patch_num)
# x.shape[3] = 4  (patch_len)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# x: (6, 4, 4)  ← (B×enc_in, patch_num, patch_len)

channel-independent 的本质

reshape 后,原本属于"batch 0 的 channel 0"、"batch 0 的 channel 1"、"batch 1 的 channel 0" 等 6 条序列,变成了 6 个独立的"样本"。它们在后续 Encoder 里完全不知道彼此的存在,互相独立地经过 attention 和 FFN。

toy 数值追踪(reshape 前后的数组排列):

reshape 前 (2, 3, 4, 4):
  [0]: batch=0, channel=0 → 4 个 patch
  [1]: batch=0, channel=1 → 4 个 patch
  [2]: batch=0, channel=2 → 4 个 patch
  [3]: batch=1, channel=0 → 4 个 patch
  [4]: batch=1, channel=1 → 4 个 patch
  [5]: batch=1, channel=2 → 4 个 patch

reshape 后 (6, 4, 4):
  [0]: 原 batch=0, channel=0 的 4 个 patch:[ [1,3,5,2],[5,2,4,6],[4,6,3,5],[3,5,7,7] ]
  [1]: 原 batch=0, channel=1 的 4 个 patch
  [2]: 原 batch=0, channel=2 的 4 个 patch
  [3]: 原 batch=1, channel=0 的 4 个 patch
  [4]: 原 batch=1, channel=1 的 4 个 patch
  [5]: 原 batch=1, channel=2 的 4 个 patch

shape 变化:(2, 3, 4, 4) → (6, 4, 4)


步骤 E:value_embedding(patch 投影)

原始代码

python
x = self.value_embedding(x) + self.position_embedding(x)

注解版(value_embedding 部分)

python
# x: (6, 4, 4)  ← (B×enc_in, patch_num, patch_len)
# nn.Linear(patch_len=4, d_model=8, bias=False)
# Linear 作用在最后一维(patch_len=4),前面 (6, 4) 是 batch 维
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
# x = self.value_embedding(x)
# x: (6, 4, 8)  ← (B×enc_in, patch_num, d_model)

nn.Linear 的维度语义

nn.Linear(in_features, out_features) 只对输入的最后一维做线性变换,其他维度全部视为 batch 维度。

这里:

  • 输入最后一维:patch_len=4(一个 patch 内的 4 个时间步)
  • 输出最后一维:d_model=8(嵌入维度)
  • 前面的 (6, 4) = (B×enc_in, patch_num),全部 batch 处理

等价于:对 6×4=24 个 patch,每个 patch 各自独立做一次 (4,) → (8,) 的线性投影。

toy 数值追踪(只追踪 x[0, 0, :],即第 0 个虚拟 batch 的第 0 个 patch):

输入 patch: [1.0, 3.0, 5.0, 2.0]  (已归一化后的值)

W 是形状 (8, 4) 的权重矩阵(bias=False)。
设 W 的第 0 行为 [0.1, -0.2, 0.3, -0.1]:

output[0] = 0.1×1.0 + (-0.2)×3.0 + 0.3×5.0 + (-0.1)×2.0
          = 0.1 - 0.6 + 1.5 - 0.2 = 0.8

(依此类推计算 8 个输出分量)
输出: [0.8, ..., ...]  ← (d_model=8,)

shape 变化:(6, 4, 4) → (6, 4, 8)


步骤 F:position_embedding 加法

原始代码

python
x = self.value_embedding(x) + self.position_embedding(x)

注解版(position_embedding 部分)

python
# self.position_embedding = PositionalEmbedding(d_model=8)
# PositionalEmbedding 输出 shape: (1, seq_len, d_model)
# 这里 seq_len 对应 patch_num=4,d_model=8
# → position_embedding(x): (1, 4, 8),广播加到 (6, 4, 8)

PositionalEmbedding 用标准的正弦/余弦位置编码,对 patch 的"位置"(第几个 patch)进行编码,让模型知道时间顺序。

shape 分析

value_embedding(x):     (6, 4, 8)
position_embedding(x):  (1, 4, 8)  ← 广播
相加结果:                (6, 4, 8)

步骤 G:Dropout

python
return self.dropout(x), n_vars
# x: (6, 4, 8)
# n_vars: 3

最终输出 (6, 4, 8)n_vars=3,传回 forecast() 步骤 3。


3. 完整 shape 链汇总

输入  x:  (2, 3, 9)     ← forecast() permute 后
  ↓ n_vars = 3
  ↓ ReplicationPad1d    → (2, 3, 11)
  ↓ unfold(-1, 4, 2)    → (2, 3, 4, 4)
  ↓ reshape             → (6, 4, 4)
  ↓ value_embedding     → (6, 4, 8)
  ↓ + position_embedding
  ↓ dropout
输出  enc_out: (6, 4, 8),  n_vars=3

4. 下一步

04B-Encoder精读(6, 4, 8) 输入 Encoder,经过 attention + FFN,再还原成 (2, 7, 3)

*记录并在线阅读我的笔记*