Skip to content

PyTorch Tensor 基础操作:切片、变形、拼接、注意力计算

Abstract

这篇按“Tensor 操作类别”组织。

它解释模型里经常看到的 PyTorch 基础操作:怎么取子张量、怎么换维、怎么 reshape、怎么切 patch、怎么做 attention 里的矩阵运算,以及这些操作在 DLinear / Informer / PatchTST 的哪个位置出现。

0. 文件索引

项目内容
本文主题PyTorch Tensor 基础操作
覆盖范围切片、复制、拼接、换维、变形、加删维度、滑窗、attention 计算、统计操作
主要模型DLinear / Informer / PatchTST
配套文档[[10-PyTorch常见层函数-池化卷积线性归一化Embedding]]

0.1 知识地图

类别函数/语法最重要的理解出现位置
切片x[:, 0:1, :]取第一个时间步且保留维度DLinear moving_avg
切片x[:, -1:, :]取最后时间步且保留维度DLinear moving_avg
复制repeat沿指定维度复制DLinear padding / PatchTST 反标准化
拼接torch.cat沿指定维度拼接DLinear padding / DataEmbedding_inverted
换维permute任意重排维度DLinear / PatchTST
换维transpose交换两个维度Conv1d / Attention
变形reshape改 shape,可合并/拆分维度PatchTST
变形view改 shape,常用于多头拆分AttentionLayer
加维unsqueeze插入 size=1 维标准化反变换 / ProbAttention
删维squeeze删除 size=1 维ProbAttention
滑窗unfold沿一维切 patchPatchTST PatchEmbedding
矩阵乘matmul批量矩阵乘ProbAttention
爱因斯坦求和einsum明确指定维度乘加FullAttention
概率化softmax某一维归一化为概率Attention
选择topk取最大 k 个值/下标Informer ProbAttention
统计mean/var/sqrt标准化Informer/PatchTST
断梯度detach不参与反向传播标准化均值

0.2 本文的具体例子标准

本文的例子标准

本文里的“例子”统一指:给一个很小但不退化的具体 tensor,写出操作前后的数值,写出对应数学公式,并说明每个维度代表什么。

text
1. 给一个很小的具体 tensor
2. 写出操作前后的数值
3. 写出对应数学公式
4. 说明每个维度代表什么

也就是说,本文不只写:

text
(B,T,C) -> (B,C,T)

而是尽量写成:

text
x =
[
  [[1, 10], [2, 20], [3, 30]]
]

x.permute(0,2,1) =
[
  [[1, 2, 3],
   [10,20,30]]
]
不用过度简单的矩阵例子

matmul / einsum / attention,尽量不用单纯 2×2 @ 2×2。这种例子能算,但看不出“共享求和维”和“输出二维网格”的关系。本文第 7 节改用 2×3 @ 3×2L=3,S=3,H=2 的 attention 例子。


1. 切片::0:1-1:

1.1 x[:, 0:1, :]

假设:

text
x.shape = (B,T,C) = (2,6,3)

代码:

python
front = x[:, 0:1, :]

含义:

片段意思
:所有 batch
0:1第 0 个时间步,但保留时间维
:所有变量

shape:

text
(2,6,3) -> (2,1,3)

数学写法:

text
front[b, 0, c] = x[b, 0, c]

非常具体的例子:

python
x = torch.tensor([
    [
        [1, 10],
        [2, 20],
        [3, 30],
    ]
])

front = x[:, 0:1, :]

结果:

text
x.shape     = (1,3,2)
front.shape = (1,1,2)

front =
[
  [
    [1, 10]
  ]
]

注意:

text
x[:, 0, :]   -> shape = (1,2),时间维消失
x[:, 0:1, :] -> shape = (1,1,2),时间维保留

在我们模型里的位置:

模型源码位置作用
DLinearAutoformer_EncDec.py -> moving_avg.forward取第一个时间步作为左 padding

1.2 x[:, -1:, :]

-1 是最后一个元素。

代码:

python
end = x[:, -1:, :]

shape:

text
(2,6,3) -> (2,1,3)

数学写法:

text
end[b, 0, c] = x[b, T-1, c]

非常具体的例子:

python
x = torch.tensor([
    [
        [1, 10],
        [2, 20],
        [3, 30],
    ]
])

end = x[:, -1:, :]

结果:

text
end =
[
  [
    [3, 30]
  ]
]

作用:

取最后一个时间步作为右 padding。

详细下钻:[[../model-order/01-DLinear/01-Autoformer_EncDec-moving_avg-forward-基础语法注解|01-Autoformer_EncDec-moving_avg-forward-基础语法注解]]


2. 复制与拼接:repeattorch.cat

2.1 repeat

基本作用:

沿每个维度复制 tensor。

例子:

python
x.shape = (2, 1, 3)
y = x.repeat(1, 2, 1)

shape:

text
(2,1,3) -> (2,2,3)

数学写法:

如果:

text
y = x.repeat(r0, r1, r2)

那么 y 是把 x 在每个维度按 r0/r1/r2 复制出来。常见场景是只复制时间维:

text
y[b, t, c] = x[b, 0, c]

非常具体的例子:

python
x = torch.tensor([
    [
        [1, 10]
    ]
])

y = x.repeat(1, 3, 1)

