27 | 使用PyTorch完成醫(yī)療圖像識別大項目:實現(xiàn)端到端模型方案

接下來需要再做一些工作,并把我們前面搞好的模型串起來,形成一個端到端的解決方案。這個方案如下,首先是從原始的CT數(shù)據(jù)出發(fā)進行圖像分割,識別可能是結(jié)節(jié)的體素,并對這些體素區(qū)域進行分組,然后用這些分割出的候選結(jié)節(jié)信息進行分類,首先是區(qū)分這是否是一個結(jié)節(jié),針對是結(jié)節(jié)的,再區(qū)分這是否是一個惡性結(jié)節(jié),這樣就完成了整個模型框架。


image.png

由于我們之前訓(xùn)練的兩個模型使用的訓(xùn)練數(shù)據(jù)是不一樣的,我們直接獲取了標(biāo)注的結(jié)節(jié)信息作為分類模型的訓(xùn)練集,而在實際中,我們需要對分割模型的結(jié)果進行分類。這就存在數(shù)據(jù)泄露的問題。也就說在分類模型的訓(xùn)練集中可能有些數(shù)據(jù)是分割模型的驗證集,反過來,在分類模型的驗證集里面可能有分割模型的訓(xùn)練集數(shù)據(jù)。所以之前壓根就沒保存模型,就是為了在這里重新訓(xùn)練一下。獲取LunaDataset跟之前還是一樣的,有區(qū)別的是從segmentationDataset中獲取標(biāo)注數(shù)據(jù)并分割為訓(xùn)練集和驗證集。

重新訓(xùn)練分類模型

先為這一章構(gòu)建數(shù)據(jù)緩存。

run('test14ch.prepcache.LunaPrepCacheApp')

然后訓(xùn)練100個epoch,由于下調(diào)了數(shù)據(jù)樣本量,訓(xùn)練集里面的正負樣本各2.5w條,驗證集保持原樣正樣本154條,負樣本5w+。所以訓(xùn)練起來還算可以,差不多10分鐘一個epoch

run('test14ch.training.ClassificationTrainingApp', f'--epochs=100', 'nodule-nonnodule')

看到第一個epoch結(jié)果,效果還不是很好,不過程序沒問題,就在這里跑著好了,我就去睡覺了。


image.png

一覺醒來,已經(jīng)70+epoch,訓(xùn)練集上的準(zhǔn)確率已經(jīng)99%+,驗證集上對于正樣本的準(zhǔn)確率也達到了94%+,不過中體的precision還是比較低的,因為兩個類別的樣本量差距太大了,有很多負樣本被歸為了陽性結(jié)果,不過這問題不大,我們主要是能把真的陽性篩出來就好了。


image.png

到了80個epoch,訓(xùn)練集效果基本沒變了,驗證集上陽性準(zhǔn)確率下降了一點。
image.png

直接跳到100個epoch,可以看到在訓(xùn)練集上的效果又提升了一丟丟,但是驗證集上,尤其是驗證集的負樣本準(zhǔn)確率下降了不少,這不符合我們的預(yù)期,說明模型有點過擬合了。


image.png

我們就用這里面的最佳模型作為我們最后系統(tǒng)中需要使用的分類模型。(第95epoch的模型,這時候有最好的f1 score)

連接分割和分類模型

新建一個代碼,就叫結(jié)節(jié)分析:nodule_analysis.py,它最核心的地方就是下面這段。

#取一個uid
        for _, series_uid in series_iter:
#然后獲取對應(yīng)的CT數(shù)據(jù)
            ct = getCt(series_uid)
#緊接著跑分割模型
            mask_a = self.segmentCt(ct, series_uid)
#給分割模型預(yù)測到的結(jié)節(jié)數(shù)據(jù)進行分組
            candidateInfo_list = self.groupSegmentationOutput(
                series_uid, ct, mask_a)
#最后跑分類模型,決定是不是結(jié)節(jié)
            classifications_list = self.classifyCandidates(
                ct, candidateInfo_list)

