ruby-spark 1.0.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (176) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +185 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +7 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/custom_marshal.rb +94 -0
  12. data/benchmark/digest.rb +150 -0
  13. data/benchmark/enumerator.rb +88 -0
  14. data/benchmark/performance/prepare.sh +18 -0
  15. data/benchmark/performance/python.py +156 -0
  16. data/benchmark/performance/r.r +69 -0
  17. data/benchmark/performance/ruby.rb +167 -0
  18. data/benchmark/performance/run-all.sh +160 -0
  19. data/benchmark/performance/scala.scala +181 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/ext/ruby_c/extconf.rb +3 -0
  27. data/ext/ruby_c/murmur.c +158 -0
  28. data/ext/ruby_c/murmur.h +9 -0
  29. data/ext/ruby_c/ruby-spark.c +18 -0
  30. data/ext/ruby_java/Digest.java +36 -0
  31. data/ext/ruby_java/Murmur2.java +98 -0
  32. data/ext/ruby_java/RubySparkExtService.java +28 -0
  33. data/ext/ruby_java/extconf.rb +3 -0
  34. data/ext/spark/build.sbt +73 -0
  35. data/ext/spark/project/plugins.sbt +9 -0
  36. data/ext/spark/sbt/sbt +34 -0
  37. data/ext/spark/src/main/scala/Exec.scala +91 -0
  38. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  39. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  40. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  41. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  42. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  43. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  44. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  46. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  47. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  48. data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
  49. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  50. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  51. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  52. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  53. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  54. data/lib/ruby-spark.rb +1 -0
  55. data/lib/spark.rb +198 -0
  56. data/lib/spark/accumulator.rb +260 -0
  57. data/lib/spark/broadcast.rb +98 -0
  58. data/lib/spark/build.rb +43 -0
  59. data/lib/spark/cli.rb +169 -0
  60. data/lib/spark/command.rb +86 -0
  61. data/lib/spark/command/base.rb +154 -0
  62. data/lib/spark/command/basic.rb +345 -0
  63. data/lib/spark/command/pair.rb +124 -0
  64. data/lib/spark/command/sort.rb +51 -0
  65. data/lib/spark/command/statistic.rb +144 -0
  66. data/lib/spark/command_builder.rb +141 -0
  67. data/lib/spark/command_validator.rb +34 -0
  68. data/lib/spark/config.rb +244 -0
  69. data/lib/spark/constant.rb +14 -0
  70. data/lib/spark/context.rb +304 -0
  71. data/lib/spark/error.rb +50 -0
  72. data/lib/spark/ext/hash.rb +41 -0
  73. data/lib/spark/ext/integer.rb +25 -0
  74. data/lib/spark/ext/io.rb +57 -0
  75. data/lib/spark/ext/ip_socket.rb +29 -0
  76. data/lib/spark/ext/module.rb +58 -0
  77. data/lib/spark/ext/object.rb +24 -0
  78. data/lib/spark/ext/string.rb +24 -0
  79. data/lib/spark/helper.rb +10 -0
  80. data/lib/spark/helper/logger.rb +40 -0
  81. data/lib/spark/helper/parser.rb +85 -0
  82. data/lib/spark/helper/serialize.rb +71 -0
  83. data/lib/spark/helper/statistic.rb +93 -0
  84. data/lib/spark/helper/system.rb +42 -0
  85. data/lib/spark/java_bridge.rb +19 -0
  86. data/lib/spark/java_bridge/base.rb +203 -0
  87. data/lib/spark/java_bridge/jruby.rb +23 -0
  88. data/lib/spark/java_bridge/rjb.rb +41 -0
  89. data/lib/spark/logger.rb +76 -0
  90. data/lib/spark/mllib.rb +100 -0
  91. data/lib/spark/mllib/classification/common.rb +31 -0
  92. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  93. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  94. data/lib/spark/mllib/classification/svm.rb +135 -0
  95. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  96. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  97. data/lib/spark/mllib/matrix.rb +120 -0
  98. data/lib/spark/mllib/regression/common.rb +73 -0
  99. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  100. data/lib/spark/mllib/regression/lasso.rb +100 -0
  101. data/lib/spark/mllib/regression/linear.rb +124 -0
  102. data/lib/spark/mllib/regression/ridge.rb +97 -0
  103. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  104. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  105. data/lib/spark/mllib/stat/distribution.rb +12 -0
  106. data/lib/spark/mllib/vector.rb +185 -0
  107. data/lib/spark/rdd.rb +1328 -0
  108. data/lib/spark/sampler.rb +92 -0
  109. data/lib/spark/serializer.rb +24 -0
  110. data/lib/spark/serializer/base.rb +170 -0
  111. data/lib/spark/serializer/cartesian.rb +37 -0
  112. data/lib/spark/serializer/marshal.rb +19 -0
  113. data/lib/spark/serializer/message_pack.rb +25 -0
  114. data/lib/spark/serializer/oj.rb +25 -0
  115. data/lib/spark/serializer/pair.rb +27 -0
  116. data/lib/spark/serializer/utf8.rb +25 -0
  117. data/lib/spark/sort.rb +189 -0
  118. data/lib/spark/stat_counter.rb +125 -0
  119. data/lib/spark/storage_level.rb +39 -0
  120. data/lib/spark/version.rb +3 -0
  121. data/lib/spark/worker/master.rb +144 -0
  122. data/lib/spark/worker/spark_files.rb +15 -0
  123. data/lib/spark/worker/worker.rb +197 -0
  124. data/ruby-spark.gemspec +36 -0
  125. data/spec/generator.rb +37 -0
  126. data/spec/inputs/lorem_300.txt +316 -0
  127. data/spec/inputs/numbers/1.txt +50 -0
  128. data/spec/inputs/numbers/10.txt +50 -0
  129. data/spec/inputs/numbers/11.txt +50 -0
  130. data/spec/inputs/numbers/12.txt +50 -0
  131. data/spec/inputs/numbers/13.txt +50 -0
  132. data/spec/inputs/numbers/14.txt +50 -0
  133. data/spec/inputs/numbers/15.txt +50 -0
  134. data/spec/inputs/numbers/16.txt +50 -0
  135. data/spec/inputs/numbers/17.txt +50 -0
  136. data/spec/inputs/numbers/18.txt +50 -0
  137. data/spec/inputs/numbers/19.txt +50 -0
  138. data/spec/inputs/numbers/2.txt +50 -0
  139. data/spec/inputs/numbers/20.txt +50 -0
  140. data/spec/inputs/numbers/3.txt +50 -0
  141. data/spec/inputs/numbers/4.txt +50 -0
  142. data/spec/inputs/numbers/5.txt +50 -0
  143. data/spec/inputs/numbers/6.txt +50 -0
  144. data/spec/inputs/numbers/7.txt +50 -0
  145. data/spec/inputs/numbers/8.txt +50 -0
  146. data/spec/inputs/numbers/9.txt +50 -0
  147. data/spec/inputs/numbers_0_100.txt +101 -0
  148. data/spec/inputs/numbers_1_100.txt +100 -0
  149. data/spec/lib/collect_spec.rb +42 -0
  150. data/spec/lib/command_spec.rb +68 -0
  151. data/spec/lib/config_spec.rb +64 -0
  152. data/spec/lib/context_spec.rb +163 -0
  153. data/spec/lib/ext_spec.rb +72 -0
  154. data/spec/lib/external_apps_spec.rb +45 -0
  155. data/spec/lib/filter_spec.rb +80 -0
  156. data/spec/lib/flat_map_spec.rb +100 -0
  157. data/spec/lib/group_spec.rb +109 -0
  158. data/spec/lib/helper_spec.rb +19 -0
  159. data/spec/lib/key_spec.rb +41 -0
  160. data/spec/lib/manipulation_spec.rb +114 -0
  161. data/spec/lib/map_partitions_spec.rb +87 -0
  162. data/spec/lib/map_spec.rb +91 -0
  163. data/spec/lib/mllib/classification_spec.rb +54 -0
  164. data/spec/lib/mllib/clustering_spec.rb +35 -0
  165. data/spec/lib/mllib/matrix_spec.rb +32 -0
  166. data/spec/lib/mllib/regression_spec.rb +116 -0
  167. data/spec/lib/mllib/vector_spec.rb +77 -0
  168. data/spec/lib/reduce_by_key_spec.rb +118 -0
  169. data/spec/lib/reduce_spec.rb +131 -0
  170. data/spec/lib/sample_spec.rb +46 -0
  171. data/spec/lib/serializer_spec.rb +13 -0
  172. data/spec/lib/sort_spec.rb +58 -0
  173. data/spec/lib/statistic_spec.rb +168 -0
  174. data/spec/lib/whole_text_files_spec.rb +33 -0
  175. data/spec/spec_helper.rb +39 -0
  176. metadata +301 -0
