Skip to content

DUET · Layer 2B — Mahalanobis_mask

§1 在父层中的位置

DUETModel.forward() 的通道路径:

python
changed_input = rearrange(input, "b l n -> b n l")      # (3,7,16)
channel_mask = self.mask_generator(changed_input)        # (3,1,7,7)
channel_group_feature, attention = self.Channel_transformer(
    x=temporal_feature, attn_mask=channel_mask
)

self.mask_generator = Mahalanobis_mask(config.seq_len)


§2 I/O 接口定义

python
Mahalanobis_mask.forward(X) -> Tensor
参数shape含义
X(B, N, L) = (3, 7, 16)每个变量的原始时序(已转置)
返回 mask(B, 1, N, N) = (3, 1, 7, 7)0/1 通道注意力掩码

mask[b, 0, i, j] = 1 表示第 b 个样本中通道 i 应当关注通道 j;= 0 则在注意力计算中被屏蔽(置为接近负无穷)。

__init__ 只有一个可学习参数:

python
frequency_size = input_size // 2 + 1   # = 16//2+1 = 9
self.A = nn.Parameter(torch.randn(frequency_size, frequency_size))  # shape (9, 9)

§3 顺序图(具体层)


§4 语义分组图(索引层)


§5 逐步骤精读

§5.0 完整原始代码

python
class Mahalanobis_mask(nn.Module):
    def __init__(self, input_size):
        super(Mahalanobis_mask, self).__init__()
        frequency_size = input_size // 2 + 1
        self.A = nn.Parameter(
            torch.randn(frequency_size, frequency_size), requires_grad=True
        )

    def calculate_prob_distance(self, X):
        XF = torch.abs(torch.fft.rfft(X, dim=-1))
        X1 = XF.unsqueeze(2)
        X2 = XF.unsqueeze(1)
        diff = X1 - X2
        temp = torch.einsum("dk,bxck->bxcd", self.A, diff)
        dist = torch.einsum("bxcd,bxcd->bxc", temp, temp)
        exp_dist = 1 / (dist + 1e-10)
        identity_matrices = 1 - torch.eye(exp_dist.shape[-1])
        mask = identity_matrices.repeat(exp_dist.shape[0], 1, 1).to(exp_dist.device)
        exp_dist = torch.einsum("bxc,bxc->bxc", exp_dist, mask)
        exp_max, _ = torch.max(exp_dist, dim=-1, keepdim=True)
        exp_max = exp_max.detach()
        p = exp_dist / exp_max
        identity_matrices = torch.eye(p.shape[-1])
        p1 = torch.einsum("bxc,bxc->bxc", p, mask)
        diag = identity_matrices.repeat(p.shape[0], 1, 1).to(p.device)
        p = (p1 + diag) * 0.99
        return p

    def bernoulli_gumbel_rsample(self, distribution_matrix):
        b, c, d = distribution_matrix.shape
        flatten_matrix = rearrange(distribution_matrix, "b c d -> (b c d) 1")
        r_flatten_matrix = 1 - flatten_matrix
        log_flatten_matrix = torch.log(flatten_matrix / r_flatten_matrix)
        log_r_flatten_matrix = torch.log(r_flatten_matrix / flatten_matrix)
        new_matrix = torch.concat([log_flatten_matrix, log_r_flatten_matrix], dim=-1)
        resample_matrix = gumbel_softmax(new_matrix, hard=True)
        resample_matrix = rearrange(
            resample_matrix[..., 0], "(b c d) -> b c d", b=b, c=c, d=d
        )
        return resample_matrix

    def forward(self, X):
        p = self.calculate_prob_distance(X)
        sample = self.bernoulli_gumbel_rsample(p)
        mask = sample.unsqueeze(1)
        cnt = torch.sum(mask, dim=-1)
        return mask
⚠️ cnt 是死代码

forwardcnt = torch.sum(mask, dim=-1) 计算了每行被激活的通道数,但结果既未被返回也未被使用。这是调试残留代码,不影响正确性。


§5.1 宏观逻辑

一句话目标:在频域计算每对通道的相似度,用可学习的度量矩阵 A 参数化距离函数,再通过 Gumbel-Bernoulli 采样把连续相似度转化为可导的离散 0/1 掩码——有相似频率成分的通道对允许注意力,相似度低的通道对被屏蔽。

为什么用频域而不是时域?

频域幅度 |XF[f]| 反映该通道"多强烈地含有频率 f 的成分",对时序的平移、相位偏移不敏感。两个有相同周期模式但时间对齐不同的变量,在时域距离很大,在频域幅度距离很小。这使相似度计算更鲁棒。

用小例子(B=1, N=3, L=4, freq_size=3, A 是 3×3 单位矩阵)

X (1, 3, 4):
  通道 0: [1, 2, 1, 2]   ← 频率 2 成分强(交替信号)
  通道 1: [1.1, 2.1, 1.1, 2.1]   ← 几乎相同
  通道 2: [5, 5, 5, 5]   ← 纯直流(频率 0 成分强)

