sqlspec 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (212) hide show
  1. sqlspec/__init__.py +7 -15
  2. sqlspec/_serialization.py +55 -25
  3. sqlspec/_typing.py +155 -52
  4. sqlspec/adapters/adbc/_types.py +1 -1
  5. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  6. sqlspec/adapters/adbc/adk/store.py +880 -0
  7. sqlspec/adapters/adbc/config.py +62 -12
  8. sqlspec/adapters/adbc/data_dictionary.py +74 -2
  9. sqlspec/adapters/adbc/driver.py +226 -58
  10. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  11. sqlspec/adapters/adbc/litestar/store.py +504 -0
  12. sqlspec/adapters/adbc/type_converter.py +44 -50
  13. sqlspec/adapters/aiosqlite/_types.py +1 -1
  14. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  15. sqlspec/adapters/aiosqlite/adk/store.py +536 -0
  16. sqlspec/adapters/aiosqlite/config.py +86 -16
  17. sqlspec/adapters/aiosqlite/data_dictionary.py +34 -2
  18. sqlspec/adapters/aiosqlite/driver.py +127 -38
  19. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  21. sqlspec/adapters/aiosqlite/pool.py +7 -7
  22. sqlspec/adapters/asyncmy/__init__.py +7 -1
  23. sqlspec/adapters/asyncmy/_types.py +1 -1
  24. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  25. sqlspec/adapters/asyncmy/adk/store.py +503 -0
  26. sqlspec/adapters/asyncmy/config.py +59 -17
  27. sqlspec/adapters/asyncmy/data_dictionary.py +41 -2
  28. sqlspec/adapters/asyncmy/driver.py +293 -62
  29. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  31. sqlspec/adapters/asyncpg/__init__.py +2 -1
  32. sqlspec/adapters/asyncpg/_type_handlers.py +71 -0
  33. sqlspec/adapters/asyncpg/_types.py +11 -7
  34. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  35. sqlspec/adapters/asyncpg/adk/store.py +460 -0
  36. sqlspec/adapters/asyncpg/config.py +57 -36
  37. sqlspec/adapters/asyncpg/data_dictionary.py +48 -2
  38. sqlspec/adapters/asyncpg/driver.py +153 -23
  39. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  41. sqlspec/adapters/bigquery/_types.py +1 -1
  42. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  43. sqlspec/adapters/bigquery/adk/store.py +585 -0
  44. sqlspec/adapters/bigquery/config.py +36 -11
  45. sqlspec/adapters/bigquery/data_dictionary.py +42 -2
  46. sqlspec/adapters/bigquery/driver.py +489 -144
  47. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  48. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  49. sqlspec/adapters/bigquery/type_converter.py +55 -23
  50. sqlspec/adapters/duckdb/_types.py +2 -2
  51. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  52. sqlspec/adapters/duckdb/adk/store.py +563 -0
  53. sqlspec/adapters/duckdb/config.py +79 -21
  54. sqlspec/adapters/duckdb/data_dictionary.py +41 -2
  55. sqlspec/adapters/duckdb/driver.py +225 -44
  56. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  57. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  58. sqlspec/adapters/duckdb/pool.py +5 -5
  59. sqlspec/adapters/duckdb/type_converter.py +51 -21
  60. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  61. sqlspec/adapters/oracledb/_types.py +20 -2
  62. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  63. sqlspec/adapters/oracledb/adk/store.py +1628 -0
  64. sqlspec/adapters/oracledb/config.py +120 -36
  65. sqlspec/adapters/oracledb/data_dictionary.py +87 -20
  66. sqlspec/adapters/oracledb/driver.py +475 -86
  67. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  68. sqlspec/adapters/oracledb/litestar/store.py +765 -0
  69. sqlspec/adapters/oracledb/migrations.py +316 -25
  70. sqlspec/adapters/oracledb/type_converter.py +91 -16
  71. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  72. sqlspec/adapters/psqlpy/_types.py +2 -1
  73. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  74. sqlspec/adapters/psqlpy/adk/store.py +483 -0
  75. sqlspec/adapters/psqlpy/config.py +45 -19
  76. sqlspec/adapters/psqlpy/data_dictionary.py +48 -2
  77. sqlspec/adapters/psqlpy/driver.py +108 -41
  78. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  79. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  80. sqlspec/adapters/psqlpy/type_converter.py +40 -11
  81. sqlspec/adapters/psycopg/_type_handlers.py +80 -0
  82. sqlspec/adapters/psycopg/_types.py +2 -1
  83. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  84. sqlspec/adapters/psycopg/adk/store.py +962 -0
  85. sqlspec/adapters/psycopg/config.py +65 -37
  86. sqlspec/adapters/psycopg/data_dictionary.py +91 -3
  87. sqlspec/adapters/psycopg/driver.py +200 -78
  88. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  90. sqlspec/adapters/sqlite/__init__.py +2 -1
  91. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  92. sqlspec/adapters/sqlite/_types.py +1 -1
  93. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  94. sqlspec/adapters/sqlite/adk/store.py +582 -0
  95. sqlspec/adapters/sqlite/config.py +85 -16
  96. sqlspec/adapters/sqlite/data_dictionary.py +34 -2
  97. sqlspec/adapters/sqlite/driver.py +120 -52
  98. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  99. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  100. sqlspec/adapters/sqlite/pool.py +5 -5
  101. sqlspec/base.py +45 -26
  102. sqlspec/builder/__init__.py +73 -4
  103. sqlspec/builder/_base.py +91 -58
  104. sqlspec/builder/_column.py +5 -5
  105. sqlspec/builder/_ddl.py +98 -89
  106. sqlspec/builder/_delete.py +5 -4
  107. sqlspec/builder/_dml.py +388 -0
  108. sqlspec/{_sql.py → builder/_factory.py} +41 -44
  109. sqlspec/builder/_insert.py +5 -82
  110. sqlspec/builder/{mixins/_join_operations.py → _join.py} +145 -143
  111. sqlspec/builder/_merge.py +446 -11
  112. sqlspec/builder/_parsing_utils.py +9 -11
  113. sqlspec/builder/_select.py +1313 -25
  114. sqlspec/builder/_update.py +11 -42
  115. sqlspec/cli.py +76 -69
  116. sqlspec/config.py +331 -62
  117. sqlspec/core/__init__.py +5 -4
  118. sqlspec/core/cache.py +18 -18
  119. sqlspec/core/compiler.py +6 -8
  120. sqlspec/core/filters.py +55 -47
  121. sqlspec/core/hashing.py +9 -9
  122. sqlspec/core/parameters.py +76 -45
  123. sqlspec/core/result.py +234 -47
  124. sqlspec/core/splitter.py +16 -17
  125. sqlspec/core/statement.py +32 -31
  126. sqlspec/core/type_conversion.py +3 -2
  127. sqlspec/driver/__init__.py +1 -3
  128. sqlspec/driver/_async.py +183 -160
  129. sqlspec/driver/_common.py +197 -109
  130. sqlspec/driver/_sync.py +189 -161
  131. sqlspec/driver/mixins/_result_tools.py +20 -236
  132. sqlspec/driver/mixins/_sql_translator.py +4 -4
  133. sqlspec/exceptions.py +70 -7
  134. sqlspec/extensions/adk/__init__.py +53 -0
  135. sqlspec/extensions/adk/_types.py +51 -0
  136. sqlspec/extensions/adk/converters.py +172 -0
  137. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  138. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  139. sqlspec/extensions/adk/service.py +181 -0
  140. sqlspec/extensions/adk/store.py +536 -0
  141. sqlspec/extensions/aiosql/adapter.py +69 -61
  142. sqlspec/extensions/fastapi/__init__.py +21 -0
  143. sqlspec/extensions/fastapi/extension.py +331 -0
  144. sqlspec/extensions/fastapi/providers.py +543 -0
  145. sqlspec/extensions/flask/__init__.py +36 -0
  146. sqlspec/extensions/flask/_state.py +71 -0
  147. sqlspec/extensions/flask/_utils.py +40 -0
  148. sqlspec/extensions/flask/extension.py +389 -0
  149. sqlspec/extensions/litestar/__init__.py +21 -4
  150. sqlspec/extensions/litestar/cli.py +54 -10
  151. sqlspec/extensions/litestar/config.py +56 -266
  152. sqlspec/extensions/litestar/handlers.py +46 -17
  153. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  154. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  155. sqlspec/extensions/litestar/plugin.py +349 -224
  156. sqlspec/extensions/litestar/providers.py +25 -25
  157. sqlspec/extensions/litestar/store.py +265 -0
  158. sqlspec/extensions/starlette/__init__.py +10 -0
  159. sqlspec/extensions/starlette/_state.py +25 -0
  160. sqlspec/extensions/starlette/_utils.py +52 -0
  161. sqlspec/extensions/starlette/extension.py +254 -0
  162. sqlspec/extensions/starlette/middleware.py +154 -0
  163. sqlspec/loader.py +30 -49
  164. sqlspec/migrations/base.py +200 -76
  165. sqlspec/migrations/commands.py +591 -62
  166. sqlspec/migrations/context.py +6 -9
  167. sqlspec/migrations/fix.py +199 -0
  168. sqlspec/migrations/loaders.py +47 -19
  169. sqlspec/migrations/runner.py +241 -75
  170. sqlspec/migrations/tracker.py +237 -21
  171. sqlspec/migrations/utils.py +51 -3
  172. sqlspec/migrations/validation.py +177 -0
  173. sqlspec/protocols.py +106 -36
  174. sqlspec/storage/_utils.py +85 -0
  175. sqlspec/storage/backends/fsspec.py +133 -107
  176. sqlspec/storage/backends/local.py +78 -51
  177. sqlspec/storage/backends/obstore.py +276 -168
  178. sqlspec/storage/registry.py +75 -39
  179. sqlspec/typing.py +30 -84
  180. sqlspec/utils/__init__.py +25 -4
  181. sqlspec/utils/arrow_helpers.py +81 -0
  182. sqlspec/utils/config_resolver.py +6 -6
  183. sqlspec/utils/correlation.py +4 -5
  184. sqlspec/utils/data_transformation.py +3 -2
  185. sqlspec/utils/deprecation.py +9 -8
  186. sqlspec/utils/fixtures.py +4 -4
  187. sqlspec/utils/logging.py +46 -6
  188. sqlspec/utils/module_loader.py +205 -5
  189. sqlspec/utils/portal.py +311 -0
  190. sqlspec/utils/schema.py +288 -0
  191. sqlspec/utils/serializers.py +113 -4
  192. sqlspec/utils/sync_tools.py +36 -22
  193. sqlspec/utils/text.py +1 -2
  194. sqlspec/utils/type_guards.py +136 -20
  195. sqlspec/utils/version.py +433 -0
  196. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/METADATA +41 -22
  197. sqlspec-0.28.0.dist-info/RECORD +221 -0
  198. sqlspec/builder/mixins/__init__.py +0 -55
  199. sqlspec/builder/mixins/_cte_and_set_ops.py +0 -253
  200. sqlspec/builder/mixins/_delete_operations.py +0 -50
  201. sqlspec/builder/mixins/_insert_operations.py +0 -282
  202. sqlspec/builder/mixins/_merge_operations.py +0 -698
  203. sqlspec/builder/mixins/_order_limit_operations.py +0 -145
  204. sqlspec/builder/mixins/_pivot_operations.py +0 -157
  205. sqlspec/builder/mixins/_select_operations.py +0 -930
  206. sqlspec/builder/mixins/_update_operations.py +0 -199
  207. sqlspec/builder/mixins/_where_clause.py +0 -1298
  208. sqlspec-0.26.0.dist-info/RECORD +0 -157
  209. sqlspec-0.26.0.dist-info/licenses/NOTICE +0 -29
  210. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/WHEEL +0 -0
  211. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/entry_points.txt +0 -0
  212. {sqlspec-0.26.0.dist-info → sqlspec-0.28.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,503 @@
1
+ """AsyncMy ADK store for Google Agent Development Kit session/event storage."""
2
+
3
+ import json
4
+ from typing import TYPE_CHECKING, Any, Final
5
+
6
+ import asyncmy
7
+
8
+ from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
9
+ from sqlspec.utils.logging import get_logger
10
+
11
+ if TYPE_CHECKING:
12
+ from datetime import datetime
13
+
14
+ from sqlspec.adapters.asyncmy.config import AsyncmyConfig
15
+
16
+ logger = get_logger("adapters.asyncmy.adk.store")
17
+
18
+ __all__ = ("AsyncmyADKStore",)
19
+
20
+ MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146
21
+
22
+
23
+ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]):
24
+ """MySQL/MariaDB ADK store using AsyncMy driver.
25
+
26
+ Implements session and event storage for Google Agent Development Kit
27
+ using MySQL/MariaDB via the AsyncMy driver. Provides:
28
+ - Session state management with JSON storage
29
+ - Event history tracking with BLOB-serialized actions
30
+ - Microsecond-precision timestamps
31
+ - Foreign key constraints with cascade delete
32
+ - Efficient upserts using ON DUPLICATE KEY UPDATE
33
+
34
+ Args:
35
+ config: AsyncmyConfig with extension_config["adk"] settings.
36
+
37
+ Example:
38
+ from sqlspec.adapters.asyncmy import AsyncmyConfig
39
+ from sqlspec.adapters.asyncmy.adk import AsyncmyADKStore
40
+
41
+ config = AsyncmyConfig(
42
+ pool_config={"host": "localhost", ...},
43
+ extension_config={
44
+ "adk": {
45
+ "session_table": "my_sessions",
46
+ "events_table": "my_events",
47
+ "owner_id_column": "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
48
+ }
49
+ }
50
+ )
51
+ store = AsyncmyADKStore(config)
52
+ await store.create_tables()
53
+
54
+ Notes:
55
+ - MySQL JSON type used (not JSONB) - requires MySQL 5.7.8+
56
+ - TIMESTAMP(6) provides microsecond precision
57
+ - InnoDB engine required for foreign key support
58
+ - State merging handled at application level
59
+ - Configuration is read from config.extension_config["adk"]
60
+ """
61
+
62
+ __slots__ = ()
63
+
64
+ def __init__(self, config: "AsyncmyConfig") -> None:
65
+ """Initialize AsyncMy ADK store.
66
+
67
+ Args:
68
+ config: AsyncmyConfig instance.
69
+
70
+ Notes:
71
+ Configuration is read from config.extension_config["adk"]:
72
+ - session_table: Sessions table name (default: "adk_sessions")
73
+ - events_table: Events table name (default: "adk_events")
74
+ - owner_id_column: Optional owner FK column DDL (default: None)
75
+ """
76
+ super().__init__(config)
77
+
78
+ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]":
79
+ """Parse owner ID column DDL for MySQL FOREIGN KEY syntax.
80
+
81
+ MySQL ignores inline REFERENCES syntax in column definitions.
82
+ This method extracts the column definition and creates a separate
83
+ FOREIGN KEY constraint.
84
+
85
+ Args:
86
+ column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
87
+
88
+ Returns:
89
+ Tuple of (column_definition, foreign_key_constraint)
90
+
91
+ Example:
92
+ Input: "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE"
93
+ Output: ("tenant_id BIGINT NOT NULL", "FOREIGN KEY (tenant_id) REFERENCES tenants(id) ON DELETE CASCADE")
94
+ """
95
+ import re
96
+
97
+ references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE)
98
+
99
+ if not references_match:
100
+ return (column_ddl.strip(), "")
101
+
102
+ col_def = column_ddl[: references_match.start()].strip()
103
+ fk_clause = references_match.group(1).strip()
104
+ col_name = col_def.split()[0]
105
+ fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}"
106
+
107
+ return (col_def, fk_constraint)
108
+
109
+ async def _get_create_sessions_table_sql(self) -> str:
110
+ """Get MySQL CREATE TABLE SQL for sessions.
111
+
112
+ Returns:
113
+ SQL statement to create adk_sessions table with indexes.
114
+
115
+ Notes:
116
+ - VARCHAR(128) for IDs and names (sufficient for UUIDs and app names)
117
+ - JSON type for state storage (MySQL 5.7.8+)
118
+ - TIMESTAMP(6) with microsecond precision
119
+ - AUTO-UPDATE on update_time
120
+ - Composite index on (app_name, user_id) for listing
121
+ - Index on update_time DESC for recent session queries
122
+ - Optional owner ID column for multi-tenancy
123
+ - MySQL requires explicit FOREIGN KEY syntax (inline REFERENCES is ignored)
124
+ """
125
+ owner_id_col = ""
126
+ fk_constraint = ""
127
+
128
+ if self._owner_id_column_ddl:
129
+ col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl)
130
+ owner_id_col = f"{col_def},"
131
+ if fk_def:
132
+ fk_constraint = f",\n {fk_def}"
133
+
134
+ return f"""
135
+ CREATE TABLE IF NOT EXISTS {self._session_table} (
136
+ id VARCHAR(128) PRIMARY KEY,
137
+ app_name VARCHAR(128) NOT NULL,
138
+ user_id VARCHAR(128) NOT NULL,
139
+ {owner_id_col}
140
+ state JSON NOT NULL,
141
+ create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
142
+ update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
143
+ INDEX idx_{self._session_table}_app_user (app_name, user_id),
144
+ INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint}
145
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
146
+ """
147
+
148
+ async def _get_create_events_table_sql(self) -> str:
149
+ """Get MySQL CREATE TABLE SQL for events.
150
+
151
+ Returns:
152
+ SQL statement to create adk_events table with indexes.
153
+
154
+ Notes:
155
+ - VARCHAR sizes: id(128), session_id(128), invocation_id(256), author(256),
156
+ branch(256), error_code(256), error_message(1024)
157
+ - BLOB for pickled actions (up to 64KB)
158
+ - JSON for content, grounding_metadata, custom_metadata, long_running_tool_ids_json
159
+ - BOOLEAN for partial, turn_complete, interrupted
160
+ - Foreign key to sessions with CASCADE delete
161
+ - Index on (session_id, timestamp ASC) for ordered event retrieval
162
+ """
163
+ return f"""
164
+ CREATE TABLE IF NOT EXISTS {self._events_table} (
165
+ id VARCHAR(128) PRIMARY KEY,
166
+ session_id VARCHAR(128) NOT NULL,
167
+ app_name VARCHAR(128) NOT NULL,
168
+ user_id VARCHAR(128) NOT NULL,
169
+ invocation_id VARCHAR(256) NOT NULL,
170
+ author VARCHAR(256) NOT NULL,
171
+ actions BLOB NOT NULL,
172
+ long_running_tool_ids_json JSON,
173
+ branch VARCHAR(256),
174
+ timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
175
+ content JSON,
176
+ grounding_metadata JSON,
177
+ custom_metadata JSON,
178
+ partial BOOLEAN,
179
+ turn_complete BOOLEAN,
180
+ interrupted BOOLEAN,
181
+ error_code VARCHAR(256),
182
+ error_message VARCHAR(1024),
183
+ FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE,
184
+ INDEX idx_{self._events_table}_session (session_id, timestamp ASC)
185
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
186
+ """
187
+
188
+ def _get_drop_tables_sql(self) -> "list[str]":
189
+ """Get MySQL DROP TABLE SQL statements.
190
+
191
+ Returns:
192
+ List of SQL statements to drop tables and indexes.
193
+
194
+ Notes:
195
+ Order matters: drop events table (child) before sessions (parent).
196
+ MySQL automatically drops indexes when dropping tables.
197
+ """
198
+ return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"]
199
+
200
+ async def create_tables(self) -> None:
201
+ """Create both sessions and events tables if they don't exist."""
202
+ async with self._config.provide_session() as driver:
203
+ await driver.execute_script(await self._get_create_sessions_table_sql())
204
+ await driver.execute_script(await self._get_create_events_table_sql())
205
+ logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
206
+
207
+ async def create_session(
208
+ self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None
209
+ ) -> SessionRecord:
210
+ """Create a new session.
211
+
212
+ Args:
213
+ session_id: Unique session identifier.
214
+ app_name: Application name.
215
+ user_id: User identifier.
216
+ state: Initial session state.
217
+ owner_id: Optional owner ID value for owner_id_column (if configured).
218
+
219
+ Returns:
220
+ Created session record.
221
+
222
+ Notes:
223
+ Uses INSERT with UTC_TIMESTAMP(6) for create_time and update_time.
224
+ State is JSON-serialized before insertion.
225
+ If owner_id_column is configured, owner_id must be provided.
226
+ """
227
+ state_json = json.dumps(state)
228
+
229
+ params: tuple[Any, ...]
230
+ if self._owner_id_column_name:
231
+ sql = f"""
232
+ INSERT INTO {self._session_table} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time)
233
+ VALUES (%s, %s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6))
234
+ """
235
+ params = (session_id, app_name, user_id, owner_id, state_json)
236
+ else:
237
+ sql = f"""
238
+ INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time)
239
+ VALUES (%s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6))
240
+ """
241
+ params = (session_id, app_name, user_id, state_json)
242
+
243
+ async with self._config.provide_connection() as conn, conn.cursor() as cursor:
244
+ await cursor.execute(sql, params)
245
+ await conn.commit()
246
+
247
+ return await self.get_session(session_id) # type: ignore[return-value]
248
+
249
+ async def get_session(self, session_id: str) -> "SessionRecord | None":
250
+ """Get session by ID.
251
+
252
+ Args:
253
+ session_id: Session identifier.
254
+
255
+ Returns:
256
+ Session record or None if not found.
257
+
258
+ Notes:
259
+ MySQL returns datetime objects for TIMESTAMP columns.
260
+ JSON is parsed from database storage.
261
+ """
262
+ sql = f"""
263
+ SELECT id, app_name, user_id, state, create_time, update_time
264
+ FROM {self._session_table}
265
+ WHERE id = %s
266
+ """
267
+
268
+ try:
269
+ async with self._config.provide_connection() as conn, conn.cursor() as cursor:
270
+ await cursor.execute(sql, (session_id,))
271
+ row = await cursor.fetchone()
272
+
273
+ if row is None:
274
+ return None
275
+
276
+ session_id_val, app_name, user_id, state_json, create_time, update_time = row
277
+
278
+ return SessionRecord(
279
+ id=session_id_val,
280
+ app_name=app_name,
281
+ user_id=user_id,
282
+ state=json.loads(state_json) if isinstance(state_json, str) else state_json,
283
+ create_time=create_time,
284
+ update_time=update_time,
285
+ )
286
+ except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue][reportAttributeAccessIssue]
287
+ if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR:
288
+ return None
289
+ raise
290
+
291
+ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None:
292
+ """Update session state.
293
+
294
+ Args:
295
+ session_id: Session identifier.
296
+ state: New state dictionary (replaces existing state).
297
+
298
+ Notes:
299
+ This replaces the entire state dictionary.
300
+ Uses update_time auto-update trigger.
301
+ """
302
+ state_json = json.dumps(state)
303
+
304
+ sql = f"""
305
+ UPDATE {self._session_table}
306
+ SET state = %s
307
+ WHERE id = %s
308
+ """
309
+
310
+ async with self._config.provide_connection() as conn, conn.cursor() as cursor:
311
+ await cursor.execute(sql, (state_json, session_id))
312
+ await conn.commit()
313
+
314
+ async def delete_session(self, session_id: str) -> None:
315
+ """Delete session and all associated events (cascade).
316
+
317
+ Args:
318
+ session_id: Session identifier.
319
+
320
+ Notes:
321
+ Foreign key constraint ensures events are cascade-deleted.
322
+ """
323
+ sql = f"DELETE FROM {self._session_table} WHERE id = %s"
324
+
325
+ async with self._config.provide_connection() as conn, conn.cursor() as cursor:
326
+ await cursor.execute(sql, (session_id,))
327
+ await conn.commit()
328
+
329
+ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
330
+ """List sessions for an app, optionally filtered by user.
331
+
332
+ Args:
333
+ app_name: Application name.
334
+ user_id: User identifier. If None, lists all sessions for the app.
335
+
336
+ Returns:
337
+ List of session records ordered by update_time DESC.
338
+
339
+ Notes:
340
+ Uses composite index on (app_name, user_id) when user_id is provided.
341
+ """
342
+ if user_id is None:
343
+ sql = f"""
344
+ SELECT id, app_name, user_id, state, create_time, update_time
345
+ FROM {self._session_table}
346
+ WHERE app_name = %s
347
+ ORDER BY update_time DESC
348
+ """
349
+ params: tuple[str, ...] = (app_name,)
350
+ else:
351
+ sql = f"""
352
+ SELECT id, app_name, user_id, state, create_time, update_time
353
+ FROM {self._session_table}
354
+ WHERE app_name = %s AND user_id = %s
355
+ ORDER BY update_time DESC
356
+ """
357
+ params = (app_name, user_id)
358
+
359
+ try:
360
+ async with self._config.provide_connection() as conn, conn.cursor() as cursor:
361
+ await cursor.execute(sql, params)
362
+ rows = await cursor.fetchall()
363
+
364
+ return [
365
+ SessionRecord(
366
+ id=row[0],
367
+ app_name=row[1],
368
+ user_id=row[2],
369
+ state=json.loads(row[3]) if isinstance(row[3], str) else row[3],
370
+ create_time=row[4],
371
+ update_time=row[5],
372
+ )
373
+ for row in rows
374
+ ]
375
+ except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue]
376
+ if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR:
377
+ return []
378
+ raise
379
+
380
+ async def append_event(self, event_record: EventRecord) -> None:
381
+ """Append an event to a session.
382
+
383
+ Args:
384
+ event_record: Event record to store.
385
+
386
+ Notes:
387
+ Uses UTC_TIMESTAMP(6) for timestamp if not provided.
388
+ JSON fields are serialized before insertion.
389
+ """
390
+ content_json = json.dumps(event_record.get("content")) if event_record.get("content") else None
391
+ grounding_metadata_json = (
392
+ json.dumps(event_record.get("grounding_metadata")) if event_record.get("grounding_metadata") else None
393
+ )
394
+ custom_metadata_json = (
395
+ json.dumps(event_record.get("custom_metadata")) if event_record.get("custom_metadata") else None
396
+ )
397
+
398
+ sql = f"""
399
+ INSERT INTO {self._events_table} (
400
+ id, session_id, app_name, user_id, invocation_id, author, actions,
401
+ long_running_tool_ids_json, branch, timestamp, content,
402
+ grounding_metadata, custom_metadata, partial, turn_complete,
403
+ interrupted, error_code, error_message
404
+ ) VALUES (
405
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
406
+ )
407
+ """
408
+
409
+ async with self._config.provide_connection() as conn, conn.cursor() as cursor:
410
+ await cursor.execute(
411
+ sql,
412
+ (
413
+ event_record["id"],
414
+ event_record["session_id"],
415
+ event_record["app_name"],
416
+ event_record["user_id"],
417
+ event_record["invocation_id"],
418
+ event_record["author"],
419
+ event_record["actions"],
420
+ event_record.get("long_running_tool_ids_json"),
421
+ event_record.get("branch"),
422
+ event_record["timestamp"],
423
+ content_json,
424
+ grounding_metadata_json,
425
+ custom_metadata_json,
426
+ event_record.get("partial"),
427
+ event_record.get("turn_complete"),
428
+ event_record.get("interrupted"),
429
+ event_record.get("error_code"),
430
+ event_record.get("error_message"),
431
+ ),
432
+ )
433
+ await conn.commit()
434
+
435
+ async def get_events(
436
+ self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None
437
+ ) -> "list[EventRecord]":
438
+ """Get events for a session.
439
+
440
+ Args:
441
+ session_id: Session identifier.
442
+ after_timestamp: Only return events after this time.
443
+ limit: Maximum number of events to return.
444
+
445
+ Returns:
446
+ List of event records ordered by timestamp ASC.
447
+
448
+ Notes:
449
+ Uses index on (session_id, timestamp ASC).
450
+ Parses JSON fields and converts BLOB actions to bytes.
451
+ """
452
+ where_clauses = ["session_id = %s"]
453
+ params: list[Any] = [session_id]
454
+
455
+ if after_timestamp is not None:
456
+ where_clauses.append("timestamp > %s")
457
+ params.append(after_timestamp)
458
+
459
+ where_clause = " AND ".join(where_clauses)
460
+ limit_clause = f" LIMIT {limit}" if limit else ""
461
+
462
+ sql = f"""
463
+ SELECT id, session_id, app_name, user_id, invocation_id, author, actions,
464
+ long_running_tool_ids_json, branch, timestamp, content,
465
+ grounding_metadata, custom_metadata, partial, turn_complete,
466
+ interrupted, error_code, error_message
467
+ FROM {self._events_table}
468
+ WHERE {where_clause}
469
+ ORDER BY timestamp ASC{limit_clause}
470
+ """
471
+
472
+ try:
473
+ async with self._config.provide_connection() as conn, conn.cursor() as cursor:
474
+ await cursor.execute(sql, params)
475
+ rows = await cursor.fetchall()
476
+
477
+ return [
478
+ EventRecord(
479
+ id=row[0],
480
+ session_id=row[1],
481
+ app_name=row[2],
482
+ user_id=row[3],
483
+ invocation_id=row[4],
484
+ author=row[5],
485
+ actions=bytes(row[6]),
486
+ long_running_tool_ids_json=row[7],
487
+ branch=row[8],
488
+ timestamp=row[9],
489
+ content=json.loads(row[10]) if row[10] and isinstance(row[10], str) else row[10],
490
+ grounding_metadata=json.loads(row[11]) if row[11] and isinstance(row[11], str) else row[11],
491
+ custom_metadata=json.loads(row[12]) if row[12] and isinstance(row[12], str) else row[12],
492
+ partial=row[13],
493
+ turn_complete=row[14],
494
+ interrupted=row[15],
495
+ error_code=row[16],
496
+ error_message=row[17],
497
+ )
498
+ for row in rows
499
+ ]
500
+ except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue]
501
+ if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR:
502
+ return []
503
+ raise
@@ -3,7 +3,7 @@
3
3
  import logging
