6.1.3.8. 如何进行量化训练

本文档仅说明在HAT中进行量化训练时需要的操作,关于量化的基本原理和在训练框架中的实现方式请参阅 horizon_plugin_pytorch 的相关文档。

在量化训练中,由浮点模型到定点模型的转换流程如下:

量化训练流程图

其中大部分步骤都已集成在HAT的训练pipeline中,用户只需注意在添加自定义模型时实现 fuse_model 方法来完成模型融合,且实现 set_qconfig 方法对量化方式进行配置即可。在编写模型时需要注意以下几点:

  • HAT只会调用最外层模块的 fuse_model 方法,因此在 fuse_model 的实现中要负责所有子模块的fuse。

  • 优先使用 hat.models.base_modules 中提供的基础模块,这些基础模块已实现 fuse_model 方法,可减少工作量和开发难度。

  • 模型注册,HAT中的各种模块全部采用了注册机制,只有将定义的模型在对应的注册项中进行注册,才可以在config文件中以 dict(type={$class_name}, ...) 的形式使用模型。

  • 需要在最外层模块实现 set_qconfig 方法,如果子模块中有特殊layer需单独设置 QConfig,也需要在该子模块中实现 set_qconfig 方法,此部分细节可见 set_qconfig 书写规范和自定义 qconfig 介绍 章节。

此外,为使模型可转为量化模型,需要满足一些条件,具体见horizon_plugin_pytorch 的相关文档。

6.1.3.8.1. 量化训练流程简介

6.1.3.8.1.1. 添加自定义模型


import torch
from torch import nn

from hat.registry import OBJECT_REGISTRY


# 使用装饰器的方式将模型进行注册
@OBJECT_REGISTRY.register_module
class ExampleNet(nn.Module):
    def __init__(self):
        ...

    def forward(self, x):
        ...

    def fuse_model(self):
        # 需要调用所有子模块的 fuse_model 方法
        if hasattr(self.submodule, "fuse_model"):
            self.submodule.fuse_model()

        # 具体 fuse 的接口见 horizon_plugin_pytorch 文档
        ...

    def set_qconfig(self):
        # 具体模型量化配置的接口见 horizon_plugin_pytorch 文档
        from hat.utils import qconfig_manager
        # 默认使用 qconfig_manager.get_default_qat_qconfig() 得到的 QConfig
        self.qconfig = qconfig_manager.get_default_qat_qconfig()
        # 对需要特殊处理的子模块,调用子模块的 set_qconfig,
        # 子模块的 set_qconfig  中只需实现对特殊 layer 的 QConfig 设置
        if hasattr(self.submodule, "set_qconfig"):
            self.submodule.set_qconfig()
        # 如果有特殊节点不需要设置QConfig,比如 loss,需要设置其QConfig 为 None
        if self.loss is Not None:
            self.loss.qconfig = None
        ...

6.1.3.8.1.2. 添加 config 文件


ckpt_dir = ...

model = dict(type="ExampleNet")

float_trainer = dict(
    type="distributed_data_parallel_trainer",
    model=model,
    data_loader=...,
    optimizer=...,
    batch_processor=...,
    num_epochs=...,
    device=None,
    callbacks=...,
    ...,
)

qat_trainer = dict(
    type="distributed_data_parallel_trainer",
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "float-checkpoint-best.pth.tar"
                ),
            ),
            dict(type="Float2QAT"),
        ],
    ),
    data_loader=...,
    optimizer=...,
    batch_processor=...,
    num_epochs=...,
    device=None,
    callbacks=...,
    ...,
)

val_callback = dict(
    type="Validation",
    data_loader=...,
    batch_processor=...,
    callbacks=[val_metric_updater, ...],
)

trace_callback = dict(
    type="SaveTraced",
    save_dir=ckpt_dir,
    trace_inputs=deploy_inputs,
)

ckpt_callback = dict(
    type="Checkpoint",
    save_dir=ckpt_dir,
    name_prefix=training_step + "-",
    strict_match=True,
    mode="max",
)

