sqlspec 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (64) hide show
  1. sqlspec/adapters/adbc/driver.py +192 -28
  2. sqlspec/adapters/asyncmy/driver.py +72 -15
  3. sqlspec/adapters/asyncpg/config.py +23 -3
  4. sqlspec/adapters/asyncpg/driver.py +30 -14
  5. sqlspec/adapters/bigquery/driver.py +79 -9
  6. sqlspec/adapters/duckdb/driver.py +39 -56
  7. sqlspec/adapters/oracledb/driver.py +99 -52
  8. sqlspec/adapters/psqlpy/driver.py +89 -31
  9. sqlspec/adapters/psycopg/driver.py +11 -23
  10. sqlspec/adapters/sqlite/driver.py +77 -8
  11. sqlspec/base.py +29 -25
  12. sqlspec/builder/__init__.py +1 -1
  13. sqlspec/builder/_base.py +4 -5
  14. sqlspec/builder/_column.py +3 -3
  15. sqlspec/builder/_ddl.py +5 -1
  16. sqlspec/builder/_delete.py +5 -6
  17. sqlspec/builder/_insert.py +6 -7
  18. sqlspec/builder/_merge.py +5 -5
  19. sqlspec/builder/_parsing_utils.py +3 -3
  20. sqlspec/builder/_select.py +6 -5
  21. sqlspec/builder/_update.py +4 -5
  22. sqlspec/builder/mixins/_cte_and_set_ops.py +5 -1
  23. sqlspec/builder/mixins/_delete_operations.py +5 -1
  24. sqlspec/builder/mixins/_insert_operations.py +5 -1
  25. sqlspec/builder/mixins/_join_operations.py +5 -0
  26. sqlspec/builder/mixins/_merge_operations.py +5 -1
  27. sqlspec/builder/mixins/_order_limit_operations.py +5 -1
  28. sqlspec/builder/mixins/_pivot_operations.py +4 -1
  29. sqlspec/builder/mixins/_select_operations.py +5 -1
  30. sqlspec/builder/mixins/_update_operations.py +5 -1
  31. sqlspec/builder/mixins/_where_clause.py +5 -1
  32. sqlspec/cli.py +281 -33
  33. sqlspec/config.py +160 -10
  34. sqlspec/core/compiler.py +11 -3
  35. sqlspec/core/filters.py +30 -9
  36. sqlspec/core/parameters.py +67 -67
  37. sqlspec/core/result.py +62 -31
  38. sqlspec/core/splitter.py +160 -34
  39. sqlspec/core/statement.py +95 -14
  40. sqlspec/driver/_common.py +12 -3
  41. sqlspec/driver/mixins/_result_tools.py +21 -4
  42. sqlspec/driver/mixins/_sql_translator.py +45 -7
  43. sqlspec/extensions/aiosql/adapter.py +1 -1
  44. sqlspec/extensions/litestar/_utils.py +1 -1
  45. sqlspec/extensions/litestar/handlers.py +21 -0
  46. sqlspec/extensions/litestar/plugin.py +15 -8
  47. sqlspec/loader.py +12 -12
  48. sqlspec/migrations/loaders.py +5 -2
  49. sqlspec/migrations/utils.py +2 -2
  50. sqlspec/storage/backends/obstore.py +1 -3
  51. sqlspec/storage/registry.py +1 -1
  52. sqlspec/utils/__init__.py +7 -0
  53. sqlspec/utils/deprecation.py +6 -0
  54. sqlspec/utils/fixtures.py +239 -30
  55. sqlspec/utils/module_loader.py +5 -1
  56. sqlspec/utils/serializers.py +6 -0
  57. sqlspec/utils/singleton.py +6 -0
  58. sqlspec/utils/sync_tools.py +10 -1
  59. {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/METADATA +1 -1
  60. {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/RECORD +64 -64
  61. {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/WHEEL +0 -0
  62. {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/entry_points.txt +0 -0
  63. {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/licenses/LICENSE +0 -0
  64. {sqlspec-0.18.0.dist-info → sqlspec-0.20.0.dist-info}/licenses/NOTICE +0 -0
@@ -1,8 +1,7 @@
1
1
  """ADBC driver implementation for Arrow Database Connectivity.
2
2
 
3
- Provides ADBC driver integration with multi-dialect database connections,
4
- Arrow-native data handling with type coercion, parameter style conversion
5
- for different database backends, and transaction management.
3
+ Provides database connectivity through ADBC with support for multiple
4
+ database dialects, parameter style conversion, and transaction management.
6
5
  """
7
6
 
8
7
  import contextlib
@@ -10,6 +9,7 @@ import datetime
10
9
  import decimal
11
10
  from typing import TYPE_CHECKING, Any, Optional, cast
12
11
 
12
+ from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
13
13
  from sqlglot import exp
14
14
 
15
15
  from sqlspec.core.cache import get_cache_config
@@ -53,22 +53,88 @@ DIALECT_PARAMETER_STYLES = {
53
53
  }
54
54
 
55
55
 
56
- def _adbc_ast_transformer(expression: Any, parameters: Any) -> tuple[Any, Any]:
57
- """AST transformer for NULL parameter handling.
56
+ def _count_placeholders(expression: Any) -> int:
57
+ """Count the number of unique parameter placeholders in a SQLGlot expression.
58
58
 
59
- For PostgreSQL, replaces NULL parameter placeholders with NULL literals
60
- in the AST to prevent Arrow from inferring 'na' types which cause binding errors.
59
+ For PostgreSQL ($1, $2) style: counts highest numbered parameter (e.g., $1, $1, $2 = 2)
60
+ For QMARK (?) style: counts total occurrences (each ? is a separate parameter)
61
+ For named (:name) style: counts unique parameter names
61
62
 
62
63
  Args:
63
64
  expression: SQLGlot AST expression
64
- parameters: Parameter values that may contain None
65
65
 
66
66
  Returns:
67
- Tuple of (modified_expression, cleaned_parameters)
67
+ Number of unique parameter placeholders expected
68
68
  """
69
- if not parameters:
70
- return expression, parameters
69
+ numeric_params = set() # For $1, $2 style
70
+ qmark_count = 0 # For ? style
71
+ named_params = set() # For :name style
72
+
73
+ def count_node(node: Any) -> Any:
74
+ nonlocal qmark_count
75
+ if isinstance(node, exp.Parameter):
76
+ # PostgreSQL style: $1, $2, etc.
77
+ param_str = str(node)
78
+ if param_str.startswith("$") and param_str[1:].isdigit():
79
+ numeric_params.add(int(param_str[1:]))
80
+ elif ":" in param_str:
81
+ # Named parameter: :name
82
+ named_params.add(param_str)
83
+ else:
84
+ # Other parameter formats
85
+ named_params.add(param_str)
86
+ elif isinstance(node, exp.Placeholder):
87
+ # QMARK style: ?
88
+ qmark_count += 1
89
+ return node
90
+
91
+ expression.transform(count_node)
92
+
93
+ # Return the appropriate count based on parameter style detected
94
+ if numeric_params:
95
+ # PostgreSQL style: return highest numbered parameter
96
+ return max(numeric_params)
97
+ if named_params:
98
+ # Named parameters: return count of unique names
99
+ return len(named_params)
100
+ # QMARK style: return total count
101
+ return qmark_count
102
+
103
+
104
+ def _is_execute_many_parameters(parameters: Any) -> bool:
105
+ """Check if parameters are in execute_many format (list/tuple of lists/tuples)."""
106
+ return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], (list, tuple))
107
+
71
108
 
109
+ def _validate_parameter_counts(expression: Any, parameters: Any, dialect: str) -> None:
110
+ """Validate parameter count against placeholder count in SQL."""
111
+ placeholder_count = _count_placeholders(expression)
112
+ is_execute_many = _is_execute_many_parameters(parameters)
113
+
114
+ if is_execute_many:
115
+ # For execute_many, validate each inner parameter set
116
+ for i, param_set in enumerate(parameters):
117
+ param_count = len(param_set) if isinstance(param_set, (list, tuple)) else 0
118
+ if param_count != placeholder_count:
119
+ msg = f"Parameter count mismatch in set {i}: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
120
+ raise SQLSpecError(msg)
121
+ else:
122
+ # For single execution, validate the parameter set directly
123
+ param_count = (
124
+ len(parameters)
125
+ if isinstance(parameters, (list, tuple))
126
+ else len(parameters)
127
+ if isinstance(parameters, dict)
128
+ else 0
129
+ )
130
+
131
+ if param_count != placeholder_count:
132
+ msg = f"Parameter count mismatch: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
133
+ raise SQLSpecError(msg)
134
+
135
+
136
+ def _find_null_positions(parameters: Any) -> set[int]:
137
+ """Find positions of None values in parameters for single execution."""
72
138
  null_positions = set()
73
139
  if isinstance(parameters, (list, tuple)):
74
140
  for i, param in enumerate(parameters):
@@ -83,7 +149,37 @@ def _adbc_ast_transformer(expression: Any, parameters: Any) -> tuple[Any, Any]:
83
149
  null_positions.add(param_num - 1)
84
150
  except ValueError:
85
151
  pass
152
+ return null_positions
153
+
154
+
155
+ def _adbc_ast_transformer(expression: Any, parameters: Any, dialect: str = "postgres") -> tuple[Any, Any]:
156
+ """Transform AST to handle NULL parameters.
86
157
 
158
+ Replaces NULL parameter placeholders with NULL literals in the AST
159
+ to prevent Arrow from inferring 'na' types which cause binding errors.
160
+ Validates parameter count before transformation.
161
+
162
+ Args:
163
+ expression: SQLGlot AST expression parsed with proper dialect
164
+ parameters: Parameter values that may contain None
165
+ dialect: SQLGlot dialect used for parsing (default: "postgres")
166
+
167
+ Returns:
168
+ Tuple of (modified_expression, cleaned_parameters)
169
+ """
170
+ if not parameters:
171
+ return expression, parameters
172
+
173
+ # Validate parameter count before transformation
174
+ _validate_parameter_counts(expression, parameters, dialect)
175
+
176
+ # For execute_many operations, skip AST transformation as different parameter
177
+ # sets may have None values in different positions, making transformation complex
178
+ if _is_execute_many_parameters(parameters):
179
+ return expression, parameters
180
+
181
+ # Find positions of None values for single execution
182
+ null_positions = _find_null_positions(parameters)
87
183
  if not null_positions:
88
184
  return expression, parameters
89
185
 
@@ -183,14 +279,28 @@ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
183
279
 
184
280
 
185
281
  def _convert_array_for_postgres_adbc(value: Any) -> Any:
186
- """Convert array values for PostgreSQL compatibility."""
282
+ """Convert array values for PostgreSQL compatibility.
283
+
284
+ Args:
285
+ value: Value to convert
286
+
287
+ Returns:
288
+ Converted value (tuples become lists)
289
+ """
187
290
  if isinstance(value, tuple):
188
291
  return list(value)
189
292
  return value
190
293
 
191
294
 
192
295
  def get_type_coercion_map(dialect: str) -> "dict[type, Any]":
193
- """Get type coercion map for Arrow type handling."""
296
+ """Get type coercion map for Arrow type handling.
297
+
298
+ Args:
299
+ dialect: Database dialect name
300
+
301
+ Returns:
302
+ Mapping of Python types to conversion functions
303
+ """
194
304
  type_map = {
195
305
  datetime.datetime: lambda x: x,
196
306
  datetime.date: lambda x: x,
@@ -245,8 +355,6 @@ class AdbcExceptionHandler:
245
355
  return
246
356
 
247
357
  try:
248
- from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
249
-
250
358
  if issubclass(exc_type, IntegrityError):
251
359
  e = exc_val
252
360
  msg = f"Integrity constraint violation: {e}"
@@ -282,9 +390,8 @@ class AdbcExceptionHandler:
282
390
  class AdbcDriver(SyncDriverAdapterBase):
283
391
  """ADBC driver for Arrow Database Connectivity.
284
392
 
285
- Provides database connectivity through ADBC with multi-database dialect
286
- support, Arrow-native data handling with type coercion, parameter style
287
- conversion for different backends, and transaction management.
393
+ Provides database connectivity through ADBC with support for multiple
394
+ database dialects, parameter style conversion, and transaction management.
288
395
  """
289
396
 
290
397
  __slots__ = ("_detected_dialect", "dialect")
@@ -309,7 +416,11 @@ class AdbcDriver(SyncDriverAdapterBase):
309
416
 
310
417
  @staticmethod
311
418
  def _ensure_pyarrow_installed() -> None:
312
- """Ensure PyArrow is installed."""
419
+ """Ensure PyArrow is installed.
420
+
421
+ Raises:
422
+ MissingDependencyError: If PyArrow is not installed
423
+ """
313
424
  from sqlspec.typing import PYARROW_INSTALLED
314
425
 
315
426
  if not PYARROW_INSTALLED:
@@ -317,7 +428,14 @@ class AdbcDriver(SyncDriverAdapterBase):
317
428
 
318
429
  @staticmethod
319
430
  def _get_dialect(connection: "AdbcConnection") -> str:
320
- """Detect database dialect from connection information."""
431
+ """Detect database dialect from connection information.
432
+
433
+ Args:
434
+ connection: ADBC connection
435
+
436
+ Returns:
437
+ Detected dialect name (defaults to 'postgres')
438
+ """
321
439
  try:
322
440
  driver_info = connection.adbc_get_info()
323
441
  vendor_name = driver_info.get("vendor_name", "").lower()
@@ -334,31 +452,53 @@ class AdbcDriver(SyncDriverAdapterBase):
334
452
  return "postgres"
335
453
 
336
454
  def _handle_postgres_rollback(self, cursor: "Cursor") -> None:
337
- """Execute rollback for PostgreSQL after transaction failure."""
455
+ """Execute rollback for PostgreSQL after transaction failure.
456
+
457
+ Args:
458
+ cursor: Database cursor
459
+ """
338
460
  if self.dialect == "postgres":
339
461
  with contextlib.suppress(Exception):
340
462
  cursor.execute("ROLLBACK")
341
463
  logger.debug("PostgreSQL rollback executed after transaction failure")
342
464
 
343
465
  def _handle_postgres_empty_parameters(self, parameters: Any) -> Any:
344
- """Process empty parameters for PostgreSQL compatibility."""
466
+ """Process empty parameters for PostgreSQL compatibility.
467
+
468
+ Args:
469
+ parameters: Parameter values
470
+
471
+ Returns:
472
+ Processed parameters
473
+ """
345
474
  if self.dialect == "postgres" and isinstance(parameters, dict) and not parameters:
346
475
  return None
347
476
  return parameters
348
477
 
349
478
  def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor":
350
- """Create context manager for cursor."""
479
+ """Create context manager for cursor.
480
+
481
+ Args:
482
+ connection: Database connection
483
+
484
+ Returns:
485
+ Cursor context manager
486
+ """
351
487
  return AdbcCursor(connection)
352
488
 
353
489
  def handle_database_exceptions(self) -> "AbstractContextManager[None]":
354
- """Handle database-specific exceptions and wrap them appropriately."""
490
+ """Handle database-specific exceptions and wrap them appropriately.
491
+
492
+ Returns:
493
+ Exception handler context manager
494
+ """
355
495
  return AdbcExceptionHandler()
356
496
 
357
497
  def _try_special_handling(self, cursor: "Cursor", statement: SQL) -> "Optional[SQLResult]":
358
498
  """Handle special operations.
359
499
 
360
500
  Args:
361
- cursor: Cursor object
501
+ cursor: Database cursor
362
502
  statement: SQL statement to analyze
363
503
 
364
504
  Returns:
@@ -368,7 +508,15 @@ class AdbcDriver(SyncDriverAdapterBase):
368
508
  return None
369
509
 
370
510
  def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
371
- """Execute SQL with multiple parameter sets."""
511
+ """Execute SQL with multiple parameter sets.
512
+
513
+ Args:
514
+ cursor: Database cursor
515
+ statement: SQL statement to execute
516
+
517
+ Returns:
518
+ Execution result with row counts
519
+ """
372
520
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
373
521
 
374
522
  try:
@@ -398,7 +546,15 @@ class AdbcDriver(SyncDriverAdapterBase):
398
546
  return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
399
547
 
400
548
  def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
401
- """Execute single SQL statement."""
549
+ """Execute single SQL statement.
550
+
551
+ Args:
552
+ cursor: Database cursor
553
+ statement: SQL statement to execute
554
+
555
+ Returns:
556
+ Execution result with data or row count
557
+ """
402
558
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
403
559
 
404
560
  try:
@@ -430,7 +586,15 @@ class AdbcDriver(SyncDriverAdapterBase):
430
586
  return self.create_execution_result(cursor, rowcount_override=row_count)
431
587
 
432
588
  def _execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult":
433
- """Execute SQL script."""
589
+ """Execute SQL script containing multiple statements.
590
+
591
+ Args:
592
+ cursor: Database cursor
593
+ statement: SQL script to execute
594
+
595
+ Returns:
596
+ Execution result with statement counts
597
+ """
434
598
  if statement.is_script:
435
599
  sql = statement._raw_sql
436
600
  prepared_parameters: list[Any] = []
@@ -51,7 +51,10 @@ asyncmy_statement_config = StatementConfig(
51
51
 
52
52
 
53
53
  class AsyncmyCursor:
54
- """Async context manager for AsyncMy cursor management."""
54
+ """Context manager for AsyncMy cursor operations.
55
+
56
+ Provides automatic cursor acquisition and cleanup for database operations.
57
+ """
55
58
 
56
59
  __slots__ = ("connection", "cursor")
57
60
 
@@ -70,7 +73,11 @@ class AsyncmyCursor:
70
73
 
71
74
 
72
75
  class AsyncmyExceptionHandler:
73
- """Custom async context manager for handling AsyncMy database exceptions."""
76
+ """Context manager for AsyncMy database exception handling.
77
+
78
+ Converts AsyncMy-specific exceptions to SQLSpec exceptions with appropriate
79
+ error categorization and context preservation.
80
+ """
74
81
 
75
82
  __slots__ = ()
76
83
 
@@ -116,10 +123,11 @@ class AsyncmyExceptionHandler:
116
123
 
117
124
 
118
125
  class AsyncmyDriver(AsyncDriverAdapterBase):
119
- """AsyncMy MySQL/MariaDB driver.
126
+ """MySQL/MariaDB database driver using AsyncMy client library.
120
127
 
121
- Provides MySQL/MariaDB connectivity with parameter style conversion,
122
- type coercion, error handling, and transaction management.
128
+ Implements asynchronous database operations for MySQL and MariaDB servers
129
+ with support for parameter style conversion, type coercion, error handling,
130
+ and transaction management.
123
131
  """
124
132
 
125
133
  __slots__ = ()
@@ -143,22 +151,33 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
143
151
  super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
144
152
 
145
153
  def with_cursor(self, connection: "AsyncmyConnection") -> "AsyncmyCursor":
146
- """Create context manager for AsyncMy cursor."""
154
+ """Create cursor context manager for the connection.
155
+
156
+ Args:
157
+ connection: AsyncMy database connection
158
+
159
+ Returns:
160
+ AsyncmyCursor: Context manager for cursor operations
161
+ """
147
162
  return AsyncmyCursor(connection)
148
163
 
149
164
  def handle_database_exceptions(self) -> "AbstractAsyncContextManager[None]":
150
- """Handle database-specific exceptions and wrap them appropriately."""
165
+ """Provide exception handling context manager.
166
+
167
+ Returns:
168
+ AbstractAsyncContextManager[None]: Context manager for AsyncMy exception handling
169
+ """
151
170
  return AsyncmyExceptionHandler()
152
171
 
153
172
  async def _try_special_handling(self, cursor: Any, statement: "SQL") -> "Optional[SQLResult]":
154
- """Hook for AsyncMy-specific special operations.
173
+ """Handle AsyncMy-specific operations before standard execution.
155
174
 
156
175
  Args:
157
176
  cursor: AsyncMy cursor object
158
177
  statement: SQL statement to analyze
159
178
 
160
179
  Returns:
161
- None - always proceeds with standard execution for AsyncMy
180
+ Optional[SQLResult]: None, always proceeds with standard execution
162
181
  """
163
182
  _ = (cursor, statement)
164
183
  return None
@@ -166,7 +185,15 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
166
185
  async def _execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
167
186
  """Execute SQL script with statement splitting and parameter handling.
168
187
 
188
+ Splits multi-statement scripts and executes each statement sequentially.
169
189
  Parameters are embedded as static values for script execution compatibility.
190
+
191
+ Args:
192
+ cursor: AsyncMy cursor object
193
+ statement: SQL script to execute
194
+
195
+ Returns:
196
+ ExecutionResult: Script execution results with statement count
170
197
  """
171
198
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
172
199
  statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
@@ -183,9 +210,20 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
183
210
  )
184
211
 
185
212
  async def _execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
186
- """Execute SQL with multiple parameter sets using AsyncMy batch processing.
213
+ """Execute SQL statement with multiple parameter sets.
187
214
 
188
- Handles MySQL type conversion and parameter processing.
215
+ Uses AsyncMy's executemany for batch operations with MySQL type conversion
216
+ and parameter processing.
217
+
218
+ Args:
219
+ cursor: AsyncMy cursor object
220
+ statement: SQL statement with multiple parameter sets
221
+
222
+ Returns:
223
+ ExecutionResult: Batch execution results
224
+
225
+ Raises:
226
+ ValueError: If no parameters provided for executemany operation
189
227
  """
190
228
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
191
229
 
@@ -200,9 +238,17 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
200
238
  return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
201
239
 
202
240
  async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
203
- """Execute single SQL statement with AsyncMy MySQL data handling.
241
+ """Execute single SQL statement.
242
+
243
+ Handles parameter processing, result fetching, and data transformation
244
+ for MySQL/MariaDB operations.
245
+
246
+ Args:
247
+ cursor: AsyncMy cursor object
248
+ statement: SQL statement to execute
204
249
 
205
- Handles parameter processing and MySQL result processing.
250
+ Returns:
251
+ ExecutionResult: Statement execution results with data or row counts
206
252
  """
207
253
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
208
254
  await cursor.execute(sql, prepared_parameters or None)
@@ -228,6 +274,9 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
228
274
  """Begin a database transaction.
229
275
 
230
276
  Explicitly starts a MySQL transaction to ensure proper transaction boundaries.
277
+
278
+ Raises:
279
+ SQLSpecError: If transaction initialization fails
231
280
  """
232
281
  try:
233
282
  async with AsyncmyCursor(self.connection) as cursor:
@@ -237,7 +286,11 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
237
286
  raise SQLSpecError(msg) from e
238
287
 
239
288
  async def rollback(self) -> None:
240
- """Rollback the current transaction."""
289
+ """Rollback the current transaction.
290
+
291
+ Raises:
292
+ SQLSpecError: If transaction rollback fails
293
+ """
241
294
  try:
242
295
  await self.connection.rollback()
243
296
  except asyncmy.errors.MySQLError as e:
@@ -245,7 +298,11 @@ class AsyncmyDriver(AsyncDriverAdapterBase):
245
298
  raise SQLSpecError(msg) from e
246
299
 
247
300
  async def commit(self) -> None:
248
- """Commit the current transaction."""
301
+ """Commit the current transaction.
302
+
303
+ Raises:
304
+ SQLSpecError: If transaction commit fails
305
+ """
249
306
  try:
250
307
  await self.connection.commit()
251
308
  except asyncmy.errors.MySQLError as e:
@@ -124,12 +124,32 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
124
124
  config = self._get_pool_config_dict()
125
125
 
126
126
  if "init" not in config:
127
- config["init"] = self._init_pgvector_connection
127
+ config["init"] = self._init_connection
128
128
 
129
129
  return await asyncpg_create_pool(**config)
130
130
 
131
- async def _init_pgvector_connection(self, connection: "AsyncpgConnection") -> None:
132
- """Initialize pgvector support for asyncpg connections."""
131
+ async def _init_connection(self, connection: "AsyncpgConnection") -> None:
132
+ """Initialize connection with JSON codecs and pgvector support."""
133
+
134
+ try:
135
+ # Set up JSON type codec
136
+ await connection.set_type_codec(
137
+ "json",
138
+ encoder=self.driver_features.get("json_serializer", to_json),
139
+ decoder=self.driver_features.get("json_deserializer", from_json),
140
+ schema="pg_catalog",
141
+ )
142
+ # Set up JSONB type codec
143
+ await connection.set_type_codec(
144
+ "jsonb",
145
+ encoder=self.driver_features.get("json_serializer", to_json),
146
+ decoder=self.driver_features.get("json_deserializer", from_json),
147
+ schema="pg_catalog",
148
+ )
149
+ except Exception as e:
150
+ logger.debug("Failed to configure JSON type codecs for asyncpg: %s", e)
151
+
152
+ # Initialize pgvector support
133
153
  try:
134
154
  import pgvector.asyncpg
135
155
 
@@ -1,10 +1,7 @@
1
1
  """AsyncPG PostgreSQL driver implementation for async PostgreSQL operations.
2
2
 
3
- Provides async PostgreSQL connectivity with:
4
- - Parameter processing with type coercion
5
- - Resource management
6
- - PostgreSQL COPY operation support
7
- - Transaction management
3
+ Provides async PostgreSQL connectivity with parameter processing, resource management,
4
+ PostgreSQL COPY operation support, and transaction management.
8
5
  """
9
6
 
10
7
  import re
@@ -102,13 +99,9 @@ class AsyncpgExceptionHandler:
102
99
  class AsyncpgDriver(AsyncDriverAdapterBase):
103
100
  """AsyncPG PostgreSQL driver for async database operations.
104
101
 
105
- Features:
106
- - COPY operation support
107
- - Numeric parameter style handling
108
- - PostgreSQL exception handling
109
- - Transaction management
110
- - SQL statement compilation and caching
111
- - Parameter processing and type coercion
102
+ Supports COPY operations, numeric parameter style handling, PostgreSQL
103
+ exception handling, transaction management, SQL statement compilation
104
+ and caching, and parameter processing with type coercion.
112
105
  """
113
106
 
114
107
  __slots__ = ()
@@ -193,7 +186,15 @@ class AsyncpgDriver(AsyncDriverAdapterBase):
193
186
  await cursor.execute(sql_text)
194
187
 
195
188
  async def _execute_script(self, cursor: "AsyncpgConnection", statement: "SQL") -> "ExecutionResult":
196
- """Execute SQL script with statement splitting and parameter handling."""
189
+ """Execute SQL script with statement splitting and parameter handling.
190
+
191
+ Args:
192
+ cursor: AsyncPG connection object
193
+ statement: SQL statement containing multiple statements
194
+
195
+ Returns:
196
+ ExecutionResult with script execution details
197
+ """
197
198
  sql, _ = self._get_compiled_sql(statement, self.statement_config)
198
199
  statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
199
200
 
@@ -210,7 +211,15 @@ class AsyncpgDriver(AsyncDriverAdapterBase):
210
211
  )
211
212
 
212
213
  async def _execute_many(self, cursor: "AsyncpgConnection", statement: "SQL") -> "ExecutionResult":
213
- """Execute SQL with multiple parameter sets using AsyncPG's executemany."""
214
+ """Execute SQL with multiple parameter sets using AsyncPG's executemany.
215
+
216
+ Args:
217
+ cursor: AsyncPG connection object
218
+ statement: SQL statement with multiple parameter sets
219
+
220
+ Returns:
221
+ ExecutionResult with batch execution details
222
+ """
214
223
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
215
224
 
216
225
  if prepared_parameters:
@@ -226,6 +235,13 @@ class AsyncpgDriver(AsyncDriverAdapterBase):
226
235
  """Execute single SQL statement.
227
236
 
228
237
  Handles both SELECT queries and non-SELECT operations.
238
+
239
+ Args:
240
+ cursor: AsyncPG connection object
241
+ statement: SQL statement to execute
242
+
243
+ Returns:
244
+ ExecutionResult with statement execution details
229
245
  """
230
246
  sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
231
247