sqlspec 0.10.1__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

@@ -1,4 +1,8 @@
1
+ # ruff: noqa: PLC2801
1
2
  import contextlib
3
+ import inspect
4
+ from collections.abc import AsyncGenerator
5
+ from contextlib import AbstractAsyncContextManager
2
6
  from typing import TYPE_CHECKING, Any, Callable, Optional, cast
3
7
 
4
8
  from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
@@ -9,11 +13,10 @@ from sqlspec.extensions.litestar._utils import (
9
13
  get_sqlspec_scope_state,
10
14
  set_sqlspec_scope_state,
11
15
  )
12
- from sqlspec.utils.sync_tools import maybe_async_
16
+ from sqlspec.utils.sync_tools import ensure_async_
13
17
 
14
18
  if TYPE_CHECKING:
15
- from collections.abc import AsyncGenerator, Awaitable, Coroutine
16
- from contextlib import AbstractAsyncContextManager
19
+ from collections.abc import Awaitable, Coroutine
17
20
 
18
21
  from litestar import Litestar
19
22
  from litestar.datastructures.state import State
@@ -26,28 +29,38 @@ if TYPE_CHECKING:
26
29
  SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
27
30
  """ASGI events that terminate a session scope."""
28
31
 
32
+ __all__ = (
33
+ "SESSION_TERMINUS_ASGI_EVENTS",
34
+ "autocommit_handler_maker",
35
+ "connection_provider_maker",
36
+ "lifespan_handler_maker",
37
+ "manual_handler_maker",
38
+ "pool_provider_maker",
39
+ "session_provider_maker",
40
+ )
41
+
29
42
 
30
43
  def manual_handler_maker(connection_scope_key: str) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
31
- """Set up the handler to issue a transaction commit or rollback based on specified status codes
44
+ """Set up the handler to close the connection.
45
+
32
46
  Args:
33
- connection_scope_key: The key to use within the application state
47
+ connection_scope_key: The key used to store the connection in the ASGI scope.
34
48
 
35
49
  Returns:
36
- The handler callable
50
+ The handler callable.
37
51
  """
38
52
 
39
53
  async def handler(message: "Message", scope: "Scope") -> None:
40
- """Handle commit/rollback, closing and cleaning up sessions before sending.
54
+ """Handle closing and cleaning up connections before sending the response.
41
55
 
42
56
  Args:
43
- message: ASGI-``Message``
44
- scope: An ASGI-``Scope``
45
-
57
+ message: ASGI Message.
58
+ scope: ASGI Scope.
46
59
  """
47
60
  connection = get_sqlspec_scope_state(scope, connection_scope_key)
48
61
  if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
49
- with contextlib.suppress(Exception):
50
- await maybe_async_(connection.close)()
62
+ if hasattr(connection, "close") and callable(connection.close):
63
+ await ensure_async_(connection.close)()
51
64
  delete_sqlspec_scope_state(scope, connection_scope_key)
52
65
 
53
66
  return handler
@@ -59,18 +72,19 @@ def autocommit_handler_maker(
59
72
  extra_commit_statuses: "Optional[set[int]]" = None,
60
73
  extra_rollback_statuses: "Optional[set[int]]" = None,
61
74
  ) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]":
62
- """Set up the handler to issue a transaction commit or rollback based on specified status codes
75
+ """Set up the handler to issue a transaction commit or rollback based on response status codes.
76
+
63
77
  Args:
64
- commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
65
- extra_commit_statuses: A set of additional status codes that trigger a commit
66
- extra_rollback_statuses: A set of additional status codes that trigger a rollback
67
- connection_scope_key: The key to use within the application state
78
+ connection_scope_key: The key used to store the connection in the ASGI scope.
79
+ commit_on_redirect: Issue a commit when the response status is a redirect (3XX).
80
+ extra_commit_statuses: A set of additional status codes that trigger a commit.
81
+ extra_rollback_statuses: A set of additional status codes that trigger a rollback.
68
82
 
69
83
  Raises:
70
- ImproperConfigurationError: If extra_commit_statuses and extra_rollback_statuses share any status codes
84
+ ImproperConfigurationError: If extra_commit_statuses and extra_rollback_statuses share status codes.
71
85
 
72
86
  Returns:
73
- The handler callable
87
+ The handler callable.
74
88
  """
75
89
  if extra_commit_statuses is None:
76
90
  extra_commit_statuses = set()
@@ -85,12 +99,11 @@ def autocommit_handler_maker(
85
99
  commit_range = range(200, 400 if commit_on_redirect else 300)
86
100
 
87
101
  async def handler(message: "Message", scope: "Scope") -> None:
88
- """Handle commit/rollback, closing and cleaning up sessions before sending.
102
+ """Handle commit/rollback, closing and cleaning up connections before sending.
89
103
 
90
104
  Args:
91
- message: ASGI-``litestar.types.Message``
92
- scope: An ASGI-``litestar.types.Scope``
93
-
105
+ message: ASGI Message.
106
+ scope: ASGI Scope.
94
107
  """
95
108
  connection = get_sqlspec_scope_state(scope, connection_scope_key)
96
109
  try:
@@ -98,13 +111,14 @@ def autocommit_handler_maker(
98
111
  if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
99
112
  "status"
100
113
  ] not in extra_rollback_statuses:
101
- await maybe_async_(connection.commit)()
102
- else:
103
- await maybe_async_(connection.rollback)()
114
+ if hasattr(connection, "commit") and callable(connection.commit):
115
+ await ensure_async_(connection.commit)()
116
+ elif hasattr(connection, "rollback") and callable(connection.rollback):
117
+ await ensure_async_(connection.rollback)()
104
118
  finally:
105
119
  if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
106
- with contextlib.suppress(Exception):
107
- await maybe_async_(connection.close)()
120
+ if hasattr(connection, "close") and callable(connection.close):
121
+ await ensure_async_(connection.close)()
108
122
  delete_sqlspec_scope_state(scope, connection_scope_key)
109
123
 
110
124
  return handler
@@ -114,101 +128,144 @@ def lifespan_handler_maker(
114
128
  config: "DatabaseConfigProtocol[Any, Any, Any]",
115
129
  pool_key: str,
116
130
  ) -> "Callable[[Litestar], AbstractAsyncContextManager[None]]":
117
- """Build the lifespan handler for the database configuration.
131
+ """Build the lifespan handler for managing the database connection pool.
132
+
133
+ The pool is created on application startup and closed on shutdown.
118
134
 
119
135
  Args:
120
- config: The database configuration.
121
- pool_key: The key to use for the connection pool within Litestar.
136
+ config: The database configuration object.
137
+ pool_key: The key under which the connection pool will be stored in `app.state`.
122
138
 
123
139
  Returns:
124
- The generated lifespan handler for the connection.
140
+ The generated lifespan handler.
125
141
  """
126
142
 
127
143
  @contextlib.asynccontextmanager
128
144
  async def lifespan_handler(app: "Litestar") -> "AsyncGenerator[None, None]":
129
- db_pool = await maybe_async_(config.create_pool)()
145
+ """Manages the database pool lifecycle.
146
+
147
+ Args:
148
+ app: The Litestar application instance.
149
+
150
+ Yields:
151
+ The generated lifespan handler.
152
+ """
153
+ db_pool = await ensure_async_(config.create_pool)()
130
154
  app.state.update({pool_key: db_pool})
131
155
  try:
132
156
  yield
133
157
  finally:
134
158
  app.state.pop(pool_key, None)
135
159
  try:
136
- await maybe_async_(config.close_pool)()
160
+ await ensure_async_(config.close_pool)()
137
161
  except Exception as e: # noqa: BLE001
138
- if app.logger:
162
+ if app.logger: # pragma: no cover
139
163
  app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e)
140
164
 
141
165
  return lifespan_handler
142
166
 
143
167
 
144
- def connection_provider_maker(
145
- connection_key: str,
146
- config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
147
- ) -> "Callable[[State,Scope], Awaitable[ConnectionT]]":
148
- """Build the connection provider for the database configuration.
168
+ def pool_provider_maker(
169
+ config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", pool_key: str
170
+ ) -> "Callable[[State, Scope], Awaitable[PoolT]]":
171
+ """Build the pool provider to inject the application-level database pool.
149
172
 
150
173
  Args:
151
- connection_key: The dependency key to use for the session within Litestar.
152
- config: The database configuration.
174
+ config: The database configuration object.
175
+ pool_key: The key used to store the connection pool in `app.state`.
153
176
 
154
177
  Returns:
155
- The generated connection provider for the connection.
178
+ The generated pool provider.
156
179
  """
