Skip to content

DUET · Layer 3 — SparseDispatcher

§1 在父层中的位置

Linear_extractor_cluster.forward() 在计算完稀疏门控矩阵 gates 后,立刻构造 SparseDispatcher(num_experts, gates) 并调用其三个方法:

python
dispatcher = SparseDispatcher(self.num_experts, gates)
expert_inputs = dispatcher.dispatch(x_norm)          # 分发
gates_per_expert = dispatcher.expert_to_gates()      # 取门控(供专家用,此代码中实际未使用)
expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]
y = dispatcher.combine(expert_outputs)               # 聚合

§2 I/O 接口定义

方法参数返回
__init__(num_experts, gates)gates (B*N, E) 稀疏矩阵预计算索引,无返回
dispatch(inp)inp (B*N, L, 1)list[E],第 i 项 shape (n_i, L, 1)
combine(expert_out)list[E],第 i 项 shape (n_i, d_model, 1)(B*N, d_model, 1)
expert_to_gates()list[E],第 i 项 shape (n_i, 1)

全局 toy 参数下:B*N=21, E=6, k=1, L=16, d_model=8。每个样本恰好路由到 1 个专家,ini=21


§3 顺序图(具体层)


§4 语义分组图(索引层)


§5 逐步骤精读

§5.0 完整原始代码

python
class SparseDispatcher(object):

    def __init__(self, num_experts, gates):
        self._gates = gates
        self._num_experts = num_experts
        sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
        _, self._expert_index = sorted_experts.split(1, dim=1)
        self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
        self._part_sizes = (gates > 0).sum(0).tolist()
        gates_exp = gates[self._batch_index.flatten()]
        self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)

    def dispatch(self, inp):
        inp_exp = inp[self._batch_index].squeeze(1)
        return torch.split(inp_exp, self._part_sizes, dim=0)

    def combine(self, expert_out, multiply_by_gates=True):
        # apply exp to expert outputs, so we are not longer in log space
        stitched = torch.cat(expert_out, 0)
        if multiply_by_gates:
            # stitched = stitched.mul(self._nonzero_gates)
            stitched = torch.einsum("i...,ij->i...", stitched, self._nonzero_gates)
        shape = list(expert_out[-1].shape)
        shape[0] = self._gates.size(0)
        zeros = torch.zeros(*shape, requires_grad=True, device=stitched.device)
        combined = zeros.index_add(0, self._batch_index, stitched.float())
        return combined

    def expert_to_gates(self):
        return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
⚠️ 两处遗留注释
  1. combine 开头注释 # apply exp to expert outputs, so we are not longer in log space — 这是从原始 T2T 代码遗留的错误注释,当前实现并不做 exp 运算。
  2. 被注释掉的 stitched = stitched.mul(self._nonzero_gates) 是旧写法(仅适用于 1D 输出),被 einsum 泛化版本替代。两行都是无效代码,不影响运行。

§5.1 宏观逻辑

一句话目标:把稀疏门控矩阵的"谁去哪个专家、权重多少"信息预计算成三个索引(_batch_index / _part_sizes / _nonzero_gates),让 dispatch 和 combine 各只做一次高效批量操作,而不是循环 B*N 次。

为什么要"预计算索引"而不是每次 dispatch/combine 都重新查找?

dispatch 和 combine 本质上都是稀疏聚集/分散操作。如果 forward() 中直接 for b in range(B*N): send_to_expert(x[b], which_expert[b]),循环次数 = 21,而且无法批量化。预计算后:

  • dispatch = 1 次索引取行(inp[_batch_index])+ 1 次 split
  • combine = 1 次 cat + 1 次 einsum + 1 次 index_add

最小例子(4 个样本,3 个专家,k=1)追踪

输入 gates (4×3):
行 0: [  0,    0,  0.7]  → 样本 0 → Expert 2
行 1: [0.9,    0,    0]  → 样本 1 → Expert 0
行 2: [  0,    0,  0.5]  → 样本 2 → Expert 2
行 3: [  0,  0.8,    0]  → 样本 3 → Expert 1

