ruby-mysql 3.0.1 → 4.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,19 @@
1
1
  # coding: ascii-8bit
2
+
2
3
  # Copyright (C) 2008 TOMITA Masahiro
3
4
  # mailto:tommy@tmtm.org
4
5
 
5
6
  require "socket"
6
7
  require "stringio"
7
8
  require "openssl"
8
- require_relative 'authenticator.rb'
9
+ require "bigdecimal"
10
+ require "date"
11
+ require 'time'
12
+ require_relative 'authenticator'
9
13
 
10
14
  class Mysql
11
15
  # MySQL network protocol
12
16
  class Protocol
13
-
14
17
  VERSION = 10
15
18
  MAX_PACKET_LENGTH = 2**24-1
16
19
 
@@ -21,8 +24,11 @@ class Mysql
21
24
  # @return [Object] converted value.
22
25
  def self.net2value(pkt, type, unsigned)
23
26
  case type
24
- when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB, Field::TYPE_JSON
27
+ when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_BLOB, Field::TYPE_JSON, Field::TYPE_GEOMETRY
25
28
  return pkt.lcs
29
+ when Field::TYPE_NEWDECIMAL
30
+ s = pkt.lcs
31
+ return s =~ /\./ && s !~ /\.0*\z/ ? BigDecimal(s) : s.to_i
26
32
  when Field::TYPE_TINY
27
33
  v = pkt.utiny
28
34
  return unsigned ? v : v < 128 ? v : v-256
@@ -37,18 +43,18 @@ class Mysql
37
43
  v = (n2 << 32) | n1
38
44
  return unsigned ? v : v < 0x8000_0000_0000_0000 ? v : v-0x10000_0000_0000_0000
39
45
  when Field::TYPE_FLOAT
40
- return pkt.read(4).unpack('e').first
46
+ return pkt.read(4).unpack1('e')
41
47
  when Field::TYPE_DOUBLE
42
- return pkt.read(8).unpack('E').first
48
+ return pkt.read(8).unpack1('E')
43
49
  when Field::TYPE_DATE
44
50
  len = pkt.utiny
45
51
  y, m, d = pkt.read(len).unpack("vCC")
46
- t = Time.new(y, m, d) rescue nil
52
+ t = Date.new(y, m, d) rescue nil
47
53
  return t
48
54
  when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
49
55
  len = pkt.utiny
50
56
  y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")
51
- return Time.new(y, m, d, h, mi, Rational((s.to_i*1000000+sp.to_i)/1000000)) rescue nil
57
+ return Time.new(y, m, d, h, mi, Rational(s.to_i*1000000+sp.to_i, 1000000)) rescue nil
52
58
  when Field::TYPE_TIME
53
59
  len = pkt.utiny
54
60
  sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")
@@ -70,6 +76,7 @@ class Mysql
70
76
  # @return [String] netdata
71
77
  # @raise [ProtocolError] value too large / value is not supported
72
78
  def self.value2net(v)
79
+ v = v == true ? 1 : v == false ? 0 : v
73
80
  case v
74
81
  when nil
75
82
  type = Field::TYPE_NULL
@@ -85,8 +92,12 @@ class Mysql
85
92
  type = Field::TYPE_LONGLONG | 0x8000
86
93
  val = [v&0xffffffff, v>>32].pack("VV")
87
94
  else
88
- raise ProtocolError, "value too large: #{v}"
95
+ type =Field::TYPE_NEWDECIMAL
96
+ val = Packet.lcs(v.to_s)
89
97
  end
98
+ when BigDecimal
99
+ type = Field::TYPE_NEWDECIMAL
100
+ val = Packet.lcs(v.to_s)
90
101
  when Float
91
102
  type = Field::TYPE_DOUBLE
92
103
  val = [v].pack("E")
@@ -96,6 +107,12 @@ class Mysql
96
107
  when Time
97
108
  type = Field::TYPE_DATETIME
98
109
  val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec].pack("CvCCCCCV")
110
+ when DateTime
111
+ type = Field::TYPE_DATETIME
112
+ val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, (v.sec_fraction*1000000).to_i].pack("CvCCCCCV")
113
+ when Date
114
+ type = Field::TYPE_DATE
115
+ val = [11, v.year, v.month, v.day, 0, 0, 0, 0].pack("CvCCCCCV")
99
116
  else
100
117
  raise ProtocolError, "class #{v.class} is not supported"
