Appearance
DLinear Torch 函数解释
Abstract
这篇不讲模型主线,只解释 DLinear 文档里反复出现的 torch / nn 函数。
1. permute
典型代码:
python
x = x.permute(0, 2, 1)作用:
- 重新排列维度顺序,不改数据本身。
在 DLinear 里为什么用:
- 原始
x是(B, seq_len, C) - 线性层要作用在最后一维
- 所以先变成
(B, C, seq_len),让seq_len处在最后一维
toy 例子:
text
原始:
(1, 4, 2)
[
[1, 10],
[2, 11],
[3, 12],
[4, 13],
]
permute(0, 2, 1) 后:
(1, 2, 4)
[
[1, 2, 3, 4],
[10,11,12,13],
]2. repeat
典型代码:
python
front = x[:, 0:1, :].repeat(1, k, 1)作用:
- 按指定维度重复张量。
在 DLinear 里为什么用:
- 为
moving_avg做首尾补边。
toy 例子:
text
x[:, 0:1, :] = [ [1, 10] ]
repeat(1, 2, 1) 后:
[
[1, 10],
[1, 10],
]3. torch.cat
典型代码:
python
torch.cat([front, x, end], dim=1)作用:
- 按指定维度拼接多个张量。
在 DLinear 里为什么用:
moving_avg里拼补边_process(...)里拼label_len历史段和未来零占位
4. AvgPool1d
典型代码:
python
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=0)作用:
- 沿长度维做滑动平均。
在 DLinear 里为什么用:
- 近似提取趋势项
trend
toy 例子:
text
输入: [1, 1, 2, 3, 4, 4]
kernel_size = 3
输出:
[ (1+1+2)/3,
(1+2+3)/3,
(2+3+4)/3,
(3+4+4)/3 ]
=
[4/3, 2, 3, 11/3]5. nn.Linear
典型代码:
python
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)作用:
- 对最后一维做线性映射。
在 DLinear 里为什么用:
- 把长度为
seq_len的历史向量直接映射成长度为pred_len的未来向量。
toy 例子:
text
输入向量: [-1/3, 0, 0, 1/3]
权重:
[ [1,0,0,0],
[0,0,0,1] ]
输出:
[-1/3, 1/3]6. torch.zeros_like
典型代码:
python
dec_input = torch.zeros_like(target[:, -pred_len:, :])作用:
- 按已有张量的 shape 和 dtype 造一个全零张量。
在 DLinear 里为什么用:
- 统一
_process(...)接口时,给未来部分留零占位。
7. ModuleList
典型代码:
python
self.Linear_Seasonal = nn.ModuleList()作用:
- 存多个子模块,让 PyTorch 正确注册参数。
在 DLinear 里为什么用:
- 当
individual=True时,每个变量各自一套线性层。
8. nn.Parameter
典型代码:
python
self.Linear_Seasonal.weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
)作用:
- 把一个 tensor 明确声明为可训练参数。
在 DLinear 里为什么用:
- 用均匀初始化方式初始化线性头权重。
9. 最短建议
如果你在 DLinear 文档里又看到这些函数:
permuterepeatcatAvgPool1dLinear
优先先回答两个问题:
- 它在改哪一维?
- 它在总体里的作用是什么?