Appearance
ReplicationPad1d 与 unfold:PatchTST 的 patch 切割
Abstract
这篇只讲 PatchTST 的 patch 化入口:
怎样从
(B,C,T)的连续时间序列,切成(B*C, patch_num, patch_len)的 patch token 输入。
0. 文件索引
| 项目 | 内容 |
|---|---|
| 源文件 | ts_benchmark/baselines/time_series_library/layers/Embed.py |
| 源类 | PatchEmbedding |
| 源方法 | forward(self, x) |
| 父文档 | zdocs/modelread/PatchTST/03-Layer2A-PatchEmbedding.md |
| 输入 | (B, C, T) = (2, 4, 12) |
| 输出 | (B*C, patch_num, patch_len) = (8, 6, 4),再进入 Linear |
1. 源码对象
python
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, padding, dropout):
super(PatchEmbedding, self).__init__()
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
self.position_embedding = PositionalEmbedding(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# do patching
n_vars = x.shape[1]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# Input encoding
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x), n_vars本文重点:
python
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))整体流程先看这张图。这里用一个更小的数字例子:
text
B=1, C=2, T=6, patch_len=3, stride=2, padding=2, d_model=42. Level 1:输入为什么是 (B,C,T)
PatchTST 在 forecast 里先做:
python
x_enc = x_enc.permute(0, 2, 1)所以:
text
原始: (B, T, C)
进入 PatchEmbedding 前: (B, C, T)toy:
text
(B, C, T) = (2, 4, 12)ReplicationPad1d 和 unfold(dimension=-1) 都沿最后一维工作,所以这里让最后一维变成时间轴 T。
3. Level 2:ReplicationPad1d((0, padding))
源码:
python
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))toy:
text
padding = stride = 2(0, padding) 的意思是:
text
左边补 0 步
右边补 padding 步所以:
text
(B, C, T) = (2, 4, 12)
-> (B, C, T + padding) = (2, 4, 14)复制填充不是补零,而是复制边界值。
对某一个变量:
text
原始:
[t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11]
右端复制 padding=2:
[t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t11, t11]4. Level 3:为什么 PatchTST 只右端填充
DLinear 的 moving average 是居中窗口,所以左右都要填。
PatchTST 的 patch 切割是从左往右滑窗:
text
Patch 0 从 t0 开始
Patch 1 从 t2 开始
Patch 2 从 t4 开始
...问题只出在最后一个 patch 可能不够长。
所以 PatchTST 只需要右端补,让最后一个 patch 能完整取到 patch_len 个点。
5. Level 4:unfold(dimension=-1, size=patch_len, step=stride)
源码:
python
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)5.1 x.unfold() 是 PyTorch 内置方法吗?
是。这里的 x.unfold(...) 不是 PatchTST 自己定义的方法,而是 PyTorch Tensor 自带的方法:
python
torch.Tensor.unfold(dimension, size, step)因为 x 是一个 torch.Tensor,所以可以直接写:
python
x.unfold(...)它的作用是:
沿某一个维度做滑动窗口切片。
先看一个最小一维例子:
python
import torch
x = torch.arange(10)
# x = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
y = x.unfold(dimension=0, size=3, step=2)输出:
text
y =
tensor([[0, 1, 2],
[2, 3, 4],
[4, 5, 6],
[6, 7, 8]])解释:
text
原始序列:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
size=3 表示每个窗口取 3 个连续元素
step=2 表示下一个窗口的起点向右移动 2 步
窗口 0 起点 0: [0, 1, 2]
窗口 1 起点 2: [2, 3, 4]
窗口 2 起点 4: [4, 5, 6]
窗口 3 起点 6: [6, 7, 8]最后不会产生 [8, 9, ?],因为不够 size=3 个元素。PatchTST 之前做 ReplicationPad1d,就是为了让最后一个窗口也够长。
5.2 回到 PatchTST:沿时间维切 patch
参数解释:
| 参数 | toy 值 | 含义 |
|---|---|---|
dimension=-1 | 最后一维 | 沿时间轴切 |
size=patch_len | 4 | 每个 patch 包含 4 个连续时间步 |
step=stride | 2 | 相邻 patch 起点间隔 2 步 |
输入:
text
x.shape = (2, 4, 14)输出:
text
x.shape = (2, 4, 6, 4)含义:
text
(B, C, patch_num, patch_len)6. Level 5:patch_num 怎么算
公式:
text
patch_num = floor((T_padded - patch_len) / stride) + 1代入:
text
T_padded = 14
patch_len = 4
stride = 2
patch_num = floor((14 - 4) / 2) + 1
= floor(10 / 2) + 1
= 6所以:
text
(2, 4, 14) -> unfold -> (2, 4, 6, 4)7. Level 6:可视化 patch 切割
只看一个变量:
text
T_padded = [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t11, t11]
0 1 2 3 4 5 6 7 8 9 10 11 12 13patch_len=4,stride=2:
text
Patch 0 起点 0: [t0, t1, t2, t3]
Patch 1 起点 2: [t2, t3, t4, t5]
Patch 2 起点 4: [t4, t5, t6, t7]
Patch 3 起点 6: [t6, t7, t8, t9]
Patch 4 起点 8: [t8, t9, t10, t11]
Patch 5 起点 10: [t10, t11, t11, t11]最后一个 patch 能成立,是因为右端复制了两个 t11。
下面这张图把同一件事换成具体数字。注意它不是只看 shape,而是直接展示两个变量各自被切成哪些 patch:
8. Level 7:reshape 到 (B*C, patch_num, patch_len)
源码:
python
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))输入:
text
(B, C, patch_num, patch_len) = (2, 4, 6, 4)输出:
text
(B*C, patch_num, patch_len) = (8, 6, 4)reshape 的宏观逻辑是:把“第几个样本”和“第几个变量”合并成新的样本维,让每个变量独立进入后面的 Linear + Transformer。
这一步是 PatchTST 的 channel-independent 核心。
它把:
text
2 个样本 * 4 个变量合并成:
text
8 条独立的单变量 patch 序列从此之后,Transformer attention 在每个变量自己的 patch 序列里算,不跨变量算。
9. 一句话总结
PatchTST 的这三步:
python
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))可以翻译成:
先在时间轴右边复制补齐,再沿时间轴滑窗切 patch,最后把 batch 和变量维合并,让每个变量作为独立序列进入 Transformer。