6.1.6.3. FCOS检测模型训练

这篇教程以FCOS-efficientnet为例,告诉大家如何使用HAT算法包训练一个定点的检测模型。在开始量化感知训练,也就是定点模型训练 之前,首先需要训练一个精度较高的纯浮点模型,然后基于这个纯浮点模型做finetune,就可以快速的训练出定点模型。 所以我们从训练一个纯浮点的FCOS-efficientnet模型开始讲起。

6.1.6.3.1. 数据集准备

在开始训练模型之前,第一步是需要准备好数据集,这里我们下载MSCOCO的 train2017.zipval2017.zip 做为网络的训练集和验证集,同时需要下载相应的标签数据 annotations_trainval2017.zip , 解压缩之后数据目录结构如下所示:

tmp_data
|-- mscoco
  -- annotations_trainval2017.zip
  |-- train2017.zip
  |-- val2017.zip
  |-- annotations
  |-- train2017
  |-- val2017

同时,为了提升训练的速度,我们对原始的jpg格式的数据集做了一个打包,将其转换成lmdb格式的数据集。只需要运行下面的脚本,就可以成功实现转换:

python3 tools/datasets/mscoco_packer.py --src-data-dir ./tmp_data/mscoco/ --target-data-dir ./tmp_data/mscoco --split-name train --pack-type lmdb
python3 tools/datasets/mscoco_packer.py --src-data-dir ./tmp_data/mscoco/ --target-data-dir ./tmp_data/mscoco --split-name val --pack-type lmdb

上面这两条命令分别对应着转换训练数据集和验证数据集,打包完成之后,data目录下的文件结构应该如下所示:

tmp_data
|-- mscoco
  |-- annotations
  |-- train2017
  |-- train_lmdb
  |-- val2017
  |-- val_lmdb

train_lmdbval_lmdb 就是打包之后的训练数据集和验证数据集,也是网络最终读取的数据集。

6.1.6.3.2. 浮点模型训练

数据集准备好之后,就可以开始训练浮点型的FCOS-efficientnet检测网络了。在网络训练开始之前,你可以使用以下命令先测试一下网络的计算量和参数数量:

python3 tools/calops.py --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py --input-shape "1,3,512,512"

如果你只是单纯的想启动这样的训练任务,只需要运行下面的命令就可以:

python3 tools/train.py --stage float --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py

由于HAT算法包使用了一种巧妙的注册机制,使得每一个训练任务都可以按照这种train.py加上config配置文件的形式启动。 train.py是统一的训练脚本,与任务无关,我们需要训练什么样的任务、使用什么样的数据集以及训练相关的超参数设置都在指定的config配置文件里面。 config文件里面提供了模型构建、数据读取等关键的dict。

6.1.6.3.2.1. 模型构建

fcos的网络结构可以参考 论文 , 这里不做详细介绍。我们通过在config配置文件中定义 model 这样的一个dict型变量,就可以方便的实现对模型的定义和修改。

model = dict(
    type="FCOS",
    backbone=dict(
        type="efficientnet",
        bn_kwargs=bn_kwargs,
        model_type='b0',
        num_classes=1000,
        include_top=False,
        activation='relu',
        use_se_block= False,
    ),
    neck=dict(
        type="BiFPN",
        in_strides=[2, 4,8, 16, 32],
        out_strides=[8, 16, 32, 64, 128],
        stride2channels=dict({2:16, 4: 24, 8:40, 16:112, 32:320}),
        out_channels=64,
        num_outs=5,
        stack=3,
        start_level=2,
        end_level=-1,
        fpn_name="bifpn_sum"
    ),
    head=dict(
        type="FCOSHead",
        num_classes=num_classes,
        in_strides=[8, 16, 32, 64, 128],
        out_strides=[8, 16, 32, 64, 128],
        stride2channels = dict({8:64, 16:64, 32:64, 64:64, 128:64}),
        upscale_bbox_pred=True,
        feat_channels=64,
        stacked_convs=4,
        int8_output=False,
        int16_output=True,
        dequant_output=True
    ),
    targets=dict(
        type="DynamicFcosTarget",
        strides=[8, 16, 32, 64, 128],
        cls_out_channels=80,
        background_label=80,
        topK=10,
        loss_cls=dict(
            type="FocalLoss",
            loss_name="cls",
            num_classes=80 + 1,
            alpha=0.25,
            gamma=2.0,
            loss_weight=1.0,
            reduction="none"
        ),
        loss_reg=dict(
            type="GIoULoss",
            loss_name="reg",
            loss_weight=2.0,
            reduction="none"
        ),
     ),
     post_process=dict(type="FCOSDecoder",
                       num_classes=80,
                       strides=[8, 16, 32, 64, 128],
                       nms_use_centerness=True,
                       nms_sqrt=True,
                       rescale=True,
                       test_cfg =dict(score_thr=0.05,
                       nms_pre=1000,
                       nms=dict(name = 'nms',
                                iou_threshold=0.6,
                                max_per_img=100)
                  )
    ),
    loss_cls=dict(
        type="FocalLoss",
        loss_name="cls",
        num_classes=80 + 1,
        alpha=0.25,
        gamma=2.0,
        loss_weight=1.0,
    ),
    loss_centerness=dict(
        type="CrossEntropyLossV2",
        loss_name="centerness",
        use_sigmoid=True
    ),
    loss_reg=dict(
        type="GIoULoss",
        loss_name="reg",
        loss_weight=1.0,
    )
)

