mysql-pr 2.9.11 → 3.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,20 @@
1
+ # frozen_string_literal: true
2
+
1
3
  # Copyright (C) 2008-2012 TOMITA Masahiro
2
4
  # mailto:tommy@tmtm.org
3
5
 
4
6
  require "socket"
5
7
  require "timeout"
6
8
  require "digest/sha1"
9
+ require "digest/sha2"
7
10
  require "stringio"
11
+ require "openssl"
8
12
 
9
13
  class MysqlPR
10
14
  # MySQL network protocol
11
15
  class Protocol
12
-
13
16
  VERSION = 10
14
- MAX_PACKET_LENGTH = 2**24-1
17
+ MAX_PACKET_LENGTH = 2**24 - 1
15
18
 
16
19
  # Convert netdata to Ruby value
17
20
  # === Argument
@@ -23,42 +26,59 @@ class MysqlPR
23
26
  def self.net2value(pkt, type, unsigned)
24
27
  case type
25
28
  when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB
26
- return pkt.lcs
29
+ pkt.lcs
27
30
  when Field::TYPE_TINY
28
31
  v = pkt.utiny
29
- return unsigned ? v : v < 128 ? v : v-256
32
+ if unsigned
33
+ v
34
+ else
35
+ v < 128 ? v : v - 256
36
+ end
30
37
  when Field::TYPE_SHORT
31
38
  v = pkt.ushort
32
- return unsigned ? v : v < 32768 ? v : v-65536
39
+ if unsigned
40
+ v
41
+ else
42
+ v < 32_768 ? v : v - 65_536
43
+ end
33
44
  when Field::TYPE_INT24, Field::TYPE_LONG
34
45
  v = pkt.ulong
35
- return unsigned ? v : v < 2**32/2 ? v : v-2**32
46
+ if unsigned
47
+ v
48
+ else
49
+ v < 2**32 / 2 ? v : v - 2**32
50
+ end
36
51
  when Field::TYPE_LONGLONG
37
- n1, n2 = pkt.ulong, pkt.ulong
52
+ n1 = pkt.ulong
53
+ n2 = pkt.ulong
38
54
  v = (n2 << 32) | n1
39
- return unsigned ? v : v < 2**64/2 ? v : v-2**64
55
+ if unsigned
56
+ v
57
+ else
58
+ v < 2**64 / 2 ? v : v - 2**64
59
+ end
40
60
  when Field::TYPE_FLOAT
41
- return pkt.read(4).unpack('e').first
61
+ pkt.read(4).unpack1("e")
42
62
  when Field::TYPE_DOUBLE
43
- return pkt.read(8).unpack('E').first
63
+ pkt.read(8).unpack1("E")
44
64
  when Field::TYPE_DATE
45
65
  len = pkt.utiny
46
66
  y, m, d = pkt.read(len).unpack("vCC")
47
- t = MysqlPR::Time.new(y, m, d, nil, nil, nil)
48
- return t
67
+ MysqlPR::Time.new(y, m, d, nil, nil, nil)
68
+
49
69
  when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
50
70
  len = pkt.utiny
51
71
  y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")
52
- return MysqlPR::Time.new(y, m, d, h, mi, s, false, sp)
72
+ MysqlPR::Time.new(y, m, d, h, mi, s, false, sp)
53
73
  when Field::TYPE_TIME
54
74
  len = pkt.utiny
55
75
  sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")
56
76
  h = d.to_i * 24 + h.to_i
57
- return MysqlPR::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)
77
+ MysqlPR::Time.new(0, 0, 0, h, mi, s, sign != 0, sp)
58
78
  when Field::TYPE_YEAR
59
- return pkt.ushort
79
+ pkt.ushort
60
80
  when Field::TYPE_BIT
61
- return pkt.lcs
81
+ pkt.lcs
62
82
  else
63
83
  raise "not implemented: type=#{type}"
64
84
  end
@@ -90,26 +110,24 @@ class MysqlPR
90
110
  val = [v].pack("V")
91
111
  elsif v < 256**8
92
112
  type = Field::TYPE_LONGLONG | 0x8000
93
- val = [v&0xffffffff, v>>32].pack("VV")
113
+ val = [v & 0xffffffff, v >> 32].pack("VV")
94
114
  else
95
115
  raise ProtocolError, "value too large: #{v}"
96
116
  end
117
+ elsif -v <= 256 / 2
118
+ type = Field::TYPE_TINY
119
+ val = [v].pack("C")
120
+ elsif -v <= 256**2 / 2
121
+ type = Field::TYPE_SHORT
122
+ val = [v].pack("v")
123
+ elsif -v <= 256**4 / 2
124
+ type = Field::TYPE_LONG
125
+ val = [v].pack("V")
126
+ elsif -v <= 256**8 / 2
127
+ type = Field::TYPE_LONGLONG
128
+ val = [v & 0xffffffff, v >> 32].pack("VV")
97
129
  else
