agno 2.1.2__py3-none-any.whl → 2.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +5540 -2273
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +689 -6
- agno/db/dynamo/dynamo.py +933 -37
- agno/db/dynamo/schemas.py +174 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +831 -9
- agno/db/firestore/schemas.py +51 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +660 -12
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +287 -14
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +590 -14
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +43 -13
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2760 -0
- agno/db/mongo/mongo.py +879 -11
- agno/db/mongo/schemas.py +42 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2912 -0
- agno/db/mysql/mysql.py +946 -68
- agno/db/mysql/schemas.py +72 -10
- agno/db/mysql/utils.py +198 -7
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2579 -0
- agno/db/postgres/postgres.py +942 -57
- agno/db/postgres/schemas.py +81 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +671 -7
- agno/db/redis/schemas.py +50 -0
- agno/db/redis/utils.py +65 -7
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +17 -2
- agno/db/singlestore/schemas.py +63 -0
- agno/db/singlestore/singlestore.py +949 -83
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2911 -0
- agno/db/sqlite/schemas.py +62 -0
- agno/db/sqlite/sqlite.py +965 -46
- agno/db/sqlite/utils.py +169 -8
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +334 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1908 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +2 -0
- agno/eval/__init__.py +10 -0
- agno/eval/accuracy.py +75 -55
- agno/eval/agent_as_judge.py +861 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +16 -7
- agno/eval/reliability.py +28 -16
- agno/eval/utils.py +35 -17
- agno/exceptions.py +27 -2
- agno/filters.py +354 -0
- agno/guardrails/prompt_injection.py +1 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +1 -1
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/semantic.py +9 -4
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +8 -0
- agno/knowledge/embedder/openai.py +8 -8
- agno/knowledge/embedder/sentence_transformer.py +6 -2
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/knowledge.py +1618 -318
- agno/knowledge/reader/base.py +6 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +16 -20
- agno/knowledge/reader/json_reader.py +5 -4
- agno/knowledge/reader/markdown_reader.py +8 -8
- agno/knowledge/reader/pdf_reader.py +17 -19
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +32 -3
- agno/knowledge/reader/s3_reader.py +3 -3
- agno/knowledge/reader/tavily_reader.py +193 -0
- agno/knowledge/reader/text_reader.py +22 -10
- agno/knowledge/reader/web_search_reader.py +1 -48
- agno/knowledge/reader/website_reader.py +10 -10
- agno/knowledge/reader/wikipedia_reader.py +33 -1
- agno/knowledge/types.py +1 -0
- agno/knowledge/utils.py +72 -7
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +544 -83
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +515 -40
- agno/models/aws/bedrock.py +102 -21
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +41 -19
- agno/models/azure/openai_chat.py +39 -8
- agno/models/base.py +1249 -525
- agno/models/cerebras/cerebras.py +91 -21
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +40 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +877 -80
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +51 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +44 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +28 -5
- agno/models/meta/llama.py +47 -14
- agno/models/meta/llama_openai.py +22 -17
- agno/models/mistral/mistral.py +8 -4
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/chat.py +24 -8
- agno/models/openai/chat.py +104 -29
- agno/models/openai/responses.py +101 -81
- agno/models/openrouter/openrouter.py +60 -3
- agno/models/perplexity/perplexity.py +17 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +24 -4
- agno/models/response.py +73 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +190 -0
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +549 -152
- agno/os/auth.py +190 -3
- agno/os/config.py +23 -0
- agno/os/interfaces/a2a/router.py +8 -11
- agno/os/interfaces/a2a/utils.py +1 -1
- agno/os/interfaces/agui/router.py +18 -3
- agno/os/interfaces/agui/utils.py +152 -39
- agno/os/interfaces/slack/router.py +55 -37
- agno/os/interfaces/slack/slack.py +9 -1
- agno/os/interfaces/whatsapp/router.py +0 -1
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/mcp.py +110 -52
- agno/os/middleware/__init__.py +2 -0
- agno/os/middleware/jwt.py +676 -112
- agno/os/router.py +40 -1478
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +599 -0
- agno/os/routers/agents/schema.py +261 -0
- agno/os/routers/evals/evals.py +96 -39
- agno/os/routers/evals/schemas.py +65 -33
- agno/os/routers/evals/utils.py +80 -10
- agno/os/routers/health.py +10 -4
- agno/os/routers/knowledge/knowledge.py +196 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +279 -52
- agno/os/routers/memory/schemas.py +46 -17
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +462 -34
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +512 -0
- agno/os/routers/teams/schema.py +257 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +499 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +624 -0
- agno/os/routers/workflows/schema.py +75 -0
- agno/os/schema.py +256 -693
- agno/os/scopes.py +469 -0
- agno/os/utils.py +514 -36
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/openai.py +5 -0
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +155 -32
- agno/run/base.py +55 -3
- agno/run/requirement.py +181 -0
- agno/run/team.py +125 -38
- agno/run/workflow.py +72 -18
- agno/session/agent.py +102 -89
- agno/session/summary.py +56 -15
- agno/session/team.py +164 -90
- agno/session/workflow.py +405 -40
- agno/table.py +10 -0
- agno/team/team.py +3974 -1903
- agno/tools/dalle.py +2 -4
- agno/tools/eleven_labs.py +23 -25
- agno/tools/exa.py +21 -16
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +16 -10
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +193 -38
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +271 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +3 -3
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/notion.py +204 -0
- agno/tools/parallel.py +314 -0
- agno/tools/postgres.py +76 -36
- agno/tools/redshift.py +406 -0
- agno/tools/scrapegraph.py +1 -1
- agno/tools/shopify.py +1519 -0
- agno/tools/slack.py +18 -3
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +146 -0
- agno/tools/toolkit.py +25 -0
- agno/tools/workflow.py +8 -1
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +157 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +111 -0
- agno/utils/agent.py +938 -0
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +151 -3
- agno/utils/gemini.py +15 -5
- agno/utils/hooks.py +118 -4
- agno/utils/http.py +113 -2
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +92 -2
- agno/utils/media.py +187 -1
- agno/utils/merge_dict.py +3 -3
- agno/utils/message.py +60 -0
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +49 -14
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/print_response/agent.py +109 -16
- agno/utils/print_response/team.py +223 -30
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/streamlit.py +1 -1
- agno/utils/team.py +98 -9
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +39 -7
- agno/vectordb/cassandra/cassandra.py +21 -5
- agno/vectordb/chroma/chromadb.py +43 -12
- agno/vectordb/clickhouse/clickhousedb.py +21 -5
- agno/vectordb/couchbase/couchbase.py +29 -5
- agno/vectordb/lancedb/lance_db.py +92 -181
- agno/vectordb/langchaindb/langchaindb.py +24 -4
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/llamaindexdb.py +25 -5
- agno/vectordb/milvus/milvus.py +50 -37
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +36 -30
- agno/vectordb/pgvector/pgvector.py +201 -77
- agno/vectordb/pineconedb/pineconedb.py +41 -23
- agno/vectordb/qdrant/qdrant.py +67 -54
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +682 -0
- agno/vectordb/singlestore/singlestore.py +50 -29
- agno/vectordb/surrealdb/surrealdb.py +31 -41
- agno/vectordb/upstashdb/upstashdb.py +34 -6
- agno/vectordb/weaviate/weaviate.py +53 -14
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +120 -18
- agno/workflow/loop.py +77 -10
- agno/workflow/parallel.py +231 -143
- agno/workflow/router.py +118 -17
- agno/workflow/step.py +609 -170
- agno/workflow/steps.py +73 -6
- agno/workflow/types.py +96 -21
- agno/workflow/workflow.py +2039 -262
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/METADATA +201 -66
- agno-2.3.13.dist-info/RECORD +613 -0
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -679
- agno/tools/memori.py +0 -339
- agno-2.1.2.dist-info/RECORD +0 -543
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/WHEEL +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/top_level.txt +0 -0
agno/db/mysql/mysql.py
CHANGED
|
@@ -1,31 +1,36 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from datetime import date, datetime, timedelta, timezone
|
|
3
|
-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from agno.tracing.schemas import Span, Trace
|
|
7
8
|
|
|
8
9
|
from agno.db.base import BaseDb, SessionType
|
|
10
|
+
from agno.db.migrations.manager import MigrationManager
|
|
9
11
|
from agno.db.mysql.schemas import get_table_schema_definition
|
|
10
12
|
from agno.db.mysql.utils import (
|
|
11
13
|
apply_sorting,
|
|
12
14
|
bulk_upsert_metrics,
|
|
13
15
|
calculate_date_metrics,
|
|
14
16
|
create_schema,
|
|
17
|
+
deserialize_cultural_knowledge_from_db,
|
|
15
18
|
fetch_all_sessions_data,
|
|
16
19
|
get_dates_to_calculate_metrics_for,
|
|
17
20
|
is_table_available,
|
|
18
21
|
is_valid_table,
|
|
22
|
+
serialize_cultural_knowledge_for_db,
|
|
19
23
|
)
|
|
24
|
+
from agno.db.schemas.culture import CulturalKnowledge
|
|
20
25
|
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
21
26
|
from agno.db.schemas.knowledge import KnowledgeRow
|
|
22
27
|
from agno.db.schemas.memory import UserMemory
|
|
23
28
|
from agno.session import AgentSession, Session, TeamSession, WorkflowSession
|
|
24
|
-
from agno.utils.log import log_debug, log_error, log_info
|
|
29
|
+
from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
25
30
|
from agno.utils.string import generate_id
|
|
26
31
|
|
|
27
32
|
try:
|
|
28
|
-
from sqlalchemy import TEXT, and_, cast, func, update
|
|
33
|
+
from sqlalchemy import TEXT, ForeignKey, Index, UniqueConstraint, and_, cast, func, update
|
|
29
34
|
from sqlalchemy.dialects import mysql
|
|
30
35
|
from sqlalchemy.engine import Engine, create_engine
|
|
31
36
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
|
@@ -38,15 +43,20 @@ except ImportError:
|
|
|
38
43
|
class MySQLDb(BaseDb):
|
|
39
44
|
def __init__(
|
|
40
45
|
self,
|
|
46
|
+
id: Optional[str] = None,
|
|
41
47
|
db_engine: Optional[Engine] = None,
|
|
42
48
|
db_schema: Optional[str] = None,
|
|
43
49
|
db_url: Optional[str] = None,
|
|
44
50
|
session_table: Optional[str] = None,
|
|
51
|
+
culture_table: Optional[str] = None,
|
|
45
52
|
memory_table: Optional[str] = None,
|
|
46
53
|
metrics_table: Optional[str] = None,
|
|
47
54
|
eval_table: Optional[str] = None,
|
|
48
55
|
knowledge_table: Optional[str] = None,
|
|
49
|
-
|
|
56
|
+
traces_table: Optional[str] = None,
|
|
57
|
+
spans_table: Optional[str] = None,
|
|
58
|
+
versions_table: Optional[str] = None,
|
|
59
|
+
create_schema: bool = True,
|
|
50
60
|
):
|
|
51
61
|
"""
|
|
52
62
|
Interface for interacting with a MySQL database.
|
|
@@ -57,15 +67,21 @@ class MySQLDb(BaseDb):
|
|
|
57
67
|
3. Raise an error if neither is provided
|
|
58
68
|
|
|
59
69
|
Args:
|
|
70
|
+
id (Optional[str]): ID of the database.
|
|
60
71
|
db_url (Optional[str]): The database URL to connect to.
|
|
61
72
|
db_engine (Optional[Engine]): The SQLAlchemy database engine to use.
|
|
62
73
|
db_schema (Optional[str]): The database schema to use.
|
|
63
74
|
session_table (Optional[str]): Name of the table to store Agent, Team and Workflow sessions.
|
|
75
|
+
culture_table (Optional[str]): Name of the table to store cultural knowledge.
|
|
64
76
|
memory_table (Optional[str]): Name of the table to store memories.
|
|
65
77
|
metrics_table (Optional[str]): Name of the table to store metrics.
|
|
66
78
|
eval_table (Optional[str]): Name of the table to store evaluation runs data.
|
|
67
79
|
knowledge_table (Optional[str]): Name of the table to store knowledge content.
|
|
68
|
-
|
|
80
|
+
traces_table (Optional[str]): Name of the table to store run traces.
|
|
81
|
+
spans_table (Optional[str]): Name of the table to store span events.
|
|
82
|
+
versions_table (Optional[str]): Name of the table to store schema versions.
|
|
83
|
+
create_schema (bool): Whether to automatically create the database schema if it doesn't exist.
|
|
84
|
+
Set to False if schema is managed externally (e.g., via migrations). Defaults to True.
|
|
69
85
|
|
|
70
86
|
Raises:
|
|
71
87
|
ValueError: If neither db_url nor db_engine is provided.
|
|
@@ -80,10 +96,14 @@ class MySQLDb(BaseDb):
|
|
|
80
96
|
super().__init__(
|
|
81
97
|
id=id,
|
|
82
98
|
session_table=session_table,
|
|
99
|
+
culture_table=culture_table,
|
|
83
100
|
memory_table=memory_table,
|
|
84
101
|
metrics_table=metrics_table,
|
|
85
102
|
eval_table=eval_table,
|
|
86
103
|
knowledge_table=knowledge_table,
|
|
104
|
+
traces_table=traces_table,
|
|
105
|
+
spans_table=spans_table,
|
|
106
|
+
versions_table=versions_table,
|
|
87
107
|
)
|
|
88
108
|
|
|
89
109
|
_engine: Optional[Engine] = db_engine
|
|
@@ -95,28 +115,38 @@ class MySQLDb(BaseDb):
|
|
|
95
115
|
self.db_url: Optional[str] = db_url
|
|
96
116
|
self.db_engine: Engine = _engine
|
|
97
117
|
self.db_schema: str = db_schema if db_schema is not None else "ai"
|
|
98
|
-
self.metadata: MetaData = MetaData()
|
|
118
|
+
self.metadata: MetaData = MetaData(schema=self.db_schema)
|
|
119
|
+
self.create_schema: bool = create_schema
|
|
99
120
|
|
|
100
121
|
# Initialize database session
|
|
101
122
|
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
|
|
102
123
|
|
|
103
124
|
# -- DB methods --
|
|
104
|
-
def
|
|
125
|
+
def table_exists(self, table_name: str) -> bool:
|
|
126
|
+
"""Check if a table with the given name exists in the MySQL database.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
table_name: Name of the table to check
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
bool: True if the table exists in the database, False otherwise
|
|
133
|
+
"""
|
|
134
|
+
with self.Session() as sess:
|
|
135
|
+
return is_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
|
|
136
|
+
|
|
137
|
+
def _create_table(self, table_name: str, table_type: str) -> Table:
|
|
105
138
|
"""
|
|
106
139
|
Create a table with the appropriate schema based on the table type.
|
|
107
140
|
|
|
108
141
|
Args:
|
|
109
142
|
table_name (str): Name of the table to create
|
|
110
143
|
table_type (str): Type of table (used to get schema definition)
|
|
111
|
-
db_schema (str): Database schema name
|
|
112
144
|
|
|
113
145
|
Returns:
|
|
114
146
|
Table: SQLAlchemy Table object
|
|
115
147
|
"""
|
|
116
148
|
try:
|
|
117
|
-
table_schema = get_table_schema_definition(table_type)
|
|
118
|
-
|
|
119
|
-
log_debug(f"Creating table {db_schema}.{table_name} with schema: {table_schema}")
|
|
149
|
+
table_schema = get_table_schema_definition(table_type).copy()
|
|
120
150
|
|
|
121
151
|
columns: List[Column] = []
|
|
122
152
|
indexes: List[str] = []
|
|
@@ -136,11 +166,20 @@ class MySQLDb(BaseDb):
|
|
|
136
166
|
if col_config.get("unique", False):
|
|
137
167
|
column_kwargs["unique"] = True
|
|
138
168
|
unique_constraints.append(col_name)
|
|
169
|
+
|
|
170
|
+
# Handle foreign key constraint
|
|
171
|
+
if "foreign_key" in col_config:
|
|
172
|
+
fk_ref = col_config["foreign_key"]
|
|
173
|
+
# For spans table, dynamically replace the traces table reference
|
|
174
|
+
# with the actual trace table name configured for this db instance
|
|
175
|
+
if table_type == "spans" and "trace_id" in fk_ref:
|
|
176
|
+
fk_ref = f"{self.db_schema}.{self.trace_table_name}.trace_id"
|
|
177
|
+
column_args.append(ForeignKey(fk_ref))
|
|
178
|
+
|
|
139
179
|
columns.append(Column(*column_args, **column_kwargs)) # type: ignore
|
|
140
180
|
|
|
141
181
|
# Create the table object
|
|
142
|
-
|
|
143
|
-
table = Table(table_name, table_metadata, *columns, schema=db_schema)
|
|
182
|
+
table = Table(table_name, self.metadata, *columns, schema=self.db_schema)
|
|
144
183
|
|
|
145
184
|
# Add multi-column unique constraints with table-specific names
|
|
146
185
|
for constraint in schema_unique_constraints:
|
|
@@ -153,17 +192,22 @@ class MySQLDb(BaseDb):
|
|
|
153
192
|
idx_name = f"idx_{table_name}_{idx_col}"
|
|
154
193
|
table.append_constraint(Index(idx_name, idx_col))
|
|
155
194
|
|
|
156
|
-
|
|
157
|
-
|
|
195
|
+
if self.create_schema:
|
|
196
|
+
with self.Session() as sess, sess.begin():
|
|
197
|
+
create_schema(session=sess, db_schema=self.db_schema)
|
|
158
198
|
|
|
159
199
|
# Create table
|
|
160
|
-
|
|
200
|
+
table_created = False
|
|
201
|
+
if not self.table_exists(table_name):
|
|
202
|
+
table.create(self.db_engine, checkfirst=True)
|
|
203
|
+
log_debug(f"Successfully created table '{table_name}'")
|
|
204
|
+
table_created = True
|
|
205
|
+
else:
|
|
206
|
+
log_debug(f"Table {self.db_schema}.{table_name} already exists, skipping creation")
|
|
161
207
|
|
|
162
208
|
# Create indexes
|
|
163
209
|
for idx in table.indexes:
|
|
164
210
|
try:
|
|
165
|
-
log_debug(f"Creating index: {idx.name}")
|
|
166
|
-
|
|
167
211
|
# Check if index already exists
|
|
168
212
|
with self.Session() as sess:
|
|
169
213
|
exists_query = text(
|
|
@@ -172,32 +216,59 @@ class MySQLDb(BaseDb):
|
|
|
172
216
|
)
|
|
173
217
|
exists = (
|
|
174
218
|
sess.execute(
|
|
175
|
-
exists_query,
|
|
219
|
+
exists_query,
|
|
220
|
+
{"schema": self.db_schema, "table_name": table_name, "index_name": idx.name},
|
|
176
221
|
).scalar()
|
|
177
222
|
is not None
|
|
178
223
|
)
|
|
179
224
|
if exists:
|
|
180
|
-
log_debug(
|
|
225
|
+
log_debug(
|
|
226
|
+
f"Index {idx.name} already exists in {self.db_schema}.{table_name}, skipping creation"
|
|
227
|
+
)
|
|
181
228
|
continue
|
|
182
229
|
|
|
183
230
|
idx.create(self.db_engine)
|
|
184
231
|
|
|
232
|
+
log_debug(f"Created index: {idx.name} for table {self.db_schema}.{table_name}")
|
|
185
233
|
except Exception as e:
|
|
186
234
|
log_error(f"Error creating index {idx.name}: {e}")
|
|
187
235
|
|
|
188
|
-
|
|
236
|
+
# Store the schema version for the created table
|
|
237
|
+
if table_name != self.versions_table_name and table_created:
|
|
238
|
+
latest_schema_version = MigrationManager(self).latest_schema_version
|
|
239
|
+
self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
|
|
240
|
+
log_info(
|
|
241
|
+
f"Successfully stored version {latest_schema_version.public} in database for table {table_name}"
|
|
242
|
+
)
|
|
243
|
+
|
|
189
244
|
return table
|
|
190
245
|
|
|
191
246
|
except Exception as e:
|
|
192
|
-
log_error(f"Could not create table {db_schema}.{table_name}: {e}")
|
|
247
|
+
log_error(f"Could not create table {self.db_schema}.{table_name}: {e}")
|
|
193
248
|
raise
|
|
194
249
|
|
|
250
|
+
def _create_all_tables(self):
|
|
251
|
+
"""Create all tables for the database."""
|
|
252
|
+
tables_to_create = [
|
|
253
|
+
(self.session_table_name, "sessions"),
|
|
254
|
+
(self.memory_table_name, "memories"),
|
|
255
|
+
(self.metrics_table_name, "metrics"),
|
|
256
|
+
(self.eval_table_name, "evals"),
|
|
257
|
+
(self.knowledge_table_name, "knowledge"),
|
|
258
|
+
(self.culture_table_name, "culture"),
|
|
259
|
+
(self.trace_table_name, "traces"),
|
|
260
|
+
(self.span_table_name, "spans"),
|
|
261
|
+
(self.versions_table_name, "versions"),
|
|
262
|
+
]
|
|
263
|
+
|
|
264
|
+
for table_name, table_type in tables_to_create:
|
|
265
|
+
self._get_or_create_table(table_name=table_name, table_type=table_type, create_table_if_not_found=True)
|
|
266
|
+
|
|
195
267
|
def _get_table(self, table_type: str, create_table_if_not_found: Optional[bool] = False) -> Optional[Table]:
|
|
196
268
|
if table_type == "sessions":
|
|
197
269
|
self.session_table = self._get_or_create_table(
|
|
198
270
|
table_name=self.session_table_name,
|
|
199
271
|
table_type="sessions",
|
|
200
|
-
db_schema=self.db_schema,
|
|
201
272
|
create_table_if_not_found=create_table_if_not_found,
|
|
202
273
|
)
|
|
203
274
|
return self.session_table
|
|
@@ -206,7 +277,6 @@ class MySQLDb(BaseDb):
|
|
|
206
277
|
self.memory_table = self._get_or_create_table(
|
|
207
278
|
table_name=self.memory_table_name,
|
|
208
279
|
table_type="memories",
|
|
209
|
-
db_schema=self.db_schema,
|
|
210
280
|
create_table_if_not_found=create_table_if_not_found,
|
|
211
281
|
)
|
|
212
282
|
return self.memory_table
|
|
@@ -215,7 +285,6 @@ class MySQLDb(BaseDb):
|
|
|
215
285
|
self.metrics_table = self._get_or_create_table(
|
|
216
286
|
table_name=self.metrics_table_name,
|
|
217
287
|
table_type="metrics",
|
|
218
|
-
db_schema=self.db_schema,
|
|
219
288
|
create_table_if_not_found=create_table_if_not_found,
|
|
220
289
|
)
|
|
221
290
|
return self.metrics_table
|
|
@@ -224,7 +293,6 @@ class MySQLDb(BaseDb):
|
|
|
224
293
|
self.eval_table = self._get_or_create_table(
|
|
225
294
|
table_name=self.eval_table_name,
|
|
226
295
|
table_type="evals",
|
|
227
|
-
db_schema=self.db_schema,
|
|
228
296
|
create_table_if_not_found=create_table_if_not_found,
|
|
229
297
|
)
|
|
230
298
|
return self.eval_table
|
|
@@ -233,15 +301,50 @@ class MySQLDb(BaseDb):
|
|
|
233
301
|
self.knowledge_table = self._get_or_create_table(
|
|
234
302
|
table_name=self.knowledge_table_name,
|
|
235
303
|
table_type="knowledge",
|
|
236
|
-
db_schema=self.db_schema,
|
|
237
304
|
create_table_if_not_found=create_table_if_not_found,
|
|
238
305
|
)
|
|
239
306
|
return self.knowledge_table
|
|
240
307
|
|
|
308
|
+
if table_type == "culture":
|
|
309
|
+
self.culture_table = self._get_or_create_table(
|
|
310
|
+
table_name=self.culture_table_name,
|
|
311
|
+
table_type="culture",
|
|
312
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
313
|
+
)
|
|
314
|
+
return self.culture_table
|
|
315
|
+
|
|
316
|
+
if table_type == "versions":
|
|
317
|
+
self.versions_table = self._get_or_create_table(
|
|
318
|
+
table_name=self.versions_table_name,
|
|
319
|
+
table_type="versions",
|
|
320
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
321
|
+
)
|
|
322
|
+
return self.versions_table
|
|
323
|
+
|
|
324
|
+
if table_type == "traces":
|
|
325
|
+
self.traces_table = self._get_or_create_table(
|
|
326
|
+
table_name=self.trace_table_name,
|
|
327
|
+
table_type="traces",
|
|
328
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
329
|
+
)
|
|
330
|
+
return self.traces_table
|
|
331
|
+
|
|
332
|
+
if table_type == "spans":
|
|
333
|
+
# Ensure traces table exists first (spans has FK to traces)
|
|
334
|
+
if create_table_if_not_found:
|
|
335
|
+
self._get_table(table_type="traces", create_table_if_not_found=True)
|
|
336
|
+
|
|
337
|
+
self.spans_table = self._get_or_create_table(
|
|
338
|
+
table_name=self.span_table_name,
|
|
339
|
+
table_type="spans",
|
|
340
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
341
|
+
)
|
|
342
|
+
return self.spans_table
|
|
343
|
+
|
|
241
344
|
raise ValueError(f"Unknown table type: {table_type}")
|
|
242
345
|
|
|
243
346
|
def _get_or_create_table(
|
|
244
|
-
self, table_name: str, table_type: str,
|
|
347
|
+
self, table_name: str, table_type: str, create_table_if_not_found: Optional[bool] = False
|
|
245
348
|
) -> Optional[Table]:
|
|
246
349
|
"""
|
|
247
350
|
Check if the table exists and is valid, else create it.
|
|
@@ -249,38 +352,71 @@ class MySQLDb(BaseDb):
|
|
|
249
352
|
Args:
|
|
250
353
|
table_name (str): Name of the table to get or create
|
|
251
354
|
table_type (str): Type of table (used to get schema definition)
|
|
252
|
-
db_schema (str): Database schema name
|
|
253
355
|
|
|
254
356
|
Returns:
|
|
255
357
|
Table: SQLAlchemy Table object representing the schema.
|
|
256
358
|
"""
|
|
257
359
|
|
|
258
360
|
with self.Session() as sess, sess.begin():
|
|
259
|
-
table_is_available = is_table_available(session=sess, table_name=table_name, db_schema=db_schema)
|
|
361
|
+
table_is_available = is_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
|
|
260
362
|
|
|
261
363
|
if not table_is_available:
|
|
262
364
|
if not create_table_if_not_found:
|
|
263
365
|
return None
|
|
264
366
|
|
|
265
|
-
|
|
367
|
+
created_table = self._create_table(table_name=table_name, table_type=table_type)
|
|
368
|
+
|
|
369
|
+
return created_table
|
|
266
370
|
|
|
267
371
|
if not is_valid_table(
|
|
268
372
|
db_engine=self.db_engine,
|
|
269
373
|
table_name=table_name,
|
|
270
374
|
table_type=table_type,
|
|
271
|
-
db_schema=db_schema,
|
|
375
|
+
db_schema=self.db_schema,
|
|
272
376
|
):
|
|
273
|
-
raise ValueError(f"Table {db_schema}.{table_name} has an invalid schema")
|
|
377
|
+
raise ValueError(f"Table {self.db_schema}.{table_name} has an invalid schema")
|
|
274
378
|
|
|
275
379
|
try:
|
|
276
|
-
table = Table(table_name, self.metadata, schema=db_schema, autoload_with=self.db_engine)
|
|
277
|
-
log_debug(f"Loaded existing table {db_schema}.{table_name}")
|
|
380
|
+
table = Table(table_name, self.metadata, schema=self.db_schema, autoload_with=self.db_engine)
|
|
278
381
|
return table
|
|
279
382
|
|
|
280
383
|
except Exception as e:
|
|
281
|
-
log_error(f"Error loading existing table {db_schema}.{table_name}: {e}")
|
|
384
|
+
log_error(f"Error loading existing table {self.db_schema}.{table_name}: {e}")
|
|
282
385
|
raise
|
|
283
386
|
|
|
387
|
+
def get_latest_schema_version(self, table_name: str) -> str:
|
|
388
|
+
"""Get the latest version of the database schema."""
|
|
389
|
+
table = self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
390
|
+
with self.Session() as sess:
|
|
391
|
+
# Latest version for the given table
|
|
392
|
+
stmt = select(table).where(table.c.table_name == table_name).order_by(table.c.version.desc()).limit(1) # type: ignore
|
|
393
|
+
result = sess.execute(stmt).fetchone()
|
|
394
|
+
if result is None:
|
|
395
|
+
return "2.0.0"
|
|
396
|
+
version_dict = dict(result._mapping)
|
|
397
|
+
return version_dict.get("version") or "2.0.0"
|
|
398
|
+
|
|
399
|
+
def upsert_schema_version(self, table_name: str, version: str) -> None:
|
|
400
|
+
"""Upsert the schema version into the database."""
|
|
401
|
+
table = self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
402
|
+
if table is None:
|
|
403
|
+
return
|
|
404
|
+
current_datetime = datetime.now().isoformat()
|
|
405
|
+
with self.Session() as sess, sess.begin():
|
|
406
|
+
stmt = mysql.insert(table).values( # type: ignore
|
|
407
|
+
table_name=table_name,
|
|
408
|
+
version=version,
|
|
409
|
+
created_at=current_datetime, # Store as ISO format string
|
|
410
|
+
updated_at=current_datetime,
|
|
411
|
+
)
|
|
412
|
+
# Update version if table_name already exists
|
|
413
|
+
stmt = stmt.on_duplicate_key_update(
|
|
414
|
+
version=version,
|
|
415
|
+
created_at=current_datetime,
|
|
416
|
+
updated_at=current_datetime,
|
|
417
|
+
)
|
|
418
|
+
sess.execute(stmt)
|
|
419
|
+
|
|
284
420
|
# -- Session methods --
|
|
285
421
|
def delete_session(self, session_id: str) -> bool:
|
|
286
422
|
"""
|
|
@@ -372,9 +508,6 @@ class MySQLDb(BaseDb):
|
|
|
372
508
|
|
|
373
509
|
if user_id is not None:
|
|
374
510
|
stmt = stmt.where(table.c.user_id == user_id)
|
|
375
|
-
if session_type is not None:
|
|
376
|
-
session_type_value = session_type.value if isinstance(session_type, SessionType) else session_type
|
|
377
|
-
stmt = stmt.where(table.c.session_type == session_type_value)
|
|
378
511
|
result = sess.execute(stmt).fetchone()
|
|
379
512
|
if result is None:
|
|
380
513
|
return None
|
|
@@ -417,7 +550,7 @@ class MySQLDb(BaseDb):
|
|
|
417
550
|
Args:
|
|
418
551
|
session_type (Optional[SessionType]): The type of sessions to get.
|
|
419
552
|
user_id (Optional[str]): The ID of the user to filter by.
|
|
420
|
-
|
|
553
|
+
component_id (Optional[str]): The ID of the agent / workflow to filter by.
|
|
421
554
|
start_timestamp (Optional[int]): The start timestamp to filter by.
|
|
422
555
|
end_timestamp (Optional[int]): The end timestamp to filter by.
|
|
423
556
|
session_name (Optional[str]): The name of the session to filter by.
|
|
@@ -426,7 +559,6 @@ class MySQLDb(BaseDb):
|
|
|
426
559
|
sort_by (Optional[str]): The field to sort by. Defaults to None.
|
|
427
560
|
sort_order (Optional[str]): The sort order. Defaults to None.
|
|
428
561
|
deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
|
|
429
|
-
create_table_if_not_found (Optional[bool]): Whether to create the table if it doesn't exist.
|
|
430
562
|
|
|
431
563
|
Returns:
|
|
432
564
|
Union[List[Session], Tuple[List[Dict], int]]:
|
|
@@ -706,7 +838,7 @@ class MySQLDb(BaseDb):
|
|
|
706
838
|
return None
|
|
707
839
|
|
|
708
840
|
def upsert_sessions(
|
|
709
|
-
self, sessions: List[Session], deserialize: Optional[bool] = True
|
|
841
|
+
self, sessions: List[Session], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
|
|
710
842
|
) -> List[Union[Session, Dict[str, Any]]]:
|
|
711
843
|
"""
|
|
712
844
|
Bulk upsert multiple sessions for improved performance on large datasets.
|
|
@@ -714,6 +846,7 @@ class MySQLDb(BaseDb):
|
|
|
714
846
|
Args:
|
|
715
847
|
sessions (List[Session]): List of sessions to upsert.
|
|
716
848
|
deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
|
|
849
|
+
preserve_updated_at (bool): If True, preserve the updated_at from the session object.
|
|
717
850
|
|
|
718
851
|
Returns:
|
|
719
852
|
List[Union[Session, Dict[str, Any]]]: List of upserted sessions.
|
|
@@ -758,6 +891,8 @@ class MySQLDb(BaseDb):
|
|
|
758
891
|
agent_data = []
|
|
759
892
|
for session in agent_sessions:
|
|
760
893
|
session_dict = session.to_dict()
|
|
894
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
895
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
761
896
|
agent_data.append(
|
|
762
897
|
{
|
|
763
898
|
"session_id": session_dict.get("session_id"),
|
|
@@ -770,7 +905,7 @@ class MySQLDb(BaseDb):
|
|
|
770
905
|
"summary": session_dict.get("summary"),
|
|
771
906
|
"metadata": session_dict.get("metadata"),
|
|
772
907
|
"created_at": session_dict.get("created_at"),
|
|
773
|
-
"updated_at":
|
|
908
|
+
"updated_at": updated_at,
|
|
774
909
|
}
|
|
775
910
|
)
|
|
776
911
|
|
|
@@ -784,7 +919,7 @@ class MySQLDb(BaseDb):
|
|
|
784
919
|
summary=stmt.inserted.summary,
|
|
785
920
|
metadata=stmt.inserted.metadata,
|
|
786
921
|
runs=stmt.inserted.runs,
|
|
787
|
-
updated_at=
|
|
922
|
+
updated_at=stmt.inserted.updated_at,
|
|
788
923
|
)
|
|
789
924
|
sess.execute(stmt, agent_data)
|
|
790
925
|
|
|
@@ -808,6 +943,8 @@ class MySQLDb(BaseDb):
|
|
|
808
943
|
team_data = []
|
|
809
944
|
for session in team_sessions:
|
|
810
945
|
session_dict = session.to_dict()
|
|
946
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
947
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
811
948
|
team_data.append(
|
|
812
949
|
{
|
|
813
950
|
"session_id": session_dict.get("session_id"),
|
|
@@ -820,7 +957,7 @@ class MySQLDb(BaseDb):
|
|
|
820
957
|
"summary": session_dict.get("summary"),
|
|
821
958
|
"metadata": session_dict.get("metadata"),
|
|
822
959
|
"created_at": session_dict.get("created_at"),
|
|
823
|
-
"updated_at":
|
|
960
|
+
"updated_at": updated_at,
|
|
824
961
|
}
|
|
825
962
|
)
|
|
826
963
|
|
|
@@ -834,7 +971,7 @@ class MySQLDb(BaseDb):
|
|
|
834
971
|
summary=stmt.inserted.summary,
|
|
835
972
|
metadata=stmt.inserted.metadata,
|
|
836
973
|
runs=stmt.inserted.runs,
|
|
837
|
-
updated_at=
|
|
974
|
+
updated_at=stmt.inserted.updated_at,
|
|
838
975
|
)
|
|
839
976
|
sess.execute(stmt, team_data)
|
|
840
977
|
|
|
@@ -858,6 +995,8 @@ class MySQLDb(BaseDb):
|
|
|
858
995
|
workflow_data = []
|
|
859
996
|
for session in workflow_sessions:
|
|
860
997
|
session_dict = session.to_dict()
|
|
998
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
999
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
861
1000
|
workflow_data.append(
|
|
862
1001
|
{
|
|
863
1002
|
"session_id": session_dict.get("session_id"),
|
|
@@ -870,7 +1009,7 @@ class MySQLDb(BaseDb):
|
|
|
870
1009
|
"summary": session_dict.get("summary"),
|
|
871
1010
|
"metadata": session_dict.get("metadata"),
|
|
872
1011
|
"created_at": session_dict.get("created_at"),
|
|
873
|
-
"updated_at":
|
|
1012
|
+
"updated_at": updated_at,
|
|
874
1013
|
}
|
|
875
1014
|
)
|
|
876
1015
|
|
|
@@ -884,7 +1023,7 @@ class MySQLDb(BaseDb):
|
|
|
884
1023
|
summary=stmt.inserted.summary,
|
|
885
1024
|
metadata=stmt.inserted.metadata,
|
|
886
1025
|
runs=stmt.inserted.runs,
|
|
887
|
-
updated_at=
|
|
1026
|
+
updated_at=stmt.inserted.updated_at,
|
|
888
1027
|
)
|
|
889
1028
|
sess.execute(stmt, workflow_data)
|
|
890
1029
|
|
|
@@ -976,9 +1115,12 @@ class MySQLDb(BaseDb):
|
|
|
976
1115
|
except Exception as e:
|
|
977
1116
|
log_error(f"Error deleting user memories: {e}")
|
|
978
1117
|
|
|
979
|
-
def get_all_memory_topics(self) -> List[str]:
|
|
1118
|
+
def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
|
|
980
1119
|
"""Get all memory topics from the database.
|
|
981
1120
|
|
|
1121
|
+
Args:
|
|
1122
|
+
user_id (Optional[str]): Optional user ID to filter topics.
|
|
1123
|
+
|
|
982
1124
|
Returns:
|
|
983
1125
|
List[str]: List of memory topics.
|
|
984
1126
|
"""
|
|
@@ -1151,7 +1293,7 @@ class MySQLDb(BaseDb):
|
|
|
1151
1293
|
log_error(f"Exception clearing user memories: {e}")
|
|
1152
1294
|
|
|
1153
1295
|
def get_user_memory_stats(
|
|
1154
|
-
self, limit: Optional[int] = None, page: Optional[int] = None
|
|
1296
|
+
self, limit: Optional[int] = None, page: Optional[int] = None, user_id: Optional[str] = None
|
|
1155
1297
|
) -> Tuple[List[Dict[str, Any]], int]:
|
|
1156
1298
|
"""Get user memories stats.
|
|
1157
1299
|
|
|
@@ -1180,17 +1322,20 @@ class MySQLDb(BaseDb):
|
|
|
1180
1322
|
return [], 0
|
|
1181
1323
|
|
|
1182
1324
|
with self.Session() as sess, sess.begin():
|
|
1183
|
-
stmt = (
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
func.max(table.c.updated_at).label("last_memory_updated_at"),
|
|
1188
|
-
)
|
|
1189
|
-
.where(table.c.user_id.is_not(None))
|
|
1190
|
-
.group_by(table.c.user_id)
|
|
1191
|
-
.order_by(func.max(table.c.updated_at).desc())
|
|
1325
|
+
stmt = select(
|
|
1326
|
+
table.c.user_id,
|
|
1327
|
+
func.count(table.c.memory_id).label("total_memories"),
|
|
1328
|
+
func.max(table.c.updated_at).label("last_memory_updated_at"),
|
|
1192
1329
|
)
|
|
1193
1330
|
|
|
1331
|
+
if user_id is not None:
|
|
1332
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
1333
|
+
else:
|
|
1334
|
+
stmt = stmt.where(table.c.user_id.is_not(None))
|
|
1335
|
+
|
|
1336
|
+
stmt = stmt.group_by(table.c.user_id)
|
|
1337
|
+
stmt = stmt.order_by(func.max(table.c.updated_at).desc())
|
|
1338
|
+
|
|
1194
1339
|
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1195
1340
|
total_count = sess.execute(count_stmt).scalar()
|
|
1196
1341
|
|
|
@@ -1243,6 +1388,8 @@ class MySQLDb(BaseDb):
|
|
|
1243
1388
|
if memory.memory_id is None:
|
|
1244
1389
|
memory.memory_id = str(uuid4())
|
|
1245
1390
|
|
|
1391
|
+
current_time = int(time.time())
|
|
1392
|
+
|
|
1246
1393
|
stmt = mysql.insert(table).values(
|
|
1247
1394
|
memory_id=memory.memory_id,
|
|
1248
1395
|
memory=memory.memory,
|
|
@@ -1251,7 +1398,9 @@ class MySQLDb(BaseDb):
|
|
|
1251
1398
|
agent_id=memory.agent_id,
|
|
1252
1399
|
team_id=memory.team_id,
|
|
1253
1400
|
topics=memory.topics,
|
|
1254
|
-
|
|
1401
|
+
feedback=memory.feedback,
|
|
1402
|
+
created_at=memory.created_at,
|
|
1403
|
+
updated_at=memory.created_at,
|
|
1255
1404
|
)
|
|
1256
1405
|
stmt = stmt.on_duplicate_key_update(
|
|
1257
1406
|
memory=memory.memory,
|
|
@@ -1259,7 +1408,10 @@ class MySQLDb(BaseDb):
|
|
|
1259
1408
|
input=memory.input,
|
|
1260
1409
|
agent_id=memory.agent_id,
|
|
1261
1410
|
team_id=memory.team_id,
|
|
1262
|
-
|
|
1411
|
+
feedback=memory.feedback,
|
|
1412
|
+
updated_at=current_time,
|
|
1413
|
+
# Preserve created_at on update - don't overwrite existing value
|
|
1414
|
+
created_at=table.c.created_at,
|
|
1263
1415
|
)
|
|
1264
1416
|
sess.execute(stmt)
|
|
1265
1417
|
|
|
@@ -1281,7 +1433,7 @@ class MySQLDb(BaseDb):
|
|
|
1281
1433
|
return None
|
|
1282
1434
|
|
|
1283
1435
|
def upsert_memories(
|
|
1284
|
-
self, memories: List[UserMemory], deserialize: Optional[bool] = True
|
|
1436
|
+
self, memories: List[UserMemory], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
|
|
1285
1437
|
) -> List[Union[UserMemory, Dict[str, Any]]]:
|
|
1286
1438
|
"""
|
|
1287
1439
|
Bulk upsert multiple user memories for improved performance on large datasets.
|
|
@@ -1313,10 +1465,15 @@ class MySQLDb(BaseDb):
|
|
|
1313
1465
|
|
|
1314
1466
|
# Prepare bulk data
|
|
1315
1467
|
bulk_data = []
|
|
1468
|
+
current_time = int(time.time())
|
|
1469
|
+
|
|
1316
1470
|
for memory in memories:
|
|
1317
1471
|
if memory.memory_id is None:
|
|
1318
1472
|
memory.memory_id = str(uuid4())
|
|
1319
1473
|
|
|
1474
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
1475
|
+
updated_at = memory.updated_at if preserve_updated_at else current_time
|
|
1476
|
+
|
|
1320
1477
|
bulk_data.append(
|
|
1321
1478
|
{
|
|
1322
1479
|
"memory_id": memory.memory_id,
|
|
@@ -1326,7 +1483,9 @@ class MySQLDb(BaseDb):
|
|
|
1326
1483
|
"agent_id": memory.agent_id,
|
|
1327
1484
|
"team_id": memory.team_id,
|
|
1328
1485
|
"topics": memory.topics,
|
|
1329
|
-
"
|
|
1486
|
+
"feedback": memory.feedback,
|
|
1487
|
+
"created_at": memory.created_at,
|
|
1488
|
+
"updated_at": updated_at,
|
|
1330
1489
|
}
|
|
1331
1490
|
)
|
|
1332
1491
|
|
|
@@ -1341,7 +1500,10 @@ class MySQLDb(BaseDb):
|
|
|
1341
1500
|
input=stmt.inserted.input,
|
|
1342
1501
|
agent_id=stmt.inserted.agent_id,
|
|
1343
1502
|
team_id=stmt.inserted.team_id,
|
|
1344
|
-
|
|
1503
|
+
feedback=stmt.inserted.feedback,
|
|
1504
|
+
updated_at=stmt.inserted.updated_at,
|
|
1505
|
+
# Preserve created_at on update
|
|
1506
|
+
created_at=table.c.created_at,
|
|
1345
1507
|
)
|
|
1346
1508
|
sess.execute(stmt, bulk_data)
|
|
1347
1509
|
|
|
@@ -1654,9 +1816,9 @@ class MySQLDb(BaseDb):
|
|
|
1654
1816
|
if page is not None:
|
|
1655
1817
|
stmt = stmt.offset((page - 1) * limit)
|
|
1656
1818
|
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1819
|
+
result = sess.execute(stmt).fetchall()
|
|
1820
|
+
if not result:
|
|
1821
|
+
return [], 0
|
|
1660
1822
|
|
|
1661
1823
|
return [KnowledgeRow.model_validate(record._mapping) for record in result], total_count
|
|
1662
1824
|
|
|
@@ -1986,6 +2148,222 @@ class MySQLDb(BaseDb):
|
|
|
1986
2148
|
log_error(f"Error upserting eval run name {eval_run_id}: {e}")
|
|
1987
2149
|
return None
|
|
1988
2150
|
|
|
2151
|
+
# -- Culture methods --
|
|
2152
|
+
|
|
2153
|
+
def clear_cultural_knowledge(self) -> None:
|
|
2154
|
+
"""Delete all cultural knowledge from the database.
|
|
2155
|
+
|
|
2156
|
+
Raises:
|
|
2157
|
+
Exception: If an error occurs during deletion.
|
|
2158
|
+
"""
|
|
2159
|
+
try:
|
|
2160
|
+
table = self._get_table(table_type="culture")
|
|
2161
|
+
if table is None:
|
|
2162
|
+
return
|
|
2163
|
+
|
|
2164
|
+
with self.Session() as sess, sess.begin():
|
|
2165
|
+
sess.execute(table.delete())
|
|
2166
|
+
|
|
2167
|
+
except Exception as e:
|
|
2168
|
+
log_warning(f"Exception deleting all cultural knowledge: {e}")
|
|
2169
|
+
raise e
|
|
2170
|
+
|
|
2171
|
+
def delete_cultural_knowledge(self, id: str) -> None:
|
|
2172
|
+
"""Delete a cultural knowledge entry from the database.
|
|
2173
|
+
|
|
2174
|
+
Args:
|
|
2175
|
+
id (str): The ID of the cultural knowledge to delete.
|
|
2176
|
+
|
|
2177
|
+
Raises:
|
|
2178
|
+
Exception: If an error occurs during deletion.
|
|
2179
|
+
"""
|
|
2180
|
+
try:
|
|
2181
|
+
table = self._get_table(table_type="culture")
|
|
2182
|
+
if table is None:
|
|
2183
|
+
return
|
|
2184
|
+
|
|
2185
|
+
with self.Session() as sess, sess.begin():
|
|
2186
|
+
delete_stmt = table.delete().where(table.c.id == id)
|
|
2187
|
+
result = sess.execute(delete_stmt)
|
|
2188
|
+
|
|
2189
|
+
success = result.rowcount > 0
|
|
2190
|
+
if success:
|
|
2191
|
+
log_debug(f"Successfully deleted cultural knowledge id: {id}")
|
|
2192
|
+
else:
|
|
2193
|
+
log_debug(f"No cultural knowledge found with id: {id}")
|
|
2194
|
+
|
|
2195
|
+
except Exception as e:
|
|
2196
|
+
log_error(f"Error deleting cultural knowledge: {e}")
|
|
2197
|
+
raise e
|
|
2198
|
+
|
|
2199
|
+
def get_cultural_knowledge(
|
|
2200
|
+
self, id: str, deserialize: Optional[bool] = True
|
|
2201
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2202
|
+
"""Get a cultural knowledge entry from the database.
|
|
2203
|
+
|
|
2204
|
+
Args:
|
|
2205
|
+
id (str): The ID of the cultural knowledge to get.
|
|
2206
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2207
|
+
|
|
2208
|
+
Returns:
|
|
2209
|
+
Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The cultural knowledge entry, or None if it doesn't exist.
|
|
2210
|
+
|
|
2211
|
+
Raises:
|
|
2212
|
+
Exception: If an error occurs during retrieval.
|
|
2213
|
+
"""
|
|
2214
|
+
try:
|
|
2215
|
+
table = self._get_table(table_type="culture")
|
|
2216
|
+
if table is None:
|
|
2217
|
+
return None
|
|
2218
|
+
|
|
2219
|
+
with self.Session() as sess, sess.begin():
|
|
2220
|
+
stmt = select(table).where(table.c.id == id)
|
|
2221
|
+
result = sess.execute(stmt).fetchone()
|
|
2222
|
+
if result is None:
|
|
2223
|
+
return None
|
|
2224
|
+
|
|
2225
|
+
db_row = dict(result._mapping)
|
|
2226
|
+
if not db_row or not deserialize:
|
|
2227
|
+
return db_row
|
|
2228
|
+
|
|
2229
|
+
return deserialize_cultural_knowledge_from_db(db_row)
|
|
2230
|
+
|
|
2231
|
+
except Exception as e:
|
|
2232
|
+
log_error(f"Exception reading from cultural knowledge table: {e}")
|
|
2233
|
+
raise e
|
|
2234
|
+
|
|
2235
|
+
def get_all_cultural_knowledge(
|
|
2236
|
+
self,
|
|
2237
|
+
name: Optional[str] = None,
|
|
2238
|
+
agent_id: Optional[str] = None,
|
|
2239
|
+
team_id: Optional[str] = None,
|
|
2240
|
+
limit: Optional[int] = None,
|
|
2241
|
+
page: Optional[int] = None,
|
|
2242
|
+
sort_by: Optional[str] = None,
|
|
2243
|
+
sort_order: Optional[str] = None,
|
|
2244
|
+
deserialize: Optional[bool] = True,
|
|
2245
|
+
) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2246
|
+
"""Get all cultural knowledge from the database as CulturalKnowledge objects.
|
|
2247
|
+
|
|
2248
|
+
Args:
|
|
2249
|
+
name (Optional[str]): The name of the cultural knowledge to filter by.
|
|
2250
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
2251
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
2252
|
+
limit (Optional[int]): The maximum number of cultural knowledge entries to return.
|
|
2253
|
+
page (Optional[int]): The page number.
|
|
2254
|
+
sort_by (Optional[str]): The column to sort by.
|
|
2255
|
+
sort_order (Optional[str]): The order to sort by.
|
|
2256
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2257
|
+
|
|
2258
|
+
Returns:
|
|
2259
|
+
Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2260
|
+
- When deserialize=True: List of CulturalKnowledge objects
|
|
2261
|
+
- When deserialize=False: List of CulturalKnowledge dictionaries and total count
|
|
2262
|
+
|
|
2263
|
+
Raises:
|
|
2264
|
+
Exception: If an error occurs during retrieval.
|
|
2265
|
+
"""
|
|
2266
|
+
try:
|
|
2267
|
+
table = self._get_table(table_type="culture")
|
|
2268
|
+
if table is None:
|
|
2269
|
+
return [] if deserialize else ([], 0)
|
|
2270
|
+
|
|
2271
|
+
with self.Session() as sess, sess.begin():
|
|
2272
|
+
stmt = select(table)
|
|
2273
|
+
|
|
2274
|
+
# Filtering
|
|
2275
|
+
if name is not None:
|
|
2276
|
+
stmt = stmt.where(table.c.name == name)
|
|
2277
|
+
if agent_id is not None:
|
|
2278
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
2279
|
+
if team_id is not None:
|
|
2280
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
2281
|
+
|
|
2282
|
+
# Get total count after applying filtering
|
|
2283
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
2284
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
2285
|
+
|
|
2286
|
+
# Sorting
|
|
2287
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
2288
|
+
# Paginating
|
|
2289
|
+
if limit is not None:
|
|
2290
|
+
stmt = stmt.limit(limit)
|
|
2291
|
+
if page is not None:
|
|
2292
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
2293
|
+
|
|
2294
|
+
result = sess.execute(stmt).fetchall()
|
|
2295
|
+
if not result:
|
|
2296
|
+
return [] if deserialize else ([], 0)
|
|
2297
|
+
|
|
2298
|
+
db_rows = [dict(record._mapping) for record in result]
|
|
2299
|
+
|
|
2300
|
+
if not deserialize:
|
|
2301
|
+
return db_rows, total_count
|
|
2302
|
+
|
|
2303
|
+
return [deserialize_cultural_knowledge_from_db(row) for row in db_rows]
|
|
2304
|
+
|
|
2305
|
+
except Exception as e:
|
|
2306
|
+
log_error(f"Error reading from cultural knowledge table: {e}")
|
|
2307
|
+
raise e
|
|
2308
|
+
|
|
2309
|
+
def upsert_cultural_knowledge(
|
|
2310
|
+
self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
|
|
2311
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2312
|
+
"""Upsert a cultural knowledge entry into the database.
|
|
2313
|
+
|
|
2314
|
+
Args:
|
|
2315
|
+
cultural_knowledge (CulturalKnowledge): The cultural knowledge to upsert.
|
|
2316
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2317
|
+
|
|
2318
|
+
Returns:
|
|
2319
|
+
Optional[CulturalKnowledge]: The upserted cultural knowledge entry.
|
|
2320
|
+
|
|
2321
|
+
Raises:
|
|
2322
|
+
Exception: If an error occurs during upsert.
|
|
2323
|
+
"""
|
|
2324
|
+
try:
|
|
2325
|
+
table = self._get_table(table_type="culture", create_table_if_not_found=True)
|
|
2326
|
+
if table is None:
|
|
2327
|
+
return None
|
|
2328
|
+
|
|
2329
|
+
if cultural_knowledge.id is None:
|
|
2330
|
+
cultural_knowledge.id = str(uuid4())
|
|
2331
|
+
|
|
2332
|
+
# Serialize content, categories, and notes into a JSON dict for DB storage
|
|
2333
|
+
content_dict = serialize_cultural_knowledge_for_db(cultural_knowledge)
|
|
2334
|
+
|
|
2335
|
+
with self.Session() as sess, sess.begin():
|
|
2336
|
+
stmt = mysql.insert(table).values(
|
|
2337
|
+
id=cultural_knowledge.id,
|
|
2338
|
+
name=cultural_knowledge.name,
|
|
2339
|
+
summary=cultural_knowledge.summary,
|
|
2340
|
+
content=content_dict if content_dict else None,
|
|
2341
|
+
metadata=cultural_knowledge.metadata,
|
|
2342
|
+
input=cultural_knowledge.input,
|
|
2343
|
+
created_at=cultural_knowledge.created_at,
|
|
2344
|
+
updated_at=int(time.time()),
|
|
2345
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2346
|
+
team_id=cultural_knowledge.team_id,
|
|
2347
|
+
)
|
|
2348
|
+
stmt = stmt.on_duplicate_key_update(
|
|
2349
|
+
name=cultural_knowledge.name,
|
|
2350
|
+
summary=cultural_knowledge.summary,
|
|
2351
|
+
content=content_dict if content_dict else None,
|
|
2352
|
+
metadata=cultural_knowledge.metadata,
|
|
2353
|
+
input=cultural_knowledge.input,
|
|
2354
|
+
updated_at=int(time.time()),
|
|
2355
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2356
|
+
team_id=cultural_knowledge.team_id,
|
|
2357
|
+
)
|
|
2358
|
+
sess.execute(stmt)
|
|
2359
|
+
|
|
2360
|
+
# Fetch the inserted/updated row
|
|
2361
|
+
return self.get_cultural_knowledge(id=cultural_knowledge.id, deserialize=deserialize)
|
|
2362
|
+
|
|
2363
|
+
except Exception as e:
|
|
2364
|
+
log_error(f"Error upserting cultural knowledge: {e}")
|
|
2365
|
+
raise e
|
|
2366
|
+
|
|
1989
2367
|
# -- Migrations --
|
|
1990
2368
|
|
|
1991
2369
|
def migrate_table_from_v1_to_v2(self, v1_db_schema: str, v1_table_name: str, v1_table_type: str):
|
|
@@ -2043,3 +2421,503 @@ class MySQLDb(BaseDb):
|
|
|
2043
2421
|
for memory in memories:
|
|
2044
2422
|
self.upsert_user_memory(memory)
|
|
2045
2423
|
log_info(f"Migrated {len(memories)} memories to table: {self.memory_table}")
|
|
2424
|
+
|
|
2425
|
+
# --- Traces ---
|
|
2426
|
+
def _get_traces_base_query(self, table: Table, spans_table: Optional[Table] = None):
|
|
2427
|
+
"""Build base query for traces with aggregated span counts.
|
|
2428
|
+
|
|
2429
|
+
Args:
|
|
2430
|
+
table: The traces table.
|
|
2431
|
+
spans_table: The spans table (optional).
|
|
2432
|
+
|
|
2433
|
+
Returns:
|
|
2434
|
+
SQLAlchemy select statement with total_spans and error_count calculated dynamically.
|
|
2435
|
+
"""
|
|
2436
|
+
from sqlalchemy import case, literal
|
|
2437
|
+
|
|
2438
|
+
if spans_table is not None:
|
|
2439
|
+
# JOIN with spans table to calculate total_spans and error_count
|
|
2440
|
+
return (
|
|
2441
|
+
select(
|
|
2442
|
+
table,
|
|
2443
|
+
func.coalesce(func.count(spans_table.c.span_id), 0).label("total_spans"),
|
|
2444
|
+
func.coalesce(func.sum(case((spans_table.c.status_code == "ERROR", 1), else_=0)), 0).label(
|
|
2445
|
+
"error_count"
|
|
2446
|
+
),
|
|
2447
|
+
)
|
|
2448
|
+
.select_from(table.outerjoin(spans_table, table.c.trace_id == spans_table.c.trace_id))
|
|
2449
|
+
.group_by(table.c.trace_id)
|
|
2450
|
+
)
|
|
2451
|
+
else:
|
|
2452
|
+
# Fallback if spans table doesn't exist
|
|
2453
|
+
return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
|
|
2454
|
+
|
|
2455
|
+
def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
|
|
2456
|
+
"""Build a SQL CASE expression that returns the component level for a trace.
|
|
2457
|
+
|
|
2458
|
+
Component levels (higher = more important):
|
|
2459
|
+
- 3: Workflow root (.run or .arun with workflow_id)
|
|
2460
|
+
- 2: Team root (.run or .arun with team_id)
|
|
2461
|
+
- 1: Agent root (.run or .arun with agent_id)
|
|
2462
|
+
- 0: Child span (not a root)
|
|
2463
|
+
|
|
2464
|
+
Args:
|
|
2465
|
+
workflow_id_col: SQL column/expression for workflow_id
|
|
2466
|
+
team_id_col: SQL column/expression for team_id
|
|
2467
|
+
agent_id_col: SQL column/expression for agent_id
|
|
2468
|
+
name_col: SQL column/expression for name
|
|
2469
|
+
|
|
2470
|
+
Returns:
|
|
2471
|
+
SQLAlchemy CASE expression returning the component level as an integer.
|
|
2472
|
+
"""
|
|
2473
|
+
from sqlalchemy import and_, case, or_
|
|
2474
|
+
|
|
2475
|
+
is_root_name = or_(name_col.like("%.run%"), name_col.like("%.arun%"))
|
|
2476
|
+
|
|
2477
|
+
return case(
|
|
2478
|
+
# Workflow root (level 3)
|
|
2479
|
+
(and_(workflow_id_col.isnot(None), is_root_name), 3),
|
|
2480
|
+
# Team root (level 2)
|
|
2481
|
+
(and_(team_id_col.isnot(None), is_root_name), 2),
|
|
2482
|
+
# Agent root (level 1)
|
|
2483
|
+
(and_(agent_id_col.isnot(None), is_root_name), 1),
|
|
2484
|
+
# Child span or unknown (level 0)
|
|
2485
|
+
else_=0,
|
|
2486
|
+
)
|
|
2487
|
+
|
|
2488
|
+
def upsert_trace(self, trace: "Trace") -> None:
|
|
2489
|
+
"""Create or update a single trace record in the database.
|
|
2490
|
+
|
|
2491
|
+
Uses INSERT ... ON DUPLICATE KEY UPDATE (upsert) to handle concurrent inserts
|
|
2492
|
+
atomically and avoid race conditions.
|
|
2493
|
+
|
|
2494
|
+
Args:
|
|
2495
|
+
trace: The Trace object to store (one per trace_id).
|
|
2496
|
+
"""
|
|
2497
|
+
from sqlalchemy import case
|
|
2498
|
+
|
|
2499
|
+
try:
|
|
2500
|
+
table = self._get_table(table_type="traces", create_table_if_not_found=True)
|
|
2501
|
+
if table is None:
|
|
2502
|
+
return
|
|
2503
|
+
|
|
2504
|
+
trace_dict = trace.to_dict()
|
|
2505
|
+
trace_dict.pop("total_spans", None)
|
|
2506
|
+
trace_dict.pop("error_count", None)
|
|
2507
|
+
|
|
2508
|
+
with self.Session() as sess, sess.begin():
|
|
2509
|
+
# Use upsert to handle concurrent inserts atomically
|
|
2510
|
+
# On conflict, update fields while preserving existing non-null context values
|
|
2511
|
+
# and keeping the earliest start_time
|
|
2512
|
+
insert_stmt = mysql.insert(table).values(trace_dict)
|
|
2513
|
+
|
|
2514
|
+
# Build component level expressions for comparing trace priority
|
|
2515
|
+
new_level = self._get_trace_component_level_expr(
|
|
2516
|
+
insert_stmt.inserted.workflow_id,
|
|
2517
|
+
insert_stmt.inserted.team_id,
|
|
2518
|
+
insert_stmt.inserted.agent_id,
|
|
2519
|
+
insert_stmt.inserted.name,
|
|
2520
|
+
)
|
|
2521
|
+
existing_level = self._get_trace_component_level_expr(
|
|
2522
|
+
table.c.workflow_id,
|
|
2523
|
+
table.c.team_id,
|
|
2524
|
+
table.c.agent_id,
|
|
2525
|
+
table.c.name,
|
|
2526
|
+
)
|
|
2527
|
+
|
|
2528
|
+
# Build the ON DUPLICATE KEY UPDATE clause
|
|
2529
|
+
# Use LEAST for start_time, GREATEST for end_time to capture full trace duration
|
|
2530
|
+
# MySQL stores timestamps as ISO strings, so string comparison works for ISO format
|
|
2531
|
+
# Duration is calculated using TIMESTAMPDIFF in microseconds then converted to ms
|
|
2532
|
+
upsert_stmt = insert_stmt.on_duplicate_key_update(
|
|
2533
|
+
end_time=func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
|
|
2534
|
+
start_time=func.least(table.c.start_time, insert_stmt.inserted.start_time),
|
|
2535
|
+
# Calculate duration in milliseconds using TIMESTAMPDIFF
|
|
2536
|
+
# TIMESTAMPDIFF(MICROSECOND, start, end) / 1000 gives milliseconds
|
|
2537
|
+
duration_ms=func.timestampdiff(
|
|
2538
|
+
text("MICROSECOND"),
|
|
2539
|
+
func.least(table.c.start_time, insert_stmt.inserted.start_time),
|
|
2540
|
+
func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
|
|
2541
|
+
)
|
|
2542
|
+
/ 1000,
|
|
2543
|
+
status=insert_stmt.inserted.status,
|
|
2544
|
+
# Update name only if new trace is from a higher-level component
|
|
2545
|
+
# Priority: workflow (3) > team (2) > agent (1) > child spans (0)
|
|
2546
|
+
name=case(
|
|
2547
|
+
(new_level > existing_level, insert_stmt.inserted.name),
|
|
2548
|
+
else_=table.c.name,
|
|
2549
|
+
),
|
|
2550
|
+
# Preserve existing non-null context values using COALESCE
|
|
2551
|
+
run_id=func.coalesce(insert_stmt.inserted.run_id, table.c.run_id),
|
|
2552
|
+
session_id=func.coalesce(insert_stmt.inserted.session_id, table.c.session_id),
|
|
2553
|
+
user_id=func.coalesce(insert_stmt.inserted.user_id, table.c.user_id),
|
|
2554
|
+
agent_id=func.coalesce(insert_stmt.inserted.agent_id, table.c.agent_id),
|
|
2555
|
+
team_id=func.coalesce(insert_stmt.inserted.team_id, table.c.team_id),
|
|
2556
|
+
workflow_id=func.coalesce(insert_stmt.inserted.workflow_id, table.c.workflow_id),
|
|
2557
|
+
)
|
|
2558
|
+
sess.execute(upsert_stmt)
|
|
2559
|
+
|
|
2560
|
+
except Exception as e:
|
|
2561
|
+
log_error(f"Error creating trace: {e}")
|
|
2562
|
+
# Don't raise - tracing should not break the main application flow
|
|
2563
|
+
|
|
2564
|
+
def get_trace(
|
|
2565
|
+
self,
|
|
2566
|
+
trace_id: Optional[str] = None,
|
|
2567
|
+
run_id: Optional[str] = None,
|
|
2568
|
+
):
|
|
2569
|
+
"""Get a single trace by trace_id or other filters.
|
|
2570
|
+
|
|
2571
|
+
Args:
|
|
2572
|
+
trace_id: The unique trace identifier.
|
|
2573
|
+
run_id: Filter by run ID (returns first match).
|
|
2574
|
+
|
|
2575
|
+
Returns:
|
|
2576
|
+
Optional[Trace]: The trace if found, None otherwise.
|
|
2577
|
+
|
|
2578
|
+
Note:
|
|
2579
|
+
If multiple filters are provided, trace_id takes precedence.
|
|
2580
|
+
For other filters, the most recent trace is returned.
|
|
2581
|
+
"""
|
|
2582
|
+
try:
|
|
2583
|
+
from agno.tracing.schemas import Trace
|
|
2584
|
+
|
|
2585
|
+
table = self._get_table(table_type="traces")
|
|
2586
|
+
if table is None:
|
|
2587
|
+
return None
|
|
2588
|
+
|
|
2589
|
+
# Get spans table for JOIN
|
|
2590
|
+
spans_table = self._get_table(table_type="spans")
|
|
2591
|
+
|
|
2592
|
+
with self.Session() as sess:
|
|
2593
|
+
# Build query with aggregated span counts
|
|
2594
|
+
stmt = self._get_traces_base_query(table, spans_table)
|
|
2595
|
+
|
|
2596
|
+
if trace_id:
|
|
2597
|
+
stmt = stmt.where(table.c.trace_id == trace_id)
|
|
2598
|
+
elif run_id:
|
|
2599
|
+
stmt = stmt.where(table.c.run_id == run_id)
|
|
2600
|
+
else:
|
|
2601
|
+
log_debug("get_trace called without any filter parameters")
|
|
2602
|
+
return None
|
|
2603
|
+
|
|
2604
|
+
# Order by most recent and get first result
|
|
2605
|
+
stmt = stmt.order_by(table.c.start_time.desc()).limit(1)
|
|
2606
|
+
result = sess.execute(stmt).fetchone()
|
|
2607
|
+
|
|
2608
|
+
if result:
|
|
2609
|
+
return Trace.from_dict(dict(result._mapping))
|
|
2610
|
+
return None
|
|
2611
|
+
|
|
2612
|
+
except Exception as e:
|
|
2613
|
+
log_error(f"Error getting trace: {e}")
|
|
2614
|
+
return None
|
|
2615
|
+
|
|
2616
|
+
def get_traces(
|
|
2617
|
+
self,
|
|
2618
|
+
run_id: Optional[str] = None,
|
|
2619
|
+
session_id: Optional[str] = None,
|
|
2620
|
+
user_id: Optional[str] = None,
|
|
2621
|
+
agent_id: Optional[str] = None,
|
|
2622
|
+
team_id: Optional[str] = None,
|
|
2623
|
+
workflow_id: Optional[str] = None,
|
|
2624
|
+
status: Optional[str] = None,
|
|
2625
|
+
start_time: Optional[datetime] = None,
|
|
2626
|
+
end_time: Optional[datetime] = None,
|
|
2627
|
+
limit: Optional[int] = 20,
|
|
2628
|
+
page: Optional[int] = 1,
|
|
2629
|
+
) -> tuple[List, int]:
|
|
2630
|
+
"""Get traces matching the provided filters with pagination.
|
|
2631
|
+
|
|
2632
|
+
Args:
|
|
2633
|
+
run_id: Filter by run ID.
|
|
2634
|
+
session_id: Filter by session ID.
|
|
2635
|
+
user_id: Filter by user ID.
|
|
2636
|
+
agent_id: Filter by agent ID.
|
|
2637
|
+
team_id: Filter by team ID.
|
|
2638
|
+
workflow_id: Filter by workflow ID.
|
|
2639
|
+
status: Filter by status (OK, ERROR, UNSET).
|
|
2640
|
+
start_time: Filter traces starting after this datetime.
|
|
2641
|
+
end_time: Filter traces ending before this datetime.
|
|
2642
|
+
limit: Maximum number of traces to return per page.
|
|
2643
|
+
page: Page number (1-indexed).
|
|
2644
|
+
|
|
2645
|
+
Returns:
|
|
2646
|
+
tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
|
|
2647
|
+
"""
|
|
2648
|
+
try:
|
|
2649
|
+
from agno.tracing.schemas import Trace
|
|
2650
|
+
|
|
2651
|
+
log_debug(
|
|
2652
|
+
f"get_traces called with filters: run_id={run_id}, session_id={session_id}, user_id={user_id}, agent_id={agent_id}, page={page}, limit={limit}"
|
|
2653
|
+
)
|
|
2654
|
+
|
|
2655
|
+
table = self._get_table(table_type="traces")
|
|
2656
|
+
if table is None:
|
|
2657
|
+
log_debug("Traces table not found")
|
|
2658
|
+
return [], 0
|
|
2659
|
+
|
|
2660
|
+
# Get spans table for JOIN
|
|
2661
|
+
spans_table = self._get_table(table_type="spans")
|
|
2662
|
+
|
|
2663
|
+
with self.Session() as sess:
|
|
2664
|
+
# Build base query with aggregated span counts
|
|
2665
|
+
base_stmt = self._get_traces_base_query(table, spans_table)
|
|
2666
|
+
|
|
2667
|
+
# Apply filters
|
|
2668
|
+
if run_id:
|
|
2669
|
+
base_stmt = base_stmt.where(table.c.run_id == run_id)
|
|
2670
|
+
if session_id:
|
|
2671
|
+
base_stmt = base_stmt.where(table.c.session_id == session_id)
|
|
2672
|
+
if user_id:
|
|
2673
|
+
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
2674
|
+
if agent_id:
|
|
2675
|
+
base_stmt = base_stmt.where(table.c.agent_id == agent_id)
|
|
2676
|
+
if team_id:
|
|
2677
|
+
base_stmt = base_stmt.where(table.c.team_id == team_id)
|
|
2678
|
+
if workflow_id:
|
|
2679
|
+
base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
|
|
2680
|
+
if status:
|
|
2681
|
+
base_stmt = base_stmt.where(table.c.status == status)
|
|
2682
|
+
if start_time:
|
|
2683
|
+
# Convert datetime to ISO string for comparison
|
|
2684
|
+
base_stmt = base_stmt.where(table.c.start_time >= start_time.isoformat())
|
|
2685
|
+
if end_time:
|
|
2686
|
+
# Convert datetime to ISO string for comparison
|
|
2687
|
+
base_stmt = base_stmt.where(table.c.end_time <= end_time.isoformat())
|
|
2688
|
+
|
|
2689
|
+
# Get total count
|
|
2690
|
+
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2691
|
+
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2692
|
+
|
|
2693
|
+
# Apply pagination
|
|
2694
|
+
offset = (page - 1) * limit if page and limit else 0
|
|
2695
|
+
paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
|
|
2696
|
+
|
|
2697
|
+
results = sess.execute(paginated_stmt).fetchall()
|
|
2698
|
+
|
|
2699
|
+
traces = [Trace.from_dict(dict(row._mapping)) for row in results]
|
|
2700
|
+
return traces, total_count
|
|
2701
|
+
|
|
2702
|
+
except Exception as e:
|
|
2703
|
+
log_error(f"Error getting traces: {e}")
|
|
2704
|
+
return [], 0
|
|
2705
|
+
|
|
2706
|
+
def get_trace_stats(
|
|
2707
|
+
self,
|
|
2708
|
+
user_id: Optional[str] = None,
|
|
2709
|
+
agent_id: Optional[str] = None,
|
|
2710
|
+
team_id: Optional[str] = None,
|
|
2711
|
+
workflow_id: Optional[str] = None,
|
|
2712
|
+
start_time: Optional[datetime] = None,
|
|
2713
|
+
end_time: Optional[datetime] = None,
|
|
2714
|
+
limit: Optional[int] = 20,
|
|
2715
|
+
page: Optional[int] = 1,
|
|
2716
|
+
) -> tuple[List[Dict[str, Any]], int]:
|
|
2717
|
+
"""Get trace statistics grouped by session.
|
|
2718
|
+
|
|
2719
|
+
Args:
|
|
2720
|
+
user_id: Filter by user ID.
|
|
2721
|
+
agent_id: Filter by agent ID.
|
|
2722
|
+
team_id: Filter by team ID.
|
|
2723
|
+
workflow_id: Filter by workflow ID.
|
|
2724
|
+
start_time: Filter sessions with traces created after this datetime.
|
|
2725
|
+
end_time: Filter sessions with traces created before this datetime.
|
|
2726
|
+
limit: Maximum number of sessions to return per page.
|
|
2727
|
+
page: Page number (1-indexed).
|
|
2728
|
+
|
|
2729
|
+
Returns:
|
|
2730
|
+
tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
|
|
2731
|
+
Each dict contains: session_id, user_id, agent_id, team_id, total_traces,
|
|
2732
|
+
workflow_id, first_trace_at, last_trace_at.
|
|
2733
|
+
"""
|
|
2734
|
+
try:
|
|
2735
|
+
table = self._get_table(table_type="traces")
|
|
2736
|
+
if table is None:
|
|
2737
|
+
log_debug("Traces table not found")
|
|
2738
|
+
return [], 0
|
|
2739
|
+
|
|
2740
|
+
with self.Session() as sess:
|
|
2741
|
+
# Build base query grouped by session_id
|
|
2742
|
+
base_stmt = (
|
|
2743
|
+
select(
|
|
2744
|
+
table.c.session_id,
|
|
2745
|
+
table.c.user_id,
|
|
2746
|
+
table.c.agent_id,
|
|
2747
|
+
table.c.team_id,
|
|
2748
|
+
table.c.workflow_id,
|
|
2749
|
+
func.count(table.c.trace_id).label("total_traces"),
|
|
2750
|
+
func.min(table.c.created_at).label("first_trace_at"),
|
|
2751
|
+
func.max(table.c.created_at).label("last_trace_at"),
|
|
2752
|
+
)
|
|
2753
|
+
.where(table.c.session_id.isnot(None)) # Only sessions with session_id
|
|
2754
|
+
.group_by(
|
|
2755
|
+
table.c.session_id, table.c.user_id, table.c.agent_id, table.c.team_id, table.c.workflow_id
|
|
2756
|
+
)
|
|
2757
|
+
)
|
|
2758
|
+
|
|
2759
|
+
# Apply filters
|
|
2760
|
+
if user_id:
|
|
2761
|
+
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
2762
|
+
if workflow_id:
|
|
2763
|
+
base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
|
|
2764
|
+
if team_id:
|
|
2765
|
+
base_stmt = base_stmt.where(table.c.team_id == team_id)
|
|
2766
|
+
if agent_id:
|
|
2767
|
+
base_stmt = base_stmt.where(table.c.agent_id == agent_id)
|
|
2768
|
+
if start_time:
|
|
2769
|
+
# Convert datetime to ISO string for comparison
|
|
2770
|
+
base_stmt = base_stmt.where(table.c.created_at >= start_time.isoformat())
|
|
2771
|
+
if end_time:
|
|
2772
|
+
# Convert datetime to ISO string for comparison
|
|
2773
|
+
base_stmt = base_stmt.where(table.c.created_at <= end_time.isoformat())
|
|
2774
|
+
|
|
2775
|
+
# Get total count of sessions
|
|
2776
|
+
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2777
|
+
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2778
|
+
|
|
2779
|
+
# Apply pagination and ordering
|
|
2780
|
+
offset = (page - 1) * limit if page and limit else 0
|
|
2781
|
+
paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
|
|
2782
|
+
|
|
2783
|
+
results = sess.execute(paginated_stmt).fetchall()
|
|
2784
|
+
|
|
2785
|
+
# Convert to list of dicts with datetime objects
|
|
2786
|
+
stats_list = []
|
|
2787
|
+
for row in results:
|
|
2788
|
+
# Convert ISO strings to datetime objects
|
|
2789
|
+
first_trace_at_str = row.first_trace_at
|
|
2790
|
+
last_trace_at_str = row.last_trace_at
|
|
2791
|
+
|
|
2792
|
+
# Parse ISO format strings to datetime objects
|
|
2793
|
+
first_trace_at = datetime.fromisoformat(first_trace_at_str.replace("Z", "+00:00"))
|
|
2794
|
+
last_trace_at = datetime.fromisoformat(last_trace_at_str.replace("Z", "+00:00"))
|
|
2795
|
+
|
|
2796
|
+
stats_list.append(
|
|
2797
|
+
{
|
|
2798
|
+
"session_id": row.session_id,
|
|
2799
|
+
"user_id": row.user_id,
|
|
2800
|
+
"agent_id": row.agent_id,
|
|
2801
|
+
"team_id": row.team_id,
|
|
2802
|
+
"workflow_id": row.workflow_id,
|
|
2803
|
+
"total_traces": row.total_traces,
|
|
2804
|
+
"first_trace_at": first_trace_at,
|
|
2805
|
+
"last_trace_at": last_trace_at,
|
|
2806
|
+
}
|
|
2807
|
+
)
|
|
2808
|
+
|
|
2809
|
+
return stats_list, total_count
|
|
2810
|
+
|
|
2811
|
+
except Exception as e:
|
|
2812
|
+
log_error(f"Error getting trace stats: {e}")
|
|
2813
|
+
return [], 0
|
|
2814
|
+
|
|
2815
|
+
# --- Spans ---
|
|
2816
|
+
def create_span(self, span: "Span") -> None:
|
|
2817
|
+
"""Create a single span in the database.
|
|
2818
|
+
|
|
2819
|
+
Args:
|
|
2820
|
+
span: The Span object to store.
|
|
2821
|
+
"""
|
|
2822
|
+
try:
|
|
2823
|
+
table = self._get_table(table_type="spans", create_table_if_not_found=True)
|
|
2824
|
+
if table is None:
|
|
2825
|
+
return
|
|
2826
|
+
|
|
2827
|
+
with self.Session() as sess, sess.begin():
|
|
2828
|
+
stmt = mysql.insert(table).values(span.to_dict())
|
|
2829
|
+
sess.execute(stmt)
|
|
2830
|
+
|
|
2831
|
+
except Exception as e:
|
|
2832
|
+
log_error(f"Error creating span: {e}")
|
|
2833
|
+
|
|
2834
|
+
def create_spans(self, spans: List) -> None:
|
|
2835
|
+
"""Create multiple spans in the database as a batch.
|
|
2836
|
+
|
|
2837
|
+
Args:
|
|
2838
|
+
spans: List of Span objects to store.
|
|
2839
|
+
"""
|
|
2840
|
+
if not spans:
|
|
2841
|
+
return
|
|
2842
|
+
|
|
2843
|
+
try:
|
|
2844
|
+
table = self._get_table(table_type="spans", create_table_if_not_found=True)
|
|
2845
|
+
if table is None:
|
|
2846
|
+
return
|
|
2847
|
+
|
|
2848
|
+
with self.Session() as sess, sess.begin():
|
|
2849
|
+
for span in spans:
|
|
2850
|
+
stmt = mysql.insert(table).values(span.to_dict())
|
|
2851
|
+
sess.execute(stmt)
|
|
2852
|
+
|
|
2853
|
+
except Exception as e:
|
|
2854
|
+
log_error(f"Error creating spans batch: {e}")
|
|
2855
|
+
|
|
2856
|
+
def get_span(self, span_id: str):
|
|
2857
|
+
"""Get a single span by its span_id.
|
|
2858
|
+
|
|
2859
|
+
Args:
|
|
2860
|
+
span_id: The unique span identifier.
|
|
2861
|
+
|
|
2862
|
+
Returns:
|
|
2863
|
+
Optional[Span]: The span if found, None otherwise.
|
|
2864
|
+
"""
|
|
2865
|
+
try:
|
|
2866
|
+
from agno.tracing.schemas import Span
|
|
2867
|
+
|
|
2868
|
+
table = self._get_table(table_type="spans")
|
|
2869
|
+
if table is None:
|
|
2870
|
+
return None
|
|
2871
|
+
|
|
2872
|
+
with self.Session() as sess:
|
|
2873
|
+
stmt = select(table).where(table.c.span_id == span_id)
|
|
2874
|
+
result = sess.execute(stmt).fetchone()
|
|
2875
|
+
if result:
|
|
2876
|
+
return Span.from_dict(dict(result._mapping))
|
|
2877
|
+
return None
|
|
2878
|
+
|
|
2879
|
+
except Exception as e:
|
|
2880
|
+
log_error(f"Error getting span: {e}")
|
|
2881
|
+
return None
|
|
2882
|
+
|
|
2883
|
+
def get_spans(
|
|
2884
|
+
self,
|
|
2885
|
+
trace_id: Optional[str] = None,
|
|
2886
|
+
parent_span_id: Optional[str] = None,
|
|
2887
|
+
limit: Optional[int] = 1000,
|
|
2888
|
+
) -> List:
|
|
2889
|
+
"""Get spans matching the provided filters.
|
|
2890
|
+
|
|
2891
|
+
Args:
|
|
2892
|
+
trace_id: Filter by trace ID.
|
|
2893
|
+
parent_span_id: Filter by parent span ID.
|
|
2894
|
+
limit: Maximum number of spans to return.
|
|
2895
|
+
|
|
2896
|
+
Returns:
|
|
2897
|
+
List[Span]: List of matching spans.
|
|
2898
|
+
"""
|
|
2899
|
+
try:
|
|
2900
|
+
from agno.tracing.schemas import Span
|
|
2901
|
+
|
|
2902
|
+
table = self._get_table(table_type="spans")
|
|
2903
|
+
if table is None:
|
|
2904
|
+
return []
|
|
2905
|
+
|
|
2906
|
+
with self.Session() as sess:
|
|
2907
|
+
stmt = select(table)
|
|
2908
|
+
|
|
2909
|
+
# Apply filters
|
|
2910
|
+
if trace_id:
|
|
2911
|
+
stmt = stmt.where(table.c.trace_id == trace_id)
|
|
2912
|
+
if parent_span_id:
|
|
2913
|
+
stmt = stmt.where(table.c.parent_span_id == parent_span_id)
|
|
2914
|
+
|
|
2915
|
+
if limit:
|
|
2916
|
+
stmt = stmt.limit(limit)
|
|
2917
|
+
|
|
2918
|
+
results = sess.execute(stmt).fetchall()
|
|
2919
|
+
return [Span.from_dict(dict(row._mapping)) for row in results]
|
|
2920
|
+
|
|
2921
|
+
except Exception as e:
|
|
2922
|
+
log_error(f"Error getting spans: {e}")
|
|
2923
|
+
return []
|