int_trainer = dict(
    type="Trainer",
    model=deploy_model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(type="Float2QAT"),
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "qat-checkpoint-best.pth.tar"
                ),
            ),
            dict(type="QAT2Quantize"),
        ],
    ),
    # int_trainer 中实际不包含训练流程
    data_loader=None,
    optimizer=None,
    batch_processor=None,
    num_epochs=0,
    ################################
    device=None,
    callbacks=[
        ckpt_callback,
        trace_callback,
    ],
)

6.1.3.8.2. 训练

只需在使用 tools/train.py 脚本时按顺序指定训练阶段即可,会自动根据训练阶段调用相应的 solver 来执行训练过程:

python3 tools/train.py --stage float ...
python3 tools/train.py --stage qat ...
python3 tools/train.py --stage int_infer ...
  • float:正常的浮点训练。

  • qat:QAT训练(量化感知训练),首先初始化一个浮点模型,加载训练好的浮点模型权重,再将此模点模型转为QAT模型进行训练。

  • int_infer:定点转化预测,此阶段首先初始化一浮点模型,将此浮点模型先转为QAT模型并加载训练好的QAT模型权重,再将 QAT模型转为定点模型。转出的定点模型无法进行训练,只能执行validation得到最终的定点模型精度。

6.1.3.8.3. 恢复训练

可以通过在 config{stage}_trainer 中配置 resume_optimizerresume_epoch_or_step 字段来恢复意外中断的训练,或仅恢复optimizer来进行fine-tune。例如:


float_trainer = dict(
    ...
    resume_optimizer=True,
    resume_epoch_or_step=True,
    ...,
)

恢复训练有三种使用场景:

  1. 完全恢复: 该场景为恢复意外中断的训练,会恢复上一个checkpoint的所有状态,包括optimizer、LR、epoch、step 等。该场景只需配置 resume_optimizer 字段即可;

  2. 恢复optimizer用于fine-tune: 该场景只会恢复optimizer和LR的状态,但epoch、step都会从0开始,用于某些任务的 fine-tune。该场景需要配置 resume_optimizer,并且需要配置resume_epoch_or_step=False

  3. 只加载模型参数: 该场景只会加载模型参数,不会恢复其他任何状态(optimizer、epoch、step、LR)。该场景只需要在 model_convert_pipeline 中配置 LoadCheckpoint ,并且需要配置 resume_optimizer=Falseresume_epoch_or_step=False

6.1.3.8.4. qat_mode

6.1.3.8.4.1. 作用

qat_mode 用于设置QAT阶段是否带BN进行量化训练,配合HAT提供的 FuseBN 接口还可以控制量化训练全程带BN或是中途逐步吸收BN。

6.1.3.8.4.2. 可选项定义

qat_mode可选的设置有如下三种:


class QATMode(object):

    FuseBN = "fuse_bn"
    WithBN = "with_bn"
    WithBNReverseFold = "with_bn_reverse_fold"

6.1.3.8.4.3. 原理介绍

6.1.3.8.4.3.1. fuse_bn

QAT阶段没有BN,HAT默认的量化训练方式。

通过将qat_mode设置为 fuse_bn ,在浮点模型op融合的过程中,BN的weight和bias均被吸收到Conv的weight和bias中,原来的Conv + BN的组合将只剩下 Conv,这一吸收过程理论上是没有误差的。

6.1.3.8.4.3.2. with_bn

QAT 阶段带 BN 进行训练。

通过设置qat_mode为 with_bn ,浮点模型转为QAT模型的时候BN不会吸收进Conv,而是在QAT阶段以 Conv + BN + 输出量化节点 的形式作为一个被融合的量化op存在于量化模型中。最终在量化训练结束转为quantized(也称int infer) 模型的步骤中,BN的weight和bias将自动吸收进conv的量化参数中,吸收之后得到的quantized op和原来的QAT op计算结果保持一致。

在这一模式下,用户还可以选择在QAT中途将BN吸收进Conv。用户手动吸收BN前后QAT模型的forward结果不一致,原因是BN weight吸收至Conv weight之后,在之前量化训练中统计出来的量化参数conv_weight_scale不再适用于当前的conv_weight,在对conv_weight的量化中将产生较大误差,需要继续进行量化训练调整量化参数。

6.1.3.8.4.3.3. with_bn_reverse_fold

QAT 阶段带 BN 进行训练。

