sqlspec 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (212) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +55 -25
  3. sqlspec/_typing.py +155 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +880 -0
  7. sqlspec/adapters/adbc/config.py +62 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +74 -2
  9. sqlspec/adapters/adbc/driver.py +226 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +44 -50
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +536 -0
  16. sqlspec/adapters/aiosqlite/config.py +86 -16
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
  18. sqlspec/adapters/aiosqlite/driver.py +127 -38
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +1 -1
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +503 -0
  26. sqlspec/adapters/asyncmy/config.py +59 -17
  27. sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
  28. sqlspec/adapters/asyncmy/driver.py +293 -62
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +460 -0
  36. sqlspec/adapters/asyncpg/config.py +57 -36
  37. sqlspec/adapters/asyncpg/data_dictionary.py +48 -2
  38. sqlspec/adapters/asyncpg/driver.py +153 -23
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +585 -0
  44. sqlspec/adapters/bigquery/config.py +36 -11
  45. sqlspec/adapters/bigquery/data_dictionary.py +42 -2
  46. sqlspec/adapters/bigquery/driver.py +489 -144
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +55 -23
  50. sqlspec/adapters/duckdb/_types.py +2 -2
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +563 -0
  53. sqlspec/adapters/duckdb/config.py +79 -21
  54. sqlspec/adapters/duckdb/data_dictionary.py +41 -2
  55. sqlspec/adapters/duckdb/driver.py +225 -44
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +5 -5
  59. sqlspec/adapters/duckdb/type_converter.py +51 -21
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1628 -0
  64. sqlspec/adapters/oracledb/config.py +120 -36
  65. sqlspec/adapters/oracledb/data_dictionary.py +87 -20
  66. sqlspec/adapters/oracledb/driver.py +475 -86
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +765 -0
  69. sqlspec/adapters/oracledb/migrations.py +316 -25
  70. sqlspec/adapters/oracledb/type_converter.py +91 -16
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +483 -0
  75. sqlspec/adapters/psqlpy/config.py +45 -19
  76. sqlspec/adapters/psqlpy/data_dictionary.py +48 -2
  77. sqlspec/adapters/psqlpy/driver.py +108 -41
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +40 -11
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +962 -0
  85. sqlspec/adapters/psycopg/config.py +65 -37
  86. sqlspec/adapters/psycopg/data_dictionary.py +91 -3
  87. sqlspec/adapters/psycopg/driver.py +200 -78
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +582 -0
  95. sqlspec/adapters/sqlite/config.py +85 -16
  96. sqlspec/adapters/sqlite/data_dictionary.py +34 -2
  97. sqlspec/adapters/sqlite/driver.py +120 -52
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +5 -5
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +91 -58
  104. sqlspec/builder/_column.py +5 -5
  105. sqlspec/builder/_ddl.py +98 -89
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +41 -44
  109. sqlspec/builder/_insert.py +5 -82
  110. sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +9 -11
  113. sqlspec/builder/_select.py +1313 -25
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +76 -69
  116. sqlspec/config.py +331 -62
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +18 -18
  119. sqlspec/core/compiler.py +6 -8
  120. sqlspec/core/filters.py +55 -47
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +76 -45
  123. sqlspec/core/result.py +234 -47
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +32 -31
  126. sqlspec/core/type_conversion.py +3 -2
  127. sqlspec/driver/__init__.py +1 -3
  128. sqlspec/driver/_async.py +183 -160
  129. sqlspec/driver/_common.py +197 -109
  130. sqlspec/driver/_sync.py +189 -161
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +70 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +69 -61
  142. sqlspec/extensions/fastapi/__init__.py +21 -0
  143. sqlspec/extensions/fastapi/extension.py +331 -0
  144. sqlspec/extensions/fastapi/providers.py +543 -0
  145. sqlspec/extensions/flask/__init__.py +36 -0
  146. sqlspec/extensions/flask/_state.py +71 -0
  147. sqlspec/extensions/flask/_utils.py +40 -0
  148. sqlspec/extensions/flask/extension.py +389 -0
  149. sqlspec/extensions/litestar/__init__.py +21 -4
  150. sqlspec/extensions/litestar/cli.py +54 -10
  151. sqlspec/extensions/litestar/config.py +56 -266
  152. sqlspec/extensions/litestar/handlers.py +46 -17
  153. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  154. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  155. sqlspec/extensions/litestar/plugin.py +349 -224
  156. sqlspec/extensions/litestar/providers.py +25 -25
  157. sqlspec/extensions/litestar/store.py +265 -0
  158. sqlspec/extensions/starlette/__init__.py +10 -0
  159. sqlspec/extensions/starlette/_state.py +25 -0
  160. sqlspec/extensions/starlette/_utils.py +52 -0
  161. sqlspec/extensions/starlette/extension.py +254 -0
  162. sqlspec/extensions/starlette/middleware.py +154 -0
  163. sqlspec/loader.py +30 -49
  164. sqlspec/migrations/base.py +200 -76
  165. sqlspec/migrations/commands.py +591 -62
  166. sqlspec/migrations/context.py +6 -9
  167. sqlspec/migrations/fix.py +199 -0
  168. sqlspec/migrations/loaders.py +47 -19
  169. sqlspec/migrations/runner.py +241 -75
  170. sqlspec/migrations/tracker.py +237 -21
  171. sqlspec/migrations/utils.py +51 -3
  172. sqlspec/migrations/validation.py +177 -0
  173. sqlspec/protocols.py +106 -36
  174. sqlspec/storage/_utils.py +85 -0
  175. sqlspec/storage/backends/fsspec.py +133 -107
  176. sqlspec/storage/backends/local.py +78 -51
  177. sqlspec/storage/backends/obstore.py +276 -168
  178. sqlspec/storage/registry.py +75 -39
  179. sqlspec/typing.py +30 -84
  180. sqlspec/utils/__init__.py +25 -4
  181. sqlspec/utils/arrow_helpers.py +81 -0
  182. sqlspec/utils/config_resolver.py +6 -6
  183. sqlspec/utils/correlation.py +4 -5
  184. sqlspec/utils/data_transformation.py +3 -2
  185. sqlspec/utils/deprecation.py +9 -8
  186. sqlspec/utils/fixtures.py +4 -4
  187. sqlspec/utils/logging.py +46 -6
  188. sqlspec/utils/module_loader.py +205 -5
  189. sqlspec/utils/portal.py +311 -0
  190. sqlspec/utils/schema.py +288 -0
  191. sqlspec/utils/serializers.py +113 -4
  192. sqlspec/utils/sync_tools.py +36 -22
  193. sqlspec/utils/text.py +1 -2
  194. sqlspec/utils/type_guards.py +136 -20
  195. sqlspec/utils/version.py +433 -0
  196. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +41 -22
  197. sqlspec-0.28.0.dist-info/RECORD +221 -0
  198. sqlspec/builder/mixins/__init__.py +0 -55
  199. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
  200. sqlspec/builder/mixins/_delete_operations.py +0 -50
  201. sqlspec/builder/mixins/_insert_operations.py +0 -282
  202. sqlspec/builder/mixins/_merge_operations.py +0 -698
  203. sqlspec/builder/mixins/_order_limit_operations.py +0 -145
  204. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  205. sqlspec/builder/mixins/_select_operations.py +0 -930
  206. sqlspec/builder/mixins/_update_operations.py +0 -199
  207. sqlspec/builder/mixins/_where_clause.py +0 -1298
  208. sqlspec-0.26.0.dist-info/RECORD +0 -157
  209. sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
  210. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
  211. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
  212. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,254 @@
