10.1.3.7. FCOS-EfficientNetB0的config构造详细说明

为了帮助您对一个完整的 config 文件有个更加清楚的了解,这篇文档以 FCOS-EfficientNetB0 模型为例,对它的每一个模块增加一个简短的注释供您参考,如下所示:


VERSION = ConfigVersion.v2       # config文件的版本号,默认是v2即可
training_step = os.environ.get("HAT_TRAINING_STEP", "float")    # 用户想训练的阶段,通常通过命令行的--stage设置

task_name = "fcos_efficientnetb0_mscoco"   # 给当前的训练任务设置一个name
num_classes = 80                           # 参与训练的数据集的类别
batch_size_per_gpu = 24                    # 每张卡的batch_size大小
device_ids = [0, 1, 2, 3]                  # 参与训练使用的gpu卡号
ckpt_dir = "./tmp_models/%s" % task_name   # 模型存储的路径
cudnn_benchmark = True                     # 是否设置torch.backends.cudnn.benchmark=True
seed = None                                # 是否设置随机种子
log_rank_zero_only = True                  # 是否只在第0卡上打印log信息
bn_kwargs = {}                             # bn的参数,{}表示使用torch的默认参数
march = March.BAYES                        # 模型最终部署到什么架构的计算平台上,默认为 March.BAYES

# 参与训练的模型的相关配置
model = dict(                       
    type="FCOS",                    # 检测模型的类型,这里使用FCOS检测模型
    backbone=dict(                  # 检测模型的backbone相关配置
        type="efficientnet",        # backbone使用的模型,这里使用efficientnet模型
        bn_kwargs=bn_kwargs,        # backbone的bn配置
        model_type="b0",            # 使用efficientnet模型系列的b0结构
        num_classes=1000,           # efficientnet作为分类模型时分类的类别,这里作为检测模型的backbone时,num_classes其实没有起到作用
        include_top=False,          # 是否包括efficientnet的分类层,因为efficientnet是作为backbone提取特征的,所以不需要分类层
        activation="relu",          # backbone的激活层,这里使用的是relu
        use_se_block=False,         # backbone是否使用se_block模块,默认不使用
    ),                              # PS: 关于backbone每个参数配置的详细解释,可以参考efficientnet的API文档
    neck=dict(                          # 检测模型的neck相关配置
        type="BiFPN",                   # neck使用BiFPN
        in_strides=[2, 4, 8, 16, 32],           # 输入feature对应的stride
        out_strides=[8, 16, 32, 64, 128],       # 输出feature对应的stride
        stride2channels=dict({2: 16, 4: 24, 8: 40, 16: 112, 32: 320}),     # 输入feature的stride和channel的对应关系
        out_channels=64,                        # 输出feature的channel
        num_outs=5,                             # 输出feature的数量
        stack=3,                                # BifpnLayer层的数量
        start_level=2,                          # backbone的第一个输出feature的索引
        end_level=-1,                           # backbone的最后一个输出feature的索引
        fpn_name="bifpn_sum",                   # fpn_name,和权重初始化的方式相关
    ),                                  # PS: 关于neck每个参数配置的详细解释,可以参考BiFPN的API文档
    head=dict(                      # 检测模型的head相关配置
        type="FCOSHead",            # head使用FCOSHead
        num_classes=num_classes,    # 检测数据集的类别
        in_strides=[8, 16, 32, 64, 128],        # 输入feature对应的stride
        out_strides=[8, 16, 32, 64, 128],       # 输出feature对应的stride
        stride2channels=dict({8: 64, 16: 64, 32: 64, 64: 64, 128: 64}),   # 输入feature的stride和channel的对应关系
        upscale_bbox_pred=False,                # 是否需要上采样bbox_pred
        feat_channels=64,           # 输入feature的channel
        stacked_convs=4,            # 连续conv的次数
        int8_output=False,          # 输出是否设置为int8
        int16_output=True,          # 输出是否设置为int16
        dequant_output=True,        # 输出是否需要反量化
    ),                              # PS: 关于head每个参数配置的详细解释,可以参考FCOSHead的API文档
    targets=dict(                           # 检测模型的target相关配置
        type="DynamicFcosTarget",           # target使用DynamicFcosTarget
        strides=[8, 16, 32, 64, 128],       # 输入feature对应的stride
        cls_out_channels=80,                # 分类的类别
        background_label=80,                # 背景的类别标签
        topK=10,                            # 对于每个ground truth最多保留的正样本数              
        loss_cls=dict(                      # 分类的损失函数设置,用于动态生成target
            type="FocalLoss",               # 这里使用FocalLoss损失函数
            loss_name="cls",                
            num_classes=80 + 1,             
            alpha=0.25,                     
            gamma=2.0,
            loss_weight=1.0,
            reduction="none",
        ),                                  # PS: 关于参数配置的详细解释,可以参考FocalLoss的API文档
        loss_reg=dict(                      # 回归的损失函数设置,使用GIoULoss损失函数
            type="GIoULoss", loss_name="reg", loss_weight=2.0, reduction="none"
        ),                                  # PS: 关于参数配置的详细解释,可以参考GIoULoss的API文档
    ),                                      # PS: 关于targets每个参数配置的详细解释,可以参考DynamicFcosTarget的API文档
    post_process=dict(                          # 检测模型的后处理相关配置
        type="FCOSMultiStrideFilter",           # 后处理使用FCOSMultiStrideFilter
        strides=[8, 16, 32, 64, 128],           # 输出feature对应的stride
        threshold=-2.944,                       # FilterModule OP中使用的阈值
        for_compile=False,                      # 模型是否支持compile
        score_threshold=0.05,                   # 过滤框是使用的score阈值
        iou_threshold=0.6,                      # nms时候的iou阈值
        max_shape=(512, 512),                   # 根据max_shape对检测框做clamp            
    ),                                          # PS: 关于后处理每个参数配置的详细解释,可以参考FCOSMultiStrideFilter的API文档
    loss_cls=dict(                      # cls分支使用的损失函数
        type="FocalLoss",
        loss_name="cls",
        num_classes=80 + 1,
        alpha=0.25,
        gamma=2.0,
        loss_weight=1.0,
    ),
    loss_centerness=dict(               # centerness分支使用的损失函数,具体参数可以参考CrossEntropyLoss的API文档
        type="CrossEntropyLoss", loss_name="centerness", use_sigmoid=True
    ),
    loss_reg=dict(                      # reg分支使用的损失函数,具体参数可以参考GIoULoss的API文档
        type="GIoULoss",
        loss_name="reg",
        loss_weight=1.0,
    ),
)

