Skip to content

Level 7 forecast_fit 训练循环

Abstract

入口:

python
for epoch in range(config.num_epochs):
    for i, (input, target, input_mark, target_mark) in enumerate(self.train_data_loader):

这一层解释:

forecast_fit(...) 怎样真正进入训练循环,并在每个 batch 上完成一次前向、算 loss、反向传播和参数更新。

1. 上一层与当前层位置

上一层是:

当前层是:

  • Level 7:只细分 forecast_fit(...) 的训练循环本体。

当前层对应上一层中的具体位置是:

这一层先不继续展开:

  • validate(...)
  • batch_forecast(...)
  • _perform_rolling_predictions(...)

1.5 这一层的关系类型说明

这一层是在继续下钻:

  • Level 6 训练主链里的最后一段
  • 也就是 forecast_fit(...) 真正进入 epoch-batch 循环之后的内部步骤

所以这里的 7A / 7B / 7C / 7D / 7E

  • 训练循环内部的并列逻辑块
  • 不是新的外层主线节点

2. 当前层第一性

训练循环的第一性是:

self.train_data_loader 吐出的每个 batch,做一次标准深度学习训练步骤:

取 batch -> 前向 -> 对齐监督目标 -> 算损失 -> backward -> optimizer.step

在当前代码里,这个通用训练步骤被包在:

python
for epoch ...
    for batch ...

里反复执行。

3. 当前命令的最小例子

当前例子的关键参数是:

  • num_epochs = 1
  • batch_size = 4
  • seq_len = 96
  • label_len = 48
  • horizon = 24
  • enc_in = dec_in = c_out = 7

所以训练循环里一个 batch 的典型 shape 是:

  • input.shape = (4, 96, 7)
  • target.shape = (4, 72, 7)
  • input_mark.shape = (4, 96, 4)
  • target_mark.shape = (4, 72, 4)

4. 先看顺序图

5. 抽象索引树

6. 职责树

6.1 7A 取 batch

职责:

  • self.train_data_loader 取一个 batch
  • 获得四元组:
    • input
    • target
    • input_mark
    • target_mark

6.2 7B 模型前向

职责:

  • self._process(...)
  • 得到:
    • output
    • 可选 additional_loss

6.3 7C 监督目标对齐

职责:

  • target 中截出真正监督的未来窗口
  • 保证 outputtarget 在 loss 计算时形状对齐

6.4 7D 损失与反向传播

职责:

  • 计算任务 loss
  • 加上可选额外 loss
  • backward()
  • optimizer.step()

6.5 7E 训练外层控制

职责:

  • 学习率调整
  • validation
  • early stopping

7. 输入输出接口

7.1 本层入口接口

训练循环批处理入口是:

python
for i, (input, target, input_mark, target_mark) in enumerate(self.train_data_loader)

7.2 本层关键输入

  • self.train_data_loader
  • criterion
  • optimizer
  • config
  • series_dim

7.3 本层关键中间变量

  • out_loss
  • output
  • additional_loss
  • loss
  • total_loss

7.4 本层输出

这一层没有单独 return 某个对象。
它的输出主要是副作用:

  • 模型参数被更新
  • check_point 可能被刷新
  • early_stopping 状态可能被更新

8. 函数 / 文件关系图

9. 关键代码对应关系

9.1 取 batch

python
for i, (input, target, input_mark, target_mark) in enumerate(self.train_data_loader):

这一步语义是:

从阶段一你已经学过的 DataLoader 里取一个训练 batch。

9.2 前向入口

python
out_loss = self._process(input, target, input_mark, target_mark)

这里是训练循环最关键的一行。
它把控制权交给 adapter。

9.3 取 output

python
output = out_loss["output"]
if "additional_loss" in out_loss:
    additional_loss = out_loss["additional_loss"]

说明:

  • 标准输出在 out_loss["output"]
  • 某些模型还可能带额外 loss

9.4 监督目标对齐

python
target = target[:, -config.horizon :, :series_dim]
output = output[:, -config.horizon :, :series_dim]

这一段的语义是:

只对未来 horizon 这段窗口做监督。

在当前例子里:

  • target 原始是 (4, 72, 7)
  • 截完后变成 (4, 24, 7)
  • output 也对齐成 (4, 24, 7)

9.5 损失与更新

python
loss = criterion(output, target)
total_loss = loss + additional_loss
total_loss.backward()
optimizer.step()

这就是标准训练步骤。

10. 当前例子的具体落地

10.1 一个 batch 的输入

  • input.shape = (4, 96, 7)
  • target.shape = (4, 72, 7)
  • input_mark.shape = (4, 96, 4)
  • target_mark.shape = (4, 72, 4)

10.2 前向后的输出

在当前 DLinear 例子里:

  • output.shape = (4, 24, 7)

10.3 loss 对齐

训练前:

  • target.shape = (4, 72, 7)

监督时截断为:

  • target[:, -24:, :].shape = (4, 24, 7)

所以 loss 真正比较的是:

  • output.shape = (4, 24, 7)
  • target.shape = (4, 24, 7)

11. 当前层最重要的认知

这一层最重要的认知是:

forecast_fit(...) 的训练循环并不直接知道 DLinear 的内部结构;它只统一依赖 self._process(...) 提供 output,然后再在外层完成 loss 和 optimizer 更新。

这说明:

  • 训练循环是通用外壳
  • adapter 和底层模型负责前向语义

12. 下一层入口

如果继续 DFS,下一层最自然的入口是:

python
self._process(input, target, input_mark, target_mark)

也就是:

  • TransformerAdapter._process(...)
  • DLinear.forward(...)

对应下层文档:

13. 只留一句

Level 7 只看训练循环怎样对每个 batch 完成一次“前向 -> 对齐目标 -> loss -> backward -> step”。

*记录并在线阅读我的笔记*