ruby-spark 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +185 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +7 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/custom_marshal.rb +94 -0
  12. data/benchmark/digest.rb +150 -0
  13. data/benchmark/enumerator.rb +88 -0
  14. data/benchmark/performance/prepare.sh +18 -0
  15. data/benchmark/performance/python.py +156 -0
  16. data/benchmark/performance/r.r +69 -0
  17. data/benchmark/performance/ruby.rb +167 -0
  18. data/benchmark/performance/run-all.sh +160 -0
  19. data/benchmark/performance/scala.scala +181 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/ext/ruby_c/extconf.rb +3 -0
  27. data/ext/ruby_c/murmur.c +158 -0
  28. data/ext/ruby_c/murmur.h +9 -0
  29. data/ext/ruby_c/ruby-spark.c +18 -0
  30. data/ext/ruby_java/Digest.java +36 -0
  31. data/ext/ruby_java/Murmur2.java +98 -0
  32. data/ext/ruby_java/RubySparkExtService.java +28 -0
  33. data/ext/ruby_java/extconf.rb +3 -0
  34. data/ext/spark/build.sbt +73 -0
  35. data/ext/spark/project/plugins.sbt +9 -0
  36. data/ext/spark/sbt/sbt +34 -0
  37. data/ext/spark/src/main/scala/Exec.scala +91 -0
  38. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  39. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  40. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  41. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  42. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  43. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  44. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  46. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  47. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  48. data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
  49. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  50. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  51. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  52. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  53. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  54. data/lib/ruby-spark.rb +1 -0
  55. data/lib/spark.rb +198 -0
  56. data/lib/spark/accumulator.rb +260 -0
  57. data/lib/spark/broadcast.rb +98 -0
  58. data/lib/spark/build.rb +43 -0
  59. data/lib/spark/cli.rb +169 -0
  60. data/lib/spark/command.rb +86 -0
  61. data/lib/spark/command/base.rb +154 -0
  62. data/lib/spark/command/basic.rb +345 -0
  63. data/lib/spark/command/pair.rb +124 -0
  64. data/lib/spark/command/sort.rb +51 -0
  65. data/lib/spark/command/statistic.rb +144 -0
  66. data/lib/spark/command_builder.rb +141 -0
  67. data/lib/spark/command_validator.rb +34 -0
  68. data/lib/spark/config.rb +244 -0
  69. data/lib/spark/constant.rb +14 -0
  70. data/lib/spark/context.rb +304 -0
  71. data/lib/spark/error.rb +50 -0
  72. data/lib/spark/ext/hash.rb +41 -0
  73. data/lib/spark/ext/integer.rb +25 -0
  74. data/lib/spark/ext/io.rb +57 -0
  75. data/lib/spark/ext/ip_socket.rb +29 -0
  76. data/lib/spark/ext/module.rb +58 -0
  77. data/lib/spark/ext/object.rb +24 -0
  78. data/lib/spark/ext/string.rb +24 -0
  79. data/lib/spark/helper.rb +10 -0
  80. data/lib/spark/helper/logger.rb +40 -0
  81. data/lib/spark/helper/parser.rb +85 -0
  82. data/lib/spark/helper/serialize.rb +71 -0
  83. data/lib/spark/helper/statistic.rb +93 -0
  84. data/lib/spark/helper/system.rb +42 -0
  85. data/lib/spark/java_bridge.rb +19 -0
  86. data/lib/spark/java_bridge/base.rb +203 -0
  87. data/lib/spark/java_bridge/jruby.rb +23 -0
  88. data/lib/spark/java_bridge/rjb.rb +41 -0
  89. data/lib/spark/logger.rb +76 -0
  90. data/lib/spark/mllib.rb +100 -0
  91. data/lib/spark/mllib/classification/common.rb +31 -0
  92. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  93. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  94. data/lib/spark/mllib/classification/svm.rb +135 -0
  95. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  96. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  97. data/lib/spark/mllib/matrix.rb +120 -0
  98. data/lib/spark/mllib/regression/common.rb +73 -0
  99. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  100. data/lib/spark/mllib/regression/lasso.rb +100 -0
  101. data/lib/spark/mllib/regression/linear.rb +124 -0
  102. data/lib/spark/mllib/regression/ridge.rb +97 -0
  103. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  104. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  105. data/lib/spark/mllib/stat/distribution.rb +12 -0
  106. data/lib/spark/mllib/vector.rb +185 -0
  107. data/lib/spark/rdd.rb +1328 -0
  108. data/lib/spark/sampler.rb +92 -0
  109. data/lib/spark/serializer.rb +24 -0
  110. data/lib/spark/serializer/base.rb +170 -0
  111. data/lib/spark/serializer/cartesian.rb +37 -0
  112. data/lib/spark/serializer/marshal.rb +19 -0
  113. data/lib/spark/serializer/message_pack.rb +25 -0
  114. data/lib/spark/serializer/oj.rb +25 -0
  115. data/lib/spark/serializer/pair.rb +27 -0
  116. data/lib/spark/serializer/utf8.rb +25 -0
  117. data/lib/spark/sort.rb +189 -0
  118. data/lib/spark/stat_counter.rb +125 -0
  119. data/lib/spark/storage_level.rb +39 -0
  120. data/lib/spark/version.rb +3 -0
  121. data/lib/spark/worker/master.rb +144 -0
  122. data/lib/spark/worker/spark_files.rb +15 -0
  123. data/lib/spark/worker/worker.rb +197 -0
  124. data/ruby-spark.gemspec +36 -0
  125. data/spec/generator.rb +37 -0
  126. data/spec/inputs/lorem_300.txt +316 -0
  127. data/spec/inputs/numbers/1.txt +50 -0
  128. data/spec/inputs/numbers/10.txt +50 -0
  129. data/spec/inputs/numbers/11.txt +50 -0
  130. data/spec/inputs/numbers/12.txt +50 -0
  131. data/spec/inputs/numbers/13.txt +50 -0
  132. data/spec/inputs/numbers/14.txt +50 -0
  133. data/spec/inputs/numbers/15.txt +50 -0
  134. data/spec/inputs/numbers/16.txt +50 -0
  135. data/spec/inputs/numbers/17.txt +50 -0
  136. data/spec/inputs/numbers/18.txt +50 -0
  137. data/spec/inputs/numbers/19.txt +50 -0
  138. data/spec/inputs/numbers/2.txt +50 -0
  139. data/spec/inputs/numbers/20.txt +50 -0
  140. data/spec/inputs/numbers/3.txt +50 -0
  141. data/spec/inputs/numbers/4.txt +50 -0
  142. data/spec/inputs/numbers/5.txt +50 -0
  143. data/spec/inputs/numbers/6.txt +50 -0
  144. data/spec/inputs/numbers/7.txt +50 -0
  145. data/spec/inputs/numbers/8.txt +50 -0
  146. data/spec/inputs/numbers/9.txt +50 -0
  147. data/spec/inputs/numbers_0_100.txt +101 -0
  148. data/spec/inputs/numbers_1_100.txt +100 -0
  149. data/spec/lib/collect_spec.rb +42 -0
  150. data/spec/lib/command_spec.rb +68 -0
  151. data/spec/lib/config_spec.rb +64 -0
  152. data/spec/lib/context_spec.rb +163 -0
  153. data/spec/lib/ext_spec.rb +72 -0
  154. data/spec/lib/external_apps_spec.rb +45 -0
  155. data/spec/lib/filter_spec.rb +80 -0
  156. data/spec/lib/flat_map_spec.rb +100 -0
  157. data/spec/lib/group_spec.rb +109 -0
  158. data/spec/lib/helper_spec.rb +19 -0
  159. data/spec/lib/key_spec.rb +41 -0
  160. data/spec/lib/manipulation_spec.rb +114 -0
  161. data/spec/lib/map_partitions_spec.rb +87 -0
  162. data/spec/lib/map_spec.rb +91 -0
  163. data/spec/lib/mllib/classification_spec.rb +54 -0
  164. data/spec/lib/mllib/clustering_spec.rb +35 -0
  165. data/spec/lib/mllib/matrix_spec.rb +32 -0
  166. data/spec/lib/mllib/regression_spec.rb +116 -0
  167. data/spec/lib/mllib/vector_spec.rb +77 -0
  168. data/spec/lib/reduce_by_key_spec.rb +118 -0
  169. data/spec/lib/reduce_spec.rb +131 -0
  170. data/spec/lib/sample_spec.rb +46 -0
  171. data/spec/lib/serializer_spec.rb +13 -0
  172. data/spec/lib/sort_spec.rb +58 -0
  173. data/spec/lib/statistic_spec.rb +168 -0
  174. data/spec/lib/whole_text_files_spec.rb +33 -0
  175. data/spec/spec_helper.rb +39 -0
  176. metadata +301 -0
