sqlspec 0.26.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +7 -15
- sqlspec/_serialization.py +55 -25
- sqlspec/_typing.py +62 -52
- sqlspec/adapters/adbc/_types.py +1 -1
- sqlspec/adapters/adbc/adk/__init__.py +5 -0
- sqlspec/adapters/adbc/adk/store.py +870 -0
- sqlspec/adapters/adbc/config.py +62 -12
- sqlspec/adapters/adbc/data_dictionary.py +52 -2
- sqlspec/adapters/adbc/driver.py +144 -45
- sqlspec/adapters/adbc/litestar/__init__.py +5 -0
- sqlspec/adapters/adbc/litestar/store.py +504 -0
- sqlspec/adapters/adbc/type_converter.py +44 -50
- sqlspec/adapters/aiosqlite/_types.py +1 -1
- sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/adk/store.py +527 -0
- sqlspec/adapters/aiosqlite/config.py +86 -16
- sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
- sqlspec/adapters/aiosqlite/driver.py +127 -38
- sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
- sqlspec/adapters/aiosqlite/pool.py +7 -7
- sqlspec/adapters/asyncmy/__init__.py +7 -1
- sqlspec/adapters/asyncmy/_types.py +1 -1
- sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
- sqlspec/adapters/asyncmy/adk/store.py +493 -0
- sqlspec/adapters/asyncmy/config.py +59 -17
- sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
- sqlspec/adapters/asyncmy/driver.py +293 -62
- sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncmy/litestar/store.py +296 -0
- sqlspec/adapters/asyncpg/__init__.py +2 -1
- sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
- sqlspec/adapters/asyncpg/_types.py +11 -7
- sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
- sqlspec/adapters/asyncpg/adk/store.py +450 -0
- sqlspec/adapters/asyncpg/config.py +57 -36
- sqlspec/adapters/asyncpg/data_dictionary.py +41 -2
- sqlspec/adapters/asyncpg/driver.py +153 -23
- sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncpg/litestar/store.py +253 -0
- sqlspec/adapters/bigquery/_types.py +1 -1
- sqlspec/adapters/bigquery/adk/__init__.py +5 -0
- sqlspec/adapters/bigquery/adk/store.py +576 -0
- sqlspec/adapters/bigquery/config.py +25 -11
- sqlspec/adapters/bigquery/data_dictionary.py +42 -2
- sqlspec/adapters/bigquery/driver.py +352 -144
- sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
- sqlspec/adapters/bigquery/litestar/store.py +327 -0
- sqlspec/adapters/bigquery/type_converter.py +55 -23
- sqlspec/adapters/duckdb/_types.py +2 -2
- sqlspec/adapters/duckdb/adk/__init__.py +14 -0
- sqlspec/adapters/duckdb/adk/store.py +553 -0
- sqlspec/adapters/duckdb/config.py +79 -21
- sqlspec/adapters/duckdb/data_dictionary.py +41 -2
- sqlspec/adapters/duckdb/driver.py +138 -43
- sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
- sqlspec/adapters/duckdb/litestar/store.py +332 -0
- sqlspec/adapters/duckdb/pool.py +5 -5
- sqlspec/adapters/duckdb/type_converter.py +51 -21
- sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
- sqlspec/adapters/oracledb/_types.py +20 -2
- sqlspec/adapters/oracledb/adk/__init__.py +5 -0
- sqlspec/adapters/oracledb/adk/store.py +1745 -0
- sqlspec/adapters/oracledb/config.py +120 -36
- sqlspec/adapters/oracledb/data_dictionary.py +87 -20
- sqlspec/adapters/oracledb/driver.py +292 -84
- sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
- sqlspec/adapters/oracledb/litestar/store.py +767 -0
- sqlspec/adapters/oracledb/migrations.py +316 -25
- sqlspec/adapters/oracledb/type_converter.py +91 -16
- sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
- sqlspec/adapters/psqlpy/_types.py +2 -1
- sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
- sqlspec/adapters/psqlpy/adk/store.py +482 -0
- sqlspec/adapters/psqlpy/config.py +45 -19
- sqlspec/adapters/psqlpy/data_dictionary.py +41 -2
- sqlspec/adapters/psqlpy/driver.py +101 -31
- sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
- sqlspec/adapters/psqlpy/litestar/store.py +272 -0
- sqlspec/adapters/psqlpy/type_converter.py +40 -11
- sqlspec/adapters/psycopg/_type_handlers.py +80 -0
- sqlspec/adapters/psycopg/_types.py +2 -1
- sqlspec/adapters/psycopg/adk/__init__.py +5 -0
- sqlspec/adapters/psycopg/adk/store.py +944 -0
- sqlspec/adapters/psycopg/config.py +65 -37
- sqlspec/adapters/psycopg/data_dictionary.py +77 -3
- sqlspec/adapters/psycopg/driver.py +200 -78
- sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
- sqlspec/adapters/psycopg/litestar/store.py +554 -0
- sqlspec/adapters/sqlite/__init__.py +2 -1
- sqlspec/adapters/sqlite/_type_handlers.py +86 -0
- sqlspec/adapters/sqlite/_types.py +1 -1
- sqlspec/adapters/sqlite/adk/__init__.py +5 -0
- sqlspec/adapters/sqlite/adk/store.py +572 -0
- sqlspec/adapters/sqlite/config.py +85 -16
- sqlspec/adapters/sqlite/data_dictionary.py +34 -2
- sqlspec/adapters/sqlite/driver.py +120 -52
- sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/sqlite/litestar/store.py +318 -0
- sqlspec/adapters/sqlite/pool.py +5 -5
- sqlspec/base.py +45 -26
- sqlspec/builder/__init__.py +73 -4
- sqlspec/builder/_base.py +91 -58
- sqlspec/builder/_column.py +5 -5
- sqlspec/builder/_ddl.py +98 -89
- sqlspec/builder/_delete.py +5 -4
- sqlspec/builder/_dml.py +388 -0
- sqlspec/{_sql.py → builder/_factory.py} +41 -44
- sqlspec/builder/_insert.py +5 -82
- sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
- sqlspec/builder/_merge.py +446 -11
- sqlspec/builder/_parsing_utils.py +9 -11
- sqlspec/builder/_select.py +1313 -25
- sqlspec/builder/_update.py +11 -42
- sqlspec/cli.py +76 -69
- sqlspec/config.py +231 -60
- sqlspec/core/__init__.py +5 -4
- sqlspec/core/cache.py +18 -18
- sqlspec/core/compiler.py +6 -8
- sqlspec/core/filters.py +37 -37
- sqlspec/core/hashing.py +9 -9
- sqlspec/core/parameters.py +76 -45
- sqlspec/core/result.py +102 -46
- sqlspec/core/splitter.py +16 -17
- sqlspec/core/statement.py +32 -31
- sqlspec/core/type_conversion.py +3 -2
- sqlspec/driver/__init__.py +1 -3
- sqlspec/driver/_async.py +95 -161
- sqlspec/driver/_common.py +133 -80
- sqlspec/driver/_sync.py +95 -162
- sqlspec/driver/mixins/_result_tools.py +20 -236
- sqlspec/driver/mixins/_sql_translator.py +4 -4
- sqlspec/exceptions.py +70 -7
- sqlspec/extensions/adk/__init__.py +53 -0
- sqlspec/extensions/adk/_types.py +51 -0
- sqlspec/extensions/adk/converters.py +172 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
- sqlspec/extensions/adk/migrations/__init__.py +0 -0
- sqlspec/extensions/adk/service.py +181 -0
- sqlspec/extensions/adk/store.py +536 -0
- sqlspec/extensions/aiosql/adapter.py +73 -53
- sqlspec/extensions/litestar/__init__.py +21 -4
- sqlspec/extensions/litestar/cli.py +54 -10
- sqlspec/extensions/litestar/config.py +59 -266
- sqlspec/extensions/litestar/handlers.py +46 -17
- sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
- sqlspec/extensions/litestar/migrations/__init__.py +3 -0
- sqlspec/extensions/litestar/plugin.py +324 -223
- sqlspec/extensions/litestar/providers.py +25 -25
- sqlspec/extensions/litestar/store.py +265 -0
- sqlspec/loader.py +30 -49
- sqlspec/migrations/base.py +200 -76
- sqlspec/migrations/commands.py +591 -62
- sqlspec/migrations/context.py +6 -9
- sqlspec/migrations/fix.py +199 -0
- sqlspec/migrations/loaders.py +47 -19
- sqlspec/migrations/runner.py +241 -75
- sqlspec/migrations/tracker.py +237 -21
- sqlspec/migrations/utils.py +51 -3
- sqlspec/migrations/validation.py +177 -0
- sqlspec/protocols.py +66 -36
- sqlspec/storage/_utils.py +98 -0
- sqlspec/storage/backends/fsspec.py +134 -106
- sqlspec/storage/backends/local.py +78 -51
- sqlspec/storage/backends/obstore.py +278 -162
- sqlspec/storage/registry.py +75 -39
- sqlspec/typing.py +14 -84
- sqlspec/utils/config_resolver.py +6 -6
- sqlspec/utils/correlation.py +4 -5
- sqlspec/utils/data_transformation.py +3 -2
- sqlspec/utils/deprecation.py +9 -8
- sqlspec/utils/fixtures.py +4 -4
- sqlspec/utils/logging.py +46 -6
- sqlspec/utils/module_loader.py +2 -2
- sqlspec/utils/schema.py +288 -0
- sqlspec/utils/serializers.py +3 -3
- sqlspec/utils/sync_tools.py +21 -17
- sqlspec/utils/text.py +1 -2
- sqlspec/utils/type_guards.py +111 -20
- sqlspec/utils/version.py +433 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.27.0.dist-info}/METADATA +40 -21
- sqlspec-0.27.0.dist-info/RECORD +207 -0
- sqlspec/builder/mixins/__init__.py +0 -55
- sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
- sqlspec/builder/mixins/_delete_operations.py +0 -50
- sqlspec/builder/mixins/_insert_operations.py +0 -282
- sqlspec/builder/mixins/_merge_operations.py +0 -698
- sqlspec/builder/mixins/_order_limit_operations.py +0 -145
- sqlspec/builder/mixins/_pivot_operations.py +0 -157
- sqlspec/builder/mixins/_select_operations.py +0 -930
- sqlspec/builder/mixins/_update_operations.py +0 -199
- sqlspec/builder/mixins/_where_clause.py +0 -1298
- sqlspec-0.26.0.dist-info/RECORD +0 -157
- sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
- {sqlspec-0.26.0.dist-info → sqlspec-0.27.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.27.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.26.0.dist-info → sqlspec-0.27.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,576 @@
|
|
|
1
|
+
"""BigQuery ADK store for Google Agent Development Kit session/event storage."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from google.cloud.bigquery import QueryJobConfig, ScalarQueryParameter
|
|
7
|
+
|
|
8
|
+
from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
|
|
9
|
+
from sqlspec.utils.logging import get_logger
|
|
10
|
+
from sqlspec.utils.serializers import from_json, to_json
|
|
11
|
+
from sqlspec.utils.sync_tools import async_
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from sqlspec.adapters.bigquery.config import BigQueryConfig
|
|
15
|
+
|
|
16
|
+
logger = get_logger("adapters.bigquery.adk.store")
|
|
17
|
+
|
|
18
|
+
__all__ = ("BigQueryADKStore",)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BigQueryADKStore(BaseAsyncADKStore["BigQueryConfig"]):
|
|
22
|
+
"""BigQuery ADK store using synchronous BigQuery client with async wrapper.
|
|
23
|
+
|
|
24
|
+
Implements session and event storage for Google Agent Development Kit
|
|
25
|
+
using Google Cloud BigQuery. Uses BigQuery's native JSON type for state/metadata
|
|
26
|
+
storage and async_() wrapper to provide async interface.
|
|
27
|
+
|
|
28
|
+
Provides:
|
|
29
|
+
- Serverless, scalable session state management with JSON storage
|
|
30
|
+
- Event history tracking optimized for analytics
|
|
31
|
+
- Microsecond-precision timestamps with TIMESTAMP type
|
|
32
|
+
- Cost-optimized queries with partitioning and clustering
|
|
33
|
+
- Efficient JSON handling with BigQuery's JSON type
|
|
34
|
+
- Manual cascade delete pattern (no foreign key support)
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
config: BigQueryConfig with extension_config["adk"] settings.
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
from sqlspec.adapters.bigquery import BigQueryConfig
|
|
41
|
+
from sqlspec.adapters.bigquery.adk import BigQueryADKStore
|
|
42
|
+
|
|
43
|
+
config = BigQueryConfig(
|
|
44
|
+
connection_config={
|
|
45
|
+
"project": "my-project",
|
|
46
|
+
"dataset_id": "my_dataset",
|
|
47
|
+
},
|
|
48
|
+
extension_config={
|
|
49
|
+
"adk": {
|
|
50
|
+
"session_table": "my_sessions",
|
|
51
|
+
"events_table": "my_events",
|
|
52
|
+
"owner_id_column": "tenant_id INT64 NOT NULL"
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
)
|
|
56
|
+
store = BigQueryADKStore(config)
|
|
57
|
+
await store.create_tables()
|
|
58
|
+
|
|
59
|
+
Notes:
|
|
60
|
+
- JSON type for state, content, and metadata (native BigQuery JSON)
|
|
61
|
+
- BYTES for pre-serialized actions from Google ADK
|
|
62
|
+
- TIMESTAMP for timezone-aware microsecond precision
|
|
63
|
+
- Partitioned by DATE(create_time) for cost optimization
|
|
64
|
+
- Clustered by app_name, user_id for query performance
|
|
65
|
+
- Uses to_json/from_json for serialization to JSON columns
|
|
66
|
+
- BigQuery has eventual consistency - handle appropriately
|
|
67
|
+
- No true foreign keys but implements cascade delete pattern
|
|
68
|
+
- Configuration is read from config.extension_config["adk"]
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
__slots__ = ("_dataset_id",)
|
|
72
|
+
|
|
73
|
+
def __init__(self, config: "BigQueryConfig") -> None:
|
|
74
|
+
"""Initialize BigQuery ADK store.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
config: BigQueryConfig instance.
|
|
78
|
+
|
|
79
|
+
Notes:
|
|
80
|
+
Configuration is read from config.extension_config["adk"]:
|
|
81
|
+
- session_table: Sessions table name (default: "adk_sessions")
|
|
82
|
+
- events_table: Events table name (default: "adk_events")
|
|
83
|
+
- owner_id_column: Optional owner FK column DDL (default: None)
|
|
84
|
+
"""
|
|
85
|
+
super().__init__(config)
|
|
86
|
+
self._dataset_id = config.connection_config.get("dataset_id")
|
|
87
|
+
|
|
88
|
+
def _get_full_table_name(self, table_name: str) -> str:
|
|
89
|
+
"""Get fully qualified table name for BigQuery.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
table_name: Base table name.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Fully qualified table name with backticks.
|
|
96
|
+
|
|
97
|
+
Notes:
|
|
98
|
+
BigQuery requires backtick-quoted identifiers for table names.
|
|
99
|
+
Format: `project.dataset.table` or `dataset.table`
|
|
100
|
+
"""
|
|
101
|
+
if self._dataset_id:
|
|
102
|
+
return f"`{self._dataset_id}.{table_name}`"
|
|
103
|
+
return f"`{table_name}`"
|
|
104
|
+
|
|
105
|
+
def _get_create_sessions_table_sql(self) -> str:
|
|
106
|
+
"""Get BigQuery CREATE TABLE SQL for sessions.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
SQL statement to create adk_sessions table.
|
|
110
|
+
|
|
111
|
+
Notes:
|
|
112
|
+
- STRING for IDs and names
|
|
113
|
+
- JSON type for state storage (native BigQuery JSON)
|
|
114
|
+
- TIMESTAMP for timezone-aware microsecond precision
|
|
115
|
+
- Partitioned by DATE(create_time) for cost optimization
|
|
116
|
+
- Clustered by app_name, user_id for query performance
|
|
117
|
+
- No indexes needed (BigQuery auto-optimizes)
|
|
118
|
+
- Optional owner ID column for multi-tenant scenarios
|
|
119
|
+
- Note: BigQuery doesn't enforce FK constraints
|
|
120
|
+
"""
|
|
121
|
+
owner_id_line = ""
|
|
122
|
+
if self._owner_id_column_ddl:
|
|
123
|
+
owner_id_line = f",\n {self._owner_id_column_ddl}"
|
|
124
|
+
|
|
125
|
+
table_name = self._get_full_table_name(self._session_table)
|
|
126
|
+
return f"""
|
|
127
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
|
128
|
+
id STRING NOT NULL,
|
|
129
|
+
app_name STRING NOT NULL,
|
|
130
|
+
user_id STRING NOT NULL{owner_id_line},
|
|
131
|
+
state JSON NOT NULL,
|
|
132
|
+
create_time TIMESTAMP NOT NULL,
|
|
133
|
+
update_time TIMESTAMP NOT NULL
|
|
134
|
+
)
|
|
135
|
+
PARTITION BY DATE(create_time)
|
|
136
|
+
CLUSTER BY app_name, user_id
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def _get_create_events_table_sql(self) -> str:
|
|
140
|
+
"""Get BigQuery CREATE TABLE SQL for events.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
SQL statement to create adk_events table.
|
|
144
|
+
|
|
145
|
+
Notes:
|
|
146
|
+
- STRING for IDs and text fields
|
|
147
|
+
- BYTES for pickled actions
|
|
148
|
+
- JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
|
|
149
|
+
- BOOL for boolean flags
|
|
150
|
+
- TIMESTAMP for timezone-aware timestamps
|
|
151
|
+
- Partitioned by DATE(timestamp) for cost optimization
|
|
152
|
+
- Clustered by session_id, timestamp for ordered retrieval
|
|
153
|
+
"""
|
|
154
|
+
table_name = self._get_full_table_name(self._events_table)
|
|
155
|
+
return f"""
|
|
156
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
|
157
|
+
id STRING NOT NULL,
|
|
158
|
+
session_id STRING NOT NULL,
|
|
159
|
+
app_name STRING NOT NULL,
|
|
160
|
+
user_id STRING NOT NULL,
|
|
161
|
+
invocation_id STRING,
|
|
162
|
+
author STRING,
|
|
163
|
+
actions BYTES,
|
|
164
|
+
long_running_tool_ids_json JSON,
|
|
165
|
+
branch STRING,
|
|
166
|
+
timestamp TIMESTAMP NOT NULL,
|
|
167
|
+
content JSON,
|
|
168
|
+
grounding_metadata JSON,
|
|
169
|
+
custom_metadata JSON,
|
|
170
|
+
partial BOOL,
|
|
171
|
+
turn_complete BOOL,
|
|
172
|
+
interrupted BOOL,
|
|
173
|
+
error_code STRING,
|
|
174
|
+
error_message STRING
|
|
175
|
+
)
|
|
176
|
+
PARTITION BY DATE(timestamp)
|
|
177
|
+
CLUSTER BY session_id, timestamp
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def _get_drop_tables_sql(self) -> "list[str]":
|
|
181
|
+
"""Get BigQuery DROP TABLE SQL statements.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
List of SQL statements to drop tables.
|
|
185
|
+
|
|
186
|
+
Notes:
|
|
187
|
+
Order matters: drop events table before sessions table.
|
|
188
|
+
BigQuery uses IF EXISTS for idempotent drops.
|
|
189
|
+
"""
|
|
190
|
+
events_table = self._get_full_table_name(self._events_table)
|
|
191
|
+
sessions_table = self._get_full_table_name(self._session_table)
|
|
192
|
+
return [f"DROP TABLE IF EXISTS {events_table}", f"DROP TABLE IF EXISTS {sessions_table}"]
|
|
193
|
+
|
|
194
|
+
def _create_tables(self) -> None:
|
|
195
|
+
"""Synchronous implementation of create_tables."""
|
|
196
|
+
with self._config.provide_connection() as conn:
|
|
197
|
+
conn.query(self._get_create_sessions_table_sql()).result()
|
|
198
|
+
conn.query(self._get_create_events_table_sql()).result()
|
|
199
|
+
logger.debug("Created BigQuery ADK tables: %s, %s", self._session_table, self._events_table)
|
|
200
|
+
|
|
201
|
+
async def create_tables(self) -> None:
|
|
202
|
+
"""Create both sessions and events tables if they don't exist."""
|
|
203
|
+
await async_(self._create_tables)()
|
|
204
|
+
|
|
205
|
+
def _create_session(
|
|
206
|
+
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
|
|
207
|
+
) -> SessionRecord:
|
|
208
|
+
"""Synchronous implementation of create_session."""
|
|
209
|
+
now = datetime.now(timezone.utc)
|
|
210
|
+
state_json = to_json(state) if state else "{}"
|
|
211
|
+
|
|
212
|
+
table_name = self._get_full_table_name(self._session_table)
|
|
213
|
+
|
|
214
|
+
if self._owner_id_column_name:
|
|
215
|
+
sql = f"""
|
|
216
|
+
INSERT INTO {table_name} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time)
|
|
217
|
+
VALUES (@id, @app_name, @user_id, @owner_id, JSON(@state), @create_time, @update_time)
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
params = [
|
|
221
|
+
ScalarQueryParameter("id", "STRING", session_id),
|
|
222
|
+
ScalarQueryParameter("app_name", "STRING", app_name),
|
|
223
|
+
ScalarQueryParameter("user_id", "STRING", user_id),
|
|
224
|
+
ScalarQueryParameter("owner_id", "STRING", str(owner_id) if owner_id is not None else None),
|
|
225
|
+
ScalarQueryParameter("state", "STRING", state_json),
|
|
226
|
+
ScalarQueryParameter("create_time", "TIMESTAMP", now),
|
|
227
|
+
ScalarQueryParameter("update_time", "TIMESTAMP", now),
|
|
228
|
+
]
|
|
229
|
+
else:
|
|
230
|
+
sql = f"""
|
|
231
|
+
INSERT INTO {table_name} (id, app_name, user_id, state, create_time, update_time)
|
|
232
|
+
VALUES (@id, @app_name, @user_id, JSON(@state), @create_time, @update_time)
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
params = [
|
|
236
|
+
ScalarQueryParameter("id", "STRING", session_id),
|
|
237
|
+
ScalarQueryParameter("app_name", "STRING", app_name),
|
|
238
|
+
ScalarQueryParameter("user_id", "STRING", user_id),
|
|
239
|
+
ScalarQueryParameter("state", "STRING", state_json),
|
|
240
|
+
ScalarQueryParameter("create_time", "TIMESTAMP", now),
|
|
241
|
+
ScalarQueryParameter("update_time", "TIMESTAMP", now),
|
|
242
|
+
]
|
|
243
|
+
|
|
244
|
+
with self._config.provide_connection() as conn:
|
|
245
|
+
job_config = QueryJobConfig(query_parameters=params)
|
|
246
|
+
conn.query(sql, job_config=job_config).result()
|
|
247
|
+
|
|
248
|
+
return SessionRecord(
|
|
249
|
+
id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
async def create_session(
|
|
253
|
+
self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
|
|
254
|
+
) -> SessionRecord:
|
|
255
|
+
"""Create a new session.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
session_id: Unique session identifier.
|
|
259
|
+
app_name: Application name.
|
|
260
|
+
user_id: User identifier.
|
|
261
|
+
state: Initial session state.
|
|
262
|
+
owner_id: Optional owner ID value for owner_id_column (if configured).
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Created session record.
|
|
266
|
+
|
|
267
|
+
Notes:
|
|
268
|
+
Uses CURRENT_TIMESTAMP() for timestamps.
|
|
269
|
+
State is JSON-serialized then stored in JSON column.
|
|
270
|
+
If owner_id_column is configured, owner_id value must be provided.
|
|
271
|
+
BigQuery doesn't enforce FK constraints, but column is useful for JOINs.
|
|
272
|
+
"""
|
|
273
|
+
return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id)
|
|
274
|
+
|
|
275
|
+
def _get_session(self, session_id: str) -> "SessionRecord | None":
|
|
276
|
+
"""Synchronous implementation of get_session."""
|
|
277
|
+
table_name = self._get_full_table_name(self._session_table)
|
|
278
|
+
sql = f"""
|
|
279
|
+
SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time
|
|
280
|
+
FROM {table_name}
|
|
281
|
+
WHERE id = @session_id
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
params = [ScalarQueryParameter("session_id", "STRING", session_id)]
|
|
285
|
+
|
|
286
|
+
with self._config.provide_connection() as conn:
|
|
287
|
+
job_config = QueryJobConfig(query_parameters=params)
|
|
288
|
+
query_job = conn.query(sql, job_config=job_config)
|
|
289
|
+
results = list(query_job.result())
|
|
290
|
+
|
|
291
|
+
if not results:
|
|
292
|
+
return None
|
|
293
|
+
|
|
294
|
+
row = results[0]
|
|
295
|
+
return SessionRecord(
|
|
296
|
+
id=row.id,
|
|
297
|
+
app_name=row.app_name,
|
|
298
|
+
user_id=row.user_id,
|
|
299
|
+
state=from_json(row.state) if row.state else {},
|
|
300
|
+
create_time=row.create_time,
|
|
301
|
+
update_time=row.update_time,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
async def get_session(self, session_id: str) -> "SessionRecord | None":
|
|
305
|
+
"""Get session by ID.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
session_id: Session identifier.
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Session record or None if not found.
|
|
312
|
+
|
|
313
|
+
Notes:
|
|
314
|
+
BigQuery returns datetime objects for TIMESTAMP columns.
|
|
315
|
+
JSON_VALUE extracts string representation for parsing.
|
|
316
|
+
"""
|
|
317
|
+
return await async_(self._get_session)(session_id)
|
|
318
|
+
|
|
319
|
+
def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
|
|
320
|
+
"""Synchronous implementation of update_session_state."""
|
|
321
|
+
now = datetime.now(timezone.utc)
|
|
322
|
+
state_json = to_json(state) if state else "{}"
|
|
323
|
+
|
|
324
|
+
table_name = self._get_full_table_name(self._session_table)
|
|
325
|
+
sql = f"""
|
|
326
|
+
UPDATE {table_name}
|
|
327
|
+
SET state = JSON(@state), update_time = @update_time
|
|
328
|
+
WHERE id = @session_id
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
params = [
|
|
332
|
+
ScalarQueryParameter("state", "STRING", state_json),
|
|
333
|
+
ScalarQueryParameter("update_time", "TIMESTAMP", now),
|
|
334
|
+
ScalarQueryParameter("session_id", "STRING", session_id),
|
|
335
|
+
]
|
|
336
|
+
|
|
337
|
+
with self._config.provide_connection() as conn:
|
|
338
|
+
job_config = QueryJobConfig(query_parameters=params)
|
|
339
|
+
conn.query(sql, job_config=job_config).result()
|
|
340
|
+
|
|
341
|
+
async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
|
|
342
|
+
"""Update session state.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
session_id: Session identifier.
|
|
346
|
+
state: New state dictionary (replaces existing state).
|
|
347
|
+
|
|
348
|
+
Notes:
|
|
349
|
+
Replaces entire state dictionary.
|
|
350
|
+
Updates update_time to CURRENT_TIMESTAMP().
|
|
351
|
+
"""
|
|
352
|
+
await async_(self._update_session_state)(session_id, state)
|
|
353
|
+
|
|
354
|
+
def _list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
|
|
355
|
+
"""Synchronous implementation of list_sessions."""
|
|
356
|
+
table_name = self._get_full_table_name(self._session_table)
|
|
357
|
+
sql = f"""
|
|
358
|
+
SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time
|
|
359
|
+
FROM {table_name}
|
|
360
|
+
WHERE app_name = @app_name AND user_id = @user_id
|
|
361
|
+
ORDER BY update_time DESC
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
params = [
|
|
365
|
+
ScalarQueryParameter("app_name", "STRING", app_name),
|
|
366
|
+
ScalarQueryParameter("user_id", "STRING", user_id),
|
|
367
|
+
]
|
|
368
|
+
|
|
369
|
+
with self._config.provide_connection() as conn:
|
|
370
|
+
job_config = QueryJobConfig(query_parameters=params)
|
|
371
|
+
query_job = conn.query(sql, job_config=job_config)
|
|
372
|
+
results = list(query_job.result())
|
|
373
|
+
|
|
374
|
+
return [
|
|
375
|
+
SessionRecord(
|
|
376
|
+
id=row.id,
|
|
377
|
+
app_name=row.app_name,
|
|
378
|
+
user_id=row.user_id,
|
|
379
|
+
state=from_json(row.state) if row.state else {},
|
|
380
|
+
create_time=row.create_time,
|
|
381
|
+
update_time=row.update_time,
|
|
382
|
+
)
|
|
383
|
+
for row in results
|
|
384
|
+
]
|
|
385
|
+
|
|
386
|
+
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
|
|
387
|
+
"""List all sessions for a user in an app.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
app_name: Application name.
|
|
391
|
+
user_id: User identifier.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
List of session records ordered by update_time DESC.
|
|
395
|
+
|
|
396
|
+
Notes:
|
|
397
|
+
Uses clustering on (app_name, user_id) for efficiency.
|
|
398
|
+
"""
|
|
399
|
+
return await async_(self._list_sessions)(app_name, user_id)
|
|
400
|
+
|
|
401
|
+
def _delete_session(self, session_id: str) -> None:
|
|
402
|
+
"""Synchronous implementation of delete_session."""
|
|
403
|
+
events_table = self._get_full_table_name(self._events_table)
|
|
404
|
+
sessions_table = self._get_full_table_name(self._session_table)
|
|
405
|
+
|
|
406
|
+
params = [ScalarQueryParameter("session_id", "STRING", session_id)]
|
|
407
|
+
|
|
408
|
+
with self._config.provide_connection() as conn:
|
|
409
|
+
job_config = QueryJobConfig(query_parameters=params)
|
|
410
|
+
conn.query(f"DELETE FROM {events_table} WHERE session_id = @session_id", job_config=job_config).result()
|
|
411
|
+
conn.query(f"DELETE FROM {sessions_table} WHERE id = @session_id", job_config=job_config).result()
|
|
412
|
+
|
|
413
|
+
async def delete_session(self, session_id: str) -> None:
|
|
414
|
+
"""Delete session and all associated events.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
session_id: Session identifier.
|
|
418
|
+
|
|
419
|
+
Notes:
|
|
420
|
+
BigQuery doesn't support foreign keys, so we manually delete events first.
|
|
421
|
+
Uses two separate DELETE statements in sequence.
|
|
422
|
+
"""
|
|
423
|
+
await async_(self._delete_session)(session_id)
|
|
424
|
+
|
|
425
|
+
def _append_event(self, event_record: EventRecord) -> None:
|
|
426
|
+
"""Synchronous implementation of append_event."""
|
|
427
|
+
table_name = self._get_full_table_name(self._events_table)
|
|
428
|
+
|
|
429
|
+
content_json = to_json(event_record.get("content")) if event_record.get("content") else None
|
|
430
|
+
grounding_metadata_json = (
|
|
431
|
+
to_json(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
|
|
432
|
+
)
|
|
433
|
+
custom_metadata_json = (
|
|
434
|
+
to_json(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
sql = f"""
|
|
438
|
+
INSERT INTO {table_name} (
|
|
439
|
+
id, session_id, app_name, user_id, invocation_id, author, actions,
|
|
440
|
+
long_running_tool_ids_json, branch, timestamp, content,
|
|
441
|
+
grounding_metadata, custom_metadata, partial, turn_complete,
|
|
442
|
+
interrupted, error_code, error_message
|
|
443
|
+
) VALUES (
|
|
444
|
+
@id, @session_id, @app_name, @user_id, @invocation_id, @author, @actions,
|
|
445
|
+
@long_running_tool_ids_json, @branch, @timestamp,
|
|
446
|
+
{"JSON(@content)" if content_json else "NULL"},
|
|
447
|
+
{"JSON(@grounding_metadata)" if grounding_metadata_json else "NULL"},
|
|
448
|
+
{"JSON(@custom_metadata)" if custom_metadata_json else "NULL"},
|
|
449
|
+
@partial, @turn_complete, @interrupted, @error_code, @error_message
|
|
450
|
+
)
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
actions_value = event_record.get("actions")
|
|
454
|
+
params = [
|
|
455
|
+
ScalarQueryParameter("id", "STRING", event_record["id"]),
|
|
456
|
+
ScalarQueryParameter("session_id", "STRING", event_record["session_id"]),
|
|
457
|
+
ScalarQueryParameter("app_name", "STRING", event_record["app_name"]),
|
|
458
|
+
ScalarQueryParameter("user_id", "STRING", event_record["user_id"]),
|
|
459
|
+
ScalarQueryParameter("invocation_id", "STRING", event_record.get("invocation_id")),
|
|
460
|
+
ScalarQueryParameter("author", "STRING", event_record.get("author")),
|
|
461
|
+
ScalarQueryParameter(
|
|
462
|
+
"actions",
|
|
463
|
+
"BYTES",
|
|
464
|
+
actions_value.decode("latin1") if isinstance(actions_value, bytes) else actions_value,
|
|
465
|
+
),
|
|
466
|
+
ScalarQueryParameter(
|
|
467
|
+
"long_running_tool_ids_json", "STRING", event_record.get("long_running_tool_ids_json")
|
|
468
|
+
),
|
|
469
|
+
ScalarQueryParameter("branch", "STRING", event_record.get("branch")),
|
|
470
|
+
ScalarQueryParameter("timestamp", "TIMESTAMP", event_record["timestamp"]),
|
|
471
|
+
ScalarQueryParameter("partial", "BOOL", event_record.get("partial")),
|
|
472
|
+
ScalarQueryParameter("turn_complete", "BOOL", event_record.get("turn_complete")),
|
|
473
|
+
ScalarQueryParameter("interrupted", "BOOL", event_record.get("interrupted")),
|
|
474
|
+
ScalarQueryParameter("error_code", "STRING", event_record.get("error_code")),
|
|
475
|
+
ScalarQueryParameter("error_message", "STRING", event_record.get("error_message")),
|
|
476
|
+
]
|
|
477
|
+
|
|
478
|
+
if content_json:
|
|
479
|
+
params.append(ScalarQueryParameter("content", "STRING", content_json))
|
|
480
|
+
if grounding_metadata_json:
|
|
481
|
+
params.append(ScalarQueryParameter("grounding_metadata", "STRING", grounding_metadata_json))
|
|
482
|
+
if custom_metadata_json:
|
|
483
|
+
params.append(ScalarQueryParameter("custom_metadata", "STRING", custom_metadata_json))
|
|
484
|
+
|
|
485
|
+
with self._config.provide_connection() as conn:
|
|
486
|
+
job_config = QueryJobConfig(query_parameters=params)
|
|
487
|
+
conn.query(sql, job_config=job_config).result()
|
|
488
|
+
|
|
489
|
+
async def append_event(self, event_record: EventRecord) -> None:
|
|
490
|
+
"""Append an event to a session.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
event_record: Event record to store.
|
|
494
|
+
|
|
495
|
+
Notes:
|
|
496
|
+
Uses BigQuery TIMESTAMP for timezone-aware timestamps.
|
|
497
|
+
JSON fields are serialized to STRING then cast to JSON.
|
|
498
|
+
Boolean fields stored natively as BOOL.
|
|
499
|
+
"""
|
|
500
|
+
await async_(self._append_event)(event_record)
|
|
501
|
+
|
|
502
|
+
def _get_events(
|
|
503
|
+
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
|
|
504
|
+
) -> "list[EventRecord]":
|
|
505
|
+
"""Synchronous implementation of get_events."""
|
|
506
|
+
table_name = self._get_full_table_name(self._events_table)
|
|
507
|
+
|
|
508
|
+
where_clauses = ["session_id = @session_id"]
|
|
509
|
+
params: list[ScalarQueryParameter] = [ScalarQueryParameter("session_id", "STRING", session_id)]
|
|
510
|
+
|
|
511
|
+
if after_timestamp is not None:
|
|
512
|
+
where_clauses.append("timestamp > @after_timestamp")
|
|
513
|
+
params.append(ScalarQueryParameter("after_timestamp", "TIMESTAMP", after_timestamp))
|
|
514
|
+
|
|
515
|
+
where_clause = " AND ".join(where_clauses)
|
|
516
|
+
limit_clause = f" LIMIT {limit}" if limit else ""
|
|
517
|
+
|
|
518
|
+
sql = f"""
|
|
519
|
+
SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
|
|
520
|
+
long_running_tool_ids_json, branch, timestamp,
|
|
521
|
+
JSON_VALUE(content) as content,
|
|
522
|
+
JSON_VALUE(grounding_metadata) as grounding_metadata,
|
|
523
|
+
JSON_VALUE(custom_metadata) as custom_metadata,
|
|
524
|
+
partial, turn_complete, interrupted, error_code, error_message
|
|
525
|
+
FROM {table_name}
|
|
526
|
+
WHERE {where_clause}
|
|
527
|
+
ORDER BY timestamp ASC{limit_clause}
|
|
528
|
+
"""
|
|
529
|
+
|
|
530
|
+
with self._config.provide_connection() as conn:
|
|
531
|
+
job_config = QueryJobConfig(query_parameters=params)
|
|
532
|
+
query_job = conn.query(sql, job_config=job_config)
|
|
533
|
+
results = list(query_job.result())
|
|
534
|
+
|
|
535
|
+
return [
|
|
536
|
+
EventRecord(
|
|
537
|
+
id=row.id,
|
|
538
|
+
session_id=row.session_id,
|
|
539
|
+
app_name=row.app_name,
|
|
540
|
+
user_id=row.user_id,
|
|
541
|
+
invocation_id=row.invocation_id,
|
|
542
|
+
author=row.author,
|
|
543
|
+
actions=bytes(row.actions) if row.actions else b"",
|
|
544
|
+
long_running_tool_ids_json=row.long_running_tool_ids_json,
|
|
545
|
+
branch=row.branch,
|
|
546
|
+
timestamp=row.timestamp,
|
|
547
|
+
content=from_json(row.content) if row.content else None,
|
|
548
|
+
grounding_metadata=from_json(row.grounding_metadata) if row.grounding_metadata else None,
|
|
549
|
+
custom_metadata=from_json(row.custom_metadata) if row.custom_metadata else None,
|
|
550
|
+
partial=row.partial,
|
|
551
|
+
turn_complete=row.turn_complete,
|
|
552
|
+
interrupted=row.interrupted,
|
|
553
|
+
error_code=row.error_code,
|
|
554
|
+
error_message=row.error_message,
|
|
555
|
+
)
|
|
556
|
+
for row in results
|
|
557
|
+
]
|
|
558
|
+
|
|
559
|
+
async def get_events(
|
|
560
|
+
self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
|
|
561
|
+
) -> "list[EventRecord]":
|
|
562
|
+
"""Get events for a session.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
session_id: Session identifier.
|
|
566
|
+
after_timestamp: Only return events after this time.
|
|
567
|
+
limit: Maximum number of events to return.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
List of event records ordered by timestamp ASC.
|
|
571
|
+
|
|
572
|
+
Notes:
|
|
573
|
+
Uses clustering on (session_id, timestamp) for efficient retrieval.
|
|
574
|
+
Parses JSON fields and converts BYTES actions to bytes.
|
|
575
|
+
"""
|
|
576
|
+
return await async_(self._get_events)(session_id, after_timestamp, limit)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import contextlib
|
|
4
4
|
import logging
|
|
5
|
-
from typing import TYPE_CHECKING, Any,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
|
|
6
6
|
|
|
7
7
|
from google.cloud.bigquery import LoadJobConfig, QueryJobConfig
|
|
8
8
|
from typing_extensions import NotRequired
|
|
@@ -14,7 +14,7 @@ from sqlspec.exceptions import ImproperConfigurationError
|
|
|
14
14
|
from sqlspec.typing import Empty
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
|
-
from collections.abc import Generator
|
|
17
|
+
from collections.abc import Callable, Generator
|
|
18
18
|
|
|
19
19
|
from google.api_core.client_info import ClientInfo
|
|
20
20
|
from google.api_core.client_options import ClientOptions
|
|
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|
|
26
26
|
logger = logging.getLogger(__name__)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
class BigQueryConnectionParams(TypedDict
|
|
29
|
+
class BigQueryConnectionParams(TypedDict):
|
|
30
30
|
"""Standard BigQuery connection parameters.
|
|
31
31
|
|
|
32
32
|
Includes both official BigQuery client parameters and BigQuery-specific configuration options.
|
|
@@ -63,7 +63,7 @@ class BigQueryConnectionParams(TypedDict, total=False):
|
|
|
63
63
|
extra: NotRequired[dict[str, Any]]
|
|
64
64
|
|
|
65
65
|
|
|
66
|
-
class BigQueryDriverFeatures(TypedDict
|
|
66
|
+
class BigQueryDriverFeatures(TypedDict):
|
|
67
67
|
"""BigQuery driver-specific features configuration.
|
|
68
68
|
|
|
69
69
|
Only non-standard BigQuery client parameters that are SQLSpec-specific extensions.
|
|
@@ -73,6 +73,8 @@ class BigQueryDriverFeatures(TypedDict, total=False):
|
|
|
73
73
|
on_job_start: NotRequired["Callable[[str], None]"]
|
|
74
74
|
on_job_complete: NotRequired["Callable[[str, Any], None]"]
|
|
75
75
|
on_connection_create: NotRequired["Callable[[Any], None]"]
|
|
76
|
+
json_serializer: NotRequired["Callable[[Any], str]"]
|
|
77
|
+
enable_uuid_conversion: NotRequired[bool]
|
|
76
78
|
|
|
77
79
|
|
|
78
80
|
__all__ = ("BigQueryConfig", "BigQueryConnectionParams", "BigQueryDriverFeatures")
|
|
@@ -86,15 +88,17 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]):
|
|
|
86
88
|
|
|
87
89
|
driver_type: ClassVar[type[BigQueryDriver]] = BigQueryDriver
|
|
88
90
|
connection_type: "ClassVar[type[BigQueryConnection]]" = BigQueryConnection
|
|
91
|
+
supports_transactional_ddl: ClassVar[bool] = False
|
|
89
92
|
|
|
90
93
|
def __init__(
|
|
91
94
|
self,
|
|
92
95
|
*,
|
|
93
|
-
connection_config: "
|
|
94
|
-
migration_config:
|
|
95
|
-
statement_config: "
|
|
96
|
-
driver_features: "
|
|
97
|
-
bind_key: "
|
|
96
|
+
connection_config: "BigQueryConnectionParams | dict[str, Any] | None" = None,
|
|
97
|
+
migration_config: dict[str, Any] | None = None,
|
|
98
|
+
statement_config: "StatementConfig | None" = None,
|
|
99
|
+
driver_features: "BigQueryDriverFeatures | dict[str, Any] | None" = None,
|
|
100
|
+
bind_key: "str | None" = None,
|
|
101
|
+
extension_config: "dict[str, dict[str, Any]] | None" = None,
|
|
98
102
|
) -> None:
|
|
99
103
|
"""Initialize BigQuery configuration.
|
|
100
104
|
|
|
@@ -104,6 +108,7 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]):
|
|
|
104
108
|
statement_config: Statement configuration override
|
|
105
109
|
driver_features: BigQuery-specific driver features
|
|
106
110
|
bind_key: Optional unique identifier for this configuration
|
|
111
|
+
extension_config: Extension-specific configuration (e.g., Litestar plugin settings)
|
|
107
112
|
"""
|
|
108
113
|
|
|
109
114
|
self.connection_config: dict[str, Any] = dict(connection_config) if connection_config else {}
|
|
@@ -113,7 +118,15 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]):
|
|
|
113
118
|
|
|
114
119
|
self.driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
|
|
115
120
|
|
|
116
|
-
|
|
121
|
+
if "enable_uuid_conversion" not in self.driver_features:
|
|
122
|
+
self.driver_features["enable_uuid_conversion"] = True
|
|
123
|
+
|
|
124
|
+
if "json_serializer" not in self.driver_features:
|
|
125
|
+
from sqlspec.utils.serializers import to_json
|
|
126
|
+
|
|
127
|
+
self.driver_features["json_serializer"] = to_json
|
|
128
|
+
|
|
129
|
+
self._connection_instance: BigQueryConnection | None = self.driver_features.get("connection_instance")
|
|
117
130
|
|
|
118
131
|
if "default_query_job_config" not in self.connection_config:
|
|
119
132
|
self._setup_default_job_config()
|
|
@@ -127,6 +140,7 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]):
|
|
|
127
140
|
statement_config=statement_config,
|
|
128
141
|
driver_features=self.driver_features,
|
|
129
142
|
bind_key=bind_key,
|
|
143
|
+
extension_config=extension_config,
|
|
130
144
|
)
|
|
131
145
|
|
|
132
146
|
def _setup_default_job_config(self) -> None:
|
|
@@ -215,7 +229,7 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]):
|
|
|
215
229
|
|
|
216
230
|
@contextlib.contextmanager
|
|
217
231
|
def provide_session(
|
|
218
|
-
self, *_args: Any, statement_config: "
|
|
232
|
+
self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any
|
|
219
233
|
) -> "Generator[BigQueryDriver, None, None]":
|
|
220
234
|
"""Provide a BigQuery driver session context manager.
|
|
221
235
|
|