Skip to content

DUET · Layer 1 — DUETModel 主链

§1 在父层中的位置

DUET._process() 调用 self.model(input),即 DUETModel.forward(input)


§2 I/O 接口定义

python
def forward(self, input) -> Tuple[Tensor, Tensor]
参数shape含义
input(B, L, N) = (3, 16, 7)历史观测值
返回 output(B, pred_len, N) = (3, 5, 7)预测值(已 denorm)
返回 L_importancescalarMoE 负载均衡损失

§3 顺序图(具体层)


§4 语义分组图(索引层)


§5 逐步骤精读

§5.0 完整原始代码

python
class DUETModel(nn.Module):
    def __init__(self, config):
        super(DUETModel, self).__init__()
        self.cluster = Linear_extractor_cluster(config)
        self.CI = config.CI
        self.n_vars = config.enc_in
        self.mask_generator = Mahalanobis_mask(config.seq_len)
        self.Channel_transformer = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(
                            True,
                            config.factor,
                            attention_dropout=config.dropout,
                            output_attention=config.output_attention,
                        ),
                        config.d_model,
                        config.n_heads,
                    ),
                    config.d_model,
                    config.d_ff,
                    dropout=config.dropout,
                    activation=config.activation,
                )
                for _ in range(config.e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(config.d_model),
        )
        self.linear_head = nn.Sequential(
            nn.Linear(config.d_model, config.pred_len), nn.Dropout(config.fc_dropout)
        )

    def forward(self, input):
        # x: [batch_size, seq_len, n_vars]
        if self.CI:
            channel_independent_input = rearrange(input, "b l n -> (b n) l 1")
            reshaped_output, L_importance = self.cluster(channel_independent_input)
            temporal_feature = rearrange(
                reshaped_output, "(b n) l 1 -> b l n", b=input.shape[0]
            )
        else:
            temporal_feature, L_importance = self.cluster(input)

        # B x d_model x n_vars -> B x n_vars x d_model
        temporal_feature = rearrange(temporal_feature, "b d n -> b n d")
        if self.n_vars > 1:
            changed_input = rearrange(input, "b l n -> b n l")
            channel_mask = self.mask_generator(changed_input)
            channel_group_feature, attention = self.Channel_transformer(
                x=temporal_feature, attn_mask=channel_mask
            )
            output = self.linear_head(channel_group_feature)
        else:
            output = temporal_feature
            output = self.linear_head(output)

        output = rearrange(output, "b n d -> b d n")
        output = self.cluster.revin(output, "denorm")
        return output, L_importance

§5.1 宏观逻辑

一句话目标:用两条独立路径分别处理"时序模式多样性"和"通道关系稀疏性",最终在通道 Transformer 里融合,输出预测。

整体数据流 SVG

两路的设计动机

路径处理的问题输入输出
MoE 时序路径分布漂移:不同样本用不同专家(B*N, L, 1)temporal_feature (B, N, d_model)
Mahalanobis 通道路径通道异质:稀疏掩码过滤弱相关通道(B, N, L)mask (B, 1, N, N)

为什么先做 MoE 时序再做通道 Transformer?

MoE 的输出 temporal_feature 是"每个变量的时序特征向量"(维度从 L=16 压缩到 d_model=8)。通道 Transformer 以这些特征向量为 token 做变量间注意力——这个顺序是必然的:需要先有变量表示,才能做变量间交互。

用小例子(B=1, N=3, L=4, d_model=2)串起来

输入: (1, 4, 3)  ← 1 个样本,4 个时间步,3 个变量

CI=True:
  rearrange → (3, 4, 1)  ← 3 个变量独立处理
  cluster → (3, 2, 1)    ← 每个变量得到 d_model=2 维特征
  rearrange back → (1, 2, 3)  ← 注意:这里是 (B, d_model, N)

rearrange (b d n → b n d):
  (1, 2, 3) → (1, 3, 2)   ← (B, N, d_model),每行是一个变量的特征

Mahalanobis_mask:
  changed_input (1,4,3) → (1,3,4)
  → mask (1, 1, 3, 3)  ← 3×3 的通道关系矩阵

Channel_transformer:
  x=(1,3,2) + mask=(1,1,3,3)
  → (1, 3, 2)   ← 变量间信息交换

linear_head: (1,3,2) → (1,3,pred_len)
rearrange: (1,3,p) → (1,p,3)
revin denorm: (1,p,3)

§5.2 rearrange 链详解

rearrange 来自 einops 库,是有命名轴的 reshape/transpose 工具。

Step 1 — CI 展开

rearrange(input, "b l n -> (b n) l 1")

轴含义:b=batch, l=time, n=variables

(3, 16, 7)(21, 16, 1):把 3 个样本 × 7 个变量展开成 21 条"独立样本",每条样本只有 1 个变量(1 是 channel=1 的占位维度)。