98
- if -v <= 256/2
99
- type = Field::TYPE_TINY
100
- val = [v].pack("C")
101
- elsif -v <= 256**2/2
102
- type = Field::TYPE_SHORT
103
- val = [v].pack("v")
104
- elsif -v <= 256**4/2
105
- type = Field::TYPE_LONG
106
- val = [v].pack("V")
107
- elsif -v <= 256**8/2
108
- type = Field::TYPE_LONGLONG
109
- val = [v&0xffffffff, v>>32].pack("VV")
110
- else
111
- raise ProtocolError, "value too large: #{v}"
112
- end
130
+ raise ProtocolError, "value too large: #{v}"
113
131
  end
114
132
  when Float
115
133
  type = Field::TYPE_DOUBLE
@@ -123,18 +141,11 @@ class MysqlPR
123
141
  else
124
142
  raise ProtocolError, "class #{v.class} is not supported"
125
143
  end
126
- return type, val
144
+ [type, val]
127
145
  end
128
146
 
129
- attr_reader :server_info
130
- attr_reader :server_version
131
- attr_reader :thread_id
132
- attr_reader :sqlstate
133
- attr_reader :affected_rows
134
- attr_reader :insert_id
135
- attr_reader :server_status
136
- attr_reader :warning_count
137
- attr_reader :message
147
+ attr_reader :server_info, :server_version, :thread_id, :sqlstate, :affected_rows, :insert_id, :server_status,
148
+ :warning_count, :message
138
149
  attr_accessor :charset
139
150
 
140
151
  # @state variable keep state for connection.
@@ -151,22 +162,29 @@ class MysqlPR
151
162
  # conn_timeout :: [Integer] connect timeout (sec).
152
163
  # read_timeout :: [Integer] read timeout (sec).
153
164
  # write_timeout :: [Integer] write timeout (sec).
165
+ # ssl_options :: [Hash / nil] SSL options. nil means no SSL.
154
166
  # === Exception
155
167
  # [ClientError] :: connection timeout
156
- def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout)
168
+ def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout, ssl_options = nil)
157
169
  @insert_id = 0
158
170
  @warning_count = 0
159
- @gc_stmt_queue = [] # stmt id list which GC destroy.
171
+ @gc_stmt_queue = [] # stmt id list which GC destroy.
160
172
  set_state :INIT
161
173
  @read_timeout = read_timeout
162
174
  @write_timeout = write_timeout
175
+ @ssl_options = ssl_options
176
+ @ssl_enabled = false
163
177
  begin
164
178
  Timeout.timeout conn_timeout do
165
- if host.nil? or host.empty? or host == "localhost"
179
+ if host.nil? || host.empty? || (host == "localhost")
166
180
  socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
167
181
  @sock = UNIXSocket.new socket
168
182
  else
169
- port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
183
+ port ||= ENV["MYSQL_TCP_PORT"] || begin
184
+ Socket.getservbyname("mysql", "tcp")
185
+ rescue StandardError
186
+ MYSQL_TCP_PORT
187
+ end
170
188
  @sock = TCPSocket.new host, port
171
189
  end
172
190
  end
@@ -175,6 +193,18 @@ class MysqlPR
175
193
  end
176
194
  end
177
195
 
196
+ # Returns true if SSL is enabled for this connection
197
+ def ssl_enabled?
198
+ @ssl_enabled
199
+ end
200
+
201
+ # Returns SSL cipher info if SSL is enabled
202
+ def ssl_cipher
203
+ return nil unless @ssl_enabled && @sock.respond_to?(:cipher)
204
+
205
+ @sock.cipher
206
+ end
207
+
178
208
  def close
179
209
  @sock.close
180
210
  end
@@ -194,22 +224,170 @@ class MysqlPR
194
224
  reset
195
225
  init_packet = InitialPacket.parse read
196
226
  @server_info = init_packet.server_version
197
- @server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
227
+ @server_version = init_packet.server_version.split(/\D/)[0, 3].inject { |a, b| a.to_i * 100 + b.to_i }
198
228
  @thread_id = init_packet.thread_id
199
- client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION
229
+ client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS |
230
+ CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION
231
+ client_flags |= CLIENT_PLUGIN_AUTH
200
232
  client_flags |= CLIENT_CONNECT_WITH_DB if db
201
233
  client_flags |= flag
202
234
  @charset = charset
203
235
  unless @charset
204
236
  @charset = Charset.by_number(init_packet.server_charset)
205
- @charset.encoding # raise error if unsupported charset
237
+ @charset.encoding # raise error if unsupported charset
206
238
  end
