記一次pytorch訓(xùn)練模型及搭建(pytorch圖片數(shù)據(jù)載入,模型訓(xùn)練)

有一天在做關(guān)于物體特征點(diǎn)定位的工作,有一天突發(fā)奇想,想要通過(guò)pytorch建一個(gè)模型進(jìn)行特征點(diǎn)定位。努力敲了半天代碼,終于實(shí)現(xiàn)了,可惜由于自己采的數(shù)據(jù)集過(guò)?。ɑ蛘呤瞧渌?,歡迎大神賜教),導(dǎo)致定位結(jié)果誤差很大??偟膩?lái)說(shuō)還算成功。
我采集的數(shù)據(jù)集,我在杯子上點(diǎn)了個(gè)黑點(diǎn),然后手工標(biāo)定得到j(luò)son文件,想要通過(guò)模型定位黑點(diǎn)坐標(biāo),奈何。。不太理想:

在這里插入圖片描述

先潑代碼:

在這里插入圖片描述

首先是對(duì)應(yīng)標(biāo)簽及樣本的數(shù)據(jù)集以便載入:

#作者:Rayne
#作用:對(duì)應(yīng)json文件中的坐標(biāo)及文件夾中圖片路徑,以便Dataset模塊載入
import os
import json

def get_img_path(img_path):
    file_path=[]
    for file in os.listdir(img_path):
        file_path.append(os.path.join(img_path+'/', file))
    return file_path

def get_label(label_path):# ./label.json
    labels={}
    with open(label_path,'r',encoding='UTF-8') as f:
        js=json.load(f)
    for i in js:
        x,y,name=i['Data']['svgArr'][0]['data'][0]['x'],i['Data']['svgArr'][0]['data'][0]['y'],i['imageName']
        labels[name]=[x,y]

    return labels

def get_all(img_path,label_path):
    file_path=get_img_path(img_path)
    labels=get_label(label_path)
    label=[]
    for file in file_path:
        label.append(labels[file.split('/')[2]])
    return file_path,label

其次是載入文件夾中的數(shù)據(jù):ImageLoader.py:

#作者:Rayne
#作用:載入圖片文件及標(biāo)簽,標(biāo)簽是[1,2]的list,對(duì)應(yīng)特征點(diǎn)x,y

import torch.utils.data as data
import torch
from PIL import Image
import numpy as np


def default_loader(path):
    return Image.open(path).convert('RGB')


###############################################
class myImageFloder(data.Dataset):

    def __init__(self, img, label, loader=default_loader):
        self.img = img
        self.label = label
        self.loader = loader

    def __getitem__(self, index):
        img = self.img[index]
        label = self.label[index]

        # 數(shù)據(jù)打開(kāi)
        img_open = self.loader(img)
        data = np.ascontiguousarray(img_open, dtype=np.float32) / 256

        label = np.array([label[0]/540.0,label[1]/384.0],dtype=np.float32)

        data = torch.from_numpy(data).view(3, 540, 384).cuda()
        label = torch.from_numpy(label).cuda()

        return data, label

    def __len__(self):
        return len(self.img)

然后是模型搭建,后來(lái)我用了遷移學(xué)習(xí):

#作者:Rayne
#作用:博主花費(fèi)20分鐘搭建的模型
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.conv2 = nn.Conv2d(6, 12, 3)
        self.pool=nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(12, 24, 3)
        self.conv4 = nn.Conv2d(24, 48, 3)
        self.conv5 = nn.Conv2d(48, 96, 3)
        self.conv6 = nn.Conv2d(96, 192, 3)
        self.fc1 = nn.Linear(192*24, 48)
        self.fc2 = nn.Linear(48, 12)
        self.fc3 = nn.Linear(12, 2)


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.pool(F.relu(self.conv5(x)))
        x = self.pool(F.relu(self.conv6(x)))
        x=torch.flatten(x)
        x = x.view(-1, 192*24)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

然后是訓(xùn)練函數(shù)啦:

#作者:Rayne
#作用:載入圖片及標(biāo)簽,定義訓(xùn)練函數(shù),打印訓(xùn)練結(jié)果。
import model
import torch
import torch.nn as nn
from data import dir_xy, ImageLoader
import torch.optim as optim

train_hist=[]
test_hist=[]
def train(net=None, criterion=None, optimizer=None, TrainImgLoader=None, TestImgLoader=None, epochs=20):
    running_loss = 0.0
    test_loss = 0.0
    for epoch in range(epochs):  # loop over the dataset multiple times
        for i, data in enumerate(TrainImgLoader):
            inputs, labels = data
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)  # 輸出為[-1,2]
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            if i % 10 ==9:
                print('第{}圈train Loss: {}'.format(epoch, running_loss / 10))
                train_hist.append(running_loss / 10)
                running_loss = 0.0

        with torch.no_grad():
            for i, data in enumerate(TestImgLoader):
                images, labels = data
                outputs = net(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                if i % 10 ==9:
                    print('第{}圈test Loss: {}'.format(epoch, test_loss / 10))
                    test_hist.append(test_loss)
                    test_loss = 0.0
    plter(epochs=epochs, train_loss=train_hist, test_loss=test_hist)
    print('Finished Training')

def plter(train_loss,test_loss,epochs):
    import matplotlib.pyplot as plt
    x = range(0, epochs)

    fig, ax = plt.subplots()
    ax.plot(range(len(train_loss)), train_loss, label='train')
    ax.plot(range(len(test_loss)), test_loss, label='test')
    ax.set_xlabel(xlabel='epoch')
    ax.set_ylabel(ylabel='MSE')
    ax.set_title('Epochs VS MSE')
    ax.legend()
    plt.show()

最后是主函數(shù):

#作者:Rayne
#作用:定義優(yōu)化器,模型,損失函數(shù)等并進(jìn)行訓(xùn)練
import model
import torch
import torch.nn as nn
from data import dir_xy, ImageLoader
import torch.optim as optim
import train_test
import torchvision


def train():
    torch.set_default_tensor_type(torch.FloatTensor)
    # net = model.Net().cuda()

    net = torchvision.models.resnet18(pretrained=True)
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, 2)
    net=net.cuda()

    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.00001)

    dir, xy = dir_xy.get_all(img_path='data/train', label_path='data/label/train.json')
    loader = ImageLoader.myImageFloder(dir, xy)
    TrainImgLoader = torch.utils.data.DataLoader(loader,batch_size=10,shuffle = True)

    dir, xy = dir_xy.get_all(img_path='data/test', label_path='data/label/test.json')
    loader2 = ImageLoader.myImageFloder(dir, xy)
    TestImgLoader = torch.utils.data.DataLoader(loader2)

    train_test.train(net=net.cuda(), criterion=criterion, optimizer=optimizer, TrainImgLoader=TrainImgLoader,
                     TestImgLoader=TestImgLoader)


train()

結(jié)果:

在這里插入圖片描述

下降到后面,尤其是10個(gè)循環(huán)后不太明顯,我相信擁有更多的數(shù)據(jù)后會(huì)得到更好的結(jié)果。希望可以幫到你~

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容