10.1.3.10. 数据校准

在量化训练(QAT)中,一个重要的步骤是确定量化参数 scale ,一个合理的 scale 能够显著提升模型训练结果和加快模型的收敛速度。Calibration是通过用浮点模型在训练集上跑少数batch的数据(只跑forward过程,没有backward),统计这些数据的分布直方图,通过一定方法去计算出 min_valuemax_value ,然后可以用这些 min_valuemax_value 去获取scale。当QAT的训练精度上不去的时候,在QAT的开始之前使用calibration做量化参数的微调,获取scale,可以为QAT提供更好的量化初始化参数,提升收敛速度和精度。

10.1.3.10.1. 如何定义 Calibration 模型

  • 默认不需要对现有模型做任何修改

    类似于定义量化模型时需要设置 QAT QConfig ,Calibration时也需要对模型设置 Calibration QConfig 。不过, Calibration QConfig 的设置相对来说比较简单,HAT框架已经实现对模型 Calibration QConfig 的默认设置,用户无需对模型做任何修改,即可使用Calibration。

  • 自定义模型子模块 Calibration QConfig

    在上文的默认情况下,会为模型的所有Module(继承自nn.Module)设置 Calibration QConfig 。因此,Calibration时也就会对所有 Module 的特征分布进行统计。如果有特殊需求,可以在模型内自定义实现 set_calibration_qconfig 方法:

    
    class Classifier(nn.Module):
        def __init__(self,):
            ...
      
        def forward(self, x):
            ...
          
        # 自定义要做 Calibration 的模块
        def set_calibration_qconfig(self, ):
          
            # 比如可以设置 Loss 的 qconfig 为 None,就会不再对 Loss 做 Calibration,
            # 可以一定程度减少统计量,提升 Calibration 速度,降低显存占用
            if self.loss is not None:
                self.loss.qconfig = None
    
    

10.1.3.10.2. 浮点模型做 Calibration

HAT中集成了Calibration功能,浮点模型做Calibration命令和正常训练相似,只需执行以下命令即可:

python3 tools/train.py --stage calibration ...

需要注意的是 config 文件中 calibration_trainer 中的一些配置:


# Note: The transforms of the dataset during calibration can be
# consistent with that during training or validation, or customized.
# Default used `val_batch_processor`.
calibration_data_loader = copy.deepcopy(data_loader)
calibration_data_loader.pop('sampler')  # Calibration do not support DDP or DP
calibration_batch_processor = copy.deepcopy(val_batch_processor)

calibration_trainer = dict(
    type="Calibrator",
    model=model,
    # 1. 设置 data_loader 和 batch_processor
    data_loader=calibration_data_loader,
    batch_processor=calibration_batch_processor,
    # 2. 设置 calibration 迭代的 batch 数目
    num_stages=30,
    ...   
)

1. 数据集的设置:

做Calibration的数据集(dataset)不能是测试集(可以是训练集或其他数据),但是做Calibration时用于数据增强的transforms 可以和正常训练时的transforms保持一致,但是也可以设置成和validation的transforms一致,也可以自定义transforms。(哪种实验效果最好,暂时没有定论,都可以尝试。)

2. Calibration 迭代的图片数目(可供参考):

  • classification:图片张数一般可以500~1500张就可以取得不错的效果。

  • segmentation&&detection:图片张数可以100~300张左右。

注解

这些图片张数具体数目也不是固定的,上方的建议只是从已有的实验中总结的经验,可根据实际情况调整。

10.1.3.10.3. 使用Calibration模型做QAT训练


qat_trainer = dict(
    type="distributed_data_parallel_trainer",
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        # (可选) 设置 QAT 训练时 scale 更新系数
        qconfig_params=dict(
            activation_qkwargs=dict(
                averaging_constant=0,
            ),
            weight_qkwargs=dict(
                averaging_constant=1,
            ),
        ),
        converters=[
            dict(type="Float2QAT"),
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "calibration-checkpoint-best.pth.tar"
                ),
            ),
        ],
    ),
)

QAT时averaging_constant参数设置:

量化时scale参数的更新规则是 scale = (1 - averaging_constant) * scale + averaging_constant * current_scale

在已有的一些实验中(主要是图像分类任务实验)发现,做完calibration后,把activation的scale固定住,不进行更新,即设置activation的 averaging_constant=0 , 并设置weight的 averaging_constant=1 ,效果可能会相对略好一些。

注解

这种设置并不适用于所有任务,在lidar任务中,固定scale,精度也可能会变差。可根据实际情况调整。

接下来只需要执行正常的QAT训练命令,即可启动QAT训练:

python3 tools/train.py --stage qat ...