157
180
 
158
- async def provide_connection(state: "State", scope: "Scope") -> "ConnectionT":
159
- connection = get_sqlspec_scope_state(scope, connection_key)
160
- if connection is None:
161
- connection = await maybe_async_(config.create_connection)()
162
- set_sqlspec_scope_state(scope, connection_key, connection)
163
- return cast("ConnectionT", connection)
181
+ async def provide_pool(state: "State", scope: "Scope") -> "PoolT":
182
+ """Provides the database pool from `app.state`.
164
183
 
165
- return provide_connection
184
+ Args:
185
+ state: The Litestar application State object.
186
+ scope: The ASGI scope (unused for app-level pool).
166
187
 
167
188
 
168
- def pool_provider_maker(
169
- pool_key: str,
170
- config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
171
- ) -> "Callable[[State,Scope], Awaitable[PoolT]]":
172
- """Build the pool provider for the database configuration.
189
+ Returns:
190
+ The database connection pool.
173
191
 
174
- Args:
175
- pool_key: The dependency key to use for the pool within Litestar.
176
- config: The database configuration.
192
+ Raises:
193
+ ImproperConfigurationError: If the pool is not found in `app.state`.
194
+ """
195
+ # The pool is stored in app.state by the lifespan handler.
196
+ # state.get(key) accesses app.state[key]
197
+ db_pool = state.get(pool_key)
198
+ if db_pool is None:
199
+ # This case should ideally not happen if the lifespan handler ran correctly.
200
+ msg = (
201
+ f"Database pool with key '{pool_key}' not found in application state. "
202
+ "Ensure the SQLSpec lifespan handler is correctly configured and has run."
203
+ )
204
+ raise ImproperConfigurationError(msg)
205
+ return cast("PoolT", db_pool)
177
206
 
178
- Returns:
179
- The generated connection pool for the database.
180
- """
207
+ return provide_pool
181
208
 
182
- async def provide_pool(state: "State", scope: "Scope") -> "PoolT":
183
- pool = get_sqlspec_scope_state(scope, pool_key)
184
- if pool is None:
185
- pool = await maybe_async_(config.create_pool)()
186
- set_sqlspec_scope_state(scope, pool_key, pool)
187
- return cast("PoolT", pool)
188
209
 
189
- return provide_pool
210
+ def connection_provider_maker(
211
+ config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
212
+ pool_key: str,
213
+ connection_key: str,
214
+ ) -> "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]":
215
+ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ConnectionT, None]":
216
+ db_pool = state.get(pool_key)
217
+ if db_pool is None:
218
+ msg = f"Database pool with key '{pool_key}' not found. Cannot create a connection."
219
+ raise ImproperConfigurationError(msg)
220
+
221
+ connection_cm = config.provide_connection(db_pool)
222
+
223
+ if not isinstance(connection_cm, AbstractAsyncContextManager):
224
+ conn_instance: ConnectionT
225
+ if hasattr(connection_cm, "__await__"):
226
+ conn_instance = await cast("Awaitable[ConnectionT]", connection_cm)
227
+ else:
228
+ conn_instance = cast("ConnectionT", connection_cm)
229
+ set_sqlspec_scope_state(scope, connection_key, conn_instance)
230
+ yield conn_instance
231
+ return
232
+
233
+ entered_connection: Optional[ConnectionT] = None
234
+ try:
235
+ entered_connection = await connection_cm.__aenter__()
236
+ set_sqlspec_scope_state(scope, connection_key, entered_connection)
237
+ yield entered_connection
238
+ finally:
239
+ if entered_connection is not None:
240
+ await connection_cm.__aexit__(None, None, None)
241
+ delete_sqlspec_scope_state(scope, connection_key) # Optional: clear from scope
242
+
243
+ return provide_connection
190
244
 
