ruby-mysql 3.0.0 → 4.0.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,16 +1,19 @@
1
1
  # coding: ascii-8bit
2
+
2
3
  # Copyright (C) 2008 TOMITA Masahiro
3
4
  # mailto:tommy@tmtm.org
4
5
 
5
6
  require "socket"
6
7
  require "stringio"
7
8
  require "openssl"
8
- require_relative 'authenticator.rb'
9
+ require "bigdecimal"
10
+ require "date"
11
+ require 'time'
12
+ require_relative 'authenticator'
9
13
 
10
14
  class Mysql
11
15
  # MySQL network protocol
12
16
  class Protocol
13
-
14
17
  VERSION = 10
15
18
  MAX_PACKET_LENGTH = 2**24-1
16
19
 
@@ -21,8 +24,11 @@ class Mysql
21
24
  # @return [Object] converted value.
22
25
  def self.net2value(pkt, type, unsigned)
23
26
  case type
24
- when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB, Field::TYPE_JSON
27
+ when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_BLOB, Field::TYPE_JSON
25
28
  return pkt.lcs
29
+ when Field::TYPE_NEWDECIMAL
30
+ s = pkt.lcs
31
+ return s =~ /\./ && s !~ /\.0*\z/ ? BigDecimal(s) : s.to_i
26
32
  when Field::TYPE_TINY
27
33
  v = pkt.utiny
28
34
  return unsigned ? v : v < 128 ? v : v-256
@@ -37,18 +43,18 @@ class Mysql
37
43
  v = (n2 << 32) | n1
38
44
  return unsigned ? v : v < 0x8000_0000_0000_0000 ? v : v-0x10000_0000_0000_0000
39
45
  when Field::TYPE_FLOAT
40
- return pkt.read(4).unpack('e').first
46
+ return pkt.read(4).unpack1('e')
41
47
  when Field::TYPE_DOUBLE
42
- return pkt.read(8).unpack('E').first
48
+ return pkt.read(8).unpack1('E')
43
49
  when Field::TYPE_DATE
44
50
  len = pkt.utiny
45
51
  y, m, d = pkt.read(len).unpack("vCC")
46
- t = Time.new(y, m, d) rescue nil
52
+ t = Date.new(y, m, d) rescue nil
47
53
  return t
48
54
  when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
49
55
  len = pkt.utiny
50
56
  y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")
51
- return Time.new(y, m, d, h, mi, Rational((s.to_i*1000000+sp.to_i)/1000000)) rescue nil
57
+ return Time.new(y, m, d, h, mi, Rational(s.to_i*1000000+sp.to_i, 1000000)) rescue nil
52
58
  when Field::TYPE_TIME
53
59
  len = pkt.utiny
54
60
  sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")
@@ -70,6 +76,7 @@ class Mysql
70
76
  # @return [String] netdata
71
77
  # @raise [ProtocolError] value too large / value is not supported
72
78
  def self.value2net(v)
79
+ v = v == true ? 1 : v == false ? 0 : v
73
80
  case v
74
81
  when nil
75
82
  type = Field::TYPE_NULL
@@ -85,8 +92,12 @@ class Mysql
85
92
  type = Field::TYPE_LONGLONG | 0x8000
86
93
  val = [v&0xffffffff, v>>32].pack("VV")
87
94
  else
88
- raise ProtocolError, "value too large: #{v}"
95
+ type =Field::TYPE_NEWDECIMAL
96
+ val = Packet.lcs(v.to_s)
89
97
  end
98
+ when BigDecimal
99
+ type = Field::TYPE_NEWDECIMAL
100
+ val = Packet.lcs(v.to_s)
90
101
  when Float
91
102
  type = Field::TYPE_DOUBLE
92
103
  val = [v].pack("E")
@@ -96,6 +107,12 @@ class Mysql
96
107
  when Time
97
108
  type = Field::TYPE_DATETIME
98
109
  val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec].pack("CvCCCCCV")
110
+ when DateTime
111
+ type = Field::TYPE_DATETIME
112
+ val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, (v.sec_fraction*1000000).to_i].pack("CvCCCCCV")
113
+ when Date
114
+ type = Field::TYPE_DATE
115
+ val = [11, v.year, v.month, v.day, 0, 0, 0, 0].pack("CvCCCCCV")
99
116
  else
