agno 2.1.2__py3-none-any.whl → 2.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. agno/agent/agent.py +5540 -2273
  2. agno/api/api.py +2 -0
  3. agno/api/os.py +1 -1
  4. agno/compression/__init__.py +3 -0
  5. agno/compression/manager.py +247 -0
  6. agno/culture/__init__.py +3 -0
  7. agno/culture/manager.py +956 -0
  8. agno/db/async_postgres/__init__.py +3 -0
  9. agno/db/base.py +689 -6
  10. agno/db/dynamo/dynamo.py +933 -37
  11. agno/db/dynamo/schemas.py +174 -10
  12. agno/db/dynamo/utils.py +63 -4
  13. agno/db/firestore/firestore.py +831 -9
  14. agno/db/firestore/schemas.py +51 -0
  15. agno/db/firestore/utils.py +102 -4
  16. agno/db/gcs_json/gcs_json_db.py +660 -12
  17. agno/db/gcs_json/utils.py +60 -26
  18. agno/db/in_memory/in_memory_db.py +287 -14
  19. agno/db/in_memory/utils.py +60 -2
  20. agno/db/json/json_db.py +590 -14
  21. agno/db/json/utils.py +60 -26
  22. agno/db/migrations/manager.py +199 -0
  23. agno/db/migrations/v1_to_v2.py +43 -13
  24. agno/db/migrations/versions/__init__.py +0 -0
  25. agno/db/migrations/versions/v2_3_0.py +938 -0
  26. agno/db/mongo/__init__.py +15 -1
  27. agno/db/mongo/async_mongo.py +2760 -0
  28. agno/db/mongo/mongo.py +879 -11
  29. agno/db/mongo/schemas.py +42 -0
  30. agno/db/mongo/utils.py +80 -8
  31. agno/db/mysql/__init__.py +2 -1
  32. agno/db/mysql/async_mysql.py +2912 -0
  33. agno/db/mysql/mysql.py +946 -68
  34. agno/db/mysql/schemas.py +72 -10
  35. agno/db/mysql/utils.py +198 -7
  36. agno/db/postgres/__init__.py +2 -1
  37. agno/db/postgres/async_postgres.py +2579 -0
  38. agno/db/postgres/postgres.py +942 -57
  39. agno/db/postgres/schemas.py +81 -18
  40. agno/db/postgres/utils.py +164 -2
  41. agno/db/redis/redis.py +671 -7
  42. agno/db/redis/schemas.py +50 -0
  43. agno/db/redis/utils.py +65 -7
  44. agno/db/schemas/__init__.py +2 -1
  45. agno/db/schemas/culture.py +120 -0
  46. agno/db/schemas/evals.py +1 -0
  47. agno/db/schemas/memory.py +17 -2
  48. agno/db/singlestore/schemas.py +63 -0
  49. agno/db/singlestore/singlestore.py +949 -83
  50. agno/db/singlestore/utils.py +60 -2
  51. agno/db/sqlite/__init__.py +2 -1
  52. agno/db/sqlite/async_sqlite.py +2911 -0
  53. agno/db/sqlite/schemas.py +62 -0
  54. agno/db/sqlite/sqlite.py +965 -46
  55. agno/db/sqlite/utils.py +169 -8
  56. agno/db/surrealdb/__init__.py +3 -0
  57. agno/db/surrealdb/metrics.py +292 -0
  58. agno/db/surrealdb/models.py +334 -0
  59. agno/db/surrealdb/queries.py +71 -0
  60. agno/db/surrealdb/surrealdb.py +1908 -0
  61. agno/db/surrealdb/utils.py +147 -0
  62. agno/db/utils.py +2 -0
  63. agno/eval/__init__.py +10 -0
  64. agno/eval/accuracy.py +75 -55
  65. agno/eval/agent_as_judge.py +861 -0
  66. agno/eval/base.py +29 -0
  67. agno/eval/performance.py +16 -7
  68. agno/eval/reliability.py +28 -16
  69. agno/eval/utils.py +35 -17
  70. agno/exceptions.py +27 -2
  71. agno/filters.py +354 -0
  72. agno/guardrails/prompt_injection.py +1 -0
  73. agno/hooks/__init__.py +3 -0
  74. agno/hooks/decorator.py +164 -0
  75. agno/integrations/discord/client.py +1 -1
  76. agno/knowledge/chunking/agentic.py +13 -10
  77. agno/knowledge/chunking/fixed.py +4 -1
  78. agno/knowledge/chunking/semantic.py +9 -4
  79. agno/knowledge/chunking/strategy.py +59 -15
  80. agno/knowledge/embedder/fastembed.py +1 -1
  81. agno/knowledge/embedder/nebius.py +1 -1
  82. agno/knowledge/embedder/ollama.py +8 -0
  83. agno/knowledge/embedder/openai.py +8 -8
  84. agno/knowledge/embedder/sentence_transformer.py +6 -2
  85. agno/knowledge/embedder/vllm.py +262 -0
  86. agno/knowledge/knowledge.py +1618 -318
  87. agno/knowledge/reader/base.py +6 -2
  88. agno/knowledge/reader/csv_reader.py +8 -10
  89. agno/knowledge/reader/docx_reader.py +5 -6
  90. agno/knowledge/reader/field_labeled_csv_reader.py +16 -20
  91. agno/knowledge/reader/json_reader.py +5 -4
  92. agno/knowledge/reader/markdown_reader.py +8 -8
  93. agno/knowledge/reader/pdf_reader.py +17 -19
  94. agno/knowledge/reader/pptx_reader.py +101 -0
  95. agno/knowledge/reader/reader_factory.py +32 -3
  96. agno/knowledge/reader/s3_reader.py +3 -3
  97. agno/knowledge/reader/tavily_reader.py +193 -0
  98. agno/knowledge/reader/text_reader.py +22 -10
  99. agno/knowledge/reader/web_search_reader.py +1 -48
  100. agno/knowledge/reader/website_reader.py +10 -10
  101. agno/knowledge/reader/wikipedia_reader.py +33 -1
  102. agno/knowledge/types.py +1 -0
  103. agno/knowledge/utils.py +72 -7
  104. agno/media.py +22 -6
  105. agno/memory/__init__.py +14 -1
  106. agno/memory/manager.py +544 -83
  107. agno/memory/strategies/__init__.py +15 -0
  108. agno/memory/strategies/base.py +66 -0
  109. agno/memory/strategies/summarize.py +196 -0
  110. agno/memory/strategies/types.py +37 -0
  111. agno/models/aimlapi/aimlapi.py +17 -0
  112. agno/models/anthropic/claude.py +515 -40
  113. agno/models/aws/bedrock.py +102 -21
  114. agno/models/aws/claude.py +131 -274
  115. agno/models/azure/ai_foundry.py +41 -19
  116. agno/models/azure/openai_chat.py +39 -8
  117. agno/models/base.py +1249 -525
  118. agno/models/cerebras/cerebras.py +91 -21
  119. agno/models/cerebras/cerebras_openai.py +21 -2
  120. agno/models/cohere/chat.py +40 -6
  121. agno/models/cometapi/cometapi.py +18 -1
  122. agno/models/dashscope/dashscope.py +2 -3
  123. agno/models/deepinfra/deepinfra.py +18 -1
  124. agno/models/deepseek/deepseek.py +69 -3
  125. agno/models/fireworks/fireworks.py +18 -1
  126. agno/models/google/gemini.py +877 -80
  127. agno/models/google/utils.py +22 -0
  128. agno/models/groq/groq.py +51 -18
  129. agno/models/huggingface/huggingface.py +17 -6
  130. agno/models/ibm/watsonx.py +16 -6
  131. agno/models/internlm/internlm.py +18 -1
  132. agno/models/langdb/langdb.py +13 -1
  133. agno/models/litellm/chat.py +44 -9
  134. agno/models/litellm/litellm_openai.py +18 -1
  135. agno/models/message.py +28 -5
  136. agno/models/meta/llama.py +47 -14
  137. agno/models/meta/llama_openai.py +22 -17
  138. agno/models/mistral/mistral.py +8 -4
  139. agno/models/nebius/nebius.py +6 -7
  140. agno/models/nvidia/nvidia.py +20 -3
  141. agno/models/ollama/chat.py +24 -8
  142. agno/models/openai/chat.py +104 -29
  143. agno/models/openai/responses.py +101 -81
  144. agno/models/openrouter/openrouter.py +60 -3
  145. agno/models/perplexity/perplexity.py +17 -1
  146. agno/models/portkey/portkey.py +7 -6
  147. agno/models/requesty/requesty.py +24 -4
  148. agno/models/response.py +73 -2
  149. agno/models/sambanova/sambanova.py +20 -3
  150. agno/models/siliconflow/siliconflow.py +19 -2
  151. agno/models/together/together.py +20 -3
  152. agno/models/utils.py +254 -8
  153. agno/models/vercel/v0.py +20 -3
  154. agno/models/vertexai/__init__.py +0 -0
  155. agno/models/vertexai/claude.py +190 -0
  156. agno/models/vllm/vllm.py +19 -14
  157. agno/models/xai/xai.py +19 -2
  158. agno/os/app.py +549 -152
  159. agno/os/auth.py +190 -3
  160. agno/os/config.py +23 -0
  161. agno/os/interfaces/a2a/router.py +8 -11
  162. agno/os/interfaces/a2a/utils.py +1 -1
  163. agno/os/interfaces/agui/router.py +18 -3
  164. agno/os/interfaces/agui/utils.py +152 -39
  165. agno/os/interfaces/slack/router.py +55 -37
  166. agno/os/interfaces/slack/slack.py +9 -1
  167. agno/os/interfaces/whatsapp/router.py +0 -1
  168. agno/os/interfaces/whatsapp/security.py +3 -1
  169. agno/os/mcp.py +110 -52
  170. agno/os/middleware/__init__.py +2 -0
  171. agno/os/middleware/jwt.py +676 -112
  172. agno/os/router.py +40 -1478
  173. agno/os/routers/agents/__init__.py +3 -0
  174. agno/os/routers/agents/router.py +599 -0
  175. agno/os/routers/agents/schema.py +261 -0
  176. agno/os/routers/evals/evals.py +96 -39
  177. agno/os/routers/evals/schemas.py +65 -33
  178. agno/os/routers/evals/utils.py +80 -10
  179. agno/os/routers/health.py +10 -4
  180. agno/os/routers/knowledge/knowledge.py +196 -38
  181. agno/os/routers/knowledge/schemas.py +82 -22
  182. agno/os/routers/memory/memory.py +279 -52
  183. agno/os/routers/memory/schemas.py +46 -17
  184. agno/os/routers/metrics/metrics.py +20 -8
  185. agno/os/routers/metrics/schemas.py +16 -16
  186. agno/os/routers/session/session.py +462 -34
  187. agno/os/routers/teams/__init__.py +3 -0
  188. agno/os/routers/teams/router.py +512 -0
  189. agno/os/routers/teams/schema.py +257 -0
  190. agno/os/routers/traces/__init__.py +3 -0
  191. agno/os/routers/traces/schemas.py +414 -0
  192. agno/os/routers/traces/traces.py +499 -0
  193. agno/os/routers/workflows/__init__.py +3 -0
  194. agno/os/routers/workflows/router.py +624 -0
  195. agno/os/routers/workflows/schema.py +75 -0
  196. agno/os/schema.py +256 -693
  197. agno/os/scopes.py +469 -0
  198. agno/os/utils.py +514 -36
  199. agno/reasoning/anthropic.py +80 -0
  200. agno/reasoning/gemini.py +73 -0
  201. agno/reasoning/openai.py +5 -0
  202. agno/reasoning/vertexai.py +76 -0
  203. agno/run/__init__.py +6 -0
  204. agno/run/agent.py +155 -32
  205. agno/run/base.py +55 -3
  206. agno/run/requirement.py +181 -0
  207. agno/run/team.py +125 -38
  208. agno/run/workflow.py +72 -18
  209. agno/session/agent.py +102 -89
  210. agno/session/summary.py +56 -15
  211. agno/session/team.py +164 -90
  212. agno/session/workflow.py +405 -40
  213. agno/table.py +10 -0
  214. agno/team/team.py +3974 -1903
  215. agno/tools/dalle.py +2 -4
  216. agno/tools/eleven_labs.py +23 -25
  217. agno/tools/exa.py +21 -16
  218. agno/tools/file.py +153 -23
  219. agno/tools/file_generation.py +16 -10
  220. agno/tools/firecrawl.py +15 -7
  221. agno/tools/function.py +193 -38
  222. agno/tools/gmail.py +238 -14
  223. agno/tools/google_drive.py +271 -0
  224. agno/tools/googlecalendar.py +36 -8
  225. agno/tools/googlesheets.py +20 -5
  226. agno/tools/jira.py +20 -0
  227. agno/tools/mcp/__init__.py +10 -0
  228. agno/tools/mcp/mcp.py +331 -0
  229. agno/tools/mcp/multi_mcp.py +347 -0
  230. agno/tools/mcp/params.py +24 -0
  231. agno/tools/mcp_toolbox.py +3 -3
  232. agno/tools/models/nebius.py +5 -5
  233. agno/tools/models_labs.py +20 -10
  234. agno/tools/nano_banana.py +151 -0
  235. agno/tools/notion.py +204 -0
  236. agno/tools/parallel.py +314 -0
  237. agno/tools/postgres.py +76 -36
  238. agno/tools/redshift.py +406 -0
  239. agno/tools/scrapegraph.py +1 -1
  240. agno/tools/shopify.py +1519 -0
  241. agno/tools/slack.py +18 -3
  242. agno/tools/spotify.py +919 -0
  243. agno/tools/tavily.py +146 -0
  244. agno/tools/toolkit.py +25 -0
  245. agno/tools/workflow.py +8 -1
  246. agno/tools/yfinance.py +12 -11
  247. agno/tracing/__init__.py +12 -0
  248. agno/tracing/exporter.py +157 -0
  249. agno/tracing/schemas.py +276 -0
  250. agno/tracing/setup.py +111 -0
  251. agno/utils/agent.py +938 -0
  252. agno/utils/cryptography.py +22 -0
  253. agno/utils/dttm.py +33 -0
  254. agno/utils/events.py +151 -3
  255. agno/utils/gemini.py +15 -5
  256. agno/utils/hooks.py +118 -4
  257. agno/utils/http.py +113 -2
  258. agno/utils/knowledge.py +12 -5
  259. agno/utils/log.py +1 -0
  260. agno/utils/mcp.py +92 -2
  261. agno/utils/media.py +187 -1
  262. agno/utils/merge_dict.py +3 -3
  263. agno/utils/message.py +60 -0
  264. agno/utils/models/ai_foundry.py +9 -2
  265. agno/utils/models/claude.py +49 -14
  266. agno/utils/models/cohere.py +9 -2
  267. agno/utils/models/llama.py +9 -2
  268. agno/utils/models/mistral.py +4 -2
  269. agno/utils/print_response/agent.py +109 -16
  270. agno/utils/print_response/team.py +223 -30
  271. agno/utils/print_response/workflow.py +251 -34
  272. agno/utils/streamlit.py +1 -1
  273. agno/utils/team.py +98 -9
  274. agno/utils/tokens.py +657 -0
  275. agno/vectordb/base.py +39 -7
  276. agno/vectordb/cassandra/cassandra.py +21 -5
  277. agno/vectordb/chroma/chromadb.py +43 -12
  278. agno/vectordb/clickhouse/clickhousedb.py +21 -5
  279. agno/vectordb/couchbase/couchbase.py +29 -5
  280. agno/vectordb/lancedb/lance_db.py +92 -181
  281. agno/vectordb/langchaindb/langchaindb.py +24 -4
  282. agno/vectordb/lightrag/lightrag.py +17 -3
  283. agno/vectordb/llamaindex/llamaindexdb.py +25 -5
  284. agno/vectordb/milvus/milvus.py +50 -37
  285. agno/vectordb/mongodb/__init__.py +7 -1
  286. agno/vectordb/mongodb/mongodb.py +36 -30
  287. agno/vectordb/pgvector/pgvector.py +201 -77
  288. agno/vectordb/pineconedb/pineconedb.py +41 -23
  289. agno/vectordb/qdrant/qdrant.py +67 -54
  290. agno/vectordb/redis/__init__.py +9 -0
  291. agno/vectordb/redis/redisdb.py +682 -0
  292. agno/vectordb/singlestore/singlestore.py +50 -29
  293. agno/vectordb/surrealdb/surrealdb.py +31 -41
  294. agno/vectordb/upstashdb/upstashdb.py +34 -6
  295. agno/vectordb/weaviate/weaviate.py +53 -14
  296. agno/workflow/__init__.py +2 -0
  297. agno/workflow/agent.py +299 -0
  298. agno/workflow/condition.py +120 -18
  299. agno/workflow/loop.py +77 -10
  300. agno/workflow/parallel.py +231 -143
  301. agno/workflow/router.py +118 -17
  302. agno/workflow/step.py +609 -170
  303. agno/workflow/steps.py +73 -6
  304. agno/workflow/types.py +96 -21
  305. agno/workflow/workflow.py +2039 -262
  306. {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/METADATA +201 -66
  307. agno-2.3.13.dist-info/RECORD +613 -0
  308. agno/tools/googlesearch.py +0 -98
  309. agno/tools/mcp.py +0 -679
  310. agno/tools/memori.py +0 -339
  311. agno-2.1.2.dist-info/RECORD +0 -543
  312. {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/WHEEL +0 -0
  313. {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/licenses/LICENSE +0 -0
  314. {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/top_level.txt +0 -0
agno/db/sqlite/sqlite.py CHANGED
@@ -1,10 +1,15 @@
1
1
  import time
2
2
  from datetime import date, datetime, timedelta, timezone
3
3
  from pathlib import Path
4
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union, cast
5
5
  from uuid import uuid4
6
6
 
7
+ if TYPE_CHECKING:
8
+ from agno.tracing.schemas import Span, Trace
9
+
7
10
  from agno.db.base import BaseDb, SessionType
11
+ from agno.db.migrations.manager import MigrationManager
12
+ from agno.db.schemas.culture import CulturalKnowledge
8
13
  from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
9
14
  from agno.db.schemas.knowledge import KnowledgeRow
10
15
  from agno.db.schemas.memory import UserMemory
@@ -13,10 +18,12 @@ from agno.db.sqlite.utils import (
13
18
  apply_sorting,
14
19
  bulk_upsert_metrics,
15
20
  calculate_date_metrics,
21
+ deserialize_cultural_knowledge_from_db,
16
22
  fetch_all_sessions_data,
17
23
  get_dates_to_calculate_metrics_for,
18
24
  is_table_available,
19
25
  is_valid_table,
26
+ serialize_cultural_knowledge_for_db,
20
27
  )
21
28
  from agno.db.utils import deserialize_session_json_fields, serialize_session_json_fields
22
29
  from agno.session import AgentSession, Session, TeamSession, WorkflowSession
@@ -24,11 +31,11 @@ from agno.utils.log import log_debug, log_error, log_info, log_warning
24
31
  from agno.utils.string import generate_id
25
32
 
26
33
  try:
27
- from sqlalchemy import Column, MetaData, Table, and_, func, select, text
34
+ from sqlalchemy import Column, MetaData, String, Table, func, select, text
28
35
  from sqlalchemy.dialects import sqlite
29
36
  from sqlalchemy.engine import Engine, create_engine
30
37
  from sqlalchemy.orm import scoped_session, sessionmaker
31
- from sqlalchemy.schema import Index, UniqueConstraint
38
+ from sqlalchemy.schema import ForeignKey, Index, UniqueConstraint
32
39
  except ImportError:
33
40
  raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`")
34
41
 
@@ -36,14 +43,18 @@ except ImportError:
36
43
  class SqliteDb(BaseDb):
37
44
  def __init__(
38
45
  self,
46
+ db_file: Optional[str] = None,
39
47
  db_engine: Optional[Engine] = None,
40
48
  db_url: Optional[str] = None,
41
- db_file: Optional[str] = None,
42
49
  session_table: Optional[str] = None,
50
+ culture_table: Optional[str] = None,
43
51
  memory_table: Optional[str] = None,
44
52
  metrics_table: Optional[str] = None,
45
53
  eval_table: Optional[str] = None,
46
54
  knowledge_table: Optional[str] = None,
55
+ traces_table: Optional[str] = None,
56
+ spans_table: Optional[str] = None,
57
+ versions_table: Optional[str] = None,
47
58
  id: Optional[str] = None,
48
59
  ):
49
60
  """
@@ -56,14 +67,18 @@ class SqliteDb(BaseDb):
56
67
  4. Create a new database in the current directory
57
68
 
58
69
  Args:
70
+ db_file (Optional[str]): The database file to connect to.
59
71
  db_engine (Optional[Engine]): The SQLAlchemy database engine to use.
60
72
  db_url (Optional[str]): The database URL to connect to.
61
- db_file (Optional[str]): The database file to connect to.
62
73
  session_table (Optional[str]): Name of the table to store Agent, Team and Workflow sessions.
74
+ culture_table (Optional[str]): Name of the table to store cultural notions.
63
75
  memory_table (Optional[str]): Name of the table to store user memories.
64
76
  metrics_table (Optional[str]): Name of the table to store metrics.
65
77
  eval_table (Optional[str]): Name of the table to store evaluation runs data.
66
78
  knowledge_table (Optional[str]): Name of the table to store knowledge documents data.
79
+ traces_table (Optional[str]): Name of the table to store run traces.
80
+ spans_table (Optional[str]): Name of the table to store span events.
81
+ versions_table (Optional[str]): Name of the table to store schema versions.
67
82
  id (Optional[str]): ID of the database.
68
83
 
69
84
  Raises:
@@ -76,10 +91,14 @@ class SqliteDb(BaseDb):
76
91
  super().__init__(
77
92
  id=id,
78
93
  session_table=session_table,
94
+ culture_table=culture_table,
79
95
  memory_table=memory_table,
80
96
  metrics_table=metrics_table,
81
97
  eval_table=eval_table,
82
98
  knowledge_table=knowledge_table,
99
+ traces_table=traces_table,
100
+ spans_table=spans_table,
101
+ versions_table=versions_table,
83
102
  )
84
103
 
85
104
  _engine: Optional[Engine] = db_engine
@@ -107,6 +126,31 @@ class SqliteDb(BaseDb):
107
126
  self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
108
127
 
109
128
  # -- DB methods --
129
+ def table_exists(self, table_name: str) -> bool:
130
+ """Check if a table with the given name exists in the SQLite database.
131
+
132
+ Args:
133
+ table_name: Name of the table to check
134
+
135
+ Returns:
136
+ bool: True if the table exists in the database, False otherwise
137
+ """
138
+ with self.Session() as sess:
139
+ return is_table_available(session=sess, table_name=table_name)
140
+
141
+ def _create_all_tables(self):
142
+ """Create all tables for the database."""
143
+ tables_to_create = [
144
+ (self.session_table_name, "sessions"),
145
+ (self.memory_table_name, "memories"),
146
+ (self.metrics_table_name, "metrics"),
147
+ (self.eval_table_name, "evals"),
148
+ (self.knowledge_table_name, "knowledge"),
149
+ (self.versions_table_name, "versions"),
150
+ ]
151
+
152
+ for table_name, table_type in tables_to_create:
153
+ self._get_or_create_table(table_name=table_name, table_type=table_type, create_table_if_not_found=True)
110
154
 
111
155
  def _create_table(self, table_name: str, table_type: str) -> Table:
112
156
  """
@@ -120,8 +164,7 @@ class SqliteDb(BaseDb):
120
164
  Table: SQLAlchemy Table object
121
165
  """
122
166
  try:
123
- table_schema = get_table_schema_definition(table_type)
124
- log_debug(f"Creating table {table_name} with schema: {table_schema}")
167
+ table_schema = get_table_schema_definition(table_type).copy()
125
168
 
126
169
  columns: List[Column] = []
127
170
  indexes: List[str] = []
@@ -143,11 +186,19 @@ class SqliteDb(BaseDb):
143
186
  column_kwargs["unique"] = True
144
187
  unique_constraints.append(col_name)
145
188
 
189
+ # Handle foreign key constraint
190
+ if "foreign_key" in col_config:
191
+ fk_ref = col_config["foreign_key"]
192
+ # For spans table, dynamically replace the traces table reference
193
+ # with the actual trace table name configured for this db instance
194
+ if table_type == "spans" and "trace_id" in fk_ref:
195
+ fk_ref = f"{self.trace_table_name}.trace_id"
196
+ column_args.append(ForeignKey(fk_ref))
197
+
146
198
  columns.append(Column(*column_args, **column_kwargs)) # type: ignore
147
199
 
148
200
  # Create the table object
149
- table_metadata = MetaData()
150
- table = Table(table_name, table_metadata, *columns)
201
+ table = Table(table_name, self.metadata, *columns)
151
202
 
152
203
  # Add multi-column unique constraints with table-specific names
153
204
  for constraint in schema_unique_constraints:
@@ -161,12 +212,17 @@ class SqliteDb(BaseDb):
161
212
  table.append_constraint(Index(idx_name, idx_col))
162
213
 
163
214
  # Create table
164
- table.create(self.db_engine, checkfirst=True)
215
+ table_created = False
216
+ if not self.table_exists(table_name):
217
+ table.create(self.db_engine, checkfirst=True)
218
+ log_debug(f"Successfully created table '{table_name}'")
219
+ table_created = True
220
+ else:
221
+ log_debug(f"Table '{table_name}' already exists, skipping creation")
165
222
 
166
223
  # Create indexes
167
224
  for idx in table.indexes:
168
225
  try:
169
- log_debug(f"Creating index: {idx.name}")
170
226
  # Check if index already exists
171
227
  with self.Session() as sess:
172
228
  exists_query = text("SELECT 1 FROM sqlite_master WHERE type = 'index' AND name = :index_name")
@@ -177,13 +233,21 @@ class SqliteDb(BaseDb):
177
233
 
178
234
  idx.create(self.db_engine)
179
235
 
236
+ log_debug(f"Created index: {idx.name} for table {table_name}")
180
237
  except Exception as e:
181
238
  log_warning(f"Error creating index {idx.name}: {e}")
182
239
 
183
- log_info(f"Successfully created table '{table_name}'")
240
+ # Store the schema version for the created table
241
+ if table_name != self.versions_table_name and table_created:
242
+ latest_schema_version = MigrationManager(self).latest_schema_version
243
+ self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
244
+
184
245
  return table
185
246
 
186
247
  except Exception as e:
248
+ from traceback import print_exc
249
+
250
+ print_exc()
187
251
  log_error(f"Could not create table '{table_name}': {e}")
188
252
  raise e
189
253
 
@@ -229,11 +293,50 @@ class SqliteDb(BaseDb):
229
293
  )
230
294
  return self.knowledge_table
231
295
 
296
+ elif table_type == "traces":
297
+ self.traces_table = self._get_or_create_table(
298
+ table_name=self.trace_table_name,
299
+ table_type="traces",
300
+ create_table_if_not_found=create_table_if_not_found,
301
+ )
302
+ return self.traces_table
303
+
304
+ elif table_type == "spans":
305
+ # Ensure traces table exists first (spans has FK to traces)
306
+ if create_table_if_not_found:
307
+ self._get_table(table_type="traces", create_table_if_not_found=True)
308
+
309
+ self.spans_table = self._get_or_create_table(
310
+ table_name=self.span_table_name,
311
+ table_type="spans",
312
+ create_table_if_not_found=create_table_if_not_found,
313
+ )
314
+ return self.spans_table
315
+
316
+ elif table_type == "culture":
317
+ self.culture_table = self._get_or_create_table(
318
+ table_name=self.culture_table_name,
319
+ table_type="culture",
320
+ create_table_if_not_found=create_table_if_not_found,
321
+ )
322
+ return self.culture_table
323
+
324
+ elif table_type == "versions":
325
+ self.versions_table = self._get_or_create_table(
326
+ table_name=self.versions_table_name,
327
+ table_type="versions",
328
+ create_table_if_not_found=create_table_if_not_found,
329
+ )
330
+ return self.versions_table
331
+
232
332
  else:
233
333
  raise ValueError(f"Unknown table type: '{table_type}'")
234
334
 
235
335
  def _get_or_create_table(
236
- self, table_name: str, table_type: str, create_table_if_not_found: Optional[bool] = False
336
+ self,
337
+ table_name: str,
338
+ table_type: str,
339
+ create_table_if_not_found: Optional[bool] = False,
237
340
  ) -> Optional[Table]:
238
341
  """
239
342
  Check if the table exists and is valid, else create it.
@@ -259,13 +362,48 @@ class SqliteDb(BaseDb):
259
362
 
260
363
  try:
261
364
  table = Table(table_name, self.metadata, autoload_with=self.db_engine)
262
- log_debug(f"Loaded existing table {table_name}")
263
365
  return table
264
366
 
265
367
  except Exception as e:
266
368
  log_error(f"Error loading existing table {table_name}: {e}")
267
369
  raise e
268
370
 
371
+ def get_latest_schema_version(self, table_name: str):
372
+ """Get the latest version of the database schema."""
373
+ table = self._get_table(table_type="versions", create_table_if_not_found=True)
374
+ if table is None:
375
+ return "2.0.0"
376
+ with self.Session() as sess:
377
+ stmt = select(table)
378
+ # Latest version for the given table
379
+ stmt = stmt.where(table.c.table_name == table_name)
380
+ stmt = stmt.order_by(table.c.version.desc()).limit(1)
381
+ result = sess.execute(stmt).fetchone()
382
+ if result is None:
383
+ return "2.0.0"
384
+ version_dict = dict(result._mapping)
385
+ return version_dict.get("version") or "2.0.0"
386
+
387
+ def upsert_schema_version(self, table_name: str, version: str) -> None:
388
+ """Upsert the schema version into the database."""
389
+ table = self._get_table(table_type="versions", create_table_if_not_found=True)
390
+ if table is None:
391
+ return
392
+ current_datetime = datetime.now().isoformat()
393
+ with self.Session() as sess, sess.begin():
394
+ stmt = sqlite.insert(table).values(
395
+ table_name=table_name,
396
+ version=version,
397
+ created_at=current_datetime, # Store as ISO format string
398
+ updated_at=current_datetime,
399
+ )
400
+ # Update version if table_name already exists
401
+ stmt = stmt.on_conflict_do_update(
402
+ index_elements=["table_name"],
403
+ set_=dict(version=version, updated_at=current_datetime),
404
+ )
405
+ sess.execute(stmt)
406
+
269
407
  # -- Session methods --
270
408
 
271
409
  def delete_session(self, session_id: str) -> bool:
@@ -357,8 +495,6 @@ class SqliteDb(BaseDb):
357
495
  # Filtering
358
496
  if user_id is not None:
359
497
  stmt = stmt.where(table.c.user_id == user_id)
360
- if session_type is not None:
361
- stmt = stmt.where(table.c.session_type == session_type)
362
498
 
363
499
  result = sess.execute(stmt).fetchone()
364
500
  if result is None:
@@ -483,7 +619,11 @@ class SqliteDb(BaseDb):
483
619
  raise e
484
620
 
485
621
  def rename_session(
486
- self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
622
+ self,
623
+ session_id: str,
624
+ session_type: SessionType,
625
+ session_name: str,
626
+ deserialize: Optional[bool] = True,
487
627
  ) -> Optional[Union[Session, Dict[str, Any]]]:
488
628
  """
489
629
  Rename a session in the database.
@@ -664,7 +804,10 @@ class SqliteDb(BaseDb):
664
804
  raise e
665
805
 
666
806
  def upsert_sessions(
667
- self, sessions: List[Session], deserialize: Optional[bool] = True
807
+ self,
808
+ sessions: List[Session],
809
+ deserialize: Optional[bool] = True,
810
+ preserve_updated_at: bool = False,
668
811
  ) -> List[Union[Session, Dict[str, Any]]]:
669
812
  """
670
813
  Bulk upsert multiple sessions for improved performance on large datasets.
@@ -672,6 +815,7 @@ class SqliteDb(BaseDb):
672
815
  Args:
673
816
  sessions (List[Session]): List of sessions to upsert.
674
817
  deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
818
+ preserve_updated_at (bool): If True, preserve the updated_at from the session object.
675
819
 
676
820
  Returns:
677
821
  List[Union[Session, Dict[str, Any]]]: List of upserted sessions.
@@ -715,6 +859,8 @@ class SqliteDb(BaseDb):
715
859
  agent_data = []
716
860
  for session in agent_sessions:
717
861
  serialized_session = serialize_session_json_fields(session.to_dict())
862
+ # Use preserved updated_at if flag is set and value exists, otherwise use current time
863
+ updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
718
864
  agent_data.append(
719
865
  {
720
866
  "session_id": serialized_session.get("session_id"),
@@ -727,7 +873,7 @@ class SqliteDb(BaseDb):
727
873
  "runs": serialized_session.get("runs"),
728
874
  "summary": serialized_session.get("summary"),
729
875
  "created_at": serialized_session.get("created_at"),
730
- "updated_at": serialized_session.get("created_at"),
876
+ "updated_at": updated_at,
731
877
  }
732
878
  )
733
879
 
@@ -743,7 +889,7 @@ class SqliteDb(BaseDb):
743
889
  metadata=stmt.excluded.metadata,
744
890
  runs=stmt.excluded.runs,
745
891
  summary=stmt.excluded.summary,
746
- updated_at=int(time.time()),
892
+ updated_at=stmt.excluded.updated_at,
747
893
  ),
748
894
  )
749
895
  sess.execute(stmt, agent_data)
@@ -768,6 +914,8 @@ class SqliteDb(BaseDb):
768
914
  team_data = []
769
915
  for session in team_sessions:
770
916
  serialized_session = serialize_session_json_fields(session.to_dict())
917
+ # Use preserved updated_at if flag is set and value exists, otherwise use current time
918
+ updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
771
919
  team_data.append(
772
920
  {
773
921
  "session_id": serialized_session.get("session_id"),
@@ -777,7 +925,7 @@ class SqliteDb(BaseDb):
777
925
  "runs": serialized_session.get("runs"),
778
926
  "summary": serialized_session.get("summary"),
779
927
  "created_at": serialized_session.get("created_at"),
780
- "updated_at": serialized_session.get("created_at"),
928
+ "updated_at": updated_at,
781
929
  "team_data": serialized_session.get("team_data"),
782
930
  "session_data": serialized_session.get("session_data"),
783
931
  "metadata": serialized_session.get("metadata"),
@@ -796,7 +944,7 @@ class SqliteDb(BaseDb):
796
944
  metadata=stmt.excluded.metadata,
797
945
  runs=stmt.excluded.runs,
798
946
  summary=stmt.excluded.summary,
799
- updated_at=int(time.time()),
947
+ updated_at=stmt.excluded.updated_at,
800
948
  ),
801
949
  )
802
950
  sess.execute(stmt, team_data)
@@ -821,6 +969,8 @@ class SqliteDb(BaseDb):
821
969
  workflow_data = []
822
970
  for session in workflow_sessions:
823
971
  serialized_session = serialize_session_json_fields(session.to_dict())
972
+ # Use preserved updated_at if flag is set and value exists, otherwise use current time
973
+ updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
824
974
  workflow_data.append(
825
975
  {
826
976
  "session_id": serialized_session.get("session_id"),
@@ -830,7 +980,7 @@ class SqliteDb(BaseDb):
830
980
  "runs": serialized_session.get("runs"),
831
981
  "summary": serialized_session.get("summary"),
832
982
  "created_at": serialized_session.get("created_at"),
833
- "updated_at": serialized_session.get("created_at"),
983
+ "updated_at": updated_at,
834
984
  "workflow_data": serialized_session.get("workflow_data"),
835
985
  "session_data": serialized_session.get("session_data"),
836
986
  "metadata": serialized_session.get("metadata"),
@@ -849,7 +999,7 @@ class SqliteDb(BaseDb):
849
999
  metadata=stmt.excluded.metadata,
850
1000
  runs=stmt.excluded.runs,
851
1001
  summary=stmt.excluded.summary,
852
- updated_at=int(time.time()),
1002
+ updated_at=stmt.excluded.updated_at,
853
1003
  ),
854
1004
  )
855
1005
  sess.execute(stmt, workflow_data)
@@ -958,9 +1108,8 @@ class SqliteDb(BaseDb):
958
1108
 
959
1109
  with self.Session() as sess, sess.begin():
960
1110
  # Select topics from all results
961
- stmt = select(func.json_array_elements_text(table.c.topics)).select_from(table)
1111
+ stmt = select(table.c.topics)
962
1112
  result = sess.execute(stmt).fetchall()
963
-
964
1113
  return list(set([record[0] for record in result]))
965
1114
 
966
1115
  except Exception as e:
@@ -968,7 +1117,10 @@ class SqliteDb(BaseDb):
968
1117
  raise e
969
1118
 
970
1119
  def get_user_memory(
971
- self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
1120
+ self,
1121
+ memory_id: str,
1122
+ deserialize: Optional[bool] = True,
1123
+ user_id: Optional[str] = None,
972
1124
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
973
1125
  """Get a memory from the database.
974
1126
 
@@ -1060,8 +1212,8 @@ class SqliteDb(BaseDb):
1060
1212
  if team_id is not None:
1061
1213
  stmt = stmt.where(table.c.team_id == team_id)
1062
1214
  if topics is not None:
1063
- topic_conditions = [text(f"topics::text LIKE '%\"{topic}\"%'") for topic in topics]
1064
- stmt = stmt.where(and_(*topic_conditions))
1215
+ for topic in topics:
1216
+ stmt = stmt.where(func.cast(table.c.topics, String).like(f'%"{topic}"%'))
1065
1217
  if search_content is not None:
1066
1218
  stmt = stmt.where(table.c.memory.ilike(f"%{search_content}%"))
1067
1219
 
@@ -1096,12 +1248,14 @@ class SqliteDb(BaseDb):
1096
1248
  self,
1097
1249
  limit: Optional[int] = None,
1098
1250
  page: Optional[int] = None,
1251
+ user_id: Optional[str] = None,
1099
1252
  ) -> Tuple[List[Dict[str, Any]], int]:
1100
1253
  """Get user memories stats.
1101
1254
 
1102
1255
  Args:
1103
1256
  limit (Optional[int]): The maximum number of user stats to return.
1104
1257
  page (Optional[int]): The page number.
1258
+ user_id (Optional[str]): User ID for filtering.
1105
1259
 
1106
1260
  Returns:
1107
1261
  Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
@@ -1124,19 +1278,20 @@ class SqliteDb(BaseDb):
1124
1278
  return [], 0
1125
1279
 
1126
1280
  with self.Session() as sess, sess.begin():
1127
- stmt = (
1128
- select(
1129
- table.c.user_id,
1130
- func.count(table.c.memory_id).label("total_memories"),
1131
- func.max(table.c.updated_at).label("last_memory_updated_at"),
1132
- )
1133
- .where(table.c.user_id.is_not(None))
1134
- .group_by(table.c.user_id)
1135
- .order_by(func.max(table.c.updated_at).desc())
1281
+ stmt = select(
1282
+ table.c.user_id,
1283
+ func.count(table.c.memory_id).label("total_memories"),
1284
+ func.max(table.c.updated_at).label("last_memory_updated_at"),
1136
1285
  )
1286
+ if user_id is not None:
1287
+ stmt = stmt.where(table.c.user_id == user_id)
1288
+ else:
1289
+ stmt = stmt.where(table.c.user_id.is_not(None))
1290
+ stmt = stmt.group_by(table.c.user_id)
1291
+ stmt = stmt.order_by(func.max(table.c.updated_at).desc())
1137
1292
 
1138
1293
  count_stmt = select(func.count()).select_from(stmt.alias())
1139
- total_count = sess.execute(count_stmt).scalar()
1294
+ total_count = sess.execute(count_stmt).scalar() or 0
1140
1295
 
1141
1296
  # Pagination
1142
1297
  if limit is not None:
@@ -1186,6 +1341,8 @@ class SqliteDb(BaseDb):
1186
1341
  if memory.memory_id is None:
1187
1342
  memory.memory_id = str(uuid4())
1188
1343
 
1344
+ current_time = int(time.time())
1345
+
1189
1346
  with self.Session() as sess, sess.begin():
1190
1347
  stmt = sqlite.insert(table).values(
1191
1348
  user_id=memory.user_id,
@@ -1195,7 +1352,9 @@ class SqliteDb(BaseDb):
1195
1352
  memory=memory.memory,
1196
1353
  topics=memory.topics,
1197
1354
  input=memory.input,
1198
- updated_at=int(time.time()),
1355
+ feedback=memory.feedback,
1356
+ created_at=memory.created_at,
1357
+ updated_at=memory.created_at,
1199
1358
  )
1200
1359
  stmt = stmt.on_conflict_do_update( # type: ignore
1201
1360
  index_elements=["memory_id"],
@@ -1203,7 +1362,12 @@ class SqliteDb(BaseDb):
1203
1362
  memory=memory.memory,
1204
1363
  topics=memory.topics,
1205
1364
  input=memory.input,
1206
- updated_at=int(time.time()),
1365
+ agent_id=memory.agent_id,
1366
+ team_id=memory.team_id,
1367
+ feedback=memory.feedback,
1368
+ updated_at=current_time,
1369
+ # Preserve created_at on update - don't overwrite existing value
1370
+ created_at=table.c.created_at,
1207
1371
  ),
1208
1372
  ).returning(table)
1209
1373
 
@@ -1224,7 +1388,10 @@ class SqliteDb(BaseDb):
1224
1388
  raise e
1225
1389
 
1226
1390
  def upsert_memories(
1227
- self, memories: List[UserMemory], deserialize: Optional[bool] = True
1391
+ self,
1392
+ memories: List[UserMemory],
1393
+ deserialize: Optional[bool] = True,
1394
+ preserve_updated_at: bool = False,
1228
1395
  ) -> List[Union[UserMemory, Dict[str, Any]]]:
1229
1396
  """
1230
1397
  Bulk upsert multiple user memories for improved performance on large datasets.
@@ -1255,10 +1422,15 @@ class SqliteDb(BaseDb):
1255
1422
  ]
1256
1423
  # Prepare bulk data
1257
1424
  bulk_data = []
1425
+ current_time = int(time.time())
1426
+
1258
1427
  for memory in memories:
1259
1428
  if memory.memory_id is None:
1260
1429
  memory.memory_id = str(uuid4())
1261
1430
 
1431
+ # Use preserved updated_at if flag is set and value exists, otherwise use current time
1432
+ updated_at = memory.updated_at if preserve_updated_at else current_time
1433
+
1262
1434
  bulk_data.append(
1263
1435
  {
1264
1436
  "user_id": memory.user_id,
@@ -1267,7 +1439,10 @@ class SqliteDb(BaseDb):
1267
1439
  "memory_id": memory.memory_id,
1268
1440
  "memory": memory.memory,
1269
1441
  "topics": memory.topics,
1270
- "updated_at": int(time.time()),
1442
+ "input": memory.input,
1443
+ "feedback": memory.feedback,
1444
+ "created_at": memory.created_at,
1445
+ "updated_at": updated_at,
1271
1446
  }
1272
1447
  )
1273
1448
 
@@ -1284,7 +1459,10 @@ class SqliteDb(BaseDb):
1284
1459
  input=stmt.excluded.input,
1285
1460
  agent_id=stmt.excluded.agent_id,
1286
1461
  team_id=stmt.excluded.team_id,
1287
- updated_at=int(time.time()),
1462
+ feedback=stmt.excluded.feedback,
1463
+ updated_at=stmt.excluded.updated_at,
1464
+ # Preserve created_at on update
1465
+ created_at=table.c.created_at,
1288
1466
  ),
1289
1467
  )
1290
1468
  sess.execute(stmt, bulk_data)
@@ -1450,7 +1628,9 @@ class SqliteDb(BaseDb):
1450
1628
  start_timestamp=start_timestamp, end_timestamp=end_timestamp
1451
1629
  )
1452
1630
  all_sessions_data = fetch_all_sessions_data(
1453
- sessions=sessions, dates_to_process=dates_to_process, start_timestamp=start_timestamp
1631
+ sessions=sessions,
1632
+ dates_to_process=dates_to_process,
1633
+ start_timestamp=start_timestamp,
1454
1634
  )
1455
1635
  if not all_sessions_data:
1456
1636
  log_info("No new session data found. Won't calculate metrics.")
@@ -1697,7 +1877,11 @@ class SqliteDb(BaseDb):
1697
1877
  with self.Session() as sess, sess.begin():
1698
1878
  current_time = int(time.time())
1699
1879
  stmt = sqlite.insert(table).values(
1700
- {"created_at": current_time, "updated_at": current_time, **eval_run.model_dump()}
1880
+ {
1881
+ "created_at": current_time,
1882
+ "updated_at": current_time,
1883
+ **eval_run.model_dump(),
1884
+ }
1701
1885
  )
1702
1886
  sess.execute(stmt)
1703
1887
  sess.commit()
@@ -1930,6 +2114,510 @@ class SqliteDb(BaseDb):
1930
2114
  log_error(f"Error renaming eval run {eval_run_id}: {e}")
1931
2115
  raise e
1932
2116
 
2117
+ # -- Trace methods --
2118
+
2119
+ def _get_traces_base_query(self, table: Table, spans_table: Optional[Table] = None):
2120
+ """Build base query for traces with aggregated span counts.
2121
+
2122
+ Args:
2123
+ table: The traces table.
2124
+ spans_table: The spans table (optional).
2125
+
2126
+ Returns:
2127
+ SQLAlchemy select statement with total_spans and error_count calculated dynamically.
2128
+ """
2129
+ from sqlalchemy import case, func, literal
2130
+
2131
+ if spans_table is not None:
2132
+ # JOIN with spans table to calculate total_spans and error_count
2133
+ return (
2134
+ select(
2135
+ table,
2136
+ func.coalesce(func.count(spans_table.c.span_id), 0).label("total_spans"),
2137
+ func.coalesce(func.sum(case((spans_table.c.status_code == "ERROR", 1), else_=0)), 0).label(
2138
+ "error_count"
2139
+ ),
2140
+ )
2141
+ .select_from(table.outerjoin(spans_table, table.c.trace_id == spans_table.c.trace_id))
2142
+ .group_by(table.c.trace_id)
2143
+ )
2144
+ else:
2145
+ # Fallback if spans table doesn't exist
2146
+ return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
2147
+
2148
+ def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
2149
+ """Build a SQL CASE expression that returns the component level for a trace.
2150
+
2151
+ Component levels (higher = more important):
2152
+ - 3: Workflow root (.run or .arun with workflow_id)
2153
+ - 2: Team root (.run or .arun with team_id)
2154
+ - 1: Agent root (.run or .arun with agent_id)
2155
+ - 0: Child span (not a root)
2156
+
2157
+ Args:
2158
+ workflow_id_col: SQL column/expression for workflow_id
2159
+ team_id_col: SQL column/expression for team_id
2160
+ agent_id_col: SQL column/expression for agent_id
2161
+ name_col: SQL column/expression for name
2162
+
2163
+ Returns:
2164
+ SQLAlchemy CASE expression returning the component level as an integer.
2165
+ """
2166
+ from sqlalchemy import and_, case, or_
2167
+
2168
+ is_root_name = or_(name_col.contains(".run"), name_col.contains(".arun"))
2169
+
2170
+ return case(
2171
+ # Workflow root (level 3)
2172
+ (and_(workflow_id_col.isnot(None), is_root_name), 3),
2173
+ # Team root (level 2)
2174
+ (and_(team_id_col.isnot(None), is_root_name), 2),
2175
+ # Agent root (level 1)
2176
+ (and_(agent_id_col.isnot(None), is_root_name), 1),
2177
+ # Child span or unknown (level 0)
2178
+ else_=0,
2179
+ )
2180
+
2181
+ def upsert_trace(self, trace: "Trace") -> None:
2182
+ """Create or update a single trace record in the database.
2183
+
2184
+ Uses INSERT ... ON CONFLICT DO UPDATE (upsert) to handle concurrent inserts
2185
+ atomically and avoid race conditions.
2186
+
2187
+ Args:
2188
+ trace: The Trace object to store (one per trace_id).
2189
+ """
2190
+ from sqlalchemy import case
2191
+
2192
+ try:
2193
+ table = self._get_table(table_type="traces", create_table_if_not_found=True)
2194
+ if table is None:
2195
+ return
2196
+
2197
+ trace_dict = trace.to_dict()
2198
+ trace_dict.pop("total_spans", None)
2199
+ trace_dict.pop("error_count", None)
2200
+
2201
+ with self.Session() as sess, sess.begin():
2202
+ # Use upsert to handle concurrent inserts atomically
2203
+ # On conflict, update fields while preserving existing non-null context values
2204
+ # and keeping the earliest start_time
2205
+ insert_stmt = sqlite.insert(table).values(trace_dict)
2206
+
2207
+ # Build component level expressions for comparing trace priority
2208
+ new_level = self._get_trace_component_level_expr(
2209
+ insert_stmt.excluded.workflow_id,
2210
+ insert_stmt.excluded.team_id,
2211
+ insert_stmt.excluded.agent_id,
2212
+ insert_stmt.excluded.name,
2213
+ )
2214
+ existing_level = self._get_trace_component_level_expr(
2215
+ table.c.workflow_id,
2216
+ table.c.team_id,
2217
+ table.c.agent_id,
2218
+ table.c.name,
2219
+ )
2220
+
2221
+ # Build the ON CONFLICT DO UPDATE clause
2222
+ # Use MIN for start_time, MAX for end_time to capture full trace duration
2223
+ # SQLite stores timestamps as ISO strings, so string comparison works for ISO format
2224
+ # Duration is calculated as: (MAX(end_time) - MIN(start_time)) in milliseconds
2225
+ # SQLite doesn't have epoch extraction, so we calculate duration using julianday
2226
+ upsert_stmt = insert_stmt.on_conflict_do_update(
2227
+ index_elements=["trace_id"],
2228
+ set_={
2229
+ "end_time": func.max(table.c.end_time, insert_stmt.excluded.end_time),
2230
+ "start_time": func.min(table.c.start_time, insert_stmt.excluded.start_time),
2231
+ # Calculate duration in milliseconds using julianday (SQLite-specific)
2232
+ # julianday returns days, so multiply by 86400000 to get milliseconds
2233
+ "duration_ms": (
2234
+ func.julianday(func.max(table.c.end_time, insert_stmt.excluded.end_time))
2235
+ - func.julianday(func.min(table.c.start_time, insert_stmt.excluded.start_time))
2236
+ )
2237
+ * 86400000,
2238
+ "status": insert_stmt.excluded.status,
2239
+ # Update name only if new trace is from a higher-level component
2240
+ # Priority: workflow (3) > team (2) > agent (1) > child spans (0)
2241
+ "name": case(
2242
+ (new_level > existing_level, insert_stmt.excluded.name),
2243
+ else_=table.c.name,
2244
+ ),
2245
+ # Preserve existing non-null context values using COALESCE
2246
+ "run_id": func.coalesce(insert_stmt.excluded.run_id, table.c.run_id),
2247
+ "session_id": func.coalesce(insert_stmt.excluded.session_id, table.c.session_id),
2248
+ "user_id": func.coalesce(insert_stmt.excluded.user_id, table.c.user_id),
2249
+ "agent_id": func.coalesce(insert_stmt.excluded.agent_id, table.c.agent_id),
2250
+ "team_id": func.coalesce(insert_stmt.excluded.team_id, table.c.team_id),
2251
+ "workflow_id": func.coalesce(insert_stmt.excluded.workflow_id, table.c.workflow_id),
2252
+ },
2253
+ )
2254
+ sess.execute(upsert_stmt)
2255
+
2256
+ except Exception as e:
2257
+ log_error(f"Error creating trace: {e}")
2258
+ # Don't raise - tracing should not break the main application flow
2259
+
2260
+ def get_trace(
2261
+ self,
2262
+ trace_id: Optional[str] = None,
2263
+ run_id: Optional[str] = None,
2264
+ ):
2265
+ """Get a single trace by trace_id or other filters.
2266
+
2267
+ Args:
2268
+ trace_id: The unique trace identifier.
2269
+ run_id: Filter by run ID (returns first match).
2270
+
2271
+ Returns:
2272
+ Optional[Trace]: The trace if found, None otherwise.
2273
+
2274
+ Note:
2275
+ If multiple filters are provided, trace_id takes precedence.
2276
+ For other filters, the most recent trace is returned.
2277
+ """
2278
+ try:
2279
+ from agno.tracing.schemas import Trace
2280
+
2281
+ table = self._get_table(table_type="traces")
2282
+ if table is None:
2283
+ return None
2284
+
2285
+ # Get spans table for JOIN
2286
+ spans_table = self._get_table(table_type="spans")
2287
+
2288
+ with self.Session() as sess:
2289
+ # Build query with aggregated span counts
2290
+ stmt = self._get_traces_base_query(table, spans_table)
2291
+
2292
+ if trace_id:
2293
+ stmt = stmt.where(table.c.trace_id == trace_id)
2294
+ elif run_id:
2295
+ stmt = stmt.where(table.c.run_id == run_id)
2296
+ else:
2297
+ log_debug("get_trace called without any filter parameters")
2298
+ return None
2299
+
2300
+ # Order by most recent and get first result
2301
+ stmt = stmt.order_by(table.c.start_time.desc()).limit(1)
2302
+ result = sess.execute(stmt).fetchone()
2303
+
2304
+ if result:
2305
+ return Trace.from_dict(dict(result._mapping))
2306
+ return None
2307
+
2308
+ except Exception as e:
2309
+ log_error(f"Error getting trace: {e}")
2310
+ return None
2311
+
2312
+ def get_traces(
2313
+ self,
2314
+ run_id: Optional[str] = None,
2315
+ session_id: Optional[str] = None,
2316
+ user_id: Optional[str] = None,
2317
+ agent_id: Optional[str] = None,
2318
+ team_id: Optional[str] = None,
2319
+ workflow_id: Optional[str] = None,
2320
+ status: Optional[str] = None,
2321
+ start_time: Optional[datetime] = None,
2322
+ end_time: Optional[datetime] = None,
2323
+ limit: Optional[int] = 20,
2324
+ page: Optional[int] = 1,
2325
+ ) -> tuple[List, int]:
2326
+ """Get traces matching the provided filters with pagination.
2327
+
2328
+ Args:
2329
+ run_id: Filter by run ID.
2330
+ session_id: Filter by session ID.
2331
+ user_id: Filter by user ID.
2332
+ agent_id: Filter by agent ID.
2333
+ team_id: Filter by team ID.
2334
+ workflow_id: Filter by workflow ID.
2335
+ status: Filter by status (OK, ERROR, UNSET).
2336
+ start_time: Filter traces starting after this datetime.
2337
+ end_time: Filter traces ending before this datetime.
2338
+ limit: Maximum number of traces to return per page.
2339
+ page: Page number (1-indexed).
2340
+
2341
+ Returns:
2342
+ tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
2343
+ """
2344
+ try:
2345
+ from sqlalchemy import func
2346
+
2347
+ from agno.tracing.schemas import Trace
2348
+
2349
+ log_debug(
2350
+ f"get_traces called with filters: run_id={run_id}, session_id={session_id}, user_id={user_id}, agent_id={agent_id}, page={page}, limit={limit}"
2351
+ )
2352
+
2353
+ table = self._get_table(table_type="traces")
2354
+ if table is None:
2355
+ log_debug(" Traces table not found")
2356
+ return [], 0
2357
+
2358
+ # Get spans table for JOIN
2359
+ spans_table = self._get_table(table_type="spans")
2360
+
2361
+ with self.Session() as sess:
2362
+ # Build base query with aggregated span counts
2363
+ base_stmt = self._get_traces_base_query(table, spans_table)
2364
+
2365
+ # Apply filters
2366
+ if run_id:
2367
+ base_stmt = base_stmt.where(table.c.run_id == run_id)
2368
+ if session_id:
2369
+ base_stmt = base_stmt.where(table.c.session_id == session_id)
2370
+ if user_id:
2371
+ base_stmt = base_stmt.where(table.c.user_id == user_id)
2372
+ if agent_id:
2373
+ base_stmt = base_stmt.where(table.c.agent_id == agent_id)
2374
+ if team_id:
2375
+ base_stmt = base_stmt.where(table.c.team_id == team_id)
2376
+ if workflow_id:
2377
+ base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
2378
+ if status:
2379
+ base_stmt = base_stmt.where(table.c.status == status)
2380
+ if start_time:
2381
+ # Convert datetime to ISO string for comparison
2382
+ base_stmt = base_stmt.where(table.c.start_time >= start_time.isoformat())
2383
+ if end_time:
2384
+ # Convert datetime to ISO string for comparison
2385
+ base_stmt = base_stmt.where(table.c.end_time <= end_time.isoformat())
2386
+
2387
+ # Get total count
2388
+ count_stmt = select(func.count()).select_from(base_stmt.alias())
2389
+ total_count = sess.execute(count_stmt).scalar() or 0
2390
+
2391
+ # Apply pagination
2392
+ offset = (page - 1) * limit if page and limit else 0
2393
+ paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
2394
+
2395
+ results = sess.execute(paginated_stmt).fetchall()
2396
+
2397
+ traces = [Trace.from_dict(dict(row._mapping)) for row in results]
2398
+ return traces, total_count
2399
+
2400
+ except Exception as e:
2401
+ log_error(f"Error getting traces: {e}")
2402
+ return [], 0
2403
+
2404
+ def get_trace_stats(
2405
+ self,
2406
+ user_id: Optional[str] = None,
2407
+ agent_id: Optional[str] = None,
2408
+ team_id: Optional[str] = None,
2409
+ workflow_id: Optional[str] = None,
2410
+ start_time: Optional[datetime] = None,
2411
+ end_time: Optional[datetime] = None,
2412
+ limit: Optional[int] = 20,
2413
+ page: Optional[int] = 1,
2414
+ ) -> tuple[List[Dict[str, Any]], int]:
2415
+ """Get trace statistics grouped by session.
2416
+
2417
+ Args:
2418
+ user_id: Filter by user ID.
2419
+ agent_id: Filter by agent ID.
2420
+ team_id: Filter by team ID.
2421
+ workflow_id: Filter by workflow ID.
2422
+ start_time: Filter sessions with traces created after this datetime.
2423
+ end_time: Filter sessions with traces created before this datetime.
2424
+ limit: Maximum number of sessions to return per page.
2425
+ page: Page number (1-indexed).
2426
+
2427
+ Returns:
2428
+ tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
2429
+ """
2430
+ try:
2431
+ from sqlalchemy import func
2432
+
2433
+ table = self._get_table(table_type="traces")
2434
+ if table is None:
2435
+ log_debug("Traces table not found")
2436
+ return [], 0
2437
+
2438
+ with self.Session() as sess:
2439
+ # Build base query grouped by session_id
2440
+ base_stmt = (
2441
+ select(
2442
+ table.c.session_id,
2443
+ table.c.user_id,
2444
+ table.c.agent_id,
2445
+ table.c.team_id,
2446
+ table.c.workflow_id,
2447
+ func.count(table.c.trace_id).label("total_traces"),
2448
+ func.min(table.c.created_at).label("first_trace_at"),
2449
+ func.max(table.c.created_at).label("last_trace_at"),
2450
+ )
2451
+ .where(table.c.session_id.isnot(None)) # Only sessions with session_id
2452
+ .group_by(
2453
+ table.c.session_id, table.c.user_id, table.c.agent_id, table.c.team_id, table.c.workflow_id
2454
+ )
2455
+ )
2456
+
2457
+ # Apply filters
2458
+ if user_id:
2459
+ base_stmt = base_stmt.where(table.c.user_id == user_id)
2460
+ if workflow_id:
2461
+ base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
2462
+ if team_id:
2463
+ base_stmt = base_stmt.where(table.c.team_id == team_id)
2464
+ if agent_id:
2465
+ base_stmt = base_stmt.where(table.c.agent_id == agent_id)
2466
+ if start_time:
2467
+ # Convert datetime to ISO string for comparison
2468
+ base_stmt = base_stmt.where(table.c.created_at >= start_time.isoformat())
2469
+ if end_time:
2470
+ # Convert datetime to ISO string for comparison
2471
+ base_stmt = base_stmt.where(table.c.created_at <= end_time.isoformat())
2472
+
2473
+ # Get total count of sessions
2474
+ count_stmt = select(func.count()).select_from(base_stmt.alias())
2475
+ total_count = sess.execute(count_stmt).scalar() or 0
2476
+
2477
+ # Apply pagination and ordering
2478
+ offset = (page - 1) * limit if page and limit else 0
2479
+ paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
2480
+
2481
+ results = sess.execute(paginated_stmt).fetchall()
2482
+
2483
+ # Convert to list of dicts with datetime objects
2484
+ from datetime import datetime
2485
+
2486
+ stats_list = []
2487
+ for row in results:
2488
+ # Convert ISO strings to datetime objects
2489
+ first_trace_at_str = row.first_trace_at
2490
+ last_trace_at_str = row.last_trace_at
2491
+
2492
+ # Parse ISO format strings to datetime objects
2493
+ first_trace_at = datetime.fromisoformat(first_trace_at_str.replace("Z", "+00:00"))
2494
+ last_trace_at = datetime.fromisoformat(last_trace_at_str.replace("Z", "+00:00"))
2495
+
2496
+ stats_list.append(
2497
+ {
2498
+ "session_id": row.session_id,
2499
+ "user_id": row.user_id,
2500
+ "agent_id": row.agent_id,
2501
+ "team_id": row.team_id,
2502
+ "workflow_id": row.workflow_id,
2503
+ "total_traces": row.total_traces,
2504
+ "first_trace_at": first_trace_at,
2505
+ "last_trace_at": last_trace_at,
2506
+ }
2507
+ )
2508
+
2509
+ return stats_list, total_count
2510
+
2511
+ except Exception as e:
2512
+ log_error(f"Error getting trace stats: {e}")
2513
+ return [], 0
2514
+
2515
+ # -- Span methods --
2516
+
2517
+ def create_span(self, span: "Span") -> None:
2518
+ """Create a single span in the database.
2519
+
2520
+ Args:
2521
+ span: The Span object to store.
2522
+ """
2523
+ try:
2524
+ table = self._get_table(table_type="spans", create_table_if_not_found=True)
2525
+ if table is None:
2526
+ return
2527
+
2528
+ with self.Session() as sess, sess.begin():
2529
+ stmt = sqlite.insert(table).values(span.to_dict())
2530
+ sess.execute(stmt)
2531
+
2532
+ except Exception as e:
2533
+ log_error(f"Error creating span: {e}")
2534
+
2535
+ def create_spans(self, spans: List) -> None:
2536
+ """Create multiple spans in the database as a batch.
2537
+
2538
+ Args:
2539
+ spans: List of Span objects to store.
2540
+ """
2541
+ if not spans:
2542
+ return
2543
+
2544
+ try:
2545
+ table = self._get_table(table_type="spans", create_table_if_not_found=True)
2546
+ if table is None:
2547
+ return
2548
+
2549
+ with self.Session() as sess, sess.begin():
2550
+ for span in spans:
2551
+ stmt = sqlite.insert(table).values(span.to_dict())
2552
+ sess.execute(stmt)
2553
+
2554
+ except Exception as e:
2555
+ log_error(f"Error creating spans batch: {e}")
2556
+
2557
+ def get_span(self, span_id: str):
2558
+ """Get a single span by its span_id.
2559
+
2560
+ Args:
2561
+ span_id: The unique span identifier.
2562
+
2563
+ Returns:
2564
+ Optional[Span]: The span if found, None otherwise.
2565
+ """
2566
+ try:
2567
+ from agno.tracing.schemas import Span
2568
+
2569
+ table = self._get_table(table_type="spans")
2570
+ if table is None:
2571
+ return None
2572
+
2573
+ with self.Session() as sess:
2574
+ stmt = table.select().where(table.c.span_id == span_id)
2575
+ result = sess.execute(stmt).fetchone()
2576
+ if result:
2577
+ return Span.from_dict(dict(result._mapping))
2578
+ return None
2579
+
2580
+ except Exception as e:
2581
+ log_error(f"Error getting span: {e}")
2582
+ return None
2583
+
2584
+ def get_spans(
2585
+ self,
2586
+ trace_id: Optional[str] = None,
2587
+ parent_span_id: Optional[str] = None,
2588
+ ) -> List:
2589
+ """Get spans matching the provided filters.
2590
+
2591
+ Args:
2592
+ trace_id: Filter by trace ID.
2593
+ parent_span_id: Filter by parent span ID.
2594
+
2595
+ Returns:
2596
+ List[Span]: List of matching spans.
2597
+ """
2598
+ try:
2599
+ from agno.tracing.schemas import Span
2600
+
2601
+ table = self._get_table(table_type="spans")
2602
+ if table is None:
2603
+ return []
2604
+
2605
+ with self.Session() as sess:
2606
+ stmt = table.select()
2607
+
2608
+ # Apply filters
2609
+ if trace_id:
2610
+ stmt = stmt.where(table.c.trace_id == trace_id)
2611
+ if parent_span_id:
2612
+ stmt = stmt.where(table.c.parent_span_id == parent_span_id)
2613
+
2614
+ results = sess.execute(stmt).fetchall()
2615
+ return [Span.from_dict(dict(row._mapping)) for row in results]
2616
+
2617
+ except Exception as e:
2618
+ log_error(f"Error getting spans: {e}")
2619
+ return []
2620
+
1933
2621
  # -- Migrations --
1934
2622
 
1935
2623
  def migrate_table_from_v1_to_v2(self, v1_db_schema: str, v1_table_name: str, v1_table_type: str):
@@ -1987,3 +2675,234 @@ class SqliteDb(BaseDb):
1987
2675
  for memory in memories:
1988
2676
  self.upsert_user_memory(memory)
1989
2677
  log_info(f"Migrated {len(memories)} memories to table: {self.memory_table}")
2678
+
2679
+ # -- Culture methods --
2680
+
2681
+ def clear_cultural_knowledge(self) -> None:
2682
+ """Delete all cultural artifacts from the database.
2683
+
2684
+ Raises:
2685
+ Exception: If an error occurs during deletion.
2686
+ """
2687
+ try:
2688
+ table = self._get_table(table_type="culture")
2689
+ if table is None:
2690
+ return
2691
+
2692
+ with self.Session() as sess, sess.begin():
2693
+ sess.execute(table.delete())
2694
+
2695
+ except Exception as e:
2696
+ from agno.utils.log import log_warning
2697
+
2698
+ log_warning(f"Exception deleting all cultural artifacts: {e}")
2699
+ raise e
2700
+
2701
+ def delete_cultural_knowledge(self, id: str) -> None:
2702
+ """Delete a cultural artifact from the database.
2703
+
2704
+ Args:
2705
+ id (str): The ID of the cultural artifact to delete.
2706
+
2707
+ Raises:
2708
+ Exception: If an error occurs during deletion.
2709
+ """
2710
+ try:
2711
+ table = self._get_table(table_type="culture")
2712
+ if table is None:
2713
+ return
2714
+
2715
+ with self.Session() as sess, sess.begin():
2716
+ delete_stmt = table.delete().where(table.c.id == id)
2717
+ result = sess.execute(delete_stmt)
2718
+
2719
+ success = result.rowcount > 0
2720
+ if success:
2721
+ log_debug(f"Successfully deleted cultural artifact id: {id}")
2722
+ else:
2723
+ log_debug(f"No cultural artifact found with id: {id}")
2724
+
2725
+ except Exception as e:
2726
+ log_error(f"Error deleting cultural artifact: {e}")
2727
+ raise e
2728
+
2729
+ def get_cultural_knowledge(
2730
+ self, id: str, deserialize: Optional[bool] = True
2731
+ ) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
2732
+ """Get a cultural artifact from the database.
2733
+
2734
+ Args:
2735
+ id (str): The ID of the cultural artifact to get.
2736
+ deserialize (Optional[bool]): Whether to serialize the cultural artifact. Defaults to True.
2737
+
2738
+ Returns:
2739
+ Optional[CulturalKnowledge]: The cultural artifact, or None if it doesn't exist.
2740
+
2741
+ Raises:
2742
+ Exception: If an error occurs during retrieval.
2743
+ """
2744
+ try:
2745
+ table = self._get_table(table_type="culture")
2746
+ if table is None:
2747
+ return None
2748
+
2749
+ with self.Session() as sess, sess.begin():
2750
+ stmt = select(table).where(table.c.id == id)
2751
+ result = sess.execute(stmt).fetchone()
2752
+ if result is None:
2753
+ return None
2754
+
2755
+ db_row = dict(result._mapping)
2756
+ if not db_row or not deserialize:
2757
+ return db_row
2758
+
2759
+ return deserialize_cultural_knowledge_from_db(db_row)
2760
+
2761
+ except Exception as e:
2762
+ log_error(f"Exception reading from cultural artifacts table: {e}")
2763
+ raise e
2764
+
2765
+ def get_all_cultural_knowledge(
2766
+ self,
2767
+ name: Optional[str] = None,
2768
+ agent_id: Optional[str] = None,
2769
+ team_id: Optional[str] = None,
2770
+ limit: Optional[int] = None,
2771
+ page: Optional[int] = None,
2772
+ sort_by: Optional[str] = None,
2773
+ sort_order: Optional[str] = None,
2774
+ deserialize: Optional[bool] = True,
2775
+ ) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
2776
+ """Get all cultural artifacts from the database as CulturalNotion objects.
2777
+
2778
+ Args:
2779
+ name (Optional[str]): The name of the cultural artifact to filter by.
2780
+ agent_id (Optional[str]): The ID of the agent to filter by.
2781
+ team_id (Optional[str]): The ID of the team to filter by.
2782
+ limit (Optional[int]): The maximum number of cultural artifacts to return.
2783
+ page (Optional[int]): The page number.
2784
+ sort_by (Optional[str]): The column to sort by.
2785
+ sort_order (Optional[str]): The order to sort by.
2786
+ deserialize (Optional[bool]): Whether to serialize the cultural artifacts. Defaults to True.
2787
+
2788
+ Returns:
2789
+ Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
2790
+ - When deserialize=True: List of CulturalNotion objects
2791
+ - When deserialize=False: List of CulturalNotion dictionaries and total count
2792
+
2793
+ Raises:
2794
+ Exception: If an error occurs during retrieval.
2795
+ """
2796
+ try:
2797
+ table = self._get_table(table_type="culture")
2798
+ if table is None:
2799
+ return [] if deserialize else ([], 0)
2800
+
2801
+ with self.Session() as sess, sess.begin():
2802
+ stmt = select(table)
2803
+
2804
+ # Filtering
2805
+ if name is not None:
2806
+ stmt = stmt.where(table.c.name == name)
2807
+ if agent_id is not None:
2808
+ stmt = stmt.where(table.c.agent_id == agent_id)
2809
+ if team_id is not None:
2810
+ stmt = stmt.where(table.c.team_id == team_id)
2811
+
2812
+ # Get total count after applying filtering
2813
+ count_stmt = select(func.count()).select_from(stmt.alias())
2814
+ total_count = sess.execute(count_stmt).scalar()
2815
+
2816
+ # Sorting
2817
+ stmt = apply_sorting(stmt, table, sort_by, sort_order)
2818
+ # Paginating
2819
+ if limit is not None:
2820
+ stmt = stmt.limit(limit)
2821
+ if page is not None:
2822
+ stmt = stmt.offset((page - 1) * limit)
2823
+
2824
+ result = sess.execute(stmt).fetchall()
2825
+ if not result:
2826
+ return [] if deserialize else ([], 0)
2827
+
2828
+ db_rows = [dict(record._mapping) for record in result]
2829
+
2830
+ if not deserialize:
2831
+ return db_rows, total_count
2832
+
2833
+ return [deserialize_cultural_knowledge_from_db(row) for row in db_rows]
2834
+
2835
+ except Exception as e:
2836
+ log_error(f"Error reading from cultural artifacts table: {e}")
2837
+ raise e
2838
+
2839
+ def upsert_cultural_knowledge(
2840
+ self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
2841
+ ) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
2842
+ """Upsert a cultural artifact into the database.
2843
+
2844
+ Args:
2845
+ cultural_knowledge (CulturalKnowledge): The cultural artifact to upsert.
2846
+ deserialize (Optional[bool]): Whether to serialize the cultural artifact. Defaults to True.
2847
+
2848
+ Returns:
2849
+ Optional[Union[CulturalNotion, Dict[str, Any]]]:
2850
+ - When deserialize=True: CulturalNotion object
2851
+ - When deserialize=False: CulturalNotion dictionary
2852
+
2853
+ Raises:
2854
+ Exception: If an error occurs during upsert.
2855
+ """
2856
+ try:
2857
+ table = self._get_table(table_type="culture", create_table_if_not_found=True)
2858
+ if table is None:
2859
+ return None
2860
+
2861
+ if cultural_knowledge.id is None:
2862
+ cultural_knowledge.id = str(uuid4())
2863
+
2864
+ # Serialize content, categories, and notes into a JSON string for DB storage (SQLite requires strings)
2865
+ content_json_str = serialize_cultural_knowledge_for_db(cultural_knowledge)
2866
+
2867
+ with self.Session() as sess, sess.begin():
2868
+ stmt = sqlite.insert(table).values(
2869
+ id=cultural_knowledge.id,
2870
+ name=cultural_knowledge.name,
2871
+ summary=cultural_knowledge.summary,
2872
+ content=content_json_str,
2873
+ metadata=cultural_knowledge.metadata,
2874
+ input=cultural_knowledge.input,
2875
+ created_at=cultural_knowledge.created_at,
2876
+ updated_at=int(time.time()),
2877
+ agent_id=cultural_knowledge.agent_id,
2878
+ team_id=cultural_knowledge.team_id,
2879
+ )
2880
+ stmt = stmt.on_conflict_do_update( # type: ignore
2881
+ index_elements=["id"],
2882
+ set_=dict(
2883
+ name=cultural_knowledge.name,
2884
+ summary=cultural_knowledge.summary,
2885
+ content=content_json_str,
2886
+ metadata=cultural_knowledge.metadata,
2887
+ input=cultural_knowledge.input,
2888
+ updated_at=int(time.time()),
2889
+ agent_id=cultural_knowledge.agent_id,
2890
+ team_id=cultural_knowledge.team_id,
2891
+ ),
2892
+ ).returning(table)
2893
+
2894
+ result = sess.execute(stmt)
2895
+ row = result.fetchone()
2896
+
2897
+ if row is None:
2898
+ return None
2899
+
2900
+ db_row: Dict[str, Any] = dict(row._mapping)
2901
+ if not db_row or not deserialize:
2902
+ return db_row
2903
+
2904
+ return deserialize_cultural_knowledge_from_db(db_row)
2905
+
2906
+ except Exception as e:
2907
+ log_error(f"Error upserting cultural knowledge: {e}")
2908
+ raise e