ruby-mysql 3.0.1 → 4.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,16 +1,19 @@
1
1
  # coding: ascii-8bit
2
+
2
3
  # Copyright (C) 2008 TOMITA Masahiro
3
4
  # mailto:tommy@tmtm.org
4
5
 
5
6
  require "socket"
6
7
  require "stringio"
7
8
  require "openssl"
8
- require_relative 'authenticator.rb'
9
+ require "bigdecimal"
10
+ require "date"
11
+ require 'time'
12
+ require_relative 'authenticator'
9
13
 
10
14
  class Mysql
11
15
  # MySQL network protocol
12
16
  class Protocol
13
-
14
17
  VERSION = 10
15
18
  MAX_PACKET_LENGTH = 2**24-1
16
19
 
@@ -21,8 +24,11 @@ class Mysql
21
24
  # @return [Object] converted value.
22
25
  def self.net2value(pkt, type, unsigned)
23
26
  case type
24
- when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB, Field::TYPE_JSON
27
+ when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_BLOB, Field::TYPE_JSON, Field::TYPE_GEOMETRY
25
28
  return pkt.lcs
29
+ when Field::TYPE_NEWDECIMAL
30
+ s = pkt.lcs
31
+ return s =~ /\./ && s !~ /\.0*\z/ ? BigDecimal(s) : s.to_i
26
32
  when Field::TYPE_TINY
27
33
  v = pkt.utiny
28
34
  return unsigned ? v : v < 128 ? v : v-256
@@ -37,18 +43,18 @@ class Mysql
37
43
  v = (n2 << 32) | n1
38
44
  return unsigned ? v : v < 0x8000_0000_0000_0000 ? v : v-0x10000_0000_0000_0000
39
45
  when Field::TYPE_FLOAT
40
- return pkt.read(4).unpack('e').first
46
+ return pkt.read(4).unpack1('e')
41
47
  when Field::TYPE_DOUBLE
42
- return pkt.read(8).unpack('E').first
48
+ return pkt.read(8).unpack1('E')
43
49
  when Field::TYPE_DATE
44
50
  len = pkt.utiny
45
51
  y, m, d = pkt.read(len).unpack("vCC")
46
- t = Time.new(y, m, d) rescue nil
52
+ t = Date.new(y, m, d) rescue nil
47
53
  return t
48
54
  when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
49
55
  len = pkt.utiny
50
56
  y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")
51
- return Time.new(y, m, d, h, mi, Rational((s.to_i*1000000+sp.to_i)/1000000)) rescue nil
57
+ return Time.new(y, m, d, h, mi, Rational(s.to_i*1000000+sp.to_i, 1000000)) rescue nil
52
58
  when Field::TYPE_TIME
53
59
  len = pkt.utiny
54
60
  sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")
@@ -70,6 +76,7 @@ class Mysql
70
76
  # @return [String] netdata
71
77
  # @raise [ProtocolError] value too large / value is not supported
72
78
  def self.value2net(v)
79
+ v = v == true ? 1 : v == false ? 0 : v
73
80
  case v
74
81
  when nil
75
82
  type = Field::TYPE_NULL
@@ -85,8 +92,12 @@ class Mysql
85
92
  type = Field::TYPE_LONGLONG | 0x8000
86
93
  val = [v&0xffffffff, v>>32].pack("VV")
87
94
  else
88
- raise ProtocolError, "value too large: #{v}"
95
+ type =Field::TYPE_NEWDECIMAL
96
+ val = Packet.lcs(v.to_s)
89
97
  end
98
+ when BigDecimal
99
+ type = Field::TYPE_NEWDECIMAL
100
+ val = Packet.lcs(v.to_s)
90
101
  when Float
91
102
  type = Field::TYPE_DOUBLE
92
103
  val = [v].pack("E")
@@ -96,6 +107,12 @@ class Mysql
96
107
  when Time
97
108
  type = Field::TYPE_DATETIME
98
109
  val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec].pack("CvCCCCCV")
110
+ when DateTime
111
+ type = Field::TYPE_DATETIME
112
+ val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, (v.sec_fraction*1000000).to_i].pack("CvCCCCCV")
113
+ when Date
114
+ type = Field::TYPE_DATE
115
+ val = [11, v.year, v.month, v.day, 0, 0, 0, 0].pack("CvCCCCCV")
99
116
  else
