sqlspec 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/adapters/adbc/driver.py +192 -28
- sqlspec/adapters/asyncmy/driver.py +72 -15
- sqlspec/adapters/asyncpg/config.py +23 -3
- sqlspec/adapters/asyncpg/driver.py +30 -14
- sqlspec/adapters/bigquery/driver.py +79 -9
- sqlspec/adapters/duckdb/driver.py +39 -56
- sqlspec/adapters/oracledb/driver.py +99 -52
- sqlspec/adapters/psqlpy/driver.py +89 -31
- sqlspec/adapters/psycopg/driver.py +11 -23
- sqlspec/adapters/sqlite/driver.py +77 -8
- sqlspec/base.py +29 -25
- sqlspec/builder/__init__.py +1 -1
- sqlspec/builder/_base.py +4 -5
- sqlspec/builder/_column.py +3 -3
- sqlspec/builder/_ddl.py +5 -1
- sqlspec/builder/_delete.py +5 -6
- sqlspec/builder/_insert.py +6 -7
- sqlspec/builder/_merge.py +5 -5
- sqlspec/builder/_parsing_utils.py +3 -3
- sqlspec/builder/_select.py +6 -5
- sqlspec/builder/_update.py +4 -5
- sqlspec/builder/mixins/_cte_and_set_ops.py +5 -1
- sqlspec/builder/mixins/_delete_operations.py +5 -1
- sqlspec/builder/mixins/_insert_operations.py +5 -1
- sqlspec/builder/mixins/_join_operations.py +5 -0
- sqlspec/builder/mixins/_merge_operations.py +5 -1
- sqlspec/builder/mixins/_order_limit_operations.py +5 -1
- sqlspec/builder/mixins/_pivot_operations.py +4 -1
- sqlspec/builder/mixins/_select_operations.py +5 -1
- sqlspec/builder/mixins/_update_operations.py +5 -1
- sqlspec/builder/mixins/_where_clause.py +5 -1
- sqlspec/cli.py +281 -33
- sqlspec/config.py +160 -10
- sqlspec/core/compiler.py +11 -3
- sqlspec/core/filters.py +30 -9
- sqlspec/core/parameters.py +67 -67
- sqlspec/core/result.py +62 -31
- sqlspec/core/splitter.py +160 -34
- sqlspec/core/statement.py +95 -14
- sqlspec/driver/_common.py +12 -3
- sqlspec/driver/mixins/_result_tools.py +21 -4
- sqlspec/driver/mixins/_sql_translator.py +45 -7
- sqlspec/extensions/aiosql/adapter.py +1 -1
- sqlspec/extensions/litestar/_utils.py +1 -1
- sqlspec/extensions/litestar/handlers.py +21 -0
- sqlspec/extensions/litestar/plugin.py +15 -8
- sqlspec/loader.py +12 -12
- sqlspec/migrations/loaders.py +5 -2
- sqlspec/migrations/utils.py +2 -2
- sqlspec/storage/backends/obstore.py +1 -3
- sqlspec/storage/registry.py +1 -1
- sqlspec/utils/__init__.py +7 -0
- sqlspec/utils/deprecation.py +6 -0
- sqlspec/utils/fixtures.py +239 -30
- sqlspec/utils/module_loader.py +5 -1
- sqlspec/utils/serializers.py +6 -0
- sqlspec/utils/singleton.py +6 -0
- sqlspec/utils/sync_tools.py +10 -1
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/METADATA +1 -1
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/RECORD +64 -64
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/adapters/adbc/driver.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""ADBC driver implementation for Arrow Database Connectivity.
|
|
2
2
|
|
|
3
|
-
Provides
|
|
4
|
-
|
|
5
|
-
for different database backends, and transaction management.
|
|
3
|
+
Provides database connectivity through ADBC with support for multiple
|
|
4
|
+
database dialects, parameter style conversion, and transaction management.
|
|
6
5
|
"""
|
|
7
6
|
|
|
8
7
|
import contextlib
|
|
@@ -10,6 +9,7 @@ import datetime
|
|
|
10
9
|
import decimal
|
|
11
10
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
12
11
|
|
|
12
|
+
from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
|
|
13
13
|
from sqlglot import exp
|
|
14
14
|
|
|
15
15
|
from sqlspec.core.cache import get_cache_config
|
|
@@ -53,22 +53,88 @@ DIALECT_PARAMETER_STYLES = {
|
|
|
53
53
|
}
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def
|
|
57
|
-
"""
|
|
56
|
+
def _count_placeholders(expression: Any) -> int:
|
|
57
|
+
"""Count the number of unique parameter placeholders in a SQLGlot expression.
|
|
58
58
|
|
|
59
|
-
For PostgreSQL,
|
|
60
|
-
|
|
59
|
+
For PostgreSQL ($1, $2) style: counts highest numbered parameter (e.g., $1, $1, $2 = 2)
|
|
60
|
+
For QMARK (?) style: counts total occurrences (each ? is a separate parameter)
|
|
61
|
+
For named (:name) style: counts unique parameter names
|
|
61
62
|
|
|
62
63
|
Args:
|
|
63
64
|
expression: SQLGlot AST expression
|
|
64
|
-
parameters: Parameter values that may contain None
|
|
65
65
|
|
|
66
66
|
Returns:
|
|
67
|
-
|
|
67
|
+
Number of unique parameter placeholders expected
|
|
68
68
|
"""
|
|
69
|
-
|
|
70
|
-
|
|
69
|
+
numeric_params = set() # For $1, $2 style
|
|
70
|
+
qmark_count = 0 # For ? style
|
|
71
|
+
named_params = set() # For :name style
|
|
72
|
+
|
|
73
|
+
def count_node(node: Any) -> Any:
|
|
74
|
+
nonlocal qmark_count
|
|
75
|
+
if isinstance(node, exp.Parameter):
|
|
76
|
+
# PostgreSQL style: $1, $2, etc.
|
|
77
|
+
param_str = str(node)
|
|
78
|
+
if param_str.startswith("$") and param_str[1:].isdigit():
|
|
79
|
+
numeric_params.add(int(param_str[1:]))
|
|
80
|
+
elif ":" in param_str:
|
|
81
|
+
# Named parameter: :name
|
|
82
|
+
named_params.add(param_str)
|
|
83
|
+
else:
|
|
84
|
+
# Other parameter formats
|
|
85
|
+
named_params.add(param_str)
|
|
86
|
+
elif isinstance(node, exp.Placeholder):
|
|
87
|
+
# QMARK style: ?
|
|
88
|
+
qmark_count += 1
|
|
89
|
+
return node
|
|
90
|
+
|
|
91
|
+
expression.transform(count_node)
|
|
92
|
+
|
|
93
|
+
# Return the appropriate count based on parameter style detected
|
|
94
|
+
if numeric_params:
|
|
95
|
+
# PostgreSQL style: return highest numbered parameter
|
|
96
|
+
return max(numeric_params)
|
|
97
|
+
if named_params:
|
|
98
|
+
# Named parameters: return count of unique names
|
|
99
|
+
return len(named_params)
|
|
100
|
+
# QMARK style: return total count
|
|
101
|
+
return qmark_count
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _is_execute_many_parameters(parameters: Any) -> bool:
|
|
105
|
+
"""Check if parameters are in execute_many format (list/tuple of lists/tuples)."""
|
|
106
|
+
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], (list, tuple))
|
|
107
|
+
|
|
71
108
|
|
|
109
|
+
def _validate_parameter_counts(expression: Any, parameters: Any, dialect: str) -> None:
|
|
110
|
+
"""Validate parameter count against placeholder count in SQL."""
|
|
111
|
+
placeholder_count = _count_placeholders(expression)
|
|
112
|
+
is_execute_many = _is_execute_many_parameters(parameters)
|
|
113
|
+
|
|
114
|
+
if is_execute_many:
|
|
115
|
+
# For execute_many, validate each inner parameter set
|
|
116
|
+
for i, param_set in enumerate(parameters):
|
|
117
|
+
param_count = len(param_set) if isinstance(param_set, (list, tuple)) else 0
|
|
118
|
+
if param_count != placeholder_count:
|
|
119
|
+
msg = f"Parameter count mismatch in set {i}: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
|
|
120
|
+
raise SQLSpecError(msg)
|
|
121
|
+
else:
|
|
122
|
+
# For single execution, validate the parameter set directly
|
|
123
|
+
param_count = (
|
|
124
|
+
len(parameters)
|
|
125
|
+
if isinstance(parameters, (list, tuple))
|
|
126
|
+
else len(parameters)
|
|
127
|
+
if isinstance(parameters, dict)
|
|
128
|
+
else 0
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if param_count != placeholder_count:
|
|
132
|
+
msg = f"Parameter count mismatch: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
|
|
133
|
+
raise SQLSpecError(msg)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _find_null_positions(parameters: Any) -> set[int]:
|
|
137
|
+
"""Find positions of None values in parameters for single execution."""
|
|
72
138
|
null_positions = set()
|
|
73
139
|
if isinstance(parameters, (list, tuple)):
|
|
74
140
|
for i, param in enumerate(parameters):
|
|
@@ -83,7 +149,37 @@ def _adbc_ast_transformer(expression: Any, parameters: Any) -> tuple[Any, Any]:
|
|
|
83
149
|
null_positions.add(param_num - 1)
|
|
84
150
|
except ValueError:
|
|
85
151
|
pass
|
|
152
|
+
return null_positions
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _adbc_ast_transformer(expression: Any, parameters: Any, dialect: str = "postgres") -> tuple[Any, Any]:
|
|
156
|
+
"""Transform AST to handle NULL parameters.
|
|
86
157
|
|
|
158
|
+
Replaces NULL parameter placeholders with NULL literals in the AST
|
|
159
|
+
to prevent Arrow from inferring 'na' types which cause binding errors.
|
|
160
|
+
Validates parameter count before transformation.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
expression: SQLGlot AST expression parsed with proper dialect
|
|
164
|
+
parameters: Parameter values that may contain None
|
|
165
|
+
dialect: SQLGlot dialect used for parsing (default: "postgres")
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Tuple of (modified_expression, cleaned_parameters)
|
|
169
|
+
"""
|
|
170
|
+
if not parameters:
|
|
171
|
+
return expression, parameters
|
|
172
|
+
|
|
173
|
+
# Validate parameter count before transformation
|
|
174
|
+
_validate_parameter_counts(expression, parameters, dialect)
|
|
175
|
+
|
|
176
|
+
# For execute_many operations, skip AST transformation as different parameter
|
|
177
|
+
# sets may have None values in different positions, making transformation complex
|
|
178
|
+
if _is_execute_many_parameters(parameters):
|
|
179
|
+
return expression, parameters
|
|
180
|
+
|
|
181
|
+
# Find positions of None values for single execution
|
|
182
|
+
null_positions = _find_null_positions(parameters)
|
|
87
183
|
if not null_positions:
|
|
88
184
|
return expression, parameters
|
|
89
185
|
|
|
@@ -183,14 +279,28 @@ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
|
|
|
183
279
|
|
|
184
280
|
|
|
185
281
|
def _convert_array_for_postgres_adbc(value: Any) -> Any:
|
|
186
|
-
"""Convert array values for PostgreSQL compatibility.
|
|
282
|
+
"""Convert array values for PostgreSQL compatibility.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
value: Value to convert
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Converted value (tuples become lists)
|
|
289
|
+
"""
|
|
187
290
|
if isinstance(value, tuple):
|
|
188
291
|
return list(value)
|
|
189
292
|
return value
|
|
190
293
|
|
|
191
294
|
|
|
192
295
|
def get_type_coercion_map(dialect: str) -> "dict[type, Any]":
|
|
193
|
-
"""Get type coercion map for Arrow type handling.
|
|
296
|
+
"""Get type coercion map for Arrow type handling.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
dialect: Database dialect name
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Mapping of Python types to conversion functions
|
|
303
|
+
"""
|
|
194
304
|
type_map = {
|
|
195
305
|
datetime.datetime: lambda x: x,
|
|
196
306
|
datetime.date: lambda x: x,
|
|
@@ -245,8 +355,6 @@ class AdbcExceptionHandler:
|
|
|
245
355
|
return
|
|
246
356
|
|
|
247
357
|
try:
|
|
248
|
-
from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
|
|
249
|
-
|
|
250
358
|
if issubclass(exc_type, IntegrityError):
|
|
251
359
|
e = exc_val
|
|
252
360
|
msg = f"Integrity constraint violation: {e}"
|
|
@@ -282,9 +390,8 @@ class AdbcExceptionHandler:
|
|
|
282
390
|
class AdbcDriver(SyncDriverAdapterBase):
|
|
283
391
|
"""ADBC driver for Arrow Database Connectivity.
|
|
284
392
|
|
|
285
|
-
Provides database connectivity through ADBC with
|
|
286
|
-
|
|
287
|
-
conversion for different backends, and transaction management.
|
|
393
|
+
Provides database connectivity through ADBC with support for multiple
|
|
394
|
+
database dialects, parameter style conversion, and transaction management.
|
|
288
395
|
"""
|
|
289
396
|
|
|
290
397
|
__slots__ = ("_detected_dialect", "dialect")
|
|
@@ -309,7 +416,11 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
309
416
|
|
|
310
417
|
@staticmethod
|
|
311
418
|
def _ensure_pyarrow_installed() -> None:
|
|
312
|
-
"""Ensure PyArrow is installed.
|
|
419
|
+
"""Ensure PyArrow is installed.
|
|
420
|
+
|
|
421
|
+
Raises:
|
|
422
|
+
MissingDependencyError: If PyArrow is not installed
|
|
423
|
+
"""
|
|
313
424
|
from sqlspec.typing import PYARROW_INSTALLED
|
|
314
425
|
|
|
315
426
|
if not PYARROW_INSTALLED:
|
|
@@ -317,7 +428,14 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
317
428
|
|
|
318
429
|
@staticmethod
|
|
319
430
|
def _get_dialect(connection: "AdbcConnection") -> str:
|
|
320
|
-
"""Detect database dialect from connection information.
|
|
431
|
+
"""Detect database dialect from connection information.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
connection: ADBC connection
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
Detected dialect name (defaults to 'postgres')
|
|
438
|
+
"""
|
|
321
439
|
try:
|
|
322
440
|
driver_info = connection.adbc_get_info()
|
|
323
441
|
vendor_name = driver_info.get("vendor_name", "").lower()
|
|
@@ -334,31 +452,53 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
334
452
|
return "postgres"
|
|
335
453
|
|
|
336
454
|
def _handle_postgres_rollback(self, cursor: "Cursor") -> None:
|
|
337
|
-
"""Execute rollback for PostgreSQL after transaction failure.
|
|
455
|
+
"""Execute rollback for PostgreSQL after transaction failure.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
cursor: Database cursor
|
|
459
|
+
"""
|
|
338
460
|
if self.dialect == "postgres":
|
|
339
461
|
with contextlib.suppress(Exception):
|
|
340
462
|
cursor.execute("ROLLBACK")
|
|
341
463
|
logger.debug("PostgreSQL rollback executed after transaction failure")
|
|
342
464
|
|
|
343
465
|
def _handle_postgres_empty_parameters(self, parameters: Any) -> Any:
|
|
344
|
-
"""Process empty parameters for PostgreSQL compatibility.
|
|
466
|
+
"""Process empty parameters for PostgreSQL compatibility.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
parameters: Parameter values
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
Processed parameters
|
|
473
|
+
"""
|
|
345
474
|
if self.dialect == "postgres" and isinstance(parameters, dict) and not parameters:
|
|
346
475
|
return None
|
|
347
476
|
return parameters
|
|
348
477
|
|
|
349
478
|
def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor":
|
|
350
|
-
"""Create context manager for cursor.
|
|
479
|
+
"""Create context manager for cursor.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
connection: Database connection
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
Cursor context manager
|
|
486
|
+
"""
|
|
351
487
|
return AdbcCursor(connection)
|
|
352
488
|
|
|
353
489
|
def handle_database_exceptions(self) -> "AbstractContextManager[None]":
|
|
354
|
-
"""Handle database-specific exceptions and wrap them appropriately.
|
|
490
|
+
"""Handle database-specific exceptions and wrap them appropriately.
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
Exception handler context manager
|
|
494
|
+
"""
|
|
355
495
|
return AdbcExceptionHandler()
|
|
356
496
|
|
|
357
497
|
def _try_special_handling(self, cursor: "Cursor", statement: SQL) -> "Optional[SQLResult]":
|
|
358
498
|
"""Handle special operations.
|
|
359
499
|
|
|
360
500
|
Args:
|
|
361
|
-
cursor:
|
|
501
|
+
cursor: Database cursor
|
|
362
502
|
statement: SQL statement to analyze
|
|
363
503
|
|
|
364
504
|
Returns:
|
|
@@ -368,7 +508,15 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
368
508
|
return None
|
|
369
509
|
|
|
370
510
|
def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
|
|
371
|
-
"""Execute SQL with multiple parameter sets.
|
|
511
|
+
"""Execute SQL with multiple parameter sets.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
cursor: Database cursor
|
|
515
|
+
statement: SQL statement to execute
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
Execution result with row counts
|
|
519
|
+
"""
|
|
372
520
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
373
521
|
|
|
374
522
|
try:
|
|
@@ -398,7 +546,15 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
398
546
|
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
|
|
399
547
|
|
|
400
548
|
def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
|
|
401
|
-
"""Execute single SQL statement.
|
|
549
|
+
"""Execute single SQL statement.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
cursor: Database cursor
|
|
553
|
+
statement: SQL statement to execute
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
Execution result with data or row count
|
|
557
|
+
"""
|
|
402
558
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
403
559
|
|
|
404
560
|
try:
|
|
@@ -430,7 +586,15 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
430
586
|
return self.create_execution_result(cursor, rowcount_override=row_count)
|
|
431
587
|
|
|
432
588
|
def _execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult":
|
|
433
|
-
"""Execute SQL script.
|
|
589
|
+
"""Execute SQL script containing multiple statements.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
cursor: Database cursor
|
|
593
|
+
statement: SQL script to execute
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
Execution result with statement counts
|
|
597
|
+
"""
|
|
434
598
|
if statement.is_script:
|
|
435
599
|
sql = statement._raw_sql
|
|
436
600
|
prepared_parameters: list[Any] = []
|
|
@@ -51,7 +51,10 @@ asyncmy_statement_config = StatementConfig(
|
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
class AsyncmyCursor:
|
|
54
|
-
"""
|
|
54
|
+
"""Context manager for AsyncMy cursor operations.
|
|
55
|
+
|
|
56
|
+
Provides automatic cursor acquisition and cleanup for database operations.
|
|
57
|
+
"""
|
|
55
58
|
|
|
56
59
|
__slots__ = ("connection", "cursor")
|
|
57
60
|
|
|
@@ -70,7 +73,11 @@ class AsyncmyCursor:
|
|
|
70
73
|
|
|
71
74
|
|
|
72
75
|
class AsyncmyExceptionHandler:
|
|
73
|
-
"""
|
|
76
|
+
"""Context manager for AsyncMy database exception handling.
|
|
77
|
+
|
|
78
|
+
Converts AsyncMy-specific exceptions to SQLSpec exceptions with appropriate
|
|
79
|
+
error categorization and context preservation.
|
|
80
|
+
"""
|
|
74
81
|
|
|
75
82
|
__slots__ = ()
|
|
76
83
|
|
|
@@ -116,10 +123,11 @@ class AsyncmyExceptionHandler:
|
|
|
116
123
|
|
|
117
124
|
|
|
118
125
|
class AsyncmyDriver(AsyncDriverAdapterBase):
|
|
119
|
-
"""
|
|
126
|
+
"""MySQL/MariaDB database driver using AsyncMy client library.
|
|
120
127
|
|
|
121
|
-
|
|
122
|
-
type coercion, error handling,
|
|
128
|
+
Implements asynchronous database operations for MySQL and MariaDB servers
|
|
129
|
+
with support for parameter style conversion, type coercion, error handling,
|
|
130
|
+
and transaction management.
|
|
123
131
|
"""
|
|
124
132
|
|
|
125
133
|
__slots__ = ()
|
|
@@ -143,22 +151,33 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
|
|
|
143
151
|
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
|
|
144
152
|
|
|
145
153
|
def with_cursor(self, connection: "AsyncmyConnection") -> "AsyncmyCursor":
|
|
146
|
-
"""Create context manager for
|
|
154
|
+
"""Create cursor context manager for the connection.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
connection: AsyncMy database connection
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
AsyncmyCursor: Context manager for cursor operations
|
|
161
|
+
"""
|
|
147
162
|
return AsyncmyCursor(connection)
|
|
148
163
|
|
|
149
164
|
def handle_database_exceptions(self) -> "AbstractAsyncContextManager[None]":
|
|
150
|
-
"""
|
|
165
|
+
"""Provide exception handling context manager.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
AbstractAsyncContextManager[None]: Context manager for AsyncMy exception handling
|
|
169
|
+
"""
|
|
151
170
|
return AsyncmyExceptionHandler()
|
|
152
171
|
|
|
153
172
|
async def _try_special_handling(self, cursor: Any, statement: "SQL") -> "Optional[SQLResult]":
|
|
154
|
-
"""
|
|
173
|
+
"""Handle AsyncMy-specific operations before standard execution.
|
|
155
174
|
|
|
156
175
|
Args:
|
|
157
176
|
cursor: AsyncMy cursor object
|
|
158
177
|
statement: SQL statement to analyze
|
|
159
178
|
|
|
160
179
|
Returns:
|
|
161
|
-
None
|
|
180
|
+
Optional[SQLResult]: None, always proceeds with standard execution
|
|
162
181
|
"""
|
|
163
182
|
_ = (cursor, statement)
|
|
164
183
|
return None
|
|
@@ -166,7 +185,15 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
|
|
|
166
185
|
async def _execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
167
186
|
"""Execute SQL script with statement splitting and parameter handling.
|
|
168
187
|
|
|
188
|
+
Splits multi-statement scripts and executes each statement sequentially.
|
|
169
189
|
Parameters are embedded as static values for script execution compatibility.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
cursor: AsyncMy cursor object
|
|
193
|
+
statement: SQL script to execute
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
ExecutionResult: Script execution results with statement count
|
|
170
197
|
"""
|
|
171
198
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
172
199
|
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
|
|
@@ -183,9 +210,20 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
|
|
|
183
210
|
)
|
|
184
211
|
|
|
185
212
|
async def _execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
186
|
-
"""Execute SQL with multiple parameter sets
|
|
213
|
+
"""Execute SQL statement with multiple parameter sets.
|
|
187
214
|
|
|
188
|
-
|
|
215
|
+
Uses AsyncMy's executemany for batch operations with MySQL type conversion
|
|
216
|
+
and parameter processing.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
cursor: AsyncMy cursor object
|
|
220
|
+
statement: SQL statement with multiple parameter sets
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
ExecutionResult: Batch execution results
|
|
224
|
+
|
|
225
|
+
Raises:
|
|
226
|
+
ValueError: If no parameters provided for executemany operation
|
|
189
227
|
"""
|
|
190
228
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
191
229
|
|
|
@@ -200,9 +238,17 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
|
|
|
200
238
|
return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
|
|
201
239
|
|
|
202
240
|
async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
203
|
-
"""Execute single SQL statement
|
|
241
|
+
"""Execute single SQL statement.
|
|
242
|
+
|
|
243
|
+
Handles parameter processing, result fetching, and data transformation
|
|
244
|
+
for MySQL/MariaDB operations.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
cursor: AsyncMy cursor object
|
|
248
|
+
statement: SQL statement to execute
|
|
204
249
|
|
|
205
|
-
|
|
250
|
+
Returns:
|
|
251
|
+
ExecutionResult: Statement execution results with data or row counts
|
|
206
252
|
"""
|
|
207
253
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
208
254
|
await cursor.execute(sql, prepared_parameters or None)
|
|
@@ -228,6 +274,9 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
|
|
|
228
274
|
"""Begin a database transaction.
|
|
229
275
|
|
|
230
276
|
Explicitly starts a MySQL transaction to ensure proper transaction boundaries.
|
|
277
|
+
|
|
278
|
+
Raises:
|
|
279
|
+
SQLSpecError: If transaction initialization fails
|
|
231
280
|
"""
|
|
232
281
|
try:
|
|
233
282
|
async with AsyncmyCursor(self.connection) as cursor:
|
|
@@ -237,7 +286,11 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
|
|
|
237
286
|
raise SQLSpecError(msg) from e
|
|
238
287
|
|
|
239
288
|
async def rollback(self) -> None:
|
|
240
|
-
"""Rollback the current transaction.
|
|
289
|
+
"""Rollback the current transaction.
|
|
290
|
+
|
|
291
|
+
Raises:
|
|
292
|
+
SQLSpecError: If transaction rollback fails
|
|
293
|
+
"""
|
|
241
294
|
try:
|
|
242
295
|
await self.connection.rollback()
|
|
243
296
|
except asyncmy.errors.MySQLError as e:
|
|
@@ -245,7 +298,11 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
|
|
|
245
298
|
raise SQLSpecError(msg) from e
|
|
246
299
|
|
|
247
300
|
async def commit(self) -> None:
|
|
248
|
-
"""Commit the current transaction.
|
|
301
|
+
"""Commit the current transaction.
|
|
302
|
+
|
|
303
|
+
Raises:
|
|
304
|
+
SQLSpecError: If transaction commit fails
|
|
305
|
+
"""
|
|
249
306
|
try:
|
|
250
307
|
await self.connection.commit()
|
|
251
308
|
except asyncmy.errors.MySQLError as e:
|
|
@@ -124,12 +124,32 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
|
|
|
124
124
|
config = self._get_pool_config_dict()
|
|
125
125
|
|
|
126
126
|
if "init" not in config:
|
|
127
|
-
config["init"] = self.
|
|
127
|
+
config["init"] = self._init_connection
|
|
128
128
|
|
|
129
129
|
return await asyncpg_create_pool(**config)
|
|
130
130
|
|
|
131
|
-
async def
|
|
132
|
-
"""Initialize
|
|
131
|
+
async def _init_connection(self, connection: "AsyncpgConnection") -> None:
|
|
132
|
+
"""Initialize connection with JSON codecs and pgvector support."""
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
# Set up JSON type codec
|
|
136
|
+
await connection.set_type_codec(
|
|
137
|
+
"json",
|
|
138
|
+
encoder=self.driver_features.get("json_serializer", to_json),
|
|
139
|
+
decoder=self.driver_features.get("json_deserializer", from_json),
|
|
140
|
+
schema="pg_catalog",
|
|
141
|
+
)
|
|
142
|
+
# Set up JSONB type codec
|
|
143
|
+
await connection.set_type_codec(
|
|
144
|
+
"jsonb",
|
|
145
|
+
encoder=self.driver_features.get("json_serializer", to_json),
|
|
146
|
+
decoder=self.driver_features.get("json_deserializer", from_json),
|
|
147
|
+
schema="pg_catalog",
|
|
148
|
+
)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logger.debug("Failed to configure JSON type codecs for asyncpg: %s", e)
|
|
151
|
+
|
|
152
|
+
# Initialize pgvector support
|
|
133
153
|
try:
|
|
134
154
|
import pgvector.asyncpg
|
|
135
155
|
|
|
@@ -1,10 +1,7 @@
|
|
|
1
1
|
"""AsyncPG PostgreSQL driver implementation for async PostgreSQL operations.
|
|
2
2
|
|
|
3
|
-
Provides async PostgreSQL connectivity with
|
|
4
|
-
|
|
5
|
-
- Resource management
|
|
6
|
-
- PostgreSQL COPY operation support
|
|
7
|
-
- Transaction management
|
|
3
|
+
Provides async PostgreSQL connectivity with parameter processing, resource management,
|
|
4
|
+
PostgreSQL COPY operation support, and transaction management.
|
|
8
5
|
"""
|
|
9
6
|
|
|
10
7
|
import re
|
|
@@ -102,13 +99,9 @@ class AsyncpgExceptionHandler:
|
|
|
102
99
|
class AsyncpgDriver(AsyncDriverAdapterBase):
|
|
103
100
|
"""AsyncPG PostgreSQL driver for async database operations.
|
|
104
101
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
- PostgreSQL exception handling
|
|
109
|
-
- Transaction management
|
|
110
|
-
- SQL statement compilation and caching
|
|
111
|
-
- Parameter processing and type coercion
|
|
102
|
+
Supports COPY operations, numeric parameter style handling, PostgreSQL
|
|
103
|
+
exception handling, transaction management, SQL statement compilation
|
|
104
|
+
and caching, and parameter processing with type coercion.
|
|
112
105
|
"""
|
|
113
106
|
|
|
114
107
|
__slots__ = ()
|
|
@@ -193,7 +186,15 @@ class AsyncpgDriver(AsyncDriverAdapterBase):
|
|
|
193
186
|
await cursor.execute(sql_text)
|
|
194
187
|
|
|
195
188
|
async def _execute_script(self, cursor: "AsyncpgConnection", statement: "SQL") -> "ExecutionResult":
|
|
196
|
-
"""Execute SQL script with statement splitting and parameter handling.
|
|
189
|
+
"""Execute SQL script with statement splitting and parameter handling.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
cursor: AsyncPG connection object
|
|
193
|
+
statement: SQL statement containing multiple statements
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
ExecutionResult with script execution details
|
|
197
|
+
"""
|
|
197
198
|
sql, _ = self._get_compiled_sql(statement, self.statement_config)
|
|
198
199
|
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
|
|
199
200
|
|
|
@@ -210,7 +211,15 @@ class AsyncpgDriver(AsyncDriverAdapterBase):
|
|
|
210
211
|
)
|
|
211
212
|
|
|
212
213
|
async def _execute_many(self, cursor: "AsyncpgConnection", statement: "SQL") -> "ExecutionResult":
|
|
213
|
-
"""Execute SQL with multiple parameter sets using AsyncPG's executemany.
|
|
214
|
+
"""Execute SQL with multiple parameter sets using AsyncPG's executemany.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
cursor: AsyncPG connection object
|
|
218
|
+
statement: SQL statement with multiple parameter sets
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
ExecutionResult with batch execution details
|
|
222
|
+
"""
|
|
214
223
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
215
224
|
|
|
216
225
|
if prepared_parameters:
|
|
@@ -226,6 +235,13 @@ class AsyncpgDriver(AsyncDriverAdapterBase):
|
|
|
226
235
|
"""Execute single SQL statement.
|
|
227
236
|
|
|
228
237
|
Handles both SELECT queries and non-SELECT operations.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
cursor: AsyncPG connection object
|
|
241
|
+
statement: SQL statement to execute
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
ExecutionResult with statement execution details
|
|
229
245
|
"""
|
|
230
246
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
231
247
|
|