singlestoredb 0.3.3__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of singlestoredb might be problematic. Click here for more details.

Files changed (121) hide show
  1. singlestoredb/__init__.py +33 -2
  2. singlestoredb/alchemy/__init__.py +90 -0
  3. singlestoredb/auth.py +6 -4
  4. singlestoredb/config.py +116 -16
  5. singlestoredb/connection.py +489 -523
  6. singlestoredb/converters.py +275 -26
  7. singlestoredb/exceptions.py +30 -4
  8. singlestoredb/functions/__init__.py +1 -0
  9. singlestoredb/functions/decorator.py +142 -0
  10. singlestoredb/functions/dtypes.py +1639 -0
  11. singlestoredb/functions/ext/__init__.py +2 -0
  12. singlestoredb/functions/ext/arrow.py +375 -0
  13. singlestoredb/functions/ext/asgi.py +661 -0
  14. singlestoredb/functions/ext/json.py +427 -0
  15. singlestoredb/functions/ext/mmap.py +306 -0
  16. singlestoredb/functions/ext/rowdat_1.py +744 -0
  17. singlestoredb/functions/signature.py +673 -0
  18. singlestoredb/fusion/__init__.py +11 -0
  19. singlestoredb/fusion/graphql.py +213 -0
  20. singlestoredb/fusion/handler.py +621 -0
  21. singlestoredb/fusion/handlers/__init__.py +0 -0
  22. singlestoredb/fusion/handlers/stage.py +257 -0
  23. singlestoredb/fusion/handlers/utils.py +162 -0
  24. singlestoredb/fusion/handlers/workspace.py +412 -0
  25. singlestoredb/fusion/registry.py +164 -0
  26. singlestoredb/fusion/result.py +399 -0
  27. singlestoredb/http/__init__.py +27 -0
  28. singlestoredb/http/connection.py +1192 -0
  29. singlestoredb/management/__init__.py +3 -2
  30. singlestoredb/management/billing_usage.py +148 -0
  31. singlestoredb/management/cluster.py +19 -14
  32. singlestoredb/management/manager.py +100 -40
  33. singlestoredb/management/organization.py +188 -0
  34. singlestoredb/management/region.py +6 -8
  35. singlestoredb/management/utils.py +253 -4
  36. singlestoredb/management/workspace.py +1153 -35
  37. singlestoredb/mysql/__init__.py +177 -0
  38. singlestoredb/mysql/_auth.py +298 -0
  39. singlestoredb/mysql/charset.py +214 -0
  40. singlestoredb/mysql/connection.py +1814 -0
  41. singlestoredb/mysql/constants/CLIENT.py +38 -0
  42. singlestoredb/mysql/constants/COMMAND.py +32 -0
  43. singlestoredb/mysql/constants/CR.py +78 -0
  44. singlestoredb/mysql/constants/ER.py +474 -0
  45. singlestoredb/mysql/constants/FIELD_TYPE.py +32 -0
  46. singlestoredb/mysql/constants/FLAG.py +15 -0
  47. singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
  48. singlestoredb/mysql/constants/__init__.py +0 -0
  49. singlestoredb/mysql/converters.py +271 -0
  50. singlestoredb/mysql/cursors.py +713 -0
  51. singlestoredb/mysql/err.py +92 -0
  52. singlestoredb/mysql/optionfile.py +20 -0
  53. singlestoredb/mysql/protocol.py +388 -0
  54. singlestoredb/mysql/tests/__init__.py +19 -0
  55. singlestoredb/mysql/tests/base.py +126 -0
  56. singlestoredb/mysql/tests/conftest.py +37 -0
  57. singlestoredb/mysql/tests/test_DictCursor.py +132 -0
  58. singlestoredb/mysql/tests/test_SSCursor.py +141 -0
  59. singlestoredb/mysql/tests/test_basic.py +452 -0
  60. singlestoredb/mysql/tests/test_connection.py +851 -0
  61. singlestoredb/mysql/tests/test_converters.py +58 -0
  62. singlestoredb/mysql/tests/test_cursor.py +141 -0
  63. singlestoredb/mysql/tests/test_err.py +16 -0
  64. singlestoredb/mysql/tests/test_issues.py +514 -0
  65. singlestoredb/mysql/tests/test_load_local.py +75 -0
  66. singlestoredb/mysql/tests/test_nextset.py +88 -0
  67. singlestoredb/mysql/tests/test_optionfile.py +27 -0
  68. singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
  69. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
  70. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
  71. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
  72. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
  73. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
  74. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
  75. singlestoredb/mysql/times.py +23 -0
  76. singlestoredb/pytest.py +283 -0
  77. singlestoredb/tests/empty.sql +0 -0
  78. singlestoredb/tests/ext_funcs/__init__.py +385 -0
  79. singlestoredb/tests/test.sql +210 -0
  80. singlestoredb/tests/test2.sql +1 -0
  81. singlestoredb/tests/test_basics.py +482 -117
  82. singlestoredb/tests/test_config.py +13 -15
  83. singlestoredb/tests/test_connection.py +241 -289
  84. singlestoredb/tests/test_dbapi.py +27 -0
  85. singlestoredb/tests/test_exceptions.py +0 -2
  86. singlestoredb/tests/test_ext_func.py +1193 -0
  87. singlestoredb/tests/test_ext_func_data.py +1101 -0
  88. singlestoredb/tests/test_fusion.py +465 -0
  89. singlestoredb/tests/test_http.py +32 -28
  90. singlestoredb/tests/test_management.py +588 -10
  91. singlestoredb/tests/test_plugin.py +33 -0
  92. singlestoredb/tests/test_results.py +11 -14
  93. singlestoredb/tests/test_types.py +0 -2
  94. singlestoredb/tests/test_udf.py +687 -0
  95. singlestoredb/tests/test_xdict.py +0 -2
  96. singlestoredb/tests/utils.py +3 -4
  97. singlestoredb/types.py +4 -5
  98. singlestoredb/utils/config.py +71 -12
  99. singlestoredb/utils/convert_rows.py +0 -2
  100. singlestoredb/utils/debug.py +13 -0
  101. singlestoredb/utils/mogrify.py +151 -0
  102. singlestoredb/utils/results.py +4 -3
  103. singlestoredb/utils/xdict.py +12 -12
  104. singlestoredb-1.0.3.dist-info/METADATA +139 -0
  105. singlestoredb-1.0.3.dist-info/RECORD +112 -0
  106. {singlestoredb-0.3.3.dist-info → singlestoredb-1.0.3.dist-info}/WHEEL +1 -1
  107. singlestoredb-1.0.3.dist-info/entry_points.txt +2 -0
  108. singlestoredb/drivers/__init__.py +0 -46
  109. singlestoredb/drivers/base.py +0 -200
  110. singlestoredb/drivers/cymysql.py +0 -40
  111. singlestoredb/drivers/http.py +0 -49
  112. singlestoredb/drivers/mariadb.py +0 -42
  113. singlestoredb/drivers/mysqlconnector.py +0 -51
  114. singlestoredb/drivers/mysqldb.py +0 -62
  115. singlestoredb/drivers/pymysql.py +0 -39
  116. singlestoredb/drivers/pyodbc.py +0 -67
  117. singlestoredb/http.py +0 -794
  118. singlestoredb-0.3.3.dist-info/METADATA +0 -105
  119. singlestoredb-0.3.3.dist-info/RECORD +0 -46
  120. {singlestoredb-0.3.3.dist-info → singlestoredb-1.0.3.dist-info}/LICENSE +0 -0
  121. {singlestoredb-0.3.3.dist-info → singlestoredb-1.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1814 @@
1
+ # type: ignore
2
+ # Python implementation of the MySQL client-server protocol
3
+ # http://dev.mysql.com/doc/internals/en/client-server-protocol.html
4
+ # Error codes:
5
+ # https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
6
+ import errno
7
+ import functools
8
+ import os
9
+ import socket
10
+ import struct
11
+ import sys
12
+ import traceback
13
+ import warnings
14
+
15
+ try:
16
+ import _singlestoredb_accel
17
+ except (ImportError, ModuleNotFoundError):
18
+ _singlestoredb_accel = None
19
+
20
+ from . import _auth
21
+
22
+ from .charset import charset_by_name, charset_by_id
23
+ from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS
24
+ from . import converters
25
+ from .cursors import (
26
+ Cursor,
27
+ CursorSV,
28
+ DictCursor,
29
+ DictCursorSV,
30
+ NamedtupleCursor,
31
+ NamedtupleCursorSV,
32
+ SSCursor,
33
+ SSCursorSV,
34
+ SSDictCursor,
35
+ SSDictCursorSV,
36
+ SSNamedtupleCursor,
37
+ SSNamedtupleCursorSV,
38
+ )
39
+ from .optionfile import Parser
40
+ from .protocol import (
41
+ dump_packet,
42
+ MysqlPacket,
43
+ FieldDescriptorPacket,
44
+ OKPacketWrapper,
45
+ EOFPacketWrapper,
46
+ LoadLocalPacketWrapper,
47
+ )
48
+ from . import err
49
+ from ..config import get_option
50
+ from .. import fusion
51
+ from .. import connection
52
+ from ..connection import Connection as BaseConnection
53
+ from ..utils.debug import log_query
54
+
55
+ try:
56
+ import ssl
57
+
58
+ SSL_ENABLED = True
59
+ except ImportError:
60
+ ssl = None
61
+ SSL_ENABLED = False
62
+
63
+ try:
64
+ import getpass
65
+
66
+ DEFAULT_USER = getpass.getuser()
67
+ del getpass
68
+ except (ImportError, KeyError):
69
+ # KeyError occurs when there's no entry in OS database for a current user.
70
+ DEFAULT_USER = None
71
+
72
+ DEBUG = get_option('debug.connection')
73
+
74
+ TEXT_TYPES = {
75
+ FIELD_TYPE.BIT,
76
+ FIELD_TYPE.BLOB,
77
+ FIELD_TYPE.LONG_BLOB,
78
+ FIELD_TYPE.MEDIUM_BLOB,
79
+ FIELD_TYPE.STRING,
80
+ FIELD_TYPE.TINY_BLOB,
81
+ FIELD_TYPE.VAR_STRING,
82
+ FIELD_TYPE.VARCHAR,
83
+ FIELD_TYPE.GEOMETRY,
84
+ }
85
+
86
+ UNSET = 'unset'
87
+
88
+ DEFAULT_CHARSET = 'utf8mb4'
89
+
90
+ MAX_PACKET_LEN = 2**24 - 1
91
+
92
+
93
+ def _pack_int24(n):
94
+ return struct.pack('<I', n)[:3]
95
+
96
+
97
+ # https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
98
+ def _lenenc_int(i):
99
+ if i < 0:
100
+ raise ValueError(
101
+ 'Encoding %d is less than 0 - no representation in LengthEncodedInteger' % i,
102
+ )
103
+ elif i < 0xFB:
104
+ return bytes([i])
105
+ elif i < (1 << 16):
106
+ return b'\xfc' + struct.pack('<H', i)
107
+ elif i < (1 << 24):
108
+ return b'\xfd' + struct.pack('<I', i)[:3]
109
+ elif i < (1 << 64):
110
+ return b'\xfe' + struct.pack('<Q', i)
111
+ else:
112
+ raise ValueError(
113
+ 'Encoding %x is larger than %x - no representation in LengthEncodedInteger'
114
+ % (i, (1 << 64)),
115
+ )
116
+
117
+
118
+ class Connection(BaseConnection):
119
+ """
120
+ Representation of a socket with a mysql server.
121
+
122
+ The proper way to get an instance of this class is to call
123
+ ``connect()``.
124
+
125
+ Establish a connection to the SingleStoreDB database.
126
+
127
+ Parameters
128
+ ----------
129
+ host : str, optional
130
+ Host where the database server is located.
131
+ user : str, optional
132
+ Username to log in as.
133
+ password : str, optional
134
+ Password to use.
135
+ database : str, optional
136
+ Database to use, None to not use a particular one.
137
+ port : int, optional
138
+ Server port to use, default is usually OK. (default: 3306)
139
+ bind_address : str, optional
140
+ When the client has multiple network interfaces, specify
141
+ the interface from which to connect to the host. Argument can be
142
+ a hostname or an IP address.
143
+ unix_socket : str, optional
144
+ Use a unix socket rather than TCP/IP.
145
+ read_timeout : int, optional
146
+ The timeout for reading from the connection in seconds
147
+ (default: None - no timeout)
148
+ write_timeout : int, optional
149
+ The timeout for writing to the connection in seconds
150
+ (default: None - no timeout)
151
+ charset : str, optional
152
+ Charset to use.
153
+ collation : str, optional
154
+ The charset collation
155
+ sql_mode : str, optional
156
+ Default SQL_MODE to use.
157
+ read_default_file : str, optional
158
+ Specifies my.cnf file to read these parameters from under the
159
+ [client] section.
160
+ conv : Dict[str, Callable[Any]], optional
161
+ Conversion dictionary to use instead of the default one.
162
+ This is used to provide custom marshalling and unmarshalling of types.
163
+ See converters.
164
+ use_unicode : bool, optional
165
+ Whether or not to default to unicode strings.
166
+ This option defaults to true.
167
+ client_flag : int, optional
168
+ Custom flags to send to MySQL. Find potential values in constants.CLIENT.
169
+ cursorclass : type, optional
170
+ Custom cursor class to use.
171
+ init_command : str, optional
172
+ Initial SQL statement to run when connection is established.
173
+ connect_timeout : int, optional
174
+ The timeout for connecting to the database in seconds.
175
+ (default: 10, min: 1, max: 31536000)
176
+ ssl : Dict[str, str], optional
177
+ A dict of arguments similar to mysql_ssl_set()'s parameters or
178
+ an ssl.SSLContext.
179
+ ssl_ca : str, optional
180
+ Path to the file that contains a PEM-formatted CA certificate.
181
+ ssl_cert : str, optional
182
+ Path to the file that contains a PEM-formatted client certificate.
183
+ ssl_cipher : str, optional
184
+ SSL ciphers to allow.
185
+ ssl_disabled : bool, optional
186
+ A boolean value that disables usage of TLS.
187
+ ssl_key : str, optional
188
+ Path to the file that contains a PEM-formatted private key for the
189
+ client certificate.
190
+ ssl_verify_cert : str, optional
191
+ Set to true to check the server certificate's validity.
192
+ ssl_verify_identity : bool, optional
193
+ Set to true to check the server's identity.
194
+ read_default_group : str, optional
195
+ Group to read from in the configuration file.
196
+ autocommit : bool, optional
197
+ Autocommit mode. None means use server default. (default: False)
198
+ local_infile : bool, optional
199
+ Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
200
+ max_allowed_packet : int, optional
201
+ Max size of packet sent to server in bytes. (default: 16MB)
202
+ Only used to limit size of "LOAD LOCAL INFILE" data packet smaller
203
+ than default (16KB).
204
+ defer_connect : bool, optional
205
+ Don't explicitly connect on construction - wait for connect call.
206
+ (default: False)
207
+ auth_plugin_map : Dict[str, type], optional
208
+ A dict of plugin names to a class that processes that plugin.
209
+ The class will take the Connection object as the argument to the
210
+ constructor. The class needs an authenticate method taking an
211
+ authentication packet as an argument. For the dialog plugin, a
212
+ prompt(echo, prompt) method can be used (if no authenticate method)
213
+ for returning a string from the user. (experimental)
214
+ server_public_key : str, optional
215
+ SHA256 authentication plugin public key value. (default: None)
216
+ binary_prefix : bool, optional
217
+ Add _binary prefix on bytes and bytearray. (default: False)
218
+ compress :
219
+ Not supported.
220
+ named_pipe :
221
+ Not supported.
222
+ db : str, optional
223
+ **DEPRECATED** Alias for database.
224
+ passwd : str, optional
225
+ **DEPRECATED** Alias for password.
226
+ parse_json : bool, optional
227
+ Parse JSON values into Python objects?
228
+ invalid_values : Dict[int, Any], optional
229
+ Dictionary of values to use in place of invalid values
230
+ found during conversion of data. The default is to return the byte content
231
+ containing the invalid value. The keys are the integers associtated with
232
+ the column type.
233
+ pure_python : bool, optional
234
+ Should we ignore the C extension even if it's available?
235
+ This can be given explicitly using True or False, or if the value is None,
236
+ the C extension will be loaded if it is available. If set to False and
237
+ the C extension can't be loaded, a NotSupportedError is raised.
238
+ nan_as_null : bool, optional
239
+ Should NaN values be treated as NULLs in parameter substitution including
240
+ uploading data?
241
+ inf_as_null : bool, optional
242
+ Should Inf values be treated as NULLs in parameter substitution including
243
+ uploading data?
244
+ track_env : bool, optional
245
+ Should the connection track the SINGLESTOREDB_URL environment variable?
246
+
247
+ See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_
248
+ in the specification.
249
+
250
+ """
251
+
252
+ driver = 'mysql'
253
+ paramstyle = 'pyformat'
254
+
255
+ _sock = None
256
+ _auth_plugin_name = ''
257
+ _closed = False
258
+ _secure = False
259
+
260
+ def __init__( # noqa: C901
261
+ self,
262
+ *,
263
+ user=None, # The first four arguments is based on DB-API 2.0 recommendation.
264
+ password='',
265
+ host=None,
266
+ database=None,
267
+ unix_socket=None,
268
+ port=0,
269
+ charset='',
270
+ collation=None,
271
+ sql_mode=None,
272
+ read_default_file=None,
273
+ conv=None,
274
+ use_unicode=True,
275
+ client_flag=0,
276
+ cursorclass=None,
277
+ init_command=None,
278
+ connect_timeout=10,
279
+ read_default_group=None,
280
+ autocommit=False,
281
+ local_infile=False,
282
+ max_allowed_packet=16 * 1024 * 1024,
283
+ defer_connect=False,
284
+ auth_plugin_map=None,
285
+ read_timeout=None,
286
+ write_timeout=None,
287
+ bind_address=None,
288
+ binary_prefix=False,
289
+ program_name=None,
290
+ server_public_key=None,
291
+ ssl=None,
292
+ ssl_ca=None,
293
+ ssl_cert=None,
294
+ ssl_cipher=None,
295
+ ssl_disabled=None,
296
+ ssl_key=None,
297
+ ssl_verify_cert=None,
298
+ ssl_verify_identity=None,
299
+ parse_json=True,
300
+ invalid_values=None,
301
+ pure_python=None,
302
+ buffered=True,
303
+ results_type='tuples',
304
+ compress=None, # not supported
305
+ named_pipe=None, # not supported
306
+ passwd=None, # deprecated
307
+ db=None, # deprecated
308
+ driver=None, # internal use
309
+ conn_attrs=None,
310
+ multi_statements=None,
311
+ nan_as_null=None,
312
+ inf_as_null=None,
313
+ encoding_errors='strict',
314
+ track_env=False,
315
+ ):
316
+ BaseConnection.__init__(**dict(locals()))
317
+
318
+ if db is not None and database is None:
319
+ # We will raise warning in 2022 or later.
320
+ # See https://github.com/PyMySQL/PyMySQL/issues/939
321
+ # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
322
+ database = db
323
+ if passwd is not None and not password:
324
+ # We will raise warning in 2022 or later.
325
+ # See https://github.com/PyMySQL/PyMySQL/issues/939
326
+ # warnings.warn(
327
+ # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
328
+ # )
329
+ password = passwd
330
+
331
+ if compress or named_pipe:
332
+ raise NotImplementedError(
333
+ 'compress and named_pipe arguments are not supported',
334
+ )
335
+
336
+ self._local_infile = bool(local_infile)
337
+ if self._local_infile:
338
+ client_flag |= CLIENT.LOCAL_FILES
339
+ if multi_statements:
340
+ client_flag |= CLIENT.MULTI_STATEMENTS
341
+
342
+ if read_default_group and not read_default_file:
343
+ if sys.platform.startswith('win'):
344
+ read_default_file = 'c:\\my.ini'
345
+ else:
346
+ read_default_file = '/etc/my.cnf'
347
+
348
+ if read_default_file:
349
+ if not read_default_group:
350
+ read_default_group = 'client'
351
+
352
+ cfg = Parser()
353
+ cfg.read(os.path.expanduser(read_default_file))
354
+
355
+ def _config(key, arg):
356
+ if arg:
357
+ return arg
358
+ try:
359
+ return cfg.get(read_default_group, key)
360
+ except Exception:
361
+ return arg
362
+
363
+ user = _config('user', user)
364
+ password = _config('password', password)
365
+ host = _config('host', host)
366
+ database = _config('database', database)
367
+ unix_socket = _config('socket', unix_socket)
368
+ port = int(_config('port', port))
369
+ bind_address = _config('bind-address', bind_address)
370
+ charset = _config('default-character-set', charset)
371
+ if not ssl:
372
+ ssl = {}
373
+ if isinstance(ssl, dict):
374
+ for key in ['ca', 'capath', 'cert', 'key', 'cipher']:
375
+ value = _config('ssl-' + key, ssl.get(key))
376
+ if value:
377
+ ssl[key] = value
378
+
379
+ self.ssl = False
380
+ if not ssl_disabled:
381
+ if ssl_ca or ssl_cert or ssl_key or ssl_cipher or \
382
+ ssl_verify_cert or ssl_verify_identity:
383
+ ssl = {
384
+ 'ca': ssl_ca,
385
+ 'check_hostname': bool(ssl_verify_identity),
386
+ 'verify_mode': ssl_verify_cert
387
+ if ssl_verify_cert is not None
388
+ else False,
389
+ }
390
+ if ssl_cert is not None:
391
+ ssl['cert'] = ssl_cert
392
+ if ssl_key is not None:
393
+ ssl['key'] = ssl_key
394
+ if ssl_cipher is not None:
395
+ ssl['cipher'] = ssl_cipher
396
+ if ssl:
397
+ if not SSL_ENABLED:
398
+ raise NotImplementedError('ssl module not found')
399
+ self.ssl = True
400
+ client_flag |= CLIENT.SSL
401
+ self.ctx = self._create_ssl_ctx(ssl)
402
+
403
+ self.host = host or 'localhost'
404
+ self.port = port or 3306
405
+ if type(self.port) is not int:
406
+ raise ValueError('port should be of type int')
407
+ self.user = user or DEFAULT_USER
408
+ self.password = password or b''
409
+ if isinstance(self.password, str):
410
+ self.password = self.password.encode('latin1')
411
+ self.db = database
412
+ self.unix_socket = unix_socket
413
+ self.bind_address = bind_address
414
+ if not (0 < connect_timeout <= 31536000):
415
+ raise ValueError('connect_timeout should be >0 and <=31536000')
416
+ self.connect_timeout = connect_timeout or None
417
+ if read_timeout is not None and read_timeout <= 0:
418
+ raise ValueError('read_timeout should be > 0')
419
+ self._read_timeout = read_timeout
420
+ if write_timeout is not None and write_timeout <= 0:
421
+ raise ValueError('write_timeout should be > 0')
422
+ self._write_timeout = write_timeout
423
+
424
+ self.charset = charset or DEFAULT_CHARSET
425
+ self.collation = collation
426
+ self.use_unicode = use_unicode
427
+ self.encoding_errors = encoding_errors
428
+
429
+ self.encoding = charset_by_name(self.charset).encoding
430
+
431
+ client_flag |= CLIENT.CAPABILITIES
432
+ client_flag |= CLIENT.CONNECT_WITH_DB
433
+
434
+ self.client_flag = client_flag
435
+
436
+ self.pure_python = pure_python
437
+ self.results_type = results_type
438
+ self.resultclass = MySQLResult
439
+ if cursorclass is not None:
440
+ self.cursorclass = cursorclass
441
+ elif buffered:
442
+ if 'dict' in self.results_type:
443
+ self.cursorclass = DictCursor
444
+ elif 'namedtuple' in self.results_type:
445
+ self.cursorclass = NamedtupleCursor
446
+ else:
447
+ self.cursorclass = Cursor
448
+ else:
449
+ if 'dict' in self.results_type:
450
+ self.cursorclass = SSDictCursor
451
+ elif 'namedtuple' in self.results_type:
452
+ self.cursorclass = SSNamedtupleCursor
453
+ else:
454
+ self.cursorclass = SSCursor
455
+
456
+ if self.pure_python is False and _singlestoredb_accel is None:
457
+ try:
458
+ import _singlestortedb_accel # noqa: F401
459
+ except Exception:
460
+ import traceback
461
+ traceback.print_exc(file=sys.stderr)
462
+ finally:
463
+ raise err.NotSupportedError(
464
+ 'pure_python=False, but the '
465
+ 'C extension can not be loaded',
466
+ )
467
+
468
+ if self.pure_python is True:
469
+ pass
470
+
471
+ # The C extension handles these types internally.
472
+ elif _singlestoredb_accel is not None:
473
+ self.resultclass = MySQLResultSV
474
+ if self.cursorclass is Cursor:
475
+ self.cursorclass = CursorSV
476
+ elif self.cursorclass is SSCursor:
477
+ self.cursorclass = SSCursorSV
478
+ elif self.cursorclass is DictCursor:
479
+ self.cursorclass = DictCursorSV
480
+ self.results_type = 'dicts'
481
+ elif self.cursorclass is SSDictCursor:
482
+ self.cursorclass = SSDictCursorSV
483
+ self.results_type = 'dicts'
484
+ elif self.cursorclass is NamedtupleCursor:
485
+ self.cursorclass = NamedtupleCursorSV
486
+ self.results_type = 'namedtuples'
487
+ elif self.cursorclass is SSNamedtupleCursor:
488
+ self.cursorclass = SSNamedtupleCursorSV
489
+ self.results_type = 'namedtuples'
490
+
491
+ self._result = None
492
+ self._affected_rows = 0
493
+ self.host_info = 'Not connected'
494
+
495
+ # specified autocommit mode. None means use server default.
496
+ self.autocommit_mode = autocommit
497
+
498
+ if conv is None:
499
+ conv = converters.conversions
500
+
501
+ self.parse_json = parse_json
502
+ self.invalid_values = (invalid_values or {}).copy()
503
+
504
+ # Need for MySQLdb compatibility.
505
+ self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
506
+ self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
507
+ self.sql_mode = sql_mode
508
+ self.init_command = init_command
509
+ self.max_allowed_packet = max_allowed_packet
510
+ self._auth_plugin_map = auth_plugin_map or {}
511
+ self._binary_prefix = binary_prefix
512
+ self.server_public_key = server_public_key
513
+
514
+ if self.connection_params['nan_as_null'] or \
515
+ self.connection_params['inf_as_null']:
516
+ float_encoder = self.encoders.get(float)
517
+ if float_encoder is not None:
518
+ self.encoders[float] = functools.partial(
519
+ float_encoder,
520
+ nan_as_null=self.connection_params['nan_as_null'],
521
+ inf_as_null=self.connection_params['inf_as_null'],
522
+ )
523
+
524
+ from .. import __version__ as VERSION_STRING
525
+
526
+ self._connect_attrs = {
527
+ '_os': str(sys.platform),
528
+ '_pid': str(os.getpid()),
529
+ '_client_name': 'SingleStoreDB Python Client',
530
+ '_client_version': VERSION_STRING,
531
+ }
532
+
533
+ if program_name:
534
+ self._connect_attrs['program_name'] = program_name
535
+ if conn_attrs is not None:
536
+ # do not overwrite the attributes that we set ourselves
537
+ for k, v in conn_attrs.items():
538
+ if k not in self._connect_attrs:
539
+ self._connect_attrs[k] = v
540
+
541
+ self._in_sync = False
542
+ self._track_env = bool(track_env) or self.host == 'singlestore.com'
543
+
544
+ if defer_connect or self._track_env:
545
+ self._sock = None
546
+ else:
547
+ self.connect()
548
+
549
+ @property
550
+ def messages(self):
551
+ # TODO
552
+ []
553
+
554
+ def __enter__(self):
555
+ return self
556
+
557
+ def __exit__(self, *exc_info):
558
+ del exc_info
559
+ self.close()
560
+
561
+ def _raise_mysql_exception(self, data):
562
+ err.raise_mysql_exception(data)
563
+
564
+ def _create_ssl_ctx(self, sslp):
565
+ if isinstance(sslp, ssl.SSLContext):
566
+ return sslp
567
+ ca = sslp.get('ca')
568
+ capath = sslp.get('capath')
569
+ hasnoca = ca is None and capath is None
570
+ ctx = ssl.create_default_context(cafile=ca, capath=capath)
571
+ ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
572
+ verify_mode_value = sslp.get('verify_mode')
573
+ if verify_mode_value is None:
574
+ ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
575
+ elif isinstance(verify_mode_value, bool):
576
+ ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
577
+ else:
578
+ if isinstance(verify_mode_value, str):
579
+ verify_mode_value = verify_mode_value.lower()
580
+ if verify_mode_value in ('none', '0', 'false', 'no'):
581
+ ctx.verify_mode = ssl.CERT_NONE
582
+ elif verify_mode_value == 'optional':
583
+ ctx.verify_mode = ssl.CERT_OPTIONAL
584
+ elif verify_mode_value in ('required', '1', 'true', 'yes'):
585
+ ctx.verify_mode = ssl.CERT_REQUIRED
586
+ else:
587
+ ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
588
+ if 'cert' in sslp:
589
+ ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
590
+ if 'cipher' in sslp:
591
+ ctx.set_ciphers(sslp['cipher'])
592
+ ctx.options |= ssl.OP_NO_SSLv2
593
+ ctx.options |= ssl.OP_NO_SSLv3
594
+ return ctx
595
+
596
+ def close(self):
597
+ """
598
+ Send the quit message and close the socket.
599
+
600
+ See `Connection.close()
601
+ <https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
602
+ in the specification.
603
+
604
+ Raises
605
+ ------
606
+ Error : If the connection is already closed.
607
+
608
+ """
609
+ if self.host == 'singlestore.com':
610
+ return
611
+ if self._closed:
612
+ raise err.Error('Already closed')
613
+ self._closed = True
614
+ if self._sock is None:
615
+ return
616
+ send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
617
+ try:
618
+ self._write_bytes(send_data)
619
+ except Exception:
620
+ pass
621
+ finally:
622
+ self._force_close()
623
+
624
+ @property
625
+ def open(self):
626
+ """Return True if the connection is open."""
627
+ return self._sock is not None
628
+
629
+ def is_connected(self):
630
+ """Return True if the connection is open."""
631
+ return self.open
632
+
633
+ def _force_close(self):
634
+ """Close connection without QUIT message."""
635
+ if self._sock:
636
+ try:
637
+ self._sock.close()
638
+ except: # noqa
639
+ pass
640
+ self._sock = None
641
+ self._rfile = None
642
+
643
+ __del__ = _force_close
644
+
645
+ def autocommit(self, value):
646
+ """Enable autocommit in the server."""
647
+ self.autocommit_mode = bool(value)
648
+ current = self.get_autocommit()
649
+ if value != current:
650
+ self._send_autocommit_mode()
651
+
652
+ def get_autocommit(self):
653
+ """Retrieve autocommit status."""
654
+ return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
655
+
656
+ def _read_ok_packet(self):
657
+ pkt = self._read_packet()
658
+ if not pkt.is_ok_packet():
659
+ raise err.OperationalError(
660
+ CR.CR_COMMANDS_OUT_OF_SYNC,
661
+ 'Command Out of Sync',
662
+ )
663
+ ok = OKPacketWrapper(pkt)
664
+ self.server_status = ok.server_status
665
+ return ok
666
+
667
+ def _send_autocommit_mode(self):
668
+ """Set whether or not to commit after every execute()."""
669
+ log_query('SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode))
670
+ self._execute_command(
671
+ COMMAND.COM_QUERY, 'SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode),
672
+ )
673
+ self._read_ok_packet()
674
+
675
+ def begin(self):
676
+ """Begin transaction."""
677
+ log_query('BEGIN')
678
+ if self.host == 'singlestore.com':
679
+ return
680
+ self._execute_command(COMMAND.COM_QUERY, 'BEGIN')
681
+ self._read_ok_packet()
682
+
683
+ def commit(self):
684
+ """
685
+ Commit changes to stable storage.
686
+
687
+ See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
688
+ in the specification.
689
+
690
+ """
691
+ log_query('COMMIT')
692
+ if self.host == 'singlestore.com':
693
+ return
694
+ self._execute_command(COMMAND.COM_QUERY, 'COMMIT')
695
+ self._read_ok_packet()
696
+
697
+ def rollback(self):
698
+ """
699
+ Roll back the current transaction.
700
+
701
+ See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
702
+ in the specification.
703
+
704
+ """
705
+ log_query('ROLLBACK')
706
+ if self.host == 'singlestore.com':
707
+ return
708
+ self._execute_command(COMMAND.COM_QUERY, 'ROLLBACK')
709
+ self._read_ok_packet()
710
+
711
+ def show_warnings(self):
712
+ """Send the "SHOW WARNINGS" SQL command."""
713
+ log_query('SHOW WARNINGS')
714
+ self._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
715
+ result = self.resultclass(self)
716
+ result.read()
717
+ return result.rows
718
+
719
+ def select_db(self, db):
720
+ """
721
+ Set current db.
722
+
723
+ db : str
724
+ The name of the db.
725
+
726
+ """
727
+ self._execute_command(COMMAND.COM_INIT_DB, db)
728
+ self._read_ok_packet()
729
+
730
+ def escape(self, obj, mapping=None):
731
+ """
732
+ Escape whatever value is passed.
733
+
734
+ Non-standard, for internal use; do not use this in your applications.
735
+
736
+ """
737
+ dtype = type(obj)
738
+ if dtype is str or isinstance(obj, str):
739
+ return "'{}'".format(self.escape_string(obj))
740
+ if dtype is bytes or dtype is bytearray or isinstance(obj, (bytes, bytearray)):
741
+ return self._quote_bytes(obj)
742
+ if mapping is None:
743
+ mapping = self.encoders
744
+ return converters.escape_item(obj, self.charset, mapping=mapping)
745
+
746
+ def literal(self, obj):
747
+ """
748
+ Alias for escape().
749
+
750
+ Non-standard, for internal use; do not use this in your applications.
751
+
752
+ """
753
+ return self.escape(obj, self.encoders)
754
+
755
+ def escape_string(self, s):
756
+ """Escape a string value."""
757
+ if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
758
+ return s.replace("'", "''")
759
+ return converters.escape_string(s)
760
+
761
+ def _quote_bytes(self, s):
762
+ if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
763
+ if self._binary_prefix:
764
+ return "_binary X'{}'".format(s.hex())
765
+ return "X'{}'".format(s.hex())
766
+ return converters.escape_bytes(s)
767
+
768
+ def cursor(self):
769
+ """Create a new cursor to execute queries with."""
770
+ return self.cursorclass(self)
771
+
772
+ # The following methods are INTERNAL USE ONLY (called from Cursor)
773
+ def query(self, sql, unbuffered=False):
774
+ """
775
+ Run a query on the server.
776
+
777
+ Internal use only.
778
+
779
+ """
780
+ # if DEBUG:
781
+ # print("DEBUG: sending query:", sql)
782
+ handler = fusion.get_handler(sql)
783
+ if handler is not None:
784
+ self._result = fusion.execute(self, sql, handler=handler)
785
+ self._affected_rows = self._result.affected_rows
786
+ else:
787
+ if isinstance(sql, str):
788
+ sql = sql.encode(self.encoding, 'surrogateescape')
789
+ self._execute_command(COMMAND.COM_QUERY, sql)
790
+ self._affected_rows = self._read_query_result(unbuffered=unbuffered)
791
+ return self._affected_rows
792
+
793
+ def next_result(self, unbuffered=False):
794
+ """
795
+ Retrieve the next result set.
796
+
797
+ Internal use only.
798
+
799
+ """
800
+ self._affected_rows = self._read_query_result(unbuffered=unbuffered)
801
+ return self._affected_rows
802
+
803
+ def affected_rows(self):
804
+ """
805
+ Return number of affected rows.
806
+
807
+ Internal use only.
808
+
809
+ """
810
+ return self._affected_rows
811
+
812
+ def kill(self, thread_id):
813
+ """
814
+ Execute kill command.
815
+
816
+ Internal use only.
817
+
818
+ """
819
+ arg = struct.pack('<I', thread_id)
820
+ self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
821
+ return self._read_ok_packet()
822
+
823
+ def ping(self, reconnect=True):
824
+ """
825
+ Check if the server is alive.
826
+
827
+ Parameters
828
+ ----------
829
+ reconnect : bool, optional
830
+ If the connection is closed, reconnect.
831
+
832
+ Raises
833
+ ------
834
+ Error : If the connection is closed and reconnect=False.
835
+
836
+ """
837
+ if self._sock is None:
838
+ if reconnect:
839
+ self.connect()
840
+ reconnect = False
841
+ else:
842
+ raise err.Error('Already closed')
843
+ try:
844
+ self._execute_command(COMMAND.COM_PING, '')
845
+ self._read_ok_packet()
846
+ except Exception:
847
+ if reconnect:
848
+ self.connect()
849
+ self.ping(False)
850
+ else:
851
+ raise
852
+
853
+ def set_charset(self, charset):
854
+ """Deprecated. Use set_character_set() instead."""
855
+ # This function has been implemented in old PyMySQL.
856
+ # But this name is different from MySQLdb.
857
+ # So we keep this function for compatibility and add
858
+ # new set_character_set() function.
859
+ self.set_character_set(charset)
860
+
861
+ def set_character_set(self, charset, collation=None):
862
+ """
863
+ Set charaset (and collation) on the server.
864
+
865
+ Send "SET NAMES charset [COLLATE collation]" query.
866
+ Update Connection.encoding based on charset.
867
+
868
+ Parameters
869
+ ----------
870
+ charset : str
871
+ The charset to enable.
872
+ collation : str, optional
873
+ The collation value
874
+
875
+ """
876
+ # Make sure charset is supported.
877
+ encoding = charset_by_name(charset).encoding
878
+
879
+ if collation:
880
+ query = f'SET NAMES {charset} COLLATE {collation}'
881
+ else:
882
+ query = f'SET NAMES {charset}'
883
+ self._execute_command(COMMAND.COM_QUERY, query)
884
+ self._read_packet()
885
+ self.charset = charset
886
+ self.encoding = encoding
887
+ self.collation = collation
888
+
889
+ def _sync_connection(self):
890
+ """Synchronize connection with env variable."""
891
+ if self._in_sync:
892
+ return
893
+
894
+ if not self._track_env:
895
+ return
896
+
897
+ url = os.environ.get('SINGLESTOREDB_URL')
898
+ if not url:
899
+ return
900
+
901
+ out = {}
902
+ urlp = connection._parse_url(url)
903
+ out.update(urlp)
904
+
905
+ out = connection._cast_params(out)
906
+
907
+ # Set default port based on driver.
908
+ if 'port' not in out or not out['port']:
909
+ out['port'] = int(get_option('port') or 3306)
910
+
911
+ # If there is no user and the password is empty, remove the password key.
912
+ if 'user' not in out and not out.get('password', None):
913
+ out.pop('password', None)
914
+
915
+ if out['host'] == 'singlestore.com':
916
+ raise err.InterfaceError(0, 'Connection URL has not been established')
917
+
918
+ # If it's just a password change, we don't need to reconnect
919
+ if self._sock is not None and \
920
+ (self.host, self.port, self.user, self.db) == \
921
+ (out['host'], out['port'], out['user'], out.get('database')):
922
+ return
923
+
924
+ self.host = out['host']
925
+ self.port = out['port']
926
+ self.user = out['user']
927
+ if isinstance(out['password'], str):
928
+ self.password = out['password'].encode('latin-1')
929
+ else:
930
+ self.password = out['password'] or b''
931
+ self.db = out.get('database')
932
+ try:
933
+ self._in_sync = True
934
+ self.connect()
935
+ finally:
936
+ self._in_sync = False
937
+
938
+ def connect(self, sock=None):
939
+ """
940
+ Connect to server using existing parameters.
941
+
942
+ Internal use only.
943
+
944
+ """
945
+ self._closed = False
946
+ try:
947
+ if sock is None:
948
+ if self.unix_socket:
949
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
950
+ sock.settimeout(self.connect_timeout)
951
+ sock.connect(self.unix_socket)
952
+ self.host_info = 'Localhost via UNIX socket'
953
+ self._secure = True
954
+ if DEBUG:
955
+ print('connected using unix_socket')
956
+ else:
957
+ kwargs = {}
958
+ if self.bind_address is not None:
959
+ kwargs['source_address'] = (self.bind_address, 0)
960
+ while True:
961
+ try:
962
+ sock = socket.create_connection(
963
+ (self.host, self.port), self.connect_timeout, **kwargs,
964
+ )
965
+ break
966
+ except OSError as e:
967
+ if e.errno == errno.EINTR:
968
+ continue
969
+ raise
970
+ self.host_info = 'socket %s:%d' % (self.host, self.port)
971
+ if DEBUG:
972
+ print('connected using socket')
973
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
974
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
975
+ sock.settimeout(None)
976
+
977
+ self._sock = sock
978
+ self._rfile = sock.makefile('rb')
979
+ self._next_seq_id = 0
980
+
981
+ self._get_server_information()
982
+ self._request_authentication()
983
+
984
+ # Send "SET NAMES" query on init for:
985
+ # - Ensure charaset (and collation) is set to the server.
986
+ # - collation_id in handshake packet may be ignored.
987
+ # - If collation is not specified, we don't know what is server's
988
+ # default collation for the charset. For example, default collation
989
+ # of utf8mb4 is:
990
+ # - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci
991
+ # - MySQL 8.0: utf8mb4_0900_ai_ci
992
+ #
993
+ # Reference:
994
+ # - https://github.com/PyMySQL/PyMySQL/issues/1092
995
+ # - https://github.com/wagtail/wagtail/issues/9477
996
+ # - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)
997
+ self.set_character_set(self.charset, self.collation)
998
+
999
+ if self.sql_mode is not None:
1000
+ c = self.cursor()
1001
+ c.execute('SET sql_mode=%s', (self.sql_mode,))
1002
+ c.close()
1003
+
1004
+ if self.init_command is not None:
1005
+ c = self.cursor()
1006
+ c.execute(self.init_command)
1007
+ c.close()
1008
+
1009
+ if self.autocommit_mode is not None:
1010
+ self.autocommit(self.autocommit_mode)
1011
+
1012
+ except BaseException as e:
1013
+ self._rfile = None
1014
+ if sock is not None:
1015
+ try:
1016
+ sock.close()
1017
+ except: # noqa
1018
+ pass
1019
+
1020
+ if isinstance(e, (OSError, IOError, socket.error)):
1021
+ exc = err.OperationalError(
1022
+ CR.CR_CONN_HOST_ERROR,
1023
+ f'Can\'t connect to MySQL server on {self.host!r} ({e})',
1024
+ )
1025
+ # Keep original exception and traceback to investigate error.
1026
+ exc.original_exception = e
1027
+ exc.traceback = traceback.format_exc()
1028
+ if DEBUG:
1029
+ print(exc.traceback)
1030
+ raise exc
1031
+
1032
+ # If e is neither DatabaseError or IOError, It's a bug.
1033
+ # But raising AssertionError hides original error.
1034
+ # So just reraise it.
1035
+ raise
1036
+
1037
+ def write_packet(self, payload):
1038
+ """
1039
+ Writes an entire "mysql packet" in its entirety to the network.
1040
+
1041
+ Adds its length and sequence number.
1042
+
1043
+ """
1044
+ # Internal note: when you build packet manually and calls _write_bytes()
1045
+ # directly, you should set self._next_seq_id properly.
1046
+ data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
1047
+ if DEBUG:
1048
+ dump_packet(data)
1049
+ self._write_bytes(data)
1050
+ self._next_seq_id = (self._next_seq_id + 1) % 256
1051
+
1052
+ def _read_packet(self, packet_type=MysqlPacket):
1053
+ """
1054
+ Read an entire "mysql packet" in its entirety from the network.
1055
+
1056
+ Raises
1057
+ ------
1058
+ OperationalError : If the connection to the MySQL server is lost.
1059
+ InternalError : If the packet sequence number is wrong.
1060
+
1061
+ Returns
1062
+ -------
1063
+ MysqlPacket
1064
+
1065
+ """
1066
+ buff = bytearray()
1067
+ while True:
1068
+ packet_header = self._read_bytes(4)
1069
+ # if DEBUG: dump_packet(packet_header)
1070
+
1071
+ btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
1072
+ bytes_to_read = btrl + (btrh << 16)
1073
+ if packet_number != self._next_seq_id:
1074
+ self._force_close()
1075
+ if packet_number == 0:
1076
+ # MariaDB sends error packet with seqno==0 when shutdown
1077
+ raise err.OperationalError(
1078
+ CR.CR_SERVER_LOST,
1079
+ 'Lost connection to MySQL server during query',
1080
+ )
1081
+ raise err.InternalError(
1082
+ 'Packet sequence number wrong - got %d expected %d'
1083
+ % (packet_number, self._next_seq_id),
1084
+ )
1085
+ self._next_seq_id = (self._next_seq_id + 1) % 256
1086
+
1087
+ recv_data = self._read_bytes(bytes_to_read)
1088
+ if DEBUG:
1089
+ dump_packet(recv_data)
1090
+ buff += recv_data
1091
+ # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
1092
+ if bytes_to_read == 0xFFFFFF:
1093
+ continue
1094
+ if bytes_to_read < MAX_PACKET_LEN:
1095
+ break
1096
+
1097
+ packet = packet_type(bytes(buff), self.encoding)
1098
+ if packet.is_error_packet():
1099
+ if self._result is not None and self._result.unbuffered_active is True:
1100
+ self._result.unbuffered_active = False
1101
+ packet.raise_for_error()
1102
+ return packet
1103
+
1104
+ def _read_bytes(self, num_bytes):
1105
+ if self._read_timeout is not None:
1106
+ self._sock.settimeout(self._read_timeout)
1107
+ while True:
1108
+ try:
1109
+ data = self._rfile.read(num_bytes)
1110
+ break
1111
+ except OSError as e:
1112
+ if e.errno == errno.EINTR:
1113
+ continue
1114
+ self._force_close()
1115
+ raise err.OperationalError(
1116
+ CR.CR_SERVER_LOST,
1117
+ 'Lost connection to MySQL server during query (%s)' % (e,),
1118
+ )
1119
+ except BaseException:
1120
+ # Don't convert unknown exception to MySQLError.
1121
+ self._force_close()
1122
+ raise
1123
+ if len(data) < num_bytes:
1124
+ self._force_close()
1125
+ raise err.OperationalError(
1126
+ CR.CR_SERVER_LOST, 'Lost connection to MySQL server during query',
1127
+ )
1128
+ return data
1129
+
1130
+ def _write_bytes(self, data):
1131
+ if self._write_timeout is not None:
1132
+ self._sock.settimeout(self._write_timeout)
1133
+ try:
1134
+ self._sock.sendall(data)
1135
+ except OSError as e:
1136
+ self._force_close()
1137
+ raise err.OperationalError(
1138
+ CR.CR_SERVER_GONE_ERROR, f'MySQL server has gone away ({e!r})',
1139
+ )
1140
+
1141
+ def _read_query_result(self, unbuffered=False):
1142
+ self._result = None
1143
+ if unbuffered:
1144
+ result = self.resultclass(self, unbuffered=unbuffered)
1145
+ else:
1146
+ result = self.resultclass(self)
1147
+ result.read()
1148
+ self._result = result
1149
+ if result.server_status is not None:
1150
+ self.server_status = result.server_status
1151
+ return result.affected_rows
1152
+
1153
+ def insert_id(self):
1154
+ if self._result:
1155
+ return self._result.insert_id
1156
+ else:
1157
+ return 0
1158
+
1159
+ def _execute_command(self, command, sql):
1160
+ """
1161
+ Execute command.
1162
+
1163
+ Raises
1164
+ ------
1165
+ InterfaceError : If the connection is closed.
1166
+ ValueError : If no username was specified.
1167
+
1168
+ """
1169
+ self._sync_connection()
1170
+
1171
+ if self._sock is None:
1172
+ raise err.InterfaceError(0, 'The connection has been closed')
1173
+
1174
+ # If the last query was unbuffered, make sure it finishes before
1175
+ # sending new commands
1176
+ if self._result is not None:
1177
+ if self._result.unbuffered_active:
1178
+ warnings.warn('Previous unbuffered result was left incomplete')
1179
+ self._result._finish_unbuffered_query()
1180
+ while self._result.has_next:
1181
+ self.next_result()
1182
+ self._result = None
1183
+
1184
+ if isinstance(sql, str):
1185
+ sql = sql.encode(self.encoding)
1186
+
1187
+ packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
1188
+
1189
+ # tiny optimization: build first packet manually instead of
1190
+ # calling self..write_packet()
1191
+ prelude = struct.pack('<iB', packet_size, command)
1192
+ packet = prelude + sql[: packet_size - 1]
1193
+ self._write_bytes(packet)
1194
+ if DEBUG:
1195
+ dump_packet(packet)
1196
+ self._next_seq_id = 1
1197
+
1198
+ if packet_size < MAX_PACKET_LEN:
1199
+ return
1200
+
1201
+ sql = sql[packet_size - 1:]
1202
+ while True:
1203
+ packet_size = min(MAX_PACKET_LEN, len(sql))
1204
+ self.write_packet(sql[:packet_size])
1205
+ sql = sql[packet_size:]
1206
+ if not sql and packet_size < MAX_PACKET_LEN:
1207
+ break
1208
+
1209
+ def _request_authentication(self): # noqa: C901
1210
+ # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
1211
+ if int(self.server_version.split('.', 1)[0]) >= 5:
1212
+ self.client_flag |= CLIENT.MULTI_RESULTS
1213
+
1214
+ if self.user is None:
1215
+ raise ValueError('Did not specify a username')
1216
+
1217
+ charset_id = charset_by_name(self.charset).id
1218
+ if isinstance(self.user, str):
1219
+ self.user = self.user.encode(self.encoding)
1220
+
1221
+ data_init = struct.pack(
1222
+ '<iIB23s', self.client_flag, MAX_PACKET_LEN, charset_id, b'',
1223
+ )
1224
+
1225
+ if self.ssl and self.server_capabilities & CLIENT.SSL:
1226
+ self.write_packet(data_init)
1227
+
1228
+ self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
1229
+ self._rfile = self._sock.makefile('rb')
1230
+ self._secure = True
1231
+
1232
+ data = data_init + self.user + b'\0'
1233
+
1234
+ authresp = b''
1235
+ plugin_name = None
1236
+
1237
+ if self._auth_plugin_name == '':
1238
+ plugin_name = b''
1239
+ authresp = _auth.scramble_native_password(self.password, self.salt)
1240
+ elif self._auth_plugin_name == 'mysql_native_password':
1241
+ plugin_name = b'mysql_native_password'
1242
+ authresp = _auth.scramble_native_password(self.password, self.salt)
1243
+ elif self._auth_plugin_name == 'caching_sha2_password':
1244
+ plugin_name = b'caching_sha2_password'
1245
+ if self.password:
1246
+ if DEBUG:
1247
+ print('caching_sha2: trying fast path')
1248
+ authresp = _auth.scramble_caching_sha2(self.password, self.salt)
1249
+ else:
1250
+ if DEBUG:
1251
+ print('caching_sha2: empty password')
1252
+ elif self._auth_plugin_name == 'sha256_password':
1253
+ plugin_name = b'sha256_password'
1254
+ if self.ssl and self.server_capabilities & CLIENT.SSL:
1255
+ authresp = self.password + b'\0'
1256
+ elif self.password:
1257
+ authresp = b'\1' # request public key
1258
+ else:
1259
+ authresp = b'\0' # empty password
1260
+
1261
+ if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
1262
+ data += _lenenc_int(len(authresp)) + authresp
1263
+ elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
1264
+ data += struct.pack('B', len(authresp)) + authresp
1265
+ else: # pragma: no cover - no testing against servers w/o secure auth (>=5.0)
1266
+ data += authresp + b'\0'
1267
+
1268
+ if self.server_capabilities & CLIENT.CONNECT_WITH_DB:
1269
+ db = self.db
1270
+ if isinstance(db, str):
1271
+ db = db.encode(self.encoding)
1272
+ data += (db or b'') + b'\0'
1273
+
1274
+ if self.server_capabilities & CLIENT.PLUGIN_AUTH:
1275
+ data += (plugin_name or b'') + b'\0'
1276
+
1277
+ if self.server_capabilities & CLIENT.CONNECT_ATTRS:
1278
+ connect_attrs = b''
1279
+ for k, v in self._connect_attrs.items():
1280
+ k = k.encode('utf-8')
1281
+ connect_attrs += _lenenc_int(len(k)) + k
1282
+ v = v.encode('utf-8')
1283
+ connect_attrs += _lenenc_int(len(v)) + v
1284
+ data += _lenenc_int(len(connect_attrs)) + connect_attrs
1285
+
1286
+ self.write_packet(data)
1287
+ auth_packet = self._read_packet()
1288
+
1289
+ # if authentication method isn't accepted the first byte
1290
+ # will have the octet 254
1291
+ if auth_packet.is_auth_switch_request():
1292
+ if DEBUG:
1293
+ print('received auth switch')
1294
+ # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
1295
+ auth_packet.read_uint8() # 0xfe packet identifier
1296
+ plugin_name = auth_packet.read_string()
1297
+ if (
1298
+ self.server_capabilities & CLIENT.PLUGIN_AUTH
1299
+ and plugin_name is not None
1300
+ ):
1301
+ auth_packet = self._process_auth(plugin_name, auth_packet)
1302
+ else:
1303
+ raise err.OperationalError('received unknown auth switch request')
1304
+ elif auth_packet.is_extra_auth_data():
1305
+ if DEBUG:
1306
+ print('received extra data')
1307
+ # https://dev.mysql.com/doc/internals/en/successful-authentication.html
1308
+ if self._auth_plugin_name == 'caching_sha2_password':
1309
+ auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
1310
+ elif self._auth_plugin_name == 'sha256_password':
1311
+ auth_packet = _auth.sha256_password_auth(self, auth_packet)
1312
+ else:
1313
+ raise err.OperationalError(
1314
+ 'Received extra packet for auth method %r', self._auth_plugin_name,
1315
+ )
1316
+
1317
+ if DEBUG:
1318
+ print('Succeed to auth')
1319
+
1320
+ def _process_auth(self, plugin_name, auth_packet):
1321
+ handler = self._get_auth_plugin_handler(plugin_name)
1322
+ if handler:
1323
+ try:
1324
+ return handler.authenticate(auth_packet)
1325
+ except AttributeError:
1326
+ if plugin_name != b'dialog':
1327
+ raise err.OperationalError(
1328
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1329
+ "Authentication plugin '%s'"
1330
+ ' not loaded: - %r missing authenticate method'
1331
+ % (plugin_name, type(handler)),
1332
+ )
1333
+ if plugin_name == b'caching_sha2_password':
1334
+ return _auth.caching_sha2_password_auth(self, auth_packet)
1335
+ elif plugin_name == b'sha256_password':
1336
+ return _auth.sha256_password_auth(self, auth_packet)
1337
+ elif plugin_name == b'mysql_native_password':
1338
+ data = _auth.scramble_native_password(self.password, auth_packet.read_all())
1339
+ elif plugin_name == b'client_ed25519':
1340
+ data = _auth.ed25519_password(self.password, auth_packet.read_all())
1341
+ elif plugin_name == b'mysql_old_password':
1342
+ data = (
1343
+ _auth.scramble_old_password(self.password, auth_packet.read_all())
1344
+ + b'\0'
1345
+ )
1346
+ elif plugin_name == b'mysql_clear_password':
1347
+ # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
1348
+ data = self.password + b'\0'
1349
+ elif plugin_name == b'auth_gssapi_client':
1350
+ data = _auth.gssapi_auth(auth_packet.read_all())
1351
+ elif plugin_name == b'dialog':
1352
+ pkt = auth_packet
1353
+ while True:
1354
+ flag = pkt.read_uint8()
1355
+ echo = (flag & 0x06) == 0x02
1356
+ last = (flag & 0x01) == 0x01
1357
+ prompt = pkt.read_all()
1358
+
1359
+ if prompt == b'Password: ':
1360
+ self.write_packet(self.password + b'\0')
1361
+ elif handler:
1362
+ resp = 'no response - TypeError within plugin.prompt method'
1363
+ try:
1364
+ resp = handler.prompt(echo, prompt)
1365
+ self.write_packet(resp + b'\0')
1366
+ except AttributeError:
1367
+ raise err.OperationalError(
1368
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1369
+ "Authentication plugin '%s'"
1370
+ ' not loaded: - %r missing prompt method'
1371
+ % (plugin_name, handler),
1372
+ )
1373
+ except TypeError:
1374
+ raise err.OperationalError(
1375
+ CR.CR_AUTH_PLUGIN_ERR,
1376
+ "Authentication plugin '%s'"
1377
+ " %r didn't respond with string. Returned '%r' to prompt %r"
1378
+ % (plugin_name, handler, resp, prompt),
1379
+ )
1380
+ else:
1381
+ raise err.OperationalError(
1382
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1383
+ "Authentication plugin '%s' not configured" % (plugin_name,),
1384
+ )
1385
+ pkt = self._read_packet()
1386
+ pkt.check_error()
1387
+ if pkt.is_ok_packet() or last:
1388
+ break
1389
+ return pkt
1390
+ else:
1391
+ raise err.OperationalError(
1392
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1393
+ "Authentication plugin '%s' not configured" % plugin_name,
1394
+ )
1395
+
1396
+ self.write_packet(data)
1397
+ pkt = self._read_packet()
1398
+ pkt.check_error()
1399
+ return pkt
1400
+
1401
+ def _get_auth_plugin_handler(self, plugin_name):
1402
+ plugin_class = self._auth_plugin_map.get(plugin_name)
1403
+ if not plugin_class and isinstance(plugin_name, bytes):
1404
+ plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
1405
+ if plugin_class:
1406
+ try:
1407
+ handler = plugin_class(self)
1408
+ except TypeError:
1409
+ raise err.OperationalError(
1410
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1411
+ "Authentication plugin '%s'"
1412
+ ' not loaded: - %r cannot be constructed with connection object'
1413
+ % (plugin_name, plugin_class),
1414
+ )
1415
+ else:
1416
+ handler = None
1417
+ return handler
1418
+
1419
+ # _mysql support
1420
+ def thread_id(self):
1421
+ return self.server_thread_id[0]
1422
+
1423
+ def character_set_name(self):
1424
+ return self.charset
1425
+
1426
+ def get_host_info(self):
1427
+ return self.host_info
1428
+
1429
+ def get_proto_info(self):
1430
+ return self.protocol_version
1431
+
1432
+ def _get_server_information(self):
1433
+ i = 0
1434
+ packet = self._read_packet()
1435
+ data = packet.get_all_data()
1436
+
1437
+ self.protocol_version = data[i]
1438
+ i += 1
1439
+
1440
+ server_end = data.find(b'\0', i)
1441
+ self.server_version = data[i:server_end].decode('latin1')
1442
+ i = server_end + 1
1443
+
1444
+ self.server_thread_id = struct.unpack('<I', data[i: i + 4])
1445
+ i += 4
1446
+
1447
+ self.salt = data[i: i + 8]
1448
+ i += 9 # 8 + 1(filler)
1449
+
1450
+ self.server_capabilities = struct.unpack('<H', data[i: i + 2])[0]
1451
+ i += 2
1452
+
1453
+ if len(data) >= i + 6:
1454
+ lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i: i + 6])
1455
+ i += 6
1456
+ # TODO: deprecate server_language and server_charset.
1457
+ # mysqlclient-python doesn't provide it.
1458
+ self.server_language = lang
1459
+ try:
1460
+ self.server_charset = charset_by_id(lang).name
1461
+ except KeyError:
1462
+ # unknown collation
1463
+ self.server_charset = None
1464
+
1465
+ self.server_status = stat
1466
+ if DEBUG:
1467
+ print('server_status: %x' % stat)
1468
+
1469
+ self.server_capabilities |= cap_h << 16
1470
+ if DEBUG:
1471
+ print('salt_len:', salt_len)
1472
+ salt_len = max(12, salt_len - 9)
1473
+
1474
+ # reserved
1475
+ i += 10
1476
+
1477
+ if len(data) >= i + salt_len:
1478
+ # salt_len includes auth_plugin_data_part_1 and filler
1479
+ self.salt += data[i: i + salt_len]
1480
+ i += salt_len
1481
+
1482
+ i += 1
1483
+ # AUTH PLUGIN NAME may appear here.
1484
+ if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
1485
+ # Due to Bug#59453 the auth-plugin-name is missing the terminating
1486
+ # NUL-char in versions prior to 5.5.10 and 5.6.2.
1487
+ # ref: https://dev.mysql.com/doc/internals/en/
1488
+ # connection-phase-packets.html#packet-Protocol::Handshake
1489
+ # didn't use version checks as mariadb is corrected and reports
1490
+ # earlier than those two.
1491
+ server_end = data.find(b'\0', i)
1492
+ if server_end < 0: # pragma: no cover - very specific upstream bug
1493
+ # not found \0 and last field so take it all
1494
+ self._auth_plugin_name = data[i:].decode('utf-8')
1495
+ else:
1496
+ self._auth_plugin_name = data[i:server_end].decode('utf-8')
1497
+
1498
+ def get_server_info(self):
1499
+ return self.server_version
1500
+
1501
+ Warning = err.Warning
1502
+ Error = err.Error
1503
+ InterfaceError = err.InterfaceError
1504
+ DatabaseError = err.DatabaseError
1505
+ DataError = err.DataError
1506
+ OperationalError = err.OperationalError
1507
+ IntegrityError = err.IntegrityError
1508
+ InternalError = err.InternalError
1509
+ ProgrammingError = err.ProgrammingError
1510
+ NotSupportedError = err.NotSupportedError
1511
+
1512
+
1513
+ class MySQLResult:
1514
+ """
1515
+ Results of a SQL query.
1516
+
1517
+ Parameters
1518
+ ----------
1519
+ connection : Connection
1520
+ The connection the result came from.
1521
+ unbuffered : bool, optional
1522
+ Should the reads be unbuffered?
1523
+
1524
+ """
1525
+
1526
+ def __init__(self, connection, unbuffered=False):
1527
+ self.connection = connection
1528
+ self.affected_rows = None
1529
+ self.insert_id = None
1530
+ self.server_status = None
1531
+ self.warning_count = 0
1532
+ self.message = None
1533
+ self.field_count = 0
1534
+ self.description = None
1535
+ self.rows = None
1536
+ self.has_next = None
1537
+ self.unbuffered_active = False
1538
+ self.converters = []
1539
+ self.fields = []
1540
+ self.encoding_errors = self.connection.encoding_errors
1541
+ if unbuffered:
1542
+ try:
1543
+ self.init_unbuffered_query()
1544
+ except Exception:
1545
+ self.connection = None
1546
+ self.unbuffered_active = False
1547
+ raise
1548
+
1549
+ def __del__(self):
1550
+ if self.unbuffered_active:
1551
+ self._finish_unbuffered_query()
1552
+
1553
+ def read(self):
1554
+ try:
1555
+ first_packet = self.connection._read_packet()
1556
+
1557
+ if first_packet.is_ok_packet():
1558
+ self._read_ok_packet(first_packet)
1559
+ elif first_packet.is_load_local_packet():
1560
+ self._read_load_local_packet(first_packet)
1561
+ else:
1562
+ self._read_result_packet(first_packet)
1563
+ finally:
1564
+ self.connection = None
1565
+
1566
+ def init_unbuffered_query(self):
1567
+ """
1568
+ Initialize an unbuffered query.
1569
+
1570
+ Raises
1571
+ ------
1572
+ OperationalError : If the connection to the MySQL server is lost.
1573
+ InternalError : Other errors.
1574
+
1575
+ """
1576
+ self.unbuffered_active = True
1577
+ first_packet = self.connection._read_packet()
1578
+
1579
+ if first_packet.is_ok_packet():
1580
+ self._read_ok_packet(first_packet)
1581
+ self.unbuffered_active = False
1582
+ self.connection = None
1583
+ elif first_packet.is_load_local_packet():
1584
+ self._read_load_local_packet(first_packet)
1585
+ self.unbuffered_active = False
1586
+ self.connection = None
1587
+ else:
1588
+ self.field_count = first_packet.read_length_encoded_integer()
1589
+ self._get_descriptions()
1590
+
1591
+ # Apparently, MySQLdb picks this number because it's the maximum
1592
+ # value of a 64bit unsigned integer. Since we're emulating MySQLdb,
1593
+ # we set it to this instead of None, which would be preferred.
1594
+ self.affected_rows = 18446744073709551615
1595
+
1596
+ def _read_ok_packet(self, first_packet):
1597
+ ok_packet = OKPacketWrapper(first_packet)
1598
+ self.affected_rows = ok_packet.affected_rows
1599
+ self.insert_id = ok_packet.insert_id
1600
+ self.server_status = ok_packet.server_status
1601
+ self.warning_count = ok_packet.warning_count
1602
+ self.message = ok_packet.message
1603
+ self.has_next = ok_packet.has_next
1604
+
1605
+ def _read_load_local_packet(self, first_packet):
1606
+ if not self.connection._local_infile:
1607
+ raise RuntimeError(
1608
+ '**WARN**: Received LOAD_LOCAL packet but local_infile option is false.',
1609
+ )
1610
+ load_packet = LoadLocalPacketWrapper(first_packet)
1611
+ sender = LoadLocalFile(load_packet.filename, self.connection)
1612
+ try:
1613
+ sender.send_data()
1614
+ except Exception:
1615
+ self.connection._read_packet() # skip ok packet
1616
+ raise
1617
+
1618
+ ok_packet = self.connection._read_packet()
1619
+ if (
1620
+ not ok_packet.is_ok_packet()
1621
+ ): # pragma: no cover - upstream induced protocol error
1622
+ raise err.OperationalError(
1623
+ CR.CR_COMMANDS_OUT_OF_SYNC,
1624
+ 'Commands Out of Sync',
1625
+ )
1626
+ self._read_ok_packet(ok_packet)
1627
+
1628
+ def _check_packet_is_eof(self, packet):
1629
+ if not packet.is_eof_packet():
1630
+ return False
1631
+ # TODO: Support CLIENT.DEPRECATE_EOF
1632
+ # 1) Add DEPRECATE_EOF to CAPABILITIES
1633
+ # 2) Mask CAPABILITIES with server_capabilities
1634
+ # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper
1635
+ # instead of EOFPacketWrapper
1636
+ wp = EOFPacketWrapper(packet)
1637
+ self.warning_count = wp.warning_count
1638
+ self.has_next = wp.has_next
1639
+ return True
1640
+
1641
+ def _read_result_packet(self, first_packet):
1642
+ self.field_count = first_packet.read_length_encoded_integer()
1643
+ self._get_descriptions()
1644
+ self._read_rowdata_packet()
1645
+
1646
+ def _read_rowdata_packet_unbuffered(self):
1647
+ # Check if in an active query
1648
+ if not self.unbuffered_active:
1649
+ return
1650
+
1651
+ # EOF
1652
+ packet = self.connection._read_packet()
1653
+ if self._check_packet_is_eof(packet):
1654
+ self.unbuffered_active = False
1655
+ self.connection = None
1656
+ self.rows = None
1657
+ return
1658
+
1659
+ row = self._read_row_from_packet(packet)
1660
+ self.affected_rows = 1
1661
+ self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
1662
+ return row
1663
+
1664
+ def _finish_unbuffered_query(self):
1665
+ # After much reading on the MySQL protocol, it appears that there is,
1666
+ # in fact, no way to stop MySQL from sending all the data after
1667
+ # executing a query, so we just spin, and wait for an EOF packet.
1668
+ while self.unbuffered_active and self.connection._sock is not None:
1669
+ try:
1670
+ packet = self.connection._read_packet()
1671
+ except err.OperationalError as e:
1672
+ if e.args[0] in (
1673
+ ER.QUERY_TIMEOUT,
1674
+ ER.STATEMENT_TIMEOUT,
1675
+ ):
1676
+ # if the query timed out we can simply ignore this error
1677
+ self.unbuffered_active = False
1678
+ self.connection = None
1679
+ return
1680
+
1681
+ raise
1682
+
1683
+ if self._check_packet_is_eof(packet):
1684
+ self.unbuffered_active = False
1685
+ self.connection = None # release reference to kill cyclic reference.
1686
+
1687
+ def _read_rowdata_packet(self):
1688
+ """Read a rowdata packet for each data row in the result set."""
1689
+ rows = []
1690
+ while True:
1691
+ packet = self.connection._read_packet()
1692
+ if self._check_packet_is_eof(packet):
1693
+ self.connection = None # release reference to kill cyclic reference.
1694
+ break
1695
+ rows.append(self._read_row_from_packet(packet))
1696
+
1697
+ self.affected_rows = len(rows)
1698
+ self.rows = tuple(rows)
1699
+
1700
+ def _read_row_from_packet(self, packet):
1701
+ row = []
1702
+ for encoding, converter in self.converters:
1703
+ try:
1704
+ data = packet.read_length_coded_string()
1705
+ except IndexError:
1706
+ # No more columns in this row
1707
+ # See https://github.com/PyMySQL/PyMySQL/pull/434
1708
+ break
1709
+ if data is not None:
1710
+ if encoding is not None:
1711
+ data = data.decode(encoding, errors=self.encoding_errors)
1712
+ if DEBUG:
1713
+ print('DEBUG: DATA = ', data)
1714
+ if converter is not None:
1715
+ data = converter(data)
1716
+ row.append(data)
1717
+ return tuple(row)
1718
+
1719
+ def _get_descriptions(self):
1720
+ """Read a column descriptor packet for each column in the result."""
1721
+ self.fields = []
1722
+ self.converters = []
1723
+ use_unicode = self.connection.use_unicode
1724
+ conn_encoding = self.connection.encoding
1725
+ description = []
1726
+
1727
+ for i in range(self.field_count):
1728
+ field = self.connection._read_packet(FieldDescriptorPacket)
1729
+ self.fields.append(field)
1730
+ description.append(field.description())
1731
+ field_type = field.type_code
1732
+ if use_unicode:
1733
+ if field_type == FIELD_TYPE.JSON:
1734
+ # When SELECT from JSON column: charset = binary
1735
+ # When SELECT CAST(... AS JSON): charset = connection encoding
1736
+ # This behavior is different from TEXT / BLOB.
1737
+ # We should decode result by connection encoding regardless charsetnr.
1738
+ # See https://github.com/PyMySQL/PyMySQL/issues/488
1739
+ encoding = conn_encoding # SELECT CAST(... AS JSON)
1740
+ elif field_type in TEXT_TYPES:
1741
+ if field.charsetnr == 63: # binary
1742
+ # TEXTs with charset=binary means BINARY types.
1743
+ encoding = None
1744
+ else:
1745
+ encoding = conn_encoding
1746
+ else:
1747
+ # Integers, Dates and Times, and other basic data is encoded in ascii
1748
+ encoding = 'ascii'
1749
+ else:
1750
+ encoding = None
1751
+ converter = self.connection.decoders.get(field_type)
1752
+ if converter is converters.through:
1753
+ converter = None
1754
+ if DEBUG:
1755
+ print(f'DEBUG: field={field}, converter={converter}')
1756
+ self.converters.append((encoding, converter))
1757
+
1758
+ eof_packet = self.connection._read_packet()
1759
+ assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
1760
+ self.description = tuple(description)
1761
+
1762
+
1763
+ class MySQLResultSV(MySQLResult):
1764
+
1765
+ def __init__(self, connection, unbuffered=False):
1766
+ MySQLResult.__init__(self, connection, unbuffered=unbuffered)
1767
+ self.options = {
1768
+ k: v for k, v in dict(
1769
+ default_converters=converters.decoders,
1770
+ results_type=connection.results_type,
1771
+ parse_json=connection.parse_json,
1772
+ invalid_values=connection.invalid_values,
1773
+ unbuffered=unbuffered,
1774
+ ).items() if v is not UNSET
1775
+ }
1776
+ self._read_rowdata_packet = functools.partial(
1777
+ _singlestoredb_accel.read_rowdata_packet, self, False,
1778
+ )
1779
+ self._read_rowdata_packet_unbuffered = functools.partial(
1780
+ _singlestoredb_accel.read_rowdata_packet, self, True,
1781
+ )
1782
+
1783
+
1784
+ class LoadLocalFile:
1785
+
1786
+ def __init__(self, filename, connection):
1787
+ self.filename = filename
1788
+ self.connection = connection
1789
+
1790
+ def send_data(self):
1791
+ """Send data packets from the local file to the server"""
1792
+ if not self.connection._sock:
1793
+ raise err.InterfaceError(0, '')
1794
+ conn = self.connection
1795
+
1796
+ try:
1797
+ with open(self.filename, 'rb') as open_file:
1798
+ packet_size = min(
1799
+ conn.max_allowed_packet, 16 * 1024,
1800
+ ) # 16KB is efficient enough
1801
+ while True:
1802
+ chunk = open_file.read(packet_size)
1803
+ if not chunk:
1804
+ break
1805
+ conn.write_packet(chunk)
1806
+ except OSError:
1807
+ raise err.OperationalError(
1808
+ ER.FILE_NOT_FOUND,
1809
+ f"Can't find file '{self.filename}'",
1810
+ )
1811
+ finally:
1812
+ if not conn._closed:
1813
+ # send the empty packet to signify we are done sending data
1814
+ conn.write_packet(b'')