agno 2.0.1__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. agno/agent/agent.py +6015 -2823
  2. agno/api/api.py +2 -0
  3. agno/api/os.py +1 -1
  4. agno/culture/__init__.py +3 -0
  5. agno/culture/manager.py +956 -0
  6. agno/db/async_postgres/__init__.py +3 -0
  7. agno/db/base.py +385 -6
  8. agno/db/dynamo/dynamo.py +388 -81
  9. agno/db/dynamo/schemas.py +47 -10
  10. agno/db/dynamo/utils.py +63 -4
  11. agno/db/firestore/firestore.py +435 -64
  12. agno/db/firestore/schemas.py +11 -0
  13. agno/db/firestore/utils.py +102 -4
  14. agno/db/gcs_json/gcs_json_db.py +384 -42
  15. agno/db/gcs_json/utils.py +60 -26
  16. agno/db/in_memory/in_memory_db.py +351 -66
  17. agno/db/in_memory/utils.py +60 -2
  18. agno/db/json/json_db.py +339 -48
  19. agno/db/json/utils.py +60 -26
  20. agno/db/migrations/manager.py +199 -0
  21. agno/db/migrations/v1_to_v2.py +510 -37
  22. agno/db/migrations/versions/__init__.py +0 -0
  23. agno/db/migrations/versions/v2_3_0.py +938 -0
  24. agno/db/mongo/__init__.py +15 -1
  25. agno/db/mongo/async_mongo.py +2036 -0
  26. agno/db/mongo/mongo.py +653 -76
  27. agno/db/mongo/schemas.py +13 -0
  28. agno/db/mongo/utils.py +80 -8
  29. agno/db/mysql/mysql.py +687 -25
  30. agno/db/mysql/schemas.py +61 -37
  31. agno/db/mysql/utils.py +60 -2
  32. agno/db/postgres/__init__.py +2 -1
  33. agno/db/postgres/async_postgres.py +2001 -0
  34. agno/db/postgres/postgres.py +676 -57
  35. agno/db/postgres/schemas.py +43 -18
  36. agno/db/postgres/utils.py +164 -2
  37. agno/db/redis/redis.py +344 -38
  38. agno/db/redis/schemas.py +18 -0
  39. agno/db/redis/utils.py +60 -2
  40. agno/db/schemas/__init__.py +2 -1
  41. agno/db/schemas/culture.py +120 -0
  42. agno/db/schemas/memory.py +13 -0
  43. agno/db/singlestore/schemas.py +26 -1
  44. agno/db/singlestore/singlestore.py +687 -53
  45. agno/db/singlestore/utils.py +60 -2
  46. agno/db/sqlite/__init__.py +2 -1
  47. agno/db/sqlite/async_sqlite.py +2371 -0
  48. agno/db/sqlite/schemas.py +24 -0
  49. agno/db/sqlite/sqlite.py +774 -85
  50. agno/db/sqlite/utils.py +168 -5
  51. agno/db/surrealdb/__init__.py +3 -0
  52. agno/db/surrealdb/metrics.py +292 -0
  53. agno/db/surrealdb/models.py +309 -0
  54. agno/db/surrealdb/queries.py +71 -0
  55. agno/db/surrealdb/surrealdb.py +1361 -0
  56. agno/db/surrealdb/utils.py +147 -0
  57. agno/db/utils.py +50 -22
  58. agno/eval/accuracy.py +50 -43
  59. agno/eval/performance.py +6 -3
  60. agno/eval/reliability.py +6 -3
  61. agno/eval/utils.py +33 -16
  62. agno/exceptions.py +68 -1
  63. agno/filters.py +354 -0
  64. agno/guardrails/__init__.py +6 -0
  65. agno/guardrails/base.py +19 -0
  66. agno/guardrails/openai.py +144 -0
  67. agno/guardrails/pii.py +94 -0
  68. agno/guardrails/prompt_injection.py +52 -0
  69. agno/integrations/discord/client.py +1 -0
  70. agno/knowledge/chunking/agentic.py +13 -10
  71. agno/knowledge/chunking/fixed.py +1 -1
  72. agno/knowledge/chunking/semantic.py +40 -8
  73. agno/knowledge/chunking/strategy.py +59 -15
  74. agno/knowledge/embedder/aws_bedrock.py +9 -4
  75. agno/knowledge/embedder/azure_openai.py +54 -0
  76. agno/knowledge/embedder/base.py +2 -0
  77. agno/knowledge/embedder/cohere.py +184 -5
  78. agno/knowledge/embedder/fastembed.py +1 -1
  79. agno/knowledge/embedder/google.py +79 -1
  80. agno/knowledge/embedder/huggingface.py +9 -4
  81. agno/knowledge/embedder/jina.py +63 -0
  82. agno/knowledge/embedder/mistral.py +78 -11
  83. agno/knowledge/embedder/nebius.py +1 -1
  84. agno/knowledge/embedder/ollama.py +13 -0
  85. agno/knowledge/embedder/openai.py +37 -65
  86. agno/knowledge/embedder/sentence_transformer.py +8 -4
  87. agno/knowledge/embedder/vllm.py +262 -0
  88. agno/knowledge/embedder/voyageai.py +69 -16
  89. agno/knowledge/knowledge.py +594 -186
  90. agno/knowledge/reader/base.py +9 -2
  91. agno/knowledge/reader/csv_reader.py +8 -10
  92. agno/knowledge/reader/docx_reader.py +5 -6
  93. agno/knowledge/reader/field_labeled_csv_reader.py +290 -0
  94. agno/knowledge/reader/json_reader.py +6 -5
  95. agno/knowledge/reader/markdown_reader.py +13 -13
  96. agno/knowledge/reader/pdf_reader.py +43 -68
  97. agno/knowledge/reader/pptx_reader.py +101 -0
  98. agno/knowledge/reader/reader_factory.py +51 -6
  99. agno/knowledge/reader/s3_reader.py +3 -15
  100. agno/knowledge/reader/tavily_reader.py +194 -0
  101. agno/knowledge/reader/text_reader.py +13 -13
  102. agno/knowledge/reader/web_search_reader.py +2 -43
  103. agno/knowledge/reader/website_reader.py +43 -25
  104. agno/knowledge/reranker/__init__.py +2 -8
  105. agno/knowledge/types.py +9 -0
  106. agno/knowledge/utils.py +20 -0
  107. agno/media.py +72 -0
  108. agno/memory/manager.py +336 -82
  109. agno/models/aimlapi/aimlapi.py +2 -2
  110. agno/models/anthropic/claude.py +183 -37
  111. agno/models/aws/bedrock.py +52 -112
  112. agno/models/aws/claude.py +33 -1
  113. agno/models/azure/ai_foundry.py +33 -15
  114. agno/models/azure/openai_chat.py +25 -8
  115. agno/models/base.py +999 -519
  116. agno/models/cerebras/cerebras.py +19 -13
  117. agno/models/cerebras/cerebras_openai.py +8 -5
  118. agno/models/cohere/chat.py +27 -1
  119. agno/models/cometapi/__init__.py +5 -0
  120. agno/models/cometapi/cometapi.py +57 -0
  121. agno/models/dashscope/dashscope.py +1 -0
  122. agno/models/deepinfra/deepinfra.py +2 -2
  123. agno/models/deepseek/deepseek.py +2 -2
  124. agno/models/fireworks/fireworks.py +2 -2
  125. agno/models/google/gemini.py +103 -31
  126. agno/models/groq/groq.py +28 -11
  127. agno/models/huggingface/huggingface.py +2 -1
  128. agno/models/internlm/internlm.py +2 -2
  129. agno/models/langdb/langdb.py +4 -4
  130. agno/models/litellm/chat.py +18 -1
  131. agno/models/litellm/litellm_openai.py +2 -2
  132. agno/models/llama_cpp/__init__.py +5 -0
  133. agno/models/llama_cpp/llama_cpp.py +22 -0
  134. agno/models/message.py +139 -0
  135. agno/models/meta/llama.py +27 -10
  136. agno/models/meta/llama_openai.py +5 -17
  137. agno/models/nebius/nebius.py +6 -6
  138. agno/models/nexus/__init__.py +3 -0
  139. agno/models/nexus/nexus.py +22 -0
  140. agno/models/nvidia/nvidia.py +2 -2
  141. agno/models/ollama/chat.py +59 -5
  142. agno/models/openai/chat.py +69 -29
  143. agno/models/openai/responses.py +103 -106
  144. agno/models/openrouter/openrouter.py +41 -3
  145. agno/models/perplexity/perplexity.py +4 -5
  146. agno/models/portkey/portkey.py +3 -3
  147. agno/models/requesty/__init__.py +5 -0
  148. agno/models/requesty/requesty.py +52 -0
  149. agno/models/response.py +77 -1
  150. agno/models/sambanova/sambanova.py +2 -2
  151. agno/models/siliconflow/__init__.py +5 -0
  152. agno/models/siliconflow/siliconflow.py +25 -0
  153. agno/models/together/together.py +2 -2
  154. agno/models/utils.py +254 -8
  155. agno/models/vercel/v0.py +2 -2
  156. agno/models/vertexai/__init__.py +0 -0
  157. agno/models/vertexai/claude.py +96 -0
  158. agno/models/vllm/vllm.py +1 -0
  159. agno/models/xai/xai.py +3 -2
  160. agno/os/app.py +543 -178
  161. agno/os/auth.py +24 -14
  162. agno/os/config.py +1 -0
  163. agno/os/interfaces/__init__.py +1 -0
  164. agno/os/interfaces/a2a/__init__.py +3 -0
  165. agno/os/interfaces/a2a/a2a.py +42 -0
  166. agno/os/interfaces/a2a/router.py +250 -0
  167. agno/os/interfaces/a2a/utils.py +924 -0
  168. agno/os/interfaces/agui/agui.py +23 -7
  169. agno/os/interfaces/agui/router.py +27 -3
  170. agno/os/interfaces/agui/utils.py +242 -142
  171. agno/os/interfaces/base.py +6 -2
  172. agno/os/interfaces/slack/router.py +81 -23
  173. agno/os/interfaces/slack/slack.py +29 -14
  174. agno/os/interfaces/whatsapp/router.py +11 -4
  175. agno/os/interfaces/whatsapp/whatsapp.py +14 -7
  176. agno/os/mcp.py +111 -54
  177. agno/os/middleware/__init__.py +7 -0
  178. agno/os/middleware/jwt.py +233 -0
  179. agno/os/router.py +556 -139
  180. agno/os/routers/evals/evals.py +71 -34
  181. agno/os/routers/evals/schemas.py +31 -31
  182. agno/os/routers/evals/utils.py +6 -5
  183. agno/os/routers/health.py +31 -0
  184. agno/os/routers/home.py +52 -0
  185. agno/os/routers/knowledge/knowledge.py +185 -38
  186. agno/os/routers/knowledge/schemas.py +82 -22
  187. agno/os/routers/memory/memory.py +158 -53
  188. agno/os/routers/memory/schemas.py +20 -16
  189. agno/os/routers/metrics/metrics.py +20 -8
  190. agno/os/routers/metrics/schemas.py +16 -16
  191. agno/os/routers/session/session.py +499 -38
  192. agno/os/schema.py +308 -198
  193. agno/os/utils.py +401 -41
  194. agno/reasoning/anthropic.py +80 -0
  195. agno/reasoning/azure_ai_foundry.py +2 -2
  196. agno/reasoning/deepseek.py +2 -2
  197. agno/reasoning/default.py +3 -1
  198. agno/reasoning/gemini.py +73 -0
  199. agno/reasoning/groq.py +2 -2
  200. agno/reasoning/ollama.py +2 -2
  201. agno/reasoning/openai.py +7 -2
  202. agno/reasoning/vertexai.py +76 -0
  203. agno/run/__init__.py +6 -0
  204. agno/run/agent.py +248 -94
  205. agno/run/base.py +44 -5
  206. agno/run/team.py +238 -97
  207. agno/run/workflow.py +144 -33
  208. agno/session/agent.py +105 -89
  209. agno/session/summary.py +65 -25
  210. agno/session/team.py +176 -96
  211. agno/session/workflow.py +406 -40
  212. agno/team/team.py +3854 -1610
  213. agno/tools/dalle.py +2 -4
  214. agno/tools/decorator.py +4 -2
  215. agno/tools/duckduckgo.py +15 -11
  216. agno/tools/e2b.py +14 -7
  217. agno/tools/eleven_labs.py +23 -25
  218. agno/tools/exa.py +21 -16
  219. agno/tools/file.py +153 -23
  220. agno/tools/file_generation.py +350 -0
  221. agno/tools/firecrawl.py +4 -4
  222. agno/tools/function.py +250 -30
  223. agno/tools/gmail.py +238 -14
  224. agno/tools/google_drive.py +270 -0
  225. agno/tools/googlecalendar.py +36 -8
  226. agno/tools/googlesheets.py +20 -5
  227. agno/tools/jira.py +20 -0
  228. agno/tools/knowledge.py +3 -3
  229. agno/tools/mcp/__init__.py +10 -0
  230. agno/tools/mcp/mcp.py +331 -0
  231. agno/tools/mcp/multi_mcp.py +347 -0
  232. agno/tools/mcp/params.py +24 -0
  233. agno/tools/mcp_toolbox.py +284 -0
  234. agno/tools/mem0.py +11 -17
  235. agno/tools/memori.py +1 -53
  236. agno/tools/memory.py +419 -0
  237. agno/tools/models/nebius.py +5 -5
  238. agno/tools/models_labs.py +20 -10
  239. agno/tools/notion.py +204 -0
  240. agno/tools/parallel.py +314 -0
  241. agno/tools/scrapegraph.py +58 -31
  242. agno/tools/searxng.py +2 -2
  243. agno/tools/serper.py +2 -2
  244. agno/tools/slack.py +18 -3
  245. agno/tools/spider.py +2 -2
  246. agno/tools/tavily.py +146 -0
  247. agno/tools/whatsapp.py +1 -1
  248. agno/tools/workflow.py +278 -0
  249. agno/tools/yfinance.py +12 -11
  250. agno/utils/agent.py +820 -0
  251. agno/utils/audio.py +27 -0
  252. agno/utils/common.py +90 -1
  253. agno/utils/events.py +217 -2
  254. agno/utils/gemini.py +180 -22
  255. agno/utils/hooks.py +57 -0
  256. agno/utils/http.py +111 -0
  257. agno/utils/knowledge.py +12 -5
  258. agno/utils/log.py +1 -0
  259. agno/utils/mcp.py +92 -2
  260. agno/utils/media.py +188 -10
  261. agno/utils/merge_dict.py +22 -1
  262. agno/utils/message.py +60 -0
  263. agno/utils/models/claude.py +40 -11
  264. agno/utils/print_response/agent.py +105 -21
  265. agno/utils/print_response/team.py +103 -38
  266. agno/utils/print_response/workflow.py +251 -34
  267. agno/utils/reasoning.py +22 -1
  268. agno/utils/serialize.py +32 -0
  269. agno/utils/streamlit.py +16 -10
  270. agno/utils/string.py +41 -0
  271. agno/utils/team.py +98 -9
  272. agno/utils/tools.py +1 -1
  273. agno/vectordb/base.py +23 -4
  274. agno/vectordb/cassandra/cassandra.py +65 -9
  275. agno/vectordb/chroma/chromadb.py +182 -38
  276. agno/vectordb/clickhouse/clickhousedb.py +64 -11
  277. agno/vectordb/couchbase/couchbase.py +105 -10
  278. agno/vectordb/lancedb/lance_db.py +124 -133
  279. agno/vectordb/langchaindb/langchaindb.py +25 -7
  280. agno/vectordb/lightrag/lightrag.py +17 -3
  281. agno/vectordb/llamaindex/__init__.py +3 -0
  282. agno/vectordb/llamaindex/llamaindexdb.py +46 -7
  283. agno/vectordb/milvus/milvus.py +126 -9
  284. agno/vectordb/mongodb/__init__.py +7 -1
  285. agno/vectordb/mongodb/mongodb.py +112 -7
  286. agno/vectordb/pgvector/pgvector.py +142 -21
  287. agno/vectordb/pineconedb/pineconedb.py +80 -8
  288. agno/vectordb/qdrant/qdrant.py +125 -39
  289. agno/vectordb/redis/__init__.py +9 -0
  290. agno/vectordb/redis/redisdb.py +694 -0
  291. agno/vectordb/singlestore/singlestore.py +111 -25
  292. agno/vectordb/surrealdb/surrealdb.py +31 -5
  293. agno/vectordb/upstashdb/upstashdb.py +76 -8
  294. agno/vectordb/weaviate/weaviate.py +86 -15
  295. agno/workflow/__init__.py +2 -0
  296. agno/workflow/agent.py +299 -0
  297. agno/workflow/condition.py +112 -18
  298. agno/workflow/loop.py +69 -10
  299. agno/workflow/parallel.py +266 -118
  300. agno/workflow/router.py +110 -17
  301. agno/workflow/step.py +638 -129
  302. agno/workflow/steps.py +65 -6
  303. agno/workflow/types.py +61 -23
  304. agno/workflow/workflow.py +2085 -272
  305. {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/METADATA +182 -58
  306. agno-2.3.0.dist-info/RECORD +577 -0
  307. agno/knowledge/reader/url_reader.py +0 -128
  308. agno/tools/googlesearch.py +0 -98
  309. agno/tools/mcp.py +0 -610
  310. agno/utils/models/aws_claude.py +0 -170
  311. agno-2.0.1.dist-info/RECORD +0 -515
  312. {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/WHEEL +0 -0
  313. {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/licenses/LICENSE +0 -0
  314. {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1361 @@
1
+ from datetime import date, datetime, timedelta, timezone
2
+ from textwrap import dedent
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
4
+
5
+ from agno.db.base import BaseDb, SessionType
6
+ from agno.db.postgres.utils import (
7
+ get_dates_to_calculate_metrics_for,
8
+ )
9
+ from agno.db.schemas import UserMemory
10
+ from agno.db.schemas.culture import CulturalKnowledge
11
+ from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
12
+ from agno.db.schemas.knowledge import KnowledgeRow
13
+ from agno.db.surrealdb import utils
14
+ from agno.db.surrealdb.metrics import (
15
+ bulk_upsert_metrics,
16
+ calculate_date_metrics,
17
+ fetch_all_sessions_data,
18
+ get_all_sessions_for_metrics_calculation,
19
+ get_metrics_calculation_starting_date,
20
+ )
21
+ from agno.db.surrealdb.models import (
22
+ TableType,
23
+ deserialize_cultural_knowledge,
24
+ deserialize_eval_run_record,
25
+ deserialize_knowledge_row,
26
+ deserialize_session,
27
+ deserialize_sessions,
28
+ deserialize_user_memories,
29
+ deserialize_user_memory,
30
+ desurrealize_eval_run_record,
31
+ desurrealize_session,
32
+ desurrealize_user_memory,
33
+ get_schema,
34
+ get_session_type,
35
+ serialize_cultural_knowledge,
36
+ serialize_eval_run_record,
37
+ serialize_knowledge_row,
38
+ serialize_session,
39
+ serialize_user_memory,
40
+ )
41
+ from agno.db.surrealdb.queries import COUNT_QUERY, WhereClause, order_limit_start
42
+ from agno.db.surrealdb.utils import build_client
43
+ from agno.session import Session
44
+ from agno.utils.log import log_debug, log_error, log_info
45
+ from agno.utils.string import generate_id
46
+
47
+ try:
48
+ from surrealdb import BlockingHttpSurrealConnection, BlockingWsSurrealConnection, RecordID
49
+ except ImportError:
50
+ raise ImportError("The `surrealdb` package is not installed. Please install it via `pip install surrealdb`.")
51
+
52
+
53
+ class SurrealDb(BaseDb):
54
+ def __init__(
55
+ self,
56
+ client: Optional[Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]],
57
+ db_url: str,
58
+ db_creds: dict[str, str],
59
+ db_ns: str,
60
+ db_db: str,
61
+ session_table: Optional[str] = None,
62
+ memory_table: Optional[str] = None,
63
+ metrics_table: Optional[str] = None,
64
+ eval_table: Optional[str] = None,
65
+ knowledge_table: Optional[str] = None,
66
+ culture_table: Optional[str] = None,
67
+ id: Optional[str] = None,
68
+ ):
69
+ """
70
+ Interface for interacting with a SurrealDB database.
71
+
72
+ Args:
73
+ client: A blocking connection, either HTTP or WS
74
+ """
75
+ if id is None:
76
+ base_seed = db_url
77
+ seed = f"{base_seed}#{db_db}"
78
+ id = generate_id(seed)
79
+
80
+ super().__init__(
81
+ id=id,
82
+ session_table=session_table,
83
+ memory_table=memory_table,
84
+ metrics_table=metrics_table,
85
+ eval_table=eval_table,
86
+ knowledge_table=knowledge_table,
87
+ culture_table=culture_table,
88
+ )
89
+ self._client = client
90
+ self._db_url = db_url
91
+ self._db_creds = db_creds
92
+ self._db_ns = db_ns
93
+ self._db_db = db_db
94
+ self._users_table_name: str = "agno_users"
95
+ self._agents_table_name: str = "agno_agents"
96
+ self._teams_table_name: str = "agno_teams"
97
+ self._workflows_table_name: str = "agno_workflows"
98
+
99
+ @property
100
+ def client(self) -> Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]:
101
+ if self._client is None:
102
+ self._client = build_client(self._db_url, self._db_creds, self._db_ns, self._db_db)
103
+ return self._client
104
+
105
+ @property
106
+ def table_names(self) -> dict[TableType, str]:
107
+ return {
108
+ "agents": self._agents_table_name,
109
+ "culture": self.culture_table_name,
110
+ "evals": self.eval_table_name,
111
+ "knowledge": self.knowledge_table_name,
112
+ "memories": self.memory_table_name,
113
+ "sessions": self.session_table_name,
114
+ "teams": self._teams_table_name,
115
+ "users": self._users_table_name,
116
+ "workflows": self._workflows_table_name,
117
+ }
118
+
119
+ def table_exists(self, table_name: str) -> bool:
120
+ """Check if a table with the given name exists in the SurrealDB database.
121
+
122
+ Args:
123
+ table_name: Name of the table to check
124
+
125
+ Returns:
126
+ bool: True if the table exists in the database, False otherwise
127
+ """
128
+ response = self._query_one("INFO FOR DB", {}, dict)
129
+ if response is None:
130
+ raise Exception("Failed to retrieve database information")
131
+ return table_name in response.get("tables", [])
132
+
133
+ def _table_exists(self, table_name: str) -> bool:
134
+ """Deprecated: Use table_exists() instead."""
135
+ return self.table_exists(table_name)
136
+
137
+ def _create_table(self, table_type: TableType, table_name: str):
138
+ query = get_schema(table_type, table_name)
139
+ self.client.query(query)
140
+
141
+ def _get_table(self, table_type: TableType, create_table_if_not_found: bool = True):
142
+ if table_type == "sessions":
143
+ table_name = self.session_table_name
144
+ elif table_type == "memories":
145
+ table_name = self.memory_table_name
146
+ elif table_type == "knowledge":
147
+ table_name = self.knowledge_table_name
148
+ elif table_type == "culture":
149
+ table_name = self.culture_table_name
150
+ elif table_type == "users":
151
+ table_name = self._users_table_name
152
+ elif table_type == "agents":
153
+ table_name = self._agents_table_name
154
+ elif table_type == "teams":
155
+ table_name = self._teams_table_name
156
+ elif table_type == "workflows":
157
+ table_name = self._workflows_table_name
158
+ elif table_type == "evals":
159
+ table_name = self.eval_table_name
160
+ elif table_type == "metrics":
161
+ table_name = self.metrics_table_name
162
+ else:
163
+ raise NotImplementedError(f"Unknown table type: {table_type}")
164
+
165
+ if create_table_if_not_found and not self._table_exists(table_name):
166
+ self._create_table(table_type, table_name)
167
+
168
+ return table_name
169
+
170
+ def get_latest_schema_version(self):
171
+ """Get the latest version of the database schema."""
172
+ pass
173
+
174
+ def upsert_schema_version(self, version: str) -> None:
175
+ """Upsert the schema version into the database."""
176
+ pass
177
+
178
+ def _query(
179
+ self,
180
+ query: str,
181
+ vars: dict[str, Any],
182
+ record_type: type[utils.RecordType],
183
+ ) -> Sequence[utils.RecordType]:
184
+ return utils.query(self.client, query, vars, record_type)
185
+
186
+ def _query_one(
187
+ self,
188
+ query: str,
189
+ vars: dict[str, Any],
190
+ record_type: type[utils.RecordType],
191
+ ) -> Optional[utils.RecordType]:
192
+ return utils.query_one(self.client, query, vars, record_type)
193
+
194
+ def _count(self, table: str, where_clause: str, where_vars: dict[str, Any], group_by: Optional[str] = None) -> int:
195
+ total_count_query = COUNT_QUERY.format(
196
+ table=table,
197
+ where_clause=where_clause,
198
+ group_clause="GROUP ALL" if group_by is None else f"GROUP BY {group_by}",
199
+ group_fields="" if group_by is None else f", {group_by}",
200
+ )
201
+ count_result = self._query_one(total_count_query, where_vars, dict)
202
+ total_count = count_result.get("count") if count_result else 0
203
+ assert isinstance(total_count, int), f"Expected int, got {type(total_count)}"
204
+ total_count = int(total_count)
205
+ return total_count
206
+
207
+ # --- Sessions ---
208
+ def clear_sessions(self) -> None:
209
+ """Delete all session rows from the database.
210
+
211
+ Raises:
212
+ Exception: If an error occurs during deletion.
213
+ """
214
+ table = self._get_table("sessions")
215
+ _ = self.client.delete(table)
216
+
217
+ def delete_session(self, session_id: str) -> bool:
218
+ table = self._get_table(table_type="sessions")
219
+ if table is None:
220
+ return False
221
+ res = self.client.delete(RecordID(table, session_id))
222
+ return bool(res)
223
+
224
+ def delete_sessions(self, session_ids: list[str]) -> None:
225
+ table = self._get_table(table_type="sessions")
226
+ if table is None:
227
+ return
228
+
229
+ records = [RecordID(table, id) for id in session_ids]
230
+ self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
231
+
232
+ def get_session(
233
+ self,
234
+ session_id: str,
235
+ session_type: SessionType,
236
+ user_id: Optional[str] = None,
237
+ deserialize: Optional[bool] = True,
238
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
239
+ r"""
240
+ Read a session from the database.
241
+
242
+ Args:
243
+ session_id (str): ID of the session to read.
244
+ session_type (SessionType): Type of session to get.
245
+ user_id (Optional[str]): User ID to filter by. Defaults to None.
246
+ deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
247
+
248
+ Returns:
249
+ Optional[Union[Session, Dict[str, Any]]]:
250
+ - When deserialize=True: Session object
251
+ - When deserialize=False: Session dictionary
252
+
253
+ Raises:
254
+ Exception: If an error occurs during retrieval.
255
+ """
256
+ sessions_table = self._get_table("sessions")
257
+ record = RecordID(sessions_table, session_id)
258
+ where = WhereClause()
259
+ if user_id is not None:
260
+ where = where.and_("user_id", user_id)
261
+ where_clause, where_vars = where.build()
262
+ query = dedent(f"""
263
+ SELECT *
264
+ FROM ONLY $record
265
+ {where_clause}
266
+ """)
267
+ vars = {"record": record, **where_vars}
268
+ raw = self._query_one(query, vars, dict)
269
+ if raw is None or not deserialize:
270
+ return raw
271
+
272
+ return deserialize_session(session_type, raw)
273
+
274
+ def get_sessions(
275
+ self,
276
+ session_type: Optional[SessionType] = None,
277
+ user_id: Optional[str] = None,
278
+ component_id: Optional[str] = None,
279
+ session_name: Optional[str] = None,
280
+ start_timestamp: Optional[int] = None,
281
+ end_timestamp: Optional[int] = None,
282
+ limit: Optional[int] = None,
283
+ page: Optional[int] = None,
284
+ sort_by: Optional[str] = None,
285
+ sort_order: Optional[str] = None,
286
+ deserialize: Optional[bool] = True,
287
+ ) -> Union[List[Session], Tuple[List[Dict[str, Any]], int]]:
288
+ r"""
289
+ Get all sessions in the given table. Can filter by user_id and entity_id.
290
+
291
+ Args:
292
+ session_type (SessionType): The type of session to get.
293
+ user_id (Optional[str]): The ID of the user to filter by.
294
+ component_id (Optional[str]): The ID of the agent / team / workflow to filter by.
295
+ session_name (Optional[str]): The name of the session to filter by.
296
+ start_timestamp (Optional[int]): The start timestamp to filter by.
297
+ end_timestamp (Optional[int]): The end timestamp to filter by.
298
+ limit (Optional[int]): The maximum number of sessions to return. Defaults to None.
299
+ page (Optional[int]): The page number to return. Defaults to None.
300
+ sort_by (Optional[str]): The field to sort by. Defaults to None.
301
+ sort_order (Optional[str]): The sort order. Defaults to None.
302
+ deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
303
+
304
+ Returns:
305
+ Union[List[Session], Tuple[List[Dict], int]]:
306
+ - When deserialize=True: List of Session objects
307
+ - When deserialize=False: Tuple of (session dictionaries, total count)
308
+
309
+ Raises:
310
+ Exception: If an error occurs during retrieval.
311
+ """
312
+ table = self._get_table("sessions")
313
+ # users_table = self._get_table("users", False) # Not used, commenting out for now.
314
+ agents_table = self._get_table("agents", False)
315
+ teams_table = self._get_table("teams", False)
316
+ workflows_table = self._get_table("workflows", False)
317
+
318
+ # -- Filters
319
+ where = WhereClause()
320
+
321
+ # user_id
322
+ if user_id is not None:
323
+ where = where.and_("user_id", user_id)
324
+
325
+ # component_id
326
+ if component_id is not None:
327
+ if session_type == SessionType.AGENT:
328
+ where = where.and_("agent", RecordID(agents_table, component_id))
329
+ elif session_type == SessionType.TEAM:
330
+ where = where.and_("team", RecordID(teams_table, component_id))
331
+ elif session_type == SessionType.WORKFLOW:
332
+ where = where.and_("workflow", RecordID(workflows_table, component_id))
333
+
334
+ # session_name
335
+ if session_name is not None:
336
+ where = where.and_("session_name", session_name, "~")
337
+
338
+ # start_timestamp
339
+ if start_timestamp is not None:
340
+ where = where.and_("start_timestamp", start_timestamp, ">=")
341
+
342
+ # end_timestamp
343
+ if end_timestamp is not None:
344
+ where = where.and_("end_timestamp", end_timestamp, "<=")
345
+
346
+ where_clause, where_vars = where.build()
347
+
348
+ # Total count
349
+ total_count = self._count(table, where_clause, where_vars)
350
+
351
+ # Query
352
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
353
+ query = dedent(f"""
354
+ SELECT *
355
+ FROM {table}
356
+ {where_clause}
357
+ {order_limit_start_clause}
358
+ """)
359
+ sessions_raw = self._query(query, where_vars, dict)
360
+ converted_sessions_raw = [desurrealize_session(session, session_type) for session in sessions_raw]
361
+
362
+ if not deserialize:
363
+ return list(converted_sessions_raw), total_count
364
+
365
+ if session_type is None:
366
+ raise ValueError("session_type is required when deserialize=True")
367
+
368
+ return deserialize_sessions(session_type, list(sessions_raw))
369
+
370
+ def rename_session(
371
+ self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
372
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
373
+ """
374
+ Rename a session in the database.
375
+
376
+ Args:
377
+ session_id (str): The ID of the session to rename.
378
+ session_type (SessionType): The type of session to rename.
379
+ session_name (str): The new name for the session.
380
+ deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
381
+
382
+ Returns:
383
+ Optional[Union[Session, Dict[str, Any]]]:
384
+ - When deserialize=True: Session object
385
+ - When deserialize=False: Session dictionary
386
+
387
+ Raises:
388
+ Exception: If an error occurs during renaming.
389
+ """
390
+ table = self._get_table("sessions")
391
+ vars = {"record": RecordID(table, session_id), "name": session_name}
392
+
393
+ # Query
394
+ query = dedent("""
395
+ UPDATE ONLY $record
396
+ SET session_name = $name
397
+ """)
398
+ session_raw = self._query_one(query, vars, dict)
399
+
400
+ if session_raw is None or not deserialize:
401
+ return session_raw
402
+ return deserialize_session(session_type, session_raw)
403
+
404
+ def upsert_session(
405
+ self, session: Session, deserialize: Optional[bool] = True
406
+ ) -> Optional[Union[Session, Dict[str, Any]]]:
407
+ """
408
+ Insert or update a session in the database.
409
+
410
+ Args:
411
+ session (Session): The session data to upsert.
412
+ deserialize (Optional[bool]): Whether to deserialize the session. Defaults to True.
413
+
414
+ Returns:
415
+ Optional[Union[Session, Dict[str, Any]]]:
416
+ - When deserialize=True: Session object
417
+ - When deserialize=False: Session dictionary
418
+
419
+ Raises:
420
+ Exception: If an error occurs during upsert.
421
+ """
422
+ session_type = get_session_type(session)
423
+ table = self._get_table("sessions")
424
+ session_raw = self._query_one(
425
+ "UPSERT ONLY $record CONTENT $content",
426
+ {
427
+ "record": RecordID(table, session.session_id),
428
+ "content": serialize_session(session, self.table_names),
429
+ },
430
+ dict,
431
+ )
432
+ if session_raw is None or not deserialize:
433
+ return session_raw
434
+
435
+ return deserialize_session(session_type, session_raw)
436
+
437
+ def upsert_sessions(
438
+ self, sessions: List[Session], deserialize: Optional[bool] = True
439
+ ) -> List[Union[Session, Dict[str, Any]]]:
440
+ """
441
+ Bulk insert or update multiple sessions.
442
+
443
+ Args:
444
+ sessions (List[Session]): The list of session data to upsert.
445
+ deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
446
+
447
+ Returns:
448
+ List[Union[Session, Dict[str, Any]]]: List of upserted sessions
449
+
450
+ Raises:
451
+ Exception: If an error occurs during bulk upsert.
452
+ """
453
+ if not sessions:
454
+ return []
455
+ session_type = get_session_type(sessions[0])
456
+ table = self._get_table("sessions")
457
+ sessions_raw: List[Dict[str, Any]] = []
458
+ for session in sessions:
459
+ # UPSERT does only work for one record at a time
460
+ session_raw = self._query_one(
461
+ "UPSERT ONLY $record CONTENT $content",
462
+ {
463
+ "record": RecordID(table, session.session_id),
464
+ "content": serialize_session(session, self.table_names),
465
+ },
466
+ dict,
467
+ )
468
+ if session_raw:
469
+ sessions_raw.append(session_raw)
470
+ if not deserialize:
471
+ return list(sessions_raw)
472
+
473
+ # wrapping with list because of:
474
+ # Type "List[Session]" is not assignable to return type "List[Session | Dict[str, Any]]"
475
+ # Consider switching from "list" to "Sequence" which is covariant
476
+ return list(deserialize_sessions(session_type, sessions_raw))
477
+
478
+ # --- Memory ---
479
+ def clear_memories(self) -> None:
480
+ """Delete all memories from the database.
481
+
482
+ Raises:
483
+ Exception: If an error occurs during deletion.
484
+ """
485
+ table = self._get_table("memories")
486
+ _ = self.client.delete(table)
487
+
488
+ # -- Cultural Knowledge methods --
489
+ def clear_cultural_knowledge(self) -> None:
490
+ """Delete all cultural knowledge from the database.
491
+
492
+ Raises:
493
+ Exception: If an error occurs during deletion.
494
+ """
495
+ table = self._get_table("culture")
496
+ _ = self.client.delete(table)
497
+
498
+ def delete_cultural_knowledge(self, id: str) -> None:
499
+ """Delete cultural knowledge by ID.
500
+
501
+ Args:
502
+ id (str): The ID of the cultural knowledge to delete.
503
+
504
+ Raises:
505
+ Exception: If an error occurs during deletion.
506
+ """
507
+ table = self._get_table("culture")
508
+ rec_id = RecordID(table, id)
509
+ self.client.delete(rec_id)
510
+
511
+ def get_cultural_knowledge(
512
+ self, id: str, deserialize: Optional[bool] = True
513
+ ) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
514
+ """Get cultural knowledge by ID.
515
+
516
+ Args:
517
+ id (str): The ID of the cultural knowledge to retrieve.
518
+ deserialize (Optional[bool]): Whether to deserialize to CulturalKnowledge object. Defaults to True.
519
+
520
+ Returns:
521
+ Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The cultural knowledge if found, None otherwise.
522
+
523
+ Raises:
524
+ Exception: If an error occurs during retrieval.
525
+ """
526
+ table = self._get_table("culture")
527
+ rec_id = RecordID(table, id)
528
+ result = self.client.select(rec_id)
529
+
530
+ if result is None:
531
+ return None
532
+
533
+ if not deserialize:
534
+ return result # type: ignore
535
+
536
+ return deserialize_cultural_knowledge(result) # type: ignore
537
+
538
+ def get_all_cultural_knowledge(
539
+ self,
540
+ agent_id: Optional[str] = None,
541
+ team_id: Optional[str] = None,
542
+ name: Optional[str] = None,
543
+ limit: Optional[int] = None,
544
+ page: Optional[int] = None,
545
+ sort_by: Optional[str] = None,
546
+ sort_order: Optional[str] = None,
547
+ deserialize: Optional[bool] = True,
548
+ ) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
549
+ """Get all cultural knowledge with filtering and pagination.
550
+
551
+ Args:
552
+ agent_id (Optional[str]): Filter by agent ID.
553
+ team_id (Optional[str]): Filter by team ID.
554
+ name (Optional[str]): Filter by name (case-insensitive partial match).
555
+ limit (Optional[int]): Maximum number of results to return.
556
+ page (Optional[int]): Page number for pagination.
557
+ sort_by (Optional[str]): Field to sort by.
558
+ sort_order (Optional[str]): Sort order ('asc' or 'desc').
559
+ deserialize (Optional[bool]): Whether to deserialize to CulturalKnowledge objects. Defaults to True.
560
+
561
+ Returns:
562
+ Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
563
+ - When deserialize=True: List of CulturalKnowledge objects
564
+ - When deserialize=False: Tuple with list of dictionaries and total count
565
+
566
+ Raises:
567
+ Exception: If an error occurs during retrieval.
568
+ """
569
+ table = self._get_table("culture")
570
+
571
+ # Build where clauses
572
+ where_clauses: List[WhereClause] = []
573
+ if agent_id is not None:
574
+ agent_rec_id = RecordID(self._get_table("agents"), agent_id)
575
+ where_clauses.append(("agent", "=", agent_rec_id)) # type: ignore
576
+ if team_id is not None:
577
+ team_rec_id = RecordID(self._get_table("teams"), team_id)
578
+ where_clauses.append(("team", "=", team_rec_id)) # type: ignore
579
+ if name is not None:
580
+ where_clauses.append(("string::lowercase(name)", "CONTAINS", name.lower())) # type: ignore
581
+
582
+ # Build query for total count
583
+ count_query = COUNT_QUERY.format(
584
+ table=table,
585
+ where=""
586
+ if not where_clauses
587
+ else f"WHERE {' AND '.join(f'{w[0]} {w[1]} ${chr(97 + i)}' for i, w in enumerate(where_clauses))}", # type: ignore
588
+ )
589
+ params = {chr(97 + i): w[2] for i, w in enumerate(where_clauses)} # type: ignore
590
+ total_count = self._query_one(count_query, params, int) or 0
591
+
592
+ # Build main query
593
+ order_limit = order_limit_start(sort_by, sort_order, limit, page)
594
+ query = f"SELECT * FROM {table}"
595
+ if where_clauses:
596
+ query += f" WHERE {' AND '.join(f'{w[0]} {w[1]} ${chr(97 + i)}' for i, w in enumerate(where_clauses))}" # type: ignore
597
+ query += order_limit
598
+
599
+ results = self._query(query, params, list) or []
600
+
601
+ if not deserialize:
602
+ return results, total_count # type: ignore
603
+
604
+ return [deserialize_cultural_knowledge(r) for r in results] # type: ignore
605
+
606
+ def upsert_cultural_knowledge(
607
+ self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
608
+ ) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
609
+ """Upsert cultural knowledge in SurrealDB.
610
+
611
+ Args:
612
+ cultural_knowledge (CulturalKnowledge): The cultural knowledge to upsert.
613
+ deserialize (Optional[bool]): Whether to deserialize the result. Defaults to True.
614
+
615
+ Returns:
616
+ Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The upserted cultural knowledge.
617
+
618
+ Raises:
619
+ Exception: If an error occurs during upsert.
620
+ """
621
+ table = self._get_table("culture", create_table_if_not_found=True)
622
+ serialized = serialize_cultural_knowledge(cultural_knowledge, table)
623
+
624
+ result = self.client.upsert(serialized["id"], serialized)
625
+
626
+ if result is None:
627
+ return None
628
+
629
+ if not deserialize:
630
+ return result # type: ignore
631
+
632
+ return deserialize_cultural_knowledge(result) # type: ignore
633
+
634
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
635
+ """Delete a user memory from the database.
636
+
637
+ Args:
638
+ memory_id (str): The ID of the memory to delete.
639
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
640
+
641
+ Returns:
642
+ bool: True if deletion was successful, False otherwise.
643
+
644
+ Raises:
645
+ Exception: If an error occurs during deletion.
646
+ """
647
+ table = self._get_table("memories")
648
+ mem_rec_id = RecordID(table, memory_id)
649
+ if user_id is None:
650
+ self.client.delete(mem_rec_id)
651
+ else:
652
+ user_rec_id = RecordID(self._get_table("users"), user_id)
653
+ self.client.query(
654
+ f"DELETE FROM {table} WHERE user = $user AND id = $memory",
655
+ {"user": user_rec_id, "memory": mem_rec_id},
656
+ )
657
+
658
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
659
+ """Delete user memories from the database.
660
+
661
+ Args:
662
+ memory_ids (List[str]): The IDs of the memories to delete.
663
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
664
+
665
+ Raises:
666
+ Exception: If an error occurs during deletion.
667
+ """
668
+ table = self._get_table("memories")
669
+ records = [RecordID(table, memory_id) for memory_id in memory_ids]
670
+ if user_id is None:
671
+ _ = self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
672
+ else:
673
+ user_rec_id = RecordID(self._get_table("users"), user_id)
674
+ _ = self.client.query(
675
+ f"DELETE FROM {table} WHERE id IN $records AND user = $user", {"records": records, "user": user_rec_id}
676
+ )
677
+
678
+ def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
679
+ """Get all memory topics from the database.
680
+
681
+ Args:
682
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
683
+
684
+ Returns:
685
+ List[str]: List of memory topics.
686
+ """
687
+ table = self._get_table("memories")
688
+ vars: dict[str, Any] = {}
689
+
690
+ # Query
691
+ if user_id is None:
692
+ query = dedent(f"""
693
+ RETURN (
694
+ SELECT
695
+ array::flatten(topics) as topics
696
+ FROM ONLY {table}
697
+ GROUP ALL
698
+ ).topics.distinct();
699
+ """)
700
+ else:
701
+ query = dedent(f"""
702
+ RETURN (
703
+ SELECT
704
+ array::flatten(topics) as topics
705
+ FROM ONLY {table}
706
+ WHERE user = $user
707
+ GROUP ALL
708
+ ).topics.distinct();
709
+ """)
710
+ vars["user"] = RecordID(self._get_table("users"), user_id)
711
+
712
+ result = self._query(query, vars, str)
713
+ return list(result)
714
+
715
+ def get_user_memory(
716
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
717
+ ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
718
+ """Get a memory from the database.
719
+
720
+ Args:
721
+ memory_id (str): The ID of the memory to get.
722
+ deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
723
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
724
+
725
+ Returns:
726
+ Optional[Union[UserMemory, Dict[str, Any]]]:
727
+ - When deserialize=True: UserMemory object
728
+ - When deserialize=False: UserMemory dictionary
729
+
730
+ Raises:
731
+ Exception: If an error occurs during retrieval.
732
+ """
733
+ table_name = self._get_table("memories")
734
+ record = RecordID(table_name, memory_id)
735
+ vars = {"record": record}
736
+
737
+ if user_id is None:
738
+ query = "SELECT * FROM ONLY $record"
739
+ else:
740
+ query = "SELECT * FROM ONLY $record WHERE user = $user"
741
+ vars["user"] = RecordID(self._get_table("users"), user_id)
742
+
743
+ result = self._query_one(query, vars, dict)
744
+ if result is None or not deserialize:
745
+ return result
746
+ return deserialize_user_memory(result)
747
+
748
+ def get_user_memories(
749
+ self,
750
+ user_id: Optional[str] = None,
751
+ agent_id: Optional[str] = None,
752
+ team_id: Optional[str] = None,
753
+ topics: Optional[List[str]] = None,
754
+ search_content: Optional[str] = None,
755
+ limit: Optional[int] = None,
756
+ page: Optional[int] = None,
757
+ sort_by: Optional[str] = None,
758
+ sort_order: Optional[str] = None,
759
+ deserialize: Optional[bool] = True,
760
+ ) -> Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
761
+ """Get all memories from the database as UserMemory objects.
762
+
763
+ Args:
764
+ user_id (Optional[str]): The ID of the user to filter by.
765
+ agent_id (Optional[str]): The ID of the agent to filter by.
766
+ team_id (Optional[str]): The ID of the team to filter by.
767
+ topics (Optional[List[str]]): The topics to filter by.
768
+ search_content (Optional[str]): The content to search for.
769
+ limit (Optional[int]): The maximum number of memories to return.
770
+ page (Optional[int]): The page number.
771
+ sort_by (Optional[str]): The column to sort by.
772
+ sort_order (Optional[str]): The order to sort by.
773
+ deserialize (Optional[bool]): Whether to serialize the memories. Defaults to True.
774
+
775
+
776
+ Returns:
777
+ Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
778
+ - When deserialize=True: List of UserMemory objects
779
+ - When deserialize=False: Tuple of (memory dictionaries, total count)
780
+
781
+ Raises:
782
+ Exception: If an error occurs during retrieval.
783
+ """
784
+ table = self._get_table("memories")
785
+ where = WhereClause()
786
+ if user_id is not None:
787
+ rec_id = RecordID(self._get_table("users"), user_id)
788
+ where.and_("user", rec_id)
789
+ if agent_id is not None:
790
+ rec_id = RecordID(self._get_table("agents"), agent_id)
791
+ where.and_("agent", rec_id)
792
+ if team_id is not None:
793
+ rec_id = RecordID(self._get_table("teams"), team_id)
794
+ where.and_("team", rec_id)
795
+ if topics is not None:
796
+ where.and_("topics", topics, "CONTAINSANY")
797
+ if search_content is not None:
798
+ where.and_("memory", search_content, "~")
799
+ where_clause, where_vars = where.build()
800
+
801
+ # Total count
802
+ total_count = self._count(table, where_clause, where_vars)
803
+
804
+ # Query
805
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
806
+ query = dedent(f"""
807
+ SELECT *
808
+ FROM {table}
809
+ {where_clause}
810
+ {order_limit_start_clause}
811
+ """)
812
+ result = self._query(query, where_vars, dict)
813
+ if deserialize:
814
+ return deserialize_user_memories(result)
815
+ return [desurrealize_user_memory(x) for x in result], total_count
816
+
817
+ def get_user_memory_stats(
818
+ self,
819
+ limit: Optional[int] = None,
820
+ page: Optional[int] = None,
821
+ user_id: Optional[str] = None,
822
+ ) -> Tuple[List[Dict[str, Any]], int]:
823
+ """Get user memories stats.
824
+
825
+ Args:
826
+ limit (Optional[int]): The maximum number of user stats to return.
827
+ page (Optional[int]): The page number.
828
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
829
+
830
+ Returns:
831
+ Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
832
+
833
+ Example:
834
+ (
835
+ [
836
+ {
837
+ "user_id": "123",
838
+ "total_memories": 10,
839
+ "last_memory_updated_at": 1714560000,
840
+ },
841
+ ],
842
+ total_count: 1,
843
+ )
844
+ """
845
+ memories_table_name = self._get_table("memories")
846
+ where = WhereClause()
847
+
848
+ if user_id is None:
849
+ where.and_("!!user", True, "=") # this checks that user is not falsy
850
+ else:
851
+ where.and_("user", RecordID(self._get_table("users"), user_id), "=")
852
+
853
+ where_clause, where_vars = where.build()
854
+ # Group
855
+ group_clause = "GROUP BY user"
856
+ # Order
857
+ order_limit_start_clause = order_limit_start("last_memory_updated_at", "DESC", limit, page)
858
+ # Total count
859
+ total_count = (
860
+ self._query_one(f"(SELECT user FROM {memories_table_name} GROUP BY user).map(|$x| $x.user).len()", {}, int)
861
+ or 0
862
+ )
863
+ # Query
864
+ query = dedent(f"""
865
+ SELECT
866
+ user,
867
+ count(id) AS total_memories,
868
+ time::max(updated_at) AS last_memory_updated_at
869
+ FROM {memories_table_name}
870
+ {where_clause}
871
+ {group_clause}
872
+ {order_limit_start_clause}
873
+ """)
874
+ result = self._query(query, where_vars, dict)
875
+
876
+ # deserialize dates and RecordIDs
877
+ for row in result:
878
+ row["user_id"] = row["user"].id
879
+ del row["user"]
880
+ row["last_memory_updated_at"] = row["last_memory_updated_at"].timestamp()
881
+ row["last_memory_updated_at"] = int(row["last_memory_updated_at"])
882
+
883
+ return list(result), total_count
884
+
885
+ def upsert_user_memory(
886
+ self, memory: UserMemory, deserialize: Optional[bool] = True
887
+ ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
888
+ """Upsert a user memory in the database.
889
+
890
+ Args:
891
+ memory (UserMemory): The user memory to upsert.
892
+ deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
893
+
894
+ Returns:
895
+ Optional[Union[UserMemory, Dict[str, Any]]]:
896
+ - When deserialize=True: UserMemory object
897
+ - When deserialize=False: UserMemory dictionary
898
+
899
+ Raises:
900
+ Exception: If an error occurs during upsert.
901
+ """
902
+ table = self._get_table("memories")
903
+ user_table = self._get_table("users")
904
+ if memory.memory_id:
905
+ record = RecordID(table, memory.memory_id)
906
+ query = "UPSERT ONLY $record CONTENT $content"
907
+ result = self._query_one(
908
+ query, {"record": record, "content": serialize_user_memory(memory, table, user_table)}, dict
909
+ )
910
+ else:
911
+ query = f"CREATE ONLY {table} CONTENT $content"
912
+ result = self._query_one(query, {"content": serialize_user_memory(memory, table, user_table)}, dict)
913
+ if result is None:
914
+ return None
915
+ elif not deserialize:
916
+ return desurrealize_user_memory(result)
917
+ return deserialize_user_memory(result)
918
+
919
+ def upsert_memories(
920
+ self, memories: List[UserMemory], deserialize: Optional[bool] = True
921
+ ) -> List[Union[UserMemory, Dict[str, Any]]]:
922
+ """
923
+ Bulk insert or update multiple memories in the database for improved performance.
924
+
925
+ Args:
926
+ memories (List[UserMemory]): The list of memories to upsert.
927
+ deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
928
+
929
+ Returns:
930
+ List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories
931
+
932
+ Raises:
933
+ Exception: If an error occurs during bulk upsert.
934
+ """
935
+ if not memories:
936
+ return []
937
+ table = self._get_table("memories")
938
+ user_table_name = self._get_table("users")
939
+ raw: list[dict] = []
940
+ for memory in memories:
941
+ if memory.memory_id:
942
+ # UPSERT does only work for one record at a time
943
+ session_raw = self._query_one(
944
+ "UPSERT ONLY $record CONTENT $content",
945
+ {
946
+ "record": RecordID(table, memory.memory_id),
947
+ "content": serialize_user_memory(memory, table, user_table_name),
948
+ },
949
+ dict,
950
+ )
951
+ else:
952
+ session_raw = self._query_one(
953
+ f"CREATE ONLY {table} CONTENT $content",
954
+ {"content": serialize_user_memory(memory, table, user_table_name)},
955
+ dict,
956
+ )
957
+ if session_raw is not None:
958
+ raw.append(session_raw)
959
+ if raw is None or not deserialize:
960
+ return [desurrealize_user_memory(x) for x in raw]
961
+ # wrapping with list because of:
962
+ # Type "List[Session]" is not assignable to return type "List[Session | Dict[str, Any]]"
963
+ # Consider switching from "list" to "Sequence" which is covariant
964
+ return list(deserialize_user_memories(raw))
965
+
966
+ # --- Metrics ---
967
+ def get_metrics(
968
+ self,
969
+ starting_date: Optional[date] = None,
970
+ ending_date: Optional[date] = None,
971
+ ) -> Tuple[List[Dict[str, Any]], Optional[int]]:
972
+ """Get all metrics matching the given date range.
973
+
974
+ Args:
975
+ starting_date (Optional[date]): The starting date to filter metrics by.
976
+ ending_date (Optional[date]): The ending date to filter metrics by.
977
+
978
+ Returns:
979
+ Tuple[List[dict], Optional[int]]: A tuple containing the metrics and the timestamp of the latest update.
980
+
981
+ Raises:
982
+ Exception: If an error occurs during retrieval.
983
+ """
984
+ table = self._get_table("metrics")
985
+
986
+ where = WhereClause()
987
+
988
+ # starting_date - need to convert date to datetime for comparison
989
+ if starting_date is not None:
990
+ starting_datetime = datetime.combine(starting_date, datetime.min.time()).replace(tzinfo=timezone.utc)
991
+ where = where.and_("date", starting_datetime, ">=")
992
+
993
+ # ending_date - need to convert date to datetime for comparison
994
+ if ending_date is not None:
995
+ ending_datetime = datetime.combine(ending_date, datetime.min.time()).replace(tzinfo=timezone.utc)
996
+ where = where.and_("date", ending_datetime, "<=")
997
+
998
+ where_clause, where_vars = where.build()
999
+
1000
+ # Query
1001
+ query = dedent(f"""
1002
+ SELECT *
1003
+ FROM {table}
1004
+ {where_clause}
1005
+ ORDER BY date ASC
1006
+ """)
1007
+
1008
+ results = self._query(query, where_vars, dict)
1009
+
1010
+ # Get the latest updated_at from all results
1011
+ latest_update = None
1012
+ if results:
1013
+ # Find the maximum updated_at timestamp
1014
+ latest_update = max(int(r["updated_at"].timestamp()) for r in results)
1015
+
1016
+ # Transform results to match expected format
1017
+ transformed_results = []
1018
+ for r in results:
1019
+ transformed = dict(r)
1020
+
1021
+ # Convert RecordID to string
1022
+ if hasattr(transformed.get("id"), "id"):
1023
+ transformed["id"] = transformed["id"].id
1024
+ elif isinstance(transformed.get("id"), RecordID):
1025
+ transformed["id"] = str(transformed["id"].id)
1026
+
1027
+ # Convert datetime objects to Unix timestamps
1028
+ if isinstance(transformed.get("created_at"), datetime):
1029
+ transformed["created_at"] = int(transformed["created_at"].timestamp())
1030
+ if isinstance(transformed.get("updated_at"), datetime):
1031
+ transformed["updated_at"] = int(transformed["updated_at"].timestamp())
1032
+ if isinstance(transformed.get("date"), datetime):
1033
+ transformed["date"] = int(transformed["date"].timestamp())
1034
+
1035
+ transformed_results.append(transformed)
1036
+
1037
+ return transformed_results, latest_update
1038
+
1039
+ return [], latest_update
1040
+
1041
+ def calculate_metrics(self) -> Optional[List[Dict[str, Any]]]: # More specific return type
1042
+ """Calculate metrics for all dates without complete metrics.
1043
+
1044
+ Returns:
1045
+ Optional[List[Dict[str, Any]]]: The calculated metrics.
1046
+
1047
+ Raises:
1048
+ Exception: If an error occurs during metrics calculation.
1049
+ """
1050
+ try:
1051
+ table = self._get_table("metrics") # Removed create_table_if_not_found parameter
1052
+
1053
+ starting_date = get_metrics_calculation_starting_date(self.client, table, self.get_sessions)
1054
+
1055
+ if starting_date is None:
1056
+ log_info("No session data found. Won't calculate metrics.")
1057
+ return None
1058
+
1059
+ dates_to_process = get_dates_to_calculate_metrics_for(starting_date)
1060
+ if not dates_to_process:
1061
+ log_info("Metrics already calculated for all relevant dates.")
1062
+ return None
1063
+
1064
+ start_timestamp = datetime.combine(dates_to_process[0], datetime.min.time()).replace(tzinfo=timezone.utc)
1065
+ end_timestamp = datetime.combine(dates_to_process[-1] + timedelta(days=1), datetime.min.time()).replace(
1066
+ tzinfo=timezone.utc
1067
+ )
1068
+
1069
+ sessions = get_all_sessions_for_metrics_calculation(
1070
+ self.client, self._get_table("sessions"), start_timestamp, end_timestamp
1071
+ )
1072
+
1073
+ all_sessions_data = fetch_all_sessions_data(
1074
+ sessions=sessions, # Added parameter name for clarity
1075
+ dates_to_process=dates_to_process,
1076
+ start_timestamp=int(start_timestamp.timestamp()), # This expects int
1077
+ )
1078
+ if not all_sessions_data:
1079
+ log_info("No new session data found. Won't calculate metrics.")
1080
+ return None
1081
+
1082
+ metrics_records = []
1083
+
1084
+ for date_to_process in dates_to_process:
1085
+ date_key = date_to_process.isoformat()
1086
+ sessions_for_date = all_sessions_data.get(date_key, {})
1087
+
1088
+ # Skip dates with no sessions
1089
+ if not any(len(sessions) > 0 for sessions in sessions_for_date.values()):
1090
+ continue
1091
+
1092
+ metrics_record = calculate_date_metrics(date_to_process, sessions_for_date)
1093
+ metrics_records.append(metrics_record)
1094
+
1095
+ results = [] # Initialize before the if block
1096
+ if metrics_records:
1097
+ results = bulk_upsert_metrics(self.client, table, metrics_records)
1098
+
1099
+ log_debug("Updated metrics calculations")
1100
+ return results
1101
+
1102
+ except Exception as e:
1103
+ log_error(f"Exception refreshing metrics: {e}")
1104
+ raise e
1105
+
1106
+ # --- Knowledge ---
1107
+ def clear_knowledge(self) -> None:
1108
+ """Delete all knowledge rows from the database.
1109
+
1110
+ Raises:
1111
+ Exception: If an error occurs during deletion.
1112
+ """
1113
+ table = self._get_table("knowledge")
1114
+ _ = self.client.delete(table)
1115
+
1116
+ def delete_knowledge_content(self, id: str):
1117
+ """Delete a knowledge row from the database.
1118
+
1119
+ Args:
1120
+ id (str): The ID of the knowledge row to delete.
1121
+ """
1122
+ table = self._get_table("knowledge")
1123
+ self.client.delete(RecordID(table, id))
1124
+
1125
+ def get_knowledge_content(self, id: str) -> Optional[KnowledgeRow]:
1126
+ """Get a knowledge row from the database.
1127
+
1128
+ Args:
1129
+ id (str): The ID of the knowledge row to get.
1130
+
1131
+ Returns:
1132
+ Optional[KnowledgeRow]: The knowledge row, or None if it doesn't exist.
1133
+ """
1134
+ table = self._get_table("knowledge")
1135
+ record_id = RecordID(table, id)
1136
+ raw = self._query_one("SELECT * FROM ONLY $record_id", {"record_id": record_id}, dict)
1137
+ return deserialize_knowledge_row(raw) if raw else None
1138
+
1139
+ def get_knowledge_contents(
1140
+ self,
1141
+ limit: Optional[int] = None,
1142
+ page: Optional[int] = None,
1143
+ sort_by: Optional[str] = None,
1144
+ sort_order: Optional[str] = None,
1145
+ ) -> Tuple[List[KnowledgeRow], int]:
1146
+ """Get all knowledge contents from the database.
1147
+
1148
+ Args:
1149
+ limit (Optional[int]): The maximum number of knowledge contents to return.
1150
+ page (Optional[int]): The page number.
1151
+ sort_by (Optional[str]): The column to sort by.
1152
+ sort_order (Optional[str]): The order to sort by.
1153
+
1154
+ Returns:
1155
+ Tuple[List[KnowledgeRow], int]: The knowledge contents and total count.
1156
+
1157
+ Raises:
1158
+ Exception: If an error occurs during retrieval.
1159
+ """
1160
+ table = self._get_table("knowledge")
1161
+ where = WhereClause()
1162
+ where_clause, where_vars = where.build()
1163
+
1164
+ # Total count
1165
+ total_count = self._count(table, where_clause, where_vars)
1166
+
1167
+ # Query
1168
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
1169
+ query = dedent(f"""
1170
+ SELECT *
1171
+ FROM {table}
1172
+ {where_clause}
1173
+ {order_limit_start_clause}
1174
+ """)
1175
+ result = self._query(query, where_vars, dict)
1176
+ return [deserialize_knowledge_row(row) for row in result], total_count
1177
+
1178
+ def upsert_knowledge_content(self, knowledge_row: KnowledgeRow) -> Optional[KnowledgeRow]:
1179
+ """Upsert knowledge content in the database.
1180
+
1181
+ Args:
1182
+ knowledge_row (KnowledgeRow): The knowledge row to upsert.
1183
+
1184
+ Returns:
1185
+ Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
1186
+ """
1187
+ knowledge_table_name = self._get_table("knowledge")
1188
+ record = RecordID(knowledge_table_name, knowledge_row.id)
1189
+ query = "UPSERT ONLY $record CONTENT $content"
1190
+ result = self._query_one(
1191
+ query, {"record": record, "content": serialize_knowledge_row(knowledge_row, knowledge_table_name)}, dict
1192
+ )
1193
+ return deserialize_knowledge_row(result) if result else None
1194
+
1195
+ # --- Evals ---
1196
+ def clear_evals(self) -> None:
1197
+ """Delete all eval rows from the database.
1198
+
1199
+ Raises:
1200
+ Exception: If an error occurs during deletion.
1201
+ """
1202
+ table = self._get_table("evals")
1203
+ _ = self.client.delete(table)
1204
+
1205
+ def create_eval_run(self, eval_run: EvalRunRecord) -> Optional[EvalRunRecord]:
1206
+ """Create an EvalRunRecord in the database.
1207
+
1208
+ Args:
1209
+ eval_run (EvalRunRecord): The eval run to create.
1210
+
1211
+ Returns:
1212
+ Optional[EvalRunRecord]: The created eval run, or None if the operation fails.
1213
+
1214
+ Raises:
1215
+ Exception: If an error occurs during creation.
1216
+ """
1217
+ table = self._get_table("evals")
1218
+ rec_id = RecordID(table, eval_run.run_id)
1219
+ query = "CREATE ONLY $record CONTENT $content"
1220
+ result = self._query_one(
1221
+ query, {"record": rec_id, "content": serialize_eval_run_record(eval_run, self.table_names)}, dict
1222
+ )
1223
+ return deserialize_eval_run_record(result) if result else None
1224
+
1225
+ def delete_eval_runs(self, eval_run_ids: List[str]) -> None:
1226
+ """Delete multiple eval runs from the database.
1227
+
1228
+ Args:
1229
+ eval_run_ids (List[str]): List of eval run IDs to delete.
1230
+ """
1231
+ table = self._get_table("evals")
1232
+ records = [RecordID(table, id) for id in eval_run_ids]
1233
+ _ = self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
1234
+
1235
+ def get_eval_run(
1236
+ self, eval_run_id: str, deserialize: Optional[bool] = True
1237
+ ) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1238
+ """Get an eval run from the database.
1239
+
1240
+ Args:
1241
+ eval_run_id (str): The ID of the eval run to get.
1242
+ deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
1243
+
1244
+ Returns:
1245
+ Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1246
+ - When deserialize=True: EvalRunRecord object
1247
+ - When deserialize=False: EvalRun dictionary
1248
+
1249
+ Raises:
1250
+ Exception: If an error occurs during retrieval.
1251
+ """
1252
+ table = self._get_table("evals")
1253
+ record = RecordID(table, eval_run_id)
1254
+ result = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
1255
+ if not result or not deserialize:
1256
+ return desurrealize_eval_run_record(result) if result is not None else None
1257
+ return deserialize_eval_run_record(result)
1258
+
1259
+ def get_eval_runs(
1260
+ self,
1261
+ limit: Optional[int] = None,
1262
+ page: Optional[int] = None,
1263
+ sort_by: Optional[str] = None,
1264
+ sort_order: Optional[str] = None,
1265
+ agent_id: Optional[str] = None,
1266
+ team_id: Optional[str] = None,
1267
+ workflow_id: Optional[str] = None,
1268
+ model_id: Optional[str] = None,
1269
+ filter_type: Optional[EvalFilterType] = None,
1270
+ eval_type: Optional[List[EvalType]] = None,
1271
+ deserialize: Optional[bool] = True,
1272
+ ) -> Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
1273
+ """Get all eval runs from the database.
1274
+
1275
+ Args:
1276
+ limit (Optional[int]): The maximum number of eval runs to return.
1277
+ page (Optional[int]): The page number to return.
1278
+ sort_by (Optional[str]): The field to sort by.
1279
+ sort_order (Optional[str]): The order to sort by.
1280
+ agent_id (Optional[str]): The ID of the agent to filter by.
1281
+ team_id (Optional[str]): The ID of the team to filter by.
1282
+ workflow_id (Optional[str]): The ID of the workflow to filter by.
1283
+ model_id (Optional[str]): The ID of the model to filter by.
1284
+ eval_type (Optional[List[EvalType]]): The type of eval to filter by.
1285
+ filter_type (Optional[EvalFilterType]): The type of filter to apply.
1286
+ deserialize (Optional[bool]): Whether to serialize the eval runs. Defaults to True.
1287
+
1288
+ Returns:
1289
+ Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
1290
+ - When deserialize=True: List of EvalRunRecord objects
1291
+ - When deserialize=False: List of eval run dictionaries and the total count
1292
+
1293
+ Raises:
1294
+ Exception: If there is an error getting the eval runs.
1295
+ """
1296
+ table = self._get_table("evals")
1297
+
1298
+ where = WhereClause()
1299
+ if filter_type is not None:
1300
+ if filter_type == EvalFilterType.AGENT:
1301
+ where.and_("agent", RecordID(self._get_table("agents"), agent_id))
1302
+ elif filter_type == EvalFilterType.TEAM:
1303
+ where.and_("team", RecordID(self._get_table("teams"), team_id))
1304
+ elif filter_type == EvalFilterType.WORKFLOW:
1305
+ where.and_("workflow", RecordID(self._get_table("workflows"), workflow_id))
1306
+ if model_id is not None:
1307
+ where.and_("model_id", model_id)
1308
+ if eval_type is not None:
1309
+ where.and_("eval_type", eval_type)
1310
+ where_clause, where_vars = where.build()
1311
+
1312
+ # Order
1313
+ order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
1314
+
1315
+ # Total count
1316
+ total_count = self._count(table, where_clause, where_vars)
1317
+
1318
+ # Query
1319
+ query = dedent(f"""
1320
+ SELECT *
1321
+ FROM {table}
1322
+ {where_clause}
1323
+ {order_limit_start_clause}
1324
+ """)
1325
+ result = self._query(query, where_vars, dict)
1326
+
1327
+ if not deserialize:
1328
+ return list(result), total_count
1329
+ return [deserialize_eval_run_record(x) for x in result]
1330
+
1331
+ def rename_eval_run(
1332
+ self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
1333
+ ) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1334
+ """Update the name of an eval run in the database.
1335
+
1336
+ Args:
1337
+ eval_run_id (str): The ID of the eval run to update.
1338
+ name (str): The new name of the eval run.
1339
+ deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
1340
+
1341
+ Returns:
1342
+ Optional[Union[EvalRunRecord, Dict[str, Any]]]:
1343
+ - When deserialize=True: EvalRunRecord object
1344
+ - When deserialize=False: EvalRun dictionary
1345
+
1346
+ Raises:
1347
+ Exception: If there is an error updating the eval run.
1348
+ """
1349
+ table = self._get_table("evals")
1350
+ vars = {"record": RecordID(table, eval_run_id), "name": name}
1351
+
1352
+ # Query
1353
+ query = dedent("""
1354
+ UPDATE ONLY $record
1355
+ SET name = $name
1356
+ """)
1357
+ raw = self._query_one(query, vars, dict)
1358
+
1359
+ if not raw or not deserialize:
1360
+ return raw
1361
+ return deserialize_eval_run_record(raw)