207
- netpw = encrypt_password passwd, init_packet.scramble_buff
208
- write AuthenticationPacket.serialize(client_flags, 1024**3, @charset.number, user, netpw, db)
209
- raise ProtocolError, 'The old style password is not supported' if read.to_s == "\xfe"
210
- set_state :READY
239
+
240
+ # SSL handshake if requested and server supports it
241
+ if @ssl_options && (init_packet.server_capabilities & CLIENT_SSL) != 0
242
+ client_flags |= CLIENT_SSL
243
+ # Send SSL request packet (partial auth packet with SSL flag)
244
+ write SSLRequestPacket.serialize(client_flags, 1024**3, @charset.number)
245
+ # Upgrade connection to SSL
246
+ upgrade_to_ssl
247
+ elsif @ssl_options && @ssl_options[:required]
248
+ raise ClientError, "SSL required but server does not support SSL"
249
+ end
250
+
251
+ auth_plugin = init_packet.auth_plugin_name || "mysql_native_password"
252
+ scramble = init_packet.scramble_buff
253
+
254
+ # Choose password encryption based on auth plugin
255
+ netpw = if auth_plugin == "caching_sha2_password"
256
+ encrypt_password_sha256(passwd, scramble)
257
+ else
258
+ encrypt_password(passwd, scramble)
259
+ end
260
+
261
+ write AuthenticationPacket.serialize(client_flags, 1024**3, @charset.number, user, netpw, db, auth_plugin)
262
+
263
+ # Read response
264
+ response = read
265
+ response_data = response.to_s
266
+
267
+ # Handle different response types
268
+ case response_data.getbyte(0)
269
+ when 0x00
270
+ # OK packet - authentication successful
271
+ set_state :READY
272
+ when 0xfe
273
+ # Auth switch request
274
+ handle_auth_switch(response_data, passwd)
275
+ when 0x01
276
+ # More data - caching_sha2_password specific
277
+ handle_caching_sha2_more_data(response_data, passwd, scramble)
278
+ else
279
+ raise ProtocolError, "Unexpected auth response: #{response_data.getbyte(0)}"
280
+ end
281
+ end
282
+
283
+ # Handle auth switch request
284
+ def handle_auth_switch(response_data, passwd)
285
+ # Parse auth switch request: 0xfe + plugin_name + scramble
286
+ pkt = Packet.new(response_data[1..])
287
+ plugin_name = pkt.string
288
+ scramble = pkt.to_s
289
+
290
+ if plugin_name == "mysql_native_password"
291
+ netpw = encrypt_password(passwd, scramble)
292
+ write netpw
293
+ read # OK or error
294
+ set_state :READY
295
+ elsif plugin_name == "caching_sha2_password"
296
+ netpw = encrypt_password_sha256(passwd, scramble)
297
+ write netpw
298
+ response = read
299
+ if response.to_s.getbyte(0) == 0x01
300
+ handle_caching_sha2_more_data(response.to_s, passwd, scramble)
301
+ else
302
+ set_state :READY
303
+ end
304
+ else
305
+ raise ProtocolError, "Unsupported auth plugin: #{plugin_name}"
306
+ end
307
+ end
308
+
309
+ # Handle caching_sha2_password "more data" response
310
+ def handle_caching_sha2_more_data(response_data, passwd, scramble)
311
+ # 0x01 + status byte
312
+ status = response_data.getbyte(1)
313
+
314
+ case status
315
+ when 0x03
316
+ # Fast auth success - server already has cached password hash
317
+ read # Read the final OK packet
318
+ set_state :READY
319
+ when 0x04
320
+ # Full authentication required
321
+ if @ssl_enabled
322
+ # Send plaintext password over SSL
323
+ write "#{passwd}\x00"
324
+ else
325
+ # Need RSA encryption - request public key
326
+ write "\x02" # Request public key
327
+ pubkey_response = read
328
+ pubkey_data = pubkey_response.to_s
329
+
330
+ raise ProtocolError, "Failed to get server public key" unless pubkey_data.getbyte(0) == 0x01
331
+
332
+ # Got public key
333
+ public_key = pubkey_data[1..]
334
+ encrypted_password = rsa_encrypt_password(passwd, scramble, public_key)
335
+ write encrypted_password
336
+
337
+ end
338
+ read
339
+ set_state :READY
340
+ else
341
+ raise ProtocolError, "Unknown caching_sha2_password status: #{status}"
342
+ end
343
+ end
344
+
345
+ # RSA encrypt password for caching_sha2_password
346
+ def rsa_encrypt_password(passwd, scramble, public_key_pem)
347
+ # XOR password with scramble
348
+ passwd_bytes = "#{passwd}\x00".bytes
349
+ scramble_bytes = scramble.bytes
350
+ xored = passwd_bytes.each_with_index.map { |b, i| b ^ scramble_bytes[i % scramble_bytes.length] }
351
+
352
+ # Encrypt with RSA public key
353
+ rsa = OpenSSL::PKey::RSA.new(public_key_pem)
354
+ rsa.public_encrypt(xored.pack("C*"), OpenSSL::PKey::RSA::PKCS1_OAEP_PADDING)
355
+ end
356
+
357
+ private
358
+
359
+ # Upgrade the connection to SSL/TLS
360
+ def upgrade_to_ssl
361
+ ssl_context = OpenSSL::SSL::SSLContext.new
362
+
363
+ # Configure SSL context based on options
364
+ ssl_context.ca_file = @ssl_options[:ca] if @ssl_options[:ca]
365
+ ssl_context.cert = OpenSSL::X509::Certificate.new(File.read(@ssl_options[:cert])) if @ssl_options[:cert]
366
+ ssl_context.key = OpenSSL::PKey::RSA.new(File.read(@ssl_options[:key])) if @ssl_options[:key]
367
+ ssl_context.ca_path = @ssl_options[:ca_path] if @ssl_options[:ca_path]
368
+
369
+ # Set verification mode
370
+ ssl_context.verify_mode = if @ssl_options[:verify] == false
371
+ OpenSSL::SSL::VERIFY_NONE
372
+ else
373
+ OpenSSL::SSL::VERIFY_PEER
374
+ end
375
+
376
+ # Set minimum TLS version if specified
377
+ ssl_context.min_version = @ssl_options[:min_version] if @ssl_options[:min_version]
378
+
379
+ # Wrap socket in SSL
380
+ ssl_socket = OpenSSL::SSL::SSLSocket.new(@sock, ssl_context)
381
+ ssl_socket.hostname = @ssl_options[:hostname] if @ssl_options[:hostname]
382
+ ssl_socket.sync_close = true
383
+ ssl_socket.connect
384
+
385
+ @sock = ssl_socket
386
+ @ssl_enabled = true
211
387
  end
212
388
 
389
+ public
390
+
213
391
  # Quit command
214
392
  def quit_command
215
393
  synchronize do
@@ -230,7 +408,7 @@ class MysqlPR
230
408
  reset
231
409
  write [COM_QUERY, @charset.convert(query)].pack("Ca*")
232
410
  get_result
233
- rescue
411
+ rescue StandardError
234
412
  set_state :READY
235
413
  raise
236
414
  end
@@ -240,26 +418,27 @@ class MysqlPR
240
418
  # === Return
241
419
  # [integer / nil] number of fields of results. nil if no results.
242
420
  def get_result
243
- begin
244
- res_packet = ResultPacket.parse read
245
- if res_packet.field_count.to_i > 0 # result data exists
246
- set_state :FIELD
247
- return res_packet.field_count
248
- end
249
- if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
250
- filename = res_packet.message
251
- File.open(filename){|f| write f}
252
- write nil # EOF mark
253
- read
254
- end
255
- @affected_rows, @insert_id, @server_status, @warning_count, @message =
256
- res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
257
- set_state :READY
258
- return nil
259
- rescue
260
- set_state :READY
261
- raise
262
- end
421
+ res_packet = ResultPacket.parse read
422
+ if res_packet.field_count.to_i.positive? # result data exists
423
+ set_state :FIELD
424
+ return res_packet.field_count
425
+ end
426
+ if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
427
+ filename = res_packet.message
428
+ File.open(filename) { |f| write f }
429
+ write nil # EOF mark
430
+ read
431
+ end
432
+ @affected_rows = res_packet.affected_rows
433
+ @insert_id = res_packet.insert_id
434
+ @server_status = res_packet.server_status
435
+ @warning_count = res_packet.warning_count
436
+ @message = res_packet.message
437
+ set_state :READY
438
+ nil
439
+ rescue StandardError
440
+ set_state :READY
441
+ raise
263
442
  end
264
443
 
265
444
  # Retrieve n fields
@@ -270,11 +449,11 @@ class MysqlPR
270
449
  def retr_fields(n)
271
450
  check_state :FIELD
272
451
  begin
273
- fields = n.times.map{Field.new FieldPacket.parse(read)}
452
+ fields = n.times.map { Field.new FieldPacket.parse(read) }
274
453
  read_eof_packet
275
454
  set_state :RESULT
276
455
  fields
277
- rescue
456
+ rescue StandardError
278
457
  set_state :READY
279
458
  raise
280
459
  end
@@ -328,11 +507,11 @@ class MysqlPR
328
507
  reset
329
508
  write [COM_PROCESS_INFO].pack("C")
330
509
  field_count = read.lcb
331
- fields = field_count.times.map{Field.new FieldPacket.parse(read)}
510
+ fields = field_count.times.map { Field.new FieldPacket.parse(read) }
332
511
  read_eof_packet
333
512
  set_state :RESULT
334
- return fields
335
- rescue
513
+ fields
514
+ rescue StandardError
336
515
  set_state :READY
337
516
  raise
338
517
  end
@@ -380,12 +559,12 @@ class MysqlPR
380
559
  reset
381
560
  write [COM_STMT_PREPARE, charset.convert(stmt)].pack("Ca*")
382
561
  res_packet = PrepareResultPacket.parse read
383
- if res_packet.param_count > 0
384
- res_packet.param_count.times{read} # skip parameter packet
562
+ if res_packet.param_count.positive?
563
+ res_packet.param_count.times { read } # skip parameter packet
385
564
  read_eof_packet
386
565
  end
387
- if res_packet.field_count > 0
388
- fields = res_packet.field_count.times.map{Field.new FieldPacket.parse(read)}
566
+ if res_packet.field_count.positive?
567
+ fields = res_packet.field_count.times.map { Field.new FieldPacket.parse(read) }
389
568
  read_eof_packet
390
569
  else
391
570
  fields = []
@@ -406,7 +585,7 @@ class MysqlPR
406
585
  reset
407
586
  write ExecutePacket.serialize(stmt_id, MysqlPR::Stmt::CURSOR_TYPE_NO_CURSOR, values)
408
587
  get_result
409
- rescue
588
+ rescue StandardError
410
589
  set_state :READY
411
590
  raise
412
591
  end
@@ -449,36 +628,34 @@ class MysqlPR
449
628
  private
450
629
 
451
630
  def check_state(st)
452
- raise 'command out of sync' unless @state == st
631
+ raise "command out of sync" unless @state == st
453
632
  end
454
633
 
455
634
  def set_state(st)
456
635
  @state = st
457
- if st == :READY
458
- gc_disabled = GC.disable unless RUBY_PLATFORM == 'java'
459
- begin
460
- while st = @gc_stmt_queue.shift
461
- reset
462
- write [COM_STMT_CLOSE, st].pack("CV")
463
- end
464
- ensure
465
- GC.enable unless gc_disabled unless RUBY_PLATFORM == 'java'
636
+ return unless st == :READY
637
+
638
+ gc_disabled = GC.disable unless RUBY_PLATFORM == "java"
639
+ begin
640
+ while (st = @gc_stmt_queue.shift)
641
+ reset
642
+ write [COM_STMT_CLOSE, st].pack("CV")
466
643
  end
644
+ ensure
645
+ GC.enable if RUBY_PLATFORM != "java" && !gc_disabled
467
646
  end
468
647
  end
469
648
 
470
649
  def synchronize
471
- begin
472
- check_state :READY
473
- return yield
474
- ensure
475
- set_state :READY
476
- end
650
+ check_state :READY
651
+ yield
652
+ ensure
653
+ set_state :READY
477
654
  end
478
655
 
479
656
  # Reset sequence number
480
657
  def reset
481
- @seq = 0 # packet counter. reset by each command
658
+ @seq = 0 # packet counter. reset by each command
482
659
  end
483
660
 
484
661
  # Read one packet data
@@ -489,35 +666,39 @@ class MysqlPR
489
666
  def read
490
667
  ret = ""
491
668
  len = nil
492
- begin
669
+ loop do
493
670
  Timeout.timeout @read_timeout do
494
671
  header = @sock.read(4)
495
672
  raise EOFError unless header && header.length == 4
673
+
496
674
  len1, len2, seq = header.unpack("CvC")
497
675
  len = (len2 << 8) + len1
498
676
  raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
677
+
499
678
  @seq = (@seq + 1) % 256
500
679
  ret = @sock.read(len)
501
680
  raise EOFError unless ret && ret.length == len
502
681
  end
503
682
  rescue EOFError
504
- raise ClientError::ServerGoneError, 'The MySQL server has gone away'
683
+ raise ClientError::ServerGoneError, "The MySQL server has gone away"
505
684
  rescue Timeout::Error
506
685
  raise ClientError, "read timeout"
507
- end while len == MAX_PACKET_LENGTH
686
+ break unless len == MAX_PACKET_LENGTH
687
+ end
508
688
 
509
689
  @sqlstate = "00000"
510
690
 
511
- # Error packet
512
- if ret[0] == ?\xff
513
- f, errno, marker, @sqlstate, message = ret.unpack("Cvaa5a*")
691
+ # Error packet (use getbyte for encoding-safe comparison)
692
+ if ret.getbyte(0) == 0xff
693
+ _, errno, marker, @sqlstate, message = ret.unpack("Cvaa5a*")
514
694
  unless marker == "#"
515
- f, errno, message = ret.unpack("Cva*") # Version 4.0 Error
695
+ _, errno, message = ret.unpack("Cva*") # Version 4.0 Error
516
696
  @sqlstate = ""
517
697
  end
518
698
  if MysqlPR::ServerError::ERROR_MAP.key? errno
519
699
  raise MysqlPR::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
520
700
  end
701
+
521
702
  raise MysqlPR::ServerError.new(message, @sqlstate)
522
703
  end
523
704
  Packet.new(ret)
@@ -527,32 +708,30 @@ class MysqlPR
527
708
  # === Argument
528
709
  # data :: [String / IO] packet data. If data is nil, write empty packet.
529
710
  def write(data)
530
- begin
531
- @sock.sync = false
532
- if data.nil?
711
+ @sock.sync = false
712
+ if data.nil?
713
+ Timeout.timeout @write_timeout do
714
+ @sock.write [0, 0, @seq].pack("CvC")
715
+ end
716
+ @seq = (@seq + 1) % 256
717
+ else
718
+ data = StringIO.new data if data.is_a? String
719
+ while (d = data.read(MAX_PACKET_LENGTH))
533
720
  Timeout.timeout @write_timeout do
534
- @sock.write [0, 0, @seq].pack("CvC")
721
+ @sock.write [d.length % 256, d.length / 256, @seq].pack("CvC")
722
+ @sock.write d
535
723
  end
536
724
  @seq = (@seq + 1) % 256
537
- else
538
- data = StringIO.new data if data.is_a? String
539
- while d = data.read(MAX_PACKET_LENGTH)
540
- Timeout.timeout @write_timeout do
541
- @sock.write [d.length%256, d.length/256, @seq].pack("CvC")
542
- @sock.write d
543
- end
544
- @seq = (@seq + 1) % 256
545
- end
546
- end
547
- @sock.sync = true
548
- Timeout.timeout @write_timeout do
549
- @sock.flush
550
725
  end
551
- rescue Errno::EPIPE
552
- raise ClientError::ServerGoneError, 'The MySQL server has gone away'
553
- rescue Timeout::Error
554
- raise ClientError, "write timeout"
555
726
  end
727
+ @sock.sync = true
728
+ Timeout.timeout @write_timeout do
729
+ @sock.flush
730
+ end
731
+ rescue Errno::EPIPE
732
+ raise ClientError::ServerGoneError, "The MySQL server has gone away"
733
+ rescue Timeout::Error
734
+ raise ClientError, "write timeout"
556
735
  end
557
736
 
558
737
  # Read EOF packet
@@ -575,17 +754,36 @@ class MysqlPR
575
754
  end
576
755
  end
577
756
 
578
- # Encrypt password
757
+ # Encrypt password for mysql_native_password (SHA1)
579
758
  # === Argument
580
759
  # plain :: [String] plain password.
581
760
  # scramble :: [String] scramble code from initial packet.
582
761
  # === Return
583
762
  # [String] encrypted password
584
763
  def encrypt_password(plain, scramble)
585
- return "" if plain.nil? or plain.empty?
764
+ return "" if plain.nil? || plain.empty?
765
+
586
766
  hash_stage1 = Digest::SHA1.digest plain
587
767
  hash_stage2 = Digest::SHA1.digest hash_stage1
588
- return hash_stage1.unpack("C*").zip(Digest::SHA1.digest(scramble+hash_stage2).unpack("C*")).map{|a,b| a^b}.pack("C*")
768
+ hash_stage1.unpack("C*").zip(Digest::SHA1.digest(scramble + hash_stage2).unpack("C*")).map do |a, b|
769
+ a ^ b
770
+ end.pack("C*")
771
+ end
772
+
773
+ # Encrypt password for caching_sha2_password (SHA256)
774
+ # === Argument
775
+ # plain :: [String] plain password.
776
+ # scramble :: [String] scramble code from initial packet.
777
+ # === Return
778
+ # [String] encrypted password
779
+ def encrypt_password_sha256(plain, scramble)
780
+ return "" if plain.nil? || plain.empty?
781
+
782
+ hash_stage1 = Digest::SHA256.digest(plain)
783
+ hash_stage2 = Digest::SHA256.digest(hash_stage1)
784
+ hash_stage1.unpack("C*").zip(Digest::SHA256.digest(hash_stage2 + scramble).unpack("C*")).map do |a, b|
785
+ a ^ b
786
+ end.pack("C*")
589
787
  end
590
788
 
591
789
  # Initial packet
@@ -599,18 +797,34 @@ class MysqlPR
599
797
  server_capabilities = pkt.ushort
600
798
  server_charset = pkt.utiny
601
799
  server_status = pkt.ushort
