sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -621
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -431
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +218 -436
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +417 -487
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +600 -553
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +392 -406
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +548 -921
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -533
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +725 -0
- sqlspec/adapters/psycopg/driver.py +734 -694
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +242 -405
- sqlspec/base.py +220 -784
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
- sqlspec-0.12.1.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -331
- sqlspec/mixins.py +0 -305
- sqlspec/statement.py +0 -378
- sqlspec-0.11.1.dist-info/RECORD +0 -69
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
"""AioSQL adapter implementation for SQLSpec.
|
|
2
|
+
|
|
3
|
+
This module provides adapter classes that implement the aiosql adapter protocols
|
|
4
|
+
while using SQLSpec drivers under the hood. This enables users to load SQL queries
|
|
5
|
+
from files using aiosql while leveraging all of SQLSpec's advanced features.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from collections.abc import AsyncGenerator, Generator
|
|
10
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
11
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar, Union, cast
|
|
12
|
+
|
|
13
|
+
from sqlspec.exceptions import MissingDependencyError
|
|
14
|
+
from sqlspec.statement.result import SQLResult
|
|
15
|
+
from sqlspec.statement.sql import SQL, SQLConfig
|
|
16
|
+
from sqlspec.typing import AIOSQL_INSTALLED, RowT
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger("sqlspec.extensions.aiosql")
|
|
22
|
+
|
|
23
|
+
__all__ = ("AiosqlAsyncAdapter", "AiosqlSyncAdapter")
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _check_aiosql_available() -> None:
|
|
29
|
+
if not AIOSQL_INSTALLED:
|
|
30
|
+
msg = "aiosql"
|
|
31
|
+
raise MissingDependencyError(msg, "aiosql")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _normalize_dialect(dialect: "Union[str, Any, None]") -> str:
|
|
35
|
+
"""Normalize dialect name for SQLGlot compatibility.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
dialect: Original dialect name (can be str, Dialect, type[Dialect], or None)
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Normalized dialect name
|
|
42
|
+
"""
|
|
43
|
+
# Handle different dialect types
|
|
44
|
+
if dialect is None:
|
|
45
|
+
return "sql"
|
|
46
|
+
|
|
47
|
+
# Extract string from dialect class or instance
|
|
48
|
+
if hasattr(dialect, "__name__"): # It's a class
|
|
49
|
+
dialect_str = str(dialect.__name__).lower() # pyright: ignore
|
|
50
|
+
elif hasattr(dialect, "name"): # It's an instance with name attribute
|
|
51
|
+
dialect_str = str(dialect.name).lower() # pyright: ignore
|
|
52
|
+
elif isinstance(dialect, str):
|
|
53
|
+
dialect_str = dialect.lower()
|
|
54
|
+
else:
|
|
55
|
+
dialect_str = str(dialect).lower()
|
|
56
|
+
|
|
57
|
+
# Map common dialect aliases to SQLGlot names
|
|
58
|
+
dialect_mapping = {
|
|
59
|
+
"postgresql": "postgres",
|
|
60
|
+
"psycopg": "postgres",
|
|
61
|
+
"asyncpg": "postgres",
|
|
62
|
+
"psqlpy": "postgres",
|
|
63
|
+
"sqlite3": "sqlite",
|
|
64
|
+
"aiosqlite": "sqlite",
|
|
65
|
+
}
|
|
66
|
+
return dialect_mapping.get(dialect_str, dialect_str)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class _AiosqlAdapterBase:
|
|
70
|
+
"""Base adapter for common logic."""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self, driver: "Union[SyncDriverAdapterProtocol[Any, Any], AsyncDriverAdapterProtocol[Any, Any]]"
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Initialize the base adapter.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
driver: SQLSpec driver to use for execution.
|
|
79
|
+
"""
|
|
80
|
+
_check_aiosql_available()
|
|
81
|
+
self.driver = driver
|
|
82
|
+
|
|
83
|
+
def process_sql(self, query_name: str, op_type: "Any", sql: str) -> str:
|
|
84
|
+
"""Process SQL for aiosql compatibility."""
|
|
85
|
+
return sql
|
|
86
|
+
|
|
87
|
+
def _create_sql_object(self, sql: str, parameters: "Any" = None) -> SQL:
|
|
88
|
+
"""Create SQL object with proper configuration."""
|
|
89
|
+
config = SQLConfig(strict_mode=False, enable_validation=False)
|
|
90
|
+
normalized_dialect = _normalize_dialect(self.driver.dialect)
|
|
91
|
+
return SQL(sql, parameters, config=config, dialect=normalized_dialect)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class AiosqlSyncAdapter(_AiosqlAdapterBase):
|
|
95
|
+
"""Sync adapter that implements aiosql protocol using SQLSpec drivers.
|
|
96
|
+
|
|
97
|
+
This adapter bridges aiosql's sync driver protocol with SQLSpec's sync drivers,
|
|
98
|
+
enabling all of SQLSpec's drivers to work with queries loaded by aiosql.
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
is_aio_driver: ClassVar[bool] = False
|
|
103
|
+
|
|
104
|
+
def __init__(self, driver: "SyncDriverAdapterProtocol[Any, Any]") -> None:
|
|
105
|
+
"""Initialize the sync adapter.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
driver: SQLSpec sync driver to use for execution
|
|
109
|
+
"""
|
|
110
|
+
super().__init__(driver)
|
|
111
|
+
|
|
112
|
+
def select(
|
|
113
|
+
self, conn: Any, query_name: str, sql: str, parameters: "Any", record_class: Optional[Any] = None
|
|
114
|
+
) -> Generator[Any, None, None]:
|
|
115
|
+
"""Execute a SELECT query and return results as generator.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
conn: Database connection (passed through to SQLSpec driver)
|
|
119
|
+
query_name: Name of the query
|
|
120
|
+
sql: SQL string
|
|
121
|
+
parameters: Query parameters
|
|
122
|
+
record_class: Deprecated - use schema_type in driver.execute instead
|
|
123
|
+
|
|
124
|
+
Yields:
|
|
125
|
+
Query result rows
|
|
126
|
+
|
|
127
|
+
Note:
|
|
128
|
+
record_class parameter is ignored. Use schema_type in driver.execute
|
|
129
|
+
or _sqlspec_schema_type in parameters for type mapping.
|
|
130
|
+
"""
|
|
131
|
+
if record_class is not None:
|
|
132
|
+
logger.warning(
|
|
133
|
+
"record_class parameter is deprecated and ignored. "
|
|
134
|
+
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Create SQL object and apply filters
|
|
138
|
+
sql_obj = self._create_sql_object(sql, parameters)
|
|
139
|
+
# Execute using SQLSpec driver
|
|
140
|
+
result = self.driver.execute(sql_obj, connection=conn)
|
|
141
|
+
|
|
142
|
+
if isinstance(result, SQLResult) and result.data is not None:
|
|
143
|
+
yield from result.data
|
|
144
|
+
|
|
145
|
+
def select_one(
|
|
146
|
+
self, conn: Any, query_name: str, sql: str, parameters: "Any", record_class: Optional[Any] = None
|
|
147
|
+
) -> Optional[RowT]:
|
|
148
|
+
"""Execute a SELECT query and return first result.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
conn: Database connection
|
|
152
|
+
query_name: Name of the query
|
|
153
|
+
sql: SQL string
|
|
154
|
+
parameters: Query parameters
|
|
155
|
+
record_class: Deprecated - use schema_type in driver.execute instead
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
First result row or None
|
|
159
|
+
|
|
160
|
+
Note:
|
|
161
|
+
record_class parameter is ignored. Use schema_type in driver.execute
|
|
162
|
+
or _sqlspec_schema_type in parameters for type mapping.
|
|
163
|
+
"""
|
|
164
|
+
if record_class is not None:
|
|
165
|
+
logger.warning(
|
|
166
|
+
"record_class parameter is deprecated and ignored. "
|
|
167
|
+
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
sql_obj = self._create_sql_object(sql, parameters)
|
|
171
|
+
|
|
172
|
+
result = cast("SQLResult[RowT]", self.driver.execute(sql_obj, connection=conn))
|
|
173
|
+
|
|
174
|
+
if hasattr(result, "data") and result.data and isinstance(result, SQLResult):
|
|
175
|
+
return cast("Optional[RowT]", result.data[0])
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
def select_value(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Optional[Any]:
|
|
179
|
+
"""Execute a SELECT query and return first value of first row.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
conn: Database connection
|
|
183
|
+
query_name: Name of the query
|
|
184
|
+
sql: SQL string
|
|
185
|
+
parameters: Query parameters
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
First value of first row or None
|
|
189
|
+
"""
|
|
190
|
+
row = self.select_one(conn, query_name, sql, parameters)
|
|
191
|
+
if row is None:
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
if isinstance(row, dict):
|
|
195
|
+
# Return first value from dict
|
|
196
|
+
return next(iter(row.values())) if row else None
|
|
197
|
+
if hasattr(row, "__getitem__"):
|
|
198
|
+
# Handle tuple/list-like objects
|
|
199
|
+
return row[0] if len(row) > 0 else None
|
|
200
|
+
# Handle scalar or object with attributes
|
|
201
|
+
return row
|
|
202
|
+
|
|
203
|
+
@contextmanager
|
|
204
|
+
def select_cursor(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Generator[Any, None, None]:
|
|
205
|
+
"""Execute a SELECT query and return cursor context manager.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
conn: Database connection
|
|
209
|
+
query_name: Name of the query
|
|
210
|
+
sql: SQL string
|
|
211
|
+
parameters: Query parameters
|
|
212
|
+
|
|
213
|
+
Yields:
|
|
214
|
+
Cursor-like object with results
|
|
215
|
+
"""
|
|
216
|
+
sql_obj = self._create_sql_object(sql, parameters)
|
|
217
|
+
result = self.driver.execute(sql_obj, connection=conn)
|
|
218
|
+
|
|
219
|
+
# Create a cursor-like object
|
|
220
|
+
class CursorLike:
|
|
221
|
+
def __init__(self, result: Any) -> None:
|
|
222
|
+
self.result = result
|
|
223
|
+
|
|
224
|
+
def fetchall(self) -> list[Any]:
|
|
225
|
+
if isinstance(result, SQLResult) and result.data is not None:
|
|
226
|
+
return list(result.data)
|
|
227
|
+
return []
|
|
228
|
+
|
|
229
|
+
def fetchone(self) -> Optional[Any]:
|
|
230
|
+
rows = self.fetchall()
|
|
231
|
+
return rows[0] if rows else None
|
|
232
|
+
|
|
233
|
+
yield CursorLike(result)
|
|
234
|
+
|
|
235
|
+
def insert_update_delete(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> int:
|
|
236
|
+
"""Execute INSERT/UPDATE/DELETE and return affected rows.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
conn: Database connection
|
|
240
|
+
query_name: Name of the query
|
|
241
|
+
sql: SQL string
|
|
242
|
+
parameters: Query parameters
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Number of affected rows
|
|
246
|
+
"""
|
|
247
|
+
sql_obj = self._create_sql_object(sql, parameters)
|
|
248
|
+
result = cast("SQLResult[Any]", self.driver.execute(sql_obj, connection=conn))
|
|
249
|
+
|
|
250
|
+
# SQLResult has rows_affected attribute
|
|
251
|
+
return result.rows_affected if hasattr(result, "rows_affected") else 0
|
|
252
|
+
|
|
253
|
+
def insert_update_delete_many(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> int:
|
|
254
|
+
"""Execute INSERT/UPDATE/DELETE with many parameter sets.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
conn: Database connection
|
|
258
|
+
query_name: Name of the query
|
|
259
|
+
sql: SQL string
|
|
260
|
+
parameters: Sequence of parameter sets
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Number of affected rows
|
|
264
|
+
"""
|
|
265
|
+
# For executemany, we don't extract sqlspec filters from individual parameter sets
|
|
266
|
+
sql_obj = self._create_sql_object(sql)
|
|
267
|
+
|
|
268
|
+
result = cast("SQLResult[Any]", self.driver.execute_many(sql_obj, parameters=parameters, connection=conn))
|
|
269
|
+
|
|
270
|
+
# SQLResult has rows_affected attribute
|
|
271
|
+
return result.rows_affected if hasattr(result, "rows_affected") else 0
|
|
272
|
+
|
|
273
|
+
def insert_returning(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Optional[Any]:
|
|
274
|
+
"""Execute INSERT with RETURNING and return result.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
conn: Database connection
|
|
278
|
+
query_name: Name of the query
|
|
279
|
+
sql: SQL string
|
|
280
|
+
parameters: Query parameters
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Returned value or None
|
|
284
|
+
"""
|
|
285
|
+
# INSERT RETURNING is treated like a select that returns data
|
|
286
|
+
return self.select_one(conn, query_name, sql, parameters)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class AiosqlAsyncAdapter(_AiosqlAdapterBase):
|
|
290
|
+
"""Async adapter that implements aiosql protocol using SQLSpec drivers.
|
|
291
|
+
|
|
292
|
+
This adapter bridges aiosql's async driver protocol with SQLSpec's async drivers,
|
|
293
|
+
enabling all of SQLSpec's features to work with queries loaded by aiosql.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
is_aio_driver: ClassVar[bool] = True
|
|
297
|
+
|
|
298
|
+
def __init__(self, driver: "AsyncDriverAdapterProtocol[Any, Any]") -> None:
|
|
299
|
+
"""Initialize the async adapter.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
driver: SQLSpec async driver to use for execution
|
|
303
|
+
"""
|
|
304
|
+
super().__init__(driver)
|
|
305
|
+
|
|
306
|
+
async def select(
|
|
307
|
+
self, conn: Any, query_name: str, sql: str, parameters: "Any", record_class: Optional[Any] = None
|
|
308
|
+
) -> list[Any]:
|
|
309
|
+
"""Execute a SELECT query and return results as list.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
conn: Database connection
|
|
313
|
+
query_name: Name of the query
|
|
314
|
+
sql: SQL string
|
|
315
|
+
parameters: Query parameters
|
|
316
|
+
record_class: Deprecated - use schema_type in driver.execute instead
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
List of query result rows
|
|
320
|
+
|
|
321
|
+
Note:
|
|
322
|
+
record_class parameter is ignored. Use schema_type in driver.execute
|
|
323
|
+
or _sqlspec_schema_type in parameters for type mapping.
|
|
324
|
+
"""
|
|
325
|
+
if record_class is not None:
|
|
326
|
+
logger.warning(
|
|
327
|
+
"record_class parameter is deprecated and ignored. "
|
|
328
|
+
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
sql_obj = self._create_sql_object(sql, parameters)
|
|
332
|
+
|
|
333
|
+
result = await self.driver.execute(sql_obj, connection=conn) # type: ignore[misc]
|
|
334
|
+
|
|
335
|
+
if hasattr(result, "data") and result.data is not None and isinstance(result, SQLResult):
|
|
336
|
+
return list(result.data)
|
|
337
|
+
return []
|
|
338
|
+
|
|
339
|
+
async def select_one(
|
|
340
|
+
self, conn: Any, query_name: str, sql: str, parameters: "Any", record_class: Optional[Any] = None
|
|
341
|
+
) -> Optional[Any]:
|
|
342
|
+
"""Execute a SELECT query and return first result.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
conn: Database connection
|
|
346
|
+
query_name: Name of the query
|
|
347
|
+
sql: SQL string
|
|
348
|
+
parameters: Query parameters
|
|
349
|
+
record_class: Deprecated - use schema_type in driver.execute instead
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
First result row or None
|
|
353
|
+
|
|
354
|
+
Note:
|
|
355
|
+
record_class parameter is ignored. Use schema_type in driver.execute
|
|
356
|
+
or _sqlspec_schema_type in parameters for type mapping.
|
|
357
|
+
"""
|
|
358
|
+
if record_class is not None:
|
|
359
|
+
logger.warning(
|
|
360
|
+
"record_class parameter is deprecated and ignored. "
|
|
361
|
+
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
sql_obj = self._create_sql_object(sql, parameters)
|
|
365
|
+
|
|
366
|
+
result = await self.driver.execute(sql_obj, connection=conn) # type: ignore[misc]
|
|
367
|
+
|
|
368
|
+
if hasattr(result, "data") and result.data and isinstance(result, SQLResult):
|
|
369
|
+
return result.data[0]
|
|
370
|
+
return None
|
|
371
|
+
|
|
372
|
+
async def select_value(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Optional[Any]:
|
|
373
|
+
"""Execute a SELECT query and return first value of first row.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
conn: Database connection
|
|
377
|
+
query_name: Name of the query
|
|
378
|
+
sql: SQL string
|
|
379
|
+
parameters: Query parameters
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
First value of first row or None
|
|
383
|
+
"""
|
|
384
|
+
row = await self.select_one(conn, query_name, sql, parameters)
|
|
385
|
+
if row is None:
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
if isinstance(row, dict):
|
|
389
|
+
# Return first value from dict
|
|
390
|
+
return next(iter(row.values())) if row else None
|
|
391
|
+
if hasattr(row, "__getitem__"):
|
|
392
|
+
# Handle tuple/list-like objects
|
|
393
|
+
return row[0] if len(row) > 0 else None
|
|
394
|
+
# Handle scalar or object with attributes
|
|
395
|
+
return row
|
|
396
|
+
|
|
397
|
+
@asynccontextmanager
|
|
398
|
+
async def select_cursor(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> AsyncGenerator[Any, None]:
|
|
399
|
+
"""Execute a SELECT query and return cursor context manager.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
conn: Database connection
|
|
403
|
+
query_name: Name of the query
|
|
404
|
+
sql: SQL string
|
|
405
|
+
parameters: Query parameters
|
|
406
|
+
|
|
407
|
+
Yields:
|
|
408
|
+
Cursor-like object with results
|
|
409
|
+
"""
|
|
410
|
+
sql_obj = self._create_sql_object(sql, parameters)
|
|
411
|
+
result = await self.driver.execute(sql_obj, connection=conn) # type: ignore[misc]
|
|
412
|
+
|
|
413
|
+
class AsyncCursorLike:
|
|
414
|
+
def __init__(self, result: Any) -> None:
|
|
415
|
+
self.result = result
|
|
416
|
+
|
|
417
|
+
@staticmethod
|
|
418
|
+
async def fetchall() -> list[Any]:
|
|
419
|
+
if isinstance(result, SQLResult) and result.data is not None:
|
|
420
|
+
return list(result.data)
|
|
421
|
+
return []
|
|
422
|
+
|
|
423
|
+
async def fetchone(self) -> Optional[Any]:
|
|
424
|
+
rows = await self.fetchall()
|
|
425
|
+
return rows[0] if rows else None
|
|
426
|
+
|
|
427
|
+
yield AsyncCursorLike(result)
|
|
428
|
+
|
|
429
|
+
async def insert_update_delete(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> None:
|
|
430
|
+
"""Execute INSERT/UPDATE/DELETE.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
conn: Database connection
|
|
434
|
+
query_name: Name of the query
|
|
435
|
+
sql: SQL string
|
|
436
|
+
parameters: Query parameters
|
|
437
|
+
|
|
438
|
+
Note:
|
|
439
|
+
Async version returns None per aiosql protocol
|
|
440
|
+
"""
|
|
441
|
+
sql_obj = self._create_sql_object(sql, parameters)
|
|
442
|
+
|
|
443
|
+
await self.driver.execute(sql_obj, connection=conn) # type: ignore[misc]
|
|
444
|
+
|
|
445
|
+
async def insert_update_delete_many(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> None:
|
|
446
|
+
"""Execute INSERT/UPDATE/DELETE with many parameter sets.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
conn: Database connection
|
|
450
|
+
query_name: Name of the query
|
|
451
|
+
sql: SQL string
|
|
452
|
+
parameters: Sequence of parameter sets
|
|
453
|
+
|
|
454
|
+
Note:
|
|
455
|
+
Async version returns None per aiosql protocol
|
|
456
|
+
"""
|
|
457
|
+
# For executemany, we don't extract sqlspec filters from individual parameter sets
|
|
458
|
+
sql_obj = self._create_sql_object(sql)
|
|
459
|
+
await self.driver.execute_many(sql_obj, parameters=parameters, connection=conn) # type: ignore[misc]
|
|
460
|
+
|
|
461
|
+
async def insert_returning(self, conn: Any, query_name: str, sql: str, parameters: "Any") -> Optional[Any]:
|
|
462
|
+
"""Execute INSERT with RETURNING and return result.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
conn: Database connection
|
|
466
|
+
query_name: Name of the query
|
|
467
|
+
sql: SQL string
|
|
468
|
+
parameters: Query parameters
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
Returned value or None
|
|
472
|
+
"""
|
|
473
|
+
# INSERT RETURNING is treated like a select that returns data
|
|
474
|
+
return await self.select_one(conn, query_name, sql, parameters)
|
|
@@ -2,9 +2,4 @@ from sqlspec.extensions.litestar import handlers, providers
|
|
|
2
2
|
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
3
3
|
from sqlspec.extensions.litestar.plugin import SQLSpec
|
|
4
4
|
|
|
5
|
-
__all__ = (
|
|
6
|
-
"DatabaseConfig",
|
|
7
|
-
"SQLSpec",
|
|
8
|
-
"handlers",
|
|
9
|
-
"providers",
|
|
10
|
-
)
|
|
5
|
+
__all__ = ("DatabaseConfig", "SQLSpec", "handlers", "providers")
|
|
@@ -3,11 +3,7 @@ from typing import TYPE_CHECKING, Any
|
|
|
3
3
|
if TYPE_CHECKING:
|
|
4
4
|
from litestar.types import Scope
|
|
5
5
|
|
|
6
|
-
__all__ = (
|
|
7
|
-
"delete_sqlspec_scope_state",
|
|
8
|
-
"get_sqlspec_scope_state",
|
|
9
|
-
"set_sqlspec_scope_state",
|
|
10
|
-
)
|
|
6
|
+
__all__ = ("delete_sqlspec_scope_state", "get_sqlspec_scope_state", "set_sqlspec_scope_state")
|
|
11
7
|
|
|
12
8
|
_SCOPE_NAMESPACE = "_sqlspec"
|
|
13
9
|
|
|
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
|
|
19
19
|
from litestar.datastructures.state import State
|
|
20
20
|
from litestar.types import BeforeMessageSendHookHandler, Scope
|
|
21
21
|
|
|
22
|
-
from sqlspec.
|
|
22
|
+
from sqlspec.config import AsyncConfigT, DriverT, SyncConfigT
|
|
23
23
|
from sqlspec.typing import ConnectionT, PoolT
|
|
24
24
|
|
|
25
25
|
|
|
@@ -48,6 +48,7 @@ class DatabaseConfig:
|
|
|
48
48
|
commit_mode: "CommitMode" = field(default=DEFAULT_COMMIT_MODE)
|
|
49
49
|
extra_commit_statuses: "Optional[set[int]]" = field(default=None)
|
|
50
50
|
extra_rollback_statuses: "Optional[set[int]]" = field(default=None)
|
|
51
|
+
enable_correlation_middleware: bool = field(default=True)
|
|
51
52
|
connection_provider: "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]" = field( # pyright: ignore[reportGeneralTypeIssues]
|
|
52
53
|
init=False, repr=False, hash=False
|
|
53
54
|
)
|
|
@@ -55,14 +56,12 @@ class DatabaseConfig:
|
|
|
55
56
|
session_provider: "Callable[[Any], AsyncGenerator[DriverT, None]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
|
|
56
57
|
before_send_handler: "BeforeMessageSendHookHandler" = field(init=False, repr=False, hash=False)
|
|
57
58
|
lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field(
|
|
58
|
-
init=False,
|
|
59
|
-
repr=False,
|
|
60
|
-
hash=False,
|
|
59
|
+
init=False, repr=False, hash=False
|
|
61
60
|
)
|
|
62
61
|
annotation: "type[Union[SyncConfigT, AsyncConfigT]]" = field(init=False, repr=False, hash=False) # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
|
|
63
62
|
|
|
64
63
|
def __post_init__(self) -> None:
|
|
65
|
-
if not self.config.
|
|
64
|
+
if not self.config.supports_connection_pooling and self.pool_key == DEFAULT_POOL_KEY: # type: ignore[union-attr,unused-ignore]
|
|
66
65
|
"""If the database configuration does not support connection pooling, the pool key must be unique. We just automatically generate a unique identify so it won't conflict with other configs that may get added"""
|
|
67
66
|
self.pool_key = f"_{self.pool_key}_{id(self.config)}"
|
|
68
67
|
if self.commit_mode == "manual":
|
|
@@ -82,7 +81,7 @@ class DatabaseConfig:
|
|
|
82
81
|
connection_scope_key=self.connection_key,
|
|
83
82
|
)
|
|
84
83
|
else:
|
|
85
|
-
msg = f"Invalid commit mode: {self.commit_mode}"
|
|
84
|
+
msg = f"Invalid commit mode: {self.commit_mode}"
|
|
86
85
|
raise ImproperConfigurationError(detail=msg)
|
|
87
86
|
self.lifespan_handler = lifespan_handler_maker(config=self.config, pool_key=self.pool_key)
|
|
88
87
|
self.connection_provider = connection_provider_maker(
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# ruff: noqa: PLC2801
|
|
2
1
|
import contextlib
|
|
3
2
|
import inspect
|
|
4
3
|
from collections.abc import AsyncGenerator
|
|
@@ -22,10 +21,9 @@ if TYPE_CHECKING:
|
|
|
22
21
|
from litestar.datastructures.state import State
|
|
23
22
|
from litestar.types import Message, Scope
|
|
24
23
|
|
|
25
|
-
from sqlspec.
|
|
24
|
+
from sqlspec.config import DatabaseConfigProtocol, DriverT
|
|
26
25
|
from sqlspec.typing import ConnectionT, PoolT
|
|
27
26
|
|
|
28
|
-
|
|
29
27
|
SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
|
|
30
28
|
"""ASGI events that terminate a session scope."""
|
|
31
29
|
|
|
@@ -125,8 +123,7 @@ def autocommit_handler_maker(
|
|
|
125
123
|
|
|
126
124
|
|
|
127
125
|
def lifespan_handler_maker(
|
|
128
|
-
config: "DatabaseConfigProtocol[Any, Any, Any]",
|
|
129
|
-
pool_key: str,
|
|
126
|
+
config: "DatabaseConfigProtocol[Any, Any, Any]", pool_key: str
|
|
130
127
|
) -> "Callable[[Litestar], AbstractAsyncContextManager[None]]":
|
|
131
128
|
"""Build the lifespan handler for managing the database connection pool.
|
|
132
129
|
|
|
@@ -158,7 +155,7 @@ def lifespan_handler_maker(
|
|
|
158
155
|
app.state.pop(pool_key, None)
|
|
159
156
|
try:
|
|
160
157
|
await ensure_async_(config.close_pool)()
|
|
161
|
-
except Exception as e:
|
|
158
|
+
except Exception as e:
|
|
162
159
|
if app.logger: # pragma: no cover
|
|
163
160
|
app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e)
|
|
164
161
|
|
|
@@ -208,9 +205,7 @@ def pool_provider_maker(
|
|
|
208
205
|
|
|
209
206
|
|
|
210
207
|
def connection_provider_maker(
|
|
211
|
-
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
|
|
212
|
-
pool_key: str,
|
|
213
|
-
connection_key: str,
|
|
208
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str, connection_key: str
|
|
214
209
|
) -> "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]":
|
|
215
210
|
async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ConnectionT, None]":
|
|
216
211
|
db_pool = state.get(pool_key)
|
|
@@ -238,7 +233,7 @@ def connection_provider_maker(
|
|
|
238
233
|
finally:
|
|
239
234
|
if entered_connection is not None:
|
|
240
235
|
await connection_cm.__aexit__(None, None, None)
|
|
241
|
-
|
|
236
|
+
delete_sqlspec_scope_state(scope, connection_key) # Clear from scope
|
|
242
237
|
|
|
243
238
|
return provide_connection
|
|
244
239
|
|
|
@@ -251,8 +246,14 @@ def session_provider_maker(
|
|
|
251
246
|
|
|
252
247
|
conn_type_annotation = config.connection_type
|
|
253
248
|
|
|
249
|
+
# Import Dependency at function level to avoid circular imports
|
|
250
|
+
from litestar.params import Dependency
|
|
251
|
+
|
|
254
252
|
db_conn_param = inspect.Parameter(
|
|
255
|
-
name=connection_dependency_key,
|
|
253
|
+
name=connection_dependency_key,
|
|
254
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
255
|
+
annotation=conn_type_annotation,
|
|
256
|
+
default=Dependency(skip_validation=True),
|
|
256
257
|
)
|
|
257
258
|
|
|
258
259
|
provider_signature = inspect.Signature(
|
|
@@ -266,6 +267,6 @@ def session_provider_maker(
|
|
|
266
267
|
provide_session.__annotations__ = {}
|
|
267
268
|
|
|
268
269
|
provide_session.__annotations__[connection_dependency_key] = conn_type_annotation
|
|
269
|
-
provide_session.__annotations__["return"] = config.driver_type
|
|
270
|
+
provide_session.__annotations__["return"] = AsyncGenerator[config.driver_type, None] # type: ignore[name-defined]
|
|
270
271
|
|
|
271
272
|
return provide_session
|