7.5.2. FX Quantization 原理介绍

阅读此文档前,建议先阅读 torch.fx — PyTorch documentation,以对 torch 的 FX 机制有初步的了解。

FX 采用符号执行的方式,可以在 nn.Module 或 function 的层面对模型建图,从而实现自动化的 fuse 以及其他基于图的优化。

7.5.2.1. 量化流程

7.5.2.1.1. Fuse(可选)

FX 可以感知计算图,所以可以实现自动化的算子融合,用户不再需要手动指定需要融合的算子,直接调用接口即可。

fused_model = horizon.quantization.fuse_fx(model)
  • 注意 fuse_fx 没有 inplace 参数,因为内部需要对模型做 symbolic trace 生成一个 GraphModule,所以无法做到 inplace 的修改

  • fused_modelmodel 会共享几乎所有属性(包括子模块、算子等),因此在 fuse 之后请不要对 model 做任何修改,否则可能影响到 fused_model

  • 用户不必显式调用 fuse_fx 接口,因为后续的 prepare_qat_fx 接口内部集成了 fuse 的过程

7.5.2.1.2. Prepare

用户在调用 prepare_qat_fx 接口之前必须根据目标硬件平台设置全局的 march。接口内部会先执行 fuse 过程(即使模型已经 fuse 过了),再将模型中符合条件的算子替换为 horizon.nn.qat 中的实现。

  • 用户可以根据需要选择合适的 qconfig(Calibtaion 或 QAT,注意两种 qconfig 不能混用)

  • fuse_fx 类似,此接口不支持 inplace 参数,且在 prepare_qat_fx 之后请不要对输入的模型做任何修改

horizon.march.set_march(March.XXX)
qat_model = horizon.quantization.prepare_qat_fx(
    model,
    {
        "": horizon.qconfig.default_calib_8bit_fake_quant_qconfig,
        "module_name": {
            "<module_name>": custom_qconfig,
        },
    },)

7.5.2.1.3. Convert

  • fuse_fx 类似,此接口不支持 inplace 参数,且在 convert_fx 之后请不要对输入的模型做任何修改

quantized_model = horizon.quantization.convert_fx(qat_model)

7.5.2.1.4. Eager Mode 兼容性

大部分情况下,FX 量化的接口可以直接替换 eager mode 量化的接口(prepare_qat -> prepare_qat_fx, convert -> convert_fx),但是不能和 eager mode 的接口混用。部分模型在以下情况下需要对代码结构做一定的修改。

  • FX 不支持的操作:torch 的 symbolic trace 支持的操作是有限的,例如不支持将非静态变量作为判断条件、默认不支持 torch 以外的 pkg(如 numpy)等,且未执行到的条件分支将被丢弃

  • 不想被 FX 处理的操作:如果模型的前后处理中使用了 torch 的 op,FX 在 trace 时会将他们视为模型的一部分,产生不符合预期的行为(例如将 torch 的某些 function 调用替换为 FloatFunctional)。

以上两种情况,都可以采用 wrap 的方法来避免,下面以 RetinaNet 为例进行说明。

from horizon_plugin_pytorch.fx.fx_helper import wrap as fx_wrap

class RetinaNet(nn.Module):
    def __init__(
        self,
        backbone: nn.Module,
        neck: Optional[nn.Module] = None,
        head: Optional[nn.Module] = None,
        anchors: Optional[nn.Module] = None,
        targets: Optional[nn.Module] = None,
        post_process: Optional[nn.Module] = None,
        loss_cls: Optional[nn.Module] = None,
        loss_reg: Optional[nn.Module] = None,
    ):
        super(RetinaNet, self).__init__()

        self.backbone = backbone
        self.neck = neck
        self.head = head
        self.anchors = anchors
        self.targets = targets
        self.post_process = post_process
        self.loss_cls = loss_cls
        self.loss_reg = loss_reg

    def rearrange_head_out(self, inputs: List[torch.Tensor], num: int):
        outputs = []
        for t in inputs:
            outputs.append(t.permute(0, 2, 3, 1).reshape(t.shape[0], -1, num))
        return torch.cat(outputs, dim=1)

    def forward(self, data: Dict):
        feat = self.backbone(data["img"])
        feat = self.neck(feat) if self.neck else feat
        cls_scores, bbox_preds = self.head(feat)

        if self.post_process is None:
            return cls_scores, bbox_preds

        # 将不需要建图的操作封装为一个 method 即可,FX 将不再关注 method 内部的逻辑,
        # 仅将它原样保留(method 中调用的 module 仍可被设置 qconfig,被
        # prepare_qat_fx 和 convert_fx 替换)
        return self._post_process( data, feat, cls_scores, bbox_preds)

    @fx_wrap()  # fx_wrap 支持直接装饰 class method
    def _post_process(self, data, feat, cls_scores, bbox_preds)
        anchors = self.anchors(feat)

        # 对 self.training 的判断必须封装起来,否则在 symbolic trace 之后,此判断
        # 逻辑会被丢掉
        if self.training:
            cls_scores = self.rearrange_head_out(
                cls_scores, self.head.num_classes
            )
            bbox_preds = self.rearrange_head_out(bbox_preds, 4)
            gt_labels = [
                torch.cat(
                    [data["gt_bboxes"][i], data["gt_classes"][i][:, None] + 1],
                    dim=-1,
                )
                for i in range(len(data["gt_classes"]))
            ]
            gt_labels = [gt_label.float() for gt_label in gt_labels]
            _, labels = self.targets(anchors, gt_labels)
            avg_factor = labels["reg_label_mask"].sum()
            if avg_factor == 0:
                avg_factor += 1
            cls_loss = self.loss_cls(
                pred=cls_scores.sigmoid(),
                target=labels["cls_label"],
                weight=labels["cls_label_mask"],
                avg_factor=avg_factor,
            )
            reg_loss = self.loss_reg(
                pred=bbox_preds,
                target=labels["reg_label"],
                weight=labels["reg_label_mask"],
                avg_factor=avg_factor,
            )
            return {
                "cls_loss": cls_loss,
                "reg_loss": reg_loss,
            }
        else:
            preds = self.post_process(
                anchors,
                cls_scores,
                bbox_preds,
                [torch.tensor(shape) for shape in data["resized_shape"]],
            )
            assert (
                "pred_bboxes" not in data.keys()
            ), "pred_bboxes has been in data.keys()"
            data["pred_bboxes"] = preds
            return data