小黑的Python日記:Unet簡(jiǎn)單實(shí)現(xiàn)裂縫分割

大噶好,我系小黑喵

裂縫數(shù)據(jù)集

數(shù)據(jù)集地址:https://github.com/cuilimeng/CrackForest-dataset
結(jié)構(gòu):

  --project
    main.py
     --image
        --train
           --data
           --groundTruth
        --val
           --data
           --groundTruth

我手動(dòng)將數(shù)據(jù)集做成這個(gè)格式,其中trian84張,val34張,都保存為了jpg圖像。

Unet

論文地址:http://www.arxiv.org/pdf/1505.04597.pdf
代碼來(lái)源:https://github.com/JavisPeng/u_net_liver
上面代碼中,作者將Unet運(yùn)用于liver識(shí)別,和裂縫一樣,都只有一個(gè)mask,因而我們可以直接使用上述代碼。

Unet結(jié)構(gòu)

需要修改dataset.py為自己的數(shù)據(jù)集,其他小小改動(dòng)即可。

#dataset.py
import torch.utils.data as data
import PIL.Image as Image
import os


def make_dataset(rootdata,roottarget):#獲取img和mask的地址
    imgs = []
    filename_data = [x for x in os.listdir(rootdata)]
    for name in filename_data:
        img = os.path.join(rootdata, name)
        mask = os.path.join(roottarget, name)
        imgs.append((img, mask))#作為元組返回
    return imgs


class MyDataset(data.Dataset):
    def __init__(self, rootdata, roottarget, transform=None, target_transform=None):
        imgs = make_dataset(rootdata,roottarget)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path).convert('L')#讀取并轉(zhuǎn)換為二值圖像
        img_y = Image.open(y_path).convert('L')
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)
#main.py
import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import MyDataset

# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 復(fù)活了,這里修改就沒錯(cuò)誤了
])

# mask只需要轉(zhuǎn)換為tensor
y_transforms = transforms.ToTensor()


def train_model(model, criterion, optimizer, dataload, num_epochs=10):
    for epoch in range(0,num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" %
                  (step,
                   (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
    torch.save(model.cpu().state_dict(), 'weights_%d.pth' % epoch)
    return model


#訓(xùn)練模型
def train():
    batch_size = 1
    liver_dataset = MyDataset(
        "image/train/data", "image/train/gt",transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(
        liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    train_model(model, criterion, optimizer, dataloaders)


#顯示模型的輸出結(jié)果
def test():
    liver_dataset = MyDataset(
        "image/val/data", "image/val/gt", transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y = model(x)
            img_y = torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()


if __name__ == '__main__':
    pretrained = False
    model = Unet(1, 1).to(device)
    if pretrained:
        model.load_state_dict(torch.load('./weights_4.pth'))
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    train()
    test()

unet.py不需要變動(dòng)

結(jié)果

訓(xùn)練了10個(gè)epoch后:累加loss大概到3
前幾張預(yù)測(cè)圖片:


上為預(yù)測(cè),下為groundTruth

對(duì)于100多張的數(shù)據(jù)集,這個(gè)效果還行。
也算是填了一個(gè)以前的坑。


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容