ruby-spark 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +185 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +7 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/custom_marshal.rb +94 -0
  12. data/benchmark/digest.rb +150 -0
  13. data/benchmark/enumerator.rb +88 -0
  14. data/benchmark/performance/prepare.sh +18 -0
  15. data/benchmark/performance/python.py +156 -0
  16. data/benchmark/performance/r.r +69 -0
  17. data/benchmark/performance/ruby.rb +167 -0
  18. data/benchmark/performance/run-all.sh +160 -0
  19. data/benchmark/performance/scala.scala +181 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/ext/ruby_c/extconf.rb +3 -0
  27. data/ext/ruby_c/murmur.c +158 -0
  28. data/ext/ruby_c/murmur.h +9 -0
  29. data/ext/ruby_c/ruby-spark.c +18 -0
  30. data/ext/ruby_java/Digest.java +36 -0
  31. data/ext/ruby_java/Murmur2.java +98 -0
  32. data/ext/ruby_java/RubySparkExtService.java +28 -0
  33. data/ext/ruby_java/extconf.rb +3 -0
  34. data/ext/spark/build.sbt +73 -0
  35. data/ext/spark/project/plugins.sbt +9 -0
  36. data/ext/spark/sbt/sbt +34 -0
  37. data/ext/spark/src/main/scala/Exec.scala +91 -0
  38. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  39. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  40. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  41. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  42. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  43. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  44. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  46. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  47. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  48. data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
  49. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  50. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  51. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  52. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  53. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  54. data/lib/ruby-spark.rb +1 -0
  55. data/lib/spark.rb +198 -0
  56. data/lib/spark/accumulator.rb +260 -0
  57. data/lib/spark/broadcast.rb +98 -0
  58. data/lib/spark/build.rb +43 -0
  59. data/lib/spark/cli.rb +169 -0
  60. data/lib/spark/command.rb +86 -0
  61. data/lib/spark/command/base.rb +154 -0
  62. data/lib/spark/command/basic.rb +345 -0
  63. data/lib/spark/command/pair.rb +124 -0
  64. data/lib/spark/command/sort.rb +51 -0
  65. data/lib/spark/command/statistic.rb +144 -0
  66. data/lib/spark/command_builder.rb +141 -0
  67. data/lib/spark/command_validator.rb +34 -0
  68. data/lib/spark/config.rb +244 -0
  69. data/lib/spark/constant.rb +14 -0
  70. data/lib/spark/context.rb +304 -0
  71. data/lib/spark/error.rb +50 -0
  72. data/lib/spark/ext/hash.rb +41 -0
  73. data/lib/spark/ext/integer.rb +25 -0
  74. data/lib/spark/ext/io.rb +57 -0
  75. data/lib/spark/ext/ip_socket.rb +29 -0
  76. data/lib/spark/ext/module.rb +58 -0
  77. data/lib/spark/ext/object.rb +24 -0
  78. data/lib/spark/ext/string.rb +24 -0
  79. data/lib/spark/helper.rb +10 -0
  80. data/lib/spark/helper/logger.rb +40 -0
  81. data/lib/spark/helper/parser.rb +85 -0
  82. data/lib/spark/helper/serialize.rb +71 -0
  83. data/lib/spark/helper/statistic.rb +93 -0
  84. data/lib/spark/helper/system.rb +42 -0
  85. data/lib/spark/java_bridge.rb +19 -0
  86. data/lib/spark/java_bridge/base.rb +203 -0
  87. data/lib/spark/java_bridge/jruby.rb +23 -0
  88. data/lib/spark/java_bridge/rjb.rb +41 -0
  89. data/lib/spark/logger.rb +76 -0
  90. data/lib/spark/mllib.rb +100 -0
  91. data/lib/spark/mllib/classification/common.rb +31 -0
  92. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  93. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  94. data/lib/spark/mllib/classification/svm.rb +135 -0
  95. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  96. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  97. data/lib/spark/mllib/matrix.rb +120 -0
  98. data/lib/spark/mllib/regression/common.rb +73 -0
  99. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  100. data/lib/spark/mllib/regression/lasso.rb +100 -0
  101. data/lib/spark/mllib/regression/linear.rb +124 -0
  102. data/lib/spark/mllib/regression/ridge.rb +97 -0
  103. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  104. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  105. data/lib/spark/mllib/stat/distribution.rb +12 -0
  106. data/lib/spark/mllib/vector.rb +185 -0
  107. data/lib/spark/rdd.rb +1328 -0
  108. data/lib/spark/sampler.rb +92 -0
  109. data/lib/spark/serializer.rb +24 -0
  110. data/lib/spark/serializer/base.rb +170 -0
  111. data/lib/spark/serializer/cartesian.rb +37 -0
  112. data/lib/spark/serializer/marshal.rb +19 -0
  113. data/lib/spark/serializer/message_pack.rb +25 -0
  114. data/lib/spark/serializer/oj.rb +25 -0
  115. data/lib/spark/serializer/pair.rb +27 -0
  116. data/lib/spark/serializer/utf8.rb +25 -0
  117. data/lib/spark/sort.rb +189 -0
  118. data/lib/spark/stat_counter.rb +125 -0
  119. data/lib/spark/storage_level.rb +39 -0
  120. data/lib/spark/version.rb +3 -0
  121. data/lib/spark/worker/master.rb +144 -0
  122. data/lib/spark/worker/spark_files.rb +15 -0
  123. data/lib/spark/worker/worker.rb +197 -0
  124. data/ruby-spark.gemspec +36 -0
  125. data/spec/generator.rb +37 -0
  126. data/spec/inputs/lorem_300.txt +316 -0
  127. data/spec/inputs/numbers/1.txt +50 -0
  128. data/spec/inputs/numbers/10.txt +50 -0
  129. data/spec/inputs/numbers/11.txt +50 -0
  130. data/spec/inputs/numbers/12.txt +50 -0
  131. data/spec/inputs/numbers/13.txt +50 -0
  132. data/spec/inputs/numbers/14.txt +50 -0
  133. data/spec/inputs/numbers/15.txt +50 -0
  134. data/spec/inputs/numbers/16.txt +50 -0
  135. data/spec/inputs/numbers/17.txt +50 -0
  136. data/spec/inputs/numbers/18.txt +50 -0
  137. data/spec/inputs/numbers/19.txt +50 -0
  138. data/spec/inputs/numbers/2.txt +50 -0
  139. data/spec/inputs/numbers/20.txt +50 -0
  140. data/spec/inputs/numbers/3.txt +50 -0
  141. data/spec/inputs/numbers/4.txt +50 -0
  142. data/spec/inputs/numbers/5.txt +50 -0
  143. data/spec/inputs/numbers/6.txt +50 -0
  144. data/spec/inputs/numbers/7.txt +50 -0
  145. data/spec/inputs/numbers/8.txt +50 -0
  146. data/spec/inputs/numbers/9.txt +50 -0
  147. data/spec/inputs/numbers_0_100.txt +101 -0
  148. data/spec/inputs/numbers_1_100.txt +100 -0
  149. data/spec/lib/collect_spec.rb +42 -0
  150. data/spec/lib/command_spec.rb +68 -0
  151. data/spec/lib/config_spec.rb +64 -0
  152. data/spec/lib/context_spec.rb +163 -0
  153. data/spec/lib/ext_spec.rb +72 -0
  154. data/spec/lib/external_apps_spec.rb +45 -0
  155. data/spec/lib/filter_spec.rb +80 -0
  156. data/spec/lib/flat_map_spec.rb +100 -0
  157. data/spec/lib/group_spec.rb +109 -0
  158. data/spec/lib/helper_spec.rb +19 -0
  159. data/spec/lib/key_spec.rb +41 -0
  160. data/spec/lib/manipulation_spec.rb +114 -0
  161. data/spec/lib/map_partitions_spec.rb +87 -0
  162. data/spec/lib/map_spec.rb +91 -0
  163. data/spec/lib/mllib/classification_spec.rb +54 -0
  164. data/spec/lib/mllib/clustering_spec.rb +35 -0
  165. data/spec/lib/mllib/matrix_spec.rb +32 -0
  166. data/spec/lib/mllib/regression_spec.rb +116 -0
  167. data/spec/lib/mllib/vector_spec.rb +77 -0
  168. data/spec/lib/reduce_by_key_spec.rb +118 -0
  169. data/spec/lib/reduce_spec.rb +131 -0
  170. data/spec/lib/sample_spec.rb +46 -0
  171. data/spec/lib/serializer_spec.rb +13 -0
  172. data/spec/lib/sort_spec.rb +58 -0
  173. data/spec/lib/statistic_spec.rb +168 -0
  174. data/spec/lib/whole_text_files_spec.rb +33 -0
  175. data/spec/spec_helper.rb +39 -0
  176. metadata +301 -0