100
117
  raise ProtocolError, "class #{v.class} is not supported"
101
118
  end
@@ -112,14 +129,18 @@ class Mysql
112
129
  attr_reader :server_status
113
130
  attr_reader :warning_count
114
131
  attr_reader :message
132
+ attr_reader :session_track
115
133
  attr_reader :get_server_public_key
134
+ attr_reader :field_count
116
135
  attr_accessor :charset
117
136
 
118
137
  # @state variable keep state for connection.
119
- # :INIT :: Initial state.
120
- # :READY :: Ready for command.
121
- # :FIELD :: After query(). retr_fields() is needed.
122
- # :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
138
+ # :INIT :: Initial state.
139
+ # :READY :: Ready for command.
140
+ # :WAIT_RESULT :: After query_command(). get_result() is needed.
141
+ # :FIELD :: After get_result(). retr_fields() is needed.
142
+ # :RESULT :: After retr_fields(), retr_all_records() is needed.
143
+ # :CLOSED :: Connection closed.
123
144
 
124
145
  # make socket connection to server.
125
146
  # @param opts [Hash]
@@ -137,13 +158,16 @@ class Mysql
137
158
  # @option :local_infile [Boolean]
138
159
  # @option :load_data_local_dir [String]
139
160
  # @option :ssl_mode [Integer]
161
+ # @option :ssl_context_params [Hash<:Symbol, String>]
140
162
  # @option :get_server_public_key [Boolean]
141
163
  # @raise [ClientError] connection timeout
142
164
  def initialize(opts)
165
+ @mutex = Mutex.new
143
166
  @opts = opts
144
167
  @charset = Mysql::Charset.by_name("utf8mb4")
145
168
  @insert_id = 0
146
169
  @warning_count = 0
170
+ @session_track = {}
147
171
  @gc_stmt_queue = [] # stmt id list which GC destroy.
148
172
  set_state :INIT
149
173
  @get_server_public_key = @opts[:get_server_public_key]
@@ -152,7 +176,7 @@ class Mysql
152
176
  socket = @opts[:socket] || ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
153
177
  @socket = Socket.unix(socket)
154
178
  else
155
- port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
179
+ port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql", "tcp") rescue MYSQL_TCP_PORT)
156
180
  @socket = Socket.tcp(@opts[:host], port, connect_timeout: @opts[:connect_timeout])
157
181
  end
158
182
  rescue Errno::ETIMEDOUT
@@ -168,146 +192,199 @@ class Mysql
168
192
  # @param charset [Mysql::Charset, nil] charset for connection. nil: use server's charset
169
193
  # @raise [ProtocolError] The old style password is not supported
170
194
  def authenticate
171
- check_state :INIT
172
- reset
173
- init_packet = InitialPacket.parse read
174
- @server_info = init_packet.server_version
175
- @server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
176
- @server_capabilities = init_packet.server_capabilities
177
- @thread_id = init_packet.thread_id
178
- @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
179
- @client_flags |= CLIENT_LOCAL_FILES if @opts[:local_infile] || @opts[:load_data_local_dir]
180
- @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
181
- @client_flags |= @opts[:flags]
182
- if @opts[:charset]
183
- @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
184
- else
185
- @charset = Charset.by_number(init_packet.server_charset)
186
- @charset.encoding # raise error if unsupported charset
187
- end
188
- enable_ssl
189
- Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin)
190
- set_state :READY
191
- end
195
+ synchronize(before: :INIT, after: :READY) do
196
+ reset
197
+ init_packet = InitialPacket.parse read
198
+ @server_info = init_packet.server_version
199
+ @server_version = init_packet.server_version.split(/\D/)[0, 3].inject{|a, b| a.to_i*100+b.to_i}
200
+ @server_capabilities = init_packet.server_capabilities
201
+ @thread_id = init_packet.thread_id
202
+ @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_RESULTS | CLIENT_PS_MULTI_RESULTS | CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_ATTRS | CLIENT_SESSION_TRACK | CLIENT_LOCAL_FILES
203
+ @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
204
+ @client_flags |= @opts[:flags]
205
+ if @opts[:charset]
206
+ @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
207
+ else
208
+ @charset = Charset.by_number(init_packet.server_charset)
209
+ @charset.encoding # raise error if unsupported charset
210
+ end
211
+ enable_ssl
212
+ Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin, @opts[:connect_attrs])
213
+ end
214
+ end
215
+
216
+ SSL_MODE_KEY = {
217
+ SSL_MODE_DISABLED => 1,
218
+ SSL_MODE_PREFERRED => 2,
219
+ SSL_MODE_REQUIRED => 3,
220
+ SSL_MODE_VERIFY_CA => 4,
221
+ SSL_MODE_VERIFY_IDENTITY => 5,
222
+ '1' => 1,
223
+ '2' => 2,
224
+ '3' => 3,
225
+ '4' => 4,
226
+ '5' => 5,
227
+ 'disabled' => 1,
228
+ 'preferred' => 2,
229
+ 'required' => 3,
230
+ 'verify_ca' => 4,
231
+ 'verify_identity' => 5,
232
+ :disabled => 1,
233
+ :preferred => 2,
234
+ :required => 3,
235
+ :verify_ca => 4,
236
+ :verify_identity => 5,
237
+ }.freeze
192
238
 
