接下來需要再做一些工作,并把我們前面搞好的模型串起來,形成一個端到端的解決方案。這個方案如下,首先是從原始的CT數(shù)據(jù)出發(fā)進行圖像分割,識別可能是結(jié)節(jié)的體素,并對這些體素區(qū)域進行分組,然后用這些分割出的候選結(jié)節(jié)信息進行分類,首先是區(qū)分這是否是一個結(jié)節(jié),針對是結(jié)節(jié)的,再區(qū)分這是否是一個惡性結(jié)節(jié),這樣就完成了整個模型框架。

由于我們之前訓(xùn)練的兩個模型使用的訓(xùn)練數(shù)據(jù)是不一樣的,我們直接獲取了標(biāo)注的結(jié)節(jié)信息作為分類模型的訓(xùn)練集,而在實際中,我們需要對分割模型的結(jié)果進行分類。這就存在數(shù)據(jù)泄露的問題。也就說在分類模型的訓(xùn)練集中可能有些數(shù)據(jù)是分割模型的驗證集,反過來,在分類模型的驗證集里面可能有分割模型的訓(xùn)練集數(shù)據(jù)。所以之前壓根就沒保存模型,就是為了在這里重新訓(xùn)練一下。獲取LunaDataset跟之前還是一樣的,有區(qū)別的是從segmentationDataset中獲取標(biāo)注數(shù)據(jù)并分割為訓(xùn)練集和驗證集。
重新訓(xùn)練分類模型
先為這一章構(gòu)建數(shù)據(jù)緩存。
run('test14ch.prepcache.LunaPrepCacheApp')
然后訓(xùn)練100個epoch,由于下調(diào)了數(shù)據(jù)樣本量,訓(xùn)練集里面的正負樣本各2.5w條,驗證集保持原樣正樣本154條,負樣本5w+。所以訓(xùn)練起來還算可以,差不多10分鐘一個epoch
run('test14ch.training.ClassificationTrainingApp', f'--epochs=100', 'nodule-nonnodule')
看到第一個epoch結(jié)果,效果還不是很好,不過程序沒問題,就在這里跑著好了,我就去睡覺了。

一覺醒來,已經(jīng)70+epoch,訓(xùn)練集上的準(zhǔn)確率已經(jīng)99%+,驗證集上對于正樣本的準(zhǔn)確率也達到了94%+,不過中體的precision還是比較低的,因為兩個類別的樣本量差距太大了,有很多負樣本被歸為了陽性結(jié)果,不過這問題不大,我們主要是能把真的陽性篩出來就好了。

到了80個epoch,訓(xùn)練集效果基本沒變了,驗證集上陽性準(zhǔn)確率下降了一點。

直接跳到100個epoch,可以看到在訓(xùn)練集上的效果又提升了一丟丟,但是驗證集上,尤其是驗證集的負樣本準(zhǔn)確率下降了不少,這不符合我們的預(yù)期,說明模型有點過擬合了。

我們就用這里面的最佳模型作為我們最后系統(tǒng)中需要使用的分類模型。(第95epoch的模型,這時候有最好的f1 score)
連接分割和分類模型
新建一個代碼,就叫結(jié)節(jié)分析:nodule_analysis.py,它最核心的地方就是下面這段。
#取一個uid
for _, series_uid in series_iter:
#然后獲取對應(yīng)的CT數(shù)據(jù)
ct = getCt(series_uid)
#緊接著跑分割模型
mask_a = self.segmentCt(ct, series_uid)
#給分割模型預(yù)測到的結(jié)節(jié)數(shù)據(jù)進行分組
candidateInfo_list = self.groupSegmentationOutput(
series_uid, ct, mask_a)
#最后跑分類模型,決定是不是結(jié)節(jié)
classifications_list = self.classifyCandidates(
ct, candidateInfo_list)
其中,分割部分代碼
def segmentCt(self, ct, series_uid):
#預(yù)測不需要更新,關(guān)閉自動梯度計算
with torch.no_grad():
#用來存儲輸出結(jié)果
output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
#初始化數(shù)據(jù)加載器
seg_dl = self.initSegmentationDl(series_uid)
#遍歷整個CT
for input_t, _, _, slice_ndx_list in seg_dl:
#發(fā)送到GPU
input_g = input_t.to(self.device)
#運行分割模型
prediction_g = self.seg_model(input_g)
#把結(jié)果存起來
for i, slice_ndx in enumerate(slice_ndx_list):
output_a[slice_ndx] = prediction_g[i].cpu().numpy()
#構(gòu)建掩碼結(jié)果
mask_a = output_a > 0.5
mask_a = morphology.binary_erosion(mask_a, iterations=1)
#返回
return mask_a
接下來跟分割的結(jié)果分組。這里使用了一個scipy.ndimage.measurements的方法,measurements.label用來標(biāo)記連通區(qū)域。舉個簡單的例子,如下左圖是我們的掩碼結(jié)果,measurements.label的功能就是去識別這里面有多少是連通的,并用一個標(biāo)記去修改它們的值。對于左上角,是第一個連通區(qū)域,那么里面的值都改為1,中間這塊是第二個連通區(qū)域,里面的值都改為2,依次類推,就變成了右側(cè)的樣子。 measurements.center_of_mass則是用來計算每個連通區(qū)域的中心點坐標(biāo)。

