CNTK2.3 C# CNN訓練MNIST(CNTK C#入門3)

20180110

keywords: CNTK ?C# ?CNN ?MNIST


1、CNN使用需明白的參數

對CNN不太明白的童鞋參見這篇,很生動cnn動態(tài)展示

卷積網絡的基本運算過程如下,z = wx + b,w是權重矩陣,b是偏差矩陣


看了上面的連接應該明白下面這三個參數的意思了,創(chuàng)建CNN網絡時要用到,在此不多說了。

filter、stride、pad


2、CNN網絡創(chuàng)建

var scaledInput = CNTKLib.ElementTimes(Constant.Scalar(0.00390625f, device), input);

? ? ? ? ? ? ? ? double convWScale = 0.26;// parameter initialization hyper parameter

? ? ? ? ? ? ? ? var m1 = new Parameter(new int[] { 5, 5, 1, 4 }, DataType.Float, CNTKLib.GlorotUniformInitializer(convWScale, -1, 2), device);

? ? ? ? ? ? ? ? var c1 = CNTKLib.ReLU(CNTKLib.Convolution(m1, scaledInput, new int[] { 1, 1, 1 }));

? ? ? ? ? ? ? ? var p1 = CNTKLib.Pooling(c1, PoolingType.Max, new int[] { 4, 4 }, new int[] { 4, 4 }, new bool[] { true });

? ? ? ? ? ? ? ? var m2 = new Parameter(new int[] { 4, 4, 4, 8 }, DataType.Float, CNTKLib.GlorotUniformInitializer(convWScale, -1, 2), device);

? ? ? ? ? ? ? ? var c2 = CNTKLib.ReLU(CNTKLib.Convolution(m2, p1, new int[] { 1, 1, 4 }));

? ? ? ? ? ? ? ? var p2 = CNTKLib.Pooling(c2, PoolingType.Max, new int[] { 3, 3 }, new int[] { 3, 3 }, new bool[] { true });

? ? ? ? ? ? ? ? var m3 = new Parameter(new int[] { 3, 3, 8, 16 }, DataType.Float, CNTKLib.GlorotUniformInitializer(convWScale, -1, 2), device);

? ? ? ? ? ? ? ? var c3 = CNTKLib.ReLU(CNTKLib.Convolution(m3, p2, new int[] { 1, 1, 8 }));

? ? ? ? ? ? ? ? var p3 = CNTKLib.Pooling(c3, PoolingType.Max, new int[] { 2, 2 }, new int[] { 2, 2 }, new bool[] { true });

? ? ? ? ? ? ? ? dout = TestHelper.Dense(p3, numClasses, device, Activation.None, classifierName);

上面的程序創(chuàng)建了3個卷積層(c1、c2、c3)、3個pooling層(p1、p2、p3)和一個輸出層(dout),下面單看一個卷積層和一個pooling層如何創(chuàng)建

var m1 = new Parameter(new int[] { 5, 5, 1, 4 }, DataType.Float, CNTKLib.GlorotUniformInitializer(convWScale, -1, 2), device);

? ? ? ? ? ? ? ? var c1 = CNTKLib.ReLU(CNTKLib.Convolution(m1, scaledInput, new int[] { 1, 1, 1 }));

? ? ? ? ? ? ? ? var p1 = CNTKLib.Pooling(c1, PoolingType.Max, new int[] { 4, 4 }, new int[] { 4, 4 }, new bool[] { true });

{ 5, 5, 1, 4 }表示卷積核filter是5*5,輸入1通道,輸出4通道

{ 1, 1, 1 }表示卷積stride是1、1;最后一個1表示輸入通道

前一個 { 4, 4 }表示Pooling filter是4*4,后一個?{ 4, 4 }表示Pooling的stride是4、4,也就是沒有重復的區(qū)域

3、幾點注意事項

1)卷積層的stride,{ 1, 1, 1 }中最后一個1表示輸入通道,如果填錯或不填,程序會莫名其妙的飛掉

var c1 = CNTKLib.ReLU(CNTKLib.Convolution(cp1, scaledInput, new int[] { 1, 1, 1 }));

2)使用下面這句正確率會降低1個點(不使用pad),由此看出 1、使用pad會提高正確率;2、創(chuàng)建卷積時缺省pad為true

var c1 = CNTKLib.ReLU(CNTKLib.Convolution(cp1, scaledInput, new int[] { 1, 1, 1 }, new bool[] { true }, new bool[] { false }));

3)前面的Pooling層filter比后面的Pooling層大,會提高正確率

4、測試結論

1)卷積網絡會比MLP提高2個點的正確率

2)對于mnist,增加卷積層并沒有提高正確率

3)我測試的結果只能達到98.86,為啥沒有超過99呢?

4)單純使用卷積層,不使用Pooling層,計算時間會增加幾倍,結果也會下降2個點??梢奝ooling層的抽象提取作用還是有效的。


下節(jié)討論LSTM長短記憶模型

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

友情鏈接更多精彩內容