结果:

text
x.shape = (1,1,2)
y.shape = (1,3,2)

y =
[
  [
    [1, 10],
    [1, 10],
    [1, 10]
  ]
]

在 DLinear 的 moving_avg 里,这就是把首尾点复制成 padding。

在我们模型里的位置:

模型源码位置作用
DLinearmoving_avg.forward复制首尾时间步作为 padding
PatchTSTforecast 反标准化(B,1,C) 复制成 (B,pred_len,C)

2.2 torch.cat

基本作用:

沿某个维度把多个 tensor 拼起来。

例子:

python
x = torch.cat([front, x, end], dim=1)

shape:

text
(2,1,3) + (2,6,3) + (2,1,3)
-> dim=1 拼接
(2,8,3)

数学写法:

假设:

text
z = cat([a, b, c], dim=1)
a.shape = (B, T_a, C)
b.shape = (B, T_b, C)
c.shape = (B, T_c, C)

那么:

text
z[b, t, ch] =
  a[b, t, ch]                      if 0 <= t < T_a
  b[b, t - T_a, ch]                if T_a <= t < T_a + T_b
  c[b, t - T_a - T_b, ch]          if T_a + T_b <= t

非常具体的例子:

python
front = torch.tensor([[[1, 10]]])
x = torch.tensor([
    [
        [2, 20],
        [3, 30],
    ]
])
end = torch.tensor([[[4, 40]]])

y = torch.cat([front, x, end], dim=1)

结果:

text
y.shape = (1,4,2)

y =
[
  [
    [1, 10],
    [2, 20],
    [3, 30],
    [4, 40]
  ]
]

在我们模型里的位置:

模型源码位置作用
DLinearmoving_avg.forward左 padding + 原序列 + 右 padding
Informer 变体DataEmbedding_inverted.forward拼接输入值和时间特征

3. 换维:permutetranspose

3.1 permute

基本作用:

任意重排多个维度。

例子:

python
y = x.permute(0, 2, 1)

shape:

text
(B,T,C) -> (B,C,T)
(2,6,3) -> (2,3,6)

数学写法:

text
y = x.permute(0, 2, 1)
y[b, c, t] = x[b, t, c]

非常具体的例子:

python
x = torch.tensor([
    [
        [1, 10],
        [2, 20],
        [3, 30],
    ]
])

y = x.permute(0, 2, 1)

结果:

text
x.shape = (1,3,2)
y.shape = (1,2,3)

y =
[
  [
    [1, 2, 3],
    [10, 20, 30]
  ]
]

这里第 0 个变量的时间序列 [1,2,3] 被放到一整行,适配 Conv1d / AvgPool1d(B,C,L) 格式。

在我们模型里的位置:

模型源码位置作用
DLinearmoving_avg.forward转成 AvgPool1d 需要的 (B,C,L)
DLinearDLinear.encoder让 Linear 作用在时间维
PatchTSTPatchTST.forecast转成 (B,C,T) 进入 PatchEmbedding
PatchTSTforecast 后半段(B,C,pred_len) 转回 (B,pred_len,C)

3.2 transpose

基本作用:

只交换两个维度。

例子:

python
y = x.transpose(1, 2)

shape:

text
(B,C,L) -> (B,L,C)

数学写法:

text
y = x.transpose(1, 2)
y[b, l, c] = x[b, c, l]

非常具体的例子:

python
x = torch.tensor([
    [
        [1, 2, 3],
        [10, 20, 30],
    ]
])

y = x.transpose(1, 2)

结果:

text
x.shape = (1,2,3)
y.shape = (1,3,2)

y =
[
  [
    [1, 10],
    [2, 20],
    [3, 30]
  ]
]

transpose(1,2)permute(0,2,1) 的二轴交换特例。

在我们模型里的位置:

模型源码位置作用
InformerTokenEmbedding.forwardConv1d 后转回 (B,L,d_model)
Informer/PatchTSTEncoderLayer.forwardFFN 的 Conv1d 前后换维
AttentionProbAttentionK.transpose(-2, -1) 做矩阵乘

详细下钻:[[../model-order/02-PatchTST/03-Conv1d与BCL格式-Informer-PatchTST|03-Conv1d与BCL格式-Informer-PatchTST]]


4. 改形状:reshapeview

4.1 reshape

基本作用:

改变 tensor shape,常用于合并或还原维度。

PatchTST 里:

python
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))

shape:

text
(B,C,patch_num,patch_len) -> (B*C,patch_num,patch_len)
(2,4,6,4) -> (8,6,4)

作用:

把每个变量当成独立序列,进入 channel-independent attention。

数学写法:

假设:

text
x.shape = (B,C,P,L)
y = reshape(x, (B*C, P, L))

那么新 batch 下标 bc 对应:

text
b = bc // C
c = bc % C
y[bc, p, l] = x[b, c, p, l]

非常具体的例子:

python
x = torch.tensor([
    [
        [[1, 2], [3, 4]],       # batch0, channel0
        [[10, 20], [30, 40]],   # batch0, channel1
    ]
])

y = x.reshape(1 * 2, 2, 2)

结果:

text
x.shape = (1,2,2,2)
y.shape = (2,2,2)

y =
[
  [
    [1, 2],
    [3, 4]
  ],
  [
    [10, 20],
    [30, 40]
  ]
]