预期分组结果:
Expert 0 ← [样本 1]          gate=0.9
Expert 1 ← [样本 3]          gate=0.8
Expert 2 ← [样本 0, 样本 2]  gate=0.7, 0.5

§5.2 __init__ 详解

Step 1 — 找所有非零元素坐标

torch.nonzero(gates) 返回 gates 中每个非零值的坐标:

对于上面的 gates (4×3),非零元素在 (0,2), (1,0), (2,2), (3,1) 四处:

nonzero(gates) = [[0, 2],
                  [1, 0],
                  [2, 2],
                  [3, 1]]

shape (4, 2) — 4 个非零值,每行是 [样本下标, 专家下标]

Step 2 — 按专家下标排序

torch.nonzero(gates).sort(0) 沿 dim=0(行)对每列独立排序:

sort(0) 对 [[0,2],[1,0],[2,2],[3,1]]:

  列 0 排序: [0,1,2,3] → 无变化,indices=[0,1,2,3]
  列 1 排序: [2,0,2,1] → [0,1,2,2],indices=[1,3,0,2]

sorted_experts  = [[0,0],[1,1],[2,2],[3,2]]   ← 两列各自排序后的值
index_sorted_experts = [[0,1],[1,3],[2,0],[3,2]]  ← 各列的原始行号
为什么 sort(0) 能提取专家排序?

index_sorted_experts[:, 1] = [1, 3, 0, 2] 是"把列 1(专家下标)排成升序所需的行置换"。这个置换应用到原始 nonzero 数组后,结果按专家编号分组。这是整个 init 的核心技巧。

Step 3 — 提取 _expert_index_batch_index

_, _expert_index = sorted_experts.split(1, dim=1):

  • _ = sorted_experts 列 0 = [[0],[1],[2],[3]](样本下标已排,被丢弃)
  • _expert_index = sorted_experts 列 1 = [[0],[1],[2],[2]]

_batch_index = nonzero(gates)[index_sorted_experts[:, 1], 0]

将置换 [1,3,0,2] 应用到原始 nonzero 数组的列 0(样本下标):

nonzero(gates)[[1,3,0,2], 0] = [nonzero[1,0], nonzero[3,0], nonzero[0,0], nonzero[2,0]]
                              = [1, 3, 0, 2]

所以 _batch_index = [1, 3, 0, 2],语义:

Slot 0 → 样本 1,去 Expert 0
Slot 1 → 样本 3,去 Expert 1
Slot 2 → 样本 0,去 Expert 2
Slot 3 → 样本 2,去 Expert 2

Step 4 — 计算 _part_sizes

(gates > 0).sum(0).tolist() = 每个专家收到的样本数:

列 0(Expert 0): 1 个非零   → n_0 = 1
列 1(Expert 1): 1 个非零   → n_1 = 1
列 2(Expert 2): 2 个非零   → n_2 = 2
_part_sizes = [1, 1, 2]

Step 5 — 提取 _nonzero_gates

python
gates_exp = gates[_batch_index.flatten()]      # 按 slot 顺序取出各行
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)

gates[_batch_index] = gates[[1,3,0,2]]

= [[0.9, 0,   0  ],   ← 样本 1
   [0,   0.8, 0  ],   ← 样本 3
   [0,   0,   0.7],   ← 样本 0
   [0,   0,   0.5]]   ← 样本 2
shape (4, 3)

torch.gather(gates_exp, 1, _expert_index) 按每行的专家列提取对应门控值:

_expert_index = [[0],[1],[2],[2]]

result = [gates_exp[0,0], gates_exp[1,1], gates_exp[2,2], gates_exp[3,2]]
       = [0.9, 0.8, 0.7, 0.5]
_nonzero_gates = [[0.9],[0.8],[0.7],[0.5]]  shape (4, 1)

§5.3 dispatch 详解