def groupSegmentationOutput(self, series_uid, ct, clean_a):
candidateLabel_a, candidate_count = measurements.label(clean_a)
centerIrc_list = measurements.center_of_mass(
ct.hu_a.clip(-1000, 1000) + 1001,
labels=candidateLabel_a,
index=np.arange(1, candidate_count+1),
)
#把識別到的數(shù)據(jù)轉(zhuǎn)化分類模型要用的數(shù)據(jù)。
candidateInfo_list = []
for i, center_irc in enumerate(centerIrc_list):
center_xyz = irc2xyz(
center_irc,
ct.origin_xyz,
ct.vxSize_xyz,
ct.direction_a,
)
assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, candidate_count])
assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
candidateInfo_tup = \
CandidateInfoTuple(False, False, False, 0.0, series_uid, center_xyz)
candidateInfo_list.append(candidateInfo_tup)
return candidateInfo_list
最后是分類模型。
def classifyCandidates(self, ct, candidateInfo_list):
# 初始化dataloader
cls_dl = self.initClassificationDl(candidateInfo_list)
classifications_list = []
for batch_ndx, batch_tup in enumerate(cls_dl):
input_t, _, _, series_list, center_list = batch_tup
#發(fā)送到GPU上
input_g = input_t.to(self.device)
with torch.no_grad():
#運行分類模型
_, probability_nodule_g = self.cls_model(input_g)
#這里還有一個分是否惡性的模型,現(xiàn)在我們還沒開發(fā),先留下這個位置
if self.malignancy_model is not None:
_, probability_mal_g = self.malignancy_model(input_g)
else:
probability_mal_g = torch.zeros_like(probability_nodule_g)
zip_iter = zip(center_list,
probability_nodule_g[:,1].tolist(),
probability_mal_g[:,1].tolist())
#轉(zhuǎn)換坐標(biāo)
for center_irc, prob_nodule, prob_mal in zip_iter:
center_xyz = irc2xyz(center_irc,
direction_a=ct.direction_a,
origin_xyz=ct.origin_xyz,
vxSize_xyz=ct.vxSize_xyz,
)
cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
classifications_list.append(cls_tup)
return classifications_list
我們的CT圖像原本有大概3300w個體素,經(jīng)過圖像分割之后,留下大約100w個體素,通過給這些體素分組,可以得到大概1000個候選結(jié)節(jié)信息,然后對這些信息進行分類確認哪些是結(jié)節(jié),哪些不是結(jié)節(jié),經(jīng)過這步之后還剩幾十個確認是結(jié)節(jié),最后一步是確認結(jié)節(jié)的性質(zhì),惡性的通常來說最多也就一兩個。

這時候在回到main方法中,我們已經(jīng)得到了模型的結(jié)果,
#這個cli_args.run_validation參數(shù)是用來判斷是否跑驗證集數(shù)據(jù)的,如果不是驗證集數(shù)據(jù),而是單個輸入uid,那么就執(zhí)行下面的信息顯示功能
if not self.cli_args.run_validation:
print(f"found nodule candidates in {series_uid}:")
for prob, prob_mal, center_xyz, center_irc in classifications_list:
if prob > 0.5:#如果我們找到的結(jié)節(jié)概率超過0.5,就輸出信息到屏幕上,給醫(yī)生看
s = f"nodule prob {prob:.3f}, "
if self.malignancy_model:
s += f"malignancy prob {prob_mal:.3f}, "
s += f"center xyz {center_xyz}"
print(s)
#這里輸出混淆矩陣
if series_uid in candidateInfo_dict:
one_confusion = match_and_score(
classifications_list, candidateInfo_dict[series_uid]
)
all_confusion += one_confusion
print_confusion(
series_uid, one_confusion, self.malignancy_model is not None
)
print_confusion(
"Total", all_confusion, self.malignancy_model is not None
)
用我們建好的模型來預(yù)測一下數(shù)據(jù)看看,運行速度還是挺快的
python -m test14ch.nodule_analysis 1.3.6.1.4.1.14519.5.2.1.6279.6001.592821488053137951302246128864
輸出的結(jié)果如下,總共發(fā)現(xiàn)19個結(jié)節(jié),其中17個是假陽性,1個良性,1個惡性。不過這里我還沒有把惡性分類器加上,看了一下代碼,這個惡性標(biāo)記應(yīng)該是從原始的標(biāo)注數(shù)據(jù)里來的?