100
117
  raise ProtocolError, "class #{v.class} is not supported"
101
118
  end
@@ -112,14 +129,18 @@ class Mysql
112
129
  attr_reader :server_status
113
130
  attr_reader :warning_count
114
131
  attr_reader :message
132
+ attr_reader :session_track
115
133
  attr_reader :get_server_public_key
134
+ attr_reader :field_count
116
135
  attr_accessor :charset
117
136
 
118
137
  # @state variable keep state for connection.
119
- # :INIT :: Initial state.
120
- # :READY :: Ready for command.
121
- # :FIELD :: After query(). retr_fields() is needed.
122
- # :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
138
+ # :INIT :: Initial state.
139
+ # :READY :: Ready for command.
140
+ # :WAIT_RESULT :: After query_command(). get_result() is needed.
141
+ # :FIELD :: After get_result(). retr_fields() is needed.
142
+ # :RESULT :: After retr_fields(), retr_all_records() is needed.
143
+ # :CLOSED :: Connection closed.
123
144
 
124
145
  # make socket connection to server.
125
146
  # @param opts [Hash]
@@ -137,22 +158,28 @@ class Mysql
137
158
  # @option :local_infile [Boolean]
138
159
  # @option :load_data_local_dir [String]
139
160
  # @option :ssl_mode [Integer]
161
+ # @option :ssl_context_params [Hash<:Symbol, String>]
140
162
  # @option :get_server_public_key [Boolean]
163
+ # @option :io [BasicSocket, OpenSSL::SSL::SSLSocket] Existing socket instance that will be used instead of creating a new socket
141
164
  # @raise [ClientError] connection timeout
142
165
  def initialize(opts)
166
+ @mutex = Mutex.new
143
167
  @opts = opts
144
168
  @charset = Mysql::Charset.by_name("utf8mb4")
145
169
  @insert_id = 0
146
170
  @warning_count = 0
171
+ @session_track = {}
147
172
  @gc_stmt_queue = [] # stmt id list which GC destroy.
148
173
  set_state :INIT
149
174
  @get_server_public_key = @opts[:get_server_public_key]
150
175
  begin
151
- if @opts[:host].nil? or @opts[:host].empty? or @opts[:host] == "localhost"
176
+ if @opts[:io]
177
+ @socket = @opts[:io]
178
+ elsif @opts[:host].nil? or @opts[:host].empty? or @opts[:host] == "localhost"
152
179
  socket = @opts[:socket] || ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
153
180
  @socket = Socket.unix(socket)
154
181
  else
155
- port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
182
+ port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql", "tcp") rescue MYSQL_TCP_PORT)
156
183
  @socket = Socket.tcp(@opts[:host], port, connect_timeout: @opts[:connect_timeout])
157
184
  end
158
185
  rescue Errno::ETIMEDOUT
@@ -161,153 +188,206 @@ class Mysql
161
188
  end
162
189
 
163
190
  def close
164
- @socket.close
191
+ @socket.close rescue nil
165
192
  end
166
193
 
167
194
  # initial negotiate and authenticate.
168
195
  # @param charset [Mysql::Charset, nil] charset for connection. nil: use server's charset
169
196
  # @raise [ProtocolError] The old style password is not supported
170
197
  def authenticate
