第50章 CIFAR100數(shù)據(jù)集與ResNet網(wǎng)絡(luò)實(shí)現(xiàn)

上一章講解了ResNet模型及組件,也介紹了jax.example_libraries.stax下JAX內(nèi)置的模型組件。有了這些準(zhǔn)備工作,可以開(kāi)始上手寫(xiě)代碼了。所以,本章著手使用ResNet實(shí)現(xiàn)CIFAR100數(shù)據(jù)集的分類(lèi)。

何為CIFAR100數(shù)據(jù)集

CIFAR10和CIFAR100都是含有標(biāo)記小圖的數(shù)據(jù)集,相比CIFAR10,CIFAR100含有100個(gè)分類(lèi),每類(lèi)含600個(gè)圖。其中,每個(gè)分類(lèi)含有500張訓(xùn)練圖和100張測(cè)試圖。CIFAR100的100個(gè)分類(lèi)又被分組到20個(gè)大類(lèi)。

Superclass Classes
aquatic mammal beaver, dolphin, otter, seal, whale
fish aquarium fish, flatfish, ray, shark, trout
flowers orchids, poppies, roses, sunflowers, tulips
food containers bottles, bowls, cans, cups, plates
fruit and vegetables apples, mushrooms, oranges, pears, sweet peppers
household electrical devices clock, computer keyboard, lamp, telephone, television
household furniture bed, chair, couch, table, wardrobe
insects bee, beetle, butterfly, caterpillar, cockroach
large carnivores bear, leopard, lion, tiger, wolf
large man-made outdoor things bridge, castle, house, road, skyscraper
large natural outdoor scenes cloud, forest, mountain, plain, sea
large omnivores and herbivores camel, cattle, chimpanzee, elephant, kangaroo
medium-sized mammals fox, porcupine, possum, raccoon, skunk
non-insect invertebrates crab, lobster, snail, spider, worm
people baby, boy, girl, man, woman
reptiles crocodile, dinosaur, lizard, snake, turtle
small mammals hamster, mouse, rabbit, shrew, squirrel
trees hamster, mouse, rabbit, shrew, squirrel
vehicles 1 bicycle, bus, motorcycle, pickup truck, train
vehicles 2 lawn-mower, rocket, streetcar, tank, tractor

每個(gè)圖含有一個(gè)“fine”標(biāo)簽(表示所屬分類(lèi))和一個(gè)“coarse”標(biāo)簽(所屬大類(lèi)),大小為32x32像素。

圖1 CIFAR分類(lèi)

可以通過(guò)兩種方式下載數(shù)據(jù)集,

版本 大小 md5sum
CIFAR-100 python version 161 MB eb9058c3a382ffc7106e4002c42a8d85
CIFAR-100 Matlab version 175 MB 6a4bfa1dcd5c9453dda6bb54194911f4
CIFAR-100 binary version (suitable for C programs) 161 MB 03b5dce01913d631647c71ecec9e9cb8

選擇python版本。

  • 使用tensorflow_datasets下載。

下面分別介紹一下。

使用下載后的CIFAR100生成數(shù)據(jù)集

CIFAR-100 python version下載后,會(huì)有如下文件結(jié)構(gòu),

train
test
meta
file.txt~

其中,meta是數(shù)據(jù)集信息,train是訓(xùn)練集,test是測(cè)試集。通過(guò)如下代碼可以讀取數(shù)據(jù)集,


import pickle

def setup():
    
    def load(fileName: str):
        
        with open(file = fileName, mode = "rb") as handler:
            
            data = pickle.load(file = handler, encoding = "latin1")
            
        return data
    
    trains = load("../../Shares/cifar-100-python/train")
    tests = load("../../Shares/cifar-100-python/test")
    metas = load("../../Shares/cifar-100-python/meta")
    
    return trains, tests, metas
    
