sqlspec 0.16.0__cp313-cp313-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- 51ff5a9eadfdefd49f98__mypyc.cpython-313-darwin.so +0 -0
- sqlspec/__init__.py +92 -0
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +14 -0
- sqlspec/_serialization.py +77 -0
- sqlspec/_sql.py +1347 -0
- sqlspec/_typing.py +680 -0
- sqlspec/adapters/__init__.py +0 -0
- sqlspec/adapters/adbc/__init__.py +5 -0
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +361 -0
- sqlspec/adapters/adbc/driver.py +512 -0
- sqlspec/adapters/aiosqlite/__init__.py +19 -0
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +253 -0
- sqlspec/adapters/aiosqlite/driver.py +248 -0
- sqlspec/adapters/asyncmy/__init__.py +19 -0
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +180 -0
- sqlspec/adapters/asyncmy/driver.py +274 -0
- sqlspec/adapters/asyncpg/__init__.py +21 -0
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +229 -0
- sqlspec/adapters/asyncpg/driver.py +344 -0
- sqlspec/adapters/bigquery/__init__.py +18 -0
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +298 -0
- sqlspec/adapters/bigquery/driver.py +558 -0
- sqlspec/adapters/duckdb/__init__.py +22 -0
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +504 -0
- sqlspec/adapters/duckdb/driver.py +368 -0
- sqlspec/adapters/oracledb/__init__.py +32 -0
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +317 -0
- sqlspec/adapters/oracledb/driver.py +538 -0
- sqlspec/adapters/psqlpy/__init__.py +16 -0
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +214 -0
- sqlspec/adapters/psqlpy/driver.py +530 -0
- sqlspec/adapters/psycopg/__init__.py +32 -0
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +426 -0
- sqlspec/adapters/psycopg/driver.py +796 -0
- sqlspec/adapters/sqlite/__init__.py +15 -0
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +240 -0
- sqlspec/adapters/sqlite/driver.py +294 -0
- sqlspec/base.py +571 -0
- sqlspec/builder/__init__.py +62 -0
- sqlspec/builder/_base.py +440 -0
- sqlspec/builder/_column.py +324 -0
- sqlspec/builder/_ddl.py +1383 -0
- sqlspec/builder/_ddl_utils.py +104 -0
- sqlspec/builder/_delete.py +77 -0
- sqlspec/builder/_insert.py +241 -0
- sqlspec/builder/_merge.py +56 -0
- sqlspec/builder/_parsing_utils.py +140 -0
- sqlspec/builder/_select.py +174 -0
- sqlspec/builder/_update.py +186 -0
- sqlspec/builder/mixins/__init__.py +55 -0
- sqlspec/builder/mixins/_cte_and_set_ops.py +195 -0
- sqlspec/builder/mixins/_delete_operations.py +36 -0
- sqlspec/builder/mixins/_insert_operations.py +152 -0
- sqlspec/builder/mixins/_join_operations.py +115 -0
- sqlspec/builder/mixins/_merge_operations.py +416 -0
- sqlspec/builder/mixins/_order_limit_operations.py +123 -0
- sqlspec/builder/mixins/_pivot_operations.py +144 -0
- sqlspec/builder/mixins/_select_operations.py +599 -0
- sqlspec/builder/mixins/_update_operations.py +164 -0
- sqlspec/builder/mixins/_where_clause.py +609 -0
- sqlspec/cli.py +247 -0
- sqlspec/config.py +395 -0
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.cpython-313-darwin.so +0 -0
- sqlspec/core/cache.py +873 -0
- sqlspec/core/compiler.cpython-313-darwin.so +0 -0
- sqlspec/core/compiler.py +396 -0
- sqlspec/core/filters.cpython-313-darwin.so +0 -0
- sqlspec/core/filters.py +830 -0
- sqlspec/core/hashing.cpython-313-darwin.so +0 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.cpython-313-darwin.so +0 -0
- sqlspec/core/parameters.py +1209 -0
- sqlspec/core/result.cpython-313-darwin.so +0 -0
- sqlspec/core/result.py +664 -0
- sqlspec/core/splitter.cpython-313-darwin.so +0 -0
- sqlspec/core/splitter.py +819 -0
- sqlspec/core/statement.cpython-313-darwin.so +0 -0
- sqlspec/core/statement.py +666 -0
- sqlspec/driver/__init__.py +19 -0
- sqlspec/driver/_async.py +472 -0
- sqlspec/driver/_common.py +612 -0
- sqlspec/driver/_sync.py +473 -0
- sqlspec/driver/mixins/__init__.py +6 -0
- sqlspec/driver/mixins/_result_tools.py +164 -0
- sqlspec/driver/mixins/_sql_translator.py +36 -0
- sqlspec/exceptions.py +193 -0
- sqlspec/extensions/__init__.py +0 -0
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +461 -0
- sqlspec/extensions/litestar/__init__.py +6 -0
- sqlspec/extensions/litestar/_utils.py +52 -0
- sqlspec/extensions/litestar/cli.py +48 -0
- sqlspec/extensions/litestar/config.py +92 -0
- sqlspec/extensions/litestar/handlers.py +260 -0
- sqlspec/extensions/litestar/plugin.py +145 -0
- sqlspec/extensions/litestar/providers.py +454 -0
- sqlspec/loader.cpython-313-darwin.so +0 -0
- sqlspec/loader.py +760 -0
- sqlspec/migrations/__init__.py +35 -0
- sqlspec/migrations/base.py +414 -0
- sqlspec/migrations/commands.py +443 -0
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +213 -0
- sqlspec/migrations/tracker.py +140 -0
- sqlspec/migrations/utils.py +129 -0
- sqlspec/protocols.py +400 -0
- sqlspec/py.typed +0 -0
- sqlspec/storage/__init__.py +23 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +163 -0
- sqlspec/storage/backends/fsspec.py +386 -0
- sqlspec/storage/backends/obstore.py +459 -0
- sqlspec/storage/capabilities.py +102 -0
- sqlspec/storage/registry.py +239 -0
- sqlspec/typing.py +299 -0
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/correlation.py +150 -0
- sqlspec/utils/deprecation.py +106 -0
- sqlspec/utils/fixtures.cpython-313-darwin.so +0 -0
- sqlspec/utils/fixtures.py +58 -0
- sqlspec/utils/logging.py +127 -0
- sqlspec/utils/module_loader.py +89 -0
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +32 -0
- sqlspec/utils/sync_tools.cpython-313-darwin.so +0 -0
- sqlspec/utils/sync_tools.py +237 -0
- sqlspec/utils/text.cpython-313-darwin.so +0 -0
- sqlspec/utils/text.py +96 -0
- sqlspec/utils/type_guards.cpython-313-darwin.so +0 -0
- sqlspec/utils/type_guards.py +1135 -0
- sqlspec-0.16.0.dist-info/METADATA +365 -0
- sqlspec-0.16.0.dist-info/RECORD +148 -0
- sqlspec-0.16.0.dist-info/WHEEL +4 -0
- sqlspec-0.16.0.dist-info/entry_points.txt +2 -0
- sqlspec-0.16.0.dist-info/licenses/LICENSE +21 -0
- sqlspec-0.16.0.dist-info/licenses/NOTICE +29 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from litestar.types import Scope
|
|
5
|
+
|
|
6
|
+
__all__ = ("delete_sqlspec_scope_state", "get_sqlspec_scope_state", "set_sqlspec_scope_state")
|
|
7
|
+
|
|
8
|
+
_SCOPE_NAMESPACE = "_sqlspec"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_sqlspec_scope_state(scope: "Scope", key: str, default: Any = None, pop: bool = False) -> Any:
|
|
12
|
+
"""Get an internal value from connection scope state.
|
|
13
|
+
|
|
14
|
+
Note:
|
|
15
|
+
If called with a default value, this method behaves like to `dict.set_default()`, both setting the key in the
|
|
16
|
+
namespace to the default value, and returning it.
|
|
17
|
+
|
|
18
|
+
If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
|
|
19
|
+
exist.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
scope: The connection scope.
|
|
23
|
+
key: Key to get from internal namespace in scope state.
|
|
24
|
+
default: Default value to return.
|
|
25
|
+
pop: Boolean flag dictating whether the value should be deleted from the state.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Value mapped to ``key`` in internal connection scope namespace.
|
|
29
|
+
"""
|
|
30
|
+
namespace = scope.setdefault(_SCOPE_NAMESPACE, {}) # type: ignore[misc]
|
|
31
|
+
return namespace.pop(key, default) if pop else namespace.get(key, default) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def set_sqlspec_scope_state(scope: "Scope", key: str, value: Any) -> None:
|
|
35
|
+
"""Set an internal value in connection scope state.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
scope: The connection scope.
|
|
39
|
+
key: Key to set under internal namespace in scope state.
|
|
40
|
+
value: Value for key.
|
|
41
|
+
"""
|
|
42
|
+
scope.setdefault(_SCOPE_NAMESPACE, {})[key] = value # type: ignore[misc]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def delete_sqlspec_scope_state(scope: "Scope", key: str) -> None:
|
|
46
|
+
"""Remove an internal value from connection scope state.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
scope: The connection scope.
|
|
50
|
+
key: Key to set under internal namespace in scope state.
|
|
51
|
+
"""
|
|
52
|
+
del scope.setdefault(_SCOPE_NAMESPACE, {})[key] # type: ignore[misc]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Litestar CLI integration for SQLSpec migrations."""
|
|
2
|
+
|
|
3
|
+
from contextlib import suppress
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from litestar.cli._utils import LitestarGroup
|
|
7
|
+
|
|
8
|
+
from sqlspec.cli import add_migration_commands
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import rich_click as click
|
|
12
|
+
except ImportError:
|
|
13
|
+
import click # type: ignore[no-redef]
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from litestar import Litestar
|
|
17
|
+
|
|
18
|
+
from sqlspec.extensions.litestar.plugin import SQLSpec
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_database_migration_plugin(app: "Litestar") -> "SQLSpec":
|
|
22
|
+
"""Retrieve the SQLSpec plugin from the Litestar application's plugins.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
app: The Litestar application
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The SQLSpec plugin
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ImproperConfigurationError: If the SQLSpec plugin is not found
|
|
32
|
+
"""
|
|
33
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
34
|
+
from sqlspec.extensions.litestar.plugin import SQLSpec
|
|
35
|
+
|
|
36
|
+
with suppress(KeyError):
|
|
37
|
+
return app.plugins.get(SQLSpec)
|
|
38
|
+
msg = "Failed to initialize database migrations. The required SQLSpec plugin is missing."
|
|
39
|
+
raise ImproperConfigurationError(msg)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@click.group(cls=LitestarGroup, name="db")
|
|
43
|
+
def database_group(ctx: "click.Context") -> None:
|
|
44
|
+
"""Manage SQLSpec database components."""
|
|
45
|
+
ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
add_migration_commands(database_group)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
5
|
+
from sqlspec.extensions.litestar.handlers import (
|
|
6
|
+
autocommit_handler_maker,
|
|
7
|
+
connection_provider_maker,
|
|
8
|
+
lifespan_handler_maker,
|
|
9
|
+
manual_handler_maker,
|
|
10
|
+
pool_provider_maker,
|
|
11
|
+
session_provider_maker,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from collections.abc import AsyncGenerator, Awaitable
|
|
16
|
+
from contextlib import AbstractAsyncContextManager
|
|
17
|
+
|
|
18
|
+
from litestar import Litestar
|
|
19
|
+
from litestar.datastructures.state import State
|
|
20
|
+
from litestar.types import BeforeMessageSendHookHandler, Scope
|
|
21
|
+
|
|
22
|
+
from sqlspec.config import AsyncConfigT, DriverT, SyncConfigT
|
|
23
|
+
from sqlspec.typing import ConnectionT, PoolT
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
|
|
27
|
+
DEFAULT_COMMIT_MODE: CommitMode = "manual"
|
|
28
|
+
DEFAULT_CONNECTION_KEY = "db_connection"
|
|
29
|
+
DEFAULT_POOL_KEY = "db_pool"
|
|
30
|
+
DEFAULT_SESSION_KEY = "db_session"
|
|
31
|
+
|
|
32
|
+
__all__ = (
|
|
33
|
+
"DEFAULT_COMMIT_MODE",
|
|
34
|
+
"DEFAULT_CONNECTION_KEY",
|
|
35
|
+
"DEFAULT_POOL_KEY",
|
|
36
|
+
"DEFAULT_SESSION_KEY",
|
|
37
|
+
"CommitMode",
|
|
38
|
+
"DatabaseConfig",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class DatabaseConfig:
|
|
44
|
+
config: "Union[SyncConfigT, AsyncConfigT]" = field() # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
|
|
45
|
+
connection_key: str = field(default=DEFAULT_CONNECTION_KEY)
|
|
46
|
+
pool_key: str = field(default=DEFAULT_POOL_KEY)
|
|
47
|
+
session_key: str = field(default=DEFAULT_SESSION_KEY)
|
|
48
|
+
commit_mode: "CommitMode" = field(default=DEFAULT_COMMIT_MODE)
|
|
49
|
+
extra_commit_statuses: "Optional[set[int]]" = field(default=None)
|
|
50
|
+
extra_rollback_statuses: "Optional[set[int]]" = field(default=None)
|
|
51
|
+
enable_correlation_middleware: bool = field(default=True)
|
|
52
|
+
connection_provider: "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]" = field( # pyright: ignore[reportGeneralTypeIssues]
|
|
53
|
+
init=False, repr=False, hash=False
|
|
54
|
+
)
|
|
55
|
+
pool_provider: "Callable[[State,Scope], Awaitable[PoolT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
|
|
56
|
+
session_provider: "Callable[[Any], AsyncGenerator[DriverT, None]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
|
|
57
|
+
before_send_handler: "BeforeMessageSendHookHandler" = field(init=False, repr=False, hash=False)
|
|
58
|
+
lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field(
|
|
59
|
+
init=False, repr=False, hash=False
|
|
60
|
+
)
|
|
61
|
+
annotation: "type[Union[SyncConfigT, AsyncConfigT]]" = field(init=False, repr=False, hash=False) # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
|
|
62
|
+
|
|
63
|
+
def __post_init__(self) -> None:
|
|
64
|
+
if not self.config.supports_connection_pooling and self.pool_key == DEFAULT_POOL_KEY: # type: ignore[union-attr,unused-ignore]
|
|
65
|
+
self.pool_key = f"_{self.pool_key}_{id(self.config)}"
|
|
66
|
+
if self.commit_mode == "manual":
|
|
67
|
+
self.before_send_handler = manual_handler_maker(connection_scope_key=self.connection_key)
|
|
68
|
+
elif self.commit_mode == "autocommit":
|
|
69
|
+
self.before_send_handler = autocommit_handler_maker(
|
|
70
|
+
commit_on_redirect=False,
|
|
71
|
+
extra_commit_statuses=self.extra_commit_statuses,
|
|
72
|
+
extra_rollback_statuses=self.extra_rollback_statuses,
|
|
73
|
+
connection_scope_key=self.connection_key,
|
|
74
|
+
)
|
|
75
|
+
elif self.commit_mode == "autocommit_include_redirect":
|
|
76
|
+
self.before_send_handler = autocommit_handler_maker(
|
|
77
|
+
commit_on_redirect=True,
|
|
78
|
+
extra_commit_statuses=self.extra_commit_statuses,
|
|
79
|
+
extra_rollback_statuses=self.extra_rollback_statuses,
|
|
80
|
+
connection_scope_key=self.connection_key,
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
msg = f"Invalid commit mode: {self.commit_mode}"
|
|
84
|
+
raise ImproperConfigurationError(detail=msg)
|
|
85
|
+
self.lifespan_handler = lifespan_handler_maker(config=self.config, pool_key=self.pool_key)
|
|
86
|
+
self.connection_provider = connection_provider_maker(
|
|
87
|
+
connection_key=self.connection_key, pool_key=self.pool_key, config=self.config
|
|
88
|
+
)
|
|
89
|
+
self.pool_provider = pool_provider_maker(config=self.config, pool_key=self.pool_key)
|
|
90
|
+
self.session_provider = session_provider_maker(
|
|
91
|
+
config=self.config, connection_dependency_key=self.connection_key
|
|
92
|
+
)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import inspect
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
from contextlib import AbstractAsyncContextManager
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
|
6
|
+
|
|
7
|
+
from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
|
|
8
|
+
|
|
9
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
10
|
+
from sqlspec.extensions.litestar._utils import (
|
|
11
|
+
delete_sqlspec_scope_state,
|
|
12
|
+
get_sqlspec_scope_state,
|
|
13
|
+
set_sqlspec_scope_state,
|
|
14
|
+
)
|
|
15
|
+
from sqlspec.utils.sync_tools import ensure_async_
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Awaitable, Coroutine
|
|
19
|
+
|
|
20
|
+
from litestar import Litestar
|
|
21
|
+
from litestar.datastructures.state import State
|
|
22
|
+
from litestar.types import Message, Scope
|
|
23
|
+
|
|
24
|
+
from sqlspec.config import DatabaseConfigProtocol, DriverT
|
|
25
|
+
from sqlspec.typing import ConnectionT, PoolT
|
|
26
|
+
|
|
27
|
+
SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
|
|
28
|
+
|
|
29
|
+
__all__ = (
|
|
30
|
+
"SESSION_TERMINUS_ASGI_EVENTS",
|
|
31
|
+
"autocommit_handler_maker",
|
|
32
|
+
"connection_provider_maker",
|
|
33
|
+
"lifespan_handler_maker",
|
|
34
|
+
"manual_handler_maker",
|
|
35
|
+
"pool_provider_maker",
|
|
36
|
+
"session_provider_maker",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def manual_handler_maker(connection_scope_key: str) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
|
|
41
|
+
"""Create handler for manual connection management.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
connection_scope_key: The key used to store the connection in the ASGI scope.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The handler callable.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
async def handler(message: "Message", scope: "Scope") -> None:
|
|
51
|
+
"""Handle closing and cleaning up connections before sending the response.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
message: ASGI Message.
|
|
55
|
+
scope: ASGI Scope.
|
|
56
|
+
"""
|
|
57
|
+
connection = get_sqlspec_scope_state(scope, connection_scope_key)
|
|
58
|
+
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
59
|
+
if hasattr(connection, "close") and callable(connection.close):
|
|
60
|
+
await ensure_async_(connection.close)()
|
|
61
|
+
delete_sqlspec_scope_state(scope, connection_scope_key)
|
|
62
|
+
|
|
63
|
+
return handler
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def autocommit_handler_maker(
|
|
67
|
+
connection_scope_key: str,
|
|
68
|
+
commit_on_redirect: bool = False,
|
|
69
|
+
extra_commit_statuses: "Optional[set[int]]" = None,
|
|
70
|
+
extra_rollback_statuses: "Optional[set[int]]" = None,
|
|
71
|
+
) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
|
|
72
|
+
"""Create handler for automatic transaction commit/rollback based on response status.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
connection_scope_key: The key used to store the connection in the ASGI scope.
|
|
76
|
+
commit_on_redirect: Issue a commit when the response status is a redirect (3XX).
|
|
77
|
+
extra_commit_statuses: A set of additional status codes that trigger a commit.
|
|
78
|
+
extra_rollback_statuses: A set of additional status codes that trigger a rollback.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ImproperConfigurationError: If extra_commit_statuses and extra_rollback_statuses share status codes.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The handler callable.
|
|
85
|
+
"""
|
|
86
|
+
if extra_commit_statuses is None:
|
|
87
|
+
extra_commit_statuses = set()
|
|
88
|
+
|
|
89
|
+
if extra_rollback_statuses is None:
|
|
90
|
+
extra_rollback_statuses = set()
|
|
91
|
+
|
|
92
|
+
if len(extra_commit_statuses & extra_rollback_statuses) > 0:
|
|
93
|
+
msg = "Extra rollback statuses and commit statuses must not share any status codes"
|
|
94
|
+
raise ImproperConfigurationError(msg)
|
|
95
|
+
|
|
96
|
+
commit_range = range(200, 400 if commit_on_redirect else 300)
|
|
97
|
+
|
|
98
|
+
async def handler(message: "Message", scope: "Scope") -> None:
|
|
99
|
+
"""Handle commit/rollback, closing and cleaning up connections before sending.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
message: ASGI Message.
|
|
103
|
+
scope: ASGI Scope.
|
|
104
|
+
"""
|
|
105
|
+
connection = get_sqlspec_scope_state(scope, connection_scope_key)
|
|
106
|
+
try:
|
|
107
|
+
if connection is not None and message["type"] == HTTP_RESPONSE_START:
|
|
108
|
+
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
|
|
109
|
+
"status"
|
|
110
|
+
] not in extra_rollback_statuses:
|
|
111
|
+
if hasattr(connection, "commit") and callable(connection.commit):
|
|
112
|
+
await ensure_async_(connection.commit)()
|
|
113
|
+
elif hasattr(connection, "rollback") and callable(connection.rollback):
|
|
114
|
+
await ensure_async_(connection.rollback)()
|
|
115
|
+
finally:
|
|
116
|
+
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
117
|
+
if hasattr(connection, "close") and callable(connection.close):
|
|
118
|
+
await ensure_async_(connection.close)()
|
|
119
|
+
delete_sqlspec_scope_state(scope, connection_scope_key)
|
|
120
|
+
|
|
121
|
+
return handler
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def lifespan_handler_maker(
|
|
125
|
+
config: "DatabaseConfigProtocol[Any, Any, Any]", pool_key: str
|
|
126
|
+
) -> "Callable[[Litestar], AbstractAsyncContextManager[None]]":
|
|
127
|
+
"""Create lifespan handler for managing database connection pool lifecycle.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
config: The database configuration object.
|
|
131
|
+
pool_key: The key under which the connection pool will be stored in `app.state`.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
The lifespan handler function.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
@contextlib.asynccontextmanager
|
|
138
|
+
async def lifespan_handler(app: "Litestar") -> "AsyncGenerator[None, None]":
|
|
139
|
+
"""Manage database pool lifecycle for the application.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
app: The Litestar application instance.
|
|
143
|
+
|
|
144
|
+
Yields:
|
|
145
|
+
Control to application during pool lifetime.
|
|
146
|
+
"""
|
|
147
|
+
db_pool = await ensure_async_(config.create_pool)()
|
|
148
|
+
app.state.update({pool_key: db_pool})
|
|
149
|
+
try:
|
|
150
|
+
yield
|
|
151
|
+
finally:
|
|
152
|
+
app.state.pop(pool_key, None)
|
|
153
|
+
try:
|
|
154
|
+
await ensure_async_(config.close_pool)()
|
|
155
|
+
except Exception as e:
|
|
156
|
+
if app.logger:
|
|
157
|
+
app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e)
|
|
158
|
+
|
|
159
|
+
return lifespan_handler
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def pool_provider_maker(
|
|
163
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str
|
|
164
|
+
) -> "Callable[[State, Scope], Awaitable[PoolT]]":
|
|
165
|
+
"""Create provider for injecting the application-level database pool.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
config: The database configuration object.
|
|
169
|
+
pool_key: The key used to store the connection pool in `app.state`.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
The pool provider function.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
async def provide_pool(state: "State", scope: "Scope") -> "PoolT":
|
|
176
|
+
"""Provide the database pool from application state.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
state: The Litestar application State object.
|
|
180
|
+
scope: The ASGI scope (unused for app-level pool).
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
The database connection pool.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
ImproperConfigurationError: If the pool is not found in `app.state`.
|
|
187
|
+
"""
|
|
188
|
+
if (db_pool := state.get(pool_key)) is None:
|
|
189
|
+
msg = (
|
|
190
|
+
f"Database pool with key '{pool_key}' not found in application state. "
|
|
191
|
+
"Ensure the SQLSpec lifespan handler is correctly configured and has run."
|
|
192
|
+
)
|
|
193
|
+
raise ImproperConfigurationError(msg)
|
|
194
|
+
return cast("PoolT", db_pool)
|
|
195
|
+
|
|
196
|
+
return provide_pool
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def connection_provider_maker(
|
|
200
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str, connection_key: str
|
|
201
|
+
) -> "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]":
|
|
202
|
+
async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ConnectionT, None]":
|
|
203
|
+
if (db_pool := state.get(pool_key)) is None:
|
|
204
|
+
msg = f"Database pool with key '{pool_key}' not found. Cannot create a connection."
|
|
205
|
+
raise ImproperConfigurationError(msg)
|
|
206
|
+
|
|
207
|
+
connection_cm = config.provide_connection(db_pool)
|
|
208
|
+
|
|
209
|
+
if not isinstance(connection_cm, AbstractAsyncContextManager):
|
|
210
|
+
conn_instance: ConnectionT
|
|
211
|
+
if hasattr(connection_cm, "__await__"):
|
|
212
|
+
conn_instance = await cast("Awaitable[ConnectionT]", connection_cm)
|
|
213
|
+
else:
|
|
214
|
+
conn_instance = cast("ConnectionT", connection_cm)
|
|
215
|
+
set_sqlspec_scope_state(scope, connection_key, conn_instance)
|
|
216
|
+
yield conn_instance
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
entered_connection = await connection_cm.__aenter__()
|
|
220
|
+
try:
|
|
221
|
+
set_sqlspec_scope_state(scope, connection_key, entered_connection)
|
|
222
|
+
yield entered_connection
|
|
223
|
+
finally:
|
|
224
|
+
await connection_cm.__aexit__(None, None, None)
|
|
225
|
+
delete_sqlspec_scope_state(scope, connection_key)
|
|
226
|
+
|
|
227
|
+
return provide_connection
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def session_provider_maker(
|
|
231
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", connection_dependency_key: str
|
|
232
|
+
) -> "Callable[[Any], AsyncGenerator[DriverT, None]]":
|
|
233
|
+
async def provide_session(*args: Any, **kwargs: Any) -> "AsyncGenerator[DriverT, None]":
|
|
234
|
+
yield cast("DriverT", config.driver_type(connection=args[0] if args else kwargs.get(connection_dependency_key))) # pyright: ignore
|
|
235
|
+
|
|
236
|
+
conn_type_annotation = config.connection_type
|
|
237
|
+
|
|
238
|
+
from litestar.params import Dependency
|
|
239
|
+
|
|
240
|
+
db_conn_param = inspect.Parameter(
|
|
241
|
+
name=connection_dependency_key,
|
|
242
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
243
|
+
annotation=conn_type_annotation,
|
|
244
|
+
default=Dependency(skip_validation=True),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
provider_signature = inspect.Signature(
|
|
248
|
+
parameters=[db_conn_param],
|
|
249
|
+
return_annotation=AsyncGenerator[config.driver_type, None], # type: ignore[name-defined]
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
provide_session.__signature__ = provider_signature # type: ignore[attr-defined]
|
|
253
|
+
|
|
254
|
+
if not hasattr(provide_session, "__annotations__") or provide_session.__annotations__ is None:
|
|
255
|
+
provide_session.__annotations__ = {}
|
|
256
|
+
|
|
257
|
+
provide_session.__annotations__[connection_dependency_key] = conn_type_annotation
|
|
258
|
+
provide_session.__annotations__["return"] = AsyncGenerator[config.driver_type, None] # type: ignore[name-defined]
|
|
259
|
+
|
|
260
|
+
return provide_session
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
2
|
+
|
|
3
|
+
from litestar.di import Provide
|
|
4
|
+
from litestar.plugins import CLIPlugin, InitPluginProtocol
|
|
5
|
+
|
|
6
|
+
from sqlspec.base import SQLSpec as SQLSpecBase
|
|
7
|
+
from sqlspec.config import AsyncConfigT, DatabaseConfigProtocol, DriverT, SyncConfigT
|
|
8
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
9
|
+
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
10
|
+
from sqlspec.typing import ConnectionT, PoolT
|
|
11
|
+
from sqlspec.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from click import Group
|
|
15
|
+
from litestar.config.app import AppConfig
|
|
16
|
+
|
|
17
|
+
logger = get_logger("extensions.litestar")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SQLSpec(InitPluginProtocol, CLIPlugin, SQLSpecBase):
|
|
21
|
+
"""Litestar plugin for SQLSpec database integration."""
|
|
22
|
+
|
|
23
|
+
__slots__ = ("_config", "_plugin_configs")
|
|
24
|
+
|
|
25
|
+
def __init__(self, config: Union["SyncConfigT", "AsyncConfigT", "DatabaseConfig", list["DatabaseConfig"]]) -> None:
|
|
26
|
+
"""Initialize SQLSpec plugin.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
config: Database configuration for SQLSpec plugin.
|
|
30
|
+
"""
|
|
31
|
+
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
|
|
32
|
+
if isinstance(config, DatabaseConfigProtocol):
|
|
33
|
+
self._plugin_configs: list[DatabaseConfig] = [DatabaseConfig(config=config)]
|
|
34
|
+
elif isinstance(config, DatabaseConfig):
|
|
35
|
+
self._plugin_configs = [config]
|
|
36
|
+
else:
|
|
37
|
+
self._plugin_configs = config
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def config(self) -> "list[DatabaseConfig]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
41
|
+
"""Return the plugin configuration.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
List of database configurations.
|
|
45
|
+
"""
|
|
46
|
+
return self._plugin_configs
|
|
47
|
+
|
|
48
|
+
def on_cli_init(self, cli: "Group") -> None:
|
|
49
|
+
"""Configure CLI commands for SQLSpec database operations.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
cli: The Click command group to add commands to.
|
|
53
|
+
"""
|
|
54
|
+
from sqlspec.extensions.litestar.cli import database_group
|
|
55
|
+
|
|
56
|
+
cli.add_command(database_group)
|
|
57
|
+
|
|
58
|
+
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
|
|
59
|
+
"""Configure Litestar application with SQLSpec database integration.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
app_config: The Litestar application configuration instance.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The updated application configuration instance.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
self._validate_dependency_keys()
|
|
69
|
+
|
|
70
|
+
def store_sqlspec_in_state() -> None:
|
|
71
|
+
app_config.state.sqlspec = self
|
|
72
|
+
|
|
73
|
+
app_config.on_startup.append(store_sqlspec_in_state)
|
|
74
|
+
app_config.signature_types.extend(
|
|
75
|
+
[SQLSpec, ConnectionT, PoolT, DriverT, DatabaseConfig, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT]
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
signature_namespace = {}
|
|
79
|
+
|
|
80
|
+
for c in self._plugin_configs:
|
|
81
|
+
c.annotation = self.add_config(c.config)
|
|
82
|
+
app_config.signature_types.append(c.annotation)
|
|
83
|
+
app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr]
|
|
84
|
+
app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr]
|
|
85
|
+
|
|
86
|
+
if hasattr(c.config, "get_signature_namespace"):
|
|
87
|
+
signature_namespace.update(c.config.get_signature_namespace()) # type: ignore[attr-defined]
|
|
88
|
+
|
|
89
|
+
app_config.before_send.append(c.before_send_handler)
|
|
90
|
+
app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType]
|
|
91
|
+
app_config.dependencies.update(
|
|
92
|
+
{
|
|
93
|
+
c.connection_key: Provide(c.connection_provider),
|
|
94
|
+
c.pool_key: Provide(c.pool_provider),
|
|
95
|
+
c.session_key: Provide(c.session_provider),
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if signature_namespace:
|
|
100
|
+
app_config.signature_namespace.update(signature_namespace)
|
|
101
|
+
|
|
102
|
+
return app_config
|
|
103
|
+
|
|
104
|
+
def get_annotations(self) -> "list[type[Union[SyncConfigT, AsyncConfigT]]]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
105
|
+
"""Return the list of annotations.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
List of annotations.
|
|
109
|
+
"""
|
|
110
|
+
return [c.annotation for c in self.config]
|
|
111
|
+
|
|
112
|
+
def get_annotation(
|
|
113
|
+
self, key: "Union[str, SyncConfigT, AsyncConfigT, type[Union[SyncConfigT, AsyncConfigT]]]"
|
|
114
|
+
) -> "type[Union[SyncConfigT, AsyncConfigT]]":
|
|
115
|
+
"""Return the annotation for the given configuration.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
key: The configuration instance or key to lookup
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
KeyError: If no configuration is found for the given key.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The annotation for the configuration.
|
|
125
|
+
"""
|
|
126
|
+
for c in self.config:
|
|
127
|
+
if key == c.config or key in {c.annotation, c.connection_key, c.pool_key}:
|
|
128
|
+
return c.annotation
|
|
129
|
+
msg = f"No configuration found for {key}"
|
|
130
|
+
raise KeyError(msg)
|
|
131
|
+
|
|
132
|
+
def _validate_dependency_keys(self) -> None:
|
|
133
|
+
"""Validate that connection and pool keys are unique across configurations.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ImproperConfigurationError: If connection keys or pool keys are not unique.
|
|
137
|
+
"""
|
|
138
|
+
connection_keys = [c.connection_key for c in self.config]
|
|
139
|
+
pool_keys = [c.pool_key for c in self.config]
|
|
140
|
+
if len(set(connection_keys)) != len(connection_keys):
|
|
141
|
+
msg = "When using multiple database configuration, each configuration must have a unique `connection_key`."
|
|
142
|
+
raise ImproperConfigurationError(detail=msg)
|
|
143
|
+
if len(set(pool_keys)) != len(pool_keys):
|
|
144
|
+
msg = "When using multiple database configuration, each configuration must have a unique `pool_key`."
|
|
145
|
+
raise ImproperConfigurationError(detail=msg)
|