ruby-mysql 2.9.14 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
1
  # coding: ascii-8bit
2
- # Copyright (C) 2008-2012 TOMITA Masahiro
2
+ # Copyright (C) 2008 TOMITA Masahiro
3
3
  # mailto:tommy@tmtm.org
4
4
 
5
5
  require "socket"
6
- require "timeout"
7
- require "digest/sha1"
8
6
  require "stringio"
7
+ require "openssl"
8
+ require_relative 'authenticator.rb'
9
9
 
10
10
  class Mysql
11
11
  # MySQL network protocol
@@ -15,12 +15,10 @@ class Mysql
15
15
  MAX_PACKET_LENGTH = 2**24-1
16
16
 
17
17
  # Convert netdata to Ruby value
18
- # === Argument
19
- # data :: [Packet] packet data
20
- # type :: [Integer] field type
21
- # unsigned :: [true or false] true if value is unsigned
22
- # === Return
23
- # Object :: converted value.
18
+ # @param data [Packet] packet data
19
+ # @param type [Integer] field type
20
+ # @param unsigned [true or false] true if value is unsigned
21
+ # @return [Object] converted value.
24
22
  def self.net2value(pkt, type, unsigned)
25
23
  case type
26
24
  when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB, Field::TYPE_JSON
@@ -45,17 +43,18 @@ class Mysql
45
43
  when Field::TYPE_DATE
46
44
  len = pkt.utiny
47
45
  y, m, d = pkt.read(len).unpack("vCC")
48
- t = Mysql::Time.new(y, m, d, nil, nil, nil)
46
+ t = Time.new(y, m, d) rescue nil
49
47
  return t
50
48
  when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
51
49
  len = pkt.utiny
52
50
  y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")
53
- return Mysql::Time.new(y, m, d, h, mi, s, false, sp)
51
+ return Time.new(y, m, d, h, mi, Rational((s.to_i*1000000+sp.to_i)/1000000)) rescue nil
54
52
  when Field::TYPE_TIME
55
53
  len = pkt.utiny
56
54
  sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")
57
- h = d.to_i * 24 + h.to_i
58
- return Mysql::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)
55
+ r = d.to_i*86400 + h.to_i*3600 + mi.to_i*60 + s.to_i + sp.to_f/1000000
56
+ r *= -1 if sign != 0
57
+ return r
59
58
  when Field::TYPE_YEAR
60
59
  return pkt.ushort
61
60
  when Field::TYPE_BIT
@@ -66,13 +65,10 @@ class Mysql
66
65
  end
67
66
 
68
67
  # convert Ruby value to netdata
69
- # === Argument
70
- # v :: [Object] Ruby value.
71
- # === Return
72
- # Integer :: type of column. Field::TYPE_*
73
- # String :: netdata
74
- # === Exception
75
- # ProtocolError :: value too large / value is not supported
68
+ # @param v [Object] Ruby value.
69
+ # @return [Integer] type of column. Field::TYPE_*
70
+ # @return [String] netdata
71
+ # @raise [ProtocolError] value too large / value is not supported
76
72
  def self.value2net(v)
77
73
  case v
78
74
  when nil
@@ -97,12 +93,9 @@ class Mysql
97
93
  when String
98
94
  type = Field::TYPE_STRING
99
95
  val = Packet.lcs(v)
100
- when ::Time
96
+ when Time
101
97
  type = Field::TYPE_DATETIME
102
98
  val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec].pack("CvCCCCCV")
103
- when Mysql::Time
104
- type = Field::TYPE_DATETIME
105
- val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.second_part].pack("CvCCCCCV")
106
99
  else
107
100
  raise ProtocolError, "class #{v.class} is not supported"
108
101
  end
@@ -112,12 +105,14 @@ class Mysql
112
105
  attr_reader :server_info
113
106
  attr_reader :server_version
114
107
  attr_reader :thread_id
108
+ attr_reader :client_flags
115
109
  attr_reader :sqlstate
116
110
  attr_reader :affected_rows
117
111
  attr_reader :insert_id
118
112
  attr_reader :server_status