193
239
  def enable_ssl
194
- case @opts[:ssl_mode]
195
- when SSL_MODE_DISABLED, '1', 'disabled'
196
- return
197
- when SSL_MODE_PREFERRED, '2', 'preferred'
240
+ ssl_mode = SSL_MODE_KEY[@opts[:ssl_mode]]
241
+ raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported" unless ssl_mode
242
+
243
+ return if ssl_mode == SSL_MODE_DISABLED
244
+ if ssl_mode == SSL_MODE_PREFERRED
198
245
  return if @socket.local_address.unix?
199
246
  return if @server_capabilities & CLIENT_SSL == 0
200
- when SSL_MODE_REQUIRED, '3', 'required'
201
- if @server_capabilities & CLIENT_SSL == 0
202
- raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
203
- end
204
- else
205
- raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported"
206
247
  end
207
- begin
208
- @client_flags |= CLIENT_SSL
209
- write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
210
- @socket = OpenSSL::SSL::SSLSocket.new(@socket)
211
- @socket.sync_close = true
212
- @socket.connect
213
- rescue => e
214
- @client_flags &= ~CLIENT_SSL
215
- return if @opts[:ssl_mode] == SSL_MODE_PREFERRED
216
- raise e
248
+ if ssl_mode >= SSL_MODE_REQUIRED && @server_capabilities & CLIENT_SSL == 0
249
+ raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
217
250
  end
251
+
252
+ context = OpenSSL::SSL::SSLContext.new
253
+ context.set_params(@opts[:ssl_context_params])
254
+ context.verify_mode = OpenSSL::SSL::VERIFY_NONE if ssl_mode < SSL_MODE_VERIFY_CA
255
+ context.verify_hostname = false if ssl_mode < SSL_MODE_VERIFY_IDENTITY
256
+
257
+ ssl_socket = OpenSSL::SSL::SSLSocket.new(@socket, context)
258
+ ssl_socket.sync_close = true
259
+ ssl_socket.hostname = @opts[:host] if ssl_mode >= SSL_MODE_VERIFY_IDENTITY
260
+
261
+ @client_flags |= CLIENT_SSL
262
+ write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
263
+
264
+ ssl_socket.connect
265
+ @socket = ssl_socket
266
+ rescue OpenSSL::SSL::SSLError => e
267
+ @client_flags &= ~CLIENT_SSL
268
+ return if @opts[:ssl_mode] < SSL_MODE_REQUIRED
269
+ raise e
270
+ end
271
+
272
+ def ssl_cipher
273
+ @client_flags.allbits?(CLIENT_SSL) ? @socket.cipher : nil
218
274
  end
219
275
 
220
276
  # Quit command
221
277
  def quit_command
222
- synchronize do
278
+ get_result if @state == :WAIT_RESULT
279
+ retr_fields if @state == :FIELD
280
+ retr_all_records(RawRecord) if @state == :RESULT
281
+ synchronize(before: :READY, after: :CLOSED) do
223
282
  reset
224
283
  write [COM_QUIT].pack("C")
225
284
  close
285
+ @gc_stmt_queue.clear
226
286
  end
227
287
  end