識別惡性結(jié)節(jié)
這個地方我們先獲取一份關(guān)于惡性腫瘤的標(biāo)注信息。這個數(shù)據(jù)來自我們之前已經(jīng)安裝的LIDC工具包,如果你還沒安裝可以用下面這行shell安裝
pip install pylidc
這里面有醫(yī)生關(guān)于惡性結(jié)節(jié)的標(biāo)注,每個醫(yī)生會標(biāo)注非常不可能、不可能、不確定、可疑、非??梢蓭追N情況,同時,對于一個結(jié)節(jié)可能會有多個醫(yī)生進行標(biāo)注。
ROC與AUC
這里插播一個小知識,就是需要學(xué)一個新的評估指標(biāo),ROC-AUC。要了解ROC曲線(Receiver Operating Characteristic curve),我們先回到混淆矩陣上來。

根據(jù)混淆矩陣,可以算出真陽性率(True Positive Rate,TPR)和假陽性率(False Positive Rate,F(xiàn)PR)
其中,

而ROC曲線就是假設(shè)我們對判斷陽性取不同的閾值時以兩個值為橫縱坐標(biāo)得到的一條曲線。下圖中,我們假設(shè)使用一個結(jié)節(jié)的直徑大小作為判斷是否惡性結(jié)節(jié)的標(biāo)準(zhǔn),那么該曲線就是當(dāng)取不同的直徑大小作為判斷閾值時,所得到的ROC曲線。而AUC(Area Under ROC)即為ROC曲線下面的面積。ROC曲線是用來衡量分類器的分類能力,AUC表示,隨機抽取一個正樣本和一個負樣本,分類器正確給出正樣本的score高于負樣本的概率。

因此,如果AUC越大,則表示模型的效果越好。
import torch
%matplotlib inline
from matplotlib import pyplot
import test14ch.dsets
import test14ch.model
#這里獲取帶有惡性結(jié)節(jié)標(biāo)記的數(shù)據(jù)集
ds = test14ch.dsets.MalignantLunaDataset(val_stride=10, isValSet_bool=True)
nodules = ds.ben_list + ds.mal_list
#獲取是否惡性結(jié)節(jié)的狀態(tài)和直徑
is_mal = torch.tensor([n.isMal_bool for n in nodules])
diam = torch.tensor([n.diameter_mm for n in nodules])
#惡性和良性數(shù)目
num_mal = is_mal.sum()
num_ben = len(is_mal) - num_mal
#設(shè)置閾值,取結(jié)節(jié)直徑的最大最小值,并分成100份
threshold = torch.linspace(diam.max(), diam.min(), steps=100)
#使用直徑來判斷是否惡性結(jié)節(jié)
predictions = (diam[None] >= threshold[:, None])
計算真陽率和假陽率
tp_diam = (predictions & is_mal[None]).sum(1).float() / num_mal
fp_diam = (predictions & ~is_mal[None]).sum(1).float() / num_ben
#計算auc
fp_diam_diff = fp_diam[1:] - fp_diam[:-1]
tp_diam_avg = (tp_diam[1:] + tp_diam[:-1])/2
auc_diam = (fp_diam_diff * tp_diam_avg).sum()
#fill用于后面繪圖使用
fp_fill = torch.ones((fp_diam.shape[0] + 1,))
fp_fill[:-1] = fp_diam
tp_fill = torch.zeros((tp_diam.shape[0] + 1,))
tp_fill[:-1] = tp_diam
print(threshold)
print(fp_diam)
print(tp_diam)
for i in range(threshold.shape[0]):
print(i, threshold[i], fp_diam[i], tp_diam[i])
pyplot.figure(figsize=(7,5), dpi=1200)
for i in [62, 88]:
pyplot.scatter(fp_diam[i], tp_diam[i], color='red')
print(f'diam: {round(threshold[i].item(), 2)}, x: {round(fp_diam[i].item(), 2)}, y: {round(tp_diam[i].item(), 2)}')
pyplot.fill(fp_fill, tp_fill, facecolor='#0077bb', alpha=0.25)
pyplot.plot(fp_diam, tp_diam, label=f'diameter baseline, AUC={auc_diam:.3f}')
pyplot.title(f'ROC diameter baseline, AUC={auc_diam:.3f}')
pyplot.ylabel('true positive rate')
pyplot.xlabel('false positive rate')
pyplot.savefig('roc_diameter_baseline.png')
最后繪制的圖像就是我們上面展示過的圖像,這個就作為我們預(yù)測是否惡性的baseline,接下來使用模型finetune來訓(xùn)練一個預(yù)測是否惡性的模型。