这样 Linear_extractor_cluster 会以 CI 模式独立处理每条序列,变量间不共享梯度路径。

样本 0, 变量 0  →  行 0
样本 0, 变量 1  →  行 1
...
样本 0, 变量 6  →  行 6
样本 1, 变量 0  →  行 7
...
样本 2, 变量 6  →  行 20

Step 2 — 恢复 temporal_feature

rearrange(reshaped_output, "(b n) l 1 -> b l n", b=input.shape[0])

(21, 8, 1)(3, 8, 7):把 21 条输出重新组织回 B=3 个样本,每个样本 N=7 个变量的特征(d_model=8)。

⚠️ 注释中 shape 注释有误

源码注释写 # B x d_model x n_vars,确实对应 (3, 8, 7)(B×d_model×N)。但变量名叫 temporal_feature,容易让人误以为最后一维是时间步。实际上经过 MoE 后,"l" 轴已经从 seq_len=16 变成了 d_model=8(expert 的输出维度),时间轴被"压缩"成了特征维度。

Step 3 — 转置为变量 token 格式

rearrange(temporal_feature, "b d n -> b n d")

(3, 8, 7)(3, 7, 8):把 (B, d_model, N) 转置为 (B, N, d_model)

这是为了符合 Transformer 的输入格式 (B, seq_len, d_model)——这里"seq_len"实际上是 N=7 个变量(每个变量是一个 token)。

Step 4 — 通道 Transformer 路径的输入

rearrange(input, "b l n -> b n l")

(3, 16, 7)(3, 7, 16):把原始输入转置为 (B, N, L),每个变量是一行长度 L 的序列,送入 Mahalanobis_mask 计算通道相似度。

Step 5 — 输出格式还原

rearrange(output, "b n d -> b d n")

linear_head 输出 (3, 7, 5)rearrange(3, 5, 7):把 N 维和 pred_len 维交换,变成标准的 (B, pred_len, N) 格式。

§5.3 CI 分支

python
if self.CI:
    channel_independent_input = rearrange(input, "b l n -> (b n) l 1")
    reshaped_output, L_importance = self.cluster(channel_independent_input)
    temporal_feature = rearrange(
        reshaped_output, "(b n) l 1 -> b l n", b=input.shape[0]
    )
else:
    temporal_feature, L_importance = self.cluster(input)

CI=True(默认):把 B×N 条序列各自独立处理,每条序列只有 1 个通道(channel=1)。专家学到的是单变量时序规律。

CI=False:直接把 (B, L, N) 传入 cluster,专家看到所有变量(N 个通道)。专家学到的是跨变量联合规律。

模式cluster 输入cluster 输出特性
CI=True(B*N, L, 1)(B*N, d_model, 1)变量独立,参数量小,不学跨变量
CI=False(B, L, N)(B, d_model, N)变量联合,参数量大,能学跨变量

TFB 默认 CI=True

§5.4 n_vars=1 分支

python
if self.n_vars > 1:
    ...  # 完整双路
else:
    output = temporal_feature
    output = self.linear_head(output)

单变量时序(N=1)跳过 Mahalanobis_maskChannel_transformer(无需建模通道关系),直接用 linear_head 输出。

§5.5 RevIN denorm

python
output = self.cluster.revin(output, "denorm")

revinLinear_extractor_cluster 内部的 RevIN 对象(affine=True)。在 MoE 路径的 forward 中先做了 norm,这里统一做 denorm,恢复原始量纲。

RevIN.denorm 步骤(affine=True):

xxbiasweight+ε2xx×σ+μ

其中 σ,μnorm 时保存的实例统计量(.detach() 不参与梯度),weight, bias 是可学习仿射参数 (N,) 形状。

toy 数值(output (3, 5, 7) denorm 后):

revin.stdev[0, :, 0] = [2.0, 1.5, 3.0, ...](第 0 个样本各变量的 std),revin.mean[0, :, 0] = [5.0, -2.0, 8.0, ...]。对 output 中 output[0, :, 0](第 0 个样本第 0 个变量的 5 步预测),denorm 后乘以 2.0 再加 5.0,还原到原始尺度。


§6 下钻子组件

子组件职责下层文档
Linear_extractor_clusterMoE 时序特征提取:RevIN + 门控路由 + 稀疏分发 + 专家线性层[[03A-Layer2A-MoE时序路径]]
Mahalanobis_mask频域通道相似度 → Gumbel Bernoulli 采样 → 0/1 掩码[[03B-Layer2B-MahalanobisMask]]
Channel_transformer (Encoder)带 Mahalanobis 掩码的变量 Transformer[[03C-Layer2C-ChannelTransformer]]

创建:2026-04-24

*记录并在线阅读我的笔记*