228
288
 
229
289
  # Query command
230
290
  # @param query [String] query string
231
- # @return [Integer, nil] number of fields of results. nil if no results.
232
291
  def query_command(query)
233
- check_state :READY
234
- begin
292
+ synchronize(before: :READY, after: :WAIT_RESULT, error: :READY) do
235
293
  reset
236
294
  write [COM_QUERY, @charset.convert(query)].pack("Ca*")
237
- get_result
238
- rescue
239
- set_state :READY
240
- raise
241
295
  end
242
296
  end
243
297
 
244
298
  # get result of query.
245
299
  # @return [integer, nil] number of fields of results. nil if no results.
246
300
  def get_result
247
- begin
301
+ synchronize(before: :WAIT_RESULT, error: :READY) do
248
302
  res_packet = ResultPacket.parse read
249
- if res_packet.field_count.to_i > 0 # result data exists
303
+ @field_count = res_packet.field_count
304
+ if @field_count.to_i > 0 # result data exists
250
305
  set_state :FIELD
251
- return res_packet.field_count
306
+ return @field_count
252
307
  end
253
- if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
308
+ if @field_count.nil? # LOAD DATA LOCAL INFILE
254
309
  send_local_file(res_packet.message)
310
+ res_packet = ResultPacket.parse read
255
311
  end
256
- @affected_rows, @insert_id, @server_status, @warning_count, @message =
257
- res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
258
- set_state :READY
312
+ @affected_rows, @insert_id, @server_status, @warning_count, @message, @session_track =
313
+ res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message, res_packet.session_track
314
+ set_state :READY unless more_results?
259
315
  return nil
260
- rescue
261
- set_state :READY
262
- raise
263
316
  end
264
317
  end
265
318
 
319
+ def more_results?
320
+ @server_status & SERVER_MORE_RESULTS_EXISTS != 0
321
+ end
322
+
266
323
  # send local file to server
267
324
  def send_local_file(filename)
268
325
  filename = File.absolute_path(filename)
269
326
  if @opts[:local_infile] || @opts[:load_data_local_dir] && filename.start_with?(@opts[:load_data_local_dir])
270
327
  File.open(filename){|f| write f}
328
+ write nil # EOF
271
329
  else
330
+ write nil # send empty data instead of file contents
331
+ read # result packet
272
332
  raise ClientError::LoadDataLocalInfileRejected, 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
273
333
  end
274
- ensure
275
- write nil # EOF mark
276
- read
277
334
  end
278
335
 
279
336
  # Retrieve n fields
280
- # @param n [Integer] number of fields
281
337
  # @return [Array<Mysql::Field>] field list
282
- def retr_fields(n)
283
- check_state :FIELD
284
- begin
285
- fields = n.times.map{Field.new FieldPacket.parse(read)}
338
+ def retr_fields
339
+ synchronize(before: :FIELD, after: :RESULT, error: :READY) do
340
+ @fields = @field_count.times.map{Field.new FieldPacket.parse(read)}
286
341
  read_eof_packet
287
- set_state :RESULT
288
- fields
289
- rescue
290
- set_state :READY
291
- raise
342
+ @no_more_records = false
343
+ @fields
292
344
  end
293
345
  end
294
346
 
295
- # Retrieve all records for simple query
296
- # @param fields [Array<Mysql::Field>] number of fields
297
- # @return [Array<Array<String>>] all records
298
- def retr_all_records(fields)
299
- check_state :RESULT
300
- enc = charset.encoding
301
- begin
302
- all_recs = []
303
- until (pkt = read).eof?
304
- all_recs.push RawRecord.new(pkt, fields, enc)
347
+ # Retrieve one record for simple query or prepared statement
348
+ # @param record_class [RawRecord or StmtRawRecord]
349
+ # @return [<record_class>] record
350
+ # @return [nil] no more record
351
+ def retr_record(record_class)
352
+ return nil if @no_more_records
353
+ synchronize(before: :RESULT) do
354
+ enc = charset.encoding
355
+ begin
356
+ unless (pkt = read).eof?
357
+ return record_class.new(pkt, @fields, enc)
358
+ end
359
+ pkt.utiny
360
+ pkt.ushort
361
+ @server_status = pkt.ushort
362
+ set_state(more_results? ? :WAIT_RESULT : :READY)
363
+ @no_more_records = true
364
+ return nil
365
+ end
366
+ end
367
+ end
368
+
369
+ # Retrieve all records for simple query or prepared statement
370
+ # @param record_class [RawRecord or StmtRawRecord]
371
+ # @return [Array<record_class>] all records
372
+ def retr_all_records(record_class)
373
+ synchronize(before: :RESULT) do
374
+ enc = charset.encoding
375
+ begin
376
+ all_recs = []
377
+ until (pkt = read).eof?
378
+ all_recs.push record_class.new(pkt, @fields, enc)
379
+ end
380
+ pkt.utiny # 0xFE
381
+ _warnings = pkt.ushort
382
+ @server_status = pkt.ushort
383
+ @no_more_records = true
384
+ all_recs
385
+ ensure
386
+ set_state(more_results? ? :WAIT_RESULT : :READY)
305
387
  end