171
- check_state :INIT
172
- reset
173
- init_packet = InitialPacket.parse read
174
- @server_info = init_packet.server_version
175
- @server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
176
- @server_capabilities = init_packet.server_capabilities
177
- @thread_id = init_packet.thread_id
178
- @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_RESULTS | CLIENT_PS_MULTI_RESULTS | CLIENT_PLUGIN_AUTH
179
- @client_flags |= CLIENT_LOCAL_FILES if @opts[:local_infile] || @opts[:load_data_local_dir]
180
- @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
181
- @client_flags |= @opts[:flags]
182
- if @opts[:charset]
183
- @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
184
- else
185
- @charset = Charset.by_number(init_packet.server_charset)
186
- @charset.encoding # raise error if unsupported charset
187
- end
188
- enable_ssl
189
- Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin)
190
- set_state :READY
191
- end
198
+ synchronize(before: :INIT, after: :READY) do
199
+ reset
200
+ init_packet = InitialPacket.parse read
201
+ @server_info = init_packet.server_version
202
+ @server_version = init_packet.server_version.split(/\D/)[0, 3].inject{|a, b| a.to_i*100+b.to_i}
203
+ @server_capabilities = init_packet.server_capabilities
204
+ @thread_id = init_packet.thread_id
205
+ @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_RESULTS | CLIENT_PS_MULTI_RESULTS | CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_ATTRS | CLIENT_SESSION_TRACK | CLIENT_LOCAL_FILES
206
+ @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
207
+ @client_flags |= @opts[:flags]
208
+ if @opts[:charset]
209
+ @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
210
+ else
211
+ @charset = Charset.by_number(init_packet.server_charset)
212
+ @charset.encoding # raise error if unsupported charset
213
+ end
214
+ enable_ssl
215
+ Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin, @opts[:connect_attrs])
216
+ end
217
+ end
218
+
219
+ SSL_MODE_KEY = {
220
+ SSL_MODE_DISABLED => 1,
221
+ SSL_MODE_PREFERRED => 2,
222
+ SSL_MODE_REQUIRED => 3,
223
+ SSL_MODE_VERIFY_CA => 4,
224
+ SSL_MODE_VERIFY_IDENTITY => 5,
225
+ '1' => 1,
226
+ '2' => 2,
227
+ '3' => 3,
228
+ '4' => 4,
229
+ '5' => 5,
230
+ 'disabled' => 1,
231
+ 'preferred' => 2,
232
+ 'required' => 3,
233
+ 'verify_ca' => 4,
234
+ 'verify_identity' => 5,
235
+ :disabled => 1,
236
+ :preferred => 2,
237
+ :required => 3,
238
+ :verify_ca => 4,
239
+ :verify_identity => 5,
240
+ }.freeze
192
241
 
193
242
  def enable_ssl
194
- case @opts[:ssl_mode]
195
- when SSL_MODE_DISABLED, '1', 'disabled'
196
- return
197
- when SSL_MODE_PREFERRED, '2', 'preferred'
243
+ ssl_mode = SSL_MODE_KEY[@opts[:ssl_mode]]
244
+ raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported" unless ssl_mode
245
+
246
+ return if ssl_mode == SSL_MODE_DISABLED
247
+ if ssl_mode == SSL_MODE_PREFERRED
198
248
  return if @socket.local_address.unix?
199
249
  return if @server_capabilities & CLIENT_SSL == 0
200
- when SSL_MODE_REQUIRED, '3', 'required'
201
- if @server_capabilities & CLIENT_SSL == 0
202
- raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
203
- end
204
- else
205
- raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported"
206
250
  end
207
- begin
208
- @client_flags |= CLIENT_SSL
209
- write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
210
- @socket = OpenSSL::SSL::SSLSocket.new(@socket)
211
- @socket.sync_close = true
212
- @socket.connect
213
- rescue => e
214
- @client_flags &= ~CLIENT_SSL
215
- return if @opts[:ssl_mode] == SSL_MODE_PREFERRED
216
- raise e
251
+ if ssl_mode >= SSL_MODE_REQUIRED && @server_capabilities & CLIENT_SSL == 0
252
+ raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
217
253
  end
254
+
255
+ context = OpenSSL::SSL::SSLContext.new
256
+ context.set_params(@opts[:ssl_context_params])
257
+ context.verify_mode = OpenSSL::SSL::VERIFY_NONE if ssl_mode < SSL_MODE_VERIFY_CA
258
+ context.verify_hostname = false if ssl_mode < SSL_MODE_VERIFY_IDENTITY
259
+
260
+ ssl_socket = OpenSSL::SSL::SSLSocket.new(@socket, context)
261
+ ssl_socket.sync_close = true
262
+ ssl_socket.hostname = @opts[:host] if ssl_mode >= SSL_MODE_VERIFY_IDENTITY
263
+
264
+ @client_flags |= CLIENT_SSL
265
+ write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
266
+
267
+ ssl_socket.connect
268
+ @socket = ssl_socket
269
+ rescue OpenSSL::SSL::SSLError => e
270
+ @client_flags &= ~CLIENT_SSL
271
+ return if @opts[:ssl_mode] < SSL_MODE_REQUIRED
272
+ raise e
273
+ end
274
+
275
+ def ssl_cipher
276
+ @client_flags.allbits?(CLIENT_SSL) ? @socket.cipher : nil
218
277
  end
