Skip to content

Linear 最后一维规则:DLinear、Informer、PatchTST 的投影层

Abstract

这篇只讲一个基础规则:

nn.Linear(in_features, out_features) 永远作用在输入 tensor 的最后一维,前面的维度都被当作 batch 维保留。

0. 文件索引

项目内容
覆盖函数nn.Linear
覆盖模型DLinear / Informer / PatchTST
输入核心规则(..., in_features) -> (..., out_features)
关键误区Linear 不一定只吃二维矩阵,它可以吃三维、四维,只要最后一维对得上

1. Level 1:最核心规则

定义:

python
layer = nn.Linear(in_features, out_features)

输入:

text
x.shape = (..., in_features)

输出:

text
y.shape = (..., out_features)

前面的 ... 原样保留。

例如:

text
(2, 6, 4) -> Linear(4, 16) -> (2, 6, 16)

含义是:

2 * 6 = 12 个长度为 4 的向量,逐个使用同一个 Linear,投影成长度 16 的向量。

2. Level 2:DLinear 里的 Linear(seq_len, pred_len)

源码位置:

python
# DLinear.py
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)

DLinear 的 encoder 里会先做:

python
seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)

shape:

text
seasonal_init: (B, T, C) -> (B, C, T)
toy: (2, 6, 3) -> (2, 3, 6)

然后:

python
seasonal_output = self.Linear_Seasonal(seasonal_init)

如果:

text
seq_len = 6
pred_len = 2

那么:

text
Linear(6, 2)
输入: (B, C, seq_len) = (2, 3, 6)
输出: (B, C, pred_len) = (2, 3, 2)

这里 Linear 作用在最后一维 T=6,不是作用在变量维 C=3

所以 DLinear 的含义是:

对每个变量,把历史长度 seq_len 的序列直接线性映射成未来长度 pred_len

3. Level 3:Informer / PatchTST 注意力里的 Q/K/V 投影

源码位置:

python
# SelfAttention_Family.py
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)

toy:

text
B = 2
L = 6
d_model = 16
n_heads = 2
d_keys = 8

输入:

text
queries.shape = (2, 6, 16)

执行:

python
queries = self.query_projection(queries)

输出:

text
(2, 6, 16) -> Linear(16, 16) -> (2, 6, 16)

然后再:

python
queries = queries.view(B, L, H, -1)

变成:

text
(2, 6, 16) -> (2, 6, 2, 8)

注意:

text
Linear 负责投影
view 负责拆多头

不要把这两步混在一起。

4. Level 4:PatchTST 的 Linear(patch_len, d_model)

源码位置:

python
# Embed.py -> PatchEmbedding
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)

PatchTST 先用 unfold 得到 patch:

text
x.shape = (B*C, patch_num, patch_len)
toy: (8, 6, 4)

再执行:

python
x = self.value_embedding(x)

如果:

text
patch_len = 4
d_model = 16

那么:

text
(8, 6, 4) -> Linear(4, 16) -> (8, 6, 16)

含义是:

每个 patch 是一个长度为 patch_len 的局部时间片段,Linear 把它投影成一个 d_model 维 token。

5. Level 5:PatchTST 的 FlattenHead

源码:

python
class FlattenHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x

toy:

text
x.shape = (B, C, d_model, patch_num) = (2, 4, 16, 6)

Flatten(start_dim=-2)

text
(2, 4, 16, 6) -> (2, 4, 96)

因为:

text
16 * 6 = 96 = nf

然后:

text
Linear(nf=96, target_window=24)
(2, 4, 96) -> (2, 4, 24)

最后父层再:

python
dec_out = dec_out.permute(0, 2, 1)

得到:

text
(2, 4, 24) -> (2, 24, 4)

6. Level 6:可算小例子

假设:

python
linear = nn.Linear(3, 2, bias=False)

权重:

text
weight =
[
  [1, 0, 0],
  [0, 1, 1],
]

输入一个向量:

text
x = [10, 20, 30]

输出:

text
y0 = 1*10 + 0*20 + 0*30 = 10
y1 = 0*10 + 1*20 + 1*30 = 50

y = [10, 50]

如果输入是:

text
x.shape = (2, 6, 3)

那么这个 Linear 会对 2*6=12 个长度为 3 的向量分别做同样的计算,输出:

text
y.shape = (2, 6, 2)

7. 常见错误

7.1 以为 Linear 只能处理二维

错误理解:

text
Linear 只能吃 (batch, feature)

正确理解:

text
Linear 可以吃 (..., feature)

只要最后一维等于 in_features

7.2 在 DLinear 里忘了先 permute

DLinear 想让 Linear 从历史时间长度映射到预测长度。

所以要先:

text
(B, T, C) -> (B, C, T)

让最后一维变成 T=seq_len

如果不 permute,最后一维是 C,Linear 就会错误地在变量维上投影。

8. 一句话总结

nn.Linear 的统一理解是:

前面的维度都保留,只把最后一维从 in_features 投影成 out_features

*记录并在线阅读我的笔记*