Skip to content

Level 7 batch_forecast 主链

Abstract

入口:

python
predicts = model.batch_forecast(horizon, predict_batch_maker)

这一层解释:

预测子块怎样进入 batch_forecast(...),并把 batch maker 提供的 rolling 请求转成真正的模型预测输入。

1. 上一层与当前层位置

上一层是:

当前层是:

  • Level 7:只细分 model.batch_forecast(...) 这一层。

当前层对应上一层中的具体位置是:

1.5 这一层的关系类型说明

这一层是在继续下钻:

  • Level 6 里的 5C 预测子块
  • 其中真正进入模型侧的入口 model.batch_forecast(...)

所以这里的 7P-1 ... 7P-5

  • batch_forecast(...) 函数内部的并列逻辑块
  • 不是新的外层主线节点

2. 当前层第一性

batch_forecast(...) 的第一性是:

从 batch maker 取出一批 rolling 预测请求,完成设备准备、归一化、时间标记补齐,然后把这些请求送入真正的 rolling 前向函数 _perform_rolling_predictions(...)

这一层不直接算 metric,也不组织 rolling 起点。
它负责的是:

  1. 从 batch maker 拿一批输入
  2. 处理 exog 和 norm
  3. 构造 all_mark
  4. _perform_rolling_predictions(...)
  5. 把预测结果反归一化并裁回目标维度

3. 当前命令的最小例子

当前例子里的关键参数:

  • horizon = 24
  • batch_size = 4
  • seq_len = 96
  • label_len = 48
  • enc_in = 7
  • predict_batch_maker 每次会给出一批 rolling 输入

这一层拿到的 input_data 语义是:

  • input_data["input"]
  • input_data["time_stamps"]
  • input_data["covariates"]

4. 先看顺序图

5. 抽象索引树

6. 职责树

6.1 7P-1 批量输入准备

职责:

  • batch_maker 拿一个预测 batch

输出:

  • input_data
  • input_np

6.2 7P-2 数据预处理

职责:

  • 拼接外生变量 exog
  • 对输入做归一化

输出:

  • 处理后的 input_np

6.3 7P-3 时间标记准备

职责:

  • 根据时间戳补齐未来时间标记

输出:

  • all_mark

6.4 7P-4 调内部 rolling 预测

职责:

  • _perform_rolling_predictions(...)

输出:

  • answers

6.5 7P-5 结果还原

职责:

  • 对预测结果反归一化
  • 只保留目标维度

输出:

  • 最终 answers[..., :series_dim]

7. 输入输出接口

7.1 本层入口接口

python
batch_forecast(horizon: int, batch_maker: BatchMaker, **kwargs) -> np.ndarray

7.2 本层关键输入

  • horizon
  • batch_maker
  • self.config
  • self.scaler

7.3 本层关键中间变量

  • input_data
  • input_np
  • series_dim
  • all_mark
  • answers

7.4 本层输出

  • answers: np.ndarray
    • 语义:这一批 rolling 请求的预测结果

8. 函数 / 文件关系图

9. 关键代码对应关系

9.1 从 batch maker 取输入

python
input_data = batch_maker.make_batch(self.config.batch_size, self.config.seq_len)
input_np = input_data["input"]

语义:

从预测子块准备好的 batch maker 中,取出当前这一批 rolling 输入窗口。

9.2 处理 exog

python
exog_data = covariates.get("exog")
if exog_data is not None:
    input_np = np.concatenate((input_np, exog_data), axis=2)

语义:

如果这批预测请求带外生变量,把它拼到特征维上。

9.3 构造时间标记

python
input_index = input_data["time_stamps"]
padding_len = (math.ceil(horizon / self.config.horizon) + 1) * self.config.horizon
all_mark = self._padding_time_stamp_mark(input_index, padding_len)

语义:

根据已有窗口时间戳,再补出未来若干步时间标记,供后续 rolling 预测使用。

9.4 调内部 rolling 预测

python
answers = self._perform_rolling_predictions(horizon, input_np, all_mark, device)

语义:

这一步才真正开始在模型内部按滚动方式生成未来预测。

10. 当前例子的具体落地

10.1 当前批输入

由于:

  • batch_size = 4
  • seq_len = 96

所以当前一批 input_np 的典型 shape 语义是:

  • (4, 96, 7) 或拼 exog 后更宽

10.2 时间标记

input_data["time_stamps"] 提供的是每个 rolling 输入窗口的时间索引,后续被补成:

  • all_mark

它会覆盖:

  • encoder 历史窗口
  • decoder 历史尾部
  • 未来 horizon 的时间标记

10.3 最终返回

这一层返回的是:

  • 当前这批 rolling 请求对应的预测数组

11. 当前层最重要的认知

这一层最重要的认知是:

batch_forecast(...) 不是直接做一次前向,而是先把 batch maker 提供的预测请求变成模型能连续滚动使用的输入格式,再交给 _perform_rolling_predictions(...)

12. 下一层入口

如果继续 DFS,下一层最自然的入口是:

python
_perform_rolling_predictions(horizon, input_np, all_mark, device)

对应:

也就是说,当前层里的:

  • 7P-4 调内部 rolling 预测

已经有下一层详细文档。

13. 只留一句

Level 7 的预测线主链只看 batch_forecast(...) 怎样把 batch 请求推进到 _perform_rolling_predictions(...)

*记录并在线阅读我的笔记*