119
113
  attr_reader :warning_count
120
114
  attr_reader :message
115
+ attr_reader :get_server_public_key
121
116
  attr_accessor :charset
122
117
 
123
118
  # @state variable keep state for connection.
@@ -127,72 +122,101 @@ class Mysql
127
122
  # :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
128
123
 
129
124
  # make socket connection to server.
130
- # === Argument
131
- # host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket
132
- # port :: [Integer] port number using by TCPSocket
133
- # socket :: [String] socket file name using by UNIXSocket
134
- # conn_timeout :: [Integer] connect timeout (sec).
135
- # read_timeout :: [Integer] read timeout (sec).
136
- # write_timeout :: [Integer] write timeout (sec).
137
- # === Exception
138
- # [ClientError] :: connection timeout
139
- def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout)
125
+ # @param opts [Hash]
126
+ # @option :host [String] hostname mysqld running
127
+ # @option :username [String] username to connect to mysqld
128
+ # @option :password [String] password to connect to mysqld
129
+ # @option :database [String] initial database name
130
+ # @option :port [String] port number (used if host is not 'localhost' or nil)
131
+ # @option :socket [String] socket filename (used if host is 'localhost' or nil)
132
+ # @option :flags [Integer] connection flag. Mysql::CLIENT_* ORed
133
+ # @option :charset [Mysql::Charset] character set
134
+ # @option :connect_timeout [Numeric, nil]
135
+ # @option :read_timeout [Numeric, nil]
136
+ # @option :write_timeout [Numeric, nil]
137
+ # @option :local_infile [Boolean]
138
+ # @option :load_data_local_dir [String]
139
+ # @option :ssl_mode [Integer]
140
+ # @option :get_server_public_key [Boolean]
141
+ # @raise [ClientError] connection timeout
142
+ def initialize(opts)
143
+ @opts = opts
144
+ @charset = Mysql::Charset.by_name("utf8mb4")
140
145
  @insert_id = 0
141
146
  @warning_count = 0
142
147
  @gc_stmt_queue = [] # stmt id list which GC destroy.
143
148
  set_state :INIT
144
- @read_timeout = read_timeout
145
- @write_timeout = write_timeout
149
+ @get_server_public_key = @opts[:get_server_public_key]
146
150
  begin
147
- Timeout.timeout conn_timeout do
148
- if host.nil? or host.empty? or host == "localhost"
149
- socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
150
- @sock = UNIXSocket.new socket
151
- else
152
- port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
153
- @sock = TCPSocket.new host, port
154
- end
151
+ if @opts[:host].nil? or @opts[:host].empty? or @opts[:host] == "localhost"
152
+ socket = @opts[:socket] || ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
153
+ @socket = Socket.unix(socket)
154
+ else
155
+ port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
156
+ @socket = Socket.tcp(@opts[:host], port, connect_timeout: @opts[:connect_timeout])
155
157
  end
156
- rescue Timeout::Error
158
+ rescue Errno::ETIMEDOUT
157
159
  raise ClientError, "connection timeout"
158
160
  end
159
161
  end
160
162
 
161
163
  def close
162
- @sock.close
164
+ @socket.close
163
165
  end
164
166
 
165
167
  # initial negotiate and authenticate.
166
- # === Argument
167
- # user :: [String / nil] username
168
- # passwd :: [String / nil] password
169
- # db :: [String / nil] default database name. nil: no default.
170
- # flag :: [Integer] client flag
171
- # charset :: [Mysql::Charset / nil] charset for connection. nil: use server's charset
172
- # === Exception
173
- # ProtocolError :: The old style password is not supported
174
- def authenticate(user, passwd, db, flag, charset)
168
+ # @param charset [Mysql::Charset, nil] charset for connection. nil: use server's charset
169
+ # @raise [ProtocolError] The old style password is not supported
170
+ def authenticate
175
171
  check_state :INIT
176
- @authinfo = [user, passwd, db, flag, charset]
177
172
  reset
178
173
  init_packet = InitialPacket.parse read
179
174
  @server_info = init_packet.server_version
180
175
  @server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
