Skip to content

Level4B Encoder 精读

Abstract

覆盖:Transformer_EncDec.py:Encoder / EncoderLayer + SelfAttention_Family.py:AttentionLayer / FullAttention + PatchTST.py:FlattenHead

输入(Encoder):(B×enc_in, patch_num, d_model) = (6, 4, 8) 输出(FlattenHead 之后):(B, pred_len, enc_in) = (2, 7, 3)


组件顺序图

组件层次树

读图方法:Encoder 是调度层,EncoderLayer 是每一层的运算单元,AttentionLayer 负责多头拆分/合并,FullAttention 才是真正的 QKV 矩阵运算。四层嵌套,从外到内精读。FlattenHead 是独立的预测头,在 Encoder 之外。


原理→代码映射

论文步骤对应代码文件行说明
Multi-Head Self-AttentionAttentionLayer + FullAttentionSelfAttention_Family.py:198/58AttentionLayer 负责拆/合头,FullAttention 算分数
Q/K/V 线性投影self.query_projection(queries)SelfAttention_Family.py:211三个独立 Linear(d_model, d_k×H)
多头拆分.view(B, L, H, d_k)SelfAttention_Family.py:218view 把 d_model=8 切成 H=2 × d_k=4
attention 分数einsum("blhe,bshe->bhls", Q, K)SelfAttention_Family.py:85等价于每头做 Q×Kᵀ,但用 einsum 一次算全部
softmax + 加权einsum("bhls,bshd->blhd", A, V)SelfAttention_Family.py:94用概率权重对 V 做加权求和
FFN两个 Conv1dTransformer_EncDec.py:35Conv1d(kernel=1) = 在每个位置独立做 MLP,等价于 Linear
预测头FlattenHeadPatchTST.py:9Flatten(patch 和 d_model 两维) + Linear→pred_len

最容易卡住的两处

卡点 1:view + transpose 为什么能"拆成多头"?

原始 Q: (6, 4, 8) = (B×enc_in, patch_num, d_model)

.view(6, 4, 2, 4):
  → (6, 4, H=2, d_k=4)
  把每个时间步的 8 维向量拆成 2 个 4 维"头"

.transpose(1, 2):
  → (6, 2, 4, 4) = (B, H, L, d_k)
  把头的维度挪到前面,让每个头能独立计算 attention

直觉:每个头"看"的是不同的子空间(4 维),多个头并行,然后把结果拼回来。

卡点 2:einsum("blhe,bshe->bhls", Q, K) 是什么意思?

Q: (B, L_Q, H, d_k) = (6, 4, 2, 4)    ← b=batch, l=query位置, h=头, e=d_k
K: (B, L_K, H, d_k) = (6, 4, 2, 4)    ← b=batch, s=key位置,   h=头, e=d_k

einsum 指定:对 e 维做求和(点积),保留 b,l,h,s
输出:(B, H, L_Q, L_K) = (6, 2, 4, 4)  ← 每头、每个 Q 对每个 K 的 score

等价逻辑:for each b,h,l,s: scores[b,h,l,s] = sum(Q[b,l,h,:] * K[b,s,h,:])

Toy 参数(本文统一)

B=2, enc_in=3, d_model=8, patch_num=4
n_heads=2, d_k=d_model//n_heads=4, d_ff=16(用于toy,真实脚本里更大)
pred_len=7, head_nf=d_model×patch_num=32
B×enc_in=6

一、Encoder 入口

1. 原始代码

python
# Transformer_EncDec.py:52
class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        ...

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        attns = []
        if self.conv_layers is not None:
            ...  # distillation path,PatchTST 不走
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns

2. 注解版

python
def forward(self, x, attn_mask=None, tau=None, delta=None):
    # x: (6, 4, 8)  ← (B×enc_in, patch_num, d_model)
    attns = []
    # PatchTST 构造时 conv_layers=None → 走 else 分支(无 distillation)
    for attn_layer in self.attn_layers:
        # 每层 EncoderLayer 处理后 shape 不变
        x, attn = attn_layer(x, ...)
        # x: (6, 4, 8)
        attns.append(attn)  # attn: None(output_attention=False 时)

    x = self.norm(x)  # LayerNorm(d_model=8),shape 不变: (6, 4, 8)
    return x, attns   # (6, 4, 8)

PatchTST 默认 e_layers=1(单层 Encoder)。多层时重复以上循环,每次 shape 不变。


二、EncoderLayer

1. 原始代码

python
# Transformer_EncDec.py:29
class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta)
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn

2. 注解版

