ruby-mysql 3.0.0 → 4.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,19 @@
1
1
  # coding: ascii-8bit
2
+
2
3
  # Copyright (C) 2008 TOMITA Masahiro
3
4
  # mailto:tommy@tmtm.org
4
5
 
5
6
  require "socket"
6
7
  require "stringio"
7
8
  require "openssl"
8
- require_relative 'authenticator.rb'
9
+ require "bigdecimal"
10
+ require "date"
11
+ require 'time'
12
+ require_relative 'authenticator'
9
13
 
10
14
  class Mysql
11
15
  # MySQL network protocol
12
16
  class Protocol
13
-
14
17
  VERSION = 10
15
18
  MAX_PACKET_LENGTH = 2**24-1
16
19
 
@@ -21,8 +24,11 @@ class Mysql
21
24
  # @return [Object] converted value.
22
25
  def self.net2value(pkt, type, unsigned)
23
26
  case type
24
- when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB, Field::TYPE_JSON
27
+ when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_BLOB, Field::TYPE_JSON
25
28
  return pkt.lcs
29
+ when Field::TYPE_NEWDECIMAL
30
+ s = pkt.lcs
31
+ return s =~ /\./ && s !~ /\.0*\z/ ? BigDecimal(s) : s.to_i
26
32
  when Field::TYPE_TINY
27
33
  v = pkt.utiny
28
34
  return unsigned ? v : v < 128 ? v : v-256
@@ -37,18 +43,18 @@ class Mysql
37
43
  v = (n2 << 32) | n1
38
44
  return unsigned ? v : v < 0x8000_0000_0000_0000 ? v : v-0x10000_0000_0000_0000
39
45
  when Field::TYPE_FLOAT
40
- return pkt.read(4).unpack('e').first
46
+ return pkt.read(4).unpack1('e')
41
47
  when Field::TYPE_DOUBLE
42
- return pkt.read(8).unpack('E').first
48
+ return pkt.read(8).unpack1('E')
43
49
  when Field::TYPE_DATE
44
50
  len = pkt.utiny
45
51
  y, m, d = pkt.read(len).unpack("vCC")
46
- t = Time.new(y, m, d) rescue nil
52
+ t = Date.new(y, m, d) rescue nil
47
53
  return t
48
54
  when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
49
55
  len = pkt.utiny
50
56
  y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")
51
- return Time.new(y, m, d, h, mi, Rational((s.to_i*1000000+sp.to_i)/1000000)) rescue nil
57
+ return Time.new(y, m, d, h, mi, Rational(s.to_i*1000000+sp.to_i, 1000000)) rescue nil
52
58
  when Field::TYPE_TIME
53
59
  len = pkt.utiny
54
60
  sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")
@@ -70,6 +76,7 @@ class Mysql
70
76
  # @return [String] netdata
71
77
  # @raise [ProtocolError] value too large / value is not supported
72
78
  def self.value2net(v)
79
+ v = v == true ? 1 : v == false ? 0 : v
73
80
  case v
74
81
  when nil
75
82
  type = Field::TYPE_NULL
@@ -85,8 +92,12 @@ class Mysql
85
92
  type = Field::TYPE_LONGLONG | 0x8000
86
93
  val = [v&0xffffffff, v>>32].pack("VV")
87
94
  else
88
- raise ProtocolError, "value too large: #{v}"
95
+ type =Field::TYPE_NEWDECIMAL
96
+ val = Packet.lcs(v.to_s)
89
97
  end
98
+ when BigDecimal
99
+ type = Field::TYPE_NEWDECIMAL
100
+ val = Packet.lcs(v.to_s)
90
101
  when Float
91
102
  type = Field::TYPE_DOUBLE
92
103
  val = [v].pack("E")
@@ -96,6 +107,12 @@ class Mysql
96
107
  when Time
97
108
  type = Field::TYPE_DATETIME
98
109
  val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec].pack("CvCCCCCV")
110
+ when DateTime
111
+ type = Field::TYPE_DATETIME
112
+ val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, (v.sec_fraction*1000000).to_i].pack("CvCCCCCV")
113
+ when Date
114
+ type = Field::TYPE_DATE
115
+ val = [11, v.year, v.month, v.day, 0, 0, 0, 0].pack("CvCCCCCV")
99
116
  else
