DeepLearning4j-使用Java訓練YOLO模型

DeepLearning4j-使用Java訓練YOLO模型

在這個Yolo v3發(fā)布的大好日子。
Deeplearning4j終于迎來了新的版本更新1.0.0-alpha,在zoo model中引入TinyYolo模型可以訓練自己的數據用于目標檢測。

不得不說,在Yolo v3這種性能和準確率上面都有大幅度提升的情況下,dl4j才引入TinyYolo總有一種49年加入國軍的感覺


一、任務和數據

數據來源自 https://github.com/cosmicad/dataset ,主要目的是識別并定位圖像中的紅細胞。
數據集總共分為兩個部分:

  1. 數據集:JPEGImages
  2. 標簽:Annotations

1.1 數據集

數據集樣張如圖所示:


數據集
數據集

數據集中所有的圖像均為.jpg格式。一共有410張圖片用于模型的訓練。

1.2 標簽

標簽如圖所示,每一個圖片都會有一個對應的xml文件作為訓練標簽。

標簽
標簽

沒一個標簽的數據都是遵守PASCAL VOC的數據格式,文件內容如下:

<annotation verified="no">
  <folder>RBC</folder>
  <filename>BloodImage_00000</filename>   //對應的圖片
  <path>/Users/cosmic/WBC_CLASSIFICATION_ANNO/RBC/BloodImage_00000.jpg</path>  //路徑(不重要)
  <source>                               //數據來源(不重要)
    <database>Unknown</database>
  </source>
  <size>                                 //圖像的寬高和通道數
    <width>640</width>
    <height>480</height>
    <depth>3</depth>
  </size>
  <segmented>0</segmented>               //是否用于分割(在圖像物體識別中01無所謂)
  <object>                               //需要檢測的物體
    <name>RBC</name>                     //物體類別的標簽,可以使用中文
    <pose>Unspecified</pose>             //拍攝角度
    <truncated>0</truncated>             //是否被截斷(0表示完整)
    <difficult>0</difficult>             //目標是否難以識別(0表示容易識別) 
    <bndbox>                             //bounding-box(包含左上角和右下角xy坐標)  
      <xmin>216</xmin>
      <ymin>359</ymin>
      <xmax>316</xmax>
      <ymax>464</ymax>
    </bndbox>
  </object>
  
  ...                                    //如果需要檢測多個物體,則定義多個<object></object>對象即可
</annotation>

1.3 如何制作自己的數據集

  1. labelImg: https://blog.csdn.net/jesse_mx/article/details/53606897
  2. BBox-Label-Tool: https://github.com/puzzledqs/BBox-Label-Tool

二、模型訓練

2.1 預定義參數用于模型的訓練

// parameters matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13;

以上代碼定義的是:

  1. 寬高和圖像的通道數
  2. YOLO模型對圖像分割的尺寸,在這里被分割成為13 x 13
// number classes for the red blood cells (RBC)
int nClasses = 1;

定義我們需要分類的數量,在這里我們只識別紅細胞這一個物體,因為值為1。

// parameters for the Yolo2OutputLayer
int nBoxes = 5;
double lambdaNoObj = 0.5;
double lambdaCoord = 5.0;
double[][] priorBoxes = { { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } };
double detectionThreshold = 0.3;

定義我們模型輸出層的一些參數。

// parameters for the training phase
int batchSize = 2;
int nEpochs = 50;
double learningRate = 1e-3;
double lrMomentum = 0.9;

定義一些我們訓練時模型的參數:

  1. batchSize為2,這里主要是因為我使用CPU運行,而且電腦只有8G運存,因此當你電腦配置更高的時候可以選擇更大的值使得模型獲得更好的訓練結果。
  2. nEpoch為50,總共訓練數據50個輪次。
  3. learningRate,學習率為1e-3。
  4. 學習率衰減動量,應用于Nesterovs更新器。

2.2 數據讀取

String dataDir = new ClassPathResource("/datasets").getFile().getPath();
File imageDir = new File(dataDir, "JPEGImages");

在本項目中數據被存放在resources文件夾下,因此需要獲取類路徑,這里主要是獲取圖像目錄。

log.info("Load data...");

RandomPathFilter pathFilter = new RandomPathFilter(rng) {
    @Override
    protected boolean accept(String name) {
        name = name.replace("/JPEGImages/", "/Annotations/").replace(".jpg", ".xml");
        try {
            return new File(new URI(name)).exists();
        } catch (URISyntaxException ex) {
            throw new RuntimeException(ex);
        }
    }
};
InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(pathFilter, 0.8, 0.2);
InputSplit trainData = data[0];
InputSplit testData = data[1];

讀取訓練數據,并且將數據劃分為訓練集和測試集。

ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, new VocLabelProvider(dataDir)); 

recordReaderTrain.initialize(trainData);

ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
    new VocLabelProvider(dataDir));
recordReaderTest.initialize(testData);

// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0, 1));

RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));

構建訓練集和測試集的迭代器,并且創(chuàng)建數據預處理器,使得圖像數據在訓練時被縮放至0~1范圍內。

2.3 模型構建

ComputationGraph model;
String modelFilename = "model_rbc.zip";
ComputationGraph pretrained = (ComputationGraph) new TinyYOLO().initPretrained();
INDArray priors = Nd4j.create(priorBoxes);

首先會從網絡上面下載預訓練模型,下載地址為用戶目錄下的.deeplearning4j目錄下,內容如圖所示:

預訓練模型
預訓練模型

接下來使用fine tune對模型結構進行更改:

 FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder().seed(seed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                .gradientNormalizationThreshold(1.0).updater(new Adam.Builder().learningRate(learningRate).build())
                .updater(new Nesterovs.Builder().learningRate(learningRate).momentum(lrMomentum).build()).activation(Activation.IDENTITY)
                .trainingWorkspaceMode(WorkspaceMode.SEPARATE).inferenceWorkspaceMode(WorkspaceMode.SEPARATE).build();

以上代碼主要做了這幾件事情:

  1. 使用隨機梯度下降優(yōu)化算法
  2. 使用 RenormalizeL2PerLayer 梯度標準化算法,用于防止梯度消失和梯度爆炸,具體內容可看:https://blog.csdn.net/u011669700/article/details/78974518
  3. 使用Nesterovs更新器,配置學習率和動量
  4. 設定訓練模式,具體可看:https://blog.csdn.net/u011669700/article/details/78846452

之后使用遷移學習對于模型架構記性修改:

model = new TransferLearning.GraphBuilder(pretrained).fineTuneConfiguration(fineTuneConf).removeVertexKeepConnections("conv2d_9")
                .addLayer("convolution2d_9",
                    new ConvolutionLayer.Builder(1, 1).nIn(1024).nOut(nBoxes * (5 + nClasses)).stride(1, 1).convolutionMode(ConvolutionMode.Same)
                        .weightInit(WeightInit.UNIFORM).hasBias(false).activation(Activation.IDENTITY).build(),
                    "leaky_re_lu_8")
                .addLayer("outputs", new Yolo2OutputLayer.Builder().lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord).boundingBoxPriors(priors).build(),
                    "convolution2d_9")
                .setOutputs("outputs")
                .build();

主要是配置識別的種類數目。

2.4 模型訓練

model.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < nEpochs; i++) {
    train.reset();
    while (train.hasNext()) {
        model.fit(train.next());
    }
    log.info("*** Completed epoch {} ***", i);
}
ModelSerializer.writeModel(model, modelFilename, true);

模型訓練完成之后,序列化保存在本地。

2.5 模型檢測可視化

// visualize results on the test set
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("RedBloodCellDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
List<String> labels = train.getLabels();
test.setCollectMetaData(true);
while (test.hasNext() && frame.isVisible()) {
    org.nd4j.linalg.dataset.DataSet ds = test.next();
    RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
    INDArray features = ds.getFeatures();
    INDArray results = model.outputSingle(features);
    List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
    File file = new File(metadata.getURI());
    log.info(file.getName() + ": " + objs);

    Mat mat = imageLoader.asMat(features);
    Mat convertedMat = new Mat();
    mat.convertTo(convertedMat, CV_8U, 255, 0);
    int w = metadata.getOrigW() * 2;
    int h = metadata.getOrigH() * 2;
    Mat image = new Mat();
    resize(convertedMat, image, new Size(w, h));
    for (DetectedObject obj : objs) {
        double[] xy1 = obj.getTopLeftXY();
        double[] xy2 = obj.getBottomRightXY();
        String label = labels.get(obj.getPredictedClass());
        int x1 = (int) Math.round(w * xy1[0] / gridWidth);
        int y1 = (int) Math.round(h * xy1[1] / gridHeight);
        int x2 = (int) Math.round(w * xy2[0] / gridWidth);
        int y2 = (int) Math.round(h * xy2[1] / gridHeight);
        rectangle(image, new Point(x1, y1), new Point(x2, y2), Scalar.RED);
        putText(image, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, Scalar.GREEN);
    }
    frame.setTitle(new File(metadata.getURI()).getName() + " - RedBloodCellDetection");
    frame.setCanvasSize(w, h);
    frame.showImage(converter.convert(image));
    frame.waitKey();
}
frame.dispose();

三、實驗結果

結果展示
結果展示

因為數據量少,訓練輪次小導致結果不是很好,有興趣的可以自己嘗試繼續(xù)訓練。

四、代碼地址

代碼地址已經放在github上面,自行下載即可: https://github.com/sjsdfg/dl4j-tutorials

在包styletransfer下,可以隨意運行。


更多文檔可以查看 https://github.com/sjsdfg/deeplearning4j-issues。
你的star是我持續(xù)分享的動力

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

友情鏈接更多精彩內容