python
def forward(self, x, attn_mask=None, tau=None, delta=None):
    # x: (6, 4, 8)  ← (B×enc_in, patch_num, d_model)

    # ── Self-Attention + 残差 ────────────────────────────────
    new_x, attn = self.attention(x, x, x, ...)
    # new_x: (6, 4, 8),Q/K/V 全部来自 x(self-attention)
    x = x + self.dropout(new_x)
    # x: (6, 4, 8),加完残差

    # ── FFN(用 Conv1d 实现)+ 残差 ──────────────────────────
    y = x = self.norm1(x)
    # 注意:同时给 x 和 y 赋值,x 用于后面的残差连接,y 用于 FFN 变换
    # x: (6, 4, 8),y: (6, 4, 8)

    y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
    # y.transpose(-1, 1): (6, 4, 8) → (6, 8, 4)  ← Conv1d 需要 (B, C, L) 格式
    # conv1: Conv1d(d_model=8, d_ff=16, kernel=1) → (6, 16, 4)
    # activation(relu/gelu) + dropout → (6, 16, 4)

    y = self.dropout(self.conv2(y).transpose(-1, 1))
    # conv2: Conv1d(d_ff=16, d_model=8, kernel=1) → (6, 8, 4)
    # .transpose(-1, 1): (6, 8, 4) → (6, 4, 8)  ← 换回 (B, L, C) 格式
    # dropout → (6, 4, 8)

    return self.norm2(x + y), attn
    # x + y: (6, 4, 8),残差连接
    # norm2: LayerNorm(8),shape 不变
    # 输出: (6, 4, 8)

3. FFN 用 Conv1d 而非 Linear 的原因

标准 Transformer FFN 用两个 nn.LinearLinear(d_model, d_ff) + Linear(d_ff, d_model)

这里改用 Conv1d(kernel=1),在数学上完全等价kernel_size=1 的 Conv1d 对每个位置独立做线性变换,等价于在序列长度维度上共享权重的 Linear。

区别在于 API 的维度约定:

  • nn.Linear:作用在最后一维,输入 (B, L, d_model)
  • nn.Conv1d:作用在中间维,输入必须是 (B, C, L),即 C(channels)在第二维

所以 Conv1d 之前必须先 transpose(-1, 1)(B, L, C) 变成 (B, C, L),之后再 transpose(-1, 1) 变回来。

toy 数值追踪(FFN 部分)

输入 y: (6, 4, 8)  例如 y[0, 0, :] = [0.5, 1.2, -0.3, 0.8, 1.5, -1.0, 0.0, 0.7]

y.transpose(-1, 1): (6, 8, 4)
  y[0, :, 0] = [0.5, 1.2, -0.3, 0.8, 1.5, -1.0, 0.0, 0.7]  ← 原来的 d_model 轴
  y[0, :, 1] = y[0, 1, :]  ← 原来 patch_num=1 的值
  ...

conv1: kernel_size=1,等价于对每个时间位置做 Linear(8→16)
  conv1(y.transpose)[0, :, 0]: 8个输入 → 16个输出(一个位置的变换)

结果 (6, 16, 4) 再 transpose 回 (6, 4, 8)  [如果 d_ff=d_model,此处 d_ff=16≠8 可见]

三、AttentionLayer

1. 原始代码

