7.5.6. Adaround(实验性功能)

Adaround 是一种业界前沿的 PTQ 量化方法,通过逐层学习模型权重是向上取整还是向下取整,可以取得比传统的四舍五入策略更好的量化精度。在我们的实验中,Adaround 在不少任务中(分类、分割、BEV等)都可以以较小的性能代价有效地提升模型 calibration 精度,成为了现有 calibration 流程的有效补充。

7.5.6.1. 基本原理

Adaround 旨在通过学习更好的取整范式降低权重的量化误差,因此其优化对象是带权重的算子。目前仅支持 Conv 和 Linear。Adaround 会以拓扑顺序逐层优化 Conv/Linear,基于单算子量化误差最小化学习向上/向下取整的 mask,最后 inplace 地修改 weight 完成优化。

7.5.6.2. 接口定义

def weight_reconstruction(
    calib_model: torch.nn.Module,
    batches: Union[list, tuple, DataLoader],
    batch_process_func: Callable = None,
    custom_config_dict: dict = None,
):
    pass

其中,custom_config_dict 为 adaround 相关的一些配置参数。 包含了以下参数:

    custom_config_dict = {
        "num_batches": 10,
        "num_steps": 100,
        "exclude_prefix": [],
        "warm_up": 0.2,
        "weight": 0.01,
        "b_range": [20, 2],
    }

num_batches: 仅当您传的数据是 Dataloader 时才有效,代表了 Dataloader 中参与 adaround 优化的 batch 数量。如果您传的数据是 list/tuple 的格式,则该参数不起作用,adaround 会使用 list 中全部的 batch。一般使用默认值 10 即可。

num_step: 每个 Conv/Linear 的优化次数,次数越大理论效果越好。这是您在 adaround 的调参过程中需要关注的主要参数。

exclude_prefix: 如果您有部分 module 不想被 adaround 优化,可在此添加其 prefix,所有以该 prefix 开头的 module 都不会被优化。在我们的实验中,绝大部分模型都不需要设置该参数,可以稳定地提升 calibration 精度。但极个别检测模型存在优化其检测 head 反而导致精度下降的情况,此时可通过该参数过滤。

warm_up: [0, 1] 之间的参数,表示 warm_up 所占比率,前 warm_up * num_step 的优化不会施加对 round 的正则,使得优化可以完全以精度最优为准则进行。对精度的影响不大,一般保持为默认值 0.2 即可。

weight: round loss 的正则化权重系数,weight 越大则 round loss 在 loss 中的统治地位越强。这是您在 adaround 的调参过程中可关注的次要参数,默认值是 0.01,可适当在默认值上下调节, 如根据 loss 相对大小情况尝试 0.1、 0.001 等。

b_range: b 是决定 round loss 平滑程度的一个参数,b_range 控制其范围。一般不需要调节,保持默认值 [20, 2]即可。这意味着该参数一开始是 20,并随 step 数线性衰减到 2。

batch_size * num_step 是算子在优化过程中实际跑的样本数(样本由于随机采样会有重复),一个可供参考的取值是让 batch_size * num_step 在 10000~20000 左右。

注意

  1. num_step 是影响 Adaround 精度的主要参数,您在调整超参时一般只需关注该参数即可。

  2. 在我们的实验中,Adaround 在大部分任务中都可以通过简单调节 num_step 参数稳定地提升 calibration 精度,但在检测任务中,可能需要仔细设置 exclude_prefix 过滤 head 中的部分层才能实现精度的提升。当您在检测任务中遇到 Adaround 导致模型 calibration 精度下降的情况时,我们建议您直接选择量化感知训练(QAT)提升量化精度。

其余参数说明请参考接口的 docstring。

7.5.6.3. 使用方法

我们支持两种给数据的方式。

7.5.6.3.1. list/tuple(推荐)

由于该接口需要频繁读取数据,我们推荐您将数据打包成 list/tuple 送入接口。这种情况下,我们会将数据全部搬移到显存/内存(取决于模型参数的 device )上,减少频繁读取数据带来的访存瓶颈。相对于传 DataLoader 的方式,该方式具有较大的性能优势。

# 先走正常的 calibration 流程
calib_model = horizon.quantization.prepare_qat_fx(float_model)
calib_model.eval()
horizon.quantization.set_fake_quantize(
    calib_model, horizon.quantization.FakeQuantState.CALIBRATION
)
for image, label in dataloader:
    calib_model(image)

# 准备 adaround 所需数据
batches = []
n = 0
for image, label in dataloader:
    if n >= 10:
        break
    batches.append(image)
    n += 1

# 自定义 adaround 配置。用户自定义不优化模型中的 head。
custom_config_dict = {"num_steps": 100, "exclude_prefix": ["head",]}

horizon.quantization.weight_reconstruction(
    calib_model,
    batches,
    None, # batch_process_func,由于 batches 中的数据已经满足要求,此处保持默认即可
    custom_config_dict,
)

# eval
calib_model.eval()
horizon.quantization.set_fake_quantize(
    calib_model, horizon.quantization.FakeQuantState.VALIDATION
)
for image, label in eval_dataloader:
    pred = calib_model(image)
    pass

7.5.6.3.2. torch.utils.data.DataLoader

尽管传 list 的方式在性能上有较大优势,但因为需要将用来校准的数据全部加载到显存/内存上,对计算设备存在一定要求。因此,我们也支持您直接传 torch DataLoader 以方便您在某些场景和任务下的使用。由于 DataLoader 只会在必要的时候加载数据,相比 list 占用显存更小,同时也带来了的较高的访存压力。在我们的实验中,DataLoader 方式的性能表现与 list 方式存在较大差距,请您酌情使用。

# 先走正常的 calibration 流程
calib_model = horizon.quantization.prepare_qat_fx(float_model)
calib_model.eval()
horizon.quantization.set_fake_quantize(
    calib_model, horizon.quantization.FakeQuantState.CALIBRATION
)
for image, label in dataloader:
    calib_model(image)

# 自定义 adaround 配置,这里和上面不同的是设置了 num_batches 为 16,表示 dataloader 中实际只有 16 个 batch 会参与优化
custom_config_dict = {"num_batches": 16, "num_steps": 100, "exclude_prefix": ["head",]}

horizon.quantization.mix_calibration(
    calib_model,
    dataloader, # 直接传 dataloader
    lambda x: x[0], # batch_process_func,由于该 dataloader 返回的 batch 是 Tuple[image, label] 的格式,所以需要索引后才能送入模型
    custom_config_dict,
)

# eval
calib_model.eval()
horizon.quantization.set_fake_quantize(
    calib_model, horizon.quantization.FakeQuantState.VALIDATION
)
for image, label in eval_dataloader:
    pred = calib_model(image)
    pass