其中,分割部分代碼

    def segmentCt(self, ct, series_uid):
#預(yù)測不需要更新,關(guān)閉自動梯度計算
        with torch.no_grad():
#用來存儲輸出結(jié)果
            output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
#初始化數(shù)據(jù)加載器
            seg_dl = self.initSegmentationDl(series_uid) 
#遍歷整個CT
            for input_t, _, _, slice_ndx_list in seg_dl:
#發(fā)送到GPU
                input_g = input_t.to(self.device)
#運行分割模型
                prediction_g = self.seg_model(input_g)
#把結(jié)果存起來
                for i, slice_ndx in enumerate(slice_ndx_list):
                    output_a[slice_ndx] = prediction_g[i].cpu().numpy()
#構(gòu)建掩碼結(jié)果
            mask_a = output_a > 0.5
            mask_a = morphology.binary_erosion(mask_a, iterations=1)
#返回
        return mask_a

接下來跟分割的結(jié)果分組。這里使用了一個scipy.ndimage.measurements的方法,measurements.label用來標(biāo)記連通區(qū)域。舉個簡單的例子,如下左圖是我們的掩碼結(jié)果,measurements.label的功能就是去識別這里面有多少是連通的,并用一個標(biāo)記去修改它們的值。對于左上角,是第一個連通區(qū)域,那么里面的值都改為1,中間這塊是第二個連通區(qū)域,里面的值都改為2,依次類推,就變成了右側(cè)的樣子。 measurements.center_of_mass則是用來計算每個連通區(qū)域的中心點坐標(biāo)。


image.png
    def groupSegmentationOutput(self, series_uid,  ct, clean_a):

        candidateLabel_a, candidate_count = measurements.label(clean_a)
        centerIrc_list = measurements.center_of_mass(
            ct.hu_a.clip(-1000, 1000) + 1001,
            labels=candidateLabel_a,
            index=np.arange(1, candidate_count+1),
        )
#把識別到的數(shù)據(jù)轉(zhuǎn)化分類模型要用的數(shù)據(jù)。
        candidateInfo_list = []
        for i, center_irc in enumerate(centerIrc_list):
            center_xyz = irc2xyz(
                center_irc,
                ct.origin_xyz,
                ct.vxSize_xyz,
                ct.direction_a,
            )
            assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, candidate_count])
            assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
            candidateInfo_tup = \
                CandidateInfoTuple(False, False, False, 0.0, series_uid, center_xyz)
            candidateInfo_list.append(candidateInfo_tup)

        return candidateInfo_list

最后是分類模型。

    def classifyCandidates(self, ct, candidateInfo_list):
# 初始化dataloader
        cls_dl = self.initClassificationDl(candidateInfo_list)
        classifications_list = []
        for batch_ndx, batch_tup in enumerate(cls_dl):
            input_t, _, _, series_list, center_list = batch_tup
#發(fā)送到GPU上
            input_g = input_t.to(self.device)
            with torch.no_grad():
#運行分類模型
                _, probability_nodule_g = self.cls_model(input_g)
#這里還有一個分是否惡性的模型,現(xiàn)在我們還沒開發(fā),先留下這個位置
                if self.malignancy_model is not None:
                    _, probability_mal_g = self.malignancy_model(input_g)
                else:
                    probability_mal_g = torch.zeros_like(probability_nodule_g)

            zip_iter = zip(center_list,
                probability_nodule_g[:,1].tolist(),
                probability_mal_g[:,1].tolist())
#轉(zhuǎn)換坐標(biāo)
            for center_irc, prob_nodule, prob_mal in zip_iter:
                center_xyz = irc2xyz(center_irc,
                    direction_a=ct.direction_a,
                    origin_xyz=ct.origin_xyz,
                    vxSize_xyz=ct.vxSize_xyz,
                )
                cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
                classifications_list.append(cls_tup)
        return classifications_list