PatchTST 的 channel-independent 就是这个思想:

text
原来 batch 里有多个变量
reshape 后每个变量变成一条独立样本

4.2 view

基本作用:

改变 tensor shape,源码里常用于 attention 拆多头。

AttentionLayer 里:

python
queries = self.query_projection(queries).view(B, L, H, -1)

shape:

text
(B,L,d_model) -> Linear -> (B,L,H*d_k) -> view -> (B,L,H,d_k)
(2,6,16) -> (2,6,16) -> (2,6,2,8)

数学写法:

假设:

text
d_model = H * d_k
y = x.view(B, L, H, d_k)

那么:

text
y[b, l, h, e] = x[b, l, h*d_k + e]

非常具体的例子:

python
x = torch.tensor([
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
    ]
])

y = x.view(1, 2, 2, 2)

结果:

text
x.shape = (1,2,4)
y.shape = (1,2,2,2)

token0: [1,2,3,4]
  -> head0 [1,2]
  -> head1 [3,4]

token1: [5,6,7,8]
  -> head0 [5,6]
  -> head1 [7,8]

view 在 attention 里不是“做注意力”,只是把最后一维重新解释成:

text
d_model = n_heads × head_dim

在我们模型里的位置:

模型源码位置作用
InformerAttentionLayer.forwardQ/K/V 拆多头
PatchTSTAttentionLayer.forwardQ/K/V 拆多头

详细下钻:[[../model-order/02-PatchTST/05-Attention基础操作-view-matmul-einsum-softmax-topk|05-Attention基础操作-view-matmul-einsum-softmax-topk]]


5. 加维和删维:unsqueeze / squeeze

5.1 unsqueeze

基本作用:

在指定位置插入一个 size=1 的维度。

例子:

python
x.shape = (2, 4)
y = x.unsqueeze(1)

shape:

text
(2,4) -> (2,1,4)

数学写法:

text
y = x.unsqueeze(1)
y[b, 0, c] = x[b, c]

非常具体的例子:

python
x = torch.tensor([
    [1, 10],
    [2, 20],
])

y = x.unsqueeze(1)

结果:

text
x.shape = (2,2)
y.shape = (2,1,2)

y =
[
  [[1, 10]],
  [[2, 20]]
]

这常用于制造可广播的时间维,比如 (B,C) 变成 (B,1,C)

在我们模型里的位置:

模型源码位置作用
PatchTSTforecast 反标准化(B,C)(B,1,C),再 repeat 到 pred_len
InformerProbAttention为批量矩阵乘或广播增加维度

5.2 squeeze

基本作用:

删除 size=1 的维度。

在 Informer ProbAttention 里:

python
Q_K_sample = torch.matmul(...).squeeze()

作用:

把矩阵乘后多出来的 size=1 维度去掉,得到采样分数。

数学写法:

text
x.shape = (B,1,C)
y = x.squeeze(1)
y[b,c] = x[b,0,c]

非常具体的例子:

python
x = torch.tensor([
    [[1, 10]],
    [[2, 20]],
])

y = x.squeeze(1)

结果:

text
x.shape = (2,1,2)
y.shape = (2,2)

y =
[
  [1, 10],
  [2, 20]
]

注意:

text
squeeze()      会删除所有 size=1 的维度
squeeze(dim)   只删除指定维度

6. 滑窗切片:unfold

基本作用:

沿某个维度滑窗,切出局部窗口。

PatchTST 里:

python
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)

toy:

text
x: (B,C,T_padded) = (2,4,14)
patch_len = 4
stride = 2

输出:

text
(2,4,14) -> (2,4,6,4)

其中:

text
6 = patch_num
4 = patch_len

可视化:

text
Patch 0: [t0,  t1,  t2,  t3]
Patch 1: [t2,  t3,  t4,  t5]
Patch 2: [t4,  t5,  t6,  t7]
...

数学写法:

如果:

text
y = x.unfold(dimension=-1, size=patch_len, step=stride)

那么:

text
y[b, c, p, i] = x[b, c, p*stride + i]

其中:

text
p = 第几个 patch
i = patch 内第几个点

非常具体的例子:

python
x = torch.tensor([
    [
        [1, 2, 3, 4, 5, 6]
    ]
])

y = x.unfold(dimension=-1, size=3, step=2)

结果:

text
x.shape = (1,1,6)
y.shape = (1,1,2,3)

y =
[
  [
    [
      [1, 2, 3],   # patch0,从 t0 开始
      [3, 4, 5],   # patch1,从 t2 开始
    ]
  ]
]

因为:

text
patch0 = x[..., 0:3] = [1,2,3]
patch1 = x[..., 2:5] = [3,4,5]

在我们模型里的位置:

模型源码位置作用
PatchTSTEmbed.py -> PatchEmbedding.forward把时间序列切成 patch token

详细下钻:[[../model-order/02-PatchTST/01-ReplicationPad1d与unfold-PatchTST-PatchEmbedding|01-ReplicationPad1d与unfold-PatchTST-PatchEmbedding]]


7. Attention 计算:matmul / einsum / softmax / topk

本节主线

Attention 的数学核心可以拆成四步:

QKQKdksoftmax()AV

