avro-jruby 1.7.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,550 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ require "net/http"
18
+
19
+ module Avro::IPC
20
+
21
+ class AvroRemoteError < Avro::AvroError; end
22
+
23
+ HANDSHAKE_REQUEST_SCHEMA = Avro::Schema.parse <<-JSON
24
+ {
25
+ "type": "record",
26
+ "name": "HandshakeRequest", "namespace":"org.apache.avro.ipc",
27
+ "fields": [
28
+ {"name": "clientHash",
29
+ "type": {"type": "fixed", "name": "MD5", "size": 16}},
30
+ {"name": "clientProtocol", "type": ["null", "string"]},
31
+ {"name": "serverHash", "type": "MD5"},
32
+ {"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]}
33
+ ]
34
+ }
35
+ JSON
36
+
37
+ HANDSHAKE_RESPONSE_SCHEMA = Avro::Schema.parse <<-JSON
38
+ {
39
+ "type": "record",
40
+ "name": "HandshakeResponse", "namespace": "org.apache.avro.ipc",
41
+ "fields": [
42
+ {"name": "match",
43
+ "type": {"type": "enum", "name": "HandshakeMatch",
44
+ "symbols": ["BOTH", "CLIENT", "NONE"]}},
45
+ {"name": "serverProtocol", "type": ["null", "string"]},
46
+ {"name": "serverHash",
47
+ "type": ["null", {"type": "fixed", "name": "MD5", "size": 16}]},
48
+ {"name": "meta",
49
+ "type": ["null", {"type": "map", "values": "bytes"}]}
50
+ ]
51
+ }
52
+ JSON
53
+
54
+ HANDSHAKE_REQUESTOR_WRITER = Avro::IO::DatumWriter.new(HANDSHAKE_REQUEST_SCHEMA)
55
+ HANDSHAKE_REQUESTOR_READER = Avro::IO::DatumReader.new(HANDSHAKE_RESPONSE_SCHEMA)
56
+ HANDSHAKE_RESPONDER_WRITER = Avro::IO::DatumWriter.new(HANDSHAKE_RESPONSE_SCHEMA)
57
+ HANDSHAKE_RESPONDER_READER = Avro::IO::DatumReader.new(HANDSHAKE_REQUEST_SCHEMA)
58
+
59
+ META_SCHEMA = Avro::Schema.parse('{"type": "map", "values": "bytes"}')
60
+ META_WRITER = Avro::IO::DatumWriter.new(META_SCHEMA)
61
+ META_READER = Avro::IO::DatumReader.new(META_SCHEMA)
62
+
63
+ SYSTEM_ERROR_SCHEMA = Avro::Schema.parse('["string"]')
64
+
65
+ # protocol cache
66
+ REMOTE_HASHES = {}
67
+ REMOTE_PROTOCOLS = {}
68
+
69
+ BUFFER_HEADER_LENGTH = 4
70
+ BUFFER_SIZE = 8192
71
+
72
+ # Raised when an error message is sent by an Avro requestor or responder.
73
+ class AvroRemoteException < Avro::AvroError; end
74
+
75
+ class ConnectionClosedException < Avro::AvroError; end
76
+
77
+ class Requestor
78
+ """Base class for the client side of a protocol interaction."""
79
+ attr_reader :local_protocol, :transport
80
+ attr_accessor :remote_protocol, :remote_hash, :send_protocol
81
+
82
+ def initialize(local_protocol, transport)
83
+ @local_protocol = local_protocol
84
+ @transport = transport
85
+ @remote_protocol = nil
86
+ @remote_hash = nil
87
+ @send_protocol = nil
88
+ end
89
+
90
+ def remote_protocol=(new_remote_protocol)
91
+ @remote_protocol = new_remote_protocol
92
+ REMOTE_PROTOCOLS[transport.remote_name] = remote_protocol
93
+ end
94
+
95
+ def remote_hash=(new_remote_hash)
96
+ @remote_hash = new_remote_hash
97
+ REMOTE_HASHES[transport.remote_name] = remote_hash
98
+ end
99
+
100
+ def request(message_name, request_datum)
101
+ # Writes a request message and reads a response or error message.
102
+ # build handshake and call request
103
+ buffer_writer = StringIO.new('', 'w+')
104
+ buffer_encoder = Avro::IO::BinaryEncoder.new(buffer_writer)
105
+ write_handshake_request(buffer_encoder)
106
+ write_call_request(message_name, request_datum, buffer_encoder)
107
+
108
+ # send the handshake and call request; block until call response
109
+ call_request = buffer_writer.string
110
+ call_response = transport.transceive(call_request)
111
+
112
+ # process the handshake and call response
113
+ buffer_decoder = Avro::IO::BinaryDecoder.new(StringIO.new(call_response))
114
+ if read_handshake_response(buffer_decoder)
115
+ read_call_response(message_name, buffer_decoder)
116
+ else
117
+ request(message_name, request_datum)
118
+ end
119
+ end
120
+
121
+ def write_handshake_request(encoder)
122
+ local_hash = local_protocol.md5
123
+ remote_name = transport.remote_name
124
+ remote_hash = REMOTE_HASHES[remote_name]
125
+ unless remote_hash
126
+ remote_hash = local_hash
127
+ self.remote_protocol = local_protocol
128
+ end
129
+ request_datum = {
130
+ 'clientHash' => local_hash,
131
+ 'serverHash' => remote_hash
132
+ }
133
+ if send_protocol
134
+ request_datum['clientProtocol'] = local_protocol.to_s
135
+ end
136
+ HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)
137
+ end
138
+
139
+ def write_call_request(message_name, request_datum, encoder)
140
+ # The format of a call request is:
141
+ # * request metadata, a map with values of type bytes
142
+ # * the message name, an Avro string, followed by
143
+ # * the message parameters. Parameters are serialized according to
144
+ # the message's request declaration.
145
+
146
+ # TODO request metadata (not yet implemented)
147
+ request_metadata = {}
148
+ META_WRITER.write(request_metadata, encoder)
149
+
150
+ message = local_protocol.messages[message_name]
151
+ unless message
152
+ raise AvroError, "Unknown message: #{message_name}"
153
+ end
154
+ encoder.write_string(message.name)
155
+
156
+ write_request(message.request, request_datum, encoder)
157
+ end
158
+
159
+ def write_request(request_schema, request_datum, encoder)
160
+ datum_writer = Avro::IO::DatumWriter.new(request_schema)
161
+ datum_writer.write(request_datum, encoder)
162
+ end
163
+
164
+ def read_handshake_response(decoder)
165
+ handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
166
+ we_have_matching_schema = false
167
+
168
+ case handshake_response['match']
169
+ when 'BOTH'
170
+ self.send_protocol = false
171
+ we_have_matching_schema = true
172
+ when 'CLIENT'
173
+ raise AvroError.new('Handshake failure. match == CLIENT') if send_protocol
174
+ self.remote_protocol = Avro::Protocol.parse(handshake_response['serverProtocol'])
175
+ self.remote_hash = handshake_response['serverHash']
176
+ self.send_protocol = false
177
+ we_have_matching_schema = true
178
+ when 'NONE'
179
+ raise AvroError.new('Handshake failure. match == NONE') if send_protocol
180
+ self.remote_protocol = Avro::Protocol.parse(handshake_response['serverProtocol'])
181
+ self.remote_hash = handshake_response['serverHash']
182
+ self.send_protocol = true
183
+ else
184
+ raise AvroError.new("Unexpected match: #{match}")
185
+ end
186
+
187
+ return we_have_matching_schema
188
+ end
189
+
190
+ def read_call_response(message_name, decoder)
191
+ # The format of a call response is:
192
+ # * response metadata, a map with values of type bytes
193
+ # * a one-byte error flag boolean, followed by either:
194
+ # * if the error flag is false,
195
+ # the message response, serialized per the message's response schema.
196
+ # * if the error flag is true,
197
+ # the error, serialized per the message's error union schema.
198
+ response_metadata = META_READER.read(decoder)
199
+
200
+ # remote response schema
201
+ remote_message_schema = remote_protocol.messages[message_name]
202
+ raise AvroError.new("Unknown remote message: #{message_name}") unless remote_message_schema
203
+
204
+ # local response schema
205
+ local_message_schema = local_protocol.messages[message_name]
206
+ unless local_message_schema
207
+ raise AvroError.new("Unknown local message: #{message_name}")
208
+ end
209
+
210
+ # error flag
211
+ if !decoder.read_boolean
212
+ writers_schema = remote_message_schema.response
213
+ readers_schema = local_message_schema.response
214
+ read_response(writers_schema, readers_schema, decoder)
215
+ else
216
+ writers_schema = remote_message_schema.errors || SYSTEM_ERROR_SCHEMA
217
+ readers_schema = local_message_schema.errors || SYSTEM_ERROR_SCHEMA
218
+ raise read_error(writers_schema, readers_schema, decoder)
219
+ end
220
+ end
221
+
222
+ def read_response(writers_schema, readers_schema, decoder)
223
+ datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema)
224
+ datum_reader.read(decoder)
225
+ end
226
+
227
+ def read_error(writers_schema, readers_schema, decoder)
228
+ datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema)
229
+ AvroRemoteError.new(datum_reader.read(decoder))
230
+ end
231
+ end
232
+
233
+ # Base class for the server side of a protocol interaction.
234
+ class Responder
235
+ attr_reader :local_protocol, :local_hash, :protocol_cache
236
+ def initialize(local_protocol)
237
+ @local_protocol = local_protocol
238
+ @local_hash = self.local_protocol.md5
239
+ @protocol_cache = {}
240
+ protocol_cache[local_hash] = local_protocol
241
+ end
242
+
243
+ # Called by a server to deserialize a request, compute and serialize
244
+ # a response or error. Compare to 'handle()' in Thrift.
245
+ def respond(call_request, transport=nil)
246
+ buffer_decoder = Avro::IO::BinaryDecoder.new(StringIO.new(call_request))
247
+ buffer_writer = StringIO.new('', 'w+')
248
+ buffer_encoder = Avro::IO::BinaryEncoder.new(buffer_writer)
249
+ error = nil
250
+ response_metadata = {}
251
+
252
+ begin
253
+ remote_protocol = process_handshake(buffer_decoder, buffer_encoder, transport)
254
+ # handshake failure
255
+ unless remote_protocol
256
+ return buffer_writer.string
257
+ end
258
+
259
+ # read request using remote protocol
260
+ request_metadata = META_READER.read(buffer_decoder)
261
+ remote_message_name = buffer_decoder.read_string
262
+
263
+ # get remote and local request schemas so we can do
264
+ # schema resolution (one fine day)
265
+ remote_message = remote_protocol.messages[remote_message_name]
266
+ unless remote_message
267
+ raise AvroError.new("Unknown remote message: #{remote_message_name}")
268
+ end
269
+ local_message = local_protocol.messages[remote_message_name]
270
+ unless local_message
271
+ raise AvroError.new("Unknown local message: #{remote_message_name}")
272
+ end
273
+ writers_schema = remote_message.request
274
+ readers_schema = local_message.request
275
+ request = read_request(writers_schema, readers_schema, buffer_decoder)
276
+ # perform server logic
277
+ begin
278
+ response = call(local_message, request)
279
+ rescue AvroRemoteError => e
280
+ error = e
281
+ rescue Exception => e
282
+ error = AvroRemoteError.new(e.to_s)
283
+ end
284
+
285
+ # write response using local protocol
286
+ META_WRITER.write(response_metadata, buffer_encoder)
287
+ buffer_encoder.write_boolean(!!error)
288
+ if error.nil?
289
+ writers_schema = local_message.response
290
+ write_response(writers_schema, response, buffer_encoder)
291
+ else
292
+ writers_schema = local_message.errors || SYSTEM_ERROR_SCHEMA
293
+ write_error(writers_schema, error, buffer_encoder)
294
+ end
295
+ rescue Avro::AvroError => e
296
+ error = AvroRemoteException.new(e.to_s)
297
+ buffer_encoder = Avro::IO::BinaryEncoder.new(StringIO.new)
298
+ META_WRITER.write(response_metadata, buffer_encoder)
299
+ buffer_encoder.write_boolean(true)
300
+ self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
301
+ end
302
+ buffer_writer.string
303
+ end
304
+
305
+ def process_handshake(decoder, encoder, connection=nil)
306
+ if connection && connection.is_connected?
307
+ return connection.protocol
308
+ end
309
+ handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
310
+ handshake_response = {}
311
+
312
+ # determine the remote protocol
313
+ client_hash = handshake_request['clientHash']
314
+ client_protocol = handshake_request['clientProtocol']
315
+ remote_protocol = protocol_cache[client_hash]
316
+
317
+ if !remote_protocol && client_protocol
318
+ remote_protocol = Avro::Protocol.parse(client_protocol)
319
+ protocol_cache[client_hash] = remote_protocol
320
+ end
321
+
322
+ # evaluate remote's guess of the local protocol
323
+ server_hash = handshake_request['serverHash']
324
+ if local_hash == server_hash
325
+ if !remote_protocol
326
+ handshake_response['match'] = 'NONE'
327
+ else
328
+ handshake_response['match'] = 'BOTH'
329
+ end
330
+ else
331
+ if !remote_protocol
332
+ handshake_response['match'] = 'NONE'
333
+ else
334
+ handshake_response['match'] = 'CLIENT'
335
+ end
336
+ end
337
+
338
+ if handshake_response['match'] != 'BOTH'
339
+ handshake_response['serverProtocol'] = local_protocol.to_s
340
+ handshake_response['serverHash'] = local_hash
341
+ end
342
+
343
+ HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
344
+
345
+ if connection && handshake_response['match'] != 'NONE'
346
+ connection.protocol = remote_protocol
347
+ end
348
+
349
+ remote_protocol
350
+ end
351
+
352
+ def call(local_message, request)
353
+ # Actual work done by server: cf. handler in thrift.
354
+ raise NotImplementedError
355
+ end
356
+
357
+ def read_request(writers_schema, readers_schema, decoder)
358
+ datum_reader = Avro::IO::DatumReader.new(writers_schema, readers_schema)
359
+ datum_reader.read(decoder)
360
+ end
361
+
362
+ def write_response(writers_schema, response_datum, encoder)
363
+ datum_writer = Avro::IO::DatumWriter.new(writers_schema)
364
+ datum_writer.write(response_datum, encoder)
365
+ end
366
+
367
+ def write_error(writers_schema, error_exception, encoder)
368
+ datum_writer = Avro::IO::DatumWriter.new(writers_schema)
369
+ datum_writer.write(error_exception.to_s, encoder)
370
+ end
371
+ end
372
+
373
+ class SocketTransport
374
+ # A simple socket-based Transport implementation.
375
+
376
+ attr_reader :sock, :remote_name
377
+ attr_accessor :protocol
378
+
379
+ def initialize(sock)
380
+ @sock = sock
381
+ @protocol = nil
382
+ end
383
+
384
+ def is_connected?()
385
+ !!@protocol
386
+ end
387
+
388
+ def transceive(request)
389
+ write_framed_message(request)
390
+ read_framed_message
391
+ end
392
+
393
+ def read_framed_message
394
+ message = []
395
+ loop do
396
+ buffer = StringIO.new
397
+ buffer_length = read_buffer_length
398
+ if buffer_length == 0
399
+ return message.join
400
+ end
401
+ while buffer.tell < buffer_length
402
+ chunk = sock.read(buffer_length - buffer.tell)
403
+ if chunk == ''
404
+ raise ConnectionClosedException.new("Socket read 0 bytes.")
405
+ end
406
+ buffer.write(chunk)
407
+ end
408
+ message << buffer.string
409
+ end
410
+ end
411
+
412
+ def write_framed_message(message)
413
+ message_length = message.size
414
+ total_bytes_sent = 0
415
+ while message_length - total_bytes_sent > 0
416
+ if message_length - total_bytes_sent > BUFFER_SIZE
417
+ buffer_length = BUFFER_SIZE
418
+ else
419
+ buffer_length = message_length - total_bytes_sent
420
+ end
421
+ write_buffer(message[total_bytes_sent,buffer_length])
422
+ total_bytes_sent += buffer_length
423
+ end
424
+ # A message is always terminated by a zero-length buffer.
425
+ write_buffer_length(0)
426
+ end
427
+
428
+ def write_buffer(chunk)
429
+ buffer_length = chunk.size
430
+ write_buffer_length(buffer_length)
431
+ total_bytes_sent = 0
432
+ while total_bytes_sent < buffer_length
433
+ bytes_sent = self.sock.write(chunk[total_bytes_sent..-1])
434
+ if bytes_sent == 0
435
+ raise ConnectionClosedException.new("Socket sent 0 bytes.")
436
+ end
437
+ total_bytes_sent += bytes_sent
438
+ end
439
+ end
440
+
441
+ def write_buffer_length(n)
442
+ bytes_sent = sock.write([n].pack('N'))
443
+ if bytes_sent == 0
444
+ raise ConnectionClosedException.new("socket sent 0 bytes")
445
+ end
446
+ end
447
+
448
+ def read_buffer_length
449
+ read = sock.read(BUFFER_HEADER_LENGTH)
450
+ if read == '' || read == nil
451
+ raise ConnectionClosedException.new("Socket read 0 bytes.")
452
+ end
453
+ read.unpack('N')[0]
454
+ end
455
+
456
+ def close
457
+ sock.close
458
+ end
459
+ end
460
+
461
+ class ConnectionClosedError < StandardError; end
462
+
463
+ class FramedWriter
464
+ attr_reader :writer
465
+ def initialize(writer)
466
+ @writer = writer
467
+ end
468
+
469
+ def write_framed_message(message)
470
+ message_size = message.size
471
+ total_bytes_sent = 0
472
+ while message_size - total_bytes_sent > 0
473
+ if message_size - total_bytes_sent > BUFFER_SIZE
474
+ buffer_size = BUFFER_SIZE
475
+ else
476
+ buffer_size = message_size - total_bytes_sent
477
+ end
478
+ write_buffer(message[total_bytes_sent, buffer_size])
479
+ total_bytes_sent += buffer_size
480
+ end
481
+ write_buffer_size(0)
482
+ end
483
+
484
+ def to_s; writer.string; end
485
+
486
+ private
487
+ def write_buffer(chunk)
488
+ buffer_size = chunk.size
489
+ write_buffer_size(buffer_size)
490
+ writer << chunk
491
+ end
492
+
493
+ def write_buffer_size(n)
494
+ writer.write([n].pack('N'))
495
+ end
496
+ end
497
+
498
+ class FramedReader
499
+ attr_reader :reader
500
+
501
+ def initialize(reader)
502
+ @reader = reader
503
+ end
504
+
505
+ def read_framed_message
506
+ message = []
507
+ loop do
508
+ buffer = ""
509
+ buffer_size = read_buffer_size
510
+
511
+ return message.join if buffer_size == 0
512
+
513
+ while buffer.size < buffer_size
514
+ chunk = reader.read(buffer_size - buffer.size)
515
+ chunk_error?(chunk)
516
+ buffer << chunk
517
+ end
518
+ message << buffer
519
+ end
520
+ end
521
+
522
+ private
523
+ def read_buffer_size
524
+ header = reader.read(BUFFER_HEADER_LENGTH)
525
+ chunk_error?(header)
526
+ header.unpack('N')[0]
527
+ end
528
+
529
+ def chunk_error?(chunk)
530
+ raise ConnectionClosedError.new("Reader read 0 bytes") if chunk == ''
531
+ end
532
+ end
533
+
534
+ # Only works for clients. Sigh.
535
+ class HTTPTransceiver
536
+ attr_reader :remote_name, :host, :port
537
+ def initialize(host, port)
538
+ @host, @port = host, port
539
+ @remote_name = "#{host}:#{port}"
540
+ @conn = Net::HTTP.start host, port
541
+ end
542
+
543
+ def transceive(message)
544
+ writer = FramedWriter.new(StringIO.new)
545
+ writer.write_framed_message(message)
546
+ resp = @conn.post('/', writer.to_s, {'Content-Type' => 'avro/binary'})
547
+ FramedReader.new(StringIO.new(resp.body)).read_framed_message
548
+ end
549
+ end
550
+ end