agno 2.0.0rc2__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +6009 -2874
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +385 -6
- agno/db/dynamo/dynamo.py +388 -81
- agno/db/dynamo/schemas.py +47 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +435 -64
- agno/db/firestore/schemas.py +11 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +384 -42
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +351 -66
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +339 -48
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +510 -37
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2036 -0
- agno/db/mongo/mongo.py +653 -76
- agno/db/mongo/schemas.py +13 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/mysql.py +687 -25
- agno/db/mysql/schemas.py +61 -37
- agno/db/mysql/utils.py +60 -2
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2001 -0
- agno/db/postgres/postgres.py +676 -57
- agno/db/postgres/schemas.py +43 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +344 -38
- agno/db/redis/schemas.py +18 -0
- agno/db/redis/utils.py +60 -2
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/memory.py +13 -0
- agno/db/singlestore/schemas.py +26 -1
- agno/db/singlestore/singlestore.py +687 -53
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2371 -0
- agno/db/sqlite/schemas.py +24 -0
- agno/db/sqlite/sqlite.py +774 -85
- agno/db/sqlite/utils.py +168 -5
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +309 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1361 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +50 -22
- agno/eval/accuracy.py +50 -43
- agno/eval/performance.py +6 -3
- agno/eval/reliability.py +6 -3
- agno/eval/utils.py +33 -16
- agno/exceptions.py +68 -1
- agno/filters.py +354 -0
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +52 -0
- agno/integrations/discord/client.py +1 -0
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +1 -1
- agno/knowledge/chunking/semantic.py +40 -8
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/aws_bedrock.py +9 -4
- agno/knowledge/embedder/azure_openai.py +54 -0
- agno/knowledge/embedder/base.py +2 -0
- agno/knowledge/embedder/cohere.py +184 -5
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/google.py +79 -1
- agno/knowledge/embedder/huggingface.py +9 -4
- agno/knowledge/embedder/jina.py +63 -0
- agno/knowledge/embedder/mistral.py +78 -11
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +13 -0
- agno/knowledge/embedder/openai.py +37 -65
- agno/knowledge/embedder/sentence_transformer.py +8 -4
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/embedder/voyageai.py +69 -16
- agno/knowledge/knowledge.py +595 -187
- agno/knowledge/reader/base.py +9 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +290 -0
- agno/knowledge/reader/json_reader.py +6 -5
- agno/knowledge/reader/markdown_reader.py +13 -13
- agno/knowledge/reader/pdf_reader.py +43 -68
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +51 -6
- agno/knowledge/reader/s3_reader.py +3 -15
- agno/knowledge/reader/tavily_reader.py +194 -0
- agno/knowledge/reader/text_reader.py +13 -13
- agno/knowledge/reader/web_search_reader.py +2 -43
- agno/knowledge/reader/website_reader.py +43 -25
- agno/knowledge/reranker/__init__.py +3 -0
- agno/knowledge/types.py +9 -0
- agno/knowledge/utils.py +20 -0
- agno/media.py +339 -266
- agno/memory/manager.py +336 -82
- agno/models/aimlapi/aimlapi.py +2 -2
- agno/models/anthropic/claude.py +183 -37
- agno/models/aws/bedrock.py +52 -112
- agno/models/aws/claude.py +33 -1
- agno/models/azure/ai_foundry.py +33 -15
- agno/models/azure/openai_chat.py +25 -8
- agno/models/base.py +1011 -566
- agno/models/cerebras/cerebras.py +19 -13
- agno/models/cerebras/cerebras_openai.py +8 -5
- agno/models/cohere/chat.py +27 -1
- agno/models/cometapi/__init__.py +5 -0
- agno/models/cometapi/cometapi.py +57 -0
- agno/models/dashscope/dashscope.py +1 -0
- agno/models/deepinfra/deepinfra.py +2 -2
- agno/models/deepseek/deepseek.py +2 -2
- agno/models/fireworks/fireworks.py +2 -2
- agno/models/google/gemini.py +110 -37
- agno/models/groq/groq.py +28 -11
- agno/models/huggingface/huggingface.py +2 -1
- agno/models/internlm/internlm.py +2 -2
- agno/models/langdb/langdb.py +4 -4
- agno/models/litellm/chat.py +18 -1
- agno/models/litellm/litellm_openai.py +2 -2
- agno/models/llama_cpp/__init__.py +5 -0
- agno/models/llama_cpp/llama_cpp.py +22 -0
- agno/models/message.py +143 -4
- agno/models/meta/llama.py +27 -10
- agno/models/meta/llama_openai.py +5 -17
- agno/models/nebius/nebius.py +6 -6
- agno/models/nexus/__init__.py +3 -0
- agno/models/nexus/nexus.py +22 -0
- agno/models/nvidia/nvidia.py +2 -2
- agno/models/ollama/chat.py +60 -6
- agno/models/openai/chat.py +102 -43
- agno/models/openai/responses.py +103 -106
- agno/models/openrouter/openrouter.py +41 -3
- agno/models/perplexity/perplexity.py +4 -5
- agno/models/portkey/portkey.py +3 -3
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +52 -0
- agno/models/response.py +81 -5
- agno/models/sambanova/sambanova.py +2 -2
- agno/models/siliconflow/__init__.py +5 -0
- agno/models/siliconflow/siliconflow.py +25 -0
- agno/models/together/together.py +2 -2
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +2 -2
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +96 -0
- agno/models/vllm/vllm.py +1 -0
- agno/models/xai/xai.py +3 -2
- agno/os/app.py +543 -175
- agno/os/auth.py +24 -14
- agno/os/config.py +1 -0
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +250 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/agui.py +23 -7
- agno/os/interfaces/agui/router.py +27 -3
- agno/os/interfaces/agui/utils.py +242 -142
- agno/os/interfaces/base.py +6 -2
- agno/os/interfaces/slack/router.py +81 -23
- agno/os/interfaces/slack/slack.py +29 -14
- agno/os/interfaces/whatsapp/router.py +11 -4
- agno/os/interfaces/whatsapp/whatsapp.py +14 -7
- agno/os/mcp.py +111 -54
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +556 -139
- agno/os/routers/evals/evals.py +71 -34
- agno/os/routers/evals/schemas.py +31 -31
- agno/os/routers/evals/utils.py +6 -5
- agno/os/routers/health.py +31 -0
- agno/os/routers/home.py +52 -0
- agno/os/routers/knowledge/knowledge.py +185 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +158 -53
- agno/os/routers/memory/schemas.py +20 -16
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +499 -38
- agno/os/schema.py +308 -198
- agno/os/utils.py +401 -41
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/azure_ai_foundry.py +2 -2
- agno/reasoning/deepseek.py +2 -2
- agno/reasoning/default.py +3 -1
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/groq.py +2 -2
- agno/reasoning/ollama.py +2 -2
- agno/reasoning/openai.py +7 -2
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +266 -112
- agno/run/base.py +53 -24
- agno/run/team.py +252 -111
- agno/run/workflow.py +156 -45
- agno/session/agent.py +105 -89
- agno/session/summary.py +65 -25
- agno/session/team.py +176 -96
- agno/session/workflow.py +406 -40
- agno/team/team.py +3854 -1692
- agno/tools/brightdata.py +3 -3
- agno/tools/cartesia.py +3 -5
- agno/tools/dalle.py +9 -8
- agno/tools/decorator.py +4 -2
- agno/tools/desi_vocal.py +2 -2
- agno/tools/duckduckgo.py +15 -11
- agno/tools/e2b.py +20 -13
- agno/tools/eleven_labs.py +26 -28
- agno/tools/exa.py +21 -16
- agno/tools/fal.py +4 -4
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +350 -0
- agno/tools/firecrawl.py +4 -4
- agno/tools/function.py +257 -37
- agno/tools/giphy.py +2 -2
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +270 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/knowledge.py +3 -3
- agno/tools/lumalab.py +3 -3
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +284 -0
- agno/tools/mem0.py +11 -17
- agno/tools/memori.py +1 -53
- agno/tools/memory.py +419 -0
- agno/tools/models/azure_openai.py +2 -2
- agno/tools/models/gemini.py +3 -3
- agno/tools/models/groq.py +3 -5
- agno/tools/models/nebius.py +7 -7
- agno/tools/models_labs.py +25 -15
- agno/tools/notion.py +204 -0
- agno/tools/openai.py +4 -9
- agno/tools/opencv.py +3 -3
- agno/tools/parallel.py +314 -0
- agno/tools/replicate.py +7 -7
- agno/tools/scrapegraph.py +58 -31
- agno/tools/searxng.py +2 -2
- agno/tools/serper.py +2 -2
- agno/tools/slack.py +18 -3
- agno/tools/spider.py +2 -2
- agno/tools/tavily.py +146 -0
- agno/tools/whatsapp.py +1 -1
- agno/tools/workflow.py +278 -0
- agno/tools/yfinance.py +12 -11
- agno/utils/agent.py +820 -0
- agno/utils/audio.py +27 -0
- agno/utils/common.py +90 -1
- agno/utils/events.py +222 -7
- agno/utils/gemini.py +181 -23
- agno/utils/hooks.py +57 -0
- agno/utils/http.py +111 -0
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +95 -5
- agno/utils/media.py +188 -10
- agno/utils/merge_dict.py +22 -1
- agno/utils/message.py +60 -0
- agno/utils/models/claude.py +40 -11
- agno/utils/models/cohere.py +1 -1
- agno/utils/models/watsonx.py +1 -1
- agno/utils/openai.py +1 -1
- agno/utils/print_response/agent.py +105 -21
- agno/utils/print_response/team.py +103 -38
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/reasoning.py +22 -1
- agno/utils/serialize.py +32 -0
- agno/utils/streamlit.py +16 -10
- agno/utils/string.py +41 -0
- agno/utils/team.py +98 -9
- agno/utils/tools.py +1 -1
- agno/vectordb/base.py +23 -4
- agno/vectordb/cassandra/cassandra.py +65 -9
- agno/vectordb/chroma/chromadb.py +182 -38
- agno/vectordb/clickhouse/clickhousedb.py +64 -11
- agno/vectordb/couchbase/couchbase.py +105 -10
- agno/vectordb/lancedb/lance_db.py +183 -135
- agno/vectordb/langchaindb/langchaindb.py +25 -7
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/__init__.py +3 -0
- agno/vectordb/llamaindex/llamaindexdb.py +46 -7
- agno/vectordb/milvus/milvus.py +126 -9
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +112 -7
- agno/vectordb/pgvector/pgvector.py +142 -21
- agno/vectordb/pineconedb/pineconedb.py +80 -8
- agno/vectordb/qdrant/qdrant.py +125 -39
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +694 -0
- agno/vectordb/singlestore/singlestore.py +111 -25
- agno/vectordb/surrealdb/surrealdb.py +31 -5
- agno/vectordb/upstashdb/upstashdb.py +76 -8
- agno/vectordb/weaviate/weaviate.py +86 -15
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +112 -18
- agno/workflow/loop.py +69 -10
- agno/workflow/parallel.py +266 -118
- agno/workflow/router.py +110 -17
- agno/workflow/step.py +645 -136
- agno/workflow/steps.py +65 -6
- agno/workflow/types.py +71 -33
- agno/workflow/workflow.py +2113 -300
- agno-2.3.0.dist-info/METADATA +618 -0
- agno-2.3.0.dist-info/RECORD +577 -0
- agno-2.3.0.dist-info/licenses/LICENSE +201 -0
- agno/knowledge/reader/url_reader.py +0 -128
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -610
- agno/utils/models/aws_claude.py +0 -170
- agno-2.0.0rc2.dist-info/METADATA +0 -355
- agno-2.0.0rc2.dist-info/RECORD +0 -515
- agno-2.0.0rc2.dist-info/licenses/LICENSE +0 -375
- {agno-2.0.0rc2.dist-info → agno-2.3.0.dist-info}/WHEEL +0 -0
- {agno-2.0.0rc2.dist-info → agno-2.3.0.dist-info}/top_level.txt +0 -0
agno/db/mysql/mysql.py
CHANGED
|
@@ -6,22 +6,27 @@ from uuid import uuid4
|
|
|
6
6
|
from sqlalchemy import Index, UniqueConstraint
|
|
7
7
|
|
|
8
8
|
from agno.db.base import BaseDb, SessionType
|
|
9
|
+
from agno.db.migrations.manager import MigrationManager
|
|
9
10
|
from agno.db.mysql.schemas import get_table_schema_definition
|
|
10
11
|
from agno.db.mysql.utils import (
|
|
11
12
|
apply_sorting,
|
|
12
13
|
bulk_upsert_metrics,
|
|
13
14
|
calculate_date_metrics,
|
|
14
15
|
create_schema,
|
|
16
|
+
deserialize_cultural_knowledge_from_db,
|
|
15
17
|
fetch_all_sessions_data,
|
|
16
18
|
get_dates_to_calculate_metrics_for,
|
|
17
19
|
is_table_available,
|
|
18
20
|
is_valid_table,
|
|
21
|
+
serialize_cultural_knowledge_for_db,
|
|
19
22
|
)
|
|
23
|
+
from agno.db.schemas.culture import CulturalKnowledge
|
|
20
24
|
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
21
25
|
from agno.db.schemas.knowledge import KnowledgeRow
|
|
22
26
|
from agno.db.schemas.memory import UserMemory
|
|
23
27
|
from agno.session import AgentSession, Session, TeamSession, WorkflowSession
|
|
24
|
-
from agno.utils.log import log_debug, log_error, log_info
|
|
28
|
+
from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
29
|
+
from agno.utils.string import generate_id
|
|
25
30
|
|
|
26
31
|
try:
|
|
27
32
|
from sqlalchemy import TEXT, and_, cast, func, update
|
|
@@ -41,10 +46,13 @@ class MySQLDb(BaseDb):
|
|
|
41
46
|
db_schema: Optional[str] = None,
|
|
42
47
|
db_url: Optional[str] = None,
|
|
43
48
|
session_table: Optional[str] = None,
|
|
49
|
+
culture_table: Optional[str] = None,
|
|
44
50
|
memory_table: Optional[str] = None,
|
|
45
51
|
metrics_table: Optional[str] = None,
|
|
46
52
|
eval_table: Optional[str] = None,
|
|
47
53
|
knowledge_table: Optional[str] = None,
|
|
54
|
+
versions_table: Optional[str] = None,
|
|
55
|
+
id: Optional[str] = None,
|
|
48
56
|
):
|
|
49
57
|
"""
|
|
50
58
|
Interface for interacting with a MySQL database.
|
|
@@ -59,21 +67,33 @@ class MySQLDb(BaseDb):
|
|
|
59
67
|
db_engine (Optional[Engine]): The SQLAlchemy database engine to use.
|
|
60
68
|
db_schema (Optional[str]): The database schema to use.
|
|
61
69
|
session_table (Optional[str]): Name of the table to store Agent, Team and Workflow sessions.
|
|
70
|
+
culture_table (Optional[str]): Name of the table to store cultural knowledge.
|
|
62
71
|
memory_table (Optional[str]): Name of the table to store memories.
|
|
63
72
|
metrics_table (Optional[str]): Name of the table to store metrics.
|
|
64
73
|
eval_table (Optional[str]): Name of the table to store evaluation runs data.
|
|
65
74
|
knowledge_table (Optional[str]): Name of the table to store knowledge content.
|
|
75
|
+
versions_table (Optional[str]): Name of the table to store schema versions.
|
|
76
|
+
id (Optional[str]): ID of the database.
|
|
66
77
|
|
|
67
78
|
Raises:
|
|
68
79
|
ValueError: If neither db_url nor db_engine is provided.
|
|
69
80
|
ValueError: If none of the tables are provided.
|
|
70
81
|
"""
|
|
82
|
+
if id is None:
|
|
83
|
+
base_seed = db_url or str(db_engine.url) # type: ignore
|
|
84
|
+
schema_suffix = db_schema if db_schema is not None else "ai"
|
|
85
|
+
seed = f"{base_seed}#{schema_suffix}"
|
|
86
|
+
id = generate_id(seed)
|
|
87
|
+
|
|
71
88
|
super().__init__(
|
|
89
|
+
id=id,
|
|
72
90
|
session_table=session_table,
|
|
91
|
+
culture_table=culture_table,
|
|
73
92
|
memory_table=memory_table,
|
|
74
93
|
metrics_table=metrics_table,
|
|
75
94
|
eval_table=eval_table,
|
|
76
95
|
knowledge_table=knowledge_table,
|
|
96
|
+
versions_table=versions_table,
|
|
77
97
|
)
|
|
78
98
|
|
|
79
99
|
_engine: Optional[Engine] = db_engine
|
|
@@ -91,6 +111,18 @@ class MySQLDb(BaseDb):
|
|
|
91
111
|
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
|
|
92
112
|
|
|
93
113
|
# -- DB methods --
|
|
114
|
+
def table_exists(self, table_name: str) -> bool:
|
|
115
|
+
"""Check if a table with the given name exists in the MySQL database.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
table_name: Name of the table to check
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
bool: True if the table exists in the database, False otherwise
|
|
122
|
+
"""
|
|
123
|
+
with self.Session() as sess:
|
|
124
|
+
return is_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
|
|
125
|
+
|
|
94
126
|
def _create_table(self, table_name: str, table_type: str, db_schema: str) -> Table:
|
|
95
127
|
"""
|
|
96
128
|
Create a table with the appropriate schema based on the table type.
|
|
@@ -106,7 +138,7 @@ class MySQLDb(BaseDb):
|
|
|
106
138
|
try:
|
|
107
139
|
table_schema = get_table_schema_definition(table_type)
|
|
108
140
|
|
|
109
|
-
log_debug(f"Creating table {
|
|
141
|
+
log_debug(f"Creating table {table_name}")
|
|
110
142
|
|
|
111
143
|
columns: List[Column] = []
|
|
112
144
|
indexes: List[str] = []
|
|
@@ -175,13 +207,32 @@ class MySQLDb(BaseDb):
|
|
|
175
207
|
except Exception as e:
|
|
176
208
|
log_error(f"Error creating index {idx.name}: {e}")
|
|
177
209
|
|
|
178
|
-
|
|
210
|
+
log_debug(f"Successfully created table {db_schema}.{table_name}")
|
|
179
211
|
return table
|
|
180
212
|
|
|
181
213
|
except Exception as e:
|
|
182
214
|
log_error(f"Could not create table {db_schema}.{table_name}: {e}")
|
|
183
215
|
raise
|
|
184
216
|
|
|
217
|
+
def _create_all_tables(self):
|
|
218
|
+
"""Create all tables for the database."""
|
|
219
|
+
tables_to_create = [
|
|
220
|
+
(self.session_table_name, "sessions"),
|
|
221
|
+
(self.memory_table_name, "memories"),
|
|
222
|
+
(self.metrics_table_name, "metrics"),
|
|
223
|
+
(self.eval_table_name, "evals"),
|
|
224
|
+
(self.knowledge_table_name, "knowledge"),
|
|
225
|
+
(self.versions_table_name, "versions"),
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
for table_name, table_type in tables_to_create:
|
|
229
|
+
if table_name != self.versions_table_name:
|
|
230
|
+
# Also store the schema version for the created table
|
|
231
|
+
latest_schema_version = MigrationManager(self).latest_schema_version
|
|
232
|
+
self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
|
|
233
|
+
|
|
234
|
+
self._create_table(table_name=table_name, table_type=table_type, db_schema=self.db_schema)
|
|
235
|
+
|
|
185
236
|
def _get_table(self, table_type: str, create_table_if_not_found: Optional[bool] = False) -> Optional[Table]:
|
|
186
237
|
if table_type == "sessions":
|
|
187
238
|
self.session_table = self._get_or_create_table(
|
|
@@ -228,6 +279,24 @@ class MySQLDb(BaseDb):
|
|
|
228
279
|
)
|
|
229
280
|
return self.knowledge_table
|
|
230
281
|
|
|
282
|
+
if table_type == "culture":
|
|
283
|
+
self.culture_table = self._get_or_create_table(
|
|
284
|
+
table_name=self.culture_table_name,
|
|
285
|
+
table_type="culture",
|
|
286
|
+
db_schema=self.db_schema,
|
|
287
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
288
|
+
)
|
|
289
|
+
return self.culture_table
|
|
290
|
+
|
|
291
|
+
if table_type == "versions":
|
|
292
|
+
self.versions_table = self._get_or_create_table(
|
|
293
|
+
table_name=self.versions_table_name,
|
|
294
|
+
table_type="versions",
|
|
295
|
+
db_schema=self.db_schema,
|
|
296
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
297
|
+
)
|
|
298
|
+
return self.versions_table
|
|
299
|
+
|
|
231
300
|
raise ValueError(f"Unknown table type: {table_type}")
|
|
232
301
|
|
|
233
302
|
def _get_or_create_table(
|
|
@@ -252,7 +321,14 @@ class MySQLDb(BaseDb):
|
|
|
252
321
|
if not create_table_if_not_found:
|
|
253
322
|
return None
|
|
254
323
|
|
|
255
|
-
|
|
324
|
+
created_table = self._create_table(table_name=table_name, table_type=table_type, db_schema=db_schema)
|
|
325
|
+
|
|
326
|
+
if table_name != self.versions_table_name:
|
|
327
|
+
# Also store the schema version for the created table
|
|
328
|
+
latest_schema_version = MigrationManager(self).latest_schema_version
|
|
329
|
+
self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
|
|
330
|
+
|
|
331
|
+
return created_table
|
|
256
332
|
|
|
257
333
|
if not is_valid_table(
|
|
258
334
|
db_engine=self.db_engine,
|
|
@@ -271,6 +347,39 @@ class MySQLDb(BaseDb):
|
|
|
271
347
|
log_error(f"Error loading existing table {db_schema}.{table_name}: {e}")
|
|
272
348
|
raise
|
|
273
349
|
|
|
350
|
+
def get_latest_schema_version(self, table_name: str) -> str:
|
|
351
|
+
"""Get the latest version of the database schema."""
|
|
352
|
+
table = self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
353
|
+
with self.Session() as sess:
|
|
354
|
+
# Latest version for the given table
|
|
355
|
+
stmt = select(table).where(table.c.table_name == table_name).order_by(table.c.version.desc()).limit(1) # type: ignore
|
|
356
|
+
result = sess.execute(stmt).fetchone()
|
|
357
|
+
if result is None:
|
|
358
|
+
return "2.0.0"
|
|
359
|
+
version_dict = dict(result._mapping)
|
|
360
|
+
return version_dict.get("version") or "2.0.0"
|
|
361
|
+
|
|
362
|
+
def upsert_schema_version(self, table_name: str, version: str) -> None:
|
|
363
|
+
"""Upsert the schema version into the database."""
|
|
364
|
+
table = self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
365
|
+
if table is None:
|
|
366
|
+
return
|
|
367
|
+
current_datetime = datetime.now().isoformat()
|
|
368
|
+
with self.Session() as sess, sess.begin():
|
|
369
|
+
stmt = mysql.insert(table).values( # type: ignore
|
|
370
|
+
table_name=table_name,
|
|
371
|
+
version=version,
|
|
372
|
+
created_at=current_datetime, # Store as ISO format string
|
|
373
|
+
updated_at=current_datetime,
|
|
374
|
+
)
|
|
375
|
+
# Update version if table_name already exists
|
|
376
|
+
stmt = stmt.on_duplicate_key_update(
|
|
377
|
+
version=version,
|
|
378
|
+
created_at=current_datetime,
|
|
379
|
+
updated_at=current_datetime,
|
|
380
|
+
)
|
|
381
|
+
sess.execute(stmt)
|
|
382
|
+
|
|
274
383
|
# -- Session methods --
|
|
275
384
|
def delete_session(self, session_id: str) -> bool:
|
|
276
385
|
"""
|
|
@@ -340,8 +449,8 @@ class MySQLDb(BaseDb):
|
|
|
340
449
|
|
|
341
450
|
Args:
|
|
342
451
|
session_id (str): ID of the session to read.
|
|
452
|
+
session_type (SessionType): Type of session to get.
|
|
343
453
|
user_id (Optional[str]): User ID to filter by. Defaults to None.
|
|
344
|
-
session_type (Optional[SessionType]): Type of session to read. Defaults to None.
|
|
345
454
|
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
346
455
|
|
|
347
456
|
Returns:
|
|
@@ -362,9 +471,6 @@ class MySQLDb(BaseDb):
|
|
|
362
471
|
|
|
363
472
|
if user_id is not None:
|
|
364
473
|
stmt = stmt.where(table.c.user_id == user_id)
|
|
365
|
-
if session_type is not None:
|
|
366
|
-
session_type_value = session_type.value if isinstance(session_type, SessionType) else session_type
|
|
367
|
-
stmt = stmt.where(table.c.session_type == session_type_value)
|
|
368
474
|
result = sess.execute(stmt).fetchone()
|
|
369
475
|
if result is None:
|
|
370
476
|
return None
|
|
@@ -405,6 +511,7 @@ class MySQLDb(BaseDb):
|
|
|
405
511
|
Get all sessions in the given table. Can filter by user_id and entity_id.
|
|
406
512
|
|
|
407
513
|
Args:
|
|
514
|
+
session_type (Optional[SessionType]): The type of sessions to get.
|
|
408
515
|
user_id (Optional[str]): The ID of the user to filter by.
|
|
409
516
|
entity_id (Optional[str]): The ID of the agent / workflow to filter by.
|
|
410
517
|
start_timestamp (Optional[int]): The start timestamp to filter by.
|
|
@@ -488,8 +595,8 @@ class MySQLDb(BaseDb):
|
|
|
488
595
|
raise ValueError(f"Invalid session type: {session_type}")
|
|
489
596
|
|
|
490
597
|
except Exception as e:
|
|
491
|
-
log_error(f"Exception getting
|
|
492
|
-
|
|
598
|
+
log_error(f"Exception getting sessions: {e}")
|
|
599
|
+
raise e
|
|
493
600
|
|
|
494
601
|
def rename_session(
|
|
495
602
|
self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
|
|
@@ -694,10 +801,232 @@ class MySQLDb(BaseDb):
|
|
|
694
801
|
log_error(f"Exception upserting into sessions table: {e}")
|
|
695
802
|
return None
|
|
696
803
|
|
|
804
|
+
def upsert_sessions(
|
|
805
|
+
self, sessions: List[Session], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
|
|
806
|
+
) -> List[Union[Session, Dict[str, Any]]]:
|
|
807
|
+
"""
|
|
808
|
+
Bulk upsert multiple sessions for improved performance on large datasets.
|
|
809
|
+
|
|
810
|
+
Args:
|
|
811
|
+
sessions (List[Session]): List of sessions to upsert.
|
|
812
|
+
deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
|
|
813
|
+
preserve_updated_at (bool): If True, preserve the updated_at from the session object.
|
|
814
|
+
|
|
815
|
+
Returns:
|
|
816
|
+
List[Union[Session, Dict[str, Any]]]: List of upserted sessions.
|
|
817
|
+
|
|
818
|
+
Raises:
|
|
819
|
+
Exception: If an error occurs during bulk upsert.
|
|
820
|
+
"""
|
|
821
|
+
if not sessions:
|
|
822
|
+
return []
|
|
823
|
+
|
|
824
|
+
try:
|
|
825
|
+
table = self._get_table(table_type="sessions", create_table_if_not_found=True)
|
|
826
|
+
if table is None:
|
|
827
|
+
log_info("Sessions table not available, falling back to individual upserts")
|
|
828
|
+
return [
|
|
829
|
+
result
|
|
830
|
+
for session in sessions
|
|
831
|
+
if session is not None
|
|
832
|
+
for result in [self.upsert_session(session, deserialize=deserialize)]
|
|
833
|
+
if result is not None
|
|
834
|
+
]
|
|
835
|
+
|
|
836
|
+
# Group sessions by type for batch processing
|
|
837
|
+
agent_sessions = []
|
|
838
|
+
team_sessions = []
|
|
839
|
+
workflow_sessions = []
|
|
840
|
+
|
|
841
|
+
for session in sessions:
|
|
842
|
+
if isinstance(session, AgentSession):
|
|
843
|
+
agent_sessions.append(session)
|
|
844
|
+
elif isinstance(session, TeamSession):
|
|
845
|
+
team_sessions.append(session)
|
|
846
|
+
elif isinstance(session, WorkflowSession):
|
|
847
|
+
workflow_sessions.append(session)
|
|
848
|
+
|
|
849
|
+
results: List[Union[Session, Dict[str, Any]]] = []
|
|
850
|
+
|
|
851
|
+
# Process each session type in bulk
|
|
852
|
+
with self.Session() as sess, sess.begin():
|
|
853
|
+
# Bulk upsert agent sessions
|
|
854
|
+
if agent_sessions:
|
|
855
|
+
agent_data = []
|
|
856
|
+
for session in agent_sessions:
|
|
857
|
+
session_dict = session.to_dict()
|
|
858
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
859
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
860
|
+
agent_data.append(
|
|
861
|
+
{
|
|
862
|
+
"session_id": session_dict.get("session_id"),
|
|
863
|
+
"session_type": SessionType.AGENT.value,
|
|
864
|
+
"agent_id": session_dict.get("agent_id"),
|
|
865
|
+
"user_id": session_dict.get("user_id"),
|
|
866
|
+
"runs": session_dict.get("runs"),
|
|
867
|
+
"agent_data": session_dict.get("agent_data"),
|
|
868
|
+
"session_data": session_dict.get("session_data"),
|
|
869
|
+
"summary": session_dict.get("summary"),
|
|
870
|
+
"metadata": session_dict.get("metadata"),
|
|
871
|
+
"created_at": session_dict.get("created_at"),
|
|
872
|
+
"updated_at": updated_at,
|
|
873
|
+
}
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
if agent_data:
|
|
877
|
+
stmt = mysql.insert(table)
|
|
878
|
+
stmt = stmt.on_duplicate_key_update(
|
|
879
|
+
agent_id=stmt.inserted.agent_id,
|
|
880
|
+
user_id=stmt.inserted.user_id,
|
|
881
|
+
agent_data=stmt.inserted.agent_data,
|
|
882
|
+
session_data=stmt.inserted.session_data,
|
|
883
|
+
summary=stmt.inserted.summary,
|
|
884
|
+
metadata=stmt.inserted.metadata,
|
|
885
|
+
runs=stmt.inserted.runs,
|
|
886
|
+
updated_at=stmt.inserted.updated_at,
|
|
887
|
+
)
|
|
888
|
+
sess.execute(stmt, agent_data)
|
|
889
|
+
|
|
890
|
+
# Fetch the results for agent sessions
|
|
891
|
+
agent_ids = [session.session_id for session in agent_sessions]
|
|
892
|
+
select_stmt = select(table).where(table.c.session_id.in_(agent_ids))
|
|
893
|
+
result = sess.execute(select_stmt).fetchall()
|
|
894
|
+
|
|
895
|
+
for row in result:
|
|
896
|
+
session_dict = dict(row._mapping)
|
|
897
|
+
if deserialize:
|
|
898
|
+
deserialized_agent_session = AgentSession.from_dict(session_dict)
|
|
899
|
+
if deserialized_agent_session is None:
|
|
900
|
+
continue
|
|
901
|
+
results.append(deserialized_agent_session)
|
|
902
|
+
else:
|
|
903
|
+
results.append(session_dict)
|
|
904
|
+
|
|
905
|
+
# Bulk upsert team sessions
|
|
906
|
+
if team_sessions:
|
|
907
|
+
team_data = []
|
|
908
|
+
for session in team_sessions:
|
|
909
|
+
session_dict = session.to_dict()
|
|
910
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
911
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
912
|
+
team_data.append(
|
|
913
|
+
{
|
|
914
|
+
"session_id": session_dict.get("session_id"),
|
|
915
|
+
"session_type": SessionType.TEAM.value,
|
|
916
|
+
"team_id": session_dict.get("team_id"),
|
|
917
|
+
"user_id": session_dict.get("user_id"),
|
|
918
|
+
"runs": session_dict.get("runs"),
|
|
919
|
+
"team_data": session_dict.get("team_data"),
|
|
920
|
+
"session_data": session_dict.get("session_data"),
|
|
921
|
+
"summary": session_dict.get("summary"),
|
|
922
|
+
"metadata": session_dict.get("metadata"),
|
|
923
|
+
"created_at": session_dict.get("created_at"),
|
|
924
|
+
"updated_at": updated_at,
|
|
925
|
+
}
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
if team_data:
|
|
929
|
+
stmt = mysql.insert(table)
|
|
930
|
+
stmt = stmt.on_duplicate_key_update(
|
|
931
|
+
team_id=stmt.inserted.team_id,
|
|
932
|
+
user_id=stmt.inserted.user_id,
|
|
933
|
+
team_data=stmt.inserted.team_data,
|
|
934
|
+
session_data=stmt.inserted.session_data,
|
|
935
|
+
summary=stmt.inserted.summary,
|
|
936
|
+
metadata=stmt.inserted.metadata,
|
|
937
|
+
runs=stmt.inserted.runs,
|
|
938
|
+
updated_at=stmt.inserted.updated_at,
|
|
939
|
+
)
|
|
940
|
+
sess.execute(stmt, team_data)
|
|
941
|
+
|
|
942
|
+
# Fetch the results for team sessions
|
|
943
|
+
team_ids = [session.session_id for session in team_sessions]
|
|
944
|
+
select_stmt = select(table).where(table.c.session_id.in_(team_ids))
|
|
945
|
+
result = sess.execute(select_stmt).fetchall()
|
|
946
|
+
|
|
947
|
+
for row in result:
|
|
948
|
+
session_dict = dict(row._mapping)
|
|
949
|
+
if deserialize:
|
|
950
|
+
deserialized_team_session = TeamSession.from_dict(session_dict)
|
|
951
|
+
if deserialized_team_session is None:
|
|
952
|
+
continue
|
|
953
|
+
results.append(deserialized_team_session)
|
|
954
|
+
else:
|
|
955
|
+
results.append(session_dict)
|
|
956
|
+
|
|
957
|
+
# Bulk upsert workflow sessions
|
|
958
|
+
if workflow_sessions:
|
|
959
|
+
workflow_data = []
|
|
960
|
+
for session in workflow_sessions:
|
|
961
|
+
session_dict = session.to_dict()
|
|
962
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
963
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
964
|
+
workflow_data.append(
|
|
965
|
+
{
|
|
966
|
+
"session_id": session_dict.get("session_id"),
|
|
967
|
+
"session_type": SessionType.WORKFLOW.value,
|
|
968
|
+
"workflow_id": session_dict.get("workflow_id"),
|
|
969
|
+
"user_id": session_dict.get("user_id"),
|
|
970
|
+
"runs": session_dict.get("runs"),
|
|
971
|
+
"workflow_data": session_dict.get("workflow_data"),
|
|
972
|
+
"session_data": session_dict.get("session_data"),
|
|
973
|
+
"summary": session_dict.get("summary"),
|
|
974
|
+
"metadata": session_dict.get("metadata"),
|
|
975
|
+
"created_at": session_dict.get("created_at"),
|
|
976
|
+
"updated_at": updated_at,
|
|
977
|
+
}
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
if workflow_data:
|
|
981
|
+
stmt = mysql.insert(table)
|
|
982
|
+
stmt = stmt.on_duplicate_key_update(
|
|
983
|
+
workflow_id=stmt.inserted.workflow_id,
|
|
984
|
+
user_id=stmt.inserted.user_id,
|
|
985
|
+
workflow_data=stmt.inserted.workflow_data,
|
|
986
|
+
session_data=stmt.inserted.session_data,
|
|
987
|
+
summary=stmt.inserted.summary,
|
|
988
|
+
metadata=stmt.inserted.metadata,
|
|
989
|
+
runs=stmt.inserted.runs,
|
|
990
|
+
updated_at=stmt.inserted.updated_at,
|
|
991
|
+
)
|
|
992
|
+
sess.execute(stmt, workflow_data)
|
|
993
|
+
|
|
994
|
+
# Fetch the results for workflow sessions
|
|
995
|
+
workflow_ids = [session.session_id for session in workflow_sessions]
|
|
996
|
+
select_stmt = select(table).where(table.c.session_id.in_(workflow_ids))
|
|
997
|
+
result = sess.execute(select_stmt).fetchall()
|
|
998
|
+
|
|
999
|
+
for row in result:
|
|
1000
|
+
session_dict = dict(row._mapping)
|
|
1001
|
+
if deserialize:
|
|
1002
|
+
deserialized_workflow_session = WorkflowSession.from_dict(session_dict)
|
|
1003
|
+
if deserialized_workflow_session is None:
|
|
1004
|
+
continue
|
|
1005
|
+
results.append(deserialized_workflow_session)
|
|
1006
|
+
else:
|
|
1007
|
+
results.append(session_dict)
|
|
1008
|
+
|
|
1009
|
+
return results
|
|
1010
|
+
|
|
1011
|
+
except Exception as e:
|
|
1012
|
+
log_error(f"Exception during bulk session upsert, falling back to individual upserts: {e}")
|
|
1013
|
+
# Fallback to individual upserts
|
|
1014
|
+
return [
|
|
1015
|
+
result
|
|
1016
|
+
for session in sessions
|
|
1017
|
+
if session is not None
|
|
1018
|
+
for result in [self.upsert_session(session, deserialize=deserialize)]
|
|
1019
|
+
if result is not None
|
|
1020
|
+
]
|
|
1021
|
+
|
|
697
1022
|
# -- Memory methods --
|
|
698
|
-
def delete_user_memory(self, memory_id: str):
|
|
1023
|
+
def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
|
|
699
1024
|
"""Delete a user memory from the database.
|
|
700
1025
|
|
|
1026
|
+
Args:
|
|
1027
|
+
memory_id (str): The ID of the memory to delete.
|
|
1028
|
+
user_id (Optional[str]): The user ID to filter by. Defaults to None.
|
|
1029
|
+
|
|
701
1030
|
Returns:
|
|
702
1031
|
bool: True if deletion was successful, False otherwise.
|
|
703
1032
|
|
|
@@ -711,6 +1040,8 @@ class MySQLDb(BaseDb):
|
|
|
711
1040
|
|
|
712
1041
|
with self.Session() as sess, sess.begin():
|
|
713
1042
|
delete_stmt = table.delete().where(table.c.memory_id == memory_id)
|
|
1043
|
+
if user_id is not None:
|
|
1044
|
+
delete_stmt = delete_stmt.where(table.c.user_id == user_id)
|
|
714
1045
|
result = sess.execute(delete_stmt)
|
|
715
1046
|
|
|
716
1047
|
success = result.rowcount > 0
|
|
@@ -722,11 +1053,12 @@ class MySQLDb(BaseDb):
|
|
|
722
1053
|
except Exception as e:
|
|
723
1054
|
log_error(f"Error deleting user memory: {e}")
|
|
724
1055
|
|
|
725
|
-
def delete_user_memories(self, memory_ids: List[str]) -> None:
|
|
1056
|
+
def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
|
|
726
1057
|
"""Delete user memories from the database.
|
|
727
1058
|
|
|
728
1059
|
Args:
|
|
729
1060
|
memory_ids (List[str]): The IDs of the memories to delete.
|
|
1061
|
+
user_id (Optional[str]): The user ID to filter by. Defaults to None.
|
|
730
1062
|
|
|
731
1063
|
Raises:
|
|
732
1064
|
Exception: If an error occurs during deletion.
|
|
@@ -738,6 +1070,8 @@ class MySQLDb(BaseDb):
|
|
|
738
1070
|
|
|
739
1071
|
with self.Session() as sess, sess.begin():
|
|
740
1072
|
delete_stmt = table.delete().where(table.c.memory_id.in_(memory_ids))
|
|
1073
|
+
if user_id is not None:
|
|
1074
|
+
delete_stmt = delete_stmt.where(table.c.user_id == user_id)
|
|
741
1075
|
result = sess.execute(delete_stmt)
|
|
742
1076
|
if result.rowcount == 0:
|
|
743
1077
|
log_debug(f"No user memories found with ids: {memory_ids}")
|
|
@@ -778,14 +1112,17 @@ class MySQLDb(BaseDb):
|
|
|
778
1112
|
|
|
779
1113
|
except Exception as e:
|
|
780
1114
|
log_error(f"Exception reading from memory table: {e}")
|
|
781
|
-
|
|
1115
|
+
raise e
|
|
782
1116
|
|
|
783
|
-
def get_user_memory(
|
|
1117
|
+
def get_user_memory(
|
|
1118
|
+
self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
|
|
1119
|
+
) -> Optional[UserMemory]:
|
|
784
1120
|
"""Get a memory from the database.
|
|
785
1121
|
|
|
786
1122
|
Args:
|
|
787
1123
|
memory_id (str): The ID of the memory to get.
|
|
788
1124
|
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
1125
|
+
user_id (Optional[str]): The user ID to filter by. Defaults to None.
|
|
789
1126
|
|
|
790
1127
|
Returns:
|
|
791
1128
|
Union[UserMemory, Dict[str, Any], None]:
|
|
@@ -802,6 +1139,8 @@ class MySQLDb(BaseDb):
|
|
|
802
1139
|
|
|
803
1140
|
with self.Session() as sess, sess.begin():
|
|
804
1141
|
stmt = select(table).where(table.c.memory_id == memory_id)
|
|
1142
|
+
if user_id is not None:
|
|
1143
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
805
1144
|
|
|
806
1145
|
result = sess.execute(stmt).fetchone()
|
|
807
1146
|
if not result:
|
|
@@ -900,7 +1239,7 @@ class MySQLDb(BaseDb):
|
|
|
900
1239
|
|
|
901
1240
|
except Exception as e:
|
|
902
1241
|
log_error(f"Exception reading from memory table: {e}")
|
|
903
|
-
|
|
1242
|
+
raise e
|
|
904
1243
|
|
|
905
1244
|
def clear_memories(self) -> None:
|
|
906
1245
|
"""Clear all user memories from the database."""
|
|
@@ -1007,6 +1346,8 @@ class MySQLDb(BaseDb):
|
|
|
1007
1346
|
if memory.memory_id is None:
|
|
1008
1347
|
memory.memory_id = str(uuid4())
|
|
1009
1348
|
|
|
1349
|
+
current_time = int(time.time())
|
|
1350
|
+
|
|
1010
1351
|
stmt = mysql.insert(table).values(
|
|
1011
1352
|
memory_id=memory.memory_id,
|
|
1012
1353
|
memory=memory.memory,
|
|
@@ -1015,7 +1356,9 @@ class MySQLDb(BaseDb):
|
|
|
1015
1356
|
agent_id=memory.agent_id,
|
|
1016
1357
|
team_id=memory.team_id,
|
|
1017
1358
|
topics=memory.topics,
|
|
1018
|
-
|
|
1359
|
+
feedback=memory.feedback,
|
|
1360
|
+
created_at=memory.created_at,
|
|
1361
|
+
updated_at=memory.created_at,
|
|
1019
1362
|
)
|
|
1020
1363
|
stmt = stmt.on_duplicate_key_update(
|
|
1021
1364
|
memory=memory.memory,
|
|
@@ -1023,7 +1366,10 @@ class MySQLDb(BaseDb):
|
|
|
1023
1366
|
input=memory.input,
|
|
1024
1367
|
agent_id=memory.agent_id,
|
|
1025
1368
|
team_id=memory.team_id,
|
|
1026
|
-
|
|
1369
|
+
feedback=memory.feedback,
|
|
1370
|
+
updated_at=current_time,
|
|
1371
|
+
# Preserve created_at on update - don't overwrite existing value
|
|
1372
|
+
created_at=table.c.created_at,
|
|
1027
1373
|
)
|
|
1028
1374
|
sess.execute(stmt)
|
|
1029
1375
|
|
|
@@ -1044,6 +1390,106 @@ class MySQLDb(BaseDb):
|
|
|
1044
1390
|
log_error(f"Exception upserting user memory: {e}")
|
|
1045
1391
|
return None
|
|
1046
1392
|
|
|
1393
|
+
def upsert_memories(
|
|
1394
|
+
self, memories: List[UserMemory], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
|
|
1395
|
+
) -> List[Union[UserMemory, Dict[str, Any]]]:
|
|
1396
|
+
"""
|
|
1397
|
+
Bulk upsert multiple user memories for improved performance on large datasets.
|
|
1398
|
+
|
|
1399
|
+
Args:
|
|
1400
|
+
memories (List[UserMemory]): List of memories to upsert.
|
|
1401
|
+
deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
|
|
1402
|
+
|
|
1403
|
+
Returns:
|
|
1404
|
+
List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories.
|
|
1405
|
+
|
|
1406
|
+
Raises:
|
|
1407
|
+
Exception: If an error occurs during bulk upsert.
|
|
1408
|
+
"""
|
|
1409
|
+
if not memories:
|
|
1410
|
+
return []
|
|
1411
|
+
|
|
1412
|
+
try:
|
|
1413
|
+
table = self._get_table(table_type="memories", create_table_if_not_found=True)
|
|
1414
|
+
if table is None:
|
|
1415
|
+
log_info("Memories table not available, falling back to individual upserts")
|
|
1416
|
+
return [
|
|
1417
|
+
result
|
|
1418
|
+
for memory in memories
|
|
1419
|
+
if memory is not None
|
|
1420
|
+
for result in [self.upsert_user_memory(memory, deserialize=deserialize)]
|
|
1421
|
+
if result is not None
|
|
1422
|
+
]
|
|
1423
|
+
|
|
1424
|
+
# Prepare bulk data
|
|
1425
|
+
bulk_data = []
|
|
1426
|
+
current_time = int(time.time())
|
|
1427
|
+
|
|
1428
|
+
for memory in memories:
|
|
1429
|
+
if memory.memory_id is None:
|
|
1430
|
+
memory.memory_id = str(uuid4())
|
|
1431
|
+
|
|
1432
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
1433
|
+
updated_at = memory.updated_at if preserve_updated_at else current_time
|
|
1434
|
+
|
|
1435
|
+
bulk_data.append(
|
|
1436
|
+
{
|
|
1437
|
+
"memory_id": memory.memory_id,
|
|
1438
|
+
"memory": memory.memory,
|
|
1439
|
+
"input": memory.input,
|
|
1440
|
+
"user_id": memory.user_id,
|
|
1441
|
+
"agent_id": memory.agent_id,
|
|
1442
|
+
"team_id": memory.team_id,
|
|
1443
|
+
"topics": memory.topics,
|
|
1444
|
+
"feedback": memory.feedback,
|
|
1445
|
+
"created_at": memory.created_at,
|
|
1446
|
+
"updated_at": updated_at,
|
|
1447
|
+
}
|
|
1448
|
+
)
|
|
1449
|
+
|
|
1450
|
+
results: List[Union[UserMemory, Dict[str, Any]]] = []
|
|
1451
|
+
|
|
1452
|
+
with self.Session() as sess, sess.begin():
|
|
1453
|
+
# Bulk upsert memories using MySQL ON DUPLICATE KEY UPDATE
|
|
1454
|
+
stmt = mysql.insert(table)
|
|
1455
|
+
stmt = stmt.on_duplicate_key_update(
|
|
1456
|
+
memory=stmt.inserted.memory,
|
|
1457
|
+
topics=stmt.inserted.topics,
|
|
1458
|
+
input=stmt.inserted.input,
|
|
1459
|
+
agent_id=stmt.inserted.agent_id,
|
|
1460
|
+
team_id=stmt.inserted.team_id,
|
|
1461
|
+
feedback=stmt.inserted.feedback,
|
|
1462
|
+
updated_at=stmt.inserted.updated_at,
|
|
1463
|
+
# Preserve created_at on update
|
|
1464
|
+
created_at=table.c.created_at,
|
|
1465
|
+
)
|
|
1466
|
+
sess.execute(stmt, bulk_data)
|
|
1467
|
+
|
|
1468
|
+
# Fetch results
|
|
1469
|
+
memory_ids = [memory.memory_id for memory in memories if memory.memory_id]
|
|
1470
|
+
select_stmt = select(table).where(table.c.memory_id.in_(memory_ids))
|
|
1471
|
+
result = sess.execute(select_stmt).fetchall()
|
|
1472
|
+
|
|
1473
|
+
for row in result:
|
|
1474
|
+
memory_dict = dict(row._mapping)
|
|
1475
|
+
if deserialize:
|
|
1476
|
+
results.append(UserMemory.from_dict(memory_dict))
|
|
1477
|
+
else:
|
|
1478
|
+
results.append(memory_dict)
|
|
1479
|
+
|
|
1480
|
+
return results
|
|
1481
|
+
|
|
1482
|
+
except Exception as e:
|
|
1483
|
+
log_error(f"Exception during bulk memory upsert, falling back to individual upserts: {e}")
|
|
1484
|
+
# Fallback to individual upserts
|
|
1485
|
+
return [
|
|
1486
|
+
result
|
|
1487
|
+
for memory in memories
|
|
1488
|
+
if memory is not None
|
|
1489
|
+
for result in [self.upsert_user_memory(memory, deserialize=deserialize)]
|
|
1490
|
+
if result is not None
|
|
1491
|
+
]
|
|
1492
|
+
|
|
1047
1493
|
# -- Metrics methods --
|
|
1048
1494
|
def _get_all_sessions_for_metrics_calculation(
|
|
1049
1495
|
self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None
|
|
@@ -1085,7 +1531,7 @@ class MySQLDb(BaseDb):
|
|
|
1085
1531
|
|
|
1086
1532
|
except Exception as e:
|
|
1087
1533
|
log_error(f"Exception reading from sessions table: {e}")
|
|
1088
|
-
|
|
1534
|
+
raise e
|
|
1089
1535
|
|
|
1090
1536
|
def _get_metrics_calculation_starting_date(self, table: Table) -> Optional[date]:
|
|
1091
1537
|
"""Get the first date for which metrics calculation is needed:
|
|
@@ -1328,9 +1774,9 @@ class MySQLDb(BaseDb):
|
|
|
1328
1774
|
if page is not None:
|
|
1329
1775
|
stmt = stmt.offset((page - 1) * limit)
|
|
1330
1776
|
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1777
|
+
result = sess.execute(stmt).fetchall()
|
|
1778
|
+
if not result:
|
|
1779
|
+
return [], 0
|
|
1334
1780
|
|
|
1335
1781
|
return [KnowledgeRow.model_validate(record._mapping) for record in result], total_count
|
|
1336
1782
|
|
|
@@ -1622,7 +2068,7 @@ class MySQLDb(BaseDb):
|
|
|
1622
2068
|
|
|
1623
2069
|
except Exception as e:
|
|
1624
2070
|
log_error(f"Exception getting eval runs: {e}")
|
|
1625
|
-
|
|
2071
|
+
raise e
|
|
1626
2072
|
|
|
1627
2073
|
def rename_eval_run(
|
|
1628
2074
|
self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
|
|
@@ -1660,6 +2106,222 @@ class MySQLDb(BaseDb):
|
|
|
1660
2106
|
log_error(f"Error upserting eval run name {eval_run_id}: {e}")
|
|
1661
2107
|
return None
|
|
1662
2108
|
|
|
2109
|
+
# -- Culture methods --
|
|
2110
|
+
|
|
2111
|
+
def clear_cultural_knowledge(self) -> None:
|
|
2112
|
+
"""Delete all cultural knowledge from the database.
|
|
2113
|
+
|
|
2114
|
+
Raises:
|
|
2115
|
+
Exception: If an error occurs during deletion.
|
|
2116
|
+
"""
|
|
2117
|
+
try:
|
|
2118
|
+
table = self._get_table(table_type="culture")
|
|
2119
|
+
if table is None:
|
|
2120
|
+
return
|
|
2121
|
+
|
|
2122
|
+
with self.Session() as sess, sess.begin():
|
|
2123
|
+
sess.execute(table.delete())
|
|
2124
|
+
|
|
2125
|
+
except Exception as e:
|
|
2126
|
+
log_warning(f"Exception deleting all cultural knowledge: {e}")
|
|
2127
|
+
raise e
|
|
2128
|
+
|
|
2129
|
+
def delete_cultural_knowledge(self, id: str) -> None:
|
|
2130
|
+
"""Delete a cultural knowledge entry from the database.
|
|
2131
|
+
|
|
2132
|
+
Args:
|
|
2133
|
+
id (str): The ID of the cultural knowledge to delete.
|
|
2134
|
+
|
|
2135
|
+
Raises:
|
|
2136
|
+
Exception: If an error occurs during deletion.
|
|
2137
|
+
"""
|
|
2138
|
+
try:
|
|
2139
|
+
table = self._get_table(table_type="culture")
|
|
2140
|
+
if table is None:
|
|
2141
|
+
return
|
|
2142
|
+
|
|
2143
|
+
with self.Session() as sess, sess.begin():
|
|
2144
|
+
delete_stmt = table.delete().where(table.c.id == id)
|
|
2145
|
+
result = sess.execute(delete_stmt)
|
|
2146
|
+
|
|
2147
|
+
success = result.rowcount > 0
|
|
2148
|
+
if success:
|
|
2149
|
+
log_debug(f"Successfully deleted cultural knowledge id: {id}")
|
|
2150
|
+
else:
|
|
2151
|
+
log_debug(f"No cultural knowledge found with id: {id}")
|
|
2152
|
+
|
|
2153
|
+
except Exception as e:
|
|
2154
|
+
log_error(f"Error deleting cultural knowledge: {e}")
|
|
2155
|
+
raise e
|
|
2156
|
+
|
|
2157
|
+
def get_cultural_knowledge(
|
|
2158
|
+
self, id: str, deserialize: Optional[bool] = True
|
|
2159
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2160
|
+
"""Get a cultural knowledge entry from the database.
|
|
2161
|
+
|
|
2162
|
+
Args:
|
|
2163
|
+
id (str): The ID of the cultural knowledge to get.
|
|
2164
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2165
|
+
|
|
2166
|
+
Returns:
|
|
2167
|
+
Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The cultural knowledge entry, or None if it doesn't exist.
|
|
2168
|
+
|
|
2169
|
+
Raises:
|
|
2170
|
+
Exception: If an error occurs during retrieval.
|
|
2171
|
+
"""
|
|
2172
|
+
try:
|
|
2173
|
+
table = self._get_table(table_type="culture")
|
|
2174
|
+
if table is None:
|
|
2175
|
+
return None
|
|
2176
|
+
|
|
2177
|
+
with self.Session() as sess, sess.begin():
|
|
2178
|
+
stmt = select(table).where(table.c.id == id)
|
|
2179
|
+
result = sess.execute(stmt).fetchone()
|
|
2180
|
+
if result is None:
|
|
2181
|
+
return None
|
|
2182
|
+
|
|
2183
|
+
db_row = dict(result._mapping)
|
|
2184
|
+
if not db_row or not deserialize:
|
|
2185
|
+
return db_row
|
|
2186
|
+
|
|
2187
|
+
return deserialize_cultural_knowledge_from_db(db_row)
|
|
2188
|
+
|
|
2189
|
+
except Exception as e:
|
|
2190
|
+
log_error(f"Exception reading from cultural knowledge table: {e}")
|
|
2191
|
+
raise e
|
|
2192
|
+
|
|
2193
|
+
def get_all_cultural_knowledge(
|
|
2194
|
+
self,
|
|
2195
|
+
name: Optional[str] = None,
|
|
2196
|
+
agent_id: Optional[str] = None,
|
|
2197
|
+
team_id: Optional[str] = None,
|
|
2198
|
+
limit: Optional[int] = None,
|
|
2199
|
+
page: Optional[int] = None,
|
|
2200
|
+
sort_by: Optional[str] = None,
|
|
2201
|
+
sort_order: Optional[str] = None,
|
|
2202
|
+
deserialize: Optional[bool] = True,
|
|
2203
|
+
) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2204
|
+
"""Get all cultural knowledge from the database as CulturalKnowledge objects.
|
|
2205
|
+
|
|
2206
|
+
Args:
|
|
2207
|
+
name (Optional[str]): The name of the cultural knowledge to filter by.
|
|
2208
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
2209
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
2210
|
+
limit (Optional[int]): The maximum number of cultural knowledge entries to return.
|
|
2211
|
+
page (Optional[int]): The page number.
|
|
2212
|
+
sort_by (Optional[str]): The column to sort by.
|
|
2213
|
+
sort_order (Optional[str]): The order to sort by.
|
|
2214
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2215
|
+
|
|
2216
|
+
Returns:
|
|
2217
|
+
Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2218
|
+
- When deserialize=True: List of CulturalKnowledge objects
|
|
2219
|
+
- When deserialize=False: List of CulturalKnowledge dictionaries and total count
|
|
2220
|
+
|
|
2221
|
+
Raises:
|
|
2222
|
+
Exception: If an error occurs during retrieval.
|
|
2223
|
+
"""
|
|
2224
|
+
try:
|
|
2225
|
+
table = self._get_table(table_type="culture")
|
|
2226
|
+
if table is None:
|
|
2227
|
+
return [] if deserialize else ([], 0)
|
|
2228
|
+
|
|
2229
|
+
with self.Session() as sess, sess.begin():
|
|
2230
|
+
stmt = select(table)
|
|
2231
|
+
|
|
2232
|
+
# Filtering
|
|
2233
|
+
if name is not None:
|
|
2234
|
+
stmt = stmt.where(table.c.name == name)
|
|
2235
|
+
if agent_id is not None:
|
|
2236
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
2237
|
+
if team_id is not None:
|
|
2238
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
2239
|
+
|
|
2240
|
+
# Get total count after applying filtering
|
|
2241
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
2242
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
2243
|
+
|
|
2244
|
+
# Sorting
|
|
2245
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
2246
|
+
# Paginating
|
|
2247
|
+
if limit is not None:
|
|
2248
|
+
stmt = stmt.limit(limit)
|
|
2249
|
+
if page is not None:
|
|
2250
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
2251
|
+
|
|
2252
|
+
result = sess.execute(stmt).fetchall()
|
|
2253
|
+
if not result:
|
|
2254
|
+
return [] if deserialize else ([], 0)
|
|
2255
|
+
|
|
2256
|
+
db_rows = [dict(record._mapping) for record in result]
|
|
2257
|
+
|
|
2258
|
+
if not deserialize:
|
|
2259
|
+
return db_rows, total_count
|
|
2260
|
+
|
|
2261
|
+
return [deserialize_cultural_knowledge_from_db(row) for row in db_rows]
|
|
2262
|
+
|
|
2263
|
+
except Exception as e:
|
|
2264
|
+
log_error(f"Error reading from cultural knowledge table: {e}")
|
|
2265
|
+
raise e
|
|
2266
|
+
|
|
2267
|
+
def upsert_cultural_knowledge(
|
|
2268
|
+
self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
|
|
2269
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2270
|
+
"""Upsert a cultural knowledge entry into the database.
|
|
2271
|
+
|
|
2272
|
+
Args:
|
|
2273
|
+
cultural_knowledge (CulturalKnowledge): The cultural knowledge to upsert.
|
|
2274
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2275
|
+
|
|
2276
|
+
Returns:
|
|
2277
|
+
Optional[CulturalKnowledge]: The upserted cultural knowledge entry.
|
|
2278
|
+
|
|
2279
|
+
Raises:
|
|
2280
|
+
Exception: If an error occurs during upsert.
|
|
2281
|
+
"""
|
|
2282
|
+
try:
|
|
2283
|
+
table = self._get_table(table_type="culture", create_table_if_not_found=True)
|
|
2284
|
+
if table is None:
|
|
2285
|
+
return None
|
|
2286
|
+
|
|
2287
|
+
if cultural_knowledge.id is None:
|
|
2288
|
+
cultural_knowledge.id = str(uuid4())
|
|
2289
|
+
|
|
2290
|
+
# Serialize content, categories, and notes into a JSON dict for DB storage
|
|
2291
|
+
content_dict = serialize_cultural_knowledge_for_db(cultural_knowledge)
|
|
2292
|
+
|
|
2293
|
+
with self.Session() as sess, sess.begin():
|
|
2294
|
+
stmt = mysql.insert(table).values(
|
|
2295
|
+
id=cultural_knowledge.id,
|
|
2296
|
+
name=cultural_knowledge.name,
|
|
2297
|
+
summary=cultural_knowledge.summary,
|
|
2298
|
+
content=content_dict if content_dict else None,
|
|
2299
|
+
metadata=cultural_knowledge.metadata,
|
|
2300
|
+
input=cultural_knowledge.input,
|
|
2301
|
+
created_at=cultural_knowledge.created_at,
|
|
2302
|
+
updated_at=int(time.time()),
|
|
2303
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2304
|
+
team_id=cultural_knowledge.team_id,
|
|
2305
|
+
)
|
|
2306
|
+
stmt = stmt.on_duplicate_key_update(
|
|
2307
|
+
name=cultural_knowledge.name,
|
|
2308
|
+
summary=cultural_knowledge.summary,
|
|
2309
|
+
content=content_dict if content_dict else None,
|
|
2310
|
+
metadata=cultural_knowledge.metadata,
|
|
2311
|
+
input=cultural_knowledge.input,
|
|
2312
|
+
updated_at=int(time.time()),
|
|
2313
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2314
|
+
team_id=cultural_knowledge.team_id,
|
|
2315
|
+
)
|
|
2316
|
+
sess.execute(stmt)
|
|
2317
|
+
|
|
2318
|
+
# Fetch the inserted/updated row
|
|
2319
|
+
return self.get_cultural_knowledge(id=cultural_knowledge.id, deserialize=deserialize)
|
|
2320
|
+
|
|
2321
|
+
except Exception as e:
|
|
2322
|
+
log_error(f"Error upserting cultural knowledge: {e}")
|
|
2323
|
+
raise e
|
|
2324
|
+
|
|
1663
2325
|
# -- Migrations --
|
|
1664
2326
|
|
|
1665
2327
|
def migrate_table_from_v1_to_v2(self, v1_db_schema: str, v1_table_name: str, v1_table_type: str):
|
|
@@ -1701,17 +2363,17 @@ class MySQLDb(BaseDb):
|
|
|
1701
2363
|
if v1_table_type == "agent_sessions":
|
|
1702
2364
|
for session in sessions:
|
|
1703
2365
|
self.upsert_session(session)
|
|
1704
|
-
log_info(f"Migrated {len(sessions)} Agent sessions to table: {self.
|
|
2366
|
+
log_info(f"Migrated {len(sessions)} Agent sessions to table: {self.session_table_name}")
|
|
1705
2367
|
|
|
1706
2368
|
elif v1_table_type == "team_sessions":
|
|
1707
2369
|
for session in sessions:
|
|
1708
2370
|
self.upsert_session(session)
|
|
1709
|
-
log_info(f"Migrated {len(sessions)} Team sessions to table: {self.
|
|
2371
|
+
log_info(f"Migrated {len(sessions)} Team sessions to table: {self.session_table_name}")
|
|
1710
2372
|
|
|
1711
2373
|
elif v1_table_type == "workflow_sessions":
|
|
1712
2374
|
for session in sessions:
|
|
1713
2375
|
self.upsert_session(session)
|
|
1714
|
-
log_info(f"Migrated {len(sessions)} Workflow sessions to table: {self.
|
|
2376
|
+
log_info(f"Migrated {len(sessions)} Workflow sessions to table: {self.session_table_name}")
|
|
1715
2377
|
|
|
1716
2378
|
elif v1_table_type == "memories":
|
|
1717
2379
|
for memory in memories:
|