sqlspec 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_typing.py +93 -0
- sqlspec/adapters/adbc/adk/store.py +21 -11
- sqlspec/adapters/adbc/data_dictionary.py +27 -5
- sqlspec/adapters/adbc/driver.py +83 -14
- sqlspec/adapters/aiosqlite/adk/store.py +27 -18
- sqlspec/adapters/asyncmy/adk/store.py +26 -16
- sqlspec/adapters/asyncpg/adk/store.py +26 -16
- sqlspec/adapters/asyncpg/data_dictionary.py +24 -17
- sqlspec/adapters/bigquery/adk/store.py +30 -21
- sqlspec/adapters/bigquery/config.py +11 -0
- sqlspec/adapters/bigquery/driver.py +138 -1
- sqlspec/adapters/duckdb/adk/store.py +21 -11
- sqlspec/adapters/duckdb/driver.py +87 -1
- sqlspec/adapters/oracledb/adk/store.py +89 -206
- sqlspec/adapters/oracledb/driver.py +183 -2
- sqlspec/adapters/oracledb/litestar/store.py +22 -24
- sqlspec/adapters/psqlpy/adk/store.py +28 -27
- sqlspec/adapters/psqlpy/data_dictionary.py +24 -17
- sqlspec/adapters/psqlpy/driver.py +7 -10
- sqlspec/adapters/psycopg/adk/store.py +51 -33
- sqlspec/adapters/psycopg/data_dictionary.py +48 -34
- sqlspec/adapters/sqlite/adk/store.py +29 -19
- sqlspec/config.py +100 -2
- sqlspec/core/filters.py +18 -10
- sqlspec/core/result.py +133 -2
- sqlspec/driver/_async.py +89 -0
- sqlspec/driver/_common.py +64 -29
- sqlspec/driver/_sync.py +95 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +2 -2
- sqlspec/extensions/adk/service.py +3 -3
- sqlspec/extensions/adk/store.py +8 -8
- sqlspec/extensions/aiosql/adapter.py +3 -15
- sqlspec/extensions/fastapi/__init__.py +21 -0
- sqlspec/extensions/fastapi/extension.py +331 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +71 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +389 -0
- sqlspec/extensions/litestar/config.py +3 -6
- sqlspec/extensions/litestar/plugin.py +26 -2
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +25 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +254 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/protocols.py +40 -0
- sqlspec/storage/_utils.py +1 -14
- sqlspec/storage/backends/fsspec.py +3 -5
- sqlspec/storage/backends/local.py +1 -1
- sqlspec/storage/backends/obstore.py +10 -18
- sqlspec/typing.py +16 -0
- sqlspec/utils/__init__.py +25 -4
- sqlspec/utils/arrow_helpers.py +81 -0
- sqlspec/utils/module_loader.py +203 -3
- sqlspec/utils/portal.py +311 -0
- sqlspec/utils/serializers.py +110 -1
- sqlspec/utils/sync_tools.py +15 -5
- sqlspec/utils/type_guards.py +25 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +2 -2
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/RECORD +64 -50
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from contextlib import asynccontextmanager
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
from sqlspec.base import SQLSpec
|
|
5
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
6
|
+
from sqlspec.extensions.starlette._state import SQLSpecConfigState
|
|
7
|
+
from sqlspec.extensions.starlette._utils import get_or_create_session
|
|
8
|
+
from sqlspec.extensions.starlette.middleware import SQLSpecAutocommitMiddleware, SQLSpecManualMiddleware
|
|
9
|
+
from sqlspec.utils.logging import get_logger
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from collections.abc import AsyncGenerator
|
|
13
|
+
|
|
14
|
+
from starlette.applications import Starlette
|
|
15
|
+
from starlette.requests import Request
|
|
16
|
+
|
|
17
|
+
__all__ = ("SQLSpecPlugin",)
|
|
18
|
+
|
|
19
|
+
logger = get_logger("extensions.starlette")
|
|
20
|
+
|
|
21
|
+
DEFAULT_COMMIT_MODE = "manual"
|
|
22
|
+
DEFAULT_CONNECTION_KEY = "db_connection"
|
|
23
|
+
DEFAULT_POOL_KEY = "db_pool"
|
|
24
|
+
DEFAULT_SESSION_KEY = "db_session"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SQLSpecPlugin:
|
|
28
|
+
"""SQLSpec integration for Starlette applications.
|
|
29
|
+
|
|
30
|
+
Provides middleware-based session management, automatic transaction handling,
|
|
31
|
+
and connection pooling lifecycle management.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
from starlette.applications import Starlette
|
|
35
|
+
from sqlspec import SQLSpec
|
|
36
|
+
from sqlspec.adapters.asyncpg import AsyncpgConfig
|
|
37
|
+
from sqlspec.extensions.starlette import SQLSpecPlugin
|
|
38
|
+
|
|
39
|
+
sqlspec = SQLSpec()
|
|
40
|
+
sqlspec.add_config(AsyncpgConfig(
|
|
41
|
+
bind_key="default",
|
|
42
|
+
pool_config={"dsn": "postgresql://localhost/mydb"},
|
|
43
|
+
extension_config={
|
|
44
|
+
"starlette": {
|
|
45
|
+
"commit_mode": "autocommit",
|
|
46
|
+
"session_key": "db"
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
))
|
|
50
|
+
|
|
51
|
+
app = Starlette()
|
|
52
|
+
db_ext = SQLSpecPlugin(sqlspec, app)
|
|
53
|
+
|
|
54
|
+
@app.route("/users")
|
|
55
|
+
async def list_users(request):
|
|
56
|
+
db = db_ext.get_session(request)
|
|
57
|
+
result = await db.execute("SELECT * FROM users")
|
|
58
|
+
return JSONResponse({"users": result.all()})
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
__slots__ = ("_config_states", "_sqlspec")
|
|
62
|
+
|
|
63
|
+
def __init__(self, sqlspec: SQLSpec, app: "Starlette | None" = None) -> None:
|
|
64
|
+
"""Initialize SQLSpec Starlette extension.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
sqlspec: Pre-configured SQLSpec instance with registered configs.
|
|
68
|
+
app: Optional Starlette application to initialize immediately.
|
|
69
|
+
"""
|
|
70
|
+
self._sqlspec = sqlspec
|
|
71
|
+
self._config_states: list[SQLSpecConfigState] = []
|
|
72
|
+
|
|
73
|
+
for cfg in self._sqlspec.configs.values():
|
|
74
|
+
settings = self._extract_starlette_settings(cfg)
|
|
75
|
+
state = self._create_config_state(cfg, settings)
|
|
76
|
+
self._config_states.append(state)
|
|
77
|
+
|
|
78
|
+
if app is not None:
|
|
79
|
+
self.init_app(app)
|
|
80
|
+
|
|
81
|
+
def _extract_starlette_settings(self, config: Any) -> "dict[str, Any]":
|
|
82
|
+
"""Extract Starlette settings from config.extension_config.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
config: Database configuration instance.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Dictionary of Starlette-specific settings.
|
|
89
|
+
"""
|
|
90
|
+
starlette_config = config.extension_config.get("starlette", {})
|
|
91
|
+
|
|
92
|
+
connection_key = starlette_config.get("connection_key", DEFAULT_CONNECTION_KEY)
|
|
93
|
+
pool_key = starlette_config.get("pool_key", DEFAULT_POOL_KEY)
|
|
94
|
+
session_key = starlette_config.get("session_key", DEFAULT_SESSION_KEY)
|
|
95
|
+
commit_mode = starlette_config.get("commit_mode", DEFAULT_COMMIT_MODE)
|
|
96
|
+
|
|
97
|
+
if not config.supports_connection_pooling and pool_key == DEFAULT_POOL_KEY:
|
|
98
|
+
pool_key = f"_{DEFAULT_POOL_KEY}_{id(config)}"
|
|
99
|
+
|
|
100
|
+
return {
|
|
101
|
+
"connection_key": connection_key,
|
|
102
|
+
"pool_key": pool_key,
|
|
103
|
+
"session_key": session_key,
|
|
104
|
+
"commit_mode": commit_mode,
|
|
105
|
+
"extra_commit_statuses": starlette_config.get("extra_commit_statuses"),
|
|
106
|
+
"extra_rollback_statuses": starlette_config.get("extra_rollback_statuses"),
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSpecConfigState:
|
|
110
|
+
"""Create configuration state object.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
config: Database configuration instance.
|
|
114
|
+
settings: Extracted Starlette settings.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Configuration state instance.
|
|
118
|
+
"""
|
|
119
|
+
return SQLSpecConfigState(
|
|
120
|
+
config=config,
|
|
121
|
+
connection_key=settings["connection_key"],
|
|
122
|
+
pool_key=settings["pool_key"],
|
|
123
|
+
session_key=settings["session_key"],
|
|
124
|
+
commit_mode=settings["commit_mode"],
|
|
125
|
+
extra_commit_statuses=settings["extra_commit_statuses"],
|
|
126
|
+
extra_rollback_statuses=settings["extra_rollback_statuses"],
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def init_app(self, app: "Starlette") -> None:
|
|
130
|
+
"""Initialize Starlette application with SQLSpec.
|
|
131
|
+
|
|
132
|
+
Validates configuration, wraps lifespan, and adds middleware.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
app: Starlette application instance.
|
|
136
|
+
"""
|
|
137
|
+
self._validate_unique_keys()
|
|
138
|
+
|
|
139
|
+
original_lifespan = app.router.lifespan_context
|
|
140
|
+
|
|
141
|
+
@asynccontextmanager
|
|
142
|
+
async def combined_lifespan(app: "Starlette") -> "AsyncGenerator[None, None]":
|
|
143
|
+
async with self.lifespan(app), original_lifespan(app):
|
|
144
|
+
yield
|
|
145
|
+
|
|
146
|
+
app.router.lifespan_context = combined_lifespan
|
|
147
|
+
|
|
148
|
+
for config_state in self._config_states:
|
|
149
|
+
self._add_middleware(app, config_state)
|
|
150
|
+
|
|
151
|
+
def _validate_unique_keys(self) -> None:
|
|
152
|
+
"""Validate that all state keys are unique across configs.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
ImproperConfigurationError: If duplicate keys found.
|
|
156
|
+
"""
|
|
157
|
+
all_keys: set[str] = set()
|
|
158
|
+
|
|
159
|
+
for state in self._config_states:
|
|
160
|
+
keys = {state.connection_key, state.pool_key, state.session_key}
|
|
161
|
+
duplicates = all_keys & keys
|
|
162
|
+
|
|
163
|
+
if duplicates:
|
|
164
|
+
msg = f"Duplicate state keys found: {duplicates}"
|
|
165
|
+
raise ImproperConfigurationError(msg)
|
|
166
|
+
|
|
167
|
+
all_keys.update(keys)
|
|
168
|
+
|
|
169
|
+
def _add_middleware(self, app: "Starlette", config_state: SQLSpecConfigState) -> None:
|
|
170
|
+
"""Add transaction middleware for configuration.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
app: Starlette application instance.
|
|
174
|
+
config_state: Configuration state.
|
|
175
|
+
"""
|
|
176
|
+
if config_state.commit_mode == "manual":
|
|
177
|
+
app.add_middleware(SQLSpecManualMiddleware, config_state=config_state)
|
|
178
|
+
elif config_state.commit_mode == "autocommit":
|
|
179
|
+
app.add_middleware(SQLSpecAutocommitMiddleware, config_state=config_state, include_redirect=False)
|
|
180
|
+
elif config_state.commit_mode == "autocommit_include_redirect":
|
|
181
|
+
app.add_middleware(SQLSpecAutocommitMiddleware, config_state=config_state, include_redirect=True)
|
|
182
|
+
|
|
183
|
+
@asynccontextmanager
|
|
184
|
+
async def lifespan(self, app: "Starlette") -> "AsyncGenerator[None, None]":
|
|
185
|
+
"""Manage connection pool lifecycle.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
app: Starlette application instance.
|
|
189
|
+
|
|
190
|
+
Yields:
|
|
191
|
+
None
|
|
192
|
+
"""
|
|
193
|
+
for config_state in self._config_states:
|
|
194
|
+
if config_state.config.supports_connection_pooling:
|
|
195
|
+
pool = await config_state.config.create_pool()
|
|
196
|
+
setattr(app.state, config_state.pool_key, pool)
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
yield
|
|
200
|
+
finally:
|
|
201
|
+
for config_state in self._config_states:
|
|
202
|
+
if config_state.config.supports_connection_pooling:
|
|
203
|
+
close_result = config_state.config.close_pool()
|
|
204
|
+
if close_result is not None:
|
|
205
|
+
await close_result
|
|
206
|
+
|
|
207
|
+
def get_session(self, request: "Request", key: "str | None" = None) -> Any:
|
|
208
|
+
"""Get or create database session for request.
|
|
209
|
+
|
|
210
|
+
Sessions are cached per request to ensure consistency.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
request: Starlette request instance.
|
|
214
|
+
key: Optional session key to retrieve specific database session.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Database session (driver instance).
|
|
218
|
+
"""
|
|
219
|
+
config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
|
|
220
|
+
|
|
221
|
+
return get_or_create_session(request, config_state)
|
|
222
|
+
|
|
223
|
+
def get_connection(self, request: "Request", key: "str | None" = None) -> Any:
|
|
224
|
+
"""Get database connection from request state.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
request: Starlette request instance.
|
|
228
|
+
key: Optional session key to retrieve specific database connection.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Database connection object.
|
|
232
|
+
"""
|
|
233
|
+
config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
|
|
234
|
+
|
|
235
|
+
return getattr(request.state, config_state.connection_key)
|
|
236
|
+
|
|
237
|
+
def _get_config_state_by_key(self, key: str) -> SQLSpecConfigState:
|
|
238
|
+
"""Get configuration state by session key.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
key: Session key to search for.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Configuration state matching the key.
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
ValueError: If no configuration found with the specified key.
|
|
248
|
+
"""
|
|
249
|
+
for state in self._config_states:
|
|
250
|
+
if state.session_key == key:
|
|
251
|
+
return state
|
|
252
|
+
|
|
253
|
+
msg = f"No configuration found with session_key: {key}"
|
|
254
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
4
|
+
|
|
5
|
+
from sqlspec.utils.logging import get_logger
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from starlette.requests import Request
|
|
9
|
+
|
|
10
|
+
from sqlspec.extensions.starlette._state import SQLSpecConfigState
|
|
11
|
+
|
|
12
|
+
__all__ = ("SQLSpecAutocommitMiddleware", "SQLSpecManualMiddleware")
|
|
13
|
+
|
|
14
|
+
logger = get_logger("extensions.starlette.middleware")
|
|
15
|
+
|
|
16
|
+
HTTP_200_OK = 200
|
|
17
|
+
HTTP_300_MULTIPLE_CHOICES = 300
|
|
18
|
+
HTTP_400_BAD_REQUEST = 400
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SQLSpecManualMiddleware(BaseHTTPMiddleware):
|
|
22
|
+
"""Middleware for manual transaction mode.
|
|
23
|
+
|
|
24
|
+
Acquires connection from pool, stores in request.state, releases after request.
|
|
25
|
+
No automatic commit or rollback - user code must handle transactions.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, app: Any, config_state: "SQLSpecConfigState") -> None:
|
|
29
|
+
"""Initialize middleware.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
app: Starlette application instance.
|
|
33
|
+
config_state: Configuration state for this database.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(app)
|
|
36
|
+
self.config_state = config_state
|
|
37
|
+
|
|
38
|
+
async def dispatch(self, request: "Request", call_next: Any) -> Any:
|
|
39
|
+
"""Process request with manual transaction mode.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
request: Incoming HTTP request.
|
|
43
|
+
call_next: Next middleware or route handler.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
HTTP response.
|
|
47
|
+
"""
|
|
48
|
+
config = self.config_state.config
|
|
49
|
+
connection_key = self.config_state.connection_key
|
|
50
|
+
|
|
51
|
+
if config.supports_connection_pooling:
|
|
52
|
+
pool = getattr(request.app.state, self.config_state.pool_key)
|
|
53
|
+
async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
|
|
54
|
+
setattr(request.state, connection_key, connection)
|
|
55
|
+
try:
|
|
56
|
+
return await call_next(request)
|
|
57
|
+
finally:
|
|
58
|
+
delattr(request.state, connection_key)
|
|
59
|
+
else:
|
|
60
|
+
connection = await config.create_connection()
|
|
61
|
+
setattr(request.state, connection_key, connection)
|
|
62
|
+
try:
|
|
63
|
+
return await call_next(request)
|
|
64
|
+
finally:
|
|
65
|
+
await connection.close()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SQLSpecAutocommitMiddleware(BaseHTTPMiddleware):
|
|
69
|
+
"""Middleware for autocommit transaction mode.
|
|
70
|
+
|
|
71
|
+
Acquires connection, commits on success status codes, rollbacks on error status codes.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, app: Any, config_state: "SQLSpecConfigState", include_redirect: bool = False) -> None:
|
|
75
|
+
"""Initialize middleware.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
app: Starlette application instance.
|
|
79
|
+
config_state: Configuration state for this database.
|
|
80
|
+
include_redirect: If True, commit on 3xx status codes as well.
|
|
81
|
+
"""
|
|
82
|
+
super().__init__(app)
|
|
83
|
+
self.config_state = config_state
|
|
84
|
+
self.include_redirect = include_redirect
|
|
85
|
+
|
|
86
|
+
async def dispatch(self, request: "Request", call_next: Any) -> Any:
|
|
87
|
+
"""Process request with autocommit transaction mode.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
request: Incoming HTTP request.
|
|
91
|
+
call_next: Next middleware or route handler.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
HTTP response.
|
|
95
|
+
"""
|
|
96
|
+
config = self.config_state.config
|
|
97
|
+
connection_key = self.config_state.connection_key
|
|
98
|
+
|
|
99
|
+
if config.supports_connection_pooling:
|
|
100
|
+
pool = getattr(request.app.state, self.config_state.pool_key)
|
|
101
|
+
async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
|
|
102
|
+
setattr(request.state, connection_key, connection)
|
|
103
|
+
try:
|
|
104
|
+
response = await call_next(request)
|
|
105
|
+
|
|
106
|
+
if self._should_commit(response.status_code):
|
|
107
|
+
await connection.commit()
|
|
108
|
+
else:
|
|
109
|
+
await connection.rollback()
|
|
110
|
+
except Exception:
|
|
111
|
+
await connection.rollback()
|
|
112
|
+
raise
|
|
113
|
+
else:
|
|
114
|
+
return response
|
|
115
|
+
finally:
|
|
116
|
+
delattr(request.state, connection_key)
|
|
117
|
+
else:
|
|
118
|
+
connection = await config.create_connection()
|
|
119
|
+
setattr(request.state, connection_key, connection)
|
|
120
|
+
try:
|
|
121
|
+
response = await call_next(request)
|
|
122
|
+
|
|
123
|
+
if self._should_commit(response.status_code):
|
|
124
|
+
await connection.commit()
|
|
125
|
+
else:
|
|
126
|
+
await connection.rollback()
|
|
127
|
+
except Exception:
|
|
128
|
+
await connection.rollback()
|
|
129
|
+
raise
|
|
130
|
+
else:
|
|
131
|
+
return response
|
|
132
|
+
finally:
|
|
133
|
+
await connection.close()
|
|
134
|
+
|
|
135
|
+
def _should_commit(self, status_code: int) -> bool:
|
|
136
|
+
"""Determine if response status code should trigger commit.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
status_code: HTTP status code.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
True if should commit, False if should rollback.
|
|
143
|
+
"""
|
|
144
|
+
extra_commit = self.config_state.extra_commit_statuses or set()
|
|
145
|
+
extra_rollback = self.config_state.extra_rollback_statuses or set()
|
|
146
|
+
|
|
147
|
+
if status_code in extra_commit:
|
|
148
|
+
return True
|
|
149
|
+
if status_code in extra_rollback:
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
if HTTP_200_OK <= status_code < HTTP_300_MULTIPLE_CHOICES:
|
|
153
|
+
return True
|
|
154
|
+
return bool(self.include_redirect and HTTP_300_MULTIPLE_CHOICES <= status_code < HTTP_400_BAD_REQUEST)
|
sqlspec/protocols.py
CHANGED
|
@@ -39,6 +39,7 @@ __all__ = (
|
|
|
39
39
|
"ParameterValueProtocol",
|
|
40
40
|
"SQLBuilderProtocol",
|
|
41
41
|
"SelectBuilderProtocol",
|
|
42
|
+
"SupportsArrowResults",
|
|
42
43
|
"WithMethodProtocol",
|
|
43
44
|
)
|
|
44
45
|
|
|
@@ -440,3 +441,42 @@ class SelectBuilderProtocol(SQLBuilderProtocol, Protocol):
|
|
|
440
441
|
def select(self, *columns: "str | exp.Expression") -> Self:
|
|
441
442
|
"""Add SELECT columns to the query."""
|
|
442
443
|
...
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@runtime_checkable
|
|
447
|
+
class SupportsArrowResults(Protocol):
|
|
448
|
+
"""Protocol for adapters that support Arrow result format.
|
|
449
|
+
|
|
450
|
+
Adapters implementing this protocol can return query results in Apache Arrow
|
|
451
|
+
format via the select_to_arrow() method, enabling zero-copy data transfer and
|
|
452
|
+
efficient integration with data science tools.
|
|
453
|
+
"""
|
|
454
|
+
|
|
455
|
+
def select_to_arrow(
|
|
456
|
+
self,
|
|
457
|
+
statement: Any,
|
|
458
|
+
/,
|
|
459
|
+
*parameters: Any,
|
|
460
|
+
statement_config: Any | None = None,
|
|
461
|
+
return_format: str = "table",
|
|
462
|
+
native_only: bool = False,
|
|
463
|
+
batch_size: int | None = None,
|
|
464
|
+
arrow_schema: Any | None = None,
|
|
465
|
+
**kwargs: Any,
|
|
466
|
+
) -> "ArrowTable | ArrowRecordBatch":
|
|
467
|
+
"""Execute query and return results as Apache Arrow Table or RecordBatch.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
statement: SQL statement to execute.
|
|
471
|
+
*parameters: Query parameters and filters.
|
|
472
|
+
statement_config: Optional statement configuration override.
|
|
473
|
+
return_format: Output format - "table", "reader", or "batches".
|
|
474
|
+
native_only: If True, raise error when native Arrow path unavailable.
|
|
475
|
+
batch_size: Chunk size for streaming modes.
|
|
476
|
+
arrow_schema: Optional target Arrow schema for type casting.
|
|
477
|
+
**kwargs: Additional keyword arguments.
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
ArrowResult containing Arrow data.
|
|
481
|
+
"""
|
|
482
|
+
...
|
sqlspec/storage/_utils.py
CHANGED
|
@@ -2,23 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
-
from sqlspec.exceptions import MissingDependencyError
|
|
6
|
-
from sqlspec.typing import PYARROW_INSTALLED
|
|
7
|
-
|
|
8
5
|
if TYPE_CHECKING:
|
|
9
6
|
from pathlib import Path
|
|
10
7
|
|
|
11
|
-
__all__ = ("
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def ensure_pyarrow() -> None:
|
|
15
|
-
"""Ensure PyArrow is available for Arrow operations.
|
|
16
|
-
|
|
17
|
-
Raises:
|
|
18
|
-
MissingDependencyError: If pyarrow is not installed.
|
|
19
|
-
"""
|
|
20
|
-
if not PYARROW_INSTALLED:
|
|
21
|
-
raise MissingDependencyError(package="pyarrow", install_package="pyarrow")
|
|
8
|
+
__all__ = ("resolve_storage_path",)
|
|
22
9
|
|
|
23
10
|
|
|
24
11
|
def resolve_storage_path(
|
|
@@ -5,9 +5,8 @@ from typing import TYPE_CHECKING, Any
|
|
|
5
5
|
|
|
6
6
|
from mypy_extensions import mypyc_attr
|
|
7
7
|
|
|
8
|
-
from sqlspec.
|
|
9
|
-
from sqlspec.
|
|
10
|
-
from sqlspec.typing import FSSPEC_INSTALLED
|
|
8
|
+
from sqlspec.storage._utils import resolve_storage_path
|
|
9
|
+
from sqlspec.utils.module_loader import ensure_fsspec, ensure_pyarrow
|
|
11
10
|
from sqlspec.utils.sync_tools import async_
|
|
12
11
|
|
|
13
12
|
if TYPE_CHECKING:
|
|
@@ -105,8 +104,7 @@ class FSSpecBackend:
|
|
|
105
104
|
__slots__ = ("_fs_uri", "backend_type", "base_path", "fs", "protocol")
|
|
106
105
|
|
|
107
106
|
def __init__(self, uri: str, **kwargs: Any) -> None:
|
|
108
|
-
|
|
109
|
-
raise MissingDependencyError(package="fsspec", install_package="fsspec")
|
|
107
|
+
ensure_fsspec()
|
|
110
108
|
|
|
111
109
|
base_path = kwargs.pop("base_path", "")
|
|
112
110
|
|
|
@@ -12,7 +12,7 @@ from urllib.parse import unquote, urlparse
|
|
|
12
12
|
|
|
13
13
|
from mypy_extensions import mypyc_attr
|
|
14
14
|
|
|
15
|
-
from sqlspec.
|
|
15
|
+
from sqlspec.utils.module_loader import ensure_pyarrow
|
|
16
16
|
from sqlspec.utils.sync_tools import async_
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
@@ -8,19 +8,17 @@ import fnmatch
|
|
|
8
8
|
import io
|
|
9
9
|
import logging
|
|
10
10
|
from collections.abc import AsyncIterator, Iterator
|
|
11
|
-
from
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Final, cast
|
|
12
13
|
from urllib.parse import urlparse
|
|
13
14
|
|
|
14
|
-
from sqlspec.utils.sync_tools import async_
|
|
15
|
-
|
|
16
|
-
if TYPE_CHECKING:
|
|
17
|
-
from pathlib import Path
|
|
18
|
-
|
|
19
15
|
from mypy_extensions import mypyc_attr
|
|
20
16
|
|
|
21
|
-
from sqlspec.exceptions import
|
|
22
|
-
from sqlspec.storage._utils import
|
|
23
|
-
from sqlspec.typing import
|
|
17
|
+
from sqlspec.exceptions import StorageOperationFailedError
|
|
18
|
+
from sqlspec.storage._utils import resolve_storage_path
|
|
19
|
+
from sqlspec.typing import ArrowRecordBatch, ArrowTable
|
|
20
|
+
from sqlspec.utils.module_loader import ensure_obstore, ensure_pyarrow
|
|
21
|
+
from sqlspec.utils.sync_tools import async_
|
|
24
22
|
|
|
25
23
|
__all__ = ("ObStoreBackend",)
|
|
26
24
|
|
|
@@ -119,11 +117,8 @@ class ObStoreBackend:
|
|
|
119
117
|
uri: Storage URI (e.g., 's3://bucket', 'file:///path', 'gs://bucket')
|
|
120
118
|
**kwargs: Additional options including base_path and obstore configuration
|
|
121
119
|
|
|
122
|
-
Raises:
|
|
123
|
-
MissingDependencyError: If obstore is not installed.
|
|
124
120
|
"""
|
|
125
|
-
|
|
126
|
-
raise MissingDependencyError(package="obstore", install_package="obstore")
|
|
121
|
+
ensure_obstore()
|
|
127
122
|
|
|
128
123
|
try:
|
|
129
124
|
# Extract base_path from kwargs
|
|
@@ -144,8 +139,6 @@ class ObStoreBackend:
|
|
|
144
139
|
|
|
145
140
|
self.store = MemoryStore()
|
|
146
141
|
elif uri.startswith("file://"):
|
|
147
|
-
from pathlib import Path as PathlibPath
|
|
148
|
-
|
|
149
142
|
from obstore.store import LocalStore
|
|
150
143
|
|
|
151
144
|
# Parse URI to extract path
|
|
@@ -155,7 +148,7 @@ class ObStoreBackend:
|
|
|
155
148
|
# Append fragment if present (handles paths with '#' character)
|
|
156
149
|
if parsed.fragment:
|
|
157
150
|
path_str = f"{path_str}#{parsed.fragment}"
|
|
158
|
-
path_obj =
|
|
151
|
+
path_obj = Path(path_str)
|
|
159
152
|
|
|
160
153
|
# If path points to a file, use its parent as the base directory
|
|
161
154
|
if path_obj.is_file():
|
|
@@ -194,9 +187,8 @@ class ObStoreBackend:
|
|
|
194
187
|
|
|
195
188
|
def _resolve_path_for_local_store(self, path: "str | Path") -> str:
|
|
196
189
|
"""Resolve path for LocalStore which expects relative paths from its root."""
|
|
197
|
-
from pathlib import Path as PathlibPath
|
|
198
190
|
|
|
199
|
-
path_obj =
|
|
191
|
+
path_obj = Path(str(path))
|
|
200
192
|
|
|
201
193
|
# If absolute path, try to make it relative to LocalStore root
|
|
202
194
|
if path_obj.is_absolute() and self._local_store_root:
|
sqlspec/typing.py
CHANGED
|
@@ -16,7 +16,9 @@ from sqlspec._typing import (
|
|
|
16
16
|
OBSTORE_INSTALLED,
|
|
17
17
|
OPENTELEMETRY_INSTALLED,
|
|
18
18
|
ORJSON_INSTALLED,
|
|
19
|
+
PANDAS_INSTALLED,
|
|
19
20
|
PGVECTOR_INSTALLED,
|
|
21
|
+
POLARS_INSTALLED,
|
|
20
22
|
PROMETHEUS_INSTALLED,
|
|
21
23
|
PYARROW_INSTALLED,
|
|
22
24
|
PYDANTIC_INSTALLED,
|
|
@@ -26,6 +28,10 @@ from sqlspec._typing import (
|
|
|
26
28
|
AiosqlSQLOperationType,
|
|
27
29
|
AiosqlSyncProtocol,
|
|
28
30
|
ArrowRecordBatch,
|
|
31
|
+
ArrowRecordBatchReader,
|
|
32
|
+
ArrowRecordBatchReaderProtocol,
|
|
33
|
+
ArrowSchema,
|
|
34
|
+
ArrowSchemaProtocol,
|
|
29
35
|
ArrowTable,
|
|
30
36
|
AttrsInstance,
|
|
31
37
|
AttrsInstanceStub,
|
|
@@ -41,6 +47,8 @@ from sqlspec._typing import (
|
|
|
41
47
|
Gauge,
|
|
42
48
|
Histogram,
|
|
43
49
|
NumpyArray,
|
|
50
|
+
PandasDataFrame,
|
|
51
|
+
PolarsDataFrame,
|
|
44
52
|
Span,
|
|
45
53
|
Status,
|
|
46
54
|
StatusCode,
|
|
@@ -136,7 +144,9 @@ __all__ = (
|
|
|
136
144
|
"OBSTORE_INSTALLED",
|
|
137
145
|
"OPENTELEMETRY_INSTALLED",
|
|
138
146
|
"ORJSON_INSTALLED",
|
|
147
|
+
"PANDAS_INSTALLED",
|
|
139
148
|
"PGVECTOR_INSTALLED",
|
|
149
|
+
"POLARS_INSTALLED",
|
|
140
150
|
"PROMETHEUS_INSTALLED",
|
|
141
151
|
"PYARROW_INSTALLED",
|
|
142
152
|
"PYDANTIC_INSTALLED",
|
|
@@ -147,6 +157,10 @@ __all__ = (
|
|
|
147
157
|
"AiosqlSQLOperationType",
|
|
148
158
|
"AiosqlSyncProtocol",
|
|
149
159
|
"ArrowRecordBatch",
|
|
160
|
+
"ArrowRecordBatchReader",
|
|
161
|
+
"ArrowRecordBatchReaderProtocol",
|
|
162
|
+
"ArrowSchema",
|
|
163
|
+
"ArrowSchemaProtocol",
|
|
150
164
|
"ArrowTable",
|
|
151
165
|
"AttrsInstance",
|
|
152
166
|
"BaseModel",
|
|
@@ -162,6 +176,8 @@ __all__ = (
|
|
|
162
176
|
"Gauge",
|
|
163
177
|
"Histogram",
|
|
164
178
|
"NumpyArray",
|
|
179
|
+
"PandasDataFrame",
|
|
180
|
+
"PolarsDataFrame",
|
|
165
181
|
"PoolT",
|
|
166
182
|
"SchemaT",
|
|
167
183
|
"Span",
|
sqlspec/utils/__init__.py
CHANGED
|
@@ -1,10 +1,31 @@
|
|
|
1
1
|
"""Utility functions and classes for SQLSpec.
|
|
2
2
|
|
|
3
3
|
This package provides various utility modules for deprecation handling,
|
|
4
|
-
fixture loading, logging, module loading
|
|
5
|
-
|
|
4
|
+
fixture loading, logging, module loading (including dependency checking),
|
|
5
|
+
portal pattern for async bridging, singleton patterns, sync/async conversion,
|
|
6
|
+
text processing, and type guards.
|
|
6
7
|
"""
|
|
7
8
|
|
|
8
|
-
from sqlspec.utils import
|
|
9
|
+
from sqlspec.utils import (
|
|
10
|
+
deprecation,
|
|
11
|
+
fixtures,
|
|
12
|
+
logging,
|
|
13
|
+
module_loader,
|
|
14
|
+
portal,
|
|
15
|
+
singleton,
|
|
16
|
+
sync_tools,
|
|
17
|
+
text,
|
|
18
|
+
type_guards,
|
|
19
|
+
)
|
|
9
20
|
|
|
10
|
-
__all__ = (
|
|
21
|
+
__all__ = (
|
|
22
|
+
"deprecation",
|
|
23
|
+
"fixtures",
|
|
24
|
+
"logging",
|
|
25
|
+
"module_loader",
|
|
26
|
+
"portal",
|
|
27
|
+
"singleton",
|
|
28
|
+
"sync_tools",
|
|
29
|
+
"text",
|
|
30
|
+
"type_guards",
|
|
31
|
+
)
|