ruby-mysql 2.9.0
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of ruby-mysql might be problematic. Click here for more details.
- data/ChangeLog +3 -0
- data/README.rdoc +68 -0
- data/lib/mysql.rb +1080 -0
- data/lib/mysql/charset.rb +255 -0
- data/lib/mysql/constants.rb +164 -0
- data/lib/mysql/error.rb +518 -0
- data/lib/mysql/protocol.rb +710 -0
- metadata +73 -0
@@ -0,0 +1,710 @@
|
|
1
|
+
# Copyright (C) 2008-2009 TOMITA Masahiro
|
2
|
+
# mailto:tommy@tmtm.org
|
3
|
+
|
4
|
+
require "socket"
|
5
|
+
require "timeout"
|
6
|
+
require "digest/sha1"
|
7
|
+
require "thread"
|
8
|
+
require "stringio"
|
9
|
+
|
10
|
+
class Mysql
|
11
|
+
# MySQL network protocol
|
12
|
+
class Protocol
|
13
|
+
|
14
|
+
VERSION = 10
|
15
|
+
MAX_PACKET_LENGTH = 2**24-1
|
16
|
+
|
17
|
+
# convert Numeric to LengthCodedBinary
|
18
|
+
def self.lcb(num)
|
19
|
+
return "\xfb" if num.nil?
|
20
|
+
return [num].pack("C") if num < 251
|
21
|
+
return [252, num].pack("Cv") if num < 65536
|
22
|
+
return [253, num&0xffff, num>>16].pack("CvC") if num < 16777216
|
23
|
+
return [254, num&0xffffffff, num>>32].pack("CVV")
|
24
|
+
end
|
25
|
+
|
26
|
+
# convert String to LengthCodedString
|
27
|
+
def self.lcs(str)
|
28
|
+
str = Charset.to_binary str
|
29
|
+
lcb(str.length)+str
|
30
|
+
end
|
31
|
+
|
32
|
+
# convert LengthCodedBinary to Integer
|
33
|
+
# === Argument
|
34
|
+
# lcb :: [String] LengthCodedBinary. This value will be broken.
|
35
|
+
# === Return
|
36
|
+
# Integer or nil
|
37
|
+
def self.lcb2int!(lcb)
|
38
|
+
return nil if lcb.empty?
|
39
|
+
case v = lcb.slice!(0)
|
40
|
+
when ?\xfb
|
41
|
+
return nil
|
42
|
+
when ?\xfc
|
43
|
+
return lcb.slice!(0,2).unpack("v").first
|
44
|
+
when ?\xfd
|
45
|
+
c, v = lcb.slice!(0,3).unpack("Cv")
|
46
|
+
return (v << 8)+c
|
47
|
+
when ?\xfe
|
48
|
+
v1, v2 = lcb.slice!(0,8).unpack("VV")
|
49
|
+
return (v2 << 32)+v1
|
50
|
+
else
|
51
|
+
return v.ord
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
# convert LengthCodedString to String
|
56
|
+
# === Argument
|
57
|
+
# lcs :: [String] LengthCodedString. This value will be broken.
|
58
|
+
# === Return
|
59
|
+
# String or nil
|
60
|
+
def self.lcs2str!(lcs)
|
61
|
+
len = lcb2int! lcs
|
62
|
+
return len && lcs.slice!(0, len)
|
63
|
+
end
|
64
|
+
|
65
|
+
def self.eof_packet?(data)
|
66
|
+
data[0] == ?\xfe && data.length == 5
|
67
|
+
end
|
68
|
+
|
69
|
+
# Convert netdata to Ruby value
|
70
|
+
# === Argument
|
71
|
+
# data :: [String] packet data. This will be broken.
|
72
|
+
# type :: [Integer] field type
|
73
|
+
# unsigned :: [true or false] true if value is unsigned
|
74
|
+
# === Return
|
75
|
+
# Object :: converted value.
|
76
|
+
def self.net2value(data, type, unsigned)
|
77
|
+
case type
|
78
|
+
when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB
|
79
|
+
return lcs2str!(data)
|
80
|
+
when Field::TYPE_TINY
|
81
|
+
v = data.slice!(0).ord
|
82
|
+
return unsigned ? v : v < 128 ? v : v-256
|
83
|
+
when Field::TYPE_SHORT
|
84
|
+
v = data.slice!(0,2).unpack("v").first
|
85
|
+
return unsigned ? v : v < 32768 ? v : v-65536
|
86
|
+
when Field::TYPE_INT24, Field::TYPE_LONG
|
87
|
+
v = data.slice!(0,4).unpack("V").first
|
88
|
+
return unsigned ? v : v < 2**32/2 ? v : v-2**32
|
89
|
+
when Field::TYPE_LONGLONG
|
90
|
+
n1, n2 = data.slice!(0,8).unpack("VV")
|
91
|
+
v = (n2 << 32) | n1
|
92
|
+
return unsigned ? v : v < 2**64/2 ? v : v-2**64
|
93
|
+
when Field::TYPE_FLOAT
|
94
|
+
return data.slice!(0,4).unpack("e").first
|
95
|
+
when Field::TYPE_DOUBLE
|
96
|
+
return data.slice!(0,8).unpack("E").first
|
97
|
+
when Field::TYPE_DATE, Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
|
98
|
+
len = data.slice!(0).ord
|
99
|
+
y, m, d, h, mi, s, bs = data.slice!(0,len).unpack("vCCCCCV")
|
100
|
+
return Mysql::Time.new(y, m, d, h, mi, s, bs)
|
101
|
+
when Field::TYPE_TIME
|
102
|
+
len = data.slice!(0).ord
|
103
|
+
sign, d, h, mi, s, sp = data.slice!(0,len).unpack("CVCCCV")
|
104
|
+
h = d.to_i * 24 + h.to_i
|
105
|
+
return Mysql::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)
|
106
|
+
when Field::TYPE_YEAR
|
107
|
+
return data.slice!(0,2).unpack("v").first
|
108
|
+
when Field::TYPE_BIT
|
109
|
+
return lcs2str!(data)
|
110
|
+
else
|
111
|
+
raise "not implemented: type=#{type}"
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
115
|
+
# convert Ruby value to netdata
|
116
|
+
# === Argument
|
117
|
+
# v :: [Object] Ruby value.
|
118
|
+
# === Return
|
119
|
+
# Integer :: type of column. Field::TYPE_*
|
120
|
+
# String :: netdata
|
121
|
+
# === Exception
|
122
|
+
# ProtocolError :: value too large / value is not supported
|
123
|
+
def self.value2net(v)
|
124
|
+
case v
|
125
|
+
when nil
|
126
|
+
type = Field::TYPE_NULL
|
127
|
+
val = ""
|
128
|
+
when Integer
|
129
|
+
if v >= 0
|
130
|
+
if v < 256
|
131
|
+
type = Field::TYPE_TINY | 0x8000
|
132
|
+
val = [v].pack("C")
|
133
|
+
elsif v < 256**2
|
134
|
+
type = Field::TYPE_SHORT | 0x8000
|
135
|
+
val = [v].pack("v")
|
136
|
+
elsif v < 256**4
|
137
|
+
type = Field::TYPE_LONG | 0x8000
|
138
|
+
val = [v].pack("V")
|
139
|
+
elsif v < 256**8
|
140
|
+
type = Field::TYPE_LONGLONG | 0x8000
|
141
|
+
val = [v&0xffffffff, v>>32].pack("VV")
|
142
|
+
else
|
143
|
+
raise ProtocolError, "value too large: #{v}"
|
144
|
+
end
|
145
|
+
else
|
146
|
+
if -v <= 256/2
|
147
|
+
type = Field::TYPE_TINY
|
148
|
+
val = [v].pack("C")
|
149
|
+
elsif -v <= 256**2/2
|
150
|
+
type = Field::TYPE_SHORT
|
151
|
+
val = [v].pack("v")
|
152
|
+
elsif -v <= 256**4/2
|
153
|
+
type = Field::TYPE_LONG
|
154
|
+
val = [v].pack("V")
|
155
|
+
elsif -v <= 256**8/2
|
156
|
+
type = Field::TYPE_LONGLONG
|
157
|
+
val = [v&0xffffffff, v>>32].pack("VV")
|
158
|
+
else
|
159
|
+
raise ProtocolError, "value too large: #{v}"
|
160
|
+
end
|
161
|
+
end
|
162
|
+
when Float
|
163
|
+
type = Field::TYPE_DOUBLE
|
164
|
+
val = [v].pack("E")
|
165
|
+
when String
|
166
|
+
type = Field::TYPE_STRING
|
167
|
+
val = lcs(v)
|
168
|
+
when Mysql::Time, ::Time
|
169
|
+
type = Field::TYPE_DATETIME
|
170
|
+
val = [7, v.year, v.month, v.day, v.hour, v.min, v.sec].pack("CvCCCCC")
|
171
|
+
else
|
172
|
+
raise ProtocolError, "class #{v.class} is not supported"
|
173
|
+
end
|
174
|
+
return type, val
|
175
|
+
end
|
176
|
+
|
177
|
+
attr_reader :server_info
|
178
|
+
attr_reader :server_version
|
179
|
+
attr_reader :thread_id
|
180
|
+
attr_reader :sqlstate
|
181
|
+
attr_reader :affected_rows
|
182
|
+
attr_reader :insert_id
|
183
|
+
attr_reader :server_status
|
184
|
+
attr_reader :warning_count
|
185
|
+
attr_reader :message
|
186
|
+
attr_accessor :charset
|
187
|
+
|
188
|
+
# make socket connection to server.
|
189
|
+
# === Argument
|
190
|
+
# host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket
|
191
|
+
# port :: [Integer] port number using by TCPSocket
|
192
|
+
# socket :: [String] socket file name using by UNIXSocket
|
193
|
+
# conn_timeout :: [Integer] connect timeout (sec).
|
194
|
+
# read_timeout :: [Integer] read timeout (sec).
|
195
|
+
# write_timeout :: [Integer] write timeout (sec).
|
196
|
+
# === Exception
|
197
|
+
# [ClientError] :: connection timeout
|
198
|
+
def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout)
|
199
|
+
@mutex = Mutex.new
|
200
|
+
@read_timeout = read_timeout
|
201
|
+
@write_timeout = write_timeout
|
202
|
+
begin
|
203
|
+
Timeout.timeout conn_timeout do
|
204
|
+
if host.nil? or host.empty? or host == "localhost"
|
205
|
+
socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
|
206
|
+
@sock = UNIXSocket.new socket
|
207
|
+
else
|
208
|
+
port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
|
209
|
+
@sock = TCPSocket.new host, port
|
210
|
+
end
|
211
|
+
end
|
212
|
+
rescue Timeout::Error
|
213
|
+
raise ClientError, "connection timeout"
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def close
|
218
|
+
@sock.close
|
219
|
+
end
|
220
|
+
|
221
|
+
# initial negotiate and authenticate.
|
222
|
+
# === Argument
|
223
|
+
# user :: [String / nil] username
|
224
|
+
# passwd :: [String / nil] password
|
225
|
+
# db :: [String / nil] default database name. nil: no default.
|
226
|
+
# flag :: [Integer] client flag
|
227
|
+
# charset :: [Mysql::Charset / nil] charset for connection. nil: use server's charset
|
228
|
+
def authenticate(user, passwd, db, flag, charset)
|
229
|
+
@authinfo = [user, passwd, db, flag, charset]
|
230
|
+
synchronize do
|
231
|
+
reset
|
232
|
+
init_packet = InitialPacket.parse read
|
233
|
+
@server_info = init_packet.server_version
|
234
|
+
@server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
|
235
|
+
@thread_id = init_packet.thread_id
|
236
|
+
client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION
|
237
|
+
client_flags |= CLIENT_CONNECT_WITH_DB if db
|
238
|
+
client_flags |= flag
|
239
|
+
@charset = charset
|
240
|
+
unless @charset
|
241
|
+
@charset = Charset.by_number(init_packet.server_charset)
|
242
|
+
@charset.encoding # raise error if unsupported charset
|
243
|
+
end
|
244
|
+
netpw = encrypt_password passwd, init_packet.scramble_buff
|
245
|
+
write AuthenticationPacket.serialize(client_flags, 1024**3, @charset.number, user, netpw, db)
|
246
|
+
read # skip OK packet
|
247
|
+
end
|
248
|
+
end
|
249
|
+
|
250
|
+
# Quit command
|
251
|
+
def quit_command
|
252
|
+
synchronize do
|
253
|
+
reset
|
254
|
+
write [COM_QUIT].pack("C")
|
255
|
+
close
|
256
|
+
end
|
257
|
+
end
|
258
|
+
|
259
|
+
# Query command
|
260
|
+
# === Argument
|
261
|
+
# query :: [String] query string
|
262
|
+
# === Return
|
263
|
+
# [Integer / nil] number of fields of results. nil if no results.
|
264
|
+
def query_command(query)
|
265
|
+
synchronize do
|
266
|
+
reset
|
267
|
+
write [COM_QUERY, @charset.convert(query)].pack("Ca*")
|
268
|
+
get_result
|
269
|
+
end
|
270
|
+
end
|
271
|
+
|
272
|
+
# get result of query.
|
273
|
+
# === Return
|
274
|
+
# [integer / nil] number of fields of results. nil if no results.
|
275
|
+
def get_result
|
276
|
+
res_packet = ResultPacket.parse read
|
277
|
+
if res_packet.field_count.to_i > 0 # result data exists
|
278
|
+
return res_packet.field_count
|
279
|
+
end
|
280
|
+
if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
|
281
|
+
filename = res_packet.message
|
282
|
+
File.open(filename){|f| write f}
|
283
|
+
write nil # EOF mark
|
284
|
+
read
|
285
|
+
end
|
286
|
+
@affected_rows, @insert_id, @server_status, @warning_count, @message =
|
287
|
+
res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
|
288
|
+
return nil
|
289
|
+
end
|
290
|
+
|
291
|
+
# Retrieve n fields
|
292
|
+
# === Argument
|
293
|
+
# n :: [Integer] number of fields
|
294
|
+
# === Return
|
295
|
+
# [Array of Mysql::Field] field list
|
296
|
+
def retr_fields(n)
|
297
|
+
fields = n.times.map{Field.new FieldPacket.parse(read)}
|
298
|
+
read_eof_packet
|
299
|
+
fields
|
300
|
+
end
|
301
|
+
|
302
|
+
# Retrieve all records for simple query
|
303
|
+
# === Argument
|
304
|
+
# fields :: [Array of Mysql::Field] field list
|
305
|
+
# === Return
|
306
|
+
# [Array of Array of String] all records
|
307
|
+
def retr_all_records(fields)
|
308
|
+
all_recs = []
|
309
|
+
until self.class.eof_packet?(data = read)
|
310
|
+
rec = fields.map do
|
311
|
+
s = self.class.lcs2str!(data)
|
312
|
+
s && charset.force_encoding(s)
|
313
|
+
end
|
314
|
+
all_recs.push rec
|
315
|
+
end
|
316
|
+
@server_status = data[3].ord
|
317
|
+
all_recs
|
318
|
+
end
|
319
|
+
|
320
|
+
# Field list command
|
321
|
+
# === Argument
|
322
|
+
# table :: [String] table name.
|
323
|
+
# field :: [String / nil] field name that may contain wild card.
|
324
|
+
# === Return
|
325
|
+
# [Array of Field] field list
|
326
|
+
def field_list_command(table, field)
|
327
|
+
synchronize do
|
328
|
+
reset
|
329
|
+
write [COM_FIELD_LIST, table, 0, field].pack("Ca*Ca*")
|
330
|
+
fields = []
|
331
|
+
until self.class.eof_packet?(data = read)
|
332
|
+
fields.push Field.new(FieldPacket.parse(data))
|
333
|
+
end
|
334
|
+
return fields
|
335
|
+
end
|
336
|
+
end
|
337
|
+
|
338
|
+
# Process info command
|
339
|
+
# === Return
|
340
|
+
# [Array of Field] field list
|
341
|
+
def process_info_command
|
342
|
+
synchronize do
|
343
|
+
reset
|
344
|
+
write [COM_PROCESS_INFO].pack("C")
|
345
|
+
field_count = self.class.lcb2int!(read)
|
346
|
+
fields = field_count.times.map{Field.new FieldPacket.parse(read)}
|
347
|
+
read_eof_packet
|
348
|
+
return fields
|
349
|
+
end
|
350
|
+
end
|
351
|
+
|
352
|
+
# Ping command
|
353
|
+
def ping_command
|
354
|
+
simple_command [COM_PING].pack("C")
|
355
|
+
end
|
356
|
+
|
357
|
+
# Kill command
|
358
|
+
def kill_command(pid)
|
359
|
+
simple_command [COM_PROCESS_KILL, pid].pack("CV")
|
360
|
+
end
|
361
|
+
|
362
|
+
# Refresh command
|
363
|
+
def refresh_command(op)
|
364
|
+
simple_command [COM_REFRESH, op].pack("CC")
|
365
|
+
end
|
366
|
+
|
367
|
+
# Set option command
|
368
|
+
def set_option_command(opt)
|
369
|
+
simple_command [COM_SET_OPTION, opt].pack("Cv")
|
370
|
+
end
|
371
|
+
|
372
|
+
# Shutdown command
|
373
|
+
def shutdown_command(level)
|
374
|
+
simple_command [COM_SHUTDOWN, level].pack("CC")
|
375
|
+
end
|
376
|
+
|
377
|
+
# Statistics command
|
378
|
+
def statistics_command
|
379
|
+
simple_command [COM_STATISTICS].pack("C")
|
380
|
+
end
|
381
|
+
|
382
|
+
# Stmt prepare command
|
383
|
+
# === Argument
|
384
|
+
# stmt :: [String] prepared statement
|
385
|
+
# === Return
|
386
|
+
# [Integer] statement id
|
387
|
+
# [Integer] number of parameters
|
388
|
+
# [Array of Field] field list
|
389
|
+
def stmt_prepare_command(stmt)
|
390
|
+
synchronize do
|
391
|
+
reset
|
392
|
+
write [COM_STMT_PREPARE, charset.convert(stmt)].pack("Ca*")
|
393
|
+
res_packet = PrepareResultPacket.parse read
|
394
|
+
if res_packet.param_count > 0
|
395
|
+
res_packet.param_count.times{read} # skip parameter packet
|
396
|
+
read_eof_packet
|
397
|
+
end
|
398
|
+
if res_packet.field_count > 0
|
399
|
+
fields = res_packet.field_count.times.map{Field.new FieldPacket.parse(read)}
|
400
|
+
read_eof_packet
|
401
|
+
else
|
402
|
+
fields = []
|
403
|
+
end
|
404
|
+
return res_packet.statement_id, res_packet.param_count, fields
|
405
|
+
end
|
406
|
+
end
|
407
|
+
|
408
|
+
# Stmt execute command
|
409
|
+
# === Argument
|
410
|
+
# stmt_id :: [Integer] statement id
|
411
|
+
# values :: [Array] parameters
|
412
|
+
# === Return
|
413
|
+
# [Integer] number of fields
|
414
|
+
def stmt_execute_command(stmt_id, values)
|
415
|
+
synchronize do
|
416
|
+
reset
|
417
|
+
write ExecutePacket.serialize(stmt_id, Mysql::Stmt::CURSOR_TYPE_NO_CURSOR, values)
|
418
|
+
return get_result
|
419
|
+
end
|
420
|
+
end
|
421
|
+
|
422
|
+
# Retrieve all records for prepared statement
|
423
|
+
# === Argument
|
424
|
+
# fields :: [Array of Mysql::Fields] field list
|
425
|
+
# charset :: [Mysql::Charset]
|
426
|
+
# === Return
|
427
|
+
# [Array of Array of Object] all records
|
428
|
+
def stmt_retr_all_records(fields, charset)
|
429
|
+
all_recs = []
|
430
|
+
until self.class.eof_packet?(data = read)
|
431
|
+
all_recs.push stmt_parse_record_packet(data, fields, charset)
|
432
|
+
end
|
433
|
+
all_recs
|
434
|
+
end
|
435
|
+
|
436
|
+
# Stmt close command
|
437
|
+
# === Argument
|
438
|
+
# stmt_id :: [Integer] statement id
|
439
|
+
def stmt_close_command(stmt_id)
|
440
|
+
synchronize do
|
441
|
+
reset
|
442
|
+
write [COM_STMT_CLOSE, stmt_id].pack("CV")
|
443
|
+
end
|
444
|
+
end
|
445
|
+
|
446
|
+
private
|
447
|
+
|
448
|
+
# Parse statement result packet
|
449
|
+
# === Argument
|
450
|
+
# data :: [String]
|
451
|
+
# fields :: [Array of Fields]
|
452
|
+
# charset :: [Mysql::Charset]
|
453
|
+
# === Return
|
454
|
+
# [Array of Object] one record
|
455
|
+
def stmt_parse_record_packet(data, fields, charset)
|
456
|
+
data.slice!(0) # skip first byte
|
457
|
+
null_bit_map = data.slice!(0, (fields.length+7+2)/8).unpack("b*").first
|
458
|
+
rec = fields.each_with_index.map do |f, i|
|
459
|
+
if null_bit_map[i+2] == ?1
|
460
|
+
nil
|
461
|
+
else
|
462
|
+
unsigned = f.flags & Field::UNSIGNED_FLAG != 0
|
463
|
+
v = self.class.net2value(data, f.type, unsigned)
|
464
|
+
if v.is_a? Numeric or v.is_a? Mysql::Time
|
465
|
+
v
|
466
|
+
elsif f.type == Field::TYPE_BIT or f.flags & Field::BINARY_FLAG != 0
|
467
|
+
Charset.to_binary(v)
|
468
|
+
else
|
469
|
+
charset.force_encoding(v)
|
470
|
+
end
|
471
|
+
end
|
472
|
+
end
|
473
|
+
rec
|
474
|
+
end
|
475
|
+
|
476
|
+
def synchronize
|
477
|
+
@mutex.synchronize do
|
478
|
+
return yield
|
479
|
+
end
|
480
|
+
end
|
481
|
+
|
482
|
+
# Reset sequence number
|
483
|
+
def reset
|
484
|
+
@seq = 0 # packet counter. reset by each command
|
485
|
+
end
|
486
|
+
|
487
|
+
# Read one packet data
|
488
|
+
# === Return
|
489
|
+
# [String] packet data
|
490
|
+
# === Exception
|
491
|
+
# [ProtocolError] invalid packet sequence number
|
492
|
+
def read
|
493
|
+
ret = ""
|
494
|
+
len = nil
|
495
|
+
begin
|
496
|
+
Timeout.timeout @read_timeout do
|
497
|
+
header = @sock.read(4)
|
498
|
+
len1, len2, seq = header.unpack("CvC")
|
499
|
+
len = (len2 << 8) + len1
|
500
|
+
raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
|
501
|
+
@seq = (@seq + 1) % 256
|
502
|
+
ret.concat @sock.read(len)
|
503
|
+
end
|
504
|
+
rescue Timeout::Error
|
505
|
+
raise ClientError, "read timeout"
|
506
|
+
end while len == MAX_PACKET_LENGTH
|
507
|
+
|
508
|
+
@sqlstate = "00000"
|
509
|
+
|
510
|
+
# Error packet
|
511
|
+
if ret[0] == ?\xff
|
512
|
+
f, errno, marker, @sqlstate, message = ret.unpack("Cvaa5a*")
|
513
|
+
unless marker == "#"
|
514
|
+
f, errno, message = ret.unpack("Cva*") # Version 4.0 Error
|
515
|
+
@sqlstate = ""
|
516
|
+
end
|
517
|
+
if Mysql::ServerError::ERROR_MAP.key? errno
|
518
|
+
raise Mysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
|
519
|
+
end
|
520
|
+
raise Mysql::ServerError.new(message, @sqlstate)
|
521
|
+
end
|
522
|
+
ret
|
523
|
+
end
|
524
|
+
|
525
|
+
# Write one packet data
|
526
|
+
# === Argument
|
527
|
+
# data :: [String / IO] packet data. If data is nil, write empty packet.
|
528
|
+
def write(data)
|
529
|
+
begin
|
530
|
+
@sock.sync = false
|
531
|
+
if data.nil?
|
532
|
+
Timeout.timeout @write_timeout do
|
533
|
+
@sock.write [0, 0, @seq].pack("CvC")
|
534
|
+
end
|
535
|
+
@seq = (@seq + 1) % 256
|
536
|
+
else
|
537
|
+
data = StringIO.new data if data.is_a? String
|
538
|
+
while d = data.read(MAX_PACKET_LENGTH)
|
539
|
+
Timeout.timeout @write_timeout do
|
540
|
+
@sock.write [d.length%256, d.length/256, @seq].pack("CvC")
|
541
|
+
@sock.write d
|
542
|
+
end
|
543
|
+
@seq = (@seq + 1) % 256
|
544
|
+
end
|
545
|
+
end
|
546
|
+
@sock.sync = true
|
547
|
+
Timeout.timeout @write_timeout do
|
548
|
+
@sock.flush
|
549
|
+
end
|
550
|
+
rescue Timeout::Error
|
551
|
+
raise ClientError, "write timeout"
|
552
|
+
end
|
553
|
+
end
|
554
|
+
|
555
|
+
# Read EOF packet
|
556
|
+
# === Exception
|
557
|
+
# [ProtocolError] packet is not EOF
|
558
|
+
def read_eof_packet
|
559
|
+
data = read
|
560
|
+
raise ProtocolError, "packet is not EOF" unless self.class.eof_packet? data
|
561
|
+
end
|
562
|
+
|
563
|
+
# Send simple command
|
564
|
+
# === Argument
|
565
|
+
# packet :: [String] packet data
|
566
|
+
# === Return
|
567
|
+
# [String] received data
|
568
|
+
def simple_command(packet)
|
569
|
+
synchronize do
|
570
|
+
reset
|
571
|
+
write packet
|
572
|
+
read
|
573
|
+
end
|
574
|
+
end
|
575
|
+
|
576
|
+
# Encrypt password
|
577
|
+
# === Argument
|
578
|
+
# plain :: [String] plain password.
|
579
|
+
# scramble :: [String] scramble code from initial packet.
|
580
|
+
# === Return
|
581
|
+
# [String] encrypted password
|
582
|
+
def encrypt_password(plain, scramble)
|
583
|
+
return "" if plain.nil? or plain.empty?
|
584
|
+
hash_stage1 = Digest::SHA1.digest plain
|
585
|
+
hash_stage2 = Digest::SHA1.digest hash_stage1
|
586
|
+
return hash_stage1.unpack("C*").zip(Digest::SHA1.digest(scramble+hash_stage2).unpack("C*")).map{|a,b| a^b}.pack("C*")
|
587
|
+
end
|
588
|
+
|
589
|
+
# Initial packet
|
590
|
+
class InitialPacket
|
591
|
+
def self.parse(data)
|
592
|
+
protocol_version, server_version, thread_id, scramble_buff, f0,
|
593
|
+
server_capabilities, server_charset, server_status, f1,
|
594
|
+
rest_scramble_buff = data.unpack("CZ*Va8CvCva13Z13")
|
595
|
+
raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION
|
596
|
+
raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0
|
597
|
+
raise ProtocolError, "invalid packet: f1=#{f1.inspect}" unless f1 == "\0\0\0\0\0\0\0\0\0\0\0\0\0"
|
598
|
+
scramble_buff.concat rest_scramble_buff
|
599
|
+
self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff
|
600
|
+
end
|
601
|
+
|
602
|
+
attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff
|
603
|
+
|
604
|
+
def initialize(*args)
|
605
|
+
@protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff = args
|
606
|
+
end
|
607
|
+
end
|
608
|
+
|
609
|
+
# Result packet
|
610
|
+
class ResultPacket
|
611
|
+
def self.parse(data)
|
612
|
+
field_count = Protocol.lcb2int! data
|
613
|
+
if field_count == 0
|
614
|
+
affected_rows = Protocol.lcb2int! data
|
615
|
+
insert_id = Protocol.lcb2int!(data)
|
616
|
+
server_status, warning_count, message = data.unpack("vva*")
|
617
|
+
return self.new(field_count, affected_rows, insert_id, server_status, warning_count, Protocol.lcs2str!(message))
|
618
|
+
elsif field_count.nil? # LOAD DATA LOCAL INFILE
|
619
|
+
return self.new(nil, nil, nil, nil, nil, data)
|
620
|
+
else
|
621
|
+
return self.new(field_count)
|
622
|
+
end
|
623
|
+
end
|
624
|
+
|
625
|
+
attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
|
626
|
+
|
627
|
+
def initialize(*args)
|
628
|
+
@field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
|
629
|
+
end
|
630
|
+
end
|
631
|
+
|
632
|
+
# Field packet
|
633
|
+
class FieldPacket
|
634
|
+
def self.parse(data)
|
635
|
+
first = Protocol.lcs2str! data
|
636
|
+
db = Protocol.lcs2str! data
|
637
|
+
table = Protocol.lcs2str! data
|
638
|
+
org_table = Protocol.lcs2str! data
|
639
|
+
name = Protocol.lcs2str! data
|
640
|
+
org_name = Protocol.lcs2str! data
|
641
|
+
f0, charsetnr, length, type, flags, decimals, f1, data = data.unpack("CvVCvCva*")
|
642
|
+
raise ProtocolError, "invalid packet: f1=#{f1}" unless f1 == 0
|
643
|
+
default = Protocol.lcs2str! data
|
644
|
+
return self.new(db, table, org_table, name, org_name, charsetnr, length, type, flags, decimals, default)
|
645
|
+
end
|
646
|
+
|
647
|
+
attr_reader :db, :table, :org_table, :name, :org_name, :charsetnr, :length, :type, :flags, :decimals, :default
|
648
|
+
|
649
|
+
def initialize(*args)
|
650
|
+
@db, @table, @org_table, @name, @org_name, @charsetnr, @length, @type, @flags, @decimals, @default = args
|
651
|
+
end
|
652
|
+
end
|
653
|
+
|
654
|
+
# Prepare result packet
|
655
|
+
class PrepareResultPacket
|
656
|
+
def self.parse(data)
|
657
|
+
raise ProtocolError, "invalid packet" unless data.slice!(0) == ?\0
|
658
|
+
statement_id, field_count, param_count, f, warning_count = data.unpack("VvvCv")
|
659
|
+
raise ProtocolError, "invalid packet" unless f == 0x00
|
660
|
+
self.new statement_id, field_count, param_count, warning_count
|
661
|
+
end
|
662
|
+
|
663
|
+
attr_reader :statement_id, :field_count, :param_count, :warning_count
|
664
|
+
|
665
|
+
def initialize(*args)
|
666
|
+
@statement_id, @field_count, @param_count, @warning_count = args
|
667
|
+
end
|
668
|
+
end
|
669
|
+
|
670
|
+
# Authentication packet
|
671
|
+
class AuthenticationPacket
|
672
|
+
def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename)
|
673
|
+
[
|
674
|
+
client_flags,
|
675
|
+
max_packet_size,
|
676
|
+
Protocol.lcb(charset_number),
|
677
|
+
"", # always 0x00 * 23
|
678
|
+
username,
|
679
|
+
Protocol.lcs(scrambled_password),
|
680
|
+
databasename
|
681
|
+
].pack("VVa*a23Z*A*Z*")
|
682
|
+
end
|
683
|
+
end
|
684
|
+
|
685
|
+
# Execute packet
|
686
|
+
class ExecutePacket
|
687
|
+
def self.serialize(statement_id, cursor_type, values)
|
688
|
+
nbm = null_bitmap values
|
689
|
+
netvalues = ""
|
690
|
+
types = values.map do |v|
|
691
|
+
t, n = Protocol.value2net v
|
692
|
+
netvalues.concat n if v
|
693
|
+
t
|
694
|
+
end
|
695
|
+
[Mysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
|
696
|
+
end
|
697
|
+
|
698
|
+
# make null bitmap
|
699
|
+
#
|
700
|
+
# If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
|
701
|
+
def self.null_bitmap(values)
|
702
|
+
bitmap = values.enum_for(:each_slice,8).map do |vals|
|
703
|
+
vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}
|
704
|
+
end
|
705
|
+
return bitmap.pack("C*")
|
706
|
+
end
|
707
|
+
|
708
|
+
end
|
709
|
+
end
|
710
|
+
end
|