ruby-spark 1.0.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +37 -0
- data/Gemfile +47 -0
- data/Guardfile +5 -0
- data/LICENSE.txt +22 -0
- data/README.md +185 -0
- data/Rakefile +35 -0
- data/TODO.md +7 -0
- data/benchmark/aggregate.rb +33 -0
- data/benchmark/bisect.rb +88 -0
- data/benchmark/custom_marshal.rb +94 -0
- data/benchmark/digest.rb +150 -0
- data/benchmark/enumerator.rb +88 -0
- data/benchmark/performance/prepare.sh +18 -0
- data/benchmark/performance/python.py +156 -0
- data/benchmark/performance/r.r +69 -0
- data/benchmark/performance/ruby.rb +167 -0
- data/benchmark/performance/run-all.sh +160 -0
- data/benchmark/performance/scala.scala +181 -0
- data/benchmark/serializer.rb +82 -0
- data/benchmark/sort.rb +43 -0
- data/benchmark/sort2.rb +164 -0
- data/benchmark/take.rb +28 -0
- data/bin/ruby-spark +8 -0
- data/example/pi.rb +28 -0
- data/ext/ruby_c/extconf.rb +3 -0
- data/ext/ruby_c/murmur.c +158 -0
- data/ext/ruby_c/murmur.h +9 -0
- data/ext/ruby_c/ruby-spark.c +18 -0
- data/ext/ruby_java/Digest.java +36 -0
- data/ext/ruby_java/Murmur2.java +98 -0
- data/ext/ruby_java/RubySparkExtService.java +28 -0
- data/ext/ruby_java/extconf.rb +3 -0
- data/ext/spark/build.sbt +73 -0
- data/ext/spark/project/plugins.sbt +9 -0
- data/ext/spark/sbt/sbt +34 -0
- data/ext/spark/src/main/scala/Exec.scala +91 -0
- data/ext/spark/src/main/scala/MLLibAPI.scala +4 -0
- data/ext/spark/src/main/scala/Marshal.scala +52 -0
- data/ext/spark/src/main/scala/MarshalDump.scala +113 -0
- data/ext/spark/src/main/scala/MarshalLoad.scala +220 -0
- data/ext/spark/src/main/scala/RubyAccumulatorParam.scala +69 -0
- data/ext/spark/src/main/scala/RubyBroadcast.scala +13 -0
- data/ext/spark/src/main/scala/RubyConstant.scala +13 -0
- data/ext/spark/src/main/scala/RubyMLLibAPI.scala +55 -0
- data/ext/spark/src/main/scala/RubyMLLibUtilAPI.scala +21 -0
- data/ext/spark/src/main/scala/RubyPage.scala +34 -0
- data/ext/spark/src/main/scala/RubyRDD.scala +364 -0
- data/ext/spark/src/main/scala/RubySerializer.scala +14 -0
- data/ext/spark/src/main/scala/RubyTab.scala +11 -0
- data/ext/spark/src/main/scala/RubyUtils.scala +15 -0
- data/ext/spark/src/main/scala/RubyWorker.scala +257 -0
- data/ext/spark/src/test/scala/MarshalSpec.scala +84 -0
- data/lib/ruby-spark.rb +1 -0
- data/lib/spark.rb +198 -0
- data/lib/spark/accumulator.rb +260 -0
- data/lib/spark/broadcast.rb +98 -0
- data/lib/spark/build.rb +43 -0
- data/lib/spark/cli.rb +169 -0
- data/lib/spark/command.rb +86 -0
- data/lib/spark/command/base.rb +154 -0
- data/lib/spark/command/basic.rb +345 -0
- data/lib/spark/command/pair.rb +124 -0
- data/lib/spark/command/sort.rb +51 -0
- data/lib/spark/command/statistic.rb +144 -0
- data/lib/spark/command_builder.rb +141 -0
- data/lib/spark/command_validator.rb +34 -0
- data/lib/spark/config.rb +244 -0
- data/lib/spark/constant.rb +14 -0
- data/lib/spark/context.rb +304 -0
- data/lib/spark/error.rb +50 -0
- data/lib/spark/ext/hash.rb +41 -0
- data/lib/spark/ext/integer.rb +25 -0
- data/lib/spark/ext/io.rb +57 -0
- data/lib/spark/ext/ip_socket.rb +29 -0
- data/lib/spark/ext/module.rb +58 -0
- data/lib/spark/ext/object.rb +24 -0
- data/lib/spark/ext/string.rb +24 -0
- data/lib/spark/helper.rb +10 -0
- data/lib/spark/helper/logger.rb +40 -0
- data/lib/spark/helper/parser.rb +85 -0
- data/lib/spark/helper/serialize.rb +71 -0
- data/lib/spark/helper/statistic.rb +93 -0
- data/lib/spark/helper/system.rb +42 -0
- data/lib/spark/java_bridge.rb +19 -0
- data/lib/spark/java_bridge/base.rb +203 -0
- data/lib/spark/java_bridge/jruby.rb +23 -0
- data/lib/spark/java_bridge/rjb.rb +41 -0
- data/lib/spark/logger.rb +76 -0
- data/lib/spark/mllib.rb +100 -0
- data/lib/spark/mllib/classification/common.rb +31 -0
- data/lib/spark/mllib/classification/logistic_regression.rb +223 -0
- data/lib/spark/mllib/classification/naive_bayes.rb +97 -0
- data/lib/spark/mllib/classification/svm.rb +135 -0
- data/lib/spark/mllib/clustering/gaussian_mixture.rb +82 -0
- data/lib/spark/mllib/clustering/kmeans.rb +118 -0
- data/lib/spark/mllib/matrix.rb +120 -0
- data/lib/spark/mllib/regression/common.rb +73 -0
- data/lib/spark/mllib/regression/labeled_point.rb +41 -0
- data/lib/spark/mllib/regression/lasso.rb +100 -0
- data/lib/spark/mllib/regression/linear.rb +124 -0
- data/lib/spark/mllib/regression/ridge.rb +97 -0
- data/lib/spark/mllib/ruby_matrix/matrix_adapter.rb +53 -0
- data/lib/spark/mllib/ruby_matrix/vector_adapter.rb +57 -0
- data/lib/spark/mllib/stat/distribution.rb +12 -0
- data/lib/spark/mllib/vector.rb +185 -0
- data/lib/spark/rdd.rb +1328 -0
- data/lib/spark/sampler.rb +92 -0
- data/lib/spark/serializer.rb +24 -0
- data/lib/spark/serializer/base.rb +170 -0
- data/lib/spark/serializer/cartesian.rb +37 -0
- data/lib/spark/serializer/marshal.rb +19 -0
- data/lib/spark/serializer/message_pack.rb +25 -0
- data/lib/spark/serializer/oj.rb +25 -0
- data/lib/spark/serializer/pair.rb +27 -0
- data/lib/spark/serializer/utf8.rb +25 -0
- data/lib/spark/sort.rb +189 -0
- data/lib/spark/stat_counter.rb +125 -0
- data/lib/spark/storage_level.rb +39 -0
- data/lib/spark/version.rb +3 -0
- data/lib/spark/worker/master.rb +144 -0
- data/lib/spark/worker/spark_files.rb +15 -0
- data/lib/spark/worker/worker.rb +197 -0
- data/ruby-spark.gemspec +36 -0
- data/spec/generator.rb +37 -0
- data/spec/inputs/lorem_300.txt +316 -0
- data/spec/inputs/numbers/1.txt +50 -0
- data/spec/inputs/numbers/10.txt +50 -0
- data/spec/inputs/numbers/11.txt +50 -0
- data/spec/inputs/numbers/12.txt +50 -0
- data/spec/inputs/numbers/13.txt +50 -0
- data/spec/inputs/numbers/14.txt +50 -0
- data/spec/inputs/numbers/15.txt +50 -0
- data/spec/inputs/numbers/16.txt +50 -0
- data/spec/inputs/numbers/17.txt +50 -0
- data/spec/inputs/numbers/18.txt +50 -0
- data/spec/inputs/numbers/19.txt +50 -0
- data/spec/inputs/numbers/2.txt +50 -0
- data/spec/inputs/numbers/20.txt +50 -0
- data/spec/inputs/numbers/3.txt +50 -0
- data/spec/inputs/numbers/4.txt +50 -0
- data/spec/inputs/numbers/5.txt +50 -0
- data/spec/inputs/numbers/6.txt +50 -0
- data/spec/inputs/numbers/7.txt +50 -0
- data/spec/inputs/numbers/8.txt +50 -0
- data/spec/inputs/numbers/9.txt +50 -0
- data/spec/inputs/numbers_0_100.txt +101 -0
- data/spec/inputs/numbers_1_100.txt +100 -0
- data/spec/lib/collect_spec.rb +42 -0
- data/spec/lib/command_spec.rb +68 -0
- data/spec/lib/config_spec.rb +64 -0
- data/spec/lib/context_spec.rb +163 -0
- data/spec/lib/ext_spec.rb +72 -0
- data/spec/lib/external_apps_spec.rb +45 -0
- data/spec/lib/filter_spec.rb +80 -0
- data/spec/lib/flat_map_spec.rb +100 -0
- data/spec/lib/group_spec.rb +109 -0
- data/spec/lib/helper_spec.rb +19 -0
- data/spec/lib/key_spec.rb +41 -0
- data/spec/lib/manipulation_spec.rb +114 -0
- data/spec/lib/map_partitions_spec.rb +87 -0
- data/spec/lib/map_spec.rb +91 -0
- data/spec/lib/mllib/classification_spec.rb +54 -0
- data/spec/lib/mllib/clustering_spec.rb +35 -0
- data/spec/lib/mllib/matrix_spec.rb +32 -0
- data/spec/lib/mllib/regression_spec.rb +116 -0
- data/spec/lib/mllib/vector_spec.rb +77 -0
- data/spec/lib/reduce_by_key_spec.rb +118 -0
- data/spec/lib/reduce_spec.rb +131 -0
- data/spec/lib/sample_spec.rb +46 -0
- data/spec/lib/serializer_spec.rb +13 -0
- data/spec/lib/sort_spec.rb +58 -0
- data/spec/lib/statistic_spec.rb +168 -0
- data/spec/lib/whole_text_files_spec.rb +33 -0
- data/spec/spec_helper.rb +39 -0
- metadata +301 -0
@@ -0,0 +1,113 @@
|
|
1
|
+
package org.apache.spark.api.ruby.marshal
|
2
|
+
|
3
|
+
import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
|
4
|
+
|
5
|
+
import scala.collection.mutable.ArrayBuffer
|
6
|
+
import scala.collection.JavaConverters._
|
7
|
+
import scala.reflect.{ClassTag, classTag}
|
8
|
+
|
9
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
10
|
+
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
|
11
|
+
|
12
|
+
|
13
|
+
/* =================================================================================================
|
14
|
+
* class MarshalDump
|
15
|
+
* =================================================================================================
|
16
|
+
*/
|
17
|
+
class MarshalDump(os: DataOutputStream) {
|
18
|
+
|
19
|
+
val NAN_BYTELIST = "nan".getBytes
|
20
|
+
val NEGATIVE_INFINITY_BYTELIST = "-inf".getBytes
|
21
|
+
val INFINITY_BYTELIST = "inf".getBytes
|
22
|
+
|
23
|
+
def dump(data: Any) {
|
24
|
+
data match {
|
25
|
+
case null =>
|
26
|
+
os.writeByte('0')
|
27
|
+
|
28
|
+
case item: Boolean =>
|
29
|
+
val char = if(item) 'T' else 'F'
|
30
|
+
os.writeByte(char)
|
31
|
+
|
32
|
+
case item: Int =>
|
33
|
+
os.writeByte('i')
|
34
|
+
dumpInt(item)
|
35
|
+
|
36
|
+
case item: Array[_] =>
|
37
|
+
os.writeByte('[')
|
38
|
+
dumpArray(item)
|
39
|
+
|
40
|
+
case item: Double =>
|
41
|
+
os.writeByte('f')
|
42
|
+
dumpFloat(item)
|
43
|
+
|
44
|
+
case item: ArrayBuffer[Any] => dump(item.toArray)
|
45
|
+
}
|
46
|
+
}
|
47
|
+
|
48
|
+
def dumpInt(data: Int) {
|
49
|
+
if(data == 0){
|
50
|
+
os.writeByte(0)
|
51
|
+
}
|
52
|
+
else if (0 < data && data < 123) {
|
53
|
+
os.writeByte(data + 5)
|
54
|
+
}
|
55
|
+
else if (-124 < data && data < 0) {
|
56
|
+
os.writeByte((data - 5) & 0xff)
|
57
|
+
}
|
58
|
+
else {
|
59
|
+
val buffer = new Array[Byte](4)
|
60
|
+
var value = data
|
61
|
+
|
62
|
+
var i = 0
|
63
|
+
while(i != 4 && value != 0 && value != -1){
|
64
|
+
buffer(i) = (value & 0xff).toByte
|
65
|
+
value = value >> 8
|
66
|
+
|
67
|
+
i += 1
|
68
|
+
}
|
69
|
+
val lenght = i + 1
|
70
|
+
if(value < 0){
|
71
|
+
os.writeByte(-lenght)
|
72
|
+
}
|
73
|
+
else{
|
74
|
+
os.writeByte(lenght)
|
75
|
+
}
|
76
|
+
os.write(buffer, 0, lenght)
|
77
|
+
}
|
78
|
+
}
|
79
|
+
|
80
|
+
def dumpArray(array: Array[_]) {
|
81
|
+
dumpInt(array.size)
|
82
|
+
|
83
|
+
for(item <- array) {
|
84
|
+
dump(item)
|
85
|
+
}
|
86
|
+
}
|
87
|
+
|
88
|
+
def dumpFloat(value: Double) {
|
89
|
+
if(value.isPosInfinity){
|
90
|
+
dumpString(NEGATIVE_INFINITY_BYTELIST)
|
91
|
+
}
|
92
|
+
else if(value.isNegInfinity){
|
93
|
+
dumpString(INFINITY_BYTELIST)
|
94
|
+
}
|
95
|
+
else if(value.isNaN){
|
96
|
+
dumpString(NAN_BYTELIST)
|
97
|
+
}
|
98
|
+
else{
|
99
|
+
// dumpString("%.17g".format(value))
|
100
|
+
dumpString(value.toString)
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
def dumpString(data: String) {
|
105
|
+
dumpString(data.getBytes)
|
106
|
+
}
|
107
|
+
|
108
|
+
def dumpString(data: Array[Byte]) {
|
109
|
+
dumpInt(data.size)
|
110
|
+
os.write(data)
|
111
|
+
}
|
112
|
+
|
113
|
+
}
|
@@ -0,0 +1,220 @@
|
|
1
|
+
package org.apache.spark.api.ruby.marshal
|
2
|
+
|
3
|
+
import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
|
4
|
+
|
5
|
+
import scala.collection.mutable.ArrayBuffer
|
6
|
+
import scala.collection.JavaConverters._
|
7
|
+
import scala.reflect.{ClassTag, classTag}
|
8
|
+
|
9
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
10
|
+
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
|
11
|
+
|
12
|
+
|
13
|
+
/* =================================================================================================
|
14
|
+
* class MarshalLoad
|
15
|
+
* =================================================================================================
|
16
|
+
*/
|
17
|
+
class MarshalLoad(is: DataInputStream) {
|
18
|
+
|
19
|
+
case class WaitForObject()
|
20
|
+
|
21
|
+
val registeredSymbols = ArrayBuffer[String]()
|
22
|
+
val registeredLinks = ArrayBuffer[Any]()
|
23
|
+
|
24
|
+
def load: Any = {
|
25
|
+
load(is.readUnsignedByte.toChar)
|
26
|
+
}
|
27
|
+
|
28
|
+
def load(dataType: Char): Any = {
|
29
|
+
dataType match {
|
30
|
+
case '0' => null
|
31
|
+
case 'T' => true
|
32
|
+
case 'F' => false
|
33
|
+
case 'i' => loadInt
|
34
|
+
case 'f' => loadAndRegisterFloat
|
35
|
+
case ':' => loadAndRegisterSymbol
|
36
|
+
case '[' => loadAndRegisterArray
|
37
|
+
case 'U' => loadAndRegisterUserObject
|
38
|
+
case _ =>
|
39
|
+
throw new IllegalArgumentException(s"Format is not supported: $dataType.")
|
40
|
+
}
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
// ----------------------------------------------------------------------------------------------
|
45
|
+
// Load by type
|
46
|
+
|
47
|
+
def loadInt: Int = {
|
48
|
+
var c = is.readByte.toInt
|
49
|
+
|
50
|
+
if (c == 0) {
|
51
|
+
return 0
|
52
|
+
} else if (4 < c && c < 128) {
|
53
|
+
return c - 5
|
54
|
+
} else if (-129 < c && c < -4) {
|
55
|
+
return c + 5
|
56
|
+
}
|
57
|
+
|
58
|
+
var result: Long = 0
|
59
|
+
|
60
|
+
if (c > 0) {
|
61
|
+
result = 0
|
62
|
+
for( i <- 0 until c ) {
|
63
|
+
result |= (is.readUnsignedByte << (8 * i)).toLong
|
64
|
+
}
|
65
|
+
} else {
|
66
|
+
c = -c
|
67
|
+
result = -1
|
68
|
+
for( i <- 0 until c ) {
|
69
|
+
result &= ~((0xff << (8 * i)).toLong)
|
70
|
+
result |= (is.readUnsignedByte << (8 * i)).toLong
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
result.toInt
|
75
|
+
}
|
76
|
+
|
77
|
+
def loadAndRegisterFloat: Double = {
|
78
|
+
val result = loadFloat
|
79
|
+
registeredLinks += result
|
80
|
+
result
|
81
|
+
}
|
82
|
+
|
83
|
+
def loadFloat: Double = {
|
84
|
+
val string = loadString
|
85
|
+
string match {
|
86
|
+
case "nan" => Double.NaN
|
87
|
+
case "inf" => Double.PositiveInfinity
|
88
|
+
case "-inf" => Double.NegativeInfinity
|
89
|
+
case _ => string.toDouble
|
90
|
+
}
|
91
|
+
}
|
92
|
+
|
93
|
+
def loadString: String = {
|
94
|
+
new String(loadStringBytes)
|
95
|
+
}
|
96
|
+
|
97
|
+
def loadStringBytes: Array[Byte] = {
|
98
|
+
val size = loadInt
|
99
|
+
val buffer = new Array[Byte](size)
|
100
|
+
|
101
|
+
var readSize = 0
|
102
|
+
while(readSize < size){
|
103
|
+
val read = is.read(buffer, readSize, size-readSize)
|
104
|
+
|
105
|
+
if(read == -1){
|
106
|
+
throw new IllegalArgumentException("Marshal too short.")
|
107
|
+
}
|
108
|
+
|
109
|
+
readSize += read
|
110
|
+
}
|
111
|
+
|
112
|
+
buffer
|
113
|
+
}
|
114
|
+
|
115
|
+
def loadAndRegisterSymbol: String = {
|
116
|
+
val result = loadString
|
117
|
+
registeredSymbols += result
|
118
|
+
result
|
119
|
+
}
|
120
|
+
|
121
|
+
def loadAndRegisterArray: Array[Any] = {
|
122
|
+
val size = loadInt
|
123
|
+
val array = new Array[Any](size)
|
124
|
+
|
125
|
+
registeredLinks += array
|
126
|
+
|
127
|
+
for( i <- 0 until size ) {
|
128
|
+
array(i) = loadNextObject
|
129
|
+
}
|
130
|
+
|
131
|
+
array
|
132
|
+
}
|
133
|
+
|
134
|
+
def loadAndRegisterUserObject: Any = {
|
135
|
+
val klass = loadNextObject.asInstanceOf[String]
|
136
|
+
|
137
|
+
// Register future class before load the next object
|
138
|
+
registeredLinks += WaitForObject()
|
139
|
+
val index = registeredLinks.size - 1
|
140
|
+
|
141
|
+
val data = loadNextObject
|
142
|
+
|
143
|
+
val result = klass match {
|
144
|
+
case "Spark::Mllib::LabeledPoint" => createLabeledPoint(data)
|
145
|
+
case "Spark::Mllib::DenseVector" => createDenseVector(data)
|
146
|
+
case "Spark::Mllib::SparseVector" => createSparseVector(data)
|
147
|
+
case other =>
|
148
|
+
throw new IllegalArgumentException(s"Object $other is not supported.")
|
149
|
+
}
|
150
|
+
|
151
|
+
registeredLinks(index) = result
|
152
|
+
|
153
|
+
result
|
154
|
+
}
|
155
|
+
|
156
|
+
|
157
|
+
// ----------------------------------------------------------------------------------------------
|
158
|
+
// Other loads
|
159
|
+
|
160
|
+
def loadNextObject: Any = {
|
161
|
+
val dataType = is.readUnsignedByte.toChar
|
162
|
+
|
163
|
+
if(isLinkType(dataType)){
|
164
|
+
readLink(dataType)
|
165
|
+
}
|
166
|
+
else{
|
167
|
+
load(dataType)
|
168
|
+
}
|
169
|
+
}
|
170
|
+
|
171
|
+
|
172
|
+
// ----------------------------------------------------------------------------------------------
|
173
|
+
// To java objects
|
174
|
+
|
175
|
+
def createLabeledPoint(data: Any): LabeledPoint = {
|
176
|
+
val array = data.asInstanceOf[Array[_]]
|
177
|
+
new LabeledPoint(array(0).asInstanceOf[Double], array(1).asInstanceOf[Vector])
|
178
|
+
}
|
179
|
+
|
180
|
+
def createDenseVector(data: Any): DenseVector = {
|
181
|
+
new DenseVector(data.asInstanceOf[Array[_]].map(toDouble(_)))
|
182
|
+
}
|
183
|
+
|
184
|
+
def createSparseVector(data: Any): SparseVector = {
|
185
|
+
val array = data.asInstanceOf[Array[_]]
|
186
|
+
val size = array(0).asInstanceOf[Int]
|
187
|
+
val indices = array(1).asInstanceOf[Array[_]].map(_.asInstanceOf[Int])
|
188
|
+
val values = array(2).asInstanceOf[Array[_]].map(toDouble(_))
|
189
|
+
|
190
|
+
new SparseVector(size, indices, values)
|
191
|
+
}
|
192
|
+
|
193
|
+
|
194
|
+
// ----------------------------------------------------------------------------------------------
|
195
|
+
// Helpers
|
196
|
+
|
197
|
+
def toDouble(data: Any): Double = data match {
|
198
|
+
case x: Int => x.toDouble
|
199
|
+
case x: Double => x
|
200
|
+
case _ => 0.0
|
201
|
+
}
|
202
|
+
|
203
|
+
|
204
|
+
// ----------------------------------------------------------------------------------------------
|
205
|
+
// Cache
|
206
|
+
|
207
|
+
def readLink(dataType: Char): Any = {
|
208
|
+
val index = loadInt
|
209
|
+
|
210
|
+
dataType match {
|
211
|
+
case '@' => registeredLinks(index)
|
212
|
+
case ';' => registeredSymbols(index)
|
213
|
+
}
|
214
|
+
}
|
215
|
+
|
216
|
+
def isLinkType(dataType: Char): Boolean = {
|
217
|
+
dataType == ';' || dataType == '@'
|
218
|
+
}
|
219
|
+
|
220
|
+
}
|
@@ -0,0 +1,69 @@
|
|
1
|
+
package org.apache.spark.api.ruby
|
2
|
+
|
3
|
+
import java.io._
|
4
|
+
import java.net._
|
5
|
+
import java.util.{List, ArrayList}
|
6
|
+
|
7
|
+
import scala.collection.JavaConversions._
|
8
|
+
import scala.collection.immutable._
|
9
|
+
|
10
|
+
import org.apache.spark._
|
11
|
+
import org.apache.spark.util.Utils
|
12
|
+
|
13
|
+
/**
|
14
|
+
* Internal class that acts as an `AccumulatorParam` for Ruby accumulators. Inside, it
|
15
|
+
* collects a list of pickled strings that we pass to Ruby through a socket.
|
16
|
+
*/
|
17
|
+
private class RubyAccumulatorParam(serverHost: String, serverPort: Int)
|
18
|
+
extends AccumulatorParam[List[Array[Byte]]] {
|
19
|
+
|
20
|
+
// Utils.checkHost(serverHost, "Expected hostname")
|
21
|
+
|
22
|
+
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
|
23
|
+
|
24
|
+
// Socket shoudl not be serialized
|
25
|
+
// Otherwise: SparkException: Task not serializable
|
26
|
+
@transient var socket: Socket = null
|
27
|
+
@transient var socketOutputStream: DataOutputStream = null
|
28
|
+
@transient var socketInputStream: DataInputStream = null
|
29
|
+
|
30
|
+
def openSocket(){
|
31
|
+
synchronized {
|
32
|
+
if (socket == null || socket.isClosed) {
|
33
|
+
socket = new Socket(serverHost, serverPort)
|
34
|
+
|
35
|
+
socketInputStream = new DataInputStream(new BufferedInputStream(socket.getInputStream, bufferSize))
|
36
|
+
socketOutputStream = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
|
37
|
+
}
|
38
|
+
}
|
39
|
+
}
|
40
|
+
|
41
|
+
override def zero(value: List[Array[Byte]]): List[Array[Byte]] = new ArrayList
|
42
|
+
|
43
|
+
override def addInPlace(val1: List[Array[Byte]], val2: List[Array[Byte]]) : List[Array[Byte]] = synchronized {
|
44
|
+
if (serverHost == null) {
|
45
|
+
// This happens on the worker node, where we just want to remember all the updates
|
46
|
+
val1.addAll(val2)
|
47
|
+
val1
|
48
|
+
} else {
|
49
|
+
// This happens on the master, where we pass the updates to Ruby through a socket
|
50
|
+
openSocket()
|
51
|
+
|
52
|
+
socketOutputStream.writeInt(val2.size)
|
53
|
+
for (array <- val2) {
|
54
|
+
socketOutputStream.writeInt(array.length)
|
55
|
+
socketOutputStream.write(array)
|
56
|
+
}
|
57
|
+
socketOutputStream.flush()
|
58
|
+
|
59
|
+
// Wait for acknowledgement
|
60
|
+
// http://stackoverflow.com/questions/28560133/ruby-server-java-scala-client-deadlock
|
61
|
+
//
|
62
|
+
// if(in.readInt() != RubyConstant.ACCUMULATOR_ACK){
|
63
|
+
// throw new SparkException("Accumulator was not acknowledged")
|
64
|
+
// }
|
65
|
+
|
66
|
+
new ArrayList
|
67
|
+
}
|
68
|
+
}
|
69
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
package org.apache.spark.api.ruby
|
2
|
+
|
3
|
+
import org.apache.spark.api.python.PythonBroadcast
|
4
|
+
|
5
|
+
/**
|
6
|
+
* An Wrapper for Ruby Broadcast, which is written into disk by Ruby. It also will
|
7
|
+
* write the data into disk after deserialization, then Ruby can read it from disks.
|
8
|
+
*
|
9
|
+
* Class use Python logic - only for semantic
|
10
|
+
*/
|
11
|
+
class RubyBroadcast(@transient var _path: String, @transient var id: java.lang.Long) extends PythonBroadcast(_path) {
|
12
|
+
|
13
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
package org.apache.spark.api.ruby
|
2
|
+
|
3
|
+
object RubyConstant {
|
4
|
+
val DATA_EOF = -2
|
5
|
+
val WORKER_ERROR = -1
|
6
|
+
val WORKER_DONE = 0
|
7
|
+
val CREATE_WORKER = 1
|
8
|
+
val KILL_WORKER = 2
|
9
|
+
val KILL_WORKER_AND_WAIT = 3
|
10
|
+
val SUCCESSFULLY_KILLED = 4
|
11
|
+
val UNSUCCESSFUL_KILLING = 5
|
12
|
+
val ACCUMULATOR_ACK = 6
|
13
|
+
}
|
@@ -0,0 +1,55 @@
|
|
1
|
+
package org.apache.spark.mllib.api.ruby
|
2
|
+
|
3
|
+
import java.util.ArrayList
|
4
|
+
|
5
|
+
import scala.collection.JavaConverters._
|
6
|
+
|
7
|
+
import org.apache.spark.rdd.RDD
|
8
|
+
import org.apache.spark.api.java.JavaRDD
|
9
|
+
import org.apache.spark.mllib.linalg._
|
10
|
+
import org.apache.spark.mllib.regression.LabeledPoint
|
11
|
+
import org.apache.spark.mllib.classification.NaiveBayes
|
12
|
+
import org.apache.spark.mllib.clustering.GaussianMixtureModel
|
13
|
+
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
|
14
|
+
import org.apache.spark.mllib.api.python.MLLibAPI
|
15
|
+
|
16
|
+
|
17
|
+
class RubyMLLibAPI extends MLLibAPI {
|
18
|
+
// trainLinearRegressionModelWithSGD
|
19
|
+
// trainLassoModelWithSGD
|
20
|
+
// trainRidgeModelWithSGD
|
21
|
+
// trainLogisticRegressionModelWithSGD
|
22
|
+
// trainLogisticRegressionModelWithLBFGS
|
23
|
+
// trainSVMModelWithSGD
|
24
|
+
// trainKMeansModel
|
25
|
+
// trainGaussianMixture
|
26
|
+
|
27
|
+
// Rjb have a problem with theta: Array[Array[Double]]
|
28
|
+
override def trainNaiveBayes(data: JavaRDD[LabeledPoint], lambda: Double) = {
|
29
|
+
val model = NaiveBayes.train(data.rdd, lambda)
|
30
|
+
|
31
|
+
List(
|
32
|
+
Vectors.dense(model.labels),
|
33
|
+
Vectors.dense(model.pi),
|
34
|
+
model.theta.toSeq
|
35
|
+
).map(_.asInstanceOf[Object]).asJava
|
36
|
+
}
|
37
|
+
|
38
|
+
// On python is wt just Object
|
39
|
+
def predictSoftGMM(
|
40
|
+
data: JavaRDD[Vector],
|
41
|
+
wt: ArrayList[Object],
|
42
|
+
mu: ArrayList[Object],
|
43
|
+
si: ArrayList[Object]): RDD[Array[Double]] = {
|
44
|
+
|
45
|
+
// val weight = wt.asInstanceOf[Array[Double]]
|
46
|
+
val weight = wt.toArray.map(_.asInstanceOf[Double])
|
47
|
+
val mean = mu.toArray.map(_.asInstanceOf[DenseVector])
|
48
|
+
val sigma = si.toArray.map(_.asInstanceOf[DenseMatrix])
|
49
|
+
val gaussians = Array.tabulate(weight.length){
|
50
|
+
i => new MultivariateGaussian(mean(i), sigma(i))
|
51
|
+
}
|
52
|
+
val model = new GaussianMixtureModel(weight, gaussians)
|
53
|
+
model.predictSoft(data)
|
54
|
+
}
|
55
|
+
}
|