python
def dispatch(self, inp):
    inp_exp = inp[self._batch_index].squeeze(1)
    return torch.split(inp_exp, self._part_sizes, dim=0)

Step 1 — 按 slot 顺序重排输入

inp[_batch_index] = inp[[1,3,0,2]] shape (4, L, 1) — 按专家分组顺序重排样本:

slot 0: inp[1]   ← 给 Expert 0
slot 1: inp[3]   ← 给 Expert 1
slot 2: inp[0]   ← 给 Expert 2
slot 3: inp[2]   ← 给 Expert 2
⚠️ .squeeze(1) 是 no-op

squeeze(1) 只在 dim=1 的 size 为 1 时才压缩。对于时序输入 shape (B*N, L, 1),dim=1 的 size = L = 16 ≠ 1,故 squeeze(1) 不改变任何维度。这行代码是从原始 2D [batch, depth] 设计遗留下来的,在当前 3D 输入场景下无效,但无害。

Step 2 — 按 _part_sizes 切分

torch.split(inp_exp, [1,1,2], dim=0) 沿 dim=0 按指定大小切分:

expert_inputs[0] shape (1, L, 1): [inp[1]]           ← Expert 0 的子 batch
expert_inputs[1] shape (1, L, 1): [inp[3]]           ← Expert 1 的子 batch
expert_inputs[2] shape (2, L, 1): [inp[0], inp[2]]   ← Expert 2 的子 batch

如果某个专家在当前 batch 中没有样本(e.g., _part_sizes[e] = 0),则 expert_inputs[e] 是 shape (0, L, 1) 的空 tensor。Linear_extractor.forward()if x.shape[0] == 0: return empty_tensor 保护,会直接返回空 tensor。


§5.4 combine 详解

python
def combine(self, expert_out, multiply_by_gates=True):
    stitched = torch.cat(expert_out, 0)
    if multiply_by_gates:
        stitched = torch.einsum("i...,ij->i...", stitched, self._nonzero_gates)
    shape = list(expert_out[-1].shape)
    shape[0] = self._gates.size(0)
    zeros = torch.zeros(*shape, requires_grad=True, device=stitched.device)
    combined = zeros.index_add(0, self._batch_index, stitched.float())
    return combined

Step 1 — 拼接所有专家输出

设每个专家输出 d_model=2 维特征((n_i, 2, 1)):

expert_out[0] shape (1, 2, 1): [[y0_s, y0_t]]      ← Expert 0 处理 inp[1] 的结果
expert_out[1] shape (1, 2, 1): [[y1_s, y1_t]]      ← Expert 1 处理 inp[3] 的结果
expert_out[2] shape (2, 2, 1): [[y2a_s, y2a_t],
                                  [y2b_s, y2b_t]]   ← Expert 2 处理 inp[0]/inp[2] 的结果

stitched = cat([out0, out1, out2], 0) shape (4, 2, 1)
stitched[0] = [y0_s, y0_t]    ← slot 0(样本 1,Expert 0)
stitched[1] = [y1_s, y1_t]    ← slot 1(样本 3,Expert 1)
stitched[2] = [y2a_s, y2a_t]  ← slot 2(样本 0,Expert 2)
stitched[3] = [y2b_s, y2b_t]  ← slot 3(样本 2,Expert 2)

Step 2 — 按门控权重缩放

torch.einsum("i...,ij->i...", stitched(4,2,1), _nonzero_gates(4,1)):

  • i 遍历 4 个 slot
  • ... 覆盖 (2, 1) — 输出的特征维度
  • j 覆盖 _nonzero_gates 的第 1 维(大小为 1),对其求和 = 标量乘法

等价于:对每个 slot istitched[i,...]stitched[i,...]×\_nonzero\_gates[i,0]

stitched[0] *= 0.9  → [0.9·y0_s, 0.9·y0_t]
stitched[1] *= 0.8  → [0.8·y1_s, 0.8·y1_t]
stitched[2] *= 0.7  → [0.7·y2a_s, 0.7·y2a_t]
stitched[3] *= 0.5  → [0.5·y2b_s, 0.5·y2b_t]
einsum 为何写成 "i...,ij->i..." 而不是简单的 mul

