news 2026/1/31 7:10:35

Flink ML LinearRegression 用 Table API 训练线性回归并输出预测值

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Flink ML LinearRegression 用 Table API 训练线性回归并输出预测值

1. LinearRegression 做什么?

它学习一个线性函数(简化表达):

[y^=w⋅x+b][ \hat{y} = w \cdot x + b ][y^=wx+b]

其中:

  • (x)(x)(x)是特征向量(Vector)
  • (w)(w)(w)是权重向量(模型参数)
  • (b)(b)(b)是偏置
  • (y^)(\hat{y})(y^)是预测值(prediction)

2. 输入列与输出列

输入列(Input Columns)

参数名类型默认值说明
featuresColVector"features"特征向量
labelColInteger(示例实际用 Double)"label"要预测的连续标签
weightColDouble"weight"样本权重(可选)

你贴的示例里labelweight都是 Double,这更符合回归场景。

输出列(Output Columns)

参数名类型默认值说明
predictionColInteger(实际为 Double)"prediction"预测值

3. 参数详解(Parameters)

LinearRegressionModel(模型)常用参数:

Key默认值说明
featuresCol"features"特征列名
predictionCol"prediction"预测列名

LinearRegression(训练器)额外参数:

Key默认值说明
labelCol"label"标签列名
weightColnull权重列名(可选)
maxIter20最大迭代次数
reg0.0正则化系数(L2/L1 混合由 elasticNet 控制)
elasticNet0.0ElasticNet 参数:0=纯L2,1=纯L1
learningRate0.1学习率
globalBatchSize32全局 batch 大小
tol1e-6收敛阈值(迭代停止条件之一)

工程建议(很实用):

  • 特征尺度差异大:先StandardScaler再 LR,收敛更稳
  • 数据量大/噪声大:适当加reg,防过拟合、也更稳定
  • 收敛慢:调大maxIter或调学习率(learningRate)但要谨慎

4. Java 示例逐段解读(fit + transform)

你给的示例是“训练后再对同一份数据做预测”,方便验证效果。

4.1 构造输入表:features + label + weight

DataStream<Row>inputStream=env.fromElements(Row.of(Vectors.dense(2,1),4.0,1.0),Row.of(Vectors.dense(3,2),7.0,1.0),Row.of(Vectors.dense(4,3),10.0,1.0),Row.of(Vectors.dense(2,4),10.0,1.0),Row.of(Vectors.dense(2,2),6.0,1.0),Row.of(Vectors.dense(4,3),10.0,1.0),Row.of(Vectors.dense(1,2),5.0,1.0),Row.of(Vectors.dense(5,3),11.0,1.0));TableinputTable=tEnv.fromDataStream(inputStream).as("features","label","weight");
  • features是二维向量
  • label是回归目标(Double)
  • weight全是 1.0(表示每个样本权重一样)

4.2 创建训练器并指定权重列

LinearRegressionlr=newLinearRegression().setWeightCol("weight");

如果你不需要样本权重,完全可以不设weightCol

4.3 训练模型 + 预测

LinearRegressionModellrModel=lr.fit(inputTable);TableoutputTable=lrModel.transform(inputTable)[0];
  • fit():在 Table 上训练,得到模型
  • transform():输出表会多一个prediction

4.4 读取输出:对比 label 与 prediction

doubleexpectedResult=(Double)row.getField(lr.getLabelCol());doublepredictionResult=(Double)row.getField(lr.getPredictionCol());System.out.printf("... Expected Result: %s \tPrediction Result: %s\n",expectedResult,predictionResult);

5. 实战用法:最常见的两条链路

链路 A:数值特征 → StandardScaler → LinearRegression

如果你的特征量纲差别大(金额、次数、时长混在一起),强烈推荐:

  • VectorAssembler(拼特征)
  • StandardScaler(标准化)
  • LinearRegression(训练)

链路 B:类别特征(StringIndexer + OneHot)+ 数值特征 → VectorAssembler → LinearRegression

当你的输入既有类别又有数值时:

  • StringIndexer:字符串类别 → index
  • OneHotEncoder:index → 稀疏向量
  • VectorAssembler:数值列 + 类别向量 拼成 features
  • LinearRegression:训练回归

6. 小结

Flink ML 的 LinearRegression 使用非常标准化:

  • 输入:features(Vector)+label(Double)+ 可选weight(Double)
  • 训练:lr.fit(table)
  • 预测:model.transform(table)输出prediction
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/1/29 11:07:54

ViGEmBus虚拟游戏手柄驱动:让所有手柄在PC上畅玩游戏的终极指南

ViGEmBus虚拟游戏手柄驱动&#xff1a;让所有手柄在PC上畅玩游戏的终极指南 【免费下载链接】ViGEmBus 项目地址: https://gitcode.com/gh_mirrors/vig/ViGEmBus 你是否曾经遇到过这样的困扰&#xff1a;心爱的手柄连接电脑后&#xff0c;游戏却完全无法识别&#xff1…

作者头像 李华
网站建设 2026/1/29 16:54:27

PyTorch-CUDA-v2.8镜像对GPT系列模型的兼容性测试

PyTorch-CUDA-v2.8镜像对GPT系列模型的兼容性测试 在当前大模型研发如火如荼的背景下&#xff0c;一个稳定、高效且开箱即用的深度学习运行环境&#xff0c;已经成为AI工程师日常开发中的“刚需”。尤其是在训练和部署GPT类大规模语言模型时&#xff0c;动辄数十GB显存占用、复…

作者头像 李华
网站建设 2026/1/30 1:30:02

GDP-D-甘露糖二钠盐 —— 糖基化研究与治疗开发的核心糖核苷酸 148296-46-2

GDP-D-甘露糖二钠盐是糖核苷酸家族中至关重要的成员&#xff0c;在细胞糖基化进程中扮演着不可替代的角色。作为甘露糖残基的关键活化供体&#xff0c;它直接参与蛋白质和脂质的翻译后修饰&#xff0c;影响其结构、稳定性与生物功能。从基础生物化学研究到前沿生物制药开发&…

作者头像 李华