4
4
  from collections.abc import AsyncGenerator
5
5
  from contextlib import asynccontextmanager
6
- from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypedDict, Union
6
+ from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
7
7
 
8
8
  import asyncmy
9
9
  from asyncmy.cursors import Cursor, DictCursor # pyright: ignore
@@ -13,20 +13,23 @@ from typing_extensions import NotRequired
13
13
  from sqlspec.adapters.asyncmy._types import AsyncmyConnection
14
14
  from sqlspec.adapters.asyncmy.driver import AsyncmyCursor, AsyncmyDriver, asyncmy_statement_config
15
15
  from sqlspec.config import AsyncDatabaseConfig
16
+ from sqlspec.utils.serializers import from_json, to_json
16
17
 
17
18
  if TYPE_CHECKING:
19
+ from collections.abc import Callable
20
+
18
21
  from asyncmy.cursors import Cursor, DictCursor # pyright: ignore
19
22
  from asyncmy.pool import Pool # pyright: ignore
20
23
 
21
24
  from sqlspec.core.statement import StatementConfig
22
25
 
23
26
 
24
- __all__ = ("AsyncmyConfig", "AsyncmyConnectionParams", "AsyncmyPoolParams")
27
+ __all__ = ("AsyncmyConfig", "AsyncmyConnectionParams", "AsyncmyDriverFeatures", "AsyncmyPoolParams")
25
28
 
