算法專題:Merge Sort

說起歸并排序(Merge Sort),其在排序界的地位可不低,畢竟O(nlogn)比較排序的三大排序方法,就是Quick Sort, Merge Sort和Heap Sort。歸并排序是典型的分而治之方法,先來看看其最簡單的遞歸實現:

def merge_sort(lst):
    """Sortsthe input list using the merge sort algorithm.

    # >>> lst = [4, 5, 1, 6, 3]
    # >>> merge_sort(lst)
    [1, 3, 4, 5, 6]
    """
    if len(lst) <= 1:
        return lst
    mid = len(lst) // 2
    left = merge_sort(lst[:mid])
    right = merge_sort(lst[mid:])
    return merge(left, right)

def merge(left, right):
    """Takestwo sorted lists and returns a single sorted list by comparing the
    elements one at a time.

    # >>> left = [1, 5, 6]
    # >>> right = [2, 3, 4]
    # >>> merge(left, right)
    [1, 2, 3, 4, 5, 6]
    """
    if not left:
        return right
    if not right:
        return left
    if left[0] < right[0]:
        return [left[0]] + merge(left[1:], right)
    return [right[0]] + merge(left, right[1:])

很明顯,歸并排序是典型的分而治之(Divide and Conquer,D&C)算法,思想就是先把兩半數據分別排序,然后再歸并到一起。這樣T(n) = 2T(n/2) + O(n),由Master Theorem可以得到其時間復雜度是O(nlogn)。
再看具體的實現。排序主體函數用的是遞歸,歸并算法一般都是這樣;而merge部分其實也可以用迭代來完成:

def merge_2(left, right):
    p1 = p2 = 0
    temp = []
    while p1 < len(left) and p2 < len(right):
        if left[p1] <= right[p2]:
            temp.append(left[p1])
            p1 += 1
        else:
            temp.append(right[p2])
            p2 += 1
    while p1 < len(left):
        temp.append(left[p1])
        p1 += 1
    while p2 < len(right):
        temp.append(right[p2])
        p2 += 1
    return temp 

單純就此情景而言,迭代的merge顯得冗長而且效率沒有提升。但是其好處就是適用性廣,因為有很多merge sort的變形,不太方便遞歸調用merge函數。
變形:merge sort有很多tweak的應用,大部分是需要考慮數組前后關系。

例1. 給出一個數組nums,需要對每個數其index之后比它大的數的個數求和。例如給出[7, 4, 5, 2, 8, 9, 0, 1],返回11,因為7后面有8,9兩個比它大的,4有3個,5有2個,2有2個,8有1個,0有1一個,總共2+3+2+2+1+1=11。

【解】
Method 1:一個naive的方法就是對于每一個數,遍歷搜索其后面所有比其大的數,顯然時間復雜度是O(n^2)。

Method 2:還有一個方法就是考慮Segment Tree,先構建從min到max的線段樹,O(max(n) - min(n)),初始count都是0。然后從反方向考慮,考慮前面比其小的有多少個。也就是說對于某個n[i],考慮[min(n), n[i]-1]這個區(qū)間里面有多少count。完了再把n[i]的count++。就這個例子而言,先放7,然后4進來的時候搜索[0,3]區(qū)間,因為是要比4小,再把4的count設置為1;這樣5進來的時候,就能搜索到4的存在。
這個算法后面的步驟是O(nlogn),但是需要構造一個線段樹,假如max很大很大,就不太合適。當然也可以argue說我就干脆構造一個囊括最小到最大的32位int的線段樹,這還是O(1)呢XD。

Method 3:這個解法考慮使用merge sort的tweak。因為要求每個n[i]之后比它大的數,可以用分而治之的思想,即考慮前一半有多少個,后一半有多少個,然后再考慮之間有多少個。
就這個例子而言,考慮最后一次merge之前的樣子:[2,4,5,7][0,1,8,9],此時兩半里面都已經計算完畢,只需要計算merge時候產生的結果。很明顯結果是8,因為前一半的4個數都比8和9要小。但是如何計算呢?
考慮到后一半已經排好序,假如對后一半使用binary search,自然可以得到第一個大于前一半某一個數的index,從而獲得所有大于這個數的個數。也就是說這個merge是O(nlogn)。那么總體就是T(n) = T(n/2) + O(nlogn),由Master Theorem可知復雜度是O(n(logn)^2)。

