RandomForest隨機(jī)森林總結(jié)

RandomForest隨機(jī)森林總結(jié)

1.隨機(jī)森林原理介紹

隨機(jī)森林,指的是利用多棵樹對(duì)樣本進(jìn)行訓(xùn)練并預(yù)測的一種分類器。該分類器最早由Leo Breiman和Adele Cutler提出,并被注冊成了商標(biāo)。簡單來說,隨機(jī)森林就是由多棵CART(Classification And Regression Tree)構(gòu)成的。對(duì)于每棵樹,它們使用的訓(xùn)練集是從總的訓(xùn)練集中有放回采樣出來的,這意味著,總的訓(xùn)練集中的有些樣本可能多次出現(xiàn)在一棵樹的訓(xùn)練集中,也可能從未出現(xiàn)在一棵樹的訓(xùn)練集中。在訓(xùn)練每棵樹的節(jié)點(diǎn)時(shí),使用的特征是從所有特征中按照一定比例隨機(jī)地?zé)o放回的抽取的,根據(jù)Leo Breiman的建議,假設(shè)總的特征數(shù)量為M,這個(gè)比例可以是sqrt(M),1/2sqrt(M),2sqrt(M)。

因此,隨機(jī)森林的訓(xùn)練過程可以總結(jié)如下:

(1)給定訓(xùn)練集S,測試集T,特征維數(shù)F。確定參數(shù):使用到的CART的數(shù)量t,每棵樹的深度d,每個(gè)節(jié)點(diǎn)使用到的特征數(shù)量f,終止條件:節(jié)點(diǎn)上最少樣本數(shù)s,節(jié)點(diǎn)上最少的信息增益m

對(duì)于第1-t棵樹,i=1-t:

(2)從S中有放回的抽取大小和S一樣的訓(xùn)練集S(i),作為根節(jié)點(diǎn)的樣本,從根節(jié)點(diǎn)開始訓(xùn)練

(3)如果當(dāng)前節(jié)點(diǎn)上達(dá)到終止條件,則設(shè)置當(dāng)前節(jié)點(diǎn)為葉子節(jié)點(diǎn),如果是分類問題,該葉子節(jié)點(diǎn)的預(yù)測輸出為當(dāng)前節(jié)點(diǎn)樣本集合中數(shù)量最多的那一類c(j),概率p為c(j)占當(dāng)前樣本集的比例;如果是回歸問題,預(yù)測輸出為當(dāng)前節(jié)點(diǎn)樣本集各個(gè)樣本值的平均值。然后繼續(xù)訓(xùn)練其他節(jié)點(diǎn)。如果當(dāng)前節(jié)點(diǎn)沒有達(dá)到終止條件,則從F維特征中無放回的隨機(jī)選取f維特征。利用這f維特征,尋找分類效果最好的一維特征k及其閾值th,當(dāng)前節(jié)點(diǎn)上樣本第k維特征小于th的樣本被劃分到左節(jié)點(diǎn),其余的被劃分到右節(jié)點(diǎn)。繼續(xù)訓(xùn)練其他節(jié)點(diǎn)。有關(guān)分類效果的評(píng)判標(biāo)準(zhǔn)在后面會(huì)講。

(4)重復(fù)(2)(3)直到所有節(jié)點(diǎn)都訓(xùn)練過了或者被標(biāo)記為葉子節(jié)點(diǎn)。

(5)重復(fù)(2),(3),(4)直到所有CART都被訓(xùn)練過。

利用隨機(jī)森林的預(yù)測過程如下:

對(duì)于第1-t棵樹,i=1-t:

(2)重復(fù)執(zhí)行(1)直到所有t棵樹都輸出了預(yù)測值。如果是分類問題,則輸出為所有樹中預(yù)測概率總和最大的那一個(gè)類,即對(duì)每個(gè)c(j)的p進(jìn)行累計(jì);如果是回歸問題,則輸出為所有樹的輸出的平均值。

