sqlspec 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (212) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +55 -25
  3. sqlspec/_typing.py +155 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +880 -0
  7. sqlspec/adapters/adbc/config.py +62 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +74 -2
  9. sqlspec/adapters/adbc/driver.py +226 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +44 -50
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +536 -0
  16. sqlspec/adapters/aiosqlite/config.py +86 -16
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
  18. sqlspec/adapters/aiosqlite/driver.py +127 -38
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +1 -1
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +503 -0
  26. sqlspec/adapters/asyncmy/config.py +59 -17
  27. sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
  28. sqlspec/adapters/asyncmy/driver.py +293 -62
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +460 -0
  36. sqlspec/adapters/asyncpg/config.py +57 -36
  37. sqlspec/adapters/asyncpg/data_dictionary.py +48 -2
  38. sqlspec/adapters/asyncpg/driver.py +153 -23
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +585 -0
  44. sqlspec/adapters/bigquery/config.py +36 -11
  45. sqlspec/adapters/bigquery/data_dictionary.py +42 -2
  46. sqlspec/adapters/bigquery/driver.py +489 -144
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +55 -23
  50. sqlspec/adapters/duckdb/_types.py +2 -2
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +563 -0
  53. sqlspec/adapters/duckdb/config.py +79 -21
  54. sqlspec/adapters/duckdb/data_dictionary.py +41 -2
  55. sqlspec/adapters/duckdb/driver.py +225 -44
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +5 -5
  59. sqlspec/adapters/duckdb/type_converter.py +51 -21
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1628 -0
  64. sqlspec/adapters/oracledb/config.py +120 -36
  65. sqlspec/adapters/oracledb/data_dictionary.py +87 -20
  66. sqlspec/adapters/oracledb/driver.py +475 -86
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +765 -0
  69. sqlspec/adapters/oracledb/migrations.py +316 -25
  70. sqlspec/adapters/oracledb/type_converter.py +91 -16
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +483 -0
  75. sqlspec/adapters/psqlpy/config.py +45 -19
  76. sqlspec/adapters/psqlpy/data_dictionary.py +48 -2
  77. sqlspec/adapters/psqlpy/driver.py +108 -41
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +40 -11
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +962 -0
  85. sqlspec/adapters/psycopg/config.py +65 -37
  86. sqlspec/adapters/psycopg/data_dictionary.py +91 -3
  87. sqlspec/adapters/psycopg/driver.py +200 -78
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +582 -0
  95. sqlspec/adapters/sqlite/config.py +85 -16
  96. sqlspec/adapters/sqlite/data_dictionary.py +34 -2
  97. sqlspec/adapters/sqlite/driver.py +120 -52
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +5 -5
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +91 -58
  104. sqlspec/builder/_column.py +5 -5
  105. sqlspec/builder/_ddl.py +98 -89
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +41 -44
  109. sqlspec/builder/_insert.py +5 -82
  110. sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +9 -11
  113. sqlspec/builder/_select.py +1313 -25
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +76 -69
  116. sqlspec/config.py +331 -62
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +18 -18
  119. sqlspec/core/compiler.py +6 -8
  120. sqlspec/core/filters.py +55 -47
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +76 -45
  123. sqlspec/core/result.py +234 -47
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +32 -31
  126. sqlspec/core/type_conversion.py +3 -2
  127. sqlspec/driver/__init__.py +1 -3
  128. sqlspec/driver/_async.py +183 -160
  129. sqlspec/driver/_common.py +197 -109
  130. sqlspec/driver/_sync.py +189 -161
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +70 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +69 -61
  142. sqlspec/extensions/fastapi/__init__.py +21 -0
  143. sqlspec/extensions/fastapi/extension.py +331 -0
  144. sqlspec/extensions/fastapi/providers.py +543 -0
  145. sqlspec/extensions/flask/__init__.py +36 -0
  146. sqlspec/extensions/flask/_state.py +71 -0
  147. sqlspec/extensions/flask/_utils.py +40 -0
  148. sqlspec/extensions/flask/extension.py +389 -0
  149. sqlspec/extensions/litestar/__init__.py +21 -4
  150. sqlspec/extensions/litestar/cli.py +54 -10
  151. sqlspec/extensions/litestar/config.py +56 -266
  152. sqlspec/extensions/litestar/handlers.py +46 -17
  153. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  154. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  155. sqlspec/extensions/litestar/plugin.py +349 -224
  156. sqlspec/extensions/litestar/providers.py +25 -25
  157. sqlspec/extensions/litestar/store.py +265 -0
  158. sqlspec/extensions/starlette/__init__.py +10 -0
  159. sqlspec/extensions/starlette/_state.py +25 -0
  160. sqlspec/extensions/starlette/_utils.py +52 -0
  161. sqlspec/extensions/starlette/extension.py +254 -0
  162. sqlspec/extensions/starlette/middleware.py +154 -0
  163. sqlspec/loader.py +30 -49
  164. sqlspec/migrations/base.py +200 -76
  165. sqlspec/migrations/commands.py +591 -62
  166. sqlspec/migrations/context.py +6 -9
  167. sqlspec/migrations/fix.py +199 -0
  168. sqlspec/migrations/loaders.py +47 -19
  169. sqlspec/migrations/runner.py +241 -75
  170. sqlspec/migrations/tracker.py +237 -21
  171. sqlspec/migrations/utils.py +51 -3
  172. sqlspec/migrations/validation.py +177 -0
  173. sqlspec/protocols.py +106 -36
  174. sqlspec/storage/_utils.py +85 -0
  175. sqlspec/storage/backends/fsspec.py +133 -107
  176. sqlspec/storage/backends/local.py +78 -51
  177. sqlspec/storage/backends/obstore.py +276 -168
  178. sqlspec/storage/registry.py +75 -39
  179. sqlspec/typing.py +30 -84
  180. sqlspec/utils/__init__.py +25 -4
  181. sqlspec/utils/arrow_helpers.py +81 -0
  182. sqlspec/utils/config_resolver.py +6 -6
  183. sqlspec/utils/correlation.py +4 -5
  184. sqlspec/utils/data_transformation.py +3 -2
  185. sqlspec/utils/deprecation.py +9 -8
  186. sqlspec/utils/fixtures.py +4 -4
  187. sqlspec/utils/logging.py +46 -6
  188. sqlspec/utils/module_loader.py +205 -5
  189. sqlspec/utils/portal.py +311 -0
  190. sqlspec/utils/schema.py +288 -0
  191. sqlspec/utils/serializers.py +113 -4
  192. sqlspec/utils/sync_tools.py +36 -22
  193. sqlspec/utils/text.py +1 -2
  194. sqlspec/utils/type_guards.py +136 -20
  195. sqlspec/utils/version.py +433 -0
  196. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +41 -22
  197. sqlspec-0.28.0.dist-info/RECORD +221 -0
  198. sqlspec/builder/mixins/__init__.py +0 -55
  199. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
  200. sqlspec/builder/mixins/_delete_operations.py +0 -50
  201. sqlspec/builder/mixins/_insert_operations.py +0 -282
  202. sqlspec/builder/mixins/_merge_operations.py +0 -698
  203. sqlspec/builder/mixins/_order_limit_operations.py +0 -145
  204. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  205. sqlspec/builder/mixins/_select_operations.py +0 -930
  206. sqlspec/builder/mixins/_update_operations.py +0 -199
  207. sqlspec/builder/mixins/_where_clause.py +0 -1298
  208. sqlspec-0.26.0.dist-info/RECORD +0 -157
  209. sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
  210. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
  211. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
  212. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,460 @@
