Swin Transformer實(shí)戰(zhàn)實(shí)例分割:訓(xùn)練自己的數(shù)據(jù)集

剛好用了swin,做分類(lèi)和目標(biāo)檢測(cè),也來(lái)試試分割的,而且很多場(chǎng)景分割更有效果,比如積水識(shí)別,安全帶,土坑裂縫等等
案例來(lái)自比賽
https://www.dcic-china.com/competitions/10021
這回真是小試牛刀了,因?yàn)槭侵腔坜r(nóng)業(yè)賽題——牛只圖像分割競(jìng)賽

數(shù)據(jù)任務(wù)

以牛只實(shí)例分割圖像數(shù)據(jù)作為訓(xùn)練樣本,參賽選手需基于訓(xùn)練樣本構(gòu)建模型,對(duì)提供的測(cè)試集中的牛只圖像進(jìn)行實(shí)例分割檢測(cè)。方法不限于實(shí)例分割,
目標(biāo)檢測(cè)是識(shí)別圖像中存在的內(nèi)容和檢測(cè)其位置,
語(yǔ)義分割是對(duì)圖像中的每個(gè)像素打上類(lèi)別標(biāo)簽,實(shí)例分割其實(shí)是目標(biāo)檢測(cè)和語(yǔ)義分割的結(jié)合,在圖像中將目標(biāo)檢測(cè)出來(lái)(目標(biāo)檢測(cè)),然后對(duì)每個(gè)像素打上標(biāo)簽(語(yǔ)義分割)。 語(yǔ)義分割不區(qū)分屬于相同類(lèi)別的不同實(shí)例(所有人都標(biāo)為紅色),實(shí)例分割區(qū)分同類(lèi)的不同實(shí)例(使用不同顏色區(qū)分不同的人)。所以題目嚴(yán)格意義是語(yǔ)義分割,但是要標(biāo)出每只牛。

看下數(shù)據(jù)就是牛棚里的,標(biāo)簽是多邊形 polygon,標(biāo)的還是一般,數(shù)據(jù)有點(diǎn)模糊,而且角度是俯視,用coco的cow恐怕差太多,圖片比較少訓(xùn)練200張,牛2千多,測(cè)試100張;
可以通過(guò)標(biāo)注軟件查看和修改標(biāo)注,比如cvat,不過(guò)這個(gè)是網(wǎng)頁(yè)版


image.png

這里就先用swin-transformer-object-detection跑個(gè)baseline吧
https://github.com/SwinTransformer/Swin-Transformer-Object-Detection

數(shù)據(jù)處理

環(huán)境部署

部署swin環(huán)境同目標(biāo)檢測(cè),這里用paddlex來(lái)處理數(shù)據(jù)集,所以部署環(huán)境

conda create -n paddlex python=3.7
conda activate paddlex
pip install cython
git clone https://github.com/philferriere/cocoapi.git
cd .\cocoapi\PythonAPI
python3 setup.py build_ext install
pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
pip install paddlex -i https://mirror.baidu.com/pypi/simpleor
git clone https://github.com/PaddlePaddle/PaddleX.git
cd PaddleX
git checkout develop
python setup.py install
環(huán)境搭建好了,
現(xiàn)在將數(shù)據(jù)集格式為

image.png

劃分?jǐn)?shù)據(jù)集,也可以省略測(cè)試集
paddlex --split_dataset --format COCO --dataset_dir 200 --val_value 0.2 --test_value 0.1
將圖片轉(zhuǎn)移到訓(xùn)練和驗(yàn)證文件夾

# -*- coding: utf-8 -*-
"""
Created on Wed Mar 29 09:20:40 2017

@author: yiyi
"""
import json
import os
import os, random, shutil
from shutil import copy2


path='E:/workspace/Swin-Transformer-Object-Detection/data/cow/200/'
valDir = 'E:/workspace/Swin-Transformer-Object-Detection/data/cow/200/val2017/'
trainDir = 'E:/workspace/Swin-Transformer-Object-Detection/data/cow/200/train2017/'


fp=open('./200/train.json','r')
data=json.load(fp)
images = []
for fi in data['images'] :
    images.append(fi['file_name'])
print(images)
fp.close()


for v in images:
    (file, filename) = os.path.split(v)
    shutil.copy(path + v, trainDir + filename ) #, follow_symlinks=False)

fp=open('./200/val.json','r')
data=json.load(fp)
images = []
for fi in data['images'] :
    images.append(fi['file_name'])
print(images)
fp.close()


for v in images:
    (file, filename) = os.path.split(v)
    shutil.copy(path + v, valDir + filename ) #, follow_symlinks=False)

模型配置

處理好圖片就開(kāi)始配置swinT了,

修改配置

修改configs\_base_\models\mask_rcnn_swin_fpn.py中num_classes

兩個(gè)地方 改為具體類(lèi)別數(shù) 設(shè)為1

修改configs\_base_\default_runtime.py中interval,load_from

