蓄水池抽樣算法(Reservoir Sampling)

蓄水池抽樣算法(Reservoir Sampling)

許多年以后,當(dāng)聽說蓄水池抽樣算法時(shí),邱simple將會(huì)想起,那個(gè)小學(xué)數(shù)學(xué)老師帶他做“小明對(duì)水池邊加水邊放水,求何時(shí)能加滿水”應(yīng)用題的下午。

一、問題

我是在一次失敗的面試經(jīng)歷中聽說蓄水池算法的。之后上網(wǎng)搜了搜,知道是一個(gè)數(shù)據(jù)抽樣算法,寥寥幾行,卻暗藏玄機(jī)。主要用來解決如下問題。

給定一個(gè)數(shù)據(jù)流,數(shù)據(jù)流長(zhǎng)度N很大,且N直到處理完所有數(shù)據(jù)之前都不可知,請(qǐng)問如何在只遍歷一遍數(shù)據(jù)(O(N))的情況下,能夠隨機(jī)選取出m個(gè)不重復(fù)的數(shù)據(jù)。

這個(gè)場(chǎng)景強(qiáng)調(diào)了3件事:

  1. 數(shù)據(jù)流長(zhǎng)度N很大且不可知,所以不能一次性存入內(nèi)存。
  2. 時(shí)間復(fù)雜度為O(N)。
  3. 隨機(jī)選取m個(gè)數(shù),每個(gè)數(shù)被選中的概率為m/N。

第1點(diǎn)限制了不能直接取N內(nèi)的m個(gè)隨機(jī)數(shù),然后按索引取出數(shù)據(jù)。第2點(diǎn)限制了不能先遍歷一遍,然后分塊存儲(chǔ)數(shù)據(jù),再隨機(jī)選取。第3點(diǎn)是數(shù)據(jù)選取絕對(duì)隨機(jī)的保證。講真,在不知道蓄水池算法前,我想破腦袋也不知道該題做何解。

二、核心代碼及原理

蓄水池抽樣算法的核心如下:

int[] reservoir = new int[m];

// init
for (int i = 0; i < reservoir.length; i++)
{
    reservoir[i] = dataStream[i];
}

for (int i = m; i < dataStream.length; i++)
{
    // 隨機(jī)獲得一個(gè)[0, i]內(nèi)的隨機(jī)整數(shù)
    int d = rand.nextInt(i + 1);
    // 如果隨機(jī)整數(shù)落在[0, m-1]范圍內(nèi),則替換蓄水池中的元素
    if (d < m)
    {
        reservoir[d] = dataStream[i];
    }
}

注:這里使用已知長(zhǎng)度的數(shù)組dataStream來表示未知長(zhǎng)度的數(shù)據(jù)流,并假設(shè)數(shù)據(jù)流長(zhǎng)度大于蓄水池容量m。

算法思路大致如下:

  1. 如果接收的數(shù)據(jù)量小于m,則依次放入蓄水池。
  2. 當(dāng)接收到第i個(gè)數(shù)據(jù)時(shí),i >= m,在[0, i]范圍內(nèi)取以隨機(jī)數(shù)d,若d的落在[0, m-1]范圍內(nèi),則用接收到的第i個(gè)數(shù)據(jù)替換蓄水池中的第d個(gè)數(shù)據(jù)。
  3. 重復(fù)步驟2。

算法的精妙之處在于:當(dāng)處理完所有的數(shù)據(jù)時(shí),蓄水池中的每個(gè)數(shù)據(jù)都是以m/N的概率獲得的。

下面用白話文推導(dǎo)驗(yàn)證該算法。假設(shè)數(shù)據(jù)開始編號(hào)為1.