1
+ from contextlib import asynccontextmanager
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ from sqlspec.base import SQLSpec
5
+ from sqlspec.exceptions import ImproperConfigurationError
6
+ from sqlspec.extensions.starlette._state import SQLSpecConfigState
7
+ from sqlspec.extensions.starlette._utils import get_or_create_session
8
+ from sqlspec.extensions.starlette.middleware import SQLSpecAutocommitMiddleware, SQLSpecManualMiddleware
9
+ from sqlspec.utils.logging import get_logger
10
+
11
+ if TYPE_CHECKING:
12
+ from collections.abc import AsyncGenerator
13
+
14
+ from starlette.applications import Starlette
15
+ from starlette.requests import Request
16
+
17
+ __all__ = ("SQLSpecPlugin",)
18
+
19
+ logger = get_logger("extensions.starlette")
20
+
21
+ DEFAULT_COMMIT_MODE = "manual"
22
+ DEFAULT_CONNECTION_KEY = "db_connection"
23
+ DEFAULT_POOL_KEY = "db_pool"
24
+ DEFAULT_SESSION_KEY = "db_session"
25
+
26
+
27
+ class SQLSpecPlugin:
28
+ """SQLSpec integration for Starlette applications.
29
+
30
+ Provides middleware-based session management, automatic transaction handling,
31
+ and connection pooling lifecycle management.
32
+
33
+ Example:
34
+ from starlette.applications import Starlette
35
+ from sqlspec import SQLSpec
36
+ from sqlspec.adapters.asyncpg import AsyncpgConfig
37
+ from sqlspec.extensions.starlette import SQLSpecPlugin
38
+
39
+ sqlspec = SQLSpec()
40
+ sqlspec.add_config(AsyncpgConfig(
41
+ bind_key="default",
42
+ pool_config={"dsn": "postgresql://localhost/mydb"},
43
+ extension_config={
44
+ "starlette": {
45
+ "commit_mode": "autocommit",
46
+ "session_key": "db"
47
+ }
48
+ }
49
+ ))
50
+
51
+ app = Starlette()
52
+ db_ext = SQLSpecPlugin(sqlspec, app)
53
+
54
+ @app.route("/users")
55
+ async def list_users(request):
56
+ db = db_ext.get_session(request)
57
+ result = await db.execute("SELECT * FROM users")
58
+ return JSONResponse({"users": result.all()})
59
+ """
60
+
61
+ __slots__ = ("_config_states", "_sqlspec")
62
+
63
+ def __init__(self, sqlspec: SQLSpec, app: "Starlette | None" = None) -> None:
64
+ """Initialize SQLSpec Starlette extension.
65
+
66
+ Args:
67
+ sqlspec: Pre-configured SQLSpec instance with registered configs.
68
+ app: Optional Starlette application to initialize immediately.
69
+ """
70
+ self._sqlspec = sqlspec
71
+ self._config_states: list[SQLSpecConfigState] = []
72
+
73
+ for cfg in self._sqlspec.configs.values():
74
+ settings = self._extract_starlette_settings(cfg)
75
+ state = self._create_config_state(cfg, settings)
76
+ self._config_states.append(state)
77
+
78
+ if app is not None:
79
+ self.init_app(app)
80
+
81
+ def _extract_starlette_settings(self, config: Any) -> "dict[str, Any]":
82
+ """Extract Starlette settings from config.extension_config.
83
+
84
+ Args:
85
+ config: Database configuration instance.
86
+
87
+ Returns:
88
+ Dictionary of Starlette-specific settings.
89
+ """
90
+ starlette_config = config.extension_config.get("starlette", {})
91
+
92
+ connection_key = starlette_config.get("connection_key", DEFAULT_CONNECTION_KEY)
93
+ pool_key = starlette_config.get("pool_key", DEFAULT_POOL_KEY)
94
+ session_key = starlette_config.get("session_key", DEFAULT_SESSION_KEY)
95
+ commit_mode = starlette_config.get("commit_mode", DEFAULT_COMMIT_MODE)
96
+
97
+ if not config.supports_connection_pooling and pool_key == DEFAULT_POOL_KEY:
98
+ pool_key = f"_{DEFAULT_POOL_KEY}_{id(config)}"
99
+
100
+ return {
101
+ "connection_key": connection_key,
102
+ "pool_key": pool_key,
103
+ "session_key": session_key,
104
+ "commit_mode": commit_mode,
105
+ "extra_commit_statuses": starlette_config.get("extra_commit_statuses"),
106
+ "extra_rollback_statuses": starlette_config.get("extra_rollback_statuses"),
107
+ }
108
+
109
+ def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSpecConfigState:
110
+ """Create configuration state object.
111
+
112
+ Args:
113
+ config: Database configuration instance.
114
+ settings: Extracted Starlette settings.
115
+
116
+ Returns:
117
+ Configuration state instance.
118
+ """
119
+ return SQLSpecConfigState(
120
+ config=config,
121
+ connection_key=settings["connection_key"],
122
+ pool_key=settings["pool_key"],
123
+ session_key=settings["session_key"],
124
+ commit_mode=settings["commit_mode"],
125
+ extra_commit_statuses=settings["extra_commit_statuses"],
126
+ extra_rollback_statuses=settings["extra_rollback_statuses"],
127
+ )
128
+
129
+ def init_app(self, app: "Starlette") -> None:
130
+ """Initialize Starlette application with SQLSpec.
131
+
132
+ Validates configuration, wraps lifespan, and adds middleware.
133
+
134
+ Args:
135
+ app: Starlette application instance.
136
+ """
137
+ self._validate_unique_keys()
138
+
139
+ original_lifespan = app.router.lifespan_context
140
+
141
+ @asynccontextmanager
142
+ async def combined_lifespan(app: "Starlette") -> "AsyncGenerator[None, None]":
143
+ async with self.lifespan(app), original_lifespan(app):
144
+ yield
145
+
146
+ app.router.lifespan_context = combined_lifespan
147
+
148
+ for config_state in self._config_states:
149
+ self._add_middleware(app, config_state)
150
+
151
+ def _validate_unique_keys(self) -> None:
152
+ """Validate that all state keys are unique across configs.
153
+
154
+ Raises:
155
+ ImproperConfigurationError: If duplicate keys found.
156
+ """
157
+ all_keys: set[str] = set()
158
+
159
+ for state in self._config_states:
160
+ keys = {state.connection_key, state.pool_key, state.session_key}
161
+ duplicates = all_keys & keys
162
+
163
+ if duplicates:
164
+ msg = f"Duplicate state keys found: {duplicates}"
165
+ raise ImproperConfigurationError(msg)
166
+
167
+ all_keys.update(keys)
168
+
169
+ def _add_middleware(self, app: "Starlette", config_state: SQLSpecConfigState) -> None:
170
+ """Add transaction middleware for configuration.
171
+
172
+ Args:
173
+ app: Starlette application instance.
174
+ config_state: Configuration state.
175
+ """
176
+ if config_state.commit_mode == "manual":
177
+ app.add_middleware(SQLSpecManualMiddleware, config_state=config_state)
178
+ elif config_state.commit_mode == "autocommit":
179
+ app.add_middleware(SQLSpecAutocommitMiddleware, config_state=config_state, include_redirect=False)
180
+ elif config_state.commit_mode == "autocommit_include_redirect":
181
+ app.add_middleware(SQLSpecAutocommitMiddleware, config_state=config_state, include_redirect=True)
182
+
183
+ @asynccontextmanager
184
+ async def lifespan(self, app: "Starlette") -> "AsyncGenerator[None, None]":
185
+ """Manage connection pool lifecycle.
186
+
187
+ Args:
188
+ app: Starlette application instance.
189
+
190
+ Yields:
191
+ None
192
+ """
193
+ for config_state in self._config_states:
194
+ if config_state.config.supports_connection_pooling:
195
+ pool = await config_state.config.create_pool()
196
+ setattr(app.state, config_state.pool_key, pool)
197
+
198
+ try:
199
+ yield
200
+ finally:
201
+ for config_state in self._config_states:
202
+ if config_state.config.supports_connection_pooling:
203
+ close_result = config_state.config.close_pool()
204
+ if close_result is not None:
205
+ await close_result
206
+
207
+ def get_session(self, request: "Request", key: "str | None" = None) -> Any:
208
+ """Get or create database session for request.
209
+
210
+ Sessions are cached per request to ensure consistency.
211
+
212
+ Args:
213
+ request: Starlette request instance.
214
+ key: Optional session key to retrieve specific database session.
215
+
216
+ Returns:
217
+ Database session (driver instance).
218
+ """
219
+ config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
220
+
221
+ return get_or_create_session(request, config_state)
222
+
223
+ def get_connection(self, request: "Request", key: "str | None" = None) -> Any:
224
+ """Get database connection from request state.
225
+
226
+ Args:
227
+ request: Starlette request instance.
228
+ key: Optional session key to retrieve specific database connection.
229
+
230
+ Returns:
231
+ Database connection object.
232
+ """
233
+ config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
234
+
235
+ return getattr(request.state, config_state.connection_key)
236
+
237
+ def _get_config_state_by_key(self, key: str) -> SQLSpecConfigState:
238
+ """Get configuration state by session key.
239
+
240
+ Args:
241
+ key: Session key to search for.
242
+
243
+ Returns:
244
+ Configuration state matching the key.
245
+
246
+ Raises:
247
+ ValueError: If no configuration found with the specified key.
248
+ """
249
+ for state in self._config_states:
250
+ if state.session_key == key:
251
+ return state
252
+
253
+ msg = f"No configuration found with session_key: {key}"
254
+ raise ValueError(msg)
@@ -0,0 +1,154 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ from starlette.middleware.base import BaseHTTPMiddleware
4
+
5
+ from sqlspec.utils.logging import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ from starlette.requests import Request
9
+
10
+ from sqlspec.extensions.starlette._state import SQLSpecConfigState
11
+
12
+ __all__ = ("SQLSpecAutocommitMiddleware", "SQLSpecManualMiddleware")
13
+
14
+ logger = get_logger("extensions.starlette.middleware")
15
+
16
+ HTTP_200_OK = 200
17
+ HTTP_300_MULTIPLE_CHOICES = 300
18
+ HTTP_400_BAD_REQUEST = 400
19
+
20
+
21
+ class SQLSpecManualMiddleware(BaseHTTPMiddleware):
22
+ """Middleware for manual transaction mode.
23
+
24
+ Acquires connection from pool, stores in request.state, releases after request.
25
+ No automatic commit or rollback - user code must handle transactions.
26
+ """
27
+
28
+ def __init__(self, app: Any, config_state: "SQLSpecConfigState") -> None:
29
+ """Initialize middleware.
30
+
31
+ Args:
32
+ app: Starlette application instance.
33
+ config_state: Configuration state for this database.
34
+ """
35
+ super().__init__(app)
36
+ self.config_state = config_state
37
+
38
+ async def dispatch(self, request: "Request", call_next: Any) -> Any:
39
+ """Process request with manual transaction mode.
40
+
41
+ Args:
42
+ request: Incoming HTTP request.
43
+ call_next: Next middleware or route handler.
44
+
45
+ Returns:
46
+ HTTP response.
47
+ """
48
+ config = self.config_state.config
49
+ connection_key = self.config_state.connection_key
50
+
51
+ if config.supports_connection_pooling:
52
+ pool = getattr(request.app.state, self.config_state.pool_key)
53
+ async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
54
+ setattr(request.state, connection_key, connection)
55
+ try:
56
+ return await call_next(request)
57
+ finally:
58
+ delattr(request.state, connection_key)
59
+ else:
60
+ connection = await config.create_connection()
61
+ setattr(request.state, connection_key, connection)
62
+ try:
63
+ return await call_next(request)
64
+ finally:
65
+ await connection.close()
66
+
67
+
68
+ class SQLSpecAutocommitMiddleware(BaseHTTPMiddleware):
69
+ """Middleware for autocommit transaction mode.
70
+
71
+ Acquires connection, commits on success status codes, rollbacks on error status codes.
72
+ """
73
+
74
+ def __init__(self, app: Any, config_state: "SQLSpecConfigState", include_redirect: bool = False) -> None:
75
+ """Initialize middleware.
76
+
77
+ Args:
78
+ app: Starlette application instance.
79
+ config_state: Configuration state for this database.
80
+ include_redirect: If True, commit on 3xx status codes as well.
81
+ """
82
+ super().__init__(app)
83
+ self.config_state = config_state
84
+ self.include_redirect = include_redirect
85
+
86
+ async def dispatch(self, request: "Request", call_next: Any) -> Any:
87
+ """Process request with autocommit transaction mode.
88
+
89
+ Args:
90
+ request: Incoming HTTP request.
91
+ call_next: Next middleware or route handler.
92
+
93
+ Returns:
94
+ HTTP response.
95
+ """
96
+ config = self.config_state.config
97
+ connection_key = self.config_state.connection_key
98
+
99
+ if config.supports_connection_pooling:
100
+ pool = getattr(request.app.state, self.config_state.pool_key)
101
+ async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
102
+ setattr(request.state, connection_key, connection)
103
+ try:
104
+ response = await call_next(request)
105
+
106
+ if self._should_commit(response.status_code):
107
+ await connection.commit()
108
+ else:
109
+ await connection.rollback()
110
+ except Exception:
111
+ await connection.rollback()
112
+ raise
113
+ else:
114
+ return response
115
+ finally:
116
+ delattr(request.state, connection_key)
117
+ else:
118
+ connection = await config.create_connection()
119
+ setattr(request.state, connection_key, connection)
120
+ try:
121
+ response = await call_next(request)
122
+
123
+ if self._should_commit(response.status_code):
124
+ await connection.commit()
125
+ else:
126
+ await connection.rollback()
127
+ except Exception:
128
+ await connection.rollback()
129
+ raise
130
+ else:
131
+ return response
132
+ finally:
133
+ await connection.close()
134
+
135
+ def _should_commit(self, status_code: int) -> bool:
136
+ """Determine if response status code should trigger commit.
137
+
138
+ Args:
139
+ status_code: HTTP status code.
140
+
141
+ Returns:
142
+ True if should commit, False if should rollback.
143
+ """
144
+ extra_commit = self.config_state.extra_commit_statuses or set()
145
+ extra_rollback = self.config_state.extra_rollback_statuses or set()
146
+
147
+ if status_code in extra_commit:
148
+ return True
149
+ if status_code in extra_rollback:
150
+ return False
151
+
152
+ if HTTP_200_OK <= status_code < HTTP_300_MULTIPLE_CHOICES:
153
+ return True
154
+ return bool(self.include_redirect and HTTP_300_MULTIPLE_CHOICES <= status_code < HTTP_400_BAD_REQUEST)
sqlspec/loader.py CHANGED
@@ -9,7 +9,7 @@ import re
9
9
  import time