PyTorch 代码里对应 matmul / einsum / softmaxtopk 是 Informer ProbAttention 为了减少计算量,先筛 query 的操作。

7.0 先理解 matmul,再理解 einsum

学习顺序

matmul 是标准矩阵乘法;einsum 是你自己写维度规则的广义乘法/求和工具。

所以先把 matmul 的行列点积、矩阵向量、batch 矩阵乘法看懂,再看 einsum("blhe,bshe->bhls") 会更自然。

matmul 的核心规则:

(m,n)@(n,p)(m,p)

也就是:

text
前一个张量的最后一维 = 后一个张量的倒数第二维

einsum 的核心规则:

text
1. 给每个维度起字母名
2. 重复但不出现在箭头右边的字母会被求和消掉
3. 箭头右边写什么,输出就保留什么维度

对照:

操作matmul 写法einsum 写法含义
向量点积torch.matmul(x, y)torch.einsum("i,i->", x, y)i 维乘加后消失
矩阵乘法torch.matmul(A, B)torch.einsum("ij,jk->ik", A, B)j 维乘加后消失
batch 矩阵乘法torch.matmul(A, B)torch.einsum("bij,bjk->bik", A, B)每个 b 独立做矩阵乘法
attention score不直接写成普通 2D matmultorch.einsum("blhe,bshe->bhls", Q, K)每个 b,h 下,所有 query-key 做点积

7.1 torch.matmul

基本作用:

批量矩阵乘。

ProbAttention 里:

python
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))

shape:

text
Q_reduce: (B,H,n_top,d_k)
K^T:      (B,H,d_k,L_K)
输出:     (B,H,n_top,L_K)

数学写法:

Y[,i,j]=kA[,i,k]B[,k,j]

7.1.1 向量和向量:点积

代码:

python
x = torch.tensor([1., 2., 3.])
y = torch.tensor([4., 5., 6.])

z = torch.matmul(x, y)

结果:

text
x.shape = (3,)
y.shape = (3,)
z.shape = ()

z = 1*4 + 2*5 + 3*6 = 32

数学上:

z=ixiyi

这里输出是 scalar,因为两个向量的共同维度 i 被求和消掉。

7.1.2 矩阵和向量:每一行和向量做点积

代码:

python
A = torch.tensor([
    [1., 2., 3.],
    [4., 5., 6.],
])

x = torch.tensor([10., 20., 30.])

y = torch.matmul(A, x)

结果:

text
A.shape = (2,3)
x.shape = (3,)
y.shape = (2,)

y[0] = 1*10 + 2*20 + 3*30 = 140
y[1] = 4*10 + 5*20 + 6*30 = 320

y = [140, 320]

语义:

text
矩阵 A 的每一行,都和向量 x 做一次点积。

7.1.3 向量和矩阵:向量和每一列做点积

代码:

python
x = torch.tensor([10., 20.])

B = torch.tensor([
    [1., 2., 3.],
    [4., 5., 6.],
])

y = torch.matmul(x, B)

结果:

text
x.shape = (2,)
B.shape = (2,3)
y.shape = (3,)

y[0] = 10*1 + 20*4 = 90
y[1] = 10*2 + 20*5 = 120
y[2] = 10*3 + 20*6 = 150

y = [90, 120, 150]

语义:

text
向量 x 去乘矩阵 B 的每一列。

7.1.4 矩阵和矩阵:行乘列

非常具体的例子:

python
A = torch.tensor([
    [1., 2., 3.],
    [4., 5., 6.],
])

B = torch.tensor([
    [10., 20.],
    [30., 40.],
    [50., 60.],
])

Y = torch.matmul(A, B)

结果:

text
A.shape = (2,3)
B.shape = (3,2)
Y.shape = (2,2)

Y =
[
  [1*10 + 2*30 + 3*50, 1*20 + 2*40 + 3*60],
  [4*10 + 5*30 + 6*50, 4*20 + 5*40 + 6*60],
]
=
[
  [220, 280],
  [490, 640],
]
读矩阵乘法时只盯一个核心

A 的列数必须等于 B 的行数。这个共享维度 3 被求和消掉,输出保留 A 的行数和 B 的列数:(2,3) @ (3,2) -> (2,2)

7.1.5 高维张量:前面的维度当 batch

batch matmul

高维 matmul 不是把所有维度都混在一起乘。它把最后两维当矩阵乘法维度,前面的维度当 batch 维度。

批量矩阵乘的具体例子:

python
A = torch.tensor([
    [
        [1., 2., 3.],
        [4., 5., 6.],
    ],
    [
        [2., 0., 1.],
        [1., 3., 2.],
    ],
])

B = torch.tensor([
    [
        [10., 20.],
        [30., 40.],
        [50., 60.],
    ],
    [
        [1., 2.],
        [3., 4.],
        [5., 6.],
    ],
])

Y = torch.matmul(A, B)

结果:

text
A.shape = (2,2,3)
B.shape = (2,3,2)
Y.shape = (2,2,2)

第 0 个 batch:
[
  [220, 280],
  [490, 640],
]

第 1 个 batch:
[
  [2*1 + 0*3 + 1*5, 2*2 + 0*4 + 1*6],
  [1*1 + 3*3 + 2*5, 1*2 + 3*4 + 2*6],
]
=
[
  [7, 10],
  [20, 26],
]

