sqlspec 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_typing.py +93 -0
- sqlspec/adapters/adbc/adk/store.py +21 -11
- sqlspec/adapters/adbc/data_dictionary.py +27 -5
- sqlspec/adapters/adbc/driver.py +83 -14
- sqlspec/adapters/aiosqlite/adk/store.py +27 -18
- sqlspec/adapters/asyncmy/adk/store.py +26 -16
- sqlspec/adapters/asyncpg/adk/store.py +26 -16
- sqlspec/adapters/asyncpg/data_dictionary.py +24 -17
- sqlspec/adapters/bigquery/adk/store.py +30 -21
- sqlspec/adapters/bigquery/config.py +11 -0
- sqlspec/adapters/bigquery/driver.py +138 -1
- sqlspec/adapters/duckdb/adk/store.py +21 -11
- sqlspec/adapters/duckdb/driver.py +87 -1
- sqlspec/adapters/oracledb/adk/store.py +89 -206
- sqlspec/adapters/oracledb/driver.py +183 -2
- sqlspec/adapters/oracledb/litestar/store.py +22 -24
- sqlspec/adapters/psqlpy/adk/store.py +28 -27
- sqlspec/adapters/psqlpy/data_dictionary.py +24 -17
- sqlspec/adapters/psqlpy/driver.py +7 -10
- sqlspec/adapters/psycopg/adk/store.py +51 -33
- sqlspec/adapters/psycopg/data_dictionary.py +48 -34
- sqlspec/adapters/sqlite/adk/store.py +29 -19
- sqlspec/config.py +100 -2
- sqlspec/core/filters.py +18 -10
- sqlspec/core/result.py +133 -2
- sqlspec/driver/_async.py +89 -0
- sqlspec/driver/_common.py +64 -29
- sqlspec/driver/_sync.py +95 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +2 -2
- sqlspec/extensions/adk/service.py +3 -3
- sqlspec/extensions/adk/store.py +8 -8
- sqlspec/extensions/aiosql/adapter.py +3 -15
- sqlspec/extensions/fastapi/__init__.py +21 -0
- sqlspec/extensions/fastapi/extension.py +331 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +71 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +389 -0
- sqlspec/extensions/litestar/config.py +3 -6
- sqlspec/extensions/litestar/plugin.py +26 -2
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +25 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +254 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/protocols.py +40 -0
- sqlspec/storage/_utils.py +1 -14
- sqlspec/storage/backends/fsspec.py +3 -5
- sqlspec/storage/backends/local.py +1 -1
- sqlspec/storage/backends/obstore.py +10 -18
- sqlspec/typing.py +16 -0
- sqlspec/utils/__init__.py +25 -4
- sqlspec/utils/arrow_helpers.py +81 -0
- sqlspec/utils/module_loader.py +203 -3
- sqlspec/utils/portal.py +311 -0
- sqlspec/utils/serializers.py +110 -1
- sqlspec/utils/sync_tools.py +15 -5
- sqlspec/utils/type_guards.py +25 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +2 -2
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/RECORD +64 -50
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
sqlspec/_typing.py
CHANGED
|
@@ -377,18 +377,101 @@ class ArrowRecordBatchResult(Protocol):
|
|
|
377
377
|
return None
|
|
378
378
|
|
|
379
379
|
|
|
380
|
+
@runtime_checkable
|
|
381
|
+
class ArrowSchemaProtocol(Protocol):
|
|
382
|
+
"""Typed shim for pyarrow.Schema."""
|
|
383
|
+
|
|
384
|
+
def field(self, i: int) -> Any:
|
|
385
|
+
"""Get field by index."""
|
|
386
|
+
...
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def names(self) -> "list[str]":
|
|
390
|
+
"""Get list of field names."""
|
|
391
|
+
...
|
|
392
|
+
|
|
393
|
+
def __len__(self) -> int:
|
|
394
|
+
"""Get number of fields."""
|
|
395
|
+
return 0
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@runtime_checkable
|
|
399
|
+
class ArrowRecordBatchReaderProtocol(Protocol):
|
|
400
|
+
"""Typed shim for pyarrow.RecordBatchReader."""
|
|
401
|
+
|
|
402
|
+
def read_all(self) -> Any:
|
|
403
|
+
"""Read all batches into a table."""
|
|
404
|
+
...
|
|
405
|
+
|
|
406
|
+
def read_next_batch(self) -> Any:
|
|
407
|
+
"""Read next batch."""
|
|
408
|
+
...
|
|
409
|
+
|
|
410
|
+
def __iter__(self) -> "Iterable[Any]":
|
|
411
|
+
"""Iterate over batches."""
|
|
412
|
+
...
|
|
413
|
+
|
|
414
|
+
|
|
380
415
|
try:
|
|
381
416
|
from pyarrow import RecordBatch as ArrowRecordBatch
|
|
417
|
+
from pyarrow import RecordBatchReader as ArrowRecordBatchReader
|
|
418
|
+
from pyarrow import Schema as ArrowSchema
|
|
382
419
|
from pyarrow import Table as ArrowTable
|
|
383
420
|
|
|
384
421
|
PYARROW_INSTALLED = True
|
|
385
422
|
except ImportError:
|
|
386
423
|
ArrowTable = ArrowTableResult # type: ignore[assignment,misc]
|
|
387
424
|
ArrowRecordBatch = ArrowRecordBatchResult # type: ignore[assignment,misc]
|
|
425
|
+
ArrowSchema = ArrowSchemaProtocol # type: ignore[assignment,misc]
|
|
426
|
+
ArrowRecordBatchReader = ArrowRecordBatchReaderProtocol # type: ignore[assignment,misc]
|
|
388
427
|
|
|
389
428
|
PYARROW_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
|
|
390
429
|
|
|
391
430
|
|
|
431
|
+
@runtime_checkable
|
|
432
|
+
class PandasDataFrameProtocol(Protocol):
|
|
433
|
+
"""Typed shim for pandas.DataFrame."""
|
|
434
|
+
|
|
435
|
+
def __len__(self) -> int:
|
|
436
|
+
"""Get number of rows."""
|
|
437
|
+
...
|
|
438
|
+
|
|
439
|
+
def __getitem__(self, key: Any) -> Any:
|
|
440
|
+
"""Get column or row."""
|
|
441
|
+
...
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
@runtime_checkable
|
|
445
|
+
class PolarsDataFrameProtocol(Protocol):
|
|
446
|
+
"""Typed shim for polars.DataFrame."""
|
|
447
|
+
|
|
448
|
+
def __len__(self) -> int:
|
|
449
|
+
"""Get number of rows."""
|
|
450
|
+
...
|
|
451
|
+
|
|
452
|
+
def __getitem__(self, key: Any) -> Any:
|
|
453
|
+
"""Get column or row."""
|
|
454
|
+
...
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
try:
|
|
458
|
+
from pandas import DataFrame as PandasDataFrame
|
|
459
|
+
|
|
460
|
+
PANDAS_INSTALLED = True
|
|
461
|
+
except ImportError:
|
|
462
|
+
PandasDataFrame = PandasDataFrameProtocol # type: ignore[assignment,misc]
|
|
463
|
+
PANDAS_INSTALLED = False
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
try:
|
|
467
|
+
from polars import DataFrame as PolarsDataFrame
|
|
468
|
+
|
|
469
|
+
POLARS_INSTALLED = True
|
|
470
|
+
except ImportError:
|
|
471
|
+
PolarsDataFrame = PolarsDataFrameProtocol # type: ignore[assignment,misc]
|
|
472
|
+
POLARS_INSTALLED = False
|
|
473
|
+
|
|
474
|
+
|
|
392
475
|
@runtime_checkable
|
|
393
476
|
class NumpyArrayStub(Protocol):
|
|
394
477
|
"""Protocol stub for numpy.ndarray when numpy is not installed.
|
|
@@ -639,7 +722,9 @@ __all__ = (
|
|
|
639
722
|
"OBSTORE_INSTALLED",
|
|
640
723
|
"OPENTELEMETRY_INSTALLED",
|
|
641
724
|
"ORJSON_INSTALLED",
|
|
725
|
+
"PANDAS_INSTALLED",
|
|
642
726
|
"PGVECTOR_INSTALLED",
|
|
727
|
+
"POLARS_INSTALLED",
|
|
643
728
|
"PROMETHEUS_INSTALLED",
|
|
644
729
|
"PYARROW_INSTALLED",
|
|
645
730
|
"PYDANTIC_INSTALLED",
|
|
@@ -650,7 +735,11 @@ __all__ = (
|
|
|
650
735
|
"AiosqlSQLOperationType",
|
|
651
736
|
"AiosqlSyncProtocol",
|
|
652
737
|
"ArrowRecordBatch",
|
|
738
|
+
"ArrowRecordBatchReader",
|
|
739
|
+
"ArrowRecordBatchReaderProtocol",
|
|
653
740
|
"ArrowRecordBatchResult",
|
|
741
|
+
"ArrowSchema",
|
|
742
|
+
"ArrowSchemaProtocol",
|
|
654
743
|
"ArrowTable",
|
|
655
744
|
"ArrowTableResult",
|
|
656
745
|
"AttrsInstance",
|
|
@@ -670,6 +759,10 @@ __all__ = (
|
|
|
670
759
|
"Histogram",
|
|
671
760
|
"NumpyArray",
|
|
672
761
|
"NumpyArrayStub",
|
|
762
|
+
"PandasDataFrame",
|
|
763
|
+
"PandasDataFrameProtocol",
|
|
764
|
+
"PolarsDataFrame",
|
|
765
|
+
"PolarsDataFrameProtocol",
|
|
673
766
|
"Span",
|
|
674
767
|
"Status",
|
|
675
768
|
"StatusCode",
|
|
@@ -639,31 +639,41 @@ class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]):
|
|
|
639
639
|
finally:
|
|
640
640
|
cursor.close() # type: ignore[no-untyped-call]
|
|
641
641
|
|
|
642
|
-
def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
|
|
643
|
-
"""List
|
|
642
|
+
def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
|
|
643
|
+
"""List sessions for an app, optionally filtered by user.
|
|
644
644
|
|
|
645
645
|
Args:
|
|
646
646
|
app_name: Application name.
|
|
647
|
-
user_id: User identifier.
|
|
647
|
+
user_id: User identifier. If None, lists all sessions for the app.
|
|
648
648
|
|
|
649
649
|
Returns:
|
|
650
650
|
List of session records ordered by update_time DESC.
|
|
651
651
|
|
|
652
652
|
Notes:
|
|
653
|
-
Uses composite index on (app_name, user_id).
|
|
654
|
-
"""
|
|
655
|
-
sql = f"""
|
|
656
|
-
SELECT id, app_name, user_id, state, create_time, update_time
|
|
657
|
-
FROM {self._session_table}
|
|
658
|
-
WHERE app_name = ? AND user_id = ?
|
|
659
|
-
ORDER BY update_time DESC
|
|
653
|
+
Uses composite index on (app_name, user_id) when user_id is provided.
|
|
660
654
|
"""
|
|
655
|
+
if user_id is None:
|
|
656
|
+
sql = f"""
|
|
657
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
658
|
+
FROM {self._session_table}
|
|
659
|
+
WHERE app_name = ?
|
|
660
|
+
ORDER BY update_time DESC
|
|
661
|
+
"""
|
|
662
|
+
params: tuple[str, ...] = (app_name,)
|
|
663
|
+
else:
|
|
664
|
+
sql = f"""
|
|
665
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
666
|
+
FROM {self._session_table}
|
|
667
|
+
WHERE app_name = ? AND user_id = ?
|
|
668
|
+
ORDER BY update_time DESC
|
|
669
|
+
"""
|
|
670
|
+
params = (app_name, user_id)
|
|
661
671
|
|
|
662
672
|
try:
|
|
663
673
|
with self._config.provide_connection() as conn:
|
|
664
674
|
cursor = conn.cursor()
|
|
665
675
|
try:
|
|
666
|
-
cursor.execute(sql,
|
|
676
|
+
cursor.execute(sql, params)
|
|
667
677
|
rows = cursor.fetchall()
|
|
668
678
|
|
|
669
679
|
return [
|
|
@@ -300,22 +300,44 @@ class AdbcDataDictionary(SyncDataDictionaryBase):
|
|
|
300
300
|
for row in result.data or []
|
|
301
301
|
]
|
|
302
302
|
|
|
303
|
+
if dialect == "postgres":
|
|
304
|
+
schema_name = schema or "public"
|
|
305
|
+
sql = """
|
|
306
|
+
SELECT
|
|
307
|
+
a.attname::text AS column_name,
|
|
308
|
+
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
|
|
309
|
+
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
|
|
310
|
+
pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default
|
|
311
|
+
FROM pg_catalog.pg_attribute a
|
|
312
|
+
JOIN pg_catalog.pg_class c ON a.attrelid = c.oid
|
|
313
|
+
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
|
|
314
|
+
LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
|
|
315
|
+
WHERE c.relname = ?
|
|
316
|
+
AND n.nspname = ?
|
|
317
|
+
AND a.attnum > 0
|
|
318
|
+
AND NOT a.attisdropped
|
|
319
|
+
ORDER BY a.attnum
|
|
320
|
+
"""
|
|
321
|
+
result = adbc_driver.execute(sql, (table, schema_name))
|
|
322
|
+
return result.data or []
|
|
323
|
+
|
|
303
324
|
if schema:
|
|
304
|
-
sql =
|
|
325
|
+
sql = """
|
|
305
326
|
SELECT column_name, data_type, is_nullable, column_default
|
|
306
327
|
FROM information_schema.columns
|
|
307
|
-
WHERE table_name =
|
|
328
|
+
WHERE table_name = ? AND table_schema = ?
|
|
308
329
|
ORDER BY ordinal_position
|
|
309
330
|
"""
|
|
331
|
+
result = adbc_driver.execute(sql, (table, schema))
|
|
310
332
|
else:
|
|
311
|
-
sql =
|
|
333
|
+
sql = """
|
|
312
334
|
SELECT column_name, data_type, is_nullable, column_default
|
|
313
335
|
FROM information_schema.columns
|
|
314
|
-
WHERE table_name =
|
|
336
|
+
WHERE table_name = ?
|
|
315
337
|
ORDER BY ordinal_position
|
|
316
338
|
"""
|
|
339
|
+
result = adbc_driver.execute(sql, (table,))
|
|
317
340
|
|
|
318
|
-
result = adbc_driver.execute(sql)
|
|
319
341
|
return result.data or []
|
|
320
342
|
|
|
321
343
|
def list_available_features(self) -> "list[str]":
|
sqlspec/adapters/adbc/driver.py
CHANGED
|
@@ -15,6 +15,7 @@ from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
|
|
|
15
15
|
from sqlspec.adapters.adbc.type_converter import ADBCTypeConverter
|
|
16
16
|
from sqlspec.core.cache import get_cache_config
|
|
17
17
|
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
|
|
18
|
+
from sqlspec.core.result import create_arrow_result
|
|
18
19
|
from sqlspec.core.statement import SQL, StatementConfig
|
|
19
20
|
from sqlspec.driver import SyncDriverAdapterBase
|
|
20
21
|
from sqlspec.exceptions import (
|
|
@@ -23,7 +24,6 @@ from sqlspec.exceptions import (
|
|
|
23
24
|
DataError,
|
|
24
25
|
ForeignKeyViolationError,
|
|
25
26
|
IntegrityError,
|
|
26
|
-
MissingDependencyError,
|
|
27
27
|
NotNullViolationError,
|
|
28
28
|
SQLParsingError,
|
|
29
29
|
SQLSpecError,
|
|
@@ -32,6 +32,7 @@ from sqlspec.exceptions import (
|
|
|
32
32
|
)
|
|
33
33
|
from sqlspec.typing import Empty
|
|
34
34
|
from sqlspec.utils.logging import get_logger
|
|
35
|
+
from sqlspec.utils.module_loader import ensure_pyarrow
|
|
35
36
|
|
|
36
37
|
if TYPE_CHECKING:
|
|
37
38
|
from contextlib import AbstractContextManager
|
|
@@ -39,9 +40,12 @@ if TYPE_CHECKING:
|
|
|
39
40
|
from adbc_driver_manager.dbapi import Cursor
|
|
40
41
|
|
|
41
42
|
from sqlspec.adapters.adbc._types import AdbcConnection
|
|
42
|
-
from sqlspec.
|
|
43
|
+
from sqlspec.builder import QueryBuilder
|
|
44
|
+
from sqlspec.core import Statement, StatementFilter
|
|
45
|
+
from sqlspec.core.result import ArrowResult, SQLResult
|
|
43
46
|
from sqlspec.driver import ExecutionResult
|
|
44
47
|
from sqlspec.driver._sync import SyncDataDictionaryBase
|
|
48
|
+
from sqlspec.typing import StatementParameters
|
|
45
49
|
|
|
46
50
|
__all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "get_adbc_statement_config")
|
|
47
51
|
|
|
@@ -507,18 +511,6 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
507
511
|
self.dialect = statement_config.dialect
|
|
508
512
|
self._data_dictionary: SyncDataDictionaryBase | None = None
|
|
509
513
|
|
|
510
|
-
@staticmethod
|
|
511
|
-
def _ensure_pyarrow_installed() -> None:
|
|
512
|
-
"""Ensure PyArrow is installed.
|
|
513
|
-
|
|
514
|
-
Raises:
|
|
515
|
-
MissingDependencyError: If PyArrow is not installed
|
|
516
|
-
"""
|
|
517
|
-
from sqlspec.typing import PYARROW_INSTALLED
|
|
518
|
-
|
|
519
|
-
if not PYARROW_INSTALLED:
|
|
520
|
-
raise MissingDependencyError(package="pyarrow", install_package="arrow")
|
|
521
|
-
|
|
522
514
|
@staticmethod
|
|
523
515
|
def _get_dialect(connection: "AdbcConnection") -> str:
|
|
524
516
|
"""Detect database dialect from connection information.
|
|
@@ -863,3 +855,80 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
863
855
|
if self._data_dictionary is None:
|
|
864
856
|
self._data_dictionary = AdbcDataDictionary()
|
|
865
857
|
return self._data_dictionary
|
|
858
|
+
|
|
859
|
+
def select_to_arrow(
|
|
860
|
+
self,
|
|
861
|
+
statement: "Statement | QueryBuilder",
|
|
862
|
+
/,
|
|
863
|
+
*parameters: "StatementParameters | StatementFilter",
|
|
864
|
+
statement_config: "StatementConfig | None" = None,
|
|
865
|
+
return_format: str = "table",
|
|
866
|
+
native_only: bool = False,
|
|
867
|
+
batch_size: int | None = None,
|
|
868
|
+
arrow_schema: Any = None,
|
|
869
|
+
**kwargs: Any,
|
|
870
|
+
) -> "ArrowResult":
|
|
871
|
+
"""Execute query and return results as Apache Arrow (ADBC native path).
|
|
872
|
+
|
|
873
|
+
ADBC provides zero-copy Arrow support via cursor.fetch_arrow_table().
|
|
874
|
+
This is 5-10x faster than the conversion path for large datasets.
|
|
875
|
+
|
|
876
|
+
Args:
|
|
877
|
+
statement: SQL statement, string, or QueryBuilder
|
|
878
|
+
*parameters: Query parameters or filters
|
|
879
|
+
statement_config: Optional statement configuration override
|
|
880
|
+
return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch
|
|
881
|
+
native_only: Ignored for ADBC (always uses native path)
|
|
882
|
+
batch_size: Batch size hint (for future streaming implementation)
|
|
883
|
+
arrow_schema: Optional pyarrow.Schema for type casting
|
|
884
|
+
**kwargs: Additional keyword arguments
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
ArrowResult with native Arrow data
|
|
888
|
+
|
|
889
|
+
Raises:
|
|
890
|
+
MissingDependencyError: If pyarrow not installed
|
|
891
|
+
SQLExecutionError: If query execution fails
|
|
892
|
+
|
|
893
|
+
Example:
|
|
894
|
+
>>> result = driver.select_to_arrow(
|
|
895
|
+
... "SELECT * FROM users WHERE age > $1", 18
|
|
896
|
+
... )
|
|
897
|
+
>>> df = result.to_pandas() # Fast zero-copy conversion
|
|
898
|
+
"""
|
|
899
|
+
ensure_pyarrow()
|
|
900
|
+
|
|
901
|
+
import pyarrow as pa
|
|
902
|
+
|
|
903
|
+
# Prepare statement
|
|
904
|
+
config = statement_config or self.statement_config
|
|
905
|
+
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
|
|
906
|
+
|
|
907
|
+
# Use ADBC cursor for native Arrow
|
|
908
|
+
with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions():
|
|
909
|
+
if cursor is None:
|
|
910
|
+
msg = "Failed to create cursor"
|
|
911
|
+
raise DatabaseConnectionError(msg)
|
|
912
|
+
|
|
913
|
+
# Get compiled SQL and parameters
|
|
914
|
+
sql, driver_params = self._get_compiled_sql(prepared_statement, config)
|
|
915
|
+
|
|
916
|
+
# Execute query
|
|
917
|
+
cursor.execute(sql, driver_params or ())
|
|
918
|
+
|
|
919
|
+
# Fetch as Arrow table (zero-copy!)
|
|
920
|
+
arrow_table = cursor.fetch_arrow_table()
|
|
921
|
+
|
|
922
|
+
# Apply schema casting if requested
|
|
923
|
+
if arrow_schema is not None:
|
|
924
|
+
arrow_table = arrow_table.cast(arrow_schema)
|
|
925
|
+
|
|
926
|
+
# Convert to batch if requested
|
|
927
|
+
if return_format == "batch":
|
|
928
|
+
batches = arrow_table.to_batches()
|
|
929
|
+
arrow_data: Any = batches[0] if batches else pa.RecordBatch.from_pydict({})
|
|
930
|
+
else:
|
|
931
|
+
arrow_data = arrow_table
|
|
932
|
+
|
|
933
|
+
# Create ArrowResult
|
|
934
|
+
return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows)
|
|
@@ -136,7 +136,7 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
|
|
|
136
136
|
"""
|
|
137
137
|
super().__init__(config)
|
|
138
138
|
|
|
139
|
-
def _get_create_sessions_table_sql(self) -> str:
|
|
139
|
+
async def _get_create_sessions_table_sql(self) -> str:
|
|
140
140
|
"""Get SQLite CREATE TABLE SQL for sessions.
|
|
141
141
|
|
|
142
142
|
Returns:
|
|
@@ -163,7 +163,7 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
|
|
|
163
163
|
ON {self._session_table}(update_time DESC);
|
|
164
164
|
"""
|
|
165
165
|
|
|
166
|
-
def _get_create_events_table_sql(self) -> str:
|
|
166
|
+
async def _get_create_events_table_sql(self) -> str:
|
|
167
167
|
"""Get SQLite CREATE TABLE SQL for events.
|
|
168
168
|
|
|
169
169
|
Returns:
|
|
@@ -228,11 +228,10 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
|
|
|
228
228
|
|
|
229
229
|
async def create_tables(self) -> None:
|
|
230
230
|
"""Create both sessions and events tables if they don't exist."""
|
|
231
|
-
async with self._config.
|
|
232
|
-
await self._enable_foreign_keys(
|
|
233
|
-
await
|
|
234
|
-
await
|
|
235
|
-
await conn.commit()
|
|
231
|
+
async with self._config.provide_session() as driver:
|
|
232
|
+
await self._enable_foreign_keys(driver.connection)
|
|
233
|
+
await driver.execute_script(await self._get_create_sessions_table_sql())
|
|
234
|
+
await driver.execute_script(await self._get_create_events_table_sql())
|
|
236
235
|
logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
|
|
237
236
|
|
|
238
237
|
async def create_session(
|
|
@@ -343,29 +342,39 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
|
|
|
343
342
|
await conn.execute(sql, (state_json, now_julian, session_id))
|
|
344
343
|
await conn.commit()
|
|
345
344
|
|
|
346
|
-
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
|
|
347
|
-
"""List
|
|
345
|
+
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
|
|
346
|
+
"""List sessions for an app, optionally filtered by user.
|
|
348
347
|
|
|
349
348
|
Args:
|
|
350
349
|
app_name: Application name.
|
|
351
|
-
user_id: User identifier.
|
|
350
|
+
user_id: User identifier. If None, lists all sessions for the app.
|
|
352
351
|
|
|
353
352
|
Returns:
|
|
354
353
|
List of session records ordered by update_time DESC.
|
|
355
354
|
|
|
356
355
|
Notes:
|
|
357
|
-
Uses composite index on (app_name, user_id).
|
|
358
|
-
"""
|
|
359
|
-
sql = f"""
|
|
360
|
-
SELECT id, app_name, user_id, state, create_time, update_time
|
|
361
|
-
FROM {self._session_table}
|
|
362
|
-
WHERE app_name = ? AND user_id = ?
|
|
363
|
-
ORDER BY update_time DESC
|
|
356
|
+
Uses composite index on (app_name, user_id) when user_id is provided.
|
|
364
357
|
"""
|
|
358
|
+
if user_id is None:
|
|
359
|
+
sql = f"""
|
|
360
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
361
|
+
FROM {self._session_table}
|
|
362
|
+
WHERE app_name = ?
|
|
363
|
+
ORDER BY update_time DESC
|
|
364
|
+
"""
|
|
365
|
+
params: tuple[str, ...] = (app_name,)
|
|
366
|
+
else:
|
|
367
|
+
sql = f"""
|
|
368
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
369
|
+
FROM {self._session_table}
|
|
370
|
+
WHERE app_name = ? AND user_id = ?
|
|
371
|
+
ORDER BY update_time DESC
|
|
372
|
+
"""
|
|
373
|
+
params = (app_name, user_id)
|
|
365
374
|
|
|
366
375
|
async with self._config.provide_connection() as conn:
|
|
367
376
|
await self._enable_foreign_keys(conn)
|
|
368
|
-
cursor = await conn.execute(sql,
|
|
377
|
+
cursor = await conn.execute(sql, params)
|
|
369
378
|
rows = await cursor.fetchall()
|
|
370
379
|
|
|
371
380
|
return [
|
|
@@ -106,7 +106,7 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
|
|
|
106
106
|
|
|
107
107
|
return (col_def, fk_constraint)
|
|
108
108
|
|
|
109
|
-
def _get_create_sessions_table_sql(self) -> str:
|
|
109
|
+
async def _get_create_sessions_table_sql(self) -> str:
|
|
110
110
|
"""Get MySQL CREATE TABLE SQL for sessions.
|
|
111
111
|
|
|
112
112
|
Returns:
|
|
@@ -145,7 +145,7 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
|
|
|
145
145
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
|
|
146
146
|
"""
|
|
147
147
|
|
|
148
|
-
def _get_create_events_table_sql(self) -> str:
|
|
148
|
+
async def _get_create_events_table_sql(self) -> str:
|
|
149
149
|
"""Get MySQL CREATE TABLE SQL for events.
|
|
150
150
|
|
|
151
151
|
Returns:
|
|
@@ -199,9 +199,9 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
|
|
|
199
199
|
|
|
200
200
|
async def create_tables(self) -> None:
|
|
201
201
|
"""Create both sessions and events tables if they don't exist."""
|
|
202
|
-
async with self._config.
|
|
203
|
-
await
|
|
204
|
-
await
|
|
202
|
+
async with self._config.provide_session() as driver:
|
|
203
|
+
await driver.execute_script(await self._get_create_sessions_table_sql())
|
|
204
|
+
await driver.execute_script(await self._get_create_events_table_sql())
|
|
205
205
|
logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
|
|
206
206
|
|
|
207
207
|
async def create_session(
|
|
@@ -326,29 +326,39 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
|
|
|
326
326
|
await cursor.execute(sql, (session_id,))
|
|
327
327
|
await conn.commit()
|
|
328
328
|
|
|
329
|
-
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
|
|
330
|
-
"""List
|
|
329
|
+
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
|
|
330
|
+
"""List sessions for an app, optionally filtered by user.
|
|
331
331
|
|
|
332
332
|
Args:
|
|
333
333
|
app_name: Application name.
|
|
334
|
-
user_id: User identifier.
|
|
334
|
+
user_id: User identifier. If None, lists all sessions for the app.
|
|
335
335
|
|
|
336
336
|
Returns:
|
|
337
337
|
List of session records ordered by update_time DESC.
|
|
338
338
|
|
|
339
339
|
Notes:
|
|
340
|
-
Uses composite index on (app_name, user_id).
|
|
341
|
-
"""
|
|
342
|
-
sql = f"""
|
|
343
|
-
SELECT id, app_name, user_id, state, create_time, update_time
|
|
344
|
-
FROM {self._session_table}
|
|
345
|
-
WHERE app_name = %s AND user_id = %s
|
|
346
|
-
ORDER BY update_time DESC
|
|
340
|
+
Uses composite index on (app_name, user_id) when user_id is provided.
|
|
347
341
|
"""
|
|
342
|
+
if user_id is None:
|
|
343
|
+
sql = f"""
|
|
344
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
345
|
+
FROM {self._session_table}
|
|
346
|
+
WHERE app_name = %s
|
|
347
|
+
ORDER BY update_time DESC
|
|
348
|
+
"""
|
|
349
|
+
params: tuple[str, ...] = (app_name,)
|
|
350
|
+
else:
|
|
351
|
+
sql = f"""
|
|
352
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
353
|
+
FROM {self._session_table}
|
|
354
|
+
WHERE app_name = %s AND user_id = %s
|
|
355
|
+
ORDER BY update_time DESC
|
|
356
|
+
"""
|
|
357
|
+
params = (app_name, user_id)
|
|
348
358
|
|
|
349
359
|
try:
|
|
350
360
|
async with self._config.provide_connection() as conn, conn.cursor() as cursor:
|
|
351
|
-
await cursor.execute(sql,
|
|
361
|
+
await cursor.execute(sql, params)
|
|
352
362
|
rows = await cursor.fetchall()
|
|
353
363
|
|
|
354
364
|
return [
|
|
@@ -84,7 +84,7 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
|
|
|
84
84
|
"""
|
|
85
85
|
super().__init__(config)
|
|
86
86
|
|
|
87
|
-
def _get_create_sessions_table_sql(self) -> str:
|
|
87
|
+
async def _get_create_sessions_table_sql(self) -> str:
|
|
88
88
|
"""Get PostgreSQL CREATE TABLE SQL for sessions.
|
|
89
89
|
|
|
90
90
|
Returns:
|
|
@@ -125,7 +125,7 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
|
|
|
125
125
|
WHERE state != '{{}}'::jsonb;
|
|
126
126
|
"""
|
|
127
127
|
|
|
128
|
-
def _get_create_events_table_sql(self) -> str:
|
|
128
|
+
async def _get_create_events_table_sql(self) -> str:
|
|
129
129
|
"""Get PostgreSQL CREATE TABLE SQL for events.
|
|
130
130
|
|
|
131
131
|
Returns:
|
|
@@ -181,9 +181,9 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
|
|
|
181
181
|
|
|
182
182
|
async def create_tables(self) -> None:
|
|
183
183
|
"""Create both sessions and events tables if they don't exist."""
|
|
184
|
-
async with self.config.
|
|
185
|
-
await
|
|
186
|
-
await
|
|
184
|
+
async with self.config.provide_session() as driver:
|
|
185
|
+
await driver.execute_script(await self._get_create_sessions_table_sql())
|
|
186
|
+
await driver.execute_script(await self._get_create_events_table_sql())
|
|
187
187
|
logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
|
|
188
188
|
|
|
189
189
|
async def create_session(
|
|
@@ -294,29 +294,39 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
|
|
|
294
294
|
async with self.config.provide_connection() as conn:
|
|
295
295
|
await conn.execute(sql, session_id)
|
|
296
296
|
|
|
297
|
-
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
|
|
298
|
-
"""List
|
|
297
|
+
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
|
|
298
|
+
"""List sessions for an app, optionally filtered by user.
|
|
299
299
|
|
|
300
300
|
Args:
|
|
301
301
|
app_name: Application name.
|
|
302
|
-
user_id: User identifier.
|
|
302
|
+
user_id: User identifier. If None, lists all sessions for the app.
|
|
303
303
|
|
|
304
304
|
Returns:
|
|
305
305
|
List of session records ordered by update_time DESC.
|
|
306
306
|
|
|
307
307
|
Notes:
|
|
308
|
-
Uses composite index on (app_name, user_id).
|
|
309
|
-
"""
|
|
310
|
-
sql = f"""
|
|
311
|
-
SELECT id, app_name, user_id, state, create_time, update_time
|
|
312
|
-
FROM {self._session_table}
|
|
313
|
-
WHERE app_name = $1 AND user_id = $2
|
|
314
|
-
ORDER BY update_time DESC
|
|
308
|
+
Uses composite index on (app_name, user_id) when user_id is provided.
|
|
315
309
|
"""
|
|
310
|
+
if user_id is None:
|
|
311
|
+
sql = f"""
|
|
312
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
313
|
+
FROM {self._session_table}
|
|
314
|
+
WHERE app_name = $1
|
|
315
|
+
ORDER BY update_time DESC
|
|
316
|
+
"""
|
|
317
|
+
params = [app_name]
|
|
318
|
+
else:
|
|
319
|
+
sql = f"""
|
|
320
|
+
SELECT id, app_name, user_id, state, create_time, update_time
|
|
321
|
+
FROM {self._session_table}
|
|
322
|
+
WHERE app_name = $1 AND user_id = $2
|
|
323
|
+
ORDER BY update_time DESC
|
|
324
|
+
"""
|
|
325
|
+
params = [app_name, user_id]
|
|
316
326
|
|
|
317
327
|
try:
|
|
318
328
|
async with self.config.provide_connection() as conn:
|
|
319
|
-
rows = await conn.fetch(sql,
|
|
329
|
+
rows = await conn.fetch(sql, *params)
|
|
320
330
|
|
|
321
331
|
return [
|
|
322
332
|
SessionRecord(
|