sqlspec 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (212) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +55 -25
  3. sqlspec/_typing.py +155 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +880 -0
  7. sqlspec/adapters/adbc/config.py +62 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +74 -2
  9. sqlspec/adapters/adbc/driver.py +226 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +44 -50
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +536 -0
  16. sqlspec/adapters/aiosqlite/config.py +86 -16
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
  18. sqlspec/adapters/aiosqlite/driver.py +127 -38
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +1 -1
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +503 -0
  26. sqlspec/adapters/asyncmy/config.py +59 -17
  27. sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
  28. sqlspec/adapters/asyncmy/driver.py +293 -62
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +460 -0
  36. sqlspec/adapters/asyncpg/config.py +57 -36
  37. sqlspec/adapters/asyncpg/data_dictionary.py +48 -2
  38. sqlspec/adapters/asyncpg/driver.py +153 -23
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +585 -0
  44. sqlspec/adapters/bigquery/config.py +36 -11
  45. sqlspec/adapters/bigquery/data_dictionary.py +42 -2
  46. sqlspec/adapters/bigquery/driver.py +489 -144
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +55 -23
  50. sqlspec/adapters/duckdb/_types.py +2 -2
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +563 -0
  53. sqlspec/adapters/duckdb/config.py +79 -21
  54. sqlspec/adapters/duckdb/data_dictionary.py +41 -2
  55. sqlspec/adapters/duckdb/driver.py +225 -44
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +5 -5
  59. sqlspec/adapters/duckdb/type_converter.py +51 -21
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1628 -0
  64. sqlspec/adapters/oracledb/config.py +120 -36
  65. sqlspec/adapters/oracledb/data_dictionary.py +87 -20
  66. sqlspec/adapters/oracledb/driver.py +475 -86
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +765 -0
  69. sqlspec/adapters/oracledb/migrations.py +316 -25
  70. sqlspec/adapters/oracledb/type_converter.py +91 -16
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +483 -0
  75. sqlspec/adapters/psqlpy/config.py +45 -19
  76. sqlspec/adapters/psqlpy/data_dictionary.py +48 -2
  77. sqlspec/adapters/psqlpy/driver.py +108 -41
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +40 -11
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +962 -0
  85. sqlspec/adapters/psycopg/config.py +65 -37
  86. sqlspec/adapters/psycopg/data_dictionary.py +91 -3
  87. sqlspec/adapters/psycopg/driver.py +200 -78
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +582 -0
  95. sqlspec/adapters/sqlite/config.py +85 -16
  96. sqlspec/adapters/sqlite/data_dictionary.py +34 -2
  97. sqlspec/adapters/sqlite/driver.py +120 -52
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +5 -5
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +91 -58
  104. sqlspec/builder/_column.py +5 -5
  105. sqlspec/builder/_ddl.py +98 -89
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +41 -44
  109. sqlspec/builder/_insert.py +5 -82
  110. sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +9 -11
  113. sqlspec/builder/_select.py +1313 -25
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +76 -69
  116. sqlspec/config.py +331 -62
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +18 -18
  119. sqlspec/core/compiler.py +6 -8
  120. sqlspec/core/filters.py +55 -47
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +76 -45
  123. sqlspec/core/result.py +234 -47
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +32 -31
  126. sqlspec/core/type_conversion.py +3 -2
  127. sqlspec/driver/__init__.py +1 -3
  128. sqlspec/driver/_async.py +183 -160
  129. sqlspec/driver/_common.py +197 -109
  130. sqlspec/driver/_sync.py +189 -161
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +70 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +69 -61
  142. sqlspec/extensions/fastapi/__init__.py +21 -0
  143. sqlspec/extensions/fastapi/extension.py +331 -0
  144. sqlspec/extensions/fastapi/providers.py +543 -0
  145. sqlspec/extensions/flask/__init__.py +36 -0
  146. sqlspec/extensions/flask/_state.py +71 -0
  147. sqlspec/extensions/flask/_utils.py +40 -0
  148. sqlspec/extensions/flask/extension.py +389 -0
  149. sqlspec/extensions/litestar/__init__.py +21 -4
  150. sqlspec/extensions/litestar/cli.py +54 -10
  151. sqlspec/extensions/litestar/config.py +56 -266
  152. sqlspec/extensions/litestar/handlers.py +46 -17
  153. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  154. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  155. sqlspec/extensions/litestar/plugin.py +349 -224
  156. sqlspec/extensions/litestar/providers.py +25 -25
  157. sqlspec/extensions/litestar/store.py +265 -0
  158. sqlspec/extensions/starlette/__init__.py +10 -0
  159. sqlspec/extensions/starlette/_state.py +25 -0
  160. sqlspec/extensions/starlette/_utils.py +52 -0
  161. sqlspec/extensions/starlette/extension.py +254 -0
  162. sqlspec/extensions/starlette/middleware.py +154 -0
  163. sqlspec/loader.py +30 -49
  164. sqlspec/migrations/base.py +200 -76
  165. sqlspec/migrations/commands.py +591 -62
  166. sqlspec/migrations/context.py +6 -9
  167. sqlspec/migrations/fix.py +199 -0
  168. sqlspec/migrations/loaders.py +47 -19
  169. sqlspec/migrations/runner.py +241 -75
  170. sqlspec/migrations/tracker.py +237 -21
  171. sqlspec/migrations/utils.py +51 -3
  172. sqlspec/migrations/validation.py +177 -0
  173. sqlspec/protocols.py +106 -36
  174. sqlspec/storage/_utils.py +85 -0
  175. sqlspec/storage/backends/fsspec.py +133 -107
  176. sqlspec/storage/backends/local.py +78 -51
  177. sqlspec/storage/backends/obstore.py +276 -168
  178. sqlspec/storage/registry.py +75 -39
  179. sqlspec/typing.py +30 -84
  180. sqlspec/utils/__init__.py +25 -4
  181. sqlspec/utils/arrow_helpers.py +81 -0
  182. sqlspec/utils/config_resolver.py +6 -6
  183. sqlspec/utils/correlation.py +4 -5
  184. sqlspec/utils/data_transformation.py +3 -2
  185. sqlspec/utils/deprecation.py +9 -8
  186. sqlspec/utils/fixtures.py +4 -4
  187. sqlspec/utils/logging.py +46 -6
  188. sqlspec/utils/module_loader.py +205 -5
  189. sqlspec/utils/portal.py +311 -0
  190. sqlspec/utils/schema.py +288 -0
  191. sqlspec/utils/serializers.py +113 -4
  192. sqlspec/utils/sync_tools.py +36 -22
  193. sqlspec/utils/text.py +1 -2
  194. sqlspec/utils/type_guards.py +136 -20
  195. sqlspec/utils/version.py +433 -0
  196. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +41 -22
  197. sqlspec-0.28.0.dist-info/RECORD +221 -0
  198. sqlspec/builder/mixins/__init__.py +0 -55
  199. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
  200. sqlspec/builder/mixins/_delete_operations.py +0 -50
  201. sqlspec/builder/mixins/_insert_operations.py +0 -282
  202. sqlspec/builder/mixins/_merge_operations.py +0 -698
  203. sqlspec/builder/mixins/_order_limit_operations.py +0 -145
  204. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  205. sqlspec/builder/mixins/_select_operations.py +0 -930
  206. sqlspec/builder/mixins/_update_operations.py +0 -199
  207. sqlspec/builder/mixins/_where_clause.py +0 -1298
  208. sqlspec-0.26.0.dist-info/RECORD +0 -157
  209. sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
  210. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
  211. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
  212. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,389 @@
