博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark ML(3):回归算法实现(线性回归、逻辑回归)
阅读量:4281 次
发布时间:2019-05-27

本文共 9979 字,大约阅读时间需要 33 分钟。

一、环境配置

1.spark2.1.0-cdh5.7.0(自编译)

2.cdh5.7.0

3.scala2.11.8

4.centos6.4

二、环境准备

1.spark客户端调试环境搭建

参考:

2.创建scala项目

参考:

3.添加pom依赖

4.0.0
sparktest
sparktest
1.0-SNAPSHOT
2008
2.11.8
0.9.0.0
1.2.0-cdh5.7.0
2.1.0
2.6.0-cdh5.7.0
cloudera
https://repository.cloudera.com/artifactory/cloudera-repos/
org.apache.spark
spark-mllib_2.11
2.1.0
org.apache.spark
spark-streaming_2.11
2.1.0
org.scala-lang
scala-library
${scala.version}
org.apache.spark
spark-sql_2.11
${spark.version}
org.apache.spark
spark-core_2.11
${spark.version}
src/main/scala
org.scala-tools
maven-scala-plugin
compile
testCompile
${scala.version}
-target:jvm-1.5
org.apache.maven.plugins
maven-eclipse-plugin
true
ch.epfl.lamp.sdt.core.scalabuilder
ch.epfl.lamp.sdt.core.scalanature
org.eclipse.jdt.launching.JRE_CONTAINER
ch.epfl.lamp.sdt.launching.SCALA_CONTAINER
org.scala-tools
maven-scala-plugin
${scala.version}

三、代码实现

1.测试数据样例

position;square;price;direction;type;name;0;190;20000;0;4室2厅2卫;中信城(别墅);0;190;20000;0;4室2厅2卫;中信城(别墅);5;400;15000;0;4室3厅3卫;融创上城;0;500;15000;0;5室3厅2卫;中海莱茵东郡;5;500;15000;0;5室3厅4卫;融创上城(别墅);1;320;15000;1;1室1厅1卫;长江花园;0;143;12000;0;3室2厅2卫;融创上城;0;200;10000;0;4室3厅2卫;中海莱茵东郡(别墅);0;207;9000;0;4室3厅4卫;中海莱茵东郡;0;130;8500;0;3室2厅2卫;伟峰东第;5;150;7000;0;3室2厅2卫;融创上城;2;178;6000;0;4室2厅2卫;鸿城国际花园;5;190;6000;0;3室2厅2卫;亚泰豪苑C栋;1;150;6000;0;5室1厅2卫;通安新居A区;2;165;6000;0;3室2厅2卫;万科惠斯勒小镇;0;64;5500;0;1室1厅1卫;保利中央公园;2;105;5500;0;2室2厅1卫;虹馆;1;160;5300;0;3室2厅1卫;昊源高格蓝湾;2;170;5100;0;4室2厅2卫;亚泰鼎盛国际;...

2.线性回归

