ruby-mysql 2.9.14 → 3.0.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,11 +1,11 @@
1
1
  # coding: ascii-8bit
2
- # Copyright (C) 2008-2012 TOMITA Masahiro
2
+ # Copyright (C) 2008 TOMITA Masahiro
3
3
  # mailto:tommy@tmtm.org
4
4
 
5
5
  require "socket"
6
- require "timeout"
7
- require "digest/sha1"
8
6
  require "stringio"
7
+ require "openssl"
8
+ require_relative 'authenticator.rb'
9
9
 
10
10
  class Mysql
11
11
  # MySQL network protocol
@@ -15,12 +15,10 @@ class Mysql
15
15
  MAX_PACKET_LENGTH = 2**24-1
16
16
 
17
17
  # Convert netdata to Ruby value
18
- # === Argument
19
- # data :: [Packet] packet data
20
- # type :: [Integer] field type
21
- # unsigned :: [true or false] true if value is unsigned
22
- # === Return
23
- # Object :: converted value.
18
+ # @param data [Packet] packet data
19
+ # @param type [Integer] field type
20
+ # @param unsigned [true or false] true if value is unsigned
21
+ # @return [Object] converted value.
24
22
  def self.net2value(pkt, type, unsigned)
25
23
  case type
26
24
  when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB, Field::TYPE_JSON
@@ -45,17 +43,18 @@ class Mysql
45
43
  when Field::TYPE_DATE
46
44
  len = pkt.utiny
47
45
  y, m, d = pkt.read(len).unpack("vCC")
48
- t = Mysql::Time.new(y, m, d, nil, nil, nil)
46
+ t = Time.new(y, m, d) rescue nil
49
47
  return t
50
48
  when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
51
49
  len = pkt.utiny
52
50
  y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")
53
- return Mysql::Time.new(y, m, d, h, mi, s, false, sp)
51
+ return Time.new(y, m, d, h, mi, Rational((s.to_i*1000000+sp.to_i)/1000000)) rescue nil
54
52
  when Field::TYPE_TIME
55
53
  len = pkt.utiny
56
54
  sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")
57
- h = d.to_i * 24 + h.to_i
58
- return Mysql::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)
55
+ r = d.to_i*86400 + h.to_i*3600 + mi.to_i*60 + s.to_i + sp.to_f/1000000
56
+ r *= -1 if sign != 0
57
+ return r
59
58
  when Field::TYPE_YEAR
60
59
  return pkt.ushort
61
60
  when Field::TYPE_BIT
@@ -66,13 +65,10 @@ class Mysql
66
65
  end
67
66
 
68
67
  # convert Ruby value to netdata
69
- # === Argument
70
- # v :: [Object] Ruby value.
71
- # === Return
72
- # Integer :: type of column. Field::TYPE_*
73
- # String :: netdata
74
- # === Exception
75
- # ProtocolError :: value too large / value is not supported
68
+ # @param v [Object] Ruby value.
69
+ # @return [Integer] type of column. Field::TYPE_*
70
+ # @return [String] netdata
71
+ # @raise [ProtocolError] value too large / value is not supported
76
72
  def self.value2net(v)
77
73
  case v
78
74
  when nil
@@ -97,12 +93,9 @@ class Mysql
97
93
  when String
98
94
  type = Field::TYPE_STRING
99
95
  val = Packet.lcs(v)
100
- when ::Time
96
+ when Time
101
97
  type = Field::TYPE_DATETIME
102
98
  val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec].pack("CvCCCCCV")
103
- when Mysql::Time
104
- type = Field::TYPE_DATETIME
105
- val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.second_part].pack("CvCCCCCV")
106
99
  else
107
100
  raise ProtocolError, "class #{v.class} is not supported"
108
101
  end
@@ -112,12 +105,14 @@ class Mysql
112
105
  attr_reader :server_info
113
106
  attr_reader :server_version
114
107
  attr_reader :thread_id
108
+ attr_reader :client_flags
115
109
  attr_reader :sqlstate
116
110
  attr_reader :affected_rows
117
111
  attr_reader :insert_id