306
- pkt.read(3)
307
- @server_status = pkt.utiny
308
- all_recs
309
- ensure
310
- set_state :READY
311
388
  end
312
389
  end
313
390
 
@@ -345,7 +422,7 @@ class Mysql
345
422
  # @param stmt [String] prepared statement
346
423
  # @return [Array<Integer, Integer, Array<Field>>] statement id, number of parameters, field list
347
424
  def stmt_prepare_command(stmt)
348
- synchronize do
425
+ synchronize(before: :READY, after: :READY) do
349
426
  reset
350
427
  write [COM_STMT_PREPARE, charset.convert(stmt)].pack("Ca*")
351
428
  res_packet = PrepareResultPacket.parse read
@@ -368,41 +445,22 @@ class Mysql
368
445
  # @param values [Array] parameters
369
446
  # @return [Integer] number of fields
370
447
  def stmt_execute_command(stmt_id, values)
371
- check_state :READY
372
- begin
448
+ synchronize(before: :READY, after: :WAIT_RESULT, error: :READY) do
373
449
  reset
374
450
  write ExecutePacket.serialize(stmt_id, Mysql::Stmt::CURSOR_TYPE_NO_CURSOR, values)
375
- get_result
376
- rescue
377
- set_state :READY
378
- raise
379
- end
380
- end
381
-
382
- # Retrieve all records for prepared statement
383
- # @param fields [Array of Mysql::Fields] field list
384
- # @param charset [Mysql::Charset]
385
- # @return [Array<Array<Object>>] all records
386
- def stmt_retr_all_records(fields, charset)
387
- check_state :RESULT
388
- enc = charset.encoding
389
- begin
390
- all_recs = []
391
- until (pkt = read).eof?
392
- all_recs.push StmtRawRecord.new(pkt, fields, enc)
393
- end
394
- all_recs
395
- ensure
396
- set_state :READY
397
451
  end
398
452
  end
399
453
 
400
454
  # Stmt close command
401
455
  # @param stmt_id [Integer] statement id
402
456
  def stmt_close_command(stmt_id)
403
- synchronize do
457
+ get_result if @state == :WAIT_RESULT
458
+ retr_fields if @state == :FIELD
459
+ retr_all_records(StmtRawRecord) if @state == :RESULT
460
+ synchronize(before: :READY, after: :READY) do
404
461
  reset
405
462
  write [COM_STMT_CLOSE, stmt_id].pack("CV")
463
+ @gc_stmt_queue.delete stmt_id
406
464
  end
407
465
  end
408
466
 
@@ -411,30 +469,35 @@ class Mysql
411
469
  end
412
470
 
413
471
  def check_state(st)
414
- raise 'command out of sync' unless @state == st
472
+ raise Mysql::ClientError::CommandsOutOfSync, 'command out of sync' unless @state == st
415
473
  end
416
474
 
417
475
  def set_state(st)
418
476
  @state = st
419
- if st == :READY && !@gc_stmt_queue.empty?
420
- gc_disabled = GC.disable
421
- begin
422
- while st = @gc_stmt_queue.shift
423
- reset
424
- write [COM_STMT_CLOSE, st].pack("CV")
425
- end
426
- ensure
427
- GC.enable unless gc_disabled
477
+ return if st != :READY || @gc_stmt_queue.empty? || @socket&.closed?
478
+ gc_disabled = GC.disable
479
+ begin
480
+ while (st = @gc_stmt_queue.shift)
481
+ reset
482
+ write [COM_STMT_CLOSE, st].pack("CV")
428
483
  end