第i個(gè)接收到的數(shù)據(jù)最后能夠留在蓄水池中的概率=第i個(gè)數(shù)據(jù)進(jìn)入過蓄水池的概率*之后第i個(gè)數(shù)據(jù)不被替換的概率(第i+1到第N次處理數(shù)據(jù)都不會(huì)被替換)。

  1. 當(dāng)i<=m時(shí),數(shù)據(jù)直接放進(jìn)蓄水池,所以第i個(gè)數(shù)據(jù)進(jìn)入過蓄水池的概率=1
  2. 當(dāng)i>m時(shí),在[1,i]內(nèi)選取隨機(jī)數(shù)d,如果d<=m,則使用第i個(gè)數(shù)據(jù)替換蓄水池中第d個(gè)數(shù)據(jù),因此第i個(gè)數(shù)據(jù)進(jìn)入過蓄水池的概率=m/i
  3. 當(dāng)i<=m時(shí),程序從接收到第m+1個(gè)數(shù)據(jù)時(shí)開始執(zhí)行替換操作,第m+1次處理會(huì)替換池中數(shù)據(jù)的為m/(m+1),會(huì)替換掉第i個(gè)數(shù)據(jù)的概率為1/m,則第m+1次處理替換掉第i個(gè)數(shù)據(jù)的概率為(m/(m+1))*(1/m)=1/(m+1),不被替換的概率為1-1/(m+1)=m/(m+1)。依次,第m+2次處理不替換掉第i個(gè)數(shù)據(jù)概率為(m+1)/(m+2)...第N次處理不替換掉第i個(gè)數(shù)據(jù)的概率為(N-1)/N。所以,之后第i個(gè)數(shù)據(jù)不被替換的概率=m/(m+1)*(m+1)/(m+2)*...*(N-1)/N=m/N。
  4. 當(dāng)i>m時(shí),程序從接收到第i+1個(gè)數(shù)據(jù)時(shí)開始有可能替換第i個(gè)數(shù)據(jù)。則參考上述第3點(diǎn),之后第i個(gè)數(shù)據(jù)不被替換的概率=i/N
  5. 結(jié)合第1點(diǎn)和第3點(diǎn)可知,當(dāng)i<=m時(shí),第i個(gè)接收到的數(shù)據(jù)最后留在蓄水池中的概率=1*m/N=m/N。結(jié)合第2點(diǎn)和第4點(diǎn)可知,當(dāng)i>m時(shí),第i個(gè)接收到的數(shù)據(jù)留在蓄水池中的概率=m/i*i/N=m/N。綜上可知,每個(gè)數(shù)據(jù)最后被選中留在蓄水池中的概率為m/N。

這個(gè)算法建立在統(tǒng)計(jì)學(xué)基礎(chǔ)上,很巧妙地獲得了“m/N”這個(gè)概率。

三、深入一些——分布式蓄水池抽樣(Distributed/Parallel Reservoir Sampling)

一塊CPU的計(jì)算能力再?gòu)?qiáng),也總有內(nèi)存和磁盤IO拖他的后腿。因此為提高數(shù)據(jù)吞吐量,分布式的硬件搭配軟件是現(xiàn)在的主流。

如果遇到超大的數(shù)據(jù)量,即使是O(N)的時(shí)間復(fù)雜度,蓄水池抽樣程序完成抽樣任務(wù)也將耗時(shí)很久。因此分布式的蓄水池抽樣算法應(yīng)運(yùn)而生。運(yùn)作原理如下:

  1. 假設(shè)有K臺(tái)機(jī)器,將大數(shù)據(jù)集分成K個(gè)數(shù)據(jù)流,每臺(tái)機(jī)器使用單機(jī)版蓄水池抽樣處理一個(gè)數(shù)據(jù)流,抽樣m個(gè)數(shù)據(jù),并最后記錄處理的數(shù)據(jù)量為N1, N2, ..., Nk, ..., NK(假設(shè)m<Nk)。N1+N2+...+NK=N。
  2. 取[1, N]一個(gè)隨機(jī)數(shù)d,若d<N1,則在第一臺(tái)機(jī)器的蓄水池中等概率不放回地(1/m)選取一個(gè)數(shù)據(jù);若N1<=d<(N1+N2),則在第二臺(tái)機(jī)器的蓄水池中等概率不放回地選取一個(gè)數(shù)據(jù);一次類推,重復(fù)m次,則最終從N大數(shù)據(jù)集中選出m個(gè)數(shù)據(jù)。