10
10
  from datetime import datetime, timezone
11
11
  from pathlib import Path
12
- from typing import TYPE_CHECKING, Any, Final, Optional, Union
12
+ from typing import TYPE_CHECKING, Any, Final
13
13
  from urllib.parse import unquote, urlparse
14
14
 
15
15
  from sqlspec.core.cache import get_cache, get_cache_config
@@ -95,7 +95,7 @@ class NamedStatement:
95
95
 
96
96
  __slots__ = ("dialect", "name", "sql", "start_line")
97
97
 
98
- def __init__(self, name: str, sql: str, dialect: "Optional[str]" = None, start_line: int = 0) -> None:
98
+ def __init__(self, name: str, sql: str, dialect: "str | None" = None, start_line: int = 0) -> None:
99
99
  self.name = name
100
100
  self.sql = sql
101
101
  self.dialect = dialect
@@ -112,11 +112,7 @@ class SQLFile:
112
112
  __slots__ = ("checksum", "content", "loaded_at", "metadata", "path")
113
113
 
114
114
  def __init__(
115
- self,
116
- content: str,
117
- path: str,
118
- metadata: "Optional[dict[str, Any]]" = None,
119
- loaded_at: "Optional[datetime]" = None,
115
+ self, content: str, path: str, metadata: "dict[str, Any] | None" = None, loaded_at: "datetime | None" = None
120
116
  ) -> None:
121
117
  """Initialize SQLFile.
122
118
 
@@ -163,7 +159,7 @@ class SQLFileLoader:
163
159
 
164
160
  __slots__ = ("_files", "_queries", "_query_to_file", "encoding", "storage_registry")
165
161
 
166
- def __init__(self, *, encoding: str = "utf-8", storage_registry: "Optional[StorageRegistry]" = None) -> None:
162
+ def __init__(self, *, encoding: str = "utf-8", storage_registry: "StorageRegistry | None" = None) -> None:
167
163
  """Initialize the SQL file loader.
168
164
 
169
165
  Args:
@@ -188,7 +184,7 @@ class SQLFileLoader:
188
184
  """
189
185
  raise SQLFileNotFoundError(path)
190
186
 
191
- def _generate_file_cache_key(self, path: Union[str, Path]) -> str:
187
+ def _generate_file_cache_key(self, path: str | Path) -> str:
192
188
  """Generate cache key for a file path.
193
189
 
194
190
  Args:
@@ -201,7 +197,7 @@ class SQLFileLoader:
201
197
  path_hash = hashlib.md5(path_str.encode(), usedforsecurity=False).hexdigest()
202
198
  return f"file:{path_hash[:16]}"
203
199
 
204
- def _calculate_file_checksum(self, path: Union[str, Path]) -> str:
200
+ def _calculate_file_checksum(self, path: str | Path) -> str:
205
201
  """Calculate checksum for file content validation.
206
202
 
207
203
  Args:
@@ -218,7 +214,7 @@ class SQLFileLoader:
218
214
  except Exception as e:
219
215
  raise SQLFileParseError(str(path), str(path), e) from e
220
216
 
221
- def _is_file_unchanged(self, path: Union[str, Path], cached_file: CachedSQLFile) -> bool:
217
+ def _is_file_unchanged(self, path: str | Path, cached_file: CachedSQLFile) -> bool:
222
218
  """Check if file has changed since caching.
