天天看点

SparkMLlib---SGD随机梯度下降算法

代码:

package mllib


import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf}

import scala.collection.mutable.HashMap

/**
  * 随机梯度下降算法
  * Created by 汪本成 on 2016/8/5.
  */
object SGD {

  //屏蔽不必要的日志显示在终端上
  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)

  //程序入口
  val conf = new SparkConf()
    .setMaster("local[1]")
    .setAppName(this.getClass().getSimpleName()
    .filter(!_.equals('$')))
  
  println(this.getClass().getSimpleName().filter(!_.equals('$')))

  val sc = new SparkContext(conf)

  //创建存储数据集HashMap集合
  val data = new HashMap[Int, Int]()
  //生成数据集内容
  def getData(): HashMap[Int, Int] = {
    for(i <- 1 to 50) {
      data += (i -> (2 * i))  //写入公式y=2x
    }
    data
  }

  //假设a=0
  var a: Double = 0
  //设置步进系数
  var b: Double = 0.1

  //设置迭代公式
  def sgd(x: Double, y: Double) = {
    a = a - b * ((a * x) - y)
  }

  def main(args: Array[String]) {
    //获取数据集
    val dataSource = getData()
    println("data: ")
    dataSource.foreach(each => println(each + " "))
    println("\nresult: ")
    var num = 1
    //开始迭代
    dataSource.foreach(myMap => {
      println(num + ":" + a + "("+myMap._1+","+myMap._2+")")
      sgd(myMap._1, myMap._2)
      num = num + 1
    })
    //显示结果
    println("最终结果a为 " + a)
  }

}
      

运行结果:

"C:\Program Files\Java\jdk1.8.0_77\bin\java" -Didea.launcher.port=7533 "-Didea.launcher.bin.path=D:\Program Files (x86)\JetBrains\IntelliJ IDEA 15.0.5\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_77\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\access-bridge-64.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\cldrdata.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\dnsns.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\jaccess.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\jfxrt.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\localedata.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\nashorn.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\sunec.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\sunjce_provider.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\sunmscapi.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\sunpkcs11.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\ext\zipfs.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\javaws.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\jce.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\jfr.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\jfxswt.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\jsse.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\management-agent.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\plugin.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\resources.jar;C:\Program Files\Java\jdk1.8.0_77\jre\lib\rt.jar;G:\location\spark-mllib\out\production\spark-mllib;C:\Program Files (x86)\scala\lib\scala-actors-migration.jar;C:\Program Files (x86)\scala\lib\scala-actors.jar;C:\Program Files (x86)\scala\lib\scala-library.jar;C:\Program Files (x86)\scala\lib\scala-reflect.jar;C:\Program Files (x86)\scala\lib\scala-swing.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\datanucleus-api-jdo-3.2.6.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\datanucleus-core-3.2.10.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\datanucleus-rdbms-3.2.9.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\spark-1.6.1-yarn-shuffle.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\spark-assembly-1.6.1-hadoop2.6.0.jar;G:\home\download\spark-1.6.1-bin-hadoop2.6\lib\spark-examples-1.6.1-hadoop2.6.0.jar;D:\Program Files (x86)\JetBrains\IntelliJ IDEA 15.0.5\lib\idea_rt.jar" com.intellij.rt.execution.application.AppMain mllib.SGD
SGD
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/G:/home/download/spark-1.6.1-bin-hadoop2.6/lib/spark-assembly-1.6.1-hadoop2.6.0.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/G:/home/download/spark-1.6.1-bin-hadoop2.6/lib/spark-examples-1.6.1-hadoop2.6.0.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
16/08/05 00:48:28 INFO Slf4jLogger: Slf4jLogger started
16/08/05 00:48:28 INFO Remoting: Starting remoting
16/08/05 00:48:28 INFO Remoting: Remoting started; listening on addresses :[akka.tcp://[email protected]:24009]
data: 
(23,46) 
(50,100) 
(32,64) 
(41,82) 
(17,34) 
(8,16) 
(35,70) 
(44,88) 
(26,52) 
(11,22) 
(29,58) 
(38,76) 
(47,94) 
(20,40) 
(2,4) 
(5,10) 
(14,28) 
(46,92) 
(40,80) 
(49,98) 
(4,8) 
(13,26) 
(22,44) 
(31,62) 
(16,32) 
(7,14) 
(43,86) 
(25,50) 
(34,68) 
(10,20) 
(37,74) 
(1,2) 
(19,38) 
(28,56) 
(45,90) 
(27,54) 
(36,72) 
(18,36) 
(9,18) 
(21,42) 
(48,96) 
(3,6) 
(12,24) 
(30,60) 
(39,78) 
(15,30) 
(42,84) 
(24,48) 
(6,12) 
(33,66) 

result: 
1:0.0(23,46)
2:4.6000000000000005(50,100)
3:-8.400000000000002(32,64)
4:24.880000000000006(41,82)
5:-68.92800000000003(17,34)
6:51.649600000000035(8,16)
7:11.929920000000003(35,70)
8:-22.82480000000001(44,88)
9:86.40432000000006(26,52)
10:-133.04691200000013(11,22)
11:15.504691199999996(29,58)
12:-23.65891328(38,76)
13:73.84495718400001(47,94)
14:-263.82634158080003(20,40)
15:267.82634158080003(2,4)
16:214.66107326464004(5,10)
17:108.33053663232002(14,28)
18:-40.53221465292802(46,92)
19:155.1159727505409(40,80)
20:-457.3479182516227(49,98)
21:1793.4568811813288(4,8)
22:1076.8741287087973(13,26)
23:-320.46223861263934(22,44)
24:388.95468633516725(31,62)
25:-810.6048413038511(16,32)
26:489.56290478231085(7,14)
27:148.2688714346932(43,86)
28:-480.6872757344877(25,50)
29:726.0309136017315(34,68)
30:-1735.6741926441557(10,20)
31:2.0000000000002274(37,74)
32:1.999999999999386(1,2)
33:1.9999999999994476(19,38)
34:2.000000000000497(28,56)
35:1.9999999999991056(45,90)
36:2.00000000000313(27,54)
37:1.9999999999946787(36,72)
38:2.000000000013835(18,36)
39:1.999999999988932(9,18)
40:1.999999999998893(21,42)
41:2.0000000000012172(48,96)
42:1.9999999999953737(3,6)
43:1.9999999999967615(12,24)
44:2.000000000000648(30,60)
45:1.999999999998704(39,78)
46:2.0000000000037588(15,30)
47:1.9999999999981206(42,84)
48:2.0000000000060134(24,48)
49:1.999999999991581(6,12)
50:1.9999999999966325(33,66)
最终结果a为 2.0000000000077454
16/08/05 00:48:28 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon.

Process finished with exit code 0
           

分析:

当α为0.1的时候,一般30次计算就计算出来了;如果是0.5,一般15次计算就有正确结果 。如果是1,则50次都没有结果