176
+ @server_capabilities = init_packet.server_capabilities
181
177
  @thread_id = init_packet.thread_id
182
- client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION
183
- client_flags |= CLIENT_CONNECT_WITH_DB if db
184
- client_flags |= flag
185
- @charset = charset
186
- unless @charset
178
+ @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
179
+ @client_flags |= CLIENT_LOCAL_FILES if @opts[:local_infile] || @opts[:load_data_local_dir]
180
+ @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
181
+ @client_flags |= @opts[:flags]
182
+ if @opts[:charset]
183
+ @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
184
+ else
187
185
  @charset = Charset.by_number(init_packet.server_charset)
188
186
  @charset.encoding # raise error if unsupported charset
189
187
  end
190
- netpw = encrypt_password passwd, init_packet.scramble_buff
191
- write AuthenticationPacket.serialize(client_flags, 1024**3, @charset.number, user, netpw, db)
192
- raise ProtocolError, 'The old style password is not supported' if read.to_s == "\xfe"
188
+ enable_ssl
189
+ Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin)
193
190
  set_state :READY
194
191
  end
195
192
 
193
+ def enable_ssl
194
+ case @opts[:ssl_mode]
195
+ when SSL_MODE_DISABLED, '1', 'disabled'
196
+ return
197
+ when SSL_MODE_PREFERRED, '2', 'preferred'
198
+ return if @socket.local_address.unix?
199
+ return if @server_capabilities & CLIENT_SSL == 0
200
+ when SSL_MODE_REQUIRED, '3', 'required'
201
+ if @server_capabilities & CLIENT_SSL == 0
202
+ raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
203
+ end
204
+ else
205
+ raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported"
206
+ end
207
+ begin
208
+ @client_flags |= CLIENT_SSL
209
+ write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
210
+ @socket = OpenSSL::SSL::SSLSocket.new(@socket)
211
+ @socket.sync_close = true
212
+ @socket.connect
213
+ rescue => e
214
+ @client_flags &= ~CLIENT_SSL
215
+ return if @opts[:ssl_mode] == SSL_MODE_PREFERRED
216
+ raise e
217
+ end
218
+ end
219
+
196
220
  # Quit command
197
221
  def quit_command
198
222
  synchronize do
@@ -203,10 +227,8 @@ class Mysql
203
227
  end
204
228
 
205
229
  # Query command
206
- # === Argument
207
- # query :: [String] query string
208
- # === Return
209
- # [Integer / nil] number of fields of results. nil if no results.
230
+ # @param query [String] query string
231
+ # @return [Integer, nil] number of fields of results. nil if no results.
210
232
  def query_command(query)
211
233
  check_state :READY
212
234
  begin
@@ -220,8 +242,7 @@ class Mysql
220
242
  end
221
243
 
222
244
  # get result of query.
223
- # === Return
224
- # [integer / nil] number of fields of results. nil if no results.
245
+ # @return [integer, nil] number of fields of results. nil if no results.
225
246
  def get_result
226
247
  begin
227
248
  res_packet = ResultPacket.parse read
@@ -230,10 +251,7 @@ class Mysql
230
251
  return res_packet.field_count
231
252
  end
232
253
  if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
233
- filename = res_packet.message
234
- File.open(filename){|f| write f}
235
- write nil # EOF mark
236
- read
254
+ send_local_file(res_packet.message)
237
255
  end
238
256
  @affected_rows, @insert_id, @server_status, @warning_count, @message =
239
257
  res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
@@ -245,11 +263,22 @@ class Mysql
245
263
  end
246
264
  end
247
265
 
266
+ # send local file to server
267
+ def send_local_file(filename)
268
+ filename = File.absolute_path(filename)
269
+ if @opts[:local_infile] || @opts[:load_data_local_dir] && filename.start_with?(@opts[:load_data_local_dir])
270
+ File.open(filename){|f| write f}
271
+ else
272
+ raise ClientError::LoadDataLocalInfileRejected, 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
273
+ end
274
+ ensure
275
+ write nil # EOF mark
276
+ read
277
+ end
278
+
248
279
  # Retrieve n fields
