ruby-spark 1.0.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (176) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +185 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +7 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/custom_marshal.rb +94 -0
  12. data/benchmark/digest.rb +150 -0
  13. data/benchmark/enumerator.rb +88 -0
  14. data/benchmark/performance/prepare.sh +18 -0
  15. data/benchmark/performance/python.py +156 -0
  16. data/benchmark/performance/r.r +69 -0
  17. data/benchmark/performance/ruby.rb +167 -0
  18. data/benchmark/performance/run-all.sh +160 -0
  19. data/benchmark/performance/scala.scala +181 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/ext/ruby_c/extconf.rb +3 -0
  27. data/ext/ruby_c/murmur.c +158 -0
  28. data/ext/ruby_c/murmur.h +9 -0
  29. data/ext/ruby_c/ruby-spark.c +18 -0
  30. data/ext/ruby_java/Digest.java +36 -0
  31. data/ext/ruby_java/Murmur2.java +98 -0
  32. data/ext/ruby_java/RubySparkExtService.java +28 -0
  33. data/ext/ruby_java/extconf.rb +3 -0
  34. data/ext/spark/build.sbt +73 -0
  35. data/ext/spark/project/plugins.sbt +9 -0
  36. data/ext/spark/sbt/sbt +34 -0
  37. data/ext/spark/src/main/scala/Exec.scala +91 -0
  38. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  39. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  40. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  41. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  42. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  43. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  44. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  46. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  47. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  48. data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
  49. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  50. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  51. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  52. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  53. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  54. data/lib/ruby-spark.rb +1 -0
  55. data/lib/spark.rb +198 -0
  56. data/lib/spark/accumulator.rb +260 -0
  57. data/lib/spark/broadcast.rb +98 -0
  58. data/lib/spark/build.rb +43 -0
  59. data/lib/spark/cli.rb +169 -0
  60. data/lib/spark/command.rb +86 -0
  61. data/lib/spark/command/base.rb +154 -0
  62. data/lib/spark/command/basic.rb +345 -0
  63. data/lib/spark/command/pair.rb +124 -0
  64. data/lib/spark/command/sort.rb +51 -0
  65. data/lib/spark/command/statistic.rb +144 -0
  66. data/lib/spark/command_builder.rb +141 -0
  67. data/lib/spark/command_validator.rb +34 -0
  68. data/lib/spark/config.rb +244 -0
  69. data/lib/spark/constant.rb +14 -0
  70. data/lib/spark/context.rb +304 -0
  71. data/lib/spark/error.rb +50 -0
  72. data/lib/spark/ext/hash.rb +41 -0
  73. data/lib/spark/ext/integer.rb +25 -0
  74. data/lib/spark/ext/io.rb +57 -0
  75. data/lib/spark/ext/ip_socket.rb +29 -0
  76. data/lib/spark/ext/module.rb +58 -0
  77. data/lib/spark/ext/object.rb +24 -0
  78. data/lib/spark/ext/string.rb +24 -0
  79. data/lib/spark/helper.rb +10 -0
  80. data/lib/spark/helper/logger.rb +40 -0
  81. data/lib/spark/helper/parser.rb +85 -0
  82. data/lib/spark/helper/serialize.rb +71 -0
  83. data/lib/spark/helper/statistic.rb +93 -0
  84. data/lib/spark/helper/system.rb +42 -0
  85. data/lib/spark/java_bridge.rb +19 -0
  86. data/lib/spark/java_bridge/base.rb +203 -0
  87. data/lib/spark/java_bridge/jruby.rb +23 -0
  88. data/lib/spark/java_bridge/rjb.rb +41 -0
  89. data/lib/spark/logger.rb +76 -0
  90. data/lib/spark/mllib.rb +100 -0
  91. data/lib/spark/mllib/classification/common.rb +31 -0
  92. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  93. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  94. data/lib/spark/mllib/classification/svm.rb +135 -0
  95. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  96. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  97. data/lib/spark/mllib/matrix.rb +120 -0
  98. data/lib/spark/mllib/regression/common.rb +73 -0
  99. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  100. data/lib/spark/mllib/regression/lasso.rb +100 -0
  101. data/lib/spark/mllib/regression/linear.rb +124 -0
  102. data/lib/spark/mllib/regression/ridge.rb +97 -0
  103. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  104. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  105. data/lib/spark/mllib/stat/distribution.rb +12 -0
  106. data/lib/spark/mllib/vector.rb +185 -0
  107. data/lib/spark/rdd.rb +1328 -0
  108. data/lib/spark/sampler.rb +92 -0
  109. data/lib/spark/serializer.rb +24 -0
  110. data/lib/spark/serializer/base.rb +170 -0
  111. data/lib/spark/serializer/cartesian.rb +37 -0
  112. data/lib/spark/serializer/marshal.rb +19 -0
  113. data/lib/spark/serializer/message_pack.rb +25 -0
  114. data/lib/spark/serializer/oj.rb +25 -0
  115. data/lib/spark/serializer/pair.rb +27 -0
  116. data/lib/spark/serializer/utf8.rb +25 -0
  117. data/lib/spark/sort.rb +189 -0
  118. data/lib/spark/stat_counter.rb +125 -0
  119. data/lib/spark/storage_level.rb +39 -0
  120. data/lib/spark/version.rb +3 -0
  121. data/lib/spark/worker/master.rb +144 -0
  122. data/lib/spark/worker/spark_files.rb +15 -0
  123. data/lib/spark/worker/worker.rb +197 -0
  124. data/ruby-spark.gemspec +36 -0
  125. data/spec/generator.rb +37 -0
  126. data/spec/inputs/lorem_300.txt +316 -0
  127. data/spec/inputs/numbers/1.txt +50 -0
  128. data/spec/inputs/numbers/10.txt +50 -0
  129. data/spec/inputs/numbers/11.txt +50 -0
  130. data/spec/inputs/numbers/12.txt +50 -0
  131. data/spec/inputs/numbers/13.txt +50 -0
  132. data/spec/inputs/numbers/14.txt +50 -0
  133. data/spec/inputs/numbers/15.txt +50 -0
  134. data/spec/inputs/numbers/16.txt +50 -0
  135. data/spec/inputs/numbers/17.txt +50 -0
  136. data/spec/inputs/numbers/18.txt +50 -0
  137. data/spec/inputs/numbers/19.txt +50 -0
  138. data/spec/inputs/numbers/2.txt +50 -0
  139. data/spec/inputs/numbers/20.txt +50 -0
  140. data/spec/inputs/numbers/3.txt +50 -0
  141. data/spec/inputs/numbers/4.txt +50 -0
  142. data/spec/inputs/numbers/5.txt +50 -0
  143. data/spec/inputs/numbers/6.txt +50 -0
  144. data/spec/inputs/numbers/7.txt +50 -0
  145. data/spec/inputs/numbers/8.txt +50 -0
  146. data/spec/inputs/numbers/9.txt +50 -0
  147. data/spec/inputs/numbers_0_100.txt +101 -0
  148. data/spec/inputs/numbers_1_100.txt +100 -0
  149. data/spec/lib/collect_spec.rb +42 -0
  150. data/spec/lib/command_spec.rb +68 -0
  151. data/spec/lib/config_spec.rb +64 -0
  152. data/spec/lib/context_spec.rb +163 -0
  153. data/spec/lib/ext_spec.rb +72 -0
  154. data/spec/lib/external_apps_spec.rb +45 -0
  155. data/spec/lib/filter_spec.rb +80 -0
  156. data/spec/lib/flat_map_spec.rb +100 -0
  157. data/spec/lib/group_spec.rb +109 -0
  158. data/spec/lib/helper_spec.rb +19 -0
  159. data/spec/lib/key_spec.rb +41 -0
  160. data/spec/lib/manipulation_spec.rb +114 -0
  161. data/spec/lib/map_partitions_spec.rb +87 -0
  162. data/spec/lib/map_spec.rb +91 -0
  163. data/spec/lib/mllib/classification_spec.rb +54 -0
  164. data/spec/lib/mllib/clustering_spec.rb +35 -0
  165. data/spec/lib/mllib/matrix_spec.rb +32 -0
  166. data/spec/lib/mllib/regression_spec.rb +116 -0
  167. data/spec/lib/mllib/vector_spec.rb +77 -0
  168. data/spec/lib/reduce_by_key_spec.rb +118 -0
  169. data/spec/lib/reduce_spec.rb +131 -0
  170. data/spec/lib/sample_spec.rb +46 -0
  171. data/spec/lib/serializer_spec.rb +13 -0
  172. data/spec/lib/sort_spec.rb +58 -0
  173. data/spec/lib/statistic_spec.rb +168 -0
  174. data/spec/lib/whole_text_files_spec.rb +33 -0
  175. data/spec/spec_helper.rb +39 -0
  176. metadata +301 -0