rfft → |XF| (1, 3, 3):
  通道 0: [6, 0, 2]   ← DC=6, freq1=0, freq2=2
  通道 1: [6.4, 0, 2]  ← 类似
  通道 2: [20, 0, 0]  ← 只有 DC

diff (1, 3, 3, 3):
  diff[0,0,1,:] = [6-6.4, 0, 2-2] = [-0.4, 0, 0]   ← 0 和 1 很近
  diff[0,0,2,:] = [6-20, 0, 2-0]  = [-14, 0, 2]    ← 0 和 2 差异大

dist (A=I 时) = ||diff||^2:
  dist[0,0,1] = 0.16 + 0 + 0 = 0.16     ← 通道 0-1 很近
  dist[0,0,2] = 196 + 0 + 4 = 200       ← 通道 0-2 很远

相似度 1/(dist+ε):
  sim[0,0,1] ≈ 6.25   sim[0,0,2] ≈ 0.005

归一化后 p:
  p[0,0,1] ≈ 0.99   p[0,0,2] ≈ 0.001

Gumbel 采样(期望):
  mask[0,0,1] ≈ 1   (通道 0 关注 1,高概率)
  mask[0,0,2] ≈ 0   (通道 0 不关注 2,低概率)

§5.2 频域幅度提取

python
XF = torch.abs(torch.fft.rfft(X, dim=-1))

torch.fft.rfft(X, dim=-1) 对实数输入做实 FFT:

  • 输入 X (3, 7, 16) → 输出 复数 (3, 7, 9)
  • 输出大小 = L//2 + 1 = 16//2+1 = 9(rfft 利用共轭对称只输出前半频率)

torch.abs(...) 取复数模(幅度),结果 |XF| (3, 7, 9) 是实数。

toy 数值(取 X[0, 0, :] = [1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1],纯高频信号):

rfft([1,-1,...]) 的结果:DC = 0,最高频分量 = 16,其余 ≈ 0
|XF[0,0,:]| ≈ [0, 0, 0, 0, 0, 0, 0, 0, 16]

§5.3 两两通道差异计算

python
X1 = XF.unsqueeze(2)   # (3, 7, 1, 9)
X2 = XF.unsqueeze(1)   # (3, 1, 7, 9)
diff = X1 - X2          # (3, 7, 7, 9)

广播减法:

diff[b, i, j, :] = |XF[b, i, :]| - |XF[b, j, :]|

即第 b 个样本中通道 i 和通道 j 的频谱幅度之差(长度为 freq_size=9 的向量)。diff 是反对称的:diff[b,i,j] = -diff[b,j,i],因此后续的 dist = ||A·diff||^2 是对称的。


§5.4 Mahalanobis 距离计算

python
temp = torch.einsum("dk,bxck->bxcd", self.A, diff)
dist = torch.einsum("bxcd,bxcd->bxc", temp, temp)

第一步 — 线性变换

einsum("dk,bxck->bxcd", A(9,9), diff(3,7,7,9)) 的含义:

对下标 k(freq_size 轴)收缩,结果 temp 的每个元素:

temp[b,x,c,d]=kA[d,k]diff[b,x,c,k]

等价于:在最后两个维度(c, k)上对 diff 做矩阵乘 A,结果 temp (3, 7, 7, 9)

第二步 — 平方范数

einsum("bxcd,bxcd->bxc", temp, temp) = 对最后一维(d)逐元素平方再求和:

dist[b,i,j]=dtemp[b,i,j,d]2=A(|XFi||XFj|)22

这是广义 Mahalanobis 距离的一种形式,度量矩阵为 M=AA(正半定)。A 是可学习的,训练时自动调整哪些频率分量对通道相似度更重要。

为什么不用真正的 Mahalanobis 距离?

真正的 Mahalanobis 距离使用协方差矩阵的逆 (Σ1),需要估计数据分布参数。这里 A 是直接可学习的参数,更灵活:训练过程会让 A 自动学到哪些频率轴的差异对预测更有区分力。

toy 数值(取 A = 单位矩阵 I,通道 0 频谱 = [3,1,0,2,0,0,0,0,0],通道 1 = [3,1,0,2,0,0,0,0,0]):

diff[0, 0, 1, :] = [0, 0, 0, 0, ...] (完全相同)
temp = I @ diff = diff = [0, ...]
dist[0, 0, 1] = 0 + 0 + ... = 0
sim[0, 0, 1] = 1/(0+1e-10) = 1e10   ← 极高相似度

§5.5 相似度归一化与对角线处理

python
exp_dist = 1 / (dist + 1e-10)

# 对角线置零
identity_matrices = 1 - torch.eye(exp_dist.shape[-1])   # (7,7),对角=0,其余=1
mask = identity_matrices.repeat(exp_dist.shape[0], 1, 1)  # (3,7,7)
exp_dist = torch.einsum("bxc,bxc->bxc", exp_dist, mask)  # 对角线 × 0 = 0

# 按行归一化到 [0,1]
exp_max, _ = torch.max(exp_dist, dim=-1, keepdim=True)   # (3,7,1)
exp_max = exp_max.detach()   # 不参与梯度(归一化不需要可导)
p = exp_dist / exp_max        # p ∈ [0,1],每行最大值 = 1