# 和model的定义基本相同,deploy_model主要用于模型编译和上板,因此没有loss部分。deploy_model通常用在int_infer阶段
deploy_model = dict(           
    type="FCOS",
    backbone=dict(
        type="efficientnet",
        bn_kwargs=bn_kwargs,
        model_type="b0",
        num_classes=1000,
        include_top=False,
        activation="relu",
        use_se_block=False,
    ),
    neck=dict(
        type="BiFPN",
        in_strides=[2, 4, 8, 16, 32],
        out_strides=[8, 16, 32, 64, 128],
        stride2channels=dict({2: 16, 4: 24, 8: 40, 16: 112, 32: 320}),
        out_channels=64,
        num_outs=5,
        stack=3,
        start_level=2,
        end_level=-1,
        fpn_name="bifpn_sum",
    ),
    head=dict(
        type="FCOSHead",
        num_classes=num_classes,
        in_strides=[8, 16, 32, 64, 128],
        out_strides=[8, 16, 32, 64, 128],
        stride2channels=dict({8: 64, 16: 64, 32: 64, 64: 64, 128: 64}),
        upscale_bbox_pred=False,
        feat_channels=64,
        stacked_convs=4,
        int8_output=False,
        int16_output=True,
        dequant_output=False,
    ),
    post_process=dict(
        type="FCOSMultiStrideFilter",
        strides=[8, 16, 32, 64, 128],
        threshold=-2.944,
        for_compile=True,
        max_shape=(512, 512),
    ),
)
# 编译deploy_model时使用的输入
deploy_inputs = dict(img=torch.randn((1, 3, 512, 512))) 

# 使deploy_model从float转成quantize的过程,用于验证模型是否可以编译
deploy_model_convert_pipeline = dict(     
    type="ModelConvertPipeline",           
    qat_mode="fuse_bn",                     # qat模型,可以选择fuse_bn, with_bn 和with_bn_reverse_fold
    converters=[
        dict(type="Float2QAT"),             # 模型由float变成qat
        dict(type="QAT2Quantize"),          # 模型由qat变成quantize
    ],
)