249
- # === Argument
250
- # n :: [Integer] number of fields
251
- # === Return
252
- # [Array of Mysql::Field] field list
280
+ # @param n [Integer] number of fields
281
+ # @return [Array<Mysql::Field>] field list
253
282
  def retr_fields(n)
254
283
  check_state :FIELD
255
284
  begin
@@ -264,10 +293,8 @@ class Mysql
264
293
  end
265
294
 
266
295
  # Retrieve all records for simple query
267
- # === Argument
268
- # fields :: [Array<Mysql::Field>] number of fields
269
- # === Return
270
- # [Array of Array of String] all records
296
+ # @param fields [Array<Mysql::Field>] number of fields
297
+ # @return [Array<Array<String>>] all records
271
298
  def retr_all_records(fields)
272
299
  check_state :RESULT
273
300
  enc = charset.encoding
@@ -284,43 +311,6 @@ class Mysql
284
311
  end
285
312
  end
286
313
 
287
- # Field list command
288
- # === Argument
289
- # table :: [String] table name.
290
- # field :: [String / nil] field name that may contain wild card.
291
- # === Return
292
- # [Array of Field] field list
293
- def field_list_command(table, field)
294
- synchronize do
295
- reset
296
- write [COM_FIELD_LIST, table, 0, field].pack("Ca*Ca*")
297
- fields = []
298
- until (data = read).eof?
299
- fields.push Field.new(FieldPacket.parse(data))
300
- end
301
- return fields
302
- end
303
- end
304
-
305
- # Process info command
306
- # === Return
307
- # [Array of Field] field list
308
- def process_info_command
309
- check_state :READY
310
- begin
311
- reset
312
- write [COM_PROCESS_INFO].pack("C")
313
- field_count = read.lcb
314
- fields = field_count.times.map{Field.new FieldPacket.parse(read)}
315
- read_eof_packet
316
- set_state :RESULT
317
- return fields
318
- rescue
319
- set_state :READY
320
- raise
321
- end
322
- end
323
-
324
314
  # Ping command
325
315
  def ping_command
326
316
  simple_command [COM_PING].pack("C")
@@ -352,12 +342,8 @@ class Mysql
352
342
  end
353
343
 
354
344
  # Stmt prepare command
355
- # === Argument
356
- # stmt :: [String] prepared statement
357
- # === Return
358
- # [Integer] statement id
359
- # [Integer] number of parameters
360
- # [Array of Field] field list
345
+ # @param stmt [String] prepared statement
346
+ # @return [Array<Integer, Integer, Array<Field>>] statement id, number of parameters, field list
361
347
  def stmt_prepare_command(stmt)
362
348
  synchronize do
363
349
  reset
@@ -378,11 +364,9 @@ class Mysql
378
364
  end
379
365
 
380
366
  # Stmt execute command
381
- # === Argument
382
- # stmt_id :: [Integer] statement id
383
- # values :: [Array] parameters
384
- # === Return
385
- # [Integer] number of fields
367
+ # @param stmt_id [Integer] statement id
368
+ # @param values [Array] parameters
369
+ # @return [Integer] number of fields
386
370
  def stmt_execute_command(stmt_id, values)
387
371
  check_state :READY
388
372
  begin
@@ -396,11 +380,9 @@ class Mysql
396
380
  end
397
381
 
398
382
  # Retrieve all records for prepared statement
399
- # === Argument
400
- # fields :: [Array of Mysql::Fields] field list
401
- # charset :: [Mysql::Charset]
402
- # === Return
403
- # [Array of Array of Object] all records
383
+ # @param fields [Array of Mysql::Fields] field list
384
+ # @param charset [Mysql::Charset]
385
+ # @return [Array<Array<Object>>] all records
404
386
  def stmt_retr_all_records(fields, charset)
405
387
  check_state :RESULT
406
388
  enc = charset.encoding
@@ -416,8 +398,7 @@ class Mysql
416
398
  end
417
399
 
418
400
  # Stmt close command
