agno 2.2.13__py3-none-any.whl → 2.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (383) hide show
  1. agno/agent/__init__.py +6 -0
  2. agno/agent/agent.py +5252 -3145
  3. agno/agent/remote.py +525 -0
  4. agno/api/api.py +2 -0
  5. agno/client/__init__.py +3 -0
  6. agno/client/a2a/__init__.py +10 -0
  7. agno/client/a2a/client.py +554 -0
  8. agno/client/a2a/schemas.py +112 -0
  9. agno/client/a2a/utils.py +369 -0
  10. agno/client/os.py +2669 -0
  11. agno/compression/__init__.py +3 -0
  12. agno/compression/manager.py +247 -0
  13. agno/culture/manager.py +2 -2
  14. agno/db/base.py +927 -6
  15. agno/db/dynamo/dynamo.py +788 -2
  16. agno/db/dynamo/schemas.py +128 -0
  17. agno/db/dynamo/utils.py +26 -3
  18. agno/db/firestore/firestore.py +674 -50
  19. agno/db/firestore/schemas.py +41 -0
  20. agno/db/firestore/utils.py +25 -10
  21. agno/db/gcs_json/gcs_json_db.py +506 -3
  22. agno/db/gcs_json/utils.py +14 -2
  23. agno/db/in_memory/in_memory_db.py +203 -4
  24. agno/db/in_memory/utils.py +14 -2
  25. agno/db/json/json_db.py +498 -2
  26. agno/db/json/utils.py +14 -2
  27. agno/db/migrations/manager.py +199 -0
  28. agno/db/migrations/utils.py +19 -0
  29. agno/db/migrations/v1_to_v2.py +54 -16
  30. agno/db/migrations/versions/__init__.py +0 -0
  31. agno/db/migrations/versions/v2_3_0.py +977 -0
  32. agno/db/mongo/async_mongo.py +1013 -39
  33. agno/db/mongo/mongo.py +684 -4
  34. agno/db/mongo/schemas.py +48 -0
  35. agno/db/mongo/utils.py +17 -0
  36. agno/db/mysql/__init__.py +2 -1
  37. agno/db/mysql/async_mysql.py +2958 -0
  38. agno/db/mysql/mysql.py +722 -53
  39. agno/db/mysql/schemas.py +77 -11
  40. agno/db/mysql/utils.py +151 -8
  41. agno/db/postgres/async_postgres.py +1254 -137
  42. agno/db/postgres/postgres.py +2316 -93
  43. agno/db/postgres/schemas.py +153 -21
  44. agno/db/postgres/utils.py +22 -7
  45. agno/db/redis/redis.py +531 -3
  46. agno/db/redis/schemas.py +36 -0
  47. agno/db/redis/utils.py +31 -15
  48. agno/db/schemas/evals.py +1 -0
  49. agno/db/schemas/memory.py +20 -9
  50. agno/db/singlestore/schemas.py +70 -1
  51. agno/db/singlestore/singlestore.py +737 -74
  52. agno/db/singlestore/utils.py +13 -3
  53. agno/db/sqlite/async_sqlite.py +1069 -89
  54. agno/db/sqlite/schemas.py +133 -1
  55. agno/db/sqlite/sqlite.py +2203 -165
  56. agno/db/sqlite/utils.py +21 -11
  57. agno/db/surrealdb/models.py +25 -0
  58. agno/db/surrealdb/surrealdb.py +603 -1
  59. agno/db/utils.py +60 -0
  60. agno/eval/__init__.py +26 -3
  61. agno/eval/accuracy.py +25 -12
  62. agno/eval/agent_as_judge.py +871 -0
  63. agno/eval/base.py +29 -0
  64. agno/eval/performance.py +10 -4
  65. agno/eval/reliability.py +22 -13
  66. agno/eval/utils.py +2 -1
  67. agno/exceptions.py +42 -0
  68. agno/hooks/__init__.py +3 -0
  69. agno/hooks/decorator.py +164 -0
  70. agno/integrations/discord/client.py +13 -2
  71. agno/knowledge/__init__.py +4 -0
  72. agno/knowledge/chunking/code.py +90 -0
  73. agno/knowledge/chunking/document.py +65 -4
  74. agno/knowledge/chunking/fixed.py +4 -1
  75. agno/knowledge/chunking/markdown.py +102 -11
  76. agno/knowledge/chunking/recursive.py +2 -2
  77. agno/knowledge/chunking/semantic.py +130 -48
  78. agno/knowledge/chunking/strategy.py +18 -0
  79. agno/knowledge/embedder/azure_openai.py +0 -1
  80. agno/knowledge/embedder/google.py +1 -1
  81. agno/knowledge/embedder/mistral.py +1 -1
  82. agno/knowledge/embedder/nebius.py +1 -1
  83. agno/knowledge/embedder/openai.py +16 -12
  84. agno/knowledge/filesystem.py +412 -0
  85. agno/knowledge/knowledge.py +4261 -1199
  86. agno/knowledge/protocol.py +134 -0
  87. agno/knowledge/reader/arxiv_reader.py +3 -2
  88. agno/knowledge/reader/base.py +9 -7
  89. agno/knowledge/reader/csv_reader.py +91 -42
  90. agno/knowledge/reader/docx_reader.py +9 -10
  91. agno/knowledge/reader/excel_reader.py +225 -0
  92. agno/knowledge/reader/field_labeled_csv_reader.py +38 -48
  93. agno/knowledge/reader/firecrawl_reader.py +3 -2
  94. agno/knowledge/reader/json_reader.py +16 -22
  95. agno/knowledge/reader/markdown_reader.py +15 -14
  96. agno/knowledge/reader/pdf_reader.py +33 -28
  97. agno/knowledge/reader/pptx_reader.py +9 -10
  98. agno/knowledge/reader/reader_factory.py +135 -1
  99. agno/knowledge/reader/s3_reader.py +8 -16
  100. agno/knowledge/reader/tavily_reader.py +3 -3
  101. agno/knowledge/reader/text_reader.py +15 -14
  102. agno/knowledge/reader/utils/__init__.py +17 -0
  103. agno/knowledge/reader/utils/spreadsheet.py +114 -0
  104. agno/knowledge/reader/web_search_reader.py +8 -65
  105. agno/knowledge/reader/website_reader.py +16 -13
  106. agno/knowledge/reader/wikipedia_reader.py +36 -3
  107. agno/knowledge/reader/youtube_reader.py +3 -2
  108. agno/knowledge/remote_content/__init__.py +33 -0
  109. agno/knowledge/remote_content/config.py +266 -0
  110. agno/knowledge/remote_content/remote_content.py +105 -17
  111. agno/knowledge/utils.py +76 -22
  112. agno/learn/__init__.py +71 -0
  113. agno/learn/config.py +463 -0
  114. agno/learn/curate.py +185 -0
  115. agno/learn/machine.py +725 -0
  116. agno/learn/schemas.py +1114 -0
  117. agno/learn/stores/__init__.py +38 -0
  118. agno/learn/stores/decision_log.py +1156 -0
  119. agno/learn/stores/entity_memory.py +3275 -0
  120. agno/learn/stores/learned_knowledge.py +1583 -0
  121. agno/learn/stores/protocol.py +117 -0
  122. agno/learn/stores/session_context.py +1217 -0
  123. agno/learn/stores/user_memory.py +1495 -0
  124. agno/learn/stores/user_profile.py +1220 -0
  125. agno/learn/utils.py +209 -0
  126. agno/media.py +22 -6
  127. agno/memory/__init__.py +14 -1
  128. agno/memory/manager.py +223 -8
  129. agno/memory/strategies/__init__.py +15 -0
  130. agno/memory/strategies/base.py +66 -0
  131. agno/memory/strategies/summarize.py +196 -0
  132. agno/memory/strategies/types.py +37 -0
  133. agno/models/aimlapi/aimlapi.py +17 -0
  134. agno/models/anthropic/claude.py +434 -59
  135. agno/models/aws/bedrock.py +121 -20
  136. agno/models/aws/claude.py +131 -274
  137. agno/models/azure/ai_foundry.py +10 -6
  138. agno/models/azure/openai_chat.py +33 -10
  139. agno/models/base.py +1162 -561
  140. agno/models/cerebras/cerebras.py +120 -24
  141. agno/models/cerebras/cerebras_openai.py +21 -2
  142. agno/models/cohere/chat.py +65 -6
  143. agno/models/cometapi/cometapi.py +18 -1
  144. agno/models/dashscope/dashscope.py +2 -3
  145. agno/models/deepinfra/deepinfra.py +18 -1
  146. agno/models/deepseek/deepseek.py +69 -3
  147. agno/models/fireworks/fireworks.py +18 -1
  148. agno/models/google/gemini.py +959 -89
  149. agno/models/google/utils.py +22 -0
  150. agno/models/groq/groq.py +48 -18
  151. agno/models/huggingface/huggingface.py +17 -6
  152. agno/models/ibm/watsonx.py +16 -6
  153. agno/models/internlm/internlm.py +18 -1
  154. agno/models/langdb/langdb.py +13 -1
  155. agno/models/litellm/chat.py +88 -9
  156. agno/models/litellm/litellm_openai.py +18 -1
  157. agno/models/message.py +24 -5
  158. agno/models/meta/llama.py +40 -13
  159. agno/models/meta/llama_openai.py +22 -21
  160. agno/models/metrics.py +12 -0
  161. agno/models/mistral/mistral.py +8 -4
  162. agno/models/n1n/__init__.py +3 -0
  163. agno/models/n1n/n1n.py +57 -0
  164. agno/models/nebius/nebius.py +6 -7
  165. agno/models/nvidia/nvidia.py +20 -3
  166. agno/models/ollama/__init__.py +2 -0
  167. agno/models/ollama/chat.py +17 -6
  168. agno/models/ollama/responses.py +100 -0
  169. agno/models/openai/__init__.py +2 -0
  170. agno/models/openai/chat.py +117 -26
  171. agno/models/openai/open_responses.py +46 -0
  172. agno/models/openai/responses.py +110 -32
  173. agno/models/openrouter/__init__.py +2 -0
  174. agno/models/openrouter/openrouter.py +67 -2
  175. agno/models/openrouter/responses.py +146 -0
  176. agno/models/perplexity/perplexity.py +19 -1
  177. agno/models/portkey/portkey.py +7 -6
  178. agno/models/requesty/requesty.py +19 -2
  179. agno/models/response.py +20 -2
  180. agno/models/sambanova/sambanova.py +20 -3
  181. agno/models/siliconflow/siliconflow.py +19 -2
  182. agno/models/together/together.py +20 -3
  183. agno/models/vercel/v0.py +20 -3
  184. agno/models/vertexai/claude.py +124 -4
  185. agno/models/vllm/vllm.py +19 -14
  186. agno/models/xai/xai.py +19 -2
  187. agno/os/app.py +467 -137
  188. agno/os/auth.py +253 -5
  189. agno/os/config.py +22 -0
  190. agno/os/interfaces/a2a/a2a.py +7 -6
  191. agno/os/interfaces/a2a/router.py +635 -26
  192. agno/os/interfaces/a2a/utils.py +32 -33
  193. agno/os/interfaces/agui/agui.py +5 -3
  194. agno/os/interfaces/agui/router.py +26 -16
  195. agno/os/interfaces/agui/utils.py +97 -57
  196. agno/os/interfaces/base.py +7 -7
  197. agno/os/interfaces/slack/router.py +16 -7
  198. agno/os/interfaces/slack/slack.py +7 -7
  199. agno/os/interfaces/whatsapp/router.py +35 -7
  200. agno/os/interfaces/whatsapp/security.py +3 -1
  201. agno/os/interfaces/whatsapp/whatsapp.py +11 -8
  202. agno/os/managers.py +326 -0
  203. agno/os/mcp.py +652 -79
  204. agno/os/middleware/__init__.py +4 -0
  205. agno/os/middleware/jwt.py +718 -115
  206. agno/os/middleware/trailing_slash.py +27 -0
  207. agno/os/router.py +105 -1558
  208. agno/os/routers/agents/__init__.py +3 -0
  209. agno/os/routers/agents/router.py +655 -0
  210. agno/os/routers/agents/schema.py +288 -0
  211. agno/os/routers/components/__init__.py +3 -0
  212. agno/os/routers/components/components.py +475 -0
  213. agno/os/routers/database.py +155 -0
  214. agno/os/routers/evals/evals.py +111 -18
  215. agno/os/routers/evals/schemas.py +38 -5
  216. agno/os/routers/evals/utils.py +80 -11
  217. agno/os/routers/health.py +3 -3
  218. agno/os/routers/knowledge/knowledge.py +284 -35
  219. agno/os/routers/knowledge/schemas.py +14 -2
  220. agno/os/routers/memory/memory.py +274 -11
  221. agno/os/routers/memory/schemas.py +44 -3
  222. agno/os/routers/metrics/metrics.py +30 -15
  223. agno/os/routers/metrics/schemas.py +10 -6
  224. agno/os/routers/registry/__init__.py +3 -0
  225. agno/os/routers/registry/registry.py +337 -0
  226. agno/os/routers/session/session.py +143 -14
  227. agno/os/routers/teams/__init__.py +3 -0
  228. agno/os/routers/teams/router.py +550 -0
  229. agno/os/routers/teams/schema.py +280 -0
  230. agno/os/routers/traces/__init__.py +3 -0
  231. agno/os/routers/traces/schemas.py +414 -0
  232. agno/os/routers/traces/traces.py +549 -0
  233. agno/os/routers/workflows/__init__.py +3 -0
  234. agno/os/routers/workflows/router.py +757 -0
  235. agno/os/routers/workflows/schema.py +139 -0
  236. agno/os/schema.py +157 -584
  237. agno/os/scopes.py +469 -0
  238. agno/os/settings.py +3 -0
  239. agno/os/utils.py +574 -185
  240. agno/reasoning/anthropic.py +85 -1
  241. agno/reasoning/azure_ai_foundry.py +93 -1
  242. agno/reasoning/deepseek.py +102 -2
  243. agno/reasoning/default.py +6 -7
  244. agno/reasoning/gemini.py +87 -3
  245. agno/reasoning/groq.py +109 -2
  246. agno/reasoning/helpers.py +6 -7
  247. agno/reasoning/manager.py +1238 -0
  248. agno/reasoning/ollama.py +93 -1
  249. agno/reasoning/openai.py +115 -1
  250. agno/reasoning/vertexai.py +85 -1
  251. agno/registry/__init__.py +3 -0
  252. agno/registry/registry.py +68 -0
  253. agno/remote/__init__.py +3 -0
  254. agno/remote/base.py +581 -0
  255. agno/run/__init__.py +2 -4
  256. agno/run/agent.py +134 -19
  257. agno/run/base.py +49 -1
  258. agno/run/cancel.py +65 -52
  259. agno/run/cancellation_management/__init__.py +9 -0
  260. agno/run/cancellation_management/base.py +78 -0
  261. agno/run/cancellation_management/in_memory_cancellation_manager.py +100 -0
  262. agno/run/cancellation_management/redis_cancellation_manager.py +236 -0
  263. agno/run/requirement.py +181 -0
  264. agno/run/team.py +111 -19
  265. agno/run/workflow.py +2 -1
  266. agno/session/agent.py +57 -92
  267. agno/session/summary.py +1 -1
  268. agno/session/team.py +62 -115
  269. agno/session/workflow.py +353 -57
  270. agno/skills/__init__.py +17 -0
  271. agno/skills/agent_skills.py +377 -0
  272. agno/skills/errors.py +32 -0
  273. agno/skills/loaders/__init__.py +4 -0
  274. agno/skills/loaders/base.py +27 -0
  275. agno/skills/loaders/local.py +216 -0
  276. agno/skills/skill.py +65 -0
  277. agno/skills/utils.py +107 -0
  278. agno/skills/validator.py +277 -0
  279. agno/table.py +10 -0
  280. agno/team/__init__.py +5 -1
  281. agno/team/remote.py +447 -0
  282. agno/team/team.py +3769 -2202
  283. agno/tools/brandfetch.py +27 -18
  284. agno/tools/browserbase.py +225 -16
  285. agno/tools/crawl4ai.py +3 -0
  286. agno/tools/duckduckgo.py +25 -71
  287. agno/tools/exa.py +0 -21
  288. agno/tools/file.py +14 -13
  289. agno/tools/file_generation.py +12 -6
  290. agno/tools/firecrawl.py +15 -7
  291. agno/tools/function.py +94 -113
  292. agno/tools/google_bigquery.py +11 -2
  293. agno/tools/google_drive.py +4 -3
  294. agno/tools/knowledge.py +9 -4
  295. agno/tools/mcp/mcp.py +301 -18
  296. agno/tools/mcp/multi_mcp.py +269 -14
  297. agno/tools/mem0.py +11 -10
  298. agno/tools/memory.py +47 -46
  299. agno/tools/mlx_transcribe.py +10 -7
  300. agno/tools/models/nebius.py +5 -5
  301. agno/tools/models_labs.py +20 -10
  302. agno/tools/nano_banana.py +151 -0
  303. agno/tools/parallel.py +0 -7
  304. agno/tools/postgres.py +76 -36
  305. agno/tools/python.py +14 -6
  306. agno/tools/reasoning.py +30 -23
  307. agno/tools/redshift.py +406 -0
  308. agno/tools/shopify.py +1519 -0
  309. agno/tools/spotify.py +919 -0
  310. agno/tools/tavily.py +4 -1
  311. agno/tools/toolkit.py +253 -18
  312. agno/tools/websearch.py +93 -0
  313. agno/tools/website.py +1 -1
  314. agno/tools/wikipedia.py +1 -1
  315. agno/tools/workflow.py +56 -48
  316. agno/tools/yfinance.py +12 -11
  317. agno/tracing/__init__.py +12 -0
  318. agno/tracing/exporter.py +161 -0
  319. agno/tracing/schemas.py +276 -0
  320. agno/tracing/setup.py +112 -0
  321. agno/utils/agent.py +251 -10
  322. agno/utils/cryptography.py +22 -0
  323. agno/utils/dttm.py +33 -0
  324. agno/utils/events.py +264 -7
  325. agno/utils/hooks.py +111 -3
  326. agno/utils/http.py +161 -2
  327. agno/utils/mcp.py +49 -8
  328. agno/utils/media.py +22 -1
  329. agno/utils/models/ai_foundry.py +9 -2
  330. agno/utils/models/claude.py +20 -5
  331. agno/utils/models/cohere.py +9 -2
  332. agno/utils/models/llama.py +9 -2
  333. agno/utils/models/mistral.py +4 -2
  334. agno/utils/os.py +0 -0
  335. agno/utils/print_response/agent.py +99 -16
  336. agno/utils/print_response/team.py +223 -24
  337. agno/utils/print_response/workflow.py +0 -2
  338. agno/utils/prompts.py +8 -6
  339. agno/utils/remote.py +23 -0
  340. agno/utils/response.py +1 -13
  341. agno/utils/string.py +91 -2
  342. agno/utils/team.py +62 -12
  343. agno/utils/tokens.py +657 -0
  344. agno/vectordb/base.py +15 -2
  345. agno/vectordb/cassandra/cassandra.py +1 -1
  346. agno/vectordb/chroma/__init__.py +2 -1
  347. agno/vectordb/chroma/chromadb.py +468 -23
  348. agno/vectordb/clickhouse/clickhousedb.py +1 -1
  349. agno/vectordb/couchbase/couchbase.py +6 -2
  350. agno/vectordb/lancedb/lance_db.py +7 -38
  351. agno/vectordb/lightrag/lightrag.py +7 -6
  352. agno/vectordb/milvus/milvus.py +118 -84
  353. agno/vectordb/mongodb/__init__.py +2 -1
  354. agno/vectordb/mongodb/mongodb.py +14 -31
  355. agno/vectordb/pgvector/pgvector.py +120 -66
  356. agno/vectordb/pineconedb/pineconedb.py +2 -19
  357. agno/vectordb/qdrant/__init__.py +2 -1
  358. agno/vectordb/qdrant/qdrant.py +33 -56
  359. agno/vectordb/redis/__init__.py +2 -1
  360. agno/vectordb/redis/redisdb.py +19 -31
  361. agno/vectordb/singlestore/singlestore.py +17 -9
  362. agno/vectordb/surrealdb/surrealdb.py +2 -38
  363. agno/vectordb/weaviate/__init__.py +2 -1
  364. agno/vectordb/weaviate/weaviate.py +7 -3
  365. agno/workflow/__init__.py +5 -1
  366. agno/workflow/agent.py +2 -2
  367. agno/workflow/condition.py +12 -10
  368. agno/workflow/loop.py +28 -9
  369. agno/workflow/parallel.py +21 -13
  370. agno/workflow/remote.py +362 -0
  371. agno/workflow/router.py +12 -9
  372. agno/workflow/step.py +261 -36
  373. agno/workflow/steps.py +12 -8
  374. agno/workflow/types.py +40 -77
  375. agno/workflow/workflow.py +939 -213
  376. {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/METADATA +134 -181
  377. agno-2.4.3.dist-info/RECORD +677 -0
  378. {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/WHEEL +1 -1
  379. agno/tools/googlesearch.py +0 -98
  380. agno/tools/memori.py +0 -339
  381. agno-2.2.13.dist-info/RECORD +0 -575
  382. {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/licenses/LICENSE +0 -0
  383. {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2958 @@
1
+ import time
2
+ from datetime import date, datetime, timedelta, timezone
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
4
+ from uuid import uuid4
5
+
6
+ if TYPE_CHECKING:
7
+ from agno.tracing.schemas import Span, Trace
8
+
9
+ from agno.db.base import AsyncBaseDb, SessionType
10
+ from agno.db.migrations.manager import MigrationManager
11
+ from agno.db.mysql.schemas import get_table_schema_definition
12
+ from agno.db.mysql.utils import (
13
+ abulk_upsert_metrics,
14
+ acreate_schema,
15
+ ais_table_available,
16
+ ais_valid_table,
17
+ apply_sorting,
18
+ calculate_date_metrics,
19
+ deserialize_cultural_knowledge_from_db,
20
+ fetch_all_sessions_data,
21
+ get_dates_to_calculate_metrics_for,
22
+ serialize_cultural_knowledge_for_db,
23
+ )
24
+ from agno.db.schemas.culture import CulturalKnowledge
25
+ from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
26
+ from agno.db.schemas.knowledge import KnowledgeRow
27
+ from agno.db.schemas.memory import UserMemory
28
+ from agno.session import AgentSession, Session, TeamSession, WorkflowSession
29
+ from agno.utils.log import log_debug, log_error, log_info, log_warning
30
+ from agno.utils.string import generate_id
31
+
32
+ try:
33
+ from sqlalchemy import TEXT, ForeignKey, Index, UniqueConstraint, and_, cast, func, update
34
+ from sqlalchemy.dialects import mysql
35
+ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
36
+ from sqlalchemy.schema import Column, MetaData, Table
37
+ from sqlalchemy.sql.expression import select, text
38
+ except ImportError:
39
+ raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`")
40
+
41
+
42
+ class AsyncMySQLDb(AsyncBaseDb):
43
+ def __init__(
44
+ self,
45
+ id: Optional[str] = None,
46
+ db_url: Optional[str] = None,
47
+ db_engine: Optional[AsyncEngine] = None,
48
+ db_schema: Optional[str] = None,
49
+ session_table: Optional[str] = None,
50
+ memory_table: Optional[str] = None,
51
+ metrics_table: Optional[str] = None,
52
+ eval_table: Optional[str] = None,
53
+ knowledge_table: Optional[str] = None,
54
+ culture_table: Optional[str] = None,
55
+ traces_table: Optional[str] = None,
56
+ spans_table: Optional[str] = None,
57
+ versions_table: Optional[str] = None,
58
+ create_schema: bool = True,
59
+ ):
60
+ """
61
+ Async interface for interacting with a MySQL database.
62
+
63
+ The following order is used to determine the database connection:
64
+ 1. Use the db_engine if provided
65
+ 2. Use the db_url
66
+ 3. Raise an error if neither is provided
67
+
68
+ Args:
69
+ id (Optional[str]): The ID of the database.
70
+ db_url (Optional[str]): The database URL to connect to. Should use asyncmy driver (e.g. mysql+asyncmy://...)
71
+ db_engine (Optional[AsyncEngine]): The SQLAlchemy async database engine to use.
72
+ db_schema (Optional[str]): The database schema to use.
73
+ session_table (Optional[str]): Name of the table to store Agent, Team and Workflow sessions.
74
+ memory_table (Optional[str]): Name of the table to store memories.
75
+ metrics_table (Optional[str]): Name of the table to store metrics.
76
+ eval_table (Optional[str]): Name of the table to store evaluation runs data.
77
+ knowledge_table (Optional[str]): Name of the table to store knowledge content.
78
+ culture_table (Optional[str]): Name of the table to store cultural knowledge.
79
+ traces_table (Optional[str]): Name of the table to store run traces.
80
+ spans_table (Optional[str]): Name of the table to store span events.
81
+ versions_table (Optional[str]): Name of the table to store schema versions.
82
+ create_schema (bool): Whether to automatically create the database schema if it doesn't exist.
83
+ Set to False if schema is managed externally (e.g., via migrations). Defaults to True.
84
+
85
+ Raises:
86
+ ValueError: If neither db_url nor db_engine is provided.
87
+ ValueError: If none of the tables are provided.
88
+ """
89
+ if id is None:
90
+ base_seed = db_url or str(db_engine.url) if db_engine else "" # type: ignore
91
+ schema_suffix = db_schema if db_schema is not None else "ai"
92
+ seed = f"{base_seed}#{schema_suffix}"
93
+ id = generate_id(seed)
94
+
95
+ super().__init__(
96
+ id=id,
97
+ session_table=session_table,
98
+ memory_table=memory_table,
99
+ metrics_table=metrics_table,
100
+ eval_table=eval_table,
101
+ knowledge_table=knowledge_table,
102
+ culture_table=culture_table,
103
+ traces_table=traces_table,
104
+ spans_table=spans_table,
105
+ versions_table=versions_table,
106
+ )
107
+
108
+ _engine: Optional[AsyncEngine] = db_engine
109
+ if _engine is None and db_url is not None:
110
+ _engine = create_async_engine(db_url)
111
+ if _engine is None:
112
+ raise ValueError("One of db_url or db_engine must be provided")
113
+
114
+ self.db_url: Optional[str] = db_url
115
+ self.db_engine: AsyncEngine = _engine
116
+ self.db_schema: str = db_schema if db_schema is not None else "ai"
117
+ self.metadata: MetaData = MetaData(schema=self.db_schema)
118
+ self.create_schema: bool = create_schema
119
+
120
+ # Initialize database session factory
121
+ self.async_session_factory = async_sessionmaker(
122
+ bind=self.db_engine,
123
+ expire_on_commit=False,
124
+ )
125
+
126
+ async def close(self) -> None:
127
+ """Close database connections and dispose of the connection pool.
128
+
129
+ Should be called during application shutdown to properly release
130
+ all database connections.
131
+ """
132
+ if self.db_engine is not None:
133
+ await self.db_engine.dispose()
134
+
135
+ # -- DB methods --
136
+ async def table_exists(self, table_name: str) -> bool:
137
+ """Check if a table with the given name exists in the MySQL database.
138
+
139
+ Args:
140
+ table_name: Name of the table to check
141
+
142
+ Returns:
143
+ bool: True if the table exists in the database, False otherwise
144
+ """
145
+ async with self.async_session_factory() as sess:
146
+ return await ais_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
147
+
148
+ async def _create_table(self, table_name: str, table_type: str) -> Table:
149
+ """
150
+ Create a table with the appropriate schema based on the table type.
151
+
152
+ Args:
153
+ table_name (str): Name of the table to create
154
+ table_type (str): Type of table (used to get schema definition)
155
+ db_schema (str): Database schema name
156
+
157
+ Returns:
158
+ Table: SQLAlchemy Table object
159
+ """
160
+ try:
161
+ # Pass traces_table_name and db_schema for spans table foreign key resolution
162
+ table_schema = get_table_schema_definition(
163
+ table_type, traces_table_name=self.trace_table_name, db_schema=self.db_schema
164
+ ).copy()
165
+
166
+ log_debug(f"Creating table {self.db_schema}.{table_name} with schema: {table_schema}")
167
+
168
+ columns: List[Column] = []
169
+ indexes: List[str] = []
170
+ unique_constraints: List[str] = []
171
+ schema_unique_constraints = table_schema.pop("_unique_constraints", [])
172
+
173
+ # Get the columns, indexes, and unique constraints from the table schema
174
+ for col_name, col_config in table_schema.items():
175
+ column_args = [col_name, col_config["type"]()]
176
+ column_kwargs = {}
177
+ if col_config.get("primary_key", False):
178
+ column_kwargs["primary_key"] = True
179
+ if "nullable" in col_config:
180
+ column_kwargs["nullable"] = col_config["nullable"]
181
+ if col_config.get("index", False):
182
+ indexes.append(col_name)
183
+ if col_config.get("unique", False):
184
+ column_kwargs["unique"] = True
185
+ unique_constraints.append(col_name)
186
+
187
+ # Handle foreign key constraint
188
+ if "foreign_key" in col_config:
189
+ column_args.append(ForeignKey(col_config["foreign_key"]))
190
+
191
+ columns.append(Column(*column_args, **column_kwargs)) # type: ignore
192
+
193
+ # Create the table object - use self.metadata to maintain FK references
194
+ table = Table(table_name, self.metadata, *columns, schema=self.db_schema)
195
+
196
+ # Add multi-column unique constraints with table-specific names
197
+ for constraint in schema_unique_constraints:
198
+ constraint_name = f"{table_name}_{constraint['name']}"
199
+ constraint_columns = constraint["columns"]
200
+ table.append_constraint(UniqueConstraint(*constraint_columns, name=constraint_name))
201
+
202
+ # Add indexes to the table definition
203
+ for idx_col in indexes:
204
+ idx_name = f"idx_{table_name}_{idx_col}"
205
+ table.append_constraint(Index(idx_name, idx_col))
206
+
207
+ # Create schema if not exists
208
+ if self.create_schema:
209
+ async with self.async_session_factory() as sess, sess.begin():
210
+ await acreate_schema(session=sess, db_schema=self.db_schema)
211
+
212
+ # Create table
213
+ table_created = False
214
+ if not await self.table_exists(table_name):
215
+ async with self.db_engine.begin() as conn:
216
+ await conn.run_sync(table.create, checkfirst=True)
217
+ log_debug(f"Successfully created table '{table_name}'")
218
+ table_created = True
219
+ else:
220
+ log_debug(f"Table {self.db_schema}.{table_name} already exists, skipping creation")
221
+
222
+ # Create indexes
223
+ for idx in table.indexes:
224
+ try:
225
+ # Check if index already exists
226
+ async with self.async_session_factory() as sess:
227
+ exists_query = text(
228
+ "SELECT 1 FROM information_schema.statistics WHERE table_schema = :schema "
229
+ "AND table_name = :table_name AND index_name = :index_name"
230
+ )
231
+ result = await sess.execute(
232
+ exists_query, {"schema": self.db_schema, "table_name": table_name, "index_name": idx.name}
233
+ )
234
+ exists = result.scalar() is not None
235
+ if exists:
236
+ log_debug(
237
+ f"Index {idx.name} already exists in {self.db_schema}.{table_name}, skipping creation"
238
+ )
239
+ continue
240
+
241
+ async with self.db_engine.begin() as conn:
242
+ await conn.run_sync(idx.create)
243
+ log_debug(f"Created index: {idx.name} for table {self.db_schema}.{table_name}")
244
+
245
+ except Exception as e:
246
+ log_error(f"Error creating index {idx.name}: {e}")
247
+
248
+ log_debug(f"Successfully created table {table_name} in schema {self.db_schema}")
249
+
250
+ # Store the schema version for the created table
251
+ if table_name != self.versions_table_name and table_created:
252
+ latest_schema_version = MigrationManager(self).latest_schema_version
253
+ await self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
254
+ log_info(
255
+ f"Successfully stored version {latest_schema_version.public} in database for table {table_name}"
256
+ )
257
+
258
+ return table
259
+
260
+ except Exception as e:
261
+ log_error(f"Could not create table {self.db_schema}.{table_name}: {e}")
262
+ raise
263
+
264
+ async def _create_all_tables(self):
265
+ """Create all tables for the database."""
266
+ tables_to_create = [
267
+ (self.session_table_name, "sessions"),
268
+ (self.memory_table_name, "memories"),
269
+ (self.metrics_table_name, "metrics"),
270
+ (self.eval_table_name, "evals"),
271
+ (self.knowledge_table_name, "knowledge"),
272
+ (self.culture_table_name, "culture"),
273
+ (self.trace_table_name, "traces"),
274
+ (self.span_table_name, "spans"),
275
+ (self.versions_table_name, "versions"),
276
+ ]
277
+
278
+ for table_name, table_type in tables_to_create:
279
+ await self._get_or_create_table(
280
+ table_name=table_name, table_type=table_type, create_table_if_not_found=True
281
+ )
282
+
283
+ async def _get_table(self, table_type: str, create_table_if_not_found: Optional[bool] = False) -> Table:
284
+ if table_type == "sessions":
285
+ self.session_table = await self._get_or_create_table(
286
+ table_name=self.session_table_name,
287
+ table_type="sessions",
288
+ create_table_if_not_found=create_table_if_not_found,
289
+ )
290
+ return self.session_table
291
+
292
+ if table_type == "memories":
293
+ self.memory_table = await self._get_or_create_table(
294
+ table_name=self.memory_table_name,
295
+ table_type="memories",
296
+ create_table_if_not_found=create_table_if_not_found,
297
+ )
298
+ return self.memory_table
299
+
300
+ if table_type == "metrics":
301
+ self.metrics_table = await self._get_or_create_table(
302
+ table_name=self.metrics_table_name,
303
+ table_type="metrics",
304
+ create_table_if_not_found=create_table_if_not_found,
305
+ )
306
+ return self.metrics_table
307
+
308
+ if table_type == "evals":
309
+ self.eval_table = await self._get_or_create_table(
310
+ table_name=self.eval_table_name,
311
+ table_type="evals",
312
+ create_table_if_not_found=create_table_if_not_found,
313
+ )
314
+ return self.eval_table
315
+
316
+ if table_type == "knowledge":
317
+ self.knowledge_table = await self._get_or_create_table(
318
+ table_name=self.knowledge_table_name,
319
+ table_type="knowledge",
320
+ create_table_if_not_found=create_table_if_not_found,
321
+ )
322
+ return self.knowledge_table
323
+
324
+ if table_type == "culture":
325
+ self.culture_table = await self._get_or_create_table(
326
+ table_name=self.culture_table_name,
327
+ table_type="culture",
328
+ create_table_if_not_found=create_table_if_not_found,
329
+ )
330
+ return self.culture_table
331
+
332
+ if table_type == "versions":
333
+ self.versions_table = await self._get_or_create_table(
334
+ table_name=self.versions_table_name,
335
+ table_type="versions",
336
+ create_table_if_not_found=create_table_if_not_found,
337
+ )
338
+ return self.versions_table
339
+
340
+ if table_type == "traces":
341
+ self.traces_table = await self._get_or_create_table(
342
+ table_name=self.trace_table_name,
343
+ table_type="traces",
344
+ create_table_if_not_found=create_table_if_not_found,
345
+ )
346
+ return self.traces_table
347
+
348
+ if table_type == "spans":
349
+ # Ensure traces table exists first (spans has FK to traces)
350
+ if create_table_if_not_found:
351
+ await self._get_table(table_type="traces", create_table_if_not_found=True)
352
+ self.spans_table = await self._get_or_create_table(
353
+ table_name=self.span_table_name,
354
+ table_type="spans",
355
+ create_table_if_not_found=create_table_if_not_found,
356
+ )
357
+ return self.spans_table
358
+
359
+ raise ValueError(f"Unknown table type: {table_type}")
360
+
361
+ async def _get_or_create_table(
362
+ self, table_name: str, table_type: str, create_table_if_not_found: Optional[bool] = False
363
+ ) -> Table:
364
+ """
365
+ Check if the table exists and is valid, else create it.
366
+
367
+ Args:
368
+ table_name (str): Name of the table to get or create
369
+ table_type (str): Type of table (used to get schema definition)
370
+
371
+ Returns:
372
+ Table: SQLAlchemy Table object representing the schema.
373
+ """
374
+
375
+ async with self.async_session_factory() as sess, sess.begin():
376
+ table_is_available = await ais_table_available(
377
+ session=sess, table_name=table_name, db_schema=self.db_schema
378
+ )
379
+
380
+ if (not table_is_available) and create_table_if_not_found:
381
+ return await self._create_table(table_name=table_name, table_type=table_type)
382
+
383
+ if not await ais_valid_table(
384
+ db_engine=self.db_engine,
385
+ table_name=table_name,
386
+ table_type=table_type,
387
+ db_schema=self.db_schema,
388
+ ):
389
+ raise ValueError(f"Table {self.db_schema}.{table_name} has an invalid schema")
390
+
391
+ try:
392
+ async with self.db_engine.connect() as conn:
393
+
394
+ def create_table(connection):
395
+ return Table(table_name, self.metadata, schema=self.db_schema, autoload_with=connection)
396
+
397
+ table = await conn.run_sync(create_table)
398
+ return table
399
+
400
+ except Exception as e:
401
+ log_error(f"Error loading existing table {self.db_schema}.{table_name}: {e}")
402
+ raise
403
+
404
+ async def get_latest_schema_version(self, table_name: str) -> str:
405
+ """Get the latest version of the database schema."""
406
+ table = await self._get_table(table_type="versions", create_table_if_not_found=True)
407
+ async with self.async_session_factory() as sess:
408
+ # Latest version for the given table
409
+ stmt = select(table).where(table.c.table_name == table_name).order_by(table.c.version.desc()).limit(1) # type: ignore
410
+ result = await sess.execute(stmt)
411
+ row = result.fetchone()
412
+ if row is None:
413
+ return "2.0.0"
414
+ version_dict = dict(row._mapping)
415
+ return version_dict.get("version") or "2.0.0"
416
+
417
+ async def upsert_schema_version(self, table_name: str, version: str) -> None:
418
+ """Upsert the schema version into the database."""
419
+ table = await self._get_table(table_type="versions", create_table_if_not_found=True)
420
+ current_datetime = datetime.now().isoformat()
421
+ async with self.async_session_factory() as sess, sess.begin():
422
+ stmt = mysql.insert(table).values( # type: ignore
423
+ table_name=table_name,
424
+ version=version,
425
+ created_at=current_datetime, # Store as ISO format string
426
+ updated_at=current_datetime,
427
+ )
428
+ # Update version if table_name already exists
429
+ stmt = stmt.on_duplicate_key_update(
430
+ version=version,
431
+ created_at=current_datetime,
432
+ updated_at=current_datetime,
433
+ )
434
+ await sess.execute(stmt)
435
+
436
+ # -- Session methods --
437
+ async def delete_session(self, session_id: str) -> bool:
438
+ """
439
+ Delete a session from the database.
440
+
441
+ Args:
442
+ session_id (str): ID of the session to delete
443
+
444
+ Returns:
445
+ bool: True if the session was deleted, False otherwise.
446
+
447
+ Raises:
448
+ Exception: If an error occurs during deletion.
449
+ """
450
+ try:
451
+ table = await self._get_table(table_type="sessions")
452
+
453
+ async with self.async_session_factory() as sess, sess.begin():
454
+ delete_stmt = table.delete().where(table.c.session_id == session_id)
455
+ result = await sess.execute(delete_stmt)
456
+
457
+ if result.rowcount == 0: # type: ignore
458
+ log_debug(f"No session found to delete with session_id: {session_id} in table {table.name}")
459
+ return False
460
+
461
+ else:
462
+ log_debug(f"Successfully deleted session with session_id: {session_id} in table {table.name}")
463
+ return True
464
+
465
+ except Exception as e:
466
+ log_error(f"Error deleting session: {e}")
467
+ return False
468
+
469
+ async def delete_sessions(self, session_ids: List[str]) -> None:
470
+ """Delete all given sessions from the database.
471
+ Can handle multiple session types in the same run.
472
+
473
+ Args:
474
+ session_ids (List[str]): The IDs of the sessions to delete.
475
+
476
+ Raises:
477
+ Exception: If an error occurs during deletion.
478
+ """
479
+ try:
480
+ table = await self._get_table(table_type="sessions")
481
+
482
+ async with self.async_session_factory() as sess, sess.begin():
483
+ delete_stmt = table.delete().where(table.c.session_id.in_(session_ids))
484
+ result = await sess.execute(delete_stmt)
485
+
486
+ log_debug(f"Successfully deleted {result.rowcount} sessions") # type: ignore
487
+
488
+ except Exception as e:
489
+ log_error(f"Error deleting sessions: {e}")
490
+
491
+ async def get_session(
492
+ self,
493
+ session_id: str,
494
+ session_type: SessionType,
495
+ user_id: Optional[str] = None,
496
+ deserialize: Optional[bool] = True,
497
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
498
+ """
499
+ Read a session from the database.
500
+
501
+ Args:
502
+ session_id (str): ID of the session to read.
503
+ user_id (Optional[str]): User ID to filter by. Defaults to None.
504
+ session_type (Optional[SessionType]): Type of session to read. Defaults to None.
505
+ deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
506
+
507
+ Returns:
508
+ Union[Session, Dict[str, Any], None]:
509
+ - When deserialize=True: Session object
510
+ - When deserialize=False: Session dictionary
511
+
512
+ Raises:
513
+ Exception: If an error occurs during retrieval.
514
+ """
515
+ try:
516
+ table = await self._get_table(table_type="sessions")
517
+
518
+ async with self.async_session_factory() as sess:
519
+ stmt = select(table).where(table.c.session_id == session_id)
520
+
521
+ if user_id is not None:
522
+ stmt = stmt.where(table.c.user_id == user_id)
523
+ result = await sess.execute(stmt)
524
+ row = result.fetchone()
525
+ if row is None:
526
+ return None
527
+
528
+ session = dict(row._mapping)
529
+
530
+ if not deserialize:
531
+ return session
532
+
533
+ if session_type == SessionType.AGENT:
534
+ return AgentSession.from_dict(session)
535
+ elif session_type == SessionType.TEAM:
536
+ return TeamSession.from_dict(session)
537
+ elif session_type == SessionType.WORKFLOW:
538
+ return WorkflowSession.from_dict(session)
539
+ else:
540
+ raise ValueError(f"Invalid session type: {session_type}")
541
+
542
+ except Exception as e:
543
+ log_error(f"Exception reading from session table: {e}")
544
+ return None
545
+
546
+ async def get_sessions(
547
+ self,
548
+ session_type: Optional[SessionType] = None,
549
+ user_id: Optional[str] = None,
550
+ component_id: Optional[str] = None,
551
+ session_name: Optional[str] = None,
552
+ start_timestamp: Optional[int] = None,
553
+ end_timestamp: Optional[int] = None,
554
+ limit: Optional[int] = None,
555
+ page: Optional[int] = None,
556
+ sort_by: Optional[str] = None,
557
+ sort_order: Optional[str] = None,
558
+ deserialize: Optional[bool] = True,
559
+ ) -> Union[List[Session], Tuple[List[Dict[str, Any]], int]]:
560
+ """
561
+ Get all sessions in the given table. Can filter by user_id and entity_id.
562
+
563
+ Args:
564
+ user_id (Optional[str]): The ID of the user to filter by.
565
+ component_id (Optional[str]): The ID of the agent / workflow to filter by.
566
+ start_timestamp (Optional[int]): The start timestamp to filter by.
567
+ end_timestamp (Optional[int]): The end timestamp to filter by.
568
+ session_name (Optional[str]): The name of the session to filter by.
569
+ limit (Optional[int]): The maximum number of sessions to return. Defaults to None.
570
+ page (Optional[int]): The page number to return. Defaults to None.
571
+ sort_by (Optional[str]): The field to sort by. Defaults to None.
572
+ sort_order (Optional[str]): The sort order. Defaults to None.
573
+ deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
574
+
575
+ Returns:
576
+ Union[List[Session], Tuple[List[Dict], int]]:
577
+ - When deserialize=True: List of Session objects
578
+ - When deserialize=False: Tuple of (session dictionaries, total count)
579
+
580
+ Raises:
581
+ Exception: If an error occurs during retrieval.
582
+ """
583
+ try:
584
+ table = await self._get_table(table_type="sessions")
585
+
586
+ async with self.async_session_factory() as sess, sess.begin():
587
+ stmt = select(table)
588
+
589
+ # Filtering
590
+ if user_id is not None:
591
+ stmt = stmt.where(table.c.user_id == user_id)
592
+ if component_id is not None:
593
+ if session_type == SessionType.AGENT:
594
+ stmt = stmt.where(table.c.agent_id == component_id)
595
+ elif session_type == SessionType.TEAM:
596
+ stmt = stmt.where(table.c.team_id == component_id)
597
+ elif session_type == SessionType.WORKFLOW:
598
+ stmt = stmt.where(table.c.workflow_id == component_id)
599
+ if start_timestamp is not None:
600
+ stmt = stmt.where(table.c.created_at >= start_timestamp)
601
+ if end_timestamp is not None:
602
+ stmt = stmt.where(table.c.created_at <= end_timestamp)
603
+ if session_name is not None:
604
+ # MySQL JSON extraction syntax
605
+ stmt = stmt.where(
606
+ func.coalesce(
607
+ func.json_unquote(func.json_extract(table.c.session_data, "$.session_name")), ""
608
+ ).ilike(f"%{session_name}%")
609
+ )
610
+ if session_type is not None:
611
+ session_type_value = session_type.value if isinstance(session_type, SessionType) else session_type
612
+ stmt = stmt.where(table.c.session_type == session_type_value)
613
+
614
+ count_stmt = select(func.count()).select_from(stmt.alias())
615
+ total_count = await sess.scalar(count_stmt) or 0
616
+
617
+ # Sorting
618
+ stmt = apply_sorting(stmt, table, sort_by, sort_order)
619
+
620
+ # Paginating
621
+ if limit is not None:
622
+ stmt = stmt.limit(limit)
623
+ if page is not None:
624
+ stmt = stmt.offset((page - 1) * limit)
625
+
626
+ result = await sess.execute(stmt)
627
+ records = result.fetchall()
628
+ if records is None:
629
+ return [], 0
630
+
631
+ session = [dict(record._mapping) for record in records]
632
+ if not deserialize:
633
+ return session, total_count
634
+
635
+ if session_type == SessionType.AGENT:
636
+ return [AgentSession.from_dict(record) for record in session] # type: ignore
637
+ elif session_type == SessionType.TEAM:
638
+ return [TeamSession.from_dict(record) for record in session] # type: ignore
639
+ elif session_type == SessionType.WORKFLOW:
640
+ return [WorkflowSession.from_dict(record) for record in session] # type: ignore
641
+ else:
642
+ raise ValueError(f"Invalid session type: {session_type}")
643
+
644
+ except Exception as e:
645
+ log_error(f"Exception reading from session table: {e}")
646
+ return [] if deserialize else ([], 0)
647
+
648
+ async def rename_session(
649
+ self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
650
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
651
+ """
652
+ Rename a session in the database.
653
+
654
+ Args:
655
+ session_id (str): The ID of the session to rename.
656
+ session_type (SessionType): The type of session to rename.
657
+ session_name (str): The new name for the session.
658
+ deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
659
+
660
+ Returns:
661
+ Optional[Union[Session, Dict[str, Any]]]:
662
+ - When deserialize=True: Session object
663
+ - When deserialize=False: Session dictionary
664
+
665
+ Raises:
666
+ Exception: If an error occurs during renaming.
667
+ """
668
+ try:
669
+ table = await self._get_table(table_type="sessions")
670
+
671
+ async with self.async_session_factory() as sess, sess.begin():
672
+ # MySQL JSON_SET syntax
673
+ stmt = (
674
+ update(table)
675
+ .where(table.c.session_id == session_id)
676
+ .where(table.c.session_type == session_type.value)
677
+ .values(session_data=func.json_set(table.c.session_data, "$.session_name", session_name))
678
+ )
679
+ await sess.execute(stmt)
680
+
681
+ # Fetch the updated row
682
+ select_stmt = select(table).where(table.c.session_id == session_id)
683
+ result = await sess.execute(select_stmt)
684
+ row = result.fetchone()
685
+ if not row:
686
+ return None
687
+
688
+ log_debug(f"Renamed session with id '{session_id}' to '{session_name}'")
689
+
690
+ session = dict(row._mapping)
691
+ if not deserialize:
692
+ return session
693
+
694
+ # Return the appropriate session type
695
+ if session_type == SessionType.AGENT:
696
+ return AgentSession.from_dict(session)
697
+ elif session_type == SessionType.TEAM:
698
+ return TeamSession.from_dict(session)
699
+ elif session_type == SessionType.WORKFLOW:
700
+ return WorkflowSession.from_dict(session)
701
+ else:
702
+ raise ValueError(f"Invalid session type: {session_type}")
703
+
704
+ except Exception as e:
705
+ log_error(f"Exception renaming session: {e}")
706
+ return None
707
+
708
+ async def upsert_session(
709
+ self, session: Session, deserialize: Optional[bool] = True
710
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
711
+ """
712
+ Insert or update a session in the database.
713
+
714
+ Args:
715
+ session (Session): The session data to upsert.
716
+ deserialize (Optional[bool]): Whether to deserialize the session. Defaults to True.
717
+
718
+ Returns:
719
+ Optional[Union[Session, Dict[str, Any]]]:
720
+ - When deserialize=True: Session object
721
+ - When deserialize=False: Session dictionary
722
+
723
+ Raises:
724
+ Exception: If an error occurs during upsert.
725
+ """
726
+ try:
727
+ table = await self._get_table(table_type="sessions", create_table_if_not_found=True)
728
+ session_dict = session.to_dict()
729
+
730
+ if isinstance(session, AgentSession):
731
+ async with self.async_session_factory() as sess, sess.begin():
732
+ current_time = int(time.time())
733
+ stmt = mysql.insert(table).values(
734
+ session_id=session_dict.get("session_id"),
735
+ session_type=SessionType.AGENT.value,
736
+ agent_id=session_dict.get("agent_id"),
737
+ user_id=session_dict.get("user_id"),
738
+ runs=session_dict.get("runs"),
739
+ agent_data=session_dict.get("agent_data"),
740
+ session_data=session_dict.get("session_data"),
741
+ summary=session_dict.get("summary"),
742
+ metadata=session_dict.get("metadata"),
743
+ created_at=session_dict.get("created_at") or current_time,
744
+ updated_at=session_dict.get("updated_at") or current_time,
745
+ )
746
+ stmt = stmt.on_duplicate_key_update(
747
+ agent_id=session_dict.get("agent_id"),
748
+ user_id=session_dict.get("user_id"),
749
+ agent_data=session_dict.get("agent_data"),
750
+ session_data=session_dict.get("session_data"),
751
+ summary=session_dict.get("summary"),
752
+ metadata=session_dict.get("metadata"),
753
+ runs=session_dict.get("runs"),
754
+ updated_at=int(time.time()),
755
+ )
756
+ await sess.execute(stmt)
757
+
758
+ # Fetch the row
759
+ select_stmt = select(table).where(table.c.session_id == session_dict.get("session_id"))
760
+ result = await sess.execute(select_stmt)
761
+ row = result.fetchone()
762
+ if row is None:
763
+ return None
764
+ session_dict = dict(row._mapping)
765
+
766
+ log_debug(f"Upserted agent session with id '{session_dict.get('session_id')}'")
767
+
768
+ if not deserialize:
769
+ return session_dict
770
+ return AgentSession.from_dict(session_dict)
771
+
772
+ elif isinstance(session, TeamSession):
773
+ async with self.async_session_factory() as sess, sess.begin():
774
+ current_time = int(time.time())
775
+ stmt = mysql.insert(table).values(
776
+ session_id=session_dict.get("session_id"),
777
+ session_type=SessionType.TEAM.value,
778
+ team_id=session_dict.get("team_id"),
779
+ user_id=session_dict.get("user_id"),
780
+ runs=session_dict.get("runs"),
781
+ team_data=session_dict.get("team_data"),
782
+ session_data=session_dict.get("session_data"),
783
+ summary=session_dict.get("summary"),
784
+ metadata=session_dict.get("metadata"),
785
+ created_at=session_dict.get("created_at") or current_time,
786
+ updated_at=session_dict.get("updated_at") or current_time,
787
+ )
788
+ stmt = stmt.on_duplicate_key_update(
789
+ team_id=session_dict.get("team_id"),
790
+ user_id=session_dict.get("user_id"),
791
+ team_data=session_dict.get("team_data"),
792
+ session_data=session_dict.get("session_data"),
793
+ summary=session_dict.get("summary"),
794
+ metadata=session_dict.get("metadata"),
795
+ runs=session_dict.get("runs"),
796
+ updated_at=int(time.time()),
797
+ )
798
+ await sess.execute(stmt)
799
+
800
+ # Fetch the row
801
+ select_stmt = select(table).where(table.c.session_id == session_dict.get("session_id"))
802
+ result = await sess.execute(select_stmt)
803
+ row = result.fetchone()
804
+ if row is None:
805
+ return None
806
+ session_dict = dict(row._mapping)
807
+
808
+ log_debug(f"Upserted team session with id '{session_dict.get('session_id')}'")
809
+
810
+ if not deserialize:
811
+ return session_dict
812
+ return TeamSession.from_dict(session_dict)
813
+
814
+ elif isinstance(session, WorkflowSession):
815
+ async with self.async_session_factory() as sess, sess.begin():
816
+ current_time = int(time.time())
817
+ stmt = mysql.insert(table).values(
818
+ session_id=session_dict.get("session_id"),
819
+ session_type=SessionType.WORKFLOW.value,
820
+ workflow_id=session_dict.get("workflow_id"),
821
+ user_id=session_dict.get("user_id"),
822
+ runs=session_dict.get("runs"),
823
+ workflow_data=session_dict.get("workflow_data"),
824
+ session_data=session_dict.get("session_data"),
825
+ summary=session_dict.get("summary"),
826
+ metadata=session_dict.get("metadata"),
827
+ created_at=session_dict.get("created_at") or current_time,
828
+ updated_at=session_dict.get("updated_at") or current_time,
829
+ )
830
+ stmt = stmt.on_duplicate_key_update(
831
+ workflow_id=session_dict.get("workflow_id"),
832
+ user_id=session_dict.get("user_id"),
833
+ workflow_data=session_dict.get("workflow_data"),
834
+ session_data=session_dict.get("session_data"),
835
+ summary=session_dict.get("summary"),
836
+ metadata=session_dict.get("metadata"),
837
+ runs=session_dict.get("runs"),
838
+ updated_at=int(time.time()),
839
+ )
840
+ await sess.execute(stmt)
841
+
842
+ # Fetch the row
843
+ select_stmt = select(table).where(table.c.session_id == session_dict.get("session_id"))
844
+ result = await sess.execute(select_stmt)
845
+ row = result.fetchone()
846
+ if row is None:
847
+ return None
848
+ session_dict = dict(row._mapping)
849
+
850
+ log_debug(f"Upserted workflow session with id '{session_dict.get('session_id')}'")
851
+
852
+ if not deserialize:
853
+ return session_dict
854
+ return WorkflowSession.from_dict(session_dict)
855
+
856
+ else:
857
+ raise ValueError(f"Invalid session type: {session.session_type}")
858
+
859
+ except Exception as e:
860
+ log_error(f"Exception upserting into sessions table: {e}")
861
+ return None
862
+
863
+ async def upsert_sessions(
864
+ self, sessions: List[Session], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
865
+ ) -> List[Union[Session, Dict[str, Any]]]:
866
+ """
867
+ Bulk upsert multiple sessions for improved performance on large datasets.
868
+
869
+ Args:
870
+ sessions (List[Session]): List of sessions to upsert.
871
+ deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
872
+ preserve_updated_at (bool): If True, preserve the updated_at from the session object.
873
+
874
+ Returns:
875
+ List[Union[Session, Dict[str, Any]]]: List of upserted sessions.
876
+
877
+ Raises:
878
+ Exception: If an error occurs during bulk upsert.
879
+ """
880
+ if not sessions:
881
+ return []
882
+
883
+ try:
884
+ table = await self._get_table(table_type="sessions")
885
+
886
+ # Group sessions by type for batch processing
887
+ agent_sessions = []
888
+ team_sessions = []
889
+ workflow_sessions = []
890
+
891
+ for session in sessions:
892
+ if isinstance(session, AgentSession):
893
+ agent_sessions.append(session)
894
+ elif isinstance(session, TeamSession):
895
+ team_sessions.append(session)
896
+ elif isinstance(session, WorkflowSession):
897
+ workflow_sessions.append(session)
898
+
899
+ results: List[Union[Session, Dict[str, Any]]] = []
900
+
901
+ # Process each session type in bulk
902
+ async with self.async_session_factory() as sess, sess.begin():
903
+ # Bulk upsert agent sessions
904
+ if agent_sessions:
905
+ agent_data = []
906
+ for session in agent_sessions:
907
+ session_dict = session.to_dict()
908
+ # Use preserved updated_at if flag is set and value exists, otherwise use current time
909
+ updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
910
+ agent_data.append(
911
+ {
912
+ "session_id": session_dict.get("session_id"),
913
+ "session_type": SessionType.AGENT.value,
914
+ "agent_id": session_dict.get("agent_id"),
915
+ "user_id": session_dict.get("user_id"),
916
+ "runs": session_dict.get("runs"),
917
+ "agent_data": session_dict.get("agent_data"),
918
+ "session_data": session_dict.get("session_data"),
919
+ "summary": session_dict.get("summary"),
920
+ "metadata": session_dict.get("metadata"),
921
+ "created_at": session_dict.get("created_at"),
922
+ "updated_at": updated_at,
923
+ }
924
+ )
925
+
926
+ if agent_data:
927
+ stmt = mysql.insert(table)
928
+ stmt = stmt.on_duplicate_key_update(
929
+ agent_id=stmt.inserted.agent_id,
930
+ user_id=stmt.inserted.user_id,
931
+ agent_data=stmt.inserted.agent_data,
932
+ session_data=stmt.inserted.session_data,
933
+ summary=stmt.inserted.summary,
934
+ metadata=stmt.inserted.metadata,
935
+ runs=stmt.inserted.runs,
936
+ updated_at=stmt.inserted.updated_at,
937
+ )
938
+ await sess.execute(stmt, agent_data)
939
+
940
+ # Fetch the results for agent sessions
941
+ agent_ids = [session.session_id for session in agent_sessions]
942
+ select_stmt = select(table).where(table.c.session_id.in_(agent_ids))
943
+ result = await sess.execute(select_stmt)
944
+ fetched_rows = result.fetchall()
945
+
946
+ for row in fetched_rows:
947
+ session_dict = dict(row._mapping)
948
+ if deserialize:
949
+ deserialized_agent_session = AgentSession.from_dict(session_dict)
950
+ if deserialized_agent_session is None:
951
+ continue
952
+ results.append(deserialized_agent_session)
953
+ else:
954
+ results.append(session_dict)
955
+
956
+ # Bulk upsert team sessions
957
+ if team_sessions:
958
+ team_data = []
959
+ for session in team_sessions:
960
+ session_dict = session.to_dict()
961
+ # Use preserved updated_at if flag is set and value exists, otherwise use current time
962
+ updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
963
+ team_data.append(
964
+ {
965
+ "session_id": session_dict.get("session_id"),
966
+ "session_type": SessionType.TEAM.value,
967
+ "team_id": session_dict.get("team_id"),
968
+ "user_id": session_dict.get("user_id"),
969
+ "runs": session_dict.get("runs"),
970
+ "team_data": session_dict.get("team_data"),
971
+ "session_data": session_dict.get("session_data"),
972
+ "summary": session_dict.get("summary"),
973
+ "metadata": session_dict.get("metadata"),
974
+ "created_at": session_dict.get("created_at"),
975
+ "updated_at": updated_at,
976
+ }
977
+ )
978
+
979
+ if team_data:
980
+ stmt = mysql.insert(table)
981
+ stmt = stmt.on_duplicate_key_update(
982
+ team_id=stmt.inserted.team_id,
983
+ user_id=stmt.inserted.user_id,
984
+ team_data=stmt.inserted.team_data,
985
+ session_data=stmt.inserted.session_data,
986
+ summary=stmt.inserted.summary,
987
+ metadata=stmt.inserted.metadata,
988
+ runs=stmt.inserted.runs,
989
+ updated_at=stmt.inserted.updated_at,
990
+ )
991
+ await sess.execute(stmt, team_data)
992
+
993
+ # Fetch the results for team sessions
994
+ team_ids = [session.session_id for session in team_sessions]
995
+ select_stmt = select(table).where(table.c.session_id.in_(team_ids))
996
+ result = await sess.execute(select_stmt)
997
+ fetched_rows = result.fetchall()
998
+
999
+ for row in fetched_rows:
1000
+ session_dict = dict(row._mapping)
1001
+ if deserialize:
1002
+ deserialized_team_session = TeamSession.from_dict(session_dict)
1003
+ if deserialized_team_session is None:
1004
+ continue
1005
+ results.append(deserialized_team_session)
1006
+ else:
1007
+ results.append(session_dict)
1008
+
1009
+ # Bulk upsert workflow sessions
1010
+ if workflow_sessions:
1011
+ workflow_data = []
1012
+ for session in workflow_sessions:
1013
+ session_dict = session.to_dict()
1014
+ # Use preserved updated_at if flag is set and value exists, otherwise use current time
1015
+ updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
1016
+ workflow_data.append(
1017
+ {
1018
+ "session_id": session_dict.get("session_id"),
1019
+ "session_type": SessionType.WORKFLOW.value,
1020
+ "workflow_id": session_dict.get("workflow_id"),
1021
+ "user_id": session_dict.get("user_id"),
1022
+ "runs": session_dict.get("runs"),
1023
+ "workflow_data": session_dict.get("workflow_data"),
1024
+ "session_data": session_dict.get("session_data"),
1025
+ "summary": session_dict.get("summary"),
1026
+ "metadata": session_dict.get("metadata"),
1027
+ "created_at": session_dict.get("created_at"),
1028
+ "updated_at": updated_at,
1029
+ }
1030
+ )
1031
+
1032
+ if workflow_data:
1033
+ stmt = mysql.insert(table)
1034
+ stmt = stmt.on_duplicate_key_update(
1035
+ workflow_id=stmt.inserted.workflow_id,
1036
+ user_id=stmt.inserted.user_id,
1037
+ workflow_data=stmt.inserted.workflow_data,
1038
+ session_data=stmt.inserted.session_data,
1039
+ summary=stmt.inserted.summary,
1040
+ metadata=stmt.inserted.metadata,
1041
+ runs=stmt.inserted.runs,
1042
+ updated_at=stmt.inserted.updated_at,
1043
+ )
1044
+ await sess.execute(stmt, workflow_data)
1045
+
1046
+ # Fetch the results for workflow sessions
1047
+ workflow_ids = [session.session_id for session in workflow_sessions]
1048
+ select_stmt = select(table).where(table.c.session_id.in_(workflow_ids))
1049
+ result = await sess.execute(select_stmt)
1050
+ fetched_rows = result.fetchall()
1051
+
1052
+ for row in fetched_rows:
1053
+ session_dict = dict(row._mapping)
1054
+ if deserialize:
1055
+ deserialized_workflow_session = WorkflowSession.from_dict(session_dict)
1056
+ if deserialized_workflow_session is None:
1057
+ continue
1058
+ results.append(deserialized_workflow_session)
1059
+ else:
1060
+ results.append(session_dict)
1061
+
1062
+ return results
1063
+
1064
+ except Exception as e:
1065
+ log_error(f"Exception during bulk session upsert, falling back to individual upserts: {e}")
1066
+ # Fallback to individual upserts
1067
+ return [
1068
+ result
1069
+ for session in sessions
1070
+ if session is not None
1071
+ for result in [await self.upsert_session(session, deserialize=deserialize)]
1072
+ if result is not None
1073
+ ]
1074
+
1075
+ # -- Memory methods --
1076
+ async def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
1077
+ """Delete a user memory from the database.
1078
+
1079
+ Returns:
1080
+ bool: True if deletion was successful, False otherwise.
1081
+
1082
+ Raises:
1083
+ Exception: If an error occurs during deletion.
1084
+ """
1085
+ try:
1086
+ table = await self._get_table(table_type="memories")
1087
+
1088
+ async with self.async_session_factory() as sess, sess.begin():
1089
+ delete_stmt = table.delete().where(table.c.memory_id == memory_id)
1090
+ if user_id is not None:
1091
+ delete_stmt = delete_stmt.where(table.c.user_id == user_id)
1092
+ result = await sess.execute(delete_stmt)
1093
+
1094
+ success = result.rowcount > 0 # type: ignore
1095
+ if success:
1096
+ log_debug(f"Successfully deleted user memory id: {memory_id}")
1097
+ else:
1098
+ log_debug(f"No user memory found with id: {memory_id}")
1099
+
1100
+ except Exception as e:
1101
+ log_error(f"Error deleting user memory: {e}")
1102
+
1103
+ async def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
1104
+ """Delete user memories from the database.
1105
+
1106
+ Args:
1107
+ memory_ids (List[str]): The IDs of the memories to delete.
1108
+ user_id (Optional[str]): Optional user ID to filter deletions.
1109
+
1110
+ Raises:
1111
+ Exception: If an error occurs during deletion.
1112
+ """
1113
+ try:
1114
+ table = await self._get_table(table_type="memories")
1115
+
1116
+ async with self.async_session_factory() as sess, sess.begin():
1117
+ delete_stmt = table.delete().where(table.c.memory_id.in_(memory_ids))
1118
+ if user_id is not None:
1119
+ delete_stmt = delete_stmt.where(table.c.user_id == user_id)
1120
+ result = await sess.execute(delete_stmt)
1121
+
1122
+ if result.rowcount == 0: # type: ignore
1123
+ log_debug(f"No user memories found with ids: {memory_ids}")
1124
+ else:
1125
+ log_debug(f"Successfully deleted {result.rowcount} user memories") # type: ignore
1126
+
1127
+ except Exception as e:
1128
+ log_error(f"Error deleting user memories: {e}")
1129
+
1130
+ async def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
1131
+ """Get all memory topics from the database.
1132
+
1133
+ Args:
1134
+ user_id (Optional[str]): Optional user ID to filter topics.
1135
+
1136
+ Returns:
1137
+ List[str]: List of memory topics.
1138
+ """
1139
+ try:
1140
+ table = await self._get_table(table_type="memories")
1141
+
1142
+ async with self.async_session_factory() as sess, sess.begin():
1143
+ # MySQL approach: extract JSON array elements differently
1144
+ stmt = select(table.c.topics)
1145
+ result = await sess.execute(stmt)
1146
+ records = result.fetchall()
1147
+
1148
+ topics_set = set()
1149
+ for row in records:
1150
+ if row[0]:
1151
+ # Parse JSON array and add topics to set
1152
+ import json
1153
+
1154
+ try:
1155
+ topics = json.loads(row[0]) if isinstance(row[0], str) else row[0]
1156
+ if isinstance(topics, list):
1157
+ topics_set.update(topics)
1158
+ except Exception:
1159
+ pass
1160
+
1161
+ return list(topics_set)
1162
+
1163
+ except Exception as e:
1164
+ log_error(f"Exception reading from memory table: {e}")
1165
+ return []
1166
+
1167
+ async def get_user_memory(
1168
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
1169
+ ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
1170
+ """Get a memory from the database.
1171
+
1172
+ Args:
1173
+ memory_id (str): The ID of the memory to get.
1174
+ deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
1175
+
1176
+ Returns:
1177
+ Union[UserMemory, Dict[str, Any], None]:
1178
+ - When deserialize=True: UserMemory object
1179
+ - When deserialize=False: UserMemory dictionary
1180
+
1181
+ Raises:
1182
+ Exception: If an error occurs during retrieval.
1183
+ """
1184
+ try:
1185
+ table = await self._get_table(table_type="memories")
1186
+
1187
+ async with self.async_session_factory() as sess, sess.begin():
1188
+ stmt = select(table).where(table.c.memory_id == memory_id)
1189
+ if user_id is not None:
1190
+ stmt = stmt.where(table.c.user_id == user_id)
1191
+
1192
+ result = await sess.execute(stmt)
1193
+ row = result.fetchone()
1194
+ if not row:
1195
+ return None
1196
+
1197
+ memory_raw = dict(row._mapping)
1198
+ if not deserialize:
1199
+ return memory_raw
1200
+
1201
+ return UserMemory.from_dict(memory_raw)
1202
+
1203
+ except Exception as e:
1204
+ log_error(f"Exception reading from memory table: {e}")
1205
+ return None
1206
+
1207
+ async def get_user_memories(
1208
+ self,
1209
+ user_id: Optional[str] = None,
1210
+ agent_id: Optional[str] = None,
1211
+ team_id: Optional[str] = None,
1212
+ topics: Optional[List[str]] = None,
1213
+ search_content: Optional[str] = None,
1214
+ limit: Optional[int] = None,
1215
+ page: Optional[int] = None,
1216
+ sort_by: Optional[str] = None,
1217
+ sort_order: Optional[str] = None,
1218
+ deserialize: Optional[bool] = True,
1219
+ ) -> Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
1220
+ """Get all memories from the database as UserMemory objects.
1221
+
1222
+ Args:
1223
+ user_id (Optional[str]): The ID of the user to filter by.
1224
+ agent_id (Optional[str]): The ID of the agent to filter by.
1225
+ team_id (Optional[str]): The ID of the team to filter by.
1226
+ topics (Optional[List[str]]): The topics to filter by.
1227
+ search_content (Optional[str]): The content to search for.
1228
+ limit (Optional[int]): The maximum number of memories to return.
1229
+ page (Optional[int]): The page number.
1230
+ sort_by (Optional[str]): The column to sort by.
1231
+ sort_order (Optional[str]): The order to sort by.
1232
+ deserialize (Optional[bool]): Whether to serialize the memories. Defaults to True.
1233
+
1234
+ Returns:
1235
+ Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
1236
+ - When deserialize=True: List of UserMemory objects
1237
+ - When deserialize=False: Tuple of (memory dictionaries, total count)
1238
+
1239
+ Raises:
1240
+ Exception: If an error occurs during retrieval.
1241
+ """
1242
+ try:
1243
+ table = await self._get_table(table_type="memories")
1244
+
1245
+ async with self.async_session_factory() as sess, sess.begin():
1246
+ stmt = select(table)
1247
+ # Filtering
1248
+ if user_id is not None:
1249
+ stmt = stmt.where(table.c.user_id == user_id)
1250
+ if agent_id is not None:
1251
+ stmt = stmt.where(table.c.agent_id == agent_id)
1252
+ if team_id is not None:
1253
+ stmt = stmt.where(table.c.team_id == team_id)
1254
+ if topics is not None:
1255
+ # MySQL JSON contains syntax
1256
+ topic_conditions = []
1257
+ for topic in topics:
1258
+ topic_conditions.append(func.json_contains(table.c.topics, f'"{topic}"'))
1259
+ stmt = stmt.where(and_(*topic_conditions))
1260
+ if search_content is not None:
1261
+ stmt = stmt.where(cast(table.c.memory, TEXT).ilike(f"%{search_content}%"))
1262
+
1263
+ # Get total count after applying filtering
1264
+ count_stmt = select(func.count()).select_from(stmt.alias())
1265
+ total_count = await sess.scalar(count_stmt) or 0
1266
+
1267
+ # Sorting
1268
+ stmt = apply_sorting(stmt, table, sort_by, sort_order)
1269
+
1270
+ # Paginating
1271
+ if limit is not None:
1272
+ stmt = stmt.limit(limit)
1273
+ if page is not None:
1274
+ stmt = stmt.offset((page - 1) * limit)
1275
+
1276
+ result = await sess.execute(stmt)
1277
+ records = result.fetchall()
1278
+ if not records:
1279
+ return [] if deserialize else ([], 0)
1280
+
1281
+ memories_raw = [dict(record._mapping) for record in records]
1282
+ if not deserialize:
1283
+ return memories_raw, total_count
1284
+
1285
+ return [UserMemory.from_dict(record) for record in memories_raw]
1286
+
1287
+ except Exception as e:
1288
+ log_error(f"Exception reading from memory table: {e}")
1289
+ return [] if deserialize else ([], 0)
1290
+
1291
+ async def clear_memories(self) -> None:
1292
+ """Delete all memories from the database.
1293
+
1294
+ Raises:
1295
+ Exception: If an error occurs during deletion.
1296
+ """
1297
+ try:
1298
+ table = await self._get_table(table_type="memories")
1299
+
1300
+ async with self.async_session_factory() as sess, sess.begin():
1301
+ await sess.execute(table.delete())
1302
+
1303
+ except Exception as e:
1304
+ log_warning(f"Exception deleting all memories: {e}")
1305
+
1306
+ # -- Cultural Knowledge methods --
1307
+ async def clear_cultural_knowledge(self) -> None:
1308
+ """Delete all cultural knowledge from the database.
1309
+
1310
+ Raises:
1311
+ Exception: If an error occurs during deletion.
1312
+ """
1313
+ try:
1314
+ table = await self._get_table(table_type="culture")
1315
+
1316
+ async with self.async_session_factory() as sess, sess.begin():
1317
+ await sess.execute(table.delete())
1318
+
1319
+ except Exception as e:
1320
+ log_warning(f"Exception deleting all cultural knowledge: {e}")
1321
+
1322
+ async def delete_cultural_knowledge(self, id: str) -> None:
1323
+ """Delete cultural knowledge by ID.
1324
+
1325
+ Args:
1326
+ id (str): The ID of the cultural knowledge to delete.
1327
+
1328
+ Raises:
1329
+ Exception: If an error occurs during deletion.
1330
+ """
1331
+ try:
1332
+ table = await self._get_table(table_type="culture")
1333
+
1334
+ async with self.async_session_factory() as sess, sess.begin():
1335
+ stmt = table.delete().where(table.c.id == id)
1336
+ await sess.execute(stmt)
1337
+
1338
+ except Exception as e:
1339
+ log_warning(f"Exception deleting cultural knowledge: {e}")
1340
+ raise e
1341
+
1342
+ async def get_cultural_knowledge(
1343
+ self, id: str, deserialize: Optional[bool] = True
1344
+ ) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
1345
+ """Get cultural knowledge by ID.
1346
+
1347
+ Args:
1348
+ id (str): The ID of the cultural knowledge to retrieve.
1349
+ deserialize (Optional[bool]): Whether to deserialize to CulturalKnowledge object. Defaults to True.
1350
+
1351
+ Returns:
1352
+ Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The cultural knowledge if found, None otherwise.
1353
+
1354
+ Raises:
1355
+ Exception: If an error occurs during retrieval.
1356
+ """
1357
+ try:
1358
+ table = await self._get_table(table_type="culture")
1359
+
1360
+ async with self.async_session_factory() as sess:
1361
+ stmt = select(table).where(table.c.id == id)
1362
+ result = await sess.execute(stmt)
1363
+ row = result.fetchone()
1364
+
1365
+ if row is None:
1366
+ return None
1367
+
1368
+ db_row = dict(row._mapping)
1369
+
1370
+ if not deserialize:
1371
+ return db_row
1372
+
1373
+ return deserialize_cultural_knowledge_from_db(db_row)
1374
+
1375
+ except Exception as e:
1376
+ log_warning(f"Exception reading cultural knowledge: {e}")
1377
+ raise e
1378
+
1379
+ async def get_all_cultural_knowledge(
1380
+ self,
1381
+ agent_id: Optional[str] = None,
1382
+ team_id: Optional[str] = None,
1383
+ name: Optional[str] = None,
1384
+ limit: Optional[int] = None,
1385
+ page: Optional[int] = None,
1386
+ sort_by: Optional[str] = None,
1387
+ sort_order: Optional[str] = None,
1388
+ deserialize: Optional[bool] = True,
1389
+ ) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
1390
+ """Get all cultural knowledge with filtering and pagination.
1391
+
1392
+ Args:
1393
+ agent_id (Optional[str]): Filter by agent ID.
1394
+ team_id (Optional[str]): Filter by team ID.
1395
+ name (Optional[str]): Filter by name (case-insensitive partial match).
1396
+ limit (Optional[int]): Maximum number of results to return.
1397
+ page (Optional[int]): Page number for pagination.
1398
+ sort_by (Optional[str]): Field to sort by.
1399
+ sort_order (Optional[str]): Sort order ('asc' or 'desc').
1400
+ deserialize (Optional[bool]): Whether to deserialize to CulturalKnowledge objects. Defaults to True.
1401
+
1402
+ Returns:
1403
+ Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
1404
+ - When deserialize=True: List of CulturalKnowledge objects
1405
+ - When deserialize=False: Tuple with list of dictionaries and total count
1406
+
1407
+ Raises:
1408
+ Exception: If an error occurs during retrieval.
1409
+ """
1410
+ try:
1411
+ table = await self._get_table(table_type="culture")
1412
+
1413
+ async with self.async_session_factory() as sess:
1414
+ # Build query with filters
1415
+ stmt = select(table)
1416
+ if agent_id is not None:
1417
+ stmt = stmt.where(table.c.agent_id == agent_id)
1418
+ if team_id is not None:
1419
+ stmt = stmt.where(table.c.team_id == team_id)
1420
+ if name is not None:
1421
+ stmt = stmt.where(table.c.name.ilike(f"%{name}%"))
1422
+
1423
+ # Get total count
1424
+ count_stmt = select(func.count()).select_from(stmt.alias())
1425
+ total_count_result = await sess.execute(count_stmt)
1426
+ total_count = total_count_result.scalar() or 0
1427
+
1428
+ # Apply sorting
1429
+ stmt = apply_sorting(stmt, table, sort_by, sort_order)
1430
+
1431
+ # Apply pagination
1432
+ if limit is not None:
1433
+ stmt = stmt.limit(limit)
1434
+ if page is not None:
1435
+ stmt = stmt.offset((page - 1) * limit)
1436
+
1437
+ # Execute query
1438
+ result = await sess.execute(stmt)
1439
+ rows = result.fetchall()
1440
+
1441
+ db_rows = [dict(row._mapping) for row in rows]
1442
+
1443
+ if not deserialize:
1444
+ return db_rows, total_count
1445
+
1446
+ return [deserialize_cultural_knowledge_from_db(row) for row in db_rows]
1447
+
1448
+ except Exception as e:
1449
+ log_warning(f"Exception reading all cultural knowledge: {e}")
1450
+ raise e
1451
+
1452
+ async def upsert_cultural_knowledge(
1453
+ self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
1454
+ ) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
1455
+ """Upsert cultural knowledge in the database.
1456
+
1457
+ Args:
1458
+ cultural_knowledge (CulturalKnowledge): The cultural knowledge to upsert.
1459
+ deserialize (Optional[bool]): Whether to deserialize the result. Defaults to True.
1460
+
1461
+ Returns:
1462
+ Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The upserted cultural knowledge.
1463
+
1464
+ Raises:
1465
+ Exception: If an error occurs during upsert.
1466
+ """
1467
+ try:
1468
+ table = await self._get_table(table_type="culture")
1469
+
1470
+ # Generate ID if not present
1471
+ if cultural_knowledge.id is None:
1472
+ cultural_knowledge.id = str(uuid4())
1473
+
1474
+ # Serialize content, categories, and notes into a JSON dict for DB storage
1475
+ content_dict = serialize_cultural_knowledge_for_db(cultural_knowledge)
1476
+
1477
+ async with self.async_session_factory() as sess, sess.begin():
1478
+ # Use MySQL-specific insert with on_duplicate_key_update
1479
+ insert_stmt = mysql.insert(table).values(
1480
+ id=cultural_knowledge.id,
1481
+ name=cultural_knowledge.name,
1482
+ summary=cultural_knowledge.summary,
1483
+ content=content_dict if content_dict else None,
1484
+ metadata=cultural_knowledge.metadata,
1485
+ input=cultural_knowledge.input,
1486
+ created_at=cultural_knowledge.created_at,
1487
+ updated_at=int(time.time()),
1488
+ agent_id=cultural_knowledge.agent_id,
1489
+ team_id=cultural_knowledge.team_id,
1490
+ )
1491
+
1492
+ # Update all fields except id on conflict
1493
+ upsert_stmt = insert_stmt.on_duplicate_key_update(
1494
+ name=cultural_knowledge.name,
1495
+ summary=cultural_knowledge.summary,
1496
+ content=content_dict if content_dict else None,
1497
+ metadata=cultural_knowledge.metadata,
1498
+ input=cultural_knowledge.input,
1499
+ updated_at=int(time.time()),
1500
+ agent_id=cultural_knowledge.agent_id,
1501
+ team_id=cultural_knowledge.team_id,
1502
+ )
1503
+
1504
+ await sess.execute(upsert_stmt)
1505
+
1506
+ # Fetch the inserted/updated row
1507
+ select_stmt = select(table).where(table.c.id == cultural_knowledge.id)
1508
+ result = await sess.execute(select_stmt)
1509
+ row = result.fetchone()
1510
+
1511
+ if row is None:
1512
+ return None
1513
+
1514
+ db_row = dict(row._mapping)
1515
+
1516
+ if not deserialize:
1517
+ return db_row
1518
+
1519
+ # Deserialize from DB format to model format
1520
+ return deserialize_cultural_knowledge_from_db(db_row)
1521
+
1522
+ except Exception as e:
1523
+ log_warning(f"Exception upserting cultural knowledge: {e}")
1524
+ raise e
1525
+
1526
+ async def get_user_memory_stats(
1527
+ self, limit: Optional[int] = None, page: Optional[int] = None, user_id: Optional[str] = None
1528
+ ) -> Tuple[List[Dict[str, Any]], int]:
1529
+ """Get user memories stats.
1530
+
1531
+ Args:
1532
+ limit (Optional[int]): The maximum number of user stats to return.
1533
+ page (Optional[int]): The page number.
1534
+
1535
+ Returns:
1536
+ Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
1537
+
1538
+ Example:
1539
+ (
1540
+ [
1541
+ {
1542
+ "user_id": "123",
1543
+ "total_memories": 10,
1544
+ "last_memory_updated_at": 1714560000,
1545
+ },
1546
+ ],
1547
+ total_count: 1,
1548
+ )
1549
+ """
1550
+ try:
1551
+ table = await self._get_table(table_type="memories")
1552
+
1553
+ async with self.async_session_factory() as sess, sess.begin():
1554
+ stmt = select(
1555
+ table.c.user_id,
1556
+ func.count(table.c.memory_id).label("total_memories"),
1557
+ func.max(table.c.updated_at).label("last_memory_updated_at"),
1558
+ )
1559
+
1560
+ if user_id is not None:
1561
+ stmt = stmt.where(table.c.user_id == user_id)
1562
+ else:
1563
+ stmt = stmt.where(table.c.user_id.is_not(None))
1564
+
1565
+ stmt = stmt.group_by(table.c.user_id)
1566
+ stmt = stmt.order_by(func.max(table.c.updated_at).desc())
1567
+
1568
+ count_stmt = select(func.count()).select_from(stmt.alias())
1569
+ total_count = await sess.scalar(count_stmt) or 0
1570
+
1571
+ # Pagination
1572
+ if limit is not None:
1573
+ stmt = stmt.limit(limit)
1574
+ if page is not None:
1575
+ stmt = stmt.offset((page - 1) * limit)
1576
+
1577
+ result = await sess.execute(stmt)
1578
+ records = result.fetchall()
1579
+ if not records:
1580
+ return [], 0
1581
+
1582
+ return [
1583
+ {
1584
+ "user_id": record.user_id, # type: ignore
1585
+ "total_memories": record.total_memories,
1586
+ "last_memory_updated_at": record.last_memory_updated_at,
1587
+ }
1588
+ for record in records
1589
+ ], total_count
1590
+
1591
+ except Exception as e:
1592
+ log_error(f"Exception getting user memory stats: {e}")
1593
+ return [], 0
1594
+
1595
+ async def upsert_user_memory(
1596
+ self, memory: UserMemory, deserialize: Optional[bool] = True
1597
+ ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
1598
+ """Upsert a user memory in the database.
1599
+
1600
+ Args:
1601
+ memory (UserMemory): The user memory to upsert.
1602
+ deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
1603
+
1604
+ Returns:
1605
+ Optional[Union[UserMemory, Dict[str, Any]]]:
1606
+ - When deserialize=True: UserMemory object
1607
+ - When deserialize=False: UserMemory dictionary
1608
+
1609
+ Raises:
1610
+ Exception: If an error occurs during upsert.
1611
+ """
1612
+ try:
1613
+ table = await self._get_table(table_type="memories")
1614
+
1615
+ async with self.async_session_factory() as sess, sess.begin():
1616
+ if memory.memory_id is None:
1617
+ memory.memory_id = str(uuid4())
1618
+
1619
+ current_time = int(time.time())
1620
+
1621
+ stmt = mysql.insert(table).values(
1622
+ memory_id=memory.memory_id,
1623
+ memory=memory.memory,
1624
+ input=memory.input,
1625
+ user_id=memory.user_id,
1626
+ agent_id=memory.agent_id,
1627
+ team_id=memory.team_id,
1628
+ topics=memory.topics,
1629
+ feedback=memory.feedback,
1630
+ created_at=memory.created_at,
1631
+ updated_at=memory.created_at,
1632
+ )
1633
+ stmt = stmt.on_duplicate_key_update(
1634
+ memory=memory.memory,
1635
+ topics=memory.topics,
1636
+ input=memory.input,
1637
+ agent_id=memory.agent_id,
1638
+ team_id=memory.team_id,
1639
+ feedback=memory.feedback,
1640
+ updated_at=current_time,
1641
+ # Preserve created_at on update - don't overwrite existing value
1642
+ created_at=table.c.created_at,
1643
+ )
1644
+ await sess.execute(stmt)
1645
+
1646
+ # Fetch the row
1647
+ select_stmt = select(table).where(table.c.memory_id == memory.memory_id)
1648
+ result = await sess.execute(select_stmt)
1649
+ row = result.fetchone()
1650
+ if row is None:
1651
+ return None
1652
+
1653
+ memory_raw = dict(row._mapping)
1654
+
1655
+ log_debug(f"Upserted user memory with id '{memory.memory_id}'")
1656
+
1657
+ if not memory_raw or not deserialize:
1658
+ return memory_raw
1659
+
1660
+ return UserMemory.from_dict(memory_raw)
1661
+
1662
+ except Exception as e:
1663
+ log_error(f"Exception upserting user memory: {e}")
1664
+ return None
1665
+
1666
+ async def upsert_memories(
1667
+ self, memories: List[UserMemory], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
1668
+ ) -> List[Union[UserMemory, Dict[str, Any]]]:
1669
+ """
1670
+ Bulk upsert multiple user memories for improved performance on large datasets.
1671
+
1672
+ Args:
1673
+ memories (List[UserMemory]): List of memories to upsert.
1674
+ deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
1675
+ preserve_updated_at (bool): If True, preserve the updated_at from the memory object.
1676
+
1677
+ Returns:
1678
+ List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories.
1679
+
1680
+ Raises:
1681
+ Exception: If an error occurs during bulk upsert.
1682
+ """
1683
+ if not memories:
1684
+ return []
1685
+
1686
+ try:
1687
+ table = await self._get_table(table_type="memories")
1688
+
1689
+ # Prepare bulk data
1690
+ bulk_data = []
1691
+ current_time = int(time.time())
1692
+ for memory in memories:
1693
+ if memory.memory_id is None:
1694
+ memory.memory_id = str(uuid4())
1695
+
1696
+ # Use preserved updated_at if flag is set and value exists, otherwise use current time
1697
+ updated_at = memory.updated_at if preserve_updated_at else current_time
1698
+ bulk_data.append(
1699
+ {
1700
+ "memory_id": memory.memory_id,
1701
+ "memory": memory.memory,
1702
+ "input": memory.input,
1703
+ "user_id": memory.user_id,
1704
+ "agent_id": memory.agent_id,
1705
+ "team_id": memory.team_id,
1706
+ "topics": memory.topics,
1707
+ "feedback": memory.feedback,
1708
+ "created_at": memory.created_at,
1709
+ "updated_at": updated_at,
1710
+ }
1711
+ )
1712
+
1713
+ results: List[Union[UserMemory, Dict[str, Any]]] = []
1714
+
1715
+ async with self.async_session_factory() as sess, sess.begin():
1716
+ # Bulk upsert memories using MySQL ON DUPLICATE KEY UPDATE
1717
+ stmt = mysql.insert(table)
1718
+ stmt = stmt.on_duplicate_key_update(
1719
+ memory=stmt.inserted.memory,
1720
+ topics=stmt.inserted.topics,
1721
+ input=stmt.inserted.input,
1722
+ agent_id=stmt.inserted.agent_id,
1723
+ team_id=stmt.inserted.team_id,
1724
+ feedback=stmt.inserted.feedback,
1725
+ updated_at=stmt.inserted.updated_at,
1726
+ # Preserve created_at on update
1727
+ created_at=table.c.created_at,
1728
+ )
1729
+ await sess.execute(stmt, bulk_data)
1730
+
1731
+ # Fetch results
1732
+ memory_ids = [memory.memory_id for memory in memories if memory.memory_id]
1733
+ select_stmt = select(table).where(table.c.memory_id.in_(memory_ids))
1734
+ result = await sess.execute(select_stmt)
1735
+ fetched_rows = result.fetchall()
1736
+
1737
+ for row in fetched_rows:
1738
+ memory_dict = dict(row._mapping)
1739
+ if deserialize:
1740
+ results.append(UserMemory.from_dict(memory_dict))
1741
+ else:
1742
+ results.append(memory_dict)
1743
+
1744
+ return results
1745
+
1746
+ except Exception as e:
1747
+ log_error(f"Exception during bulk memory upsert, falling back to individual upserts: {e}")
1748
+ # Fallback to individual upserts
1749
+ return [
1750
+ result
1751
+ for memory in memories
1752
+ if memory is not None
1753
+ for result in [await self.upsert_user_memory(memory, deserialize=deserialize)]
1754
+ if result is not None
1755
+ ]
1756
+
1757
+ # -- Metrics methods --
1758
+ async def _get_all_sessions_for_metrics_calculation(
1759
+ self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None
1760
+ ) -> List[Dict[str, Any]]:
1761
+ """
1762
+ Get all sessions of all types (agent, team, workflow) as raw dictionaries.
1763
+
1764
+ Args:
1765
+ start_timestamp (Optional[int]): The start timestamp to filter by. Defaults to None.
1766
+ end_timestamp (Optional[int]): The end timestamp to filter by. Defaults to None.
1767
+
1768
+ Returns:
1769
+ List[Dict[str, Any]]: List of session dictionaries with session_type field.
1770
+
1771
+ Raises:
1772
+ Exception: If an error occurs during retrieval.
1773
+ """
1774
+ try:
1775
+ table = await self._get_table(table_type="sessions")
1776
+
1777
+ stmt = select(
1778
+ table.c.user_id,
1779
+ table.c.session_data,
1780
+ table.c.runs,
1781
+ table.c.created_at,
1782
+ table.c.session_type,
1783
+ )
1784
+
1785
+ if start_timestamp is not None:
1786
+ stmt = stmt.where(table.c.created_at >= start_timestamp)
1787
+ if end_timestamp is not None:
1788
+ stmt = stmt.where(table.c.created_at <= end_timestamp)
1789
+
1790
+ async with self.async_session_factory() as sess:
1791
+ result = await sess.execute(stmt)
1792
+ records = result.fetchall()
1793
+
1794
+ return [dict(record._mapping) for record in records]
1795
+
1796
+ except Exception as e:
1797
+ log_error(f"Exception reading from sessions table: {e}")
1798
+ return []
1799
+
1800
+ async def _get_metrics_calculation_starting_date(self, table: Table) -> Optional[date]:
1801
+ """Get the first date for which metrics calculation is needed:
1802
+
1803
+ 1. If there are metrics records, return the date of the first day without a complete metrics record.
1804
+ 2. If there are no metrics records, return the date of the first recorded session.
1805
+ 3. If there are no metrics records and no sessions records, return None.
1806
+
1807
+ Args:
1808
+ table (Table): The table to get the starting date for.
1809
+
1810
+ Returns:
1811
+ Optional[date]: The starting date for which metrics calculation is needed.
1812
+ """
1813
+ async with self.async_session_factory() as sess:
1814
+ stmt = select(table).order_by(table.c.date.desc()).limit(1)
1815
+ result = await sess.execute(stmt)
1816
+ row = result.fetchone()
1817
+
1818
+ # 1. Return the date of the first day without a complete metrics record.
1819
+ if row is not None:
1820
+ if row.completed:
1821
+ return row._mapping["date"] + timedelta(days=1)
1822
+ else:
1823
+ return row._mapping["date"]
1824
+
1825
+ # 2. No metrics records. Return the date of the first recorded session.
1826
+ first_session, _ = await self.get_sessions(sort_by="created_at", sort_order="asc", limit=1, deserialize=False)
1827
+
1828
+ first_session_date = first_session[0]["created_at"] if first_session else None # type: ignore[index]
1829
+
1830
+ # 3. No metrics records and no sessions records. Return None.
1831
+ if first_session_date is None:
1832
+ return None
1833
+
1834
+ return datetime.fromtimestamp(first_session_date, tz=timezone.utc).date()
1835
+
1836
+ async def calculate_metrics(self) -> Optional[list[dict]]:
1837
+ """Calculate metrics for all dates without complete metrics.
1838
+
1839
+ Returns:
1840
+ Optional[list[dict]]: The calculated metrics.
1841
+
1842
+ Raises:
1843
+ Exception: If an error occurs during metrics calculation.
1844
+ """
1845
+ try:
1846
+ table = await self._get_table(table_type="metrics")
1847
+
1848
+ starting_date = await self._get_metrics_calculation_starting_date(table)
1849
+
1850
+ if starting_date is None:
1851
+ log_info("No session data found. Won't calculate metrics.")
1852
+ return None
1853
+
1854
+ dates_to_process = get_dates_to_calculate_metrics_for(starting_date)
1855
+ if not dates_to_process:
1856
+ log_info("Metrics already calculated for all relevant dates.")
1857
+ return None
1858
+
1859
+ start_timestamp = int(
1860
+ datetime.combine(dates_to_process[0], datetime.min.time()).replace(tzinfo=timezone.utc).timestamp()
1861
+ )
1862
+ end_timestamp = int(
1863
+ datetime.combine(dates_to_process[-1] + timedelta(days=1), datetime.min.time())
1864
+ .replace(tzinfo=timezone.utc)
1865
+ .timestamp()
1866
+ )
1867
+
1868
+ sessions = await self._get_all_sessions_for_metrics_calculation(
1869
+ start_timestamp=start_timestamp, end_timestamp=end_timestamp
1870
+ )
1871
+
1872
+ all_sessions_data = fetch_all_sessions_data(
1873
+ sessions=sessions, dates_to_process=dates_to_process, start_timestamp=start_timestamp
1874
+ )
1875
+ if not all_sessions_data:
1876
+ log_info("No new session data found. Won't calculate metrics.")
1877
+ return None
1878
+
1879
+ results = []
1880
+ metrics_records = []
1881
+
1882
+ for date_to_process in dates_to_process:
1883
+ date_key = date_to_process.isoformat()
1884
+ sessions_for_date = all_sessions_data.get(date_key, {})
1885
+
1886
+ # Skip dates with no sessions
1887
+ if not any(len(sessions) > 0 for sessions in sessions_for_date.values()):
1888
+ continue
1889
+
1890
+ metrics_record = calculate_date_metrics(date_to_process, sessions_for_date)
1891
+
1892
+ metrics_records.append(metrics_record)
1893
+
1894
+ if metrics_records:
1895
+ async with self.async_session_factory() as sess, sess.begin():
1896
+ results = await abulk_upsert_metrics(session=sess, table=table, metrics_records=metrics_records)
1897
+
1898
+ log_debug("Updated metrics calculations")
1899
+
1900
+ return results
1901
+
1902
+ except Exception as e:
1903
+ log_error(f"Exception refreshing metrics: {e}")
1904
+ return None
1905
+
1906
+ async def get_metrics(
1907
+ self, starting_date: Optional[date] = None, ending_date: Optional[date] = None
1908
+ ) -> Tuple[List[dict], Optional[int]]:
1909
+ """Get all metrics matching the given date range.
1910
+
1911
+ Args:
1912
+ starting_date (Optional[date]): The starting date to filter metrics by.
1913
+ ending_date (Optional[date]): The ending date to filter metrics by.
1914
+
1915
+ Returns:
1916
+ Tuple[List[dict], Optional[int]]: A tuple containing the metrics and the timestamp of the latest update.
1917
+
1918
+ Raises:
1919
+ Exception: If an error occurs during retrieval.
1920
+ """
1921
+ try:
1922
+ table = await self._get_table(table_type="metrics")
1923
+
1924
+ async with self.async_session_factory() as sess, sess.begin():
1925
+ stmt = select(table)
1926
+ if starting_date:
1927
+ stmt = stmt.where(table.c.date >= starting_date)
1928
+ if ending_date:
1929
+ stmt = stmt.where(table.c.date <= ending_date)
1930
+ result = await sess.execute(stmt)
1931
+ records = result.fetchall()
1932
+ if not records:
1933
+ return [], None
1934
+
1935
+ # Get the latest updated_at
1936
+ latest_stmt = select(func.max(table.c.updated_at))
1937
+ latest_result = await sess.execute(latest_stmt)
1938
+ latest_updated_at = latest_result.scalar()
1939
+
1940
+ return [dict(row._mapping) for row in records], latest_updated_at
1941
+
1942
+ except Exception as e:
1943
+ log_warning(f"Exception getting metrics: {e}")
1944
+ return [], None
1945
+
1946
+ # -- Knowledge methods --
1947
+ async def delete_knowledge_content(self, id: str):
1948
+ """Delete a knowledge row from the database.
1949
+
1950
+ Args:
1951
+ id (str): The ID of the knowledge row to delete.
1952
+ """
1953
+ table = await self._get_table(table_type="knowledge")
1954
+
1955
+ try:
1956
+ async with self.async_session_factory() as sess, sess.begin():
1957
+ stmt = table.delete().where(table.c.id == id)
1958
+ await sess.execute(stmt)
1959
+
1960
+ except Exception as e:
1961
+ log_error(f"Exception deleting knowledge content: {e}")
1962
+
1963
+ async def get_knowledge_content(self, id: str) -> Optional[KnowledgeRow]:
1964
+ """Get a knowledge row from the database.
1965
+
1966
+ Args:
1967
+ id (str): The ID of the knowledge row to get.
1968
+
1969
+ Returns:
1970
+ Optional[KnowledgeRow]: The knowledge row, or None if it doesn't exist.
1971
+ """
1972
+ table = await self._get_table(table_type="knowledge")
1973
+
1974
+ try:
1975
+ async with self.async_session_factory() as sess, sess.begin():
1976
+ stmt = select(table).where(table.c.id == id)
1977
+ result = await sess.execute(stmt)
1978
+ row = result.fetchone()
1979
+ if row is None:
1980
+ return None
1981
+
1982
+ return KnowledgeRow.model_validate(row._mapping)
1983
+
1984
+ except Exception as e:
1985
+ log_error(f"Exception getting knowledge content: {e}")
1986
+ return None
1987
+
1988
+ async def get_knowledge_contents(
1989
+ self,
1990
+ limit: Optional[int] = None,
1991
+ page: Optional[int] = None,
1992
+ sort_by: Optional[str] = None,
1993
+ sort_order: Optional[str] = None,
1994
+ ) -> Tuple[List[KnowledgeRow], int]:
1995
+ """Get all knowledge contents from the database.
1996
+
1997
+ Args:
1998
+ limit (Optional[int]): The maximum number of knowledge contents to return.
1999
+ page (Optional[int]): The page number.
2000
+ sort_by (Optional[str]): The column to sort by.
2001
+ sort_order (Optional[str]): The order to sort by.
2002
+
2003
+ Returns:
2004
+ List[KnowledgeRow]: The knowledge contents.
2005
+
2006
+ Raises:
2007
+ Exception: If an error occurs during retrieval.
2008
+ """
2009
+ table = await self._get_table(table_type="knowledge")
2010
+
2011
+ try:
2012
+ async with self.async_session_factory() as sess, sess.begin():
2013
+ stmt = select(table)
2014
+
2015
+ # Apply sorting
2016
+ if sort_by is not None:
2017
+ stmt = stmt.order_by(getattr(table.c, sort_by) * (1 if sort_order == "asc" else -1))
2018
+
2019
+ # Get total count before applying limit and pagination
2020
+ count_stmt = select(func.count()).select_from(stmt.alias())
2021
+ total_count = await sess.scalar(count_stmt) or 0
2022
+
2023
+ # Apply pagination after count
2024
+ if limit is not None:
2025
+ stmt = stmt.limit(limit)
2026
+ if page is not None:
2027
+ stmt = stmt.offset((page - 1) * limit)
2028
+
2029
+ result = await sess.execute(stmt)
2030
+ records = result.fetchall()
2031
+ return [KnowledgeRow.model_validate(record._mapping) for record in records], total_count
2032
+
2033
+ except Exception as e:
2034
+ log_error(f"Exception getting knowledge contents: {e}")
2035
+ return [], 0
2036
+
2037
+ async def upsert_knowledge_content(self, knowledge_row: KnowledgeRow):
2038
+ """Upsert knowledge content in the database.
2039
+
2040
+ Args:
2041
+ knowledge_row (KnowledgeRow): The knowledge row to upsert.
2042
+
2043
+ Returns:
2044
+ Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
2045
+ """
2046
+ try:
2047
+ table = await self._get_table(table_type="knowledge")
2048
+ async with self.async_session_factory() as sess, sess.begin():
2049
+ # Get the actual table columns to avoid "unconsumed column names" error
2050
+ table_columns = set(table.columns.keys())
2051
+
2052
+ # Only include fields that exist in the table and are not None
2053
+ insert_data = {}
2054
+ update_fields = {}
2055
+
2056
+ # Map of KnowledgeRow fields to table columns
2057
+ field_mapping = {
2058
+ "id": "id",
2059
+ "name": "name",
2060
+ "description": "description",
2061
+ "metadata": "metadata",
2062
+ "type": "type",
2063
+ "size": "size",
2064
+ "linked_to": "linked_to",
2065
+ "access_count": "access_count",
2066
+ "status": "status",
2067
+ "status_message": "status_message",
2068
+ "created_at": "created_at",
2069
+ "updated_at": "updated_at",
2070
+ "external_id": "external_id",
2071
+ }
2072
+
2073
+ # Build insert and update data only for fields that exist in the table
2074
+ for model_field, table_column in field_mapping.items():
2075
+ if table_column in table_columns:
2076
+ value = getattr(knowledge_row, model_field, None)
2077
+ if value is not None:
2078
+ insert_data[table_column] = value
2079
+ # Don't include ID in update_fields since it's the primary key
2080
+ if table_column != "id":
2081
+ update_fields[table_column] = value
2082
+
2083
+ # Ensure id is always included for the insert
2084
+ if "id" in table_columns and knowledge_row.id:
2085
+ insert_data["id"] = knowledge_row.id
2086
+
2087
+ # Handle case where update_fields is empty (all fields are None or don't exist in table)
2088
+ if not update_fields:
2089
+ # If we have insert_data, just do an insert without conflict resolution
2090
+ if insert_data:
2091
+ stmt = mysql.insert(table).values(insert_data)
2092
+ await sess.execute(stmt)
2093
+ else:
2094
+ # If we have no data at all, this is an error
2095
+ log_error("No valid fields found for knowledge row upsert")
2096
+ return None
2097
+ else:
2098
+ # Normal upsert with conflict resolution
2099
+ stmt = mysql.insert(table).values(insert_data).on_duplicate_key_update(**update_fields)
2100
+ await sess.execute(stmt)
2101
+
2102
+ log_debug(f"Upserted knowledge row with id '{knowledge_row.id}'")
2103
+
2104
+ return knowledge_row
2105
+
2106
+ except Exception as e:
2107
+ log_error(f"Error upserting knowledge row: {e}")
2108
+ return None
2109
+
2110
+ # -- Eval methods --
2111
+ async def create_eval_run(self, eval_run: EvalRunRecord) -> Optional[EvalRunRecord]:
2112
+ """Create an EvalRunRecord in the database.
2113
+
2114
+ Args:
2115
+ eval_run (EvalRunRecord): The eval run to create.
2116
+
2117
+ Returns:
2118
+ Optional[EvalRunRecord]: The created eval run, or None if the operation fails.
2119
+
2120
+ Raises:
2121
+ Exception: If an error occurs during creation.
2122
+ """
2123
+ try:
2124
+ table = await self._get_table(table_type="evals")
2125
+
2126
+ async with self.async_session_factory() as sess, sess.begin():
2127
+ current_time = int(time.time())
2128
+ stmt = mysql.insert(table).values(
2129
+ {"created_at": current_time, "updated_at": current_time, **eval_run.model_dump()}
2130
+ )
2131
+ await sess.execute(stmt)
2132
+
2133
+ log_debug(f"Created eval run with id '{eval_run.run_id}'")
2134
+
2135
+ return eval_run
2136
+
2137
+ except Exception as e:
2138
+ log_error(f"Error creating eval run: {e}")
2139
+ return None
2140
+
2141
+ async def delete_eval_run(self, eval_run_id: str) -> None:
2142
+ """Delete an eval run from the database.
2143
+
2144
+ Args:
2145
+ eval_run_id (str): The ID of the eval run to delete.
2146
+ """
2147
+ try:
2148
+ table = await self._get_table(table_type="evals")
2149
+
2150
+ async with self.async_session_factory() as sess, sess.begin():
2151
+ stmt = table.delete().where(table.c.run_id == eval_run_id)
2152
+ result = await sess.execute(stmt)
2153
+
2154
+ if result.rowcount == 0: # type: ignore
2155
+ log_warning(f"No eval run found with ID: {eval_run_id}")
2156
+ else:
2157
+ log_debug(f"Deleted eval run with ID: {eval_run_id}")
2158
+
2159
+ except Exception as e:
2160
+ log_error(f"Error deleting eval run {eval_run_id}: {e}")
2161
+
2162
+ async def delete_eval_runs(self, eval_run_ids: List[str]) -> None:
2163
+ """Delete multiple eval runs from the database.
2164
+
2165
+ Args:
2166
+ eval_run_ids (List[str]): List of eval run IDs to delete.
2167
+ """
2168
+ try:
2169
+ table = await self._get_table(table_type="evals")
2170
+
2171
+ async with self.async_session_factory() as sess, sess.begin():
2172
+ stmt = table.delete().where(table.c.run_id.in_(eval_run_ids))
2173
+ result = await sess.execute(stmt)
2174
+
2175
+ if result.rowcount == 0: # type: ignore
2176
+ log_warning(f"No eval runs found with IDs: {eval_run_ids}")
2177
+ else:
2178
+ log_debug(f"Deleted {result.rowcount} eval runs") # type: ignore
2179
+
2180
+ except Exception as e:
2181
+ log_error(f"Error deleting eval runs {eval_run_ids}: {e}")
2182
+
2183
+ async def get_eval_run(
2184
+ self, eval_run_id: str, deserialize: Optional[bool] = True
2185
+ ) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
2186
+ """Get an eval run from the database.
2187
+
2188
+ Args:
2189
+ eval_run_id (str): The ID of the eval run to get.
2190
+ deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
2191
+
2192
+ Returns:
2193
+ Optional[Union[EvalRunRecord, Dict[str, Any]]]:
2194
+ - When deserialize=True: EvalRunRecord object
2195
+ - When deserialize=False: EvalRun dictionary
2196
+
2197
+ Raises:
2198
+ Exception: If an error occurs during retrieval.
2199
+ """
2200
+ try:
2201
+ table = await self._get_table(table_type="evals")
2202
+
2203
+ async with self.async_session_factory() as sess, sess.begin():
2204
+ stmt = select(table).where(table.c.run_id == eval_run_id)
2205
+ result = await sess.execute(stmt)
2206
+ row = result.fetchone()
2207
+ if row is None:
2208
+ return None
2209
+
2210
+ eval_run_raw = dict(row._mapping)
2211
+ if not deserialize:
2212
+ return eval_run_raw
2213
+
2214
+ return EvalRunRecord.model_validate(eval_run_raw)
2215
+
2216
+ except Exception as e:
2217
+ log_error(f"Exception getting eval run {eval_run_id}: {e}")
2218
+ return None
2219
+
2220
+ async def get_eval_runs(
2221
+ self,
2222
+ limit: Optional[int] = None,
2223
+ page: Optional[int] = None,
2224
+ sort_by: Optional[str] = None,
2225
+ sort_order: Optional[str] = None,
2226
+ agent_id: Optional[str] = None,
2227
+ team_id: Optional[str] = None,
2228
+ workflow_id: Optional[str] = None,
2229
+ model_id: Optional[str] = None,
2230
+ filter_type: Optional[EvalFilterType] = None,
2231
+ eval_type: Optional[List[EvalType]] = None,
2232
+ deserialize: Optional[bool] = True,
2233
+ ) -> Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
2234
+ """Get all eval runs from the database.
2235
+
2236
+ Args:
2237
+ limit (Optional[int]): The maximum number of eval runs to return.
2238
+ page (Optional[int]): The page number.
2239
+ sort_by (Optional[str]): The column to sort by.
2240
+ sort_order (Optional[str]): The order to sort by.
2241
+ agent_id (Optional[str]): The ID of the agent to filter by.
2242
+ team_id (Optional[str]): The ID of the team to filter by.
2243
+ workflow_id (Optional[str]): The ID of the workflow to filter by.
2244
+ model_id (Optional[str]): The ID of the model to filter by.
2245
+ eval_type (Optional[List[EvalType]]): The type(s) of eval to filter by.
2246
+ filter_type (Optional[EvalFilterType]): Filter by component type (agent, team, workflow).
2247
+ deserialize (Optional[bool]): Whether to serialize the eval runs. Defaults to True.
2248
+
2249
+ Returns:
2250
+ Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
2251
+ - When deserialize=True: List of EvalRunRecord objects
2252
+ - When deserialize=False: List of dictionaries
2253
+
2254
+ Raises:
2255
+ Exception: If an error occurs during retrieval.
2256
+ """
2257
+ try:
2258
+ table = await self._get_table(table_type="evals")
2259
+
2260
+ async with self.async_session_factory() as sess, sess.begin():
2261
+ stmt = select(table)
2262
+
2263
+ # Filtering
2264
+ if agent_id is not None:
2265
+ stmt = stmt.where(table.c.agent_id == agent_id)
2266
+ if team_id is not None:
2267
+ stmt = stmt.where(table.c.team_id == team_id)
2268
+ if workflow_id is not None:
2269
+ stmt = stmt.where(table.c.workflow_id == workflow_id)
2270
+ if model_id is not None:
2271
+ stmt = stmt.where(table.c.model_id == model_id)
2272
+ if eval_type is not None and len(eval_type) > 0:
2273
+ stmt = stmt.where(table.c.eval_type.in_(eval_type))
2274
+ if filter_type is not None:
2275
+ if filter_type == EvalFilterType.AGENT:
2276
+ stmt = stmt.where(table.c.agent_id.is_not(None))
2277
+ elif filter_type == EvalFilterType.TEAM:
2278
+ stmt = stmt.where(table.c.team_id.is_not(None))
2279
+ elif filter_type == EvalFilterType.WORKFLOW:
2280
+ stmt = stmt.where(table.c.workflow_id.is_not(None))
2281
+
2282
+ # Get total count after applying filtering
2283
+ count_stmt = select(func.count()).select_from(stmt.alias())
2284
+ total_count = await sess.scalar(count_stmt) or 0
2285
+
2286
+ # Sorting
2287
+ if sort_by is None:
2288
+ stmt = stmt.order_by(table.c.created_at.desc())
2289
+ else:
2290
+ stmt = apply_sorting(stmt, table, sort_by, sort_order)
2291
+
2292
+ # Paginating
2293
+ if limit is not None:
2294
+ stmt = stmt.limit(limit)
2295
+ if page is not None:
2296
+ stmt = stmt.offset((page - 1) * limit)
2297
+
2298
+ result = await sess.execute(stmt)
2299
+ records = result.fetchall()
2300
+ if not records:
2301
+ return [] if deserialize else ([], 0)
2302
+
2303
+ eval_runs_raw = [dict(row._mapping) for row in records]
2304
+ if not deserialize:
2305
+ return eval_runs_raw, total_count
2306
+
2307
+ return [EvalRunRecord.model_validate(row) for row in eval_runs_raw]
2308
+
2309
+ except Exception as e:
2310
+ log_error(f"Exception getting eval runs: {e}")
2311
+ return [] if deserialize else ([], 0)
2312
+
2313
+ async def rename_eval_run(
2314
+ self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
2315
+ ) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
2316
+ """Upsert the name of an eval run in the database, returning raw dictionary.
2317
+
2318
+ Args:
2319
+ eval_run_id (str): The ID of the eval run to update.
2320
+ name (str): The new name of the eval run.
2321
+
2322
+ Returns:
2323
+ Optional[Dict[str, Any]]: The updated eval run, or None if the operation fails.
2324
+
2325
+ Raises:
2326
+ Exception: If an error occurs during update.
2327
+ """
2328
+ try:
2329
+ table = await self._get_table(table_type="evals")
2330
+ async with self.async_session_factory() as sess, sess.begin():
2331
+ stmt = (
2332
+ table.update().where(table.c.run_id == eval_run_id).values(name=name, updated_at=int(time.time()))
2333
+ )
2334
+ await sess.execute(stmt)
2335
+
2336
+ eval_run_raw = await self.get_eval_run(eval_run_id=eval_run_id, deserialize=deserialize)
2337
+ if not eval_run_raw or not deserialize:
2338
+ return eval_run_raw
2339
+
2340
+ return EvalRunRecord.model_validate(eval_run_raw)
2341
+
2342
+ except Exception as e:
2343
+ log_error(f"Error upserting eval run name {eval_run_id}: {e}")
2344
+ return None
2345
+
2346
+ # -- Migrations --
2347
+
2348
+ async def migrate_table_from_v1_to_v2(self, v1_db_schema: str, v1_table_name: str, v1_table_type: str):
2349
+ """Migrate all content in the given table to the right v2 table"""
2350
+
2351
+ from typing import Sequence
2352
+
2353
+ from agno.db.migrations.v1_to_v2 import (
2354
+ get_all_table_content,
2355
+ parse_agent_sessions,
2356
+ parse_memories,
2357
+ parse_team_sessions,
2358
+ parse_workflow_sessions,
2359
+ )
2360
+
2361
+ # Get all content from the old table
2362
+ old_content: list[dict[str, Any]] = get_all_table_content(
2363
+ db=self,
2364
+ db_schema=v1_db_schema,
2365
+ table_name=v1_table_name,
2366
+ )
2367
+ if not old_content:
2368
+ log_info(f"No content to migrate from table {v1_table_name}")
2369
+ return
2370
+
2371
+ # Parse the content into the new format
2372
+ memories: List[UserMemory] = []
2373
+ sessions: Sequence[Union[AgentSession, TeamSession, WorkflowSession]] = []
2374
+ if v1_table_type == "agent_sessions":
2375
+ sessions = parse_agent_sessions(old_content)
2376
+ elif v1_table_type == "team_sessions":
2377
+ sessions = parse_team_sessions(old_content)
2378
+ elif v1_table_type == "workflow_sessions":
2379
+ sessions = parse_workflow_sessions(old_content)
2380
+ elif v1_table_type == "memories":
2381
+ memories = parse_memories(old_content)
2382
+ else:
2383
+ raise ValueError(f"Invalid table type: {v1_table_type}")
2384
+
2385
+ # Insert the new content into the new table
2386
+ if v1_table_type == "agent_sessions":
2387
+ for session in sessions:
2388
+ await self.upsert_session(session)
2389
+ log_info(f"Migrated {len(sessions)} Agent sessions to table: {self.session_table_name}")
2390
+
2391
+ elif v1_table_type == "team_sessions":
2392
+ for session in sessions:
2393
+ await self.upsert_session(session)
2394
+ log_info(f"Migrated {len(sessions)} Team sessions to table: {self.session_table_name}")
2395
+
2396
+ elif v1_table_type == "workflow_sessions":
2397
+ for session in sessions:
2398
+ await self.upsert_session(session)
2399
+ log_info(f"Migrated {len(sessions)} Workflow sessions to table: {self.session_table_name}")
2400
+
2401
+ elif v1_table_type == "memories":
2402
+ for memory in memories:
2403
+ await self.upsert_user_memory(memory)
2404
+ log_info(f"Migrated {len(memories)} memories to table: {self.memory_table}")
2405
+
2406
+ # --- Traces ---
2407
+ def _get_traces_base_query(self, table: Table, spans_table: Optional[Table] = None):
2408
+ """Build base query for traces with aggregated span counts.
2409
+
2410
+ Args:
2411
+ table: The traces table.
2412
+ spans_table: The spans table (optional).
2413
+
2414
+ Returns:
2415
+ SQLAlchemy select statement with total_spans and error_count calculated dynamically.
2416
+ """
2417
+ from sqlalchemy import case, literal
2418
+
2419
+ if spans_table is not None:
2420
+ # JOIN with spans table to calculate total_spans and error_count
2421
+ return (
2422
+ select(
2423
+ table,
2424
+ func.coalesce(func.count(spans_table.c.span_id), 0).label("total_spans"),
2425
+ func.coalesce(func.sum(case((spans_table.c.status_code == "ERROR", 1), else_=0)), 0).label(
2426
+ "error_count"
2427
+ ),
2428
+ )
2429
+ .select_from(table.outerjoin(spans_table, table.c.trace_id == spans_table.c.trace_id))
2430
+ .group_by(table.c.trace_id)
2431
+ )
2432
+ else:
2433
+ # Fallback if spans table doesn't exist
2434
+ return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
2435
+
2436
+ def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
2437
+ """Build a SQL CASE expression that returns the component level for a trace.
2438
+
2439
+ Component levels (higher = more important):
2440
+ - 3: Workflow root (.run or .arun with workflow_id)
2441
+ - 2: Team root (.run or .arun with team_id)
2442
+ - 1: Agent root (.run or .arun with agent_id)
2443
+ - 0: Child span (not a root)
2444
+
2445
+ Args:
2446
+ workflow_id_col: SQL column/expression for workflow_id
2447
+ team_id_col: SQL column/expression for team_id
2448
+ agent_id_col: SQL column/expression for agent_id
2449
+ name_col: SQL column/expression for name
2450
+
2451
+ Returns:
2452
+ SQLAlchemy CASE expression returning the component level as an integer.
2453
+ """
2454
+ from sqlalchemy import and_, case, or_
2455
+
2456
+ is_root_name = or_(name_col.like("%.run%"), name_col.like("%.arun%"))
2457
+
2458
+ return case(
2459
+ # Workflow root (level 3)
2460
+ (and_(workflow_id_col.isnot(None), is_root_name), 3),
2461
+ # Team root (level 2)
2462
+ (and_(team_id_col.isnot(None), is_root_name), 2),
2463
+ # Agent root (level 1)
2464
+ (and_(agent_id_col.isnot(None), is_root_name), 1),
2465
+ # Child span or unknown (level 0)
2466
+ else_=0,
2467
+ )
2468
+
2469
+ async def upsert_trace(self, trace: "Trace") -> None:
2470
+ """Create or update a single trace record in the database.
2471
+
2472
+ Uses INSERT ... ON DUPLICATE KEY UPDATE (upsert) to handle concurrent inserts
2473
+ atomically and avoid race conditions.
2474
+
2475
+ Args:
2476
+ trace: The Trace object to store (one per trace_id).
2477
+ """
2478
+ from sqlalchemy import case
2479
+
2480
+ try:
2481
+ table = await self._get_table(table_type="traces", create_table_if_not_found=True)
2482
+ if table is None:
2483
+ return
2484
+
2485
+ trace_dict = trace.to_dict()
2486
+ trace_dict.pop("total_spans", None)
2487
+ trace_dict.pop("error_count", None)
2488
+
2489
+ async with self.async_session_factory() as sess, sess.begin():
2490
+ # Use upsert to handle concurrent inserts atomically
2491
+ # On conflict, update fields while preserving existing non-null context values
2492
+ # and keeping the earliest start_time
2493
+ insert_stmt = mysql.insert(table).values(trace_dict)
2494
+
2495
+ # Build component level expressions for comparing trace priority
2496
+ new_level = self._get_trace_component_level_expr(
2497
+ insert_stmt.inserted.workflow_id,
2498
+ insert_stmt.inserted.team_id,
2499
+ insert_stmt.inserted.agent_id,
2500
+ insert_stmt.inserted.name,
2501
+ )
2502
+ existing_level = self._get_trace_component_level_expr(
2503
+ table.c.workflow_id,
2504
+ table.c.team_id,
2505
+ table.c.agent_id,
2506
+ table.c.name,
2507
+ )
2508
+
2509
+ # Build the ON DUPLICATE KEY UPDATE clause
2510
+ # Use LEAST for start_time, GREATEST for end_time to capture full trace duration
2511
+ # MySQL stores timestamps as ISO strings, so string comparison works for ISO format
2512
+ # Duration is calculated using TIMESTAMPDIFF in microseconds then converted to ms
2513
+ upsert_stmt = insert_stmt.on_duplicate_key_update(
2514
+ end_time=func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
2515
+ start_time=func.least(table.c.start_time, insert_stmt.inserted.start_time),
2516
+ # Calculate duration in milliseconds using TIMESTAMPDIFF
2517
+ # TIMESTAMPDIFF(MICROSECOND, start, end) / 1000 gives milliseconds
2518
+ duration_ms=func.timestampdiff(
2519
+ text("MICROSECOND"),
2520
+ func.least(table.c.start_time, insert_stmt.inserted.start_time),
2521
+ func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
2522
+ )
2523
+ / 1000,
2524
+ status=insert_stmt.inserted.status,
2525
+ # Update name only if new trace is from a higher-level component
2526
+ # Priority: workflow (3) > team (2) > agent (1) > child spans (0)
2527
+ name=case(
2528
+ (new_level > existing_level, insert_stmt.inserted.name),
2529
+ else_=table.c.name,
2530
+ ),
2531
+ # Preserve existing non-null context values using COALESCE
2532
+ run_id=func.coalesce(insert_stmt.inserted.run_id, table.c.run_id),
2533
+ session_id=func.coalesce(insert_stmt.inserted.session_id, table.c.session_id),
2534
+ user_id=func.coalesce(insert_stmt.inserted.user_id, table.c.user_id),
2535
+ agent_id=func.coalesce(insert_stmt.inserted.agent_id, table.c.agent_id),
2536
+ team_id=func.coalesce(insert_stmt.inserted.team_id, table.c.team_id),
2537
+ workflow_id=func.coalesce(insert_stmt.inserted.workflow_id, table.c.workflow_id),
2538
+ )
2539
+ await sess.execute(upsert_stmt)
2540
+
2541
+ except Exception as e:
2542
+ log_error(f"Error creating trace: {e}")
2543
+ # Don't raise - tracing should not break the main application flow
2544
+
2545
+ async def get_trace(
2546
+ self,
2547
+ trace_id: Optional[str] = None,
2548
+ run_id: Optional[str] = None,
2549
+ ):
2550
+ """Get a single trace by trace_id or other filters.
2551
+
2552
+ Args:
2553
+ trace_id: The unique trace identifier.
2554
+ run_id: Filter by run ID (returns first match).
2555
+
2556
+ Returns:
2557
+ Optional[Trace]: The trace if found, None otherwise.
2558
+
2559
+ Note:
2560
+ If multiple filters are provided, trace_id takes precedence.
2561
+ For other filters, the most recent trace is returned.
2562
+ """
2563
+ try:
2564
+ from agno.tracing.schemas import Trace
2565
+
2566
+ table = await self._get_table(table_type="traces")
2567
+ if table is None:
2568
+ return None
2569
+
2570
+ # Get spans table for JOIN
2571
+ spans_table = await self._get_table(table_type="spans")
2572
+
2573
+ async with self.async_session_factory() as sess:
2574
+ # Build query with aggregated span counts
2575
+ stmt = self._get_traces_base_query(table, spans_table)
2576
+
2577
+ if trace_id:
2578
+ stmt = stmt.where(table.c.trace_id == trace_id)
2579
+ elif run_id:
2580
+ stmt = stmt.where(table.c.run_id == run_id)
2581
+ else:
2582
+ log_debug("get_trace called without any filter parameters")
2583
+ return None
2584
+
2585
+ # Order by most recent and get first result
2586
+ stmt = stmt.order_by(table.c.start_time.desc()).limit(1)
2587
+ result = await sess.execute(stmt)
2588
+ row = result.fetchone()
2589
+
2590
+ if row:
2591
+ return Trace.from_dict(dict(row._mapping))
2592
+ return None
2593
+
2594
+ except Exception as e:
2595
+ log_error(f"Error getting trace: {e}")
2596
+ return None
2597
+
2598
+ async def get_traces(
2599
+ self,
2600
+ run_id: Optional[str] = None,
2601
+ session_id: Optional[str] = None,
2602
+ user_id: Optional[str] = None,
2603
+ agent_id: Optional[str] = None,
2604
+ team_id: Optional[str] = None,
2605
+ workflow_id: Optional[str] = None,
2606
+ status: Optional[str] = None,
2607
+ start_time: Optional[datetime] = None,
2608
+ end_time: Optional[datetime] = None,
2609
+ limit: Optional[int] = 20,
2610
+ page: Optional[int] = 1,
2611
+ ) -> tuple[List, int]:
2612
+ """Get traces matching the provided filters with pagination.
2613
+
2614
+ Args:
2615
+ run_id: Filter by run ID.
2616
+ session_id: Filter by session ID.
2617
+ user_id: Filter by user ID.
2618
+ agent_id: Filter by agent ID.
2619
+ team_id: Filter by team ID.
2620
+ workflow_id: Filter by workflow ID.
2621
+ status: Filter by status (OK, ERROR, UNSET).
2622
+ start_time: Filter traces starting after this datetime.
2623
+ end_time: Filter traces ending before this datetime.
2624
+ limit: Maximum number of traces to return per page.
2625
+ page: Page number (1-indexed).
2626
+
2627
+ Returns:
2628
+ tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
2629
+ """
2630
+ try:
2631
+ from agno.tracing.schemas import Trace
2632
+
2633
+ log_debug(
2634
+ f"get_traces called with filters: run_id={run_id}, session_id={session_id}, user_id={user_id}, agent_id={agent_id}, page={page}, limit={limit}"
2635
+ )
2636
+
2637
+ table = await self._get_table(table_type="traces")
2638
+ if table is None:
2639
+ log_debug("Traces table not found")
2640
+ return [], 0
2641
+
2642
+ # Get spans table for JOIN
2643
+ spans_table = await self._get_table(table_type="spans")
2644
+
2645
+ async with self.async_session_factory() as sess:
2646
+ # Build base query with aggregated span counts
2647
+ base_stmt = self._get_traces_base_query(table, spans_table)
2648
+
2649
+ # Apply filters
2650
+ if run_id:
2651
+ base_stmt = base_stmt.where(table.c.run_id == run_id)
2652
+ if session_id:
2653
+ base_stmt = base_stmt.where(table.c.session_id == session_id)
2654
+ if user_id:
2655
+ base_stmt = base_stmt.where(table.c.user_id == user_id)
2656
+ if agent_id:
2657
+ base_stmt = base_stmt.where(table.c.agent_id == agent_id)
2658
+ if team_id:
2659
+ base_stmt = base_stmt.where(table.c.team_id == team_id)
2660
+ if workflow_id:
2661
+ base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
2662
+ if status:
2663
+ base_stmt = base_stmt.where(table.c.status == status)
2664
+ if start_time:
2665
+ # Convert datetime to ISO string for comparison
2666
+ base_stmt = base_stmt.where(table.c.start_time >= start_time.isoformat())
2667
+ if end_time:
2668
+ # Convert datetime to ISO string for comparison
2669
+ base_stmt = base_stmt.where(table.c.end_time <= end_time.isoformat())
2670
+
2671
+ # Get total count
2672
+ count_stmt = select(func.count()).select_from(base_stmt.alias())
2673
+ count_result = await sess.execute(count_stmt)
2674
+ total_count = count_result.scalar() or 0
2675
+
2676
+ # Apply pagination
2677
+ offset = (page - 1) * limit if page and limit else 0
2678
+ paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
2679
+
2680
+ result = await sess.execute(paginated_stmt)
2681
+ results = result.fetchall()
2682
+
2683
+ traces = [Trace.from_dict(dict(row._mapping)) for row in results]
2684
+ return traces, total_count
2685
+
2686
+ except Exception as e:
2687
+ log_error(f"Error getting traces: {e}")
2688
+ return [], 0
2689
+
2690
+ async def get_trace_stats(
2691
+ self,
2692
+ user_id: Optional[str] = None,
2693
+ agent_id: Optional[str] = None,
2694
+ team_id: Optional[str] = None,
2695
+ workflow_id: Optional[str] = None,
2696
+ start_time: Optional[datetime] = None,
2697
+ end_time: Optional[datetime] = None,
2698
+ limit: Optional[int] = 20,
2699
+ page: Optional[int] = 1,
2700
+ ) -> tuple[List[Dict[str, Any]], int]:
2701
+ """Get trace statistics grouped by session.
2702
+
2703
+ Args:
2704
+ user_id: Filter by user ID.
2705
+ agent_id: Filter by agent ID.
2706
+ team_id: Filter by team ID.
2707
+ workflow_id: Filter by workflow ID.
2708
+ start_time: Filter sessions with traces created after this datetime.
2709
+ end_time: Filter sessions with traces created before this datetime.
2710
+ limit: Maximum number of sessions to return per page.
2711
+ page: Page number (1-indexed).
2712
+
2713
+ Returns:
2714
+ tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
2715
+ Each dict contains: session_id, user_id, agent_id, team_id, total_traces,
2716
+ workflow_id, first_trace_at, last_trace_at.
2717
+ """
2718
+ try:
2719
+ table = await self._get_table(table_type="traces")
2720
+ if table is None:
2721
+ log_debug("Traces table not found")
2722
+ return [], 0
2723
+
2724
+ async with self.async_session_factory() as sess:
2725
+ # Build base query grouped by session_id
2726
+ base_stmt = (
2727
+ select(
2728
+ table.c.session_id,
2729
+ table.c.user_id,
2730
+ table.c.agent_id,
2731
+ table.c.team_id,
2732
+ table.c.workflow_id,
2733
+ func.count(table.c.trace_id).label("total_traces"),
2734
+ func.min(table.c.created_at).label("first_trace_at"),
2735
+ func.max(table.c.created_at).label("last_trace_at"),
2736
+ )
2737
+ .where(table.c.session_id.isnot(None)) # Only sessions with session_id
2738
+ .group_by(
2739
+ table.c.session_id, table.c.user_id, table.c.agent_id, table.c.team_id, table.c.workflow_id
2740
+ )
2741
+ )
2742
+
2743
+ # Apply filters
2744
+ if user_id:
2745
+ base_stmt = base_stmt.where(table.c.user_id == user_id)
2746
+ if workflow_id:
2747
+ base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
2748
+ if team_id:
2749
+ base_stmt = base_stmt.where(table.c.team_id == team_id)
2750
+ if agent_id:
2751
+ base_stmt = base_stmt.where(table.c.agent_id == agent_id)
2752
+ if start_time:
2753
+ # Convert datetime to ISO string for comparison
2754
+ base_stmt = base_stmt.where(table.c.created_at >= start_time.isoformat())
2755
+ if end_time:
2756
+ # Convert datetime to ISO string for comparison
2757
+ base_stmt = base_stmt.where(table.c.created_at <= end_time.isoformat())
2758
+
2759
+ # Get total count of sessions
2760
+ count_stmt = select(func.count()).select_from(base_stmt.alias())
2761
+ count_result = await sess.execute(count_stmt)
2762
+ total_count = count_result.scalar() or 0
2763
+
2764
+ # Apply pagination and ordering
2765
+ offset = (page - 1) * limit if page and limit else 0
2766
+ paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
2767
+
2768
+ result = await sess.execute(paginated_stmt)
2769
+ results = result.fetchall()
2770
+
2771
+ # Convert to list of dicts with datetime objects
2772
+ stats_list = []
2773
+ for row in results:
2774
+ # Convert ISO strings to datetime objects
2775
+ first_trace_at_str = row.first_trace_at
2776
+ last_trace_at_str = row.last_trace_at
2777
+
2778
+ # Parse ISO format strings to datetime objects
2779
+ first_trace_at = datetime.fromisoformat(first_trace_at_str.replace("Z", "+00:00"))
2780
+ last_trace_at = datetime.fromisoformat(last_trace_at_str.replace("Z", "+00:00"))
2781
+
2782
+ stats_list.append(
2783
+ {
2784
+ "session_id": row.session_id,
2785
+ "user_id": row.user_id,
2786
+ "agent_id": row.agent_id,
2787
+ "team_id": row.team_id,
2788
+ "workflow_id": row.workflow_id,
2789
+ "total_traces": row.total_traces,
2790
+ "first_trace_at": first_trace_at,
2791
+ "last_trace_at": last_trace_at,
2792
+ }
2793
+ )
2794
+
2795
+ return stats_list, total_count
2796
+
2797
+ except Exception as e:
2798
+ log_error(f"Error getting trace stats: {e}")
2799
+ return [], 0
2800
+
2801
+ # --- Spans ---
2802
+ async def create_span(self, span: "Span") -> None:
2803
+ """Create a single span in the database.
2804
+
2805
+ Args:
2806
+ span: The Span object to store.
2807
+ """
2808
+ try:
2809
+ table = await self._get_table(table_type="spans", create_table_if_not_found=True)
2810
+ if table is None:
2811
+ return
2812
+
2813
+ async with self.async_session_factory() as sess, sess.begin():
2814
+ stmt = mysql.insert(table).values(span.to_dict())
2815
+ await sess.execute(stmt)
2816
+
2817
+ except Exception as e:
2818
+ log_error(f"Error creating span: {e}")
2819
+
2820
+ async def create_spans(self, spans: List) -> None:
2821
+ """Create multiple spans in the database as a batch.
2822
+
2823
+ Args:
2824
+ spans: List of Span objects to store.
2825
+ """
2826
+ if not spans:
2827
+ return
2828
+
2829
+ try:
2830
+ table = await self._get_table(table_type="spans", create_table_if_not_found=True)
2831
+ if table is None:
2832
+ return
2833
+
2834
+ async with self.async_session_factory() as sess, sess.begin():
2835
+ for span in spans:
2836
+ stmt = mysql.insert(table).values(span.to_dict())
2837
+ await sess.execute(stmt)
2838
+
2839
+ except Exception as e:
2840
+ log_error(f"Error creating spans batch: {e}")
2841
+
2842
+ async def get_span(self, span_id: str):
2843
+ """Get a single span by its span_id.
2844
+
2845
+ Args:
2846
+ span_id: The unique span identifier.
2847
+
2848
+ Returns:
2849
+ Optional[Span]: The span if found, None otherwise.
2850
+ """
2851
+ try:
2852
+ from agno.tracing.schemas import Span
2853
+
2854
+ table = await self._get_table(table_type="spans")
2855
+ if table is None:
2856
+ return None
2857
+
2858
+ async with self.async_session_factory() as sess:
2859
+ stmt = select(table).where(table.c.span_id == span_id)
2860
+ result = await sess.execute(stmt)
2861
+ row = result.fetchone()
2862
+ if row:
2863
+ return Span.from_dict(dict(row._mapping))
2864
+ return None
2865
+
2866
+ except Exception as e:
2867
+ log_error(f"Error getting span: {e}")
2868
+ return None
2869
+
2870
+ async def get_spans(
2871
+ self,
2872
+ trace_id: Optional[str] = None,
2873
+ parent_span_id: Optional[str] = None,
2874
+ limit: Optional[int] = 1000,
2875
+ ) -> List:
2876
+ """Get spans matching the provided filters.
2877
+
2878
+ Args:
2879
+ trace_id: Filter by trace ID.
2880
+ parent_span_id: Filter by parent span ID.
2881
+ limit: Maximum number of spans to return.
2882
+
2883
+ Returns:
2884
+ List[Span]: List of matching spans.
2885
+ """
2886
+ try:
2887
+ from agno.tracing.schemas import Span
2888
+
2889
+ table = await self._get_table(table_type="spans")
2890
+ if table is None:
2891
+ return []
2892
+
2893
+ async with self.async_session_factory() as sess:
2894
+ stmt = select(table)
2895
+
2896
+ # Apply filters
2897
+ if trace_id:
2898
+ stmt = stmt.where(table.c.trace_id == trace_id)
2899
+ if parent_span_id:
2900
+ stmt = stmt.where(table.c.parent_span_id == parent_span_id)
2901
+
2902
+ if limit:
2903
+ stmt = stmt.limit(limit)
2904
+
2905
+ result = await sess.execute(stmt)
2906
+ results = result.fetchall()
2907
+ return [Span.from_dict(dict(row._mapping)) for row in results]
2908
+
2909
+ except Exception as e:
2910
+ log_error(f"Error getting spans: {e}")
2911
+ return []
2912
+
2913
+ # -- Learning methods (stubs) --
2914
+ async def get_learning(
2915
+ self,
2916
+ learning_type: str,
2917
+ user_id: Optional[str] = None,
2918
+ agent_id: Optional[str] = None,
2919
+ team_id: Optional[str] = None,
2920
+ session_id: Optional[str] = None,
2921
+ namespace: Optional[str] = None,
2922
+ entity_id: Optional[str] = None,
2923
+ entity_type: Optional[str] = None,
2924
+ ) -> Optional[Dict[str, Any]]:
2925
+ raise NotImplementedError("Learning methods not yet implemented for AsyncMySQLDb")
2926
+
2927
+ async def upsert_learning(
2928
+ self,
2929
+ id: str,
2930
+ learning_type: str,
2931
+ content: Dict[str, Any],
2932
+ user_id: Optional[str] = None,
2933
+ agent_id: Optional[str] = None,
2934
+ team_id: Optional[str] = None,
2935
+ session_id: Optional[str] = None,
2936
+ namespace: Optional[str] = None,
2937
+ entity_id: Optional[str] = None,
2938
+ entity_type: Optional[str] = None,
2939
+ metadata: Optional[Dict[str, Any]] = None,
2940
+ ) -> None:
2941
+ raise NotImplementedError("Learning methods not yet implemented for AsyncMySQLDb")
2942
+
2943
+ async def delete_learning(self, id: str) -> bool:
2944
+ raise NotImplementedError("Learning methods not yet implemented for AsyncMySQLDb")
2945
+
2946
+ async def get_learnings(
2947
+ self,
2948
+ learning_type: Optional[str] = None,
2949
+ user_id: Optional[str] = None,
2950
+ agent_id: Optional[str] = None,
2951
+ team_id: Optional[str] = None,
2952
+ session_id: Optional[str] = None,
2953
+ namespace: Optional[str] = None,
2954
+ entity_id: Optional[str] = None,
2955
+ entity_type: Optional[str] = None,
2956
+ limit: Optional[int] = None,
2957
+ ) -> List[Dict[str, Any]]:
2958
+ raise NotImplementedError("Learning methods not yet implemented for AsyncMySQLDb")