1
+ """AsyncPG ADK store for Google Agent Development Kit session/event storage."""
2
+
3
+ from typing import TYPE_CHECKING, Any, Final
4
+
5
+ import asyncpg
6
+
7
+ from sqlspec.config import AsyncConfigT
8
+ from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
9
+ from sqlspec.utils.logging import get_logger
10
+
11
+ if TYPE_CHECKING:
12
+ from datetime import datetime
13
+
14
+ logger = get_logger("adapters.asyncpg.adk.store")
15
+
16
+ __all__ = ("AsyncpgADKStore",)
17
+
18
+ POSTGRES_TABLE_NOT_FOUND_ERROR: Final = "42P01"
19
+
20
+
21
+ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]):
22
+ """PostgreSQL ADK store base class for all PostgreSQL drivers.
23
+
24
+ Implements session and event storage for Google Agent Development Kit
25
+ using PostgreSQL via any PostgreSQL driver (AsyncPG, Psycopg, Psqlpy).
26
+ All drivers share the same SQL dialect and parameter style ($1, $2, etc).
27
+
28
+ Provides:
29
+ - Session state management with JSONB storage and merge operations
30
+ - Event history tracking with BYTEA-serialized actions
31
+ - Microsecond-precision timestamps with TIMESTAMPTZ
32
+ - Foreign key constraints with cascade delete
33
+ - Efficient upserts using ON CONFLICT
34
+ - GIN indexes for JSONB queries
35
+ - HOT updates with FILLFACTOR 80
36
+ - Optional user FK column for multi-tenancy
37
+
38
+ Args:
39
+ config: PostgreSQL database config with extension_config["adk"] settings.
40
+
41
+ Example:
42
+ from sqlspec.adapters.asyncpg import AsyncpgConfig
43
+ from sqlspec.adapters.asyncpg.adk import AsyncpgADKStore
44
+
45
+ config = AsyncpgConfig(
46
+ pool_config={"dsn": "postgresql://..."},
47
+ extension_config={
48
+ "adk": {
49
+ "session_table": "my_sessions",
50
+ "events_table": "my_events",
51
+ "owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
52
+ }
53
+ }
54
+ )
55
+ store = AsyncpgADKStore(config)
56
+ await store.create_tables()
57
+
58
+ Notes:
59
+ - PostgreSQL JSONB type used for state (more efficient than JSON)
60
+ - AsyncPG automatically converts Python dicts to/from JSONB (no manual serialization)
61
+ - TIMESTAMPTZ provides timezone-aware microsecond precision
62
+ - State merging uses `state || $1::jsonb` operator for efficiency
63
+ - BYTEA for pre-serialized actions from Google ADK (not pickled here)
64
+ - GIN index on state for JSONB queries (partial index)
65
+ - FILLFACTOR 80 leaves space for HOT updates
66
+ - Generic over PostgresConfigT to support all PostgreSQL drivers
67
+ - Owner ID column enables multi-tenant isolation with referential integrity
68
+ - Configuration is read from config.extension_config["adk"]
69
+ """
70
+
71
+ __slots__ = ()
72
+
73
+ def __init__(self, config: AsyncConfigT) -> None:
74
+ """Initialize AsyncPG ADK store.
75
+
76
+ Args:
77
+ config: PostgreSQL database config.
78
+
79
+ Notes:
80
+ Configuration is read from config.extension_config["adk"]:
81
+ - session_table: Sessions table name (default: "adk_sessions")
82
+ - events_table: Events table name (default: "adk_events")
83
+ - owner_id_column: Optional owner FK column DDL (default: None)
84
+ """
85
+ super().__init__(config)
86
+
87
+ async def _get_create_sessions_table_sql(self) -> str:
88
+ """Get PostgreSQL CREATE TABLE SQL for sessions.
89
+
90
+ Returns:
91
+ SQL statement to create adk_sessions table with indexes.
92
+
93
+ Notes:
94
+ - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names)
95
+ - JSONB type for state storage with default empty object
96
+ - TIMESTAMPTZ with microsecond precision
97
+ - FILLFACTOR 80 for HOT updates (reduces table bloat)
98
+ - Composite index on (app_name, user_id) for listing
99
+ - Index on update_time DESC for recent session queries
100
+ - Partial GIN index on state for JSONB queries (only non-empty)
101
+ - Optional owner ID column for multi-tenancy or owner references
102
+ """
103
+ owner_id_line = ""
104
+ if self._owner_id_column_ddl:
105
+ owner_id_line = f",\n {self._owner_id_column_ddl}"
106
+
107
+ return f"""
108
+ CREATE TABLE IF NOT EXISTS {self._session_table} (
109
+ id VARCHAR(128) PRIMARY KEY,
110
+ app_name VARCHAR(128) NOT NULL,
111
+ user_id VARCHAR(128) NOT NULL{owner_id_line},
112
+ state JSONB NOT NULL DEFAULT '{{}}'::jsonb,
113
+ create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
114
+ update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
115
+ ) WITH (fillfactor = 80);
116
+
117
+ CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user
118
+ ON {self._session_table}(app_name, user_id);
119
+
120
+ CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time
121
+ ON {self._session_table}(update_time DESC);
122
+
123
+ CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state
124
+ ON {self._session_table} USING GIN (state)
125
+ WHERE state != '{{}}'::jsonb;
126
+ """
127
+
128
+ async def _get_create_events_table_sql(self) -> str:
129
+ """Get PostgreSQL CREATE TABLE SQL for events.
130
+
131
+ Returns:
132
+ SQL statement to create adk_events table with indexes.
133
+
134
+ Notes:
135
+ - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256),
136
+ branch(256), error_code(256), error_message(1024)
137
+ - BYTEA for pickled actions (no size limit)
138
+ - JSONB for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
139
+ - BOOLEAN for partial, turn_complete, interrupted
140
+ - Foreign key to sessions with CASCADE delete
141
+ - Index on (session_id, timestamp ASC) for ordered event retrieval
142
+ """
143
+ return f"""
144
+ CREATE TABLE IF NOT EXISTS {self._events_table} (
145
+ id VARCHAR(128) PRIMARY KEY,
146
+ session_id VARCHAR(128) NOT NULL,
147
+ app_name VARCHAR(128) NOT NULL,
148
+ user_id VARCHAR(128) NOT NULL,
149
+ invocation_id VARCHAR(256),
150
+ author VARCHAR(256),
151
+ actions BYTEA,
152
+ long_running_tool_ids_json JSONB,
153
+ branch VARCHAR(256),
154
+ timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
155
+ content JSONB,
156
+ grounding_metadata JSONB,
157
+ custom_metadata JSONB,
158
+ partial BOOLEAN,
159
+ turn_complete BOOLEAN,
160
+ interrupted BOOLEAN,
161
+ error_code VARCHAR(256),
162
+ error_message VARCHAR(1024),
163
+ FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE
164
+ );
165
+
166
+ CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session
167
+ ON {self._events_table}(session_id, timestamp ASC);
168
+ """
169
+
170
+ def _get_drop_tables_sql(self) -> "list[str]":
171
+ """Get PostgreSQL DROP TABLE SQL statements.
172
+
173
+ Returns:
174
+ List of SQL statements to drop tables and indexes.
175
+
176
+ Notes:
177
+ Order matters: drop events table (child) before sessions (parent).
178
+ PostgreSQL automatically drops indexes when dropping tables.
179
+ """
180
+ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
181
+
182
+ async def create_tables(self) -> None:
183
+ """Create both sessions and events tables if they don't exist."""
184
+ async with self.config.provide_session() as driver:
185
+ await driver.execute_script(await self._get_create_sessions_table_sql())
186
+ await driver.execute_script(await self._get_create_events_table_sql())
187
+ logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
188
+
189
+ async def create_session(
190
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
191
+ ) -> SessionRecord:
192
+ """Create a new session.
193
+
194
+ Args:
195
+ session_id: Unique session identifier.
196
+ app_name: Application name.
197
+ user_id: User identifier.
198
+ state: Initial session state.
199
+ owner_id: Optional owner ID value for owner_id_column (if configured).
200
+
201
+ Returns:
202
+ Created session record.
203
+
204
+ Notes:
205
+ Uses CURRENT_TIMESTAMP for create_time and update_time.
206
+ State is passed as dict and asyncpg converts to JSONB automatically.
207
+ If owner_id_column is configured, owner_id value must be provided.
208
+ """
209
+ async with self.config.provide_connection() as conn:
210
+ if self._owner_id_column_name:
211
+ sql = f"""
212
+ INSERT INTO {self._session_table}
213
+ (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time)
214
+ VALUES ($1, $2, $3, $4, $5, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
215
+ """
216
+ await conn.execute(sql, session_id, app_name, user_id, owner_id, state)
217
+ else:
218
+ sql = f"""
219
+ INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time)
220
+ VALUES ($1, $2, $3, $4, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
221
+ """
222
+ await conn.execute(sql, session_id, app_name, user_id, state)
223
+
224
+ return await self.get_session(session_id) # type: ignore[return-value]
225
+
226
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
227
+ """Get session by ID.
228
+
229
+ Args:
230
+ session_id: Session identifier.
231
+
232
+ Returns:
233
+ Session record or None if not found.
234
+
235
+ Notes:
236
+ PostgreSQL returns datetime objects for TIMESTAMPTZ columns.
237
+ JSONB is automatically parsed by asyncpg.
238
+ """
239
+ sql = f"""
240
+ SELECT id, app_name, user_id, state, create_time, update_time
241
+ FROM {self._session_table}
242
+ WHERE id = $1
243
+ """
244
+
245
+ try:
246
+ async with self.config.provide_connection() as conn:
247
+ row = await conn.fetchrow(sql, session_id)
248
+
249
+ if row is None:
250
+ return None
251
+
252
+ return SessionRecord(
253
+ id=row["id"],
254
+ app_name=row["app_name"],
255
+ user_id=row["user_id"],
256
+ state=row["state"],
257
+ create_time=row["create_time"],
258
+ update_time=row["update_time"],
259
+ )
260
+ except asyncpg.exceptions.UndefinedTableError:
261
+ return None
262
+
263
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
264
+ """Update session state.
265
+
266
+ Args:
267
+ session_id: Session identifier.
268
+ state: New state dictionary (replaces existing state).
269
+
270
+ Notes:
271
+ This replaces the entire state dictionary.
272
+ Uses CURRENT_TIMESTAMP for update_time.
273
+ """
274
+ sql = f"""
275
+ UPDATE {self._session_table}
276
+ SET state = $1, update_time = CURRENT_TIMESTAMP
277
+ WHERE id = $2
278
+ """
279
+
280
+ async with self.config.provide_connection() as conn:
281
+ await conn.execute(sql, state, session_id)
282
+
283
+ async def delete_session(self, session_id: str) -> None:
284
+ """Delete session and all associated events (cascade).
285
+
286
+ Args:
287
+ session_id: Session identifier.
288
+
289
+ Notes:
290
+ Foreign key constraint ensures events are cascade-deleted.
291
+ """
292
+ sql = f"DELETE FROM {self._session_table} WHERE id = $1"
293
+
294
+ async with self.config.provide_connection() as conn:
295
+ await conn.execute(sql, session_id)
296
+
297
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
298
+ """List sessions for an app, optionally filtered by user.
299
+
300
+ Args:
301
+ app_name: Application name.
302
+ user_id: User identifier. If None, lists all sessions for the app.
303
+
304
+ Returns:
305
+ List of session records ordered by update_time DESC.
306
+
307
+ Notes:
308
+ Uses composite index on (app_name, user_id) when user_id is provided.
309
+ """
310
+ if user_id is None:
311
+ sql = f"""
312
+ SELECT id, app_name, user_id, state, create_time, update_time
313
+ FROM {self._session_table}
314
+ WHERE app_name = $1
315
+ ORDER BY update_time DESC
316
+ """
317
+ params = [app_name]
318
+ else:
319
+ sql = f"""
320
+ SELECT id, app_name, user_id, state, create_time, update_time
321
+ FROM {self._session_table}
322
+ WHERE app_name = $1 AND user_id = $2
323
+ ORDER BY update_time DESC
324
+ """
325
+ params = [app_name, user_id]
326
+
327
+ try:
328
+ async with self.config.provide_connection() as conn:
329
+ rows = await conn.fetch(sql, *params)
330
+
331
+ return [
332
+ SessionRecord(
333
+ id=row["id"],
334
+ app_name=row["app_name"],
335
+ user_id=row["user_id"],
336
+ state=row["state"],
337
+ create_time=row["create_time"],
338
+ update_time=row["update_time"],
339
+ )
340
+ for row in rows
341
+ ]
342
+ except asyncpg.exceptions.UndefinedTableError:
343
+ return []
344
+
345
+ async def append_event(self, event_record: EventRecord) -> None:
346
+ """Append an event to a session.
347
+
348
+ Args:
349
+ event_record: Event record to store.
350
+
351
+ Notes:
352
+ Uses CURRENT_TIMESTAMP for timestamp if not provided.
353
+ JSONB fields are passed as dicts and asyncpg converts automatically.
354
+ """
355
+ content_json = event_record.get("content")
356
+ grounding_metadata_json = event_record.get("grounding_metadata")
357
+ custom_metadata_json = event_record.get("custom_metadata")
358
+
359
+ sql = f"""
360
+ INSERT INTO {self._events_table} (
361
+ id, session_id, app_name, user_id, invocation_id, author, actions,
362
+ long_running_tool_ids_json, branch, timestamp, content,
363
+ grounding_metadata, custom_metadata, partial, turn_complete,
364
+ interrupted, error_code, error_message
365
+ ) VALUES (
366
+ $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
367
+ )
368
+ """
369
+
370
+ async with self.config.provide_connection() as conn:
371
+ await conn.execute(
372
+ sql,
373
+ event_record["id"],
374
+ event_record["session_id"],
375
+ event_record["app_name"],
376
+ event_record["user_id"],
377
+ event_record.get("invocation_id"),
378
+ event_record.get("author"),
379
+ event_record.get("actions"),
380
+ event_record.get("long_running_tool_ids_json"),
381
+ event_record.get("branch"),
382
+ event_record["timestamp"],
383
+ content_json,
384
+ grounding_metadata_json,
385
+ custom_metadata_json,
386
+ event_record.get("partial"),
387
+ event_record.get("turn_complete"),
388
+ event_record.get("interrupted"),
389
+ event_record.get("error_code"),
390
+ event_record.get("error_message"),
391
+ )
392
+
393
+ async def get_events(
394
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
395
+ ) -> "list[EventRecord]":
396
+ """Get events for a session.
397
+
398
+ Args:
399
+ session_id: Session identifier.
400
+ after_timestamp: Only return events after this time.
401
+ limit: Maximum number of events to return.
402
+
403
+ Returns:
404
+ List of event records ordered by timestamp ASC.
405
+
406
+ Notes:
407
+ Uses index on (session_id, timestamp ASC).
408
+ Parses JSONB fields and converts BYTEA actions to bytes.
409
+ """
410
+ where_clauses = ["session_id = $1"]
411
+ params: list[Any] = [session_id]
412
+
413
+ if after_timestamp is not None:
414
+ where_clauses.append(f"timestamp > ${len(params) + 1}")
415
+ params.append(after_timestamp)
416
+
417
+ where_clause = " AND ".join(where_clauses)
418
+ limit_clause = f" LIMIT ${len(params) + 1}" if limit else ""
419
+ if limit:
420
+ params.append(limit)
421
+
422
+ sql = f"""
423
+ SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
424
+ long_running_tool_ids_json, branch, timestamp, content,
425
+ grounding_metadata, custom_metadata, partial, turn_complete,
426
+ interrupted, error_code, error_message
427
+ FROM {self._events_table}
428
+ WHERE {where_clause}
429
+ ORDER BY timestamp ASC{limit_clause}
430
+ """
431
+
432
+ try:
433
+ async with self.config.provide_connection() as conn:
434
+ rows = await conn.fetch(sql, *params)
435
+
436
+ return [
437
+ EventRecord(
438
+ id=row["id"],
439
+ session_id=row["session_id"],
440
+ app_name=row["app_name"],
441
+ user_id=row["user_id"],
442
+ invocation_id=row["invocation_id"],
443
+ author=row["author"],
444
+ actions=bytes(row["actions"]) if row["actions"] else b"",
445
+ long_running_tool_ids_json=row["long_running_tool_ids_json"],
446
+ branch=row["branch"],
447
+ timestamp=row["timestamp"],
448
+ content=row["content"],
449
+ grounding_metadata=row["grounding_metadata"],
450
+ custom_metadata=row["custom_metadata"],
451
+ partial=row["partial"],
452
+ turn_complete=row["turn_complete"],
453
+ interrupted=row["interrupted"],
454
+ error_code=row["error_code"],
455
+ error_message=row["error_message"],
456
+ )
457
+ for row in rows
458
+ ]
459
+ except asyncpg.exceptions.UndefinedTableError:
460
+ return []
@@ -3,7 +3,7 @@
3
3
  import logging
