10.1.3.3. 如何开启 AMP

AMP全称为Automatic Mixed Precision,即自动混合精度。AMP开启后,pytorch可以自动地在模型执行时将一些算子(如卷积和全连接)使用 float16 进行计算,以达到提升计算速度、减少显存占用的效果。详见 pytorch官方文档

HAT中已经为AMP做好相关的工作,用户只需要在定义config文件中的 batch_processor 字段时将 enable_amp 参数设置为 True 即可。

注解

在模型验证时为得到准确的指标,一般是不需开启AMP的,在定义 val_batch_processor 字段时请将 enable_amp 参数设置为 False


# configs/example.py

# 使用 BasicBatchProcessor
batch_processor = dict(
    type='BasicBatchProcessor',
    need_grad_update=...,
    batch_transforms=...,
    enable_amp=True,
)

# 使用 MultiBatchProcessor
batch_processor = dict(
    type="MultiBatchProcessor",
    need_grad_update=...,
    batch_transforms=...,
    loss_collector=...,
    enable_amp=True,
)