singlestoredb 1.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. singlestoredb/__init__.py +75 -0
  2. singlestoredb/ai/__init__.py +2 -0
  3. singlestoredb/ai/chat.py +139 -0
  4. singlestoredb/ai/embeddings.py +128 -0
  5. singlestoredb/alchemy/__init__.py +90 -0
  6. singlestoredb/apps/__init__.py +3 -0
  7. singlestoredb/apps/_cloud_functions.py +90 -0
  8. singlestoredb/apps/_config.py +72 -0
  9. singlestoredb/apps/_connection_info.py +18 -0
  10. singlestoredb/apps/_dashboards.py +47 -0
  11. singlestoredb/apps/_process.py +32 -0
  12. singlestoredb/apps/_python_udfs.py +100 -0
  13. singlestoredb/apps/_stdout_supress.py +30 -0
  14. singlestoredb/apps/_uvicorn_util.py +36 -0
  15. singlestoredb/auth.py +245 -0
  16. singlestoredb/config.py +484 -0
  17. singlestoredb/connection.py +1487 -0
  18. singlestoredb/converters.py +950 -0
  19. singlestoredb/docstring/__init__.py +33 -0
  20. singlestoredb/docstring/attrdoc.py +126 -0
  21. singlestoredb/docstring/common.py +230 -0
  22. singlestoredb/docstring/epydoc.py +267 -0
  23. singlestoredb/docstring/google.py +412 -0
  24. singlestoredb/docstring/numpydoc.py +562 -0
  25. singlestoredb/docstring/parser.py +100 -0
  26. singlestoredb/docstring/py.typed +1 -0
  27. singlestoredb/docstring/rest.py +256 -0
  28. singlestoredb/docstring/tests/__init__.py +1 -0
  29. singlestoredb/docstring/tests/_pydoctor.py +21 -0
  30. singlestoredb/docstring/tests/test_epydoc.py +729 -0
  31. singlestoredb/docstring/tests/test_google.py +1007 -0
  32. singlestoredb/docstring/tests/test_numpydoc.py +1100 -0
  33. singlestoredb/docstring/tests/test_parse_from_object.py +109 -0
  34. singlestoredb/docstring/tests/test_parser.py +248 -0
  35. singlestoredb/docstring/tests/test_rest.py +547 -0
  36. singlestoredb/docstring/tests/test_util.py +70 -0
  37. singlestoredb/docstring/util.py +141 -0
  38. singlestoredb/exceptions.py +120 -0
  39. singlestoredb/functions/__init__.py +16 -0
  40. singlestoredb/functions/decorator.py +201 -0
  41. singlestoredb/functions/dtypes.py +1793 -0
  42. singlestoredb/functions/ext/__init__.py +1 -0
  43. singlestoredb/functions/ext/arrow.py +375 -0
  44. singlestoredb/functions/ext/asgi.py +2133 -0
  45. singlestoredb/functions/ext/json.py +420 -0
  46. singlestoredb/functions/ext/mmap.py +413 -0
  47. singlestoredb/functions/ext/rowdat_1.py +724 -0
  48. singlestoredb/functions/ext/timer.py +89 -0
  49. singlestoredb/functions/ext/utils.py +218 -0
  50. singlestoredb/functions/signature.py +1578 -0
  51. singlestoredb/functions/typing/__init__.py +41 -0
  52. singlestoredb/functions/typing/numpy.py +20 -0
  53. singlestoredb/functions/typing/pandas.py +2 -0
  54. singlestoredb/functions/typing/polars.py +2 -0
  55. singlestoredb/functions/typing/pyarrow.py +2 -0
  56. singlestoredb/functions/utils.py +421 -0
  57. singlestoredb/fusion/__init__.py +11 -0
  58. singlestoredb/fusion/graphql.py +213 -0
  59. singlestoredb/fusion/handler.py +916 -0
  60. singlestoredb/fusion/handlers/__init__.py +0 -0
  61. singlestoredb/fusion/handlers/export.py +525 -0
  62. singlestoredb/fusion/handlers/files.py +690 -0
  63. singlestoredb/fusion/handlers/job.py +660 -0
  64. singlestoredb/fusion/handlers/models.py +250 -0
  65. singlestoredb/fusion/handlers/stage.py +502 -0
  66. singlestoredb/fusion/handlers/utils.py +324 -0
  67. singlestoredb/fusion/handlers/workspace.py +956 -0
  68. singlestoredb/fusion/registry.py +249 -0
  69. singlestoredb/fusion/result.py +399 -0
  70. singlestoredb/http/__init__.py +27 -0
  71. singlestoredb/http/connection.py +1267 -0
  72. singlestoredb/magics/__init__.py +34 -0
  73. singlestoredb/magics/run_personal.py +137 -0
  74. singlestoredb/magics/run_shared.py +134 -0
  75. singlestoredb/management/__init__.py +9 -0
  76. singlestoredb/management/billing_usage.py +148 -0
  77. singlestoredb/management/cluster.py +462 -0
  78. singlestoredb/management/export.py +295 -0
  79. singlestoredb/management/files.py +1102 -0
  80. singlestoredb/management/inference_api.py +105 -0
  81. singlestoredb/management/job.py +887 -0
  82. singlestoredb/management/manager.py +373 -0
  83. singlestoredb/management/organization.py +226 -0
  84. singlestoredb/management/region.py +169 -0
  85. singlestoredb/management/utils.py +423 -0
  86. singlestoredb/management/workspace.py +1927 -0
  87. singlestoredb/mysql/__init__.py +177 -0
  88. singlestoredb/mysql/_auth.py +298 -0
  89. singlestoredb/mysql/charset.py +214 -0
  90. singlestoredb/mysql/connection.py +2032 -0
  91. singlestoredb/mysql/constants/CLIENT.py +38 -0
  92. singlestoredb/mysql/constants/COMMAND.py +32 -0
  93. singlestoredb/mysql/constants/CR.py +78 -0
  94. singlestoredb/mysql/constants/ER.py +474 -0
  95. singlestoredb/mysql/constants/EXTENDED_TYPE.py +3 -0
  96. singlestoredb/mysql/constants/FIELD_TYPE.py +48 -0
  97. singlestoredb/mysql/constants/FLAG.py +15 -0
  98. singlestoredb/mysql/constants/SERVER_STATUS.py +10 -0
  99. singlestoredb/mysql/constants/VECTOR_TYPE.py +6 -0
  100. singlestoredb/mysql/constants/__init__.py +0 -0
  101. singlestoredb/mysql/converters.py +271 -0
  102. singlestoredb/mysql/cursors.py +896 -0
  103. singlestoredb/mysql/err.py +92 -0
  104. singlestoredb/mysql/optionfile.py +20 -0
  105. singlestoredb/mysql/protocol.py +450 -0
  106. singlestoredb/mysql/tests/__init__.py +19 -0
  107. singlestoredb/mysql/tests/base.py +126 -0
  108. singlestoredb/mysql/tests/conftest.py +37 -0
  109. singlestoredb/mysql/tests/test_DictCursor.py +132 -0
  110. singlestoredb/mysql/tests/test_SSCursor.py +141 -0
  111. singlestoredb/mysql/tests/test_basic.py +452 -0
  112. singlestoredb/mysql/tests/test_connection.py +851 -0
  113. singlestoredb/mysql/tests/test_converters.py +58 -0
  114. singlestoredb/mysql/tests/test_cursor.py +141 -0
  115. singlestoredb/mysql/tests/test_err.py +16 -0
  116. singlestoredb/mysql/tests/test_issues.py +514 -0
  117. singlestoredb/mysql/tests/test_load_local.py +75 -0
  118. singlestoredb/mysql/tests/test_nextset.py +88 -0
  119. singlestoredb/mysql/tests/test_optionfile.py +27 -0
  120. singlestoredb/mysql/tests/thirdparty/__init__.py +6 -0
  121. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
  122. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/capabilities.py +323 -0
  123. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/dbapi20.py +865 -0
  124. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +110 -0
  125. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +224 -0
  126. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +101 -0
  127. singlestoredb/mysql/times.py +23 -0
  128. singlestoredb/notebook/__init__.py +16 -0
  129. singlestoredb/notebook/_objects.py +213 -0
  130. singlestoredb/notebook/_portal.py +352 -0
  131. singlestoredb/py.typed +0 -0
  132. singlestoredb/pytest.py +352 -0
  133. singlestoredb/server/__init__.py +0 -0
  134. singlestoredb/server/docker.py +452 -0
  135. singlestoredb/server/free_tier.py +267 -0
  136. singlestoredb/tests/__init__.py +0 -0
  137. singlestoredb/tests/alltypes.sql +307 -0
  138. singlestoredb/tests/alltypes_no_nulls.sql +208 -0
  139. singlestoredb/tests/empty.sql +0 -0
  140. singlestoredb/tests/ext_funcs/__init__.py +702 -0
  141. singlestoredb/tests/local_infile.csv +3 -0
  142. singlestoredb/tests/test.ipynb +18 -0
  143. singlestoredb/tests/test.sql +680 -0
  144. singlestoredb/tests/test2.ipynb +18 -0
  145. singlestoredb/tests/test2.sql +1 -0
  146. singlestoredb/tests/test_basics.py +1332 -0
  147. singlestoredb/tests/test_config.py +318 -0
  148. singlestoredb/tests/test_connection.py +3103 -0
  149. singlestoredb/tests/test_dbapi.py +27 -0
  150. singlestoredb/tests/test_exceptions.py +45 -0
  151. singlestoredb/tests/test_ext_func.py +1472 -0
  152. singlestoredb/tests/test_ext_func_data.py +1101 -0
  153. singlestoredb/tests/test_fusion.py +1527 -0
  154. singlestoredb/tests/test_http.py +288 -0
  155. singlestoredb/tests/test_management.py +1599 -0
  156. singlestoredb/tests/test_plugin.py +33 -0
  157. singlestoredb/tests/test_results.py +171 -0
  158. singlestoredb/tests/test_types.py +132 -0
  159. singlestoredb/tests/test_udf.py +737 -0
  160. singlestoredb/tests/test_udf_returns.py +459 -0
  161. singlestoredb/tests/test_vectorstore.py +51 -0
  162. singlestoredb/tests/test_xdict.py +333 -0
  163. singlestoredb/tests/utils.py +141 -0
  164. singlestoredb/types.py +373 -0
  165. singlestoredb/utils/__init__.py +0 -0
  166. singlestoredb/utils/config.py +950 -0
  167. singlestoredb/utils/convert_rows.py +69 -0
  168. singlestoredb/utils/debug.py +13 -0
  169. singlestoredb/utils/dtypes.py +205 -0
  170. singlestoredb/utils/events.py +65 -0
  171. singlestoredb/utils/mogrify.py +151 -0
  172. singlestoredb/utils/results.py +585 -0
  173. singlestoredb/utils/xdict.py +425 -0
  174. singlestoredb/vectorstore.py +192 -0
  175. singlestoredb/warnings.py +5 -0
  176. singlestoredb-1.16.1.dist-info/METADATA +165 -0
  177. singlestoredb-1.16.1.dist-info/RECORD +183 -0
  178. singlestoredb-1.16.1.dist-info/WHEEL +5 -0
  179. singlestoredb-1.16.1.dist-info/entry_points.txt +2 -0
  180. singlestoredb-1.16.1.dist-info/licenses/LICENSE +201 -0
  181. singlestoredb-1.16.1.dist-info/top_level.txt +3 -0
  182. sqlx/__init__.py +4 -0
  183. sqlx/magic.py +113 -0