注:有關(guān)分類效果的評(píng)判標(biāo)準(zhǔn),因?yàn)槭褂玫氖荂ART,因此使用的也是CART的平板標(biāo)準(zhǔn),和C3.0,C4.5都不相同。

對(duì)于分類問題(將某個(gè)樣本劃分到某一類),也就是離散變量問題,CART使用Gini值作為評(píng)判標(biāo)準(zhǔn)。定義為Gini=1-∑(P(i)*P(i)),P(i)為當(dāng)前節(jié)點(diǎn)上數(shù)據(jù)集中第i類樣本的比例。例如:分為2類,當(dāng)前節(jié)點(diǎn)上有100個(gè)樣本,屬于第一類的樣本有70個(gè),屬于第二類的樣本有30個(gè),則Gini=1-0.7×07-0.3×03=0.42,可以看出,類別分布越平均,Gini值越大,類分布越不均勻,Gini值越小。在尋找最佳的分類特征和閾值時(shí),評(píng)判標(biāo)準(zhǔn)為:argmax(Gini-GiniLeft-GiniRight),即尋找最佳的特征f和閾值th,使得當(dāng)前節(jié)點(diǎn)的Gini值減去左子節(jié)點(diǎn)的Gini和右子節(jié)點(diǎn)的Gini值最大。

對(duì)于回歸問題,相對(duì)更加簡單,直接使用argmax(Var-VarLeft-VarRight)作為評(píng)判標(biāo)準(zhǔn),即當(dāng)前節(jié)點(diǎn)訓(xùn)練集的方差Var減去減去左子節(jié)點(diǎn)的方差VarLeft和右子節(jié)點(diǎn)的方差VarRight值最大。

2.OpenCV函數(shù)使用

OpenCV提供了隨機(jī)森林的相關(guān)類和函數(shù)。具體使用方法如下:

(1)首先利用CvRTParams定義自己的參數(shù),其格式如下

CvRTParams::CvRTParams(intmax_depth,intmin_sample_count,floatregression_accuracy,booluse_surrogates,intmax_categories,constfloat* priors,boolcalc_var_importance,intnactive_vars,intmax_num_of_trees_in_the_forest,floatforest_accuracy,inttermcrit_type)

大部分參數(shù)描述都在http://docs.opencv.org/modules/ml/doc/random_trees.html上面有,說一下沒有描述的幾個(gè)參數(shù)的意義

booluse_surrogates:是否使用代理,指的是,如果當(dāng)前的測試樣本缺少某些特征,但是在當(dāng)前節(jié)點(diǎn)上的分類or回歸特征正是缺少的這個(gè)特征,那么這個(gè)樣本就沒法繼續(xù)沿著樹向下走了,達(dá)不到葉子節(jié)點(diǎn)的話,就沒有預(yù)測輸出,這種情況下,可以利用當(dāng)前節(jié)點(diǎn)下面的所有子節(jié)點(diǎn)中的葉子節(jié)點(diǎn)預(yù)測輸出的平均值,作為這個(gè)樣本的預(yù)測輸出。

const float*priors:先驗(yàn)知識(shí),這個(gè)指的是,可以根據(jù)各個(gè)類別樣本數(shù)量的先驗(yàn)分布,對(duì)其進(jìn)行加權(quán)。比如:如果一共有3類,第一類樣本占整個(gè)訓(xùn)練集的80%,其余兩類各占10%,那么這個(gè)數(shù)據(jù)集里面的數(shù)據(jù)就很不平均,如果每類的樣本都加權(quán)的話,就算把所有樣本都預(yù)測成第一類,那么準(zhǔn)確率也有80%,這顯然是不合理的,因此我們需要提高后兩類的權(quán)重,使得后兩類的分類正確率也不會(huì)太低。

floatregression_accuracy:回歸樹的終止條件,如果當(dāng)前節(jié)點(diǎn)上所有樣本的真實(shí)值和預(yù)測值之間的差小于這個(gè)數(shù)值時(shí),停止生產(chǎn)這個(gè)節(jié)點(diǎn),并將其作為葉子節(jié)點(diǎn)。