@@ -0,0 +1,41 @@
1
+ if !ENV.has_key?('JAVA_HOME')
2
+ raise Spark::ConfigurationError, 'Environment variable JAVA_HOME is not set'
3
+ end
4
+
5
+ require 'rjb'
6
+
7
+ module Spark
8
+ module JavaBridge
9
+ class RJB < Base
10
+
11
+ def initialize(*args)
12
+ super
13
+ Rjb.load(jars)
14
+ Rjb.primitive_conversion = true
15
+ end
16
+
17
+ def import(name, klass)
18
+ Object.const_set(name, silence_warnings { Rjb.import(klass) })
19
+ end
20
+
21
+ def java_object?(object)
22
+ object.is_a?(Rjb::Rjb_JavaProxy)
23
+ end
24
+
25
+ private
26
+
27
+ def jars
28
+ separator = windows? ? ';' : ':'
29
+ super.join(separator)
30
+ end
31
+
32
+ def silence_warnings
33
+ old_verbose, $VERBOSE = $VERBOSE, nil
34
+ yield
35
+ ensure
36
+ $VERBOSE = old_verbose
37
+ end
38
+
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,76 @@
1
+ # Necessary libraries
2
+ Spark.load_lib
3
+
4
+ module Spark
5
+ class Logger
6
+
7
+ attr_reader :jlogger
8
+
9
+ def initialize
10
+ @jlogger = JLogger.getLogger('Ruby')
11
+ end
12
+
13
+ def level_off
14
+ JLevel.toLevel('OFF')
15
+ end
16
+
17
+ # Disable all Spark log
18
+ def disable
19
+ jlogger.setLevel(level_off)
20
+ JLogger.getLogger('org').setLevel(level_off)
21
+ JLogger.getLogger('akka').setLevel(level_off)
22
+ JLogger.getRootLogger.setLevel(level_off)
23
+ end
24
+
25
+ def enabled?
26
+ !disabled?
27
+ end
28
+
29
+ def info(message)
30
+ jlogger.info(message) if info?
31
+ end
32
+
33
+ def debug(message)
34
+ jlogger.debug(message) if debug?
35
+ end
36
+
37
+ def trace(message)
38
+ jlogger.trace(message) if trace?
39
+ end
40
+
41
+ def warning(message)
42
+ jlogger.warn(message) if warning?
43
+ end
44
+
45
+ def error(message)
46
+ jlogger.error(message) if error?
47
+ end
48
+
49
+ def info?
50
+ level_enabled?('info')
51
+ end
52
+
53
+ def debug?
54
+ level_enabled?('debug')
55
+ end
56
+
57
+ def trace?
58
+ level_enabled?('trace')
59
+ end
60
+
61
+ def warning?
62
+ level_enabled?('warn')
63
+ end
64
+
65
+ def error?
66
+ level_enabled?('error')
67
+ end
68
+
69
+ def level_enabled?(type)
70
+ jlogger.isEnabledFor(JPriority.toPriority(type.upcase))
71
+ end
72
+
73
+ alias_method :warn, :warning
74
+
75
+ end
76
+ end
@@ -0,0 +1,100 @@
1
+ module Spark
2
+ # MLlib is Spark’s scalable machine learning library consisting of common learning algorithms and utilities,
3
+ # including classification, regression, clustering, collaborative filtering, dimensionality reduction,
4
+ # as well as underlying optimization primitives.
5
+ module Mllib
6
+
7
+ def self.autoload(klass, location, import=true)
8
+ if import
9
+ @for_importing ||= []
10
+ @for_importing << klass
11
+ end
12
+
13
+ super(klass, location)
14
+ end
15
+
16
+ def self.autoload_without_import(klass, location)
17
+ autoload(klass, location, false)
18
+ end
19
+
20
+ # Base classes
21
+ autoload_without_import :VectorBase, 'spark/mllib/vector'
22
+ autoload_without_import :MatrixBase, 'spark/mllib/matrix'
23
+ autoload_without_import :RegressionMethodBase, 'spark/mllib/regression/common'
24
+ autoload_without_import :ClassificationMethodBase, 'spark/mllib/classification/common'
25
+
26
+ # Linear algebra
27
+ autoload :Vectors, 'spark/mllib/vector'
28
+ autoload :DenseVector, 'spark/mllib/vector'
29
+ autoload :SparseVector, 'spark/mllib/vector'
30
+ autoload :Matrices, 'spark/mllib/matrix'
31
+ autoload :DenseMatrix, 'spark/mllib/matrix'
32
+ autoload :SparseMatrix, 'spark/mllib/matrix'
33
+
34
+ # Regression
35
+ autoload :LabeledPoint, 'spark/mllib/regression/labeled_point'
36
+ autoload :RegressionModel, 'spark/mllib/regression/common'
37
+ autoload :LinearRegressionModel, 'spark/mllib/regression/linear'
38
+ autoload :LinearRegressionWithSGD, 'spark/mllib/regression/linear'
39
+ autoload :LassoModel, 'spark/mllib/regression/lasso'
40
+ autoload :LassoWithSGD, 'spark/mllib/regression/lasso'
41
+ autoload :RidgeRegressionModel, 'spark/mllib/regression/ridge'
42
+ autoload :RidgeRegressionWithSGD, 'spark/mllib/regression/ridge'
43
+
44
+ # Classification
45
+ autoload :ClassificationModel, 'spark/mllib/classification/common'
46
+ autoload :LogisticRegressionWithSGD, 'spark/mllib/classification/logistic_regression'
47
+ autoload :LogisticRegressionWithLBFGS, 'spark/mllib/classification/logistic_regression'
48
+ autoload :SVMModel, 'spark/mllib/classification/svm'
49
+ autoload :SVMWithSGD, 'spark/mllib/classification/svm'
50
+ autoload :NaiveBayesModel, 'spark/mllib/classification/naive_bayes'
51
+ autoload :NaiveBayes, 'spark/mllib/classification/naive_bayes'
52
+
53
+ # Clustering
54
+ autoload :KMeans, 'spark/mllib/clustering/kmeans'
55
+ autoload :KMeansModel, 'spark/mllib/clustering/kmeans'
56
+ autoload :GaussianMixture, 'spark/mllib/clustering/gaussian_mixture'
57
+ autoload :GaussianMixtureModel, 'spark/mllib/clustering/gaussian_mixture'
58
+
59
+ # Stat
60
+ autoload :MultivariateGaussian, 'spark/mllib/stat/distribution'
61
+
62
+ def self.prepare
63
+ return if @prepared
64
+
65
+ # if narray?
66
+ # require 'spark/mllib/narray/vector'
67
+ # require 'spark/mllib/narray/matrix'
68
+ # elsif mdarray?
69
+ # require 'spark/mllib/mdarray/vector'
70
+ # require 'spark/mllib/mdarray/matrix'
71
+ # else
72
+ # require 'spark/mllib/matrix/vector'
73
+ # require 'spark/mllib/matrix/matrix'
74
+ # end
75
+
76
+ require 'spark/mllib/ruby_matrix/vector_adapter'
77
+ require 'spark/mllib/ruby_matrix/matrix_adapter'
78
+
79
+ @prepared = true
80
+ nil
81
+ end
82
+
83
+ def self.import(to=Object)
84
+ @for_importing.each do |klass|
85
+ to.const_set(klass, const_get(klass))
86
+ end
87
+ nil
88
+ end
89
+
90
+ def self.narray?
91
+ Gem::Specification::find_all_by_name('narray').any?
92
+ end
93
+
94
+ def self.mdarray?
95
+ Gem::Specification::find_all_by_name('mdarray').any?
96
+ end
97
+ end
98
+ end
99
+
100
+ Spark::Mllib.prepare
@@ -0,0 +1,31 @@
1
+ module Spark
2
+ module Mllib
3
+ class ClassificationModel
4
+
5
+ attr_reader :weights, :intercept, :threshold
6
+
7
+ def initialize(weights, intercept)
8
+ @weights = Spark::Mllib::Vectors.to_vector(weights)
9
+ @intercept = intercept.to_f
10
+ @threshold = nil
11
+ end
12
+
13
+ def threshold=(value)
14
+ @threshold = value.to_f
15
+ end
16
+
17
+ def clear_threshold
18
+ @threshold = nil
19
+ end
20
+
21
+ end
22
+ end
23
+ end
24
+
25
+ module Spark
26
+ module Mllib
27
+ class ClassificationMethodBase < RegressionMethodBase
28
+
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,223 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # LogisticRegressionModel
5
+ #
6
+ # A linear binary classification model derived from logistic regression.
7
+ #
8
+ # == Examples:
9
+ #
10
+ # Spark::Mllib.import
11
+ #
12
+ # # Dense vectors
13
+ # data = [
14
+ # LabeledPoint.new(0.0, [0.0, 1.0]),
15
+ # LabeledPoint.new(1.0, [1.0, 0.0]),
16
+ # ]
17
+ # lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))
18
+ #
19
+ # lrm.predict([1.0, 0.0])
20
+ # # => 1
21
+ # lrm.predict([0.0, 1.0])
22
+ # # => 0
23
+ #
24
+ # lrm.clear_threshold
25
+ # lrm.predict([0.0, 1.0])
26
+ # # => 0.123...
27
+ #
28
+ #
29
+ # # Sparse vectors
30
+ # data = [
31
+ # LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
32
+ # LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
33
+ # LabeledPoint.new(0.0, SparseVector.new(2, {0 => 1.0})),
34
+ # LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
35
+ # ]
36
+ # lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))
37
+ #
38
+ # lrm.predict([0.0, 1.0])
39
+ # # => 1
40
+ # lrm.predict([1.0, 0.0])
41
+ # # => 0
42
+ # lrm.predict(SparseVector.new(2, {1 => 1.0}))
43
+ # # => 1
44
+ # lrm.predict(SparseVector.new(2, {0 => 1.0}))
45
+ # # => 0
46
+ #
47
+ #
48
+ # # LogisticRegressionWithLBFGS
49
+ # data = [
50
+ # LabeledPoint.new(0.0, [0.0, 1.0]),
51
+ # LabeledPoint.new(1.0, [1.0, 0.0]),
52
+ # ]
53
+ # lrm = LogisticRegressionWithLBFGS.train($sc.parallelize(data))
54
+ #
55
+ # lrm.predict([1.0, 0.0])
56
+ # # => 1
57
+ # lrm.predict([0.0, 1.0])
58
+ # # => 0
59
+ #
60
+ class LogisticRegressionModel < ClassificationModel
61
+
62
+ def initialize(*args)
63
+ super
64
+ @threshold = 0.5
65
+ end
66
+
67
+ # Predict values for a single data point or an RDD of points using
68
+ # the model trained.
69
+ def predict(vector)
70
+ vector = Spark::Mllib::Vectors.to_vector(vector)
71
+ margin = weights.dot(vector) + intercept
72
+ score = 1.0 / (1.0 + Math.exp(-margin))
73
+
74
+ if threshold.nil?
75
+ return score
76
+ end
77
+
78
+ if score > threshold
79
+ 1
80
+ else
81
+ 0
82
+ end
83
+ end
84
+
85
+ end
86
+ end
87
+ end
88
+
89
+ module Spark
90
+ module Mllib
91
+ class LogisticRegressionWithSGD < ClassificationMethodBase
92
+
93
+ DEFAULT_OPTIONS = {
94
+ iterations: 100,
95
+ step: 1.0,
96
+ mini_batch_fraction: 1.0,
97
+ initial_weights: nil,
98
+ reg_param: 0.01,
99
+ reg_type: 'l2',
100
+ intercept: false
101
+ }
102
+
103
+ # Train a logistic regression model on the given data.
104
+ #
105
+ # == Arguments:
106
+ # rdd::
107
+ # The training data, an RDD of LabeledPoint.
108
+ #
109
+ # iterations::
110
+ # The number of iterations (default: 100).
111
+ #
112
+ # step::
113
+ # The step parameter used in SGD (default: 1.0).
114
+ #
115
+ # mini_batch_fraction::
116
+ # Fraction of data to be used for each SGD iteration.
117
+ #
118
+ # initial_weights::
119
+ # The initial weights (default: nil).
120
+ #
121
+ # reg_param::
122
+ # The regularizer parameter (default: 0.01).
123
+ #
124
+ # reg_type::
125
+ # The type of regularizer used for training our model (default: "l2").
126
+ #
127
+ # Allowed values:
128
+ # - "l1" for using L1 regularization
129
+ # - "l2" for using L2 regularization
130
+ # - nil for no regularization
131
+ #
132
+ # intercept::
133
+ # Boolean parameter which indicates the use
134
+ # or not of the augmented representation for
135
+ # training data (i.e. whether bias features
136
+ # are activated or not).
137
+ #
138
+ def self.train(rdd, options={})
139
+ super
140
+
141
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithSGD', rdd,
142
+ options[:iterations].to_i,
143
+ options[:step].to_f,
144
+ options[:mini_batch_fraction].to_f,
145
+ options[:initial_weights],
146
+ options[:reg_param].to_f,
147
+ options[:reg_type],
148
+ options[:intercept])
149
+
150
+ LogisticRegressionModel.new(weights, intercept)
151
+ end
152
+
153
+ end
154
+ end
155
+ end
156
+
157
+ module Spark
158
+ module Mllib
159
+ class LogisticRegressionWithLBFGS < ClassificationMethodBase
160
+
161
+ DEFAULT_OPTIONS = {
162
+ iterations: 100,
163
+ initial_weights: nil,
164
+ reg_param: 0.01,
165
+ reg_type: 'l2',
166
+ intercept: false,
167
+ corrections: 10,
168
+ tolerance: 0.0001
169
+ }
170
+
171
+ # Train a logistic regression model on the given data.
172
+ #
173
+ # == Arguments:
174
+ # rdd::
175
+ # The training data, an RDD of LabeledPoint.
176
+ #
177
+ # iterations::
178
+ # The number of iterations (default: 100).
179
+ #
180
+ # initial_weights::
181
+ # The initial weights (default: nil).
182
+ #
183
+ # reg_param::
184
+ # The regularizer parameter (default: 0.01).
185
+ #
186
+ # reg_type::
187
+ # The type of regularizer used for training our model (default: "l2").
188
+ #
189
+ # Allowed values:
190
+ # - "l1" for using L1 regularization
191
+ # - "l2" for using L2 regularization
192
+ # - nil for no regularization
193
+ #
194
+ # intercept::
195
+ # Boolean parameter which indicates the use
196
+ # or not of the augmented representation for
197
+ # training data (i.e. whether bias features
198
+ # are activated or not).
199
+ #
200
+ # corrections::
201
+ # The number of corrections used in the LBFGS update (default: 10).
202
+ #
203
+ # tolerance::
204
+ # The convergence tolerance of iterations for L-BFGS (default: 0.0001).
205
+ #
206
+ def self.train(rdd, options={})
207
+ super
208
+
209
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithLBFGS', rdd,
210
+ options[:iterations].to_i,
211
+ options[:initial_weights],
212
+ options[:reg_param].to_f,
213
+ options[:reg_type],
214
+ options[:intercept],
215
+ options[:corrections].to_i,
216
+ options[:tolerance].to_f)
217
+
218
+ LogisticRegressionModel.new(weights, intercept)
219
+ end
220
+
221
+ end
222
+ end
223
+ end