【Lesson1】R 機(jī)器學(xué)習(xí)流程及案例實(shí)現(xiàn)

R 機(jī)器學(xué)習(xí)流程及案例實(shí)現(xiàn)

一直在學(xué)習(xí)機(jī)器學(xué)習(xí)的項(xiàng)目;學(xué)的斷斷續(xù)續(xù)。近期需要完成一些數(shù)據(jù)建模與分析,將機(jī)器學(xué)習(xí)重新整理了一遍。這篇文章主要是介紹R數(shù)據(jù)科學(xué)中,構(gòu)建機(jī)器學(xué)習(xí)模型的流程。為了更適合無基礎(chǔ)的人快速了解整個(gè)流程框架,本文省去機(jī)器學(xué)習(xí)模型的原理及公式部分,如果需要了解,請戳 Here

在看完本文以后,讓你們能夠?qū)C(jī)器學(xué)習(xí)模型有一個(gè)基本認(rèn)識,然后根據(jù)現(xiàn)有數(shù)據(jù)去構(gòu)建一個(gè)機(jī)器學(xué)習(xí)模型及其需要的步驟與預(yù)期結(jié)果,最后可以對自己的進(jìn)行操作練習(xí)與實(shí)現(xiàn)。

機(jī)器學(xué)習(xí)-流程

根據(jù)Max Kuhn 的Caret文章,進(jìn)行總結(jié),一般的機(jī)器學(xué)習(xí)流程主要分為以下過程。

image.png

將Data分成Train與Test兩部分。主要花費(fèi)的精力是在Train數(shù)據(jù)集上,因?yàn)樾枰业揭粋€(gè)合適的模型來擬合Train數(shù)據(jù),對模型參數(shù)進(jìn)行不斷調(diào)整,達(dá)到該數(shù)據(jù)的最優(yōu)。同時(shí)還需要考慮resampling,至于為什么要resample,其實(shí)就是:針對本數(shù)據(jù)模型的R^2可以達(dá)到0.99,但是只適用于本數(shù)據(jù),不能外推,所以the goal is not to “predict” the data you have in hand, but to develop a model that will predict new datasets.
有時(shí)候,變量較多,或者變量會存在相關(guān)系,那么就會涉及到變量的處理,Pre-processing(這也是一個(gè)相當(dāng)麻煩的過程)。

1.數(shù)據(jù)拆分Train與Test數(shù)據(jù)集
2.Train數(shù)據(jù)集模型選擇與調(diào)參
3.模型預(yù)測Test數(shù)據(jù)集

在上述模型調(diào)整好以后,嗯,那我們可以對Test數(shù)據(jù)進(jìn)行預(yù)測了。看下模型預(yù)測效果。這里預(yù)測的效果優(yōu)越是需要根據(jù)預(yù)測變量類型來選擇不同的評估指標(biāo),主要分為分類與回歸兩種。然后繪制相應(yīng)的RMSE曲線或者ROC曲線,來展示模型的預(yù)測性能。

當(dāng)然了,在醫(yī)學(xué)上機(jī)器學(xué)習(xí)應(yīng)用遠(yuǎn)不止于此,還需探究變量間的關(guān)聯(lián)性,稱之為explanation ML,在后面篇幅會介紹。。

案例操作

下面以caret舉例,Caret包的優(yōu)點(diǎn):主要使用train函數(shù),集中多個(gè)模型。其中函數(shù)中定義了模型與調(diào)節(jié)參數(shù),所以只要替換模型與參數(shù),即可調(diào)用不同模型。因此省去了因運(yùn)行不同模型而學(xué)習(xí)不同的packages。另外對于預(yù)測變量不管是分類變量還是連續(xù)性變量,Caret都可以構(gòu)建。
本次操作利用pdp包里面的pima數(shù)據(jù)集進(jìn)行演示。該數(shù)據(jù)收集了
392例女性糖尿病患者的臨床指標(biāo),包括年齡,血糖,胰島素及血壓等指標(biāo)。主要是通過臨床指標(biāo)預(yù)測患者是否患糖尿病。

1. 數(shù)據(jù)拆分

將pima數(shù)據(jù)進(jìn)行預(yù)處理,丟棄NA,glucose轉(zhuǎn)成分類變量(glucose > 149=="High")。然后利用createDataPartition()將數(shù)據(jù)分成train(80%)與test (20%)兩個(gè)部分。

