ruby-spark 1.1.0.1-java

Sign up to get free protection for your applications and to get access to all the features.
Files changed (180) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +37 -0
  3. data/Gemfile +47 -0
  4. data/Guardfile +5 -0
  5. data/LICENSE.txt +22 -0
  6. data/README.md +252 -0
  7. data/Rakefile +35 -0
  8. data/TODO.md +6 -0
  9. data/benchmark/aggregate.rb +33 -0
  10. data/benchmark/bisect.rb +88 -0
  11. data/benchmark/comparison/prepare.sh +18 -0
  12. data/benchmark/comparison/python.py +156 -0
  13. data/benchmark/comparison/r.r +69 -0
  14. data/benchmark/comparison/ruby.rb +167 -0
  15. data/benchmark/comparison/run-all.sh +160 -0
  16. data/benchmark/comparison/scala.scala +181 -0
  17. data/benchmark/custom_marshal.rb +94 -0
  18. data/benchmark/digest.rb +150 -0
  19. data/benchmark/enumerator.rb +88 -0
  20. data/benchmark/serializer.rb +82 -0
  21. data/benchmark/sort.rb +43 -0
  22. data/benchmark/sort2.rb +164 -0
  23. data/benchmark/take.rb +28 -0
  24. data/bin/ruby-spark +8 -0
  25. data/example/pi.rb +28 -0
  26. data/example/website_search.rb +83 -0
  27. data/ext/ruby_c/extconf.rb +3 -0
  28. data/ext/ruby_c/murmur.c +158 -0
  29. data/ext/ruby_c/murmur.h +9 -0
  30. data/ext/ruby_c/ruby-spark.c +18 -0
  31. data/ext/ruby_java/Digest.java +36 -0
  32. data/ext/ruby_java/Murmur2.java +98 -0
  33. data/ext/ruby_java/RubySparkExtService.java +28 -0
  34. data/ext/ruby_java/extconf.rb +3 -0
  35. data/ext/spark/build.sbt +73 -0
  36. data/ext/spark/project/plugins.sbt +9 -0
  37. data/ext/spark/sbt/sbt +34 -0
  38. data/ext/spark/src/main/scala/Exec.scala +91 -0
  39. data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
  40. data/ext/spark/src/main/scala/Marshal.scala +52 -0
  41. data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
  42. data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
  43. data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
  44. data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
  45. data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
  46. data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
  47. data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
  48. data/ext/spark/src/main/scala/RubyPage.scala +34 -0
  49. data/ext/spark/src/main/scala/RubyRDD.scala +392 -0
  50. data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
  51. data/ext/spark/src/main/scala/RubyTab.scala +11 -0
  52. data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
  53. data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
  54. data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
  55. data/lib/ruby-spark.rb +1 -0
  56. data/lib/spark.rb +198 -0
  57. data/lib/spark/accumulator.rb +260 -0
  58. data/lib/spark/broadcast.rb +98 -0
  59. data/lib/spark/build.rb +43 -0
  60. data/lib/spark/cli.rb +169 -0
  61. data/lib/spark/command.rb +86 -0
  62. data/lib/spark/command/base.rb +158 -0
  63. data/lib/spark/command/basic.rb +345 -0
  64. data/lib/spark/command/pair.rb +124 -0
  65. data/lib/spark/command/sort.rb +51 -0
  66. data/lib/spark/command/statistic.rb +144 -0
  67. data/lib/spark/command_builder.rb +141 -0
  68. data/lib/spark/command_validator.rb +34 -0
  69. data/lib/spark/config.rb +238 -0
  70. data/lib/spark/constant.rb +14 -0
  71. data/lib/spark/context.rb +322 -0
  72. data/lib/spark/error.rb +50 -0
  73. data/lib/spark/ext/hash.rb +41 -0
  74. data/lib/spark/ext/integer.rb +25 -0
  75. data/lib/spark/ext/io.rb +67 -0
  76. data/lib/spark/ext/ip_socket.rb +29 -0
  77. data/lib/spark/ext/module.rb +58 -0
  78. data/lib/spark/ext/object.rb +24 -0
  79. data/lib/spark/ext/string.rb +24 -0
  80. data/lib/spark/helper.rb +10 -0
  81. data/lib/spark/helper/logger.rb +40 -0
  82. data/lib/spark/helper/parser.rb +85 -0
  83. data/lib/spark/helper/serialize.rb +71 -0
  84. data/lib/spark/helper/statistic.rb +93 -0
  85. data/lib/spark/helper/system.rb +42 -0
  86. data/lib/spark/java_bridge.rb +19 -0
  87. data/lib/spark/java_bridge/base.rb +203 -0
  88. data/lib/spark/java_bridge/jruby.rb +23 -0
  89. data/lib/spark/java_bridge/rjb.rb +41 -0
  90. data/lib/spark/logger.rb +76 -0
  91. data/lib/spark/mllib.rb +100 -0
  92. data/lib/spark/mllib/classification/common.rb +31 -0
  93. data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
  94. data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
  95. data/lib/spark/mllib/classification/svm.rb +135 -0
  96. data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
  97. data/lib/spark/mllib/clustering/kmeans.rb +118 -0
  98. data/lib/spark/mllib/matrix.rb +120 -0
  99. data/lib/spark/mllib/regression/common.rb +73 -0
  100. data/lib/spark/mllib/regression/labeled_point.rb +41 -0
  101. data/lib/spark/mllib/regression/lasso.rb +100 -0
  102. data/lib/spark/mllib/regression/linear.rb +124 -0
  103. data/lib/spark/mllib/regression/ridge.rb +97 -0
  104. data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
  105. data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
  106. data/lib/spark/mllib/stat/distribution.rb +12 -0
  107. data/lib/spark/mllib/vector.rb +185 -0
  108. data/lib/spark/rdd.rb +1377 -0
  109. data/lib/spark/sampler.rb +92 -0
  110. data/lib/spark/serializer.rb +79 -0
  111. data/lib/spark/serializer/auto_batched.rb +59 -0
  112. data/lib/spark/serializer/base.rb +63 -0
  113. data/lib/spark/serializer/batched.rb +84 -0
  114. data/lib/spark/serializer/cartesian.rb +13 -0
  115. data/lib/spark/serializer/compressed.rb +27 -0
  116. data/lib/spark/serializer/marshal.rb +17 -0
  117. data/lib/spark/serializer/message_pack.rb +23 -0
  118. data/lib/spark/serializer/oj.rb +23 -0
  119. data/lib/spark/serializer/pair.rb +41 -0
  120. data/lib/spark/serializer/text.rb +25 -0
  121. data/lib/spark/sort.rb +189 -0
  122. data/lib/spark/stat_counter.rb +125 -0
  123. data/lib/spark/storage_level.rb +39 -0
  124. data/lib/spark/version.rb +3 -0
  125. data/lib/spark/worker/master.rb +144 -0
  126. data/lib/spark/worker/spark_files.rb +15 -0
  127. data/lib/spark/worker/worker.rb +200 -0
  128. data/ruby-spark.gemspec +47 -0
  129. data/spec/generator.rb +37 -0
  130. data/spec/inputs/lorem_300.txt +316 -0
  131. data/spec/inputs/numbers/1.txt +50 -0
  132. data/spec/inputs/numbers/10.txt +50 -0
  133. data/spec/inputs/numbers/11.txt +50 -0
  134. data/spec/inputs/numbers/12.txt +50 -0
  135. data/spec/inputs/numbers/13.txt +50 -0
  136. data/spec/inputs/numbers/14.txt +50 -0
  137. data/spec/inputs/numbers/15.txt +50 -0
  138. data/spec/inputs/numbers/16.txt +50 -0
  139. data/spec/inputs/numbers/17.txt +50 -0
  140. data/spec/inputs/numbers/18.txt +50 -0
  141. data/spec/inputs/numbers/19.txt +50 -0
  142. data/spec/inputs/numbers/2.txt +50 -0
  143. data/spec/inputs/numbers/20.txt +50 -0
  144. data/spec/inputs/numbers/3.txt +50 -0
  145. data/spec/inputs/numbers/4.txt +50 -0
  146. data/spec/inputs/numbers/5.txt +50 -0
  147. data/spec/inputs/numbers/6.txt +50 -0
  148. data/spec/inputs/numbers/7.txt +50 -0
  149. data/spec/inputs/numbers/8.txt +50 -0
  150. data/spec/inputs/numbers/9.txt +50 -0
  151. data/spec/inputs/numbers_0_100.txt +101 -0
  152. data/spec/inputs/numbers_1_100.txt +100 -0
  153. data/spec/lib/collect_spec.rb +42 -0
  154. data/spec/lib/command_spec.rb +68 -0
  155. data/spec/lib/config_spec.rb +64 -0
  156. data/spec/lib/context_spec.rb +165 -0
  157. data/spec/lib/ext_spec.rb +72 -0
  158. data/spec/lib/external_apps_spec.rb +45 -0
  159. data/spec/lib/filter_spec.rb +80 -0
  160. data/spec/lib/flat_map_spec.rb +100 -0
  161. data/spec/lib/group_spec.rb +109 -0
  162. data/spec/lib/helper_spec.rb +19 -0
  163. data/spec/lib/key_spec.rb +41 -0
  164. data/spec/lib/manipulation_spec.rb +122 -0
  165. data/spec/lib/map_partitions_spec.rb +87 -0
  166. data/spec/lib/map_spec.rb +91 -0
  167. data/spec/lib/mllib/classification_spec.rb +54 -0
  168. data/spec/lib/mllib/clustering_spec.rb +35 -0
  169. data/spec/lib/mllib/matrix_spec.rb +32 -0
  170. data/spec/lib/mllib/regression_spec.rb +116 -0
  171. data/spec/lib/mllib/vector_spec.rb +77 -0
  172. data/spec/lib/reduce_by_key_spec.rb +118 -0
  173. data/spec/lib/reduce_spec.rb +131 -0
  174. data/spec/lib/sample_spec.rb +46 -0
  175. data/spec/lib/serializer_spec.rb +88 -0
  176. data/spec/lib/sort_spec.rb +58 -0
  177. data/spec/lib/statistic_spec.rb +170 -0
  178. data/spec/lib/whole_text_files_spec.rb +33 -0
  179. data/spec/spec_helper.rb +38 -0
  180. metadata +389 -0
