BFPRT詳解(top-k問題)

top-K問題指的是從一個數(shù)組中找出里面前k大或是前k小的問題。解決這類問題可以有以下的集中方法。

1、排序。排序之后取前k大或是前k小,時間復雜度是O(nlog(n));

2、使用堆。與之對應的是最大堆和最小堆,時間復雜度是O(nlog(k));

3、使用快排中的partition,將數(shù)組分成小于等于大于三部分,根據(jù)k除去一部分數(shù)據(jù),在對剩下的數(shù)據(jù)進行partition,直至找到前k大或是前k小的數(shù),時間復雜度是O(n),不過這個時間復雜度是概率統(tǒng)計下的結果,并不是嚴格的O(n)的。

下面介紹的BFPRT算法求解top-k問題的時間復雜度是嚴格的O(n)的。

一、BFPRT流程

我們將BFPRT算法看成一個函數(shù)bfprt(int[] arr,int k),返回值是第k大或是第k小的值。

  1. 將數(shù)組按照5個為一組進行分組,最后剩下的不足5個的為一組,此項操作時間復雜度是O(n);

  2. 將每組的5個數(shù)進行排序,因為每組只有5個數(shù)進行排序,一個組排序的時間復雜度是O(1),這個操作的時間復雜度是O(n);

  3. 將每組的中位數(shù)取出來構成一個新的數(shù)組newArr,這個數(shù)組的長度大約是n/5,所以這個操作的時間復雜度是O(n);

  4. 求出新數(shù)組的中位數(shù),即遞歸調用bfprt(newArr,newArr.length/2),假設原來的問題時間復雜度是T(n),則這個操作的時間是T(n/5);

  5. 使用步驟4求出的中位數(shù)進行partition,這一步驟最少可以排除掉arr的3/10,在對剩下的進行bfprt,這個操作的時間是T(7n/10);

整個過程的時間復雜度是T(n) = T(n/5) + T(7n/10) + O(n) = O(n),這個的證明過程大家可以看算法導論。

二、BFPRT算法實現(xiàn)

public static void main(String[] args) {
        int[] arr = { 6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9 };

        int[] res = minKthNumsByBFPRT(arr,11);
        System.out.println(Arrays.toString(res));
        Arrays.sort(arr);
        System.out.println(Arrays.toString(arr));
    }

    public static int[] minKthNumsByBFPRT(int[] arr,int k){
        if(arr == null || k < 1 || arr.length < 1 ||arr.length < k){
            return null;
        }
        // 使用BFPRT算法得到第k小的數(shù)kthNum
        int kthNum = getKthMInNumByBFPRT(arr,0,arr.length - 1,k);
        int[] res = new int[k];
        int index = 0;
        // 將小于kthNum的數(shù)放到結果數(shù)組中
        for(int i = 0; i < arr.length; i++){
            if(arr[i] < kthNum){
                res[index++] = arr[i];
            }
        }
        //小于kthNum的數(shù)個數(shù)小于k個時,剩下的數(shù)全部用kthNum填充
        while(index < k){
            res[index++] = kthNum;
        }
        // 返回最小的前k的數(shù)的數(shù)組
        return res;
    }

    private static int getKthMInNumByBFPRT(int[] arr, int left, int right, int k) {
        if(left == right){
            return arr[left];
        }
        // 將原數(shù)組進行復制
        int[] copy = copyArr(arr);
        // 得到中位數(shù)的中位數(shù)median
        int median = getMedianOfMedian(copy,left,right);
        // 按median進行劃分小于等于大于三部分,中間等于區(qū)域的索引范圍
        int[] range = partition(copy,left,right,median);
        // 在等于區(qū)域的范圍內,直接返回
        if(k >= range[0] && k <= range[1]){
            return copy[k];
        // k在等于區(qū)域的左邊,即在小于區(qū)域,再對左邊的區(qū)域進行BFPRT算法即可
        }else if(k < range[0]){
            return getKthMInNumByBFPRT(copy,left,range[0]-1,k);
        // k在等于區(qū)域的右邊,即在大于區(qū)域,再對右邊的區(qū)域進行BFPRT算法即可
        }else{
            return getKthMInNumByBFPRT(copy,range[1]+1,right,k);
        }
    }

    // 對copy數(shù)組使用median進行partition操作,大于median放左邊,等于median放中間,大于median放右邊
    private static int[] partition(int[] copy, int left, int right, int median) {
        int less = left - 1;
        int more = right + 1;
        int cur = left;
        while(cur < more){
            if(copy[cur] < median){
                swap(copy,++less,cur++);
            }else if(copy[cur] > median){
                swap(copy,cur,--more);
            }else{
                cur++;
            }
        }
        return new int[]{less+1,more-1};
    }

    private static int getMedianOfMedian(int[] copy, int left, int right) {
        int len = right - left + 1;
        // 檢查區(qū)間長度是否能被5整除,如果不足5個剩下的數(shù)作為一組
        int offset = len%5==0?0:1;
        // median是存放每個數(shù)組中位數(shù)的數(shù)組
        int[] median = new int[len/5+offset];
        int index = 0;
        for(int i = left;i <= right;i = i+5){
            // 取最小值是因為最后一組可能沒有5個數(shù)
            int end = Math.min(i + 4,right);
            // 采用插入排序
            insertSort(copy,i,end);
            // 取每組的中位數(shù)
            median[index++] = copy[(i+end)>>1];
        }
        // 求中位數(shù)組成的數(shù)組的中位數(shù)
        return getKthMInNumByBFPRT(median,0,median.length-1,median.length/2);
    }

    private static void insertSort(int[] copy, int left, int right) {
        for(int i = left+1;i<=right;i++){
            for (int j = i-1; j >=left ; j--) {
                if(copy[j] > copy[j+1]){
                    swap(copy,j,j+1);
                }else{
                    break;
                }
            }
        }
    }

    private static void swap(int[] copy, int j, int i) {
        int temp = copy[j];
        copy[j]=  copy[i];
        copy[i] = temp;
    }

    private static int[] copyArr(int[] arr) {
        int[] res = new int[arr.length];
        for (int i = 0; i < arr.length; i++) {
            res[i] = arr[i];
        }
        return res;
    }

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容