Spark的MLlib實(shí)現(xiàn)了協(xié)同過濾(Collaborative Filtering)這個(gè)功能。官網(wǎng)文檔鏈接
熟悉推薦算法的同學(xué)可能也有這個(gè)認(rèn)識(shí):協(xié)同過濾主要分為3大類——1、基于User的協(xié)同過濾;2、基于Item的協(xié)同過濾;3、基于Model的協(xié)同過濾。前面兩個(gè)比較簡(jiǎn)單不多描述了,主要講下基于Model的協(xié)同過濾。在網(wǎng)上找到一個(gè)對(duì)基于Model的協(xié)同過濾的算法總結(jié)包括:Aspect Model,pLSA,LDA,聚類,SVD,Matrix Factorization等。不管這句話說的是否嚴(yán)謹(jǐn)(比如還有二分圖模型),總之我認(rèn)為Spark MLlib目前(2.2.0版本)并不能算是完整的協(xié)同過濾。只是做了基于Model的協(xié)同過濾中的矩陣分解內(nèi)容。當(dāng)然做好了矩陣分解,接下來再做別的也就輕松了。
關(guān)于基于Model的矩陣分解,可以參考矩陣分解在協(xié)同過濾推薦算法中的應(yīng)用。Spark的MLlib中使用的是ALS(Alternating Least Squares (ALS) matrix factorization)算法。這個(gè)可以看成是對(duì)FunkSVD的一種求解實(shí)現(xiàn)。不過考慮到有時(shí)候我們輸入的User-Item的rating可能不是某種評(píng)判的數(shù)值打分,而是User對(duì)于Item的某種偏好,此時(shí)使用ALS-WR(alternating-least-squares with weighted-λ-regularization)通過置信度權(quán)重來重新定義目標(biāo)函數(shù),從而得到新的結(jié)果。關(guān)于ALS和ALS-WR可以參考協(xié)同過濾之ALS-WR算法和機(jī)器學(xué)習(xí)(十四)——協(xié)同過濾的ALS算法(2)、主成分分析以及協(xié)同過濾 CF & ALS 及在Spark上的實(shí)現(xiàn)
上面主要是理論基礎(chǔ)部分,熟悉了理論基礎(chǔ)后,我們看下通過Spark的MLlib的落地實(shí)現(xiàn),我們需要做哪些工作。同時(shí)依然建議參考另2篇文章ALS-WR(協(xié)同過濾推薦算法) in ML和深入理解Spark ML:基于ALS矩陣分解的協(xié)同過濾算法與源碼分析
Collaborative filtering
正如前面所講的,我們的工作是要把評(píng)分矩陣用User和Item的latent factors表達(dá)出來。MLlib通過ALS算法來學(xué)習(xí)得到User以及Item的latent factors,在具體的實(shí)現(xiàn)中需要以下參數(shù):
- numBlocks is the number of blocks the users and items will be partitioned into in order to parallelize computation (defaults to 10). 用于并行計(jì)算,同時(shí)設(shè)置User和Item的block數(shù)目,還可以使用numUserBlocks和numItemBlocks分別設(shè)置User和Item的block數(shù)目。
- rank is the number of latent factors in the model (defaults to 10). 表示latent factors的長度。對(duì)于這個(gè)值的設(shè)置參見What is recommended number of latent factors for the implicit collaborative filtering using ALS
- maxIter is the maximum number of iterations to run (defaults to 10). 交替計(jì)算User和Item的latent factors的迭代次數(shù)。
- regParam specifies the regularization parameter in ALS (defaults to 1.0). L2正則的系數(shù)lambda
- implicitPrefs specifies whether to use the explicit feedback ALS variant or one adapted for implicit feedback data (defaults to false which means using explicit feedback). 表示原始User和Item的rating矩陣的值是否是評(píng)判的打分值,F(xiàn)alse表示是打分值,True表示是矩陣的值是某種偏好。
- alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference observations (defaults to 1.0). 當(dāng)implicitPrefs為true時(shí),表示對(duì)原始rating的一個(gè)置信度系數(shù),用于和rate相乘,是一個(gè)常值??梢愿鶕?jù)對(duì)于原始數(shù)據(jù)的觀察,統(tǒng)計(jì)先設(shè)置一個(gè)值,然后再進(jìn)行后續(xù)的tuning。
- nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false). 對(duì)應(yīng)于選擇求解最小二乘的方法:if (nonnegative) new NNLSSolver else new CholeskySolver。如果True就是用非負(fù)正則化最小二乘(NNLS),F(xiàn)alse就是用喬里斯基分解(Cholesky)
Note: 基于DataFrame的MLlib API目前只支持integer類型的user和Item的id。其他numeric類型的user和item id列也支持,不過ids必須在integer的取值范圍內(nèi)。這里的numeric類型指的是java.lang.Number,看了下源碼感覺負(fù)值也應(yīng)該是可以的。