1
+ """Flask extension for SQLSpec database integration."""
2
+
3
+ import atexit
4
+ from typing import TYPE_CHECKING, Any, Literal
5
+
6
+ from sqlspec.base import SQLSpec
7
+ from sqlspec.config import AsyncDatabaseConfig, NoPoolAsyncConfig
8
+ from sqlspec.exceptions import ImproperConfigurationError
9
+ from sqlspec.extensions.flask._state import FlaskConfigState
10
+ from sqlspec.extensions.flask._utils import get_or_create_session
11
+ from sqlspec.utils.logging import get_logger
12
+ from sqlspec.utils.portal import PortalProvider
13
+
14
+ if TYPE_CHECKING:
15
+ from flask import Flask, Response
16
+
17
+ __all__ = ("SQLSpecPlugin",)
18
+
19
+ logger = get_logger("extensions.flask")
20
+
21
+ DEFAULT_COMMIT_MODE: Literal["manual"] = "manual"
22
+ DEFAULT_SESSION_KEY = "db_session"
23
+
24
+
25
+ class SQLSpecPlugin:
26
+ """Flask extension for SQLSpec database integration.
27
+
28
+ Provides request-scoped session management, automatic transaction handling,
29
+ and async adapter support via portal pattern.
30
+
31
+ Example:
32
+ from flask import Flask
33
+ from sqlspec import SQLSpec
34
+ from sqlspec.adapters.sqlite import SqliteConfig
35
+ from sqlspec.extensions.flask import SQLSpecPlugin
36
+
37
+ sqlspec = SQLSpec()
38
+ config = SqliteConfig(
39
+ pool_config={"database": "app.db"},
40
+ extension_config={
41
+ "flask": {
42
+ "commit_mode": "autocommit",
43
+ "session_key": "db"
44
+ }
45
+ }
46
+ )
47
+ sqlspec.add_config(config)
48
+
49
+ app = Flask(__name__)
50
+ plugin = SQLSpecPlugin(sqlspec, app)
51
+
52
+ @app.route("/users")
53
+ def list_users():
54
+ db = plugin.get_session()
55
+ result = db.execute("SELECT * FROM users")
56
+ return {"users": result.all()}
57
+ """
58
+
59
+ def __init__(self, sqlspec: SQLSpec, app: "Flask | None" = None) -> None:
60
+ """Initialize Flask extension with SQLSpec instance.
61
+
62
+ Args:
63
+ sqlspec: SQLSpec instance with registered configs.
64
+ app: Optional Flask application to initialize immediately.
65
+ """
66
+ self._sqlspec = sqlspec
67
+ self._config_states: list[FlaskConfigState] = []
68
+ self._portal: PortalProvider | None = None
69
+ self._has_async_configs = False
70
+ self._cleanup_registered = False
71
+ self._shutdown_complete = False
72
+
73
+ for cfg in self._sqlspec.configs.values():
74
+ state = self._create_config_state(cfg)
75
+ self._config_states.append(state)
76
+
77
+ if state.is_async:
78
+ self._has_async_configs = True
79
+
80
+ if app is not None:
81
+ self.init_app(app)
82
+
83
+ def _create_config_state(self, config: Any) -> FlaskConfigState:
84
+ """Create configuration state from database config.
85
+
86
+ Args:
87
+ config: Database configuration instance.
88
+
89
+ Returns:
90
+ FlaskConfigState instance.
91
+ """
92
+ flask_config = config.extension_config.get("flask", {})
93
+
94
+ session_key = flask_config.get("session_key", DEFAULT_SESSION_KEY)
95
+ connection_key = flask_config.get("connection_key", f"sqlspec_connection_{session_key}")
96
+ commit_mode = flask_config.get("commit_mode", DEFAULT_COMMIT_MODE)
97
+ extra_commit_statuses = flask_config.get("extra_commit_statuses")
98
+ extra_rollback_statuses = flask_config.get("extra_rollback_statuses")
99
+
100
+ is_async = isinstance(config, (AsyncDatabaseConfig, NoPoolAsyncConfig))
101
+
102
+ return FlaskConfigState(
103
+ config=config,
104
+ connection_key=connection_key,
105
+ session_key=session_key,
106
+ commit_mode=commit_mode,
107
+ extra_commit_statuses=extra_commit_statuses,
108
+ extra_rollback_statuses=extra_rollback_statuses,
109
+ is_async=is_async,
110
+ )
111
+
112
+ def init_app(self, app: "Flask") -> None:
113
+ """Initialize Flask application with SQLSpec.
114
+
115
+ Validates configuration, creates portal if needed, creates pools,
116
+ and registers hooks.
117
+
118
+ Args:
119
+ app: Flask application to initialize.
120
+
121
+ Raises:
122
+ ImproperConfigurationError: If extension already registered or keys not unique.
123
+ """
124
+ if "sqlspec" in app.extensions:
125
+ msg = "SQLSpec extension already registered on this Flask application"
126
+ raise ImproperConfigurationError(msg)
127
+
128
+ self._validate_unique_keys()
129
+
130
+ if self._has_async_configs:
131
+ self._portal = PortalProvider()
132
+ self._portal.start()
133
+ logger.debug("Portal provider started for async adapters")
134
+
135
+ pools: dict[str, Any] = {}
136
+ for config_state in self._config_states:
137
+ if config_state.config.supports_connection_pooling:
138
+ if config_state.is_async:
139
+ pool = self._portal.portal.call(config_state.config.create_pool) # type: ignore[union-attr,arg-type]
140
+ else:
141
+ pool = config_state.config.create_pool()
142
+ pools[config_state.session_key] = pool
143
+
144
+ app.extensions["sqlspec"] = {"plugin": self, "pools": pools}
145
+
146
+ app.before_request(self._before_request_handler)
147
+ app.after_request(self._after_request_handler)
148
+ app.teardown_appcontext(self._teardown_appcontext_handler)
149
+ self._register_shutdown_hook()
150
+
151
+ logger.debug("SQLSpec Flask extension initialized")
152
+
153
+ def _validate_unique_keys(self) -> None:
154
+ """Validate that all state keys are unique across configs.
155
+
156
+ Raises:
157
+ ImproperConfigurationError: If duplicate keys found.
158
+ """
159
+ all_keys: set[str] = set()
160
+
161
+ for state in self._config_states:
162
+ keys = {state.connection_key, state.session_key}
163
+ duplicates = all_keys & keys
164
+
165
+ if duplicates:
166
+ msg = f"Duplicate state keys found: {duplicates}. Use unique session_key values."
167
+ raise ImproperConfigurationError(msg)
168
+
169
+ all_keys.update(keys)
170
+
171
+ def _register_shutdown_hook(self) -> None:
172
+ """Register shutdown hook for pool and portal cleanup."""
173
+
174
+ if self._cleanup_registered:
175
+ return
176
+
177
+ atexit.register(self.shutdown)
178
+ self._cleanup_registered = True
179
+
180
+ def _before_request_handler(self) -> None:
181
+ """Acquire connection before request.
182
+
183
+ Stores connection in Flask g object for each configured database.
184
+ Also stores context managers for proper cleanup.
185
+ """
186
+ from flask import current_app, g
187
+
188
+ for config_state in self._config_states:
189
+ if config_state.config.supports_connection_pooling:
190
+ pool = current_app.extensions["sqlspec"]["pools"][config_state.session_key]
191
+ conn_ctx = config_state.config.provide_connection(pool)
192
+
193
+ if config_state.is_async:
194
+ connection = self._portal.portal.call(conn_ctx.__aenter__) # type: ignore[union-attr]
195
+ else:
196
+ connection = conn_ctx.__enter__() # type: ignore[union-attr]
197
+
198
+ setattr(g, f"{config_state.connection_key}_ctx", conn_ctx)
199
+ elif config_state.is_async:
200
+ connection = self._portal.portal.call(config_state.config.create_connection) # type: ignore[union-attr,arg-type]
201
+ else:
202
+ connection = config_state.config.create_connection()
203
+
204
+ setattr(g, config_state.connection_key, connection)
205
+
206
+ def _after_request_handler(self, response: "Response") -> "Response":
207
+ """Handle transaction after request based on response status.
208
+
209
+ Args:
210
+ response: Flask response object.
211
+
212
+ Returns:
213
+ Response object (unchanged).
214
+ """
215
+ from flask import g
216
+
217
+ for config_state in self._config_states:
218
+ if config_state.commit_mode == "manual":
219
+ continue
220
+
221
+ cache_key = f"sqlspec_session_cache_{config_state.session_key}"
222
+ session = getattr(g, cache_key, None)
223
+
224
+ if session is None:
225
+ continue
226
+
227
+ if config_state.should_commit(response.status_code):
228
+ self._execute_commit(session, config_state)
229
+ elif config_state.should_rollback(response.status_code):
230
+ self._execute_rollback(session, config_state)
231
+
232
+ return response
233
+
234
+ def _teardown_appcontext_handler(self, _exc: "BaseException | None" = None) -> None:
235
+ """Clean up connections when request context ends.
236
+
237
+ Closes all connections and cleans up g object.
238
+
239
+ Args:
240
+ _exc: Exception that occurred (if any).
241
+ """
242
+ from flask import g
243
+
244
+ for config_state in self._config_states:
245
+ connection = getattr(g, config_state.connection_key, None)
246
+ ctx_key = f"{config_state.connection_key}_ctx"
247
+ conn_ctx = getattr(g, ctx_key, None)
248
+
249
+ if connection is not None:
250
+ try:
251
+ if conn_ctx is not None:
252
+ if config_state.is_async:
253
+ self._portal.portal.call(conn_ctx.__aexit__, None, None, None) # type: ignore[union-attr]
254
+ else:
255
+ conn_ctx.__exit__(None, None, None)
256
+ elif config_state.is_async:
257
+ self._portal.portal.call(connection.close) # type: ignore[union-attr]
258
+ else:
259
+ connection.close()
260
+ except Exception:
261
+ logger.exception("Error closing connection")
262
+
263
+ if hasattr(g, config_state.connection_key):
264
+ delattr(g, config_state.connection_key)
265
+ if hasattr(g, ctx_key):
266
+ delattr(g, ctx_key)
267
+
268
+ cache_key = f"sqlspec_session_cache_{config_state.session_key}"
269
+ if hasattr(g, cache_key):
270
+ delattr(g, cache_key)
271
+
272
+ def get_session(self, key: "str | None" = None) -> Any:
273
+ """Get or create database session for current request.
274
+
275
+ Sessions are cached per request for consistency.
276
+
277
+ Args:
278
+ key: Session key for multi-database configs. Defaults to first config if None.
279
+
280
+ Returns:
281
+ Database session (driver instance).
282
+ """
283
+ config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
284
+
285
+ return get_or_create_session(config_state, self._portal.portal if self._portal else None)
286
+
287
+ def get_connection(self, key: "str | None" = None) -> Any:
288
+ """Get database connection for current request.
289
+
290
+ Args:
291
+ key: Session key for multi-database configs. Defaults to first config if None.
292
+
293
+ Returns:
294
+ Raw database connection.
295
+ """
296
+ from flask import g
297
+
298
+ config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
299
+
300
+ return getattr(g, config_state.connection_key)
301
+
302
+ def _get_config_state_by_key(self, key: str) -> FlaskConfigState:
303
+ """Get config state by session key.
304
+
305
+ Args:
306
+ key: Session key to look up.
307
+
308
+ Returns:
309
+ FlaskConfigState for the key.
310
+
311
+ Raises:
312
+ ImproperConfigurationError: If key not found.
313
+ """
314
+ for state in self._config_states:
315
+ if state.session_key == key:
316
+ return state
317
+
318
+ msg = f"No configuration found for key: {key}"
319
+ raise ImproperConfigurationError(msg)
320
+
321
+ def shutdown(self) -> None:
322
+ """Dispose connection pools and stop async portal."""
323
+
324
+ if self._shutdown_complete:
325
+ return
326
+
327
+ self._shutdown_complete = True
328
+
329
+ for config_state in self._config_states:
330
+ if config_state.config.supports_connection_pooling:
331
+ self._close_pool_state(config_state)
332
+
333
+ if self._portal is not None:
334
+ try:
335
+ self._portal.stop()
336
+ except Exception:
337
+ logger.exception("Error stopping portal during shutdown")
338
+ finally:
339
+ self._portal = None
340
+
341
+ def _close_pool_state(self, config_state: FlaskConfigState) -> None:
342
+ """Close pool associated with configuration state."""
343
+
344
+ try:
345
+ if config_state.is_async:
346
+ if self._portal is None:
347
+ logger.debug(
348
+ "Portal not initialized - skipping async pool shutdown for %s", config_state.session_key
349
+ )
350
+ return
351
+ _ = self._portal.portal.call(config_state.config.close_pool) # type: ignore[arg-type]
352
+ else:
353
+ config_state.config.close_pool()
354
+ except Exception:
355
+ logger.exception("Error closing pool during shutdown for key %s", config_state.session_key)
356
+
357
+ def _execute_commit(self, session: Any, config_state: FlaskConfigState) -> None:
358
+ """Execute commit on session.
359
+
360
+ Args:
361
+ session: Database session.
362
+ config_state: Configuration state.
363
+ """
364
+ try:
365
+ if config_state.is_async:
366
+ connection = self.get_connection(config_state.session_key)
367
+ self._portal.portal.call(connection.commit) # type: ignore[union-attr]
368
+ else:
369
+ connection = self.get_connection(config_state.session_key)
370
+ connection.commit()
371
+ except Exception:
372
+ logger.exception("Error committing transaction")
373
+
374
+ def _execute_rollback(self, session: Any, config_state: FlaskConfigState) -> None:
375
+ """Execute rollback on session.
376
+
377
+ Args:
378
+ session: Database session.
379
+ config_state: Configuration state.
380
+ """
381
+ try:
382
+ if config_state.is_async:
383
+ connection = self.get_connection(config_state.session_key)
384
+ self._portal.portal.call(connection.rollback) # type: ignore[union-attr]
385
+ else:
386
+ connection = self.get_connection(config_state.session_key)
387
+ connection.rollback()
388
+ except Exception as exc:
389
+ logger.debug("Rollback failed (may be no active transaction): %s", exc)
@@ -1,6 +1,23 @@
1
- from sqlspec.extensions.litestar import handlers, providers
2
1
  from sqlspec.extensions.litestar.cli import database_group