223
219
 
224
220
  Args:
@@ -235,7 +231,7 @@ class SQLFileLoader:
235
231
  else:
236
232
  return current_checksum == cached_file.sql_file.checksum
237
233
 
238
- def _read_file_content(self, path: Union[str, Path]) -> str:
234
+ def _read_file_content(self, path: str | Path) -> str:
239
235
  """Read file content using storage backend.
240
236
 
241
237
  Args:
@@ -349,7 +345,7 @@ class SQLFileLoader:
349
345
 
350
346
  return statements
351
347
 
352
- def load_sql(self, *paths: Union[str, Path]) -> None:
348
+ def load_sql(self, *paths: str | Path) -> None:
353
349
  """Load SQL files and parse named queries.
354
350
 
355
351
  Args:
@@ -358,43 +354,20 @@ class SQLFileLoader:
358
354
  correlation_id = CorrelationContext.get()
359
355
  start_time = time.perf_counter()
360
356
 
361
- logger.info("Loading SQL files", extra={"file_count": len(paths), "correlation_id": correlation_id})
362
-
363
- loaded_count = 0
364
- query_count_before = len(self._queries)
365
-
366
357
  try:
367
358
  for path in paths:
368
359
  path_str = str(path)
369
360
  if "://" in path_str:
370
361
  self._load_single_file(path, None)
