FlinkML中机器学习算法介绍(一)
Flink ML 是 Apache Flink 生态的子项目,提供机器学习 (ML) API 和基础设施,简化了 ML 管道的构建。用户可以使用标准的 ML API 实现 ML 算法,并进一步使用这些基础设施构建用于训练和推断作业的 ML 管道。为用户提供了标准的 ML API,用户可以使用这些 API 实现 ML 算法,并进一步使用 Flink ML 提供的基础设施构建用于训练和推断作业的 ML 管道。 可以帮助用户构建和部署机器学习模型,以便在实时数据流中进行预测和推断。
Flink ML 的算法库包含常用的机器学习算法:
一、分类(Classification): 属于监督学习的范畴,根据一些给定的已知类别的样本,使它能够对未知类别的样本进行分类,要求必须事先明确知道各个类别的信息。
1、KNN:KNN是一种分类算法。KNN 的基本假设是,如果所提供样本的大多数最近的 K 个邻居属于同一标签,则所提供样本也极有可能属于该标签。
KNN 的优点是:无需训练,计算时间快,算法简单易懂,适用于回归和分类,准确度高,不需要与更好的监督学习模型进行比较,不需要对数据进行额外的假设、调整多个参数或构建模型。
KNN 的缺点是:计算时间随着数据量的增加而增加,对于高维数据不太适用,对于分类不平衡的数据集表现不佳。
应用场景包括:文本分类或文本挖掘、森林清查和估算森林变量、基因表达谱的功能基因组学研究、数据预处理。
示例代码: import org.apache.flink.ml.classification.knn.Knn; import org.apache.flink.ml.classification.knn.KnnModel; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; /** Simple program that trains a Knn model and uses it for classification. */ public class KnnExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input training and prediction data. DataStream trainStream = env.fromElements( Row.of(Vectors.dense(2.0, 3.0), 1.0), Row.of(Vectors.dense(2.1, 3.1), 1.0), Row.of(Vectors.dense(200.1, 300.1), 2.0), Row.of(Vectors.dense(200.2, 300.2), 2.0), Row.of(Vectors.dense(200.3, 300.3), 2.0), Row.of(Vectors.dense(200.4, 300.4), 2.0), Row.of(Vectors.dense(200.4, 300.4), 2.0), Row.of(Vectors.dense(200.6, 300.6), 2.0), Row.of(Vectors.dense(2.1, 3.1), 1.0), Row.of(Vectors.dense(2.1, 3.1), 1.0), Row.of(Vectors.dense(2.1, 3.1), 1.0), Row.of(Vectors.dense(2.1, 3.1), 1.0), Row.of(Vectors.dense(2.3, 3.2), 1.0), Row.of(Vectors.dense(2.3, 3.2), 1.0), Row.of(Vectors.dense(2.8, 3.2), 3.0), Row.of(Vectors.dense(300., 3.2), 4.0), Row.of(Vectors.dense(2.2, 3.2), 1.0), Row.of(Vectors.dense(2.4, 3.2), 5.0), Row.of(Vectors.dense(2.5, 3.2), 5.0), Row.of(Vectors.dense(2.5, 3.2), 5.0), Row.of(Vectors.dense(2.1, 3.1), 1.0)); Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label"); DataStream predictStream = env.fromElements( Row.of(Vectors.dense(4.0, 4.1), 5.0), Row.of(Vectors.dense(300, 42), 2.0)); Table predictTable = tEnv.fromDataStream(predictStream).as("features", "label"); // Creates a Knn object and initializes its parameters. Knn knn = new Knn().setK(4); // Trains the Knn Model. KnnModel knnModel = knn.fit(trainTable); // Uses the Knn Model for predictions. Table outputTable = knnModel.transform(predictTable)[0]; // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector features = (DenseVector) row.getField(knn.getFeaturesCol()); double expectedResult = (Double) row.getField(knn.getLabelCol()); double predictionResult = (Double) row.getField(knn.getPredictionCol()); System.out.printf( "Features: %-15s Expected Result: %s Prediction Result: %s ", features, expectedResult, predictionResult); } } }
2、Linear SVC(线性支持向量分类器)是一种算法,它试图找到一个超平面,以最大化分类样本之间的距离。
线性支持向量分类器的优点是:在高维空间中有效,即使在维数大于样本数的情况下仍然有效,使用训练点的子集进行决策函数(称为支持向量),因此也具有内存效率。线性支持向量分类器的缺点是:它不能处理非线性数据。当数据集的样本数大于特征数时,使用 LinearSVC 会更快。此外,如果您的数据集非常大,则可以使用 SGDClassifier 或 Nystroem 转换器等核逼近方法来加速 LinearSVC 的训练。
示例代码: import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; /** Simple program that trains a LinearSVC model and uses it for classification. */ public class LinearSVCExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. DataStream inputStream = env.fromElements( Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.), Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.), Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.), Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight"); // Creates a LinearSVC object and initializes its parameters. LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); // Trains the LinearSVC Model. LinearSVCModel linearSVCModel = linearSVC.fit(inputTable); // Uses the LinearSVC Model for predictions. Table outputTable = linearSVCModel.transform(inputTable)[0]; // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector features = (DenseVector) row.getField(linearSVC.getFeaturesCol()); double expectedResult = (Double) row.getField(linearSVC.getLabelCol()); double predictionResult = (Double) row.getField(linearSVC.getPredictionCol()); DenseVector rawPredictionResult = (DenseVector) row.getField(linearSVC.getRawPredictionCol()); System.out.printf( "Features: %-25s Expected Result: %s Prediction Result: %s Raw Prediction Result: %s ", features, expectedResult, predictionResult, rawPredictionResult); } } }
3、Logistic Regression Logistic regression是广义线性模型的一种特殊情况,它是一种用于解决二分类问题的机器学习方法,用于估计某种事物的可能性,是一种基于概率的模式识别算法,虽然名字中带"回归",但实际上是一种分类方法。在实际应用中,逻辑回归可以说是应用最广泛的机器学习算法之一。逻辑回归的目标是根据输入特征的线性组合来预测一个二元输出变量的概率。它使用sigmoid函数(S(x) = 1 / (1 + e^(-x)))将线性函数的输出转换为概率值,从而进行分类,逻辑回归可以用于二元分类和多元分类。
Logistic Regression算法是一种广泛使用的算法,因为它非常高效,不需要太大的计算量,又通俗易懂,不需要缩放输入特征,不需要任何调整,且很容易调整,并且输出校准好的预测概率。但是,它也有一些缺点。例如,它不能用于解决非线性问题,因为Logistic的决策面是线性的;对多重共线性数据较为敏感;很难处理数据不平衡的问题;准确率并不是很高,因为形式非常的简单(非常类似线性模型),很难去拟合数据的真实分布。
示例代码:
普通示例 import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; /** Simple program that trains a LogisticRegression model and uses it for classification. */ public class LogisticRegressionExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. DataStream inputStream = env.fromElements( Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.), Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.), Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.), Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight"); // Creates a LogisticRegression object and initializes its parameters. LogisticRegression lr = new LogisticRegression().setWeightCol("weight"); // Trains the LogisticRegression Model. LogisticRegressionModel lrModel = lr.fit(inputTable); // Uses the LogisticRegression Model for predictions. Table outputTable = lrModel.transform(inputTable)[0]; // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol()); double expectedResult = (Double) row.getField(lr.getLabelCol()); double predictionResult = (Double) row.getField(lr.getPredictionCol()); DenseVector rawPredictionResult = (DenseVector) row.getField(lr.getRawPredictionCol()); System.out.printf( "Features: %-25s Expected Result: %s Prediction Result: %s Raw Prediction Result: %s ", features, expectedResult, predictionResult, rawPredictionResult); } } }
在线无界流 import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; import java.util.Arrays; import java.util.Collections; import java.util.List; /** Simple program that trains an OnlineLogisticRegression model and uses it for classification. */ public class OnlineLogisticRegressionExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(4); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input training and prediction data. Both are infinite streams that periodically // sends out provided data to trigger model update and prediction. List trainData1 = Arrays.asList( Row.of(Vectors.dense(0.1, 2.), 0.), Row.of(Vectors.dense(0.2, 2.), 0.), Row.of(Vectors.dense(0.3, 2.), 0.), Row.of(Vectors.dense(0.4, 2.), 0.), Row.of(Vectors.dense(0.5, 2.), 0.), Row.of(Vectors.dense(11., 12.), 1.), Row.of(Vectors.dense(12., 11.), 1.), Row.of(Vectors.dense(13., 12.), 1.), Row.of(Vectors.dense(14., 12.), 1.), Row.of(Vectors.dense(15., 12.), 1.)); List trainData2 = Arrays.asList( Row.of(Vectors.dense(0.2, 3.), 0.), Row.of(Vectors.dense(0.8, 1.), 0.), Row.of(Vectors.dense(0.7, 1.), 0.), Row.of(Vectors.dense(0.6, 2.), 0.), Row.of(Vectors.dense(0.2, 2.), 0.), Row.of(Vectors.dense(14., 17.), 1.), Row.of(Vectors.dense(15., 10.), 1.), Row.of(Vectors.dense(16., 16.), 1.), Row.of(Vectors.dense(17., 10.), 1.), Row.of(Vectors.dense(18., 13.), 1.)); List predictData = Arrays.asList( Row.of(Vectors.dense(0.8, 2.7), 0.0), Row.of(Vectors.dense(15.5, 11.2), 1.0)); RowTypeInfo typeInfo = new RowTypeInfo( new TypeInformation[] {DenseVectorTypeInfo.INSTANCE, Types.DOUBLE}, new String[] {"features", "label"}); SourceFunction trainSource = new PeriodicSourceFunction(1000, Arrays.asList(trainData1, trainData2)); DataStream trainStream = env.addSource(trainSource, typeInfo); Table trainTable = tEnv.fromDataStream(trainStream).as("features"); SourceFunction predictSource = new PeriodicSourceFunction(1000, Collections.singletonList(predictData)); DataStream predictStream = env.addSource(predictSource, typeInfo); Table predictTable = tEnv.fromDataStream(predictStream).as("features"); // Creates an online LogisticRegression object and initializes its parameters and initial // model data. Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L); Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData)); OnlineLogisticRegression olr = new OnlineLogisticRegression() .setFeaturesCol("features") .setLabelCol("label") .setPredictionCol("prediction") .setReg(0.2) .setElasticNet(0.5) .setGlobalBatchSize(10) .setInitialModelData(initModelDataTable); // Trains the online LogisticRegression Model. OnlineLogisticRegressionModel onlineModel = olr.fit(trainTable); // Uses the online LogisticRegression Model for predictions. Table outputTable = onlineModel.transform(predictTable)[0]; // Extracts and displays the results. As training data stream continuously triggers the // update of the internal model data, raw prediction results of the same predict dataset // would change over time. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector features = (DenseVector) row.getField(olr.getFeaturesCol()); Double expectedResult = (Double) row.getField(olr.getLabelCol()); Double predictionResult = (Double) row.getField(olr.getPredictionCol()); DenseVector rawPredictionResult = (DenseVector) row.getField(olr.getRawPredictionCol()); System.out.printf( "Features: %-25s Expected Result: %s Prediction Result: %s Raw Prediction Result: %s ", features, expectedResult, predictionResult, rawPredictionResult); } } }
4、Naive Bayes 朴素贝叶斯算法是一种基于贝叶斯定理和特征条件独立假设的分类方法。在许多场合,朴素贝叶斯分类算法可以与决策树和神经网络分类算法相媲美,该算法能运用到大型数据库中,而且方法简单、分类准确率高、速度快。
朴素贝叶斯算法的优点包括: 模型发源于古典数学理论,有稳定的分类效率。 对小规模的数据表现很好,能处理多分类任务,适合增量式训练,尤其是数据量超出内存时,可以一批批的去增量训练。 对缺失数据不太敏感,算法也比较简单,常用于文本分类。
缺点包括: 需要计算先验概率。 分类决策存在错误率。 对输入数据的表达形式很敏感。 由于使用了样本属性独立性的假设,所以如果样本属性有关联时其效果不好。
Naive Bayes应用场景比较广泛,文本分类/垃圾文本过滤/情感判别是应用最多的场景之一,朴素贝叶斯在文本分类场景中占据着一席之地。此外,朴素贝叶斯还可以应用于互斥群组中个体的区分,以及在估算决策论框架的矩阵中。 import org.apache.flink.ml.classification.naivebayes.NaiveBayes; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; /** Simple program that trains a NaiveBayes model and uses it for classification. */ public class NaiveBayesExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input training and prediction data. DataStream trainStream = env.fromElements( Row.of(Vectors.dense(0, 0.), 11), Row.of(Vectors.dense(1, 0), 10), Row.of(Vectors.dense(1, 1.), 10)); Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label"); DataStream predictStream = env.fromElements( Row.of(Vectors.dense(0, 1.)), Row.of(Vectors.dense(0, 0.)), Row.of(Vectors.dense(1, 0)), Row.of(Vectors.dense(1, 1.))); Table predictTable = tEnv.fromDataStream(predictStream).as("features"); // Creates a NaiveBayes object and initializes its parameters. NaiveBayes naiveBayes = new NaiveBayes() .setSmoothing(1.0) .setFeaturesCol("features") .setLabelCol("label") .setPredictionCol("prediction") .setModelType("multinomial"); // Trains the NaiveBayes Model. NaiveBayesModel naiveBayesModel = naiveBayes.fit(trainTable); // Uses the NaiveBayes Model for predictions. Table outputTable = naiveBayesModel.transform(predictTable)[0]; // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector features = (DenseVector) row.getField(naiveBayes.getFeaturesCol()); double predictionResult = (Double) row.getField(naiveBayes.getPredictionCol()); System.out.printf("Features: %s Prediction Result: %s ", features, predictionResult); } } }
二、回归
有监督学习的两大应用之一,产生连续的结果。例如向模型输入人的各种数据的训练样本,产生"输入一个人的数据,判断此人20年后今后的经济能力"的结果,结果是连续的,往往得到一条回归曲线。当输入自变量不同时,输出的因变量非离散分布(不仅仅是一条线性直线,多项曲线也是回归曲线)。
1、Linear Regression 算法是一种常用的回归算法,它的目的是通过找到一条直线或者一个平面,来拟合数据集中的数据点,从而实现对连续型变量的预测。
Linear Regression算法的基本思想是,假设数据集中的数据点之间存在一个线性关系,即:
y = wx + b
其中,y是因变量,x是自变量,w和b是待求的参数,分别表示斜率和截距。Linear Regression算法的任务就是通过给定的数据集,找到最合适的w和b,使得预测值y和真实值y之间的误差最小。
Linear Regression算法可以分为两种类型,根据自变量x的个数不同: 简单线性回归(Simple Linear Regression):当x只有一个时,即只有一个特征或属性时,称为简单线性回归。这时,Linear Regression算法就是在二维平面上找到一条直线,来拟合数据点。 多元线性回归(Multiple Linear Regression):当x有多个时,即有多个特征或属性时,称为多元线性回归。这时,Linear Regression算法就是在高维空间中找到一个平面或者一个超平面,来拟合数据点。
步骤描述: 输入:数据集D = {(x1,y1),(x2,y2),…,(xn,yn)},其中xi是自变量向量,yi是因变量标量 输出:参数w和b 步骤: 计算预测值y = wx + b 计算预测值y和真实值y之间的误差e = y - y 计算误差e的平方和或者均方误差作为损失函数L(w,b) 使用梯度下降法或者最小二乘法等优化方法,更新w和b的值,使得L(w,b)最小化 初始化:随机给定w和b的初始值 迭代:直到达到最大迭代次数或者收敛条件 返回当前的w和b
Linear Regression算法是一种简单而有效的回归算法,它可以用于预测房价、销量、收入等连续型变量。但是它也有一些缺点,比如: 对于非线性关系的数据集效果不好,因为它假设数据点之间是线性相关的 对于异常值或者噪声敏感,可能影响参数的估计和预测的准确性 对于多重共线性的特征可能导致参数不稳定或者过拟合,需要使用正则化方法进行惩罚或者选择合适的特征子集
示例: import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; /** Simple program that trains a LinearRegression model and uses it for regression. */ public class LinearRegressionExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. DataStream inputStream = env.fromElements( Row.of(Vectors.dense(2, 1), 4.0, 1.0), Row.of(Vectors.dense(3, 2), 7.0, 1.0), Row.of(Vectors.dense(4, 3), 10.0, 1.0), Row.of(Vectors.dense(2, 4), 10.0, 1.0), Row.of(Vectors.dense(2, 2), 6.0, 1.0), Row.of(Vectors.dense(4, 3), 10.0, 1.0), Row.of(Vectors.dense(1, 2), 5.0, 1.0), Row.of(Vectors.dense(5, 3), 11.0, 1.0)); Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight"); // Creates a LinearRegression object and initializes its parameters. LinearRegression lr = new LinearRegression().setWeightCol("weight"); // Trains the LinearRegression Model. LinearRegressionModel lrModel = lr.fit(inputTable); // Uses the LinearRegression Model for predictions. Table outputTable = lrModel.transform(inputTable)[0]; // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol()); double expectedResult = (Double) row.getField(lr.getLabelCol()); double predictionResult = (Double) row.getField(lr.getPredictionCol()); System.out.printf( "Features: %s Expected Result: %s Prediction Result: %s ", features, expectedResult, predictionResult); } } }
三、聚类
属于无监督学习的范畴,根据样本间的某种距离或者相似性来定义聚类,即把相似的(或距离近的)样本聚为同一类,而把不相似的(或距离远的)样本归在其他类。
1、K-means算法是一种迭代求解的聚类分析算法,它的目的是将数据集划分为K个不同的簇,使得每个簇内的数据点尽可能相似,而不同簇之间的数据点尽可能不同。
KMeans算法的优点包括: 算法简单,实现容易。 对处理大数据集,该算法保持可伸缩性和高效性。 当簇是密集的,且它们与其他簇是分离的,聚类效果较好。
缺点包括:
K-means算法是一种简单而有效的聚类方法,但是它也有一些局限性,比如: 需要事先确定K值,但是在实际应用中,K值往往不容易确定。 对于初始聚类中心的选择敏感,不同的初始聚类中心可能导致不同的聚类结果。 对于噪声和异常值敏感,可能影响聚类质量。 对于非凸形状或者大小差异较大的簇效果不好,因为它假设每个簇是球形或者椭球形的。
K-means算法有很多应用场景,比如:
文档分类:可以将相同话题的文档聚集在一起,并自动生成不同话题的专栏。
用户分群:可以根据用户的行为、偏好、属性等特征,将用户划分为不同的群体,从而进行个性化的推荐、营销、服务等。
图像分割:可以将图像中的像素点按照颜色或灰度进行聚类,从而实现图像的分割、压缩、增强等。
异常检测:可以将数据中的异常点或噪声点划分为一个簇,从而进行过滤或处理。
数据降维:可以将高维数据中的相似点聚集在一起,从而降低数据的维度,减少计算量和存储空间。
import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
/** Simple program that trains a KMeans model and uses it for clustering. */
public class KMeansExample {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// Generates input data.
DataStream inputStream =
env.fromElements(
Vectors.dense(0.0, 0.0),
Vectors.dense(0.0, 0.3),
Vectors.dense(0.3, 0.0),
Vectors.dense(9.0, 0.0),
Vectors.dense(9.0, 0.6),
Vectors.dense(9.6, 0.0));
Table inputTable = tEnv.fromDataStream(inputStream).as("features");
// Creates a K-means object and initializes its parameters.
KMeans kmeans = new KMeans().setK(2).setSeed(1L);
// Trains the K-means Model.
KMeansModel kmeansModel = kmeans.fit(inputTable);
// Uses the K-means Model for predictions.
Table outputTable = kmeansModel.transform(inputTable)[0];
// Extracts and displays the results.
for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector features = (DenseVector) row.getField(kmeans.getFeaturesCol());
int clusterId = (Integer) row.getField(kmeans.getPredictionCol());
System.out.printf("Features: %s Cluster ID: %s ", features, clusterId);
}
}
}
2、AgglomerativeClustering
AgglomerativeClustering算法是一种层次聚类算法,它的目的是将数据集划分为不同层次的簇,形成树状的聚类结构。AgglomerativeClustering算法是自底向上的,也就是说,它从每个数据点作为一个簇开始,然后逐步将最相近的两个簇合并为一个新的簇,直到达到预设的簇数或者满足某种停止条件。
AgglomerativeClustering算法的关键在于如何计算簇之间的距离或相似度。根据不同的距离度量方法,AgglomerativeClustering算法可以分为以下三种类型: 单链接(single-linkage):簇之间的距离定义为簇内两个最近的数据点之间的距离。这种方法倾向于产生链状的簇。 全链接(complete-linkage):簇之间的距离定义为簇内两个最远的数据点之间的距离。这种方法倾向于产生紧凑的簇。 平均链接(average-linkage):簇之间的距离定义为簇内所有数据点两两之间的距离的平均值。这种方法倾向于产生平衡的簇。
AgglomerativeClustering算法可以用以下步骤描述: 输入:数据集D,簇数K或者停止条件 输出:簇划分C 步骤: 找出距离最小的两个簇Ci和Cj 将Ci和Cj合并为一个新的簇Ck 更新距离矩阵,删除Ci和Cj对应的行和列,增加Ck对应的行和列 初始化:将每个数据点作为一个簇,构建一个n n的距离矩阵,其中n是数据点的数量 迭代:直到达到K个簇或者满足停止条件 返回当前的簇划分C
AgglomerativeClustering算法是一种常用的层次聚类算法,它可以展示数据集的层次结构,并且不需要事先指定簇数。但是它也有一些缺点,比如: 计算复杂度较高,需要O(n3)的时间和O(n2)的空间 对噪声和异常值敏感,可能影响聚类质量 对不同形状或大小差异较大的簇效果不好,因为它假设每个簇是球形或者椭球形的
示例代码: import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering; import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; /** Simple program that creates an AgglomerativeClustering instance and uses it for clustering. */ public class AgglomerativeClusteringExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. DataStream inputStream = env.fromElements( Vectors.dense(1, 1), Vectors.dense(1, 4), Vectors.dense(1, 0), Vectors.dense(4, 1.5), Vectors.dense(4, 4), Vectors.dense(4, 0)); Table inputTable = tEnv.fromDataStream(inputStream).as("features"); // Creates an AgglomerativeClustering object and initializes its parameters. AgglomerativeClustering agglomerativeClustering = new AgglomerativeClustering() .setLinkage(AgglomerativeClusteringParams.LINKAGE_WARD) .setDistanceMeasure(EuclideanDistanceMeasure.NAME) .setPredictionCol("prediction"); // Uses the AgglomerativeClustering object for clustering. Table[] outputs = agglomerativeClustering.transform(inputTable); // Extracts and displays the results. for (CloseableIterator it = outputs[0].execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector features = (DenseVector) row.getField(agglomerativeClustering.getFeaturesCol()); int clusterId = (Integer) row.getField(agglomerativeClustering.getPredictionCol()); System.out.printf("Features: %s Cluster ID: %s ", features, clusterId); } } }
四、Evaluation 评估
机器学习的评估算法是用来评价机器学习模型的性能和效果的方法。
1、Binary Classification Evaluator
Binary Classification Evaluator是二值分类评估器。它可以用于评价二分类模型的性能和效果,例如准确率,精确率,召回率,F1值,AUC值等。
Binary Classification Evaluator算法的基本思想是,根据模型的预测值和真实标签,计算出不同的评估指标,并根据指标的大小来判断模型的优劣。它可以用以下步骤描述: 输入:数据集D = {(x1,y1),(x2,y2),…,(xn,yn)},其中xi是自变量向量,yi是因变量标量 输出:评估指标 步骤: 初始化:创建一个BinaryClassificationEvaluator对象,并设置参数rawPredictionCol, labelCol, weightCol, metricName等 计算:调用evaluate方法,传入数据集D,返回评估指标的值 返回:输出评估指标的值
示例: import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluator; import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluatorParams; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; /** * Simple program that creates a BinaryClassificationEvaluator instance and uses it for evaluation. */ public class BinaryClassificationEvaluatorExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. DataStream inputStream = env.fromElements( Row.of(1.0, Vectors.dense(0.1, 0.9)), Row.of(1.0, Vectors.dense(0.2, 0.8)), Row.of(1.0, Vectors.dense(0.3, 0.7)), Row.of(0.0, Vectors.dense(0.25, 0.75)), Row.of(0.0, Vectors.dense(0.4, 0.6)), Row.of(1.0, Vectors.dense(0.35, 0.65)), Row.of(1.0, Vectors.dense(0.45, 0.55)), Row.of(0.0, Vectors.dense(0.6, 0.4)), Row.of(0.0, Vectors.dense(0.7, 0.3)), Row.of(1.0, Vectors.dense(0.65, 0.35)), Row.of(0.0, Vectors.dense(0.8, 0.2)), Row.of(1.0, Vectors.dense(0.9, 0.1))); Table inputTable = tEnv.fromDataStream(inputStream).as("label", "rawPrediction"); // Creates a BinaryClassificationEvaluator object and initializes its parameters. BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator() .setMetricsNames( BinaryClassificationEvaluatorParams.AREA_UNDER_PR, BinaryClassificationEvaluatorParams.KS, BinaryClassificationEvaluatorParams.AREA_UNDER_ROC); // Uses the BinaryClassificationEvaluator object for evaluations. Table outputTable = evaluator.transform(inputTable)[0]; // Extracts and displays the results. Row evaluationResult = outputTable.execute().collect().next(); System.out.printf( "Area under the precision-recall curve: %s ", evaluationResult.getField(BinaryClassificationEvaluatorParams.AREA_UNDER_PR)); System.out.printf( "Area under the receiver operating characteristic curve: %s ", evaluationResult.getField(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)); System.out.printf( "Kolmogorov-Smirnov value: %s ", evaluationResult.getField(BinaryClassificationEvaluatorParams.KS)); } }
五、Recommendation(推荐)
1、Swing
Swing算法是一种用于召回的算法,它是阿里早期使用的一种原创算法,在阿里多个业务场景被验证是非常有效的一种召回方式。它认为user-item-user的结构比itemCF的单边结构更稳定,更能反映物品之间的相似度。
Swing算法的基本思想是,如果同时喜欢两个物品的用户越多,且这些用户之间的重合度越低,那么这两个物品之间的相似度越高。它通过计算用户对之间共同喜欢的物品数量的倒数来衡量物品之间的相似度。
package org.apache.flink.ml.examples.recommendation; import org.apache.flink.ml.recommendation.swing.Swing; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; /** * Simple program that creates a Swing instance and uses it to generate recommendations for items. */ public class SwingExample { public static void main(String[] args) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. DataStream inputStream = env.fromElements( Row.of(0L, 10L), Row.of(0L, 11L), Row.of(0L, 12L), Row.of(1L, 13L), Row.of(1L, 12L), Row.of(2L, 10L), Row.of(2L, 11L), Row.of(2L, 12L), Row.of(3L, 13L), Row.of(3L, 12L)); Table inputTable = tEnv.fromDataStream(inputStream).as("user", "item"); // Creates a Swing object and initializes its parameters. Swing swing = new Swing().setUserCol("user").setItemCol("item").setMinUserBehavior(1); // Transforms the data. Table[] outputTable = swing.transform(inputTable); // Extracts and displays the result of swing algorithm. for (CloseableIterator it = outputTable[0].execute().collect(); it.hasNext(); ) { Row row = it.next(); long mainItem = row.getFieldAs(0); String itemRankScore = row.getFieldAs(1); System.out.printf("item: %d, top-k similar items: %s ", mainItem, itemRankScore); } } }