Appearance
DUET · Layer 2B — Mahalanobis_mask
§1 在父层中的位置
DUETModel.forward() 的通道路径:
python
changed_input = rearrange(input, "b l n -> b n l") # (3,7,16)
channel_mask = self.mask_generator(changed_input) # (3,1,7,7)
channel_group_feature, attention = self.Channel_transformer(
x=temporal_feature, attn_mask=channel_mask
)self.mask_generator = Mahalanobis_mask(config.seq_len)。
§2 I/O 接口定义
python
Mahalanobis_mask.forward(X) -> Tensor| 参数 | shape | 含义 |
|---|---|---|
X | (B, N, L) = (3, 7, 16) | 每个变量的原始时序(已转置) |
| 返回 mask | (B, 1, N, N) = (3, 1, 7, 7) | 0/1 通道注意力掩码 |
mask[b, 0, i, j] = 1 表示第 b 个样本中通道 i 应当关注通道 j;= 0 则在注意力计算中被屏蔽(置为接近负无穷)。
__init__ 只有一个可学习参数:
python
frequency_size = input_size // 2 + 1 # = 16//2+1 = 9
self.A = nn.Parameter(torch.randn(frequency_size, frequency_size)) # shape (9, 9)§3 顺序图(具体层)
§4 语义分组图(索引层)
§5 逐步骤精读
§5.0 完整原始代码
python
class Mahalanobis_mask(nn.Module):
def __init__(self, input_size):
super(Mahalanobis_mask, self).__init__()
frequency_size = input_size // 2 + 1
self.A = nn.Parameter(
torch.randn(frequency_size, frequency_size), requires_grad=True
)
def calculate_prob_distance(self, X):
XF = torch.abs(torch.fft.rfft(X, dim=-1))
X1 = XF.unsqueeze(2)
X2 = XF.unsqueeze(1)
diff = X1 - X2
temp = torch.einsum("dk,bxck->bxcd", self.A, diff)
dist = torch.einsum("bxcd,bxcd->bxc", temp, temp)
exp_dist = 1 / (dist + 1e-10)
identity_matrices = 1 - torch.eye(exp_dist.shape[-1])
mask = identity_matrices.repeat(exp_dist.shape[0], 1, 1).to(exp_dist.device)
exp_dist = torch.einsum("bxc,bxc->bxc", exp_dist, mask)
exp_max, _ = torch.max(exp_dist, dim=-1, keepdim=True)
exp_max = exp_max.detach()
p = exp_dist / exp_max
identity_matrices = torch.eye(p.shape[-1])
p1 = torch.einsum("bxc,bxc->bxc", p, mask)
diag = identity_matrices.repeat(p.shape[0], 1, 1).to(p.device)
p = (p1 + diag) * 0.99
return p
def bernoulli_gumbel_rsample(self, distribution_matrix):
b, c, d = distribution_matrix.shape
flatten_matrix = rearrange(distribution_matrix, "b c d -> (b c d) 1")
r_flatten_matrix = 1 - flatten_matrix
log_flatten_matrix = torch.log(flatten_matrix / r_flatten_matrix)
log_r_flatten_matrix = torch.log(r_flatten_matrix / flatten_matrix)
new_matrix = torch.concat([log_flatten_matrix, log_r_flatten_matrix], dim=-1)
resample_matrix = gumbel_softmax(new_matrix, hard=True)
resample_matrix = rearrange(
resample_matrix[..., 0], "(b c d) -> b c d", b=b, c=c, d=d
)
return resample_matrix
def forward(self, X):
p = self.calculate_prob_distance(X)
sample = self.bernoulli_gumbel_rsample(p)
mask = sample.unsqueeze(1)
cnt = torch.sum(mask, dim=-1)
return mask⚠️ cnt 是死代码
cnt 是死代码
forward里cnt = torch.sum(mask, dim=-1)计算了每行被激活的通道数,但结果既未被返回也未被使用。这是调试残留代码,不影响正确性。
§5.1 宏观逻辑
一句话目标:在频域计算每对通道的相似度,用可学习的度量矩阵 A 参数化距离函数,再通过 Gumbel-Bernoulli 采样把连续相似度转化为可导的离散 0/1 掩码——有相似频率成分的通道对允许注意力,相似度低的通道对被屏蔽。
为什么用频域而不是时域?
频域幅度
用小例子(B=1, N=3, L=4, freq_size=3, A 是 3×3 单位矩阵):
X (1, 3, 4):
通道 0: [1, 2, 1, 2] ← 频率 2 成分强(交替信号)
通道 1: [1.1, 2.1, 1.1, 2.1] ← 几乎相同
通道 2: [5, 5, 5, 5] ← 纯直流(频率 0 成分强)
rfft → |XF| (1, 3, 3):
通道 0: [6, 0, 2] ← DC=6, freq1=0, freq2=2
通道 1: [6.4, 0, 2] ← 类似
通道 2: [20, 0, 0] ← 只有 DC
diff (1, 3, 3, 3):
diff[0,0,1,:] = [6-6.4, 0, 2-2] = [-0.4, 0, 0] ← 0 和 1 很近
diff[0,0,2,:] = [6-20, 0, 2-0] = [-14, 0, 2] ← 0 和 2 差异大
dist (A=I 时) = ||diff||^2:
dist[0,0,1] = 0.16 + 0 + 0 = 0.16 ← 通道 0-1 很近
dist[0,0,2] = 196 + 0 + 4 = 200 ← 通道 0-2 很远
相似度 1/(dist+ε):
sim[0,0,1] ≈ 6.25 sim[0,0,2] ≈ 0.005
归一化后 p:
p[0,0,1] ≈ 0.99 p[0,0,2] ≈ 0.001
Gumbel 采样(期望):
mask[0,0,1] ≈ 1 (通道 0 关注 1,高概率)
mask[0,0,2] ≈ 0 (通道 0 不关注 2,低概率)§5.2 频域幅度提取
python
XF = torch.abs(torch.fft.rfft(X, dim=-1))torch.fft.rfft(X, dim=-1) 对实数输入做实 FFT:
- 输入
X (3, 7, 16)→ 输出 复数(3, 7, 9) - 输出大小 =
L//2 + 1 = 16//2+1 = 9(rfft 利用共轭对称只输出前半频率)
torch.abs(...) 取复数模(幅度),结果 |XF| (3, 7, 9) 是实数。
toy 数值(取 X[0, 0, :] = [1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1],纯高频信号):
rfft([1,-1,...]) 的结果:DC = 0,最高频分量 = 16,其余 ≈ 0
|XF[0,0,:]| ≈ [0, 0, 0, 0, 0, 0, 0, 0, 16]§5.3 两两通道差异计算
python
X1 = XF.unsqueeze(2) # (3, 7, 1, 9)
X2 = XF.unsqueeze(1) # (3, 1, 7, 9)
diff = X1 - X2 # (3, 7, 7, 9)广播减法:
diff[b, i, j, :] = |XF[b, i, :]| - |XF[b, j, :]|即第 b 个样本中通道 i 和通道 j 的频谱幅度之差(长度为 freq_size=9 的向量)。diff 是反对称的:diff[b,i,j] = -diff[b,j,i],因此后续的 dist = ||A·diff||^2 是对称的。
§5.4 Mahalanobis 距离计算
python
temp = torch.einsum("dk,bxck->bxcd", self.A, diff)
dist = torch.einsum("bxcd,bxcd->bxc", temp, temp)第一步 — 线性变换:
einsum("dk,bxck->bxcd", A(9,9), diff(3,7,7,9)) 的含义:
对下标 k(freq_size 轴)收缩,结果 temp 的每个元素:
等价于:在最后两个维度(c, k)上对 diff 做矩阵乘 temp (3, 7, 7, 9)。
第二步 — 平方范数:
einsum("bxcd,bxcd->bxc", temp, temp) = 对最后一维(d)逐元素平方再求和:
这是广义 Mahalanobis 距离的一种形式,度量矩阵为
为什么不用真正的 Mahalanobis 距离?
真正的 Mahalanobis 距离使用协方差矩阵的逆
,需要估计数据分布参数。这里 是直接可学习的参数,更灵活:训练过程会让 自动学到哪些频率轴的差异对预测更有区分力。
toy 数值(取 A = 单位矩阵 I,通道 0 频谱 = [3,1,0,2,0,0,0,0,0],通道 1 = [3,1,0,2,0,0,0,0,0]):
diff[0, 0, 1, :] = [0, 0, 0, 0, ...] (完全相同)
temp = I @ diff = diff = [0, ...]
dist[0, 0, 1] = 0 + 0 + ... = 0
sim[0, 0, 1] = 1/(0+1e-10) = 1e10 ← 极高相似度§5.5 相似度归一化与对角线处理
python
exp_dist = 1 / (dist + 1e-10)
# 对角线置零
identity_matrices = 1 - torch.eye(exp_dist.shape[-1]) # (7,7),对角=0,其余=1
mask = identity_matrices.repeat(exp_dist.shape[0], 1, 1) # (3,7,7)
exp_dist = torch.einsum("bxc,bxc->bxc", exp_dist, mask) # 对角线 × 0 = 0
# 按行归一化到 [0,1]
exp_max, _ = torch.max(exp_dist, dim=-1, keepdim=True) # (3,7,1)
exp_max = exp_max.detach() # 不参与梯度(归一化不需要可导)
p = exp_dist / exp_max # p ∈ [0,1],每行最大值 = 1
# 加回对角线 = 1 再缩放
identity_matrices = torch.eye(p.shape[-1]) # (7,7),对角=1,其余=0
p1 = torch.einsum("bxc,bxc->bxc", p, mask) # 再次零对角(确保干净)
diag = identity_matrices.repeat(p.shape[0], 1, 1)
p = (p1 + diag) * 0.99最终 p (3, 7, 7) 的含义:
| 位置 | 值域 | 语义 |
|---|---|---|
对角 p[b,i,i] | 0.99 | 自相关,始终高概率保留 |
非对角 p[b,i,j] | [0, 0.99] | 通道 i 与 j 的相似度,越高越可能 attend |
为什么乘以 0.99 而不是 1.0?
p[b,i,j] 是 Bernoulli 采样的参数。若 p=1,log(p/(1-p)) = log(∞) 导致数值溢出。乘以 0.99 使所有值严格小于 1,保证 logit 有界:
§5.6 Gumbel-Bernoulli 可微分采样
python
def bernoulli_gumbel_rsample(self, distribution_matrix):
b, c, d = distribution_matrix.shape # (3, 7, 7)
flatten_matrix = rearrange(distribution_matrix, "b c d -> (b c d) 1") # (147, 1)
r_flatten_matrix = 1 - flatten_matrix # (147, 1)
log_flatten_matrix = torch.log(flatten_matrix / r_flatten_matrix) # log(p/(1-p))
log_r_flatten_matrix = torch.log(r_flatten_matrix / flatten_matrix) # log((1-p)/p)
new_matrix = torch.concat([log_flatten_matrix, log_r_flatten_matrix], dim=-1) # (147, 2)
resample_matrix = gumbel_softmax(new_matrix, hard=True) # (147, 2) one-hot
resample_matrix = rearrange(
resample_matrix[..., 0], "(b c d) -> b c d", b=b, c=c, d=d
) # (3, 7, 7) — 取类别 1 的 one-hot 列
return resample_matrix核心思路:把每个 Bernoulli(
二分类 logit 的构造:
new_matrix (147, 2):每行是 [logit_1, logit_0] = [logit, -logit]。
Gumbel Softmax (hard=True):
前向传播:对每行加独立 Gumbel 噪声后取 argmax,返回 one-hot(0 或 1,不可导)。 反向传播:用 softmax 的梯度代替 argmax 的梯度(straight-through)。
这样在反向传播时梯度通过 softmax 流回 log_flatten_matrix,进而流回 p,最终流回 A 的梯度。
结果提取:
resample_matrix[..., 0] = one-hot 的第 0 列 = "是否选了类别 1(= attend)"toy 数值(p=0.9,单个 entry):
logit_1 = log(0.9/0.1) = log(9) ≈ 2.20
logit_0 = log(0.1/0.9) ≈ -2.20
加 Gumbel 噪声(G ~ -log(-log(U)), U ~ Uniform[0,1]):
noisy_logit_1 = 2.20 + G_1 ≈ 2.20 + 0.5 = 2.70
noisy_logit_0 = -2.20 + G_0 ≈ -2.20 + (-0.3) = -2.50
argmax → 类别 1 → resample_matrix = [1, 0]
结果 resample_matrix[..., 0] = 1 ← 此位置保留注意力§5.7 掩码生成与 FullAttention 中的应用
python
mask = sample.unsqueeze(1) # (3, 7, 7) → (3, 1, 7, 7)
cnt = torch.sum(mask, dim=-1) # 死代码,不影响输出
return maskmask (3, 1, 7, 7) 的 H 维(dim=1)为 1,在 FullAttention 中会广播到所有注意力头(H=2)。
FullAttention 中的掩码应用(masked_attention.py 中):
python
large_negative = -math.log(1e10) # ≈ -23.03
attention_mask = torch.where(attn_mask == 0, large_negative, 0)
scores = scores * attn_mask + attention_mask设 scores (3, 2, 7, 7),attn_mask (3, 1, 7, 7) 广播到 (3, 2, 7, 7):
| 位置 | attn_mask | 操作 | 结果 |
|---|---|---|---|
| mask=1 | 保留 | scores × 1 + 0 | 原始注意力分数 |
| mask=0 | 屏蔽 | scores × 0 + (-23.03) | ≈ -23.03 |
经过 softmax 后,
为什么用 -log(1e10) 而不是 -∞?
-log(1e10) 而不是 -∞?PyTorch 中 softmax 对
-inf的处理在某些情况下会产生 NaN(当所有 logits 都是-inf时)。-log(1e10) ≈ -23.03是一个足够小的有限值,softmax 后的权重约为,在浮点精度内等效于 0,但不会触发 NaN。
§6 下钻子组件
本层无需下钻,series_decomp / rfft / gumbel_softmax 均已在本文档内完整解释。
创建:2026-04-24