singlestoredb 1.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. singlestoredb/__init__.py +75 -0
  2. singlestoredb/ai/__init__.py +2 -0
  3. singlestoredb/ai/chat.py +139 -0
  4. singlestoredb/ai/embeddings.py +128 -0
  5. singlestoredb/alchemy/__init__.py +90 -0
  6. singlestoredb/apps/__init__.py +3 -0
  7. singlestoredb/apps/_cloud_functions.py +90 -0
  8. singlestoredb/apps/_config.py +72 -0
  9. singlestoredb/apps/_connection_info.py +18 -0
  10. singlestoredb/apps/_dashboards.py +47 -0
  11. singlestoredb/apps/_process.py +32 -0
  12. singlestoredb/apps/_python_udfs.py +100 -0
  13. singlestoredb/apps/_stdout_supress.py +30 -0
  14. singlestoredb/apps/_uvicorn_util.py +36 -0
  15. singlestoredb/auth.py +245 -0
  16. singlestoredb/config.py +484 -0
  17. singlestoredb/connection.py +1487 -0
  18. singlestoredb/converters.py +950 -0
  19. singlestoredb/docstring/__init__.py +33 -0
  20. singlestoredb/docstring/attrdoc.py +126 -0
  21. singlestoredb/docstring/common.py +230 -0
  22. singlestoredb/docstring/epydoc.py +267 -0
  23. singlestoredb/docstring/google.py +412 -0
  24. singlestoredb/docstring/numpydoc.py +562 -0
  25. singlestoredb/docstring/parser.py +100 -0
  26. singlestoredb/docstring/py.typed +1 -0
  27. singlestoredb/docstring/rest.py +256 -0
  28. singlestoredb/docstring/tests/__init__.py +1 -0
  29. singlestoredb/docstring/tests/_pydoctor.py +21 -0
  30. singlestoredb/docstring/tests/test_epydoc.py +729 -0
  31. singlestoredb/docstring/tests/test_google.py +1007 -0
  32. singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
  33. singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
  34. singlestoredb/docstring/tests/test_parser.py +248 -0
  35. singlestoredb/docstring/tests/test_rest.py +547 -0
  36. singlestoredb/docstring/tests/test_util.py +70 -0
  37. singlestoredb/docstring/util.py +141 -0
  38. singlestoredb/exceptions.py +120 -0
  39. singlestoredb/functions/__init__.py +16 -0
  40. singlestoredb/functions/decorator.py +201 -0
  41. singlestoredb/functions/dtypes.py +1793 -0
  42. singlestoredb/functions/ext/__init__.py +1 -0
  43. singlestoredb/functions/ext/arrow.py +375 -0
  44. singlestoredb/functions/ext/asgi.py +2133 -0
  45. singlestoredb/functions/ext/json.py +420 -0
  46. singlestoredb/functions/ext/mmap.py +413 -0
  47. singlestoredb/functions/ext/rowdat_1.py +724 -0
  48. singlestoredb/functions/ext/timer.py +89 -0
  49. singlestoredb/functions/ext/utils.py +218 -0
  50. singlestoredb/functions/signature.py +1578 -0
  51. singlestoredb/functions/typing/__init__.py +41 -0
  52. singlestoredb/functions/typing/numpy.py +20 -0
  53. singlestoredb/functions/typing/pandas.py +2 -0
  54. singlestoredb/functions/typing/polars.py +2 -0
  55. singlestoredb/functions/typing/pyarrow.py +2 -0
  56. singlestoredb/functions/utils.py +421 -0
  57. singlestoredb/fusion/__init__.py +11 -0
  58. singlestoredb/fusion/graphql.py +213 -0
  59. singlestoredb/fusion/handler.py +916 -0
  60. singlestoredb/fusion/handlers/__init__.py +0 -0
  61. singlestoredb/fusion/handlers/export.py +525 -0
  62. singlestoredb/fusion/handlers/files.py +690 -0
  63. singlestoredb/fusion/handlers/job.py +660 -0
  64. singlestoredb/fusion/handlers/models.py +250 -0
  65. singlestoredb/fusion/handlers/stage.py +502 -0
  66. singlestoredb/fusion/handlers/utils.py +324 -0
  67. singlestoredb/fusion/handlers/workspace.py +956 -0
  68. singlestoredb/fusion/registry.py +249 -0
  69. singlestoredb/fusion/result.py +399 -0
  70. singlestoredb/http/__init__.py +27 -0
  71. singlestoredb/http/connection.py +1267 -0
  72. singlestoredb/magics/__init__.py +34 -0
  73. singlestoredb/magics/run_personal.py +137 -0
  74. singlestoredb/magics/run_shared.py +134 -0
  75. singlestoredb/management/__init__.py +9 -0
  76. singlestoredb/management/billing_usage.py +148 -0
  77. singlestoredb/management/cluster.py +462 -0
  78. singlestoredb/management/export.py +295 -0
  79. singlestoredb/management/files.py +1102 -0
  80. singlestoredb/management/inference_api.py +105 -0
  81. singlestoredb/management/job.py +887 -0
  82. singlestoredb/management/manager.py +373 -0
  83. singlestoredb/management/organization.py +226 -0
  84. singlestoredb/management/region.py +169 -0
  85. singlestoredb/management/utils.py +423 -0
  86. singlestoredb/management/workspace.py +1927 -0
  87. singlestoredb/mysql/__init__.py +177 -0
  88. singlestoredb/mysql/_auth.py +298 -0
  89. singlestoredb/mysql/charset.py +214 -0
  90. singlestoredb/mysql/connection.py +2032 -0
  91. singlestoredb/mysql/constants/CLIENT.py +38 -0
  92. singlestoredb/mysql/constants/COMMAND.py +32 -0
  93. singlestoredb/mysql/constants/CR.py +78 -0
  94. singlestoredb/mysql/constants/ER.py +474 -0
  95. singlestoredb/mysql/constants/EXTENDED_TYPE.py +3 -0
  96. singlestoredb/mysql/constants/FIELD_TYPE.py +48 -0
  97. singlestoredb/mysql/constants/FLAG.py +15 -0
  98. singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
  99. singlestoredb/mysql/constants/VECTOR_TYPE.py +6 -0
  100. singlestoredb/mysql/constants/__init__.py +0 -0
  101. singlestoredb/mysql/converters.py +271 -0
  102. singlestoredb/mysql/cursors.py +896 -0
  103. singlestoredb/mysql/err.py +92 -0
  104. singlestoredb/mysql/optionfile.py +20 -0
  105. singlestoredb/mysql/protocol.py +450 -0
  106. singlestoredb/mysql/tests/__init__.py +19 -0
  107. singlestoredb/mysql/tests/base.py +126 -0
  108. singlestoredb/mysql/tests/conftest.py +37 -0
  109. singlestoredb/mysql/tests/test_DictCursor.py +132 -0
  110. singlestoredb/mysql/tests/test_SSCursor.py +141 -0
  111. singlestoredb/mysql/tests/test_basic.py +452 -0
  112. singlestoredb/mysql/tests/test_connection.py +851 -0
  113. singlestoredb/mysql/tests/test_converters.py +58 -0
  114. singlestoredb/mysql/tests/test_cursor.py +141 -0
  115. singlestoredb/mysql/tests/test_err.py +16 -0
  116. singlestoredb/mysql/tests/test_issues.py +514 -0
  117. singlestoredb/mysql/tests/test_load_local.py +75 -0
  118. singlestoredb/mysql/tests/test_nextset.py +88 -0
  119. singlestoredb/mysql/tests/test_optionfile.py +27 -0
  120. singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
  121. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
  122. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
  123. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
  124. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
  125. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
  126. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
  127. singlestoredb/mysql/times.py +23 -0
  128. singlestoredb/notebook/__init__.py +16 -0
  129. singlestoredb/notebook/_objects.py +213 -0
  130. singlestoredb/notebook/_portal.py +352 -0
  131. singlestoredb/py.typed +0 -0
  132. singlestoredb/pytest.py +352 -0
  133. singlestoredb/server/__init__.py +0 -0
  134. singlestoredb/server/docker.py +452 -0
  135. singlestoredb/server/free_tier.py +267 -0
  136. singlestoredb/tests/__init__.py +0 -0
  137. singlestoredb/tests/alltypes.sql +307 -0
  138. singlestoredb/tests/alltypes_no_nulls.sql +208 -0
  139. singlestoredb/tests/empty.sql +0 -0
  140. singlestoredb/tests/ext_funcs/__init__.py +702 -0
  141. singlestoredb/tests/local_infile.csv +3 -0
  142. singlestoredb/tests/test.ipynb +18 -0
  143. singlestoredb/tests/test.sql +680 -0
  144. singlestoredb/tests/test2.ipynb +18 -0
  145. singlestoredb/tests/test2.sql +1 -0
  146. singlestoredb/tests/test_basics.py +1332 -0
  147. singlestoredb/tests/test_config.py +318 -0
  148. singlestoredb/tests/test_connection.py +3103 -0
  149. singlestoredb/tests/test_dbapi.py +27 -0
  150. singlestoredb/tests/test_exceptions.py +45 -0
  151. singlestoredb/tests/test_ext_func.py +1472 -0
  152. singlestoredb/tests/test_ext_func_data.py +1101 -0
  153. singlestoredb/tests/test_fusion.py +1527 -0
  154. singlestoredb/tests/test_http.py +288 -0
  155. singlestoredb/tests/test_management.py +1599 -0
  156. singlestoredb/tests/test_plugin.py +33 -0
  157. singlestoredb/tests/test_results.py +171 -0
  158. singlestoredb/tests/test_types.py +132 -0
  159. singlestoredb/tests/test_udf.py +737 -0
  160. singlestoredb/tests/test_udf_returns.py +459 -0
  161. singlestoredb/tests/test_vectorstore.py +51 -0
  162. singlestoredb/tests/test_xdict.py +333 -0
  163. singlestoredb/tests/utils.py +141 -0
  164. singlestoredb/types.py +373 -0
  165. singlestoredb/utils/__init__.py +0 -0
  166. singlestoredb/utils/config.py +950 -0
  167. singlestoredb/utils/convert_rows.py +69 -0
  168. singlestoredb/utils/debug.py +13 -0
  169. singlestoredb/utils/dtypes.py +205 -0
  170. singlestoredb/utils/events.py +65 -0
  171. singlestoredb/utils/mogrify.py +151 -0
  172. singlestoredb/utils/results.py +585 -0
  173. singlestoredb/utils/xdict.py +425 -0
  174. singlestoredb/vectorstore.py +192 -0
  175. singlestoredb/warnings.py +5 -0
  176. singlestoredb-1.16.1.dist-info/METADATA +165 -0
  177. singlestoredb-1.16.1.dist-info/RECORD +183 -0
  178. singlestoredb-1.16.1.dist-info/WHEEL +5 -0
  179. singlestoredb-1.16.1.dist-info/entry_points.txt +2 -0
  180. singlestoredb-1.16.1.dist-info/licenses/LICENSE +201 -0
  181. singlestoredb-1.16.1.dist-info/top_level.txt +3 -0
  182. sqlx/__init__.py +4 -0
  183. sqlx/magic.py +113 -0