m/N的概率驗(yàn)證如下:

  1. 第k臺(tái)機(jī)器中的蓄水池?cái)?shù)據(jù)被選取的概率為m/Nk。
  2. 從第k臺(tái)機(jī)器的蓄水池中選取一個(gè)數(shù)據(jù)放進(jìn)最終蓄水池的概率為Nk/N。
  3. 第k臺(tái)機(jī)器蓄水池的一個(gè)數(shù)據(jù)被選中的概率為1/m。(不放回選取時(shí)等概率的)
  4. 重復(fù)m次選取,則每個(gè)數(shù)據(jù)被選中的概率為m*(m/Nk*Nk/N*1/m)=m/N

四、算法驗(yàn)證

寫一份完整的代碼,用來驗(yàn)證蓄水池抽樣的隨機(jī)性。數(shù)據(jù)集大小為1000,蓄水池容量為10,做10_0000次抽樣。如果程序正確,那么每個(gè)數(shù)被抽中的次數(shù)接近1000次。

package cn.edu.njupt.qyz;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class ReservoirSampling {
    
    static ExecutorService exec = Executors.newFixedThreadPool(4);
    
    // 抽樣任務(wù),用作模擬并行抽樣
    private static class SampleTask implements Callable<int[]>
    {
        // 輸入該任務(wù)的數(shù)據(jù)
        private int[] innerData;
        // 蓄水池容量
        private int m;
        
        SampleTask (int m, int[] innerData)
        {
            this.innerData = innerData;
            this.m = m;
        }

        @Override
        public int[] call() throws Exception
        {
            int[] reservoir = sample(this.m, this.innerData);
            return reservoir;
        }
        
    }
    
    // 并行抽樣
    public static int[] mutiSample(int m, int[] dataStream) throws InterruptedException, ExecutionException
    {
        Random rand = new Random();
        
        
        int[] reservoir = initReservoir(m, dataStream);
        
        // 生成3個(gè)范圍內(nèi)隨機(jī)數(shù),將數(shù)據(jù)切成4份
        List<Integer> list = getRandInt(rand, dataStream.length); 
        int s1 = list.get(0);
        int s2 = list.get(1);
        int s3 = list.get(2);
        // 每個(gè)任務(wù)處理的數(shù)據(jù)量
        double n1 = s1 - 0;
        double n2 = s2 - s1;
        double n3 = s3 - s2;
        double n4 = dataStream.length - s3;
        
        // 并行抽樣
        Future<int[]> f1 = exec.submit(new SampleTask(m, Arrays.copyOfRange(dataStream, 0, s1)));
        Future<int[]> f2 = exec.submit(new SampleTask(m, Arrays.copyOfRange(dataStream, s1, s2)));
        Future<int[]> f3 = exec.submit(new SampleTask(m, Arrays.copyOfRange(dataStream, s2, s3)));
        Future<int[]> f4 = exec.submit(new SampleTask(m, Arrays.copyOfRange(dataStream, s3, dataStream.length)));
        List<Integer> r1 = getList(f1.get());
        List<Integer> r2 = getList(f2.get());
        List<Integer> r3 = getList(f3.get());
        List<Integer> r4 = getList(f4.get());
        
        // 進(jìn)行m次抽樣
        for (int i = 0; i < m; i++)
        {
            int p = rand.nextInt(dataStream.length);
            // 根據(jù)隨機(jī)數(shù)落在的范圍選擇元素
            if (p < s1)
            {
                reservoir[i] = getRandEle(r1, rand.nextInt(r1.size()));
            }
            else if (p < s2)
            {
                reservoir[i] = getRandEle(r2, rand.nextInt(r2.size()));
            }
            else if (p < s3)
            {
                reservoir[i] = getRandEle(r3, rand.nextInt(r3.size()));
            }
            else
            {
                reservoir[i] = getRandEle(r4, rand.nextInt(r4.size()));
            }
        }
        
        return reservoir;
    }
    
    // 根據(jù)輸入返回隨機(jī)位置的元素,并且刪除該元素,模擬不放回
    private static int getRandEle(List<Integer> list, int idx)
    {
        return list.remove(idx);
    }
    
    // 獲取bound范圍內(nèi)的3個(gè)隨機(jī)數(shù),用來分割數(shù)據(jù)集
    private static List<Integer> getRandInt(Random rand, int bound)
    {
        Set<Integer> set = new TreeSet<>();
        List<Integer> list = new ArrayList<>();
        
        while (set.size() < 3)
        {
            set.add(rand.nextInt(bound));
        }
        for (int e: set)
        {
            list.add(e);
        }
        return list;
    }
    // 數(shù)據(jù)轉(zhuǎn)換成List
    private static List<Integer> getList(int[] arr)
    {
        List<Integer> list = new LinkedList<>();
        for (int a : arr)
        {
            list.add(a);
        }
        return list;
    }
    
    // 單機(jī)版蓄水池抽樣
    public static int[] sample(int m, int[] dataStream)
    {
        // 隨機(jī)數(shù)生成器,以系統(tǒng)當(dāng)前nano時(shí)間作為種子
        Random rand = new Random();
        
        int[] reservoir = initReservoir(m, dataStream);
        
        // init
        for (int i = 0; i < reservoir.length; i++)
        {
            reservoir[i] = dataStream[i];
        }

        for (int i = m; i < dataStream.length; i++)
        {
            // 隨機(jī)獲得一個(gè)[0, i]內(nèi)的隨機(jī)整數(shù)
            int d = rand.nextInt(i + 1);
            // 如果隨機(jī)整數(shù)在[0, m-1]范圍內(nèi),則替換蓄水池中的元素
            if (d < m)
            {
                reservoir[d] = dataStream[i];
            }
        }
        return reservoir;
    }
    
    private static int[] initReservoir (int m, int[] dataStream)
    {
        int[] reservoir;
        
        if (m > dataStream.length)
        {
            reservoir = new int[dataStream.length];
        }
        else
        {
            reservoir = new int[m];
        }
        return reservoir;
    }
    
    // 單機(jī)版測(cè)試
    public void test()
    {
        // 樣本長(zhǎng)度
        int len = 1000;
        // 蓄水池容量
        int m = 10;
        // 抽樣次數(shù),用作驗(yàn)證抽樣的隨機(jī)性
        int iterTime = 100000;
        // 每個(gè)數(shù)字被抽到的次數(shù)
        int[] freq = new int[len];
        // 樣本
        int[] dataStream = new int[len];
        
        // init dataStream
        for (int i = 0; i < dataStream.length; i++)
        {
            dataStream[i] = i;
        }
        
        // count freq
        for (int k = 0; k < iterTime; k++)
        {
            // 進(jìn)行抽樣
            int[] reservoir = sample(m, dataStream);
            // 計(jì)算出現(xiàn)次數(shù)
            for (int i = 0; i < reservoir.length; i++)
            {
                int ele = reservoir[i];
                freq[ele] += 1; 
            }
        }
        
        printStaticInfo(freq);
    }
    
    // 測(cè)試并行抽樣
    public void mutiTest() throws InterruptedException, ExecutionException
    {
        // 樣本長(zhǎng)度
        int len = 1000;
        // 蓄水池容量
        int m = 10;
        // 抽樣次數(shù),用作驗(yàn)證抽樣的隨機(jī)性
        int iterTime = 10_0000;
        // 每個(gè)數(shù)字被抽到的次數(shù)
        int[] freq = new int[len];
        // 樣本
        int[] dataStream = new int[len];
        
        // init dataStream
        for (int i = 0; i < dataStream.length; i++)
        {
            dataStream[i] = i;
        }
        
        // count freq
        for (int k = 0; k < iterTime; k++)
        {
            // 進(jìn)行抽樣
            int[] reservoir = mutiSample(m, dataStream);
            // 計(jì)算出現(xiàn)次數(shù)
            for (int i = 0; i < reservoir.length; i++)
            {
                int ele = reservoir[i];
                freq[ele] += 1; 
            }
        }
        printStaticInfo(freq);
    }
    // 打印統(tǒng)計(jì)信息
    private void printStaticInfo (int[] freq)
    {
        // 期望、方差和標(biāo)準(zhǔn)差
        double avg = 0;
        double var = 0;
        double sigma = 0;
        // print
        for (int i = 0; i < freq.length; i++)
        {
            if (i % 10 == 9) System.out.println();
            System.out.print(freq[i] + ", ");
            avg += ((double)(freq[i]) / freq.length);
            var += (double)(freq[i] * freq[i]) / freq.length;
        }
        
        // 輸出統(tǒng)計(jì)信息
        System.out.println("\n===============================");
        var = var - avg * avg;
        sigma = Math.sqrt(var);
        System.out.println("Average: " + avg);
        System.out.println("Variance: " + var);
        System.out.println("Standard deviation: " + sigma);
    }
    
    public static void main (String[] args) throws InterruptedException, ExecutionException
    {
        ReservoirSampling rs = new ReservoirSampling();
        rs.mutiTest();
    }
}

