今天的任務(wù)是依照這篇介紹的方法,使用GoogleNet和AlexNet遷移學(xué)習(xí)ECG。
Signal Classification with Wavelet Analysis and Convolutional Neural Networks
整個(gè)實(shí)現(xiàn)流程包括以下幾步:
- 下載三個(gè)ECG Dataset;
- 整理數(shù)據(jù)集,包括降采樣、截?cái)?、?biāo)簽,存儲(chǔ)到一個(gè)ECG_Data的structure里;
- Plot原始數(shù)據(jù);
- 使用CWT,得到scalogram,作為該樣本的輸入特征圖;
- 劃分訓(xùn)練集和測(cè)試集;
- 使用GoogleNet進(jìn)行訓(xùn)練;
- 使用AlexNet進(jìn)行訓(xùn)練;
下載ECG Dataset
實(shí)驗(yàn)選用的數(shù)據(jù)集為以下三個(gè):
The MIT-BIH Atrial Fibrillation Database
The MIT-BIH ST Change Database
The MIT-BIH Supraventricular Arrhythmia Database
由于網(wǎng)站提供的下載整個(gè)數(shù)據(jù)包的鏈接失效,所以使用wget工具來下載數(shù)據(jù)庫(kù),指令如下:
wget -r -np http://physionet.org/physiobank/database/afdb/
wget -r -np http://physionet.org/physiobank/database/stdb/
wget -r -np http://physionet.org/physiobank/database/svdb/
下載完成后,共得到23+28+78=129條記錄,每個(gè)記錄都包含雙通道的心電信號(hào)。
整理數(shù)據(jù)集
- 使用WFDB工具箱中的rdsamp函數(shù),新建一個(gè)id_list,數(shù)組內(nèi)存儲(chǔ)數(shù)據(jù)庫(kù)的樣本編號(hào),按數(shù)組依次讀取信號(hào);
- 使用resample函數(shù),將各個(gè)樣本的采樣率都降為128Hz;
- 直接截取信號(hào)的前65536個(gè)點(diǎn)(即512s信號(hào));
- 分離信號(hào)的兩個(gè)通道,存儲(chǔ)在ECGData.data中,同時(shí)在ECGData.label中存儲(chǔ)對(duì)應(yīng)的病癥標(biāo)簽,建立Data Structure,大小為(248, 65536)和(248, 1)。
stdb的第313-317,319-323文件只有一列數(shù)據(jù)。
經(jīng)過整理后的ECGData Structure組成如下:
1-46:AF記錄,共46條;
47-92:ST記錄,共46條;
93-248,SV記錄,共157條。
%% Load AFDB
start=1
flag={'AF'};
id_list=[04015,04043,04048,04126,04746,04908,04936,05091,05121,05261,06426,06453,06995,07162,07859,07879,07910,08215,08219,08378,08405,08434,08455];
for id = 1:23
signal=rdsamp(['/database/afdb/', num2str(id_list(id), '%05d')]);
resamp_signal=resample(signal, 128, 250);
cutoff_signal=resamp_signal(1:65536, :);
ECGData.Data(:,start)=cutoff_signal(1:65536,1);
ECGData.Labels(start)=flag;
ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
ECGData.Labels(start+1)=flag;
start=start+2;
id
end
%% Load STDB
start=47
flag={'ST'};
% 313~317 319~323
id_list=[300,301,302,303,304,305,306,307,308,309,310,311,312,318,324,325,326,327];
for id = 1:18
signal=rdsamp(['/database/stdb/', num2str(id_list(id))]);
resamp_signal=resample(signal, 128, 360);
cutoff_signal=resamp_signal(1:65536, :);
ECGData.Data(:,start)=cutoff_signal(1:65536,1);
ECGData.Labels(start)=flag;
ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
ECGData.Labels(start+1)=flag;
start=start+2;
id
end
id_list=[313,314,315,316,317,319,320,321,322,323];
for id = 1:10
signal=rdsamp(['/database/stdb/', num2str(id_list(id))]);
resamp_signal=resample(signal, 128, 360);
cutoff_signal=resamp_signal(1:65536, :);
ECGData.Data(:,start)=cutoff_signal(1:65536,1);
ECGData.Labels(start)=flag;
start=start+1;
id
end
%% Load SVDB
start=93
flag={'SV'};
id_list=[800,801,802,803,804,805,806,807,808,809,810,811,812,820,821,822,823,824,825,826,827,828,829,840,841,842,843,844,845,846,847,848,849,850,851,852,853,854,855,856,857,858,859,860,861,862,863,864,865,866,867,868,869,870,871,872,873,874,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,894];
for id = 1:78
signal=rdsamp(['/database/svdb/', num2str(id_list(id))]);
resamp_signal=resample(signal, 128, 128);
cutoff_signal=resamp_signal(1:65536, :);
ECGData.Data(:,start)=cutoff_signal(1:65536,1);
ECGData.Labels(start)=flag;
ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
ECGData.Labels(start+1)=flag;
start=start+2;
id
end
%% Rebuild
ECGData.Data=ECGData.Data';
ECGData.Labels=ECGData.Labels';
Plot原始數(shù)據(jù)
調(diào)用例程中的helperPlotReps()函數(shù),看原始數(shù)據(jù)。很奇怪的是第三類問題好像采樣率有些奇怪,但是找不到問題的原因。