101
118
  end
@@ -112,14 +129,18 @@ class Mysql
112
129
  attr_reader :server_status
113
130
  attr_reader :warning_count
114
131
  attr_reader :message
132
+ attr_reader :session_track
115
133
  attr_reader :get_server_public_key
134
+ attr_reader :field_count
116
135
  attr_accessor :charset
117
136
 
118
137
  # @state variable keep state for connection.
119
- # :INIT :: Initial state.
120
- # :READY :: Ready for command.
121
- # :FIELD :: After query(). retr_fields() is needed.
122
- # :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
138
+ # :INIT :: Initial state.
139
+ # :READY :: Ready for command.
140
+ # :WAIT_RESULT :: After query_command(). get_result() is needed.
141
+ # :FIELD :: After get_result(). retr_fields() is needed.
142
+ # :RESULT :: After retr_fields(), retr_all_records() is needed.
143
+ # :CLOSED :: Connection closed.
123
144
 
124
145
  # make socket connection to server.
125
146
  # @param opts [Hash]
@@ -137,22 +158,28 @@ class Mysql
137
158
  # @option :local_infile [Boolean]
138
159
  # @option :load_data_local_dir [String]
139
160
  # @option :ssl_mode [Integer]
161
+ # @option :ssl_context_params [Hash<:Symbol, String>]
140
162
  # @option :get_server_public_key [Boolean]
163
+ # @option :io [BasicSocket, OpenSSL::SSL::SSLSocket] Existing socket instance that will be used instead of creating a new socket
141
164
  # @raise [ClientError] connection timeout
142
165
  def initialize(opts)
166
+ @mutex = Mutex.new
143
167
  @opts = opts
144
168
  @charset = Mysql::Charset.by_name("utf8mb4")
145
169
  @insert_id = 0
146
170
  @warning_count = 0
171
+ @session_track = {}
147
172
  @gc_stmt_queue = [] # stmt id list which GC destroy.
148
173
  set_state :INIT
149
174
  @get_server_public_key = @opts[:get_server_public_key]
150
175
  begin
151
- if @opts[:host].nil? or @opts[:host].empty? or @opts[:host] == "localhost"
176
+ if @opts[:io]
177
+ @socket = @opts[:io]
178
+ elsif @opts[:host].nil? or @opts[:host].empty? or @opts[:host] == "localhost"
152
179
  socket = @opts[:socket] || ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
153
180
  @socket = Socket.unix(socket)
154
181
  else
155
- port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
182
+ port = @opts[:port] || ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql", "tcp") rescue MYSQL_TCP_PORT)
156
183
  @socket = Socket.tcp(@opts[:host], port, connect_timeout: @opts[:connect_timeout])
157
184
  end
158
185
  rescue Errno::ETIMEDOUT
@@ -161,153 +188,206 @@ class Mysql
161
188
  end
162
189
 
163
190
  def close
164
- @socket.close
191
+ @socket.close rescue nil
165
192
  end
166
193
 
167
194
  # initial negotiate and authenticate.
168
195
  # @param charset [Mysql::Charset, nil] charset for connection. nil: use server's charset
169
196
  # @raise [ProtocolError] The old style password is not supported
170
197
  def authenticate
