cassandra_model_spark 0.0.1.5-java → 0.0.4-java

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,108 @@
1
+ package org.apache.spark.api.cassandra_model
2
+
3
+ import org.luaj.vm2._
4
+ import org.luaj.vm2.lib._
5
+ import org.apache.spark.sql.types._
6
+ import org.apache.spark.sql._
7
+
8
+ class LuaRowLib extends TwoArgFunction {
9
+ override def call(mod_name: LuaValue, env: LuaValue): LuaValue = {
10
+ val fn_table = new LuaTable()
11
+
12
+ fn_table.set("new", new newrow())
13
+ fn_table.set("append", new append())
14
+ fn_table.set("replace", new replace())
15
+ fn_table.set("slice", new slice())
16
+
17
+ env.set("row", fn_table)
18
+ fn_table
19
+ }
20
+
21
+ private def toLuaRowValue(lua_row: LuaValue): LuaRowValue = lua_row match {
22
+ case row: LuaRowValue => row
23
+ }
24
+
25
+ private def toLuaString(lua_key: LuaValue): String = lua_key match {
26
+ case str: LuaString => str.toString()
27
+ }
28
+
29
+ private def convertedValue(lua_value: LuaValue): Any = lua_value match {
30
+ case str: LuaString => str.toString()
31
+ case num: LuaInteger => num.toint()
32
+ case dfnum: LuaDouble => dfnum.todouble()
33
+ }
34
+
35
+ private def guessedDataType(value: Any): DataType = value match {
36
+ case str: String => StringType
37
+ case num: Int => IntegerType
38
+ case dfnum: Double => DoubleType
39
+ }
40
+
41
+ class newrow extends LibFunction {
42
+ override def call(): LuaValue = {
43
+ val new_fields: Array[StructField] = Array()
44
+ val new_schema = StructType(new_fields)
45
+ val new_values: Seq[Any] = Seq()
46
+ val new_row = Row.fromSeq(new_values)
47
+
48
+ new LuaRowValue(new_schema, new_row)
49
+ }
50
+ }
51
+
52
+ class append extends LibFunction {
53
+ override def call(lua_row: LuaValue, lua_key: LuaValue, lua_value: LuaValue): LuaValue = {
54
+ val row = toLuaRowValue(lua_row)
55
+ val key = toLuaString(lua_key)
56
+ val value = convertedValue(lua_value)
57
+ val data_type = guessedDataType(value)
58
+ val fields = row.schema.fields :+ StructField(key, data_type)
59
+ val new_schema = StructType(fields)
60
+ val new_values = row.row.toSeq :+ value
61
+ val new_row = Row.fromSeq(new_values)
62
+
63
+ new LuaRowValue(new_schema, new_row)
64
+ }
65
+ }
66
+
67
+ class replace extends LibFunction {
68
+ override def call(lua_row: LuaValue, lua_key: LuaValue, lua_value: LuaValue): LuaValue = {
69
+ val row = toLuaRowValue(lua_row)
70
+ val key = toLuaString(lua_key)
71
+ val value = convertedValue(lua_value)
72
+ val data_type = guessedDataType(value)
73
+ val schema = row.schema
74
+ val column_index = schema.fieldIndex(key)
75
+ val new_values = row.row.toSeq.updated(column_index, value)
76
+ val new_row = Row.fromSeq(new_values)
77
+
78
+ new LuaRowValue(schema, new_row)
79
+ }
80
+ }
81
+
82
+ class slice extends LibFunction {
83
+ override def call(lua_row: LuaValue, lua_keys: LuaValue): LuaValue = {
84
+ val row = toLuaRowValue(lua_row)
85
+ val key_list = toLuaTable(lua_keys)
86
+ val keys = tableToArray(key_list)
87
+ val schema = row.schema
88
+ val new_schema = StructType(keys.map(schema(_)))
89
+ val field_indices = keys.map(schema.fieldIndex(_))
90
+ val new_values = field_indices.map(row.row(_))
91
+ val new_row = Row.fromSeq(new_values)
92
+
93
+ new LuaRowValue(new_schema, new_row)
94
+ }
95
+
96
+ private def toLuaTable(lua_keys: LuaValue): LuaTable = lua_keys match {
97
+ case list: LuaTable => list
98
+ }
99
+
100
+ private def tableToArray(key_list: LuaValue): IndexedSeq[String] = (1 to key_list.length).map {
101
+ index: Int => key_list.get(index) match {
102
+ case str: LuaString => str.toString()
103
+ }
104
+ }
105
+ }
106
+
107
+ }
108
+
@@ -2,7 +2,7 @@ package org.apache.spark.api.cassandra_model
2
2
 