library(tidyverse)
library(caret)
library(pdp)
### get data
data(pima)
df=pima %>% na.omit() %>% as.tbl() %>% 
  mutate(glucose=as.factor(ifelse(glucose>143,"High","Low")))
### splitdata
set.seed(13)
samp = createDataPartition(df$diabetes, p = 0.8, list = FALSE)
train = df[samp,]
test = df[-samp,]

2. 模型構(gòu)建

這里使用train()函數(shù),因變量為diabetes,自變量默認(rèn)選擇全部,需要提前使用trainControl()設(shè)置resampling方法,里面涉及"boot", "cv", "LOOCV", "LGOCV"等一系列方法,這里我們設(shè)置為5-fold cross validation--method = "cv", number = 5。
因?yàn)閐iabetes是二分類變量,我們采用gbm算法,然后用AUC來評估訓(xùn)練模型的優(yōu)越性。

myControl = trainControl(method = "cv", 
                         classProbs=T,
                         number = 5,
                         summaryFunction=prSummary,
                         verboseIter = FALSE)
set.seed(12)
model_gbm = train(diabetes ~ ., 
                  data = train,
                  method = "gbm",
                  trControl = myControl,
                  verbose = F,
                  #tuneGrid = gbm.grid,
                  metric = "ROC")

需要提示的是,這里為了減少運(yùn)行時(shí)間,并沒有進(jìn)行tuning 參數(shù)調(diào)節(jié)。gbm模型主要涉及三個(gè)參數(shù),可以把參數(shù)放入gird,然后一個(gè)一個(gè)測試,得出每個(gè)參數(shù)對應(yīng)調(diào)節(jié)下的AUC值,根據(jù)最大的AUC,選擇對應(yīng)的模型參數(shù)。當(dāng)然如果不設(shè)置grid,train會自動選擇最適參數(shù)。

gbm.grid <- expand.grid(interaction.depth = c(1,2,8),
                         n.trees = c(50, 100, 150, 200, 250, 300),
                         shrinkage = 0.1,
                         n.minobsinnode = 20)
 head(gbm.grid)
 
 model_gbm = train(diabetes ~ ., 
                  data = train,
                  method = "gbm",
                  trControl = myControl,
                  verbose = F,
                  tuneGrid = gbm.grid,
                  metric = "ROC")

接下來,我們看下model_gbm,這里面儲存了我們所要的信息。gbm最合適參數(shù)


image.png

3. 模型預(yù)測

### Predict
pred = predict(model_gbm,newdata=test)
confusionMatrix(pred,test$diabetes)
Confusion Matrix and Statistics

          Reference
Prediction neg pos
       neg  47   9
       pos   5  17
                                          
               Accuracy : 0.8205          
                 95% CI : (0.7172, 0.8983)
    No Information Rate : 0.6667          
    P-Value [Acc > NIR] : 0.001942        
                                          
                  Kappa : 0.58            
                                          
 Mcnemar's Test P-Value : 0.422678        
                                          
            Sensitivity : 0.9038          
            Specificity : 0.6538          
         Pos Pred Value : 0.8393          
         Neg Pred Value : 0.7727          
             Prevalence : 0.6667          
         Detection Rate : 0.6026          
   Detection Prevalence : 0.7179          
      Balanced Accuracy : 0.7788          
                                          
       'Positive' Class : neg      

4. 變量重要性與解釋

這里顯示, "insulin" "glucose" 與 "mass" 對模型結(jié)果影響較大。具體怎么樣的影響需要借助于邊際效應(yīng)的關(guān)系。pdp-案例:Explaining Black-Box Machine Learning Models - Code Part 1: tabular data + caret + iml

 varImp(model_gbm)
 plot(varImp(model_gbm))
image.png

4. 多個(gè)模型比較

有時(shí)候需要多個(gè)模型放在一起比較。

set.seed(12)
model_gbm = train(diabetes ~ ., 
                  data = train,
                  method = "gbm",
                  trControl = myControl,
                  verbose = F,
                  #tuneGrid = gbm.grid,
                  metric = "ROC")
model_svm = train(diabetes ~ ., 
                 data=train,
                 method = "svmRadial",
                 trControl = myControl,
                 tuneLength = 8,
                 metric = "ROC")
                                   
model_rda = train(diabetes ~ ., 
                 data=train,
                 method = "rda", 
                 trControl = myControl,
                 tuneLength = 4,
                 metric = "ROC")
                                    
# compare all
all=resamples(list(GBM = model_gbm,SVM=model_svm,RDA = model_rda))
summary(all)
Call:
summary.resamples(object = all)