171
- check_state :INIT
172
- reset
173
- init_packet = InitialPacket.parse read
174
- @server_info = init_packet.server_version
175
- @server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}
176
- @server_capabilities = init_packet.server_capabilities
177
- @thread_id = init_packet.thread_id
178
- @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_RESULTS | CLIENT_PS_MULTI_RESULTS | CLIENT_PLUGIN_AUTH
179
- @client_flags |= CLIENT_LOCAL_FILES if @opts[:local_infile] || @opts[:load_data_local_dir]
180
- @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
181
- @client_flags |= @opts[:flags]
182
- if @opts[:charset]
183
- @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
184
- else
185
- @charset = Charset.by_number(init_packet.server_charset)
186
- @charset.encoding # raise error if unsupported charset
187
- end
188
- enable_ssl
189
- Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin)
190
- set_state :READY
191
- end
198
+ synchronize(before: :INIT, after: :READY) do
199
+ reset
200
+ init_packet = InitialPacket.parse read
201
+ @server_info = init_packet.server_version
202
+ @server_version = init_packet.server_version.split(/\D/)[0, 3].inject{|a, b| a.to_i*100+b.to_i}
203
+ @server_capabilities = init_packet.server_capabilities
204
+ @thread_id = init_packet.thread_id
205
+ @client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_RESULTS | CLIENT_PS_MULTI_RESULTS | CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_ATTRS | CLIENT_SESSION_TRACK | CLIENT_LOCAL_FILES
206
+ @client_flags |= CLIENT_CONNECT_WITH_DB if @opts[:database]
207
+ @client_flags |= @opts[:flags]
208
+ if @opts[:charset]
209
+ @charset = @opts[:charset].is_a?(Charset) ? @opts[:charset] : Charset.by_name(@opts[:charset])
210
+ else
211
+ @charset = Charset.by_number(init_packet.server_charset)
212
+ @charset.encoding # raise error if unsupported charset
213
+ end
214
+ enable_ssl
215
+ Authenticator.new(self).authenticate(@opts[:username], @opts[:password].to_s, @opts[:database], init_packet.scramble_buff, init_packet.auth_plugin, @opts[:connect_attrs])
216
+ end
217
+ end
218
+
219
+ SSL_MODE_KEY = {
220
+ SSL_MODE_DISABLED => 1,
221
+ SSL_MODE_PREFERRED => 2,
222
+ SSL_MODE_REQUIRED => 3,
223
+ SSL_MODE_VERIFY_CA => 4,
224
+ SSL_MODE_VERIFY_IDENTITY => 5,
225
+ '1' => 1,
226
+ '2' => 2,
227
+ '3' => 3,
228
+ '4' => 4,
229
+ '5' => 5,
230
+ 'disabled' => 1,
231
+ 'preferred' => 2,
232
+ 'required' => 3,
233
+ 'verify_ca' => 4,
234
+ 'verify_identity' => 5,
235
+ :disabled => 1,
236
+ :preferred => 2,
237
+ :required => 3,
238
+ :verify_ca => 4,
239
+ :verify_identity => 5,
240
+ }.freeze
192
241
 
193
242
  def enable_ssl
194
- case @opts[:ssl_mode]
195
- when SSL_MODE_DISABLED, '1', 'disabled'
196
- return
197
- when SSL_MODE_PREFERRED, '2', 'preferred'
243
+ ssl_mode = SSL_MODE_KEY[@opts[:ssl_mode]]
244
+ raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported" unless ssl_mode
245
+
246
+ return if ssl_mode == SSL_MODE_DISABLED
247
+ if ssl_mode == SSL_MODE_PREFERRED
198
248
  return if @socket.local_address.unix?
199
249
  return if @server_capabilities & CLIENT_SSL == 0
200
- when SSL_MODE_REQUIRED, '3', 'required'
201
- if @server_capabilities & CLIENT_SSL == 0
202
- raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
203
- end
204
- else
205
- raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported"
206
250
  end
207
- begin
208
- @client_flags |= CLIENT_SSL
209
- write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
210
- @socket = OpenSSL::SSL::SSLSocket.new(@socket)
211
- @socket.sync_close = true
212
- @socket.connect
213
- rescue => e
214
- @client_flags &= ~CLIENT_SSL
215
- return if @opts[:ssl_mode] == SSL_MODE_PREFERRED
216
- raise e
251
+ if ssl_mode >= SSL_MODE_REQUIRED && @server_capabilities & CLIENT_SSL == 0
252
+ raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
217
253
  end
254
+
255
+ context = OpenSSL::SSL::SSLContext.new
256
+ context.set_params(@opts[:ssl_context_params])
257
+ context.verify_mode = OpenSSL::SSL::VERIFY_NONE if ssl_mode < SSL_MODE_VERIFY_CA
258
+ context.verify_hostname = false if ssl_mode < SSL_MODE_VERIFY_IDENTITY
259
+
260
+ ssl_socket = OpenSSL::SSL::SSLSocket.new(@socket, context)
261
+ ssl_socket.sync_close = true
262
+ ssl_socket.hostname = @opts[:host] if ssl_mode >= SSL_MODE_VERIFY_IDENTITY
263
+
264
+ @client_flags |= CLIENT_SSL
265
+ write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
266
+
267
+ ssl_socket.connect
268
+ @socket = ssl_socket
269
+ rescue OpenSSL::SSL::SSLError => e
270
+ @client_flags &= ~CLIENT_SSL
271
+ return if @opts[:ssl_mode] < SSL_MODE_REQUIRED
272
+ raise e
273
+ end
274
+
275
+ def ssl_cipher
276
+ @client_flags.allbits?(CLIENT_SSL) ? @socket.cipher : nil
218
277
  end
