7.6.4. QAT

horizon_plugin_pytorch.quantization.convert(module: torch.nn.modules.module.Module, mapping: Optional[Dict[Type[torch.nn.modules.module.Module], Type[torch.nn.modules.module.Module]]] = None, inplace: bool = False, remove_qconfig: bool = True, fast_mode: bool = False)

Convert modules.

Convert submodules in input module to a different module according to mapping by calling from_float method on the target module class. And remove qconfig at the end if remove_qconfig is set to True.

参数
  • module – input module

  • mapping – a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules

  • inplace – carry out model transformations in-place, the original module is mutated

  • fast_mode – whether to accelerate quantized model forward. If set True, quantized model cannot be compiled

horizon_plugin_pytorch.quantization.convert_fx(graph_module: torch.fx.graph_module.GraphModule, inplace: bool = False, convert_custom_config_dict: Optional[Dict[str, Any]] = None, _remove_qconfig: bool = True, fast_mode: bool = False) horizon_plugin_pytorch.quantization.fx.graph_module.QuantizedGraphModule

Convert a calibrated or trained model to a quantized model.

参数
  • graph_module – A prepared and calibrated/trained model (GraphModule)

  • inplace – Carry out model transformations in-place, the original module is mutated.

  • convert_custom_config_dict

    dictionary for custom configurations for convert function:

    convert_custom_config_dict = {
        # We automativally preserve all attributes, this option is
        # just in case and not likely to be used.
        "preserved_attributes": ["preserved_attr"],
    }
    

  • _remove_qconfig – Option to remove the qconfig attributes in the model after convert. for internal use only.

  • fast_mode – whether to accelerate quantized model forward. If set True, quantized model cannot be compiled.

返回

A quantized model (GraphModule)

Example: convert fx example:

# prepared_model: the model after prepare_fx/prepare_qat_fx and
# calibration/training
quantized_model = convert_fx(prepared_model)
horizon_plugin_pytorch.quantization.fuse_fx(model: torch.nn.modules.module.Module, fuse_custom_config_dict: Optional[Dict[str, Any]] = None) horizon_plugin_pytorch.quantization.fx.graph_module.GraphModuleWithAttr

Fuse modules like conv+add+bn+relu etc.

Fusion rules are defined in horizon_plugin_pytorch.quantization.fx.fusion_pattern.py

参数
  • model – a torch.nn.Module model

  • fuse_custom_config_dict

    Dictionary for custom configurations for fuse_fx, e.g.

    fuse_custom_config_dict = {
        # We automativally preserve all attributes, this option is
        # just in case and not likely to be used.
        "preserved_attributes": ["preserved_attr"],
    }
    

Example: fuse_fx example:

from torch.quantization import fuse_fx
m = fuse_fx(m)
horizon_plugin_pytorch.quantization.fuse_known_modules(mod_list, is_qat=False, additional_fuser_method_mapping=None)

Fuse modules.

Return a list of modules that fuses the operations specified in the input module list.

Fuses only the following sequence of modules: conv, bn; conv, bn, relu; conv, relu; conv, bn, add; conv, bn, add, relu; conv, add; conv, add, relu; linear, bn; linear, bn, relu; linear, relu; linear, bn, add; linear, bn, add, relu; linear, add; linear, add, relu. For these sequences, the first element in the output module list performs the fused operation. The rest of the elements are set to nn.Identity()

horizon_plugin_pytorch.quantization.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)

Fuses a list of modules into a single module.

Fuses only the following sequence of modules: conv, bn; conv, bn, relu; conv, relu; conv, bn, add; conv, bn, add, relu; conv, add; conv, add, relu; linear, bn; linear, bn, relu; linear, relu; linear, bn, add; linear, bn, add, relu; linear, add; linear, add, relu. For these sequences, the first element in the output module list performs the fused operation. The rest of the elements are set to nn.Identity()

参数
  • model – Model containing the modules to be fused

  • modules_to_fuse – list of list of module names to fuse. Can also be a list of strings if there is only a single list of modules to fuse.

  • inplace – bool specifying if fusion happens in place on the model, by default a new model is returned

  • fuser_func – Function that takes in a list of modules and outputs a list of fused modules of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.ao.quantization.fuse_known_modules

  • fuse_custom_config_dict – custom configuration for fusion

