singlestoredb 1.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. singlestoredb/__init__.py +75 -0
  2. singlestoredb/ai/__init__.py +2 -0
  3. singlestoredb/ai/chat.py +139 -0
  4. singlestoredb/ai/embeddings.py +128 -0
  5. singlestoredb/alchemy/__init__.py +90 -0
  6. singlestoredb/apps/__init__.py +3 -0
  7. singlestoredb/apps/_cloud_functions.py +90 -0
  8. singlestoredb/apps/_config.py +72 -0
  9. singlestoredb/apps/_connection_info.py +18 -0
  10. singlestoredb/apps/_dashboards.py +47 -0
  11. singlestoredb/apps/_process.py +32 -0
  12. singlestoredb/apps/_python_udfs.py +100 -0
  13. singlestoredb/apps/_stdout_supress.py +30 -0
  14. singlestoredb/apps/_uvicorn_util.py +36 -0
  15. singlestoredb/auth.py +245 -0
  16. singlestoredb/config.py +484 -0
  17. singlestoredb/connection.py +1487 -0
  18. singlestoredb/converters.py +950 -0
  19. singlestoredb/docstring/__init__.py +33 -0
  20. singlestoredb/docstring/attrdoc.py +126 -0
  21. singlestoredb/docstring/common.py +230 -0
  22. singlestoredb/docstring/epydoc.py +267 -0
  23. singlestoredb/docstring/google.py +412 -0
  24. singlestoredb/docstring/numpydoc.py +562 -0
  25. singlestoredb/docstring/parser.py +100 -0
  26. singlestoredb/docstring/py.typed +1 -0
  27. singlestoredb/docstring/rest.py +256 -0
  28. singlestoredb/docstring/tests/__init__.py +1 -0
  29. singlestoredb/docstring/tests/_pydoctor.py +21 -0
  30. singlestoredb/docstring/tests/test_epydoc.py +729 -0
  31. singlestoredb/docstring/tests/test_google.py +1007 -0
  32. singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
  33. singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
  34. singlestoredb/docstring/tests/test_parser.py +248 -0
  35. singlestoredb/docstring/tests/test_rest.py +547 -0
  36. singlestoredb/docstring/tests/test_util.py +70 -0
  37. singlestoredb/docstring/util.py +141 -0
  38. singlestoredb/exceptions.py +120 -0
  39. singlestoredb/functions/__init__.py +16 -0
  40. singlestoredb/functions/decorator.py +201 -0
  41. singlestoredb/functions/dtypes.py +1793 -0
  42. singlestoredb/functions/ext/__init__.py +1 -0
  43. singlestoredb/functions/ext/arrow.py +375 -0
  44. singlestoredb/functions/ext/asgi.py +2133 -0
  45. singlestoredb/functions/ext/json.py +420 -0
  46. singlestoredb/functions/ext/mmap.py +413 -0
  47. singlestoredb/functions/ext/rowdat_1.py +724 -0
  48. singlestoredb/functions/ext/timer.py +89 -0
  49. singlestoredb/functions/ext/utils.py +218 -0
  50. singlestoredb/functions/signature.py +1578 -0
  51. singlestoredb/functions/typing/__init__.py +41 -0
  52. singlestoredb/functions/typing/numpy.py +20 -0
  53. singlestoredb/functions/typing/pandas.py +2 -0
  54. singlestoredb/functions/typing/polars.py +2 -0
  55. singlestoredb/functions/typing/pyarrow.py +2 -0
  56. singlestoredb/functions/utils.py +421 -0
  57. singlestoredb/fusion/__init__.py +11 -0
  58. singlestoredb/fusion/graphql.py +213 -0
  59. singlestoredb/fusion/handler.py +916 -0
  60. singlestoredb/fusion/handlers/__init__.py +0 -0
  61. singlestoredb/fusion/handlers/export.py +525 -0
  62. singlestoredb/fusion/handlers/files.py +690 -0
  63. singlestoredb/fusion/handlers/job.py +660 -0
  64. singlestoredb/fusion/handlers/models.py +250 -0
  65. singlestoredb/fusion/handlers/stage.py +502 -0
  66. singlestoredb/fusion/handlers/utils.py +324 -0
  67. singlestoredb/fusion/handlers/workspace.py +956 -0
  68. singlestoredb/fusion/registry.py +249 -0
  69. singlestoredb/fusion/result.py +399 -0
  70. singlestoredb/http/__init__.py +27 -0
  71. singlestoredb/http/connection.py +1267 -0
  72. singlestoredb/magics/__init__.py +34 -0
  73. singlestoredb/magics/run_personal.py +137 -0
  74. singlestoredb/magics/run_shared.py +134 -0
  75. singlestoredb/management/__init__.py +9 -0
  76. singlestoredb/management/billing_usage.py +148 -0
  77. singlestoredb/management/cluster.py +462 -0
  78. singlestoredb/management/export.py +295 -0
  79. singlestoredb/management/files.py +1102 -0
  80. singlestoredb/management/inference_api.py +105 -0
  81. singlestoredb/management/job.py +887 -0
  82. singlestoredb/management/manager.py +373 -0
  83. singlestoredb/management/organization.py +226 -0
  84. singlestoredb/management/region.py +169 -0
  85. singlestoredb/management/utils.py +423 -0
  86. singlestoredb/management/workspace.py +1927 -0
  87. singlestoredb/mysql/__init__.py +177 -0
  88. singlestoredb/mysql/_auth.py +298 -0
  89. singlestoredb/mysql/charset.py +214 -0
  90. singlestoredb/mysql/connection.py +2032 -0
  91. singlestoredb/mysql/constants/CLIENT.py +38 -0
  92. singlestoredb/mysql/constants/COMMAND.py +32 -0
  93. singlestoredb/mysql/constants/CR.py +78 -0
  94. singlestoredb/mysql/constants/ER.py +474 -0
  95. singlestoredb/mysql/constants/EXTENDED_TYPE.py +3 -0
  96. singlestoredb/mysql/constants/FIELD_TYPE.py +48 -0
  97. singlestoredb/mysql/constants/FLAG.py +15 -0
  98. singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
  99. singlestoredb/mysql/constants/VECTOR_TYPE.py +6 -0
  100. singlestoredb/mysql/constants/__init__.py +0 -0
  101. singlestoredb/mysql/converters.py +271 -0
  102. singlestoredb/mysql/cursors.py +896 -0
  103. singlestoredb/mysql/err.py +92 -0
  104. singlestoredb/mysql/optionfile.py +20 -0
  105. singlestoredb/mysql/protocol.py +450 -0
  106. singlestoredb/mysql/tests/__init__.py +19 -0
  107. singlestoredb/mysql/tests/base.py +126 -0
  108. singlestoredb/mysql/tests/conftest.py +37 -0
  109. singlestoredb/mysql/tests/test_DictCursor.py +132 -0
  110. singlestoredb/mysql/tests/test_SSCursor.py +141 -0
  111. singlestoredb/mysql/tests/test_basic.py +452 -0
  112. singlestoredb/mysql/tests/test_connection.py +851 -0
  113. singlestoredb/mysql/tests/test_converters.py +58 -0
  114. singlestoredb/mysql/tests/test_cursor.py +141 -0
  115. singlestoredb/mysql/tests/test_err.py +16 -0
  116. singlestoredb/mysql/tests/test_issues.py +514 -0
  117. singlestoredb/mysql/tests/test_load_local.py +75 -0
  118. singlestoredb/mysql/tests/test_nextset.py +88 -0
  119. singlestoredb/mysql/tests/test_optionfile.py +27 -0
  120. singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
  121. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
  122. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
  123. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
  124. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
  125. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
  126. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
  127. singlestoredb/mysql/times.py +23 -0
  128. singlestoredb/notebook/__init__.py +16 -0
  129. singlestoredb/notebook/_objects.py +213 -0
  130. singlestoredb/notebook/_portal.py +352 -0
  131. singlestoredb/py.typed +0 -0
  132. singlestoredb/pytest.py +352 -0
  133. singlestoredb/server/__init__.py +0 -0
  134. singlestoredb/server/docker.py +452 -0
  135. singlestoredb/server/free_tier.py +267 -0
  136. singlestoredb/tests/__init__.py +0 -0
  137. singlestoredb/tests/alltypes.sql +307 -0
  138. singlestoredb/tests/alltypes_no_nulls.sql +208 -0
  139. singlestoredb/tests/empty.sql +0 -0
  140. singlestoredb/tests/ext_funcs/__init__.py +702 -0
  141. singlestoredb/tests/local_infile.csv +3 -0
  142. singlestoredb/tests/test.ipynb +18 -0
  143. singlestoredb/tests/test.sql +680 -0
  144. singlestoredb/tests/test2.ipynb +18 -0
  145. singlestoredb/tests/test2.sql +1 -0
  146. singlestoredb/tests/test_basics.py +1332 -0
  147. singlestoredb/tests/test_config.py +318 -0
  148. singlestoredb/tests/test_connection.py +3103 -0
  149. singlestoredb/tests/test_dbapi.py +27 -0
  150. singlestoredb/tests/test_exceptions.py +45 -0
  151. singlestoredb/tests/test_ext_func.py +1472 -0
  152. singlestoredb/tests/test_ext_func_data.py +1101 -0
  153. singlestoredb/tests/test_fusion.py +1527 -0
  154. singlestoredb/tests/test_http.py +288 -0
  155. singlestoredb/tests/test_management.py +1599 -0
  156. singlestoredb/tests/test_plugin.py +33 -0
  157. singlestoredb/tests/test_results.py +171 -0
  158. singlestoredb/tests/test_types.py +132 -0
  159. singlestoredb/tests/test_udf.py +737 -0
  160. singlestoredb/tests/test_udf_returns.py +459 -0
  161. singlestoredb/tests/test_vectorstore.py +51 -0
  162. singlestoredb/tests/test_xdict.py +333 -0
  163. singlestoredb/tests/utils.py +141 -0
  164. singlestoredb/types.py +373 -0
  165. singlestoredb/utils/__init__.py +0 -0
  166. singlestoredb/utils/config.py +950 -0
  167. singlestoredb/utils/convert_rows.py +69 -0
  168. singlestoredb/utils/debug.py +13 -0
  169. singlestoredb/utils/dtypes.py +205 -0
  170. singlestoredb/utils/events.py +65 -0
  171. singlestoredb/utils/mogrify.py +151 -0
  172. singlestoredb/utils/results.py +585 -0
  173. singlestoredb/utils/xdict.py +425 -0
  174. singlestoredb/vectorstore.py +192 -0
  175. singlestoredb/warnings.py +5 -0
  176. singlestoredb-1.16.1.dist-info/METADATA +165 -0
  177. singlestoredb-1.16.1.dist-info/RECORD +183 -0
  178. singlestoredb-1.16.1.dist-info/WHEEL +5 -0
  179. singlestoredb-1.16.1.dist-info/entry_points.txt +2 -0
  180. singlestoredb-1.16.1.dist-info/licenses/LICENSE +201 -0
  181. singlestoredb-1.16.1.dist-info/top_level.txt +3 -0
  182. sqlx/__init__.py +4 -0
  183. sqlx/magic.py +113 -0