371
- loaded_count += 1
372
362
  else:
373
363
  path_obj = Path(path)
374
364
  if path_obj.is_dir():
375
- loaded_count += self._load_directory(path_obj)
365
+ self._load_directory(path_obj)
376
366
  elif path_obj.exists():
377
367
  self._load_single_file(path_obj, None)
378
- loaded_count += 1
379
368
  elif path_obj.suffix:
380
369
  self._raise_file_not_found(str(path))
381
370
 
382
- duration = time.perf_counter() - start_time
383
- new_queries = len(self._queries) - query_count_before
384
-
385
- logger.info(
386
- "Loaded %d SQL files with %d new queries in %.3fms",
387
- loaded_count,
388
- new_queries,
389
- duration * 1000,
390
- extra={
391
- "files_loaded": loaded_count,
392
- "new_queries": new_queries,
393
- "duration_ms": duration * 1000,
394
- "correlation_id": correlation_id,
395
- },
396
- )
397
-
398
371
  except Exception as e:
399
372
  duration = time.perf_counter() - start_time
400
373
  logger.exception(
@@ -408,34 +381,40 @@ class SQLFileLoader:
408
381
  )
409
382
  raise
410
383
 
411
- def _load_directory(self, dir_path: Path) -> int:
412
- """Load all SQL files from a directory."""
384
+ def _load_directory(self, dir_path: Path) -> None:
385
+ """Load all SQL files from a directory.
386
+
387
+ Args:
388
+ dir_path: Directory path to load SQL files from.
389
+ """
413
390
  sql_files = list(dir_path.rglob("*.sql"))
