tmtm-ruby-mysql 0.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,558 @@
1
+ # Copyright (C) 2008 TOMITA Masahiro
2
+ # mailto:tommy@tmtm.org
3
+
4
+ require "socket"
5
+ require "timeout"
6
+ require "digest/sha1"
7
+ require "thread"
8
+ require "stringio"
9
+
10
+ class Mysql
11
+ # MySQL network protocol
12
+ class Protocol
13
+
14
+ VERSION = 10
15
+ MAX_PACKET_LENGTH = 2**24-1
16
+
17
+ # convert Numeric to LengthCodedBinary
18
+ def self.lcb(num)
19
+ return "\xfb" if num.nil?
20
+ return [num].pack("C") if num < 251
21
+ return [252, num].pack("Cv") if num < 65536
22
+ return [253, num&0xffff, num>>16].pack("CvC") if num < 16777216
23
+ return [254, num&0xffffffff, num>>32].pack("CVV")
24
+ end
25
+
26
+ # convert String to LengthCodedString
27
+ def self.lcs(str)
28
+ lcb(str.length)+str
29
+ end
30
+
31
+ # convert LengthCodedBinary to Integer
32
+ # === Argument
33
+ # lcb :: [String] LengthCodedBinary. This value will be broken.
34
+ # === Return
35
+ # Integer or nil
36
+ def self.lcb2int!(lcb)
37
+ return nil if lcb.empty?
38
+ case ord lcb
39
+ when 251
40
+ ord! lcb
41
+ return nil
42
+ when 252
43
+ _, v = lcb.unpack("Cv")
44
+ lcb[0, 3] = ""
45
+ return v
46
+ when 253
47
+ v, = lcb.unpack("V")
48
+ lcb[0, 4] = ""
49
+ return v >> 8
50
+ when 254
51
+ _, v1, v2 = lcb.unpack("CVV")
52
+ lcb[0, 9] = ""
53
+ return (v2 << 32) | v1
54
+ else
55
+ return ord! lcb
56
+ end
57
+ end
58
+
59
+ # convert LengthCodedString to String
60
+ # === Argument
61
+ # lcs :: [String] LengthCodedString. This value will be broken.
62
+ # === Return
63
+ # String or nil
64
+ def self.lcs2str!(lcs)
65
+ len = lcb2int! lcs
66
+ return len && lcs.slice!(0, len)
67
+ end
68
+
69
+ def self.eof_packet?(data)
70
+ ord(data) == 0xfe && data.length == 5
71
+ end
72
+
73
+ # Convert netdata to Ruby value
74
+ # === Argument
75
+ # data :: [String] packet data. This will be broken.
76
+ # type :: [Integer] field type
77
+ # unsigned :: [true or false] true if value is unsigned
78
+ # === Return
79
+ # Object :: converted value.
80
+ def self.net2value(data, type, unsigned)
81
+ case type
82
+ when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB
83
+ return Protocol.lcs2str!(data)
84
+ when Field::TYPE_TINY
85
+ v = ord! data
86
+ return unsigned ? v : v < 128 ? v : v-256
87
+ when Field::TYPE_SHORT
88
+ v = data.slice!(0,2).unpack("v").first
89
+ return unsigned ? v : v < 32768 ? v : v-65536
90
+ when Field::TYPE_INT24, Field::TYPE_LONG
91
+ v = data.slice!(0,4).unpack("V").first
92
+ return unsigned ? v : v < 2**32/2 ? v : v-2**32
93
+ when Field::TYPE_LONGLONG
94
+ n1, n2 = data.slice!(0,8).unpack("VV")
95
+ v = (n2<<32) | n1
96
+ return unsigned ? v : v < 2**64/2 ? v : v-2**64
97
+ when Field::TYPE_FLOAT
98
+ return data.slice!(0,4).unpack("e").first
99
+ when Field::TYPE_DOUBLE
100
+ return data.slice!(0,8).unpack("E").first
101
+ when Field::TYPE_DATE, Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
102
+ len = ord! data
103
+ y, m, d, h, mi, s, bs = data.slice!(0,len).unpack("vCCCCCV")
104
+ return Mysql::Time.new(y, m, d, h, mi, s, bs)
105
+ when Field::TYPE_TIME
106
+ len = ord! data
107
+ sign, d, h, mi, s, sp = data.slice!(0,len).unpack("CVCCCV")
108
+ h = d.to_i * 24 + h.to_i
109
+ return Mysql::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)
110
+ when Field::TYPE_YEAR
111
+ return data.slice!(0,2).unpack("v").first
112
+ when Field::TYPE_BIT
113
+ return Protocol.lcs2str!(data)
114
+ else
115
+ raise "not implemented: type=#{type}"
116
+ end
117
+ end
118
+
119
+ # convert Ruby value to netdata
120
+ # === Argument
121
+ # v :: [Object] Ruby value.
122
+ # === Return
123
+ # String :: netdata
124
+ # === Exception
125
+ # ProtocolError :: value too large / value is not supported
126
+ def self.value2net(v)
127
+ case v
128
+ when nil
129
+ type = Field::TYPE_NULL
130
+ val = ""
131
+ when Integer
132
+ if v >= 0
133
+ if v < 256
134
+ type = Field::TYPE_TINY | 0x8000
135
+ val = [v].pack("C")
136
+ elsif v < 256**2
137
+ type = Field::TYPE_SHORT | 0x8000
138
+ val = [v].pack("v")
139
+ elsif v < 256**4
140
+ type = Field::TYPE_LONG | 0x8000
141
+ val = [v].pack("V")
142
+ elsif v < 256**8
143
+ type = Field::TYPE_LONGLONG | 0x8000
144
+ val = [v&0xffffffff, v>>32].pack("VV")
145
+ else
146
+ raise ProtocolError, "value too large: #{v}"
147
+ end
148
+ else
149
+ if -v <= 256/2
150
+ type = Field::TYPE_TINY
151
+ val = [v].pack("C")
152
+ elsif -v <= 256**2/2
153
+ type = Field::TYPE_SHORT
154
+ val = [v].pack("v")
155
+ elsif -v <= 256**4/2
156
+ type = Field::TYPE_LONG
157
+ val = [v].pack("V")
158
+ elsif -v <= 256**8/2
159
+ type = Field::TYPE_LONGLONG
160
+ val = [v&0xffffffff, v>>32].pack("VV")
161
+ else
162
+ raise ProtocolError, "value too large: #{v}"
163
+ end
164
+ end
165
+ when Float
166
+ type = Field::TYPE_DOUBLE
167
+ val = [v].pack("E")
168
+ when String
169
+ type = Field::TYPE_STRING
170
+ val = Protocol.lcs(v)
171
+ when Mysql::Time, ::Time
172
+ type = Field::TYPE_DATETIME
173
+ val = [7, v.year, v.month, v.day, v.hour, v.min, v.sec].pack("CvCCCCC")
174
+ else
175
+ raise ProtocolError, "class #{v.class} is not supported"
176
+ end
177
+ return type, val
178
+ end
179
+
180
+ if "".respond_to? :ord
181
+ def self.ord(str)
182
+ str.ord
183
+ end
184
+ def self.ord!(str)
185
+ str.slice!(0).ord
186
+ end
187
+ else
188
+ def self.ord(str)
189
+ str[0]
190
+ end
191
+ def self.ord!(str)
192
+ str.slice!(0)
193
+ end
194
+ end
195
+
196
+ attr_reader :sqlstate
197
+
198
+ # make socket connection to server.
199
+ # === Argument
200
+ # host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket
201
+ # port :: [Integer] port number using by TCPSocket
202
+ # socket :: [String] socket file name using by UNIXSocket
203
+ # conn_timeout :: [Integer] connect timeout (sec).
204
+ # read_timeout :: [Integer] read timeout (sec).
205
+ # write_timeout :: [Integer] write timeout (sec).
206
+ def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout)
207
+ begin
208
+ Timeout.timeout conn_timeout do
209
+ if host.nil? or host.empty? or host == "localhost"
210
+ socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
211
+ @sock = UNIXSocket.new socket
212
+ else
213
+ port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
214
+ @sock = TCPSocket.new host, port
215
+ end
216
+ end
217
+ rescue Timeout::Error
218
+ raise ClientError, "connection timeout"
219
+ end
220
+ @read_timeout = read_timeout
221
+ @write_timeout = write_timeout
222
+ @seq = 0 # packet counter. reset by each command
223
+ @mutex = Mutex.new
224
+ end
225
+
226
+ def close
227
+ @sock.close
228
+ end
229
+
230
+ def synchronize
231
+ @mutex.synchronize do
232
+ return yield
233
+ end
234
+ end
235
+
236
+ # Reset sequence number
237
+ def reset
238
+ @seq = 0
239
+ end
240
+
241
+ # Read one packet data
242
+ # === Return
243
+ # String
244
+ # === Exception
245
+ # ProtocolError :: invalid packet sequence number
246
+ def read
247
+ ret = ""
248
+ len = nil
249
+ begin
250
+ Timeout.timeout @read_timeout do
251
+ header = @sock.read(4)
252
+ len1, len2, seq = header.unpack("CvC")
253
+ len = (len2 << 8) + len1
254
+ raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
255
+ @seq = (@seq + 1) % 256
256
+ ret.concat @sock.read(len)
257
+ end
258
+ rescue Timeout::Error
259
+ raise ClientError, "read timeout"
260
+ end while len == MAX_PACKET_LENGTH
261
+
262
+ @sqlstate = "00000"
263
+
264
+ # Error packet
265
+ if Protocol.ord(ret) == 0xff
266
+ f, errno, marker, @sqlstate, message = ret.unpack("Cvaa5a*")
267
+ unless marker == "#"
268
+ f, errno, message = ret.unpack("Cva*") # Version 4.0 Error
269
+ @sqlstate = ""
270
+ end
271
+ if Mysql::ServerError::ERROR_MAP.key? errno
272
+ raise Mysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
273
+ end
274
+ raise Mysql::ServerError.new(message, @sqlstate)
275
+ end
276
+ ret
277
+ end
278
+
279
+ # Write one packet data
280
+ # === Argument
281
+ # data [String / IO] ::
282
+ def write(data)
283
+ begin
284
+ @sock.sync = false
285
+ data = StringIO.new data if data.is_a? String
286
+ while d = data.read(MAX_PACKET_LENGTH)
287
+ Timeout.timeout @write_timeout do
288
+ @sock.write [d.length%256, d.length/256, @seq].pack("CvC")
289
+ @sock.write d
290
+ end
291
+ @seq = (@seq + 1) % 256
292
+ end
293
+ @sock.sync = true
294
+ Timeout.timeout @write_timeout do
295
+ @sock.flush
296
+ end
297
+ rescue Timeout::Error
298
+ raise ClientError, "write timeout"
299
+ end
300
+ end
301
+
302
+ # Send one packet
303
+ # === Argument
304
+ # packet :: [*Packet]
305
+ def send_packet(packet)
306
+ write packet.serialize
307
+ end
308
+
309
+ # Read EOF packet
310
+ # === Exception
311
+ # ProtocolError :: packet is not EOF
312
+ def read_eof_packet
313
+ data = read
314
+ raise ProtocolError, "packet is not EOF" unless Protocol.eof_packet? data
315
+ end
316
+
317
+ # Read initial packet
318
+ # === Return
319
+ # InitialPacket ::
320
+ # === Exception
321
+ # ProtocolError :: invalid packet
322
+ def read_initial_packet
323
+ InitialPacket.parse read
324
+ end
325
+
326
+ # Read result packet
327
+ # === Return
328
+ # ResultPacket ::
329
+ def read_result_packet
330
+ ResultPacket.parse read
331
+ end
332
+
333
+ # Read field packet
334
+ # === Return
335
+ # FieldPacket :: packet data
336
+ # === Exception
337
+ # ProtocolError :: invalid packet
338
+ def read_field_packet
339
+ FieldPacket.parse read
340
+ end
341
+
342
+ # Read prepare result packet
343
+ # === Return
344
+ # PrepareResultPacket ::
345
+ # === Exception
346
+ # ProtocolError :: invalid packet
347
+ def read_prepare_result_packet
348
+ PrepareResultPacket.parse read
349
+ end
350
+
351
+ # client->server packet base class
352
+ class TxPacket
353
+ end
354
+
355
+ # server->client packet base class
356
+ class RxPacket
357
+ end
358
+
359
+ # Initial packet
360
+ class InitialPacket < RxPacket
361
+ def self.parse(data)
362
+ protocol_version, server_version, thread_id, scramble_buff, f0,
363
+ server_capabilities, server_charset, server_status, f1,
364
+ rest_scramble_buff = data.unpack("CZ*Va8CvCva13Z13")
365
+ raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION
366
+ raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0
367
+ raise ProtocolError, "invalid packet: f1=#{f1.inspect}" unless f1 == "\0\0\0\0\0\0\0\0\0\0\0\0\0"
368
+ scramble_buff.concat rest_scramble_buff
369
+ self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff
370
+ end
371
+
372
+ attr_accessor :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff
373
+
374
+ def initialize(*args)
375
+ @protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff = args
376
+ end
377
+
378
+ def crypt_password(plain)
379
+ return "" if plain.nil? or plain.empty?
380
+ hash_stage1 = Digest::SHA1.digest plain
381
+ hash_stage2 = Digest::SHA1.digest hash_stage1
382
+ return hash_stage1.unpack("C*").zip(Digest::SHA1.digest(@scramble_buff+hash_stage2).unpack("C*")).map{|a,b| a^b}.pack("C*")
383
+ end
384
+ end
385
+
386
+ # Authentication packet
387
+ class AuthenticationPacket < TxPacket
388
+ attr_accessor :client_flags, :max_packet_size, :charset_number, :username, :scrambled_password, :databasename
389
+
390
+ def initialize(*args)
391
+ @client_flags, @max_packet_size, @charset_number, @username, @scrambled_password, @databasename = args
392
+ end
393
+
394
+ def serialize
395
+ [
396
+ client_flags,
397
+ max_packet_size,
398
+ Protocol.lcb(charset_number),
399
+ "", # always 0x00 * 23
400
+ username,
401
+ Protocol.lcs(scrambled_password),
402
+ databasename
403
+ ].pack("VVa*a23Z*A*Z*")
404
+ end
405
+ end
406
+
407
+ # Quit packet
408
+ class QuitPacket < TxPacket
409
+ def serialize
410
+ [COM_QUIT].pack("C")
411
+ end
412
+ end
413
+
414
+ # Query packet
415
+ class QueryPacket < TxPacket
416
+ attr_accessor :query
417
+
418
+ def initialize(*args)
419
+ @query, = args
420
+ end
421
+
422
+ def serialize
423
+ [COM_QUERY, query].pack("Ca*")
424
+ end
425
+ end
426
+
427
+ # Result packet
428
+ class ResultPacket < RxPacket
429
+ def self.parse(data)
430
+ field_count = Protocol.lcb2int! data
431
+ if field_count == 0
432
+ affected_rows = Protocol.lcb2int! data
433
+ insert_id = Protocol.lcb2int!(data)
434
+ server_status, warning_count, message = data.unpack("vva*")
435
+ return self.new field_count, affected_rows, insert_id, server_status, warning_count, message
436
+ else
437
+ return self.new field_count
438
+ end
439
+ end
440
+
441
+ attr_accessor :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
442
+
443
+ def initialize(*args)
444
+ @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
445
+ end
446
+ end
447
+
448
+ # Field packet
449
+ class FieldPacket < RxPacket
450
+ def self.parse(data)
451
+ first = Protocol.lcs2str! data
452
+ db = Protocol.lcs2str! data
453
+ table = Protocol.lcs2str! data
454
+ org_table = Protocol.lcs2str! data
455
+ name = Protocol.lcs2str! data
456
+ org_name = Protocol.lcs2str! data
457
+ f0, charsetnr, length, type, flags, decimals, f1, data = data.unpack("CvVCvCva*")
458
+ raise ProtocolError, "invalid packet: f1=#{f1}" unless f1 == 0
459
+ default = Protocol.lcs2str! data
460
+ return self.new db, table, org_table, name, org_name, charsetnr, length, type, flags, decimals, default
461
+ end
462
+
463
+ attr_accessor :db, :table, :org_table, :name, :org_name, :charsetnr, :length, :type, :flags, :decimals, :default
464
+
465
+ def initialize(*args)
466
+ @db, @table, @org_table, @name, @org_name, @charsetnr, @length, @type, @flags, @decimals, @default = args
467
+ end
468
+ end
469
+
470
+ # Prepare packet
471
+ class PreparePacket < TxPacket
472
+ attr_accessor :query
473
+
474
+ def initialize(*args)
475
+ @query, = args
476
+ end
477
+
478
+ def serialize
479
+ [COM_STMT_PREPARE, query].pack("Ca*")
480
+ end
481
+ end
482
+
483
+ # Prepare result packet
484
+ class PrepareResultPacket < RxPacket
485
+ def self.parse(data)
486
+ raise ProtocolError, "invalid packet" unless Protocol.ord!(data) == 0x00
487
+ statement_id, field_count, param_count, f, warning_count = data.unpack("VvvCv")
488
+ raise ProtocolError, "invalid packet" unless f == 0x00
489
+ self.new statement_id, field_count, param_count, warning_count
490
+ end
491
+
492
+ attr_accessor :statement_id, :field_count, :param_count, :warning_count
493
+
494
+ def initialize(*args)
495
+ @statement_id, @field_count, @param_count, @warning_count = args
496
+ end
497
+ end
498
+
499
+ # Execute packet
500
+ class ExecutePacket < TxPacket
501
+ attr_accessor :statement_id, :cursor_type, :values
502
+
503
+ def initialize(*args)
504
+ @statement_id, @cursor_type, @values = args
505
+ end
506
+
507
+ def serialize
508
+ nbm = null_bitmap values
509
+ netvalues = ""
510
+ types = values.map do |v|
511
+ t, n = Protocol.value2net v
512
+ netvalues.concat n if v
513
+ t
514
+ end
515
+ [Mysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
516
+ end
517
+
518
+ private
519
+
520
+ # make null bitmap
521
+ #
522
+ # If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
523
+ def null_bitmap(values)
524
+ bitmap = values.enum_for(:each_slice,8).map do |vals|
525
+ vals.reverse.inject(0){|b, v|(b<<1 | (v ? 0 : 1))}
526
+ end
527
+ return bitmap.pack("C*")
528
+ end
529
+
530
+ end
531
+
532
+ # Fetch packet
533
+ class FetchPacket < TxPacket
534
+ attr_accessor :statement_id, :fetch_length
535
+
536
+ def initialize(*args)
537
+ @statement_id, @fetch_length = args
538
+ end
539
+
540
+ def serialize
541
+ [Mysql::COM_STMT_FETCH, statement_id, fetch_length].pack("CVV")
542
+ end
543
+ end
544
+
545
+ # Stmt close packet
546
+ class StmtClosePacket < TxPacket
547
+ attr_accessor :statement_id
548
+
549
+ def initialize(*args)
550
+ @statement_id, = args
551
+ end
552
+
553
+ def serialize
554
+ [Mysql::COM_STMT_CLOSE, statement_id].pack("CV")
555
+ end
556
+ end
557
+ end
558
+ end