Skip to content

4A series_decomp 与 moving_avg

Abstract

这一篇是 04-Level4-encoder分解与线性预测总览 里的 4A 子块。

它只讲一件事:

怎样通过滑动平均,把输入序列拆成 trendseasonal

1. 当前层第一性

这一层存在的第一性是:

先把序列里变化慢的部分提出来,再把剩下的快变化部分单独留下。

在 DLinear 里:

  • trend
    • 更像“平滑背景”
  • seasonal
    • 更像“残差 / 快变化项”

2. 上下文

父节点:

下一层:

当前入口接口:

python
seasonal_init, trend_init = self.decompsition(x)

当前出口接口:

python
seasonal_init.shape = trend_init.shape = (B, seq_len, C)

3. 顺序图

4. 抽象树

5. 完整代码

位置:

python
class moving_avg(nn.Module):
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x

class series_decomp(nn.Module):
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean

6. 中文注释版完整代码

python
class moving_avg(nn.Module):
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # 为了保持长度不变,先在前后各补一段边界值
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)

        # AvgPool1d 要吃 [B, C, L],所以先 permute
        x = self.avg(x.permute(0, 2, 1))

        # 再 permute 回 [B, L, C]
        x = x.permute(0, 2, 1)
        return x

class series_decomp(nn.Module):
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean

7. 固定 toy 例子

固定:

  • kernel_size = 3
  • stride = 1
  • B = 1
  • seq_len = 4
  • C = 2

toy 输入:

text
x =
[
  [1, 10],
  [2, 11],
  [3, 12],
  [4, 13],
]

8. 代码块 1:前后补边

代码:

python
front = x[:, 0:1, :].repeat(1, 1, 1)
end = x[:, -1:, :].repeat(1, 1, 1)
x = torch.cat([front, x, end], dim=1)

8.1 输入/输出语义

输入:

  • 原始序列 x: (1, 4, 2)

输出:

  • 补边后的序列 x_pad: (1, 6, 2)

为什么要这么做:

  • AvgPool1d(kernel_size=3) 会看局部窗口
  • 不补边的话,首尾两端无法对齐长度

8.2 toy 张量逐步演变

原始:

text
[
  [1, 10],
  [2, 11],
  [3, 12],
  [4, 13],
]

补边后:

text
[
  [1, 10],
  [1, 10],
  [2, 11],
  [3, 12],
  [4, 13],
  [4, 13],
]

9. 代码块 2:AvgPool1d 做滑动平均

代码:

python
x = self.avg(x.permute(0, 2, 1))

9.1 输入/输出语义

输入:

  • x_pad: (B, seq_len + 2, C)

经过 permute(0,2,1) 后:

  • (B, C, L)

输出:

  • moving_mean: (B, seq_len, C)

9.2 toy 张量逐步演变

第 1 个变量通道:

text
[1, 1, 2, 3, 4, 4]

长度为 3 的窗口依次是:

text
窗口1: [1, 1, 2] -> 平均 = 4/3
窗口2: [1, 2, 3] -> 平均 = 2
窗口3: [2, 3, 4] -> 平均 = 3
窗口4: [3, 4, 4] -> 平均 = 11/3

第 2 个变量通道:

text
[10, 10, 11, 12, 13, 13]

窗口平均:

text
窗口1: [10, 10, 11] -> 31/3
窗口2: [10, 11, 12] -> 11
窗口3: [11, 12, 13] -> 12
窗口4: [12, 13, 13] -> 38/3

所以:

text
moving_mean =
[
  [4/3, 31/3],
  [2,   11],
  [3,   12],
  [11/3,38/3],
]

这一段在总体里的作用:

把原序列里变化慢的“平滑背景”提出来。

10. 代码块 3:x - moving_mean

代码:

python
res = x - moving_mean

10.1 输入/输出语义

输入:

  • x
    • 原始序列。
  • moving_mean
    • 平滑后的趋势项。

输出:

  • res
    • 原序列减去趋势项后的残差,也就是 seasonal。

10.2 toy 张量逐步演变

text
seasonal =
[
  [1-4/3,   10-31/3],
  [2-2,     11-11],
  [3-3,     12-12],
  [4-11/3,  13-38/3],
]
=
[
  [-1/3, -1/3],
  [0,    0],
  [0,    0],
  [1/3,  1/3],
]

这一段在总体里的作用:

把“去掉平滑背景之后剩下的波动”单独留下,供 seasonal 线性头处理。

11. 当前层最该固定什么

  1. moving_avg 不是随便平滑一下,它就是 DLinear 里 trend 的定义方式。
  2. trend = moving_mean
  3. seasonal = x - moving_mean
  4. 两路张量 shape 都保持 (B, seq_len, C)

12. 下一步

继续看:

补查 torch 算子:

*记录并在线阅读我的笔记*