@@ -0,0 +1,69 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import java.io._
4
+ import java.net._
5
+ import java.util.{List, ArrayList}
6
+
7
+ import scala.collection.JavaConversions._
8
+ import scala.collection.immutable._
9
+
10
+ import org.apache.spark._
11
+ import org.apache.spark.util.Utils
12
+
13
+ /**
14
+ * Internal class that acts as an `AccumulatorParam` for Ruby accumulators. Inside, it
15
+ * collects a list of pickled strings that we pass to Ruby through a socket.
16
+ */
17
+ private class RubyAccumulatorParam(serverHost: String, serverPort: Int)
18
+ extends AccumulatorParam[List[Array[Byte]]] {
19
+
20
+ // Utils.checkHost(serverHost, "Expected hostname")
21
+
22
+ val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
23
+
24
+ // Socket shoudl not be serialized
25
+ // Otherwise: SparkException: Task not serializable
26
+ @transient var socket: Socket = null
27
+ @transient var socketOutputStream: DataOutputStream = null
28
+ @transient var socketInputStream: DataInputStream = null
29
+
30
+ def openSocket(){
31
+ synchronized {
32
+ if (socket == null || socket.isClosed) {
33
+ socket = new Socket(serverHost, serverPort)
34
+
35
+ socketInputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, bufferSize))
36
+ socketOutputStream = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
37
+ }
38
+ }
39
+ }
40
+
41
+ override def zero(value: List[Array[Byte]]): List[Array[Byte]] = new ArrayList
42
+
43
+ override def addInPlace(val1: List[Array[Byte]], val2: List[Array[Byte]]) : List[Array[Byte]] = synchronized {
44
+ if (serverHost == null) {
45
+ // This happens on the worker node, where we just want to remember all the updates
46
+ val1.addAll(val2)
47
+ val1
48
+ } else {
49
+ // This happens on the master, where we pass the updates to Ruby through a socket
50
+ openSocket()
51
+
52
+ socketOutputStream.writeInt(val2.size)
53
+ for (array <- val2) {
54
+ socketOutputStream.writeInt(array.length)
55
+ socketOutputStream.write(array)
56
+ }
57
+ socketOutputStream.flush()
58
+
59
+ // Wait for acknowledgement
60
+ // http://stackoverflow.com/questions/28560133/ruby-server-java-scala-client-deadlock
61
+ //
62
+ // if(in.readInt() != RubyConstant.ACCUMULATOR_ACK){
63
+ // throw new SparkException("Accumulator was not acknowledged")
64
+ // }
65
+
66
+ new ArrayList
67
+ }
68
+ }
69
+ }
@@ -0,0 +1,13 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import org.apache.spark.api.python.PythonBroadcast
4
+
5
+ /**
6
+ * An Wrapper for Ruby Broadcast, which is written into disk by Ruby. It also will
7
+ * write the data into disk after deserialization, then Ruby can read it from disks.
8
+ *
9
+ * Class use Python logic - only for semantic
10
+ */
11
+ class RubyBroadcast(@transient var _path: String, @transient var id: java.lang.Long) extends PythonBroadcast(_path) {
12
+
13
+ }
@@ -0,0 +1,13 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ object RubyConstant {
4
+ val DATA_EOF = -2
5
+ val WORKER_ERROR = -1
6
+ val WORKER_DONE = 0
7
+ val CREATE_WORKER = 1
8
+ val KILL_WORKER = 2
9
+ val KILL_WORKER_AND_WAIT = 3
10
+ val SUCCESSFULLY_KILLED = 4
11
+ val UNSUCCESSFUL_KILLING = 5
12
+ val ACCUMULATOR_ACK = 6
13
+ }
@@ -0,0 +1,55 @@
1
+ package org.apache.spark.mllib.api.ruby
2
+
3
+ import java.util.ArrayList
4
+
5
+ import scala.collection.JavaConverters._
6
+
7
+ import org.apache.spark.rdd.RDD
8
+ import org.apache.spark.api.java.JavaRDD
9
+ import org.apache.spark.mllib.linalg._
10
+ import org.apache.spark.mllib.regression.LabeledPoint
11
+ import org.apache.spark.mllib.classification.NaiveBayes
12
+ import org.apache.spark.mllib.clustering.GaussianMixtureModel
13
+ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
14
+ import org.apache.spark.mllib.api.python.MLLibAPI
15
+
16
+
17
+ class RubyMLLibAPI extends MLLibAPI {
18
+ // trainLinearRegressionModelWithSGD
19
+ // trainLassoModelWithSGD
20
+ // trainRidgeModelWithSGD
21
+ // trainLogisticRegressionModelWithSGD
22
+ // trainLogisticRegressionModelWithLBFGS
23
+ // trainSVMModelWithSGD
24
+ // trainKMeansModel
25
+ // trainGaussianMixture
26
+
27
+ // Rjb have a problem with theta: Array[Array[Double]]
28
+ override def trainNaiveBayes(data: JavaRDD[LabeledPoint], lambda: Double) = {
29
+ val model = NaiveBayes.train(data.rdd, lambda)
30
+
31
+ List(
32
+ Vectors.dense(model.labels),
33
+ Vectors.dense(model.pi),
34
+ model.theta.toSeq
35
+ ).map(_.asInstanceOf[Object]).asJava
36
+ }
37
+
38
+ // On python is wt just Object
39
+ def predictSoftGMM(
40
+ data: JavaRDD[Vector],
41
+ wt: ArrayList[Object],
42
+ mu: ArrayList[Object],
43
+ si: ArrayList[Object]): RDD[Array[Double]] = {
44
+
45
+ // val weight = wt.asInstanceOf[Array[Double]]
46
+ val weight = wt.toArray.map(_.asInstanceOf[Double])
47
+ val mean = mu.toArray.map(_.asInstanceOf[DenseVector])
48
+ val sigma = si.toArray.map(_.asInstanceOf[DenseMatrix])
49
+ val gaussians = Array.tabulate(weight.length){
50
+ i => new MultivariateGaussian(mean(i), sigma(i))
51
+ }
52
+ val model = new GaussianMixtureModel(weight, gaussians)
53
+ model.predictSoft(data)
54
+ }
55
+ }
@@ -0,0 +1,21 @@
1
+ package org.apache.spark.mllib.api.ruby
2
+
3
+ import java.util.ArrayList
4
+
5
+ import org.apache.spark.mllib.util.LinearDataGenerator
6
+ import org.apache.spark.mllib.regression.LabeledPoint
7
+
8
+ object RubyMLLibUtilAPI {
9
+
10
+ // Ruby does have a problem with creating Array[Double]
11
+ def generateLinearInput(
12
+ intercept: Double,
13
+ weights: ArrayList[String],
14
+ nPoints: Int,
15
+ seed: Int,
16
+ eps: Double = 0.1): Seq[LabeledPoint] = {
17
+
18
+ LinearDataGenerator.generateLinearInput(intercept, weights.toArray.map(_.toString.toDouble), nPoints, seed, eps)
19
+ }
20
+
21
+ }
@@ -0,0 +1,34 @@
1
+ package org.apache.spark.ui.ruby
2
+
3
+ // import javax.servlet.http.HttpServletRequest
4
+
5
+ // import scala.xml.Node
6
+
7
+ // import org.apache.spark.ui.{WebUIPage, UIUtils}
8
+ // import org.apache.spark.util.Utils
9
+
10
+ // private[ui] class RubyPage(parent: RubyTab, rbConfig: Array[Tuple2[String, String]]) extends WebUIPage("") {
11
+
12
+ // def render(request: HttpServletRequest): Seq[Node] = {
13
+ // val content = UIUtils.listingTable(header, row, rbConfig)
14
+ // UIUtils.headerSparkPage("Ruby Config", content, parent)
15
+ // }
16
+
17
+ // private def header = Seq(
18
+ // "Number"
19
+ // )
20
+
21
+ // private def row(keyValue: (String, String)): Seq[Node] = {
22
+ // // scalastyle:off
23
+ // keyValue match {
24
+ // case (key, value) =>
25
+ // <tr>
26
+ // <td>{key}</td>
27
+ // <td>{value}</td>
28
+ // </tr>
29
+ // }
30
+ // // scalastyle:on
31
+ // }
32
+ // }
33
+
34
+ class RubyPage {}
@@ -0,0 +1,392 @@
1
+ package org.apache.spark.api.ruby
2
+
3
+ import java.io._
4
+ import java.net._
5
+ import java.util.{List, ArrayList, Collections}
6
+
7
+ import scala.util.Try
8
+ import scala.reflect.ClassTag
9
+ import scala.collection.JavaConversions._
10
+
11
+ import org.apache.spark._
12
+ import org.apache.spark.{SparkEnv, Partition, SparkException, TaskContext}
13
+ import org.apache.spark.api.ruby._
14
+ import org.apache.spark.api.ruby.marshal._
15
+ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
16
+ import org.apache.spark.api.python.PythonRDD
17
+ import org.apache.spark.broadcast.Broadcast
18
+ import org.apache.spark.rdd.RDD
19
+ import org.apache.spark.util.Utils
20
+ import org.apache.spark.InterruptibleIterator
21
+
22
+
23
+ /* =================================================================================================
24
+ * Class RubyRDD
25
+ * =================================================================================================
26
+ */
27
+
28
+ class RubyRDD(
29
+ @transient parent: RDD[_],
30
+ command: Array[Byte],
31
+ broadcastVars: ArrayList[Broadcast[RubyBroadcast]],
32
+ accumulator: Accumulator[List[Array[Byte]]])
33
+ extends RDD[Array[Byte]](parent){
34
+
35
+ val bufferSize = conf.getInt("spark.buffer.size", 65536)
36
+
37
+ val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
38
+
39
+ override def getPartitions: Array[Partition] = firstParent.partitions
40
+
41
+ override val partitioner = None
42
+
43
+ /* ------------------------------------------------------------------------------------------ */
44
+
45
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
46
+
47
+ val env = SparkEnv.get
48
+
49
+ // Get worker and id
50
+ val (worker, workerId) = RubyWorker.create(env)
51
+
52
+ // Start a thread to feed the process input from our parent's iterator
53
+ val writerThread = new WriterThread(env, worker, split, context)
54
+
55
+ context.addTaskCompletionListener { context =>
56
+ writerThread.shutdownOnTaskCompletion()
57
+ writerThread.join()
58
+
59
+ // Cleanup the worker socket. This will also cause the worker to exit.
60
+ try {
61
+ RubyWorker.remove(worker, workerId)
62
+ worker.close()
63
+ } catch {
64
+ case e: Exception => logWarning("Failed to close worker socket", e)
65
+ }
66
+ }
67
+
68
+ val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
69
+
70
+ // Send data
71
+ writerThread.start()
72
+
73
+ // For violent termination of worker
74
+ new MonitorThread(workerId, worker, context).start()
75
+
76
+ // Return an iterator that read lines from the process's stdout
77
+ val stdoutIterator = new StreamReader(stream, writerThread, context)
78
+
79
+ // An iterator that wraps around an existing iterator to provide task killing functionality.
80
+ new InterruptibleIterator(context, stdoutIterator)
81
+
82
+ } // end compute
83
+
84
+ /* ------------------------------------------------------------------------------------------ */
85
+
86
+ class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
87
+ extends Thread("stdout writer for worker") {
88
+
89
+ @volatile private var _exception: Exception = null
90
+
91
+ setDaemon(true)
92
+
93
+ // Contains the exception thrown while writing the parent iterator to the process.
94
+ def exception: Option[Exception] = Option(_exception)
95
+
96
+ // Terminates the writer thread, ignoring any exceptions that may occur due to cleanup.
97
+ def shutdownOnTaskCompletion() {
98
+ assert(context.isCompleted)
99
+ this.interrupt()
100
+ }
101
+
102
+ // -------------------------------------------------------------------------------------------
103
+ // Send the necessary data for worker
104
+ // - split index
105
+ // - command
106
+ // - iterator
107
+
108
+ override def run(): Unit = Utils.logUncaughtExceptions {
109
+ try {
110
+ SparkEnv.set(env)
111
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
112
+ val dataOut = new DataOutputStream(stream)
113
+
114
+ // Partition index
115
+ dataOut.writeInt(split.index)
116
+
117
+ // Spark files
118
+ PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
119
+
120
+ // Broadcast variables
121
+ dataOut.writeInt(broadcastVars.length)
122
+ for (broadcast <- broadcastVars) {
123
+ dataOut.writeLong(broadcast.value.id)
124
+ PythonRDD.writeUTF(broadcast.value.path, dataOut)
125
+ }
126
+
127
+ // Serialized command
128
+ dataOut.writeInt(command.length)
129
+ dataOut.write(command)
130
+
131
+ // Send it
132
+ dataOut.flush()
133
+
134
+ // Data
135
+ PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
136
+ dataOut.writeInt(RubyConstant.DATA_EOF)
137
+ dataOut.flush()
138
+ } catch {
139
+ case e: Exception if context.isCompleted || context.isInterrupted =>
140
+ logDebug("Exception thrown after task completion (likely due to cleanup)", e)
141
+
142
+ case e: Exception =>
143
+ // We must avoid throwing exceptions here, because the thread uncaught exception handler
144
+ // will kill the whole executor (see org.apache.spark.executor.Executor).
145
+ _exception = e
146
+ } finally {
147
+ Try(worker.shutdownOutput()) // kill worker process
148
+ }
149
+ }
150
+ } // end WriterThread
151
+
152
+
153
+ /* ------------------------------------------------------------------------------------------ */
154
+
155
+ class StreamReader(stream: DataInputStream, writerThread: WriterThread, context: TaskContext) extends Iterator[Array[Byte]] {
156
+
157
+ def hasNext = _nextObj != null
158
+ var _nextObj = read()
159
+
160
+ // -------------------------------------------------------------------------------------------
161
+
162
+ def next(): Array[Byte] = {
163
+ val obj = _nextObj
164
+ if (hasNext) {
165
+ _nextObj = read()
166
+ }
167
+ obj
168
+ }
169
+
170
+ // -------------------------------------------------------------------------------------------
171
+
172
+ private def read(): Array[Byte] = {
173
+ if (writerThread.exception.isDefined) {
174
+ throw writerThread.exception.get
175
+ }
176
+ try {
177
+ stream.readInt() match {
178
+ case length if length > 0 =>
179
+ val obj = new Array[Byte](length)
180
+ stream.readFully(obj)
181
+ obj
182
+ case RubyConstant.WORKER_DONE =>
183
+ val numAccumulatorUpdates = stream.readInt()
184
+ (1 to numAccumulatorUpdates).foreach { _ =>
185
+ val updateLen = stream.readInt()
186
+ val update = new Array[Byte](updateLen)
187
+ stream.readFully(update)
188
+ accumulator += Collections.singletonList(update)
189
+ }
190
+ null
191
+ case RubyConstant.WORKER_ERROR =>
192
+ // Exception from worker
193
+
194
+ // message
195
+ val length = stream.readInt()
196
+ val obj = new Array[Byte](length)
197
+ stream.readFully(obj)
198
+
199
+ // stackTrace
200
+ val stackTraceLen = stream.readInt()
201
+ val stackTrace = new Array[String](stackTraceLen)
202
+ (0 until stackTraceLen).foreach { i =>
203
+ val length = stream.readInt()
204
+ val obj = new Array[Byte](length)
205
+ stream.readFully(obj)
206
+
207
+ stackTrace(i) = new String(obj, "utf-8")
208
+ }
209
+
210
+ // Worker will be killed
211
+ stream.close
212
+
213
+ // exception
214
+ val exception = new RubyException(new String(obj, "utf-8"), writerThread.exception.getOrElse(null))
215
+ exception.appendToStackTrace(stackTrace)
216
+
217
+ throw exception
218
+ }
219
+ } catch {
220
+
221
+ case e: Exception if context.isInterrupted =>
222
+ logDebug("Exception thrown after task interruption", e)
223
+ throw new TaskKilledException
224
+
225
+ case e: Exception if writerThread.exception.isDefined =>
226
+ logError("Worker exited unexpectedly (crashed)", e)
227
+ throw writerThread.exception.get
228
+
229
+ case eof: EOFException =>
230
+ throw new SparkException("Worker exited unexpectedly (crashed)", eof)
231
+ }
232
+ }
233
+ } // end StreamReader
234
+
235
+ /* ---------------------------------------------------------------------------------------------
236
+ * Monitor thread for controll worker. Kill worker if task is interrupted.
237
+ */
238
+
239
+ class MonitorThread(workerId: Long, worker: Socket, context: TaskContext)
240
+ extends Thread("Worker Monitor for worker") {
241
+
242
+ setDaemon(true)
243
+
244
+ override def run() {
245
+ // Kill the worker if it is interrupted, checking until task completion.
246
+ while (!context.isInterrupted && !context.isCompleted) {
247
+ Thread.sleep(2000)
248
+ }
249
+ if (!context.isCompleted) {
250
+ try {
251
+ logWarning("Incomplete task interrupted: Attempting to kill Worker "+workerId.toString())
252
+ RubyWorker.kill(workerId)
253
+ } catch {
254
+ case e: Exception =>
255
+ logError("Exception when trying to kill worker "+workerId.toString(), e)
256
+ }
257
+ }
258
+ }
259
+ } // end MonitorThread
260
+ } // end RubyRDD
261
+
262
+
263
+
264
+ /* =================================================================================================
265
+ * Class PairwiseRDD
266
+ * =================================================================================================
267
+ *
268
+ * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Ruby.
269
+ * This is used by PySpark's shuffle operations.
270
+ * Borrowed from Python Package -> need new deserializeLongValue ->
271
+ * Marshal will add the same 4b header
272
+ */
273
+
274
+ class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) {
275
+ override def getPartitions = prev.partitions
276
+ override def compute(split: Partition, context: TaskContext) =
277
+ prev.iterator(split, context).grouped(2).map {
278
+ case Seq(a, b) => (Utils.deserializeLongValue(a.reverse), b)
279
+ case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
280
+ }
281
+ val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
282
+ }
283
+
284
+
285
+
286
+ /* =================================================================================================
287
+ * Object RubyRDD
288
+ * =================================================================================================
289
+ */
290
+
291
+ object RubyRDD extends Logging {
292
+
293
+ def runJob(
294
+ sc: SparkContext,
295
+ rdd: JavaRDD[Array[Byte]],
296
+ partitions: ArrayList[Int],
297
+ allowLocal: Boolean,
298
+ filename: String): String = {
299
+ type ByteArray = Array[Byte]
300
+ type UnrolledPartition = Array[ByteArray]
301
+ val allPartitions: Array[UnrolledPartition] =
302
+ sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
303
+ val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
304
+ writeRDDToFile(flattenedPartition.iterator, filename)
305
+ }
306
+
307
+ def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = {
308
+ val file = new DataInputStream(new BufferedInputStream(new FileInputStream(filename)))
309
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
310
+ try {
311
+ while (true) {
312
+ val length = file.readInt()
313
+ val obj = new Array[Byte](length)
314
+ file.readFully(obj)
315
+ objs.append(obj)
316
+ }
317
+ } catch {
318
+ case eof: EOFException => {}
319
+ }
320
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
321
+ }
322
+
323
+ def writeRDDToFile[T](items: Iterator[T], filename: String): String = {
324
+ val file = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filename)))
325
+
326
+ try {
327
+ PythonRDD.writeIteratorToStream(items, file)
328
+ } finally {
329
+ file.close()
330
+ }
331
+
332
+ filename
333
+ }
334
+
335
+ def writeRDDToFile[T](rdd: RDD[T], filename: String): String = {
336
+ writeRDDToFile(rdd.collect.iterator, filename)
337
+ }
338
+
339
+ def readBroadcastFromFile(sc: JavaSparkContext, path: String, id: java.lang.Long): Broadcast[RubyBroadcast] = {
340
+ sc.broadcast(new RubyBroadcast(path, id))
341
+ }
342
+
343
+ /**
344
+ * Convert an RDD of serialized Ruby objects to RDD of objects, that is usable in Java.
345
+ */
346
+ def toJava(rbRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
347
+ rbRDD.rdd.mapPartitions { iter =>
348
+ iter.flatMap { item =>
349
+ val obj = Marshal.load(item)
350
+ if(batched){
351
+ obj.asInstanceOf[Array[_]]
352
+ }
353
+ else{
354
+ Seq(item)
355
+ }
356
+ }
357
+ }.toJavaRDD()
358
+ }
359
+
360
+ /**
361
+ * Convert an RDD of Java objects to an RDD of serialized Ruby objects, that is usable by Ruby.
362
+ */
363
+ def toRuby(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
364
+ jRDD.rdd.mapPartitions { iter => new IterableMarshaller(iter) }
365
+ }
366
+
367
+ }
368
+
369
+
370
+
371
+ /* =================================================================================================
372
+ * Class RubyException
373
+ * =================================================================================================
374
+ */
375
+
376
+ class RubyException(msg: String, cause: Exception) extends RuntimeException(msg, cause) {
377
+ def appendToStackTrace(toAdded: Array[String]) {
378
+ val newStactTrace = getStackTrace.toBuffer
379
+
380
+ var regexpMatch = "(.*):([0-9]+):in `([a-z]+)'".r
381
+
382
+ for(item <- toAdded) {
383
+ item match {
384
+ case regexpMatch(fileName, lineNumber, methodName) =>
385
+ newStactTrace += new StackTraceElement("RubyWorker", methodName, fileName, lineNumber.toInt)
386
+ case _ => null
387
+ }
388
+ }
389
+
390
+ setStackTrace(newStactTrace.toArray)
391
+ }
392
+ }