ruby-spark 1.1.0.1-java
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +37 -0
- data/Gemfile +47 -0
- data/Guardfile +5 -0
- data/LICENSE.txt +22 -0
- data/README.md +252 -0
- data/Rakefile +35 -0
- data/TODO.md +6 -0
- data/benchmark/aggregate.rb +33 -0
- data/benchmark/bisect.rb +88 -0
- data/benchmark/comparison/prepare.sh +18 -0
- data/benchmark/comparison/python.py +156 -0
- data/benchmark/comparison/r.r +69 -0
- data/benchmark/comparison/ruby.rb +167 -0
- data/benchmark/comparison/run-all.sh +160 -0
- data/benchmark/comparison/scala.scala +181 -0
- data/benchmark/custom_marshal.rb +94 -0
- data/benchmark/digest.rb +150 -0
- data/benchmark/enumerator.rb +88 -0
- data/benchmark/serializer.rb +82 -0
- data/benchmark/sort.rb +43 -0
- data/benchmark/sort2.rb +164 -0
- data/benchmark/take.rb +28 -0
- data/bin/ruby-spark +8 -0
- data/example/pi.rb +28 -0
- data/example/website_search.rb +83 -0
- data/ext/ruby_c/extconf.rb +3 -0
- data/ext/ruby_c/murmur.c +158 -0
- data/ext/ruby_c/murmur.h +9 -0
- data/ext/ruby_c/ruby-spark.c +18 -0
- data/ext/ruby_java/Digest.java +36 -0
- data/ext/ruby_java/Murmur2.java +98 -0
- data/ext/ruby_java/RubySparkExtService.java +28 -0
- data/ext/ruby_java/extconf.rb +3 -0
- data/ext/spark/build.sbt +73 -0
- data/ext/spark/project/plugins.sbt +9 -0
- data/ext/spark/sbt/sbt +34 -0
- data/ext/spark/src/main/scala/Exec.scala +91 -0
- data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
- data/ext/spark/src/main/scala/Marshal.scala +52 -0
- data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
- data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
- data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
- data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
- data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
- data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
- data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
- data/ext/spark/src/main/scala/RubyPage.scala +34 -0
- data/ext/spark/src/main/scala/RubyRDD.scala +392 -0
- data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
- data/ext/spark/src/main/scala/RubyTab.scala +11 -0
- data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
- data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
- data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
- data/lib/ruby-spark.rb +1 -0
- data/lib/spark.rb +198 -0
- data/lib/spark/accumulator.rb +260 -0
- data/lib/spark/broadcast.rb +98 -0
- data/lib/spark/build.rb +43 -0
- data/lib/spark/cli.rb +169 -0
- data/lib/spark/command.rb +86 -0
- data/lib/spark/command/base.rb +158 -0
- data/lib/spark/command/basic.rb +345 -0
- data/lib/spark/command/pair.rb +124 -0
- data/lib/spark/command/sort.rb +51 -0
- data/lib/spark/command/statistic.rb +144 -0
- data/lib/spark/command_builder.rb +141 -0
- data/lib/spark/command_validator.rb +34 -0
- data/lib/spark/config.rb +238 -0
- data/lib/spark/constant.rb +14 -0
- data/lib/spark/context.rb +322 -0
- data/lib/spark/error.rb +50 -0
- data/lib/spark/ext/hash.rb +41 -0
- data/lib/spark/ext/integer.rb +25 -0
- data/lib/spark/ext/io.rb +67 -0
- data/lib/spark/ext/ip_socket.rb +29 -0
- data/lib/spark/ext/module.rb +58 -0
- data/lib/spark/ext/object.rb +24 -0
- data/lib/spark/ext/string.rb +24 -0
- data/lib/spark/helper.rb +10 -0
- data/lib/spark/helper/logger.rb +40 -0
- data/lib/spark/helper/parser.rb +85 -0
- data/lib/spark/helper/serialize.rb +71 -0
- data/lib/spark/helper/statistic.rb +93 -0
- data/lib/spark/helper/system.rb +42 -0
- data/lib/spark/java_bridge.rb +19 -0
- data/lib/spark/java_bridge/base.rb +203 -0
- data/lib/spark/java_bridge/jruby.rb +23 -0
- data/lib/spark/java_bridge/rjb.rb +41 -0
- data/lib/spark/logger.rb +76 -0
- data/lib/spark/mllib.rb +100 -0
- data/lib/spark/mllib/classification/common.rb +31 -0
- data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
- data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
- data/lib/spark/mllib/classification/svm.rb +135 -0
- data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
- data/lib/spark/mllib/clustering/kmeans.rb +118 -0
- data/lib/spark/mllib/matrix.rb +120 -0
- data/lib/spark/mllib/regression/common.rb +73 -0
- data/lib/spark/mllib/regression/labeled_point.rb +41 -0
- data/lib/spark/mllib/regression/lasso.rb +100 -0
- data/lib/spark/mllib/regression/linear.rb +124 -0
- data/lib/spark/mllib/regression/ridge.rb +97 -0
- data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
- data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
- data/lib/spark/mllib/stat/distribution.rb +12 -0
- data/lib/spark/mllib/vector.rb +185 -0
- data/lib/spark/rdd.rb +1377 -0
- data/lib/spark/sampler.rb +92 -0
- data/lib/spark/serializer.rb +79 -0
- data/lib/spark/serializer/auto_batched.rb +59 -0
- data/lib/spark/serializer/base.rb +63 -0
- data/lib/spark/serializer/batched.rb +84 -0
- data/lib/spark/serializer/cartesian.rb +13 -0
- data/lib/spark/serializer/compressed.rb +27 -0
- data/lib/spark/serializer/marshal.rb +17 -0
- data/lib/spark/serializer/message_pack.rb +23 -0
- data/lib/spark/serializer/oj.rb +23 -0
- data/lib/spark/serializer/pair.rb +41 -0
- data/lib/spark/serializer/text.rb +25 -0
- data/lib/spark/sort.rb +189 -0
- data/lib/spark/stat_counter.rb +125 -0
- data/lib/spark/storage_level.rb +39 -0
- data/lib/spark/version.rb +3 -0
- data/lib/spark/worker/master.rb +144 -0
- data/lib/spark/worker/spark_files.rb +15 -0
- data/lib/spark/worker/worker.rb +200 -0
- data/ruby-spark.gemspec +47 -0
- data/spec/generator.rb +37 -0
- data/spec/inputs/lorem_300.txt +316 -0
- data/spec/inputs/numbers/1.txt +50 -0
- data/spec/inputs/numbers/10.txt +50 -0
- data/spec/inputs/numbers/11.txt +50 -0
- data/spec/inputs/numbers/12.txt +50 -0
- data/spec/inputs/numbers/13.txt +50 -0
- data/spec/inputs/numbers/14.txt +50 -0
- data/spec/inputs/numbers/15.txt +50 -0
- data/spec/inputs/numbers/16.txt +50 -0
- data/spec/inputs/numbers/17.txt +50 -0
- data/spec/inputs/numbers/18.txt +50 -0
- data/spec/inputs/numbers/19.txt +50 -0
- data/spec/inputs/numbers/2.txt +50 -0
- data/spec/inputs/numbers/20.txt +50 -0
- data/spec/inputs/numbers/3.txt +50 -0
- data/spec/inputs/numbers/4.txt +50 -0
- data/spec/inputs/numbers/5.txt +50 -0
- data/spec/inputs/numbers/6.txt +50 -0
- data/spec/inputs/numbers/7.txt +50 -0
- data/spec/inputs/numbers/8.txt +50 -0
- data/spec/inputs/numbers/9.txt +50 -0
- data/spec/inputs/numbers_0_100.txt +101 -0
- data/spec/inputs/numbers_1_100.txt +100 -0
- data/spec/lib/collect_spec.rb +42 -0
- data/spec/lib/command_spec.rb +68 -0
- data/spec/lib/config_spec.rb +64 -0
- data/spec/lib/context_spec.rb +165 -0
- data/spec/lib/ext_spec.rb +72 -0
- data/spec/lib/external_apps_spec.rb +45 -0
- data/spec/lib/filter_spec.rb +80 -0
- data/spec/lib/flat_map_spec.rb +100 -0
- data/spec/lib/group_spec.rb +109 -0
- data/spec/lib/helper_spec.rb +19 -0
- data/spec/lib/key_spec.rb +41 -0
- data/spec/lib/manipulation_spec.rb +122 -0
- data/spec/lib/map_partitions_spec.rb +87 -0
- data/spec/lib/map_spec.rb +91 -0
- data/spec/lib/mllib/classification_spec.rb +54 -0
- data/spec/lib/mllib/clustering_spec.rb +35 -0
- data/spec/lib/mllib/matrix_spec.rb +32 -0
- data/spec/lib/mllib/regression_spec.rb +116 -0
- data/spec/lib/mllib/vector_spec.rb +77 -0
- data/spec/lib/reduce_by_key_spec.rb +118 -0
- data/spec/lib/reduce_spec.rb +131 -0
- data/spec/lib/sample_spec.rb +46 -0
- data/spec/lib/serializer_spec.rb +88 -0
- data/spec/lib/sort_spec.rb +58 -0
- data/spec/lib/statistic_spec.rb +170 -0
- data/spec/lib/whole_text_files_spec.rb +33 -0
- data/spec/spec_helper.rb +38 -0
- metadata +389 -0
data/ext/spark/build.sbt
ADDED
@@ -0,0 +1,73 @@
|
|
1
|
+
import AssemblyKeys._
|
2
|
+
|
3
|
+
assemblySettings
|
4
|
+
|
5
|
+
// Default values
|
6
|
+
val defaultScalaVersion = "2.10.4"
|
7
|
+
val defaultSparkVersion = "1.3.0"
|
8
|
+
val defaultSparkCoreVersion = "2.10"
|
9
|
+
val defaultSparkHome = "target"
|
10
|
+
val defaultHadoopVersion = "1.0.4"
|
11
|
+
|
12
|
+
// Values
|
13
|
+
val _scalaVersion = scala.util.Properties.envOrElse("SCALA_VERSION", defaultScalaVersion)
|
14
|
+
val _sparkVersion = scala.util.Properties.envOrElse("SPARK_VERSION", defaultSparkVersion)
|
15
|
+
val _sparkCoreVersion = scala.util.Properties.envOrElse("SPARK_CORE_VERSION", defaultSparkCoreVersion)
|
16
|
+
val _sparkHome = scala.util.Properties.envOrElse("SPARK_HOME", defaultSparkHome)
|
17
|
+
val _hadoopVersion = scala.util.Properties.envOrElse("HADOOP_VERSION", defaultHadoopVersion)
|
18
|
+
|
19
|
+
// Project settings
|
20
|
+
name := "ruby-spark"
|
21
|
+
|
22
|
+
version := "1.0.0"
|
23
|
+
|
24
|
+
scalaVersion := _scalaVersion
|
25
|
+
|
26
|
+
javacOptions ++= Seq("-source", "1.7", "-target", "1.7")
|
27
|
+
|
28
|
+
// Jar target folder
|
29
|
+
artifactPath in Compile in packageBin := file(s"${_sparkHome}/ruby-spark.jar")
|
30
|
+
outputPath in packageDependency := file(s"${_sparkHome}/ruby-spark-deps.jar")
|
31
|
+
|
32
|
+
// Protocol buffer support
|
33
|
+
seq(sbtprotobuf.ProtobufPlugin.protobufSettings: _*)
|
34
|
+
|
35
|
+
// Additional libraries
|
36
|
+
libraryDependencies ++= Seq(
|
37
|
+
"org.apache.spark" %% "spark-core" % _sparkVersion excludeAll(ExclusionRule(organization = "org.apache.hadoop")),
|
38
|
+
"org.apache.spark" %% "spark-graphx" % _sparkVersion,
|
39
|
+
"org.apache.spark" %% "spark-mllib" % _sparkVersion,
|
40
|
+
"org.apache.hadoop" % "hadoop-client" % _hadoopVersion,
|
41
|
+
"com.github.fommil.netlib" % "all" % "1.1.2",
|
42
|
+
"org.scalatest" % "scalatest_2.10" % "2.2.1" % "test"
|
43
|
+
)
|
44
|
+
|
45
|
+
// Repositories
|
46
|
+
resolvers ++= Seq(
|
47
|
+
"JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/",
|
48
|
+
"Spray Repository" at "http://repo.spray.cc/",
|
49
|
+
"Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/",
|
50
|
+
"Akka Repository" at "http://repo.akka.io/releases/",
|
51
|
+
"Twitter4J Repository" at "http://twitter4j.org/maven2/",
|
52
|
+
"Apache HBase" at "https://repository.apache.org/content/repositories/releases",
|
53
|
+
"Twitter Maven Repo" at "http://maven.twttr.com/",
|
54
|
+
"scala-tools" at "https://oss.sonatype.org/content/groups/scala-tools",
|
55
|
+
"Typesafe repository" at "http://repo.typesafe.com/typesafe/releases/",
|
56
|
+
"Second Typesafe repo" at "http://repo.typesafe.com/typesafe/maven-releases/",
|
57
|
+
"Mesosphere Public Repository" at "http://downloads.mesosphere.io/maven",
|
58
|
+
Resolver.sonatypeRepo("public")
|
59
|
+
)
|
60
|
+
|
61
|
+
// Merge strategy
|
62
|
+
mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) =>
|
63
|
+
{
|
64
|
+
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
|
65
|
+
case m if m.startsWith("META-INF") => MergeStrategy.discard
|
66
|
+
case PathList("javax", "servlet", xs @ _*) => MergeStrategy.first
|
67
|
+
case PathList("org", "apache", xs @ _*) => MergeStrategy.first
|
68
|
+
case PathList("org", "jboss", xs @ _*) => MergeStrategy.first
|
69
|
+
case "about.html" => MergeStrategy.rename
|
70
|
+
case "reference.conf" => MergeStrategy.concat
|
71
|
+
case _ => MergeStrategy.first
|
72
|
+
}
|
73
|
+
}
|
@@ -0,0 +1,9 @@
|
|
1
|
+
resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns)
|
2
|
+
|
3
|
+
resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/"
|
4
|
+
|
5
|
+
resolvers += "Spray Repository" at "http://repo.spray.cc/"
|
6
|
+
|
7
|
+
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.10.2")
|
8
|
+
|
9
|
+
addSbtPlugin("com.github.gseitz" % "sbt-protobuf" % "0.3.3")
|
data/ext/spark/sbt/sbt
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
#!/bin/bash
|
2
|
+
|
3
|
+
# This script launches sbt for this project. If present it uses the system
|
4
|
+
# version of sbt. If there is no system version of sbt it attempts to download
|
5
|
+
# sbt locally.
|
6
|
+
SBT_VERSION=0.13.7
|
7
|
+
URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
|
8
|
+
URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
|
9
|
+
JAR=sbt/sbt-launch-${SBT_VERSION}.jar
|
10
|
+
|
11
|
+
# Download sbt launch jar if it hasn't been downloaded yet
|
12
|
+
if [ ! -f ${JAR} ]; then
|
13
|
+
# Download
|
14
|
+
printf "Attempting to fetch sbt\n"
|
15
|
+
JAR_DL=${JAR}.part
|
16
|
+
if hash curl 2>/dev/null; then
|
17
|
+
(curl --progress-bar ${URL1} > ${JAR_DL} || curl --progress-bar ${URL2} > ${JAR_DL}) && mv ${JAR_DL} ${JAR}
|
18
|
+
elif hash wget 2>/dev/null; then
|
19
|
+
(wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR}
|
20
|
+
else
|
21
|
+
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
|
22
|
+
exit -1
|
23
|
+
fi
|
24
|
+
fi
|
25
|
+
if [ ! -f ${JAR} ]; then
|
26
|
+
# We failed to download
|
27
|
+
printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n"
|
28
|
+
exit -1
|
29
|
+
fi
|
30
|
+
printf "Launching sbt from ${JAR}\n"
|
31
|
+
java \
|
32
|
+
-Xmx1200m -XX:MaxPermSize=350m -XX:ReservedCodeCacheSize=256m \
|
33
|
+
-jar ${JAR} \
|
34
|
+
"$@"
|
@@ -0,0 +1,91 @@
|
|
1
|
+
package org.apache.spark.api.ruby
|
2
|
+
|
3
|
+
import java.io.{File, FileOutputStream, InputStreamReader, BufferedReader}
|
4
|
+
|
5
|
+
import scala.collection.JavaConversions._
|
6
|
+
|
7
|
+
import org.apache.spark.{SparkEnv, Logging}
|
8
|
+
import org.apache.spark.util._
|
9
|
+
|
10
|
+
|
11
|
+
/* =================================================================================================
|
12
|
+
* class FileCommand
|
13
|
+
* =================================================================================================
|
14
|
+
*
|
15
|
+
* Save command to file and than execute him because from Scala you cannot simply run
|
16
|
+
* something like "bash --norc -i -c 'source .zshrc; ruby master.rb'"
|
17
|
+
*/
|
18
|
+
|
19
|
+
class FileCommand(command: String) extends Logging {
|
20
|
+
|
21
|
+
var pb: ProcessBuilder = null
|
22
|
+
var file: File = null
|
23
|
+
|
24
|
+
// Command is complete.
|
25
|
+
def this(command: String, env: SparkEnv) = {
|
26
|
+
this(command)
|
27
|
+
create(env)
|
28
|
+
}
|
29
|
+
|
30
|
+
// Template must contains %s which will be replaced for command
|
31
|
+
def this(template: String, command: String, env: SparkEnv, envVars: Map[String, String]) = {
|
32
|
+
this(template.format(command), env)
|
33
|
+
setEnvVars(envVars)
|
34
|
+
}
|
35
|
+
|
36
|
+
private def create(env: SparkEnv) {
|
37
|
+
val dir = new File(env.sparkFilesDir)
|
38
|
+
val ext = if(Utils.isWindows) ".cmd" else ".sh"
|
39
|
+
val shell = if(Utils.isWindows) "cmd" else "bash"
|
40
|
+
|
41
|
+
file = File.createTempFile("command", ext, dir)
|
42
|
+
|
43
|
+
val out = new FileOutputStream(file)
|
44
|
+
out.write(command.getBytes)
|
45
|
+
out.close
|
46
|
+
|
47
|
+
logInfo(s"New FileCommand at ${file.getAbsolutePath}")
|
48
|
+
|
49
|
+
pb = new ProcessBuilder(shell, file.getAbsolutePath)
|
50
|
+
}
|
51
|
+
|
52
|
+
def setEnvVars(vars: Map[String, String]) {
|
53
|
+
pb.environment().putAll(vars)
|
54
|
+
}
|
55
|
+
|
56
|
+
def run = {
|
57
|
+
new ExecutedFileCommand(pb.start)
|
58
|
+
}
|
59
|
+
}
|
60
|
+
|
61
|
+
|
62
|
+
/* =================================================================================================
|
63
|
+
* class ExecutedFileCommand
|
64
|
+
* =================================================================================================
|
65
|
+
*
|
66
|
+
* Represent process executed from file.
|
67
|
+
*/
|
68
|
+
|
69
|
+
class ExecutedFileCommand(process: Process) {
|
70
|
+
|
71
|
+
var reader: BufferedReader = null
|
72
|
+
|
73
|
+
def readLine = {
|
74
|
+
openInput
|
75
|
+
reader.readLine.toString.trim
|
76
|
+
}
|
77
|
+
|
78
|
+
def openInput {
|
79
|
+
if(reader != null){
|
80
|
+
return
|
81
|
+
}
|
82
|
+
|
83
|
+
val input = new InputStreamReader(process.getInputStream)
|
84
|
+
reader = new BufferedReader(input)
|
85
|
+
}
|
86
|
+
|
87
|
+
// Delegation
|
88
|
+
def destroy = process.destroy
|
89
|
+
def getInputStream = process.getInputStream
|
90
|
+
def getErrorStream = process.getErrorStream
|
91
|
+
}
|
@@ -0,0 +1,52 @@
|
|
1
|
+
package org.apache.spark.api.ruby.marshal
|
2
|
+
|
3
|
+
import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
|
4
|
+
|
5
|
+
import scala.collection.mutable.ArrayBuffer
|
6
|
+
import scala.collection.JavaConverters._
|
7
|
+
|
8
|
+
|
9
|
+
/* =================================================================================================
|
10
|
+
* object Marshal
|
11
|
+
* =================================================================================================
|
12
|
+
*/
|
13
|
+
object Marshal {
|
14
|
+
def load(bytes: Array[Byte]) = {
|
15
|
+
val is = new DataInputStream(new ByteArrayInputStream(bytes))
|
16
|
+
|
17
|
+
val majorVersion = is.readUnsignedByte // 4
|
18
|
+
val minorVersion = is.readUnsignedByte // 8
|
19
|
+
|
20
|
+
(new MarshalLoad(is)).load
|
21
|
+
}
|
22
|
+
|
23
|
+
def dump(data: Any) = {
|
24
|
+
val aos = new ByteArrayOutputStream
|
25
|
+
val os = new DataOutputStream(aos)
|
26
|
+
|
27
|
+
os.writeByte(4)
|
28
|
+
os.writeByte(8)
|
29
|
+
|
30
|
+
(new MarshalDump(os)).dump(data)
|
31
|
+
aos.toByteArray
|
32
|
+
}
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
/* =================================================================================================
|
37
|
+
* class IterableMarshaller
|
38
|
+
* =================================================================================================
|
39
|
+
*/
|
40
|
+
class IterableMarshaller(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
|
41
|
+
private val buffer = new ArrayBuffer[Any]
|
42
|
+
|
43
|
+
override def hasNext: Boolean = iter.hasNext
|
44
|
+
|
45
|
+
override def next(): Array[Byte] = {
|
46
|
+
while (iter.hasNext) {
|
47
|
+
buffer += iter.next()
|
48
|
+
}
|
49
|
+
|
50
|
+
Marshal.dump(buffer)
|
51
|
+
}
|
52
|
+
}
|
@@ -0,0 +1,113 @@
|
|
1
|
+
package org.apache.spark.api.ruby.marshal
|
2
|
+
|
3
|
+
import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
|
4
|
+
|
5
|
+
import scala.collection.mutable.ArrayBuffer
|
6
|
+
import scala.collection.JavaConverters._
|
7
|
+
import scala.reflect.{ClassTag, classTag}
|
8
|
+
|
9
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
10
|
+
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
|
11
|
+
|
12
|
+
|
13
|
+
/* =================================================================================================
|
14
|
+
* class MarshalDump
|
15
|
+
* =================================================================================================
|
16
|
+
*/
|
17
|
+
class MarshalDump(os: DataOutputStream) {
|
18
|
+
|
19
|
+
val NAN_BYTELIST = "nan".getBytes
|
20
|
+
val NEGATIVE_INFINITY_BYTELIST = "-inf".getBytes
|
21
|
+
val INFINITY_BYTELIST = "inf".getBytes
|
22
|
+
|
23
|
+
def dump(data: Any) {
|
24
|
+
data match {
|
25
|
+
case null =>
|
26
|
+
os.writeByte('0')
|
27
|
+
|
28
|
+
case item: Boolean =>
|
29
|
+
val char = if(item) 'T' else 'F'
|
30
|
+
os.writeByte(char)
|
31
|
+
|
32
|
+
case item: Int =>
|
33
|
+
os.writeByte('i')
|
34
|
+
dumpInt(item)
|
35
|
+
|
36
|
+
case item: Array[_] =>
|
37
|
+
os.writeByte('[')
|
38
|
+
dumpArray(item)
|
39
|
+
|
40
|
+
case item: Double =>
|
41
|
+
os.writeByte('f')
|
42
|
+
dumpFloat(item)
|
43
|
+
|
44
|
+
case item: ArrayBuffer[Any] => dump(item.toArray)
|
45
|
+
}
|
46
|
+
}
|
47
|
+
|
48
|
+
def dumpInt(data: Int) {
|
49
|
+
if(data == 0){
|
50
|
+
os.writeByte(0)
|
51
|
+
}
|
52
|
+
else if (0 < data && data < 123) {
|
53
|
+
os.writeByte(data + 5)
|
54
|
+
}
|
55
|
+
else if (-124 < data && data < 0) {
|
56
|
+
os.writeByte((data - 5) & 0xff)
|
57
|
+
}
|
58
|
+
else {
|
59
|
+
val buffer = new Array[Byte](4)
|
60
|
+
var value = data
|
61
|
+
|
62
|
+
var i = 0
|
63
|
+
while(i != 4 && value != 0 && value != -1){
|
64
|
+
buffer(i) = (value & 0xff).toByte
|
65
|
+
value = value >> 8
|
66
|
+
|
67
|
+
i += 1
|
68
|
+
}
|
69
|
+
val lenght = i + 1
|
70
|
+
if(value < 0){
|
71
|
+
os.writeByte(-lenght)
|
72
|
+
}
|
73
|
+
else{
|
74
|
+
os.writeByte(lenght)
|
75
|
+
}
|
76
|
+
os.write(buffer, 0, lenght)
|
77
|
+
}
|
78
|
+
}
|
79
|
+
|
80
|
+
def dumpArray(array: Array[_]) {
|
81
|
+
dumpInt(array.size)
|
82
|
+
|
83
|
+
for(item <- array) {
|
84
|
+
dump(item)
|
85
|
+
}
|
86
|
+
}
|
87
|
+
|
88
|
+
def dumpFloat(value: Double) {
|
89
|
+
if(value.isPosInfinity){
|
90
|
+
dumpString(NEGATIVE_INFINITY_BYTELIST)
|
91
|
+
}
|
92
|
+
else if(value.isNegInfinity){
|
93
|
+
dumpString(INFINITY_BYTELIST)
|
94
|
+
}
|
95
|
+
else if(value.isNaN){
|
96
|
+
dumpString(NAN_BYTELIST)
|
97
|
+
}
|
98
|
+
else{
|
99
|
+
// dumpString("%.17g".format(value))
|
100
|
+
dumpString(value.toString)
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
def dumpString(data: String) {
|
105
|
+
dumpString(data.getBytes)
|
106
|
+
}
|
107
|
+
|
108
|
+
def dumpString(data: Array[Byte]) {
|
109
|
+
dumpInt(data.size)
|
110
|
+
os.write(data)
|
111
|
+
}
|
112
|
+
|
113
|
+
}
|
@@ -0,0 +1,220 @@
|
|
1
|
+
package org.apache.spark.api.ruby.marshal
|
2
|
+
|
3
|
+
import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
|
4
|
+
|
5
|
+
import scala.collection.mutable.ArrayBuffer
|
6
|
+
import scala.collection.JavaConverters._
|
7
|
+
import scala.reflect.{ClassTag, classTag}
|
8
|
+
|
9
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
10
|
+
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
|
11
|
+
|
12
|
+
|
13
|
+
/* =================================================================================================
|
14
|
+
* class MarshalLoad
|
15
|
+
* =================================================================================================
|
16
|
+
*/
|
17
|
+
class MarshalLoad(is: DataInputStream) {
|
18
|
+
|
19
|
+
case class WaitForObject()
|
20
|
+
|
21
|
+
val registeredSymbols = ArrayBuffer[String]()
|
22
|
+
val registeredLinks = ArrayBuffer[Any]()
|
23
|
+
|
24
|
+
def load: Any = {
|
25
|
+
load(is.readUnsignedByte.toChar)
|
26
|
+
}
|
27
|
+
|
28
|
+
def load(dataType: Char): Any = {
|
29
|
+
dataType match {
|
30
|
+
case '0' => null
|
31
|
+
case 'T' => true
|
32
|
+
case 'F' => false
|
33
|
+
case 'i' => loadInt
|
34
|
+
case 'f' => loadAndRegisterFloat
|
35
|
+
case ':' => loadAndRegisterSymbol
|
36
|
+
case '[' => loadAndRegisterArray
|
37
|
+
case 'U' => loadAndRegisterUserObject
|
38
|
+
case _ =>
|
39
|
+
throw new IllegalArgumentException(s"Format is not supported: $dataType.")
|
40
|
+
}
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
// ----------------------------------------------------------------------------------------------
|
45
|
+
// Load by type
|
46
|
+
|
47
|
+
def loadInt: Int = {
|
48
|
+
var c = is.readByte.toInt
|
49
|
+
|
50
|
+
if (c == 0) {
|
51
|
+
return 0
|
52
|
+
} else if (4 < c && c < 128) {
|
53
|
+
return c - 5
|
54
|
+
} else if (-129 < c && c < -4) {
|
55
|
+
return c + 5
|
56
|
+
}
|
57
|
+
|
58
|
+
var result: Long = 0
|
59
|
+
|
60
|
+
if (c > 0) {
|
61
|
+
result = 0
|
62
|
+
for( i <- 0 until c ) {
|
63
|
+
result |= (is.readUnsignedByte << (8 * i)).toLong
|
64
|
+
}
|
65
|
+
} else {
|
66
|
+
c = -c
|
67
|
+
result = -1
|
68
|
+
for( i <- 0 until c ) {
|
69
|
+
result &= ~((0xff << (8 * i)).toLong)
|
70
|
+
result |= (is.readUnsignedByte << (8 * i)).toLong
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
result.toInt
|
75
|
+
}
|
76
|
+
|
77
|
+
def loadAndRegisterFloat: Double = {
|
78
|
+
val result = loadFloat
|
79
|
+
registeredLinks += result
|
80
|
+
result
|
81
|
+
}
|
82
|
+
|
83
|
+
def loadFloat: Double = {
|
84
|
+
val string = loadString
|
85
|
+
string match {
|
86
|
+
case "nan" => Double.NaN
|
87
|
+
case "inf" => Double.PositiveInfinity
|
88
|
+
case "-inf" => Double.NegativeInfinity
|
89
|
+
case _ => string.toDouble
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
def loadString: String = {
|
94
|
+
new String(loadStringBytes)
|
95
|
+
}
|
96
|
+
|
97
|
+
def loadStringBytes: Array[Byte] = {
|
98
|
+
val size = loadInt
|
99
|
+
val buffer = new Array[Byte](size)
|
100
|
+
|
101
|
+
var readSize = 0
|
102
|
+
while(readSize < size){
|
103
|
+
val read = is.read(buffer, readSize, size-readSize)
|
104
|
+
|
105
|
+
if(read == -1){
|
106
|
+
throw new IllegalArgumentException("Marshal too short.")
|
107
|
+
}
|
108
|
+
|
109
|
+
readSize += read
|
110
|
+
}
|
111
|
+
|
112
|
+
buffer
|
113
|
+
}
|
114
|
+
|
115
|
+
def loadAndRegisterSymbol: String = {
|
116
|
+
val result = loadString
|
117
|
+
registeredSymbols += result
|
118
|
+
result
|
119
|
+
}
|
120
|
+
|
121
|
+
def loadAndRegisterArray: Array[Any] = {
|
122
|
+
val size = loadInt
|
123
|
+
val array = new Array[Any](size)
|
124
|
+
|
125
|
+
registeredLinks += array
|
126
|
+
|
127
|
+
for( i <- 0 until size ) {
|
128
|
+
array(i) = loadNextObject
|
129
|
+
}
|
130
|
+
|
131
|
+
array
|
132
|
+
}
|
133
|
+
|
134
|
+
def loadAndRegisterUserObject: Any = {
|
135
|
+
val klass = loadNextObject.asInstanceOf[String]
|
136
|
+
|
137
|
+
// Register future class before load the next object
|
138
|
+
registeredLinks += WaitForObject()
|
139
|
+
val index = registeredLinks.size - 1
|
140
|
+
|
141
|
+
val data = loadNextObject
|
142
|
+
|
143
|
+
val result = klass match {
|
144
|
+
case "Spark::Mllib::LabeledPoint" => createLabeledPoint(data)
|
145
|
+
case "Spark::Mllib::DenseVector" => createDenseVector(data)
|
146
|
+
case "Spark::Mllib::SparseVector" => createSparseVector(data)
|
147
|
+
case other =>
|
148
|
+
throw new IllegalArgumentException(s"Object $other is not supported.")
|
149
|
+
}
|
150
|
+
|
151
|
+
registeredLinks(index) = result
|
152
|
+
|
153
|
+
result
|
154
|
+
}
|
155
|
+
|
156
|
+
|
157
|
+
// ----------------------------------------------------------------------------------------------
|
158
|
+
// Other loads
|
159
|
+
|
160
|
+
def loadNextObject: Any = {
|
161
|
+
val dataType = is.readUnsignedByte.toChar
|
162
|
+
|
163
|
+
if(isLinkType(dataType)){
|
164
|
+
readLink(dataType)
|
165
|
+
}
|
166
|
+
else{
|
167
|
+
load(dataType)
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
|
172
|
+
// ----------------------------------------------------------------------------------------------
|
173
|
+
// To java objects
|
174
|
+
|
175
|
+
def createLabeledPoint(data: Any): LabeledPoint = {
|
176
|
+
val array = data.asInstanceOf[Array[_]]
|
177
|
+
new LabeledPoint(array(0).asInstanceOf[Double], array(1).asInstanceOf[Vector])
|
178
|
+
}
|
179
|
+
|
180
|
+
def createDenseVector(data: Any): DenseVector = {
|
181
|
+
new DenseVector(data.asInstanceOf[Array[_]].map(toDouble(_)))
|
182
|
+
}
|
183
|
+
|
184
|
+
def createSparseVector(data: Any): SparseVector = {
|
185
|
+
val array = data.asInstanceOf[Array[_]]
|
186
|
+
val size = array(0).asInstanceOf[Int]
|
187
|
+
val indices = array(1).asInstanceOf[Array[_]].map(_.asInstanceOf[Int])
|
188
|
+
val values = array(2).asInstanceOf[Array[_]].map(toDouble(_))
|
189
|
+
|
190
|
+
new SparseVector(size, indices, values)
|
191
|
+
}
|
192
|
+
|
193
|
+
|
194
|
+
// ----------------------------------------------------------------------------------------------
|
195
|
+
// Helpers
|
196
|
+
|
197
|
+
def toDouble(data: Any): Double = data match {
|
198
|
+
case x: Int => x.toDouble
|
199
|
+
case x: Double => x
|
200
|
+
case _ => 0.0
|
201
|
+
}
|
202
|
+
|
203
|
+
|
204
|
+
// ----------------------------------------------------------------------------------------------
|
205
|
+
// Cache
|
206
|
+
|
207
|
+
def readLink(dataType: Char): Any = {
|
208
|
+
val index = loadInt
|
209
|
+
|
210
|
+
dataType match {
|
211
|
+
case '@' => registeredLinks(index)
|
212
|
+
case ';' => registeredSymbols(index)
|
213
|
+
}
|
214
|
+
}
|
215
|
+
|
216
|
+
def isLinkType(dataType: Char): Boolean = {
|
217
|
+
dataType == ';' || dataType == '@'
|
218
|
+
}
|
219
|
+
|
220
|
+
}
|