7.6.3. 伪量化算子

class horizon_plugin_pytorch.quantization.FakeQuantize(observer: type = <class 'horizon_plugin_pytorch.quantization.observer.MovingAverageMinMaxObserver'>, saturate: bool = None, in_place: bool = False, compat_mask: bool = True, channel_len: int = 1, fast_training=True, **observer_kwargs)

Simulate the quantize and dequantize operations in training time.

The output of this module is given by

fake_quant_x = clamp(floor(x / scale + 0.5), quant_min, quant_max) * scale # noqa

  • scale defines the scale factor used for quantization.

  • zero_point specifies the quantized value to which 0 in floating point maps to

  • quant_min specifies the minimum allowable quantized value.

  • quant_max specifies the maximum allowable quantized value.

  • fake_quant_enabled controls the application of fake quantization on tensors, note that statistics can still be updated.

  • observer_enabled controls statistics collection on tensors

  • dtype specifies the quantized dtype that is being emulated with fake-quantization, the allowable values is qint8 and qint16. The values of quant_min and quant_max should be chosen to be consistent with the dtype

参数
  • observer – Module for observing statistics on input tensors and calculating scale and zero-point.

  • saturate – Whether zero out the grad for value out of quanti range.

  • in_place – Whether use in place fake quantize.

  • compat_mask – Whether pack the bool mask into bitfield when saturate = True.

  • channel_len – Size of data at channel dim.

  • fast_training – Whether use fast training mode. If True, computing scale and fake quantization will be done in one step.

  • observer_kwargs – Arguments for the observer module

observer

User provided module that collects statistics on the input tensor and provides a method to calculate scale and zero-point.

extra_repr()

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

set_qparams(scale: Union[torch.Tensor, Sequence, float], zero_point: Optional[Union[torch.Tensor, Sequence, int]] = None)

Set qparams, default symmetric.

classmethod with_args(**kwargs)

Wrapper that allows creation of class factories.

This can be useful when there is a need to create classes with the same constructor arguments, but different instances. Can be used in conjunction with _callable_args

Example:

>>> # xdoctest: +SKIP("Undefined vars")
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False