sqlspec 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_typing.py +93 -0
- sqlspec/adapters/adbc/adk/store.py +21 -11
- sqlspec/adapters/adbc/data_dictionary.py +27 -5
- sqlspec/adapters/adbc/driver.py +83 -14
- sqlspec/adapters/aiosqlite/adk/store.py +27 -18
- sqlspec/adapters/asyncmy/adk/store.py +26 -16
- sqlspec/adapters/asyncpg/adk/store.py +26 -16
- sqlspec/adapters/asyncpg/data_dictionary.py +24 -17
- sqlspec/adapters/bigquery/adk/store.py +30 -21
- sqlspec/adapters/bigquery/config.py +11 -0
- sqlspec/adapters/bigquery/driver.py +138 -1
- sqlspec/adapters/duckdb/adk/store.py +21 -11
- sqlspec/adapters/duckdb/driver.py +87 -1
- sqlspec/adapters/oracledb/adk/store.py +89 -206
- sqlspec/adapters/oracledb/driver.py +183 -2
- sqlspec/adapters/oracledb/litestar/store.py +22 -24
- sqlspec/adapters/psqlpy/adk/store.py +28 -27
- sqlspec/adapters/psqlpy/data_dictionary.py +24 -17
- sqlspec/adapters/psqlpy/driver.py +7 -10
- sqlspec/adapters/psycopg/adk/store.py +51 -33
- sqlspec/adapters/psycopg/data_dictionary.py +48 -34
- sqlspec/adapters/sqlite/adk/store.py +29 -19
- sqlspec/config.py +100 -2
- sqlspec/core/filters.py +18 -10
- sqlspec/core/result.py +133 -2
- sqlspec/driver/_async.py +89 -0
- sqlspec/driver/_common.py +64 -29
- sqlspec/driver/_sync.py +95 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +2 -2
- sqlspec/extensions/adk/service.py +3 -3
- sqlspec/extensions/adk/store.py +8 -8
- sqlspec/extensions/aiosql/adapter.py +3 -15
- sqlspec/extensions/fastapi/__init__.py +21 -0
- sqlspec/extensions/fastapi/extension.py +331 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +71 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +389 -0
- sqlspec/extensions/litestar/config.py +3 -6
- sqlspec/extensions/litestar/plugin.py +26 -2
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +25 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +254 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/protocols.py +40 -0
- sqlspec/storage/_utils.py +1 -14
- sqlspec/storage/backends/fsspec.py +3 -5
- sqlspec/storage/backends/local.py +1 -1
- sqlspec/storage/backends/obstore.py +10 -18
- sqlspec/typing.py +16 -0
- sqlspec/utils/__init__.py +25 -4
- sqlspec/utils/arrow_helpers.py +81 -0
- sqlspec/utils/module_loader.py +203 -3
- sqlspec/utils/portal.py +311 -0
- sqlspec/utils/serializers.py +110 -1
- sqlspec/utils/sync_tools.py +15 -5
- sqlspec/utils/type_guards.py +25 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +2 -2
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/RECORD +64 -50
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
"""Flask extension for SQLSpec database integration."""
|
|
2
|
+
|
|
3
|
+
import atexit
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
5
|
+
|
|
6
|
+
from sqlspec.base import SQLSpec
|
|
7
|
+
from sqlspec.config import AsyncDatabaseConfig, NoPoolAsyncConfig
|
|
8
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
9
|
+
from sqlspec.extensions.flask._state import FlaskConfigState
|
|
10
|
+
from sqlspec.extensions.flask._utils import get_or_create_session
|
|
11
|
+
from sqlspec.utils.logging import get_logger
|
|
12
|
+
from sqlspec.utils.portal import PortalProvider
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from flask import Flask, Response
|
|
16
|
+
|
|
17
|
+
__all__ = ("SQLSpecPlugin",)
|
|
18
|
+
|
|
19
|
+
logger = get_logger("extensions.flask")
|
|
20
|
+
|
|
21
|
+
DEFAULT_COMMIT_MODE: Literal["manual"] = "manual"
|
|
22
|
+
DEFAULT_SESSION_KEY = "db_session"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SQLSpecPlugin:
|
|
26
|
+
"""Flask extension for SQLSpec database integration.
|
|
27
|
+
|
|
28
|
+
Provides request-scoped session management, automatic transaction handling,
|
|
29
|
+
and async adapter support via portal pattern.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
from flask import Flask
|
|
33
|
+
from sqlspec import SQLSpec
|
|
34
|
+
from sqlspec.adapters.sqlite import SqliteConfig
|
|
35
|
+
from sqlspec.extensions.flask import SQLSpecPlugin
|
|
36
|
+
|
|
37
|
+
sqlspec = SQLSpec()
|
|
38
|
+
config = SqliteConfig(
|
|
39
|
+
pool_config={"database": "app.db"},
|
|
40
|
+
extension_config={
|
|
41
|
+
"flask": {
|
|
42
|
+
"commit_mode": "autocommit",
|
|
43
|
+
"session_key": "db"
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
)
|
|
47
|
+
sqlspec.add_config(config)
|
|
48
|
+
|
|
49
|
+
app = Flask(__name__)
|
|
50
|
+
plugin = SQLSpecPlugin(sqlspec, app)
|
|
51
|
+
|
|
52
|
+
@app.route("/users")
|
|
53
|
+
def list_users():
|
|
54
|
+
db = plugin.get_session()
|
|
55
|
+
result = db.execute("SELECT * FROM users")
|
|
56
|
+
return {"users": result.all()}
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, sqlspec: SQLSpec, app: "Flask | None" = None) -> None:
|
|
60
|
+
"""Initialize Flask extension with SQLSpec instance.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
sqlspec: SQLSpec instance with registered configs.
|
|
64
|
+
app: Optional Flask application to initialize immediately.
|
|
65
|
+
"""
|
|
66
|
+
self._sqlspec = sqlspec
|
|
67
|
+
self._config_states: list[FlaskConfigState] = []
|
|
68
|
+
self._portal: PortalProvider | None = None
|
|
69
|
+
self._has_async_configs = False
|
|
70
|
+
self._cleanup_registered = False
|
|
71
|
+
self._shutdown_complete = False
|
|
72
|
+
|
|
73
|
+
for cfg in self._sqlspec.configs.values():
|
|
74
|
+
state = self._create_config_state(cfg)
|
|
75
|
+
self._config_states.append(state)
|
|
76
|
+
|
|
77
|
+
if state.is_async:
|
|
78
|
+
self._has_async_configs = True
|
|
79
|
+
|
|
80
|
+
if app is not None:
|
|
81
|
+
self.init_app(app)
|
|
82
|
+
|
|
83
|
+
def _create_config_state(self, config: Any) -> FlaskConfigState:
|
|
84
|
+
"""Create configuration state from database config.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
config: Database configuration instance.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
FlaskConfigState instance.
|
|
91
|
+
"""
|
|
92
|
+
flask_config = config.extension_config.get("flask", {})
|
|
93
|
+
|
|
94
|
+
session_key = flask_config.get("session_key", DEFAULT_SESSION_KEY)
|
|
95
|
+
connection_key = flask_config.get("connection_key", f"sqlspec_connection_{session_key}")
|
|
96
|
+
commit_mode = flask_config.get("commit_mode", DEFAULT_COMMIT_MODE)
|
|
97
|
+
extra_commit_statuses = flask_config.get("extra_commit_statuses")
|
|
98
|
+
extra_rollback_statuses = flask_config.get("extra_rollback_statuses")
|
|
99
|
+
|
|
100
|
+
is_async = isinstance(config, (AsyncDatabaseConfig, NoPoolAsyncConfig))
|
|
101
|
+
|
|
102
|
+
return FlaskConfigState(
|
|
103
|
+
config=config,
|
|
104
|
+
connection_key=connection_key,
|
|
105
|
+
session_key=session_key,
|
|
106
|
+
commit_mode=commit_mode,
|
|
107
|
+
extra_commit_statuses=extra_commit_statuses,
|
|
108
|
+
extra_rollback_statuses=extra_rollback_statuses,
|
|
109
|
+
is_async=is_async,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def init_app(self, app: "Flask") -> None:
|
|
113
|
+
"""Initialize Flask application with SQLSpec.
|
|
114
|
+
|
|
115
|
+
Validates configuration, creates portal if needed, creates pools,
|
|
116
|
+
and registers hooks.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
app: Flask application to initialize.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ImproperConfigurationError: If extension already registered or keys not unique.
|
|
123
|
+
"""
|
|
124
|
+
if "sqlspec" in app.extensions:
|
|
125
|
+
msg = "SQLSpec extension already registered on this Flask application"
|
|
126
|
+
raise ImproperConfigurationError(msg)
|
|
127
|
+
|
|
128
|
+
self._validate_unique_keys()
|
|
129
|
+
|
|
130
|
+
if self._has_async_configs:
|
|
131
|
+
self._portal = PortalProvider()
|
|
132
|
+
self._portal.start()
|
|
133
|
+
logger.debug("Portal provider started for async adapters")
|
|
134
|
+
|
|
135
|
+
pools: dict[str, Any] = {}
|
|
136
|
+
for config_state in self._config_states:
|
|
137
|
+
if config_state.config.supports_connection_pooling:
|
|
138
|
+
if config_state.is_async:
|
|
139
|
+
pool = self._portal.portal.call(config_state.config.create_pool) # type: ignore[union-attr,arg-type]
|
|
140
|
+
else:
|
|
141
|
+
pool = config_state.config.create_pool()
|
|
142
|
+
pools[config_state.session_key] = pool
|
|
143
|
+
|
|
144
|
+
app.extensions["sqlspec"] = {"plugin": self, "pools": pools}
|
|
145
|
+
|
|
146
|
+
app.before_request(self._before_request_handler)
|
|
147
|
+
app.after_request(self._after_request_handler)
|
|
148
|
+
app.teardown_appcontext(self._teardown_appcontext_handler)
|
|
149
|
+
self._register_shutdown_hook()
|
|
150
|
+
|
|
151
|
+
logger.debug("SQLSpec Flask extension initialized")
|
|
152
|
+
|
|
153
|
+
def _validate_unique_keys(self) -> None:
|
|
154
|
+
"""Validate that all state keys are unique across configs.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
ImproperConfigurationError: If duplicate keys found.
|
|
158
|
+
"""
|
|
159
|
+
all_keys: set[str] = set()
|
|
160
|
+
|
|
161
|
+
for state in self._config_states:
|
|
162
|
+
keys = {state.connection_key, state.session_key}
|
|
163
|
+
duplicates = all_keys & keys
|
|
164
|
+
|
|
165
|
+
if duplicates:
|
|
166
|
+
msg = f"Duplicate state keys found: {duplicates}. Use unique session_key values."
|
|
167
|
+
raise ImproperConfigurationError(msg)
|
|
168
|
+
|
|
169
|
+
all_keys.update(keys)
|
|
170
|
+
|
|
171
|
+
def _register_shutdown_hook(self) -> None:
|
|
172
|
+
"""Register shutdown hook for pool and portal cleanup."""
|
|
173
|
+
|
|
174
|
+
if self._cleanup_registered:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
atexit.register(self.shutdown)
|
|
178
|
+
self._cleanup_registered = True
|
|
179
|
+
|
|
180
|
+
def _before_request_handler(self) -> None:
|
|
181
|
+
"""Acquire connection before request.
|
|
182
|
+
|
|
183
|
+
Stores connection in Flask g object for each configured database.
|
|
184
|
+
Also stores context managers for proper cleanup.
|
|
185
|
+
"""
|
|
186
|
+
from flask import current_app, g
|
|
187
|
+
|
|
188
|
+
for config_state in self._config_states:
|
|
189
|
+
if config_state.config.supports_connection_pooling:
|
|
190
|
+
pool = current_app.extensions["sqlspec"]["pools"][config_state.session_key]
|
|
191
|
+
conn_ctx = config_state.config.provide_connection(pool)
|
|
192
|
+
|
|
193
|
+
if config_state.is_async:
|
|
194
|
+
connection = self._portal.portal.call(conn_ctx.__aenter__) # type: ignore[union-attr]
|
|
195
|
+
else:
|
|
196
|
+
connection = conn_ctx.__enter__() # type: ignore[union-attr]
|
|
197
|
+
|
|
198
|
+
setattr(g, f"{config_state.connection_key}_ctx", conn_ctx)
|
|
199
|
+
elif config_state.is_async:
|
|
200
|
+
connection = self._portal.portal.call(config_state.config.create_connection) # type: ignore[union-attr,arg-type]
|
|
201
|
+
else:
|
|
202
|
+
connection = config_state.config.create_connection()
|
|
203
|
+
|
|
204
|
+
setattr(g, config_state.connection_key, connection)
|
|
205
|
+
|
|
206
|
+
def _after_request_handler(self, response: "Response") -> "Response":
|
|
207
|
+
"""Handle transaction after request based on response status.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
response: Flask response object.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Response object (unchanged).
|
|
214
|
+
"""
|
|
215
|
+
from flask import g
|
|
216
|
+
|
|
217
|
+
for config_state in self._config_states:
|
|
218
|
+
if config_state.commit_mode == "manual":
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
cache_key = f"sqlspec_session_cache_{config_state.session_key}"
|
|
222
|
+
session = getattr(g, cache_key, None)
|
|
223
|
+
|
|
224
|
+
if session is None:
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
if config_state.should_commit(response.status_code):
|
|
228
|
+
self._execute_commit(session, config_state)
|
|
229
|
+
elif config_state.should_rollback(response.status_code):
|
|
230
|
+
self._execute_rollback(session, config_state)
|
|
231
|
+
|
|
232
|
+
return response
|
|
233
|
+
|
|
234
|
+
def _teardown_appcontext_handler(self, _exc: "BaseException | None" = None) -> None:
|
|
235
|
+
"""Clean up connections when request context ends.
|
|
236
|
+
|
|
237
|
+
Closes all connections and cleans up g object.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
_exc: Exception that occurred (if any).
|
|
241
|
+
"""
|
|
242
|
+
from flask import g
|
|
243
|
+
|
|
244
|
+
for config_state in self._config_states:
|
|
245
|
+
connection = getattr(g, config_state.connection_key, None)
|
|
246
|
+
ctx_key = f"{config_state.connection_key}_ctx"
|
|
247
|
+
conn_ctx = getattr(g, ctx_key, None)
|
|
248
|
+
|
|
249
|
+
if connection is not None:
|
|
250
|
+
try:
|
|
251
|
+
if conn_ctx is not None:
|
|
252
|
+
if config_state.is_async:
|
|
253
|
+
self._portal.portal.call(conn_ctx.__aexit__, None, None, None) # type: ignore[union-attr]
|
|
254
|
+
else:
|
|
255
|
+
conn_ctx.__exit__(None, None, None)
|
|
256
|
+
elif config_state.is_async:
|
|
257
|
+
self._portal.portal.call(connection.close) # type: ignore[union-attr]
|
|
258
|
+
else:
|
|
259
|
+
connection.close()
|
|
260
|
+
except Exception:
|
|
261
|
+
logger.exception("Error closing connection")
|
|
262
|
+
|
|
263
|
+
if hasattr(g, config_state.connection_key):
|
|
264
|
+
delattr(g, config_state.connection_key)
|
|
265
|
+
if hasattr(g, ctx_key):
|
|
266
|
+
delattr(g, ctx_key)
|
|
267
|
+
|
|
268
|
+
cache_key = f"sqlspec_session_cache_{config_state.session_key}"
|
|
269
|
+
if hasattr(g, cache_key):
|
|
270
|
+
delattr(g, cache_key)
|
|
271
|
+
|
|
272
|
+
def get_session(self, key: "str | None" = None) -> Any:
|
|
273
|
+
"""Get or create database session for current request.
|
|
274
|
+
|
|
275
|
+
Sessions are cached per request for consistency.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
key: Session key for multi-database configs. Defaults to first config if None.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Database session (driver instance).
|
|
282
|
+
"""
|
|
283
|
+
config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
|
|
284
|
+
|
|
285
|
+
return get_or_create_session(config_state, self._portal.portal if self._portal else None)
|
|
286
|
+
|
|
287
|
+
def get_connection(self, key: "str | None" = None) -> Any:
|
|
288
|
+
"""Get database connection for current request.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
key: Session key for multi-database configs. Defaults to first config if None.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
Raw database connection.
|
|
295
|
+
"""
|
|
296
|
+
from flask import g
|
|
297
|
+
|
|
298
|
+
config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
|
|
299
|
+
|
|
300
|
+
return getattr(g, config_state.connection_key)
|
|
301
|
+
|
|
302
|
+
def _get_config_state_by_key(self, key: str) -> FlaskConfigState:
|
|
303
|
+
"""Get config state by session key.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
key: Session key to look up.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
FlaskConfigState for the key.
|
|
310
|
+
|
|
311
|
+
Raises:
|
|
312
|
+
ImproperConfigurationError: If key not found.
|
|
313
|
+
"""
|
|
314
|
+
for state in self._config_states:
|
|
315
|
+
if state.session_key == key:
|
|
316
|
+
return state
|
|
317
|
+
|
|
318
|
+
msg = f"No configuration found for key: {key}"
|
|
319
|
+
raise ImproperConfigurationError(msg)
|
|
320
|
+
|
|
321
|
+
def shutdown(self) -> None:
|
|
322
|
+
"""Dispose connection pools and stop async portal."""
|
|
323
|
+
|
|
324
|
+
if self._shutdown_complete:
|
|
325
|
+
return
|
|
326
|
+
|
|
327
|
+
self._shutdown_complete = True
|
|
328
|
+
|
|
329
|
+
for config_state in self._config_states:
|
|
330
|
+
if config_state.config.supports_connection_pooling:
|
|
331
|
+
self._close_pool_state(config_state)
|
|
332
|
+
|
|
333
|
+
if self._portal is not None:
|
|
334
|
+
try:
|
|
335
|
+
self._portal.stop()
|
|
336
|
+
except Exception:
|
|
337
|
+
logger.exception("Error stopping portal during shutdown")
|
|
338
|
+
finally:
|
|
339
|
+
self._portal = None
|
|
340
|
+
|
|
341
|
+
def _close_pool_state(self, config_state: FlaskConfigState) -> None:
|
|
342
|
+
"""Close pool associated with configuration state."""
|
|
343
|
+
|
|
344
|
+
try:
|
|
345
|
+
if config_state.is_async:
|
|
346
|
+
if self._portal is None:
|
|
347
|
+
logger.debug(
|
|
348
|
+
"Portal not initialized - skipping async pool shutdown for %s", config_state.session_key
|
|
349
|
+
)
|
|
350
|
+
return
|
|
351
|
+
_ = self._portal.portal.call(config_state.config.close_pool) # type: ignore[arg-type]
|
|
352
|
+
else:
|
|
353
|
+
config_state.config.close_pool()
|
|
354
|
+
except Exception:
|
|
355
|
+
logger.exception("Error closing pool during shutdown for key %s", config_state.session_key)
|
|
356
|
+
|
|
357
|
+
def _execute_commit(self, session: Any, config_state: FlaskConfigState) -> None:
|
|
358
|
+
"""Execute commit on session.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
session: Database session.
|
|
362
|
+
config_state: Configuration state.
|
|
363
|
+
"""
|
|
364
|
+
try:
|
|
365
|
+
if config_state.is_async:
|
|
366
|
+
connection = self.get_connection(config_state.session_key)
|
|
367
|
+
self._portal.portal.call(connection.commit) # type: ignore[union-attr]
|
|
368
|
+
else:
|
|
369
|
+
connection = self.get_connection(config_state.session_key)
|
|
370
|
+
connection.commit()
|
|
371
|
+
except Exception:
|
|
372
|
+
logger.exception("Error committing transaction")
|
|
373
|
+
|
|
374
|
+
def _execute_rollback(self, session: Any, config_state: FlaskConfigState) -> None:
|
|
375
|
+
"""Execute rollback on session.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
session: Database session.
|
|
379
|
+
config_state: Configuration state.
|
|
380
|
+
"""
|
|
381
|
+
try:
|
|
382
|
+
if config_state.is_async:
|
|
383
|
+
connection = self.get_connection(config_state.session_key)
|
|
384
|
+
self._portal.portal.call(connection.rollback) # type: ignore[union-attr]
|
|
385
|
+
else:
|
|
386
|
+
connection = self.get_connection(config_state.session_key)
|
|
387
|
+
connection.rollback()
|
|
388
|
+
except Exception as exc:
|
|
389
|
+
logger.debug("Rollback failed (may be no active transaction): %s", exc)
|
|
@@ -40,8 +40,7 @@ class LitestarConfig(TypedDict):
|
|
|
40
40
|
in_memory: NotRequired[bool]
|
|
41
41
|
"""Enable in-memory table storage (Oracle-specific). Default: False.
|
|
42
42
|
|
|
43
|
-
When enabled, tables are created with the
|
|
44
|
-
which stores table data in columnar format in memory for faster query performance.
|
|
43
|
+
When enabled, tables are created with the in-memory attribute for databases that support it.
|
|
45
44
|
|
|
46
45
|
This is an Oracle-specific feature that requires:
|
|
47
46
|
- Oracle Database 12.1.0.2 or higher
|
|
@@ -62,8 +61,6 @@ class LitestarConfig(TypedDict):
|
|
|
62
61
|
)
|
|
63
62
|
|
|
64
63
|
Notes:
|
|
65
|
-
-
|
|
66
|
-
-
|
|
67
|
-
- Requires Oracle Database In-Memory option license
|
|
68
|
-
- Ignored by non-Oracle adapters
|
|
64
|
+
- Tables created with INMEMORY PRIORITY HIGH clause
|
|
65
|
+
- Ignored by unsupported adapters
|
|
69
66
|
"""
|
|
@@ -25,8 +25,9 @@ from sqlspec.extensions.litestar.handlers import (
|
|
|
25
25
|
pool_provider_maker,
|
|
26
26
|
session_provider_maker,
|
|
27
27
|
)
|
|
28
|
-
from sqlspec.typing import ConnectionT, PoolT
|
|
28
|
+
from sqlspec.typing import NUMPY_INSTALLED, ConnectionT, PoolT, SchemaT
|
|
29
29
|
from sqlspec.utils.logging import get_logger
|
|
30
|
+
from sqlspec.utils.serializers import numpy_array_dec_hook, numpy_array_enc_hook, numpy_array_predicate
|
|
30
31
|
|
|
31
32
|
if TYPE_CHECKING:
|
|
32
33
|
from collections.abc import AsyncGenerator, Callable
|
|
@@ -82,6 +83,10 @@ class _PluginConfigState:
|
|
|
82
83
|
class SQLSpecPlugin(InitPluginProtocol, CLIPlugin):
|
|
83
84
|
"""Litestar plugin for SQLSpec database integration.
|
|
84
85
|
|
|
86
|
+
Automatically configures NumPy array serialization when NumPy is installed,
|
|
87
|
+
enabling seamless bidirectional conversion between NumPy arrays and JSON
|
|
88
|
+
for vector embedding workflows.
|
|
89
|
+
|
|
85
90
|
Session Table Migrations:
|
|
86
91
|
The Litestar extension includes migrations for creating session storage tables.
|
|
87
92
|
To include these migrations in your database migration workflow, add 'litestar'
|
|
@@ -225,6 +230,8 @@ class SQLSpecPlugin(InitPluginProtocol, CLIPlugin):
|
|
|
225
230
|
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
|
|
226
231
|
"""Configure Litestar application with SQLSpec database integration.
|
|
227
232
|
|
|
233
|
+
Automatically registers NumPy array serialization when NumPy is installed.
|
|
234
|
+
|
|
228
235
|
Args:
|
|
229
236
|
app_config: The Litestar application configuration instance.
|
|
230
237
|
|
|
@@ -239,7 +246,7 @@ class SQLSpecPlugin(InitPluginProtocol, CLIPlugin):
|
|
|
239
246
|
app_config.on_startup.append(store_sqlspec_in_state)
|
|
240
247
|
app_config.signature_types.extend([SQLSpec, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT])
|
|
241
248
|
|
|
242
|
-
signature_namespace = {"ConnectionT": ConnectionT, "PoolT": PoolT, "DriverT": DriverT}
|
|
249
|
+
signature_namespace = {"ConnectionT": ConnectionT, "PoolT": PoolT, "DriverT": DriverT, "SchemaT": SchemaT}
|
|
243
250
|
|
|
244
251
|
for state in self._plugin_configs:
|
|
245
252
|
state.annotation = type(state.config)
|
|
@@ -262,6 +269,23 @@ class SQLSpecPlugin(InitPluginProtocol, CLIPlugin):
|
|
|
262
269
|
if signature_namespace:
|
|
263
270
|
app_config.signature_namespace.update(signature_namespace)
|
|
264
271
|
|
|
272
|
+
if NUMPY_INSTALLED:
|
|
273
|
+
import numpy as np
|
|
274
|
+
|
|
275
|
+
if app_config.type_encoders is None:
|
|
276
|
+
app_config.type_encoders = {np.ndarray: numpy_array_enc_hook}
|
|
277
|
+
else:
|
|
278
|
+
encoders_dict = dict(app_config.type_encoders)
|
|
279
|
+
encoders_dict[np.ndarray] = numpy_array_enc_hook
|
|
280
|
+
app_config.type_encoders = encoders_dict
|
|
281
|
+
|
|
282
|
+
if app_config.type_decoders is None:
|
|
283
|
+
app_config.type_decoders = [(numpy_array_predicate, numpy_array_dec_hook)] # type: ignore[list-item]
|
|
284
|
+
else:
|
|
285
|
+
decoders_list = list(app_config.type_decoders)
|
|
286
|
+
decoders_list.append((numpy_array_predicate, numpy_array_dec_hook)) # type: ignore[arg-type]
|
|
287
|
+
app_config.type_decoders = decoders_list
|
|
288
|
+
|
|
265
289
|
return app_config
|
|
266
290
|
|
|
267
291
|
def get_annotations(
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Starlette extension for SQLSpec.
|
|
2
|
+
|
|
3
|
+
Provides middleware-based session management, automatic transaction handling,
|
|
4
|
+
and connection pooling lifecycle management for Starlette applications.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from sqlspec.extensions.starlette.extension import SQLSpecPlugin
|
|
8
|
+
from sqlspec.extensions.starlette.middleware import SQLSpecAutocommitMiddleware, SQLSpecManualMiddleware
|
|
9
|
+
|
|
10
|
+
__all__ = ("SQLSpecAutocommitMiddleware", "SQLSpecManualMiddleware", "SQLSpecPlugin")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sqlspec.config import DatabaseConfigProtocol
|
|
6
|
+
|
|
7
|
+
__all__ = ("CommitMode", "SQLSpecConfigState")
|
|
8
|
+
|
|
9
|
+
CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SQLSpecConfigState:
|
|
14
|
+
"""Internal state for each database configuration.
|
|
15
|
+
|
|
16
|
+
Tracks all configuration parameters needed for middleware and session management.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
config: "DatabaseConfigProtocol[Any, Any, Any]"
|
|
20
|
+
connection_key: str
|
|
21
|
+
pool_key: str
|
|
22
|
+
session_key: str
|
|
23
|
+
commit_mode: CommitMode
|
|
24
|
+
extra_commit_statuses: "set[int] | None"
|
|
25
|
+
extra_rollback_statuses: "set[int] | None"
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from starlette.requests import Request
|
|
5
|
+
|
|
6
|
+
from sqlspec.extensions.starlette._state import SQLSpecConfigState
|
|
7
|
+
|
|
8
|
+
__all__ = ("get_connection_from_request", "get_or_create_session")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_connection_from_request(request: "Request", config_state: "SQLSpecConfigState") -> Any:
|
|
12
|
+
"""Get database connection from request state.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
request: Starlette request instance.
|
|
16
|
+
config_state: Configuration state for the database.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Database connection object.
|
|
20
|
+
"""
|
|
21
|
+
return getattr(request.state, config_state.connection_key)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_or_create_session(request: "Request", config_state: "SQLSpecConfigState") -> Any:
|
|
25
|
+
"""Get or create database session for request.
|
|
26
|
+
|
|
27
|
+
Sessions are cached per request to ensure the same session instance
|
|
28
|
+
is returned for multiple calls within the same request.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
request: Starlette request instance.
|
|
32
|
+
config_state: Configuration state for the database.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Database session (driver instance).
|
|
36
|
+
"""
|
|
37
|
+
session_instance_key = f"{config_state.session_key}_instance"
|
|
38
|
+
|
|
39
|
+
existing_session = getattr(request.state, session_instance_key, None)
|
|
40
|
+
if existing_session is not None:
|
|
41
|
+
return existing_session
|
|
42
|
+
|
|
43
|
+
connection = get_connection_from_request(request, config_state)
|
|
44
|
+
|
|
45
|
+
session = config_state.config.driver_type(
|
|
46
|
+
connection=connection,
|
|
47
|
+
statement_config=config_state.config.statement_config,
|
|
48
|
+
driver_features=config_state.config.driver_features,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
setattr(request.state, session_instance_key, session)
|
|
52
|
+
return session
|