Skip to content

AutoCorrelation 完整手推例子

配套文档:[[04A-Layer5-AutoCorrelation]](宏观原理与公式链)
PyTorch 函数速查:[[../../pytorch-basics/concept-order/11-PyTorch-Tensor基础操作-切片变形拼接注意力]]

本文用一组固定的 toy 数据,从 forward() 入口一步步追踪到输出,每步给出具体张量数值、对应的 PyTorch 函数含义、以及数学意义。全文只用这一个 toy,数字前后完全连贯。


0. Toy 数据定义(全文通用)

B=1, L=8, H=2, E=2, D=2(D=E=2,V 的特征维与 Q/K 相同)

Q 和 K(完全相同,验证自相关)

所有 head 和 channel 共用同一条一维信号:

q[l]=k[l]=[1, 0, 0, 0, 1, 0, 0, 0]

这是周期为 4 的脉冲序列(t=0,4 各有一个脉冲)。

python
# 完整形状 (B, L, H, E) = (1, 8, 2, 2)
# 对所有 h, e: Q[0, :, h, e] = K[0, :, h, e] = [1,0,0,0,1,0,0,0]

V(四条不同曲线,体现 roll 的效果)

python
# V 形状 (B, L, H, D) = (1, 8, 2, 2)
V[0, :, 0, 0] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]   # 递增
V[0, :, 0, 1] = [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]   # 递减
V[0, :, 1, 0] = [0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8]   # 周期=2 交替
V[0, :, 1, 1] = [0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8]   # 周期=4(恰好与 lag 匹配)
为什么 V 用四种曲线

聚合结果取决于 V 本身的周期性:

  • 递增/递减曲线能直观看到 roll 把"4步之后的值"混入当前位置
  • 周期=2 曲线 roll(-4) 后发生改变(2≠4)
  • 周期=4 曲线 roll(-4) 后完全不变(完美匹配检测到的 lag)

1. 步骤 0:L/S 对齐

python
B, L, H, E = queries.shape   # 1, 8, 2, 2
_, S, _, D = values.shape    # S = 8
# L == S → 走 else 分支
values = values[:, :L, :, :]  # 无变化
keys   = keys[:, :L, :, :]    # 无变化

目的:FFT 互相关要求 Q 和 K/V 的时序长度一致,统一对齐到 min(L, S)

此 toy 里 L=S=8,两个分支均不改变形状。


2. 步骤 1:permute — 把时间维移到最后

python
q_time_major = queries.permute(0, 2, 3, 1).contiguous()
k_time_major = keys.permute(0, 2, 3, 1).contiguous()

为什么要 permute

torch.fft.rfftdim=-1 参数是"对最后一维做 FFT"。原始 queries (B,L,H,E) 的最后一维是 E(特征维),而我们需要对 L(时间维)做频域变换。

所以先把时间维 L 换到最后:

(B,L,H,E)permute(0,2,3,1)(B,H,E,L)

数值变化(追踪 b=0, h=0, e=0 这一切片)

# 原始 queries[0, :, 0, 0] 是按 L 排列的一列
queries[0, :, 0, 0] = [1, 0, 0, 0, 1, 0, 0, 0]   # shape=(8,),L 在 dim=1 位置

# permute 后,同样的数据现在在最后一维
q_time_major[0, 0, 0, :] = [1, 0, 0, 0, 1, 0, 0, 0]   # shape=(8,),L 在 dim=-1 位置

permute 不改变数值,只改变轴的解释顺序(底层内存重排由 .contiguous() 完成)。

整体形状:(1,8,2,2)(1,2,2,8)

contiguous() 的作用

permute 只修改 stride 描述符,不移动实际内存。.contiguous() 保证后续 rfft 能正确访问连续内存。


3. 步骤 2:rfft — 时域 → 频域

python
q_fft = torch.fft.rfft(q_time_major, dim=-1)
k_fft = torch.fft.rfft(k_time_major, dim=-1)

rfft 是什么