100
117
  raise ProtocolError, "class #{v.class} is not supported"
101
118
  end
@@ -112,14 +129,18 @@ class Mysql
112
129
  attr_reader :server_status
113
130
  attr_reader :warning_count
114
131
  attr_reader :message
132
+ attr_reader :session_track
115
133
  attr_reader :get_server_public_key
134
+ attr_reader :field_count
116
135
  attr_accessor :charset
117
136
 
118
137
  # @state variable keep state for connection.
119
- # :INIT :: Initial state.
120
- # :READY :: Ready for command.
121
- # :FIELD :: After query(). retr_fields() is needed.
122
- # :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
138
+ # :INIT :: Initial state.
139
+ # :READY :: Ready for command.
140
+ # :WAIT_RESULT :: After query_command(). get_result() is needed.
141
+ # :FIELD :: After get_result(). retr_fields() is needed.
142
+ # :RESULT :: After retr_fields(), retr_all_records() is needed.
143
+ # :CLOSED :: Connection closed.
123
144
 
124
145
  # make socket connection to server.
125
146
  # @param opts [Hash]
@@ -137,13 +158,16 @@ class Mysql
137
158
  # @option :local_infile [Boolean]
138
159
  # @option :load_data_local_dir [String]
139
160
  # @option :ssl_mode [Integer]
161
+ # @option :ssl_context_params [Hash<:Symbol, String>]
140
162
  # @option :get_server_public_key [Boolean]
141
163
  # @raise [ClientError] connection timeout
142
164
  def initialize(opts)
165
+ @mutex = Mutex.new
143
166
  @opts = opts
144
167
  @charset = Mysql::Charset.by_name("utf8mb4")
145
168
  @insert_id = 0
146
169
  @warning_count = 0
170
+ @session_track = {}
147
171
  @gc_stmt_queue = [] # stmt id list which GC destroy.
148
172
  set_state :INIT
149
173
  @get_server_public_key = @opts[:get_server_public_key]
@@ -152,7 +176,7 @@ class Mysql
152
176
  socket = @opts[:socket] || ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
153
177
  @socket = Socket.unix(socket)
154
178
  else
155
- port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
179
+ port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql", "tcp") rescue MYSQL_TCP_PORT)
156
180
  @socket = Socket.tcp(@opts[:host], port, connect_timeout: @opts[:connect_timeout])
157
181
  end
158
182
  rescue Errno::ETIMEDOUT
@@ -168,146 +192,199 @@ class Mysql
168
192
  # @param charset [Mysql::Charset, nil] charset for connection. nil: use server's charset
169
193
  # @raise [ProtocolError] The old style password is not supported
170
194
  def authenticate
171
- check_state :INIT
172
- reset
173
- init_packet = InitialPacket.parse read
174
- @server_info = init_packet.server_version
175
- @server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
176
- @server_capabilities = init_packet.server_capabilities
177
- @thread_id = init_packet.thread_id
178
- @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
179
- @client_flags |= CLIENT_LOCAL_FILES if @opts[:local_infile] || @opts[:load_data_local_dir]
180
- @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
181
- @client_flags |= @opts[:flags]
182
- if @opts[:charset]
183
- @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
184
- else
185
- @charset = Charset.by_number(init_packet.server_charset)
186
- @charset.encoding # raise error if unsupported charset
187
- end
188
- enable_ssl
189
- Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin)
190
- set_state :READY
191
- end
195
+ synchronize(before: :INIT, after: :READY) do
196
+ reset
197
+ init_packet = InitialPacket.parse read
198
+ @server_info = init_packet.server_version
199
+ @server_version = init_packet.server_version.split(/\D/)[0, 3].inject{|a, b| a.to_i*100+b.to_i}
200
+ @server_capabilities = init_packet.server_capabilities
201
+ @thread_id = init_packet.thread_id
202
+ @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_RESULTS | CLIENT_PS_MULTI_RESULTS | CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_ATTRS | CLIENT_SESSION_TRACK | CLIENT_LOCAL_FILES
203
+ @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
204
+ @client_flags |= @opts[:flags]
205
+ if @opts[:charset]
206
+ @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
207
+ else
208
+ @charset = Charset.by_number(init_packet.server_charset)
209
+ @charset.encoding # raise error if unsupported charset
210
+ end
211
+ enable_ssl
212
+ Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin, @opts[:connect_attrs])
213
+ end
214
+ end
215
+
216
+ SSL_MODE_KEY = {
217
+ SSL_MODE_DISABLED => 1,
218
+ SSL_MODE_PREFERRED => 2,
219
+ SSL_MODE_REQUIRED => 3,
220
+ SSL_MODE_VERIFY_CA => 4,
221
+ SSL_MODE_VERIFY_IDENTITY => 5,
222
+ '1' => 1,
223
+ '2' => 2,
224
+ '3' => 3,
225
+ '4' => 4,
226
+ '5' => 5,
227
+ 'disabled' => 1,
228
+ 'preferred' => 2,
229
+ 'required' => 3,
230
+ 'verify_ca' => 4,
231
+ 'verify_identity' => 5,
232
+ :disabled => 1,
233
+ :preferred => 2,
234
+ :required => 3,
235
+ :verify_ca => 4,
236
+ :verify_identity => 5,
237
+ }.freeze
192
238
 