219
278
 
220
279
  # Quit command
221
280
  def quit_command
222
- synchronize do
281
+ get_result if @state == :WAIT_RESULT
282
+ retr_fields if @state == :FIELD
283
+ retr_all_records(RawRecord) if @state == :RESULT
284
+ synchronize(before: :READY, after: :CLOSED) do
223
285
  reset
224
286
  write [COM_QUIT].pack("C")
225
287
  close
288
+ @gc_stmt_queue.clear
226
289
  end
227
290
  end
228
291
 
229
292
  # Query command
230
293
  # @param query [String] query string
231
- # @return [Integer, nil] number of fields of results. nil if no results.
232
294
  def query_command(query)
233
- check_state :READY
234
- begin
295
+ synchronize(before: :READY, after: :WAIT_RESULT, error: :READY) do
235
296
  reset
236
297
  write [COM_QUERY, @charset.convert(query)].pack("Ca*")
237
- get_result
238
- rescue
239
- set_state :READY
240
- raise
241
298
  end
242
299
  end
243
300
 
244
301
  # get result of query.
245
302
  # @return [integer, nil] number of fields of results. nil if no results.
246
303
  def get_result
247
- begin
304
+ synchronize(before: :WAIT_RESULT, error: :READY) do
248
305
  res_packet = ResultPacket.parse read
249
- if res_packet.field_count.to_i > 0 # result data exists
306
+ @field_count = res_packet.field_count
307
+ if @field_count.to_i > 0 # result data exists
250
308
  set_state :FIELD
251
- return res_packet.field_count
309
+ return @field_count
252
310
  end
253
- if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
311
+ if @field_count.nil? # LOAD DATA LOCAL INFILE
254
312
  send_local_file(res_packet.message)
255
313
  res_packet = ResultPacket.parse read
256
314
  end
257
- @affected_rows, @insert_id, @server_status, @warning_count, @message =
258
- res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
259
- set_state :READY
315
+ @affected_rows, @insert_id, @server_status, @warning_count, @message, @session_track =
316
+ res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message, res_packet.session_track
317
+ set_state :READY unless more_results?
260
318
  return nil
261
- rescue
262
- set_state :READY
263
- raise
264
319
  end
265
320
  end
266
321
 
322
+ def more_results?
323
+ @server_status & SERVER_MORE_RESULTS_EXISTS != 0
324
+ end
325
+
267
326
  # send local file to server
268
327
  def send_local_file(filename)
269
328
  filename = File.absolute_path(filename)
270
329
  if @opts[:local_infile] || @opts[:load_data_local_dir] && filename.start_with?(@opts[:load_data_local_dir])
271
330
  File.open(filename){|f| write f}
331
+ write nil # EOF
272
332
  else
333
+ write nil # send empty data instead of file contents
334
+ read # result packet
273
335
  raise ClientError::LoadDataLocalInfileRejected, 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
274
336
  end
275
- ensure
276
- write nil # EOF mark
277
337
  end
278
338
 
279
339
  # Retrieve n fields
280
- # @param n [Integer] number of fields
281
340
  # @return [Array<Mysql::Field>] field list
282
- def retr_fields(n)
283
- check_state :FIELD
284
- begin
285
- fields = n.times.map{Field.new FieldPacket.parse(read)}
341
+ def retr_fields
342
+ synchronize(before: :FIELD, after: :RESULT, error: :READY) do
343
+ @fields = @field_count.times.map{Field.new FieldPacket.parse(read)}
286
344
  read_eof_packet
287
- set_state :RESULT
288
- fields
289
- rescue
290
- set_state :READY
291
- raise
345
+ @no_more_records = false
346
+ @fields
292
347
  end
293
348
  end
294
349
 
