sqlspec 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/adapters/adbc/driver.py +192 -28
- sqlspec/adapters/asyncmy/driver.py +72 -15
- sqlspec/adapters/asyncpg/config.py +23 -3
- sqlspec/adapters/asyncpg/driver.py +30 -14
- sqlspec/adapters/bigquery/driver.py +79 -9
- sqlspec/adapters/duckdb/driver.py +39 -56
- sqlspec/adapters/oracledb/driver.py +99 -52
- sqlspec/adapters/psqlpy/driver.py +89 -31
- sqlspec/adapters/psycopg/driver.py +11 -23
- sqlspec/adapters/sqlite/driver.py +77 -8
- sqlspec/base.py +11 -11
- sqlspec/builder/__init__.py +1 -1
- sqlspec/builder/_base.py +4 -5
- sqlspec/builder/_column.py +3 -3
- sqlspec/builder/_ddl.py +5 -1
- sqlspec/builder/_delete.py +5 -6
- sqlspec/builder/_insert.py +6 -7
- sqlspec/builder/_merge.py +5 -5
- sqlspec/builder/_parsing_utils.py +3 -3
- sqlspec/builder/_select.py +6 -5
- sqlspec/builder/_update.py +4 -5
- sqlspec/builder/mixins/_cte_and_set_ops.py +5 -1
- sqlspec/builder/mixins/_delete_operations.py +5 -1
- sqlspec/builder/mixins/_insert_operations.py +5 -1
- sqlspec/builder/mixins/_join_operations.py +5 -0
- sqlspec/builder/mixins/_merge_operations.py +5 -1
- sqlspec/builder/mixins/_order_limit_operations.py +5 -1
- sqlspec/builder/mixins/_pivot_operations.py +4 -1
- sqlspec/builder/mixins/_select_operations.py +5 -1
- sqlspec/builder/mixins/_update_operations.py +5 -1
- sqlspec/builder/mixins/_where_clause.py +5 -1
- sqlspec/config.py +15 -15
- sqlspec/core/compiler.py +11 -3
- sqlspec/core/filters.py +30 -9
- sqlspec/core/parameters.py +67 -67
- sqlspec/core/result.py +62 -31
- sqlspec/core/splitter.py +160 -34
- sqlspec/core/statement.py +95 -14
- sqlspec/driver/_common.py +12 -3
- sqlspec/driver/mixins/_result_tools.py +21 -4
- sqlspec/driver/mixins/_sql_translator.py +45 -7
- sqlspec/extensions/aiosql/adapter.py +1 -1
- sqlspec/extensions/litestar/_utils.py +1 -1
- sqlspec/extensions/litestar/config.py +186 -2
- sqlspec/extensions/litestar/handlers.py +21 -0
- sqlspec/extensions/litestar/plugin.py +237 -3
- sqlspec/loader.py +12 -12
- sqlspec/migrations/loaders.py +5 -2
- sqlspec/migrations/utils.py +2 -2
- sqlspec/storage/backends/obstore.py +1 -3
- sqlspec/storage/registry.py +1 -1
- sqlspec/utils/__init__.py +7 -0
- sqlspec/utils/deprecation.py +6 -0
- sqlspec/utils/fixtures.py +239 -30
- sqlspec/utils/module_loader.py +5 -1
- sqlspec/utils/serializers.py +6 -0
- sqlspec/utils/singleton.py +6 -0
- sqlspec/utils/sync_tools.py +10 -1
- {sqlspec-0.19.0.dist-info → sqlspec-0.21.0.dist-info}/METADATA +230 -44
- {sqlspec-0.19.0.dist-info → sqlspec-0.21.0.dist-info}/RECORD +64 -64
- {sqlspec-0.19.0.dist-info → sqlspec-0.21.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.19.0.dist-info → sqlspec-0.21.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.19.0.dist-info → sqlspec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.19.0.dist-info → sqlspec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""SQL translation mixin for cross-database compatibility."""
|
|
2
|
+
|
|
1
3
|
from typing import Final, NoReturn, Optional
|
|
2
4
|
|
|
3
5
|
from mypy_extensions import trait
|
|
@@ -33,8 +35,7 @@ class SQLTranslatorMixin:
|
|
|
33
35
|
Returns:
|
|
34
36
|
SQL string in target dialect
|
|
35
37
|
|
|
36
|
-
|
|
37
|
-
SQLConversionError: If parsing or conversion fails
|
|
38
|
+
|
|
38
39
|
"""
|
|
39
40
|
|
|
40
41
|
parsed_expression: Optional[exp.Expression] = None
|
|
@@ -53,7 +54,15 @@ class SQLTranslatorMixin:
|
|
|
53
54
|
return self._generate_sql_safely(parsed_expression, target_dialect, pretty)
|
|
54
55
|
|
|
55
56
|
def _parse_statement_safely(self, statement: "Statement") -> "exp.Expression":
|
|
56
|
-
"""Parse statement with
|
|
57
|
+
"""Parse statement with error handling.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
statement: SQL statement to parse
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Parsed expression
|
|
64
|
+
|
|
65
|
+
"""
|
|
57
66
|
try:
|
|
58
67
|
sql_string = str(statement)
|
|
59
68
|
|
|
@@ -62,23 +71,52 @@ class SQLTranslatorMixin:
|
|
|
62
71
|
self._raise_parse_error(e)
|
|
63
72
|
|
|
64
73
|
def _generate_sql_safely(self, expression: "exp.Expression", dialect: DialectType, pretty: bool) -> str:
|
|
65
|
-
"""Generate SQL with
|
|
74
|
+
"""Generate SQL with error handling.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
expression: Parsed expression to convert
|
|
78
|
+
dialect: Target SQL dialect
|
|
79
|
+
pretty: Whether to format the output SQL
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Generated SQL string
|
|
83
|
+
|
|
84
|
+
"""
|
|
66
85
|
try:
|
|
67
86
|
return expression.sql(dialect=dialect, pretty=pretty)
|
|
68
87
|
except Exception as e:
|
|
69
88
|
self._raise_conversion_error(dialect, e)
|
|
70
89
|
|
|
71
90
|
def _raise_statement_parse_error(self) -> NoReturn:
|
|
72
|
-
"""Raise error for unparsable statements.
|
|
91
|
+
"""Raise error for unparsable statements.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
SQLConversionError: Always raised
|
|
95
|
+
"""
|
|
73
96
|
msg = "Statement could not be parsed"
|
|
74
97
|
raise SQLConversionError(msg)
|
|
75
98
|
|
|
76
99
|
def _raise_parse_error(self, e: Exception) -> NoReturn:
|
|
77
|
-
"""Raise error for parsing failures.
|
|
100
|
+
"""Raise error for parsing failures.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
e: Original exception that caused the failure
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
SQLConversionError: Always raised
|
|
107
|
+
"""
|
|
78
108
|
error_msg = f"Failed to parse SQL statement: {e!s}"
|
|
79
109
|
raise SQLConversionError(error_msg) from e
|
|
80
110
|
|
|
81
111
|
def _raise_conversion_error(self, dialect: DialectType, e: Exception) -> NoReturn:
|
|
82
|
-
"""Raise error for conversion failures.
|
|
112
|
+
"""Raise error for conversion failures.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
dialect: Target dialect that caused the failure
|
|
116
|
+
e: Original exception that caused the failure
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
SQLConversionError: Always raised
|
|
120
|
+
"""
|
|
83
121
|
error_msg = f"Failed to convert SQL expression to {dialect}: {e!s}"
|
|
84
122
|
raise SQLConversionError(error_msg) from e
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
This module provides adapter classes that implement the aiosql adapter protocols
|
|
4
4
|
while using SQLSpec drivers under the hood. This enables users to load SQL queries
|
|
5
|
-
from files using aiosql while
|
|
5
|
+
from files using aiosql while using SQLSpec's features for execution and type mapping.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
@@ -12,7 +12,7 @@ def get_sqlspec_scope_state(scope: "Scope", key: str, default: Any = None, pop:
|
|
|
12
12
|
"""Get an internal value from connection scope state.
|
|
13
13
|
|
|
14
14
|
Note:
|
|
15
|
-
If called with a default value, this method behaves like
|
|
15
|
+
If called with a default value, this method behaves like `dict.setdefault()`, both setting the key in the
|
|
16
16
|
namespace to the default value, and returning it.
|
|
17
17
|
|
|
18
18
|
If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
|
-
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
|
|
3
3
|
|
|
4
4
|
from sqlspec.exceptions import ImproperConfigurationError
|
|
5
|
+
from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state
|
|
5
6
|
from sqlspec.extensions.litestar.handlers import (
|
|
6
7
|
autocommit_handler_maker,
|
|
7
8
|
connection_provider_maker,
|
|
@@ -13,13 +14,14 @@ from sqlspec.extensions.litestar.handlers import (
|
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
from collections.abc import AsyncGenerator, Awaitable
|
|
16
|
-
from contextlib import AbstractAsyncContextManager
|
|
17
|
+
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
17
18
|
|
|
18
19
|
from litestar import Litestar
|
|
19
20
|
from litestar.datastructures.state import State
|
|
20
21
|
from litestar.types import BeforeMessageSendHookHandler, Scope
|
|
21
22
|
|
|
22
23
|
from sqlspec.config import AsyncConfigT, DriverT, SyncConfigT
|
|
24
|
+
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
23
25
|
from sqlspec.typing import ConnectionT, PoolT
|
|
24
26
|
|
|
25
27
|
|
|
@@ -34,8 +36,10 @@ __all__ = (
|
|
|
34
36
|
"DEFAULT_CONNECTION_KEY",
|
|
35
37
|
"DEFAULT_POOL_KEY",
|
|
36
38
|
"DEFAULT_SESSION_KEY",
|
|
39
|
+
"AsyncDatabaseConfig",
|
|
37
40
|
"CommitMode",
|
|
38
41
|
"DatabaseConfig",
|
|
42
|
+
"SyncDatabaseConfig",
|
|
39
43
|
)
|
|
40
44
|
|
|
41
45
|
|
|
@@ -90,3 +94,183 @@ class DatabaseConfig:
|
|
|
90
94
|
self.session_provider = session_provider_maker(
|
|
91
95
|
config=self.config, connection_dependency_key=self.connection_key
|
|
92
96
|
)
|
|
97
|
+
|
|
98
|
+
def get_request_session(
|
|
99
|
+
self, state: "State", scope: "Scope"
|
|
100
|
+
) -> "Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]":
|
|
101
|
+
"""Get a session instance from the current request.
|
|
102
|
+
|
|
103
|
+
This method provides access to the database session that has been added to the request
|
|
104
|
+
scope, similar to Advanced Alchemy's provide_session method. It first looks for an
|
|
105
|
+
existing session in the request scope state, and if not found, creates a new one using
|
|
106
|
+
the connection from the scope.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
state: The Litestar application State object.
|
|
110
|
+
scope: The ASGI scope containing the request context.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A driver session instance.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ImproperConfigurationError: If no connection is available in the scope.
|
|
117
|
+
"""
|
|
118
|
+
# Create a unique scope key for sessions to avoid conflicts
|
|
119
|
+
session_scope_key = f"{self.session_key}_instance"
|
|
120
|
+
|
|
121
|
+
# Try to get existing session from scope
|
|
122
|
+
session = get_sqlspec_scope_state(scope, session_scope_key)
|
|
123
|
+
if session is not None:
|
|
124
|
+
return cast("Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]", session)
|
|
125
|
+
|
|
126
|
+
# Get connection from scope state
|
|
127
|
+
connection = get_sqlspec_scope_state(scope, self.connection_key)
|
|
128
|
+
if connection is None:
|
|
129
|
+
msg = f"No database connection found in scope for key '{self.connection_key}'. "
|
|
130
|
+
msg += "Ensure the connection dependency is properly configured and available."
|
|
131
|
+
raise ImproperConfigurationError(detail=msg)
|
|
132
|
+
|
|
133
|
+
# Create new session using the connection
|
|
134
|
+
# Access driver_type which is available on all config types
|
|
135
|
+
session = self.config.driver_type(connection=connection) # type: ignore[union-attr]
|
|
136
|
+
|
|
137
|
+
# Store session in scope for future use
|
|
138
|
+
set_sqlspec_scope_state(scope, session_scope_key, session)
|
|
139
|
+
|
|
140
|
+
return cast("Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]", session)
|
|
141
|
+
|
|
142
|
+
def get_request_connection(self, state: "State", scope: "Scope") -> "Any":
|
|
143
|
+
"""Get a connection instance from the current request.
|
|
144
|
+
|
|
145
|
+
This method provides access to the database connection that has been added to the request
|
|
146
|
+
scope. This is useful in guards, middleware, or other contexts where you need direct
|
|
147
|
+
access to the connection that's been established for the current request.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
state: The Litestar application State object.
|
|
151
|
+
scope: The ASGI scope containing the request context.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
A database connection instance.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
ImproperConfigurationError: If no connection is available in the scope.
|
|
158
|
+
"""
|
|
159
|
+
connection = get_sqlspec_scope_state(scope, self.connection_key)
|
|
160
|
+
if connection is None:
|
|
161
|
+
msg = f"No database connection found in scope for key '{self.connection_key}'. "
|
|
162
|
+
msg += "Ensure the connection dependency is properly configured and available."
|
|
163
|
+
raise ImproperConfigurationError(detail=msg)
|
|
164
|
+
|
|
165
|
+
return cast("Any", connection)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# Add passthrough methods to both specialized classes for convenience
|
|
169
|
+
class SyncDatabaseConfig(DatabaseConfig):
|
|
170
|
+
"""Sync-specific DatabaseConfig with better typing for get_request_session."""
|
|
171
|
+
|
|
172
|
+
def get_request_session(self, state: "State", scope: "Scope") -> "SyncDriverAdapterBase":
|
|
173
|
+
"""Get a sync session instance from the current request.
|
|
174
|
+
|
|
175
|
+
This method provides access to the database session that has been added to the request
|
|
176
|
+
scope, similar to Advanced Alchemy's provide_session method. It first looks for an
|
|
177
|
+
existing session in the request scope state, and if not found, creates a new one using
|
|
178
|
+
the connection from the scope.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
state: The Litestar application State object.
|
|
182
|
+
scope: The ASGI scope containing the request context.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
A sync driver session instance.
|
|
186
|
+
"""
|
|
187
|
+
session = super().get_request_session(state, scope)
|
|
188
|
+
return cast("SyncDriverAdapterBase", session)
|
|
189
|
+
|
|
190
|
+
def provide_session(self) -> "AbstractContextManager[SyncDriverAdapterBase]":
|
|
191
|
+
"""Provide a database session context manager.
|
|
192
|
+
|
|
193
|
+
This is a passthrough to the underlying config's provide_session method
|
|
194
|
+
for convenient access to database sessions.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Context manager that yields a sync driver session.
|
|
198
|
+
"""
|
|
199
|
+
return self.config.provide_session() # type: ignore[union-attr,no-any-return]
|
|
200
|
+
|
|
201
|
+
def provide_connection(self) -> "AbstractContextManager[Any]":
|
|
202
|
+
"""Provide a database connection context manager.
|
|
203
|
+
|
|
204
|
+
This is a passthrough to the underlying config's provide_connection method
|
|
205
|
+
for convenient access to database connections.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Context manager that yields a sync database connection.
|
|
209
|
+
"""
|
|
210
|
+
return self.config.provide_connection() # type: ignore[union-attr,no-any-return]
|
|
211
|
+
|
|
212
|
+
def create_connection(self) -> "Any":
|
|
213
|
+
"""Create and return a new database connection.
|
|
214
|
+
|
|
215
|
+
This is a passthrough to the underlying config's create_connection method
|
|
216
|
+
for direct connection creation without context management.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
A new sync database connection.
|
|
220
|
+
"""
|
|
221
|
+
return self.config.create_connection() # type: ignore[union-attr]
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class AsyncDatabaseConfig(DatabaseConfig):
|
|
225
|
+
"""Async-specific DatabaseConfig with better typing for get_request_session."""
|
|
226
|
+
|
|
227
|
+
def get_request_session(self, state: "State", scope: "Scope") -> "AsyncDriverAdapterBase":
|
|
228
|
+
"""Get an async session instance from the current request.
|
|
229
|
+
|
|
230
|
+
This method provides access to the database session that has been added to the request
|
|
231
|
+
scope, similar to Advanced Alchemy's provide_session method. It first looks for an
|
|
232
|
+
existing session in the request scope state, and if not found, creates a new one using
|
|
233
|
+
the connection from the scope.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
state: The Litestar application State object.
|
|
237
|
+
scope: The ASGI scope containing the request context.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
An async driver session instance.
|
|
241
|
+
"""
|
|
242
|
+
session = super().get_request_session(state, scope)
|
|
243
|
+
return cast("AsyncDriverAdapterBase", session)
|
|
244
|
+
|
|
245
|
+
def provide_session(self) -> "AbstractAsyncContextManager[AsyncDriverAdapterBase]":
|
|
246
|
+
"""Provide a database session context manager.
|
|
247
|
+
|
|
248
|
+
This is a passthrough to the underlying config's provide_session method
|
|
249
|
+
for convenient access to database sessions.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Context manager that yields an async driver session.
|
|
253
|
+
"""
|
|
254
|
+
return self.config.provide_session() # type: ignore[union-attr,no-any-return]
|
|
255
|
+
|
|
256
|
+
def provide_connection(self) -> "AbstractAsyncContextManager[Any]":
|
|
257
|
+
"""Provide a database connection context manager.
|
|
258
|
+
|
|
259
|
+
This is a passthrough to the underlying config's provide_connection method
|
|
260
|
+
for convenient access to database connections.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Context manager that yields an async database connection.
|
|
264
|
+
"""
|
|
265
|
+
return self.config.provide_connection() # type: ignore[union-attr,no-any-return]
|
|
266
|
+
|
|
267
|
+
async def create_connection(self) -> "Any":
|
|
268
|
+
"""Create and return a new database connection.
|
|
269
|
+
|
|
270
|
+
This is a passthrough to the underlying config's create_connection method
|
|
271
|
+
for direct connection creation without context management.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
A new async database connection.
|
|
275
|
+
"""
|
|
276
|
+
return await self.config.create_connection() # type: ignore[union-attr]
|
|
@@ -199,6 +199,17 @@ def pool_provider_maker(
|
|
|
199
199
|
def connection_provider_maker(
|
|
200
200
|
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str, connection_key: str
|
|
201
201
|
) -> "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]":
|
|
202
|
+
"""Create provider for database connections with proper lifecycle management.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
config: The database configuration object.
|
|
206
|
+
pool_key: The key used to retrieve the connection pool from `app.state`.
|
|
207
|
+
connection_key: The key used to store the connection in the ASGI scope.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
The connection provider function.
|
|
211
|
+
"""
|
|
212
|
+
|
|
202
213
|
async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ConnectionT, None]":
|
|
203
214
|
if (db_pool := state.get(pool_key)) is None:
|
|
204
215
|
msg = f"Database pool with key '{pool_key}' not found. Cannot create a connection."
|
|
@@ -230,6 +241,16 @@ def connection_provider_maker(
|
|
|
230
241
|
def session_provider_maker(
|
|
231
242
|
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", connection_dependency_key: str
|
|
232
243
|
) -> "Callable[[Any], AsyncGenerator[DriverT, None]]":
|
|
244
|
+
"""Create provider for database driver sessions.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
config: The database configuration object.
|
|
248
|
+
connection_dependency_key: The key used for connection dependency injection.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
The session provider function.
|
|
252
|
+
"""
|
|
253
|
+
|
|
233
254
|
async def provide_session(*args: Any, **kwargs: Any) -> "AsyncGenerator[DriverT, None]":
|
|
234
255
|
yield cast("DriverT", config.driver_type(connection=args[0] if args else kwargs.get(connection_dependency_key))) # pyright: ignore
|
|
235
256
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING, Optional, Union
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload
|
|
2
2
|
|
|
3
3
|
from litestar.di import Provide
|
|
4
4
|
from litestar.plugins import CLIPlugin, InitPluginProtocol
|
|
@@ -6,14 +6,17 @@ from litestar.plugins import CLIPlugin, InitPluginProtocol
|
|
|
6
6
|
from sqlspec.base import SQLSpec as SQLSpecBase
|
|
7
7
|
from sqlspec.config import AsyncConfigT, DatabaseConfigProtocol, DriverT, SyncConfigT
|
|
8
8
|
from sqlspec.exceptions import ImproperConfigurationError
|
|
9
|
-
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
9
|
+
from sqlspec.extensions.litestar.config import AsyncDatabaseConfig, DatabaseConfig, SyncDatabaseConfig
|
|
10
10
|
from sqlspec.typing import ConnectionT, PoolT
|
|
11
11
|
from sqlspec.utils.logging import get_logger
|
|
12
12
|
|
|
13
13
|
if TYPE_CHECKING:
|
|
14
14
|
from click import Group
|
|
15
15
|
from litestar.config.app import AppConfig
|
|
16
|
+
from litestar.datastructures.state import State
|
|
17
|
+
from litestar.types import Scope
|
|
16
18
|
|
|
19
|
+
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
17
20
|
from sqlspec.loader import SQLFileLoader
|
|
18
21
|
|
|
19
22
|
logger = get_logger("extensions.litestar")
|
|
@@ -131,11 +134,242 @@ class SQLSpec(SQLSpecBase, InitPluginProtocol, CLIPlugin):
|
|
|
131
134
|
The annotation for the configuration.
|
|
132
135
|
"""
|
|
133
136
|
for c in self.config:
|
|
134
|
-
|
|
137
|
+
# Check annotation only if it's been set (during on_app_init)
|
|
138
|
+
annotation_match = hasattr(c, "annotation") and key == c.annotation
|
|
139
|
+
if key == c.config or annotation_match or key in {c.connection_key, c.pool_key}:
|
|
140
|
+
if not hasattr(c, "annotation"):
|
|
141
|
+
msg = (
|
|
142
|
+
"Annotation not set for configuration. Ensure the plugin has been initialized with on_app_init."
|
|
143
|
+
)
|
|
144
|
+
raise AttributeError(msg)
|
|
135
145
|
return c.annotation
|
|
136
146
|
msg = f"No configuration found for {key}"
|
|
137
147
|
raise KeyError(msg)
|
|
138
148
|
|
|
149
|
+
@overload
|
|
150
|
+
def get_config(self, name: "type[SyncConfigT]") -> "SyncConfigT": ...
|
|
151
|
+
|
|
152
|
+
@overload
|
|
153
|
+
def get_config(self, name: "type[AsyncConfigT]") -> "AsyncConfigT": ...
|
|
154
|
+
|
|
155
|
+
@overload
|
|
156
|
+
def get_config(self, name: str) -> "DatabaseConfig": ...
|
|
157
|
+
|
|
158
|
+
@overload
|
|
159
|
+
def get_config(self, name: "type[SyncDatabaseConfig]") -> "SyncDatabaseConfig": ...
|
|
160
|
+
|
|
161
|
+
@overload
|
|
162
|
+
def get_config(self, name: "type[AsyncDatabaseConfig]") -> "AsyncDatabaseConfig": ...
|
|
163
|
+
|
|
164
|
+
def get_config(
|
|
165
|
+
self, name: "Union[type[DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]], str, Any]"
|
|
166
|
+
) -> "Union[DatabaseConfigProtocol[ConnectionT, PoolT, DriverT], DatabaseConfig, SyncDatabaseConfig, AsyncDatabaseConfig]":
|
|
167
|
+
"""Get a configuration instance by name, supporting both base behavior and Litestar extensions.
|
|
168
|
+
|
|
169
|
+
This method extends the base get_config to support Litestar-specific lookup patterns
|
|
170
|
+
while maintaining compatibility with the base class signature. It supports lookup by
|
|
171
|
+
connection key, pool key, session key, config instance, or annotation type.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
name: The configuration identifier - can be:
|
|
175
|
+
- Type annotation (base class behavior)
|
|
176
|
+
- connection_key (e.g., "auth_db_connection")
|
|
177
|
+
- pool_key (e.g., "analytics_db_pool")
|
|
178
|
+
- session_key (e.g., "reporting_db_session")
|
|
179
|
+
- config instance
|
|
180
|
+
- annotation type
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
KeyError: If no configuration is found for the given name.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
The configuration instance for the specified name.
|
|
187
|
+
"""
|
|
188
|
+
# First try base class behavior for type-based lookup
|
|
189
|
+
# Only call super() if name matches the expected base class types
|
|
190
|
+
if not isinstance(name, str):
|
|
191
|
+
try:
|
|
192
|
+
return super().get_config(name) # type: ignore[no-any-return]
|
|
193
|
+
except (KeyError, AttributeError):
|
|
194
|
+
# Fall back to Litestar-specific lookup patterns
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
# Litestar-specific lookups by string keys
|
|
198
|
+
if isinstance(name, str):
|
|
199
|
+
for c in self.config:
|
|
200
|
+
if name in {c.connection_key, c.pool_key, c.session_key}:
|
|
201
|
+
return c # Return the DatabaseConfig wrapper for string lookups
|
|
202
|
+
|
|
203
|
+
# Lookup by config instance or annotation
|
|
204
|
+
for c in self.config:
|
|
205
|
+
annotation_match = hasattr(c, "annotation") and name == c.annotation
|
|
206
|
+
if name == c.config or annotation_match:
|
|
207
|
+
return c.config # Return the underlying config for type-based lookups
|
|
208
|
+
|
|
209
|
+
msg = f"No database configuration found for name '{name}'. Available keys: {self._get_available_keys()}"
|
|
210
|
+
raise KeyError(msg)
|
|
211
|
+
|
|
212
|
+
def provide_request_session(
|
|
213
|
+
self,
|
|
214
|
+
key: "Union[str, SyncConfigT, AsyncConfigT, type[Union[SyncConfigT, AsyncConfigT]]]",
|
|
215
|
+
state: "State",
|
|
216
|
+
scope: "Scope",
|
|
217
|
+
) -> "Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]":
|
|
218
|
+
"""Provide a database session for the specified configuration key from request scope.
|
|
219
|
+
|
|
220
|
+
This is a convenience method that combines get_config and get_request_session
|
|
221
|
+
into a single call, similar to Advanced Alchemy's provide_session pattern.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
key: The configuration identifier (same as get_config)
|
|
225
|
+
state: The Litestar application State object
|
|
226
|
+
scope: The ASGI scope containing the request context
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
A driver session instance for the specified database configuration
|
|
230
|
+
|
|
231
|
+
Example:
|
|
232
|
+
>>> sqlspec_plugin = connection.app.state.sqlspec
|
|
233
|
+
>>> # Direct session access by key
|
|
234
|
+
>>> auth_session = sqlspec_plugin.provide_request_session(
|
|
235
|
+
... "auth_db", state, scope
|
|
236
|
+
... )
|
|
237
|
+
>>> analytics_session = sqlspec_plugin.provide_request_session(
|
|
238
|
+
... "analytics_db", state, scope
|
|
239
|
+
... )
|
|
240
|
+
"""
|
|
241
|
+
# Get DatabaseConfig wrapper for Litestar methods
|
|
242
|
+
db_config = self._get_database_config(key)
|
|
243
|
+
return db_config.get_request_session(state, scope)
|
|
244
|
+
|
|
245
|
+
def provide_sync_request_session(
|
|
246
|
+
self, key: "Union[str, SyncConfigT, type[SyncConfigT]]", state: "State", scope: "Scope"
|
|
247
|
+
) -> "SyncDriverAdapterBase":
|
|
248
|
+
"""Provide a sync database session for the specified configuration key from request scope.
|
|
249
|
+
|
|
250
|
+
This method provides better type hints for sync database sessions, ensuring the returned
|
|
251
|
+
session is properly typed as SyncDriverAdapterBase for better IDE support and type safety.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
key: The sync configuration identifier
|
|
255
|
+
state: The Litestar application State object
|
|
256
|
+
scope: The ASGI scope containing the request context
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
A sync driver session instance for the specified database configuration
|
|
260
|
+
|
|
261
|
+
Example:
|
|
262
|
+
>>> sqlspec_plugin = connection.app.state.sqlspec
|
|
263
|
+
>>> auth_session = sqlspec_plugin.provide_sync_request_session(
|
|
264
|
+
... "auth_db", state, scope
|
|
265
|
+
... )
|
|
266
|
+
>>> # auth_session is now correctly typed as SyncDriverAdapterBase
|
|
267
|
+
"""
|
|
268
|
+
# Get DatabaseConfig wrapper for Litestar methods
|
|
269
|
+
db_config = self._get_database_config(key)
|
|
270
|
+
session = db_config.get_request_session(state, scope)
|
|
271
|
+
return cast("SyncDriverAdapterBase", session)
|
|
272
|
+
|
|
273
|
+
def provide_async_request_session(
|
|
274
|
+
self, key: "Union[str, AsyncConfigT, type[AsyncConfigT]]", state: "State", scope: "Scope"
|
|
275
|
+
) -> "AsyncDriverAdapterBase":
|
|
276
|
+
"""Provide an async database session for the specified configuration key from request scope.
|
|
277
|
+
|
|
278
|
+
This method provides better type hints for async database sessions, ensuring the returned
|
|
279
|
+
session is properly typed as AsyncDriverAdapterBase for better IDE support and type safety.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
key: The async configuration identifier
|
|
283
|
+
state: The Litestar application State object
|
|
284
|
+
scope: The ASGI scope containing the request context
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
An async driver session instance for the specified database configuration
|
|
288
|
+
|
|
289
|
+
Example:
|
|
290
|
+
>>> sqlspec_plugin = connection.app.state.sqlspec
|
|
291
|
+
>>> auth_session = sqlspec_plugin.provide_async_request_session(
|
|
292
|
+
... "auth_db", state, scope
|
|
293
|
+
... )
|
|
294
|
+
>>> # auth_session is now correctly typed as AsyncDriverAdapterBase
|
|
295
|
+
"""
|
|
296
|
+
# Get DatabaseConfig wrapper for Litestar methods
|
|
297
|
+
db_config = self._get_database_config(key)
|
|
298
|
+
session = db_config.get_request_session(state, scope)
|
|
299
|
+
return cast("AsyncDriverAdapterBase", session)
|
|
300
|
+
|
|
301
|
+
def provide_request_connection(
|
|
302
|
+
self,
|
|
303
|
+
key: "Union[str, SyncConfigT, AsyncConfigT, type[Union[SyncConfigT, AsyncConfigT]]]",
|
|
304
|
+
state: "State",
|
|
305
|
+
scope: "Scope",
|
|
306
|
+
) -> Any:
|
|
307
|
+
"""Provide a database connection for the specified configuration key from request scope.
|
|
308
|
+
|
|
309
|
+
This is a convenience method that combines get_config and get_request_connection
|
|
310
|
+
into a single call.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
key: The configuration identifier (same as get_config)
|
|
314
|
+
state: The Litestar application State object
|
|
315
|
+
scope: The ASGI scope containing the request context
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
A database connection instance for the specified database configuration
|
|
319
|
+
|
|
320
|
+
Example:
|
|
321
|
+
>>> sqlspec_plugin = connection.app.state.sqlspec
|
|
322
|
+
>>> # Direct connection access by key
|
|
323
|
+
>>> auth_conn = sqlspec_plugin.provide_request_connection(
|
|
324
|
+
... "auth_db", state, scope
|
|
325
|
+
... )
|
|
326
|
+
>>> analytics_conn = sqlspec_plugin.provide_request_connection(
|
|
327
|
+
... "analytics_db", state, scope
|
|
328
|
+
... )
|
|
329
|
+
"""
|
|
330
|
+
# Get DatabaseConfig wrapper for Litestar methods
|
|
331
|
+
db_config = self._get_database_config(key)
|
|
332
|
+
return db_config.get_request_connection(state, scope)
|
|
333
|
+
|
|
334
|
+
def _get_database_config(
|
|
335
|
+
self, key: "Union[str, SyncConfigT, AsyncConfigT, type[Union[SyncConfigT, AsyncConfigT]]]"
|
|
336
|
+
) -> DatabaseConfig:
|
|
337
|
+
"""Get a DatabaseConfig wrapper instance by name.
|
|
338
|
+
|
|
339
|
+
This is used internally by provide_request_session and provide_request_connection
|
|
340
|
+
to get the DatabaseConfig wrapper that has the request session methods.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
key: The configuration identifier
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
The DatabaseConfig wrapper instance
|
|
347
|
+
|
|
348
|
+
Raises:
|
|
349
|
+
KeyError: If no configuration is found for the given key
|
|
350
|
+
"""
|
|
351
|
+
# For string keys, lookup by connection/pool/session keys
|
|
352
|
+
if isinstance(key, str):
|
|
353
|
+
for c in self.config:
|
|
354
|
+
if key in {c.connection_key, c.pool_key, c.session_key}:
|
|
355
|
+
return c
|
|
356
|
+
|
|
357
|
+
# For other keys, lookup by config instance or annotation
|
|
358
|
+
for c in self.config:
|
|
359
|
+
annotation_match = hasattr(c, "annotation") and key == c.annotation
|
|
360
|
+
if key == c.config or annotation_match:
|
|
361
|
+
return c
|
|
362
|
+
|
|
363
|
+
msg = f"No database configuration found for name '{key}'. Available keys: {self._get_available_keys()}"
|
|
364
|
+
raise KeyError(msg)
|
|
365
|
+
|
|
366
|
+
def _get_available_keys(self) -> "list[str]":
|
|
367
|
+
"""Get a list of all available configuration keys for error messages."""
|
|
368
|
+
keys = []
|
|
369
|
+
for c in self.config:
|
|
370
|
+
keys.extend([c.connection_key, c.pool_key, c.session_key])
|
|
371
|
+
return keys
|
|
372
|
+
|
|
139
373
|
def _validate_dependency_keys(self) -> None:
|
|
140
374
|
"""Validate that connection and pool keys are unique across configurations.
|
|
141
375
|
|