Models: GBM, SVM, RDA 
Number of resamples: 5 

AUC 
         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
GBM 0.8499955 0.8508692 0.8611407 0.8696634 0.8868533 0.8994585    0
SVM 0.8300370 0.8355535 0.8563194 0.8584288 0.8608459 0.9093879    0
RDA 0.8252053 0.8387715 0.8963407 0.8772405 0.9124427 0.9134421    0

F 
         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
GBM 0.7804878 0.8235294 0.8297872 0.8193452 0.8314607 0.8314607    0
SVM 0.8043478 0.8089888 0.8181818 0.8208631 0.8222222 0.8505747    0
RDA 0.7380952 0.8048780 0.8181818 0.8073135 0.8275862 0.8478261    0

Precision 
         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
GBM 0.7500000 0.7872340 0.7872340 0.7876843 0.8000000 0.8139535    0
SVM 0.7400000 0.7659574 0.7708333 0.7763243 0.7826087 0.8222222    0
RDA 0.7380952 0.7800000 0.7826087 0.7851408 0.8000000 0.8250000    0

Recall 
         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
GBM 0.7619048 0.8333333 0.8809524 0.8571429 0.8809524 0.9285714    0
SVM 0.8571429 0.8571429 0.8809524 0.8714286 0.8809524 0.8809524    0
RDA 0.7380952 0.7857143 0.8571429 0.8333333 0.8571429 0.9285714    0

模型AUC

可以看出AUC最大的為gbm模型0.8739。

# ROC
# Build custom AUC function to extract AUC
# from the caret model object
library(pROC) 
test_roc = function(model, data) {
  roc(data$diabetes,
      predict(model, data, type = "prob")[, "pos"])
  
}

# Examine results for test set
model_list = list(GBM = model_gbm,SVM=model_svm,RDA = model_rda)

model_list_roc = model_list %>%
  map(test_roc, data = test)

model_list_roc %>%
  map(auc)

# plot
df_roc=c()
for (i in 1:length(model_list)) {
  a=test_roc(model_list[[i]],test)
  b=tibble(tpr=a$sensitivities,
           fpr=1-a$specificities,
           model=names(model_list)[i])
  
  df_roc=rbind(df_roc,b)
}

ggplot(data=df_roc,aes(x = fpr, y = tpr, group = model)) +
  geom_line(aes(color = model), size = 1) +
  geom_abline(intercept = 0, slope = 1, 
              color = "gray", size = 1)+
  labs(title = ("ROC Curves for all models"),
       x="False Positive Rate (1 - Specificity)",
       y="True Positive Rate (Sensivity or Recall)")

image.png

結(jié)語

這是Caret的使用,后續(xù)會介紹如何使用Tidymodel,將更簡化操作,輸入輸出步驟。
未完待續(xù)。

Caret 參考

  1. Caret resampling介紹
  2. Caret基礎(chǔ)介紹-Rebecca
  3. A Brief Introduction to caret 變量為連續(xù)性
  4. Caret Tune 參數(shù) 循環(huán)設(shè)置
  5. Kaggle Caret 實(shí)戰(zhàn)
  6. Data Science and Predictive Analytics
  7. Evaluating Model Performance by Building Cross-Validation from Scratch【為什么要resampling 】

next

Using XGBoost with Tidymodels 結(jié)合Caret
Caret 案例Machine Learning for Insurance Claims
Caret 預(yù)測Amesing huose-多個(gè)caret模型
Predict the Residential Sale Price of Properties in Ames
Multivariate Adaptive Regression Splines
Ames housing prediction
Tidymodels: tidy machine learning in R

pdp

pdp-案例:Explaining Black-Box Machine Learning Models - Code Part 1: tabular data + caret + iml
Chapter 5: Model-Agnostic Methods
Shining a light on the “Black Box” of machine learning
Gradient Boosting Machines
Partial dependence plots for tidymodels-based xgboost
【VIP】--Variable importance plots: an introduction to vip
【pdp】: An R Package for
VIP: Classification of Student Success with Caret

Handling Class Imbalance data

主要兩種,1.resample方法增加精度。2.采用PROC評估。

  1. 【W(wǎng)eighting and sampling】-Handling Class Imbalance with R and Caret - An Introduction
  2. 【PROC】-Handling Class Imbalance with R and Caret - Caveats when using the AUC

Tidymodel with R

https://www.tidymodels.org/learn/
https://www.tmwr.org/
https://algotech.netlify.app/blog/tidymodels/

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容