sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -644
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -462
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +217 -451
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +418 -498
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +592 -634
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +393 -436
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +549 -942
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -550
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +741 -0
  31. sqlspec/adapters/psycopg/driver.py +732 -733
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +243 -426
  35. sqlspec/base.py +220 -825
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
  137. sqlspec-0.12.0.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -330
  150. sqlspec/mixins.py +0 -306
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.0.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,474 @@
1
+ """AioSQL adapter implementation for SQLSpec.
2
+
3
+ This module provides adapter classes that implement the aiosql adapter protocols
4
+ while using SQLSpec drivers under the hood. This enables users to load SQL queries
5
+ from files using aiosql while leveraging all of SQLSpec's advanced features.
6
+ """
7
+
8
+ import logging
9
+ from collections.abc import AsyncGenerator, Generator
10
+ from contextlib import asynccontextmanager, contextmanager
11
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar, Union, cast
12
+
13
+ from sqlspec.exceptions import MissingDependencyError
14
+ from sqlspec.statement.result import SQLResult
15
+ from sqlspec.statement.sql import SQL, SQLConfig
16
+ from sqlspec.typing import AIOSQL_INSTALLED, RowT
17
+
18
+ if TYPE_CHECKING:
19
+ from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
20
+
21
+ logger = logging.getLogger("sqlspec.extensions.aiosql")
22
+
23
+ __all__ = ("AiosqlAsyncAdapter", "AiosqlSyncAdapter")
24
+
25
+ T = TypeVar("T")
26
+
27
+
28
+ def _check_aiosql_available() -> None:
29
+ if not AIOSQL_INSTALLED:
30
+ msg = "aiosql"
31
+ raise MissingDependencyError(msg, "aiosql")
32
+
33
+
34
+ def _normalize_dialect(dialect: "Union[str, Any, None]") -> str:
35
+ """Normalize dialect name for SQLGlot compatibility.
36
+
37
+ Args:
38
+ dialect: Original dialect name (can be str, Dialect, type[Dialect], or None)
39
+
40
+ Returns:
41
+ Normalized dialect name
42
+ """
43
+ # Handle different dialect types
44
+ if dialect is None:
45
+ return "sql"
46
+
47
+ # Extract string from dialect class or instance
48
+ if hasattr(dialect, "__name__"): # It's a class
49
+ dialect_str = str(dialect.__name__).lower() # pyright: ignore
50
+ elif hasattr(dialect, "name"): # It's an instance with name attribute
51
+ dialect_str = str(dialect.name).lower() # pyright: ignore
52
+ elif isinstance(dialect, str):
53
+ dialect_str = dialect.lower()
54
+ else:
55
+ dialect_str = str(dialect).lower()
56
+
57
+ # Map common dialect aliases to SQLGlot names
58
+ dialect_mapping = {
59
+ "postgresql": "postgres",
60
+ "psycopg": "postgres",
61
+ "asyncpg": "postgres",
62
+ "psqlpy": "postgres",
63
+ "sqlite3": "sqlite",
64
+ "aiosqlite": "sqlite",
65
+ }
66
+ return dialect_mapping.get(dialect_str, dialect_str)
67
+
68
+
69
+ class _AiosqlAdapterBase:
70
+ """Base adapter for common logic."""
71
+
72
+ def __init__(
73
+ self, driver: "Union[SyncDriverAdapterProtocol[Any, Any], AsyncDriverAdapterProtocol[Any, Any]]"
74
+ ) -> None:
75
+ """Initialize the base adapter.
76
+
77
+ Args:
78
+ driver: SQLSpec driver to use for execution.
79
+ """
80
+ _check_aiosql_available()
81
+ self.driver = driver
82
+
83
+ def process_sql(self, query_name: str, op_type: "Any", sql: str) -> str:
84
+ """Process SQL for aiosql compatibility."""
85
+ return sql
86
+
87
+ def _create_sql_object(self, sql: str, parameters: "Any" = None) -> SQL:
88
+ """Create SQL object with proper configuration."""
89
+ config = SQLConfig(strict_mode=False, enable_validation=False)
90
+ normalized_dialect = _normalize_dialect(self.driver.dialect)
91
+ return SQL(sql, parameters, config=config, dialect=normalized_dialect)
92
+
93
+
94
+ class AiosqlSyncAdapter(_AiosqlAdapterBase):
95
+ """Sync adapter that implements aiosql protocol using SQLSpec drivers.
96
+
97
+ This adapter bridges aiosql's sync driver protocol with SQLSpec's sync drivers,
98
+ enabling all of SQLSpec's drivers to work with queries loaded by aiosql.
99
+
100
+ """
101
+
102
+ is_aio_driver: ClassVar[bool] = False
103
+
104
+ def __init__(self, driver: "SyncDriverAdapterProtocol[Any, Any]") -> None:
105
+ """Initialize the sync adapter.
106
+
107
+ Args:
108
+ driver: SQLSpec sync driver to use for execution
109
+ """
110
+ super().__init__(driver)
111
+
112
+ def select(
113
+ self, conn: Any, query_name: str, sql: str, parameters: "Any", record_class: Optional[Any] = None
114
+ ) -> Generator[Any, None, None]:
115
+ """Execute a SELECT query and return results as generator.
116
+
117
+ Args:
118
+ conn: Database connection (passed through to SQLSpec driver)
119
+ query_name: Name of the query
120
+ sql: SQL string
121
+ parameters: Query parameters
122
+ record_class: Deprecated - use schema_type in driver.execute instead
123
+
124
+ Yields:
125
+ Query result rows
126
+
127
+ Note:
128
+ record_class parameter is ignored. Use schema_type in driver.execute
129
+ or _sqlspec_schema_type in parameters for type mapping.
130
+ """
131
+ if record_class is not None:
132
+ logger.warning(
133
+ "record_class parameter is deprecated and ignored. "
134
+ "Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
135
+ )
136
+
137
+ # Create SQL object and apply filters
138
+ sql_obj = self._create_sql_object(sql, parameters)
139
+ # Execute using SQLSpec driver
140
+ result = self.driver.execute(sql_obj, connection=conn)
141
+
142
+ if isinstance(result, SQLResult) and result.data is not None:
143
+ yield from result.data
144
+
145
+ def select_one(
146
+ self, conn: Any, query_name: str, sql: str, parameters: "Any", record_class: Optional[Any] = None
147
+ ) -> Optional[RowT]:
148
+ """Execute a SELECT query and return first result.
149
+
150
+ Args:
151
+ conn: Database connection
152
+ query_name: Name of the query
153
+ sql: SQL string
154
+ parameters: Query parameters
155
+ record_class: Deprecated - use schema_type in driver.execute instead
156
+
157
+ Returns:
158
+ First result row or None
159
+
160
+ Note:
161
+ record_class parameter is ignored. Use schema_type in driver.execute
162
+ or _sqlspec_schema_type in parameters for type mapping.
163
+ """
164
+ if record_class is not None:
165
+ logger.warning(
166
+ "record_class parameter is deprecated and ignored. "
167
+ "Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
168
+ )
169
+
170
+ sql_obj = self._create_sql_object(sql, parameters)
171
+
172
+ result = cast("SQLResult[RowT]", self.driver.execute(sql_obj, connection=conn))
173
+
174
+ if hasattr(result, "data") and result.data and isinstance(result, SQLResult):
175
+ return cast("Optional[RowT]", result.data[0])
176
+ return None
177
+
178
+ def select_value(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Optional[Any]:
179
+ """Execute a SELECT query and return first value of first row.
180
+
181
+ Args:
182
+ conn: Database connection
183
+ query_name: Name of the query
184
+ sql: SQL string
185
+ parameters: Query parameters
186
+
187
+ Returns:
188
+ First value of first row or None
189
+ """
190
+ row = self.select_one(conn, query_name, sql, parameters)
191
+ if row is None:
192
+ return None
193
+
194
+ if isinstance(row, dict):
195
+ # Return first value from dict
196
+ return next(iter(row.values())) if row else None
197
+ if hasattr(row, "__getitem__"):
198
+ # Handle tuple/list-like objects
199
+ return row[0] if len(row) > 0 else None
200
+ # Handle scalar or object with attributes
201
+ return row
202
+
203
+ @contextmanager
204
+ def select_cursor(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Generator[Any, None, None]:
205
+ """Execute a SELECT query and return cursor context manager.
206
+
207
+ Args:
208
+ conn: Database connection
209
+ query_name: Name of the query
210
+ sql: SQL string
211
+ parameters: Query parameters
212
+
213
+ Yields:
214
+ Cursor-like object with results
215
+ """
216
+ sql_obj = self._create_sql_object(sql, parameters)
217
+ result = self.driver.execute(sql_obj, connection=conn)
218
+
219
+ # Create a cursor-like object
220
+ class CursorLike:
221
+ def __init__(self, result: Any) -> None:
222
+ self.result = result
223
+
224
+ def fetchall(self) -> list[Any]:
225
+ if isinstance(result, SQLResult) and result.data is not None:
226
+ return list(result.data)
227
+ return []
228
+
229
+ def fetchone(self) -> Optional[Any]:
230
+ rows = self.fetchall()
231
+ return rows[0] if rows else None
232
+
233
+ yield CursorLike(result)
234
+
235
+ def insert_update_delete(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> int:
236
+ """Execute INSERT/UPDATE/DELETE and return affected rows.
237
+
238
+ Args:
239
+ conn: Database connection
240
+ query_name: Name of the query
241
+ sql: SQL string
242
+ parameters: Query parameters
243
+
244
+ Returns:
245
+ Number of affected rows
246
+ """
247
+ sql_obj = self._create_sql_object(sql, parameters)
248
+ result = cast("SQLResult[Any]", self.driver.execute(sql_obj, connection=conn))
249
+
250
+ # SQLResult has rows_affected attribute
251
+ return result.rows_affected if hasattr(result, "rows_affected") else 0
252
+
253
+ def insert_update_delete_many(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> int:
254
+ """Execute INSERT/UPDATE/DELETE with many parameter sets.
255
+
256
+ Args:
257
+ conn: Database connection
258
+ query_name: Name of the query
259
+ sql: SQL string
260
+ parameters: Sequence of parameter sets
261
+
262
+ Returns:
263
+ Number of affected rows
264
+ """
265
+ # For executemany, we don't extract sqlspec filters from individual parameter sets
266
+ sql_obj = self._create_sql_object(sql)
267
+
268
+ result = cast("SQLResult[Any]", self.driver.execute_many(sql_obj, parameters=parameters, connection=conn))
269
+
270
+ # SQLResult has rows_affected attribute
271
+ return result.rows_affected if hasattr(result, "rows_affected") else 0
272
+
273
+ def insert_returning(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Optional[Any]:
274
+ """Execute INSERT with RETURNING and return result.
275
+
276
+ Args:
277
+ conn: Database connection
278
+ query_name: Name of the query
279
+ sql: SQL string
280
+ parameters: Query parameters
281
+
282
+ Returns:
283
+ Returned value or None
284
+ """
285
+ # INSERT RETURNING is treated like a select that returns data
286
+ return self.select_one(conn, query_name, sql, parameters)
287
+
288
+
289
+ class AiosqlAsyncAdapter(_AiosqlAdapterBase):
290
+ """Async adapter that implements aiosql protocol using SQLSpec drivers.
291
+
292
+ This adapter bridges aiosql's async driver protocol with SQLSpec's async drivers,
293
+ enabling all of SQLSpec's features to work with queries loaded by aiosql.
294
+ """
295
+
296
+ is_aio_driver: ClassVar[bool] = True
297
+
298
+ def __init__(self, driver: "AsyncDriverAdapterProtocol[Any, Any]") -> None:
299
+ """Initialize the async adapter.
300
+
301
+ Args:
302
+ driver: SQLSpec async driver to use for execution
303
+ """
304
+ super().__init__(driver)
305
+
306
+ async def select(
307
+ self, conn: Any, query_name: str, sql: str, parameters: "Any", record_class: Optional[Any] = None
308
+ ) -> list[Any]:
309
+ """Execute a SELECT query and return results as list.
310
+
311
+ Args:
312
+ conn: Database connection
313
+ query_name: Name of the query
314
+ sql: SQL string
315
+ parameters: Query parameters
316
+ record_class: Deprecated - use schema_type in driver.execute instead
317
+
318
+ Returns:
319
+ List of query result rows
320
+
321
+ Note:
322
+ record_class parameter is ignored. Use schema_type in driver.execute
323
+ or _sqlspec_schema_type in parameters for type mapping.
324
+ """
325
+ if record_class is not None:
326
+ logger.warning(
327
+ "record_class parameter is deprecated and ignored. "
328
+ "Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
329
+ )
330
+
331
+ sql_obj = self._create_sql_object(sql, parameters)
332
+
333
+ result = await self.driver.execute(sql_obj, connection=conn) # type: ignore[misc]
334
+
335
+ if hasattr(result, "data") and result.data is not None and isinstance(result, SQLResult):
336
+ return list(result.data)
337
+ return []
338
+
339
+ async def select_one(
340
+ self, conn: Any, query_name: str, sql: str, parameters: "Any", record_class: Optional[Any] = None
341
+ ) -> Optional[Any]:
342
+ """Execute a SELECT query and return first result.
343
+
344
+ Args:
345
+ conn: Database connection
346
+ query_name: Name of the query
347
+ sql: SQL string
348
+ parameters: Query parameters
349
+ record_class: Deprecated - use schema_type in driver.execute instead
350
+
351
+ Returns:
352
+ First result row or None
353
+
354
+ Note:
355
+ record_class parameter is ignored. Use schema_type in driver.execute
356
+ or _sqlspec_schema_type in parameters for type mapping.
357
+ """
358
+ if record_class is not None:
359
+ logger.warning(
360
+ "record_class parameter is deprecated and ignored. "
361
+ "Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
362
+ )
363
+
364
+ sql_obj = self._create_sql_object(sql, parameters)
365
+
366
+ result = await self.driver.execute(sql_obj, connection=conn) # type: ignore[misc]
367
+
368
+ if hasattr(result, "data") and result.data and isinstance(result, SQLResult):
369
+ return result.data[0]
370
+ return None
371
+
372
+ async def select_value(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Optional[Any]:
373
+ """Execute a SELECT query and return first value of first row.
374
+
375
+ Args:
376
+ conn: Database connection
377
+ query_name: Name of the query
378
+ sql: SQL string
379
+ parameters: Query parameters
380
+
381
+ Returns:
382
+ First value of first row or None
383
+ """
384
+ row = await self.select_one(conn, query_name, sql, parameters)
385
+ if row is None:
386
+ return None
387
+
388
+ if isinstance(row, dict):
389
+ # Return first value from dict
390
+ return next(iter(row.values())) if row else None
391
+ if hasattr(row, "__getitem__"):
392
+ # Handle tuple/list-like objects
393
+ return row[0] if len(row) > 0 else None
394
+ # Handle scalar or object with attributes
395
+ return row
396
+
397
+ @asynccontextmanager
398
+ async def select_cursor(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> AsyncGenerator[Any, None]:
399
+ """Execute a SELECT query and return cursor context manager.
400
+
401
+ Args:
402
+ conn: Database connection
403
+ query_name: Name of the query
404
+ sql: SQL string
405
+ parameters: Query parameters
406
+
407
+ Yields:
408
+ Cursor-like object with results
409
+ """
410
+ sql_obj = self._create_sql_object(sql, parameters)
411
+ result = await self.driver.execute(sql_obj, connection=conn) # type: ignore[misc]
412
+
413
+ class AsyncCursorLike:
414
+ def __init__(self, result: Any) -> None:
415
+ self.result = result
416
+
417
+ @staticmethod
418
+ async def fetchall() -> list[Any]:
419
+ if isinstance(result, SQLResult) and result.data is not None:
420
+ return list(result.data)
421
+ return []
422
+
423
+ async def fetchone(self) -> Optional[Any]:
424
+ rows = await self.fetchall()
425
+ return rows[0] if rows else None
426
+
427
+ yield AsyncCursorLike(result)
428
+
429
+ async def insert_update_delete(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> None:
430
+ """Execute INSERT/UPDATE/DELETE.
431
+
432
+ Args:
433
+ conn: Database connection
434
+ query_name: Name of the query
435
+ sql: SQL string
436
+ parameters: Query parameters
437
+
438
+ Note:
439
+ Async version returns None per aiosql protocol
440
+ """
441
+ sql_obj = self._create_sql_object(sql, parameters)
442
+
443
+ await self.driver.execute(sql_obj, connection=conn) # type: ignore[misc]
444
+
445
+ async def insert_update_delete_many(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> None:
446
+ """Execute INSERT/UPDATE/DELETE with many parameter sets.
447
+
448
+ Args:
449
+ conn: Database connection
450
+ query_name: Name of the query
451
+ sql: SQL string
452
+ parameters: Sequence of parameter sets
453
+
454
+ Note:
455
+ Async version returns None per aiosql protocol
456
+ """
457
+ # For executemany, we don't extract sqlspec filters from individual parameter sets
458
+ sql_obj = self._create_sql_object(sql)
459
+ await self.driver.execute_many(sql_obj, parameters=parameters, connection=conn) # type: ignore[misc]
460
+
461
+ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Optional[Any]:
462
+ """Execute INSERT with RETURNING and return result.
463
+
464
+ Args:
465
+ conn: Database connection
466
+ query_name: Name of the query
467
+ sql: SQL string
468
+ parameters: Query parameters
469
+
470
+ Returns:
471
+ Returned value or None
472
+ """
473
+ # INSERT RETURNING is treated like a select that returns data
474
+ return await self.select_one(conn, query_name, sql, parameters)
@@ -2,9 +2,4 @@ from sqlspec.extensions.litestar import handlers, providers
2
2
  from sqlspec.extensions.litestar.config import DatabaseConfig
3
3
  from sqlspec.extensions.litestar.plugin import SQLSpec
4
4
 
5
- __all__ = (
6
- "DatabaseConfig",
7
- "SQLSpec",
8
- "handlers",
9
- "providers",
10
- )
5
+ __all__ = ("DatabaseConfig", "SQLSpec", "handlers", "providers")
@@ -3,11 +3,7 @@ from typing import TYPE_CHECKING, Any
3
3
  if TYPE_CHECKING:
4
4
  from litestar.types import Scope
5
5
 
6
- __all__ = (
7
- "delete_sqlspec_scope_state",
8
- "get_sqlspec_scope_state",
9
- "set_sqlspec_scope_state",
10
- )
6
+ __all__ = ("delete_sqlspec_scope_state", "get_sqlspec_scope_state", "set_sqlspec_scope_state")
11
7
 
12
8
  _SCOPE_NAMESPACE = "_sqlspec"
13
9
 
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
19
19
  from litestar.datastructures.state import State
20
20
  from litestar.types import BeforeMessageSendHookHandler, Scope
21
21
 
22
- from sqlspec.base import AsyncConfigT, DriverT, SyncConfigT
22
+ from sqlspec.config import AsyncConfigT, DriverT, SyncConfigT
23
23
  from sqlspec.typing import ConnectionT, PoolT
24
24
 
25
25
 
@@ -48,6 +48,7 @@ class DatabaseConfig:
48
48
  commit_mode: "CommitMode" = field(default=DEFAULT_COMMIT_MODE)
49
49
  extra_commit_statuses: "Optional[set[int]]" = field(default=None)
50
50
  extra_rollback_statuses: "Optional[set[int]]" = field(default=None)
51
+ enable_correlation_middleware: bool = field(default=True)
51
52
  connection_provider: "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]" = field( # pyright: ignore[reportGeneralTypeIssues]
52
53
  init=False, repr=False, hash=False
53
54
  )
@@ -55,14 +56,12 @@ class DatabaseConfig:
55
56
  session_provider: "Callable[[Any], AsyncGenerator[DriverT, None]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
56
57
  before_send_handler: "BeforeMessageSendHookHandler" = field(init=False, repr=False, hash=False)
57
58
  lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field(
58
- init=False,
59
- repr=False,
60
- hash=False,
59
+ init=False, repr=False, hash=False
61
60
  )
62
61
  annotation: "type[Union[SyncConfigT, AsyncConfigT]]" = field(init=False, repr=False, hash=False) # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
63
62
 
64
63
  def __post_init__(self) -> None:
65
- if not self.config.support_connection_pooling and self.pool_key == DEFAULT_POOL_KEY: # type: ignore[union-attr,unused-ignore]
64
+ if not self.config.supports_connection_pooling and self.pool_key == DEFAULT_POOL_KEY: # type: ignore[union-attr,unused-ignore]
66
65
  """If the database configuration does not support connection pooling, the pool key must be unique. We just automatically generate a unique identify so it won't conflict with other configs that may get added"""
67
66
  self.pool_key = f"_{self.pool_key}_{id(self.config)}"
68
67
  if self.commit_mode == "manual":
@@ -82,7 +81,7 @@ class DatabaseConfig:
82
81
  connection_scope_key=self.connection_key,
83
82
  )
84
83
  else:
85
- msg = f"Invalid commit mode: {self.commit_mode}" # type: ignore[unreachable]
84
+ msg = f"Invalid commit mode: {self.commit_mode}"
86
85
  raise ImproperConfigurationError(detail=msg)
87
86
  self.lifespan_handler = lifespan_handler_maker(config=self.config, pool_key=self.pool_key)
88
87
  self.connection_provider = connection_provider_maker(
@@ -1,4 +1,3 @@
1
- # ruff: noqa: PLC2801
2
1
  import contextlib
3
2
  import inspect
4
3
  from collections.abc import AsyncGenerator
@@ -22,10 +21,9 @@ if TYPE_CHECKING:
22
21
  from litestar.datastructures.state import State
23
22
  from litestar.types import Message, Scope
24
23
 
25
- from sqlspec.base import DatabaseConfigProtocol, DriverT
24
+ from sqlspec.config import DatabaseConfigProtocol, DriverT
26
25
  from sqlspec.typing import ConnectionT, PoolT
27
26
 
28
-
29
27
  SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
30
28
  """ASGI events that terminate a session scope."""
31
29
 
@@ -125,8 +123,7 @@ def autocommit_handler_maker(
125
123
 
126
124
 
127
125
  def lifespan_handler_maker(
128
- config: "DatabaseConfigProtocol[Any, Any, Any]",
129
- pool_key: str,
126
+ config: "DatabaseConfigProtocol[Any, Any, Any]", pool_key: str
130
127
  ) -> "Callable[[Litestar], AbstractAsyncContextManager[None]]":
131
128
  """Build the lifespan handler for managing the database connection pool.
132
129
 
@@ -158,7 +155,7 @@ def lifespan_handler_maker(
158
155
  app.state.pop(pool_key, None)
159
156
  try:
160
157
  await ensure_async_(config.close_pool)()
161
- except Exception as e: # noqa: BLE001
158
+ except Exception as e:
162
159
  if app.logger: # pragma: no cover
163
160
  app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e)
164
161
 
@@ -208,9 +205,7 @@ def pool_provider_maker(
208
205
 
209
206
 
210
207
  def connection_provider_maker(
211
- config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
212
- pool_key: str,
213
- connection_key: str,
208
+ config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str, connection_key: str
214
209
  ) -> "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]":
215
210
  async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ConnectionT, None]":
216
211
  db_pool = state.get(pool_key)
@@ -238,7 +233,7 @@ def connection_provider_maker(
238
233
  finally:
239
234
  if entered_connection is not None:
240
235
  await connection_cm.__aexit__(None, None, None)
241
- delete_sqlspec_scope_state(scope, connection_key) # Optional: clear from scope
236
+ delete_sqlspec_scope_state(scope, connection_key) # Clear from scope
242
237
 
243
238
  return provide_connection
244
239
 
@@ -251,8 +246,14 @@ def session_provider_maker(
251
246
 
252
247
  conn_type_annotation = config.connection_type
253
248
 
249
+ # Import Dependency at function level to avoid circular imports
250
+ from litestar.params import Dependency
251
+
254
252
  db_conn_param = inspect.Parameter(
255
- name=connection_dependency_key, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=conn_type_annotation
253
+ name=connection_dependency_key,
254
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
255
+ annotation=conn_type_annotation,
256
+ default=Dependency(skip_validation=True),
256
257
  )
257
258
 
258
259
  provider_signature = inspect.Signature(
@@ -266,6 +267,6 @@ def session_provider_maker(
266
267
  provide_session.__annotations__ = {}
267
268
 
268
269
  provide_session.__annotations__[connection_dependency_key] = conn_type_annotation
269
- provide_session.__annotations__["return"] = config.driver_type
270
+ provide_session.__annotations__["return"] = AsyncGenerator[config.driver_type, None] # type: ignore[name-defined]
270
271
 
271
272
  return provide_session