295
- # Retrieve all records for simple query
296
- # @param fields [Array<Mysql::Field>] number of fields
297
- # @return [Array<Array<String>>] all records
298
- def retr_all_records(fields)
299
- check_state :RESULT
300
- enc = charset.encoding
301
- begin
302
- all_recs = []
303
- until (pkt = read).eof?
304
- all_recs.push RawRecord.new(pkt, fields, enc)
350
+ # Retrieve one record for simple query or prepared statement
351
+ # @param record_class [RawRecord or StmtRawRecord]
352
+ # @return [<record_class>] record
353
+ # @return [nil] no more record
354
+ def retr_record(record_class)
355
+ return nil if @no_more_records
356
+ synchronize(before: :RESULT) do
357
+ enc = charset.encoding
358
+ begin
359
+ unless (pkt = read).eof?
360
+ return record_class.new(pkt, @fields, enc)
361
+ end
362
+ pkt.utiny
363
+ pkt.ushort
364
+ @server_status = pkt.ushort
365
+ set_state(more_results? ? :WAIT_RESULT : :READY)
366
+ @no_more_records = true
367
+ return nil
368
+ end
369
+ end
370
+ end
371
+
372
+ # Retrieve all records for simple query or prepared statement
373
+ # @param record_class [RawRecord or StmtRawRecord]
374
+ # @return [Array<record_class>] all records
375
+ def retr_all_records(record_class)
376
+ synchronize(before: :RESULT) do
377
+ enc = charset.encoding
378
+ begin
379
+ all_recs = []
380
+ until (pkt = read).eof?
381
+ all_recs.push record_class.new(pkt, @fields, enc)
382
+ end
383
+ pkt.utiny # 0xFE
384
+ _warnings = pkt.ushort
385
+ @server_status = pkt.ushort
386
+ @no_more_records = true
387
+ all_recs
388
+ ensure
389
+ set_state(more_results? ? :WAIT_RESULT : :READY)
305
390
  end
306
- pkt.read(3)
307
- @server_status = pkt.utiny
308
- all_recs
309
- ensure
310
- set_state :READY
311
391
  end
312
392
  end
313
393
 
@@ -345,7 +425,7 @@ class Mysql
345
425
  # @param stmt [String] prepared statement
346
426
  # @return [Array<Integer, Integer, Array<Field>>] statement id, number of parameters, field list
347
427
  def stmt_prepare_command(stmt)
348
- synchronize do
428
+ synchronize(before: :READY, after: :READY) do
349
429
  reset
350
430
  write [COM_STMT_PREPARE, charset.convert(stmt)].pack("Ca*")
351
431
  res_packet = PrepareResultPacket.parse read
@@ -368,41 +448,22 @@ class Mysql
368
448
  # @param values [Array] parameters
369
449
  # @return [Integer] number of fields
370
450
  def stmt_execute_command(stmt_id, values)
371
- check_state :READY
372
- begin
451
+ synchronize(before: :READY, after: :WAIT_RESULT, error: :READY) do
373
452
  reset
374
453
  write ExecutePacket.serialize(stmt_id, Mysql::Stmt::CURSOR_TYPE_NO_CURSOR, values)
375
- get_result
376
- rescue
377
- set_state :READY
378
- raise
379
- end
380
- end
381
-
382
- # Retrieve all records for prepared statement
383
- # @param fields [Array of Mysql::Fields] field list
384
- # @param charset [Mysql::Charset]
385
- # @return [Array<Array<Object>>] all records
386
- def stmt_retr_all_records(fields, charset)
387
- check_state :RESULT
388
- enc = charset.encoding
389
- begin
390
- all_recs = []
391
- until (pkt = read).eof?
392
- all_recs.push StmtRawRecord.new(pkt, fields, enc)
393
- end
394
- all_recs
395
- ensure
396
- set_state :READY
397
454
  end
398
455
  end
399
456
 
400
457
  # Stmt close command
401
458
  # @param stmt_id [Integer] statement id
402
459
  def stmt_close_command(stmt_id)
403
- synchronize do
460
+ get_result if @state == :WAIT_RESULT
461
+ retr_fields if @state == :FIELD
462
+ retr_all_records(StmtRawRecord) if @state == :RESULT
463
+ synchronize(before: :READY, after: :READY) do
404
464
  reset
405
465
  write [COM_STMT_CLOSE, stmt_id].pack("CV")
466
+ @gc_stmt_queue.delete stmt_id
406
467
  end
407
468
  end
408
469
 
@@ -411,30 +472,35 @@ class Mysql
411
472
  end
412
473
 
413
474
  def check_state(st)
414
- raise 'command out of sync' unless @state == st
475
+ raise Mysql::ClientError::CommandsOutOfSync, 'command out of sync' unless @state == st
415
476
  end