26
29
  logger = logging.getLogger(__name__)
27
30
 
28
31
 
29
- class AsyncmyConnectionParams(TypedDict, total=False):
32
+ class AsyncmyConnectionParams(TypedDict):
30
33
  """Asyncmy connection parameters."""
31
34
 
32
35
  host: NotRequired[str]
@@ -44,11 +47,11 @@ class AsyncmyConnectionParams(TypedDict, total=False):
44
47
  ssl: NotRequired[Any]
45
48
  sql_mode: NotRequired[str]
46
49
  init_command: NotRequired[str]
47
- cursor_class: NotRequired[Union[type["Cursor"], type["DictCursor"]]]
50
+ cursor_class: NotRequired[type["Cursor"] | type["DictCursor"]]
48
51
  extra: NotRequired[dict[str, Any]]
49
52
 
50
53
 
51
- class AsyncmyPoolParams(AsyncmyConnectionParams, total=False):
54
+ class AsyncmyPoolParams(AsyncmyConnectionParams):
52
55
  """Asyncmy pool parameters."""
53
56
 
54
57
  minsize: NotRequired[int]
@@ -57,21 +60,41 @@ class AsyncmyPoolParams(AsyncmyConnectionParams, total=False):
57
60
  pool_recycle: NotRequired[int]