其中, model 下面的 type 表示定义的模型名称,剩余的变量表示模型的其他组成部分。 这样定义模型的好处在于我们可以很方便的替换我们想要的结构。例如,如果我们想训练一个backbone为resnet50的模型, 只需要将 model 下面的 backbone 替换掉就可以。

6.1.6.3.2.2. 数据增强

model 的定义一样,数据增强的流程是通过在config配置文件中定义 data_loaderval_data_loader 这两个dict来实现的,分别对应着训练集和验证集的处理流程。以 data_loader 为例:

 dataset=dict(
    type="Coco",
    data_path="./tmp_data/coco/train_lmdb/",
    transforms=[
        dict(
            type="Resize",
            img_scale=(512, 512),
            ratio_range=(0.5, 2.0),
            keep_ratio=True,
        ),
        dict(
            type="RandomCrop",
            size=(512, 512)
        ),
        dict(
            type="Pad",
            divisor=512,
        ),
        dict(
            type="RandomFlip",
            px=0.5,
            py=0,
        ),
        dict(
            type = "AugmentHSV",
            hgain=0.015,
            sgain=0.7,
            vgain=0.4
        ),
        dict(
            type="ToTensor",
            to_yuv=True,
        ),
        dict(
            type="Normalize",
            mean=128.0,
            std=128.0,
        ),
    ],
    ),
    sampler=dict(type=torch.utils.data.DistributedSampler),
    batch_size=batch_size_per_gpu,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=dict(type="Collate2D"),
)

其中type直接用的pytorch自带的接口torch.utils.data.DataLoader,表示的是将 batch_size 大小的图片组合到一起。 这里面唯一需要关注的可能是 dataset 这个变量, CocoFromLMDB 表示从lmdb数据集中读取图片,路径也就是我们在 数据集准备 中提到的路径。 transforms 下面包含着一系列的数据增强。 val_data_loader 中除了图片翻转(RandomFlip),其他的数据变换和 data_loader 一致。 你也可以通过在 transforms 中插入新的dict实现自己希望的数据增强操作。

6.1.6.3.2.3. 训练策略

为了训练一个精度高的模型,好的训练策略是必不可少的。对于每一个训练任务而言,相应的训练策略同样都定义在其中的config文件中,从 float_trainer 这个变量 就可以看出来。

float_trainer = dict(
     type='distributed_data_parallel_trainer',
     model=model,
     data_loader=data_loader,
     optimizer=dict(
                    type=torch.optim.SGD,
                    params={"weight": dict(weight_decay=4e-5)},
                    lr=0.14,
                    momentum=0.937,
                    nesterov=True,
     ),
     batch_processor=batch_processor,
     num_epochs=300,
     device=None,
     callbacks=[
          stat_callback,
          loss_show_update,
          dict(type="ExponentialMovingAverage"),
          dict(
               type="CosLrUpdater",
               warmup_len=2,
               warmup_by="epoch",
               stage_log_interval=1,
          ),
          val_callback,
          ckpt_callback,
     ],
     train_metrics=dict(
          type="LossShow",
     ),
     sync_bn=True,
     val_metrics=dict(
          type="COCODetectionMetric",
          ann_file="./tmp_data/coco/annotations/instances_val2017.json",
    )
)

float_trainer 从大局上定义了我们的训练方式,包括使用多卡分布式训练(distributed_data_parallel_trainer),模型训练的epoch次数,以及优化器的选择。 同时 callbacks 中体现了模型在训练过程中使用到的小策略以及用户想实现的操作,包括学习率的变换方式(WarmupStepLrUpdater),在训练过程中验证模型的 指标(Validation),以及保存(Checkpoint)模型的操作。当然,如果你有自己希望模型在训练过程中实现的操作,也可以按照这种dict的方式添加。 float_trainer 负责将整个训练的逻辑给串联起来,其中也会负责模型的pretrain。

注解

如果需要复现精度,config中的训练策略最好不要修改,否则可能会有意外的训练情况出现。

通过上面的介绍,你应该对config文件的功能有了一个比较清楚的认识。然后通过前面提到的训练脚本,就可以训练一个高精度的纯浮点的检测模型。 当然训练一个好的检测模型不是我们最终的目的,它只是做为一个pretrain为我们后面训练定点模型服务的。

6.1.6.3.3. 量化模型训练

