10.1.4.9. PwcNet光流预测模型训练

这篇教程主要是告诉大家如何利用HAT在光流数据集 FlyingChairs 上从头开始训练一个 PwcNet 模型,包括浮点、量化和定点模型。

FlyingChairs 是光流预测中用的比较多的数据集,很多先进的光流预测研究都会优先基于这个数据集做好验证。 开始训练模型之前,第一步是准备好数据集,这里我们下载官方的数据集 FlyingChairs.zip 作为训练和验证集。 同时需要下载相应的标签数据 FlyingChairs_train_val.txt , 解压缩之后数据目录结构如下所示:

tmp_data
|-- FlyingChairs
  |-- FlyingChairs_release
    |-- data
    |-- README.txt
  |-- FlyingChairs_train_val.txt
  |-- FlyingChairs.zip

10.1.4.9.1. 训练流程

如果你只是想简单的把 PwcNet 的模型训练起来,那么可以首先阅读一下这一章的内容。 和其他任务一样,对于所有的训练,评测任务,HAT统一采用 tools + config 的形式来完成。 在准备好原始数据集之后,可以通过下面的流程,方便地完成整个训练的流程。

10.1.4.9.1.1. 数据集准备

为了提升训练速度,我们对原始的数据集做了一个打包,将其转换为 LMDB 格式的数据集。只需要运行下面的脚本,就可以成功实现转换:

python3 tools/datasets/flyingchairs_packer.py --src-data-dir ${data-dir} --split-name train --pack-type lmdb  --num-workers 10
python3 tools/datasets/flyingchairs_packer.py --src-data-dir ${data-dir} --split-name val --pack-type lmdb  --num-workers 10

上面这两条命令分别对应转换训练数据集和验证数据集,打包完成之后,目录下的文件结构应该如下所示:

tmp_data
|-- FlyingChairs
  |-- FlyingChairs_release
    |-- data
    |-- README.txt
  |-- FlyingChairs_train_val.txt
  |-- FlyingChairs.zip
  |-- train_lmdb
  |-- val_lmdb

train_lmdbval_lmdb 就是打包之后的训练数据集和验证数据集,接下来就可以开始训练模型。

10.1.4.9.1.2. 模型训练

在网络开始训练之前,你可以使用以下命令先计算一下网络的计算量和参数数量:

python3 tools/calops.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

下一步就可以开始训练。训练也可以通过下面的脚本来完成,在训练之前需要确认配置中数据集路径是否已经切换到已经打包好的数据集路径。

python3 tools/train.py --stage "float" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py
python3 tools/train.py --stage "calibration" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py
python3 tools/train.py --stage "qat" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py
python3 tools/train.py --stage "int_infer" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

由于HAT算法包使用了注册机制,使得每一个训练任务都可以按照这种 train.py 加上 config 配置文件的形式启动。 train.py 是统一的训练脚本, 与任务无关,我们需要训练什么样的任务、使用什么样的数据集以及训练相关的超参数设置都在指定的 config 配置文件里面。 上面的命令中 --stage 后面的参数可以是 "float""calibration""qat""int_infer",分别可以完成浮点模型、量化模型的训练以及量化模型到定点模型的转化, 其中量化模型的训练依赖于上一步浮点训练产出的浮点模型,定点模型的转化依赖于量化训练产生的量化模型。

10.1.4.9.1.3. 模型验证

在完成训练之后,可以得到训练完成的浮点、量化或定点模型。和训练方法类似,我们可以用相同方法来对训好的模型做指标验证, 得到为 FloatCalibrationQATQuantized 的指标,分别为浮点、量化和完全定点的指标。

python3 tools/predict.py --stage "float" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

python3 tools/predict.py --stage "calibration" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

python3 tools/predict.py --stage "qat" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

python3 tools/predict.py --stage "int_infer" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

和训练模型时类似, --stage 后面的参数为 "float""calibration""qat""int_infer" 时, 分别可以完成对训练好的浮点模型、量化模型、定点模型的验证。

10.1.4.9.1.4. 模型推理

HAT提供了 infer.py 脚本提供了对定点模型的推理结果进行可视化展示。

python3 tools/infer.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py --model-inputs img1:${img1-path},img2:${img2-path} --save-path ${save_path}

10.1.4.9.1.5. 仿真上板精度验证

除了上述模型验证之外,我们还提供和上板完全一致的精度验证方法,可以通过下面的方式完成:

python3 tools/align_bpu_validation.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

10.1.4.9.1.6. 定点模型检查和编译

在HAT中集成的量化训练工具链主要是为了地平线的计算平台准备的,因此,对于量化模型的检查和编译是必须的。 我们在HAT中提供了模型检查的接口,可以让用户定义好量化模型之后,先检查能否在 BPU 上正常运行:

python3 tools/model_checker.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

在模型训练完成后,可以通过 compile_perf 脚本将量化模型编译成可以上板运行的 hbm 文件,同时该工具也能预估在 BPU 上的运行性能:

python3 tools/compile_perf.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

以上就是从数据准备到生成量化可部署模型的全过程。

10.1.4.9.1.7. ONNX模型导出

如果想要导出onnx模型, 运行下面的命令即可:

python3 tools/export_onnx.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

10.1.4.9.2. 训练细节

在这个说明中,我们对模型训练需要注意的一些事项进行说明,主要为 config 的一些相关设置。

10.1.4.9.2.1. 模型构建

PwcNet 的网络结构可以参考 论文社区TensorFlow版本 ,这里不做详细介绍。 我们通过在 config 配置文件中定义 model 这样的一个 dict 型变量,就可以方便的实现对模型的定义和修改。

from torch import nn
loss_weights = [0.005, 0.01, 0.02, 0.08, 0.32]
out_channels = [16, 32, 64, 96, 128, 196]
flow_pred_lvl = 2
pyr_lvls = 6
use_bn = True
bn_kwargs = {}
use_res = True
use_dense = True

model = dict(
  type="PwcnetTask",
  backbone=dict(
      type="PwcNet",
      out_channels=out_channels,
      use_bn=use_bn,
      bn_kwargs=bn_kwargs,
      pyr_lvls=pyr_lvls,
      flow_pred_lvl=flow_pred_lvl,
      act_type=nn.ReLU(),
  ),
  head=dict(
      type="PwcnetHead",
      in_channels=out_channels,
      bn_kwargs=bn_kwargs,
      use_bn=use_bn,
      md=4,
      use_res=use_res,
      use_dense=use_dense,
      pyr_lvls=pyr_lvls,
      flow_pred_lvl=flow_pred_lvl,
      act_type=nn.ReLU(),
  ),
  loss=dict(type="LnNormLoss", norm_order=2, power=1, reduction="mean"),
  loss_weights=loss_weights,
)

模型除了 backbone 之外,还有 head``和 ``losses 模块,在PwcNet中, backbone``主要是提取两张图像的特征, ``head 主要是由特征来得到预测的光流图。 losses 部分采样论文中的LnNormLoss来作为训练的 loss, loss_weights``是特征层对应的 ``loss 和权重。

10.1.4.9.2.2. 数据增强

model 的定义一样,数据增强的流程是通过在 config 配置文件中定义 data_loaderval_data_loader 这两个 dict 来实现的, 分别对应着训练集和验证集的处理流程。以 data_loader 为例,数据增强使用了 RandomCropRandomFlipSegRandomAffineFlowRandomAffineScale

data_loader = dict(
    type=torch.utils.data.DataLoader,
    dataset=dict(
        type="FlyingChairs",
        data_path="./tmp_data/FlyingChairs/train_lmdb/",
        transforms=[
            dict(
                type="RandomCrop",
                size=(256, 448),
            ),
            dict(
                type="RandomFlip",
                px=0.5,
                py=0.5,
            ),
            dict(
                type="ToTensor",
                to_yuv=False,
            ),
            dict(
                type="SegRandomAffine",
                degrees=0,
                translate=(0.05, 0.05),
                scale=(1.0, 1.0),
                interpolation=InterpolationMode.BILINEAR,
                label_fill_value=0,
                translate_p=0.5,
                scale_p=0.0,
            ),
            dict(
                type="FlowRandomAffineScale",
                scale_p=0.5,
                scale_r=0.05,
            ),
        ],
        to_rgb=True,
    ),
    sampler=dict(type=torch.utils.data.DistributedSampler),
    batch_size=batch_size_per_gpu,
    pin_memory=True,
    shuffle=True,
    num_workers=4,
    collate_fn=hat.data.collates.collate_2d,
)

因为最终跑在 BPU 上的模型使用的是 YUV444 的图像输入,而一般的训练图像输入都采用 RGB 的形式, 所以HAT提供 BgrToYuv444 的数据增强来将 RGB 转到 YUV444 的格式。 为了优化训练过程,HAT使用了 batch_processor,可将一些增强处理放在 batch_processor 中优化训练:

def loss_collector(outputs: dict):
    return outputs["losses"]

batch_processor = dict(
    type="MultiBatchProcessor",
    need_grad_update=True,
    batch_transforms=[
        dict(type="BgrToYuv444", rgb_input=True),
        dict(
            type="TorchVisionAdapter",
            interface="Normalize",
            mean=128.0,
            std=128.0,
        ),
        dict(
            type="Scale",
            scales=tuple(1 / np.array(train_scales)),
            mode="bilinear",
        ),
    ],
    loss_collector=loss_collector,
)

其中 loss_collector 是一个获取当前批量数据的 loss 的函数。

验证集的数据转换相对简单很多,如下所示:

val_data_loader = dict(
    type=torch.utils.data.DataLoader,
    dataset=dict(
        type="FlyingChairs",
        data_path="./tmp_data/FlyingChairs/val_lmdb/",
        transforms=[
            dict(
                type="ToTensor",
                to_yuv=False,
            ),
        ],
        to_rgb=True,
    ),
    batch_size=batch_size_per_gpu,
    shuffle=False,
    num_workers=data_num_workers,
    pin_memory=True,
    collate_fn=hat.data.collates.collate_2d,
)
val_batch_processor = dict(
    type="MultiBatchProcessor",
    need_grad_update=False,
    batch_transforms=[
        dict(type="BgrToYuv444", rgb_input=True),
        dict(
            type="TorchVisionAdapter",
            interface="Normalize",
            mean=128.0,
            std=128.0,
        ),
    ],
    loss_collector=None,
)

10.1.4.9.2.3. 训练策略

FlyingChairs 数据集上训练浮点模型使用 Cosine 的学习策略配合 Warmup,以及对 weight 的参数施加L2 norm。 configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py 文件中的 float_trainercalibration_trainer, qat_trainerint_trainer 分别对应浮点、量化、定点模型的训练策略。下面以 float_trainer 训练策略示例:

float_trainer = dict(
    type="distributed_data_parallel_trainer",
    model=model,
    data_loader=data_loader,
    optimizer=dict(
        type=torch.optim.Adam,
        params={"weight": dict(weight_decay=4e-4)},
        lr=lr,
    ),
    batch_processor=batch_processor,
    stop_by="epoch",
    num_epochs=max_epoch,
    device=None,
    callbacks=[
        stat_callback,
        loss_metirc_show_update,
        dict(
            type="CosLrUpdater",
            warmup_by="epoch",
            warmup_len=10,
            step_log_interval=1000,
        ),
        val_callback,
        ckpt_callback,
    ],
    train_metrics=[
        dict(type="LossShow"),
        dict(type="EndPointError"),
    ],
    val_metrics=[
        dict(type="EndPointError"),
    ],
    sync_bn=True,
)

10.1.4.9.2.4. 量化训练

关于量化训练中的关键步骤,比如准备浮点模型、算子替换、插入量化和反量化节点、设置量化参数以及算子的融合等, 请阅读 量化感知训练 章节的内容。这里主要讲一下HAT的光流预测中如何定义和使用量化模型。

在模型准备的好情况下,包括量化已有的一些模块完成之后,HAT在训练脚本中统一使用下面的脚本将浮点模型映射到定点模型上来。

model.fuse_model()
model.set_qconfig()
horizon.quantization.prepare_qat(model, inplace=True)

量化训练的整体策略可以直接沿用浮点训练的策略,但学习率和训练长度需要适当调整。因为有浮点预训练模型,所以量化训练的学习率 Lr 可以很小, 一般可以从 0.001 或 0.0001 开始,并可以搭配 StepLrUpdater 做1-2次 scale=0.1Lr 调整; 同时训练的长度不用很长。此外 weight decay 也会对训练结果有一定影响。

PwcNet 示例模型的量化训练策略可见 configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py 文件。

10.1.4.9.2.5. 模型检查编译和仿真上板精度验证

对于HAT来说,量化模型的意义在于可以在BPU上直接运行。因此,对于量化模型的检查和编译是必须的。 前文提到的 compile_perf 脚本也可以让用户定义好量化模型之后,先检查能否在BPU上正常运行, 并可通过 align_bpu_validation 脚本获取模型上板精度。用法同前文。