4
4
  from collections.abc import Callable
5
5
  from contextlib import asynccontextmanager
6
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypedDict, Union
6
+ from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
7
7
 
8
8
  from asyncpg import Connection, Record
9
9
  from asyncpg import create_pool as asyncpg_create_pool
@@ -14,6 +14,7 @@ from typing_extensions import NotRequired
14
14
  from sqlspec.adapters.asyncpg._types import AsyncpgConnection
15
15
  from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, asyncpg_statement_config
16
16
  from sqlspec.config import AsyncDatabaseConfig
17
+ from sqlspec.typing import PGVECTOR_INSTALLED
17
18
  from sqlspec.utils.serializers import from_json, to_json
18
19
 
19
20
  if TYPE_CHECKING:
@@ -28,7 +29,7 @@ __all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures",
28
29
  logger = logging.getLogger("sqlspec")
29
30
 
30
31
 
31
- class AsyncpgConnectionConfig(TypedDict, total=False):
32
+ class AsyncpgConnectionConfig(TypedDict):
32
33
  """TypedDict for AsyncPG connection parameters."""
33
34
 
34
35
  dsn: NotRequired[str]
@@ -48,7 +49,7 @@ class AsyncpgConnectionConfig(TypedDict, total=False):
48
49
  server_settings: NotRequired[dict[str, str]]