除了上面文檔中的參數(shù),還有一些別的參數(shù)設(shè)置也有必要列出來(下面的Dataset<Row>即為DataFrame):
- userCol:用戶列的名字,String類型。對(duì)應(yīng)于后續(xù)調(diào)用fit()操作時(shí)輸入的Dataset<Row>入?yún)r(shí)用戶id所在schema中的name
- itemCol:item列的名字,String類型。對(duì)應(yīng)于后續(xù)調(diào)用fit()操作時(shí)輸入的Dataset<Row>入?yún)r(shí)item id所在schema中的name
- ratingCol:rating列的名字,String類型。對(duì)應(yīng)于后續(xù)調(diào)用fit()操作時(shí)輸入的Dataset<Row>入?yún)r(shí)rating值所在schema中的name
- predictionCol:String類型。做transform()操作時(shí)輸出的預(yù)測(cè)值在Dataset<Row>結(jié)果的schema中的name,默認(rèn)是“prediction”
- coldStartStrategy:String類型。有兩個(gè)取值"nan" or "drop"。這個(gè)參數(shù)指示用在prediction階段時(shí)遇到未知或者新加入的user或item時(shí)的處理策略。尤其是在交叉驗(yàn)證或者生產(chǎn)場(chǎng)景中,遇到?jīng)]有在訓(xùn)練集中出現(xiàn)的user/item id時(shí)。"nan"表示對(duì)于未知id的prediction結(jié)果為NaN。"drop"表示對(duì)于transform()的入?yún)ataFrame中出現(xiàn)未知ids的行,將會(huì)在包含prediction的返回DataFrame中被drop。默認(rèn)值是"nan"
Explicit和implicit feedback
標(biāo)準(zhǔn)的協(xié)同過濾中的矩陣分解(matrix factorization)都是對(duì)user-item的打分矩陣做因子分解,比如用戶對(duì)電影的打分,也稱為顯式反饋(explicit feedback)。
不過在現(xiàn)實(shí)情況中,很多user-item都不是某種特定意義的評(píng)分,而是一些比如用戶的購買記錄、搜索關(guān)鍵字,甚至是鼠標(biāo)的移動(dòng)。我們將這些間接用戶行為稱之為隱式反饋(implicit feedback)。
在Spark中處理隱式反饋的算法是ALS-WR。可以重點(diǎn)看下前面給出的參考鏈接中的算法結(jié)果,觀察損失函數(shù),就可以知道大致過程。
正則化系數(shù)
這里指的是在ALS算法中L2正則項(xiàng)的系數(shù),用來防止過擬合,也能使矩陣的因子分解后的U和V矩陣的值不會(huì)太震蕩,方便接下來對(duì)U和V矩陣再做進(jìn)一步的利用。
而且Spark通過ALS-WR算法使得 regParam 較少的被數(shù)據(jù)集的規(guī)模所影響。這樣可以使得在樣本子集中學(xué)習(xí)得到的最佳參數(shù)可以應(yīng)用在數(shù)據(jù)全集上而且獲得相似的性能。
冷啟動(dòng)策略
我們使用訓(xùn)練后的 ALSModel 對(duì)test數(shù)據(jù)進(jìn)行預(yù)測(cè),不過可能會(huì)遇到?jīng)]有出現(xiàn)在訓(xùn)練模型中的user或者item id,這是由以下兩種情況產(chǎn)生引起的:
- 在生成中:本來就會(huì)有新的user或者item上線,是之前訓(xùn)練時(shí)不曾有的(這也稱之為“cold start problem”)
- 在交叉驗(yàn)證階段:不管是用Spark的 CrossValidator 或者 TrainValidationSplit 都有可能出現(xiàn)驗(yàn)證集中的id是訓(xùn)練集中沒有出現(xiàn)過的。
默認(rèn)Spark使用NaN來表示對(duì)于未知id的rate的預(yù)測(cè)結(jié)果,這樣在生產(chǎn)中可以提示系統(tǒng)有新的id加入,作為接下來是否采取措施的依據(jù)。
不過在交叉驗(yàn)證階段,NaN會(huì)妨礙接下來的評(píng)分度量 (比如使用 RegressionEvaluator ),此時(shí)可以選擇"drop"來使得出現(xiàn)NaN的行都丟掉。方便調(diào)參時(shí)做模型選擇。
舉個(gè)栗子
下面這個(gè)栗子也是官網(wǎng)文檔中的栗子。首先看下數(shù)據(jù)的模樣:

然后是代碼:
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
// $example on$
import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
// $example off$
public class JavaALSExample {
// $example on$
public static class Rating implements Serializable {
private int userId;
private int movieId;
private float rating;
private long timestamp;
public Rating() {}
public Rating(int userId, int movieId, float rating, long timestamp) {
this.userId = userId;
this.movieId = movieId;
this.rating = rating;
this.timestamp = timestamp;
}
public int getUserId() {
return userId;
}
public int getMovieId() {
return movieId;
}
public float getRating() {
return rating;
}
public long getTimestamp() {
return timestamp;
}
public static Rating parseRating(String str) {
String[] fields = str.split("::");
if (fields.length != 4) {
throw new IllegalArgumentException("Each line must contain 4 fields");
}
int userId = Integer.parseInt(fields[0]);
int movieId = Integer.parseInt(fields[1]);
float rating = Float.parseFloat(fields[2]);
long timestamp = Long.parseLong(fields[3]);
return new Rating(userId, movieId, rating, timestamp);
}
}
// $example off$
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaALSExample")
.getOrCreate();
// $example on$
JavaRDD<Rating> ratingsRDD = spark
.read().textFile("data/mllib/als/sample_movielens_ratings.txt").javaRDD()
.map(Rating::parseRating);
Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];
// Build the recommendation model using ALS on the training data
ALS als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating");
ALSModel model = als.fit(training);
model.userFactors();
model.itemFactors();
// Evaluate the model by computing the RMSE on the test data
// Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
model.setColdStartStrategy("drop");
Dataset<Row> predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
Double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse);
// Generate top 10 movie recommendations for each user
Dataset<Row> userRecs = model.recommendForAllUsers(10);
// Generate top 10 user recommendations for each movie
Dataset<Row> movieRecs = model.recommendForAllItems(10);
// Generate top 10 movie recommendations for a specified set of users
//todo: Those API @Since("2.3.0")
// Dataset<Row> users = ratings.select(als.getUserCol()).distinct().limit(3);
// Dataset<Row> userSubsetRecs = model.recommendForUserSubset(users, 10);
// // Generate top 10 user recommendations for a specified set of movies
// Dataset<Row> movies = ratings.select(als.getItemCol()).distinct().limit(3);
// Dataset<Row> movieSubSetRecs = model.recommendForItemSubset(movies, 10);
// $example off$
userRecs.show();
movieRecs.show();
// userSubsetRecs.show();
// movieSubSetRecs.show();
spark.stop();
}
}
代碼還是不難的,建議在IDEA中閱讀看下。實(shí)際使用時(shí)還需要加上tuning環(huán)節(jié)來對(duì)rank,maxIter,regParam ,alpha 甚至numBlocks進(jìn)行調(diào)參。