前面提到了用CNN來(lái)做OCR。這篇文章介紹另一種做OCR的方法,就是通過(guò)LSTM+CTC。這種方法的好處是他可以事先不用知道一共有幾個(gè)字符需要識(shí)別。之前我試過(guò)不用CTC,只用LSTM,效果一直不行,后來(lái)下決心加上CTC,效果一下就上去了。
CTC是序列標(biāo)志的一個(gè)重要算法,它主要解決了label對(duì)齊的問(wèn)題。有很多實(shí)現(xiàn)。百度IDL在16年初公開了一個(gè)GPU的實(shí)現(xiàn),號(hào)稱速度比之前的theano-ctc, stanford-ctc都要快。Mxnet目前還沒有ctc的實(shí)現(xiàn),因此決定吧warpctc集成進(jìn)mxnet。
根據(jù)issue里作者們的建議,決定和集成torch一樣,寫一個(gè)plugin,因此C++代碼放在plugin/warpctc目錄中。整個(gè)集成任務(wù)其實(shí)就是寫一個(gè)wrapctc的op。代碼在 plugin/warpctc/warpctc-inl.h.
CTC這一層其實(shí)和SoftmaxOutput很像。其實(shí)他們的forward的實(shí)現(xiàn)就是一模一樣的。唯一的差別就是backward中g(shù)rad的實(shí)現(xiàn),在這里需要調(diào)用warpctc的compute_ctc_loss函數(shù)來(lái)計(jì)算梯度。實(shí)際上warpctc的主要接口也就是這個(gè)函數(shù)。
下面說(shuō)說(shuō)具體怎么用lstm+ctc來(lái)做ocr的任務(wù)。詳細(xì)的代碼在 examples/warpctc/lstm_ocr.py。這里只說(shuō)說(shuō)大體思路。
假設(shè)我們要解決的是4位數(shù)字的識(shí)別,圖片是80*30的圖片。那么我們就將每張圖片按列切分成80個(gè)30維的向量。然后作為一個(gè)lstm的80個(gè)輸入。一個(gè)lstm的輸出和輸入數(shù)目應(yīng)該是相同的。而我們的預(yù)測(cè)目標(biāo)卻只有4個(gè)數(shù)字。而不是80個(gè)數(shù)字。在沒有用ctc時(shí)我想了兩個(gè)解決方案。第一個(gè)是用encode-decode模式。也就是80個(gè)輸入做encode,然后decode成4個(gè)輸出。實(shí)測(cè)效果很挫。第二個(gè)是把4個(gè)label每個(gè)copy20遍,從而變成80個(gè)label。實(shí)測(cè)也很挫。沒辦法,最后只能用ctc loss了。
用ctc loss的體會(huì)就是,如果input的長(zhǎng)度遠(yuǎn)遠(yuǎn)大于label的長(zhǎng)度,比如我這里是80和4的關(guān)系。那么一開始的收斂會(huì)比較慢。在其中有一段時(shí)間cost幾乎不變。此刻一定要有耐心,最終一定會(huì)收斂的。在ocr識(shí)別的這個(gè)例子上最終可以收斂到95%的精度。
目前代碼還在等待merge。pull request。
---------------
歡迎關(guān)注 微信公眾號(hào)【ResysChina】