rfft = Real FFT。对一条实数序列 x[0..L1] 计算离散傅里叶变换:

X[k]=n=0L1x[n]e2πikn/L,k=0,1,,L1

由于实数输入的 DFT 满足共轭对称 X[k]=X[Lk],只需保留不冗余的前半段:

rfft 输出长度=L2+1=82+1=5 个复数

手算 rfft([1,0,0,0,1,0,0,0])

只有 n=0 和 n=4 处有非零值(x[0]=x[4]=1),代入公式:

X[k]=e2πik0/8+e2πik4/8=1+eπik

逐个计算:

k(频率)eπikX[k]解读
0+12+0j直流分量(总能量)
1−10+0j无此频率
2+12+0j频率2,周期 L/2=4 ← 主频
3−10+0j无此频率
4+12+0j奈奎斯特频率
q\_fft[0,h,e,:]=[2+0j, 0+0j, 2+0j, 0+0j, 2+0j]
频率 k=2 对应的物理周期

频率 k 对应的周期 = L/k=8/2=4。这正是我们设置的信号周期 p=4。FFT 在频域准确捕捉到了这个主频。

整体 shape:(1,2,2,8)(1,2,2,5)(复数张量)


4. 步骤 3:互相关 — 频域乘法 = 时域互相关

python
res = q_fft * torch.conj(k_fft)

互相关定理(核心)

时域互相关定义(暴力计算需 O(L2)):

CQK[τ]=t=0L1Q[t]K[(t+τ)modL]

互相关定理告诉我们,这等价于频域操作(O(LlogL)):

CQK=IFFT(FFT(Q)×FFT(K))

其中 复共轭(complex conjugate)。

torch.conj 是什么

对每个复数 z=a+bi,取共轭:

z¯=abi

对我们的 toy(Q=K,均为实数),共轭不改变值:

conj(k\_fft)=conj([2,0,2,0,2])=[2,0,2,0,2](全实数,共轭等于自身)

计算 res

逐元素相乘(Q=K 的自相关情形):

res[k]=q\_fft[k]×k\_fft[k]
kq_fft[k]conj(k_fft[k])res[k]
02+0j2−0j4+0j
10+0j0−0j0+0j
22+0j2−0j4+0j
30+0j0−0j0+0j
42+0j2−0j4+0j
res[0,h,e,:]=[4+0j, 0+0j, 4+0j, 0+0j, 4+0j]

形状不变:(1,2,2,5)(复数)


5. 步骤 4:irfft — 频域 → 时域(互相关函数)

python
corr = torch.fft.irfft(res, dim=-1)

irfft 是什么

irfft = Inverse Real FFT。把 rfft 的输出(L//2+1 个复数)变回实数序列。

由于省略了 n 参数,PyTorch 自动推断输出长度:

n=2×(input\_size1)=2×(51)=8=L

irfft 内部先补全共轭对称的频谱,再做逆 DFT:

x[n]=1Lk=0L1Xfull[k]e2πikn/L

手算 irfft([4,0,4,0,4])

补全对称频谱(N=8 时,rfft 输出 X[0..4],irfft 内部补充 X[5..7]):

Xfull=[4,0,4,0,4,0,4,0]

只有 k=0,2,4,6 处有非零值(均为4),逆变换:

corr[n]=48(1+e2πi2n/8+e2πi4n/8+e2πi6n/8)=12(1+eπin/2+eπin+e3πin/2)

逐位计算:

τ计算corr[τ]含义
012(1+1+1+1)2lag=0:Q 和 K 完全对齐
112(1+i1i)0lag=1:无相关
212(11+11)0lag=2:无相关
312(1i1+i)0lag=3:无相关
412(1+1+1+1)2lag=4:移位一个周期后完美对齐 ✓
5–70无相关
corr[0,h,e,:]=[2, 0, 0, 0, 2, 0, 0, 0]
corr 的物理含义

corr[b,h,e,τ] = 在头 h、通道 e 下,把 K 右移 τ 步后与 Q 的内积(对齐程度)。 峰值在 τ=0 和 τ=4,正好是周期 p=4 的整数倍。

形状:(1,2,2,8)(实数)


6. 步骤 5:两次均值 — 压缩 H 和 E 维

python
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)