@@ -0,0 +1,2032 @@
1
+ # type: ignore
2
+ # Python implementation of the MySQL client-server protocol
3
+ # http://dev.mysql.com/doc/internals/en/client-server-protocol.html
4
+ # Error codes:
5
+ # https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
6
+ import errno
7
+ import functools
8
+ import io
9
+ import os
10
+ import queue
11
+ import socket
12
+ import struct
13
+ import sys
14
+ import traceback
15
+ import warnings
16
+ from collections.abc import Iterable
17
+ from typing import Any
18
+ from typing import Dict
19
+
20
+ try:
21
+ import _singlestoredb_accel
22
+ except (ImportError, ModuleNotFoundError):
23
+ _singlestoredb_accel = None
24
+
25
+ from . import _auth
26
+ from ..utils import events
27
+
28
+ from .charset import charset_by_name, charset_by_id
29
+ from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS
30
+ from . import converters
31
+ from .cursors import (
32
+ Cursor,
33
+ CursorSV,
34
+ DictCursor,
35
+ DictCursorSV,
36
+ NamedtupleCursor,
37
+ NamedtupleCursorSV,
38
+ ArrowCursor,
39
+ ArrowCursorSV,
40
+ NumpyCursor,
41
+ NumpyCursorSV,
42
+ PandasCursor,
43
+ PandasCursorSV,
44
+ PolarsCursor,
45
+ PolarsCursorSV,
46
+ SSCursor,
47
+ SSCursorSV,
48
+ SSDictCursor,
49
+ SSDictCursorSV,
50
+ SSNamedtupleCursor,
51
+ SSNamedtupleCursorSV,
52
+ SSArrowCursor,
53
+ SSArrowCursorSV,
54
+ SSNumpyCursor,
55
+ SSNumpyCursorSV,
56
+ SSPandasCursor,
57
+ SSPandasCursorSV,
58
+ SSPolarsCursor,
59
+ SSPolarsCursorSV,
60
+ )
61
+ from .optionfile import Parser
62
+ from .protocol import (
63
+ dump_packet,
64
+ MysqlPacket,
65
+ FieldDescriptorPacket,
66
+ OKPacketWrapper,
67
+ EOFPacketWrapper,
68
+ LoadLocalPacketWrapper,
69
+ )
70
+ from . import err
71
+ from ..config import get_option
72
+ from .. import fusion
73
+ from .. import connection
74
+ from ..connection import Connection as BaseConnection
75
+ from ..utils.debug import log_query
76
+
77
+ try:
78
+ import ssl
79
+
80
+ SSL_ENABLED = True
81
+ except ImportError:
82
+ ssl = None
83
+ SSL_ENABLED = False
84
+
85
+ try:
86
+ import getpass
87
+
88
+ DEFAULT_USER = getpass.getuser()
89
+ del getpass
90
+ except (ImportError, KeyError):
91
+ # KeyError occurs when there's no entry in OS database for a current user.
92
+ DEFAULT_USER = None
93
+
94
+ DEBUG = get_option('debug.connection')
95
+
96
+ TEXT_TYPES = {
97
+ FIELD_TYPE.BIT,
98
+ FIELD_TYPE.BLOB,
99
+ FIELD_TYPE.LONG_BLOB,
100
+ FIELD_TYPE.MEDIUM_BLOB,
101
+ FIELD_TYPE.STRING,
102
+ FIELD_TYPE.TINY_BLOB,
103
+ FIELD_TYPE.VAR_STRING,
104
+ FIELD_TYPE.VARCHAR,
105
+ FIELD_TYPE.GEOMETRY,
106
+ FIELD_TYPE.BSON,
107
+ FIELD_TYPE.FLOAT32_VECTOR_JSON,
108
+ FIELD_TYPE.FLOAT64_VECTOR_JSON,
109
+ FIELD_TYPE.INT8_VECTOR_JSON,
110
+ FIELD_TYPE.INT16_VECTOR_JSON,
111
+ FIELD_TYPE.INT32_VECTOR_JSON,
112
+ FIELD_TYPE.INT64_VECTOR_JSON,
113
+ FIELD_TYPE.FLOAT32_VECTOR,
114
+ FIELD_TYPE.FLOAT64_VECTOR,
115
+ FIELD_TYPE.INT8_VECTOR,
116
+ FIELD_TYPE.INT16_VECTOR,
117
+ FIELD_TYPE.INT32_VECTOR,
118
+ FIELD_TYPE.INT64_VECTOR,
119
+ }
120
+
121
+ UNSET = 'unset'
122
+
123
+ DEFAULT_CHARSET = 'utf8mb4'
124
+
125
+ MAX_PACKET_LEN = 2**24 - 1
126
+
127
+
128
+ def _pack_int24(n):
129
+ return struct.pack('<I', n)[:3]
130
+
131
+
132
+ # https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
133
+ def _lenenc_int(i):
134
+ if i < 0:
135
+ raise ValueError(
136
+ 'Encoding %d is less than 0 - no representation in LengthEncodedInteger' % i,
137
+ )
138
+ elif i < 0xFB:
139
+ return bytes([i])
140
+ elif i < (1 << 16):
141
+ return b'\xfc' + struct.pack('<H', i)
142
+ elif i < (1 << 24):
143
+ return b'\xfd' + struct.pack('<I', i)[:3]
144
+ elif i < (1 << 64):
145
+ return b'\xfe' + struct.pack('<Q', i)
146
+ else:
147
+ raise ValueError(
148
+ 'Encoding %x is larger than %x - no representation in LengthEncodedInteger'
149
+ % (i, (1 << 64)),
150
+ )
151
+
152
+
153
+ class Connection(BaseConnection):
154
+ """
155
+ Representation of a socket with a mysql server.
156
+
157
+ The proper way to get an instance of this class is to call
158
+ ``connect()``.
159
+
160
+ Establish a connection to the SingleStoreDB database.
161
+
162
+ Parameters
163
+ ----------
164
+ host : str, optional
165
+ Host where the database server is located.
166
+ user : str, optional
167
+ Username to log in as.
168
+ password : str, optional
169
+ Password to use.
170
+ database : str, optional
171
+ Database to use, None to not use a particular one.
172
+ port : int, optional
173
+ Server port to use, default is usually OK. (default: 3306)
174
+ bind_address : str, optional
175
+ When the client has multiple network interfaces, specify
176
+ the interface from which to connect to the host. Argument can be
177
+ a hostname or an IP address.
178
+ unix_socket : str, optional
179
+ Use a unix socket rather than TCP/IP.
180
+ read_timeout : int, optional
181
+ The timeout for reading from the connection in seconds
182
+ (default: None - no timeout)
183
+ write_timeout : int, optional
184
+ The timeout for writing to the connection in seconds
185
+ (default: None - no timeout)
186
+ charset : str, optional
187
+ Charset to use.
188
+ collation : str, optional
189
+ The charset collation
190
+ sql_mode : str, optional
191
+ Default SQL_MODE to use.
192
+ read_default_file : str, optional
193
+ Specifies my.cnf file to read these parameters from under the
194
+ [client] section.
195
+ conv : Dict[str, Callable[Any]], optional
196
+ Conversion dictionary to use instead of the default one.
197
+ This is used to provide custom marshalling and unmarshalling of types.
198
+ See converters.
199
+ use_unicode : bool, optional
200
+ Whether or not to default to unicode strings.
201
+ This option defaults to true.
202
+ client_flag : int, optional
203
+ Custom flags to send to MySQL. Find potential values in constants.CLIENT.
204
+ cursorclass : type, optional
205
+ Custom cursor class to use.
206
+ init_command : str, optional
207
+ Initial SQL statement to run when connection is established.
208
+ connect_timeout : int, optional
209
+ The timeout for connecting to the database in seconds.
210
+ (default: 10, min: 1, max: 31536000)
211
+ ssl : Dict[str, str], optional
212
+ A dict of arguments similar to mysql_ssl_set()'s parameters or
213
+ an ssl.SSLContext.
214
+ ssl_ca : str, optional
215
+ Path to the file that contains a PEM-formatted CA certificate.
216
+ ssl_cert : str, optional
217
+ Path to the file that contains a PEM-formatted client certificate.
218
+ ssl_cipher : str, optional
219
+ SSL ciphers to allow.
220
+ ssl_disabled : bool, optional
221
+ A boolean value that disables usage of TLS.
222
+ ssl_key : str, optional
223
+ Path to the file that contains a PEM-formatted private key for the
224
+ client certificate.
225
+ ssl_verify_cert : str, optional
226
+ Set to true to check the server certificate's validity.
227
+ ssl_verify_identity : bool, optional
228
+ Set to true to check the server's identity.
229
+ tls_sni_servername: str, optional
230
+ Set server host name for TLS connection
231
+ read_default_group : str, optional
232
+ Group to read from in the configuration file.
233
+ autocommit : bool, optional
234
+ Autocommit mode. None means use server default. (default: False)
235
+ local_infile : bool, optional
236
+ Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
237
+ max_allowed_packet : int, optional
238
+ Max size of packet sent to server in bytes. (default: 16MB)
239
+ Only used to limit size of "LOAD LOCAL INFILE" data packet smaller
240
+ than default (16KB).
241
+ defer_connect : bool, optional
242
+ Don't explicitly connect on construction - wait for connect call.
243
+ (default: False)
244
+ auth_plugin_map : Dict[str, type], optional
245
+ A dict of plugin names to a class that processes that plugin.
246
+ The class will take the Connection object as the argument to the
247
+ constructor. The class needs an authenticate method taking an
248
+ authentication packet as an argument. For the dialog plugin, a
249
+ prompt(echo, prompt) method can be used (if no authenticate method)
250
+ for returning a string from the user. (experimental)
251
+ server_public_key : str, optional
252
+ SHA256 authentication plugin public key value. (default: None)
253
+ binary_prefix : bool, optional
254
+ Add _binary prefix on bytes and bytearray. (default: False)
255
+ compress :
256
+ Not supported.
257
+ named_pipe :
258
+ Not supported.
259
+ db : str, optional
260
+ **DEPRECATED** Alias for database.
261
+ passwd : str, optional
262
+ **DEPRECATED** Alias for password.
263
+ parse_json : bool, optional
264
+ Parse JSON values into Python objects?
265
+ invalid_values : Dict[int, Any], optional
266
+ Dictionary of values to use in place of invalid values
267
+ found during conversion of data. The default is to return the byte content
268
+ containing the invalid value. The keys are the integers associtated with
269
+ the column type.
270
+ pure_python : bool, optional
271
+ Should we ignore the C extension even if it's available?
272
+ This can be given explicitly using True or False, or if the value is None,
273
+ the C extension will be loaded if it is available. If set to False and
274
+ the C extension can't be loaded, a NotSupportedError is raised.
275
+ nan_as_null : bool, optional
276
+ Should NaN values be treated as NULLs in parameter substitution including
277
+ uploading data?
278
+ inf_as_null : bool, optional
279
+ Should Inf values be treated as NULLs in parameter substitution including
280
+ uploading data?
281
+ track_env : bool, optional
282
+ Should the connection track the SINGLESTOREDB_URL environment variable?
283
+ enable_extended_data_types : bool, optional
284
+ Should extended data types (BSON, vector) be enabled?
285
+ vector_data_format : str, optional
286
+ Specify the data type of vector values: json or binary
287
+
288
+ See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_
289
+ in the specification.
290
+
291
+ """
292
+
293
+ driver = 'mysql'
294
+ paramstyle = 'pyformat'
295
+
296
+ _sock = None
297
+ _auth_plugin_name = ''
298
+ _closed = False
299
+ _secure = False
300
+ _tls_sni_servername = None
301
+
302
+ def __init__( # noqa: C901
303
+ self,
304
+ *,
305
+ user=None, # The first four arguments is based on DB-API 2.0 recommendation.
306
+ password='',
307
+ host=None,
308
+ database=None,
309
+ unix_socket=None,
310
+ port=0,
311
+ charset='',
312
+ collation=None,
313
+ sql_mode=None,
314
+ read_default_file=None,
315
+ conv=None,
316
+ use_unicode=True,
317
+ client_flag=0,
318
+ cursorclass=None,
319
+ init_command=None,
320
+ connect_timeout=10,
321
+ read_default_group=None,
322
+ autocommit=False,
323
+ local_infile=False,
324
+ max_allowed_packet=16 * 1024 * 1024,
325
+ defer_connect=False,
326
+ auth_plugin_map=None,
327
+ read_timeout=None,
328
+ write_timeout=None,
329
+ bind_address=None,
330
+ binary_prefix=False,
331
+ program_name=None,
332
+ server_public_key=None,
333
+ ssl=None,
334
+ ssl_ca=None,
335
+ ssl_cert=None,
336
+ ssl_cipher=None,
337
+ ssl_disabled=None,
338
+ ssl_key=None,
339
+ ssl_verify_cert=None,
340
+ ssl_verify_identity=None,
341
+ tls_sni_servername=None,
342
+ parse_json=True,
343
+ invalid_values=None,
344
+ pure_python=None,
345
+ buffered=True,
346
+ results_type='tuples',
347
+ compress=None, # not supported
348
+ named_pipe=None, # not supported
349
+ passwd=None, # deprecated
350
+ db=None, # deprecated
351
+ driver=None, # internal use
352
+ conn_attrs=None,
353
+ multi_statements=None,
354
+ client_found_rows=None,
355
+ nan_as_null=None,
356
+ inf_as_null=None,
357
+ encoding_errors='strict',
358
+ track_env=False,
359
+ enable_extended_data_types=True,
360
+ vector_data_format='binary',
361
+ ):
362
+ BaseConnection.__init__(**dict(locals()))
363
+
364
+ if db is not None and database is None:
365
+ # We will raise warning in 2022 or later.
366
+ # See https://github.com/PyMySQL/PyMySQL/issues/939
367
+ # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
368
+ database = db
369
+ if passwd is not None and not password:
370
+ # We will raise warning in 2022 or later.
371
+ # See https://github.com/PyMySQL/PyMySQL/issues/939
372
+ # warnings.warn(
373
+ # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
374
+ # )
375
+ password = passwd
376
+
377
+ if compress or named_pipe:
378
+ raise NotImplementedError(
379
+ 'compress and named_pipe arguments are not supported',
380
+ )
381
+
382
+ self._local_infile = bool(local_infile)
383
+ self._local_infile_stream = None
384
+ if self._local_infile:
385
+ client_flag |= CLIENT.LOCAL_FILES
386
+ if multi_statements:
387
+ client_flag |= CLIENT.MULTI_STATEMENTS
388
+ if client_found_rows:
389
+ client_flag |= CLIENT.FOUND_ROWS
390
+
391
+ if read_default_group and not read_default_file:
392
+ if sys.platform.startswith('win'):
393
+ read_default_file = 'c:\\my.ini'
394
+ else:
395
+ read_default_file = '/etc/my.cnf'
396
+
397
+ if read_default_file:
398
+ if not read_default_group:
399
+ read_default_group = 'client'
400
+
401
+ cfg = Parser()
402
+ cfg.read(os.path.expanduser(read_default_file))
403
+
404
+ def _config(key, arg):
405
+ if arg:
406
+ return arg
407
+ try:
408
+ return cfg.get(read_default_group, key)
409
+ except Exception:
410
+ return arg
411
+
412
+ user = _config('user', user)
413
+ password = _config('password', password)
414
+ host = _config('host', host)
415
+ database = _config('database', database)
416
+ unix_socket = _config('socket', unix_socket)
417
+ port = int(_config('port', port))
418
+ bind_address = _config('bind-address', bind_address)
419
+ charset = _config('default-character-set', charset)
420
+ if not ssl:
421
+ ssl = {}
422
+ if isinstance(ssl, dict):
423
+ for key in ['ca', 'capath', 'cert', 'key', 'cipher']:
424
+ value = _config('ssl-' + key, ssl.get(key))
425
+ if value:
426
+ ssl[key] = value
427
+
428
+ self.ssl = False
429
+ if not ssl_disabled:
430
+ if ssl_ca or ssl_cert or ssl_key or ssl_cipher or \
431
+ ssl_verify_cert or ssl_verify_identity:
432
+ ssl = {
433
+ 'ca': ssl_ca,
434
+ 'check_hostname': bool(ssl_verify_identity),
435
+ 'verify_mode': ssl_verify_cert
436
+ if ssl_verify_cert is not None
437
+ else False,
438
+ }
439
+ if ssl_cert is not None:
440
+ ssl['cert'] = ssl_cert
441
+ if ssl_key is not None:
442
+ ssl['key'] = ssl_key
443
+ if ssl_cipher is not None:
444
+ ssl['cipher'] = ssl_cipher
445
+ if ssl:
446
+ if not SSL_ENABLED:
447
+ raise NotImplementedError('ssl module not found')
448
+ self.ssl = True
449
+ client_flag |= CLIENT.SSL
450
+ self.ctx = self._create_ssl_ctx(ssl)
451
+
452
+ self.host = host or 'localhost'
453
+ self.port = port or 3306
454
+ if type(self.port) is not int:
455
+ raise ValueError('port should be of type int')
456
+ self.user = user or DEFAULT_USER
457
+ self.password = password or b''
458
+ if isinstance(self.password, str):
459
+ self.password = self.password.encode('latin1')
460
+ self.db = database
461
+ self.unix_socket = unix_socket
462
+ self.bind_address = bind_address
463
+ if not (0 < connect_timeout <= 31536000):
464
+ raise ValueError('connect_timeout should be >0 and <=31536000')
465
+ self.connect_timeout = connect_timeout or None
466
+ if read_timeout is not None and read_timeout <= 0:
467
+ raise ValueError('read_timeout should be > 0')
468
+ self._read_timeout = read_timeout
469
+ if write_timeout is not None and write_timeout <= 0:
470
+ raise ValueError('write_timeout should be > 0')
471
+ self._write_timeout = write_timeout
472
+
473
+ self.charset = charset or DEFAULT_CHARSET
474
+ self.collation = collation
475
+ self.use_unicode = use_unicode
476
+ self.encoding_errors = encoding_errors
477
+
478
+ self.encoding = charset_by_name(self.charset).encoding
479
+
480
+ client_flag |= CLIENT.CAPABILITIES
481
+ client_flag |= CLIENT.CONNECT_WITH_DB
482
+
483
+ self.client_flag = client_flag
484
+
485
+ self.pure_python = pure_python
486
+ self.results_type = results_type
487
+ self.resultclass = MySQLResult
488
+ if cursorclass is not None:
489
+ self.cursorclass = cursorclass
490
+ elif buffered:
491
+ if 'dict' in self.results_type:
492
+ self.cursorclass = DictCursor
493
+ elif 'namedtuple' in self.results_type:
494
+ self.cursorclass = NamedtupleCursor
495
+ elif 'numpy' in self.results_type:
496
+ self.cursorclass = NumpyCursor
497
+ elif 'arrow' in self.results_type:
498
+ self.cursorclass = ArrowCursor
499
+ elif 'pandas' in self.results_type:
500
+ self.cursorclass = PandasCursor
501
+ elif 'polars' in self.results_type:
502
+ self.cursorclass = PolarsCursor
503
+ else:
504
+ self.cursorclass = Cursor
505
+ else:
506
+ if 'dict' in self.results_type:
507
+ self.cursorclass = SSDictCursor
508
+ elif 'namedtuple' in self.results_type:
509
+ self.cursorclass = SSNamedtupleCursor
510
+ elif 'numpy' in self.results_type:
511
+ self.cursorclass = SSNumpyCursor
512
+ elif 'arrow' in self.results_type:
513
+ self.cursorclass = SSArrowCursor
514
+ elif 'pandas' in self.results_type:
515
+ self.cursorclass = SSPandasCursor
516
+ elif 'polars' in self.results_type:
517
+ self.cursorclass = SSPolarsCursor
518
+ else:
519
+ self.cursorclass = SSCursor
520
+
521
+ if self.pure_python is False and _singlestoredb_accel is None:
522
+ try:
523
+ import _singlestortedb_accel # noqa: F401
524
+ except Exception:
525
+ import traceback
526
+ traceback.print_exc(file=sys.stderr)
527
+ finally:
528
+ raise err.NotSupportedError(
529
+ 'pure_python=False, but the '
530
+ 'C extension can not be loaded',
531
+ )
532
+
533
+ if self.pure_python is True:
534
+ pass
535
+
536
+ # The C extension handles these types internally.
537
+ elif _singlestoredb_accel is not None:
538
+ self.resultclass = MySQLResultSV
539
+ if self.cursorclass is Cursor:
540
+ self.cursorclass = CursorSV
541
+ elif self.cursorclass is SSCursor:
542
+ self.cursorclass = SSCursorSV
543
+ elif self.cursorclass is DictCursor:
544
+ self.cursorclass = DictCursorSV
545
+ self.results_type = 'dicts'
546
+ elif self.cursorclass is SSDictCursor:
547
+ self.cursorclass = SSDictCursorSV
548
+ self.results_type = 'dicts'
549
+ elif self.cursorclass is NamedtupleCursor:
550
+ self.cursorclass = NamedtupleCursorSV
551
+ self.results_type = 'namedtuples'
552
+ elif self.cursorclass is SSNamedtupleCursor:
553
+ self.cursorclass = SSNamedtupleCursorSV
554
+ self.results_type = 'namedtuples'
555
+ elif self.cursorclass is NumpyCursor:
556
+ self.cursorclass = NumpyCursorSV
557
+ self.results_type = 'numpy'
558
+ elif self.cursorclass is SSNumpyCursor:
559
+ self.cursorclass = SSNumpyCursorSV
560
+ self.results_type = 'numpy'
561
+ elif self.cursorclass is ArrowCursor:
562
+ self.cursorclass = ArrowCursorSV
563
+ self.results_type = 'arrow'
564
+ elif self.cursorclass is SSArrowCursor:
565
+ self.cursorclass = SSArrowCursorSV
566
+ self.results_type = 'arrow'
567
+ elif self.cursorclass is PandasCursor:
568
+ self.cursorclass = PandasCursorSV
569
+ self.results_type = 'pandas'
570
+ elif self.cursorclass is SSPandasCursor:
571
+ self.cursorclass = SSPandasCursorSV
572
+ self.results_type = 'pandas'
573
+ elif self.cursorclass is PolarsCursor:
574
+ self.cursorclass = PolarsCursorSV
575
+ self.results_type = 'polars'
576
+ elif self.cursorclass is SSPolarsCursor:
577
+ self.cursorclass = SSPolarsCursorSV
578
+ self.results_type = 'polars'
579
+
580
+ self._result = None
581
+ self._affected_rows = 0
582
+ self.host_info = 'Not connected'
583
+
584
+ # specified autocommit mode. None means use server default.
585
+ self.autocommit_mode = autocommit
586
+
587
+ if conv is None:
588
+ conv = converters.conversions
589
+
590
+ conv = conv.copy()
591
+
592
+ self.parse_json = parse_json
593
+ self.invalid_values = (invalid_values or {}).copy()
594
+
595
+ # Disable JSON parsing for Arrow
596
+ if self.results_type in ['arrow']:
597
+ conv[245] = None
598
+ self.parse_json = False
599
+
600
+ # Disable date/time parsing for polars; let polars do the parsing
601
+ elif self.results_type in ['polars']:
602
+ conv[7] = None
603
+ conv[10] = None
604
+ conv[12] = None
605
+
606
+ # Need for MySQLdb compatibility.
607
+ self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
608
+ self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
609
+ self.sql_mode = sql_mode
610
+ self.init_command = init_command
611
+ self.max_allowed_packet = max_allowed_packet
612
+ self._auth_plugin_map = auth_plugin_map or {}
613
+ self._binary_prefix = binary_prefix
614
+ self.server_public_key = server_public_key
615
+
616
+ if self.connection_params['nan_as_null'] or \
617
+ self.connection_params['inf_as_null']:
618
+ float_encoder = self.encoders.get(float)
619
+ if float_encoder is not None:
620
+ self.encoders[float] = functools.partial(
621
+ float_encoder,
622
+ nan_as_null=self.connection_params['nan_as_null'],
623
+ inf_as_null=self.connection_params['inf_as_null'],
624
+ )
625
+
626
+ from .. import __version__ as VERSION_STRING
627
+
628
+ if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:
629
+ VERSION_STRING += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']
630
+
631
+ self._connect_attrs = {
632
+ '_os': str(sys.platform),
633
+ '_pid': str(os.getpid()),
634
+ '_client_name': 'SingleStoreDB Python Client',
635
+ '_client_version': VERSION_STRING,
636
+ }
637
+
638
+ if program_name:
639
+ self._connect_attrs['program_name'] = program_name
640
+ if conn_attrs is not None:
641
+ # do not overwrite the attributes that we set ourselves
642
+ for k, v in conn_attrs.items():
643
+ if k not in self._connect_attrs:
644
+ self._connect_attrs[k] = v
645
+
646
+ self._is_committable = True
647
+ self._in_sync = False
648
+ self._tls_sni_servername = tls_sni_servername
649
+ self._track_env = bool(track_env) or self.host == 'singlestore.com'
650
+ self._enable_extended_data_types = enable_extended_data_types
651
+ if vector_data_format.lower() in ['json', 'binary']:
652
+ self._vector_data_format = vector_data_format
653
+ else:
654
+ raise ValueError(
655
+ 'unknown value for vector_data_format, '
656
+ f'expecting "json" or "binary": {vector_data_format}',
657
+ )
658
+ self._connection_info = {}
659
+ events.subscribe(self._handle_event)
660
+
661
+ if defer_connect or self._track_env:
662
+ self._sock = None
663
+ else:
664
+ self.connect()
665
+
666
+ def _handle_event(self, data: Dict[str, Any]) -> None:
667
+ if data.get('name', '') == 'singlestore.portal.connection_updated':
668
+ self._connection_info = dict(data)
669
+
670
+ @property
671
+ def messages(self):
672
+ # TODO
673
+ []
674
+
675
+ def __enter__(self):
676
+ return self
677
+
678
+ def __exit__(self, *exc_info):
679
+ del exc_info
680
+ self.close()
681
+
682
+ def _raise_mysql_exception(self, data):
683
+ err.raise_mysql_exception(data)
684
+
685
+ def _create_ssl_ctx(self, sslp):
686
+ if isinstance(sslp, ssl.SSLContext):
687
+ return sslp
688
+ ca = sslp.get('ca')
689
+ capath = sslp.get('capath')
690
+ hasnoca = ca is None and capath is None
691
+ ctx = ssl.create_default_context(cafile=ca, capath=capath)
692
+ ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
693
+ verify_mode_value = sslp.get('verify_mode')
694
+ if verify_mode_value is None:
695
+ ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
696
+ elif isinstance(verify_mode_value, bool):
697
+ ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
698
+ else:
699
+ if isinstance(verify_mode_value, str):
700
+ verify_mode_value = verify_mode_value.lower()
701
+ if verify_mode_value in ('none', '0', 'false', 'no'):
702
+ ctx.verify_mode = ssl.CERT_NONE
703
+ elif verify_mode_value == 'optional':
704
+ ctx.verify_mode = ssl.CERT_OPTIONAL
705
+ elif verify_mode_value in ('required', '1', 'true', 'yes'):
706
+ ctx.verify_mode = ssl.CERT_REQUIRED
707
+ else:
708
+ ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
709
+ if 'cert' in sslp:
710
+ ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
711
+ if 'cipher' in sslp:
712
+ ctx.set_ciphers(sslp['cipher'])
713
+ ctx.options |= ssl.OP_NO_SSLv2
714
+ ctx.options |= ssl.OP_NO_SSLv3
715
+ return ctx
716
+
717
+ def close(self):
718
+ """
719
+ Send the quit message and close the socket.
720
+
721
+ See `Connection.close()
722
+ <https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
723
+ in the specification.
724
+
725
+ Raises
726
+ ------
727
+ Error : If the connection is already closed.
728
+
729
+ """
730
+ self._result = None
731
+ if self.host == 'singlestore.com':
732
+ return
733
+ if self._closed:
734
+ raise err.Error('Already closed')
735
+ events.unsubscribe(self._handle_event)
736
+ self._closed = True
737
+ if self._sock is None:
738
+ return
739
+ send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
740
+ try:
741
+ self._write_bytes(send_data)
742
+ except Exception:
743
+ pass
744
+ finally:
745
+ self._force_close()
746
+
747
+ @property
748
+ def open(self):
749
+ """Return True if the connection is open."""
750
+ return self._sock is not None
751
+
752
+ def is_connected(self):
753
+ """Return True if the connection is open."""
754
+ return self.open
755
+
756
+ def _force_close(self):
757
+ """Close connection without QUIT message."""
758
+ if self._sock:
759
+ try:
760
+ self._sock.close()
761
+ except: # noqa
762
+ pass
763
+ self._sock = None
764
+ self._rfile = None
765
+
766
+ __del__ = _force_close
767
+
768
+ def autocommit(self, value):
769
+ """Enable autocommit in the server."""
770
+ self.autocommit_mode = bool(value)
771
+ current = self.get_autocommit()
772
+ if value != current:
773
+ self._send_autocommit_mode()
774
+
775
+ def get_autocommit(self):
776
+ """Retrieve autocommit status."""
777
+ return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
778
+
779
+ def _read_ok_packet(self):
780
+ pkt = self._read_packet()
781
+ if not pkt.is_ok_packet():
782
+ raise err.OperationalError(
783
+ CR.CR_COMMANDS_OUT_OF_SYNC,
784
+ 'Command Out of Sync',
785
+ )
786
+ ok = OKPacketWrapper(pkt)
787
+ self.server_status = ok.server_status
788
+ return ok
789
+
790
+ def _send_autocommit_mode(self):
791
+ """Set whether or not to commit after every execute()."""
792
+ log_query('SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode))
793
+ self._execute_command(
794
+ COMMAND.COM_QUERY, 'SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode),
795
+ )
796
+ self._read_ok_packet()
797
+
798
+ def begin(self):
799
+ """Begin transaction."""
800
+ log_query('BEGIN')
801
+ if self.host == 'singlestore.com':
802
+ return
803
+ self._execute_command(COMMAND.COM_QUERY, 'BEGIN')
804
+ self._read_ok_packet()
805
+
806
+ def commit(self):
807
+ """
808
+ Commit changes to stable storage.
809
+
810
+ See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
811
+ in the specification.
812
+
813
+ """
814
+ log_query('COMMIT')
815
+ if not self._is_committable or self.host == 'singlestore.com':
816
+ self._is_committable = True
817
+ return
818
+ self._execute_command(COMMAND.COM_QUERY, 'COMMIT')
819
+ self._read_ok_packet()
820
+
821
+ def rollback(self):
822
+ """
823
+ Roll back the current transaction.
824
+
825
+ See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
826
+ in the specification.
827
+
828
+ """
829
+ log_query('ROLLBACK')
830
+ if not self._is_committable or self.host == 'singlestore.com':
831
+ self._is_committable = True
832
+ return
833
+ self._execute_command(COMMAND.COM_QUERY, 'ROLLBACK')
834
+ self._read_ok_packet()
835
+
836
+ def show_warnings(self):
837
+ """Send the "SHOW WARNINGS" SQL command."""
838
+ log_query('SHOW WARNINGS')
839
+ self._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
840
+ result = self.resultclass(self)
841
+ result.read()
842
+ return result.rows
843
+
844
+ def select_db(self, db):
845
+ """
846
+ Set current db.
847
+
848
+ db : str
849
+ The name of the db.
850
+
851
+ """
852
+ self._execute_command(COMMAND.COM_INIT_DB, db)
853
+ self._read_ok_packet()
854
+
855
+ def escape(self, obj, mapping=None):
856
+ """
857
+ Escape whatever value is passed.
858
+
859
+ Non-standard, for internal use; do not use this in your applications.
860
+
861
+ """
862
+ dtype = type(obj)
863
+ if dtype is str or isinstance(obj, str):
864
+ return "'{}'".format(self.escape_string(obj))
865
+ if dtype is bytes or dtype is bytearray or isinstance(obj, (bytes, bytearray)):
866
+ return self._quote_bytes(obj)
867
+ if mapping is None:
868
+ mapping = self.encoders
869
+ return converters.escape_item(obj, self.charset, mapping=mapping)
870
+
871
+ def literal(self, obj):
872
+ """
873
+ Alias for escape().
874
+
875
+ Non-standard, for internal use; do not use this in your applications.
876
+
877
+ """
878
+ return self.escape(obj, self.encoders)
879
+
880
+ def escape_string(self, s):
881
+ """Escape a string value."""
882
+ if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
883
+ return s.replace("'", "''")
884
+ return converters.escape_string(s)
885
+
886
+ def _quote_bytes(self, s):
887
+ if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
888
+ if self._binary_prefix:
889
+ return "_binary X'{}'".format(s.hex())
890
+ return "X'{}'".format(s.hex())
891
+ return converters.escape_bytes(s)
892
+
893
+ def cursor(self):
894
+ """Create a new cursor to execute queries with."""
895
+ return self.cursorclass(self)
896
+
897
+ # The following methods are INTERNAL USE ONLY (called from Cursor)
898
+ def query(self, sql, unbuffered=False, infile_stream=None):
899
+ """
900
+ Run a query on the server.
901
+
902
+ Internal use only.
903
+
904
+ """
905
+ # if DEBUG:
906
+ # print("DEBUG: sending query:", sql)
907
+ handler = fusion.get_handler(sql)
908
+ if handler is not None:
909
+ self._is_committable = False
910
+ self._result = fusion.execute(self, sql, handler=handler)
911
+ self._affected_rows = self._result.affected_rows
912
+ else:
913
+ self._is_committable = True
914
+ if isinstance(sql, str):
915
+ sql = sql.encode(self.encoding, 'surrogateescape')
916
+ self._local_infile_stream = infile_stream
917
+ self._execute_command(COMMAND.COM_QUERY, sql)
918
+ self._affected_rows = self._read_query_result(unbuffered=unbuffered)
919
+ self._local_infile_stream = None
920
+ return self._affected_rows
921
+
922
+ def next_result(self, unbuffered=False):
923
+ """
924
+ Retrieve the next result set.
925
+
926
+ Internal use only.
927
+
928
+ """
929
+ self._affected_rows = self._read_query_result(unbuffered=unbuffered)
930
+ return self._affected_rows
931
+
932
+ def affected_rows(self):
933
+ """
934
+ Return number of affected rows.
935
+
936
+ Internal use only.
937
+
938
+ """
939
+ return self._affected_rows
940
+
941
+ def kill(self, thread_id):
942
+ """
943
+ Execute kill command.
944
+
945
+ Internal use only.
946
+
947
+ """
948
+ arg = struct.pack('<I', thread_id)
949
+ self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
950
+ return self._read_ok_packet()
951
+
952
+ def ping(self, reconnect=True):
953
+ """
954
+ Check if the server is alive.
955
+
956
+ Parameters
957
+ ----------
958
+ reconnect : bool, optional
959
+ If the connection is closed, reconnect.
960
+
961
+ Raises
962
+ ------
963
+ Error : If the connection is closed and reconnect=False.
964
+
965
+ """
966
+ if self._sock is None:
967
+ if reconnect:
968
+ self.connect()
969
+ reconnect = False
970
+ else:
971
+ raise err.Error('Already closed')
972
+ try:
973
+ self._execute_command(COMMAND.COM_PING, '')
974
+ self._read_ok_packet()
975
+ except Exception:
976
+ if reconnect:
977
+ self.connect()
978
+ self.ping(False)
979
+ else:
980
+ raise
981
+
982
+ def set_charset(self, charset):
983
+ """Deprecated. Use set_character_set() instead."""
984
+ # This function has been implemented in old PyMySQL.
985
+ # But this name is different from MySQLdb.
986
+ # So we keep this function for compatibility and add
987
+ # new set_character_set() function.
988
+ self.set_character_set(charset)
989
+
990
+ def set_character_set(self, charset, collation=None):
991
+ """
992
+ Set charaset (and collation) on the server.
993
+
994
+ Send "SET NAMES charset [COLLATE collation]" query.
995
+ Update Connection.encoding based on charset.
996
+
997
+ Parameters
998
+ ----------
999
+ charset : str
1000
+ The charset to enable.
1001
+ collation : str, optional
1002
+ The collation value
1003
+
1004
+ """
1005
+ # Make sure charset is supported.
1006
+ encoding = charset_by_name(charset).encoding
1007
+
1008
+ if collation:
1009
+ query = f'SET NAMES {charset} COLLATE {collation}'
1010
+ else:
1011
+ query = f'SET NAMES {charset}'
1012
+ self._execute_command(COMMAND.COM_QUERY, query)
1013
+ self._read_packet()
1014
+ self.charset = charset
1015
+ self.encoding = encoding
1016
+ self.collation = collation
1017
+
1018
+ def _sync_connection(self):
1019
+ """Synchronize connection with env variable."""
1020
+ if self._in_sync:
1021
+ return
1022
+
1023
+ if not self._track_env:
1024
+ return
1025
+
1026
+ url = self._connection_info.get('connection_url')
1027
+ if not url:
1028
+ url = os.environ.get('SINGLESTOREDB_URL')
1029
+ if not url:
1030
+ return
1031
+
1032
+ out = {}
1033
+ urlp = connection._parse_url(url)
1034
+ out.update(urlp)
1035
+
1036
+ out = connection._cast_params(out)
1037
+
1038
+ # Set default port based on driver.
1039
+ if 'port' not in out or not out['port']:
1040
+ out['port'] = int(get_option('port') or 3306)
1041
+
1042
+ # If there is no user and the password is empty, remove the password key.
1043
+ if 'user' not in out and not out.get('password', None):
1044
+ out.pop('password', None)
1045
+
1046
+ if out['host'] == 'singlestore.com':
1047
+ raise err.InterfaceError(0, 'Connection URL has not been established')
1048
+
1049
+ # If it's just a password change, we don't need to reconnect
1050
+ if self._sock is not None and \
1051
+ (self.host, self.port, self.user, self.db) == \
1052
+ (out['host'], out['port'], out['user'], out.get('database')):
1053
+ return
1054
+
1055
+ self.host = out['host']
1056
+ self.port = out['port']
1057
+ self.user = out['user']
1058
+ if isinstance(out['password'], str):
1059
+ self.password = out['password'].encode('latin-1')
1060
+ else:
1061
+ self.password = out['password'] or b''
1062
+ self.db = out.get('database')
1063
+ try:
1064
+ self._in_sync = True
1065
+ self.connect()
1066
+ finally:
1067
+ self._in_sync = False
1068
+
1069
+ def connect(self, sock=None):
1070
+ """
1071
+ Connect to server using existing parameters.
1072
+
1073
+ Internal use only.
1074
+
1075
+ """
1076
+ self._closed = False
1077
+ try:
1078
+ if sock is None:
1079
+ if self.unix_socket:
1080
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1081
+ sock.settimeout(self.connect_timeout)
1082
+ sock.connect(self.unix_socket)
1083
+ self.host_info = 'Localhost via UNIX socket'
1084
+ self._secure = True
1085
+ if DEBUG:
1086
+ print('connected using unix_socket')
1087
+ else:
1088
+ kwargs = {}
1089
+ if self.bind_address is not None:
1090
+ kwargs['source_address'] = (self.bind_address, 0)
1091
+ while True:
1092
+ try:
1093
+ sock = socket.create_connection(
1094
+ (self.host, self.port), self.connect_timeout, **kwargs,
1095
+ )
1096
+ break
1097
+ except OSError as e:
1098
+ if e.errno == errno.EINTR:
1099
+ continue
1100
+ raise
1101
+ self.host_info = 'socket %s:%d' % (self.host, self.port)
1102
+ if DEBUG:
1103
+ print('connected using socket')
1104
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1105
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1106
+ sock.settimeout(None)
1107
+
1108
+ self._sock = sock
1109
+ self._rfile = sock.makefile('rb')
1110
+ self._next_seq_id = 0
1111
+
1112
+ self._get_server_information()
1113
+ self._request_authentication()
1114
+
1115
+ # Send "SET NAMES" query on init for:
1116
+ # - Ensure charaset (and collation) is set to the server.
1117
+ # - collation_id in handshake packet may be ignored.
1118
+ # - If collation is not specified, we don't know what is server's
1119
+ # default collation for the charset. For example, default collation
1120
+ # of utf8mb4 is:
1121
+ # - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci
1122
+ # - MySQL 8.0: utf8mb4_0900_ai_ci
1123
+ #
1124
+ # Reference:
1125
+ # - https://github.com/PyMySQL/PyMySQL/issues/1092
1126
+ # - https://github.com/wagtail/wagtail/issues/9477
1127
+ # - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)
1128
+ self.set_character_set(self.charset, self.collation)
1129
+
1130
+ if self.sql_mode is not None:
1131
+ c = self.cursor()
1132
+ c.execute('SET sql_mode=%s', (self.sql_mode,))
1133
+ c.close()
1134
+
1135
+ if self._enable_extended_data_types:
1136
+ c = self.cursor()
1137
+ try:
1138
+ c.execute('SET @@SESSION.enable_extended_types_metadata=on')
1139
+ except self.OperationalError:
1140
+ pass
1141
+ c.close()
1142
+
1143
+ if self._vector_data_format:
1144
+ c = self.cursor()
1145
+ try:
1146
+ val = self._vector_data_format
1147
+ c.execute(f'SET @@SESSION.vector_type_project_format={val}')
1148
+ except self.OperationalError:
1149
+ pass
1150
+ c.close()
1151
+
1152
+ if self.init_command is not None:
1153
+ c = self.cursor()
1154
+ c.execute(self.init_command)
1155
+ c.close()
1156
+
1157
+ if self.autocommit_mode is not None:
1158
+ self.autocommit(self.autocommit_mode)
1159
+
1160
+ except BaseException as e:
1161
+ self._rfile = None
1162
+ if sock is not None:
1163
+ try:
1164
+ sock.close()
1165
+ except: # noqa
1166
+ pass
1167
+
1168
+ if isinstance(e, (OSError, IOError, socket.error)):
1169
+ exc = err.OperationalError(
1170
+ CR.CR_CONN_HOST_ERROR,
1171
+ f'Can\'t connect to MySQL server on {self.host!r} ({e})',
1172
+ )
1173
+ # Keep original exception and traceback to investigate error.
1174
+ exc.original_exception = e
1175
+ exc.traceback = traceback.format_exc()
1176
+ if DEBUG:
1177
+ print(exc.traceback)
1178
+ raise exc
1179
+
1180
+ # If e is neither DatabaseError or IOError, It's a bug.
1181
+ # But raising AssertionError hides original error.
1182
+ # So just reraise it.
1183
+ raise
1184
+
1185
+ def write_packet(self, payload):
1186
+ """
1187
+ Writes an entire "mysql packet" in its entirety to the network.
1188
+
1189
+ Adds its length and sequence number.
1190
+
1191
+ """
1192
+ # Internal note: when you build packet manually and calls _write_bytes()
1193
+ # directly, you should set self._next_seq_id properly.
1194
+ data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
1195
+ if DEBUG:
1196
+ dump_packet(data)
1197
+ self._write_bytes(data)
1198
+ self._next_seq_id = (self._next_seq_id + 1) % 256
1199
+
1200
+ def _read_packet(self, packet_type=MysqlPacket):
1201
+ """
1202
+ Read an entire "mysql packet" in its entirety from the network.
1203
+
1204
+ Raises
1205
+ ------
1206
+ OperationalError : If the connection to the MySQL server is lost.
1207
+ InternalError : If the packet sequence number is wrong.
1208
+
1209
+ Returns
1210
+ -------
1211
+ MysqlPacket
1212
+
1213
+ """
1214
+ buff = bytearray()
1215
+ while True:
1216
+ packet_header = self._read_bytes(4)
1217
+ # if DEBUG: dump_packet(packet_header)
1218
+
1219
+ btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
1220
+ bytes_to_read = btrl + (btrh << 16)
1221
+ if packet_number != self._next_seq_id:
1222
+ self._force_close()
1223
+ if packet_number == 0:
1224
+ # MariaDB sends error packet with seqno==0 when shutdown
1225
+ raise err.OperationalError(
1226
+ CR.CR_SERVER_LOST,
1227
+ 'Lost connection to MySQL server during query',
1228
+ )
1229
+ raise err.InternalError(
1230
+ 'Packet sequence number wrong - got %d expected %d'
1231
+ % (packet_number, self._next_seq_id),
1232
+ )
1233
+ self._next_seq_id = (self._next_seq_id + 1) % 256
1234
+
1235
+ recv_data = self._read_bytes(bytes_to_read)
1236
+ if DEBUG:
1237
+ dump_packet(recv_data)
1238
+ buff += recv_data
1239
+ # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
1240
+ if bytes_to_read == 0xFFFFFF:
1241
+ continue
1242
+ if bytes_to_read < MAX_PACKET_LEN:
1243
+ break
1244
+
1245
+ packet = packet_type(bytes(buff), self.encoding)
1246
+ if packet.is_error_packet():
1247
+ if self._result is not None and self._result.unbuffered_active is True:
1248
+ self._result.unbuffered_active = False
1249
+ packet.raise_for_error()
1250
+ return packet
1251
+
1252
+ def _read_bytes(self, num_bytes):
1253
+ if self._read_timeout is not None:
1254
+ self._sock.settimeout(self._read_timeout)
1255
+ while True:
1256
+ try:
1257
+ data = self._rfile.read(num_bytes)
1258
+ break
1259
+ except OSError as e:
1260
+ if e.errno == errno.EINTR:
1261
+ continue
1262
+ self._force_close()
1263
+ raise err.OperationalError(
1264
+ CR.CR_SERVER_LOST,
1265
+ 'Lost connection to MySQL server during query (%s)' % (e,),
1266
+ )
1267
+ except BaseException:
1268
+ # Don't convert unknown exception to MySQLError.
1269
+ self._force_close()
1270
+ raise
1271
+ if len(data) < num_bytes:
1272
+ self._force_close()
1273
+ raise err.OperationalError(
1274
+ CR.CR_SERVER_LOST, 'Lost connection to MySQL server during query',
1275
+ )
1276
+ return data
1277
+
1278
+ def _write_bytes(self, data):
1279
+ if self._write_timeout is not None:
1280
+ self._sock.settimeout(self._write_timeout)
1281
+ try:
1282
+ self._sock.sendall(data)
1283
+ except OSError as e:
1284
+ self._force_close()
1285
+ raise err.OperationalError(
1286
+ CR.CR_SERVER_GONE_ERROR, f'MySQL server has gone away ({e!r})',
1287
+ )
1288
+
1289
+ def _read_query_result(self, unbuffered=False):
1290
+ self._result = None
1291
+ if unbuffered:
1292
+ result = self.resultclass(self, unbuffered=unbuffered)
1293
+ else:
1294
+ result = self.resultclass(self)
1295
+ result.read()
1296
+ self._result = result
1297
+ if result.server_status is not None:
1298
+ self.server_status = result.server_status
1299
+ return result.affected_rows
1300
+
1301
+ def insert_id(self):
1302
+ if self._result:
1303
+ return self._result.insert_id
1304
+ else:
1305
+ return 0
1306
+
1307
+ def _execute_command(self, command, sql):
1308
+ """
1309
+ Execute command.
1310
+
1311
+ Raises
1312
+ ------
1313
+ InterfaceError : If the connection is closed.
1314
+ ValueError : If no username was specified.
1315
+
1316
+ """
1317
+ self._sync_connection()
1318
+
1319
+ if self._sock is None:
1320
+ raise err.InterfaceError(0, 'The connection has been closed')
1321
+
1322
+ # If the last query was unbuffered, make sure it finishes before
1323
+ # sending new commands
1324
+ if self._result is not None:
1325
+ if self._result.unbuffered_active:
1326
+ warnings.warn('Previous unbuffered result was left incomplete')
1327
+ self._result._finish_unbuffered_query()
1328
+ while self._result.has_next:
1329
+ self.next_result()
1330
+ self._result = None
1331
+
1332
+ if isinstance(sql, str):
1333
+ sql = sql.encode(self.encoding)
1334
+
1335
+ packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
1336
+
1337
+ # tiny optimization: build first packet manually instead of
1338
+ # calling self..write_packet()
1339
+ prelude = struct.pack('<iB', packet_size, command)
1340
+ packet = prelude + sql[: packet_size - 1]
1341
+ self._write_bytes(packet)
1342
+ if DEBUG:
1343
+ dump_packet(packet)
1344
+ self._next_seq_id = 1
1345
+
1346
+ if packet_size < MAX_PACKET_LEN:
1347
+ return
1348
+
1349
+ sql = sql[packet_size - 1:]
1350
+ while True:
1351
+ packet_size = min(MAX_PACKET_LEN, len(sql))
1352
+ self.write_packet(sql[:packet_size])
1353
+ sql = sql[packet_size:]
1354
+ if not sql and packet_size < MAX_PACKET_LEN:
1355
+ break
1356
+
1357
+ def _request_authentication(self): # noqa: C901
1358
+ # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
1359
+ if int(self.server_version.split('.', 1)[0]) >= 5:
1360
+ self.client_flag |= CLIENT.MULTI_RESULTS
1361
+
1362
+ if self.user is None:
1363
+ raise ValueError('Did not specify a username')
1364
+
1365
+ charset_id = charset_by_name(self.charset).id
1366
+ if isinstance(self.user, str):
1367
+ self.user = self.user.encode(self.encoding)
1368
+
1369
+ data_init = struct.pack(
1370
+ '<iIB23s', self.client_flag, MAX_PACKET_LEN, charset_id, b'',
1371
+ )
1372
+
1373
+ if self.ssl and self.server_capabilities & CLIENT.SSL:
1374
+ self.write_packet(data_init)
1375
+
1376
+ hostname = self.host
1377
+ if self._tls_sni_servername:
1378
+ hostname = self._tls_sni_servername
1379
+ self._sock = self.ctx.wrap_socket(self._sock, server_hostname=hostname)
1380
+ self._rfile = self._sock.makefile('rb')
1381
+ self._secure = True
1382
+
1383
+ data = data_init + self.user + b'\0'
1384
+
1385
+ authresp = b''
1386
+ plugin_name = None
1387
+
1388
+ if self._auth_plugin_name == '':
1389
+ plugin_name = b''
1390
+ authresp = _auth.scramble_native_password(self.password, self.salt)
1391
+ elif self._auth_plugin_name == 'mysql_native_password':
1392
+ plugin_name = b'mysql_native_password'
1393
+ authresp = _auth.scramble_native_password(self.password, self.salt)
1394
+ elif self._auth_plugin_name == 'caching_sha2_password':
1395
+ plugin_name = b'caching_sha2_password'
1396
+ if self.password:
1397
+ if DEBUG:
1398
+ print('caching_sha2: trying fast path')
1399
+ authresp = _auth.scramble_caching_sha2(self.password, self.salt)
1400
+ else:
1401
+ if DEBUG:
1402
+ print('caching_sha2: empty password')
1403
+ elif self._auth_plugin_name == 'sha256_password':
1404
+ plugin_name = b'sha256_password'
1405
+ if self.ssl and self.server_capabilities & CLIENT.SSL:
1406
+ authresp = self.password + b'\0'
1407
+ elif self.password:
1408
+ authresp = b'\1' # request public key
1409
+ else:
1410
+ authresp = b'\0' # empty password
1411
+
1412
+ if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
1413
+ data += _lenenc_int(len(authresp)) + authresp
1414
+ elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
1415
+ data += struct.pack('B', len(authresp)) + authresp
1416
+ else: # pragma: no cover - no testing against servers w/o secure auth (>=5.0)
1417
+ data += authresp + b'\0'
1418
+
1419
+ if self.server_capabilities & CLIENT.CONNECT_WITH_DB:
1420
+ db = self.db
1421
+ if isinstance(db, str):
1422
+ db = db.encode(self.encoding)
1423
+ data += (db or b'') + b'\0'
1424
+
1425
+ if self.server_capabilities & CLIENT.PLUGIN_AUTH:
1426
+ data += (plugin_name or b'') + b'\0'
1427
+
1428
+ if self.server_capabilities & CLIENT.CONNECT_ATTRS:
1429
+ connect_attrs = b''
1430
+ for k, v in self._connect_attrs.items():
1431
+ k = k.encode('utf-8')
1432
+ connect_attrs += _lenenc_int(len(k)) + k
1433
+ v = v.encode('utf-8')
1434
+ connect_attrs += _lenenc_int(len(v)) + v
1435
+ data += _lenenc_int(len(connect_attrs)) + connect_attrs
1436
+
1437
+ self.write_packet(data)
1438
+ auth_packet = self._read_packet()
1439
+
1440
+ # if authentication method isn't accepted the first byte
1441
+ # will have the octet 254
1442
+ if auth_packet.is_auth_switch_request():
1443
+ if DEBUG:
1444
+ print('received auth switch')
1445
+ # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
1446
+ auth_packet.read_uint8() # 0xfe packet identifier
1447
+ plugin_name = auth_packet.read_string()
1448
+ if (
1449
+ self.server_capabilities & CLIENT.PLUGIN_AUTH
1450
+ and plugin_name is not None
1451
+ ):
1452
+ auth_packet = self._process_auth(plugin_name, auth_packet)
1453
+ else:
1454
+ raise err.OperationalError('received unknown auth switch request')
1455
+ elif auth_packet.is_extra_auth_data():
1456
+ if DEBUG:
1457
+ print('received extra data')
1458
+ # https://dev.mysql.com/doc/internals/en/successful-authentication.html
1459
+ if self._auth_plugin_name == 'caching_sha2_password':
1460
+ auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
1461
+ elif self._auth_plugin_name == 'sha256_password':
1462
+ auth_packet = _auth.sha256_password_auth(self, auth_packet)
1463
+ else:
1464
+ raise err.OperationalError(
1465
+ 'Received extra packet for auth method %r', self._auth_plugin_name,
1466
+ )
1467
+
1468
+ if DEBUG:
1469
+ print('Succeed to auth')
1470
+
1471
+ def _process_auth(self, plugin_name, auth_packet):
1472
+ handler = self._get_auth_plugin_handler(plugin_name)
1473
+ if handler:
1474
+ try:
1475
+ return handler.authenticate(auth_packet)
1476
+ except AttributeError:
1477
+ if plugin_name != b'dialog':
1478
+ raise err.OperationalError(
1479
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1480
+ "Authentication plugin '%s'"
1481
+ ' not loaded: - %r missing authenticate method'
1482
+ % (plugin_name, type(handler)),
1483
+ )
1484
+ if plugin_name == b'caching_sha2_password':
1485
+ return _auth.caching_sha2_password_auth(self, auth_packet)
1486
+ elif plugin_name == b'sha256_password':
1487
+ return _auth.sha256_password_auth(self, auth_packet)
1488
+ elif plugin_name == b'mysql_native_password':
1489
+ data = _auth.scramble_native_password(self.password, auth_packet.read_all())
1490
+ elif plugin_name == b'client_ed25519':
1491
+ data = _auth.ed25519_password(self.password, auth_packet.read_all())
1492
+ elif plugin_name == b'mysql_old_password':
1493
+ data = (
1494
+ _auth.scramble_old_password(self.password, auth_packet.read_all())
1495
+ + b'\0'
1496
+ )
1497
+ elif plugin_name == b'mysql_clear_password':
1498
+ # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
1499
+ data = self.password + b'\0'
1500
+ elif plugin_name == b'auth_gssapi_client':
1501
+ data = _auth.gssapi_auth(auth_packet.read_all())
1502
+ elif plugin_name == b'dialog':
1503
+ pkt = auth_packet
1504
+ while True:
1505
+ flag = pkt.read_uint8()
1506
+ echo = (flag & 0x06) == 0x02
1507
+ last = (flag & 0x01) == 0x01
1508
+ prompt = pkt.read_all()
1509
+
1510
+ if prompt == b'Password: ':
1511
+ self.write_packet(self.password + b'\0')
1512
+ elif handler:
1513
+ resp = 'no response - TypeError within plugin.prompt method'
1514
+ try:
1515
+ resp = handler.prompt(echo, prompt)
1516
+ self.write_packet(resp + b'\0')
1517
+ except AttributeError:
1518
+ raise err.OperationalError(
1519
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1520
+ "Authentication plugin '%s'"
1521
+ ' not loaded: - %r missing prompt method'
1522
+ % (plugin_name, handler),
1523
+ )
1524
+ except TypeError:
1525
+ raise err.OperationalError(
1526
+ CR.CR_AUTH_PLUGIN_ERR,
1527
+ "Authentication plugin '%s'"
1528
+ " %r didn't respond with string. Returned '%r' to prompt %r"
1529
+ % (plugin_name, handler, resp, prompt),
1530
+ )
1531
+ else:
1532
+ raise err.OperationalError(
1533
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1534
+ "Authentication plugin '%s' not configured" % (plugin_name,),
1535
+ )
1536
+ pkt = self._read_packet()
1537
+ pkt.check_error()
1538
+ if pkt.is_ok_packet() or last:
1539
+ break
1540
+ return pkt
1541
+ else:
1542
+ raise err.OperationalError(
1543
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1544
+ "Authentication plugin '%s' not configured" % plugin_name,
1545
+ )
1546
+
1547
+ self.write_packet(data)
1548
+ pkt = self._read_packet()
1549
+ pkt.check_error()
1550
+ return pkt
1551
+
1552
+ def _get_auth_plugin_handler(self, plugin_name):
1553
+ plugin_class = self._auth_plugin_map.get(plugin_name)
1554
+ if not plugin_class and isinstance(plugin_name, bytes):
1555
+ plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
1556
+ if plugin_class:
1557
+ try:
1558
+ handler = plugin_class(self)
1559
+ except TypeError:
1560
+ raise err.OperationalError(
1561
+ CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1562
+ "Authentication plugin '%s'"
1563
+ ' not loaded: - %r cannot be constructed with connection object'
1564
+ % (plugin_name, plugin_class),
1565
+ )
1566
+ else:
1567
+ handler = None
1568
+ return handler
1569
+
1570
+ # _mysql support
1571
+ def thread_id(self):
1572
+ return self.server_thread_id[0]
1573
+
1574
+ def character_set_name(self):
1575
+ return self.charset
1576
+
1577
+ def get_host_info(self):
1578
+ return self.host_info
1579
+
1580
+ def get_proto_info(self):
1581
+ return self.protocol_version
1582
+
1583
+ def _get_server_information(self):
1584
+ i = 0
1585
+ packet = self._read_packet()
1586
+ data = packet.get_all_data()
1587
+
1588
+ self.protocol_version = data[i]
1589
+ i += 1
1590
+
1591
+ server_end = data.find(b'\0', i)
1592
+ self.server_version = data[i:server_end].decode('latin1')
1593
+ i = server_end + 1
1594
+
1595
+ self.server_thread_id = struct.unpack('<I', data[i: i + 4])
1596
+ i += 4
1597
+
1598
+ self.salt = data[i: i + 8]
1599
+ i += 9 # 8 + 1(filler)
1600
+
1601
+ self.server_capabilities = struct.unpack('<H', data[i: i + 2])[0]
1602
+ i += 2
1603
+
1604
+ if len(data) >= i + 6:
1605
+ lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i: i + 6])
1606
+ i += 6
1607
+ # TODO: deprecate server_language and server_charset.
1608
+ # mysqlclient-python doesn't provide it.
1609
+ self.server_language = lang
1610
+ try:
1611
+ self.server_charset = charset_by_id(lang).name
1612
+ except KeyError:
1613
+ # unknown collation
1614
+ self.server_charset = None
1615
+
1616
+ self.server_status = stat
1617
+ if DEBUG:
1618
+ print('server_status: %x' % stat)
1619
+
1620
+ self.server_capabilities |= cap_h << 16
1621
+ if DEBUG:
1622
+ print('salt_len:', salt_len)
1623
+ salt_len = max(12, salt_len - 9)
1624
+
1625
+ # reserved
1626
+ i += 10
1627
+
1628
+ if len(data) >= i + salt_len:
1629
+ # salt_len includes auth_plugin_data_part_1 and filler
1630
+ self.salt += data[i: i + salt_len]
1631
+ i += salt_len
1632
+
1633
+ i += 1
1634
+ # AUTH PLUGIN NAME may appear here.
1635
+ if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
1636
+ # Due to Bug#59453 the auth-plugin-name is missing the terminating
1637
+ # NUL-char in versions prior to 5.5.10 and 5.6.2.
1638
+ # ref: https://dev.mysql.com/doc/internals/en/
1639
+ # connection-phase-packets.html#packet-Protocol::Handshake
1640
+ # didn't use version checks as mariadb is corrected and reports
1641
+ # earlier than those two.
1642
+ server_end = data.find(b'\0', i)
1643
+ if server_end < 0: # pragma: no cover - very specific upstream bug
1644
+ # not found \0 and last field so take it all
1645
+ self._auth_plugin_name = data[i:].decode('utf-8')
1646
+ else:
1647
+ self._auth_plugin_name = data[i:server_end].decode('utf-8')
1648
+
1649
+ def get_server_info(self):
1650
+ return self.server_version
1651
+
1652
+ Warning = err.Warning
1653
+ Error = err.Error
1654
+ InterfaceError = err.InterfaceError
1655
+ DatabaseError = err.DatabaseError
1656
+ DataError = err.DataError
1657
+ OperationalError = err.OperationalError
1658
+ IntegrityError = err.IntegrityError
1659
+ InternalError = err.InternalError
1660
+ ProgrammingError = err.ProgrammingError
1661
+ NotSupportedError = err.NotSupportedError
1662
+
1663
+
1664
+ class MySQLResult:
1665
+ """
1666
+ Results of a SQL query.
1667
+
1668
+ Parameters
1669
+ ----------
1670
+ connection : Connection
1671
+ The connection the result came from.
1672
+ unbuffered : bool, optional
1673
+ Should the reads be unbuffered?
1674
+
1675
+ """
1676
+
1677
+ def __init__(self, connection, unbuffered=False):
1678
+ self.connection = connection
1679
+ self.affected_rows = None
1680
+ self.insert_id = None
1681
+ self.server_status = None
1682
+ self.warning_count = 0
1683
+ self.message = None
1684
+ self.field_count = 0
1685
+ self.description = None
1686
+ self.rows = None
1687
+ self.has_next = None
1688
+ self.unbuffered_active = False
1689
+ self.converters = []
1690
+ self.fields = []
1691
+ self.encoding_errors = self.connection.encoding_errors
1692
+ if unbuffered:
1693
+ try:
1694
+ self.init_unbuffered_query()
1695
+ except Exception:
1696
+ self.connection = None
1697
+ self.unbuffered_active = False
1698
+ raise
1699
+
1700
+ def __del__(self):
1701
+ if self.unbuffered_active:
1702
+ self._finish_unbuffered_query()
1703
+
1704
+ def read(self):
1705
+ try:
1706
+ first_packet = self.connection._read_packet()
1707
+
1708
+ if first_packet.is_ok_packet():
1709
+ self._read_ok_packet(first_packet)
1710
+ elif first_packet.is_load_local_packet():
1711
+ self._read_load_local_packet(first_packet)
1712
+ else:
1713
+ self._read_result_packet(first_packet)
1714
+ finally:
1715
+ self.connection = None
1716
+
1717
+ def init_unbuffered_query(self):
1718
+ """
1719
+ Initialize an unbuffered query.
1720
+
1721
+ Raises
1722
+ ------
1723
+ OperationalError : If the connection to the MySQL server is lost.
1724
+ InternalError : Other errors.
1725
+
1726
+ """
1727
+ self.unbuffered_active = True
1728
+ first_packet = self.connection._read_packet()
1729
+
1730
+ if first_packet.is_ok_packet():
1731
+ self._read_ok_packet(first_packet)
1732
+ self.unbuffered_active = False
1733
+ self.connection = None
1734
+ elif first_packet.is_load_local_packet():
1735
+ self._read_load_local_packet(first_packet)
1736
+ self.unbuffered_active = False
1737
+ self.connection = None
1738
+ else:
1739
+ self.field_count = first_packet.read_length_encoded_integer()
1740
+ self._get_descriptions()
1741
+
1742
+ # Apparently, MySQLdb picks this number because it's the maximum
1743
+ # value of a 64bit unsigned integer. Since we're emulating MySQLdb,
1744
+ # we set it to this instead of None, which would be preferred.
1745
+ self.affected_rows = 18446744073709551615
1746
+
1747
+ def _read_ok_packet(self, first_packet):
1748
+ ok_packet = OKPacketWrapper(first_packet)
1749
+ self.affected_rows = ok_packet.affected_rows
1750
+ self.insert_id = ok_packet.insert_id
1751
+ self.server_status = ok_packet.server_status
1752
+ self.warning_count = ok_packet.warning_count
1753
+ self.message = ok_packet.message
1754
+ self.has_next = ok_packet.has_next
1755
+
1756
+ def _read_load_local_packet(self, first_packet):
1757
+ if not self.connection._local_infile:
1758
+ raise RuntimeError(
1759
+ '**WARN**: Received LOAD_LOCAL packet but local_infile option is false.',
1760
+ )
1761
+ load_packet = LoadLocalPacketWrapper(first_packet)
1762
+ sender = LoadLocalFile(load_packet.filename, self.connection)
1763
+ try:
1764
+ sender.send_data()
1765
+ except Exception:
1766
+ self.connection._read_packet() # skip ok packet
1767
+ raise
1768
+
1769
+ ok_packet = self.connection._read_packet()
1770
+ if (
1771
+ not ok_packet.is_ok_packet()
1772
+ ): # pragma: no cover - upstream induced protocol error
1773
+ raise err.OperationalError(
1774
+ CR.CR_COMMANDS_OUT_OF_SYNC,
1775
+ 'Commands Out of Sync',
1776
+ )
1777
+ self._read_ok_packet(ok_packet)
1778
+
1779
+ def _check_packet_is_eof(self, packet):
1780
+ if not packet.is_eof_packet():
1781
+ return False
1782
+ # TODO: Support CLIENT.DEPRECATE_EOF
1783
+ # 1) Add DEPRECATE_EOF to CAPABILITIES
1784
+ # 2) Mask CAPABILITIES with server_capabilities
1785
+ # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper
1786
+ # instead of EOFPacketWrapper
1787
+ wp = EOFPacketWrapper(packet)
1788
+ self.warning_count = wp.warning_count
1789
+ self.has_next = wp.has_next
1790
+ return True
1791
+
1792
+ def _read_result_packet(self, first_packet):
1793
+ self.field_count = first_packet.read_length_encoded_integer()
1794
+ self._get_descriptions()
1795
+ self._read_rowdata_packet()
1796
+
1797
+ def _read_rowdata_packet_unbuffered(self):
1798
+ # Check if in an active query
1799
+ if not self.unbuffered_active:
1800
+ return
1801
+
1802
+ # EOF
1803
+ packet = self.connection._read_packet()
1804
+ if self._check_packet_is_eof(packet):
1805
+ self.unbuffered_active = False
1806
+ self.connection = None
1807
+ self.rows = None
1808
+ return
1809
+
1810
+ row = self._read_row_from_packet(packet)
1811
+ self.affected_rows = 1
1812
+ self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
1813
+ return row
1814
+
1815
+ def _finish_unbuffered_query(self):
1816
+ # After much reading on the MySQL protocol, it appears that there is,
1817
+ # in fact, no way to stop MySQL from sending all the data after
1818
+ # executing a query, so we just spin, and wait for an EOF packet.
1819
+ while self.unbuffered_active and self.connection._sock is not None:
1820
+ try:
1821
+ packet = self.connection._read_packet()
1822
+ except err.OperationalError as e:
1823
+ if e.args[0] in (
1824
+ ER.QUERY_TIMEOUT,
1825
+ ER.STATEMENT_TIMEOUT,
1826
+ ):
1827
+ # if the query timed out we can simply ignore this error
1828
+ self.unbuffered_active = False
1829
+ self.connection = None
1830
+ return
1831
+
1832
+ raise
1833
+
1834
+ if self._check_packet_is_eof(packet):
1835
+ self.unbuffered_active = False
1836
+ self.connection = None # release reference to kill cyclic reference.
1837
+
1838
+ def _read_rowdata_packet(self):
1839
+ """Read a rowdata packet for each data row in the result set."""
1840
+ rows = []
1841
+ while True:
1842
+ packet = self.connection._read_packet()
1843
+ if self._check_packet_is_eof(packet):
1844
+ self.connection = None # release reference to kill cyclic reference.
1845
+ break
1846
+ rows.append(self._read_row_from_packet(packet))
1847
+
1848
+ self.affected_rows = len(rows)
1849
+ self.rows = tuple(rows)
1850
+
1851
+ def _read_row_from_packet(self, packet):
1852
+ row = []
1853
+ for i, (encoding, converter) in enumerate(self.converters):
1854
+ try:
1855
+ data = packet.read_length_coded_string()
1856
+ except IndexError:
1857
+ # No more columns in this row
1858
+ # See https://github.com/PyMySQL/PyMySQL/pull/434
1859
+ break
1860
+ if data is not None:
1861
+ if encoding is not None:
1862
+ try:
1863
+ data = data.decode(encoding, errors=self.encoding_errors)
1864
+ except UnicodeDecodeError:
1865
+ raise UnicodeDecodeError(
1866
+ 'failed to decode string value in column '
1867
+ f"'{self.fields[i].name}' using encoding '{encoding}'; " +
1868
+ "use the 'encoding_errors' option on the connection " +
1869
+ 'to specify how to handle this error',
1870
+ )
1871
+ if DEBUG:
1872
+ print('DEBUG: DATA = ', data)
1873
+ if converter is not None:
1874
+ data = converter(data)
1875
+ row.append(data)
1876
+ return tuple(row)
1877
+
1878
+ def _get_descriptions(self):
1879
+ """Read a column descriptor packet for each column in the result."""
1880
+ self.fields = []
1881
+ self.converters = []
1882
+ use_unicode = self.connection.use_unicode
1883
+ conn_encoding = self.connection.encoding
1884
+ description = []
1885
+
1886
+ for i in range(self.field_count):
1887
+ field = self.connection._read_packet(FieldDescriptorPacket)
1888
+ self.fields.append(field)
1889
+ description.append(field.description())
1890
+ field_type = field.type_code
1891
+ if use_unicode:
1892
+ if field_type == FIELD_TYPE.JSON:
1893
+ # When SELECT from JSON column: charset = binary
1894
+ # When SELECT CAST(... AS JSON): charset = connection encoding
1895
+ # This behavior is different from TEXT / BLOB.
1896
+ # We should decode result by connection encoding regardless charsetnr.
1897
+ # See https://github.com/PyMySQL/PyMySQL/issues/488
1898
+ encoding = conn_encoding # SELECT CAST(... AS JSON)
1899
+ elif field_type in TEXT_TYPES:
1900
+ if field.charsetnr == 63: # binary
1901
+ # TEXTs with charset=binary means BINARY types.
1902
+ encoding = None
1903
+ else:
1904
+ encoding = conn_encoding
1905
+ else:
1906
+ # Integers, Dates and Times, and other basic data is encoded in ascii
1907
+ encoding = 'ascii'
1908
+ else:
1909
+ encoding = None
1910
+ converter = self.connection.decoders.get(field_type)
1911
+ if converter is converters.through:
1912
+ converter = None
1913
+ if DEBUG:
1914
+ print(f'DEBUG: field={field}, converter={converter}')
1915
+ self.converters.append((encoding, converter))
1916
+
1917
+ eof_packet = self.connection._read_packet()
1918
+ assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
1919
+ self.description = tuple(description)
1920
+
1921
+
1922
+ class MySQLResultSV(MySQLResult):
1923
+
1924
+ def __init__(self, connection, unbuffered=False):
1925
+ MySQLResult.__init__(self, connection, unbuffered=unbuffered)
1926
+ self.options = {
1927
+ k: v for k, v in dict(
1928
+ default_converters=converters.decoders,
1929
+ results_type=connection.results_type,
1930
+ parse_json=connection.parse_json,
1931
+ invalid_values=connection.invalid_values,
1932
+ unbuffered=unbuffered,
1933
+ encoding_errors=connection.encoding_errors,
1934
+ ).items() if v is not UNSET
1935
+ }
1936
+
1937
+ def _read_rowdata_packet(self, *args, **kwargs):
1938
+ return _singlestoredb_accel.read_rowdata_packet(self, False, *args, **kwargs)
1939
+
1940
+ def _read_rowdata_packet_unbuffered(self, *args, **kwargs):
1941
+ return _singlestoredb_accel.read_rowdata_packet(self, True, *args, **kwargs)
1942
+
1943
+
1944
+ class LoadLocalFile:
1945
+
1946
+ def __init__(self, filename, connection):
1947
+ self.filename = filename
1948
+ self.connection = connection
1949
+
1950
+ def send_data(self):
1951
+ """Send data packets from the local file to the server"""
1952
+ if not self.connection._sock:
1953
+ raise err.InterfaceError(0, 'Connection is closed')
1954
+
1955
+ conn = self.connection
1956
+ infile = conn._local_infile_stream
1957
+
1958
+ # 16KB is efficient enough
1959
+ packet_size = min(conn.max_allowed_packet, 16 * 1024)
1960
+
1961
+ try:
1962
+
1963
+ if self.filename in [':stream:', b':stream:']:
1964
+
1965
+ if infile is None:
1966
+ raise err.OperationalError(
1967
+ ER.FILE_NOT_FOUND,
1968
+ ':stream: specified for LOCAL INFILE, but no stream was supplied',
1969
+ )
1970
+
1971
+ # Binary IO
1972
+ elif isinstance(infile, io.RawIOBase):
1973
+ while True:
1974
+ chunk = infile.read(packet_size)
1975
+ if not chunk:
1976
+ break
1977
+ conn.write_packet(chunk)
1978
+
1979
+ # Text IO
1980
+ elif isinstance(infile, io.TextIOBase):
1981
+ while True:
1982
+ chunk = infile.read(packet_size)
1983
+ if not chunk:
1984
+ break
1985
+ conn.write_packet(chunk.encode('utf8'))
1986
+
1987
+ # Iterable of bytes or str
1988
+ elif isinstance(infile, Iterable):
1989
+ for chunk in infile:
1990
+ if not chunk:
1991
+ continue
1992
+ if isinstance(chunk, str):
1993
+ conn.write_packet(chunk.encode('utf8'))
1994
+ else:
1995
+ conn.write_packet(chunk)
1996
+
1997
+ # Queue (empty value ends the iteration)
1998
+ elif isinstance(infile, queue.Queue):
1999
+ while True:
2000
+ chunk = infile.get()
2001
+ if not chunk:
2002
+ break
2003
+ if isinstance(chunk, str):
2004
+ conn.write_packet(chunk.encode('utf8'))
2005
+ else:
2006
+ conn.write_packet(chunk)
2007
+
2008
+ else:
2009
+ raise err.OperationalError(
2010
+ ER.FILE_NOT_FOUND,
2011
+ ':stream: specified for LOCAL INFILE, ' +
2012
+ f'but stream type is unrecognized: {infile}',
2013
+ )
2014
+
2015
+ else:
2016
+ try:
2017
+ with open(self.filename, 'rb') as open_file:
2018
+ while True:
2019
+ chunk = open_file.read(packet_size)
2020
+ if not chunk:
2021
+ break
2022
+ conn.write_packet(chunk)
2023
+ except OSError:
2024
+ raise err.OperationalError(
2025
+ ER.FILE_NOT_FOUND,
2026
+ f"Can't find file '{self.filename!s}'",
2027
+ )
2028
+
2029
+ finally:
2030
+ if not conn._closed:
2031
+ # send the empty packet to signify we are done sending data
2032
+ conn.write_packet(b'')