416
477
 
417
478
  def set_state(st)
418
479
  @state = st
419
- if st == :READY && !@gc_stmt_queue.empty?
420
- gc_disabled = GC.disable
421
- begin
422
- while st = @gc_stmt_queue.shift
423
- reset
424
- write [COM_STMT_CLOSE, st].pack("CV")
425
- end
426
- ensure
427
- GC.enable unless gc_disabled
480
+ return if st != :READY || @gc_stmt_queue.empty? || @socket&.closed?
481
+ gc_disabled = GC.disable
482
+ begin
483
+ while (st = @gc_stmt_queue.shift)
484
+ reset
485
+ write [COM_STMT_CLOSE, st].pack("CV")
428
486
  end
487
+ ensure
488
+ GC.enable unless gc_disabled
429
489
  end
430
490
  end
431
491
 
432
- def synchronize
433
- begin
434
- check_state :READY
435
- return yield
436
- ensure
437
- set_state :READY
492
+ def synchronize(before: nil, after: nil, error: nil)
493
+ @mutex.synchronize do
494
+ check_state before if before
495
+ begin
496
+ return yield
497
+ rescue
498
+ set_state error if error
499
+ raised = true
500
+ raise
501
+ ensure
502
+ set_state after if after && !raised
503
+ end
438
504
  end
439
505
  end
440
506
 
@@ -450,18 +516,19 @@ class Mysql
450
516
  data = ''
451
517
  len = nil
452
518
  begin
453
- header = read_timeout(4, @opts[:read_timeout])
519
+ timeout = @state == :INIT ? @opts[:connect_timeout] : @opts[:read_timeout]
520
+ header = read_timeout(4, timeout)
454
521
  raise EOFError unless header && header.length == 4
455
522
  len1, len2, seq = header.unpack("CvC")
456
523
  len = (len2 << 8) + len1
457
524
  raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
458
525
  @seq = (@seq + 1) % 256
459
- ret = read_timeout(len, @opts[:read_timeout])
526
+ ret = read_timeout(len, timeout)
460
527
  raise EOFError unless ret && ret.length == len
461
528
  data.concat ret
462
- rescue EOFError
463
- @socket.close rescue nil
464
- raise ClientError::ServerGoneError, 'MySQL server has gone away'
529
+ rescue EOFError, OpenSSL::SSL::SSLError
530
+ close
531
+ raise ClientError::ServerLost, 'Lost connection to server during query'
465
532
  rescue Errno::ETIMEDOUT
466
533
  raise ClientError, "read timeout"
467
534
  end while len == MAX_PACKET_LENGTH
@@ -495,10 +562,10 @@ class Mysql
495
562
  r = @socket.read_nonblock(len - result.size, exception: false)
496
563
  case r
497
564
  when :wait_readable
498
- IO.select([@socket], nil, nil, e - now)
565
+ IO.select([@socket], nil, nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
499
566
  next
500
567
  when :wait_writable
501
- IO.select(nil, [@socket], nil, e - now)
568
+ IO.select(nil, [@socket], nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
502
569
  next
503
570
  else
504
571
  result << r
@@ -510,26 +577,25 @@ class Mysql
510
577
  # Write one packet data
511
578
  # @param data [String, IO, nil] packet data. If data is nil, write empty packet.
512
579
  def write(data)
513
- begin
514
- @socket.sync = false
515
- if data.nil?
516
- write_timeout([0, 0, @seq].pack("CvC"), @opts[:write_timeout])
580
+ timeout = @state == :INIT ? @opts[:connect_timeout] : @opts[:write_timeout]
581
+ @socket.sync = false
582
+ if data.nil?
583
+ write_timeout([0, 0, @seq].pack("CvC"), timeout)
584
+ @seq = (@seq + 1) % 256
585
+ else
586
+ data = StringIO.new data if data.is_a? String
587
+ while (d = data.read(MAX_PACKET_LENGTH))
588
+ write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, timeout)
517
589
  @seq = (@seq + 1) % 256
518
- else
519
- data = StringIO.new data if data.is_a? String
520
- while d = data.read(MAX_PACKET_LENGTH)
521
- write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, @opts[:write_timeout])
522
- @seq = (@seq + 1) % 256
523
- end
524
590
  end
