ruby-spark 1.1.0.1-java

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +252 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +6 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/comparison/prepare.sh +18 -0
  12. data/benchmark/comparison/python.py +156 -0
  13. data/benchmark/comparison/r.r +69 -0
  14. data/benchmark/comparison/ruby.rb +167 -0
  15. data/benchmark/comparison/run-all.sh +160 -0
  16. data/benchmark/comparison/scala.scala +181 -0
  17. data/benchmark/custom_marshal.rb +94 -0
  18. data/benchmark/digest.rb +150 -0
  19. data/benchmark/enumerator.rb +88 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/example/website_search.rb +83 -0
  27. data/ext/ruby_c/extconf.rb +3 -0
  28. data/ext/ruby_c/murmur.c +158 -0
  29. data/ext/ruby_c/murmur.h +9 -0
  30. data/ext/ruby_c/ruby-spark.c +18 -0
  31. data/ext/ruby_java/Digest.java +36 -0
  32. data/ext/ruby_java/Murmur2.java +98 -0
  33. data/ext/ruby_java/RubySparkExtService.java +28 -0
  34. data/ext/ruby_java/extconf.rb +3 -0
  35. data/ext/spark/build.sbt +73 -0
  36. data/ext/spark/project/plugins.sbt +9 -0
  37. data/ext/spark/sbt/sbt +34 -0
  38. data/ext/spark/src/main/scala/Exec.scala +91 -0
  39. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  40. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  41. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  42. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  43. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  44. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  46. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  47. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  48. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  49. data/ext/spark/src/main/scala/RubyRDD.scala +392 -0
  50. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  51. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  52. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  53. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  54. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  55. data/lib/ruby-spark.rb +1 -0
  56. data/lib/spark.rb +198 -0
  57. data/lib/spark/accumulator.rb +260 -0
  58. data/lib/spark/broadcast.rb +98 -0
  59. data/lib/spark/build.rb +43 -0
  60. data/lib/spark/cli.rb +169 -0
  61. data/lib/spark/command.rb +86 -0
  62. data/lib/spark/command/base.rb +158 -0
  63. data/lib/spark/command/basic.rb +345 -0
  64. data/lib/spark/command/pair.rb +124 -0
  65. data/lib/spark/command/sort.rb +51 -0
  66. data/lib/spark/command/statistic.rb +144 -0
  67. data/lib/spark/command_builder.rb +141 -0
  68. data/lib/spark/command_validator.rb +34 -0
  69. data/lib/spark/config.rb +238 -0
  70. data/lib/spark/constant.rb +14 -0
  71. data/lib/spark/context.rb +322 -0
  72. data/lib/spark/error.rb +50 -0
  73. data/lib/spark/ext/hash.rb +41 -0
  74. data/lib/spark/ext/integer.rb +25 -0
  75. data/lib/spark/ext/io.rb +67 -0
  76. data/lib/spark/ext/ip_socket.rb +29 -0
  77. data/lib/spark/ext/module.rb +58 -0
  78. data/lib/spark/ext/object.rb +24 -0
  79. data/lib/spark/ext/string.rb +24 -0
  80. data/lib/spark/helper.rb +10 -0
  81. data/lib/spark/helper/logger.rb +40 -0
  82. data/lib/spark/helper/parser.rb +85 -0
  83. data/lib/spark/helper/serialize.rb +71 -0
  84. data/lib/spark/helper/statistic.rb +93 -0
  85. data/lib/spark/helper/system.rb +42 -0
  86. data/lib/spark/java_bridge.rb +19 -0
  87. data/lib/spark/java_bridge/base.rb +203 -0
  88. data/lib/spark/java_bridge/jruby.rb +23 -0
  89. data/lib/spark/java_bridge/rjb.rb +41 -0
  90. data/lib/spark/logger.rb +76 -0
  91. data/lib/spark/mllib.rb +100 -0
  92. data/lib/spark/mllib/classification/common.rb +31 -0
  93. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  94. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  95. data/lib/spark/mllib/classification/svm.rb +135 -0
  96. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  97. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  98. data/lib/spark/mllib/matrix.rb +120 -0
  99. data/lib/spark/mllib/regression/common.rb +73 -0
  100. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  101. data/lib/spark/mllib/regression/lasso.rb +100 -0
  102. data/lib/spark/mllib/regression/linear.rb +124 -0
  103. data/lib/spark/mllib/regression/ridge.rb +97 -0
  104. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  105. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  106. data/lib/spark/mllib/stat/distribution.rb +12 -0
  107. data/lib/spark/mllib/vector.rb +185 -0
  108. data/lib/spark/rdd.rb +1377 -0
  109. data/lib/spark/sampler.rb +92 -0
  110. data/lib/spark/serializer.rb +79 -0
  111. data/lib/spark/serializer/auto_batched.rb +59 -0
  112. data/lib/spark/serializer/base.rb +63 -0
  113. data/lib/spark/serializer/batched.rb +84 -0
  114. data/lib/spark/serializer/cartesian.rb +13 -0
  115. data/lib/spark/serializer/compressed.rb +27 -0
  116. data/lib/spark/serializer/marshal.rb +17 -0
  117. data/lib/spark/serializer/message_pack.rb +23 -0
  118. data/lib/spark/serializer/oj.rb +23 -0
  119. data/lib/spark/serializer/pair.rb +41 -0
  120. data/lib/spark/serializer/text.rb +25 -0
  121. data/lib/spark/sort.rb +189 -0
  122. data/lib/spark/stat_counter.rb +125 -0
  123. data/lib/spark/storage_level.rb +39 -0
  124. data/lib/spark/version.rb +3 -0
  125. data/lib/spark/worker/master.rb +144 -0
  126. data/lib/spark/worker/spark_files.rb +15 -0
  127. data/lib/spark/worker/worker.rb +200 -0
  128. data/ruby-spark.gemspec +47 -0
  129. data/spec/generator.rb +37 -0
  130. data/spec/inputs/lorem_300.txt +316 -0
  131. data/spec/inputs/numbers/1.txt +50 -0
  132. data/spec/inputs/numbers/10.txt +50 -0
  133. data/spec/inputs/numbers/11.txt +50 -0
  134. data/spec/inputs/numbers/12.txt +50 -0
  135. data/spec/inputs/numbers/13.txt +50 -0
  136. data/spec/inputs/numbers/14.txt +50 -0
  137. data/spec/inputs/numbers/15.txt +50 -0
  138. data/spec/inputs/numbers/16.txt +50 -0
  139. data/spec/inputs/numbers/17.txt +50 -0
  140. data/spec/inputs/numbers/18.txt +50 -0
  141. data/spec/inputs/numbers/19.txt +50 -0
  142. data/spec/inputs/numbers/2.txt +50 -0
  143. data/spec/inputs/numbers/20.txt +50 -0
  144. data/spec/inputs/numbers/3.txt +50 -0
  145. data/spec/inputs/numbers/4.txt +50 -0
  146. data/spec/inputs/numbers/5.txt +50 -0
  147. data/spec/inputs/numbers/6.txt +50 -0
  148. data/spec/inputs/numbers/7.txt +50 -0
  149. data/spec/inputs/numbers/8.txt +50 -0
  150. data/spec/inputs/numbers/9.txt +50 -0
  151. data/spec/inputs/numbers_0_100.txt +101 -0
  152. data/spec/inputs/numbers_1_100.txt +100 -0
  153. data/spec/lib/collect_spec.rb +42 -0
  154. data/spec/lib/command_spec.rb +68 -0
  155. data/spec/lib/config_spec.rb +64 -0
  156. data/spec/lib/context_spec.rb +165 -0
  157. data/spec/lib/ext_spec.rb +72 -0
  158. data/spec/lib/external_apps_spec.rb +45 -0
  159. data/spec/lib/filter_spec.rb +80 -0
  160. data/spec/lib/flat_map_spec.rb +100 -0
  161. data/spec/lib/group_spec.rb +109 -0
  162. data/spec/lib/helper_spec.rb +19 -0
  163. data/spec/lib/key_spec.rb +41 -0
  164. data/spec/lib/manipulation_spec.rb +122 -0
  165. data/spec/lib/map_partitions_spec.rb +87 -0
  166. data/spec/lib/map_spec.rb +91 -0
  167. data/spec/lib/mllib/classification_spec.rb +54 -0
  168. data/spec/lib/mllib/clustering_spec.rb +35 -0
  169. data/spec/lib/mllib/matrix_spec.rb +32 -0
  170. data/spec/lib/mllib/regression_spec.rb +116 -0
  171. data/spec/lib/mllib/vector_spec.rb +77 -0
  172. data/spec/lib/reduce_by_key_spec.rb +118 -0
  173. data/spec/lib/reduce_spec.rb +131 -0
  174. data/spec/lib/sample_spec.rb +46 -0
  175. data/spec/lib/serializer_spec.rb +88 -0
  176. data/spec/lib/sort_spec.rb +58 -0
  177. data/spec/lib/statistic_spec.rb +170 -0
  178. data/spec/lib/whole_text_files_spec.rb +33 -0
  179. data/spec/spec_helper.rb +38 -0
  180. metadata +389 -0
