Appearance
PyTorch Tensor 基础操作:切片、变形、拼接、注意力计算
Abstract
这篇按“Tensor 操作类别”组织。
它解释模型里经常看到的 PyTorch 基础操作:怎么取子张量、怎么换维、怎么 reshape、怎么切 patch、怎么做 attention 里的矩阵运算,以及这些操作在 DLinear / Informer / PatchTST 的哪个位置出现。
0. 文件索引
| 项目 | 内容 |
|---|---|
| 本文主题 | PyTorch Tensor 基础操作 |
| 覆盖范围 | 切片、复制、拼接、换维、变形、加删维度、滑窗、attention 计算、统计操作 |
| 主要模型 | DLinear / Informer / PatchTST |
| 配套文档 | [[10-PyTorch常见层函数-池化卷积线性归一化Embedding]] |
0.1 知识地图
| 类别 | 函数/语法 | 最重要的理解 | 出现位置 |
|---|---|---|---|
| 切片 | x[:, 0:1, :] | 取第一个时间步且保留维度 | DLinear moving_avg |
| 切片 | x[:, -1:, :] | 取最后时间步且保留维度 | DLinear moving_avg |
| 复制 | repeat | 沿指定维度复制 | DLinear padding / PatchTST 反标准化 |
| 拼接 | torch.cat | 沿指定维度拼接 | DLinear padding / DataEmbedding_inverted |
| 换维 | permute | 任意重排维度 | DLinear / PatchTST |
| 换维 | transpose | 交换两个维度 | Conv1d / Attention |
| 变形 | reshape | 改 shape,可合并/拆分维度 | PatchTST |
| 变形 | view | 改 shape,常用于多头拆分 | AttentionLayer |
| 加维 | unsqueeze | 插入 size=1 维 | 标准化反变换 / ProbAttention |
| 删维 | squeeze | 删除 size=1 维 | ProbAttention |
| 滑窗 | unfold | 沿一维切 patch | PatchTST PatchEmbedding |
| 矩阵乘 | matmul | 批量矩阵乘 | ProbAttention |
| 爱因斯坦求和 | einsum | 明确指定维度乘加 | FullAttention |
| 概率化 | softmax | 某一维归一化为概率 | Attention |
| 选择 | topk | 取最大 k 个值/下标 | Informer ProbAttention |
| 统计 | mean/var/sqrt | 标准化 | Informer/PatchTST |
| 断梯度 | detach | 不参与反向传播 | 标准化均值 |
0.2 本文的具体例子标准
本文的例子标准
本文里的“例子”统一指:给一个很小但不退化的具体 tensor,写出操作前后的数值,写出对应数学公式,并说明每个维度代表什么。
text
1. 给一个很小的具体 tensor
2. 写出操作前后的数值
3. 写出对应数学公式
4. 说明每个维度代表什么也就是说,本文不只写:
text
(B,T,C) -> (B,C,T)而是尽量写成:
text
x =
[
[[1, 10], [2, 20], [3, 30]]
]
x.permute(0,2,1) =
[
[[1, 2, 3],
[10,20,30]]
]不用过度简单的矩阵例子
对
matmul / einsum / attention,尽量不用单纯2×2 @ 2×2。这种例子能算,但看不出“共享求和维”和“输出二维网格”的关系。本文第 7 节改用2×3 @ 3×2和L=3,S=3,H=2的 attention 例子。
1. 切片::、0:1、-1:
1.1 x[:, 0:1, :]
假设:
text
x.shape = (B,T,C) = (2,6,3)代码:
python
front = x[:, 0:1, :]含义:
| 片段 | 意思 |
|---|---|
: | 所有 batch |
0:1 | 第 0 个时间步,但保留时间维 |
: | 所有变量 |
shape:
text
(2,6,3) -> (2,1,3)数学写法:
text
front[b, 0, c] = x[b, 0, c]非常具体的例子:
python
x = torch.tensor([
[
[1, 10],
[2, 20],
[3, 30],
]
])
front = x[:, 0:1, :]结果:
text
x.shape = (1,3,2)
front.shape = (1,1,2)
front =
[
[
[1, 10]
]
]注意:
text
x[:, 0, :] -> shape = (1,2),时间维消失
x[:, 0:1, :] -> shape = (1,1,2),时间维保留在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| DLinear | Autoformer_EncDec.py -> moving_avg.forward | 取第一个时间步作为左 padding |
1.2 x[:, -1:, :]
-1 是最后一个元素。
代码:
python
end = x[:, -1:, :]shape:
text
(2,6,3) -> (2,1,3)数学写法:
text
end[b, 0, c] = x[b, T-1, c]非常具体的例子:
python
x = torch.tensor([
[
[1, 10],
[2, 20],
[3, 30],
]
])
end = x[:, -1:, :]结果:
text
end =
[
[
[3, 30]
]
]作用:
取最后一个时间步作为右 padding。
详细下钻:[[../model-order/01-DLinear/01-Autoformer_EncDec-moving_avg-forward-基础语法注解|01-Autoformer_EncDec-moving_avg-forward-基础语法注解]]
2. 复制与拼接:repeat 和 torch.cat
2.1 repeat
基本作用:
沿每个维度复制 tensor。
例子:
python
x.shape = (2, 1, 3)
y = x.repeat(1, 2, 1)shape:
text
(2,1,3) -> (2,2,3)数学写法:
如果:
text
y = x.repeat(r0, r1, r2)那么 y 是把 x 在每个维度按 r0/r1/r2 复制出来。常见场景是只复制时间维:
text
y[b, t, c] = x[b, 0, c]非常具体的例子:
python
x = torch.tensor([
[
[1, 10]
]
])
y = x.repeat(1, 3, 1)结果:
text
x.shape = (1,1,2)
y.shape = (1,3,2)
y =
[
[
[1, 10],
[1, 10],
[1, 10]
]
]在 DLinear 的 moving_avg 里,这就是把首尾点复制成 padding。
在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| DLinear | moving_avg.forward | 复制首尾时间步作为 padding |
| PatchTST | forecast 反标准化 | 把 (B,1,C) 复制成 (B,pred_len,C) |
2.2 torch.cat
基本作用:
沿某个维度把多个 tensor 拼起来。
例子:
python
x = torch.cat([front, x, end], dim=1)shape:
text
(2,1,3) + (2,6,3) + (2,1,3)
-> dim=1 拼接
(2,8,3)数学写法:
假设:
text
z = cat([a, b, c], dim=1)
a.shape = (B, T_a, C)
b.shape = (B, T_b, C)
c.shape = (B, T_c, C)那么:
text
z[b, t, ch] =
a[b, t, ch] if 0 <= t < T_a
b[b, t - T_a, ch] if T_a <= t < T_a + T_b
c[b, t - T_a - T_b, ch] if T_a + T_b <= t非常具体的例子:
python
front = torch.tensor([[[1, 10]]])
x = torch.tensor([
[
[2, 20],
[3, 30],
]
])
end = torch.tensor([[[4, 40]]])
y = torch.cat([front, x, end], dim=1)结果:
text
y.shape = (1,4,2)
y =
[
[
[1, 10],
[2, 20],
[3, 30],
[4, 40]
]
]在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| DLinear | moving_avg.forward | 左 padding + 原序列 + 右 padding |
| Informer 变体 | DataEmbedding_inverted.forward | 拼接输入值和时间特征 |
3. 换维:permute 和 transpose
3.1 permute
基本作用:
任意重排多个维度。
例子:
python
y = x.permute(0, 2, 1)shape:
text
(B,T,C) -> (B,C,T)
(2,6,3) -> (2,3,6)数学写法:
text
y = x.permute(0, 2, 1)
y[b, c, t] = x[b, t, c]非常具体的例子:
python
x = torch.tensor([
[
[1, 10],
[2, 20],
[3, 30],
]
])
y = x.permute(0, 2, 1)结果:
text
x.shape = (1,3,2)
y.shape = (1,2,3)
y =
[
[
[1, 2, 3],
[10, 20, 30]
]
]这里第 0 个变量的时间序列 [1,2,3] 被放到一整行,适配 Conv1d / AvgPool1d 的 (B,C,L) 格式。
在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| DLinear | moving_avg.forward | 转成 AvgPool1d 需要的 (B,C,L) |
| DLinear | DLinear.encoder | 让 Linear 作用在时间维 |
| PatchTST | PatchTST.forecast | 转成 (B,C,T) 进入 PatchEmbedding |
| PatchTST | forecast 后半段 | 把 (B,C,pred_len) 转回 (B,pred_len,C) |
3.2 transpose
基本作用:
只交换两个维度。
例子:
python
y = x.transpose(1, 2)shape:
text
(B,C,L) -> (B,L,C)数学写法:
text
y = x.transpose(1, 2)
y[b, l, c] = x[b, c, l]非常具体的例子:
python
x = torch.tensor([
[
[1, 2, 3],
[10, 20, 30],
]
])
y = x.transpose(1, 2)结果:
text
x.shape = (1,2,3)
y.shape = (1,3,2)
y =
[
[
[1, 10],
[2, 20],
[3, 30]
]
]transpose(1,2) 是 permute(0,2,1) 的二轴交换特例。
在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| Informer | TokenEmbedding.forward | Conv1d 后转回 (B,L,d_model) |
| Informer/PatchTST | EncoderLayer.forward | FFN 的 Conv1d 前后换维 |
| Attention | ProbAttention | K.transpose(-2, -1) 做矩阵乘 |
详细下钻:[[../model-order/02-PatchTST/03-Conv1d与BCL格式-Informer-PatchTST|03-Conv1d与BCL格式-Informer-PatchTST]]
4. 改形状:reshape 和 view
4.1 reshape
基本作用:
改变 tensor shape,常用于合并或还原维度。
PatchTST 里:
python
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))shape:
text
(B,C,patch_num,patch_len) -> (B*C,patch_num,patch_len)
(2,4,6,4) -> (8,6,4)作用:
把每个变量当成独立序列,进入 channel-independent attention。
数学写法:
假设:
text
x.shape = (B,C,P,L)
y = reshape(x, (B*C, P, L))那么新 batch 下标 bc 对应:
text
b = bc // C
c = bc % C
y[bc, p, l] = x[b, c, p, l]非常具体的例子:
python
x = torch.tensor([
[
[[1, 2], [3, 4]], # batch0, channel0
[[10, 20], [30, 40]], # batch0, channel1
]
])
y = x.reshape(1 * 2, 2, 2)结果:
text
x.shape = (1,2,2,2)
y.shape = (2,2,2)
y =
[
[
[1, 2],
[3, 4]
],
[
[10, 20],
[30, 40]
]
]PatchTST 的 channel-independent 就是这个思想:
text
原来 batch 里有多个变量
reshape 后每个变量变成一条独立样本4.2 view
基本作用:
改变 tensor shape,源码里常用于 attention 拆多头。
AttentionLayer 里:
python
queries = self.query_projection(queries).view(B, L, H, -1)shape:
text
(B,L,d_model) -> Linear -> (B,L,H*d_k) -> view -> (B,L,H,d_k)
(2,6,16) -> (2,6,16) -> (2,6,2,8)数学写法:
假设:
text
d_model = H * d_k
y = x.view(B, L, H, d_k)那么:
text
y[b, l, h, e] = x[b, l, h*d_k + e]非常具体的例子:
python
x = torch.tensor([
[
[1, 2, 3, 4],
[5, 6, 7, 8],
]
])
y = x.view(1, 2, 2, 2)结果:
text
x.shape = (1,2,4)
y.shape = (1,2,2,2)
token0: [1,2,3,4]
-> head0 [1,2]
-> head1 [3,4]
token1: [5,6,7,8]
-> head0 [5,6]
-> head1 [7,8]view 在 attention 里不是“做注意力”,只是把最后一维重新解释成:
text
d_model = n_heads × head_dim在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| Informer | AttentionLayer.forward | Q/K/V 拆多头 |
| PatchTST | AttentionLayer.forward | Q/K/V 拆多头 |
详细下钻:[[../model-order/02-PatchTST/05-Attention基础操作-view-matmul-einsum-softmax-topk|05-Attention基础操作-view-matmul-einsum-softmax-topk]]
5. 加维和删维:unsqueeze / squeeze
5.1 unsqueeze
基本作用:
在指定位置插入一个 size=1 的维度。
例子:
python
x.shape = (2, 4)
y = x.unsqueeze(1)shape:
text
(2,4) -> (2,1,4)数学写法:
text
y = x.unsqueeze(1)
y[b, 0, c] = x[b, c]非常具体的例子:
python
x = torch.tensor([
[1, 10],
[2, 20],
])
y = x.unsqueeze(1)结果:
text
x.shape = (2,2)
y.shape = (2,1,2)
y =
[
[[1, 10]],
[[2, 20]]
]这常用于制造可广播的时间维,比如 (B,C) 变成 (B,1,C)。
在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| PatchTST | forecast 反标准化 | (B,C) 变 (B,1,C),再 repeat 到 pred_len |
| Informer | ProbAttention | 为批量矩阵乘或广播增加维度 |
5.2 squeeze
基本作用:
删除 size=1 的维度。
在 Informer ProbAttention 里:
python
Q_K_sample = torch.matmul(...).squeeze()作用:
把矩阵乘后多出来的 size=1 维度去掉,得到采样分数。
数学写法:
text
x.shape = (B,1,C)
y = x.squeeze(1)
y[b,c] = x[b,0,c]非常具体的例子:
python
x = torch.tensor([
[[1, 10]],
[[2, 20]],
])
y = x.squeeze(1)结果:
text
x.shape = (2,1,2)
y.shape = (2,2)
y =
[
[1, 10],
[2, 20]
]注意:
text
squeeze() 会删除所有 size=1 的维度
squeeze(dim) 只删除指定维度6. 滑窗切片:unfold
基本作用:
沿某个维度滑窗,切出局部窗口。
PatchTST 里:
python
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)toy:
text
x: (B,C,T_padded) = (2,4,14)
patch_len = 4
stride = 2输出:
text
(2,4,14) -> (2,4,6,4)其中:
text
6 = patch_num
4 = patch_len可视化:
text
Patch 0: [t0, t1, t2, t3]
Patch 1: [t2, t3, t4, t5]
Patch 2: [t4, t5, t6, t7]
...数学写法:
如果:
text
y = x.unfold(dimension=-1, size=patch_len, step=stride)那么:
text
y[b, c, p, i] = x[b, c, p*stride + i]其中:
text
p = 第几个 patch
i = patch 内第几个点非常具体的例子:
python
x = torch.tensor([
[
[1, 2, 3, 4, 5, 6]
]
])
y = x.unfold(dimension=-1, size=3, step=2)结果:
text
x.shape = (1,1,6)
y.shape = (1,1,2,3)
y =
[
[
[
[1, 2, 3], # patch0,从 t0 开始
[3, 4, 5], # patch1,从 t2 开始
]
]
]因为:
text
patch0 = x[..., 0:3] = [1,2,3]
patch1 = x[..., 2:5] = [3,4,5]在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| PatchTST | Embed.py -> PatchEmbedding.forward | 把时间序列切成 patch token |
详细下钻:[[../model-order/02-PatchTST/01-ReplicationPad1d与unfold-PatchTST-PatchEmbedding|01-ReplicationPad1d与unfold-PatchTST-PatchEmbedding]]
7. Attention 计算:matmul / einsum / softmax / topk
本节主线
Attention 的数学核心可以拆成四步:
PyTorch 代码里对应
matmul / einsum / softmax。topk是 Informer ProbAttention 为了减少计算量,先筛 query 的操作。
7.0 先理解 matmul,再理解 einsum
学习顺序
matmul是标准矩阵乘法;einsum是你自己写维度规则的广义乘法/求和工具。所以先把
matmul的行列点积、矩阵向量、batch 矩阵乘法看懂,再看einsum("blhe,bshe->bhls")会更自然。
matmul 的核心规则:
也就是:
text
前一个张量的最后一维 = 后一个张量的倒数第二维einsum 的核心规则:
text
1. 给每个维度起字母名
2. 重复但不出现在箭头右边的字母会被求和消掉
3. 箭头右边写什么,输出就保留什么维度对照:
| 操作 | matmul 写法 | einsum 写法 | 含义 |
|---|---|---|---|
| 向量点积 | torch.matmul(x, y) | torch.einsum("i,i->", x, y) | i 维乘加后消失 |
| 矩阵乘法 | torch.matmul(A, B) | torch.einsum("ij,jk->ik", A, B) | j 维乘加后消失 |
| batch 矩阵乘法 | torch.matmul(A, B) | torch.einsum("bij,bjk->bik", A, B) | 每个 b 独立做矩阵乘法 |
| attention score | 不直接写成普通 2D matmul | torch.einsum("blhe,bshe->bhls", Q, K) | 每个 b,h 下,所有 query-key 做点积 |
7.1 torch.matmul
基本作用:
批量矩阵乘。
ProbAttention 里:
python
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))shape:
text
Q_reduce: (B,H,n_top,d_k)
K^T: (B,H,d_k,L_K)
输出: (B,H,n_top,L_K)数学写法:
7.1.1 向量和向量:点积
代码:
python
x = torch.tensor([1., 2., 3.])
y = torch.tensor([4., 5., 6.])
z = torch.matmul(x, y)结果:
text
x.shape = (3,)
y.shape = (3,)
z.shape = ()
z = 1*4 + 2*5 + 3*6 = 32数学上:
这里输出是 scalar,因为两个向量的共同维度 i 被求和消掉。
7.1.2 矩阵和向量:每一行和向量做点积
代码:
python
A = torch.tensor([
[1., 2., 3.],
[4., 5., 6.],
])
x = torch.tensor([10., 20., 30.])
y = torch.matmul(A, x)结果:
text
A.shape = (2,3)
x.shape = (3,)
y.shape = (2,)
y[0] = 1*10 + 2*20 + 3*30 = 140
y[1] = 4*10 + 5*20 + 6*30 = 320
y = [140, 320]语义:
text
矩阵 A 的每一行,都和向量 x 做一次点积。7.1.3 向量和矩阵:向量和每一列做点积
代码:
python
x = torch.tensor([10., 20.])
B = torch.tensor([
[1., 2., 3.],
[4., 5., 6.],
])
y = torch.matmul(x, B)结果:
text
x.shape = (2,)
B.shape = (2,3)
y.shape = (3,)
y[0] = 10*1 + 20*4 = 90
y[1] = 10*2 + 20*5 = 120
y[2] = 10*3 + 20*6 = 150
y = [90, 120, 150]语义:
text
向量 x 去乘矩阵 B 的每一列。7.1.4 矩阵和矩阵:行乘列
非常具体的例子:
python
A = torch.tensor([
[1., 2., 3.],
[4., 5., 6.],
])
B = torch.tensor([
[10., 20.],
[30., 40.],
[50., 60.],
])
Y = torch.matmul(A, B)结果:
text
A.shape = (2,3)
B.shape = (3,2)
Y.shape = (2,2)
Y =
[
[1*10 + 2*30 + 3*50, 1*20 + 2*40 + 3*60],
[4*10 + 5*30 + 6*50, 4*20 + 5*40 + 6*60],
]
=
[
[220, 280],
[490, 640],
]读矩阵乘法时只盯一个核心
A的列数必须等于B的行数。这个共享维度3被求和消掉,输出保留A的行数和B的列数:(2,3) @ (3,2) -> (2,2)。
7.1.5 高维张量:前面的维度当 batch
batch matmul
高维
matmul不是把所有维度都混在一起乘。它把最后两维当矩阵乘法维度,前面的维度当 batch 维度。
批量矩阵乘的具体例子:
python
A = torch.tensor([
[
[1., 2., 3.],
[4., 5., 6.],
],
[
[2., 0., 1.],
[1., 3., 2.],
],
])
B = torch.tensor([
[
[10., 20.],
[30., 40.],
[50., 60.],
],
[
[1., 2.],
[3., 4.],
[5., 6.],
],
])
Y = torch.matmul(A, B)结果:
text
A.shape = (2,2,3)
B.shape = (2,3,2)
Y.shape = (2,2,2)
第 0 个 batch:
[
[220, 280],
[490, 640],
]
第 1 个 batch:
[
[2*1 + 0*3 + 1*5, 2*2 + 0*4 + 1*6],
[1*1 + 3*3 + 2*5, 1*2 + 3*4 + 2*6],
]
=
[
[7, 10],
[20, 26],
]在 ProbAttention 里,torch.matmul(Q_reduce, K.transpose(-2, -1)) 就是在每个 B,H 组合下做这样的批量矩阵乘。
对应到 attention:
text
Q_reduce.shape = (B,H,n_top,d_k)
K.transpose(-2,-1).shape = (B,H,d_k,L_K)
torch.matmul 后:
(B,H,n_top,L_K)其中:
text
B,H 是 batch-like 维度
n_top,d_k @ d_k,L_K 才是真正的矩阵乘法7.2 torch.einsum
1. 核心法则:标签匹配(Pattern Matching)
einsum 的公式由逗号隔开的输入标签和箭头后的输出标签组成。例如:"abcd, adef -> abcef"。
求和规则(收缩维度): 如果一个字母在左侧(输入)出现多次,但在右侧(输出)消失了,那么这个维度就会被相乘并求和(类似于点积)。
保留规则(自由维度): 如果一个字母在右侧(输出)出现了,它就会被保留在结果中。
对齐规则(批次维度): 如果一个字母在左右两侧都出现了,它就像“批次(Batch)”一样。
在 Transformer 的 einsum('b h q d, b h k d -> b h q k', Q, K) 这个操作中,einsum 实际上帮你完成了一系列复杂的维度对齐和转置。
我们可以把这个过程拆解为三个“自动行为”来理解:
1. 自动对齐的“批次循环”(b 和 h)
当 b 和 h 同时出现在两个输入和输出的相同位置时,einsum 内部逻辑相当于写了两个嵌套循环:
python
# 伪代码:理解 einsum 内部如何处理 b 和 h
for b in range(Batch):
for h in range(Head):
# 在每一个特定的 (b, h) 索引下,取出的是两个 2D 矩阵
matrix_Q = Q[b, h] # 形状 (q, d)
matrix_K = K[b, h] # 形状 (k, d)
# 对这两个 2D 矩阵进行后续操作...
result[b, h] = some_op(matrix_Q, matrix_K)所以,b 和 h 被自动理解为“独立运行的平行维度”,它们不参与乘法叠加,只是作为索引存在。
2. 自动处理“维度转置”(d 的收缩)
这是 einsum 最精妙的地方。
- 标准矩阵乘法 (
matmul) 的要求: 左矩阵的列数必须等于右矩阵的行数。即(q, d) @ (d, k)。 - 你的输入:
Q是(q, d),K也是(k, d)(最后一位都是d)。
如果你用 torch.matmul(Q, K),程序会报错,因为维度不匹配(d 对不上 k)。你必须手动写成 torch.matmul(Q, K.transpose(-1, -2))。
但在 einsum 中: 你定义了 ... q d, ... k d -> ... q k。 einsum 看到 d 在两个输入中都出现了,但在输出中消失了,它就明白:“哦!d 是我要进行相乘并求和的那个维度。”
它在内部会自动找到两个张量中标签为 d 的维度,把它们对齐,不管 d 是在行还是在列。这就是你说的**“自动转换好维度”**。
3. 内部运算逻辑:QD × DK → QK
正如你所猜想的,einsum 内部执行的逻辑就是:
- 逐元素相乘: 取出
Q[b, h, q, d]和K[b, h, k, d]。注意,此时q和k是不同的索引,所以它其实是在做一个类似“广播”的动作,生成一个中间状态(b, h, q, k, d)。 - 求和(Summation): 因为输出标签里没有
d,它会对所有d维上的结果进行累加。
数学表达就是:
这恰恰就是矩阵乘法
进阶:Transformer 中的下一步 (V 的加权)
理解了上面那个,你就能秒懂 Attention 的下一步:注意力分数乘以 Value (
attn:(b, h, q, k)(刚才计算的结果)V:(b, h, k, d)(Value 张量)- 目标: 得到
(b, h, q, d)
python
# einsum 表达式
torch.einsum('b h q k, b h k d -> b h q d', attn, V)这里的逻辑:
b, h依然是批次。k在左右都有,但在输出中消失对 k维度进行收缩(求和)。q和d被保留。- 这本质上是执行了
(q, k) @ (k, d)的矩阵乘法。
总结:为什么 Transformer 喜欢用 einsum?
- 免去转置烦恼: 你不需要纠结什么时候该
transpose(-1, -2),只需要给维度起个名字(如q, k, d)。 - 逻辑清晰: 看到
...q k, ...k d -> ...q d,一眼就能看出是在用“分数(k)”对“特征(d)”做加权平均。 - 代码稳健: 即使以后你把维度顺序改了(比如把 Head 换到最后:
b q h d),只要einsum的标签对应正确,代码逻辑依然成立,不需要改动复杂的transpose或view。
2. FullAttention的例子
FullAttention 里:
python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)shape:
text
queries: (B,L,H,E)
keys: (B,S,H,E)
scores: (B,H,L,S)含义:
对
E维做点积,得到每个 query-key 对的相似度。
数学写法:
非常具体的例子:
text
B=1, L=3, S=3, H=2, E=2
queries:
head0:
query0 [1, 0]
query1 [0, 1]
query2 [1, 1]
head1:
query0 [2, 0]
query1 [0, 2]
query2 [1, 1]
keys:
head0:
key0 [1, 0]
key1 [0, 1]
key2 [1, 1]
head1:
key0 [1, 1]
key1 [2, 0]
key2 [0, 2]逐项计算:
text
head0:
query0 对 key0 = [1,0] · [1,0] = 1
query0 对 key1 = [1,0] · [0,1] = 0
query0 对 key2 = [1,0] · [1,1] = 1
query1 对 key0 = [0,1] · [1,0] = 0
query1 对 key1 = [0,1] · [0,1] = 1
query1 对 key2 = [0,1] · [1,1] = 1
query2 对 key0 = [1,1] · [1,0] = 1
query2 对 key1 = [1,1] · [0,1] = 1
query2 对 key2 = [1,1] · [1,1] = 2
head1:
query0 对 key0 = [2,0] · [1,1] = 2
query0 对 key1 = [2,0] · [2,0] = 4
query0 对 key2 = [2,0] · [0,2] = 0
query1 对 key0 = [0,2] · [1,1] = 2
query1 对 key1 = [0,2] · [2,0] = 0
query1 对 key2 = [0,2] · [0,2] = 4
query2 对 key0 = [1,1] · [1,1] = 2
query2 对 key1 = [1,1] · [2,0] = 2
query2 对 key2 = [1,1] · [0,2] = 2所以:
text
scores.shape = (1,2,3,3)
scores =
[
head0:
[
[1, 0, 1],
[0, 1, 1],
[1, 1, 2],
],
head1:
[
[2, 4, 0],
[2, 0, 4],
[2, 2, 2],
],
]einsum("blhe,bshe->bhls") 的读法
einsum("blhe,bshe->bhls") 的读法
b和h是分组维,保留下来;l是 query token 维,保留下来;s是 key token 维,保留下来;e同时出现在输入但不出现在输出,所以被乘加求和。
7.3 torch.softmax
Attention 里:
python
A = torch.softmax(scores, dim=-1)如果:
text
scores.shape = (B,H,L,S)那么 dim=-1 是 key 维 S。
含义:
每个 query 对所有 key 的权重和为 1。
数学写法:
非常具体的例子:
python
scores = torch.tensor([1.0, 0.0, 2.0])
A = torch.softmax(scores, dim=-1)计算:
text
exp(1) ≈ 2.718
exp(0) = 1
exp(2) ≈ 7.389
sum = 2.718 + 1 + 7.389 = 11.107
softmax([1,0,2])
= [2.718/11.107, 1/11.107, 7.389/11.107]
≈ [0.245, 0.090, 0.665]如果:
text
scores =
[
[1, 0, 1],
[0, 1, 1],
[1, 1, 2],
]沿最后一维 softmax:
text
softmax([1,0,1]) ≈ [0.422, 0.155, 0.422]
softmax([0,1,1]) ≈ [0.155, 0.422, 0.422]
softmax([1,1,2]) ≈ [0.212, 0.212, 0.576]Attention 里 softmax 的方向
scores.shape = (B,H,L,S)时,dim=-1是S,也就是 key 维。含义是:每个 query token 对所有 key token 的权重和为 1。
7.4 topk
Informer 的 ProbAttention 里:
python
M_top = M.topk(n_top, sorted=False)[1]作用:
选出稀疏度最高的 query 下标,只对这些 query 做更完整的 attention。
数学写法:
text
values, indices = topk(x, k)表示取 x 中最大的 k 个值和对应下标。
非常具体的例子:
python
x = torch.tensor([0.2, 3.0, 1.5, 4.0, 2.2])
values, indices = x.topk(3, sorted=True)结果:
text
values = [4.0, 3.0, 2.2]
indices = [3, 1, 4]如果 sorted=True,结果会按值从大到小排序。Informer 里常用 [1] 只取下标:
python
M_top = M.topk(n_top, sorted=False)[1]这里 M_top 是被选中的 query 位置。
7.5 Attention 完整手算例子:QK^T -> softmax -> A V
本节目的
这个例子不是为了贴近真实参数大小,而是为了看清每一个 token 的信息怎样经过
scores、A、V_out流动。例子选择L=3,S=3,H=2,E=2,D=2,比2×2更能体现“多 query、多 key、多 head”的关系。
这个例子对应 PatchTST 的 FullAttention 核心:
python
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
A = torch.softmax(scale * scores, dim=-1)
V = torch.einsum("bhls,bshd->blhd", A, values)固定:
text
B = 1
L = 3
S = 3
H = 2
E = D = 2
scale = 1 / sqrt(E) = 1 / sqrt(2) ≈ 0.707设:
text
head0:
Q0 = [1, 0] K0 = [1, 0] V0 = [10, 0]
Q1 = [0, 1] K1 = [0, 1] V1 = [0, 20]
Q2 = [1, 1] K2 = [1, 1] V2 = [30, 40]
head1:
Q0 = [2, 0] K0 = [1, 1] V0 = [1, 10]
Q1 = [0, 2] K1 = [2, 0] V1 = [2, 20]
Q2 = [1, 1] K2 = [0, 2] V2 = [3, 30]第一步:算 scores。
text
scores = QK^T
head0 scores:
key0 key1 key2
query0 1 0 1
query1 0 1 1
query2 1 1 2
head1 scores:
key0 key1 key2
query0 2 4 0
query1 2 0 4
query2 2 2 2第二步:scale。
text
scale * scores =
head0:
key0 key1 key2
query0 0.707 0 0.707
query1 0 0.707 0.707
query2 0.707 0.707 1.414
head1:
key0 key1 key2
query0 1.414 2.828 0
query1 1.414 0 2.828
query2 1.414 1.414 1.414第三步:对 key 维 softmax。
text
head0:
softmax([0.707, 0, 0.707]) ≈ [0.401, 0.198, 0.401]
softmax([0, 0.707, 0.707]) ≈ [0.198, 0.401, 0.401]
softmax([0.707, 0.707, 1.414]) ≈ [0.248, 0.248, 0.503]
head1:
softmax([1.414, 2.828, 0 ]) ≈ [0.187, 0.768, 0.045]
softmax([1.414, 0, 2.828]) ≈ [0.187, 0.045, 0.768]
softmax([1.414, 1.414, 1.414]) = [0.333, 0.333, 0.333]所以注意力权重:
text
head0 A =
key0 key1 key2
query0 0.401 0.198 0.401
query1 0.198 0.401 0.401
query2 0.248 0.248 0.503
head1 A =
key0 key1 key2
query0 0.187 0.768 0.045
query1 0.187 0.045 0.768
query2 0.333 0.333 0.333第四步:用 A 加权 values。
text
head0:
out query0
= 0.401*[10,0] + 0.198*[0,20] + 0.401*[30,40]
= [16.04, 20.00]
out query1
= 0.198*[10,0] + 0.401*[0,20] + 0.401*[30,40]
= [14.01, 24.06]
out query2
= 0.248*[10,0] + 0.248*[0,20] + 0.503*[30,40]
= [17.57, 25.08]
head1:
out query0
= 0.187*[1,10] + 0.768*[2,20] + 0.045*[3,30]
= [1.858, 18.58]
out query1
= 0.187*[1,10] + 0.045*[2,20] + 0.768*[3,30]
= [2.581, 25.81]
out query2
= 0.333*[1,10] + 0.333*[2,20] + 0.333*[3,30]
= [2.000, 20.00]最终:
text
V_out =
[
token0:
head0 [16.04, 20.00]
head1 [1.858, 18.58]
token1:
head0 [14.01, 24.06]
head1 [2.581, 25.81]
token2:
head0 [17.57, 25.08]
head1 [2.000, 20.00]
]shape 对应:
text
queries: (1,3,2,2)
keys: (1,3,2,2)
values: (1,3,2,2)
scores: (1,2,3,3)
A: (1,2,3,3)
V_out: (1,3,2,2)这个例子的语义是:
text
每个 query token 根据自己和 key token 的相似度,
从 value token 中加权拿信息。和 AttentionLayer 的关系
AttentionLayer 的关系
FullAttention输出的是(B,L,H,D)。回到AttentionLayer后,还会执行out.view(B,L,-1),把H×D合并回d_model,再经过out_projection返回(B,L,d_model)。
详细下钻:[[../model-order/02-PatchTST/05-Attention基础操作-view-matmul-einsum-softmax-topk|05-Attention基础操作-view-matmul-einsum-softmax-topk]]
8. 统计操作:mean / var / sqrt / detach
PatchTST 里:
python
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev输入:
text
x_enc.shape = (B,T,C)均值:
text
mean(dim=1, keepdim=True)
(B,T,C) -> (B,1,C)8.1 mean(1, keepdim=True) 到底按什么求均值
x_enc.mean(1, keepdim=True) 等价于:
python
x_enc.mean(dim=1, keepdim=True)这里 dim=1 指的是第 1 维。
如果:
text
x_enc.shape = (B, T, C)那么三个维度编号是:
| 维度编号 | 语义 |
|---|---|
dim=0 | batch 维,几条样本 |
dim=1 | time 维,几个历史时间步 |
dim=2 | channel 维,几个变量 |
所以:
python
x_enc.mean(dim=1, keepdim=True)意思是:
对每条样本、每个变量,沿时间维
T求平均。
一个具体数字例子
假设:
text
B = 2
T = 3
C = 2也就是:
text
x_enc.shape = (2, 3, 2)具体数值:
python
x_enc = torch.tensor([
[
[1., 10.], # batch 0, t0, 两个变量
[2., 20.], # batch 0, t1
[3., 30.], # batch 0, t2
],
[
[4., 40.], # batch 1, t0
[5., 50.], # batch 1, t1
[6., 60.], # batch 1, t2
],
])对 dim=1 求均值,就是把每个 batch 内的 t0/t1/t2 平均掉。
第 0 条样本:
text
变量 0: mean([1, 2, 3]) = 2
变量 1: mean([10, 20, 30]) = 20第 1 条样本:
text
变量 0: mean([4, 5, 6]) = 5
变量 1: mean([40, 50, 60]) = 50所以:
python
means = x_enc.mean(dim=1, keepdim=True)结果是:
text
means =
[
[[2., 20.]],
[[5., 50.]],
]shape 是:
text
(2, 1, 2)中间维度为什么是 1?
因为 keepdim=True 表示:
虽然时间维被平均掉了,但仍然保留这个维度的位置,只是长度从
T=3变成1。
如果不用 keepdim=True:
python
x_enc.mean(dim=1)shape 会是:
text
(2, 2)也就是时间维直接消失。
为什么后面可以 x_enc - means
原始:
text
x_enc.shape = (2, 3, 2)
means.shape = (2, 1, 2)PyTorch 会广播:
text
(2, 1, 2) -> (2, 3, 2)相当于每个时间步都减同一个历史均值。
第 0 条样本:
text
原始:
[[1, 10],
[2, 20],
[3, 30]]
均值:
[[2, 20]]
相减:
[[1-2, 10-20],
[2-2, 20-20],
[3-2, 30-20]]
=
[[-1, -10],
[ 0, 0],
[ 1, 10]]方差:
text
var(dim=1, keepdim=True)
(B,T,C) -> (B,1,C)8.2 var(dim=1, keepdim=True, unbiased=False) 的具体例子
数学写法:
text
var[b, 0, c] = (1/T) * Σ_t (x[b,t,c] - mean[b,0,c])²注意这里 unbiased=False,所以除以 T,不是除以 T-1。
继续用上面的第 0 条样本:
text
原始:
[[1, 10],
[2, 20],
[3, 30]]
均值:
[[2, 20]]变量 0:
text
var = ((1-2)^2 + (2-2)^2 + (3-2)^2) / 3
= (1 + 0 + 1) / 3
= 0.6667变量 1:
text
var = ((10-20)^2 + (20-20)^2 + (30-20)^2) / 3
= (100 + 0 + 100) / 3
= 66.6667所以第 0 条样本的方差是:
text
var[0] = [[0.6667, 66.6667]]sqrt:
方差开方得到标准差。
8.3 sqrt 的具体例子
PatchTST 里:
python
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)数学写法:
text
stdev[b,0,c] = sqrt(var[b,0,c] + 1e-5)继续上面的结果:
text
var[0] = [[0.6667, 66.6667]]那么:
text
stdev[0,0,0] = sqrt(0.6667 + 1e-5) ≈ 0.8165
stdev[0,0,1] = sqrt(66.6667 + 1e-5) ≈ 8.1650所以:
text
stdev[0] ≈ [[0.8165, 8.1650]]1e-5 的作用:
text
防止方差为 0 时除以 0。detach:
不让均值这条计算参与反向传播。
8.4 detach 的具体例子
代码:
python
means = x_enc.mean(1, keepdim=True).detach()数学数值上:
text
detach 不改变 means 的值。例如:
text
mean before detach =
[
[[2., 20.]],
[[5., 50.]],
]
mean after detach =
[
[[2., 20.]],
[[5., 50.]],
]变化只发生在计算图:
text
不 detach:
x_enc -> mean -> loss
反向传播时 loss 会沿 mean 这条路影响 x_enc
detach:
x_enc -> mean -> detach -> loss
反向传播在 detach 处截断在 PatchTST 这种标准化里,detach 表示:
text
均值/标准差作为当前输入窗口的统计量使用,
但不希望模型通过反向传播去“利用”这条统计量计算路径。8.5 标准化完整具体例子
代码:
python
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev对第 0 条样本:
text
原始:
[[1, 10],
[2, 20],
[3, 30]]
均值:
[[2, 20]]
减均值:
[[-1, -10],
[ 0, 0],
[ 1, 10]]
标准差:
[[0.8165, 8.1650]]
除以标准差:
[[-1/0.8165, -10/8.1650],
[ 0/0.8165, 0/8.1650],
[ 1/0.8165, 10/8.1650]]
≈
[[-1.225, -1.225],
[ 0.000, 0.000],
[ 1.225, 1.225]]在我们模型里的位置:
| 模型 | 源码位置 | 作用 |
|---|---|---|
| Informer | Informer.short_forecast | 标准化和反标准化 |
| PatchTST | PatchTST.forecast | 标准化和反标准化 |
详细下钻:[[../model-order/02-PatchTST/02-Flatten与标准化统计-PatchTST输出头|02-Flatten与标准化统计-PatchTST输出头]]
9. 一句话总结
这一组 Tensor 操作可以这样记:
text
切片: 从大 tensor 取局部,并决定是否保留维度
repeat/cat: 复制和拼接,常用于 padding
permute/transpose: 调整维度顺序,常用于适配 Conv1d/Pool1d
reshape/view: 改 shape,常用于合并变量或拆多头
unsqueeze/squeeze: 加删 size=1 维,常用于广播和 matmul
unfold: 滑窗切 patch
matmul/einsum/softmax/topk: attention 的基础运算
mean/var/sqrt/detach: 标准化的基础运算