什么是線段樹
線段樹(Segment Tree)也叫區(qū)間樹,其本質上是一種二分搜索樹,不同點在于線段樹中每個節(jié)點不再是存放單純的元素,而是存放了一個可以表示區(qū)間的值,通常是該區(qū)間合并后的值。并且每個區(qū)間會被平均分為2個子區(qū)間,作為它的左右子節(jié)點。比如說根節(jié)點存放了區(qū)間 [1,10],那么就會被分為區(qū)間 [1,5] 作為左子節(jié)點,區(qū)間 [6,10] 作為右子節(jié)點。
例如,我們可以將這樣一個數組所表示的區(qū)間構造成線段樹:

并且指定區(qū)間合并規(guī)則為區(qū)間內的元素求和,那么構造出來的線段樹表示如下:

- 從這顆線段樹可以看到,由于具有二分搜索樹的特性,我們可以快速地在線段樹中找到一個區(qū)間。注意,這里是指按區(qū)間查找,而不是按元素值查找。所以線段樹相對難理解的地方就在于每個節(jié)點既有區(qū)間的概念又有一個元素值。
為什么要使用線段樹
關于線段樹的一個經典問題就是:區(qū)間染色。假設有一面墻,長度為 n,每次選擇一段兒墻進行染色。在 m 次操作后,我們可以在 [i, j] 區(qū)間內看見多少中顏色?
對于這個問題,我們可以使用一個數組來實現:

對于染色操作(更新區(qū)間)我們可以遍歷數組找到目標區(qū)間進行染色,時間復雜度是 。對于查詢操作(查詢區(qū)間)也是遍歷數組即可,同樣時間復雜度為
。顯然用線性結構來解決這類問題的時間復雜度要更高一些,此時線段樹就派上用場了,因為樹形結構的時間復雜度通常在
。
除此之外,線段樹的另一個經典問題就是:區(qū)間查詢。查詢一個區(qū)間 [i, j] 的最大值和最小值,或者區(qū)間數字之和。例如,在實際業(yè)務中很常見的基于區(qū)間的統(tǒng)計查詢:2017年注冊用戶中消費最高的用戶?消費最少的用戶?學習時間最長的用戶?某個太空區(qū)間中天體總量?
對于靜態(tài)區(qū)間數據(區(qū)間內的數據不會發(fā)生變化)來說,是比較好解決的,但以上所提到的問題都是動態(tài)的區(qū)間數據(區(qū)間內的數據在不斷的變化),此時線段樹就是一個比較好的選擇。
通過以上的介紹,我們能總結出線段樹的兩個核心操作:
- 區(qū)間更新:更新區(qū)間中一個元素或者一個區(qū)間的值
- 區(qū)間查詢:查詢一個區(qū)間
[i, j]的最大值、最小值,或者區(qū)間數字之和
線段樹基礎表示
線段樹雖然不像堆那樣是一棵完全二叉樹,但線段樹由于其特性滿足平衡二叉樹(左右子樹高度相差不超過1),所以依然可以使用數組進行表示。我們可以將其看做是一顆滿二叉樹,空節(jié)點就當做葉子節(jié)點即可。如下示例:

既然可以用數組來表示一棵線段樹,那么如果區(qū)間有 n 個元素,此時應該創(chuàng)建多大容量的數組來構建一顆線段樹呢?對于這個問題,我們先來看如何求一棵滿二叉樹的節(jié)點:假設這棵樹有 h 層,那么這棵樹就一共有 個節(jié)點(大約是
)。對于最后一層(
層)來說,就有
個節(jié)點。因此,最后一層的節(jié)點數大致等于前面所有層節(jié)點之和。
了解了如何求滿二叉樹的節(jié)點數量后,回到之前的問題,如果區(qū)間有 n 個元素,此時應該開多大空間的數組?我們可以分成兩種情況:
- 如果
,那么只需要開辟
的數組空間
- 如果
,那么就需要開辟
的數組空間
通常來說,我們的線段樹不考慮添加元素,即區(qū)間固定(區(qū)間內的數據可以是不固定的),那么使用 的靜態(tài)空間即可。這也是普遍構造線段樹時,使用的一個通用值。除非對內存有嚴格要求,否則一般開辟
的數組空間即可。而且對于內存有要求的情況下,一般也不會采用數組來表示,此時鏈式結會是更優(yōu)的選擇。
接下來,我們就實現一下線段樹的基礎結構代碼:
package tree;
/**
* 線段樹 - 基于數組的表示實現
*
* @author 01
* @date 2021-01-27
**/
public class SegmentTree<E> {
/**
* 保存原始數組,即需要被構造成線段樹的區(qū)間
*/
private E[] data;
/**
* 線段樹的數組表示
*/
private E[] tree;
public SegmentTree(E[] arr) {
this.data = (E[]) new Object[arr.length];
System.arraycopy(arr, 0, this.data, 0, arr.length);
// 開辟 4n 的數組空間用于構造線段樹
this.tree = (E[]) new Object[4 * arr.length];
}
public int getSize() {
return data.length;
}
public E get(int index) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal");
}
return data[index];
}
/**
* 返回完全二叉樹的數組表示中,一個索引所表示的元素的左子節(jié)點的索引
*/
private int leftChild(int index) {
return 2 * index + 1;
}
/**
* 返回完全二叉樹的數組表示中,一個索引所表示的元素的右子節(jié)點的索引
*/
private int rightChild(int index) {
return 2 * index + 2;
}
}
- 對于這里求某個索引對于的左右子節(jié)點索引的方式,可以參考之前數據結構之優(yōu)先隊列和堆一文中的說明
創(chuàng)建線段樹
在本小節(jié)中,我們來根據之前實現的基礎代碼,完成創(chuàng)建線段樹邏輯的編寫。需要說明一下的是,在本例中,線段樹每個節(jié)點所存儲的元素是區(qū)間合并后的值。具體的實現代碼如下:
/**
* 用戶自定義的區(qū)間合并邏輯
*/
private final Merger<E> merger;
public SegmentTree(E[] arr, Merger<E> merger) {
this.merger = merger;
this.data = (E[]) new Object[arr.length];
System.arraycopy(arr, 0, this.data, 0, arr.length);
// 開辟 4n 的數組空間用于構建線段樹
this.tree = (E[]) new Object[4 * arr.length];
// 構建線段樹,傳入根節(jié)點索引,以及區(qū)間的左右端點
buildSegmentTree(0, 0, data.length - 1);
}
/**
* 在treeIndex的位置創(chuàng)建表示區(qū)間[left...right]的線段樹
*/
private void buildSegmentTree(int treeIndex, int left, int right) {
// 區(qū)間中只有一個元素,代表遞歸到底了
if (left == right) {
tree[treeIndex] = data[left];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 計算中間點,需要避免整型溢出
int mid = left + (right - left) / 2;
// 構建左子樹
buildSegmentTree(leftTreeIndex, left, mid);
// 構建右子樹
buildSegmentTree(rightTreeIndex, mid + 1, right);
// 對于兩個區(qū)間的合并規(guī)則是與業(yè)務相關的,所以要調用用戶自定義的邏輯來完成
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
/**
* 遍歷打印樹中節(jié)點中值信息。
*
* @return String
*/
@Override
public String toString() {
StringBuilder res = new StringBuilder();
res.append('[');
for (int i = 0; i < tree.length; i++) {
if (tree[i] != null) {
res.append(tree[i]);
} else {
res.append("null");
}
if (i != tree.length - 1) {
res.append(", ");
}
}
res.append(']');
return res.toString();
}
- 在線段樹中根節(jié)點存儲的數據,實際就是左右兩個子節(jié)點數據的合并(遞歸即可),而具體如何合并是由業(yè)務決定的。例如,可以是求和,也可以是求最大值或最小值。另外,這里沒有通過一個對象來表示節(jié)點中的左右區(qū)間,而是通過方法參數的形式表示了這個區(qū)間,數組中只存儲區(qū)間合并后的值。
用戶傳入的 Merger 是一個接口,其定義如下:
package tree;
/**
* 合并器接口
*
* @author 01
* @date 2021-01-27
**/
public interface Merger<E> {
/**
* 用戶自定義的區(qū)間合并邏輯
*
* @param a 區(qū)間a
* @param b 區(qū)間b
* @return 合并后的結果
*/
E merge(E a, E b);
}
最后,我們來編寫一個簡單的測試用例進行一下測試:
package tree;
/**
* 測試SegmentTree
*
* @author 01
*/
public class SegmentTreeTests {
public static void main(String[] args) {
Integer[] nums = {-2, 0, 3, -5, 2, -1};
SegmentTree<Integer> segTree = new SegmentTree<>(
nums, Integer::sum // 對兩個區(qū)間中的值進行求和
);
System.out.println(segTree);
}
}
輸出結果如下:
[-3, 1, -4, -2, 3, -3, -1, -2, 0, null, null, -5, 2, null, null, null, null, null, null, null, null, null, null, null]
- 可以看到,線段樹的根節(jié)點是
-3,因為對整個數組的求和結果就是-3。左子節(jié)點為1,因為-2 + 0 + 3 = 1。右子節(jié)點為-4,同理,因為-5 + 2 + -1 = -4,其余以此類推。結果符合預期,證明我們實現的線段樹沒有問題。
線段樹中的區(qū)間查詢
例如,我們要對如下這棵線段樹查詢 [2, 5] 這個區(qū)間:

由于我們之前傳入的 Merger 實現的是求和邏輯,那么這相當于查詢2 ~ 5區(qū)間所有元素的和。從根節(jié)點開始往下,我們知道分割位置,左節(jié)點查詢 [2, 3],右節(jié)點查詢 [4, 5],找到兩個節(jié)點之后合并就可以了。
具體的實現代碼如下:
/**
* 查詢區(qū)間[queryLeft, queryRight]的值,如[2, 5]
*/
public E query(int queryLeft, int queryRight) {
if (queryLeft < 0 || queryLeft >= data.length ||
queryRight < 0 || queryRight >= data.length ||
queryLeft > queryRight) {
throw new IllegalArgumentException("Index is illegal");
}
return query(0, 0,
data.length - 1, queryLeft, queryRight);
}
/**
* 在以treeIndex為根的線段樹中[left...right]的范圍里,搜索區(qū)間[queryLeft...queryRight]的值
*/
private E query(int treeIndex, int left, int right,
int queryLeft, int queryRight) {
// 找到了目標區(qū)間
if (left == queryLeft && right == queryRight) {
return tree[treeIndex];
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 計算中間點,需要避免整型溢出
int mid = left + (right - left) / 2;
if (queryLeft >= mid + 1) {
// 目標區(qū)間不在左子樹中,查找右子樹
return query(rightTreeIndex, mid + 1, right, queryLeft, queryRight);
} else if (queryRight <= mid) {
// 目標區(qū)間不在右子樹中,查找左子樹
return query(leftTreeIndex, left, mid, queryLeft, queryRight);
}
// 目標區(qū)間一部分在右子樹中,一部分在左子樹中,則兩個子樹都需要找
E leftResult = query(leftTreeIndex, left, mid, queryLeft, mid);
E rightResult = query(rightTreeIndex, mid + 1, right, mid + 1, queryRight);
// 找到目標區(qū)間的值,將其合并后返回
return merger.merge(leftResult, rightResult);
}
進行一個簡單的測試:
public static void main(String[] args) {
Integer[] nums = {-2, 0, 3, -5, 2, -1};
SegmentTree<Integer> segTree = new SegmentTree<>(
nums, Integer::sum // 對兩個區(qū)間中的值進行求和
);
System.out.println(segTree.query(0,2));
System.out.println(segTree.query(2,5));
System.out.println(segTree.query(0,5));
}
輸出結果如下:
1
-1
-3
線段樹中的更新操作
我們使用線段樹來解決區(qū)間相關的問題,主要是針對區(qū)間內的數據是動態(tài)變化的情況,如果是靜態(tài)區(qū)間一般不需要用到線段樹。所以在本小節(jié),我們就來實現線段樹中的更新操作。
實際上線段樹中的更新操作,本質上是在二分查找。因為根據線段樹的特性,待更新的目標節(jié)點肯定是一個葉子節(jié)點,我們只需要找到這個葉子節(jié)點并進行更新即可。我們查找待更新節(jié)點的依據是數組的索引,而數組的索引是從 0 ~ n 有序的,所以在一個有序的區(qū)間中查找某個特定的值,妥妥的就是二分查找了。
知道了我們在更新線段樹中某個節(jié)點時,要找的這個待更新節(jié)點是一個葉子節(jié)點,并且找到這個葉子節(jié)點的過程本質上是一個二分查找,那么這個思路就很清晰了。
首先,將找到葉子節(jié)點的條件作為遞歸的退出條件。然后計算中間點,并將線段樹數組劃分為 [left...mid] 和 [mid+1...right] 兩個區(qū)間。接著判斷要找的數組索引落在哪個區(qū)間,就繼續(xù)往哪個區(qū)間遞歸查找。最后,將區(qū)間的值進行合并。如此一來,就完成了目標節(jié)點的更新操作。
具體的實現代碼如下:
/**
* 將index位置的值,更新為e
*/
public void set(int index, E e) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal");
}
data[index] = e;
set(0, 0, data.length - 1, index, e);
}
/**
* 在以treeIndex為根的線段樹中更新index的值為e
*/
private void set(int treeIndex, int left, int right, int index, E e) {
// 找到了葉子節(jié)點
if (left == right) {
// 進行更新
tree[treeIndex] = e;
return;
}
int mid = left + (right - left) / 2;
// 將線段樹數組劃分為[left...mid]和[mid+1...right]兩個區(qū)間
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (index >= mid + 1) {
// index在右子樹
set(rightTreeIndex, mid + 1, right, index, e);
} else {
// index在左子樹
set(leftTreeIndex, left, mid, index, e);
}
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
Leetcode上線段樹相關的問題
在本文的最后,我們來使用自己實現的線段樹解決一個Leetcode上的307號問題:
該問題的主要需求是更新數組下標對應的值,以及查詢數組中某個區(qū)間內的元素總和。像這種對區(qū)間內數據有更新需求的,會使得區(qū)間內數據動態(tài)變化的,就很適合使用線段樹來解決。具體的實現代碼如下:
package tree.solution;
import tree.SegmentTree;
/**
* Leetcode 307. Range Sum Query - Mutable
* https://leetcode.com/problems/range-sum-query-mutable/description/
*/
class NumArray {
private SegmentTree<Integer> segTree;
public NumArray(int[] nums) {
if (nums.length != 0) {
Integer[] data = new Integer[nums.length];
for (int i = 0; i < nums.length; i++) {
data[i] = nums[i];
}
segTree = new SegmentTree<>(data, Integer::sum);
}
}
public void update(int i, int val) {
if (segTree == null) {
throw new IllegalArgumentException("Error");
}
segTree.set(i, val);
}
public int sumRange(int i, int j) {
if (segTree == null) {
throw new IllegalArgumentException("Error");
}
return segTree.query(i, j);
}
}