排序是工程中必不可少的功能,很多編程語言SDK都提供了排序相關(guān)的實現(xiàn)。作為軟件工程師,我們在學(xué)習(xí)各類排序算法的同時,是否有思考過,如何去實現(xiàn)一個工業(yè)級的排序算法?如果你是Go語言的作者之一,該如何去實現(xiàn)一種能適應(yīng)多種情況的排序算法?
Go SDK中排序相關(guān)的實現(xiàn)主要在sort/sort.go中,本文主要基于該文件進(jìn)行相關(guān)實現(xiàn)的分析。
首先來看看Go對排序接口的定義,利用Go的interface特性可以輕松實現(xiàn)多種數(shù)據(jù)類型的排序功能。想要調(diào)用sort包的排序功能我們需要實現(xiàn)這個排序接口,排序接口主要定義了三個方法:
-
Len() int: 返回傳入數(shù)據(jù)的總數(shù) -
Less(i, j int) bool: 返回數(shù)組中下標(biāo)為i的數(shù)據(jù)是否小于下標(biāo)為j的數(shù)據(jù) -
Swap(i, j int): 表示執(zhí)行交換數(shù)組中下標(biāo)為i的數(shù)據(jù)和下標(biāo)為j的數(shù)據(jù)
// A type, typically a collection, that satisfies sort.Interface can be
// sorted by the routines in this package. The methods require that the
// elements of the collection be enumerated by an integer index.
type Interface interface {
// Len is the number of elements in the collection.
Len() int
// Less reports whether the element with
// index i should sort before the element with index j.
Less(i, j int) bool
// Swap swaps the elements with indexes i and j.
Swap(i, j int)
}
了解了包中對sort接口的定義后,再來看看sort包對外提供的主要接口Sort,源碼如下:
// Sort sorts data.
// It makes one call to data.Len to determine n, and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
func Sort(data Interface) {
n := data.Len()
quickSort(data, 0, n, maxDepth(n))
}
如注釋所說,當(dāng)我們調(diào)用Sort方法時,該方法會調(diào)用一次data.Len(),之后會以O(n*log(n))的時間復(fù)雜度調(diào)用data.Less和data.Swap。我們可以看到,Sort內(nèi)部調(diào)用了包私有的quickSort方法,也就是我們熟悉的快排,同時傳了4個參數(shù),學(xué)過快排的同學(xué)都能理解前三個參數(shù)的含義,但是我們還看到了一個陌生的函數(shù)調(diào)用maxDepth(n),這里的depth究竟代表什么呢?所以先探究一下這個函數(shù),代碼如下:
// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
func maxDepth(n int) int {
var depth int
for i := n; i > 0; i >>= 1 {
depth++
}
return depth * 2
}
簡單來說,maxDepth方法返回的深度表示了數(shù)據(jù)的量級,qiuckSort方法會根據(jù)這個量級選擇使用快排還是堆排序,學(xué)過堆排序的同學(xué)都知道,堆排序的時間復(fù)雜度穩(wěn)定在O(nlogn),有時候比快排還穩(wěn)定,但是堆排序?qū)?shù)據(jù)是跳著訪問的,對CPU緩存不友好。
了解了maxDepth方法以后就可以來看看quickSort的源碼了
func quickSort(data Interface, a, b, maxDepth int) {
for b-a > 12 { // Use ShellSort for slices <= 12 elements
if maxDepth == 0 {
heapSort(data, a, b)
return
}
maxDepth--
mlo, mhi := doPivot(data, a, b)
// Avoiding recursion on the larger subproblem guarantees
// a stack depth of at most lg(b-a).
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi // i.e., quickSort(data, mhi, b)
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo // i.e., quickSort(data, a, mlo)
}
}
if b-a > 1 {
// Do ShellSort pass with gap 6
// It could be written in this simplified form cause b-a <= 12
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}
}
這里代碼的實現(xiàn)方式比較好理解,首先對于數(shù)組元素大于12個的情況會在快排和堆排之間選擇,除此之外的情況會使用希爾排序(間隔為6)和插入排序進(jìn)行排序。
包中對于heapSort的實現(xiàn)中規(guī)中矩,使用從上往下堆化的方式建堆。這里就不詳細(xì)介紹,對于快排的實現(xiàn)方式,有的同學(xué)就發(fā)現(xiàn)不同了,這里調(diào)用了一個尋找分區(qū)點(diǎn)的函數(shù)doPivot,但是doPivot返回了兩個值(這里就利用了Go中函數(shù)可以有多個返回值的特性)。同時這里可以看到返回mlo,mhi以后并沒有繼續(xù)遞歸地在左右分區(qū)查找,而是做了一個比較,原因也正如注釋所說,由于使用了遞歸的方式實現(xiàn)排序,就必須要考慮到棧溢出的問題,所以對分區(qū)的兩半,把數(shù)量多的放到下一次循環(huán)繼續(xù)切分循環(huán),小的直接遞歸。這里也表明了調(diào)用quickSort的最高棧深度為log(b-a),也就是log(n)。
接下來可以看看doPivot函數(shù),為什么會返回兩個分區(qū)點(diǎn)呢?因為mlo到mhi之間的數(shù)已經(jīng)被確定了位置,這里考慮到取中位數(shù)的時候數(shù)組出現(xiàn)大量重復(fù)的數(shù)會影響到排序性能的問題,可以發(fā)現(xiàn)Go作者對這種情況的解決方式充滿著智慧。具體代碼如下:
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
m := int(uint(lo+hi) >> 1) // 首先用位運(yùn)算的方式求中間點(diǎn),防止溢出
if hi-lo > 40 {
// 多數(shù)取中
// Tukey's ``Ninther,'' median of three medians of three.
s := (hi - lo) / 8
medianOfThree(data, lo, lo+s, lo+2*s)
medianOfThree(data, m, m-s, m+s)
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
}
medianOfThree(data, lo, m, hi-1)
// 接下來要對數(shù)據(jù)達(dá)成以下劃分結(jié)果
// data[lo] = pivot (set up by ChoosePivot)
// data[lo < i < a] < pivot
// data[a <= i < b] <= pivot
// data[b <= i < c] unexamined
// data[c <= i < hi-1] > pivot
// data[hi-1] >= pivot
pivot := lo
a, c := lo+1, hi-1
for ; a < c && data.Less(a, pivot); a++ {
}
b := a
for {
for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
}
for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
}
if b >= c {
break
}
// data[b] > pivot; data[c-1] <= pivot
data.Swap(b, c-1)
b++
c--
}
// 如果data[c <= i < hi-1] > pivot,hi-c<3 這表明數(shù)據(jù)中有重復(fù)的數(shù),
// 這里保守一些,認(rèn)為hi-c<5 為邊界,如果重復(fù)的數(shù)較多,
// 會以直接掃描跳過的方式把pivot左右兩邊的區(qū)間縮小
// If hi-c<3 then there are duplicates (by property of median of nine).
// Let's be a bit more conservative, and set border to 5.
protect := hi-c < 5
if !protect && hi-c < (hi-lo)/4 {
// Lets test some points for equality to pivot
dups := 0
if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
data.Swap(c, hi-1)
c++
dups++
}
if !data.Less(b-1, pivot) { // data[b-1] = pivot
b--
dups++
}
// m-lo = (hi-lo)/2 > 6
// b-lo > (hi-lo)*3/4-1 > 8
// ==> m < b ==> data[m] <= pivot
if !data.Less(m, pivot) { // data[m] = pivot
data.Swap(m, b-1)
b--
dups++
}
// if at least 2 points are equal to pivot, assume skewed distribution
protect = dups > 1
}
if protect {
// Protect against a lot of duplicates
// Add invariant:
// data[a <= i < b] unexamined
// data[b <= i < c] = pivot
for {
for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
}
for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
}
if a >= b {
break
}
// data[a] == pivot; data[b-1] < pivot
data.Swap(a, b-1)
a++
b--
}
}
// Swap pivot into middle
data.Swap(pivot, b-1)
return b - 1, c
}