这是为了兼容 k>1 的情形。对 k=1,_nonzero_gates(B*N, 1),einsum 退化为标量乘法。对 k=2,每个样本路由到 2 个专家,stitched 有 2×B*N 行,_nonzero_gates 仍是 (2×B*N, 1),einsum 同样适用。ij 中的 j=1 维在这里起到了维度对齐的作用(保证广播正确),并非真正的求和轴。

Step 3 — 散射回原始位置

python
shape = list(expert_out[-1].shape)   # = [n_last, 2, 1]
shape[0] = self._gates.size(0)        # → [4, 2, 1]
zeros = torch.zeros(4, 2, 1, requires_grad=True, ...)
combined = zeros.index_add(0, _batch_index=[1,3,0,2], stitched_scaled)

index_add(0, index, src) 的语义:zeros[index[i]] += src[i] 对所有 i:

zeros[_batch_index[0]=1] += stitched[0]  → combined[1] = 0.9·[y0_s, y0_t]
zeros[_batch_index[1]=3] += stitched[1]  → combined[3] = 0.8·[y1_s, y1_t]
zeros[_batch_index[2]=0] += stitched[2]  → combined[0] = 0.7·[y2a_s, y2a_t]
zeros[_batch_index[3]=2] += stitched[3]  → combined[2] = 0.5·[y2b_s, y2b_t]

最终 combined shape (4, 2, 1),已还原到原始样本顺序:

combined[0] = 0.7 × expert2_output(inp[0])   ← 样本 0,Expert 2 处理,gate=0.7
combined[1] = 0.9 × expert0_output(inp[1])   ← 样本 1,Expert 0 处理,gate=0.9
combined[2] = 0.5 × expert2_output(inp[2])   ← 样本 2,Expert 2 处理,gate=0.5
combined[3] = 0.8 × expert1_output(inp[3])   ← 样本 3,Expert 1 处理,gate=0.8
k>1 时 index_add 的叠加语义

对 k=1,每个样本在 _batch_index 中恰好出现 1 次,index_add 退化为赋值。对 k=2,每个样本出现 2 次(被 2 个专家处理),index_add 将两个专家的 weighted 输出相加——实现了 MoE 的加权叠加公式:

yb=e=1kgb,eexperte(xb)

这是 SparseDispatcher 实现稀疏 MoE 的核心等式。


§5.5 expert_to_gates 详解

python
def expert_to_gates(self):
    return torch.split(self._nonzero_gates, self._part_sizes, dim=0)

_nonzero_gates (4, 1)_part_sizes=[1,1,2] 切分:

gates_per_expert[0] shape (1, 1): [[0.9]]   ← Expert 0 的样本的门控值
gates_per_expert[1] shape (1, 1): [[0.8]]   ← Expert 1 的样本的门控值
gates_per_expert[2] shape (2, 1): [[0.7],[0.5]]  ← Expert 2 的两个样本的门控值

linear_extractor_cluster.py 的实际调用中,gates = dispatcher.expert_to_gates() 的结果被赋给 gates 变量,但此后未被使用(下一行直接进入 expert 前向)。这是代码冗余,不影响正确性。


§6 全局 toy 参数下的规模

全局 toy:B*N=21, num_experts=6, k=1。

torch.nonzero(gates) → shape (21, 2)(每个样本恰好 1 个非零)。

_part_sizes 的期望分布(均匀分配)约为 [4, 4, 4, 3, 3, 3](21 / 6 ≈ 3.5),实际由门控路由决定,可能某些专家获得更多样本。

专家接收空 batch(_part_sizes[e] = 0)的情况完全合法;Linear_extractor.forward() 对此有 if x.shape[0] == 0 保护。


创建:2026-04-24

*记录并在线阅读我的笔记*