Method 4:但是,上面這個merge方法沒有利用前一半也排好序的條件,因此可以做到更好。
考慮兩個指針p1和p2,分別指向前一半和后一半。p1初始是在2,p2在0,因為此時n[p2] < n[p1]因此p2增加,直至指向8。那么因為第二半是遞增的,p2后面的數肯定也滿足,因此這時候就可以獲得大于n[p1]的個數:第二半的長度-p2。然后呢?p1遞增指向4,假如p2重新回到0然后掃描,這個復雜度就是O(n^2),比上面的二分查找還要差。
因此做一些調整,不遞增p1,而是遞增p2。也就是說換一個思路,不是從第二半里面找比第一半大的,而是從第一半里面找比第二半小的:剛開始還是p1指向2,p2指向0,然后因為n[p1] >= n[p2],因為第一半遞增,后面的肯定也比n[p2]要大,因此沒必要往后看,可以直接計算個數:p1-s,s是遞歸使用的開始的index,這里是0,也就是說對于n[p2]沒有比其更小的。
然后遞增p2,但需要注意的是p1不用復位,這是很關鍵的一點。為什么?因為p1停止的條件,要么就是已經掃完整個一半了,要么就是現在的n[p1]比n[p2-1]要大,也就是說現在的p1之前的都比n[p2-1]要小,而n[p2]>n[p2-1],因此前面那些根本就不需要比較就能知道結論,可以直接沿用之前的p1的位置。
在這個例子里面,比較明顯的就是8和9.對于8,p1將會遞增至第一半的長度,也就是說整個第一半都比8要小,那么對于9而言,比8大,因此整個第一半也都比9小,無需再從頭比較。
再舉一個一般性一點的例子:[2,4,5,7][0,1,6,8],對于6,p1將會停在7上面,計數是3;p2遞增后,對于8,可以知道p1前面都是比6小的,那么肯定也就比8小,因此直接從p1在7上面開始,最后計數是4。
這樣一來,merge函數兩個指針就不需要走回頭路,效率O(n),整體效率是O(nlogn),空間復雜度O(n)。當然,具體實現的時候,還是要把兩半真正的merge排好序,因為上面的計算都是在兩邊都排好序的情況下進行的。當只有一個元素的時候可以直接返回0。代碼如下:

# count number that larger than it and after it
def dc2(self, n, s, e):
    if s >= e:
        return 0
    m = (s + e) // 2
    ans = self.dc2(n, s, m) + self.dc2(n, m + 1, e)
    p1 = s
    for q in range(m + 1, e + 1):
        while p1 <= m and n[q] > n[p1]:
            p1 += 1
        ans += p1 - s
    # merge
    temp = []
    p1, p2 = s, m + 1
    while p1 <= m and p2 <= e:
        if n[p1] <= n[p2]:
            temp.append(n[p1])
            p1 += 1
        else:
            temp.append(n[p2])
            p2 += 1
    while p1 <= m:
        temp.append(n[p1])
        p1 += 1
    while p2 <= m:
        temp.append(n[p2])
        p2 += 1
    for i in range(len(temp)):
        n[i + s] = temp[i]
    return ans 

例2. 給出一個數組n和一個范圍[a, b],求n有多少個子區(qū)間的和在[a,b]之內。假設數組n的元素和a,b都是整數。例如給出[2,3,4,1],和范圍[3,5],那么子區(qū)間[3][4][2,3][4,1]都滿足條件,返回4。

【解】
Method 1:naive方法就是找出所有的子區(qū)間,然后看有多少個滿足條件。復雜度非常高。

Method 2:看到子區(qū)間之和,當然想到prefix sum。也就是說可以造一個數組s,每一個元素s[i] = n[0]+...+n[i]。那么所有的子區(qū)間除了n[0]都可以用s的后一個元素減去前一個元素獲得。
也就是說,問題轉換成為:給出一個數組s,計算有多少對ij,使得s[i] - s[j] in [a,b]而且i < j?
假如s是升序的,那么好說;但s是無序的。Naive方法就是對每一個s[i],都掃一遍后面元素看看能不能滿足在區(qū)間a+s[i],b+s[i]里面,假如滿足那么減去s[i]就在要求的區(qū)間里面。當然最后還需要比較一下單個的s元素。這個做法復雜度O(n^2)。