單機(jī)版輸出和并行版的輸出類似,截取片段如下:

948, 1006, 1014, 1019, 1033, 1040, 948, 1014, 1000, 951, 
1014, 987, 1049, 1043, 1034, 983, 1006, 974, 1060, 1009, 
986, 1021, 1024, 963, 1041, 1028, 988, 1011, 975, 980, 
1055, 1017, 1010, 1018, 1013, 983, 942, 1056, 1003, 1063, 
1004, 1004, 999, 976, 957, 935, 1061, 1018, 1002, 1018, 
1019, 946, 985, 1057, 1012, 965, 978, 1040, 1026, 1064, 
1026, 1018, 980, 996, 1025, 1028, 1006, 944, 986, 981, 
923, 1015, 991, 1019, 1024, 1143, 989, 985, 1022, 1019, 
1004, 1000, 989, 972, 1041, 988, 1050, 932, 975, 1037, 
1016, 983, 1051, 1003, 983, 986, 1017, 1009, 936, 993, 
965, 976, 1001, 1000, 988, 1030, 1050, 1024, 981, 985, 
935, 1023, 996, 1007, 1013, 1046, 1003, 1006, 973, 989, 
943, 
===============================
Average: 1000.0000000000002
Variance: 1011.8799999983748
Standard deviation: 31.81006130139291