后來發(fā)現(xiàn)這些參數(shù)在決策樹里面有解釋,英文說明在這里http://docs.opencv.org/modules/ml/doc/decision_trees.html#cvdtreeparams

具體例子如下,網(wǎng)上找了個(gè)別人的例子,自己改成了可以讀取MNIST數(shù)據(jù)并且做分類的形式,如下:


#include <cv.h>//opencv general include file

#include <ml.h>//opencv machine learning include file

#include <stdio.h>

usingnamespacecv;//OpenCV API is in the C++ "cv" namespace

/******************************************************************************/

//global definitions (for speed and ease of use)

//手寫體數(shù)字識(shí)別

#defineNUMBER_OF_TRAINING_SAMPLES 60000

#defineATTRIBUTES_PER_SAMPLE 784

#defineNUMBER_OF_TESTING_SAMPLES 10000

#defineNUMBER_OF_CLASSES 10

//N.B. classes are integer handwritten digits in range 0-9

/******************************************************************************/

//loads the sample database from file (which is a CSV text file)

inlinevoidrevertInt(int&x)

{

x=((x&0x000000ff)<<24)|((x&0x0000ff00)<<8)|((x&0x00ff0000)>>8)|((x&0xff000000)>>24);

};

intread_data_from_csv(constchar* samplePath,constchar*labelPath, Mat data, Mat classes,

intn_samples )

{

FILE* sampleFile=fopen(samplePath,"rb");

FILE* labelFile=fopen(labelPath,"rb");

intmbs=0,number=0,col=0,row=0;

fread(&mbs,4,1,sampleFile);

fread(&number,4,1,sampleFile);

fread(&row,4,1,sampleFile);

fread(&col,4,1,sampleFile);

revertInt(mbs);

revertInt(number);

revertInt(row);

revertInt(col);

fread(&mbs,4,1,labelFile);

fread(&number,4,1,labelFile);

revertInt(mbs);

revertInt(number);

unsignedchartemp;

for(intline =0; line < n_samples; line++)

{

//for each attribute on the line in the file

for(intattribute =0; attribute < (ATTRIBUTES_PER_SAMPLE +1); attribute++)

{

if(attribute <ATTRIBUTES_PER_SAMPLE)

{

//first 64 elements (0-63) in each line are the attributes

fread(&temp,1,1,sampleFile);

//fscanf(f, "%f,", &tmp);

data.at<float>(line, attribute) = static_cast<float>(temp);

//printf("%f,", data.at<float>(line, attribute));

}

elseif(attribute ==ATTRIBUTES_PER_SAMPLE)

{

//attribute 65 is the class label {0 ... 9}

fread(&temp,1,1,labelFile);

//fscanf(f, "%f,", &tmp);

classes.at<float>(line,0) = static_cast<float>(temp);

//printf("%f\n", classes.at<float>(line, 0));

}

}

}

fclose(sampleFile);

fclose(labelFile);

return1;//all OK

}

/******************************************************************************/

intmain(intargc,char**argv )