3
- from sqlspec.extensions.litestar.config import DatabaseConfig
4
- from sqlspec.extensions.litestar.plugin import SQLSpec
2
+ from sqlspec.extensions.litestar.config import LitestarConfig
3
+ from sqlspec.extensions.litestar.plugin import (
4
+ DEFAULT_COMMIT_MODE,
5
+ DEFAULT_CONNECTION_KEY,
6
+ DEFAULT_POOL_KEY,
7
+ DEFAULT_SESSION_KEY,
8
+ CommitMode,
9
+ SQLSpecPlugin,
10
+ )
11
+ from sqlspec.extensions.litestar.store import BaseSQLSpecStore
5
12
 
6
- __all__ = ("DatabaseConfig", "SQLSpec", "database_group", "handlers", "providers")
13
+ __all__ = (
14
+ "DEFAULT_COMMIT_MODE",
15
+ "DEFAULT_CONNECTION_KEY",
16
+ "DEFAULT_POOL_KEY",
17
+ "DEFAULT_SESSION_KEY",
18
+ "BaseSQLSpecStore",
19
+ "CommitMode",
20
+ "LitestarConfig",
21
+ "SQLSpecPlugin",
22
+ "database_group",
23
+ )
@@ -3,22 +3,18 @@
3
3
  from contextlib import suppress