219
278
 
220
279
  # Quit command
221
280
  def quit_command
222
- synchronize do
281
+ get_result if @state == :WAIT_RESULT
282
+ retr_fields if @state == :FIELD
283
+ retr_all_records(RawRecord) if @state == :RESULT
284
+ synchronize(before: :READY, after: :CLOSED) do
223
285
  reset
224
286
  write [COM_QUIT].pack("C")
225
287
  close
288
+ @gc_stmt_queue.clear
226
289
  end
227
290
  end
228
291
 
229
292
  # Query command
230
293
  # @param query [String] query string
231
- # @return [Integer, nil] number of fields of results. nil if no results.
232
294
  def query_command(query)
233
- check_state :READY
234
- begin
295
+ synchronize(before: :READY, after: :WAIT_RESULT, error: :READY) do
235
296
  reset
236
297
  write [COM_QUERY, @charset.convert(query)].pack("Ca*")
237
- get_result
238
- rescue
239
- set_state :READY
240
- raise
241
298
  end
242
299
  end
243
300
 
244
301
  # get result of query.
245
302
  # @return [integer, nil] number of fields of results. nil if no results.
246
303
  def get_result
247
- begin
304
+ synchronize(before: :WAIT_RESULT, error: :READY) do
248
305
  res_packet = ResultPacket.parse read
249
- if res_packet.field_count.to_i > 0 # result data exists
306
+ @field_count = res_packet.field_count
307
+ if @field_count.to_i > 0 # result data exists
250
308
  set_state :FIELD
251
- return res_packet.field_count
309
+ return @field_count
252
310
  end
253
- if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE
311
+ if @field_count.nil? # LOAD DATA LOCAL INFILE
254
312
  send_local_file(res_packet.message)
255
313
  res_packet = ResultPacket.parse read
256
314
  end
257
- @affected_rows, @insert_id, @server_status, @warning_count, @message =
258
- res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message
259
- set_state :READY
315
+ @affected_rows, @insert_id, @server_status, @warning_count, @message, @session_track =
316
+ res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message, res_packet.session_track
317
+ set_state :READY unless more_results?
260
318
  return nil
261
- rescue
262
- set_state :READY
263
- raise
264
319
  end
265
320
  end
266
321
 
322
+ def more_results?
323
+ @server_status & SERVER_MORE_RESULTS_EXISTS != 0
324
+ end
325
+
267
326
  # send local file to server
268
327
  def send_local_file(filename)
269
328
  filename = File.absolute_path(filename)
270
329
  if @opts[:local_infile] || @opts[:load_data_local_dir] && filename.start_with?(@opts[:load_data_local_dir])
271
330
  File.open(filename){|f| write f}
331
+ write nil # EOF
272
332
  else
333
+ write nil # send empty data instead of file contents
334
+ read # result packet
273
335
  raise ClientError::LoadDataLocalInfileRejected, 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
274
336
  end
275
- ensure
276
- write nil # EOF mark
277
337
  end
278
338
 
279
339
  # Retrieve n fields
280
- # @param n [Integer] number of fields
281
340
  # @return [Array<Mysql::Field>] field list
282
- def retr_fields(n)
283
- check_state :FIELD
284
- begin
285
- fields = n.times.map{Field.new FieldPacket.parse(read)}
341
+ def retr_fields
342
+ synchronize(before: :FIELD, after: :RESULT, error: :READY) do
343
+ @fields = @field_count.times.map{Field.new FieldPacket.parse(read)}
286
344
  read_eof_packet
287
- set_state :RESULT
288
- fields
289
- rescue
290
- set_state :READY
291
- raise
345
+ @no_more_records = false
346
+ @fields
292
347
  end
293
348
  end
294
349
 
