ruby-spark 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +37 -0
- data/Gemfile +47 -0
- data/Guardfile +5 -0
- data/LICENSE.txt +22 -0
- data/README.md +185 -0
- data/Rakefile +35 -0
- data/TODO.md +7 -0
- data/benchmark/aggregate.rb +33 -0
- data/benchmark/bisect.rb +88 -0
- data/benchmark/custom_marshal.rb +94 -0
- data/benchmark/digest.rb +150 -0
- data/benchmark/enumerator.rb +88 -0
- data/benchmark/performance/prepare.sh +18 -0
- data/benchmark/performance/python.py +156 -0
- data/benchmark/performance/r.r +69 -0
- data/benchmark/performance/ruby.rb +167 -0
- data/benchmark/performance/run-all.sh +160 -0
- data/benchmark/performance/scala.scala +181 -0
- data/benchmark/serializer.rb +82 -0
- data/benchmark/sort.rb +43 -0
- data/benchmark/sort2.rb +164 -0
- data/benchmark/take.rb +28 -0
- data/bin/ruby-spark +8 -0
- data/example/pi.rb +28 -0
- data/ext/ruby_c/extconf.rb +3 -0
- data/ext/ruby_c/murmur.c +158 -0
- data/ext/ruby_c/murmur.h +9 -0
- data/ext/ruby_c/ruby-spark.c +18 -0
- data/ext/ruby_java/Digest.java +36 -0
- data/ext/ruby_java/Murmur2.java +98 -0
- data/ext/ruby_java/RubySparkExtService.java +28 -0
- data/ext/ruby_java/extconf.rb +3 -0
- data/ext/spark/build.sbt +73 -0
- data/ext/spark/project/plugins.sbt +9 -0
- data/ext/spark/sbt/sbt +34 -0
- data/ext/spark/src/main/scala/Exec.scala +91 -0
- data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
- data/ext/spark/src/main/scala/Marshal.scala +52 -0
- data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
- data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
- data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
- data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
- data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
- data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
- data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
- data/ext/spark/src/main/scala/RubyPage.scala +34 -0
- data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
- data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
- data/ext/spark/src/main/scala/RubyTab.scala +11 -0
- data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
- data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
- data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
- data/lib/ruby-spark.rb +1 -0
- data/lib/spark.rb +198 -0
- data/lib/spark/accumulator.rb +260 -0
- data/lib/spark/broadcast.rb +98 -0
- data/lib/spark/build.rb +43 -0
- data/lib/spark/cli.rb +169 -0
- data/lib/spark/command.rb +86 -0
- data/lib/spark/command/base.rb +154 -0
- data/lib/spark/command/basic.rb +345 -0
- data/lib/spark/command/pair.rb +124 -0
- data/lib/spark/command/sort.rb +51 -0
- data/lib/spark/command/statistic.rb +144 -0
- data/lib/spark/command_builder.rb +141 -0
- data/lib/spark/command_validator.rb +34 -0
- data/lib/spark/config.rb +244 -0
- data/lib/spark/constant.rb +14 -0
- data/lib/spark/context.rb +304 -0
- data/lib/spark/error.rb +50 -0
- data/lib/spark/ext/hash.rb +41 -0
- data/lib/spark/ext/integer.rb +25 -0
- data/lib/spark/ext/io.rb +57 -0
- data/lib/spark/ext/ip_socket.rb +29 -0
- data/lib/spark/ext/module.rb +58 -0
- data/lib/spark/ext/object.rb +24 -0
- data/lib/spark/ext/string.rb +24 -0
- data/lib/spark/helper.rb +10 -0
- data/lib/spark/helper/logger.rb +40 -0
- data/lib/spark/helper/parser.rb +85 -0
- data/lib/spark/helper/serialize.rb +71 -0
- data/lib/spark/helper/statistic.rb +93 -0
- data/lib/spark/helper/system.rb +42 -0
- data/lib/spark/java_bridge.rb +19 -0
- data/lib/spark/java_bridge/base.rb +203 -0
- data/lib/spark/java_bridge/jruby.rb +23 -0
- data/lib/spark/java_bridge/rjb.rb +41 -0
- data/lib/spark/logger.rb +76 -0
- data/lib/spark/mllib.rb +100 -0
- data/lib/spark/mllib/classification/common.rb +31 -0
- data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
- data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
- data/lib/spark/mllib/classification/svm.rb +135 -0
- data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
- data/lib/spark/mllib/clustering/kmeans.rb +118 -0
- data/lib/spark/mllib/matrix.rb +120 -0
- data/lib/spark/mllib/regression/common.rb +73 -0
- data/lib/spark/mllib/regression/labeled_point.rb +41 -0
- data/lib/spark/mllib/regression/lasso.rb +100 -0
- data/lib/spark/mllib/regression/linear.rb +124 -0
- data/lib/spark/mllib/regression/ridge.rb +97 -0
- data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
- data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
- data/lib/spark/mllib/stat/distribution.rb +12 -0
- data/lib/spark/mllib/vector.rb +185 -0
- data/lib/spark/rdd.rb +1328 -0
- data/lib/spark/sampler.rb +92 -0
- data/lib/spark/serializer.rb +24 -0
- data/lib/spark/serializer/base.rb +170 -0
- data/lib/spark/serializer/cartesian.rb +37 -0
- data/lib/spark/serializer/marshal.rb +19 -0
- data/lib/spark/serializer/message_pack.rb +25 -0
- data/lib/spark/serializer/oj.rb +25 -0
- data/lib/spark/serializer/pair.rb +27 -0
- data/lib/spark/serializer/utf8.rb +25 -0
- data/lib/spark/sort.rb +189 -0
- data/lib/spark/stat_counter.rb +125 -0
- data/lib/spark/storage_level.rb +39 -0
- data/lib/spark/version.rb +3 -0
- data/lib/spark/worker/master.rb +144 -0
- data/lib/spark/worker/spark_files.rb +15 -0
- data/lib/spark/worker/worker.rb +197 -0
- data/ruby-spark.gemspec +36 -0
- data/spec/generator.rb +37 -0
- data/spec/inputs/lorem_300.txt +316 -0
- data/spec/inputs/numbers/1.txt +50 -0
- data/spec/inputs/numbers/10.txt +50 -0
- data/spec/inputs/numbers/11.txt +50 -0
- data/spec/inputs/numbers/12.txt +50 -0
- data/spec/inputs/numbers/13.txt +50 -0
- data/spec/inputs/numbers/14.txt +50 -0
- data/spec/inputs/numbers/15.txt +50 -0
- data/spec/inputs/numbers/16.txt +50 -0
- data/spec/inputs/numbers/17.txt +50 -0
- data/spec/inputs/numbers/18.txt +50 -0
- data/spec/inputs/numbers/19.txt +50 -0
- data/spec/inputs/numbers/2.txt +50 -0
- data/spec/inputs/numbers/20.txt +50 -0
- data/spec/inputs/numbers/3.txt +50 -0
- data/spec/inputs/numbers/4.txt +50 -0
- data/spec/inputs/numbers/5.txt +50 -0
- data/spec/inputs/numbers/6.txt +50 -0
- data/spec/inputs/numbers/7.txt +50 -0
- data/spec/inputs/numbers/8.txt +50 -0
- data/spec/inputs/numbers/9.txt +50 -0
- data/spec/inputs/numbers_0_100.txt +101 -0
- data/spec/inputs/numbers_1_100.txt +100 -0
- data/spec/lib/collect_spec.rb +42 -0
- data/spec/lib/command_spec.rb +68 -0
- data/spec/lib/config_spec.rb +64 -0
- data/spec/lib/context_spec.rb +163 -0
- data/spec/lib/ext_spec.rb +72 -0
- data/spec/lib/external_apps_spec.rb +45 -0
- data/spec/lib/filter_spec.rb +80 -0
- data/spec/lib/flat_map_spec.rb +100 -0
- data/spec/lib/group_spec.rb +109 -0
- data/spec/lib/helper_spec.rb +19 -0
- data/spec/lib/key_spec.rb +41 -0
- data/spec/lib/manipulation_spec.rb +114 -0
- data/spec/lib/map_partitions_spec.rb +87 -0
- data/spec/lib/map_spec.rb +91 -0
- data/spec/lib/mllib/classification_spec.rb +54 -0
- data/spec/lib/mllib/clustering_spec.rb +35 -0
- data/spec/lib/mllib/matrix_spec.rb +32 -0
- data/spec/lib/mllib/regression_spec.rb +116 -0
- data/spec/lib/mllib/vector_spec.rb +77 -0
- data/spec/lib/reduce_by_key_spec.rb +118 -0
- data/spec/lib/reduce_spec.rb +131 -0
- data/spec/lib/sample_spec.rb +46 -0
- data/spec/lib/serializer_spec.rb +13 -0
- data/spec/lib/sort_spec.rb +58 -0
- data/spec/lib/statistic_spec.rb +168 -0
- data/spec/lib/whole_text_files_spec.rb +33 -0
- data/spec/spec_helper.rb +39 -0
- metadata +301 -0
@@ -0,0 +1,41 @@
|
|
1
|
+
if !ENV.has_key?('JAVA_HOME')
|
2
|
+
raise Spark::ConfigurationError, 'Environment variable JAVA_HOME is not set'
|
3
|
+
end
|
4
|
+
|
5
|
+
require 'rjb'
|
6
|
+
|
7
|
+
module Spark
|
8
|
+
module JavaBridge
|
9
|
+
class RJB < Base
|
10
|
+
|
11
|
+
def initialize(*args)
|
12
|
+
super
|
13
|
+
Rjb.load(jars)
|
14
|
+
Rjb.primitive_conversion = true
|
15
|
+
end
|
16
|
+
|
17
|
+
def import(name, klass)
|
18
|
+
Object.const_set(name, silence_warnings { Rjb.import(klass) })
|
19
|
+
end
|
20
|
+
|
21
|
+
def java_object?(object)
|
22
|
+
object.is_a?(Rjb::Rjb_JavaProxy)
|
23
|
+
end
|
24
|
+
|
25
|
+
private
|
26
|
+
|
27
|
+
def jars
|
28
|
+
separator = windows? ? ';' : ':'
|
29
|
+
super.join(separator)
|
30
|
+
end
|
31
|
+
|
32
|
+
def silence_warnings
|
33
|
+
old_verbose, $VERBOSE = $VERBOSE, nil
|
34
|
+
yield
|
35
|
+
ensure
|
36
|
+
$VERBOSE = old_verbose
|
37
|
+
end
|
38
|
+
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
data/lib/spark/logger.rb
ADDED
@@ -0,0 +1,76 @@
|
|
1
|
+
# Necessary libraries
|
2
|
+
Spark.load_lib
|
3
|
+
|
4
|
+
module Spark
|
5
|
+
class Logger
|
6
|
+
|
7
|
+
attr_reader :jlogger
|
8
|
+
|
9
|
+
def initialize
|
10
|
+
@jlogger = JLogger.getLogger('Ruby')
|
11
|
+
end
|
12
|
+
|
13
|
+
def level_off
|
14
|
+
JLevel.toLevel('OFF')
|
15
|
+
end
|
16
|
+
|
17
|
+
# Disable all Spark log
|
18
|
+
def disable
|
19
|
+
jlogger.setLevel(level_off)
|
20
|
+
JLogger.getLogger('org').setLevel(level_off)
|
21
|
+
JLogger.getLogger('akka').setLevel(level_off)
|
22
|
+
JLogger.getRootLogger.setLevel(level_off)
|
23
|
+
end
|
24
|
+
|
25
|
+
def enabled?
|
26
|
+
!disabled?
|
27
|
+
end
|
28
|
+
|
29
|
+
def info(message)
|
30
|
+
jlogger.info(message) if info?
|
31
|
+
end
|
32
|
+
|
33
|
+
def debug(message)
|
34
|
+
jlogger.debug(message) if debug?
|
35
|
+
end
|
36
|
+
|
37
|
+
def trace(message)
|
38
|
+
jlogger.trace(message) if trace?
|
39
|
+
end
|
40
|
+
|
41
|
+
def warning(message)
|
42
|
+
jlogger.warn(message) if warning?
|
43
|
+
end
|
44
|
+
|
45
|
+
def error(message)
|
46
|
+
jlogger.error(message) if error?
|
47
|
+
end
|
48
|
+
|
49
|
+
def info?
|
50
|
+
level_enabled?('info')
|
51
|
+
end
|
52
|
+
|
53
|
+
def debug?
|
54
|
+
level_enabled?('debug')
|
55
|
+
end
|
56
|
+
|
57
|
+
def trace?
|
58
|
+
level_enabled?('trace')
|
59
|
+
end
|
60
|
+
|
61
|
+
def warning?
|
62
|
+
level_enabled?('warn')
|
63
|
+
end
|
64
|
+
|
65
|
+
def error?
|
66
|
+
level_enabled?('error')
|
67
|
+
end
|
68
|
+
|
69
|
+
def level_enabled?(type)
|
70
|
+
jlogger.isEnabledFor(JPriority.toPriority(type.upcase))
|
71
|
+
end
|
72
|
+
|
73
|
+
alias_method :warn, :warning
|
74
|
+
|
75
|
+
end
|
76
|
+
end
|
data/lib/spark/mllib.rb
ADDED
@@ -0,0 +1,100 @@
|
|
1
|
+
module Spark
|
2
|
+
# MLlib is Spark’s scalable machine learning library consisting of common learning algorithms and utilities,
|
3
|
+
# including classification, regression, clustering, collaborative filtering, dimensionality reduction,
|
4
|
+
# as well as underlying optimization primitives.
|
5
|
+
module Mllib
|
6
|
+
|
7
|
+
def self.autoload(klass, location, import=true)
|
8
|
+
if import
|
9
|
+
@for_importing ||= []
|
10
|
+
@for_importing << klass
|
11
|
+
end
|
12
|
+
|
13
|
+
super(klass, location)
|
14
|
+
end
|
15
|
+
|
16
|
+
def self.autoload_without_import(klass, location)
|
17
|
+
autoload(klass, location, false)
|
18
|
+
end
|
19
|
+
|
20
|
+
# Base classes
|
21
|
+
autoload_without_import :VectorBase, 'spark/mllib/vector'
|
22
|
+
autoload_without_import :MatrixBase, 'spark/mllib/matrix'
|
23
|
+
autoload_without_import :RegressionMethodBase, 'spark/mllib/regression/common'
|
24
|
+
autoload_without_import :ClassificationMethodBase, 'spark/mllib/classification/common'
|
25
|
+
|
26
|
+
# Linear algebra
|
27
|
+
autoload :Vectors, 'spark/mllib/vector'
|
28
|
+
autoload :DenseVector, 'spark/mllib/vector'
|
29
|
+
autoload :SparseVector, 'spark/mllib/vector'
|
30
|
+
autoload :Matrices, 'spark/mllib/matrix'
|
31
|
+
autoload :DenseMatrix, 'spark/mllib/matrix'
|
32
|
+
autoload :SparseMatrix, 'spark/mllib/matrix'
|
33
|
+
|
34
|
+
# Regression
|
35
|
+
autoload :LabeledPoint, 'spark/mllib/regression/labeled_point'
|
36
|
+
autoload :RegressionModel, 'spark/mllib/regression/common'
|
37
|
+
autoload :LinearRegressionModel, 'spark/mllib/regression/linear'
|
38
|
+
autoload :LinearRegressionWithSGD, 'spark/mllib/regression/linear'
|
39
|
+
autoload :LassoModel, 'spark/mllib/regression/lasso'
|
40
|
+
autoload :LassoWithSGD, 'spark/mllib/regression/lasso'
|
41
|
+
autoload :RidgeRegressionModel, 'spark/mllib/regression/ridge'
|
42
|
+
autoload :RidgeRegressionWithSGD, 'spark/mllib/regression/ridge'
|
43
|
+
|
44
|
+
# Classification
|
45
|
+
autoload :ClassificationModel, 'spark/mllib/classification/common'
|
46
|
+
autoload :LogisticRegressionWithSGD, 'spark/mllib/classification/logistic_regression'
|
47
|
+
autoload :LogisticRegressionWithLBFGS, 'spark/mllib/classification/logistic_regression'
|
48
|
+
autoload :SVMModel, 'spark/mllib/classification/svm'
|
49
|
+
autoload :SVMWithSGD, 'spark/mllib/classification/svm'
|
50
|
+
autoload :NaiveBayesModel, 'spark/mllib/classification/naive_bayes'
|
51
|
+
autoload :NaiveBayes, 'spark/mllib/classification/naive_bayes'
|
52
|
+
|
53
|
+
# Clustering
|
54
|
+
autoload :KMeans, 'spark/mllib/clustering/kmeans'
|
55
|
+
autoload :KMeansModel, 'spark/mllib/clustering/kmeans'
|
56
|
+
autoload :GaussianMixture, 'spark/mllib/clustering/gaussian_mixture'
|
57
|
+
autoload :GaussianMixtureModel, 'spark/mllib/clustering/gaussian_mixture'
|
58
|
+
|
59
|
+
# Stat
|
60
|
+
autoload :MultivariateGaussian, 'spark/mllib/stat/distribution'
|
61
|
+
|
62
|
+
def self.prepare
|
63
|
+
return if @prepared
|
64
|
+
|
65
|
+
# if narray?
|
66
|
+
# require 'spark/mllib/narray/vector'
|
67
|
+
# require 'spark/mllib/narray/matrix'
|
68
|
+
# elsif mdarray?
|
69
|
+
# require 'spark/mllib/mdarray/vector'
|
70
|
+
# require 'spark/mllib/mdarray/matrix'
|
71
|
+
# else
|
72
|
+
# require 'spark/mllib/matrix/vector'
|
73
|
+
# require 'spark/mllib/matrix/matrix'
|
74
|
+
# end
|
75
|
+
|
76
|
+
require 'spark/mllib/ruby_matrix/vector_adapter'
|
77
|
+
require 'spark/mllib/ruby_matrix/matrix_adapter'
|
78
|
+
|
79
|
+
@prepared = true
|
80
|
+
nil
|
81
|
+
end
|
82
|
+
|
83
|
+
def self.import(to=Object)
|
84
|
+
@for_importing.each do |klass|
|
85
|
+
to.const_set(klass, const_get(klass))
|
86
|
+
end
|
87
|
+
nil
|
88
|
+
end
|
89
|
+
|
90
|
+
def self.narray?
|
91
|
+
Gem::Specification::find_all_by_name('narray').any?
|
92
|
+
end
|
93
|
+
|
94
|
+
def self.mdarray?
|
95
|
+
Gem::Specification::find_all_by_name('mdarray').any?
|
96
|
+
end
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
Spark::Mllib.prepare
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
class ClassificationModel
|
4
|
+
|
5
|
+
attr_reader :weights, :intercept, :threshold
|
6
|
+
|
7
|
+
def initialize(weights, intercept)
|
8
|
+
@weights = Spark::Mllib::Vectors.to_vector(weights)
|
9
|
+
@intercept = intercept.to_f
|
10
|
+
@threshold = nil
|
11
|
+
end
|
12
|
+
|
13
|
+
def threshold=(value)
|
14
|
+
@threshold = value.to_f
|
15
|
+
end
|
16
|
+
|
17
|
+
def clear_threshold
|
18
|
+
@threshold = nil
|
19
|
+
end
|
20
|
+
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
module Spark
|
26
|
+
module Mllib
|
27
|
+
class ClassificationMethodBase < RegressionMethodBase
|
28
|
+
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,223 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# LogisticRegressionModel
|
5
|
+
#
|
6
|
+
# A linear binary classification model derived from logistic regression.
|
7
|
+
#
|
8
|
+
# == Examples:
|
9
|
+
#
|
10
|
+
# Spark::Mllib.import
|
11
|
+
#
|
12
|
+
# # Dense vectors
|
13
|
+
# data = [
|
14
|
+
# LabeledPoint.new(0.0, [0.0, 1.0]),
|
15
|
+
# LabeledPoint.new(1.0, [1.0, 0.0]),
|
16
|
+
# ]
|
17
|
+
# lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))
|
18
|
+
#
|
19
|
+
# lrm.predict([1.0, 0.0])
|
20
|
+
# # => 1
|
21
|
+
# lrm.predict([0.0, 1.0])
|
22
|
+
# # => 0
|
23
|
+
#
|
24
|
+
# lrm.clear_threshold
|
25
|
+
# lrm.predict([0.0, 1.0])
|
26
|
+
# # => 0.123...
|
27
|
+
#
|
28
|
+
#
|
29
|
+
# # Sparse vectors
|
30
|
+
# data = [
|
31
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
|
32
|
+
# LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
|
33
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {0 => 1.0})),
|
34
|
+
# LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
|
35
|
+
# ]
|
36
|
+
# lrm = LogisticRegressionWithSGD.train($sc.parallelize(data))
|
37
|
+
#
|
38
|
+
# lrm.predict([0.0, 1.0])
|
39
|
+
# # => 1
|
40
|
+
# lrm.predict([1.0, 0.0])
|
41
|
+
# # => 0
|
42
|
+
# lrm.predict(SparseVector.new(2, {1 => 1.0}))
|
43
|
+
# # => 1
|
44
|
+
# lrm.predict(SparseVector.new(2, {0 => 1.0}))
|
45
|
+
# # => 0
|
46
|
+
#
|
47
|
+
#
|
48
|
+
# # LogisticRegressionWithLBFGS
|
49
|
+
# data = [
|
50
|
+
# LabeledPoint.new(0.0, [0.0, 1.0]),
|
51
|
+
# LabeledPoint.new(1.0, [1.0, 0.0]),
|
52
|
+
# ]
|
53
|
+
# lrm = LogisticRegressionWithLBFGS.train($sc.parallelize(data))
|
54
|
+
#
|
55
|
+
# lrm.predict([1.0, 0.0])
|
56
|
+
# # => 1
|
57
|
+
# lrm.predict([0.0, 1.0])
|
58
|
+
# # => 0
|
59
|
+
#
|
60
|
+
class LogisticRegressionModel < ClassificationModel
|
61
|
+
|
62
|
+
def initialize(*args)
|
63
|
+
super
|
64
|
+
@threshold = 0.5
|
65
|
+
end
|
66
|
+
|
67
|
+
# Predict values for a single data point or an RDD of points using
|
68
|
+
# the model trained.
|
69
|
+
def predict(vector)
|
70
|
+
vector = Spark::Mllib::Vectors.to_vector(vector)
|
71
|
+
margin = weights.dot(vector) + intercept
|
72
|
+
score = 1.0 / (1.0 + Math.exp(-margin))
|
73
|
+
|
74
|
+
if threshold.nil?
|
75
|
+
return score
|
76
|
+
end
|
77
|
+
|
78
|
+
if score > threshold
|
79
|
+
1
|
80
|
+
else
|
81
|
+
0
|
82
|
+
end
|
83
|
+
end
|
84
|
+
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
module Spark
|
90
|
+
module Mllib
|
91
|
+
class LogisticRegressionWithSGD < ClassificationMethodBase
|
92
|
+
|
93
|
+
DEFAULT_OPTIONS = {
|
94
|
+
iterations: 100,
|
95
|
+
step: 1.0,
|
96
|
+
mini_batch_fraction: 1.0,
|
97
|
+
initial_weights: nil,
|
98
|
+
reg_param: 0.01,
|
99
|
+
reg_type: 'l2',
|
100
|
+
intercept: false
|
101
|
+
}
|
102
|
+
|
103
|
+
# Train a logistic regression model on the given data.
|
104
|
+
#
|
105
|
+
# == Arguments:
|
106
|
+
# rdd::
|
107
|
+
# The training data, an RDD of LabeledPoint.
|
108
|
+
#
|
109
|
+
# iterations::
|
110
|
+
# The number of iterations (default: 100).
|
111
|
+
#
|
112
|
+
# step::
|
113
|
+
# The step parameter used in SGD (default: 1.0).
|
114
|
+
#
|
115
|
+
# mini_batch_fraction::
|
116
|
+
# Fraction of data to be used for each SGD iteration.
|
117
|
+
#
|
118
|
+
# initial_weights::
|
119
|
+
# The initial weights (default: nil).
|
120
|
+
#
|
121
|
+
# reg_param::
|
122
|
+
# The regularizer parameter (default: 0.01).
|
123
|
+
#
|
124
|
+
# reg_type::
|
125
|
+
# The type of regularizer used for training our model (default: "l2").
|
126
|
+
#
|
127
|
+
# Allowed values:
|
128
|
+
# - "l1" for using L1 regularization
|
129
|
+
# - "l2" for using L2 regularization
|
130
|
+
# - nil for no regularization
|
131
|
+
#
|
132
|
+
# intercept::
|
133
|
+
# Boolean parameter which indicates the use
|
134
|
+
# or not of the augmented representation for
|
135
|
+
# training data (i.e. whether bias features
|
136
|
+
# are activated or not).
|
137
|
+
#
|
138
|
+
def self.train(rdd, options={})
|
139
|
+
super
|
140
|
+
|
141
|
+
weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithSGD', rdd,
|
142
|
+
options[:iterations].to_i,
|
143
|
+
options[:step].to_f,
|
144
|
+
options[:mini_batch_fraction].to_f,
|
145
|
+
options[:initial_weights],
|
146
|
+
options[:reg_param].to_f,
|
147
|
+
options[:reg_type],
|
148
|
+
options[:intercept])
|
149
|
+
|
150
|
+
LogisticRegressionModel.new(weights, intercept)
|
151
|
+
end
|
152
|
+
|
153
|
+
end
|
154
|
+
end
|
155
|
+
end
|
156
|
+
|
157
|
+
module Spark
|
158
|
+
module Mllib
|
159
|
+
class LogisticRegressionWithLBFGS < ClassificationMethodBase
|
160
|
+
|
161
|
+
DEFAULT_OPTIONS = {
|
162
|
+
iterations: 100,
|
163
|
+
initial_weights: nil,
|
164
|
+
reg_param: 0.01,
|
165
|
+
reg_type: 'l2',
|
166
|
+
intercept: false,
|
167
|
+
corrections: 10,
|
168
|
+
tolerance: 0.0001
|
169
|
+
}
|
170
|
+
|
171
|
+
# Train a logistic regression model on the given data.
|
172
|
+
#
|
173
|
+
# == Arguments:
|
174
|
+
# rdd::
|
175
|
+
# The training data, an RDD of LabeledPoint.
|
176
|
+
#
|
177
|
+
# iterations::
|
178
|
+
# The number of iterations (default: 100).
|
179
|
+
#
|
180
|
+
# initial_weights::
|
181
|
+
# The initial weights (default: nil).
|
182
|
+
#
|
183
|
+
# reg_param::
|
184
|
+
# The regularizer parameter (default: 0.01).
|
185
|
+
#
|
186
|
+
# reg_type::
|
187
|
+
# The type of regularizer used for training our model (default: "l2").
|
188
|
+
#
|
189
|
+
# Allowed values:
|
190
|
+
# - "l1" for using L1 regularization
|
191
|
+
# - "l2" for using L2 regularization
|
192
|
+
# - nil for no regularization
|
193
|
+
#
|
194
|
+
# intercept::
|
195
|
+
# Boolean parameter which indicates the use
|
196
|
+
# or not of the augmented representation for
|
197
|
+
# training data (i.e. whether bias features
|
198
|
+
# are activated or not).
|
199
|
+
#
|
200
|
+
# corrections::
|
201
|
+
# The number of corrections used in the LBFGS update (default: 10).
|
202
|
+
#
|
203
|
+
# tolerance::
|
204
|
+
# The convergence tolerance of iterations for L-BFGS (default: 0.0001).
|
205
|
+
#
|
206
|
+
def self.train(rdd, options={})
|
207
|
+
super
|
208
|
+
|
209
|
+
weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainLogisticRegressionModelWithLBFGS', rdd,
|
210
|
+
options[:iterations].to_i,
|
211
|
+
options[:initial_weights],
|
212
|
+
options[:reg_param].to_f,
|
213
|
+
options[:reg_type],
|
214
|
+
options[:intercept],
|
215
|
+
options[:corrections].to_i,
|
216
|
+
options[:tolerance].to_f)
|
217
|
+
|
218
|
+
LogisticRegressionModel.new(weights, intercept)
|
219
|
+
end
|
220
|
+
|
221
|
+
end
|
222
|
+
end
|
223
|
+
end
|