# Example of fuse_custom_config_dict
fuse_custom_config_dict = {
    # Additional fuser_method mapping
    "additional_fuser_method_mapping": {
        (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
    },
}
返回

model with fused modules. A new copy is created if inplace=True.

Examples:

>>> # xdoctest: +SKIP
>>> m = M().eval()
>>> # m is a module containing the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'],
                      ['submodule.conv', 'submodule.relu']]
>>> fused_m = fuse_modules(
                m, modules_to_fuse)
>>> output = fused_m(input)

>>> m = M().eval()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = fuse_modules(
                m, modules_to_fuse)
>>> output = fused_m(input)
horizon_plugin_pytorch.quantization.prepare_qat(model: torch.nn.modules.module.Module, mapping: Optional[Dict[Type[torch.nn.modules.module.Module], Type[torch.nn.modules.module.Module]]] = None, inplace: bool = False, optimize_graph: bool = False, hybrid: bool = False, optimize_kwargs: Optional[Dict[str, Tuple]] = None, example_inputs: Any = None, qconfig_setter: Optional[Union[Tuple[horizon_plugin_pytorch.quantization.qconfig_template.QconfigSetterBase, ...], horizon_plugin_pytorch.quantization.qconfig_template.QconfigSetterBase]] = None, verbose: int = 0)

Prepare qat.

Prepare a copy of the model for quantization-aware training and converts it to quantized version.

Quantization configuration should be assigned preemptively to individual submodules in .qconfig attribute.

参数
  • model – input model to be modified in-place

  • mapping – dictionary that maps float modules to quantized modules to be replaced.

  • inplace – carry out model transformations in-place, the original module is mutated

  • optimize_graph – whether to do some process on origin model for special purpose. Currently only support using torch.fx to fix cat input scale(only used on Bernoulli)

  • hybrid – whether to generate a hybrid model that some intermediate operation is computed in float. There are some constraints for this functionality now: 1. The hybrid model cannot pass check_model and cannot be compiled. 2. Some quantized operation cannot directly accept input from float operation, user need to manually insert QuantStub.

  • optimize_kwargs

    a dict for optimize graph with the following format:

    optimize_kwargs = {
        # optional, specify which type of optimization to do. Only
        # support "unify_inputs_scale" now
        "opt_types": ("unify_inputs_scale",),
    
        # optional, modules start with qualified name to optimize
        "module_prefixes": ("backbone.conv",),
    
        # optional, modules in these types will be optimize
        "module_types": (horizon.nn.qat.conv2d,),
    
        # optional, functions to optimize
        "functions": (torch.clamp,),
    
        # optional, methods to optimize. Only support
        # FloatFunctional methods now
        "methods": ("add",),
    }
    

  • example_inputs – model inputs. It is used to trace model or check model structure.

  • qconfig_setter – Qconfig setter. Only needed when using qconfig template.

  • verbose

    whether check model structure. it has two levels: 0: do nothing 1: check model structure

    1. if model has shared ops

    2. if model has unfused operations

    3. model quantization config

horizon_plugin_pytorch.quantization.prepare_qat_fx(model: Union[torch.nn.modules.module.Module, torch.fx.graph_module.GraphModule], qconfig_dict: Optional[Dict[str, Any]] = None, prepare_custom_config_dict: Optional[Dict[str, Any]] = None, optimize_graph: bool = False, hybrid: bool = False, hybrid_dict: Optional[Dict[str, List]] = None, opset_version: str = 'hbdk3', example_inputs: Any = None, qconfig_setter: Optional[Union[Tuple[horizon_plugin_pytorch.quantization.qconfig_template.QconfigSetterBase, ...], horizon_plugin_pytorch.quantization.qconfig_template.QconfigSetterBase]] = None, verbose: int = 0) horizon_plugin_pytorch.quantization.fx.graph_module.ObservedGraphModule

Prepare a model for quantization aware training.

