Skip to content

Level 8 _perform_rolling_predictions

Abstract

入口:

python
_perform_rolling_predictions(horizon, input_np, all_mark, device)

这一层解释:

预测线里,模型怎样在内部反复调用 _process(...),把一次次 horizon 预测拼接成最终长度为 horizon 的输出。

1. 上一层与当前层位置

上一层是:

当前层是:

  • Level 8:预测线最底部的 rolling 内循环。

1.5 这一层的关系类型说明

这一层是在继续下钻:

  • Level 7 里的 7P-4 调内部 rolling 预测

所以这里的 8P-1 ... 8P-5

  • _perform_rolling_predictions(...) 内部的并列逻辑块
  • 不是新的外层流程节点

2. 当前层第一性

_perform_rolling_predictions(...) 的第一性是:

在模型内部不断重复“构造 rolling 输入 -> 调 _process(...) -> 取出新预测 -> 拼回输入”,直到累积的预测长度达到目标 horizon

这一层和训练线的 Level 8 很像,但语义不同:

  • 训练线 Level 8:单次 batch 前向
  • 预测线 Level 8:多轮 rolling 前向拼接

3. 当前命令的最小例子

当前例子的关键参数:

  • horizon = 24
  • self.config.horizon = 24
  • seq_len = 96
  • label_len = 48

在当前 DLinear 例子里,由于:

  • 目标预测长度就是 24
  • 每次 _process(...) 也会给出 24 步预测

所以很多情况下循环只跑一轮就够。

但代码写成 rolling 的形式,是为了兼容:

  • 目标 horizon 比单次输出更长的情况

4. 先看顺序图

5. 抽象索引树

6. 职责树

6.1 8P-1 初始 rolling 数据构造

职责:

  • 根据 input_npall_mark 构造第一轮的:
    • input_np
    • target_np
    • input_mark_np
    • target_mark_np

6.2 8P-2 单轮前向

职责:

  • 把 numpy 批数据转成 tensor
  • _process(...)
  • 得到 output

6.3 8P-3 累积答案

职责:

  • output 中取出当前轮生成的一段预测
  • 追加到 answers

6.4 8P-4 更新下一轮输入

职责:

  • 把本轮输出拼回输入序列
  • 更新下一轮 rolling 使用的数据块

6.5 8P-5 合并最终答案

职责:

  • 把多轮 answer 按时间维拼接
  • 截取最终需要的 horizon 长度

7. 输入输出接口

7.1 本层入口接口

python
_perform_rolling_predictions(
    horizon: int,
    input_np: np.ndarray,
    all_mark: np.ndarray,
    device: torch.device,
)

7.2 本层关键输入

  • horizon
  • input_np
  • all_mark
  • device

7.3 本层关键中间变量

  • rolling_time
  • target_np
  • input_mark_np
  • target_mark_np
  • answers
  • output

7.4 本层输出

  • answers[:, -horizon:, :]

8. 函数 / 文件关系图

9. 关键代码对应关系

9.1 初始 rolling 数据

python
input_np, target_np, input_mark_np, target_mark_np = self._get_rolling_data(
    input_np, None, all_mark, rolling_time
)

语义:

把第一轮预测所需的 encoder / decoder 输入全部准备好。

9.2 单轮前向

python
out_loss = self._process(input, dec_input, input_mark, target_mark)
output = out_loss["output"]

这里再次汇入你训练线已经学过的那条底层主链:

  • TransformerAdapter._process(...)
  • DLinear.forward(...)

9.3 取当前轮答案

python
answer = output.cpu().numpy().reshape(real_batch_size, -1, column_num)[
    :, -self.config.horizon :, :
]
answers.append(answer)

语义:

从本轮输出中取出最后 config.horizon 步,作为当前轮生成的新预测片段。

9.4 更新下一轮输入

python
output = output.cpu().numpy()[:, -self.config.horizon :, :]
(
    input_np,
    target_np,
    input_mark_np,
    target_mark_np,
) = self._get_rolling_data(input_np, output, all_mark, rolling_time)

语义:

把当前轮的预测结果拼回输入,生成下一轮 rolling 所需的新输入块。

10. 当前例子的具体落地

10.1 当前例子为什么循环可能只跑一次

因为:

  • horizon = 24
  • self.config.horizon = 24

而单轮 _process(...) 返回的 output 本身就有:

  • 24 步预测

所以一轮就已经达到目标长度。

10.2 为什么还要写成 rolling 形式

因为框架要兼容:

  • 目标预测长度更长
  • 单轮输出长度较短

这时就需要多轮滚动拼接。

11. 当前层最重要的认知

这一层最重要的认知是:

预测线的真正“内循环”不在 _eval_batch(...),而在 _perform_rolling_predictions(...)。它通过重复调用 _process(...) 来累积未来预测。

12. 下一层入口

如果继续 DFS,这一层后面最自然的两个方向是:

  1. TransformerAdapter._process(...) 的预测语义
  2. _get_rolling_data(...) 怎样构造 decoder 输入和时间标记

但对当前 benchmark 主链理解来说,可以先停在这一层。

13. 只留一句

Level 8 的预测线核心是:不断调用 _process(...),把每一轮新预测拼回输入,直到得到完整 horizon。

*记录并在线阅读我的笔记*