Pytorch多標(biāo)簽CNN端到端驗證碼識別

這其實是一個多標(biāo)簽分類問題,每個驗證碼圖片有4個字符(標(biāo)簽),并且順序固定;只要將卷積神經(jīng)網(wǎng)絡(luò)的最后一層稍加修改就能實現(xiàn)多標(biāo)簽分類。

如下圖所示,我們的驗證碼一共有4個數(shù)字,將4個數(shù)字轉(zhuǎn)換成40位one_hot形式,輸出層的[0-9]輸出值對應(yīng)第一個字符的onehot編碼,[10-19]輸出值對應(yīng)第二個字符的onehot編碼,[20-29]輸出值對應(yīng)第三個字符,[30-39]輸出值對于第四個字符,并使用pytorch的多標(biāo)簽分類函數(shù)nn.MultiLabelSoftMarginLoss作為損失函數(shù)。

image.png

訓(xùn)練集800張圖片,測試集200張,每張圖片大小20*60

模式結(jié)構(gòu):
CNN (
(conv1): Sequential (
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU ()
(2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
)
(conv2): Sequential (
(0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU ()
(2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
)
(out): Linear (624 -> 40)
)

# coding: utf-8
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import math
import csv
import cv2

#讀取標(biāo)簽
csvfile = open('GenPics/label.csv')
reader = csv.reader(csvfile)
lables = []
for line in reader:
    tmpLine = [line[0],line[1]]
    lables.append(tmpLine)
csvfile.close()

X = []
y = []

#讀入圖片
picnum = len(lables)
print("picnum : ", picnum)
for i in range(0, picnum):
    img_name = "GenPics/" + lables[i][0] + '.jpg'
    img = cv2.imread(img_name, cv2.IMREAD_GRAYSCALE)
    X.append(img)  
    y.append(lables[i][1])
   
tmp = []
for i in range(len(y)):
    c0 = int(y[i][0])
    c1 = int(y[i][1])
    c2 = int(y[i][2])
    c3 = int(y[i][3])
    tmp.append(c0)
    tmp.append(c1)
    tmp.append(c2)
    tmp.append(c3)

#處理成one_hot形式
X = np.array(X)
X = torch.from_numpy(X)
X = torch.unsqueeze(X, dim=1)
X = X.type(torch.FloatTensor)/255.
batch_size = 4000
yt = torch.LongTensor(tmp)
yt = torch.unsqueeze(yt, 1)
yt_onehot = torch.FloatTensor(batch_size, 10)
yt_onehot.zero_()
yt_onehot.scatter_(1, yt, 1)
yt_onehot = yt_onehot.view(-1, 40)
y = yt_onehot

#劃分訓(xùn)練集和測試集
train_x = X[:800]
train_y = y[:800]
test_x = X[800:]
test_x = Variable(test_x, volatile=True)
test_y = y[800:]

#定義模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
                            nn.Conv2d(
                                in_channels=1,
                                out_channels=32,
                                kernel_size=3,
                                stride=1,
                                padding=0,                              
                                    ),
                            nn.ReLU(),
                            nn.MaxPool2d(kernel_size=2),
                                )
        self.conv2 = nn.Sequential(
                            nn.Conv2d(32, 16, 3, 1, 0),
                            nn.ReLU(),
                            nn.MaxPool2d(2),
                          
                        )
        self.out = nn.Linear(16*3*13, 40)
  
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)   
        return output
cnn = CNN()
print(cnn)
# CNN (
  # (conv1): Sequential (
    # (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    # (1): ReLU ()
    # (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  # )
  # (conv2): Sequential (
    # (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
    # (1): ReLU ()
    # (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  # )
  # (out): Linear (624 -> 40)
# )

#定義優(yōu)化模型和損失函數(shù)
batsize = 8
epochs = 10
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
loss_func = nn.MultiLabelSoftMarginLoss()

#進行迭代訓(xùn)練
for epoch in range(epochs):
    losses = []
    iters = int(math.ceil(train_x.shape[0]/batsize))
    for i in range(iters):
        train_x_i = train_x[i*batsize: (i+1)*batsize]
        train_y_i = train_y[i*batsize: (i+1)*batsize]
        tx = Variable(train_x_i)
        ty = Variable(train_y_i)
        out = cnn(tx)
        loss = loss_func(out, ty)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()  
      
        losses.append(loss.data.mean())
    print('[%d/%d] Loss: %.3f' % (epoch+1, epochs, np.mean(losses)))
# [1/10] Loss: 0.352
# [2/10] Loss: 0.322
# [3/10] Loss: 0.244
# [4/10] Loss: 0.100
# [5/10] Loss: 0.053
# [6/10] Loss: 0.040
# [7/10] Loss: 0.035
# [8/10] Loss: 0.031
# [9/10] Loss: 0.028
# [10/10] Loss: 0.026

#測試集驗證準(zhǔn)確率
test_output = cnn(test_x)
correct_num = 0
for i in range(test_output.size()[0]):
    c0 = np.argmax(test_output[i, 0:10].data.numpy())
    c1 = np.argmax(test_output[i, 10:20].data.numpy())
    c2 = np.argmax(test_output[i, 20:30].data.numpy())
    c3 = np.argmax(test_output[i, 30:40].data.numpy())
    c = '%s%s%s%s' % (c0, c1, c2, c3)
    if c == lables[800+i][1]:
        correct_num += 1
print("Test accurate :", float(correct_num)/ len(test_output))
# Test accurate : 0.98

#單個圖片驗證
img_path = 'test2.jpg'
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
imgArr = np.array(img)
imgArr = np.expand_dims(imgArr, axis=0)
imgArr = torch.from_numpy(imgArr)
imgArr = torch.unsqueeze(imgArr, dim=1)
imgArr = imgArr.type(torch.FloatTensor)/255.
imgArr = Variable(imgArr, volatile=True)
pred_img = cnn(imgArr)
c0 = np.argmax(pred_img[0, 0:10].data.numpy())
c1 = np.argmax(pred_img[0, 10:20].data.numpy())
c2 = np.argmax(pred_img[0, 20:30].data.numpy())
c3 = np.argmax(pred_img[0, 30:40].data.numpy())
c = '%s%s%s%s' % (c0, c1, c2, c3)
print(c)
# 5955
import matplotlib.pyplot as plt
img = plt.imread(img_path)
plt.imshow(img)
plt.show()
image.png

參考引用:https://github.com/junliangliu/captcha

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容