ruby-spark 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +185 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +7 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/custom_marshal.rb +94 -0
  12. data/benchmark/digest.rb +150 -0
  13. data/benchmark/enumerator.rb +88 -0
  14. data/benchmark/performance/prepare.sh +18 -0
  15. data/benchmark/performance/python.py +156 -0
  16. data/benchmark/performance/r.r +69 -0
  17. data/benchmark/performance/ruby.rb +167 -0
  18. data/benchmark/performance/run-all.sh +160 -0
  19. data/benchmark/performance/scala.scala +181 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/ext/ruby_c/extconf.rb +3 -0
  27. data/ext/ruby_c/murmur.c +158 -0
  28. data/ext/ruby_c/murmur.h +9 -0
  29. data/ext/ruby_c/ruby-spark.c +18 -0
  30. data/ext/ruby_java/Digest.java +36 -0
  31. data/ext/ruby_java/Murmur2.java +98 -0
  32. data/ext/ruby_java/RubySparkExtService.java +28 -0
  33. data/ext/ruby_java/extconf.rb +3 -0
  34. data/ext/spark/build.sbt +73 -0
  35. data/ext/spark/project/plugins.sbt +9 -0
  36. data/ext/spark/sbt/sbt +34 -0
  37. data/ext/spark/src/main/scala/Exec.scala +91 -0
  38. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  39. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  40. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  41. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  42. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  43. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  44. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  46. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  47. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  48. data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
  49. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  50. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  51. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  52. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  53. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  54. data/lib/ruby-spark.rb +1 -0
  55. data/lib/spark.rb +198 -0
  56. data/lib/spark/accumulator.rb +260 -0
  57. data/lib/spark/broadcast.rb +98 -0
  58. data/lib/spark/build.rb +43 -0
  59. data/lib/spark/cli.rb +169 -0
  60. data/lib/spark/command.rb +86 -0
  61. data/lib/spark/command/base.rb +154 -0
  62. data/lib/spark/command/basic.rb +345 -0
  63. data/lib/spark/command/pair.rb +124 -0
  64. data/lib/spark/command/sort.rb +51 -0
  65. data/lib/spark/command/statistic.rb +144 -0
  66. data/lib/spark/command_builder.rb +141 -0
  67. data/lib/spark/command_validator.rb +34 -0
  68. data/lib/spark/config.rb +244 -0
  69. data/lib/spark/constant.rb +14 -0
  70. data/lib/spark/context.rb +304 -0
  71. data/lib/spark/error.rb +50 -0
  72. data/lib/spark/ext/hash.rb +41 -0
  73. data/lib/spark/ext/integer.rb +25 -0
  74. data/lib/spark/ext/io.rb +57 -0
  75. data/lib/spark/ext/ip_socket.rb +29 -0
  76. data/lib/spark/ext/module.rb +58 -0
  77. data/lib/spark/ext/object.rb +24 -0
  78. data/lib/spark/ext/string.rb +24 -0
  79. data/lib/spark/helper.rb +10 -0
  80. data/lib/spark/helper/logger.rb +40 -0
  81. data/lib/spark/helper/parser.rb +85 -0
  82. data/lib/spark/helper/serialize.rb +71 -0
  83. data/lib/spark/helper/statistic.rb +93 -0
  84. data/lib/spark/helper/system.rb +42 -0
  85. data/lib/spark/java_bridge.rb +19 -0
  86. data/lib/spark/java_bridge/base.rb +203 -0
  87. data/lib/spark/java_bridge/jruby.rb +23 -0
  88. data/lib/spark/java_bridge/rjb.rb +41 -0
  89. data/lib/spark/logger.rb +76 -0
  90. data/lib/spark/mllib.rb +100 -0
  91. data/lib/spark/mllib/classification/common.rb +31 -0
  92. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  93. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  94. data/lib/spark/mllib/classification/svm.rb +135 -0
  95. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  96. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  97. data/lib/spark/mllib/matrix.rb +120 -0
  98. data/lib/spark/mllib/regression/common.rb +73 -0
  99. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  100. data/lib/spark/mllib/regression/lasso.rb +100 -0
  101. data/lib/spark/mllib/regression/linear.rb +124 -0
  102. data/lib/spark/mllib/regression/ridge.rb +97 -0
  103. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  104. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  105. data/lib/spark/mllib/stat/distribution.rb +12 -0
  106. data/lib/spark/mllib/vector.rb +185 -0
  107. data/lib/spark/rdd.rb +1328 -0
  108. data/lib/spark/sampler.rb +92 -0
  109. data/lib/spark/serializer.rb +24 -0
  110. data/lib/spark/serializer/base.rb +170 -0
  111. data/lib/spark/serializer/cartesian.rb +37 -0
  112. data/lib/spark/serializer/marshal.rb +19 -0
  113. data/lib/spark/serializer/message_pack.rb +25 -0
  114. data/lib/spark/serializer/oj.rb +25 -0
  115. data/lib/spark/serializer/pair.rb +27 -0
  116. data/lib/spark/serializer/utf8.rb +25 -0
  117. data/lib/spark/sort.rb +189 -0
  118. data/lib/spark/stat_counter.rb +125 -0
  119. data/lib/spark/storage_level.rb +39 -0
  120. data/lib/spark/version.rb +3 -0
  121. data/lib/spark/worker/master.rb +144 -0
  122. data/lib/spark/worker/spark_files.rb +15 -0
  123. data/lib/spark/worker/worker.rb +197 -0
  124. data/ruby-spark.gemspec +36 -0
  125. data/spec/generator.rb +37 -0
  126. data/spec/inputs/lorem_300.txt +316 -0
  127. data/spec/inputs/numbers/1.txt +50 -0
  128. data/spec/inputs/numbers/10.txt +50 -0
  129. data/spec/inputs/numbers/11.txt +50 -0
  130. data/spec/inputs/numbers/12.txt +50 -0
  131. data/spec/inputs/numbers/13.txt +50 -0
  132. data/spec/inputs/numbers/14.txt +50 -0
  133. data/spec/inputs/numbers/15.txt +50 -0
  134. data/spec/inputs/numbers/16.txt +50 -0
  135. data/spec/inputs/numbers/17.txt +50 -0
  136. data/spec/inputs/numbers/18.txt +50 -0
  137. data/spec/inputs/numbers/19.txt +50 -0
  138. data/spec/inputs/numbers/2.txt +50 -0
  139. data/spec/inputs/numbers/20.txt +50 -0
  140. data/spec/inputs/numbers/3.txt +50 -0
  141. data/spec/inputs/numbers/4.txt +50 -0
  142. data/spec/inputs/numbers/5.txt +50 -0
  143. data/spec/inputs/numbers/6.txt +50 -0
  144. data/spec/inputs/numbers/7.txt +50 -0
  145. data/spec/inputs/numbers/8.txt +50 -0
  146. data/spec/inputs/numbers/9.txt +50 -0
  147. data/spec/inputs/numbers_0_100.txt +101 -0
  148. data/spec/inputs/numbers_1_100.txt +100 -0
  149. data/spec/lib/collect_spec.rb +42 -0
  150. data/spec/lib/command_spec.rb +68 -0
  151. data/spec/lib/config_spec.rb +64 -0
  152. data/spec/lib/context_spec.rb +163 -0
  153. data/spec/lib/ext_spec.rb +72 -0
  154. data/spec/lib/external_apps_spec.rb +45 -0
  155. data/spec/lib/filter_spec.rb +80 -0
  156. data/spec/lib/flat_map_spec.rb +100 -0
  157. data/spec/lib/group_spec.rb +109 -0
  158. data/spec/lib/helper_spec.rb +19 -0
  159. data/spec/lib/key_spec.rb +41 -0
  160. data/spec/lib/manipulation_spec.rb +114 -0
  161. data/spec/lib/map_partitions_spec.rb +87 -0
  162. data/spec/lib/map_spec.rb +91 -0
  163. data/spec/lib/mllib/classification_spec.rb +54 -0
  164. data/spec/lib/mllib/clustering_spec.rb +35 -0
  165. data/spec/lib/mllib/matrix_spec.rb +32 -0
  166. data/spec/lib/mllib/regression_spec.rb +116 -0
  167. data/spec/lib/mllib/vector_spec.rb +77 -0
  168. data/spec/lib/reduce_by_key_spec.rb +118 -0
  169. data/spec/lib/reduce_spec.rb +131 -0
  170. data/spec/lib/sample_spec.rb +46 -0
  171. data/spec/lib/serializer_spec.rb +13 -0
  172. data/spec/lib/sort_spec.rb +58 -0
  173. data/spec/lib/statistic_spec.rb +168 -0
  174. data/spec/lib/whole_text_files_spec.rb +33 -0
  175. data/spec/spec_helper.rb +39 -0
  176. metadata +301 -0
