本章包括下面各节:
24.1 与推荐相关的组件介绍
24.2 常用推荐算法
24.2.1 协同过滤
24.2.2 交替最小二乘法
24.3 数据探索
24.4 评分预测
24.5 根据用户推荐影片
24.6 计算相似影片
24.7 根据影片推荐用户
24.8 计算相似用户
详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。
package com.alibaba.alink;
import com.alibaba.alink.common.MTable;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.evaluation.EvalRegressionBatchOp;
import com.alibaba.alink.operator.batch.recommendation.AlsTrainBatchOp;
import com.alibaba.alink.operator.batch.recommendation.ItemCfTrainBatchOp;
import com.alibaba.alink.operator.batch.recommendation.UserCfTrainBatchOp;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.batch.source.TsvSourceBatchOp;
import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
import com.alibaba.alink.operator.stream.source.TsvSourceStreamOp;
import com.alibaba.alink.pipeline.LocalPredictor;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.dataproc.Lookup;
import com.alibaba.alink.pipeline.recommendation.*;
import org.apache.flink.types.Row;
import java.io.File;
public class Chap24 {
static final String DATA_DIR = Utils.ROOT_DIR + "movielens" + File.separator + "ml-100k" + File.separator;
static final String RATING_FILE = "u.data";
static final String USER_FILE = "u.user";
static final String ITEM_FILE = "u.item";
static final String RATING_TRAIN_FILE = "ua.base";
static final String RATING_TEST_FILE = "ua.test";
static final String USER_COL = "user_id";
static final String ITEM_COL = "item_id";
static final String RATING_COL = "rating";
static final String RECOMM_COL = "recomm";
static final String ALS_MODEL_FILE = "als_model.ak";
static final String ITEMCF_MODEL_FILE = "itemcf_model.ak";
static final String USERCF_MODEL_FILE = "usercf_model.ak";
static final String RATING_SCHEMA_STRING
= "user_id long, item_id long, rating float, ts long";
static final String USER_SCHEMA_STRING
= "user_id long, age int, gender string, occupation string, zip_code string";
static final String ITEM_SCHEMA_STRING = "item_id long, title string, "
+ "release_date string, video_release_date string, imdb_url string, "
+ "unknown int, action int, adventure int, animation int, "
+ "children int, comedy int, crime int, documentary int, drama int, "
+ "fantasy int, film_noir int, horror int, musical int, mystery int, "
+ "romance int, sci_fi int, thriller int, war int, western int";
static TsvSourceBatchOp getSourceRatings() {
return new TsvSourceBatchOp()
.setFilePath(DATA_DIR + RATING_FILE)
.setSchemaStr(RATING_SCHEMA_STRING);
}
static TsvSourceStreamOp getStreamSourceRatings() {
return new TsvSourceStreamOp()
.setFilePath(DATA_DIR + RATING_FILE)
.setSchemaStr(RATING_SCHEMA_STRING);
}
static CsvSourceBatchOp getSourceUsers() {
return new CsvSourceBatchOp()
.setFieldDelimiter("|")
.setFilePath(DATA_DIR + USER_FILE)
.setSchemaStr(USER_SCHEMA_STRING);
}
static CsvSourceBatchOp getSourceItems() {
return new CsvSourceBatchOp()
.setFieldDelimiter("|")
.setFilePath(DATA_DIR + ITEM_FILE)
.setSchemaStr(ITEM_SCHEMA_STRING);
}
static CsvSourceStreamOp getStreamSourceItems() {
return new CsvSourceStreamOp()
.setFieldDelimiter("|")
.setFilePath(DATA_DIR + ITEM_FILE)
.setSchemaStr(ITEM_SCHEMA_STRING);
}
public static void main(String[] args) throws Exception {
c_4();
c_5();
c_6();
c_7();
c_8();
}
static void c_4() throws Exception {
TsvSourceBatchOp train_set = new TsvSourceBatchOp()
.setFilePath(DATA_DIR + RATING_TRAIN_FILE)
.setSchemaStr(RATING_SCHEMA_STRING);
TsvSourceBatchOp test_set = new TsvSourceBatchOp()
.setFilePath(DATA_DIR + RATING_TEST_FILE)
.setSchemaStr(RATING_SCHEMA_STRING);
if (!new File(DATA_DIR + ALS_MODEL_FILE).exists()) {
train_set
.link(
new AlsTrainBatchOp()
.setUserCol(USER_COL)
.setItemCol(ITEM_COL)
.setRateCol(RATING_COL)
.setLambda(0.1)
.setRank(10)
.setNumIter(10)
)
.link(
new AkSinkBatchOp()
.setFilePath(DATA_DIR + ALS_MODEL_FILE)
);
BatchOperator.execute();
}
new PipelineModel
(
new AlsRateRecommender()
.setUserCol(USER_COL)
.setItemCol(ITEM_COL)
.setRecommCol(RECOMM_COL)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + ALS_MODEL_FILE)
),
new Lookup()
.setSelectedCols(ITEM_COL)
.setOutputCols("item_name")
.setModelData(getSourceItems())
.setMapKeyCols("item_id")
.setMapValueCols("title")
)
.transform(
test_set.filter("user_id=1")
)
.select("user_id, rating, recomm, item_name")
.orderBy("rating, recomm", 1000)
.lazyPrint(-1);
BatchOperator.execute();
new AlsRateRecommender()
.setUserCol(USER_COL)
.setItemCol(ITEM_COL)
.setRecommCol(RECOMM_COL)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + ALS_MODEL_FILE)
)
.transform(test_set)
.link(
new EvalRegressionBatchOp()
.setLabelCol(RATING_COL)
.setPredictionCol(RECOMM_COL)
.lazyPrintMetrics()
);
BatchOperator.execute();
}
static void c_5() throws Exception {
if (!new File(DATA_DIR + ITEMCF_MODEL_FILE).exists()) {
getSourceRatings()
.link(
new ItemCfTrainBatchOp()
.setUserCol(USER_COL)
.setItemCol(ITEM_COL)
.setRateCol(RATING_COL)
)
.link(
new AkSinkBatchOp()
.setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
);
BatchOperator.execute();
}
MemSourceBatchOp test_data = new MemSourceBatchOp(new Long[]{1L}, "user_id");
new ItemCfItemsPerUserRecommender()
.setUserCol(USER_COL)
.setRecommCol(RECOMM_COL)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
)
.transform(test_data)
.print();
LocalPredictor recomm_predictor = new ItemCfItemsPerUserRecommender()
.setUserCol(USER_COL)
.setRecommCol(RECOMM_COL)
.setK(20)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
)
.collectLocalPredictor("user_id long");
System.out.println(recomm_predictor.getOutputSchema());
LocalPredictor kv_predictor = new Lookup()
.setSelectedCols(ITEM_COL)
.setOutputCols("item_name")
.setModelData(getSourceItems())
.setMapKeyCols("item_id")
.setMapValueCols("title")
.collectLocalPredictor("item_id long");
System.out.println(kv_predictor.getOutputSchema());
MTable recommResult = (MTable) recomm_predictor.map(Row.of(1L)).getField(1);
System.out.println(recommResult);
new Lookup()
.setSelectedCols(ITEM_COL)
.setOutputCols("item_name")
.setModelData(getSourceItems())
.setMapKeyCols("item_id")
.setMapValueCols("title")
.transform(
getSourceRatings().filter("user_id=1 AND rating>4")
)
.select("item_name")
.orderBy("item_name", 1000)
.lazyPrint(-1);
LocalPredictor recomm_predictor_2 = new ItemCfItemsPerUserRecommender()
.setUserCol(USER_COL)
.setRecommCol(RECOMM_COL)
.setK(20)
.setExcludeKnown(true)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
)
.collectLocalPredictor("user_id long");
recommResult = (MTable) recomm_predictor_2.map(Row.of(1L)).getField(1);
System.out.println(recommResult);
}
static void c_6() throws Exception {
MemSourceBatchOp test_data = new MemSourceBatchOp(new Long[]{50L}, ITEM_COL);
new ItemCfSimilarItemsRecommender()
.setItemCol(ITEM_COL)
.setRecommCol(RECOMM_COL)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
)
.transform(test_data)
.print();
LocalPredictor recomm_predictor = new ItemCfSimilarItemsRecommender()
.setItemCol(ITEM_COL)
.setRecommCol(RECOMM_COL)
.setK(10)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
)
.collectLocalPredictor("item_id long");
LocalPredictor kv_predictor = new Lookup()
.setSelectedCols(ITEM_COL)
.setOutputCols("item_name")
.setModelData(getSourceItems())
.setMapKeyCols("item_id")
.setMapValueCols("title")
.collectLocalPredictor("item_id long");
MTable recommResult = (MTable) recomm_predictor.map(Row.of(50L)).getField(1);
System.out.println(recommResult);
}
static void c_7() throws Exception {
if (!new File(DATA_DIR + USERCF_MODEL_FILE).exists()) {
getSourceRatings()
.link(
new UserCfTrainBatchOp()
.setUserCol(USER_COL)
.setItemCol(ITEM_COL)
.setRateCol(RATING_COL)
)
.link(
new AkSinkBatchOp()
.setFilePath(DATA_DIR + USERCF_MODEL_FILE)
);
BatchOperator.execute();
}
MemSourceBatchOp test_data = new MemSourceBatchOp(new Long[]{50L}, ITEM_COL);
new UserCfUsersPerItemRecommender()
.setItemCol(ITEM_COL)
.setRecommCol(RECOMM_COL)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + USERCF_MODEL_FILE)
)
.transform(test_data)
.print();
getSourceRatings()
.filter("user_id IN (276,429,222,864,194,650,896,303,749,301) AND item_id=50")
.print();
new UserCfUsersPerItemRecommender()
.setItemCol(ITEM_COL)
.setRecommCol(RECOMM_COL)
.setExcludeKnown(true)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + USERCF_MODEL_FILE)
)
.transform(test_data)
.print();
}
static void c_8() throws Exception {
MemSourceBatchOp test_data = new MemSourceBatchOp(new Long[]{1L}, USER_COL);
new UserCfSimilarUsersRecommender()
.setUserCol(USER_COL)
.setRecommCol(RECOMM_COL)
.setModelData(
new AkSourceBatchOp()
.setFilePath(DATA_DIR + USERCF_MODEL_FILE)
)
.transform(test_data)
.print();
getSourceUsers()
.filter("user_id IN (1, 916,864,268,92,435,457,738,429,303,276)")
.print();
}
}