Skip to content

ReplicationPad1d 与 unfold:PatchTST 的 patch 切割

Abstract

这篇只讲 PatchTST 的 patch 化入口:

怎样从 (B,C,T) 的连续时间序列,切成 (B*C, patch_num, patch_len) 的 patch token 输入。

0. 文件索引

项目内容
源文件ts_benchmark/baselines/time_series_library/layers/Embed.py
源类PatchEmbedding
源方法forward(self, x)
父文档zdocs/modelread/PatchTST/03-Layer2A-PatchEmbedding.md
输入(B, C, T) = (2, 4, 12)
输出(B*C, patch_num, patch_len) = (8, 6, 4),再进入 Linear

1. 源码对象

python
class PatchEmbedding(nn.Module):
    def __init__(self, d_model, patch_len, stride, padding, dropout):
        super(PatchEmbedding, self).__init__()
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch_layer = nn.ReplicationPad1d((0, padding))

        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
        self.position_embedding = PositionalEmbedding(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # do patching
        n_vars = x.shape[1]
        x = self.padding_patch_layer(x)
        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
        # Input encoding
        x = self.value_embedding(x) + self.position_embedding(x)
        return self.dropout(x), n_vars

本文重点:

python
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))

整体流程先看这张图。这里用一个更小的数字例子:

text
B=1, C=2, T=6, patch_len=3, stride=2, padding=2, d_model=4

2. Level 1:输入为什么是 (B,C,T)

PatchTST 在 forecast 里先做:

python
x_enc = x_enc.permute(0, 2, 1)

所以:

text
原始: (B, T, C)
进入 PatchEmbedding 前: (B, C, T)

toy:

text
(B, C, T) = (2, 4, 12)

ReplicationPad1dunfold(dimension=-1) 都沿最后一维工作,所以这里让最后一维变成时间轴 T

3. Level 2:ReplicationPad1d((0, padding))

源码:

python
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))

toy:

text
padding = stride = 2

(0, padding) 的意思是:

text
左边补 0 步
右边补 padding 步

所以:

text
(B, C, T) = (2, 4, 12)
-> (B, C, T + padding) = (2, 4, 14)

复制填充不是补零,而是复制边界值。

对某一个变量:

text
原始:
[t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11]

右端复制 padding=2:
[t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t11, t11]

4. Level 3:为什么 PatchTST 只右端填充

DLinear 的 moving average 是居中窗口,所以左右都要填。

PatchTST 的 patch 切割是从左往右滑窗:

text
Patch 0 从 t0 开始
Patch 1 从 t2 开始
Patch 2 从 t4 开始
...

问题只出在最后一个 patch 可能不够长。

所以 PatchTST 只需要右端补,让最后一个 patch 能完整取到 patch_len 个点。

5. Level 4:unfold(dimension=-1, size=patch_len, step=stride)

源码:

python
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)

5.1 x.unfold() 是 PyTorch 内置方法吗?

是。这里的 x.unfold(...) 不是 PatchTST 自己定义的方法,而是 PyTorch Tensor 自带的方法

python
torch.Tensor.unfold(dimension, size, step)

因为 x 是一个 torch.Tensor,所以可以直接写:

python
x.unfold(...)

它的作用是:

沿某一个维度做滑动窗口切片。

先看一个最小一维例子:

python
import torch

x = torch.arange(10)
# x = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

y = x.unfold(dimension=0, size=3, step=2)

输出:

text
y =
tensor([[0, 1, 2],
        [2, 3, 4],
        [4, 5, 6],
        [6, 7, 8]])

解释:

text
原始序列:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

size=3  表示每个窗口取 3 个连续元素
step=2  表示下一个窗口的起点向右移动 2 步

窗口 0 起点 0: [0, 1, 2]
窗口 1 起点 2: [2, 3, 4]
窗口 2 起点 4: [4, 5, 6]
窗口 3 起点 6: [6, 7, 8]

最后不会产生 [8, 9, ?],因为不够 size=3 个元素。PatchTST 之前做 ReplicationPad1d,就是为了让最后一个窗口也够长。

5.2 回到 PatchTST:沿时间维切 patch

参数解释:

参数toy 值含义
dimension=-1最后一维沿时间轴切
size=patch_len4每个 patch 包含 4 个连续时间步
step=stride2相邻 patch 起点间隔 2 步

输入:

text
x.shape = (2, 4, 14)

输出:

text
x.shape = (2, 4, 6, 4)

含义:

text
(B, C, patch_num, patch_len)

6. Level 5:patch_num 怎么算

公式:

text
patch_num = floor((T_padded - patch_len) / stride) + 1

代入:

text
T_padded = 14
patch_len = 4
stride = 2

patch_num = floor((14 - 4) / 2) + 1
          = floor(10 / 2) + 1
          = 6

所以:

text
(2, 4, 14) -> unfold -> (2, 4, 6, 4)

7. Level 6:可视化 patch 切割

只看一个变量:

text
T_padded = [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t11, t11]
             0   1   2   3   4   5   6   7   8   9   10   11   12   13

patch_len=4stride=2

text
Patch 0 起点 0:  [t0,  t1,  t2,  t3]
Patch 1 起点 2:  [t2,  t3,  t4,  t5]
Patch 2 起点 4:  [t4,  t5,  t6,  t7]
Patch 3 起点 6:  [t6,  t7,  t8,  t9]
Patch 4 起点 8:  [t8,  t9,  t10, t11]
Patch 5 起点 10: [t10, t11, t11, t11]

最后一个 patch 能成立,是因为右端复制了两个 t11

下面这张图把同一件事换成具体数字。注意它不是只看 shape,而是直接展示两个变量各自被切成哪些 patch:

8. Level 7:reshape 到 (B*C, patch_num, patch_len)

源码:

python
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))

输入:

text
(B, C, patch_num, patch_len) = (2, 4, 6, 4)

输出:

text
(B*C, patch_num, patch_len) = (8, 6, 4)

reshape 的宏观逻辑是:把“第几个样本”和“第几个变量”合并成新的样本维,让每个变量独立进入后面的 Linear + Transformer

这一步是 PatchTST 的 channel-independent 核心。

它把:

text
2 个样本 * 4 个变量

合并成:

text
8 条独立的单变量 patch 序列

从此之后,Transformer attention 在每个变量自己的 patch 序列里算,不跨变量算。

9. 一句话总结

PatchTST 的这三步:

python
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))

可以翻译成:

先在时间轴右边复制补齐,再沿时间轴滑窗切 patch,最后把 batch 和变量维合并,让每个变量作为独立序列进入 Transformer。

*记录并在线阅读我的笔记*