Appearance
AutoCorrelation 完整手推例子
配套文档:[[04A-Layer5-AutoCorrelation]](宏观原理与公式链)
PyTorch 函数速查:[[../../pytorch-basics/concept-order/11-PyTorch-Tensor基础操作-切片变形拼接注意力]]
本文用一组固定的 toy 数据,从 forward() 入口一步步追踪到输出,每步给出具体张量数值、对应的 PyTorch 函数含义、以及数学意义。全文只用这一个 toy,数字前后完全连贯。
0. Toy 数据定义(全文通用)
B=1, L=8, H=2, E=2, D=2(D=E=2,V 的特征维与 Q/K 相同)Q 和 K(完全相同,验证自相关)
所有 head 和 channel 共用同一条一维信号:
这是周期为 4 的脉冲序列(t=0,4 各有一个脉冲)。
python
# 完整形状 (B, L, H, E) = (1, 8, 2, 2)
# 对所有 h, e: Q[0, :, h, e] = K[0, :, h, e] = [1,0,0,0,1,0,0,0]V(四条不同曲线,体现 roll 的效果)
python
# V 形状 (B, L, H, D) = (1, 8, 2, 2)
V[0, :, 0, 0] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] # 递增
V[0, :, 0, 1] = [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] # 递减
V[0, :, 1, 0] = [0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8] # 周期=2 交替
V[0, :, 1, 1] = [0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8] # 周期=4(恰好与 lag 匹配)为什么 V 用四种曲线
聚合结果取决于 V 本身的周期性:
- 递增/递减曲线能直观看到 roll 把"4步之后的值"混入当前位置
- 周期=2 曲线 roll(-4) 后发生改变(2≠4)
- 周期=4 曲线 roll(-4) 后完全不变(完美匹配检测到的 lag)
1. 步骤 0:L/S 对齐
python
B, L, H, E = queries.shape # 1, 8, 2, 2
_, S, _, D = values.shape # S = 8
# L == S → 走 else 分支
values = values[:, :L, :, :] # 无变化
keys = keys[:, :L, :, :] # 无变化目的:FFT 互相关要求 Q 和 K/V 的时序长度一致,统一对齐到 min(L, S)。
此 toy 里 L=S=8,两个分支均不改变形状。
2. 步骤 1:permute — 把时间维移到最后
python
q_time_major = queries.permute(0, 2, 3, 1).contiguous()
k_time_major = keys.permute(0, 2, 3, 1).contiguous()为什么要 permute
torch.fft.rfft 的 dim=-1 参数是"对最后一维做 FFT"。原始 queries (B,L,H,E) 的最后一维是 E(特征维),而我们需要对 L(时间维)做频域变换。
所以先把时间维 L 换到最后:
数值变化(追踪 b=0, h=0, e=0 这一切片):
# 原始 queries[0, :, 0, 0] 是按 L 排列的一列
queries[0, :, 0, 0] = [1, 0, 0, 0, 1, 0, 0, 0] # shape=(8,),L 在 dim=1 位置
# permute 后,同样的数据现在在最后一维
q_time_major[0, 0, 0, :] = [1, 0, 0, 0, 1, 0, 0, 0] # shape=(8,),L 在 dim=-1 位置permute 不改变数值,只改变轴的解释顺序(底层内存重排由 .contiguous() 完成)。
整体形状:(1,8,2,2) → (1,2,2,8)
contiguous() 的作用
contiguous() 的作用
permute只修改 stride 描述符,不移动实际内存。.contiguous()保证后续rfft能正确访问连续内存。
3. 步骤 2:rfft — 时域 → 频域
python
q_fft = torch.fft.rfft(q_time_major, dim=-1)
k_fft = torch.fft.rfft(k_time_major, dim=-1)rfft 是什么
rfft = Real FFT。对一条实数序列
由于实数输入的 DFT 满足共轭对称
手算 rfft([1,0,0,0,1,0,0,0])
只有 n=0 和 n=4 处有非零值(x[0]=x[4]=1),代入公式:
逐个计算:
| k(频率) | X[k] | 解读 | |
|---|---|---|---|
| 0 | +1 | 2+0j | 直流分量(总能量) |
| 1 | −1 | 0+0j | 无此频率 |
| 2 | +1 | 2+0j | 频率2,周期 L/2=4 ← 主频 |
| 3 | −1 | 0+0j | 无此频率 |
| 4 | +1 | 2+0j | 奈奎斯特频率 |
频率 k=2 对应的物理周期
频率 k 对应的周期 =
。这正是我们设置的信号周期 p=4。FFT 在频域准确捕捉到了这个主频。
整体 shape:(1,2,2,8) → (1,2,2,5)(复数张量)
4. 步骤 3:互相关 — 频域乘法 = 时域互相关
python
res = q_fft * torch.conj(k_fft)互相关定理(核心)
时域互相关定义(暴力计算需
互相关定理告诉我们,这等价于频域操作(
其中
torch.conj 是什么
对每个复数
对我们的 toy(Q=K,均为实数),共轭不改变值:
计算 res
逐元素相乘(Q=K 的自相关情形):
| k | q_fft[k] | conj(k_fft[k]) | res[k] |
|---|---|---|---|
| 0 | 2+0j | 2−0j | 4+0j |
| 1 | 0+0j | 0−0j | 0+0j |
| 2 | 2+0j | 2−0j | 4+0j |
| 3 | 0+0j | 0−0j | 0+0j |
| 4 | 2+0j | 2−0j | 4+0j |
形状不变:(1,2,2,5)(复数)
5. 步骤 4:irfft — 频域 → 时域(互相关函数)
python
corr = torch.fft.irfft(res, dim=-1)irfft 是什么
irfft = Inverse Real FFT。把 rfft 的输出(L//2+1 个复数)变回实数序列。
由于省略了 n 参数,PyTorch 自动推断输出长度:
irfft 内部先补全共轭对称的频谱,再做逆 DFT:
手算 irfft([4,0,4,0,4])
补全对称频谱(N=8 时,rfft 输出 X[0..4],irfft 内部补充 X[5..7]):
只有 k=0,2,4,6 处有非零值(均为4),逆变换:
逐位计算:
| τ | 计算 | corr[τ] | 含义 |
|---|---|---|---|
| 0 | 2 | lag=0:Q 和 K 完全对齐 | |
| 1 | 0 | lag=1:无相关 | |
| 2 | 0 | lag=2:无相关 | |
| 3 | 0 | lag=3:无相关 | |
| 4 | 2 | lag=4:移位一个周期后完美对齐 ✓ | |
| 5–7 | — | 0 | 无相关 |
corr 的物理含义
corr[b,h,e,τ]= 在头 h、通道 e 下,把 K 右移 τ 步后与 Q 的内积(对齐程度)。 峰值在 τ=0 和 τ=4,正好是周期 p=4 的整数倍。
形状:(1,2,2,8)(实数)
6. 步骤 5:两次均值 — 压缩 H 和 E 维
python
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)目的:topk 要在 lag(L)维上找峰值,需要先把 H 和 E 维度合并成单一的"平均相关强度"。
第一次 mean(dim=1):压缩 H 维
# dim=1 是 H 维(H=2 个 head 求均值)
# 由于所有 head 值相同,均值等于原值
结果[0, e, :] = [2,0,0,0,2,0,0,0] 对 e=0,1第二次 mean(dim=1):压缩 E 维
# dim=1 现在是 E 维(E=2 个 channel 求均值)
mean_value[0, :] = [2, 0, 0, 0, 2, 0, 0, 0]mean_value[b, τ] = 样本 b 在延迟 τ 处的跨所有 head 和 channel 的平均互相关强度。
7. 步骤 6:topk — 选主 lag
python
top_k = int(self.factor * math.log(length)) # factor=1, length=8
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]top_k 的计算
math.log 是自然对数
torch.mean(mean_value, dim=0):batch 均值
mean_value.shape = (1, 8) # B=1
torch.mean(mean_value, dim=0).shape = (8,) # B=1 时均值等于原值
= [2, 0, 0, 0, 2, 0, 0, 0]训练时 batch 均值的代价
当 B>1 时,此步会把所有样本的相关强度平均后再选 top-k,导致所有样本共享同一个 lag 集合。这是 training 路径的近似。Inference 路径用 per-sample topk 避免此问题(见 [[04A-Layer5-AutoCorrelation]] §5.5)。
torch.topk
python
torch.topk(tensor([2,0,0,0,2,0,0,0]), k=2, dim=-1)topk 返回两个张量:
values = tensor([2., 2.])
indices = tensor([0, 4]) # τ=0 和 τ=4 的相关强度最高8. 步骤 7:softmax — 相关强度 → 权重
python
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
tmp_corr = torch.softmax(weights, dim=-1)收集选中 lag 处的强度
python
mean_value[:, index[0]] = mean_value[:, 0] = tensor([2.0]) # shape (1,)
mean_value[:, index[1]] = mean_value[:, 4] = tensor([2.0]) # shape (1,)
weights = torch.stack([tensor([2.0]), tensor([2.0])], dim=-1)
= tensor([[2.0, 2.0]]) # shape (1, 2) = (B, top_k)softmax 归一化
两个值相等时:
tmp_corr = tensor([[0.5, 0.5]]) # shape (1, 2) = (B, top_k)这里 softmax 的意义
两个 lag 强度完全相等,模型给它们一样的权重。这等价于:用当前时刻和"4步之后的时刻"的等比例混合来作为最终表示。
9. 步骤 8:roll + 加权累加 — 时延聚合
python
values_permuted = values.permute(0, 2, 3, 1).contiguous()
# (1,8,2,2) → (1,2,2,8) = (B,H,D,L)
delays_agg = torch.zeros_like(values_permuted)
for i in range(top_k):
pattern = torch.roll(values_permuted, -int(index[i]), -1)
weight_expanded = tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1)
.repeat(1, head, channel, length)
delays_agg += pattern * weight_expanded先 permute values
与 Q/K 同理,把时间维移到最后:
values (1,8,2,2) → permute(0,2,3,1) → (1,2,2,8)
values_permuted[0, 0, 0, :] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
values_permuted[0, 0, 1, :] = [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
values_permuted[0, 1, 0, :] = [0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8]
values_permuted[0, 1, 1, :] = [0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8]权重广播展开(以 i=0 为例)
tmp_corr[:, 0] = [0.5],shape (1,) → 扩展到 (1, 2, 2, 8):
[0.5] (1,)
↓ unsqueeze(1) (1,1)
↓ unsqueeze(1) (1,1,1)
↓ unsqueeze(1) (1,1,1,1)
↓ repeat(1,2,2,8) (1,2,2,8)
所有 (h,d,t) 位置均填 0.5torch.roll 是什么
torch.roll(x, shifts=-τ, dims=-1) 把序列循环左移 τ 步:
物理含义:位置 t 上的值,变成原来"τ 步之后"的那个值。
i=0:τ=0,roll(-0)
roll 移位 0 步 = 原序列不变:
pattern = values_permuted(无变化)i=1:τ=4,roll(-4)
每条序列循环左移 4 步:
原始: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
t=0 t=1 t=2 t=3 t=4 t=5 t=6 t=7
roll(-4): 把 t=4..7 移到开头,t=0..3 绕到末尾
[0.5, 0.6, 0.7, 0.8, 0.1, 0.2, 0.3, 0.4]
↑原来的 t=4 现在在 t=0 位置roll(-τ) 的直觉
roll(V, -τ)[t]=V[t+τ],相当于"从位置 t 往前看 τ 步的值"。AutoCorrelation 检测到周期=4,意味着当前时刻 t 的信息可以从 t+4(一个完整周期之后)借鉴。roll(-4) 把每个 t+4 的值"提前"放到位置 t,再加权平均。
最终累加结果(逐切片验证)
(h=0, d=0):递增序列
| t | 原序列×0.5 | roll(-4)×0.5 | delays_agg |
|---|---|---|---|
| 0 | 0.1×0.5=0.05 | 0.5×0.5=0.25 | 0.30 |
| 1 | 0.2×0.5=0.10 | 0.6×0.5=0.30 | 0.40 |
| 2 | 0.3×0.5=0.15 | 0.7×0.5=0.35 | 0.50 |
| 3 | 0.4×0.5=0.20 | 0.8×0.5=0.40 | 0.60 |
| 4 | 0.5×0.5=0.25 | 0.1×0.5=0.05 | 0.30 |
| 5 | 0.6×0.5=0.30 | 0.2×0.5=0.10 | 0.40 |
| 6 | 0.7×0.5=0.35 | 0.3×0.5=0.15 | 0.50 |
| 7 | 0.8×0.5=0.40 | 0.4×0.5=0.20 | 0.60 |
结果:[0.3, 0.4, 0.5, 0.6, 0.3, 0.4, 0.5, 0.6]——注意序列变成周期=4的重复!
(h=0, d=1):递减序列
roll(-4):[0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1] → [0.4,0.3,0.2,0.1,0.8,0.7,0.6,0.5]
delays_agg[0,0,1,:] = 0.5×[0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1]
+ 0.5×[0.4,0.3,0.2,0.1,0.8,0.7,0.6,0.5]
= [0.6, 0.5, 0.4, 0.3, 0.6, 0.5, 0.4, 0.3]同样变成周期=4 的重复。
(h=1, d=0):周期=2 交替序列
原序列:[0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8]
roll(-4):[0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8](周期=2 的序列移4步后还是同一个序列!)
delays_agg[0,1,0,:] = [0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8] # 不变为什么周期=2 的序列也"不变"
roll(-4) 移动了4步,对于周期=2 的序列,移4步 = 移2个完整周期 = 回到原位。所以聚合结果等于原序列。这不是因为模型"知道"这条序列的周期,而是因为 lag=4 恰好是周期2的整数倍的副作用。
(h=1, d=1):周期=4 序列
原序列:[0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8]
roll(-4):[0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8](周期=4 移4步后完全相同)
delays_agg[0,1,1,:] = [0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8] # 不变周期匹配时聚合不改变值
对于与检测到的 lag 周期完全匹配的 V 序列,roll(-τ) = 原序列,聚合结果等于原序列(两个权重都乘以相同的序列,结果不变)。
10. 步骤 9:permute 还原 → 输出 V
python
V = delays_agg.permute(0, 3, 1, 2)python
return (V.contiguous(), None)最终输出与输入形状相同:(1, 8, 2, 2)
11. 完整形状追踪表
| 步骤 | 操作 | 形状 | 备注 |
|---|---|---|---|
| 输入 | — | Q/K: (1,8,2,2), V: (1,8,2,2) | (B,L,H,E) |
| 对齐 | 截断/补零 | (1,8,2,2) | L=S=8,无变化 |
| permute | (0,2,3,1) | Q/K: (1,2,2,8) | (B,H,E,L) |
| rfft | dim=-1 | (1,2,2,5) 复数 | L//2+1=5 |
| conj×乘 | 逐元素 | (1,2,2,5) 复数 | res |
| irfft | dim=-1 | (1,2,2,8) 实数 | corr |
| mean×2 | dim=1 两次 | (1,8) | mean_value |
| topk | k=2 | index=(2,) | index=[0,4] |
| stack+softmax | — | (1,2) | tmp_corr=[[0.5,0.5]] |
| V permute | (0,2,3,1) | (1,2,2,8) | (B,H,D,L) |
| roll×2 + 累加 | dim=-1 | (1,2,2,8) | delays_agg |
| permute 还原 | (0,3,1,2) | (1,8,2,2) | 输出 V |
12. 完整结果一览
输入 Q/K 的公共信号: [1,0,0,0,1,0,0,0] (周期=4)
检测到的主 lag: τ=[0, 4]
权重: w=[0.5, 0.5]
输出 delays_agg(各切片):
(h=0, d=0): [0.30, 0.40, 0.50, 0.60, 0.30, 0.40, 0.50, 0.60]
(h=0, d=1): [0.60, 0.50, 0.40, 0.30, 0.60, 0.50, 0.40, 0.30]
(h=1, d=0): [0.40, 0.80, 0.40, 0.80, 0.40, 0.80, 0.40, 0.80] ← 不变
(h=1, d=1): [0.20, 0.40, 0.60, 0.80, 0.20, 0.40, 0.60, 0.80] ← 不变13. 关键直觉总结
为什么用 FFT 互相关而不是点积
| 对比维度 | 标准点积注意力 | AutoCorrelation |
|---|---|---|
| 相似度的语义 | 位置 i 和位置 j 的向量内容相似 | Q 和 K 在时间偏移 τ 下的对齐强度 |
| 计算规模 | ||
| 聚合方式 | 每个 query 位置有独立权重向量 | 所有位置共享同一组 lag 权重 |
| 适合的序列特性 | 任意语义关联 | 周期性时序(季节性、循环模式) |
为什么 roll(-τ) 等价于"聚合 τ 步之后的信息"
对位置 t 来说,roll(-τ) 把"τ 步之后的 V 值"放到当前位置。这样做加权求和,等价于:
roll 操作的物理直觉
对整条序列做 roll,不需要逐时间步单独索引,因为所有位置的偏移量是同一个 τ。标准注意力中 score 矩阵的每一行对应不同的 query,因此每个位置有不同的聚合对象;AutoCorrelation 用周期假设简化了这一点:同一个 lag 对整条序列有效。
irfft 没有显式 n 参数的安全性
代码里 corr = torch.fft.irfft(res, dim=-1) 不传 n,PyTorch 自动推断:
对偶数 L(L=8,12,…)结果正确。若 L 为奇数,则 n=L。实际代码依赖奇偶隐含假设。