sqlspec 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +15 -0
- sqlspec/_serialization.py +16 -2
- sqlspec/_typing.py +40 -7
- sqlspec/adapters/adbc/__init__.py +7 -0
- sqlspec/adapters/adbc/config.py +183 -17
- sqlspec/adapters/adbc/driver.py +392 -0
- sqlspec/adapters/aiosqlite/__init__.py +5 -1
- sqlspec/adapters/aiosqlite/config.py +24 -6
- sqlspec/adapters/aiosqlite/driver.py +264 -0
- sqlspec/adapters/asyncmy/__init__.py +7 -2
- sqlspec/adapters/asyncmy/config.py +71 -11
- sqlspec/adapters/asyncmy/driver.py +246 -0
- sqlspec/adapters/asyncpg/__init__.py +9 -0
- sqlspec/adapters/asyncpg/config.py +102 -25
- sqlspec/adapters/asyncpg/driver.py +444 -0
- sqlspec/adapters/duckdb/__init__.py +5 -1
- sqlspec/adapters/duckdb/config.py +194 -12
- sqlspec/adapters/duckdb/driver.py +225 -0
- sqlspec/adapters/oracledb/__init__.py +7 -4
- sqlspec/adapters/oracledb/config/__init__.py +4 -4
- sqlspec/adapters/oracledb/config/_asyncio.py +96 -12
- sqlspec/adapters/oracledb/config/_common.py +1 -1
- sqlspec/adapters/oracledb/config/_sync.py +96 -12
- sqlspec/adapters/oracledb/driver.py +571 -0
- sqlspec/adapters/psqlpy/__init__.py +0 -0
- sqlspec/adapters/psqlpy/config.py +258 -0
- sqlspec/adapters/psqlpy/driver.py +335 -0
- sqlspec/adapters/psycopg/__init__.py +16 -0
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +107 -15
- sqlspec/adapters/psycopg/config/_common.py +2 -2
- sqlspec/adapters/psycopg/config/_sync.py +107 -15
- sqlspec/adapters/psycopg/driver.py +578 -0
- sqlspec/adapters/sqlite/__init__.py +7 -0
- sqlspec/adapters/sqlite/config.py +24 -6
- sqlspec/adapters/sqlite/driver.py +305 -0
- sqlspec/base.py +565 -63
- sqlspec/exceptions.py +30 -0
- sqlspec/extensions/litestar/__init__.py +19 -0
- sqlspec/extensions/litestar/_utils.py +56 -0
- sqlspec/extensions/litestar/config.py +87 -0
- sqlspec/extensions/litestar/handlers.py +213 -0
- sqlspec/extensions/litestar/plugin.py +105 -11
- sqlspec/statement.py +373 -0
- sqlspec/typing.py +81 -17
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/fixtures.py +4 -5
- sqlspec/utils/sync_tools.py +335 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/METADATA +4 -1
- sqlspec-0.9.0.dist-info/RECORD +61 -0
- sqlspec-0.7.1.dist-info/RECORD +0 -46
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/exceptions.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from contextlib import contextmanager
|
|
1
3
|
from typing import Any, Optional
|
|
2
4
|
|
|
3
5
|
__all__ = (
|
|
@@ -6,7 +8,9 @@ __all__ = (
|
|
|
6
8
|
"MissingDependencyError",
|
|
7
9
|
"MultipleResultsFoundError",
|
|
8
10
|
"NotFoundError",
|
|
11
|
+
"ParameterStyleMismatchError",
|
|
9
12
|
"RepositoryError",
|
|
13
|
+
"SQLParsingError",
|
|
10
14
|
"SQLSpecError",
|
|
11
15
|
"SerializationError",
|
|
12
16
|
)
|
|
@@ -74,6 +78,20 @@ class SQLParsingError(SQLSpecError):
|
|
|
74
78
|
super().__init__(message)
|
|
75
79
|
|
|
76
80
|
|
|
81
|
+
class ParameterStyleMismatchError(SQLSpecError):
|
|
82
|
+
"""Error when parameter style doesn't match SQL placeholder style.
|
|
83
|
+
|
|
84
|
+
This exception is raised when there's a mismatch between the parameter type
|
|
85
|
+
(dictionary, tuple, etc.) and the placeholder style in the SQL query
|
|
86
|
+
(named, positional, etc.).
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, message: Optional[str] = None) -> None:
|
|
90
|
+
if message is None:
|
|
91
|
+
message = "Parameter style mismatch: dictionary parameters provided but no named placeholders found in SQL."
|
|
92
|
+
super().__init__(message)
|
|
93
|
+
|
|
94
|
+
|
|
77
95
|
class ImproperConfigurationError(SQLSpecError):
|
|
78
96
|
"""Improper Configuration error.
|
|
79
97
|
|
|
@@ -99,3 +117,15 @@ class NotFoundError(RepositoryError):
|
|
|
99
117
|
|
|
100
118
|
class MultipleResultsFoundError(RepositoryError):
|
|
101
119
|
"""A single database result was required but more than one were found."""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@contextmanager
|
|
123
|
+
def wrap_exceptions(wrap_exceptions: bool = True) -> Generator[None, None, None]:
|
|
124
|
+
try:
|
|
125
|
+
yield
|
|
126
|
+
|
|
127
|
+
except Exception as exc:
|
|
128
|
+
if wrap_exceptions is False:
|
|
129
|
+
raise
|
|
130
|
+
msg = "An error occurred during the operation."
|
|
131
|
+
raise RepositoryError(detail=msg) from exc
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
2
|
+
from sqlspec.extensions.litestar.handlers import (
|
|
3
|
+
autocommit_handler_maker,
|
|
4
|
+
connection_provider_maker,
|
|
5
|
+
lifespan_handler_maker,
|
|
6
|
+
manual_handler_maker,
|
|
7
|
+
pool_provider_maker,
|
|
8
|
+
)
|
|
9
|
+
from sqlspec.extensions.litestar.plugin import SQLSpec
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
"DatabaseConfig",
|
|
13
|
+
"SQLSpec",
|
|
14
|
+
"autocommit_handler_maker",
|
|
15
|
+
"connection_provider_maker",
|
|
16
|
+
"lifespan_handler_maker",
|
|
17
|
+
"manual_handler_maker",
|
|
18
|
+
"pool_provider_maker",
|
|
19
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from litestar.types import Scope
|
|
5
|
+
|
|
6
|
+
__all__ = (
|
|
7
|
+
"delete_sqlspec_scope_state",
|
|
8
|
+
"get_sqlspec_scope_state",
|
|
9
|
+
"set_sqlspec_scope_state",
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
_SCOPE_NAMESPACE = "_sqlspec"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_sqlspec_scope_state(scope: "Scope", key: str, default: Any = None, pop: bool = False) -> Any:
|
|
16
|
+
"""Get an internal value from connection scope state.
|
|
17
|
+
|
|
18
|
+
Note:
|
|
19
|
+
If called with a default value, this method behaves like to `dict.set_default()`, both setting the key in the
|
|
20
|
+
namespace to the default value, and returning it.
|
|
21
|
+
|
|
22
|
+
If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
|
|
23
|
+
exist.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
scope: The connection scope.
|
|
27
|
+
key: Key to get from internal namespace in scope state.
|
|
28
|
+
default: Default value to return.
|
|
29
|
+
pop: Boolean flag dictating whether the value should be deleted from the state.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Value mapped to ``key`` in internal connection scope namespace.
|
|
33
|
+
"""
|
|
34
|
+
namespace = scope.setdefault(_SCOPE_NAMESPACE, {}) # type: ignore[misc]
|
|
35
|
+
return namespace.pop(key, default) if pop else namespace.get(key, default) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def set_sqlspec_scope_state(scope: "Scope", key: str, value: Any) -> None:
|
|
39
|
+
"""Set an internal value in connection scope state.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
scope: The connection scope.
|
|
43
|
+
key: Key to set under internal namespace in scope state.
|
|
44
|
+
value: Value for key.
|
|
45
|
+
"""
|
|
46
|
+
scope.setdefault(_SCOPE_NAMESPACE, {})[key] = value # type: ignore[misc]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def delete_sqlspec_scope_state(scope: "Scope", key: str) -> None:
|
|
50
|
+
"""Remove an internal value from connection scope state.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
scope: The connection scope.
|
|
54
|
+
key: Key to set under internal namespace in scope state.
|
|
55
|
+
"""
|
|
56
|
+
del scope.setdefault(_SCOPE_NAMESPACE, {})[key] # type: ignore[misc]
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
from sqlspec.base import (
|
|
5
|
+
ConnectionT,
|
|
6
|
+
PoolT,
|
|
7
|
+
)
|
|
8
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
9
|
+
from sqlspec.extensions.litestar.handlers import (
|
|
10
|
+
autocommit_handler_maker,
|
|
11
|
+
connection_provider_maker,
|
|
12
|
+
lifespan_handler_maker,
|
|
13
|
+
manual_handler_maker,
|
|
14
|
+
pool_provider_maker,
|
|
15
|
+
session_provider_maker,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from collections.abc import Awaitable
|
|
20
|
+
from contextlib import AbstractAsyncContextManager
|
|
21
|
+
|
|
22
|
+
from litestar import Litestar
|
|
23
|
+
from litestar.datastructures.state import State
|
|
24
|
+
from litestar.types import BeforeMessageSendHookHandler, Scope
|
|
25
|
+
|
|
26
|
+
from sqlspec.base import (
|
|
27
|
+
AsyncConfigT,
|
|
28
|
+
ConnectionT,
|
|
29
|
+
DriverT,
|
|
30
|
+
PoolT,
|
|
31
|
+
SyncConfigT,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
|
|
35
|
+
DEFAULT_COMMIT_MODE: CommitMode = "manual"
|
|
36
|
+
DEFAULT_CONNECTION_KEY = "db_connection"
|
|
37
|
+
DEFAULT_POOL_KEY = "db_pool"
|
|
38
|
+
DEFAULT_SESSION_KEY = "db_session"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class DatabaseConfig:
|
|
43
|
+
config: "Union[SyncConfigT, AsyncConfigT]" = field() # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
|
|
44
|
+
connection_key: str = field(default=DEFAULT_CONNECTION_KEY)
|
|
45
|
+
pool_key: str = field(default=DEFAULT_POOL_KEY)
|
|
46
|
+
session_key: str = field(default=DEFAULT_SESSION_KEY)
|
|
47
|
+
commit_mode: "CommitMode" = field(default=DEFAULT_COMMIT_MODE)
|
|
48
|
+
extra_commit_statuses: "Optional[set[int]]" = field(default=None)
|
|
49
|
+
extra_rollback_statuses: "Optional[set[int]]" = field(default=None)
|
|
50
|
+
connection_provider: "Callable[[State,Scope], Awaitable[ConnectionT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
|
|
51
|
+
pool_provider: "Callable[[State,Scope], Awaitable[PoolT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
|
|
52
|
+
session_provider: "Callable[[State,Scope], Awaitable[DriverT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
|
|
53
|
+
before_send_handler: "BeforeMessageSendHookHandler" = field(init=False, repr=False, hash=False)
|
|
54
|
+
lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field(
|
|
55
|
+
init=False,
|
|
56
|
+
repr=False,
|
|
57
|
+
hash=False,
|
|
58
|
+
)
|
|
59
|
+
annotation: "type[Union[SyncConfigT, AsyncConfigT]]" = field(init=False, repr=False, hash=False) # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
|
|
60
|
+
|
|
61
|
+
def __post_init__(self) -> None:
|
|
62
|
+
if not self.config.support_connection_pooling and self.pool_key == DEFAULT_POOL_KEY: # type: ignore[union-attr,unused-ignore]
|
|
63
|
+
"""If the database configuration does not support connection pooling, the pool key must be unique. We just automatically generate a unique identify so it won't conflict with other configs that may get added"""
|
|
64
|
+
self.pool_key = f"_{self.pool_key}_{id(self.config)}"
|
|
65
|
+
if self.commit_mode == "manual":
|
|
66
|
+
self.before_send_handler = manual_handler_maker(connection_scope_key=self.connection_key)
|
|
67
|
+
elif self.commit_mode == "autocommit":
|
|
68
|
+
self.before_send_handler = autocommit_handler_maker(
|
|
69
|
+
commit_on_redirect=False,
|
|
70
|
+
extra_commit_statuses=self.extra_commit_statuses,
|
|
71
|
+
extra_rollback_statuses=self.extra_rollback_statuses,
|
|
72
|
+
connection_scope_key=self.connection_key,
|
|
73
|
+
)
|
|
74
|
+
elif self.commit_mode == "autocommit_include_redirect":
|
|
75
|
+
self.before_send_handler = autocommit_handler_maker(
|
|
76
|
+
commit_on_redirect=True,
|
|
77
|
+
extra_commit_statuses=self.extra_commit_statuses,
|
|
78
|
+
extra_rollback_statuses=self.extra_rollback_statuses,
|
|
79
|
+
connection_scope_key=self.connection_key,
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
msg = f"Invalid commit mode: {self.commit_mode}" # type: ignore[unreachable]
|
|
83
|
+
raise ImproperConfigurationError(detail=msg)
|
|
84
|
+
self.lifespan_handler = lifespan_handler_maker(config=self.config, pool_key=self.pool_key)
|
|
85
|
+
self.connection_provider = connection_provider_maker(connection_key=self.connection_key, config=self.config)
|
|
86
|
+
self.pool_provider = pool_provider_maker(pool_key=self.pool_key, config=self.config)
|
|
87
|
+
self.session_provider = session_provider_maker(session_key=self.session_key, config=self.config)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
|
3
|
+
|
|
4
|
+
from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
|
|
5
|
+
|
|
6
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
7
|
+
from sqlspec.extensions.litestar._utils import (
|
|
8
|
+
delete_sqlspec_scope_state,
|
|
9
|
+
get_sqlspec_scope_state,
|
|
10
|
+
set_sqlspec_scope_state,
|
|
11
|
+
)
|
|
12
|
+
from sqlspec.utils.sync_tools import maybe_async_
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from collections.abc import AsyncGenerator, Awaitable, Coroutine
|
|
16
|
+
from contextlib import AbstractAsyncContextManager
|
|
17
|
+
|
|
18
|
+
from litestar import Litestar
|
|
19
|
+
from litestar.datastructures.state import State
|
|
20
|
+
from litestar.types import Message, Scope
|
|
21
|
+
|
|
22
|
+
from sqlspec.base import ConnectionT, DatabaseConfigProtocol, DriverT, PoolT
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
|
|
26
|
+
"""ASGI events that terminate a session scope."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def manual_handler_maker(connection_scope_key: str) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
|
|
30
|
+
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
|
|
31
|
+
Args:
|
|
32
|
+
connection_scope_key: The key to use within the application state
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
The handler callable
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
async def handler(message: "Message", scope: "Scope") -> None:
|
|
39
|
+
"""Handle commit/rollback, closing and cleaning up sessions before sending.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
message: ASGI-``Message``
|
|
43
|
+
scope: An ASGI-``Scope``
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
connection = get_sqlspec_scope_state(scope, connection_scope_key)
|
|
47
|
+
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
48
|
+
with contextlib.suppress(Exception):
|
|
49
|
+
await maybe_async_(connection.close)()
|
|
50
|
+
delete_sqlspec_scope_state(scope, connection_scope_key)
|
|
51
|
+
|
|
52
|
+
return handler
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def autocommit_handler_maker(
|
|
56
|
+
connection_scope_key: str,
|
|
57
|
+
commit_on_redirect: bool = False,
|
|
58
|
+
extra_commit_statuses: "Optional[set[int]]" = None,
|
|
59
|
+
extra_rollback_statuses: "Optional[set[int]]" = None,
|
|
60
|
+
) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
|
|
61
|
+
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
|
|
62
|
+
Args:
|
|
63
|
+
commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
|
|
64
|
+
extra_commit_statuses: A set of additional status codes that trigger a commit
|
|
65
|
+
extra_rollback_statuses: A set of additional status codes that trigger a rollback
|
|
66
|
+
connection_scope_key: The key to use within the application state
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ImproperConfigurationError: If extra_commit_statuses and extra_rollback_statuses share any status codes
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The handler callable
|
|
73
|
+
"""
|
|
74
|
+
if extra_commit_statuses is None:
|
|
75
|
+
extra_commit_statuses = set()
|
|
76
|
+
|
|
77
|
+
if extra_rollback_statuses is None:
|
|
78
|
+
extra_rollback_statuses = set()
|
|
79
|
+
|
|
80
|
+
if len(extra_commit_statuses & extra_rollback_statuses) > 0:
|
|
81
|
+
msg = "Extra rollback statuses and commit statuses must not share any status codes"
|
|
82
|
+
raise ImproperConfigurationError(msg)
|
|
83
|
+
|
|
84
|
+
commit_range = range(200, 400 if commit_on_redirect else 300)
|
|
85
|
+
|
|
86
|
+
async def handler(message: "Message", scope: "Scope") -> None:
|
|
87
|
+
"""Handle commit/rollback, closing and cleaning up sessions before sending.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
message: ASGI-``litestar.types.Message``
|
|
91
|
+
scope: An ASGI-``litestar.types.Scope``
|
|
92
|
+
|
|
93
|
+
"""
|
|
94
|
+
connection = get_sqlspec_scope_state(scope, connection_scope_key)
|
|
95
|
+
try:
|
|
96
|
+
if connection is not None and message["type"] == HTTP_RESPONSE_START:
|
|
97
|
+
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
|
|
98
|
+
"status"
|
|
99
|
+
] not in extra_rollback_statuses:
|
|
100
|
+
await maybe_async_(connection.commit)()
|
|
101
|
+
else:
|
|
102
|
+
await maybe_async_(connection.rollback)()
|
|
103
|
+
finally:
|
|
104
|
+
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
105
|
+
with contextlib.suppress(Exception):
|
|
106
|
+
await maybe_async_(connection.close)()
|
|
107
|
+
delete_sqlspec_scope_state(scope, connection_scope_key)
|
|
108
|
+
|
|
109
|
+
return handler
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def lifespan_handler_maker(
|
|
113
|
+
config: "DatabaseConfigProtocol[Any, Any, Any]",
|
|
114
|
+
pool_key: str,
|
|
115
|
+
) -> "Callable[[Litestar], AbstractAsyncContextManager[None]]":
|
|
116
|
+
"""Build the lifespan handler for the database configuration.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
config: The database configuration.
|
|
120
|
+
pool_key: The key to use for the connection pool within Litestar.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The generated lifespan handler for the connection.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
@contextlib.asynccontextmanager
|
|
127
|
+
async def lifespan_handler(app: "Litestar") -> "AsyncGenerator[None, None]":
|
|
128
|
+
db_pool = await maybe_async_(config.create_pool)()
|
|
129
|
+
app.state.update({pool_key: db_pool})
|
|
130
|
+
try:
|
|
131
|
+
yield
|
|
132
|
+
finally:
|
|
133
|
+
app.state.pop(pool_key, None)
|
|
134
|
+
try:
|
|
135
|
+
await maybe_async_(config.close_pool)()
|
|
136
|
+
except Exception as e: # noqa: BLE001
|
|
137
|
+
if app.logger:
|
|
138
|
+
app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e)
|
|
139
|
+
|
|
140
|
+
return lifespan_handler
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def connection_provider_maker(
|
|
144
|
+
connection_key: str,
|
|
145
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
|
|
146
|
+
) -> "Callable[[State,Scope], Awaitable[ConnectionT]]":
|
|
147
|
+
"""Build the connection provider for the database configuration.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
connection_key: The dependency key to use for the session within Litestar.
|
|
151
|
+
config: The database configuration.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
The generated connection provider for the connection.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
async def provide_connection(state: "State", scope: "Scope") -> "ConnectionT":
|
|
158
|
+
connection = get_sqlspec_scope_state(scope, connection_key)
|
|
159
|
+
if connection is None:
|
|
160
|
+
connection = await maybe_async_(config.create_connection)()
|
|
161
|
+
set_sqlspec_scope_state(scope, connection_key, connection)
|
|
162
|
+
return cast("ConnectionT", connection)
|
|
163
|
+
|
|
164
|
+
return provide_connection
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def pool_provider_maker(
|
|
168
|
+
pool_key: str,
|
|
169
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
|
|
170
|
+
) -> "Callable[[State,Scope], Awaitable[PoolT]]":
|
|
171
|
+
"""Build the pool provider for the database configuration.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
pool_key: The dependency key to use for the pool within Litestar.
|
|
175
|
+
config: The database configuration.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
The generated connection pool for the database.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
async def provide_pool(state: "State", scope: "Scope") -> "PoolT":
|
|
182
|
+
pool = get_sqlspec_scope_state(scope, pool_key)
|
|
183
|
+
if pool is None:
|
|
184
|
+
pool = await maybe_async_(config.create_pool)()
|
|
185
|
+
set_sqlspec_scope_state(scope, pool_key, pool)
|
|
186
|
+
return cast("PoolT", pool)
|
|
187
|
+
|
|
188
|
+
return provide_pool
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def session_provider_maker(
|
|
192
|
+
session_key: str,
|
|
193
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
|
|
194
|
+
) -> "Callable[[State,Scope], Awaitable[DriverT]]":
|
|
195
|
+
"""Build the session provider for the database configuration.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
session_key: The dependency key to use for the session within Litestar.
|
|
199
|
+
config: The database configuration.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
The generated session provider for the database.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
async def provide_session(state: "State", scope: "Scope") -> "DriverT":
|
|
206
|
+
session = get_sqlspec_scope_state(scope, session_key)
|
|
207
|
+
if session is None:
|
|
208
|
+
connection = await maybe_async_(config.create_connection)()
|
|
209
|
+
session = config.driver_type(connection=connection) # pyright: ignore[reportCallIssue]
|
|
210
|
+
set_sqlspec_scope_state(scope, session_key, session)
|
|
211
|
+
return cast("DriverT", session)
|
|
212
|
+
|
|
213
|
+
return provide_session
|
|
@@ -1,34 +1,64 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
2
2
|
|
|
3
|
+
from litestar.di import Provide
|
|
3
4
|
from litestar.plugins import InitPluginProtocol
|
|
4
5
|
|
|
6
|
+
from sqlspec.base import (
|
|
7
|
+
AsyncConfigT,
|
|
8
|
+
ConnectionT,
|
|
9
|
+
DatabaseConfigProtocol,
|
|
10
|
+
PoolT,
|
|
11
|
+
SyncConfigT,
|
|
12
|
+
)
|
|
13
|
+
from sqlspec.base import SQLSpec as SQLSpecBase
|
|
14
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
15
|
+
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
16
|
+
|
|
5
17
|
if TYPE_CHECKING:
|
|
18
|
+
from click import Group
|
|
6
19
|
from litestar.config.app import AppConfig
|
|
7
20
|
|
|
8
|
-
|
|
21
|
+
|
|
22
|
+
CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
|
|
23
|
+
DEFAULT_COMMIT_MODE: CommitMode = "manual"
|
|
24
|
+
DEFAULT_CONNECTION_KEY = "db_connection"
|
|
25
|
+
DEFAULT_POOL_KEY = "db_pool"
|
|
26
|
+
DEFAULT_SESSION_KEY = "db_session"
|
|
9
27
|
|
|
10
28
|
|
|
11
|
-
class
|
|
29
|
+
class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
12
30
|
"""SQLSpec plugin."""
|
|
13
31
|
|
|
14
|
-
__slots__ = ("_config",)
|
|
32
|
+
__slots__ = ("_config", "_plugin_configs")
|
|
15
33
|
|
|
16
|
-
def __init__(
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
config: Union["SyncConfigT", "AsyncConfigT", "DatabaseConfig", list["DatabaseConfig"]],
|
|
37
|
+
) -> None:
|
|
17
38
|
"""Initialize ``SQLSpecPlugin``.
|
|
18
39
|
|
|
19
40
|
Args:
|
|
20
41
|
config: configure SQLSpec plugin for use with Litestar.
|
|
21
42
|
"""
|
|
22
|
-
self.
|
|
43
|
+
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
|
|
44
|
+
if isinstance(config, DatabaseConfigProtocol):
|
|
45
|
+
self._plugin_configs: list[DatabaseConfig] = [DatabaseConfig(config=config)]
|
|
46
|
+
elif isinstance(config, DatabaseConfig):
|
|
47
|
+
self._plugin_configs = [config]
|
|
48
|
+
else:
|
|
49
|
+
self._plugin_configs = config
|
|
23
50
|
|
|
24
51
|
@property
|
|
25
|
-
def config(self) -> "
|
|
52
|
+
def config(self) -> "list[DatabaseConfig]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
26
53
|
"""Return the plugin config.
|
|
27
54
|
|
|
28
55
|
Returns:
|
|
29
56
|
ConfigManager.
|
|
30
57
|
"""
|
|
31
|
-
return self.
|
|
58
|
+
return self._plugin_configs
|
|
59
|
+
|
|
60
|
+
def on_cli_init(self, cli: "Group") -> None:
|
|
61
|
+
"""Configure the CLI for use with SQLSpec."""
|
|
32
62
|
|
|
33
63
|
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
|
|
34
64
|
"""Configure application for use with SQLSpec.
|
|
@@ -39,8 +69,72 @@ class SQLSpecPlugin(InitPluginProtocol):
|
|
|
39
69
|
Returns:
|
|
40
70
|
The updated :class:`AppConfig <.config.app.AppConfig>` instance.
|
|
41
71
|
"""
|
|
72
|
+
self._validate_dependency_keys()
|
|
73
|
+
app_config.signature_types.extend(
|
|
74
|
+
[
|
|
75
|
+
SQLSpec,
|
|
76
|
+
ConnectionT,
|
|
77
|
+
PoolT,
|
|
78
|
+
DatabaseConfig,
|
|
79
|
+
DatabaseConfigProtocol,
|
|
80
|
+
SyncConfigT,
|
|
81
|
+
AsyncConfigT,
|
|
82
|
+
]
|
|
83
|
+
)
|
|
84
|
+
for c in self._plugin_configs:
|
|
85
|
+
c.annotation = self.add_config(c.config)
|
|
86
|
+
app_config.before_send.append(c.before_send_handler)
|
|
87
|
+
app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType]
|
|
88
|
+
app_config.dependencies.update(
|
|
89
|
+
{
|
|
90
|
+
c.connection_key: Provide(c.connection_provider),
|
|
91
|
+
c.pool_key: Provide(c.pool_provider),
|
|
92
|
+
c.session_key: Provide(c.session_provider),
|
|
93
|
+
},
|
|
94
|
+
)
|
|
42
95
|
|
|
43
|
-
from sqlspec.base import ConfigManager
|
|
44
|
-
|
|
45
|
-
app_config.signature_types.append(ConfigManager)
|
|
46
96
|
return app_config
|
|
97
|
+
|
|
98
|
+
def get_annotations(self) -> "list[type[Union[SyncConfigT, AsyncConfigT]]]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
99
|
+
"""Return the list of annotations.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of annotations.
|
|
103
|
+
"""
|
|
104
|
+
return [c.annotation for c in self.config]
|
|
105
|
+
|
|
106
|
+
def get_annotation(
|
|
107
|
+
self,
|
|
108
|
+
key: "Union[str, SyncConfigT, AsyncConfigT, type[Union[SyncConfigT, AsyncConfigT]]]",
|
|
109
|
+
) -> "type[Union[SyncConfigT, AsyncConfigT]]":
|
|
110
|
+
"""Return the annotation for the given configuration.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
key: The configuration instance or key to lookup
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
KeyError: If no configuration is found for the given key.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
The annotation for the configuration.
|
|
120
|
+
"""
|
|
121
|
+
for c in self.config:
|
|
122
|
+
if key == c.config or key in {c.annotation, c.connection_key, c.pool_key}:
|
|
123
|
+
return c.annotation
|
|
124
|
+
msg = f"No configuration found for {key}"
|
|
125
|
+
raise KeyError(msg)
|
|
126
|
+
|
|
127
|
+
def _validate_dependency_keys(self) -> None:
|
|
128
|
+
"""Verify uniqueness of ``connection_key`` and ``pool_key``.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
ImproperConfigurationError: If session keys or pool keys are not unique.
|
|
132
|
+
"""
|
|
133
|
+
connection_keys = [c.connection_key for c in self.config]
|
|
134
|
+
pool_keys = [c.pool_key for c in self.config]
|
|
135
|
+
if len(set(connection_keys)) != len(connection_keys):
|
|
136
|
+
msg = "When using multiple database configuration, each configuration must have a unique `connection_key`."
|
|
137
|
+
raise ImproperConfigurationError(detail=msg)
|
|
138
|
+
if len(set(pool_keys)) != len(pool_keys):
|
|
139
|
+
msg = "When using multiple database configuration, each configuration must have a unique `pool_key`."
|
|
140
|
+
raise ImproperConfigurationError(detail=msg)
|