目的:topk 要在 lag(L)维上找峰值,需要先把 H 和 E 维度合并成单一的"平均相关强度"。

第一次 mean(dim=1):压缩 H 维

corr (1,2,2,8)mean dim=1(1,2,8)
# dim=1 是 H 维(H=2 个 head 求均值)
# 由于所有 head 值相同,均值等于原值
结果[0, e, :] = [2,0,0,0,2,0,0,0]  对 e=0,1

第二次 mean(dim=1):压缩 E 维

(1,2,8)mean dim=1(1,8)=mean\_value
# dim=1 现在是 E 维(E=2 个 channel 求均值)
mean_value[0, :] = [2, 0, 0, 0, 2, 0, 0, 0]

mean_value[b, τ] = 样本 b 在延迟 τ 处的跨所有 head 和 channel 的平均互相关强度


7. 步骤 6:topk — 选主 lag

python
top_k = int(self.factor * math.log(length))   # factor=1, length=8
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]

top_k 的计算

top\_k=1×ln(8)=2.08=2

math.log 是自然对数 ln(以 e 为底),不是 log2

torch.mean(mean_value, dim=0):batch 均值

mean_value.shape = (1, 8)   # B=1
torch.mean(mean_value, dim=0).shape = (8,)   # B=1 时均值等于原值
= [2, 0, 0, 0, 2, 0, 0, 0]
训练时 batch 均值的代价

当 B>1 时,此步会把所有样本的相关强度平均后再选 top-k,导致所有样本共享同一个 lag 集合。这是 training 路径的近似。Inference 路径用 per-sample topk 避免此问题(见 [[04A-Layer5-AutoCorrelation]] §5.5)。

torch.topk

python
torch.topk(tensor([2,0,0,0,2,0,0,0]), k=2, dim=-1)

topk 返回两个张量:

values  = tensor([2., 2.])
indices = tensor([0, 4])   # τ=0 和 τ=4 的相关强度最高
index=[0, 4]

8. 步骤 7:softmax — 相关强度 → 权重

python
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
tmp_corr = torch.softmax(weights, dim=-1)

收集选中 lag 处的强度

python
mean_value[:, index[0]] = mean_value[:, 0] = tensor([2.0])   # shape (1,)
mean_value[:, index[1]] = mean_value[:, 4] = tensor([2.0])   # shape (1,)

weights = torch.stack([tensor([2.0]), tensor([2.0])], dim=-1)
        = tensor([[2.0, 2.0]])   # shape (1, 2) = (B, top_k)

softmax 归一化

wi=ecijecj

两个值相等时:

softmax([2.0, 2.0])=[e2e2+e2, e2e2+e2]=[0.5, 0.5]
tmp_corr = tensor([[0.5, 0.5]])   # shape (1, 2) = (B, top_k)
这里 softmax 的意义

两个 lag 强度完全相等,模型给它们一样的权重。这等价于:用当前时刻和"4步之后的时刻"的等比例混合来作为最终表示。


9. 步骤 8:roll + 加权累加 — 时延聚合

python
values_permuted = values.permute(0, 2, 3, 1).contiguous()
# (1,8,2,2) → (1,2,2,8) = (B,H,D,L)

delays_agg = torch.zeros_like(values_permuted)

for i in range(top_k):
    pattern = torch.roll(values_permuted, -int(index[i]), -1)
    weight_expanded = tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1)
                           .repeat(1, head, channel, length)
    delays_agg += pattern * weight_expanded

先 permute values

与 Q/K 同理,把时间维移到最后:

values (1,8,2,2) → permute(0,2,3,1) → (1,2,2,8)