当我们有了纯浮点模型之后,就可以开始训练相应的定点模型了。和浮点训练的方式一样,我们只需要通过运行下面的脚本就可以训练定点模型了。 不过这里需要说明的是,FCOS在量化训练过程中建议加上calibration的流程。calibration可以为QAT的量化训练提供一个更好的初始化参数。

python3 tools/train.py --stage calibration --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py

python3 tools/train.py --stage qat --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py

可以看到,我们的配置文件没有改变,只改变了 stage 的类型。此时我们使用的训练策略来自于config文件中的qat_trainer。

qat_trainer = dict(
     type='distributed_data_parallel_trainer',
     model=model,
     data_loader=data_loader,
     optimizer=dict(
         type=torch.optim.SGD,
         params={"weight": dict(weight_decay=4e-5)},
         lr=0.001,
         momentum=0.9,
     ),
     batch_processor=batch_processor,
     num_epochs=10,
     device=None,
     callbacks=[
          stat_callback,
          loss_show_update,
          dict(
               type="StepDecayLrUpdater",
               lr_decay_id=[4],
               stage_log_interval=500,
          ),
          val_callback,
          ckpt_callback,
     ],
     train_metrics=dict(
          type="LossShow",
     ,
     val_metrics=dict(
          type="COCODetectionMetric",
          ann_file="./tmp_data/coco/annotations/instances_val2017.json",
     )
 )

6.1.6.3.3.1. quantize参数的值不同

当我们训练量化模型的时候,需要设置quantize=True,此时相应的浮点模型会被转换成量化模型,相关代码如下:

model.fuse_model()
model.set_qconfig()
horizon.quantization.prepare_qat(model, inplace=True)

关于量化训练中的关键步骤,比如准备浮点模型、算子替换、插入量化和反量化节点、设置量化参数以及算子的融合等,请阅读《Horizon Plugin PyTorch》手册中的 浮点模型准备算子融合 两节中的内容。

6.1.6.3.3.2. 训练策略不同

正如我们之前所说,量化训练其实是在纯浮点训练基础上的finetue。因此量化训练的时候,我们的初始学习率设置为浮点训练的十分之一, 训练的epoch次数也大大减少,最重要的是 model 定义的时候,我们的 pretrained 需要设置成已经训练出来的纯浮点模型的地址。

做完这些简单的调整之后,就可以开始训练我们的量化模型了。

6.1.6.3.3.3. 模型验证

模型训练完成之后,我们还可以验证训练出来的模型性能。由于我们提供了float和qat两阶段的训练过程,相应的我们可以验证这两个阶段训练出来的模型性能, 只需要相应的运行以下两条命令即可:

python3 tools/predict.py --stage float --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py --ckpt ${float-checkpoint-path}
python3 tools/predict.py --stage qat --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py --ckpt ${qat-checkpoint-path}

同时,我们还提供了quantization模型的性能测试,只需要运行以下命令:

python3 tools/predict.py --stage int_infer configs/detection/fcos/fcos_efficientnetb0_mscoco.py --ckpt ${int-infer-checkpoint-path}

这个显示出来的精度才是最终的int8模型的真正精度,当然这个精度和qat验证阶段的精度应该是保持十分接近的。

6.1.6.3.3.4. 量化训练

关于量化训练中的关键步骤,比如准备浮点模型、算子替换、插入量化和反量化节点、设置量化参数以及算子的融合等,请阅读 量化感知训练 章节的内容。这里主要讲一下HAT的分类中如何定义和使用量化模型。

在模型准备的好情况下,包括量化已有的一些模块完成之后,HAT在训练脚本中统一使用下面的脚本将浮点模型映射到定点模型上来。

qconfig_manager.set_qconfig_mode(qconfig_manager.QconfigMode.QAT)
model.set_qconfig()
model = horizon.quantization.prepare_qat_fx(model)

量化训练的策略并不统一,这里简单描述分类模型训练中的常见策略。

量化训练的整体策略可以直接沿用浮点训练的策略,但学习率和训练长度需要适当调整。因为有浮点预训练模型,所以量化训练的学习率 Lr 可以很小, 一般可以从0.001或0.0001开始,并可以搭配 StepLrUpdater 做1-2次 scale=0.1Lr 调整;同时训练的长度不用很长。 此外 weight decay 也会对训练结果有一定影响。

6.1.6.3.3.5. 模型检查编译和仿真上板精度验证

对于HAT来说,量化模型的意义在于可以在 BPU 上直接运行。因此,对于量化模型的检查和编译是必须的。 前文提到的 compile_perf 脚本也可以让用户定义好量化模型之后,先检查能否在 BPU 上正常运行, 并可通过 align_bpu_validation 脚本获取模型上板精度。用法同前文。

6.1.6.3.3.6. 结果可视化

如果你希望可以看到训练出来的模型对于单张图片的检测效果,我们的tools文件夹下面同样提供了单张图片预测及可视化的脚本,你只需要运行以下脚本即可:

python3 tools/infer.py --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py --dataset fcos_coco --input-size 512x512x3 --input-images ${img-path} --input-format yuv --is-plot