@@ -0,0 +1,69 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import java.io._
4
+ import java.net._
5
+ import java.util.{List, ArrayList}
6
+
7
+ import scala.collection.JavaConversions._
8
+ import scala.collection.immutable._
9
+
10
+ import org.apache.spark._
11
+ import org.apache.spark.util.Utils
12
+
13
+ /**
14
+ * Internal class that acts as an `AccumulatorParam` for Ruby accumulators. Inside, it
15
+ * collects a list of pickled strings that we pass to Ruby through a socket.
16
+ */
17
+ private class RubyAccumulatorParam(serverHost: String, serverPort: Int)
18
+ extends AccumulatorParam[List[Array[Byte]]] {
19
+
20
+ // Utils.checkHost(serverHost, "Expected hostname")
21
+
22
+ val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
23
+
24
+ // Socket shoudl not be serialized
25
+ // Otherwise: SparkException: Task not serializable
26
+ @transient var socket: Socket = null
27
+ @transient var socketOutputStream: DataOutputStream = null
28
+ @transient var socketInputStream: DataInputStream = null
29
+
30
+ def openSocket(){
31
+ synchronized {
32
+ if (socket == null || socket.isClosed) {
33
+ socket = new Socket(serverHost, serverPort)
34
+
35
+ socketInputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, bufferSize))
36
+ socketOutputStream = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
37
+ }
38
+ }
39
+ }
40
+
41
+ override def zero(value: List[Array[Byte]]): List[Array[Byte]] = new ArrayList
42
+
43
+ override def addInPlace(val1: List[Array[Byte]], val2: List[Array[Byte]]) : List[Array[Byte]] = synchronized {
44
+ if (serverHost == null) {
45
+ // This happens on the worker node, where we just want to remember all the updates
46
+ val1.addAll(val2)
47
+ val1
48
+ } else {
49
+ // This happens on the master, where we pass the updates to Ruby through a socket
50
+ openSocket()
51
+
52
+ socketOutputStream.writeInt(val2.size)
53
+ for (array <- val2) {
54
+ socketOutputStream.writeInt(array.length)
55
+ socketOutputStream.write(array)
56
+ }
57
+ socketOutputStream.flush()
58
+
59
+ // Wait for acknowledgement
60
+ // http://stackoverflow.com/questions/28560133/ruby-server-java-scala-client-deadlock
61
+ //
62
+ // if(in.readInt() != RubyConstant.ACCUMULATOR_ACK){
63
+ // throw new SparkException("Accumulator was not acknowledged")
64
+ // }
65
+
66
+ new ArrayList
67
+ }
68
+ }
69
+ }
@@ -0,0 +1,13 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import org.apache.spark.api.python.PythonBroadcast
4
+
5
+ /**
6
+ * An Wrapper for Ruby Broadcast, which is written into disk by Ruby. It also will
7
+ * write the data into disk after deserialization, then Ruby can read it from disks.
8
+ *
9
+ * Class use Python logic - only for semantic
10
+ */
11
+ class RubyBroadcast(@transient var _path: String, @transient var id: java.lang.Long) extends PythonBroadcast(_path) {
12
+
13
+ }
@@ -0,0 +1,13 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ object RubyConstant {
4
+ val DATA_EOF = -2
5
+ val WORKER_ERROR = -1
6
+ val WORKER_DONE = 0
7
+ val CREATE_WORKER = 1
8
+ val KILL_WORKER = 2
9
+ val KILL_WORKER_AND_WAIT = 3
10
+ val SUCCESSFULLY_KILLED = 4
11
+ val UNSUCCESSFUL_KILLING = 5
12
+ val ACCUMULATOR_ACK = 6
13
+ }
@@ -0,0 +1,55 @@
1
+ package org.apache.spark.mllib.api.ruby
2
+
3
+ import java.util.ArrayList
4
+
5
+ import scala.collection.JavaConverters._
6
+
7
+ import org.apache.spark.rdd.RDD
8
+ import org.apache.spark.api.java.JavaRDD
9
+ import org.apache.spark.mllib.linalg._
10
+ import org.apache.spark.mllib.regression.LabeledPoint
11
+ import org.apache.spark.mllib.classification.NaiveBayes
12
+ import org.apache.spark.mllib.clustering.GaussianMixtureModel
13
+ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
14
+ import org.apache.spark.mllib.api.python.MLLibAPI
15
+
16
+
17
+ class RubyMLLibAPI extends MLLibAPI {
18
+ // trainLinearRegressionModelWithSGD
19
+ // trainLassoModelWithSGD
20
+ // trainRidgeModelWithSGD
21
+ // trainLogisticRegressionModelWithSGD
22
+ // trainLogisticRegressionModelWithLBFGS
23
+ // trainSVMModelWithSGD
24
+ // trainKMeansModel
25
+ // trainGaussianMixture
26
+
27
+ // Rjb have a problem with theta: Array[Array[Double]]
28
+ override def trainNaiveBayes(data: JavaRDD[LabeledPoint], lambda: Double) = {
29
+ val model = NaiveBayes.train(data.rdd, lambda)
30
+
31
+ List(
32
+ Vectors.dense(model.labels),
33
+ Vectors.dense(model.pi),
34
+ model.theta.toSeq
35
+ ).map(_.asInstanceOf[Object]).asJava
36
+ }
37
+
38
+ // On python is wt just Object
39
+ def predictSoftGMM(
40
+ data: JavaRDD[Vector],
41
+ wt: ArrayList[Object],
42
+ mu: ArrayList[Object],
43
+ si: ArrayList[Object]): RDD[Array[Double]] = {
44
+
45
+ // val weight = wt.asInstanceOf[Array[Double]]
46
+ val weight = wt.toArray.map(_.asInstanceOf[Double])
47
+ val mean = mu.toArray.map(_.asInstanceOf[DenseVector])
48
+ val sigma = si.toArray.map(_.asInstanceOf[DenseMatrix])
49
+ val gaussians = Array.tabulate(weight.length){
50
+ i => new MultivariateGaussian(mean(i), sigma(i))
51
+ }
52
+ val model = new GaussianMixtureModel(weight, gaussians)
53
+ model.predictSoft(data)
54
+ }
55
+ }
@@ -0,0 +1,21 @@
1
+ package org.apache.spark.mllib.api.ruby
2
+
3
+ import java.util.ArrayList
4
+
5
+ import org.apache.spark.mllib.util.LinearDataGenerator
6
+ import org.apache.spark.mllib.regression.LabeledPoint
7
+
8
+ object RubyMLLibUtilAPI {
9
+
10
+ // Ruby does have a problem with creating Array[Double]
11
+ def generateLinearInput(
12
+ intercept: Double,
13
+ weights: ArrayList[String],
14
+ nPoints: Int,
15
+ seed: Int,
16
+ eps: Double = 0.1): Seq[LabeledPoint] = {
17
+
18
+ LinearDataGenerator.generateLinearInput(intercept, weights.toArray.map(_.toString.toDouble), nPoints, seed, eps)
19
+ }
20
+
21
+ }
@@ -0,0 +1,34 @@
1
+ package org.apache.spark.ui.ruby
2
+
3
+ // import javax.servlet.http.HttpServletRequest
4
+
5
+ // import scala.xml.Node
6
+
7
+ // import org.apache.spark.ui.{WebUIPage, UIUtils}
8
+ // import org.apache.spark.util.Utils
9
+
10
+ // private[ui] class RubyPage(parent: RubyTab, rbConfig: Array[Tuple2[String, String]]) extends WebUIPage("") {
11
+
12
+ // def render(request: HttpServletRequest): Seq[Node] = {
13
+ // val content = UIUtils.listingTable(header, row, rbConfig)
14
+ // UIUtils.headerSparkPage("Ruby Config", content, parent)
15
+ // }
16
+
17
+ // private def header = Seq(
18
+ // "Number"
19
+ // )
20
+
21
+ // private def row(keyValue: (String, String)): Seq[Node] = {
22
+ // // scalastyle:off
23
+ // keyValue match {
24
+ // case (key, value) =>
25
+ // <tr>
26
+ // <td>{key}</td>
27
+ // <td>{value}</td>
28
+ // </tr>
29
+ // }
30
+ // // scalastyle:on
31
+ // }
32
+ // }
33
+
34
+ class RubyPage {}
@@ -0,0 +1,392 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import java.io._
4
+ import java.net._
5
+ import java.util.{List, ArrayList, Collections}
6
+
7
+ import scala.util.Try
8
+ import scala.reflect.ClassTag
9
+ import scala.collection.JavaConversions._
10
+
11
+ import org.apache.spark._
12
+ import org.apache.spark.{SparkEnv, Partition, SparkException, TaskContext}
13
+ import org.apache.spark.api.ruby._
14
+ import org.apache.spark.api.ruby.marshal._
15
+ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
16
+ import org.apache.spark.api.python.PythonRDD
17
+ import org.apache.spark.broadcast.Broadcast
18
+ import org.apache.spark.rdd.RDD
19
+ import org.apache.spark.util.Utils
20
+ import org.apache.spark.InterruptibleIterator
21
+
22
+
23
+ /* =================================================================================================
24
+ * Class RubyRDD
25
+ * =================================================================================================
26
+ */
27
+
28
+ class RubyRDD(
29
+ @transient parent: RDD[_],
30
+ command: Array[Byte],
31
+ broadcastVars: ArrayList[Broadcast[RubyBroadcast]],
32
+ accumulator: Accumulator[List[Array[Byte]]])
33
+ extends RDD[Array[Byte]](parent){
34
+
35
+ val bufferSize = conf.getInt("spark.buffer.size", 65536)
36
+
37
+ val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
38
+
39
+ override def getPartitions: Array[Partition] = firstParent.partitions
40
+
41
+ override val partitioner = None
42
+
43
+ /* ------------------------------------------------------------------------------------------ */
44
+
45
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
46
+
47
+ val env = SparkEnv.get
48
+
49
+ // Get worker and id
50
+ val (worker, workerId) = RubyWorker.create(env)
51
+
52
+ // Start a thread to feed the process input from our parent's iterator
53
+ val writerThread = new WriterThread(env, worker, split, context)
54
+
55
+ context.addTaskCompletionListener { context =>
56
+ writerThread.shutdownOnTaskCompletion()
57
+ writerThread.join()
58
+
59
+ // Cleanup the worker socket. This will also cause the worker to exit.
60
+ try {
61
+ RubyWorker.remove(worker, workerId)
62
+ worker.close()
63
+ } catch {
64
+ case e: Exception => logWarning("Failed to close worker socket", e)
65
+ }
66
+ }
67
+
68
+ val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
69
+
70
+ // Send data
71
+ writerThread.start()
72
+
73
+ // For violent termination of worker
74
+ new MonitorThread(workerId, worker, context).start()
75
+
76
+ // Return an iterator that read lines from the process's stdout
77
+ val stdoutIterator = new StreamReader(stream, writerThread, context)
78
+
79
+ // An iterator that wraps around an existing iterator to provide task killing functionality.
80
+ new InterruptibleIterator(context, stdoutIterator)
81
+
82
+ } // end compute
83
+
84
+ /* ------------------------------------------------------------------------------------------ */
85
+
86
+ class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
87
+ extends Thread("stdout writer for worker") {
88
+
89
+ @volatile private var _exception: Exception = null
90
+
91
+ setDaemon(true)
92
+
93
+ // Contains the exception thrown while writing the parent iterator to the process.
94
+ def exception: Option[Exception] = Option(_exception)
95
+
96
+ // Terminates the writer thread, ignoring any exceptions that may occur due to cleanup.
97
+ def shutdownOnTaskCompletion() {
98
+ assert(context.isCompleted)
99
+ this.interrupt()
100
+ }
101
+
102
+ // -------------------------------------------------------------------------------------------
103
+ // Send the necessary data for worker
104
+ // - split index
105
+ // - command
106
+ // - iterator
107
+
108
+ override def run(): Unit = Utils.logUncaughtExceptions {
109
+ try {
110
+ SparkEnv.set(env)
111
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
112
+ val dataOut = new DataOutputStream(stream)
113
+
114
+ // Partition index
115
+ dataOut.writeInt(split.index)
116
+
117
+ // Spark files
118
+ PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
119
+
120
+ // Broadcast variables
121
+ dataOut.writeInt(broadcastVars.length)
122
+ for (broadcast <- broadcastVars) {
123
+ dataOut.writeLong(broadcast.value.id)
124
+ PythonRDD.writeUTF(broadcast.value.path, dataOut)
125
+ }
126
+
127
+ // Serialized command
128
+ dataOut.writeInt(command.length)
129
+ dataOut.write(command)
130
+
131
+ // Send it
132
+ dataOut.flush()
133
+
134
+ // Data
135
+ PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
136
+ dataOut.writeInt(RubyConstant.DATA_EOF)
137
+ dataOut.flush()
138
+ } catch {
139
+ case e: Exception if context.isCompleted || context.isInterrupted =>
140
+ logDebug("Exception thrown after task completion (likely due to cleanup)", e)
141
+
142
+ case e: Exception =>
143
+ // We must avoid throwing exceptions here, because the thread uncaught exception handler
144
+ // will kill the whole executor (see org.apache.spark.executor.Executor).
145
+ _exception = e
146
+ } finally {
147
+ Try(worker.shutdownOutput()) // kill worker process
148
+ }
149
+ }
150
+ } // end WriterThread
151
+
152
+
153
+ /* ------------------------------------------------------------------------------------------ */
154
+
155
+ class StreamReader(stream: DataInputStream, writerThread: WriterThread, context: TaskContext) extends Iterator[Array[Byte]] {
156
+
157
+ def hasNext = _nextObj != null
158
+ var _nextObj = read()
159
+
160
+ // -------------------------------------------------------------------------------------------
161
+
162
+ def next(): Array[Byte] = {
163
+ val obj = _nextObj
164
+ if (hasNext) {
165
+ _nextObj = read()
166
+ }
167
+ obj
168
+ }
169
+
170
+ // -------------------------------------------------------------------------------------------
171
+
172
+ private def read(): Array[Byte] = {
173
+ if (writerThread.exception.isDefined) {
174
+ throw writerThread.exception.get
175
+ }
176
+ try {
177
+ stream.readInt() match {
178
+ case length if length > 0 =>
179
+ val obj = new Array[Byte](length)
180
+ stream.readFully(obj)
181
+ obj
182
+ case RubyConstant.WORKER_DONE =>
183
+ val numAccumulatorUpdates = stream.readInt()
184
+ (1 to numAccumulatorUpdates).foreach { _ =>
185
+ val updateLen = stream.readInt()
186
+ val update = new Array[Byte](updateLen)
187
+ stream.readFully(update)
188
+ accumulator += Collections.singletonList(update)
189
+ }
190
+ null
191
+ case RubyConstant.WORKER_ERROR =>
192
+ // Exception from worker
193
+
194
+ // message
195
+ val length = stream.readInt()
196
+ val obj = new Array[Byte](length)
197
+ stream.readFully(obj)
198
+
199
+ // stackTrace
200
+ val stackTraceLen = stream.readInt()
201
+ val stackTrace = new Array[String](stackTraceLen)
202
+ (0 until stackTraceLen).foreach { i =>
203
+ val length = stream.readInt()
204
+ val obj = new Array[Byte](length)
205
+ stream.readFully(obj)
206
+
207
+ stackTrace(i) = new String(obj, "utf-8")
208
+ }
209
+
210
+ // Worker will be killed
211
+ stream.close
212
+
213
+ // exception
214
+ val exception = new RubyException(new String(obj, "utf-8"), writerThread.exception.getOrElse(null))
215
+ exception.appendToStackTrace(stackTrace)
216
+
217
+ throw exception
218
+ }
219
+ } catch {
220
+
221
+ case e: Exception if context.isInterrupted =>
222
+ logDebug("Exception thrown after task interruption", e)
223
+ throw new TaskKilledException
224
+
225
+ case e: Exception if writerThread.exception.isDefined =>
226
+ logError("Worker exited unexpectedly (crashed)", e)
227
+ throw writerThread.exception.get
228
+
229
+ case eof: EOFException =>
230
+ throw new SparkException("Worker exited unexpectedly (crashed)", eof)
231
+ }
232
+ }
233
+ } // end StreamReader
234
+
235
+ /* ---------------------------------------------------------------------------------------------
236
+ * Monitor thread for controll worker. Kill worker if task is interrupted.
237
+ */
238
+
239
+ class MonitorThread(workerId: Long, worker: Socket, context: TaskContext)
240
+ extends Thread("Worker Monitor for worker") {
241
+
242
+ setDaemon(true)
243
+
244
+ override def run() {
245
+ // Kill the worker if it is interrupted, checking until task completion.
246
+ while (!context.isInterrupted && !context.isCompleted) {
247
+ Thread.sleep(2000)
248
+ }
249
+ if (!context.isCompleted) {
250
+ try {
251
+ logWarning("Incomplete task interrupted: Attempting to kill Worker "+workerId.toString())
252
+ RubyWorker.kill(workerId)
253
+ } catch {
254
+ case e: Exception =>
255
+ logError("Exception when trying to kill worker "+workerId.toString(), e)
256
+ }
257
+ }
258
+ }
259
+ } // end MonitorThread
260
+ } // end RubyRDD
261
+
262
+
263
+
264
+ /* =================================================================================================
265
+ * Class PairwiseRDD
266
+ * =================================================================================================
267
+ *
268
+ * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Ruby.
269
+ * This is used by PySpark's shuffle operations.
270
+ * Borrowed from Python Package -> need new deserializeLongValue ->
271
+ * Marshal will add the same 4b header
272
+ */
273
+
274
+ class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) {
275
+ override def getPartitions = prev.partitions
276
+ override def compute(split: Partition, context: TaskContext) =
277
+ prev.iterator(split, context).grouped(2).map {
278
+ case Seq(a, b) => (Utils.deserializeLongValue(a.reverse), b)
279
+ case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
280
+ }
281
+ val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
282
+ }
283
+
284
+
285
+
286
+ /* =================================================================================================
287
+ * Object RubyRDD
288
+ * =================================================================================================
289
+ */
290
+
291
+ object RubyRDD extends Logging {
292
+
293
+ def runJob(
294
+ sc: SparkContext,
295
+ rdd: JavaRDD[Array[Byte]],
296
+ partitions: ArrayList[Int],
297
+ allowLocal: Boolean,
298
+ filename: String): String = {
299
+ type ByteArray = Array[Byte]
300
+ type UnrolledPartition = Array[ByteArray]
301
+ val allPartitions: Array[UnrolledPartition] =
302
+ sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
303
+ val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
304
+ writeRDDToFile(flattenedPartition.iterator, filename)
305
+ }
306
+
307
+ def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = {
308
+ val file = new DataInputStream(new BufferedInputStream(new FileInputStream(filename)))
309
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
310
+ try {
311
+ while (true) {
312
+ val length = file.readInt()
313
+ val obj = new Array[Byte](length)
314
+ file.readFully(obj)
315
+ objs.append(obj)
316
+ }
317
+ } catch {
318
+ case eof: EOFException => {}
319
+ }
320
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
321
+ }
322
+
323
+ def writeRDDToFile[T](items: Iterator[T], filename: String): String = {
324
+ val file = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filename)))
325
+
326
+ try {
327
+ PythonRDD.writeIteratorToStream(items, file)
328
+ } finally {
329
+ file.close()
330
+ }
331
+
332
+ filename
333
+ }
334
+
335
+ def writeRDDToFile[T](rdd: RDD[T], filename: String): String = {
336
+ writeRDDToFile(rdd.collect.iterator, filename)
337
+ }
338
+
339
+ def readBroadcastFromFile(sc: JavaSparkContext, path: String, id: java.lang.Long): Broadcast[RubyBroadcast] = {
340
+ sc.broadcast(new RubyBroadcast(path, id))
341
+ }
342
+
343
+ /**
344
+ * Convert an RDD of serialized Ruby objects to RDD of objects, that is usable in Java.
345
+ */
346
+ def toJava(rbRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
347
+ rbRDD.rdd.mapPartitions { iter =>
348
+ iter.flatMap { item =>
349
+ val obj = Marshal.load(item)
350
+ if(batched){
351
+ obj.asInstanceOf[Array[_]]
352
+ }
353
+ else{
354
+ Seq(item)
355
+ }
356
+ }
357
+ }.toJavaRDD()
358
+ }
359
+
360
+ /**
361
+ * Convert an RDD of Java objects to an RDD of serialized Ruby objects, that is usable by Ruby.
362
+ */
363
+ def toRuby(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
364
+ jRDD.rdd.mapPartitions { iter => new IterableMarshaller(iter) }
365
+ }
366
+
367
+ }
368
+
369
+
370
+
371
+ /* =================================================================================================
372
+ * Class RubyException
373
+ * =================================================================================================
374
+ */
375
+
376
+ class RubyException(msg: String, cause: Exception) extends RuntimeException(msg, cause) {
377
+ def appendToStackTrace(toAdded: Array[String]) {
378
+ val newStactTrace = getStackTrace.toBuffer
379
+
380
+ var regexpMatch = "(.*):([0-9]+):in `([a-z]+)'".r
381
+
382
+ for(item <- toAdded) {
383
+ item match {
384
+ case regexpMatch(fileName, lineNumber, methodName) =>
385
+ newStactTrace += new StackTraceElement("RubyWorker", methodName, fileName, lineNumber.toInt)
386
+ case _ => null
387
+ }
388
+ }
389
+
390
+ setStackTrace(newStactTrace.toArray)
391
+ }
392
+ }