本模式与 with_bn 的不同之处在于在BN吸收之前,量化训练阶段计算conv_weight_scale时会考虑BN的weight(具体的计算方式不在此详述),目的是为了吸收BN weight之后conv_weight_scale仍然适用于新的conv_weight。

该模式用意是为分步吸收BN提供一种无损的吸收方式:在量化训练中途吸收BN,吸收前后模型forward结果理论上完全一致,用户可以在量化训练结束前逐步吸收模型中所有的BN并且保证每次吸收之后loss不会有太大的波动。

在该模式下如果有BN在量化训练结束时仍未被吸收,在QAT模型转quantized模型的过程中剩余的BN将自动被吸收,这一吸收操作理论上是无损的。

6.1.3.8.4.4. 用法

6.1.3.8.4.4.1. 设置 qat_mode

用户只需要在 model_convert_pipeline 中设置 qat_mode 即可。

例如:


model_convert_pipeline=dict(
    type="ModelConvertPipeline",
    qat_mode="with_bn",
    converters=[
        dict(type="Float2QAT"),
        dict(
            type="LoadCheckpoint",
            checkpoint_path=os.path.join(
                ckpt_dir, "qat-checkpoint-best.pth.tar"
            ),
        ),
    ],
)

6.1.3.8.4.4.2. 查看当前 qat_mode


from horizon_plugin_pytorch.qat_mode import get_qat_mode
qat_mod = get_qat_mode()

6.1.3.8.4.4.3. 设置逐步吸收 BN

with_bnwith_bn_reverse_fold 两种模式下,用户可以将 FuseBN 设置为回调函数用于在指定的epoch或是step吸收指定module中的BN。

FuseBN定义:


class FuseBN(OnlineModelTrick):
    Args:
        module: sub model names to fuse BN.
        step_or_epoch: when to fuseBN, same length as module.
        update_by: by step or by epoch.
        inplace: if fuse BN inplace
    def __init__(
        self,
        modules: List[List[str]],
        step_or_epoch: List[int],
        update_by: str,
        inplace: bool = False,
    )

在config文件中使用FuseBN Example:


from hat.callbacks import FuseBN

# 定义回调函数
# 命名为 backbone 的 module 中的 BN 将在第 1000 个 step 被吸收
# 命名为 neck 的 module 中的 BN 将在第 1500 个 step 被吸收
fuse_bn_callback = FuseBN(
   modules=[['backbone'], ['neck']],
   step_or_epoch=[1000, 1500],
   update_by='step',
)

# 将回调函数加入到 trainer 中
qat_trainer = dict(
    type="distributed_data_parallel_trainer",
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(type="Float2QAT"),
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "qat-checkpoint-best.pth.tar"
                ),
            ),
        ],
    ),
    data_loader=...,
    optimizer=...,
    batch_processor=...,
    num_epochs=...,
    device=None,
    callbacks=[
        callbacks0,
        ..., 
        fuse_bn_callback,
        callbacks99
    ],
    ...,
)

6.1.3.8.4.5. qat_mode 总结

qat_mode

BN 何时被吸收

如何吸收BN

理论上吸收后模型 forward 结果是否有变化

fuse_bn

一定在浮点模型 op 融合过程

执行 fuse_module 之后吸收完成

with_bn

可以在量化训练中途

通过设置回调函数在指定 epoch 或 batch 吸收

with_bn

可以在 QAT 模型转 quantized 模型过程

随 QAT 转 quantized 自动完成

with_bn_reverse_fold

可以在量化训练中途

通过设置回调函数在指定 epoch 或 batch 吸收

with_bn_reverse_fold

可以在 QAT 模型转 quantized 模型过程

随 QAT 转 quantized 自动完成

一般训练流程是浮点训练到理想精度然后量化训练,该流程只需要使用 fuse_bn 即可。如果是没有浮点训练一开始就是量化训练,为了确保模型能收敛,才需要使用带BN的量化训练模式。

注解

本文中之所以说“理论上吸收前后无损”或“无变化”,是由于在实际计算中吸收前后两次浮点计算的结果有较低的概率会在小数点较靠后的数位上不一致,微小的变化加上量化操作导致吸收BN后Conv的输出相比吸收前Conv + BN的输出在部分数值上可能会产生一个输出scale的绝对误差。