sqlspec 0.16.1__cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- 51ff5a9eadfdefd49f98__mypyc.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/__init__.py +92 -0
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +14 -0
- sqlspec/_serialization.py +77 -0
- sqlspec/_sql.py +1780 -0
- sqlspec/_typing.py +680 -0
- sqlspec/adapters/__init__.py +0 -0
- sqlspec/adapters/adbc/__init__.py +5 -0
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +361 -0
- sqlspec/adapters/adbc/driver.py +512 -0
- sqlspec/adapters/aiosqlite/__init__.py +19 -0
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +253 -0
- sqlspec/adapters/aiosqlite/driver.py +248 -0
- sqlspec/adapters/asyncmy/__init__.py +19 -0
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +180 -0
- sqlspec/adapters/asyncmy/driver.py +274 -0
- sqlspec/adapters/asyncpg/__init__.py +21 -0
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +229 -0
- sqlspec/adapters/asyncpg/driver.py +344 -0
- sqlspec/adapters/bigquery/__init__.py +18 -0
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +298 -0
- sqlspec/adapters/bigquery/driver.py +558 -0
- sqlspec/adapters/duckdb/__init__.py +22 -0
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +504 -0
- sqlspec/adapters/duckdb/driver.py +368 -0
- sqlspec/adapters/oracledb/__init__.py +32 -0
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +317 -0
- sqlspec/adapters/oracledb/driver.py +538 -0
- sqlspec/adapters/psqlpy/__init__.py +16 -0
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +214 -0
- sqlspec/adapters/psqlpy/driver.py +530 -0
- sqlspec/adapters/psycopg/__init__.py +32 -0
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +426 -0
- sqlspec/adapters/psycopg/driver.py +796 -0
- sqlspec/adapters/sqlite/__init__.py +15 -0
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +240 -0
- sqlspec/adapters/sqlite/driver.py +294 -0
- sqlspec/base.py +571 -0
- sqlspec/builder/__init__.py +62 -0
- sqlspec/builder/_base.py +473 -0
- sqlspec/builder/_column.py +320 -0
- sqlspec/builder/_ddl.py +1346 -0
- sqlspec/builder/_ddl_utils.py +103 -0
- sqlspec/builder/_delete.py +76 -0
- sqlspec/builder/_insert.py +256 -0
- sqlspec/builder/_merge.py +71 -0
- sqlspec/builder/_parsing_utils.py +140 -0
- sqlspec/builder/_select.py +170 -0
- sqlspec/builder/_update.py +188 -0
- sqlspec/builder/mixins/__init__.py +55 -0
- sqlspec/builder/mixins/_cte_and_set_ops.py +222 -0
- sqlspec/builder/mixins/_delete_operations.py +41 -0
- sqlspec/builder/mixins/_insert_operations.py +244 -0
- sqlspec/builder/mixins/_join_operations.py +122 -0
- sqlspec/builder/mixins/_merge_operations.py +476 -0
- sqlspec/builder/mixins/_order_limit_operations.py +135 -0
- sqlspec/builder/mixins/_pivot_operations.py +153 -0
- sqlspec/builder/mixins/_select_operations.py +603 -0
- sqlspec/builder/mixins/_update_operations.py +187 -0
- sqlspec/builder/mixins/_where_clause.py +621 -0
- sqlspec/cli.py +247 -0
- sqlspec/config.py +395 -0
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/core/cache.py +871 -0
- sqlspec/core/compiler.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/core/compiler.py +417 -0
- sqlspec/core/filters.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/core/filters.py +830 -0
- sqlspec/core/hashing.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/core/parameters.py +1237 -0
- sqlspec/core/result.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/core/result.py +677 -0
- sqlspec/core/splitter.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/core/splitter.py +819 -0
- sqlspec/core/statement.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/core/statement.py +676 -0
- sqlspec/driver/__init__.py +19 -0
- sqlspec/driver/_async.py +502 -0
- sqlspec/driver/_common.py +631 -0
- sqlspec/driver/_sync.py +503 -0
- sqlspec/driver/mixins/__init__.py +6 -0
- sqlspec/driver/mixins/_result_tools.py +193 -0
- sqlspec/driver/mixins/_sql_translator.py +86 -0
- sqlspec/exceptions.py +193 -0
- sqlspec/extensions/__init__.py +0 -0
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +461 -0
- sqlspec/extensions/litestar/__init__.py +6 -0
- sqlspec/extensions/litestar/_utils.py +52 -0
- sqlspec/extensions/litestar/cli.py +48 -0
- sqlspec/extensions/litestar/config.py +92 -0
- sqlspec/extensions/litestar/handlers.py +260 -0
- sqlspec/extensions/litestar/plugin.py +145 -0
- sqlspec/extensions/litestar/providers.py +454 -0
- sqlspec/loader.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/loader.py +760 -0
- sqlspec/migrations/__init__.py +35 -0
- sqlspec/migrations/base.py +414 -0
- sqlspec/migrations/commands.py +443 -0
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +213 -0
- sqlspec/migrations/tracker.py +140 -0
- sqlspec/migrations/utils.py +129 -0
- sqlspec/protocols.py +407 -0
- sqlspec/py.typed +0 -0
- sqlspec/storage/__init__.py +23 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +163 -0
- sqlspec/storage/backends/fsspec.py +386 -0
- sqlspec/storage/backends/obstore.py +459 -0
- sqlspec/storage/capabilities.py +102 -0
- sqlspec/storage/registry.py +239 -0
- sqlspec/typing.py +299 -0
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/correlation.py +150 -0
- sqlspec/utils/deprecation.py +106 -0
- sqlspec/utils/fixtures.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/utils/fixtures.py +58 -0
- sqlspec/utils/logging.py +127 -0
- sqlspec/utils/module_loader.py +89 -0
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +32 -0
- sqlspec/utils/sync_tools.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/utils/sync_tools.py +237 -0
- sqlspec/utils/text.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/utils/text.py +96 -0
- sqlspec/utils/type_guards.cpython-310-aarch64-linux-gnu.so +0 -0
- sqlspec/utils/type_guards.py +1139 -0
- sqlspec-0.16.1.dist-info/METADATA +365 -0
- sqlspec-0.16.1.dist-info/RECORD +148 -0
- sqlspec-0.16.1.dist-info/WHEEL +7 -0
- sqlspec-0.16.1.dist-info/entry_points.txt +2 -0
- sqlspec-0.16.1.dist-info/licenses/LICENSE +21 -0
- sqlspec-0.16.1.dist-info/licenses/NOTICE +29 -0
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import inspect
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
from contextlib import AbstractAsyncContextManager
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
|
6
|
+
|
|
7
|
+
from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
|
|
8
|
+
|
|
9
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
10
|
+
from sqlspec.extensions.litestar._utils import (
|
|
11
|
+
delete_sqlspec_scope_state,
|
|
12
|
+
get_sqlspec_scope_state,
|
|
13
|
+
set_sqlspec_scope_state,
|
|
14
|
+
)
|
|
15
|
+
from sqlspec.utils.sync_tools import ensure_async_
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Awaitable, Coroutine
|
|
19
|
+
|
|
20
|
+
from litestar import Litestar
|
|
21
|
+
from litestar.datastructures.state import State
|
|
22
|
+
from litestar.types import Message, Scope
|
|
23
|
+
|
|
24
|
+
from sqlspec.config import DatabaseConfigProtocol, DriverT
|
|
25
|
+
from sqlspec.typing import ConnectionT, PoolT
|
|
26
|
+
|
|
27
|
+
SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
|
|
28
|
+
|
|
29
|
+
__all__ = (
|
|
30
|
+
"SESSION_TERMINUS_ASGI_EVENTS",
|
|
31
|
+
"autocommit_handler_maker",
|
|
32
|
+
"connection_provider_maker",
|
|
33
|
+
"lifespan_handler_maker",
|
|
34
|
+
"manual_handler_maker",
|
|
35
|
+
"pool_provider_maker",
|
|
36
|
+
"session_provider_maker",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def manual_handler_maker(connection_scope_key: str) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
|
|
41
|
+
"""Create handler for manual connection management.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
connection_scope_key: The key used to store the connection in the ASGI scope.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The handler callable.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
async def handler(message: "Message", scope: "Scope") -> None:
|
|
51
|
+
"""Handle closing and cleaning up connections before sending the response.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
message: ASGI Message.
|
|
55
|
+
scope: ASGI Scope.
|
|
56
|
+
"""
|
|
57
|
+
connection = get_sqlspec_scope_state(scope, connection_scope_key)
|
|
58
|
+
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
59
|
+
if hasattr(connection, "close") and callable(connection.close):
|
|
60
|
+
await ensure_async_(connection.close)()
|
|
61
|
+
delete_sqlspec_scope_state(scope, connection_scope_key)
|
|
62
|
+
|
|
63
|
+
return handler
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def autocommit_handler_maker(
|
|
67
|
+
connection_scope_key: str,
|
|
68
|
+
commit_on_redirect: bool = False,
|
|
69
|
+
extra_commit_statuses: "Optional[set[int]]" = None,
|
|
70
|
+
extra_rollback_statuses: "Optional[set[int]]" = None,
|
|
71
|
+
) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
|
|
72
|
+
"""Create handler for automatic transaction commit/rollback based on response status.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
connection_scope_key: The key used to store the connection in the ASGI scope.
|
|
76
|
+
commit_on_redirect: Issue a commit when the response status is a redirect (3XX).
|
|
77
|
+
extra_commit_statuses: A set of additional status codes that trigger a commit.
|
|
78
|
+
extra_rollback_statuses: A set of additional status codes that trigger a rollback.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ImproperConfigurationError: If extra_commit_statuses and extra_rollback_statuses share status codes.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The handler callable.
|
|
85
|
+
"""
|
|
86
|
+
if extra_commit_statuses is None:
|
|
87
|
+
extra_commit_statuses = set()
|
|
88
|
+
|
|
89
|
+
if extra_rollback_statuses is None:
|
|
90
|
+
extra_rollback_statuses = set()
|
|
91
|
+
|
|
92
|
+
if len(extra_commit_statuses & extra_rollback_statuses) > 0:
|
|
93
|
+
msg = "Extra rollback statuses and commit statuses must not share any status codes"
|
|
94
|
+
raise ImproperConfigurationError(msg)
|
|
95
|
+
|
|
96
|
+
commit_range = range(200, 400 if commit_on_redirect else 300)
|
|
97
|
+
|
|
98
|
+
async def handler(message: "Message", scope: "Scope") -> None:
|
|
99
|
+
"""Handle commit/rollback, closing and cleaning up connections before sending.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
message: ASGI Message.
|
|
103
|
+
scope: ASGI Scope.
|
|
104
|
+
"""
|
|
105
|
+
connection = get_sqlspec_scope_state(scope, connection_scope_key)
|
|
106
|
+
try:
|
|
107
|
+
if connection is not None and message["type"] == HTTP_RESPONSE_START:
|
|
108
|
+
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
|
|
109
|
+
"status"
|
|
110
|
+
] not in extra_rollback_statuses:
|
|
111
|
+
if hasattr(connection, "commit") and callable(connection.commit):
|
|
112
|
+
await ensure_async_(connection.commit)()
|
|
113
|
+
elif hasattr(connection, "rollback") and callable(connection.rollback):
|
|
114
|
+
await ensure_async_(connection.rollback)()
|
|
115
|
+
finally:
|
|
116
|
+
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
117
|
+
if hasattr(connection, "close") and callable(connection.close):
|
|
118
|
+
await ensure_async_(connection.close)()
|
|
119
|
+
delete_sqlspec_scope_state(scope, connection_scope_key)
|
|
120
|
+
|
|
121
|
+
return handler
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def lifespan_handler_maker(
|
|
125
|
+
config: "DatabaseConfigProtocol[Any, Any, Any]", pool_key: str
|
|
126
|
+
) -> "Callable[[Litestar], AbstractAsyncContextManager[None]]":
|
|
127
|
+
"""Create lifespan handler for managing database connection pool lifecycle.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
config: The database configuration object.
|
|
131
|
+
pool_key: The key under which the connection pool will be stored in `app.state`.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
The lifespan handler function.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
@contextlib.asynccontextmanager
|
|
138
|
+
async def lifespan_handler(app: "Litestar") -> "AsyncGenerator[None, None]":
|
|
139
|
+
"""Manage database pool lifecycle for the application.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
app: The Litestar application instance.
|
|
143
|
+
|
|
144
|
+
Yields:
|
|
145
|
+
Control to application during pool lifetime.
|
|
146
|
+
"""
|
|
147
|
+
db_pool = await ensure_async_(config.create_pool)()
|
|
148
|
+
app.state.update({pool_key: db_pool})
|
|
149
|
+
try:
|
|
150
|
+
yield
|
|
151
|
+
finally:
|
|
152
|
+
app.state.pop(pool_key, None)
|
|
153
|
+
try:
|
|
154
|
+
await ensure_async_(config.close_pool)()
|
|
155
|
+
except Exception as e:
|
|
156
|
+
if app.logger:
|
|
157
|
+
app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e)
|
|
158
|
+
|
|
159
|
+
return lifespan_handler
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def pool_provider_maker(
|
|
163
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str
|
|
164
|
+
) -> "Callable[[State, Scope], Awaitable[PoolT]]":
|
|
165
|
+
"""Create provider for injecting the application-level database pool.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
config: The database configuration object.
|
|
169
|
+
pool_key: The key used to store the connection pool in `app.state`.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
The pool provider function.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
async def provide_pool(state: "State", scope: "Scope") -> "PoolT":
|
|
176
|
+
"""Provide the database pool from application state.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
state: The Litestar application State object.
|
|
180
|
+
scope: The ASGI scope (unused for app-level pool).
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
The database connection pool.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
ImproperConfigurationError: If the pool is not found in `app.state`.
|
|
187
|
+
"""
|
|
188
|
+
if (db_pool := state.get(pool_key)) is None:
|
|
189
|
+
msg = (
|
|
190
|
+
f"Database pool with key '{pool_key}' not found in application state. "
|
|
191
|
+
"Ensure the SQLSpec lifespan handler is correctly configured and has run."
|
|
192
|
+
)
|
|
193
|
+
raise ImproperConfigurationError(msg)
|
|
194
|
+
return cast("PoolT", db_pool)
|
|
195
|
+
|
|
196
|
+
return provide_pool
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def connection_provider_maker(
|
|
200
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str, connection_key: str
|
|
201
|
+
) -> "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]":
|
|
202
|
+
async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ConnectionT, None]":
|
|
203
|
+
if (db_pool := state.get(pool_key)) is None:
|
|
204
|
+
msg = f"Database pool with key '{pool_key}' not found. Cannot create a connection."
|
|
205
|
+
raise ImproperConfigurationError(msg)
|
|
206
|
+
|
|
207
|
+
connection_cm = config.provide_connection(db_pool)
|
|
208
|
+
|
|
209
|
+
if not isinstance(connection_cm, AbstractAsyncContextManager):
|
|
210
|
+
conn_instance: ConnectionT
|
|
211
|
+
if hasattr(connection_cm, "__await__"):
|
|
212
|
+
conn_instance = await cast("Awaitable[ConnectionT]", connection_cm)
|
|
213
|
+
else:
|
|
214
|
+
conn_instance = cast("ConnectionT", connection_cm)
|
|
215
|
+
set_sqlspec_scope_state(scope, connection_key, conn_instance)
|
|
216
|
+
yield conn_instance
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
entered_connection = await connection_cm.__aenter__()
|
|
220
|
+
try:
|
|
221
|
+
set_sqlspec_scope_state(scope, connection_key, entered_connection)
|
|
222
|
+
yield entered_connection
|
|
223
|
+
finally:
|
|
224
|
+
await connection_cm.__aexit__(None, None, None)
|
|
225
|
+
delete_sqlspec_scope_state(scope, connection_key)
|
|
226
|
+
|
|
227
|
+
return provide_connection
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def session_provider_maker(
|
|
231
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", connection_dependency_key: str
|
|
232
|
+
) -> "Callable[[Any], AsyncGenerator[DriverT, None]]":
|
|
233
|
+
async def provide_session(*args: Any, **kwargs: Any) -> "AsyncGenerator[DriverT, None]":
|
|
234
|
+
yield cast("DriverT", config.driver_type(connection=args[0] if args else kwargs.get(connection_dependency_key))) # pyright: ignore
|
|
235
|
+
|
|
236
|
+
conn_type_annotation = config.connection_type
|
|
237
|
+
|
|
238
|
+
from litestar.params import Dependency
|
|
239
|
+
|
|
240
|
+
db_conn_param = inspect.Parameter(
|
|
241
|
+
name=connection_dependency_key,
|
|
242
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
243
|
+
annotation=conn_type_annotation,
|
|
244
|
+
default=Dependency(skip_validation=True),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
provider_signature = inspect.Signature(
|
|
248
|
+
parameters=[db_conn_param],
|
|
249
|
+
return_annotation=AsyncGenerator[config.driver_type, None], # type: ignore[name-defined]
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
provide_session.__signature__ = provider_signature # type: ignore[attr-defined]
|
|
253
|
+
|
|
254
|
+
if not hasattr(provide_session, "__annotations__") or provide_session.__annotations__ is None:
|
|
255
|
+
provide_session.__annotations__ = {}
|
|
256
|
+
|
|
257
|
+
provide_session.__annotations__[connection_dependency_key] = conn_type_annotation
|
|
258
|
+
provide_session.__annotations__["return"] = AsyncGenerator[config.driver_type, None] # type: ignore[name-defined]
|
|
259
|
+
|
|
260
|
+
return provide_session
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
2
|
+
|
|
3
|
+
from litestar.di import Provide
|
|
4
|
+
from litestar.plugins import CLIPlugin, InitPluginProtocol
|
|
5
|
+
|
|
6
|
+
from sqlspec.base import SQLSpec as SQLSpecBase
|
|
7
|
+
from sqlspec.config import AsyncConfigT, DatabaseConfigProtocol, DriverT, SyncConfigT
|
|
8
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
9
|
+
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
10
|
+
from sqlspec.typing import ConnectionT, PoolT
|
|
11
|
+
from sqlspec.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from click import Group
|
|
15
|
+
from litestar.config.app import AppConfig
|
|
16
|
+
|
|
17
|
+
logger = get_logger("extensions.litestar")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SQLSpec(InitPluginProtocol, CLIPlugin, SQLSpecBase):
|
|
21
|
+
"""Litestar plugin for SQLSpec database integration."""
|
|
22
|
+
|
|
23
|
+
__slots__ = ("_config", "_plugin_configs")
|
|
24
|
+
|
|
25
|
+
def __init__(self, config: Union["SyncConfigT", "AsyncConfigT", "DatabaseConfig", list["DatabaseConfig"]]) -> None:
|
|
26
|
+
"""Initialize SQLSpec plugin.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
config: Database configuration for SQLSpec plugin.
|
|
30
|
+
"""
|
|
31
|
+
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
|
|
32
|
+
if isinstance(config, DatabaseConfigProtocol):
|
|
33
|
+
self._plugin_configs: list[DatabaseConfig] = [DatabaseConfig(config=config)]
|
|
34
|
+
elif isinstance(config, DatabaseConfig):
|
|
35
|
+
self._plugin_configs = [config]
|
|
36
|
+
else:
|
|
37
|
+
self._plugin_configs = config
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def config(self) -> "list[DatabaseConfig]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
41
|
+
"""Return the plugin configuration.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
List of database configurations.
|
|
45
|
+
"""
|
|
46
|
+
return self._plugin_configs
|
|
47
|
+
|
|
48
|
+
def on_cli_init(self, cli: "Group") -> None:
|
|
49
|
+
"""Configure CLI commands for SQLSpec database operations.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
cli: The Click command group to add commands to.
|
|
53
|
+
"""
|
|
54
|
+
from sqlspec.extensions.litestar.cli import database_group
|
|
55
|
+
|
|
56
|
+
cli.add_command(database_group)
|
|
57
|
+
|
|
58
|
+
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
|
|
59
|
+
"""Configure Litestar application with SQLSpec database integration.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
app_config: The Litestar application configuration instance.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The updated application configuration instance.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
self._validate_dependency_keys()
|
|
69
|
+
|
|
70
|
+
def store_sqlspec_in_state() -> None:
|
|
71
|
+
app_config.state.sqlspec = self
|
|
72
|
+
|
|
73
|
+
app_config.on_startup.append(store_sqlspec_in_state)
|
|
74
|
+
app_config.signature_types.extend(
|
|
75
|
+
[SQLSpec, ConnectionT, PoolT, DriverT, DatabaseConfig, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT]
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
signature_namespace = {}
|
|
79
|
+
|
|
80
|
+
for c in self._plugin_configs:
|
|
81
|
+
c.annotation = self.add_config(c.config)
|
|
82
|
+
app_config.signature_types.append(c.annotation)
|
|
83
|
+
app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr]
|
|
84
|
+
app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr]
|
|
85
|
+
|
|
86
|
+
if hasattr(c.config, "get_signature_namespace"):
|
|
87
|
+
signature_namespace.update(c.config.get_signature_namespace()) # type: ignore[attr-defined]
|
|
88
|
+
|
|
89
|
+
app_config.before_send.append(c.before_send_handler)
|
|
90
|
+
app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType]
|
|
91
|
+
app_config.dependencies.update(
|
|
92
|
+
{
|
|
93
|
+
c.connection_key: Provide(c.connection_provider),
|
|
94
|
+
c.pool_key: Provide(c.pool_provider),
|
|
95
|
+
c.session_key: Provide(c.session_provider),
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if signature_namespace:
|
|
100
|
+
app_config.signature_namespace.update(signature_namespace)
|
|
101
|
+
|
|
102
|
+
return app_config
|
|
103
|
+
|
|
104
|
+
def get_annotations(self) -> "list[type[Union[SyncConfigT, AsyncConfigT]]]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
105
|
+
"""Return the list of annotations.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
List of annotations.
|
|
109
|
+
"""
|
|
110
|
+
return [c.annotation for c in self.config]
|
|
111
|
+
|
|
112
|
+
def get_annotation(
|
|
113
|
+
self, key: "Union[str, SyncConfigT, AsyncConfigT, type[Union[SyncConfigT, AsyncConfigT]]]"
|
|
114
|
+
) -> "type[Union[SyncConfigT, AsyncConfigT]]":
|
|
115
|
+
"""Return the annotation for the given configuration.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
key: The configuration instance or key to lookup
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
KeyError: If no configuration is found for the given key.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The annotation for the configuration.
|
|
125
|
+
"""
|
|
126
|
+
for c in self.config:
|
|
127
|
+
if key == c.config or key in {c.annotation, c.connection_key, c.pool_key}:
|
|
128
|
+
return c.annotation
|
|
129
|
+
msg = f"No configuration found for {key}"
|
|
130
|
+
raise KeyError(msg)
|
|
131
|
+
|
|
132
|
+
def _validate_dependency_keys(self) -> None:
|
|
133
|
+
"""Validate that connection and pool keys are unique across configurations.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ImproperConfigurationError: If connection keys or pool keys are not unique.
|
|
137
|
+
"""
|
|
138
|
+
connection_keys = [c.connection_key for c in self.config]
|
|
139
|
+
pool_keys = [c.pool_key for c in self.config]
|
|
140
|
+
if len(set(connection_keys)) != len(connection_keys):
|
|
141
|
+
msg = "When using multiple database configuration, each configuration must have a unique `connection_key`."
|
|
142
|
+
raise ImproperConfigurationError(detail=msg)
|
|
143
|
+
if len(set(pool_keys)) != len(pool_keys):
|
|
144
|
+
msg = "When using multiple database configuration, each configuration must have a unique `pool_key`."
|
|
145
|
+
raise ImproperConfigurationError(detail=msg)
|