525
- @socket.sync = true
526
- @socket.flush
527
- rescue Errno::EPIPE
528
- @socket.close rescue nil
529
- raise ClientError::ServerGoneError, 'MySQL server has gone away'
530
- rescue Errno::ETIMEDOUT
531
- raise ClientError, "write timeout"
532
591
  end
592
+ @socket.sync = true
593
+ @socket.flush
594
+ rescue Errno::EPIPE, OpenSSL::SSL::SSLError
595
+ close
596
+ raise ClientError::ServerGoneError, 'MySQL server has gone away'
597
+ rescue Errno::ETIMEDOUT
598
+ raise ClientError, "write timeout"
533
599
  end
534
600
 
535
601
  def write_timeout(data, timeout)
@@ -539,12 +605,12 @@ class Mysql
539
605
  while len < data.size
540
606
  now = Time.now
541
607
  raise Errno::ETIMEDOUT if now > e
542
- l = @socket.write_nonblock(data[len..-1], exception: false)
608
+ l = @socket.write_nonblock(data[len..], exception: false)
543
609
  case l
544
610
  when :wait_readable
545
- IO.select([@socket], nil, nil, e - now)
611
+ IO.select([@socket], nil, nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
546
612
  when :wait_writable
547
- IO.select(nil, [@socket], nil, e - now)
613
+ IO.select(nil, [@socket], nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
548
614
  else
549
615
  len += l
550
616
  end
@@ -555,14 +621,18 @@ class Mysql
555
621
  # Read EOF packet
556
622
  # @raise [ProtocolError] packet is not EOF
557
623
  def read_eof_packet
558
- raise ProtocolError, "packet is not EOF" unless read.eof?
624
+ pkt = read
625
+ raise ProtocolError, "packet is not EOF" unless pkt.eof?
626
+ pkt.utiny # 0xFE
627
+ _warnings = pkt.ushort
628
+ @server_status = pkt.ushort
559
629
  end
560
630
 
561
631
  # Send simple command
562
632
  # @param packet :: [String] packet data
563
633
  # @return [String] received data
564
634
  def simple_command(packet)
565
- synchronize do
635
+ synchronize(before: :READY, after: :READY) do
566
636
  reset
567
637
  write packet
568
638
  read.to_s
@@ -613,7 +683,10 @@ class Mysql
613
683
  server_status = pkt.ushort
614
684
  warning_count = pkt.ushort
615
685
  message = pkt.lcs
616
- return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message)
686
+ session_track = parse_session_track(pkt.lcs) if server_status & SERVER_SESSION_STATE_CHANGED
687
+ message = pkt.lcs unless pkt.to_s.empty?
688
+
689
+ return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message, session_track)
617
690
  elsif field_count.nil? # LOAD DATA LOCAL INFILE
618
691
  return self.new(nil, nil, nil, nil, nil, pkt.to_s)
619
692
  else
@@ -621,10 +694,38 @@ class Mysql
621
694
  end
622
695
  end
623
696
 
624
- attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
697
+ def self.parse_session_track(data)
698
+ session_track = {}
699
+ pkt = Packet.new(data.to_s)
700
+ until pkt.to_s.empty?
701
+ type = pkt.lcb
702
+ session_track[type] ||= []
703
+ case type
704
+ when SESSION_TRACK_SYSTEM_VARIABLES
705
+ p = Packet.new(pkt.lcs)
706
+ session_track[type].push [p.lcs, p.lcs]
707
+ when SESSION_TRACK_SCHEMA
708
+ pkt.lcb # skip
709
+ session_track[type].push pkt.lcs
710
+ when SESSION_TRACK_STATE_CHANGE
711
+ session_track[type].push pkt.lcs
712
+ when SESSION_TRACK_GTIDS
713
+ pkt.lcb # skip
714
+ pkt.lcb # skip
715
+ session_track[type].push pkt.lcs
716
+ when SESSION_TRACK_TRANSACTION_CHARACTERISTICS, SESSION_TRACK_TRANSACTION_STATE
717
+ pkt.lcb # skip
718
+ session_track[type].push pkt.lcs
719
+ end
720
+ end
721
+ session_track
722
+ end
723
+
724
+ attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message, :session_track
625
725
 
626
726
  def initialize(*args)
627
- @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
727
+ @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message, @session_track = args
728
+ @session_track ||= {}
628
729
  end