我們的CT圖像原本有大概3300w個體素,經(jīng)過圖像分割之后,留下大約100w個體素,通過給這些體素分組,可以得到大概1000個候選結(jié)節(jié)信息,然后對這些信息進行分類確認哪些是結(jié)節(jié),哪些不是結(jié)節(jié),經(jīng)過這步之后還剩幾十個確認是結(jié)節(jié),最后一步是確認結(jié)節(jié)的性質(zhì),惡性的通常來說最多也就一兩個。


image.png

這時候在回到main方法中,我們已經(jīng)得到了模型的結(jié)果,

#這個cli_args.run_validation參數(shù)是用來判斷是否跑驗證集數(shù)據(jù)的,如果不是驗證集數(shù)據(jù),而是單個輸入uid,那么就執(zhí)行下面的信息顯示功能
            if not self.cli_args.run_validation:
                print(f"found nodule candidates in {series_uid}:")
                for prob, prob_mal, center_xyz, center_irc in classifications_list:
                    if prob > 0.5:#如果我們找到的結(jié)節(jié)概率超過0.5,就輸出信息到屏幕上,給醫(yī)生看
                        s = f"nodule prob {prob:.3f}, "
                        if self.malignancy_model:
                            s += f"malignancy prob {prob_mal:.3f}, "
                        s += f"center xyz {center_xyz}"
                        print(s)
#這里輸出混淆矩陣
            if series_uid in candidateInfo_dict:
                one_confusion = match_and_score(
                    classifications_list, candidateInfo_dict[series_uid]
                )
                all_confusion += one_confusion
                print_confusion(
                    series_uid, one_confusion, self.malignancy_model is not None
                )

        print_confusion(
            "Total", all_confusion, self.malignancy_model is not None
        )

用我們建好的模型來預(yù)測一下數(shù)據(jù)看看,運行速度還是挺快的

python -m test14ch.nodule_analysis 1.3.6.1.4.1.14519.5.2.1.6279.6001.592821488053137951302246128864

輸出的結(jié)果如下,總共發(fā)現(xiàn)19個結(jié)節(jié),其中17個是假陽性,1個良性,1個惡性。不過這里我還沒有把惡性分類器加上,看了一下代碼,這個惡性標(biāo)記應(yīng)該是從原始的標(biāo)注數(shù)據(jù)里來的?


image.png

識別惡性結(jié)節(jié)

這個地方我們先獲取一份關(guān)于惡性腫瘤的標(biāo)注信息。這個數(shù)據(jù)來自我們之前已經(jīng)安裝的LIDC工具包,如果你還沒安裝可以用下面這行shell安裝

pip install pylidc

這里面有醫(yī)生關(guān)于惡性結(jié)節(jié)的標(biāo)注,每個醫(yī)生會標(biāo)注非常不可能、不可能、不確定、可疑、非??梢蓭追N情況,同時,對于一個結(jié)節(jié)可能會有多個醫(yī)生進行標(biāo)注。

ROC與AUC

這里插播一個小知識,就是需要學(xué)一個新的評估指標(biāo),ROC-AUC。要了解ROC曲線(Receiver Operating Characteristic curve),我們先回到混淆矩陣上來。


image.png

根據(jù)混淆矩陣,可以算出真陽性率(True Positive Rate,TPR)和假陽性率(False Positive Rate,F(xiàn)PR)
其中,


image.png

而ROC曲線就是假設(shè)我們對判斷陽性取不同的閾值時以兩個值為橫縱坐標(biāo)得到的一條曲線。下圖中,我們假設(shè)使用一個結(jié)節(jié)的直徑大小作為判斷是否惡性結(jié)節(jié)的標(biāo)準(zhǔn),那么該曲線就是當(dāng)取不同的直徑大小作為判斷閾值時,所得到的ROC曲線。而AUC(Area Under ROC)即為ROC曲線下面的面積。ROC曲線是用來衡量分類器的分類能力,AUC表示,隨機抽取一個正樣本和一個負樣本,分類器正確給出正樣本的score高于負樣本的概率。
image.png