在 ProbAttention 里,torch.matmul(Q_reduce, K.transpose(-2, -1)) 就是在每个 B,H 组合下做这样的批量矩阵乘。

对应到 attention:

text
Q_reduce.shape = (B,H,n_top,d_k)
K.transpose(-2,-1).shape = (B,H,d_k,L_K)

torch.matmul 后:
(B,H,n_top,L_K)

其中:

text
B,H 是 batch-like 维度
n_top,d_k @ d_k,L_K 才是真正的矩阵乘法

7.2 torch.einsum

1. 核心法则:标签匹配(Pattern Matching)

einsum 的公式由逗号隔开的输入标签和箭头后的输出标签组成。例如:"abcd, adef -> abcef"。

  • 求和规则(收缩维度): 如果一个字母在左侧(输入)出现多次,但在右侧(输出)消失了,那么这个维度就会被相乘并求和(类似于点积)。

  • 保留规则(自由维度): 如果一个字母在右侧(输出)出现了,它就会被保留在结果中。

  • 对齐规则(批次维度): 如果一个字母在左右两侧都出现了,它就像“批次(Batch)”一样。

在 Transformer 的 einsum('b h q d, b h k d -> b h q k', Q, K) 这个操作中,einsum 实际上帮你完成了一系列复杂的维度对齐和转置。

我们可以把这个过程拆解为三个“自动行为”来理解:

1. 自动对齐的“批次循环”(b 和 h)

bh 同时出现在两个输入和输出的相同位置时,einsum 内部逻辑相当于写了两个嵌套循环:

python
# 伪代码:理解 einsum 内部如何处理 b 和 h
for b in range(Batch):
    for h in range(Head):
        # 在每一个特定的 (b, h) 索引下,取出的是两个 2D 矩阵
        matrix_Q = Q[b, h] # 形状 (q, d)
        matrix_K = K[b, h] # 形状 (k, d)
        
        # 对这两个 2D 矩阵进行后续操作...
        result[b, h] = some_op(matrix_Q, matrix_K)

所以,bh 被自动理解为“独立运行的平行维度”,它们不参与乘法叠加,只是作为索引存在。


2. 自动处理“维度转置”(d 的收缩)

这是 einsum 最精妙的地方。

  • 标准矩阵乘法 (matmul) 的要求: 左矩阵的列数必须等于右矩阵的行数。即 (q, d) @ (d, k)
  • 你的输入: Q(q, d)K 也是 (k, d)(最后一位都是 d)。

如果你用 torch.matmul(Q, K),程序会报错,因为维度不匹配(d 对不上 k)。你必须手动写成 torch.matmul(Q, K.transpose(-1, -2))

但在 einsum 中: 你定义了 ... q d, ... k d -> ... q keinsum 看到 d 在两个输入中都出现了,但在输出中消失了,它就明白:“哦!d 是我要进行相乘并求和的那个维度。”

它在内部会自动找到两个张量中标签为 d 的维度,把它们对齐,不管 d 是在行还是在列。这就是你说的**“自动转换好维度”**。


3. 内部运算逻辑:QD × DK → QK

正如你所猜想的,einsum 内部执行的逻辑就是:

  1. 逐元素相乘: 取出 Q[b, h, q, d]K[b, h, k, d]。注意,此时 qk 是不同的索引,所以它其实是在做一个类似“广播”的动作,生成一个中间状态 (b, h, q, k, d)
  2. 求和(Summation): 因为输出标签里没有 d,它会对所有 d 维上的结果进行累加。

数学表达就是:

Outputb,h,q,k=dQb,h,q,d×Kb,h,k,d

这恰恰就是矩阵乘法 Q×KT 的定义。


进阶:Transformer 中的下一步 (V 的加权)

理解了上面那个,你就能秒懂 Attention 的下一步:注意力分数乘以 Value (V)

  • attn: (b, h, q, k) (刚才计算的结果)
  • V: (b, h, k, d) (Value 张量)
  • 目标: 得到 (b, h, q, d)
python
# einsum 表达式
torch.einsum('b h q k, b h k d -> b h q d', attn, V)

这里的逻辑:

  1. b, h 依然是批次。
  2. k 在左右都有,但在输出中消失 k 维度进行收缩(求和)
  3. qd 被保留。
  4. 这本质上是执行了 (q, k) @ (k, d) 的矩阵乘法。

总结:为什么 Transformer 喜欢用 einsum
  1. 免去转置烦恼: 你不需要纠结什么时候该 transpose(-1, -2),只需要给维度起个名字(如 q, k, d)。
  2. 逻辑清晰: 看到 ...q k, ...k d -> ...q d,一眼就能看出是在用“分数(k)”对“特征(d)”做加权平均。
  3. 代码稳健: 即使以后你把维度顺序改了(比如把 Head 换到最后:b q h d),只要 einsum 的标签对应正确,代码逻辑依然成立,不需要改动复杂的 transposeview

2. FullAttention的例子

FullAttention 里:

python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)

shape:

text
queries: (B,L,H,E)
keys:    (B,S,H,E)
scores:  (B,H,L,S)

含义:

E 维做点积,得到每个 query-key 对的相似度。

数学写法:

scores[b,h,l,s]=equeries[b,l,h,e]keys[b,s,h,e]

非常具体的例子:

text
B=1, L=3, S=3, H=2, E=2

queries:

head0:
  query0 [1, 0]
  query1 [0, 1]
  query2 [1, 1]

head1:
  query0 [2, 0]
  query1 [0, 2]
  query2 [1, 1]

keys:

head0:
  key0 [1, 0]
  key1 [0, 1]
  key2 [1, 1]

head1:
  key0 [1, 1]
  key1 [2, 0]
  key2 [0, 2]

逐项计算:

text
head0:
query0 对 key0 = [1,0] · [1,0] = 1
query0 对 key1 = [1,0] · [0,1] = 0
query0 对 key2 = [1,0] · [1,1] = 1

query1 对 key0 = [0,1] · [1,0] = 0
query1 对 key1 = [0,1] · [0,1] = 1
query1 对 key2 = [0,1] · [1,1] = 1

query2 对 key0 = [1,1] · [1,0] = 1
query2 对 key1 = [1,1] · [0,1] = 1
query2 对 key2 = [1,1] · [1,1] = 2

head1:
query0 对 key0 = [2,0] · [1,1] = 2
query0 对 key1 = [2,0] · [2,0] = 4
query0 对 key2 = [2,0] · [0,2] = 0

query1 对 key0 = [0,2] · [1,1] = 2
query1 对 key1 = [0,2] · [2,0] = 0
query1 对 key2 = [0,2] · [0,2] = 4

query2 对 key0 = [1,1] · [1,1] = 2
query2 对 key1 = [1,1] · [2,0] = 2
query2 对 key2 = [1,1] · [0,2] = 2

所以:

text
scores.shape = (1,2,3,3)

scores =
[
  head0:
    [
      [1, 0, 1],
      [0, 1, 1],
      [1, 1, 2],
    ],
  head1:
    [
      [2, 4, 0],
      [2, 0, 4],
      [2, 2, 2],
    ],
]
einsum("blhe,bshe->bhls") 的读法

bh 是分组维,保留下来;l 是 query token 维,保留下来;s 是 key token 维,保留下来;e 同时出现在输入但不出现在输出,所以被乘加求和。

7.3 torch.softmax

Attention 里:

python
A = torch.softmax(scores, dim=-1)

如果:

text
scores.shape = (B,H,L,S)

那么 dim=-1 是 key 维 S

含义:

每个 query 对所有 key 的权重和为 1。

数学写法:

softmax(zi)=exp(zi)jexp(zj)

非常具体的例子:

python
scores = torch.tensor([1.0, 0.0, 2.0])
A = torch.softmax(scores, dim=-1)

计算:

text
exp(1) ≈ 2.718
exp(0) = 1
exp(2) ≈ 7.389
sum = 2.718 + 1 + 7.389 = 11.107

softmax([1,0,2])
= [2.718/11.107, 1/11.107, 7.389/11.107]
≈ [0.245, 0.090, 0.665]

如果:

text
scores =
[
  [1, 0, 1],
  [0, 1, 1],
  [1, 1, 2],
]

沿最后一维 softmax:

text
softmax([1,0,1]) ≈ [0.422, 0.155, 0.422]
softmax([0,1,1]) ≈ [0.155, 0.422, 0.422]
softmax([1,1,2]) ≈ [0.212, 0.212, 0.576]
Attention 里 softmax 的方向

scores.shape = (B,H,L,S) 时,dim=-1S,也就是 key 维。含义是:每个 query token 对所有 key token 的权重和为 1。

7.4 topk

Informer 的 ProbAttention 里:

python
M_top = M.topk(n_top, sorted=False)[1]

作用:

选出稀疏度最高的 query 下标,只对这些 query 做更完整的 attention。

数学写法:

text
values, indices = topk(x, k)

表示取 x 中最大的 k 个值和对应下标。

非常具体的例子:

python
x = torch.tensor([0.2, 3.0, 1.5, 4.0, 2.2])
values, indices = x.topk(3, sorted=True)

结果:

text
values  = [4.0, 3.0, 2.2]
indices = [3, 1, 4]

如果 sorted=True,结果会按值从大到小排序。Informer 里常用 [1] 只取下标:

python
M_top = M.topk(n_top, sorted=False)[1]

这里 M_top 是被选中的 query 位置。

7.5 Attention 完整手算例子:QK^T -> softmax -> A V

本节目的

这个例子不是为了贴近真实参数大小,而是为了看清每一个 token 的信息怎样经过 scoresAV_out 流动。例子选择 L=3,S=3,H=2,E=2,D=2,比 2×2 更能体现“多 query、多 key、多 head”的关系。

这个例子对应 PatchTST 的 FullAttention 核心:

python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
A = torch.softmax(scale * scores, dim=-1)
V = torch.einsum("bhls,bshd->blhd", A, values)

固定:

text
B = 1
L = 3
S = 3
H = 2
E = D = 2
scale = 1 / sqrt(E) = 1 / sqrt(2) ≈ 0.707

设:

text
head0:
Q0 = [1, 0]   K0 = [1, 0]   V0 = [10, 0]
Q1 = [0, 1]   K1 = [0, 1]   V1 = [0, 20]
Q2 = [1, 1]   K2 = [1, 1]   V2 = [30, 40]