@@ -0,0 +1,97 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # NaiveBayesModel
5
+ #
6
+ # Model for Naive Bayes classifiers.
7
+ #
8
+ # Contains two parameters:
9
+ # pi:: vector of logs of class priors (dimension C)
10
+ # theta:: matrix of logs of class conditional probabilities (CxD)
11
+ #
12
+ # == Examples:
13
+ #
14
+ # Spark::Mllib.import
15
+ #
16
+ # # Dense vectors
17
+ # data = [
18
+ # LabeledPoint.new(0.0, [0.0, 0.0]),
19
+ # LabeledPoint.new(0.0, [0.0, 1.0]),
20
+ # LabeledPoint.new(1.0, [1.0, 0.0])
21
+ # ]
22
+ # model = NaiveBayes.train($sc.parallelize(data))
23
+ #
24
+ # model.predict([0.0, 1.0])
25
+ # # => 0.0
26
+ # model.predict([1.0, 0.0])
27
+ # # => 1.0
28
+ #
29
+ #
30
+ # # Sparse vectors
31
+ # data = [
32
+ # LabeledPoint.new(0.0, SparseVector.new(2, {1 => 0.0})),
33
+ # LabeledPoint.new(0.0, SparseVector.new(2, {1 => 1.0})),
34
+ # LabeledPoint.new(1.0, SparseVector.new(2, {0 => 1.0}))
35
+ # ]
36
+ # model = NaiveBayes.train($sc.parallelize(data))
37
+ #
38
+ # model.predict(SparseVector.new(2, {1 => 1.0}))
39
+ # # => 0.0
40
+ # model.predict(SparseVector.new(2, {0 => 1.0}))
41
+ # # => 1.0
42
+ #
43
+ class NaiveBayesModel
44
+
45
+ attr_reader :labels, :pi, :theta
46
+
47
+ def initialize(labels, pi, theta)
48
+ @labels = labels
49
+ @pi = pi
50
+ @theta = theta
51
+ end
52
+
53
+ # Predict values for a single data point or an RDD of points using
54
+ # the model trained.
55
+ def predict(vector)
56
+ vector = Spark::Mllib::Vectors.to_vector(vector)
57
+ array = (vector.dot(theta) + pi).to_a
58
+ index = array.index(array.max)
59
+ labels[index]
60
+ end
61
+
62
+ end
63
+ end
64
+ end
65
+
66
+
67
+ module Spark
68
+ module Mllib
69
+ class NaiveBayes
70
+
71
+ # Trains a Naive Bayes model given an RDD of (label, features) pairs.
72
+ #
73
+ # This is the Multinomial NB (http://tinyurl.com/lsdw6p) which can handle all kinds of
74
+ # discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
75
+ # document classification. By making every vector a 0-1 vector, it can also be used as
76
+ # Bernoulli NB (http://tinyurl.com/p7c96j6). The input feature values must be nonnegative.
77
+ #
78
+ # == Arguments:
79
+ # rdd:: RDD of LabeledPoint.
80
+ # lambda:: The smoothing parameter.
81
+ #
82
+ def self.train(rdd, lambda=1.0)
83
+ # Validation
84
+ first = rdd.first
85
+ unless first.is_a?(LabeledPoint)
86
+ raise Spark::MllibError, "RDD should contains LabeledPoint, got #{first.class}"
87
+ end
88
+
89
+ labels, pi, theta = Spark.jb.call(RubyMLLibAPI.new, 'trainNaiveBayes', rdd, lambda)
90
+ theta = Spark::Mllib::Matrices.dense(theta.size, theta.first.size, theta)
91
+
92
+ NaiveBayesModel.new(labels, pi, theta)
93
+ end
94
+
95
+ end
96
+ end
97
+ end
@@ -0,0 +1,135 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # SVMModel
5
+ #
6
+ # A support vector machine.
7
+ #
8
+ # == Examples:
9
+ #
10
+ # Spark::Mllib.import
11
+ #
12
+ # # Dense vectors
13
+ # data = [
14
+ # LabeledPoint.new(0.0, [0.0]),
15
+ # LabeledPoint.new(1.0, [1.0]),
16
+ # LabeledPoint.new(1.0, [2.0]),
17
+ # LabeledPoint.new(1.0, [3.0])
18
+ # ]
19
+ # svm = SVMWithSGD.train($sc.parallelize(data))
20
+ #
21
+ # svm.predict([1.0])
22
+ # # => 1
23
+ # svm.clear_threshold
24
+ # svm.predict([1.0])
25
+ # # => 1.25...
26
+ #
27
+ #
28
+ # # Sparse vectors
29
+ # data = [
30
+ # LabeledPoint.new(0.0, SparseVector.new(2, {0 => -1.0})),
31
+ # LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
32
+ # LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
33
+ # LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
34
+ # ]
35
+ # svm = SVMWithSGD.train($sc.parallelize(data))
36
+ #
37
+ # svm.predict(SparseVector.new(2, {1 => 1.0}))
38
+ # # => 1
39
+ # svm.predict(SparseVector.new(2, {0 => -1.0}))
40
+ # # => 0
41
+ #
42
+ class SVMModel < ClassificationModel
43
+
44
+ def initialize(*args)
45
+ super
46
+ @threshold = 0.0
47
+ end
48
+
49
+ # Predict values for a single data point or an RDD of points using
50
+ # the model trained.
51
+ def predict(vector)
52
+ vector = Spark::Mllib::Vectors.to_vector(vector)
53
+ margin = weights.dot(vector) + intercept
54
+
55
+ if threshold.nil?
56
+ return margin
57
+ end
58
+
59
+ if margin > threshold
60
+ 1
61
+ else
62
+ 0
63
+ end
64
+ end
65
+
66
+ end
67
+ end
68
+ end
69
+
70
+ module Spark
71
+ module Mllib
72
+ class SVMWithSGD < ClassificationMethodBase
73
+
74
+ DEFAULT_OPTIONS = {
75
+ iterations: 100,
76
+ step: 1.0,
77
+ reg_param: 0.01,
78
+ mini_batch_fraction: 1.0,
79
+ initial_weights: nil,
80
+ reg_type: 'l2',
81
+ intercept: false
82
+ }
83
+
84
+ # Train a support vector machine on the given data.
85
+ #
86
+ # rdd::
87
+ # The training data, an RDD of LabeledPoint.
88
+ #
89
+ # iterations::
90
+ # The number of iterations (default: 100).
91
+ #
92
+ # step::
93
+ # The step parameter used in SGD (default: 1.0).
94
+ #
95
+ # reg_param::
96
+ # The regularizer parameter (default: 0.01).
97
+ #
98
+ # mini_batch_fraction::
99
+ # Fraction of data to be used for each SGD iteration.
100
+ #
101
+ # initial_weights::
102
+ # The initial weights (default: nil).
103
+ #
104
+ # reg_type::
105
+ # The type of regularizer used for training our model (default: "l2").
106
+ #
107
+ # Allowed values:
108
+ # - "l1" for using L1 regularization
109
+ # - "l2" for using L2 regularization
110
+ # - nil for no regularization
111
+ #
112
+ # intercept::
113
+ # Boolean parameter which indicates the use
114
+ # or not of the augmented representation for
115
+ # training data (i.e. whether bias features
116
+ # are activated or not).
117
+ #
118
+ def self.train(rdd, options={})
119
+ super
120
+
121
+ weights, intercept = Spark.jb.call(RubyMLLibAPI.new, 'trainSVMModelWithSGD', rdd,
122
+ options[:iterations].to_i,
123
+ options[:step].to_f,
124
+ options[:reg_param].to_f,
125
+ options[:mini_batch_fraction].to_f,
126
+ options[:initial_weights],
127
+ options[:reg_type],
128
+ options[:intercept])
129
+
130
+ SVMModel.new(weights, intercept)
131
+ end
132
+
133
+ end
134
+ end
135
+ end
@@ -0,0 +1,82 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # GaussianMixtureModel
5
+ #
6
+ # A clustering model derived from the Gaussian Mixture Model method.
7
+ #
8
+ # == Examples:
9
+ #
10
+ # Spark::Mllib.import
11
+ #
12
+ # data = [
13
+ # DenseVector.new([-0.1, -0.05]),
14
+ # DenseVector.new([-0.01, -0.1]),
15
+ # DenseVector.new([0.9, 0.8]),
16
+ # DenseVector.new([0.75, 0.935]),
17
+ # DenseVector.new([-0.83, -0.68]),
18
+ # DenseVector.new([-0.91, -0.76])
19
+ # ]
20
+ #
21
+ # model = GaussianMixture.train($sc.parallelize(data), 3, convergence_tol: 0.0001, max_iterations: 50, seed: 10)
22
+ #
23
+ # labels = model.predict($sc.parallelize(data)).collect
24
+ #
25
+ class GaussianMixtureModel
26
+
27
+ attr_reader :weights, :gaussians, :k
28
+
29
+ def initialize(weights, gaussians)
30
+ @weights = weights
31
+ @gaussians = gaussians
32
+ @k = weights.size
33
+ end
34
+
35
+ # Find the cluster to which the points in 'x' has maximum membership
36
+ # in this model.
37
+ def predict(rdd)
38
+ if rdd.is_a?(Spark::RDD)
39
+ predict_soft(rdd).map('lambda{|x| x.index(x.max)}')
40
+ else
41
+ raise ArgumentError, 'Argument must be a RDD.'
42
+ end
43
+ end
44
+
45
+ # Find the membership of each point in 'x' to all mixture components.
46
+ def predict_soft(rdd)
47
+ Spark.jb.call(RubyMLLibAPI.new, 'predictSoftGMM', rdd, weights, means, sigmas)
48
+ end
49
+
50
+ def means
51
+ @means ||= @gaussians.map(&:mu)
52
+ end
53
+
54
+ def sigmas
55
+ @sigmas ||= @gaussians.map(&:sigma)
56
+ end
57
+
58
+ end
59
+ end
60
+ end
61
+
62
+ module Spark
63
+ module Mllib
64
+ class GaussianMixture
65
+
66
+ def self.train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil)
67
+ weights, means, sigmas = Spark.jb.call(RubyMLLibAPI.new, 'trainGaussianMixture', rdd,
68
+ k, convergence_tol, max_iterations, Spark.jb.to_long(seed))
69
+
70
+ means.map! {|mu| Spark.jb.java_to_ruby(mu)}
71
+ sigmas.map!{|sigma| Spark.jb.java_to_ruby(sigma)}
72
+
73
+ mvgs = Array.new(k) do |i|
74
+ MultivariateGaussian.new(means[i], sigmas[i])
75
+ end
76
+
77
+ GaussianMixtureModel.new(weights, mvgs)
78
+ end
79
+
80
+ end
81
+ end
82
+ end
@@ -0,0 +1,118 @@
1
+ module Spark
2
+ module Mllib
3
+ ##
4
+ # KMeansModel
5
+ #
6
+ # A clustering model derived from the k-means method.
7
+ #
8
+ # == Examples:
9
+ #
10
+ # Spark::Mllib.import
11
+ #
12
+ # # Dense vectors
13
+ # data = [
14
+ # DenseVector.new([0.0,0.0]),
15
+ # DenseVector.new([1.0,1.0]),
16
+ # DenseVector.new([9.0,8.0]),
17
+ # DenseVector.new([8.0,9.0])
18
+ # ]
19
+ #
20
+ # model = KMeans.train($sc.parallelize(data), 2, max_iterations: 10,
21
+ # runs: 30, initialization_mode: "random")
22
+ #
23
+ # model.predict([0.0, 0.0]) == model.predict([1.0, 1.0])
24
+ # # => true
25
+ # model.predict([8.0, 9.0]) == model.predict([9.0, 8.0])
26
+ # # => true
27
+ #
28
+ #
29
+ # # Sparse vectors
30
+ # data = [
31
+ # SparseVector.new(3, {1 => 1.0}),
32
+ # SparseVector.new(3, {1 => 1.1}),
33
+ # SparseVector.new(3, {2 => 1.0}),
34
+ # SparseVector.new(3, {2 => 1.1})
35
+ # ]
36
+ # model = KMeans.train($sc.parallelize(data), 2, initialization_mode: "k-means||")
37
+ #
38
+ # model.predict([0.0, 1.0, 0.0]) == model.predict([0, 1.1, 0.0])
39
+ # # => true
40
+ # model.predict([0.0, 0.0, 1.0]) == model.predict([0, 0, 1.1])
41
+ # # => true
42
+ # model.predict(data[0]) == model.predict(data[1])
43
+ # # => true
44
+ # model.predict(data[2]) == model.predict(data[3])
45
+ # # => true
46
+ #
47
+ class KMeansModel
48
+
49
+ attr_reader :centers
50
+
51
+ def initialize(centers)
52
+ @centers = centers
53
+ end
54
+
55
+ # Find the cluster to which x belongs in this model.
56
+ def predict(vector)
57
+ vector = Spark::Mllib::Vectors.to_vector(vector)
58
+ best = 0
59
+ best_distance = Float::INFINITY
60
+
61
+ @centers.each_with_index do |center, index|
62
+ distance = vector.squared_distance(center)
63
+ if distance < best_distance
64
+ best = index
65
+ best_distance = distance
66
+ end
67
+ end
68
+
69
+ best
70
+ end
71
+
72
+ def self.from_java(object)
73
+ centers = object.clusterCenters
74
+ centers.map! do |center|
75
+ Spark.jb.java_to_ruby(center)
76
+ end
77
+
78
+ KMeansModel.new(centers)
79
+ end
80
+
81
+ end
82
+ end
83
+ end
84
+
85
+ module Spark
86
+ module Mllib
87
+ class KMeans
88
+
89
+ # Trains a k-means model using the given set of parameters.
90
+ #
91
+ # == Arguments:
92
+ # rdd::
93
+ # The training data, an RDD of Vectors.
94
+ #
95
+ # k::
96
+ # Number of clusters.
97
+ #
98
+ # max_iterations::
99
+ # Max number of iterations.
100
+ #
101
+ # runs::
102
+ # Number of parallel runs, defaults to 1. The best model is returned.
103
+ #
104
+ # initialization_mode::
105
+ # Initialization model, either "random" or "k-means||" (default).
106
+ #
107
+ # seed::
108
+ # Random seed value for cluster initialization.
109
+ #
110
+ def self.train(rdd, k, max_iterations: 100, runs: 1, initialization_mode: 'k-means||', seed: nil)
111
+ # Call returns KMeansModel
112
+ Spark.jb.call(RubyMLLibAPI.new, 'trainKMeansModel', rdd,
113
+ k, max_iterations, runs, initialization_mode, Spark.jb.to_long(seed))
114
+ end
115
+
116
+ end
117
+ end
118
+ end
@@ -0,0 +1,120 @@
1
+ module Spark
2
+ module Mllib
3
+ module Matrices
4
+
5
+ def self.dense(*args)
6
+ DenseMatrix.new(*args)
7
+ end
8
+
9
+ def self.sparse(*args)
10
+ SparseMatrix.new(*args)
11
+ end
12
+
13
+ def self.to_matrix(data)
14
+ if data.is_a?(SparseMatrix) || data.is_a?(DenseMatrix)
15
+ data
16
+ elsif data.is_a?(Array)
17
+ DenseMatrix.new(data)
18
+ end
19
+ end
20
+
21
+ end
22
+ end
23
+ end
24
+
25
+ module Spark
26
+ module Mllib
27
+ # @abstract Parent for all type of matrices
28
+ class MatrixBase < MatrixAdapter
29
+ end
30
+ end
31
+ end
32
+
33
+ module Spark
34
+ module Mllib
35
+ ##
36
+ # DenseMatrix
37
+ #
38
+ # DenseMatrix.new(2, 3, [[1,2,3], [4,5,6]]).values
39
+ # # => [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
40
+ #
41
+ class DenseMatrix < MatrixBase
42
+
43
+ def initialize(rows, cols, values)
44
+ super(:dense, rows, cols, values.to_a)
45
+ end
46
+
47
+ def to_java
48
+ JDenseMatrix.new(shape[0], shape[1], values.flatten)
49
+ end
50
+
51
+ def self.from_java(object)
52
+ rows = object.numRows
53
+ cols = object.numCols
54
+ values = object.values
55
+
56
+ DenseMatrix.new(rows, cols, values)
57
+ end
58
+
59
+ end
60
+ end
61
+ end
62
+
63
+ module Spark
64
+ module Mllib
65
+ ##
66
+ # SparseMatrix
67
+ #
68
+ # == Arguments:
69
+ # rows::
70
+ # Number of rows.
71
+ #
72
+ # cols::
73
+ # Number of columns.
74
+ #
75
+ # col_pointers::
76
+ # The index corresponding to the start of a new column.
77
+ #
78
+ # row_indices::
79
+ # The row index of the entry. They must be in strictly
80
+ # increasing order for each column.
81
+ #
82
+ # values::
83
+ # Nonzero matrix entries in column major.
84
+ #
85
+ # == Examples:
86
+ #
87
+ # SparseMatrix.new(3, 3, [0, 2, 3, 6], [0, 2, 1, 0, 1, 2], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).values
88
+ #
89
+ # # => [
90
+ # # [1.0, 0.0, 4.0],
91
+ # # [0.0, 3.0, 5.0],
92
+ # # [2.0, 0.0, 6.0]
93
+ # # ]
94
+ #
95
+ class SparseMatrix < MatrixBase
96
+
97
+ attr_reader :col_pointers, :row_indices
98
+
99
+ def initialize(rows, cols, col_pointers, row_indices, values)
100
+ super(:sparse, rows, cols)
101
+
102
+ @col_pointers = col_pointers
103
+ @row_indices = row_indices
104
+ @values = values
105
+
106
+ j = 0
107
+ while j < cols
108
+ idx = col_pointers[j]
109
+ idx_end = col_pointers[j+1]
110
+ while idx < idx_end
111
+ self[row_indices[idx], j] = values[idx]
112
+ idx += 1
113
+ end
114
+ j += 1
115
+ end
116
+ end
117
+
118
+ end
119
+ end
120
+ end