sqlspec 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +15 -0
- sqlspec/_serialization.py +16 -2
- sqlspec/_typing.py +1 -1
- sqlspec/adapters/adbc/__init__.py +7 -0
- sqlspec/adapters/adbc/config.py +160 -17
- sqlspec/adapters/adbc/driver.py +333 -0
- sqlspec/adapters/aiosqlite/__init__.py +6 -2
- sqlspec/adapters/aiosqlite/config.py +25 -7
- sqlspec/adapters/aiosqlite/driver.py +275 -0
- sqlspec/adapters/asyncmy/__init__.py +7 -2
- sqlspec/adapters/asyncmy/config.py +75 -14
- sqlspec/adapters/asyncmy/driver.py +255 -0
- sqlspec/adapters/asyncpg/__init__.py +9 -0
- sqlspec/adapters/asyncpg/config.py +99 -20
- sqlspec/adapters/asyncpg/driver.py +288 -0
- sqlspec/adapters/duckdb/__init__.py +6 -2
- sqlspec/adapters/duckdb/config.py +195 -13
- sqlspec/adapters/duckdb/driver.py +225 -0
- sqlspec/adapters/oracledb/__init__.py +11 -8
- sqlspec/adapters/oracledb/config/__init__.py +6 -6
- sqlspec/adapters/oracledb/config/_asyncio.py +98 -13
- sqlspec/adapters/oracledb/config/_common.py +1 -1
- sqlspec/adapters/oracledb/config/_sync.py +99 -14
- sqlspec/adapters/oracledb/driver.py +498 -0
- sqlspec/adapters/psycopg/__init__.py +11 -0
- sqlspec/adapters/psycopg/config/__init__.py +6 -6
- sqlspec/adapters/psycopg/config/_async.py +105 -13
- sqlspec/adapters/psycopg/config/_common.py +2 -2
- sqlspec/adapters/psycopg/config/_sync.py +105 -13
- sqlspec/adapters/psycopg/driver.py +616 -0
- sqlspec/adapters/sqlite/__init__.py +7 -0
- sqlspec/adapters/sqlite/config.py +25 -7
- sqlspec/adapters/sqlite/driver.py +303 -0
- sqlspec/base.py +416 -36
- sqlspec/extensions/litestar/__init__.py +19 -0
- sqlspec/extensions/litestar/_utils.py +56 -0
- sqlspec/extensions/litestar/config.py +81 -0
- sqlspec/extensions/litestar/handlers.py +188 -0
- sqlspec/extensions/litestar/plugin.py +100 -11
- sqlspec/typing.py +72 -17
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/fixtures.py +4 -5
- sqlspec/utils/sync_tools.py +335 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/METADATA +1 -1
- sqlspec-0.8.0.dist-info/RECORD +57 -0
- sqlspec-0.7.1.dist-info/RECORD +0 -46
- {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.7.1.dist-info → sqlspec-0.8.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
from sqlspec.base import (
|
|
5
|
+
ConnectionT,
|
|
6
|
+
PoolT,
|
|
7
|
+
)
|
|
8
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
9
|
+
from sqlspec.extensions.litestar.handlers import (
|
|
10
|
+
autocommit_handler_maker,
|
|
11
|
+
connection_provider_maker,
|
|
12
|
+
lifespan_handler_maker,
|
|
13
|
+
manual_handler_maker,
|
|
14
|
+
pool_provider_maker,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from collections.abc import Awaitable
|
|
19
|
+
from contextlib import AbstractAsyncContextManager
|
|
20
|
+
|
|
21
|
+
from litestar import Litestar
|
|
22
|
+
from litestar.datastructures.state import State
|
|
23
|
+
from litestar.types import BeforeMessageSendHookHandler, Scope
|
|
24
|
+
|
|
25
|
+
from sqlspec.base import (
|
|
26
|
+
AsyncConfigT,
|
|
27
|
+
ConnectionT,
|
|
28
|
+
PoolT,
|
|
29
|
+
SyncConfigT,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
|
|
33
|
+
DEFAULT_COMMIT_MODE: CommitMode = "manual"
|
|
34
|
+
DEFAULT_CONNECTION_KEY = "db_connection"
|
|
35
|
+
DEFAULT_POOL_KEY = "db_pool"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class DatabaseConfig:
|
|
40
|
+
config: "Union[SyncConfigT, AsyncConfigT]" = field() # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
|
|
41
|
+
connection_key: str = field(default=DEFAULT_CONNECTION_KEY)
|
|
42
|
+
pool_key: str = field(default=DEFAULT_POOL_KEY)
|
|
43
|
+
commit_mode: "CommitMode" = field(default=DEFAULT_COMMIT_MODE)
|
|
44
|
+
extra_commit_statuses: "Optional[set[int]]" = field(default=None)
|
|
45
|
+
extra_rollback_statuses: "Optional[set[int]]" = field(default=None)
|
|
46
|
+
connection_provider: "Callable[[State,Scope], Awaitable[ConnectionT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
|
|
47
|
+
pool_provider: "Callable[[State,Scope], Awaitable[PoolT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
|
|
48
|
+
before_send_handler: "BeforeMessageSendHookHandler" = field(init=False, repr=False, hash=False)
|
|
49
|
+
lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field(
|
|
50
|
+
init=False,
|
|
51
|
+
repr=False,
|
|
52
|
+
hash=False,
|
|
53
|
+
)
|
|
54
|
+
annotation: "type[Union[SyncConfigT, AsyncConfigT]]" = field(init=False, repr=False, hash=False) # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
|
|
55
|
+
|
|
56
|
+
def __post_init__(self) -> None:
|
|
57
|
+
if not self.config.support_connection_pooling and self.pool_key == DEFAULT_POOL_KEY: # type: ignore[union-attr,unused-ignore]
|
|
58
|
+
"""If the database configuration does not support connection pooling, the pool key must be unique. We just automatically generate a unique identify so it won't conflict with other configs that may get added"""
|
|
59
|
+
self.pool_key = f"_{self.pool_key}_{id(self.config)}"
|
|
60
|
+
if self.commit_mode == "manual":
|
|
61
|
+
self.before_send_handler = manual_handler_maker(connection_scope_key=self.connection_key)
|
|
62
|
+
elif self.commit_mode == "autocommit":
|
|
63
|
+
self.before_send_handler = autocommit_handler_maker(
|
|
64
|
+
commit_on_redirect=False,
|
|
65
|
+
extra_commit_statuses=self.extra_commit_statuses,
|
|
66
|
+
extra_rollback_statuses=self.extra_rollback_statuses,
|
|
67
|
+
connection_scope_key=self.connection_key,
|
|
68
|
+
)
|
|
69
|
+
elif self.commit_mode == "autocommit_include_redirect":
|
|
70
|
+
self.before_send_handler = autocommit_handler_maker(
|
|
71
|
+
commit_on_redirect=True,
|
|
72
|
+
extra_commit_statuses=self.extra_commit_statuses,
|
|
73
|
+
extra_rollback_statuses=self.extra_rollback_statuses,
|
|
74
|
+
connection_scope_key=self.connection_key,
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
msg = f"Invalid commit mode: {self.commit_mode}" # type: ignore[unreachable]
|
|
78
|
+
raise ImproperConfigurationError(detail=msg)
|
|
79
|
+
self.lifespan_handler = lifespan_handler_maker(config=self.config, pool_key=self.pool_key)
|
|
80
|
+
self.connection_provider = connection_provider_maker(connection_key=self.connection_key, config=self.config)
|
|
81
|
+
self.pool_provider = pool_provider_maker(pool_key=self.pool_key, config=self.config)
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
|
3
|
+
|
|
4
|
+
from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
|
|
5
|
+
|
|
6
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
7
|
+
from sqlspec.extensions.litestar._utils import (
|
|
8
|
+
delete_sqlspec_scope_state,
|
|
9
|
+
get_sqlspec_scope_state,
|
|
10
|
+
set_sqlspec_scope_state,
|
|
11
|
+
)
|
|
12
|
+
from sqlspec.utils.sync_tools import maybe_async_
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from collections.abc import AsyncGenerator, Awaitable, Coroutine
|
|
16
|
+
from contextlib import AbstractAsyncContextManager
|
|
17
|
+
|
|
18
|
+
from litestar import Litestar
|
|
19
|
+
from litestar.datastructures.state import State
|
|
20
|
+
from litestar.types import Message, Scope
|
|
21
|
+
|
|
22
|
+
from sqlspec.base import ConnectionT, DatabaseConfigProtocol, DriverT, PoolT
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
|
|
26
|
+
"""ASGI events that terminate a session scope."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def manual_handler_maker(connection_scope_key: str) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
|
|
30
|
+
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
|
|
31
|
+
Args:
|
|
32
|
+
connection_scope_key: The key to use within the application state
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
The handler callable
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
async def handler(message: "Message", scope: "Scope") -> None:
|
|
39
|
+
"""Handle commit/rollback, closing and cleaning up sessions before sending.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
message: ASGI-``Message``
|
|
43
|
+
scope: An ASGI-``Scope``
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
connection = get_sqlspec_scope_state(scope, connection_scope_key)
|
|
47
|
+
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
48
|
+
with contextlib.suppress(Exception):
|
|
49
|
+
await maybe_async_(connection.close)()
|
|
50
|
+
delete_sqlspec_scope_state(scope, connection_scope_key)
|
|
51
|
+
|
|
52
|
+
return handler
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def autocommit_handler_maker(
|
|
56
|
+
connection_scope_key: str,
|
|
57
|
+
commit_on_redirect: bool = False,
|
|
58
|
+
extra_commit_statuses: "Optional[set[int]]" = None,
|
|
59
|
+
extra_rollback_statuses: "Optional[set[int]]" = None,
|
|
60
|
+
) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
|
|
61
|
+
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
|
|
62
|
+
Args:
|
|
63
|
+
commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
|
|
64
|
+
extra_commit_statuses: A set of additional status codes that trigger a commit
|
|
65
|
+
extra_rollback_statuses: A set of additional status codes that trigger a rollback
|
|
66
|
+
connection_scope_key: The key to use within the application state
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ImproperConfigurationError: If extra_commit_statuses and extra_rollback_statuses share any status codes
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The handler callable
|
|
73
|
+
"""
|
|
74
|
+
if extra_commit_statuses is None:
|
|
75
|
+
extra_commit_statuses = set()
|
|
76
|
+
|
|
77
|
+
if extra_rollback_statuses is None:
|
|
78
|
+
extra_rollback_statuses = set()
|
|
79
|
+
|
|
80
|
+
if len(extra_commit_statuses & extra_rollback_statuses) > 0:
|
|
81
|
+
msg = "Extra rollback statuses and commit statuses must not share any status codes"
|
|
82
|
+
raise ImproperConfigurationError(msg)
|
|
83
|
+
|
|
84
|
+
commit_range = range(200, 400 if commit_on_redirect else 300)
|
|
85
|
+
|
|
86
|
+
async def handler(message: "Message", scope: "Scope") -> None:
|
|
87
|
+
"""Handle commit/rollback, closing and cleaning up sessions before sending.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
message: ASGI-``litestar.types.Message``
|
|
91
|
+
scope: An ASGI-``litestar.types.Scope``
|
|
92
|
+
|
|
93
|
+
"""
|
|
94
|
+
connection = get_sqlspec_scope_state(scope, connection_scope_key)
|
|
95
|
+
try:
|
|
96
|
+
if connection is not None and message["type"] == HTTP_RESPONSE_START:
|
|
97
|
+
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
|
|
98
|
+
"status"
|
|
99
|
+
] not in extra_rollback_statuses:
|
|
100
|
+
await maybe_async_(connection.commit)()
|
|
101
|
+
else:
|
|
102
|
+
await maybe_async_(connection.rollback)()
|
|
103
|
+
finally:
|
|
104
|
+
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
105
|
+
with contextlib.suppress(Exception):
|
|
106
|
+
await maybe_async_(connection.close)()
|
|
107
|
+
delete_sqlspec_scope_state(scope, connection_scope_key)
|
|
108
|
+
|
|
109
|
+
return handler
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def lifespan_handler_maker(
|
|
113
|
+
config: "DatabaseConfigProtocol[Any, Any, Any]",
|
|
114
|
+
pool_key: str,
|
|
115
|
+
) -> "Callable[[Litestar], AbstractAsyncContextManager[None]]":
|
|
116
|
+
"""Build the lifespan handler for the database configuration.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
config: The database configuration.
|
|
120
|
+
pool_key: The key to use for the connection pool within Litestar.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The generated lifespan handler for the connection.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
@contextlib.asynccontextmanager
|
|
127
|
+
async def lifespan_handler(app: "Litestar") -> "AsyncGenerator[None, None]":
|
|
128
|
+
db_pool = await maybe_async_(config.create_pool)()
|
|
129
|
+
app.state.update({pool_key: db_pool})
|
|
130
|
+
try:
|
|
131
|
+
yield
|
|
132
|
+
finally:
|
|
133
|
+
app.state.pop(pool_key, None)
|
|
134
|
+
try:
|
|
135
|
+
await maybe_async_(config.close_pool)()
|
|
136
|
+
except Exception as e: # noqa: BLE001
|
|
137
|
+
if app.logger:
|
|
138
|
+
app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e)
|
|
139
|
+
|
|
140
|
+
return lifespan_handler
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def connection_provider_maker(
|
|
144
|
+
connection_key: str,
|
|
145
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
|
|
146
|
+
) -> "Callable[[State,Scope], Awaitable[ConnectionT]]":
|
|
147
|
+
"""Build the connection provider for the database configuration.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
connection_key: The dependency key to use for the session within Litestar.
|
|
151
|
+
config: The database configuration.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
The generated connection provider for the connection.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
async def provide_connection(state: "State", scope: "Scope") -> "ConnectionT":
|
|
158
|
+
connection = get_sqlspec_scope_state(scope, connection_key)
|
|
159
|
+
if connection is None:
|
|
160
|
+
connection = await maybe_async_(config.create_connection)()
|
|
161
|
+
set_sqlspec_scope_state(scope, connection_key, connection)
|
|
162
|
+
return cast("ConnectionT", connection)
|
|
163
|
+
|
|
164
|
+
return provide_connection
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def pool_provider_maker(
|
|
168
|
+
pool_key: str,
|
|
169
|
+
config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
|
|
170
|
+
) -> "Callable[[State,Scope], Awaitable[PoolT]]":
|
|
171
|
+
"""Build the pool provider for the database configuration.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
pool_key: The dependency key to use for the pool within Litestar.
|
|
175
|
+
config: The database configuration.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
The generated connection pool for the database.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
async def provide_pool(state: "State", scope: "Scope") -> "PoolT":
|
|
182
|
+
pool = get_sqlspec_scope_state(scope, pool_key)
|
|
183
|
+
if pool is None:
|
|
184
|
+
pool = await maybe_async_(config.create_pool)()
|
|
185
|
+
set_sqlspec_scope_state(scope, pool_key, pool)
|
|
186
|
+
return cast("PoolT", pool)
|
|
187
|
+
|
|
188
|
+
return provide_pool
|
|
@@ -1,34 +1,63 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Literal, Union
|
|
2
2
|
|
|
3
|
+
from litestar.di import Provide
|
|
3
4
|
from litestar.plugins import InitPluginProtocol
|
|
4
5
|
|
|
6
|
+
from sqlspec.base import (
|
|
7
|
+
AsyncConfigT,
|
|
8
|
+
ConnectionT,
|
|
9
|
+
DatabaseConfigProtocol,
|
|
10
|
+
PoolT,
|
|
11
|
+
SyncConfigT,
|
|
12
|
+
)
|
|
13
|
+
from sqlspec.base import SQLSpec as SQLSpecBase
|
|
14
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
15
|
+
from sqlspec.extensions.litestar.config import DatabaseConfig
|
|
16
|
+
|
|
5
17
|
if TYPE_CHECKING:
|
|
18
|
+
from click import Group
|
|
6
19
|
from litestar.config.app import AppConfig
|
|
7
20
|
|
|
8
|
-
|
|
21
|
+
|
|
22
|
+
CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
|
|
23
|
+
DEFAULT_COMMIT_MODE: CommitMode = "manual"
|
|
24
|
+
DEFAULT_CONNECTION_KEY = "db_connection"
|
|
25
|
+
DEFAULT_POOL_KEY = "db_pool"
|
|
9
26
|
|
|
10
27
|
|
|
11
|
-
class
|
|
28
|
+
class SQLSpec(InitPluginProtocol, SQLSpecBase):
|
|
12
29
|
"""SQLSpec plugin."""
|
|
13
30
|
|
|
14
|
-
__slots__ = ("_config",)
|
|
31
|
+
__slots__ = ("_config", "_plugin_configs")
|
|
15
32
|
|
|
16
|
-
def __init__(
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
config: Union["SyncConfigT", "AsyncConfigT", "DatabaseConfig", list["DatabaseConfig"]],
|
|
36
|
+
) -> None:
|
|
17
37
|
"""Initialize ``SQLSpecPlugin``.
|
|
18
38
|
|
|
19
39
|
Args:
|
|
20
40
|
config: configure SQLSpec plugin for use with Litestar.
|
|
21
41
|
"""
|
|
22
|
-
self.
|
|
42
|
+
self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
|
|
43
|
+
if isinstance(config, DatabaseConfigProtocol):
|
|
44
|
+
self._plugin_configs: list[DatabaseConfig] = [DatabaseConfig(config=config)]
|
|
45
|
+
elif isinstance(config, DatabaseConfig):
|
|
46
|
+
self._plugin_configs = [config]
|
|
47
|
+
else:
|
|
48
|
+
self._plugin_configs = config
|
|
23
49
|
|
|
24
50
|
@property
|
|
25
|
-
def config(self) -> "
|
|
51
|
+
def config(self) -> "list[DatabaseConfig]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
26
52
|
"""Return the plugin config.
|
|
27
53
|
|
|
28
54
|
Returns:
|
|
29
55
|
ConfigManager.
|
|
30
56
|
"""
|
|
31
|
-
return self.
|
|
57
|
+
return self._plugin_configs
|
|
58
|
+
|
|
59
|
+
def on_cli_init(self, cli: "Group") -> None:
|
|
60
|
+
"""Configure the CLI for use with SQLSpec."""
|
|
32
61
|
|
|
33
62
|
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
|
|
34
63
|
"""Configure application for use with SQLSpec.
|
|
@@ -39,8 +68,68 @@ class SQLSpecPlugin(InitPluginProtocol):
|
|
|
39
68
|
Returns:
|
|
40
69
|
The updated :class:`AppConfig <.config.app.AppConfig>` instance.
|
|
41
70
|
"""
|
|
71
|
+
self._validate_dependency_keys()
|
|
72
|
+
app_config.signature_types.extend(
|
|
73
|
+
[
|
|
74
|
+
SQLSpec,
|
|
75
|
+
ConnectionT,
|
|
76
|
+
PoolT,
|
|
77
|
+
DatabaseConfig,
|
|
78
|
+
DatabaseConfigProtocol,
|
|
79
|
+
SyncConfigT,
|
|
80
|
+
AsyncConfigT,
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
for c in self._plugin_configs:
|
|
84
|
+
c.annotation = self.add_config(c.config)
|
|
85
|
+
app_config.before_send.append(c.before_send_handler)
|
|
86
|
+
app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType]
|
|
87
|
+
app_config.dependencies.update(
|
|
88
|
+
{c.connection_key: Provide(c.connection_provider), c.pool_key: Provide(c.pool_provider)},
|
|
89
|
+
)
|
|
42
90
|
|
|
43
|
-
from sqlspec.base import ConfigManager
|
|
44
|
-
|
|
45
|
-
app_config.signature_types.append(ConfigManager)
|
|
46
91
|
return app_config
|
|
92
|
+
|
|
93
|
+
def get_annotations(self) -> "list[type[Union[SyncConfigT, AsyncConfigT]]]": # pyright: ignore[reportInvalidTypeVarUse]
|
|
94
|
+
"""Return the list of annotations.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
List of annotations.
|
|
98
|
+
"""
|
|
99
|
+
return [c.annotation for c in self.config]
|
|
100
|
+
|
|
101
|
+
def get_annotation(
|
|
102
|
+
self,
|
|
103
|
+
key: "Union[str, SyncConfigT, AsyncConfigT, type[Union[SyncConfigT, AsyncConfigT]]]",
|
|
104
|
+
) -> "type[Union[SyncConfigT, AsyncConfigT]]":
|
|
105
|
+
"""Return the annotation for the given configuration.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
key: The configuration instance or key to lookup
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
KeyError: If no configuration is found for the given key.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
The annotation for the configuration.
|
|
115
|
+
"""
|
|
116
|
+
for c in self.config:
|
|
117
|
+
if key == c.config or key in {c.annotation, c.connection_key, c.pool_key}:
|
|
118
|
+
return c.annotation
|
|
119
|
+
msg = f"No configuration found for {key}"
|
|
120
|
+
raise KeyError(msg)
|
|
121
|
+
|
|
122
|
+
def _validate_dependency_keys(self) -> None:
|
|
123
|
+
"""Verify uniqueness of ``connection_key`` and ``pool_key``.
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
ImproperConfigurationError: If session keys or pool keys are not unique.
|
|
127
|
+
"""
|
|
128
|
+
connection_keys = [c.connection_key for c in self.config]
|
|
129
|
+
pool_keys = [c.pool_key for c in self.config]
|
|
130
|
+
if len(set(connection_keys)) != len(connection_keys):
|
|
131
|
+
msg = "When using multiple database configuration, each configuration must have a unique `connection_key`."
|
|
132
|
+
raise ImproperConfigurationError(detail=msg)
|
|
133
|
+
if len(set(pool_keys)) != len(pool_keys):
|
|
134
|
+
msg = "When using multiple database configuration, each configuration must have a unique `pool_key`."
|
|
135
|
+
raise ImproperConfigurationError(detail=msg)
|
sqlspec/typing.py
CHANGED
|
@@ -11,6 +11,7 @@ from sqlspec._typing import (
|
|
|
11
11
|
UNSET,
|
|
12
12
|
BaseModel,
|
|
13
13
|
DataclassProtocol,
|
|
14
|
+
DTOData,
|
|
14
15
|
Empty,
|
|
15
16
|
EmptyType,
|
|
16
17
|
Struct,
|
|
@@ -38,26 +39,53 @@ FilterTypeT = TypeVar("FilterTypeT", bound="StatementFilter")
|
|
|
38
39
|
|
|
39
40
|
:class:`~advanced_alchemy.filters.StatementFilter`
|
|
40
41
|
"""
|
|
42
|
+
SupportedSchemaModel: TypeAlias = "Union[Struct, BaseModel, DataclassProtocol]"
|
|
43
|
+
"""Type alias for pydantic or msgspec models.
|
|
41
44
|
|
|
45
|
+
:class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`DataclassProtocol`
|
|
46
|
+
"""
|
|
47
|
+
ModelDTOT = TypeVar("ModelDTOT", bound="SupportedSchemaModel")
|
|
48
|
+
"""Type variable for model DTOs.
|
|
42
49
|
|
|
43
|
-
|
|
50
|
+
:class:`msgspec.Struct`|:class:`pydantic.BaseModel`
|
|
51
|
+
"""
|
|
52
|
+
PydanticOrMsgspecT = SupportedSchemaModel
|
|
44
53
|
"""Type alias for pydantic or msgspec models.
|
|
45
54
|
|
|
46
55
|
:class:`msgspec.Struct` or :class:`pydantic.BaseModel`
|
|
47
56
|
"""
|
|
48
|
-
|
|
57
|
+
ModelDict: TypeAlias = "Union[dict[str, Any], SupportedSchemaModel, DTOData[SupportedSchemaModel]]"
|
|
49
58
|
"""Type alias for model dictionaries.
|
|
50
59
|
|
|
51
60
|
Represents:
|
|
52
61
|
- :type:`dict[str, Any]` | :class:`DataclassProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`
|
|
53
62
|
"""
|
|
54
|
-
|
|
63
|
+
ModelDictList: TypeAlias = "Sequence[Union[dict[str, Any], SupportedSchemaModel]]"
|
|
55
64
|
"""Type alias for model dictionary lists.
|
|
56
65
|
|
|
57
66
|
A list or sequence of any of the following:
|
|
58
67
|
- :type:`Sequence`[:type:`dict[str, Any]` | :class:`DataclassProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`]
|
|
59
68
|
|
|
60
69
|
"""
|
|
70
|
+
BulkModelDict: TypeAlias = (
|
|
71
|
+
"Union[Sequence[Union[dict[str, Any], SupportedSchemaModel]], DTOData[list[SupportedSchemaModel]]]"
|
|
72
|
+
)
|
|
73
|
+
"""Type alias for bulk model dictionaries.
|
|
74
|
+
|
|
75
|
+
Represents:
|
|
76
|
+
- :type:`Sequence`[:type:`dict[str, Any]` | :class:`DataclassProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel`]
|
|
77
|
+
- :class:`DTOData`[:type:`list[ModelT]`]
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
StatementParameterType: TypeAlias = "Union[dict[str, Any], list[Any], tuple[Any, ...], None]"
|
|
81
|
+
"""Type alias for parameter types.
|
|
82
|
+
|
|
83
|
+
Represents:
|
|
84
|
+
- :type:`dict[str, Any]`
|
|
85
|
+
- :type:`list[Any]`
|
|
86
|
+
- :type:`tuple[Any, ...]`
|
|
87
|
+
- :type:`None`
|
|
88
|
+
"""
|
|
61
89
|
|
|
62
90
|
|
|
63
91
|
def is_dataclass_instance(obj: Any) -> "TypeGuard[DataclassProtocol]":
|
|
@@ -286,7 +314,14 @@ def is_schema_or_dict_without_field(
|
|
|
286
314
|
|
|
287
315
|
|
|
288
316
|
def is_dataclass(obj: "Any") -> "TypeGuard[DataclassProtocol]":
|
|
289
|
-
"""Check if an object is a dataclass.
|
|
317
|
+
"""Check if an object is a dataclass.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
obj: Value to check.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
bool
|
|
324
|
+
"""
|
|
290
325
|
return is_dataclass_instance(obj)
|
|
291
326
|
|
|
292
327
|
|
|
@@ -294,17 +329,33 @@ def is_dataclass_with_field(
|
|
|
294
329
|
obj: "Any",
|
|
295
330
|
field_name: str,
|
|
296
331
|
) -> "TypeGuard[object]": # Can't specify dataclass type directly
|
|
297
|
-
"""Check if an object is a dataclass and has a specific field.
|
|
332
|
+
"""Check if an object is a dataclass and has a specific field.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
obj: Value to check.
|
|
336
|
+
field_name: Field name to check for.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
bool
|
|
340
|
+
"""
|
|
298
341
|
return is_dataclass(obj) and hasattr(obj, field_name)
|
|
299
342
|
|
|
300
343
|
|
|
301
344
|
def is_dataclass_without_field(obj: "Any", field_name: str) -> "TypeGuard[object]":
|
|
302
|
-
"""Check if an object is a dataclass and does not have a specific field.
|
|
345
|
+
"""Check if an object is a dataclass and does not have a specific field.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
obj: Value to check.
|
|
349
|
+
field_name: Field name to check for.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
bool
|
|
353
|
+
"""
|
|
303
354
|
return is_dataclass(obj) and not hasattr(obj, field_name)
|
|
304
355
|
|
|
305
356
|
|
|
306
357
|
def extract_dataclass_fields(
|
|
307
|
-
|
|
358
|
+
obj: "DataclassProtocol",
|
|
308
359
|
exclude_none: bool = False,
|
|
309
360
|
exclude_empty: bool = False,
|
|
310
361
|
include: "Optional[AbstractSet[str]]" = None,
|
|
@@ -313,12 +364,14 @@ def extract_dataclass_fields(
|
|
|
313
364
|
"""Extract dataclass fields.
|
|
314
365
|
|
|
315
366
|
Args:
|
|
316
|
-
|
|
367
|
+
obj: A dataclass instance.
|
|
317
368
|
exclude_none: Whether to exclude None values.
|
|
318
369
|
exclude_empty: Whether to exclude Empty values.
|
|
319
370
|
include: An iterable of fields to include.
|
|
320
371
|
exclude: An iterable of fields to exclude.
|
|
321
372
|
|
|
373
|
+
Raises:
|
|
374
|
+
ValueError: If there are fields that are both included and excluded.
|
|
322
375
|
|
|
323
376
|
Returns:
|
|
324
377
|
A tuple of dataclass fields.
|
|
@@ -330,11 +383,11 @@ def extract_dataclass_fields(
|
|
|
330
383
|
msg = f"Fields {common} are both included and excluded."
|
|
331
384
|
raise ValueError(msg)
|
|
332
385
|
|
|
333
|
-
dataclass_fields: Iterable[Field[Any]] = fields(
|
|
386
|
+
dataclass_fields: Iterable[Field[Any]] = fields(obj)
|
|
334
387
|
if exclude_none:
|
|
335
|
-
dataclass_fields = (field for field in dataclass_fields if getattr(
|
|
388
|
+
dataclass_fields = (field for field in dataclass_fields if getattr(obj, field.name) is not None)
|
|
336
389
|
if exclude_empty:
|
|
337
|
-
dataclass_fields = (field for field in dataclass_fields if getattr(
|
|
390
|
+
dataclass_fields = (field for field in dataclass_fields if getattr(obj, field.name) is not Empty)
|
|
338
391
|
if include:
|
|
339
392
|
dataclass_fields = (field for field in dataclass_fields if field.name in include)
|
|
340
393
|
if exclude:
|
|
@@ -344,7 +397,7 @@ def extract_dataclass_fields(
|
|
|
344
397
|
|
|
345
398
|
|
|
346
399
|
def extract_dataclass_items(
|
|
347
|
-
|
|
400
|
+
obj: "DataclassProtocol",
|
|
348
401
|
exclude_none: bool = False,
|
|
349
402
|
exclude_empty: bool = False,
|
|
350
403
|
include: "Optional[AbstractSet[str]]" = None,
|
|
@@ -355,7 +408,7 @@ def extract_dataclass_items(
|
|
|
355
408
|
Unlike the 'asdict' method exports by the stdlib, this function does not pickle values.
|
|
356
409
|
|
|
357
410
|
Args:
|
|
358
|
-
|
|
411
|
+
obj: A dataclass instance.
|
|
359
412
|
exclude_none: Whether to exclude None values.
|
|
360
413
|
exclude_empty: Whether to exclude Empty values.
|
|
361
414
|
include: An iterable of fields to include.
|
|
@@ -364,8 +417,8 @@ def extract_dataclass_items(
|
|
|
364
417
|
Returns:
|
|
365
418
|
A tuple of key/value pairs.
|
|
366
419
|
"""
|
|
367
|
-
dataclass_fields = extract_dataclass_fields(
|
|
368
|
-
return tuple((field.name, getattr(
|
|
420
|
+
dataclass_fields = extract_dataclass_fields(obj, exclude_none, exclude_empty, include, exclude)
|
|
421
|
+
return tuple((field.name, getattr(obj, field.name)) for field in dataclass_fields)
|
|
369
422
|
|
|
370
423
|
|
|
371
424
|
def dataclass_to_dict(
|
|
@@ -442,9 +495,11 @@ __all__ = (
|
|
|
442
495
|
"EmptyType",
|
|
443
496
|
"FailFast",
|
|
444
497
|
"FilterTypeT",
|
|
445
|
-
"
|
|
446
|
-
"
|
|
498
|
+
"ModelDict",
|
|
499
|
+
"ModelDictList",
|
|
500
|
+
"StatementParameterType",
|
|
447
501
|
"Struct",
|
|
502
|
+
"SupportedSchemaModel",
|
|
448
503
|
"TypeAdapter",
|
|
449
504
|
"UnsetType",
|
|
450
505
|
"convert",
|
sqlspec/utils/__init__.py
CHANGED
sqlspec/utils/fixtures.py
CHANGED
|
@@ -19,7 +19,7 @@ def open_fixture(fixtures_path: "Union[Path, AsyncPath]", fixture_name: str) ->
|
|
|
19
19
|
fixture_name (str): The fixture name to load.
|
|
20
20
|
|
|
21
21
|
Raises:
|
|
22
|
-
:
|
|
22
|
+
FileNotFoundError: Fixtures not found.
|
|
23
23
|
|
|
24
24
|
Returns:
|
|
25
25
|
Any: The parsed JSON data
|
|
@@ -43,8 +43,8 @@ async def open_fixture_async(fixtures_path: "Union[Path, AsyncPath]", fixture_na
|
|
|
43
43
|
fixture_name (str): The fixture name to load.
|
|
44
44
|
|
|
45
45
|
Raises:
|
|
46
|
-
:
|
|
47
|
-
:
|
|
46
|
+
FileNotFoundError: Fixtures not found.
|
|
47
|
+
MissingDependencyError: The `anyio` library is required to use this function.
|
|
48
48
|
|
|
49
49
|
Returns:
|
|
50
50
|
Any: The parsed JSON data
|
|
@@ -52,8 +52,7 @@ async def open_fixture_async(fixtures_path: "Union[Path, AsyncPath]", fixture_na
|
|
|
52
52
|
try:
|
|
53
53
|
from anyio import Path as AsyncPath
|
|
54
54
|
except ImportError as exc:
|
|
55
|
-
|
|
56
|
-
raise MissingDependencyError(msg) from exc
|
|
55
|
+
raise MissingDependencyError(package="anyio") from exc
|
|
57
56
|
|
|
58
57
|
fixture = AsyncPath(fixtures_path / f"{fixture_name}.json")
|
|
59
58
|
if await fixture.exists():
|