49
50
 
50
51
 
51
- class AsyncpgPoolConfig(AsyncpgConnectionConfig, total=False):
52
+ class AsyncpgPoolConfig(AsyncpgConnectionConfig):
52
53
  """TypedDict for AsyncPG pool parameters, inheriting connection parameters."""
53
54
 
54
55
  min_size: NotRequired[int]
@@ -63,11 +64,31 @@ class AsyncpgPoolConfig(AsyncpgConnectionConfig, total=False):
63
64
  extra: NotRequired[dict[str, Any]]
64
65
 
65
66
 
66
- class AsyncpgDriverFeatures(TypedDict, total=False):
67
- """TypedDict for AsyncPG driver features configuration."""
67
+ class AsyncpgDriverFeatures(TypedDict):
68
+ """AsyncPG driver feature flags.
69
+
70
+ json_serializer: Custom JSON serializer function for PostgreSQL JSON/JSONB types.
71
+ Defaults to sqlspec.utils.serializers.to_json.
72
+ Use for performance optimization (e.g., orjson) or custom encoding behavior.
73
+ Applied when enable_json_codecs is True.
74
+ json_deserializer: Custom JSON deserializer function for PostgreSQL JSON/JSONB types.
75
+ Defaults to sqlspec.utils.serializers.from_json.
76
+ Use for performance optimization (e.g., orjson) or custom decoding behavior.
77
+ Applied when enable_json_codecs is True.
78
+ enable_json_codecs: Enable automatic JSON/JSONB codec registration on connections.
79
+ Defaults to True for seamless Python dict/list to PostgreSQL JSON/JSONB conversion.
80
+ Set to False to disable automatic codec registration (manual handling required).
81
+ enable_pgvector: Enable pgvector extension support for vector similarity search.
82
+ Requires pgvector-python package (pip install pgvector) and PostgreSQL with pgvector extension.
83
+ Defaults to True when pgvector-python is installed.
84
+ Provides automatic conversion between Python objects and PostgreSQL vector types.
85
+ Enables vector similarity operations and index support.
86
+ """
68
87
 