# 训练数据集的加载过程
data_loader = dict(                               
    type=torch.utils.data.DataLoader,             # 使用torch原生的DataLoader接口
    dataset=dict(                                 # 获取dataset的过程
        type="Coco",                              # 对应coco的dataset读取接口
        data_path="./tmp_data/mscoco/train_lmdb/",          # 数据集的路径
        transforms=[                                        # 数据的transforms过程
            dict(                                           
                type="Resize",                              # Resize操作
                img_scale=(512, 512),                       # resize之后的图像大小
                ratio_range=(0.5, 2.0),                     # 图像缩放的范围
                keep_ratio=True,                            # 缩放的过程是否保持长宽比
            ),
            dict(type="RandomCrop", size=(512, 512)),       # RandomCrop的操作
            dict(                                           # Pad的操作
                type="Pad",
                divisor=512,                                # Pad之后的图像长宽是512的倍数
            ),
            dict(                                           # RandomFlip的操作
                type="RandomFlip",
                px=0.5,                                     # x方向翻转的概率
                py=0,                                       # y方向翻转的概率
            ),
            dict(type="AugmentHSV", hgain=0.015, sgain=0.7, vgain=0.4),   # AugmentHSV的操作
            dict(
                type="ToTensor",                            # 数据由numpy转成Tensor
                to_yuv=True,                                # 图像是否转成yuv格式
            ),
            dict(                                           # 数据归一化操作,由[0,255]转成[-1,1]
                type="Normalize",
                mean=128.0,
                std=128.0,
            ),
        ],
    ),
    sampler=dict(type=torch.utils.data.DistributedSampler),    # DDP训练模式下数据集的采样方式
    batch_size=batch_size_per_gpu,                             # 单卡的batch_size大小
    shuffle=True,                                              # 是否打乱数据
    num_workers=8,                                             # 数据读取的进程数
    pin_memory=True,                                           # 是否使用pin_memory
    collate_fn=hat.data.collates.collate_2d,                   # 把多张图片数据打包成一个batch数据的方式
)                                                 # PS:关于DataLoader每个参数的详细解释,可以参考torch的官方文档

# 验证数据集的加载流程,和训练数据集同理      
val_data_loader = dict(                                   
    type=torch.utils.data.DataLoader,
    dataset=dict(
        type="Coco",
        data_path="./tmp_data/mscoco/val_lmdb/",
        transforms=[
            dict(
                type="Resize",
                img_scale=(512, 512),
                keep_ratio=True,
            ),
            dict(
                type="Pad",
                size=(512, 512),
            ),
            dict(
                type="ToTensor",
                to_yuv=True,
            ),
            dict(
                type="Normalize",
                mean=128.0,
                std=128.0,
            ),
        ],
    ),
    batch_size=batch_size_per_gpu,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    collate_fn=hat.data.collates.collate_2d,
)

# outputs通常为model的输出,此函数用于获取model输出中的loss部分,从而进行后面的梯度更新
def loss_collector(outputs: dict):
    losses = []
    for _, loss in outputs.items():
        losses.append(loss)
    return losses

# 更新loss,通常用于在模型训练的过程中打印loss,可以和下面被调用的地方一起理解
def update_loss(metrics, batch, model_outs):
    for metric in metrics:
        metric.update(model_outs)

# 定义在训练过程中打印loss的函数
loss_show_update = dict(
    type="MetricUpdater",
    metric_update_func=update_loss,
    step_log_freq=1,
    epoch_log_freq=1,
    log_prefix="loss_ " + task_name,
)

# 训练数据集每次迭代的处理方式
batch_processor = dict(
    type="MultiBatchProcessor",
    need_grad_update=True,          # 是否进行梯度更新
    loss_collector=loss_collector,  # 获取loss的方式
)
# 验证数据集每次迭代的处理方式
val_batch_processor = dict(
    type="MultiBatchProcessor", 
    need_grad_update=False,         # 是否进行梯度更新
)

# 模型指标的更新方式,这里的指标是mAP
def update_metric(metrics, batch, model_outs):
    for metric in metrics:
        metric.update(model_outs)

# 模型验证过程中,验证指标的更新方式
val_metric_updater = dict(
    type="MetricUpdater",
    metric_update_func=update_metric,
    step_log_freq=500,
    epoch_log_freq=1,
    log_prefix="Validation " + task_name,
)