295
- # Retrieve all records for simple query
296
- # @param fields [Array<Mysql::Field>] number of fields
297
- # @return [Array<Array<String>>] all records
298
- def retr_all_records(fields)
299
- check_state :RESULT
300
- enc = charset.encoding
301
- begin
302
- all_recs = []
303
- until (pkt = read).eof?
304
- all_recs.push RawRecord.new(pkt, fields, enc)
350
+ # Retrieve one record for simple query or prepared statement
351
+ # @param record_class [RawRecord or StmtRawRecord]
352
+ # @return [<record_class>] record
353
+ # @return [nil] no more record
354
+ def retr_record(record_class)
355
+ return nil if @no_more_records
356
+ synchronize(before: :RESULT) do
357
+ enc = charset.encoding
358
+ begin
359
+ unless (pkt = read).eof?
360
+ return record_class.new(pkt, @fields, enc)
361
+ end
362
+ pkt.utiny
363
+ pkt.ushort
364
+ @server_status = pkt.ushort
365
+ set_state(more_results? ? :WAIT_RESULT : :READY)
366
+ @no_more_records = true
367
+ return nil
368
+ end
369
+ end
370
+ end
371
+
372
+ # Retrieve all records for simple query or prepared statement
373
+ # @param record_class [RawRecord or StmtRawRecord]
374
+ # @return [Array<record_class>] all records
375
+ def retr_all_records(record_class)
376
+ synchronize(before: :RESULT) do
377
+ enc = charset.encoding
378
+ begin
379
+ all_recs = []
380
+ until (pkt = read).eof?
381
+ all_recs.push record_class.new(pkt, @fields, enc)
382
+ end
383
+ pkt.utiny # 0xFE
384
+ _warnings = pkt.ushort
385
+ @server_status = pkt.ushort
386
+ @no_more_records = true
387
+ all_recs
388
+ ensure
389
+ set_state(more_results? ? :WAIT_RESULT : :READY)
305
390
  end
306
- pkt.read(3)
307
- @server_status = pkt.utiny
308
- all_recs
309
- ensure
310
- set_state :READY
311
391
  end
312
392
  end
313
393
 
@@ -345,7 +425,7 @@ class Mysql
345
425
  # @param stmt [String] prepared statement
346
426
  # @return [Array<Integer, Integer, Array<Field>>] statement id, number of parameters, field list
347
427
  def stmt_prepare_command(stmt)
348
- synchronize do
428
+ synchronize(before: :READY, after: :READY) do
349
429
  reset
350
430
  write [COM_STMT_PREPARE, charset.convert(stmt)].pack("Ca*")
351
431
  res_packet = PrepareResultPacket.parse read
@@ -368,41 +448,22 @@ class Mysql
368
448
  # @param values [Array] parameters
369
449
  # @return [Integer] number of fields
370
450
  def stmt_execute_command(stmt_id, values)
371
- check_state :READY
372
- begin
451
+ synchronize(before: :READY, after: :WAIT_RESULT, error: :READY) do
373
452
  reset
374
453
  write ExecutePacket.serialize(stmt_id, Mysql::Stmt::CURSOR_TYPE_NO_CURSOR, values)
375
- get_result
376
- rescue
377
- set_state :READY
378
- raise
379
- end
380
- end
381
-
382
- # Retrieve all records for prepared statement
383
- # @param fields [Array of Mysql::Fields] field list
384
- # @param charset [Mysql::Charset]
385
- # @return [Array<Array<Object>>] all records
386
- def stmt_retr_all_records(fields, charset)
387
- check_state :RESULT
388
- enc = charset.encoding
389
- begin
390
- all_recs = []
391
- until (pkt = read).eof?
392
- all_recs.push StmtRawRecord.new(pkt, fields, enc)
393
- end
394
- all_recs
395
- ensure
396
- set_state :READY
397
454
  end
398
455
  end
399
456
 
400
457
  # Stmt close command
401
458
  # @param stmt_id [Integer] statement id
402
459
  def stmt_close_command(stmt_id)
403
- synchronize do
460
+ get_result if @state == :WAIT_RESULT
461
+ retr_fields if @state == :FIELD
462
+ retr_all_records(StmtRawRecord) if @state == :RESULT
463
+ synchronize(before: :READY, after: :READY) do
404
464
  reset
405
465
  write [COM_STMT_CLOSE, stmt_id].pack("CV")
466
+ @gc_stmt_queue.delete stmt_id
406
467
  end
407
468
  end
408
469
 
@@ -411,30 +472,35 @@ class Mysql
411
472
  end
412
473
 
413
474
  def check_state(st)
414
- raise 'command out of sync' unless @state == st
475
+ raise Mysql::ClientError::CommandsOutOfSync, 'command out of sync' unless @state == st
415
476
  end
