Appearance
DUET · Layer 3 — SparseDispatcher
§1 在父层中的位置
Linear_extractor_cluster.forward() 在计算完稀疏门控矩阵 gates 后,立刻构造 SparseDispatcher(num_experts, gates) 并调用其三个方法:
python
dispatcher = SparseDispatcher(self.num_experts, gates)
expert_inputs = dispatcher.dispatch(x_norm) # 分发
gates_per_expert = dispatcher.expert_to_gates() # 取门控(供专家用,此代码中实际未使用)
expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]
y = dispatcher.combine(expert_outputs) # 聚合§2 I/O 接口定义
| 方法 | 参数 | 返回 |
|---|---|---|
__init__(num_experts, gates) | gates (B*N, E) 稀疏矩阵 | 预计算索引,无返回 |
dispatch(inp) | inp (B*N, L, 1) | list[E],第 i 项 shape (n_i, L, 1) |
combine(expert_out) | list[E],第 i 项 shape (n_i, d_model, 1) | (B*N, d_model, 1) |
expert_to_gates() | — | list[E],第 i 项 shape (n_i, 1) |
全局 toy 参数下:B*N=21, E=6, k=1, L=16, d_model=8。每个样本恰好路由到 1 个专家,
§3 顺序图(具体层)
§4 语义分组图(索引层)
§5 逐步骤精读
§5.0 完整原始代码
python
class SparseDispatcher(object):
def __init__(self, num_experts, gates):
self._gates = gates
self._num_experts = num_experts
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
_, self._expert_index = sorted_experts.split(1, dim=1)
self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
self._part_sizes = (gates > 0).sum(0).tolist()
gates_exp = gates[self._batch_index.flatten()]
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
def dispatch(self, inp):
inp_exp = inp[self._batch_index].squeeze(1)
return torch.split(inp_exp, self._part_sizes, dim=0)
def combine(self, expert_out, multiply_by_gates=True):
# apply exp to expert outputs, so we are not longer in log space
stitched = torch.cat(expert_out, 0)
if multiply_by_gates:
# stitched = stitched.mul(self._nonzero_gates)
stitched = torch.einsum("i...,ij->i...", stitched, self._nonzero_gates)
shape = list(expert_out[-1].shape)
shape[0] = self._gates.size(0)
zeros = torch.zeros(*shape, requires_grad=True, device=stitched.device)
combined = zeros.index_add(0, self._batch_index, stitched.float())
return combined
def expert_to_gates(self):
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)⚠️ 两处遗留注释
combine开头注释# apply exp to expert outputs, so we are not longer in log space— 这是从原始 T2T 代码遗留的错误注释,当前实现并不做 exp 运算。- 被注释掉的
stitched = stitched.mul(self._nonzero_gates)是旧写法(仅适用于 1D 输出),被 einsum 泛化版本替代。两行都是无效代码,不影响运行。
§5.1 宏观逻辑
一句话目标:把稀疏门控矩阵的"谁去哪个专家、权重多少"信息预计算成三个索引(_batch_index / _part_sizes / _nonzero_gates),让 dispatch 和 combine 各只做一次高效批量操作,而不是循环 B*N 次。
为什么要"预计算索引"而不是每次 dispatch/combine 都重新查找?
dispatch 和 combine 本质上都是稀疏聚集/分散操作。如果 forward() 中直接 for b in range(B*N): send_to_expert(x[b], which_expert[b]),循环次数 = 21,而且无法批量化。预计算后:
- dispatch = 1 次索引取行(
inp[_batch_index])+ 1 次 split - combine = 1 次 cat + 1 次 einsum + 1 次 index_add
最小例子(4 个样本,3 个专家,k=1)追踪:
输入 gates (4×3):
行 0: [ 0, 0, 0.7] → 样本 0 → Expert 2
行 1: [0.9, 0, 0] → 样本 1 → Expert 0
行 2: [ 0, 0, 0.5] → 样本 2 → Expert 2
行 3: [ 0, 0.8, 0] → 样本 3 → Expert 1
预期分组结果:
Expert 0 ← [样本 1] gate=0.9
Expert 1 ← [样本 3] gate=0.8
Expert 2 ← [样本 0, 样本 2] gate=0.7, 0.5§5.2 __init__ 详解
Step 1 — 找所有非零元素坐标
torch.nonzero(gates) 返回 gates 中每个非零值的坐标:
对于上面的 gates (4×3),非零元素在 (0,2), (1,0), (2,2), (3,1) 四处:
nonzero(gates) = [[0, 2],
[1, 0],
[2, 2],
[3, 1]]shape (4, 2) — 4 个非零值,每行是 [样本下标, 专家下标]。
Step 2 — 按专家下标排序
torch.nonzero(gates).sort(0) 沿 dim=0(行)对每列独立排序:
sort(0) 对 [[0,2],[1,0],[2,2],[3,1]]:
列 0 排序: [0,1,2,3] → 无变化,indices=[0,1,2,3]
列 1 排序: [2,0,2,1] → [0,1,2,2],indices=[1,3,0,2]
sorted_experts = [[0,0],[1,1],[2,2],[3,2]] ← 两列各自排序后的值
index_sorted_experts = [[0,1],[1,3],[2,0],[3,2]] ← 各列的原始行号为什么 sort(0) 能提取专家排序?
index_sorted_experts[:, 1]=[1, 3, 0, 2]是"把列 1(专家下标)排成升序所需的行置换"。这个置换应用到原始 nonzero 数组后,结果按专家编号分组。这是整个 init 的核心技巧。
Step 3 — 提取 _expert_index 和 _batch_index
_, _expert_index = sorted_experts.split(1, dim=1):
_= sorted_experts 列 0 =[[0],[1],[2],[3]](样本下标已排,被丢弃)_expert_index= sorted_experts 列 1 =[[0],[1],[2],[2]]
_batch_index = nonzero(gates)[index_sorted_experts[:, 1], 0]
将置换 [1,3,0,2] 应用到原始 nonzero 数组的列 0(样本下标):
nonzero(gates)[[1,3,0,2], 0] = [nonzero[1,0], nonzero[3,0], nonzero[0,0], nonzero[2,0]]
= [1, 3, 0, 2]所以 _batch_index = [1, 3, 0, 2],语义:
Slot 0 → 样本 1,去 Expert 0
Slot 1 → 样本 3,去 Expert 1
Slot 2 → 样本 0,去 Expert 2
Slot 3 → 样本 2,去 Expert 2Step 4 — 计算 _part_sizes
(gates > 0).sum(0).tolist() = 每个专家收到的样本数:
列 0(Expert 0): 1 个非零 → n_0 = 1
列 1(Expert 1): 1 个非零 → n_1 = 1
列 2(Expert 2): 2 个非零 → n_2 = 2
_part_sizes = [1, 1, 2]Step 5 — 提取 _nonzero_gates
python
gates_exp = gates[_batch_index.flatten()] # 按 slot 顺序取出各行
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)gates[_batch_index] = gates[[1,3,0,2]]:
= [[0.9, 0, 0 ], ← 样本 1
[0, 0.8, 0 ], ← 样本 3
[0, 0, 0.7], ← 样本 0
[0, 0, 0.5]] ← 样本 2
shape (4, 3)torch.gather(gates_exp, 1, _expert_index) 按每行的专家列提取对应门控值:
_expert_index = [[0],[1],[2],[2]]
result = [gates_exp[0,0], gates_exp[1,1], gates_exp[2,2], gates_exp[3,2]]
= [0.9, 0.8, 0.7, 0.5]
_nonzero_gates = [[0.9],[0.8],[0.7],[0.5]] shape (4, 1)§5.3 dispatch 详解
python
def dispatch(self, inp):
inp_exp = inp[self._batch_index].squeeze(1)
return torch.split(inp_exp, self._part_sizes, dim=0)Step 1 — 按 slot 顺序重排输入
inp[_batch_index] = inp[[1,3,0,2]] shape (4, L, 1) — 按专家分组顺序重排样本:
slot 0: inp[1] ← 给 Expert 0
slot 1: inp[3] ← 给 Expert 1
slot 2: inp[0] ← 给 Expert 2
slot 3: inp[2] ← 给 Expert 2⚠️ .squeeze(1) 是 no-op
.squeeze(1) 是 no-op
squeeze(1)只在 dim=1 的 size 为 1 时才压缩。对于时序输入 shape(B*N, L, 1),dim=1 的 size = L = 16 ≠ 1,故 squeeze(1) 不改变任何维度。这行代码是从原始 2D[batch, depth]设计遗留下来的,在当前 3D 输入场景下无效,但无害。
Step 2 — 按 _part_sizes 切分
torch.split(inp_exp, [1,1,2], dim=0) 沿 dim=0 按指定大小切分:
expert_inputs[0] shape (1, L, 1): [inp[1]] ← Expert 0 的子 batch
expert_inputs[1] shape (1, L, 1): [inp[3]] ← Expert 1 的子 batch
expert_inputs[2] shape (2, L, 1): [inp[0], inp[2]] ← Expert 2 的子 batch如果某个专家在当前 batch 中没有样本(e.g., _part_sizes[e] = 0),则 expert_inputs[e] 是 shape (0, L, 1) 的空 tensor。Linear_extractor.forward() 有 if x.shape[0] == 0: return empty_tensor 保护,会直接返回空 tensor。
§5.4 combine 详解
python
def combine(self, expert_out, multiply_by_gates=True):
stitched = torch.cat(expert_out, 0)
if multiply_by_gates:
stitched = torch.einsum("i...,ij->i...", stitched, self._nonzero_gates)
shape = list(expert_out[-1].shape)
shape[0] = self._gates.size(0)
zeros = torch.zeros(*shape, requires_grad=True, device=stitched.device)
combined = zeros.index_add(0, self._batch_index, stitched.float())
return combinedStep 1 — 拼接所有专家输出
设每个专家输出 d_model=2 维特征((n_i, 2, 1)):
expert_out[0] shape (1, 2, 1): [[y0_s, y0_t]] ← Expert 0 处理 inp[1] 的结果
expert_out[1] shape (1, 2, 1): [[y1_s, y1_t]] ← Expert 1 处理 inp[3] 的结果
expert_out[2] shape (2, 2, 1): [[y2a_s, y2a_t],
[y2b_s, y2b_t]] ← Expert 2 处理 inp[0]/inp[2] 的结果
stitched = cat([out0, out1, out2], 0) shape (4, 2, 1)
stitched[0] = [y0_s, y0_t] ← slot 0(样本 1,Expert 0)
stitched[1] = [y1_s, y1_t] ← slot 1(样本 3,Expert 1)
stitched[2] = [y2a_s, y2a_t] ← slot 2(样本 0,Expert 2)
stitched[3] = [y2b_s, y2b_t] ← slot 3(样本 2,Expert 2)Step 2 — 按门控权重缩放
torch.einsum("i...,ij->i...", stitched(4,2,1), _nonzero_gates(4,1)):
i遍历 4 个 slot...覆盖(2, 1)— 输出的特征维度j覆盖_nonzero_gates的第 1 维(大小为 1),对其求和 = 标量乘法
等价于:对每个 slot
stitched[0] *= 0.9 → [0.9·y0_s, 0.9·y0_t]
stitched[1] *= 0.8 → [0.8·y1_s, 0.8·y1_t]
stitched[2] *= 0.7 → [0.7·y2a_s, 0.7·y2a_t]
stitched[3] *= 0.5 → [0.5·y2b_s, 0.5·y2b_t]einsum 为何写成 "i...,ij->i..." 而不是简单的 mul?
"i...,ij->i..." 而不是简单的 mul?这是为了兼容 k>1 的情形。对 k=1,
_nonzero_gates是(B*N, 1),einsum 退化为标量乘法。对 k=2,每个样本路由到 2 个专家,stitched 有 2×B*N 行,_nonzero_gates仍是(2×B*N, 1),einsum 同样适用。ij中的j=1维在这里起到了维度对齐的作用(保证广播正确),并非真正的求和轴。
Step 3 — 散射回原始位置
python
shape = list(expert_out[-1].shape) # = [n_last, 2, 1]
shape[0] = self._gates.size(0) # → [4, 2, 1]
zeros = torch.zeros(4, 2, 1, requires_grad=True, ...)
combined = zeros.index_add(0, _batch_index=[1,3,0,2], stitched_scaled)index_add(0, index, src) 的语义:zeros[index[i]] += src[i] 对所有 i:
zeros[_batch_index[0]=1] += stitched[0] → combined[1] = 0.9·[y0_s, y0_t]
zeros[_batch_index[1]=3] += stitched[1] → combined[3] = 0.8·[y1_s, y1_t]
zeros[_batch_index[2]=0] += stitched[2] → combined[0] = 0.7·[y2a_s, y2a_t]
zeros[_batch_index[3]=2] += stitched[3] → combined[2] = 0.5·[y2b_s, y2b_t]最终 combined shape (4, 2, 1),已还原到原始样本顺序:
combined[0] = 0.7 × expert2_output(inp[0]) ← 样本 0,Expert 2 处理,gate=0.7
combined[1] = 0.9 × expert0_output(inp[1]) ← 样本 1,Expert 0 处理,gate=0.9
combined[2] = 0.5 × expert2_output(inp[2]) ← 样本 2,Expert 2 处理,gate=0.5
combined[3] = 0.8 × expert1_output(inp[3]) ← 样本 3,Expert 1 处理,gate=0.8k>1 时 index_add 的叠加语义
对 k=1,每个样本在
_batch_index中恰好出现 1 次,index_add 退化为赋值。对 k=2,每个样本出现 2 次(被 2 个专家处理),index_add 将两个专家的 weighted 输出相加——实现了 MoE 的加权叠加公式:这是 SparseDispatcher 实现稀疏 MoE 的核心等式。
§5.5 expert_to_gates 详解
python
def expert_to_gates(self):
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)将 _nonzero_gates (4, 1) 按 _part_sizes=[1,1,2] 切分:
gates_per_expert[0] shape (1, 1): [[0.9]] ← Expert 0 的样本的门控值
gates_per_expert[1] shape (1, 1): [[0.8]] ← Expert 1 的样本的门控值
gates_per_expert[2] shape (2, 1): [[0.7],[0.5]] ← Expert 2 的两个样本的门控值在 linear_extractor_cluster.py 的实际调用中,gates = dispatcher.expert_to_gates() 的结果被赋给 gates 变量,但此后未被使用(下一行直接进入 expert 前向)。这是代码冗余,不影响正确性。
§6 全局 toy 参数下的规模
全局 toy:B*N=21, num_experts=6, k=1。
torch.nonzero(gates) → shape (21, 2)(每个样本恰好 1 个非零)。
_part_sizes 的期望分布(均匀分配)约为 [4, 4, 4, 3, 3, 3](21 / 6 ≈ 3.5),实际由门控路由决定,可能某些专家获得更多样本。
专家接收空 batch(_part_sizes[e] = 0)的情况完全合法;Linear_extractor.forward() 对此有 if x.shape[0] == 0 保护。
创建:2026-04-24