sqlspec 0.12.2__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +16 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +17 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +17 -29
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +32 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +18 -9
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +44 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.1.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/NOTICE +0 -0
@@ -3,10 +3,10 @@
3
3
  import logging
4
4
  from collections.abc import AsyncGenerator
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
8
7
 
9
8
  import asyncmy
9
+ from asyncmy.pool import Pool as AsyncmyPool
10
10
 
11
11
  from sqlspec.adapters.asyncmy.driver import AsyncmyConnection, AsyncmyDriver
12
12
  from sqlspec.config import AsyncDatabaseConfig
@@ -193,7 +193,6 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
193
193
  if getattr(self, field, None) is not None and getattr(self, field) is not Empty
194
194
  }
195
195
 
196
- # Add connection-specific extras (not pool-specific ones)
197
196
  config.update(self.extras)
198
197
 
199
198
  return config
@@ -264,15 +263,16 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
264
263
  An AsyncmyDriver instance.
265
264
  """
266
265
  async with self.provide_connection(*args, **kwargs) as connection:
267
- # Create statement config with parameter style info if not already set
268
266
  statement_config = self.statement_config
267
+ # Inject parameter style info if not already set
269
268
  if statement_config.allowed_parameter_styles is None:
269
+ from dataclasses import replace
270
+
270
271
  statement_config = replace(
271
272
  statement_config,
272
273
  allowed_parameter_styles=self.supported_parameter_styles,
273
274
  target_parameter_style=self.preferred_parameter_style,
274
275
  )
275
-
276
276
  yield self.driver_type(connection=connection, config=statement_config)
277
277
 
278
278
  async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: ignore
@@ -284,3 +284,16 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "Pool", AsyncmyDriver
284
284
  if not self.pool_instance:
285
285
  self.pool_instance = await self.create_pool()
286
286
  return self.pool_instance
287
+
288
+ def get_signature_namespace(self) -> "dict[str, type[Any]]":
289
+ """Get the signature namespace for Asyncmy types.
290
+
291
+ This provides all Asyncmy-specific types that Litestar needs to recognize
292
+ to avoid serialization attempts.
293
+
294
+ Returns:
295
+ Dictionary mapping type names to types.
296
+ """
297
+ namespace = super().get_signature_namespace()
298
+ namespace.update({"AsyncmyConnection": AsyncmyConnection, "AsyncmyPool": AsyncmyPool})
299
+ return namespace
@@ -1,12 +1,13 @@
1
1
  import logging
2
2
  from collections.abc import AsyncGenerator, Sequence
3
3
  from contextlib import asynccontextmanager
4
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
4
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
5
5
 
6
6
  from asyncmy import Connection
7
7
  from typing_extensions import TypeAlias
8
8
 
9
9
  from sqlspec.driver import AsyncDriverAdapterProtocol
10
+ from sqlspec.driver.connection import managed_transaction_async
10
11
  from sqlspec.driver.mixins import (
11
12
  AsyncPipelinedExecutionMixin,
12
13
  AsyncStorageMixin,
@@ -14,10 +15,11 @@ from sqlspec.driver.mixins import (
14
15
  ToSchemaMixin,
15
16
  TypeCoercionMixin,
16
17
  )
17
- from sqlspec.statement.parameters import ParameterStyle
18
- from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
18
+ from sqlspec.driver.parameters import normalize_parameter_sequence
19
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
20
+ from sqlspec.statement.result import SQLResult
19
21
  from sqlspec.statement.sql import SQL, SQLConfig
20
- from sqlspec.typing import DictRow, ModelDTOT, RowT
22
+ from sqlspec.typing import DictRow, RowT
21
23
 
22
24
  if TYPE_CHECKING:
23
25
  from asyncmy.cursors import Cursor, DictCursor
@@ -60,7 +62,7 @@ class AsyncmyDriver(
60
62
  self, connection: "Optional[AsyncmyConnection]" = None
61
63
  ) -> "AsyncGenerator[Union[Cursor, DictCursor], None]":
62
64
  conn = self._connection(connection)
63
- cursor = await conn.cursor()
65
+ cursor = conn.cursor()
64
66
  try:
65
67
  yield cursor
66
68
  finally:
@@ -68,95 +70,146 @@ class AsyncmyDriver(
68
70
 
69
71
  async def _execute_statement(
70
72
  self, statement: SQL, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
71
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
73
+ ) -> SQLResult[RowT]:
72
74
  if statement.is_script:
73
75
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
74
76
  return await self._execute_script(sql, connection=connection, **kwargs)
75
77
 
76
- # Let the SQL object handle parameter style conversion based on dialect support
77
- sql, params = statement.compile(placeholder_style=self.default_parameter_style)
78
+ # Detect parameter styles in the SQL
79
+ detected_styles = set()
80
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
81
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
82
+ param_infos = validator.extract_parameters(sql_str)
83
+ if param_infos:
84
+ detected_styles = {p.style for p in param_infos}
85
+
86
+ # Determine target style based on what's in the SQL
87
+ target_style = self.default_parameter_style
88
+
89
+ # Check if there are unsupported styles
90
+ unsupported_styles = detected_styles - set(self.supported_parameter_styles)
91
+ if unsupported_styles:
92
+ # Force conversion to default style
93
+ target_style = self.default_parameter_style
94
+ elif detected_styles:
95
+ # Prefer the first supported style found
96
+ for style in detected_styles:
97
+ if style in self.supported_parameter_styles:
98
+ target_style = style
99
+ break
100
+
101
+ # Compile with the determined style
102
+ sql, params = statement.compile(placeholder_style=target_style)
78
103
 
79
104
  if statement.is_many:
80
- # Process parameter list through type coercion
81
105
  params = self._process_parameters(params)
82
106
  return await self._execute_many(sql, params, connection=connection, **kwargs)
83
107
 
84
- # Process parameters through type coercion
85
108
  params = self._process_parameters(params)
86
109
  return await self._execute(sql, params, statement, connection=connection, **kwargs)
87
110
 
88
111
  async def _execute(
89
112
  self, sql: str, parameters: Any, statement: SQL, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
90
- ) -> Union[SelectResultDict, DMLResultDict]:
91
- conn = self._connection(connection)
92
- # AsyncMy doesn't like empty lists/tuples, convert to None
93
- if not parameters:
94
- parameters = None
95
- async with self._get_cursor(conn) as cursor:
96
- # AsyncMy expects list/tuple parameters or dict for named params
97
- await cursor.execute(sql, parameters)
98
-
99
- if self.returns_rows(statement.expression):
100
- # For SELECT queries, fetch data and return SelectResultDict
101
- data = await cursor.fetchall()
102
- column_names = [desc[0] for desc in cursor.description or []]
103
- result: SelectResultDict = {"data": data, "column_names": column_names, "rows_affected": len(data)}
104
- return result
105
-
106
- # For DML/DDL queries, return DMLResultDict
107
- dml_result: DMLResultDict = {
108
- "rows_affected": cursor.rowcount if cursor.rowcount is not None else -1,
109
- "status_message": "OK",
110
- }
111
- return dml_result
113
+ ) -> SQLResult[RowT]:
114
+ # Use provided connection or driver's default connection
115
+ conn = connection if connection is not None else self._connection(None)
116
+
117
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
118
+ # Normalize parameters using consolidated utility
119
+ normalized_params = normalize_parameter_sequence(parameters)
120
+ # AsyncMy doesn't like empty lists/tuples, convert to None
121
+ final_params = (
122
+ normalized_params[0] if normalized_params and len(normalized_params) == 1 else normalized_params
123
+ )
124
+ if not final_params:
125
+ final_params = None
126
+
127
+ async with self._get_cursor(txn_conn) as cursor:
128
+ # AsyncMy expects list/tuple parameters or dict for named params
129
+ await cursor.execute(sql, final_params)
130
+
131
+ if self.returns_rows(statement.expression):
132
+ # For SELECT queries, fetch data and return SQLResult
133
+ data = await cursor.fetchall()
134
+ column_names = [desc[0] for desc in cursor.description or []]
135
+ return SQLResult(
136
+ statement=statement,
137
+ data=data,
138
+ column_names=column_names,
139
+ rows_affected=len(data),
140
+ operation_type="SELECT",
141
+ )
142
+
143
+ # For DML/DDL queries
144
+ return SQLResult(
145
+ statement=statement,
146
+ data=[],
147
+ rows_affected=cursor.rowcount if cursor.rowcount is not None else -1,
148
+ operation_type=self._determine_operation_type(statement),
149
+ metadata={"status_message": "OK"},
150
+ )
112
151
 
113
152
  async def _execute_many(
114
153
  self, sql: str, param_list: Any, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
115
- ) -> DMLResultDict:
116
- conn = self._connection(connection)
117
-
118
- # Convert parameter list to proper format for executemany
119
- params_list: list[Union[list[Any], tuple[Any, ...]]] = []
120
- if param_list and isinstance(param_list, Sequence):
121
- for param_set in param_list:
122
- if isinstance(param_set, (list, tuple)):
123
- params_list.append(param_set)
124
- elif param_set is None:
125
- params_list.append([])
126
- else:
127
- params_list.append([param_set])
128
-
129
- async with self._get_cursor(conn) as cursor:
130
- await cursor.executemany(sql, params_list)
131
- result: DMLResultDict = {
132
- "rows_affected": cursor.rowcount if cursor.rowcount != -1 else len(params_list),
133
- "status_message": "OK",
134
- }
135
- return result
154
+ ) -> SQLResult[RowT]:
155
+ # Use provided connection or driver's default connection
156
+ conn = connection if connection is not None else self._connection(None)
157
+
158
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
159
+ # Normalize parameter list using consolidated utility
160
+ normalized_param_list = normalize_parameter_sequence(param_list)
161
+
162
+ params_list: list[Union[list[Any], tuple[Any, ...]]] = []
163
+ if normalized_param_list and isinstance(normalized_param_list, Sequence):
164
+ for param_set in normalized_param_list:
165
+ if isinstance(param_set, (list, tuple)):
166
+ params_list.append(param_set)
167
+ elif param_set is None:
168
+ params_list.append([])
169
+ else:
170
+ params_list.append([param_set])
171
+
172
+ async with self._get_cursor(txn_conn) as cursor:
173
+ await cursor.executemany(sql, params_list)
174
+ return SQLResult(
175
+ statement=SQL(sql, _dialect=self.dialect),
176
+ data=[],
177
+ rows_affected=cursor.rowcount if cursor.rowcount != -1 else len(params_list),
178
+ operation_type="EXECUTE",
179
+ metadata={"status_message": "OK"},
180
+ )
136
181
 
137
182
  async def _execute_script(
138
183
  self, script: str, connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any
139
- ) -> ScriptResultDict:
140
- conn = self._connection(connection)
141
- # AsyncMy may not support multi-statement scripts without CLIENT_MULTI_STATEMENTS flag
142
- # Use the shared implementation to split and execute statements individually
143
- statements = self._split_script_statements(script)
144
- statements_executed = 0
145
-
146
- async with self._get_cursor(conn) as cursor:
147
- for statement_str in statements:
148
- if statement_str:
149
- await cursor.execute(statement_str)
150
- statements_executed += 1
151
-
152
- result: ScriptResultDict = {"statements_executed": statements_executed, "status_message": "SCRIPT EXECUTED"}
153
- return result
184
+ ) -> SQLResult[RowT]:
185
+ # Use provided connection or driver's default connection
186
+ conn = connection if connection is not None else self._connection(None)
187
+
188
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
189
+ # AsyncMy may not support multi-statement scripts without CLIENT_MULTI_STATEMENTS flag
190
+ statements = self._split_script_statements(script)
191
+ statements_executed = 0
192
+
193
+ async with self._get_cursor(txn_conn) as cursor:
194
+ for statement_str in statements:
195
+ if statement_str:
196
+ await cursor.execute(statement_str)
197
+ statements_executed += 1
198
+
199
+ return SQLResult(
200
+ statement=SQL(script, _dialect=self.dialect).as_script(),
201
+ data=[],
202
+ rows_affected=0,
203
+ operation_type="SCRIPT",
204
+ metadata={"status_message": "SCRIPT EXECUTED"},
205
+ total_statements=statements_executed,
206
+ successful_statements=statements_executed,
207
+ )
154
208
 
155
209
  async def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
156
210
  self._ensure_pyarrow_installed()
157
211
  conn = self._connection(None)
158
-
159
- async with self._get_cursor(conn) as cursor:
212
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
160
213
  if mode == "replace":
161
214
  await cursor.execute(f"TRUNCATE TABLE {table_name}")
162
215
  elif mode == "create":
@@ -174,71 +227,6 @@ class AsyncmyDriver(
174
227
  await cursor.executemany(sql, data_for_ingest)
175
228
  return cursor.rowcount if cursor.rowcount is not None else -1
176
229
 
177
- async def _wrap_select_result(
178
- self, statement: SQL, result: SelectResultDict, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any
179
- ) -> "Union[SQLResult[ModelDTOT], SQLResult[RowT]]":
180
- data = result["data"]
181
- column_names = result["column_names"]
182
- rows_affected = result["rows_affected"]
183
-
184
- if not data:
185
- return SQLResult[RowT](
186
- statement=statement, data=[], column_names=column_names, rows_affected=0, operation_type="SELECT"
187
- )
188
-
189
- rows_as_dicts = [dict(zip(column_names, row)) for row in data]
190
-
191
- if schema_type:
192
- converted_data = self.to_schema(data=rows_as_dicts, schema_type=schema_type)
193
- return SQLResult[ModelDTOT](
194
- statement=statement,
195
- data=list(converted_data),
196
- column_names=column_names,
197
- rows_affected=rows_affected,
198
- operation_type="SELECT",
199
- )
200
-
201
- return SQLResult[RowT](
202
- statement=statement,
203
- data=rows_as_dicts,
204
- column_names=column_names,
205
- rows_affected=rows_affected,
206
- operation_type="SELECT",
207
- )
208
-
209
- async def _wrap_execute_result(
210
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
211
- ) -> SQLResult[RowT]:
212
- operation_type = "UNKNOWN"
213
- if statement.expression:
214
- operation_type = str(statement.expression.key).upper()
215
-
216
- # Handle script results
217
- if "statements_executed" in result:
218
- script_result = cast("ScriptResultDict", result)
219
- return SQLResult[RowT](
220
- statement=statement,
221
- data=[],
222
- rows_affected=0,
223
- operation_type="SCRIPT",
224
- metadata={
225
- "status_message": script_result.get("status_message", ""),
226
- "statements_executed": script_result.get("statements_executed", -1),
227
- },
228
- )
229
-
230
- # Handle DML results
231
- dml_result = cast("DMLResultDict", result)
232
- rows_affected = dml_result.get("rows_affected", -1)
233
- status_message = dml_result.get("status_message", "")
234
- return SQLResult[RowT](
235
- statement=statement,
236
- data=[],
237
- rows_affected=rows_affected,
238
- operation_type=operation_type,
239
- metadata={"status_message": status_message},
240
- )
241
-
242
230
  def _connection(self, connection: Optional["AsyncmyConnection"] = None) -> "AsyncmyConnection":
243
231
  """Get the connection to use for the operation."""
244
232
  return connection or self.connection
@@ -3,11 +3,12 @@
3
3
  import logging
4
4
  from collections.abc import AsyncGenerator, Awaitable, Callable
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
8
7
 
9
- from asyncpg import Record
8
+ from asyncpg import Connection, Record
10
9
  from asyncpg import create_pool as asyncpg_create_pool
10
+ from asyncpg.connection import ConnectionMeta
11
+ from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
11
12
  from typing_extensions import NotRequired, Unpack
12
13
 
13
14
  from sqlspec.adapters.asyncpg.driver import AsyncpgConnection, AsyncpgDriver
@@ -19,7 +20,6 @@ from sqlspec.utils.serializers import from_json, to_json
19
20
  if TYPE_CHECKING:
20
21
  from asyncio.events import AbstractEventLoop
21
22
 
22
- from asyncpg.pool import Pool
23
23
  from sqlglot.dialects.dialect import DialectType
24
24
 
25
25
 
@@ -224,7 +224,6 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
224
224
 
225
225
  super().__init__()
226
226
 
227
- # Set pool_instance after super().__init__() to ensure it's not overridden
228
227
  if pool_instance_from_kwargs is not None:
229
228
  self.pool_instance = pool_instance_from_kwargs
230
229
 
@@ -241,7 +240,6 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
241
240
  if getattr(self, field, None) is not None and getattr(self, field) is not Empty
242
241
  }
243
242
 
244
- # Add connection-specific extras (not pool-specific ones)
245
243
  config.update(self.extras)
246
244
 
247
245
  return config
@@ -318,15 +316,16 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
318
316
  An AsyncpgDriver instance.
319
317
  """
