singlestoredb 1.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- singlestoredb/__init__.py +75 -0
- singlestoredb/ai/__init__.py +2 -0
- singlestoredb/ai/chat.py +139 -0
- singlestoredb/ai/embeddings.py +128 -0
- singlestoredb/alchemy/__init__.py +90 -0
- singlestoredb/apps/__init__.py +3 -0
- singlestoredb/apps/_cloud_functions.py +90 -0
- singlestoredb/apps/_config.py +72 -0
- singlestoredb/apps/_connection_info.py +18 -0
- singlestoredb/apps/_dashboards.py +47 -0
- singlestoredb/apps/_process.py +32 -0
- singlestoredb/apps/_python_udfs.py +100 -0
- singlestoredb/apps/_stdout_supress.py +30 -0
- singlestoredb/apps/_uvicorn_util.py +36 -0
- singlestoredb/auth.py +245 -0
- singlestoredb/config.py +484 -0
- singlestoredb/connection.py +1487 -0
- singlestoredb/converters.py +950 -0
- singlestoredb/docstring/__init__.py +33 -0
- singlestoredb/docstring/attrdoc.py +126 -0
- singlestoredb/docstring/common.py +230 -0
- singlestoredb/docstring/epydoc.py +267 -0
- singlestoredb/docstring/google.py +412 -0
- singlestoredb/docstring/numpydoc.py +562 -0
- singlestoredb/docstring/parser.py +100 -0
- singlestoredb/docstring/py.typed +1 -0
- singlestoredb/docstring/rest.py +256 -0
- singlestoredb/docstring/tests/__init__.py +1 -0
- singlestoredb/docstring/tests/_pydoctor.py +21 -0
- singlestoredb/docstring/tests/test_epydoc.py +729 -0
- singlestoredb/docstring/tests/test_google.py +1007 -0
- singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
- singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
- singlestoredb/docstring/tests/test_parser.py +248 -0
- singlestoredb/docstring/tests/test_rest.py +547 -0
- singlestoredb/docstring/tests/test_util.py +70 -0
- singlestoredb/docstring/util.py +141 -0
- singlestoredb/exceptions.py +120 -0
- singlestoredb/functions/__init__.py +16 -0
- singlestoredb/functions/decorator.py +201 -0
- singlestoredb/functions/dtypes.py +1793 -0
- singlestoredb/functions/ext/__init__.py +1 -0
- singlestoredb/functions/ext/arrow.py +375 -0
- singlestoredb/functions/ext/asgi.py +2133 -0
- singlestoredb/functions/ext/json.py +420 -0
- singlestoredb/functions/ext/mmap.py +413 -0
- singlestoredb/functions/ext/rowdat_1.py +724 -0
- singlestoredb/functions/ext/timer.py +89 -0
- singlestoredb/functions/ext/utils.py +218 -0
- singlestoredb/functions/signature.py +1578 -0
- singlestoredb/functions/typing/__init__.py +41 -0
- singlestoredb/functions/typing/numpy.py +20 -0
- singlestoredb/functions/typing/pandas.py +2 -0
- singlestoredb/functions/typing/polars.py +2 -0
- singlestoredb/functions/typing/pyarrow.py +2 -0
- singlestoredb/functions/utils.py +421 -0
- singlestoredb/fusion/__init__.py +11 -0
- singlestoredb/fusion/graphql.py +213 -0
- singlestoredb/fusion/handler.py +916 -0
- singlestoredb/fusion/handlers/__init__.py +0 -0
- singlestoredb/fusion/handlers/export.py +525 -0
- singlestoredb/fusion/handlers/files.py +690 -0
- singlestoredb/fusion/handlers/job.py +660 -0
- singlestoredb/fusion/handlers/models.py +250 -0
- singlestoredb/fusion/handlers/stage.py +502 -0
- singlestoredb/fusion/handlers/utils.py +324 -0
- singlestoredb/fusion/handlers/workspace.py +956 -0
- singlestoredb/fusion/registry.py +249 -0
- singlestoredb/fusion/result.py +399 -0
- singlestoredb/http/__init__.py +27 -0
- singlestoredb/http/connection.py +1267 -0
- singlestoredb/magics/__init__.py +34 -0
- singlestoredb/magics/run_personal.py +137 -0
- singlestoredb/magics/run_shared.py +134 -0
- singlestoredb/management/__init__.py +9 -0
- singlestoredb/management/billing_usage.py +148 -0
- singlestoredb/management/cluster.py +462 -0
- singlestoredb/management/export.py +295 -0
- singlestoredb/management/files.py +1102 -0
- singlestoredb/management/inference_api.py +105 -0
- singlestoredb/management/job.py +887 -0
- singlestoredb/management/manager.py +373 -0
- singlestoredb/management/organization.py +226 -0
- singlestoredb/management/region.py +169 -0
- singlestoredb/management/utils.py +423 -0
- singlestoredb/management/workspace.py +1927 -0
- singlestoredb/mysql/__init__.py +177 -0
- singlestoredb/mysql/_auth.py +298 -0
- singlestoredb/mysql/charset.py +214 -0
- singlestoredb/mysql/connection.py +2032 -0
- singlestoredb/mysql/constants/CLIENT.py +38 -0
- singlestoredb/mysql/constants/COMMAND.py +32 -0
- singlestoredb/mysql/constants/CR.py +78 -0
- singlestoredb/mysql/constants/ER.py +474 -0
- singlestoredb/mysql/constants/EXTENDED_TYPE.py +3 -0
- singlestoredb/mysql/constants/FIELD_TYPE.py +48 -0
- singlestoredb/mysql/constants/FLAG.py +15 -0
- singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
- singlestoredb/mysql/constants/VECTOR_TYPE.py +6 -0
- singlestoredb/mysql/constants/__init__.py +0 -0
- singlestoredb/mysql/converters.py +271 -0
- singlestoredb/mysql/cursors.py +896 -0
- singlestoredb/mysql/err.py +92 -0
- singlestoredb/mysql/optionfile.py +20 -0
- singlestoredb/mysql/protocol.py +450 -0
- singlestoredb/mysql/tests/__init__.py +19 -0
- singlestoredb/mysql/tests/base.py +126 -0
- singlestoredb/mysql/tests/conftest.py +37 -0
- singlestoredb/mysql/tests/test_DictCursor.py +132 -0
- singlestoredb/mysql/tests/test_SSCursor.py +141 -0
- singlestoredb/mysql/tests/test_basic.py +452 -0
- singlestoredb/mysql/tests/test_connection.py +851 -0
- singlestoredb/mysql/tests/test_converters.py +58 -0
- singlestoredb/mysql/tests/test_cursor.py +141 -0
- singlestoredb/mysql/tests/test_err.py +16 -0
- singlestoredb/mysql/tests/test_issues.py +514 -0
- singlestoredb/mysql/tests/test_load_local.py +75 -0
- singlestoredb/mysql/tests/test_nextset.py +88 -0
- singlestoredb/mysql/tests/test_optionfile.py +27 -0
- singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
- singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
- singlestoredb/mysql/times.py +23 -0
- singlestoredb/notebook/__init__.py +16 -0
- singlestoredb/notebook/_objects.py +213 -0
- singlestoredb/notebook/_portal.py +352 -0
- singlestoredb/py.typed +0 -0
- singlestoredb/pytest.py +352 -0
- singlestoredb/server/__init__.py +0 -0
- singlestoredb/server/docker.py +452 -0
- singlestoredb/server/free_tier.py +267 -0
- singlestoredb/tests/__init__.py +0 -0
- singlestoredb/tests/alltypes.sql +307 -0
- singlestoredb/tests/alltypes_no_nulls.sql +208 -0
- singlestoredb/tests/empty.sql +0 -0
- singlestoredb/tests/ext_funcs/__init__.py +702 -0
- singlestoredb/tests/local_infile.csv +3 -0
- singlestoredb/tests/test.ipynb +18 -0
- singlestoredb/tests/test.sql +680 -0
- singlestoredb/tests/test2.ipynb +18 -0
- singlestoredb/tests/test2.sql +1 -0
- singlestoredb/tests/test_basics.py +1332 -0
- singlestoredb/tests/test_config.py +318 -0
- singlestoredb/tests/test_connection.py +3103 -0
- singlestoredb/tests/test_dbapi.py +27 -0
- singlestoredb/tests/test_exceptions.py +45 -0
- singlestoredb/tests/test_ext_func.py +1472 -0
- singlestoredb/tests/test_ext_func_data.py +1101 -0
- singlestoredb/tests/test_fusion.py +1527 -0
- singlestoredb/tests/test_http.py +288 -0
- singlestoredb/tests/test_management.py +1599 -0
- singlestoredb/tests/test_plugin.py +33 -0
- singlestoredb/tests/test_results.py +171 -0
- singlestoredb/tests/test_types.py +132 -0
- singlestoredb/tests/test_udf.py +737 -0
- singlestoredb/tests/test_udf_returns.py +459 -0
- singlestoredb/tests/test_vectorstore.py +51 -0
- singlestoredb/tests/test_xdict.py +333 -0
- singlestoredb/tests/utils.py +141 -0
- singlestoredb/types.py +373 -0
- singlestoredb/utils/__init__.py +0 -0
- singlestoredb/utils/config.py +950 -0
- singlestoredb/utils/convert_rows.py +69 -0
- singlestoredb/utils/debug.py +13 -0
- singlestoredb/utils/dtypes.py +205 -0
- singlestoredb/utils/events.py +65 -0
- singlestoredb/utils/mogrify.py +151 -0
- singlestoredb/utils/results.py +585 -0
- singlestoredb/utils/xdict.py +425 -0
- singlestoredb/vectorstore.py +192 -0
- singlestoredb/warnings.py +5 -0
- singlestoredb-1.16.1.dist-info/METADATA +165 -0
- singlestoredb-1.16.1.dist-info/RECORD +183 -0
- singlestoredb-1.16.1.dist-info/WHEEL +5 -0
- singlestoredb-1.16.1.dist-info/entry_points.txt +2 -0
- singlestoredb-1.16.1.dist-info/licenses/LICENSE +201 -0
- singlestoredb-1.16.1.dist-info/top_level.txt +3 -0
- sqlx/__init__.py +4 -0
- sqlx/magic.py +113 -0
|
@@ -0,0 +1,2032 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
# Python implementation of the MySQL client-server protocol
|
|
3
|
+
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
|
4
|
+
# Error codes:
|
|
5
|
+
# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
|
|
6
|
+
import errno
|
|
7
|
+
import functools
|
|
8
|
+
import io
|
|
9
|
+
import os
|
|
10
|
+
import queue
|
|
11
|
+
import socket
|
|
12
|
+
import struct
|
|
13
|
+
import sys
|
|
14
|
+
import traceback
|
|
15
|
+
import warnings
|
|
16
|
+
from collections.abc import Iterable
|
|
17
|
+
from typing import Any
|
|
18
|
+
from typing import Dict
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import _singlestoredb_accel
|
|
22
|
+
except (ImportError, ModuleNotFoundError):
|
|
23
|
+
_singlestoredb_accel = None
|
|
24
|
+
|
|
25
|
+
from . import _auth
|
|
26
|
+
from ..utils import events
|
|
27
|
+
|
|
28
|
+
from .charset import charset_by_name, charset_by_id
|
|
29
|
+
from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS
|
|
30
|
+
from . import converters
|
|
31
|
+
from .cursors import (
|
|
32
|
+
Cursor,
|
|
33
|
+
CursorSV,
|
|
34
|
+
DictCursor,
|
|
35
|
+
DictCursorSV,
|
|
36
|
+
NamedtupleCursor,
|
|
37
|
+
NamedtupleCursorSV,
|
|
38
|
+
ArrowCursor,
|
|
39
|
+
ArrowCursorSV,
|
|
40
|
+
NumpyCursor,
|
|
41
|
+
NumpyCursorSV,
|
|
42
|
+
PandasCursor,
|
|
43
|
+
PandasCursorSV,
|
|
44
|
+
PolarsCursor,
|
|
45
|
+
PolarsCursorSV,
|
|
46
|
+
SSCursor,
|
|
47
|
+
SSCursorSV,
|
|
48
|
+
SSDictCursor,
|
|
49
|
+
SSDictCursorSV,
|
|
50
|
+
SSNamedtupleCursor,
|
|
51
|
+
SSNamedtupleCursorSV,
|
|
52
|
+
SSArrowCursor,
|
|
53
|
+
SSArrowCursorSV,
|
|
54
|
+
SSNumpyCursor,
|
|
55
|
+
SSNumpyCursorSV,
|
|
56
|
+
SSPandasCursor,
|
|
57
|
+
SSPandasCursorSV,
|
|
58
|
+
SSPolarsCursor,
|
|
59
|
+
SSPolarsCursorSV,
|
|
60
|
+
)
|
|
61
|
+
from .optionfile import Parser
|
|
62
|
+
from .protocol import (
|
|
63
|
+
dump_packet,
|
|
64
|
+
MysqlPacket,
|
|
65
|
+
FieldDescriptorPacket,
|
|
66
|
+
OKPacketWrapper,
|
|
67
|
+
EOFPacketWrapper,
|
|
68
|
+
LoadLocalPacketWrapper,
|
|
69
|
+
)
|
|
70
|
+
from . import err
|
|
71
|
+
from ..config import get_option
|
|
72
|
+
from .. import fusion
|
|
73
|
+
from .. import connection
|
|
74
|
+
from ..connection import Connection as BaseConnection
|
|
75
|
+
from ..utils.debug import log_query
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
import ssl
|
|
79
|
+
|
|
80
|
+
SSL_ENABLED = True
|
|
81
|
+
except ImportError:
|
|
82
|
+
ssl = None
|
|
83
|
+
SSL_ENABLED = False
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
import getpass
|
|
87
|
+
|
|
88
|
+
DEFAULT_USER = getpass.getuser()
|
|
89
|
+
del getpass
|
|
90
|
+
except (ImportError, KeyError):
|
|
91
|
+
# KeyError occurs when there's no entry in OS database for a current user.
|
|
92
|
+
DEFAULT_USER = None
|
|
93
|
+
|
|
94
|
+
DEBUG = get_option('debug.connection')
|
|
95
|
+
|
|
96
|
+
TEXT_TYPES = {
|
|
97
|
+
FIELD_TYPE.BIT,
|
|
98
|
+
FIELD_TYPE.BLOB,
|
|
99
|
+
FIELD_TYPE.LONG_BLOB,
|
|
100
|
+
FIELD_TYPE.MEDIUM_BLOB,
|
|
101
|
+
FIELD_TYPE.STRING,
|
|
102
|
+
FIELD_TYPE.TINY_BLOB,
|
|
103
|
+
FIELD_TYPE.VAR_STRING,
|
|
104
|
+
FIELD_TYPE.VARCHAR,
|
|
105
|
+
FIELD_TYPE.GEOMETRY,
|
|
106
|
+
FIELD_TYPE.BSON,
|
|
107
|
+
FIELD_TYPE.FLOAT32_VECTOR_JSON,
|
|
108
|
+
FIELD_TYPE.FLOAT64_VECTOR_JSON,
|
|
109
|
+
FIELD_TYPE.INT8_VECTOR_JSON,
|
|
110
|
+
FIELD_TYPE.INT16_VECTOR_JSON,
|
|
111
|
+
FIELD_TYPE.INT32_VECTOR_JSON,
|
|
112
|
+
FIELD_TYPE.INT64_VECTOR_JSON,
|
|
113
|
+
FIELD_TYPE.FLOAT32_VECTOR,
|
|
114
|
+
FIELD_TYPE.FLOAT64_VECTOR,
|
|
115
|
+
FIELD_TYPE.INT8_VECTOR,
|
|
116
|
+
FIELD_TYPE.INT16_VECTOR,
|
|
117
|
+
FIELD_TYPE.INT32_VECTOR,
|
|
118
|
+
FIELD_TYPE.INT64_VECTOR,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
UNSET = 'unset'
|
|
122
|
+
|
|
123
|
+
DEFAULT_CHARSET = 'utf8mb4'
|
|
124
|
+
|
|
125
|
+
MAX_PACKET_LEN = 2**24 - 1
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _pack_int24(n):
|
|
129
|
+
return struct.pack('<I', n)[:3]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
|
|
133
|
+
def _lenenc_int(i):
|
|
134
|
+
if i < 0:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
'Encoding %d is less than 0 - no representation in LengthEncodedInteger' % i,
|
|
137
|
+
)
|
|
138
|
+
elif i < 0xFB:
|
|
139
|
+
return bytes([i])
|
|
140
|
+
elif i < (1 << 16):
|
|
141
|
+
return b'\xfc' + struct.pack('<H', i)
|
|
142
|
+
elif i < (1 << 24):
|
|
143
|
+
return b'\xfd' + struct.pack('<I', i)[:3]
|
|
144
|
+
elif i < (1 << 64):
|
|
145
|
+
return b'\xfe' + struct.pack('<Q', i)
|
|
146
|
+
else:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
'Encoding %x is larger than %x - no representation in LengthEncodedInteger'
|
|
149
|
+
% (i, (1 << 64)),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class Connection(BaseConnection):
|
|
154
|
+
"""
|
|
155
|
+
Representation of a socket with a mysql server.
|
|
156
|
+
|
|
157
|
+
The proper way to get an instance of this class is to call
|
|
158
|
+
``connect()``.
|
|
159
|
+
|
|
160
|
+
Establish a connection to the SingleStoreDB database.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
host : str, optional
|
|
165
|
+
Host where the database server is located.
|
|
166
|
+
user : str, optional
|
|
167
|
+
Username to log in as.
|
|
168
|
+
password : str, optional
|
|
169
|
+
Password to use.
|
|
170
|
+
database : str, optional
|
|
171
|
+
Database to use, None to not use a particular one.
|
|
172
|
+
port : int, optional
|
|
173
|
+
Server port to use, default is usually OK. (default: 3306)
|
|
174
|
+
bind_address : str, optional
|
|
175
|
+
When the client has multiple network interfaces, specify
|
|
176
|
+
the interface from which to connect to the host. Argument can be
|
|
177
|
+
a hostname or an IP address.
|
|
178
|
+
unix_socket : str, optional
|
|
179
|
+
Use a unix socket rather than TCP/IP.
|
|
180
|
+
read_timeout : int, optional
|
|
181
|
+
The timeout for reading from the connection in seconds
|
|
182
|
+
(default: None - no timeout)
|
|
183
|
+
write_timeout : int, optional
|
|
184
|
+
The timeout for writing to the connection in seconds
|
|
185
|
+
(default: None - no timeout)
|
|
186
|
+
charset : str, optional
|
|
187
|
+
Charset to use.
|
|
188
|
+
collation : str, optional
|
|
189
|
+
The charset collation
|
|
190
|
+
sql_mode : str, optional
|
|
191
|
+
Default SQL_MODE to use.
|
|
192
|
+
read_default_file : str, optional
|
|
193
|
+
Specifies my.cnf file to read these parameters from under the
|
|
194
|
+
[client] section.
|
|
195
|
+
conv : Dict[str, Callable[Any]], optional
|
|
196
|
+
Conversion dictionary to use instead of the default one.
|
|
197
|
+
This is used to provide custom marshalling and unmarshalling of types.
|
|
198
|
+
See converters.
|
|
199
|
+
use_unicode : bool, optional
|
|
200
|
+
Whether or not to default to unicode strings.
|
|
201
|
+
This option defaults to true.
|
|
202
|
+
client_flag : int, optional
|
|
203
|
+
Custom flags to send to MySQL. Find potential values in constants.CLIENT.
|
|
204
|
+
cursorclass : type, optional
|
|
205
|
+
Custom cursor class to use.
|
|
206
|
+
init_command : str, optional
|
|
207
|
+
Initial SQL statement to run when connection is established.
|
|
208
|
+
connect_timeout : int, optional
|
|
209
|
+
The timeout for connecting to the database in seconds.
|
|
210
|
+
(default: 10, min: 1, max: 31536000)
|
|
211
|
+
ssl : Dict[str, str], optional
|
|
212
|
+
A dict of arguments similar to mysql_ssl_set()'s parameters or
|
|
213
|
+
an ssl.SSLContext.
|
|
214
|
+
ssl_ca : str, optional
|
|
215
|
+
Path to the file that contains a PEM-formatted CA certificate.
|
|
216
|
+
ssl_cert : str, optional
|
|
217
|
+
Path to the file that contains a PEM-formatted client certificate.
|
|
218
|
+
ssl_cipher : str, optional
|
|
219
|
+
SSL ciphers to allow.
|
|
220
|
+
ssl_disabled : bool, optional
|
|
221
|
+
A boolean value that disables usage of TLS.
|
|
222
|
+
ssl_key : str, optional
|
|
223
|
+
Path to the file that contains a PEM-formatted private key for the
|
|
224
|
+
client certificate.
|
|
225
|
+
ssl_verify_cert : str, optional
|
|
226
|
+
Set to true to check the server certificate's validity.
|
|
227
|
+
ssl_verify_identity : bool, optional
|
|
228
|
+
Set to true to check the server's identity.
|
|
229
|
+
tls_sni_servername: str, optional
|
|
230
|
+
Set server host name for TLS connection
|
|
231
|
+
read_default_group : str, optional
|
|
232
|
+
Group to read from in the configuration file.
|
|
233
|
+
autocommit : bool, optional
|
|
234
|
+
Autocommit mode. None means use server default. (default: False)
|
|
235
|
+
local_infile : bool, optional
|
|
236
|
+
Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
|
|
237
|
+
max_allowed_packet : int, optional
|
|
238
|
+
Max size of packet sent to server in bytes. (default: 16MB)
|
|
239
|
+
Only used to limit size of "LOAD LOCAL INFILE" data packet smaller
|
|
240
|
+
than default (16KB).
|
|
241
|
+
defer_connect : bool, optional
|
|
242
|
+
Don't explicitly connect on construction - wait for connect call.
|
|
243
|
+
(default: False)
|
|
244
|
+
auth_plugin_map : Dict[str, type], optional
|
|
245
|
+
A dict of plugin names to a class that processes that plugin.
|
|
246
|
+
The class will take the Connection object as the argument to the
|
|
247
|
+
constructor. The class needs an authenticate method taking an
|
|
248
|
+
authentication packet as an argument. For the dialog plugin, a
|
|
249
|
+
prompt(echo, prompt) method can be used (if no authenticate method)
|
|
250
|
+
for returning a string from the user. (experimental)
|
|
251
|
+
server_public_key : str, optional
|
|
252
|
+
SHA256 authentication plugin public key value. (default: None)
|
|
253
|
+
binary_prefix : bool, optional
|
|
254
|
+
Add _binary prefix on bytes and bytearray. (default: False)
|
|
255
|
+
compress :
|
|
256
|
+
Not supported.
|
|
257
|
+
named_pipe :
|
|
258
|
+
Not supported.
|
|
259
|
+
db : str, optional
|
|
260
|
+
**DEPRECATED** Alias for database.
|
|
261
|
+
passwd : str, optional
|
|
262
|
+
**DEPRECATED** Alias for password.
|
|
263
|
+
parse_json : bool, optional
|
|
264
|
+
Parse JSON values into Python objects?
|
|
265
|
+
invalid_values : Dict[int, Any], optional
|
|
266
|
+
Dictionary of values to use in place of invalid values
|
|
267
|
+
found during conversion of data. The default is to return the byte content
|
|
268
|
+
containing the invalid value. The keys are the integers associtated with
|
|
269
|
+
the column type.
|
|
270
|
+
pure_python : bool, optional
|
|
271
|
+
Should we ignore the C extension even if it's available?
|
|
272
|
+
This can be given explicitly using True or False, or if the value is None,
|
|
273
|
+
the C extension will be loaded if it is available. If set to False and
|
|
274
|
+
the C extension can't be loaded, a NotSupportedError is raised.
|
|
275
|
+
nan_as_null : bool, optional
|
|
276
|
+
Should NaN values be treated as NULLs in parameter substitution including
|
|
277
|
+
uploading data?
|
|
278
|
+
inf_as_null : bool, optional
|
|
279
|
+
Should Inf values be treated as NULLs in parameter substitution including
|
|
280
|
+
uploading data?
|
|
281
|
+
track_env : bool, optional
|
|
282
|
+
Should the connection track the SINGLESTOREDB_URL environment variable?
|
|
283
|
+
enable_extended_data_types : bool, optional
|
|
284
|
+
Should extended data types (BSON, vector) be enabled?
|
|
285
|
+
vector_data_format : str, optional
|
|
286
|
+
Specify the data type of vector values: json or binary
|
|
287
|
+
|
|
288
|
+
See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_
|
|
289
|
+
in the specification.
|
|
290
|
+
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
driver = 'mysql'
|
|
294
|
+
paramstyle = 'pyformat'
|
|
295
|
+
|
|
296
|
+
_sock = None
|
|
297
|
+
_auth_plugin_name = ''
|
|
298
|
+
_closed = False
|
|
299
|
+
_secure = False
|
|
300
|
+
_tls_sni_servername = None
|
|
301
|
+
|
|
302
|
+
def __init__( # noqa: C901
|
|
303
|
+
self,
|
|
304
|
+
*,
|
|
305
|
+
user=None, # The first four arguments is based on DB-API 2.0 recommendation.
|
|
306
|
+
password='',
|
|
307
|
+
host=None,
|
|
308
|
+
database=None,
|
|
309
|
+
unix_socket=None,
|
|
310
|
+
port=0,
|
|
311
|
+
charset='',
|
|
312
|
+
collation=None,
|
|
313
|
+
sql_mode=None,
|
|
314
|
+
read_default_file=None,
|
|
315
|
+
conv=None,
|
|
316
|
+
use_unicode=True,
|
|
317
|
+
client_flag=0,
|
|
318
|
+
cursorclass=None,
|
|
319
|
+
init_command=None,
|
|
320
|
+
connect_timeout=10,
|
|
321
|
+
read_default_group=None,
|
|
322
|
+
autocommit=False,
|
|
323
|
+
local_infile=False,
|
|
324
|
+
max_allowed_packet=16 * 1024 * 1024,
|
|
325
|
+
defer_connect=False,
|
|
326
|
+
auth_plugin_map=None,
|
|
327
|
+
read_timeout=None,
|
|
328
|
+
write_timeout=None,
|
|
329
|
+
bind_address=None,
|
|
330
|
+
binary_prefix=False,
|
|
331
|
+
program_name=None,
|
|
332
|
+
server_public_key=None,
|
|
333
|
+
ssl=None,
|
|
334
|
+
ssl_ca=None,
|
|
335
|
+
ssl_cert=None,
|
|
336
|
+
ssl_cipher=None,
|
|
337
|
+
ssl_disabled=None,
|
|
338
|
+
ssl_key=None,
|
|
339
|
+
ssl_verify_cert=None,
|
|
340
|
+
ssl_verify_identity=None,
|
|
341
|
+
tls_sni_servername=None,
|
|
342
|
+
parse_json=True,
|
|
343
|
+
invalid_values=None,
|
|
344
|
+
pure_python=None,
|
|
345
|
+
buffered=True,
|
|
346
|
+
results_type='tuples',
|
|
347
|
+
compress=None, # not supported
|
|
348
|
+
named_pipe=None, # not supported
|
|
349
|
+
passwd=None, # deprecated
|
|
350
|
+
db=None, # deprecated
|
|
351
|
+
driver=None, # internal use
|
|
352
|
+
conn_attrs=None,
|
|
353
|
+
multi_statements=None,
|
|
354
|
+
client_found_rows=None,
|
|
355
|
+
nan_as_null=None,
|
|
356
|
+
inf_as_null=None,
|
|
357
|
+
encoding_errors='strict',
|
|
358
|
+
track_env=False,
|
|
359
|
+
enable_extended_data_types=True,
|
|
360
|
+
vector_data_format='binary',
|
|
361
|
+
):
|
|
362
|
+
BaseConnection.__init__(**dict(locals()))
|
|
363
|
+
|
|
364
|
+
if db is not None and database is None:
|
|
365
|
+
# We will raise warning in 2022 or later.
|
|
366
|
+
# See https://github.com/PyMySQL/PyMySQL/issues/939
|
|
367
|
+
# warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
|
|
368
|
+
database = db
|
|
369
|
+
if passwd is not None and not password:
|
|
370
|
+
# We will raise warning in 2022 or later.
|
|
371
|
+
# See https://github.com/PyMySQL/PyMySQL/issues/939
|
|
372
|
+
# warnings.warn(
|
|
373
|
+
# "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
|
|
374
|
+
# )
|
|
375
|
+
password = passwd
|
|
376
|
+
|
|
377
|
+
if compress or named_pipe:
|
|
378
|
+
raise NotImplementedError(
|
|
379
|
+
'compress and named_pipe arguments are not supported',
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
self._local_infile = bool(local_infile)
|
|
383
|
+
self._local_infile_stream = None
|
|
384
|
+
if self._local_infile:
|
|
385
|
+
client_flag |= CLIENT.LOCAL_FILES
|
|
386
|
+
if multi_statements:
|
|
387
|
+
client_flag |= CLIENT.MULTI_STATEMENTS
|
|
388
|
+
if client_found_rows:
|
|
389
|
+
client_flag |= CLIENT.FOUND_ROWS
|
|
390
|
+
|
|
391
|
+
if read_default_group and not read_default_file:
|
|
392
|
+
if sys.platform.startswith('win'):
|
|
393
|
+
read_default_file = 'c:\\my.ini'
|
|
394
|
+
else:
|
|
395
|
+
read_default_file = '/etc/my.cnf'
|
|
396
|
+
|
|
397
|
+
if read_default_file:
|
|
398
|
+
if not read_default_group:
|
|
399
|
+
read_default_group = 'client'
|
|
400
|
+
|
|
401
|
+
cfg = Parser()
|
|
402
|
+
cfg.read(os.path.expanduser(read_default_file))
|
|
403
|
+
|
|
404
|
+
def _config(key, arg):
|
|
405
|
+
if arg:
|
|
406
|
+
return arg
|
|
407
|
+
try:
|
|
408
|
+
return cfg.get(read_default_group, key)
|
|
409
|
+
except Exception:
|
|
410
|
+
return arg
|
|
411
|
+
|
|
412
|
+
user = _config('user', user)
|
|
413
|
+
password = _config('password', password)
|
|
414
|
+
host = _config('host', host)
|
|
415
|
+
database = _config('database', database)
|
|
416
|
+
unix_socket = _config('socket', unix_socket)
|
|
417
|
+
port = int(_config('port', port))
|
|
418
|
+
bind_address = _config('bind-address', bind_address)
|
|
419
|
+
charset = _config('default-character-set', charset)
|
|
420
|
+
if not ssl:
|
|
421
|
+
ssl = {}
|
|
422
|
+
if isinstance(ssl, dict):
|
|
423
|
+
for key in ['ca', 'capath', 'cert', 'key', 'cipher']:
|
|
424
|
+
value = _config('ssl-' + key, ssl.get(key))
|
|
425
|
+
if value:
|
|
426
|
+
ssl[key] = value
|
|
427
|
+
|
|
428
|
+
self.ssl = False
|
|
429
|
+
if not ssl_disabled:
|
|
430
|
+
if ssl_ca or ssl_cert or ssl_key or ssl_cipher or \
|
|
431
|
+
ssl_verify_cert or ssl_verify_identity:
|
|
432
|
+
ssl = {
|
|
433
|
+
'ca': ssl_ca,
|
|
434
|
+
'check_hostname': bool(ssl_verify_identity),
|
|
435
|
+
'verify_mode': ssl_verify_cert
|
|
436
|
+
if ssl_verify_cert is not None
|
|
437
|
+
else False,
|
|
438
|
+
}
|
|
439
|
+
if ssl_cert is not None:
|
|
440
|
+
ssl['cert'] = ssl_cert
|
|
441
|
+
if ssl_key is not None:
|
|
442
|
+
ssl['key'] = ssl_key
|
|
443
|
+
if ssl_cipher is not None:
|
|
444
|
+
ssl['cipher'] = ssl_cipher
|
|
445
|
+
if ssl:
|
|
446
|
+
if not SSL_ENABLED:
|
|
447
|
+
raise NotImplementedError('ssl module not found')
|
|
448
|
+
self.ssl = True
|
|
449
|
+
client_flag |= CLIENT.SSL
|
|
450
|
+
self.ctx = self._create_ssl_ctx(ssl)
|
|
451
|
+
|
|
452
|
+
self.host = host or 'localhost'
|
|
453
|
+
self.port = port or 3306
|
|
454
|
+
if type(self.port) is not int:
|
|
455
|
+
raise ValueError('port should be of type int')
|
|
456
|
+
self.user = user or DEFAULT_USER
|
|
457
|
+
self.password = password or b''
|
|
458
|
+
if isinstance(self.password, str):
|
|
459
|
+
self.password = self.password.encode('latin1')
|
|
460
|
+
self.db = database
|
|
461
|
+
self.unix_socket = unix_socket
|
|
462
|
+
self.bind_address = bind_address
|
|
463
|
+
if not (0 < connect_timeout <= 31536000):
|
|
464
|
+
raise ValueError('connect_timeout should be >0 and <=31536000')
|
|
465
|
+
self.connect_timeout = connect_timeout or None
|
|
466
|
+
if read_timeout is not None and read_timeout <= 0:
|
|
467
|
+
raise ValueError('read_timeout should be > 0')
|
|
468
|
+
self._read_timeout = read_timeout
|
|
469
|
+
if write_timeout is not None and write_timeout <= 0:
|
|
470
|
+
raise ValueError('write_timeout should be > 0')
|
|
471
|
+
self._write_timeout = write_timeout
|
|
472
|
+
|
|
473
|
+
self.charset = charset or DEFAULT_CHARSET
|
|
474
|
+
self.collation = collation
|
|
475
|
+
self.use_unicode = use_unicode
|
|
476
|
+
self.encoding_errors = encoding_errors
|
|
477
|
+
|
|
478
|
+
self.encoding = charset_by_name(self.charset).encoding
|
|
479
|
+
|
|
480
|
+
client_flag |= CLIENT.CAPABILITIES
|
|
481
|
+
client_flag |= CLIENT.CONNECT_WITH_DB
|
|
482
|
+
|
|
483
|
+
self.client_flag = client_flag
|
|
484
|
+
|
|
485
|
+
self.pure_python = pure_python
|
|
486
|
+
self.results_type = results_type
|
|
487
|
+
self.resultclass = MySQLResult
|
|
488
|
+
if cursorclass is not None:
|
|
489
|
+
self.cursorclass = cursorclass
|
|
490
|
+
elif buffered:
|
|
491
|
+
if 'dict' in self.results_type:
|
|
492
|
+
self.cursorclass = DictCursor
|
|
493
|
+
elif 'namedtuple' in self.results_type:
|
|
494
|
+
self.cursorclass = NamedtupleCursor
|
|
495
|
+
elif 'numpy' in self.results_type:
|
|
496
|
+
self.cursorclass = NumpyCursor
|
|
497
|
+
elif 'arrow' in self.results_type:
|
|
498
|
+
self.cursorclass = ArrowCursor
|
|
499
|
+
elif 'pandas' in self.results_type:
|
|
500
|
+
self.cursorclass = PandasCursor
|
|
501
|
+
elif 'polars' in self.results_type:
|
|
502
|
+
self.cursorclass = PolarsCursor
|
|
503
|
+
else:
|
|
504
|
+
self.cursorclass = Cursor
|
|
505
|
+
else:
|
|
506
|
+
if 'dict' in self.results_type:
|
|
507
|
+
self.cursorclass = SSDictCursor
|
|
508
|
+
elif 'namedtuple' in self.results_type:
|
|
509
|
+
self.cursorclass = SSNamedtupleCursor
|
|
510
|
+
elif 'numpy' in self.results_type:
|
|
511
|
+
self.cursorclass = SSNumpyCursor
|
|
512
|
+
elif 'arrow' in self.results_type:
|
|
513
|
+
self.cursorclass = SSArrowCursor
|
|
514
|
+
elif 'pandas' in self.results_type:
|
|
515
|
+
self.cursorclass = SSPandasCursor
|
|
516
|
+
elif 'polars' in self.results_type:
|
|
517
|
+
self.cursorclass = SSPolarsCursor
|
|
518
|
+
else:
|
|
519
|
+
self.cursorclass = SSCursor
|
|
520
|
+
|
|
521
|
+
if self.pure_python is False and _singlestoredb_accel is None:
|
|
522
|
+
try:
|
|
523
|
+
import _singlestortedb_accel # noqa: F401
|
|
524
|
+
except Exception:
|
|
525
|
+
import traceback
|
|
526
|
+
traceback.print_exc(file=sys.stderr)
|
|
527
|
+
finally:
|
|
528
|
+
raise err.NotSupportedError(
|
|
529
|
+
'pure_python=False, but the '
|
|
530
|
+
'C extension can not be loaded',
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
if self.pure_python is True:
|
|
534
|
+
pass
|
|
535
|
+
|
|
536
|
+
# The C extension handles these types internally.
|
|
537
|
+
elif _singlestoredb_accel is not None:
|
|
538
|
+
self.resultclass = MySQLResultSV
|
|
539
|
+
if self.cursorclass is Cursor:
|
|
540
|
+
self.cursorclass = CursorSV
|
|
541
|
+
elif self.cursorclass is SSCursor:
|
|
542
|
+
self.cursorclass = SSCursorSV
|
|
543
|
+
elif self.cursorclass is DictCursor:
|
|
544
|
+
self.cursorclass = DictCursorSV
|
|
545
|
+
self.results_type = 'dicts'
|
|
546
|
+
elif self.cursorclass is SSDictCursor:
|
|
547
|
+
self.cursorclass = SSDictCursorSV
|
|
548
|
+
self.results_type = 'dicts'
|
|
549
|
+
elif self.cursorclass is NamedtupleCursor:
|
|
550
|
+
self.cursorclass = NamedtupleCursorSV
|
|
551
|
+
self.results_type = 'namedtuples'
|
|
552
|
+
elif self.cursorclass is SSNamedtupleCursor:
|
|
553
|
+
self.cursorclass = SSNamedtupleCursorSV
|
|
554
|
+
self.results_type = 'namedtuples'
|
|
555
|
+
elif self.cursorclass is NumpyCursor:
|
|
556
|
+
self.cursorclass = NumpyCursorSV
|
|
557
|
+
self.results_type = 'numpy'
|
|
558
|
+
elif self.cursorclass is SSNumpyCursor:
|
|
559
|
+
self.cursorclass = SSNumpyCursorSV
|
|
560
|
+
self.results_type = 'numpy'
|
|
561
|
+
elif self.cursorclass is ArrowCursor:
|
|
562
|
+
self.cursorclass = ArrowCursorSV
|
|
563
|
+
self.results_type = 'arrow'
|
|
564
|
+
elif self.cursorclass is SSArrowCursor:
|
|
565
|
+
self.cursorclass = SSArrowCursorSV
|
|
566
|
+
self.results_type = 'arrow'
|
|
567
|
+
elif self.cursorclass is PandasCursor:
|
|
568
|
+
self.cursorclass = PandasCursorSV
|
|
569
|
+
self.results_type = 'pandas'
|
|
570
|
+
elif self.cursorclass is SSPandasCursor:
|
|
571
|
+
self.cursorclass = SSPandasCursorSV
|
|
572
|
+
self.results_type = 'pandas'
|
|
573
|
+
elif self.cursorclass is PolarsCursor:
|
|
574
|
+
self.cursorclass = PolarsCursorSV
|
|
575
|
+
self.results_type = 'polars'
|
|
576
|
+
elif self.cursorclass is SSPolarsCursor:
|
|
577
|
+
self.cursorclass = SSPolarsCursorSV
|
|
578
|
+
self.results_type = 'polars'
|
|
579
|
+
|
|
580
|
+
self._result = None
|
|
581
|
+
self._affected_rows = 0
|
|
582
|
+
self.host_info = 'Not connected'
|
|
583
|
+
|
|
584
|
+
# specified autocommit mode. None means use server default.
|
|
585
|
+
self.autocommit_mode = autocommit
|
|
586
|
+
|
|
587
|
+
if conv is None:
|
|
588
|
+
conv = converters.conversions
|
|
589
|
+
|
|
590
|
+
conv = conv.copy()
|
|
591
|
+
|
|
592
|
+
self.parse_json = parse_json
|
|
593
|
+
self.invalid_values = (invalid_values or {}).copy()
|
|
594
|
+
|
|
595
|
+
# Disable JSON parsing for Arrow
|
|
596
|
+
if self.results_type in ['arrow']:
|
|
597
|
+
conv[245] = None
|
|
598
|
+
self.parse_json = False
|
|
599
|
+
|
|
600
|
+
# Disable date/time parsing for polars; let polars do the parsing
|
|
601
|
+
elif self.results_type in ['polars']:
|
|
602
|
+
conv[7] = None
|
|
603
|
+
conv[10] = None
|
|
604
|
+
conv[12] = None
|
|
605
|
+
|
|
606
|
+
# Need for MySQLdb compatibility.
|
|
607
|
+
self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
|
|
608
|
+
self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
|
|
609
|
+
self.sql_mode = sql_mode
|
|
610
|
+
self.init_command = init_command
|
|
611
|
+
self.max_allowed_packet = max_allowed_packet
|
|
612
|
+
self._auth_plugin_map = auth_plugin_map or {}
|
|
613
|
+
self._binary_prefix = binary_prefix
|
|
614
|
+
self.server_public_key = server_public_key
|
|
615
|
+
|
|
616
|
+
if self.connection_params['nan_as_null'] or \
|
|
617
|
+
self.connection_params['inf_as_null']:
|
|
618
|
+
float_encoder = self.encoders.get(float)
|
|
619
|
+
if float_encoder is not None:
|
|
620
|
+
self.encoders[float] = functools.partial(
|
|
621
|
+
float_encoder,
|
|
622
|
+
nan_as_null=self.connection_params['nan_as_null'],
|
|
623
|
+
inf_as_null=self.connection_params['inf_as_null'],
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
from .. import __version__ as VERSION_STRING
|
|
627
|
+
|
|
628
|
+
if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:
|
|
629
|
+
VERSION_STRING += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']
|
|
630
|
+
|
|
631
|
+
self._connect_attrs = {
|
|
632
|
+
'_os': str(sys.platform),
|
|
633
|
+
'_pid': str(os.getpid()),
|
|
634
|
+
'_client_name': 'SingleStoreDB Python Client',
|
|
635
|
+
'_client_version': VERSION_STRING,
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
if program_name:
|
|
639
|
+
self._connect_attrs['program_name'] = program_name
|
|
640
|
+
if conn_attrs is not None:
|
|
641
|
+
# do not overwrite the attributes that we set ourselves
|
|
642
|
+
for k, v in conn_attrs.items():
|
|
643
|
+
if k not in self._connect_attrs:
|
|
644
|
+
self._connect_attrs[k] = v
|
|
645
|
+
|
|
646
|
+
self._is_committable = True
|
|
647
|
+
self._in_sync = False
|
|
648
|
+
self._tls_sni_servername = tls_sni_servername
|
|
649
|
+
self._track_env = bool(track_env) or self.host == 'singlestore.com'
|
|
650
|
+
self._enable_extended_data_types = enable_extended_data_types
|
|
651
|
+
if vector_data_format.lower() in ['json', 'binary']:
|
|
652
|
+
self._vector_data_format = vector_data_format
|
|
653
|
+
else:
|
|
654
|
+
raise ValueError(
|
|
655
|
+
'unknown value for vector_data_format, '
|
|
656
|
+
f'expecting "json" or "binary": {vector_data_format}',
|
|
657
|
+
)
|
|
658
|
+
self._connection_info = {}
|
|
659
|
+
events.subscribe(self._handle_event)
|
|
660
|
+
|
|
661
|
+
if defer_connect or self._track_env:
|
|
662
|
+
self._sock = None
|
|
663
|
+
else:
|
|
664
|
+
self.connect()
|
|
665
|
+
|
|
666
|
+
def _handle_event(self, data: Dict[str, Any]) -> None:
|
|
667
|
+
if data.get('name', '') == 'singlestore.portal.connection_updated':
|
|
668
|
+
self._connection_info = dict(data)
|
|
669
|
+
|
|
670
|
+
@property
|
|
671
|
+
def messages(self):
|
|
672
|
+
# TODO
|
|
673
|
+
[]
|
|
674
|
+
|
|
675
|
+
def __enter__(self):
|
|
676
|
+
return self
|
|
677
|
+
|
|
678
|
+
def __exit__(self, *exc_info):
|
|
679
|
+
del exc_info
|
|
680
|
+
self.close()
|
|
681
|
+
|
|
682
|
+
def _raise_mysql_exception(self, data):
|
|
683
|
+
err.raise_mysql_exception(data)
|
|
684
|
+
|
|
685
|
+
def _create_ssl_ctx(self, sslp):
|
|
686
|
+
if isinstance(sslp, ssl.SSLContext):
|
|
687
|
+
return sslp
|
|
688
|
+
ca = sslp.get('ca')
|
|
689
|
+
capath = sslp.get('capath')
|
|
690
|
+
hasnoca = ca is None and capath is None
|
|
691
|
+
ctx = ssl.create_default_context(cafile=ca, capath=capath)
|
|
692
|
+
ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
|
|
693
|
+
verify_mode_value = sslp.get('verify_mode')
|
|
694
|
+
if verify_mode_value is None:
|
|
695
|
+
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
|
|
696
|
+
elif isinstance(verify_mode_value, bool):
|
|
697
|
+
ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
|
|
698
|
+
else:
|
|
699
|
+
if isinstance(verify_mode_value, str):
|
|
700
|
+
verify_mode_value = verify_mode_value.lower()
|
|
701
|
+
if verify_mode_value in ('none', '0', 'false', 'no'):
|
|
702
|
+
ctx.verify_mode = ssl.CERT_NONE
|
|
703
|
+
elif verify_mode_value == 'optional':
|
|
704
|
+
ctx.verify_mode = ssl.CERT_OPTIONAL
|
|
705
|
+
elif verify_mode_value in ('required', '1', 'true', 'yes'):
|
|
706
|
+
ctx.verify_mode = ssl.CERT_REQUIRED
|
|
707
|
+
else:
|
|
708
|
+
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
|
|
709
|
+
if 'cert' in sslp:
|
|
710
|
+
ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
|
|
711
|
+
if 'cipher' in sslp:
|
|
712
|
+
ctx.set_ciphers(sslp['cipher'])
|
|
713
|
+
ctx.options |= ssl.OP_NO_SSLv2
|
|
714
|
+
ctx.options |= ssl.OP_NO_SSLv3
|
|
715
|
+
return ctx
|
|
716
|
+
|
|
717
|
+
def close(self):
|
|
718
|
+
"""
|
|
719
|
+
Send the quit message and close the socket.
|
|
720
|
+
|
|
721
|
+
See `Connection.close()
|
|
722
|
+
<https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
|
|
723
|
+
in the specification.
|
|
724
|
+
|
|
725
|
+
Raises
|
|
726
|
+
------
|
|
727
|
+
Error : If the connection is already closed.
|
|
728
|
+
|
|
729
|
+
"""
|
|
730
|
+
self._result = None
|
|
731
|
+
if self.host == 'singlestore.com':
|
|
732
|
+
return
|
|
733
|
+
if self._closed:
|
|
734
|
+
raise err.Error('Already closed')
|
|
735
|
+
events.unsubscribe(self._handle_event)
|
|
736
|
+
self._closed = True
|
|
737
|
+
if self._sock is None:
|
|
738
|
+
return
|
|
739
|
+
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
|
|
740
|
+
try:
|
|
741
|
+
self._write_bytes(send_data)
|
|
742
|
+
except Exception:
|
|
743
|
+
pass
|
|
744
|
+
finally:
|
|
745
|
+
self._force_close()
|
|
746
|
+
|
|
747
|
+
@property
|
|
748
|
+
def open(self):
|
|
749
|
+
"""Return True if the connection is open."""
|
|
750
|
+
return self._sock is not None
|
|
751
|
+
|
|
752
|
+
def is_connected(self):
|
|
753
|
+
"""Return True if the connection is open."""
|
|
754
|
+
return self.open
|
|
755
|
+
|
|
756
|
+
def _force_close(self):
|
|
757
|
+
"""Close connection without QUIT message."""
|
|
758
|
+
if self._sock:
|
|
759
|
+
try:
|
|
760
|
+
self._sock.close()
|
|
761
|
+
except: # noqa
|
|
762
|
+
pass
|
|
763
|
+
self._sock = None
|
|
764
|
+
self._rfile = None
|
|
765
|
+
|
|
766
|
+
__del__ = _force_close
|
|
767
|
+
|
|
768
|
+
def autocommit(self, value):
|
|
769
|
+
"""Enable autocommit in the server."""
|
|
770
|
+
self.autocommit_mode = bool(value)
|
|
771
|
+
current = self.get_autocommit()
|
|
772
|
+
if value != current:
|
|
773
|
+
self._send_autocommit_mode()
|
|
774
|
+
|
|
775
|
+
def get_autocommit(self):
|
|
776
|
+
"""Retrieve autocommit status."""
|
|
777
|
+
return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
|
|
778
|
+
|
|
779
|
+
def _read_ok_packet(self):
|
|
780
|
+
pkt = self._read_packet()
|
|
781
|
+
if not pkt.is_ok_packet():
|
|
782
|
+
raise err.OperationalError(
|
|
783
|
+
CR.CR_COMMANDS_OUT_OF_SYNC,
|
|
784
|
+
'Command Out of Sync',
|
|
785
|
+
)
|
|
786
|
+
ok = OKPacketWrapper(pkt)
|
|
787
|
+
self.server_status = ok.server_status
|
|
788
|
+
return ok
|
|
789
|
+
|
|
790
|
+
def _send_autocommit_mode(self):
|
|
791
|
+
"""Set whether or not to commit after every execute()."""
|
|
792
|
+
log_query('SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode))
|
|
793
|
+
self._execute_command(
|
|
794
|
+
COMMAND.COM_QUERY, 'SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode),
|
|
795
|
+
)
|
|
796
|
+
self._read_ok_packet()
|
|
797
|
+
|
|
798
|
+
def begin(self):
|
|
799
|
+
"""Begin transaction."""
|
|
800
|
+
log_query('BEGIN')
|
|
801
|
+
if self.host == 'singlestore.com':
|
|
802
|
+
return
|
|
803
|
+
self._execute_command(COMMAND.COM_QUERY, 'BEGIN')
|
|
804
|
+
self._read_ok_packet()
|
|
805
|
+
|
|
806
|
+
def commit(self):
|
|
807
|
+
"""
|
|
808
|
+
Commit changes to stable storage.
|
|
809
|
+
|
|
810
|
+
See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
|
|
811
|
+
in the specification.
|
|
812
|
+
|
|
813
|
+
"""
|
|
814
|
+
log_query('COMMIT')
|
|
815
|
+
if not self._is_committable or self.host == 'singlestore.com':
|
|
816
|
+
self._is_committable = True
|
|
817
|
+
return
|
|
818
|
+
self._execute_command(COMMAND.COM_QUERY, 'COMMIT')
|
|
819
|
+
self._read_ok_packet()
|
|
820
|
+
|
|
821
|
+
def rollback(self):
|
|
822
|
+
"""
|
|
823
|
+
Roll back the current transaction.
|
|
824
|
+
|
|
825
|
+
See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
|
|
826
|
+
in the specification.
|
|
827
|
+
|
|
828
|
+
"""
|
|
829
|
+
log_query('ROLLBACK')
|
|
830
|
+
if not self._is_committable or self.host == 'singlestore.com':
|
|
831
|
+
self._is_committable = True
|
|
832
|
+
return
|
|
833
|
+
self._execute_command(COMMAND.COM_QUERY, 'ROLLBACK')
|
|
834
|
+
self._read_ok_packet()
|
|
835
|
+
|
|
836
|
+
def show_warnings(self):
|
|
837
|
+
"""Send the "SHOW WARNINGS" SQL command."""
|
|
838
|
+
log_query('SHOW WARNINGS')
|
|
839
|
+
self._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
|
|
840
|
+
result = self.resultclass(self)
|
|
841
|
+
result.read()
|
|
842
|
+
return result.rows
|
|
843
|
+
|
|
844
|
+
def select_db(self, db):
|
|
845
|
+
"""
|
|
846
|
+
Set current db.
|
|
847
|
+
|
|
848
|
+
db : str
|
|
849
|
+
The name of the db.
|
|
850
|
+
|
|
851
|
+
"""
|
|
852
|
+
self._execute_command(COMMAND.COM_INIT_DB, db)
|
|
853
|
+
self._read_ok_packet()
|
|
854
|
+
|
|
855
|
+
def escape(self, obj, mapping=None):
|
|
856
|
+
"""
|
|
857
|
+
Escape whatever value is passed.
|
|
858
|
+
|
|
859
|
+
Non-standard, for internal use; do not use this in your applications.
|
|
860
|
+
|
|
861
|
+
"""
|
|
862
|
+
dtype = type(obj)
|
|
863
|
+
if dtype is str or isinstance(obj, str):
|
|
864
|
+
return "'{}'".format(self.escape_string(obj))
|
|
865
|
+
if dtype is bytes or dtype is bytearray or isinstance(obj, (bytes, bytearray)):
|
|
866
|
+
return self._quote_bytes(obj)
|
|
867
|
+
if mapping is None:
|
|
868
|
+
mapping = self.encoders
|
|
869
|
+
return converters.escape_item(obj, self.charset, mapping=mapping)
|
|
870
|
+
|
|
871
|
+
def literal(self, obj):
|
|
872
|
+
"""
|
|
873
|
+
Alias for escape().
|
|
874
|
+
|
|
875
|
+
Non-standard, for internal use; do not use this in your applications.
|
|
876
|
+
|
|
877
|
+
"""
|
|
878
|
+
return self.escape(obj, self.encoders)
|
|
879
|
+
|
|
880
|
+
def escape_string(self, s):
|
|
881
|
+
"""Escape a string value."""
|
|
882
|
+
if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
|
|
883
|
+
return s.replace("'", "''")
|
|
884
|
+
return converters.escape_string(s)
|
|
885
|
+
|
|
886
|
+
def _quote_bytes(self, s):
|
|
887
|
+
if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
|
|
888
|
+
if self._binary_prefix:
|
|
889
|
+
return "_binary X'{}'".format(s.hex())
|
|
890
|
+
return "X'{}'".format(s.hex())
|
|
891
|
+
return converters.escape_bytes(s)
|
|
892
|
+
|
|
893
|
+
def cursor(self):
|
|
894
|
+
"""Create a new cursor to execute queries with."""
|
|
895
|
+
return self.cursorclass(self)
|
|
896
|
+
|
|
897
|
+
# The following methods are INTERNAL USE ONLY (called from Cursor)
|
|
898
|
+
def query(self, sql, unbuffered=False, infile_stream=None):
|
|
899
|
+
"""
|
|
900
|
+
Run a query on the server.
|
|
901
|
+
|
|
902
|
+
Internal use only.
|
|
903
|
+
|
|
904
|
+
"""
|
|
905
|
+
# if DEBUG:
|
|
906
|
+
# print("DEBUG: sending query:", sql)
|
|
907
|
+
handler = fusion.get_handler(sql)
|
|
908
|
+
if handler is not None:
|
|
909
|
+
self._is_committable = False
|
|
910
|
+
self._result = fusion.execute(self, sql, handler=handler)
|
|
911
|
+
self._affected_rows = self._result.affected_rows
|
|
912
|
+
else:
|
|
913
|
+
self._is_committable = True
|
|
914
|
+
if isinstance(sql, str):
|
|
915
|
+
sql = sql.encode(self.encoding, 'surrogateescape')
|
|
916
|
+
self._local_infile_stream = infile_stream
|
|
917
|
+
self._execute_command(COMMAND.COM_QUERY, sql)
|
|
918
|
+
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
|
|
919
|
+
self._local_infile_stream = None
|
|
920
|
+
return self._affected_rows
|
|
921
|
+
|
|
922
|
+
def next_result(self, unbuffered=False):
|
|
923
|
+
"""
|
|
924
|
+
Retrieve the next result set.
|
|
925
|
+
|
|
926
|
+
Internal use only.
|
|
927
|
+
|
|
928
|
+
"""
|
|
929
|
+
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
|
|
930
|
+
return self._affected_rows
|
|
931
|
+
|
|
932
|
+
def affected_rows(self):
|
|
933
|
+
"""
|
|
934
|
+
Return number of affected rows.
|
|
935
|
+
|
|
936
|
+
Internal use only.
|
|
937
|
+
|
|
938
|
+
"""
|
|
939
|
+
return self._affected_rows
|
|
940
|
+
|
|
941
|
+
def kill(self, thread_id):
|
|
942
|
+
"""
|
|
943
|
+
Execute kill command.
|
|
944
|
+
|
|
945
|
+
Internal use only.
|
|
946
|
+
|
|
947
|
+
"""
|
|
948
|
+
arg = struct.pack('<I', thread_id)
|
|
949
|
+
self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
|
|
950
|
+
return self._read_ok_packet()
|
|
951
|
+
|
|
952
|
+
def ping(self, reconnect=True):
|
|
953
|
+
"""
|
|
954
|
+
Check if the server is alive.
|
|
955
|
+
|
|
956
|
+
Parameters
|
|
957
|
+
----------
|
|
958
|
+
reconnect : bool, optional
|
|
959
|
+
If the connection is closed, reconnect.
|
|
960
|
+
|
|
961
|
+
Raises
|
|
962
|
+
------
|
|
963
|
+
Error : If the connection is closed and reconnect=False.
|
|
964
|
+
|
|
965
|
+
"""
|
|
966
|
+
if self._sock is None:
|
|
967
|
+
if reconnect:
|
|
968
|
+
self.connect()
|
|
969
|
+
reconnect = False
|
|
970
|
+
else:
|
|
971
|
+
raise err.Error('Already closed')
|
|
972
|
+
try:
|
|
973
|
+
self._execute_command(COMMAND.COM_PING, '')
|
|
974
|
+
self._read_ok_packet()
|
|
975
|
+
except Exception:
|
|
976
|
+
if reconnect:
|
|
977
|
+
self.connect()
|
|
978
|
+
self.ping(False)
|
|
979
|
+
else:
|
|
980
|
+
raise
|
|
981
|
+
|
|
982
|
+
def set_charset(self, charset):
|
|
983
|
+
"""Deprecated. Use set_character_set() instead."""
|
|
984
|
+
# This function has been implemented in old PyMySQL.
|
|
985
|
+
# But this name is different from MySQLdb.
|
|
986
|
+
# So we keep this function for compatibility and add
|
|
987
|
+
# new set_character_set() function.
|
|
988
|
+
self.set_character_set(charset)
|
|
989
|
+
|
|
990
|
+
def set_character_set(self, charset, collation=None):
|
|
991
|
+
"""
|
|
992
|
+
Set charaset (and collation) on the server.
|
|
993
|
+
|
|
994
|
+
Send "SET NAMES charset [COLLATE collation]" query.
|
|
995
|
+
Update Connection.encoding based on charset.
|
|
996
|
+
|
|
997
|
+
Parameters
|
|
998
|
+
----------
|
|
999
|
+
charset : str
|
|
1000
|
+
The charset to enable.
|
|
1001
|
+
collation : str, optional
|
|
1002
|
+
The collation value
|
|
1003
|
+
|
|
1004
|
+
"""
|
|
1005
|
+
# Make sure charset is supported.
|
|
1006
|
+
encoding = charset_by_name(charset).encoding
|
|
1007
|
+
|
|
1008
|
+
if collation:
|
|
1009
|
+
query = f'SET NAMES {charset} COLLATE {collation}'
|
|
1010
|
+
else:
|
|
1011
|
+
query = f'SET NAMES {charset}'
|
|
1012
|
+
self._execute_command(COMMAND.COM_QUERY, query)
|
|
1013
|
+
self._read_packet()
|
|
1014
|
+
self.charset = charset
|
|
1015
|
+
self.encoding = encoding
|
|
1016
|
+
self.collation = collation
|
|
1017
|
+
|
|
1018
|
+
def _sync_connection(self):
|
|
1019
|
+
"""Synchronize connection with env variable."""
|
|
1020
|
+
if self._in_sync:
|
|
1021
|
+
return
|
|
1022
|
+
|
|
1023
|
+
if not self._track_env:
|
|
1024
|
+
return
|
|
1025
|
+
|
|
1026
|
+
url = self._connection_info.get('connection_url')
|
|
1027
|
+
if not url:
|
|
1028
|
+
url = os.environ.get('SINGLESTOREDB_URL')
|
|
1029
|
+
if not url:
|
|
1030
|
+
return
|
|
1031
|
+
|
|
1032
|
+
out = {}
|
|
1033
|
+
urlp = connection._parse_url(url)
|
|
1034
|
+
out.update(urlp)
|
|
1035
|
+
|
|
1036
|
+
out = connection._cast_params(out)
|
|
1037
|
+
|
|
1038
|
+
# Set default port based on driver.
|
|
1039
|
+
if 'port' not in out or not out['port']:
|
|
1040
|
+
out['port'] = int(get_option('port') or 3306)
|
|
1041
|
+
|
|
1042
|
+
# If there is no user and the password is empty, remove the password key.
|
|
1043
|
+
if 'user' not in out and not out.get('password', None):
|
|
1044
|
+
out.pop('password', None)
|
|
1045
|
+
|
|
1046
|
+
if out['host'] == 'singlestore.com':
|
|
1047
|
+
raise err.InterfaceError(0, 'Connection URL has not been established')
|
|
1048
|
+
|
|
1049
|
+
# If it's just a password change, we don't need to reconnect
|
|
1050
|
+
if self._sock is not None and \
|
|
1051
|
+
(self.host, self.port, self.user, self.db) == \
|
|
1052
|
+
(out['host'], out['port'], out['user'], out.get('database')):
|
|
1053
|
+
return
|
|
1054
|
+
|
|
1055
|
+
self.host = out['host']
|
|
1056
|
+
self.port = out['port']
|
|
1057
|
+
self.user = out['user']
|
|
1058
|
+
if isinstance(out['password'], str):
|
|
1059
|
+
self.password = out['password'].encode('latin-1')
|
|
1060
|
+
else:
|
|
1061
|
+
self.password = out['password'] or b''
|
|
1062
|
+
self.db = out.get('database')
|
|
1063
|
+
try:
|
|
1064
|
+
self._in_sync = True
|
|
1065
|
+
self.connect()
|
|
1066
|
+
finally:
|
|
1067
|
+
self._in_sync = False
|
|
1068
|
+
|
|
1069
|
+
def connect(self, sock=None):
|
|
1070
|
+
"""
|
|
1071
|
+
Connect to server using existing parameters.
|
|
1072
|
+
|
|
1073
|
+
Internal use only.
|
|
1074
|
+
|
|
1075
|
+
"""
|
|
1076
|
+
self._closed = False
|
|
1077
|
+
try:
|
|
1078
|
+
if sock is None:
|
|
1079
|
+
if self.unix_socket:
|
|
1080
|
+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
1081
|
+
sock.settimeout(self.connect_timeout)
|
|
1082
|
+
sock.connect(self.unix_socket)
|
|
1083
|
+
self.host_info = 'Localhost via UNIX socket'
|
|
1084
|
+
self._secure = True
|
|
1085
|
+
if DEBUG:
|
|
1086
|
+
print('connected using unix_socket')
|
|
1087
|
+
else:
|
|
1088
|
+
kwargs = {}
|
|
1089
|
+
if self.bind_address is not None:
|
|
1090
|
+
kwargs['source_address'] = (self.bind_address, 0)
|
|
1091
|
+
while True:
|
|
1092
|
+
try:
|
|
1093
|
+
sock = socket.create_connection(
|
|
1094
|
+
(self.host, self.port), self.connect_timeout, **kwargs,
|
|
1095
|
+
)
|
|
1096
|
+
break
|
|
1097
|
+
except OSError as e:
|
|
1098
|
+
if e.errno == errno.EINTR:
|
|
1099
|
+
continue
|
|
1100
|
+
raise
|
|
1101
|
+
self.host_info = 'socket %s:%d' % (self.host, self.port)
|
|
1102
|
+
if DEBUG:
|
|
1103
|
+
print('connected using socket')
|
|
1104
|
+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
1105
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
|
1106
|
+
sock.settimeout(None)
|
|
1107
|
+
|
|
1108
|
+
self._sock = sock
|
|
1109
|
+
self._rfile = sock.makefile('rb')
|
|
1110
|
+
self._next_seq_id = 0
|
|
1111
|
+
|
|
1112
|
+
self._get_server_information()
|
|
1113
|
+
self._request_authentication()
|
|
1114
|
+
|
|
1115
|
+
# Send "SET NAMES" query on init for:
|
|
1116
|
+
# - Ensure charaset (and collation) is set to the server.
|
|
1117
|
+
# - collation_id in handshake packet may be ignored.
|
|
1118
|
+
# - If collation is not specified, we don't know what is server's
|
|
1119
|
+
# default collation for the charset. For example, default collation
|
|
1120
|
+
# of utf8mb4 is:
|
|
1121
|
+
# - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci
|
|
1122
|
+
# - MySQL 8.0: utf8mb4_0900_ai_ci
|
|
1123
|
+
#
|
|
1124
|
+
# Reference:
|
|
1125
|
+
# - https://github.com/PyMySQL/PyMySQL/issues/1092
|
|
1126
|
+
# - https://github.com/wagtail/wagtail/issues/9477
|
|
1127
|
+
# - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)
|
|
1128
|
+
self.set_character_set(self.charset, self.collation)
|
|
1129
|
+
|
|
1130
|
+
if self.sql_mode is not None:
|
|
1131
|
+
c = self.cursor()
|
|
1132
|
+
c.execute('SET sql_mode=%s', (self.sql_mode,))
|
|
1133
|
+
c.close()
|
|
1134
|
+
|
|
1135
|
+
if self._enable_extended_data_types:
|
|
1136
|
+
c = self.cursor()
|
|
1137
|
+
try:
|
|
1138
|
+
c.execute('SET @@SESSION.enable_extended_types_metadata=on')
|
|
1139
|
+
except self.OperationalError:
|
|
1140
|
+
pass
|
|
1141
|
+
c.close()
|
|
1142
|
+
|
|
1143
|
+
if self._vector_data_format:
|
|
1144
|
+
c = self.cursor()
|
|
1145
|
+
try:
|
|
1146
|
+
val = self._vector_data_format
|
|
1147
|
+
c.execute(f'SET @@SESSION.vector_type_project_format={val}')
|
|
1148
|
+
except self.OperationalError:
|
|
1149
|
+
pass
|
|
1150
|
+
c.close()
|
|
1151
|
+
|
|
1152
|
+
if self.init_command is not None:
|
|
1153
|
+
c = self.cursor()
|
|
1154
|
+
c.execute(self.init_command)
|
|
1155
|
+
c.close()
|
|
1156
|
+
|
|
1157
|
+
if self.autocommit_mode is not None:
|
|
1158
|
+
self.autocommit(self.autocommit_mode)
|
|
1159
|
+
|
|
1160
|
+
except BaseException as e:
|
|
1161
|
+
self._rfile = None
|
|
1162
|
+
if sock is not None:
|
|
1163
|
+
try:
|
|
1164
|
+
sock.close()
|
|
1165
|
+
except: # noqa
|
|
1166
|
+
pass
|
|
1167
|
+
|
|
1168
|
+
if isinstance(e, (OSError, IOError, socket.error)):
|
|
1169
|
+
exc = err.OperationalError(
|
|
1170
|
+
CR.CR_CONN_HOST_ERROR,
|
|
1171
|
+
f'Can\'t connect to MySQL server on {self.host!r} ({e})',
|
|
1172
|
+
)
|
|
1173
|
+
# Keep original exception and traceback to investigate error.
|
|
1174
|
+
exc.original_exception = e
|
|
1175
|
+
exc.traceback = traceback.format_exc()
|
|
1176
|
+
if DEBUG:
|
|
1177
|
+
print(exc.traceback)
|
|
1178
|
+
raise exc
|
|
1179
|
+
|
|
1180
|
+
# If e is neither DatabaseError or IOError, It's a bug.
|
|
1181
|
+
# But raising AssertionError hides original error.
|
|
1182
|
+
# So just reraise it.
|
|
1183
|
+
raise
|
|
1184
|
+
|
|
1185
|
+
def write_packet(self, payload):
|
|
1186
|
+
"""
|
|
1187
|
+
Writes an entire "mysql packet" in its entirety to the network.
|
|
1188
|
+
|
|
1189
|
+
Adds its length and sequence number.
|
|
1190
|
+
|
|
1191
|
+
"""
|
|
1192
|
+
# Internal note: when you build packet manually and calls _write_bytes()
|
|
1193
|
+
# directly, you should set self._next_seq_id properly.
|
|
1194
|
+
data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
|
|
1195
|
+
if DEBUG:
|
|
1196
|
+
dump_packet(data)
|
|
1197
|
+
self._write_bytes(data)
|
|
1198
|
+
self._next_seq_id = (self._next_seq_id + 1) % 256
|
|
1199
|
+
|
|
1200
|
+
def _read_packet(self, packet_type=MysqlPacket):
|
|
1201
|
+
"""
|
|
1202
|
+
Read an entire "mysql packet" in its entirety from the network.
|
|
1203
|
+
|
|
1204
|
+
Raises
|
|
1205
|
+
------
|
|
1206
|
+
OperationalError : If the connection to the MySQL server is lost.
|
|
1207
|
+
InternalError : If the packet sequence number is wrong.
|
|
1208
|
+
|
|
1209
|
+
Returns
|
|
1210
|
+
-------
|
|
1211
|
+
MysqlPacket
|
|
1212
|
+
|
|
1213
|
+
"""
|
|
1214
|
+
buff = bytearray()
|
|
1215
|
+
while True:
|
|
1216
|
+
packet_header = self._read_bytes(4)
|
|
1217
|
+
# if DEBUG: dump_packet(packet_header)
|
|
1218
|
+
|
|
1219
|
+
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
|
|
1220
|
+
bytes_to_read = btrl + (btrh << 16)
|
|
1221
|
+
if packet_number != self._next_seq_id:
|
|
1222
|
+
self._force_close()
|
|
1223
|
+
if packet_number == 0:
|
|
1224
|
+
# MariaDB sends error packet with seqno==0 when shutdown
|
|
1225
|
+
raise err.OperationalError(
|
|
1226
|
+
CR.CR_SERVER_LOST,
|
|
1227
|
+
'Lost connection to MySQL server during query',
|
|
1228
|
+
)
|
|
1229
|
+
raise err.InternalError(
|
|
1230
|
+
'Packet sequence number wrong - got %d expected %d'
|
|
1231
|
+
% (packet_number, self._next_seq_id),
|
|
1232
|
+
)
|
|
1233
|
+
self._next_seq_id = (self._next_seq_id + 1) % 256
|
|
1234
|
+
|
|
1235
|
+
recv_data = self._read_bytes(bytes_to_read)
|
|
1236
|
+
if DEBUG:
|
|
1237
|
+
dump_packet(recv_data)
|
|
1238
|
+
buff += recv_data
|
|
1239
|
+
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
|
|
1240
|
+
if bytes_to_read == 0xFFFFFF:
|
|
1241
|
+
continue
|
|
1242
|
+
if bytes_to_read < MAX_PACKET_LEN:
|
|
1243
|
+
break
|
|
1244
|
+
|
|
1245
|
+
packet = packet_type(bytes(buff), self.encoding)
|
|
1246
|
+
if packet.is_error_packet():
|
|
1247
|
+
if self._result is not None and self._result.unbuffered_active is True:
|
|
1248
|
+
self._result.unbuffered_active = False
|
|
1249
|
+
packet.raise_for_error()
|
|
1250
|
+
return packet
|
|
1251
|
+
|
|
1252
|
+
def _read_bytes(self, num_bytes):
|
|
1253
|
+
if self._read_timeout is not None:
|
|
1254
|
+
self._sock.settimeout(self._read_timeout)
|
|
1255
|
+
while True:
|
|
1256
|
+
try:
|
|
1257
|
+
data = self._rfile.read(num_bytes)
|
|
1258
|
+
break
|
|
1259
|
+
except OSError as e:
|
|
1260
|
+
if e.errno == errno.EINTR:
|
|
1261
|
+
continue
|
|
1262
|
+
self._force_close()
|
|
1263
|
+
raise err.OperationalError(
|
|
1264
|
+
CR.CR_SERVER_LOST,
|
|
1265
|
+
'Lost connection to MySQL server during query (%s)' % (e,),
|
|
1266
|
+
)
|
|
1267
|
+
except BaseException:
|
|
1268
|
+
# Don't convert unknown exception to MySQLError.
|
|
1269
|
+
self._force_close()
|
|
1270
|
+
raise
|
|
1271
|
+
if len(data) < num_bytes:
|
|
1272
|
+
self._force_close()
|
|
1273
|
+
raise err.OperationalError(
|
|
1274
|
+
CR.CR_SERVER_LOST, 'Lost connection to MySQL server during query',
|
|
1275
|
+
)
|
|
1276
|
+
return data
|
|
1277
|
+
|
|
1278
|
+
def _write_bytes(self, data):
|
|
1279
|
+
if self._write_timeout is not None:
|
|
1280
|
+
self._sock.settimeout(self._write_timeout)
|
|
1281
|
+
try:
|
|
1282
|
+
self._sock.sendall(data)
|
|
1283
|
+
except OSError as e:
|
|
1284
|
+
self._force_close()
|
|
1285
|
+
raise err.OperationalError(
|
|
1286
|
+
CR.CR_SERVER_GONE_ERROR, f'MySQL server has gone away ({e!r})',
|
|
1287
|
+
)
|
|
1288
|
+
|
|
1289
|
+
def _read_query_result(self, unbuffered=False):
|
|
1290
|
+
self._result = None
|
|
1291
|
+
if unbuffered:
|
|
1292
|
+
result = self.resultclass(self, unbuffered=unbuffered)
|
|
1293
|
+
else:
|
|
1294
|
+
result = self.resultclass(self)
|
|
1295
|
+
result.read()
|
|
1296
|
+
self._result = result
|
|
1297
|
+
if result.server_status is not None:
|
|
1298
|
+
self.server_status = result.server_status
|
|
1299
|
+
return result.affected_rows
|
|
1300
|
+
|
|
1301
|
+
def insert_id(self):
|
|
1302
|
+
if self._result:
|
|
1303
|
+
return self._result.insert_id
|
|
1304
|
+
else:
|
|
1305
|
+
return 0
|
|
1306
|
+
|
|
1307
|
+
def _execute_command(self, command, sql):
|
|
1308
|
+
"""
|
|
1309
|
+
Execute command.
|
|
1310
|
+
|
|
1311
|
+
Raises
|
|
1312
|
+
------
|
|
1313
|
+
InterfaceError : If the connection is closed.
|
|
1314
|
+
ValueError : If no username was specified.
|
|
1315
|
+
|
|
1316
|
+
"""
|
|
1317
|
+
self._sync_connection()
|
|
1318
|
+
|
|
1319
|
+
if self._sock is None:
|
|
1320
|
+
raise err.InterfaceError(0, 'The connection has been closed')
|
|
1321
|
+
|
|
1322
|
+
# If the last query was unbuffered, make sure it finishes before
|
|
1323
|
+
# sending new commands
|
|
1324
|
+
if self._result is not None:
|
|
1325
|
+
if self._result.unbuffered_active:
|
|
1326
|
+
warnings.warn('Previous unbuffered result was left incomplete')
|
|
1327
|
+
self._result._finish_unbuffered_query()
|
|
1328
|
+
while self._result.has_next:
|
|
1329
|
+
self.next_result()
|
|
1330
|
+
self._result = None
|
|
1331
|
+
|
|
1332
|
+
if isinstance(sql, str):
|
|
1333
|
+
sql = sql.encode(self.encoding)
|
|
1334
|
+
|
|
1335
|
+
packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
|
|
1336
|
+
|
|
1337
|
+
# tiny optimization: build first packet manually instead of
|
|
1338
|
+
# calling self..write_packet()
|
|
1339
|
+
prelude = struct.pack('<iB', packet_size, command)
|
|
1340
|
+
packet = prelude + sql[: packet_size - 1]
|
|
1341
|
+
self._write_bytes(packet)
|
|
1342
|
+
if DEBUG:
|
|
1343
|
+
dump_packet(packet)
|
|
1344
|
+
self._next_seq_id = 1
|
|
1345
|
+
|
|
1346
|
+
if packet_size < MAX_PACKET_LEN:
|
|
1347
|
+
return
|
|
1348
|
+
|
|
1349
|
+
sql = sql[packet_size - 1:]
|
|
1350
|
+
while True:
|
|
1351
|
+
packet_size = min(MAX_PACKET_LEN, len(sql))
|
|
1352
|
+
self.write_packet(sql[:packet_size])
|
|
1353
|
+
sql = sql[packet_size:]
|
|
1354
|
+
if not sql and packet_size < MAX_PACKET_LEN:
|
|
1355
|
+
break
|
|
1356
|
+
|
|
1357
|
+
def _request_authentication(self): # noqa: C901
|
|
1358
|
+
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
|
1359
|
+
if int(self.server_version.split('.', 1)[0]) >= 5:
|
|
1360
|
+
self.client_flag |= CLIENT.MULTI_RESULTS
|
|
1361
|
+
|
|
1362
|
+
if self.user is None:
|
|
1363
|
+
raise ValueError('Did not specify a username')
|
|
1364
|
+
|
|
1365
|
+
charset_id = charset_by_name(self.charset).id
|
|
1366
|
+
if isinstance(self.user, str):
|
|
1367
|
+
self.user = self.user.encode(self.encoding)
|
|
1368
|
+
|
|
1369
|
+
data_init = struct.pack(
|
|
1370
|
+
'<iIB23s', self.client_flag, MAX_PACKET_LEN, charset_id, b'',
|
|
1371
|
+
)
|
|
1372
|
+
|
|
1373
|
+
if self.ssl and self.server_capabilities & CLIENT.SSL:
|
|
1374
|
+
self.write_packet(data_init)
|
|
1375
|
+
|
|
1376
|
+
hostname = self.host
|
|
1377
|
+
if self._tls_sni_servername:
|
|
1378
|
+
hostname = self._tls_sni_servername
|
|
1379
|
+
self._sock = self.ctx.wrap_socket(self._sock, server_hostname=hostname)
|
|
1380
|
+
self._rfile = self._sock.makefile('rb')
|
|
1381
|
+
self._secure = True
|
|
1382
|
+
|
|
1383
|
+
data = data_init + self.user + b'\0'
|
|
1384
|
+
|
|
1385
|
+
authresp = b''
|
|
1386
|
+
plugin_name = None
|
|
1387
|
+
|
|
1388
|
+
if self._auth_plugin_name == '':
|
|
1389
|
+
plugin_name = b''
|
|
1390
|
+
authresp = _auth.scramble_native_password(self.password, self.salt)
|
|
1391
|
+
elif self._auth_plugin_name == 'mysql_native_password':
|
|
1392
|
+
plugin_name = b'mysql_native_password'
|
|
1393
|
+
authresp = _auth.scramble_native_password(self.password, self.salt)
|
|
1394
|
+
elif self._auth_plugin_name == 'caching_sha2_password':
|
|
1395
|
+
plugin_name = b'caching_sha2_password'
|
|
1396
|
+
if self.password:
|
|
1397
|
+
if DEBUG:
|
|
1398
|
+
print('caching_sha2: trying fast path')
|
|
1399
|
+
authresp = _auth.scramble_caching_sha2(self.password, self.salt)
|
|
1400
|
+
else:
|
|
1401
|
+
if DEBUG:
|
|
1402
|
+
print('caching_sha2: empty password')
|
|
1403
|
+
elif self._auth_plugin_name == 'sha256_password':
|
|
1404
|
+
plugin_name = b'sha256_password'
|
|
1405
|
+
if self.ssl and self.server_capabilities & CLIENT.SSL:
|
|
1406
|
+
authresp = self.password + b'\0'
|
|
1407
|
+
elif self.password:
|
|
1408
|
+
authresp = b'\1' # request public key
|
|
1409
|
+
else:
|
|
1410
|
+
authresp = b'\0' # empty password
|
|
1411
|
+
|
|
1412
|
+
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
|
1413
|
+
data += _lenenc_int(len(authresp)) + authresp
|
|
1414
|
+
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
|
1415
|
+
data += struct.pack('B', len(authresp)) + authresp
|
|
1416
|
+
else: # pragma: no cover - no testing against servers w/o secure auth (>=5.0)
|
|
1417
|
+
data += authresp + b'\0'
|
|
1418
|
+
|
|
1419
|
+
if self.server_capabilities & CLIENT.CONNECT_WITH_DB:
|
|
1420
|
+
db = self.db
|
|
1421
|
+
if isinstance(db, str):
|
|
1422
|
+
db = db.encode(self.encoding)
|
|
1423
|
+
data += (db or b'') + b'\0'
|
|
1424
|
+
|
|
1425
|
+
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
|
1426
|
+
data += (plugin_name or b'') + b'\0'
|
|
1427
|
+
|
|
1428
|
+
if self.server_capabilities & CLIENT.CONNECT_ATTRS:
|
|
1429
|
+
connect_attrs = b''
|
|
1430
|
+
for k, v in self._connect_attrs.items():
|
|
1431
|
+
k = k.encode('utf-8')
|
|
1432
|
+
connect_attrs += _lenenc_int(len(k)) + k
|
|
1433
|
+
v = v.encode('utf-8')
|
|
1434
|
+
connect_attrs += _lenenc_int(len(v)) + v
|
|
1435
|
+
data += _lenenc_int(len(connect_attrs)) + connect_attrs
|
|
1436
|
+
|
|
1437
|
+
self.write_packet(data)
|
|
1438
|
+
auth_packet = self._read_packet()
|
|
1439
|
+
|
|
1440
|
+
# if authentication method isn't accepted the first byte
|
|
1441
|
+
# will have the octet 254
|
|
1442
|
+
if auth_packet.is_auth_switch_request():
|
|
1443
|
+
if DEBUG:
|
|
1444
|
+
print('received auth switch')
|
|
1445
|
+
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
|
1446
|
+
auth_packet.read_uint8() # 0xfe packet identifier
|
|
1447
|
+
plugin_name = auth_packet.read_string()
|
|
1448
|
+
if (
|
|
1449
|
+
self.server_capabilities & CLIENT.PLUGIN_AUTH
|
|
1450
|
+
and plugin_name is not None
|
|
1451
|
+
):
|
|
1452
|
+
auth_packet = self._process_auth(plugin_name, auth_packet)
|
|
1453
|
+
else:
|
|
1454
|
+
raise err.OperationalError('received unknown auth switch request')
|
|
1455
|
+
elif auth_packet.is_extra_auth_data():
|
|
1456
|
+
if DEBUG:
|
|
1457
|
+
print('received extra data')
|
|
1458
|
+
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
|
|
1459
|
+
if self._auth_plugin_name == 'caching_sha2_password':
|
|
1460
|
+
auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
|
|
1461
|
+
elif self._auth_plugin_name == 'sha256_password':
|
|
1462
|
+
auth_packet = _auth.sha256_password_auth(self, auth_packet)
|
|
1463
|
+
else:
|
|
1464
|
+
raise err.OperationalError(
|
|
1465
|
+
'Received extra packet for auth method %r', self._auth_plugin_name,
|
|
1466
|
+
)
|
|
1467
|
+
|
|
1468
|
+
if DEBUG:
|
|
1469
|
+
print('Succeed to auth')
|
|
1470
|
+
|
|
1471
|
+
def _process_auth(self, plugin_name, auth_packet):
|
|
1472
|
+
handler = self._get_auth_plugin_handler(plugin_name)
|
|
1473
|
+
if handler:
|
|
1474
|
+
try:
|
|
1475
|
+
return handler.authenticate(auth_packet)
|
|
1476
|
+
except AttributeError:
|
|
1477
|
+
if plugin_name != b'dialog':
|
|
1478
|
+
raise err.OperationalError(
|
|
1479
|
+
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
|
1480
|
+
"Authentication plugin '%s'"
|
|
1481
|
+
' not loaded: - %r missing authenticate method'
|
|
1482
|
+
% (plugin_name, type(handler)),
|
|
1483
|
+
)
|
|
1484
|
+
if plugin_name == b'caching_sha2_password':
|
|
1485
|
+
return _auth.caching_sha2_password_auth(self, auth_packet)
|
|
1486
|
+
elif plugin_name == b'sha256_password':
|
|
1487
|
+
return _auth.sha256_password_auth(self, auth_packet)
|
|
1488
|
+
elif plugin_name == b'mysql_native_password':
|
|
1489
|
+
data = _auth.scramble_native_password(self.password, auth_packet.read_all())
|
|
1490
|
+
elif plugin_name == b'client_ed25519':
|
|
1491
|
+
data = _auth.ed25519_password(self.password, auth_packet.read_all())
|
|
1492
|
+
elif plugin_name == b'mysql_old_password':
|
|
1493
|
+
data = (
|
|
1494
|
+
_auth.scramble_old_password(self.password, auth_packet.read_all())
|
|
1495
|
+
+ b'\0'
|
|
1496
|
+
)
|
|
1497
|
+
elif plugin_name == b'mysql_clear_password':
|
|
1498
|
+
# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
|
|
1499
|
+
data = self.password + b'\0'
|
|
1500
|
+
elif plugin_name == b'auth_gssapi_client':
|
|
1501
|
+
data = _auth.gssapi_auth(auth_packet.read_all())
|
|
1502
|
+
elif plugin_name == b'dialog':
|
|
1503
|
+
pkt = auth_packet
|
|
1504
|
+
while True:
|
|
1505
|
+
flag = pkt.read_uint8()
|
|
1506
|
+
echo = (flag & 0x06) == 0x02
|
|
1507
|
+
last = (flag & 0x01) == 0x01
|
|
1508
|
+
prompt = pkt.read_all()
|
|
1509
|
+
|
|
1510
|
+
if prompt == b'Password: ':
|
|
1511
|
+
self.write_packet(self.password + b'\0')
|
|
1512
|
+
elif handler:
|
|
1513
|
+
resp = 'no response - TypeError within plugin.prompt method'
|
|
1514
|
+
try:
|
|
1515
|
+
resp = handler.prompt(echo, prompt)
|
|
1516
|
+
self.write_packet(resp + b'\0')
|
|
1517
|
+
except AttributeError:
|
|
1518
|
+
raise err.OperationalError(
|
|
1519
|
+
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
|
1520
|
+
"Authentication plugin '%s'"
|
|
1521
|
+
' not loaded: - %r missing prompt method'
|
|
1522
|
+
% (plugin_name, handler),
|
|
1523
|
+
)
|
|
1524
|
+
except TypeError:
|
|
1525
|
+
raise err.OperationalError(
|
|
1526
|
+
CR.CR_AUTH_PLUGIN_ERR,
|
|
1527
|
+
"Authentication plugin '%s'"
|
|
1528
|
+
" %r didn't respond with string. Returned '%r' to prompt %r"
|
|
1529
|
+
% (plugin_name, handler, resp, prompt),
|
|
1530
|
+
)
|
|
1531
|
+
else:
|
|
1532
|
+
raise err.OperationalError(
|
|
1533
|
+
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
|
1534
|
+
"Authentication plugin '%s' not configured" % (plugin_name,),
|
|
1535
|
+
)
|
|
1536
|
+
pkt = self._read_packet()
|
|
1537
|
+
pkt.check_error()
|
|
1538
|
+
if pkt.is_ok_packet() or last:
|
|
1539
|
+
break
|
|
1540
|
+
return pkt
|
|
1541
|
+
else:
|
|
1542
|
+
raise err.OperationalError(
|
|
1543
|
+
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
|
1544
|
+
"Authentication plugin '%s' not configured" % plugin_name,
|
|
1545
|
+
)
|
|
1546
|
+
|
|
1547
|
+
self.write_packet(data)
|
|
1548
|
+
pkt = self._read_packet()
|
|
1549
|
+
pkt.check_error()
|
|
1550
|
+
return pkt
|
|
1551
|
+
|
|
1552
|
+
def _get_auth_plugin_handler(self, plugin_name):
|
|
1553
|
+
plugin_class = self._auth_plugin_map.get(plugin_name)
|
|
1554
|
+
if not plugin_class and isinstance(plugin_name, bytes):
|
|
1555
|
+
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
|
|
1556
|
+
if plugin_class:
|
|
1557
|
+
try:
|
|
1558
|
+
handler = plugin_class(self)
|
|
1559
|
+
except TypeError:
|
|
1560
|
+
raise err.OperationalError(
|
|
1561
|
+
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
|
|
1562
|
+
"Authentication plugin '%s'"
|
|
1563
|
+
' not loaded: - %r cannot be constructed with connection object'
|
|
1564
|
+
% (plugin_name, plugin_class),
|
|
1565
|
+
)
|
|
1566
|
+
else:
|
|
1567
|
+
handler = None
|
|
1568
|
+
return handler
|
|
1569
|
+
|
|
1570
|
+
# _mysql support
|
|
1571
|
+
def thread_id(self):
|
|
1572
|
+
return self.server_thread_id[0]
|
|
1573
|
+
|
|
1574
|
+
def character_set_name(self):
|
|
1575
|
+
return self.charset
|
|
1576
|
+
|
|
1577
|
+
def get_host_info(self):
|
|
1578
|
+
return self.host_info
|
|
1579
|
+
|
|
1580
|
+
def get_proto_info(self):
|
|
1581
|
+
return self.protocol_version
|
|
1582
|
+
|
|
1583
|
+
def _get_server_information(self):
|
|
1584
|
+
i = 0
|
|
1585
|
+
packet = self._read_packet()
|
|
1586
|
+
data = packet.get_all_data()
|
|
1587
|
+
|
|
1588
|
+
self.protocol_version = data[i]
|
|
1589
|
+
i += 1
|
|
1590
|
+
|
|
1591
|
+
server_end = data.find(b'\0', i)
|
|
1592
|
+
self.server_version = data[i:server_end].decode('latin1')
|
|
1593
|
+
i = server_end + 1
|
|
1594
|
+
|
|
1595
|
+
self.server_thread_id = struct.unpack('<I', data[i: i + 4])
|
|
1596
|
+
i += 4
|
|
1597
|
+
|
|
1598
|
+
self.salt = data[i: i + 8]
|
|
1599
|
+
i += 9 # 8 + 1(filler)
|
|
1600
|
+
|
|
1601
|
+
self.server_capabilities = struct.unpack('<H', data[i: i + 2])[0]
|
|
1602
|
+
i += 2
|
|
1603
|
+
|
|
1604
|
+
if len(data) >= i + 6:
|
|
1605
|
+
lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i: i + 6])
|
|
1606
|
+
i += 6
|
|
1607
|
+
# TODO: deprecate server_language and server_charset.
|
|
1608
|
+
# mysqlclient-python doesn't provide it.
|
|
1609
|
+
self.server_language = lang
|
|
1610
|
+
try:
|
|
1611
|
+
self.server_charset = charset_by_id(lang).name
|
|
1612
|
+
except KeyError:
|
|
1613
|
+
# unknown collation
|
|
1614
|
+
self.server_charset = None
|
|
1615
|
+
|
|
1616
|
+
self.server_status = stat
|
|
1617
|
+
if DEBUG:
|
|
1618
|
+
print('server_status: %x' % stat)
|
|
1619
|
+
|
|
1620
|
+
self.server_capabilities |= cap_h << 16
|
|
1621
|
+
if DEBUG:
|
|
1622
|
+
print('salt_len:', salt_len)
|
|
1623
|
+
salt_len = max(12, salt_len - 9)
|
|
1624
|
+
|
|
1625
|
+
# reserved
|
|
1626
|
+
i += 10
|
|
1627
|
+
|
|
1628
|
+
if len(data) >= i + salt_len:
|
|
1629
|
+
# salt_len includes auth_plugin_data_part_1 and filler
|
|
1630
|
+
self.salt += data[i: i + salt_len]
|
|
1631
|
+
i += salt_len
|
|
1632
|
+
|
|
1633
|
+
i += 1
|
|
1634
|
+
# AUTH PLUGIN NAME may appear here.
|
|
1635
|
+
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
|
|
1636
|
+
# Due to Bug#59453 the auth-plugin-name is missing the terminating
|
|
1637
|
+
# NUL-char in versions prior to 5.5.10 and 5.6.2.
|
|
1638
|
+
# ref: https://dev.mysql.com/doc/internals/en/
|
|
1639
|
+
# connection-phase-packets.html#packet-Protocol::Handshake
|
|
1640
|
+
# didn't use version checks as mariadb is corrected and reports
|
|
1641
|
+
# earlier than those two.
|
|
1642
|
+
server_end = data.find(b'\0', i)
|
|
1643
|
+
if server_end < 0: # pragma: no cover - very specific upstream bug
|
|
1644
|
+
# not found \0 and last field so take it all
|
|
1645
|
+
self._auth_plugin_name = data[i:].decode('utf-8')
|
|
1646
|
+
else:
|
|
1647
|
+
self._auth_plugin_name = data[i:server_end].decode('utf-8')
|
|
1648
|
+
|
|
1649
|
+
def get_server_info(self):
|
|
1650
|
+
return self.server_version
|
|
1651
|
+
|
|
1652
|
+
Warning = err.Warning
|
|
1653
|
+
Error = err.Error
|
|
1654
|
+
InterfaceError = err.InterfaceError
|
|
1655
|
+
DatabaseError = err.DatabaseError
|
|
1656
|
+
DataError = err.DataError
|
|
1657
|
+
OperationalError = err.OperationalError
|
|
1658
|
+
IntegrityError = err.IntegrityError
|
|
1659
|
+
InternalError = err.InternalError
|
|
1660
|
+
ProgrammingError = err.ProgrammingError
|
|
1661
|
+
NotSupportedError = err.NotSupportedError
|
|
1662
|
+
|
|
1663
|
+
|
|
1664
|
+
class MySQLResult:
|
|
1665
|
+
"""
|
|
1666
|
+
Results of a SQL query.
|
|
1667
|
+
|
|
1668
|
+
Parameters
|
|
1669
|
+
----------
|
|
1670
|
+
connection : Connection
|
|
1671
|
+
The connection the result came from.
|
|
1672
|
+
unbuffered : bool, optional
|
|
1673
|
+
Should the reads be unbuffered?
|
|
1674
|
+
|
|
1675
|
+
"""
|
|
1676
|
+
|
|
1677
|
+
def __init__(self, connection, unbuffered=False):
|
|
1678
|
+
self.connection = connection
|
|
1679
|
+
self.affected_rows = None
|
|
1680
|
+
self.insert_id = None
|
|
1681
|
+
self.server_status = None
|
|
1682
|
+
self.warning_count = 0
|
|
1683
|
+
self.message = None
|
|
1684
|
+
self.field_count = 0
|
|
1685
|
+
self.description = None
|
|
1686
|
+
self.rows = None
|
|
1687
|
+
self.has_next = None
|
|
1688
|
+
self.unbuffered_active = False
|
|
1689
|
+
self.converters = []
|
|
1690
|
+
self.fields = []
|
|
1691
|
+
self.encoding_errors = self.connection.encoding_errors
|
|
1692
|
+
if unbuffered:
|
|
1693
|
+
try:
|
|
1694
|
+
self.init_unbuffered_query()
|
|
1695
|
+
except Exception:
|
|
1696
|
+
self.connection = None
|
|
1697
|
+
self.unbuffered_active = False
|
|
1698
|
+
raise
|
|
1699
|
+
|
|
1700
|
+
def __del__(self):
|
|
1701
|
+
if self.unbuffered_active:
|
|
1702
|
+
self._finish_unbuffered_query()
|
|
1703
|
+
|
|
1704
|
+
def read(self):
|
|
1705
|
+
try:
|
|
1706
|
+
first_packet = self.connection._read_packet()
|
|
1707
|
+
|
|
1708
|
+
if first_packet.is_ok_packet():
|
|
1709
|
+
self._read_ok_packet(first_packet)
|
|
1710
|
+
elif first_packet.is_load_local_packet():
|
|
1711
|
+
self._read_load_local_packet(first_packet)
|
|
1712
|
+
else:
|
|
1713
|
+
self._read_result_packet(first_packet)
|
|
1714
|
+
finally:
|
|
1715
|
+
self.connection = None
|
|
1716
|
+
|
|
1717
|
+
def init_unbuffered_query(self):
|
|
1718
|
+
"""
|
|
1719
|
+
Initialize an unbuffered query.
|
|
1720
|
+
|
|
1721
|
+
Raises
|
|
1722
|
+
------
|
|
1723
|
+
OperationalError : If the connection to the MySQL server is lost.
|
|
1724
|
+
InternalError : Other errors.
|
|
1725
|
+
|
|
1726
|
+
"""
|
|
1727
|
+
self.unbuffered_active = True
|
|
1728
|
+
first_packet = self.connection._read_packet()
|
|
1729
|
+
|
|
1730
|
+
if first_packet.is_ok_packet():
|
|
1731
|
+
self._read_ok_packet(first_packet)
|
|
1732
|
+
self.unbuffered_active = False
|
|
1733
|
+
self.connection = None
|
|
1734
|
+
elif first_packet.is_load_local_packet():
|
|
1735
|
+
self._read_load_local_packet(first_packet)
|
|
1736
|
+
self.unbuffered_active = False
|
|
1737
|
+
self.connection = None
|
|
1738
|
+
else:
|
|
1739
|
+
self.field_count = first_packet.read_length_encoded_integer()
|
|
1740
|
+
self._get_descriptions()
|
|
1741
|
+
|
|
1742
|
+
# Apparently, MySQLdb picks this number because it's the maximum
|
|
1743
|
+
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
|
|
1744
|
+
# we set it to this instead of None, which would be preferred.
|
|
1745
|
+
self.affected_rows = 18446744073709551615
|
|
1746
|
+
|
|
1747
|
+
def _read_ok_packet(self, first_packet):
|
|
1748
|
+
ok_packet = OKPacketWrapper(first_packet)
|
|
1749
|
+
self.affected_rows = ok_packet.affected_rows
|
|
1750
|
+
self.insert_id = ok_packet.insert_id
|
|
1751
|
+
self.server_status = ok_packet.server_status
|
|
1752
|
+
self.warning_count = ok_packet.warning_count
|
|
1753
|
+
self.message = ok_packet.message
|
|
1754
|
+
self.has_next = ok_packet.has_next
|
|
1755
|
+
|
|
1756
|
+
def _read_load_local_packet(self, first_packet):
|
|
1757
|
+
if not self.connection._local_infile:
|
|
1758
|
+
raise RuntimeError(
|
|
1759
|
+
'**WARN**: Received LOAD_LOCAL packet but local_infile option is false.',
|
|
1760
|
+
)
|
|
1761
|
+
load_packet = LoadLocalPacketWrapper(first_packet)
|
|
1762
|
+
sender = LoadLocalFile(load_packet.filename, self.connection)
|
|
1763
|
+
try:
|
|
1764
|
+
sender.send_data()
|
|
1765
|
+
except Exception:
|
|
1766
|
+
self.connection._read_packet() # skip ok packet
|
|
1767
|
+
raise
|
|
1768
|
+
|
|
1769
|
+
ok_packet = self.connection._read_packet()
|
|
1770
|
+
if (
|
|
1771
|
+
not ok_packet.is_ok_packet()
|
|
1772
|
+
): # pragma: no cover - upstream induced protocol error
|
|
1773
|
+
raise err.OperationalError(
|
|
1774
|
+
CR.CR_COMMANDS_OUT_OF_SYNC,
|
|
1775
|
+
'Commands Out of Sync',
|
|
1776
|
+
)
|
|
1777
|
+
self._read_ok_packet(ok_packet)
|
|
1778
|
+
|
|
1779
|
+
def _check_packet_is_eof(self, packet):
|
|
1780
|
+
if not packet.is_eof_packet():
|
|
1781
|
+
return False
|
|
1782
|
+
# TODO: Support CLIENT.DEPRECATE_EOF
|
|
1783
|
+
# 1) Add DEPRECATE_EOF to CAPABILITIES
|
|
1784
|
+
# 2) Mask CAPABILITIES with server_capabilities
|
|
1785
|
+
# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper
|
|
1786
|
+
# instead of EOFPacketWrapper
|
|
1787
|
+
wp = EOFPacketWrapper(packet)
|
|
1788
|
+
self.warning_count = wp.warning_count
|
|
1789
|
+
self.has_next = wp.has_next
|
|
1790
|
+
return True
|
|
1791
|
+
|
|
1792
|
+
def _read_result_packet(self, first_packet):
|
|
1793
|
+
self.field_count = first_packet.read_length_encoded_integer()
|
|
1794
|
+
self._get_descriptions()
|
|
1795
|
+
self._read_rowdata_packet()
|
|
1796
|
+
|
|
1797
|
+
def _read_rowdata_packet_unbuffered(self):
|
|
1798
|
+
# Check if in an active query
|
|
1799
|
+
if not self.unbuffered_active:
|
|
1800
|
+
return
|
|
1801
|
+
|
|
1802
|
+
# EOF
|
|
1803
|
+
packet = self.connection._read_packet()
|
|
1804
|
+
if self._check_packet_is_eof(packet):
|
|
1805
|
+
self.unbuffered_active = False
|
|
1806
|
+
self.connection = None
|
|
1807
|
+
self.rows = None
|
|
1808
|
+
return
|
|
1809
|
+
|
|
1810
|
+
row = self._read_row_from_packet(packet)
|
|
1811
|
+
self.affected_rows = 1
|
|
1812
|
+
self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
|
|
1813
|
+
return row
|
|
1814
|
+
|
|
1815
|
+
def _finish_unbuffered_query(self):
|
|
1816
|
+
# After much reading on the MySQL protocol, it appears that there is,
|
|
1817
|
+
# in fact, no way to stop MySQL from sending all the data after
|
|
1818
|
+
# executing a query, so we just spin, and wait for an EOF packet.
|
|
1819
|
+
while self.unbuffered_active and self.connection._sock is not None:
|
|
1820
|
+
try:
|
|
1821
|
+
packet = self.connection._read_packet()
|
|
1822
|
+
except err.OperationalError as e:
|
|
1823
|
+
if e.args[0] in (
|
|
1824
|
+
ER.QUERY_TIMEOUT,
|
|
1825
|
+
ER.STATEMENT_TIMEOUT,
|
|
1826
|
+
):
|
|
1827
|
+
# if the query timed out we can simply ignore this error
|
|
1828
|
+
self.unbuffered_active = False
|
|
1829
|
+
self.connection = None
|
|
1830
|
+
return
|
|
1831
|
+
|
|
1832
|
+
raise
|
|
1833
|
+
|
|
1834
|
+
if self._check_packet_is_eof(packet):
|
|
1835
|
+
self.unbuffered_active = False
|
|
1836
|
+
self.connection = None # release reference to kill cyclic reference.
|
|
1837
|
+
|
|
1838
|
+
def _read_rowdata_packet(self):
|
|
1839
|
+
"""Read a rowdata packet for each data row in the result set."""
|
|
1840
|
+
rows = []
|
|
1841
|
+
while True:
|
|
1842
|
+
packet = self.connection._read_packet()
|
|
1843
|
+
if self._check_packet_is_eof(packet):
|
|
1844
|
+
self.connection = None # release reference to kill cyclic reference.
|
|
1845
|
+
break
|
|
1846
|
+
rows.append(self._read_row_from_packet(packet))
|
|
1847
|
+
|
|
1848
|
+
self.affected_rows = len(rows)
|
|
1849
|
+
self.rows = tuple(rows)
|
|
1850
|
+
|
|
1851
|
+
def _read_row_from_packet(self, packet):
|
|
1852
|
+
row = []
|
|
1853
|
+
for i, (encoding, converter) in enumerate(self.converters):
|
|
1854
|
+
try:
|
|
1855
|
+
data = packet.read_length_coded_string()
|
|
1856
|
+
except IndexError:
|
|
1857
|
+
# No more columns in this row
|
|
1858
|
+
# See https://github.com/PyMySQL/PyMySQL/pull/434
|
|
1859
|
+
break
|
|
1860
|
+
if data is not None:
|
|
1861
|
+
if encoding is not None:
|
|
1862
|
+
try:
|
|
1863
|
+
data = data.decode(encoding, errors=self.encoding_errors)
|
|
1864
|
+
except UnicodeDecodeError:
|
|
1865
|
+
raise UnicodeDecodeError(
|
|
1866
|
+
'failed to decode string value in column '
|
|
1867
|
+
f"'{self.fields[i].name}' using encoding '{encoding}'; " +
|
|
1868
|
+
"use the 'encoding_errors' option on the connection " +
|
|
1869
|
+
'to specify how to handle this error',
|
|
1870
|
+
)
|
|
1871
|
+
if DEBUG:
|
|
1872
|
+
print('DEBUG: DATA = ', data)
|
|
1873
|
+
if converter is not None:
|
|
1874
|
+
data = converter(data)
|
|
1875
|
+
row.append(data)
|
|
1876
|
+
return tuple(row)
|
|
1877
|
+
|
|
1878
|
+
def _get_descriptions(self):
|
|
1879
|
+
"""Read a column descriptor packet for each column in the result."""
|
|
1880
|
+
self.fields = []
|
|
1881
|
+
self.converters = []
|
|
1882
|
+
use_unicode = self.connection.use_unicode
|
|
1883
|
+
conn_encoding = self.connection.encoding
|
|
1884
|
+
description = []
|
|
1885
|
+
|
|
1886
|
+
for i in range(self.field_count):
|
|
1887
|
+
field = self.connection._read_packet(FieldDescriptorPacket)
|
|
1888
|
+
self.fields.append(field)
|
|
1889
|
+
description.append(field.description())
|
|
1890
|
+
field_type = field.type_code
|
|
1891
|
+
if use_unicode:
|
|
1892
|
+
if field_type == FIELD_TYPE.JSON:
|
|
1893
|
+
# When SELECT from JSON column: charset = binary
|
|
1894
|
+
# When SELECT CAST(... AS JSON): charset = connection encoding
|
|
1895
|
+
# This behavior is different from TEXT / BLOB.
|
|
1896
|
+
# We should decode result by connection encoding regardless charsetnr.
|
|
1897
|
+
# See https://github.com/PyMySQL/PyMySQL/issues/488
|
|
1898
|
+
encoding = conn_encoding # SELECT CAST(... AS JSON)
|
|
1899
|
+
elif field_type in TEXT_TYPES:
|
|
1900
|
+
if field.charsetnr == 63: # binary
|
|
1901
|
+
# TEXTs with charset=binary means BINARY types.
|
|
1902
|
+
encoding = None
|
|
1903
|
+
else:
|
|
1904
|
+
encoding = conn_encoding
|
|
1905
|
+
else:
|
|
1906
|
+
# Integers, Dates and Times, and other basic data is encoded in ascii
|
|
1907
|
+
encoding = 'ascii'
|
|
1908
|
+
else:
|
|
1909
|
+
encoding = None
|
|
1910
|
+
converter = self.connection.decoders.get(field_type)
|
|
1911
|
+
if converter is converters.through:
|
|
1912
|
+
converter = None
|
|
1913
|
+
if DEBUG:
|
|
1914
|
+
print(f'DEBUG: field={field}, converter={converter}')
|
|
1915
|
+
self.converters.append((encoding, converter))
|
|
1916
|
+
|
|
1917
|
+
eof_packet = self.connection._read_packet()
|
|
1918
|
+
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
|
|
1919
|
+
self.description = tuple(description)
|
|
1920
|
+
|
|
1921
|
+
|
|
1922
|
+
class MySQLResultSV(MySQLResult):
|
|
1923
|
+
|
|
1924
|
+
def __init__(self, connection, unbuffered=False):
|
|
1925
|
+
MySQLResult.__init__(self, connection, unbuffered=unbuffered)
|
|
1926
|
+
self.options = {
|
|
1927
|
+
k: v for k, v in dict(
|
|
1928
|
+
default_converters=converters.decoders,
|
|
1929
|
+
results_type=connection.results_type,
|
|
1930
|
+
parse_json=connection.parse_json,
|
|
1931
|
+
invalid_values=connection.invalid_values,
|
|
1932
|
+
unbuffered=unbuffered,
|
|
1933
|
+
encoding_errors=connection.encoding_errors,
|
|
1934
|
+
).items() if v is not UNSET
|
|
1935
|
+
}
|
|
1936
|
+
|
|
1937
|
+
def _read_rowdata_packet(self, *args, **kwargs):
|
|
1938
|
+
return _singlestoredb_accel.read_rowdata_packet(self, False, *args, **kwargs)
|
|
1939
|
+
|
|
1940
|
+
def _read_rowdata_packet_unbuffered(self, *args, **kwargs):
|
|
1941
|
+
return _singlestoredb_accel.read_rowdata_packet(self, True, *args, **kwargs)
|
|
1942
|
+
|
|
1943
|
+
|
|
1944
|
+
class LoadLocalFile:
|
|
1945
|
+
|
|
1946
|
+
def __init__(self, filename, connection):
|
|
1947
|
+
self.filename = filename
|
|
1948
|
+
self.connection = connection
|
|
1949
|
+
|
|
1950
|
+
def send_data(self):
|
|
1951
|
+
"""Send data packets from the local file to the server"""
|
|
1952
|
+
if not self.connection._sock:
|
|
1953
|
+
raise err.InterfaceError(0, 'Connection is closed')
|
|
1954
|
+
|
|
1955
|
+
conn = self.connection
|
|
1956
|
+
infile = conn._local_infile_stream
|
|
1957
|
+
|
|
1958
|
+
# 16KB is efficient enough
|
|
1959
|
+
packet_size = min(conn.max_allowed_packet, 16 * 1024)
|
|
1960
|
+
|
|
1961
|
+
try:
|
|
1962
|
+
|
|
1963
|
+
if self.filename in [':stream:', b':stream:']:
|
|
1964
|
+
|
|
1965
|
+
if infile is None:
|
|
1966
|
+
raise err.OperationalError(
|
|
1967
|
+
ER.FILE_NOT_FOUND,
|
|
1968
|
+
':stream: specified for LOCAL INFILE, but no stream was supplied',
|
|
1969
|
+
)
|
|
1970
|
+
|
|
1971
|
+
# Binary IO
|
|
1972
|
+
elif isinstance(infile, io.RawIOBase):
|
|
1973
|
+
while True:
|
|
1974
|
+
chunk = infile.read(packet_size)
|
|
1975
|
+
if not chunk:
|
|
1976
|
+
break
|
|
1977
|
+
conn.write_packet(chunk)
|
|
1978
|
+
|
|
1979
|
+
# Text IO
|
|
1980
|
+
elif isinstance(infile, io.TextIOBase):
|
|
1981
|
+
while True:
|
|
1982
|
+
chunk = infile.read(packet_size)
|
|
1983
|
+
if not chunk:
|
|
1984
|
+
break
|
|
1985
|
+
conn.write_packet(chunk.encode('utf8'))
|
|
1986
|
+
|
|
1987
|
+
# Iterable of bytes or str
|
|
1988
|
+
elif isinstance(infile, Iterable):
|
|
1989
|
+
for chunk in infile:
|
|
1990
|
+
if not chunk:
|
|
1991
|
+
continue
|
|
1992
|
+
if isinstance(chunk, str):
|
|
1993
|
+
conn.write_packet(chunk.encode('utf8'))
|
|
1994
|
+
else:
|
|
1995
|
+
conn.write_packet(chunk)
|
|
1996
|
+
|
|
1997
|
+
# Queue (empty value ends the iteration)
|
|
1998
|
+
elif isinstance(infile, queue.Queue):
|
|
1999
|
+
while True:
|
|
2000
|
+
chunk = infile.get()
|
|
2001
|
+
if not chunk:
|
|
2002
|
+
break
|
|
2003
|
+
if isinstance(chunk, str):
|
|
2004
|
+
conn.write_packet(chunk.encode('utf8'))
|
|
2005
|
+
else:
|
|
2006
|
+
conn.write_packet(chunk)
|
|
2007
|
+
|
|
2008
|
+
else:
|
|
2009
|
+
raise err.OperationalError(
|
|
2010
|
+
ER.FILE_NOT_FOUND,
|
|
2011
|
+
':stream: specified for LOCAL INFILE, ' +
|
|
2012
|
+
f'but stream type is unrecognized: {infile}',
|
|
2013
|
+
)
|
|
2014
|
+
|
|
2015
|
+
else:
|
|
2016
|
+
try:
|
|
2017
|
+
with open(self.filename, 'rb') as open_file:
|
|
2018
|
+
while True:
|
|
2019
|
+
chunk = open_file.read(packet_size)
|
|
2020
|
+
if not chunk:
|
|
2021
|
+
break
|
|
2022
|
+
conn.write_packet(chunk)
|
|
2023
|
+
except OSError:
|
|
2024
|
+
raise err.OperationalError(
|
|
2025
|
+
ER.FILE_NOT_FOUND,
|
|
2026
|
+
f"Can't find file '{self.filename!s}'",
|
|
2027
|
+
)
|
|
2028
|
+
|
|
2029
|
+
finally:
|
|
2030
|
+
if not conn._closed:
|
|
2031
|
+
# send the empty packet to signify we are done sending data
|
|
2032
|
+
conn.write_packet(b'')
|