head1:
Q0 = [2, 0]   K0 = [1, 1]   V0 = [1, 10]
Q1 = [0, 2]   K1 = [2, 0]   V1 = [2, 20]
Q2 = [1, 1]   K2 = [0, 2]   V2 = [3, 30]

第一步:算 scores。

text
scores = QK^T

head0 scores:

             key0   key1   key2
query0        1      0      1
query1        0      1      1
query2        1      1      2

head1 scores:

             key0   key1   key2
query0        2      4      0
query1        2      0      4
query2        2      2      2

第二步:scale。

text
scale * scores =

head0:
             key0    key1    key2
query0      0.707   0       0.707
query1      0       0.707   0.707
query2      0.707   0.707   1.414

head1:
             key0    key1    key2
query0      1.414   2.828   0
query1      1.414   0       2.828
query2      1.414   1.414   1.414

第三步:对 key 维 softmax。

text
head0:
softmax([0.707, 0,     0.707]) ≈ [0.401, 0.198, 0.401]
softmax([0,     0.707, 0.707]) ≈ [0.198, 0.401, 0.401]
softmax([0.707, 0.707, 1.414]) ≈ [0.248, 0.248, 0.503]

head1:
softmax([1.414, 2.828, 0    ]) ≈ [0.187, 0.768, 0.045]
softmax([1.414, 0,     2.828]) ≈ [0.187, 0.045, 0.768]
softmax([1.414, 1.414, 1.414]) = [0.333, 0.333, 0.333]

所以注意力权重:

text
head0 A =

             key0    key1    key2
query0      0.401   0.198   0.401
query1      0.198   0.401   0.401
query2      0.248   0.248   0.503

head1 A =

             key0    key1    key2
query0      0.187   0.768   0.045
query1      0.187   0.045   0.768
query2      0.333   0.333   0.333

第四步:用 A 加权 values。

text
head0:

out query0
= 0.401*[10,0] + 0.198*[0,20] + 0.401*[30,40]
= [16.04, 20.00]

out query1
= 0.198*[10,0] + 0.401*[0,20] + 0.401*[30,40]
= [14.01, 24.06]

out query2
= 0.248*[10,0] + 0.248*[0,20] + 0.503*[30,40]
= [17.57, 25.08]

head1:

out query0
= 0.187*[1,10] + 0.768*[2,20] + 0.045*[3,30]
= [1.858, 18.58]

out query1
= 0.187*[1,10] + 0.045*[2,20] + 0.768*[3,30]
= [2.581, 25.81]

out query2
= 0.333*[1,10] + 0.333*[2,20] + 0.333*[3,30]
= [2.000, 20.00]

最终:

text
V_out =
[
  token0:
    head0 [16.04, 20.00]
    head1 [1.858, 18.58]

  token1:
    head0 [14.01, 24.06]
    head1 [2.581, 25.81]

  token2:
    head0 [17.57, 25.08]
    head1 [2.000, 20.00]
]

shape 对应:

text
queries: (1,3,2,2)
keys:    (1,3,2,2)
values:  (1,3,2,2)
scores:  (1,2,3,3)
A:       (1,2,3,3)
V_out:   (1,3,2,2)

这个例子的语义是:

text
每个 query token 根据自己和 key token 的相似度,
从 value token 中加权拿信息。
AttentionLayer 的关系

FullAttention 输出的是 (B,L,H,D)。回到 AttentionLayer 后,还会执行 out.view(B,L,-1),把 H×D 合并回 d_model,再经过 out_projection 返回 (B,L,d_model)

详细下钻:[[../model-order/02-PatchTST/05-Attention基础操作-view-matmul-einsum-softmax-topk|05-Attention基础操作-view-matmul-einsum-softmax-topk]]


8. 统计操作:mean / var / sqrt / detach

PatchTST 里:

python
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev

输入:

text
x_enc.shape = (B,T,C)

均值:

text
mean(dim=1, keepdim=True)
(B,T,C) -> (B,1,C)

8.1 mean(1, keepdim=True) 到底按什么求均值

x_enc.mean(1, keepdim=True) 等价于:

python
x_enc.mean(dim=1, keepdim=True)

这里 dim=1 指的是第 1 维。

如果:

text
x_enc.shape = (B, T, C)

那么三个维度编号是:

维度编号语义
dim=0batch 维,几条样本
dim=1time 维,几个历史时间步
dim=2channel 维,几个变量

所以:

python
x_enc.mean(dim=1, keepdim=True)

意思是:

对每条样本、每个变量,沿时间维 T 求平均。

一个具体数字例子

假设:

text
B = 2
T = 3
C = 2

也就是:

text
x_enc.shape = (2, 3, 2)

具体数值:

python
x_enc = torch.tensor([
    [
        [1., 10.],   # batch 0, t0, 两个变量
        [2., 20.],   # batch 0, t1
        [3., 30.],   # batch 0, t2
    ],
    [
        [4., 40.],   # batch 1, t0
        [5., 50.],   # batch 1, t1
        [6., 60.],   # batch 1, t2
    ],
])

dim=1 求均值,就是把每个 batch 内的 t0/t1/t2 平均掉。

第 0 条样本:

text
变量 0: mean([1, 2, 3]) = 2
变量 1: mean([10, 20, 30]) = 20

第 1 条样本:

text
变量 0: mean([4, 5, 6]) = 5
变量 1: mean([40, 50, 60]) = 50

所以:

python
means = x_enc.mean(dim=1, keepdim=True)

结果是:

text
means =
[
    [[2., 20.]],
    [[5., 50.]],
]

shape 是:

text
(2, 1, 2)

中间维度为什么是 1

因为 keepdim=True 表示:

虽然时间维被平均掉了,但仍然保留这个维度的位置,只是长度从 T=3 变成 1

如果不用 keepdim=True

python
x_enc.mean(dim=1)

shape 会是:

text
(2, 2)

也就是时间维直接消失。

为什么后面可以 x_enc - means

原始:

text
x_enc.shape = (2, 3, 2)
means.shape = (2, 1, 2)

PyTorch 会广播:

text
(2, 1, 2) -> (2, 3, 2)

相当于每个时间步都减同一个历史均值。

第 0 条样本:

text
原始:
[[1, 10],
 [2, 20],
 [3, 30]]

均值:
[[2, 20]]

相减:
[[1-2, 10-20],
 [2-2, 20-20],
 [3-2, 30-20]]
=
[[-1, -10],
 [ 0,   0],
 [ 1,  10]]

方差:

text
var(dim=1, keepdim=True)
(B,T,C) -> (B,1,C)

8.2 var(dim=1, keepdim=True, unbiased=False) 的具体例子

数学写法:

text
var[b, 0, c] = (1/T) * Σ_t (x[b,t,c] - mean[b,0,c])²

注意这里 unbiased=False,所以除以 T,不是除以 T-1

继续用上面的第 0 条样本:

text
原始:
[[1, 10],
 [2, 20],
 [3, 30]]

均值:
[[2, 20]]

变量 0:

text
var = ((1-2)^2 + (2-2)^2 + (3-2)^2) / 3
    = (1 + 0 + 1) / 3
    = 0.6667

变量 1:

text
var = ((10-20)^2 + (20-20)^2 + (30-20)^2) / 3
    = (100 + 0 + 100) / 3
    = 66.6667

所以第 0 条样本的方差是:

text
var[0] = [[0.6667, 66.6667]]

sqrt

方差开方得到标准差。

8.3 sqrt 的具体例子

PatchTST 里:

python
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)

数学写法:

text
stdev[b,0,c] = sqrt(var[b,0,c] + 1e-5)

继续上面的结果:

text
var[0] = [[0.6667, 66.6667]]

那么:

text
stdev[0,0,0] = sqrt(0.6667 + 1e-5) ≈ 0.8165
stdev[0,0,1] = sqrt(66.6667 + 1e-5) ≈ 8.1650

所以:

text
stdev[0] ≈ [[0.8165, 8.1650]]

1e-5 的作用:

text
防止方差为 0 时除以 0。

detach

不让均值这条计算参与反向传播。

8.4 detach 的具体例子

代码:

python
means = x_enc.mean(1, keepdim=True).detach()

数学数值上:

text
detach 不改变 means 的值。

例如:

text
mean before detach =
[
  [[2., 20.]],
  [[5., 50.]],
]

mean after detach =
[
  [[2., 20.]],
  [[5., 50.]],
]

变化只发生在计算图:

text
不 detach:
x_enc -> mean -> loss
反向传播时 loss 会沿 mean 这条路影响 x_enc

detach:
x_enc -> mean -> detach -> loss
反向传播在 detach 处截断

在 PatchTST 这种标准化里,detach 表示:

text
均值/标准差作为当前输入窗口的统计量使用,
但不希望模型通过反向传播去“利用”这条统计量计算路径。

8.5 标准化完整具体例子

代码:

python
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev

对第 0 条样本:

text
原始:
[[1, 10],
 [2, 20],
 [3, 30]]

均值:
[[2, 20]]

减均值:
[[-1, -10],
 [ 0,   0],
 [ 1,  10]]

标准差:
[[0.8165, 8.1650]]

除以标准差:
[[-1/0.8165, -10/8.1650],
 [ 0/0.8165,   0/8.1650],
 [ 1/0.8165,  10/8.1650]]


[[-1.225, -1.225],
 [ 0.000,  0.000],
 [ 1.225,  1.225]]

在我们模型里的位置:

模型源码位置作用
InformerInformer.short_forecast标准化和反标准化
PatchTSTPatchTST.forecast标准化和反标准化

详细下钻:[[../model-order/02-PatchTST/02-Flatten与标准化统计-PatchTST输出头|02-Flatten与标准化统计-PatchTST输出头]]


9. 一句话总结

这一组 Tensor 操作可以这样记:

text
切片: 从大 tensor 取局部,并决定是否保留维度
repeat/cat: 复制和拼接,常用于 padding
permute/transpose: 调整维度顺序,常用于适配 Conv1d/Pool1d
reshape/view: 改 shape,常用于合并变量或拆多头
unsqueeze/squeeze: 加删 size=1 维,常用于广播和 matmul
unfold: 滑窗切 patch
matmul/einsum/softmax/topk: attention 的基础运算
mean/var/sqrt/detach: 标准化的基础运算

*记录并在线阅读我的笔记*