6.1.6.8. 参考模型精度问题定位和处理方法

Samples中提供的参考模型为适用于地平线x3、j3、j5计算平台上的推荐方案。所包含原代码均为开源数据集训练的结果。本章将结束切换到自有数据集所遇到的问题和处理方法。

6.1.6.8.1. 训练不收敛

此问题一般是由于数据集的原因,可以按以下步骤排查:

检查数据集打包是否正常。直接读取打包数据,可以通过可视化或直接比较数据,检测是否和打包前一致。

dataset=dict(
    type="Coco",
    data_path="./tmp_data/mscoco/train_lmdb/",
    transforms=[
        dict(
            type="Resize",
            img_scale=(512, 512),
            ratio_range=(0.5, 2.0),
            keep_ratio=True,
        ),
    ]
)

检查下target和deocder是否正确, ground truth经过target后,再还原后,应该原来数据保持一致。

import torch
import torchvision
from hat.models.task_modules.fcos.target import DynamicFcosTarget
from hat.models.losses import FocalLoss, GIoULoss
from hat.data.datasets.mscoco import Coco
from hat.data.transforms import Resize, ToTensor
from hat.data.collates import collate_2d

target= DynamicFcosTarget(
            strides=[8, 16, 32, 64, 128],
            cls_out_channels=80,
            background_label=80,
            topK=10,
            loss_cls=FocalLoss(
                loss_name="cls",
                num_classes=80 + 1,
                alpha=0.25,
                gamma=2.0,
                loss_weight=1.0,
                reduction="none",
            ),
            loss_reg=GIoULoss(
                loss_name="reg",
                loss_weight=2.0,
                reduction="none"
            ),
        )

data_loader = torch.utils.data.DataLoader(
    dataset = Coco(
        data_path="./tmp_data/mscoco/val_lmdb/",
        transforms=torchvision.transforms.Compose(
            [Resize(
                img_scale=(512, 512),
                ratio_range=(0.5, 2.0),
                keep_ratio=True),
             ToTensor(
                to_yuv=False),
            ],
        )
    ),
    collate_fn=collate_2d,
)

batch_data = next(data_loader.__iter__())
gt_bbox = batch_data["gt_bboxes"][0]
gt_cls = batch_data["gt_classes"][0].view(-1, 1)
gts = torch.cat([gt_bbox, gt_cls], dim=1)

cls_target, giou_target, centerness_target = target(batch_data, feats)

cls = cls_target["target"]
reg = giou_target["target"]

feats = [
            [
                torch.zeros(1, 80, 64, 64),
                torch.zeros(1, 80, 32, 32),
                torch.zeros(1, 80, 16, 16),
                torch.zeros(1, 80, 8, 8),
                torch.zeros(1, 80, 4, 4)
            ],
            [
                torch.ones(1, 4, 64, 64),
                torch.ones(1, 4, 32, 32),
                torch.ones(1, 4, 16, 16),
                torch.ones(1, 4, 8, 8),
                torch.ones(1, 4, 4, 4)
            ],
            [
                torch.zeros(1, 1, 64, 64),
                torch.zeros(1, 1, 32, 32),
                torch.zeros(1, 1, 16, 16),
                torch.zeros(1, 1, 8, 8),
                torch.zeros(1, 1, 4, 4)
            ],
        ]

cls_target, giou_target, centerness_target = target(batch_data, feats)

cls = cls_target["target"]
reg = giou_target["target"]

image_size=(512, 512)
grid_xy_list = list()
strides = list()
for feat in feats[0]:
    grid_y, grid_x = torch.meshgrid(torch.arange(feat.shape[2]), torch.arange(feat.shape[3]))
    grid_xy = torch.stack([grid_x, grid_y], dim=2).float().to(device=feat.device) + 0.5
    s = image_size[0] / feat.shape[2]
    grid_xy = grid_xy
    grid_xy = grid_xy.view(-1, 2)
    stride = torch.ones((grid_xy.shape[0], 4)) * s
    strides.append(stride)
    grid_xy_list.append(grid_xy)
grid_xy_list = torch.cat(grid_xy_list)
strides = torch.cat(strides)

index = cls < 80
grid_xy = grid_xy_list[index]
strides = strides[index]
cls = cls[index].view(-1, 1)
#根据bbox的构成取相应值
reg[..., 0:2] = grid_xy +  reg[..., 0:2]
reg[..., 2:4] = reg[..., 2:4] + grid_xy
reg = reg * strides
target = torch.unique(torch.cat([reg, cls], dim=1), dim=0)
print(target)
print(gts)

6.1.6.8.2. 精度未达到目标

首先检查目标网络和参考模型中计算量是否未同一数量级,如果差别比较大建议使用较大的backbone。

python3 tools/calops.py --config configs/detection/fcos/fcos_efficientnetb0_mscoco.py --input-shape "1,3,512,512"