def train():

        trains, tests, metas = setup2()
    
    for key in trains.keys():
        
        print(f"key = {key}, len(trains[key]) = {len(trains[key])}")
    
    print("--------------------------------------------------")
    
    for key in tests.keys():
        
        print(f"key = {key}, len(tests[key]) = {len(tests[key])}")
    
    print("--------------------------------------------------")
    
    for key in metas.keys():
        
        print(f"key = {key}, len(metas[key]) = {len(metas[key])}”)
    
def main():
    
    train()

運(yùn)行結(jié)果打印輸出如下,


key = filenames, len(trains[key]) = 50000
key = batch_label, len(trains[key]) = 21
key = fine_labels, len(trains[key]) = 50000
key = coarse_labels, len(trains[key]) = 50000
key = data, len(trains[key]) = 50000
--------------------------------------------------
key = filenames, len(tests[key]) = 10000
key = batch_label, len(tests[key]) = 20
key = fine_labels, len(tests[key]) = 10000
key = coarse_labels, len(tests[key]) = 10000
key = data, len(tests[key]) = 10000
--------------------------------------------------
key = fine_label_names, len(metas[key]) = 100
key = coarse_label_names, len(metas[key]) = 20

具體說(shuō)明如下,

  • filenames,長(zhǎng)度為50000的列表,每一項(xiàng)代表對(duì)應(yīng)一個(gè)圖片文件名。
  • batch_label,批的信息。
  • fine_labels,所屬分類(lèi)。
  • coarse_labels,所屬大類(lèi)。
  • data,長(zhǎng)度為50000 x 3072的的二位數(shù)據(jù),每一行代表一幅圖片的像素值。
使用tensorflow_datasets

import tensorflow as tf
import tensorflow_datasets as tfds
import jax

def setup():
    
    (trains, tests), meta = tfds.load("cifar100", data_dir = "/tmp/", split = [tfds.Split.TRAIN, tfds.Split.TEST], with_info = True, batch_size = -1)
    
    #tensorflow_datasets.show_examples(trains, metas)
        
    trains = tfds.as_numpy(trains)
    tests = tfds.as_numpy(tests)
    
    train_images, train_labels = trains["image"], trains["label"]
    test_images, test_labels = tests["image"], tests["label"]
    
    return (train_images, train_labels), (test_images, test_labels)
    
def train():
    
    (train_images, train_labels), (test_images, test_labels) = setup()
    
    print((train_images.shape, train_labels.shape), (test_images.shape, test_labels.shape))
    
def main():
    
    train()
    
if __name__ == "__main__":
    
    main()

運(yùn)行結(jié)果打印輸出如下,


((50000, 32, 32, 3), (50000,)) ((10000, 32, 32, 3), (10000,))

keras.datasets數(shù)據(jù)集

def setup():
    
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data()

    return (train_images, train_labels), (test_images, test_labels)

運(yùn)行結(jié)果打印輸出如下,


((50000, 32, 32, 3), (50000, 1)) ((10000, 32, 32, 3), (10000, 1))

ResNet殘差模型實(shí)現(xiàn)

ResNet神經(jīng)網(wǎng)絡(luò)架構(gòu)在上一章已經(jīng)介紹,該網(wǎng)絡(luò)創(chuàng)造性地使用“模塊化‘的思維去對(duì)網(wǎng)絡(luò)進(jìn)行疊加,從而實(shí)現(xiàn)了數(shù)據(jù)在模塊內(nèi)部特征的傳遞不會(huì)丟失。

從下圖可以看到,模塊內(nèi)部司機(jī)上是3個(gè)卷積通道互相疊加,形成一個(gè)瓶頸設(shè)計(jì)。對(duì)于每一個(gè)殘差模塊,使用3層卷積。這3層分別是1 x 1、3 x 3和1 x 1的卷積層,其中1 x 1層負(fù)責(zé)先減少后增加(恢復(fù))尺寸,使3 x 3層具有較小的輸入和輸出尺寸瓶頸。

實(shí)現(xiàn)3層卷積結(jié)構(gòu)的代碼如下,


import jax.example_libraries.stax

def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

代碼中輸入的數(shù)據(jù)首先經(jīng)過(guò)jax.example_libraries.stax.Conv()卷積運(yùn)算,輸出的為四分之一的輸出維度,這是為了降低輸入數(shù)據(jù)的整個(gè)數(shù)據(jù)量,為進(jìn)行下一層[3, 3]的計(jì)算做準(zhǔn)備。 jax.example_libraries.stax.BatchNorm()是批標(biāo)準(zhǔn)化層,jax.example_libraries.stax.Relu是激活層。

另外,這里使用了3個(gè)之前沒(méi)有見(jiàn)過(guò)的類(lèi),首先需要知道,這些類(lèi)的目的是將不同的計(jì)算通路進(jìn)行一個(gè)組合。jax.example_libraries.stax.FanOut(2)是對(duì)數(shù)據(jù)進(jìn)行復(fù)制,jax.example_libraries.stax.paralle(Main, Identity)是將主通計(jì)算結(jié)果與Identity通路計(jì)算結(jié)果進(jìn)行同時(shí)并聯(lián)處理,jax.example_libraries.stax.FanInSum()對(duì)并聯(lián)處理的數(shù)據(jù)進(jìn)行合并。

在數(shù)據(jù)傳遞過(guò)程中,ResNet模塊使用了名為“shortcut”的“新石高速公路”,即集捷通道。shortcut連接相當(dāng)于簡(jiǎn)單執(zhí)行了同等映射,不會(huì)產(chǎn)生額外的參數(shù),也不會(huì)增加計(jì)算復(fù)雜度,如下圖所示,

而且,整個(gè)網(wǎng)絡(luò)依舊可以通過(guò)端到端的反向傳播訓(xùn)練。代碼如下,


def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

有的時(shí)候,除了判定是否對(duì)輸入數(shù)據(jù)進(jìn)行處理外,由于ResNet在實(shí)現(xiàn)過(guò)程中對(duì)數(shù)據(jù)的維度做了改變,因此,當(dāng)輸入的維度和要求模型輸出的維度不同(input_channel不等于out_dim)時(shí),需要對(duì)輸入的維度進(jìn)行padding操作。所謂padding操作就是補(bǔ)全數(shù)據(jù),通過(guò)設(shè)置padding參數(shù)對(duì)數(shù)據(jù)進(jìn)行補(bǔ)全。

ResNet網(wǎng)絡(luò)實(shí)現(xiàn)

ResNet網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示,

圖中一共提到5種深度的ResNet,分別是18、34、50、101和152,其中所有的網(wǎng)絡(luò)都分為5個(gè)部分,分貝是conv1、conv2_x、conv3_x、conv4_x和conv5_x。

下面將對(duì)其進(jìn)行實(shí)現(xiàn)。需要說(shuō)明的是,ResNet完整的實(shí)現(xiàn)需要較高性能的顯卡。為了便于演示,下面代碼里做了修改,去掉了pooling層,并降低了filters的數(shù)目和每層的層數(shù),這一點(diǎn)請(qǐng)務(wù)必注意。

完整實(shí)現(xiàn)的ResNet50代碼如下,


import jax.example_libraries.stax

def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

def ConvolutionalBlock(kernel_size, filters, strides = (1, 1)):
    
    kernel_size_ = kernel_size
    filters1, filters2, filters3 = filters
    
    Main = jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.Conv(filters1, (1, 1), strides = strides, padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Conv(filters3, (1, 1), strides = strides, padding = "SAME"),
        jax.example_libraries.stax.BatchNorm()
    )
    
    Shortcut = jax.example_libraries.stax.serial(
        jax.example_libraries.stax.Conv(filters3, (1, 1), strides, padding = "SAME")
    )
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(
            Main,
            Shortcut
        ),
        
        jax.example_libraries.stax.FanInSum,
        jax.example_libraries.stax.Relu)

def ResNet50(number_classes):
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.Conv(64, (3, 3), padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.MaxPool((3, 3), strides = (2, 2)),
        
        ConvolutionalBlock(3, [64, 64, 256]),
        
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        
        ConvolutionalBlock(3, [128, 128, 512]),
        
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128,]),
        
        ConvolutionalBlock(3, [256, 256, 1024]),
        
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        
        ConvolutionalBlock(3, [512, 512, 2048]),
        
        IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]),
        
        jax.example_libraries.stax.AvgPool((7, 7)),
        
        jax.example_libraries.stax.Flatten,
        
        jax.example_libraries.stax.Dense(number_classes),
        
        jax.example_libraries.stax.LogSoftmax
    )

結(jié)論

本章介紹了CIFAR100的數(shù)據(jù)集的結(jié)構(gòu),也介紹了ResNet殘差模塊及網(wǎng)絡(luò)實(shí)現(xiàn),還是為了實(shí)戰(zhàn)做準(zhǔn)備。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容