ruby-spark 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +185 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +7 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/custom_marshal.rb +94 -0
  12. data/benchmark/digest.rb +150 -0
  13. data/benchmark/enumerator.rb +88 -0
  14. data/benchmark/performance/prepare.sh +18 -0
  15. data/benchmark/performance/python.py +156 -0
  16. data/benchmark/performance/r.r +69 -0
  17. data/benchmark/performance/ruby.rb +167 -0
  18. data/benchmark/performance/run-all.sh +160 -0
  19. data/benchmark/performance/scala.scala +181 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/ext/ruby_c/extconf.rb +3 -0
  27. data/ext/ruby_c/murmur.c +158 -0
  28. data/ext/ruby_c/murmur.h +9 -0
  29. data/ext/ruby_c/ruby-spark.c +18 -0
  30. data/ext/ruby_java/Digest.java +36 -0
  31. data/ext/ruby_java/Murmur2.java +98 -0
  32. data/ext/ruby_java/RubySparkExtService.java +28 -0
  33. data/ext/ruby_java/extconf.rb +3 -0
  34. data/ext/spark/build.sbt +73 -0
  35. data/ext/spark/project/plugins.sbt +9 -0
  36. data/ext/spark/sbt/sbt +34 -0
  37. data/ext/spark/src/main/scala/Exec.scala +91 -0
  38. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  39. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  40. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  41. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  42. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  43. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  44. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  46. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  47. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  48. data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
  49. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  50. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  51. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  52. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  53. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  54. data/lib/ruby-spark.rb +1 -0
  55. data/lib/spark.rb +198 -0
  56. data/lib/spark/accumulator.rb +260 -0
  57. data/lib/spark/broadcast.rb +98 -0
  58. data/lib/spark/build.rb +43 -0
  59. data/lib/spark/cli.rb +169 -0
  60. data/lib/spark/command.rb +86 -0
  61. data/lib/spark/command/base.rb +154 -0
  62. data/lib/spark/command/basic.rb +345 -0
  63. data/lib/spark/command/pair.rb +124 -0
  64. data/lib/spark/command/sort.rb +51 -0
  65. data/lib/spark/command/statistic.rb +144 -0
  66. data/lib/spark/command_builder.rb +141 -0
  67. data/lib/spark/command_validator.rb +34 -0
  68. data/lib/spark/config.rb +244 -0
  69. data/lib/spark/constant.rb +14 -0
  70. data/lib/spark/context.rb +304 -0
  71. data/lib/spark/error.rb +50 -0
  72. data/lib/spark/ext/hash.rb +41 -0
  73. data/lib/spark/ext/integer.rb +25 -0
  74. data/lib/spark/ext/io.rb +57 -0
  75. data/lib/spark/ext/ip_socket.rb +29 -0
  76. data/lib/spark/ext/module.rb +58 -0
  77. data/lib/spark/ext/object.rb +24 -0
  78. data/lib/spark/ext/string.rb +24 -0
  79. data/lib/spark/helper.rb +10 -0
  80. data/lib/spark/helper/logger.rb +40 -0
  81. data/lib/spark/helper/parser.rb +85 -0
  82. data/lib/spark/helper/serialize.rb +71 -0
  83. data/lib/spark/helper/statistic.rb +93 -0
  84. data/lib/spark/helper/system.rb +42 -0
  85. data/lib/spark/java_bridge.rb +19 -0
  86. data/lib/spark/java_bridge/base.rb +203 -0
  87. data/lib/spark/java_bridge/jruby.rb +23 -0
  88. data/lib/spark/java_bridge/rjb.rb +41 -0
  89. data/lib/spark/logger.rb +76 -0
  90. data/lib/spark/mllib.rb +100 -0
  91. data/lib/spark/mllib/classification/common.rb +31 -0
  92. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  93. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  94. data/lib/spark/mllib/classification/svm.rb +135 -0
  95. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  96. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  97. data/lib/spark/mllib/matrix.rb +120 -0
  98. data/lib/spark/mllib/regression/common.rb +73 -0
  99. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  100. data/lib/spark/mllib/regression/lasso.rb +100 -0
  101. data/lib/spark/mllib/regression/linear.rb +124 -0
  102. data/lib/spark/mllib/regression/ridge.rb +97 -0
  103. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  104. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  105. data/lib/spark/mllib/stat/distribution.rb +12 -0
  106. data/lib/spark/mllib/vector.rb +185 -0
  107. data/lib/spark/rdd.rb +1328 -0
  108. data/lib/spark/sampler.rb +92 -0
  109. data/lib/spark/serializer.rb +24 -0
  110. data/lib/spark/serializer/base.rb +170 -0
  111. data/lib/spark/serializer/cartesian.rb +37 -0
  112. data/lib/spark/serializer/marshal.rb +19 -0
  113. data/lib/spark/serializer/message_pack.rb +25 -0
  114. data/lib/spark/serializer/oj.rb +25 -0
  115. data/lib/spark/serializer/pair.rb +27 -0
  116. data/lib/spark/serializer/utf8.rb +25 -0
  117. data/lib/spark/sort.rb +189 -0
  118. data/lib/spark/stat_counter.rb +125 -0
  119. data/lib/spark/storage_level.rb +39 -0
  120. data/lib/spark/version.rb +3 -0
  121. data/lib/spark/worker/master.rb +144 -0
  122. data/lib/spark/worker/spark_files.rb +15 -0
  123. data/lib/spark/worker/worker.rb +197 -0
  124. data/ruby-spark.gemspec +36 -0
  125. data/spec/generator.rb +37 -0
  126. data/spec/inputs/lorem_300.txt +316 -0
  127. data/spec/inputs/numbers/1.txt +50 -0
  128. data/spec/inputs/numbers/10.txt +50 -0
  129. data/spec/inputs/numbers/11.txt +50 -0
  130. data/spec/inputs/numbers/12.txt +50 -0
  131. data/spec/inputs/numbers/13.txt +50 -0
  132. data/spec/inputs/numbers/14.txt +50 -0
  133. data/spec/inputs/numbers/15.txt +50 -0
  134. data/spec/inputs/numbers/16.txt +50 -0
  135. data/spec/inputs/numbers/17.txt +50 -0
  136. data/spec/inputs/numbers/18.txt +50 -0
  137. data/spec/inputs/numbers/19.txt +50 -0
  138. data/spec/inputs/numbers/2.txt +50 -0
  139. data/spec/inputs/numbers/20.txt +50 -0
  140. data/spec/inputs/numbers/3.txt +50 -0
  141. data/spec/inputs/numbers/4.txt +50 -0
  142. data/spec/inputs/numbers/5.txt +50 -0
  143. data/spec/inputs/numbers/6.txt +50 -0
  144. data/spec/inputs/numbers/7.txt +50 -0
  145. data/spec/inputs/numbers/8.txt +50 -0
  146. data/spec/inputs/numbers/9.txt +50 -0
  147. data/spec/inputs/numbers_0_100.txt +101 -0
  148. data/spec/inputs/numbers_1_100.txt +100 -0
  149. data/spec/lib/collect_spec.rb +42 -0
  150. data/spec/lib/command_spec.rb +68 -0
  151. data/spec/lib/config_spec.rb +64 -0
  152. data/spec/lib/context_spec.rb +163 -0
  153. data/spec/lib/ext_spec.rb +72 -0
  154. data/spec/lib/external_apps_spec.rb +45 -0
  155. data/spec/lib/filter_spec.rb +80 -0
  156. data/spec/lib/flat_map_spec.rb +100 -0
  157. data/spec/lib/group_spec.rb +109 -0
  158. data/spec/lib/helper_spec.rb +19 -0
  159. data/spec/lib/key_spec.rb +41 -0
  160. data/spec/lib/manipulation_spec.rb +114 -0
  161. data/spec/lib/map_partitions_spec.rb +87 -0
  162. data/spec/lib/map_spec.rb +91 -0
  163. data/spec/lib/mllib/classification_spec.rb +54 -0
  164. data/spec/lib/mllib/clustering_spec.rb +35 -0
  165. data/spec/lib/mllib/matrix_spec.rb +32 -0
  166. data/spec/lib/mllib/regression_spec.rb +116 -0
  167. data/spec/lib/mllib/vector_spec.rb +77 -0
  168. data/spec/lib/reduce_by_key_spec.rb +118 -0
  169. data/spec/lib/reduce_spec.rb +131 -0
  170. data/spec/lib/sample_spec.rb +46 -0
  171. data/spec/lib/serializer_spec.rb +13 -0
  172. data/spec/lib/sort_spec.rb +58 -0
  173. data/spec/lib/statistic_spec.rb +168 -0
  174. data/spec/lib/whole_text_files_spec.rb +33 -0
  175. data/spec/spec_helper.rb +39 -0
  176. metadata +301 -0