# 加回对角线 = 1 再缩放
identity_matrices = torch.eye(p.shape[-1])   # (7,7),对角=1,其余=0
p1 = torch.einsum("bxc,bxc->bxc", p, mask)  # 再次零对角(确保干净)
diag = identity_matrices.repeat(p.shape[0], 1, 1)
p = (p1 + diag) * 0.99

最终 p (3, 7, 7) 的含义:

位置值域语义
对角 p[b,i,i]0.99自相关,始终高概率保留
非对角 p[b,i,j][0, 0.99]通道 i 与 j 的相似度,越高越可能 attend

为什么乘以 0.99 而不是 1.0?

p[b,i,j] 是 Bernoulli 采样的参数。若 p=1,log(p/(1-p)) = log(∞) 导致数值溢出。乘以 0.99 使所有值严格小于 1,保证 logit 有界:logit(0.99)=log(99)4.6


§5.6 Gumbel-Bernoulli 可微分采样

python
def bernoulli_gumbel_rsample(self, distribution_matrix):
    b, c, d = distribution_matrix.shape       # (3, 7, 7)

    flatten_matrix = rearrange(distribution_matrix, "b c d -> (b c d) 1")  # (147, 1)
    r_flatten_matrix = 1 - flatten_matrix                                    # (147, 1)

    log_flatten_matrix = torch.log(flatten_matrix / r_flatten_matrix)       # log(p/(1-p))
    log_r_flatten_matrix = torch.log(r_flatten_matrix / flatten_matrix)     # log((1-p)/p)

    new_matrix = torch.concat([log_flatten_matrix, log_r_flatten_matrix], dim=-1)  # (147, 2)
    resample_matrix = gumbel_softmax(new_matrix, hard=True)                         # (147, 2) one-hot

    resample_matrix = rearrange(
        resample_matrix[..., 0], "(b c d) -> b c d", b=b, c=c, d=d
    )  # (3, 7, 7) — 取类别 1 的 one-hot 列
    return resample_matrix

核心思路:把每个 Bernoulli(p) 变量转化为二分类 categorical,利用 Gumbel Softmax 实现可微分的硬采样(straight-through estimator)。

二分类 logit 的构造

logit1=logp1p,logit0=log1pp=logit1

new_matrix (147, 2):每行是 [logit_1, logit_0] = [logit, -logit]

Gumbel Softmax (hard=True)

前向传播:对每行加独立 Gumbel 噪声后取 argmax,返回 one-hot(0 或 1,不可导)。 反向传播:用 softmax 的梯度代替 argmax 的梯度(straight-through)。

这样在反向传播时梯度通过 softmax 流回 log_flatten_matrix,进而流回 p,最终流回 A 的梯度。

结果提取

resample_matrix[..., 0] = one-hot 的第 0 列 = "是否选了类别 1(= attend)"

toy 数值(p=0.9,单个 entry):

logit_1 = log(0.9/0.1) = log(9) ≈ 2.20
logit_0 = log(0.1/0.9) ≈ -2.20

加 Gumbel 噪声(G ~ -log(-log(U)), U ~ Uniform[0,1]):
  noisy_logit_1 = 2.20 + G_1 ≈ 2.20 + 0.5 = 2.70
  noisy_logit_0 = -2.20 + G_0 ≈ -2.20 + (-0.3) = -2.50

argmax → 类别 1 → resample_matrix = [1, 0]
结果 resample_matrix[..., 0] = 1   ← 此位置保留注意力

§5.7 掩码生成与 FullAttention 中的应用

python
mask = sample.unsqueeze(1)   # (3, 7, 7) → (3, 1, 7, 7)
cnt = torch.sum(mask, dim=-1)  # 死代码,不影响输出
return mask

mask (3, 1, 7, 7) 的 H 维(dim=1)为 1,在 FullAttention 中会广播到所有注意力头(H=2)。

FullAttention 中的掩码应用masked_attention.py 中):

python
large_negative = -math.log(1e10)   # ≈ -23.03
attention_mask = torch.where(attn_mask == 0, large_negative, 0)
scores = scores * attn_mask + attention_mask

scores (3, 2, 7, 7)attn_mask (3, 1, 7, 7) 广播到 (3, 2, 7, 7)

位置attn_mask操作结果
mask=1保留scores × 1 + 0原始注意力分数
mask=0屏蔽scores × 0 + (-23.03)≈ -23.03

经过 softmax 后,23.03 对应的注意力权重 e23.03/normalizer0,被屏蔽通道对实际上不参与信息传播。

为什么用 -log(1e10) 而不是 -∞

PyTorch 中 softmax 对 -inf 的处理在某些情况下会产生 NaN(当所有 logits 都是 -inf 时)。-log(1e10) ≈ -23.03 是一个足够小的有限值,softmax 后的权重约为 e23.031e10,在浮点精度内等效于 0,但不会触发 NaN。


§6 下钻子组件

本层无需下钻,series_decomp / rfft / gumbel_softmax 均已在本文档内完整解释。


创建:2026-04-24

*记录并在线阅读我的笔记*