root@k8s-master1:/media/nizhengqi/7a646073-10bf-41e4-93b5-4b89df793ff8/wyh/Swin-Transformer-Object-Detection# cat configs/_base_/default_runtime.py
checkpoint_config = dict(interval=4)
# yapf:disable
log_config = dict(
    interval=40,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = "mask_rcnn_swin_1.pth"
resume_from = None#"mask_rcnn_swin_1.pth"
workflow = [('train', 1),('val',1)]

修改權(quán)重文件

主要修改類(lèi)別為自己的類(lèi)別數(shù) cat changeclass.py

import torch

model_save_dir = "./"

pretrained_weights = torch.load('mask_rcnn_swin_tiny_patch4_window7.pth')

num_class = 1 #實(shí)際類(lèi)別數(shù)

pretrained_weights['state_dict']['roi_head.bbox_head.fc_cls.weight'].resize_(num_class + 1, 1024)
pretrained_weights['state_dict']['roi_head.bbox_head.fc_cls.bias'].resize_(num_class + 1)
pretrained_weights['state_dict']['roi_head.bbox_head.fc_reg.weight'].resize_(num_class * 4, 1024)
pretrained_weights['state_dict']['roi_head.bbox_head.fc_reg.bias'].resize_(num_class * 4)
pretrained_weights['state_dict']['roi_head.mask_head.conv_logits.weight'].resize_(num_class, 256, 1, 1)
pretrained_weights['state_dict']['roi_head.mask_head.conv_logits.bias'].resize_(num_class)

torch.save(pretrained_weights, "{}/mask_rcnn_swin_{}.pth".format(model_save_dir, num_class))

相應(yīng)修改configs/base/datasets/coco_instance.py中數(shù)據(jù)集路徑

dataset_type = 'CocoDataset'
data_root = 'data/cow/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=1,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
evaluation = dict(metric=['bbox', 'segm'])
[object Object][object Object]

修改模型權(quán)重參數(shù)等

修改configs\swin\mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py中的max_epochs、lr

參數(shù)文件改為coco_instance的

_base_ = [

'../\_base\_/models/mask_rcnn_swin_fpn.py',

'../\_base\_/datasets/coco_instance.py',

'../\_base\_/schedules/schedule_1x.py', '../\_base\_/default_runtime.py'

]

lr_config = dict(step=[27, 33])
runner = dict(type='EpochBasedRunnerAmp', max_epochs=40)

# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(
    type="DistOptimizerHook",
    update_interval=1,
    grad_clip=None,
    coalesce=True,
    bucket_size_mb=-1,
    use_fp16=True,
)

修改mmdet/core/evalution/class_names.py和mmdet/datasets/coco.py中的標(biāo)簽

def coco_classes():
      return ['cow']

 class CocoDataset(CustomDataset):
 
      CLASSES = ('cow',)

注意一個(gè)類(lèi)也要逗號(hào),還要進(jìn)行編譯 python setup.py install

訓(xùn)練

python tools/train.py configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py

推理結(jié)果

root@c92561fab718:/workspace# cat infer.py
from mmdet.apis import init_detector
from mmdet.apis import inference_detector
import torch
import os
import json
from PIL import Image
from mmdet.apis import show_result_pyplot
from mmdet.core.mask.utils import encode_mask_results
import numpy as np
# 模型配置文件
config_file = './work_dirs/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py'

# 預(yù)訓(xùn)練模型文件
checkpoint_file = './work_dirs/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco/epoch_40.pth'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("[INFO] 當(dāng)前使用{}做推斷".format(device))

# 通過(guò)模型配置文件與預(yù)訓(xùn)練文件構(gòu)建模型
model = init_detector(config_file, checkpoint_file, device=device)

# 測(cè)試單張圖片并進(jìn)行展示
#img = './data/wechat/val2017/wechat_20210726_0792.jpg'
# img = './nms_test.jpg'
json_filepath = './results.json'
json_file = open(json_filepath,mode='w',encoding='utf-8')

files = os.listdir("data/cow/images")
count = 0
waste = 0
results = []
for file in files:
    #img_path = os.path.join("data/safehat/val2017",file)
    #img = Image.open(img_path)
    #if img.mode == 'RGBA':
     #   r, g, b, a = im.split()
     #   img = Image.merge("RGB", (r, g, b))
    img =   os.path.join("data/cow/images",file)
    outfile = os.path.join('result/',file)
    result = inference_detector(model, img)

    #for bbox_results, mask_results in result:
    print('images/'+file)
    #print(encode_mask_results(result[1]))
    print(len(result[0][0]))
    print(len(result[1][0]))
    mask = encode_mask_results(result[1])

    imageid = "images/"+file

    for i in range(len(result[1][0])):
        masksize = mask[0][i]["size"]
        maskcount = ""+mask[0][i]["counts"].decode()
        score  = np.round(np.float(result[0][0][i][4]),3)
        resdict = { "image_id": imageid ,"category_id": 1 ,
                "segmentation": { "size": masksize, "counts": maskcount },
                "score": score }
        #resdict = str(resdict).replace("'","\"")
        results.append(resdict)
    #print(res)

#print(results)
#res = str(results).replace('\'','"').replace(r"\n","")

json.dump(results,json_file,ensure_ascii=False,indent=4)

這里提交后只有0.6多,后面還需要調(diào)優(yōu)參數(shù),增強(qiáng)數(shù)據(jù),或者試試swin 分割 ,swin-Unet之類(lèi)。

image.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容