Unet
一、原理:

Unet網(wǎng)絡分為兩個部分:
第一部分:特征提取。上圖中的左側,有點類似VGG網(wǎng)絡。由簡單的卷積、池化下采樣。圖中采用的是3*3和1*1的卷積核進行卷積操作,3*3用于提取特征,1*1用于改變緯度。另外每經(jīng)過一次池化,就變成另一個尺度,包括input的圖像總計5個尺度。
第二部分:上采樣及特征融合。上圖中的右側。此處的上采樣即通過轉置卷積進行。然后進行特征融合,但是此處的特征融合和FCN的方法不一樣(見下方)。但是融合之前要將其crop。這里的融合也是拼接。
特征融合:
1.Unet:拼接。采用將特征在channel維度拼接在一起,形成更厚的特征。對應于?TensorFlow的tf.concat()函數(shù),比較占顯存。
2. FCN:對應點相加,并不形成更厚的特征,對應于TensorFlow中的tf.add()函數(shù)。
Unet網(wǎng)絡的輸入與輸出部分:

Unet最開始是用來設計在醫(yī)學圖像中的細胞分割的,但是分割時候不可能將原圖輸入網(wǎng)絡,所以必須切成一張一張的小patch,在切成小patch的時候,Unet由于網(wǎng)絡結構原因適合有overlap的切圖,可以看圖,紅框是要分割區(qū)域,但是在切圖時要包含周圍區(qū)域,overlap另一個重要原因是周圍overlap部分可以為分割區(qū)域邊緣部分提供文理等信息??梢钥袋S框的邊緣,分割結果并沒有受到切成小patch而造成分割情況不好。在后續(xù)使用的時候,由于本人使用的是512*512大小的圖片,所以這步就不需要進行。
優(yōu)點:
(1):多次下采樣,提供多個尺度,實現(xiàn)了網(wǎng)絡對圖像特征的多尺度特征識別;
(2):在上采樣部分,進行了特征融合,并且是將多個不同的尺度特征融合。這一層的轉置卷積后與上一層同一個尺度的特征提取卷積的輸出進行融合。想對比FCN僅在最后一層進行融合。
代碼分享:
#導入相應模塊
from __future__import print_function
import os
import datetime
import numpyas np
from keras.modelsimport Model
from keras.layersimport Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, AveragePooling2D, Dropout, BatchNormalization
from keras.optimizersimport Adam
from keras.layers.convolutionalimport UpSampling2D, Conv2D
from keras.callbacksimport ModelCheckpoint
from kerasimport backendas K
from keras.layers.advanced_activationsimport LeakyReLU, ReLU
import cv2
PIXEL =圖片大小
BATCH_SIZE = batch_size
lr =學習率
EPOCH =訓練epoch
train_img_CHANNEL =訓練圖片的緯度
train_mask_CHANNEL =訓練圖片mask的緯度
train_NUM =訓練圖片數(shù)量
train_img ='訓練集image路徑'
train_mask ='訓練集mask路徑'
test_img ='測試集image路徑'
test_mask ='測試集mask路徑'
#訓練generator,返回的X、Y的緯度是4維的
def train_generator(train_img, train_mask,BATCH_SIZE):
while 1:
X_train_files = os.listdir(train_img)
Y_train_files = os.listdir(train_mask)
a = (np.arange(1, train_NUM))
X = []
Y = []
for iin range(BATCH_SIZE):
index = np.random.choice(a)
img = cv2.imread(train_img + X_train_files[index], 1)
img = np.array(img).reshape(PIXEL, PIXEL, train_img_CHANNEL)
X.append(img)
img1 = cv2.imread(train_mask + Y_train_files[index], 1)
img1 = np.array(img1).reshape(PIXEL, PIXEL, train_mask_CHANNEL);
Y.append(img1)
X = np.array(X)
Y = np.array(Y)
yield X, Y
#測試generator,返回的X、Y的緯度是4維的
def test_generator(test_img, test_mask,BATCH_SIZE):
while 1:
X_test_files = os.listdir(test_img)
Y_test_files = os.listdir(test_mask)
a = (np.arange(1, train_NUM))
X = []
Y = []
for iin range(BATCH_SIZE):
index = np.random.choice(a)
img = cv2.imread(test_img + X_test_files[index], 1)
img = np.array(img).reshape(PIXEL, PIXEL, train_img_CHANNEL)
X.append(img)
img1 = cv2.imread(test_mask + Y_test_files[index], 1)
img1 = np.array(img1).reshape(PIXEL, PIXEL, train_mask_CHANNEL);
Y.append(img1)
X = np.array(X)
Y = np.array(Y)
yield X, Y
#搭建模型
inputs = Input((PIXEL, PIXEL, 3))
conv1 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
pool1 = AveragePooling2D(pool_size=(2, 2))(conv1)# 16
conv2 = BatchNormalization(momentum=0.99)(pool1)
conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = BatchNormalization(momentum=0.99)(conv2)
conv2 = Conv2D(64, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = Dropout(0.02)(conv2)
pool2 = AveragePooling2D(pool_size=(2, 2))(conv2)# 8
conv3 = BatchNormalization(momentum=0.99)(pool2)
conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = BatchNormalization(momentum=0.99)(conv3)
conv3 = Conv2D(128, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = Dropout(0.02)(conv3)
pool3 = AveragePooling2D(pool_size=(2, 2))(conv3)# 4
conv4 = BatchNormalization(momentum=0.99)(pool3)
conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = BatchNormalization(momentum=0.99)(conv4)
conv4 = Conv2D(256, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = Dropout(0.02)(conv4)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
conv5 = BatchNormalization(momentum=0.99)(pool4)
conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = BatchNormalization(momentum=0.99)(conv5)
conv5 = Conv2D(512, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
pool4 = AveragePooling2D(pool_size=(2, 2))(pool3)# 2
pool5 = AveragePooling2D(pool_size=(2, 2))(pool4)# 1
conv6 = BatchNormalization(momentum=0.99)(pool5)
conv6 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
up7 = (UpSampling2D(size=(2, 2))(conv7))# 2
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)
merge7 = concatenate([pool4, conv7], axis=3)
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
up8 = (UpSampling2D(size=(2, 2))(conv8))# 4
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)
merge8 = concatenate([pool3, conv8], axis=3)
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
up9 = (UpSampling2D(size=(2, 2))(conv9))# 8
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)
merge9 = concatenate([pool2, conv9], axis=3)
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
up10 = (UpSampling2D(size=(2, 2))(conv10))# 16
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up10)
conv11 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10)
up11 = (UpSampling2D(size=(2, 2))(conv11))# 32
conv11 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up11)
conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
model = Model(input=inputs, output=conv12)
print(model.summary())
model.compile(optimizer=Adam(lr=1e-3), loss='mse', metrics=['accuracy'])
history = model.fit_generator(train_generator(train_img, train_mask,BATCH_SIZE),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? steps_per_epoch=600, nb_epoch=EPOCH,validation_data=test_generator(test_img, test_mask,BATCH_SIZE),nb_val_samples=20)
end_time = datetime.datetime.now().strftime('%Y-%m-%d? %H:%M:%S')
model.save('模型保存路徑,h5格式')
mse = np.array((history.history['loss']))
np.save('歷史loss保存路徑,npy格式', mse)