416
477
 
417
478
  def set_state(st)
418
479
  @state = st
419
- if st == :READY && !@gc_stmt_queue.empty?
420
- gc_disabled = GC.disable
421
- begin
422
- while st = @gc_stmt_queue.shift
423
- reset
424
- write [COM_STMT_CLOSE, st].pack("CV")
425
- end
426
- ensure
427
- GC.enable unless gc_disabled
480
+ return if st != :READY || @gc_stmt_queue.empty? || @socket&.closed?
481
+ gc_disabled = GC.disable
482
+ begin
483
+ while (st = @gc_stmt_queue.shift)
484
+ reset
485
+ write [COM_STMT_CLOSE, st].pack("CV")
428
486
  end
487
+ ensure
488
+ GC.enable unless gc_disabled
429
489
  end
430
490
  end
431
491
 
432
- def synchronize
433
- begin
434
- check_state :READY
435
- return yield
436
- ensure
437
- set_state :READY
492
+ def synchronize(before: nil, after: nil, error: nil)
493
+ @mutex.synchronize do
494
+ check_state before if before
495
+ begin
496
+ return yield
497
+ rescue
498
+ set_state error if error
499
+ raised = true
500
+ raise
501
+ ensure
502
+ set_state after if after && !raised
503
+ end
438
504
  end
439
505
  end
440
506
 
@@ -450,18 +516,19 @@ class Mysql
450
516
  data = ''
451
517
  len = nil
452
518
  begin
453
- header = read_timeout(4, @opts[:read_timeout])
519
+ timeout = @state == :INIT ? @opts[:connect_timeout] : @opts[:read_timeout]
520
+ header = read_timeout(4, timeout)
454
521
  raise EOFError unless header && header.length == 4
455
522
  len1, len2, seq = header.unpack("CvC")
456
523
  len = (len2 << 8) + len1
457
524
  raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
458
525
  @seq = (@seq + 1) % 256
459
- ret = read_timeout(len, @opts[:read_timeout])
526
+ ret = read_timeout(len, timeout)
460
527
  raise EOFError unless ret && ret.length == len
461
528
  data.concat ret
462
- rescue EOFError
463
- @socket.close rescue nil
464
- raise ClientError::ServerGoneError, 'MySQL server has gone away'
529
+ rescue EOFError, OpenSSL::SSL::SSLError
530
+ close
531
+ raise ClientError::ServerLost, 'Lost connection to server during query'
465
532
  rescue Errno::ETIMEDOUT
466
533
  raise ClientError, "read timeout"
467
534
  end while len == MAX_PACKET_LENGTH
@@ -495,10 +562,10 @@ class Mysql
495
562
  r = @socket.read_nonblock(len - result.size, exception: false)
496
563
  case r
497
564
  when :wait_readable
