tmtm-ruby-mysql 0.0.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,558 @@
1
+ # Copyright (C) 2008 TOMITA Masahiro
2
+ # mailto:tommy@tmtm.org
3
+
4
+ require "socket"
5
+ require "timeout"
6
+ require "digest/sha1"
7
+ require "thread"
8
+ require "stringio"
9
+
10
+ class Mysql
11
+ # MySQL network protocol
12
+ class Protocol
13
+
14
+ VERSION = 10
15
+ MAX_PACKET_LENGTH = 2**24-1
16
+
17
+ # convert Numeric to LengthCodedBinary
18
+ def self.lcb(num)
19
+ return "\xfb" if num.nil?
20
+ return [num].pack("C") if num < 251
21
+ return [252, num].pack("Cv") if num < 65536
22
+ return [253, num&0xffff, num>>16].pack("CvC") if num < 16777216
23
+ return [254, num&0xffffffff, num>>32].pack("CVV")
24
+ end
25
+
26
+ # convert String to LengthCodedString
27
+ def self.lcs(str)
28
+ lcb(str.length)+str
29
+ end
30
+
31
+ # convert LengthCodedBinary to Integer
32
+ # === Argument
33
+ # lcb :: [String] LengthCodedBinary. This value will be broken.
34
+ # === Return
35
+ # Integer or nil
36
+ def self.lcb2int!(lcb)
37
+ return nil if lcb.empty?
38
+ case ord lcb
39
+ when 251
40
+ ord! lcb
41
+ return nil
42
+ when 252
43
+ _, v = lcb.unpack("Cv")
44
+ lcb[0, 3] = ""
45
+ return v
46
+ when 253
47
+ v, = lcb.unpack("V")
48
+ lcb[0, 4] = ""
49
+ return v >> 8
50
+ when 254
51
+ _, v1, v2 = lcb.unpack("CVV")
52
+ lcb[0, 9] = ""
53
+ return (v2 << 32) | v1
54
+ else
55
+ return ord! lcb
56
+ end
57
+ end
58
+
59
+ # convert LengthCodedString to String
60
+ # === Argument
61
+ # lcs :: [String] LengthCodedString. This value will be broken.
62
+ # === Return
63
+ # String or nil
64
+ def self.lcs2str!(lcs)
65
+ len = lcb2int! lcs
66
+ return len && lcs.slice!(0, len)
67
+ end
68
+
69
+ def self.eof_packet?(data)
70
+ ord(data) == 0xfe && data.length == 5
71
+ end
72
+
73
+ # Convert netdata to Ruby value
74
+ # === Argument
75
+ # data :: [String] packet data. This will be broken.
76
+ # type :: [Integer] field type
77
+ # unsigned :: [true or false] true if value is unsigned
78
+ # === Return
79
+ # Object :: converted value.
80
+ def self.net2value(data, type, unsigned)
81
+ case type
82
+ when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB
83
+ return Protocol.lcs2str!(data)
84
+ when Field::TYPE_TINY
85
+ v = ord! data
86
+ return unsigned ? v : v < 128 ? v : v-256
87
+ when Field::TYPE_SHORT
88
+ v = data.slice!(0,2).unpack("v").first
89
+ return unsigned ? v : v < 32768 ? v : v-65536
90
+ when Field::TYPE_INT24, Field::TYPE_LONG
91
+ v = data.slice!(0,4).unpack("V").first
92
+ return unsigned ? v : v < 2**32/2 ? v : v-2**32
93
+ when Field::TYPE_LONGLONG
94
+ n1, n2 = data.slice!(0,8).unpack("VV")
95
+ v = (n2<<32) | n1
96
+ return unsigned ? v : v < 2**64/2 ? v : v-2**64
97
+ when Field::TYPE_FLOAT
98
+ return data.slice!(0,4).unpack("e").first
99
+ when Field::TYPE_DOUBLE
100
+ return data.slice!(0,8).unpack("E").first
101
+ when Field::TYPE_DATE, Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
102
+ len = ord! data
103
+ y, m, d, h, mi, s, bs = data.slice!(0,len).unpack("vCCCCCV")
104
+ return Mysql::Time.new(y, m, d, h, mi, s, bs)
105
+ when Field::TYPE_TIME
106
+ len = ord! data
107
+ sign, d, h, mi, s, sp = data.slice!(0,len).unpack("CVCCCV")
108
+ h = d.to_i * 24 + h.to_i
109
+ return Mysql::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)
110
+ when Field::TYPE_YEAR
111
+ return data.slice!(0,2).unpack("v").first
112
+ when Field::TYPE_BIT
113
+ return Protocol.lcs2str!(data)
114
+ else
115
+ raise "not implemented: type=#{type}"
116
+ end
117
+ end
118
+
119
+ # convert Ruby value to netdata
120
+ # === Argument
121
+ # v :: [Object] Ruby value.
122
+ # === Return
123
+ # String :: netdata
124
+ # === Exception
125
+ # ProtocolError :: value too large / value is not supported
126
+ def self.value2net(v)
127
+ case v
128
+ when nil
129
+ type = Field::TYPE_NULL
130
+ val = ""
131
+ when Integer
132
+ if v >= 0
133
+ if v < 256
134
+ type = Field::TYPE_TINY | 0x8000
135
+ val = [v].pack("C")
136
+ elsif v < 256**2
137
+ type = Field::TYPE_SHORT | 0x8000
138
+ val = [v].pack("v")
139
+ elsif v < 256**4
140
+ type = Field::TYPE_LONG | 0x8000
141
+ val = [v].pack("V")
142
+ elsif v < 256**8
143
+ type = Field::TYPE_LONGLONG | 0x8000
144
+ val = [v&0xffffffff, v>>32].pack("VV")
145
+ else
146
+ raise ProtocolError, "value too large: #{v}"
147
+ end
148
+ else
149
+ if -v <= 256/2
150
+ type = Field::TYPE_TINY
151
+ val = [v].pack("C")
152
+ elsif -v <= 256**2/2
153
+ type = Field::TYPE_SHORT
154
+ val = [v].pack("v")
155
+ elsif -v <= 256**4/2
156
+ type = Field::TYPE_LONG
157
+ val = [v].pack("V")
158
+ elsif -v <= 256**8/2
159
+ type = Field::TYPE_LONGLONG
160
+ val = [v&0xffffffff, v>>32].pack("VV")
161
+ else
162
+ raise ProtocolError, "value too large: #{v}"
163
+ end
164
+ end
165
+ when Float
166
+ type = Field::TYPE_DOUBLE
167
+ val = [v].pack("E")
168
+ when String
169
+ type = Field::TYPE_STRING
170
+ val = Protocol.lcs(v)
171
+ when Mysql::Time, ::Time
172
+ type = Field::TYPE_DATETIME
173
+ val = [7, v.year, v.month, v.day, v.hour, v.min, v.sec].pack("CvCCCCC")
174
+ else
175
+ raise ProtocolError, "class #{v.class} is not supported"
176
+ end
177
+ return type, val
178
+ end
179
+
180
+ if "".respond_to? :ord
181
+ def self.ord(str)
182
+ str.ord
183
+ end
184
+ def self.ord!(str)
185
+ str.slice!(0).ord
186
+ end
187
+ else
188
+ def self.ord(str)
189
+ str[0]
190
+ end
191
+ def self.ord!(str)
192
+ str.slice!(0)
193
+ end
194
+ end
195
+
196
+ attr_reader :sqlstate
197
+
198
+ # make socket connection to server.
199
+ # === Argument
200
+ # host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket
201
+ # port :: [Integer] port number using by TCPSocket
202
+ # socket :: [String] socket file name using by UNIXSocket
203
+ # conn_timeout :: [Integer] connect timeout (sec).
204
+ # read_timeout :: [Integer] read timeout (sec).
205
+ # write_timeout :: [Integer] write timeout (sec).
206
+ def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout)
207
+ begin
208
+ Timeout.timeout conn_timeout do
209
+ if host.nil? or host.empty? or host == "localhost"
210
+ socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
211
+ @sock = UNIXSocket.new socket
212
+ else
213
+ port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
214
+ @sock = TCPSocket.new host, port
215
+ end
216
+ end
217
+ rescue Timeout::Error
218
+ raise ClientError, "connection timeout"
219
+ end
220
+ @read_timeout = read_timeout
221
+ @write_timeout = write_timeout
222
+ @seq = 0 # packet counter. reset by each command
223
+ @mutex = Mutex.new
224
+ end
225
+
226
+ def close
227
+ @sock.close
228
+ end
229
+
230
+ def synchronize
231
+ @mutex.synchronize do
232
+ return yield
233
+ end
234
+ end
235
+
236
+ # Reset sequence number
237
+ def reset
238
+ @seq = 0
239
+ end
240
+
241
+ # Read one packet data
242
+ # === Return
243
+ # String
244
+ # === Exception
245
+ # ProtocolError :: invalid packet sequence number
246
+ def read
247
+ ret = ""
248
+ len = nil
249
+ begin
250
+ Timeout.timeout @read_timeout do
251
+ header = @sock.read(4)
252
+ len1, len2, seq = header.unpack("CvC")
253
+ len = (len2 << 8) + len1
254
+ raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
255
+ @seq = (@seq + 1) % 256
256
+ ret.concat @sock.read(len)
257
+ end
258
+ rescue Timeout::Error
259
+ raise ClientError, "read timeout"
260
+ end while len == MAX_PACKET_LENGTH
261
+
262
+ @sqlstate = "00000"
263
+
264
+ # Error packet
265
+ if Protocol.ord(ret) == 0xff
266
+ f, errno, marker, @sqlstate, message = ret.unpack("Cvaa5a*")
267
+ unless marker == "#"
268
+ f, errno, message = ret.unpack("Cva*") # Version 4.0 Error
269
+ @sqlstate = ""
270
+ end
271
+ if Mysql::ServerError::ERROR_MAP.key? errno
272
+ raise Mysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
273
+ end
274
+ raise Mysql::ServerError.new(message, @sqlstate)
275
+ end
276
+ ret
277
+ end
278
+
279
+ # Write one packet data
280
+ # === Argument
281
+ # data [String / IO] ::
282
+ def write(data)
283
+ begin
284
+ @sock.sync = false
285
+ data = StringIO.new data if data.is_a? String
286
+ while d = data.read(MAX_PACKET_LENGTH)
287
+ Timeout.timeout @write_timeout do
288
+ @sock.write [d.length%256, d.length/256, @seq].pack("CvC")
289
+ @sock.write d
290
+ end
291
+ @seq = (@seq + 1) % 256
292
+ end
293
+ @sock.sync = true
294
+ Timeout.timeout @write_timeout do
295
+ @sock.flush
296
+ end
297
+ rescue Timeout::Error
298
+ raise ClientError, "write timeout"
299
+ end
300
+ end
301
+
302
+ # Send one packet
303
+ # === Argument
304
+ # packet :: [*Packet]
305
+ def send_packet(packet)
306
+ write packet.serialize
307
+ end
308
+
309
+ # Read EOF packet
310
+ # === Exception
311
+ # ProtocolError :: packet is not EOF
312
+ def read_eof_packet
313
+ data = read
314
+ raise ProtocolError, "packet is not EOF" unless Protocol.eof_packet? data
315
+ end
316
+
317
+ # Read initial packet
318
+ # === Return
319
+ # InitialPacket ::
320
+ # === Exception
321
+ # ProtocolError :: invalid packet
322
+ def read_initial_packet
323
+ InitialPacket.parse read
324
+ end
325
+
326
+ # Read result packet
327
+ # === Return
328
+ # ResultPacket ::
329
+ def read_result_packet
330
+ ResultPacket.parse read
331
+ end
332
+
333
+ # Read field packet
334
+ # === Return
335
+ # FieldPacket :: packet data
336
+ # === Exception
337
+ # ProtocolError :: invalid packet
338
+ def read_field_packet
339
+ FieldPacket.parse read
340
+ end
341
+
342
+ # Read prepare result packet
343
+ # === Return
344
+ # PrepareResultPacket ::
345
+ # === Exception
346
+ # ProtocolError :: invalid packet
347
+ def read_prepare_result_packet
348
+ PrepareResultPacket.parse read
349
+ end
350
+
351
+ # client->server packet base class
352
+ class TxPacket
353
+ end
354
+
355
+ # server->client packet base class
356
+ class RxPacket
357
+ end
358
+
359
+ # Initial packet
360
+ class InitialPacket < RxPacket
361
+ def self.parse(data)
362
+ protocol_version, server_version, thread_id, scramble_buff, f0,
363
+ server_capabilities, server_charset, server_status, f1,
364
+ rest_scramble_buff = data.unpack("CZ*Va8CvCva13Z13")
365
+ raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION
366
+ raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0
367
+ raise ProtocolError, "invalid packet: f1=#{f1.inspect}" unless f1 == "\0\0\0\0\0\0\0\0\0\0\0\0\0"
368
+ scramble_buff.concat rest_scramble_buff
369
+ self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff
370
+ end
371
+
372
+ attr_accessor :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff
373
+
374
+ def initialize(*args)
375
+ @protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff = args
376
+ end
377
+
378
+ def crypt_password(plain)
379
+ return "" if plain.nil? or plain.empty?
380
+ hash_stage1 = Digest::SHA1.digest plain
381
+ hash_stage2 = Digest::SHA1.digest hash_stage1
382
+ return hash_stage1.unpack("C*").zip(Digest::SHA1.digest(@scramble_buff+hash_stage2).unpack("C*")).map{|a,b| a^b}.pack("C*")
383
+ end
384
+ end
385
+
386
+ # Authentication packet
387
+ class AuthenticationPacket < TxPacket
388
+ attr_accessor :client_flags, :max_packet_size, :charset_number, :username, :scrambled_password, :databasename
389
+
390
+ def initialize(*args)
391
+ @client_flags, @max_packet_size, @charset_number, @username, @scrambled_password, @databasename = args
392
+ end
393
+
394
+ def serialize
395
+ [
396
+ client_flags,
397
+ max_packet_size,
398
+ Protocol.lcb(charset_number),
399
+ "", # always 0x00 * 23
400
+ username,
401
+ Protocol.lcs(scrambled_password),
402
+ databasename
403
+ ].pack("VVa*a23Z*A*Z*")
404
+ end
405
+ end
406
+
407
+ # Quit packet
408
+ class QuitPacket < TxPacket
409
+ def serialize
410
+ [COM_QUIT].pack("C")
411
+ end
412
+ end
413
+
414
+ # Query packet
415
+ class QueryPacket < TxPacket
416
+ attr_accessor :query
417
+
418
+ def initialize(*args)
419
+ @query, = args
420
+ end
421
+
422
+ def serialize
423
+ [COM_QUERY, query].pack("Ca*")
424
+ end
425
+ end
426
+
427
+ # Result packet
428
+ class ResultPacket < RxPacket
429
+ def self.parse(data)
430
+ field_count = Protocol.lcb2int! data
431
+ if field_count == 0
432
+ affected_rows = Protocol.lcb2int! data
433
+ insert_id = Protocol.lcb2int!(data)
434
+ server_status, warning_count, message = data.unpack("vva*")
435
+ return self.new field_count, affected_rows, insert_id, server_status, warning_count, message
436
+ else
437
+ return self.new field_count
438
+ end
439
+ end
440
+
441
+ attr_accessor :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
442
+
443
+ def initialize(*args)
444
+ @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
445
+ end
446
+ end
447
+
448
+ # Field packet
449
+ class FieldPacket < RxPacket
450
+ def self.parse(data)
451
+ first = Protocol.lcs2str! data
452
+ db = Protocol.lcs2str! data
453
+ table = Protocol.lcs2str! data
454
+ org_table = Protocol.lcs2str! data
455
+ name = Protocol.lcs2str! data
456
+ org_name = Protocol.lcs2str! data
457
+ f0, charsetnr, length, type, flags, decimals, f1, data = data.unpack("CvVCvCva*")
458
+ raise ProtocolError, "invalid packet: f1=#{f1}" unless f1 == 0
459
+ default = Protocol.lcs2str! data
460
+ return self.new db, table, org_table, name, org_name, charsetnr, length, type, flags, decimals, default
461
+ end
462
+
463
+ attr_accessor :db, :table, :org_table, :name, :org_name, :charsetnr, :length, :type, :flags, :decimals, :default
464
+
465
+ def initialize(*args)
466
+ @db, @table, @org_table, @name, @org_name, @charsetnr, @length, @type, @flags, @decimals, @default = args
467
+ end
468
+ end
469
+
470
+ # Prepare packet
471
+ class PreparePacket < TxPacket
472
+ attr_accessor :query
473
+
474
+ def initialize(*args)
475
+ @query, = args
476
+ end
477
+
478
+ def serialize
479
+ [COM_STMT_PREPARE, query].pack("Ca*")
480
+ end
481
+ end
482
+
483
+ # Prepare result packet
484
+ class PrepareResultPacket < RxPacket
485
+ def self.parse(data)
486
+ raise ProtocolError, "invalid packet" unless Protocol.ord!(data) == 0x00
487
+ statement_id, field_count, param_count, f, warning_count = data.unpack("VvvCv")
488
+ raise ProtocolError, "invalid packet" unless f == 0x00
489
+ self.new statement_id, field_count, param_count, warning_count
490
+ end
491
+
492
+ attr_accessor :statement_id, :field_count, :param_count, :warning_count
493
+
494
+ def initialize(*args)
495
+ @statement_id, @field_count, @param_count, @warning_count = args
496
+ end
497
+ end
498
+
499
+ # Execute packet
500
+ class ExecutePacket < TxPacket
501
+ attr_accessor :statement_id, :cursor_type, :values
502
+
503
+ def initialize(*args)
504
+ @statement_id, @cursor_type, @values = args
505
+ end
506
+
507
+ def serialize
508
+ nbm = null_bitmap values
509
+ netvalues = ""
510
+ types = values.map do |v|
511
+ t, n = Protocol.value2net v
512
+ netvalues.concat n if v
513
+ t
514
+ end
515
+ [Mysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
516
+ end
517
+
518
+ private
519
+
520
+ # make null bitmap
521
+ #
522
+ # If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
523
+ def null_bitmap(values)
524
+ bitmap = values.enum_for(:each_slice,8).map do |vals|
525
+ vals.reverse.inject(0){|b, v|(b<<1 | (v ? 0 : 1))}
526
+ end
527
+ return bitmap.pack("C*")
528
+ end
529
+
530
+ end
531
+
532
+ # Fetch packet
533
+ class FetchPacket < TxPacket
534
+ attr_accessor :statement_id, :fetch_length
535
+
536
+ def initialize(*args)
537
+ @statement_id, @fetch_length = args
538
+ end
539
+
540
+ def serialize
541
+ [Mysql::COM_STMT_FETCH, statement_id, fetch_length].pack("CVV")
542
+ end
543
+ end
544
+
545
+ # Stmt close packet
546
+ class StmtClosePacket < TxPacket
547
+ attr_accessor :statement_id
548
+
549
+ def initialize(*args)
550
+ @statement_id, = args
551
+ end
552
+
553
+ def serialize
554
+ [Mysql::COM_STMT_CLOSE, statement_id].pack("CV")
555
+ end
556
+ end
557
+ end
558
+ end