前言
互聯(lián)網(wǎng)發(fā)展至今,搜索引擎仍然是獲取信息最重要的途徑之一,而搜索結(jié)果的排序是搜索引擎的核心技術(shù)之一,常見的排序算法有 PageRank、向量空間模型( 如:TF-IDF)、概率模型(如:BM25)、機(jī)器學(xué)習(xí)排序等,今天準(zhǔn)備通過實(shí)例介紹一下人工神經(jīng)網(wǎng)絡(luò)在搜索結(jié)果排序中的應(yīng)用。
人工神經(jīng)網(wǎng)絡(luò)
人工神經(jīng)網(wǎng)絡(luò)是一種模仿生物神經(jīng)網(wǎng)絡(luò)(動物的中樞神經(jīng)系統(tǒng),特別是大腦)的結(jié)構(gòu)和功能的數(shù)學(xué)模型或計(jì)算模型,用于對函數(shù)進(jìn)行估計(jì)或近似。深入的概念以及定義讀者可以自行谷歌,在本文中用到的神經(jīng)網(wǎng)絡(luò)成為多層感知器網(wǎng)絡(luò),它由多層的神經(jīng)元組成,每一層的神經(jīng)元輸入是上一層的輸出,本層的輸出是下一層的輸入,依次相連,理論上可以有 N 層,通常采用三層網(wǎng)絡(luò),即:輸入層、隱藏層、輸出層。在搜索排序中,輸入層為查詢的詞,輸出層為結(jié)果文檔,如圖:

如圖 1-1 所示,輸入層傳入查詢詞,經(jīng)過隱藏層轉(zhuǎn)化后輸出,上圖中每個線條有不同的強(qiáng)度,因此輸出層的每個文檔也有不同的強(qiáng)度,即相關(guān)度。如圖:

圖 1-2 是一個訓(xùn)練后的神經(jīng)網(wǎng)絡(luò)的實(shí)例,每個輸入節(jié)點(diǎn)到隱藏節(jié)點(diǎn)、隱藏節(jié)點(diǎn)到輸出節(jié)點(diǎn)之間都有一個不同強(qiáng)弱的權(quán)重,對于任意的輸入,根據(jù)自身的權(quán)重游走到隱藏層,隱藏層累加所有的輸入后經(jīng)過特定函數(shù)轉(zhuǎn)化,游走到輸出層,輸出層累加所有的輸入即為結(jié)果。
訓(xùn)練神經(jīng)網(wǎng)絡(luò)的核心工作就是創(chuàng)建一個隱藏層,并訓(xùn)練出輸入節(jié)點(diǎn)到隱藏節(jié)點(diǎn)、隱藏節(jié)點(diǎn)到輸出節(jié)點(diǎn)之間的權(quán)重關(guān)系。有了該網(wǎng)絡(luò)后,對于任意的輸入,模型都能給出一個合理的預(yù)測結(jié)果。
數(shù)據(jù)準(zhǔn)備
數(shù)據(jù)結(jié)構(gòu)
為了訓(xùn)練神經(jīng)網(wǎng)絡(luò),我們需要以下幾方面的數(shù)據(jù)
- 用于訓(xùn)練的數(shù)據(jù),包含查詢詞—查詢結(jié)果—期望值,如:
//格式:查詢詞,查詢結(jié)果,結(jié)果期望
兒童 感冒,[doc:兒童感冒,doc:玩具,doc:感冒藥],[1.0,0.0,1.0]
玩具,[doc:兒童感冒,doc:玩具,doc:感冒藥],[0.0,1.0,0.0]
......
一般會把數(shù)據(jù)映射成 id,假設(shè)映射關(guān)系如下:
查詢詞 id 映射:
1 <-> 兒童
2 <-> 感冒
3 <-> 玩具
查詢結(jié)果 id 映射:
1 <-> doc:兒童感冒
2 <-> doc:玩具
3 <-> doc:感冒藥
則訓(xùn)練數(shù)據(jù)最終為:
//格式:查詢詞,查詢結(jié)果,結(jié)果期望
1 2,[1,2,3],[1.0,0.0,1.0]
2,[1,2,3],[0.0,1.0,0.0]
......
- 隱藏層節(jié)點(diǎn),用 mysql 存儲,表結(jié)構(gòu)為:
create table `ann`.`hidden_node` (
`id` int(11) NOT NULL auto_increment primary key,
`node_name` varchar(255) NOT NULL
)ENGINE=InnoDB DEFAULT CHARSET=utf8;
- 輸入節(jié)點(diǎn)與隱藏節(jié)點(diǎn)的關(guān)系表,用 mysql 存儲,表結(jié)構(gòu)為:
create table `ann`.`word2hidden` (
`fromid` int(11) NOT NULL, //輸入節(jié)點(diǎn) id,這里為詞 id
`toid` int(11) NOT NULL, //目標(biāo)節(jié)點(diǎn) id,這里為隱藏節(jié)點(diǎn)的 id
`strength` double NOT NULL //信號強(qiáng)度
)ENGINE=InnoDB DEFAULT CHARSET=utf8;
- 隱藏節(jié)點(diǎn)與輸出節(jié)點(diǎn)的關(guān)系表,用 mysql 存儲,表結(jié)構(gòu)為:
create table `ann`.`hidden2doc` (
`fromid` int(11) NOT NULL, //輸入節(jié)點(diǎn) id,這里為隱藏節(jié)點(diǎn)的 id
`toid` int(11) NOT NULL, //目標(biāo)節(jié)點(diǎn) id,這里為結(jié)果文檔 id
`strength` double NOT NULL
)ENGINE=InnoDB DEFAULT CHARSET=utf8;
數(shù)據(jù)操作 API
建立好數(shù)據(jù)結(jié)構(gòu)之后。還需要寫對數(shù)據(jù)操作的 API,主要是兩個操作:
- 獲取兩個節(jié)點(diǎn)間的信號強(qiáng)度,稱為 getStrength 方法,主要負(fù)責(zé)從數(shù)據(jù)庫中查詢出信號強(qiáng)度,如果這兩個節(jié)點(diǎn)還未建立任何的關(guān)系,則賦予一個默認(rèn)值,如果是第 0 層(layer=0,即輸入節(jié)點(diǎn)與隱藏節(jié)點(diǎn)之間的聯(lián)系),則賦默認(rèn)值 -0.2;如果是第 1 層(layer=0,即隱藏節(jié)點(diǎn)與輸出節(jié)點(diǎn)之間的聯(lián)系)則賦默認(rèn)值 0.0,java 實(shí)現(xiàn)代碼如下:
private double getStrength(int fromid, int toid, int layer) throws SQLException {
String table = layer == 0 ? "word2hidden" : "hidden2doc";
double default_strength = layer == 0 ? -0.2 : 0.0;
String sql = "select strength from " + table + " where fromid=? and toid=?";
PreparedStatement pstat = conn.prepareStatement(sql);
pstat.setInt(1, fromid);
pstat.setInt(2, toid);
ResultSet rs = pstat.executeQuery();
if (rs.next()) {
return rs.getDouble(1);
} else {
return default_strength;
}
}
- 設(shè)置兩個節(jié)點(diǎn)間的信號強(qiáng)度,稱為 setStrength 方法,java 實(shí)現(xiàn)代碼如下:
private void setStrength(int fromid, int toid, int layer, double strength) throws SQLException {
String table = layer == 0 ? "word2hidden" : "hidden2doc";
String sql = "select strength from " + table +
" where fromid=" + fromid +
" and toid=" + toid;
ResultSet rs = conn.createStatement().executeQuery(sql);
if (rs.next()) {
conn.createStatement().execute("update " + table +
" set strength=" + strength +
" where fromid=" + fromid +
" and toid=" + toid);
} else {
conn.createStatement().execute("insert into " + table +
" (fromid, toid, strength)" +
" values(" + fromid + "," + toid + "," + strength + ")");
}
}
模型訓(xùn)練
隱藏層建立
一般情況下,在構(gòu)建神經(jīng)網(wǎng)絡(luò)時,我們會預(yù)先建立好所有的節(jié)點(diǎn)。不過在本例中,我們只在需要的時候建立新的隱藏節(jié)點(diǎn),更為簡單高效。
每當(dāng)我們傳入以前從未見過的查詢時,我們就建立一個隱藏節(jié)點(diǎn),隨后,給每個查詢中從詞與該隱藏節(jié)點(diǎn)之間、隱藏節(jié)點(diǎn)與輸出節(jié)點(diǎn)之間賦予默認(rèn)值。
我們用一個名為 generateHiddenNode 的方法來實(shí)現(xiàn)此功能,java 代碼如下:
private void generateHiddenNode(List<Integer> queryWordIds,
List<Integer> docIds) throws SQLException {
if (queryWordIds.size() > 3) { //暫不支持過長的查詢詞組
return;
}
//判斷該查詢是否已經(jīng)存在隱藏節(jié)點(diǎn)
String key = queryWordIds.stream()
.map(id -> String.valueOf(id))
.sorted()
.collect(Collectors.joining("_"));
ResultSet rs = conn.createStatement()
.executeQuery("select id from hidden_node where node_name='" + key + "'");
if (!rs.next()) {
//新建隱藏節(jié)點(diǎn)
PreparedStatement ps = conn.prepareStatement("insert into hidden_node values (NULL,?)",
Statement.RETURN_GENERATED_KEYS);
ps.setString(1, key);
ps.executeUpdate();
rs = ps.getGeneratedKeys();
rs.next();
int id = rs.getInt(1);
for (int queryWordId : queryWordIds) {
//查詢詞組與該隱藏節(jié)點(diǎn)的默認(rèn)權(quán)重為 1.0 / 詞個數(shù)
setStrength(queryWordId, id, 0, 1.0 / queryWordIds.size());
}
for (int docId : docIds) {
//隱藏節(jié)點(diǎn)與查詢結(jié)果的默認(rèn)權(quán)重為 0.1
setStrength(id, docId, 1, 0.1);
}
}
}
前饋法預(yù)測結(jié)果
建立隱藏層后,其實(shí)我們已經(jīng)擁有了一個最基本的神經(jīng)網(wǎng)絡(luò)(雖然此時所有值都是默認(rèn)賦值的,不「智能」,這個問題后面會解決),因此,對于任意給定的查詢詞與查詢結(jié)果,我們能為每個結(jié)果給出一個預(yù)測結(jié)果,步驟如下:
- 獲取與本次查詢相關(guān)的所有隱藏節(jié)點(diǎn),這是一項(xiàng)性能優(yōu)化措施,因?yàn)樵诒纠?,只有跟查詢詞或者查詢結(jié)果相關(guān)的隱藏節(jié)點(diǎn)才會對最終的預(yù)測結(jié)果又影響,因此我們只需要找出相關(guān)的因此節(jié)點(diǎn)進(jìn)行后續(xù)的計(jì)算皆可以了,具體實(shí)現(xiàn)為 getAllRelatedHiddens 方法:
private List<Integer> getAllRelatedHiddens(List<Integer> queryWordIds,
List<Integer> docIds) throws SQLException {
Set<Integer> hidden_ids = new HashSet<>();
//與輸入節(jié)點(diǎn)相關(guān)的隱藏節(jié)點(diǎn)
String queryWordIds_str = queryWordIds.stream()
.map(id -> String.valueOf(id))
.collect(Collectors.joining(","));
ResultSet rs = conn.createStatement()
.executeQuery("select toid from word2hidden where fromid in (" + queryWordIds_str + ")");
while (rs.next()) {
hidden_ids.add(rs.getInt(1));
}
//與輸出節(jié)點(diǎn)相關(guān)的隱藏節(jié)點(diǎn)
String resultIds_str = docIds.stream()
.map(id -> String.valueOf(id))
.collect(Collectors.joining(","));
rs = conn.createStatement()
.executeQuery("select toid from hidden2doc where fromid in (" + resultIds_str + ")");
while (rs.next()) {
hidden_ids.add(rs.getInt(1));
}
return new ArrayList<>(hidden_ids);
}
- 構(gòu)建本次查詢的神經(jīng)網(wǎng)絡(luò)的權(quán)重矩陣
現(xiàn)在我們已經(jīng)有了輸入節(jié)點(diǎn)(查詢詞組)、相關(guān)的隱藏節(jié)點(diǎn)、輸出節(jié)點(diǎn)(查詢結(jié)果)、輸入節(jié)點(diǎn)與隱藏節(jié)點(diǎn)的信號強(qiáng)度,隱藏節(jié)點(diǎn)與輸出節(jié)點(diǎn)的信號強(qiáng)度,我們可以以此構(gòu)建兩個權(quán)重矩陣,分別為:
- 輸入權(quán)重矩陣:記錄第 i 個查詢詞與第 j 個隱藏節(jié)點(diǎn)之間的信號強(qiáng)度
- 輸出權(quán)重矩陣:記錄第 j 個隱藏節(jié)點(diǎn)與第 k 個查詢結(jié)果之間的信號強(qiáng)度
代碼實(shí)現(xiàn)如下:
private List<Integer> queryWordIds;//查詢詞組
private List<Integer> hidden_ids;//相關(guān)隱藏節(jié)點(diǎn)
private List<Integer> docIds;//查詢結(jié)果
private double[][] input_weight;//輸入層的權(quán)重矩陣
private double[][] output_weight;//輸出層的權(quán)重矩陣
private double[] word_val;//每個詞的輸出信號
private double[] hidden_val;//每個隱藏層的輸出信號
private double[] doc_val;//每個輸出文檔的輸出信號
private void setupNetWork(List<Integer> queryWordIds, List<Integer> docIds) throws SQLException {
//初始化參數(shù),供后續(xù)使用
this.queryWordIds = queryWordIds;
this.docIds = docIds;
hidden_ids = getAllRelatedHiddens(queryWordIds, docIds);
word_val = new double[queryWordIds.size()];
hidden_val = new double[hidden_ids.size()];
doc_val = new double[docIds.size()];
//構(gòu)造 word -> hidden 權(quán)重矩陣
input_weight = new double[queryWordIds.size()][hidden_ids.size()];
for (int i = 0; i < queryWordIds.size(); i++) {
for(int j=0;j<hidden_ids.size();j++){
input_weight[i][j] = getStrength(queryWordIds.get(i), hidden_ids.get(j), 0);
}
}
//構(gòu)造 hidden -> docid 權(quán)重矩陣
output_weight = new double[hidden_ids.size()][docIds.size()];
for(int j=0;j<hidden_ids.size();j++){
for(int k=0;k<docIds.size();k++){
output_weight[j][k] = getStrength(hidden_ids.get(j), docIds.get(k), 1);
}
}
}
- 構(gòu)造前饋算法
有了權(quán)重矩陣后,我們就可以一層一層地依次往下計(jì)算,直到得到最終結(jié)果。在這里為: - 查詢單詞依據(jù)輸入權(quán)重矩陣,向隱藏節(jié)點(diǎn)輸出信號,隱藏節(jié)點(diǎn)匯總所有輸入信號強(qiáng)度,并通過S 型函數(shù)反饋輸入,輸出反饋后自身節(jié)點(diǎn)的信號
- 依據(jù)輸出權(quán)重矩陣,隱藏節(jié)點(diǎn)向輸出層輸出信號,輸出節(jié)點(diǎn)匯總所有輸入信號強(qiáng)度,并通過S 型函數(shù)反饋輸入,形成結(jié)果預(yù)測
整個前饋算法進(jìn)行結(jié)果預(yù)測的過程如圖:

其中:默認(rèn)查詢詞的信號強(qiáng)度是 1.0
S 型函數(shù)是神經(jīng)元負(fù)責(zé)對輸入進(jìn)行反饋的函數(shù),在這里我們使用反雙面正切變換函數(shù)(tanh),其函數(shù)圖像為:

前饋算法代碼如下:
private void feedForward(){
for(int i=0;i<word_val.length;i++){
word_val[i] = 1.0;
}
for(int j=0;j<hidden_val.length;j++){
double sum = 0.0;
for(int i=0;i<word_val.length;i++){
sum += word_val[i] * input_weight[i][j];
}
hidden_val[j] = Math.tanh(sum);
}
for(int k=0;k<doc_val.length;k++){
double sum = 0.0;
for(int j=0;j<hidden_val.length;j++){
sum += hidden_val[j] * output_weight[j][k];
}
doc_val[k] = Math.tanh(sum);
}
}
反向傳播法調(diào)整權(quán)重矩陣
在跑完上面的前饋法預(yù)測出結(jié)果后,我們會發(fā)現(xiàn)所有的結(jié)果值都相同,因?yàn)橐陨辖⒌乃袡?quán)重都是一樣的默認(rèn)值,因此預(yù)測的結(jié)果也是毫無價值的。接下去的工作就是調(diào)整各個節(jié)點(diǎn)之間的權(quán)重,使之?dāng)M合真實(shí)情況。
調(diào)整權(quán)重的核心工作是計(jì)算預(yù)測值與真實(shí)值之間的誤差,用這個誤差反向調(diào)整各個節(jié)點(diǎn)連接之間的權(quán)重。由于之前我們了 S 型函數(shù)進(jìn)行信號反饋,其特點(diǎn)是結(jié)果在 0 附近變化特別快,而結(jié)果在 1 或 -1 附近時變化又特別緩慢,因此我們設(shè)計(jì)了一個 dtanh 函數(shù)來進(jìn)行誤差反向調(diào)整時加權(quán),用以適應(yīng) tanh 函數(shù)的特性,dtanh 函數(shù)如下:
public static double dtanh(double y){
return 1.0 - y * y;
}
下面我們介紹反向傳播法調(diào)整權(quán)重矩陣的具體步驟:
- 對于輸出層的每個節(jié)點(diǎn):
- 計(jì)算當(dāng)前輸出結(jié)果與期望結(jié)果之間的差距
- 利用 dtanh 函數(shù)確定節(jié)點(diǎn)的總輸入需要如何改變
- 根據(jù)總輸入需要改變的量,調(diào)整所有隱藏節(jié)點(diǎn)到輸出節(jié)點(diǎn)間的權(quán)重
- 對于隱藏層中的每個節(jié)點(diǎn):
- 計(jì)算由該隱藏層導(dǎo)致的輸出層每個節(jié)點(diǎn)的誤差,求和,即隱藏層需要改變的輸出結(jié)果
- 利用 dtanh 函數(shù)確定節(jié)點(diǎn)的總輸出需要如何改變
- 根據(jù)每個隱藏節(jié)點(diǎn)總輸出需要改變的量,調(diào)整所有輸入節(jié)點(diǎn)到隱藏節(jié)點(diǎn)間的權(quán)重
畫成流程圖如下:

如上圖所示,紅色虛線代表計(jì)算本節(jié)點(diǎn)需要作出改變的值,紅色的實(shí)線代表一次權(quán)重調(diào)整后的值,每經(jīng)過一次調(diào)整,整個網(wǎng)絡(luò)就向真實(shí)情況逼近一點(diǎn)。
代碼實(shí)現(xiàn)如下:
private void backPropagate(double[] targets){
//輸出層誤差
double[] output_deltas = new double[docIds.size()];
for(int k=0;k<docIds.size();k++){
double error = targets[k] - doc_val[k];
output_deltas[k] = dtanh(doc_val[k]) * error;
}
//隱藏層誤差
double[] hidden_deltas = new double[hidden_ids.size()];
for(int j=0;j<hidden_ids.size();j++){
double error = 0.0;
for(int k=0;k<docIds.size();k++){
error += output_deltas[k] * output_weight[j][k];
}
hidden_deltas[j] = dtanh(hidden_val[j]) * error;
}
//更新輸出權(quán)重矩陣
for(int j=0;j<hidden_ids.size();j++){
for(int k=0;k<docIds.size();k++){
double change = output_deltas[k] * hidden_val[j];
output_weight[j][k] = output_weight[j][k] + change;
}
}
//跟新輸入權(quán)重矩陣
for(int i=0;i<queryWordIds.size();i++){
for(int j=0;j<hidden_ids.size();j++){
double change = hidden_deltas[j] * word_val[i];
input_weight[i][j] = input_weight[i][j] + change;
}
}
}
在調(diào)整完權(quán)重矩陣后,我們還需要將節(jié)點(diǎn)間關(guān)系數(shù)據(jù)保存到數(shù)據(jù)庫,代碼如下:
private void updateDatabase() throws SQLException {
for(int i=0;i<queryWordIds.size();i++) {
for (int j = 0; j < hidden_ids.size(); j++) {
setStrength(queryWordIds.get(i), hidden_ids.get(j), 0, input_weight[i][j]);
}
}
for(int j=0;j<hidden_ids.size();j++) {
for (int k = 0; k < docIds.size(); k++) {
setStrength(hidden_ids.get(j), docIds.get(k), 1, output_weight[j][k]);
}
}
}
訓(xùn)練流程匯總
首先我們回顧下訓(xùn)練神經(jīng)網(wǎng)絡(luò)的所有步驟,依次為:
- 建立隱藏層
- 前饋法預(yù)測結(jié)果
- 反向傳播法調(diào)整權(quán)重矩陣
下面我們就可以寫一個方法,進(jìn)行一次完整的訓(xùn)練過程,代碼如下:
public void train(List<Integer> queryWordIds, List<Integer> docIds, double[] targets) throws SQLException {
//增加隱藏節(jié)點(diǎn)
generateHiddenNode(queryWordIds, docIds);
//構(gòu)建神經(jīng)網(wǎng)絡(luò)的權(quán)重矩陣
setupNetWork(queryWordIds, docIds);
//前饋法預(yù)測
feedForward();
//反向傳播法調(diào)整權(quán)重矩陣
backPropagate(targets);
//將新的連接關(guān)系存入數(shù)據(jù)庫
updateDatabase();
}
為了方便觀察結(jié)果,我們添加一個 getResult 方法獲取當(dāng)前神經(jīng)網(wǎng)絡(luò)對輸入的預(yù)測情況,代碼如下:
public double[] getResult(List<Integer> queryWordIds, List<Integer> docIds) throws SQLException {
setupNetWork(queryWordIds, docIds);
feedForward();
return doc_val;
}
最后,就可以利用歷史數(shù)據(jù)進(jìn)行神經(jīng)網(wǎng)絡(luò)的訓(xùn)練了,如:
public static void main(String[] args) throws Exception {
String[] docs = new String[]{"兒童感冒", "玩具", "感冒藥"};
//doc:兒童感冒, doc:玩具, doc:感冒藥
List<Integer> docIds = Arrays.asList(1, 2, 3);
//"兒童 感冒"
List<Integer> queryWordIds_1 = Arrays.asList(1, 2);
double[] targets_1 = new double[]{1.0, 0.0, 1.0};
//"玩具"
List<Integer> queryWordIds_2 = Arrays.asList(3);
double[] targets_2 = new double[]{0.0, 1.0, 0.0};
//"感冒"
List<Integer> queryWordIds_3 = Arrays.asList(2);
double[] targets_3 = new double[]{0.0, 0.0, 1.0};
AnnRank annRank = new AnnRank();
for(int i=0;i<10;i++){
annRank.train(queryWordIds_1, docIds, targets_1);
annRank.train(queryWordIds_2, docIds, targets_2);
annRank.train(queryWordIds_3, docIds, targets_3);
}
//"兒童 玩具"
double[] rs = annRank.getResult(Arrays.asList(3), docIds);
System.out.println("========== query: 兒童 玩具 =========");
for(int i=0;i<docs.length;i++){
System.out.println(docs[i] + ":" + rs[i]);
}
//"兒童"
rs = annRank.getResult(Arrays.asList(1), docIds);
System.out.println("========== query: 兒童 =========");
for(int i=0;i<docs.length;i++){
System.out.println(docs[i] + ":" + rs[i]);
}
}
輸出:
========== query: 兒童 玩具 =========
兒童感冒:-0.04868859068035461
玩具:0.9509734260901455
感冒藥:-0.20469923136438772
========== query: 兒童 =========
兒童感冒:0.933995702651645
玩具:-0.37891769815951537
感冒藥:0.7154603176065658
總結(jié)
本文簡單介紹了人工神經(jīng)網(wǎng)絡(luò)在搜索排序中的應(yīng)用,在實(shí)際與搜索引擎的結(jié)合中,還有一些工程上的問題需要解決(如頻繁數(shù)據(jù)庫更新的效率問題),有興趣的讀者可以通過實(shí)踐進(jìn)行深入的了解。