ruby-spark 1.0.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (176) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +185 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +7 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/custom_marshal.rb +94 -0
  12. data/benchmark/digest.rb +150 -0
  13. data/benchmark/enumerator.rb +88 -0
  14. data/benchmark/performance/prepare.sh +18 -0
  15. data/benchmark/performance/python.py +156 -0
  16. data/benchmark/performance/r.r +69 -0
  17. data/benchmark/performance/ruby.rb +167 -0
  18. data/benchmark/performance/run-all.sh +160 -0
  19. data/benchmark/performance/scala.scala +181 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/ext/ruby_c/extconf.rb +3 -0
  27. data/ext/ruby_c/murmur.c +158 -0
  28. data/ext/ruby_c/murmur.h +9 -0
  29. data/ext/ruby_c/ruby-spark.c +18 -0
  30. data/ext/ruby_java/Digest.java +36 -0
  31. data/ext/ruby_java/Murmur2.java +98 -0
  32. data/ext/ruby_java/RubySparkExtService.java +28 -0
  33. data/ext/ruby_java/extconf.rb +3 -0
  34. data/ext/spark/build.sbt +73 -0
  35. data/ext/spark/project/plugins.sbt +9 -0
  36. data/ext/spark/sbt/sbt +34 -0
  37. data/ext/spark/src/main/scala/Exec.scala +91 -0
  38. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  39. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  40. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  41. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  42. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  43. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  44. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  46. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  47. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  48. data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
  49. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  50. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  51. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  52. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  53. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  54. data/lib/ruby-spark.rb +1 -0
  55. data/lib/spark.rb +198 -0
  56. data/lib/spark/accumulator.rb +260 -0
  57. data/lib/spark/broadcast.rb +98 -0
  58. data/lib/spark/build.rb +43 -0
  59. data/lib/spark/cli.rb +169 -0
  60. data/lib/spark/command.rb +86 -0
  61. data/lib/spark/command/base.rb +154 -0
  62. data/lib/spark/command/basic.rb +345 -0
  63. data/lib/spark/command/pair.rb +124 -0
  64. data/lib/spark/command/sort.rb +51 -0
  65. data/lib/spark/command/statistic.rb +144 -0
  66. data/lib/spark/command_builder.rb +141 -0
  67. data/lib/spark/command_validator.rb +34 -0
  68. data/lib/spark/config.rb +244 -0
  69. data/lib/spark/constant.rb +14 -0
  70. data/lib/spark/context.rb +304 -0
  71. data/lib/spark/error.rb +50 -0
  72. data/lib/spark/ext/hash.rb +41 -0
  73. data/lib/spark/ext/integer.rb +25 -0
  74. data/lib/spark/ext/io.rb +57 -0
  75. data/lib/spark/ext/ip_socket.rb +29 -0
  76. data/lib/spark/ext/module.rb +58 -0
  77. data/lib/spark/ext/object.rb +24 -0
  78. data/lib/spark/ext/string.rb +24 -0
  79. data/lib/spark/helper.rb +10 -0
  80. data/lib/spark/helper/logger.rb +40 -0
  81. data/lib/spark/helper/parser.rb +85 -0
  82. data/lib/spark/helper/serialize.rb +71 -0
  83. data/lib/spark/helper/statistic.rb +93 -0
  84. data/lib/spark/helper/system.rb +42 -0
  85. data/lib/spark/java_bridge.rb +19 -0
  86. data/lib/spark/java_bridge/base.rb +203 -0
  87. data/lib/spark/java_bridge/jruby.rb +23 -0
  88. data/lib/spark/java_bridge/rjb.rb +41 -0
  89. data/lib/spark/logger.rb +76 -0
  90. data/lib/spark/mllib.rb +100 -0
  91. data/lib/spark/mllib/classification/common.rb +31 -0
  92. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  93. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  94. data/lib/spark/mllib/classification/svm.rb +135 -0
  95. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  96. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  97. data/lib/spark/mllib/matrix.rb +120 -0
  98. data/lib/spark/mllib/regression/common.rb +73 -0
  99. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  100. data/lib/spark/mllib/regression/lasso.rb +100 -0
  101. data/lib/spark/mllib/regression/linear.rb +124 -0
  102. data/lib/spark/mllib/regression/ridge.rb +97 -0
  103. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  104. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  105. data/lib/spark/mllib/stat/distribution.rb +12 -0
  106. data/lib/spark/mllib/vector.rb +185 -0
  107. data/lib/spark/rdd.rb +1328 -0
  108. data/lib/spark/sampler.rb +92 -0
  109. data/lib/spark/serializer.rb +24 -0
  110. data/lib/spark/serializer/base.rb +170 -0
  111. data/lib/spark/serializer/cartesian.rb +37 -0
  112. data/lib/spark/serializer/marshal.rb +19 -0
  113. data/lib/spark/serializer/message_pack.rb +25 -0
  114. data/lib/spark/serializer/oj.rb +25 -0
  115. data/lib/spark/serializer/pair.rb +27 -0
  116. data/lib/spark/serializer/utf8.rb +25 -0
  117. data/lib/spark/sort.rb +189 -0
  118. data/lib/spark/stat_counter.rb +125 -0
  119. data/lib/spark/storage_level.rb +39 -0
  120. data/lib/spark/version.rb +3 -0
  121. data/lib/spark/worker/master.rb +144 -0
  122. data/lib/spark/worker/spark_files.rb +15 -0
  123. data/lib/spark/worker/worker.rb +197 -0
  124. data/ruby-spark.gemspec +36 -0
  125. data/spec/generator.rb +37 -0
  126. data/spec/inputs/lorem_300.txt +316 -0
  127. data/spec/inputs/numbers/1.txt +50 -0
  128. data/spec/inputs/numbers/10.txt +50 -0
  129. data/spec/inputs/numbers/11.txt +50 -0
  130. data/spec/inputs/numbers/12.txt +50 -0
  131. data/spec/inputs/numbers/13.txt +50 -0
  132. data/spec/inputs/numbers/14.txt +50 -0
  133. data/spec/inputs/numbers/15.txt +50 -0
  134. data/spec/inputs/numbers/16.txt +50 -0
  135. data/spec/inputs/numbers/17.txt +50 -0
  136. data/spec/inputs/numbers/18.txt +50 -0
  137. data/spec/inputs/numbers/19.txt +50 -0
  138. data/spec/inputs/numbers/2.txt +50 -0
  139. data/spec/inputs/numbers/20.txt +50 -0
  140. data/spec/inputs/numbers/3.txt +50 -0
  141. data/spec/inputs/numbers/4.txt +50 -0
  142. data/spec/inputs/numbers/5.txt +50 -0
  143. data/spec/inputs/numbers/6.txt +50 -0
  144. data/spec/inputs/numbers/7.txt +50 -0
  145. data/spec/inputs/numbers/8.txt +50 -0
  146. data/spec/inputs/numbers/9.txt +50 -0
  147. data/spec/inputs/numbers_0_100.txt +101 -0
  148. data/spec/inputs/numbers_1_100.txt +100 -0
  149. data/spec/lib/collect_spec.rb +42 -0
  150. data/spec/lib/command_spec.rb +68 -0
  151. data/spec/lib/config_spec.rb +64 -0
  152. data/spec/lib/context_spec.rb +163 -0
  153. data/spec/lib/ext_spec.rb +72 -0
  154. data/spec/lib/external_apps_spec.rb +45 -0
  155. data/spec/lib/filter_spec.rb +80 -0
  156. data/spec/lib/flat_map_spec.rb +100 -0
  157. data/spec/lib/group_spec.rb +109 -0
  158. data/spec/lib/helper_spec.rb +19 -0
  159. data/spec/lib/key_spec.rb +41 -0
  160. data/spec/lib/manipulation_spec.rb +114 -0
  161. data/spec/lib/map_partitions_spec.rb +87 -0
  162. data/spec/lib/map_spec.rb +91 -0
  163. data/spec/lib/mllib/classification_spec.rb +54 -0
  164. data/spec/lib/mllib/clustering_spec.rb +35 -0
  165. data/spec/lib/mllib/matrix_spec.rb +32 -0
  166. data/spec/lib/mllib/regression_spec.rb +116 -0
  167. data/spec/lib/mllib/vector_spec.rb +77 -0
  168. data/spec/lib/reduce_by_key_spec.rb +118 -0
  169. data/spec/lib/reduce_spec.rb +131 -0
  170. data/spec/lib/sample_spec.rb +46 -0
  171. data/spec/lib/serializer_spec.rb +13 -0
  172. data/spec/lib/sort_spec.rb +58 -0
  173. data/spec/lib/statistic_spec.rb +168 -0
  174. data/spec/lib/whole_text_files_spec.rb +33 -0
  175. data/spec/spec_helper.rb +39 -0
  176. metadata +301 -0
