7.5.7. 自动校准(实验性功能)

量化感知训练工具链目前已经集成了多种 calibration 策略,如 mse、kl、percentile、min-max 等。对于大多数模型,mse 都能取得不错的 calibration 精度。但如果您相关知识储备丰富,有意愿探索精度更高的 calibration 流水线,目前的 calibration 接口可能无法满足您的要求。对此,我们特别探索开发了自动校准的接口,您可以通过该接口自定义需要搜索的 calibration 策略和超参数,基于模型输出相似度逐层搜索最优的量化参数。

本接口与我们 Calibration 流程中的 Mix Observer 有所不同,具体如下:

  1. Mix Observer 在搜索某一算子的量化参数时,只将该算子的输出相似度作为评价指标。而本接口将模型最终输出的相似度作为评价指标来搜索最优量化参数。

  2. Mix Observer 在搜索某一算子的量化参数时,前面的算子都是浮点计算,没有考虑累积的量化误差。而本接口在搜索某一算子的量化参数时,其前面所有的激活和权重都是量化的。

我们开展的消融实验表明上述两点都对模型的 Calibration 精度有一定提升作用。

需要注意的是,由于该策略基于对模型各层的逐层搜索,并以模型最终输出作为量化参数的评价指标,需要耗费较多的时长。

7.5.7.1. 基本原理

  1. 记录浮点模型所有 DeQuantize 算子的输出

  2. 以拓扑排序逐个遍历各个待量化的算子:

    1. 校准某个算子时, 将其 weight(如果有) 和 activation 进行量化,遍历用户指定的 calibration 策略,记录模型对应的 DeQuantize 输出

    2. 对量化输出和浮点输出计算 L2 距离 ,更新最优量化参数

    3. 遍历完所有的 calibration 策略后,将最优量化参数应用到该算子上,开始搜索下一个算子

7.5.7.2. 接口定义

def auto_calibrate(
    calib_model: torch.nn.Module,
    batches: Union[list, tuple, DataLoader],
    num_batches: int = 10,
    batch_process_func: Callable = None,
    observer_list: list = ("percentile", "mse", "kl", "min_max"),
    percentile_list: list = None,
):
    pass

进一步的接口说明请参考接口的 docstring。

7.5.7.3. 使用方法

我们支持两种给数据的方式。

7.5.7.3.1. list/tuple(推荐)

由于该接口需要频繁读取数据,我们推荐您将数据打包成 list/tuple 送入接口。这种情况下,我们会将数据全部搬移到显存/内存(取决于模型参数的 device )上,减少频繁读取数据带来的访存瓶颈。相对于传 DataLoader 的方式,该方式具有较大的性能优势。

calib_model = horizon.quantization.prepare_qat_fx(float_model)
batches = []
n = 0
for image, label in dataloader:
    if n >= 10:
        break
    batches.append(image)
    n += 1

horizon.quantization.auto_calibration(
    calib_model,
    batches,
    10, # num_batches,该方式下不起作用,保持默认即可。list 中所有的 batch 都会被用来校准
    None, # batch_process_func,由于 batches 中的数据已经满足要求,此处保持默认即可
    ["percentile", "min_max"], # 自定义搜索的 calibration 策略
    [99.99, 99.999, 99.9995, 999.9999], # 自定义的 percentile 参数
)

# eval
calib_model.eval()
horizon.quantization.set_fake_quantize(
    calib_model, horizon.quantization.FakeQuantState.VALIDATION
)
for image, label in eval_dataloader:
    pred = calib_model(image)
    pass

7.5.7.3.2. torch.utils.data.DataLoader

尽管传 list 的方式在性能上有较大优势,但因为需要将用来校准的数据全部加载到显存/内存上,对计算设备存在一定要求。因此,我们也支持您直接传 torch DataLoader 以方便您在某些场景和任务下的使用。由于 DataLoader 只会在必要的时候加载数据,相比 list 占用显存更小,同时也带来了的较高的访存压力。在我们的实验中,DataLoader 方式的性能表现与 list 方式存在较大差距,请您酌情使用。

calib_model = horizon.quantization.prepare_qat_fx(float_model)

horizon.quantization.auto_calibration(
    calib_model,
    dataloader, # 直接传 dataloader
    10, # num_batches,只用 dataloader 中的 10 个 batch 进行校准
    lambda x: x[0], # batch_process_func,由于该 dataloader 返回的 batch 是 Tuple[image, label] 的格式,所以需要索引后才能送入模型
    ["percentile", "min_max"], # 自定义搜索的 calibration 策略
    [99.99, 99.999, 99.9995, 999.9999], # 自定义的 percentile 参数
)

# eval
calib_model.eval()
horizon.quantization.set_fake_quantize(
    calib_model, horizon.quantization.FakeQuantState.VALIDATION
)
for image, label in eval_dataloader:
    pred = calib_model(image)
    pass