sqlspec 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (64) hide show
  1. sqlspec/_typing.py +93 -0
  2. sqlspec/adapters/adbc/adk/store.py +21 -11
  3. sqlspec/adapters/adbc/data_dictionary.py +27 -5
  4. sqlspec/adapters/adbc/driver.py +83 -14
  5. sqlspec/adapters/aiosqlite/adk/store.py +27 -18
  6. sqlspec/adapters/asyncmy/adk/store.py +26 -16
  7. sqlspec/adapters/asyncpg/adk/store.py +26 -16
  8. sqlspec/adapters/asyncpg/data_dictionary.py +24 -17
  9. sqlspec/adapters/bigquery/adk/store.py +30 -21
  10. sqlspec/adapters/bigquery/config.py +11 -0
  11. sqlspec/adapters/bigquery/driver.py +138 -1
  12. sqlspec/adapters/duckdb/adk/store.py +21 -11
  13. sqlspec/adapters/duckdb/driver.py +87 -1
  14. sqlspec/adapters/oracledb/adk/store.py +89 -206
  15. sqlspec/adapters/oracledb/driver.py +183 -2
  16. sqlspec/adapters/oracledb/litestar/store.py +22 -24
  17. sqlspec/adapters/psqlpy/adk/store.py +28 -27
  18. sqlspec/adapters/psqlpy/data_dictionary.py +24 -17
  19. sqlspec/adapters/psqlpy/driver.py +7 -10
  20. sqlspec/adapters/psycopg/adk/store.py +51 -33
  21. sqlspec/adapters/psycopg/data_dictionary.py +48 -34
  22. sqlspec/adapters/sqlite/adk/store.py +29 -19
  23. sqlspec/config.py +100 -2
  24. sqlspec/core/filters.py +18 -10
  25. sqlspec/core/result.py +133 -2
  26. sqlspec/driver/_async.py +89 -0
  27. sqlspec/driver/_common.py +64 -29
  28. sqlspec/driver/_sync.py +95 -0
  29. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +2 -2
  30. sqlspec/extensions/adk/service.py +3 -3
  31. sqlspec/extensions/adk/store.py +8 -8
  32. sqlspec/extensions/aiosql/adapter.py +3 -15
  33. sqlspec/extensions/fastapi/__init__.py +21 -0
  34. sqlspec/extensions/fastapi/extension.py +331 -0
  35. sqlspec/extensions/fastapi/providers.py +543 -0
  36. sqlspec/extensions/flask/__init__.py +36 -0
  37. sqlspec/extensions/flask/_state.py +71 -0
  38. sqlspec/extensions/flask/_utils.py +40 -0
  39. sqlspec/extensions/flask/extension.py +389 -0
  40. sqlspec/extensions/litestar/config.py +3 -6
  41. sqlspec/extensions/litestar/plugin.py +26 -2
  42. sqlspec/extensions/starlette/__init__.py +10 -0
  43. sqlspec/extensions/starlette/_state.py +25 -0
  44. sqlspec/extensions/starlette/_utils.py +52 -0
  45. sqlspec/extensions/starlette/extension.py +254 -0
  46. sqlspec/extensions/starlette/middleware.py +154 -0
  47. sqlspec/protocols.py +40 -0
  48. sqlspec/storage/_utils.py +1 -14
  49. sqlspec/storage/backends/fsspec.py +3 -5
  50. sqlspec/storage/backends/local.py +1 -1
  51. sqlspec/storage/backends/obstore.py +10 -18
  52. sqlspec/typing.py +16 -0
  53. sqlspec/utils/__init__.py +25 -4
  54. sqlspec/utils/arrow_helpers.py +81 -0
  55. sqlspec/utils/module_loader.py +203 -3
  56. sqlspec/utils/portal.py +311 -0
  57. sqlspec/utils/serializers.py +110 -1
  58. sqlspec/utils/sync_tools.py +15 -5
  59. sqlspec/utils/type_guards.py +25 -0
  60. {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +2 -2
  61. {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/RECORD +64 -50
  62. {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
  63. {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
  64. {sqlspec-0.27.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
sqlspec/_typing.py CHANGED
@@ -377,18 +377,101 @@ class ArrowRecordBatchResult(Protocol):
377
377
  return None
378
378
 
379
379
 
380
+ @runtime_checkable
381
+ class ArrowSchemaProtocol(Protocol):
382
+ """Typed shim for pyarrow.Schema."""
383
+
384
+ def field(self, i: int) -> Any:
385
+ """Get field by index."""
386
+ ...
387
+
388
+ @property
389
+ def names(self) -> "list[str]":
390
+ """Get list of field names."""
391
+ ...
392
+
393
+ def __len__(self) -> int:
394
+ """Get number of fields."""
395
+ return 0
396
+
397
+
398
+ @runtime_checkable
399
+ class ArrowRecordBatchReaderProtocol(Protocol):
400
+ """Typed shim for pyarrow.RecordBatchReader."""
401
+
402
+ def read_all(self) -> Any:
403
+ """Read all batches into a table."""
404
+ ...
405
+
406
+ def read_next_batch(self) -> Any:
407
+ """Read next batch."""
408
+ ...
409
+
410
+ def __iter__(self) -> "Iterable[Any]":
411
+ """Iterate over batches."""
412
+ ...
413
+
414
+
380
415
  try:
381
416
  from pyarrow import RecordBatch as ArrowRecordBatch
417
+ from pyarrow import RecordBatchReader as ArrowRecordBatchReader
418
+ from pyarrow import Schema as ArrowSchema
382
419
  from pyarrow import Table as ArrowTable
383
420
 
384
421
  PYARROW_INSTALLED = True
385
422
  except ImportError:
386
423
  ArrowTable = ArrowTableResult # type: ignore[assignment,misc]
387
424
  ArrowRecordBatch = ArrowRecordBatchResult # type: ignore[assignment,misc]
425
+ ArrowSchema = ArrowSchemaProtocol # type: ignore[assignment,misc]
426
+ ArrowRecordBatchReader = ArrowRecordBatchReaderProtocol # type: ignore[assignment,misc]
388
427
 
389
428
  PYARROW_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
390
429
 
391
430
 
431
+ @runtime_checkable
432
+ class PandasDataFrameProtocol(Protocol):
433
+ """Typed shim for pandas.DataFrame."""
434
+
435
+ def __len__(self) -> int:
436
+ """Get number of rows."""
437
+ ...
438
+
439
+ def __getitem__(self, key: Any) -> Any:
440
+ """Get column or row."""
441
+ ...
442
+
443
+
444
+ @runtime_checkable
445
+ class PolarsDataFrameProtocol(Protocol):
446
+ """Typed shim for polars.DataFrame."""
447
+
448
+ def __len__(self) -> int:
449
+ """Get number of rows."""
450
+ ...
451
+
452
+ def __getitem__(self, key: Any) -> Any:
453
+ """Get column or row."""
454
+ ...
455
+
456
+
457
+ try:
458
+ from pandas import DataFrame as PandasDataFrame
459
+
460
+ PANDAS_INSTALLED = True
461
+ except ImportError:
462
+ PandasDataFrame = PandasDataFrameProtocol # type: ignore[assignment,misc]
463
+ PANDAS_INSTALLED = False
464
+
465
+
466
+ try:
467
+ from polars import DataFrame as PolarsDataFrame
468
+
469
+ POLARS_INSTALLED = True
470
+ except ImportError:
471
+ PolarsDataFrame = PolarsDataFrameProtocol # type: ignore[assignment,misc]
472
+ POLARS_INSTALLED = False
473
+
474
+
392
475
  @runtime_checkable
393
476
  class NumpyArrayStub(Protocol):
394
477
  """Protocol stub for numpy.ndarray when numpy is not installed.
@@ -639,7 +722,9 @@ __all__ = (
639
722
  "OBSTORE_INSTALLED",
640
723
  "OPENTELEMETRY_INSTALLED",
641
724
  "ORJSON_INSTALLED",
725
+ "PANDAS_INSTALLED",
642
726
  "PGVECTOR_INSTALLED",
727
+ "POLARS_INSTALLED",
643
728
  "PROMETHEUS_INSTALLED",
644
729
  "PYARROW_INSTALLED",
645
730
  "PYDANTIC_INSTALLED",
@@ -650,7 +735,11 @@ __all__ = (
650
735
  "AiosqlSQLOperationType",
651
736
  "AiosqlSyncProtocol",
652
737
  "ArrowRecordBatch",
738
+ "ArrowRecordBatchReader",
739
+ "ArrowRecordBatchReaderProtocol",
653
740
  "ArrowRecordBatchResult",
741
+ "ArrowSchema",
742
+ "ArrowSchemaProtocol",
654
743
  "ArrowTable",
655
744
  "ArrowTableResult",
656
745
  "AttrsInstance",
@@ -670,6 +759,10 @@ __all__ = (
670
759
  "Histogram",
671
760
  "NumpyArray",
672
761
  "NumpyArrayStub",
762
+ "PandasDataFrame",
763
+ "PandasDataFrameProtocol",
764
+ "PolarsDataFrame",
765
+ "PolarsDataFrameProtocol",
673
766
  "Span",
674
767
  "Status",
675
768
  "StatusCode",
@@ -639,31 +639,41 @@ class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]):
639
639
  finally:
640
640
  cursor.close() # type: ignore[no-untyped-call]
641
641
 
642
- def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
643
- """List all sessions for a user in an app.
642
+ def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
643
+ """List sessions for an app, optionally filtered by user.
644
644
 
645
645
  Args:
646
646
  app_name: Application name.
647
- user_id: User identifier.
647
+ user_id: User identifier. If None, lists all sessions for the app.
648
648
 
649
649
  Returns:
650
650
  List of session records ordered by update_time DESC.
651
651
 
652
652
  Notes:
653
- Uses composite index on (app_name, user_id).
654
- """
655
- sql = f"""
656
- SELECT id, app_name, user_id, state, create_time, update_time
657
- FROM {self._session_table}
658
- WHERE app_name = ? AND user_id = ?
659
- ORDER BY update_time DESC
653
+ Uses composite index on (app_name, user_id) when user_id is provided.
660
654
  """
655
+ if user_id is None:
656
+ sql = f"""
657
+ SELECT id, app_name, user_id, state, create_time, update_time
658
+ FROM {self._session_table}
659
+ WHERE app_name = ?
660
+ ORDER BY update_time DESC
661
+ """
662
+ params: tuple[str, ...] = (app_name,)
663
+ else:
664
+ sql = f"""
665
+ SELECT id, app_name, user_id, state, create_time, update_time
666
+ FROM {self._session_table}
667
+ WHERE app_name = ? AND user_id = ?
668
+ ORDER BY update_time DESC
669
+ """
670
+ params = (app_name, user_id)
661
671
 
662
672
  try:
663
673
  with self._config.provide_connection() as conn:
664
674
  cursor = conn.cursor()
665
675
  try:
666
- cursor.execute(sql, (app_name, user_id))
676
+ cursor.execute(sql, params)
667
677
  rows = cursor.fetchall()
668
678
 
669
679
  return [
@@ -300,22 +300,44 @@ class AdbcDataDictionary(SyncDataDictionaryBase):
300
300
  for row in result.data or []
301
301
  ]
302
302
 
303
+ if dialect == "postgres":
304
+ schema_name = schema or "public"
305
+ sql = """
306
+ SELECT
307
+ a.attname::text AS column_name,
308
+ pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
309
+ CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
310
+ pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default
311
+ FROM pg_catalog.pg_attribute a
312
+ JOIN pg_catalog.pg_class c ON a.attrelid = c.oid
313
+ JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
314
+ LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
315
+ WHERE c.relname = ?
316
+ AND n.nspname = ?
317
+ AND a.attnum > 0
318
+ AND NOT a.attisdropped
319
+ ORDER BY a.attnum
320
+ """
321
+ result = adbc_driver.execute(sql, (table, schema_name))
322
+ return result.data or []
323
+
303
324
  if schema:
304
- sql = f"""
325
+ sql = """
305
326
  SELECT column_name, data_type, is_nullable, column_default
306
327
  FROM information_schema.columns
307
- WHERE table_name = '{table}' AND table_schema = '{schema}'
328
+ WHERE table_name = ? AND table_schema = ?
308
329
  ORDER BY ordinal_position
309
330
  """
331
+ result = adbc_driver.execute(sql, (table, schema))
310
332
  else:
311
- sql = f"""
333
+ sql = """
312
334
  SELECT column_name, data_type, is_nullable, column_default
313
335
  FROM information_schema.columns
314
- WHERE table_name = '{table}'
336
+ WHERE table_name = ?
315
337
  ORDER BY ordinal_position
316
338
  """
339
+ result = adbc_driver.execute(sql, (table,))
317
340
 
318
- result = adbc_driver.execute(sql)
319
341
  return result.data or []
320
342
 
321
343
  def list_available_features(self) -> "list[str]":
@@ -15,6 +15,7 @@ from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
15
15
  from sqlspec.adapters.adbc.type_converter import ADBCTypeConverter
16
16
  from sqlspec.core.cache import get_cache_config
17
17
  from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
18
+ from sqlspec.core.result import create_arrow_result
18
19
  from sqlspec.core.statement import SQL, StatementConfig
19
20
  from sqlspec.driver import SyncDriverAdapterBase
20
21
  from sqlspec.exceptions import (
@@ -23,7 +24,6 @@ from sqlspec.exceptions import (
23
24
  DataError,
24
25
  ForeignKeyViolationError,
25
26
  IntegrityError,
26
- MissingDependencyError,
27
27
  NotNullViolationError,
28
28
  SQLParsingError,
29
29
  SQLSpecError,
@@ -32,6 +32,7 @@ from sqlspec.exceptions import (
32
32
  )
33
33
  from sqlspec.typing import Empty
34
34
  from sqlspec.utils.logging import get_logger
35
+ from sqlspec.utils.module_loader import ensure_pyarrow
35
36
 
36
37
  if TYPE_CHECKING:
37
38
  from contextlib import AbstractContextManager
@@ -39,9 +40,12 @@ if TYPE_CHECKING:
39
40
  from adbc_driver_manager.dbapi import Cursor
40
41
 
41
42
  from sqlspec.adapters.adbc._types import AdbcConnection
42
- from sqlspec.core.result import SQLResult
43
+ from sqlspec.builder import QueryBuilder
44
+ from sqlspec.core import Statement, StatementFilter
45
+ from sqlspec.core.result import ArrowResult, SQLResult
43
46
  from sqlspec.driver import ExecutionResult
44
47
  from sqlspec.driver._sync import SyncDataDictionaryBase
48
+ from sqlspec.typing import StatementParameters
45
49
 
46
50
  __all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "get_adbc_statement_config")
47
51
 
@@ -507,18 +511,6 @@ class AdbcDriver(SyncDriverAdapterBase):
507
511
  self.dialect = statement_config.dialect
508
512
  self._data_dictionary: SyncDataDictionaryBase | None = None
509
513
 
510
- @staticmethod
511
- def _ensure_pyarrow_installed() -> None:
512
- """Ensure PyArrow is installed.
513
-
514
- Raises:
515
- MissingDependencyError: If PyArrow is not installed
516
- """
517
- from sqlspec.typing import PYARROW_INSTALLED
518
-
519
- if not PYARROW_INSTALLED:
520
- raise MissingDependencyError(package="pyarrow", install_package="arrow")
521
-
522
514
  @staticmethod
523
515
  def _get_dialect(connection: "AdbcConnection") -> str:
524
516
  """Detect database dialect from connection information.
@@ -863,3 +855,80 @@ class AdbcDriver(SyncDriverAdapterBase):
863
855
  if self._data_dictionary is None:
864
856
  self._data_dictionary = AdbcDataDictionary()
865
857
  return self._data_dictionary
858
+
859
+ def select_to_arrow(
860
+ self,
861
+ statement: "Statement | QueryBuilder",
862
+ /,
863
+ *parameters: "StatementParameters | StatementFilter",
864
+ statement_config: "StatementConfig | None" = None,
865
+ return_format: str = "table",
866
+ native_only: bool = False,
867
+ batch_size: int | None = None,
868
+ arrow_schema: Any = None,
869
+ **kwargs: Any,
870
+ ) -> "ArrowResult":
871
+ """Execute query and return results as Apache Arrow (ADBC native path).
872
+
873
+ ADBC provides zero-copy Arrow support via cursor.fetch_arrow_table().
874
+ This is 5-10x faster than the conversion path for large datasets.
875
+
876
+ Args:
877
+ statement: SQL statement, string, or QueryBuilder
878
+ *parameters: Query parameters or filters
879
+ statement_config: Optional statement configuration override
880
+ return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch
881
+ native_only: Ignored for ADBC (always uses native path)
882
+ batch_size: Batch size hint (for future streaming implementation)
883
+ arrow_schema: Optional pyarrow.Schema for type casting
884
+ **kwargs: Additional keyword arguments
885
+
886
+ Returns:
887
+ ArrowResult with native Arrow data
888
+
889
+ Raises:
890
+ MissingDependencyError: If pyarrow not installed
891
+ SQLExecutionError: If query execution fails
892
+
893
+ Example:
894
+ >>> result = driver.select_to_arrow(
895
+ ... "SELECT * FROM users WHERE age > $1", 18
896
+ ... )
897
+ >>> df = result.to_pandas() # Fast zero-copy conversion
898
+ """
899
+ ensure_pyarrow()
900
+
901
+ import pyarrow as pa
902
+
903
+ # Prepare statement
904
+ config = statement_config or self.statement_config
905
+ prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
906
+
907
+ # Use ADBC cursor for native Arrow
908
+ with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions():
909
+ if cursor is None:
910
+ msg = "Failed to create cursor"
911
+ raise DatabaseConnectionError(msg)
912
+
913
+ # Get compiled SQL and parameters
914
+ sql, driver_params = self._get_compiled_sql(prepared_statement, config)
915
+
916
+ # Execute query
917
+ cursor.execute(sql, driver_params or ())
918
+
919
+ # Fetch as Arrow table (zero-copy!)
920
+ arrow_table = cursor.fetch_arrow_table()
921
+
922
+ # Apply schema casting if requested
923
+ if arrow_schema is not None:
924
+ arrow_table = arrow_table.cast(arrow_schema)
925
+
926
+ # Convert to batch if requested
927
+ if return_format == "batch":
928
+ batches = arrow_table.to_batches()
929
+ arrow_data: Any = batches[0] if batches else pa.RecordBatch.from_pydict({})
930
+ else:
931
+ arrow_data = arrow_table
932
+
933
+ # Create ArrowResult
934
+ return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows)
@@ -136,7 +136,7 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
136
136
  """
137
137
  super().__init__(config)
138
138
 
139
- def _get_create_sessions_table_sql(self) -> str:
139
+ async def _get_create_sessions_table_sql(self) -> str:
140
140
  """Get SQLite CREATE TABLE SQL for sessions.
141
141
 
142
142
  Returns:
@@ -163,7 +163,7 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
163
163
  ON {self._session_table}(update_time DESC);
164
164
  """
165
165
 
166
- def _get_create_events_table_sql(self) -> str:
166
+ async def _get_create_events_table_sql(self) -> str:
167
167
  """Get SQLite CREATE TABLE SQL for events.
168
168
 
169
169
  Returns:
@@ -228,11 +228,10 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
228
228
 
229
229
  async def create_tables(self) -> None:
230
230
  """Create both sessions and events tables if they don't exist."""
231
- async with self._config.provide_connection() as conn:
232
- await self._enable_foreign_keys(conn)
233
- await conn.executescript(self._get_create_sessions_table_sql())
234
- await conn.executescript(self._get_create_events_table_sql())
235
- await conn.commit()
231
+ async with self._config.provide_session() as driver:
232
+ await self._enable_foreign_keys(driver.connection)
233
+ await driver.execute_script(await self._get_create_sessions_table_sql())
234
+ await driver.execute_script(await self._get_create_events_table_sql())
236
235
  logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
237
236
 
238
237
  async def create_session(
@@ -343,29 +342,39 @@ class AiosqliteADKStore(BaseAsyncADKStore["AiosqliteConfig"]):
343
342
  await conn.execute(sql, (state_json, now_julian, session_id))
344
343
  await conn.commit()
345
344
 
346
- async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
347
- """List all sessions for a user in an app.
345
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
346
+ """List sessions for an app, optionally filtered by user.
348
347
 
349
348
  Args:
350
349
  app_name: Application name.
351
- user_id: User identifier.
350
+ user_id: User identifier. If None, lists all sessions for the app.
352
351
 
353
352
  Returns:
354
353
  List of session records ordered by update_time DESC.
355
354
 
356
355
  Notes:
357
- Uses composite index on (app_name, user_id).
358
- """
359
- sql = f"""
360
- SELECT id, app_name, user_id, state, create_time, update_time
361
- FROM {self._session_table}
362
- WHERE app_name = ? AND user_id = ?
363
- ORDER BY update_time DESC
356
+ Uses composite index on (app_name, user_id) when user_id is provided.
364
357
  """
358
+ if user_id is None:
359
+ sql = f"""
360
+ SELECT id, app_name, user_id, state, create_time, update_time
361
+ FROM {self._session_table}
362
+ WHERE app_name = ?
363
+ ORDER BY update_time DESC
364
+ """
365
+ params: tuple[str, ...] = (app_name,)
366
+ else:
367
+ sql = f"""
368
+ SELECT id, app_name, user_id, state, create_time, update_time
369
+ FROM {self._session_table}
370
+ WHERE app_name = ? AND user_id = ?
371
+ ORDER BY update_time DESC
372
+ """
373
+ params = (app_name, user_id)
365
374
 
366
375
  async with self._config.provide_connection() as conn:
367
376
  await self._enable_foreign_keys(conn)
368
- cursor = await conn.execute(sql, (app_name, user_id))
377
+ cursor = await conn.execute(sql, params)
369
378
  rows = await cursor.fetchall()
370
379
 
371
380
  return [
@@ -106,7 +106,7 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
106
106
 
107
107
  return (col_def, fk_constraint)
108
108
 
109
- def _get_create_sessions_table_sql(self) -> str:
109
+ async def _get_create_sessions_table_sql(self) -> str:
110
110
  """Get MySQL CREATE TABLE SQL for sessions.
111
111
 
112
112
  Returns:
@@ -145,7 +145,7 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
145
145
  ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
146
146
  """
147
147
 
148
- def _get_create_events_table_sql(self) -> str:
148
+ async def _get_create_events_table_sql(self) -> str:
149
149
  """Get MySQL CREATE TABLE SQL for events.
150
150
 
151
151
  Returns:
@@ -199,9 +199,9 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
199
199
 
200
200
  async def create_tables(self) -> None:
201
201
  """Create both sessions and events tables if they don't exist."""
202
- async with self._config.provide_connection() as conn, conn.cursor() as cursor:
203
- await cursor.execute(self._get_create_sessions_table_sql())
204
- await cursor.execute(self._get_create_events_table_sql())
202
+ async with self._config.provide_session() as driver:
203
+ await driver.execute_script(await self._get_create_sessions_table_sql())
204
+ await driver.execute_script(await self._get_create_events_table_sql())
205
205
  logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
206
206
 
207
207
  async def create_session(
@@ -326,29 +326,39 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
326
326
  await cursor.execute(sql, (session_id,))
327
327
  await conn.commit()
328
328
 
329
- async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
330
- """List all sessions for a user in an app.
329
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
330
+ """List sessions for an app, optionally filtered by user.
331
331
 
332
332
  Args:
333
333
  app_name: Application name.
334
- user_id: User identifier.
334
+ user_id: User identifier. If None, lists all sessions for the app.
335
335
 
336
336
  Returns:
337
337
  List of session records ordered by update_time DESC.
338
338
 
339
339
  Notes:
340
- Uses composite index on (app_name, user_id).
341
- """
342
- sql = f"""
343
- SELECT id, app_name, user_id, state, create_time, update_time
344
- FROM {self._session_table}
345
- WHERE app_name = %s AND user_id = %s
346
- ORDER BY update_time DESC
340
+ Uses composite index on (app_name, user_id) when user_id is provided.
347
341
  """
342
+ if user_id is None:
343
+ sql = f"""
344
+ SELECT id, app_name, user_id, state, create_time, update_time
345
+ FROM {self._session_table}
346
+ WHERE app_name = %s
347
+ ORDER BY update_time DESC
348
+ """
349
+ params: tuple[str, ...] = (app_name,)
350
+ else:
351
+ sql = f"""
352
+ SELECT id, app_name, user_id, state, create_time, update_time
353
+ FROM {self._session_table}
354
+ WHERE app_name = %s AND user_id = %s
355
+ ORDER BY update_time DESC
356
+ """
357
+ params = (app_name, user_id)
348
358
 
349
359
  try:
350
360
  async with self._config.provide_connection() as conn, conn.cursor() as cursor:
351
- await cursor.execute(sql, (app_name, user_id))
361
+ await cursor.execute(sql, params)
352
362
  rows = await cursor.fetchall()
353
363
 
354
364
  return [
@@ -84,7 +84,7 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
84
84
  """
85
85
  super().__init__(config)
86
86
 
87
- def _get_create_sessions_table_sql(self) -> str:
87
+ async def _get_create_sessions_table_sql(self) -> str:
88
88
  """Get PostgreSQL CREATE TABLE SQL for sessions.
89
89
 
90
90
  Returns:
@@ -125,7 +125,7 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
125
125
  WHERE state != '{{}}'::jsonb;
126
126
  """
127
127
 
128
- def _get_create_events_table_sql(self) -> str:
128
+ async def _get_create_events_table_sql(self) -> str:
129
129
  """Get PostgreSQL CREATE TABLE SQL for events.
130
130
 
131
131
  Returns:
@@ -181,9 +181,9 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
181
181
 
182
182
  async def create_tables(self) -> None:
183
183
  """Create both sessions and events tables if they don't exist."""
184
- async with self.config.provide_connection() as conn:
185
- await conn.execute(self._get_create_sessions_table_sql())
186
- await conn.execute(self._get_create_events_table_sql())
184
+ async with self.config.provide_session() as driver:
185
+ await driver.execute_script(await self._get_create_sessions_table_sql())
186
+ await driver.execute_script(await self._get_create_events_table_sql())
187
187
  logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
188
188
 
189
189
  async def create_session(
@@ -294,29 +294,39 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
294
294
  async with self.config.provide_connection() as conn:
295
295
  await conn.execute(sql, session_id)
296
296
 
297
- async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
298
- """List all sessions for a user in an app.
297
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
298
+ """List sessions for an app, optionally filtered by user.
299
299
 
300
300
  Args:
301
301
  app_name: Application name.
302
- user_id: User identifier.
302
+ user_id: User identifier. If None, lists all sessions for the app.
303
303
 
304
304
  Returns:
305
305
  List of session records ordered by update_time DESC.
306
306
 
307
307
  Notes:
308
- Uses composite index on (app_name, user_id).
309
- """
310
- sql = f"""
311
- SELECT id, app_name, user_id, state, create_time, update_time
312
- FROM {self._session_table}
313
- WHERE app_name = $1 AND user_id = $2
314
- ORDER BY update_time DESC
308
+ Uses composite index on (app_name, user_id) when user_id is provided.
315
309
  """
310
+ if user_id is None:
311
+ sql = f"""
312
+ SELECT id, app_name, user_id, state, create_time, update_time
313
+ FROM {self._session_table}
314
+ WHERE app_name = $1
315
+ ORDER BY update_time DESC
316
+ """
317
+ params = [app_name]
318
+ else:
319
+ sql = f"""
320
+ SELECT id, app_name, user_id, state, create_time, update_time
321
+ FROM {self._session_table}
322
+ WHERE app_name = $1 AND user_id = $2
323
+ ORDER BY update_time DESC
324
+ """
325
+ params = [app_name, user_id]
316
326
 
317
327
  try:
318
328
  async with self.config.provide_connection() as conn:
319
- rows = await conn.fetch(sql, app_name, user_id)
329
+ rows = await conn.fetch(sql, *params)
320
330
 
321
331
  return [
322
332
  SessionRecord(