因此,如果AUC越大,則表示模型的效果越好。

import torch
%matplotlib inline
from matplotlib import pyplot

import test14ch.dsets
import test14ch.model
#這里獲取帶有惡性結(jié)節(jié)標(biāo)記的數(shù)據(jù)集
ds = test14ch.dsets.MalignantLunaDataset(val_stride=10, isValSet_bool=True)  
nodules = ds.ben_list + ds.mal_list
#獲取是否惡性結(jié)節(jié)的狀態(tài)和直徑
is_mal = torch.tensor([n.isMal_bool for n in nodules]) 
diam  = torch.tensor([n.diameter_mm for n in nodules])
#惡性和良性數(shù)目
num_mal = is_mal.sum()  
num_ben = len(is_mal) - num_mal
#設(shè)置閾值,取結(jié)節(jié)直徑的最大最小值,并分成100份
threshold = torch.linspace(diam.max(), diam.min(), steps=100)
#使用直徑來判斷是否惡性結(jié)節(jié)
predictions = (diam[None] >= threshold[:, None])  
計算真陽率和假陽率
tp_diam = (predictions & is_mal[None]).sum(1).float() / num_mal  
fp_diam = (predictions & ~is_mal[None]).sum(1).float() / num_ben
#計算auc
fp_diam_diff =  fp_diam[1:] - fp_diam[:-1]
tp_diam_avg  = (tp_diam[1:] + tp_diam[:-1])/2
auc_diam = (fp_diam_diff * tp_diam_avg).sum()
#fill用于后面繪圖使用
fp_fill = torch.ones((fp_diam.shape[0] + 1,))
fp_fill[:-1] = fp_diam

tp_fill = torch.zeros((tp_diam.shape[0] + 1,))
tp_fill[:-1] = tp_diam

print(threshold)
print(fp_diam)
print(tp_diam)

for i in range(threshold.shape[0]):
    print(i, threshold[i], fp_diam[i], tp_diam[i])

pyplot.figure(figsize=(7,5), dpi=1200)
for i in [62, 88]:
    pyplot.scatter(fp_diam[i], tp_diam[i], color='red')
    print(f'diam: {round(threshold[i].item(), 2)}, x: {round(fp_diam[i].item(), 2)}, y: {round(tp_diam[i].item(), 2)}')
pyplot.fill(fp_fill, tp_fill, facecolor='#0077bb', alpha=0.25)
pyplot.plot(fp_diam, tp_diam, label=f'diameter baseline, AUC={auc_diam:.3f}')
pyplot.title(f'ROC diameter baseline, AUC={auc_diam:.3f}')
pyplot.ylabel('true positive rate')
pyplot.xlabel('false positive rate')
pyplot.savefig('roc_diameter_baseline.png')

最后繪制的圖像就是我們上面展示過的圖像,這個就作為我們預(yù)測是否惡性的baseline,接下來使用模型finetune來訓(xùn)練一個預(yù)測是否惡性的模型。


image.png
finetune

這里使用的是之前的分類模型,我們在分類模型的基礎(chǔ)上進行微調(diào)。下圖顯示了微調(diào)的方案,如果微調(diào)深度為1,那么只是把最后的全連接層重新訓(xùn)練,這里保持模型的主干和尾部都維持之前的權(quán)重不動,只把全連接層的權(quán)重重新初始化,然后把模型目標(biāo)改成區(qū)分是否惡性結(jié)節(jié),讓模型去學(xué)習(xí)新的權(quán)重。如果微調(diào)深度為2,那么把主干最后一個卷積塊的參數(shù)也都重置,再進行訓(xùn)練。


image.png

在訓(xùn)練分類的代碼里,通過簡單的修改就可以實現(xiàn)微調(diào)。

#判斷是否開啟微調(diào)模式
        if self.cli_args.finetune:
#加載模型
            d = torch.load(self.cli_args.finetune, map_location='cpu')
#獲取所有層
            model_blocks = [
                n for n, subm in model.named_children()
                if len(list(subm.parameters())) > 0
            ]
#獲取需要微調(diào)的層
            finetune_blocks = model_blocks[-self.cli_args.finetune_depth:]
            log.info(f"finetuning from {self.cli_args.finetune}, blocks {' '.join(finetune_blocks)}")
#加載保存的狀態(tài)
            model.load_state_dict(
                {
                    k: v for k,v in d['model_state'].items()
                    if k.split('.')[0] not in model_blocks[-1]
                },
                strict=False,
            )
#只在需要finetune的層上進行梯度計算
            for n, p in model.named_parameters():
                if n.split('.')[0] not in finetune_blocks:
                    p.requires_grad_(False)

然后啟動模型重新訓(xùn)練,這里我們使用區(qū)分惡性的數(shù)據(jù)集??梢钥吹竭@里使用的是MalignantLunaDataset,第一次只調(diào)整最后的全連接層,訓(xùn)練40個epoch。

run('test14ch.training.ClassificationTrainingApp', f'--epochs=40', '--malignant', '--dataset=MalignantLunaDataset',
    '--finetune=''D:/pytorchtest/data-unversioned/part2/models/p2ch14/cls_2022-06-27_21.59.28_nodule-nonnodule.best.state',
    'finetune-head')

這個訓(xùn)練速度稍微快一點,第一個epoch效果不怎么好,訓(xùn)練集上兩個類別準(zhǔn)確率67%,驗證集上平均下來只有60%左右。


image.png

到40個epoch


image.png

由于中間我也沒盯著,后來看保存的最佳模型,是第29個epoch的模型
[圖片上傳失敗...(image-7370ad-1656594433704)]

根據(jù)代碼可以看到,評判模型的得分,我們這里用的是auc,可以看到第29個epoch在驗證集上的auc要略高一點,所以這里最佳模型是第29個epoch的模型。


image.png

輸出了它的效果跟之前的AUC對比一下,結(jié)果發(fā)現(xiàn)訓(xùn)了半天模型還不如之前就用結(jié)節(jié)的直徑來判斷得到的AUC更好。
image.png

不行,把finetune深度改成2,又訓(xùn)了40輪。
run('test14ch.training.ClassificationTrainingApp', f'--epochs=40', '--malignant', '--dataset=MalignantLunaDataset',
    '--finetune=''D:/pytorchtest/data-unversioned/part2/models/p2ch14/cls_2022-06-27_21.59.28_nodule-nonnodule.best.state',
    '--finetune-depth=2',
    'finetune-depth2')

直接看結(jié)果,這里的best model竟然是第4個epoch的模型,可以看到訓(xùn)練集上的準(zhǔn)確率基本在93+%,驗證集惡性準(zhǔn)確率比較低只有73%,良性的準(zhǔn)確率為87%。


image.png

到40個epoch的時候,可以看到在訓(xùn)練集上準(zhǔn)確率都99%+了,但是驗證集效果下滑,出現(xiàn)了過擬合現(xiàn)象。


image.png

結(jié)果這個AUC是有提升了,但是還沒有超越使用直徑直接分類的效果。不如我們實際的模型就用直徑直接分類好了。當(dāng)然,還有很多優(yōu)化方案我們可以嘗試,比如說做模型集成,做樣本的增強,給訓(xùn)練數(shù)據(jù)提取更多特征,比如說做一個平滑的標(biāo)簽,甚至是使用更復(fù)雜的模型等等,但是在實際的項目中,我們可能說“由于時間緊迫,我們的第一版就先上線”,畢竟當(dāng)下的效果已經(jīng)能夠滿足業(yè)務(wù)需求,并且整體的邏輯已經(jīng)完成,終于可以給這個模型訓(xùn)練階段畫上一個句號。
image.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容