sqlspec 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (54) hide show
  1. sqlspec/__init__.py +15 -0
  2. sqlspec/_serialization.py +16 -2
  3. sqlspec/_typing.py +40 -7
  4. sqlspec/adapters/adbc/__init__.py +7 -0
  5. sqlspec/adapters/adbc/config.py +183 -17
  6. sqlspec/adapters/adbc/driver.py +392 -0
  7. sqlspec/adapters/aiosqlite/__init__.py +5 -1
  8. sqlspec/adapters/aiosqlite/config.py +24 -6
  9. sqlspec/adapters/aiosqlite/driver.py +264 -0
  10. sqlspec/adapters/asyncmy/__init__.py +7 -2
  11. sqlspec/adapters/asyncmy/config.py +71 -11
  12. sqlspec/adapters/asyncmy/driver.py +246 -0
  13. sqlspec/adapters/asyncpg/__init__.py +9 -0
  14. sqlspec/adapters/asyncpg/config.py +102 -25
  15. sqlspec/adapters/asyncpg/driver.py +444 -0
  16. sqlspec/adapters/duckdb/__init__.py +5 -1
  17. sqlspec/adapters/duckdb/config.py +194 -12
  18. sqlspec/adapters/duckdb/driver.py +225 -0
  19. sqlspec/adapters/oracledb/__init__.py +7 -4
  20. sqlspec/adapters/oracledb/config/__init__.py +4 -4
  21. sqlspec/adapters/oracledb/config/_asyncio.py +96 -12
  22. sqlspec/adapters/oracledb/config/_common.py +1 -1
  23. sqlspec/adapters/oracledb/config/_sync.py +96 -12
  24. sqlspec/adapters/oracledb/driver.py +571 -0
  25. sqlspec/adapters/psqlpy/__init__.py +0 -0
  26. sqlspec/adapters/psqlpy/config.py +258 -0
  27. sqlspec/adapters/psqlpy/driver.py +335 -0
  28. sqlspec/adapters/psycopg/__init__.py +16 -0
  29. sqlspec/adapters/psycopg/config/__init__.py +6 -6
  30. sqlspec/adapters/psycopg/config/_async.py +107 -15
  31. sqlspec/adapters/psycopg/config/_common.py +2 -2
  32. sqlspec/adapters/psycopg/config/_sync.py +107 -15
  33. sqlspec/adapters/psycopg/driver.py +578 -0
  34. sqlspec/adapters/sqlite/__init__.py +7 -0
  35. sqlspec/adapters/sqlite/config.py +24 -6
  36. sqlspec/adapters/sqlite/driver.py +305 -0
  37. sqlspec/base.py +565 -63
  38. sqlspec/exceptions.py +30 -0
  39. sqlspec/extensions/litestar/__init__.py +19 -0
  40. sqlspec/extensions/litestar/_utils.py +56 -0
  41. sqlspec/extensions/litestar/config.py +87 -0
  42. sqlspec/extensions/litestar/handlers.py +213 -0
  43. sqlspec/extensions/litestar/plugin.py +105 -11
  44. sqlspec/statement.py +373 -0
  45. sqlspec/typing.py +81 -17
  46. sqlspec/utils/__init__.py +3 -0
  47. sqlspec/utils/fixtures.py +4 -5
  48. sqlspec/utils/sync_tools.py +335 -0
  49. {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/METADATA +4 -1
  50. sqlspec-0.9.0.dist-info/RECORD +61 -0
  51. sqlspec-0.7.1.dist-info/RECORD +0 -46
  52. {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/WHEEL +0 -0
  53. {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/licenses/LICENSE +0 -0
  54. {sqlspec-0.7.1.dist-info → sqlspec-0.9.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/exceptions.py CHANGED
@@ -1,3 +1,5 @@
1
+ from collections.abc import Generator
2
+ from contextlib import contextmanager
1
3
  from typing import Any, Optional
2
4
 
3
5
  __all__ = (
@@ -6,7 +8,9 @@ __all__ = (
6
8
  "MissingDependencyError",
7
9
  "MultipleResultsFoundError",
8
10
  "NotFoundError",
11
+ "ParameterStyleMismatchError",
9
12
  "RepositoryError",
13
+ "SQLParsingError",
10
14
  "SQLSpecError",
11
15
  "SerializationError",
12
16
  )
@@ -74,6 +78,20 @@ class SQLParsingError(SQLSpecError):
74
78
  super().__init__(message)
75
79
 
76
80
 
81
+ class ParameterStyleMismatchError(SQLSpecError):
82
+ """Error when parameter style doesn't match SQL placeholder style.
83
+
84
+ This exception is raised when there's a mismatch between the parameter type
85
+ (dictionary, tuple, etc.) and the placeholder style in the SQL query
86
+ (named, positional, etc.).
87
+ """
88
+
89
+ def __init__(self, message: Optional[str] = None) -> None:
90
+ if message is None:
91
+ message = "Parameter style mismatch: dictionary parameters provided but no named placeholders found in SQL."
92
+ super().__init__(message)
93
+
94
+
77
95
  class ImproperConfigurationError(SQLSpecError):
78
96
  """Improper Configuration error.
79
97
 
@@ -99,3 +117,15 @@ class NotFoundError(RepositoryError):
99
117
 
100
118
  class MultipleResultsFoundError(RepositoryError):
101
119
  """A single database result was required but more than one were found."""
120
+
121
+
122
+ @contextmanager
123
+ def wrap_exceptions(wrap_exceptions: bool = True) -> Generator[None, None, None]:
124
+ try:
125
+ yield
126
+
127
+ except Exception as exc:
128
+ if wrap_exceptions is False:
129
+ raise
130
+ msg = "An error occurred during the operation."
131
+ raise RepositoryError(detail=msg) from exc
@@ -0,0 +1,19 @@
1
+ from sqlspec.extensions.litestar.config import DatabaseConfig
2
+ from sqlspec.extensions.litestar.handlers import (
3
+ autocommit_handler_maker,
4
+ connection_provider_maker,
5
+ lifespan_handler_maker,
6
+ manual_handler_maker,
7
+ pool_provider_maker,
8
+ )
9
+ from sqlspec.extensions.litestar.plugin import SQLSpec
10
+
11
+ __all__ = (
12
+ "DatabaseConfig",
13
+ "SQLSpec",
14
+ "autocommit_handler_maker",
15
+ "connection_provider_maker",
16
+ "lifespan_handler_maker",
17
+ "manual_handler_maker",
18
+ "pool_provider_maker",
19
+ )
@@ -0,0 +1,56 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ if TYPE_CHECKING:
4
+ from litestar.types import Scope
5
+
6
+ __all__ = (
7
+ "delete_sqlspec_scope_state",
8
+ "get_sqlspec_scope_state",
9
+ "set_sqlspec_scope_state",
10
+ )
11
+
12
+ _SCOPE_NAMESPACE = "_sqlspec"
13
+
14
+
15
+ def get_sqlspec_scope_state(scope: "Scope", key: str, default: Any = None, pop: bool = False) -> Any:
16
+ """Get an internal value from connection scope state.
17
+
18
+ Note:
19
+ If called with a default value, this method behaves like to `dict.set_default()`, both setting the key in the
20
+ namespace to the default value, and returning it.
21
+
22
+ If called without a default value, the method behaves like `dict.get()`, returning ``None`` if the key does not
23
+ exist.
24
+
25
+ Args:
26
+ scope: The connection scope.
27
+ key: Key to get from internal namespace in scope state.
28
+ default: Default value to return.
29
+ pop: Boolean flag dictating whether the value should be deleted from the state.
30
+
31
+ Returns:
32
+ Value mapped to ``key`` in internal connection scope namespace.
33
+ """
34
+ namespace = scope.setdefault(_SCOPE_NAMESPACE, {}) # type: ignore[misc]
35
+ return namespace.pop(key, default) if pop else namespace.get(key, default) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
36
+
37
+
38
+ def set_sqlspec_scope_state(scope: "Scope", key: str, value: Any) -> None:
39
+ """Set an internal value in connection scope state.
40
+
41
+ Args:
42
+ scope: The connection scope.
43
+ key: Key to set under internal namespace in scope state.
44
+ value: Value for key.
45
+ """
46
+ scope.setdefault(_SCOPE_NAMESPACE, {})[key] = value # type: ignore[misc]
47
+
48
+
49
+ def delete_sqlspec_scope_state(scope: "Scope", key: str) -> None:
50
+ """Remove an internal value from connection scope state.
51
+
52
+ Args:
53
+ scope: The connection scope.
54
+ key: Key to set under internal namespace in scope state.
55
+ """
56
+ del scope.setdefault(_SCOPE_NAMESPACE, {})[key] # type: ignore[misc]
@@ -0,0 +1,87 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
3
+
4
+ from sqlspec.base import (
5
+ ConnectionT,
6
+ PoolT,
7
+ )
8
+ from sqlspec.exceptions import ImproperConfigurationError
9
+ from sqlspec.extensions.litestar.handlers import (
10
+ autocommit_handler_maker,
11
+ connection_provider_maker,
12
+ lifespan_handler_maker,
13
+ manual_handler_maker,
14
+ pool_provider_maker,
15
+ session_provider_maker,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import Awaitable
20
+ from contextlib import AbstractAsyncContextManager
21
+
22
+ from litestar import Litestar
23
+ from litestar.datastructures.state import State
24
+ from litestar.types import BeforeMessageSendHookHandler, Scope
25
+
26
+ from sqlspec.base import (
27
+ AsyncConfigT,
28
+ ConnectionT,
29
+ DriverT,
30
+ PoolT,
31
+ SyncConfigT,
32
+ )
33
+
34
+ CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
35
+ DEFAULT_COMMIT_MODE: CommitMode = "manual"
36
+ DEFAULT_CONNECTION_KEY = "db_connection"
37
+ DEFAULT_POOL_KEY = "db_pool"
38
+ DEFAULT_SESSION_KEY = "db_session"
39
+
40
+
41
+ @dataclass
42
+ class DatabaseConfig:
43
+ config: "Union[SyncConfigT, AsyncConfigT]" = field() # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
44
+ connection_key: str = field(default=DEFAULT_CONNECTION_KEY)
45
+ pool_key: str = field(default=DEFAULT_POOL_KEY)
46
+ session_key: str = field(default=DEFAULT_SESSION_KEY)
47
+ commit_mode: "CommitMode" = field(default=DEFAULT_COMMIT_MODE)
48
+ extra_commit_statuses: "Optional[set[int]]" = field(default=None)
49
+ extra_rollback_statuses: "Optional[set[int]]" = field(default=None)
50
+ connection_provider: "Callable[[State,Scope], Awaitable[ConnectionT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
51
+ pool_provider: "Callable[[State,Scope], Awaitable[PoolT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
52
+ session_provider: "Callable[[State,Scope], Awaitable[DriverT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues]
53
+ before_send_handler: "BeforeMessageSendHookHandler" = field(init=False, repr=False, hash=False)
54
+ lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field(
55
+ init=False,
56
+ repr=False,
57
+ hash=False,
58
+ )
59
+ annotation: "type[Union[SyncConfigT, AsyncConfigT]]" = field(init=False, repr=False, hash=False) # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
60
+
61
+ def __post_init__(self) -> None:
62
+ if not self.config.support_connection_pooling and self.pool_key == DEFAULT_POOL_KEY: # type: ignore[union-attr,unused-ignore]
63
+ """If the database configuration does not support connection pooling, the pool key must be unique. We just automatically generate a unique identify so it won't conflict with other configs that may get added"""
64
+ self.pool_key = f"_{self.pool_key}_{id(self.config)}"
65
+ if self.commit_mode == "manual":
66
+ self.before_send_handler = manual_handler_maker(connection_scope_key=self.connection_key)
67
+ elif self.commit_mode == "autocommit":
68
+ self.before_send_handler = autocommit_handler_maker(
69
+ commit_on_redirect=False,
70
+ extra_commit_statuses=self.extra_commit_statuses,
71
+ extra_rollback_statuses=self.extra_rollback_statuses,
72
+ connection_scope_key=self.connection_key,
73
+ )
74
+ elif self.commit_mode == "autocommit_include_redirect":
75
+ self.before_send_handler = autocommit_handler_maker(
76
+ commit_on_redirect=True,
77
+ extra_commit_statuses=self.extra_commit_statuses,
78
+ extra_rollback_statuses=self.extra_rollback_statuses,
79
+ connection_scope_key=self.connection_key,
80
+ )
81
+ else:
82
+ msg = f"Invalid commit mode: {self.commit_mode}" # type: ignore[unreachable]
83
+ raise ImproperConfigurationError(detail=msg)
84
+ self.lifespan_handler = lifespan_handler_maker(config=self.config, pool_key=self.pool_key)
85
+ self.connection_provider = connection_provider_maker(connection_key=self.connection_key, config=self.config)
86
+ self.pool_provider = pool_provider_maker(pool_key=self.pool_key, config=self.config)
87
+ self.session_provider = session_provider_maker(session_key=self.session_key, config=self.config)
@@ -0,0 +1,213 @@
1
+ import contextlib
2
+ from typing import TYPE_CHECKING, Any, Callable, Optional, cast
3
+
4
+ from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
5
+
6
+ from sqlspec.exceptions import ImproperConfigurationError
7
+ from sqlspec.extensions.litestar._utils import (
8
+ delete_sqlspec_scope_state,
9
+ get_sqlspec_scope_state,
10
+ set_sqlspec_scope_state,
11
+ )
12
+ from sqlspec.utils.sync_tools import maybe_async_
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import AsyncGenerator, Awaitable, Coroutine
16
+ from contextlib import AbstractAsyncContextManager
17
+
18
+ from litestar import Litestar
19
+ from litestar.datastructures.state import State
20
+ from litestar.types import Message, Scope
21
+
22
+ from sqlspec.base import ConnectionT, DatabaseConfigProtocol, DriverT, PoolT
23
+
24
+
25
+ SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
26
+ """ASGI events that terminate a session scope."""
27
+
28
+
29
+ def manual_handler_maker(connection_scope_key: str) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
30
+ """Set up the handler to issue a transaction commit or rollback based on specified status codes
31
+ Args:
32
+ connection_scope_key: The key to use within the application state
33
+
34
+ Returns:
35
+ The handler callable
36
+ """
37
+
38
+ async def handler(message: "Message", scope: "Scope") -> None:
39
+ """Handle commit/rollback, closing and cleaning up sessions before sending.
40
+
41
+ Args:
42
+ message: ASGI-``Message``
43
+ scope: An ASGI-``Scope``
44
+
45
+ """
46
+ connection = get_sqlspec_scope_state(scope, connection_scope_key)
47
+ if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
48
+ with contextlib.suppress(Exception):
49
+ await maybe_async_(connection.close)()
50
+ delete_sqlspec_scope_state(scope, connection_scope_key)
51
+
52
+ return handler
53
+
54
+
55
+ def autocommit_handler_maker(
56
+ connection_scope_key: str,
57
+ commit_on_redirect: bool = False,
58
+ extra_commit_statuses: "Optional[set[int]]" = None,
59
+ extra_rollback_statuses: "Optional[set[int]]" = None,
60
+ ) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
61
+ """Set up the handler to issue a transaction commit or rollback based on specified status codes
62
+ Args:
63
+ commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
64
+ extra_commit_statuses: A set of additional status codes that trigger a commit
65
+ extra_rollback_statuses: A set of additional status codes that trigger a rollback
66
+ connection_scope_key: The key to use within the application state
67
+
68
+ Raises:
69
+ ImproperConfigurationError: If extra_commit_statuses and extra_rollback_statuses share any status codes
70
+
71
+ Returns:
72
+ The handler callable
73
+ """
74
+ if extra_commit_statuses is None:
75
+ extra_commit_statuses = set()
76
+
77
+ if extra_rollback_statuses is None:
78
+ extra_rollback_statuses = set()
79
+
80
+ if len(extra_commit_statuses & extra_rollback_statuses) > 0:
81
+ msg = "Extra rollback statuses and commit statuses must not share any status codes"
82
+ raise ImproperConfigurationError(msg)
83
+
84
+ commit_range = range(200, 400 if commit_on_redirect else 300)
85
+
86
+ async def handler(message: "Message", scope: "Scope") -> None:
87
+ """Handle commit/rollback, closing and cleaning up sessions before sending.
88
+
89
+ Args:
90
+ message: ASGI-``litestar.types.Message``
91
+ scope: An ASGI-``litestar.types.Scope``
92
+
93
+ """
94
+ connection = get_sqlspec_scope_state(scope, connection_scope_key)
95
+ try:
96
+ if connection is not None and message["type"] == HTTP_RESPONSE_START:
97
+ if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
98
+ "status"
99
+ ] not in extra_rollback_statuses:
100
+ await maybe_async_(connection.commit)()
101
+ else:
102
+ await maybe_async_(connection.rollback)()
103
+ finally:
104
+ if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
105
+ with contextlib.suppress(Exception):
106
+ await maybe_async_(connection.close)()
107
+ delete_sqlspec_scope_state(scope, connection_scope_key)
108
+
109
+ return handler
110
+
111
+
112
+ def lifespan_handler_maker(
113
+ config: "DatabaseConfigProtocol[Any, Any, Any]",
114
+ pool_key: str,
115
+ ) -> "Callable[[Litestar], AbstractAsyncContextManager[None]]":
116
+ """Build the lifespan handler for the database configuration.
117
+
118
+ Args:
119
+ config: The database configuration.
120
+ pool_key: The key to use for the connection pool within Litestar.
121
+
122
+ Returns:
123
+ The generated lifespan handler for the connection.
124
+ """
125
+
126
+ @contextlib.asynccontextmanager
127
+ async def lifespan_handler(app: "Litestar") -> "AsyncGenerator[None, None]":
128
+ db_pool = await maybe_async_(config.create_pool)()
129
+ app.state.update({pool_key: db_pool})
130
+ try:
131
+ yield
132
+ finally:
133
+ app.state.pop(pool_key, None)
134
+ try:
135
+ await maybe_async_(config.close_pool)()
136
+ except Exception as e: # noqa: BLE001
137
+ if app.logger:
138
+ app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e)
139
+
140
+ return lifespan_handler
141
+
142
+
143
+ def connection_provider_maker(
144
+ connection_key: str,
145
+ config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
146
+ ) -> "Callable[[State,Scope], Awaitable[ConnectionT]]":
147
+ """Build the connection provider for the database configuration.
148
+
149
+ Args:
150
+ connection_key: The dependency key to use for the session within Litestar.
151
+ config: The database configuration.
152
+
153
+ Returns:
154
+ The generated connection provider for the connection.
155
+ """
156
+
157
+ async def provide_connection(state: "State", scope: "Scope") -> "ConnectionT":
158
+ connection = get_sqlspec_scope_state(scope, connection_key)
159
+ if connection is None:
160
+ connection = await maybe_async_(config.create_connection)()
161
+ set_sqlspec_scope_state(scope, connection_key, connection)
162
+ return cast("ConnectionT", connection)
163
+
164
+ return provide_connection
165
+
166
+
167
+ def pool_provider_maker(
168
+ pool_key: str,
169
+ config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
170
+ ) -> "Callable[[State,Scope], Awaitable[PoolT]]":
171
+ """Build the pool provider for the database configuration.
172
+
173
+ Args:
174
+ pool_key: The dependency key to use for the pool within Litestar.
175
+ config: The database configuration.
176
+
177
+ Returns:
178
+ The generated connection pool for the database.
179
+ """
180
+
181
+ async def provide_pool(state: "State", scope: "Scope") -> "PoolT":
182
+ pool = get_sqlspec_scope_state(scope, pool_key)
183
+ if pool is None:
184
+ pool = await maybe_async_(config.create_pool)()
185
+ set_sqlspec_scope_state(scope, pool_key, pool)
186
+ return cast("PoolT", pool)
187
+
188
+ return provide_pool
189
+
190
+
191
+ def session_provider_maker(
192
+ session_key: str,
193
+ config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
194
+ ) -> "Callable[[State,Scope], Awaitable[DriverT]]":
195
+ """Build the session provider for the database configuration.
196
+
197
+ Args:
198
+ session_key: The dependency key to use for the session within Litestar.
199
+ config: The database configuration.
200
+
201
+ Returns:
202
+ The generated session provider for the database.
203
+ """
204
+
205
+ async def provide_session(state: "State", scope: "Scope") -> "DriverT":
206
+ session = get_sqlspec_scope_state(scope, session_key)
207
+ if session is None:
208
+ connection = await maybe_async_(config.create_connection)()
209
+ session = config.driver_type(connection=connection) # pyright: ignore[reportCallIssue]
210
+ set_sqlspec_scope_state(scope, session_key, session)
211
+ return cast("DriverT", session)
212
+
213
+ return provide_session
@@ -1,34 +1,64 @@
1
- from typing import TYPE_CHECKING
1
+ from typing import TYPE_CHECKING, Any, Literal, Union
2
2
 
3
+ from litestar.di import Provide
3
4
  from litestar.plugins import InitPluginProtocol
4
5
 
6
+ from sqlspec.base import (
7
+ AsyncConfigT,
8
+ ConnectionT,
9
+ DatabaseConfigProtocol,
10
+ PoolT,
11
+ SyncConfigT,
12
+ )
13
+ from sqlspec.base import SQLSpec as SQLSpecBase
14
+ from sqlspec.exceptions import ImproperConfigurationError
15
+ from sqlspec.extensions.litestar.config import DatabaseConfig
16
+
5
17
  if TYPE_CHECKING:
18
+ from click import Group
6
19
  from litestar.config.app import AppConfig
7
20
 
8
- from sqlspec.base import ConfigManager
21
+
22
+ CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
23
+ DEFAULT_COMMIT_MODE: CommitMode = "manual"
24
+ DEFAULT_CONNECTION_KEY = "db_connection"
25
+ DEFAULT_POOL_KEY = "db_pool"
26
+ DEFAULT_SESSION_KEY = "db_session"
9
27
 
10
28
 
11
- class SQLSpecPlugin(InitPluginProtocol):
29
+ class SQLSpec(InitPluginProtocol, SQLSpecBase):
12
30
  """SQLSpec plugin."""
13
31
 
14
- __slots__ = ("_config",)
32
+ __slots__ = ("_config", "_plugin_configs")
15
33
 
16
- def __init__(self, config: "ConfigManager") -> None:
34
+ def __init__(
35
+ self,
36
+ config: Union["SyncConfigT", "AsyncConfigT", "DatabaseConfig", list["DatabaseConfig"]],
37
+ ) -> None:
17
38
  """Initialize ``SQLSpecPlugin``.
18
39
 
19
40
  Args:
20
41
  config: configure SQLSpec plugin for use with Litestar.
21
42
  """
22
- self._config = config
43
+ self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {}
44
+ if isinstance(config, DatabaseConfigProtocol):
45
+ self._plugin_configs: list[DatabaseConfig] = [DatabaseConfig(config=config)]
46
+ elif isinstance(config, DatabaseConfig):
47
+ self._plugin_configs = [config]
48
+ else:
49
+ self._plugin_configs = config
23
50
 
24
51
  @property
25
- def config(self) -> "ConfigManager":
52
+ def config(self) -> "list[DatabaseConfig]": # pyright: ignore[reportInvalidTypeVarUse]
26
53
  """Return the plugin config.
27
54
 
28
55
  Returns:
29
56
  ConfigManager.
30
57
  """
31
- return self._config
58
+ return self._plugin_configs
59
+
60
+ def on_cli_init(self, cli: "Group") -> None:
61
+ """Configure the CLI for use with SQLSpec."""
32
62
 
33
63
  def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
34
64
  """Configure application for use with SQLSpec.
@@ -39,8 +69,72 @@ class SQLSpecPlugin(InitPluginProtocol):
39
69
  Returns:
40
70
  The updated :class:`AppConfig <.config.app.AppConfig>` instance.
41
71
  """
72
+ self._validate_dependency_keys()
73
+ app_config.signature_types.extend(
74
+ [
75
+ SQLSpec,
76
+ ConnectionT,
77
+ PoolT,
78
+ DatabaseConfig,
79
+ DatabaseConfigProtocol,
80
+ SyncConfigT,
81
+ AsyncConfigT,
82
+ ]
83
+ )
84
+ for c in self._plugin_configs:
85
+ c.annotation = self.add_config(c.config)
86
+ app_config.before_send.append(c.before_send_handler)
87
+ app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType]
88
+ app_config.dependencies.update(
89
+ {
90
+ c.connection_key: Provide(c.connection_provider),
91
+ c.pool_key: Provide(c.pool_provider),
92
+ c.session_key: Provide(c.session_provider),
93
+ },
94
+ )
42
95
 
43
- from sqlspec.base import ConfigManager
44
-
45
- app_config.signature_types.append(ConfigManager)
46
96
  return app_config
97
+
98
+ def get_annotations(self) -> "list[type[Union[SyncConfigT, AsyncConfigT]]]": # pyright: ignore[reportInvalidTypeVarUse]
99
+ """Return the list of annotations.
100
+
101
+ Returns:
102
+ List of annotations.
103
+ """
104
+ return [c.annotation for c in self.config]
105
+
106
+ def get_annotation(
107
+ self,
108
+ key: "Union[str, SyncConfigT, AsyncConfigT, type[Union[SyncConfigT, AsyncConfigT]]]",
109
+ ) -> "type[Union[SyncConfigT, AsyncConfigT]]":
110
+ """Return the annotation for the given configuration.
111
+
112
+ Args:
113
+ key: The configuration instance or key to lookup
114
+
115
+ Raises:
116
+ KeyError: If no configuration is found for the given key.
117
+
118
+ Returns:
119
+ The annotation for the configuration.
120
+ """
121
+ for c in self.config:
122
+ if key == c.config or key in {c.annotation, c.connection_key, c.pool_key}:
123
+ return c.annotation
124
+ msg = f"No configuration found for {key}"
125
+ raise KeyError(msg)
126
+
127
+ def _validate_dependency_keys(self) -> None:
128
+ """Verify uniqueness of ``connection_key`` and ``pool_key``.
129
+
130
+ Raises:
131
+ ImproperConfigurationError: If session keys or pool keys are not unique.
132
+ """
133
+ connection_keys = [c.connection_key for c in self.config]
134
+ pool_keys = [c.pool_key for c in self.config]
135
+ if len(set(connection_keys)) != len(connection_keys):
136
+ msg = "When using multiple database configuration, each configuration must have a unique `connection_key`."
137
+ raise ImproperConfigurationError(detail=msg)
138
+ if len(set(pool_keys)) != len(pool_keys):
139
+ msg = "When using multiple database configuration, each configuration must have a unique `pool_key`."
140
+ raise ImproperConfigurationError(detail=msg)