sqlspec 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +7 -15
- sqlspec/_serialization.py +55 -25
- sqlspec/_typing.py +155 -52
- sqlspec/adapters/adbc/_types.py +1 -1
- sqlspec/adapters/adbc/adk/__init__.py +5 -0
- sqlspec/adapters/adbc/adk/store.py +880 -0
- sqlspec/adapters/adbc/config.py +62 -12
- sqlspec/adapters/adbc/data_dictionary.py +74 -2
- sqlspec/adapters/adbc/driver.py +226 -58
- sqlspec/adapters/adbc/litestar/__init__.py +5 -0
- sqlspec/adapters/adbc/litestar/store.py +504 -0
- sqlspec/adapters/adbc/type_converter.py +44 -50
- sqlspec/adapters/aiosqlite/_types.py +1 -1
- sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/adk/store.py +536 -0
- sqlspec/adapters/aiosqlite/config.py +86 -16
- sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
- sqlspec/adapters/aiosqlite/driver.py +127 -38
- sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
- sqlspec/adapters/aiosqlite/pool.py +7 -7
- sqlspec/adapters/asyncmy/__init__.py +7 -1
- sqlspec/adapters/asyncmy/_types.py +1 -1
- sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
- sqlspec/adapters/asyncmy/adk/store.py +503 -0
- sqlspec/adapters/asyncmy/config.py +59 -17
- sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
- sqlspec/adapters/asyncmy/driver.py +293 -62
- sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncmy/litestar/store.py +296 -0
- sqlspec/adapters/asyncpg/__init__.py +2 -1
- sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
- sqlspec/adapters/asyncpg/_types.py +11 -7
- sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
- sqlspec/adapters/asyncpg/adk/store.py +460 -0
- sqlspec/adapters/asyncpg/config.py +57 -36
- sqlspec/adapters/asyncpg/data_dictionary.py +48 -2
- sqlspec/adapters/asyncpg/driver.py +153 -23
- sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncpg/litestar/store.py +253 -0
- sqlspec/adapters/bigquery/_types.py +1 -1
- sqlspec/adapters/bigquery/adk/__init__.py +5 -0
- sqlspec/adapters/bigquery/adk/store.py +585 -0
- sqlspec/adapters/bigquery/config.py +36 -11
- sqlspec/adapters/bigquery/data_dictionary.py +42 -2
- sqlspec/adapters/bigquery/driver.py +489 -144
- sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
- sqlspec/adapters/bigquery/litestar/store.py +327 -0
- sqlspec/adapters/bigquery/type_converter.py +55 -23
- sqlspec/adapters/duckdb/_types.py +2 -2
- sqlspec/adapters/duckdb/adk/__init__.py +14 -0
- sqlspec/adapters/duckdb/adk/store.py +563 -0
- sqlspec/adapters/duckdb/config.py +79 -21
- sqlspec/adapters/duckdb/data_dictionary.py +41 -2
- sqlspec/adapters/duckdb/driver.py +225 -44
- sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
- sqlspec/adapters/duckdb/litestar/store.py +332 -0
- sqlspec/adapters/duckdb/pool.py +5 -5
- sqlspec/adapters/duckdb/type_converter.py +51 -21
- sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
- sqlspec/adapters/oracledb/_types.py +20 -2
- sqlspec/adapters/oracledb/adk/__init__.py +5 -0
- sqlspec/adapters/oracledb/adk/store.py +1628 -0
- sqlspec/adapters/oracledb/config.py +120 -36
- sqlspec/adapters/oracledb/data_dictionary.py +87 -20
- sqlspec/adapters/oracledb/driver.py +475 -86
- sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
- sqlspec/adapters/oracledb/litestar/store.py +765 -0
- sqlspec/adapters/oracledb/migrations.py +316 -25
- sqlspec/adapters/oracledb/type_converter.py +91 -16
- sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
- sqlspec/adapters/psqlpy/_types.py +2 -1
- sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
- sqlspec/adapters/psqlpy/adk/store.py +483 -0
- sqlspec/adapters/psqlpy/config.py +45 -19
- sqlspec/adapters/psqlpy/data_dictionary.py +48 -2
- sqlspec/adapters/psqlpy/driver.py +108 -41
- sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
- sqlspec/adapters/psqlpy/litestar/store.py +272 -0
- sqlspec/adapters/psqlpy/type_converter.py +40 -11
- sqlspec/adapters/psycopg/_type_handlers.py +80 -0
- sqlspec/adapters/psycopg/_types.py +2 -1
- sqlspec/adapters/psycopg/adk/__init__.py +5 -0
- sqlspec/adapters/psycopg/adk/store.py +962 -0
- sqlspec/adapters/psycopg/config.py +65 -37
- sqlspec/adapters/psycopg/data_dictionary.py +91 -3
- sqlspec/adapters/psycopg/driver.py +200 -78
- sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
- sqlspec/adapters/psycopg/litestar/store.py +554 -0
- sqlspec/adapters/sqlite/__init__.py +2 -1
- sqlspec/adapters/sqlite/_type_handlers.py +86 -0
- sqlspec/adapters/sqlite/_types.py +1 -1
- sqlspec/adapters/sqlite/adk/__init__.py +5 -0
- sqlspec/adapters/sqlite/adk/store.py +582 -0
- sqlspec/adapters/sqlite/config.py +85 -16
- sqlspec/adapters/sqlite/data_dictionary.py +34 -2
- sqlspec/adapters/sqlite/driver.py +120 -52
- sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/sqlite/litestar/store.py +318 -0
- sqlspec/adapters/sqlite/pool.py +5 -5
- sqlspec/base.py +45 -26
- sqlspec/builder/__init__.py +73 -4
- sqlspec/builder/_base.py +91 -58
- sqlspec/builder/_column.py +5 -5
- sqlspec/builder/_ddl.py +98 -89
- sqlspec/builder/_delete.py +5 -4
- sqlspec/builder/_dml.py +388 -0
- sqlspec/{_sql.py → builder/_factory.py} +41 -44
- sqlspec/builder/_insert.py +5 -82
- sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
- sqlspec/builder/_merge.py +446 -11
- sqlspec/builder/_parsing_utils.py +9 -11
- sqlspec/builder/_select.py +1313 -25
- sqlspec/builder/_update.py +11 -42
- sqlspec/cli.py +76 -69
- sqlspec/config.py +331 -62
- sqlspec/core/__init__.py +5 -4
- sqlspec/core/cache.py +18 -18
- sqlspec/core/compiler.py +6 -8
- sqlspec/core/filters.py +55 -47
- sqlspec/core/hashing.py +9 -9
- sqlspec/core/parameters.py +76 -45
- sqlspec/core/result.py +234 -47
- sqlspec/core/splitter.py +16 -17
- sqlspec/core/statement.py +32 -31
- sqlspec/core/type_conversion.py +3 -2
- sqlspec/driver/__init__.py +1 -3
- sqlspec/driver/_async.py +183 -160
- sqlspec/driver/_common.py +197 -109
- sqlspec/driver/_sync.py +189 -161
- sqlspec/driver/mixins/_result_tools.py +20 -236
- sqlspec/driver/mixins/_sql_translator.py +4 -4
- sqlspec/exceptions.py +70 -7
- sqlspec/extensions/adk/__init__.py +53 -0
- sqlspec/extensions/adk/_types.py +51 -0
- sqlspec/extensions/adk/converters.py +172 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
- sqlspec/extensions/adk/migrations/__init__.py +0 -0
- sqlspec/extensions/adk/service.py +181 -0
- sqlspec/extensions/adk/store.py +536 -0
- sqlspec/extensions/aiosql/adapter.py +69 -61
- sqlspec/extensions/fastapi/__init__.py +21 -0
- sqlspec/extensions/fastapi/extension.py +331 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +71 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +389 -0
- sqlspec/extensions/litestar/__init__.py +21 -4
- sqlspec/extensions/litestar/cli.py +54 -10
- sqlspec/extensions/litestar/config.py +56 -266
- sqlspec/extensions/litestar/handlers.py +46 -17
- sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
- sqlspec/extensions/litestar/migrations/__init__.py +3 -0
- sqlspec/extensions/litestar/plugin.py +349 -224
- sqlspec/extensions/litestar/providers.py +25 -25
- sqlspec/extensions/litestar/store.py +265 -0
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +25 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +254 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/loader.py +30 -49
- sqlspec/migrations/base.py +200 -76
- sqlspec/migrations/commands.py +591 -62
- sqlspec/migrations/context.py +6 -9
- sqlspec/migrations/fix.py +199 -0
- sqlspec/migrations/loaders.py +47 -19
- sqlspec/migrations/runner.py +241 -75
- sqlspec/migrations/tracker.py +237 -21
- sqlspec/migrations/utils.py +51 -3
- sqlspec/migrations/validation.py +177 -0
- sqlspec/protocols.py +106 -36
- sqlspec/storage/_utils.py +85 -0
- sqlspec/storage/backends/fsspec.py +133 -107
- sqlspec/storage/backends/local.py +78 -51
- sqlspec/storage/backends/obstore.py +276 -168
- sqlspec/storage/registry.py +75 -39
- sqlspec/typing.py +30 -84
- sqlspec/utils/__init__.py +25 -4
- sqlspec/utils/arrow_helpers.py +81 -0
- sqlspec/utils/config_resolver.py +6 -6
- sqlspec/utils/correlation.py +4 -5
- sqlspec/utils/data_transformation.py +3 -2
- sqlspec/utils/deprecation.py +9 -8
- sqlspec/utils/fixtures.py +4 -4
- sqlspec/utils/logging.py +46 -6
- sqlspec/utils/module_loader.py +205 -5
- sqlspec/utils/portal.py +311 -0
- sqlspec/utils/schema.py +288 -0
- sqlspec/utils/serializers.py +113 -4
- sqlspec/utils/sync_tools.py +36 -22
- sqlspec/utils/text.py +1 -2
- sqlspec/utils/type_guards.py +136 -20
- sqlspec/utils/version.py +433 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +41 -22
- sqlspec-0.28.0.dist-info/RECORD +221 -0
- sqlspec/builder/mixins/__init__.py +0 -55
- sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
- sqlspec/builder/mixins/_delete_operations.py +0 -50
- sqlspec/builder/mixins/_insert_operations.py +0 -282
- sqlspec/builder/mixins/_merge_operations.py +0 -698
- sqlspec/builder/mixins/_order_limit_operations.py +0 -145
- sqlspec/builder/mixins/_pivot_operations.py +0 -157
- sqlspec/builder/mixins/_select_operations.py +0 -930
- sqlspec/builder/mixins/_update_operations.py +0 -199
- sqlspec/builder/mixins/_where_clause.py +0 -1298
- sqlspec-0.26.0.dist-info/RECORD +0 -157
- sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
- {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from contextlib import asynccontextmanager
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
from sqlspec.base import SQLSpec
|
|
5
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
6
|
+
from sqlspec.extensions.starlette._state import SQLSpecConfigState
|
|
7
|
+
from sqlspec.extensions.starlette._utils import get_or_create_session
|
|
8
|
+
from sqlspec.extensions.starlette.middleware import SQLSpecAutocommitMiddleware, SQLSpecManualMiddleware
|
|
9
|
+
from sqlspec.utils.logging import get_logger
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from collections.abc import AsyncGenerator
|
|
13
|
+
|
|
14
|
+
from starlette.applications import Starlette
|
|
15
|
+
from starlette.requests import Request
|
|
16
|
+
|
|
17
|
+
__all__ = ("SQLSpecPlugin",)
|
|
18
|
+
|
|
19
|
+
logger = get_logger("extensions.starlette")
|
|
20
|
+
|
|
21
|
+
DEFAULT_COMMIT_MODE = "manual"
|
|
22
|
+
DEFAULT_CONNECTION_KEY = "db_connection"
|
|
23
|
+
DEFAULT_POOL_KEY = "db_pool"
|
|
24
|
+
DEFAULT_SESSION_KEY = "db_session"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SQLSpecPlugin:
|
|
28
|
+
"""SQLSpec integration for Starlette applications.
|
|
29
|
+
|
|
30
|
+
Provides middleware-based session management, automatic transaction handling,
|
|
31
|
+
and connection pooling lifecycle management.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
from starlette.applications import Starlette
|
|
35
|
+
from sqlspec import SQLSpec
|
|
36
|
+
from sqlspec.adapters.asyncpg import AsyncpgConfig
|
|
37
|
+
from sqlspec.extensions.starlette import SQLSpecPlugin
|
|
38
|
+
|
|
39
|
+
sqlspec = SQLSpec()
|
|
40
|
+
sqlspec.add_config(AsyncpgConfig(
|
|
41
|
+
bind_key="default",
|
|
42
|
+
pool_config={"dsn": "postgresql://localhost/mydb"},
|
|
43
|
+
extension_config={
|
|
44
|
+
"starlette": {
|
|
45
|
+
"commit_mode": "autocommit",
|
|
46
|
+
"session_key": "db"
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
))
|
|
50
|
+
|
|
51
|
+
app = Starlette()
|
|
52
|
+
db_ext = SQLSpecPlugin(sqlspec, app)
|
|
53
|
+
|
|
54
|
+
@app.route("/users")
|
|
55
|
+
async def list_users(request):
|
|
56
|
+
db = db_ext.get_session(request)
|
|
57
|
+
result = await db.execute("SELECT * FROM users")
|
|
58
|
+
return JSONResponse({"users": result.all()})
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
__slots__ = ("_config_states", "_sqlspec")
|
|
62
|
+
|
|
63
|
+
def __init__(self, sqlspec: SQLSpec, app: "Starlette | None" = None) -> None:
|
|
64
|
+
"""Initialize SQLSpec Starlette extension.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
sqlspec: Pre-configured SQLSpec instance with registered configs.
|
|
68
|
+
app: Optional Starlette application to initialize immediately.
|
|
69
|
+
"""
|
|
70
|
+
self._sqlspec = sqlspec
|
|
71
|
+
self._config_states: list[SQLSpecConfigState] = []
|
|
72
|
+
|
|
73
|
+
for cfg in self._sqlspec.configs.values():
|
|
74
|
+
settings = self._extract_starlette_settings(cfg)
|
|
75
|
+
state = self._create_config_state(cfg, settings)
|
|
76
|
+
self._config_states.append(state)
|
|
77
|
+
|
|
78
|
+
if app is not None:
|
|
79
|
+
self.init_app(app)
|
|
80
|
+
|
|
81
|
+
def _extract_starlette_settings(self, config: Any) -> "dict[str, Any]":
|
|
82
|
+
"""Extract Starlette settings from config.extension_config.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
config: Database configuration instance.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Dictionary of Starlette-specific settings.
|
|
89
|
+
"""
|
|
90
|
+
starlette_config = config.extension_config.get("starlette", {})
|
|
91
|
+
|
|
92
|
+
connection_key = starlette_config.get("connection_key", DEFAULT_CONNECTION_KEY)
|
|
93
|
+
pool_key = starlette_config.get("pool_key", DEFAULT_POOL_KEY)
|
|
94
|
+
session_key = starlette_config.get("session_key", DEFAULT_SESSION_KEY)
|
|
95
|
+
commit_mode = starlette_config.get("commit_mode", DEFAULT_COMMIT_MODE)
|
|
96
|
+
|
|
97
|
+
if not config.supports_connection_pooling and pool_key == DEFAULT_POOL_KEY:
|
|
98
|
+
pool_key = f"_{DEFAULT_POOL_KEY}_{id(config)}"
|
|
99
|
+
|
|
100
|
+
return {
|
|
101
|
+
"connection_key": connection_key,
|
|
102
|
+
"pool_key": pool_key,
|
|
103
|
+
"session_key": session_key,
|
|
104
|
+
"commit_mode": commit_mode,
|
|
105
|
+
"extra_commit_statuses": starlette_config.get("extra_commit_statuses"),
|
|
106
|
+
"extra_rollback_statuses": starlette_config.get("extra_rollback_statuses"),
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSpecConfigState:
|
|
110
|
+
"""Create configuration state object.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
config: Database configuration instance.
|
|
114
|
+
settings: Extracted Starlette settings.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Configuration state instance.
|
|
118
|
+
"""
|
|
119
|
+
return SQLSpecConfigState(
|
|
120
|
+
config=config,
|
|
121
|
+
connection_key=settings["connection_key"],
|
|
122
|
+
pool_key=settings["pool_key"],
|
|
123
|
+
session_key=settings["session_key"],
|
|
124
|
+
commit_mode=settings["commit_mode"],
|
|
125
|
+
extra_commit_statuses=settings["extra_commit_statuses"],
|
|
126
|
+
extra_rollback_statuses=settings["extra_rollback_statuses"],
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def init_app(self, app: "Starlette") -> None:
|
|
130
|
+
"""Initialize Starlette application with SQLSpec.
|
|
131
|
+
|
|
132
|
+
Validates configuration, wraps lifespan, and adds middleware.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
app: Starlette application instance.
|
|
136
|
+
"""
|
|
137
|
+
self._validate_unique_keys()
|
|
138
|
+
|
|
139
|
+
original_lifespan = app.router.lifespan_context
|
|
140
|
+
|
|
141
|
+
@asynccontextmanager
|
|
142
|
+
async def combined_lifespan(app: "Starlette") -> "AsyncGenerator[None, None]":
|
|
143
|
+
async with self.lifespan(app), original_lifespan(app):
|
|
144
|
+
yield
|
|
145
|
+
|
|
146
|
+
app.router.lifespan_context = combined_lifespan
|
|
147
|
+
|
|
148
|
+
for config_state in self._config_states:
|
|
149
|
+
self._add_middleware(app, config_state)
|
|
150
|
+
|
|
151
|
+
def _validate_unique_keys(self) -> None:
|
|
152
|
+
"""Validate that all state keys are unique across configs.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
ImproperConfigurationError: If duplicate keys found.
|
|
156
|
+
"""
|
|
157
|
+
all_keys: set[str] = set()
|
|
158
|
+
|
|
159
|
+
for state in self._config_states:
|
|
160
|
+
keys = {state.connection_key, state.pool_key, state.session_key}
|
|
161
|
+
duplicates = all_keys & keys
|
|
162
|
+
|
|
163
|
+
if duplicates:
|
|
164
|
+
msg = f"Duplicate state keys found: {duplicates}"
|
|
165
|
+
raise ImproperConfigurationError(msg)
|
|
166
|
+
|
|
167
|
+
all_keys.update(keys)
|
|
168
|
+
|
|
169
|
+
def _add_middleware(self, app: "Starlette", config_state: SQLSpecConfigState) -> None:
|
|
170
|
+
"""Add transaction middleware for configuration.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
app: Starlette application instance.
|
|
174
|
+
config_state: Configuration state.
|
|
175
|
+
"""
|
|
176
|
+
if config_state.commit_mode == "manual":
|
|
177
|
+
app.add_middleware(SQLSpecManualMiddleware, config_state=config_state)
|
|
178
|
+
elif config_state.commit_mode == "autocommit":
|
|
179
|
+
app.add_middleware(SQLSpecAutocommitMiddleware, config_state=config_state, include_redirect=False)
|
|
180
|
+
elif config_state.commit_mode == "autocommit_include_redirect":
|
|
181
|
+
app.add_middleware(SQLSpecAutocommitMiddleware, config_state=config_state, include_redirect=True)
|
|
182
|
+
|
|
183
|
+
@asynccontextmanager
|
|
184
|
+
async def lifespan(self, app: "Starlette") -> "AsyncGenerator[None, None]":
|
|
185
|
+
"""Manage connection pool lifecycle.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
app: Starlette application instance.
|
|
189
|
+
|
|
190
|
+
Yields:
|
|
191
|
+
None
|
|
192
|
+
"""
|
|
193
|
+
for config_state in self._config_states:
|
|
194
|
+
if config_state.config.supports_connection_pooling:
|
|
195
|
+
pool = await config_state.config.create_pool()
|
|
196
|
+
setattr(app.state, config_state.pool_key, pool)
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
yield
|
|
200
|
+
finally:
|
|
201
|
+
for config_state in self._config_states:
|
|
202
|
+
if config_state.config.supports_connection_pooling:
|
|
203
|
+
close_result = config_state.config.close_pool()
|
|
204
|
+
if close_result is not None:
|
|
205
|
+
await close_result
|
|
206
|
+
|
|
207
|
+
def get_session(self, request: "Request", key: "str | None" = None) -> Any:
|
|
208
|
+
"""Get or create database session for request.
|
|
209
|
+
|
|
210
|
+
Sessions are cached per request to ensure consistency.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
request: Starlette request instance.
|
|
214
|
+
key: Optional session key to retrieve specific database session.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Database session (driver instance).
|
|
218
|
+
"""
|
|
219
|
+
config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
|
|
220
|
+
|
|
221
|
+
return get_or_create_session(request, config_state)
|
|
222
|
+
|
|
223
|
+
def get_connection(self, request: "Request", key: "str | None" = None) -> Any:
|
|
224
|
+
"""Get database connection from request state.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
request: Starlette request instance.
|
|
228
|
+
key: Optional session key to retrieve specific database connection.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Database connection object.
|
|
232
|
+
"""
|
|
233
|
+
config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
|
|
234
|
+
|
|
235
|
+
return getattr(request.state, config_state.connection_key)
|
|
236
|
+
|
|
237
|
+
def _get_config_state_by_key(self, key: str) -> SQLSpecConfigState:
|
|
238
|
+
"""Get configuration state by session key.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
key: Session key to search for.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Configuration state matching the key.
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
ValueError: If no configuration found with the specified key.
|
|
248
|
+
"""
|
|
249
|
+
for state in self._config_states:
|
|
250
|
+
if state.session_key == key:
|
|
251
|
+
return state
|
|
252
|
+
|
|
253
|
+
msg = f"No configuration found with session_key: {key}"
|
|
254
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
4
|
+
|
|
5
|
+
from sqlspec.utils.logging import get_logger
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from starlette.requests import Request
|
|
9
|
+
|
|
10
|
+
from sqlspec.extensions.starlette._state import SQLSpecConfigState
|
|
11
|
+
|
|
12
|
+
__all__ = ("SQLSpecAutocommitMiddleware", "SQLSpecManualMiddleware")
|
|
13
|
+
|
|
14
|
+
logger = get_logger("extensions.starlette.middleware")
|
|
15
|
+
|
|
16
|
+
HTTP_200_OK = 200
|
|
17
|
+
HTTP_300_MULTIPLE_CHOICES = 300
|
|
18
|
+
HTTP_400_BAD_REQUEST = 400
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SQLSpecManualMiddleware(BaseHTTPMiddleware):
|
|
22
|
+
"""Middleware for manual transaction mode.
|
|
23
|
+
|
|
24
|
+
Acquires connection from pool, stores in request.state, releases after request.
|
|
25
|
+
No automatic commit or rollback - user code must handle transactions.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, app: Any, config_state: "SQLSpecConfigState") -> None:
|
|
29
|
+
"""Initialize middleware.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
app: Starlette application instance.
|
|
33
|
+
config_state: Configuration state for this database.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(app)
|
|
36
|
+
self.config_state = config_state
|
|
37
|
+
|
|
38
|
+
async def dispatch(self, request: "Request", call_next: Any) -> Any:
|
|
39
|
+
"""Process request with manual transaction mode.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
request: Incoming HTTP request.
|
|
43
|
+
call_next: Next middleware or route handler.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
HTTP response.
|
|
47
|
+
"""
|
|
48
|
+
config = self.config_state.config
|
|
49
|
+
connection_key = self.config_state.connection_key
|
|
50
|
+
|
|
51
|
+
if config.supports_connection_pooling:
|
|
52
|
+
pool = getattr(request.app.state, self.config_state.pool_key)
|
|
53
|
+
async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
|
|
54
|
+
setattr(request.state, connection_key, connection)
|
|
55
|
+
try:
|
|
56
|
+
return await call_next(request)
|
|
57
|
+
finally:
|
|
58
|
+
delattr(request.state, connection_key)
|
|
59
|
+
else:
|
|
60
|
+
connection = await config.create_connection()
|
|
61
|
+
setattr(request.state, connection_key, connection)
|
|
62
|
+
try:
|
|
63
|
+
return await call_next(request)
|
|
64
|
+
finally:
|
|
65
|
+
await connection.close()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SQLSpecAutocommitMiddleware(BaseHTTPMiddleware):
|
|
69
|
+
"""Middleware for autocommit transaction mode.
|
|
70
|
+
|
|
71
|
+
Acquires connection, commits on success status codes, rollbacks on error status codes.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, app: Any, config_state: "SQLSpecConfigState", include_redirect: bool = False) -> None:
|
|
75
|
+
"""Initialize middleware.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
app: Starlette application instance.
|
|
79
|
+
config_state: Configuration state for this database.
|
|
80
|
+
include_redirect: If True, commit on 3xx status codes as well.
|
|
81
|
+
"""
|
|
82
|
+
super().__init__(app)
|
|
83
|
+
self.config_state = config_state
|
|
84
|
+
self.include_redirect = include_redirect
|
|
85
|
+
|
|
86
|
+
async def dispatch(self, request: "Request", call_next: Any) -> Any:
|
|
87
|
+
"""Process request with autocommit transaction mode.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
request: Incoming HTTP request.
|
|
91
|
+
call_next: Next middleware or route handler.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
HTTP response.
|
|
95
|
+
"""
|
|
96
|
+
config = self.config_state.config
|
|
97
|
+
connection_key = self.config_state.connection_key
|
|
98
|
+
|
|
99
|
+
if config.supports_connection_pooling:
|
|
100
|
+
pool = getattr(request.app.state, self.config_state.pool_key)
|
|
101
|
+
async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
|
|
102
|
+
setattr(request.state, connection_key, connection)
|
|
103
|
+
try:
|
|
104
|
+
response = await call_next(request)
|
|
105
|
+
|
|
106
|
+
if self._should_commit(response.status_code):
|
|
107
|
+
await connection.commit()
|
|
108
|
+
else:
|
|
109
|
+
await connection.rollback()
|
|
110
|
+
except Exception:
|
|
111
|
+
await connection.rollback()
|
|
112
|
+
raise
|
|
113
|
+
else:
|
|
114
|
+
return response
|
|
115
|
+
finally:
|
|
116
|
+
delattr(request.state, connection_key)
|
|
117
|
+
else:
|
|
118
|
+
connection = await config.create_connection()
|
|
119
|
+
setattr(request.state, connection_key, connection)
|
|
120
|
+
try:
|
|
121
|
+
response = await call_next(request)
|
|
122
|
+
|
|
123
|
+
if self._should_commit(response.status_code):
|
|
124
|
+
await connection.commit()
|
|
125
|
+
else:
|
|
126
|
+
await connection.rollback()
|
|
127
|
+
except Exception:
|
|
128
|
+
await connection.rollback()
|
|
129
|
+
raise
|
|
130
|
+
else:
|
|
131
|
+
return response
|
|
132
|
+
finally:
|
|
133
|
+
await connection.close()
|
|
134
|
+
|
|
135
|
+
def _should_commit(self, status_code: int) -> bool:
|
|
136
|
+
"""Determine if response status code should trigger commit.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
status_code: HTTP status code.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
True if should commit, False if should rollback.
|
|
143
|
+
"""
|
|
144
|
+
extra_commit = self.config_state.extra_commit_statuses or set()
|
|
145
|
+
extra_rollback = self.config_state.extra_rollback_statuses or set()
|
|
146
|
+
|
|
147
|
+
if status_code in extra_commit:
|
|
148
|
+
return True
|
|
149
|
+
if status_code in extra_rollback:
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
if HTTP_200_OK <= status_code < HTTP_300_MULTIPLE_CHOICES:
|
|
153
|
+
return True
|
|
154
|
+
return bool(self.include_redirect and HTTP_300_MULTIPLE_CHOICES <= status_code < HTTP_400_BAD_REQUEST)
|
sqlspec/loader.py
CHANGED
|
@@ -9,7 +9,7 @@ import re
|
|
|
9
9
|
import time
|
|
10
10
|
from datetime import datetime, timezone
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
from typing import TYPE_CHECKING, Any, Final
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Final
|
|
13
13
|
from urllib.parse import unquote, urlparse
|
|
14
14
|
|
|
15
15
|
from sqlspec.core.cache import get_cache, get_cache_config
|
|
@@ -95,7 +95,7 @@ class NamedStatement:
|
|
|
95
95
|
|
|
96
96
|
__slots__ = ("dialect", "name", "sql", "start_line")
|
|
97
97
|
|
|
98
|
-
def __init__(self, name: str, sql: str, dialect: "
|
|
98
|
+
def __init__(self, name: str, sql: str, dialect: "str | None" = None, start_line: int = 0) -> None:
|
|
99
99
|
self.name = name
|
|
100
100
|
self.sql = sql
|
|
101
101
|
self.dialect = dialect
|
|
@@ -112,11 +112,7 @@ class SQLFile:
|
|
|
112
112
|
__slots__ = ("checksum", "content", "loaded_at", "metadata", "path")
|
|
113
113
|
|
|
114
114
|
def __init__(
|
|
115
|
-
self,
|
|
116
|
-
content: str,
|
|
117
|
-
path: str,
|
|
118
|
-
metadata: "Optional[dict[str, Any]]" = None,
|
|
119
|
-
loaded_at: "Optional[datetime]" = None,
|
|
115
|
+
self, content: str, path: str, metadata: "dict[str, Any] | None" = None, loaded_at: "datetime | None" = None
|
|
120
116
|
) -> None:
|
|
121
117
|
"""Initialize SQLFile.
|
|
122
118
|
|
|
@@ -163,7 +159,7 @@ class SQLFileLoader:
|
|
|
163
159
|
|
|
164
160
|
__slots__ = ("_files", "_queries", "_query_to_file", "encoding", "storage_registry")
|
|
165
161
|
|
|
166
|
-
def __init__(self, *, encoding: str = "utf-8", storage_registry: "
|
|
162
|
+
def __init__(self, *, encoding: str = "utf-8", storage_registry: "StorageRegistry | None" = None) -> None:
|
|
167
163
|
"""Initialize the SQL file loader.
|
|
168
164
|
|
|
169
165
|
Args:
|
|
@@ -188,7 +184,7 @@ class SQLFileLoader:
|
|
|
188
184
|
"""
|
|
189
185
|
raise SQLFileNotFoundError(path)
|
|
190
186
|
|
|
191
|
-
def _generate_file_cache_key(self, path:
|
|
187
|
+
def _generate_file_cache_key(self, path: str | Path) -> str:
|
|
192
188
|
"""Generate cache key for a file path.
|
|
193
189
|
|
|
194
190
|
Args:
|
|
@@ -201,7 +197,7 @@ class SQLFileLoader:
|
|
|
201
197
|
path_hash = hashlib.md5(path_str.encode(), usedforsecurity=False).hexdigest()
|
|
202
198
|
return f"file:{path_hash[:16]}"
|
|
203
199
|
|
|
204
|
-
def _calculate_file_checksum(self, path:
|
|
200
|
+
def _calculate_file_checksum(self, path: str | Path) -> str:
|
|
205
201
|
"""Calculate checksum for file content validation.
|
|
206
202
|
|
|
207
203
|
Args:
|
|
@@ -218,7 +214,7 @@ class SQLFileLoader:
|
|
|
218
214
|
except Exception as e:
|
|
219
215
|
raise SQLFileParseError(str(path), str(path), e) from e
|
|
220
216
|
|
|
221
|
-
def _is_file_unchanged(self, path:
|
|
217
|
+
def _is_file_unchanged(self, path: str | Path, cached_file: CachedSQLFile) -> bool:
|
|
222
218
|
"""Check if file has changed since caching.
|
|
223
219
|
|
|
224
220
|
Args:
|
|
@@ -235,7 +231,7 @@ class SQLFileLoader:
|
|
|
235
231
|
else:
|
|
236
232
|
return current_checksum == cached_file.sql_file.checksum
|
|
237
233
|
|
|
238
|
-
def _read_file_content(self, path:
|
|
234
|
+
def _read_file_content(self, path: str | Path) -> str:
|
|
239
235
|
"""Read file content using storage backend.
|
|
240
236
|
|
|
241
237
|
Args:
|
|
@@ -349,7 +345,7 @@ class SQLFileLoader:
|
|
|
349
345
|
|
|
350
346
|
return statements
|
|
351
347
|
|
|
352
|
-
def load_sql(self, *paths:
|
|
348
|
+
def load_sql(self, *paths: str | Path) -> None:
|
|
353
349
|
"""Load SQL files and parse named queries.
|
|
354
350
|
|
|
355
351
|
Args:
|
|
@@ -358,43 +354,20 @@ class SQLFileLoader:
|
|
|
358
354
|
correlation_id = CorrelationContext.get()
|
|
359
355
|
start_time = time.perf_counter()
|
|
360
356
|
|
|
361
|
-
logger.info("Loading SQL files", extra={"file_count": len(paths), "correlation_id": correlation_id})
|
|
362
|
-
|
|
363
|
-
loaded_count = 0
|
|
364
|
-
query_count_before = len(self._queries)
|
|
365
|
-
|
|
366
357
|
try:
|
|
367
358
|
for path in paths:
|
|
368
359
|
path_str = str(path)
|
|
369
360
|
if "://" in path_str:
|
|
370
361
|
self._load_single_file(path, None)
|
|
371
|
-
loaded_count += 1
|
|
372
362
|
else:
|
|
373
363
|
path_obj = Path(path)
|
|
374
364
|
if path_obj.is_dir():
|
|
375
|
-
|
|
365
|
+
self._load_directory(path_obj)
|
|
376
366
|
elif path_obj.exists():
|
|
377
367
|
self._load_single_file(path_obj, None)
|
|
378
|
-
loaded_count += 1
|
|
379
368
|
elif path_obj.suffix:
|
|
380
369
|
self._raise_file_not_found(str(path))
|
|
381
370
|
|
|
382
|
-
duration = time.perf_counter() - start_time
|
|
383
|
-
new_queries = len(self._queries) - query_count_before
|
|
384
|
-
|
|
385
|
-
logger.info(
|
|
386
|
-
"Loaded %d SQL files with %d new queries in %.3fms",
|
|
387
|
-
loaded_count,
|
|
388
|
-
new_queries,
|
|
389
|
-
duration * 1000,
|
|
390
|
-
extra={
|
|
391
|
-
"files_loaded": loaded_count,
|
|
392
|
-
"new_queries": new_queries,
|
|
393
|
-
"duration_ms": duration * 1000,
|
|
394
|
-
"correlation_id": correlation_id,
|
|
395
|
-
},
|
|
396
|
-
)
|
|
397
|
-
|
|
398
371
|
except Exception as e:
|
|
399
372
|
duration = time.perf_counter() - start_time
|
|
400
373
|
logger.exception(
|
|
@@ -408,34 +381,40 @@ class SQLFileLoader:
|
|
|
408
381
|
)
|
|
409
382
|
raise
|
|
410
383
|
|
|
411
|
-
def _load_directory(self, dir_path: Path) ->
|
|
412
|
-
"""Load all SQL files from a directory.
|
|
384
|
+
def _load_directory(self, dir_path: Path) -> None:
|
|
385
|
+
"""Load all SQL files from a directory.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
dir_path: Directory path to load SQL files from.
|
|
389
|
+
"""
|
|
413
390
|
sql_files = list(dir_path.rglob("*.sql"))
|
|
414
391
|
if not sql_files:
|
|
415
|
-
return
|
|
392
|
+
return
|
|
416
393
|
|
|
417
394
|
for file_path in sql_files:
|
|
418
395
|
relative_path = file_path.relative_to(dir_path)
|
|
419
396
|
namespace_parts = relative_path.parent.parts
|
|
420
397
|
self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None)
|
|
421
|
-
return len(sql_files)
|
|
422
398
|
|
|
423
|
-
def _load_single_file(self, file_path:
|
|
399
|
+
def _load_single_file(self, file_path: str | Path, namespace: str | None) -> bool:
|
|
424
400
|
"""Load a single SQL file with optional namespace.
|
|
425
401
|
|
|
426
402
|
Args:
|
|
427
403
|
file_path: Path to the SQL file.
|
|
428
404
|
namespace: Optional namespace prefix for queries.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
True if file was newly loaded, False if already cached.
|
|
429
408
|
"""
|
|
430
409
|
path_str = str(file_path)
|
|
431
410
|
|
|
432
411
|
if path_str in self._files:
|
|
433
|
-
return
|
|
412
|
+
return False
|
|
434
413
|
|
|
435
414
|
cache_config = get_cache_config()
|
|
436
415
|
if not cache_config.compiled_cache_enabled:
|
|
437
416
|
self._load_file_without_cache(file_path, namespace)
|
|
438
|
-
return
|
|
417
|
+
return True
|
|
439
418
|
|
|
440
419
|
cache_key_str = self._generate_file_cache_key(file_path)
|
|
441
420
|
cache = get_cache()
|
|
@@ -459,7 +438,7 @@ class SQLFileLoader:
|
|
|
459
438
|
)
|
|
460
439
|
self._queries[namespaced_name] = statement
|
|
461
440
|
self._query_to_file[namespaced_name] = path_str
|
|
462
|
-
return
|
|
441
|
+
return True
|
|
463
442
|
|
|
464
443
|
self._load_file_without_cache(file_path, namespace)
|
|
465
444
|
|
|
@@ -476,7 +455,9 @@ class SQLFileLoader:
|
|
|
476
455
|
cached_file_data = CachedSQLFile(sql_file=sql_file, parsed_statements=file_statements)
|
|
477
456
|
cache.put("file", cache_key_str, cached_file_data)
|
|
478
457
|
|
|
479
|
-
|
|
458
|
+
return True
|
|
459
|
+
|
|
460
|
+
def _load_file_without_cache(self, file_path: str | Path, namespace: str | None) -> None:
|
|
480
461
|
"""Load a single SQL file without using cache.
|
|
481
462
|
|
|
482
463
|
Args:
|
|
@@ -503,7 +484,7 @@ class SQLFileLoader:
|
|
|
503
484
|
self._queries[namespaced_name] = statement
|
|
504
485
|
self._query_to_file[namespaced_name] = path_str
|
|
505
486
|
|
|
506
|
-
def add_named_sql(self, name: str, sql: str, dialect: "
|
|
487
|
+
def add_named_sql(self, name: str, sql: str, dialect: "str | None" = None) -> None:
|
|
507
488
|
"""Add a named SQL query directly without loading from a file.
|
|
508
489
|
|
|
509
490
|
Args:
|
|
@@ -529,7 +510,7 @@ class SQLFileLoader:
|
|
|
529
510
|
self._queries[normalized_name] = statement
|
|
530
511
|
self._query_to_file[normalized_name] = "<directly added>"
|
|
531
512
|
|
|
532
|
-
def get_file(self, path:
|
|
513
|
+
def get_file(self, path: str | Path) -> "SQLFile | None":
|
|
533
514
|
"""Get a loaded SQLFile object by path.
|
|
534
515
|
|
|
535
516
|
Args:
|
|
@@ -540,7 +521,7 @@ class SQLFileLoader:
|
|
|
540
521
|
"""
|
|
541
522
|
return self._files.get(str(path))
|
|
542
523
|
|
|
543
|
-
def get_file_for_query(self, name: str) -> "
|
|
524
|
+
def get_file_for_query(self, name: str) -> "SQLFile | None":
|
|
544
525
|
"""Get the SQLFile object containing a query.
|
|
545
526
|
|
|
546
527
|
Args:
|