58
61
 
59
62
 
63
+ class AsyncmyDriverFeatures(TypedDict):
64
+ """Asyncmy driver feature flags.
65
+
66
+ MySQL/MariaDB handle JSON natively, but custom serializers can be provided
67
+ for specialized use cases (e.g., orjson for performance, msgspec for type safety).
68
+
69
+ json_serializer: Custom JSON serializer function.
70
+ Defaults to sqlspec.utils.serializers.to_json.
71
+ Use for performance (orjson) or custom encoding.
72
+ json_deserializer: Custom JSON deserializer function.
73
+ Defaults to sqlspec.utils.serializers.from_json.
74
+ Use for performance (orjson) or custom decoding.
75
+ """
76
+
77
+ json_serializer: NotRequired["Callable[[Any], str]"]
78
+ json_deserializer: NotRequired["Callable[[str], Any]"]
79
+
80
+
60
81
  class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", AsyncmyDriver]): # pyright: ignore
61
82
  """Configuration for Asyncmy database connections."""
62
83
 
63
84
  driver_type: ClassVar[type[AsyncmyDriver]] = AsyncmyDriver
64
85
  connection_type: "ClassVar[type[AsyncmyConnection]]" = AsyncmyConnection # pyright: ignore
86
+ supports_transactional_ddl: ClassVar[bool] = False
65
87
 