@@ -0,0 +1,73 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # RegressionModel
5
+ #
6
+ # A linear model that has a vector of coefficients and an intercept.
7
+ #
8
+ class RegressionModel
9
+
10
+ attr_reader :weights, :intercept
11
+
12
+ def initialize(weights, intercept)
13
+ @weights = Spark::Mllib::Vectors.to_vector(weights)
14
+ @intercept = intercept.to_f
15
+ end
16
+
17
+ # Predict the value of the dependent variable given a vector data
18
+ # containing values for the independent variables.
19
+ #
20
+ # == Examples:
21
+ # lm = RegressionModel.new([1.0, 2.0], 0.1)
22
+ #
23
+ # lm.predict([-1.03, 7.777]) - 14.624 < 1e-6
24
+ # # => true
25
+ #
26
+ # lm.predict(SparseVector.new(2, {0 => -1.03, 1 => 7.777})) - 14.624 < 1e-6
27
+ # # => true
28
+ #
29
+ def predict(data)
30
+ data = Spark::Mllib::Vectors.to_vector(data)
31
+ @weights.dot(data) + @intercept
32
+ end
33
+
34
+ end
35
+ end
36
+ end
37
+
38
+
39
+ module Spark
40
+ module Mllib
41
+ ##
42
+ # RegressionMethodBase
43
+ #
44
+ # Parent for regression methods
45
+ #
46
+ class RegressionMethodBase
47
+
48
+ def self.train(rdd, options)
49
+ # String keys to symbols
50
+ options.symbolize_keys!
51
+
52
+ # Reverse merge
53
+ self::DEFAULT_OPTIONS.each do |key, value|
54
+ if options.has_key?(key)
55
+ # value from user
56
+ else
57
+ options[key] = value
58
+ end
59
+ end
60
+
61
+ # Validation
62
+ first = rdd.first
63
+ unless first.is_a?(LabeledPoint)
64
+ raise Spark::MllibError, "RDD should contains LabeledPoint, got #{first.class}"
65
+ end
66
+
67
+ # Initial weights is optional for user (not for Spark)
68
+ options[:initial_weights] = Vectors.to_vector(options[:initial_weights] || [0.0] * first.features.size)
69
+ end
70
+
71
+ end
72
+ end
73
+ end
@@ -0,0 +1,41 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # LabeledPoint
5
+ #
6
+ # The features and labels of a data point.
7
+ #
8
+ # == Parameters:
9
+ # label::
10
+ # Label for this data point.
11
+ #
12
+ # features::
13
+ # Vector of features for this point
14
+ #
15
+ class LabeledPoint
16
+
17
+ attr_reader :label, :features
18
+
19
+ def initialize(label, features)
20
+ @label = label.to_f
21
+ @features = Spark::Mllib::Vectors.to_vector(features)
22
+ end
23
+
24
+ def self.from_java(object)
25
+ LabeledPoint.new(
26
+ object.label,
27
+ Spark.jb.java_to_ruby(object.features)
28
+ )
29
+ end
30
+
31
+ def marshal_dump
32
+ [@label, @features]
33
+ end
34
+
35
+ def marshal_load(array)
36
+ initialize(array[0], array[1])
37
+ end
38
+
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,100 @@
1
+ ##
2
+ # LassoModel
3
+ #
4
+ # Train a regression model with L1-regularization using Stochastic Gradient Descent.
5
+ # This solves the l1-regularized least squares regression formulation
6
+ # f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1
7
+ # Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
8
+ # its corresponding right hand side label y.
9
+ # See also the documentation for the precise formulation.
10
+ #
11
+ # == Examples:
12
+ #
13
+ # Spark::Mllib.import
14
+ #
15
+ # # Dense vectors
16
+ # data = [
17
+ # LabeledPoint.new(0.0, [0.0]),
18
+ # LabeledPoint.new(1.0, [1.0]),
19
+ # LabeledPoint.new(3.0, [2.0]),
20
+ # LabeledPoint.new(2.0, [3.0])
21
+ # ]
22
+ # lrm = LassoWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
23
+ #
24
+ # lrm.predict([0.0]) - 0 < 0.5
25
+ # # => true
26
+ #
27
+ # lrm.predict([1.0]) - 1 < 0.5
28
+ # # => true
29
+ #
30
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
31
+ # # => true
32
+ #
33
+ #
34
+ # # Sparse vectors
35
+ # data = [
36
+ # LabeledPoint.new(0.0, SparseVector.new(1, {0 => 0.0})),
37
+ # LabeledPoint.new(1.0, SparseVector.new(1, {0 => 1.0})),
38
+ # LabeledPoint.new(3.0, SparseVector.new(1, {0 => 2.0})),
39
+ # LabeledPoint.new(2.0, SparseVector.new(1, {0 => 3.0}))
40
+ # ]
41
+ # lrm = LinearRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
42
+ #
43
+ # lrm.predict([0.0]) - 0 < 0.5
44
+ # # => true
45
+ #
46
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
47
+ # # => true
48
+ #
49
+ class Spark::Mllib::LassoModel < Spark::Mllib::RegressionModel
50
+ end
51
+
52
+ module Spark
53
+ module Mllib
54
+ class LassoWithSGD < RegressionMethodBase
55
+
56
+ DEFAULT_OPTIONS = {
57
+ iterations: 100,
58
+ step: 1.0,
59
+ reg_param: 0.01,
60
+ mini_batch_fraction: 1.0,
61
+ initial_weights: nil
62
+ }
63
+
64
+ # Train a Lasso regression model on the given data.
65
+ #
66
+ # == Parameters:
67
+ # rdd::
68
+ # The training data (RDD instance).
69
+ #
70
+ # iterations::
71
+ # The number of iterations (default: 100).
72
+ #
73
+ # step::
74
+ # The step parameter used in SGD (default: 1.0).
75
+ #
76
+ # reg_param::
77
+ # The regularizer parameter (default: 0.0).
78
+ #
79
+ # mini_batch_fraction::
80
+ # Fraction of data to be used for each SGD iteration (default: 1.0).
81
+ #
82
+ # initial_weights::
83
+ # The initial weights (default: nil).
84
+ #
85
+ def self.train(rdd, options={})
86
+ super
87
+
88
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLassoModelWithSGD', rdd,
89
+ options[:iterations].to_i,
90
+ options[:step].to_f,
91
+ options[:reg_param].to_f,
92
+ options[:mini_batch_fraction].to_f,
93
+ options[:initial_weights])
94
+
95
+ LassoModel.new(weights, intercept)
96
+ end
97
+
98
+ end
99
+ end
100
+ end
@@ -0,0 +1,124 @@
1
+ ##
2
+ # LinearRegressionModel
3
+ #
4
+ # Train a linear regression model with no regularization using Stochastic Gradient Descent.
5
+ # This solves the least squares regression formulation
6
+ # f(weights) = 1/n ||A weights-y||^2^
7
+ # (which is the mean squared error).
8
+ # Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
9
+ # its corresponding right hand side label y.
10
+ # See also the documentation for the precise formulation.
11
+ #
12
+ # == Examples:
13
+ #
14
+ # Spark::Mllib.import
15
+ #
16
+ # # Dense vectors
17
+ # data = [
18
+ # LabeledPoint.new(0.0, [0.0]),
19
+ # LabeledPoint.new(1.0, [1.0]),
20
+ # LabeledPoint.new(3.0, [2.0]),
21
+ # LabeledPoint.new(2.0, [3.0])
22
+ # ]
23
+ # lrm = LinearRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
24
+ #
25
+ # lrm.intercept # => 0.0
26
+ # lrm.weights # => [0.9285714285714286]
27
+ #
28
+ # lrm.predict([0.0]) < 0.5
29
+ # # => true
30
+ #
31
+ # lrm.predict([1.0]) - 1 < 0.5
32
+ # # => true
33
+ #
34
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
35
+ # # => true
36
+ #
37
+ # # Sparse vectors
38
+ # data = [
39
+ # LabeledPoint.new(0.0, SparseVector.new(1, {0 => 0.0})),
40
+ # LabeledPoint.new(1.0, SparseVector.new(1, {0 => 1.0})),
41
+ # LabeledPoint.new(3.0, SparseVector.new(1, {0 => 2.0})),
42
+ # LabeledPoint.new(2.0, SparseVector.new(1, {0 => 3.0}))
43
+ # ]
44
+ # lrm = LinearRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
45
+ #
46
+ # lrm.intercept # => 0.0
47
+ # lrm.weights # => [0.9285714285714286]
48
+ #
49
+ # lrm.predict([0.0]) < 0.5
50
+ # # => true
51
+ #
52
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
53
+ # # => true
54
+ #
55
+ class Spark::Mllib::LinearRegressionModel < Spark::Mllib::RegressionModel
56
+ end
57
+
58
+ module Spark
59
+ module Mllib
60
+ class LinearRegressionWithSGD < RegressionMethodBase
61
+
62
+ DEFAULT_OPTIONS = {
63
+ iterations: 100,
64
+ step: 1.0,
65
+ mini_batch_fraction: 1.0,
66
+ initial_weights: nil,
67
+ reg_param: 0.0,
68
+ reg_type: nil,
69
+ intercept: false
70
+ }
71
+
72
+ # Train a linear regression model on the given data.
73
+ #
74
+ # == Parameters:
75
+ # rdd::
76
+ # The training data (RDD instance).
77
+ #
78
+ # iterations::
79
+ # The number of iterations (default: 100).
80
+ #
81
+ # step::
82
+ # The step parameter used in SGD (default: 1.0).
83
+ #
84
+ # mini_batch_fraction::
85
+ # Fraction of data to be used for each SGD iteration (default: 1.0).
86
+ #
87
+ # initial_weights::
88
+ # The initial weights (default: nil).
89
+ #
90
+ # reg_param::
91
+ # The regularizer parameter (default: 0.0).
92
+ #
93
+ # reg_type::
94
+ # The type of regularizer used for training our model (default: nil).
95
+ #
96
+ # Allowed values:
97
+ # - "l1" for using L1 regularization (lasso),
98
+ # - "l2" for using L2 regularization (ridge),
99
+ # - None for no regularization
100
+ #
101
+ # intercept::
102
+ # Boolean parameter which indicates the use
103
+ # or not of the augmented representation for
104
+ # training data (i.e. whether bias features
105
+ # are activated or not). (default: False)
106
+ #
107
+ def self.train(rdd, options={})
108
+ super
109
+
110
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLinearRegressionModelWithSGD', rdd,
111
+ options[:iterations].to_i,
112
+ options[:step].to_f,
113
+ options[:mini_batch_fraction].to_f,
114
+ options[:initial_weights],
115
+ options[:reg_param].to_f,
116
+ options[:reg_type],
117
+ options[:intercept])
118
+
119
+ LinearRegressionModel.new(weights, intercept)
120
+ end
121
+
122
+ end
123
+ end
124
+ end
@@ -0,0 +1,97 @@
1
+ ##
2
+ # RidgeRegressionModel
3
+ #
4
+ # Train a regression model with L2-regularization using Stochastic Gradient Descent.
5
+ # This solves the l1-regularized least squares regression formulation
6
+ # f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^
7
+ # Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
8
+ # its corresponding right hand side label y.
9
+ # See also the documentation for the precise formulation.
10
+ #
11
+ # == Examples:
12
+ #
13
+ # Spark::Mllib.import
14
+ #
15
+ # data = [
16
+ # LabeledPoint.new(0.0, [0.0]),
17
+ # LabeledPoint.new(1.0, [1.0]),
18
+ # LabeledPoint.new(3.0, [2.0]),
19
+ # LabeledPoint.new(2.0, [3.0])
20
+ # ]
21
+ # lrm = RidgeRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
22
+ #
23
+ # lrm.predict([0.0]) - 0 < 0.5
24
+ # # => true
25
+ #
26
+ # lrm.predict([1.0]) - 1 < 0.5
27
+ # # => true
28
+ #
29
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
30
+ # # => true
31
+ #
32
+ # data = [
33
+ # LabeledPoint.new(0.0, SparseVector.new(1, {0 => 0.0})),
34
+ # LabeledPoint.new(1.0, SparseVector.new(1, {0 => 1.0})),
35
+ # LabeledPoint.new(3.0, SparseVector.new(1, {0 => 2.0})),
36
+ # LabeledPoint.new(2.0, SparseVector.new(1, {0 => 3.0}))
37
+ # ]
38
+ # lrm = LinearRegressionWithSGD.train($sc.parallelize(data), initial_weights: [1.0])
39
+ #
40
+ # lrm.predict([0.0]) - 0 < 0.5
41
+ # # => true
42
+ #
43
+ # lrm.predict(SparseVector.new(1, {0 => 1.0})) - 1 < 0.5
44
+ # # => true
45
+ #
46
+ class Spark::Mllib::RidgeRegressionModel < Spark::Mllib::RegressionModel
47
+ end
48
+
49
+ module Spark
50
+ module Mllib
51
+ class RidgeRegressionWithSGD < RegressionMethodBase
52
+
53
+ DEFAULT_OPTIONS = {
54
+ iterations: 100,
55
+ step: 1.0,
56
+ reg_param: 0.01,
57
+ mini_batch_fraction: 1.0,
58
+ initial_weights: nil
59
+ }
60
+
61
+ # Train a ridge regression model on the given data.
62
+ #
63
+ # == Parameters:
64
+ # rdd::
65
+ # The training data (RDD instance).
66
+ #
67
+ # iterations::
68
+ # The number of iterations (default: 100).
69
+ #
70
+ # step::
71
+ # The step parameter used in SGD (default: 1.0).
72
+ #
73
+ # reg_param::
74
+ # The regularizer parameter (default: 0.0).
75
+ #
76
+ # mini_batch_fraction::
77
+ # Fraction of data to be used for each SGD iteration (default: 1.0).
78
+ #
79
+ # initial_weights::
80
+ # The initial weights (default: nil).
81
+ #
82
+ def self.train(rdd, options={})
83
+ super
84
+
85
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainRidgeModelWithSGD', rdd,
86
+ options[:iterations].to_i,
87
+ options[:step].to_f,
88
+ options[:reg_param].to_f,
89
+ options[:mini_batch_fraction].to_f,
90
+ options[:initial_weights])
91
+
92
+ RidgeRegressionModel.new(weights, intercept)
93
+ end
94
+
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,53 @@
1
+ require 'matrix'
2
+
3
+ module Spark
4
+ module Mllib
5
+ class MatrixAdapter < ::Matrix
6
+
7
+ def self.new(*args)
8
+ object = self.allocate
9
+
10
+ if args.size == 2
11
+ # Matrix is initialized from Matrix
12
+ # Arguments: rows, column count
13
+ object.__send__(:original_initialize, *args)
14
+ else
15
+ object.__send__(:initialize, *args)
16
+ end
17
+
18
+ object
19
+ end
20
+
21
+ alias_method :original_initialize, :initialize
22
+
23
+ def initialize(type, rows, cols, values=nil)
24
+ case type
25
+ when :dense
26
+ values = values.dup
27
+ if rows * cols == values.size
28
+ # Values are on one row
29
+ # 2x2 => [1,2,3,4]
30
+ values = values.each_slice(cols).to_a
31
+ else
32
+ # 2x2 => [[1,2], [3,4]]
33
+ end
34
+ when :sparse
35
+ values = Array.new(rows) { Array.new(cols) { 0.0 } }
36
+ else
37
+ raise Spark::MllibError, 'Unknow vector type.'
38
+ end
39
+
40
+ super(values, cols)
41
+ end
42
+
43
+ def shape
44
+ [row_count, column_count]
45
+ end
46
+
47
+ def values
48
+ @values || to_a
49
+ end
50
+
51
+ end
52
+ end
53
+ end