Alink教程(Java版)

第24章 构建推荐系统

本章包括下面各节:
24.1 与推荐相关的组件介绍
24.2 常用推荐算法
24.2.1 协同过滤
24.2.2 交替最小二乘法
24.3 数据探索
24.4 评分预测
24.5 根据用户推荐影片
24.6 计算相似影片
24.7 根据影片推荐用户
24.8 计算相似用户

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Java)》,这里为本章对应的示例代码。

package com.alibaba.alink;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.evaluation.EvalRegressionBatchOp;
import com.alibaba.alink.operator.batch.recommendation.AlsTrainBatchOp;
import com.alibaba.alink.operator.batch.recommendation.ItemCfTrainBatchOp;
import com.alibaba.alink.operator.batch.recommendation.UserCfTrainBatchOp;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.batch.source.MemSourceBatchOp;
import com.alibaba.alink.operator.batch.source.TsvSourceBatchOp;
import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
import com.alibaba.alink.operator.stream.source.TsvSourceStreamOp;
import com.alibaba.alink.pipeline.LocalPredictor;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.dataproc.Lookup;
import com.alibaba.alink.pipeline.recommendation.*;
import org.apache.flink.types.Row;

import java.io.File;

public class Chap24 {

    static final String DATA_DIR = Utils.ROOT_DIR + "movielens" + File.separator + "ml-100k" + File.separator;

    static final String RATING_FILE = "u.data";
    static final String USER_FILE = "u.user";
    static final String ITEM_FILE = "u.item";
    static final String RATING_TRAIN_FILE = "ua.base";
    static final String RATING_TEST_FILE = "ua.test";

    static final String USER_COL = "user_id";
    static final String ITEM_COL = "item_id";
    static final String RATING_COL = "rating";
    static final String RECOMM_COL = "recomm";

    static final String ALS_MODEL_FILE = "als_model.ak";
    static final String ITEMCF_MODEL_FILE = "itemcf_model.ak";
    static final String USERCF_MODEL_FILE = "usercf_model.ak";

    static final String RATING_SCHEMA_STRING
        = "user_id long, item_id long, rating float, ts long";

    static final String USER_SCHEMA_STRING
        = "user_id long, age int, gender string, occupation string, zip_code string";

    static final String ITEM_SCHEMA_STRING = "item_id long, title string, "
        + "release_date string, video_release_date string, imdb_url string, "
        + "unknown int, action int, adventure int, animation int, "
        + "children int, comedy int, crime int, documentary int, drama int, "
        + "fantasy int, film_noir int, horror int, musical int, mystery int, "
        + "romance int, sci_fi int, thriller int, war int, western int";

    static TsvSourceBatchOp getSourceRatings() {
        return new TsvSourceBatchOp()
            .setFilePath(DATA_DIR + RATING_FILE)
            .setSchemaStr(RATING_SCHEMA_STRING);
    }

    static TsvSourceStreamOp getStreamSourceRatings() {
        return new TsvSourceStreamOp()
            .setFilePath(DATA_DIR + RATING_FILE)
            .setSchemaStr(RATING_SCHEMA_STRING);
    }

    static CsvSourceBatchOp getSourceUsers() {
        return new CsvSourceBatchOp()
            .setFieldDelimiter("|")
            .setFilePath(DATA_DIR + USER_FILE)
            .setSchemaStr(USER_SCHEMA_STRING);
    }

    static CsvSourceBatchOp getSourceItems() {
        return new CsvSourceBatchOp()
            .setFieldDelimiter("|")
            .setFilePath(DATA_DIR + ITEM_FILE)
            .setSchemaStr(ITEM_SCHEMA_STRING);
    }

    static CsvSourceStreamOp getStreamSourceItems() {
        return new CsvSourceStreamOp()
            .setFieldDelimiter("|")
            .setFilePath(DATA_DIR + ITEM_FILE)
            .setSchemaStr(ITEM_SCHEMA_STRING);
    }