191
245
 
192
246
  def session_provider_maker(
193
- session_key: str,
194
- config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]",
195
- ) -> "Callable[[State,Scope], Awaitable[DriverT]]":
196
- """Build the session provider for the database configuration.
247
+ config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", connection_dependency_key: str
248
+ ) -> "Callable[[Any], AsyncGenerator[DriverT, None]]":
249
+ async def provide_session(*args: Any, **kwargs: Any) -> "AsyncGenerator[DriverT, None]":
250
+ yield cast("DriverT", config.driver_type(connection=args[0] if args else kwargs.get(connection_dependency_key))) # pyright: ignore
197
251
 
198
- Args:
199
- session_key: The dependency key to use for the session within Litestar.
200
- config: The database configuration.
252
+ conn_type_annotation = config.connection_type
201
253
 
202
- Returns:
203
- The generated session provider for the database.
204
- """
254
+ db_conn_param = inspect.Parameter(
255
+ name=connection_dependency_key, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=conn_type_annotation
256
+ )
257
+
258
+ provider_signature = inspect.Signature(
259
+ parameters=[db_conn_param],
260
+ return_annotation=AsyncGenerator[config.driver_type, None], # type: ignore[name-defined]
261
+ )
262
+
263
+ provide_session.__signature__ = provider_signature # type: ignore[attr-defined]
264
+
265
+ if not hasattr(provide_session, "__annotations__") or provide_session.__annotations__ is None:
266
+ provide_session.__annotations__ = {}
205
267
 
206
- async def provide_session(state: "State", scope: "Scope") -> "DriverT":
207
- session = get_sqlspec_scope_state(scope, session_key)
208
- if session is None:
209
- connection = await maybe_async_(config.create_connection)()
210
- session = config.driver_type(connection=connection) # pyright: ignore[reportCallIssue]
211
- set_sqlspec_scope_state(scope, session_key, session)
212
- return cast("DriverT", session)
268
+ provide_session.__annotations__[connection_dependency_key] = conn_type_annotation
269
+ provide_session.__annotations__["return"] = config.driver_type
213
270
 
214
271
  return provide_session
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Any, Literal, Union
1
+ from typing import TYPE_CHECKING, Any, Union
2
2
 
3
3
  from litestar.di import Provide
4
4
  from litestar.plugins import InitPluginProtocol
@@ -19,13 +19,6 @@ if TYPE_CHECKING:
19
19
  from litestar.config.app import AppConfig
20
20
 
21
21
 
22
- CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
23
- DEFAULT_COMMIT_MODE: CommitMode = "manual"
24
- DEFAULT_CONNECTION_KEY = "db_connection"
25
- DEFAULT_POOL_KEY = "db_pool"
26
- DEFAULT_SESSION_KEY = "db_session"
27
-
28
-
29
22
  class SQLSpec(InitPluginProtocol, SQLSpecBase):
30
23
  """SQLSpec plugin."""
31
24
 
@@ -70,6 +63,13 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
70
63
  The updated :class:`AppConfig <.config.app.AppConfig>` instance.
71
64
  """
72
65
  self._validate_dependency_keys()
66
+
67
+ def store_sqlspec_in_state() -> None:
68
+ app_config.state.sqlspec = self
69
+
70
+ app_config.on_startup.append(store_sqlspec_in_state)
71
+
72
+ # Register types for injection
73
73
  app_config.signature_types.extend(
74
74
  [
75
75
  SQLSpec,
@@ -82,6 +82,7 @@ class SQLSpec(InitPluginProtocol, SQLSpecBase):
82
82
  AsyncConfigT,
83
83
  ]
84
84
  )
85
+
85
86
  for c in self._plugin_configs:
86
87
  c.annotation = self.add_config(c.config)
87
88
  app_config.signature_types.append(c.annotation)