7.4.3. 量化感知训练指南

量化感知训练通过在模型中插入一些伪量化节点,从而使得通过量化感知训练得到的模型转换成定点模型时尽可能减少精度损失。 量化感知训练和传统的模型训练无异,开发者可以从零开始,搭建一个伪量化模型,然后对该伪量化模型进行训练。 由于部署的硬件平台有诸多限制,对于开发者来说,搞清这些限制,并且根据这些限制搭建伪量化模型门槛较高。量化感知训练工具通过在开发者提供的浮点模型上根据部署平台的限制自动插入伪量化量化算子的方法,降低开发者开发量化模型的门槛。

量化感知训练由于施加了各种限制,因此,一般来说,量化感知训练比纯浮点模型的训练更加困难。量化感知训练工具的目标是降低量化感知训练的难度,降低量化模型部署的工程难度。

7.4.3.1. 流程和示例

虽然量化感知训练工具不强制要求用户从一个预训练的浮点模型开始,但是,经验表明,通常从预训练的高精度浮点模型开始量化感知训练能大大降低量化感知训练的难度。

from horizon_plugin_pytorch.quantization import get_default_qconfig
# 将模型转为 QAT 状态
default_qat_8bit_fake_quant_qconfig = get_default_qconfig(
    activation_fake_quant="fake_quant",
    weight_fake_quant="fake_quant",
    activation_observer="min_max",
    weight_observer="min_max",
    activation_qkwargs=None,
    weight_qkwargs={
        "qscheme": torch.per_channel_symmetric,
        "ch_axis": 0,
    },
)
default_qat_out_8bit_fake_quant_qconfig = get_default_qconfig(
    activation_fake_quant=None,
    weight_fake_quant="fake_quant",
    activation_observer=None,
    weight_observer="min_max",
    activation_qkwargs=None,
    weight_qkwargs={
        "qscheme": torch.per_channel_symmetric,
        "ch_axis": 0,
    },
)
qat_model = prepare_qat_fx(
    float_model,
    {
        "": default_qat_8bit_fake_quant_qconfig,
        "module_name": {
            "classifier": default_qat_out_8bit_fake_quant_qconfig,
        },
    },
).to(device)
# 加载 Calibration 模型中的量化参数
qat_model.load_state_dict(calib_model.state_dict())
# 进行量化感知训练
# 作为一个 filetune 过程,量化感知训练一般需要设定较小的学习率
optimizer = torch.optim.SGD(
    qat_model.parameters(), lr=0.0001, weight_decay=2e-4
)

for nepoch in range(epoch_num):
    # 注意此处对 QAT 模型 training 状态的控制方法
    qat_model.train()
    set_fake_quantize(qat_model, FakeQuantState.QAT)

    train_one_epoch(
        qat_model,
        nn.CrossEntropyLoss(),
        optimizer,
        None,
        train_data_loader,
        device,
    )

    # 注意此处对 QAT 模型 eval 状态的控制方法
    qat_model.eval()
    set_fake_quantize(qat_model, FakeQuantState.VALIDATION)

    # 测试 qat 模型精度
    top1, top5 = evaluate(
        qat_model,
        eval_data_loader,
        device,
    )
    print(
        "QAT model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
            top1.avg, top5.avg
        )
    )

# 测试 quantized 模型精度
quantized_model = convert_fx(qat_model.eval()).to(device)

top1, top5 = evaluate(
    quantized_model,
    eval_data_loader,
    device,
)
print(
    "Quantized model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
        top1.avg, top5.avg
    )
)

注意

由于部署平台的底层限制,QAT 模型无法完全代表最终上板精度,请务必监控 quantized 模型精度,确保 quantized 模型精度正常,否则可能出现模型上板掉点问题。

由上述示例代码可以看到,与传统的纯浮点模型训练相比,量化感知训练多了两个步骤:

  1. prepare_qat_fx

  2. 加载 Calibration 模型参数

7.4.3.1.1. prepare_qat_fx

这一步骤的目标是对浮点网络进行变换,插入伪量化节点。

7.4.3.1.2. 加载 Calibration 模型参数

通过加载 Calibration 得到的伪量化参数,来获得一个较好的初始化。

7.4.3.1.3. 训练迭代

至此,完成了伪量化模型的搭建和参数的初始化,然后就可以进行常规的训练迭代和模型参数更新,并且监控 quantized 模型精度。

7.4.3.2. 伪量化算子

量化感知训练和传统的浮点模型的训练主要区别在于插入了伪量化算子,并且,不同量化感知训练算法也是通过伪量化算子来体现的,因此,这里介绍一下伪量化算子。

注解

由于 BPU 只支持对称量化,因此,这里以对称量化为例介绍。

7.4.3.2.1. 伪量化过程

以 int8 量化感知训练为例,一般来说,伪量化算子的计算过程如下:

fake_quant_x = clip(round(x / scale),-128, 127) * scale

和 Conv2d 通过训练来优化 weight, bias 参数类似,伪量化算子要通过训练来优化 scale 参数。 然而,由于 round 作为阶梯函数,其梯度为 0,从而导致了伪量化算子无法直接通过梯度反向传播的方式进行训练。解决这一问题,通常有两种方案:基于统计的方法和基于“学习”的方法。

7.4.3.2.2. 基于统计的方法

量化地目标是把 Tensor 中的浮点数通过 scale 参数均匀地映射到 int8 表示的 [-128, 127] 的范围上。既然是均匀映射,那么很容易得到 scale 的计算方法:

def compute_scale(x: Tensor):
    xmin, xmax = x.max(), maxv = x.min()
    return max(xmin.abs(), xmax.abs()) / 256.0

由于 Tensor 中数据分布不均匀以及外点问题,又衍生了不同的计算 xmin 和 xmax 的方法。可以参考 MovingAverageMinMaxObserver 等。

在工具中的使用方法请参考 default_qat_8bit_fake_quant_qconfig 及其相关接口。

7.4.3.2.3. 基于学习的方法

虽然 round 的梯度为 0,研究者通过实验发现,在该场景下,如果直接设置其梯度为 1 也可以使得模型收敛到预期的精度。

def round_ste(x: Tensor):
    return (x.round() - x).detach() + x

在工具中的使用方法请参考 default_qat_8bit_lsq_quant_qconfig 及其相关接口。

有兴趣进一步了解的用户可以参考如下论文:Learned Step Size Quantization