118
112
  attr_reader :server_status
119
113
  attr_reader :warning_count
120
114
  attr_reader :message
115
+ attr_reader :get_server_public_key
121
116
  attr_accessor :charset
122
117
 
123
118
  # @state variable keep state for connection.
@@ -127,72 +122,101 @@ class Mysql
127
122
  # :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
128
123
 
129
124
  # make socket connection to server.
130
- # === Argument
131
- # host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket
132
- # port :: [Integer] port number using by TCPSocket
133
- # socket :: [String] socket file name using by UNIXSocket
134
- # conn_timeout :: [Integer] connect timeout (sec).
135
- # read_timeout :: [Integer] read timeout (sec).
136
- # write_timeout :: [Integer] write timeout (sec).
137
- # === Exception
138
- # [ClientError] :: connection timeout
139
- def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout)
125
+ # @param opts [Hash]
126
+ # @option :host [String] hostname mysqld running
127
+ # @option :username [String] username to connect to mysqld
128
+ # @option :password [String] password to connect to mysqld
129
+ # @option :database [String] initial database name
130
+ # @option :port [String] port number (used if host is not 'localhost' or nil)
131
+ # @option :socket [String] socket filename (used if host is 'localhost' or nil)
132
+ # @option :flags [Integer] connection flag. Mysql::CLIENT_* ORed
133
+ # @option :charset [Mysql::Charset] character set
134
+ # @option :connect_timeout [Numeric, nil]
135
+ # @option :read_timeout [Numeric, nil]
136
+ # @option :write_timeout [Numeric, nil]
137
+ # @option :local_infile [Boolean]
138
+ # @option :load_data_local_dir [String]
139
+ # @option :ssl_mode [Integer]
140
+ # @option :get_server_public_key [Boolean]
141
+ # @raise [ClientError] connection timeout
142
+ def initialize(opts)
143
+ @opts = opts
144
+ @charset = Mysql::Charset.by_name("utf8mb4")
140
145
  @insert_id = 0
141
146
  @warning_count = 0
142
147
  @gc_stmt_queue = [] # stmt id list which GC destroy.
143
148
  set_state :INIT
144
- @read_timeout = read_timeout
145
- @write_timeout = write_timeout
149
+ @get_server_public_key = @opts[:get_server_public_key]
146
150
  begin
147
- Timeout.timeout conn_timeout do
148
- if host.nil? or host.empty? or host == "localhost"
149
- socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
150
- @sock = UNIXSocket.new socket
151
- else
152
- port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
153
- @sock = TCPSocket.new host, port
154
- end
151
+ if @opts[:host].nil? or @opts[:host].empty? or @opts[:host] == "localhost"
152
+ socket = @opts[:socket] || ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
153
+ @socket = Socket.unix(socket)
154
+ else
155
+ port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
156
+ @socket = Socket.tcp(@opts[:host], port, connect_timeout: @opts[:connect_timeout])
155
157
  end
156
- rescue Timeout::Error
158
+ rescue Errno::ETIMEDOUT
157
159
  raise ClientError, "connection timeout"
158
160
  end
159
161
  end
160
162
 
161
163
  def close
162
- @sock.close
164
+ @socket.close
163
165
  end
164
166
 
165
167
  # initial negotiate and authenticate.
166
- # === Argument
167
- # user :: [String / nil] username
168
- # passwd :: [String / nil] password
169
- # db :: [String / nil] default database name. nil: no default.
170
- # flag :: [Integer] client flag
171
- # charset :: [Mysql::Charset / nil] charset for connection. nil: use server's charset
172
- # === Exception
173
- # ProtocolError :: The old style password is not supported
174
- def authenticate(user, passwd, db, flag, charset)
168
+ # @param charset [Mysql::Charset, nil] charset for connection. nil: use server's charset
169
+ # @raise [ProtocolError] The old style password is not supported
170
+ def authenticate
175
171
  check_state :INIT
176
- @authinfo = [user, passwd, db, flag, charset]
177
172
  reset
178
173
  init_packet = InitialPacket.parse read
179
174
  @server_info = init_packet.server_version