{

for(inti=0; i< argc; i++)

std::cout<<argv[i]<<std::endl;

//lets just check the version first

printf ("OpenCV version %s (%d.%d.%d)\n",

CV_VERSION,

CV_MAJOR_VERSION, CV_MINOR_VERSION, CV_SUBMINOR_VERSION);

//定義訓(xùn)練數(shù)據(jù)與標(biāo)簽矩陣

Mat training_data =Mat(NUMBER_OF_TRAINING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1);

Mat training_classifications= Mat(NUMBER_OF_TRAINING_SAMPLES,1, CV_32FC1);

//定義測試數(shù)據(jù)矩陣與標(biāo)簽

Mat testing_data =Mat(NUMBER_OF_TESTING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1);

Mat testing_classifications= Mat(NUMBER_OF_TESTING_SAMPLES,1, CV_32FC1);

//define all the attributes as numerical

//alternatives are CV_VAR_CATEGORICAL or CV_VAR_ORDERED(=CV_VAR_NUMERICAL)

//that can be assigned on a per attribute basis

Mat var_type= Mat(ATTRIBUTES_PER_SAMPLE +1,1, CV_8U );

var_type.setTo(Scalar(CV_VAR_NUMERICAL) );//all inputs are numerical

//this is a classification problem (i.e. predict a discrete number of class

//outputs) so reset the last (+1) output var_type element to CV_VAR_CATEGORICAL

var_type.at<uchar>(ATTRIBUTES_PER_SAMPLE,0) =CV_VAR_CATEGORICAL;

doubleresult;//value returned from a prediction

//加載訓(xùn)練數(shù)據(jù)集和測試數(shù)據(jù)集

if(read_data_from_csv(argv[1],argv[2], training_data, training_classifications, NUMBER_OF_TRAINING_SAMPLES) &&

read_data_from_csv(argv[3],argv[4], testing_data, testing_classifications, NUMBER_OF_TESTING_SAMPLES))

{

/********************************步驟1:定義初始化Random Trees的參數(shù)******************************/

floatpriors[] = {1,1,1,1,1,1,1,1,1,1};//weights of each classification for classes

CvRTParamsparams= CvRTParams(20,//max depth

50,//min sample count

0,//regression accuracy: N/A here

false,//compute surrogate split, no missing data

15,//max number of categories (use sub-optimal algorithm for larger numbers)

priors,//the array of priors

false,//calculate variable importance

50,//number of variables randomly selected at node and used to find the best split(s).

100,//max number of trees in the forest

0.01f,//forest accuracy

CV_TERMCRIT_ITER |? ? CV_TERMCRIT_EPS//termination cirteria

);

/****************************步驟2:訓(xùn)練 Random Decision Forest(RDF)分類器*********************/

printf("\nUsing training database: %s\n\n", argv[1]);

CvRTrees* rtree =newCvRTrees;

booltrain_result=rtree->train(training_data, CV_ROW_SAMPLE, training_classifications,

Mat(), Mat(), var_type, Mat(),params);

//float train_error=rtree->get_train_error();

//printf("train error:%f\n",train_error);

//perform classifier testing and report results

Mat test_sample;

intcorrect_class =0;

intwrong_class =0;

intfalse_positives [NUMBER_OF_CLASSES] = {0,0,0,0,0,0,0,0,0,0};

printf("\nUsing testing database: %s\n\n", argv[2]);

for(inttsample =0; tsample < NUMBER_OF_TESTING_SAMPLES; tsample++)

{

//extract a row from the testing matrix

test_sample =testing_data.row(tsample);

/********************************步驟3:預(yù)測*********************************************/

result= rtree->predict(test_sample, Mat());

printf("Testing Sample %i -> class result (digit %d)\n", tsample, (int) result);

//if the prediction and the (true) testing classification are the same

//(N.B. openCV uses a floating point decision tree implementation!)

if(fabs(result - testing_classifications.at<float>(tsample,0))

>=FLT_EPSILON)

{

//if they differ more than floating point error => wrong class

wrong_class++;

false_positives[(int) result]++;

}

else

{

//otherwise correct

correct_class++;

}

}

printf("\nResults on the testing database: %s\n"

"\tCorrect classification: %d (%g%%)\n"

"\tWrong classifications: %d (%g%%)\n",

argv[2],

correct_class, (double) correct_class*100/NUMBER_OF_TESTING_SAMPLES,

wrong_class, (double) wrong_class*100/NUMBER_OF_TESTING_SAMPLES);

for(inti =0; i < NUMBER_OF_CLASSES; i++)

{

printf("\tClass (digit %d) false postives? ? %d (%g%%)\n", i,

false_positives[i],

(double) false_positives[i]*100/NUMBER_OF_TESTING_SAMPLES);

}

//all matrix memory free by destructors

//all OK : main returns 0

return0;

}

//not OK : main returns -1

return-1;

}

