ruby-spark 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +37 -0
- data/Gemfile +47 -0
- data/Guardfile +5 -0
- data/LICENSE.txt +22 -0
- data/README.md +185 -0
- data/Rakefile +35 -0
- data/TODO.md +7 -0
- data/benchmark/aggregate.rb +33 -0
- data/benchmark/bisect.rb +88 -0
- data/benchmark/custom_marshal.rb +94 -0
- data/benchmark/digest.rb +150 -0
- data/benchmark/enumerator.rb +88 -0
- data/benchmark/performance/prepare.sh +18 -0
- data/benchmark/performance/python.py +156 -0
- data/benchmark/performance/r.r +69 -0
- data/benchmark/performance/ruby.rb +167 -0
- data/benchmark/performance/run-all.sh +160 -0
- data/benchmark/performance/scala.scala +181 -0
- data/benchmark/serializer.rb +82 -0
- data/benchmark/sort.rb +43 -0
- data/benchmark/sort2.rb +164 -0
- data/benchmark/take.rb +28 -0
- data/bin/ruby-spark +8 -0
- data/example/pi.rb +28 -0
- data/ext/ruby_c/extconf.rb +3 -0
- data/ext/ruby_c/murmur.c +158 -0
- data/ext/ruby_c/murmur.h +9 -0
- data/ext/ruby_c/ruby-spark.c +18 -0
- data/ext/ruby_java/Digest.java +36 -0
- data/ext/ruby_java/Murmur2.java +98 -0
- data/ext/ruby_java/RubySparkExtService.java +28 -0
- data/ext/ruby_java/extconf.rb +3 -0
- data/ext/spark/build.sbt +73 -0
- data/ext/spark/project/plugins.sbt +9 -0
- data/ext/spark/sbt/sbt +34 -0
- data/ext/spark/src/main/scala/Exec.scala +91 -0
- data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
- data/ext/spark/src/main/scala/Marshal.scala +52 -0
- data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
- data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
- data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
- data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
- data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
- data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
- data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
- data/ext/spark/src/main/scala/RubyPage.scala +34 -0
- data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
- data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
- data/ext/spark/src/main/scala/RubyTab.scala +11 -0
- data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
- data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
- data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
- data/lib/ruby-spark.rb +1 -0
- data/lib/spark.rb +198 -0
- data/lib/spark/accumulator.rb +260 -0
- data/lib/spark/broadcast.rb +98 -0
- data/lib/spark/build.rb +43 -0
- data/lib/spark/cli.rb +169 -0
- data/lib/spark/command.rb +86 -0
- data/lib/spark/command/base.rb +154 -0
- data/lib/spark/command/basic.rb +345 -0
- data/lib/spark/command/pair.rb +124 -0
- data/lib/spark/command/sort.rb +51 -0
- data/lib/spark/command/statistic.rb +144 -0
- data/lib/spark/command_builder.rb +141 -0
- data/lib/spark/command_validator.rb +34 -0
- data/lib/spark/config.rb +244 -0
- data/lib/spark/constant.rb +14 -0
- data/lib/spark/context.rb +304 -0
- data/lib/spark/error.rb +50 -0
- data/lib/spark/ext/hash.rb +41 -0
- data/lib/spark/ext/integer.rb +25 -0
- data/lib/spark/ext/io.rb +57 -0
- data/lib/spark/ext/ip_socket.rb +29 -0
- data/lib/spark/ext/module.rb +58 -0
- data/lib/spark/ext/object.rb +24 -0
- data/lib/spark/ext/string.rb +24 -0
- data/lib/spark/helper.rb +10 -0
- data/lib/spark/helper/logger.rb +40 -0
- data/lib/spark/helper/parser.rb +85 -0
- data/lib/spark/helper/serialize.rb +71 -0
- data/lib/spark/helper/statistic.rb +93 -0
- data/lib/spark/helper/system.rb +42 -0
- data/lib/spark/java_bridge.rb +19 -0
- data/lib/spark/java_bridge/base.rb +203 -0
- data/lib/spark/java_bridge/jruby.rb +23 -0
- data/lib/spark/java_bridge/rjb.rb +41 -0
- data/lib/spark/logger.rb +76 -0
- data/lib/spark/mllib.rb +100 -0
- data/lib/spark/mllib/classification/common.rb +31 -0
- data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
- data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
- data/lib/spark/mllib/classification/svm.rb +135 -0
- data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
- data/lib/spark/mllib/clustering/kmeans.rb +118 -0
- data/lib/spark/mllib/matrix.rb +120 -0
- data/lib/spark/mllib/regression/common.rb +73 -0
- data/lib/spark/mllib/regression/labeled_point.rb +41 -0
- data/lib/spark/mllib/regression/lasso.rb +100 -0
- data/lib/spark/mllib/regression/linear.rb +124 -0
- data/lib/spark/mllib/regression/ridge.rb +97 -0
- data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
- data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
- data/lib/spark/mllib/stat/distribution.rb +12 -0
- data/lib/spark/mllib/vector.rb +185 -0
- data/lib/spark/rdd.rb +1328 -0
- data/lib/spark/sampler.rb +92 -0
- data/lib/spark/serializer.rb +24 -0
- data/lib/spark/serializer/base.rb +170 -0
- data/lib/spark/serializer/cartesian.rb +37 -0
- data/lib/spark/serializer/marshal.rb +19 -0
- data/lib/spark/serializer/message_pack.rb +25 -0
- data/lib/spark/serializer/oj.rb +25 -0
- data/lib/spark/serializer/pair.rb +27 -0
- data/lib/spark/serializer/utf8.rb +25 -0
- data/lib/spark/sort.rb +189 -0
- data/lib/spark/stat_counter.rb +125 -0
- data/lib/spark/storage_level.rb +39 -0
- data/lib/spark/version.rb +3 -0
- data/lib/spark/worker/master.rb +144 -0
- data/lib/spark/worker/spark_files.rb +15 -0
- data/lib/spark/worker/worker.rb +197 -0
- data/ruby-spark.gemspec +36 -0
- data/spec/generator.rb +37 -0
- data/spec/inputs/lorem_300.txt +316 -0
- data/spec/inputs/numbers/1.txt +50 -0
- data/spec/inputs/numbers/10.txt +50 -0
- data/spec/inputs/numbers/11.txt +50 -0
- data/spec/inputs/numbers/12.txt +50 -0
- data/spec/inputs/numbers/13.txt +50 -0
- data/spec/inputs/numbers/14.txt +50 -0
- data/spec/inputs/numbers/15.txt +50 -0
- data/spec/inputs/numbers/16.txt +50 -0
- data/spec/inputs/numbers/17.txt +50 -0
- data/spec/inputs/numbers/18.txt +50 -0
- data/spec/inputs/numbers/19.txt +50 -0
- data/spec/inputs/numbers/2.txt +50 -0
- data/spec/inputs/numbers/20.txt +50 -0
- data/spec/inputs/numbers/3.txt +50 -0
- data/spec/inputs/numbers/4.txt +50 -0
- data/spec/inputs/numbers/5.txt +50 -0
- data/spec/inputs/numbers/6.txt +50 -0
- data/spec/inputs/numbers/7.txt +50 -0
- data/spec/inputs/numbers/8.txt +50 -0
- data/spec/inputs/numbers/9.txt +50 -0
- data/spec/inputs/numbers_0_100.txt +101 -0
- data/spec/inputs/numbers_1_100.txt +100 -0
- data/spec/lib/collect_spec.rb +42 -0
- data/spec/lib/command_spec.rb +68 -0
- data/spec/lib/config_spec.rb +64 -0
- data/spec/lib/context_spec.rb +163 -0
- data/spec/lib/ext_spec.rb +72 -0
- data/spec/lib/external_apps_spec.rb +45 -0
- data/spec/lib/filter_spec.rb +80 -0
- data/spec/lib/flat_map_spec.rb +100 -0
- data/spec/lib/group_spec.rb +109 -0
- data/spec/lib/helper_spec.rb +19 -0
- data/spec/lib/key_spec.rb +41 -0
- data/spec/lib/manipulation_spec.rb +114 -0
- data/spec/lib/map_partitions_spec.rb +87 -0
- data/spec/lib/map_spec.rb +91 -0
- data/spec/lib/mllib/classification_spec.rb +54 -0
- data/spec/lib/mllib/clustering_spec.rb +35 -0
- data/spec/lib/mllib/matrix_spec.rb +32 -0
- data/spec/lib/mllib/regression_spec.rb +116 -0
- data/spec/lib/mllib/vector_spec.rb +77 -0
- data/spec/lib/reduce_by_key_spec.rb +118 -0
- data/spec/lib/reduce_spec.rb +131 -0
- data/spec/lib/sample_spec.rb +46 -0
- data/spec/lib/serializer_spec.rb +13 -0
- data/spec/lib/sort_spec.rb +58 -0
- data/spec/lib/statistic_spec.rb +168 -0
- data/spec/lib/whole_text_files_spec.rb +33 -0
- data/spec/spec_helper.rb +39 -0
- metadata +301 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# NaiveBayesModel
|
5
|
+
#
|
6
|
+
# Model for Naive Bayes classifiers.
|
7
|
+
#
|
8
|
+
# Contains two parameters:
|
9
|
+
# pi:: vector of logs of class priors (dimension C)
|
10
|
+
# theta:: matrix of logs of class conditional probabilities (CxD)
|
11
|
+
#
|
12
|
+
# == Examples:
|
13
|
+
#
|
14
|
+
# Spark::Mllib.import
|
15
|
+
#
|
16
|
+
# # Dense vectors
|
17
|
+
# data = [
|
18
|
+
# LabeledPoint.new(0.0, [0.0, 0.0]),
|
19
|
+
# LabeledPoint.new(0.0, [0.0, 1.0]),
|
20
|
+
# LabeledPoint.new(1.0, [1.0, 0.0])
|
21
|
+
# ]
|
22
|
+
# model = NaiveBayes.train($sc.parallelize(data))
|
23
|
+
#
|
24
|
+
# model.predict([0.0, 1.0])
|
25
|
+
# # => 0.0
|
26
|
+
# model.predict([1.0, 0.0])
|
27
|
+
# # => 1.0
|
28
|
+
#
|
29
|
+
#
|
30
|
+
# # Sparse vectors
|
31
|
+
# data = [
|
32
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {1 => 0.0})),
|
33
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {1 => 1.0})),
|
34
|
+
# LabeledPoint.new(1.0, SparseVector.new(2, {0 => 1.0}))
|
35
|
+
# ]
|
36
|
+
# model = NaiveBayes.train($sc.parallelize(data))
|
37
|
+
#
|
38
|
+
# model.predict(SparseVector.new(2, {1 => 1.0}))
|
39
|
+
# # => 0.0
|
40
|
+
# model.predict(SparseVector.new(2, {0 => 1.0}))
|
41
|
+
# # => 1.0
|
42
|
+
#
|
43
|
+
class NaiveBayesModel
|
44
|
+
|
45
|
+
attr_reader :labels, :pi, :theta
|
46
|
+
|
47
|
+
def initialize(labels, pi, theta)
|
48
|
+
@labels = labels
|
49
|
+
@pi = pi
|
50
|
+
@theta = theta
|
51
|
+
end
|
52
|
+
|
53
|
+
# Predict values for a single data point or an RDD of points using
|
54
|
+
# the model trained.
|
55
|
+
def predict(vector)
|
56
|
+
vector = Spark::Mllib::Vectors.to_vector(vector)
|
57
|
+
array = (vector.dot(theta) + pi).to_a
|
58
|
+
index = array.index(array.max)
|
59
|
+
labels[index]
|
60
|
+
end
|
61
|
+
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
|
67
|
+
module Spark
|
68
|
+
module Mllib
|
69
|
+
class NaiveBayes
|
70
|
+
|
71
|
+
# Trains a Naive Bayes model given an RDD of (label, features) pairs.
|
72
|
+
#
|
73
|
+
# This is the Multinomial NB (http://tinyurl.com/lsdw6p) which can handle all kinds of
|
74
|
+
# discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
|
75
|
+
# document classification. By making every vector a 0-1 vector, it can also be used as
|
76
|
+
# Bernoulli NB (http://tinyurl.com/p7c96j6). The input feature values must be nonnegative.
|
77
|
+
#
|
78
|
+
# == Arguments:
|
79
|
+
# rdd:: RDD of LabeledPoint.
|
80
|
+
# lambda:: The smoothing parameter.
|
81
|
+
#
|
82
|
+
def self.train(rdd, lambda=1.0)
|
83
|
+
# Validation
|
84
|
+
first = rdd.first
|
85
|
+
unless first.is_a?(LabeledPoint)
|
86
|
+
raise Spark::MllibError, "RDD should contains LabeledPoint, got #{first.class}"
|
87
|
+
end
|
88
|
+
|
89
|
+
labels, pi, theta = Spark.jb.call(RubyMLLibAPI.new, 'trainNaiveBayes', rdd, lambda)
|
90
|
+
theta = Spark::Mllib::Matrices.dense(theta.size, theta.first.size, theta)
|
91
|
+
|
92
|
+
NaiveBayesModel.new(labels, pi, theta)
|
93
|
+
end
|
94
|
+
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
@@ -0,0 +1,135 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# SVMModel
|
5
|
+
#
|
6
|
+
# A support vector machine.
|
7
|
+
#
|
8
|
+
# == Examples:
|
9
|
+
#
|
10
|
+
# Spark::Mllib.import
|
11
|
+
#
|
12
|
+
# # Dense vectors
|
13
|
+
# data = [
|
14
|
+
# LabeledPoint.new(0.0, [0.0]),
|
15
|
+
# LabeledPoint.new(1.0, [1.0]),
|
16
|
+
# LabeledPoint.new(1.0, [2.0]),
|
17
|
+
# LabeledPoint.new(1.0, [3.0])
|
18
|
+
# ]
|
19
|
+
# svm = SVMWithSGD.train($sc.parallelize(data))
|
20
|
+
#
|
21
|
+
# svm.predict([1.0])
|
22
|
+
# # => 1
|
23
|
+
# svm.clear_threshold
|
24
|
+
# svm.predict([1.0])
|
25
|
+
# # => 1.25...
|
26
|
+
#
|
27
|
+
#
|
28
|
+
# # Sparse vectors
|
29
|
+
# data = [
|
30
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {0 => -1.0})),
|
31
|
+
# LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
|
32
|
+
# LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
|
33
|
+
# LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
|
34
|
+
# ]
|
35
|
+
# svm = SVMWithSGD.train($sc.parallelize(data))
|
36
|
+
#
|
37
|
+
# svm.predict(SparseVector.new(2, {1 => 1.0}))
|
38
|
+
# # => 1
|
39
|
+
# svm.predict(SparseVector.new(2, {0 => -1.0}))
|
40
|
+
# # => 0
|
41
|
+
#
|
42
|
+
class SVMModel < ClassificationModel
|
43
|
+
|
44
|
+
def initialize(*args)
|
45
|
+
super
|
46
|
+
@threshold = 0.0
|
47
|
+
end
|
48
|
+
|
49
|
+
# Predict values for a single data point or an RDD of points using
|
50
|
+
# the model trained.
|
51
|
+
def predict(vector)
|
52
|
+
vector = Spark::Mllib::Vectors.to_vector(vector)
|
53
|
+
margin = weights.dot(vector) + intercept
|
54
|
+
|
55
|
+
if threshold.nil?
|
56
|
+
return margin
|
57
|
+
end
|
58
|
+
|
59
|
+
if margin > threshold
|
60
|
+
1
|
61
|
+
else
|
62
|
+
0
|
63
|
+
end
|
64
|
+
end
|
65
|
+
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
module Spark
|
71
|
+
module Mllib
|
72
|
+
class SVMWithSGD < ClassificationMethodBase
|
73
|
+
|
74
|
+
DEFAULT_OPTIONS = {
|
75
|
+
iterations: 100,
|
76
|
+
step: 1.0,
|
77
|
+
reg_param: 0.01,
|
78
|
+
mini_batch_fraction: 1.0,
|
79
|
+
initial_weights: nil,
|
80
|
+
reg_type: 'l2',
|
81
|
+
intercept: false
|
82
|
+
}
|
83
|
+
|
84
|
+
# Train a support vector machine on the given data.
|
85
|
+
#
|
86
|
+
# rdd::
|
87
|
+
# The training data, an RDD of LabeledPoint.
|
88
|
+
#
|
89
|
+
# iterations::
|
90
|
+
# The number of iterations (default: 100).
|
91
|
+
#
|
92
|
+
# step::
|
93
|
+
# The step parameter used in SGD (default: 1.0).
|
94
|
+
#
|
95
|
+
# reg_param::
|
96
|
+
# The regularizer parameter (default: 0.01).
|
97
|
+
#
|
98
|
+
# mini_batch_fraction::
|
99
|
+
# Fraction of data to be used for each SGD iteration.
|
100
|
+
#
|
101
|
+
# initial_weights::
|
102
|
+
# The initial weights (default: nil).
|
103
|
+
#
|
104
|
+
# reg_type::
|
105
|
+
# The type of regularizer used for training our model (default: "l2").
|
106
|
+
#
|
107
|
+
# Allowed values:
|
108
|
+
# - "l1" for using L1 regularization
|
109
|
+
# - "l2" for using L2 regularization
|
110
|
+
# - nil for no regularization
|
111
|
+
#
|
112
|
+
# intercept::
|
113
|
+
# Boolean parameter which indicates the use
|
114
|
+
# or not of the augmented representation for
|
115
|
+
# training data (i.e. whether bias features
|
116
|
+
# are activated or not).
|
117
|
+
#
|
118
|
+
def self.train(rdd, options={})
|
119
|
+
super
|
120
|
+
|
121
|
+
weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainSVMModelWithSGD', rdd,
|
122
|
+
options[:iterations].to_i,
|
123
|
+
options[:step].to_f,
|
124
|
+
options[:reg_param].to_f,
|
125
|
+
options[:mini_batch_fraction].to_f,
|
126
|
+
options[:initial_weights],
|
127
|
+
options[:reg_type],
|
128
|
+
options[:intercept])
|
129
|
+
|
130
|
+
SVMModel.new(weights, intercept)
|
131
|
+
end
|
132
|
+
|
133
|
+
end
|
134
|
+
end
|
135
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# GaussianMixtureModel
|
5
|
+
#
|
6
|
+
# A clustering model derived from the Gaussian Mixture Model method.
|
7
|
+
#
|
8
|
+
# == Examples:
|
9
|
+
#
|
10
|
+
# Spark::Mllib.import
|
11
|
+
#
|
12
|
+
# data = [
|
13
|
+
# DenseVector.new([-0.1, -0.05]),
|
14
|
+
# DenseVector.new([-0.01, -0.1]),
|
15
|
+
# DenseVector.new([0.9, 0.8]),
|
16
|
+
# DenseVector.new([0.75, 0.935]),
|
17
|
+
# DenseVector.new([-0.83, -0.68]),
|
18
|
+
# DenseVector.new([-0.91, -0.76])
|
19
|
+
# ]
|
20
|
+
#
|
21
|
+
# model = GaussianMixture.train($sc.parallelize(data), 3, convergence_tol: 0.0001, max_iterations: 50, seed: 10)
|
22
|
+
#
|
23
|
+
# labels = model.predict($sc.parallelize(data)).collect
|
24
|
+
#
|
25
|
+
class GaussianMixtureModel
|
26
|
+
|
27
|
+
attr_reader :weights, :gaussians, :k
|
28
|
+
|
29
|
+
def initialize(weights, gaussians)
|
30
|
+
@weights = weights
|
31
|
+
@gaussians = gaussians
|
32
|
+
@k = weights.size
|
33
|
+
end
|
34
|
+
|
35
|
+
# Find the cluster to which the points in 'x' has maximum membership
|
36
|
+
# in this model.
|
37
|
+
def predict(rdd)
|
38
|
+
if rdd.is_a?(Spark::RDD)
|
39
|
+
predict_soft(rdd).map('lambda{|x| x.index(x.max)}')
|
40
|
+
else
|
41
|
+
raise ArgumentError, 'Argument must be a RDD.'
|
42
|
+
end
|
43
|
+
end
|
44
|
+
|
45
|
+
# Find the membership of each point in 'x' to all mixture components.
|
46
|
+
def predict_soft(rdd)
|
47
|
+
Spark.jb.call(RubyMLLibAPI.new, 'predictSoftGMM', rdd, weights, means, sigmas)
|
48
|
+
end
|
49
|
+
|
50
|
+
def means
|
51
|
+
@means ||= @gaussians.map(&:mu)
|
52
|
+
end
|
53
|
+
|
54
|
+
def sigmas
|
55
|
+
@sigmas ||= @gaussians.map(&:sigma)
|
56
|
+
end
|
57
|
+
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
|
62
|
+
module Spark
|
63
|
+
module Mllib
|
64
|
+
class GaussianMixture
|
65
|
+
|
66
|
+
def self.train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil)
|
67
|
+
weights, means, sigmas = Spark.jb.call(RubyMLLibAPI.new, 'trainGaussianMixture', rdd,
|
68
|
+
k, convergence_tol, max_iterations, Spark.jb.to_long(seed))
|
69
|
+
|
70
|
+
means.map! {|mu| Spark.jb.java_to_ruby(mu)}
|
71
|
+
sigmas.map!{|sigma| Spark.jb.java_to_ruby(sigma)}
|
72
|
+
|
73
|
+
mvgs = Array.new(k) do |i|
|
74
|
+
MultivariateGaussian.new(means[i], sigmas[i])
|
75
|
+
end
|
76
|
+
|
77
|
+
GaussianMixtureModel.new(weights, mvgs)
|
78
|
+
end
|
79
|
+
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
@@ -0,0 +1,118 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
##
|
4
|
+
# KMeansModel
|
5
|
+
#
|
6
|
+
# A clustering model derived from the k-means method.
|
7
|
+
#
|
8
|
+
# == Examples:
|
9
|
+
#
|
10
|
+
# Spark::Mllib.import
|
11
|
+
#
|
12
|
+
# # Dense vectors
|
13
|
+
# data = [
|
14
|
+
# DenseVector.new([0.0,0.0]),
|
15
|
+
# DenseVector.new([1.0,1.0]),
|
16
|
+
# DenseVector.new([9.0,8.0]),
|
17
|
+
# DenseVector.new([8.0,9.0])
|
18
|
+
# ]
|
19
|
+
#
|
20
|
+
# model = KMeans.train($sc.parallelize(data), 2, max_iterations: 10,
|
21
|
+
# runs: 30, initialization_mode: "random")
|
22
|
+
#
|
23
|
+
# model.predict([0.0, 0.0]) == model.predict([1.0, 1.0])
|
24
|
+
# # => true
|
25
|
+
# model.predict([8.0, 9.0]) == model.predict([9.0, 8.0])
|
26
|
+
# # => true
|
27
|
+
#
|
28
|
+
#
|
29
|
+
# # Sparse vectors
|
30
|
+
# data = [
|
31
|
+
# SparseVector.new(3, {1 => 1.0}),
|
32
|
+
# SparseVector.new(3, {1 => 1.1}),
|
33
|
+
# SparseVector.new(3, {2 => 1.0}),
|
34
|
+
# SparseVector.new(3, {2 => 1.1})
|
35
|
+
# ]
|
36
|
+
# model = KMeans.train($sc.parallelize(data), 2, initialization_mode: "k-means||")
|
37
|
+
#
|
38
|
+
# model.predict([0.0, 1.0, 0.0]) == model.predict([0, 1.1, 0.0])
|
39
|
+
# # => true
|
40
|
+
# model.predict([0.0, 0.0, 1.0]) == model.predict([0, 0, 1.1])
|
41
|
+
# # => true
|
42
|
+
# model.predict(data[0]) == model.predict(data[1])
|
43
|
+
# # => true
|
44
|
+
# model.predict(data[2]) == model.predict(data[3])
|
45
|
+
# # => true
|
46
|
+
#
|
47
|
+
class KMeansModel
|
48
|
+
|
49
|
+
attr_reader :centers
|
50
|
+
|
51
|
+
def initialize(centers)
|
52
|
+
@centers = centers
|
53
|
+
end
|
54
|
+
|
55
|
+
# Find the cluster to which x belongs in this model.
|
56
|
+
def predict(vector)
|
57
|
+
vector = Spark::Mllib::Vectors.to_vector(vector)
|
58
|
+
best = 0
|
59
|
+
best_distance = Float::INFINITY
|
60
|
+
|
61
|
+
@centers.each_with_index do |center, index|
|
62
|
+
distance = vector.squared_distance(center)
|
63
|
+
if distance < best_distance
|
64
|
+
best = index
|
65
|
+
best_distance = distance
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
best
|
70
|
+
end
|
71
|
+
|
72
|
+
def self.from_java(object)
|
73
|
+
centers = object.clusterCenters
|
74
|
+
centers.map! do |center|
|
75
|
+
Spark.jb.java_to_ruby(center)
|
76
|
+
end
|
77
|
+
|
78
|
+
KMeansModel.new(centers)
|
79
|
+
end
|
80
|
+
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
84
|
+
|
85
|
+
module Spark
|
86
|
+
module Mllib
|
87
|
+
class KMeans
|
88
|
+
|
89
|
+
# Trains a k-means model using the given set of parameters.
|
90
|
+
#
|
91
|
+
# == Arguments:
|
92
|
+
# rdd::
|
93
|
+
# The training data, an RDD of Vectors.
|
94
|
+
#
|
95
|
+
# k::
|
96
|
+
# Number of clusters.
|
97
|
+
#
|
98
|
+
# max_iterations::
|
99
|
+
# Max number of iterations.
|
100
|
+
#
|
101
|
+
# runs::
|
102
|
+
# Number of parallel runs, defaults to 1. The best model is returned.
|
103
|
+
#
|
104
|
+
# initialization_mode::
|
105
|
+
# Initialization model, either "random" or "k-means||" (default).
|
106
|
+
#
|
107
|
+
# seed::
|
108
|
+
# Random seed value for cluster initialization.
|
109
|
+
#
|
110
|
+
def self.train(rdd, k, max_iterations: 100, runs: 1, initialization_mode: 'k-means||', seed: nil)
|
111
|
+
# Call returns KMeansModel
|
112
|
+
Spark.jb.call(RubyMLLibAPI.new, 'trainKMeansModel', rdd,
|
113
|
+
k, max_iterations, runs, initialization_mode, Spark.jb.to_long(seed))
|
114
|
+
end
|
115
|
+
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
@@ -0,0 +1,120 @@
|
|
1
|
+
module Spark
|
2
|
+
module Mllib
|
3
|
+
module Matrices
|
4
|
+
|
5
|
+
def self.dense(*args)
|
6
|
+
DenseMatrix.new(*args)
|
7
|
+
end
|
8
|
+
|
9
|
+
def self.sparse(*args)
|
10
|
+
SparseMatrix.new(*args)
|
11
|
+
end
|
12
|
+
|
13
|
+
def self.to_matrix(data)
|
14
|
+
if data.is_a?(SparseMatrix) || data.is_a?(DenseMatrix)
|
15
|
+
data
|
16
|
+
elsif data.is_a?(Array)
|
17
|
+
DenseMatrix.new(data)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
module Spark
|
26
|
+
module Mllib
|
27
|
+
# @abstract Parent for all type of matrices
|
28
|
+
class MatrixBase < MatrixAdapter
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
module Spark
|
34
|
+
module Mllib
|
35
|
+
##
|
36
|
+
# DenseMatrix
|
37
|
+
#
|
38
|
+
# DenseMatrix.new(2, 3, [[1,2,3], [4,5,6]]).values
|
39
|
+
# # => [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
40
|
+
#
|
41
|
+
class DenseMatrix < MatrixBase
|
42
|
+
|
43
|
+
def initialize(rows, cols, values)
|
44
|
+
super(:dense, rows, cols, values.to_a)
|
45
|
+
end
|
46
|
+
|
47
|
+
def to_java
|
48
|
+
JDenseMatrix.new(shape[0], shape[1], values.flatten)
|
49
|
+
end
|
50
|
+
|
51
|
+
def self.from_java(object)
|
52
|
+
rows = object.numRows
|
53
|
+
cols = object.numCols
|
54
|
+
values = object.values
|
55
|
+
|
56
|
+
DenseMatrix.new(rows, cols, values)
|
57
|
+
end
|
58
|
+
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
module Spark
|
64
|
+
module Mllib
|
65
|
+
##
|
66
|
+
# SparseMatrix
|
67
|
+
#
|
68
|
+
# == Arguments:
|
69
|
+
# rows::
|
70
|
+
# Number of rows.
|
71
|
+
#
|
72
|
+
# cols::
|
73
|
+
# Number of columns.
|
74
|
+
#
|
75
|
+
# col_pointers::
|
76
|
+
# The index corresponding to the start of a new column.
|
77
|
+
#
|
78
|
+
# row_indices::
|
79
|
+
# The row index of the entry. They must be in strictly
|
80
|
+
# increasing order for each column.
|
81
|
+
#
|
82
|
+
# values::
|
83
|
+
# Nonzero matrix entries in column major.
|
84
|
+
#
|
85
|
+
# == Examples:
|
86
|
+
#
|
87
|
+
# SparseMatrix.new(3, 3, [0, 2, 3, 6], [0, 2, 1, 0, 1, 2], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).values
|
88
|
+
#
|
89
|
+
# # => [
|
90
|
+
# # [1.0, 0.0, 4.0],
|
91
|
+
# # [0.0, 3.0, 5.0],
|
92
|
+
# # [2.0, 0.0, 6.0]
|
93
|
+
# # ]
|
94
|
+
#
|
95
|
+
class SparseMatrix < MatrixBase
|
96
|
+
|
97
|
+
attr_reader :col_pointers, :row_indices
|
98
|
+
|
99
|
+
def initialize(rows, cols, col_pointers, row_indices, values)
|
100
|
+
super(:sparse, rows, cols)
|
101
|
+
|
102
|
+
@col_pointers = col_pointers
|
103
|
+
@row_indices = row_indices
|
104
|
+
@values = values
|
105
|
+
|
106
|
+
j = 0
|
107
|
+
while j < cols
|
108
|
+
idx = col_pointers[j]
|
109
|
+
idx_end = col_pointers[j+1]
|
110
|
+
while idx < idx_end
|
111
|
+
self[row_indices[idx], j] = values[idx]
|
112
|
+
idx += 1
|
113
|
+
end
|
114
|
+
j += 1
|
115
|
+
end
|
116
|
+
end
|
117
|
+
|
118
|
+
end
|
119
|
+
end
|
120
|
+
end
|