180
175
  @server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
176
+ @server_capabilities = init_packet.server_capabilities
181
177
  @thread_id = init_packet.thread_id
182
- client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION
183
- client_flags |= CLIENT_CONNECT_WITH_DB if db
184
- client_flags |= flag
185
- @charset = charset
186
- unless @charset
178
+ @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
179
+ @client_flags |= CLIENT_LOCAL_FILES if @opts[:local_infile] || @opts[:load_data_local_dir]
180
+ @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
181
+ @client_flags |= @opts[:flags]
182
+ if @opts[:charset]
183
+ @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
184
+ else
187
185
  @charset = Charset.by_number(init_packet.server_charset)
188
186
  @charset.encoding # raise error if unsupported charset
189
187
  end
190
- netpw = encrypt_password passwd, init_packet.scramble_buff
191
- write AuthenticationPacket.serialize(client_flags, 1024**3, @charset.number, user, netpw, db)
192
- raise ProtocolError, 'The old style password is not supported' if read.to_s == "\xfe"
188
+ enable_ssl
189
+ Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin)
193
190
  set_state :READY
194
191
  end
195
192
 
193
+ def enable_ssl
194
+ case @opts[:ssl_mode]
195
+ when SSL_MODE_DISABLED, '1', 'disabled'
196
+ return
197
+ when SSL_MODE_PREFERRED, '2', 'preferred'
198
+ return if @socket.local_address.unix?
199
+ return if @server_capabilities & CLIENT_SSL == 0
200
+ when SSL_MODE_REQUIRED, '3', 'required'
201
+ if @server_capabilities & CLIENT_SSL == 0
202
+ raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
203
+ end
204
+ else
205
+ raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported"
206
+ end
207
+ begin
208
+ @client_flags |= CLIENT_SSL
209
+ write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
210
+ @socket = OpenSSL::SSL::SSLSocket.new(@socket)
211
+ @socket.sync_close = true
212
+ @socket.connect
213
+ rescue => e
214
+ @client_flags &= ~CLIENT_SSL
215
+ return if @opts[:ssl_mode] == SSL_MODE_PREFERRED
216
+ raise e
217
+ end
218
+ end
219
+
196
220
  # Quit command
197
221
  def quit_command
198
222
  synchronize do
@@ -203,10 +227,8 @@ class Mysql
203
227
  end
204
228
 
205
229
  # Query command
206
- # === Argument
207
- # query :: [String] query string
208
- # === Return
209
- # [Integer / nil] number of fields of results. nil if no results.
230
+ # @param query [String] query string
231
+ # @return [Integer, nil] number of fields of results. nil if no results.
210
232
  def query_command(query)
211
233
  check_state :READY
212
234
  begin
@@ -220,8 +242,7 @@ class Mysql
220
242
  end
221
243
 
222
244
  # get result of query.
223
- # === Return
224
- # [integer / nil] number of fields of results. nil if no results.
245
+ # @return [integer, nil] number of fields of results. nil if no results.
225
246
  def get_result
226
247
  begin
227
248
  res_packet = ResultPacket.parse read
@@ -230,10 +251,7 @@ class Mysql
230
251
  return res_packet.field_count
231
252
  end
232
253
  if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
233
- filename = res_packet.message
234
- File.open(filename){|f| write f}
235
- write nil # EOF mark
236
- read
254
+ send_local_file(res_packet.message)
237
255
  end
238
256
  @affected_rows, @insert_id, @server_status, @warning_count, @message =
239
257
  res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
@@ -245,11 +263,22 @@ class Mysql
245
263
  end
246
264
  end
247
265
 
266
+ # send local file to server
267
+ def send_local_file(filename)
268
+ filename = File.absolute_path(filename)
269
+ if @opts[:local_infile] || @opts[:load_data_local_dir] && filename.start_with?(@opts[:load_data_local_dir])
270
+ File.open(filename){|f| write f}
271
+ else
272
+ raise ClientError::LoadDataLocalInfileRejected, 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
273
+ end
274
+ ensure
275
+ write nil # EOF mark
276
+ read
277
+ end
278
+
248
279
  # Retrieve n fields