419
- # === Argument
420
- # stmt_id :: [Integer] statement id
401
+ # @param stmt_id [Integer] statement id
421
402
  def stmt_close_command(stmt_id)
422
403
  synchronize do
423
404
  reset
@@ -429,15 +410,13 @@ class Mysql
429
410
  @gc_stmt_queue.push stmt_id
430
411
  end
431
412
 
432
- private
433
-
434
413
  def check_state(st)
435
414
  raise 'command out of sync' unless @state == st
436
415
  end
437
416
 
438
417
  def set_state(st)
439
418
  @state = st
440
- if st == :READY
419
+ if st == :READY && !@gc_stmt_queue.empty?
441
420
  gc_disabled = GC.disable
442
421
  begin
443
422
  while st = @gc_stmt_queue.shift
@@ -465,28 +444,24 @@ class Mysql
465
444
  end
466
445
 
467
446
  # Read one packet data
468
- # === Return
469
- # [Packet] packet data
470
- # === Exception
471
- # [ProtocolError] invalid packet sequence number
447
+ # @return [Packet] packet data
448
+ # @rails [ProtocolError] invalid packet sequence number
472
449
  def read
473
450
  data = ''
474
451
  len = nil
475
452
  begin
476
- Timeout.timeout @read_timeout do
477
- header = @sock.read(4)
478
- raise EOFError unless header && header.length == 4
479
- len1, len2, seq = header.unpack("CvC")
480
- len = (len2 << 8) + len1
481
- raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
482
- @seq = (@seq + 1) % 256
483
- ret = @sock.read(len)
484
- raise EOFError unless ret && ret.length == len
485
- data.concat ret
486
- end
453
+ header = read_timeout(4, @opts[:read_timeout])
454
+ raise EOFError unless header && header.length == 4
455
+ len1, len2, seq = header.unpack("CvC")
456
+ len = (len2 << 8) + len1
457
+ raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
458
+ @seq = (@seq + 1) % 256
459
+ ret = read_timeout(len, @opts[:read_timeout])
460
+ raise EOFError unless ret && ret.length == len
461
+ data.concat ret
487
462
  rescue EOFError
488
463
  raise ClientError::ServerGoneError, 'MySQL server has gone away'
489
- rescue Timeout::Error
464
+ rescue Errno::ETIMEDOUT
490
465
  raise ClientError, "read timeout"
491
466
  end while len == MAX_PACKET_LENGTH
492
467
 
@@ -494,64 +469,95 @@ class Mysql
494
469
 
495
470
  # Error packet
496
471
  if data[0] == ?\xff
497
- f, errno, marker, @sqlstate, message = data.unpack("Cvaa5a*")
472
+ _, errno, marker, @sqlstate, message = data.unpack("Cvaa5a*")
498
473
  unless marker == "#"
499
- f, errno, message = data.unpack("Cva*") # Version 4.0 Error
474
+ _, errno, message = data.unpack("Cva*") # Version 4.0 Error
500
475
  @sqlstate = ""
501
476
  end
502
477
  message.force_encoding(@charset.encoding)
503
478
  if Mysql::ServerError::ERROR_MAP.key? errno
504
479
  raise Mysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
505
480
  end
506
- raise Mysql::ServerError.new(message, @sqlstate)
481
+ raise Mysql::ServerError.new(message, @sqlstate, errno)
507
482
  end
508
483
  Packet.new(data)
509
484
  end
510
485
 
486
+ def read_timeout(len, timeout)
487
+ return @socket.read(len) if timeout.nil? || timeout == 0
488
+ result = ''
489
+ e = Time.now + timeout
490
+ while result.size < len
491
+ now = Time.now
492
+ raise Errno::ETIMEDOUT if now > e
493
+ r = @socket.read_nonblock(len - result.size, exception: false)
494
+ case r
495
+ when :wait_readable
496
+ IO.select([@socket], nil, nil, e - now)
497
+ next
498
+ when :wait_writable
499
+ IO.select(nil, [@socket], nil, e - now)
500
+ next
501
+ else
502
+ result << r
503
+ end
504
+ end
505
+ return result
506
+ end
507
+
511
508
  # Write one packet data