320
318
  async with self.provide_connection(*args, **kwargs) as connection:
321
- # Create statement config with parameter style info if not already set
322
319
  statement_config = self.statement_config
320
+ # Inject parameter style info if not already set
323
321
  if statement_config is not None and statement_config.allowed_parameter_styles is None:
322
+ from dataclasses import replace
323
+
324
324
  statement_config = replace(
325
325
  statement_config,
326
326
  allowed_parameter_styles=self.supported_parameter_styles,
327
327
  target_parameter_style=self.preferred_parameter_style,
328
328
  )
329
-
330
329
  yield self.driver_type(connection=connection, config=statement_config)
331
330
 
332
331
  async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
@@ -348,27 +347,16 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
348
347
  Returns:
349
348
  Dictionary mapping type names to types.
350
349
  """
351
- # Get base types from parent
352
350
  namespace = super().get_signature_namespace()
353
-
354
- # Add AsyncPG-specific types
355
- try:
356
- from asyncpg import Connection, Record
357
- from asyncpg.connection import ConnectionMeta
358
- from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
359
-
360
- namespace.update(
361
- {
362
- "Connection": Connection,
363
- "Pool": Pool,
364
- "PoolConnectionProxy": PoolConnectionProxy,
365
- "PoolConnectionProxyMeta": PoolConnectionProxyMeta,
366
- "ConnectionMeta": ConnectionMeta,
367
- "Record": Record,
368
- "AsyncpgConnection": type(AsyncpgConnection), # The Union type alias
369
- }
370
- )
371
- except ImportError:
372
- logger.warning("Failed to import AsyncPG types for signature namespace")
373
-
351
+ namespace.update(
352
+ {
353
+ "Connection": Connection,
354
+ "Pool": Pool,
355
+ "PoolConnectionProxy": PoolConnectionProxy,
356
+ "PoolConnectionProxyMeta": PoolConnectionProxyMeta,
357
+ "ConnectionMeta": ConnectionMeta,
358
+ "Record": Record,
359
+ "AsyncpgConnection": type(AsyncpgConnection),
360
+ }
361
+ )
374
362
  return namespace