python
# SelfAttention_Family.py:198
class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        d_keys = d_keys or (d_model // n_heads)   # = 8 // 2 = 4
        d_values = d_values or (d_model // n_heads) # = 4
        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)    # Linear(8, 8)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)      # Linear(8, 8)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)  # Linear(8, 8)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)    # Linear(8, 8)
        self.n_heads = n_heads  # = 2

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(queries, keys, values, attn_mask, ...)
        out = out.view(B, L, -1)

        return self.out_projection(out), attn

2. 注解版

python
def forward(self, queries, keys, values, attn_mask, ...):
    # PatchTST self-attention: queries=keys=values=x (来自 EncoderLayer)
    # x: (6, 4, 8)  ← (B×enc_in, patch_num, d_model)
    B, L, _ = queries.shape  # B=6, L=4(patch_num)
    _, S, _ = keys.shape     # S=4(self-attn 时 S=L)
    H = self.n_heads         # H=2

    # Q/K/V 投影 + 拆分多头
    queries = self.query_projection(queries).view(B, L, H, -1)
    # query_projection: Linear(8,8), 输出 (6,4,8)
    # .view(6, 4, 2, -1): -1 = 8//2 = 4
    # queries: (6, 4, 2, 4)  ← (B×enc_in, patch_num, n_heads, d_k)

    keys    = self.key_projection(keys).view(B, S, H, -1)
    # keys: (6, 4, 2, 4)

    values  = self.value_projection(values).view(B, S, H, -1)
    # values: (6, 4, 2, 4)

    out, attn = self.inner_attention(queries, keys, values, attn_mask, ...)
    # out: (6, 4, 2, 4)  ← FullAttention 的输出(见下节)

    out = out.view(B, L, -1)
    # out.view(6, 4, -1): -1 = n_heads × d_k = 2×4 = 8
    # out: (6, 4, 8)  ← 合并多头

    return self.out_projection(out), attn
    # out_projection: Linear(8, 8)
    # 输出: (6, 4, 8)

四、FullAttention

1. 原始代码

python
# SelfAttention_Family.py:58
class FullAttention(nn.Module):
    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape   # B=6, L=4, H=2, E=4 (d_k)
        _, S, _, D = values.shape    # S=4, D=4 (d_v)
        scale = self.scale or 1.0 / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:  # PatchTST 里 mask_flag=False,跳过
            ...

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        return V.contiguous(), None  # output_attention=False

2. 注解版

python
def forward(self, queries, keys, values, attn_mask, ...):
    # queries: (6, 4, 2, 4)  ← (B×enc_in, patch_num, n_heads, d_k)
    # keys:    (6, 4, 2, 4)
    # values:  (6, 4, 2, 4)
    B, L, H, E = queries.shape  # 6, 4, 2, 4
    scale = 1.0 / sqrt(4) = 0.5

    # ── 计算注意力分数 ────────────────────────────────────────
    scores = torch.einsum("blhe,bshe->bhls", queries, keys)
    # einsum 下标解读:
    #   b=batch(6), l=query位置(4), h=head(2), e=d_k(4)
    #   b=batch(6), s=key位置(4),   h=head(2), e=d_k(4)
    #   → bhls: 对 e(d_k)维度求和,结果形状 (B, H, L, S)
    # scores: (6, 2, 4, 4)  ← (B×enc_in, n_heads, patch_num, patch_num)

    # mask_flag=False(PatchTST 不掩码,所有 patch 之间都可以 attend)

    # ── softmax 归一化 ────────────────────────────────────────
    A = torch.softmax(scale * scores, dim=-1)
    # scale * scores: (6, 2, 4, 4)
    # softmax(dim=-1): 对最后一维(key 位置 s)做 softmax
    # A: (6, 2, 4, 4)  ← 每个 query 位置对应 4 个 key 的注意力权重,和=1

    # ── 加权聚合 value ────────────────────────────────────────
    V = torch.einsum("bhls,bshd->blhd", A, values)
    # einsum 下标解读:
    #   b=batch(6), h=head(2), l=query位置(4), s=key位置(4)
    #   b=batch(6), s=key位置(4), h=head(2), d=d_v(4)
    #   → blhd: 对 s(key位置)求和,结果形状 (B, L, H, D)
    # V: (6, 4, 2, 4)  ← (B×enc_in, patch_num, n_heads, d_v)

    return V.contiguous(), None
    # V: (6, 4, 2, 4)

3. einsum 下标深度解析

scores = torch.einsum("blhe,bshe->bhls", queries, keys)

逐字母解读:

字母含义维度值在输出中
bbatch(含 channel)6✅ 保留
lquery 的时序位置(patch 号)4✅ 保留
hattention head2✅ 保留
ed_k(每个 head 的 key 维度)4❌ 求和消去
skey 的时序位置(patch 号)4✅ 保留

queries[b,l,h,:]keys[b,s,h,:] 做点积(对 e 求和),得到第 h 个 head 里第 l 个 query patch 对第 s 个 key patch 的相似度分数。

等价的矩阵写法(固定 b 和 h):

scores[b, h] = queries[b, :, h, :] @ keys[b, :, h, :].T
             = (4, 4) @ (4, 4).T = (4, 4)

V = torch.einsum("bhls,bshd->blhd", A, values)

A[b,h,l,s] 是第 l 个 query patch 对第 s 个 key patch 的注意力权重(已经 softmax 过,和=1)。

这里对 s(key 位置)求和:把 4 个 value patch 按注意力权重加权求和,得到第 l 个 query patch 的输出。

等价的矩阵写法(固定 b 和 h):

V[b, :, h, :] = A[b, h, :, :] @ values[b, :, h, :]
              = (4, 4) @ (4, 4) = (4, 4)

4. toy 数值追踪(单头简化版)

固定 b=0, h=0(第 0 个虚拟 batch 的第 0 个 head):

queries[0, :, 0, :]:  4 个 patch × d_k=4 → 矩阵 Q (4×4)
keys[0, :, 0, :]:     4 个 patch × d_k=4 → 矩阵 K (4×4)

设(简化):
Q = [ [1,0,0,0],
     [0,1,0,0],
     [0,0,1,0],
     [0,0,0,1] ]

K = [ [1,0,0,0],
     [0,1,0,0],
     [0,0,1,0],
     [0,0,0,1] ]

scores = Q @ K.T = I (4×4) 单位矩阵
scale = 0.5
scaled_scores = 0.5 × I

softmax(dim=-1) 对每行做 softmax:
  每行 = [0.5, 0, 0, 0](除了对角线)
  softmax([0.5, 0, 0, 0]) = [e^0.5/(e^0.5+3), 1/(e^0.5+3), ...]
                           ≈ [0.432, 0.189, 0.189, 0.189]

→ A[0, 0, :, :] 每行都是 [0.432, 0.189, 0.189, 0.189](因为 K=I)

实际意义:每个 patch 主要 attend 自己(权重 0.432),也少量 attend 其他 patch(各 0.189)

五、FlattenHead

这部分紧跟在 Encoder 输出后,位于 forecast() 的步骤 5-8。

1. 原始代码

python
# PatchTST.py:9
class FlattenHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x

2. 注解版

python
def forward(self, x):
    # 进入此函数之前,forecast() 步骤 5-6 已经完成了 reshape 和 permute:
    # x: (2, 3, 8, 4)  ← (B, enc_in, d_model, patch_num)

    x = self.flatten(x)
    # nn.Flatten(start_dim=-2): 把最后两个维度合并
    # (2, 3, 8, 4) → (2, 3, 32)  ← head_nf = d_model × patch_num = 8×4 = 32

    x = self.linear(x)
    # nn.Linear(nf=32, target_window=7)
    # 作用在最后一维:32 → 7,前面 (2, 3) 是 batch 维
    # (2, 3, 32) → (2, 3, 7)  ← pred_len = 7

    x = self.dropout(x)
    return x  # (2, 3, 7)

3. nn.Flatten(start_dim=-2) 语义

nn.Flatten(start_dim, end_dim) 把从 start_dimend_dim(默认到最后一维)的所有维度合并成一个维度。

start_dim=-2 表示从倒数第二维开始合并:

输入: (2, 3, 8, 4)
  dim: 0=B, 1=enc_in, 2=d_model(-2), 3=patch_num(-1)
                              ↑ start_dim=-2
合并 dim[-2] 和 dim[-1]: 8×4 = 32
输出: (2, 3, 32)

4. toy 数值追踪

x[0, 0, :, :]:  (8, 4) 矩阵 → 展平为长度 32 的向量
  行0: [v00, v01, v02, v03]
  行1: [v10, v11, v12, v13]
  ...
  行7: [v70, v71, v72, v73]

展平后: [v00, v01, v02, v03, v10, v11, ..., v70, v71, v72, v73]  (长度32)

nn.Linear(32, 7): 设 W 是 (7, 32) 权重矩阵
  output[k] = sum_j(W[k,j] × input[j])  for k=0..6

得到 x[0, 0, :]: 长度 7 的向量 = 对 channel 0 的 pred_len=7 步预测(归一化空间)

六、forecast() 步骤 8 后续

FlattenHead 输出 (2, 3, 7) 后,forecast() 还有两步:

步骤 8:permute(0, 2, 1)

python
dec_out = dec_out.permute(0, 2, 1)
# (2, 3, 7) → (2, 7, 3)
# 把 (B, enc_in, pred_len) 变回 (B, pred_len, enc_in)
# 还原成 TFB 框架的标准输出格式

步骤 9:Denormalize(细节见 03-Level3-forward主链 §3.9)

python
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)
dec_out = dec_out + means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)
# (2, 7, 3),值从归一化空间还原到原始量纲

七、完整 shape 链汇总

Encoder 输入:  (6, 4, 8)   ← 04A 的输出
  ↓ EncoderLayer × e_layers
    ├── Self-Attention       (6,4,8) → QKV (6,4,2,4) → scores (6,2,4,4) → out (6,4,8)
    └── FFN (Conv1d)         (6,4,8) → transpose (6,8,4) → conv1 (6,16,4) → conv2 (6,8,4) → transpose (6,4,8)
  ↓ LayerNorm
Encoder 输出:  (6, 4, 8)

  ↓ reshape    (6,4,8) → (2,3,4,8)
  ↓ permute    (2,3,4,8) → (2,3,8,4)
  ↓ FlattenHead
    ├── Flatten(start_dim=-2)  (2,3,8,4) → (2,3,32)
    └── Linear(32,7)           (2,3,32) → (2,3,7)
  ↓ permute    (2,3,7) → (2,7,3)
  ↓ Denormalize

最终输出: (2, 7, 3)  = (B, pred_len, enc_in)

*记录并在线阅读我的笔记*