卷積神經網絡在圖像處理領域獲得了巨大的成功,其結合特征提取和目標訓練為一體的模型,能夠很好地利用已有的信息度和結果進行反饋訓練。
對于文本識別的卷積神經網絡來說,同樣也是充分利用特征提取時提取的問本特征來計算文本特征權值大小。,歸一化處理需要處理的數據。這樣使得原來的文本信息抽象成一個向量化的樣本集,之后將樣本集和訓練好的模板輸入卷積神經網絡進行處理。
本章將在上一章的基礎上使用卷積神經網絡實現文本分類問題,這里將采用基于字符和基于詞嵌入的兩種詞卷積神經網絡處理方法。實際上無論是基于字符還是基于詞嵌入形式的處理方式都是可以互相轉換的。本章只介紹基本的模型使用方法,更深入的應用請自行研究學習。
字符(非單詞)文本的處理
本節(jié)介紹基于字符的CNN處理方法,基于單詞的卷積處理將在下一章介紹。由于單詞都是由字母組成的,因此可以簡單地將單詞拆分成字母的表示形式,
hello -> [‘h’, ‘e’, ‘l’, ‘l’, ‘o’]
這樣可以看到,一個單詞hello被人為拆成了’h’, ‘e’, ‘l’, ‘l’, ‘o’這5個字母。對于hello的處理有兩種方法,
- 獨熱編碼。
- 字符嵌入。
處理的結果,單詞“hello”將被轉成一個[5, n]的矩陣。本例將采用獨熱編碼的方法處理。
使用卷積神經網絡計算字符矩陣時,對于每個單詞拆分后的數據,根據不同的長度對其進行卷積處理,提取出高層抽象概念。這樣做的好處是不需要使用預訓練好的詞向量和語法句法結構信息。除此之外,還有一個好處是可以很容易推廣到所有語言。使用CNN處理字符文本分類的原理如下圖所示,

