Skip to content

Autoformer_EncDec.moving_avg.forward 基础语法注解

Abstract

这篇只讲一件事:

读懂 Autoformer_EncDec.pyclass moving_avgforward(self, x),尤其是它为什么要用 x[:, 0:1, :]repeat(...)torch.cat(...) 给时间序列两端做 padding。

0. 文件索引

项目内容
源文件ts_benchmark/baselines/time_series_library/layers/Autoformer_EncDec.py
源类class moving_avg(nn.Module)
源方法def forward(self, x)
输入张量x: (B, T, C)
当前 toy example(B, T, C) = (2, 6, 3)
当前 kernelkernel_size = 3
当前 padding 步数pad = (kernel_size - 1) // 2 = 1
这篇讲的代码块front / end / torch.cat 这三行边界复制 padding
暂不展开self.avg(x.permute(0, 2, 1)) 和第二次 permute,后续单独讲
Note

仓库里多个 baseline 目录下都有结构相同或非常接近的 Autoformer_EncDec.py。 本文先以 time_series_library 版本为坐标;如果别的 baseline 里 moving_avg.forward 代码相同,这篇解释可以直接迁移。

0.1 本文件知识点索引

这篇不是泛泛讲 PyTorch,而是从真实源码中的一小段 forward 代码出发,绑定下面这些基础语法:

知识点对应源码解决的问题
三维张量语义x: (B, T, C)每一维分别代表什么
全量切片:为什么 batch 维和 feature 维全部保留
范围切片0:1为什么取第一个时间步但不丢掉时间维
负索引-1:为什么能取最后一个时间步
复制张量.repeat(1, pad, 1)为什么能把边界时间步复制成 padding
拼接张量torch.cat(..., dim=1)为什么时间长度会变长
时间维编号dim=1为什么第 1 维就是时间维

0.2 以后同类注解文档的固定规则

以后如果继续写“某个源码方法里的基础语法注解”,开头都按这个规则组织:

  1. frontmatter 里写清 source_filesource_classsource_methodinput_shapetoy_shapeknowledge_points
  2. 正文先放“文件索引”,说明这篇挂在哪个真实源码对象上
  3. 再放“知识点索引”,说明这篇解决哪些基础语法问题
  4. 先讲源码整体目的,再进入局部语法
  5. 每个语法点都绑定回源码,不单独漂浮讲 Python / PyTorch
  6. 只要涉及张量变化,就写出逐步 shape 流
  7. toy example 必须跟源码使用同一套维度语义

0.3 源码对象

源码位置:

python
class moving_avg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """

    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x

本文先只解释这三行:

python
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)

为了降低理解难度,当前 toy example 固定为:

python
# x:     (B, T, C) = (2, 6, 3)
# pad = (kernel_size - 1) // 2 = (3 - 1) // 2 = 1

front = x[:, 0:1, :].repeat(1, 1, 1)
end = x[:, -1:, :].repeat(1, 1, 1)

x = torch.cat([front, x, end], dim=1)

它的核心意思是:

  1. x 里取出第一个时间步,得到 front
  2. x 里取出最后一个时间步,得到 end
  3. front 拼到原序列左边
  4. end 拼到原序列右边
  5. 时间长度从 6 变成 8

1. 当前 toy example

先固定张量语义:

text
x.shape = (2, 6, 3)

三个维度分别是:

维度位置名字含义当前大小
第 0 维Bbatch size,一次有几条样本2
第 1 维Ttime length,每条样本有几个时间步6
第 2 维Cchannel / feature,每个时间步有几个变量3

也就是说:

text
x[b, t, c]

表示:

b 条样本,在第 t 个时间步,第 c 个特征的数值。

为了更直观,可以把 x 想成:

text
第 0 条样本:
  t0: [a000, a001, a002]
  t1: [a010, a011, a012]
  t2: [a020, a021, a022]
  t3: [a030, a031, a032]
  t4: [a040, a041, a042]
  t5: [a050, a051, a052]

第 1 条样本:
  t0: [a100, a101, a102]
  t1: [a110, a111, a112]
  t2: [a120, a121, a122]
  t3: [a130, a131, a132]
  t4: [a140, a141, a142]
  t5: [a150, a151, a152]

2. Level 1:x[:, 0:1, :] 是什么

代码:

python
front = x[:, 0:1, :]

这个表达式是在做三维切片:

python
x[第0维怎么取, 第1维怎么取, 第2维怎么取]

对应到 (B, T, C)

python
x[B维怎么取, T维怎么取, C维怎么取]

所以:

python
x[:, 0:1, :]

可以拆成:

片段作用对哪个维度生效
:全部 batch 都要第 0 维 B
0:1只取第 0 个时间步,但是保留时间维第 1 维 T
:全部特征都要第 2 维 C

因此:

text
x[:, 0:1, :].shape = (2, 1, 3)

注意这里不是 (2, 3),而是 (2, 1, 3)

原因是 0:1 是切片写法,会保留原来的维度。


3. Level 2:为什么用 0:1,不是 0

这两个写法非常容易混:

python
x[:, 0, :]
x[:, 0:1, :]

它们取到的数值都来自第 0 个时间步,但 shape 不一样。

写法含义输出 shape
x[:, 0, :]取第 0 个时间步,并删除时间维(2, 3)
x[:, 0:1, :]取第 0 个时间步,但保留时间维(2, 1, 3)

在这段 padding 代码里,必须保留时间维。

因为后面要执行:

python
torch.cat([front, x, end], dim=1)

frontxend 必须都是三维张量:

text
front.shape = (2, 1, 3)
x.shape     = (2, 6, 3)
end.shape   = (2, 1, 3)

如果写成:

python
front = x[:, 0, :]

那么:

text
front.shape = (2, 3)

它就和 x.shape = (2, 6, 3) 维度数量不一致,不能沿时间维拼接。


4. Level 3:: 是什么意思

在 Python / PyTorch 切片里,: 表示:

这一维全部取出来。

例如:

python
x[:, 0:1, :]

第一个 :

python
x[全部 batch, 0:1, :]

含义是:

text
batch 0 要
batch 1 也要

最后一个 :

python
x[:, 0:1, 全部 feature]

含义是:

text
feature 0 要
feature 1 要
feature 2 也要

所以 x[:, 0:1, :] 不是只取一个数字,而是取一个子张量。


5. Level 4:-1: 是什么意思

代码:

python
end = x[:, -1:, :]

这里的 -1 是 Python 负索引。

在 Python 里:

索引含义
0第一个
1第二个
-1最后一个
-2倒数第二个

所以:

python
x[:, -1:, :]

表示:

对每条样本,取最后一个时间步,并保留时间维。

当前 T = 6,时间步编号是:

text
0, 1, 2, 3, 4, 5

最后一个时间步就是 5

因此:

python
x[:, -1:, :]

等价于:

python
x[:, 5:6, :]

输出 shape:

text
x[:, -1:, :].shape = (2, 1, 3)

6. Level 5:repeat(1, 1, 1) 是什么

代码:

python
front = x[:, 0:1, :].repeat(1, 1, 1)
end = x[:, -1:, :].repeat(1, 1, 1)

repeat(...) 的作用是:

沿每个维度复制张量。

对于一个三维张量:

text
(B, T, C)

调用:

python
repeat(a, b, c)

意思是:

参数作用在哪一维含义
a第 0 维 Bbatch 维复制几次
b第 1 维 T时间维复制几次
c第 2 维 C特征维复制几次

当前:

python
repeat(1, 1, 1)

意思是:

text
B 维复制 1 次
T 维复制 1 次
C 维复制 1 次

也就是不改变 shape。

所以:

text
x[:, 0:1, :].shape              = (2, 1, 3)
x[:, 0:1, :].repeat(1, 1, 1)    = (2, 1, 3)

在这个具体例子里,repeat(1, 1, 1) 看起来有点多余。

但是它保留了一个通用模式:

python
front = x[:, 0:1, :].repeat(1, pad, 1)
end = x[:, -1:, :].repeat(1, pad, 1)

如果 pad = 2,那么:

python
front = x[:, 0:1, :].repeat(1, 2, 1)

shape 会变成:

text
(2, 1, 3) -> (2, 2, 3)

意思是把第一个时间步复制成左边的 2 个 padding 时间步。


7. Level 6:torch.cat([...], dim=1) 是什么

代码:

python
x = torch.cat([front, x, end], dim=1)

torch.cat 的作用是:

把多个张量沿指定维度拼接起来。

这里有三个张量:

text
front.shape = (2, 1, 3)
x.shape     = (2, 6, 3)
end.shape   = (2, 1, 3)

参数:

python
dim=1

表示沿第 1 维拼接。

因为 x 的维度语义是:

text
(B, T, C)

所以第 1 维就是时间维 T

拼接后:

text
(2, 1, 3)
(2, 6, 3)
(2, 1, 3)

沿时间维相加:

text
T = 1 + 6 + 1 = 8

其他维度保持不变:

text
B = 2
C = 3

最终:

text
torch.cat([front, x, end], dim=1).shape = (2, 8, 3)

8. 完整张量流

完整代码:

python
# x: (B, T, C) = (2, 6, 3)
# pad = 1

front = x[:, 0:1, :].repeat(1, 1, 1)
end = x[:, -1:, :].repeat(1, 1, 1)

x = torch.cat([front, x, end], dim=1)

逐步 shape:

text
原始:
  x.shape = (2, 6, 3)

取左端:
  x[:, 0:1, :].shape = (2, 1, 3)
  front.shape        = (2, 1, 3)

取右端:
  x[:, -1:, :].shape = (2, 1, 3)
  end.shape          = (2, 1, 3)

拼接:
  torch.cat([front, x, end], dim=1)
  (2, 1, 3) + (2, 6, 3) + (2, 1, 3)
  = (2, 8, 3)

按时间步看:

text
原始 x:
  [t0, t1, t2, t3, t4, t5]

front:
  [t0]

end:
  [t5]

拼接后:
  [t0, t0, t1, t2, t3, t4, t5, t5]

这就是边界复制 padding:

text
左边复制第一个时间步
右边复制最后一个时间步

9. 一个可运行的小例子

用很小的数字看得更清楚。

python
import torch

x = torch.tensor([
    [
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32],
        [40, 41, 42],
        [50, 51, 52],
        [60, 61, 62],
    ],
    [
        [100, 101, 102],
        [200, 201, 202],
        [300, 301, 302],
        [400, 401, 402],
        [500, 501, 502],
        [600, 601, 602],
    ],
])

front = x[:, 0:1, :].repeat(1, 1, 1)
end = x[:, -1:, :].repeat(1, 1, 1)

y = torch.cat([front, x, end], dim=1)

print(x.shape)
print(front.shape)
print(end.shape)
print(y.shape)
print(y)

输出 shape:

text
torch.Size([2, 6, 3])
torch.Size([2, 1, 3])
torch.Size([2, 1, 3])
torch.Size([2, 8, 3])

输出内容会变成:

text
第 0 条样本:
  [[10, 11, 12],   <- 复制出来的左 padding
   [10, 11, 12],
   [20, 21, 22],
   [30, 31, 32],
   [40, 41, 42],
   [50, 51, 52],
   [60, 61, 62],
   [60, 61, 62]]   <- 复制出来的右 padding

第 1 条样本:
  [[100, 101, 102],
   [100, 101, 102],
   [200, 201, 202],
   [300, 301, 302],
   [400, 401, 402],
   [500, 501, 502],
   [600, 601, 602],
   [600, 601, 602]]

10. 最容易混的点

10.1 x[:, 0, :]x[:, 0:1, :]

python
x[:, 0, :].shape

结果:

text
(2, 3)

时间维被挤掉了。

python
x[:, 0:1, :].shape

结果:

text
(2, 1, 3)

时间维还在。

这段代码需要和原始 x 拼接,所以要用 0:1

10.2 dim=1 为什么是时间维

因为当前张量格式是:

text
(B, T, C)

维度编号从 0 开始:

编号维度含义
dim=0Bbatch 维
dim=1T时间维
dim=2C特征维

所以:

python
torch.cat([front, x, end], dim=1)

是在时间方向上把序列接长。

10.3 repeat(1, 1, 1) 为什么没变

因为每一维都只复制 1 次。

真正有扩展效果的是:

python
repeat(1, pad, 1)

pad = 2 时:

text
(2, 1, 3) -> (2, 2, 3)

11. 一句话总结

这段代码:

python
front = x[:, 0:1, :].repeat(1, 1, 1)
end = x[:, -1:, :].repeat(1, 1, 1)
x = torch.cat([front, x, end], dim=1)

可以翻译成中文:

对每条时间序列,取第一个时间步复制到最左边,取最后一个时间步复制到最右边,然后沿时间维拼起来,让序列长度从 6 变成 8

*记录并在线阅读我的笔记*