193
239
  def enable_ssl
194
- case @opts[:ssl_mode]
195
- when SSL_MODE_DISABLED, '1', 'disabled'
196
- return
197
- when SSL_MODE_PREFERRED, '2', 'preferred'
240
+ ssl_mode = SSL_MODE_KEY[@opts[:ssl_mode]]
241
+ raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported" unless ssl_mode
242
+
243
+ return if ssl_mode == SSL_MODE_DISABLED
244
+ if ssl_mode == SSL_MODE_PREFERRED
198
245
  return if @socket.local_address.unix?
199
246
  return if @server_capabilities & CLIENT_SSL == 0
200
- when SSL_MODE_REQUIRED, '3', 'required'
201
- if @server_capabilities & CLIENT_SSL == 0
202
- raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
203
- end
204
- else
205
- raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported"
206
247
  end
207
- begin
208
- @client_flags |= CLIENT_SSL
209
- write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
210
- @socket = OpenSSL::SSL::SSLSocket.new(@socket)
211
- @socket.sync_close = true
212
- @socket.connect
213
- rescue => e
214
- @client_flags &= ~CLIENT_SSL
215
- return if @opts[:ssl_mode] == SSL_MODE_PREFERRED
216
- raise e
248
+ if ssl_mode >= SSL_MODE_REQUIRED && @server_capabilities & CLIENT_SSL == 0
249
+ raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
217
250
  end
251
+
252
+ context = OpenSSL::SSL::SSLContext.new
253
+ context.set_params(@opts[:ssl_context_params])
254
+ context.verify_mode = OpenSSL::SSL::VERIFY_NONE if ssl_mode < SSL_MODE_VERIFY_CA
255
+ context.verify_hostname = false if ssl_mode < SSL_MODE_VERIFY_IDENTITY
256
+
257
+ ssl_socket = OpenSSL::SSL::SSLSocket.new(@socket, context)
258
+ ssl_socket.sync_close = true
259
+ ssl_socket.hostname = @opts[:host] if ssl_mode >= SSL_MODE_VERIFY_IDENTITY
260
+
261
+ @client_flags |= CLIENT_SSL
262
+ write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
263
+
264
+ ssl_socket.connect
265
+ @socket = ssl_socket
266
+ rescue OpenSSL::SSL::SSLError => e
267
+ @client_flags &= ~CLIENT_SSL
268
+ return if @opts[:ssl_mode] < SSL_MODE_REQUIRED
269
+ raise e
270
+ end
271
+
272
+ def ssl_cipher
273
+ @client_flags.allbits?(CLIENT_SSL) ? @socket.cipher : nil
218
274
  end
219
275
 
220
276
  # Quit command
221
277
  def quit_command
222
- synchronize do
278
+ get_result if @state == :WAIT_RESULT
279
+ retr_fields if @state == :FIELD
280
+ retr_all_records(RawRecord) if @state == :RESULT
281
+ synchronize(before: :READY, after: :CLOSED) do
223
282
  reset
224
283
  write [COM_QUIT].pack("C")
225
284
  close
285
+ @gc_stmt_queue.clear
226
286
  end
227
287
  end
228
288
 
