7.4.6. 量化部署 PT 模型的跨设备 Inference 说明

量化部署的 pt 模型要求 trace 时使用的 device 和后续 infer 时使用的 device 一致。

若用户试图直接通过 to(device) 操作修改 pt 模型的 device,可能会出现模型 forward 报错的问题,torch 官方对此进行了解释,见 TorchScript-Frequently Asked Questions — PyTorch documentation

下面举例说明:

import torch


class Net(torch.nn.Module):
    def forward(self, x: torch.Tensor):
        y = torch.ones(x.shape, device=x.device)
        z = torch.zeros_like(x)

        return y + z


script_mod = torch.jit.trace(
    Net(), torch.rand(2, 3, 3, 3, device=torch.device("cpu"))
)
script_mod.to(torch.device("cuda"))
print(script_mod.graph)

# graph(%self : __torch__.Net,
#       %x : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu)):
#   %4 : int = prim::Constant[value=0]()
#   %5 : int = aten::size(%x, %4)
#   %6 : Long(device=cpu) = prim::NumToTensor(%5)
#   %16 : int = aten::Int(%6)
#   %7 : int = prim::Constant[value=1]()
#   %8 : int = aten::size(%x, %7)
#   %9 : Long(device=cpu) = prim::NumToTensor(%8)
#   %17 : int = aten::Int(%9)
#   %10 : int = prim::Constant[value=2]()
#   %11 : int = aten::size(%x, %10)
#   %12 : Long(device=cpu) = prim::NumToTensor(%11)
#   %18 : int = aten::Int(%12)
#   %13 : int = prim::Constant[value=3]()
#   %14 : int = aten::size(%x, %13)
#   %15 : Long(device=cpu) = prim::NumToTensor(%14)
#   %19 : int = aten::Int(%15)
#   %20 : int[] = prim::ListConstruct(%16, %17, %18, %19)
#   %21 : NoneType = prim::Constant()
#   %22 : NoneType = prim::Constant()
#   %23 : Device = prim::Constant[value="cpu"]()
#   %24 : bool = prim::Constant[value=0]()
#   %y : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::ones(%20, %21, %22, %23, %24)
#   %26 : int = prim::Constant[value=6]()
#   %27 : int = prim::Constant[value=0]()
#   %28 : Device = prim::Constant[value="cpu"]()
#   %29 : bool = prim::Constant[value=0]()
#   %30 : NoneType = prim::Constant()
#   %z : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::zeros_like(%x, %26, %27, %28, %29, %30)
#   %32 : int = prim::Constant[value=1]()
#   %33 : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::add(%y, %z, %32)
#   return (%33)

可以看到,在调用 to(torch.device("cuda")) 后,模型的 graph 中记录的 aten::onesaten::zeros_like 的 device 参数仍为 prim::Constant[value="cpu"](),因此在模型 forward 时,它们的输出仍为 cpu Tensor。这是因为 to(device) 只能移动模型中的 buffer(weight、bias 等),无法修改 ScriptModule 的 graph。

torch 官方对以上限制给出的解决方案是,在 trace 前就确定好 pt 模型将要在哪个 device 上执行,并在对应的 device 上 trace 即可。

针对以上限制,训练工具建议根据具体场景选择以下解决方案:

7.4.6.1. PT 模型执行使用的 device 和 trace 不一致

对于可以确定 pt 模型将仅在 GPU 上执行,只需要修改卡号的情况,我们首先推荐使用 cuda:0,即零号卡进行 trace。在使用模型时,用户可以通过 torch.cuda.set_device 接口,将物理上的任意卡映射为逻辑上的“零卡”,此时使用 cuda:0 trace 出的模型实际将在指定的物理卡上运行。

若 trace 时使用的 device 和执行时使用的 device 存在 CPU、GPU 的不一致,用户可以使用 horizon_plugin_pytorch.jit.to_device 接口实现 pt 模型的 device 迁移。此接口会寻找模型 graph 中的 device 参数,并将它们替换为需要的值。效果如下:

from horizon_plugin_pytorch.jit import to_device

script_mod = to_device(script_mod, torch.device("cuda"))
print(script_mod.graph)

# graph(%self : __torch__.Net,
#       %x.1 : Tensor):
#   %38 : bool = prim::Constant[value=0]()
#   %60 : Device = prim::Constant[value="cuda"]()
#   %34 : NoneType = prim::Constant()
#   %3 : int = prim::Constant[value=0]()
#   %10 : int = prim::Constant[value=1]()
#   %17 : int = prim::Constant[value=2]()
#   %24 : int = prim::Constant[value=3]()
#   %41 : int = prim::Constant[value=6]()
#   %4 : int = aten::size(%x.1, %3)
#   %5 : Tensor = prim::NumToTensor(%4)
#   %8 : int = aten::Int(%5)
#   %11 : int = aten::size(%x.1, %10)
#   %12 : Tensor = prim::NumToTensor(%11)
#   %15 : int = aten::Int(%12)
#   %18 : int = aten::size(%x.1, %17)
#   %19 : Tensor = prim::NumToTensor(%18)
#   %22 : int = aten::Int(%19)
#   %25 : int = aten::size(%x.1, %24)
#   %26 : Tensor = prim::NumToTensor(%25)
#   %32 : int = aten::Int(%26)
#   %33 : int[] = prim::ListConstruct(%8, %15, %22, %32)
#   %y.1 : Tensor = aten::ones(%33, %34, %34, %60, %38)
#   %z.1 : Tensor = aten::zeros_like(%x.1, %41, %3, %60, %38, %34)
#   %50 : Tensor = aten::add(%y.1, %z.1, %10)
#   return (%50)

7.4.6.2. 多卡并行推理

在此场景下,用户需要通过 trace 或 to_device 的方式取得 cuda:0 上的 pt 模型,并且为每块卡单独开启一个进程,通过 torch.cuda.set_device 的方式为每个进程设置不同的默认卡。一个简单的示例如下:

import os
import torch
import signal
import torch.distributed as dist
import torch.multiprocessing as mp
from horizon_plugin_pytorch.jit import to_device

model_path = "path_to_pt_model_file"

def main_func(rank, world_size, device_ids):
    torch.cuda.set_device(device_ids[rank])
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    model = to_device(torch.jit.load(model_path), torch.device("cuda"))

    # 数据加载,模型 forward,精度计算等内容此处省略

def launch(device_ids):
    try:
        world_size = len(device_ids)
        mp.spawn(
            main_func,
            args=(world_size, device_ids),
            nprocs=world_size,
            join=True,
        )
    # 当按下 Ctrl+c 时,关闭所有子进程
    except KeyboardInterrupt:
        os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)

launch([0, 1, 2, 3])

上述操作对 pt 模型的处理和 torch.nn.parallel.DistributedDataParallel 的做法一致,数据加载和模型精度计算相关内容请参考 Getting Started with Distributed Data Parallel — PyTorch Tutorials