69
88
  json_serializer: NotRequired[Callable[[Any], str]]
70
89
  json_deserializer: NotRequired[Callable[[str], Any]]
90
+ enable_json_codecs: NotRequired[bool]
91
+ enable_pgvector: NotRequired[bool]
71
92
 
72
93
 
73
94
  class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", AsyncpgDriver]):
@@ -75,16 +96,18 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
75
96
 
76
97
  driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver
77
98
  connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment]
99
+ supports_transactional_ddl: "ClassVar[bool]" = True
78
100
 
79
101
  def __init__(
80
102
  self,
81
103
  *,
82
- pool_config: "Optional[Union[AsyncpgPoolConfig, dict[str, Any]]]" = None,
83
- pool_instance: "Optional[Pool[Record]]" = None,
84
- migration_config: "Optional[dict[str, Any]]" = None,
85
- statement_config: "Optional[StatementConfig]" = None,
86
- driver_features: "Optional[Union[AsyncpgDriverFeatures, dict[str, Any]]]" = None,
87
- bind_key: "Optional[str]" = None,
104
+ pool_config: "AsyncpgPoolConfig | dict[str, Any] | None" = None,
105
+ pool_instance: "Pool[Record] | None" = None,
106
+ migration_config: "dict[str, Any] | None" = None,
107
+ statement_config: "StatementConfig | None" = None,
108
+ driver_features: "AsyncpgDriverFeatures | dict[str, Any] | None" = None,
109
+ bind_key: "str | None" = None,
110
+ extension_config: "dict[str, dict[str, Any]] | None" = None,
88
111
  ) -> None:
89
112
  """Initialize AsyncPG configuration.
90
113
 
@@ -95,6 +118,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
95
118
  statement_config: Statement configuration override
96
119
  driver_features: Driver features configuration (TypedDict or dict)
97
120
  bind_key: Optional unique identifier for this configuration
121
+ extension_config: Extension-specific configuration (e.g., Litestar plugin settings)
98
122
  """
99
123
  features_dict: dict[str, Any] = dict(driver_features) if driver_features else {}
100
124
 
@@ -102,6 +126,11 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
102
126
  features_dict["json_serializer"] = to_json
103
127
  if "json_deserializer" not in features_dict:
104
128
  features_dict["json_deserializer"] = from_json
129
+ if "enable_json_codecs" not in features_dict:
130
+ features_dict["enable_json_codecs"] = True
131
+ if "enable_pgvector" not in features_dict:
132
+ features_dict["enable_pgvector"] = PGVECTOR_INSTALLED
133
+
105
134
  super().__init__(
106
135
  pool_config=dict(pool_config) if pool_config else {},
107
136
  pool_instance=pool_instance,
@@ -109,6 +138,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
109
138
  statement_config=statement_config or asyncpg_statement_config,
110
139
  driver_features=features_dict,
111
140
  bind_key=bind_key,
141
+ extension_config=extension_config,
112
142
  )