498
- IO.select([@socket], nil, nil, e - now)
565
+ IO.select([@socket], nil, nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
499
566
  next
500
567
  when :wait_writable
501
- IO.select(nil, [@socket], nil, e - now)
568
+ IO.select(nil, [@socket], nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
502
569
  next
503
570
  else
504
571
  result << r
@@ -510,26 +577,25 @@ class Mysql
510
577
  # Write one packet data
511
578
  # @param data [String, IO, nil] packet data. If data is nil, write empty packet.
512
579
  def write(data)
513
- begin
514
- @socket.sync = false
515
- if data.nil?
516
- write_timeout([0, 0, @seq].pack("CvC"), @opts[:write_timeout])
580
+ timeout = @state == :INIT ? @opts[:connect_timeout] : @opts[:write_timeout]
581
+ @socket.sync = false
582
+ if data.nil?
583
+ write_timeout([0, 0, @seq].pack("CvC"), timeout)
584
+ @seq = (@seq + 1) % 256
585
+ else
586
+ data = StringIO.new data if data.is_a? String
587
+ while (d = data.read(MAX_PACKET_LENGTH))
588
+ write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, timeout)
517
589
  @seq = (@seq + 1) % 256
518
- else
519
- data = StringIO.new data if data.is_a? String
520
- while d = data.read(MAX_PACKET_LENGTH)
521
- write_timeout([d.length%256, d.length/256, @seq].pack("CvC")+d, @opts[:write_timeout])
522
- @seq = (@seq + 1) % 256
523
- end
524
590
  end
525
- @socket.sync = true
526
- @socket.flush
527
- rescue Errno::EPIPE
528
- @socket.close rescue nil
529
- raise ClientError::ServerGoneError, 'MySQL server has gone away'
530
- rescue Errno::ETIMEDOUT
531
- raise ClientError, "write timeout"
532
591
  end
592
+ @socket.sync = true
593
+ @socket.flush
594
+ rescue Errno::EPIPE, OpenSSL::SSL::SSLError
595
+ close
596
+ raise ClientError::ServerGoneError, 'MySQL server has gone away'
597
+ rescue Errno::ETIMEDOUT
598
+ raise ClientError, "write timeout"
533
599
  end
534
600
 
535
601
  def write_timeout(data, timeout)
@@ -539,12 +605,12 @@ class Mysql
539
605
  while len < data.size
540
606
  now = Time.now
541
607
  raise Errno::ETIMEDOUT if now > e
542
- l = @socket.write_nonblock(data[len..-1], exception: false)
608
+ l = @socket.write_nonblock(data[len..], exception: false)
543
609
  case l
544
610
  when :wait_readable
545
- IO.select([@socket], nil, nil, e - now)
611
+ IO.select([@socket], nil, nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
546
612
  when :wait_writable
547
- IO.select(nil, [@socket], nil, e - now)
613
+ IO.select(nil, [@socket], nil, e - now) # rubocop:disable Lint/IncompatibleIoSelectWithFiberScheduler
548
614
  else
549
615
  len += l
550
616
  end
@@ -555,14 +621,18 @@ class Mysql
555
621
  # Read EOF packet
556
622
  # @raise [ProtocolError] packet is not EOF
557
623
  def read_eof_packet
558
- raise ProtocolError, "packet is not EOF" unless read.eof?
624
+ pkt = read
625
+ raise ProtocolError, "packet is not EOF" unless pkt.eof?
626
+ pkt.utiny # 0xFE
627
+ _warnings = pkt.ushort
628
+ @server_status = pkt.ushort
559
629
  end
560
630
 
561
631
  # Send simple command
562
632
  # @param packet :: [String] packet data
563
633
  # @return [String] received data
564
634
  def simple_command(packet)
565
- synchronize do
635
+ synchronize(before: :READY, after: :READY) do
566
636
  reset
567
637
  write packet
568
638
  read.to_s
@@ -613,7 +683,10 @@ class Mysql
613
683
  server_status = pkt.ushort
614
684
  warning_count = pkt.ushort
615
685
  message = pkt.lcs
616
- return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message)
686
+ session_track = parse_session_track(pkt.lcs) if server_status & SERVER_SESSION_STATE_CHANGED
687
+ message = pkt.lcs unless pkt.to_s.empty?
688
+
689
+ return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message, session_track)
617
690
  elsif field_count.nil? # LOAD DATA LOCAL INFILE
618
691
  return self.new(nil, nil, nil, nil, nil, pkt.to_s)
619
692
  else
@@ -621,10 +694,38 @@ class Mysql
621
694
  end
622
695
  end
623
696
 
624
- attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
697
+ def self.parse_session_track(data)
698
+ session_track = {}
699
+ pkt = Packet.new(data.to_s)
700
+ until pkt.to_s.empty?
701
+ type = pkt.lcb
702
+ session_track[type] ||= []
703
+ case type
704
+ when SESSION_TRACK_SYSTEM_VARIABLES
705
+ p = Packet.new(pkt.lcs)
706
+ session_track[type].push [p.lcs, p.lcs]
707
+ when SESSION_TRACK_SCHEMA
708
+ pkt.lcb # skip
709
+ session_track[type].push pkt.lcs
710
+ when SESSION_TRACK_STATE_CHANGE
711
+ session_track[type].push pkt.lcs
712
+ when SESSION_TRACK_GTIDS
713
+ pkt.lcb # skip
714
+ pkt.lcb # skip
715
+ session_track[type].push pkt.lcs
716
+ when SESSION_TRACK_TRANSACTION_CHARACTERISTICS, SESSION_TRACK_TRANSACTION_STATE
717
+ pkt.lcb # skip
718
+ session_track[type].push pkt.lcs
719
+ end
720
+ end
721
+ session_track
722
+ end
723
+
724
+ attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message, :session_track
625
725
 
626
726
  def initialize(*args)
627
- @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
727
+ @field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message, @session_track = args
728
+ @session_track ||= {}
628
729
  end