4
4
  from typing import TYPE_CHECKING
5
5
 
6
+ import rich_click as click
6
7
  from litestar.cli._utils import LitestarGroup
7
8
 
8
9
  from sqlspec.cli import add_migration_commands
9
10
 
10
- try:
11
- import rich_click as click
12
- except ImportError:
13
- import click # type: ignore[no-redef]
14
-
15
11
  if TYPE_CHECKING:
16
12
  from litestar import Litestar
17
13
 
18
- from sqlspec.extensions.litestar.plugin import SQLSpec
14
+ from sqlspec.extensions.litestar.plugin import SQLSpecPlugin
19
15
 
20
16
 
21
- def get_database_migration_plugin(app: "Litestar") -> "SQLSpec":
17
+ def get_database_migration_plugin(app: "Litestar") -> "SQLSpecPlugin":
22
18
  """Retrieve the SQLSpec plugin from the Litestar application's plugins.
23
19
 
24
20
  Args:
@@ -31,18 +27,66 @@ def get_database_migration_plugin(app: "Litestar") -> "SQLSpec":
31
27
  ImproperConfigurationError: If the SQLSpec plugin is not found
32
28
  """
33
29
  from sqlspec.exceptions import ImproperConfigurationError
34
- from sqlspec.extensions.litestar.plugin import SQLSpec
30
+ from sqlspec.extensions.litestar.plugin import SQLSpecPlugin
35
31
 