484
+ ensure
485
+ GC.enable unless gc_disabled
429
486
  end
430
487
  end
431
488
 
432
- def synchronize
433
- begin
434
- check_state :READY
435
- return yield
436
- ensure
437
- set_state :READY
489
+ def synchronize(before: nil, after: nil, error: nil)
490
+ @mutex.synchronize do
491
+ check_state before if before
492
+ begin
493
+ return yield
494
+ rescue
495
+ set_state error if error
496
+ raised = true
497
+ raise
498
+ ensure
499
+ set_state after if after && !raised
500
+ end
438
501
  end
439
502
  end
440
503
 
@@ -450,17 +513,19 @@ class Mysql
450
513
  data = ''
451
514
  len = nil
452
515
  begin
453
- header = read_timeout(4, @opts[:read_timeout])
516
+ timeout = @state == :INIT ? @opts[:connect_timeout] : @opts[:read_timeout]
517
+ header = read_timeout(4, timeout)
454
518
  raise EOFError unless header && header.length == 4
455
519
  len1, len2, seq = header.unpack("CvC")
456
520
  len = (len2 << 8) + len1
457
521
  raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
458
522
  @seq = (@seq + 1) % 256
459
- ret = read_timeout(len, @opts[:read_timeout])
523
+ ret = read_timeout(len, timeout)
460
524
  raise EOFError unless ret && ret.length == len
461
525
  data.concat ret
462
- rescue EOFError
463
- raise ClientError::ServerGoneError, 'MySQL server has gone away'
526
+ rescue EOFError, OpenSSL::SSL::SSLError
527
+ close
528
+ raise ClientError::ServerLost, 'Lost connection to server during query'
464
529
  rescue Errno::ETIMEDOUT
465
530
  raise ClientError, "read timeout"
466
531
  end while len == MAX_PACKET_LENGTH
@@ -474,6 +539,7 @@ class Mysql
474
539
  _, errno, message = data.unpack("Cva*") # Version 4.0 Error
475
540
  @sqlstate = ""
476
541
  end
542
+ @server_status &= ~SERVER_MORE_RESULTS_EXISTS
477
543
  message.force_encoding(@charset.encoding)
478
544
  if Mysql::ServerError::ERROR_MAP.key? errno
479
545
  raise Mysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
@@ -493,10 +559,10 @@ class Mysql
493
559
  r = @socket.read_nonblock(len - result.size, exception: false)
494
560
  case r
495
561
  when :wait_readable
496
- IO.select([@socket], nil, nil, e - now)
562
+ IO.select([@socket], nil, nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
497
563
  next
498
564
  when :wait_writable
499
- IO.select(nil, [@socket], nil, e - now)
565
+ IO.select(nil, [@socket], nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
500
566
  next
501
567
  else
502
568
  result << r
@@ -508,25 +574,25 @@ class Mysql
508
574
  # Write one packet data
509
575
  # @param data [String, IO, nil] packet data. If data is nil, write empty packet.
510
576
  def write(data)
511
- begin
512
- @socket.sync = false
513
- if data.nil?
514
- write_timeout([0, 0, @seq].pack("CvC"), @opts[:write_timeout])
577
+ timeout = @state == :INIT ? @opts[:connect_timeout] : @opts[:write_timeout]
578
+ @socket.sync = false
579
+ if data.nil?
580
+ write_timeout([0, 0, @seq].pack("CvC"), timeout)
581
+ @seq = (@seq + 1) % 256
582
+ else
583
+ data = StringIO.new data if data.is_a? String
584
+ while (d = data.read(MAX_PACKET_LENGTH))
585
+ write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, timeout)
515
586
  @seq = (@seq + 1) % 256
516
- else
517
- data = StringIO.new data if data.is_a? String
518
- while d = data.read(MAX_PACKET_LENGTH)
519
- write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, @opts[:write_timeout])
520
- @seq = (@seq + 1) % 256
521
- end
522
587
  end