66
88
  def __init__(
67
89
  self,
68
90
  *,
69
- pool_config: "Optional[Union[AsyncmyPoolParams, dict[str, Any]]]" = None,
70
- pool_instance: "Optional[AsyncmyPool]" = None,
71
- migration_config: Optional[dict[str, Any]] = None,
72
- statement_config: "Optional[StatementConfig]" = None,
73
- driver_features: "Optional[dict[str, Any]]" = None,
74
- bind_key: "Optional[str]" = None,
91
+ pool_config: "AsyncmyPoolParams | dict[str, Any] | None" = None,
92
+ pool_instance: "AsyncmyPool | None" = None,
93
+ migration_config: dict[str, Any] | None = None,
94
+ statement_config: "StatementConfig | None" = None,
95
+ driver_features: "AsyncmyDriverFeatures | dict[str, Any] | None" = None,
96
+ bind_key: "str | None" = None,
97
+ extension_config: "dict[str, dict[str, Any]] | None" = None,
75
98
  ) -> None:
76
99
  """Initialize Asyncmy configuration.
77
100
 
@@ -80,8 +103,9 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", Asyncm
80
103
  pool_instance: Existing pool instance to use
81
104
  migration_config: Migration configuration
82
105
  statement_config: Statement configuration override
83
- driver_features: Driver feature configuration
106
+ driver_features: Driver feature configuration (TypedDict or dict)
84
107
  bind_key: Optional unique identifier for this configuration
108
+ extension_config: Extension-specific configuration (e.g., Litestar plugin settings)
85
109
  """