    public static void main(String[] args) throws Exception {

        c_4();

        c_5();

        c_6();

        c_7();

        c_8();

    }

    static void c_4() throws Exception {

        TsvSourceBatchOp train_set = new TsvSourceBatchOp()
            .setFilePath(DATA_DIR + RATING_TRAIN_FILE)
            .setSchemaStr(RATING_SCHEMA_STRING);

        TsvSourceBatchOp test_set = new TsvSourceBatchOp()
            .setFilePath(DATA_DIR + RATING_TEST_FILE)
            .setSchemaStr(RATING_SCHEMA_STRING);

        if (!new File(DATA_DIR + ALS_MODEL_FILE).exists()) {

            train_set
                .link(
                    new AlsTrainBatchOp()
                        .setUserCol(USER_COL)
                        .setItemCol(ITEM_COL)
                        .setRateCol(RATING_COL)
                        .setLambda(0.1)
                        .setRank(10)
                        .setNumIter(10)
                )
                .link(
                    new AkSinkBatchOp()
                        .setFilePath(DATA_DIR + ALS_MODEL_FILE)
                );
            BatchOperator.execute();

        }

        new PipelineModel
            (
                new AlsRateRecommender()
                    .setUserCol(USER_COL)
                    .setItemCol(ITEM_COL)
                    .setRecommCol(RECOMM_COL)
                    .setModelData(
                        new AkSourceBatchOp()
                            .setFilePath(DATA_DIR + ALS_MODEL_FILE)
                    ),
                new Lookup()
                    .setSelectedCols(ITEM_COL)
                    .setOutputCols("item_name")
                    .setModelData(getSourceItems())
                    .setMapKeyCols("item_id")
                    .setMapValueCols("title")
            )
            .transform(
                test_set.filter("user_id=1")
            )
            .select("user_id, rating, recomm, item_name")
            .orderBy("rating, recomm", 1000)
            .lazyPrint(-1);

        BatchOperator.execute();

        new AlsRateRecommender()
            .setUserCol(USER_COL)
            .setItemCol(ITEM_COL)
            .setRecommCol(RECOMM_COL)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + ALS_MODEL_FILE)
            )
            .transform(test_set)
            .link(
                new EvalRegressionBatchOp()
                    .setLabelCol(RATING_COL)
                    .setPredictionCol(RECOMM_COL)
                    .lazyPrintMetrics()
            );
        BatchOperator.execute();

    }

    static void c_5() throws Exception {

        if (!new File(DATA_DIR + ITEMCF_MODEL_FILE).exists()) {

            getSourceRatings()
                .link(
                    new ItemCfTrainBatchOp()
                        .setUserCol(USER_COL)
                        .setItemCol(ITEM_COL)
                        .setRateCol(RATING_COL)
                )
                .link(
                    new AkSinkBatchOp()
                        .setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
                );
            BatchOperator.execute();

        }

        MemSourceBatchOp test_data = new MemSourceBatchOp(new Long[]{1L}, "user_id");

        new ItemCfItemsPerUserRecommender()
            .setUserCol(USER_COL)
            .setRecommCol(RECOMM_COL)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
            )
            .transform(test_data)
            .print();

        LocalPredictor recomm_predictor = new ItemCfItemsPerUserRecommender()
            .setUserCol(USER_COL)
            .setRecommCol(RECOMM_COL)
            .setK(20)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
            )
            .collectLocalPredictor("user_id long");

        System.out.println(recomm_predictor.getOutputSchema());

        LocalPredictor kv_predictor = new Lookup()
            .setSelectedCols(ITEM_COL)
            .setOutputCols("item_name")
            .setModelData(getSourceItems())
            .setMapKeyCols("item_id")
            .setMapValueCols("title")
            .collectLocalPredictor("item_id long");

        System.out.println(kv_predictor.getOutputSchema());

        MTable recommResult = (MTable) recomm_predictor.map(Row.of(1L)).getField(1);

        System.out.println(recommResult);


        new Lookup()
            .setSelectedCols(ITEM_COL)
            .setOutputCols("item_name")
            .setModelData(getSourceItems())
            .setMapKeyCols("item_id")
            .setMapValueCols("title")
            .transform(
                getSourceRatings().filter("user_id=1 AND rating>4")
            )
            .select("item_name")
            .orderBy("item_name", 1000)
            .lazyPrint(-1);

        LocalPredictor recomm_predictor_2 = new ItemCfItemsPerUserRecommender()
            .setUserCol(USER_COL)
            .setRecommCol(RECOMM_COL)
            .setK(20)
            .setExcludeKnown(true)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
            )
            .collectLocalPredictor("user_id long");

        recommResult = (MTable) recomm_predictor_2.map(Row.of(1L)).getField(1);

        System.out.println(recommResult);

    }

    static void c_6() throws Exception {
        MemSourceBatchOp test_data = new MemSourceBatchOp(new Long[]{50L}, ITEM_COL);

        new ItemCfSimilarItemsRecommender()
            .setItemCol(ITEM_COL)
            .setRecommCol(RECOMM_COL)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
            )
            .transform(test_data)
            .print();

        LocalPredictor recomm_predictor = new ItemCfSimilarItemsRecommender()
            .setItemCol(ITEM_COL)
            .setRecommCol(RECOMM_COL)
            .setK(10)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + ITEMCF_MODEL_FILE)
            )
            .collectLocalPredictor("item_id long");

        LocalPredictor kv_predictor = new Lookup()
            .setSelectedCols(ITEM_COL)
            .setOutputCols("item_name")
            .setModelData(getSourceItems())
            .setMapKeyCols("item_id")
            .setMapValueCols("title")
            .collectLocalPredictor("item_id long");

        MTable recommResult = (MTable) recomm_predictor.map(Row.of(50L)).getField(1);

        System.out.println(recommResult);

    }

    static void c_7() throws Exception {

        if (!new File(DATA_DIR + USERCF_MODEL_FILE).exists()) {

            getSourceRatings()
                .link(
                    new UserCfTrainBatchOp()
                        .setUserCol(USER_COL)
                        .setItemCol(ITEM_COL)
                        .setRateCol(RATING_COL)
                )
                .link(
                    new AkSinkBatchOp()
                        .setFilePath(DATA_DIR + USERCF_MODEL_FILE)
                );
            BatchOperator.execute();

        }

        MemSourceBatchOp test_data = new MemSourceBatchOp(new Long[]{50L}, ITEM_COL);

        new UserCfUsersPerItemRecommender()
            .setItemCol(ITEM_COL)
            .setRecommCol(RECOMM_COL)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + USERCF_MODEL_FILE)
            )
            .transform(test_data)
            .print();

        getSourceRatings()
            .filter("user_id IN (276,429,222,864,194,650,896,303,749,301) AND item_id=50")
            .print();

        new UserCfUsersPerItemRecommender()
            .setItemCol(ITEM_COL)
            .setRecommCol(RECOMM_COL)
            .setExcludeKnown(true)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + USERCF_MODEL_FILE)
            )
            .transform(test_data)
            .print();

    }

    static void c_8() throws Exception {
        MemSourceBatchOp test_data = new MemSourceBatchOp(new Long[]{1L}, USER_COL);

        new UserCfSimilarUsersRecommender()
            .setUserCol(USER_COL)
            .setRecommCol(RECOMM_COL)
            .setModelData(
                new AkSourceBatchOp()
                    .setFilePath(DATA_DIR + USERCF_MODEL_FILE)
            )
            .transform(test_data)
            .print();

        getSourceUsers()
            .filter("user_id IN (1, 916,864,268,92,435,457,738,429,303,276)")
            .print();
    }

}