# 设置根据某个频率打印训练的log信息,比如loss值
stat_callback = dict(
    type="StatsMonitor",
    log_freq=1,
)

# 对模型进行trace,保存相应的pt文件
trace_callback = dict(
    type="SaveTraced",
    save_dir=ckpt_dir,
    trace_inputs=deploy_inputs,
)

# 保存模型权重
ckpt_callback = dict(
    type="Checkpoint",
    save_dir=ckpt_dir,
    name_prefix=training_step + "-",
    save_interval=1,
    strict_match=True,
    mode="max",
    monitor_metric_key="mAP",
)

# 训练结束对模型进行验证
val_callback = dict(
    type="Validation",
    data_loader=val_data_loader,
    batch_processor=val_batch_processor,
    callbacks=[val_metric_updater],
    val_model=None,
    init_with_train_model=False,
    val_interval=1,
    val_on_train_end=True,
)

# 浮点模型训练的相关设置
float_trainer = dict(  
    type="distributed_data_parallel_trainer",      # 设置DDP训练
    model=model,                                   # 参与训练的模型
    data_loader=data_loader,                       # 参与训练的数据集 
    optimizer=dict(                                # 优化器的设置
        type=torch.optim.SGD,
        params={"weight": dict(weight_decay=4e-5)},
        lr=0.14,
        momentum=0.937,
        nesterov=True,
    ),
    batch_processor=batch_processor,               # 每次迭代的处理方式
    num_epochs=300,                                # 模型训练的epoch次数
    device=None,                                   # 模型训练的device
    callbacks=[                                    # 模型训练过程中会调用的callbacks
        stat_callback,
        loss_show_update,
        dict(type="ExponentialMovingAverage"),
        dict(
            type="CosLrUpdater",
            warmup_len=2,
            warmup_by="epoch",
            step_log_interval=1,
        ),
        val_callback,
        ckpt_callback,
    ],
    train_metrics=dict(                            # 训练过程中的metric,主要用于打印loss
        type="LossShow",
    ),
    sync_bn=True,                                  # 是否同步BN
    val_metrics=dict(                              # 验证过程中的metric,主要用于打印指标
        type="COCODetectionMetric",
        ann_file="./tmp_data/mscoco/instances_val2017.json",
    ),
)


calibration_data_loader = copy.deepcopy(data_loader)    # 参与calibration的数据集
calibration_data_loader.pop("sampler")                  # calibration只能单卡跑,因此不需要sample
calibration_batch_processor = copy.deepcopy(val_batch_processor)    # calibration过程中每个迭代数据的处理方式

# calibration过程的相关设置
calibration_trainer = dict(
    type="Calibrator",
    model=model,
    model_convert_pipeline=dict(                # 用于把模型从浮点模型转成calibration的模型
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="LoadCheckpoint",              # calibration的开始之前需要load浮点的模型
                checkpoint_path=os.path.join(
                    ckpt_dir, "float-checkpoint-best.pth.tar"
                ),
            ),
            dict(type="Float2Calibration"),       # 把浮点模型转成calibration的模型
        ],
    ),
    data_loader=calibration_data_loader,
    batch_processor=calibration_batch_processor,
    num_steps=10,             # calibration的迭代次数
    device=None,
    callbacks=[
        stat_callback,
        val_callback,
        ckpt_callback,
    ],
    val_metrics=dict(
        type="COCODetectionMetric",
        ann_file="./tmp_data/mscoco/instances_val2017.json",
    ),
    log_interval=1,
)

# qat模型的训练配置,参数含义可以参考float_trainer
qat_trainer = dict(
    type="distributed_data_parallel_trainer",
    model=model,
    model_convert_pipeline=dict(        # 用于把模型从浮点模型转成qat的模型
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(type="Float2QAT"),     # 把模型从浮点模型转成qat的模型
            dict(                       # 加载calibration之后的模型
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(            
                    ckpt_dir, "calibration-checkpoint-best.pth.tar"
                ),
            ),
        ],
    ),
    data_loader=data_loader,
    optimizer=dict(
        type=torch.optim.SGD,
        params={"weight": dict(weight_decay=4e-5)},
        lr=0.001,        # 学习率通常设置为float训练的十分之一
        momentum=0.9,
    ),
    batch_processor=batch_processor,
    num_epochs=10,      # qat的训练epoch次数通常远小于float的训练次数
    device=None,
    callbacks=[
        stat_callback,
        loss_show_update,
        dict(
            type="StepDecayLrUpdater",
            lr_decay_id=[4],
            step_log_interval=500,
        ),
        val_callback,
        ckpt_callback,
    ],
    train_metrics=dict(
        type="LossShow",
    ),
    val_metrics=dict(
        type="COCODetectionMetric",
        ann_file="./tmp_data/mscoco/instances_val2017.json",
    ),
)