86
110
  processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {}
87
111
  if "extra" in processed_pool_config:
@@ -96,17 +120,33 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", Asyncm
96
120
  if statement_config is None:
97
121
  statement_config = asyncmy_statement_config
98
122
 
123
+ processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
124
+
125
+ if "json_serializer" not in processed_driver_features:
126
+ processed_driver_features["json_serializer"] = to_json
127
+ if "json_deserializer" not in processed_driver_features:
128
+ processed_driver_features["json_deserializer"] = from_json
129
+
99
130
  super().__init__(
100
131
  pool_config=processed_pool_config,
101
132
  pool_instance=pool_instance,
102
133
  migration_config=migration_config,
103
134
  statement_config=statement_config,
104
- driver_features=driver_features or {},
135
+ driver_features=processed_driver_features,
105
136
  bind_key=bind_key,
137
+ extension_config=extension_config,
106
138
  )
107
139
 
108
140
  async def _create_pool(self) -> "AsyncmyPool": # pyright: ignore
109
- """Create the actual async connection pool."""
141
+ """Create the actual async connection pool.
142
+
143
+ MySQL/MariaDB handle JSON types natively without requiring connection-level
144
+ type handlers. JSON serialization is handled via type_coercion_map in the
145
+ driver's statement_config (see driver.py).
146
+
147
+ Future driver_features can be added here if needed (e.g., custom connection
148
+ initialization, specialized type handling).
149
+ """
110
150
  return await asyncmy.create_pool(**dict(self.pool_config)) # pyright: ignore
111
151
 
112
152
  async def _close_pool(self) -> None:
@@ -146,7 +186,7 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", Asyncm
146
186
 
147
187
  @asynccontextmanager
148
188
  async def provide_session(
149
- self, *args: Any, statement_config: "Optional[StatementConfig]" = None, **kwargs: Any
189
+ self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
150
190
  ) -> AsyncGenerator[AsyncmyDriver, None]:
151
191
  """Provide an async driver session context manager.
152
192
 
@@ -159,8 +199,10 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", Asyncm
159
199
  An AsyncmyDriver instance.
160
200
  """
161
201
  async with self.provide_connection(*args, **kwargs) as connection:
162
- final_statement_config = statement_config or asyncmy_statement_config
163
- yield self.driver_type(connection=connection, statement_config=final_statement_config)
202
+ final_statement_config = statement_config or self.statement_config or asyncmy_statement_config
203
+ yield self.driver_type(
204
+ connection=connection, statement_config=final_statement_config, driver_features=self.driver_features
205
+ )
164
206
 
165
207
  async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: ignore
166
208
  """Provide async pool instance.