@@ -0,0 +1,1267 @@
1
+ #!/usr/bin/env python
2
+ """SingleStoreDB HTTP API interface."""
3
+ import datetime
4
+ import decimal
5
+ import functools
6
+ import io
7
+ import json
8
+ import math
9
+ import os
10
+ import re
11
+ import time
12
+ from base64 import b64decode
13
+ from collections.abc import Iterable
14
+ from collections.abc import Sequence
15
+ from typing import Any
16
+ from typing import Callable
17
+ from typing import Dict
18
+ from typing import List
19
+ from typing import Optional
20
+ from typing import Tuple
21
+ from typing import Union
22
+ from urllib.parse import urljoin
23
+ from urllib.parse import urlparse
24
+
25
+ import requests
26
+
27
+ try:
28
+ import numpy as np
29
+ has_numpy = True
30
+ except ImportError:
31
+ has_numpy = False
32
+
33
+ try:
34
+ import pygeos
35
+ has_pygeos = True
36
+ except ImportError:
37
+ has_pygeos = False
38
+
39
+ try:
40
+ import shapely.geometry
41
+ import shapely.wkt
42
+ has_shapely = True
43
+ except ImportError:
44
+ has_shapely = False
45
+
46
+ try:
47
+ import pydantic
48
+ has_pydantic = True
49
+ except ImportError:
50
+ has_pydantic = False
51
+
52
+ from .. import connection
53
+ from .. import fusion
54
+ from .. import types
55
+ from ..config import get_option
56
+ from ..converters import converters
57
+ from ..exceptions import DatabaseError # noqa: F401
58
+ from ..exceptions import DataError
59
+ from ..exceptions import Error # noqa: F401
60
+ from ..exceptions import IntegrityError
61
+ from ..exceptions import InterfaceError
62
+ from ..exceptions import InternalError
63
+ from ..exceptions import NotSupportedError
64
+ from ..exceptions import OperationalError
65
+ from ..exceptions import ProgrammingError
66
+ from ..exceptions import Warning # noqa: F401
67
+ from ..utils.convert_rows import convert_rows
68
+ from ..utils.debug import log_query
69
+ from ..utils.mogrify import mogrify
70
+ from ..utils.results import Description
71
+ from ..utils.results import format_results
72
+ from ..utils.results import get_schema
73
+ from ..utils.results import Result
74
+
75
+
76
+ # DB-API settings
77
+ apilevel = '2.0'
78
+ paramstyle = 'named'
79
+ threadsafety = 1
80
+
81
+
82
+ _interface_errors = set([
83
+ 0,
84
+ 2013, # CR_SERVER_LOST
85
+ 2006, # CR_SERVER_GONE_ERROR
86
+ 2012, # CR_HANDSHAKE_ERR
87
+ 2004, # CR_IPSOCK_ERROR
88
+ 2014, # CR_COMMANDS_OUT_OF_SYNC
89
+ ])
90
+ _data_errors = set([
91
+ 1406, # ER_DATA_TOO_LONG
92
+ 1441, # ER_DATETIME_FUNCTION_OVERFLOW
93
+ 1365, # ER_DIVISION_BY_ZERO
94
+ 1230, # ER_NO_DEFAULT
95
+ 1171, # ER_PRIMARY_CANT_HAVE_NULL
96
+ 1264, # ER_WARN_DATA_OUT_OF_RANGE
97
+ 1265, # ER_WARN_DATA_TRUNCATED
98
+ ])
99
+ _programming_errors = set([
100
+ 1065, # ER_EMPTY_QUERY
101
+ 1179, # ER_CANT_DO_THIS_DURING_AN_TRANSACTION
102
+ 1007, # ER_DB_CREATE_EXISTS
103
+ 1110, # ER_FIELD_SPECIFIED_TWICE
104
+ 1111, # ER_INVALID_GROUP_FUNC_USE
105
+ 1082, # ER_NO_SUCH_INDEX
106
+ 1741, # ER_NO_SUCH_KEY_VALUE
107
+ 1146, # ER_NO_SUCH_TABLE
108
+ 1449, # ER_NO_SUCH_USER
109
+ 1064, # ER_PARSE_ERROR
110
+ 1149, # ER_SYNTAX_ERROR
111
+ 1113, # ER_TABLE_MUST_HAVE_COLUMNS
112
+ 1112, # ER_UNSUPPORTED_EXTENSION
113
+ 1102, # ER_WRONG_DB_NAME
114
+ 1103, # ER_WRONG_TABLE_NAME
115
+ 1049, # ER_BAD_DB_ERROR
116
+ 1582, # ER_??? Wrong number of args
117
+ ])
118
+ _integrity_errors = set([
119
+ 1215, # ER_CANNOT_ADD_FOREIGN
120
+ 1062, # ER_DUP_ENTRY
121
+ 1169, # ER_DUP_UNIQUE
122
+ 1364, # ER_NO_DEFAULT_FOR_FIELD
123
+ 1216, # ER_NO_REFERENCED_ROW
124
+ 1452, # ER_NO_REFERENCED_ROW_2
125
+ 1217, # ER_ROW_IS_REFERENCED
126
+ 1451, # ER_ROW_IS_REFERENCED_2
127
+ 1460, # ER_XAER_OUTSIDE
128
+ 1401, # ER_XAER_RMERR
129
+ 1048, # ER_BAD_NULL_ERROR
130
+ 1264, # ER_DATA_OUT_OF_RANGE
131
+ 4025, # ER_CONSTRAINT_FAILED
132
+ 1826, # ER_DUP_CONSTRAINT_NAME
133
+ ])
134
+
135
+
136
+ def get_precision_scale(type_code: str) -> Tuple[Optional[int], Optional[int]]:
137
+ """Parse the precision and scale from a data type."""
138
+ if '(' not in type_code:
139
+ return (None, None)
140
+ m = re.search(r'\(\s*(\d+)\s*,\s*(\d+)\s*\)', type_code)
141
+ if m:
142
+ return int(m.group(1)), int(m.group(2))
143
+ m = re.search(r'\(\s*(\d+)\s*\)', type_code)
144
+ if m:
145
+ return (int(m.group(1)), None)
146
+ raise ValueError(f'Unrecognized type code: {type_code}')
147
+
148
+
149
+ def get_exc_type(code: int) -> type:
150
+ """Map error code to DB-API error type."""
151
+ if code in _interface_errors:
152
+ return InterfaceError
153
+ if code in _data_errors:
154
+ return DataError
155
+ if code in _programming_errors:
156
+ return ProgrammingError
157
+ if code in _integrity_errors:
158
+ return IntegrityError
159
+ if code >= 1000:
160
+ return OperationalError
161
+ return InternalError
162
+
163
+
164
+ def identity(x: Any) -> Any:
165
+ """Return input value."""
166
+ return x
167
+
168
+
169
+ def b64decode_converter(
170
+ converter: Callable[..., Any],
171
+ x: Optional[str],
172
+ encoding: str = 'utf-8',
173
+ ) -> Optional[bytes]:
174
+ """Decode value before applying converter."""
175
+ if x is None:
176
+ return None
177
+ if converter is None:
178
+ return b64decode(x)
179
+ return converter(b64decode(x))
180
+
181
+
182
+ def encode_timedelta(obj: datetime.timedelta) -> str:
183
+ """Encode timedelta as str."""
184
+ seconds = int(obj.seconds) % 60
185
+ minutes = int(obj.seconds // 60) % 60
186
+ hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
187
+ if obj.microseconds:
188
+ fmt = '{0:02d}:{1:02d}:{2:02d}.{3:06d}'
189
+ else:
190
+ fmt = '{0:02d}:{1:02d}:{2:02d}'
191
+ return fmt.format(hours, minutes, seconds, obj.microseconds)
192
+
193
+
194
+ def encode_time(obj: datetime.time) -> str:
195
+ """Encode time as str."""
196
+ if obj.microsecond:
197
+ fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'
198
+ else:
199
+ fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}'
200
+ return fmt.format(obj)
201
+
202
+
203
+ def encode_datetime(obj: datetime.datetime) -> str:
204
+ """Encode datetime as str."""
205
+ if obj.microsecond:
206
+ fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \
207
+ '{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'
208
+ else:
209
+ fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \
210
+ '{0.hour:02}:{0.minute:02}:{0.second:02}'
211
+ return fmt.format(obj)
212
+
213
+
214
+ def encode_date(obj: datetime.date) -> str:
215
+ """Encode date as str."""
216
+ fmt = '{0.year:04}-{0.month:02}-{0.day:02}'
217
+ return fmt.format(obj)
218
+
219
+
220
+ def encode_struct_time(obj: time.struct_time) -> str:
221
+ """Encode time struct to str."""
222
+ return encode_datetime(datetime.datetime(*obj[:6]))
223
+
224
+
225
+ def encode_decimal(o: decimal.Decimal) -> str:
226
+ """Encode decimal to str."""
227
+ return format(o, 'f')
228
+
229
+
230
+ # Most argument encoding is done by the JSON encoder, but these
231
+ # are exceptions to the rule.
232
+ encoders = {
233
+ datetime.datetime: encode_datetime,
234
+ datetime.date: encode_date,
235
+ datetime.time: encode_time,
236
+ datetime.timedelta: encode_timedelta,
237
+ time.struct_time: encode_struct_time,
238
+ decimal.Decimal: encode_decimal,
239
+ }
240
+
241
+
242
+ if has_shapely:
243
+ encoders[shapely.geometry.Point] = shapely.wkt.dumps
244
+ encoders[shapely.geometry.Polygon] = shapely.wkt.dumps
245
+ encoders[shapely.geometry.LineString] = shapely.wkt.dumps
246
+
247
+ if has_numpy:
248
+
249
+ def encode_ndarray(obj: np.ndarray) -> bytes: # type: ignore
250
+ """Encode an ndarray as bytes."""
251
+ return obj.tobytes()
252
+
253
+ encoders[np.ndarray] = encode_ndarray
254
+
255
+ if has_pygeos:
256
+ encoders[pygeos.Geometry] = pygeos.io.to_wkt
257
+
258
+
259
+ def convert_special_type(
260
+ arg: Any,
261
+ nan_as_null: bool = False,
262
+ inf_as_null: bool = False,
263
+ ) -> Any:
264
+ """Convert special data type objects."""
265
+ dtype = type(arg)
266
+ if dtype is float or \
267
+ (
268
+ has_numpy and dtype in (
269
+ np.float16, np.float32, np.float64,
270
+ getattr(np, 'float128', np.float64),
271
+ )
272
+ ):
273
+ if nan_as_null and math.isnan(arg):
274
+ return None
275
+ if inf_as_null and math.isinf(arg):
276
+ return None
277
+ func = encoders.get(dtype, None)
278
+ if func is not None:
279
+ return func(arg) # type: ignore
280
+ return arg
281
+
282
+
283
+ def convert_special_params(
284
+ params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
285
+ nan_as_null: bool = False,
286
+ inf_as_null: bool = False,
287
+ ) -> Optional[Union[Sequence[Any], Dict[str, Any]]]:
288
+ """Convert parameters of special data types."""
289
+ if params is None:
290
+ return params
291
+ converter = functools.partial(
292
+ convert_special_type,
293
+ nan_as_null=nan_as_null,
294
+ inf_as_null=inf_as_null,
295
+ )
296
+ if isinstance(params, Dict):
297
+ return {k: converter(v) for k, v in params.items()}
298
+ return tuple(map(converter, params))
299
+
300
+
301
+ class PyMyField(object):
302
+ """Field for PyMySQL compatibility."""
303
+
304
+ def __init__(self, name: str, flags: int, charset: int) -> None:
305
+ self.name = name
306
+ self.flags = flags
307
+ self.charsetnr = charset
308
+
309
+
310
+ class PyMyResult(object):
311
+ """Result for PyMySQL compatibility."""
312
+
313
+ def __init__(self) -> None:
314
+ self.fields: List[PyMyField] = []
315
+ self.unbuffered_active = False
316
+
317
+ def append(self, item: PyMyField) -> None:
318
+ self.fields.append(item)
319
+
320
+
321
+ class Cursor(connection.Cursor):
322
+ """
323
+ SingleStoreDB HTTP database cursor.
324
+
325
+ Cursor objects should not be created directly. They should come from
326
+ the `cursor` method on the `Connection` object.
327
+
328
+ Parameters
329
+ ----------
330
+ connection : Connection
331
+ The HTTP Connection object the cursor belongs to
332
+
333
+ """
334
+
335
+ def __init__(self, conn: 'Connection'):
336
+ connection.Cursor.__init__(self, conn)
337
+ self._connection: Optional[Connection] = conn
338
+ self._results: List[List[Tuple[Any, ...]]] = [[]]
339
+ self._results_type: str = self._connection._results_type \
340
+ if self._connection is not None else 'tuples'
341
+ self._row_idx: int = -1
342
+ self._result_idx: int = -1
343
+ self._descriptions: List[List[Description]] = []
344
+ self._schemas: List[Dict[str, Any]] = []
345
+ self.arraysize: int = get_option('results.arraysize')
346
+ self.rowcount: int = 0
347
+ self.lastrowid: Optional[int] = None
348
+ self._pymy_results: List[PyMyResult] = []
349
+ self._expect_results: bool = False
350
+
351
+ @property
352
+ def _result(self) -> Optional[PyMyResult]:
353
+ """Return Result object for PyMySQL compatibility."""
354
+ if self._result_idx < 0:
355
+ return None
356
+ return self._pymy_results[self._result_idx]
357
+
358
+ @property
359
+ def description(self) -> Optional[List[Description]]:
360
+ """Return description for current result set."""
361
+ if not self._descriptions:
362
+ return None
363
+ if self._result_idx >= 0 and self._result_idx < len(self._descriptions):
364
+ return self._descriptions[self._result_idx]
365
+ return None
366
+
367
+ @property
368
+ def _schema(self) -> Optional[Any]:
369
+ if not self._schemas:
370
+ return None
371
+ if self._result_idx >= 0 and self._result_idx < len(self._schemas):
372
+ return self._schemas[self._result_idx]
373
+ return None
374
+
375
+ def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
376
+ """
377
+ Invoke a POST request on the HTTP connection.
378
+
379
+ Parameters
380
+ ----------
381
+ path : str
382
+ The path of the resource
383
+ *args : positional parameters, optional
384
+ Extra parameters to the POST request
385
+ **kwargs : keyword parameters, optional
386
+ Extra keyword parameters to the POST request
387
+
388
+ Returns
389
+ -------
390
+ requests.Response
391
+
392
+ """
393
+ if self._connection is None:
394
+ raise ProgrammingError(errno=2048, msg='Connection is closed.')
395
+ if 'timeout' not in kwargs:
396
+ kwargs['timeout'] = self._connection.connection_params['connect_timeout']
397
+ return self._connection._post(path, *args, **kwargs)
398
+
399
+ def callproc(
400
+ self, name: str,
401
+ params: Optional[Sequence[Any]] = None,
402
+ ) -> None:
403
+ """
404
+ Call a stored procedure.
405
+
406
+ Parameters
407
+ ----------
408
+ name : str
409
+ Name of the stored procedure
410
+ params : sequence, optional
411
+ Parameters to the stored procedure
412
+
413
+ """
414
+ if self._connection is None:
415
+ raise ProgrammingError(errno=2048, msg='Connection is closed.')
416
+
417
+ name = connection._name_check(name)
418
+
419
+ if not params:
420
+ self._execute(f'CALL {name}();', is_callproc=True)
421
+ else:
422
+ keys = ', '.join(['%s' for i in range(len(params))])
423
+ self._execute(f'CALL {name}({keys});', params, is_callproc=True)
424
+
425
+ def close(self) -> None:
426
+ """Close the cursor."""
427
+ self._connection = None
428
+
429
+ def execute(
430
+ self, query: str,
431
+ args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
432
+ infile_stream: Optional[ # type: ignore
433
+ Union[
434
+ io.RawIOBase,
435
+ io.TextIOBase,
436
+ Iterable[Union[bytes, str]],
437
+ connection.InfileQueue,
438
+ ]
439
+ ] = None,
440
+ ) -> int:
441
+ """
442
+ Execute a SQL statement.
443
+
444
+ Parameters
445
+ ----------
446
+ query : str
447
+ The SQL statement to execute
448
+ args : iterable or dict, optional
449
+ Parameters to substitute into the SQL code
450
+
451
+ """
452
+ return self._execute(query, args, infile_stream=infile_stream)
453
+
454
+ def _validate_param_subs(
455
+ self, query: str,
456
+ args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
457
+ ) -> None:
458
+ """Make sure the parameter substitions are valid."""
459
+ if args:
460
+ if isinstance(args, Sequence):
461
+ query = query % tuple(args)
462
+ else:
463
+ query = query % args
464
+
465
+ def _execute_fusion_query(
466
+ self,
467
+ oper: Union[str, bytes],
468
+ params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
469
+ handler: Any = None,
470
+ ) -> int:
471
+ oper = mogrify(oper, params)
472
+
473
+ if isinstance(oper, bytes):
474
+ oper = oper.decode('utf-8')
475
+
476
+ log_query(oper, None)
477
+
478
+ results_type = self._results_type
479
+ self._results_type = 'tuples'
480
+ try:
481
+ mgmt_res = fusion.execute(
482
+ self._connection, # type: ignore
483
+ oper,
484
+ handler=handler,
485
+ )
486
+ finally:
487
+ self._results_type = results_type
488
+
489
+ self._descriptions.append(list(mgmt_res.description))
490
+ self._schemas.append(get_schema(self._results_type, list(mgmt_res.description)))
491
+ self._results.append(list(mgmt_res.rows))
492
+ self.rowcount = len(self._results[-1])
493
+
494
+ pymy_res = PyMyResult()
495
+ for field in mgmt_res.fields:
496
+ pymy_res.append(
497
+ PyMyField(
498
+ field.name,
499
+ field.flags,
500
+ field.charsetnr,
501
+ ),
502
+ )
503
+
504
+ self._pymy_results.append(pymy_res)
505
+
506
+ if self._results and self._results[0]:
507
+ self._row_idx = 0
508
+ self._result_idx = 0
509
+
510
+ return self.rowcount
511
+
512
+ def _execute(
513
+ self, oper: str,
514
+ params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
515
+ is_callproc: bool = False,
516
+ infile_stream: Optional[ # type: ignore
517
+ Union[
518
+ io.RawIOBase,
519
+ io.TextIOBase,
520
+ Iterable[Union[bytes, str]],
521
+ connection.InfileQueue,
522
+ ]
523
+ ] = None,
524
+ ) -> int:
525
+ self._descriptions = []
526
+ self._schemas = []
527
+ self._results = []
528
+ self._pymy_results = []
529
+ self._row_idx = -1
530
+ self._result_idx = -1
531
+ self.rowcount = 0
532
+ self._expect_results = False
533
+
534
+ if self._connection is None:
535
+ raise ProgrammingError(errno=2048, msg='Connection is closed.')
536
+
537
+ sql_type = 'exec'
538
+ if re.match(r'^\s*(select|show|call|echo|describe|with)\s+', oper, flags=re.I):
539
+ self._expect_results = True
540
+ sql_type = 'query'
541
+
542
+ if has_pydantic and isinstance(params, pydantic.BaseModel):
543
+ params = params.model_dump()
544
+
545
+ self._validate_param_subs(oper, params)
546
+
547
+ handler = fusion.get_handler(oper)
548
+ if handler is not None:
549
+ return self._execute_fusion_query(oper, params, handler=handler)
550
+
551
+ oper, params = self._connection._convert_params(oper, params)
552
+
553
+ log_query(oper, params)
554
+
555
+ data: Dict[str, Any] = dict(sql=oper)
556
+ if params is not None:
557
+ data['args'] = convert_special_params(
558
+ params,
559
+ nan_as_null=self._connection.connection_params['nan_as_null'],
560
+ inf_as_null=self._connection.connection_params['inf_as_null'],
561
+ )
562
+ if self._connection._database:
563
+ data['database'] = self._connection._database
564
+
565
+ if sql_type == 'query':
566
+ res = self._post('query/tuples', json=data)
567
+ else:
568
+ res = self._post('exec', json=data)
569
+
570
+ if res.status_code >= 400:
571
+ if res.text:
572
+ m = re.match(r'^Error\s+(\d+).*?:', res.text)
573
+ if m:
574
+ code = m.group(1)
575
+ msg = res.text.split(':', 1)[-1]
576
+ icode = int(code.split()[-1])
577
+ else:
578
+ icode = res.status_code
579
+ msg = res.text
580
+ raise get_exc_type(icode)(icode, msg.strip())
581
+ raise InterfaceError(errno=res.status_code, msg='HTTP Error')
582
+
583
+ out = json.loads(res.text)
584
+
585
+ if 'error' in out:
586
+ raise OperationalError(
587
+ errno=out['error'].get('code', 0),
588
+ msg=out['error'].get('message', 'HTTP Error'),
589
+ )
590
+
591
+ if sql_type == 'query':
592
+ # description: (name, type_code, display_size, internal_size,
593
+ # precision, scale, null_ok, column_flags, charset)
594
+
595
+ # Remove converters for things the JSON parser already converted
596
+ http_converters = dict(self._connection.decoders)
597
+ http_converters.pop(4, None)
598
+ http_converters.pop(5, None)
599
+ http_converters.pop(6, None)
600
+ http_converters.pop(15, None)
601
+ http_converters.pop(245, None)
602
+ http_converters.pop(247, None)
603
+ http_converters.pop(249, None)
604
+ http_converters.pop(250, None)
605
+ http_converters.pop(251, None)
606
+ http_converters.pop(252, None)
607
+ http_converters.pop(253, None)
608
+ http_converters.pop(254, None)
609
+
610
+ # Merge passed in converters
611
+ if self._connection._conv:
612
+ for k, v in self._connection._conv.items():
613
+ if isinstance(k, int):
614
+ http_converters[k] = v
615
+
616
+ # Make JSON a string for Arrow
617
+ if 'arrow' in self._results_type:
618
+ def json_to_str(x: Any) -> Optional[str]:
619
+ if x is None:
620
+ return None
621
+ return json.dumps(x)
622
+ http_converters[245] = json_to_str
623
+
624
+ # Don't convert date/times in polars
625
+ elif 'polars' in self._results_type:
626
+ http_converters.pop(7, None)
627
+ http_converters.pop(10, None)
628
+ http_converters.pop(12, None)
629
+
630
+ results = out['results']
631
+
632
+ # Convert data to Python types
633
+ if results and results[0]:
634
+ self._row_idx = 0
635
+ self._result_idx = 0
636
+
637
+ for result in results:
638
+
639
+ pymy_res = PyMyResult()
640
+ convs = []
641
+
642
+ description: List[Description] = []
643
+ for i, col in enumerate(result.get('columns', [])):
644
+ charset = 0
645
+ flags = 0
646
+ data_type = col['dataType'].split('(')[0]
647
+ type_code = types.ColumnType.get_code(data_type)
648
+ prec, scale = get_precision_scale(col['dataType'])
649
+ converter = http_converters.get(type_code, None)
650
+
651
+ if 'UNSIGNED' in data_type:
652
+ flags = 32
653
+
654
+ if data_type.endswith('BLOB') or data_type.endswith('BINARY'):
655
+ converter = functools.partial(
656
+ b64decode_converter, converter, # type: ignore
657
+ )
658
+ charset = 63 # BINARY
659
+
660
+ if type_code == 0: # DECIMAL
661
+ type_code = types.ColumnType.get_code('NEWDECIMAL')
662
+ elif type_code == 15: # VARCHAR / VARBINARY
663
+ type_code = types.ColumnType.get_code('VARSTRING')
664
+
665
+ if converter is not None:
666
+ convs.append((i, None, converter))
667
+
668
+ description.append(
669
+ Description(
670
+ str(col['name']), type_code,
671
+ None, None, prec, scale,
672
+ col.get('nullable', False),
673
+ flags, charset,
674
+ ),
675
+ )
676
+ pymy_res.append(PyMyField(col['name'], flags, charset))
677
+
678
+ self._descriptions.append(description)
679
+ self._schemas.append(get_schema(self._results_type, description))
680
+
681
+ rows = convert_rows(result.get('rows', []), convs)
682
+
683
+ self._results.append(rows)
684
+ self._pymy_results.append(pymy_res)
685
+
686
+ # For compatibility with PyMySQL/MySQLdb
687
+ if is_callproc:
688
+ self._results.append([])
689
+
690
+ self.rowcount = len(self._results[0])
691
+
692
+ else:
693
+ # For compatibility with PyMySQL/MySQLdb
694
+ if is_callproc:
695
+ self._results.append([])
696
+
697
+ self.rowcount = out['rowsAffected']
698
+
699
+ return self.rowcount
700
+
701
+ def executemany(
702
+ self, query: str,
703
+ args: Optional[Sequence[Union[Sequence[Any], Dict[str, Any]]]] = None,
704
+ ) -> int:
705
+ """
706
+ Execute SQL code against multiple sets of parameters.
707
+
708
+ Parameters
709
+ ----------
710
+ query : str
711
+ The SQL statement to execute
712
+ args : iterable of iterables or dicts, optional
713
+ Sets of parameters to substitute into the SQL code
714
+
715
+ """
716
+ if self._connection is None:
717
+ raise ProgrammingError(errno=2048, msg='Connection is closed.')
718
+
719
+ results = []
720
+ rowcount = 0
721
+ if args is not None and len(args) > 0:
722
+ description = []
723
+ schema = {}
724
+ # Detect dataframes
725
+ if hasattr(args, 'itertuples'):
726
+ argiter = args.itertuples(index=False) # type: ignore
727
+ else:
728
+ argiter = iter(args)
729
+ for params in argiter:
730
+ self.execute(query, params)
731
+ if self._descriptions:
732
+ description = self._descriptions[-1]
733
+ if self._schemas:
734
+ schema = self._schemas[-1]
735
+ if self._rows is not None:
736
+ results.append(self._rows)
737
+ rowcount += self.rowcount
738
+ self._results = results
739
+ self._descriptions = [description for _ in range(len(results))]
740
+ self._schemas = [schema for _ in range(len(results))]
741
+ else:
742
+ self.execute(query)
743
+ rowcount += self.rowcount
744
+
745
+ self.rowcount = rowcount
746
+
747
+ return self.rowcount
748
+
749
+ @property
750
+ def _has_row(self) -> bool:
751
+ """Determine if a row is available."""
752
+ if self._result_idx < 0 or self._result_idx >= len(self._results):
753
+ return False
754
+ if self._row_idx < 0 or self._row_idx >= len(self._results[self._result_idx]):
755
+ return False
756
+ return True
757
+
758
+ @property
759
+ def _rows(self) -> List[Tuple[Any, ...]]:
760
+ """Return current set of rows."""
761
+ if not self._has_row:
762
+ return []
763
+ return self._results[self._result_idx]
764
+
765
+ def fetchone(self) -> Optional[Result]:
766
+ """
767
+ Fetch a single row from the result set.
768
+
769
+ Returns
770
+ -------
771
+ tuple
772
+ Values of the returned row if there are rows remaining
773
+ None
774
+ If there are no rows left to return
775
+
776
+ """
777
+ if self._connection is None:
778
+ raise ProgrammingError(errno=2048, msg='Connection is closed')
779
+ if not self._expect_results:
780
+ raise self._connection.ProgrammingError(msg='No query has been submitted')
781
+ if not self._has_row:
782
+ return None
783
+ out = self._rows[self._row_idx]
784
+ self._row_idx += 1
785
+ return format_results(
786
+ self._results_type,
787
+ self.description or [],
788
+ out, single=True,
789
+ schema=self._schema,
790
+ )
791
+
792
+ def fetchmany(
793
+ self,
794
+ size: Optional[int] = None,
795
+ ) -> Result:
796
+ """
797
+ Fetch `size` rows from the result.
798
+
799
+ If `size` is not specified, the `arraysize` attribute is used.
800
+
801
+ Returns
802
+ -------
803
+ list of tuples
804
+ Values of the returned rows if there are rows remaining
805
+
806
+ """
807
+ if self._connection is None:
808
+ raise ProgrammingError(errno=2048, msg='Connection is closed')
809
+ if not self._expect_results:
810
+ raise self._connection.ProgrammingError(msg='No query has been submitted')
811
+ if not self._has_row:
812
+ if 'dict' in self._results_type:
813
+ return {}
814
+ return tuple()
815
+ if not size:
816
+ size = max(int(self.arraysize), 1)
817
+ else:
818
+ size = max(int(size), 1)
819
+ out = self._rows[self._row_idx:self._row_idx+size]
820
+ self._row_idx += len(out)
821
+ return format_results(
822
+ self._results_type, self.description or [],
823
+ out, schema=self._schema,
824
+ )
825
+
826
+ def fetchall(self) -> Result:
827
+ """
828
+ Fetch all rows in the result set.
829
+
830
+ Returns
831
+ -------
832
+ list of tuples
833
+ Values of the returned rows if there are rows remaining
834
+
835
+ """
836
+ if self._connection is None:
837
+ raise ProgrammingError(errno=2048, msg='Connection is closed')
838
+ if not self._expect_results:
839
+ raise self._connection.ProgrammingError(msg='No query has been submitted')
840
+ if not self._has_row:
841
+ if 'dict' in self._results_type:
842
+ return {}
843
+ return tuple()
844
+ out = list(self._rows[self._row_idx:])
845
+ self._row_idx = len(out)
846
+ return format_results(
847
+ self._results_type, self.description or [],
848
+ out, schema=self._schema,
849
+ )
850
+
851
+ def nextset(self) -> Optional[bool]:
852
+ """Skip to the next available result set."""
853
+ if self._connection is None:
854
+ raise ProgrammingError(errno=2048, msg='Connection is closed')
855
+
856
+ if self._result_idx < 0:
857
+ self._row_idx = -1
858
+ return None
859
+
860
+ self._result_idx += 1
861
+ self._row_idx = 0
862
+
863
+ if self._result_idx >= len(self._results):
864
+ self._result_idx = -1
865
+ self._row_idx = -1
866
+ return None
867
+
868
+ self.rowcount = len(self._results[self._result_idx])
869
+
870
+ return True
871
+
872
+ def setinputsizes(self, sizes: Sequence[int]) -> None:
873
+ """Predefine memory areas for parameters."""
874
+ pass
875
+
876
+ def setoutputsize(self, size: int, column: Optional[str] = None) -> None:
877
+ """Set a column buffer size for fetches of large columns."""
878
+ pass
879
+
880
+ @property
881
+ def rownumber(self) -> Optional[int]:
882
+ """
883
+ Return the zero-based index of the cursor in the result set.
884
+
885
+ Returns
886
+ -------
887
+ int
888
+
889
+ """
890
+ if self._row_idx < 0:
891
+ return None
892
+ return self._row_idx
893
+
894
+ def scroll(self, value: int, mode: str = 'relative') -> None:
895
+ """
896
+ Scroll the cursor to the position in the result set.
897
+
898
+ Parameters
899
+ ----------
900
+ value : int
901
+ Value of the positional move
902
+ mode : str
903
+ Type of move that should be made: 'relative' or 'absolute'
904
+
905
+ """
906
+ if self._connection is None:
907
+ raise ProgrammingError(errno=2048, msg='Connection is closed')
908
+ if mode == 'relative':
909
+ self._row_idx += value
910
+ elif mode == 'absolute':
911
+ self._row_idx = value
912
+ else:
913
+ raise ValueError(
914
+ f'{mode} is not a valid mode, '
915
+ 'expecting "relative" or "absolute"',
916
+ )
917
+
918
+ def next(self) -> Optional[Result]:
919
+ """
920
+ Return the next row from the result set for use in iterators.
921
+
922
+ Returns
923
+ -------
924
+ tuple
925
+ Values from the next result row
926
+ None
927
+ If no more rows exist
928
+
929
+ """
930
+ if self._connection is None:
931
+ raise InterfaceError(errno=2048, msg='Connection is closed')
932
+ out = self.fetchone()
933
+ if out is None:
934
+ raise StopIteration
935
+ return out
936
+
937
+ __next__ = next
938
+
939
+ def __iter__(self) -> Iterable[Tuple[Any, ...]]:
940
+ """Return result iterator."""
941
+ return iter(self._rows[self._row_idx:])
942
+
943
+ def __enter__(self) -> 'Cursor':
944
+ """Enter a context."""
945
+ return self
946
+
947
+ def __exit__(
948
+ self, exc_type: Optional[object],
949
+ exc_value: Optional[Exception], exc_traceback: Optional[str],
950
+ ) -> None:
951
+ """Exit a context."""
952
+ self.close()
953
+
954
+ @property
955
+ def open(self) -> bool:
956
+ """Check if the cursor is still connected."""
957
+ if self._connection is None:
958
+ return False
959
+ return self._connection.is_connected()
960
+
961
+ def is_connected(self) -> bool:
962
+ """
963
+ Check if the cursor is still connected.
964
+
965
+ Returns
966
+ -------
967
+ bool
968
+
969
+ """
970
+ return self.open
971
+
972
+
973
+ class Connection(connection.Connection):
974
+ """
975
+ SingleStoreDB HTTP database connection.
976
+
977
+ Instances of this object are typically created through the
978
+ `connection` function rather than creating them directly.
979
+
980
+ See Also
981
+ --------
982
+ `connect`
983
+
984
+ """
985
+ driver = 'https'
986
+ paramstyle = 'qmark'
987
+
988
+ def __init__(self, **kwargs: Any):
989
+ from .. import __version__ as client_version
990
+
991
+ if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:
992
+ client_version += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']
993
+
994
+ connection.Connection.__init__(self, **kwargs)
995
+
996
+ host = kwargs.get('host', get_option('host'))
997
+ port = kwargs.get('port', get_option('http_port'))
998
+
999
+ self._sess: Optional[requests.Session] = requests.Session()
1000
+
1001
+ user = kwargs.get('user', get_option('user'))
1002
+ password = kwargs.get('password', get_option('password'))
1003
+ if user is not None and password is not None:
1004
+ self._sess.auth = (user, password)
1005
+ elif user is not None:
1006
+ self._sess.auth = (user, '')
1007
+ self._sess.headers.update({
1008
+ 'Content-Type': 'application/json',
1009
+ 'Accept': 'application/json',
1010
+ 'Accept-Encoding': 'compress,identity',
1011
+ 'User-Agent': f'SingleStoreDB-Python/{client_version}',
1012
+ })
1013
+
1014
+ if kwargs.get('ssl_disabled', get_option('ssl_disabled')):
1015
+ self._sess.verify = False
1016
+ else:
1017
+ ssl_key = kwargs.get('ssl_key', get_option('ssl_key'))
1018
+ ssl_cert = kwargs.get('ssl_cert', get_option('ssl_cert'))
1019
+ if ssl_key and ssl_cert:
1020
+ self._sess.cert = (ssl_key, ssl_cert)
1021
+ elif ssl_cert:
1022
+ self._sess.cert = ssl_cert
1023
+
1024
+ ssl_ca = kwargs.get('ssl_ca', get_option('ssl_ca'))
1025
+ if ssl_ca:
1026
+ self._sess.verify = ssl_ca
1027
+
1028
+ ssl_verify_cert = kwargs.get('ssl_verify_cert', True)
1029
+ if not ssl_verify_cert:
1030
+ self._sess.verify = False
1031
+
1032
+ if kwargs.get('multi_statements', False):
1033
+ raise self.InterfaceError(
1034
+ 0, 'The Data API does not allow multiple '
1035
+ 'statements within a query',
1036
+ )
1037
+
1038
+ self._version = kwargs.get('version', 'v2')
1039
+ self.driver = kwargs.get('driver', 'https')
1040
+
1041
+ self.encoders = {k: v for (k, v) in converters.items() if type(k) is not int}
1042
+ self.decoders = {k: v for (k, v) in converters.items() if type(k) is int}
1043
+
1044
+ self._database = kwargs.get('database', get_option('database'))
1045
+ self._url = f'{self.driver}://{host}:{port}/api/{self._version}/'
1046
+ self._host = host
1047
+ self._messages: List[Tuple[int, str]] = []
1048
+ self._autocommit: bool = True
1049
+ self._conv = kwargs.get('conv', None)
1050
+ self._in_sync: bool = False
1051
+ self._track_env: bool = kwargs.get('track_env', False) \
1052
+ or host == 'singlestore.com'
1053
+
1054
+ @property
1055
+ def messages(self) -> List[Tuple[int, str]]:
1056
+ return self._messages
1057
+
1058
+ def connect(self) -> 'Connection':
1059
+ """Connect to the server."""
1060
+ return self
1061
+
1062
+ def _sync_connection(self, kwargs: Dict[str, Any]) -> None:
1063
+ """Synchronize connection with env variable."""
1064
+ if self._sess is None:
1065
+ raise InterfaceError(errno=2048, msg='Connection is closed.')
1066
+
1067
+ if self._in_sync:
1068
+ return
1069
+
1070
+ if not self._track_env:
1071
+ return
1072
+
1073
+ url = os.environ.get('SINGLESTOREDB_URL')
1074
+ if not url:
1075
+ if self._host == 'singlestore.com':
1076
+ raise InterfaceError(0, 'Connection URL has not been established')
1077
+ return
1078
+
1079
+ out = {}
1080
+ urlp = connection._parse_url(url)
1081
+ out.update(urlp)
1082
+ out = connection._cast_params(out)
1083
+
1084
+ # Set default port based on driver.
1085
+ if 'port' not in out or not out['port']:
1086
+ if out.get('driver', 'https') == 'http':
1087
+ out['port'] = int(get_option('port') or 80)
1088
+ else:
1089
+ out['port'] = int(get_option('port') or 443)
1090
+
1091
+ # If there is no user and the password is empty, remove the password key.
1092
+ if 'user' not in out and not out.get('password', None):
1093
+ out.pop('password', None)
1094
+
1095
+ if out['host'] == 'singlestore.com':
1096
+ raise InterfaceError(0, 'Connection URL has not been established')
1097
+
1098
+ # Get current connection attributes
1099
+ curr_url = urlparse(self._url, scheme='singlestoredb', allow_fragments=True)
1100
+ if self._sess.auth is not None:
1101
+ auth = tuple(self._sess.auth) # type: ignore
1102
+ else:
1103
+ auth = (None, None) # type: ignore
1104
+
1105
+ # If it's just a password change, we don't need to reconnect
1106
+ if (curr_url.hostname, curr_url.port, auth[0], self._database) == \
1107
+ (out['host'], out['port'], out['user'], out.get('database')):
1108
+ return
1109
+
1110
+ try:
1111
+ self._in_sync = True
1112
+ sess = requests.Session()
1113
+ sess.auth = (out['user'], out['password'])
1114
+ sess.headers.update(self._sess.headers)
1115
+ sess.verify = self._sess.verify
1116
+ sess.cert = self._sess.cert
1117
+ self._database = out.get('database')
1118
+ self._host = out['host']
1119
+ self._url = f'{out.get("driver", "https")}://{out["host"]}:{out["port"]}' \
1120
+ f'/api/{self._version}/'
1121
+ self._sess = sess
1122
+ if self._database:
1123
+ kwargs['json']['database'] = self._database
1124
+ finally:
1125
+ self._in_sync = False
1126
+
1127
+ def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
1128
+ """
1129
+ Invoke a POST request on the HTTP connection.
1130
+
1131
+ Parameters
1132
+ ----------
1133
+ path : str
1134
+ The path of the resource
1135
+ *args : positional parameters, optional
1136
+ Extra parameters to the POST request
1137
+ **kwargs : keyword parameters, optional
1138
+ Extra keyword parameters to the POST request
1139
+
1140
+ Returns
1141
+ -------
1142
+ requests.Response
1143
+
1144
+ """
1145
+ if self._sess is None:
1146
+ raise InterfaceError(errno=2048, msg='Connection is closed.')
1147
+
1148
+ self._sync_connection(kwargs)
1149
+
1150
+ return self._sess.post(urljoin(self._url, path), *args, **kwargs)
1151
+
1152
+ def close(self) -> None:
1153
+ """Close the connection."""
1154
+ if self._host == 'singlestore.com':
1155
+ return
1156
+ if self._sess is None:
1157
+ raise Error(errno=2048, msg='Connection is closed')
1158
+ self._sess = None
1159
+
1160
+ def autocommit(self, value: bool = True) -> None:
1161
+ """Set autocommit mode."""
1162
+ if self._host == 'singlestore.com':
1163
+ return
1164
+ if self._sess is None:
1165
+ raise InterfaceError(errno=2048, msg='Connection is closed')
1166
+ self._autocommit = value
1167
+
1168
+ def commit(self) -> None:
1169
+ """Commit the pending transaction."""
1170
+ if self._host == 'singlestore.com':
1171
+ return
1172
+ if self._sess is None:
1173
+ raise InterfaceError(errno=2048, msg='Connection is closed')
1174
+ if self._autocommit:
1175
+ return
1176
+ raise NotSupportedError(msg='operation not supported')
1177
+
1178
+ def rollback(self) -> None:
1179
+ """Rollback the pending transaction."""
1180
+ if self._host == 'singlestore.com':
1181
+ return
1182
+ if self._sess is None:
1183
+ raise InterfaceError(errno=2048, msg='Connection is closed')
1184
+ if self._autocommit:
1185
+ return
1186
+ raise NotSupportedError(msg='operation not supported')
1187
+
1188
+ def cursor(self) -> Cursor:
1189
+ """
1190
+ Create a new cursor object.
1191
+
1192
+ Returns
1193
+ -------
1194
+ Cursor
1195
+
1196
+ """
1197
+ return Cursor(self)
1198
+
1199
+ def __enter__(self) -> 'Connection':
1200
+ """Enter a context."""
1201
+ return self
1202
+
1203
+ def __exit__(
1204
+ self, exc_type: Optional[object],
1205
+ exc_value: Optional[Exception], exc_traceback: Optional[str],
1206
+ ) -> None:
1207
+ """Exit a context."""
1208
+ self.close()
1209
+
1210
+ @property
1211
+ def open(self) -> bool:
1212
+ """Check if the database is still connected."""
1213
+ if self._sess is None:
1214
+ return False
1215
+ url = '/'.join(self._url.split('/')[:3]) + '/ping'
1216
+ res = self._sess.get(url)
1217
+ if res.status_code <= 400 and res.text == 'pong':
1218
+ return True
1219
+ return False
1220
+
1221
+ def is_connected(self) -> bool:
1222
+ """
1223
+ Check if the database is still connected.
1224
+
1225
+ Returns
1226
+ -------
1227
+ bool
1228
+
1229
+ """
1230
+ return self.open
1231
+
1232
+
1233
+ def connect(
1234
+ host: Optional[str] = None,
1235
+ user: Optional[str] = None,
1236
+ password: Optional[str] = None,
1237
+ port: Optional[int] = None,
1238
+ database: Optional[str] = None,
1239
+ driver: Optional[str] = None,
1240
+ pure_python: Optional[bool] = None,
1241
+ local_infile: Optional[bool] = None,
1242
+ charset: Optional[str] = None,
1243
+ ssl_key: Optional[str] = None,
1244
+ ssl_cert: Optional[str] = None,
1245
+ ssl_ca: Optional[str] = None,
1246
+ ssl_disabled: Optional[bool] = None,
1247
+ ssl_cipher: Optional[str] = None,
1248
+ ssl_verify_cert: Optional[bool] = None,
1249
+ ssl_verify_identity: Optional[bool] = None,
1250
+ conv: Optional[Dict[int, Callable[..., Any]]] = None,
1251
+ credential_type: Optional[str] = None,
1252
+ autocommit: Optional[bool] = None,
1253
+ results_type: Optional[str] = None,
1254
+ buffered: Optional[bool] = None,
1255
+ results_format: Optional[str] = None,
1256
+ program_name: Optional[str] = None,
1257
+ conn_attrs: Optional[Dict[str, str]] = None,
1258
+ multi_statements: Optional[bool] = None,
1259
+ connect_timeout: Optional[int] = None,
1260
+ nan_as_null: Optional[bool] = None,
1261
+ inf_as_null: Optional[bool] = None,
1262
+ encoding_errors: Optional[str] = None,
1263
+ track_env: Optional[bool] = None,
1264
+ enable_extended_data_types: Optional[bool] = None,
1265
+ vector_data_format: Optional[str] = None,
1266
+ ) -> Connection:
1267
+ return Connection(**dict(locals()))