ruby-spark 1.1.0.1-java
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +37 -0
- data/Gemfile +47 -0
- data/Guardfile +5 -0
- data/LICENSE.txt +22 -0
- data/README.md +252 -0
- data/Rakefile +35 -0
- data/TODO.md +6 -0
- data/benchmark/aggregate.rb +33 -0
- data/benchmark/bisect.rb +88 -0
- data/benchmark/comparison/prepare.sh +18 -0
- data/benchmark/comparison/python.py +156 -0
- data/benchmark/comparison/r.r +69 -0
- data/benchmark/comparison/ruby.rb +167 -0
- data/benchmark/comparison/run-all.sh +160 -0
- data/benchmark/comparison/scala.scala +181 -0
- data/benchmark/custom_marshal.rb +94 -0
- data/benchmark/digest.rb +150 -0
- data/benchmark/enumerator.rb +88 -0
- data/benchmark/serializer.rb +82 -0
- data/benchmark/sort.rb +43 -0
- data/benchmark/sort2.rb +164 -0
- data/benchmark/take.rb +28 -0
- data/bin/ruby-spark +8 -0
- data/example/pi.rb +28 -0
- data/example/website_search.rb +83 -0
- data/ext/ruby_c/extconf.rb +3 -0
- data/ext/ruby_c/murmur.c +158 -0
- data/ext/ruby_c/murmur.h +9 -0
- data/ext/ruby_c/ruby-spark.c +18 -0
- data/ext/ruby_java/Digest.java +36 -0
- data/ext/ruby_java/Murmur2.java +98 -0
- data/ext/ruby_java/RubySparkExtService.java +28 -0
- data/ext/ruby_java/extconf.rb +3 -0
- data/ext/spark/build.sbt +73 -0
- data/ext/spark/project/plugins.sbt +9 -0
- data/ext/spark/sbt/sbt +34 -0
- data/ext/spark/src/main/scala/Exec.scala +91 -0
- data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
- data/ext/spark/src/main/scala/Marshal.scala +52 -0
- data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
- data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
- data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
- data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
- data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
- data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
- data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
- data/ext/spark/src/main/scala/RubyPage.scala +34 -0
- data/ext/spark/src/main/scala/RubyRDD.scala +392 -0
- data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
- data/ext/spark/src/main/scala/RubyTab.scala +11 -0
- data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
- data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
- data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
- data/lib/ruby-spark.rb +1 -0
- data/lib/spark.rb +198 -0
- data/lib/spark/accumulator.rb +260 -0
- data/lib/spark/broadcast.rb +98 -0
- data/lib/spark/build.rb +43 -0
- data/lib/spark/cli.rb +169 -0
- data/lib/spark/command.rb +86 -0
- data/lib/spark/command/base.rb +158 -0
- data/lib/spark/command/basic.rb +345 -0
- data/lib/spark/command/pair.rb +124 -0
- data/lib/spark/command/sort.rb +51 -0
- data/lib/spark/command/statistic.rb +144 -0
- data/lib/spark/command_builder.rb +141 -0
- data/lib/spark/command_validator.rb +34 -0
- data/lib/spark/config.rb +238 -0
- data/lib/spark/constant.rb +14 -0
- data/lib/spark/context.rb +322 -0
- data/lib/spark/error.rb +50 -0
- data/lib/spark/ext/hash.rb +41 -0
- data/lib/spark/ext/integer.rb +25 -0
- data/lib/spark/ext/io.rb +67 -0
- data/lib/spark/ext/ip_socket.rb +29 -0
- data/lib/spark/ext/module.rb +58 -0
- data/lib/spark/ext/object.rb +24 -0
- data/lib/spark/ext/string.rb +24 -0
- data/lib/spark/helper.rb +10 -0
- data/lib/spark/helper/logger.rb +40 -0
- data/lib/spark/helper/parser.rb +85 -0
- data/lib/spark/helper/serialize.rb +71 -0
- data/lib/spark/helper/statistic.rb +93 -0
- data/lib/spark/helper/system.rb +42 -0
- data/lib/spark/java_bridge.rb +19 -0
- data/lib/spark/java_bridge/base.rb +203 -0
- data/lib/spark/java_bridge/jruby.rb +23 -0
- data/lib/spark/java_bridge/rjb.rb +41 -0
- data/lib/spark/logger.rb +76 -0
- data/lib/spark/mllib.rb +100 -0
- data/lib/spark/mllib/classification/common.rb +31 -0
- data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
- data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
- data/lib/spark/mllib/classification/svm.rb +135 -0
- data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
- data/lib/spark/mllib/clustering/kmeans.rb +118 -0
- data/lib/spark/mllib/matrix.rb +120 -0
- data/lib/spark/mllib/regression/common.rb +73 -0
- data/lib/spark/mllib/regression/labeled_point.rb +41 -0
- data/lib/spark/mllib/regression/lasso.rb +100 -0
- data/lib/spark/mllib/regression/linear.rb +124 -0
- data/lib/spark/mllib/regression/ridge.rb +97 -0
- data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
- data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
- data/lib/spark/mllib/stat/distribution.rb +12 -0
- data/lib/spark/mllib/vector.rb +185 -0
- data/lib/spark/rdd.rb +1377 -0
- data/lib/spark/sampler.rb +92 -0
- data/lib/spark/serializer.rb +79 -0
- data/lib/spark/serializer/auto_batched.rb +59 -0
- data/lib/spark/serializer/base.rb +63 -0
- data/lib/spark/serializer/batched.rb +84 -0
- data/lib/spark/serializer/cartesian.rb +13 -0
- data/lib/spark/serializer/compressed.rb +27 -0
- data/lib/spark/serializer/marshal.rb +17 -0
- data/lib/spark/serializer/message_pack.rb +23 -0
- data/lib/spark/serializer/oj.rb +23 -0
- data/lib/spark/serializer/pair.rb +41 -0
- data/lib/spark/serializer/text.rb +25 -0
- data/lib/spark/sort.rb +189 -0
- data/lib/spark/stat_counter.rb +125 -0
- data/lib/spark/storage_level.rb +39 -0
- data/lib/spark/version.rb +3 -0
- data/lib/spark/worker/master.rb +144 -0
- data/lib/spark/worker/spark_files.rb +15 -0
- data/lib/spark/worker/worker.rb +200 -0
- data/ruby-spark.gemspec +47 -0
- data/spec/generator.rb +37 -0
- data/spec/inputs/lorem_300.txt +316 -0
- data/spec/inputs/numbers/1.txt +50 -0
- data/spec/inputs/numbers/10.txt +50 -0
- data/spec/inputs/numbers/11.txt +50 -0
- data/spec/inputs/numbers/12.txt +50 -0
- data/spec/inputs/numbers/13.txt +50 -0
- data/spec/inputs/numbers/14.txt +50 -0
- data/spec/inputs/numbers/15.txt +50 -0
- data/spec/inputs/numbers/16.txt +50 -0
- data/spec/inputs/numbers/17.txt +50 -0
- data/spec/inputs/numbers/18.txt +50 -0
- data/spec/inputs/numbers/19.txt +50 -0
- data/spec/inputs/numbers/2.txt +50 -0
- data/spec/inputs/numbers/20.txt +50 -0
- data/spec/inputs/numbers/3.txt +50 -0
- data/spec/inputs/numbers/4.txt +50 -0
- data/spec/inputs/numbers/5.txt +50 -0
- data/spec/inputs/numbers/6.txt +50 -0
- data/spec/inputs/numbers/7.txt +50 -0
- data/spec/inputs/numbers/8.txt +50 -0
- data/spec/inputs/numbers/9.txt +50 -0
- data/spec/inputs/numbers_0_100.txt +101 -0
- data/spec/inputs/numbers_1_100.txt +100 -0
- data/spec/lib/collect_spec.rb +42 -0
- data/spec/lib/command_spec.rb +68 -0
- data/spec/lib/config_spec.rb +64 -0
- data/spec/lib/context_spec.rb +165 -0
- data/spec/lib/ext_spec.rb +72 -0
- data/spec/lib/external_apps_spec.rb +45 -0
- data/spec/lib/filter_spec.rb +80 -0
- data/spec/lib/flat_map_spec.rb +100 -0
- data/spec/lib/group_spec.rb +109 -0
- data/spec/lib/helper_spec.rb +19 -0
- data/spec/lib/key_spec.rb +41 -0
- data/spec/lib/manipulation_spec.rb +122 -0
- data/spec/lib/map_partitions_spec.rb +87 -0
- data/spec/lib/map_spec.rb +91 -0
- data/spec/lib/mllib/classification_spec.rb +54 -0
- data/spec/lib/mllib/clustering_spec.rb +35 -0
- data/spec/lib/mllib/matrix_spec.rb +32 -0
- data/spec/lib/mllib/regression_spec.rb +116 -0
- data/spec/lib/mllib/vector_spec.rb +77 -0
- data/spec/lib/reduce_by_key_spec.rb +118 -0
- data/spec/lib/reduce_spec.rb +131 -0
- data/spec/lib/sample_spec.rb +46 -0
- data/spec/lib/serializer_spec.rb +88 -0
- data/spec/lib/sort_spec.rb +58 -0
- data/spec/lib/statistic_spec.rb +170 -0
- data/spec/lib/whole_text_files_spec.rb +33 -0
- data/spec/spec_helper.rb +38 -0
- metadata +389 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
package org.apache.spark.api.ruby
|
|
2
|
+
|
|
3
|
+
import java.io._
|
|
4
|
+
import java.net._
|
|
5
|
+
import java.util.{List, ArrayList}
|
|
6
|
+
|
|
7
|
+
import scala.collection.JavaConversions._
|
|
8
|
+
import scala.collection.immutable._
|
|
9
|
+
|
|
10
|
+
import org.apache.spark._
|
|
11
|
+
import org.apache.spark.util.Utils
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Internal class that acts as an `AccumulatorParam` for Ruby accumulators. Inside, it
|
|
15
|
+
* collects a list of pickled strings that we pass to Ruby through a socket.
|
|
16
|
+
*/
|
|
17
|
+
private class RubyAccumulatorParam(serverHost: String, serverPort: Int)
|
|
18
|
+
extends AccumulatorParam[List[Array[Byte]]] {
|
|
19
|
+
|
|
20
|
+
// Utils.checkHost(serverHost, "Expected hostname")
|
|
21
|
+
|
|
22
|
+
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
|
|
23
|
+
|
|
24
|
+
// Socket shoudl not be serialized
|
|
25
|
+
// Otherwise: SparkException: Task not serializable
|
|
26
|
+
@transient var socket: Socket = null
|
|
27
|
+
@transient var socketOutputStream: DataOutputStream = null
|
|
28
|
+
@transient var socketInputStream: DataInputStream = null
|
|
29
|
+
|
|
30
|
+
def openSocket(){
|
|
31
|
+
synchronized {
|
|
32
|
+
if (socket == null || socket.isClosed) {
|
|
33
|
+
socket = new Socket(serverHost, serverPort)
|
|
34
|
+
|
|
35
|
+
socketInputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, bufferSize))
|
|
36
|
+
socketOutputStream = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
override def zero(value: List[Array[Byte]]): List[Array[Byte]] = new ArrayList
|
|
42
|
+
|
|
43
|
+
override def addInPlace(val1: List[Array[Byte]], val2: List[Array[Byte]]) : List[Array[Byte]] = synchronized {
|
|
44
|
+
if (serverHost == null) {
|
|
45
|
+
// This happens on the worker node, where we just want to remember all the updates
|
|
46
|
+
val1.addAll(val2)
|
|
47
|
+
val1
|
|
48
|
+
} else {
|
|
49
|
+
// This happens on the master, where we pass the updates to Ruby through a socket
|
|
50
|
+
openSocket()
|
|
51
|
+
|
|
52
|
+
socketOutputStream.writeInt(val2.size)
|
|
53
|
+
for (array <- val2) {
|
|
54
|
+
socketOutputStream.writeInt(array.length)
|
|
55
|
+
socketOutputStream.write(array)
|
|
56
|
+
}
|
|
57
|
+
socketOutputStream.flush()
|
|
58
|
+
|
|
59
|
+
// Wait for acknowledgement
|
|
60
|
+
// http://stackoverflow.com/questions/28560133/ruby-server-java-scala-client-deadlock
|
|
61
|
+
//
|
|
62
|
+
// if(in.readInt() != RubyConstant.ACCUMULATOR_ACK){
|
|
63
|
+
// throw new SparkException("Accumulator was not acknowledged")
|
|
64
|
+
// }
|
|
65
|
+
|
|
66
|
+
new ArrayList
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
package org.apache.spark.api.ruby
|
|
2
|
+
|
|
3
|
+
import org.apache.spark.api.python.PythonBroadcast
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* An Wrapper for Ruby Broadcast, which is written into disk by Ruby. It also will
|
|
7
|
+
* write the data into disk after deserialization, then Ruby can read it from disks.
|
|
8
|
+
*
|
|
9
|
+
* Class use Python logic - only for semantic
|
|
10
|
+
*/
|
|
11
|
+
class RubyBroadcast(@transient var _path: String, @transient var id: java.lang.Long) extends PythonBroadcast(_path) {
|
|
12
|
+
|
|
13
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
package org.apache.spark.api.ruby
|
|
2
|
+
|
|
3
|
+
object RubyConstant {
|
|
4
|
+
val DATA_EOF = -2
|
|
5
|
+
val WORKER_ERROR = -1
|
|
6
|
+
val WORKER_DONE = 0
|
|
7
|
+
val CREATE_WORKER = 1
|
|
8
|
+
val KILL_WORKER = 2
|
|
9
|
+
val KILL_WORKER_AND_WAIT = 3
|
|
10
|
+
val SUCCESSFULLY_KILLED = 4
|
|
11
|
+
val UNSUCCESSFUL_KILLING = 5
|
|
12
|
+
val ACCUMULATOR_ACK = 6
|
|
13
|
+
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
package org.apache.spark.mllib.api.ruby
|
|
2
|
+
|
|
3
|
+
import java.util.ArrayList
|
|
4
|
+
|
|
5
|
+
import scala.collection.JavaConverters._
|
|
6
|
+
|
|
7
|
+
import org.apache.spark.rdd.RDD
|
|
8
|
+
import org.apache.spark.api.java.JavaRDD
|
|
9
|
+
import org.apache.spark.mllib.linalg._
|
|
10
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
|
11
|
+
import org.apache.spark.mllib.classification.NaiveBayes
|
|
12
|
+
import org.apache.spark.mllib.clustering.GaussianMixtureModel
|
|
13
|
+
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
|
14
|
+
import org.apache.spark.mllib.api.python.MLLibAPI
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RubyMLLibAPI extends MLLibAPI {
|
|
18
|
+
// trainLinearRegressionModelWithSGD
|
|
19
|
+
// trainLassoModelWithSGD
|
|
20
|
+
// trainRidgeModelWithSGD
|
|
21
|
+
// trainLogisticRegressionModelWithSGD
|
|
22
|
+
// trainLogisticRegressionModelWithLBFGS
|
|
23
|
+
// trainSVMModelWithSGD
|
|
24
|
+
// trainKMeansModel
|
|
25
|
+
// trainGaussianMixture
|
|
26
|
+
|
|
27
|
+
// Rjb have a problem with theta: Array[Array[Double]]
|
|
28
|
+
override def trainNaiveBayes(data: JavaRDD[LabeledPoint], lambda: Double) = {
|
|
29
|
+
val model = NaiveBayes.train(data.rdd, lambda)
|
|
30
|
+
|
|
31
|
+
List(
|
|
32
|
+
Vectors.dense(model.labels),
|
|
33
|
+
Vectors.dense(model.pi),
|
|
34
|
+
model.theta.toSeq
|
|
35
|
+
).map(_.asInstanceOf[Object]).asJava
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
// On python is wt just Object
|
|
39
|
+
def predictSoftGMM(
|
|
40
|
+
data: JavaRDD[Vector],
|
|
41
|
+
wt: ArrayList[Object],
|
|
42
|
+
mu: ArrayList[Object],
|
|
43
|
+
si: ArrayList[Object]): RDD[Array[Double]] = {
|
|
44
|
+
|
|
45
|
+
// val weight = wt.asInstanceOf[Array[Double]]
|
|
46
|
+
val weight = wt.toArray.map(_.asInstanceOf[Double])
|
|
47
|
+
val mean = mu.toArray.map(_.asInstanceOf[DenseVector])
|
|
48
|
+
val sigma = si.toArray.map(_.asInstanceOf[DenseMatrix])
|
|
49
|
+
val gaussians = Array.tabulate(weight.length){
|
|
50
|
+
i => new MultivariateGaussian(mean(i), sigma(i))
|
|
51
|
+
}
|
|
52
|
+
val model = new GaussianMixtureModel(weight, gaussians)
|
|
53
|
+
model.predictSoft(data)
|
|
54
|
+
}
|
|
55
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
package org.apache.spark.mllib.api.ruby
|
|
2
|
+
|
|
3
|
+
import java.util.ArrayList
|
|
4
|
+
|
|
5
|
+
import org.apache.spark.mllib.util.LinearDataGenerator
|
|
6
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
|
7
|
+
|
|
8
|
+
object RubyMLLibUtilAPI {
|
|
9
|
+
|
|
10
|
+
// Ruby does have a problem with creating Array[Double]
|
|
11
|
+
def generateLinearInput(
|
|
12
|
+
intercept: Double,
|
|
13
|
+
weights: ArrayList[String],
|
|
14
|
+
nPoints: Int,
|
|
15
|
+
seed: Int,
|
|
16
|
+
eps: Double = 0.1): Seq[LabeledPoint] = {
|
|
17
|
+
|
|
18
|
+
LinearDataGenerator.generateLinearInput(intercept, weights.toArray.map(_.toString.toDouble), nPoints, seed, eps)
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
package org.apache.spark.ui.ruby
|
|
2
|
+
|
|
3
|
+
// import javax.servlet.http.HttpServletRequest
|
|
4
|
+
|
|
5
|
+
// import scala.xml.Node
|
|
6
|
+
|
|
7
|
+
// import org.apache.spark.ui.{WebUIPage, UIUtils}
|
|
8
|
+
// import org.apache.spark.util.Utils
|
|
9
|
+
|
|
10
|
+
// private[ui] class RubyPage(parent: RubyTab, rbConfig: Array[Tuple2[String, String]]) extends WebUIPage("") {
|
|
11
|
+
|
|
12
|
+
// def render(request: HttpServletRequest): Seq[Node] = {
|
|
13
|
+
// val content = UIUtils.listingTable(header, row, rbConfig)
|
|
14
|
+
// UIUtils.headerSparkPage("Ruby Config", content, parent)
|
|
15
|
+
// }
|
|
16
|
+
|
|
17
|
+
// private def header = Seq(
|
|
18
|
+
// "Number"
|
|
19
|
+
// )
|
|
20
|
+
|
|
21
|
+
// private def row(keyValue: (String, String)): Seq[Node] = {
|
|
22
|
+
// // scalastyle:off
|
|
23
|
+
// keyValue match {
|
|
24
|
+
// case (key, value) =>
|
|
25
|
+
// <tr>
|
|
26
|
+
// <td>{key}</td>
|
|
27
|
+
// <td>{value}</td>
|
|
28
|
+
// </tr>
|
|
29
|
+
// }
|
|
30
|
+
// // scalastyle:on
|
|
31
|
+
// }
|
|
32
|
+
// }
|
|
33
|
+
|
|
34
|
+
class RubyPage {}
|
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
package org.apache.spark.api.ruby
|
|
2
|
+
|
|
3
|
+
import java.io._
|
|
4
|
+
import java.net._
|
|
5
|
+
import java.util.{List, ArrayList, Collections}
|
|
6
|
+
|
|
7
|
+
import scala.util.Try
|
|
8
|
+
import scala.reflect.ClassTag
|
|
9
|
+
import scala.collection.JavaConversions._
|
|
10
|
+
|
|
11
|
+
import org.apache.spark._
|
|
12
|
+
import org.apache.spark.{SparkEnv, Partition, SparkException, TaskContext}
|
|
13
|
+
import org.apache.spark.api.ruby._
|
|
14
|
+
import org.apache.spark.api.ruby.marshal._
|
|
15
|
+
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
|
|
16
|
+
import org.apache.spark.api.python.PythonRDD
|
|
17
|
+
import org.apache.spark.broadcast.Broadcast
|
|
18
|
+
import org.apache.spark.rdd.RDD
|
|
19
|
+
import org.apache.spark.util.Utils
|
|
20
|
+
import org.apache.spark.InterruptibleIterator
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
/* =================================================================================================
|
|
24
|
+
* Class RubyRDD
|
|
25
|
+
* =================================================================================================
|
|
26
|
+
*/
|
|
27
|
+
|
|
28
|
+
class RubyRDD(
|
|
29
|
+
@transient parent: RDD[_],
|
|
30
|
+
command: Array[Byte],
|
|
31
|
+
broadcastVars: ArrayList[Broadcast[RubyBroadcast]],
|
|
32
|
+
accumulator: Accumulator[List[Array[Byte]]])
|
|
33
|
+
extends RDD[Array[Byte]](parent){
|
|
34
|
+
|
|
35
|
+
val bufferSize = conf.getInt("spark.buffer.size", 65536)
|
|
36
|
+
|
|
37
|
+
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
|
|
38
|
+
|
|
39
|
+
override def getPartitions: Array[Partition] = firstParent.partitions
|
|
40
|
+
|
|
41
|
+
override val partitioner = None
|
|
42
|
+
|
|
43
|
+
/* ------------------------------------------------------------------------------------------ */
|
|
44
|
+
|
|
45
|
+
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
|
|
46
|
+
|
|
47
|
+
val env = SparkEnv.get
|
|
48
|
+
|
|
49
|
+
// Get worker and id
|
|
50
|
+
val (worker, workerId) = RubyWorker.create(env)
|
|
51
|
+
|
|
52
|
+
// Start a thread to feed the process input from our parent's iterator
|
|
53
|
+
val writerThread = new WriterThread(env, worker, split, context)
|
|
54
|
+
|
|
55
|
+
context.addTaskCompletionListener { context =>
|
|
56
|
+
writerThread.shutdownOnTaskCompletion()
|
|
57
|
+
writerThread.join()
|
|
58
|
+
|
|
59
|
+
// Cleanup the worker socket. This will also cause the worker to exit.
|
|
60
|
+
try {
|
|
61
|
+
RubyWorker.remove(worker, workerId)
|
|
62
|
+
worker.close()
|
|
63
|
+
} catch {
|
|
64
|
+
case e: Exception => logWarning("Failed to close worker socket", e)
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
|
|
69
|
+
|
|
70
|
+
// Send data
|
|
71
|
+
writerThread.start()
|
|
72
|
+
|
|
73
|
+
// For violent termination of worker
|
|
74
|
+
new MonitorThread(workerId, worker, context).start()
|
|
75
|
+
|
|
76
|
+
// Return an iterator that read lines from the process's stdout
|
|
77
|
+
val stdoutIterator = new StreamReader(stream, writerThread, context)
|
|
78
|
+
|
|
79
|
+
// An iterator that wraps around an existing iterator to provide task killing functionality.
|
|
80
|
+
new InterruptibleIterator(context, stdoutIterator)
|
|
81
|
+
|
|
82
|
+
} // end compute
|
|
83
|
+
|
|
84
|
+
/* ------------------------------------------------------------------------------------------ */
|
|
85
|
+
|
|
86
|
+
class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
|
|
87
|
+
extends Thread("stdout writer for worker") {
|
|
88
|
+
|
|
89
|
+
@volatile private var _exception: Exception = null
|
|
90
|
+
|
|
91
|
+
setDaemon(true)
|
|
92
|
+
|
|
93
|
+
// Contains the exception thrown while writing the parent iterator to the process.
|
|
94
|
+
def exception: Option[Exception] = Option(_exception)
|
|
95
|
+
|
|
96
|
+
// Terminates the writer thread, ignoring any exceptions that may occur due to cleanup.
|
|
97
|
+
def shutdownOnTaskCompletion() {
|
|
98
|
+
assert(context.isCompleted)
|
|
99
|
+
this.interrupt()
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
// -------------------------------------------------------------------------------------------
|
|
103
|
+
// Send the necessary data for worker
|
|
104
|
+
// - split index
|
|
105
|
+
// - command
|
|
106
|
+
// - iterator
|
|
107
|
+
|
|
108
|
+
override def run(): Unit = Utils.logUncaughtExceptions {
|
|
109
|
+
try {
|
|
110
|
+
SparkEnv.set(env)
|
|
111
|
+
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
|
|
112
|
+
val dataOut = new DataOutputStream(stream)
|
|
113
|
+
|
|
114
|
+
// Partition index
|
|
115
|
+
dataOut.writeInt(split.index)
|
|
116
|
+
|
|
117
|
+
// Spark files
|
|
118
|
+
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
|
|
119
|
+
|
|
120
|
+
// Broadcast variables
|
|
121
|
+
dataOut.writeInt(broadcastVars.length)
|
|
122
|
+
for (broadcast <- broadcastVars) {
|
|
123
|
+
dataOut.writeLong(broadcast.value.id)
|
|
124
|
+
PythonRDD.writeUTF(broadcast.value.path, dataOut)
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Serialized command
|
|
128
|
+
dataOut.writeInt(command.length)
|
|
129
|
+
dataOut.write(command)
|
|
130
|
+
|
|
131
|
+
// Send it
|
|
132
|
+
dataOut.flush()
|
|
133
|
+
|
|
134
|
+
// Data
|
|
135
|
+
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
|
|
136
|
+
dataOut.writeInt(RubyConstant.DATA_EOF)
|
|
137
|
+
dataOut.flush()
|
|
138
|
+
} catch {
|
|
139
|
+
case e: Exception if context.isCompleted || context.isInterrupted =>
|
|
140
|
+
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
|
|
141
|
+
|
|
142
|
+
case e: Exception =>
|
|
143
|
+
// We must avoid throwing exceptions here, because the thread uncaught exception handler
|
|
144
|
+
// will kill the whole executor (see org.apache.spark.executor.Executor).
|
|
145
|
+
_exception = e
|
|
146
|
+
} finally {
|
|
147
|
+
Try(worker.shutdownOutput()) // kill worker process
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
} // end WriterThread
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
/* ------------------------------------------------------------------------------------------ */
|
|
154
|
+
|
|
155
|
+
class StreamReader(stream: DataInputStream, writerThread: WriterThread, context: TaskContext) extends Iterator[Array[Byte]] {
|
|
156
|
+
|
|
157
|
+
def hasNext = _nextObj != null
|
|
158
|
+
var _nextObj = read()
|
|
159
|
+
|
|
160
|
+
// -------------------------------------------------------------------------------------------
|
|
161
|
+
|
|
162
|
+
def next(): Array[Byte] = {
|
|
163
|
+
val obj = _nextObj
|
|
164
|
+
if (hasNext) {
|
|
165
|
+
_nextObj = read()
|
|
166
|
+
}
|
|
167
|
+
obj
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
// -------------------------------------------------------------------------------------------
|
|
171
|
+
|
|
172
|
+
private def read(): Array[Byte] = {
|
|
173
|
+
if (writerThread.exception.isDefined) {
|
|
174
|
+
throw writerThread.exception.get
|
|
175
|
+
}
|
|
176
|
+
try {
|
|
177
|
+
stream.readInt() match {
|
|
178
|
+
case length if length > 0 =>
|
|
179
|
+
val obj = new Array[Byte](length)
|
|
180
|
+
stream.readFully(obj)
|
|
181
|
+
obj
|
|
182
|
+
case RubyConstant.WORKER_DONE =>
|
|
183
|
+
val numAccumulatorUpdates = stream.readInt()
|
|
184
|
+
(1 to numAccumulatorUpdates).foreach { _ =>
|
|
185
|
+
val updateLen = stream.readInt()
|
|
186
|
+
val update = new Array[Byte](updateLen)
|
|
187
|
+
stream.readFully(update)
|
|
188
|
+
accumulator += Collections.singletonList(update)
|
|
189
|
+
}
|
|
190
|
+
null
|
|
191
|
+
case RubyConstant.WORKER_ERROR =>
|
|
192
|
+
// Exception from worker
|
|
193
|
+
|
|
194
|
+
// message
|
|
195
|
+
val length = stream.readInt()
|
|
196
|
+
val obj = new Array[Byte](length)
|
|
197
|
+
stream.readFully(obj)
|
|
198
|
+
|
|
199
|
+
// stackTrace
|
|
200
|
+
val stackTraceLen = stream.readInt()
|
|
201
|
+
val stackTrace = new Array[String](stackTraceLen)
|
|
202
|
+
(0 until stackTraceLen).foreach { i =>
|
|
203
|
+
val length = stream.readInt()
|
|
204
|
+
val obj = new Array[Byte](length)
|
|
205
|
+
stream.readFully(obj)
|
|
206
|
+
|
|
207
|
+
stackTrace(i) = new String(obj, "utf-8")
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
// Worker will be killed
|
|
211
|
+
stream.close
|
|
212
|
+
|
|
213
|
+
// exception
|
|
214
|
+
val exception = new RubyException(new String(obj, "utf-8"), writerThread.exception.getOrElse(null))
|
|
215
|
+
exception.appendToStackTrace(stackTrace)
|
|
216
|
+
|
|
217
|
+
throw exception
|
|
218
|
+
}
|
|
219
|
+
} catch {
|
|
220
|
+
|
|
221
|
+
case e: Exception if context.isInterrupted =>
|
|
222
|
+
logDebug("Exception thrown after task interruption", e)
|
|
223
|
+
throw new TaskKilledException
|
|
224
|
+
|
|
225
|
+
case e: Exception if writerThread.exception.isDefined =>
|
|
226
|
+
logError("Worker exited unexpectedly (crashed)", e)
|
|
227
|
+
throw writerThread.exception.get
|
|
228
|
+
|
|
229
|
+
case eof: EOFException =>
|
|
230
|
+
throw new SparkException("Worker exited unexpectedly (crashed)", eof)
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
} // end StreamReader
|
|
234
|
+
|
|
235
|
+
/* ---------------------------------------------------------------------------------------------
|
|
236
|
+
* Monitor thread for controll worker. Kill worker if task is interrupted.
|
|
237
|
+
*/
|
|
238
|
+
|
|
239
|
+
class MonitorThread(workerId: Long, worker: Socket, context: TaskContext)
|
|
240
|
+
extends Thread("Worker Monitor for worker") {
|
|
241
|
+
|
|
242
|
+
setDaemon(true)
|
|
243
|
+
|
|
244
|
+
override def run() {
|
|
245
|
+
// Kill the worker if it is interrupted, checking until task completion.
|
|
246
|
+
while (!context.isInterrupted && !context.isCompleted) {
|
|
247
|
+
Thread.sleep(2000)
|
|
248
|
+
}
|
|
249
|
+
if (!context.isCompleted) {
|
|
250
|
+
try {
|
|
251
|
+
logWarning("Incomplete task interrupted: Attempting to kill Worker "+workerId.toString())
|
|
252
|
+
RubyWorker.kill(workerId)
|
|
253
|
+
} catch {
|
|
254
|
+
case e: Exception =>
|
|
255
|
+
logError("Exception when trying to kill worker "+workerId.toString(), e)
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
} // end MonitorThread
|
|
260
|
+
} // end RubyRDD
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
/* =================================================================================================
|
|
265
|
+
* Class PairwiseRDD
|
|
266
|
+
* =================================================================================================
|
|
267
|
+
*
|
|
268
|
+
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Ruby.
|
|
269
|
+
* This is used by PySpark's shuffle operations.
|
|
270
|
+
* Borrowed from Python Package -> need new deserializeLongValue ->
|
|
271
|
+
* Marshal will add the same 4b header
|
|
272
|
+
*/
|
|
273
|
+
|
|
274
|
+
class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) {
|
|
275
|
+
override def getPartitions = prev.partitions
|
|
276
|
+
override def compute(split: Partition, context: TaskContext) =
|
|
277
|
+
prev.iterator(split, context).grouped(2).map {
|
|
278
|
+
case Seq(a, b) => (Utils.deserializeLongValue(a.reverse), b)
|
|
279
|
+
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
|
|
280
|
+
}
|
|
281
|
+
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
/* =================================================================================================
|
|
287
|
+
* Object RubyRDD
|
|
288
|
+
* =================================================================================================
|
|
289
|
+
*/
|
|
290
|
+
|
|
291
|
+
object RubyRDD extends Logging {
|
|
292
|
+
|
|
293
|
+
def runJob(
|
|
294
|
+
sc: SparkContext,
|
|
295
|
+
rdd: JavaRDD[Array[Byte]],
|
|
296
|
+
partitions: ArrayList[Int],
|
|
297
|
+
allowLocal: Boolean,
|
|
298
|
+
filename: String): String = {
|
|
299
|
+
type ByteArray = Array[Byte]
|
|
300
|
+
type UnrolledPartition = Array[ByteArray]
|
|
301
|
+
val allPartitions: Array[UnrolledPartition] =
|
|
302
|
+
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
|
|
303
|
+
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
|
|
304
|
+
writeRDDToFile(flattenedPartition.iterator, filename)
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = {
|
|
308
|
+
val file = new DataInputStream(new BufferedInputStream(new FileInputStream(filename)))
|
|
309
|
+
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
|
|
310
|
+
try {
|
|
311
|
+
while (true) {
|
|
312
|
+
val length = file.readInt()
|
|
313
|
+
val obj = new Array[Byte](length)
|
|
314
|
+
file.readFully(obj)
|
|
315
|
+
objs.append(obj)
|
|
316
|
+
}
|
|
317
|
+
} catch {
|
|
318
|
+
case eof: EOFException => {}
|
|
319
|
+
}
|
|
320
|
+
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
def writeRDDToFile[T](items: Iterator[T], filename: String): String = {
|
|
324
|
+
val file = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filename)))
|
|
325
|
+
|
|
326
|
+
try {
|
|
327
|
+
PythonRDD.writeIteratorToStream(items, file)
|
|
328
|
+
} finally {
|
|
329
|
+
file.close()
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
filename
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
def writeRDDToFile[T](rdd: RDD[T], filename: String): String = {
|
|
336
|
+
writeRDDToFile(rdd.collect.iterator, filename)
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
def readBroadcastFromFile(sc: JavaSparkContext, path: String, id: java.lang.Long): Broadcast[RubyBroadcast] = {
|
|
340
|
+
sc.broadcast(new RubyBroadcast(path, id))
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
/**
|
|
344
|
+
* Convert an RDD of serialized Ruby objects to RDD of objects, that is usable in Java.
|
|
345
|
+
*/
|
|
346
|
+
def toJava(rbRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
|
|
347
|
+
rbRDD.rdd.mapPartitions { iter =>
|
|
348
|
+
iter.flatMap { item =>
|
|
349
|
+
val obj = Marshal.load(item)
|
|
350
|
+
if(batched){
|
|
351
|
+
obj.asInstanceOf[Array[_]]
|
|
352
|
+
}
|
|
353
|
+
else{
|
|
354
|
+
Seq(item)
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
}.toJavaRDD()
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
/**
|
|
361
|
+
* Convert an RDD of Java objects to an RDD of serialized Ruby objects, that is usable by Ruby.
|
|
362
|
+
*/
|
|
363
|
+
def toRuby(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
|
|
364
|
+
jRDD.rdd.mapPartitions { iter => new IterableMarshaller(iter) }
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
/* =================================================================================================
|
|
372
|
+
* Class RubyException
|
|
373
|
+
* =================================================================================================
|
|
374
|
+
*/
|
|
375
|
+
|
|
376
|
+
class RubyException(msg: String, cause: Exception) extends RuntimeException(msg, cause) {
|
|
377
|
+
def appendToStackTrace(toAdded: Array[String]) {
|
|
378
|
+
val newStactTrace = getStackTrace.toBuffer
|
|
379
|
+
|
|
380
|
+
var regexpMatch = "(.*):([0-9]+):in `([a-z]+)'".r
|
|
381
|
+
|
|
382
|
+
for(item <- toAdded) {
|
|
383
|
+
item match {
|
|
384
|
+
case regexpMatch(fileName, lineNumber, methodName) =>
|
|
385
|
+
newStactTrace += new StackTraceElement("RubyWorker", methodName, fileName, lineNumber.toInt)
|
|
386
|
+
case _ => null
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
setStackTrace(newStactTrace.toArray)
|
|
391
|
+
}
|
|
392
|
+
}
|