1. LinearRegression 做什么?
它学习一个线性函数(简化表达):
[y^=w⋅x+b][ \hat{y} = w \cdot x + b ][y^=w⋅x+b]
其中:
- (x)(x)(x)是特征向量(Vector)
- (w)(w)(w)是权重向量(模型参数)
- (b)(b)(b)是偏置
- (y^)(\hat{y})(y^)是预测值(prediction)
2. 输入列与输出列
输入列(Input Columns)
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
featuresCol | Vector | "features" | 特征向量 |
labelCol | Integer(示例实际用 Double) | "label" | 要预测的连续标签 |
weightCol | Double | "weight" | 样本权重(可选) |
你贴的示例里
label和weight都是 Double,这更符合回归场景。
输出列(Output Columns)
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
predictionCol | Integer(实际为 Double) | "prediction" | 预测值 |
3. 参数详解(Parameters)
LinearRegressionModel(模型)常用参数:
| Key | 默认值 | 说明 |
|---|---|---|
featuresCol | "features" | 特征列名 |
predictionCol | "prediction" | 预测列名 |
LinearRegression(训练器)额外参数:
| Key | 默认值 | 说明 |
|---|---|---|
labelCol | "label" | 标签列名 |
weightCol | null | 权重列名(可选) |
maxIter | 20 | 最大迭代次数 |
reg | 0.0 | 正则化系数(L2/L1 混合由 elasticNet 控制) |
elasticNet | 0.0 | ElasticNet 参数:0=纯L2,1=纯L1 |
learningRate | 0.1 | 学习率 |
globalBatchSize | 32 | 全局 batch 大小 |
tol | 1e-6 | 收敛阈值(迭代停止条件之一) |
工程建议(很实用):
- 特征尺度差异大:先
StandardScaler再 LR,收敛更稳 - 数据量大/噪声大:适当加
reg,防过拟合、也更稳定 - 收敛慢:调大
maxIter或调学习率(learningRate)但要谨慎
4. Java 示例逐段解读(fit + transform)
你给的示例是“训练后再对同一份数据做预测”,方便验证效果。
4.1 构造输入表:features + label + weight
DataStream<Row>inputStream=env.fromElements(Row.of(Vectors.dense(2,1),4.0,1.0),Row.of(Vectors.dense(3,2),7.0,1.0),Row.of(Vectors.dense(4,3),10.0,1.0),Row.of(Vectors.dense(2,4),10.0,1.0),Row.of(Vectors.dense(2,2),6.0,1.0),Row.of(Vectors.dense(4,3),10.0,1.0),Row.of(Vectors.dense(1,2),5.0,1.0),Row.of(Vectors.dense(5,3),11.0,1.0));TableinputTable=tEnv.fromDataStream(inputStream).as("features","label","weight");features是二维向量label是回归目标(Double)weight全是 1.0(表示每个样本权重一样)
4.2 创建训练器并指定权重列
LinearRegressionlr=newLinearRegression().setWeightCol("weight");如果你不需要样本权重,完全可以不设weightCol。
4.3 训练模型 + 预测
LinearRegressionModellrModel=lr.fit(inputTable);TableoutputTable=lrModel.transform(inputTable)[0];fit():在 Table 上训练,得到模型transform():输出表会多一个prediction列
4.4 读取输出:对比 label 与 prediction
doubleexpectedResult=(Double)row.getField(lr.getLabelCol());doublepredictionResult=(Double)row.getField(lr.getPredictionCol());System.out.printf("... Expected Result: %s \tPrediction Result: %s\n",expectedResult,predictionResult);5. 实战用法:最常见的两条链路
链路 A:数值特征 → StandardScaler → LinearRegression
如果你的特征量纲差别大(金额、次数、时长混在一起),强烈推荐:
- VectorAssembler(拼特征)
- StandardScaler(标准化)
- LinearRegression(训练)
链路 B:类别特征(StringIndexer + OneHot)+ 数值特征 → VectorAssembler → LinearRegression
当你的输入既有类别又有数值时:
- StringIndexer:字符串类别 → index
- OneHotEncoder:index → 稀疏向量
- VectorAssembler:数值列 + 类别向量 拼成 features
- LinearRegression:训练回归
6. 小结
Flink ML 的 LinearRegression 使用非常标准化:
- 输入:
features(Vector)+label(Double)+ 可选weight(Double) - 训练:
lr.fit(table) - 预测:
model.transform(table)输出prediction