629
730
  end
630
731
 
@@ -679,7 +780,7 @@ class Mysql
679
780
 
680
781
  # Authentication packet
681
782
  class AuthenticationPacket
682
- def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin)
783
+ def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin, connect_attrs)
683
784
  data = [
684
785
  client_flags,
685
786
  max_packet_size,
@@ -695,7 +796,8 @@ class Mysql
695
796
  end
696
797
  data.push auth_plugin
697
798
  pack.concat "Z*"
698
- data.pack(pack)
799
+ attr = connect_attrs.map{|k, v| [Packet.lcs(k.to_s), Packet.lcs(v.to_s)]}.flatten.join
800
+ data.pack(pack) + Packet.lcb(attr.size)+attr
699
801
  end
700
802
  end
701
803
 
@@ -718,7 +820,7 @@ class Mysql
718
820
  netvalues = ""
719
821
  types = values.map do |v|
720
822
  t, n = Protocol.value2net v
721
- netvalues.concat n if v
823
+ netvalues.concat n unless v.nil?
722
824
  t
723
825
  end
724
826
  [Mysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
@@ -728,14 +830,14 @@ class Mysql
728
830
  #
729
831
  # If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
730
832
  def self.null_bitmap(values)
731
- bitmap = values.enum_for(:each_slice,8).map do |vals|
732
- vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}
833
+ bitmap = values.enum_for(:each_slice, 8).map do |vals|
834
+ vals.reverse.inject(0){|b, v| (b << 1 | (v.nil? ? 1 : 0))}
733
835
  end
734
836
  return bitmap.pack("C*")
735
837
  end
736
-
737
838
  end
738
839
 
840
+ # Authentication result packet
739
841
  class AuthenticationResultPacket
740
842
  def self.parse(pkt)
741
843
  result = pkt.utiny
@@ -752,6 +854,7 @@ class Mysql
752
854
  end
753
855
  end
754
856
 
857
+ # raw record
755
858
  class RawRecord
756
859
  def initialize(packet, fields, encoding)
757
860
  @packet, @fields, @encoding = packet, fields, encoding
@@ -759,16 +862,19 @@ class Mysql
759
862
 
760
863
  def to_a
761
864
  @fields.map do |f|
762
- if s = @packet.lcs
763
- unless f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
764
- s = Charset.convert_encoding(s, @encoding)
765
- end
865
+ s = @packet.lcs
866
+ if s.nil?
867
+ nil
868
+ elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
869
+ s.b
870
+ else
871
+ Charset.convert_encoding(s, @encoding)
766
872
  end
767
- s
768
873
  end
769
874
  end
770
875
  end
771
876
 
877
+ # prepared statement raw record
772
878
  class StmtRawRecord
773
879
  # @param pkt [Packet]
774
880
  # @param fields [Array of Fields]
@@ -781,19 +887,21 @@ class Mysql
781
887
  # @return [Array<Object>] one record
782
888
  def parse_record_packet
783
889
  @packet.utiny # skip first byte
784
- null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first
890
+ null_bit_map = @packet.read((@fields.length+7+2)/8).unpack1("b*")
785
891
  rec = @fields.each_with_index.map do |f, i|
786
- if null_bit_map[i+2] == ?1
892
+ if null_bit_map[i+2] == '1'
787
893
  nil
788
894
  else
789
895
  unsigned = f.flags & Field::UNSIGNED_FLAG != 0
790
896
  v = Protocol.net2value(@packet, f.type, unsigned)
791
- if v.nil? or v.is_a? Numeric or v.is_a? Time
897
+ if v.nil? or v.is_a? Numeric or v.is_a? Time or v.is_a? Date
792
898
  v
793
- elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
899
+ elsif f.type == Field::TYPE_BIT
794
900
  Charset.to_binary(v)
901
+ elsif v.is_a? String
902
+ f.charsetnr == Charset::BINARY_CHARSET_NUMBER ? Charset.to_binary(v) : Charset.convert_encoding(v, @encoding)
795
903
  else
796
- Charset.convert_encoding(v, @encoding)
904
+ v
797
905
  end
798
906
  end
799
907
  end
@@ -801,6 +909,5 @@ class Mysql
801
909
  end
802
910
 
803
911
  alias to_a parse_record_packet
804
-
805
912
  end
806
913
  end