此外,為了對(duì)比單機(jī)版與并行版(4線程)的性能差異,使用10_0000大小的數(shù)據(jù)集,蓄水池容量10,進(jìn)行100_0000次重復(fù)抽樣,對(duì)比兩者的運(yùn)行時(shí)間。結(jié)果如下

---------單機(jī)版----------

===============================
Average: 100.00000000000125
Variance: 100.31497999751264
Standard deviation: 10.015736617818613
---------并行版----------

===============================
Average: 100.00000000000169
Variance: 100.63045999737915
Standard deviation: 10.031473470900432
單機(jī)版耗時(shí):2006s
并行版耗時(shí):1265s

從輸出結(jié)果可以看出,算法保證了數(shù)據(jù)選取的隨機(jī)性。且并行版算法能夠有效提高數(shù)據(jù)吞吐量。

五、應(yīng)用場(chǎng)景

蓄水池抽樣的O(N)時(shí)間復(fù)雜度,O(m)空間復(fù)雜度令其適用于對(duì)流數(shù)據(jù)、大數(shù)據(jù)集的等概率抽樣。比如一個(gè)大文本數(shù)據(jù),隨機(jī)輸出其中的幾行。

六、總結(jié)

象征性總結(jié):優(yōu)雅巧妙的算法——蓄水池抽樣。

七、參考文獻(xiàn)

  1. 數(shù)據(jù)工程師必知算法:蓄水池抽樣
  2. 【算法34】蓄水池抽樣算法 (Reservoir Sampling Algorithm)
  3. 分布式/并行蓄水池抽樣 (Distributed/Parallel Reservoir Sampling)
  4. Distributed/Parallel Reservoir Sampling
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容