@@ -0,0 +1,113 @@
1
+ package org.apache.spark.api.ruby.marshal
2
+
3
+ import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
4
+
5
+ import scala.collection.mutable.ArrayBuffer
6
+ import scala.collection.JavaConverters._
7
+ import scala.reflect.{ClassTag, classTag}
8
+
9
+ import org.apache.spark.mllib.regression.LabeledPoint
10
+ import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
11
+
12
+
13
+ /* =================================================================================================
14
+ * class MarshalDump
15
+ * =================================================================================================
16
+ */
17
+ class MarshalDump(os: DataOutputStream) {
18
+
19
+ val NAN_BYTELIST = "nan".getBytes
20
+ val NEGATIVE_INFINITY_BYTELIST = "-inf".getBytes
21
+ val INFINITY_BYTELIST = "inf".getBytes
22
+
23
+ def dump(data: Any) {
24
+ data match {
25
+ case null =>
26
+ os.writeByte('0')
27
+
28
+ case item: Boolean =>
29
+ val char = if(item) 'T' else 'F'
30
+ os.writeByte(char)
31
+
32
+ case item: Int =>
33
+ os.writeByte('i')
34
+ dumpInt(item)
35
+
36
+ case item: Array[_] =>
37
+ os.writeByte('[')
38
+ dumpArray(item)
39
+
40
+ case item: Double =>
41
+ os.writeByte('f')
42
+ dumpFloat(item)
43
+
44
+ case item: ArrayBuffer[Any] => dump(item.toArray)
45
+ }
46
+ }
47
+
48
+ def dumpInt(data: Int) {
49
+ if(data == 0){
50
+ os.writeByte(0)
51
+ }
52
+ else if (0 < data && data < 123) {
53
+ os.writeByte(data + 5)
54
+ }
55
+ else if (-124 < data && data < 0) {
56
+ os.writeByte((data - 5) & 0xff)
57
+ }
58
+ else {
59
+ val buffer = new Array[Byte](4)
60
+ var value = data
61
+
62
+ var i = 0
63
+ while(i != 4 && value != 0 && value != -1){
64
+ buffer(i) = (value & 0xff).toByte
65
+ value = value >> 8
66
+
67
+ i += 1
68
+ }
69
+ val lenght = i + 1
70
+ if(value < 0){
71
+ os.writeByte(-lenght)
72
+ }
73
+ else{
74
+ os.writeByte(lenght)
75
+ }
76
+ os.write(buffer, 0, lenght)
77
+ }
78
+ }
79
+
80
+ def dumpArray(array: Array[_]) {
81
+ dumpInt(array.size)
82
+
83
+ for(item <- array) {
84
+ dump(item)
85
+ }
86
+ }
87
+
88
+ def dumpFloat(value: Double) {
89
+ if(value.isPosInfinity){
90
+ dumpString(NEGATIVE_INFINITY_BYTELIST)
91
+ }
92
+ else if(value.isNegInfinity){
93
+ dumpString(INFINITY_BYTELIST)
94
+ }
95
+ else if(value.isNaN){
96
+ dumpString(NAN_BYTELIST)
97
+ }
98
+ else{
99
+ // dumpString("%.17g".format(value))
100
+ dumpString(value.toString)
101
+ }
102
+ }
103
+
104
+ def dumpString(data: String) {
105
+ dumpString(data.getBytes)
106
+ }
107
+
108
+ def dumpString(data: Array[Byte]) {
109
+ dumpInt(data.size)
110
+ os.write(data)
111
+ }
112
+
113
+ }
@@ -0,0 +1,220 @@
1
+ package org.apache.spark.api.ruby.marshal
2
+
3
+ import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
4
+
5
+ import scala.collection.mutable.ArrayBuffer
6
+ import scala.collection.JavaConverters._
7
+ import scala.reflect.{ClassTag, classTag}
8
+
9
+ import org.apache.spark.mllib.regression.LabeledPoint
10
+ import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
11
+
12
+
13
+ /* =================================================================================================
14
+ * class MarshalLoad
15
+ * =================================================================================================
16
+ */
17
+ class MarshalLoad(is: DataInputStream) {
18
+
19
+ case class WaitForObject()
20
+
21
+ val registeredSymbols = ArrayBuffer[String]()
22
+ val registeredLinks = ArrayBuffer[Any]()
23
+
24
+ def load: Any = {
25
+ load(is.readUnsignedByte.toChar)
26
+ }
27
+
28
+ def load(dataType: Char): Any = {
29
+ dataType match {
30
+ case '0' => null
31
+ case 'T' => true
32
+ case 'F' => false
33
+ case 'i' => loadInt
34
+ case 'f' => loadAndRegisterFloat
35
+ case ':' => loadAndRegisterSymbol
36
+ case '[' => loadAndRegisterArray
37
+ case 'U' => loadAndRegisterUserObject
38
+ case _ =>
39
+ throw new IllegalArgumentException(s"Format is not supported: $dataType.")
40
+ }
41
+ }
42
+
43
+
44
+ // ----------------------------------------------------------------------------------------------
45
+ // Load by type
46
+
47
+ def loadInt: Int = {
48
+ var c = is.readByte.toInt
49
+
50
+ if (c == 0) {
51
+ return 0
52
+ } else if (4 < c && c < 128) {
53
+ return c - 5
54
+ } else if (-129 < c && c < -4) {
55
+ return c + 5
56
+ }
57
+
58
+ var result: Long = 0
59
+
60
+ if (c > 0) {
61
+ result = 0
62
+ for( i <- 0 until c ) {
63
+ result |= (is.readUnsignedByte << (8 * i)).toLong
64
+ }
65
+ } else {
66
+ c = -c
67
+ result = -1
68
+ for( i <- 0 until c ) {
69
+ result &= ~((0xff << (8 * i)).toLong)
70
+ result |= (is.readUnsignedByte << (8 * i)).toLong
71
+ }
72
+ }
73
+
74
+ result.toInt
75
+ }
76
+
77
+ def loadAndRegisterFloat: Double = {
78
+ val result = loadFloat
79
+ registeredLinks += result
80
+ result
81
+ }
82
+
83
+ def loadFloat: Double = {
84
+ val string = loadString
85
+ string match {
86
+ case "nan" => Double.NaN
87
+ case "inf" => Double.PositiveInfinity
88
+ case "-inf" => Double.NegativeInfinity
89
+ case _ => string.toDouble
90
+ }
91
+ }
92
+
93
+ def loadString: String = {
94
+ new String(loadStringBytes)
95
+ }
96
+
97
+ def loadStringBytes: Array[Byte] = {
98
+ val size = loadInt
99
+ val buffer = new Array[Byte](size)
100
+
101
+ var readSize = 0
102
+ while(readSize < size){
103
+ val read = is.read(buffer, readSize, size-readSize)
104
+
105
+ if(read == -1){
106
+ throw new IllegalArgumentException("Marshal too short.")
107
+ }
108
+
109
+ readSize += read
110
+ }
111
+
112
+ buffer
113
+ }
114
+
115
+ def loadAndRegisterSymbol: String = {
116
+ val result = loadString
117
+ registeredSymbols += result
118
+ result
119
+ }
120
+
121
+ def loadAndRegisterArray: Array[Any] = {
122
+ val size = loadInt
123
+ val array = new Array[Any](size)
124
+
125
+ registeredLinks += array
126
+
127
+ for( i <- 0 until size ) {
128
+ array(i) = loadNextObject
129
+ }
130
+
131
+ array
132
+ }
133
+
134
+ def loadAndRegisterUserObject: Any = {
135
+ val klass = loadNextObject.asInstanceOf[String]
136
+
137
+ // Register future class before load the next object
138
+ registeredLinks += WaitForObject()
139
+ val index = registeredLinks.size - 1
140
+
141
+ val data = loadNextObject
142
+
143
+ val result = klass match {
144
+ case "Spark::Mllib::LabeledPoint" => createLabeledPoint(data)
145
+ case "Spark::Mllib::DenseVector" => createDenseVector(data)
146
+ case "Spark::Mllib::SparseVector" => createSparseVector(data)
147
+ case other =>
148
+ throw new IllegalArgumentException(s"Object $other is not supported.")
149
+ }
150
+
151
+ registeredLinks(index) = result
152
+
153
+ result
154
+ }
155
+
156
+
157
+ // ----------------------------------------------------------------------------------------------
158
+ // Other loads
159
+
160
+ def loadNextObject: Any = {
161
+ val dataType = is.readUnsignedByte.toChar
162
+
163
+ if(isLinkType(dataType)){
164
+ readLink(dataType)
165
+ }
166
+ else{
167
+ load(dataType)
168
+ }
169
+ }
170
+
171
+
172
+ // ----------------------------------------------------------------------------------------------
173
+ // To java objects
174
+
175
+ def createLabeledPoint(data: Any): LabeledPoint = {
176
+ val array = data.asInstanceOf[Array[_]]
177
+ new LabeledPoint(array(0).asInstanceOf[Double], array(1).asInstanceOf[Vector])
178
+ }
179
+
180
+ def createDenseVector(data: Any): DenseVector = {
181
+ new DenseVector(data.asInstanceOf[Array[_]].map(toDouble(_)))
182
+ }
183
+
184
+ def createSparseVector(data: Any): SparseVector = {
185
+ val array = data.asInstanceOf[Array[_]]
186
+ val size = array(0).asInstanceOf[Int]
187
+ val indices = array(1).asInstanceOf[Array[_]].map(_.asInstanceOf[Int])
188
+ val values = array(2).asInstanceOf[Array[_]].map(toDouble(_))
189
+
190
+ new SparseVector(size, indices, values)
191
+ }
192
+
193
+
194
+ // ----------------------------------------------------------------------------------------------
195
+ // Helpers
196
+
197
+ def toDouble(data: Any): Double = data match {
198
+ case x: Int => x.toDouble
199
+ case x: Double => x
200
+ case _ => 0.0
201
+ }
202
+
203
+
204
+ // ----------------------------------------------------------------------------------------------
205
+ // Cache
206
+
207
+ def readLink(dataType: Char): Any = {
208
+ val index = loadInt
209
+
210
+ dataType match {
211
+ case '@' => registeredLinks(index)
212
+ case ';' => registeredSymbols(index)
213
+ }
214
+ }
215
+
216
+ def isLinkType(dataType: Char): Boolean = {
217
+ dataType == ';' || dataType == '@'
218
+ }
219
+
220
+ }
@@ -0,0 +1,69 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import java.io._
4
+ import java.net._
5
+ import java.util.{List, ArrayList}
6
+
7
+ import scala.collection.JavaConversions._
8
+ import scala.collection.immutable._
9
+
10
+ import org.apache.spark._
11
+ import org.apache.spark.util.Utils
12
+
13
+ /**
14
+ * Internal class that acts as an `AccumulatorParam` for Ruby accumulators. Inside, it
15
+ * collects a list of pickled strings that we pass to Ruby through a socket.
16
+ */
17
+ private class RubyAccumulatorParam(serverHost: String, serverPort: Int)
18
+ extends AccumulatorParam[List[Array[Byte]]] {
19
+
20
+ // Utils.checkHost(serverHost, "Expected hostname")
21
+
22
+ val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
23
+
24
+ // Socket shoudl not be serialized
25
+ // Otherwise: SparkException: Task not serializable
26
+ @transient var socket: Socket = null
27
+ @transient var socketOutputStream: DataOutputStream = null
28
+ @transient var socketInputStream: DataInputStream = null
29
+
30
+ def openSocket(){
31
+ synchronized {
32
+ if (socket == null || socket.isClosed) {
33
+ socket = new Socket(serverHost, serverPort)
34
+
35
+ socketInputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, bufferSize))
36
+ socketOutputStream = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
37
+ }
38
+ }
39
+ }
40
+
41
+ override def zero(value: List[Array[Byte]]): List[Array[Byte]] = new ArrayList
42
+
43
+ override def addInPlace(val1: List[Array[Byte]], val2: List[Array[Byte]]) : List[Array[Byte]] = synchronized {
44
+ if (serverHost == null) {
45
+ // This happens on the worker node, where we just want to remember all the updates
46
+ val1.addAll(val2)
47
+ val1
48
+ } else {
49
+ // This happens on the master, where we pass the updates to Ruby through a socket
50
+ openSocket()
51
+
52
+ socketOutputStream.writeInt(val2.size)
53
+ for (array <- val2) {
54
+ socketOutputStream.writeInt(array.length)
55
+ socketOutputStream.write(array)
56
+ }
57
+ socketOutputStream.flush()
58
+
59
+ // Wait for acknowledgement
60
+ // http://stackoverflow.com/questions/28560133/ruby-server-java-scala-client-deadlock
61
+ //
62
+ // if(in.readInt() != RubyConstant.ACCUMULATOR_ACK){
63
+ // throw new SparkException("Accumulator was not acknowledged")
64
+ // }
65
+
66
+ new ArrayList
67
+ }
68
+ }
69
+ }
@@ -0,0 +1,13 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import org.apache.spark.api.python.PythonBroadcast
4
+
5
+ /**
6
+ * An Wrapper for Ruby Broadcast, which is written into disk by Ruby. It also will
7
+ * write the data into disk after deserialization, then Ruby can read it from disks.
8
+ *
9
+ * Class use Python logic - only for semantic
10
+ */
11
+ class RubyBroadcast(@transient var _path: String, @transient var id: java.lang.Long) extends PythonBroadcast(_path) {
12
+
13
+ }
@@ -0,0 +1,13 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ object RubyConstant {
4
+ val DATA_EOF = -2
5
+ val WORKER_ERROR = -1
6
+ val WORKER_DONE = 0
7
+ val CREATE_WORKER = 1
8
+ val KILL_WORKER = 2
9
+ val KILL_WORKER_AND_WAIT = 3
10
+ val SUCCESSFULLY_KILLED = 4
11
+ val UNSUCCESSFUL_KILLING = 5
12
+ val ACCUMULATOR_ACK = 6
13
+ }
@@ -0,0 +1,55 @@
1
+ package org.apache.spark.mllib.api.ruby
2
+
3
+ import java.util.ArrayList
4
+
5
+ import scala.collection.JavaConverters._
6
+
7
+ import org.apache.spark.rdd.RDD
8
+ import org.apache.spark.api.java.JavaRDD
9
+ import org.apache.spark.mllib.linalg._
10
+ import org.apache.spark.mllib.regression.LabeledPoint
11
+ import org.apache.spark.mllib.classification.NaiveBayes
12
+ import org.apache.spark.mllib.clustering.GaussianMixtureModel
13
+ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
14
+ import org.apache.spark.mllib.api.python.MLLibAPI
15
+
16
+
17
+ class RubyMLLibAPI extends MLLibAPI {
18
+ // trainLinearRegressionModelWithSGD
19
+ // trainLassoModelWithSGD
20
+ // trainRidgeModelWithSGD
21
+ // trainLogisticRegressionModelWithSGD
22
+ // trainLogisticRegressionModelWithLBFGS
23
+ // trainSVMModelWithSGD
24
+ // trainKMeansModel
25
+ // trainGaussianMixture
26
+
27
+ // Rjb have a problem with theta: Array[Array[Double]]
28
+ override def trainNaiveBayes(data: JavaRDD[LabeledPoint], lambda: Double) = {
29
+ val model = NaiveBayes.train(data.rdd, lambda)
30
+
31
+ List(
32
+ Vectors.dense(model.labels),
33
+ Vectors.dense(model.pi),
34
+ model.theta.toSeq
35
+ ).map(_.asInstanceOf[Object]).asJava
36
+ }
37
+
38
+ // On python is wt just Object
39
+ def predictSoftGMM(
40
+ data: JavaRDD[Vector],
41
+ wt: ArrayList[Object],
42
+ mu: ArrayList[Object],
43
+ si: ArrayList[Object]): RDD[Array[Double]] = {
44
+
45
+ // val weight = wt.asInstanceOf[Array[Double]]
46
+ val weight = wt.toArray.map(_.asInstanceOf[Double])
47
+ val mean = mu.toArray.map(_.asInstanceOf[DenseVector])
48
+ val sigma = si.toArray.map(_.asInstanceOf[DenseMatrix])
49
+ val gaussians = Array.tabulate(weight.length){
50
+ i => new MultivariateGaussian(mean(i), sigma(i))
51
+ }
52
+ val model = new GaussianMixtureModel(weight, gaussians)
53
+ model.predictSoft(data)
54
+ }
55
+ }