ruby-spark 1.1.0.1-java

Sign up to get free protection for your applications and to get access to all the features.
Files changed (180) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +252 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +6 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/comparison/prepare.sh +18 -0
  12. data/benchmark/comparison/python.py +156 -0
  13. data/benchmark/comparison/r.r +69 -0
  14. data/benchmark/comparison/ruby.rb +167 -0
  15. data/benchmark/comparison/run-all.sh +160 -0
  16. data/benchmark/comparison/scala.scala +181 -0
  17. data/benchmark/custom_marshal.rb +94 -0
  18. data/benchmark/digest.rb +150 -0
  19. data/benchmark/enumerator.rb +88 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/example/website_search.rb +83 -0
  27. data/ext/ruby_c/extconf.rb +3 -0
  28. data/ext/ruby_c/murmur.c +158 -0
  29. data/ext/ruby_c/murmur.h +9 -0
  30. data/ext/ruby_c/ruby-spark.c +18 -0
  31. data/ext/ruby_java/Digest.java +36 -0
  32. data/ext/ruby_java/Murmur2.java +98 -0
  33. data/ext/ruby_java/RubySparkExtService.java +28 -0
  34. data/ext/ruby_java/extconf.rb +3 -0
  35. data/ext/spark/build.sbt +73 -0
  36. data/ext/spark/project/plugins.sbt +9 -0
  37. data/ext/spark/sbt/sbt +34 -0
  38. data/ext/spark/src/main/scala/Exec.scala +91 -0
  39. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  40. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  41. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  42. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  43. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  44. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  46. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  47. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  48. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  49. data/ext/spark/src/main/scala/RubyRDD.scala +392 -0
  50. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  51. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  52. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  53. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  54. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  55. data/lib/ruby-spark.rb +1 -0
  56. data/lib/spark.rb +198 -0
  57. data/lib/spark/accumulator.rb +260 -0
  58. data/lib/spark/broadcast.rb +98 -0
  59. data/lib/spark/build.rb +43 -0
  60. data/lib/spark/cli.rb +169 -0
  61. data/lib/spark/command.rb +86 -0
  62. data/lib/spark/command/base.rb +158 -0
  63. data/lib/spark/command/basic.rb +345 -0
  64. data/lib/spark/command/pair.rb +124 -0
  65. data/lib/spark/command/sort.rb +51 -0
  66. data/lib/spark/command/statistic.rb +144 -0
  67. data/lib/spark/command_builder.rb +141 -0
  68. data/lib/spark/command_validator.rb +34 -0
  69. data/lib/spark/config.rb +238 -0
  70. data/lib/spark/constant.rb +14 -0
  71. data/lib/spark/context.rb +322 -0
  72. data/lib/spark/error.rb +50 -0
  73. data/lib/spark/ext/hash.rb +41 -0
  74. data/lib/spark/ext/integer.rb +25 -0
  75. data/lib/spark/ext/io.rb +67 -0
  76. data/lib/spark/ext/ip_socket.rb +29 -0
  77. data/lib/spark/ext/module.rb +58 -0
  78. data/lib/spark/ext/object.rb +24 -0
  79. data/lib/spark/ext/string.rb +24 -0
  80. data/lib/spark/helper.rb +10 -0
  81. data/lib/spark/helper/logger.rb +40 -0
  82. data/lib/spark/helper/parser.rb +85 -0
  83. data/lib/spark/helper/serialize.rb +71 -0
  84. data/lib/spark/helper/statistic.rb +93 -0
  85. data/lib/spark/helper/system.rb +42 -0
  86. data/lib/spark/java_bridge.rb +19 -0
  87. data/lib/spark/java_bridge/base.rb +203 -0
  88. data/lib/spark/java_bridge/jruby.rb +23 -0
  89. data/lib/spark/java_bridge/rjb.rb +41 -0
  90. data/lib/spark/logger.rb +76 -0
  91. data/lib/spark/mllib.rb +100 -0
  92. data/lib/spark/mllib/classification/common.rb +31 -0
  93. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  94. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  95. data/lib/spark/mllib/classification/svm.rb +135 -0
  96. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  97. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  98. data/lib/spark/mllib/matrix.rb +120 -0
  99. data/lib/spark/mllib/regression/common.rb +73 -0
  100. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  101. data/lib/spark/mllib/regression/lasso.rb +100 -0
  102. data/lib/spark/mllib/regression/linear.rb +124 -0
  103. data/lib/spark/mllib/regression/ridge.rb +97 -0
  104. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  105. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  106. data/lib/spark/mllib/stat/distribution.rb +12 -0
  107. data/lib/spark/mllib/vector.rb +185 -0
  108. data/lib/spark/rdd.rb +1377 -0
  109. data/lib/spark/sampler.rb +92 -0
  110. data/lib/spark/serializer.rb +79 -0
  111. data/lib/spark/serializer/auto_batched.rb +59 -0
  112. data/lib/spark/serializer/base.rb +63 -0
  113. data/lib/spark/serializer/batched.rb +84 -0
  114. data/lib/spark/serializer/cartesian.rb +13 -0
  115. data/lib/spark/serializer/compressed.rb +27 -0
  116. data/lib/spark/serializer/marshal.rb +17 -0
  117. data/lib/spark/serializer/message_pack.rb +23 -0
  118. data/lib/spark/serializer/oj.rb +23 -0
  119. data/lib/spark/serializer/pair.rb +41 -0
  120. data/lib/spark/serializer/text.rb +25 -0
  121. data/lib/spark/sort.rb +189 -0
  122. data/lib/spark/stat_counter.rb +125 -0
  123. data/lib/spark/storage_level.rb +39 -0
  124. data/lib/spark/version.rb +3 -0
  125. data/lib/spark/worker/master.rb +144 -0
  126. data/lib/spark/worker/spark_files.rb +15 -0
  127. data/lib/spark/worker/worker.rb +200 -0
  128. data/ruby-spark.gemspec +47 -0
  129. data/spec/generator.rb +37 -0
  130. data/spec/inputs/lorem_300.txt +316 -0
  131. data/spec/inputs/numbers/1.txt +50 -0
  132. data/spec/inputs/numbers/10.txt +50 -0
  133. data/spec/inputs/numbers/11.txt +50 -0
  134. data/spec/inputs/numbers/12.txt +50 -0
  135. data/spec/inputs/numbers/13.txt +50 -0
  136. data/spec/inputs/numbers/14.txt +50 -0
  137. data/spec/inputs/numbers/15.txt +50 -0
  138. data/spec/inputs/numbers/16.txt +50 -0
  139. data/spec/inputs/numbers/17.txt +50 -0
  140. data/spec/inputs/numbers/18.txt +50 -0
  141. data/spec/inputs/numbers/19.txt +50 -0
  142. data/spec/inputs/numbers/2.txt +50 -0
  143. data/spec/inputs/numbers/20.txt +50 -0
  144. data/spec/inputs/numbers/3.txt +50 -0
  145. data/spec/inputs/numbers/4.txt +50 -0
  146. data/spec/inputs/numbers/5.txt +50 -0
  147. data/spec/inputs/numbers/6.txt +50 -0
  148. data/spec/inputs/numbers/7.txt +50 -0
  149. data/spec/inputs/numbers/8.txt +50 -0
  150. data/spec/inputs/numbers/9.txt +50 -0
  151. data/spec/inputs/numbers_0_100.txt +101 -0
  152. data/spec/inputs/numbers_1_100.txt +100 -0
  153. data/spec/lib/collect_spec.rb +42 -0
  154. data/spec/lib/command_spec.rb +68 -0
  155. data/spec/lib/config_spec.rb +64 -0
  156. data/spec/lib/context_spec.rb +165 -0
  157. data/spec/lib/ext_spec.rb +72 -0
  158. data/spec/lib/external_apps_spec.rb +45 -0
  159. data/spec/lib/filter_spec.rb +80 -0
  160. data/spec/lib/flat_map_spec.rb +100 -0
  161. data/spec/lib/group_spec.rb +109 -0
  162. data/spec/lib/helper_spec.rb +19 -0
  163. data/spec/lib/key_spec.rb +41 -0
  164. data/spec/lib/manipulation_spec.rb +122 -0
  165. data/spec/lib/map_partitions_spec.rb +87 -0
  166. data/spec/lib/map_spec.rb +91 -0
  167. data/spec/lib/mllib/classification_spec.rb +54 -0
  168. data/spec/lib/mllib/clustering_spec.rb +35 -0
  169. data/spec/lib/mllib/matrix_spec.rb +32 -0
  170. data/spec/lib/mllib/regression_spec.rb +116 -0
  171. data/spec/lib/mllib/vector_spec.rb +77 -0
  172. data/spec/lib/reduce_by_key_spec.rb +118 -0
  173. data/spec/lib/reduce_spec.rb +131 -0
  174. data/spec/lib/sample_spec.rb +46 -0
  175. data/spec/lib/serializer_spec.rb +88 -0
  176. data/spec/lib/sort_spec.rb +58 -0
  177. data/spec/lib/statistic_spec.rb +170 -0
  178. data/spec/lib/whole_text_files_spec.rb +33 -0
  179. data/spec/spec_helper.rb +38 -0
  180. metadata +389 -0