113
143
 
114
144
  def _get_pool_config_dict(self) -> "dict[str, Any]":
@@ -132,35 +162,24 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
132
162
  return await asyncpg_create_pool(**config)
133
163
 
134
164
  async def _init_connection(self, connection: "AsyncpgConnection") -> None:
135
- """Initialize connection with JSON codecs and pgvector support."""
165
+ """Initialize connection with JSON codecs and pgvector support.
136
166
 
137
- try:
138
- # Set up JSON type codec
139
- await connection.set_type_codec(
140
- "json",
141
- encoder=self.driver_features.get("json_serializer", to_json),
142
- decoder=self.driver_features.get("json_deserializer", from_json),
143
- schema="pg_catalog",
144
- )
145
- # Set up JSONB type codec
146
- await connection.set_type_codec(
147
- "jsonb",
167
+ Args:
168
+ connection: AsyncPG connection to initialize.
169
+ """
170
+ if self.driver_features.get("enable_json_codecs", True):
171
+ from sqlspec.adapters.asyncpg._type_handlers import register_json_codecs
172
+
173
+ await register_json_codecs(
174
+ connection,
148
175
  encoder=self.driver_features.get("json_serializer", to_json),
149
176
  decoder=self.driver_features.get("json_deserializer", from_json),
150
- schema="pg_catalog",
151
177
  )
