Appearance
Flatten 与标准化统计:PatchTST 的输出头和反标准化
Abstract
这篇讲 PatchTST forward 后半段的两个基础点:
输出头怎样用
Flatten + Linear从 patch 表示得到预测值,以及标准化/反标准化里的mean / var / sqrt / detach / repeat是什么意思。
0. 文件索引
| 项目 | 内容 |
|---|---|
| 源文件 | ts_benchmark/baselines/time_series_library/models/PatchTST.py |
| 源类 | FlattenHead / PatchTST |
| 源方法 | forecast |
| 覆盖函数 | nn.Flatten / nn.Linear / mean / var / sqrt / detach / unsqueeze / repeat |
1. Level 1:PatchTST forecast 的相关源码
python
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
enc_out, n_vars = self.patch_embedding(x_enc)
# Encoder
enc_out, attns = self.encoder(enc_out)
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
)
enc_out = enc_out.permute(0, 1, 3, 2)
# Decoder
dec_out = self.head(enc_out)
dec_out = dec_out.permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out2. Level 2:标准化里的 mean / var / sqrt
输入:
text
x_enc.shape = (B, T, C)
toy: (2, 12, 4)代码:
python
means = x_enc.mean(1, keepdim=True).detach()解释:
text
dim=1 是时间维 T
keepdim=True 保留时间维shape:
text
(2, 12, 4) -> mean(dim=1, keepdim=True) -> (2, 1, 4)含义:
对每个样本、每个变量,沿历史时间维求均值。
2.1 mean(1, keepdim=True) 的具体数字例子
mean(1, keepdim=True) 也可以写成:
python
mean(dim=1, keepdim=True)这里的 1 指第 1 维。
因为 PatchTST 里:
text
x_enc.shape = (B, T, C)所以:
text
dim=0 是 batch 维
dim=1 是 time 维
dim=2 是 channel / variable 维因此:
python
x_enc.mean(1, keepdim=True)意思是:
对每条样本、每个变量,把所有历史时间步求平均。
用一个更小的 toy:
text
x_enc.shape = (B, T, C) = (2, 3, 2)具体值:
python
x_enc = torch.tensor([
[
[1., 10.], # batch 0, t0
[2., 20.], # batch 0, t1
[3., 30.], # batch 0, t2
],
[
[4., 40.], # batch 1, t0
[5., 50.], # batch 1, t1
[6., 60.], # batch 1, t2
],
])对第 0 条样本:
text
变量 0: mean([1, 2, 3]) = 2
变量 1: mean([10, 20, 30]) = 20对第 1 条样本:
text
变量 0: mean([4, 5, 6]) = 5
变量 1: mean([40, 50, 60]) = 50所以:
python
means = x_enc.mean(1, keepdim=True)结果:
text
means =
[
[[2., 20.]],
[[5., 50.]],
]shape:
text
(2, 1, 2)这里中间那个 1 就是被保留下来的时间维。
如果不用 keepdim=True:
python
x_enc.mean(1).shape会得到:
text
(2, 2)时间维会直接消失。
2.2 为什么 x_enc - means 能减
代码:
python
x_enc = x_enc - means此时:
text
x_enc.shape = (2, 3, 2)
means.shape = (2, 1, 2)PyTorch 会自动广播:
text
(2, 1, 2) -> (2, 3, 2)也就是每个时间步都减同一个历史均值。
第 0 条样本:
text
原始:
[[1, 10],
[2, 20],
[3, 30]]
均值:
[[2, 20]]
相减:
[[1-2, 10-20],
[2-2, 20-20],
[3-2, 30-20]]
=
[[-1, -10],
[ 0, 0],
[ 1, 10]]标准差:
python
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)shape:
text
torch.var(..., dim=1, keepdim=True): (2, 1, 4)
torch.sqrt(...): (2, 1, 4)1e-5 是为了避免方差为 0 时除以 0。
3. Level 3:detach() 是什么
python
means = x_enc.mean(1, keepdim=True).detach()detach() 的意思是:
生成一个不参与梯度传播的 tensor 视图。
这里的语义:
text
均值和标准差用于标准化输入,但不希望模型通过它们反向传播去“学习”均值计算过程。它不改变 shape。
4. Level 4:FlattenHead
源码:
python
class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x输入:
text
enc_out.shape = (B, C, d_model, patch_num)
toy: (2, 4, 16, 6)执行:
python
x = self.flatten(x)start_dim=-2 表示从倒数第 2 维开始展平。
所以:
text
(2, 4, 16, 6) -> (2, 4, 96)因为:
text
16 * 6 = 96再执行:
python
x = self.linear(x)如果:
text
nf = 96
target_window = pred_len = 24则:
text
(2, 4, 96) -> Linear(96, 24) -> (2, 4, 24)最后:
python
dec_out = dec_out.permute(0, 2, 1)得到:
text
(2, 4, 24) -> (2, 24, 4)也就是最终预测格式:
text
(B, pred_len, C)5. Level 5:反标准化里的 unsqueeze + repeat
源码:
python
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))先看:
python
stdev[:, 0, :]原来:
text
stdev.shape = (B, 1, C) = (2, 1, 4)取掉中间那个 size=1 的时间维:
text
stdev[:, 0, :].shape = (2, 4)再:
python
unsqueeze(1)把时间维加回来:
text
(2, 4) -> (2, 1, 4)再:
python
repeat(1, self.pred_len, 1)复制到每个预测时间步:
text
(2, 1, 4) -> (2, 24, 4)这样它才能和:
text
dec_out.shape = (2, 24, 4)逐元素相乘、相加。
6. Level 6:可算小例子
某个变量的历史序列:
text
x = [2, 4, 6, 8]均值:
text
mean = 5方差:
text
var = ((2-5)^2 + (4-5)^2 + (6-5)^2 + (8-5)^2) / 4
= (9 + 1 + 1 + 9) / 4
= 5标准差:
text
stdev = sqrt(5) ≈ 2.236标准化:
text
[(2-5)/2.236, (4-5)/2.236, (6-5)/2.236, (8-5)/2.236]
= [-1.342, -0.447, 0.447, 1.342]如果模型预测标准化空间里的未来值:
text
pred_norm = [0.0, 1.0]反标准化:
text
pred = pred_norm * stdev + mean
= [0.0, 1.0] * 2.236 + 5
= [5.0, 7.236]7. 常见错误
7.1 忘记 keepdim=True
如果:
python
means = x_enc.mean(1)shape 会是:
text
(B, C)后续和 (B,T,C) 相减时虽然可能广播成功,但维度语义不如 (B,1,C) 清楚。
7.2 不理解 repeat(1, pred_len, 1)
它不是复制 batch,也不是复制变量。
它只是在时间维复制到预测长度:
text
(B, 1, C) -> (B, pred_len, C)8. 一句话总结
PatchTST 后半段可以压成:
FlattenHead把每个变量的所有 patch 表示展平成一个向量,再用 Linear 映射到预测长度;标准化用历史均值/标准差稳定输入,反标准化再把预测值还原回原始尺度。