ruby-spark 1.0.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +37 -0
- data/Gemfile +47 -0
- data/Guardfile +5 -0
- data/LICENSE.txt +22 -0
- data/README.md +185 -0
- data/Rakefile +35 -0
- data/TODO.md +7 -0
- data/benchmark/aggregate.rb +33 -0
- data/benchmark/bisect.rb +88 -0
- data/benchmark/custom_marshal.rb +94 -0
- data/benchmark/digest.rb +150 -0
- data/benchmark/enumerator.rb +88 -0
- data/benchmark/performance/prepare.sh +18 -0
- data/benchmark/performance/python.py +156 -0
- data/benchmark/performance/r.r +69 -0
- data/benchmark/performance/ruby.rb +167 -0
- data/benchmark/performance/run-all.sh +160 -0
- data/benchmark/performance/scala.scala +181 -0
- data/benchmark/serializer.rb +82 -0
- data/benchmark/sort.rb +43 -0
- data/benchmark/sort2.rb +164 -0
- data/benchmark/take.rb +28 -0
- data/bin/ruby-spark +8 -0
- data/example/pi.rb +28 -0
- data/ext/ruby_c/extconf.rb +3 -0
- data/ext/ruby_c/murmur.c +158 -0
- data/ext/ruby_c/murmur.h +9 -0
- data/ext/ruby_c/ruby-spark.c +18 -0
- data/ext/ruby_java/Digest.java +36 -0
- data/ext/ruby_java/Murmur2.java +98 -0
- data/ext/ruby_java/RubySparkExtService.java +28 -0
- data/ext/ruby_java/extconf.rb +3 -0
- data/ext/spark/build.sbt +73 -0
- data/ext/spark/project/plugins.sbt +9 -0
- data/ext/spark/sbt/sbt +34 -0
- data/ext/spark/src/main/scala/Exec.scala +91 -0
- data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
- data/ext/spark/src/main/scala/Marshal.scala +52 -0
- data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
- data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
- data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
- data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
- data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
- data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
- data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
- data/ext/spark/src/main/scala/RubyPage.scala +34 -0
- data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
- data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
- data/ext/spark/src/main/scala/RubyTab.scala +11 -0
- data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
- data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
- data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
- data/lib/ruby-spark.rb +1 -0
- data/lib/spark.rb +198 -0
- data/lib/spark/accumulator.rb +260 -0
- data/lib/spark/broadcast.rb +98 -0
- data/lib/spark/build.rb +43 -0
- data/lib/spark/cli.rb +169 -0
- data/lib/spark/command.rb +86 -0
- data/lib/spark/command/base.rb +154 -0
- data/lib/spark/command/basic.rb +345 -0
- data/lib/spark/command/pair.rb +124 -0
- data/lib/spark/command/sort.rb +51 -0
- data/lib/spark/command/statistic.rb +144 -0
- data/lib/spark/command_builder.rb +141 -0
- data/lib/spark/command_validator.rb +34 -0
- data/lib/spark/config.rb +244 -0
- data/lib/spark/constant.rb +14 -0
- data/lib/spark/context.rb +304 -0
- data/lib/spark/error.rb +50 -0
- data/lib/spark/ext/hash.rb +41 -0
- data/lib/spark/ext/integer.rb +25 -0
- data/lib/spark/ext/io.rb +57 -0
- data/lib/spark/ext/ip_socket.rb +29 -0
- data/lib/spark/ext/module.rb +58 -0
- data/lib/spark/ext/object.rb +24 -0
- data/lib/spark/ext/string.rb +24 -0
- data/lib/spark/helper.rb +10 -0
- data/lib/spark/helper/logger.rb +40 -0
- data/lib/spark/helper/parser.rb +85 -0
- data/lib/spark/helper/serialize.rb +71 -0
- data/lib/spark/helper/statistic.rb +93 -0
- data/lib/spark/helper/system.rb +42 -0
- data/lib/spark/java_bridge.rb +19 -0
- data/lib/spark/java_bridge/base.rb +203 -0
- data/lib/spark/java_bridge/jruby.rb +23 -0
- data/lib/spark/java_bridge/rjb.rb +41 -0
- data/lib/spark/logger.rb +76 -0
- data/lib/spark/mllib.rb +100 -0
- data/lib/spark/mllib/classification/common.rb +31 -0
- data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
- data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
- data/lib/spark/mllib/classification/svm.rb +135 -0
- data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
- data/lib/spark/mllib/clustering/kmeans.rb +118 -0
- data/lib/spark/mllib/matrix.rb +120 -0
- data/lib/spark/mllib/regression/common.rb +73 -0
- data/lib/spark/mllib/regression/labeled_point.rb +41 -0
- data/lib/spark/mllib/regression/lasso.rb +100 -0
- data/lib/spark/mllib/regression/linear.rb +124 -0
- data/lib/spark/mllib/regression/ridge.rb +97 -0
- data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
- data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
- data/lib/spark/mllib/stat/distribution.rb +12 -0
- data/lib/spark/mllib/vector.rb +185 -0
- data/lib/spark/rdd.rb +1328 -0
- data/lib/spark/sampler.rb +92 -0
- data/lib/spark/serializer.rb +24 -0
- data/lib/spark/serializer/base.rb +170 -0
- data/lib/spark/serializer/cartesian.rb +37 -0
- data/lib/spark/serializer/marshal.rb +19 -0
- data/lib/spark/serializer/message_pack.rb +25 -0
- data/lib/spark/serializer/oj.rb +25 -0
- data/lib/spark/serializer/pair.rb +27 -0
- data/lib/spark/serializer/utf8.rb +25 -0
- data/lib/spark/sort.rb +189 -0
- data/lib/spark/stat_counter.rb +125 -0
- data/lib/spark/storage_level.rb +39 -0
- data/lib/spark/version.rb +3 -0
- data/lib/spark/worker/master.rb +144 -0
- data/lib/spark/worker/spark_files.rb +15 -0
- data/lib/spark/worker/worker.rb +197 -0
- data/ruby-spark.gemspec +36 -0
- data/spec/generator.rb +37 -0
- data/spec/inputs/lorem_300.txt +316 -0
- data/spec/inputs/numbers/1.txt +50 -0
- data/spec/inputs/numbers/10.txt +50 -0
- data/spec/inputs/numbers/11.txt +50 -0
- data/spec/inputs/numbers/12.txt +50 -0
- data/spec/inputs/numbers/13.txt +50 -0
- data/spec/inputs/numbers/14.txt +50 -0
- data/spec/inputs/numbers/15.txt +50 -0
- data/spec/inputs/numbers/16.txt +50 -0
- data/spec/inputs/numbers/17.txt +50 -0
- data/spec/inputs/numbers/18.txt +50 -0
- data/spec/inputs/numbers/19.txt +50 -0
- data/spec/inputs/numbers/2.txt +50 -0
- data/spec/inputs/numbers/20.txt +50 -0
- data/spec/inputs/numbers/3.txt +50 -0
- data/spec/inputs/numbers/4.txt +50 -0
- data/spec/inputs/numbers/5.txt +50 -0
- data/spec/inputs/numbers/6.txt +50 -0
- data/spec/inputs/numbers/7.txt +50 -0
- data/spec/inputs/numbers/8.txt +50 -0
- data/spec/inputs/numbers/9.txt +50 -0
- data/spec/inputs/numbers_0_100.txt +101 -0
- data/spec/inputs/numbers_1_100.txt +100 -0
- data/spec/lib/collect_spec.rb +42 -0
- data/spec/lib/command_spec.rb +68 -0
- data/spec/lib/config_spec.rb +64 -0
- data/spec/lib/context_spec.rb +163 -0
- data/spec/lib/ext_spec.rb +72 -0
- data/spec/lib/external_apps_spec.rb +45 -0
- data/spec/lib/filter_spec.rb +80 -0
- data/spec/lib/flat_map_spec.rb +100 -0
- data/spec/lib/group_spec.rb +109 -0
- data/spec/lib/helper_spec.rb +19 -0
- data/spec/lib/key_spec.rb +41 -0
- data/spec/lib/manipulation_spec.rb +114 -0
- data/spec/lib/map_partitions_spec.rb +87 -0
- data/spec/lib/map_spec.rb +91 -0
- data/spec/lib/mllib/classification_spec.rb +54 -0
- data/spec/lib/mllib/clustering_spec.rb +35 -0
- data/spec/lib/mllib/matrix_spec.rb +32 -0
- data/spec/lib/mllib/regression_spec.rb +116 -0
- data/spec/lib/mllib/vector_spec.rb +77 -0
- data/spec/lib/reduce_by_key_spec.rb +118 -0
- data/spec/lib/reduce_spec.rb +131 -0
- data/spec/lib/sample_spec.rb +46 -0
- data/spec/lib/serializer_spec.rb +13 -0
- data/spec/lib/sort_spec.rb +58 -0
- data/spec/lib/statistic_spec.rb +168 -0
- data/spec/lib/whole_text_files_spec.rb +33 -0
- data/spec/spec_helper.rb +39 -0
- metadata +301 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# NaiveBayesModel
|
5
|
+
#
|
6
|
+
# Model for Naive Bayes classifiers.
|
7
|
+
#
|
8
|
+
# Contains two parameters:
|
9
|
+
# pi:: vector of logs of class priors (dimension C)
|
10
|
+
# theta:: matrix of logs of class conditional probabilities (CxD)
|
11
|
+
#
|
12
|
+
# == Examples:
|
13
|
+
#
|
14
|
+
# Spark::Mllib.import
|
15
|
+
#
|
16
|
+
# # Dense vectors
|
17
|
+
# data = [
|
18
|
+
# LabeledPoint.new(0.0, [0.0, 0.0]),
|
19
|
+
# LabeledPoint.new(0.0, [0.0, 1.0]),
|
20
|
+
# LabeledPoint.new(1.0, [1.0, 0.0])
|
21
|
+
# ]
|
22
|
+
# model = NaiveBayes.train($sc.parallelize(data))
|
23
|
+
#
|
24
|
+
# model.predict([0.0, 1.0])
|
25
|
+
# # => 0.0
|
26
|
+
# model.predict([1.0, 0.0])
|
27
|
+
# # => 1.0
|
28
|
+
#
|
29
|
+
#
|
30
|
+
# # Sparse vectors
|
31
|
+
# data = [
|
32
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {1 => 0.0})),
|
33
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {1 => 1.0})),
|
34
|
+
# LabeledPoint.new(1.0, SparseVector.new(2, {0 => 1.0}))
|
35
|
+
# ]
|
36
|
+
# model = NaiveBayes.train($sc.parallelize(data))
|
37
|
+
#
|
38
|
+
# model.predict(SparseVector.new(2, {1 => 1.0}))
|
39
|
+
# # => 0.0
|
40
|
+
# model.predict(SparseVector.new(2, {0 => 1.0}))
|
41
|
+
# # => 1.0
|
42
|
+
#
|
43
|
+
class NaiveBayesModel
|
44
|
+
|
45
|
+
attr_reader :labels, :pi, :theta
|
46
|
+
|
47
|
+
def initialize(labels, pi, theta)
|
48
|
+
@labels = labels
|
49
|
+
@pi = pi
|
50
|
+
@theta = theta
|
51
|
+
end
|
52
|
+
|
53
|
+
# Predict values for a single data point or an RDD of points using
|
54
|
+
# the model trained.
|
55
|
+
def predict(vector)
|
56
|
+
vector = Spark::Mllib::Vectors.to_vector(vector)
|
57
|
+
array = (vector.dot(theta) + pi).to_a
|
58
|
+
index = array.index(array.max)
|
59
|
+
labels[index]
|
60
|
+
end
|
61
|
+
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
|
67
|
+
module Spark
|
68
|
+
module Mllib
|
69
|
+
class NaiveBayes
|
70
|
+
|
71
|
+
# Trains a Naive Bayes model given an RDD of (label, features) pairs.
|
72
|
+
#
|
73
|
+
# This is the Multinomial NB (http://tinyurl.com/lsdw6p) which can handle all kinds of
|
74
|
+
# discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
|
75
|
+
# document classification. By making every vector a 0-1 vector, it can also be used as
|
76
|
+
# Bernoulli NB (http://tinyurl.com/p7c96j6). The input feature values must be nonnegative.
|
77
|
+
#
|
78
|
+
# == Arguments:
|
79
|
+
# rdd:: RDD of LabeledPoint.
|
80
|
+
# lambda:: The smoothing parameter.
|
81
|
+
#
|
82
|
+
def self.train(rdd, lambda=1.0)
|
83
|
+
# Validation
|
84
|
+
first = rdd.first
|
85
|
+
unless first.is_a?(LabeledPoint)
|
86
|
+
raise Spark::MllibError, "RDD should contains LabeledPoint, got #{first.class}"
|
87
|
+
end
|
88
|
+
|
89
|
+
labels, pi, theta = Spark.jb.call(RubyMLLibAPI.new, 'trainNaiveBayes', rdd, lambda)
|
90
|
+
theta = Spark::Mllib::Matrices.dense(theta.size, theta.first.size, theta)
|
91
|
+
|
92
|
+
NaiveBayesModel.new(labels, pi, theta)
|
93
|
+
end
|
94
|
+
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
@@ -0,0 +1,135 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# SVMModel
|
5
|
+
#
|
6
|
+
# A support vector machine.
|
7
|
+
#
|
8
|
+
# == Examples:
|
9
|
+
#
|
10
|
+
# Spark::Mllib.import
|
11
|
+
#
|
12
|
+
# # Dense vectors
|
13
|
+
# data = [
|
14
|
+
# LabeledPoint.new(0.0, [0.0]),
|
15
|
+
# LabeledPoint.new(1.0, [1.0]),
|
16
|
+
# LabeledPoint.new(1.0, [2.0]),
|
17
|
+
# LabeledPoint.new(1.0, [3.0])
|
18
|
+
# ]
|
19
|
+
# svm = SVMWithSGD.train($sc.parallelize(data))
|
20
|
+
#
|
21
|
+
# svm.predict([1.0])
|
22
|
+
# # => 1
|
23
|
+
# svm.clear_threshold
|
24
|
+
# svm.predict([1.0])
|
25
|
+
# # => 1.25...
|
26
|
+
#
|
27
|
+
#
|
28
|
+
# # Sparse vectors
|
29
|
+
# data = [
|
30
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {0 => -1.0})),
|
31
|
+
# LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
|
32
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
|
33
|
+
# LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
|
34
|
+
# ]
|
35
|
+
# svm = SVMWithSGD.train($sc.parallelize(data))
|
36
|
+
#
|
37
|
+
# svm.predict(SparseVector.new(2, {1 => 1.0}))
|
38
|
+
# # => 1
|
39
|
+
# svm.predict(SparseVector.new(2, {0 => -1.0}))
|
40
|
+
# # => 0
|
41
|
+
#
|
42
|
+
class SVMModel < ClassificationModel
|
43
|
+
|
44
|
+
def initialize(*args)
|
45
|
+
super
|
46
|
+
@threshold = 0.0
|
47
|
+
end
|
48
|
+
|
49
|
+
# Predict values for a single data point or an RDD of points using
|
50
|
+
# the model trained.
|
51
|
+
def predict(vector)
|
52
|
+
vector = Spark::Mllib::Vectors.to_vector(vector)
|
53
|
+
margin = weights.dot(vector) + intercept
|
54
|
+
|
55
|
+
if threshold.nil?
|
56
|
+
return margin
|
57
|
+
end
|
58
|
+
|
59
|
+
if margin > threshold
|
60
|
+
1
|
61
|
+
else
|
62
|
+
0
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
module Spark
|
71
|
+
module Mllib
|
72
|
+
class SVMWithSGD < ClassificationMethodBase
|
73
|
+
|
74
|
+
DEFAULT_OPTIONS = {
|
75
|
+
iterations: 100,
|
76
|
+
step: 1.0,
|
77
|
+
reg_param: 0.01,
|
78
|
+
mini_batch_fraction: 1.0,
|
79
|
+
initial_weights: nil,
|
80
|
+
reg_type: 'l2',
|
81
|
+
intercept: false
|
82
|
+
}
|
83
|
+
|
84
|
+
# Train a support vector machine on the given data.
|
85
|
+
#
|
86
|
+
# rdd::
|
87
|
+
# The training data, an RDD of LabeledPoint.
|
88
|
+
#
|
89
|
+
# iterations::
|
90
|
+
# The number of iterations (default: 100).
|
91
|
+
#
|
92
|
+
# step::
|
93
|
+
# The step parameter used in SGD (default: 1.0).
|
94
|
+
#
|
95
|
+
# reg_param::
|
96
|
+
# The regularizer parameter (default: 0.01).
|
97
|
+
#
|
98
|
+
# mini_batch_fraction::
|
99
|
+
# Fraction of data to be used for each SGD iteration.
|
100
|
+
#
|
101
|
+
# initial_weights::
|
102
|
+
# The initial weights (default: nil).
|
103
|
+
#
|
104
|
+
# reg_type::
|
105
|
+
# The type of regularizer used for training our model (default: "l2").
|
106
|
+
#
|
107
|
+
# Allowed values:
|
108
|
+
# - "l1" for using L1 regularization
|
109
|
+
# - "l2" for using L2 regularization
|
110
|
+
# - nil for no regularization
|
111
|
+
#
|
112
|
+
# intercept::
|
113
|
+
# Boolean parameter which indicates the use
|
114
|
+
# or not of the augmented representation for
|
115
|
+
# training data (i.e. whether bias features
|
116
|
+
# are activated or not).
|
117
|
+
#
|
118
|
+
def self.train(rdd, options={})
|
119
|
+
super
|
120
|
+
|
121
|
+
weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainSVMModelWithSGD', rdd,
|
122
|
+
options[:iterations].to_i,
|
123
|
+
options[:step].to_f,
|
124
|
+
options[:reg_param].to_f,
|
125
|
+
options[:mini_batch_fraction].to_f,
|
126
|
+
options[:initial_weights],
|
127
|
+
options[:reg_type],
|
128
|
+
options[:intercept])
|
129
|
+
|
130
|
+
SVMModel.new(weights, intercept)
|
131
|
+
end
|
132
|
+
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# GaussianMixtureModel
|
5
|
+
#
|
6
|
+
# A clustering model derived from the Gaussian Mixture Model method.
|
7
|
+
#
|
8
|
+
# == Examples:
|
9
|
+
#
|
10
|
+
# Spark::Mllib.import
|
11
|
+
#
|
12
|
+
# data = [
|
13
|
+
# DenseVector.new([-0.1, -0.05]),
|
14
|
+
# DenseVector.new([-0.01, -0.1]),
|
15
|
+
# DenseVector.new([0.9, 0.8]),
|
16
|
+
# DenseVector.new([0.75, 0.935]),
|
17
|
+
# DenseVector.new([-0.83, -0.68]),
|
18
|
+
# DenseVector.new([-0.91, -0.76])
|
19
|
+
# ]
|
20
|
+
#
|
21
|
+
# model = GaussianMixture.train($sc.parallelize(data), 3, convergence_tol: 0.0001, max_iterations: 50, seed: 10)
|
22
|
+
#
|
23
|
+
# labels = model.predict($sc.parallelize(data)).collect
|
24
|
+
#
|
25
|
+
class GaussianMixtureModel
|
26
|
+
|
27
|
+
attr_reader :weights, :gaussians, :k
|
28
|
+
|
29
|
+
def initialize(weights, gaussians)
|
30
|
+
@weights = weights
|
31
|
+
@gaussians = gaussians
|
32
|
+
@k = weights.size
|
33
|
+
end
|
34
|
+
|
35
|
+
# Find the cluster to which the points in 'x' has maximum membership
|
36
|
+
# in this model.
|
37
|
+
def predict(rdd)
|
38
|
+
if rdd.is_a?(Spark::RDD)
|
39
|
+
predict_soft(rdd).map('lambda{|x| x.index(x.max)}')
|
40
|
+
else
|
41
|
+
raise ArgumentError, 'Argument must be a RDD.'
|
42
|
+
end
|
43
|
+
end
|
44
|
+
|
45
|
+
# Find the membership of each point in 'x' to all mixture components.
|
46
|
+
def predict_soft(rdd)
|
47
|
+
Spark.jb.call(RubyMLLibAPI.new, 'predictSoftGMM', rdd, weights, means, sigmas)
|
48
|
+
end
|
49
|
+
|
50
|
+
def means
|
51
|
+
@means ||= @gaussians.map(&:mu)
|
52
|
+
end
|
53
|
+
|
54
|
+
def sigmas
|
55
|
+
@sigmas ||= @gaussians.map(&:sigma)
|
56
|
+
end
|
57
|
+
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
module Spark
|
63
|
+
module Mllib
|
64
|
+
class GaussianMixture
|
65
|
+
|
66
|
+
def self.train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil)
|
67
|
+
weights, means, sigmas = Spark.jb.call(RubyMLLibAPI.new, 'trainGaussianMixture', rdd,
|
68
|
+
k, convergence_tol, max_iterations, Spark.jb.to_long(seed))
|
69
|
+
|
70
|
+
means.map! {|mu| Spark.jb.java_to_ruby(mu)}
|
71
|
+
sigmas.map!{|sigma| Spark.jb.java_to_ruby(sigma)}
|
72
|
+
|
73
|
+
mvgs = Array.new(k) do |i|
|
74
|
+
MultivariateGaussian.new(means[i], sigmas[i])
|
75
|
+
end
|
76
|
+
|
77
|
+
GaussianMixtureModel.new(weights, mvgs)
|
78
|
+
end
|
79
|
+
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
@@ -0,0 +1,118 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# KMeansModel
|
5
|
+
#
|
6
|
+
# A clustering model derived from the k-means method.
|
7
|
+
#
|
8
|
+
# == Examples:
|
9
|
+
#
|
10
|
+
# Spark::Mllib.import
|
11
|
+
#
|
12
|
+
# # Dense vectors
|
13
|
+
# data = [
|
14
|
+
# DenseVector.new([0.0,0.0]),
|
15
|
+
# DenseVector.new([1.0,1.0]),
|
16
|
+
# DenseVector.new([9.0,8.0]),
|
17
|
+
# DenseVector.new([8.0,9.0])
|
18
|
+
# ]
|
19
|
+
#
|
20
|
+
# model = KMeans.train($sc.parallelize(data), 2, max_iterations: 10,
|
21
|
+
# runs: 30, initialization_mode: "random")
|
22
|
+
#
|
23
|
+
# model.predict([0.0, 0.0]) == model.predict([1.0, 1.0])
|
24
|
+
# # => true
|
25
|
+
# model.predict([8.0, 9.0]) == model.predict([9.0, 8.0])
|
26
|
+
# # => true
|
27
|
+
#
|
28
|
+
#
|
29
|
+
# # Sparse vectors
|
30
|
+
# data = [
|
31
|
+
# SparseVector.new(3, {1 => 1.0}),
|
32
|
+
# SparseVector.new(3, {1 => 1.1}),
|
33
|
+
# SparseVector.new(3, {2 => 1.0}),
|
34
|
+
# SparseVector.new(3, {2 => 1.1})
|
35
|
+
# ]
|
36
|
+
# model = KMeans.train($sc.parallelize(data), 2, initialization_mode: "k-means||")
|
37
|
+
#
|
38
|
+
# model.predict([0.0, 1.0, 0.0]) == model.predict([0, 1.1, 0.0])
|
39
|
+
# # => true
|
40
|
+
# model.predict([0.0, 0.0, 1.0]) == model.predict([0, 0, 1.1])
|
41
|
+
# # => true
|
42
|
+
# model.predict(data[0]) == model.predict(data[1])
|
43
|
+
# # => true
|
44
|
+
# model.predict(data[2]) == model.predict(data[3])
|
45
|
+
# # => true
|
46
|
+
#
|
47
|
+
class KMeansModel
|
48
|
+
|
49
|
+
attr_reader :centers
|
50
|
+
|
51
|
+
def initialize(centers)
|
52
|
+
@centers = centers
|
53
|
+
end
|
54
|
+
|
55
|
+
# Find the cluster to which x belongs in this model.
|
56
|
+
def predict(vector)
|
57
|
+
vector = Spark::Mllib::Vectors.to_vector(vector)
|
58
|
+
best = 0
|
59
|
+
best_distance = Float::INFINITY
|
60
|
+
|
61
|
+
@centers.each_with_index do |center, index|
|
62
|
+
distance = vector.squared_distance(center)
|
63
|
+
if distance < best_distance
|
64
|
+
best = index
|
65
|
+
best_distance = distance
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
best
|
70
|
+
end
|
71
|
+
|
72
|
+
def self.from_java(object)
|
73
|
+
centers = object.clusterCenters
|
74
|
+
centers.map! do |center|
|
75
|
+
Spark.jb.java_to_ruby(center)
|
76
|
+
end
|
77
|
+
|
78
|
+
KMeansModel.new(centers)
|
79
|
+
end
|
80
|
+
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
|
85
|
+
module Spark
|
86
|
+
module Mllib
|
87
|
+
class KMeans
|
88
|
+
|
89
|
+
# Trains a k-means model using the given set of parameters.
|
90
|
+
#
|
91
|
+
# == Arguments:
|
92
|
+
# rdd::
|
93
|
+
# The training data, an RDD of Vectors.
|
94
|
+
#
|
95
|
+
# k::
|
96
|
+
# Number of clusters.
|
97
|
+
#
|
98
|
+
# max_iterations::
|
99
|
+
# Max number of iterations.
|
100
|
+
#
|
101
|
+
# runs::
|
102
|
+
# Number of parallel runs, defaults to 1. The best model is returned.
|
103
|
+
#
|
104
|
+
# initialization_mode::
|
105
|
+
# Initialization model, either "random" or "k-means||" (default).
|
106
|
+
#
|
107
|
+
# seed::
|
108
|
+
# Random seed value for cluster initialization.
|
109
|
+
#
|
110
|
+
def self.train(rdd, k, max_iterations: 100, runs: 1, initialization_mode: 'k-means||', seed: nil)
|
111
|
+
# Call returns KMeansModel
|
112
|
+
Spark.jb.call(RubyMLLibAPI.new, 'trainKMeansModel', rdd,
|
113
|
+
k, max_iterations, runs, initialization_mode, Spark.jb.to_long(seed))
|
114
|
+
end
|
115
|
+
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
@@ -0,0 +1,120 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
module Matrices
|
4
|
+
|
5
|
+
def self.dense(*args)
|
6
|
+
DenseMatrix.new(*args)
|
7
|
+
end
|
8
|
+
|
9
|
+
def self.sparse(*args)
|
10
|
+
SparseMatrix.new(*args)
|
11
|
+
end
|
12
|
+
|
13
|
+
def self.to_matrix(data)
|
14
|
+
if data.is_a?(SparseMatrix) || data.is_a?(DenseMatrix)
|
15
|
+
data
|
16
|
+
elsif data.is_a?(Array)
|
17
|
+
DenseMatrix.new(data)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
module Spark
|
26
|
+
module Mllib
|
27
|
+
# @abstract Parent for all type of matrices
|
28
|
+
class MatrixBase < MatrixAdapter
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
module Spark
|
34
|
+
module Mllib
|
35
|
+
##
|
36
|
+
# DenseMatrix
|
37
|
+
#
|
38
|
+
# DenseMatrix.new(2, 3, [[1,2,3], [4,5,6]]).values
|
39
|
+
# # => [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
40
|
+
#
|
41
|
+
class DenseMatrix < MatrixBase
|
42
|
+
|
43
|
+
def initialize(rows, cols, values)
|
44
|
+
super(:dense, rows, cols, values.to_a)
|
45
|
+
end
|
46
|
+
|
47
|
+
def to_java
|
48
|
+
JDenseMatrix.new(shape[0], shape[1], values.flatten)
|
49
|
+
end
|
50
|
+
|
51
|
+
def self.from_java(object)
|
52
|
+
rows = object.numRows
|
53
|
+
cols = object.numCols
|
54
|
+
values = object.values
|
55
|
+
|
56
|
+
DenseMatrix.new(rows, cols, values)
|
57
|
+
end
|
58
|
+
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
module Spark
|
64
|
+
module Mllib
|
65
|
+
##
|
66
|
+
# SparseMatrix
|
67
|
+
#
|
68
|
+
# == Arguments:
|
69
|
+
# rows::
|
70
|
+
# Number of rows.
|
71
|
+
#
|
72
|
+
# cols::
|
73
|
+
# Number of columns.
|
74
|
+
#
|
75
|
+
# col_pointers::
|
76
|
+
# The index corresponding to the start of a new column.
|
77
|
+
#
|
78
|
+
# row_indices::
|
79
|
+
# The row index of the entry. They must be in strictly
|
80
|
+
# increasing order for each column.
|
81
|
+
#
|
82
|
+
# values::
|
83
|
+
# Nonzero matrix entries in column major.
|
84
|
+
#
|
85
|
+
# == Examples:
|
86
|
+
#
|
87
|
+
# SparseMatrix.new(3, 3, [0, 2, 3, 6], [0, 2, 1, 0, 1, 2], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).values
|
88
|
+
#
|
89
|
+
# # => [
|
90
|
+
# # [1.0, 0.0, 4.0],
|
91
|
+
# # [0.0, 3.0, 5.0],
|
92
|
+
# # [2.0, 0.0, 6.0]
|
93
|
+
# # ]
|
94
|
+
#
|
95
|
+
class SparseMatrix < MatrixBase
|
96
|
+
|
97
|
+
attr_reader :col_pointers, :row_indices
|
98
|
+
|
99
|
+
def initialize(rows, cols, col_pointers, row_indices, values)
|
100
|
+
super(:sparse, rows, cols)
|
101
|
+
|
102
|
+
@col_pointers = col_pointers
|
103
|
+
@row_indices = row_indices
|
104
|
+
@values = values
|
105
|
+
|
106
|
+
j = 0
|
107
|
+
while j < cols
|
108
|
+
idx = col_pointers[j]
|
109
|
+
idx_end = col_pointers[j+1]
|
110
|
+
while idx < idx_end
|
111
|
+
self[row_indices[idx], j] = values[idx]
|
112
|
+
idx += 1
|
113
|
+
end
|
114
|
+
j += 1
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
end
|
119
|
+
end
|
120
|
+
end
|