Unet學習整理

Unet

一、原理:


Unet網(wǎng)絡分為兩個部分:

第一部分:特征提取。上圖中的左側,有點類似VGG網(wǎng)絡。由簡單的卷積、池化下采樣。圖中采用的是3*3和1*1的卷積核進行卷積操作,3*3用于提取特征,1*1用于改變緯度。另外每經(jīng)過一次池化,就變成另一個尺度,包括input的圖像總計5個尺度。

第二部分:上采樣及特征融合。上圖中的右側。此處的上采樣即通過轉置卷積進行。然后進行特征融合,但是此處的特征融合和FCN的方法不一樣(見下方)。但是融合之前要將其crop。這里的融合也是拼接。

特征融合:

1.Unet:拼接。采用將特征在channel維度拼接在一起,形成更厚的特征。對應于?TensorFlow的tf.concat()函數(shù),比較占顯存。

2. FCN:對應點相加,并不形成更厚的特征,對應于TensorFlow中的tf.add()函數(shù)。

Unet網(wǎng)絡的輸入與輸出部分:


Unet最開始是用來設計在醫(yī)學圖像中的細胞分割的,但是分割時候不可能將原圖輸入網(wǎng)絡,所以必須切成一張一張的小patch,在切成小patch的時候,Unet由于網(wǎng)絡結構原因適合有overlap的切圖,可以看圖,紅框是要分割區(qū)域,但是在切圖時要包含周圍區(qū)域,overlap另一個重要原因是周圍overlap部分可以為分割區(qū)域邊緣部分提供文理等信息??梢钥袋S框的邊緣,分割結果并沒有受到切成小patch而造成分割情況不好。在后續(xù)使用的時候,由于本人使用的是512*512大小的圖片,所以這步就不需要進行。

優(yōu)點:

(1):多次下采樣,提供多個尺度,實現(xiàn)了網(wǎng)絡對圖像特征的多尺度特征識別;

(2):在上采樣部分,進行了特征融合,并且是將多個不同的尺度特征融合。這一層的轉置卷積后與上一層同一個尺度的特征提取卷積的輸出進行融合。想對比FCN僅在最后一層進行融合。

代碼分享:

#導入相應模塊

from __future__import print_function

import os

import datetime

import numpyas np

from keras.modelsimport Model

from keras.layersimport Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, AveragePooling2D, Dropout, BatchNormalization

from keras.optimizersimport Adam

from keras.layers.convolutionalimport UpSampling2D, Conv2D

from keras.callbacksimport ModelCheckpoint

from kerasimport backendas K

from keras.layers.advanced_activationsimport LeakyReLU, ReLU

import cv2


PIXEL =圖片大小

BATCH_SIZE = batch_size

lr =學習率

EPOCH =訓練epoch

train_img_CHANNEL =訓練圖片的緯度

train_mask_CHANNEL =訓練圖片mask的緯度

train_NUM =訓練圖片數(shù)量

train_img ='訓練集image路徑'

train_mask ='訓練集mask路徑'

test_img ='測試集image路徑'

test_mask ='測試集mask路徑'

#訓練generator,返回的X、Y的緯度是4維的

def train_generator(train_img, train_mask,BATCH_SIZE):

while 1:

X_train_files = os.listdir(train_img)

Y_train_files = os.listdir(train_mask)

a = (np.arange(1, train_NUM))

X = []

Y = []

for iin range(BATCH_SIZE):

index = np.random.choice(a)

img = cv2.imread(train_img + X_train_files[index], 1)

img = np.array(img).reshape(PIXEL, PIXEL, train_img_CHANNEL)

X.append(img)

img1 = cv2.imread(train_mask + Y_train_files[index], 1)

img1 = np.array(img1).reshape(PIXEL, PIXEL, train_mask_CHANNEL);

Y.append(img1)

X = np.array(X)

Y = np.array(Y)

yield X, Y

