7.5.3. RGB888 数据部署

7.5.3.1. 场景

BPU 中图像金字塔的输出图像是 centered YUV444 的格式,其数据范围是 [-128, 127],但在训练阶段中,您的训练数据集有可能是 RGB 格式的,因此您需要对训练集的图片格式进行处理,避免出现训练的模型只能接受 RGB 的数据输入而无法正常上板推理的情况。通常,我们推荐您在训练时,在图像预处理阶段将 RGB 格式的图片转为 YUV 格式,与推理时 BPU 的数据流对齐。

由于编译器目前不支持颜色空间转换,用户可以手动插入颜色空间转换节点,从而绕过编译器的限制。

7.5.3.2. YUV 格式简介

YUV 一般用来描述模拟电视系统的颜色空间,在 BT.601 中 YUV 主要有两种制式:YUV studio swing(Y:16~235,UV:16~240)和 YUV full swing(YUV:0~255)。

BPU 支持的 YUV 格式是 full swing,因此在调用我们的工具中 YUV 的相关函数时,应确保指定了 full 作为 swing 格式。

7.5.3.3. 在训练时对 RGB 输入进行预处理

在训练时,您可以使用 horizon.functional.rgb2centered_yuvhorizon.functional.bgr2centered_yuv 将 RGB 图像转换为 BPU 所支持的 YUV 格式。以 rgb2centered_yuv 为例,该函数的定义如下:

def rgb2centered_yuv(input: Tensor, swing: str = "studio") -> Tensor:
    """Convert color space.

    Convert images from RGB format to centered YUV444 BT.601

    Args:
        input: input image in RGB format, ranging 0~255
        swing: "studio" for YUV studio swing (Y: -112~107,
                U, V: -112~112)
                "full" for YUV full swing (Y, U, V: -128~127).
                default is "studio"

    Returns:
        output: centered YUV image
    """

函数输入为 RGB 图像,输出为 centered YUV 图像。其中,centered YUV 是指减去了 128 的偏置的 YUV 图像,这是 BPU 图像金字塔输出的标准图像格式。对于 full swing 而言,其范围应为 -128~127。您可以通过 swing 参数控制 full 和 studio 的取向。为了和 BPU 数据流格式对齐,请您将 swing 设为 “full”

7.5.3.4. 在推理时对 YUV 输入进行实时转换

在任何情况下,我们都推荐您使用上述介绍的方案,即在训练时就将 RGB 图像转成 YUV 格式,这样可以避免在推理时引入额外的性能开销和精度损失。但如果您已经使用了 RGB 图像训练了模型,我们也提供了补救措施,通过在推理的时候在模型输入处插入颜色空间转换算子,将输入的 YUV 图像实时转换为 RGB 格式,从而支持 RGB 模型的上板部署,避免您重新训练模型给您带来时间成本和资源上的损失。由于该算子随模型运行在 BPU 上,底层采用定点运算实现,因而不可避免地会引入一定的精度损失,因此仅作为补救方案,请您尽可能按照我们所推荐的方式对数据进行处理。

7.5.3.4.1. 算子定义

您可以在推理模型的开头(QuantStub 的后面)插入 horizon.functional.centered_yuv2rgbhorizon.functional.centered_yuv2bgr 算子实现该功能。以 centered_yuv2rgb 为例,其定义为:

def centered_yuv2rgb(
    input: QTensor,
    swing: str = "studio",
    mean: Union[List[float], Tensor] = (128.0,),
    std: Union[List[float], Tensor] = (128.0,),
    q_scale: Union[float, Tensor] = 1.0 / 128.0,
) -> QTensor:

swing 为 YUV 的格式,可选项为 “full” 和 “studio”。为了和 BPU 的 YUV 数据格式对齐,请您将 swing 设为 “full”mean, std 均为您在训练时 RGB 图像所使用的归一化均值、标准差,支持 list 和 torch.Tensor 两种输入类型,支持单通道或三通道的归一化参数。如您的归一化均值为 [128, 0, -128] 时,您可以传入一个 [128., 0., -128.] 的 list 或 torch.tensor([128., 0., -128.])。 q_scale 为您在量化感知训练阶段所用的 QuantStub 的 scale 数值。支持 float 和 torch.Tensor 两种数据类型。

该算子完成了以下操作:

  1. 根据给定的 swing 所对应的转换公式将输入图像转换成 RGB 格式

  2. 使用给定的 meanstd 对 RGB 图像进行归一化

  3. 使用给定的 q_scale 对 RGB 图像进行量化

由于该算子已经包括了对 RGB 图像的量化操作,因此在插入这个算子后您需要手动地将模型 QuantStub 的 scale 参数更改为 1。

插入该算子后的部署模型如下图所示:

../../../_images/yuv1.svg

注意

该算子为部署专用算子,请勿在训练阶段使用该算子。

7.5.3.4.2. 使用方法

在您使用 RGB 图像完成量化感知训练后,您需要:

  1. 获取量化感知训练时模型 QuantStub 所使用的 scale 值,以及 RGB 图像所使用的归一化参数;

  2. 调用 convert_fx 接口将 qat 模型转换为 quantized 模型;

  3. 在模型的 QuantStub 后面插入 centered_yuv2rgb 算子,算子需要传入步骤 1 中所获取的参数;

  4. 将 QuantStub 的 scale 参数修改成 1。

示例:

import torch
from horizon_plugin_pytorch.quantization import (
    QuantStub,
    prepare_qat_fx,
    convert_fx,
)
from horizon_plugin_pytorch.functional import centered_yuv2rgb
from horizon_plugin_pytorch.quantization.qconfig import (
    default_qat_8bit_fake_quant_qconfig,
)
from horizon_plugin_pytorch import March, set_march

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = QuantStub()
        self.conv = torch.nn.Conv2d(3, 3, 3)
        self.bn = torch.nn.BatchNorm2d(3)
        self.relu = torch.nn.ReLU()

    def forward(self, input):
        x = self.quant(input)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

    def set_qconfig(self):
        self.qconfig = default_qat_8bit_fake_quant_qconfig


data = torch.rand(1, 3, 28, 28)
net = Net()

set_march(March.XXX)

net.set_qconfig()
qat_net = prepare_qat_fx(net)
qat_net(data)
quantized_net = convert_fx(qat_net)
traced = quantized_net
print("Before centered_yuv2rgb")
traced.graph.print_tabular()

# Replace QuantStub nodes with centered_yuv2rgb
patterns = ["quant"]
for n in traced.graph.nodes:
    if any(n.target == pattern for pattern in patterns):
        with traced.graph.inserting_after(n):
            new_node = traced.graph.call_function(centered_yuv2rgb, (n,), {"swing": "full"})
            n.replace_all_uses_with(new_node)
            new_node.args = (n,)

traced.quant.scale.fill_(1.0)
traced.recompile()
print("\nAfter centered_yuv2rgb")
traced.graph.print_tabular()

对比前后 Graph 可以看到修改后的图中插入了颜色空间转换节点:

Before centered_yuv2rgb
opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  input_1  input     ()          {}
call_module  quant    quant     (input_1,)  {}
call_module  conv     conv      (quant,)    {}
output       output   output    (conv,)     {}

After centered_yuv2rgb
opcode         name              target                                         args                 kwargs
-------------  ----------------  ---------------------------------------------  -------------------  -----------------
placeholder    input_1           input                                          ()                   {}
call_module    quant             quant                                          (input_1,)           {}
call_function  centered_yuv2rgb  <function centered_yuv2rgb at 0x7fa1c2b48040>  (quant,)             {'swing': 'full'}
call_module    conv              conv                                           (centered_yuv2rgb,)  {}
output         output            output                                         (conv,)              {}