@@ -0,0 +1,1332 @@
1
+ #!/usr/bin/env python
2
+ # type: ignore
3
+ """Basic SingleStoreDB connection testing."""
4
+ import datetime
5
+ import decimal
6
+ import math
7
+ import os
8
+ import unittest
9
+ from typing import Optional
10
+
11
+ from requests.exceptions import InvalidJSONError
12
+
13
+ try:
14
+ import numpy as np
15
+ has_numpy = True
16
+ except ImportError:
17
+ has_numpy = False
18
+
19
+ try:
20
+ import shapely.wkt
21
+ has_shapely = True
22
+ except ImportError:
23
+ has_shapely = False
24
+
25
+ try:
26
+ import pygeos
27
+ from pygeos.testing import assert_geometries_equal
28
+ has_pygeos = True
29
+ except ImportError:
30
+ has_pygeos = False
31
+
32
+ try:
33
+ import pydantic
34
+ has_pydantic = True
35
+ except ImportError:
36
+ has_pydantic = False
37
+
38
+ import singlestoredb as s2
39
+ from . import utils
40
+ # import traceback
41
+
42
+
43
+ class TestBasics(unittest.TestCase):
44
+
45
+ dbname: str = ''
46
+ dbexisted: bool = False
47
+
48
+ @classmethod
49
+ def setUpClass(cls):
50
+ sql_file = os.path.join(os.path.dirname(__file__), 'test.sql')
51
+ cls.dbname, cls.dbexisted = utils.load_sql(sql_file)
52
+
53
+ @classmethod
54
+ def tearDownClass(cls):
55
+ if not cls.dbexisted:
56
+ utils.drop_database(cls.dbname)
57
+
58
+ def setUp(self):
59
+ self.conn = s2.connect(database=type(self).dbname)
60
+ self.cur = self.conn.cursor()
61
+
62
+ def tearDown(self):
63
+ try:
64
+ if self.cur is not None:
65
+ self.cur.close()
66
+ except Exception:
67
+ # traceback.print_exc()
68
+ pass
69
+
70
+ try:
71
+ if self.conn is not None:
72
+ self.conn.close()
73
+ except Exception:
74
+ # traceback.print_exc()
75
+ pass
76
+
77
+ def test_connection(self):
78
+ self.cur.execute('show databases')
79
+ dbs = set([x[0] for x in self.cur.fetchall()])
80
+ assert type(self).dbname in dbs, dbs
81
+
82
+ def test_fetchall(self):
83
+ self.cur.execute('select * from data')
84
+
85
+ out = self.cur.fetchall()
86
+
87
+ desc = self.cur.description
88
+ rowcount = self.cur.rowcount
89
+ rownumber = self.cur.rownumber
90
+ lastrowid = self.cur.lastrowid
91
+
92
+ assert sorted(out) == sorted([
93
+ ('a', 'antelopes', 2),
94
+ ('b', 'bears', 2),
95
+ ('c', 'cats', 5),
96
+ ('d', 'dogs', 4),
97
+ ('e', 'elephants', 0),
98
+ ]), out
99
+
100
+ assert rowcount in (5, -1), rowcount
101
+ assert rownumber == 5, rownumber
102
+ assert lastrowid is None, lastrowid
103
+ assert len(desc) == 3, desc
104
+ assert desc[0].name == 'id', desc[0].name
105
+ assert desc[0].type_code in [253, 15], desc[0].type_code
106
+ assert desc[1].name == 'name', desc[1].name
107
+ assert desc[1].type_code in [253, 15], desc[1].type_code
108
+ assert desc[2].name == 'value', desc[2].name
109
+ assert desc[2].type_code == 8, desc[2].type_code
110
+
111
+ def test_fetchone(self):
112
+ self.cur.execute('select * from data')
113
+
114
+ out = []
115
+ while True:
116
+ row = self.cur.fetchone()
117
+ if row is None:
118
+ break
119
+ out.append(row)
120
+ assert self.cur.rownumber == len(out), self.cur.rownumber
121
+
122
+ desc = self.cur.description
123
+ rowcount = self.cur.rowcount
124
+ rownumber = self.cur.rownumber
125
+ lastrowid = self.cur.lastrowid
126
+
127
+ assert sorted(out) == sorted([
128
+ ('a', 'antelopes', 2),
129
+ ('b', 'bears', 2),
130
+ ('c', 'cats', 5),
131
+ ('d', 'dogs', 4),
132
+ ('e', 'elephants', 0),
133
+ ]), out
134
+
135
+ assert rowcount in (5, -1), rowcount
136
+ assert rownumber == 5, rownumber
137
+ assert lastrowid is None, lastrowid
138
+ assert len(desc) == 3, desc
139
+ assert desc[0].name == 'id', desc[0].name
140
+ assert desc[0].type_code in [253, 15], desc[0].type_code
141
+ assert desc[1].name == 'name', desc[1].name
142
+ assert desc[1].type_code in [253, 15], desc[1].type_code
143
+ assert desc[2].name == 'value', desc[2].name
144
+ assert desc[2].type_code == 8, desc[2].type_code
145
+
146
+ def test_fetchmany(self):
147
+ self.cur.execute('select * from data')
148
+
149
+ out = []
150
+ while True:
151
+ rows = self.cur.fetchmany(size=3)
152
+ assert len(rows) <= 3, rows
153
+ if not rows:
154
+ break
155
+ out.extend(rows)
156
+ assert self.cur.rownumber == len(out), self.cur.rownumber
157
+
158
+ desc = self.cur.description
159
+ rowcount = self.cur.rowcount
160
+ rownumber = self.cur.rownumber
161
+ lastrowid = self.cur.lastrowid
162
+
163
+ assert sorted(out) == sorted([
164
+ ('a', 'antelopes', 2),
165
+ ('b', 'bears', 2),
166
+ ('c', 'cats', 5),
167
+ ('d', 'dogs', 4),
168
+ ('e', 'elephants', 0),
169
+ ]), out
170
+
171
+ assert rowcount in (5, -1), rowcount
172
+ assert rownumber == 5, rownumber
173
+ assert lastrowid is None, lastrowid
174
+ assert len(desc) == 3, desc
175
+ assert desc[0].name == 'id'
176
+ assert desc[0].type_code in [253, 15]
177
+ assert desc[1].name == 'name'
178
+ assert desc[1].type_code in [253, 15]
179
+ assert desc[2].name == 'value'
180
+ assert desc[2].type_code == 8
181
+
182
+ def test_arraysize(self):
183
+ self.cur.execute('select * from data')
184
+
185
+ self.cur.arraysize = 3
186
+ assert self.cur.arraysize == 3
187
+
188
+ rows = self.cur.fetchmany()
189
+ assert len(rows) == 3, rows
190
+ assert self.cur.rownumber == 3, self.cur.rownumber
191
+
192
+ self.cur.arraysize = 1
193
+ assert self.cur.arraysize == 1
194
+
195
+ rows = self.cur.fetchmany()
196
+ assert len(rows) == 1, rows
197
+ assert self.cur.rownumber == 4, self.cur.rownumber
198
+
199
+ rows = self.cur.fetchmany()
200
+ assert len(rows) == 1, rows
201
+ assert self.cur.rownumber == 5, self.cur.rownumber
202
+
203
+ rows = self.cur.fetchall()
204
+ assert len(rows) == 0, rows
205
+ assert self.cur.rownumber == 5, self.cur.rownumber
206
+
207
+ def test_execute_with_dict_params(self):
208
+ self.cur.execute('select * from data where id < %(name)s', dict(name='d'))
209
+ out = self.cur.fetchall()
210
+
211
+ desc = self.cur.description
212
+ rowcount = self.cur.rowcount
213
+ lastrowid = self.cur.lastrowid
214
+
215
+ assert sorted(out) == sorted([
216
+ ('a', 'antelopes', 2),
217
+ ('b', 'bears', 2),
218
+ ('c', 'cats', 5),
219
+ ]), out
220
+
221
+ assert rowcount in (3, -1), rowcount
222
+ assert lastrowid is None, lastrowid
223
+ assert len(desc) == 3, desc
224
+ assert desc[0].name == 'id', desc[0].name
225
+ assert desc[0].type_code in [253, 15], desc[0].type_code
226
+ assert desc[1].name == 'name', desc[1].name
227
+ assert desc[1].type_code in [253, 15], desc[1].type_code
228
+ assert desc[2].name == 'value', desc[2].name
229
+ assert desc[2].type_code == 8, desc[2].type_code
230
+
231
+ with self.assertRaises(KeyError):
232
+ self.cur.execute('select * from data where id < %(name)s', dict(foo='d'))
233
+
234
+ def test_execute_with_positional_params(self):
235
+ self.cur.execute('select * from data where id < %s', ['d'])
236
+ out = self.cur.fetchall()
237
+
238
+ desc = self.cur.description
239
+ rowcount = self.cur.rowcount
240
+ lastrowid = self.cur.lastrowid
241
+
242
+ assert sorted(out) == sorted([
243
+ ('a', 'antelopes', 2),
244
+ ('b', 'bears', 2),
245
+ ('c', 'cats', 5),
246
+ ]), out
247
+
248
+ assert rowcount in (3, -1), rowcount
249
+ assert lastrowid is None, lastrowid
250
+ assert len(desc) == 3, desc
251
+ assert desc[0].name == 'id', desc[0].name
252
+ assert desc[0].type_code in [253, 15], desc[0].type_code
253
+ assert desc[1].name == 'name', desc[1].name
254
+ assert desc[1].type_code in [253, 15], desc[1].type_code
255
+ assert desc[2].name == 'value', desc[2].name
256
+ assert desc[2].type_code == 8, desc[2].type_code
257
+
258
+ with self.assertRaises(TypeError):
259
+ self.cur.execute(
260
+ 'select * from data where id < %s and id > %s', ['d', 'e', 'f'],
261
+ )
262
+
263
+ with self.assertRaises(TypeError):
264
+ self.cur.execute('select * from data where id < %s and id > %s', ['d'])
265
+
266
+ def test_execute_with_escaped_positional_substitutions(self):
267
+ self.cur.execute(
268
+ 'select `id`, `time` from alltypes where `time` = %s', ['00:07:00'],
269
+ )
270
+ out = self.cur.fetchall()
271
+ assert out[0] == (0, datetime.timedelta(seconds=420)), out[0]
272
+
273
+ self.cur.execute('select `id`, `time` from alltypes where `time` = "00:07:00"')
274
+ out = self.cur.fetchall()
275
+ assert out[0] == (0, datetime.timedelta(seconds=420)), out[0]
276
+
277
+ # with self.assertRaises(IndexError):
278
+ # self.cur.execute(
279
+ # 'select `id`, `time` from alltypes where `id` = %1s '
280
+ # 'or `time` = "00:07:00"', [0],
281
+ # )
282
+
283
+ self.cur.execute(
284
+ 'select `id`, `time` from alltypes where `id` = %s '
285
+ 'or `time` = "00:07:00"', [0],
286
+ )
287
+ out = self.cur.fetchall()
288
+ assert out[0] == (0, datetime.timedelta(seconds=420)), out[0]
289
+
290
+ def test_execute_with_escaped_substitutions(self):
291
+ self.cur.execute(
292
+ 'select `id`, `time` from alltypes where `time` = %(time)s',
293
+ dict(time='00:07:00'),
294
+ )
295
+ out = self.cur.fetchall()
296
+ assert out[0] == (0, datetime.timedelta(seconds=420)), out[0]
297
+
298
+ self.cur.execute(
299
+ 'select `id`, `time` from alltypes where `time` = %(time)s',
300
+ dict(time='00:07:00'),
301
+ )
302
+ out = self.cur.fetchall()
303
+ assert len(out) == 1, out
304
+
305
+ with self.assertRaises(KeyError):
306
+ self.cur.execute(
307
+ 'select `id`, `time`, `char_100` from alltypes '
308
+ 'where `time` = %(time)s or `char_100` like "foo:bar"',
309
+ dict(x='00:07:00'),
310
+ )
311
+
312
+ self.cur.execute(
313
+ 'select `id`, `time`, `char_100` from alltypes '
314
+ 'where `time` = %(time)s or `char_100` like "foo::bar"',
315
+ dict(time='00:07:00'),
316
+ )
317
+ out = self.cur.fetchall()
318
+ assert out[0][:2] == (0, datetime.timedelta(seconds=420)), out[0]
319
+
320
+ def test_is_connected(self):
321
+ assert self.conn.is_connected()
322
+ assert self.cur.is_connected()
323
+ self.cur.close()
324
+ assert not self.cur.is_connected()
325
+ assert self.conn.is_connected()
326
+ self.conn.close()
327
+ assert not self.cur.is_connected()
328
+ assert not self.conn.is_connected()
329
+
330
+ def test_connection_attr(self):
331
+ # Use context manager to get to underlying object (self.conn is a weakref.proxy)
332
+ with self.conn as conn:
333
+ assert conn is self.conn
334
+
335
+ def test_executemany(self):
336
+ # NOTE: Doesn't actually do anything since no rows match
337
+ self.cur.executemany(
338
+ 'delete from data where id > %(name)s',
339
+ [dict(name='z'), dict(name='y')],
340
+ )
341
+
342
+ def test_executemany_no_args(self):
343
+ self.cur.executemany('select * from data where id > "z"')
344
+
345
+ def test_context_managers(self):
346
+ with s2.connect() as conn:
347
+ with conn.cursor() as cur:
348
+ assert cur.is_connected()
349
+ assert conn.is_connected()
350
+ assert not cur.is_connected()
351
+ assert not conn.is_connected()
352
+
353
+ def test_iterator(self):
354
+ self.cur.execute('select * from data')
355
+
356
+ out = []
357
+ for row in self.cur:
358
+ out.append(row)
359
+
360
+ assert sorted(out) == sorted([
361
+ ('a', 'antelopes', 2),
362
+ ('b', 'bears', 2),
363
+ ('c', 'cats', 5),
364
+ ('d', 'dogs', 4),
365
+ ('e', 'elephants', 0),
366
+ ]), out
367
+
368
+ def test_urls(self):
369
+ from singlestoredb.connection import build_params
370
+ from singlestoredb.config import get_option
371
+
372
+ # Full URL (without scheme)
373
+ url = 'me:p455w0rd@s2host.com:3307/mydb'
374
+ out = build_params(host=url)
375
+ assert out['driver'] == get_option('driver'), out['driver']
376
+ assert out['host'] == 's2host.com', out['host']
377
+ assert out['port'] == 3307, out['port']
378
+ assert out['database'] == 'mydb', out['database']
379
+ assert out['user'] == 'me', out['user']
380
+ assert out['password'] == 'p455w0rd', out['password']
381
+
382
+ # Full URL (with scheme)
383
+ url = 'http://me:p455w0rd@s2host.com:3307/mydb'
384
+ out = build_params(host=url)
385
+ assert out['driver'] == 'http', out['driver']
386
+ assert out['host'] == 's2host.com', out['host']
387
+ assert out['port'] == 3307, out['port']
388
+ assert out['database'] == 'mydb', out['database']
389
+ assert out['user'] == 'me', out['user']
390
+ assert out['password'] == 'p455w0rd', out['password']
391
+
392
+ # No port
393
+ url = 'me:p455w0rd@s2host.com/mydb'
394
+ out = build_params(host=url)
395
+ assert out['driver'] == get_option('driver'), out['driver']
396
+ assert out['host'] == 's2host.com', out['host']
397
+ if out['driver'] in ['http', 'https']:
398
+ assert out['port'] in [get_option('http_port'), 80, 443], out['port']
399
+ else:
400
+ assert out['port'] in [get_option('port'), 3306], out['port']
401
+ assert out['database'] == 'mydb', out['database']
402
+ assert out['user'] == 'me', out['user']
403
+ assert out['password'] == 'p455w0rd', out['password']
404
+
405
+ # No http port
406
+ url = 'http://me:p455w0rd@s2host.com/mydb'
407
+ out = build_params(host=url)
408
+ assert out['driver'] == 'http', out['driver']
409
+ assert out['host'] == 's2host.com', out['host']
410
+ assert out['port'] in [get_option('http_port'), 80], out['port']
411
+ assert out['database'] == 'mydb', out['database']
412
+ assert out['user'] == 'me', out['user']
413
+ assert out['password'] == 'p455w0rd', out['password']
414
+
415
+ # No https port
416
+ url = 'https://me:p455w0rd@s2host.com/mydb'
417
+ out = build_params(host=url)
418
+ assert out['driver'] == 'https', out['driver']
419
+ assert out['host'] == 's2host.com', out['host']
420
+ assert out['port'] in [get_option('http_port'), 443], out['port']
421
+ assert out['database'] == 'mydb', out['database']
422
+ assert out['user'] == 'me', out['user']
423
+ assert out['password'] == 'p455w0rd', out['password']
424
+
425
+ # Invalid port
426
+ url = 'https://me:p455w0rd@s2host.com:foo/mydb'
427
+ with self.assertRaises(ValueError):
428
+ build_params(host=url)
429
+
430
+ # Empty password
431
+ url = 'me:@s2host.com/mydb'
432
+ out = build_params(host=url)
433
+ assert out['driver'] == get_option('driver'), out['driver']
434
+ assert out['host'] == 's2host.com', out['host']
435
+ if out['driver'] in ['http', 'https']:
436
+ assert out['port'] in [get_option('http_port'), 80, 443], out['port']
437
+ else:
438
+ assert out['port'] in [get_option('port'), 3306], out['port']
439
+ assert out['database'] == 'mydb', out['database']
440
+ assert out['user'] == 'me', out['user']
441
+ assert out['password'] == '', out['password']
442
+
443
+ # No user/password
444
+ url = 's2host.com/mydb'
445
+ out = build_params(host=url)
446
+ assert out['driver'] == get_option('driver'), out['driver']
447
+ assert out['host'] == 's2host.com', out['host']
448
+ if out['driver'] in ['http', 'https']:
449
+ assert out['port'] in [get_option('http_port'), 80, 443], out['port']
450
+ else:
451
+ assert out['port'] in [get_option('port'), 3306], out['port']
452
+ assert out['database'] == 'mydb', out['database']
453
+ assert 'user' not in out or out['user'] == get_option('user'), out['user']
454
+ assert 'password' not in out or out['password'] == get_option(
455
+ 'password',
456
+ ), out['password']
457
+
458
+ # Just hostname
459
+ url = 's2host.com'
460
+ out = build_params(host=url)
461
+ assert out['driver'] == get_option('driver'), out['driver']
462
+ assert out['host'] == 's2host.com', out['host']
463
+ if out['driver'] in ['http', 'https']:
464
+ assert out['port'] in [get_option('http_port'), 80, 443], out['port']
465
+ else:
466
+ assert out['port'] in [get_option('port'), 3306], out['port']
467
+ assert 'database' not in out
468
+ assert 'user' not in out or out['user'] == get_option('user'), out['user']
469
+ assert 'password' not in out or out['password'] == get_option(
470
+ 'password',
471
+ ), out['password']
472
+
473
+ # Just hostname and port
474
+ url = 's2host.com:1000'
475
+ out = build_params(host=url)
476
+ assert out['driver'] == get_option('driver'), out['driver']
477
+ assert out['host'] == 's2host.com', out['host']
478
+ assert out['port'] == 1000, out['port']
479
+ assert 'database' not in out
480
+ assert 'user' not in out or out['user'] == get_option('user'), out['user']
481
+ assert 'password' not in out or out['password'] == get_option(
482
+ 'password',
483
+ ), out['password']
484
+
485
+ # Query options
486
+ url = 's2host.com:1000?local_infile=1&charset=utf8'
487
+ out = build_params(host=url)
488
+ assert out['driver'] == get_option('driver'), out['driver']
489
+ assert out['host'] == 's2host.com', out['host']
490
+ assert out['port'] == 1000, out['port']
491
+ assert 'database' not in out
492
+ assert 'user' not in out or out['user'] == get_option('user'), out['user']
493
+ assert 'password' not in out or out['password'] == get_option(
494
+ 'password',
495
+ ), out['password']
496
+ assert out['local_infile'] is True, out['local_infile']
497
+ assert out['charset'] == 'utf8', out['charset']
498
+
499
+ # Bad query option
500
+ url = 's2host.com:1000?bad_param=10'
501
+ with self.assertRaises(ValueError):
502
+ build_params(host=url)
503
+
504
+ def test_wrap_exc(self):
505
+ with self.assertRaises(s2.ProgrammingError) as cm:
506
+ self.cur.execute('garbage syntax')
507
+
508
+ exc = cm.exception
509
+ assert exc.errno == 1064, exc.errno
510
+ assert 'You have an error in your SQL syntax' in exc.errmsg, exc.errmsg
511
+
512
+ def test_extended_types(self):
513
+ if not has_numpy or not has_pygeos or not has_shapely:
514
+ self.skipTest('Test requires numpy, pygeos, and shapely')
515
+
516
+ import uuid
517
+
518
+ key = str(uuid.uuid4())
519
+
520
+ # shapely data
521
+ data = [
522
+ (
523
+ 1, 'POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))', 'POINT(1.5 1.5)',
524
+ [0.5, 0.6], datetime.datetime(1950, 1, 2, 12, 13, 14),
525
+ datetime.date(1950, 1, 2), datetime.time(12, 13, 14),
526
+ datetime.timedelta(seconds=123456), key,
527
+ ),
528
+ (
529
+ 2, 'POLYGON((5 1, 6 1, 6 2, 5 2, 5 1))', 'POINT(5.5 1.5)',
530
+ [1.3, 2.5], datetime.datetime(1960, 3, 4, 15, 16, 17),
531
+ datetime.date(1960, 3, 4), datetime.time(15, 16, 17),
532
+ datetime.timedelta(seconds=2), key,
533
+ ),
534
+ (
535
+ 3, 'POLYGON((5 5, 6 5, 6 6, 5 6, 5 5))', 'POINT(5.5 5.5)',
536
+ [10.3, 11.1], datetime.datetime(1970, 6, 7, 18, 19, 20),
537
+ datetime.date(1970, 5, 6), datetime.time(18, 19, 20),
538
+ datetime.timedelta(seconds=-2), key,
539
+ ),
540
+ (
541
+ 4, 'POLYGON((1 5, 2 5, 2 6, 1 6, 1 5))', 'POINT(1.5 5.5)',
542
+ [3.3, 3.4], datetime.datetime(1980, 8, 9, 21, 22, 23),
543
+ datetime.date(1980, 7, 8), datetime.time(21, 22, 23),
544
+ datetime.timedelta(seconds=-123456), key,
545
+ ),
546
+ (
547
+ 5, 'POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))', 'POINT(3.5 3.5)',
548
+ [2.9, 9.5], datetime.datetime(2010, 10, 11, 1, 2, 3),
549
+ datetime.date(2010, 8, 9), datetime.time(1, 2, 3),
550
+ datetime.timedelta(seconds=0), key,
551
+ ),
552
+ ]
553
+
554
+ new_data = []
555
+ for i, row in enumerate(data):
556
+ row = list(row)
557
+ row[1] = shapely.wkt.loads(row[1])
558
+ row[2] = shapely.wkt.loads(row[2])
559
+ if 'http' in self.conn.driver:
560
+ row[3] = ''
561
+ else:
562
+ row[3] = np.array(row[3], dtype='<f4')
563
+ new_data.append(row)
564
+
565
+ self.cur.executemany(
566
+ 'INSERT INTO extended_types '
567
+ '(id, geography, geographypoint, vectors, dt, d, t, td, testkey) '
568
+ 'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)', new_data,
569
+ )
570
+
571
+ self.cur.execute(
572
+ 'SELECT * FROM extended_types WHERE testkey = %s ORDER BY id', [key],
573
+ )
574
+
575
+ for data_row, row in zip(new_data, self.cur):
576
+ assert data_row[0] == row[0]
577
+ assert data_row[1].equals_exact(shapely.wkt.loads(row[1]), 1e-4)
578
+ assert data_row[2].equals_exact(shapely.wkt.loads(row[2]), 1e-4)
579
+ if 'http' in self.conn.driver:
580
+ assert row[3] == b''
581
+ else:
582
+ assert (data_row[3] == np.frombuffer(row[3], dtype='<f4')).all()
583
+
584
+ # pygeos data
585
+ data = [
586
+ (
587
+ 6, 'POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))', 'POINT(1.5 1.5)',
588
+ [0.5, 0.6], datetime.datetime(1950, 1, 2, 12, 13, 14),
589
+ datetime.date(1950, 1, 2), datetime.time(12, 13, 14),
590
+ datetime.timedelta(seconds=123456), key,
591
+ ),
592
+ (
593
+ 7, 'POLYGON((5 1, 6 1, 6 2, 5 2, 5 1))', 'POINT(5.5 1.5)',
594
+ [1.3, 2.5], datetime.datetime(1960, 3, 4, 15, 16, 17),
595
+ datetime.date(1960, 3, 4), datetime.time(15, 16, 17),
596
+ datetime.timedelta(seconds=2), key,
597
+ ),
598
+ (
599
+ 8, 'POLYGON((5 5, 6 5, 6 6, 5 6, 5 5))', 'POINT(5.5 5.5)',
600
+ [10.3, 11.1], datetime.datetime(1970, 6, 7, 18, 19, 20),
601
+ datetime.date(1970, 5, 6), datetime.time(18, 19, 20),
602
+ datetime.timedelta(seconds=-2), key,
603
+ ),
604
+ (
605
+ 9, 'POLYGON((1 5, 2 5, 2 6, 1 6, 1 5))', 'POINT(1.5 5.5)',
606
+ [3.3, 3.4], datetime.datetime(1980, 8, 9, 21, 22, 23),
607
+ datetime.date(1980, 7, 8), datetime.time(21, 22, 23),
608
+ datetime.timedelta(seconds=-123456), key,
609
+ ),
610
+ (
611
+ 10, 'POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))', 'POINT(3.5 3.5)',
612
+ [2.9, 9.5], datetime.datetime(2010, 10, 11, 1, 2, 3),
613
+ datetime.date(2010, 8, 9), datetime.time(1, 2, 3),
614
+ datetime.timedelta(seconds=0), key,
615
+ ),
616
+ ]
617
+
618
+ new_data = []
619
+ for i, row in enumerate(data):
620
+ row = list(row)
621
+ row[1] = pygeos.io.from_wkt(row[1])
622
+ row[2] = pygeos.io.from_wkt(row[2])
623
+ if 'http' in self.conn.driver:
624
+ row[3] = ''
625
+ else:
626
+ row[3] = np.array(row[3], dtype='<f4')
627
+ new_data.append(row)
628
+
629
+ self.cur.executemany(
630
+ 'INSERT INTO extended_types '
631
+ '(id, geography, geographypoint, vectors, dt, d, t, td, testkey) '
632
+ 'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)', new_data,
633
+ )
634
+
635
+ self.cur.execute(
636
+ 'SELECT * FROM extended_types WHERE id >= 6 and testkey = %s ORDER BY id', [
637
+ key,
638
+ ],
639
+ )
640
+
641
+ for data_row, row in zip(new_data, self.cur):
642
+ assert data_row[0] == row[0]
643
+ assert_geometries_equal(data_row[1], pygeos.io.from_wkt(row[1]))
644
+ assert_geometries_equal(data_row[2], pygeos.io.from_wkt(row[2]))
645
+ if 'http' in self.conn.driver:
646
+ assert row[3] == b''
647
+ else:
648
+ assert (data_row[3] == np.frombuffer(row[3], dtype='<f4')).all()
649
+
650
+ def test_alltypes(self):
651
+ self.cur.execute('select * from alltypes where id = 0')
652
+ names = [x[0] for x in self.cur.description]
653
+ types = [x[1] for x in self.cur.description]
654
+ out = self.cur.fetchone()
655
+ row = dict(zip(names, out))
656
+ typ = dict(zip(names, types))
657
+
658
+ bits = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
659
+
660
+ def otype(x):
661
+ return x
662
+
663
+ assert row['id'] == 0, row['id']
664
+ assert typ['id'] == otype(3), typ['id']
665
+
666
+ assert row['tinyint'] == 80, row['tinyint']
667
+ assert typ['tinyint'] == otype(1), typ['tinyint']
668
+
669
+ assert row['bool'] == 0, row['bool']
670
+ assert typ['bool'] == otype(1), typ['bool']
671
+
672
+ assert row['boolean'] == 1, row['boolean']
673
+ assert typ['boolean'] == otype(1), typ['boolean']
674
+
675
+ assert row['smallint'] == -27897, row['smallint']
676
+ assert typ['smallint'] == otype(2), typ['smallint']
677
+
678
+ assert row['mediumint'] == 104729, row['mediumint']
679
+ assert typ['mediumint'] == otype(9), typ['mediumint']
680
+
681
+ assert row['int24'] == -200899, row['int24']
682
+ assert typ['int24'] == otype(9), typ['int24']
683
+
684
+ assert row['int'] == -1295369311, row['int']
685
+ assert typ['int'] == otype(3), typ['int']
686
+
687
+ assert row['integer'] == -1741727421, row['integer']
688
+ assert typ['integer'] == otype(3), typ['integer']
689
+
690
+ assert row['bigint'] == -266883847, row['bigint']
691
+ assert typ['bigint'] == otype(8), typ['bigint']
692
+
693
+ assert row['float'] == -146487000.0, row['float']
694
+ assert typ['float'] == otype(4), typ['float']
695
+
696
+ assert row['double'] == -474646154.719356, row['double']
697
+ assert typ['double'] == otype(5), typ['double']
698
+
699
+ assert row['real'] == -901409776.279346, row['real']
700
+ assert typ['real'] == otype(5), typ['real']
701
+
702
+ assert row['decimal'] == decimal.Decimal('28111097.610822'), row['decimal']
703
+ assert typ['decimal'] == otype(246), typ['decimal']
704
+
705
+ assert row['dec'] == decimal.Decimal('389451155.931428'), row['dec']
706
+ assert typ['dec'] == otype(246), typ['dec']
707
+
708
+ assert row['fixed'] == decimal.Decimal('-143773416.044092'), row['fixed']
709
+ assert typ['fixed'] == otype(246), typ['fixed']
710
+
711
+ assert row['numeric'] == decimal.Decimal('866689461.300046'), row['numeric']
712
+ assert typ['numeric'] == otype(246), typ['numeric']
713
+
714
+ assert row['date'] == datetime.date(8524, 11, 10), row['date']
715
+ assert typ['date'] == 10, typ['date']
716
+
717
+ assert row['time'] == datetime.timedelta(minutes=7), row['time']
718
+ assert typ['time'] == 11, typ['time']
719
+
720
+ assert row['time_6'] == datetime.timedelta(
721
+ hours=1, minutes=10, microseconds=2,
722
+ ), row['time_6']
723
+ assert typ['time_6'] == 11, typ['time_6']
724
+
725
+ assert row['datetime'] == datetime.datetime(
726
+ 9948, 3, 11, 15, 29, 22,
727
+ ), row['datetime']
728
+ assert typ['datetime'] == 12, typ['datetime']
729
+
730
+ assert row['datetime_6'] == datetime.datetime(
731
+ 1756, 10, 29, 2, 2, 42, 8,
732
+ ), row['datetime_6']
733
+ assert typ['datetime_6'] == 12, typ['datetime_6']
734
+
735
+ assert row['timestamp'] == datetime.datetime(
736
+ 1980, 12, 31, 1, 10, 23,
737
+ ), row['timestamp']
738
+ assert typ['timestamp'] == otype(7), typ['timestamp']
739
+
740
+ assert row['timestamp_6'] == datetime.datetime(
741
+ 1991, 1, 2, 22, 15, 10, 6,
742
+ ), row['timestamp_6']
743
+ assert typ['timestamp_6'] == otype(7), typ['timestamp_6']
744
+
745
+ assert row['year'] == 1923, row['year']
746
+ assert typ['year'] == otype(13), typ['year']
747
+
748
+ assert row['char_100'] == \
749
+ 'This is a test of a 100 character column.', row['char_100']
750
+ assert typ['char_100'] == otype(254), typ['char_100']
751
+
752
+ assert row['binary_100'] == bytearray(bits + [0] * 84), row['binary_100']
753
+ assert typ['binary_100'] == otype(254), typ['binary_100']
754
+
755
+ assert row['varchar_200'] == \
756
+ 'This is a test of a variable character column.', row['varchar_200']
757
+ assert typ['varchar_200'] == otype(253), typ['varchar_200'] # why not 15?
758
+
759
+ assert row['varbinary_200'] == bytearray(bits * 2), row['varbinary_200']
760
+ assert typ['varbinary_200'] == otype(253), typ['varbinary_200'] # why not 15?
761
+
762
+ assert row['longtext'] == 'This is a longtext column.', row['longtext']
763
+ assert typ['longtext'] == otype(251), typ['longtext']
764
+
765
+ assert row['mediumtext'] == 'This is a mediumtext column.', row['mediumtext']
766
+ assert typ['mediumtext'] == otype(250), typ['mediumtext']
767
+
768
+ assert row['text'] == 'This is a text column.', row['text']
769
+ assert typ['text'] == otype(252), typ['text']
770
+
771
+ assert row['tinytext'] == 'This is a tinytext column.'
772
+ assert typ['tinytext'] == otype(249), typ['tinytext']
773
+
774
+ assert row['longblob'] == bytearray(bits * 3), row['longblob']
775
+ assert typ['longblob'] == otype(251), typ['longblob']
776
+
777
+ assert row['mediumblob'] == bytearray(bits * 2), row['mediumblob']
778
+ assert typ['mediumblob'] == otype(250), typ['mediumblob']
779
+
780
+ assert row['blob'] == bytearray(bits), row['blob']
781
+ assert typ['blob'] == otype(252), typ['blob']
782
+
783
+ assert row['tinyblob'] == bytearray([10, 11, 12, 13, 14, 15]), row['tinyblob']
784
+ assert typ['tinyblob'] == otype(249), typ['tinyblob']
785
+
786
+ assert row['json'] == {'a': 10, 'b': 2.75, 'c': 'hello world'}, row['json']
787
+ assert typ['json'] == otype(245), typ['json']
788
+
789
+ assert row['enum'] == 'one', row['enum']
790
+ assert typ['enum'] == otype(253), typ['enum'] # mysql code: 247
791
+
792
+ # TODO: HTTP sees this as a varchar, so it doesn't become a set.
793
+ assert row['set'] in [{'two'}, 'two'], row['set']
794
+ assert typ['set'] == otype(253), typ['set'] # mysql code: 248
795
+
796
+ assert row['bit'] == b'\x00\x00\x00\x00\x00\x00\x00\x80', row['bit']
797
+ assert typ['bit'] == otype(16), typ['bit']
798
+
799
+ def test_alltypes_nulls(self):
800
+ self.cur.execute('select * from alltypes where id = 1')
801
+ names = [x[0] for x in self.cur.description]
802
+ types = [x[1] for x in self.cur.description]
803
+ out = self.cur.fetchone()
804
+ row = dict(zip(names, out))
805
+ typ = dict(zip(names, types))
806
+
807
+ def otype(x):
808
+ return x
809
+
810
+ assert row['id'] == 1, row['id']
811
+ assert typ['id'] == otype(3), typ['id']
812
+
813
+ assert row['tinyint'] is None, row['tinyint']
814
+ assert typ['tinyint'] == otype(1), typ['tinyint']
815
+
816
+ assert row['bool'] is None, row['bool']
817
+ assert typ['bool'] == otype(1), typ['bool']
818
+
819
+ assert row['boolean'] is None, row['boolean']
820
+ assert typ['boolean'] == otype(1), typ['boolean']
821
+
822
+ assert row['smallint'] is None, row['smallint']
823
+ assert typ['smallint'] == otype(2), typ['smallint']
824
+
825
+ assert row['mediumint'] is None, row['mediumint']
826
+ assert typ['mediumint'] == otype(9), typ['mediumint']
827
+
828
+ assert row['int24'] is None, row['int24']
829
+ assert typ['int24'] == otype(9), typ['int24']
830
+
831
+ assert row['int'] is None, row['int']
832
+ assert typ['int'] == otype(3), typ['int']
833
+
834
+ assert row['integer'] is None, row['integer']
835
+ assert typ['integer'] == otype(3), typ['integer']
836
+
837
+ assert row['bigint'] is None, row['bigint']
838
+ assert typ['bigint'] == otype(8), typ['bigint']
839
+
840
+ assert row['float'] is None, row['float']
841
+ assert typ['float'] == otype(4), typ['float']
842
+
843
+ assert row['double'] is None, row['double']
844
+ assert typ['double'] == otype(5), typ['double']
845
+
846
+ assert row['real'] is None, row['real']
847
+ assert typ['real'] == otype(5), typ['real']
848
+
849
+ assert row['decimal'] is None, row['decimal']
850
+ assert typ['decimal'] == otype(246), typ['decimal']
851
+
852
+ assert row['dec'] is None, row['dec']
853
+ assert typ['dec'] == otype(246), typ['dec']
854
+
855
+ assert row['fixed'] is None, row['fixed']
856
+ assert typ['fixed'] == otype(246), typ['fixed']
857
+
858
+ assert row['numeric'] is None, row['numeric']
859
+ assert typ['numeric'] == otype(246), typ['numeric']
860
+
861
+ assert row['date'] is None, row['date']
862
+ assert typ['date'] == 10, typ['date']
863
+
864
+ assert row['time'] is None, row['time']
865
+ assert typ['time'] == 11, typ['time']
866
+
867
+ assert row['time'] is None, row['time']
868
+ assert typ['time_6'] == 11, typ['time_6']
869
+
870
+ assert row['datetime'] is None, row['datetime']
871
+ assert typ['datetime'] == 12, typ['datetime']
872
+
873
+ assert row['datetime_6'] is None, row['datetime_6']
874
+ assert typ['datetime'] == 12, typ['datetime']
875
+
876
+ assert row['timestamp'] is None, row['timestamp']
877
+ assert typ['timestamp'] == otype(7), typ['timestamp']
878
+
879
+ assert row['timestamp_6'] is None, row['timestamp_6']
880
+ assert typ['timestamp_6'] == otype(7), typ['timestamp_6']
881
+
882
+ assert row['year'] is None, row['year']
883
+ assert typ['year'] == otype(13), typ['year']
884
+
885
+ assert row['char_100'] is None, row['char_100']
886
+ assert typ['char_100'] == otype(254), typ['char_100']
887
+
888
+ assert row['binary_100'] is None, row['binary_100']
889
+ assert typ['binary_100'] == otype(254), typ['binary_100']
890
+
891
+ assert row['varchar_200'] is None, typ['varchar_200']
892
+ assert typ['varchar_200'] == otype(253), typ['varchar_200'] # why not 15?
893
+
894
+ assert row['varbinary_200'] is None, row['varbinary_200']
895
+ assert typ['varbinary_200'] == otype(253), typ['varbinary_200'] # why not 15?
896
+
897
+ assert row['longtext'] is None, row['longtext']
898
+ assert typ['longtext'] == otype(251), typ['longtext']
899
+
900
+ assert row['mediumtext'] is None, row['mediumtext']
901
+ assert typ['mediumtext'] == otype(250), typ['mediumtext']
902
+
903
+ assert row['text'] is None, row['text']
904
+ assert typ['text'] == otype(252), typ['text']
905
+
906
+ assert row['tinytext'] is None, row['tinytext']
907
+ assert typ['tinytext'] == otype(249), typ['tinytext']
908
+
909
+ assert row['longblob'] is None, row['longblob']
910
+ assert typ['longblob'] == otype(251), typ['longblob']
911
+
912
+ assert row['mediumblob'] is None, row['mediumblob']
913
+ assert typ['mediumblob'] == otype(250), typ['mediumblob']
914
+
915
+ assert row['blob'] is None, row['blob']
916
+ assert typ['blob'] == otype(252), typ['blob']
917
+
918
+ assert row['tinyblob'] is None, row['tinyblob']
919
+ assert typ['tinyblob'] == otype(249), typ['tinyblob']
920
+
921
+ assert row['json'] is None, row['json']
922
+ assert typ['json'] == otype(245), typ['json']
923
+
924
+ assert row['enum'] is None, row['enum']
925
+ assert typ['enum'] == otype(253), typ['enum'] # mysql code: 247
926
+
927
+ assert row['set'] is None, row['set']
928
+ assert typ['set'] == otype(253), typ['set'] # mysql code: 248
929
+
930
+ assert row['bit'] is None, row['bit']
931
+ assert typ['bit'] == otype(16), typ['bit']
932
+
933
+ def test_alltypes_mins(self):
934
+ self.cur.execute('select * from alltypes where id = 2')
935
+ names = [x[0] for x in self.cur.description]
936
+ out = self.cur.fetchone()
937
+ row = dict(zip(names, out))
938
+
939
+ expected = dict(
940
+ id=2,
941
+ tinyint=-128,
942
+ unsigned_tinyint=0,
943
+ bool=-128,
944
+ boolean=-128,
945
+ smallint=-32768,
946
+ unsigned_smallint=0,
947
+ mediumint=-8388608,
948
+ unsigned_mediumint=0,
949
+ int24=-8388608,
950
+ unsigned_int24=0,
951
+ int=-2147483648,
952
+ unsigned_int=0,
953
+ integer=-2147483648,
954
+ unsigned_integer=0,
955
+ bigint=-9223372036854775808,
956
+ unsigned_bigint=0,
957
+ float=0,
958
+ double=-1.7976931348623158e308,
959
+ real=-1.7976931348623158e308,
960
+ decimal=decimal.Decimal('-99999999999999.999999'),
961
+ dec=-decimal.Decimal('99999999999999.999999'),
962
+ fixed=decimal.Decimal('-99999999999999.999999'),
963
+ numeric=decimal.Decimal('-99999999999999.999999'),
964
+ date=datetime.date(1000, 1, 1),
965
+ time=-1 * datetime.timedelta(hours=838, minutes=59, seconds=59),
966
+ time_6=-1 * datetime.timedelta(hours=838, minutes=59, seconds=59),
967
+ datetime=datetime.datetime(1000, 1, 1, 0, 0, 0),
968
+ datetime_6=datetime.datetime(1000, 1, 1, 0, 0, 0, 0),
969
+ timestamp=datetime.datetime(1970, 1, 1, 0, 0, 1),
970
+ timestamp_6=datetime.datetime(1970, 1, 1, 0, 0, 1, 0),
971
+ year=1901,
972
+ char_100='',
973
+ binary_100=b'\x00' * 100,
974
+ varchar_200='',
975
+ varbinary_200=b'',
976
+ longtext='',
977
+ mediumtext='',
978
+ text='',
979
+ tinytext='',
980
+ longblob=b'',
981
+ mediumblob=b'',
982
+ blob=b'',
983
+ tinyblob=b'',
984
+ json={},
985
+ enum='one',
986
+ set='two',
987
+ bit=b'\x00\x00\x00\x00\x00\x00\x00\x00',
988
+ )
989
+
990
+ for k, v in sorted(row.items()):
991
+ assert v == expected[k], '{} != {} in key {}'.format(v, expected[k], k)
992
+
993
+ def test_alltypes_maxs(self):
994
+ self.cur.execute('select * from alltypes where id = 3')
995
+ names = [x[0] for x in self.cur.description]
996
+ out = self.cur.fetchone()
997
+ row = dict(zip(names, out))
998
+
999
+ expected = dict(
1000
+ id=3,
1001
+ tinyint=127,
1002
+ unsigned_tinyint=255,
1003
+ bool=127,
1004
+ boolean=127,
1005
+ smallint=32767,
1006
+ unsigned_smallint=65535,
1007
+ mediumint=8388607,
1008
+ unsigned_mediumint=16777215,
1009
+ int24=8388607,
1010
+ unsigned_int24=16777215,
1011
+ int=2147483647,
1012
+ unsigned_int=4294967295,
1013
+ integer=2147483647,
1014
+ unsigned_integer=4294967295,
1015
+ bigint=9223372036854775807,
1016
+ unsigned_bigint=18446744073709551615,
1017
+ float=0,
1018
+ double=1.7976931348623158e308,
1019
+ real=1.7976931348623158e308,
1020
+ decimal=decimal.Decimal('99999999999999.999999'),
1021
+ dec=decimal.Decimal('99999999999999.999999'),
1022
+ fixed=decimal.Decimal('99999999999999.999999'),
1023
+ numeric=decimal.Decimal('99999999999999.999999'),
1024
+ date=datetime.date(9999, 12, 31),
1025
+ time=datetime.timedelta(hours=838, minutes=59, seconds=59),
1026
+ time_6=datetime.timedelta(hours=838, minutes=59, seconds=59),
1027
+ datetime=datetime.datetime(9999, 12, 31, 23, 59, 59),
1028
+ datetime_6=datetime.datetime(9999, 12, 31, 23, 59, 59, 999999),
1029
+ timestamp=datetime.datetime(2038, 1, 19, 3, 14, 7),
1030
+ timestamp_6=datetime.datetime(2038, 1, 19, 3, 14, 7, 999999),
1031
+ year=2155,
1032
+ char_100='',
1033
+ binary_100=b'\x00' * 100,
1034
+ varchar_200='',
1035
+ varbinary_200=b'',
1036
+ longtext='',
1037
+ mediumtext='',
1038
+ text='',
1039
+ tinytext='',
1040
+ longblob=b'',
1041
+ mediumblob=b'',
1042
+ blob=b'',
1043
+ tinyblob=b'',
1044
+ json={},
1045
+ enum='one',
1046
+ set='two',
1047
+ bit=b'\xff\xff\xff\xff\xff\xff\xff\xff',
1048
+ )
1049
+
1050
+ for k, v in sorted(row.items()):
1051
+ # TODO: Figure out how to get time zones working
1052
+ if 'timestamp' in k:
1053
+ continue
1054
+ assert v == expected[k], '{} != {} in key {}'.format(v, expected[k], k)
1055
+
1056
+ def test_alltypes_zeros(self):
1057
+ self.cur.execute('select * from alltypes where id = 4')
1058
+ names = [x[0] for x in self.cur.description]
1059
+ out = self.cur.fetchone()
1060
+ row = dict(zip(names, out))
1061
+
1062
+ expected = dict(
1063
+ id=4,
1064
+ tinyint=0,
1065
+ unsigned_tinyint=0,
1066
+ bool=0,
1067
+ boolean=0,
1068
+ smallint=0,
1069
+ unsigned_smallint=0,
1070
+ mediumint=0,
1071
+ unsigned_mediumint=0,
1072
+ int24=0,
1073
+ unsigned_int24=0,
1074
+ int=0,
1075
+ unsigned_int=0,
1076
+ integer=0,
1077
+ unsigned_integer=0,
1078
+ bigint=0,
1079
+ unsigned_bigint=0,
1080
+ float=0,
1081
+ double=0,
1082
+ real=0,
1083
+ decimal=decimal.Decimal('0.0'),
1084
+ dec=decimal.Decimal('0.0'),
1085
+ fixed=decimal.Decimal('0.0'),
1086
+ numeric=decimal.Decimal('0.0'),
1087
+ date=None,
1088
+ time=datetime.timedelta(hours=0, minutes=0, seconds=0),
1089
+ time_6=datetime.timedelta(hours=0, minutes=0, seconds=0, microseconds=0),
1090
+ datetime=None,
1091
+ datetime_6=None,
1092
+ timestamp=None,
1093
+ timestamp_6=None,
1094
+ year=None,
1095
+ char_100='',
1096
+ binary_100=b'\x00' * 100,
1097
+ varchar_200='',
1098
+ varbinary_200=b'',
1099
+ longtext='',
1100
+ mediumtext='',
1101
+ text='',
1102
+ tinytext='',
1103
+ longblob=b'',
1104
+ mediumblob=b'',
1105
+ blob=b'',
1106
+ tinyblob=b'',
1107
+ json={},
1108
+ enum='one',
1109
+ set='two',
1110
+ bit=b'\x00\x00\x00\x00\x00\x00\x00\x00',
1111
+ )
1112
+
1113
+ for k, v in sorted(row.items()):
1114
+ assert v == expected[k], '{} != {} in key {}'.format(v, expected[k], k)
1115
+
1116
+ def _test_MySQLdb(self):
1117
+ try:
1118
+ import json
1119
+ import MySQLdb
1120
+ except (ModuleNotFoundError, ImportError):
1121
+ self.skipTest('MySQLdb is not installed')
1122
+
1123
+ self.cur.execute('select * from alltypes order by id')
1124
+ s2_out = self.cur.fetchall()
1125
+
1126
+ port = self.conn.connection_params['port']
1127
+ if 'http' in self.conn.driver:
1128
+ port = 3306
1129
+
1130
+ args = dict(
1131
+ host=self.conn.connection_params['host'],
1132
+ port=port,
1133
+ user=self.conn.connection_params['user'],
1134
+ password=self.conn.connection_params['password'],
1135
+ database=type(self).dbname,
1136
+ )
1137
+
1138
+ with MySQLdb.connect(**args) as conn:
1139
+ conn.converter[245] = json.loads
1140
+ with conn.cursor() as cur:
1141
+ cur.execute('select * from alltypes order by id')
1142
+ mydb_out = cur.fetchall()
1143
+
1144
+ for a, b in zip(s2_out, mydb_out):
1145
+ assert a == b, (a, b)
1146
+
1147
+ def test_int_string(self):
1148
+ string = 'a' * 48
1149
+ self.cur.execute(f"SELECT 1, '{string}'")
1150
+ self.assertEqual((1, string), self.cur.fetchone())
1151
+
1152
+ def test_double_string(self):
1153
+ string = 'a' * 49
1154
+ self.cur.execute(f"SELECT 1.2 :> DOUBLE, '{string}'")
1155
+ self.assertEqual((1.2, string), self.cur.fetchone())
1156
+
1157
+ def test_year_string(self):
1158
+ string = 'a' * 49
1159
+ self.cur.execute(f"SELECT 1999 :> YEAR, '{string}'")
1160
+ self.assertEqual((1999, string), self.cur.fetchone())
1161
+
1162
+ def test_nan_as_null(self):
1163
+ with self.assertRaises((s2.ProgrammingError, InvalidJSONError)):
1164
+ self.cur.execute('SELECT %s :> DOUBLE AS X', [math.nan])
1165
+
1166
+ with s2.connect(database=type(self).dbname, nan_as_null=True) as conn:
1167
+ with conn.cursor() as cur:
1168
+ cur.execute('SELECT %s :> DOUBLE AS X', [math.nan])
1169
+ self.assertEqual(None, list(cur)[0][0])
1170
+
1171
+ with s2.connect(database=type(self).dbname, nan_as_null=True) as conn:
1172
+ with conn.cursor() as cur:
1173
+ cur.execute('SELECT %s :> DOUBLE AS X', [1.234])
1174
+ self.assertEqual(1.234, list(cur)[0][0])
1175
+
1176
+ def test_inf_as_null(self):
1177
+ with self.assertRaises((s2.ProgrammingError, InvalidJSONError)):
1178
+ self.cur.execute('SELECT %s :> DOUBLE AS X', [math.inf])
1179
+
1180
+ with s2.connect(database=type(self).dbname, inf_as_null=True) as conn:
1181
+ with conn.cursor() as cur:
1182
+ cur.execute('SELECT %s :> DOUBLE AS X', [math.inf])
1183
+ self.assertEqual(None, list(cur)[0][0])
1184
+
1185
+ with s2.connect(database=type(self).dbname, inf_as_null=True) as conn:
1186
+ with conn.cursor() as cur:
1187
+ cur.execute('SELECT %s :> DOUBLE AS X', [1.234])
1188
+ self.assertEqual(1.234, list(cur)[0][0])
1189
+
1190
+ def test_encoding_errors(self):
1191
+ with s2.connect(
1192
+ database=type(self).dbname,
1193
+ encoding_errors='strict',
1194
+ ) as conn:
1195
+ with conn.cursor() as cur:
1196
+ cur.execute('SELECT * FROM badutf8')
1197
+ list(cur)
1198
+
1199
+ with s2.connect(
1200
+ database=type(self).dbname,
1201
+ encoding_errors='backslashreplace',
1202
+ ) as conn:
1203
+ with conn.cursor() as cur:
1204
+ cur.execute('SELECT * FROM badutf8')
1205
+ list(cur)
1206
+
1207
+ def test_character_lengths(self):
1208
+ if 'http' in self.conn.driver:
1209
+ self.skipTest('Character lengths too long for HTTP interface')
1210
+
1211
+ tbl_id = str(id(self))
1212
+
1213
+ self.cur.execute('DROP TABLE IF EXISTS test_character_lengths')
1214
+ self.cur.execute(rf'''
1215
+ CREATE TABLE `test_character_lengths_{tbl_id}` (
1216
+ `id` text CHARACTER SET utf8 COLLATE utf8_general_ci NOT NULL,
1217
+ `char_col` longtext CHARACTER SET utf8 COLLATE utf8_general_ci NOT NULL,
1218
+ `int_col` INT,
1219
+ PRIMARY KEY (`id`),
1220
+ SORT KEY `id` (`id`)
1221
+ ) AUTOSTATS_CARDINALITY_MODE=INCREMENTAL
1222
+ AUTOSTATS_HISTOGRAM_MODE=CREATE
1223
+ AUTOSTATS_SAMPLING=ON
1224
+ SQL_MODE='STRICT_ALL_TABLES'
1225
+ ''')
1226
+
1227
+ CHAR_STR_SHORT = 'a'
1228
+ CHAR_STR_LONG = 'a' * (2**8-1)
1229
+ SHORT_STR_SHORT = 'a' * ((2**8-1) + 1)
1230
+ SHORT_STR_LONG = 'a' * (2**16-1)
1231
+ INT24_STR_SHORT = 'a' * ((2**16-1) + 1)
1232
+ INT24_STR_LONG = 'a' * (2**24-1)
1233
+ INT64_STR_SHORT = 'a' * ((2**24-1) + 1)
1234
+ INT64_STR_LONG = 'a' * ((2**24-1) + 100000)
1235
+
1236
+ data = [
1237
+ ['CHAR_SHORT', CHAR_STR_SHORT, 123456],
1238
+ ['CHAR_LONG', CHAR_STR_LONG, 123456],
1239
+ ['SHORT_SHORT', SHORT_STR_SHORT, 123456],
1240
+ ['SHORT_LONG', SHORT_STR_LONG, 123456],
1241
+ ['INT24_SHORT', INT24_STR_SHORT, 123456],
1242
+ ['INT24_LONG', INT24_STR_LONG, 123456],
1243
+ ['INT64_SHORT', INT64_STR_SHORT, 123456],
1244
+ ['INT64_LONG', INT64_STR_LONG, 123456],
1245
+ ]
1246
+
1247
+ self.cur.executemany(
1248
+ f'INSERT INTO test_character_lengths_{tbl_id}(id, char_col, int_col) '
1249
+ 'VALUES (%s, %s, %s)', data,
1250
+ )
1251
+
1252
+ for i, row in enumerate(data):
1253
+ self.cur.execute(
1254
+ f'SELECT id, char_col, int_col FROM test_character_lengths_{tbl_id} '
1255
+ 'WHERE id = %s',
1256
+ [row[0]],
1257
+ )
1258
+ assert data[i] == list(list(self.cur)[0])
1259
+
1260
+ try:
1261
+ self.cur.execute(f'DROP TABLE test_character_lengths_{tbl_id}')
1262
+ except Exception:
1263
+ pass
1264
+
1265
+ def test_pydantic(self):
1266
+ if not has_pydantic:
1267
+ self.skipTest('Test requires pydantic')
1268
+
1269
+ tblname = 'foo_' + str(id(self))
1270
+
1271
+ class FooData(pydantic.BaseModel):
1272
+ x: Optional[int]
1273
+ y: Optional[float]
1274
+ z: Optional[str] = None
1275
+
1276
+ self.cur.execute(f'''
1277
+ CREATE TABLE {tblname}(
1278
+ x INT,
1279
+ y DOUBLE,
1280
+ z TEXT
1281
+ )
1282
+ ''')
1283
+
1284
+ self.cur.execute(
1285
+ f'INSERT INTO {tblname}(x, y) VALUES (%(x)s, %(y)s)',
1286
+ FooData(x=2, y=3.23),
1287
+ )
1288
+
1289
+ self.cur.execute('SELECT * FROM ' + tblname)
1290
+
1291
+ assert list(sorted(self.cur.fetchall())) == \
1292
+ list(sorted([(2, 3.23, None)]))
1293
+
1294
+ self.cur.executemany(
1295
+ f'INSERT INTO {tblname}(x) VALUES (%(x)s)',
1296
+ [FooData(x=3, y=3.12), FooData(x=10, y=100.11)],
1297
+ )
1298
+
1299
+ self.cur.execute('SELECT * FROM ' + tblname)
1300
+
1301
+ assert list(sorted(self.cur.fetchall())) == \
1302
+ list(
1303
+ sorted([
1304
+ (2, 3.23, None),
1305
+ (3, None, None),
1306
+ (10, None, None),
1307
+ ]),
1308
+ )
1309
+
1310
+ def test_charset(self):
1311
+ self.skipTest('Skip until charset commands are re-implemented')
1312
+
1313
+ with s2.connect(database=type(self).dbname) as conn:
1314
+ with conn.cursor() as cur:
1315
+ cur.execute('''
1316
+ select json_extract_string('{"foo":"😀"}', "bar");
1317
+ ''')
1318
+
1319
+ if 'http' in self.conn.driver:
1320
+ self.skipTest('Charset is not use in HTTP interface')
1321
+
1322
+ with self.assertRaises(s2.OperationalError):
1323
+ with s2.connect(database=type(self).dbname, charset='utf8') as conn:
1324
+ with conn.cursor() as cur:
1325
+ cur.execute('''
1326
+ select json_extract_string('{"foo":"😀"}', "bar");
1327
+ ''')
1328
+
1329
+
1330
+ if __name__ == '__main__':
1331
+ import nose2
1332
+ nose2.main()