#測試generator,返回的X、Y的緯度是4維的

def test_generator(test_img, test_mask,BATCH_SIZE):

while 1:

X_test_files = os.listdir(test_img)

Y_test_files = os.listdir(test_mask)

a = (np.arange(1, train_NUM))

X = []

Y = []

for iin range(BATCH_SIZE):

index = np.random.choice(a)

img = cv2.imread(test_img + X_test_files[index], 1)

img = np.array(img).reshape(PIXEL, PIXEL, train_img_CHANNEL)

X.append(img)

img1 = cv2.imread(test_mask + Y_test_files[index], 1)

img1 = np.array(img1).reshape(PIXEL, PIXEL, train_mask_CHANNEL);

Y.append(img1)

X = np.array(X)

Y = np.array(Y)

yield X, Y

#搭建模型

inputs = Input((PIXEL, PIXEL, 3))

conv1 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)

pool1 = AveragePooling2D(pool_size=(2, 2))(conv1)# 16

conv2 = BatchNormalization(momentum=0.99)(pool1)

conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)

conv2 = BatchNormalization(momentum=0.99)(conv2)

conv2 = Conv2D(64, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)

conv2 = Dropout(0.02)(conv2)

pool2 = AveragePooling2D(pool_size=(2, 2))(conv2)# 8

conv3 = BatchNormalization(momentum=0.99)(pool2)

conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)

conv3 = BatchNormalization(momentum=0.99)(conv3)

conv3 = Conv2D(128, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)

conv3 = Dropout(0.02)(conv3)

pool3 = AveragePooling2D(pool_size=(2, 2))(conv3)# 4

conv4 = BatchNormalization(momentum=0.99)(pool3)

conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)

conv4 = BatchNormalization(momentum=0.99)(conv4)

conv4 = Conv2D(256, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)

conv4 = Dropout(0.02)(conv4)

pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)

conv5 = BatchNormalization(momentum=0.99)(pool4)

conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)

conv5 = BatchNormalization(momentum=0.99)(conv5)

conv5 = Conv2D(512, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)

conv5 = Dropout(0.02)(conv5)

pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)

pool4 = AveragePooling2D(pool_size=(2, 2))(pool3)# 2

pool5 = AveragePooling2D(pool_size=(2, 2))(pool4)# 1

conv6 = BatchNormalization(momentum=0.99)(pool5)

conv6 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

up7 = (UpSampling2D(size=(2, 2))(conv7))# 2

conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)

merge7 = concatenate([pool4, conv7], axis=3)

conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)

up8 = (UpSampling2D(size=(2, 2))(conv8))# 4

conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)

merge8 = concatenate([pool3, conv8], axis=3)

conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)

up9 = (UpSampling2D(size=(2, 2))(conv9))# 8

conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)

merge9 = concatenate([pool2, conv9], axis=3)

conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)

up10 = (UpSampling2D(size=(2, 2))(conv10))# 16

conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up10)

conv11 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10)

up11 = (UpSampling2D(size=(2, 2))(conv11))# 32

conv11 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up11)

conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)

model = Model(input=inputs, output=conv12)

print(model.summary())

model.compile(optimizer=Adam(lr=1e-3), loss='mse', metrics=['accuracy'])

history = model.fit_generator(train_generator(train_img, train_mask,BATCH_SIZE),

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? steps_per_epoch=600, nb_epoch=EPOCH,validation_data=test_generator(test_img, test_mask,BATCH_SIZE),nb_val_samples=20)

end_time = datetime.datetime.now().strftime('%Y-%m-%d? %H:%M:%S')

model.save('模型保存路徑,h5格式')

mse = np.array((history.history['loss']))

np.save('歷史loss保存路徑,npy格式', mse)

最后編輯于
?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。
禁止轉載,如需轉載請通過簡信或評論聯(lián)系作者。

友情鏈接更多精彩內(nèi)容