# int_infer模型的训练配置,通常这一阶段不会进行训练,只通过调用callback保存quantize的模型参数和模型的pt文件
int_infer_trainer = dict(
    type="Trainer",
    model=deploy_model,                     # 这里是deploy_model
    model_convert_pipeline=dict(            # 用于把模型从浮点模型转成quantize的模型
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(type="Float2QAT"),          # 把float模型转成qat模型
            dict(                            # 加载qat模型的参数
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "qat-checkpoint-best.pth.tar"
                ),
                ignore_extra=True,
            ),
            dict(type="QAT2Quantize"),       # 把qat模型转成quantize模型
        ],
    ),
    data_loader=None,
    optimizer=None,
    batch_processor=None,
    num_epochs=0,                            # epoch设置为0,跳过训练
    device=None,
    callbacks=[
        ckpt_callback,                       # 用于保存quantize模型的参数
        trace_callback,                      # 用于保存quantize模型的pt文件
    ],
)

# 模型编译相关的配置选项
compile_dir = os.path.join(ckpt_dir, "compile")
compile_cfg = dict(
    march=march,
    name="fcos_effb0_test_model",
    out_dir=compile_dir,
    hbm=os.path.join(compile_dir, "model.hbm"),
    layer_details=True,
    input_source=["pyramid"],
    opt="O3",
)

# 浮点模型预测相关的配置
float_predictor = dict(
    type="Predictor",              # 类型为Predictor
    model=model,                   # 参与Predictor的模型
    model_convert_pipeline=dict(      # 对于浮点模型而言,predict之前需要加载浮点的模型参数
        type="ModelConvertPipeline",
        converters=[
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "float-checkpoint-best.pth.tar"
                ),
            ),
        ],
    ),
    data_loader=[val_data_loader],     # 参与predict的数据集
    batch_processor=val_batch_processor,   # 数据集每次迭代的处理方式
    device=None,
    metrics=dict(                    # predict过程中和结束的时候打印的指标
        type="COCODetectionMetric",
        ann_file="./tmp_data/mscoco/instances_val2017.json",
    ),
    callbacks=[
        val_metric_updater,
    ],
    log_interval=50,
)

# qat模型预测相关的配置,参数含义和float一致
qat_predictor = dict(
    type="Predictor",
    model=model,
    model_convert_pipeline=dict(         # 用于把模型从浮点模型转成qat的模型   
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(type="Float2QAT"),      # 把浮点模型转成qat模型
            dict(                        # predict之前需要加载qat的模型参数
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "qat-checkpoint-best.pth.tar"
                ),
                ignore_extra=True,
            ),
        ],
    ),
    data_loader=[val_data_loader],
    batch_processor=val_batch_processor,
    device=None,
    metrics=dict(
        type="COCODetectionMetric",
        ann_file="./tmp_data/mscoco/instances_val2017.json",
    ),
    callbacks=[
        val_metric_updater,
    ],
    log_interval=50,
)

# quantize模型预测相关的配置,参数含义和float一致
int_infer_predictor = dict(
    type="Predictor",
    model=model,
    model_convert_pipeline=dict(          # 用于把模型从浮点模型转成quantize模型   
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(type="Float2QAT"),            # 浮点模型转成qat
            dict(                              # 预测之前需要加载qat的模型参数
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "qat-checkpoint-best.pth.tar"
                ),
                ignore_extra=True,
            ),
            dict(type="QAT2Quantize"),         # qat模型转成quantize模型
        ],
    ),
    data_loader=[val_data_loader],
    batch_processor=val_batch_processor,
    device=None,
    metrics=dict(
        type="COCODetectionMetric",
        ann_file="./tmp_data/mscoco/instances_val2017.json",
    ),
    callbacks=[
        val_metric_updater,
    ],
    log_interval=50,
)