152
- except Exception as e:
153
- logger.debug("Failed to configure JSON type codecs for asyncpg: %s", e)
154
178
 
155
- # Initialize pgvector support
156
- try:
157
- import pgvector.asyncpg
179
+ if self.driver_features.get("enable_pgvector", False):
180
+ from sqlspec.adapters.asyncpg._type_handlers import register_pgvector_support
158
181
 
159
- await pgvector.asyncpg.register_vector(connection)
160
- except ImportError:
161
- pass
162
- except Exception as e:
163
- logger.debug("Failed to register pgvector for asyncpg: %s", e)
182
+ await register_pgvector_support(connection)
164
183
 
165
184
  async def _close_pool(self) -> None:
166
185
  """Close the actual async connection pool."""
@@ -204,7 +223,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
204
223
 
205
224
  @asynccontextmanager
206
225
  async def provide_session(
207
- self, *args: Any, statement_config: "Optional[StatementConfig]" = None, **kwargs: Any
226
+ self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
208
227
  ) -> "AsyncGenerator[AsyncpgDriver, None]":
209
228
  """Provide an async driver session context manager.
210
229
 
@@ -218,7 +237,9 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async
218
237
  """
219
238
  async with self.provide_connection(*args, **kwargs) as connection:
220
239
  final_statement_config = statement_config or self.statement_config or asyncpg_statement_config
221
- yield self.driver_type(connection=connection, statement_config=final_statement_config)
240
+ yield self.driver_type(
241
+ connection=connection, statement_config=final_statement_config, driver_features=self.driver_features
242
+ )
222
243
 
223
244
  async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
224
245
  """Provide async pool instance.