sqlspec 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_typing.py +93 -0
- sqlspec/adapters/adbc/adk/store.py +21 -11
- sqlspec/adapters/adbc/data_dictionary.py +27 -5
- sqlspec/adapters/adbc/driver.py +83 -14
- sqlspec/adapters/aiosqlite/adk/store.py +27 -18
- sqlspec/adapters/asyncmy/adk/store.py +26 -16
- sqlspec/adapters/asyncpg/adk/store.py +26 -16
- sqlspec/adapters/asyncpg/data_dictionary.py +24 -17
- sqlspec/adapters/bigquery/adk/store.py +30 -21
- sqlspec/adapters/bigquery/config.py +11 -0
- sqlspec/adapters/bigquery/driver.py +138 -1
- sqlspec/adapters/duckdb/adk/store.py +21 -11
- sqlspec/adapters/duckdb/driver.py +87 -1
- sqlspec/adapters/oracledb/adk/store.py +89 -206
- sqlspec/adapters/oracledb/driver.py +183 -2
- sqlspec/adapters/oracledb/litestar/store.py +22 -24
- sqlspec/adapters/psqlpy/adk/store.py +28 -27
- sqlspec/adapters/psqlpy/data_dictionary.py +24 -17
- sqlspec/adapters/psqlpy/driver.py +7 -10
- sqlspec/adapters/psycopg/adk/store.py +51 -33
- sqlspec/adapters/psycopg/data_dictionary.py +48 -34
- sqlspec/adapters/sqlite/adk/store.py +29 -19
- sqlspec/config.py +100 -2
- sqlspec/core/filters.py +18 -10
- sqlspec/core/result.py +133 -2
- sqlspec/driver/_async.py +89 -0
- sqlspec/driver/_common.py +64 -29
- sqlspec/driver/_sync.py +95 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +2 -2
- sqlspec/extensions/adk/service.py +3 -3
- sqlspec/extensions/adk/store.py +8 -8
- sqlspec/extensions/aiosql/adapter.py +3 -15
- sqlspec/extensions/fastapi/__init__.py +21 -0
- sqlspec/extensions/fastapi/extension.py +331 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +71 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +389 -0
- sqlspec/extensions/litestar/config.py +3 -6
- sqlspec/extensions/litestar/plugin.py +26 -2
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +25 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +254 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/protocols.py +40 -0
- sqlspec/storage/_utils.py +1 -14
- sqlspec/storage/backends/fsspec.py +3 -5
- sqlspec/storage/backends/local.py +1 -1
- sqlspec/storage/backends/obstore.py +10 -18
- sqlspec/typing.py +16 -0
- sqlspec/utils/__init__.py +25 -4
- sqlspec/utils/arrow_helpers.py +81 -0
- sqlspec/utils/module_loader.py +203 -3
- sqlspec/utils/portal.py +311 -0
- sqlspec/utils/serializers.py +110 -1
- sqlspec/utils/sync_tools.py +15 -5
- sqlspec/utils/type_guards.py +25 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +2 -2
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/RECORD +64 -50
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
|
|
|
6
6
|
from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
|
|
7
7
|
from sqlspec.utils.logging import get_logger
|
|
8
8
|
from sqlspec.utils.serializers import from_json, to_json
|
|
9
|
-
from sqlspec.utils.sync_tools import async_
|
|
9
|
+
from sqlspec.utils.sync_tools import async_, run_
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from sqlspec.adapters.sqlite.config import SqliteConfig
|
|
@@ -140,7 +140,7 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]):
|
|
|
140
140
|
"""
|
|
141
141
|
super().__init__(config)
|
|
142
142
|
|
|
143
|
-
def _get_create_sessions_table_sql(self) -> str:
|
|
143
|
+
async def _get_create_sessions_table_sql(self) -> str:
|
|
144
144
|
"""Get SQLite CREATE TABLE SQL for sessions.
|
|
145
145
|
|
|
146
146
|
Returns:
|
|
@@ -172,7 +172,7 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]):
|
|
|
172
172
|
ON {self._session_table}(update_time DESC);
|
|
173
173
|
"""
|
|
174
174
|
|
|
175
|
-
def _get_create_events_table_sql(self) -> str:
|
|
175
|
+
async def _get_create_events_table_sql(self) -> str:
|
|
176
176
|
"""Get SQLite CREATE TABLE SQL for events.
|
|
177
177
|
|
|
178
178
|
Returns:
|
|
@@ -237,10 +237,10 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]):
|
|
|
237
237
|
|
|
238
238
|
def _create_tables(self) -> None:
|
|
239
239
|
"""Synchronous implementation of create_tables."""
|
|
240
|
-
with self._config.
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
240
|
+
with self._config.provide_session() as driver:
|
|
241
|
+
driver.connection.execute("PRAGMA foreign_keys = ON")
|
|
242
|
+
driver.execute_script(run_(self._get_create_sessions_table_sql)())
|
|
243
|
+
driver.execute_script(run_(self._get_create_events_table_sql)())
|
|
244
244
|
logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
|
|
245
245
|
|
|
246
246
|
async def create_tables(self) -> None:
|
|
@@ -370,18 +370,28 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]):
|
|
|
370
370
|
"""
|
|
371
371
|
await async_(self._update_session_state)(session_id, state)
|
|
372
372
|
|
|
373
|
-
def _list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
|
|
373
|
+
def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionRecord]":
|
|
374
374
|
"""Synchronous implementation of list_sessions."""
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
375
|
+
if user_id is None:
|
|
376
|
+
sql = f"""
|
|
377
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
378
|
+
FROM {self._session_table}
|
|
379
|
+
WHERE app_name = ?
|
|
380
|
+
ORDER BY update_time DESC
|
|
381
|
+
"""
|
|
382
|
+
params: tuple[str, ...] = (app_name,)
|
|
383
|
+
else:
|
|
384
|
+
sql = f"""
|
|
385
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
386
|
+
FROM {self._session_table}
|
|
387
|
+
WHERE app_name = ? AND user_id = ?
|
|
388
|
+
ORDER BY update_time DESC
|
|
389
|
+
"""
|
|
390
|
+
params = (app_name, user_id)
|
|
381
391
|
|
|
382
392
|
with self._config.provide_connection() as conn:
|
|
383
393
|
self._enable_foreign_keys(conn)
|
|
384
|
-
cursor = conn.execute(sql,
|
|
394
|
+
cursor = conn.execute(sql, params)
|
|
385
395
|
rows = cursor.fetchall()
|
|
386
396
|
|
|
387
397
|
return [
|
|
@@ -396,18 +406,18 @@ class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]):
|
|
|
396
406
|
for row in rows
|
|
397
407
|
]
|
|
398
408
|
|
|
399
|
-
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
|
|
400
|
-
"""List
|
|
409
|
+
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
|
|
410
|
+
"""List sessions for an app, optionally filtered by user.
|
|
401
411
|
|
|
402
412
|
Args:
|
|
403
413
|
app_name: Application name.
|
|
404
|
-
user_id: User identifier.
|
|
414
|
+
user_id: User identifier. If None, lists all sessions for the app.
|
|
405
415
|
|
|
406
416
|
Returns:
|
|
407
417
|
List of session records ordered by update_time DESC.
|
|
408
418
|
|
|
409
419
|
Notes:
|
|
410
|
-
Uses composite index on (app_name, user_id).
|
|
420
|
+
Uses composite index on (app_name, user_id) when user_id is provided.
|
|
411
421
|
"""
|
|
412
422
|
return await async_(self._list_sessions)(app_name, user_id)
|
|
413
423
|
|
sqlspec/config.py
CHANGED
|
@@ -26,11 +26,13 @@ __all__ = (
|
|
|
26
26
|
"ConfigT",
|
|
27
27
|
"DatabaseConfigProtocol",
|
|
28
28
|
"DriverT",
|
|
29
|
+
"FlaskConfig",
|
|
29
30
|
"LifecycleConfig",
|
|
30
31
|
"LitestarConfig",
|
|
31
32
|
"MigrationConfig",
|
|
32
33
|
"NoPoolAsyncConfig",
|
|
33
34
|
"NoPoolSyncConfig",
|
|
35
|
+
"StarletteConfig",
|
|
34
36
|
"SyncConfigT",
|
|
35
37
|
"SyncDatabaseConfig",
|
|
36
38
|
)
|
|
@@ -98,6 +100,49 @@ class MigrationConfig(TypedDict):
|
|
|
98
100
|
"""Wrap migrations in transactions when supported. When enabled (default for adapters that support it), each migration runs in a transaction that is committed on success or rolled back on failure. This prevents partial migrations from leaving the database in an inconsistent state. Requires adapter support for transactional DDL. Defaults to True for PostgreSQL, SQLite, and DuckDB; False for MySQL, Oracle, and BigQuery. Individual migrations can override this with a '-- transactional: false' comment."""
|
|
99
101
|
|
|
100
102
|
|
|
103
|
+
class FlaskConfig(TypedDict):
|
|
104
|
+
"""Configuration options for Flask SQLSpec extension.
|
|
105
|
+
|
|
106
|
+
All fields are optional with sensible defaults. Use in extension_config["flask"]:
|
|
107
|
+
|
|
108
|
+
Example:
|
|
109
|
+
from sqlspec.adapters.asyncpg import AsyncpgConfig
|
|
110
|
+
|
|
111
|
+
config = AsyncpgConfig(
|
|
112
|
+
pool_config={"dsn": "postgresql://localhost/mydb"},
|
|
113
|
+
extension_config={
|
|
114
|
+
"flask": {
|
|
115
|
+
"commit_mode": "autocommit",
|
|
116
|
+
"session_key": "db"
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
Notes:
|
|
122
|
+
This TypedDict provides type safety for extension config.
|
|
123
|
+
Flask extension uses g object for request-scoped storage.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
connection_key: NotRequired[str]
|
|
127
|
+
"""Key for storing connection in Flask g object. Default: auto-generated from session_key."""
|
|
128
|
+
|
|
129
|
+
session_key: NotRequired[str]
|
|
130
|
+
"""Key for accessing session via plugin.get_session(). Default: 'db_session'."""
|
|
131
|
+
|
|
132
|
+
commit_mode: NotRequired[Literal["manual", "autocommit", "autocommit_include_redirect"]]
|
|
133
|
+
"""Transaction commit mode. Default: 'manual'.
|
|
134
|
+
- manual: No automatic commits, user handles explicitly
|
|
135
|
+
- autocommit: Commits on 2xx status, rollback otherwise
|
|
136
|
+
- autocommit_include_redirect: Commits on 2xx-3xx status, rollback otherwise
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
extra_commit_statuses: NotRequired[set[int]]
|
|
140
|
+
"""Additional HTTP status codes that trigger commit. Default: None."""
|
|
141
|
+
|
|
142
|
+
extra_rollback_statuses: NotRequired[set[int]]
|
|
143
|
+
"""Additional HTTP status codes that trigger rollback. Default: None."""
|
|
144
|
+
|
|
145
|
+
|
|
101
146
|
class LitestarConfig(TypedDict):
|
|
102
147
|
"""Configuration options for Litestar SQLSpec plugin.
|
|
103
148
|
|
|
@@ -126,6 +171,61 @@ class LitestarConfig(TypedDict):
|
|
|
126
171
|
"""Additional HTTP status codes that trigger rollback. Default: set()"""
|
|
127
172
|
|
|
128
173
|
|
|
174
|
+
class StarletteConfig(TypedDict):
|
|
175
|
+
"""Configuration options for Starlette and FastAPI extensions.
|
|
176
|
+
|
|
177
|
+
All fields are optional with sensible defaults. Use in extension_config["starlette"]:
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
from sqlspec.adapters.asyncpg import AsyncpgConfig
|
|
181
|
+
|
|
182
|
+
config = AsyncpgConfig(
|
|
183
|
+
pool_config={"dsn": "postgresql://localhost/mydb"},
|
|
184
|
+
extension_config={
|
|
185
|
+
"starlette": {
|
|
186
|
+
"commit_mode": "autocommit",
|
|
187
|
+
"session_key": "db"
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
Notes:
|
|
193
|
+
Both Starlette and FastAPI extensions use the "starlette" key.
|
|
194
|
+
This TypedDict provides type safety for extension config.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
connection_key: NotRequired[str]
|
|
198
|
+
"""Key for storing connection in request.state. Default: 'db_connection'"""
|
|
199
|
+
|
|
200
|
+
pool_key: NotRequired[str]
|
|
201
|
+
"""Key for storing connection pool in app.state. Default: 'db_pool'"""
|
|
202
|
+
|
|
203
|
+
session_key: NotRequired[str]
|
|
204
|
+
"""Key for storing session in request.state. Default: 'db_session'"""
|
|
205
|
+
|
|
206
|
+
commit_mode: NotRequired[Literal["manual", "autocommit", "autocommit_include_redirect"]]
|
|
207
|
+
"""Transaction commit mode. Default: 'manual'
|
|
208
|
+
|
|
209
|
+
- manual: No automatic commit/rollback
|
|
210
|
+
- autocommit: Commit on 2xx, rollback otherwise
|
|
211
|
+
- autocommit_include_redirect: Commit on 2xx-3xx, rollback otherwise
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
extra_commit_statuses: NotRequired[set[int]]
|
|
215
|
+
"""Additional HTTP status codes that trigger commit. Default: set()
|
|
216
|
+
|
|
217
|
+
Example:
|
|
218
|
+
extra_commit_statuses={201, 202}
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
extra_rollback_statuses: NotRequired[set[int]]
|
|
222
|
+
"""Additional HTTP status codes that trigger rollback. Default: set()
|
|
223
|
+
|
|
224
|
+
Example:
|
|
225
|
+
extra_rollback_statuses={409}
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
|
|
129
229
|
class ADKConfig(TypedDict):
|
|
130
230
|
"""Configuration options for ADK session store extension.
|
|
131
231
|
|
|
@@ -356,8 +456,6 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
|
|
|
356
456
|
migration_config = self.migration_config or {}
|
|
357
457
|
script_location = migration_config.get("script_location", "migrations")
|
|
358
458
|
|
|
359
|
-
from pathlib import Path
|
|
360
|
-
|
|
361
459
|
migration_path = Path(script_location)
|
|
362
460
|
if migration_path.exists() and not self._migration_loader.list_files():
|
|
363
461
|
self._migration_loader.load_sql(migration_path)
|
sqlspec/core/filters.py
CHANGED
|
@@ -576,10 +576,13 @@ class LimitOffsetFilter(PaginationFilter):
|
|
|
576
576
|
limit_placeholder = exp.Placeholder(this=limit_param_name)
|
|
577
577
|
offset_placeholder = exp.Placeholder(this=offset_param_name)
|
|
578
578
|
|
|
579
|
-
|
|
580
|
-
current_statement =
|
|
581
|
-
|
|
582
|
-
|
|
579
|
+
if statement.statement_expression is not None:
|
|
580
|
+
current_statement = statement.statement_expression
|
|
581
|
+
else:
|
|
582
|
+
try:
|
|
583
|
+
current_statement = sqlglot.parse_one(statement.raw_sql, dialect=statement.dialect)
|
|
584
|
+
except Exception:
|
|
585
|
+
current_statement = exp.Select().from_(f"({statement.raw_sql})")
|
|
583
586
|
|
|
584
587
|
if isinstance(current_statement, exp.Select):
|
|
585
588
|
new_statement = current_statement.limit(limit_placeholder).offset(offset_placeholder)
|
|
@@ -587,7 +590,6 @@ class LimitOffsetFilter(PaginationFilter):
|
|
|
587
590
|
new_statement = exp.Select().from_(current_statement).limit(limit_placeholder).offset(offset_placeholder)
|
|
588
591
|
|
|
589
592
|
result = statement.copy(statement=new_statement)
|
|
590
|
-
|
|
591
593
|
result = result.add_named_parameter(limit_param_name, self.limit)
|
|
592
594
|
return result.add_named_parameter(offset_param_name, self.offset)
|
|
593
595
|
|
|
@@ -628,12 +630,18 @@ class OrderByFilter(StatementFilter):
|
|
|
628
630
|
col_expr = exp.column(self.field_name)
|
|
629
631
|
order_expr = col_expr.desc() if converted_sort_order == "desc" else col_expr.asc()
|
|
630
632
|
|
|
631
|
-
if statement.statement_expression is None:
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
633
|
+
if statement.statement_expression is not None:
|
|
634
|
+
current_statement = statement.statement_expression
|
|
635
|
+
else:
|
|
636
|
+
try:
|
|
637
|
+
current_statement = sqlglot.parse_one(statement.raw_sql, dialect=statement.dialect)
|
|
638
|
+
except Exception:
|
|
639
|
+
current_statement = exp.Select().from_(f"({statement.raw_sql})")
|
|
640
|
+
|
|
641
|
+
if isinstance(current_statement, exp.Select):
|
|
642
|
+
new_statement = current_statement.order_by(order_expr)
|
|
635
643
|
else:
|
|
636
|
-
new_statement = exp.Select().from_(
|
|
644
|
+
new_statement = exp.Select().from_(current_statement).order_by(order_expr)
|
|
637
645
|
|
|
638
646
|
return statement.copy(statement=new_statement)
|
|
639
647
|
|
sqlspec/core/result.py
CHANGED
|
@@ -16,13 +16,14 @@ from mypy_extensions import mypyc_attr
|
|
|
16
16
|
from typing_extensions import TypeVar
|
|
17
17
|
|
|
18
18
|
from sqlspec.core.compiler import OperationType
|
|
19
|
+
from sqlspec.utils.module_loader import ensure_pandas, ensure_polars, ensure_pyarrow
|
|
19
20
|
from sqlspec.utils.schema import to_schema
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
22
23
|
from collections.abc import Iterator
|
|
23
24
|
|
|
24
25
|
from sqlspec.core.statement import SQL
|
|
25
|
-
from sqlspec.typing import SchemaT
|
|
26
|
+
from sqlspec.typing import ArrowTable, PandasDataFrame, PolarsDataFrame, SchemaT
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
__all__ = ("ArrowResult", "SQLResult", "StatementResult")
|
|
@@ -618,7 +619,7 @@ class ArrowResult(StatementResult):
|
|
|
618
619
|
"""
|
|
619
620
|
return self.data is not None
|
|
620
621
|
|
|
621
|
-
def get_data(self) ->
|
|
622
|
+
def get_data(self) -> "ArrowTable":
|
|
622
623
|
"""Get the Apache Arrow Table from the result.
|
|
623
624
|
|
|
624
625
|
Returns:
|
|
@@ -626,10 +627,19 @@ class ArrowResult(StatementResult):
|
|
|
626
627
|
|
|
627
628
|
Raises:
|
|
628
629
|
ValueError: If no Arrow table is available.
|
|
630
|
+
TypeError: If data is not an Arrow Table.
|
|
629
631
|
"""
|
|
630
632
|
if self.data is None:
|
|
631
633
|
msg = "No Arrow table available for this result"
|
|
632
634
|
raise ValueError(msg)
|
|
635
|
+
|
|
636
|
+
ensure_pyarrow()
|
|
637
|
+
|
|
638
|
+
import pyarrow as pa
|
|
639
|
+
|
|
640
|
+
if not isinstance(self.data, pa.Table):
|
|
641
|
+
msg = f"Expected an Arrow Table, but got {type(self.data).__name__}"
|
|
642
|
+
raise TypeError(msg)
|
|
633
643
|
return self.data
|
|
634
644
|
|
|
635
645
|
@property
|
|
@@ -680,6 +690,127 @@ class ArrowResult(StatementResult):
|
|
|
680
690
|
|
|
681
691
|
return cast("int", self.data.num_columns)
|
|
682
692
|
|
|
693
|
+
def to_pandas(self) -> "PandasDataFrame":
|
|
694
|
+
"""Convert Arrow data to pandas DataFrame.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
pandas DataFrame containing the result data.
|
|
698
|
+
|
|
699
|
+
Raises:
|
|
700
|
+
ValueError: If no Arrow table is available.
|
|
701
|
+
|
|
702
|
+
Examples:
|
|
703
|
+
>>> result = session.select_to_arrow("SELECT * FROM users")
|
|
704
|
+
>>> df = result.to_pandas()
|
|
705
|
+
>>> print(df.head())
|
|
706
|
+
"""
|
|
707
|
+
if self.data is None:
|
|
708
|
+
msg = "No Arrow table available"
|
|
709
|
+
raise ValueError(msg)
|
|
710
|
+
|
|
711
|
+
ensure_pandas()
|
|
712
|
+
|
|
713
|
+
import pandas as pd
|
|
714
|
+
|
|
715
|
+
result = self.data.to_pandas()
|
|
716
|
+
if not isinstance(result, pd.DataFrame):
|
|
717
|
+
msg = f"Expected a pandas DataFrame, but got {type(result).__name__}"
|
|
718
|
+
raise TypeError(msg)
|
|
719
|
+
return result
|
|
720
|
+
|
|
721
|
+
def to_polars(self) -> "PolarsDataFrame":
|
|
722
|
+
"""Convert Arrow data to Polars DataFrame.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
Polars DataFrame containing the result data.
|
|
726
|
+
|
|
727
|
+
Raises:
|
|
728
|
+
ValueError: If no Arrow table is available.
|
|
729
|
+
|
|
730
|
+
Examples:
|
|
731
|
+
>>> result = session.select_to_arrow("SELECT * FROM users")
|
|
732
|
+
>>> df = result.to_polars()
|
|
733
|
+
>>> print(df.head())
|
|
734
|
+
"""
|
|
735
|
+
if self.data is None:
|
|
736
|
+
msg = "No Arrow table available"
|
|
737
|
+
raise ValueError(msg)
|
|
738
|
+
|
|
739
|
+
ensure_polars()
|
|
740
|
+
|
|
741
|
+
import polars as pl
|
|
742
|
+
|
|
743
|
+
result = pl.from_arrow(self.data)
|
|
744
|
+
if not isinstance(result, pl.DataFrame):
|
|
745
|
+
msg = f"Expected a Polars DataFrame, but got {type(result).__name__}"
|
|
746
|
+
raise TypeError(msg)
|
|
747
|
+
return result
|
|
748
|
+
|
|
749
|
+
def to_dict(self) -> "list[dict[str, Any]]":
|
|
750
|
+
"""Convert Arrow data to list of dictionaries.
|
|
751
|
+
|
|
752
|
+
Returns:
|
|
753
|
+
List of dictionaries, one per row.
|
|
754
|
+
|
|
755
|
+
Raises:
|
|
756
|
+
ValueError: If no Arrow table is available.
|
|
757
|
+
|
|
758
|
+
Examples:
|
|
759
|
+
>>> result = session.select_to_arrow(
|
|
760
|
+
... "SELECT id, name FROM users"
|
|
761
|
+
... )
|
|
762
|
+
>>> rows = result.to_dict()
|
|
763
|
+
>>> print(rows[0])
|
|
764
|
+
{'id': 1, 'name': 'Alice'}
|
|
765
|
+
"""
|
|
766
|
+
if self.data is None:
|
|
767
|
+
msg = "No Arrow table available"
|
|
768
|
+
raise ValueError(msg)
|
|
769
|
+
|
|
770
|
+
return cast("list[dict[str, Any]]", self.data.to_pylist())
|
|
771
|
+
|
|
772
|
+
def __len__(self) -> int:
|
|
773
|
+
"""Return number of rows in the Arrow table.
|
|
774
|
+
|
|
775
|
+
Returns:
|
|
776
|
+
Number of rows.
|
|
777
|
+
|
|
778
|
+
Raises:
|
|
779
|
+
ValueError: If no Arrow table is available.
|
|
780
|
+
|
|
781
|
+
Examples:
|
|
782
|
+
>>> result = session.select_to_arrow("SELECT * FROM users")
|
|
783
|
+
>>> print(len(result))
|
|
784
|
+
100
|
|
785
|
+
"""
|
|
786
|
+
if self.data is None:
|
|
787
|
+
msg = "No Arrow table available"
|
|
788
|
+
raise ValueError(msg)
|
|
789
|
+
|
|
790
|
+
return cast("int", self.data.num_rows)
|
|
791
|
+
|
|
792
|
+
def __iter__(self) -> "Iterator[dict[str, Any]]":
|
|
793
|
+
"""Iterate over rows as dictionaries.
|
|
794
|
+
|
|
795
|
+
Yields:
|
|
796
|
+
Dictionary for each row.
|
|
797
|
+
|
|
798
|
+
Raises:
|
|
799
|
+
ValueError: If no Arrow table is available.
|
|
800
|
+
|
|
801
|
+
Examples:
|
|
802
|
+
>>> result = session.select_to_arrow(
|
|
803
|
+
... "SELECT id, name FROM users"
|
|
804
|
+
... )
|
|
805
|
+
>>> for row in result:
|
|
806
|
+
... print(row["name"])
|
|
807
|
+
"""
|
|
808
|
+
if self.data is None:
|
|
809
|
+
msg = "No Arrow table available"
|
|
810
|
+
raise ValueError(msg)
|
|
811
|
+
|
|
812
|
+
yield from self.data.to_pylist()
|
|
813
|
+
|
|
683
814
|
|
|
684
815
|
def create_sql_result(
|
|
685
816
|
statement: "SQL",
|
sqlspec/driver/_async.py
CHANGED
|
@@ -4,6 +4,7 @@ from abc import abstractmethod
|
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Final, TypeVar, overload
|
|
5
5
|
|
|
6
6
|
from sqlspec.core import SQL, Statement
|
|
7
|
+
from sqlspec.core.result import create_arrow_result
|
|
7
8
|
from sqlspec.driver._common import (
|
|
8
9
|
CommonDriverAttributesMixin,
|
|
9
10
|
DataDictionaryMixin,
|
|
@@ -12,7 +13,10 @@ from sqlspec.driver._common import (
|
|
|
12
13
|
handle_single_row_error,
|
|
13
14
|
)
|
|
14
15
|
from sqlspec.driver.mixins import SQLTranslatorMixin
|
|
16
|
+
from sqlspec.exceptions import ImproperConfigurationError
|
|
17
|
+
from sqlspec.utils.arrow_helpers import convert_dict_to_arrow
|
|
15
18
|
from sqlspec.utils.logging import get_logger
|
|
19
|
+
from sqlspec.utils.module_loader import ensure_pyarrow
|
|
16
20
|
|
|
17
21
|
if TYPE_CHECKING:
|
|
18
22
|
from collections.abc import Sequence
|
|
@@ -341,6 +345,91 @@ class AsyncDriverAdapterBase(CommonDriverAttributesMixin, SQLTranslatorMixin):
|
|
|
341
345
|
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
342
346
|
return result.get_data(schema_type=schema_type)
|
|
343
347
|
|
|
348
|
+
async def select_to_arrow(
|
|
349
|
+
self,
|
|
350
|
+
statement: "Statement | QueryBuilder",
|
|
351
|
+
/,
|
|
352
|
+
*parameters: "StatementParameters | StatementFilter",
|
|
353
|
+
statement_config: "StatementConfig | None" = None,
|
|
354
|
+
return_format: str = "table",
|
|
355
|
+
native_only: bool = False,
|
|
356
|
+
batch_size: int | None = None,
|
|
357
|
+
arrow_schema: Any = None,
|
|
358
|
+
**kwargs: Any,
|
|
359
|
+
) -> "Any":
|
|
360
|
+
"""Execute query and return results as Apache Arrow format (async).
|
|
361
|
+
|
|
362
|
+
This base implementation uses the conversion path: execute() → dict → Arrow.
|
|
363
|
+
Adapters with native Arrow support (ADBC, DuckDB, BigQuery) override this
|
|
364
|
+
method to use zero-copy native paths for 5-10x performance improvement.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
statement: SQL query string, Statement, or QueryBuilder
|
|
368
|
+
*parameters: Query parameters (same format as execute()/select())
|
|
369
|
+
statement_config: Optional statement configuration override
|
|
370
|
+
return_format: "table" for pyarrow.Table (default), "reader" for RecordBatchReader,
|
|
371
|
+
"batches" for iterator of RecordBatches
|
|
372
|
+
native_only: If True, raise error if native Arrow unavailable (default: False)
|
|
373
|
+
batch_size: Rows per batch for "batches" format (default: None = all rows)
|
|
374
|
+
arrow_schema: Optional pyarrow.Schema for type casting
|
|
375
|
+
**kwargs: Additional keyword arguments
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
ArrowResult containing pyarrow.Table, RecordBatchReader, or RecordBatches
|
|
379
|
+
|
|
380
|
+
Raises:
|
|
381
|
+
ImproperConfigurationError: If native_only=True and adapter doesn't support native Arrow
|
|
382
|
+
|
|
383
|
+
Examples:
|
|
384
|
+
>>> result = await driver.select_to_arrow(
|
|
385
|
+
... "SELECT * FROM users WHERE age > ?", 18
|
|
386
|
+
... )
|
|
387
|
+
>>> df = result.to_pandas()
|
|
388
|
+
>>> print(df.head())
|
|
389
|
+
|
|
390
|
+
>>> # Force native Arrow path (raises error if unavailable)
|
|
391
|
+
>>> result = await driver.select_to_arrow(
|
|
392
|
+
... "SELECT * FROM users", native_only=True
|
|
393
|
+
... )
|
|
394
|
+
"""
|
|
395
|
+
# Check pyarrow is available
|
|
396
|
+
ensure_pyarrow()
|
|
397
|
+
|
|
398
|
+
# Check if native_only requested but not supported
|
|
399
|
+
if native_only:
|
|
400
|
+
msg = (
|
|
401
|
+
f"Adapter '{self.__class__.__name__}' does not support native Arrow results. "
|
|
402
|
+
f"Use native_only=False to allow conversion path, or switch to an adapter "
|
|
403
|
+
f"with native Arrow support (ADBC, DuckDB, BigQuery)."
|
|
404
|
+
)
|
|
405
|
+
raise ImproperConfigurationError(msg)
|
|
406
|
+
|
|
407
|
+
# Execute query using standard path
|
|
408
|
+
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
409
|
+
|
|
410
|
+
# Convert dict results to Arrow
|
|
411
|
+
arrow_data = convert_dict_to_arrow(
|
|
412
|
+
result.data,
|
|
413
|
+
return_format=return_format, # type: ignore[arg-type]
|
|
414
|
+
batch_size=batch_size,
|
|
415
|
+
)
|
|
416
|
+
if arrow_schema is not None:
|
|
417
|
+
import pyarrow as pa
|
|
418
|
+
|
|
419
|
+
if not isinstance(arrow_schema, pa.Schema):
|
|
420
|
+
msg = f"arrow_schema must be a pyarrow.Schema, got {type(arrow_schema).__name__}"
|
|
421
|
+
raise TypeError(msg)
|
|
422
|
+
|
|
423
|
+
arrow_data = arrow_data.cast(arrow_schema)
|
|
424
|
+
return create_arrow_result(
|
|
425
|
+
statement=result.statement,
|
|
426
|
+
data=arrow_data,
|
|
427
|
+
rows_affected=result.rows_affected,
|
|
428
|
+
last_inserted_id=result.last_inserted_id,
|
|
429
|
+
execution_time=result.execution_time,
|
|
430
|
+
metadata=result.metadata,
|
|
431
|
+
)
|
|
432
|
+
|
|
344
433
|
async def select_value(
|
|
345
434
|
self,
|
|
346
435
|
statement: "Statement | QueryBuilder",
|