@@ -0,0 +1,73 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # RegressionModel
5
+ #
6
+ # A linear model that has a vector of coefficients and an intercept.
7
+ #
8
+ class RegressionModel
9
+
10
+ attr_reader :weights, :intercept
11
+
12
+ def initialize(weights, intercept)
13
+ @weights = Spark::Mllib::Vectors.to_vector(weights)
14
+ @intercept = intercept.to_f
15
+ end
16
+
17
+ # Predict the value of the dependent variable given a vector data
18
+ # containing values for the independent variables.
19
+ #
20
+ # == Examples:
21
+ # lm = RegressionModel.new([1.0, 2.0], 0.1)
22
+ #
23
+ # lm.predict([-1.03, 7.777]) - 14.624 < 1e-6
24
+ # # => true
25
+ #
26
+ # lm.predict(SparseVector.new(2, {0 => -1.03, 1 => 7.777})) - 14.624 < 1e-6
27
+ # # => true
28
+ #
29
+ def predict(data)
30
+ data = Spark::Mllib::Vectors.to_vector(data)
31
+ @weights.dot(data) + @intercept
32
+ end
33
+
34
+ end
35
+ end
36
+ end
37
+
38
+
39
+ module Spark
40
+ module Mllib
41
+ ##
42
+ # RegressionMethodBase
43
+ #
44
+ # Parent for regression methods
45
+ #
46
+ class RegressionMethodBase
47
+
48
+ def self.train(rdd, options)
49
+ # String keys to symbols
50
+ options.symbolize_keys!
51
+
52
+ # Reverse merge
53
+ self::DEFAULT_OPTIONS.each do |key, value|
54
+ if options.has_key?(key)
55
+ # value from user
56
+ else
57
+ options[key] = value
58
+ end
59
+ end
60
+
61
+ # Validation
62
+ first = rdd.first
63
+ unless first.is_a?(LabeledPoint)
64
+ raise Spark::MllibError, "RDD should contains LabeledPoint, got #{first.class}"
65
+ end
66
+
67
+ # Initial weights is optional for user (not for Spark)
68
+ options[:initial_weights] = Vectors.to_vector(options[:initial_weights] || [0.0] * first.features.size)
69
+ end
70
+
71
+ end
72
+ end
73
+ end
@@ -0,0 +1,41 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # LabeledPoint
5
+ #
6
+ # The features and labels of a data point.
7
+ #
8
+ # == Parameters:
9
+ # label::
10
+ # Label for this data point.
11
+ #
12
+ # features::
13
+ # Vector of features for this point
14
+ #
15
+ class LabeledPoint
16
+
17
+ attr_reader :label, :features
18
+
19
+ def initialize(label, features)
20
+ @label = label.to_f
21
+ @features = Spark::Mllib::Vectors.to_vector(features)
22
+ end
23
+
24
+ def self.from_java(object)
25
+ LabeledPoint.new(
26
+ object.label,
27
+ Spark.jb.java_to_ruby(object.features)
28
+ )
29
+ end
30
+
31
+ def marshal_dump
32
+ [@label, @features]
33
+ end
34
+
35
+ def marshal_load(array)
36
+ initialize(array[0], array[1])
37
+ end
38
+
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,100 @@
1
+ ##
2
+ # LassoModel
3
+ #
4
+ # Train a regression model with L1-regularization using Stochastic Gradient Descent.
5
+ # This solves the l1-regularized least squares regression formulation
6
+ # f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1
7
+ # Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
8
+ # its corresponding right hand side label y.
9
+ # See also the documentation for the precise formulation.
10
+ #
11
+ # == Examples:
12
+ #
13
+ # Spark::Mllib.import
14
+ #
15
+ # # Dense vectors
16
+ # data = [
17
+ # LabeledPoint.new(0.0, [0.0]),
18
+ # LabeledPoint.new(1.0, [1.0]),
19
+ # LabeledPoint.new(3.0, [2.0]),
20
+ # LabeledPoint.new(2.0, [3.0])
21
+ # ]
22
+ # lrm = LassoWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
23
+ #
24
+ # lrm.predict([0.0]) - 0 < 0.5
25
+ # # => true
26
+ #
27
+ # lrm.predict([1.0]) - 1 < 0.5
28
+ # # => true
29
+ #
30
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
31
+ # # => true
32
+ #
33
+ #
34
+ # # Sparse vectors
35
+ # data = [
36
+ # LabeledPoint.new(0.0, SparseVector.new(1, {0 => 0.0})),
37
+ # LabeledPoint.new(1.0, SparseVector.new(1, {0 => 1.0})),
38
+ # LabeledPoint.new(3.0, SparseVector.new(1, {0 => 2.0})),
39
+ # LabeledPoint.new(2.0, SparseVector.new(1, {0 => 3.0}))
40
+ # ]
41
+ # lrm = LinearRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
42
+ #
43
+ # lrm.predict([0.0]) - 0 < 0.5
44
+ # # => true
45
+ #
46
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
47
+ # # => true
48
+ #
49
+ class Spark::Mllib::LassoModel < Spark::Mllib::RegressionModel
50
+ end
51
+
52
+ module Spark
53
+ module Mllib
54
+ class LassoWithSGD < RegressionMethodBase
55
+
56
+ DEFAULT_OPTIONS = {
57
+ iterations: 100,
58
+ step: 1.0,
59
+ reg_param: 0.01,
60
+ mini_batch_fraction: 1.0,
61
+ initial_weights: nil
62
+ }
63
+
64
+ # Train a Lasso regression model on the given data.
65
+ #
66
+ # == Parameters:
67
+ # rdd::
68
+ # The training data (RDD instance).
69
+ #
70
+ # iterations::
71
+ # The number of iterations (default: 100).
72
+ #
73
+ # step::
74
+ # The step parameter used in SGD (default: 1.0).
75
+ #
76
+ # reg_param::
77
+ # The regularizer parameter (default: 0.0).
78
+ #
79
+ # mini_batch_fraction::
80
+ # Fraction of data to be used for each SGD iteration (default: 1.0).
81
+ #
82
+ # initial_weights::
83
+ # The initial weights (default: nil).
84
+ #
85
+ def self.train(rdd, options={})
86
+ super
87
+
88
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLassoModelWithSGD', rdd,
89
+ options[:iterations].to_i,
90
+ options[:step].to_f,
91
+ options[:reg_param].to_f,
92
+ options[:mini_batch_fraction].to_f,
93
+ options[:initial_weights])
94
+
95
+ LassoModel.new(weights, intercept)
96
+ end
97
+
98
+ end
99
+ end
100
+ end
@@ -0,0 +1,124 @@
1
+ ##
2
+ # LinearRegressionModel
3
+ #
4
+ # Train a linear regression model with no regularization using Stochastic Gradient Descent.
5
+ # This solves the least squares regression formulation
6
+ # f(weights) = 1/n ||A weights-y||^2^
7
+ # (which is the mean squared error).
8
+ # Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
9
+ # its corresponding right hand side label y.
10
+ # See also the documentation for the precise formulation.
11
+ #
12
+ # == Examples:
13
+ #
14
+ # Spark::Mllib.import
15
+ #
16
+ # # Dense vectors
17
+ # data = [
18
+ # LabeledPoint.new(0.0, [0.0]),
19
+ # LabeledPoint.new(1.0, [1.0]),
20
+ # LabeledPoint.new(3.0, [2.0]),
21
+ # LabeledPoint.new(2.0, [3.0])
22
+ # ]
23
+ # lrm = LinearRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
24
+ #
25
+ # lrm.intercept # => 0.0
26
+ # lrm.weights # => [0.9285714285714286]
27
+ #
28
+ # lrm.predict([0.0]) < 0.5
29
+ # # => true
30
+ #
31
+ # lrm.predict([1.0]) - 1 < 0.5
32
+ # # => true
33
+ #
34
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
35
+ # # => true
36
+ #
37
+ # # Sparse vectors
38
+ # data = [
39
+ # LabeledPoint.new(0.0, SparseVector.new(1, {0 => 0.0})),
40
+ # LabeledPoint.new(1.0, SparseVector.new(1, {0 => 1.0})),
41
+ # LabeledPoint.new(3.0, SparseVector.new(1, {0 => 2.0})),
42
+ # LabeledPoint.new(2.0, SparseVector.new(1, {0 => 3.0}))
43
+ # ]
44
+ # lrm = LinearRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
45
+ #
46
+ # lrm.intercept # => 0.0
47
+ # lrm.weights # => [0.9285714285714286]
48
+ #
49
+ # lrm.predict([0.0]) < 0.5
50
+ # # => true
51
+ #
52
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
53
+ # # => true
54
+ #
55
+ class Spark::Mllib::LinearRegressionModel < Spark::Mllib::RegressionModel
56
+ end
57
+
58
+ module Spark
59
+ module Mllib
60
+ class LinearRegressionWithSGD < RegressionMethodBase
61
+
62
+ DEFAULT_OPTIONS = {
63
+ iterations: 100,
64
+ step: 1.0,
65
+ mini_batch_fraction: 1.0,
66
+ initial_weights: nil,
67
+ reg_param: 0.0,
68
+ reg_type: nil,
69
+ intercept: false
70
+ }
71
+
72
+ # Train a linear regression model on the given data.
73
+ #
74
+ # == Parameters:
75
+ # rdd::
76
+ # The training data (RDD instance).
77
+ #
78
+ # iterations::
79
+ # The number of iterations (default: 100).
80
+ #
81
+ # step::
82
+ # The step parameter used in SGD (default: 1.0).
83
+ #
84
+ # mini_batch_fraction::
85
+ # Fraction of data to be used for each SGD iteration (default: 1.0).
86
+ #
87
+ # initial_weights::
88
+ # The initial weights (default: nil).
89
+ #
90
+ # reg_param::
91
+ # The regularizer parameter (default: 0.0).
92
+ #
93
+ # reg_type::
94
+ # The type of regularizer used for training our model (default: nil).
95
+ #
96
+ # Allowed values:
97
+ # - "l1" for using L1 regularization (lasso),
98
+ # - "l2" for using L2 regularization (ridge),
99
+ # - None for no regularization
100
+ #
101
+ # intercept::
102
+ # Boolean parameter which indicates the use
103
+ # or not of the augmented representation for
104
+ # training data (i.e. whether bias features
105
+ # are activated or not). (default: False)
106
+ #
107
+ def self.train(rdd, options={})
108
+ super
109
+
110
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLinearRegressionModelWithSGD', rdd,
111
+ options[:iterations].to_i,
112
+ options[:step].to_f,
113
+ options[:mini_batch_fraction].to_f,
114
+ options[:initial_weights],
115
+ options[:reg_param].to_f,
116
+ options[:reg_type],
117
+ options[:intercept])
118
+
119
+ LinearRegressionModel.new(weights, intercept)
120
+ end
121
+
122
+ end
123
+ end
124
+ end
@@ -0,0 +1,97 @@
1
+ ##
2
+ # RidgeRegressionModel
3
+ #
4
+ # Train a regression model with L2-regularization using Stochastic Gradient Descent.
5
+ # This solves the l1-regularized least squares regression formulation
6
+ # f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^
7
+ # Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
8
+ # its corresponding right hand side label y.
9
+ # See also the documentation for the precise formulation.
10
+ #
11
+ # == Examples:
12
+ #
13
+ # Spark::Mllib.import
14
+ #
15
+ # data = [
16
+ # LabeledPoint.new(0.0, [0.0]),
17
+ # LabeledPoint.new(1.0, [1.0]),
18
+ # LabeledPoint.new(3.0, [2.0]),
19
+ # LabeledPoint.new(2.0, [3.0])
20
+ # ]
21
+ # lrm = RidgeRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
22
+ #
23
+ # lrm.predict([0.0]) - 0 < 0.5
24
+ # # => true
25
+ #
26
+ # lrm.predict([1.0]) - 1 < 0.5
27
+ # # => true
28
+ #
29
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
30
+ # # => true
31
+ #
32
+ # data = [
33
+ # LabeledPoint.new(0.0, SparseVector.new(1, {0 => 0.0})),
34
+ # LabeledPoint.new(1.0, SparseVector.new(1, {0 => 1.0})),
35
+ # LabeledPoint.new(3.0, SparseVector.new(1, {0 => 2.0})),
36
+ # LabeledPoint.new(2.0, SparseVector.new(1, {0 => 3.0}))
37
+ # ]
38
+ # lrm = LinearRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
39
+ #
40
+ # lrm.predict([0.0]) - 0 < 0.5
41
+ # # => true
42
+ #
43
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
44
+ # # => true
45
+ #
46
+ class Spark::Mllib::RidgeRegressionModel < Spark::Mllib::RegressionModel
47
+ end
48
+
49
+ module Spark
50
+ module Mllib
51
+ class RidgeRegressionWithSGD < RegressionMethodBase
52
+
53
+ DEFAULT_OPTIONS = {
54
+ iterations: 100,
55
+ step: 1.0,
56
+ reg_param: 0.01,
57
+ mini_batch_fraction: 1.0,
58
+ initial_weights: nil
59
+ }
60
+
61
+ # Train a ridge regression model on the given data.
62
+ #
63
+ # == Parameters:
64
+ # rdd::
65
+ # The training data (RDD instance).
66
+ #
67
+ # iterations::
68
+ # The number of iterations (default: 100).
69
+ #
70
+ # step::
71
+ # The step parameter used in SGD (default: 1.0).
72
+ #
73
+ # reg_param::
74
+ # The regularizer parameter (default: 0.0).
75
+ #
76
+ # mini_batch_fraction::
77
+ # Fraction of data to be used for each SGD iteration (default: 1.0).
78
+ #
79
+ # initial_weights::
80
+ # The initial weights (default: nil).
81
+ #
82
+ def self.train(rdd, options={})
83
+ super
84
+
85
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainRidgeModelWithSGD', rdd,
86
+ options[:iterations].to_i,
87
+ options[:step].to_f,
88
+ options[:reg_param].to_f,
89
+ options[:mini_batch_fraction].to_f,
90
+ options[:initial_weights])
91
+
92
+ RidgeRegressionModel.new(weights, intercept)
93
+ end
94
+
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,53 @@
1
+ require 'matrix'
2
+
3
+ module Spark
4
+ module Mllib
5
+ class MatrixAdapter < ::Matrix
6
+
7
+ def self.new(*args)
8
+ object = self.allocate
9
+
10
+ if args.size == 2
11
+ # Matrix is initialized from Matrix
12
+ # Arguments: rows, column count
13
+ object.__send__(:original_initialize, *args)
14
+ else
15
+ object.__send__(:initialize, *args)
16
+ end
17
+
18
+ object
19
+ end
20
+
21
+ alias_method :original_initialize, :initialize
22
+
23
+ def initialize(type, rows, cols, values=nil)
24
+ case type
25
+ when :dense
26
+ values = values.dup
27
+ if rows * cols == values.size
28
+ # Values are on one row
29
+ # 2x2 => [1,2,3,4]
30
+ values = values.each_slice(cols).to_a
31
+ else
32
+ # 2x2 => [[1,2], [3,4]]
33
+ end
34
+ when :sparse
35
+ values = Array.new(rows) { Array.new(cols) { 0.0 } }
36
+ else
37
+ raise Spark::MllibError, 'Unknow vector type.'
38
+ end
39
+
40
+ super(values, cols)
41
+ end
42
+
43
+ def shape
44
+ [row_count, column_count]
45
+ end
46
+
47
+ def values
48
+ @values || to_a
49
+ end
50
+
51
+ end
52
+ end
53
+ end