一、定義
Trie樹,又稱為單詞查找樹,是一種樹形結(jié)構(gòu)(Trie一詞源于單詞Retrieval-取出)。
Trie樹經(jīng)常被搜索引擎系統(tǒng)用于文本詞頻統(tǒng)計。它的特點是:
利用字符串的公共前綴來減少查詢時間,最大限度地減少無謂的字符串比較。
- 查找命中所需的時間與被查找的鍵的長度成正比;
- 查找未命中只需檢查若干個字符;

1.1 數(shù)據(jù)結(jié)構(gòu)定義
Trie樹的實現(xiàn)方式有很多種,本文中的實現(xiàn)稱為“R向單詞查找樹”(R為字母表大小):
- Trie樹的根結(jié)點不保存字符(也可看成保存空字符"");
- Trie樹的每個結(jié)點含有R條鏈接(R為字母表的大?。?,每個結(jié)點
R.children[i]指向以字符i為根結(jié)點的子樹; - 每個鍵所關(guān)聯(lián)的值保存在該鍵的最后一個字符所在的結(jié)點中(值為空的結(jié)點在Trie樹中沒有對應(yīng)的鍵)。
public class TrieST<V> {
private static final int R = 256; // extended ASCII
private Node root; // root of trie
private int n; // number of keys in trie
// R-way trie node
private class Node {
private V val;
private Node[] children = new Node[R];
}
}

1.2 API定義

二、實現(xiàn)
2.1 查找
Trie樹的查找步驟如下:
- 從根結(jié)點開始一次搜索;
- 取得要查找關(guān)鍵詞的第一個字母,并根據(jù)該字母選擇對應(yīng)的子樹并轉(zhuǎn)到該子樹繼續(xù)進行檢索;
- 在相應(yīng)的子樹上,取得要查找關(guān)鍵詞的第二個字母,并進一步選擇對應(yīng)的子樹進行檢索。
- 迭代過程……
- 在某個結(jié)點處,關(guān)鍵詞的所有字母已被取出,則讀取附在該結(jié)點上的信息,即完成查找。
查找結(jié)果共3種情況:
- 鍵的尾字符對應(yīng)的結(jié)點中保存的值為空;(未命中)
- 鍵的尾字符對應(yīng)的結(jié)點中保存的值非空;(命中)
- 查找結(jié)束于一條空鏈接。(未命中)

查找-源碼實現(xiàn):
public V get(String key) {
if (key == null)
throw new IllegalArgumentException("argument to get() is null");
Node x = get(root, key, 0);
if (x == null)
return null;
return x.val;
}
//在以x為根結(jié)點的Trie樹中,查找鍵key.charAt(d)所在的結(jié)點
private Node get(Node x, String key, int d) {
if (x == null)
return null;
if (d == key.length())
return x;
char c = key.charAt(d);
return get(x.children[c], key, d+1);
}
2.2 插入
在插入之前要進行一次查找,在Trie樹中意味著沿著被查找的鍵的所有字符到達樹中表示尾字符的結(jié)點或一個空鏈接。
結(jié)果共2種情況:
- 在到達鍵的尾字符之前就遇到了一個空鏈接;
- 在遇到空鏈接之前就到達了鍵的尾字符。
插入-源碼實現(xiàn):
public void put(String key, Value val) {
if (key == null)
throw new IllegalArgumentException("first argument to put() is null");
root = put(root, key, val, 0);
}
//在以x為根結(jié)點的Trie樹中,插入鍵key.charAt(d)所在的結(jié)點
//返回插入后新樹的根結(jié)點
private Node put(Node x, String key, Value val, int d) {
if (x == null)
x = new Node();
if (d == key.length()) {
if (x.val == null)
n++;
x.val = val;
return x;
}
char c = key.charAt(d);
x.children[c] = put(x.children[c], key, val, d + 1);
return x;
}
2.3 刪除
刪除一個鍵的流程如下:
- 查找鍵所在結(jié)點,并將值置為null;
- 判斷該結(jié)點是否含有指向子結(jié)點的非空鏈接?
如果有,則直接返回;
如果沒有,則刪除該結(jié)點。若刪除后,其父結(jié)點的所有鏈接也為空,就繼續(xù)刪除它的父結(jié)點,依此類推。
注:在遞歸刪除了某個結(jié)點x之后,如果該結(jié)點的值和所有的鏈接均為空,則返回null,否則返回x。

刪除-源碼實現(xiàn):
public void delete(String key) {
if (key == null)
throw new IllegalArgumentException("argument to delete() is null");
root = delete(root, key, 0);
}
//刪除以x為根結(jié)點的樹中的指定鍵,返回調(diào)整后的新樹的根結(jié)點
private Node delete(Node x, String key, int d) {
if (x == null) return null;
if (d == key.length()) {
if (x.val != null) n--;
x.val = null;
}
else {
char c = key.charAt(d);
x.next[c] = delete(x.next[c], key, d+1);
}
// remove subtrie rooted at x if it is completely empty
if (x.val != null) return x;
for (int c = 0; c < R; c++)
if (x.next[c] != null)
return x;
return null;
}
2.4 遍歷
遍歷得到樹中的所有鍵。
注:根結(jié)點相當(dāng)于保存空字符 ""。

遍歷-源碼實現(xiàn):
public Iterable<String> keys() {
return keysWithPrefix("");
}
//查找所有以@prefix為前綴的鍵
public Iterable<String> keysWithPrefix(String prefix) {
Queue<String> results = new Queue<String>();
//查找鍵prefix所在的結(jié)點x
Node x = get(root, prefix, 0);
//在以結(jié)點x為根,查找符合前綴的鍵
collect(x, new StringBuilder(prefix), results);
return results;
}
//在以結(jié)點x為根的子樹中,查找鍵prefix
//注:prefix包含了所有從root到x的字符
private void collect(Node x, StringBuilder prefix, Queue<String> results) {
if (x == null) return;
if (x.val != null)
results.enqueue(prefix.toString());
for (char c = 0; c < R; c++) {
prefix.append(c);
collect(x.next[c], prefix, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}
2.5 完整源碼
public class TrieST<Value> {
private static final int R = 256; // extended ASCII
private Node root; // root of trie
private int n; // number of keys in trie
// R-way trie node
private static class Node {
private Object val;
private Node[] next = new Node[R];
}
public TrieST() {
}
/**
* Returns the value associated with the given key.
* @param key the key
* @return the value associated with the given key if the key is in the symbol table
* and {@code null} if the key is not in the symbol table
* @throws IllegalArgumentException if {@code key} is {@code null}
*/
public Value get(String key) {
if (key == null) throw new IllegalArgumentException("argument to get() is null");
Node x = get(root, key, 0);
if (x == null) return null;
return (Value) x.val;
}
/**
* Does this symbol table contain the given key?
* @param key the key
* @return {@code true} if this symbol table contains {@code key} and
* {@code false} otherwise
* @throws IllegalArgumentException if {@code key} is {@code null}
*/
public boolean contains(String key) {
if (key == null) throw new IllegalArgumentException("argument to contains() is null");
return get(key) != null;
}
private Node get(Node x, String key, int d) {
if (x == null) return null;
if (d == key.length()) return x;
char c = key.charAt(d);
return get(x.next[c], key, d+1);
}
/**
* Inserts the key-value pair into the symbol table, overwriting the old value
* with the new value if the key is already in the symbol table.
* If the value is {@code null}, this effectively deletes the key from the symbol table.
* @param key the key
* @param val the value
* @throws IllegalArgumentException if {@code key} is {@code null}
*/
public void put(String key, Value val) {
if (key == null) throw new IllegalArgumentException("first argument to put() is null");
if (val == null) delete(key);
else root = put(root, key, val, 0);
}
private Node put(Node x, String key, Value val, int d) {
if (x == null) x = new Node();
if (d == key.length()) {
if (x.val == null) n++;
x.val = val;
return x;
}
char c = key.charAt(d);
x.next[c] = put(x.next[c], key, val, d+1);
return x;
}
/**
* Returns the number of key-value pairs in this symbol table.
* @return the number of key-value pairs in this symbol table
*/
public int size() {
return n;
}
/**
* Is this symbol table empty?
* @return {@code true} if this symbol table is empty and {@code false} otherwise
*/
public boolean isEmpty() {
return size() == 0;
}
/**
* Returns all keys in the symbol table as an {@code Iterable}.
* To iterate over all of the keys in the symbol table named {@code st},
* use the foreach notation: {@code for (Key key : st.keys())}.
* @return all keys in the symbol table as an {@code Iterable}
*/
public Iterable<String> keys() {
return keysWithPrefix("");
}
/**
* Returns all of the keys in the set that start with {@code prefix}.
* @param prefix the prefix
* @return all of the keys in the set that start with {@code prefix},
* as an iterable
*/
public Iterable<String> keysWithPrefix(String prefix) {
Queue<String> results = new Queue<String>();
Node x = get(root, prefix, 0);
collect(x, new StringBuilder(prefix), results);
return results;
}
private void collect(Node x, StringBuilder prefix, Queue<String> results) {
if (x == null) return;
if (x.val != null) results.enqueue(prefix.toString());
for (char c = 0; c < R; c++) {
prefix.append(c);
collect(x.next[c], prefix, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}
/**
* Returns all of the keys in the symbol table that match {@code pattern},
* where . symbol is treated as a wildcard character.
* @param pattern the pattern
* @return all of the keys in the symbol table that match {@code pattern},
* as an iterable, where . is treated as a wildcard character.
*/
public Iterable<String> keysThatMatch(String pattern) {
Queue<String> results = new Queue<String>();
collect(root, new StringBuilder(), pattern, results);
return results;
}
private void collect(Node x, StringBuilder prefix, String pattern, Queue<String> results) {
if (x == null) return;
int d = prefix.length();
if (d == pattern.length() && x.val != null)
results.enqueue(prefix.toString());
if (d == pattern.length())
return;
char c = pattern.charAt(d);
if (c == '.') {
for (char ch = 0; ch < R; ch++) {
prefix.append(ch);
collect(x.next[ch], prefix, pattern, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}
else {
prefix.append(c);
collect(x.next[c], prefix, pattern, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}
/**
* Returns the string in the symbol table that is the longest prefix of {@code query},
* or {@code null}, if no such string.
* @param query the query string
* @return the string in the symbol table that is the longest prefix of {@code query},
* or {@code null} if no such string
* @throws IllegalArgumentException if {@code query} is {@code null}
*/
public String longestPrefixOf(String query) {
if (query == null) throw new IllegalArgumentException("argument to longestPrefixOf() is null");
int length = longestPrefixOf(root, query, 0, -1);
if (length == -1) return null;
else return query.substring(0, length);
}
// returns the length of the longest string key in the subtrie
// rooted at x that is a prefix of the query string,
// assuming the first d character match and we have already
// found a prefix match of given length (-1 if no such match)
private int longestPrefixOf(Node x, String query, int d, int length) {
if (x == null) return length;
if (x.val != null) length = d;
if (d == query.length()) return length;
char c = query.charAt(d);
return longestPrefixOf(x.next[c], query, d+1, length);
}
/**
* Removes the key from the set if the key is present.
* @param key the key
* @throws IllegalArgumentException if {@code key} is {@code null}
*/
public void delete(String key) {
if (key == null) throw new IllegalArgumentException("argument to delete() is null");
root = delete(root, key, 0);
}
private Node delete(Node x, String key, int d) {
if (x == null) return null;
if (d == key.length()) {
if (x.val != null) n--;
x.val = null;
}
else {
char c = key.charAt(d);
x.next[c] = delete(x.next[c], key, d+1);
}
// remove subtrie rooted at x if it is completely empty
if (x.val != null) return x;
for (int c = 0; c < R; c++)
if (x.next[c] != null)
return x;
return null;
}
/**
* Unit tests the {@code TrieST} data type.
*
* @param args the command-line arguments
*/
public static void main(String[] args) {
// build symbol table from standard input
TrieST<Integer> st = new TrieST<Integer>();
for (int i = 0; !StdIn.isEmpty(); i++) {
String key = StdIn.readString();
st.put(key, i);
}
// print results
if (st.size() < 100) {
StdOut.println("keys(\"\"):");
for (String key : st.keys()) {
StdOut.println(key + " " + st.get(key));
}
StdOut.println();
}
StdOut.println("longestPrefixOf(\"shellsort\"):");
StdOut.println(st.longestPrefixOf("shellsort"));
StdOut.println();
StdOut.println("longestPrefixOf(\"quicksort\"):");
StdOut.println(st.longestPrefixOf("quicksort"));
StdOut.println();
StdOut.println("keysWithPrefix(\"shor\"):");
for (String s : st.keysWithPrefix("shor"))
StdOut.println(s);
StdOut.println();
StdOut.println("keysThatMatch(\".he.l.\"):");
for (String s : st.keysThatMatch(".he.l."))
StdOut.println(s);
}
}
三、性能分析
時間復(fù)雜度
Trie樹的形狀與鍵的插入(刪除)順序無關(guān)。
查找效率僅與樹的高度有關(guān),而樹的高度由鍵的長度決定。
當(dāng)字母表的大小為R,在一棵由N個隨機鍵構(gòu)造的單詞查找樹中,未命中查找平均所需檢查的結(jié)點數(shù)量~logRN。空間復(fù)雜度
R向單詞查找樹中,鏈接總數(shù)在RN~RNw之間
(R:字母表大小,N:鍵總數(shù),w:鍵平均長度)
故R向單詞查找樹不適合處理字母表R很大的鍵。
四、三向單詞查找樹
4.1 定義
在R向單詞查找樹中,當(dāng)字母表R很大時,會消耗大量空間??梢酝ㄟ^一種稱為“三向單詞查找樹”的數(shù)據(jù)結(jié)構(gòu)進行優(yōu)化。
在三向單詞查找樹中,每個結(jié)點都含有一個字符、三個鏈接、一個值。這三條鏈接分別對應(yīng)小于、等于和大于結(jié)點字符的所有鍵,只有在沿著中間鏈接前進時才會根據(jù)字符找到表中的鍵。
這種實現(xiàn)方式相當(dāng)于將R向單詞查找樹中的每個結(jié)點實現(xiàn)為以非空鏈接所對應(yīng)的字符作為鍵的二叉查找樹。

數(shù)據(jù)結(jié)構(gòu)定義:
public class TST<V> {
private int n; // size
private Node<V> root; // root of TST
private static class Node<V> {
private char c;
private Node<V> left, mid, right;
private V val;
}
}
4.2 實現(xiàn)
4.2.1 查找
查找步驟:
- 比較鍵的首字符與樹的根結(jié)點字符的大小。
如果鍵首字符較小,則選擇左鏈接;
如果較大,則選擇右鏈接;
如果相等,則選擇中鏈接。 - 遞歸地重復(fù)步驟1;
- 直到遇到一個空鏈接或到達鍵的末尾。
如果為空鏈接,則未命中;
如果到達鍵的末尾,且結(jié)點值為空,則未命中;
如果到達鍵的末尾,且結(jié)點值非空,則命中。

查找-源碼實現(xiàn):
public V get(String key) {
if (key == null)
throw new IllegalArgumentException("calls get() with null argument");
if (key.length() == 0)
throw new IllegalArgumentException("key must have length >= 1");
Node<V> x = get(root, key, 0);
if (x == null)
return null;
return x.val;
}
// 在以x為根結(jié)點的樹中,查找鍵key[d]
private Node<V> get(Node<V> x, String key, int d) {
if (x == null)
return null;
if (key.length() == 0)
throw new IllegalArgumentException("key must have length >= 1");
char c = key.charAt(d);
if (c < x.c)
return get(x.left, key, d);
else if (c > x.c)
return get(x.right, key, d);
else {
if (d < key.length() - 1)
return get(x.mid, key, d + 1);
else
return x;
}
}
4.2.2 插入
插入-源碼實現(xiàn):
public void put(String key, V val) {
if (key == null) {
throw new IllegalArgumentException("calls put() with null key");
}
if (!contains(key))
n++;
root = put(root, key, val, 0);
}
//在以x為根結(jié)點的樹中,插入結(jié)點key[d],返回新樹的根結(jié)點
private Node<V> put(Node<V> x, String key, V val, int d) {
char c = key.charAt(d);
if (x == null) {
x = new Node<V>();
x.c = c;
}
if (c < x.c)
x.left = put(x.left, key, val, d);
else if (c > x.c)
x.right = put(x.right, key, val, d);
else {
if (d < key.length() - 1)
x.mid = put(x.mid, key, val, d + 1);
else
x.val = val;
}
return x;
}
4.2.3 完整源碼
public class TST<Value> {
private int n; // size
private Node<Value> root; // root of TST
private static class Node<Value> {
private char c; // character
private Node<Value> left, mid, right; // left, middle, and right subtries
private Value val; // value associated with string
}
public TST() {
}
/**
* Returns the number of key-value pairs in this symbol table.
* @return the number of key-value pairs in this symbol table
*/
public int size() {
return n;
}
/**
* Does this symbol table contain the given key?
* @param key the key
* @return {@code true} if this symbol table contains {@code key} and
* {@code false} otherwise
* @throws IllegalArgumentException if {@code key} is {@code null}
*/
public boolean contains(String key) {
if (key == null) {
throw new IllegalArgumentException("argument to contains() is null");
}
return get(key) != null;
}
/**
* Returns the value associated with the given key.
* @param key the key
* @return the value associated with the given key if the key is in the symbol table
* and {@code null} if the key is not in the symbol table
* @throws IllegalArgumentException if {@code key} is {@code null}
*/
public Value get(String key) {
if (key == null) {
throw new IllegalArgumentException("calls get() with null argument");
}
if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
Node<Value> x = get(root, key, 0);
if (x == null) return null;
return x.val;
}
// return subtrie corresponding to given key
private Node<Value> get(Node<Value> x, String key, int d) {
if (x == null) return null;
if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
char c = key.charAt(d);
if (c < x.c) return get(x.left, key, d);
else if (c > x.c) return get(x.right, key, d);
else if (d < key.length() - 1) return get(x.mid, key, d+1);
else return x;
}
/**
* Inserts the key-value pair into the symbol table, overwriting the old value
* with the new value if the key is already in the symbol table.
* If the value is {@code null}, this effectively deletes the key from the symbol table.
* @param key the key
* @param val the value
* @throws IllegalArgumentException if {@code key} is {@code null}
*/
public void put(String key, Value val) {
if (key == null) {
throw new IllegalArgumentException("calls put() with null key");
}
if (!contains(key)) n++;
root = put(root, key, val, 0);
}
private Node<Value> put(Node<Value> x, String key, Value val, int d) {
char c = key.charAt(d);
if (x == null) {
x = new Node<Value>();
x.c = c;
}
if (c < x.c) x.left = put(x.left, key, val, d);
else if (c > x.c) x.right = put(x.right, key, val, d);
else if (d < key.length() - 1) x.mid = put(x.mid, key, val, d+1);
else x.val = val;
return x;
}
/**
* Returns the string in the symbol table that is the longest prefix of {@code query},
* or {@code null}, if no such string.
* @param query the query string
* @return the string in the symbol table that is the longest prefix of {@code query},
* or {@code null} if no such string
* @throws IllegalArgumentException if {@code query} is {@code null}
*/
public String longestPrefixOf(String query) {
if (query == null) {
throw new IllegalArgumentException("calls longestPrefixOf() with null argument");
}
if (query.length() == 0) return null;
int length = 0;
Node<Value> x = root;
int i = 0;
while (x != null && i < query.length()) {
char c = query.charAt(i);
if (c < x.c) x = x.left;
else if (c > x.c) x = x.right;
else {
i++;
if (x.val != null) length = i;
x = x.mid;
}
}
return query.substring(0, length);
}
/**
* Returns all keys in the symbol table as an {@code Iterable}.
* To iterate over all of the keys in the symbol table named {@code st},
* use the foreach notation: {@code for (Key key : st.keys())}.
* @return all keys in the symbol table as an {@code Iterable}
*/
public Iterable<String> keys() {
Queue<String> queue = new Queue<String>();
collect(root, new StringBuilder(), queue);
return queue;
}
/**
* Returns all of the keys in the set that start with {@code prefix}.
* @param prefix the prefix
* @return all of the keys in the set that start with {@code prefix},
* as an iterable
* @throws IllegalArgumentException if {@code prefix} is {@code null}
*/
public Iterable<String> keysWithPrefix(String prefix) {
if (prefix == null) {
throw new IllegalArgumentException("calls keysWithPrefix() with null argument");
}
Queue<String> queue = new Queue<String>();
Node<Value> x = get(root, prefix, 0);
if (x == null) return queue;
if (x.val != null) queue.enqueue(prefix);
collect(x.mid, new StringBuilder(prefix), queue);
return queue;
}
// all keys in subtrie rooted at x with given prefix
private void collect(Node<Value> x, StringBuilder prefix, Queue<String> queue) {
if (x == null) return;
collect(x.left, prefix, queue);
if (x.val != null) queue.enqueue(prefix.toString() + x.c);
collect(x.mid, prefix.append(x.c), queue);
prefix.deleteCharAt(prefix.length() - 1);
collect(x.right, prefix, queue);
}
/**
* Returns all of the keys in the symbol table that match {@code pattern},
* where . symbol is treated as a wildcard character.
* @param pattern the pattern
* @return all of the keys in the symbol table that match {@code pattern},
* as an iterable, where . is treated as a wildcard character.
*/
public Iterable<String> keysThatMatch(String pattern) {
Queue<String> queue = new Queue<String>();
collect(root, new StringBuilder(), 0, pattern, queue);
return queue;
}
private void collect(Node<Value> x, StringBuilder prefix, int i, String pattern, Queue<String> queue) {
if (x == null) return;
char c = pattern.charAt(i);
if (c == '.' || c < x.c) collect(x.left, prefix, i, pattern, queue);
if (c == '.' || c == x.c) {
if (i == pattern.length() - 1 && x.val != null) queue.enqueue(prefix.toString() + x.c);
if (i < pattern.length() - 1) {
collect(x.mid, prefix.append(x.c), i+1, pattern, queue);
prefix.deleteCharAt(prefix.length() - 1);
}
}
if (c == '.' || c > x.c) collect(x.right, prefix, i, pattern, queue);
}
/**
* Unit tests the {@code TST} data type.
*
* @param args the command-line arguments
*/
public static void main(String[] args) {
// build symbol table from standard input
TST<Integer> st = new TST<Integer>();
for (int i = 0; !StdIn.isEmpty(); i++) {
String key = StdIn.readString();
st.put(key, i);
}
// print results
if (st.size() < 100) {
StdOut.println("keys(\"\"):");
for (String key : st.keys()) {
StdOut.println(key + " " + st.get(key));
}
StdOut.println();
}
StdOut.println("longestPrefixOf(\"shellsort\"):");
StdOut.println(st.longestPrefixOf("shellsort"));
StdOut.println();
StdOut.println("longestPrefixOf(\"shell\"):");
StdOut.println(st.longestPrefixOf("shell"));
StdOut.println();
StdOut.println("keysWithPrefix(\"shor\"):");
for (String s : st.keysWithPrefix("shor"))
StdOut.println(s);
StdOut.println();
StdOut.println("keysThatMatch(\".he.l.\"):");
for (String s : st.keysThatMatch(".he.l."))
StdOut.println(s);
}
}
4.3 性能分析
三向單詞查找樹是R向單詞查找樹的緊湊表示,其每個結(jié)點只含有3條鏈接。
- 時間復(fù)雜度
查找未命中平均需要比較~InN次。 - 空間復(fù)雜度
三向單詞查找樹中,鏈接總數(shù)在3N~3Nw之間
(N:鍵總數(shù),w:鍵平均長度)
五、各類字符串查找算法比較
在空間足夠的情況下,R向單詞查找樹的速度是最快的,能夠在常數(shù)次字符比較內(nèi)完成查找。
但是對于大型字母表,R向單詞查找樹的空間通常無法滿足需求,此時三向單詞查找樹是較好的選擇,它對字符的比較次數(shù)是對數(shù)級別的。