414
391
  if not sql_files:
415
- return 0
392
+ return
416
393
 
417
394
  for file_path in sql_files:
418
395
  relative_path = file_path.relative_to(dir_path)
419
396
  namespace_parts = relative_path.parent.parts
420
397
  self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
421
- return len(sql_files)
422
398
 
423
- def _load_single_file(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
399
+ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> bool:
424
400
  """Load a single SQL file with optional namespace.
425
401
 
426
402
  Args:
427
403
  file_path: Path to the SQL file.
428
404
  namespace: Optional namespace prefix for queries.
405
+
406
+ Returns:
407
+ True if file was newly loaded, False if already cached.
429
408
  """
430
409
  path_str = str(file_path)
431
410
 
432
411
  if path_str in self._files:
433
- return
412
+ return False
434
413
 
435
414
  cache_config = get_cache_config()
436
415
  if not cache_config.compiled_cache_enabled:
437
416
  self._load_file_without_cache(file_path, namespace)
438
- return
417
+ return True
439
418
 
440
419
  cache_key_str = self._generate_file_cache_key(file_path)
441
420
  cache = get_cache()
@@ -459,7 +438,7 @@ class SQLFileLoader:
459
438
  )
460
439
  self._queries[namespaced_name] = statement
461
440
  self._query_to_file[namespaced_name] = path_str