values_permuted[0, 0, 0, :] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
values_permuted[0, 0, 1, :] = [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
values_permuted[0, 1, 0, :] = [0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8]
values_permuted[0, 1, 1, :] = [0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8]

权重广播展开(以 i=0 为例)

tmp_corr[:, 0] = [0.5],shape (1,) → 扩展到 (1, 2, 2, 8)

[0.5]             (1,)
  ↓ unsqueeze(1)  (1,1)
  ↓ unsqueeze(1)  (1,1,1)
  ↓ unsqueeze(1)  (1,1,1,1)
  ↓ repeat(1,2,2,8)  (1,2,2,8)

所有 (h,d,t) 位置均填 0.5

torch.roll 是什么

torch.roll(x, shifts=-τ, dims=-1) 把序列循环左移 τ 步:

roll(x,τ)[t]=x[(t+τ)modL]

物理含义:位置 t 上的值,变成原来"τ 步之后"的那个值。

i=0:τ=0,roll(-0)

roll 移位 0 步 = 原序列不变:

pattern = values_permuted(无变化)

i=1:τ=4,roll(-4)

每条序列循环左移 4 步:

原始:   [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
          t=0  t=1  t=2  t=3  t=4  t=5  t=6  t=7

roll(-4):  把 t=4..7 移到开头,t=0..3 绕到末尾
          [0.5, 0.6, 0.7, 0.8, 0.1, 0.2, 0.3, 0.4]
           ↑原来的 t=4 现在在 t=0 位置
roll(-τ) 的直觉

roll(V, -τ)[t] = V[t+τ],相当于"从位置 t 往前看 τ 步的值"。

AutoCorrelation 检测到周期=4,意味着当前时刻 t 的信息可以从 t+4(一个完整周期之后)借鉴。roll(-4) 把每个 t+4 的值"提前"放到位置 t,再加权平均。

最终累加结果(逐切片验证)

(h=0, d=0):递增序列

Vagg=0.5×[0.1,,0.8]+0.5×roll(4)
t原序列×0.5roll(-4)×0.5delays_agg
00.1×0.5=0.050.5×0.5=0.250.30
10.2×0.5=0.100.6×0.5=0.300.40
20.3×0.5=0.150.7×0.5=0.350.50
30.4×0.5=0.200.8×0.5=0.400.60
40.5×0.5=0.250.1×0.5=0.050.30
50.6×0.5=0.300.2×0.5=0.100.40
60.7×0.5=0.350.3×0.5=0.150.50
70.8×0.5=0.400.4×0.5=0.200.60

结果:[0.3, 0.4, 0.5, 0.6, 0.3, 0.4, 0.5, 0.6]——注意序列变成周期=4的重复!

(h=0, d=1):递减序列

roll(-4):[0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1][0.4,0.3,0.2,0.1,0.8,0.7,0.6,0.5]

delays_agg[0,0,1,:] = 0.5×[0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1]
                    + 0.5×[0.4,0.3,0.2,0.1,0.8,0.7,0.6,0.5]
                    = [0.6, 0.5, 0.4, 0.3, 0.6, 0.5, 0.4, 0.3]

同样变成周期=4 的重复。

(h=1, d=0):周期=2 交替序列

原序列:[0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8]

roll(-4):[0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8](周期=2 的序列移4步后还是同一个序列!)

delays\_agg=0.5×原序列+0.5×原序列=原序列
delays_agg[0,1,0,:] = [0.4, 0.8, 0.4, 0.8, 0.4, 0.8, 0.4, 0.8]   # 不变
为什么周期=2 的序列也"不变"

roll(-4) 移动了4步,对于周期=2 的序列,移4步 = 移2个完整周期 = 回到原位。所以聚合结果等于原序列。这不是因为模型"知道"这条序列的周期,而是因为 lag=4 恰好是周期2的整数倍的副作用。

(h=1, d=1):周期=4 序列

原序列:[0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8]

roll(-4):[0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8](周期=4 移4步后完全相同)

delays_agg[0,1,1,:] = [0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.8]   # 不变
周期匹配时聚合不改变值

对于与检测到的 lag 周期完全匹配的 V 序列,roll(-τ) = 原序列,聚合结果等于原序列(两个权重都乘以相同的序列,结果不变)。


10. 步骤 9:permute 还原 → 输出 V

python
V = delays_agg.permute(0, 3, 1, 2)
(B,H,D,L)=(1,2,2,8)permute(0,3,1,2)(B,L,H,D)=(1,8,2,2)
python
return (V.contiguous(), None)

最终输出与输入形状相同:(1, 8, 2, 2)


11. 完整形状追踪表

步骤操作形状备注
输入Q/K: (1,8,2,2), V: (1,8,2,2)(B,L,H,E)
对齐截断/补零(1,8,2,2)L=S=8,无变化
permute(0,2,3,1)Q/K: (1,2,2,8)(B,H,E,L)
rfftdim=-1(1,2,2,5) 复数L//2+1=5
conj×乘逐元素(1,2,2,5) 复数res
irfftdim=-1(1,2,2,8) 实数corr
mean×2dim=1 两次(1,8)mean_value
topkk=2index=(2,)index=[0,4]
stack+softmax(1,2)tmp_corr=[[0.5,0.5]]
V permute(0,2,3,1)(1,2,2,8)(B,H,D,L)
roll×2 + 累加dim=-1(1,2,2,8)delays_agg
permute 还原(0,3,1,2)(1,8,2,2)输出 V

12. 完整结果一览

输入 Q/K 的公共信号: [1,0,0,0,1,0,0,0]  (周期=4)

检测到的主 lag: τ=[0, 4]
权重: w=[0.5, 0.5]

输出 delays_agg(各切片):
  (h=0, d=0): [0.30, 0.40, 0.50, 0.60, 0.30, 0.40, 0.50, 0.60]
  (h=0, d=1): [0.60, 0.50, 0.40, 0.30, 0.60, 0.50, 0.40, 0.30]
  (h=1, d=0): [0.40, 0.80, 0.40, 0.80, 0.40, 0.80, 0.40, 0.80]  ← 不变
  (h=1, d=1): [0.20, 0.40, 0.60, 0.80, 0.20, 0.40, 0.60, 0.80]  ← 不变

13. 关键直觉总结

为什么用 FFT 互相关而不是点积

对比维度标准点积注意力AutoCorrelation
相似度的语义位置 i 和位置 j 的向量内容相似Q 和 K 在时间偏移 τ 下的对齐强度
计算规模L×L 的 score 矩阵L 维的 corr 向量(topk 后只用 lnL 个)
聚合方式每个 query 位置有独立权重向量所有位置共享同一组 lag 权重
适合的序列特性任意语义关联周期性时序(季节性、循环模式)

为什么 roll(-τ) 等价于"聚合 τ 步之后的信息"

roll(V,τ)[t]=V[(t+τ)modL]

对位置 t 来说,roll(-τ) 把"τ 步之后的 V 值"放到当前位置。这样做加权求和,等价于:

Out[t]=kwkV[(t+τk)modL]
roll 操作的物理直觉

对整条序列做 roll,不需要逐时间步单独索引,因为所有位置的偏移量是同一个 τ。标准注意力中 score 矩阵的每一行对应不同的 query,因此每个位置有不同的聚合对象;AutoCorrelation 用周期假设简化了这一点:同一个 lag 对整条序列有效。

irfft 没有显式 n 参数的安全性

代码里 corr = torch.fft.irfft(res, dim=-1) 不传 n,PyTorch 自动推断:

ninferred=2×(input\_size1)=2×(L2+11)=2×L2=L

偶数 L(L=8,12,…)结果正确。若 L 为奇数,则 L/2L/2,推断出的 n 会比 L 少1,此时应显式传 n=L。实际代码依赖奇偶隐含假设。

*记录并在线阅读我的笔记*