代碼閱讀筆記1

代碼位置

代碼結構:

1.master文件夾
(1) dice_loss.py
(2) eval.py
(3) predict.py: ** 完全沒涉及pruning后的網(wǎng)絡**
(4) pruning.py:
(5) submit.py:** 完全沒涉及pruning后的網(wǎng)絡**
(6) train.py:** 完全沒涉及pruning后的網(wǎng)絡**
(7) 文件夾unet
<1> prune_layers.py
<2> prune_unet_model.py : class PruneUNet
<3> prune_unet_parts.py : 對p_double_conv, p_inconv, p_outconv, p_down, p_up 進行了定義

<4> unet_model.py
<5> unet_parts.py
(8) 文件夾util
<1> load.py
<2> util.py
<3> data_vis.py
<4> crf.py

閱讀目的:

能用,能跑,放自己的數(shù)據(jù)能跑

閱讀筆記

A. pruning.py閱讀
  1. pruning.py中的line 91的net.train()的理解
    net的定義:
net = PruneUNet(n_channels=3, n_classes=1) 
  1. 中心for循環(huán)的理解
    for循環(huán)做了以下四件事情:
    對每個epoch:
    (1) reset the generator for training data and validation data.
    這里實際上對每一個epoch,在開始時對數(shù)據(jù)都做了traditional augmentation,但是我們的數(shù)據(jù)量足夠多,不需要這么做。待改進
    (2) 取validation dataset的前四項,進行prediction和計算accuracy。
    這里因為validation dataset是自動生成的,所以雖然都是前4項,但是validation dataset是不一樣的。
    (3) 用PruneUNet訓練training dataset的前兩個batch:
    PruneUNet位于unet文件夾的prune_unet_model.py
  • model.eval() :Pytorch會自動把BN和Dropout固定住,不會取平均,而是用訓練好的值
  • model.train():讓model變成訓練模式,此時 dropout和batch normalization的操作在訓練時起到防止網(wǎng)絡過擬合的作用。
    訓練,算loss,反向傳播
    然后進行prune,
    對每個epoch,都要循環(huán)num_prune_iterations次,每一次運行一遍net.prune。net.prune的具體內(nèi)容見B,總結下來是去掉一個channel。

疑問:如果一直執(zhí)行net.prune,都是去掉最小值,但是去掉最小值之后如果不刪掉對應prune_feature_map 里的值,那么每次刪掉的module里的filter不是一樣的嗎?
回答:每一次找到對應layer_idx和filter_idx之后,對conv2d層執(zhí)行prune,都需要運行位于prune_layer.py中的函數(shù)prune_feature_map,在這個函數(shù)中,執(zhí)行了下面兩步:

indices = Variable(torch.LongTensor([i for i in range(self.out_channels) if i != map_index]))
self.weight = nn.Parameter(self.weight.index_select(0, indices).data)

對bias和對weight有一樣的操作。
最后將輸出channel減一。
這里重點理解這個index_select函數(shù):
函數(shù)格式:

index_select(
    dim,
    index)

參數(shù)含義:

dim:表示從第幾維挑選數(shù)據(jù),類型為int值;index:表示從第一個參數(shù)維度中的哪個位置挑選數(shù)據(jù),類型為torch.Tensor類的實例;

(4) 繼續(xù)對第一次循環(huán)里的validation的數(shù)據(jù)用pruned的代碼進行預測和計算loss。
(5) if save_cp時,保存net.state_dict()

B. prune_unet_model.py閱讀

PruneUNet這個class中定義了4個函數(shù):
__ init __,forward,set_pruning和prune。
其中 __ init __里,所有的down和up layer 以及output layer 都是pruned layer。
其中prune是具體進性layer prune的函數(shù),做了以下事情:
(1) 去掉model里的大的block
(2) 找到泰勒估計中,最小估計值所對應的layer和filter的位置,用prune_feature_map函數(shù)進行prune。
(3) 如果下一層不是最后一層,對應去drop掉下一層的輸入channel
(4) down layer的channel改變之后,對應up layer的channel也要改變,這里用hard code去寫。
進一步去看:line68的taylor_estimates_by_module 和 estimates_by_f_map是怎么計算得到的。
對每一個module list的module,在line64進行了module.taylor_estimates,去進行排序。

先取出每個module_list 的module.taylor_estimates和idx,
再從module.taylor_estimates里取出f_map_idx和對應的估計值estimate。

C. prune_layers.py閱讀

在prune_layers.py中,定義了class PrunableConv2d(nn.Conv2d)class PrunableBatchNorm2d(nn.BatchNorm2d),對PrunableConv2d(nn.Conv2d),定義了屬性taylor_estimates

D. 提問:
問題1:

pruning.py基于前幾個training batch和幾個epoch和手動輸入的num_prune_iterations對unet進行pruning,那么如何用prune好的網(wǎng)絡對我們的數(shù)據(jù)進行計算呢?

num_prune_iterations = 100,
epochs=5

這里又涉及兩個問題:(a) prune完需要retrain嗎? (b) 如何用pruned的網(wǎng)絡進行inference?
對問題(a),其實在pruning.py里,有反向傳播更新梯度值和權重值的過程了,未必要重新去train。
對問題(b),需要繼續(xù)閱讀代碼。
代碼里并沒有寫,可能需要自己在prediction的代碼里導入pruned_unet

問題2: 如何計算每個module的taylor_estimates

prune_layers.pyclass PrunableConv2d(nn.Conv2d)中有一個函數(shù)_calculate_taylor_estimate(self, _, grad_input, grad_output)專門計算taylor_estimates。
這里有注釋:# skip dim 1 as it is kernel size
其中,_recent_activations是forward之后該conv2d層的output。

mul_(value)
mul()的直接運算形式,即直接執(zhí)行并且返回修改后的張量

# skip dim 1 as it is kernel size
estimates = self._recent_activations.mul_(grad_output[0])
estimates = estimates.mean(dim=(0, 2, 3))        
# normalization
self.taylor_estimates = torch.abs(estimates) / torch.sqrt(torch.sum(estimates * estimates))

修改代碼據(jù)為己用:

A. pruning.py改動記錄

  1. 把optimizer從SGD改成Adam,和自己的UNet保持一致。(已完成)
  2. line61criterion = nn.BCELoss()改成自己定義的Diceloss
  3. line183的net的定義n_channel從3改成1
  4. summary(net, (3, 640, 640))注釋掉這一步可視化,因為暫時不探究其參數(shù)含義。
    B.
最后編輯于
?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

友情鏈接更多精彩內(nèi)容