462
- return
441
+ return True
463
442
 
464
443
  self._load_file_without_cache(file_path, namespace)
465
444
 
@@ -476,7 +455,9 @@ class SQLFileLoader:
476
455
  cached_file_data = CachedSQLFile(sql_file=sql_file, parsed_statements=file_statements)
477
456
  cache.put("file", cache_key_str, cached_file_data)
478
457
 
479
- def _load_file_without_cache(self, file_path: Union[str, Path], namespace: Optional[str]) -> None:
458
+ return True
459
+
460
+ def _load_file_without_cache(self, file_path: str | Path, namespace: str | None) -> None:
480
461
  """Load a single SQL file without using cache.
481
462
 
482
463
  Args:
@@ -503,7 +484,7 @@ class SQLFileLoader:
503
484
  self._queries[namespaced_name] = statement
504
485
  self._query_to_file[namespaced_name] = path_str
505
486
 
506
- def add_named_sql(self, name: str, sql: str, dialect: "Optional[str]" = None) -> None:
487
+ def add_named_sql(self, name: str, sql: str, dialect: "str | None" = None) -> None:
507
488
  """Add a named SQL query directly without loading from a file.
508
489
 
509
490
  Args:
@@ -529,7 +510,7 @@ class SQLFileLoader:
529
510
  self._queries[normalized_name] = statement
530
511
  self._query_to_file[normalized_name] = "<directly added>"
531
512
 
532
- def get_file(self, path: Union[str, Path]) -> "Optional[SQLFile]":
513
+ def get_file(self, path: str | Path) -> "SQLFile | None":
533
514
  """Get a loaded SQLFile object by path.
534
515
 
535
516
  Args:
@@ -540,7 +521,7 @@ class SQLFileLoader:
540
521
  """
541
522
  return self._files.get(str(path))
542
523
 
543
- def get_file_for_query(self, name: str) -> "Optional[SQLFile]":
524
+ def get_file_for_query(self, name: str) -> "SQLFile | None":
544
525
  """Get the SQLFile object containing a query.
545
526
 
546
527
  Args: