7.4.2. Calibration 指南

在量化中,一个重要的步骤是确定量化参数,合理的初始量化参数能够显著提升模型精度并加快模型的收敛速度。Calibration 就是在浮点模型中插入 Observer,使用少量训练数据,在模型 forward 过程中统计各处的数据分布,以确定合理的量化参数的过程。虽然不做 Calibration 也可以进行量化感知训练,但一般来说,它对量化感知训练有益无害,所以推荐用户将此步骤作为必选项。

7.4.2.1. 流程和示例

Calibration 与 QAT 的整体流程如下图所示:

../../../_images/calibration_v2_workflow.svg

下面分别介绍各个步骤:

  1. 构建并训练浮点模型。参考 horizon_plugin_pytorch 快速入门章节中的 获取浮点模型 小节内容。

  2. 在浮点模型上插入 Observer 节点。参考 horizon_plugin_pytorch 快速入门章节中的 Calibration 小节内容。使用 prepare_qat_fx 方法转化浮点模型前,需要为模型设置 qconfig

    model.qconfig = horizon.quantization.get_default_qconfig()
    

    get_default_qconfig 可以为 weightactivation 设置不同的 observer 。目前,calibration 可选 observer 有 “min_max”、 “percentile”、 “mse”、 “kl” 和 “mix”。如无特殊需求,weight_observer 推荐使用默认的 “min_max”,activation_observer 推荐使用 “mse”。特殊用法和调试技巧见下面的常见算法介绍。

    fake_quant 参数对 Calibration 结果无影响,保留默认状态即可。

    def get_default_qconfig(
        activation_fake_quant: Optional[str] = "fake_quant",
        weight_fake_quant: Optional[str] = "fake_quant",
        activation_observer: Optional[str] = "min_max",
        weight_observer: Optional[str] = "min_max",
        activation_qkwargs: Optional[Dict] = None,
        weight_qkwargs: Optional[Dict] = None,
    ):
    
  3. 设置 fake quantize 状态为 CALIBRATION

    horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.CALIBRATION)
    

    fake quantize 一共有三种状态,分别需要在 QATcalibrationvalidation 前将模型的 fake quantize 设置为对应的状态。在 calibration 状态下,仅观测各算子输入输出的统计量。在 QAT 状态下,除观测统计量外还会进行伪量化操作。而在 validation 状态下,不会观测统计量,仅进行伪量化操作。

    class FakeQuantState(Enum):
        QAT = "qat"
        CALIBRATION = "calibration"
        VALIDATION = "validation"
    
  4. calibration。把准备好的校准数据喂给模型,模型在 forward 过程中由 observer 观测相关统计量。

  5. 设置模型状态为 eval 并设置 fake quantize 状态为 VALIDATION

    model.eval()
    horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.VALIDATION)
    
  6. 验证 calibration 效果。如果效果满意,则可以直接将模型转为定点或在此基础上进行量化感知训练,不满意则调整 calibration qconfig 中的参数继续 calibration。

7.4.2.2. 常用算法介绍

注解

有关每个算子的参数说明,请参考文末 API 文档。

算法

速度排名

精度排名

易用性排名

min_max

1

5

1

percentile

2

4

4

mse

5

1

2

kl

4

2

3

mix

3

2

1

常用的几种校准方法性能如上表所示,数字越小越好,速度表示相同数据校准耗时,精度表示该方法在大多数模型上的校准效果,易用性表示该方法的调参复杂度。

对于同一模型而言,不同方法不同参数的精度/速度会存在较大差别,最新的一些研究工作也表明,没有一种方法可以在所有模型上都取得最好的精度,需要针对地调整其参数。所以推荐用户对这几种校准方法都进行尝试。

  1. min_max。此方法仅统计最大值最小值的滑动平均,用于快速确定 Batch size、average_constant 等通用参数,没有太多技巧。

  2. percentile。此方法是所有方法中精度上限最高的,但也是调整起来最麻烦的,如果通过其他方法或本方法的默认参数就可以满足精度要求,那么不建议在调参上花太多时间。percentile 可调的参数一共有两个 bins、percentile。bins 越多,max 的候选项间隔越小,可供调整的粒度越细,但也意味着更高的计算耗时。建议先确定 percentile 再调整 bins,两者交替迭代缩小调参范围直至达到满意的效果。绝大部分情况下 bins 取 2048 提供的调整粒度完全足够,不需要单独调整这个参数。以下是一个模型的调参路径:

    顺序

    percentile

    bins

    精度

    1

    99.99

    2048

    53.75

    2

    99.99

    4096

    54.38

    3

    99.995

    4096

    16.25

    4

    99.985

    4096

    32.67

    5

    99.9875

    4096

    57.06

    6

    99.9875

    8192

    62.84

    7

    99.98875

    8192

    57.62

    8

    99.988125

    8192

    63.15

    在这个例子中,可以看到仔细调整后,精度提升了大约 10%。 模型中不同 op 的输入输出之间存在很大差异,一组全局的 percentile 参数可能很难满足所有 op 的需求,对精度要求较高时,可以先通过上面的方法找到较好的全局参数,再通过 debug 工具找到误差较大的几个 op,单独为这几个 op 设置 percentile 参数,设置方式参照 qconfig 设置。下面列举几种常见的容易导致误差较大的数据分布:

    ../../../_images/calibration_percentile_longtail.png

    超长尾分布,percentile 的取值应当小一些,图中 99.9 是较好的取值。

    ../../../_images/calibration_percentile_bimodal.png

    值域过大,且分布并不集中在一处,这种情况无论是保留尾部还是忽略尾部都会带来较大的精度损失,应该在训练浮点模型时通过调整 weight decay 等参数避免这种情况的出现。

    ../../../_images/calibration_percentile_ln.png

    layernorm 的输出分布会呈现出若干集中度非常高的区域,此时 percentile 按照正常方法调整对于量化结果不会有任何影响,需要将 percentile 调整幅度增加。

  3. mse。可调整的参数只有 stride,默认 stride 为 1,会逐步尝试最大值的 100 分位并选出量化反量化前后误差最小(L2 距离)的分位对应的值。此方法对大模型耗时较高,在合理范围内调大 stride 可以在保证精度的前提下减少耗时,stride 调整过大会影响精度。注意,调整此方法的参数只能优化耗时,并不能显著提升精度。

  4. kl。可调的参数一共有两个 bin 和 update_interval。由于此方法耗时过长,不建议调整默认 bin。update_interval 默认为 1,表示间隔多少个 forward step 计算一次 KL,调大可以减少耗时(不影响精度),但需要保证 update_interval 不超过总的 calibration step,否则无法得到正常的量化参数。一般推荐直接将 update_interval 设为 calibration step,这样前面的 forward step 只采集数据更新直方图,只有最后一个 step 才会计算 KL 和 scale,可以最大程度减少 KL 的耗时,同时由于最终的直方图包含所有输入数据的统计信息,因此不会对精度造成影响。

  5. mix。此方法为混合校准,对于每一个需要统计的地方,都会尝试 percentile 方法的不同参数,选出量化反量化前后误差最小(L2 距离)的方法。自动化程度较高,没有需要调整的参数。

7.4.2.3. 调参技巧

  1. calibration 数据越多越好,但因为边际效应的存在,当数据量大到一定程度后,对精度的提升将非常有限。如果训练集较小,可以全部用来 calibration,如果训练集较大,可以结合 calibration 耗时挑选大小合适的子集,建议至少进行 10 - 100 个 step 的校准。

  2. 数据可以做水平翻转这类 augmentation,不要做马赛克这种 augmentation。尽量使用 infer 阶段的前处理 + 训练数据进行校准。

  3. Batch size 尽可能大,如果数据噪声较大或模型离群点较多,可以适当减小。此参数应当在尝试 min max 方法时确定。

  4. average_constant 表示每个 step 对最大值最小值的影响,average_constant 越小,当前 step 的影响越小,历史滑动均值的影响越大。该参数需要结合数据量在 0.01 ~ 0.5 之间调整。当数据量充足时(step > 100),average_constant 取 0.01,数据量不足时,average_constant 酌情增加,极端情况下,只有 2 个 step 的数据,average_constant 取 0.5。此参数应当在尝试 min max 方法时确定,之后其他方法都沿用此参数。

  5. calibration 模型精度较好时,固定 feature map 的量化参数进行 QAT 训练可以取得更好的效果,精度较差时,则不能固定 calibration 得到的量化参数。关于精度是好还是坏,没有明确的标准,需要去尝试。比如:某模型精度为 100,如果 calibration 精度为 50,那么精度肯定称不上好,但如果 calibration 精度为 95,那么这个精度是否可以达到固定 feature map 量化参数的程度就需要尝试了,通常做法是固定与不固定都做实验进行对比。

  6. 优先尝试 min max 方法,该方法是速度最快的,用来跑通 calibration 流程,调整并确定 batch size 和 average_constant 两个参数,接着分别尝试 percentile、kl、mse 和 mix 四种方法并选取效果最好的方法。

7.4.2.4. Observer 参数文档