229
289
  # Query command
230
290
  # @param query [String] query string
231
- # @return [Integer, nil] number of fields of results. nil if no results.
232
291
  def query_command(query)
233
- check_state :READY
234
- begin
292
+ synchronize(before: :READY, after: :WAIT_RESULT, error: :READY) do
235
293
  reset
236
294
  write [COM_QUERY, @charset.convert(query)].pack("Ca*")
237
- get_result
238
- rescue
239
- set_state :READY
240
- raise
241
295
  end
242
296
  end
243
297
 
244
298
  # get result of query.
245
299
  # @return [integer, nil] number of fields of results. nil if no results.
246
300
  def get_result
247
- begin
301
+ synchronize(before: :WAIT_RESULT, error: :READY) do
248
302
  res_packet = ResultPacket.parse read
249
- if res_packet.field_count.to_i > 0 # result data exists
303
+ @field_count = res_packet.field_count
304
+ if @field_count.to_i > 0 # result data exists
250
305
  set_state :FIELD
251
- return res_packet.field_count
306
+ return @field_count
252
307
  end
253
- if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
308
+ if @field_count.nil? # LOAD DATA LOCAL INFILE
254
309
  send_local_file(res_packet.message)
310
+ res_packet = ResultPacket.parse read
255
311
  end
256
- @affected_rows, @insert_id, @server_status, @warning_count, @message =
257
- res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
258
- set_state :READY
312
+ @affected_rows, @insert_id, @server_status, @warning_count, @message, @session_track =
313
+ res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message, res_packet.session_track
314
+ set_state :READY unless more_results?
259
315
  return nil
260
- rescue
261
- set_state :READY
262
- raise
263
316
  end
264
317
  end
265
318
 
319
+ def more_results?
320
+ @server_status & SERVER_MORE_RESULTS_EXISTS != 0
321
+ end
322
+
266
323
  # send local file to server
267
324
  def send_local_file(filename)
268
325
  filename = File.absolute_path(filename)
269
326
  if @opts[:local_infile] || @opts[:load_data_local_dir] && filename.start_with?(@opts[:load_data_local_dir])
270
327
  File.open(filename){|f| write f}
328
+ write nil # EOF
271
329
  else
330
+ write nil # send empty data instead of file contents
331
+ read # result packet
272
332
  raise ClientError::LoadDataLocalInfileRejected, 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
273
333
  end
274
- ensure
275
- write nil # EOF mark
276
- read
277
334
  end
278
335
 
279
336
  # Retrieve n fields
280
- # @param n [Integer] number of fields
281
337
  # @return [Array<Mysql::Field>] field list
282
- def retr_fields(n)
283
- check_state :FIELD
284
- begin
285
- fields = n.times.map{Field.new FieldPacket.parse(read)}
338
+ def retr_fields
339
+ synchronize(before: :FIELD, after: :RESULT, error: :READY) do
340
+ @fields = @field_count.times.map{Field.new FieldPacket.parse(read)}
286
341
  read_eof_packet
287
- set_state :RESULT
288
- fields
289
- rescue
290
- set_state :READY
291
- raise
342
+ @no_more_records = false
343
+ @fields
292
344
  end
293
345
  end
294
346
 
295
- # Retrieve all records for simple query
296
- # @param fields [Array<Mysql::Field>] number of fields
297
- # @return [Array<Array<String>>] all records
298
- def retr_all_records(fields)
299
- check_state :RESULT
300
- enc = charset.encoding
301
- begin
302
- all_recs = []
303
- until (pkt = read).eof?
304
- all_recs.push RawRecord.new(pkt, fields, enc)
347
+ # Retrieve one record for simple query or prepared statement
348
+ # @param record_class [RawRecord or StmtRawRecord]
349
+ # @return [<record_class>] record
350
+ # @return [nil] no more record
351
+ def retr_record(record_class)
352
+ return nil if @no_more_records
353
+ synchronize(before: :RESULT) do
354
+ enc = charset.encoding
355
+ begin
356
+ unless (pkt = read).eof?
357
+ return record_class.new(pkt, @fields, enc)
358
+ end
359
+ pkt.utiny
360
+ pkt.ushort
361
+ @server_status = pkt.ushort
362
+ set_state(more_results? ? :WAIT_RESULT : :READY)
363
+ @no_more_records = true
364
+ return nil
365
+ end
366
+ end
367
+ end
368
+
369
+ # Retrieve all records for simple query or prepared statement
370
+ # @param record_class [RawRecord or StmtRawRecord]
371
+ # @return [Array<record_class>] all records
372
+ def retr_all_records(record_class)
373
+ synchronize(before: :RESULT) do
374
+ enc = charset.encoding
375
+ begin
376
+ all_recs = []
377
+ until (pkt = read).eof?
378
+ all_recs.push record_class.new(pkt, @fields, enc)
379
+ end
380
+ pkt.utiny # 0xFE
381
+ _warnings = pkt.ushort
382
+ @server_status = pkt.ushort
383
+ @no_more_records = true
384
+ all_recs
385
+ ensure
386
+ set_state(more_results? ? :WAIT_RESULT : :READY)
305
387
  end
306
- pkt.read(3)
307
- @server_status = pkt.utiny
308
- all_recs
309
- ensure
310
- set_state :READY
311
388
  end
312
389
  end
313
390
 
@@ -345,7 +422,7 @@ class Mysql
345
422
  # @param stmt [String] prepared statement
346
423
  # @return [Array<Integer, Integer, Array<Field>>] statement id, number of parameters, field list
347
424
  def stmt_prepare_command(stmt)
348
- synchronize do
425
+ synchronize(before: :READY, after: :READY) do
349
426
  reset
350
427
  write [COM_STMT_PREPARE, charset.convert(stmt)].pack("Ca*")
351
428
  res_packet = PrepareResultPacket.parse read
@@ -368,41 +445,22 @@ class Mysql
368
445
  # @param values [Array] parameters
369
446
  # @return [Integer] number of fields
370
447
  def stmt_execute_command(stmt_id, values)
371
- check_state :READY
372
- begin
448
+ synchronize(before: :READY, after: :WAIT_RESULT, error: :READY) do
373
449
  reset
374
450
  write ExecutePacket.serialize(stmt_id, Mysql::Stmt::CURSOR_TYPE_NO_CURSOR, values)
375
- get_result
376
- rescue
377
- set_state :READY
378
- raise
379
- end
380
- end
381
-
382
- # Retrieve all records for prepared statement
383
- # @param fields [Array of Mysql::Fields] field list
384
- # @param charset [Mysql::Charset]
385
- # @return [Array<Array<Object>>] all records
386
- def stmt_retr_all_records(fields, charset)
387
- check_state :RESULT
388
- enc = charset.encoding
389
- begin
390
- all_recs = []
391
- until (pkt = read).eof?
392
- all_recs.push StmtRawRecord.new(pkt, fields, enc)
393
- end
394
- all_recs
395
- ensure
396
- set_state :READY
397
451
  end
398
452
  end
399
453
 
400
454
  # Stmt close command
401
455
  # @param stmt_id [Integer] statement id
402
456
  def stmt_close_command(stmt_id)
403
- synchronize do
457
+ get_result if @state == :WAIT_RESULT
458
+ retr_fields if @state == :FIELD
459
+ retr_all_records(StmtRawRecord) if @state == :RESULT
460
+ synchronize(before: :READY, after: :READY) do
404
461
  reset
405
462
  write [COM_STMT_CLOSE, stmt_id].pack("CV")
463
+ @gc_stmt_queue.delete stmt_id
406
464
  end
407
465
  end
408
466
 
@@ -411,30 +469,35 @@ class Mysql
411
469
  end
412
470
 
413
471
  def check_state(st)
414
- raise 'command out of sync' unless @state == st
472
+ raise Mysql::ClientError::CommandsOutOfSync, 'command out of sync' unless @state == st
415
473
  end
416
474
 
417
475
  def set_state(st)
418
476
  @state = st
419
- if st == :READY && !@gc_stmt_queue.empty?
420
- gc_disabled = GC.disable
421
- begin
422
- while st = @gc_stmt_queue.shift
423
- reset
424
- write [COM_STMT_CLOSE, st].pack("CV")
425
- end
426
- ensure
427
- GC.enable unless gc_disabled
477
+ return if st != :READY || @gc_stmt_queue.empty? || @socket&.closed?
478
+ gc_disabled = GC.disable
479
+ begin
480
+ while (st = @gc_stmt_queue.shift)
481
+ reset
482
+ write [COM_STMT_CLOSE, st].pack("CV")
428
483
  end
484
+ ensure
485
+ GC.enable unless gc_disabled
429
486
  end
430
487
  end
431
488
 
432
- def synchronize
433
- begin
434
- check_state :READY
435
- return yield
436
- ensure
437
- set_state :READY
489
+ def synchronize(before: nil, after: nil, error: nil)
490
+ @mutex.synchronize do
491
+ check_state before if before
492
+ begin
493
+ return yield
494
+ rescue
495
+ set_state error if error
496
+ raised = true
497
+ raise
498
+ ensure
499
+ set_state after if after && !raised
500
+ end
438
501
  end
439
502
  end
440
503
 
@@ -450,17 +513,19 @@ class Mysql
450
513
  data = ''
451
514
  len = nil
452
515
  begin
453
- header = read_timeout(4, @opts[:read_timeout])
516
+ timeout = @state == :INIT ? @opts[:connect_timeout] : @opts[:read_timeout]
517
+ header = read_timeout(4, timeout)
454
518
  raise EOFError unless header && header.length == 4
455
519
  len1, len2, seq = header.unpack("CvC")
456
520
  len = (len2 << 8) + len1
457
521
  raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
458
522
  @seq = (@seq + 1) % 256
459
- ret = read_timeout(len, @opts[:read_timeout])
523
+ ret = read_timeout(len, timeout)
460
524
  raise EOFError unless ret && ret.length == len
461
525
  data.concat ret
462
- rescue EOFError
463
- raise ClientError::ServerGoneError, 'MySQL server has gone away'
526
+ rescue EOFError, OpenSSL::SSL::SSLError
527
+ close
528
+ raise ClientError::ServerLost, 'Lost connection to server during query'
464
529
  rescue Errno::ETIMEDOUT
465
530
  raise ClientError, "read timeout"
466
531
  end while len == MAX_PACKET_LENGTH
@@ -474,6 +539,7 @@ class Mysql
474
539
  _, errno, message = data.unpack("Cva*") # Version 4.0 Error
475
540
  @sqlstate = ""
476
541
  end
542
+ @server_status &= ~SERVER_MORE_RESULTS_EXISTS
477
543
  message.force_encoding(@charset.encoding)
478
544
  if Mysql::ServerError::ERROR_MAP.key? errno
479
545
  raise Mysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
@@ -493,10 +559,10 @@ class Mysql
493
559
  r = @socket.read_nonblock(len - result.size, exception: false)
494
560
  case r
495
561
  when :wait_readable
496
- IO.select([@socket], nil, nil, e - now)
562
+ IO.select([@socket], nil, nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
497
563
  next
498
564
  when :wait_writable
499
- IO.select(nil, [@socket], nil, e - now)
565
+ IO.select(nil, [@socket], nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
500
566
  next
501
567
  else
502
568
  result << r
@@ -508,25 +574,25 @@ class Mysql
508
574
  # Write one packet data
509
575
  # @param data [String, IO, nil] packet data. If data is nil, write empty packet.
510
576
  def write(data)
511
- begin
512
- @socket.sync = false
513
- if data.nil?
514
- write_timeout([0, 0, @seq].pack("CvC"), @opts[:write_timeout])
577
+ timeout = @state == :INIT ? @opts[:connect_timeout] : @opts[:write_timeout]
578
+ @socket.sync = false
579
+ if data.nil?
580
+ write_timeout([0, 0, @seq].pack("CvC"), timeout)
581
+ @seq = (@seq + 1) % 256
582
+ else
583
+ data = StringIO.new data if data.is_a? String
584
+ while (d = data.read(MAX_PACKET_LENGTH))
585
+ write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, timeout)
515
586
  @seq = (@seq + 1) % 256
516
- else
517
- data = StringIO.new data if data.is_a? String
518
- while d = data.read(MAX_PACKET_LENGTH)
519
- write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, @opts[:write_timeout])
520
- @seq = (@seq + 1) % 256
521
- end
522
587
  end
523
- @socket.sync = true
524
- @socket.flush
525
- rescue Errno::EPIPE
526
- raise ClientError::ServerGoneError, 'MySQL server has gone away'
527
- rescue Errno::ETIMEDOUT
528
- raise ClientError, "write timeout"
529
588
  end
589
+ @socket.sync = true
590
+ @socket.flush
591
+ rescue Errno::EPIPE, OpenSSL::SSL::SSLError
592
+ close
593
+ raise ClientError::ServerGoneError, 'MySQL server has gone away'
594
+ rescue Errno::ETIMEDOUT
595
+ raise ClientError, "write timeout"
530
596
  end
531
597
 
532
598
  def write_timeout(data, timeout)
@@ -536,12 +602,12 @@ class Mysql
536
602
  while len < data.size
537
603
  now = Time.now
538
604
  raise Errno::ETIMEDOUT if now > e
539
- l = @socket.write_nonblock(data[len..-1], exception: false)
605
+ l = @socket.write_nonblock(data[len..], exception: false)
540
606
  case l
541
607
  when :wait_readable
542
- IO.select([@socket], nil, nil, e - now)
608
+ IO.select([@socket], nil, nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
543
609
  when :wait_writable
544
- IO.select(nil, [@socket], nil, e - now)
610
+ IO.select(nil, [@socket], nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
545
611
  else
546
612
  len += l
547
613
  end
@@ -552,14 +618,18 @@ class Mysql
552
618
  # Read EOF packet
553
619
  # @raise [ProtocolError] packet is not EOF
554
620
  def read_eof_packet
555
- raise ProtocolError, "packet is not EOF" unless read.eof?
621
+ pkt = read
622
+ raise ProtocolError, "packet is not EOF" unless pkt.eof?
623
+ pkt.utiny # 0xFE
624
+ _warnings = pkt.ushort
625
+ @server_status = pkt.ushort
556
626
  end
557
627
 
558
628
  # Send simple command
559
629
  # @param packet :: [String] packet data
560
630
  # @return [String] received data
561
631
  def simple_command(packet)
562
- synchronize do
632
+ synchronize(before: :READY, after: :READY) do
563
633
  reset
564
634
  write packet
565
635
  read.to_s
@@ -610,7 +680,10 @@ class Mysql
610
680
  server_status = pkt.ushort
611
681
  warning_count = pkt.ushort
612
682
  message = pkt.lcs
613
- return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message)
683
+ session_track = parse_session_track(pkt.lcs) if server_status & SERVER_SESSION_STATE_CHANGED
684
+ message = pkt.lcs unless pkt.to_s.empty?
685
+
686
+ return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message, session_track)
614
687
  elsif field_count.nil? # LOAD DATA LOCAL INFILE
615
688
  return self.new(nil, nil, nil, nil, nil, pkt.to_s)
616
689
  else
@@ -618,10 +691,38 @@ class Mysql
618
691
  end
619
692
  end
620
693
 
621
- attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
694
+ def self.parse_session_track(data)
695
+ session_track = {}
696
+ pkt = Packet.new(data.to_s)
697
+ until pkt.to_s.empty?
698
+ type = pkt.lcb
699
+ session_track[type] ||= []
700
+ case type
701
+ when SESSION_TRACK_SYSTEM_VARIABLES
702
+ p = Packet.new(pkt.lcs)
703
+ session_track[type].push [p.lcs, p.lcs]
704
+ when SESSION_TRACK_SCHEMA
705
+ pkt.lcb # skip
706
+ session_track[type].push pkt.lcs
707
+ when SESSION_TRACK_STATE_CHANGE
708
+ session_track[type].push pkt.lcs
709
+ when SESSION_TRACK_GTIDS
710
+ pkt.lcb # skip
711
+ pkt.lcb # skip
712
+ session_track[type].push pkt.lcs
713
+ when SESSION_TRACK_TRANSACTION_CHARACTERISTICS, SESSION_TRACK_TRANSACTION_STATE
714
+ pkt.lcb # skip
715
+ session_track[type].push pkt.lcs
716
+ end
717
+ end
718
+ session_track
719
+ end
720
+
721
+ attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message, :session_track
622
722
 
623
723
  def initialize(*args)
624
- @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
724
+ @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message, @session_track = args
725
+ @session_track ||= {}
625
726
  end
626
727
  end
627
728
 
@@ -676,7 +777,7 @@ class Mysql
676
777
 
677
778
  # Authentication packet
678
779
  class AuthenticationPacket
679
- def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin)
780
+ def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin, connect_attrs)
680
781
  data = [
681
782
  client_flags,
682
783
  max_packet_size,
@@ -692,7 +793,8 @@ class Mysql
692
793
  end
693
794
  data.push auth_plugin
694
795
  pack.concat "Z*"
695
- data.pack(pack)
796
+ attr = connect_attrs.map{|k, v| [Packet.lcs(k.to_s), Packet.lcs(v.to_s)]}.flatten.join
797
+ data.pack(pack) + Packet.lcb(attr.size)+attr
696
798
  end
697
799
  end
698
800
 
@@ -715,7 +817,7 @@ class Mysql
715
817
  netvalues = ""
716
818
  types = values.map do |v|
717
819
  t, n = Protocol.value2net v
718
- netvalues.concat n if v
820
+ netvalues.concat n unless v.nil?
719
821
  t
720
822
  end
721
823
  [Mysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
@@ -725,14 +827,14 @@ class Mysql
725
827
  #
726
828
  # If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
727
829
  def self.null_bitmap(values)
728
- bitmap = values.enum_for(:each_slice,8).map do |vals|
729
- vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}
830
+ bitmap = values.enum_for(:each_slice, 8).map do |vals|
831
+ vals.reverse.inject(0){|b, v| (b << 1 | (v.nil? ? 1 : 0))}
730
832
  end
731
833
  return bitmap.pack("C*")
732
834
  end
733
-
734
835
  end
735
836
 
837
+ # Authentication result packet
736
838
  class AuthenticationResultPacket
737
839
  def self.parse(pkt)
738
840
  result = pkt.utiny
@@ -749,6 +851,7 @@ class Mysql
749
851
  end
750
852
  end
751
853
 
854
+ # raw record
752
855
  class RawRecord
753
856
  def initialize(packet, fields, encoding)
754
857
  @packet, @fields, @encoding = packet, fields, encoding
@@ -756,16 +859,19 @@ class Mysql
756
859
 
757
860
  def to_a
758
861
  @fields.map do |f|
759
- if s = @packet.lcs
760
- unless f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
761
- s = Charset.convert_encoding(s, @encoding)
762
- end
862
+ s = @packet.lcs
863
+ if s.nil?
864
+ nil
865
+ elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
866
+ s.b
867
+ else
868
+ Charset.convert_encoding(s, @encoding)
763
869
  end
764
- s
765
870
  end
766
871
  end
767
872
  end
768
873
 
874
+ # prepared statement raw record
769
875
  class StmtRawRecord
770
876
  # @param pkt [Packet]
771
877
  # @param fields [Array of Fields]
@@ -778,19 +884,21 @@ class Mysql
778
884
  # @return [Array<Object>] one record
779
885
  def parse_record_packet
780
886
  @packet.utiny # skip first byte
781
- null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first
887
+ null_bit_map = @packet.read((@fields.length+7+2)/8).unpack1("b*")
782
888
  rec = @fields.each_with_index.map do |f, i|
783
- if null_bit_map[i+2] == ?1
889
+ if null_bit_map[i+2] == '1'
784
890
  nil
785
891
  else
786
892
  unsigned = f.flags & Field::UNSIGNED_FLAG != 0
787
893
  v = Protocol.net2value(@packet, f.type, unsigned)
788
- if v.nil? or v.is_a? Numeric or v.is_a? Time
894
+ if v.nil? or v.is_a? Numeric or v.is_a? Time or v.is_a? Date
789
895
  v
790
- elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
896
+ elsif f.type == Field::TYPE_BIT
791
897
  Charset.to_binary(v)
898
+ elsif v.is_a? String
899
+ f.charsetnr == Charset::BINARY_CHARSET_NUMBER ? Charset.to_binary(v) : Charset.convert_encoding(v, @encoding)
792
900
  else
793
- Charset.convert_encoding(v, @encoding)
901
+ v
794
902
  end
795
903
  end
796
904
  end
@@ -798,6 +906,5 @@ class Mysql
798
906
  end
799
907
 
800
908
  alias to_a parse_record_packet
801
-
802
909
  end
803
910
  end