singlestoredb 0.4.0__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of singlestoredb might be problematic. Click here for more details.

Files changed (120) hide show
  1. singlestoredb/__init__.py +33 -1
  2. singlestoredb/alchemy/__init__.py +90 -0
  3. singlestoredb/auth.py +5 -1
  4. singlestoredb/config.py +116 -14
  5. singlestoredb/connection.py +483 -516
  6. singlestoredb/converters.py +238 -135
  7. singlestoredb/exceptions.py +30 -2
  8. singlestoredb/functions/__init__.py +1 -0
  9. singlestoredb/functions/decorator.py +142 -0
  10. singlestoredb/functions/dtypes.py +1639 -0
  11. singlestoredb/functions/ext/__init__.py +2 -0
  12. singlestoredb/functions/ext/arrow.py +375 -0
  13. singlestoredb/functions/ext/asgi.py +661 -0
  14. singlestoredb/functions/ext/json.py +427 -0
  15. singlestoredb/functions/ext/mmap.py +306 -0
  16. singlestoredb/functions/ext/rowdat_1.py +744 -0
  17. singlestoredb/functions/signature.py +673 -0
  18. singlestoredb/fusion/__init__.py +11 -0
  19. singlestoredb/fusion/graphql.py +213 -0
  20. singlestoredb/fusion/handler.py +621 -0
  21. singlestoredb/fusion/handlers/stage.py +257 -0
  22. singlestoredb/fusion/handlers/utils.py +162 -0
  23. singlestoredb/fusion/handlers/workspace.py +412 -0
  24. singlestoredb/fusion/registry.py +164 -0
  25. singlestoredb/fusion/result.py +399 -0
  26. singlestoredb/http/__init__.py +27 -0
  27. singlestoredb/{http.py → http/connection.py} +555 -154
  28. singlestoredb/management/__init__.py +3 -0
  29. singlestoredb/management/billing_usage.py +148 -0
  30. singlestoredb/management/cluster.py +14 -6
  31. singlestoredb/management/manager.py +100 -38
  32. singlestoredb/management/organization.py +188 -0
  33. singlestoredb/management/region.py +5 -5
  34. singlestoredb/management/utils.py +281 -2
  35. singlestoredb/management/workspace.py +1344 -49
  36. singlestoredb/{clients/pymysqlsv → mysql}/__init__.py +16 -21
  37. singlestoredb/{clients/pymysqlsv → mysql}/_auth.py +39 -8
  38. singlestoredb/{clients/pymysqlsv → mysql}/charset.py +26 -23
  39. singlestoredb/{clients/pymysqlsv/connections.py → mysql/connection.py} +532 -165
  40. singlestoredb/{clients/pymysqlsv → mysql}/constants/CLIENT.py +0 -1
  41. singlestoredb/{clients/pymysqlsv → mysql}/constants/COMMAND.py +0 -1
  42. singlestoredb/{clients/pymysqlsv → mysql}/constants/CR.py +0 -2
  43. singlestoredb/{clients/pymysqlsv → mysql}/constants/ER.py +0 -1
  44. singlestoredb/{clients/pymysqlsv → mysql}/constants/FIELD_TYPE.py +1 -1
  45. singlestoredb/{clients/pymysqlsv → mysql}/constants/FLAG.py +0 -1
  46. singlestoredb/{clients/pymysqlsv → mysql}/constants/SERVER_STATUS.py +0 -1
  47. singlestoredb/mysql/converters.py +271 -0
  48. singlestoredb/{clients/pymysqlsv → mysql}/cursors.py +228 -112
  49. singlestoredb/mysql/err.py +92 -0
  50. singlestoredb/{clients/pymysqlsv → mysql}/optionfile.py +5 -4
  51. singlestoredb/{clients/pymysqlsv → mysql}/protocol.py +49 -20
  52. singlestoredb/mysql/tests/__init__.py +19 -0
  53. singlestoredb/{clients/pymysqlsv → mysql}/tests/base.py +32 -12
  54. singlestoredb/mysql/tests/conftest.py +37 -0
  55. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_DictCursor.py +11 -7
  56. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_SSCursor.py +17 -12
  57. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_basic.py +32 -24
  58. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_connection.py +130 -119
  59. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_converters.py +9 -7
  60. singlestoredb/mysql/tests/test_cursor.py +141 -0
  61. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_err.py +3 -2
  62. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_issues.py +35 -27
  63. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_load_local.py +13 -11
  64. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_nextset.py +7 -3
  65. singlestoredb/{clients/pymysqlsv → mysql}/tests/test_optionfile.py +2 -1
  66. singlestoredb/{clients/pymysqlsv → mysql}/tests/thirdparty/__init__.py +1 -1
  67. singlestoredb/mysql/tests/thirdparty/test_MySQLdb/__init__.py +9 -0
  68. singlestoredb/{clients/pymysqlsv → mysql}/tests/thirdparty/test_MySQLdb/capabilities.py +19 -17
  69. singlestoredb/{clients/pymysqlsv → mysql}/tests/thirdparty/test_MySQLdb/dbapi20.py +31 -22
  70. singlestoredb/{clients/pymysqlsv → mysql}/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +3 -4
  71. singlestoredb/{clients/pymysqlsv → mysql}/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +24 -20
  72. singlestoredb/{clients/pymysqlsv → mysql}/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py +4 -4
  73. singlestoredb/{clients/pymysqlsv → mysql}/times.py +3 -4
  74. singlestoredb/pytest.py +283 -0
  75. singlestoredb/tests/empty.sql +0 -0
  76. singlestoredb/tests/ext_funcs/__init__.py +385 -0
  77. singlestoredb/tests/test.sql +210 -0
  78. singlestoredb/tests/test2.sql +1 -0
  79. singlestoredb/tests/test_basics.py +482 -115
  80. singlestoredb/tests/test_config.py +13 -13
  81. singlestoredb/tests/test_connection.py +241 -305
  82. singlestoredb/tests/test_dbapi.py +27 -0
  83. singlestoredb/tests/test_ext_func.py +1193 -0
  84. singlestoredb/tests/test_ext_func_data.py +1101 -0
  85. singlestoredb/tests/test_fusion.py +465 -0
  86. singlestoredb/tests/test_http.py +32 -26
  87. singlestoredb/tests/test_management.py +588 -8
  88. singlestoredb/tests/test_plugin.py +33 -0
  89. singlestoredb/tests/test_results.py +11 -12
  90. singlestoredb/tests/test_udf.py +687 -0
  91. singlestoredb/tests/utils.py +3 -2
  92. singlestoredb/utils/config.py +58 -0
  93. singlestoredb/utils/debug.py +13 -0
  94. singlestoredb/utils/mogrify.py +151 -0
  95. singlestoredb/utils/results.py +4 -1
  96. singlestoredb-1.0.4.dist-info/METADATA +139 -0
  97. singlestoredb-1.0.4.dist-info/RECORD +112 -0
  98. {singlestoredb-0.4.0.dist-info → singlestoredb-1.0.4.dist-info}/WHEEL +1 -1
  99. singlestoredb-1.0.4.dist-info/entry_points.txt +2 -0
  100. singlestoredb/clients/pymysqlsv/converters.py +0 -365
  101. singlestoredb/clients/pymysqlsv/err.py +0 -144
  102. singlestoredb/clients/pymysqlsv/tests/__init__.py +0 -19
  103. singlestoredb/clients/pymysqlsv/tests/test_cursor.py +0 -133
  104. singlestoredb/clients/pymysqlsv/tests/thirdparty/test_MySQLdb/__init__.py +0 -9
  105. singlestoredb/drivers/__init__.py +0 -45
  106. singlestoredb/drivers/base.py +0 -198
  107. singlestoredb/drivers/cymysql.py +0 -38
  108. singlestoredb/drivers/http.py +0 -47
  109. singlestoredb/drivers/mariadb.py +0 -40
  110. singlestoredb/drivers/mysqlconnector.py +0 -49
  111. singlestoredb/drivers/mysqldb.py +0 -60
  112. singlestoredb/drivers/pymysql.py +0 -37
  113. singlestoredb/drivers/pymysqlsv.py +0 -35
  114. singlestoredb/drivers/pyodbc.py +0 -65
  115. singlestoredb-0.4.0.dist-info/METADATA +0 -111
  116. singlestoredb-0.4.0.dist-info/RECORD +0 -86
  117. /singlestoredb/{clients → fusion/handlers}/__init__.py +0 -0
  118. /singlestoredb/{clients/pymysqlsv → mysql}/constants/__init__.py +0 -0
  119. {singlestoredb-0.4.0.dist-info → singlestoredb-1.0.4.dist-info}/LICENSE +0 -0
  120. {singlestoredb-0.4.0.dist-info → singlestoredb-1.0.4.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,10 @@
1
1
  #!/usr/bin/env python
2
2
  """SingleStoreDB connections and cursors."""
3
+ import abc
3
4
  import inspect
4
- import pprint
5
5
  import re
6
+ import warnings
6
7
  import weakref
7
- from collections import namedtuple
8
8
  from collections.abc import Mapping
9
9
  from collections.abc import MutableMapping
10
10
  from typing import Any
@@ -28,23 +28,17 @@ except ImportError:
28
28
  pass
29
29
 
30
30
  from . import auth
31
- from . import drivers
32
31
  from . import exceptions
33
- from . import types
34
32
  from .config import get_option
35
- from .drivers.base import Driver
36
- from .utils.convert_rows import convert_row
37
- from .utils.convert_rows import convert_rows
38
33
  from .utils.results import Description
39
- from .utils.results import format_results
40
34
  from .utils.results import Result
41
35
 
42
36
 
43
37
  # DB-API settings
44
38
  apilevel = '2.0'
45
39
  threadsafety = 1
46
- paramstyle = map_paramstyle = 'named'
47
- positional_paramstyle = 'numeric'
40
+ paramstyle = map_paramstyle = 'pyformat'
41
+ positional_paramstyle = 'format'
48
42
 
49
43
 
50
44
  # Type codes for character-based columns
@@ -102,7 +96,7 @@ def cast_bool_param(val: Any) -> bool:
102
96
  if val.lower() in ['on', 't', 'true', 'y', 'yes', 'enabled', 'enable']:
103
97
  return True
104
98
  elif val.lower() in ['off', 'f', 'false', 'n', 'no', 'disabled', 'disable']:
105
- return True
99
+ return False
106
100
 
107
101
  raise ValueError('Unrecognized value for bool: {}'.format(val))
108
102
 
@@ -127,10 +121,18 @@ def build_params(**kwargs: Any) -> Dict[str, Any]:
127
121
 
128
122
  # Set known parameters
129
123
  for name in inspect.getfullargspec(connect).args:
130
- if name == 'converters':
131
- out[name] = kwargs.get(name, {})
132
- elif name == 'results_format':
133
- out[name] = kwargs.get(name, get_option('results.format'))
124
+ if name == 'conv':
125
+ out[name] = kwargs.get(name, None)
126
+ elif name == 'results_format': # deprecated
127
+ if kwargs.get(name, None) is not None:
128
+ warnings.warn(
129
+ 'The `results_format=` parameter has been '
130
+ 'renamed to `results_type=`.',
131
+ DeprecationWarning,
132
+ )
133
+ out['results_type'] = kwargs.get(name, get_option('results.type'))
134
+ elif name == 'results_type':
135
+ out[name] = kwargs.get(name, get_option('results.type'))
134
136
  else:
135
137
  out[name] = kwargs.get(name, get_option(name))
136
138
 
@@ -157,6 +159,9 @@ def build_params(**kwargs: Any) -> Dict[str, Any]:
157
159
  if 'user' not in out and not out.get('password', None):
158
160
  out.pop('password', None)
159
161
 
162
+ if out.get('ssl_ca', '') and not out.get('ssl_verify_cert', None):
163
+ out['ssl_verify_cert'] = True
164
+
160
165
  return out
161
166
 
162
167
 
@@ -302,6 +307,13 @@ def quote_identifier(name: str) -> str:
302
307
  return f'`{name}`'
303
308
 
304
309
 
310
+ class Driver(object):
311
+ """Compatibility class for driver name."""
312
+
313
+ def __init__(self, name: str):
314
+ self.name = name
315
+
316
+
305
317
  class VariableAccessor(MutableMapping): # type: ignore
306
318
  """Variable accessor class."""
307
319
 
@@ -327,30 +339,30 @@ class VariableAccessor(MutableMapping): # type: ignore
327
339
 
328
340
  def __getitem__(self, name: str) -> Any:
329
341
  name = _name_check(name)
330
- with self.connection._i_cursor() as cur:
331
- cur.execute('show {} variables like "{}";'.format(self.vtype, name))
332
- out = list(cur)
333
- if not out:
334
- raise KeyError(f"No variable found with the name '{name}'.")
335
- if len(out) > 1:
336
- raise KeyError(f"Multiple variables found with the name '{name}'.")
337
- return self._cast_value(out[0][1])
342
+ out = self.connection._iquery(
343
+ 'show {} variables like %s;'.format(self.vtype),
344
+ [name],
345
+ )
346
+ if not out:
347
+ raise KeyError(f"No variable found with the name '{name}'.")
348
+ if len(out) > 1:
349
+ raise KeyError(f"Multiple variables found with the name '{name}'.")
350
+ return self._cast_value(out[0]['Value'])
338
351
 
339
352
  def __setitem__(self, name: str, value: Any) -> None:
340
353
  name = _name_check(name)
341
- with self.connection._i_cursor() as cur:
342
- if value is True:
343
- value = 'ON'
344
- elif value is False:
345
- value = 'OFF'
346
- if 'local' in self.vtype:
347
- cur.execute(
348
- 'set {} {}=:1;'.format(
349
- self.vtype.replace('local', 'session'), name,
350
- ), [value],
351
- )
352
- else:
353
- cur.execute('set {} {}=:1;'.format(self.vtype, name), [value])
354
+ if value is True:
355
+ value = 'ON'
356
+ elif value is False:
357
+ value = 'OFF'
358
+ if 'local' in self.vtype:
359
+ self.connection._iquery(
360
+ 'set {} {}=%s;'.format(
361
+ self.vtype.replace('local', 'session'), name,
362
+ ), [value],
363
+ )
364
+ else:
365
+ self.connection._iquery('set {} {}=%s;'.format(self.vtype, name), [value])
354
366
 
355
367
  def __delitem__(self, name: str) -> None:
356
368
  raise TypeError('Variables can not be deleted.')
@@ -365,17 +377,15 @@ class VariableAccessor(MutableMapping): # type: ignore
365
377
  del self[name]
366
378
 
367
379
  def __len__(self) -> int:
368
- with self.connection._i_cursor() as cur:
369
- cur.execute('show {} variables;'.format(self.vtype))
370
- return len(list(cur))
380
+ out = self.connection._iquery('show {} variables;'.format(self.vtype))
381
+ return len(list(out))
371
382
 
372
383
  def __iter__(self) -> Iterator[str]:
373
- with self.connection._i_cursor() as cur:
374
- cur.execute('show {} variables;'.format(self.vtype))
375
- return iter(x[0] for x in list(cur))
384
+ out = self.connection._iquery('show {} variables;'.format(self.vtype))
385
+ return iter(list(x.values())[0] for x in out)
376
386
 
377
387
 
378
- class Cursor(object):
388
+ class Cursor(metaclass=abc.ABCMeta):
379
389
  """
380
390
  Database cursor for submitting commands and queries.
381
391
 
@@ -384,21 +394,14 @@ class Cursor(object):
384
394
 
385
395
  """
386
396
 
387
- def __init__(
388
- self, connection: 'Connection', cursor: Any, driver: Driver,
389
- ):
397
+ def __init__(self, connection: 'Connection'):
390
398
  """Call ``Connection.cursor`` instead."""
391
399
  self.errorhandler = connection.errorhandler
392
- self._results_format: str = connection.results_format
393
- self._conn: Optional[Connection] = weakref.proxy(connection)
394
- self._cursor = cursor
395
- self._driver = driver
400
+ self._connection: Optional[Connection] = weakref.proxy(connection)
396
401
 
397
- #: Current row of the cursor.
398
- self.rownumber: Optional[int] = None
402
+ self._rownumber: Optional[int] = None
399
403
 
400
- #: Description of columns in the last executed query.
401
- self.description: Optional[List[Description]] = None
404
+ self._description: Optional[List[Description]] = None
402
405
 
403
406
  #: Default batch size of ``fetchmany`` calls.
404
407
  self.arraysize = get_option('results.arraysize')
@@ -413,86 +416,32 @@ class Cursor(object):
413
416
  #: Number of rows affected by the last query.
414
417
  self.rowcount: int = -1
415
418
 
416
- #: Messages generated during last query.
417
- self.messages: List[str] = []
419
+ self._messages: List[Tuple[int, str]] = []
418
420
 
419
421
  #: Row ID of the last modified row.
420
422
  self.lastrowid: Optional[int] = None
421
423
 
422
424
  @property
423
- def connection(self) -> Optional['Connection']:
424
- """
425
- Return the connection that the cursor belongs to.
426
-
427
- Returns
428
- -------
429
- Connection or None
430
-
431
- """
432
- return self._conn
425
+ def messages(self) -> List[Tuple[int, str]]:
426
+ """Messages created by the server."""
427
+ return self._messages
433
428
 
434
- def _set_description(self) -> None:
435
- """
436
- Return column descriptions for the current result set.
429
+ @abc.abstractproperty
430
+ def description(self) -> Optional[List[Description]]:
431
+ """The field descriptions of the last query."""
432
+ return self._description
437
433
 
438
- Returns
439
- -------
440
- list of Description
434
+ @abc.abstractproperty
435
+ def rownumber(self) -> Optional[int]:
436
+ """The last modified row number."""
437
+ return self._rownumber
441
438
 
442
- """
443
- if self._cursor.description:
444
- self._converters.clear()
445
- out = []
446
- for i, item in enumerate(self._cursor.description):
447
- item = list(item) + [None, None]
448
- item[1] = types.ColumnType.get_code(item[1])
449
- item[6] = not (not (item[6]))
450
- out.append(Description(*item[:9]))
451
-
452
- # Setup override converters, if the SET flag is set use that
453
- # converter but keep the same type code.
454
- if item[7] and item[7] & 2048: # SET_FLAG = 2048
455
- conv = self._driver.converters.get(247, None) # SET CODE = 247
456
- else:
457
- conv = self._driver.converters.get(item[1], None)
458
-
459
- encoding = None
460
-
461
- # Determine proper encoding for character fields as needed
462
- if self._driver.returns_bytes:
463
- if item[1] in CHAR_COLUMNS:
464
- if item[8] and item[8] == 63: # BINARY / BLOB
465
- pass
466
- elif self._conn is not None:
467
- encoding = self._conn.encoding
468
- else:
469
- encoding = 'utf-8'
470
- elif item[1] == 16: # BIT
471
- pass
472
- else:
473
- encoding = 'ascii'
474
-
475
- if conv is not None:
476
- self._converters.append((i, encoding, conv))
477
- elif encoding is not None:
478
- self._converters.append((i, encoding, None))
479
-
480
- self.description = out
481
-
482
- def _update_attrs(self) -> None:
483
- """Update cursor attributes from the last query."""
484
- if self._cursor is None:
485
- return
486
- self.messages[:] = getattr(self._cursor, 'messages', [])
487
- self.lastrowid = getattr(
488
- self._cursor, 'lastrowid',
489
- getattr(self._cursor, '_lastrowid', None),
490
- ) or None
491
- self.rowcount = getattr(
492
- self._cursor, 'rowcount',
493
- getattr(self._cursor, '_rowcount', -1),
494
- )
439
+ @property
440
+ def connection(self) -> Optional['Connection']:
441
+ """the connection that the cursor belongs to."""
442
+ return self._connection
495
443
 
444
+ @abc.abstractmethod
496
445
  def callproc(
497
446
  self, name: str,
498
447
  params: Optional[Sequence[Any]] = None,
@@ -506,18 +455,23 @@ class Cursor(object):
506
455
  multiple result sets, subsequent result sets can be accessed
507
456
  using :meth:`nextset`.
508
457
 
458
+ Examples
459
+ --------
460
+ >>> cur.callproc('myprocedure', ['arg1', 'arg2'])
461
+ >>> print(cur.fetchall())
462
+
509
463
  Parameters
510
464
  ----------
511
465
  name : str
512
466
  Name of the stored procedure
513
- params : iterable, optional
467
+ params : iterable, optional
514
468
  Parameters to the stored procedure
515
469
 
516
470
  """
517
471
  # NOTE: The `callproc` interface varies quite a bit between drivers
518
472
  # so it is implemented using `execute` here.
519
473
 
520
- if self._cursor is None:
474
+ if not self.is_connected():
521
475
  raise exceptions.InterfaceError(2048, 'Cursor is closed.')
522
476
 
523
477
  name = _name_check(name)
@@ -528,186 +482,144 @@ class Cursor(object):
528
482
  keys = ', '.join([f':{i+1}' for i in range(len(params))])
529
483
  self.execute(f'CALL {name}({keys});', params)
530
484
 
485
+ @abc.abstractmethod
486
+ def is_connected(self) -> bool:
487
+ """Is the cursor still connected?"""
488
+ raise NotImplementedError
489
+
490
+ @abc.abstractmethod
531
491
  def close(self) -> None:
532
492
  """Close the cursor."""
533
- if self._cursor is None:
534
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
535
-
536
- try:
537
- self._cursor.close()
538
-
539
- # Ignore weak reference errors. It just means the connection
540
- # was closed underneath us.
541
- except ReferenceError:
542
- pass
543
-
544
- except Exception as exc:
545
- raise self._driver.convert_exception(exc)
546
-
547
- self._cursor = None
548
- self._conn = None
493
+ raise NotImplementedError
549
494
 
495
+ @abc.abstractmethod
550
496
  def execute(
551
- self, oper: str,
552
- params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
553
- ) -> None:
497
+ self, query: str,
498
+ args: Optional[Union[Sequence[Any], Dict[str, Any], Any]] = None,
499
+ ) -> int:
554
500
  """
555
501
  Execute a SQL statement.
556
502
 
503
+ Queries can use the ``format``-style parameters (``%s``) when using a
504
+ list of paramters or ``pyformat``-style parameters (``%(key)s``)
505
+ when using a dictionary of parameters.
506
+
557
507
  Parameters
558
508
  ----------
559
- oper : str
509
+ query : str
560
510
  The SQL statement to execute
561
- params : iterable or dict, optional
511
+ args : Sequence or dict, optional
562
512
  Parameters to substitute into the SQL code
563
513
 
564
- """
565
- if self._cursor is None:
566
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
514
+ Examples
515
+ --------
516
+ >>> cur.execute('select * from mytable')
567
517
 
568
- self.description = None
569
- self.rownumber = None
518
+ >>> cur.execute('select * from mytable where id < %s', [100])
570
519
 
571
- try:
572
- if params:
573
- param_converter = sqlparams.SQLParams(
574
- isinstance(params, Mapping) and
575
- map_paramstyle or positional_paramstyle,
576
- self._driver.dbapi.paramstyle,
577
- escape_char=True,
578
- )
579
- self._cursor.execute(*param_converter.format(oper, params))
580
- else:
581
- self._cursor.execute(oper)
582
- except Exception as exc:
583
- raise self._driver.convert_exception(exc)
520
+ >>> cur.execute('select * from mytable where id < %(max)s', dict(max=100))
521
+
522
+ Returns
523
+ -------
524
+ Number of rows affected
584
525
 
585
- self._set_description()
586
- self._update_attrs()
587
- self.rownumber = 0
526
+ """
527
+ raise NotImplementedError
588
528
 
589
529
  def executemany(
590
- self, oper: str,
591
- param_seq: Optional[Sequence[Union[Sequence[Any], Dict[str, Any]]]] = None,
592
- ) -> None:
530
+ self, query: str,
531
+ args: Optional[Sequence[Union[Sequence[Any], Dict[str, Any], Any]]] = None,
532
+ ) -> int:
593
533
  """
594
534
  Execute SQL code against multiple sets of parameters.
595
535
 
536
+ Queries can use the ``format``-style parameters (``%s``) when using
537
+ lists of paramters or ``pyformat``-style parameters (``%(key)s``)
538
+ when using dictionaries of parameters.
539
+
596
540
  Parameters
597
541
  ----------
598
- oper : str
542
+ query : str
599
543
  The SQL statement to execute
600
- params_seq : iterable of iterables or dicts, optional
544
+ args : iterable of iterables or dicts, optional
601
545
  Sets of parameters to substitute into the SQL code
602
546
 
603
- """
604
- if self._cursor is None:
605
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
606
-
607
- self.description = None
608
- self.rownumber = None
609
-
610
- is_dataframe = False
611
- if isinstance(param_seq, DataFrame):
612
- is_dataframe = True
613
- else:
614
- param_seq = param_seq or [[]]
615
-
616
- try:
617
- # NOTE: Just implement using `execute` to cover driver inconsistencies
618
- if is_dataframe:
619
- for params in param_seq.itertuples(index=False):
620
- self.execute(oper, params)
547
+ Examples
548
+ --------
549
+ >>> cur.executemany('select * from mytable where id < %s',
550
+ ... [[100], [200], [300]])
621
551
 
622
- elif param_seq[0]:
623
- for params in param_seq:
624
- self.execute(oper, params)
625
- else:
626
- self.execute(oper)
552
+ >>> cur.executemany('select * from mytable where id < %(max)s',
553
+ ... [dict(max=100), dict(max=100), dict(max=300)])
627
554
 
628
- except Exception as exc:
629
- raise self._driver.convert_exception(exc)
555
+ Returns
556
+ -------
557
+ Number of rows affected
630
558
 
631
- self._set_description()
632
- self._update_attrs()
633
- self.rownumber = 0
559
+ """
560
+ # NOTE: Just implement using `execute` to cover driver inconsistencies
561
+ if not args:
562
+ self.execute(query)
563
+ else:
564
+ for params in args:
565
+ self.execute(query, params)
566
+ return self.rowcount
634
567
 
568
+ @abc.abstractmethod
635
569
  def fetchone(self) -> Optional[Result]:
636
570
  """
637
571
  Fetch a single row from the result set.
638
572
 
573
+ Examples
574
+ --------
575
+ >>> while True:
576
+ ... row = cur.fetchone()
577
+ ... if row is None:
578
+ ... break
579
+ ... print(row)
580
+
639
581
  Returns
640
582
  -------
641
583
  tuple
642
584
  Values of the returned row if there are rows remaining
643
585
 
644
586
  """
645
- if self._cursor is None:
646
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
647
-
648
- try:
649
- out = self._cursor.fetchone()
650
- except Exception as exc:
651
- raise self._driver.convert_exception(exc)
652
-
653
- if out is not None and self.rownumber is not None:
654
- self.rownumber += 1
655
-
656
- if out is not None:
657
- out = convert_row(tuple(out), self._converters)
658
-
659
- return format_results(
660
- self._results_format,
661
- self.description or [],
662
- out, single=True,
663
- )
587
+ raise NotImplementedError
664
588
 
589
+ @abc.abstractmethod
665
590
  def fetchmany(self, size: Optional[int] = None) -> Result:
666
591
  """
667
592
  Fetch `size` rows from the result.
668
593
 
669
594
  If `size` is not specified, the `arraysize` attribute is used.
670
595
 
596
+ Examples
597
+ --------
598
+ >>> while True:
599
+ ... out = cur.fetchmany(100)
600
+ ... if not len(out):
601
+ ... break
602
+ ... for row in out:
603
+ ... print(row)
604
+
671
605
  Returns
672
606
  -------
673
607
  list of tuples
674
608
  Values of the returned rows if there are rows remaining
675
609
 
676
610
  """
677
- if self._cursor is None:
678
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
679
-
680
- if size is not None:
681
- size = max(int(size), 1)
682
- else:
683
- size = max(int(self.arraysize), 1)
684
-
685
- try:
686
- # This is to get around a bug in mysql.connector. For some reason,
687
- # fetchmany(1) returns the same row over and over again.
688
- if size == 1:
689
- out = [self._cursor.fetchone()]
690
- else:
691
- # Don't use a keyword parameter for size=. Pyodbc fails with that.
692
- out = self._cursor.fetchmany(size)
693
- except Exception as exc:
694
- raise self._driver.convert_exception(exc)
695
-
696
- out = convert_rows(out, self._converters)
697
-
698
- formatted: Result = format_results(
699
- self._results_format, self.description or [], out,
700
- )
701
-
702
- if self.rownumber is not None:
703
- self.rownumber += len(formatted)
704
-
705
- return formatted
611
+ raise NotImplementedError
706
612
 
613
+ @abc.abstractmethod
707
614
  def fetchall(self) -> Result:
708
615
  """
709
616
  Fetch all rows in the result set.
710
617
 
618
+ Examples
619
+ --------
620
+ >>> for row in cur.fetchall():
621
+ ... print(row)
622
+
711
623
  Returns
712
624
  -------
713
625
  list of tuples
@@ -716,29 +628,22 @@ class Cursor(object):
716
628
  If there are no rows to return
717
629
 
718
630
  """
719
- if self._cursor is None:
720
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
721
-
722
- try:
723
- out = self._cursor.fetchall()
724
- except Exception as exc:
725
- raise self._driver.convert_exception(exc)
726
-
727
- out = convert_rows(out, self._converters)
728
-
729
- formatted: Result = format_results(
730
- self._results_format, self.description or [], out,
731
- )
732
-
733
- if self.rownumber is not None:
734
- self.rownumber += len(formatted)
735
-
736
- return formatted
631
+ raise NotImplementedError
737
632
 
633
+ @abc.abstractmethod
738
634
  def nextset(self) -> Optional[bool]:
739
635
  """
740
636
  Skip to the next available result set.
741
637
 
638
+ This is used when calling a procedure that returns multiple
639
+ results sets.
640
+
641
+ Note
642
+ ----
643
+ The ``nextset`` method must be called until it returns an empty
644
+ set (i.e., once more than the number of expected result sets).
645
+ This is to retain compatibility with PyMySQL and MySOLdb.
646
+
742
647
  Returns
743
648
  -------
744
649
  ``True``
@@ -747,46 +652,19 @@ class Cursor(object):
747
652
  If no other result set is available
748
653
 
749
654
  """
750
- if self._cursor is None:
751
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
752
-
753
- self.rownumber = None
754
-
755
- try:
756
- out = self._cursor.nextset()
757
- self._set_description()
758
- if out:
759
- self.rownumber = 0
760
- return True
761
- return False
762
-
763
- except Exception as exc:
764
- exc = self._driver.convert_exception(exc)
765
- if getattr(exc, 'errno', -1) == 2053:
766
- return False
767
- self.rownumber = 0
768
- return True
655
+ raise NotImplementedError
769
656
 
657
+ @abc.abstractmethod
770
658
  def setinputsizes(self, sizes: Sequence[int]) -> None:
771
659
  """Predefine memory areas for parameters."""
772
- if self._cursor is None:
773
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
774
-
775
- try:
776
- self._cursor.setinputsizes(sizes)
777
- except Exception as exc:
778
- raise self._driver.convert_exception(exc)
660
+ raise NotImplementedError
779
661
 
662
+ @abc.abstractmethod
780
663
  def setoutputsize(self, size: int, column: Optional[str] = None) -> None:
781
664
  """Set a column buffer size for fetches of large columns."""
782
- if self._cursor is None:
783
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
784
-
785
- try:
786
- self._cursor.setoutputsize(size, column)
787
- except Exception as exc:
788
- raise self._driver.convert_exception(exc)
665
+ raise NotImplementedError
789
666
 
667
+ @abc.abstractmethod
790
668
  def scroll(self, value: int, mode: str = 'relative') -> None:
791
669
  """
792
670
  Scroll the cursor to the position in the result set.
@@ -799,21 +677,7 @@ class Cursor(object):
799
677
  Where to move the cursor from: 'relative' or 'absolute'
800
678
 
801
679
  """
802
- if self._cursor is None:
803
- raise exceptions.InterfaceError(2048, 'Cursor is closed.')
804
-
805
- value = int(value)
806
- try:
807
- self._cursor.scroll(value, mode=mode)
808
- if self.rownumber is not None:
809
- if mode == 'relative':
810
- self.rownumber += value
811
- elif mode == 'absolute':
812
- self.rownumber = value
813
- else:
814
- raise ValueError(f'Unrecognized scroll mode {mode}')
815
- except Exception as exc:
816
- raise self._driver.convert_exception(exc)
680
+ raise NotImplementedError
817
681
 
818
682
  def next(self) -> Optional[Result]:
819
683
  """
@@ -829,18 +693,12 @@ class Cursor(object):
829
693
  tuple of values
830
694
 
831
695
  """
832
- if self._cursor is None:
696
+ if not self.is_connected():
833
697
  raise exceptions.InterfaceError(2048, 'Cursor is closed.')
834
-
835
- try:
836
- out = self.fetchone()
837
- if out is None:
838
- raise StopIteration
839
- return out
840
- except StopIteration:
841
- raise
842
- except Exception as exc:
843
- raise self._driver.convert_exception(exc)
698
+ out = self.fetchone()
699
+ if out is None:
700
+ raise StopIteration
701
+ return out
844
702
 
845
703
  __next__ = next
846
704
 
@@ -859,24 +717,22 @@ class Cursor(object):
859
717
  """Exit a context."""
860
718
  self.close()
861
719
 
862
- def is_connected(self) -> bool:
863
- """
864
- Check if the cursor is connected.
865
-
866
- Returns
867
- -------
868
- bool
869
-
870
- """
871
- if self._conn is None:
872
- return False
873
- return self._conn.is_connected()
874
-
875
720
 
876
721
  class ShowResult(Sequence[Any]):
877
722
  """
878
723
  Simple result object.
879
724
 
725
+ This object is primarily used for displaying results to a
726
+ terminal or web browser, but it can also be treated like a
727
+ simple data frame where columns are accessible using either
728
+ dictionary key-like syntax or attribute syntax.
729
+
730
+ Examples
731
+ --------
732
+ >>> conn.show.status().Value[10]
733
+
734
+ >>> conn.show.status()[10]['Value']
735
+
880
736
  Parameters
881
737
  ----------
882
738
  *args : Any
@@ -884,10 +740,14 @@ class ShowResult(Sequence[Any]):
884
740
  **kwargs : Any
885
741
  Keyword parameters to send to underlying list constructor
886
742
 
743
+ See Also
744
+ --------
745
+ :attr:`Connection.show`
746
+
887
747
  """
888
748
 
889
749
  def __init__(self, *args: Any, **kwargs: Any) -> None:
890
- self._data: List[Any] = []
750
+ self._data: List[Dict[str, Any]] = []
891
751
  item: Any = None
892
752
  for item in list(*args, **kwargs):
893
753
  self._data.append(item)
@@ -896,41 +756,69 @@ class ShowResult(Sequence[Any]):
896
756
  return self._data[item]
897
757
 
898
758
  def __getattr__(self, name: str) -> List[Any]:
759
+ if name.startswith('_ipython'):
760
+ raise AttributeError(name)
899
761
  out = []
900
762
  for item in self._data:
901
- out.append(getattr(item, name))
763
+ out.append(item[name])
902
764
  return out
903
765
 
904
766
  def __len__(self) -> int:
905
767
  return len(self._data)
906
768
 
907
- def _repr_pretty_(self, p: Any, cycle: bool) -> None:
908
- if cycle:
909
- p.text('[...]')
910
- else:
911
- p.text('[\n')
912
- for item in self._data:
913
- p.text(' ')
914
- p.text(pprint.pformat(item))
915
- p.text('\n')
916
- p.text(']')
769
+ def __repr__(self) -> str:
770
+ if not self._data:
771
+ return ''
772
+ return '\n{}\n'.format(self._format_table(self._data))
773
+
774
+ @property
775
+ def columns(self) -> List[str]:
776
+ """The columns in the result."""
777
+ if not self._data:
778
+ return []
779
+ return list(self._data[0].keys())
780
+
781
+ def _format_table(self, rows: Sequence[Dict[str, Any]]) -> str:
782
+ if not self._data:
783
+ return ''
784
+
785
+ keys = rows[0].keys()
786
+ lens = [len(x) for x in keys]
787
+
788
+ for row in self._data:
789
+ align = ['<'] * len(keys)
790
+ for i, k in enumerate(keys):
791
+ lens[i] = max(lens[i], len(str(row[k])))
792
+ align[i] = '<' if isinstance(row[k], (bytes, bytearray, str)) else '>'
793
+
794
+ fmt = '| %s |' % '|'.join([' {:%s%d} ' % (x, y) for x, y in zip(align, lens)])
795
+
796
+ out = []
797
+ out.append(fmt.format(*keys))
798
+ out.append('-' * len(out[0]))
799
+ for row in rows:
800
+ out.append(fmt.format(*[str(x) for x in row.values()]))
801
+ return '\n'.join(out)
802
+
803
+ def __str__(self) -> str:
804
+ return self.__repr__()
917
805
 
918
806
  def _repr_html_(self) -> str:
919
807
  if not self._data:
920
808
  return ''
921
809
  cell_style = 'style="text-align: left; vertical-align: top"'
922
810
  out = []
923
- out.append('<table>')
811
+ out.append('<table border="1" class="dataframe">')
924
812
  out.append('<thead>')
925
813
  out.append('<tr>')
926
- for name in self._data[0]._fields:
814
+ for name in self._data[0].keys():
927
815
  out.append(f'<th {cell_style}>{name}</th>')
928
816
  out.append('</tr>')
929
817
  out.append('</thead>')
930
818
  out.append('<tbody>')
931
819
  for row in self._data:
932
820
  out.append('<tr>')
933
- for item in row:
821
+ for item in row.values():
934
822
  out.append(f'<td {cell_style}>{item}</td>')
935
823
  out.append('</tr>')
936
824
  out.append('</tbody>')
@@ -939,7 +827,14 @@ class ShowResult(Sequence[Any]):
939
827
 
940
828
 
941
829
  class ShowAccessor(object):
942
- """Accessor for ``SHOW`` commands."""
830
+ """
831
+ Accessor for ``SHOW`` commands.
832
+
833
+ See Also
834
+ --------
835
+ :attr:`Connection.show`
836
+
837
+ """
943
838
 
944
839
  def __init__(self, conn: 'Connection'):
945
840
  self._conn = conn
@@ -948,152 +843,172 @@ class ShowAccessor(object):
948
843
  """Show the column information for the given table."""
949
844
  table = quote_identifier(table)
950
845
  if full:
951
- return self._query(f'full columns in {table}')
952
- return self._query(f'columns in {table}')
846
+ return self._iquery(f'full columns in {table}')
847
+ return self._iquery(f'columns in {table}')
953
848
 
954
849
  def tables(self, extended: bool = False) -> ShowResult:
955
850
  """Show tables in the current database."""
956
851
  if extended:
957
- return self._query('tables extended')
958
- return self._query('tables')
852
+ return self._iquery('tables extended')
853
+ return self._iquery('tables')
959
854
 
960
855
  def warnings(self) -> ShowResult:
961
856
  """Show warnings."""
962
- return self._query('warnings')
857
+ return self._iquery('warnings')
963
858
 
964
859
  def errors(self) -> ShowResult:
965
860
  """Show errors."""
966
- return self._query('errors')
861
+ return self._iquery('errors')
967
862
 
968
863
  def databases(self, extended: bool = False) -> ShowResult:
969
864
  """Show all databases in the server."""
970
865
  if extended:
971
- return self._query('databases extended')
972
- return self._query('databases')
866
+ return self._iquery('databases extended')
867
+ return self._iquery('databases')
973
868
 
974
869
  def database_status(self) -> ShowResult:
975
870
  """Show status of the current database."""
976
- return self._query('database status')
871
+ return self._iquery('database status')
977
872
 
978
873
  def global_status(self) -> ShowResult:
979
874
  """Show global status of the current server."""
980
- return self._query('global status')
875
+ return self._iquery('global status')
981
876
 
982
877
  def indexes(self, table: str) -> ShowResult:
983
878
  """Show all indexes in the given table."""
984
879
  table = quote_identifier(table)
985
- return self._query('indexes in {table}')
880
+ return self._iquery(f'indexes in {table}')
986
881
 
987
882
  def functions(self) -> ShowResult:
988
883
  """Show all functions in the current database."""
989
- return self._query('functions')
884
+ return self._iquery('functions')
990
885
 
991
886
  def partitions(self, extended: bool = False) -> ShowResult:
992
887
  """Show partitions in the current database."""
993
888
  if extended:
994
- return self._query('partitions extended')
995
- return self._query('partitions')
889
+ return self._iquery('partitions extended')
890
+ return self._iquery('partitions')
996
891
 
997
892
  def pipelines(self) -> ShowResult:
998
893
  """Show all pipelines in the current database."""
999
- return self._query('pipelines')
894
+ return self._iquery('pipelines')
1000
895
 
1001
- def plan(self, plan_id: str, json: bool = False) -> ShowResult:
896
+ def plan(self, plan_id: int, json: bool = False) -> ShowResult:
1002
897
  """Show the plan for the given plan ID."""
1003
- plan_id = quote_identifier(plan_id)
898
+ plan_id = int(plan_id)
1004
899
  if json:
1005
- return self._query(f'plan json {plan_id}')
1006
- return self._query(f'plan {plan_id}')
900
+ return self._iquery(f'plan json {plan_id}')
901
+ return self._iquery(f'plan {plan_id}')
1007
902
 
1008
903
  def plancache(self) -> ShowResult:
1009
904
  """Show all query statements compiled and executed."""
1010
- return self._query('plancache')
905
+ return self._iquery('plancache')
1011
906
 
1012
907
  def processlist(self) -> ShowResult:
1013
908
  """Show details about currently running threads."""
1014
- return self._query('processlist')
909
+ return self._iquery('processlist')
1015
910
 
1016
911
  def reproduction(self, outfile: Optional[str] = None) -> ShowResult:
1017
912
  """Show troubleshooting data for query optimizer and code generation."""
1018
913
  if outfile:
1019
914
  outfile = outfile.replace('"', r'\"')
1020
- return self._query('reproduction into outfile "{outfile}"')
1021
- return self._query('reproduction')
915
+ return self._iquery('reproduction into outfile "{outfile}"')
916
+ return self._iquery('reproduction')
1022
917
 
1023
918
  def schemas(self) -> ShowResult:
1024
919
  """Show schemas in the server."""
1025
- return self._query('schemas')
920
+ return self._iquery('schemas')
1026
921
 
1027
922
  def session_status(self) -> ShowResult:
1028
923
  """Show server status information for a session."""
1029
- return self._query('session status')
924
+ return self._iquery('session status')
1030
925
 
1031
926
  def status(self, extended: bool = False) -> ShowResult:
1032
927
  """Show server status information."""
1033
928
  if extended:
1034
- return self._query('status extended')
1035
- return self._query('status')
929
+ return self._iquery('status extended')
930
+ return self._iquery('status')
1036
931
 
1037
932
  def table_status(self) -> ShowResult:
1038
933
  """Show table status information for the current database."""
1039
- return self._query('table status')
934
+ return self._iquery('table status')
1040
935
 
1041
936
  def procedures(self) -> ShowResult:
1042
937
  """Show all procedures in the current database."""
1043
- return self._query('procedures')
938
+ return self._iquery('procedures')
1044
939
 
1045
940
  def aggregates(self) -> ShowResult:
1046
941
  """Show all aggregate functions in the current database."""
1047
- return self._query('aggregates')
942
+ return self._iquery('aggregates')
1048
943
 
1049
944
  def create_aggregate(self, name: str) -> ShowResult:
1050
945
  """Show the function creation code for the given aggregate function."""
1051
946
  name = quote_identifier(name)
1052
- return self._query(f'create aggregate {name}')
947
+ return self._iquery(f'create aggregate {name}')
1053
948
 
1054
949
  def create_function(self, name: str) -> ShowResult:
1055
950
  """Show the function creation code for the given function."""
1056
951
  name = quote_identifier(name)
1057
- return self._query(f'create function {name}')
952
+ return self._iquery(f'create function {name}')
1058
953
 
1059
954
  def create_pipeline(self, name: str, extended: bool = False) -> ShowResult:
1060
955
  """Show the pipeline creation code for the given pipeline."""
1061
956
  name = quote_identifier(name)
1062
957
  if extended:
1063
- return self._query(f'create pipeline {name} extended')
1064
- return self._query(f'create pipeline {name}')
958
+ return self._iquery(f'create pipeline {name} extended')
959
+ return self._iquery(f'create pipeline {name}')
1065
960
 
1066
961
  def create_table(self, name: str) -> ShowResult:
1067
962
  """Show the table creation code for the given table."""
1068
963
  name = quote_identifier(name)
1069
- return self._query(f'create table {name}')
964
+ return self._iquery(f'create table {name}')
1070
965
 
1071
966
  def create_view(self, name: str) -> ShowResult:
1072
967
  """Show the view creation code for the given view."""
1073
968
  name = quote_identifier(name)
1074
- return self._query(f'create view {name}')
1075
-
1076
- def _query(self, qtype: str) -> ShowResult:
969
+ return self._iquery(f'create view {name}')
970
+
971
+ # def grants(
972
+ # self,
973
+ # user: Optional[str] = None,
974
+ # hostname: Optional[str] = None,
975
+ # role: Optional[str] = None
976
+ # ) -> ShowResult:
977
+ # """Show the privileges for the given user or role."""
978
+ # if user:
979
+ # if not re.match(r'^[\w+-_]+$', user):
980
+ # raise ValueError(f'User name is not valid: {user}')
981
+ # if hostname and not re.match(r'^[\w+-_\.]+$', hostname):
982
+ # raise ValueError(f'Hostname is not valid: {hostname}')
983
+ # if hostname:
984
+ # return self._iquery(f"grants for '{user}@{hostname}'")
985
+ # return self._iquery(f"grants for '{user}'")
986
+ # if role:
987
+ # if not re.match(r'^[\w+-_]+$', role):
988
+ # raise ValueError(f'Role is not valid: {role}')
989
+ # return self._iquery(f"grants for role '{role}'")
990
+ # return self._iquery('grants')
991
+
992
+ def _iquery(self, qtype: str) -> ShowResult:
1077
993
  """Query the given object type."""
1078
- with self._conn._i_cursor() as cur:
1079
- cur.execute(f'show {qtype}')
1080
- out = []
1081
- if cur.description:
1082
- names = [under2camel(str(x[0]).replace(' ', '')) for x in cur.description]
1083
- names[0] = 'Name'
1084
- item_type = namedtuple('Row', names) # type: ignore
1085
- for item in cur.fetchall():
1086
- out.append(item_type(*item))
1087
- return ShowResult(out)
1088
-
1089
-
1090
- class Connection(object):
994
+ out = self._conn._iquery(f'show {qtype}')
995
+ for i, row in enumerate(out):
996
+ new_row = {}
997
+ for j, (k, v) in enumerate(row.items()):
998
+ if j == 0:
999
+ k = 'Name'
1000
+ new_row[under2camel(k)] = v
1001
+ out[i] = new_row
1002
+ return ShowResult(out)
1003
+
1004
+
1005
+ class Connection(metaclass=abc.ABCMeta):
1091
1006
  """
1092
1007
  SingleStoreDB connection.
1093
1008
 
1094
1009
  Instances of this object are typically created through the
1095
1010
  :func:`singlestoredb.connect` function rather than creating them directly.
1096
- See the :func:`connect` function for parameter definitions.
1011
+ See the :func:`singlestoredb.connect` function for parameter definitions.
1097
1012
 
1098
1013
  See Also
1099
1014
  --------
@@ -1112,20 +1027,25 @@ class Connection(object):
1112
1027
  ProgrammingError = exceptions.ProgrammingError
1113
1028
  NotSupportedError = exceptions.NotSupportedError
1114
1029
 
1030
+ #: Read-only DB-API parameter style
1031
+ paramstyle = 'pyformat'
1032
+
1033
+ # Must be set by subclass
1034
+ driver = ''
1035
+
1036
+ # Populated when first needed
1037
+ _map_param_converter: Optional[sqlparams.SQLParams] = None
1038
+ _positional_param_converter: Optional[sqlparams.SQLParams] = None
1039
+
1115
1040
  def __init__(self, **kwargs: Any):
1116
1041
  """Call :func:`singlestoredb.connect` instead."""
1117
- self._conn: Optional[Any] = None
1042
+ self.connection_params: Dict[str, Any] = kwargs
1118
1043
  self.errorhandler = None
1119
- self.connection_params: Dict[str, Any] = build_params(**kwargs)
1120
-
1121
- #: Query results format ('tuple', 'namedtuple', 'dict', 'dataframe')
1122
- self.results_format = self.connection_params.pop(
1123
- 'results_format',
1124
- get_option('results.format'),
1125
- )
1044
+ self._results_type: str = kwargs.get('results_type', None) or 'tuples'
1126
1045
 
1127
1046
  #: Session encoding
1128
- self.encoding = self.connection_params.get('charset', 'utf-8').replace('mb4', '')
1047
+ self.encoding = self.connection_params.get('charset', None) or 'utf-8'
1048
+ self.encoding = self.encoding.replace('mb4', '')
1129
1049
 
1130
1050
  # Handle various authentication types
1131
1051
  credential_type = self.connection_params.get('credential_type', None)
@@ -1135,14 +1055,6 @@ class Connection(object):
1135
1055
  self.connection_params['password'] = str(info)
1136
1056
  self.connection_params['credential_type'] = auth.JWT
1137
1057
 
1138
- drv_name = re.sub(r'^\w+\+', r'', self.connection_params['driver']).lower()
1139
- self._driver = drivers.get_driver(drv_name, self.connection_params)
1140
-
1141
- try:
1142
- self._conn = self._driver.connect()
1143
- except Exception as exc:
1144
- raise self._driver.convert_exception(exc)
1145
-
1146
1058
  #: Attribute-like access to global server variables
1147
1059
  self.globals = VariableAccessor(self, 'global')
1148
1060
 
@@ -1161,41 +1073,93 @@ class Connection(object):
1161
1073
  #: Attribute-like access to all cluster server variables
1162
1074
  self.cluster_vars = VariableAccessor(self, 'cluster')
1163
1075
 
1076
+ # For backwards compatibility with SQLAlchemy package
1077
+ self._driver = Driver(self.driver)
1078
+
1079
+ # Output decoders
1080
+ self.decoders: Dict[int, Callable[[Any], Any]] = {}
1081
+
1082
+ @classmethod
1083
+ def _convert_params(
1084
+ cls, oper: str,
1085
+ params: Optional[Union[Sequence[Any], Dict[str, Any], Any]],
1086
+ ) -> Tuple[Any, ...]:
1087
+ """Convert query to correct parameter format."""
1088
+ if params:
1089
+
1090
+ if cls._map_param_converter is None:
1091
+ cls._map_param_converter = sqlparams.SQLParams(
1092
+ map_paramstyle, cls.paramstyle, escape_char=True,
1093
+ )
1094
+
1095
+ if cls._positional_param_converter is None:
1096
+ cls._positional_param_converter = sqlparams.SQLParams(
1097
+ positional_paramstyle, cls.paramstyle, escape_char=True,
1098
+ )
1099
+
1100
+ is_sequence = isinstance(params, Sequence) \
1101
+ and not isinstance(params, str) \
1102
+ and not isinstance(params, bytes)
1103
+ is_mapping = isinstance(params, Mapping)
1104
+
1105
+ param_converter = cls._map_param_converter \
1106
+ if is_mapping else cls._positional_param_converter
1107
+
1108
+ if not is_sequence and not is_mapping:
1109
+ params = [params]
1110
+
1111
+ return param_converter.format(oper, params)
1112
+
1113
+ return (oper, None)
1114
+
1164
1115
  def autocommit(self, value: bool = True) -> None:
1165
1116
  """Set autocommit mode."""
1166
- if self._conn is None:
1167
- raise exceptions.InterfaceError(2048, 'Connection is closed.')
1168
1117
  self.locals.autocommit = bool(value)
1169
1118
 
1119
+ @abc.abstractmethod
1120
+ def connect(self) -> 'Connection':
1121
+ """Connect to the server."""
1122
+ raise NotImplementedError
1123
+
1124
+ def _iquery(
1125
+ self, oper: str,
1126
+ params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
1127
+ fix_names: bool = True,
1128
+ ) -> List[Dict[str, Any]]:
1129
+ """Return the results of a query as a list of dicts (for internal use)."""
1130
+ with self.cursor() as cur:
1131
+ cur.execute(oper, params)
1132
+ if not re.match(r'^\s*(select|show|call|echo)\s+', oper, flags=re.I):
1133
+ return []
1134
+ out = list(cur.fetchall())
1135
+ if not out:
1136
+ return []
1137
+ if isinstance(out, DataFrame):
1138
+ out = out.to_dict(orient='records')
1139
+ elif isinstance(out[0], (tuple, list)):
1140
+ if cur.description:
1141
+ names = [x[0] for x in cur.description]
1142
+ if fix_names:
1143
+ names = [under2camel(str(x).replace(' ', '')) for x in names]
1144
+ out = [{k: v for k, v in zip(names, row)} for row in out]
1145
+ return out
1146
+
1147
+ @abc.abstractmethod
1170
1148
  def close(self) -> None:
1171
1149
  """Close the database connection."""
1172
- if self._conn is None:
1173
- return None
1174
- try:
1175
- self._conn.close()
1176
- except Exception as exc:
1177
- raise self._driver.convert_exception(exc)
1178
- finally:
1179
- self._conn = None
1150
+ raise NotImplementedError
1180
1151
 
1152
+ @abc.abstractmethod
1181
1153
  def commit(self) -> None:
1182
1154
  """Commit the pending transaction."""
1183
- if self._conn is None:
1184
- raise exceptions.InterfaceError(2048, 'Connection is closed.')
1185
- try:
1186
- self._conn.commit()
1187
- except Exception as exc:
1188
- raise self._driver.convert_exception(exc)
1155
+ raise NotImplementedError
1189
1156
 
1157
+ @abc.abstractmethod
1190
1158
  def rollback(self) -> None:
1191
1159
  """Rollback the pending transaction."""
1192
- if self._conn is None:
1193
- raise exceptions.InterfaceError(2048, 'Connection is closed.')
1194
- try:
1195
- self._conn.rollback()
1196
- except Exception as exc:
1197
- raise self._driver.convert_exception(exc)
1160
+ raise NotImplementedError
1198
1161
 
1162
+ @abc.abstractmethod
1199
1163
  def cursor(self) -> Cursor:
1200
1164
  """
1201
1165
  Create a new cursor object.
@@ -1209,46 +1173,12 @@ class Connection(object):
1209
1173
  :class:`Cursor`
1210
1174
 
1211
1175
  """
1212
- if self._conn is None:
1213
- raise exceptions.InterfaceError(2048, 'Connection is closed.')
1214
- try:
1215
- cur = self._conn.cursor()
1216
- except Exception as exc:
1217
- raise self._driver.convert_exception(exc)
1218
- return Cursor(self, cur, self._driver)
1219
-
1220
- def _i_cursor(self) -> Cursor:
1221
- """
1222
- Create a cursor for internal use.
1223
-
1224
- Internal cursors always return tuples in results.
1225
- These are used to ensure that methods that query the database
1226
- have a consistent results structure regardless of the
1227
- `results.format` option.
1228
-
1229
- Returns
1230
- -------
1231
- Cursor
1232
-
1233
- """
1234
- out = self.cursor()
1235
- out._results_format = 'tuple'
1236
- return out
1237
-
1238
- @property
1239
- def messages(self) -> Sequence[Tuple[int, str]]:
1240
- """
1241
- Return messages generated by the connection.
1242
-
1243
- Returns
1244
- -------
1245
- list of tuples
1246
- Each tuple contains an int code and a message
1176
+ raise NotImplementedError
1247
1177
 
1248
- """
1249
- if self._conn is None:
1250
- raise exceptions.InterfaceError(2048, 'Connection is closed.')
1251
- return self._conn.messages
1178
+ @abc.abstractproperty
1179
+ def messages(self) -> List[Tuple[int, str]]:
1180
+ """Messages generated during the connection."""
1181
+ raise NotImplementedError
1252
1182
 
1253
1183
  def __enter__(self) -> 'Connection':
1254
1184
  """Enter a context."""
@@ -1261,6 +1191,7 @@ class Connection(object):
1261
1191
  """Exit a context."""
1262
1192
  self.close()
1263
1193
 
1194
+ @abc.abstractmethod
1264
1195
  def is_connected(self) -> bool:
1265
1196
  """
1266
1197
  Determine if the database is still connected.
@@ -1270,12 +1201,7 @@ class Connection(object):
1270
1201
  bool
1271
1202
 
1272
1203
  """
1273
- if self._conn is None:
1274
- return False
1275
- try:
1276
- return self._driver.is_connected(self._conn)
1277
- except Exception as exc:
1278
- raise self._driver.convert_exception(exc)
1204
+ raise NotImplementedError
1279
1205
 
1280
1206
  def enable_data_api(self, port: Optional[int] = None) -> int:
1281
1207
  """
@@ -1301,14 +1227,11 @@ class Connection(object):
1301
1227
  port number of the HTTP server
1302
1228
 
1303
1229
  """
1304
- if self._conn is None:
1305
- raise exceptions.InterfaceError(2048, 'Connection is closed.')
1306
- with self._i_cursor() as cur:
1307
- if port is not None:
1308
- self.globals.http_proxy_port = int(port)
1309
- self.globals.http_api = True
1310
- cur.execute('restart proxy')
1311
- return int(self.globals.http_proxy_port)
1230
+ if port is not None:
1231
+ self.globals.http_proxy_port = int(port)
1232
+ self.globals.http_api = True
1233
+ self._iquery('restart proxy')
1234
+ return int(self.globals.http_proxy_port)
1312
1235
 
1313
1236
  enable_http_api = enable_data_api
1314
1237
 
@@ -1321,11 +1244,8 @@ class Connection(object):
1321
1244
  :meth:`enable_data_api`
1322
1245
 
1323
1246
  """
1324
- if self._conn is None:
1325
- raise exceptions.InterfaceError(2048, 'Connection is closed.')
1326
- with self._i_cursor() as cur:
1327
- self.globals.http_api = False
1328
- cur.execute('restart proxy')
1247
+ self.globals.http_api = False
1248
+ self._iquery('restart proxy')
1329
1249
 
1330
1250
  disable_http_api = disable_data_api
1331
1251
 
@@ -1346,14 +1266,25 @@ def connect(
1346
1266
  password: Optional[str] = None, port: Optional[int] = None,
1347
1267
  database: Optional[str] = None, driver: Optional[str] = None,
1348
1268
  pure_python: Optional[bool] = None, local_infile: Optional[bool] = None,
1349
- odbc_driver: Optional[str] = None, charset: Optional[str] = None,
1269
+ charset: Optional[str] = None,
1350
1270
  ssl_key: Optional[str] = None, ssl_cert: Optional[str] = None,
1351
1271
  ssl_ca: Optional[str] = None, ssl_disabled: Optional[bool] = None,
1352
- ssl_cipher: Optional[str] = None,
1353
- converters: Optional[Dict[int, Callable[..., Any]]] = None,
1354
- results_format: Optional[str] = None,
1272
+ ssl_cipher: Optional[str] = None, ssl_verify_cert: Optional[bool] = None,
1273
+ ssl_verify_identity: Optional[bool] = None,
1274
+ conv: Optional[Dict[int, Callable[..., Any]]] = None,
1355
1275
  credential_type: Optional[str] = None,
1356
1276
  autocommit: Optional[bool] = None,
1277
+ results_type: Optional[str] = None,
1278
+ buffered: Optional[bool] = None,
1279
+ results_format: Optional[str] = None,
1280
+ program_name: Optional[str] = None,
1281
+ conn_attrs: Optional[Dict[str, str]] = None,
1282
+ multi_statements: Optional[bool] = None,
1283
+ connect_timeout: Optional[int] = None,
1284
+ nan_as_null: Optional[bool] = None,
1285
+ inf_as_null: Optional[bool] = None,
1286
+ encoding_errors: Optional[str] = None,
1287
+ track_env: Optional[bool] = None,
1357
1288
  ) -> Connection:
1358
1289
  """
1359
1290
  Return a SingleStoreDB connection.
@@ -1363,7 +1294,7 @@ def connect(
1363
1294
  host : str, optional
1364
1295
  Hostname, IP address, or URL that describes the connection.
1365
1296
  The scheme or protocol defines which database connector to use.
1366
- By default, the ``pymysql`` scheme is used. To connect to the
1297
+ By default, the ``mysql`` scheme is used. To connect to the
1367
1298
  HTTP API, the scheme can be set to ``http`` or ``https``. The username,
1368
1299
  password, host, and port are specified as in a standard URL. The path
1369
1300
  indicates the database name. The overall form of the URL is:
@@ -1383,8 +1314,6 @@ def connect(
1383
1314
  Use the connector in pure Python mode
1384
1315
  local_infile : bool, optional
1385
1316
  Allow local file uploads
1386
- odbc_driver : str, optional
1387
- Name of the ODBC driver to use for ODBC connections
1388
1317
  charset : str, optional
1389
1318
  Character set for string values
1390
1319
  ssl_key : str, optional
@@ -1397,14 +1326,41 @@ def connect(
1397
1326
  Sets the SSL cipher list
1398
1327
  ssl_disabled : bool, optional
1399
1328
  Disable SSL usage
1400
- converters : dict[int, Callable], optional
1329
+ ssl_verify_cert : bool, optional
1330
+ Verify the server's certificate. This is automatically enabled if
1331
+ ``ssl_ca`` is also specified.
1332
+ ssl_verify_identity : bool, optional
1333
+ Verify the server's identity
1334
+ conv : dict[int, Callable], optional
1401
1335
  Dictionary of data conversion functions
1402
- results_format : str, optional
1403
- Format of query results: tuple, namedtuple, dict, or dataframe
1404
1336
  credential_type : str, optional
1405
1337
  Type of authentication to use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO
1406
1338
  autocommit : bool, optional
1407
1339
  Enable autocommits
1340
+ results_type : str, optional
1341
+ The form of the query results: tuples, namedtuples, dicts
1342
+ results_format : str, optional
1343
+ Deprecated. This option has been renamed to results_type.
1344
+ program_name : str, optional
1345
+ Name of the program
1346
+ conn_attrs : dict, optional
1347
+ Additional connection attributes for telemetry. Example:
1348
+ {'program_version': "1.0.2", "_connector_name": "dbt connector"}
1349
+ multi_statements: bool, optional
1350
+ Should multiple statements be allowed within a single query?
1351
+ connect_timeout : int, optional
1352
+ The timeout for connecting to the database in seconds.
1353
+ (default: 10, min: 1, max: 31536000)
1354
+ nan_as_null : bool, optional
1355
+ Should NaN values be treated as NULLs when used in parameter
1356
+ substitutions including uploaded data?
1357
+ inf_as_null : bool, optional
1358
+ Should Inf values be treated as NULLs when used in parameter
1359
+ substitutions including uploaded data?
1360
+ encoding_errors : str, optional
1361
+ The error handler name for value decoding errors
1362
+ track_env : bool, optional
1363
+ Should the connection track the SINGLESTOREDB_URL environment variable?
1408
1364
 
1409
1365
  Examples
1410
1366
  --------
@@ -1460,4 +1416,15 @@ def connect(
1460
1416
  :class:`Connection`
1461
1417
 
1462
1418
  """
1463
- return Connection(**dict(locals()))
1419
+ params = build_params(**dict(locals()))
1420
+ driver = params.get('driver', 'mysql')
1421
+
1422
+ if not driver or driver == 'mysql':
1423
+ from .mysql.connection import Connection # type: ignore
1424
+ return Connection(**params)
1425
+
1426
+ if driver in ['http', 'https']:
1427
+ from .http.connection import Connection
1428
+ return Connection(**params)
1429
+
1430
+ raise ValueError(f'Unrecognized protocol: {driver}')