finetune
這里使用的是之前的分類模型,我們在分類模型的基礎(chǔ)上進行微調(diào)。下圖顯示了微調(diào)的方案,如果微調(diào)深度為1,那么只是把最后的全連接層重新訓(xùn)練,這里保持模型的主干和尾部都維持之前的權(quán)重不動,只把全連接層的權(quán)重重新初始化,然后把模型目標(biāo)改成區(qū)分是否惡性結(jié)節(jié),讓模型去學(xué)習(xí)新的權(quán)重。如果微調(diào)深度為2,那么把主干最后一個卷積塊的參數(shù)也都重置,再進行訓(xùn)練。

在訓(xùn)練分類的代碼里,通過簡單的修改就可以實現(xiàn)微調(diào)。
#判斷是否開啟微調(diào)模式
if self.cli_args.finetune:
#加載模型
d = torch.load(self.cli_args.finetune, map_location='cpu')
#獲取所有層
model_blocks = [
n for n, subm in model.named_children()
if len(list(subm.parameters())) > 0
]
#獲取需要微調(diào)的層
finetune_blocks = model_blocks[-self.cli_args.finetune_depth:]
log.info(f"finetuning from {self.cli_args.finetune}, blocks {' '.join(finetune_blocks)}")
#加載保存的狀態(tài)
model.load_state_dict(
{
k: v for k,v in d['model_state'].items()
if k.split('.')[0] not in model_blocks[-1]
},
strict=False,
)
#只在需要finetune的層上進行梯度計算
for n, p in model.named_parameters():
if n.split('.')[0] not in finetune_blocks:
p.requires_grad_(False)
然后啟動模型重新訓(xùn)練,這里我們使用區(qū)分惡性的數(shù)據(jù)集??梢钥吹竭@里使用的是MalignantLunaDataset,第一次只調(diào)整最后的全連接層,訓(xùn)練40個epoch。
run('test14ch.training.ClassificationTrainingApp', f'--epochs=40', '--malignant', '--dataset=MalignantLunaDataset',
'--finetune=''D:/pytorchtest/data-unversioned/part2/models/p2ch14/cls_2022-06-27_21.59.28_nodule-nonnodule.best.state',
'finetune-head')
這個訓(xùn)練速度稍微快一點,第一個epoch效果不怎么好,訓(xùn)練集上兩個類別準(zhǔn)確率67%,驗證集上平均下來只有60%左右。

到40個epoch

由于中間我也沒盯著,后來看保存的最佳模型,是第29個epoch的模型
[圖片上傳失敗...(image-7370ad-1656594433704)]
根據(jù)代碼可以看到,評判模型的得分,我們這里用的是auc,可以看到第29個epoch在驗證集上的auc要略高一點,所以這里最佳模型是第29個epoch的模型。

輸出了它的效果跟之前的AUC對比一下,結(jié)果發(fā)現(xiàn)訓(xùn)了半天模型還不如之前就用結(jié)節(jié)的直徑來判斷得到的AUC更好。

不行,把finetune深度改成2,又訓(xùn)了40輪。
run('test14ch.training.ClassificationTrainingApp', f'--epochs=40', '--malignant', '--dataset=MalignantLunaDataset',
'--finetune=''D:/pytorchtest/data-unversioned/part2/models/p2ch14/cls_2022-06-27_21.59.28_nodule-nonnodule.best.state',
'--finetune-depth=2',
'finetune-depth2')
直接看結(jié)果,這里的best model竟然是第4個epoch的模型,可以看到訓(xùn)練集上的準(zhǔn)確率基本在93+%,驗證集惡性準(zhǔn)確率比較低只有73%,良性的準(zhǔn)確率為87%。

到40個epoch的時候,可以看到在訓(xùn)練集上準(zhǔn)確率都99%+了,但是驗證集效果下滑,出現(xiàn)了過擬合現(xiàn)象。

結(jié)果這個AUC是有提升了,但是還沒有超越使用直徑直接分類的效果。不如我們實際的模型就用直徑直接分類好了。當(dāng)然,還有很多優(yōu)化方案我們可以嘗試,比如說做模型集成,做樣本的增強,給訓(xùn)練數(shù)據(jù)提取更多特征,比如說做一個平滑的標(biāo)簽,甚至是使用更復(fù)雜的模型等等,但是在實際的項目中,我們可能說“由于時間緊迫,我們的第一版就先上線”,畢竟當(dāng)下的效果已經(jīng)能夠滿足業(yè)務(wù)需求,并且整體的邏輯已經(jīng)完成,終于可以給這個模型訓(xùn)練階段畫上一個句號。