特征提取與數(shù)據(jù)集劃分
首先使用cwtfilterbank函數(shù)對(duì)原始信號(hào)進(jìn)行CWT變換,得到的結(jié)果如圖所示。

然后使用helpCreateRGBfromTF()對(duì)整個(gè)數(shù)據(jù)集進(jìn)行變換,并使用splitEachLabel進(jìn)行訓(xùn)練集和測(cè)試集的分割,分割得到大小為199的訓(xùn)練集和大小為49的測(cè)試集,存儲(chǔ)在ImageDatastore里。
helperCreateRGBfromTF(ECGData,parentDir,dataDir)
allImages = imageDatastore(fullfile(parentDir,dataDir),...
'IncludeSubfolders',true,...
'LabelSource','foldernames');
rng default
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized');
disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);
使用GoogleNet進(jìn)行訓(xùn)練
GoogleNet是使用ImageNet訓(xùn)練的對(duì)于1000分類的深層CNN網(wǎng)絡(luò),其結(jié)構(gòu)如圖所示,為了進(jìn)行遷移學(xué)習(xí),我們將最后四層修改為針對(duì)三分類問題的輸出。
lgraph = removeLayers(lgraph,{'pool5-drop_7x7_s1','loss3-classifier','prob','output'});
numClasses = numel(categories(imgsTrain.Labels));
newLayers = [
dropoutLayer(0.6,'Name','newDropout')
fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',5,'BiasLearnRateFactor',5)
softmaxLayer('Name','softmax')
classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);
lgraph = connectLayers(lgraph,'pool5-7x7_s1','newDropout');
inputSize = net.Layers(1).InputSize;

同時(shí),設(shè)置GoogleNet訓(xùn)練的一些參數(shù),開始訓(xùn)練,訓(xùn)練結(jié)果如圖所示。
options = trainingOptions('sgdm',...
'MiniBatchSize',15,...
'MaxEpochs',20,...
'InitialLearnRate',1e-4,...
'ValidationData',imgsValidation,...
'ValidationFrequency',10,...
'ValidationPatience',Inf,...
'Verbose',1,...
'ExecutionEnvironment','cpu',...
'Plots','training-progress');
rng default
trainedGN = trainNetwork(imgsTrain,lgraph,options);
trainedGN.Layers(end-2:end)
cNames = trainedGN.Layers(end).ClassNames


使用GoogleNet訓(xùn)練的最終正確率為:71.429%
同時(shí)我們還觀測(cè)了GoogleNet的激活函數(shù)、對(duì)于AF病癥的激活函數(shù)以及最強(qiáng)的AF通道,如下圖所示。



使用AlexNet進(jìn)行訓(xùn)練
AlexNet共有25層,如下圖所示。

針對(duì)本問題,我們修改了AlexNet的最后三層,同時(shí)改變圖像的形狀以匹配AlexNet的輸入。
%% Load
alex=alexnet;
layers = alex.Layers
%% Modify AlexNet Network Parameters
layers(23) = fullyConnectedLayer(3);
layers(25) = classificationLayer;
%% Prepare RGB Data for AlexNet
inputSize = alex.Layers(1).InputSize;
augimgsTrain = augmentedImageDatastore(inputSize(1:2),imgsTrain);
augimgsValidation = augmentedImageDatastore(inputSize(1:2),imgsValidation);
同時(shí),設(shè)置AlexNet訓(xùn)練的一些參數(shù),開始訓(xùn)練,訓(xùn)練結(jié)果如圖所示。


使用AlexNet訓(xùn)練的最終正確率為75.51%