3
3
  import scala.collection.mutable._
4
4
 
5
- class MarshalLoader (dump: Array[Byte]) {
5
+ class MarshalLoader(dump: Array[Byte]) {
6
6
  private val bytes: Array[Byte] = dump
7
7
  private var parse_index: Int = 0
8
8
  private var symbol_table: List[String] = List()
@@ -34,7 +34,7 @@ class MarshalLoader (dump: Array[Byte]) {
34
34
  var bit: Int = 0
35
35
  var value: Int = 0
36
36
 
37
- for (bit <- 0 to num_bytes-1) {
37
+ for (bit <- 0 to num_bytes - 1) {
38
38
  val next_value = 0xff & nextByte()
39
39
  value += (next_value << (bit * 8))
40
40
  }
@@ -145,7 +145,7 @@ class MarshalLoader (dump: Array[Byte]) {
145
145
  val length = decodeInt()
146
146
 
147
147
  var item = 0
148
- for (item <- 0 to length-1) {
148
+ for (item <- 0 to length - 1) {
149
149
  val key = decodeAny()
150
150
  val value = decodeAny()
151
151
  result(key) = value
@@ -160,7 +160,7 @@ class MarshalLoader (dump: Array[Byte]) {
160
160
  val length = decodeInt()
161
161
 
162
162
  var item = 0
163
- for (item <- 0 to length-1) {
163
+ for (item <- 0 to length - 1) {
164
164
  val value = decodeAny()
165
165
  list_result :+= value
166
166
  }
@@ -171,9 +171,9 @@ class MarshalLoader (dump: Array[Byte]) {
171
171
  }
172
172
 
173
173
  private def decodeObjectReference(): AnyRef = {
174
- val index = decodeInt()-1
174
+ val index = decodeInt() - 1
175
175
 
176
- object_table(index)
176
+ object_table(index)
177
177
  }
178
178
 
179
179
  private def decodeAny(): AnyRef = {
@@ -18,7 +18,15 @@ object MapStringStringRowMapping {
18
18
  val value = decoder.getValue()
19
19
 
20
20
  value match {
21
- case (m: Map[_, _]) => m map { case (key, value) => (String.valueOf(key), String.valueOf(value)) }
21
+ case (m: Map[_, _]) => m map {
22
+ case (key, value) => {
23
+ val new_value = value match {
24
+ case Some(some) => String.valueOf(some)
25
+ case None => null
26
+ }
27
+ (String.valueOf(key), new_value)
28
+ }
29
+ }
22
30
  case _ => new IllegalArgumentException("Unsupported Ruby Type")
23
31
  }
24
32
  } else {
@@ -28,7 +36,7 @@ object MapStringStringRowMapping {
28
36
 
29
37
  private def updatedRow(row: CassandraRow): CassandraRow = {
30
38
  val columns = row.columnNames
31
- val values = row.columnValues.map{
39
+ val values = row.columnValues.map {
32
40
  value => value match {
33
41
  case (blob: Array[Byte]) => decodeValue(blob)
34
42
  case _ => value
@@ -67,7 +75,7 @@ object SparkRowRowMapping {
67
75
 
68
76
  private def updatedRow(row: CassandraRow): CassandraRow = {
69
77
  val columns = row.columnNames
70
- val values = row.columnValues.map{
78
+ val values = row.columnValues.map {
71
79
  value => value match {
72
80
  case (blob: Array[Byte]) => decodeValue(blob)
73
81
  case _ => value
@@ -0,0 +1,20 @@
1
+ package org.apache.spark.api.cassandra_model
2
+
3
+ import org.apache.spark.rdd._
4
+ import com.datastax.spark.connector._
5
+ import com.datastax.spark.connector.rdd._
6
+ import org.apache.spark.sql._
7
+
8
+ object RowConversions {
9
+ def cassandraRDDToRowRDD(rdd: RDD[CassandraRow]): RDD[Row] = {
10
+ rdd.map(row => Row.fromSeq(cassandraToRow(row)))
11
+ }
12
+
13
+ private def cassandraToRow(row: CassandraRow): Seq[Any] = {
14
+ row.columnValues.map {
15
+ case (date: java.util.Date) => new java.sql.Timestamp(date.getTime())
16
+ case (uuid: java.util.UUID) => uuid.toString()
17
+ case value => value
18
+ }
19
+ }
20
+ }
@@ -38,5 +38,7 @@ require 'cassandra_model_spark/raw_connection'
38
38
  require 'cassandra_model_spark/connection_cache'
39
39
  require 'cassandra_model_spark/record'
40
40
  require 'cassandra_model_spark/query_builder'
41
+ require 'cassandra_model_spark/sql_schema'
42
+ require 'cassandra_model_spark/schema'
41
43
  require 'cassandra_model_spark/data_frame'
42
44
  require 'cassandra_model_spark/column_cast'
@@ -1,9 +1,9 @@
1
1
  module CassandraModel
2
2
  class ConnectionCache
3
3
  def self.clear
4
- @@cache.values.map(&:java_spark_context).map(&:stop)
4
+ @@cache.values.select(&:has_spark_context?).map(&:java_spark_context).map(&:stop)
5
5
  @@cache.values.map(&:shutdown)
6
6
  @@cache.clear
7
7
  end
8
8
  end
9
- end
9
+ end
@@ -1,13 +1,14 @@
1
1
  module CassandraModel
2
2
  module Spark
3
+ #noinspection RubyStringKeysInHashInspection
3
4
  class DataFrame
4
5
  include QueryHelper
5
6
 
6
7
  SQL_TYPE_MAP = {
7
- int: SqlIntegerType,
8
- text: SqlStringType,
9
- double: SqlDoubleType,
10
- timestamp: SqlTimestampType,
8
+ int: Lib::SqlIntegerType,
9
+ text: Lib::SqlStringType,
10
+ double: Lib::SqlDoubleType,
11
+ timestamp: Lib::SqlTimestampType,
11
12
  }.freeze
12
13
  #noinspection RubyStringKeysInHashInspection
13
14
  SQL_RUBY_TYPE_FUNCTIONS = {
@@ -21,6 +22,29 @@ module CassandraModel
21
22
 
22
23
  attr_reader :table_name, :record_klass
23
24
 
25
+ class << self
26
+ def from_csv(record_klass, path, options = {})
27
+ sql_context = options.delete(:sql_context) || create_sql_context(record_klass)
28
+ updated_options = csv_options(options)
29
+ csv_frame = sql_context.read.format('com.databricks.spark.csv').options(updated_options).load(path)
30
+
31
+ table_name = File.basename(path).gsub(/\./, '_') + "_#{SecureRandom.hex(2)}"
32
+ new(record_klass, nil, spark_data_frame: csv_frame, alias: table_name)
33
+ end
34
+
35
+ def create_sql_context(record_klass)
36
+ Lib::CassandraSQLContext.new(record_klass.table.connection.spark_context).tap do |context|
37
+ context.setKeyspace(record_klass.table.connection.config[:keyspace])
38
+ end
39
+ end
40
+
41
+ def csv_options(options)
42
+ options.inject('header' => 'true') do |memo, (key, value)|
43
+ memo.merge!(key.to_s.camelize(:lower) => value)
44
+ end.to_java
45
+ end
46
+ end
47
+
24
48
  def initialize(record_klass, rdd, options = {})
25
49
  @table_name = options.fetch(:alias) { record_klass.table_name }
26
50
  @sql_context = options[:sql_context]
@@ -36,7 +60,7 @@ module CassandraModel
36
60
  end
37
61
 
38
62
  def sql_context
39
- @sql_context ||= create_sql_context
63
+ @sql_context ||= self.class.create_sql_context(record_klass)
40
64
  end
41
65
 
42
66
  def union(rhs)
@@ -47,19 +71,9 @@ module CassandraModel
47
71
  end
48
72
 
49
73
  def spark_data_frame
50
- @frame ||= SparkSchemaBuilder.new.tap do |builder|
51
- record_klass.cassandra_columns.each do |name, type|
52
- select_name = record_klass.normalized_column(name)
53
- mapped_type = row_type_mapping[select_name]
54
- type = if mapped_type
55
- name = mapped_type[:name]
56
- mapped_type[:type]
57
- else
58
- SQL_TYPE_MAP.fetch(type) { SqlStringType }
59
- end
60
- builder.add_column(name.to_s, type)
61
- end
62
- end.create_data_frame(sql_context, rdd).tap { |frame| frame.register_temp_table(table_name.to_s) }
74
+ @frame ||= sql_context.createDataFrame(converted_rdd, record_klass.sql_schema.schema).tap do |frame|
75
+ frame.register_temp_table(table_name.to_s)
76
+ end
63
77
  end
64
78
 
65
79
  def cache
@@ -104,6 +118,12 @@ module CassandraModel
104
118
 
105
119
  end
106
120
 
121
+ def sql_frame(query, options)
122
+ spark_data_frame
123
+ new_frame = sql_context.sql(query)
124
+ self.class.new(options.delete(:class) || record_klass, nil, options.merge(spark_data_frame: new_frame))
125
+ end
126
+
107
127
  def query(restriction, options)
108
128
  spark_data_frame
109
129
  select_clause = select_columns(options)
@@ -129,6 +149,23 @@ module CassandraModel
129
149
  row_to_record(query.schema, row)
130
150
  end
131
151
 
152
+ def to_csv(path, options = {})
153
+ updated_options = csv_options(options)
154
+ spark_data_frame.write.format('com.databricks.spark.csv').options(updated_options).save(path)
155
+ end
156
+
157
+ def save_to(save_record_klass)
158
+ #noinspection RubyStringKeysInHashInspection
159
+ java_options = save_options_for_model(save_record_klass)
160
+
161
+ available_columns = spark_data_frame.schema.fields.map(&:name).map(&:to_sym)
162
+ column_map = save_record_klass.denormalized_column_map(available_columns)
163
+
164
+ save_frame = frame_to_save(available_columns, column_map)
165
+ save_frame(java_options, save_frame)
166
+ save_truth_table(column_map, java_options, save_record_klass)
167
+ end
168
+
132
169
  def ==(rhs)
133
170
  rhs.is_a?(DataFrame) &&
134
171
  record_klass == rhs.record_klass &&
@@ -152,16 +189,16 @@ module CassandraModel
152
189
 
153
190
  def initialize_rdd(rdd)
154
191
  if rdd
155
- @rdd = if @row_mapping[:mapper]
156
- @row_mapping[:mapper].mappedRDD(rdd)
157
- else
158
- rdd
159
- end
192
+ @rdd = rdd
160
193
  else
161
194
  @derived = true
162
195
  end
163
196
  end
164
197
 
198
+ def converted_rdd
199
+ Lib::SqlRowConversions.cassandraRDDToRowRDD(rdd)
200
+ end
201
+
165
202
  def initialize_row_mapping(options)
166
203
  @row_mapping = options.fetch(:row_mapping) do
167
204
  @record_klass.rdd_row_mapping || {}
@@ -172,12 +209,6 @@ module CassandraModel
172
209
  @row_mapping[:type_map] ||= {}
173
210
  end
174
211
 
175
- def create_sql_context
176
- CassandraSQLContext.new(record_klass.table.connection.spark_context).tap do |context|
177
- context.setKeyspace(record_klass.table.connection.config[:keyspace])
178
- end
179
- end
180
-
181
212
  def row_to_record(schema, row)
182
213
  attributes = row_attributes(row, schema)
183
214
 
@@ -204,25 +235,33 @@ module CassandraModel
204
235
  end
205
236
 
206
237
  def field_value(field, index, row)
207
- data_type = field.data_type
238
+ data_type = field.dataType
208
239
  if column_is_struct?(data_type)
209
240
  row_attributes(row.get(index), data_type)
210
241
  else
211
- decode_column_value(data_type, index, row)
242
+ decode_column_value(field, index, row)
212
243
  end
213
244
  end
214
245
 
215
- def decode_column_value(data_type, index, row)
216
- sql_type = data_type.to_string
246
+ def decode_column_value(field, index, row)
247
+ sql_type = field.dataType.toString
217
248
  converter = SQL_RUBY_TYPE_FUNCTIONS.fetch(sql_type) { :getString }
218
249
  value = row.public_send(converter, index)
219
250
 
251
+ data_column_name = record_klass.select_column(field.name.to_sym)
252
+ case record_klass.cassandra_columns[data_column_name]
253
+ when :uuid
254
+ value = Cassandra::Uuid.new(value)
255
+ when :timeuuid
256
+ value = Cassandra::TimeUuid.new(value)
257
+ end
258
+
220
259
  value = decode_hash(value) if column_is_string_map?(sql_type)
221
260
  value
222
261
  end
223
262
 
224
263
  def decode_hash(value)
225
- Hash[value.toSeq.array.to_a.map! { |pair| [pair._1.to_string, pair._2.to_string] }]
264
+ Hash[value.toSeq.array.to_a.map! { |pair| [pair._1.toString, pair._2.toString] }]
226
265
  end
227
266
 
228
267
  def column_is_string_map?(sql_type)
@@ -297,8 +336,12 @@ module CassandraModel
297
336
  '*'
298
337
  elsif column.respond_to?(:quote)
299
338
  column.quote('`')
300
- else
339
+ elsif column.is_a?(Symbol)
301
340
  "`#{select_column(column)}`"
341
+ elsif column.is_a?(String)
342
+ "'#{column.gsub(/'/, "\\\\'")}'"
343
+ else
344
+ column
302
345
  end
303
346
  end
304
347
 
@@ -369,6 +412,49 @@ module CassandraModel
369
412
  ThomasUtils::KeyComparer.new(updated_key, '=').quote('`')
370
413
  end
371
414
 
415
+ def frame_to_save(available_columns, column_map)
416
+ if available_columns == column_map.keys
417
+ spark_data_frame
418
+ else
419
+ select_clause = save_select_clause(column_map)
420
+ query({}, select: select_clause)
421
+ end
422
+ end
423
+
424
+ def csv_options(options)
425
+ self.class.csv_options(options)
426
+ end
427
+
428
+ def save_options_for_model(save_record_klass)
429
+ {
430
+ 'table' => save_record_klass.table_name,
431
+ 'keyspace' => save_record_klass.table.connection.config[:keyspace]
432
+ }.to_java
433
+ end
434
+
435
+ def save_truth_table(column_map, java_options, save_record_klass)
436
+ save_record_klass.composite_defaults.each do |row|
437
+ updated_map = row.inject({}.merge(column_map)) do |memo, (column, value)|
438
+ value = value.to_s if value.is_a?(Cassandra::Uuid)
439
+ memo.merge!(column => value)
440
+ end
441
+
442
+ select_clause = save_select_clause(updated_map)
443
+ frame = query({}, select: select_clause)
444
+ save_frame(java_options, frame)
445
+ end
446
+ end
447
+
448
+ def save_frame(java_options, save_frame)
449
+ save_frame.write.format('org.apache.spark.sql.cassandra').options(java_options).mode('Append').save
450
+ end
451
+
452
+ def save_select_clause(updated_column_map)
453
+ updated_column_map.map do |target, source|
454
+ {source => {as: target}}
455
+ end
456
+ end
457
+
372
458
  end
373
459
  end
374
460
  end