class horizon_plugin_pytorch.quantization.observer_v2.KLObserver(bins: int = 512, update_interval: int = 1, averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

KL observer.

KL observer based on histogram. Histogram is calculated online and won’t be saved.

参数
  • bins – Number of histograms bins.

  • update_interval – Interval of computing KL entropy and update min/max. KLObserver will constantly collect histograms of activations, but only perform KL calculation when update_interval is satisfied. if it is set to 1, KL entropy will be computed every forward step. Larger interval guarantees less time and does no harm to calibration accuracy. Set it to the total calibration steps can achieve best performance. update_interval must be no greater than total calibration steps, otherwise no min/max will be computed.

  • averaging_constant – Averaging constant for min/max.

  • ch_axis – Channel axis.

  • dtype – Quantized data type.

  • qscheme – Quantization scheme to be used.

  • quant_min – Min quantization value. Will follow dtype if unspecified.

  • quant_max – Max quantization value. Will follow dtype if unspecified.

  • is_sync_quantize – If sync statistics when training with multiple devices.

  • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class horizon_plugin_pytorch.quantization.observer_v2.MSEObserver(stride: int = 1, averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

MSE observer.

Observer module for computing the quantization parameters based on the Mean Square Error (MSE) between the original tensor and the quantized one.

This observer linear searches the quantization scales that minimize MSE.

参数
  • stride – Searching stride. Larger value gives smaller search space, which means less computing time but possibly poorer accuracy. Default is 1. Suggests no greater than 20.

  • averaging_constant – Averaging constant for min/max.

  • ch_axis – Channel axis.

  • dtype – Quantized data type.

  • qscheme – Quantization scheme to be used.

  • quant_min – Min quantization value. Will follow dtype if unspecified.

  • quant_max – Max quantization value. Will follow dtype if unspecified.

  • is_sync_quantize – If sync statistics when training with multiple devices.

  • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class horizon_plugin_pytorch.quantization.observer_v2.MinMaxObserver(averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

Min max observer.

This observer computes the quantization parameters based on minimums and maximums of the incoming tensors. The module records the moving average minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

参数
  • averaging_constant – Averaging constant for min/max.

  • ch_axis – Channel axis.

  • dtype – Quantized data type.

  • qscheme – Quantization scheme to be used.

  • quant_min – Min quantization value. Will follow dtype if unspecified.

  • quant_max – Max quantization value. Will follow dtype if unspecified.

  • is_sync_quantize – If sync statistics when training with multiple devices.

  • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Record the running minimum and maximum of x.

class horizon_plugin_pytorch.quantization.observer_v2.MixObserver(averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

Mix observer.

This observer computes the quantization parameters based on multiple calibration methods and selects the quantization parameters with the smallest quantization error.

参数
  • averaging_constant – Averaging constant for min/max.

  • ch_axis – Channel axis.

  • dtype – Quantized data type.

  • qscheme – Quantization scheme to be used.

  • quant_min – Min quantization value. Will follow dtype if unspecified.

  • quant_max – Max quantization value. Will follow dtype if unspecified.

  • is_sync_quantize – If sync statistics when training with multiple devices.

  • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class horizon_plugin_pytorch.quantization.observer_v2.PercentileObserver(percentile: float = 99.99, bins: int = 2048, averaging_constant: float = 0.01, ch_axis: int = - 1, dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8', qscheme: torch.qscheme = torch.per_tensor_symmetric, quant_min: int = None, quant_max: int = None, is_sync_quantize: bool = False, factory_kwargs: Dict = None)

Percentile observer.

Percentile observer based on histogram. Histogram is calculated online and won’t be saved. The minimum and maximum are moving averaged to compute the quantization parameters.

参数
  • percentile – Index percentile of histrogram

  • bins – Number of histograms bins.

  • averaging_constant – Averaging constant for min/max.

  • ch_axis – Channel axis.

  • dtype – Quantized data type.

  • qscheme – Quantization scheme to be used.

  • quant_min – Min quantization value. Will follow dtype if unspecified.

  • quant_max – Max quantization value. Will follow dtype if unspecified.

  • is_sync_quantize – If sync statistics when training with multiple devices.

  • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class horizon_plugin_pytorch.quantization.MovingAverageMinMaxObserver(averaging_constant=0.01, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=None, quant_max=None, is_sync_quantize=False, factory_kwargs=None)

MovingAverageMinMax Observer.

Observer module for computing the quantization parameters based on the moving average of the min and max values.

This observer computes the quantization parameters based on the moving averages of minimums and maximums of the incoming tensors. The module records the average minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

参数
  • averaging_constant – Averaging constant for min/max.

  • dtype – Quantized data type

  • qscheme – Quantization scheme to be used, only support per_tensor_symmetric scheme

  • reduce_range – Reduces the range of the quantized data type by 1 bit

  • quant_min – Minimum quantization value.

  • quant_max – Maximum quantization value.

  • is_sync_quantize – Whether use sync quantize

  • factory_kwargs – Arguments for register data buffer

forward(x_orig)

Record the running minimum and maximum of x.

class horizon_plugin_pytorch.quantization.MovingAveragePerChannelMinMaxObserver(averaging_constant=0.01, ch_axis=0, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, quant_min=None, quant_max=None, is_sync_quantize=False, factory_kwargs=None)

MovingAveragePerChannelMinMax Observer.

Observer module for computing the quantization parameters based on the running per channel min and max values.

This observer uses the tensor min/max statistics to compute the per channel quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

参数
  • averaging_constant – Averaging constant for min/max.

  • ch_axis – Channel axis

  • dtype – Quantized data type

  • qscheme – Quantization scheme to be used, Only support per_channel_symmetric

  • quant_min – Minimum quantization value.

  • quant_max – Maximum quantization value.

  • is_sync_quantize – whether use sync quantize

  • factory_kwargs – Arguments for register data buffer

forward(x_orig)

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.