package sparktestimport org.apache.spark.ml.classification.LogisticRegressionimport org.apache.spark.ml.feature.VectorAssemblerimport org.apache.spark.ml.linalgimport org.apache.spark.ml.regression.LinearRegressionimport org.apache.spark.sql.SparkSessionimport org.apache.spark.{SparkConf, SparkContext}import scala.util.Randomobject Main {  def main(args: Array[String]): Unit = {    val conf = new SparkConf().setAppName("linear").setMaster("local")    val sc = new SparkContext(conf)    val spark = SparkSession.builder().config(conf).getOrCreate()    val file = spark.read.format("csv").option("sep",";").option("header","true").load("house.csv")    //val file: Nothing = spark.read.format("csv").option("sep", ";").option("header", "true").load("house.csv")    import spark.implicits._    //打乱顺序    val rand = new Random()    val data = file.select("square","price").map(      row => (row.getAs[String](0).toDouble,row.getString(1).toDouble,rand.nextDouble()))      .toDF("square","price","rand").sort("rand") //强制类型转换过程    val ass = new VectorAssembler().setInputCols(Array("square")).setOutputCol("features")    val dataset = ass.transform(data)//特征包装    val Array(train,test) = dataset.randomSplit(Array(0.8,0.2))//拆分成训练数据集和测试数据集    //train.show()    //线性回归    val lr = new LinearRegression().setStandardization(true).setMaxIter(10)      .setFeaturesCol("features")      .setLabelCol("price")    //创建一个对象    val model = lr.fit(train) //训练    model.transform(test).show()  }}结果:|square|  price|                rand|features|        prediction|+------+-------+--------------------+--------+------------------+|  64.0| 1400.0|0.006545997025104056|  [64.0]| 1871.626837251842|| 100.0| 2600.0|0.006070889056102979| [100.0]|1910.2354430535167||  10.0|  450.0|0.016279200291292373|  [10.0]|  1813.71392854933||  60.0| 1700.0| 0.01773595114007931|  [60.0]| 1867.336992162767|| 150.0| 7000.0| 0.01799868562447937| [150.0]|1963.8585066669539|| 320.0|15000.0|0.022266421888484267| [320.0]|  2146.17692295264||  42.0| 1200.0| 0.03158087604172155|  [42.0]|1848.0326892619296||  88.0| 1800.0|  0.0340325321865349|  [88.0]| 1897.365907786292||  60.0| 1600.0| 0.05842848976518067|  [60.0]| 1867.336992162767|| 154.0| 3700.0| 0.08695690147815338| [154.0]| 1968.148351756029||  96.0| 2500.0| 0.08956069761188501|  [96.0]|1905.9455979644417||  63.0| 1400.0|  0.1058435529752908|  [63.0]|1870.5543759795733|| 100.0| 2200.0| 0.12881102655257837| [100.0]|1910.2354430535167||  20.0|  500.0| 0.13298961275676147|  [20.0]|1824.4385412720173|| 148.0| 2800.0|  0.1347286681517027| [148.0]|1961.7135841224165||  92.0| 2950.0|  0.1418523082181563|  [92.0]|1901.6557528753667||  96.0| 2300.0| 0.14272158486886666|  [96.0]|1905.9455979644417||  70.0| 1500.0| 0.14371869433210183|  [70.0]|1878.0616048854545||  18.0|  600.0|  0.1523863581129299|  [18.0]|  1822.29361872748||  20.0|  800.0| 0.17536250365755057|  [20.0]|1824.4385412720173|+------+-------+--------------------+--------+------------------+only showing top 20 rows

3.逻辑回归

package sparktestimport org.apache.spark.ml.classification.LogisticRegressionimport org.apache.spark.ml.feature.VectorAssemblerimport org.apache.spark.ml.linalgimport org.apache.spark.ml.regression.LinearRegression//import org.apache.spark.mimport org.apache.spark.sql.SparkSessionimport org.apache.spark.{SparkConf, SparkContext}import scala.util.Randomobject Main {  def main(args: Array[String]): Unit = {    val conf = new SparkConf().setAppName("linear").setMaster("local")    val sc = new SparkContext(conf)    val spark = SparkSession.builder().config(conf).getOrCreate()    val file = spark.read.format("csv").option("sep",";").option("header","true").load("house.csv")    //val file: Nothing = spark.read.format("csv").option("sep", ";").option("header", "true").load("house.csv")    import spark.implicits._    //打乱顺序    val rand = new Random()    val data = file.select("square","price").map(      row => (row.getAs[String](0).toDouble,row.getString(1).toDouble,rand.nextDouble()))      .toDF("square","price","rand").sort("rand") //强制类型转换过程    val ass = new VectorAssembler().setInputCols(Array("square")).setOutputCol("features")    val dataset = ass.transform(data)//特征包装    val Array(train,test) = dataset.randomSplit(Array(0.8,0.2))//拆分成训练数据集和测试数据集    //逻辑回归    val lr = new LogisticRegression().setLabelCol("price").setFeaturesCol("features")      .setRegParam(0.3).setElasticNetParam(0.8).setMaxIter(10)    val model = lr.fit(train)    model.transform(test).show()    val s = model.summary.totalIterations    println(s"iter: ${s}")  }}结果:|square| price|                rand|features|       rawPrediction|         probability|prediction|+------+------+--------------------+--------+--------------------+--------------------+----------+|  43.0|1600.0|0.001716305364737214|  [43.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  60.0|1600.0|0.010588427013326074|  [60.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  61.0|1300.0|0.043301012076277345|  [61.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  60.0|1600.0| 0.05231439852503761|  [60.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  20.0| 600.0| 0.05386280768045393|  [20.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  89.0|2800.0|  0.0650227911769532|  [89.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  60.0|1500.0| 0.06793901574354433|  [60.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  90.0|2500.0| 0.07541330585084804|  [90.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  65.0|1300.0| 0.07727780227514891|  [65.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  25.0|1300.0| 0.09515681816587895|  [25.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  50.0|1400.0| 0.08681645057310305|  [50.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  60.0|1600.0| 0.10042920576336689|  [60.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  41.0|1500.0| 0.11564005495013441|  [41.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  92.0|2950.0| 0.11751726539452112|  [92.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  71.0|1600.0| 0.12520507959550664|  [71.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|| 110.0|2700.0| 0.13631041935375054| [110.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  10.0| 450.0|  0.1429523132917182|  [10.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  65.0|1700.0| 0.15676743340088062|  [65.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  90.0|2100.0| 0.18187817541586593|  [90.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0||  62.0|1500.0| 0.19149303955455987|  [62.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|+------+------+--------------------+--------+--------------------+--------------------+----------+only showing top 20 rows

 

转载地址:http://ugygi.baihongyu.com/

你可能感兴趣的文章
PostgreSQL中表的阶层数据取得方法
查看>>
敏捷开发下的B端交互设计流程
查看>>
如何用产品思维迭代项目管理流程?(创业有感)
查看>>
流程不紧扣价值,就是伪流程
查看>>
算法时间空间复杂度学习总结
查看>>
10分钟掌握数据类型、索引、查询的MySQL优化技巧
查看>>
Go 网络编程示例
查看>>
Web指纹识别技术研究与优化实现(CMS)
查看>>
JNI基础知识(java中的一套接口,用来跟c和c++通信)
查看>>
如何在线关闭一个tcp socket连接
查看>>
最全的微服务知识科普
查看>>
LVDS接口分类,时序,输出格式
查看>>
selinux在 android 上的实现
查看>>
快速解决Android中的selinux权限问题
查看>>
request_firmware函数的使用
查看>>
Linux内核中的软中断、tasklet和工作队列详解
查看>>
Ubuntu 如何更换内核
查看>>
Android 9.0 Auto及m4 core倒车逻辑--基于imx8qm
查看>>
FreeRTOS移植——基于stm32f1
查看>>
关于FreeRTOS移植到STM32F103上的步骤以及注意事项
查看>>