ruby-spark 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +37 -0
- data/Gemfile +47 -0
- data/Guardfile +5 -0
- data/LICENSE.txt +22 -0
- data/README.md +185 -0
- data/Rakefile +35 -0
- data/TODO.md +7 -0
- data/benchmark/aggregate.rb +33 -0
- data/benchmark/bisect.rb +88 -0
- data/benchmark/custom_marshal.rb +94 -0
- data/benchmark/digest.rb +150 -0
- data/benchmark/enumerator.rb +88 -0
- data/benchmark/performance/prepare.sh +18 -0
- data/benchmark/performance/python.py +156 -0
- data/benchmark/performance/r.r +69 -0
- data/benchmark/performance/ruby.rb +167 -0
- data/benchmark/performance/run-all.sh +160 -0
- data/benchmark/performance/scala.scala +181 -0
- data/benchmark/serializer.rb +82 -0
- data/benchmark/sort.rb +43 -0
- data/benchmark/sort2.rb +164 -0
- data/benchmark/take.rb +28 -0
- data/bin/ruby-spark +8 -0
- data/example/pi.rb +28 -0
- data/ext/ruby_c/extconf.rb +3 -0
- data/ext/ruby_c/murmur.c +158 -0
- data/ext/ruby_c/murmur.h +9 -0
- data/ext/ruby_c/ruby-spark.c +18 -0
- data/ext/ruby_java/Digest.java +36 -0
- data/ext/ruby_java/Murmur2.java +98 -0
- data/ext/ruby_java/RubySparkExtService.java +28 -0
- data/ext/ruby_java/extconf.rb +3 -0
- data/ext/spark/build.sbt +73 -0
- data/ext/spark/project/plugins.sbt +9 -0
- data/ext/spark/sbt/sbt +34 -0
- data/ext/spark/src/main/scala/Exec.scala +91 -0
- data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
- data/ext/spark/src/main/scala/Marshal.scala +52 -0
- data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
- data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
- data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
- data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
- data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
- data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
- data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
- data/ext/spark/src/main/scala/RubyPage.scala +34 -0
- data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
- data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
- data/ext/spark/src/main/scala/RubyTab.scala +11 -0
- data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
- data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
- data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
- data/lib/ruby-spark.rb +1 -0
- data/lib/spark.rb +198 -0
- data/lib/spark/accumulator.rb +260 -0
- data/lib/spark/broadcast.rb +98 -0
- data/lib/spark/build.rb +43 -0
- data/lib/spark/cli.rb +169 -0
- data/lib/spark/command.rb +86 -0
- data/lib/spark/command/base.rb +154 -0
- data/lib/spark/command/basic.rb +345 -0
- data/lib/spark/command/pair.rb +124 -0
- data/lib/spark/command/sort.rb +51 -0
- data/lib/spark/command/statistic.rb +144 -0
- data/lib/spark/command_builder.rb +141 -0
- data/lib/spark/command_validator.rb +34 -0
- data/lib/spark/config.rb +244 -0
- data/lib/spark/constant.rb +14 -0
- data/lib/spark/context.rb +304 -0
- data/lib/spark/error.rb +50 -0
- data/lib/spark/ext/hash.rb +41 -0
- data/lib/spark/ext/integer.rb +25 -0
- data/lib/spark/ext/io.rb +57 -0
- data/lib/spark/ext/ip_socket.rb +29 -0
- data/lib/spark/ext/module.rb +58 -0
- data/lib/spark/ext/object.rb +24 -0
- data/lib/spark/ext/string.rb +24 -0
- data/lib/spark/helper.rb +10 -0
- data/lib/spark/helper/logger.rb +40 -0
- data/lib/spark/helper/parser.rb +85 -0
- data/lib/spark/helper/serialize.rb +71 -0
- data/lib/spark/helper/statistic.rb +93 -0
- data/lib/spark/helper/system.rb +42 -0
- data/lib/spark/java_bridge.rb +19 -0
- data/lib/spark/java_bridge/base.rb +203 -0
- data/lib/spark/java_bridge/jruby.rb +23 -0
- data/lib/spark/java_bridge/rjb.rb +41 -0
- data/lib/spark/logger.rb +76 -0
- data/lib/spark/mllib.rb +100 -0
- data/lib/spark/mllib/classification/common.rb +31 -0
- data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
- data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
- data/lib/spark/mllib/classification/svm.rb +135 -0
- data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
- data/lib/spark/mllib/clustering/kmeans.rb +118 -0
- data/lib/spark/mllib/matrix.rb +120 -0
- data/lib/spark/mllib/regression/common.rb +73 -0
- data/lib/spark/mllib/regression/labeled_point.rb +41 -0
- data/lib/spark/mllib/regression/lasso.rb +100 -0
- data/lib/spark/mllib/regression/linear.rb +124 -0
- data/lib/spark/mllib/regression/ridge.rb +97 -0
- data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
- data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
- data/lib/spark/mllib/stat/distribution.rb +12 -0
- data/lib/spark/mllib/vector.rb +185 -0
- data/lib/spark/rdd.rb +1328 -0
- data/lib/spark/sampler.rb +92 -0
- data/lib/spark/serializer.rb +24 -0
- data/lib/spark/serializer/base.rb +170 -0
- data/lib/spark/serializer/cartesian.rb +37 -0
- data/lib/spark/serializer/marshal.rb +19 -0
- data/lib/spark/serializer/message_pack.rb +25 -0
- data/lib/spark/serializer/oj.rb +25 -0
- data/lib/spark/serializer/pair.rb +27 -0
- data/lib/spark/serializer/utf8.rb +25 -0
- data/lib/spark/sort.rb +189 -0
- data/lib/spark/stat_counter.rb +125 -0
- data/lib/spark/storage_level.rb +39 -0
- data/lib/spark/version.rb +3 -0
- data/lib/spark/worker/master.rb +144 -0
- data/lib/spark/worker/spark_files.rb +15 -0
- data/lib/spark/worker/worker.rb +197 -0
- data/ruby-spark.gemspec +36 -0
- data/spec/generator.rb +37 -0
- data/spec/inputs/lorem_300.txt +316 -0
- data/spec/inputs/numbers/1.txt +50 -0
- data/spec/inputs/numbers/10.txt +50 -0
- data/spec/inputs/numbers/11.txt +50 -0
- data/spec/inputs/numbers/12.txt +50 -0
- data/spec/inputs/numbers/13.txt +50 -0
- data/spec/inputs/numbers/14.txt +50 -0
- data/spec/inputs/numbers/15.txt +50 -0
- data/spec/inputs/numbers/16.txt +50 -0
- data/spec/inputs/numbers/17.txt +50 -0
- data/spec/inputs/numbers/18.txt +50 -0
- data/spec/inputs/numbers/19.txt +50 -0
- data/spec/inputs/numbers/2.txt +50 -0
- data/spec/inputs/numbers/20.txt +50 -0
- data/spec/inputs/numbers/3.txt +50 -0
- data/spec/inputs/numbers/4.txt +50 -0
- data/spec/inputs/numbers/5.txt +50 -0
- data/spec/inputs/numbers/6.txt +50 -0
- data/spec/inputs/numbers/7.txt +50 -0
- data/spec/inputs/numbers/8.txt +50 -0
- data/spec/inputs/numbers/9.txt +50 -0
- data/spec/inputs/numbers_0_100.txt +101 -0
- data/spec/inputs/numbers_1_100.txt +100 -0
- data/spec/lib/collect_spec.rb +42 -0
- data/spec/lib/command_spec.rb +68 -0
- data/spec/lib/config_spec.rb +64 -0
- data/spec/lib/context_spec.rb +163 -0
- data/spec/lib/ext_spec.rb +72 -0
- data/spec/lib/external_apps_spec.rb +45 -0
- data/spec/lib/filter_spec.rb +80 -0
- data/spec/lib/flat_map_spec.rb +100 -0
- data/spec/lib/group_spec.rb +109 -0
- data/spec/lib/helper_spec.rb +19 -0
- data/spec/lib/key_spec.rb +41 -0
- data/spec/lib/manipulation_spec.rb +114 -0
- data/spec/lib/map_partitions_spec.rb +87 -0
- data/spec/lib/map_spec.rb +91 -0
- data/spec/lib/mllib/classification_spec.rb +54 -0
- data/spec/lib/mllib/clustering_spec.rb +35 -0
- data/spec/lib/mllib/matrix_spec.rb +32 -0
- data/spec/lib/mllib/regression_spec.rb +116 -0
- data/spec/lib/mllib/vector_spec.rb +77 -0
- data/spec/lib/reduce_by_key_spec.rb +118 -0
- data/spec/lib/reduce_spec.rb +131 -0
- data/spec/lib/sample_spec.rb +46 -0
- data/spec/lib/serializer_spec.rb +13 -0
- data/spec/lib/sort_spec.rb +58 -0
- data/spec/lib/statistic_spec.rb +168 -0
- data/spec/lib/whole_text_files_spec.rb +33 -0
- data/spec/spec_helper.rb +39 -0
- metadata +301 -0
@@ -0,0 +1,113 @@
|
|
1
|
+
package org.apache.spark.api.ruby.marshal
|
2
|
+
|
3
|
+
import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
|
4
|
+
|
5
|
+
import scala.collection.mutable.ArrayBuffer
|
6
|
+
import scala.collection.JavaConverters._
|
7
|
+
import scala.reflect.{ClassTag, classTag}
|
8
|
+
|
9
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
10
|
+
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
|
11
|
+
|
12
|
+
|
13
|
+
/* =================================================================================================
|
14
|
+
* class MarshalDump
|
15
|
+
* =================================================================================================
|
16
|
+
*/
|
17
|
+
class MarshalDump(os: DataOutputStream) {
|
18
|
+
|
19
|
+
val NAN_BYTELIST = "nan".getBytes
|
20
|
+
val NEGATIVE_INFINITY_BYTELIST = "-inf".getBytes
|
21
|
+
val INFINITY_BYTELIST = "inf".getBytes
|
22
|
+
|
23
|
+
def dump(data: Any) {
|
24
|
+
data match {
|
25
|
+
case null =>
|
26
|
+
os.writeByte('0')
|
27
|
+
|
28
|
+
case item: Boolean =>
|
29
|
+
val char = if(item) 'T' else 'F'
|
30
|
+
os.writeByte(char)
|
31
|
+
|
32
|
+
case item: Int =>
|
33
|
+
os.writeByte('i')
|
34
|
+
dumpInt(item)
|
35
|
+
|
36
|
+
case item: Array[_] =>
|
37
|
+
os.writeByte('[')
|
38
|
+
dumpArray(item)
|
39
|
+
|
40
|
+
case item: Double =>
|
41
|
+
os.writeByte('f')
|
42
|
+
dumpFloat(item)
|
43
|
+
|
44
|
+
case item: ArrayBuffer[Any] => dump(item.toArray)
|
45
|
+
}
|
46
|
+
}
|
47
|
+
|
48
|
+
def dumpInt(data: Int) {
|
49
|
+
if(data == 0){
|
50
|
+
os.writeByte(0)
|
51
|
+
}
|
52
|
+
else if (0 < data && data < 123) {
|
53
|
+
os.writeByte(data + 5)
|
54
|
+
}
|
55
|
+
else if (-124 < data && data < 0) {
|
56
|
+
os.writeByte((data - 5) & 0xff)
|
57
|
+
}
|
58
|
+
else {
|
59
|
+
val buffer = new Array[Byte](4)
|
60
|
+
var value = data
|
61
|
+
|
62
|
+
var i = 0
|
63
|
+
while(i != 4 && value != 0 && value != -1){
|
64
|
+
buffer(i) = (value & 0xff).toByte
|
65
|
+
value = value >> 8
|
66
|
+
|
67
|
+
i += 1
|
68
|
+
}
|
69
|
+
val lenght = i + 1
|
70
|
+
if(value < 0){
|
71
|
+
os.writeByte(-lenght)
|
72
|
+
}
|
73
|
+
else{
|
74
|
+
os.writeByte(lenght)
|
75
|
+
}
|
76
|
+
os.write(buffer, 0, lenght)
|
77
|
+
}
|
78
|
+
}
|
79
|
+
|
80
|
+
def dumpArray(array: Array[_]) {
|
81
|
+
dumpInt(array.size)
|
82
|
+
|
83
|
+
for(item <- array) {
|
84
|
+
dump(item)
|
85
|
+
}
|
86
|
+
}
|
87
|
+
|
88
|
+
def dumpFloat(value: Double) {
|
89
|
+
if(value.isPosInfinity){
|
90
|
+
dumpString(NEGATIVE_INFINITY_BYTELIST)
|
91
|
+
}
|
92
|
+
else if(value.isNegInfinity){
|
93
|
+
dumpString(INFINITY_BYTELIST)
|
94
|
+
}
|
95
|
+
else if(value.isNaN){
|
96
|
+
dumpString(NAN_BYTELIST)
|
97
|
+
}
|
98
|
+
else{
|
99
|
+
// dumpString("%.17g".format(value))
|
100
|
+
dumpString(value.toString)
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
def dumpString(data: String) {
|
105
|
+
dumpString(data.getBytes)
|
106
|
+
}
|
107
|
+
|
108
|
+
def dumpString(data: Array[Byte]) {
|
109
|
+
dumpInt(data.size)
|
110
|
+
os.write(data)
|
111
|
+
}
|
112
|
+
|
113
|
+
}
|
@@ -0,0 +1,220 @@
|
|
1
|
+
package org.apache.spark.api.ruby.marshal
|
2
|
+
|
3
|
+
import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
|
4
|
+
|
5
|
+
import scala.collection.mutable.ArrayBuffer
|
6
|
+
import scala.collection.JavaConverters._
|
7
|
+
import scala.reflect.{ClassTag, classTag}
|
8
|
+
|
9
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
10
|
+
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
|
11
|
+
|
12
|
+
|
13
|
+
/* =================================================================================================
|
14
|
+
* class MarshalLoad
|
15
|
+
* =================================================================================================
|
16
|
+
*/
|
17
|
+
class MarshalLoad(is: DataInputStream) {
|
18
|
+
|
19
|
+
case class WaitForObject()
|
20
|
+
|
21
|
+
val registeredSymbols = ArrayBuffer[String]()
|
22
|
+
val registeredLinks = ArrayBuffer[Any]()
|
23
|
+
|
24
|
+
def load: Any = {
|
25
|
+
load(is.readUnsignedByte.toChar)
|
26
|
+
}
|
27
|
+
|
28
|
+
def load(dataType: Char): Any = {
|
29
|
+
dataType match {
|
30
|
+
case '0' => null
|
31
|
+
case 'T' => true
|
32
|
+
case 'F' => false
|
33
|
+
case 'i' => loadInt
|
34
|
+
case 'f' => loadAndRegisterFloat
|
35
|
+
case ':' => loadAndRegisterSymbol
|
36
|
+
case '[' => loadAndRegisterArray
|
37
|
+
case 'U' => loadAndRegisterUserObject
|
38
|
+
case _ =>
|
39
|
+
throw new IllegalArgumentException(s"Format is not supported: $dataType.")
|
40
|
+
}
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
// ----------------------------------------------------------------------------------------------
|
45
|
+
// Load by type
|
46
|
+
|
47
|
+
def loadInt: Int = {
|
48
|
+
var c = is.readByte.toInt
|
49
|
+
|
50
|
+
if (c == 0) {
|
51
|
+
return 0
|
52
|
+
} else if (4 < c && c < 128) {
|
53
|
+
return c - 5
|
54
|
+
} else if (-129 < c && c < -4) {
|
55
|
+
return c + 5
|
56
|
+
}
|
57
|
+
|
58
|
+
var result: Long = 0
|
59
|
+
|
60
|
+
if (c > 0) {
|
61
|
+
result = 0
|
62
|
+
for( i <- 0 until c ) {
|
63
|
+
result |= (is.readUnsignedByte << (8 * i)).toLong
|
64
|
+
}
|
65
|
+
} else {
|
66
|
+
c = -c
|
67
|
+
result = -1
|
68
|
+
for( i <- 0 until c ) {
|
69
|
+
result &= ~((0xff << (8 * i)).toLong)
|
70
|
+
result |= (is.readUnsignedByte << (8 * i)).toLong
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
result.toInt
|
75
|
+
}
|
76
|
+
|
77
|
+
def loadAndRegisterFloat: Double = {
|
78
|
+
val result = loadFloat
|
79
|
+
registeredLinks += result
|
80
|
+
result
|
81
|
+
}
|
82
|
+
|
83
|
+
def loadFloat: Double = {
|
84
|
+
val string = loadString
|
85
|
+
string match {
|
86
|
+
case "nan" => Double.NaN
|
87
|
+
case "inf" => Double.PositiveInfinity
|
88
|
+
case "-inf" => Double.NegativeInfinity
|
89
|
+
case _ => string.toDouble
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
def loadString: String = {
|
94
|
+
new String(loadStringBytes)
|
95
|
+
}
|
96
|
+
|
97
|
+
def loadStringBytes: Array[Byte] = {
|
98
|
+
val size = loadInt
|
99
|
+
val buffer = new Array[Byte](size)
|
100
|
+
|
101
|
+
var readSize = 0
|
102
|
+
while(readSize < size){
|
103
|
+
val read = is.read(buffer, readSize, size-readSize)
|
104
|
+
|
105
|
+
if(read == -1){
|
106
|
+
throw new IllegalArgumentException("Marshal too short.")
|
107
|
+
}
|
108
|
+
|
109
|
+
readSize += read
|
110
|
+
}
|
111
|
+
|
112
|
+
buffer
|
113
|
+
}
|
114
|
+
|
115
|
+
def loadAndRegisterSymbol: String = {
|
116
|
+
val result = loadString
|
117
|
+
registeredSymbols += result
|
118
|
+
result
|
119
|
+
}
|
120
|
+
|
121
|
+
def loadAndRegisterArray: Array[Any] = {
|
122
|
+
val size = loadInt
|
123
|
+
val array = new Array[Any](size)
|
124
|
+
|
125
|
+
registeredLinks += array
|
126
|
+
|
127
|
+
for( i <- 0 until size ) {
|
128
|
+
array(i) = loadNextObject
|
129
|
+
}
|
130
|
+
|
131
|
+
array
|
132
|
+
}
|
133
|
+
|
134
|
+
def loadAndRegisterUserObject: Any = {
|
135
|
+
val klass = loadNextObject.asInstanceOf[String]
|
136
|
+
|
137
|
+
// Register future class before load the next object
|
138
|
+
registeredLinks += WaitForObject()
|
139
|
+
val index = registeredLinks.size - 1
|
140
|
+
|
141
|
+
val data = loadNextObject
|
142
|
+
|
143
|
+
val result = klass match {
|
144
|
+
case "Spark::Mllib::LabeledPoint" => createLabeledPoint(data)
|
145
|
+
case "Spark::Mllib::DenseVector" => createDenseVector(data)
|
146
|
+
case "Spark::Mllib::SparseVector" => createSparseVector(data)
|
147
|
+
case other =>
|
148
|
+
throw new IllegalArgumentException(s"Object $other is not supported.")
|
149
|
+
}
|
150
|
+
|
151
|
+
registeredLinks(index) = result
|
152
|
+
|
153
|
+
result
|
154
|
+
}
|
155
|
+
|
156
|
+
|
157
|
+
// ----------------------------------------------------------------------------------------------
|
158
|
+
// Other loads
|
159
|
+
|
160
|
+
def loadNextObject: Any = {
|
161
|
+
val dataType = is.readUnsignedByte.toChar
|
162
|
+
|
163
|
+
if(isLinkType(dataType)){
|
164
|
+
readLink(dataType)
|
165
|
+
}
|
166
|
+
else{
|
167
|
+
load(dataType)
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
|
172
|
+
// ----------------------------------------------------------------------------------------------
|
173
|
+
// To java objects
|
174
|
+
|
175
|
+
def createLabeledPoint(data: Any): LabeledPoint = {
|
176
|
+
val array = data.asInstanceOf[Array[_]]
|
177
|
+
new LabeledPoint(array(0).asInstanceOf[Double], array(1).asInstanceOf[Vector])
|
178
|
+
}
|
179
|
+
|
180
|
+
def createDenseVector(data: Any): DenseVector = {
|
181
|
+
new DenseVector(data.asInstanceOf[Array[_]].map(toDouble(_)))
|
182
|
+
}
|
183
|
+
|
184
|
+
def createSparseVector(data: Any): SparseVector = {
|
185
|
+
val array = data.asInstanceOf[Array[_]]
|
186
|
+
val size = array(0).asInstanceOf[Int]
|
187
|
+
val indices = array(1).asInstanceOf[Array[_]].map(_.asInstanceOf[Int])
|
188
|
+
val values = array(2).asInstanceOf[Array[_]].map(toDouble(_))
|
189
|
+
|
190
|
+
new SparseVector(size, indices, values)
|
191
|
+
}
|
192
|
+
|
193
|
+
|
194
|
+
// ----------------------------------------------------------------------------------------------
|
195
|
+
// Helpers
|
196
|
+
|
197
|
+
def toDouble(data: Any): Double = data match {
|
198
|
+
case x: Int => x.toDouble
|
199
|
+
case x: Double => x
|
200
|
+
case _ => 0.0
|
201
|
+
}
|
202
|
+
|
203
|
+
|
204
|
+
// ----------------------------------------------------------------------------------------------
|
205
|
+
// Cache
|
206
|
+
|
207
|
+
def readLink(dataType: Char): Any = {
|
208
|
+
val index = loadInt
|
209
|
+
|
210
|
+
dataType match {
|
211
|
+
case '@' => registeredLinks(index)
|
212
|
+
case ';' => registeredSymbols(index)
|
213
|
+
}
|
214
|
+
}
|
215
|
+
|
216
|
+
def isLinkType(dataType: Char): Boolean = {
|
217
|
+
dataType == ';' || dataType == '@'
|
218
|
+
}
|
219
|
+
|
220
|
+
}
|
@@ -0,0 +1,69 @@
|
|
1
|
+
package org.apache.spark.api.ruby
|
2
|
+
|
3
|
+
import java.io._
|
4
|
+
import java.net._
|
5
|
+
import java.util.{List, ArrayList}
|
6
|
+
|
7
|
+
import scala.collection.JavaConversions._
|
8
|
+
import scala.collection.immutable._
|
9
|
+
|
10
|
+
import org.apache.spark._
|
11
|
+
import org.apache.spark.util.Utils
|
12
|
+
|
13
|
+
/**
|
14
|
+
* Internal class that acts as an `AccumulatorParam` for Ruby accumulators. Inside, it
|
15
|
+
* collects a list of pickled strings that we pass to Ruby through a socket.
|
16
|
+
*/
|
17
|
+
private class RubyAccumulatorParam(serverHost: String, serverPort: Int)
|
18
|
+
extends AccumulatorParam[List[Array[Byte]]] {
|
19
|
+
|
20
|
+
// Utils.checkHost(serverHost, "Expected hostname")
|
21
|
+
|
22
|
+
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
|
23
|
+
|
24
|
+
// Socket shoudl not be serialized
|
25
|
+
// Otherwise: SparkException: Task not serializable
|
26
|
+
@transient var socket: Socket = null
|
27
|
+
@transient var socketOutputStream: DataOutputStream = null
|
28
|
+
@transient var socketInputStream: DataInputStream = null
|
29
|
+
|
30
|
+
def openSocket(){
|
31
|
+
synchronized {
|
32
|
+
if (socket == null || socket.isClosed) {
|
33
|
+
socket = new Socket(serverHost, serverPort)
|
34
|
+
|
35
|
+
socketInputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, bufferSize))
|
36
|
+
socketOutputStream = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
|
37
|
+
}
|
38
|
+
}
|
39
|
+
}
|
40
|
+
|
41
|
+
override def zero(value: List[Array[Byte]]): List[Array[Byte]] = new ArrayList
|
42
|
+
|
43
|
+
override def addInPlace(val1: List[Array[Byte]], val2: List[Array[Byte]]) : List[Array[Byte]] = synchronized {
|
44
|
+
if (serverHost == null) {
|
45
|
+
// This happens on the worker node, where we just want to remember all the updates
|
46
|
+
val1.addAll(val2)
|
47
|
+
val1
|
48
|
+
} else {
|
49
|
+
// This happens on the master, where we pass the updates to Ruby through a socket
|
50
|
+
openSocket()
|
51
|
+
|
52
|
+
socketOutputStream.writeInt(val2.size)
|
53
|
+
for (array <- val2) {
|
54
|
+
socketOutputStream.writeInt(array.length)
|
55
|
+
socketOutputStream.write(array)
|
56
|
+
}
|
57
|
+
socketOutputStream.flush()
|
58
|
+
|
59
|
+
// Wait for acknowledgement
|
60
|
+
// http://stackoverflow.com/questions/28560133/ruby-server-java-scala-client-deadlock
|
61
|
+
//
|
62
|
+
// if(in.readInt() != RubyConstant.ACCUMULATOR_ACK){
|
63
|
+
// throw new SparkException("Accumulator was not acknowledged")
|
64
|
+
// }
|
65
|
+
|
66
|
+
new ArrayList
|
67
|
+
}
|
68
|
+
}
|
69
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
package org.apache.spark.api.ruby
|
2
|
+
|
3
|
+
import org.apache.spark.api.python.PythonBroadcast
|
4
|
+
|
5
|
+
/**
|
6
|
+
* An Wrapper for Ruby Broadcast, which is written into disk by Ruby. It also will
|
7
|
+
* write the data into disk after deserialization, then Ruby can read it from disks.
|
8
|
+
*
|
9
|
+
* Class use Python logic - only for semantic
|
10
|
+
*/
|
11
|
+
class RubyBroadcast(@transient var _path: String, @transient var id: java.lang.Long) extends PythonBroadcast(_path) {
|
12
|
+
|
13
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
package org.apache.spark.api.ruby
|
2
|
+
|
3
|
+
object RubyConstant {
|
4
|
+
val DATA_EOF = -2
|
5
|
+
val WORKER_ERROR = -1
|
6
|
+
val WORKER_DONE = 0
|
7
|
+
val CREATE_WORKER = 1
|
8
|
+
val KILL_WORKER = 2
|
9
|
+
val KILL_WORKER_AND_WAIT = 3
|
10
|
+
val SUCCESSFULLY_KILLED = 4
|
11
|
+
val UNSUCCESSFUL_KILLING = 5
|
12
|
+
val ACCUMULATOR_ACK = 6
|
13
|
+
}
|
@@ -0,0 +1,55 @@
|
|
1
|
+
package org.apache.spark.mllib.api.ruby
|
2
|
+
|
3
|
+
import java.util.ArrayList
|
4
|
+
|
5
|
+
import scala.collection.JavaConverters._
|
6
|
+
|
7
|
+
import org.apache.spark.rdd.RDD
|
8
|
+
import org.apache.spark.api.java.JavaRDD
|
9
|
+
import org.apache.spark.mllib.linalg._
|
10
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
11
|
+
import org.apache.spark.mllib.classification.NaiveBayes
|
12
|
+
import org.apache.spark.mllib.clustering.GaussianMixtureModel
|
13
|
+
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
14
|
+
import org.apache.spark.mllib.api.python.MLLibAPI
|
15
|
+
|
16
|
+
|
17
|
+
class RubyMLLibAPI extends MLLibAPI {
|
18
|
+
// trainLinearRegressionModelWithSGD
|
19
|
+
// trainLassoModelWithSGD
|
20
|
+
// trainRidgeModelWithSGD
|
21
|
+
// trainLogisticRegressionModelWithSGD
|
22
|
+
// trainLogisticRegressionModelWithLBFGS
|
23
|
+
// trainSVMModelWithSGD
|
24
|
+
// trainKMeansModel
|
25
|
+
// trainGaussianMixture
|
26
|
+
|
27
|
+
// Rjb have a problem with theta: Array[Array[Double]]
|
28
|
+
override def trainNaiveBayes(data: JavaRDD[LabeledPoint], lambda: Double) = {
|
29
|
+
val model = NaiveBayes.train(data.rdd, lambda)
|
30
|
+
|
31
|
+
List(
|
32
|
+
Vectors.dense(model.labels),
|
33
|
+
Vectors.dense(model.pi),
|
34
|
+
model.theta.toSeq
|
35
|
+
).map(_.asInstanceOf[Object]).asJava
|
36
|
+
}
|
37
|
+
|
38
|
+
// On python is wt just Object
|
39
|
+
def predictSoftGMM(
|
40
|
+
data: JavaRDD[Vector],
|
41
|
+
wt: ArrayList[Object],
|
42
|
+
mu: ArrayList[Object],
|
43
|
+
si: ArrayList[Object]): RDD[Array[Double]] = {
|
44
|
+
|
45
|
+
// val weight = wt.asInstanceOf[Array[Double]]
|
46
|
+
val weight = wt.toArray.map(_.asInstanceOf[Double])
|
47
|
+
val mean = mu.toArray.map(_.asInstanceOf[DenseVector])
|
48
|
+
val sigma = si.toArray.map(_.asInstanceOf[DenseMatrix])
|
49
|
+
val gaussians = Array.tabulate(weight.length){
|
50
|
+
i => new MultivariateGaussian(mean(i), sigma(i))
|
51
|
+
}
|
52
|
+
val model = new GaussianMixtureModel(weight, gaussians)
|
53
|
+
model.predictSoft(data)
|
54
|
+
}
|
55
|
+
}
|