@@ -0,0 +1,41 @@
1
+ if !ENV.has_key?('JAVA_HOME')
2
+ raise Spark::ConfigurationError, 'Environment variable JAVA_HOME is not set'
3
+ end
4
+
5
+ require 'rjb'
6
+
7
+ module Spark
8
+ module JavaBridge
9
+ class RJB < Base
10
+
11
+ def initialize(*args)
12
+ super
13
+ Rjb.load(jars)
14
+ Rjb.primitive_conversion = true
15
+ end
16
+
17
+ def import(name, klass)
18
+ Object.const_set(name, silence_warnings { Rjb.import(klass) })
19
+ end
20
+
21
+ def java_object?(object)
22
+ object.is_a?(Rjb::Rjb_JavaProxy)
23
+ end
24
+
25
+ private
26
+
27
+ def jars
28
+ separator = windows? ? ';' : ':'
29
+ super.join(separator)
30
+ end
31
+
32
+ def silence_warnings
33
+ old_verbose, $VERBOSE = $VERBOSE, nil
34
+ yield
35
+ ensure
36
+ $VERBOSE = old_verbose
37
+ end
38
+
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,76 @@
1
+ # Necessary libraries
2
+ Spark.load_lib
3
+
4
+ module Spark
5
+ class Logger
6
+
7
+ attr_reader :jlogger
8
+
9
+ def initialize
10
+ @jlogger = JLogger.getLogger('Ruby')
11
+ end
12
+
13
+ def level_off
14
+ JLevel.toLevel('OFF')
15
+ end
16
+
17
+ # Disable all Spark log
18
+ def disable
19
+ jlogger.setLevel(level_off)
20
+ JLogger.getLogger('org').setLevel(level_off)
21
+ JLogger.getLogger('akka').setLevel(level_off)
22
+ JLogger.getRootLogger.setLevel(level_off)
23
+ end
24
+
25
+ def enabled?
26
+ !disabled?
27
+ end
28
+
29
+ def info(message)
30
+ jlogger.info(message) if info?
31
+ end
32
+
33
+ def debug(message)
34
+ jlogger.debug(message) if debug?
35
+ end
36
+
37
+ def trace(message)
38
+ jlogger.trace(message) if trace?
39
+ end
40
+
41
+ def warning(message)
42
+ jlogger.warn(message) if warning?
43
+ end
44
+
45
+ def error(message)
46
+ jlogger.error(message) if error?
47
+ end
48
+
49
+ def info?
50
+ level_enabled?('info')
51
+ end
52
+
53
+ def debug?
54
+ level_enabled?('debug')
55
+ end
56
+
57
+ def trace?
58
+ level_enabled?('trace')
59
+ end
60
+
61
+ def warning?
62
+ level_enabled?('warn')
63
+ end
64
+
65
+ def error?
66
+ level_enabled?('error')
67
+ end
68
+
69
+ def level_enabled?(type)
70
+ jlogger.isEnabledFor(JPriority.toPriority(type.upcase))
71
+ end
72
+
73
+ alias_method :warn, :warning
74
+
75
+ end
76
+ end
@@ -0,0 +1,100 @@
1
+ module Spark
2
+ # MLlib is Spark’s scalable machine learning library consisting of common learning algorithms and utilities,
3
+ # including classification, regression, clustering, collaborative filtering, dimensionality reduction,
4
+ # as well as underlying optimization primitives.
5
+ module Mllib
6
+
7
+ def self.autoload(klass, location, import=true)
8
+ if import
9
+ @for_importing ||= []
10
+ @for_importing << klass
11
+ end
12
+
13
+ super(klass, location)
14
+ end
15
+
16
+ def self.autoload_without_import(klass, location)
17
+ autoload(klass, location, false)
18
+ end
19
+
20
+ # Base classes
21
+ autoload_without_import :VectorBase, 'spark/mllib/vector'
22
+ autoload_without_import :MatrixBase, 'spark/mllib/matrix'
23
+ autoload_without_import :RegressionMethodBase, 'spark/mllib/regression/common'
24
+ autoload_without_import :ClassificationMethodBase, 'spark/mllib/classification/common'
25
+
26
+ # Linear algebra
27
+ autoload :Vectors, 'spark/mllib/vector'
28
+ autoload :DenseVector, 'spark/mllib/vector'
29
+ autoload :SparseVector, 'spark/mllib/vector'
30
+ autoload :Matrices, 'spark/mllib/matrix'
31
+ autoload :DenseMatrix, 'spark/mllib/matrix'
32
+ autoload :SparseMatrix, 'spark/mllib/matrix'
33
+
34
+ # Regression
35
+ autoload :LabeledPoint, 'spark/mllib/regression/labeled_point'
36
+ autoload :RegressionModel, 'spark/mllib/regression/common'
37
+ autoload :LinearRegressionModel, 'spark/mllib/regression/linear'
38
+ autoload :LinearRegressionWithSGD, 'spark/mllib/regression/linear'
39
+ autoload :LassoModel, 'spark/mllib/regression/lasso'
40
+ autoload :LassoWithSGD, 'spark/mllib/regression/lasso'
41
+ autoload :RidgeRegressionModel, 'spark/mllib/regression/ridge'
42
+ autoload :RidgeRegressionWithSGD, 'spark/mllib/regression/ridge'
43
+
44
+ # Classification
45
+ autoload :ClassificationModel, 'spark/mllib/classification/common'
46
+ autoload :LogisticRegressionWithSGD, 'spark/mllib/classification/logistic_regression'
47
+ autoload :LogisticRegressionWithLBFGS, 'spark/mllib/classification/logistic_regression'
48
+ autoload :SVMModel, 'spark/mllib/classification/svm'
49
+ autoload :SVMWithSGD, 'spark/mllib/classification/svm'
50
+ autoload :NaiveBayesModel, 'spark/mllib/classification/naive_bayes'
51
+ autoload :NaiveBayes, 'spark/mllib/classification/naive_bayes'
52
+
53
+ # Clustering
54
+ autoload :KMeans, 'spark/mllib/clustering/kmeans'
55
+ autoload :KMeansModel, 'spark/mllib/clustering/kmeans'
56
+ autoload :GaussianMixture, 'spark/mllib/clustering/gaussian_mixture'
57
+ autoload :GaussianMixtureModel, 'spark/mllib/clustering/gaussian_mixture'
58
+
59
+ # Stat
60
+ autoload :MultivariateGaussian, 'spark/mllib/stat/distribution'
61
+
62
+ def self.prepare
63
+ return if @prepared
64
+
65
+ # if narray?
66
+ # require 'spark/mllib/narray/vector'
67
+ # require 'spark/mllib/narray/matrix'
68
+ # elsif mdarray?
69
+ # require 'spark/mllib/mdarray/vector'
70
+ # require 'spark/mllib/mdarray/matrix'
71
+ # else
72
+ # require 'spark/mllib/matrix/vector'
73
+ # require 'spark/mllib/matrix/matrix'
74
+ # end
75
+
76
+ require 'spark/mllib/ruby_matrix/vector_adapter'
77
+ require 'spark/mllib/ruby_matrix/matrix_adapter'
78
+
79
+ @prepared = true
80
+ nil
81
+ end
82
+
83
+ def self.import(to=Object)
84
+ @for_importing.each do |klass|
85
+ to.const_set(klass, const_get(klass))
86
+ end
87
+ nil
88
+ end
89
+
90
+ def self.narray?
91
+ Gem::Specification::find_all_by_name('narray').any?
92
+ end
93
+
94
+ def self.mdarray?
95
+ Gem::Specification::find_all_by_name('mdarray').any?
96
+ end
97
+ end
98
+ end
99
+
100
+ Spark::Mllib.prepare
@@ -0,0 +1,31 @@
1
+ module Spark
2
+ module Mllib
3
+ class ClassificationModel
4
+
5
+ attr_reader :weights, :intercept, :threshold
6
+
7
+ def initialize(weights, intercept)
8
+ @weights = Spark::Mllib::Vectors.to_vector(weights)
9
+ @intercept = intercept.to_f
10
+ @threshold = nil
11
+ end
12
+
13
+ def threshold=(value)
14
+ @threshold = value.to_f
15
+ end
16
+
17
+ def clear_threshold
18
+ @threshold = nil
19
+ end
20
+
21
+ end
22
+ end
23
+ end
24
+
25
+ module Spark
26
+ module Mllib
27
+ class ClassificationMethodBase < RegressionMethodBase
28
+
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,223 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # LogisticRegressionModel
5
+ #
6
+ # A linear binary classification model derived from logistic regression.
7
+ #
8
+ # == Examples:
9
+ #
10
+ # Spark::Mllib.import
11
+ #
12
+ # # Dense vectors
13
+ # data = [
14
+ # LabeledPoint.new(0.0, [0.0, 1.0]),
15
+ # LabeledPoint.new(1.0, [1.0, 0.0]),
16
+ # ]
17
+ # lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))
18
+ #
19
+ # lrm.predict([1.0, 0.0])
20
+ # # => 1
21
+ # lrm.predict([0.0, 1.0])
22
+ # # => 0
23
+ #
24
+ # lrm.clear_threshold
25
+ # lrm.predict([0.0, 1.0])
26
+ # # => 0.123...
27
+ #
28
+ #
29
+ # # Sparse vectors
30
+ # data = [
31
+ # LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
32
+ # LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
33
+ # LabeledPoint.new(0.0, SparseVector.new(2, {0 => 1.0})),
34
+ # LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
35
+ # ]
36
+ # lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))
37
+ #
38
+ # lrm.predict([0.0, 1.0])
39
+ # # => 1
40
+ # lrm.predict([1.0, 0.0])
41
+ # # => 0
42
+ # lrm.predict(SparseVector.new(2, {1 => 1.0}))
43
+ # # => 1
44
+ # lrm.predict(SparseVector.new(2, {0 => 1.0}))
45
+ # # => 0
46
+ #
47
+ #
48
+ # # LogisticRegressionWithLBFGS
49
+ # data = [
50
+ # LabeledPoint.new(0.0, [0.0, 1.0]),
51
+ # LabeledPoint.new(1.0, [1.0, 0.0]),
52
+ # ]
53
+ # lrm = LogisticRegressionWithLBFGS.train($sc.parallelize(data))
54
+ #
55
+ # lrm.predict([1.0, 0.0])
56
+ # # => 1
57
+ # lrm.predict([0.0, 1.0])
58
+ # # => 0
59
+ #
60
+ class LogisticRegressionModel < ClassificationModel
61
+
62
+ def initialize(*args)
63
+ super
64
+ @threshold = 0.5
65
+ end
66
+
67
+ # Predict values for a single data point or an RDD of points using
68
+ # the model trained.
69
+ def predict(vector)
70
+ vector = Spark::Mllib::Vectors.to_vector(vector)
71
+ margin = weights.dot(vector) + intercept
72
+ score = 1.0 / (1.0 + Math.exp(-margin))
73
+
74
+ if threshold.nil?
75
+ return score
76
+ end
77
+
78
+ if score > threshold
79
+ 1
80
+ else
81
+ 0
82
+ end
83
+ end
84
+
85
+ end
86
+ end
87
+ end
88
+
89
+ module Spark
90
+ module Mllib
91
+ class LogisticRegressionWithSGD < ClassificationMethodBase
92
+
93
+ DEFAULT_OPTIONS = {
94
+ iterations: 100,
95
+ step: 1.0,
96
+ mini_batch_fraction: 1.0,
97
+ initial_weights: nil,
98
+ reg_param: 0.01,
99
+ reg_type: 'l2',
100
+ intercept: false
101
+ }
102
+
103
+ # Train a logistic regression model on the given data.
104
+ #
105
+ # == Arguments:
106
+ # rdd::
107
+ # The training data, an RDD of LabeledPoint.
108
+ #
109
+ # iterations::
110
+ # The number of iterations (default: 100).
111
+ #
112
+ # step::
113
+ # The step parameter used in SGD (default: 1.0).
114
+ #
115
+ # mini_batch_fraction::
116
+ # Fraction of data to be used for each SGD iteration.
117
+ #
118
+ # initial_weights::
119
+ # The initial weights (default: nil).
120
+ #
121
+ # reg_param::
122
+ # The regularizer parameter (default: 0.01).
123
+ #
124
+ # reg_type::
125
+ # The type of regularizer used for training our model (default: "l2").
126
+ #
127
+ # Allowed values:
128
+ # - "l1" for using L1 regularization
129
+ # - "l2" for using L2 regularization
130
+ # - nil for no regularization
131
+ #
132
+ # intercept::
133
+ # Boolean parameter which indicates the use
134
+ # or not of the augmented representation for
135
+ # training data (i.e. whether bias features
136
+ # are activated or not).
137
+ #
138
+ def self.train(rdd, options={})
139
+ super
140
+
141
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithSGD', rdd,
142
+ options[:iterations].to_i,
143
+ options[:step].to_f,
144
+ options[:mini_batch_fraction].to_f,
145
+ options[:initial_weights],
146
+ options[:reg_param].to_f,
147
+ options[:reg_type],
148
+ options[:intercept])
149
+
150
+ LogisticRegressionModel.new(weights, intercept)
151
+ end
152
+
153
+ end
154
+ end
155
+ end
156
+
157
+ module Spark
158
+ module Mllib
159
+ class LogisticRegressionWithLBFGS < ClassificationMethodBase
160
+
161
+ DEFAULT_OPTIONS = {
162
+ iterations: 100,
163
+ initial_weights: nil,
164
+ reg_param: 0.01,
165
+ reg_type: 'l2',
166
+ intercept: false,
167
+ corrections: 10,
168
+ tolerance: 0.0001
169
+ }
170
+
171
+ # Train a logistic regression model on the given data.
172
+ #
173
+ # == Arguments:
174
+ # rdd::
175
+ # The training data, an RDD of LabeledPoint.
176
+ #
177
+ # iterations::
178
+ # The number of iterations (default: 100).
179
+ #
180
+ # initial_weights::
181
+ # The initial weights (default: nil).
182
+ #
183
+ # reg_param::
184
+ # The regularizer parameter (default: 0.01).
185
+ #
186
+ # reg_type::
187
+ # The type of regularizer used for training our model (default: "l2").
188
+ #
189
+ # Allowed values:
190
+ # - "l1" for using L1 regularization
191
+ # - "l2" for using L2 regularization
192
+ # - nil for no regularization
193
+ #
194
+ # intercept::
195
+ # Boolean parameter which indicates the use
196
+ # or not of the augmented representation for
197
+ # training data (i.e. whether bias features
198
+ # are activated or not).
199
+ #
200
+ # corrections::
201
+ # The number of corrections used in the LBFGS update (default: 10).
202
+ #
203
+ # tolerance::
204
+ # The convergence tolerance of iterations for L-BFGS (default: 0.0001).
205
+ #
206
+ def self.train(rdd, options={})
207
+ super
208
+
209
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithLBFGS', rdd,
210
+ options[:iterations].to_i,
211
+ options[:initial_weights],
212
+ options[:reg_param].to_f,
213
+ options[:reg_type],
214
+ options[:intercept],
215
+ options[:corrections].to_i,
216
+ options[:tolerance].to_f)
217
+
218
+ LogisticRegressionModel.new(weights, intercept)
219
+ end
220
+
221
+ end
222
+ end
223
+ end