MNIST樣本可以在這個(gè)網(wǎng)址http://yann.lecun.com/exdb/mnist/下載,改一下路徑可以直接跑的。

3.如何自己設(shè)計(jì)隨機(jī)森林程序

有時(shí)現(xiàn)有的庫無法滿足要求,就需要自己設(shè)計(jì)一個(gè)分類器算法,這部分講一下如何設(shè)計(jì)自己的隨機(jī)森林分類器,代碼實(shí)現(xiàn)就不貼了,因?yàn)樵诠ぷ髦杏玫搅?,因此比較敏感。

首先,要有一個(gè)RandomForest類,里面保存整個(gè)樹需要的一些參數(shù),包括但不限于:訓(xùn)練樣本數(shù)量、測試樣本數(shù)量、特征維數(shù)、每個(gè)節(jié)點(diǎn)隨機(jī)提取的特征維數(shù)、CART樹的數(shù)量、樹的最大深度、類別數(shù)量(如果是分類問題)、一些終止條件、指向所有樹的指針,指向訓(xùn)練集和測試集的指針,指向訓(xùn)練集label的指針等。還要有一些函數(shù),至少要有train和predict吧。train里面直接調(diào)用每棵樹的train方法即可,predict同理,但要對(duì)每棵樹的預(yù)測輸出做處理,得到森林的預(yù)測輸出。

其次,要有一個(gè)sample類,這個(gè)類可不是用來存儲(chǔ)訓(xùn)練集和對(duì)應(yīng)label的,這是因?yàn)?,每棵樹、每個(gè)節(jié)點(diǎn)都有自己的樣本集和,如果你的存儲(chǔ)每個(gè)樣本集和的話,需要的內(nèi)存實(shí)在是太過巨大了,假設(shè)樣本數(shù)量為M,特征維數(shù)為N,則整個(gè)訓(xùn)練集大小為M×N,而每棵樹的每層都有這么多樣本,樹的深度為D,共有S棵樹的話,則需要存儲(chǔ)M×N×D×S的存儲(chǔ)空間。這實(shí)在是太大了。因此,每個(gè)節(jié)點(diǎn)訓(xùn)練時(shí)用到的訓(xùn)練樣本和特征,我們都用序號(hào)數(shù)組來代替,sample類就是干這個(gè)的。sample的函數(shù)基本需要兩個(gè)就行,第一個(gè)是從現(xiàn)有訓(xùn)練集有放回的隨機(jī)抽取一個(gè)新的訓(xùn)練集,當(dāng)然,只包含樣本的序號(hào)。第二個(gè)函數(shù)是從現(xiàn)有的特征中無放回的隨機(jī)抽取一定數(shù)量的特征,同理,也是特征序號(hào)即可。

然后,需要一個(gè)Tree類,代表每棵樹,里面保存樹的一些參數(shù)以及一個(gè)指向所有節(jié)點(diǎn)的指針。

最后,需要一個(gè)Node類,代表樹的每個(gè)節(jié)點(diǎn)。

需要說明的是,保存樹的方式可以是最普通的數(shù)組,也可是是vector。Node的保存方式同理,但是個(gè)人不建議用鏈表的方式,在程序設(shè)計(jì)以及函數(shù)處理上太麻煩,但是在省空間上并沒有太多的體現(xiàn)。

目前先寫這么多,最后這部分我還會(huì)再擴(kuò)充一些。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 我,來了 跨山越海地來了 我們很近了 隔了一條叫做滄浪的河 聽得見小巷挑著擔(dān)的農(nóng)人叫賣聲清亮 看得見豆?jié){油條小攤的...
    紫草茵茵閱讀 327評(píng)論 0 3
  • 杜琪峰導(dǎo)演,趙薇、古天樂、鐘漢良領(lǐng)銜主演。 這樣的卡司陣容的確很強(qiáng)大。我不是主演中任何一個(gè)人的粉,但在影片上映前的...
    季霖閱讀 1,336評(píng)論 3 4

友情鏈接更多精彩內(nèi)容