512
- # === Argument
513
- # data :: [String / IO] packet data. If data is nil, write empty packet.
509
+ # @param data [String, IO, nil] packet data. If data is nil, write empty packet.
514
510
  def write(data)
515
511
  begin
516
- @sock.sync = false
512
+ @socket.sync = false
517
513
  if data.nil?
518
- Timeout.timeout @write_timeout do
519
- @sock.write [0, 0, @seq].pack("CvC")
520
- end
514
+ write_timeout([0, 0, @seq].pack("CvC"), @opts[:write_timeout])
521
515
  @seq = (@seq + 1) % 256
522
516
  else
523
517
  data = StringIO.new data if data.is_a? String
524
518
  while d = data.read(MAX_PACKET_LENGTH)
525
- Timeout.timeout @write_timeout do
526
- @sock.write [d.length%256, d.length/256, @seq].pack("CvC")
527
- @sock.write d
528
- end
519
+ write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, @opts[:write_timeout])
529
520
  @seq = (@seq + 1) % 256
530
521
  end
531
522
  end
532
- @sock.sync = true
533
- Timeout.timeout @write_timeout do
534
- @sock.flush
535
- end
523
+ @socket.sync = true
524
+ @socket.flush
536
525
  rescue Errno::EPIPE
537
526
  raise ClientError::ServerGoneError, 'MySQL server has gone away'
538
- rescue Timeout::Error
527
+ rescue Errno::ETIMEDOUT
539
528
  raise ClientError, "write timeout"
540
529
  end
541
530
  end
542
531
 
532
+ def write_timeout(data, timeout)
533
+ return @socket.write(data) if timeout.nil? || timeout == 0
534
+ len = 0
535
+ e = Time.now + timeout
536
+ while len < data.size
537
+ now = Time.now
538
+ raise Errno::ETIMEDOUT if now > e
539
+ l = @socket.write_nonblock(data[len..-1], exception: false)
540
+ case l
541
+ when :wait_readable
542
+ IO.select([@socket], nil, nil, e - now)
543
+ when :wait_writable
544
+ IO.select(nil, [@socket], nil, e - now)
545
+ else
546
+ len += l
547
+ end
548
+ end
549
+ return len
550
+ end
551
+
543
552
  # Read EOF packet
544
- # === Exception
545
- # [ProtocolError] packet is not EOF
553
+ # @raise [ProtocolError] packet is not EOF
546
554
  def read_eof_packet
547
555
  raise ProtocolError, "packet is not EOF" unless read.eof?
548
556
  end
549
557
 
550
558
  # Send simple command
551
- # === Argument
552
- # packet :: [String] packet data
553
- # === Return
554
- # [String] received data
559
+ # @param packet :: [String] packet data
560
+ # @return [String] received data
555
561
  def simple_command(packet)
556
562
  synchronize do
557
563
  reset
@@ -560,19 +566,6 @@ class Mysql
560
566
  end
561
567
  end
562
568
 
563
- # Encrypt password
564
- # === Argument
565
- # plain :: [String] plain password.
566
- # scramble :: [String] scramble code from initial packet.
567
- # === Return
568
- # [String] encrypted password
569
- def encrypt_password(plain, scramble)
570
- return "" if plain.nil? or plain.empty?
571
- hash_stage1 = Digest::SHA1.digest plain
572
- hash_stage2 = Digest::SHA1.digest hash_stage1
573
- return hash_stage1.unpack("C*").zip(Digest::SHA1.digest(scramble+hash_stage2).unpack("C*")).map{|a,b| a^b}.pack("C*")
574
- end
575
-
576
569
  # Initial packet
577
570
  class InitialPacket
578
571
  def self.parse(pkt)
@@ -584,18 +577,26 @@ class Mysql
584
577
  server_capabilities = pkt.ushort
585
578
  server_charset = pkt.utiny
586
579
  server_status = pkt.ushort
587
- _f1 = pkt.read(13)
580
+ server_capabilities2 = pkt.ushort
581
+ scramble_length = pkt.utiny
582
+ _f1 = pkt.read(10)
588
583
  rest_scramble_buff = pkt.string