523
- @socket.sync = true
524
- @socket.flush
525
- rescue Errno::EPIPE
526
- raise ClientError::ServerGoneError, 'MySQL server has gone away'
527
- rescue Errno::ETIMEDOUT
528
- raise ClientError, "write timeout"
529
588
  end
589
+ @socket.sync = true
590
+ @socket.flush
591
+ rescue Errno::EPIPE, OpenSSL::SSL::SSLError
592
+ close
593
+ raise ClientError::ServerGoneError, 'MySQL server has gone away'
594
+ rescue Errno::ETIMEDOUT
595
+ raise ClientError, "write timeout"
530
596
  end
531
597
 
532
598
  def write_timeout(data, timeout)
@@ -536,12 +602,12 @@ class Mysql
536
602
  while len < data.size
537
603
  now = Time.now
538
604
  raise Errno::ETIMEDOUT if now > e
539
- l = @socket.write_nonblock(data[len..-1], exception: false)
605
+ l = @socket.write_nonblock(data[len..], exception: false)
540
606
  case l
541
607
  when :wait_readable
542
- IO.select([@socket], nil, nil, e - now)
608
+ IO.select([@socket], nil, nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
543
609
  when :wait_writable
544
- IO.select(nil, [@socket], nil, e - now)
610
+ IO.select(nil, [@socket], nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
545
611
  else
546
612
  len += l
547
613
  end
@@ -552,14 +618,18 @@ class Mysql
552
618
  # Read EOF packet
553
619
  # @raise [ProtocolError] packet is not EOF
554
620
  def read_eof_packet
555
- raise ProtocolError, "packet is not EOF" unless read.eof?
621
+ pkt = read
622
+ raise ProtocolError, "packet is not EOF" unless pkt.eof?
623
+ pkt.utiny # 0xFE
624
+ _warnings = pkt.ushort
625
+ @server_status = pkt.ushort
556
626
  end
557
627
 
558
628
  # Send simple command
559
629
  # @param packet :: [String] packet data
560
630
  # @return [String] received data
561
631
  def simple_command(packet)
562
- synchronize do
632
+ synchronize(before: :READY, after: :READY) do
563
633
  reset
564
634
  write packet
565
635
  read.to_s
@@ -610,7 +680,10 @@ class Mysql
610
680
  server_status = pkt.ushort
611
681
  warning_count = pkt.ushort
612
682
  message = pkt.lcs
613
- return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message)
683
+ session_track = parse_session_track(pkt.lcs) if server_status & SERVER_SESSION_STATE_CHANGED
684
+ message = pkt.lcs unless pkt.to_s.empty?
685
+
686
+ return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message, session_track)
614
687
  elsif field_count.nil? # LOAD DATA LOCAL INFILE
615
688
  return self.new(nil, nil, nil, nil, nil, pkt.to_s)
616
689
  else
@@ -618,10 +691,38 @@ class Mysql
618
691
  end
619
692
  end
620
693
 
621
- attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
694
+ def self.parse_session_track(data)
695
+ session_track = {}
696
+ pkt = Packet.new(data.to_s)
697
+ until pkt.to_s.empty?
698
+ type = pkt.lcb
699
+ session_track[type] ||= []
700
+ case type
701
+ when SESSION_TRACK_SYSTEM_VARIABLES
702
+ p = Packet.new(pkt.lcs)
703
+ session_track[type].push [p.lcs, p.lcs]
704
+ when SESSION_TRACK_SCHEMA
705
+ pkt.lcb # skip
706
+ session_track[type].push pkt.lcs
707
+ when SESSION_TRACK_STATE_CHANGE
708
+ session_track[type].push pkt.lcs
709
+ when SESSION_TRACK_GTIDS
710
+ pkt.lcb # skip
711
+ pkt.lcb # skip
712
+ session_track[type].push pkt.lcs
713
+ when SESSION_TRACK_TRANSACTION_CHARACTERISTICS, SESSION_TRACK_TRANSACTION_STATE
714
+ pkt.lcb # skip
715
+ session_track[type].push pkt.lcs
716
+ end
717
+ end
718
+ session_track
719
+ end
720
+
721
+ attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message, :session_track
622
722
 
623
723
  def initialize(*args)
624
- @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
724
+ @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message, @session_track = args
725
+ @session_track ||= {}
625
726
  end
626
727
  end
627
728
 
