7.5.5. 算子融合

训练工具支持的算子融合可分为两大类:1. 吸收 BN;2. 融合 Add、ReLU(6)

7.5.5.1. 吸收 BN

吸收 BN 的目的是为了减少模型的计算量。因为 BN 是线性变换过程,因此,当 BNConv 一起出现的时候,可以把 BN 的参数吸收到 Conv 的参数中,从而在部署的模型中消除 BN 的计算。

吸收的计算过程如下:

通过吸收 BN ,可以把 Conv2d + BN2d 简化为 Conv2d

../../../_images/absorb_bn.svg

7.5.5.2. 融合 Add、ReLU(6)

和 CUDA Kernel Fusion 中将 CUDA Kernel 融合以提高计算速度不同,训练工具支持的融合更加偏重量化层面

BPU 硬件针对常见的模型基本结构做了优化,在计算 Conv -> Add -> ReLU 这种算子组合时,可使算子间的数据传递保留高精度的状态,提高模型整体的数值精度。因此在对模型进行量化时,我们可以将 Conv -> Add -> ReLU 视为一个整体

由于训练工具对模型进行量化改造时以 torch.nn.Module 为单位,为了在量化时将 Conv -> Add -> ReLU 视为一个整体,需要将它们合并为一个 Module

算子融合除了可以使中间结果保留高精度状态之外,也可以省去将中间结果转化为低精度表示的过程,因此执行速度和不融合相比也会更快

(由于算子融合既可以提高模型精度,又可以提高模型速度,一般应该对所有可融合的部分进行融合)

7.5.5.3. 实现原理

得益于 FX 可以获取计算图的优势,训练工具可以自动化地对模型的计算图进行分析,根据预定义的 fusion pattern 对可融合部分进行匹配,并通过 submodule 替换实现融合的操作。下面举例进行说明

(吸收 BN 和融合 Add、ReLU(6) 可以通过相同的机制完成,因此在融合时不需要进行区分)

import torch
from torch import nn
from torch.quantization import DeQuantStub
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization import fuse_fx


class ModelForFusion(torch.nn.Module):
    def __init__(
        self,
    ):
        super(ModelForFusion, self).__init__()
        self.quantx = QuantStub()
        self.quanty = QuantStub()
        self.conv = nn.Conv2d(3, 3, 3)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
        self.dequant = DeQuantStub()

    def forward(self, x, y):
        x = self.quantx(x)
        y = self.quanty(y)
        x = self.conv(x)
        x = self.bn(x)
        x = x + y
        x = self.relu(x)
        x = self.dequant(x)

        return x


float_model = ModelForFusion()
fused_model = fuse_fx(float_model)

print(fused_model)
"""
ModelForFusion(
  (quantx): QuantStub()
  (quanty): QuantStub()
  (conv): Identity()
  (bn): Identity()
  (relu): Identity()
  (dequant): DeQuantStub()
  (_generated_add_0): ConvAddReLU2d(
    (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (relu): ReLU()
  )
)



def forward(self, x, y):
    quantx = self.quantx(x);  x = None
    quanty = self.quanty(y);  y = None
    _generated_add_0 = self._generated_add_0
    add_1 = self._generated_add_0(quantx, quanty);  quantx = quanty = None
    dequant = self.dequant(add_1);  add_1 = None
    return dequant
"""

可以看到,对模型执行算子融合操作后,BN 被吸收进 Conv 中,且 Conv、Add、ReLU 被融合进一个 Module 中(_generated_add_0)。原本的 submodule 被替换为 Identity,且不在 forward 代码中调用

(FX 自动地将模型中 x = x + y 的加号替换为了名为 _generated_add_0Module 形式,以支持算子融合和量化的相关操作)

7.5.5.4. 可以融合的算子

目前支持的可融合的算子组合见以下函数定义

import operator
import torch
from torch import nn
from horizon_plugin_pytorch import nn as horizon_nn


def register_fusion_patterns():
    convs = (
        nn.Conv2d,
        nn.ConvTranspose2d,
        nn.Conv3d,
        nn.Linear,
    )
    bns = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
    adds = (
        nn.quantized.FloatFunctional.add,
        horizon_nn.quantized.FloatFunctional.add,
        torch.add,
        operator.add,  # 即代码中使用的加号
    )
    relus = (nn.ReLU, nn.ReLU6, nn.functional.relu, nn.functional.relu6)

    for conv in convs:
        for bn in bns:
            for add in adds:
                for relu in relus:
                    # conv bn
                    register_fusion_pattern((bn, conv))(ConvBNAddReLUFusion)

                    # conv relu
                    register_fusion_pattern((relu, conv))(ConvBNAddReLUFusion)

                    # conv add
                    register_fusion_pattern((add, conv, MatchAllNode))(
                        ConvBNAddReLUFusion
                    )  # conv 的输出作为 add 的第一个输入
                    register_fusion_pattern((add, MatchAllNode, conv))(
                        ConvBNAddedReLUFusion
                    )  # conv 的输出作为 add 的第二个输入

                    # conv bn relu
                    register_fusion_pattern((relu, (bn, conv)))(
                        ConvBNAddReLUFusion
                    )

                    # conv bn add
                    register_fusion_pattern((add, (bn, conv), MatchAllNode))(
                        ConvBNAddReLUFusion
                    )
                    register_fusion_pattern((add, MatchAllNode, (bn, conv)))(
                        ConvBNAddedReLUFusion
                    )

                    # conv add relu
                    register_fusion_pattern((relu, (add, conv, MatchAllNode)))(
                        ConvBNAddReLUFusion
                    )
                    register_fusion_pattern((relu, (add, MatchAllNode, conv)))(
                        ConvBNAddedReLUFusion
                    )

                    # conv bn add relu
                    register_fusion_pattern(
                        (relu, (add, (bn, conv), MatchAllNode))
                    )(ConvBNAddReLUFusion)
                    register_fusion_pattern(
                        (relu, (add, MatchAllNode, (bn, conv)))
                    )(ConvBNAddedReLUFusion)