249
- # === Argument
250
- # n :: [Integer] number of fields
251
- # === Return
252
- # [Array of Mysql::Field] field list
280
+ # @param n [Integer] number of fields
281
+ # @return [Array<Mysql::Field>] field list
253
282
  def retr_fields(n)
254
283
  check_state :FIELD
255
284
  begin
@@ -264,10 +293,8 @@ class Mysql
264
293
  end
265
294
 
266
295
  # Retrieve all records for simple query
267
- # === Argument
268
- # fields :: [Array<Mysql::Field>] number of fields
269
- # === Return
270
- # [Array of Array of String] all records
296
+ # @param fields [Array<Mysql::Field>] number of fields
297
+ # @return [Array<Array<String>>] all records
271
298
  def retr_all_records(fields)
272
299
  check_state :RESULT
273
300
  enc = charset.encoding
@@ -284,43 +311,6 @@ class Mysql
284
311
  end
285
312
  end
286
313
 
287
- # Field list command
288
- # === Argument
289
- # table :: [String] table name.
290
- # field :: [String / nil] field name that may contain wild card.
291
- # === Return
292
- # [Array of Field] field list
293
- def field_list_command(table, field)
294
- synchronize do
295
- reset
296
- write [COM_FIELD_LIST, table, 0, field].pack("Ca*Ca*")
297
- fields = []
298
- until (data = read).eof?
299
- fields.push Field.new(FieldPacket.parse(data))
300
- end
301
- return fields
302
- end
303
- end
304
-
305
- # Process info command
306
- # === Return
307
- # [Array of Field] field list
308
- def process_info_command
309
- check_state :READY
310
- begin
311
- reset
312
- write [COM_PROCESS_INFO].pack("C")
313
- field_count = read.lcb
314
- fields = field_count.times.map{Field.new FieldPacket.parse(read)}
315
- read_eof_packet
316
- set_state :RESULT
317
- return fields
318
- rescue
319
- set_state :READY
320
- raise
321
- end
322
- end
323
-
324
314
  # Ping command
325
315
  def ping_command
326
316
  simple_command [COM_PING].pack("C")
@@ -352,12 +342,8 @@ class Mysql
352
342
  end
353
343
 
354
344
  # Stmt prepare command
355
- # === Argument
356
- # stmt :: [String] prepared statement
357
- # === Return
358
- # [Integer] statement id
359
- # [Integer] number of parameters
360
- # [Array of Field] field list
345
+ # @param stmt [String] prepared statement
346
+ # @return [Array<Integer, Integer, Array<Field>>] statement id, number of parameters, field list
361
347
  def stmt_prepare_command(stmt)
362
348
  synchronize do
363
349
  reset
@@ -378,11 +364,9 @@ class Mysql
378
364
  end
379
365
 
380
366
  # Stmt execute command
381
- # === Argument
382
- # stmt_id :: [Integer] statement id
383
- # values :: [Array] parameters
384
- # === Return
385
- # [Integer] number of fields
367
+ # @param stmt_id [Integer] statement id
368
+ # @param values [Array] parameters
369
+ # @return [Integer] number of fields
386
370
  def stmt_execute_command(stmt_id, values)
387
371
  check_state :READY
388
372
  begin
@@ -396,11 +380,9 @@ class Mysql
396
380
  end
397
381
 
398
382
  # Retrieve all records for prepared statement
399
- # === Argument
400
- # fields :: [Array of Mysql::Fields] field list
401
- # charset :: [Mysql::Charset]
402
- # === Return
403
- # [Array of Array of Object] all records
383
+ # @param fields [Array of Mysql::Fields] field list
384
+ # @param charset [Mysql::Charset]
385
+ # @return [Array<Array<Object>>] all records
404
386
  def stmt_retr_all_records(fields, charset)
405
387
  check_state :RESULT
406
388
  enc = charset.encoding
@@ -416,8 +398,7 @@ class Mysql
416
398
  end
417
399
 
418
400
  # Stmt close command