標題文本讀取和轉化
對于AG News數據集來說,每條新聞都有對應的分類,也有標題和正文。對于正文的抽取在前幾章中已經介紹。這里直接對新聞標題進行處理,如下所示,
| 3 | Wall St. Bears Claw Back Into the Black (Reuters) |
|---|---|
| 3 | Wall St. Bears Claw Back Into the Black (Reuters) |
| 3 | Carlyle Looks Toward Commercial Aerospace (Reuters) |
| 3 | Oil and Economy Cloud Stocks' Outlook (Reuters) |
| 3 | Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) |
| 3 | Oil prices soar to all-time record, posing new menace to US economy (AFP) |
| 3 | Stocks End Up, But Near Year Lows (Reuters) |
| 3 | Money Funds Fell in Latest Week (AP) |
| 3 | Fed minutes show dissent over inflation (USATODAY.com) |
| 3 | Safety Net (Forbes.com) |
| 3 | Wall St. Bears Claw Back Into the Black |
由于只對文本標題進行處理,因此在進行數據清洗時不用處理時不用處理停用詞和進行詞根還原。對于空格,由于是字符計算,因此不需要保留,直接刪除即可。 修改原來代碼如下,
def stop_words():
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.data.path.append("/tmp/")
nltk.download("stopwords", download_dir = "/tmp/");
stops = nltk.corpus.stopwords.words("English")
print(stops)
return stops
def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# Trim the string
string = string.strip()
string = string + "eos"
return string
def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# Trim the string
string = string.strip()
# Seperate the string with space, an array will be yielded
strings = string.split(" ")
strings = [word for word in strings if word not in stops]
strings = [nltk.PorterStemmer().stem(word) for word in strings]
strings.append("eos")
strings = ["bos"] + strings
return strings
def setup():
with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
labels = []
titles = []
descriptions = []
trains = csv.reader(handler)
for line in trains:
labels.append(jax.numpy.int32(line[0]))
titles.append(purify(line[1]))
descriptions.append(purify_stops(line[2]))
return labels, titles, descriptions
文本獨熱編碼處理
下面將生成的字符串進行獨熱編碼處理,處理方式很簡單,首先建立一個由26個字母組成的字符表,
def one_hot(strings):
alphabet = "abcdefghijklmnopqrstuvwxyz"
將不同的字符獲取字符表對應位置進行提取,根據提取的位置將對應的字符位置設置成1,其它為0。例如字符“c”在字符表中排行第3,那么獲取的字符矩陣為,
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
其它字符類似,代碼如下,
alphabet = "abcdefghijklmnopqrstuvwxyz"
def one_hot(characters):
array = numpy.array(characters)
length = len(alphabet) + 1
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array th>
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
下一步就是將字符串按字母表中的順序轉換成數字序列,代碼如下,
def indexes_of(characters):
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def train():
string = "hello"
indexes = indexes_of(string)
print("string =", string, ", indexes =", indexes)
if __name__ == "__main__":
train()
這樣生成結果如下,
string = hello , indexes = [7, 4, 11, 11, 14]
將代碼整合到一起,如下,
import numpy
def one_hot(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
array = numpy.array(characters)
length = len(alphabet)
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
def indexes_of(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def indexes_matrix(string):
indexes = indexes_of(string)
matrix = one_hot(indexes)
return matrix
def train():
#labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
string = "hello"
indexes = indexes_matrix(string)
print("string =", string, ", indexes =", indexes)
if __name__ == "__main__":
train()
運行結果打印輸出如下,
string = hello , indexes = [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0.]]
可以看到,單詞“hello”被轉換成一個[5, 26]大小的矩陣,供下一步處理。有了上面定義的方法,下一步就是對新聞標題進行獨熱編碼處理。代碼如下,
import numpy
import sys
sys.path.append("../52/")
import AgNewsCsvReader
def one_hot(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
array = numpy.array(characters)
length = len(alphabet)
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
def indexes_of(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def indexes_matrix(string):
indexes = indexes_of(string)
matrix = one_hot(indexes)
return matrix
def train():
labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
for title in titles[: 10]:
indexes = indexes_matrix(title)
print("string =", title, ", indexes.shape =", indexes.shape)
if __name__ == "__main__":
train()
運行結果打印輸出如下,
string = wallstbearsclawbackintotheblackreuterseos , indexes.shape = (41, 27)
string = carlylelookstowardcommercialaerospacereuterseos , indexes.shape = (47, 27)
string = oilandeconomycloudstocksoutlookreuterseos , indexes.shape = (41, 27)
string = iraqhaltsoilexportsfrommainsouthernpipelinereuterseos , indexes.shape = (53, 27)
string = oilpricessoartoalltimerecordposingnewmenacetouseconomyafpeos , indexes.shape = (60, 27)
string = stocksendupbutnearyearlowsreuterseos , indexes.shape = (36, 27)
string = moneyfundsfellinlatestweekapeos , indexes.shape = (31, 27)
string = fedminutesshowdissentoverinflationusatodaycomeos , indexes.shape = (48, 27)
string = safetynetforbescomeos , indexes.shape = (21, 27)
string = wallstbearsclawbackintotheblackeos , indexes.shape = (34, 27)
不過,這里出現了一個新問題,對云不同長度的單詞,矩陣的行長度不同。雖然卷積神經網絡可以處理不同長度的字符串,但是在本例中還是希望以相同大小的矩陣作為輸入進行計算。
生成文本矩陣時矩陣補全
對于不同長度的矩陣處理,簡單的思路就是將其進行規(guī)范化處理:長的截斷,短的補長。代碼如下,
def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
length = len(string)
if length > maximum_length:
string = string[: maximum_length]
matrix = indexes_matrix(string)
return matrix
else:
matrix = indexes_matrix(string)
length = maximum_length - length
matrix_padded = numpy.zeros([length, len(alphabet)])
matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
return matrix
def train():
'''
string = "hello"
indexes = indexes_matrix(string)
print("string =", string, ", indexes =", indexes)
'''
labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
for title in titles[: 10]:
indexes = align_string_matrix(title)
print("string =", title, ", indexes.shape =", indexes.shape)
if __name__ == "__main__":
train()
代碼中,首先對不同長度的字符串進行處理,
- 大于maximum_length(默認64,可根據需要自行設置該值)的字符串,截取前部分進行矩陣轉換。
- 長度小于maximum_length的,先生成由0構成的補全矩陣,再與原矩陣進行串接(numpy.concatenate)。
運行結果打印輸出如下,
string = wall st bears claw back into the black reuterseos , indexes.shape = (64, 27)
string = carlyle looks toward commercial aerospace reuterseos , indexes.shape = (64, 27)
string = oil and economy cloud stocks outlook reuterseos , indexes.shape = (64, 27)
string = iraq halts oil exports from main southern pipeline reuterseos , indexes.shape = (64, 27)
string = oil prices soar to all time record posing new menace to us economy afpeos , indexes.shape = (64, 27)
string = stocks end up but near year lows reuterseos , indexes.shape = (64, 27)
string = money funds fell in latest week apeos , indexes.shape = (64, 27)
string = fed minutes show dissent over inflation usatoday comeos , indexes.shape = (64, 27)
string = safety net forbes comeos , indexes.shape = (64, 27)
string = wall st bears claw back into the blackeos , indexes.shape = (64, 27)
構建分類標簽獨熱編碼矩陣
對于分類的表示,同樣可以使用獨熱編碼進行處理。代碼如下,
def one_hot_numbers(numbers):
array = numpy.array(numbers)
maximum = numpy.max(array) + 1
eyes = numpy.eye(maximum)[array]
return eyes
def train():
'''
string = "hello"
indexes = indexes_matrix(string)
print("string =", string, ", indexes =", indexes)
'''
labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
for title in titles[: 10]:
indexes = align_string_matrix(title)
print("string =", title, ", indexes.shape =", indexes.shape)
one_hoted_labels = one_hot_numbers(labels)
print("one_hoted_labels.shape = ", one_hoted_labels.shape)
if __name__ == "__main__":
train()
截止到目前,全部代碼如下,
../52/AgNewsCsvReader.py
import csv
import re
import jax
import ssl
import nltk
def stop_words():
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.data.path.append("/tmp/")
nltk.download("stopwords", download_dir = "/tmp/");
stops = nltk.corpus.stopwords.words("English")
print(stops)
return stops
def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# string = re.sub(pattern = " ", repl = "", string = string)
# Trim the string
string = string.strip()
string = string + "eos"
return string
def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# Trim the string
string = string.strip()
# Seperate the string with space, an array will be yielded
strings = string.split(" ")
strings = [word for word in strings if word not in stops]
strings = [nltk.PorterStemmer().stem(word) for word in strings]
strings.append("eos")
strings = ["bos"] + strings
return strings
def setup():
with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
labels = []
titles = []
descriptions = []
trains = csv.reader(handler)
trains = list(trains)
for i in range(len(trains)):
line = trains[I]
labels.append(jax.numpy.int32(line[0]))
titles.append(purify(line[1]))
descriptions.append(purify_stops(line[2]))
return labels, titles, descriptions
CharactersConvolutionalNeuralNetwork.py
import numpy
import sys
sys.path.append("../52/")
import AgNewsCsvReader
def one_hot(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
array = numpy.array(characters)
length = len(alphabet)
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
def one_hot_numbers(numbers):
array = numpy.array(numbers)
maximum = numpy.max(array) + 1
eyes = numpy.eye(maximum)[array]
return eyes
def indexes_of(characters, alphabet = None):
alphabet = ("abcdefghijklmnopqrstuvwxyz" if alphabet == None else alphabet)
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def indexes_matrix(string, alphabet = "abcdefghijklmnopqrstuvwxyz "):
indexes = indexes_of(string, alphabet)
matrix = one_hot(indexes, alphabet)
return matrix
def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
length = len(string)
if length > maximum_length:
string = string[: maximum_length]
matrix = indexes_matrix(string)
return matrix
else:
matrix = indexes_matrix(string)
length = maximum_length - length
matrix_padded = numpy.zeros([length, len(alphabet)])
matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
return matrix
def train():
'''
string = "hello"
indexes = indexes_matrix(string)
print("string =", string, ", indexes =", indexes)
'''
labels, titles, descriptions = AgNewsCsvReader.setup()
#print(labels[: 5], titles[: 5], titles[: 5])
trains = []
for title in titles[: 10]:
matrix = align_string_matrix(title)
trains.append(matrix)
trains = numpy.expand_dims(trains, axis = -1)
labels = one_hot_numbers(labels)
print("trains.shape =", trains.shape, ", labels.shape =", labels.shape)
if __name__ == "__main__":
train()
代碼中,首先通過csv庫獲取全文本數據,之后逐行將文本和標簽讀入,分別將其轉化成獨熱編碼矩陣后,再使用NumPy庫將其對應的列表轉換成NumPy格式。運行結果打印輸出如下,
trains.shape = (120000, 64, 27, 1) , labels.shape = (120000, 5)
這里分別生成了訓練集和標簽數據的獨熱編碼矩陣列表,
- 訓練集的維度為[120000, 64, 27, 1],第一個數字代表樣本總數,第二個和第三個數字為生成的矩陣維度,最后一個1代表這里只使用1個通道。
- 標簽數據為[120000, 5],是一個二維矩陣,120000是樣本的總數,5是類別。注意,one-hot是從0開始的,而標簽的分類是從1開始的,因此會自動添加一個0的標簽。
至此,文本數據處理結束。
一維卷積神經網絡conv1d模型實現文本分類
在完成了文本的處理后,下面進入基于卷積神經網絡的分類模型設計,如本章開始時提到了卷積處理字符文本分類的架構圖所示,模型的設計有多種多樣,根據類似的模型設計了一個由5層神經網絡構成的文本分類模型,
| 層級 | 名稱 |
|---|---|
| 1 | Conv 3 x 3, 1 x 1 |
| 2 | Conv 5 x 5, 1 x 1 |
| 3 | Conv 3 x 3, 1 x 1 |
| 4 | Fully Connected 256 |
| 5 | Fully Connected 5 |
前3層是基于一維的卷積神經網絡,后2層適用于分類任務的全連接層。代碼如下,
def cnn(number_classes):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(1, (3, 3)),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(1, (5, 5)),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Flatten,
jax.example_libraries.stax.Dense(32),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Dense(number_classes),
jax.example_libraries.stax.LogSoftmax
)
完整訓練代碼如下所示,
../52/AgNewsCsvReader.py
import csv
import re
import jax
import ssl
import nltk
def stop_words():
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.data.path.append("/tmp/")
nltk.download("stopwords", download_dir = "/tmp/");
stops = nltk.corpus.stopwords.words("English")
print(stops)
return stops
def purify(string: str, pattern: str = r"[^a-z]", replacement: str = " "):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# string = re.sub(pattern = " ", repl = "", string = string)
# Trim the string
string = string.strip()
string = string + " eos"
return string
def purify_stops(string: str, pattern: str = r"[^a-z0-9]", replacement: str = " ", stops = stop_words()):
string = string.lower()
string = re.sub(pattern = pattern, repl = replacement, string = string)
# Replace the consucutive spaces with single space
string = re.sub(pattern = r" +", repl = replacement, string = string)
# Trim the string
string = string.strip()
# Seperate the string with space, an array will be yielded
strings = string.split(" ")
strings = [word for word in strings if word not in stops]
strings = [nltk.PorterStemmer().stem(word) for word in strings]
strings.append("eos")
strings = ["bos"] + strings
return strings
def setup():
with open("../../Shares/ag_news_csv/train.csv", "r") as handler:
train_labels = []
train_titles = []
train_descriptions = []
trains = csv.reader(handler)
trains = list(trains)
for i in range(len(trains)):
line = trains[I]
train_labels.append(jax.numpy.int32(line[0]))
train_titles.append(purify(line[1]))
train_descriptions.append(purify_stops(line[2]))
with open("../../Shares/ag_news_csv/test.csv", "r") as handler:
test_labels = []
test_titles = []
test_descriptions = []
tests = csv.reader(handler)
tests = list(tests)
for i in range(len(tests)):
line = tests[I]
test_labels.append(jax.numpy.int32(line[0]))
test_titles.append(purify(line[1]))
test_descriptions.append(purify_stops(line[2]))
return (train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions)
def main():
(train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions) = setup()
print((train_labels.shape, train_titles.shape, train_descriptions.shape), (test_labels.shape, test_titles.shape, test_descriptions.shape))
if __name__ == "__main__":
main()
CharactersConvolutionalNeuralNetwork.py
import numpy
import jax
import jax.example_libraries.stax
import jax.example_libraries.optimizers
import sys
sys.path.append("../52/")
import AgNewsCsvReader
def one_hot(characters, alphabet):
array = numpy.array(characters)
length = len(alphabet)
# jax.numpy.eye(N, M = None, K = 0, dtype) to create a 2-dimension array that
# the elements in diagonal will be filled out with 1s, others are 0s.
eyes = numpy.eye(length)[array]
return eyes
def one_hot_numbers(numbers):
array = numpy.array(numbers)
maximum = numpy.max(array) + 1
eyes = numpy.eye(maximum)[array]
return eyes
def indexes_of(characters, alphabet):
indexes = []
for character in characters:
index = alphabet.index(character)
indexes.append(index)
return indexes
def indexes_matrix(string, alphabet):
indexes = indexes_of(string, alphabet)
matrix = one_hot(indexes, alphabet)
return matrix
def align_string_matrix(string, maximum_length = 64, alphabet = "abcdefghijklmnopqrstuvwxyz "):
length = len(string)
if length > maximum_length:
string = string[: maximum_length]
matrix = indexes_matrix(string, alphabet)
return matrix
else:
matrix = indexes_matrix(string, alphabet)
length = maximum_length - length
matrix_padded = numpy.zeros([length, len(alphabet)])
matrix = numpy.concatenate([matrix, matrix_padded], axis = 0)
return matrix
def cnn(number_classes):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(1, (3, 3)),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(1, (5, 5)),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Flatten,
jax.example_libraries.stax.Dense(32),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Dense(number_classes),
jax.example_libraries.stax.LogSoftmax
)
def setup():
prng = jax.random.PRNGKey(15)
(train_labels, train_titles, train_descriptions), (test_labels, test_titles, test_descriptions) = AgNewsCsvReader.setup()
train_texts = []
for title in train_titles:
matrix = align_string_matrix(title)
train_texts.append(matrix)
train_texts = numpy.expand_dims(train_texts, axis = -1)
train_labels = one_hot_numbers(train_labels)
test_texts = []
for title in test_titles:
matrix = align_string_matrix(title)
test_texts.append(matrix)
test_texts = numpy.expand_dims(test_texts, axis = -1)
test_labels = one_hot_numbers(test_labels)
number_classes = 5
input_shape = [-1, 64, 28, 1]
batch_size = 100
epochs = 5
init_random_params, predict = cnn(number_classes)
optimizer_init_function, optimizer_update_function, get_params_function = jax.example_libraries.optimizers.adam(step_size = 2.17e-4)
_, init_params = init_random_params(prng, input_shape = input_shape)
optimizer_state = optimizer_init_function(init_params)
return (prng, number_classes, batch_size, epochs, init_params, optimizer_state), (init_random_params, optimizer_init_function, predict, optimizer_update_function, get_params_function), ((train_texts, train_labels), (test_texts, tes>
def verify_accuracy(params, batch, predict_function):
inputs, targets = batch
predictions = predict_function(params, inputs)
class_ = jax.numpy.argmax(predictions, axis = 1)
targets = jax.numpy.argmax(targets, axis = 1)
return jax.numpy.sum(predictions == targets)
def loss_function(params, batch, predict_function):
inputs, targets = batch
predictions = predict_function(params, inputs)
losses = -targets * predictions
losses = jax.numpy.sum(losses, axis = 1)
losses = jax.numpy.mean(losses)
return losses
def update_function(i, optimizer_state, batch, get_params_function, optimizer_update_function, predict_function):
params = get_params_function(optimizer_state)
loss_function_grad = jax.grad(loss_function)
gradients = loss_function_grad(params, batch, predict_function)
return optimizer_update_function(i, gradients, optimizer_state)
def train():
(prng, number_classes, batch_size, epochs, init_params, optimizer_state), (init_random_params, optimizer_init_function, predict, optimizer_update_function, get_params_function), ((train_texts, train_labels), (test_texts, test_label>
print("train_texts.shape =", train_texts.shape, ", train_labels.shape =", train_labels.shape, ", test_texts.shape =", test_texts.shape, ", test_labels.shape =", test_labels.shape)
train_batch_number = int(len(train_texts) / batch_size)
test_batch_number = int(len(test_texts) / batch_size)
for i in range(epochs):
print(f"Epoch {i} started")
for j in range(train_batch_number):
start = j * batch_size
end = (j + 1) * batch_size
batch = (train_texts[start: end], train_labels[start: end])
optimizer_state = update_function(i, optimizer_state, batch, get_params_function, optimizer_update_function, predict)
if (j + 1) % 10 == 0:
params = get_params_function(optimizer_state)
losses = loss_function(params, batch)
print("Losses now is =", losses)
params = get_params_function(optimizer_state)
print(f"Epoch {i} compeleted")
accuracies = []
predictions = 0.0
for j in range(test_batch_number):
start = j * batch_size
end = (j + 1) * batch_size
batch = (test_texts[start: end], test_labels[start: end])
predictions += verify_accuracy(params, batch)
accuracies.append(predictions / float(len(train_texts)))
print(f"Training accuracies =", accuracies)
if __name__ == "__main__":
train()
首先獲取訓練集和測試集,接下來定義預損失函數、優(yōu)化器,與ResNet類似,不再贅述。
結論
本章基于AG News新聞標題和分類標簽,使用一層卷積和全連接層建構了一個文本分類模型。注意,這個示例知識為了說明問題,效果并不一定好。