agno 2.1.2__py3-none-any.whl → 2.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +5540 -2273
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +689 -6
- agno/db/dynamo/dynamo.py +933 -37
- agno/db/dynamo/schemas.py +174 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +831 -9
- agno/db/firestore/schemas.py +51 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +660 -12
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +287 -14
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +590 -14
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +43 -13
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2760 -0
- agno/db/mongo/mongo.py +879 -11
- agno/db/mongo/schemas.py +42 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2912 -0
- agno/db/mysql/mysql.py +946 -68
- agno/db/mysql/schemas.py +72 -10
- agno/db/mysql/utils.py +198 -7
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2579 -0
- agno/db/postgres/postgres.py +942 -57
- agno/db/postgres/schemas.py +81 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +671 -7
- agno/db/redis/schemas.py +50 -0
- agno/db/redis/utils.py +65 -7
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +17 -2
- agno/db/singlestore/schemas.py +63 -0
- agno/db/singlestore/singlestore.py +949 -83
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2911 -0
- agno/db/sqlite/schemas.py +62 -0
- agno/db/sqlite/sqlite.py +965 -46
- agno/db/sqlite/utils.py +169 -8
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +334 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1908 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +2 -0
- agno/eval/__init__.py +10 -0
- agno/eval/accuracy.py +75 -55
- agno/eval/agent_as_judge.py +861 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +16 -7
- agno/eval/reliability.py +28 -16
- agno/eval/utils.py +35 -17
- agno/exceptions.py +27 -2
- agno/filters.py +354 -0
- agno/guardrails/prompt_injection.py +1 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +1 -1
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/semantic.py +9 -4
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +8 -0
- agno/knowledge/embedder/openai.py +8 -8
- agno/knowledge/embedder/sentence_transformer.py +6 -2
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/knowledge.py +1618 -318
- agno/knowledge/reader/base.py +6 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +16 -20
- agno/knowledge/reader/json_reader.py +5 -4
- agno/knowledge/reader/markdown_reader.py +8 -8
- agno/knowledge/reader/pdf_reader.py +17 -19
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +32 -3
- agno/knowledge/reader/s3_reader.py +3 -3
- agno/knowledge/reader/tavily_reader.py +193 -0
- agno/knowledge/reader/text_reader.py +22 -10
- agno/knowledge/reader/web_search_reader.py +1 -48
- agno/knowledge/reader/website_reader.py +10 -10
- agno/knowledge/reader/wikipedia_reader.py +33 -1
- agno/knowledge/types.py +1 -0
- agno/knowledge/utils.py +72 -7
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +544 -83
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +515 -40
- agno/models/aws/bedrock.py +102 -21
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +41 -19
- agno/models/azure/openai_chat.py +39 -8
- agno/models/base.py +1249 -525
- agno/models/cerebras/cerebras.py +91 -21
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +40 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +877 -80
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +51 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +44 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +28 -5
- agno/models/meta/llama.py +47 -14
- agno/models/meta/llama_openai.py +22 -17
- agno/models/mistral/mistral.py +8 -4
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/chat.py +24 -8
- agno/models/openai/chat.py +104 -29
- agno/models/openai/responses.py +101 -81
- agno/models/openrouter/openrouter.py +60 -3
- agno/models/perplexity/perplexity.py +17 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +24 -4
- agno/models/response.py +73 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +190 -0
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +549 -152
- agno/os/auth.py +190 -3
- agno/os/config.py +23 -0
- agno/os/interfaces/a2a/router.py +8 -11
- agno/os/interfaces/a2a/utils.py +1 -1
- agno/os/interfaces/agui/router.py +18 -3
- agno/os/interfaces/agui/utils.py +152 -39
- agno/os/interfaces/slack/router.py +55 -37
- agno/os/interfaces/slack/slack.py +9 -1
- agno/os/interfaces/whatsapp/router.py +0 -1
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/mcp.py +110 -52
- agno/os/middleware/__init__.py +2 -0
- agno/os/middleware/jwt.py +676 -112
- agno/os/router.py +40 -1478
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +599 -0
- agno/os/routers/agents/schema.py +261 -0
- agno/os/routers/evals/evals.py +96 -39
- agno/os/routers/evals/schemas.py +65 -33
- agno/os/routers/evals/utils.py +80 -10
- agno/os/routers/health.py +10 -4
- agno/os/routers/knowledge/knowledge.py +196 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +279 -52
- agno/os/routers/memory/schemas.py +46 -17
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +462 -34
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +512 -0
- agno/os/routers/teams/schema.py +257 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +499 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +624 -0
- agno/os/routers/workflows/schema.py +75 -0
- agno/os/schema.py +256 -693
- agno/os/scopes.py +469 -0
- agno/os/utils.py +514 -36
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/openai.py +5 -0
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +155 -32
- agno/run/base.py +55 -3
- agno/run/requirement.py +181 -0
- agno/run/team.py +125 -38
- agno/run/workflow.py +72 -18
- agno/session/agent.py +102 -89
- agno/session/summary.py +56 -15
- agno/session/team.py +164 -90
- agno/session/workflow.py +405 -40
- agno/table.py +10 -0
- agno/team/team.py +3974 -1903
- agno/tools/dalle.py +2 -4
- agno/tools/eleven_labs.py +23 -25
- agno/tools/exa.py +21 -16
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +16 -10
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +193 -38
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +271 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +3 -3
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/notion.py +204 -0
- agno/tools/parallel.py +314 -0
- agno/tools/postgres.py +76 -36
- agno/tools/redshift.py +406 -0
- agno/tools/scrapegraph.py +1 -1
- agno/tools/shopify.py +1519 -0
- agno/tools/slack.py +18 -3
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +146 -0
- agno/tools/toolkit.py +25 -0
- agno/tools/workflow.py +8 -1
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +157 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +111 -0
- agno/utils/agent.py +938 -0
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +151 -3
- agno/utils/gemini.py +15 -5
- agno/utils/hooks.py +118 -4
- agno/utils/http.py +113 -2
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +92 -2
- agno/utils/media.py +187 -1
- agno/utils/merge_dict.py +3 -3
- agno/utils/message.py +60 -0
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +49 -14
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/print_response/agent.py +109 -16
- agno/utils/print_response/team.py +223 -30
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/streamlit.py +1 -1
- agno/utils/team.py +98 -9
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +39 -7
- agno/vectordb/cassandra/cassandra.py +21 -5
- agno/vectordb/chroma/chromadb.py +43 -12
- agno/vectordb/clickhouse/clickhousedb.py +21 -5
- agno/vectordb/couchbase/couchbase.py +29 -5
- agno/vectordb/lancedb/lance_db.py +92 -181
- agno/vectordb/langchaindb/langchaindb.py +24 -4
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/llamaindexdb.py +25 -5
- agno/vectordb/milvus/milvus.py +50 -37
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +36 -30
- agno/vectordb/pgvector/pgvector.py +201 -77
- agno/vectordb/pineconedb/pineconedb.py +41 -23
- agno/vectordb/qdrant/qdrant.py +67 -54
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +682 -0
- agno/vectordb/singlestore/singlestore.py +50 -29
- agno/vectordb/surrealdb/surrealdb.py +31 -41
- agno/vectordb/upstashdb/upstashdb.py +34 -6
- agno/vectordb/weaviate/weaviate.py +53 -14
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +120 -18
- agno/workflow/loop.py +77 -10
- agno/workflow/parallel.py +231 -143
- agno/workflow/router.py +118 -17
- agno/workflow/step.py +609 -170
- agno/workflow/steps.py +73 -6
- agno/workflow/types.py +96 -21
- agno/workflow/workflow.py +2039 -262
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/METADATA +201 -66
- agno-2.3.13.dist-info/RECORD +613 -0
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -679
- agno/tools/memori.py +0 -339
- agno-2.1.2.dist-info/RECORD +0 -543
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/WHEEL +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/top_level.txt +0 -0
agno/db/sqlite/sqlite.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from datetime import date, datetime, timedelta, timezone
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
|
5
5
|
from uuid import uuid4
|
|
6
6
|
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from agno.tracing.schemas import Span, Trace
|
|
9
|
+
|
|
7
10
|
from agno.db.base import BaseDb, SessionType
|
|
11
|
+
from agno.db.migrations.manager import MigrationManager
|
|
12
|
+
from agno.db.schemas.culture import CulturalKnowledge
|
|
8
13
|
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
9
14
|
from agno.db.schemas.knowledge import KnowledgeRow
|
|
10
15
|
from agno.db.schemas.memory import UserMemory
|
|
@@ -13,10 +18,12 @@ from agno.db.sqlite.utils import (
|
|
|
13
18
|
apply_sorting,
|
|
14
19
|
bulk_upsert_metrics,
|
|
15
20
|
calculate_date_metrics,
|
|
21
|
+
deserialize_cultural_knowledge_from_db,
|
|
16
22
|
fetch_all_sessions_data,
|
|
17
23
|
get_dates_to_calculate_metrics_for,
|
|
18
24
|
is_table_available,
|
|
19
25
|
is_valid_table,
|
|
26
|
+
serialize_cultural_knowledge_for_db,
|
|
20
27
|
)
|
|
21
28
|
from agno.db.utils import deserialize_session_json_fields, serialize_session_json_fields
|
|
22
29
|
from agno.session import AgentSession, Session, TeamSession, WorkflowSession
|
|
@@ -24,11 +31,11 @@ from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
|
24
31
|
from agno.utils.string import generate_id
|
|
25
32
|
|
|
26
33
|
try:
|
|
27
|
-
from sqlalchemy import Column, MetaData,
|
|
34
|
+
from sqlalchemy import Column, MetaData, String, Table, func, select, text
|
|
28
35
|
from sqlalchemy.dialects import sqlite
|
|
29
36
|
from sqlalchemy.engine import Engine, create_engine
|
|
30
37
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
|
31
|
-
from sqlalchemy.schema import Index, UniqueConstraint
|
|
38
|
+
from sqlalchemy.schema import ForeignKey, Index, UniqueConstraint
|
|
32
39
|
except ImportError:
|
|
33
40
|
raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`")
|
|
34
41
|
|
|
@@ -36,14 +43,18 @@ except ImportError:
|
|
|
36
43
|
class SqliteDb(BaseDb):
|
|
37
44
|
def __init__(
|
|
38
45
|
self,
|
|
46
|
+
db_file: Optional[str] = None,
|
|
39
47
|
db_engine: Optional[Engine] = None,
|
|
40
48
|
db_url: Optional[str] = None,
|
|
41
|
-
db_file: Optional[str] = None,
|
|
42
49
|
session_table: Optional[str] = None,
|
|
50
|
+
culture_table: Optional[str] = None,
|
|
43
51
|
memory_table: Optional[str] = None,
|
|
44
52
|
metrics_table: Optional[str] = None,
|
|
45
53
|
eval_table: Optional[str] = None,
|
|
46
54
|
knowledge_table: Optional[str] = None,
|
|
55
|
+
traces_table: Optional[str] = None,
|
|
56
|
+
spans_table: Optional[str] = None,
|
|
57
|
+
versions_table: Optional[str] = None,
|
|
47
58
|
id: Optional[str] = None,
|
|
48
59
|
):
|
|
49
60
|
"""
|
|
@@ -56,14 +67,18 @@ class SqliteDb(BaseDb):
|
|
|
56
67
|
4. Create a new database in the current directory
|
|
57
68
|
|
|
58
69
|
Args:
|
|
70
|
+
db_file (Optional[str]): The database file to connect to.
|
|
59
71
|
db_engine (Optional[Engine]): The SQLAlchemy database engine to use.
|
|
60
72
|
db_url (Optional[str]): The database URL to connect to.
|
|
61
|
-
db_file (Optional[str]): The database file to connect to.
|
|
62
73
|
session_table (Optional[str]): Name of the table to store Agent, Team and Workflow sessions.
|
|
74
|
+
culture_table (Optional[str]): Name of the table to store cultural notions.
|
|
63
75
|
memory_table (Optional[str]): Name of the table to store user memories.
|
|
64
76
|
metrics_table (Optional[str]): Name of the table to store metrics.
|
|
65
77
|
eval_table (Optional[str]): Name of the table to store evaluation runs data.
|
|
66
78
|
knowledge_table (Optional[str]): Name of the table to store knowledge documents data.
|
|
79
|
+
traces_table (Optional[str]): Name of the table to store run traces.
|
|
80
|
+
spans_table (Optional[str]): Name of the table to store span events.
|
|
81
|
+
versions_table (Optional[str]): Name of the table to store schema versions.
|
|
67
82
|
id (Optional[str]): ID of the database.
|
|
68
83
|
|
|
69
84
|
Raises:
|
|
@@ -76,10 +91,14 @@ class SqliteDb(BaseDb):
|
|
|
76
91
|
super().__init__(
|
|
77
92
|
id=id,
|
|
78
93
|
session_table=session_table,
|
|
94
|
+
culture_table=culture_table,
|
|
79
95
|
memory_table=memory_table,
|
|
80
96
|
metrics_table=metrics_table,
|
|
81
97
|
eval_table=eval_table,
|
|
82
98
|
knowledge_table=knowledge_table,
|
|
99
|
+
traces_table=traces_table,
|
|
100
|
+
spans_table=spans_table,
|
|
101
|
+
versions_table=versions_table,
|
|
83
102
|
)
|
|
84
103
|
|
|
85
104
|
_engine: Optional[Engine] = db_engine
|
|
@@ -107,6 +126,31 @@ class SqliteDb(BaseDb):
|
|
|
107
126
|
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
|
|
108
127
|
|
|
109
128
|
# -- DB methods --
|
|
129
|
+
def table_exists(self, table_name: str) -> bool:
|
|
130
|
+
"""Check if a table with the given name exists in the SQLite database.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
table_name: Name of the table to check
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
bool: True if the table exists in the database, False otherwise
|
|
137
|
+
"""
|
|
138
|
+
with self.Session() as sess:
|
|
139
|
+
return is_table_available(session=sess, table_name=table_name)
|
|
140
|
+
|
|
141
|
+
def _create_all_tables(self):
|
|
142
|
+
"""Create all tables for the database."""
|
|
143
|
+
tables_to_create = [
|
|
144
|
+
(self.session_table_name, "sessions"),
|
|
145
|
+
(self.memory_table_name, "memories"),
|
|
146
|
+
(self.metrics_table_name, "metrics"),
|
|
147
|
+
(self.eval_table_name, "evals"),
|
|
148
|
+
(self.knowledge_table_name, "knowledge"),
|
|
149
|
+
(self.versions_table_name, "versions"),
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
for table_name, table_type in tables_to_create:
|
|
153
|
+
self._get_or_create_table(table_name=table_name, table_type=table_type, create_table_if_not_found=True)
|
|
110
154
|
|
|
111
155
|
def _create_table(self, table_name: str, table_type: str) -> Table:
|
|
112
156
|
"""
|
|
@@ -120,8 +164,7 @@ class SqliteDb(BaseDb):
|
|
|
120
164
|
Table: SQLAlchemy Table object
|
|
121
165
|
"""
|
|
122
166
|
try:
|
|
123
|
-
table_schema = get_table_schema_definition(table_type)
|
|
124
|
-
log_debug(f"Creating table {table_name} with schema: {table_schema}")
|
|
167
|
+
table_schema = get_table_schema_definition(table_type).copy()
|
|
125
168
|
|
|
126
169
|
columns: List[Column] = []
|
|
127
170
|
indexes: List[str] = []
|
|
@@ -143,11 +186,19 @@ class SqliteDb(BaseDb):
|
|
|
143
186
|
column_kwargs["unique"] = True
|
|
144
187
|
unique_constraints.append(col_name)
|
|
145
188
|
|
|
189
|
+
# Handle foreign key constraint
|
|
190
|
+
if "foreign_key" in col_config:
|
|
191
|
+
fk_ref = col_config["foreign_key"]
|
|
192
|
+
# For spans table, dynamically replace the traces table reference
|
|
193
|
+
# with the actual trace table name configured for this db instance
|
|
194
|
+
if table_type == "spans" and "trace_id" in fk_ref:
|
|
195
|
+
fk_ref = f"{self.trace_table_name}.trace_id"
|
|
196
|
+
column_args.append(ForeignKey(fk_ref))
|
|
197
|
+
|
|
146
198
|
columns.append(Column(*column_args, **column_kwargs)) # type: ignore
|
|
147
199
|
|
|
148
200
|
# Create the table object
|
|
149
|
-
|
|
150
|
-
table = Table(table_name, table_metadata, *columns)
|
|
201
|
+
table = Table(table_name, self.metadata, *columns)
|
|
151
202
|
|
|
152
203
|
# Add multi-column unique constraints with table-specific names
|
|
153
204
|
for constraint in schema_unique_constraints:
|
|
@@ -161,12 +212,17 @@ class SqliteDb(BaseDb):
|
|
|
161
212
|
table.append_constraint(Index(idx_name, idx_col))
|
|
162
213
|
|
|
163
214
|
# Create table
|
|
164
|
-
|
|
215
|
+
table_created = False
|
|
216
|
+
if not self.table_exists(table_name):
|
|
217
|
+
table.create(self.db_engine, checkfirst=True)
|
|
218
|
+
log_debug(f"Successfully created table '{table_name}'")
|
|
219
|
+
table_created = True
|
|
220
|
+
else:
|
|
221
|
+
log_debug(f"Table '{table_name}' already exists, skipping creation")
|
|
165
222
|
|
|
166
223
|
# Create indexes
|
|
167
224
|
for idx in table.indexes:
|
|
168
225
|
try:
|
|
169
|
-
log_debug(f"Creating index: {idx.name}")
|
|
170
226
|
# Check if index already exists
|
|
171
227
|
with self.Session() as sess:
|
|
172
228
|
exists_query = text("SELECT 1 FROM sqlite_master WHERE type = 'index' AND name = :index_name")
|
|
@@ -177,13 +233,21 @@ class SqliteDb(BaseDb):
|
|
|
177
233
|
|
|
178
234
|
idx.create(self.db_engine)
|
|
179
235
|
|
|
236
|
+
log_debug(f"Created index: {idx.name} for table {table_name}")
|
|
180
237
|
except Exception as e:
|
|
181
238
|
log_warning(f"Error creating index {idx.name}: {e}")
|
|
182
239
|
|
|
183
|
-
|
|
240
|
+
# Store the schema version for the created table
|
|
241
|
+
if table_name != self.versions_table_name and table_created:
|
|
242
|
+
latest_schema_version = MigrationManager(self).latest_schema_version
|
|
243
|
+
self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
|
|
244
|
+
|
|
184
245
|
return table
|
|
185
246
|
|
|
186
247
|
except Exception as e:
|
|
248
|
+
from traceback import print_exc
|
|
249
|
+
|
|
250
|
+
print_exc()
|
|
187
251
|
log_error(f"Could not create table '{table_name}': {e}")
|
|
188
252
|
raise e
|
|
189
253
|
|
|
@@ -229,11 +293,50 @@ class SqliteDb(BaseDb):
|
|
|
229
293
|
)
|
|
230
294
|
return self.knowledge_table
|
|
231
295
|
|
|
296
|
+
elif table_type == "traces":
|
|
297
|
+
self.traces_table = self._get_or_create_table(
|
|
298
|
+
table_name=self.trace_table_name,
|
|
299
|
+
table_type="traces",
|
|
300
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
301
|
+
)
|
|
302
|
+
return self.traces_table
|
|
303
|
+
|
|
304
|
+
elif table_type == "spans":
|
|
305
|
+
# Ensure traces table exists first (spans has FK to traces)
|
|
306
|
+
if create_table_if_not_found:
|
|
307
|
+
self._get_table(table_type="traces", create_table_if_not_found=True)
|
|
308
|
+
|
|
309
|
+
self.spans_table = self._get_or_create_table(
|
|
310
|
+
table_name=self.span_table_name,
|
|
311
|
+
table_type="spans",
|
|
312
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
313
|
+
)
|
|
314
|
+
return self.spans_table
|
|
315
|
+
|
|
316
|
+
elif table_type == "culture":
|
|
317
|
+
self.culture_table = self._get_or_create_table(
|
|
318
|
+
table_name=self.culture_table_name,
|
|
319
|
+
table_type="culture",
|
|
320
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
321
|
+
)
|
|
322
|
+
return self.culture_table
|
|
323
|
+
|
|
324
|
+
elif table_type == "versions":
|
|
325
|
+
self.versions_table = self._get_or_create_table(
|
|
326
|
+
table_name=self.versions_table_name,
|
|
327
|
+
table_type="versions",
|
|
328
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
329
|
+
)
|
|
330
|
+
return self.versions_table
|
|
331
|
+
|
|
232
332
|
else:
|
|
233
333
|
raise ValueError(f"Unknown table type: '{table_type}'")
|
|
234
334
|
|
|
235
335
|
def _get_or_create_table(
|
|
236
|
-
self,
|
|
336
|
+
self,
|
|
337
|
+
table_name: str,
|
|
338
|
+
table_type: str,
|
|
339
|
+
create_table_if_not_found: Optional[bool] = False,
|
|
237
340
|
) -> Optional[Table]:
|
|
238
341
|
"""
|
|
239
342
|
Check if the table exists and is valid, else create it.
|
|
@@ -259,13 +362,48 @@ class SqliteDb(BaseDb):
|
|
|
259
362
|
|
|
260
363
|
try:
|
|
261
364
|
table = Table(table_name, self.metadata, autoload_with=self.db_engine)
|
|
262
|
-
log_debug(f"Loaded existing table {table_name}")
|
|
263
365
|
return table
|
|
264
366
|
|
|
265
367
|
except Exception as e:
|
|
266
368
|
log_error(f"Error loading existing table {table_name}: {e}")
|
|
267
369
|
raise e
|
|
268
370
|
|
|
371
|
+
def get_latest_schema_version(self, table_name: str):
|
|
372
|
+
"""Get the latest version of the database schema."""
|
|
373
|
+
table = self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
374
|
+
if table is None:
|
|
375
|
+
return "2.0.0"
|
|
376
|
+
with self.Session() as sess:
|
|
377
|
+
stmt = select(table)
|
|
378
|
+
# Latest version for the given table
|
|
379
|
+
stmt = stmt.where(table.c.table_name == table_name)
|
|
380
|
+
stmt = stmt.order_by(table.c.version.desc()).limit(1)
|
|
381
|
+
result = sess.execute(stmt).fetchone()
|
|
382
|
+
if result is None:
|
|
383
|
+
return "2.0.0"
|
|
384
|
+
version_dict = dict(result._mapping)
|
|
385
|
+
return version_dict.get("version") or "2.0.0"
|
|
386
|
+
|
|
387
|
+
def upsert_schema_version(self, table_name: str, version: str) -> None:
|
|
388
|
+
"""Upsert the schema version into the database."""
|
|
389
|
+
table = self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
390
|
+
if table is None:
|
|
391
|
+
return
|
|
392
|
+
current_datetime = datetime.now().isoformat()
|
|
393
|
+
with self.Session() as sess, sess.begin():
|
|
394
|
+
stmt = sqlite.insert(table).values(
|
|
395
|
+
table_name=table_name,
|
|
396
|
+
version=version,
|
|
397
|
+
created_at=current_datetime, # Store as ISO format string
|
|
398
|
+
updated_at=current_datetime,
|
|
399
|
+
)
|
|
400
|
+
# Update version if table_name already exists
|
|
401
|
+
stmt = stmt.on_conflict_do_update(
|
|
402
|
+
index_elements=["table_name"],
|
|
403
|
+
set_=dict(version=version, updated_at=current_datetime),
|
|
404
|
+
)
|
|
405
|
+
sess.execute(stmt)
|
|
406
|
+
|
|
269
407
|
# -- Session methods --
|
|
270
408
|
|
|
271
409
|
def delete_session(self, session_id: str) -> bool:
|
|
@@ -357,8 +495,6 @@ class SqliteDb(BaseDb):
|
|
|
357
495
|
# Filtering
|
|
358
496
|
if user_id is not None:
|
|
359
497
|
stmt = stmt.where(table.c.user_id == user_id)
|
|
360
|
-
if session_type is not None:
|
|
361
|
-
stmt = stmt.where(table.c.session_type == session_type)
|
|
362
498
|
|
|
363
499
|
result = sess.execute(stmt).fetchone()
|
|
364
500
|
if result is None:
|
|
@@ -483,7 +619,11 @@ class SqliteDb(BaseDb):
|
|
|
483
619
|
raise e
|
|
484
620
|
|
|
485
621
|
def rename_session(
|
|
486
|
-
self,
|
|
622
|
+
self,
|
|
623
|
+
session_id: str,
|
|
624
|
+
session_type: SessionType,
|
|
625
|
+
session_name: str,
|
|
626
|
+
deserialize: Optional[bool] = True,
|
|
487
627
|
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
488
628
|
"""
|
|
489
629
|
Rename a session in the database.
|
|
@@ -664,7 +804,10 @@ class SqliteDb(BaseDb):
|
|
|
664
804
|
raise e
|
|
665
805
|
|
|
666
806
|
def upsert_sessions(
|
|
667
|
-
self,
|
|
807
|
+
self,
|
|
808
|
+
sessions: List[Session],
|
|
809
|
+
deserialize: Optional[bool] = True,
|
|
810
|
+
preserve_updated_at: bool = False,
|
|
668
811
|
) -> List[Union[Session, Dict[str, Any]]]:
|
|
669
812
|
"""
|
|
670
813
|
Bulk upsert multiple sessions for improved performance on large datasets.
|
|
@@ -672,6 +815,7 @@ class SqliteDb(BaseDb):
|
|
|
672
815
|
Args:
|
|
673
816
|
sessions (List[Session]): List of sessions to upsert.
|
|
674
817
|
deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
|
|
818
|
+
preserve_updated_at (bool): If True, preserve the updated_at from the session object.
|
|
675
819
|
|
|
676
820
|
Returns:
|
|
677
821
|
List[Union[Session, Dict[str, Any]]]: List of upserted sessions.
|
|
@@ -715,6 +859,8 @@ class SqliteDb(BaseDb):
|
|
|
715
859
|
agent_data = []
|
|
716
860
|
for session in agent_sessions:
|
|
717
861
|
serialized_session = serialize_session_json_fields(session.to_dict())
|
|
862
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
863
|
+
updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
|
|
718
864
|
agent_data.append(
|
|
719
865
|
{
|
|
720
866
|
"session_id": serialized_session.get("session_id"),
|
|
@@ -727,7 +873,7 @@ class SqliteDb(BaseDb):
|
|
|
727
873
|
"runs": serialized_session.get("runs"),
|
|
728
874
|
"summary": serialized_session.get("summary"),
|
|
729
875
|
"created_at": serialized_session.get("created_at"),
|
|
730
|
-
"updated_at":
|
|
876
|
+
"updated_at": updated_at,
|
|
731
877
|
}
|
|
732
878
|
)
|
|
733
879
|
|
|
@@ -743,7 +889,7 @@ class SqliteDb(BaseDb):
|
|
|
743
889
|
metadata=stmt.excluded.metadata,
|
|
744
890
|
runs=stmt.excluded.runs,
|
|
745
891
|
summary=stmt.excluded.summary,
|
|
746
|
-
updated_at=
|
|
892
|
+
updated_at=stmt.excluded.updated_at,
|
|
747
893
|
),
|
|
748
894
|
)
|
|
749
895
|
sess.execute(stmt, agent_data)
|
|
@@ -768,6 +914,8 @@ class SqliteDb(BaseDb):
|
|
|
768
914
|
team_data = []
|
|
769
915
|
for session in team_sessions:
|
|
770
916
|
serialized_session = serialize_session_json_fields(session.to_dict())
|
|
917
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
918
|
+
updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
|
|
771
919
|
team_data.append(
|
|
772
920
|
{
|
|
773
921
|
"session_id": serialized_session.get("session_id"),
|
|
@@ -777,7 +925,7 @@ class SqliteDb(BaseDb):
|
|
|
777
925
|
"runs": serialized_session.get("runs"),
|
|
778
926
|
"summary": serialized_session.get("summary"),
|
|
779
927
|
"created_at": serialized_session.get("created_at"),
|
|
780
|
-
"updated_at":
|
|
928
|
+
"updated_at": updated_at,
|
|
781
929
|
"team_data": serialized_session.get("team_data"),
|
|
782
930
|
"session_data": serialized_session.get("session_data"),
|
|
783
931
|
"metadata": serialized_session.get("metadata"),
|
|
@@ -796,7 +944,7 @@ class SqliteDb(BaseDb):
|
|
|
796
944
|
metadata=stmt.excluded.metadata,
|
|
797
945
|
runs=stmt.excluded.runs,
|
|
798
946
|
summary=stmt.excluded.summary,
|
|
799
|
-
updated_at=
|
|
947
|
+
updated_at=stmt.excluded.updated_at,
|
|
800
948
|
),
|
|
801
949
|
)
|
|
802
950
|
sess.execute(stmt, team_data)
|
|
@@ -821,6 +969,8 @@ class SqliteDb(BaseDb):
|
|
|
821
969
|
workflow_data = []
|
|
822
970
|
for session in workflow_sessions:
|
|
823
971
|
serialized_session = serialize_session_json_fields(session.to_dict())
|
|
972
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
973
|
+
updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
|
|
824
974
|
workflow_data.append(
|
|
825
975
|
{
|
|
826
976
|
"session_id": serialized_session.get("session_id"),
|
|
@@ -830,7 +980,7 @@ class SqliteDb(BaseDb):
|
|
|
830
980
|
"runs": serialized_session.get("runs"),
|
|
831
981
|
"summary": serialized_session.get("summary"),
|
|
832
982
|
"created_at": serialized_session.get("created_at"),
|
|
833
|
-
"updated_at":
|
|
983
|
+
"updated_at": updated_at,
|
|
834
984
|
"workflow_data": serialized_session.get("workflow_data"),
|
|
835
985
|
"session_data": serialized_session.get("session_data"),
|
|
836
986
|
"metadata": serialized_session.get("metadata"),
|
|
@@ -849,7 +999,7 @@ class SqliteDb(BaseDb):
|
|
|
849
999
|
metadata=stmt.excluded.metadata,
|
|
850
1000
|
runs=stmt.excluded.runs,
|
|
851
1001
|
summary=stmt.excluded.summary,
|
|
852
|
-
updated_at=
|
|
1002
|
+
updated_at=stmt.excluded.updated_at,
|
|
853
1003
|
),
|
|
854
1004
|
)
|
|
855
1005
|
sess.execute(stmt, workflow_data)
|
|
@@ -958,9 +1108,8 @@ class SqliteDb(BaseDb):
|
|
|
958
1108
|
|
|
959
1109
|
with self.Session() as sess, sess.begin():
|
|
960
1110
|
# Select topics from all results
|
|
961
|
-
stmt = select(
|
|
1111
|
+
stmt = select(table.c.topics)
|
|
962
1112
|
result = sess.execute(stmt).fetchall()
|
|
963
|
-
|
|
964
1113
|
return list(set([record[0] for record in result]))
|
|
965
1114
|
|
|
966
1115
|
except Exception as e:
|
|
@@ -968,7 +1117,10 @@ class SqliteDb(BaseDb):
|
|
|
968
1117
|
raise e
|
|
969
1118
|
|
|
970
1119
|
def get_user_memory(
|
|
971
|
-
self,
|
|
1120
|
+
self,
|
|
1121
|
+
memory_id: str,
|
|
1122
|
+
deserialize: Optional[bool] = True,
|
|
1123
|
+
user_id: Optional[str] = None,
|
|
972
1124
|
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
973
1125
|
"""Get a memory from the database.
|
|
974
1126
|
|
|
@@ -1060,8 +1212,8 @@ class SqliteDb(BaseDb):
|
|
|
1060
1212
|
if team_id is not None:
|
|
1061
1213
|
stmt = stmt.where(table.c.team_id == team_id)
|
|
1062
1214
|
if topics is not None:
|
|
1063
|
-
|
|
1064
|
-
|
|
1215
|
+
for topic in topics:
|
|
1216
|
+
stmt = stmt.where(func.cast(table.c.topics, String).like(f'%"{topic}"%'))
|
|
1065
1217
|
if search_content is not None:
|
|
1066
1218
|
stmt = stmt.where(table.c.memory.ilike(f"%{search_content}%"))
|
|
1067
1219
|
|
|
@@ -1096,12 +1248,14 @@ class SqliteDb(BaseDb):
|
|
|
1096
1248
|
self,
|
|
1097
1249
|
limit: Optional[int] = None,
|
|
1098
1250
|
page: Optional[int] = None,
|
|
1251
|
+
user_id: Optional[str] = None,
|
|
1099
1252
|
) -> Tuple[List[Dict[str, Any]], int]:
|
|
1100
1253
|
"""Get user memories stats.
|
|
1101
1254
|
|
|
1102
1255
|
Args:
|
|
1103
1256
|
limit (Optional[int]): The maximum number of user stats to return.
|
|
1104
1257
|
page (Optional[int]): The page number.
|
|
1258
|
+
user_id (Optional[str]): User ID for filtering.
|
|
1105
1259
|
|
|
1106
1260
|
Returns:
|
|
1107
1261
|
Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
|
|
@@ -1124,19 +1278,20 @@ class SqliteDb(BaseDb):
|
|
|
1124
1278
|
return [], 0
|
|
1125
1279
|
|
|
1126
1280
|
with self.Session() as sess, sess.begin():
|
|
1127
|
-
stmt = (
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
func.max(table.c.updated_at).label("last_memory_updated_at"),
|
|
1132
|
-
)
|
|
1133
|
-
.where(table.c.user_id.is_not(None))
|
|
1134
|
-
.group_by(table.c.user_id)
|
|
1135
|
-
.order_by(func.max(table.c.updated_at).desc())
|
|
1281
|
+
stmt = select(
|
|
1282
|
+
table.c.user_id,
|
|
1283
|
+
func.count(table.c.memory_id).label("total_memories"),
|
|
1284
|
+
func.max(table.c.updated_at).label("last_memory_updated_at"),
|
|
1136
1285
|
)
|
|
1286
|
+
if user_id is not None:
|
|
1287
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
1288
|
+
else:
|
|
1289
|
+
stmt = stmt.where(table.c.user_id.is_not(None))
|
|
1290
|
+
stmt = stmt.group_by(table.c.user_id)
|
|
1291
|
+
stmt = stmt.order_by(func.max(table.c.updated_at).desc())
|
|
1137
1292
|
|
|
1138
1293
|
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1139
|
-
total_count = sess.execute(count_stmt).scalar()
|
|
1294
|
+
total_count = sess.execute(count_stmt).scalar() or 0
|
|
1140
1295
|
|
|
1141
1296
|
# Pagination
|
|
1142
1297
|
if limit is not None:
|
|
@@ -1186,6 +1341,8 @@ class SqliteDb(BaseDb):
|
|
|
1186
1341
|
if memory.memory_id is None:
|
|
1187
1342
|
memory.memory_id = str(uuid4())
|
|
1188
1343
|
|
|
1344
|
+
current_time = int(time.time())
|
|
1345
|
+
|
|
1189
1346
|
with self.Session() as sess, sess.begin():
|
|
1190
1347
|
stmt = sqlite.insert(table).values(
|
|
1191
1348
|
user_id=memory.user_id,
|
|
@@ -1195,7 +1352,9 @@ class SqliteDb(BaseDb):
|
|
|
1195
1352
|
memory=memory.memory,
|
|
1196
1353
|
topics=memory.topics,
|
|
1197
1354
|
input=memory.input,
|
|
1198
|
-
|
|
1355
|
+
feedback=memory.feedback,
|
|
1356
|
+
created_at=memory.created_at,
|
|
1357
|
+
updated_at=memory.created_at,
|
|
1199
1358
|
)
|
|
1200
1359
|
stmt = stmt.on_conflict_do_update( # type: ignore
|
|
1201
1360
|
index_elements=["memory_id"],
|
|
@@ -1203,7 +1362,12 @@ class SqliteDb(BaseDb):
|
|
|
1203
1362
|
memory=memory.memory,
|
|
1204
1363
|
topics=memory.topics,
|
|
1205
1364
|
input=memory.input,
|
|
1206
|
-
|
|
1365
|
+
agent_id=memory.agent_id,
|
|
1366
|
+
team_id=memory.team_id,
|
|
1367
|
+
feedback=memory.feedback,
|
|
1368
|
+
updated_at=current_time,
|
|
1369
|
+
# Preserve created_at on update - don't overwrite existing value
|
|
1370
|
+
created_at=table.c.created_at,
|
|
1207
1371
|
),
|
|
1208
1372
|
).returning(table)
|
|
1209
1373
|
|
|
@@ -1224,7 +1388,10 @@ class SqliteDb(BaseDb):
|
|
|
1224
1388
|
raise e
|
|
1225
1389
|
|
|
1226
1390
|
def upsert_memories(
|
|
1227
|
-
self,
|
|
1391
|
+
self,
|
|
1392
|
+
memories: List[UserMemory],
|
|
1393
|
+
deserialize: Optional[bool] = True,
|
|
1394
|
+
preserve_updated_at: bool = False,
|
|
1228
1395
|
) -> List[Union[UserMemory, Dict[str, Any]]]:
|
|
1229
1396
|
"""
|
|
1230
1397
|
Bulk upsert multiple user memories for improved performance on large datasets.
|
|
@@ -1255,10 +1422,15 @@ class SqliteDb(BaseDb):
|
|
|
1255
1422
|
]
|
|
1256
1423
|
# Prepare bulk data
|
|
1257
1424
|
bulk_data = []
|
|
1425
|
+
current_time = int(time.time())
|
|
1426
|
+
|
|
1258
1427
|
for memory in memories:
|
|
1259
1428
|
if memory.memory_id is None:
|
|
1260
1429
|
memory.memory_id = str(uuid4())
|
|
1261
1430
|
|
|
1431
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
1432
|
+
updated_at = memory.updated_at if preserve_updated_at else current_time
|
|
1433
|
+
|
|
1262
1434
|
bulk_data.append(
|
|
1263
1435
|
{
|
|
1264
1436
|
"user_id": memory.user_id,
|
|
@@ -1267,7 +1439,10 @@ class SqliteDb(BaseDb):
|
|
|
1267
1439
|
"memory_id": memory.memory_id,
|
|
1268
1440
|
"memory": memory.memory,
|
|
1269
1441
|
"topics": memory.topics,
|
|
1270
|
-
"
|
|
1442
|
+
"input": memory.input,
|
|
1443
|
+
"feedback": memory.feedback,
|
|
1444
|
+
"created_at": memory.created_at,
|
|
1445
|
+
"updated_at": updated_at,
|
|
1271
1446
|
}
|
|
1272
1447
|
)
|
|
1273
1448
|
|
|
@@ -1284,7 +1459,10 @@ class SqliteDb(BaseDb):
|
|
|
1284
1459
|
input=stmt.excluded.input,
|
|
1285
1460
|
agent_id=stmt.excluded.agent_id,
|
|
1286
1461
|
team_id=stmt.excluded.team_id,
|
|
1287
|
-
|
|
1462
|
+
feedback=stmt.excluded.feedback,
|
|
1463
|
+
updated_at=stmt.excluded.updated_at,
|
|
1464
|
+
# Preserve created_at on update
|
|
1465
|
+
created_at=table.c.created_at,
|
|
1288
1466
|
),
|
|
1289
1467
|
)
|
|
1290
1468
|
sess.execute(stmt, bulk_data)
|
|
@@ -1450,7 +1628,9 @@ class SqliteDb(BaseDb):
|
|
|
1450
1628
|
start_timestamp=start_timestamp, end_timestamp=end_timestamp
|
|
1451
1629
|
)
|
|
1452
1630
|
all_sessions_data = fetch_all_sessions_data(
|
|
1453
|
-
sessions=sessions,
|
|
1631
|
+
sessions=sessions,
|
|
1632
|
+
dates_to_process=dates_to_process,
|
|
1633
|
+
start_timestamp=start_timestamp,
|
|
1454
1634
|
)
|
|
1455
1635
|
if not all_sessions_data:
|
|
1456
1636
|
log_info("No new session data found. Won't calculate metrics.")
|
|
@@ -1697,7 +1877,11 @@ class SqliteDb(BaseDb):
|
|
|
1697
1877
|
with self.Session() as sess, sess.begin():
|
|
1698
1878
|
current_time = int(time.time())
|
|
1699
1879
|
stmt = sqlite.insert(table).values(
|
|
1700
|
-
{
|
|
1880
|
+
{
|
|
1881
|
+
"created_at": current_time,
|
|
1882
|
+
"updated_at": current_time,
|
|
1883
|
+
**eval_run.model_dump(),
|
|
1884
|
+
}
|
|
1701
1885
|
)
|
|
1702
1886
|
sess.execute(stmt)
|
|
1703
1887
|
sess.commit()
|
|
@@ -1930,6 +2114,510 @@ class SqliteDb(BaseDb):
|
|
|
1930
2114
|
log_error(f"Error renaming eval run {eval_run_id}: {e}")
|
|
1931
2115
|
raise e
|
|
1932
2116
|
|
|
2117
|
+
# -- Trace methods --
|
|
2118
|
+
|
|
2119
|
+
def _get_traces_base_query(self, table: Table, spans_table: Optional[Table] = None):
|
|
2120
|
+
"""Build base query for traces with aggregated span counts.
|
|
2121
|
+
|
|
2122
|
+
Args:
|
|
2123
|
+
table: The traces table.
|
|
2124
|
+
spans_table: The spans table (optional).
|
|
2125
|
+
|
|
2126
|
+
Returns:
|
|
2127
|
+
SQLAlchemy select statement with total_spans and error_count calculated dynamically.
|
|
2128
|
+
"""
|
|
2129
|
+
from sqlalchemy import case, func, literal
|
|
2130
|
+
|
|
2131
|
+
if spans_table is not None:
|
|
2132
|
+
# JOIN with spans table to calculate total_spans and error_count
|
|
2133
|
+
return (
|
|
2134
|
+
select(
|
|
2135
|
+
table,
|
|
2136
|
+
func.coalesce(func.count(spans_table.c.span_id), 0).label("total_spans"),
|
|
2137
|
+
func.coalesce(func.sum(case((spans_table.c.status_code == "ERROR", 1), else_=0)), 0).label(
|
|
2138
|
+
"error_count"
|
|
2139
|
+
),
|
|
2140
|
+
)
|
|
2141
|
+
.select_from(table.outerjoin(spans_table, table.c.trace_id == spans_table.c.trace_id))
|
|
2142
|
+
.group_by(table.c.trace_id)
|
|
2143
|
+
)
|
|
2144
|
+
else:
|
|
2145
|
+
# Fallback if spans table doesn't exist
|
|
2146
|
+
return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
|
|
2147
|
+
|
|
2148
|
+
def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
|
|
2149
|
+
"""Build a SQL CASE expression that returns the component level for a trace.
|
|
2150
|
+
|
|
2151
|
+
Component levels (higher = more important):
|
|
2152
|
+
- 3: Workflow root (.run or .arun with workflow_id)
|
|
2153
|
+
- 2: Team root (.run or .arun with team_id)
|
|
2154
|
+
- 1: Agent root (.run or .arun with agent_id)
|
|
2155
|
+
- 0: Child span (not a root)
|
|
2156
|
+
|
|
2157
|
+
Args:
|
|
2158
|
+
workflow_id_col: SQL column/expression for workflow_id
|
|
2159
|
+
team_id_col: SQL column/expression for team_id
|
|
2160
|
+
agent_id_col: SQL column/expression for agent_id
|
|
2161
|
+
name_col: SQL column/expression for name
|
|
2162
|
+
|
|
2163
|
+
Returns:
|
|
2164
|
+
SQLAlchemy CASE expression returning the component level as an integer.
|
|
2165
|
+
"""
|
|
2166
|
+
from sqlalchemy import and_, case, or_
|
|
2167
|
+
|
|
2168
|
+
is_root_name = or_(name_col.contains(".run"), name_col.contains(".arun"))
|
|
2169
|
+
|
|
2170
|
+
return case(
|
|
2171
|
+
# Workflow root (level 3)
|
|
2172
|
+
(and_(workflow_id_col.isnot(None), is_root_name), 3),
|
|
2173
|
+
# Team root (level 2)
|
|
2174
|
+
(and_(team_id_col.isnot(None), is_root_name), 2),
|
|
2175
|
+
# Agent root (level 1)
|
|
2176
|
+
(and_(agent_id_col.isnot(None), is_root_name), 1),
|
|
2177
|
+
# Child span or unknown (level 0)
|
|
2178
|
+
else_=0,
|
|
2179
|
+
)
|
|
2180
|
+
|
|
2181
|
+
def upsert_trace(self, trace: "Trace") -> None:
|
|
2182
|
+
"""Create or update a single trace record in the database.
|
|
2183
|
+
|
|
2184
|
+
Uses INSERT ... ON CONFLICT DO UPDATE (upsert) to handle concurrent inserts
|
|
2185
|
+
atomically and avoid race conditions.
|
|
2186
|
+
|
|
2187
|
+
Args:
|
|
2188
|
+
trace: The Trace object to store (one per trace_id).
|
|
2189
|
+
"""
|
|
2190
|
+
from sqlalchemy import case
|
|
2191
|
+
|
|
2192
|
+
try:
|
|
2193
|
+
table = self._get_table(table_type="traces", create_table_if_not_found=True)
|
|
2194
|
+
if table is None:
|
|
2195
|
+
return
|
|
2196
|
+
|
|
2197
|
+
trace_dict = trace.to_dict()
|
|
2198
|
+
trace_dict.pop("total_spans", None)
|
|
2199
|
+
trace_dict.pop("error_count", None)
|
|
2200
|
+
|
|
2201
|
+
with self.Session() as sess, sess.begin():
|
|
2202
|
+
# Use upsert to handle concurrent inserts atomically
|
|
2203
|
+
# On conflict, update fields while preserving existing non-null context values
|
|
2204
|
+
# and keeping the earliest start_time
|
|
2205
|
+
insert_stmt = sqlite.insert(table).values(trace_dict)
|
|
2206
|
+
|
|
2207
|
+
# Build component level expressions for comparing trace priority
|
|
2208
|
+
new_level = self._get_trace_component_level_expr(
|
|
2209
|
+
insert_stmt.excluded.workflow_id,
|
|
2210
|
+
insert_stmt.excluded.team_id,
|
|
2211
|
+
insert_stmt.excluded.agent_id,
|
|
2212
|
+
insert_stmt.excluded.name,
|
|
2213
|
+
)
|
|
2214
|
+
existing_level = self._get_trace_component_level_expr(
|
|
2215
|
+
table.c.workflow_id,
|
|
2216
|
+
table.c.team_id,
|
|
2217
|
+
table.c.agent_id,
|
|
2218
|
+
table.c.name,
|
|
2219
|
+
)
|
|
2220
|
+
|
|
2221
|
+
# Build the ON CONFLICT DO UPDATE clause
|
|
2222
|
+
# Use MIN for start_time, MAX for end_time to capture full trace duration
|
|
2223
|
+
# SQLite stores timestamps as ISO strings, so string comparison works for ISO format
|
|
2224
|
+
# Duration is calculated as: (MAX(end_time) - MIN(start_time)) in milliseconds
|
|
2225
|
+
# SQLite doesn't have epoch extraction, so we calculate duration using julianday
|
|
2226
|
+
upsert_stmt = insert_stmt.on_conflict_do_update(
|
|
2227
|
+
index_elements=["trace_id"],
|
|
2228
|
+
set_={
|
|
2229
|
+
"end_time": func.max(table.c.end_time, insert_stmt.excluded.end_time),
|
|
2230
|
+
"start_time": func.min(table.c.start_time, insert_stmt.excluded.start_time),
|
|
2231
|
+
# Calculate duration in milliseconds using julianday (SQLite-specific)
|
|
2232
|
+
# julianday returns days, so multiply by 86400000 to get milliseconds
|
|
2233
|
+
"duration_ms": (
|
|
2234
|
+
func.julianday(func.max(table.c.end_time, insert_stmt.excluded.end_time))
|
|
2235
|
+
- func.julianday(func.min(table.c.start_time, insert_stmt.excluded.start_time))
|
|
2236
|
+
)
|
|
2237
|
+
* 86400000,
|
|
2238
|
+
"status": insert_stmt.excluded.status,
|
|
2239
|
+
# Update name only if new trace is from a higher-level component
|
|
2240
|
+
# Priority: workflow (3) > team (2) > agent (1) > child spans (0)
|
|
2241
|
+
"name": case(
|
|
2242
|
+
(new_level > existing_level, insert_stmt.excluded.name),
|
|
2243
|
+
else_=table.c.name,
|
|
2244
|
+
),
|
|
2245
|
+
# Preserve existing non-null context values using COALESCE
|
|
2246
|
+
"run_id": func.coalesce(insert_stmt.excluded.run_id, table.c.run_id),
|
|
2247
|
+
"session_id": func.coalesce(insert_stmt.excluded.session_id, table.c.session_id),
|
|
2248
|
+
"user_id": func.coalesce(insert_stmt.excluded.user_id, table.c.user_id),
|
|
2249
|
+
"agent_id": func.coalesce(insert_stmt.excluded.agent_id, table.c.agent_id),
|
|
2250
|
+
"team_id": func.coalesce(insert_stmt.excluded.team_id, table.c.team_id),
|
|
2251
|
+
"workflow_id": func.coalesce(insert_stmt.excluded.workflow_id, table.c.workflow_id),
|
|
2252
|
+
},
|
|
2253
|
+
)
|
|
2254
|
+
sess.execute(upsert_stmt)
|
|
2255
|
+
|
|
2256
|
+
except Exception as e:
|
|
2257
|
+
log_error(f"Error creating trace: {e}")
|
|
2258
|
+
# Don't raise - tracing should not break the main application flow
|
|
2259
|
+
|
|
2260
|
+
def get_trace(
|
|
2261
|
+
self,
|
|
2262
|
+
trace_id: Optional[str] = None,
|
|
2263
|
+
run_id: Optional[str] = None,
|
|
2264
|
+
):
|
|
2265
|
+
"""Get a single trace by trace_id or other filters.
|
|
2266
|
+
|
|
2267
|
+
Args:
|
|
2268
|
+
trace_id: The unique trace identifier.
|
|
2269
|
+
run_id: Filter by run ID (returns first match).
|
|
2270
|
+
|
|
2271
|
+
Returns:
|
|
2272
|
+
Optional[Trace]: The trace if found, None otherwise.
|
|
2273
|
+
|
|
2274
|
+
Note:
|
|
2275
|
+
If multiple filters are provided, trace_id takes precedence.
|
|
2276
|
+
For other filters, the most recent trace is returned.
|
|
2277
|
+
"""
|
|
2278
|
+
try:
|
|
2279
|
+
from agno.tracing.schemas import Trace
|
|
2280
|
+
|
|
2281
|
+
table = self._get_table(table_type="traces")
|
|
2282
|
+
if table is None:
|
|
2283
|
+
return None
|
|
2284
|
+
|
|
2285
|
+
# Get spans table for JOIN
|
|
2286
|
+
spans_table = self._get_table(table_type="spans")
|
|
2287
|
+
|
|
2288
|
+
with self.Session() as sess:
|
|
2289
|
+
# Build query with aggregated span counts
|
|
2290
|
+
stmt = self._get_traces_base_query(table, spans_table)
|
|
2291
|
+
|
|
2292
|
+
if trace_id:
|
|
2293
|
+
stmt = stmt.where(table.c.trace_id == trace_id)
|
|
2294
|
+
elif run_id:
|
|
2295
|
+
stmt = stmt.where(table.c.run_id == run_id)
|
|
2296
|
+
else:
|
|
2297
|
+
log_debug("get_trace called without any filter parameters")
|
|
2298
|
+
return None
|
|
2299
|
+
|
|
2300
|
+
# Order by most recent and get first result
|
|
2301
|
+
stmt = stmt.order_by(table.c.start_time.desc()).limit(1)
|
|
2302
|
+
result = sess.execute(stmt).fetchone()
|
|
2303
|
+
|
|
2304
|
+
if result:
|
|
2305
|
+
return Trace.from_dict(dict(result._mapping))
|
|
2306
|
+
return None
|
|
2307
|
+
|
|
2308
|
+
except Exception as e:
|
|
2309
|
+
log_error(f"Error getting trace: {e}")
|
|
2310
|
+
return None
|
|
2311
|
+
|
|
2312
|
+
def get_traces(
|
|
2313
|
+
self,
|
|
2314
|
+
run_id: Optional[str] = None,
|
|
2315
|
+
session_id: Optional[str] = None,
|
|
2316
|
+
user_id: Optional[str] = None,
|
|
2317
|
+
agent_id: Optional[str] = None,
|
|
2318
|
+
team_id: Optional[str] = None,
|
|
2319
|
+
workflow_id: Optional[str] = None,
|
|
2320
|
+
status: Optional[str] = None,
|
|
2321
|
+
start_time: Optional[datetime] = None,
|
|
2322
|
+
end_time: Optional[datetime] = None,
|
|
2323
|
+
limit: Optional[int] = 20,
|
|
2324
|
+
page: Optional[int] = 1,
|
|
2325
|
+
) -> tuple[List, int]:
|
|
2326
|
+
"""Get traces matching the provided filters with pagination.
|
|
2327
|
+
|
|
2328
|
+
Args:
|
|
2329
|
+
run_id: Filter by run ID.
|
|
2330
|
+
session_id: Filter by session ID.
|
|
2331
|
+
user_id: Filter by user ID.
|
|
2332
|
+
agent_id: Filter by agent ID.
|
|
2333
|
+
team_id: Filter by team ID.
|
|
2334
|
+
workflow_id: Filter by workflow ID.
|
|
2335
|
+
status: Filter by status (OK, ERROR, UNSET).
|
|
2336
|
+
start_time: Filter traces starting after this datetime.
|
|
2337
|
+
end_time: Filter traces ending before this datetime.
|
|
2338
|
+
limit: Maximum number of traces to return per page.
|
|
2339
|
+
page: Page number (1-indexed).
|
|
2340
|
+
|
|
2341
|
+
Returns:
|
|
2342
|
+
tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
|
|
2343
|
+
"""
|
|
2344
|
+
try:
|
|
2345
|
+
from sqlalchemy import func
|
|
2346
|
+
|
|
2347
|
+
from agno.tracing.schemas import Trace
|
|
2348
|
+
|
|
2349
|
+
log_debug(
|
|
2350
|
+
f"get_traces called with filters: run_id={run_id}, session_id={session_id}, user_id={user_id}, agent_id={agent_id}, page={page}, limit={limit}"
|
|
2351
|
+
)
|
|
2352
|
+
|
|
2353
|
+
table = self._get_table(table_type="traces")
|
|
2354
|
+
if table is None:
|
|
2355
|
+
log_debug(" Traces table not found")
|
|
2356
|
+
return [], 0
|
|
2357
|
+
|
|
2358
|
+
# Get spans table for JOIN
|
|
2359
|
+
spans_table = self._get_table(table_type="spans")
|
|
2360
|
+
|
|
2361
|
+
with self.Session() as sess:
|
|
2362
|
+
# Build base query with aggregated span counts
|
|
2363
|
+
base_stmt = self._get_traces_base_query(table, spans_table)
|
|
2364
|
+
|
|
2365
|
+
# Apply filters
|
|
2366
|
+
if run_id:
|
|
2367
|
+
base_stmt = base_stmt.where(table.c.run_id == run_id)
|
|
2368
|
+
if session_id:
|
|
2369
|
+
base_stmt = base_stmt.where(table.c.session_id == session_id)
|
|
2370
|
+
if user_id:
|
|
2371
|
+
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
2372
|
+
if agent_id:
|
|
2373
|
+
base_stmt = base_stmt.where(table.c.agent_id == agent_id)
|
|
2374
|
+
if team_id:
|
|
2375
|
+
base_stmt = base_stmt.where(table.c.team_id == team_id)
|
|
2376
|
+
if workflow_id:
|
|
2377
|
+
base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
|
|
2378
|
+
if status:
|
|
2379
|
+
base_stmt = base_stmt.where(table.c.status == status)
|
|
2380
|
+
if start_time:
|
|
2381
|
+
# Convert datetime to ISO string for comparison
|
|
2382
|
+
base_stmt = base_stmt.where(table.c.start_time >= start_time.isoformat())
|
|
2383
|
+
if end_time:
|
|
2384
|
+
# Convert datetime to ISO string for comparison
|
|
2385
|
+
base_stmt = base_stmt.where(table.c.end_time <= end_time.isoformat())
|
|
2386
|
+
|
|
2387
|
+
# Get total count
|
|
2388
|
+
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2389
|
+
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2390
|
+
|
|
2391
|
+
# Apply pagination
|
|
2392
|
+
offset = (page - 1) * limit if page and limit else 0
|
|
2393
|
+
paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
|
|
2394
|
+
|
|
2395
|
+
results = sess.execute(paginated_stmt).fetchall()
|
|
2396
|
+
|
|
2397
|
+
traces = [Trace.from_dict(dict(row._mapping)) for row in results]
|
|
2398
|
+
return traces, total_count
|
|
2399
|
+
|
|
2400
|
+
except Exception as e:
|
|
2401
|
+
log_error(f"Error getting traces: {e}")
|
|
2402
|
+
return [], 0
|
|
2403
|
+
|
|
2404
|
+
def get_trace_stats(
|
|
2405
|
+
self,
|
|
2406
|
+
user_id: Optional[str] = None,
|
|
2407
|
+
agent_id: Optional[str] = None,
|
|
2408
|
+
team_id: Optional[str] = None,
|
|
2409
|
+
workflow_id: Optional[str] = None,
|
|
2410
|
+
start_time: Optional[datetime] = None,
|
|
2411
|
+
end_time: Optional[datetime] = None,
|
|
2412
|
+
limit: Optional[int] = 20,
|
|
2413
|
+
page: Optional[int] = 1,
|
|
2414
|
+
) -> tuple[List[Dict[str, Any]], int]:
|
|
2415
|
+
"""Get trace statistics grouped by session.
|
|
2416
|
+
|
|
2417
|
+
Args:
|
|
2418
|
+
user_id: Filter by user ID.
|
|
2419
|
+
agent_id: Filter by agent ID.
|
|
2420
|
+
team_id: Filter by team ID.
|
|
2421
|
+
workflow_id: Filter by workflow ID.
|
|
2422
|
+
start_time: Filter sessions with traces created after this datetime.
|
|
2423
|
+
end_time: Filter sessions with traces created before this datetime.
|
|
2424
|
+
limit: Maximum number of sessions to return per page.
|
|
2425
|
+
page: Page number (1-indexed).
|
|
2426
|
+
|
|
2427
|
+
Returns:
|
|
2428
|
+
tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
|
|
2429
|
+
"""
|
|
2430
|
+
try:
|
|
2431
|
+
from sqlalchemy import func
|
|
2432
|
+
|
|
2433
|
+
table = self._get_table(table_type="traces")
|
|
2434
|
+
if table is None:
|
|
2435
|
+
log_debug("Traces table not found")
|
|
2436
|
+
return [], 0
|
|
2437
|
+
|
|
2438
|
+
with self.Session() as sess:
|
|
2439
|
+
# Build base query grouped by session_id
|
|
2440
|
+
base_stmt = (
|
|
2441
|
+
select(
|
|
2442
|
+
table.c.session_id,
|
|
2443
|
+
table.c.user_id,
|
|
2444
|
+
table.c.agent_id,
|
|
2445
|
+
table.c.team_id,
|
|
2446
|
+
table.c.workflow_id,
|
|
2447
|
+
func.count(table.c.trace_id).label("total_traces"),
|
|
2448
|
+
func.min(table.c.created_at).label("first_trace_at"),
|
|
2449
|
+
func.max(table.c.created_at).label("last_trace_at"),
|
|
2450
|
+
)
|
|
2451
|
+
.where(table.c.session_id.isnot(None)) # Only sessions with session_id
|
|
2452
|
+
.group_by(
|
|
2453
|
+
table.c.session_id, table.c.user_id, table.c.agent_id, table.c.team_id, table.c.workflow_id
|
|
2454
|
+
)
|
|
2455
|
+
)
|
|
2456
|
+
|
|
2457
|
+
# Apply filters
|
|
2458
|
+
if user_id:
|
|
2459
|
+
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
2460
|
+
if workflow_id:
|
|
2461
|
+
base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
|
|
2462
|
+
if team_id:
|
|
2463
|
+
base_stmt = base_stmt.where(table.c.team_id == team_id)
|
|
2464
|
+
if agent_id:
|
|
2465
|
+
base_stmt = base_stmt.where(table.c.agent_id == agent_id)
|
|
2466
|
+
if start_time:
|
|
2467
|
+
# Convert datetime to ISO string for comparison
|
|
2468
|
+
base_stmt = base_stmt.where(table.c.created_at >= start_time.isoformat())
|
|
2469
|
+
if end_time:
|
|
2470
|
+
# Convert datetime to ISO string for comparison
|
|
2471
|
+
base_stmt = base_stmt.where(table.c.created_at <= end_time.isoformat())
|
|
2472
|
+
|
|
2473
|
+
# Get total count of sessions
|
|
2474
|
+
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2475
|
+
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2476
|
+
|
|
2477
|
+
# Apply pagination and ordering
|
|
2478
|
+
offset = (page - 1) * limit if page and limit else 0
|
|
2479
|
+
paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
|
|
2480
|
+
|
|
2481
|
+
results = sess.execute(paginated_stmt).fetchall()
|
|
2482
|
+
|
|
2483
|
+
# Convert to list of dicts with datetime objects
|
|
2484
|
+
from datetime import datetime
|
|
2485
|
+
|
|
2486
|
+
stats_list = []
|
|
2487
|
+
for row in results:
|
|
2488
|
+
# Convert ISO strings to datetime objects
|
|
2489
|
+
first_trace_at_str = row.first_trace_at
|
|
2490
|
+
last_trace_at_str = row.last_trace_at
|
|
2491
|
+
|
|
2492
|
+
# Parse ISO format strings to datetime objects
|
|
2493
|
+
first_trace_at = datetime.fromisoformat(first_trace_at_str.replace("Z", "+00:00"))
|
|
2494
|
+
last_trace_at = datetime.fromisoformat(last_trace_at_str.replace("Z", "+00:00"))
|
|
2495
|
+
|
|
2496
|
+
stats_list.append(
|
|
2497
|
+
{
|
|
2498
|
+
"session_id": row.session_id,
|
|
2499
|
+
"user_id": row.user_id,
|
|
2500
|
+
"agent_id": row.agent_id,
|
|
2501
|
+
"team_id": row.team_id,
|
|
2502
|
+
"workflow_id": row.workflow_id,
|
|
2503
|
+
"total_traces": row.total_traces,
|
|
2504
|
+
"first_trace_at": first_trace_at,
|
|
2505
|
+
"last_trace_at": last_trace_at,
|
|
2506
|
+
}
|
|
2507
|
+
)
|
|
2508
|
+
|
|
2509
|
+
return stats_list, total_count
|
|
2510
|
+
|
|
2511
|
+
except Exception as e:
|
|
2512
|
+
log_error(f"Error getting trace stats: {e}")
|
|
2513
|
+
return [], 0
|
|
2514
|
+
|
|
2515
|
+
# -- Span methods --
|
|
2516
|
+
|
|
2517
|
+
def create_span(self, span: "Span") -> None:
|
|
2518
|
+
"""Create a single span in the database.
|
|
2519
|
+
|
|
2520
|
+
Args:
|
|
2521
|
+
span: The Span object to store.
|
|
2522
|
+
"""
|
|
2523
|
+
try:
|
|
2524
|
+
table = self._get_table(table_type="spans", create_table_if_not_found=True)
|
|
2525
|
+
if table is None:
|
|
2526
|
+
return
|
|
2527
|
+
|
|
2528
|
+
with self.Session() as sess, sess.begin():
|
|
2529
|
+
stmt = sqlite.insert(table).values(span.to_dict())
|
|
2530
|
+
sess.execute(stmt)
|
|
2531
|
+
|
|
2532
|
+
except Exception as e:
|
|
2533
|
+
log_error(f"Error creating span: {e}")
|
|
2534
|
+
|
|
2535
|
+
def create_spans(self, spans: List) -> None:
|
|
2536
|
+
"""Create multiple spans in the database as a batch.
|
|
2537
|
+
|
|
2538
|
+
Args:
|
|
2539
|
+
spans: List of Span objects to store.
|
|
2540
|
+
"""
|
|
2541
|
+
if not spans:
|
|
2542
|
+
return
|
|
2543
|
+
|
|
2544
|
+
try:
|
|
2545
|
+
table = self._get_table(table_type="spans", create_table_if_not_found=True)
|
|
2546
|
+
if table is None:
|
|
2547
|
+
return
|
|
2548
|
+
|
|
2549
|
+
with self.Session() as sess, sess.begin():
|
|
2550
|
+
for span in spans:
|
|
2551
|
+
stmt = sqlite.insert(table).values(span.to_dict())
|
|
2552
|
+
sess.execute(stmt)
|
|
2553
|
+
|
|
2554
|
+
except Exception as e:
|
|
2555
|
+
log_error(f"Error creating spans batch: {e}")
|
|
2556
|
+
|
|
2557
|
+
def get_span(self, span_id: str):
|
|
2558
|
+
"""Get a single span by its span_id.
|
|
2559
|
+
|
|
2560
|
+
Args:
|
|
2561
|
+
span_id: The unique span identifier.
|
|
2562
|
+
|
|
2563
|
+
Returns:
|
|
2564
|
+
Optional[Span]: The span if found, None otherwise.
|
|
2565
|
+
"""
|
|
2566
|
+
try:
|
|
2567
|
+
from agno.tracing.schemas import Span
|
|
2568
|
+
|
|
2569
|
+
table = self._get_table(table_type="spans")
|
|
2570
|
+
if table is None:
|
|
2571
|
+
return None
|
|
2572
|
+
|
|
2573
|
+
with self.Session() as sess:
|
|
2574
|
+
stmt = table.select().where(table.c.span_id == span_id)
|
|
2575
|
+
result = sess.execute(stmt).fetchone()
|
|
2576
|
+
if result:
|
|
2577
|
+
return Span.from_dict(dict(result._mapping))
|
|
2578
|
+
return None
|
|
2579
|
+
|
|
2580
|
+
except Exception as e:
|
|
2581
|
+
log_error(f"Error getting span: {e}")
|
|
2582
|
+
return None
|
|
2583
|
+
|
|
2584
|
+
def get_spans(
|
|
2585
|
+
self,
|
|
2586
|
+
trace_id: Optional[str] = None,
|
|
2587
|
+
parent_span_id: Optional[str] = None,
|
|
2588
|
+
) -> List:
|
|
2589
|
+
"""Get spans matching the provided filters.
|
|
2590
|
+
|
|
2591
|
+
Args:
|
|
2592
|
+
trace_id: Filter by trace ID.
|
|
2593
|
+
parent_span_id: Filter by parent span ID.
|
|
2594
|
+
|
|
2595
|
+
Returns:
|
|
2596
|
+
List[Span]: List of matching spans.
|
|
2597
|
+
"""
|
|
2598
|
+
try:
|
|
2599
|
+
from agno.tracing.schemas import Span
|
|
2600
|
+
|
|
2601
|
+
table = self._get_table(table_type="spans")
|
|
2602
|
+
if table is None:
|
|
2603
|
+
return []
|
|
2604
|
+
|
|
2605
|
+
with self.Session() as sess:
|
|
2606
|
+
stmt = table.select()
|
|
2607
|
+
|
|
2608
|
+
# Apply filters
|
|
2609
|
+
if trace_id:
|
|
2610
|
+
stmt = stmt.where(table.c.trace_id == trace_id)
|
|
2611
|
+
if parent_span_id:
|
|
2612
|
+
stmt = stmt.where(table.c.parent_span_id == parent_span_id)
|
|
2613
|
+
|
|
2614
|
+
results = sess.execute(stmt).fetchall()
|
|
2615
|
+
return [Span.from_dict(dict(row._mapping)) for row in results]
|
|
2616
|
+
|
|
2617
|
+
except Exception as e:
|
|
2618
|
+
log_error(f"Error getting spans: {e}")
|
|
2619
|
+
return []
|
|
2620
|
+
|
|
1933
2621
|
# -- Migrations --
|
|
1934
2622
|
|
|
1935
2623
|
def migrate_table_from_v1_to_v2(self, v1_db_schema: str, v1_table_name: str, v1_table_type: str):
|
|
@@ -1987,3 +2675,234 @@ class SqliteDb(BaseDb):
|
|
|
1987
2675
|
for memory in memories:
|
|
1988
2676
|
self.upsert_user_memory(memory)
|
|
1989
2677
|
log_info(f"Migrated {len(memories)} memories to table: {self.memory_table}")
|
|
2678
|
+
|
|
2679
|
+
# -- Culture methods --
|
|
2680
|
+
|
|
2681
|
+
def clear_cultural_knowledge(self) -> None:
|
|
2682
|
+
"""Delete all cultural artifacts from the database.
|
|
2683
|
+
|
|
2684
|
+
Raises:
|
|
2685
|
+
Exception: If an error occurs during deletion.
|
|
2686
|
+
"""
|
|
2687
|
+
try:
|
|
2688
|
+
table = self._get_table(table_type="culture")
|
|
2689
|
+
if table is None:
|
|
2690
|
+
return
|
|
2691
|
+
|
|
2692
|
+
with self.Session() as sess, sess.begin():
|
|
2693
|
+
sess.execute(table.delete())
|
|
2694
|
+
|
|
2695
|
+
except Exception as e:
|
|
2696
|
+
from agno.utils.log import log_warning
|
|
2697
|
+
|
|
2698
|
+
log_warning(f"Exception deleting all cultural artifacts: {e}")
|
|
2699
|
+
raise e
|
|
2700
|
+
|
|
2701
|
+
def delete_cultural_knowledge(self, id: str) -> None:
|
|
2702
|
+
"""Delete a cultural artifact from the database.
|
|
2703
|
+
|
|
2704
|
+
Args:
|
|
2705
|
+
id (str): The ID of the cultural artifact to delete.
|
|
2706
|
+
|
|
2707
|
+
Raises:
|
|
2708
|
+
Exception: If an error occurs during deletion.
|
|
2709
|
+
"""
|
|
2710
|
+
try:
|
|
2711
|
+
table = self._get_table(table_type="culture")
|
|
2712
|
+
if table is None:
|
|
2713
|
+
return
|
|
2714
|
+
|
|
2715
|
+
with self.Session() as sess, sess.begin():
|
|
2716
|
+
delete_stmt = table.delete().where(table.c.id == id)
|
|
2717
|
+
result = sess.execute(delete_stmt)
|
|
2718
|
+
|
|
2719
|
+
success = result.rowcount > 0
|
|
2720
|
+
if success:
|
|
2721
|
+
log_debug(f"Successfully deleted cultural artifact id: {id}")
|
|
2722
|
+
else:
|
|
2723
|
+
log_debug(f"No cultural artifact found with id: {id}")
|
|
2724
|
+
|
|
2725
|
+
except Exception as e:
|
|
2726
|
+
log_error(f"Error deleting cultural artifact: {e}")
|
|
2727
|
+
raise e
|
|
2728
|
+
|
|
2729
|
+
def get_cultural_knowledge(
|
|
2730
|
+
self, id: str, deserialize: Optional[bool] = True
|
|
2731
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2732
|
+
"""Get a cultural artifact from the database.
|
|
2733
|
+
|
|
2734
|
+
Args:
|
|
2735
|
+
id (str): The ID of the cultural artifact to get.
|
|
2736
|
+
deserialize (Optional[bool]): Whether to serialize the cultural artifact. Defaults to True.
|
|
2737
|
+
|
|
2738
|
+
Returns:
|
|
2739
|
+
Optional[CulturalKnowledge]: The cultural artifact, or None if it doesn't exist.
|
|
2740
|
+
|
|
2741
|
+
Raises:
|
|
2742
|
+
Exception: If an error occurs during retrieval.
|
|
2743
|
+
"""
|
|
2744
|
+
try:
|
|
2745
|
+
table = self._get_table(table_type="culture")
|
|
2746
|
+
if table is None:
|
|
2747
|
+
return None
|
|
2748
|
+
|
|
2749
|
+
with self.Session() as sess, sess.begin():
|
|
2750
|
+
stmt = select(table).where(table.c.id == id)
|
|
2751
|
+
result = sess.execute(stmt).fetchone()
|
|
2752
|
+
if result is None:
|
|
2753
|
+
return None
|
|
2754
|
+
|
|
2755
|
+
db_row = dict(result._mapping)
|
|
2756
|
+
if not db_row or not deserialize:
|
|
2757
|
+
return db_row
|
|
2758
|
+
|
|
2759
|
+
return deserialize_cultural_knowledge_from_db(db_row)
|
|
2760
|
+
|
|
2761
|
+
except Exception as e:
|
|
2762
|
+
log_error(f"Exception reading from cultural artifacts table: {e}")
|
|
2763
|
+
raise e
|
|
2764
|
+
|
|
2765
|
+
def get_all_cultural_knowledge(
|
|
2766
|
+
self,
|
|
2767
|
+
name: Optional[str] = None,
|
|
2768
|
+
agent_id: Optional[str] = None,
|
|
2769
|
+
team_id: Optional[str] = None,
|
|
2770
|
+
limit: Optional[int] = None,
|
|
2771
|
+
page: Optional[int] = None,
|
|
2772
|
+
sort_by: Optional[str] = None,
|
|
2773
|
+
sort_order: Optional[str] = None,
|
|
2774
|
+
deserialize: Optional[bool] = True,
|
|
2775
|
+
) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2776
|
+
"""Get all cultural artifacts from the database as CulturalNotion objects.
|
|
2777
|
+
|
|
2778
|
+
Args:
|
|
2779
|
+
name (Optional[str]): The name of the cultural artifact to filter by.
|
|
2780
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
2781
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
2782
|
+
limit (Optional[int]): The maximum number of cultural artifacts to return.
|
|
2783
|
+
page (Optional[int]): The page number.
|
|
2784
|
+
sort_by (Optional[str]): The column to sort by.
|
|
2785
|
+
sort_order (Optional[str]): The order to sort by.
|
|
2786
|
+
deserialize (Optional[bool]): Whether to serialize the cultural artifacts. Defaults to True.
|
|
2787
|
+
|
|
2788
|
+
Returns:
|
|
2789
|
+
Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2790
|
+
- When deserialize=True: List of CulturalNotion objects
|
|
2791
|
+
- When deserialize=False: List of CulturalNotion dictionaries and total count
|
|
2792
|
+
|
|
2793
|
+
Raises:
|
|
2794
|
+
Exception: If an error occurs during retrieval.
|
|
2795
|
+
"""
|
|
2796
|
+
try:
|
|
2797
|
+
table = self._get_table(table_type="culture")
|
|
2798
|
+
if table is None:
|
|
2799
|
+
return [] if deserialize else ([], 0)
|
|
2800
|
+
|
|
2801
|
+
with self.Session() as sess, sess.begin():
|
|
2802
|
+
stmt = select(table)
|
|
2803
|
+
|
|
2804
|
+
# Filtering
|
|
2805
|
+
if name is not None:
|
|
2806
|
+
stmt = stmt.where(table.c.name == name)
|
|
2807
|
+
if agent_id is not None:
|
|
2808
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
2809
|
+
if team_id is not None:
|
|
2810
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
2811
|
+
|
|
2812
|
+
# Get total count after applying filtering
|
|
2813
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
2814
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
2815
|
+
|
|
2816
|
+
# Sorting
|
|
2817
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
2818
|
+
# Paginating
|
|
2819
|
+
if limit is not None:
|
|
2820
|
+
stmt = stmt.limit(limit)
|
|
2821
|
+
if page is not None:
|
|
2822
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
2823
|
+
|
|
2824
|
+
result = sess.execute(stmt).fetchall()
|
|
2825
|
+
if not result:
|
|
2826
|
+
return [] if deserialize else ([], 0)
|
|
2827
|
+
|
|
2828
|
+
db_rows = [dict(record._mapping) for record in result]
|
|
2829
|
+
|
|
2830
|
+
if not deserialize:
|
|
2831
|
+
return db_rows, total_count
|
|
2832
|
+
|
|
2833
|
+
return [deserialize_cultural_knowledge_from_db(row) for row in db_rows]
|
|
2834
|
+
|
|
2835
|
+
except Exception as e:
|
|
2836
|
+
log_error(f"Error reading from cultural artifacts table: {e}")
|
|
2837
|
+
raise e
|
|
2838
|
+
|
|
2839
|
+
def upsert_cultural_knowledge(
|
|
2840
|
+
self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
|
|
2841
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2842
|
+
"""Upsert a cultural artifact into the database.
|
|
2843
|
+
|
|
2844
|
+
Args:
|
|
2845
|
+
cultural_knowledge (CulturalKnowledge): The cultural artifact to upsert.
|
|
2846
|
+
deserialize (Optional[bool]): Whether to serialize the cultural artifact. Defaults to True.
|
|
2847
|
+
|
|
2848
|
+
Returns:
|
|
2849
|
+
Optional[Union[CulturalNotion, Dict[str, Any]]]:
|
|
2850
|
+
- When deserialize=True: CulturalNotion object
|
|
2851
|
+
- When deserialize=False: CulturalNotion dictionary
|
|
2852
|
+
|
|
2853
|
+
Raises:
|
|
2854
|
+
Exception: If an error occurs during upsert.
|
|
2855
|
+
"""
|
|
2856
|
+
try:
|
|
2857
|
+
table = self._get_table(table_type="culture", create_table_if_not_found=True)
|
|
2858
|
+
if table is None:
|
|
2859
|
+
return None
|
|
2860
|
+
|
|
2861
|
+
if cultural_knowledge.id is None:
|
|
2862
|
+
cultural_knowledge.id = str(uuid4())
|
|
2863
|
+
|
|
2864
|
+
# Serialize content, categories, and notes into a JSON string for DB storage (SQLite requires strings)
|
|
2865
|
+
content_json_str = serialize_cultural_knowledge_for_db(cultural_knowledge)
|
|
2866
|
+
|
|
2867
|
+
with self.Session() as sess, sess.begin():
|
|
2868
|
+
stmt = sqlite.insert(table).values(
|
|
2869
|
+
id=cultural_knowledge.id,
|
|
2870
|
+
name=cultural_knowledge.name,
|
|
2871
|
+
summary=cultural_knowledge.summary,
|
|
2872
|
+
content=content_json_str,
|
|
2873
|
+
metadata=cultural_knowledge.metadata,
|
|
2874
|
+
input=cultural_knowledge.input,
|
|
2875
|
+
created_at=cultural_knowledge.created_at,
|
|
2876
|
+
updated_at=int(time.time()),
|
|
2877
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2878
|
+
team_id=cultural_knowledge.team_id,
|
|
2879
|
+
)
|
|
2880
|
+
stmt = stmt.on_conflict_do_update( # type: ignore
|
|
2881
|
+
index_elements=["id"],
|
|
2882
|
+
set_=dict(
|
|
2883
|
+
name=cultural_knowledge.name,
|
|
2884
|
+
summary=cultural_knowledge.summary,
|
|
2885
|
+
content=content_json_str,
|
|
2886
|
+
metadata=cultural_knowledge.metadata,
|
|
2887
|
+
input=cultural_knowledge.input,
|
|
2888
|
+
updated_at=int(time.time()),
|
|
2889
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2890
|
+
team_id=cultural_knowledge.team_id,
|
|
2891
|
+
),
|
|
2892
|
+
).returning(table)
|
|
2893
|
+
|
|
2894
|
+
result = sess.execute(stmt)
|
|
2895
|
+
row = result.fetchone()
|
|
2896
|
+
|
|
2897
|
+
if row is None:
|
|
2898
|
+
return None
|
|
2899
|
+
|
|
2900
|
+
db_row: Dict[str, Any] = dict(row._mapping)
|
|
2901
|
+
if not db_row or not deserialize:
|
|
2902
|
+
return db_row
|
|
2903
|
+
|
|
2904
|
+
return deserialize_cultural_knowledge_from_db(db_row)
|
|
2905
|
+
|
|
2906
|
+
except Exception as e:
|
|
2907
|
+
log_error(f"Error upserting cultural knowledge: {e}")
|
|
2908
|
+
raise e
|