419
- # === Argument
420
- # stmt_id :: [Integer] statement id
401
+ # @param stmt_id [Integer] statement id
421
402
  def stmt_close_command(stmt_id)
422
403
  synchronize do
423
404
  reset
@@ -429,15 +410,13 @@ class Mysql
429
410
  @gc_stmt_queue.push stmt_id
430
411
  end
431
412
 
432
- private
433
-
434
413
  def check_state(st)
435
414
  raise 'command out of sync' unless @state == st
436
415
  end
437
416
 
438
417
  def set_state(st)
439
418
  @state = st
440
- if st == :READY
419
+ if st == :READY && !@gc_stmt_queue.empty?
441
420
  gc_disabled = GC.disable
442
421
  begin
443
422
  while st = @gc_stmt_queue.shift
@@ -465,28 +444,24 @@ class Mysql
465
444
  end
466
445
 
467
446
  # Read one packet data
468
- # === Return
469
- # [Packet] packet data
470
- # === Exception
471
- # [ProtocolError] invalid packet sequence number
447
+ # @return [Packet] packet data
448
+ # @rails [ProtocolError] invalid packet sequence number
472
449
  def read
473
450
  data = ''
474
451
  len = nil
475
452
  begin
476
- Timeout.timeout @read_timeout do
477
- header = @sock.read(4)
478
- raise EOFError unless header && header.length == 4
479
- len1, len2, seq = header.unpack("CvC")
480
- len = (len2 << 8) + len1
481
- raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
482
- @seq = (@seq + 1) % 256
483
- ret = @sock.read(len)
484
- raise EOFError unless ret && ret.length == len
485
- data.concat ret
486
- end
453
+ header = read_timeout(4, @opts[:read_timeout])
454
+ raise EOFError unless header && header.length == 4
455
+ len1, len2, seq = header.unpack("CvC")
456
+ len = (len2 << 8) + len1
457
+ raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
458
+ @seq = (@seq + 1) % 256
459
+ ret = read_timeout(len, @opts[:read_timeout])
460
+ raise EOFError unless ret && ret.length == len
461
+ data.concat ret
487
462
  rescue EOFError
488
463
  raise ClientError::ServerGoneError, 'MySQL server has gone away'
489
- rescue Timeout::Error
464
+ rescue Errno::ETIMEDOUT
490
465
  raise ClientError, "read timeout"
491
466
  end while len == MAX_PACKET_LENGTH
492
467
 
@@ -494,64 +469,95 @@ class Mysql
494
469
 
495
470
  # Error packet
496
471
  if data[0] == ?\xff
497
- f, errno, marker, @sqlstate, message = data.unpack("Cvaa5a*")
472
+ _, errno, marker, @sqlstate, message = data.unpack("Cvaa5a*")
498
473
  unless marker == "#"
499
- f, errno, message = data.unpack("Cva*") # Version 4.0 Error
474
+ _, errno, message = data.unpack("Cva*") # Version 4.0 Error
500
475
  @sqlstate = ""
501
476
  end
502
477
  message.force_encoding(@charset.encoding)
503
478
  if Mysql::ServerError::ERROR_MAP.key? errno
504
479
  raise Mysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
505
480
  end
506
- raise Mysql::ServerError.new(message, @sqlstate)
481
+ raise Mysql::ServerError.new(message, @sqlstate, errno)
507
482
  end
508
483
  Packet.new(data)
509
484
  end
510
485
 
486
+ def read_timeout(len, timeout)
487
+ return @socket.read(len) if timeout.nil? || timeout == 0
488
+ result = ''
489
+ e = Time.now + timeout
490
+ while result.size < len
491
+ now = Time.now
492
+ raise Errno::ETIMEDOUT if now > e
493
+ r = @socket.read_nonblock(len - result.size, exception: false)
494
+ case r
495
+ when :wait_readable
496
+ IO.select([@socket], nil, nil, e - now)
497
+ next
498
+ when :wait_writable
499
+ IO.select(nil, [@socket], nil, e - now)
500
+ next
501
+ else
502
+ result << r
503
+ end
504
+ end
505
+ return result
506
+ end
507
+
511
508
  # Write one packet data