@@ -0,0 +1,97 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # NaiveBayesModel
5
+ #
6
+ # Model for Naive Bayes classifiers.
7
+ #
8
+ # Contains two parameters:
9
+ # pi:: vector of logs of class priors (dimension C)
10
+ # theta:: matrix of logs of class conditional probabilities (CxD)
11
+ #
12
+ # == Examples:
13
+ #
14
+ # Spark::Mllib.import
15
+ #
16
+ # # Dense vectors
17
+ # data = [
18
+ # LabeledPoint.new(0.0, [0.0, 0.0]),
19
+ # LabeledPoint.new(0.0, [0.0, 1.0]),
20
+ # LabeledPoint.new(1.0, [1.0, 0.0])
21
+ # ]
22
+ # model = NaiveBayes.train($sc.parallelize(data))
23
+ #
24
+ # model.predict([0.0, 1.0])
25
+ # # => 0.0
26
+ # model.predict([1.0, 0.0])
27
+ # # => 1.0
28
+ #
29
+ #
30
+ # # Sparse vectors
31
+ # data = [
32
+ # LabeledPoint.new(0.0, SparseVector.new(2, {1 => 0.0})),
33
+ # LabeledPoint.new(0.0, SparseVector.new(2, {1 => 1.0})),
34
+ # LabeledPoint.new(1.0, SparseVector.new(2, {0 => 1.0}))
35
+ # ]
36
+ # model = NaiveBayes.train($sc.parallelize(data))
37
+ #
38
+ # model.predict(SparseVector.new(2, {1 => 1.0}))
39
+ # # => 0.0
40
+ # model.predict(SparseVector.new(2, {0 => 1.0}))
41
+ # # => 1.0
42
+ #
43
+ class NaiveBayesModel
44
+
45
+ attr_reader :labels, :pi, :theta
46
+
47
+ def initialize(labels, pi, theta)
48
+ @labels = labels
49
+ @pi = pi
50
+ @theta = theta
51
+ end
52
+
53
+ # Predict values for a single data point or an RDD of points using
54
+ # the model trained.
55
+ def predict(vector)
56
+ vector = Spark::Mllib::Vectors.to_vector(vector)
57
+ array = (vector.dot(theta) + pi).to_a
58
+ index = array.index(array.max)
59
+ labels[index]
60
+ end
61
+
62
+ end
63
+ end
64
+ end
65
+
66
+
67
+ module Spark
68
+ module Mllib
69
+ class NaiveBayes
70
+
71
+ # Trains a Naive Bayes model given an RDD of (label, features) pairs.
72
+ #
73
+ # This is the Multinomial NB (http://tinyurl.com/lsdw6p) which can handle all kinds of
74
+ # discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
75
+ # document classification. By making every vector a 0-1 vector, it can also be used as
76
+ # Bernoulli NB (http://tinyurl.com/p7c96j6). The input feature values must be nonnegative.
77
+ #
78
+ # == Arguments:
79
+ # rdd:: RDD of LabeledPoint.
80
+ # lambda:: The smoothing parameter.
81
+ #
82
+ def self.train(rdd, lambda=1.0)
83
+ # Validation
84
+ first = rdd.first
85
+ unless first.is_a?(LabeledPoint)
86
+ raise Spark::MllibError, "RDD should contains LabeledPoint, got #{first.class}"
87
+ end
88
+
89
+ labels, pi, theta = Spark.jb.call(RubyMLLibAPI.new, 'trainNaiveBayes', rdd, lambda)
90
+ theta = Spark::Mllib::Matrices.dense(theta.size, theta.first.size, theta)
91
+
92
+ NaiveBayesModel.new(labels, pi, theta)
93
+ end
94
+
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,135 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # SVMModel
5
+ #
6
+ # A support vector machine.
7
+ #
8
+ # == Examples:
9
+ #
10
+ # Spark::Mllib.import
11
+ #
12
+ # # Dense vectors
13
+ # data = [
14
+ # LabeledPoint.new(0.0, [0.0]),
15
+ # LabeledPoint.new(1.0, [1.0]),
16
+ # LabeledPoint.new(1.0, [2.0]),
17
+ # LabeledPoint.new(1.0, [3.0])
18
+ # ]
19
+ # svm = SVMWithSGD.train($sc.parallelize(data))
20
+ #
21
+ # svm.predict([1.0])
22
+ # # => 1
23
+ # svm.clear_threshold
24
+ # svm.predict([1.0])
25
+ # # => 1.25...
26
+ #
27
+ #
28
+ # # Sparse vectors
29
+ # data = [
30
+ # LabeledPoint.new(0.0, SparseVector.new(2, {0 => -1.0})),
31
+ # LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
32
+ # LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
33
+ # LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
34
+ # ]
35
+ # svm = SVMWithSGD.train($sc.parallelize(data))
36
+ #
37
+ # svm.predict(SparseVector.new(2, {1 => 1.0}))
38
+ # # => 1
39
+ # svm.predict(SparseVector.new(2, {0 => -1.0}))
40
+ # # => 0
41
+ #
42
+ class SVMModel < ClassificationModel
43
+
44
+ def initialize(*args)
45
+ super
46
+ @threshold = 0.0
47
+ end
48
+
49
+ # Predict values for a single data point or an RDD of points using
50
+ # the model trained.
51
+ def predict(vector)
52
+ vector = Spark::Mllib::Vectors.to_vector(vector)
53
+ margin = weights.dot(vector) + intercept
54
+
55
+ if threshold.nil?
56
+ return margin
57
+ end
58
+
59
+ if margin > threshold
60
+ 1
61
+ else
62
+ 0
63
+ end
64
+ end
65
+
66
+ end
67
+ end
68
+ end
69
+
70
+ module Spark
71
+ module Mllib
72
+ class SVMWithSGD < ClassificationMethodBase
73
+
74
+ DEFAULT_OPTIONS = {
75
+ iterations: 100,
76
+ step: 1.0,
77
+ reg_param: 0.01,
78
+ mini_batch_fraction: 1.0,
79
+ initial_weights: nil,
80
+ reg_type: 'l2',
81
+ intercept: false
82
+ }
83
+
84
+ # Train a support vector machine on the given data.
85
+ #
86
+ # rdd::
87
+ # The training data, an RDD of LabeledPoint.
88
+ #
89
+ # iterations::
90
+ # The number of iterations (default: 100).
91
+ #
92
+ # step::
93
+ # The step parameter used in SGD (default: 1.0).
94
+ #
95
+ # reg_param::
96
+ # The regularizer parameter (default: 0.01).
97
+ #
98
+ # mini_batch_fraction::
99
+ # Fraction of data to be used for each SGD iteration.
100
+ #
101
+ # initial_weights::
102
+ # The initial weights (default: nil).
103
+ #
104
+ # reg_type::
105
+ # The type of regularizer used for training our model (default: "l2").
106
+ #
107
+ # Allowed values:
108
+ # - "l1" for using L1 regularization
109
+ # - "l2" for using L2 regularization
110
+ # - nil for no regularization
111
+ #
112
+ # intercept::
113
+ # Boolean parameter which indicates the use
114
+ # or not of the augmented representation for
115
+ # training data (i.e. whether bias features
116
+ # are activated or not).
117
+ #
118
+ def self.train(rdd, options={})
119
+ super
120
+
121
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainSVMModelWithSGD', rdd,
122
+ options[:iterations].to_i,
123
+ options[:step].to_f,
124
+ options[:reg_param].to_f,
125
+ options[:mini_batch_fraction].to_f,
126
+ options[:initial_weights],
127
+ options[:reg_type],
128
+ options[:intercept])
129
+
130
+ SVMModel.new(weights, intercept)
131
+ end
132
+
133
+ end
134
+ end
135
+ end
@@ -0,0 +1,82 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # GaussianMixtureModel
5
+ #
6
+ # A clustering model derived from the Gaussian Mixture Model method.
7
+ #
8
+ # == Examples:
9
+ #
10
+ # Spark::Mllib.import
11
+ #
12
+ # data = [
13
+ # DenseVector.new([-0.1, -0.05]),
14
+ # DenseVector.new([-0.01, -0.1]),
15
+ # DenseVector.new([0.9, 0.8]),
16
+ # DenseVector.new([0.75, 0.935]),
17
+ # DenseVector.new([-0.83, -0.68]),
18
+ # DenseVector.new([-0.91, -0.76])
19
+ # ]
20
+ #
21
+ # model = GaussianMixture.train($sc.parallelize(data), 3, convergence_tol: 0.0001, max_iterations: 50, seed: 10)
22
+ #
23
+ # labels = model.predict($sc.parallelize(data)).collect
24
+ #
25
+ class GaussianMixtureModel
26
+
27
+ attr_reader :weights, :gaussians, :k
28
+
29
+ def initialize(weights, gaussians)
30
+ @weights = weights
31
+ @gaussians = gaussians
32
+ @k = weights.size
33
+ end
34
+
35
+ # Find the cluster to which the points in 'x' has maximum membership
36
+ # in this model.
37
+ def predict(rdd)
38
+ if rdd.is_a?(Spark::RDD)
39
+ predict_soft(rdd).map('lambda{|x| x.index(x.max)}')
40
+ else
41
+ raise ArgumentError, 'Argument must be a RDD.'
42
+ end
43
+ end
44
+
45
+ # Find the membership of each point in 'x' to all mixture components.
46
+ def predict_soft(rdd)
47
+ Spark.jb.call(RubyMLLibAPI.new, 'predictSoftGMM', rdd, weights, means, sigmas)
48
+ end
49
+
50
+ def means
51
+ @means ||= @gaussians.map(&:mu)
52
+ end
53
+
54
+ def sigmas
55
+ @sigmas ||= @gaussians.map(&:sigma)
56
+ end
57
+
58
+ end
59
+ end
60
+ end
61
+
62
+ module Spark
63
+ module Mllib
64
+ class GaussianMixture
65
+
66
+ def self.train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil)
67
+ weights, means, sigmas = Spark.jb.call(RubyMLLibAPI.new, 'trainGaussianMixture', rdd,
68
+ k, convergence_tol, max_iterations, Spark.jb.to_long(seed))
69
+
70
+ means.map! {|mu| Spark.jb.java_to_ruby(mu)}
71
+ sigmas.map!{|sigma| Spark.jb.java_to_ruby(sigma)}
72
+
73
+ mvgs = Array.new(k) do |i|
74
+ MultivariateGaussian.new(means[i], sigmas[i])
75
+ end
76
+
77
+ GaussianMixtureModel.new(weights, mvgs)
78
+ end
79
+
80
+ end
81
+ end
82
+ end
@@ -0,0 +1,118 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # KMeansModel
5
+ #
6
+ # A clustering model derived from the k-means method.
7
+ #
8
+ # == Examples:
9
+ #
10
+ # Spark::Mllib.import
11
+ #
12
+ # # Dense vectors
13
+ # data = [
14
+ # DenseVector.new([0.0,0.0]),
15
+ # DenseVector.new([1.0,1.0]),
16
+ # DenseVector.new([9.0,8.0]),
17
+ # DenseVector.new([8.0,9.0])
18
+ # ]
19
+ #
20
+ # model = KMeans.train($sc.parallelize(data), 2, max_iterations: 10,
21
+ # runs: 30, initialization_mode: "random")
22
+ #
23
+ # model.predict([0.0, 0.0]) == model.predict([1.0, 1.0])
24
+ # # => true
25
+ # model.predict([8.0, 9.0]) == model.predict([9.0, 8.0])
26
+ # # => true
27
+ #
28
+ #
29
+ # # Sparse vectors
30
+ # data = [
31
+ # SparseVector.new(3, {1 => 1.0}),
32
+ # SparseVector.new(3, {1 => 1.1}),
33
+ # SparseVector.new(3, {2 => 1.0}),
34
+ # SparseVector.new(3, {2 => 1.1})
35
+ # ]
36
+ # model = KMeans.train($sc.parallelize(data), 2, initialization_mode: "k-means||")
37
+ #
38
+ # model.predict([0.0, 1.0, 0.0]) == model.predict([0, 1.1, 0.0])
39
+ # # => true
40
+ # model.predict([0.0, 0.0, 1.0]) == model.predict([0, 0, 1.1])
41
+ # # => true
42
+ # model.predict(data[0]) == model.predict(data[1])
43
+ # # => true
44
+ # model.predict(data[2]) == model.predict(data[3])
45
+ # # => true
46
+ #
47
+ class KMeansModel
48
+
49
+ attr_reader :centers
50
+
51
+ def initialize(centers)
52
+ @centers = centers
53
+ end
54
+
55
+ # Find the cluster to which x belongs in this model.
56
+ def predict(vector)
57
+ vector = Spark::Mllib::Vectors.to_vector(vector)
58
+ best = 0
59
+ best_distance = Float::INFINITY
60
+
61
+ @centers.each_with_index do |center, index|
62
+ distance = vector.squared_distance(center)
63
+ if distance < best_distance
64
+ best = index
65
+ best_distance = distance
66
+ end
67
+ end
68
+
69
+ best
70
+ end
71
+
72
+ def self.from_java(object)
73
+ centers = object.clusterCenters
74
+ centers.map! do |center|
75
+ Spark.jb.java_to_ruby(center)
76
+ end
77
+
78
+ KMeansModel.new(centers)
79
+ end
80
+
81
+ end
82
+ end
83
+ end
84
+
85
+ module Spark
86
+ module Mllib
87
+ class KMeans
88
+
89
+ # Trains a k-means model using the given set of parameters.
90
+ #
91
+ # == Arguments:
92
+ # rdd::
93
+ # The training data, an RDD of Vectors.
94
+ #
95
+ # k::
96
+ # Number of clusters.
97
+ #
98
+ # max_iterations::
99
+ # Max number of iterations.
100
+ #
101
+ # runs::
102
+ # Number of parallel runs, defaults to 1. The best model is returned.
103
+ #
104
+ # initialization_mode::
105
+ # Initialization model, either "random" or "k-means||" (default).
106
+ #
107
+ # seed::
108
+ # Random seed value for cluster initialization.
109
+ #
110
+ def self.train(rdd, k, max_iterations: 100, runs: 1, initialization_mode: 'k-means||', seed: nil)
111
+ # Call returns KMeansModel
112
+ Spark.jb.call(RubyMLLibAPI.new, 'trainKMeansModel', rdd,
113
+ k, max_iterations, runs, initialization_mode, Spark.jb.to_long(seed))
114
+ end
115
+
116
+ end
117
+ end
118
+ end
@@ -0,0 +1,120 @@
1
+ module Spark
2
+ module Mllib
3
+ module Matrices
4
+
5
+ def self.dense(*args)
6
+ DenseMatrix.new(*args)
7
+ end
8
+
9
+ def self.sparse(*args)
10
+ SparseMatrix.new(*args)
11
+ end
12
+
13
+ def self.to_matrix(data)
14
+ if data.is_a?(SparseMatrix) || data.is_a?(DenseMatrix)
15
+ data
16
+ elsif data.is_a?(Array)
17
+ DenseMatrix.new(data)
18
+ end
19
+ end
20
+
21
+ end
22
+ end
23
+ end
24
+
25
+ module Spark
26
+ module Mllib
27
+ # @abstract Parent for all type of matrices
28
+ class MatrixBase < MatrixAdapter
29
+ end
30
+ end
31
+ end
32
+
33
+ module Spark
34
+ module Mllib
35
+ ##
36
+ # DenseMatrix
37
+ #
38
+ # DenseMatrix.new(2, 3, [[1,2,3], [4,5,6]]).values
39
+ # # => [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
40
+ #
41
+ class DenseMatrix < MatrixBase
42
+
43
+ def initialize(rows, cols, values)
44
+ super(:dense, rows, cols, values.to_a)
45
+ end
46
+
47
+ def to_java
48
+ JDenseMatrix.new(shape[0], shape[1], values.flatten)
49
+ end
50
+
51
+ def self.from_java(object)
52
+ rows = object.numRows
53
+ cols = object.numCols
54
+ values = object.values
55
+
56
+ DenseMatrix.new(rows, cols, values)
57
+ end
58
+
59
+ end
60
+ end
61
+ end
62
+
63
+ module Spark
64
+ module Mllib
65
+ ##
66
+ # SparseMatrix
67
+ #
68
+ # == Arguments:
69
+ # rows::
70
+ # Number of rows.
71
+ #
72
+ # cols::
73
+ # Number of columns.
74
+ #
75
+ # col_pointers::
76
+ # The index corresponding to the start of a new column.
77
+ #
78
+ # row_indices::
79
+ # The row index of the entry. They must be in strictly
80
+ # increasing order for each column.
81
+ #
82
+ # values::
83
+ # Nonzero matrix entries in column major.
84
+ #
85
+ # == Examples:
86
+ #
87
+ # SparseMatrix.new(3, 3, [0, 2, 3, 6], [0, 2, 1, 0, 1, 2], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).values
88
+ #
89
+ # # => [
90
+ # # [1.0, 0.0, 4.0],
91
+ # # [0.0, 3.0, 5.0],
92
+ # # [2.0, 0.0, 6.0]
93
+ # # ]
94
+ #
95
+ class SparseMatrix < MatrixBase
96
+
97
+ attr_reader :col_pointers, :row_indices
98
+
99
+ def initialize(rows, cols, col_pointers, row_indices, values)
100
+ super(:sparse, rows, cols)
101
+
102
+ @col_pointers = col_pointers
103
+ @row_indices = row_indices
104
+ @values = values
105
+
106
+ j = 0
107
+ while j < cols
108
+ idx = col_pointers[j]
109
+ idx_end = col_pointers[j+1]
110
+ while idx < idx_end
111
+ self[row_indices[idx], j] = values[idx]
112
+ idx += 1
113
+ end
114
+ j += 1
115
+ end
116
+ end
117
+
118
+ end
119
+ end
120
+ end