6.1.3.2. config 文件介绍

使用 HAT 算法包训练模型通常只需使用一条命令就可以了,即:

python3 tools/train.py --stage TRAINING_STEPS --config /PATH/TO/CONFIG

其中 /PATH/TO/CONFIG 就是模型训练对应的 config 文件,它负责定义了模型结构、数据集加载、以及整套的训练流程。

本章节通过介绍config文件中一些固定的全局的关键字,以及它们是如何配置的,让用户对config文件的内容以及作用有个大致的了解。

6.1.3.2.1. 全局关键字

  • training_step:模型训练的各个阶段,包括 float

  • device_ids:模型训练使用的 gpu 列表。

  • cudnn_benchmark:是否打开cudnn benchmark。通常默认为 True

  • seed:是否设置随机数种子。通常默认为 None

  • log_rank_zero_only:简化多卡训练时的日志打印,只在第0卡上输出日志。通常默认为 True

  • model:参与 training 过程中的模型结构。type 表示模型的类型,如 ClassifierSegmentorRetinaNet等等,分别对应分类、分割、检测中的某一类模型。它会在使用过程中被 build 成具体的类,余下的参数都是用于初始化这个类。

  • test_model:参与 test 过程的模型结构,主要用于模型编译。和 model 相比,大多数情况下只需要把损失函数以及后处理部分设置为 None 即可。

  • test_inputs:test 过程的模拟输入。不用关心具体的数值,只要保证格式满足输入要求即可。

  • data_loader:训练阶段的数据集加载流程。它的 type 是一个具体的类 torch.utils.data.DataLoader ,余下的参数都是用于初始化这个类。相关参数的含义也可以参考 pytorch 官网提供的接口文档。这里 dataset 表示读取某个具体的数据集,例如ImageNetMSCOCOVOC等等,它的 transforms 表示在数据读取过程中添加的数据增强操作。

  • val_data_loader:验证模型性能阶段的数据集加载流程。和 data_loader 不同的地方在于 data_path 不同,以及去掉了 transforms 的过程和 sample 的过程。

  • batch_processor:模型在训练过程中每个迭代 step 进行的操作,包括前向计算、梯度回传、参数更新等等。如果包含 batch_transforms 参数,表示一些数据增强的操作是在 gpu 上进行的,这可以大大加快训练速度。

  • val_batch_processor:模型在验证过程中每个迭代 step 进行的操作,只包含前向计算。

  • metric_updater:模型训练过程中更新指标的方法,这个指标是用来验证训练的模型性能是否在提升。它通常是和 float_trainer 下面的 train_metrics 配合着使用。train_metrics 是具体的指标形式,metric_updater 只是提供一种更新方法。

  • val_metric_updater:训练出来的模型在验证性能的过程中更新指标的方法,这个指标用来验证最终训练出来的模型性能到底如何。它通常是和 float_trainer 下面的 val_metrics 配合着使用,和 metric_updater 同理。

  • float_trainer:浮点模型训练流程的配置。type 类型为 distributed_data_parallel_trainer 表示支持分布式训练,其余的参数分别定义了模型、数据集加载、优化器、训练 epoch 长度等等。其中 callbacks 表示训练过程中进行的一系列操作,比如模型保存、学习率更新、精度验证等等。

  • float_solver:直接被 tools/train.py 文件调用的变量,核心主要是 float_trainer

这里把float_solver的内容展开说一下:

float_solver = dict(
    trainer=float_trainer,
    quantize=False,
    # 配置 resume_checkpoint, 即 checkpoint 文件路径
    resume_checkpoint="./tmp_models/last_checkpoint.pth.tar",
    # 配置 resume_optimizer, 即是否恢复 optimizer, 默认为 True
    resume_optimizer=False,
    # 配置 resume_epoch_or_step, 即是否恢复 epoch(step) 计数, 默认
 True
    resume_epoch_or_step=False,
    ...,
)

恢复训练有三种使用场景:

  1. 完全恢复:该场景为恢复意外中断的训练,会恢复上一个checkpoint的所有状态,包括optimizer、LR、epoch、step等。该场景只需配置 resume_checkpoint 字段即可;

  2. 恢复optimizer用于fine-tune:该场景只会恢复optimizer和LR的状态,但epoch、step都会从0开始,用于某些任务的fine-tune。该场景需要配置 resume_checkpoint,并且需要配置 resume_epoch_or_step=False

  3. 只加载模型参数:该场景只会加载模型参数,不会恢复其他任何状态(optimizer、epoch、step、LR)。该场景需要配置 resume_checkpoint,并且需要配置 resume_optimizer=Falseresume_epoch_or_step=False

之所以称这些变量为全局关键字,是因为几乎每个 config 文件中都定义了以上这些变量,且对应的功能基本一致。因此通过对这篇文档的学习,用户可以大致理解任意一个 config 文件实现的功能。

6.1.3.2.2. 如何配置

这里主要介绍数据类型为 dict 的全局关键字的配置。

数据类型为 dict 的全局关键字可以为两种:

  1. 包含 type的,例如 modeldata_loaderfloat_trainer等。

  2. 不包含 type的,例如 float_solverstep2solver 等。

它们的区别在于包含 type 的全局关键字本质可以看作是一个 class,它的 type 值可以是一个 string 变量,也可以是一个具体的 class,如果是 string, 在程序运行中同样会被 build 成一个相应的 class。这个dict中除掉type之外的其它keys的值都用于初始化这个 class。和全局关键字属性类似,这些keys的值可以是一个数值,也可以是一个包含type变量的dict,例如 data_loader 中的 dataset 属性,以及这个 dataset 下面的 transforms 属性。

对于没有 type 变量的全局关键字来说,它就是一个普通类型的 dict 变量,代码在运行过程中会通过其 keys 获取对应的 values

提示

所有已经提供的config配置,可以保证正常运行和复现精度。如果因为环境配置和训练时间等原因,需要修改配置的话,那么相对的训练策略可能也需要更改。直接修改config中的个别配置有时候并不能得到想要的结果。