584
+ auth_plugin = pkt.string
585
+
586
+ server_capabilities |= server_capabilities2 << 16
587
+ scramble_buff.concat rest_scramble_buff
588
+
589
589
  raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION
590
590
  raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0
591
- scramble_buff.concat rest_scramble_buff
592
- self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff
591
+ raise ProtocolError, "invalid packet: scramble_length(#{scramble_length}) != length of scramble(#{scramble_buff.size + 1})" unless scramble_length == scramble_buff.size + 1
592
+
593
+ self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff, auth_plugin
593
594
  end
594
595
 
595
- attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff
596
+ attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff, :auth_plugin
596
597
 
597
598
  def initialize(*args)
598
- @protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff = args
599
+ @protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff, @auth_plugin = args
599
600
  end
600
601
  end
601
602
 
@@ -675,16 +676,35 @@ class Mysql
675
676
 
676
677
  # Authentication packet
677
678
  class AuthenticationPacket
678
- def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename)
679
- [
679
+ def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin)
680
+ data = [
680
681
  client_flags,
681
682
  max_packet_size,
682
- Packet.lcb(charset_number),
683
+ charset_number,
683
684
  "", # always 0x00 * 23
684
685
  username,
685
686
  Packet.lcs(scrambled_password),
686
- databasename
687
- ].pack("VVa*a23Z*A*Z*")
687
+ ]
688
+ pack = "VVCa23Z*A*"
689
+ if databasename
690
+ data.push databasename
691
+ pack.concat "Z*"
692
+ end
693
+ data.push auth_plugin
694
+ pack.concat "Z*"
695
+ data.pack(pack)
696
+ end
697
+ end
698
+
699
+ # TLS Authentication packet
700
+ class TlsAuthenticationPacket
701
+ def self.serialize(client_flags, max_packet_size, charset_number)
702
+ [
703
+ client_flags,
704
+ max_packet_size,
705
+ charset_number,
706
+ "", # always 0x00 * 23
707
+ ].pack("VVCa23")
688
708
  end
689
709
  end
690
710
 
@@ -712,6 +732,21 @@ class Mysql
712
732
  end
713
733
 
714
734
  end
735
+
736
+ class AuthenticationResultPacket
737
+ def self.parse(pkt)
738
+ result = pkt.utiny
739
+ auth_plugin = pkt.string
740
+ scramble = pkt.string
741
+ self.new(result, auth_plugin, scramble)
742
+ end
743
+
744
+ attr_reader :result, :auth_plugin, :scramble
745
+
746
+ def initialize(*args)
747
+ @result, @auth_plugin, @scramble = args
748
+ end
749
+ end
715
750
  end
716
751
 
717
752
  class RawRecord
@@ -732,17 +767,15 @@ class Mysql
732
767
  end
733
768
 
734
769
  class StmtRawRecord
735
- # === Argument
736
- # pkt :: [Packet]
737
- # fields :: [Array of Fields]
738
- # encoding:: [Encoding]
770
+ # @param pkt [Packet]
771
+ # @param fields [Array of Fields]
772
+ # @param encoding [Encoding]
739
773
  def initialize(packet, fields, encoding)
740
774
  @packet, @fields, @encoding = packet, fields, encoding
741
775
  end
742
776
 
743
777
  # Parse statement result packet
744
- # === Return
745
- # [Array of Object] one record
778
+ # @return [Array<Object>] one record
746
779
  def parse_record_packet
747
780
  @packet.utiny # skip first byte
748
781
  null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first
@@ -752,7 +785,7 @@ class Mysql
752
785
  else
753
786
  unsigned = f.flags & Field::UNSIGNED_FLAG != 0
754
787
  v = Protocol.net2value(@packet, f.type, unsigned)
755
- if v.is_a? Numeric or v.is_a? Mysql::Time
788
+ if v.nil? or v.is_a? Numeric or v.is_a? Time
756
789
  v
757
790
  elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
758
791
  Charset.to_binary(v)