Skip to content

DLinear Torch 函数解释

Abstract

这篇不讲模型主线,只解释 DLinear 文档里反复出现的 torch / nn 函数。

1. permute

典型代码:

python
x = x.permute(0, 2, 1)

作用:

  • 重新排列维度顺序,不改数据本身。

在 DLinear 里为什么用:

  • 原始 x(B, seq_len, C)
  • 线性层要作用在最后一维
  • 所以先变成 (B, C, seq_len),让 seq_len 处在最后一维

toy 例子:

text
原始:
(1, 4, 2)
[
  [1, 10],
  [2, 11],
  [3, 12],
  [4, 13],
]

permute(0, 2, 1) 后:
(1, 2, 4)
[
  [1, 2, 3, 4],
  [10,11,12,13],
]

2. repeat

典型代码:

python
front = x[:, 0:1, :].repeat(1, k, 1)

作用:

  • 按指定维度重复张量。

在 DLinear 里为什么用:

  • moving_avg 做首尾补边。

toy 例子:

text
x[:, 0:1, :] = [ [1, 10] ]
repeat(1, 2, 1) 后:
[
  [1, 10],
  [1, 10],
]

3. torch.cat

典型代码:

python
torch.cat([front, x, end], dim=1)

作用:

  • 按指定维度拼接多个张量。

在 DLinear 里为什么用:

  • moving_avg 里拼补边
  • _process(...) 里拼 label_len 历史段和未来零占位

4. AvgPool1d

典型代码:

python
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=0)

作用:

  • 沿长度维做滑动平均。

在 DLinear 里为什么用:

  • 近似提取趋势项 trend

toy 例子:

text
输入: [1, 1, 2, 3, 4, 4]
kernel_size = 3

输出:
[ (1+1+2)/3,
  (1+2+3)/3,
  (2+3+4)/3,
  (3+4+4)/3 ]
=
[4/3, 2, 3, 11/3]

5. nn.Linear

典型代码:

python
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)

作用:

  • 对最后一维做线性映射。

在 DLinear 里为什么用:

  • 把长度为 seq_len 的历史向量直接映射成长度为 pred_len 的未来向量。

toy 例子:

text
输入向量: [-1/3, 0, 0, 1/3]
权重:
[ [1,0,0,0],
 [0,0,0,1] ]

输出:
[-1/3, 1/3]

6. torch.zeros_like

典型代码:

python
dec_input = torch.zeros_like(target[:, -pred_len:, :])

作用:

  • 按已有张量的 shape 和 dtype 造一个全零张量。

在 DLinear 里为什么用:

  • 统一 _process(...) 接口时,给未来部分留零占位。

7. ModuleList

典型代码:

python
self.Linear_Seasonal = nn.ModuleList()

作用:

  • 存多个子模块,让 PyTorch 正确注册参数。

在 DLinear 里为什么用:

  • individual=True 时,每个变量各自一套线性层。

8. nn.Parameter

典型代码:

python
self.Linear_Seasonal.weight = nn.Parameter(
    (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
)

作用:

  • 把一个 tensor 明确声明为可训练参数。

在 DLinear 里为什么用:

  • 用均匀初始化方式初始化线性头权重。

9. 最短建议

如果你在 DLinear 文档里又看到这些函数:

  • permute
  • repeat
  • cat
  • AvgPool1d
  • Linear

优先先回答两个问题:

  1. 它在改哪一维?
  2. 它在总体里的作用是什么?

*记录并在线阅读我的笔记*