512
- # === Argument
513
- # data :: [String / IO] packet data. If data is nil, write empty packet.
509
+ # @param data [String, IO, nil] packet data. If data is nil, write empty packet.
514
510
  def write(data)
515
511
  begin
516
- @sock.sync = false
512
+ @socket.sync = false
517
513
  if data.nil?
518
- Timeout.timeout @write_timeout do
519
- @sock.write [0, 0, @seq].pack("CvC")
520
- end
514
+ write_timeout([0, 0, @seq].pack("CvC"), @opts[:write_timeout])
521
515
  @seq = (@seq + 1) % 256
522
516
  else
523
517
  data = StringIO.new data if data.is_a? String
524
518
  while d = data.read(MAX_PACKET_LENGTH)
525
- Timeout.timeout @write_timeout do
526
- @sock.write [d.length%256, d.length/256, @seq].pack("CvC")
527
- @sock.write d
528
- end
519
+ write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, @opts[:write_timeout])
529
520
  @seq = (@seq + 1) % 256
530
521
  end
531
522
  end
532
- @sock.sync = true
533
- Timeout.timeout @write_timeout do
534
- @sock.flush
535
- end
523
+ @socket.sync = true
524
+ @socket.flush
536
525
  rescue Errno::EPIPE
537
526
  raise ClientError::ServerGoneError, 'MySQL server has gone away'
538
- rescue Timeout::Error
527
+ rescue Errno::ETIMEDOUT
539
528
  raise ClientError, "write timeout"
540
529
  end
541
530
  end
542
531
 
532
+ def write_timeout(data, timeout)
533
+ return @socket.write(data) if timeout.nil? || timeout == 0
534
+ len = 0
535
+ e = Time.now + timeout
536
+ while len < data.size
537
+ now = Time.now
538
+ raise Errno::ETIMEDOUT if now > e
539
+ l = @socket.write_nonblock(data[len..-1], exception: false)
540
+ case l
541
+ when :wait_readable
542
+ IO.select([@socket], nil, nil, e - now)
543
+ when :wait_writable
544
+ IO.select(nil, [@socket], nil, e - now)
545
+ else
546
+ len += l
547
+ end
548
+ end
549
+ return len
550
+ end
551
+
543
552
  # Read EOF packet
544
- # === Exception
545
- # [ProtocolError] packet is not EOF
553
+ # @raise [ProtocolError] packet is not EOF
546
554
  def read_eof_packet
547
555
  raise ProtocolError, "packet is not EOF" unless read.eof?
548
556
  end
549
557
 
550
558
  # Send simple command
551
- # === Argument
552
- # packet :: [String] packet data
553
- # === Return
554
- # [String] received data
559
+ # @param packet :: [String] packet data
560
+ # @return [String] received data
555
561
  def simple_command(packet)
556
562
  synchronize do
557
563
  reset
@@ -560,19 +566,6 @@ class Mysql
560
566
  end
561
567
  end
562
568
 
563
- # Encrypt password
564
- # === Argument
565
- # plain :: [String] plain password.
566
- # scramble :: [String] scramble code from initial packet.
567
- # === Return
568
- # [String] encrypted password
569
- def encrypt_password(plain, scramble)
570
- return "" if plain.nil? or plain.empty?
571
- hash_stage1 = Digest::SHA1.digest plain
572
- hash_stage2 = Digest::SHA1.digest hash_stage1
573
- return hash_stage1.unpack("C*").zip(Digest::SHA1.digest(scramble+hash_stage2).unpack("C*")).map{|a,b| a^b}.pack("C*")
574
- end
575
-
576
569
  # Initial packet
577
570
  class InitialPacket
578
571
  def self.parse(pkt)
@@ -584,18 +577,26 @@ class Mysql
584
577
  server_capabilities = pkt.ushort
585
578
  server_charset = pkt.utiny
586
579
  server_status = pkt.ushort
587
- _f1 = pkt.read(13)
580
+ server_capabilities2 = pkt.ushort
581
+ scramble_length = pkt.utiny
582
+ _f1 = pkt.read(10)
588
583
  rest_scramble_buff = pkt.string