@@ -0,0 +1,113 @@
1
+ package org.apache.spark.api.ruby.marshal
2
+
3
+ import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
4
+
5
+ import scala.collection.mutable.ArrayBuffer
6
+ import scala.collection.JavaConverters._
7
+ import scala.reflect.{ClassTag, classTag}
8
+
9
+ import org.apache.spark.mllib.regression.LabeledPoint
10
+ import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
11
+
12
+
13
+ /* =================================================================================================
14
+ * class MarshalDump
15
+ * =================================================================================================
16
+ */
17
+ class MarshalDump(os: DataOutputStream) {
18
+
19
+ val NAN_BYTELIST = "nan".getBytes
20
+ val NEGATIVE_INFINITY_BYTELIST = "-inf".getBytes
21
+ val INFINITY_BYTELIST = "inf".getBytes
22
+
23
+ def dump(data: Any) {
24
+ data match {
25
+ case null =>
26
+ os.writeByte('0')
27
+
28
+ case item: Boolean =>
29
+ val char = if(item) 'T' else 'F'
30
+ os.writeByte(char)
31
+
32
+ case item: Int =>
33
+ os.writeByte('i')
34
+ dumpInt(item)
35
+
36
+ case item: Array[_] =>
37
+ os.writeByte('[')
38
+ dumpArray(item)
39
+
40
+ case item: Double =>
41
+ os.writeByte('f')
42
+ dumpFloat(item)
43
+
44
+ case item: ArrayBuffer[Any] => dump(item.toArray)
45
+ }
46
+ }
47
+
48
+ def dumpInt(data: Int) {
49
+ if(data == 0){
50
+ os.writeByte(0)
51
+ }
52
+ else if (0 < data && data < 123) {
53
+ os.writeByte(data + 5)
54
+ }
55
+ else if (-124 < data && data < 0) {
56
+ os.writeByte((data - 5) & 0xff)
57
+ }
58
+ else {
59
+ val buffer = new Array[Byte](4)
60
+ var value = data
61
+
62
+ var i = 0
63
+ while(i != 4 && value != 0 && value != -1){
64
+ buffer(i) = (value & 0xff).toByte
65
+ value = value >> 8
66
+
67
+ i += 1
68
+ }
69
+ val lenght = i + 1
70
+ if(value < 0){
71
+ os.writeByte(-lenght)
72
+ }
73
+ else{
74
+ os.writeByte(lenght)
75
+ }
76
+ os.write(buffer, 0, lenght)
77
+ }
78
+ }
79
+
80
+ def dumpArray(array: Array[_]) {
81
+ dumpInt(array.size)
82
+
83
+ for(item <- array) {
84
+ dump(item)
85
+ }
86
+ }
87
+
88
+ def dumpFloat(value: Double) {
89
+ if(value.isPosInfinity){
90
+ dumpString(NEGATIVE_INFINITY_BYTELIST)
91
+ }
92
+ else if(value.isNegInfinity){
93
+ dumpString(INFINITY_BYTELIST)
94
+ }
95
+ else if(value.isNaN){
96
+ dumpString(NAN_BYTELIST)
97
+ }
98
+ else{
99
+ // dumpString("%.17g".format(value))
100
+ dumpString(value.toString)
101
+ }
102
+ }
103
+
104
+ def dumpString(data: String) {
105
+ dumpString(data.getBytes)
106
+ }
107
+
108
+ def dumpString(data: Array[Byte]) {
109
+ dumpInt(data.size)
110
+ os.write(data)
111
+ }
112
+
113
+ }
@@ -0,0 +1,220 @@
1
+ package org.apache.spark.api.ruby.marshal
2
+
3
+ import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
4
+
5
+ import scala.collection.mutable.ArrayBuffer
6
+ import scala.collection.JavaConverters._
7
+ import scala.reflect.{ClassTag, classTag}
8
+
9
+ import org.apache.spark.mllib.regression.LabeledPoint
10
+ import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
11
+
12
+
13
+ /* =================================================================================================
14
+ * class MarshalLoad
15
+ * =================================================================================================
16
+ */
17
+ class MarshalLoad(is: DataInputStream) {
18
+
19
+ case class WaitForObject()
20
+
21
+ val registeredSymbols = ArrayBuffer[String]()
22
+ val registeredLinks = ArrayBuffer[Any]()
23
+
24
+ def load: Any = {
25
+ load(is.readUnsignedByte.toChar)
26
+ }
27
+
28
+ def load(dataType: Char): Any = {
29
+ dataType match {
30
+ case '0' => null
31
+ case 'T' => true
32
+ case 'F' => false
33
+ case 'i' => loadInt
34
+ case 'f' => loadAndRegisterFloat
35
+ case ':' => loadAndRegisterSymbol
36
+ case '[' => loadAndRegisterArray
37
+ case 'U' => loadAndRegisterUserObject
38
+ case _ =>
39
+ throw new IllegalArgumentException(s"Format is not supported: $dataType.")
40
+ }
41
+ }
42
+
43
+
44
+ // ----------------------------------------------------------------------------------------------
45
+ // Load by type
46
+
47
+ def loadInt: Int = {
48
+ var c = is.readByte.toInt
49
+
50
+ if (c == 0) {
51
+ return 0
52
+ } else if (4 < c && c < 128) {
53
+ return c - 5
54
+ } else if (-129 < c && c < -4) {
55
+ return c + 5
56
+ }
57
+
58
+ var result: Long = 0
59
+
60
+ if (c > 0) {
61
+ result = 0
62
+ for( i <- 0 until c ) {
63
+ result |= (is.readUnsignedByte << (8 * i)).toLong
64
+ }
65
+ } else {
66
+ c = -c
67
+ result = -1
68
+ for( i <- 0 until c ) {
69
+ result &= ~((0xff << (8 * i)).toLong)
70
+ result |= (is.readUnsignedByte << (8 * i)).toLong
71
+ }
72
+ }
73
+
74
+ result.toInt
75
+ }
76
+
77
+ def loadAndRegisterFloat: Double = {
78
+ val result = loadFloat
79
+ registeredLinks += result
80
+ result
81
+ }
82
+
83
+ def loadFloat: Double = {
84
+ val string = loadString
85
+ string match {
86
+ case "nan" => Double.NaN
87
+ case "inf" => Double.PositiveInfinity
88
+ case "-inf" => Double.NegativeInfinity
89
+ case _ => string.toDouble
90
+ }
91
+ }
92
+
93
+ def loadString: String = {
94
+ new String(loadStringBytes)
95
+ }
96
+
97
+ def loadStringBytes: Array[Byte] = {
98
+ val size = loadInt
99
+ val buffer = new Array[Byte](size)
100
+
101
+ var readSize = 0
102
+ while(readSize < size){
103
+ val read = is.read(buffer, readSize, size-readSize)
104
+
105
+ if(read == -1){
106
+ throw new IllegalArgumentException("Marshal too short.")
107
+ }
108
+
109
+ readSize += read
110
+ }
111
+
112
+ buffer
113
+ }
114
+
115
+ def loadAndRegisterSymbol: String = {
116
+ val result = loadString
117
+ registeredSymbols += result
118
+ result
119
+ }
120
+
121
+ def loadAndRegisterArray: Array[Any] = {
122
+ val size = loadInt
123
+ val array = new Array[Any](size)
124
+
125
+ registeredLinks += array
126
+
127
+ for( i <- 0 until size ) {
128
+ array(i) = loadNextObject
129
+ }
130
+
131
+ array
132
+ }
133
+
134
+ def loadAndRegisterUserObject: Any = {
135
+ val klass = loadNextObject.asInstanceOf[String]
136
+
137
+ // Register future class before load the next object
138
+ registeredLinks += WaitForObject()
139
+ val index = registeredLinks.size - 1
140
+
141
+ val data = loadNextObject
142
+
143
+ val result = klass match {
144
+ case "Spark::Mllib::LabeledPoint" => createLabeledPoint(data)
145
+ case "Spark::Mllib::DenseVector" => createDenseVector(data)
146
+ case "Spark::Mllib::SparseVector" => createSparseVector(data)
147
+ case other =>
148
+ throw new IllegalArgumentException(s"Object $other is not supported.")
149
+ }
150
+
151
+ registeredLinks(index) = result
152
+
153
+ result
154
+ }
155
+
156
+
157
+ // ----------------------------------------------------------------------------------------------
158
+ // Other loads
159
+
160
+ def loadNextObject: Any = {
161
+ val dataType = is.readUnsignedByte.toChar
162
+
163
+ if(isLinkType(dataType)){
164
+ readLink(dataType)
165
+ }
166
+ else{
167
+ load(dataType)
168
+ }
169
+ }
170
+
171
+
172
+ // ----------------------------------------------------------------------------------------------
173
+ // To java objects
174
+
175
+ def createLabeledPoint(data: Any): LabeledPoint = {
176
+ val array = data.asInstanceOf[Array[_]]
177
+ new LabeledPoint(array(0).asInstanceOf[Double], array(1).asInstanceOf[Vector])
178
+ }
179
+
180
+ def createDenseVector(data: Any): DenseVector = {
181
+ new DenseVector(data.asInstanceOf[Array[_]].map(toDouble(_)))
182
+ }
183
+
184
+ def createSparseVector(data: Any): SparseVector = {
185
+ val array = data.asInstanceOf[Array[_]]
186
+ val size = array(0).asInstanceOf[Int]
187
+ val indices = array(1).asInstanceOf[Array[_]].map(_.asInstanceOf[Int])
188
+ val values = array(2).asInstanceOf[Array[_]].map(toDouble(_))
189
+
190
+ new SparseVector(size, indices, values)
191
+ }
192
+
193
+
194
+ // ----------------------------------------------------------------------------------------------
195
+ // Helpers
196
+
197
+ def toDouble(data: Any): Double = data match {
198
+ case x: Int => x.toDouble
199
+ case x: Double => x
200
+ case _ => 0.0
201
+ }
202
+
203
+
204
+ // ----------------------------------------------------------------------------------------------
205
+ // Cache
206
+
207
+ def readLink(dataType: Char): Any = {
208
+ val index = loadInt
209
+
210
+ dataType match {
211
+ case '@' => registeredLinks(index)
212
+ case ';' => registeredSymbols(index)
213
+ }
214
+ }
215
+
216
+ def isLinkType(dataType: Char): Boolean = {
217
+ dataType == ';' || dataType == '@'
218
+ }
219
+
220
+ }
@@ -0,0 +1,69 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import java.io._
4
+ import java.net._
5
+ import java.util.{List, ArrayList}
6
+
7
+ import scala.collection.JavaConversions._
8
+ import scala.collection.immutable._
9
+
10
+ import org.apache.spark._
11
+ import org.apache.spark.util.Utils
12
+
13
+ /**
14
+ * Internal class that acts as an `AccumulatorParam` for Ruby accumulators. Inside, it
15
+ * collects a list of pickled strings that we pass to Ruby through a socket.
16
+ */
17
+ private class RubyAccumulatorParam(serverHost: String, serverPort: Int)
18
+ extends AccumulatorParam[List[Array[Byte]]] {
19
+
20
+ // Utils.checkHost(serverHost, "Expected hostname")
21
+
22
+ val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
23
+
24
+ // Socket shoudl not be serialized
25
+ // Otherwise: SparkException: Task not serializable
26
+ @transient var socket: Socket = null
27
+ @transient var socketOutputStream: DataOutputStream = null
28
+ @transient var socketInputStream: DataInputStream = null
29
+
30
+ def openSocket(){
31
+ synchronized {
32
+ if (socket == null || socket.isClosed) {
33
+ socket = new Socket(serverHost, serverPort)
34
+
35
+ socketInputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, bufferSize))
36
+ socketOutputStream = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
37
+ }
38
+ }
39
+ }
40
+
41
+ override def zero(value: List[Array[Byte]]): List[Array[Byte]] = new ArrayList
42
+
43
+ override def addInPlace(val1: List[Array[Byte]], val2: List[Array[Byte]]) : List[Array[Byte]] = synchronized {
44
+ if (serverHost == null) {
45
+ // This happens on the worker node, where we just want to remember all the updates
46
+ val1.addAll(val2)
47
+ val1
48
+ } else {
49
+ // This happens on the master, where we pass the updates to Ruby through a socket
50
+ openSocket()
51
+
52
+ socketOutputStream.writeInt(val2.size)
53
+ for (array <- val2) {
54
+ socketOutputStream.writeInt(array.length)
55
+ socketOutputStream.write(array)
56
+ }
57
+ socketOutputStream.flush()
58
+
59
+ // Wait for acknowledgement
60
+ // http://stackoverflow.com/questions/28560133/ruby-server-java-scala-client-deadlock
61
+ //
62
+ // if(in.readInt() != RubyConstant.ACCUMULATOR_ACK){
63
+ // throw new SparkException("Accumulator was not acknowledged")
64
+ // }
65
+
66
+ new ArrayList
67
+ }
68
+ }
69
+ }
@@ -0,0 +1,13 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import org.apache.spark.api.python.PythonBroadcast
4
+
5
+ /**
6
+ * An Wrapper for Ruby Broadcast, which is written into disk by Ruby. It also will
7
+ * write the data into disk after deserialization, then Ruby can read it from disks.
8
+ *
9
+ * Class use Python logic - only for semantic
10
+ */
11
+ class RubyBroadcast(@transient var _path: String, @transient var id: java.lang.Long) extends PythonBroadcast(_path) {
12
+
13
+ }
@@ -0,0 +1,13 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ object RubyConstant {
4
+ val DATA_EOF = -2
5
+ val WORKER_ERROR = -1
6
+ val WORKER_DONE = 0
7
+ val CREATE_WORKER = 1
8
+ val KILL_WORKER = 2
9
+ val KILL_WORKER_AND_WAIT = 3
10
+ val SUCCESSFULLY_KILLED = 4
11
+ val UNSUCCESSFUL_KILLING = 5
12
+ val ACCUMULATOR_ACK = 6
13
+ }
@@ -0,0 +1,55 @@
1
+ package org.apache.spark.mllib.api.ruby
2
+
3
+ import java.util.ArrayList
4
+
5
+ import scala.collection.JavaConverters._
6
+
7
+ import org.apache.spark.rdd.RDD
8
+ import org.apache.spark.api.java.JavaRDD
9
+ import org.apache.spark.mllib.linalg._
10
+ import org.apache.spark.mllib.regression.LabeledPoint
11
+ import org.apache.spark.mllib.classification.NaiveBayes
12
+ import org.apache.spark.mllib.clustering.GaussianMixtureModel
13
+ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
14
+ import org.apache.spark.mllib.api.python.MLLibAPI
15
+
16
+
17
+ class RubyMLLibAPI extends MLLibAPI {
18
+ // trainLinearRegressionModelWithSGD
19
+ // trainLassoModelWithSGD
20
+ // trainRidgeModelWithSGD
21
+ // trainLogisticRegressionModelWithSGD
22
+ // trainLogisticRegressionModelWithLBFGS
23
+ // trainSVMModelWithSGD
24
+ // trainKMeansModel
25
+ // trainGaussianMixture
26
+
27
+ // Rjb have a problem with theta: Array[Array[Double]]
28
+ override def trainNaiveBayes(data: JavaRDD[LabeledPoint], lambda: Double) = {
29
+ val model = NaiveBayes.train(data.rdd, lambda)
30
+
31
+ List(
32
+ Vectors.dense(model.labels),
33
+ Vectors.dense(model.pi),
34
+ model.theta.toSeq
35
+ ).map(_.asInstanceOf[Object]).asJava
36
+ }
37
+
38
+ // On python is wt just Object
39
+ def predictSoftGMM(
40
+ data: JavaRDD[Vector],
41
+ wt: ArrayList[Object],
42
+ mu: ArrayList[Object],
43
+ si: ArrayList[Object]): RDD[Array[Double]] = {
44
+
45
+ // val weight = wt.asInstanceOf[Array[Double]]
46
+ val weight = wt.toArray.map(_.asInstanceOf[Double])
47
+ val mean = mu.toArray.map(_.asInstanceOf[DenseVector])
48
+ val sigma = si.toArray.map(_.asInstanceOf[DenseMatrix])
49
+ val gaussians = Array.tabulate(weight.length){
50
+ i => new MultivariateGaussian(mean(i), sigma(i))
51
+ }
52
+ val model = new GaussianMixtureModel(weight, gaussians)
53
+ model.predictSoft(data)
54
+ }
55
+ }