629
730
  end
630
731
 
@@ -679,7 +780,7 @@ class Mysql
679
780
 
680
781
  # Authentication packet
681
782
  class AuthenticationPacket
682
- def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin)
783
+ def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename, auth_plugin, connect_attrs)
683
784
  data = [
684
785
  client_flags,
685
786
  max_packet_size,
@@ -695,7 +796,8 @@ class Mysql
695
796
  end
696
797
  data.push auth_plugin
697
798
  pack.concat "Z*"
698
- data.pack(pack)
799
+ attr = connect_attrs.map{|k, v| [Packet.lcs(k.to_s), Packet.lcs(v.to_s)]}.flatten.join
800
+ data.pack(pack) + Packet.lcb(attr.size)+attr
699
801
  end
700
802
  end
701
803
 
@@ -718,7 +820,7 @@ class Mysql
718
820
  netvalues = ""
719
821
  types = values.map do |v|
720
822
  t, n = Protocol.value2net v
721
- netvalues.concat n if v
823
+ netvalues.concat n unless v.nil?
722
824
  t
723
825
  end
724
826
  [Mysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
@@ -728,14 +830,14 @@ class Mysql
728
830
  #
729
831
  # If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
730
832
  def self.null_bitmap(values)
731
- bitmap = values.enum_for(:each_slice,8).map do |vals|
732
- vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}
833
+ bitmap = values.enum_for(:each_slice, 8).map do |vals|
834
+ vals.reverse.inject(0){|b, v| (b << 1 | (v.nil? ? 1 : 0))}
733
835
  end
734
836
  return bitmap.pack("C*")
735
837
  end
736
-
737
838
  end
738
839
 
840
+ # Authentication result packet
739
841
  class AuthenticationResultPacket
740
842
  def self.parse(pkt)
741
843
  result = pkt.utiny
@@ -752,6 +854,7 @@ class Mysql
752
854
  end
753
855
  end
754
856
 
857
+ # raw record
755
858
  class RawRecord
756
859
  def initialize(packet, fields, encoding)
757
860
  @packet, @fields, @encoding = packet, fields, encoding
@@ -759,16 +862,19 @@ class Mysql
759
862
 
760
863
  def to_a
761
864
  @fields.map do |f|
762
- if s = @packet.lcs
763
- unless f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
764
- s = Charset.convert_encoding(s, @encoding)
765
- end
865
+ s = @packet.lcs
866
+ if s.nil?
867
+ nil
868
+ elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
869
+ s.b
870
+ else
871
+ Charset.convert_encoding(s, @encoding)
766
872
  end
767
- s
768
873
  end
769
874
  end
770
875
  end
771
876
 
877
+ # prepared statement raw record
772
878
  class StmtRawRecord
773
879
  # @param pkt [Packet]
774
880
  # @param fields [Array of Fields]
@@ -781,19 +887,21 @@ class Mysql
781
887
  # @return [Array<Object>] one record
782
888
  def parse_record_packet
783
889
  @packet.utiny # skip first byte
784
- null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first
890
+ null_bit_map = @packet.read((@fields.length+7+2)/8).unpack1("b*")
785
891
  rec = @fields.each_with_index.map do |f, i|
786
- if null_bit_map[i+2] == ?1
892
+ if null_bit_map[i+2] == '1'
787
893
  nil
788
894
  else
789
895
  unsigned = f.flags & Field::UNSIGNED_FLAG != 0
790
896
  v = Protocol.net2value(@packet, f.type, unsigned)
791
- if v.nil? or v.is_a? Numeric or v.is_a? Time
897
+ if v.nil? or v.is_a? Numeric or v.is_a? Time or v.is_a? Date
792
898
  v
793
- elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER
899
+ elsif f.type == Field::TYPE_BIT
794
900
  Charset.to_binary(v)
901
+ elsif v.is_a? String
902
+ f.charsetnr == Charset::BINARY_CHARSET_NUMBER ? Charset.to_binary(v) : Charset.convert_encoding(v, @encoding)
795
903
  else
796
- Charset.convert_encoding(v, @encoding)
904
+ v
797
905
  end
798
906
  end
799
907
  end
@@ -801,6 +909,5 @@ class Mysql
801
909
  end
802
910
 
803
911
  alias to_a parse_record_packet
804
-
805
912
  end
806
913
  end