Method 3:在子區(qū)間prefix sum的基礎上,考慮merge sort的tweak。假設n=[7, 4, 5, 2, 8, 9, 0, 1], a=0, b=7。
考慮最后一次merge之前的情況:[2,4,5,7][0,1,8,9].從上一題得到啟發(fā),假如對第一半里的每一個數s[i],在第二半里面二分查找第一個大于等于s[i]+a的index1,假如index1不存在那就不需要再找了,沒有符合條件的;和第一個大于s[i]+b的index2,假如不存在那么index2=e也就是end的index。那么自然就可以得到個數index2 - index1。這樣merge的復雜度是O(nlogn),總體O(n(logn)^2)。

Method 4:
在Method 3的基礎上改進。類似于例1,Method 3的問題還是在于沒有利用第一半排好序的條件。
考慮三個指針,p1p2和q,p1p2都指向第一半,p2指向第二半。 因為要利用第一半排序的條件,因此還是固定遞增q。對于s[q],需要s[q] - s[i] 在區(qū)間[a,b]當中。也就是說s[q] - s[i] >= a, s[i] <= s[q] - a; s[q] - s[i] <= b, s[i] >= s[q] - b。
因此兩個指針p1p2,p1不斷遞增直至不滿足s[p1] <= s[q] - a,p2不斷遞增直至不滿足s[p2] < s[q] - b。那么,p1之前的都是滿足s[q] - s[i] >= a的,p2之后的都是滿足s[q] - s[i] <= b,p1p2之間的就是滿足條件的,即count+=p1-p2.
然后遞增q,因為s[q] >= s[q-1],因此之前p1p2的位置可以延續(xù),即s[q] - a >= s[q-1] - a >= s[i],也就是說p1之前p2之前的元素還是滿足那些條件。因此,這個merge函數的復雜度是O(n),總體時間復雜度O(nlogn),空間復雜度O(n)。注意單個區(qū)間的情況已經被涵蓋了。代碼如下:

class Solution:
    # count numbers of subarray sum in range of [a,b]
    def countSubarraySum(self, nums, a, b):
        if not nums:
            return 0
        n = [0] * len(nums)
        for i in range(len(nums)):
            if i != 0:
                n[i] = nums[i] + n[i - 1]
            else:
                n[i] = nums[i]
        return self.dc(n, a, b, 0, len(n) - 1)

    # count number of prefix sum that x[i] - x[j] in [a, b] and i > j plus itself in [a, b]
    def dc(self, n, a, b, s, e):
        if s > e:
            return 0
        if s == e:
            return a <= n[s] <= b
        m = (s + e) // 2
        ans = self.dc(n, a, b, s, m) + self.dc(n, a, b, m + 1, e)
        p1 = p2 = s
        for q in range(m + 1, e + 1):
            while p1 <= m and n[q] - n[p1] >= a:
                p1 += 1
            while p2 <= m and n[q] - n[p2] > b:
                p2 += 1
            if p2 <= p1:
                ans += p1 - p2
        # merge
        temp = []
        p1, p2 = s, m + 1
        while p1 <= m and p2 <= e:
            if n[p1] <= n[p2]:
                temp.append(n[p1])
                p1 += 1
            else:
                temp.append(n[p2])
                p2 += 1
        while p1 <= m:
            temp.append(n[p1])
            p1 += 1
        while p2 <= m:
            temp.append(n[p2])
            p2 += 1
        for i in range(len(temp)):
            n[i + s] = temp[i]
        return ans
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

  • 背景 一年多以前我在知乎上答了有關LeetCode的問題, 分享了一些自己做題目的經驗。 張土汪:刷leetcod...
    土汪閱讀 12,890評論 0 33
  • 貪心算法 貪心算法總是作出在當前看來最好的選擇。也就是說貪心算法并不從整體最優(yōu)考慮,它所作出的選擇只是在某種意義上...
    fredal閱讀 9,419評論 3 52
  • 概述 排序有內部排序和外部排序,內部排序是數據記錄在內存中進行排序,而外部排序是因排序的數據很大,一次不能容納全部...
    蟻前閱讀 5,297評論 0 52
  • 概述:排序有內部排序和外部排序,內部排序是數據記錄在內存中進行排序,而外部排序是因排序的數據很大,一次不能容納全部...
    每天刷兩次牙閱讀 3,818評論 0 15
  • 1.插入排序—直接插入排序(Straight Insertion Sort) 基本思想: 將一個記錄插入到已排序好...
    依依玖玥閱讀 1,347評論 0 2

友情鏈接更多精彩內容