602
- f1 = pkt.read(13)
603
- rest_scramble_buff = pkt.string
800
+ server_capabilities_upper = pkt.ushort
801
+ auth_plugin_data_length = pkt.utiny
802
+ pkt.read(10) # reserved
803
+ # Read rest of scramble (12 bytes for caching_sha2_password, or variable)
804
+ rest_scramble_len = [auth_plugin_data_length - 8, 12].max
805
+ rest_scramble_buff = pkt.read(rest_scramble_len)
806
+ # Remove trailing null if present
807
+ rest_scramble_buff = rest_scramble_buff.sub(/\x00+\z/, "")
808
+ auth_plugin_name = begin
809
+ pkt.string
810
+ rescue StandardError
811
+ "mysql_native_password"
812
+ end
604
813
  raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION
605
- raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0
814
+ raise ProtocolError, "invalid packet: f0=#{f0}" unless f0.zero?
815
+
606
816
  scramble_buff.concat rest_scramble_buff
607
- self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff
817
+ server_capabilities |= (server_capabilities_upper << 16)
818
+ new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status,
819
+ scramble_buff, auth_plugin_name
608
820
  end
609
821
 
610
- attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff
822
+ attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset,
823
+ :server_status, :scramble_buff, :auth_plugin_name
611
824
 
612
825
  def initialize(*args)
613
- @protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff = args
826
+ @protocol_version, @server_version, @thread_id, @server_capabilities,
827
+ @server_charset, @server_status, @scramble_buff, @auth_plugin_name = args
614
828
  end
615
829
  end
616
830
 
@@ -618,17 +832,17 @@ class MysqlPR
618
832
  class ResultPacket
619
833
  def self.parse(pkt)
620
834
  field_count = pkt.lcb
621
- if field_count == 0
835
+ if field_count.zero?
622
836
  affected_rows = pkt.lcb
623
837
  insert_id = pkt.lcb
624
838
  server_status = pkt.ushort
625
839
  warning_count = pkt.ushort
626
840
  message = pkt.lcs
627
- return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message)
628
- elsif field_count.nil? # LOAD DATA LOCAL INFILE
629
- return self.new(nil, nil, nil, nil, nil, pkt.to_s)
841
+ new(field_count, affected_rows, insert_id, server_status, warning_count, message)
842
+ elsif field_count.nil? # LOAD DATA LOCAL INFILE
843
+ new(nil, nil, nil, nil, nil, pkt.to_s)
630
844
  else
631
- return self.new(field_count)
845
+ new(field_count)
632
846
  end
633
847
  end
634
848
 
@@ -642,13 +856,13 @@ class MysqlPR
642
856
  # Field packet
643
857
  class FieldPacket
644
858
  def self.parse(pkt)
645
- first = pkt.lcs
859
+ pkt.lcs
646
860
  db = pkt.lcs
647
861
  table = pkt.lcs
648
862
  org_table = pkt.lcs
649
863
  name = pkt.lcs
650
864
  org_name = pkt.lcs
651
- f0 = pkt.utiny
865
+ pkt.utiny
652
866
  charsetnr = pkt.ushort
653
867
  length = pkt.ulong
654
868
  type = pkt.utiny
@@ -656,9 +870,10 @@ class MysqlPR
656
870
  decimals = pkt.utiny
657
871
  f1 = pkt.ushort
658
872
 
659
- raise ProtocolError, "invalid packet: f1=#{f1}" unless f1 == 0
873
+ raise ProtocolError, "invalid packet: f1=#{f1}" unless f1.zero?
874
+
660
875
  default = pkt.lcs
661
- return self.new(db, table, org_table, name, org_name, charsetnr, length, type, flags, decimals, default)
876
+ new(db, table, org_table, name, org_name, charsetnr, length, type, flags, decimals, default)
662
877
  end
663
878
 
664
879
  attr_reader :db, :table, :org_table, :name, :org_name, :charsetnr, :length, :type, :flags, :decimals, :default
@@ -671,14 +886,16 @@ class MysqlPR
671
886
  # Prepare result packet
672
887
  class PrepareResultPacket
673
888
  def self.parse(pkt)
674
- raise ProtocolError, "invalid packet" unless pkt.utiny == 0
889
+ raise ProtocolError, "invalid packet" unless pkt.utiny.zero?
890
+
675
891
  statement_id = pkt.ulong
676
892
  field_count = pkt.ushort
677
893
  param_count = pkt.ushort
678
894
  f = pkt.utiny
679
895
  warning_count = pkt.ushort
680
- raise ProtocolError, "invalid packet" unless f == 0x00
681
- self.new statement_id, field_count, param_count, warning_count
896
+ raise ProtocolError, "invalid packet" unless f.zero?
897
+
898
+ new statement_id, field_count, param_count, warning_count
682
899
  end
683
900
 
684
901
  attr_reader :statement_id, :field_count, :param_count, :warning_count
@@ -688,18 +905,35 @@ class MysqlPR
688
905
  end
689
906
  end