584
+ auth_plugin = pkt.string
585
+
586
+ server_capabilities |= server_capabilities2 << 16
587
+ scramble_buff.concat rest_scramble_buff
588
+
589
589
  raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION
590
590
  raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0
591
- scramble_buff.concat rest_scramble_buff
592
- self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff
591
+ raise ProtocolError, "invalid packet: scramble_length(#{scramble_length}) != length of scramble(#{scramble_buff.size + 1})" unless scramble_length == scramble_buff.size + 1
592
+
593
+ self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff, auth_plugin
593
594
  end
594
595
 
595
- attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff
596
+ attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff, :auth_plugin
596
597
 
597
598
  def initialize(*args)
598
- @protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff = args
599
+ @protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff, @auth_plugin = args
599
600
  end
600
601
  end
601
602
 
@@ -675,16 +676,35 @@ class Mysql
675
676
 
676
677
  # Authentication packet
677
678
  class AuthenticationPacket
678
- def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename)
679
- [
679
+ def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin)
680
+ data = [
680
681
  client_flags,
681
682
  max_packet_size,
682
- Packet.lcb(charset_number),
683
+ charset_number,
683
684
  "", # always 0x00 * 23
684
685
  username,
685
686
  Packet.lcs(scrambled_password),
686
- databasename
687
- ].pack("VVa*a23Z*A*Z*")
687
+ ]
688
+ pack = "VVCa23Z*A*"
689
+ if databasename
690
+ data.push databasename
691
+ pack.concat "Z*"
692
+ end
693
+ data.push auth_plugin
694
+ pack.concat "Z*"
695
+ data.pack(pack)
696
+ end
697
+ end
698
+
699
+ # TLS Authentication packet
700
+ class TlsAuthenticationPacket
701
+ def self.serialize(client_flags, max_packet_size, charset_number)
702
+ [
703
+ client_flags,
704
+ max_packet_size,
705
+ charset_number,
706
+ "", # always 0x00 * 23
707
+ ].pack("VVCa23")
688
708
  end
689
709
  end
690
710
 
@@ -712,6 +732,21 @@ class Mysql
712
732
  end
713
733
 
714
734
  end
735
+
736
+ class AuthenticationResultPacket
737
+ def self.parse(pkt)
738
+ result = pkt.utiny
739
+ auth_plugin = pkt.string
740
+ scramble = pkt.string
741
+ self.new(result, auth_plugin, scramble)
742
+ end
743
+
744
+ attr_reader :result, :auth_plugin, :scramble
745
+
746
+ def initialize(*args)
747
+ @result, @auth_plugin, @scramble = args
748
+ end
749
+ end
715
750
  end
716
751
 
717
752
  class RawRecord
@@ -732,17 +767,15 @@ class Mysql
732
767
  end
733
768
 
734
769
  class StmtRawRecord
735
- # === Argument
736
- # pkt :: [Packet]
737
- # fields :: [Array of Fields]
738
- # encoding:: [Encoding]
770
+ # @param pkt [Packet]
771
+ # @param fields [Array of Fields]
772
+ # @param encoding [Encoding]
739
773
  def initialize(packet, fields, encoding)
740
774
  @packet, @fields, @encoding = packet, fields, encoding
741
775
  end
742
776
 
743
777
  # Parse statement result packet
744
- # === Return
745
- # [Array of Object] one record
778
+ # @return [Array<Object>] one record
746
779
  def parse_record_packet
747
780
  @packet.utiny # skip first byte
748
781
  null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first
@@ -752,7 +785,7 @@ class Mysql
752
785
  else
753
786
  unsigned = f.flags & Field::UNSIGNED_FLAG != 0
754
787
  v = Protocol.net2value(@packet, f.type, unsigned)
755
- if v.is_a? Numeric or v.is_a? Mysql::Time
788
+ if v.nil? or v.is_a? Numeric or v.is_a? Time
756
789
  v
757
790
  elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
758
791
  Charset.to_binary(v)