Alink教程(Python版)

第6章 用户定义函数(UDF/UDTF)

本章包括下面各节:
6.1 用户定义标量函数(UDF)
6.1.1 示例数据及问题
6.1.2 UDF的定义
6.1.3 使用UDF处理批式数据
6.1.4 使用UDF处理流式数据
6.2 用户定义表值函数(UDTF)
6.2.1 示例数据及问题
6.2.2 UDTF的定义
6.2.3 使用UDTF处理批式数据
6.2.4 使用UDTF处理流式数据

详细内容请阅读纸质书《Alink权威指南:基于Flink的机器学习实例入门(Python)》,这里为本章对应的示例代码。

from pyalink.alink import *
useLocalEnv(1)

from utils import *
import os
import pandas as pd

DATA_DIR = ROOT_DIR + "movielens" + os.sep + "ml-100k" + os.sep

RATING_FILE = "u.data";
ITEM_FILE = "u.item";

RATING_SCHEMA_STRING = "user_id long, item_id long, rating float, ts long";

ITEM_SCHEMA_STRING = "item_id long, title string, "\
    + "release_date string, video_release_date string, imdb_url string, "\
    + "unknown int, action int, adventure int, animation int, "\
    + "children int, comedy int, crime int, documentary int, drama int, "\
    + "fantasy int, film_noir int, horror int, musical int, mystery int, "\
    + "romance int, sci_fi int, thriller int, war int, western int";


def getSourceRatings() :
    return TsvSourceBatchOp()\
            .setFilePath(DATA_DIR + RATING_FILE)\
            .setSchemaStr(RATING_SCHEMA_STRING);


def getStreamSourceRatings() :
    return TsvSourceStreamOp()\
            .setFilePath(DATA_DIR + RATING_FILE)\
            .setSchemaStr(RATING_SCHEMA_STRING);

def getSourceItems() :
    return CsvSourceBatchOp()\
            .setFieldDelimiter("|")\
            .setFilePath(DATA_DIR + ITEM_FILE)\
            .setSchemaStr(ITEM_SCHEMA_STRING);


def getStreamSourceItems() :
    return CsvSourceStreamOp()\
            .setFieldDelimiter("|")\
            .setFilePath(DATA_DIR + ITEM_FILE)\
            .setSchemaStr(ITEM_SCHEMA_STRING);

import datetime

@udf(input_types=[DataTypes.BIGINT()], result_type=DataTypes.TIMESTAMP(3))
def from_unix_timestamp(ts):
    return datetime.datetime.fromtimestamp(ts)

#c_1_3

ratings = getSourceRatings();

ratings.firstN(5).print();

ratings\
    .link(
        UDFBatchOp()\
            .setFunc(from_unix_timestamp)\
            .setSelectedCols(["ts"])\
            .setOutputCol("ts")
    )\
    .firstN(5)\
    .print();

BatchOperator.registerFunction("from_unix_timestamp", from_unix_timestamp);

ratings\
    .select("user_id, item_id, rating, from_unix_timestamp(ts) AS ts")\
    .firstN(5)\
    .print();

ratings.registerTableName("ratings");

BatchOperator\
    .sqlQuery("SELECT user_id, item_id, rating, from_unix_timestamp(ts) AS ts FROM ratings")\
    .firstN(5)\
    .print();

#c_1_4

ratings = getStreamSourceRatings();

ratings = ratings.filter("user_id=1 AND item_id<5");

ratings.print();

StreamOperator.execute();

ratings\
    .link(
        UDFStreamOp()\
            .setFunc(from_unix_timestamp)\
            .setSelectedCols(["ts"])\
            .setOutputCol("ts")
    )\
    .print();

StreamOperator.execute();

StreamOperator.registerFunction("from_unix_timestamp", from_unix_timestamp);

ratings\
    .select("user_id, item_id, rating, from_unix_timestamp(ts) AS ts")\
    .print();

StreamOperator.execute();

ratings.registerTableName("ratings");

StreamOperator\
    .sqlQuery("SELECT user_id, item_id, rating, from_unix_timestamp(ts) AS ts FROM ratings")\
    .print();

StreamOperator.execute();

@udtf(input_types=[DataTypes.STRING()], result_types=[DataTypes.STRING(), DataTypes.INT()])
def doc_word_count(s):
    dict = {}
    for t in s.split() :
        if t in dict :
            dict[t] = dict[t] + 1
        else :
            dict[t] = 1
    
    for k in dict :
        yield k, dict[k]

#c_2_3

items = getSourceItems();

items.select("item_id, title").lazyPrint(10, "<- original data ->");

words = items\
    .link(
        UDTFBatchOp()\
            .setFunc(doc_word_count)\
            .setSelectedCols(["title"])\
            .setOutputCols(["word", "cnt"])\
            .setReservedCols(["item_id"])
    );

words.lazyPrint(20, "<- after word count ->");

words.groupBy("word", "word, SUM(cnt) AS cnt")\
    .orderBy("cnt", 20, order='desc')\
    .print();

BatchOperator.registerFunction("doc_word_count", doc_word_count);

items.registerTableName("items");

BatchOperator\
    .sqlQuery("SELECT item_id, word, cnt FROM items, "
              + "LATERAL TABLE(doc_word_count(title)) as T(word, cnt)")\
    .firstN(20)\
    .print();

#c_2_4

items = getStreamSourceItems();

items = items.select("item_id, title").filter("item_id<4");

items.print();
StreamOperator.execute();

words = items\
    .link(
        UDTFStreamOp()\
            .setFunc(doc_word_count)\
            .setSelectedCols(["title"])\
            .setOutputCols(["word", "cnt"])\
            .setReservedCols(["item_id"])
    );

words.print();
StreamOperator.execute();

StreamOperator.registerFunction("doc_word_count", doc_word_count);

items.registerTableName("items");

StreamOperator\
    .sqlQuery("SELECT item_id, word, cnt FROM items, "
              + "LATERAL TABLE(doc_word_count(title)) as T(word, cnt)")\
    .print();

StreamOperator.execute();