690
907
 
908
+ # SSL Request packet - sent before SSL handshake
909
+ class SSLRequestPacket
910
+ def self.serialize(client_flags, max_packet_size, charset_number)
911
+ [
912
+ client_flags,
913
+ max_packet_size,
914
+ charset_number,
915
+ "" # filler: 23 bytes of 0x00
916
+ ].pack("VVCa23")
917
+ end
918
+ end
919
+
691
920
  # Authentication packet
692
921
  class AuthenticationPacket
693
- def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename)
694
- [
922
+ def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename,
923
+ auth_plugin_name = nil)
924
+ packet = [
695
925
  client_flags,
696
926
  max_packet_size,
697
- Packet.lcb(charset_number),
698
- "", # always 0x00 * 23
699
- username,
700
- Packet.lcs(scrambled_password),
701
- databasename
702
- ].pack("VVa*a23Z*A*Z*")
927
+ charset_number,
928
+ "" # reserved 23 bytes
929
+ ].pack("VVCa23")
930
+
931
+ packet << "#{username}\x00"
932
+ packet << Packet.lcs(scrambled_password)
933
+ packet << "#{databasename}\x00" if databasename && (client_flags & MysqlPR::CLIENT_CONNECT_WITH_DB) != 0
934
+ packet << "#{auth_plugin_name}\x00" if auth_plugin_name && (client_flags & MysqlPR::CLIENT_PLUGIN_AUTH) != 0
935
+
936
+ packet
703
937
  end
704
938
  end
705
939
 
@@ -713,30 +947,32 @@ class MysqlPR
713
947
  netvalues.concat n if v
714
948
  t
715
949
  end
716
- [MysqlPR::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
950
+ [MysqlPR::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"),
951
+ netvalues].pack("CVCVa*Ca*a*")
717
952
  end
718
953
 
719
954
  # make null bitmap
720
955
  #
721
956
  # If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
722
957
  def self.null_bitmap(values)
723
- bitmap = values.enum_for(:each_slice,8).map do |vals|
724
- vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}
958
+ bitmap = values.enum_for(:each_slice, 8).map do |vals|
959
+ vals.reverse.inject(0) { |b, v| (b << 1 | (v ? 0 : 1)) }
725
960
  end
726
- return bitmap.pack("C*")
961
+ bitmap.pack("C*")
727
962
  end
728
-
729
963
  end
730
964
  end
731
965
 
732
966
  class RawRecord
733
967
  def initialize(packet, nfields, encoding)
734
- @packet, @nfields, @encoding = packet, nfields, encoding
968
+ @packet = packet
969
+ @nfields = nfields
970
+ @encoding = encoding
735
971
  end
736
972
 
737
973
  def to_a
738
974
  @nfields.times.map do
739
- if s = @packet.lcs
975
+ if (s = @packet.lcs)
740
976
  s = Charset.convert_encoding(s, @encoding)
741
977
  end
742
978
  s
@@ -750,34 +986,34 @@ class MysqlPR
750
986
  # fields :: [Array of Fields]
751
987
  # encoding:: [Encoding]
752
988
  def initialize(packet, fields, encoding)
753
- @packet, @fields, @encoding = packet, fields, encoding
989
+ @packet = packet
990
+ @fields = fields
991
+ @encoding = encoding
754
992
  end
755
993
 
756
994
  # Parse statement result packet
757
995
  # === Return
758
996
  # [Array of Object] one record
759
997
  def parse_record_packet
760
- @packet.utiny # skip first byte
761
- null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first
762
- rec = @fields.each_with_index.map do |f, i|
763
- if null_bit_map[i+2] == ?1
998
+ @packet.utiny # skip first byte
999
+ null_bit_map = @packet.read((@fields.length + 7 + 2) / 8).unpack1("b*")
1000
+ @fields.each_with_index.map do |f, i|
1001
+ if null_bit_map[i + 2] == "1"
764
1002
  nil
765
1003
  else
766
1004
  unsigned = f.flags & Field::UNSIGNED_FLAG != 0
767
1005
  v = Protocol.net2value(@packet, f.type, unsigned)
768
- if v.is_a? Numeric or v.is_a? MysqlPR::Time
1006
+ if v.is_a?(Numeric) || v.is_a?(MysqlPR::Time)
769
1007
  v
770
- elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
1008
+ elsif (f.type == Field::TYPE_BIT) || (f.charsetnr == Charset::BINARY_CHARSET_NUMBER)
771
1009
  Charset.to_binary(v)
772
1010
  else
773
1011
  Charset.convert_encoding(v, @encoding)
774
1012
  end
775
1013
  end
776
1014
  end
777
- rec
778
1015
  end
779
1016
 
780
1017
  alias to_a parse_record_packet
781
-
782
1018
  end
783
1019
  end