36
32
  with suppress(KeyError):
37
- return app.plugins.get(SQLSpec)
33
+ return app.plugins.get(SQLSpecPlugin)
38
34
  msg = "Failed to initialize database migrations. The required SQLSpec plugin is missing."
39
35
  raise ImproperConfigurationError(msg)
40
36
 
41
37
 
42
- @click.group(cls=LitestarGroup, name="db")
38
+ @click.group(cls=LitestarGroup, name="db", aliases=["database"])
43
39
  def database_group(ctx: "click.Context") -> None:
44
40
  """Manage SQLSpec database components."""
45
41
  ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
46
42
 
47
43
 
48
44
  add_migration_commands(database_group)
45
+
46
+
47
+ def add_sessions_delete_expired_command() -> None:
48
+ """Add delete-expired command to Litestar's sessions CLI group."""
49
+ try:
50
+ from litestar.cli._utils import console
51
+ from litestar.cli.commands.sessions import get_session_backend, sessions_group
52
+ except ImportError:
53
+ return
54
+
55
+ @sessions_group.command("delete-expired") # type: ignore[misc]
56
+ @click.option(
57
+ "--verbose", is_flag=True, default=False, help="Show detailed information about the cleanup operation"
58
+ )
59
+ def delete_expired_sessions_command(app: "Litestar", verbose: bool) -> None:
60
+ """Delete expired sessions from the session store.
61
+
62
+ This command removes all sessions that have passed their expiration time.
63
+ It can be scheduled via cron or systemd timers for automatic maintenance.
64
+
65
+ Examples:
66
+ litestar sessions delete-expired
67
+ litestar sessions delete-expired --verbose
68
+ """
69
+ import anyio
70
+
71
+ backend = get_session_backend(app)
72
+ store = backend.config.get_store_from_app(app)
73
+
74
+ if not hasattr(store, "delete_expired"):
75
+ console.print(f"[red]{type(store).__name__} does not support deleting expired sessions")
76
+ return
77
+
78
+ async def _delete_expired() -> int:
79
+ return await store.delete_expired() # type: ignore[no-any-return]
80
+
81
+ count = anyio.run(_delete_expired)
82
+
83
+ if count > 0:
84
+ if verbose:
85
+ console.print(f"[green]Successfully deleted {count} expired session(s)")
86
+ else:
87
+ console.print(f"[green]Deleted {count} expired session(s)")
88
+ else:
89
+ console.print("[yellow]No expired sessions found")
90
+
91
+
92
+ add_sessions_delete_expired_command()