Skip to content

Flatten 与标准化统计:PatchTST 的输出头和反标准化

Abstract

这篇讲 PatchTST forward 后半段的两个基础点:

输出头怎样用 Flatten + Linear 从 patch 表示得到预测值,以及标准化/反标准化里的 mean / var / sqrt / detach / repeat 是什么意思。

0. 文件索引

项目内容
源文件ts_benchmark/baselines/time_series_library/models/PatchTST.py
源类FlattenHead / PatchTST
源方法forecast
覆盖函数nn.Flatten / nn.Linear / mean / var / sqrt / detach / unsqueeze / repeat

1. Level 1:PatchTST forecast 的相关源码

python
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
    # Normalization from Non-stationary Transformer
    means = x_enc.mean(1, keepdim=True).detach()
    x_enc = x_enc - means
    stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
    x_enc /= stdev

    # do patching and embedding
    x_enc = x_enc.permute(0, 2, 1)
    enc_out, n_vars = self.patch_embedding(x_enc)

    # Encoder
    enc_out, attns = self.encoder(enc_out)
    enc_out = torch.reshape(
        enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
    )
    enc_out = enc_out.permute(0, 1, 3, 2)

    # Decoder
    dec_out = self.head(enc_out)
    dec_out = dec_out.permute(0, 2, 1)

    # De-Normalization from Non-stationary Transformer
    dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    return dec_out

2. Level 2:标准化里的 mean / var / sqrt

输入:

text
x_enc.shape = (B, T, C)
toy: (2, 12, 4)

代码:

python
means = x_enc.mean(1, keepdim=True).detach()

解释:

text
dim=1 是时间维 T
keepdim=True 保留时间维

shape:

text
(2, 12, 4) -> mean(dim=1, keepdim=True) -> (2, 1, 4)

含义:

对每个样本、每个变量,沿历史时间维求均值。

2.1 mean(1, keepdim=True) 的具体数字例子

mean(1, keepdim=True) 也可以写成:

python
mean(dim=1, keepdim=True)

这里的 1 指第 1 维。

因为 PatchTST 里:

text
x_enc.shape = (B, T, C)

所以:

text
dim=0 是 batch 维
dim=1 是 time 维
dim=2 是 channel / variable 维

因此:

python
x_enc.mean(1, keepdim=True)

意思是:

对每条样本、每个变量,把所有历史时间步求平均。

用一个更小的 toy:

text
x_enc.shape = (B, T, C) = (2, 3, 2)

具体值:

python
x_enc = torch.tensor([
    [
        [1., 10.],   # batch 0, t0
        [2., 20.],   # batch 0, t1
        [3., 30.],   # batch 0, t2
    ],
    [
        [4., 40.],   # batch 1, t0
        [5., 50.],   # batch 1, t1
        [6., 60.],   # batch 1, t2
    ],
])

对第 0 条样本:

text
变量 0: mean([1, 2, 3]) = 2
变量 1: mean([10, 20, 30]) = 20

对第 1 条样本:

text
变量 0: mean([4, 5, 6]) = 5
变量 1: mean([40, 50, 60]) = 50

所以:

python
means = x_enc.mean(1, keepdim=True)

结果:

text
means =
[
    [[2., 20.]],
    [[5., 50.]],
]

shape:

text
(2, 1, 2)

这里中间那个 1 就是被保留下来的时间维。

如果不用 keepdim=True

python
x_enc.mean(1).shape

会得到:

text
(2, 2)

时间维会直接消失。

2.2 为什么 x_enc - means 能减

代码:

python
x_enc = x_enc - means

此时:

text
x_enc.shape = (2, 3, 2)
means.shape = (2, 1, 2)

PyTorch 会自动广播:

text
(2, 1, 2) -> (2, 3, 2)

也就是每个时间步都减同一个历史均值。

第 0 条样本:

text
原始:
[[1, 10],
 [2, 20],
 [3, 30]]

均值:
[[2, 20]]

相减:
[[1-2, 10-20],
 [2-2, 20-20],
 [3-2, 30-20]]
=
[[-1, -10],
 [ 0,   0],
 [ 1,  10]]

标准差:

python
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)

shape:

text
torch.var(..., dim=1, keepdim=True): (2, 1, 4)
torch.sqrt(...): (2, 1, 4)

1e-5 是为了避免方差为 0 时除以 0。

3. Level 3:detach() 是什么

python
means = x_enc.mean(1, keepdim=True).detach()

detach() 的意思是:

生成一个不参与梯度传播的 tensor 视图。

这里的语义:

text
均值和标准差用于标准化输入,但不希望模型通过它们反向传播去“学习”均值计算过程。

它不改变 shape。

4. Level 4:FlattenHead

源码:

python
class FlattenHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x

输入:

text
enc_out.shape = (B, C, d_model, patch_num)
toy: (2, 4, 16, 6)

执行:

python
x = self.flatten(x)

start_dim=-2 表示从倒数第 2 维开始展平。

所以:

text
(2, 4, 16, 6) -> (2, 4, 96)

因为:

text
16 * 6 = 96

再执行:

python
x = self.linear(x)

如果:

text
nf = 96
target_window = pred_len = 24

则:

text
(2, 4, 96) -> Linear(96, 24) -> (2, 4, 24)

最后:

python
dec_out = dec_out.permute(0, 2, 1)

得到:

text
(2, 4, 24) -> (2, 24, 4)

也就是最终预测格式:

text
(B, pred_len, C)

5. Level 5:反标准化里的 unsqueeze + repeat

源码:

python
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

先看:

python
stdev[:, 0, :]

原来:

text
stdev.shape = (B, 1, C) = (2, 1, 4)

取掉中间那个 size=1 的时间维:

text
stdev[:, 0, :].shape = (2, 4)

再:

python
unsqueeze(1)

把时间维加回来:

text
(2, 4) -> (2, 1, 4)

再:

python
repeat(1, self.pred_len, 1)

复制到每个预测时间步:

text
(2, 1, 4) -> (2, 24, 4)

这样它才能和:

text
dec_out.shape = (2, 24, 4)

逐元素相乘、相加。

6. Level 6:可算小例子

某个变量的历史序列:

text
x = [2, 4, 6, 8]

均值:

text
mean = 5

方差:

text
var = ((2-5)^2 + (4-5)^2 + (6-5)^2 + (8-5)^2) / 4
    = (9 + 1 + 1 + 9) / 4
    = 5

标准差:

text
stdev = sqrt(5) ≈ 2.236

标准化:

text
[(2-5)/2.236, (4-5)/2.236, (6-5)/2.236, (8-5)/2.236]
= [-1.342, -0.447, 0.447, 1.342]

如果模型预测标准化空间里的未来值:

text
pred_norm = [0.0, 1.0]

反标准化:

text
pred = pred_norm * stdev + mean
     = [0.0, 1.0] * 2.236 + 5
     = [5.0, 7.236]

7. 常见错误

7.1 忘记 keepdim=True

如果:

python
means = x_enc.mean(1)

shape 会是:

text
(B, C)

后续和 (B,T,C) 相减时虽然可能广播成功,但维度语义不如 (B,1,C) 清楚。

7.2 不理解 repeat(1, pred_len, 1)

它不是复制 batch,也不是复制变量。

它只是在时间维复制到预测长度:

text
(B, 1, C) -> (B, pred_len, C)

8. 一句话总结

PatchTST 后半段可以压成:

FlattenHead 把每个变量的所有 patch 表示展平成一个向量,再用 Linear 映射到预测长度;标准化用历史均值/标准差稳定输入,反标准化再把预测值还原回原始尺度。

*记录并在线阅读我的笔记*