参数
  • model – torch.nn.Module model or GraphModule model (maybe from fuse_fx)

  • qconfig_dict

    qconfig_dict is a dictionary with the following configurations:

    qconfig_dict = {
        # optional, global config
        "": qconfig,
    
        # optional, used for module types
        "module_type": [
            (torch.nn.Conv2d, qconfig),
            ...,
        ],
    
        # optional, used for module names
        "module_name": [
            ("foo.bar", qconfig)
            ...,
        ],
        # priority (in increasing order):
        #   global, module_type, module_name, module.qconfig
        # qconfig == None means quantization should be
        # skipped for anything matching the rule.
        # The qconfig of function or method is the same as the
        # qconfig of its parent module, if it needs to be set
        # separately, please wrap this function as a module.
    }
    

  • prepare_custom_config_dict

    customization configuration dictionary for quantization tool:

    prepare_custom_config_dict = {
        # We automativally preserve all attributes, this option is
        # just in case and not likely to be used.
        "preserved_attributes": ["preserved_attr"],
    }
    

  • optimize_graph – whether to do some process on origin model for special purpose. Currently only support using torch.fx to fix cat input scale(only used on Bernoulli)

  • hybrid – Whether prepare model in hybrid mode. Default value is False and model runs on BPU completely. It should be True if the model is quantized by model convert or contains some CPU ops. In hybrid mode, ops which aren’t supported by BPU and ops which are specified by the user will run on CPU. How to set qconfig: Qconfig in hybrid mode is the same as qconfig in non-hybrid mode. For BPU op, we should ensure the input of this op is quantized, the activation qconfig of its previous non-quantstub op should not be None even if its previous non-quantstub op is a CPU op. How to specify CPU op: Define CPU module_name or module_type in hybrid_dict.

  • hybrid_dict

    hybrid_dict is a dictionary to define user-specified CPU op:

    hybrid_dict = {
        # optional, used for module types
        "module_type": [torch.nn.Conv2d, ...],
    
        # optional, used for module names
        "module_name": ["foo.bar", ...],
    }
    # priority (in increasing order): module_type, module_name
    # To set a function or method as CPU op, wrap it as a module.
    

  • opset_version – opset_version specifics the version of opset that determines the behavior of hybrid mode. Ops that in the quantized opset will be considered as quantized ops and run on BPU, while ops not in the quantized opset but in the float opset will be marked as hybrid (float) ops and run on CPU. Valid options are “hbdk3” and “hbdk4”.

  • example_inputs – model inputs. It is used to trace model or check model structure.

  • qconfig_setter – Qconfig setter. Only needed when using qconfig template.

  • verbose

    whether check model structure. It has three levels: 0: do nothing 1: check qat model structure.

    1. if model has shared ops

    2. if model has unfused operations

    3. model quantization config

返回

A GraphModule with fake quant modules (configured by qconfig_dict), ready for quantization aware training

Example: prepare_qat_fx example:

import torch
from horizon_plugin_pytorch.quantization import get_default_qat_qconfig
from horizon_plugin_pytorch.quantization import prepare_qat_fx

qconfig = get_default_qat_qconfig()
def train_loop(model, train_data):
    model.train()
    for image, target in data_loader:
        ...

qconfig_dict = {"": qconfig}
prepared_model = prepare_qat_fx(float_model, qconfig_dict)
# Run QAT training
train_loop(prepared_model, train_loop)

Extended tracer and wrap of torch.fx.

This file defines a inherit tracer of torch.fx.Tracer and a extended wrap to allow wrapping of user-defined Module or method, which help users do some optimization of their own module by torch.fx

horizon_plugin_pytorch.fx.fx_helper.wrap(skip_compile: bool = False)

Extend torch.fx.wrap.

This function can be:

1) called or used as a decorator on a string to register a builtin function as a “leaf function”

2) called or used as a decorator on a function to register this function as a “leaf function”

3) called or used as a decorator on subclass of torch.nn.Module to register this module as a “leaf module”, and register all user defined method in this class as “leaf method”

4) called or used as a decorator on a class method to register it as “leaf method”

参数

skip_compile – Whether the wrapped part should not be compiled.

返回

The actural decorator.

返回类型

wrap_inner