@@ -676,7 +777,7 @@ class Mysql
676
777
 
677
778
  # Authentication packet
678
779
  class AuthenticationPacket
679
- def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin)
780
+ def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin, connect_attrs)
680
781
  data = [
681
782
  client_flags,
682
783
  max_packet_size,
@@ -692,7 +793,8 @@ class Mysql
692
793
  end
693
794
  data.push auth_plugin
694
795
  pack.concat "Z*"
695
- data.pack(pack)
796
+ attr = connect_attrs.map{|k, v| [Packet.lcs(k.to_s), Packet.lcs(v.to_s)]}.flatten.join
797
+ data.pack(pack) + Packet.lcb(attr.size)+attr
696
798
  end
697
799
  end
698
800
 
@@ -715,7 +817,7 @@ class Mysql
715
817
  netvalues = ""
716
818
  types = values.map do |v|
717
819
  t, n = Protocol.value2net v
718
- netvalues.concat n if v
820
+ netvalues.concat n unless v.nil?
719
821
  t
720
822
  end
721
823
  [Mysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
@@ -725,14 +827,14 @@ class Mysql
725
827
  #
726
828
  # If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
727
829
  def self.null_bitmap(values)
728
- bitmap = values.enum_for(:each_slice,8).map do |vals|
729
- vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}
830
+ bitmap = values.enum_for(:each_slice, 8).map do |vals|
831
+ vals.reverse.inject(0){|b, v| (b << 1 | (v.nil? ? 1 : 0))}
730
832
  end
731
833
  return bitmap.pack("C*")
732
834
  end
733
-
734
835
  end
735
836
 
837
+ # Authentication result packet
736
838
  class AuthenticationResultPacket
737
839
  def self.parse(pkt)
738
840
  result = pkt.utiny
@@ -749,6 +851,7 @@ class Mysql
749
851
  end
750
852
  end
751
853
 
854
+ # raw record
752
855
  class RawRecord
753
856
  def initialize(packet, fields, encoding)
754
857
  @packet, @fields, @encoding = packet, fields, encoding
@@ -756,16 +859,19 @@ class Mysql
756
859
 
757
860
  def to_a
758
861
  @fields.map do |f|
759
- if s = @packet.lcs
760
- unless f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
761
- s = Charset.convert_encoding(s, @encoding)
762
- end
862
+ s = @packet.lcs
863
+ if s.nil?
864
+ nil
865
+ elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
866
+ s.b
867
+ else
868
+ Charset.convert_encoding(s, @encoding)
763
869
  end
764
- s
765
870
  end
766
871
  end
767
872
  end
768
873
 
874
+ # prepared statement raw record
769
875
  class StmtRawRecord
770
876
  # @param pkt [Packet]
771
877
  # @param fields [Array of Fields]
@@ -778,19 +884,21 @@ class Mysql
778
884
  # @return [Array<Object>] one record
779
885
  def parse_record_packet
780
886
  @packet.utiny # skip first byte
781
- null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first
887
+ null_bit_map = @packet.read((@fields.length+7+2)/8).unpack1("b*")
782
888
  rec = @fields.each_with_index.map do |f, i|
783
- if null_bit_map[i+2] == ?1
889
+ if null_bit_map[i+2] == '1'
784
890
  nil
785
891
  else
786
892
  unsigned = f.flags & Field::UNSIGNED_FLAG != 0
787
893
  v = Protocol.net2value(@packet, f.type, unsigned)
788
- if v.nil? or v.is_a? Numeric or v.is_a? Time
894
+ if v.nil? or v.is_a? Numeric or v.is_a? Time or v.is_a? Date
789
895
  v
790
- elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
896
+ elsif f.type == Field::TYPE_BIT
791
897
  Charset.to_binary(v)
898
+ elsif v.is_a? String
899
+ f.charsetnr == Charset::BINARY_CHARSET_NUMBER ? Charset.to_binary(v) : Charset.convert_encoding(v, @encoding)
792
900
  else
793
- Charset.convert_encoding(v, @encoding)
901
+ v
794
902
  end
795
903
  end
796
904
  end
@@ -798,6 +906,5 @@ class Mysql
798
906
  end
799
907
 
800
908
  alias to_a parse_record_packet
801
-
802
909
  end
803
910
  end