agno 2.1.2__py3-none-any.whl → 2.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. agno/agent/agent.py +5540 -2273
  2. agno/api/api.py +2 -0
  3. agno/api/os.py +1 -1
  4. agno/compression/__init__.py +3 -0
  5. agno/compression/manager.py +247 -0
  6. agno/culture/__init__.py +3 -0
  7. agno/culture/manager.py +956 -0
  8. agno/db/async_postgres/__init__.py +3 -0
  9. agno/db/base.py +689 -6
  10. agno/db/dynamo/dynamo.py +933 -37
  11. agno/db/dynamo/schemas.py +174 -10
  12. agno/db/dynamo/utils.py +63 -4
  13. agno/db/firestore/firestore.py +831 -9
  14. agno/db/firestore/schemas.py +51 -0
  15. agno/db/firestore/utils.py +102 -4
  16. agno/db/gcs_json/gcs_json_db.py +660 -12
  17. agno/db/gcs_json/utils.py +60 -26
  18. agno/db/in_memory/in_memory_db.py +287 -14
  19. agno/db/in_memory/utils.py +60 -2
  20. agno/db/json/json_db.py +590 -14
  21. agno/db/json/utils.py +60 -26
  22. agno/db/migrations/manager.py +199 -0
  23. agno/db/migrations/v1_to_v2.py +43 -13
  24. agno/db/migrations/versions/__init__.py +0 -0
  25. agno/db/migrations/versions/v2_3_0.py +938 -0
  26. agno/db/mongo/__init__.py +15 -1
  27. agno/db/mongo/async_mongo.py +2760 -0
  28. agno/db/mongo/mongo.py +879 -11
  29. agno/db/mongo/schemas.py +42 -0
  30. agno/db/mongo/utils.py +80 -8
  31. agno/db/mysql/__init__.py +2 -1
  32. agno/db/mysql/async_mysql.py +2912 -0
  33. agno/db/mysql/mysql.py +946 -68
  34. agno/db/mysql/schemas.py +72 -10
  35. agno/db/mysql/utils.py +198 -7
  36. agno/db/postgres/__init__.py +2 -1
  37. agno/db/postgres/async_postgres.py +2579 -0
  38. agno/db/postgres/postgres.py +942 -57
  39. agno/db/postgres/schemas.py +81 -18
  40. agno/db/postgres/utils.py +164 -2
  41. agno/db/redis/redis.py +671 -7
  42. agno/db/redis/schemas.py +50 -0
  43. agno/db/redis/utils.py +65 -7
  44. agno/db/schemas/__init__.py +2 -1
  45. agno/db/schemas/culture.py +120 -0
  46. agno/db/schemas/evals.py +1 -0
  47. agno/db/schemas/memory.py +17 -2
  48. agno/db/singlestore/schemas.py +63 -0
  49. agno/db/singlestore/singlestore.py +949 -83
  50. agno/db/singlestore/utils.py +60 -2
  51. agno/db/sqlite/__init__.py +2 -1
  52. agno/db/sqlite/async_sqlite.py +2911 -0
  53. agno/db/sqlite/schemas.py +62 -0
  54. agno/db/sqlite/sqlite.py +965 -46
  55. agno/db/sqlite/utils.py +169 -8
  56. agno/db/surrealdb/__init__.py +3 -0
  57. agno/db/surrealdb/metrics.py +292 -0
  58. agno/db/surrealdb/models.py +334 -0
  59. agno/db/surrealdb/queries.py +71 -0
  60. agno/db/surrealdb/surrealdb.py +1908 -0
  61. agno/db/surrealdb/utils.py +147 -0
  62. agno/db/utils.py +2 -0
  63. agno/eval/__init__.py +10 -0
  64. agno/eval/accuracy.py +75 -55
  65. agno/eval/agent_as_judge.py +861 -0
  66. agno/eval/base.py +29 -0
  67. agno/eval/performance.py +16 -7
  68. agno/eval/reliability.py +28 -16
  69. agno/eval/utils.py +35 -17
  70. agno/exceptions.py +27 -2
  71. agno/filters.py +354 -0
  72. agno/guardrails/prompt_injection.py +1 -0
  73. agno/hooks/__init__.py +3 -0
  74. agno/hooks/decorator.py +164 -0
  75. agno/integrations/discord/client.py +1 -1
  76. agno/knowledge/chunking/agentic.py +13 -10
  77. agno/knowledge/chunking/fixed.py +4 -1
  78. agno/knowledge/chunking/semantic.py +9 -4
  79. agno/knowledge/chunking/strategy.py +59 -15
  80. agno/knowledge/embedder/fastembed.py +1 -1
  81. agno/knowledge/embedder/nebius.py +1 -1
  82. agno/knowledge/embedder/ollama.py +8 -0
  83. agno/knowledge/embedder/openai.py +8 -8
  84. agno/knowledge/embedder/sentence_transformer.py +6 -2
  85. agno/knowledge/embedder/vllm.py +262 -0
  86. agno/knowledge/knowledge.py +1618 -318
  87. agno/knowledge/reader/base.py +6 -2
  88. agno/knowledge/reader/csv_reader.py +8 -10
  89. agno/knowledge/reader/docx_reader.py +5 -6
  90. agno/knowledge/reader/field_labeled_csv_reader.py +16 -20
  91. agno/knowledge/reader/json_reader.py +5 -4
  92. agno/knowledge/reader/markdown_reader.py +8 -8
  93. agno/knowledge/reader/pdf_reader.py +17 -19
  94. agno/knowledge/reader/pptx_reader.py +101 -0
  95. agno/knowledge/reader/reader_factory.py +32 -3
  96. agno/knowledge/reader/s3_reader.py +3 -3
  97. agno/knowledge/reader/tavily_reader.py +193 -0
  98. agno/knowledge/reader/text_reader.py +22 -10
  99. agno/knowledge/reader/web_search_reader.py +1 -48
  100. agno/knowledge/reader/website_reader.py +10 -10
  101. agno/knowledge/reader/wikipedia_reader.py +33 -1
  102. agno/knowledge/types.py +1 -0
  103. agno/knowledge/utils.py +72 -7
  104. agno/media.py +22 -6
  105. agno/memory/__init__.py +14 -1
  106. agno/memory/manager.py +544 -83
  107. agno/memory/strategies/__init__.py +15 -0
  108. agno/memory/strategies/base.py +66 -0
  109. agno/memory/strategies/summarize.py +196 -0
  110. agno/memory/strategies/types.py +37 -0
  111. agno/models/aimlapi/aimlapi.py +17 -0
  112. agno/models/anthropic/claude.py +515 -40
  113. agno/models/aws/bedrock.py +102 -21
  114. agno/models/aws/claude.py +131 -274
  115. agno/models/azure/ai_foundry.py +41 -19
  116. agno/models/azure/openai_chat.py +39 -8
  117. agno/models/base.py +1249 -525
  118. agno/models/cerebras/cerebras.py +91 -21
  119. agno/models/cerebras/cerebras_openai.py +21 -2
  120. agno/models/cohere/chat.py +40 -6
  121. agno/models/cometapi/cometapi.py +18 -1
  122. agno/models/dashscope/dashscope.py +2 -3
  123. agno/models/deepinfra/deepinfra.py +18 -1
  124. agno/models/deepseek/deepseek.py +69 -3
  125. agno/models/fireworks/fireworks.py +18 -1
  126. agno/models/google/gemini.py +877 -80
  127. agno/models/google/utils.py +22 -0
  128. agno/models/groq/groq.py +51 -18
  129. agno/models/huggingface/huggingface.py +17 -6
  130. agno/models/ibm/watsonx.py +16 -6
  131. agno/models/internlm/internlm.py +18 -1
  132. agno/models/langdb/langdb.py +13 -1
  133. agno/models/litellm/chat.py +44 -9
  134. agno/models/litellm/litellm_openai.py +18 -1
  135. agno/models/message.py +28 -5
  136. agno/models/meta/llama.py +47 -14
  137. agno/models/meta/llama_openai.py +22 -17
  138. agno/models/mistral/mistral.py +8 -4
  139. agno/models/nebius/nebius.py +6 -7
  140. agno/models/nvidia/nvidia.py +20 -3
  141. agno/models/ollama/chat.py +24 -8
  142. agno/models/openai/chat.py +104 -29
  143. agno/models/openai/responses.py +101 -81
  144. agno/models/openrouter/openrouter.py +60 -3
  145. agno/models/perplexity/perplexity.py +17 -1
  146. agno/models/portkey/portkey.py +7 -6
  147. agno/models/requesty/requesty.py +24 -4
  148. agno/models/response.py +73 -2
  149. agno/models/sambanova/sambanova.py +20 -3
  150. agno/models/siliconflow/siliconflow.py +19 -2
  151. agno/models/together/together.py +20 -3
  152. agno/models/utils.py +254 -8
  153. agno/models/vercel/v0.py +20 -3
  154. agno/models/vertexai/__init__.py +0 -0
  155. agno/models/vertexai/claude.py +190 -0
  156. agno/models/vllm/vllm.py +19 -14
  157. agno/models/xai/xai.py +19 -2
  158. agno/os/app.py +549 -152
  159. agno/os/auth.py +190 -3
  160. agno/os/config.py +23 -0
  161. agno/os/interfaces/a2a/router.py +8 -11
  162. agno/os/interfaces/a2a/utils.py +1 -1
  163. agno/os/interfaces/agui/router.py +18 -3
  164. agno/os/interfaces/agui/utils.py +152 -39
  165. agno/os/interfaces/slack/router.py +55 -37
  166. agno/os/interfaces/slack/slack.py +9 -1
  167. agno/os/interfaces/whatsapp/router.py +0 -1
  168. agno/os/interfaces/whatsapp/security.py +3 -1
  169. agno/os/mcp.py +110 -52
  170. agno/os/middleware/__init__.py +2 -0
  171. agno/os/middleware/jwt.py +676 -112
  172. agno/os/router.py +40 -1478
  173. agno/os/routers/agents/__init__.py +3 -0
  174. agno/os/routers/agents/router.py +599 -0
  175. agno/os/routers/agents/schema.py +261 -0
  176. agno/os/routers/evals/evals.py +96 -39
  177. agno/os/routers/evals/schemas.py +65 -33
  178. agno/os/routers/evals/utils.py +80 -10
  179. agno/os/routers/health.py +10 -4
  180. agno/os/routers/knowledge/knowledge.py +196 -38
  181. agno/os/routers/knowledge/schemas.py +82 -22
  182. agno/os/routers/memory/memory.py +279 -52
  183. agno/os/routers/memory/schemas.py +46 -17
  184. agno/os/routers/metrics/metrics.py +20 -8
  185. agno/os/routers/metrics/schemas.py +16 -16
  186. agno/os/routers/session/session.py +462 -34
  187. agno/os/routers/teams/__init__.py +3 -0
  188. agno/os/routers/teams/router.py +512 -0
  189. agno/os/routers/teams/schema.py +257 -0
  190. agno/os/routers/traces/__init__.py +3 -0
  191. agno/os/routers/traces/schemas.py +414 -0
  192. agno/os/routers/traces/traces.py +499 -0
  193. agno/os/routers/workflows/__init__.py +3 -0
  194. agno/os/routers/workflows/router.py +624 -0
  195. agno/os/routers/workflows/schema.py +75 -0
  196. agno/os/schema.py +256 -693
  197. agno/os/scopes.py +469 -0
  198. agno/os/utils.py +514 -36
  199. agno/reasoning/anthropic.py +80 -0
  200. agno/reasoning/gemini.py +73 -0
  201. agno/reasoning/openai.py +5 -0
  202. agno/reasoning/vertexai.py +76 -0
  203. agno/run/__init__.py +6 -0
  204. agno/run/agent.py +155 -32
  205. agno/run/base.py +55 -3
  206. agno/run/requirement.py +181 -0
  207. agno/run/team.py +125 -38
  208. agno/run/workflow.py +72 -18
  209. agno/session/agent.py +102 -89
  210. agno/session/summary.py +56 -15
  211. agno/session/team.py +164 -90
  212. agno/session/workflow.py +405 -40
  213. agno/table.py +10 -0
  214. agno/team/team.py +3974 -1903
  215. agno/tools/dalle.py +2 -4
  216. agno/tools/eleven_labs.py +23 -25
  217. agno/tools/exa.py +21 -16
  218. agno/tools/file.py +153 -23
  219. agno/tools/file_generation.py +16 -10
  220. agno/tools/firecrawl.py +15 -7
  221. agno/tools/function.py +193 -38
  222. agno/tools/gmail.py +238 -14
  223. agno/tools/google_drive.py +271 -0
  224. agno/tools/googlecalendar.py +36 -8
  225. agno/tools/googlesheets.py +20 -5
  226. agno/tools/jira.py +20 -0
  227. agno/tools/mcp/__init__.py +10 -0
  228. agno/tools/mcp/mcp.py +331 -0
  229. agno/tools/mcp/multi_mcp.py +347 -0
  230. agno/tools/mcp/params.py +24 -0
  231. agno/tools/mcp_toolbox.py +3 -3
  232. agno/tools/models/nebius.py +5 -5
  233. agno/tools/models_labs.py +20 -10
  234. agno/tools/nano_banana.py +151 -0
  235. agno/tools/notion.py +204 -0
  236. agno/tools/parallel.py +314 -0
  237. agno/tools/postgres.py +76 -36
  238. agno/tools/redshift.py +406 -0
  239. agno/tools/scrapegraph.py +1 -1
  240. agno/tools/shopify.py +1519 -0
  241. agno/tools/slack.py +18 -3
  242. agno/tools/spotify.py +919 -0
  243. agno/tools/tavily.py +146 -0
  244. agno/tools/toolkit.py +25 -0
  245. agno/tools/workflow.py +8 -1
  246. agno/tools/yfinance.py +12 -11
  247. agno/tracing/__init__.py +12 -0
  248. agno/tracing/exporter.py +157 -0
  249. agno/tracing/schemas.py +276 -0
  250. agno/tracing/setup.py +111 -0
  251. agno/utils/agent.py +938 -0
  252. agno/utils/cryptography.py +22 -0
  253. agno/utils/dttm.py +33 -0
  254. agno/utils/events.py +151 -3
  255. agno/utils/gemini.py +15 -5
  256. agno/utils/hooks.py +118 -4
  257. agno/utils/http.py +113 -2
  258. agno/utils/knowledge.py +12 -5
  259. agno/utils/log.py +1 -0
  260. agno/utils/mcp.py +92 -2
  261. agno/utils/media.py +187 -1
  262. agno/utils/merge_dict.py +3 -3
  263. agno/utils/message.py +60 -0
  264. agno/utils/models/ai_foundry.py +9 -2
  265. agno/utils/models/claude.py +49 -14
  266. agno/utils/models/cohere.py +9 -2
  267. agno/utils/models/llama.py +9 -2
  268. agno/utils/models/mistral.py +4 -2
  269. agno/utils/print_response/agent.py +109 -16
  270. agno/utils/print_response/team.py +223 -30
  271. agno/utils/print_response/workflow.py +251 -34
  272. agno/utils/streamlit.py +1 -1
  273. agno/utils/team.py +98 -9
  274. agno/utils/tokens.py +657 -0
  275. agno/vectordb/base.py +39 -7
  276. agno/vectordb/cassandra/cassandra.py +21 -5
  277. agno/vectordb/chroma/chromadb.py +43 -12
  278. agno/vectordb/clickhouse/clickhousedb.py +21 -5
  279. agno/vectordb/couchbase/couchbase.py +29 -5
  280. agno/vectordb/lancedb/lance_db.py +92 -181
  281. agno/vectordb/langchaindb/langchaindb.py +24 -4
  282. agno/vectordb/lightrag/lightrag.py +17 -3
  283. agno/vectordb/llamaindex/llamaindexdb.py +25 -5
  284. agno/vectordb/milvus/milvus.py +50 -37
  285. agno/vectordb/mongodb/__init__.py +7 -1
  286. agno/vectordb/mongodb/mongodb.py +36 -30
  287. agno/vectordb/pgvector/pgvector.py +201 -77
  288. agno/vectordb/pineconedb/pineconedb.py +41 -23
  289. agno/vectordb/qdrant/qdrant.py +67 -54
  290. agno/vectordb/redis/__init__.py +9 -0
  291. agno/vectordb/redis/redisdb.py +682 -0
  292. agno/vectordb/singlestore/singlestore.py +50 -29
  293. agno/vectordb/surrealdb/surrealdb.py +31 -41
  294. agno/vectordb/upstashdb/upstashdb.py +34 -6
  295. agno/vectordb/weaviate/weaviate.py +53 -14
  296. agno/workflow/__init__.py +2 -0
  297. agno/workflow/agent.py +299 -0
  298. agno/workflow/condition.py +120 -18
  299. agno/workflow/loop.py +77 -10
  300. agno/workflow/parallel.py +231 -143
  301. agno/workflow/router.py +118 -17
  302. agno/workflow/step.py +609 -170
  303. agno/workflow/steps.py +73 -6
  304. agno/workflow/types.py +96 -21
  305. agno/workflow/workflow.py +2039 -262
  306. {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/METADATA +201 -66
  307. agno-2.3.13.dist-info/RECORD +613 -0
  308. agno/tools/googlesearch.py +0 -98
  309. agno/tools/mcp.py +0 -679
  310. agno/tools/memori.py +0 -339
  311. agno-2.1.2.dist-info/RECORD +0 -543
  312. {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/WHEEL +0 -0
  313. {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/licenses/LICENSE +0 -0
  314. {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1908 @@
1
+ from datetime import date, datetime, timedelta, timezone
2
+ from textwrap import dedent
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
4
+
5
+ if TYPE_CHECKING:
6
+ from agno.tracing.schemas import Span, Trace
7
+
8
+ from agno.db.base import BaseDb, SessionType
9
+ from agno.db.postgres.utils import (
10
+ get_dates_to_calculate_metrics_for,
11
+ )
12
+ from agno.db.schemas import UserMemory
13
+ from agno.db.schemas.culture import CulturalKnowledge
14
+ from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
15
+ from agno.db.schemas.knowledge import KnowledgeRow
16
+ from agno.db.surrealdb import utils
17
+ from agno.db.surrealdb.metrics import (
18
+ bulk_upsert_metrics,
19
+ calculate_date_metrics,
20
+ fetch_all_sessions_data,
21
+ get_all_sessions_for_metrics_calculation,
22
+ get_metrics_calculation_starting_date,
23
+ )
24
+ from agno.db.surrealdb.models import (
25
+ TableType,
26
+ deserialize_cultural_knowledge,
27
+ deserialize_eval_run_record,
28
+ deserialize_knowledge_row,
29
+ deserialize_session,
30
+ deserialize_sessions,
31
+ deserialize_user_memories,
32
+ deserialize_user_memory,
33
+ desurrealize_eval_run_record,
34
+ desurrealize_session,
35
+ desurrealize_user_memory,
36
+ get_schema,
37
+ get_session_type,
38
+ serialize_cultural_knowledge,
39
+ serialize_eval_run_record,
40
+ serialize_knowledge_row,
41
+ serialize_session,
42
+ serialize_user_memory,
43
+ )
44
+ from agno.db.surrealdb.queries import COUNT_QUERY, WhereClause, order_limit_start
45
+ from agno.db.surrealdb.utils import build_client
46
+ from agno.session import Session
47
+ from agno.utils.log import log_debug, log_error, log_info
48
+ from agno.utils.string import generate_id
49
+
50
+ try:
51
+ from surrealdb import BlockingHttpSurrealConnection, BlockingWsSurrealConnection, RecordID
52
+ except ImportError:
53
+ raise ImportError("The `surrealdb` package is not installed. Please install it via `pip install surrealdb`.")
54
+
55
+
56
+ class SurrealDb(BaseDb):
57
+ def __init__(
58
+ self,
59
+ client: Optional[Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]],
60
+ db_url: str,
61
+ db_creds: dict[str, str],
62
+ db_ns: str,
63
+ db_db: str,
64
+ session_table: Optional[str] = None,
65
+ memory_table: Optional[str] = None,
66
+ metrics_table: Optional[str] = None,
67
+ eval_table: Optional[str] = None,
68
+ knowledge_table: Optional[str] = None,
69
+ culture_table: Optional[str] = None,
70
+ traces_table: Optional[str] = None,
71
+ spans_table: Optional[str] = None,
72
+ id: Optional[str] = None,
73
+ ):
74
+ """
75
+ Interface for interacting with a SurrealDB database.
76
+
77
+ Args:
78
+ client: A blocking connection, either HTTP or WS
79
+ db_url: The URL of the SurrealDB database.
80
+ db_creds: The credentials for the SurrealDB database.
81
+ db_ns: The namespace for the SurrealDB database.
82
+ db_db: The database name for the SurrealDB database.
83
+ session_table: The name of the session table.
84
+ memory_table: The name of the memory table.
85
+ metrics_table: The name of the metrics table.
86
+ eval_table: The name of the eval table.
87
+ knowledge_table: The name of the knowledge table.
88
+ culture_table: The name of the culture table.
89
+ traces_table: The name of the traces table.
90
+ spans_table: The name of the spans table.
91
+ id: The ID of the database.
92
+ """
93
+ if id is None:
94
+ base_seed = db_url
95
+ seed = f"{base_seed}#{db_db}"
96
+ id = generate_id(seed)
97
+
98
+ super().__init__(
99
+ id=id,
100
+ session_table=session_table,
101
+ memory_table=memory_table,
102
+ metrics_table=metrics_table,
103
+ eval_table=eval_table,
104
+ knowledge_table=knowledge_table,
105
+ culture_table=culture_table,
106
+ traces_table=traces_table,
107
+ spans_table=spans_table,
108
+ )
109
+ self._client = client
110
+ self._db_url = db_url
111
+ self._db_creds = db_creds
112
+ self._db_ns = db_ns
113
+ self._db_db = db_db
114
+ self._users_table_name: str = "agno_users"
115
+ self._agents_table_name: str = "agno_agents"
116
+ self._teams_table_name: str = "agno_teams"
117
+ self._workflows_table_name: str = "agno_workflows"
118
+
119
+ @property
120
+ def client(self) -> Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]:
121
+ if self._client is None:
122
+ self._client = build_client(self._db_url, self._db_creds, self._db_ns, self._db_db)
123
+ return self._client
124
+
125
+ @property
126
+ def table_names(self) -> dict[TableType, str]:
127
+ return {
128
+ "agents": self._agents_table_name,
129
+ "culture": self.culture_table_name,
130
+ "evals": self.eval_table_name,
131
+ "knowledge": self.knowledge_table_name,
132
+ "memories": self.memory_table_name,
133
+ "sessions": self.session_table_name,
134
+ "spans": self.span_table_name,
135
+ "teams": self._teams_table_name,
136
+ "traces": self.trace_table_name,
137
+ "users": self._users_table_name,
138
+ "workflows": self._workflows_table_name,
139
+ }
140
+
141
+ def table_exists(self, table_name: str) -> bool:
142
+ """Check if a table with the given name exists in the SurrealDB database.
143
+
144
+ Args:
145
+ table_name: Name of the table to check
146
+
147
+ Returns:
148
+ bool: True if the table exists in the database, False otherwise
149
+ """
150
+ response = self._query_one("INFO FOR DB", {}, dict)
151
+ if response is None:
152
+ raise Exception("Failed to retrieve database information")
153
+ return table_name in response.get("tables", [])
154
+
155
+ def _table_exists(self, table_name: str) -> bool:
156
+ """Deprecated: Use table_exists() instead."""
157
+ return self.table_exists(table_name)
158
+
159
+ def _create_table(self, table_type: TableType, table_name: str):
160
+ query = get_schema(table_type, table_name)
161
+ self.client.query(query)
162
+
163
+ def _get_table(self, table_type: TableType, create_table_if_not_found: bool = True):
164
+ if table_type == "sessions":
165
+ table_name = self.session_table_name
166
+ elif table_type == "memories":
167
+ table_name = self.memory_table_name
168
+ elif table_type == "knowledge":
169
+ table_name = self.knowledge_table_name
170
+ elif table_type == "culture":
171
+ table_name = self.culture_table_name
172
+ elif table_type == "users":
173
+ table_name = self._users_table_name
174
+ elif table_type == "agents":
175
+ table_name = self._agents_table_name
176
+ elif table_type == "teams":
177
+ table_name = self._teams_table_name
178
+ elif table_type == "workflows":
179
+ table_name = self._workflows_table_name
180
+ elif table_type == "evals":
181
+ table_name = self.eval_table_name
182
+ elif table_type == "metrics":
183
+ table_name = self.metrics_table_name
184
+ elif table_type == "traces":
185
+ table_name = self.trace_table_name
186
+ elif table_type == "spans":
187
+ # Ensure traces table exists before spans (for foreign key-like relationship)
188
+ if create_table_if_not_found:
189
+ self._get_table("traces", create_table_if_not_found=True)
190
+ table_name = self.span_table_name
191
+ else:
192
+ raise NotImplementedError(f"Unknown table type: {table_type}")
193
+
194
+ if create_table_if_not_found and not self._table_exists(table_name):
195
+ self._create_table(table_type, table_name)
196
+
197
+ return table_name
198
+
199
+ def get_latest_schema_version(self):
200
+ """Get the latest version of the database schema."""
201
+ pass
202
+
203
+ def upsert_schema_version(self, version: str) -> None:
204
+ """Upsert the schema version into the database."""
205
+ pass
206
+
207
+ def _query(
208
+ self,
209
+ query: str,
210
+ vars: dict[str, Any],
211
+ record_type: type[utils.RecordType],
212
+ ) -> Sequence[utils.RecordType]:
213
+ return utils.query(self.client, query, vars, record_type)
214
+
215
+ def _query_one(
216
+ self,
217
+ query: str,
218
+ vars: dict[str, Any],
219
+ record_type: type[utils.RecordType],
220
+ ) -> Optional[utils.RecordType]:
221
+ return utils.query_one(self.client, query, vars, record_type)
222
+
223
+ def _count(self, table: str, where_clause: str, where_vars: dict[str, Any], group_by: Optional[str] = None) -> int:
224
+ total_count_query = COUNT_QUERY.format(
225
+ table=table,
226
+ where_clause=where_clause,
227
+ group_clause="GROUP ALL" if group_by is None else f"GROUP BY {group_by}",
228
+ group_fields="" if group_by is None else f", {group_by}",
229
+ )
230
+ count_result = self._query_one(total_count_query, where_vars, dict)
231
+ total_count = count_result.get("count") if count_result else 0
232
+ assert isinstance(total_count, int), f"Expected int, got {type(total_count)}"
233
+ total_count = int(total_count)
234
+ return total_count
235
+
236
+ # --- Sessions ---
237
+ def clear_sessions(self) -> None:
238
+ """Delete all session rows from the database.
239
+
240
+ Raises:
241
+ Exception: If an error occurs during deletion.
242
+ """
243
+ table = self._get_table("sessions")
244
+ _ = self.client.delete(table)
245
+
246
+ def delete_session(self, session_id: str) -> bool:
247
+ table = self._get_table(table_type="sessions")
248
+ if table is None:
249
+ return False
250
+ res = self.client.delete(RecordID(table, session_id))
251
+ return bool(res)
252
+
253
+ def delete_sessions(self, session_ids: list[str]) -> None:
254
+ table = self._get_table(table_type="sessions")
255
+ if table is None:
256
+ return
257
+
258
+ records = [RecordID(table, id) for id in session_ids]
259
+ self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
260
+
261
+ def get_session(
262
+ self,
263
+ session_id: str,
264
+ session_type: SessionType,
265
+ user_id: Optional[str] = None,
266
+ deserialize: Optional[bool] = True,
267
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
268
+ r"""
269
+ Read a session from the database.
270
+
271
+ Args:
272
+ session_id (str): ID of the session to read.
273
+ session_type (SessionType): Type of session to get.
274
+ user_id (Optional[str]): User ID to filter by. Defaults to None.
275
+ deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
276
+
277
+ Returns:
278
+ Optional[Union[Session, Dict[str, Any]]]:
279
+ - When deserialize=True: Session object
280
+ - When deserialize=False: Session dictionary
281
+
282
+ Raises:
283
+ Exception: If an error occurs during retrieval.
284
+ """
285
+ sessions_table = self._get_table("sessions")
286
+ record = RecordID(sessions_table, session_id)
287
+ where = WhereClause()
288
+ if user_id is not None:
289
+ where = where.and_("user_id", user_id)
290
+ where_clause, where_vars = where.build()
291
+ query = dedent(f"""
292
+ SELECT *
293
+ FROM ONLY $record
294
+ {where_clause}
295
+ """)
296
+ vars = {"record": record, **where_vars}
297
+ raw = self._query_one(query, vars, dict)
298
+ if raw is None or not deserialize:
299
+ return raw
300
+
301
+ return deserialize_session(session_type, raw)
302
+
303
+ def get_sessions(
304
+ self,
305
+ session_type: Optional[SessionType] = None,
306
+ user_id: Optional[str] = None,
307
+ component_id: Optional[str] = None,
308
+ session_name: Optional[str] = None,
309
+ start_timestamp: Optional[int] = None,
310
+ end_timestamp: Optional[int] = None,
311
+ limit: Optional[int] = None,
312
+ page: Optional[int] = None,
313
+ sort_by: Optional[str] = None,
314
+ sort_order: Optional[str] = None,
315
+ deserialize: Optional[bool] = True,
316
+ ) -> Union[List[Session], Tuple[List[Dict[str, Any]], int]]:
317
+ r"""
318
+ Get all sessions in the given table. Can filter by user_id and entity_id.
319
+
320
+ Args:
321
+ session_type (SessionType): The type of session to get.
322
+ user_id (Optional[str]): The ID of the user to filter by.
323
+ component_id (Optional[str]): The ID of the agent / team / workflow to filter by.
324
+ session_name (Optional[str]): The name of the session to filter by.
325
+ start_timestamp (Optional[int]): The start timestamp to filter by.
326
+ end_timestamp (Optional[int]): The end timestamp to filter by.
327
+ limit (Optional[int]): The maximum number of sessions to return. Defaults to None.
328
+ page (Optional[int]): The page number to return. Defaults to None.
329
+ sort_by (Optional[str]): The field to sort by. Defaults to None.
330
+ sort_order (Optional[str]): The sort order. Defaults to None.
331
+ deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
332
+
333
+ Returns:
334
+ Union[List[Session], Tuple[List[Dict], int]]:
335
+ - When deserialize=True: List of Session objects
336
+ - When deserialize=False: Tuple of (session dictionaries, total count)
337
+
338
+ Raises:
339
+ Exception: If an error occurs during retrieval.
340
+ """
341
+ table = self._get_table("sessions")
342
+ # users_table = self._get_table("users", False) # Not used, commenting out for now.
343
+ agents_table = self._get_table("agents", False)
344
+ teams_table = self._get_table("teams", False)
345
+ workflows_table = self._get_table("workflows", False)
346
+
347
+ # -- Filters
348
+ where = WhereClause()
349
+
350
+ # user_id
351
+ if user_id is not None:
352
+ where = where.and_("user_id", user_id)
353
+
354
+ # component_id
355
+ if component_id is not None:
356
+ if session_type == SessionType.AGENT:
357
+ where = where.and_("agent", RecordID(agents_table, component_id))
358
+ elif session_type == SessionType.TEAM:
359
+ where = where.and_("team", RecordID(teams_table, component_id))
360
+ elif session_type == SessionType.WORKFLOW:
361
+ where = where.and_("workflow", RecordID(workflows_table, component_id))
362
+
363
+ # session_name
364
+ if session_name is not None:
365
+ where = where.and_("session_name", session_name, "~")
366
+
367
+ # start_timestamp
368
+ if start_timestamp is not None:
369
+ where = where.and_("start_timestamp", start_timestamp, ">=")
370
+
371
+ # end_timestamp
372
+ if end_timestamp is not None:
373
+ where = where.and_("end_timestamp", end_timestamp, "<=")
374
+
375
+ where_clause, where_vars = where.build()
376
+
377
+ # Total count
378
+ total_count = self._count(table, where_clause, where_vars)
379
+
380
+ # Query
381
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
382
+ query = dedent(f"""
383
+ SELECT *
384
+ FROM {table}
385
+ {where_clause}
386
+ {order_limit_start_clause}
387
+ """)
388
+ sessions_raw = self._query(query, where_vars, dict)
389
+ converted_sessions_raw = [desurrealize_session(session, session_type) for session in sessions_raw]
390
+
391
+ if not deserialize:
392
+ return list(converted_sessions_raw), total_count
393
+
394
+ if session_type is None:
395
+ raise ValueError("session_type is required when deserialize=True")
396
+
397
+ return deserialize_sessions(session_type, list(sessions_raw))
398
+
399
+ def rename_session(
400
+ self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
401
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
402
+ """
403
+ Rename a session in the database.
404
+
405
+ Args:
406
+ session_id (str): The ID of the session to rename.
407
+ session_type (SessionType): The type of session to rename.
408
+ session_name (str): The new name for the session.
409
+ deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
410
+
411
+ Returns:
412
+ Optional[Union[Session, Dict[str, Any]]]:
413
+ - When deserialize=True: Session object
414
+ - When deserialize=False: Session dictionary
415
+
416
+ Raises:
417
+ Exception: If an error occurs during renaming.
418
+ """
419
+ table = self._get_table("sessions")
420
+ vars = {"record": RecordID(table, session_id), "name": session_name}
421
+
422
+ # Query
423
+ query = dedent("""
424
+ UPDATE ONLY $record
425
+ SET session_name = $name
426
+ """)
427
+ session_raw = self._query_one(query, vars, dict)
428
+
429
+ if session_raw is None or not deserialize:
430
+ return session_raw
431
+ return deserialize_session(session_type, session_raw)
432
+
433
+ def upsert_session(
434
+ self, session: Session, deserialize: Optional[bool] = True
435
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
436
+ """
437
+ Insert or update a session in the database.
438
+
439
+ Args:
440
+ session (Session): The session data to upsert.
441
+ deserialize (Optional[bool]): Whether to deserialize the session. Defaults to True.
442
+
443
+ Returns:
444
+ Optional[Union[Session, Dict[str, Any]]]:
445
+ - When deserialize=True: Session object
446
+ - When deserialize=False: Session dictionary
447
+
448
+ Raises:
449
+ Exception: If an error occurs during upsert.
450
+ """
451
+ session_type = get_session_type(session)
452
+ table = self._get_table("sessions")
453
+ session_raw = self._query_one(
454
+ "UPSERT ONLY $record CONTENT $content",
455
+ {
456
+ "record": RecordID(table, session.session_id),
457
+ "content": serialize_session(session, self.table_names),
458
+ },
459
+ dict,
460
+ )
461
+ if session_raw is None or not deserialize:
462
+ return session_raw
463
+
464
+ return deserialize_session(session_type, session_raw)
465
+
466
+ def upsert_sessions(
467
+ self, sessions: List[Session], deserialize: Optional[bool] = True
468
+ ) -> List[Union[Session, Dict[str, Any]]]:
469
+ """
470
+ Bulk insert or update multiple sessions.
471
+
472
+ Args:
473
+ sessions (List[Session]): The list of session data to upsert.
474
+ deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
475
+
476
+ Returns:
477
+ List[Union[Session, Dict[str, Any]]]: List of upserted sessions
478
+
479
+ Raises:
480
+ Exception: If an error occurs during bulk upsert.
481
+ """
482
+ if not sessions:
483
+ return []
484
+ session_type = get_session_type(sessions[0])
485
+ table = self._get_table("sessions")
486
+ sessions_raw: List[Dict[str, Any]] = []
487
+ for session in sessions:
488
+ # UPSERT does only work for one record at a time
489
+ session_raw = self._query_one(
490
+ "UPSERT ONLY $record CONTENT $content",
491
+ {
492
+ "record": RecordID(table, session.session_id),
493
+ "content": serialize_session(session, self.table_names),
494
+ },
495
+ dict,
496
+ )
497
+ if session_raw:
498
+ sessions_raw.append(session_raw)
499
+ if not deserialize:
500
+ return list(sessions_raw)
501
+
502
+ # wrapping with list because of:
503
+ # Type "List[Session]" is not assignable to return type "List[Session | Dict[str, Any]]"
504
+ # Consider switching from "list" to "Sequence" which is covariant
505
+ return list(deserialize_sessions(session_type, sessions_raw))
506
+
507
+ # --- Memory ---
508
+ def clear_memories(self) -> None:
509
+ """Delete all memories from the database.
510
+
511
+ Raises:
512
+ Exception: If an error occurs during deletion.
513
+ """
514
+ table = self._get_table("memories")
515
+ _ = self.client.delete(table)
516
+
517
+ # -- Cultural Knowledge methods --
518
+ def clear_cultural_knowledge(self) -> None:
519
+ """Delete all cultural knowledge from the database.
520
+
521
+ Raises:
522
+ Exception: If an error occurs during deletion.
523
+ """
524
+ table = self._get_table("culture")
525
+ _ = self.client.delete(table)
526
+
527
+ def delete_cultural_knowledge(self, id: str) -> None:
528
+ """Delete cultural knowledge by ID.
529
+
530
+ Args:
531
+ id (str): The ID of the cultural knowledge to delete.
532
+
533
+ Raises:
534
+ Exception: If an error occurs during deletion.
535
+ """
536
+ table = self._get_table("culture")
537
+ rec_id = RecordID(table, id)
538
+ self.client.delete(rec_id)
539
+
540
+ def get_cultural_knowledge(
541
+ self, id: str, deserialize: Optional[bool] = True
542
+ ) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
543
+ """Get cultural knowledge by ID.
544
+
545
+ Args:
546
+ id (str): The ID of the cultural knowledge to retrieve.
547
+ deserialize (Optional[bool]): Whether to deserialize to CulturalKnowledge object. Defaults to True.
548
+
549
+ Returns:
550
+ Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The cultural knowledge if found, None otherwise.
551
+
552
+ Raises:
553
+ Exception: If an error occurs during retrieval.
554
+ """
555
+ table = self._get_table("culture")
556
+ rec_id = RecordID(table, id)
557
+ result = self.client.select(rec_id)
558
+
559
+ if result is None:
560
+ return None
561
+
562
+ if not deserialize:
563
+ return result # type: ignore
564
+
565
+ return deserialize_cultural_knowledge(result) # type: ignore
566
+
567
+ def get_all_cultural_knowledge(
568
+ self,
569
+ agent_id: Optional[str] = None,
570
+ team_id: Optional[str] = None,
571
+ name: Optional[str] = None,
572
+ limit: Optional[int] = None,
573
+ page: Optional[int] = None,
574
+ sort_by: Optional[str] = None,
575
+ sort_order: Optional[str] = None,
576
+ deserialize: Optional[bool] = True,
577
+ ) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
578
+ """Get all cultural knowledge with filtering and pagination.
579
+
580
+ Args:
581
+ agent_id (Optional[str]): Filter by agent ID.
582
+ team_id (Optional[str]): Filter by team ID.
583
+ name (Optional[str]): Filter by name (case-insensitive partial match).
584
+ limit (Optional[int]): Maximum number of results to return.
585
+ page (Optional[int]): Page number for pagination.
586
+ sort_by (Optional[str]): Field to sort by.
587
+ sort_order (Optional[str]): Sort order ('asc' or 'desc').
588
+ deserialize (Optional[bool]): Whether to deserialize to CulturalKnowledge objects. Defaults to True.
589
+
590
+ Returns:
591
+ Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
592
+ - When deserialize=True: List of CulturalKnowledge objects
593
+ - When deserialize=False: Tuple with list of dictionaries and total count
594
+
595
+ Raises:
596
+ Exception: If an error occurs during retrieval.
597
+ """
598
+ table = self._get_table("culture")
599
+
600
+ # Build where clauses
601
+ where_clauses: List[WhereClause] = []
602
+ if agent_id is not None:
603
+ agent_rec_id = RecordID(self._get_table("agents"), agent_id)
604
+ where_clauses.append(("agent", "=", agent_rec_id)) # type: ignore
605
+ if team_id is not None:
606
+ team_rec_id = RecordID(self._get_table("teams"), team_id)
607
+ where_clauses.append(("team", "=", team_rec_id)) # type: ignore
608
+ if name is not None:
609
+ where_clauses.append(("string::lowercase(name)", "CONTAINS", name.lower())) # type: ignore
610
+
611
+ # Build query for total count
612
+ count_query = COUNT_QUERY.format(
613
+ table=table,
614
+ where=""
615
+ if not where_clauses
616
+ else f"WHERE {' AND '.join(f'{w[0]} {w[1]} ${chr(97 + i)}' for i, w in enumerate(where_clauses))}", # type: ignore
617
+ )
618
+ params = {chr(97 + i): w[2] for i, w in enumerate(where_clauses)} # type: ignore
619
+ total_count = self._query_one(count_query, params, int) or 0
620
+
621
+ # Build main query
622
+ order_limit = order_limit_start(sort_by, sort_order, limit, page)
623
+ query = f"SELECT * FROM {table}"
624
+ if where_clauses:
625
+ query += f" WHERE {' AND '.join(f'{w[0]} {w[1]} ${chr(97 + i)}' for i, w in enumerate(where_clauses))}" # type: ignore
626
+ query += order_limit
627
+
628
+ results = self._query(query, params, list) or []
629
+
630
+ if not deserialize:
631
+ return results, total_count # type: ignore
632
+
633
+ return [deserialize_cultural_knowledge(r) for r in results] # type: ignore
634
+
635
+ def upsert_cultural_knowledge(
636
+ self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
637
+ ) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
638
+ """Upsert cultural knowledge in SurrealDB.
639
+
640
+ Args:
641
+ cultural_knowledge (CulturalKnowledge): The cultural knowledge to upsert.
642
+ deserialize (Optional[bool]): Whether to deserialize the result. Defaults to True.
643
+
644
+ Returns:
645
+ Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The upserted cultural knowledge.
646
+
647
+ Raises:
648
+ Exception: If an error occurs during upsert.
649
+ """
650
+ table = self._get_table("culture", create_table_if_not_found=True)
651
+ serialized = serialize_cultural_knowledge(cultural_knowledge, table)
652
+
653
+ result = self.client.upsert(serialized["id"], serialized)
654
+
655
+ if result is None:
656
+ return None
657
+
658
+ if not deserialize:
659
+ return result # type: ignore
660
+
661
+ return deserialize_cultural_knowledge(result) # type: ignore
662
+
663
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
664
+ """Delete a user memory from the database.
665
+
666
+ Args:
667
+ memory_id (str): The ID of the memory to delete.
668
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
669
+
670
+ Returns:
671
+ bool: True if deletion was successful, False otherwise.
672
+
673
+ Raises:
674
+ Exception: If an error occurs during deletion.
675
+ """
676
+ table = self._get_table("memories")
677
+ mem_rec_id = RecordID(table, memory_id)
678
+ if user_id is None:
679
+ self.client.delete(mem_rec_id)
680
+ else:
681
+ user_rec_id = RecordID(self._get_table("users"), user_id)
682
+ self.client.query(
683
+ f"DELETE FROM {table} WHERE user = $user AND id = $memory",
684
+ {"user": user_rec_id, "memory": mem_rec_id},
685
+ )
686
+
687
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
688
+ """Delete user memories from the database.
689
+
690
+ Args:
691
+ memory_ids (List[str]): The IDs of the memories to delete.
692
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
693
+
694
+ Raises:
695
+ Exception: If an error occurs during deletion.
696
+ """
697
+ table = self._get_table("memories")
698
+ records = [RecordID(table, memory_id) for memory_id in memory_ids]
699
+ if user_id is None:
700
+ _ = self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
701
+ else:
702
+ user_rec_id = RecordID(self._get_table("users"), user_id)
703
+ _ = self.client.query(
704
+ f"DELETE FROM {table} WHERE id IN $records AND user = $user", {"records": records, "user": user_rec_id}
705
+ )
706
+
707
+ def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
708
+ """Get all memory topics from the database.
709
+
710
+ Args:
711
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
712
+
713
+ Returns:
714
+ List[str]: List of memory topics.
715
+ """
716
+ table = self._get_table("memories")
717
+ vars: dict[str, Any] = {}
718
+
719
+ # Query
720
+ if user_id is None:
721
+ query = dedent(f"""
722
+ RETURN (
723
+ SELECT
724
+ array::flatten(topics) as topics
725
+ FROM ONLY {table}
726
+ GROUP ALL
727
+ ).topics.distinct();
728
+ """)
729
+ else:
730
+ query = dedent(f"""
731
+ RETURN (
732
+ SELECT
733
+ array::flatten(topics) as topics
734
+ FROM ONLY {table}
735
+ WHERE user = $user
736
+ GROUP ALL
737
+ ).topics.distinct();
738
+ """)
739
+ vars["user"] = RecordID(self._get_table("users"), user_id)
740
+
741
+ result = self._query(query, vars, str)
742
+ return list(result)
743
+
744
+ def get_user_memory(
745
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
746
+ ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
747
+ """Get a memory from the database.
748
+
749
+ Args:
750
+ memory_id (str): The ID of the memory to get.
751
+ deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
752
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
753
+
754
+ Returns:
755
+ Optional[Union[UserMemory, Dict[str, Any]]]:
756
+ - When deserialize=True: UserMemory object
757
+ - When deserialize=False: UserMemory dictionary
758
+
759
+ Raises:
760
+ Exception: If an error occurs during retrieval.
761
+ """
762
+ table_name = self._get_table("memories")
763
+ record = RecordID(table_name, memory_id)
764
+ vars = {"record": record}
765
+
766
+ if user_id is None:
767
+ query = "SELECT * FROM ONLY $record"
768
+ else:
769
+ query = "SELECT * FROM ONLY $record WHERE user = $user"
770
+ vars["user"] = RecordID(self._get_table("users"), user_id)
771
+
772
+ result = self._query_one(query, vars, dict)
773
+ if result is None or not deserialize:
774
+ return result
775
+ return deserialize_user_memory(result)
776
+
777
+ def get_user_memories(
778
+ self,
779
+ user_id: Optional[str] = None,
780
+ agent_id: Optional[str] = None,
781
+ team_id: Optional[str] = None,
782
+ topics: Optional[List[str]] = None,
783
+ search_content: Optional[str] = None,
784
+ limit: Optional[int] = None,
785
+ page: Optional[int] = None,
786
+ sort_by: Optional[str] = None,
787
+ sort_order: Optional[str] = None,
788
+ deserialize: Optional[bool] = True,
789
+ ) -> Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
790
+ """Get all memories from the database as UserMemory objects.
791
+
792
+ Args:
793
+ user_id (Optional[str]): The ID of the user to filter by.
794
+ agent_id (Optional[str]): The ID of the agent to filter by.
795
+ team_id (Optional[str]): The ID of the team to filter by.
796
+ topics (Optional[List[str]]): The topics to filter by.
797
+ search_content (Optional[str]): The content to search for.
798
+ limit (Optional[int]): The maximum number of memories to return.
799
+ page (Optional[int]): The page number.
800
+ sort_by (Optional[str]): The column to sort by.
801
+ sort_order (Optional[str]): The order to sort by.
802
+ deserialize (Optional[bool]): Whether to serialize the memories. Defaults to True.
803
+
804
+
805
+ Returns:
806
+ Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
807
+ - When deserialize=True: List of UserMemory objects
808
+ - When deserialize=False: Tuple of (memory dictionaries, total count)
809
+
810
+ Raises:
811
+ Exception: If an error occurs during retrieval.
812
+ """
813
+ table = self._get_table("memories")
814
+ where = WhereClause()
815
+ if user_id is not None:
816
+ rec_id = RecordID(self._get_table("users"), user_id)
817
+ where.and_("user", rec_id)
818
+ if agent_id is not None:
819
+ rec_id = RecordID(self._get_table("agents"), agent_id)
820
+ where.and_("agent", rec_id)
821
+ if team_id is not None:
822
+ rec_id = RecordID(self._get_table("teams"), team_id)
823
+ where.and_("team", rec_id)
824
+ if topics is not None:
825
+ where.and_("topics", topics, "CONTAINSANY")
826
+ if search_content is not None:
827
+ where.and_("memory", search_content, "~")
828
+ where_clause, where_vars = where.build()
829
+
830
+ # Total count
831
+ total_count = self._count(table, where_clause, where_vars)
832
+
833
+ # Query
834
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
835
+ query = dedent(f"""
836
+ SELECT *
837
+ FROM {table}
838
+ {where_clause}
839
+ {order_limit_start_clause}
840
+ """)
841
+ result = self._query(query, where_vars, dict)
842
+ if deserialize:
843
+ return deserialize_user_memories(result)
844
+ return [desurrealize_user_memory(x) for x in result], total_count
845
+
846
+ def get_user_memory_stats(
847
+ self,
848
+ limit: Optional[int] = None,
849
+ page: Optional[int] = None,
850
+ user_id: Optional[str] = None,
851
+ ) -> Tuple[List[Dict[str, Any]], int]:
852
+ """Get user memories stats.
853
+
854
+ Args:
855
+ limit (Optional[int]): The maximum number of user stats to return.
856
+ page (Optional[int]): The page number.
857
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
858
+
859
+ Returns:
860
+ Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
861
+
862
+ Example:
863
+ (
864
+ [
865
+ {
866
+ "user_id": "123",
867
+ "total_memories": 10,
868
+ "last_memory_updated_at": 1714560000,
869
+ },
870
+ ],
871
+ total_count: 1,
872
+ )
873
+ """
874
+ memories_table_name = self._get_table("memories")
875
+ where = WhereClause()
876
+
877
+ if user_id is None:
878
+ where.and_("!!user", True, "=") # this checks that user is not falsy
879
+ else:
880
+ where.and_("user", RecordID(self._get_table("users"), user_id), "=")
881
+
882
+ where_clause, where_vars = where.build()
883
+ # Group
884
+ group_clause = "GROUP BY user"
885
+ # Order
886
+ order_limit_start_clause = order_limit_start("last_memory_updated_at", "DESC", limit, page)
887
+ # Total count
888
+ total_count = (
889
+ self._query_one(f"(SELECT user FROM {memories_table_name} GROUP BY user).map(|$x| $x.user).len()", {}, int)
890
+ or 0
891
+ )
892
+ # Query
893
+ query = dedent(f"""
894
+ SELECT
895
+ user,
896
+ count(id) AS total_memories,
897
+ time::max(updated_at) AS last_memory_updated_at
898
+ FROM {memories_table_name}
899
+ {where_clause}
900
+ {group_clause}
901
+ {order_limit_start_clause}
902
+ """)
903
+ result = self._query(query, where_vars, dict)
904
+
905
+ # deserialize dates and RecordIDs
906
+ for row in result:
907
+ row["user_id"] = row["user"].id
908
+ del row["user"]
909
+ row["last_memory_updated_at"] = row["last_memory_updated_at"].timestamp()
910
+ row["last_memory_updated_at"] = int(row["last_memory_updated_at"])
911
+
912
+ return list(result), total_count
913
+
914
+ def upsert_user_memory(
915
+ self, memory: UserMemory, deserialize: Optional[bool] = True
916
+ ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
917
+ """Upsert a user memory in the database.
918
+
919
+ Args:
920
+ memory (UserMemory): The user memory to upsert.
921
+ deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
922
+
923
+ Returns:
924
+ Optional[Union[UserMemory, Dict[str, Any]]]:
925
+ - When deserialize=True: UserMemory object
926
+ - When deserialize=False: UserMemory dictionary
927
+
928
+ Raises:
929
+ Exception: If an error occurs during upsert.
930
+ """
931
+ table = self._get_table("memories")
932
+ user_table = self._get_table("users")
933
+ if memory.memory_id:
934
+ record = RecordID(table, memory.memory_id)
935
+ query = "UPSERT ONLY $record CONTENT $content"
936
+ result = self._query_one(
937
+ query, {"record": record, "content": serialize_user_memory(memory, table, user_table)}, dict
938
+ )
939
+ else:
940
+ query = f"CREATE ONLY {table} CONTENT $content"
941
+ result = self._query_one(query, {"content": serialize_user_memory(memory, table, user_table)}, dict)
942
+ if result is None:
943
+ return None
944
+ elif not deserialize:
945
+ return desurrealize_user_memory(result)
946
+ return deserialize_user_memory(result)
947
+
948
+ def upsert_memories(
949
+ self, memories: List[UserMemory], deserialize: Optional[bool] = True
950
+ ) -> List[Union[UserMemory, Dict[str, Any]]]:
951
+ """
952
+ Bulk insert or update multiple memories in the database for improved performance.
953
+
954
+ Args:
955
+ memories (List[UserMemory]): The list of memories to upsert.
956
+ deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
957
+
958
+ Returns:
959
+ List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories
960
+
961
+ Raises:
962
+ Exception: If an error occurs during bulk upsert.
963
+ """
964
+ if not memories:
965
+ return []
966
+ table = self._get_table("memories")
967
+ user_table_name = self._get_table("users")
968
+ raw: list[dict] = []
969
+ for memory in memories:
970
+ if memory.memory_id:
971
+ # UPSERT does only work for one record at a time
972
+ session_raw = self._query_one(
973
+ "UPSERT ONLY $record CONTENT $content",
974
+ {
975
+ "record": RecordID(table, memory.memory_id),
976
+ "content": serialize_user_memory(memory, table, user_table_name),
977
+ },
978
+ dict,
979
+ )
980
+ else:
981
+ session_raw = self._query_one(
982
+ f"CREATE ONLY {table} CONTENT $content",
983
+ {"content": serialize_user_memory(memory, table, user_table_name)},
984
+ dict,
985
+ )
986
+ if session_raw is not None:
987
+ raw.append(session_raw)
988
+ if raw is None or not deserialize:
989
+ return [desurrealize_user_memory(x) for x in raw]
990
+ # wrapping with list because of:
991
+ # Type "List[Session]" is not assignable to return type "List[Session | Dict[str, Any]]"
992
+ # Consider switching from "list" to "Sequence" which is covariant
993
+ return list(deserialize_user_memories(raw))
994
+
995
+ # --- Metrics ---
996
+ def get_metrics(
997
+ self,
998
+ starting_date: Optional[date] = None,
999
+ ending_date: Optional[date] = None,
1000
+ ) -> Tuple[List[Dict[str, Any]], Optional[int]]:
1001
+ """Get all metrics matching the given date range.
1002
+
1003
+ Args:
1004
+ starting_date (Optional[date]): The starting date to filter metrics by.
1005
+ ending_date (Optional[date]): The ending date to filter metrics by.
1006
+
1007
+ Returns:
1008
+ Tuple[List[dict], Optional[int]]: A tuple containing the metrics and the timestamp of the latest update.
1009
+
1010
+ Raises:
1011
+ Exception: If an error occurs during retrieval.
1012
+ """
1013
+ table = self._get_table("metrics")
1014
+
1015
+ where = WhereClause()
1016
+
1017
+ # starting_date - need to convert date to datetime for comparison
1018
+ if starting_date is not None:
1019
+ starting_datetime = datetime.combine(starting_date, datetime.min.time()).replace(tzinfo=timezone.utc)
1020
+ where = where.and_("date", starting_datetime, ">=")
1021
+
1022
+ # ending_date - need to convert date to datetime for comparison
1023
+ if ending_date is not None:
1024
+ ending_datetime = datetime.combine(ending_date, datetime.min.time()).replace(tzinfo=timezone.utc)
1025
+ where = where.and_("date", ending_datetime, "<=")
1026
+
1027
+ where_clause, where_vars = where.build()
1028
+
1029
+ # Query
1030
+ query = dedent(f"""
1031
+ SELECT *
1032
+ FROM {table}
1033
+ {where_clause}
1034
+ ORDER BY date ASC
1035
+ """)
1036
+
1037
+ results = self._query(query, where_vars, dict)
1038
+
1039
+ # Get the latest updated_at from all results
1040
+ latest_update = None
1041
+ if results:
1042
+ # Find the maximum updated_at timestamp
1043
+ latest_update = max(int(r["updated_at"].timestamp()) for r in results)
1044
+
1045
+ # Transform results to match expected format
1046
+ transformed_results = []
1047
+ for r in results:
1048
+ transformed = dict(r)
1049
+
1050
+ # Convert RecordID to string
1051
+ if hasattr(transformed.get("id"), "id"):
1052
+ transformed["id"] = transformed["id"].id
1053
+ elif isinstance(transformed.get("id"), RecordID):
1054
+ transformed["id"] = str(transformed["id"].id)
1055
+
1056
+ # Convert datetime objects to Unix timestamps
1057
+ if isinstance(transformed.get("created_at"), datetime):
1058
+ transformed["created_at"] = int(transformed["created_at"].timestamp())
1059
+ if isinstance(transformed.get("updated_at"), datetime):
1060
+ transformed["updated_at"] = int(transformed["updated_at"].timestamp())
1061
+ if isinstance(transformed.get("date"), datetime):
1062
+ transformed["date"] = int(transformed["date"].timestamp())
1063
+
1064
+ transformed_results.append(transformed)
1065
+
1066
+ return transformed_results, latest_update
1067
+
1068
+ return [], latest_update
1069
+
1070
+ def calculate_metrics(self) -> Optional[List[Dict[str, Any]]]: # More specific return type
1071
+ """Calculate metrics for all dates without complete metrics.
1072
+
1073
+ Returns:
1074
+ Optional[List[Dict[str, Any]]]: The calculated metrics.
1075
+
1076
+ Raises:
1077
+ Exception: If an error occurs during metrics calculation.
1078
+ """
1079
+ try:
1080
+ table = self._get_table("metrics") # Removed create_table_if_not_found parameter
1081
+
1082
+ starting_date = get_metrics_calculation_starting_date(self.client, table, self.get_sessions)
1083
+
1084
+ if starting_date is None:
1085
+ log_info("No session data found. Won't calculate metrics.")
1086
+ return None
1087
+
1088
+ dates_to_process = get_dates_to_calculate_metrics_for(starting_date)
1089
+ if not dates_to_process:
1090
+ log_info("Metrics already calculated for all relevant dates.")
1091
+ return None
1092
+
1093
+ start_timestamp = datetime.combine(dates_to_process[0], datetime.min.time()).replace(tzinfo=timezone.utc)
1094
+ end_timestamp = datetime.combine(dates_to_process[-1] + timedelta(days=1), datetime.min.time()).replace(
1095
+ tzinfo=timezone.utc
1096
+ )
1097
+
1098
+ sessions = get_all_sessions_for_metrics_calculation(
1099
+ self.client, self._get_table("sessions"), start_timestamp, end_timestamp
1100
+ )
1101
+
1102
+ all_sessions_data = fetch_all_sessions_data(
1103
+ sessions=sessions, # Added parameter name for clarity
1104
+ dates_to_process=dates_to_process,
1105
+ start_timestamp=int(start_timestamp.timestamp()), # This expects int
1106
+ )
1107
+ if not all_sessions_data:
1108
+ log_info("No new session data found. Won't calculate metrics.")
1109
+ return None
1110
+
1111
+ metrics_records = []
1112
+
1113
+ for date_to_process in dates_to_process:
1114
+ date_key = date_to_process.isoformat()
1115
+ sessions_for_date = all_sessions_data.get(date_key, {})
1116
+
1117
+ # Skip dates with no sessions
1118
+ if not any(len(sessions) > 0 for sessions in sessions_for_date.values()):
1119
+ continue
1120
+
1121
+ metrics_record = calculate_date_metrics(date_to_process, sessions_for_date)
1122
+ metrics_records.append(metrics_record)
1123
+
1124
+ results = [] # Initialize before the if block
1125
+ if metrics_records:
1126
+ results = bulk_upsert_metrics(self.client, table, metrics_records)
1127
+
1128
+ log_debug("Updated metrics calculations")
1129
+ return results
1130
+
1131
+ except Exception as e:
1132
+ log_error(f"Exception refreshing metrics: {e}")
1133
+ raise e
1134
+
1135
+ # --- Knowledge ---
1136
+ def clear_knowledge(self) -> None:
1137
+ """Delete all knowledge rows from the database.
1138
+
1139
+ Raises:
1140
+ Exception: If an error occurs during deletion.
1141
+ """
1142
+ table = self._get_table("knowledge")
1143
+ _ = self.client.delete(table)
1144
+
1145
+ def delete_knowledge_content(self, id: str):
1146
+ """Delete a knowledge row from the database.
1147
+
1148
+ Args:
1149
+ id (str): The ID of the knowledge row to delete.
1150
+ """
1151
+ table = self._get_table("knowledge")
1152
+ self.client.delete(RecordID(table, id))
1153
+
1154
+ def get_knowledge_content(self, id: str) -> Optional[KnowledgeRow]:
1155
+ """Get a knowledge row from the database.
1156
+
1157
+ Args:
1158
+ id (str): The ID of the knowledge row to get.
1159
+
1160
+ Returns:
1161
+ Optional[KnowledgeRow]: The knowledge row, or None if it doesn't exist.
1162
+ """
1163
+ table = self._get_table("knowledge")
1164
+ record_id = RecordID(table, id)
1165
+ raw = self._query_one("SELECT * FROM ONLY $record_id", {"record_id": record_id}, dict)
1166
+ return deserialize_knowledge_row(raw) if raw else None
1167
+
1168
+ def get_knowledge_contents(
1169
+ self,
1170
+ limit: Optional[int] = None,
1171
+ page: Optional[int] = None,
1172
+ sort_by: Optional[str] = None,
1173
+ sort_order: Optional[str] = None,
1174
+ ) -> Tuple[List[KnowledgeRow], int]:
1175
+ """Get all knowledge contents from the database.
1176
+
1177
+ Args:
1178
+ limit (Optional[int]): The maximum number of knowledge contents to return.
1179
+ page (Optional[int]): The page number.
1180
+ sort_by (Optional[str]): The column to sort by.
1181
+ sort_order (Optional[str]): The order to sort by.
1182
+
1183
+ Returns:
1184
+ Tuple[List[KnowledgeRow], int]: The knowledge contents and total count.
1185
+
1186
+ Raises:
1187
+ Exception: If an error occurs during retrieval.
1188
+ """
1189
+ table = self._get_table("knowledge")
1190
+ where = WhereClause()
1191
+ where_clause, where_vars = where.build()
1192
+
1193
+ # Total count
1194
+ total_count = self._count(table, where_clause, where_vars)
1195
+
1196
+ # Query
1197
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
1198
+ query = dedent(f"""
1199
+ SELECT *
1200
+ FROM {table}
1201
+ {where_clause}
1202
+ {order_limit_start_clause}
1203
+ """)
1204
+ result = self._query(query, where_vars, dict)
1205
+ return [deserialize_knowledge_row(row) for row in result], total_count
1206
+
1207
+ def upsert_knowledge_content(self, knowledge_row: KnowledgeRow) -> Optional[KnowledgeRow]:
1208
+ """Upsert knowledge content in the database.
1209
+
1210
+ Args:
1211
+ knowledge_row (KnowledgeRow): The knowledge row to upsert.
1212
+
1213
+ Returns:
1214
+ Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
1215
+ """
1216
+ knowledge_table_name = self._get_table("knowledge")
1217
+ record = RecordID(knowledge_table_name, knowledge_row.id)
1218
+ query = "UPSERT ONLY $record CONTENT $content"
1219
+ result = self._query_one(
1220
+ query, {"record": record, "content": serialize_knowledge_row(knowledge_row, knowledge_table_name)}, dict
1221
+ )
1222
+ return deserialize_knowledge_row(result) if result else None
1223
+
1224
+ # --- Evals ---
1225
+ def clear_evals(self) -> None:
1226
+ """Delete all eval rows from the database.
1227
+
1228
+ Raises:
1229
+ Exception: If an error occurs during deletion.
1230
+ """
1231
+ table = self._get_table("evals")
1232
+ _ = self.client.delete(table)
1233
+
1234
+ def create_eval_run(self, eval_run: EvalRunRecord) -> Optional[EvalRunRecord]:
1235
+ """Create an EvalRunRecord in the database.
1236
+
1237
+ Args:
1238
+ eval_run (EvalRunRecord): The eval run to create.
1239
+
1240
+ Returns:
1241
+ Optional[EvalRunRecord]: The created eval run, or None if the operation fails.
1242
+
1243
+ Raises:
1244
+ Exception: If an error occurs during creation.
1245
+ """
1246
+ table = self._get_table("evals")
1247
+ rec_id = RecordID(table, eval_run.run_id)
1248
+ query = "CREATE ONLY $record CONTENT $content"
1249
+ result = self._query_one(
1250
+ query, {"record": rec_id, "content": serialize_eval_run_record(eval_run, self.table_names)}, dict
1251
+ )
1252
+ return deserialize_eval_run_record(result) if result else None
1253
+
1254
+ def delete_eval_runs(self, eval_run_ids: List[str]) -> None:
1255
+ """Delete multiple eval runs from the database.
1256
+
1257
+ Args:
1258
+ eval_run_ids (List[str]): List of eval run IDs to delete.
1259
+ """
1260
+ table = self._get_table("evals")
1261
+ records = [RecordID(table, id) for id in eval_run_ids]
1262
+ _ = self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
1263
+
1264
+ def get_eval_run(
1265
+ self, eval_run_id: str, deserialize: Optional[bool] = True
1266
+ ) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1267
+ """Get an eval run from the database.
1268
+
1269
+ Args:
1270
+ eval_run_id (str): The ID of the eval run to get.
1271
+ deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
1272
+
1273
+ Returns:
1274
+ Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1275
+ - When deserialize=True: EvalRunRecord object
1276
+ - When deserialize=False: EvalRun dictionary
1277
+
1278
+ Raises:
1279
+ Exception: If an error occurs during retrieval.
1280
+ """
1281
+ table = self._get_table("evals")
1282
+ record = RecordID(table, eval_run_id)
1283
+ result = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
1284
+ if not result or not deserialize:
1285
+ return desurrealize_eval_run_record(result) if result is not None else None
1286
+ return deserialize_eval_run_record(result)
1287
+
1288
+ def get_eval_runs(
1289
+ self,
1290
+ limit: Optional[int] = None,
1291
+ page: Optional[int] = None,
1292
+ sort_by: Optional[str] = None,
1293
+ sort_order: Optional[str] = None,
1294
+ agent_id: Optional[str] = None,
1295
+ team_id: Optional[str] = None,
1296
+ workflow_id: Optional[str] = None,
1297
+ model_id: Optional[str] = None,
1298
+ filter_type: Optional[EvalFilterType] = None,
1299
+ eval_type: Optional[List[EvalType]] = None,
1300
+ deserialize: Optional[bool] = True,
1301
+ ) -> Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
1302
+ """Get all eval runs from the database.
1303
+
1304
+ Args:
1305
+ limit (Optional[int]): The maximum number of eval runs to return.
1306
+ page (Optional[int]): The page number to return.
1307
+ sort_by (Optional[str]): The field to sort by.
1308
+ sort_order (Optional[str]): The order to sort by.
1309
+ agent_id (Optional[str]): The ID of the agent to filter by.
1310
+ team_id (Optional[str]): The ID of the team to filter by.
1311
+ workflow_id (Optional[str]): The ID of the workflow to filter by.
1312
+ model_id (Optional[str]): The ID of the model to filter by.
1313
+ eval_type (Optional[List[EvalType]]): The type of eval to filter by.
1314
+ filter_type (Optional[EvalFilterType]): The type of filter to apply.
1315
+ deserialize (Optional[bool]): Whether to serialize the eval runs. Defaults to True.
1316
+
1317
+ Returns:
1318
+ Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
1319
+ - When deserialize=True: List of EvalRunRecord objects
1320
+ - When deserialize=False: List of eval run dictionaries and the total count
1321
+
1322
+ Raises:
1323
+ Exception: If there is an error getting the eval runs.
1324
+ """
1325
+ table = self._get_table("evals")
1326
+
1327
+ where = WhereClause()
1328
+ if filter_type is not None:
1329
+ if filter_type == EvalFilterType.AGENT:
1330
+ where.and_("agent", RecordID(self._get_table("agents"), agent_id))
1331
+ elif filter_type == EvalFilterType.TEAM:
1332
+ where.and_("team", RecordID(self._get_table("teams"), team_id))
1333
+ elif filter_type == EvalFilterType.WORKFLOW:
1334
+ where.and_("workflow", RecordID(self._get_table("workflows"), workflow_id))
1335
+ if model_id is not None:
1336
+ where.and_("model_id", model_id)
1337
+ if eval_type is not None:
1338
+ where.and_("eval_type", eval_type)
1339
+ where_clause, where_vars = where.build()
1340
+
1341
+ # Order
1342
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
1343
+
1344
+ # Total count
1345
+ total_count = self._count(table, where_clause, where_vars)
1346
+
1347
+ # Query
1348
+ query = dedent(f"""
1349
+ SELECT *
1350
+ FROM {table}
1351
+ {where_clause}
1352
+ {order_limit_start_clause}
1353
+ """)
1354
+ result = self._query(query, where_vars, dict)
1355
+
1356
+ if not deserialize:
1357
+ return list(result), total_count
1358
+ return [deserialize_eval_run_record(x) for x in result]
1359
+
1360
+ def rename_eval_run(
1361
+ self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
1362
+ ) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1363
+ """Update the name of an eval run in the database.
1364
+
1365
+ Args:
1366
+ eval_run_id (str): The ID of the eval run to update.
1367
+ name (str): The new name of the eval run.
1368
+ deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
1369
+
1370
+ Returns:
1371
+ Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1372
+ - When deserialize=True: EvalRunRecord object
1373
+ - When deserialize=False: EvalRun dictionary
1374
+
1375
+ Raises:
1376
+ Exception: If there is an error updating the eval run.
1377
+ """
1378
+ table = self._get_table("evals")
1379
+ vars = {"record": RecordID(table, eval_run_id), "name": name}
1380
+
1381
+ # Query
1382
+ query = dedent("""
1383
+ UPDATE ONLY $record
1384
+ SET name = $name
1385
+ """)
1386
+ raw = self._query_one(query, vars, dict)
1387
+
1388
+ if not raw or not deserialize:
1389
+ return raw
1390
+ return deserialize_eval_run_record(raw)
1391
+
1392
+ # --- Traces ---
1393
+ def upsert_trace(self, trace: "Trace") -> None:
1394
+ """Create or update a single trace record in the database.
1395
+
1396
+ Args:
1397
+ trace: The Trace object to store (one per trace_id).
1398
+ """
1399
+ try:
1400
+ table = self._get_table("traces", create_table_if_not_found=True)
1401
+ record = RecordID(table, trace.trace_id)
1402
+
1403
+ # Check if trace exists
1404
+ existing = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
1405
+
1406
+ if existing:
1407
+ # workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
1408
+ def get_component_level(workflow_id: Any, team_id: Any, agent_id: Any, name: str) -> int:
1409
+ is_root_name = ".run" in name or ".arun" in name
1410
+ if not is_root_name:
1411
+ return 0
1412
+ elif workflow_id:
1413
+ return 3
1414
+ elif team_id:
1415
+ return 2
1416
+ elif agent_id:
1417
+ return 1
1418
+ else:
1419
+ return 0
1420
+
1421
+ existing_level = get_component_level(
1422
+ existing.get("workflow_id"),
1423
+ existing.get("team_id"),
1424
+ existing.get("agent_id"),
1425
+ existing.get("name", ""),
1426
+ )
1427
+ new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
1428
+ should_update_name = new_level > existing_level
1429
+
1430
+ # Parse existing start_time to calculate correct duration
1431
+ existing_start_time = existing.get("start_time")
1432
+ if isinstance(existing_start_time, datetime):
1433
+ recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
1434
+ else:
1435
+ recalculated_duration_ms = trace.duration_ms
1436
+
1437
+ # Build update query
1438
+ update_fields = [
1439
+ "end_time = $end_time",
1440
+ "duration_ms = $duration_ms",
1441
+ "status = $status",
1442
+ ]
1443
+ update_vars: Dict[str, Any] = {
1444
+ "record": record,
1445
+ "end_time": trace.end_time,
1446
+ "duration_ms": recalculated_duration_ms,
1447
+ "status": trace.status,
1448
+ }
1449
+
1450
+ if should_update_name:
1451
+ update_fields.append("name = $name")
1452
+ update_vars["name"] = trace.name
1453
+
1454
+ # Update context fields only if new value is not None
1455
+ if trace.run_id is not None:
1456
+ update_fields.append("run_id = $run_id")
1457
+ update_vars["run_id"] = trace.run_id
1458
+ if trace.session_id is not None:
1459
+ update_fields.append("session_id = $session_id")
1460
+ update_vars["session_id"] = trace.session_id
1461
+ if trace.user_id is not None:
1462
+ update_fields.append("user_id = $user_id")
1463
+ update_vars["user_id"] = trace.user_id
1464
+ if trace.agent_id is not None:
1465
+ update_fields.append("agent_id = $agent_id")
1466
+ update_vars["agent_id"] = trace.agent_id
1467
+ if trace.team_id is not None:
1468
+ update_fields.append("team_id = $team_id")
1469
+ update_vars["team_id"] = trace.team_id
1470
+ if trace.workflow_id is not None:
1471
+ update_fields.append("workflow_id = $workflow_id")
1472
+ update_vars["workflow_id"] = trace.workflow_id
1473
+
1474
+ update_query = f"UPDATE ONLY $record SET {', '.join(update_fields)}"
1475
+ self._query_one(update_query, update_vars, dict)
1476
+ else:
1477
+ # Create new trace
1478
+ trace_dict = trace.to_dict()
1479
+ trace_dict.pop("total_spans", None)
1480
+ trace_dict.pop("error_count", None)
1481
+
1482
+ # Convert datetime fields
1483
+ if isinstance(trace_dict.get("start_time"), str):
1484
+ trace_dict["start_time"] = datetime.fromisoformat(trace_dict["start_time"].replace("Z", "+00:00"))
1485
+ if isinstance(trace_dict.get("end_time"), str):
1486
+ trace_dict["end_time"] = datetime.fromisoformat(trace_dict["end_time"].replace("Z", "+00:00"))
1487
+ if isinstance(trace_dict.get("created_at"), str):
1488
+ trace_dict["created_at"] = datetime.fromisoformat(trace_dict["created_at"].replace("Z", "+00:00"))
1489
+
1490
+ self._query_one(
1491
+ "CREATE ONLY $record CONTENT $content",
1492
+ {"record": record, "content": trace_dict},
1493
+ dict,
1494
+ )
1495
+
1496
+ except Exception as e:
1497
+ log_error(f"Error creating trace: {e}")
1498
+
1499
+ def get_trace(
1500
+ self,
1501
+ trace_id: Optional[str] = None,
1502
+ run_id: Optional[str] = None,
1503
+ ):
1504
+ """Get a single trace by trace_id or other filters.
1505
+
1506
+ Args:
1507
+ trace_id: The unique trace identifier.
1508
+ run_id: Filter by run ID (returns first match).
1509
+
1510
+ Returns:
1511
+ Optional[Trace]: The trace if found, None otherwise.
1512
+
1513
+ Note:
1514
+ If multiple filters are provided, trace_id takes precedence.
1515
+ For other filters, the most recent trace is returned.
1516
+ """
1517
+ try:
1518
+ table = self._get_table("traces", create_table_if_not_found=False)
1519
+ spans_table = self._get_table("spans", create_table_if_not_found=False)
1520
+
1521
+ if trace_id:
1522
+ record = RecordID(table, trace_id)
1523
+ trace_data = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
1524
+ elif run_id:
1525
+ query = dedent(f"""
1526
+ SELECT * FROM {table}
1527
+ WHERE run_id = $run_id
1528
+ ORDER BY start_time DESC
1529
+ LIMIT 1
1530
+ """)
1531
+ trace_data = self._query_one(query, {"run_id": run_id}, dict)
1532
+ else:
1533
+ log_debug("get_trace called without any filter parameters")
1534
+ return None
1535
+
1536
+ if not trace_data:
1537
+ return None
1538
+
1539
+ # Calculate total_spans and error_count
1540
+ id_obj = trace_data.get("id")
1541
+ trace_id_val = trace_data.get("trace_id") or (id_obj.id if id_obj is not None else None)
1542
+ if trace_id_val:
1543
+ count_query = f"SELECT count() as total FROM {spans_table} WHERE trace_id = $trace_id GROUP ALL"
1544
+ count_result = self._query_one(count_query, {"trace_id": trace_id_val}, dict)
1545
+ trace_data["total_spans"] = count_result.get("total", 0) if count_result else 0
1546
+
1547
+ error_query = f"SELECT count() as total FROM {spans_table} WHERE trace_id = $trace_id AND status_code = 'ERROR' GROUP ALL"
1548
+ error_result = self._query_one(error_query, {"trace_id": trace_id_val}, dict)
1549
+ trace_data["error_count"] = error_result.get("total", 0) if error_result else 0
1550
+
1551
+ # Deserialize
1552
+ return self._deserialize_trace(trace_data)
1553
+
1554
+ except Exception as e:
1555
+ log_error(f"Error getting trace: {e}")
1556
+ return None
1557
+
1558
+ def get_traces(
1559
+ self,
1560
+ run_id: Optional[str] = None,
1561
+ session_id: Optional[str] = None,
1562
+ user_id: Optional[str] = None,
1563
+ agent_id: Optional[str] = None,
1564
+ team_id: Optional[str] = None,
1565
+ workflow_id: Optional[str] = None,
1566
+ status: Optional[str] = None,
1567
+ start_time: Optional[datetime] = None,
1568
+ end_time: Optional[datetime] = None,
1569
+ limit: Optional[int] = 20,
1570
+ page: Optional[int] = 1,
1571
+ ) -> tuple[List, int]:
1572
+ """Get traces matching the provided filters with pagination.
1573
+
1574
+ Args:
1575
+ run_id: Filter by run ID.
1576
+ session_id: Filter by session ID.
1577
+ user_id: Filter by user ID.
1578
+ agent_id: Filter by agent ID.
1579
+ team_id: Filter by team ID.
1580
+ workflow_id: Filter by workflow ID.
1581
+ status: Filter by status (OK, ERROR, UNSET).
1582
+ start_time: Filter traces starting after this datetime.
1583
+ end_time: Filter traces ending before this datetime.
1584
+ limit: Maximum number of traces to return per page.
1585
+ page: Page number (1-indexed).
1586
+
1587
+ Returns:
1588
+ tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
1589
+ """
1590
+ try:
1591
+ table = self._get_table("traces", create_table_if_not_found=False)
1592
+ spans_table = self._get_table("spans", create_table_if_not_found=False)
1593
+
1594
+ # Build where clause
1595
+ where = WhereClause()
1596
+ if run_id:
1597
+ where.and_("run_id", run_id)
1598
+ if session_id:
1599
+ where.and_("session_id", session_id)
1600
+ if user_id:
1601
+ where.and_("user_id", user_id)
1602
+ if agent_id:
1603
+ where.and_("agent_id", agent_id)
1604
+ if team_id:
1605
+ where.and_("team_id", team_id)
1606
+ if workflow_id:
1607
+ where.and_("workflow_id", workflow_id)
1608
+ if status:
1609
+ where.and_("status", status)
1610
+ if start_time:
1611
+ where.and_("start_time", start_time, ">=")
1612
+ if end_time:
1613
+ where.and_("end_time", end_time, "<=")
1614
+
1615
+ where_clause, where_vars = where.build()
1616
+
1617
+ # Total count
1618
+ total_count = self._count(table, where_clause, where_vars)
1619
+
1620
+ # Query with pagination
1621
+ order_limit_start_clause = order_limit_start("start_time", "DESC", limit, page)
1622
+ query = dedent(f"""
1623
+ SELECT * FROM {table}
1624
+ {where_clause}
1625
+ {order_limit_start_clause}
1626
+ """)
1627
+ traces_raw = self._query(query, where_vars, dict)
1628
+
1629
+ # Add total_spans and error_count to each trace
1630
+ result_traces = []
1631
+ for trace_data in traces_raw:
1632
+ id_obj = trace_data.get("id")
1633
+ trace_id_val = trace_data.get("trace_id") or (id_obj.id if id_obj is not None else None)
1634
+ if trace_id_val:
1635
+ count_query = f"SELECT count() as total FROM {spans_table} WHERE trace_id = $trace_id GROUP ALL"
1636
+ count_result = self._query_one(count_query, {"trace_id": trace_id_val}, dict)
1637
+ trace_data["total_spans"] = count_result.get("total", 0) if count_result else 0
1638
+
1639
+ error_query = f"SELECT count() as total FROM {spans_table} WHERE trace_id = $trace_id AND status_code = 'ERROR' GROUP ALL"
1640
+ error_result = self._query_one(error_query, {"trace_id": trace_id_val}, dict)
1641
+ trace_data["error_count"] = error_result.get("total", 0) if error_result else 0
1642
+
1643
+ result_traces.append(self._deserialize_trace(trace_data))
1644
+
1645
+ return result_traces, total_count
1646
+
1647
+ except Exception as e:
1648
+ log_error(f"Error getting traces: {e}")
1649
+ return [], 0
1650
+
1651
+ def get_trace_stats(
1652
+ self,
1653
+ user_id: Optional[str] = None,
1654
+ agent_id: Optional[str] = None,
1655
+ team_id: Optional[str] = None,
1656
+ workflow_id: Optional[str] = None,
1657
+ start_time: Optional[datetime] = None,
1658
+ end_time: Optional[datetime] = None,
1659
+ limit: Optional[int] = 20,
1660
+ page: Optional[int] = 1,
1661
+ ) -> tuple[List[Dict[str, Any]], int]:
1662
+ """Get trace statistics grouped by session.
1663
+
1664
+ Args:
1665
+ user_id: Filter by user ID.
1666
+ agent_id: Filter by agent ID.
1667
+ team_id: Filter by team ID.
1668
+ workflow_id: Filter by workflow ID.
1669
+ start_time: Filter sessions with traces created after this datetime.
1670
+ end_time: Filter sessions with traces created before this datetime.
1671
+ limit: Maximum number of sessions to return per page.
1672
+ page: Page number (1-indexed).
1673
+
1674
+ Returns:
1675
+ tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
1676
+ Each dict contains: session_id, user_id, agent_id, team_id, workflow_id, total_traces,
1677
+ first_trace_at, last_trace_at.
1678
+ """
1679
+ try:
1680
+ table = self._get_table("traces", create_table_if_not_found=False)
1681
+
1682
+ # Build where clause
1683
+ where = WhereClause()
1684
+ where.and_("!!session_id", True, "=") # Ensure session_id is not null
1685
+ if user_id:
1686
+ where.and_("user_id", user_id)
1687
+ if agent_id:
1688
+ where.and_("agent_id", agent_id)
1689
+ if team_id:
1690
+ where.and_("team_id", team_id)
1691
+ if workflow_id:
1692
+ where.and_("workflow_id", workflow_id)
1693
+ if start_time:
1694
+ where.and_("created_at", start_time, ">=")
1695
+ if end_time:
1696
+ where.and_("created_at", end_time, "<=")
1697
+
1698
+ where_clause, where_vars = where.build()
1699
+
1700
+ # Get total count of unique sessions
1701
+ count_query = dedent(f"""
1702
+ SELECT count() as total FROM (
1703
+ SELECT session_id FROM {table}
1704
+ {where_clause}
1705
+ GROUP BY session_id
1706
+ ) GROUP ALL
1707
+ """)
1708
+ count_result = self._query_one(count_query, where_vars, dict)
1709
+ total_count = count_result.get("total", 0) if count_result else 0
1710
+
1711
+ # Query with aggregation
1712
+ order_limit_start_clause = order_limit_start("last_trace_at", "DESC", limit, page)
1713
+ query = dedent(f"""
1714
+ SELECT
1715
+ session_id,
1716
+ user_id,
1717
+ agent_id,
1718
+ team_id,
1719
+ workflow_id,
1720
+ count() AS total_traces,
1721
+ time::min(created_at) AS first_trace_at,
1722
+ time::max(created_at) AS last_trace_at
1723
+ FROM {table}
1724
+ {where_clause}
1725
+ GROUP BY session_id, user_id, agent_id, team_id, workflow_id
1726
+ {order_limit_start_clause}
1727
+ """)
1728
+ results = self._query(query, where_vars, dict)
1729
+
1730
+ # Convert datetime objects
1731
+ stats_list = []
1732
+ for row in results:
1733
+ stat = dict(row)
1734
+ if isinstance(stat.get("first_trace_at"), datetime):
1735
+ pass # Keep as datetime
1736
+ if isinstance(stat.get("last_trace_at"), datetime):
1737
+ pass # Keep as datetime
1738
+ stats_list.append(stat)
1739
+
1740
+ return stats_list, total_count
1741
+
1742
+ except Exception as e:
1743
+ log_error(f"Error getting trace stats: {e}")
1744
+ return [], 0
1745
+
1746
+ def _deserialize_trace(self, trace_data: dict) -> "Trace":
1747
+ """Helper to deserialize a trace record from SurrealDB."""
1748
+ from agno.tracing.schemas import Trace
1749
+
1750
+ # Handle RecordID for id field
1751
+ if isinstance(trace_data.get("id"), RecordID):
1752
+ if "trace_id" not in trace_data or not trace_data["trace_id"]:
1753
+ trace_data["trace_id"] = trace_data["id"].id
1754
+ del trace_data["id"]
1755
+
1756
+ # Convert datetime to ISO string for Trace.from_dict
1757
+ for field in ["start_time", "end_time", "created_at"]:
1758
+ if isinstance(trace_data.get(field), datetime):
1759
+ trace_data[field] = trace_data[field].isoformat()
1760
+
1761
+ return Trace.from_dict(trace_data)
1762
+
1763
+ # --- Spans ---
1764
+ def create_span(self, span: "Span") -> None:
1765
+ """Create a single span in the database.
1766
+
1767
+ Args:
1768
+ span: The Span object to store.
1769
+ """
1770
+ try:
1771
+ table = self._get_table("spans", create_table_if_not_found=True)
1772
+ record = RecordID(table, span.span_id)
1773
+
1774
+ span_dict = span.to_dict()
1775
+
1776
+ # Convert datetime fields
1777
+ if isinstance(span_dict.get("start_time"), str):
1778
+ span_dict["start_time"] = datetime.fromisoformat(span_dict["start_time"].replace("Z", "+00:00"))
1779
+ if isinstance(span_dict.get("end_time"), str):
1780
+ span_dict["end_time"] = datetime.fromisoformat(span_dict["end_time"].replace("Z", "+00:00"))
1781
+ if isinstance(span_dict.get("created_at"), str):
1782
+ span_dict["created_at"] = datetime.fromisoformat(span_dict["created_at"].replace("Z", "+00:00"))
1783
+
1784
+ self._query_one(
1785
+ "CREATE ONLY $record CONTENT $content",
1786
+ {"record": record, "content": span_dict},
1787
+ dict,
1788
+ )
1789
+
1790
+ except Exception as e:
1791
+ log_error(f"Error creating span: {e}")
1792
+
1793
+ def create_spans(self, spans: List) -> None:
1794
+ """Create multiple spans in the database as a batch.
1795
+
1796
+ Args:
1797
+ spans: List of Span objects to store.
1798
+ """
1799
+ if not spans:
1800
+ return
1801
+
1802
+ try:
1803
+ table = self._get_table("spans", create_table_if_not_found=True)
1804
+
1805
+ for span in spans:
1806
+ record = RecordID(table, span.span_id)
1807
+ span_dict = span.to_dict()
1808
+
1809
+ # Convert datetime fields
1810
+ if isinstance(span_dict.get("start_time"), str):
1811
+ span_dict["start_time"] = datetime.fromisoformat(span_dict["start_time"].replace("Z", "+00:00"))
1812
+ if isinstance(span_dict.get("end_time"), str):
1813
+ span_dict["end_time"] = datetime.fromisoformat(span_dict["end_time"].replace("Z", "+00:00"))
1814
+ if isinstance(span_dict.get("created_at"), str):
1815
+ span_dict["created_at"] = datetime.fromisoformat(span_dict["created_at"].replace("Z", "+00:00"))
1816
+
1817
+ self._query_one(
1818
+ "CREATE ONLY $record CONTENT $content",
1819
+ {"record": record, "content": span_dict},
1820
+ dict,
1821
+ )
1822
+
1823
+ except Exception as e:
1824
+ log_error(f"Error creating spans batch: {e}")
1825
+
1826
+ def get_span(self, span_id: str):
1827
+ """Get a single span by its span_id.
1828
+
1829
+ Args:
1830
+ span_id: The unique span identifier.
1831
+
1832
+ Returns:
1833
+ Optional[Span]: The span if found, None otherwise.
1834
+ """
1835
+ try:
1836
+ table = self._get_table("spans", create_table_if_not_found=False)
1837
+ record = RecordID(table, span_id)
1838
+
1839
+ span_data = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
1840
+ if not span_data:
1841
+ return None
1842
+
1843
+ return self._deserialize_span(span_data)
1844
+
1845
+ except Exception as e:
1846
+ log_error(f"Error getting span: {e}")
1847
+ return None
1848
+
1849
+ def get_spans(
1850
+ self,
1851
+ trace_id: Optional[str] = None,
1852
+ parent_span_id: Optional[str] = None,
1853
+ limit: Optional[int] = 1000,
1854
+ ) -> List:
1855
+ """Get spans matching the provided filters.
1856
+
1857
+ Args:
1858
+ trace_id: Filter by trace ID.
1859
+ parent_span_id: Filter by parent span ID.
1860
+ limit: Maximum number of spans to return.
1861
+
1862
+ Returns:
1863
+ List[Span]: List of matching spans.
1864
+ """
1865
+ try:
1866
+ table = self._get_table("spans", create_table_if_not_found=False)
1867
+
1868
+ # Build where clause
1869
+ where = WhereClause()
1870
+ if trace_id:
1871
+ where.and_("trace_id", trace_id)
1872
+ if parent_span_id:
1873
+ where.and_("parent_span_id", parent_span_id)
1874
+
1875
+ where_clause, where_vars = where.build()
1876
+
1877
+ # Query
1878
+ limit_clause = f"LIMIT {limit}" if limit else ""
1879
+ query = dedent(f"""
1880
+ SELECT * FROM {table}
1881
+ {where_clause}
1882
+ ORDER BY start_time ASC
1883
+ {limit_clause}
1884
+ """)
1885
+ spans_raw = self._query(query, where_vars, dict)
1886
+
1887
+ return [self._deserialize_span(s) for s in spans_raw]
1888
+
1889
+ except Exception as e:
1890
+ log_error(f"Error getting spans: {e}")
1891
+ return []
1892
+
1893
+ def _deserialize_span(self, span_data: dict) -> "Span":
1894
+ """Helper to deserialize a span record from SurrealDB."""
1895
+ from agno.tracing.schemas import Span
1896
+
1897
+ # Handle RecordID for id field
1898
+ if isinstance(span_data.get("id"), RecordID):
1899
+ if "span_id" not in span_data or not span_data["span_id"]:
1900
+ span_data["span_id"] = span_data["id"].id
1901
+ del span_data["id"]
1902
+
1903
+ # Convert datetime to ISO string for Span.from_dict
1904
+ for field in ["start_time", "end_time", "created_at"]:
1905
+ if isinstance(span_data.get(field), datetime):
1906
+ span_data[field] = span_data[field].isoformat()
1907
+
1908
+ return Span.from_dict(span_data)