agno 2.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/__init__.py +8 -0
- agno/agent/__init__.py +51 -0
- agno/agent/agent.py +10405 -0
- agno/api/__init__.py +0 -0
- agno/api/agent.py +28 -0
- agno/api/api.py +40 -0
- agno/api/evals.py +22 -0
- agno/api/os.py +17 -0
- agno/api/routes.py +13 -0
- agno/api/schemas/__init__.py +9 -0
- agno/api/schemas/agent.py +16 -0
- agno/api/schemas/evals.py +16 -0
- agno/api/schemas/os.py +14 -0
- agno/api/schemas/response.py +6 -0
- agno/api/schemas/team.py +16 -0
- agno/api/schemas/utils.py +21 -0
- agno/api/schemas/workflows.py +16 -0
- agno/api/settings.py +53 -0
- agno/api/team.py +30 -0
- agno/api/workflow.py +28 -0
- agno/cloud/aws/base.py +214 -0
- agno/cloud/aws/s3/__init__.py +2 -0
- agno/cloud/aws/s3/api_client.py +43 -0
- agno/cloud/aws/s3/bucket.py +195 -0
- agno/cloud/aws/s3/object.py +57 -0
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/__init__.py +24 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +598 -0
- agno/db/dynamo/__init__.py +3 -0
- agno/db/dynamo/dynamo.py +2042 -0
- agno/db/dynamo/schemas.py +314 -0
- agno/db/dynamo/utils.py +743 -0
- agno/db/firestore/__init__.py +3 -0
- agno/db/firestore/firestore.py +1795 -0
- agno/db/firestore/schemas.py +140 -0
- agno/db/firestore/utils.py +376 -0
- agno/db/gcs_json/__init__.py +3 -0
- agno/db/gcs_json/gcs_json_db.py +1335 -0
- agno/db/gcs_json/utils.py +228 -0
- agno/db/in_memory/__init__.py +3 -0
- agno/db/in_memory/in_memory_db.py +1160 -0
- agno/db/in_memory/utils.py +230 -0
- agno/db/json/__init__.py +3 -0
- agno/db/json/json_db.py +1328 -0
- agno/db/json/utils.py +230 -0
- agno/db/migrations/__init__.py +0 -0
- agno/db/migrations/v1_to_v2.py +635 -0
- agno/db/mongo/__init__.py +17 -0
- agno/db/mongo/async_mongo.py +2026 -0
- agno/db/mongo/mongo.py +1982 -0
- agno/db/mongo/schemas.py +87 -0
- agno/db/mongo/utils.py +259 -0
- agno/db/mysql/__init__.py +3 -0
- agno/db/mysql/mysql.py +2308 -0
- agno/db/mysql/schemas.py +138 -0
- agno/db/mysql/utils.py +355 -0
- agno/db/postgres/__init__.py +4 -0
- agno/db/postgres/async_postgres.py +1927 -0
- agno/db/postgres/postgres.py +2260 -0
- agno/db/postgres/schemas.py +139 -0
- agno/db/postgres/utils.py +442 -0
- agno/db/redis/__init__.py +3 -0
- agno/db/redis/redis.py +1660 -0
- agno/db/redis/schemas.py +123 -0
- agno/db/redis/utils.py +346 -0
- agno/db/schemas/__init__.py +4 -0
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/evals.py +33 -0
- agno/db/schemas/knowledge.py +40 -0
- agno/db/schemas/memory.py +46 -0
- agno/db/schemas/metrics.py +0 -0
- agno/db/singlestore/__init__.py +3 -0
- agno/db/singlestore/schemas.py +130 -0
- agno/db/singlestore/singlestore.py +2272 -0
- agno/db/singlestore/utils.py +384 -0
- agno/db/sqlite/__init__.py +4 -0
- agno/db/sqlite/async_sqlite.py +2293 -0
- agno/db/sqlite/schemas.py +133 -0
- agno/db/sqlite/sqlite.py +2288 -0
- agno/db/sqlite/utils.py +431 -0
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +309 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1353 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +116 -0
- agno/debug.py +18 -0
- agno/eval/__init__.py +14 -0
- agno/eval/accuracy.py +834 -0
- agno/eval/performance.py +773 -0
- agno/eval/reliability.py +306 -0
- agno/eval/utils.py +119 -0
- agno/exceptions.py +161 -0
- agno/filters.py +354 -0
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +52 -0
- agno/integrations/__init__.py +0 -0
- agno/integrations/discord/__init__.py +3 -0
- agno/integrations/discord/client.py +203 -0
- agno/knowledge/__init__.py +5 -0
- agno/knowledge/chunking/__init__.py +0 -0
- agno/knowledge/chunking/agentic.py +79 -0
- agno/knowledge/chunking/document.py +91 -0
- agno/knowledge/chunking/fixed.py +57 -0
- agno/knowledge/chunking/markdown.py +151 -0
- agno/knowledge/chunking/recursive.py +63 -0
- agno/knowledge/chunking/row.py +39 -0
- agno/knowledge/chunking/semantic.py +86 -0
- agno/knowledge/chunking/strategy.py +165 -0
- agno/knowledge/content.py +74 -0
- agno/knowledge/document/__init__.py +5 -0
- agno/knowledge/document/base.py +58 -0
- agno/knowledge/embedder/__init__.py +5 -0
- agno/knowledge/embedder/aws_bedrock.py +343 -0
- agno/knowledge/embedder/azure_openai.py +210 -0
- agno/knowledge/embedder/base.py +23 -0
- agno/knowledge/embedder/cohere.py +323 -0
- agno/knowledge/embedder/fastembed.py +62 -0
- agno/knowledge/embedder/fireworks.py +13 -0
- agno/knowledge/embedder/google.py +258 -0
- agno/knowledge/embedder/huggingface.py +94 -0
- agno/knowledge/embedder/jina.py +182 -0
- agno/knowledge/embedder/langdb.py +22 -0
- agno/knowledge/embedder/mistral.py +206 -0
- agno/knowledge/embedder/nebius.py +13 -0
- agno/knowledge/embedder/ollama.py +154 -0
- agno/knowledge/embedder/openai.py +195 -0
- agno/knowledge/embedder/sentence_transformer.py +63 -0
- agno/knowledge/embedder/together.py +13 -0
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/embedder/voyageai.py +165 -0
- agno/knowledge/knowledge.py +1988 -0
- agno/knowledge/reader/__init__.py +7 -0
- agno/knowledge/reader/arxiv_reader.py +81 -0
- agno/knowledge/reader/base.py +95 -0
- agno/knowledge/reader/csv_reader.py +166 -0
- agno/knowledge/reader/docx_reader.py +82 -0
- agno/knowledge/reader/field_labeled_csv_reader.py +292 -0
- agno/knowledge/reader/firecrawl_reader.py +201 -0
- agno/knowledge/reader/json_reader.py +87 -0
- agno/knowledge/reader/markdown_reader.py +137 -0
- agno/knowledge/reader/pdf_reader.py +431 -0
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +313 -0
- agno/knowledge/reader/s3_reader.py +89 -0
- agno/knowledge/reader/tavily_reader.py +194 -0
- agno/knowledge/reader/text_reader.py +115 -0
- agno/knowledge/reader/web_search_reader.py +372 -0
- agno/knowledge/reader/website_reader.py +455 -0
- agno/knowledge/reader/wikipedia_reader.py +59 -0
- agno/knowledge/reader/youtube_reader.py +78 -0
- agno/knowledge/remote_content/__init__.py +0 -0
- agno/knowledge/remote_content/remote_content.py +88 -0
- agno/knowledge/reranker/__init__.py +3 -0
- agno/knowledge/reranker/base.py +14 -0
- agno/knowledge/reranker/cohere.py +64 -0
- agno/knowledge/reranker/infinity.py +195 -0
- agno/knowledge/reranker/sentence_transformer.py +54 -0
- agno/knowledge/types.py +39 -0
- agno/knowledge/utils.py +189 -0
- agno/media.py +462 -0
- agno/memory/__init__.py +3 -0
- agno/memory/manager.py +1327 -0
- agno/models/__init__.py +0 -0
- agno/models/aimlapi/__init__.py +5 -0
- agno/models/aimlapi/aimlapi.py +45 -0
- agno/models/anthropic/__init__.py +5 -0
- agno/models/anthropic/claude.py +757 -0
- agno/models/aws/__init__.py +15 -0
- agno/models/aws/bedrock.py +701 -0
- agno/models/aws/claude.py +378 -0
- agno/models/azure/__init__.py +18 -0
- agno/models/azure/ai_foundry.py +485 -0
- agno/models/azure/openai_chat.py +131 -0
- agno/models/base.py +2175 -0
- agno/models/cerebras/__init__.py +12 -0
- agno/models/cerebras/cerebras.py +501 -0
- agno/models/cerebras/cerebras_openai.py +112 -0
- agno/models/cohere/__init__.py +5 -0
- agno/models/cohere/chat.py +389 -0
- agno/models/cometapi/__init__.py +5 -0
- agno/models/cometapi/cometapi.py +57 -0
- agno/models/dashscope/__init__.py +5 -0
- agno/models/dashscope/dashscope.py +91 -0
- agno/models/deepinfra/__init__.py +5 -0
- agno/models/deepinfra/deepinfra.py +28 -0
- agno/models/deepseek/__init__.py +5 -0
- agno/models/deepseek/deepseek.py +61 -0
- agno/models/defaults.py +1 -0
- agno/models/fireworks/__init__.py +5 -0
- agno/models/fireworks/fireworks.py +26 -0
- agno/models/google/__init__.py +5 -0
- agno/models/google/gemini.py +1085 -0
- agno/models/groq/__init__.py +5 -0
- agno/models/groq/groq.py +556 -0
- agno/models/huggingface/__init__.py +5 -0
- agno/models/huggingface/huggingface.py +491 -0
- agno/models/ibm/__init__.py +5 -0
- agno/models/ibm/watsonx.py +422 -0
- agno/models/internlm/__init__.py +3 -0
- agno/models/internlm/internlm.py +26 -0
- agno/models/langdb/__init__.py +1 -0
- agno/models/langdb/langdb.py +48 -0
- agno/models/litellm/__init__.py +14 -0
- agno/models/litellm/chat.py +468 -0
- agno/models/litellm/litellm_openai.py +25 -0
- agno/models/llama_cpp/__init__.py +5 -0
- agno/models/llama_cpp/llama_cpp.py +22 -0
- agno/models/lmstudio/__init__.py +5 -0
- agno/models/lmstudio/lmstudio.py +25 -0
- agno/models/message.py +434 -0
- agno/models/meta/__init__.py +12 -0
- agno/models/meta/llama.py +475 -0
- agno/models/meta/llama_openai.py +78 -0
- agno/models/metrics.py +120 -0
- agno/models/mistral/__init__.py +5 -0
- agno/models/mistral/mistral.py +432 -0
- agno/models/nebius/__init__.py +3 -0
- agno/models/nebius/nebius.py +54 -0
- agno/models/nexus/__init__.py +3 -0
- agno/models/nexus/nexus.py +22 -0
- agno/models/nvidia/__init__.py +5 -0
- agno/models/nvidia/nvidia.py +28 -0
- agno/models/ollama/__init__.py +5 -0
- agno/models/ollama/chat.py +441 -0
- agno/models/openai/__init__.py +9 -0
- agno/models/openai/chat.py +883 -0
- agno/models/openai/like.py +27 -0
- agno/models/openai/responses.py +1050 -0
- agno/models/openrouter/__init__.py +5 -0
- agno/models/openrouter/openrouter.py +66 -0
- agno/models/perplexity/__init__.py +5 -0
- agno/models/perplexity/perplexity.py +187 -0
- agno/models/portkey/__init__.py +3 -0
- agno/models/portkey/portkey.py +81 -0
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +52 -0
- agno/models/response.py +199 -0
- agno/models/sambanova/__init__.py +5 -0
- agno/models/sambanova/sambanova.py +28 -0
- agno/models/siliconflow/__init__.py +5 -0
- agno/models/siliconflow/siliconflow.py +25 -0
- agno/models/together/__init__.py +5 -0
- agno/models/together/together.py +25 -0
- agno/models/utils.py +266 -0
- agno/models/vercel/__init__.py +3 -0
- agno/models/vercel/v0.py +26 -0
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +70 -0
- agno/models/vllm/__init__.py +3 -0
- agno/models/vllm/vllm.py +78 -0
- agno/models/xai/__init__.py +3 -0
- agno/models/xai/xai.py +113 -0
- agno/os/__init__.py +3 -0
- agno/os/app.py +876 -0
- agno/os/auth.py +57 -0
- agno/os/config.py +104 -0
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +250 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/__init__.py +3 -0
- agno/os/interfaces/agui/agui.py +47 -0
- agno/os/interfaces/agui/router.py +144 -0
- agno/os/interfaces/agui/utils.py +534 -0
- agno/os/interfaces/base.py +25 -0
- agno/os/interfaces/slack/__init__.py +3 -0
- agno/os/interfaces/slack/router.py +148 -0
- agno/os/interfaces/slack/security.py +30 -0
- agno/os/interfaces/slack/slack.py +47 -0
- agno/os/interfaces/whatsapp/__init__.py +3 -0
- agno/os/interfaces/whatsapp/router.py +211 -0
- agno/os/interfaces/whatsapp/security.py +53 -0
- agno/os/interfaces/whatsapp/whatsapp.py +36 -0
- agno/os/mcp.py +292 -0
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +1763 -0
- agno/os/routers/__init__.py +3 -0
- agno/os/routers/evals/__init__.py +3 -0
- agno/os/routers/evals/evals.py +430 -0
- agno/os/routers/evals/schemas.py +142 -0
- agno/os/routers/evals/utils.py +162 -0
- agno/os/routers/health.py +31 -0
- agno/os/routers/home.py +52 -0
- agno/os/routers/knowledge/__init__.py +3 -0
- agno/os/routers/knowledge/knowledge.py +997 -0
- agno/os/routers/knowledge/schemas.py +178 -0
- agno/os/routers/memory/__init__.py +3 -0
- agno/os/routers/memory/memory.py +515 -0
- agno/os/routers/memory/schemas.py +62 -0
- agno/os/routers/metrics/__init__.py +3 -0
- agno/os/routers/metrics/metrics.py +190 -0
- agno/os/routers/metrics/schemas.py +47 -0
- agno/os/routers/session/__init__.py +3 -0
- agno/os/routers/session/session.py +997 -0
- agno/os/schema.py +1055 -0
- agno/os/settings.py +43 -0
- agno/os/utils.py +630 -0
- agno/py.typed +0 -0
- agno/reasoning/__init__.py +0 -0
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/azure_ai_foundry.py +67 -0
- agno/reasoning/deepseek.py +63 -0
- agno/reasoning/default.py +97 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/groq.py +71 -0
- agno/reasoning/helpers.py +63 -0
- agno/reasoning/ollama.py +67 -0
- agno/reasoning/openai.py +86 -0
- agno/reasoning/step.py +31 -0
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +787 -0
- agno/run/base.py +229 -0
- agno/run/cancel.py +81 -0
- agno/run/messages.py +32 -0
- agno/run/team.py +753 -0
- agno/run/workflow.py +708 -0
- agno/session/__init__.py +10 -0
- agno/session/agent.py +295 -0
- agno/session/summary.py +265 -0
- agno/session/team.py +392 -0
- agno/session/workflow.py +205 -0
- agno/team/__init__.py +37 -0
- agno/team/team.py +8793 -0
- agno/tools/__init__.py +10 -0
- agno/tools/agentql.py +120 -0
- agno/tools/airflow.py +69 -0
- agno/tools/api.py +122 -0
- agno/tools/apify.py +314 -0
- agno/tools/arxiv.py +127 -0
- agno/tools/aws_lambda.py +53 -0
- agno/tools/aws_ses.py +66 -0
- agno/tools/baidusearch.py +89 -0
- agno/tools/bitbucket.py +292 -0
- agno/tools/brandfetch.py +213 -0
- agno/tools/bravesearch.py +106 -0
- agno/tools/brightdata.py +367 -0
- agno/tools/browserbase.py +209 -0
- agno/tools/calcom.py +255 -0
- agno/tools/calculator.py +151 -0
- agno/tools/cartesia.py +187 -0
- agno/tools/clickup.py +244 -0
- agno/tools/confluence.py +240 -0
- agno/tools/crawl4ai.py +158 -0
- agno/tools/csv_toolkit.py +185 -0
- agno/tools/dalle.py +110 -0
- agno/tools/daytona.py +475 -0
- agno/tools/decorator.py +262 -0
- agno/tools/desi_vocal.py +108 -0
- agno/tools/discord.py +161 -0
- agno/tools/docker.py +716 -0
- agno/tools/duckdb.py +379 -0
- agno/tools/duckduckgo.py +91 -0
- agno/tools/e2b.py +703 -0
- agno/tools/eleven_labs.py +196 -0
- agno/tools/email.py +67 -0
- agno/tools/evm.py +129 -0
- agno/tools/exa.py +396 -0
- agno/tools/fal.py +127 -0
- agno/tools/file.py +240 -0
- agno/tools/file_generation.py +350 -0
- agno/tools/financial_datasets.py +288 -0
- agno/tools/firecrawl.py +143 -0
- agno/tools/function.py +1187 -0
- agno/tools/giphy.py +93 -0
- agno/tools/github.py +1760 -0
- agno/tools/gmail.py +922 -0
- agno/tools/google_bigquery.py +117 -0
- agno/tools/google_drive.py +270 -0
- agno/tools/google_maps.py +253 -0
- agno/tools/googlecalendar.py +674 -0
- agno/tools/googlesearch.py +98 -0
- agno/tools/googlesheets.py +377 -0
- agno/tools/hackernews.py +77 -0
- agno/tools/jina.py +101 -0
- agno/tools/jira.py +170 -0
- agno/tools/knowledge.py +218 -0
- agno/tools/linear.py +426 -0
- agno/tools/linkup.py +58 -0
- agno/tools/local_file_system.py +90 -0
- agno/tools/lumalab.py +183 -0
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +284 -0
- agno/tools/mem0.py +193 -0
- agno/tools/memori.py +339 -0
- agno/tools/memory.py +419 -0
- agno/tools/mlx_transcribe.py +139 -0
- agno/tools/models/__init__.py +0 -0
- agno/tools/models/azure_openai.py +190 -0
- agno/tools/models/gemini.py +203 -0
- agno/tools/models/groq.py +158 -0
- agno/tools/models/morph.py +186 -0
- agno/tools/models/nebius.py +124 -0
- agno/tools/models_labs.py +195 -0
- agno/tools/moviepy_video.py +349 -0
- agno/tools/neo4j.py +134 -0
- agno/tools/newspaper.py +46 -0
- agno/tools/newspaper4k.py +93 -0
- agno/tools/notion.py +204 -0
- agno/tools/openai.py +202 -0
- agno/tools/openbb.py +160 -0
- agno/tools/opencv.py +321 -0
- agno/tools/openweather.py +233 -0
- agno/tools/oxylabs.py +385 -0
- agno/tools/pandas.py +102 -0
- agno/tools/parallel.py +314 -0
- agno/tools/postgres.py +257 -0
- agno/tools/pubmed.py +188 -0
- agno/tools/python.py +205 -0
- agno/tools/reasoning.py +283 -0
- agno/tools/reddit.py +467 -0
- agno/tools/replicate.py +117 -0
- agno/tools/resend.py +62 -0
- agno/tools/scrapegraph.py +222 -0
- agno/tools/searxng.py +152 -0
- agno/tools/serpapi.py +116 -0
- agno/tools/serper.py +255 -0
- agno/tools/shell.py +53 -0
- agno/tools/slack.py +136 -0
- agno/tools/sleep.py +20 -0
- agno/tools/spider.py +116 -0
- agno/tools/sql.py +154 -0
- agno/tools/streamlit/__init__.py +0 -0
- agno/tools/streamlit/components.py +113 -0
- agno/tools/tavily.py +254 -0
- agno/tools/telegram.py +48 -0
- agno/tools/todoist.py +218 -0
- agno/tools/tool_registry.py +1 -0
- agno/tools/toolkit.py +146 -0
- agno/tools/trafilatura.py +388 -0
- agno/tools/trello.py +274 -0
- agno/tools/twilio.py +186 -0
- agno/tools/user_control_flow.py +78 -0
- agno/tools/valyu.py +228 -0
- agno/tools/visualization.py +467 -0
- agno/tools/webbrowser.py +28 -0
- agno/tools/webex.py +76 -0
- agno/tools/website.py +54 -0
- agno/tools/webtools.py +45 -0
- agno/tools/whatsapp.py +286 -0
- agno/tools/wikipedia.py +63 -0
- agno/tools/workflow.py +278 -0
- agno/tools/x.py +335 -0
- agno/tools/yfinance.py +257 -0
- agno/tools/youtube.py +184 -0
- agno/tools/zendesk.py +82 -0
- agno/tools/zep.py +454 -0
- agno/tools/zoom.py +382 -0
- agno/utils/__init__.py +0 -0
- agno/utils/agent.py +820 -0
- agno/utils/audio.py +49 -0
- agno/utils/certs.py +27 -0
- agno/utils/code_execution.py +11 -0
- agno/utils/common.py +132 -0
- agno/utils/dttm.py +13 -0
- agno/utils/enum.py +22 -0
- agno/utils/env.py +11 -0
- agno/utils/events.py +696 -0
- agno/utils/format_str.py +16 -0
- agno/utils/functions.py +166 -0
- agno/utils/gemini.py +426 -0
- agno/utils/hooks.py +57 -0
- agno/utils/http.py +74 -0
- agno/utils/json_schema.py +234 -0
- agno/utils/knowledge.py +36 -0
- agno/utils/location.py +19 -0
- agno/utils/log.py +255 -0
- agno/utils/mcp.py +214 -0
- agno/utils/media.py +352 -0
- agno/utils/merge_dict.py +41 -0
- agno/utils/message.py +118 -0
- agno/utils/models/__init__.py +0 -0
- agno/utils/models/ai_foundry.py +43 -0
- agno/utils/models/claude.py +358 -0
- agno/utils/models/cohere.py +87 -0
- agno/utils/models/llama.py +78 -0
- agno/utils/models/mistral.py +98 -0
- agno/utils/models/openai_responses.py +140 -0
- agno/utils/models/schema_utils.py +153 -0
- agno/utils/models/watsonx.py +41 -0
- agno/utils/openai.py +257 -0
- agno/utils/pickle.py +32 -0
- agno/utils/pprint.py +178 -0
- agno/utils/print_response/__init__.py +0 -0
- agno/utils/print_response/agent.py +842 -0
- agno/utils/print_response/team.py +1724 -0
- agno/utils/print_response/workflow.py +1668 -0
- agno/utils/prompts.py +111 -0
- agno/utils/reasoning.py +108 -0
- agno/utils/response.py +163 -0
- agno/utils/response_iterator.py +17 -0
- agno/utils/safe_formatter.py +24 -0
- agno/utils/serialize.py +32 -0
- agno/utils/shell.py +22 -0
- agno/utils/streamlit.py +487 -0
- agno/utils/string.py +231 -0
- agno/utils/team.py +139 -0
- agno/utils/timer.py +41 -0
- agno/utils/tools.py +102 -0
- agno/utils/web.py +23 -0
- agno/utils/whatsapp.py +305 -0
- agno/utils/yaml_io.py +25 -0
- agno/vectordb/__init__.py +3 -0
- agno/vectordb/base.py +127 -0
- agno/vectordb/cassandra/__init__.py +5 -0
- agno/vectordb/cassandra/cassandra.py +501 -0
- agno/vectordb/cassandra/extra_param_mixin.py +11 -0
- agno/vectordb/cassandra/index.py +13 -0
- agno/vectordb/chroma/__init__.py +5 -0
- agno/vectordb/chroma/chromadb.py +929 -0
- agno/vectordb/clickhouse/__init__.py +9 -0
- agno/vectordb/clickhouse/clickhousedb.py +835 -0
- agno/vectordb/clickhouse/index.py +9 -0
- agno/vectordb/couchbase/__init__.py +3 -0
- agno/vectordb/couchbase/couchbase.py +1442 -0
- agno/vectordb/distance.py +7 -0
- agno/vectordb/lancedb/__init__.py +6 -0
- agno/vectordb/lancedb/lance_db.py +995 -0
- agno/vectordb/langchaindb/__init__.py +5 -0
- agno/vectordb/langchaindb/langchaindb.py +163 -0
- agno/vectordb/lightrag/__init__.py +5 -0
- agno/vectordb/lightrag/lightrag.py +388 -0
- agno/vectordb/llamaindex/__init__.py +3 -0
- agno/vectordb/llamaindex/llamaindexdb.py +166 -0
- agno/vectordb/milvus/__init__.py +4 -0
- agno/vectordb/milvus/milvus.py +1182 -0
- agno/vectordb/mongodb/__init__.py +9 -0
- agno/vectordb/mongodb/mongodb.py +1417 -0
- agno/vectordb/pgvector/__init__.py +12 -0
- agno/vectordb/pgvector/index.py +23 -0
- agno/vectordb/pgvector/pgvector.py +1462 -0
- agno/vectordb/pineconedb/__init__.py +5 -0
- agno/vectordb/pineconedb/pineconedb.py +747 -0
- agno/vectordb/qdrant/__init__.py +5 -0
- agno/vectordb/qdrant/qdrant.py +1134 -0
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +694 -0
- agno/vectordb/search.py +7 -0
- agno/vectordb/singlestore/__init__.py +10 -0
- agno/vectordb/singlestore/index.py +41 -0
- agno/vectordb/singlestore/singlestore.py +763 -0
- agno/vectordb/surrealdb/__init__.py +3 -0
- agno/vectordb/surrealdb/surrealdb.py +699 -0
- agno/vectordb/upstashdb/__init__.py +5 -0
- agno/vectordb/upstashdb/upstashdb.py +718 -0
- agno/vectordb/weaviate/__init__.py +8 -0
- agno/vectordb/weaviate/index.py +15 -0
- agno/vectordb/weaviate/weaviate.py +1005 -0
- agno/workflow/__init__.py +23 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +738 -0
- agno/workflow/loop.py +735 -0
- agno/workflow/parallel.py +824 -0
- agno/workflow/router.py +702 -0
- agno/workflow/step.py +1432 -0
- agno/workflow/steps.py +592 -0
- agno/workflow/types.py +520 -0
- agno/workflow/workflow.py +4321 -0
- agno-2.2.13.dist-info/METADATA +614 -0
- agno-2.2.13.dist-info/RECORD +575 -0
- agno-2.2.13.dist-info/WHEEL +5 -0
- agno-2.2.13.dist-info/licenses/LICENSE +201 -0
- agno-2.2.13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2272 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from datetime import date, datetime, timedelta, timezone
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
from agno.db.base import BaseDb, SessionType
|
|
8
|
+
from agno.db.schemas.culture import CulturalKnowledge
|
|
9
|
+
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
10
|
+
from agno.db.schemas.knowledge import KnowledgeRow
|
|
11
|
+
from agno.db.schemas.memory import UserMemory
|
|
12
|
+
from agno.db.singlestore.schemas import get_table_schema_definition
|
|
13
|
+
from agno.db.singlestore.utils import (
|
|
14
|
+
apply_sorting,
|
|
15
|
+
bulk_upsert_metrics,
|
|
16
|
+
calculate_date_metrics,
|
|
17
|
+
create_schema,
|
|
18
|
+
deserialize_cultural_knowledge_from_db,
|
|
19
|
+
fetch_all_sessions_data,
|
|
20
|
+
get_dates_to_calculate_metrics_for,
|
|
21
|
+
is_table_available,
|
|
22
|
+
is_valid_table,
|
|
23
|
+
serialize_cultural_knowledge_for_db,
|
|
24
|
+
)
|
|
25
|
+
from agno.session import AgentSession, Session, TeamSession, WorkflowSession
|
|
26
|
+
from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
27
|
+
from agno.utils.string import generate_id
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from sqlalchemy import Index, UniqueConstraint, and_, func, update
|
|
31
|
+
from sqlalchemy.dialects import mysql
|
|
32
|
+
from sqlalchemy.engine import Engine, create_engine
|
|
33
|
+
from sqlalchemy.orm import scoped_session, sessionmaker
|
|
34
|
+
from sqlalchemy.schema import Column, MetaData, Table
|
|
35
|
+
from sqlalchemy.sql.expression import select, text
|
|
36
|
+
except ImportError:
|
|
37
|
+
raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SingleStoreDb(BaseDb):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
id: Optional[str] = None,
|
|
44
|
+
db_engine: Optional[Engine] = None,
|
|
45
|
+
db_schema: Optional[str] = None,
|
|
46
|
+
db_url: Optional[str] = None,
|
|
47
|
+
session_table: Optional[str] = None,
|
|
48
|
+
culture_table: Optional[str] = None,
|
|
49
|
+
memory_table: Optional[str] = None,
|
|
50
|
+
metrics_table: Optional[str] = None,
|
|
51
|
+
eval_table: Optional[str] = None,
|
|
52
|
+
knowledge_table: Optional[str] = None,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Interface for interacting with a SingleStore database.
|
|
56
|
+
|
|
57
|
+
The following order is used to determine the database connection:
|
|
58
|
+
1. Use the db_engine if provided
|
|
59
|
+
2. Use the db_url
|
|
60
|
+
3. Raise an error if neither is provided
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
id (Optional[str]): The ID of the database.
|
|
64
|
+
db_engine (Optional[Engine]): The SQLAlchemy database engine to use.
|
|
65
|
+
db_schema (Optional[str]): The database schema to use.
|
|
66
|
+
db_url (Optional[str]): The database URL to connect to.
|
|
67
|
+
session_table (Optional[str]): Name of the table to store Agent, Team and Workflow sessions.
|
|
68
|
+
culture_table (Optional[str]): Name of the table to store cultural knowledge.
|
|
69
|
+
memory_table (Optional[str]): Name of the table to store memories.
|
|
70
|
+
metrics_table (Optional[str]): Name of the table to store metrics.
|
|
71
|
+
eval_table (Optional[str]): Name of the table to store evaluation runs data.
|
|
72
|
+
knowledge_table (Optional[str]): Name of the table to store knowledge content.
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
ValueError: If neither db_url nor db_engine is provided.
|
|
76
|
+
ValueError: If none of the tables are provided.
|
|
77
|
+
"""
|
|
78
|
+
if id is None:
|
|
79
|
+
base_seed = db_url or str(db_engine.url) if db_engine else "singlestore" # type: ignore
|
|
80
|
+
schema_suffix = db_schema if db_schema is not None else "ai"
|
|
81
|
+
seed = f"{base_seed}#{schema_suffix}"
|
|
82
|
+
id = generate_id(seed)
|
|
83
|
+
|
|
84
|
+
super().__init__(
|
|
85
|
+
id=id,
|
|
86
|
+
session_table=session_table,
|
|
87
|
+
culture_table=culture_table,
|
|
88
|
+
memory_table=memory_table,
|
|
89
|
+
metrics_table=metrics_table,
|
|
90
|
+
eval_table=eval_table,
|
|
91
|
+
knowledge_table=knowledge_table,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
_engine: Optional[Engine] = db_engine
|
|
95
|
+
if _engine is None and db_url is not None:
|
|
96
|
+
_engine = create_engine(
|
|
97
|
+
db_url,
|
|
98
|
+
connect_args={
|
|
99
|
+
"charset": "utf8mb4",
|
|
100
|
+
"ssl": {"ssl_disabled": False, "ssl_ca": None, "ssl_check_hostname": False},
|
|
101
|
+
},
|
|
102
|
+
)
|
|
103
|
+
if _engine is None:
|
|
104
|
+
raise ValueError("One of db_url or db_engine must be provided")
|
|
105
|
+
|
|
106
|
+
self.db_url: Optional[str] = db_url
|
|
107
|
+
self.db_engine: Engine = _engine
|
|
108
|
+
self.db_schema: Optional[str] = db_schema
|
|
109
|
+
self.metadata: MetaData = MetaData()
|
|
110
|
+
|
|
111
|
+
# Initialize database session
|
|
112
|
+
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
|
|
113
|
+
|
|
114
|
+
# -- DB methods --
|
|
115
|
+
def table_exists(self, table_name: str) -> bool:
|
|
116
|
+
"""Check if a table with the given name exists in the SingleStore database.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
table_name: Name of the table to check
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
bool: True if the table exists in the database, False otherwise
|
|
123
|
+
"""
|
|
124
|
+
with self.Session() as sess:
|
|
125
|
+
return is_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
|
|
126
|
+
|
|
127
|
+
def _create_table_structure_only(self, table_name: str, table_type: str, db_schema: Optional[str]) -> Table:
|
|
128
|
+
"""
|
|
129
|
+
Create a table structure definition without actually creating the table in the database.
|
|
130
|
+
Used to avoid autoload issues with SingleStore JSON types.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
table_name (str): Name of the table
|
|
134
|
+
table_type (str): Type of table (used to get schema definition)
|
|
135
|
+
db_schema (Optional[str]): Database schema name
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Table: SQLAlchemy Table object with column definitions
|
|
139
|
+
"""
|
|
140
|
+
try:
|
|
141
|
+
table_schema = get_table_schema_definition(table_type)
|
|
142
|
+
|
|
143
|
+
columns: List[Column] = []
|
|
144
|
+
# Get the columns from the table schema
|
|
145
|
+
for col_name, col_config in table_schema.items():
|
|
146
|
+
# Skip constraint definitions
|
|
147
|
+
if col_name.startswith("_"):
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
column_args = [col_name, col_config["type"]()]
|
|
151
|
+
column_kwargs: Dict[str, Any] = {}
|
|
152
|
+
if col_config.get("primary_key", False):
|
|
153
|
+
column_kwargs["primary_key"] = True
|
|
154
|
+
if "nullable" in col_config:
|
|
155
|
+
column_kwargs["nullable"] = col_config["nullable"]
|
|
156
|
+
if col_config.get("unique", False):
|
|
157
|
+
column_kwargs["unique"] = True
|
|
158
|
+
columns.append(Column(*column_args, **column_kwargs))
|
|
159
|
+
|
|
160
|
+
# Create the table object without constraints to avoid autoload issues
|
|
161
|
+
table_metadata = MetaData(schema=db_schema)
|
|
162
|
+
table = Table(table_name, table_metadata, *columns, schema=db_schema)
|
|
163
|
+
|
|
164
|
+
return table
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
table_ref = f"{db_schema}.{table_name}" if db_schema else table_name
|
|
168
|
+
log_error(f"Could not create table structure for {table_ref}: {e}")
|
|
169
|
+
raise
|
|
170
|
+
|
|
171
|
+
def _create_all_tables(self):
|
|
172
|
+
"""Create all tables for the database."""
|
|
173
|
+
tables_to_create = [
|
|
174
|
+
(self.session_table_name, "sessions"),
|
|
175
|
+
(self.memory_table_name, "memories"),
|
|
176
|
+
(self.metrics_table_name, "metrics"),
|
|
177
|
+
(self.eval_table_name, "evals"),
|
|
178
|
+
(self.knowledge_table_name, "knowledge"),
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
for table_name, table_type in tables_to_create:
|
|
182
|
+
self._create_table(table_name=table_name, table_type=table_type, db_schema=self.db_schema)
|
|
183
|
+
|
|
184
|
+
def _create_table(self, table_name: str, table_type: str, db_schema: Optional[str]) -> Table:
|
|
185
|
+
"""
|
|
186
|
+
Create a table with the appropriate schema based on the table type.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
table_name (str): Name of the table to create
|
|
190
|
+
table_type (str): Type of table (used to get schema definition)
|
|
191
|
+
db_schema (Optional[str]): Database schema name
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Table: SQLAlchemy Table object
|
|
195
|
+
"""
|
|
196
|
+
table_ref = f"{db_schema}.{table_name}" if db_schema else table_name
|
|
197
|
+
try:
|
|
198
|
+
table_schema = get_table_schema_definition(table_type)
|
|
199
|
+
|
|
200
|
+
log_debug(f"Creating table {table_ref}")
|
|
201
|
+
|
|
202
|
+
columns: List[Column] = []
|
|
203
|
+
indexes: List[str] = []
|
|
204
|
+
unique_constraints: List[str] = []
|
|
205
|
+
schema_unique_constraints = table_schema.pop("_unique_constraints", [])
|
|
206
|
+
|
|
207
|
+
# Get the columns, indexes, and unique constraints from the table schema
|
|
208
|
+
for col_name, col_config in table_schema.items():
|
|
209
|
+
column_args = [col_name, col_config["type"]()]
|
|
210
|
+
column_kwargs: Dict[str, Any] = {}
|
|
211
|
+
if col_config.get("primary_key", False):
|
|
212
|
+
column_kwargs["primary_key"] = True
|
|
213
|
+
if "nullable" in col_config:
|
|
214
|
+
column_kwargs["nullable"] = col_config["nullable"]
|
|
215
|
+
if col_config.get("index", False):
|
|
216
|
+
indexes.append(col_name)
|
|
217
|
+
if col_config.get("unique", False):
|
|
218
|
+
column_kwargs["unique"] = True
|
|
219
|
+
unique_constraints.append(col_name)
|
|
220
|
+
columns.append(Column(*column_args, **column_kwargs))
|
|
221
|
+
|
|
222
|
+
# Create the table object
|
|
223
|
+
table_metadata = MetaData(schema=db_schema)
|
|
224
|
+
table = Table(table_name, table_metadata, *columns, schema=db_schema)
|
|
225
|
+
|
|
226
|
+
# Add multi-column unique constraints with table-specific names
|
|
227
|
+
for constraint in schema_unique_constraints:
|
|
228
|
+
constraint_name = f"{table_name}_{constraint['name']}"
|
|
229
|
+
constraint_columns = constraint["columns"]
|
|
230
|
+
table.append_constraint(UniqueConstraint(*constraint_columns, name=constraint_name))
|
|
231
|
+
|
|
232
|
+
# Add indexes to the table definition
|
|
233
|
+
for idx_col in indexes:
|
|
234
|
+
idx_name = f"idx_{table_name}_{idx_col}"
|
|
235
|
+
table.append_constraint(Index(idx_name, idx_col))
|
|
236
|
+
|
|
237
|
+
# Create schema if one is specified
|
|
238
|
+
if db_schema is not None:
|
|
239
|
+
with self.Session() as sess, sess.begin():
|
|
240
|
+
create_schema(session=sess, db_schema=db_schema)
|
|
241
|
+
|
|
242
|
+
# SingleStore has a limitation on the number of unique multi-field constraints per table.
|
|
243
|
+
# We need to work around that limitation for the sessions table.
|
|
244
|
+
if table_type == "sessions":
|
|
245
|
+
with self.Session() as sess, sess.begin():
|
|
246
|
+
# Build column definitions
|
|
247
|
+
columns_sql = []
|
|
248
|
+
for col in table.columns:
|
|
249
|
+
col_sql = f"{col.name} {col.type.compile(self.db_engine.dialect)}"
|
|
250
|
+
if not col.nullable:
|
|
251
|
+
col_sql += " NOT NULL"
|
|
252
|
+
columns_sql.append(col_sql)
|
|
253
|
+
|
|
254
|
+
columns_def = ", ".join(columns_sql)
|
|
255
|
+
|
|
256
|
+
# Add shard key and single unique constraint
|
|
257
|
+
table_sql = f"""CREATE TABLE IF NOT EXISTS {table_ref} (
|
|
258
|
+
{columns_def},
|
|
259
|
+
SHARD KEY (session_id),
|
|
260
|
+
UNIQUE KEY uq_session_type (session_id, session_type)
|
|
261
|
+
)"""
|
|
262
|
+
|
|
263
|
+
sess.execute(text(table_sql))
|
|
264
|
+
else:
|
|
265
|
+
table.create(self.db_engine, checkfirst=True)
|
|
266
|
+
|
|
267
|
+
# Create indexes
|
|
268
|
+
for idx in table.indexes:
|
|
269
|
+
try:
|
|
270
|
+
log_debug(f"Creating index: {idx.name}")
|
|
271
|
+
|
|
272
|
+
# Check if index already exists
|
|
273
|
+
with self.Session() as sess:
|
|
274
|
+
if db_schema is not None:
|
|
275
|
+
exists_query = text(
|
|
276
|
+
"SELECT 1 FROM information_schema.statistics WHERE table_schema = :schema AND index_name = :index_name"
|
|
277
|
+
)
|
|
278
|
+
exists = (
|
|
279
|
+
sess.execute(exists_query, {"schema": db_schema, "index_name": idx.name}).scalar()
|
|
280
|
+
is not None
|
|
281
|
+
)
|
|
282
|
+
else:
|
|
283
|
+
exists_query = text(
|
|
284
|
+
"SELECT 1 FROM information_schema.statistics WHERE table_schema = DATABASE() AND index_name = :index_name"
|
|
285
|
+
)
|
|
286
|
+
exists = sess.execute(exists_query, {"index_name": idx.name}).scalar() is not None
|
|
287
|
+
if exists:
|
|
288
|
+
log_debug(f"Index {idx.name} already exists in {table_ref}, skipping creation")
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
idx.create(self.db_engine)
|
|
292
|
+
|
|
293
|
+
except Exception as e:
|
|
294
|
+
log_error(f"Error creating index {idx.name}: {e}")
|
|
295
|
+
|
|
296
|
+
log_debug(f"Successfully created table {table_ref}")
|
|
297
|
+
return table
|
|
298
|
+
|
|
299
|
+
except Exception as e:
|
|
300
|
+
log_error(f"Could not create table {table_ref}: {e}")
|
|
301
|
+
raise
|
|
302
|
+
|
|
303
|
+
def _get_table(self, table_type: str, create_table_if_not_found: Optional[bool] = False) -> Optional[Table]:
|
|
304
|
+
if table_type == "sessions":
|
|
305
|
+
self.session_table = self._get_or_create_table(
|
|
306
|
+
table_name=self.session_table_name,
|
|
307
|
+
table_type="sessions",
|
|
308
|
+
db_schema=self.db_schema,
|
|
309
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
310
|
+
)
|
|
311
|
+
return self.session_table
|
|
312
|
+
|
|
313
|
+
if table_type == "memories":
|
|
314
|
+
self.memory_table = self._get_or_create_table(
|
|
315
|
+
table_name=self.memory_table_name,
|
|
316
|
+
table_type="memories",
|
|
317
|
+
db_schema=self.db_schema,
|
|
318
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
319
|
+
)
|
|
320
|
+
return self.memory_table
|
|
321
|
+
|
|
322
|
+
if table_type == "metrics":
|
|
323
|
+
self.metrics_table = self._get_or_create_table(
|
|
324
|
+
table_name=self.metrics_table_name,
|
|
325
|
+
table_type="metrics",
|
|
326
|
+
db_schema=self.db_schema,
|
|
327
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
328
|
+
)
|
|
329
|
+
return self.metrics_table
|
|
330
|
+
|
|
331
|
+
if table_type == "evals":
|
|
332
|
+
self.eval_table = self._get_or_create_table(
|
|
333
|
+
table_name=self.eval_table_name,
|
|
334
|
+
table_type="evals",
|
|
335
|
+
db_schema=self.db_schema,
|
|
336
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
337
|
+
)
|
|
338
|
+
return self.eval_table
|
|
339
|
+
|
|
340
|
+
if table_type == "knowledge":
|
|
341
|
+
self.knowledge_table = self._get_or_create_table(
|
|
342
|
+
table_name=self.knowledge_table_name,
|
|
343
|
+
table_type="knowledge",
|
|
344
|
+
db_schema=self.db_schema,
|
|
345
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
346
|
+
)
|
|
347
|
+
return self.knowledge_table
|
|
348
|
+
|
|
349
|
+
if table_type == "culture":
|
|
350
|
+
self.culture_table = self._get_or_create_table(
|
|
351
|
+
table_name=self.culture_table_name,
|
|
352
|
+
table_type="culture",
|
|
353
|
+
db_schema=self.db_schema,
|
|
354
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
355
|
+
)
|
|
356
|
+
return self.culture_table
|
|
357
|
+
|
|
358
|
+
raise ValueError(f"Unknown table type: {table_type}")
|
|
359
|
+
|
|
360
|
+
def _get_or_create_table(
|
|
361
|
+
self,
|
|
362
|
+
table_name: str,
|
|
363
|
+
table_type: str,
|
|
364
|
+
db_schema: Optional[str],
|
|
365
|
+
create_table_if_not_found: Optional[bool] = False,
|
|
366
|
+
) -> Optional[Table]:
|
|
367
|
+
"""
|
|
368
|
+
Check if the table exists and is valid, else create it.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
table_name (str): Name of the table to get or create
|
|
372
|
+
table_type (str): Type of table (used to get schema definition)
|
|
373
|
+
db_schema (Optional[str]): Database schema name
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
Table: SQLAlchemy Table object representing the schema.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
with self.Session() as sess, sess.begin():
|
|
380
|
+
table_is_available = is_table_available(session=sess, table_name=table_name, db_schema=db_schema)
|
|
381
|
+
|
|
382
|
+
if not table_is_available:
|
|
383
|
+
if not create_table_if_not_found:
|
|
384
|
+
return None
|
|
385
|
+
return self._create_table(table_name=table_name, table_type=table_type, db_schema=db_schema)
|
|
386
|
+
|
|
387
|
+
if not is_valid_table(
|
|
388
|
+
db_engine=self.db_engine,
|
|
389
|
+
table_name=table_name,
|
|
390
|
+
table_type=table_type,
|
|
391
|
+
db_schema=db_schema,
|
|
392
|
+
):
|
|
393
|
+
table_ref = f"{db_schema}.{table_name}" if db_schema else table_name
|
|
394
|
+
raise ValueError(f"Table {table_ref} has an invalid schema")
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
return self._create_table_structure_only(table_name=table_name, table_type=table_type, db_schema=db_schema)
|
|
398
|
+
|
|
399
|
+
except Exception as e:
|
|
400
|
+
table_ref = f"{db_schema}.{table_name}" if db_schema else table_name
|
|
401
|
+
log_error(f"Error loading existing table {table_ref}: {e}")
|
|
402
|
+
raise
|
|
403
|
+
|
|
404
|
+
# -- Session methods --
|
|
405
|
+
def delete_session(self, session_id: str) -> bool:
|
|
406
|
+
"""
|
|
407
|
+
Delete a session from the database.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
session_id (str): ID of the session to delete
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
bool: True if the session was deleted, False otherwise.
|
|
414
|
+
|
|
415
|
+
Raises:
|
|
416
|
+
Exception: If an error occurs during deletion.
|
|
417
|
+
"""
|
|
418
|
+
try:
|
|
419
|
+
table = self._get_table(table_type="sessions")
|
|
420
|
+
if table is None:
|
|
421
|
+
return False
|
|
422
|
+
|
|
423
|
+
with self.Session() as sess, sess.begin():
|
|
424
|
+
delete_stmt = table.delete().where(table.c.session_id == session_id)
|
|
425
|
+
result = sess.execute(delete_stmt)
|
|
426
|
+
if result.rowcount == 0:
|
|
427
|
+
log_debug(f"No session found to delete with session_id: {session_id} in table {table.name}")
|
|
428
|
+
return False
|
|
429
|
+
else:
|
|
430
|
+
log_debug(f"Successfully deleted session with session_id: {session_id} in table {table.name}")
|
|
431
|
+
return True
|
|
432
|
+
|
|
433
|
+
except Exception as e:
|
|
434
|
+
log_error(f"Error deleting session: {e}")
|
|
435
|
+
raise e
|
|
436
|
+
|
|
437
|
+
def delete_sessions(self, session_ids: List[str]) -> None:
|
|
438
|
+
"""Delete all given sessions from the database.
|
|
439
|
+
Can handle multiple session types in the same run.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
session_ids (List[str]): The IDs of the sessions to delete.
|
|
443
|
+
|
|
444
|
+
Raises:
|
|
445
|
+
Exception: If an error occurs during deletion.
|
|
446
|
+
"""
|
|
447
|
+
try:
|
|
448
|
+
table = self._get_table(table_type="sessions")
|
|
449
|
+
if table is None:
|
|
450
|
+
return
|
|
451
|
+
|
|
452
|
+
with self.Session() as sess, sess.begin():
|
|
453
|
+
delete_stmt = table.delete().where(table.c.session_id.in_(session_ids))
|
|
454
|
+
result = sess.execute(delete_stmt)
|
|
455
|
+
|
|
456
|
+
log_debug(f"Successfully deleted {result.rowcount} sessions")
|
|
457
|
+
|
|
458
|
+
except Exception as e:
|
|
459
|
+
log_error(f"Error deleting sessions: {e}")
|
|
460
|
+
raise e
|
|
461
|
+
|
|
462
|
+
def get_session(
|
|
463
|
+
self,
|
|
464
|
+
session_id: str,
|
|
465
|
+
session_type: SessionType,
|
|
466
|
+
user_id: Optional[str] = None,
|
|
467
|
+
deserialize: Optional[bool] = True,
|
|
468
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
469
|
+
"""
|
|
470
|
+
Read a session from the database.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
session_id (str): ID of the session to read.
|
|
474
|
+
session_type (SessionType): Type of session to get.
|
|
475
|
+
user_id (Optional[str]): User ID to filter by. Defaults to None.
|
|
476
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
Union[Session, Dict[str, Any], None]:
|
|
480
|
+
- When deserialize=True: Session object
|
|
481
|
+
- When deserialize=False: Session dictionary
|
|
482
|
+
|
|
483
|
+
Raises:
|
|
484
|
+
Exception: If an error occurs during retrieval.
|
|
485
|
+
"""
|
|
486
|
+
try:
|
|
487
|
+
table = self._get_table(table_type="sessions")
|
|
488
|
+
if table is None:
|
|
489
|
+
return None
|
|
490
|
+
|
|
491
|
+
with self.Session() as sess:
|
|
492
|
+
stmt = select(table).where(table.c.session_id == session_id)
|
|
493
|
+
|
|
494
|
+
if user_id is not None:
|
|
495
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
496
|
+
result = sess.execute(stmt).fetchone()
|
|
497
|
+
if result is None:
|
|
498
|
+
return None
|
|
499
|
+
|
|
500
|
+
session = dict(result._mapping)
|
|
501
|
+
|
|
502
|
+
if not deserialize:
|
|
503
|
+
return session
|
|
504
|
+
|
|
505
|
+
if session_type == SessionType.AGENT:
|
|
506
|
+
return AgentSession.from_dict(session)
|
|
507
|
+
elif session_type == SessionType.TEAM:
|
|
508
|
+
return TeamSession.from_dict(session)
|
|
509
|
+
elif session_type == SessionType.WORKFLOW:
|
|
510
|
+
return WorkflowSession.from_dict(session)
|
|
511
|
+
else:
|
|
512
|
+
raise ValueError(f"Invalid session type: {session_type}")
|
|
513
|
+
|
|
514
|
+
except Exception as e:
|
|
515
|
+
log_error(f"Exception reading from session table: {e}")
|
|
516
|
+
raise e
|
|
517
|
+
|
|
518
|
+
def get_sessions(
|
|
519
|
+
self,
|
|
520
|
+
session_type: Optional[SessionType] = None,
|
|
521
|
+
user_id: Optional[str] = None,
|
|
522
|
+
component_id: Optional[str] = None,
|
|
523
|
+
session_name: Optional[str] = None,
|
|
524
|
+
start_timestamp: Optional[int] = None,
|
|
525
|
+
end_timestamp: Optional[int] = None,
|
|
526
|
+
limit: Optional[int] = None,
|
|
527
|
+
page: Optional[int] = None,
|
|
528
|
+
sort_by: Optional[str] = None,
|
|
529
|
+
sort_order: Optional[str] = None,
|
|
530
|
+
deserialize: Optional[bool] = True,
|
|
531
|
+
) -> Union[List[Session], Tuple[List[Dict[str, Any]], int]]:
|
|
532
|
+
"""
|
|
533
|
+
Get all sessions in the given table. Can filter by user_id and entity_id.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
session_type (Optional[SessionType]): The type of session to filter by.
|
|
537
|
+
user_id (Optional[str]): The ID of the user to filter by.
|
|
538
|
+
component_id (Optional[str]): The ID of the agent / workflow to filter by.
|
|
539
|
+
session_name (Optional[str]): The name of the session to filter by.
|
|
540
|
+
start_timestamp (Optional[int]): The start timestamp to filter by.
|
|
541
|
+
end_timestamp (Optional[int]): The end timestamp to filter by.
|
|
542
|
+
limit (Optional[int]): The maximum number of sessions to return. Defaults to None.
|
|
543
|
+
page (Optional[int]): The page number to return. Defaults to None.
|
|
544
|
+
sort_by (Optional[str]): The field to sort by. Defaults to None.
|
|
545
|
+
sort_order (Optional[str]): The sort order. Defaults to None.
|
|
546
|
+
deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
|
|
547
|
+
create_table_if_not_found (Optional[bool]): Whether to create the table if it doesn't exist.
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
Union[List[Session], Tuple[List[Dict], int]]:
|
|
551
|
+
- When deserialize=True: List of Session objects
|
|
552
|
+
- When deserialize=False: Tuple of (session dictionaries, total count)
|
|
553
|
+
|
|
554
|
+
Raises:
|
|
555
|
+
Exception: If an error occurs during retrieval.
|
|
556
|
+
"""
|
|
557
|
+
try:
|
|
558
|
+
table = self._get_table(table_type="sessions")
|
|
559
|
+
if table is None:
|
|
560
|
+
return [] if deserialize else ([], 0)
|
|
561
|
+
|
|
562
|
+
with self.Session() as sess, sess.begin():
|
|
563
|
+
stmt = select(table)
|
|
564
|
+
|
|
565
|
+
# Filtering
|
|
566
|
+
if user_id is not None:
|
|
567
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
568
|
+
if component_id is not None:
|
|
569
|
+
if session_type == SessionType.AGENT:
|
|
570
|
+
stmt = stmt.where(table.c.agent_id == component_id)
|
|
571
|
+
elif session_type == SessionType.TEAM:
|
|
572
|
+
stmt = stmt.where(table.c.team_id == component_id)
|
|
573
|
+
elif session_type == SessionType.WORKFLOW:
|
|
574
|
+
stmt = stmt.where(table.c.workflow_id == component_id)
|
|
575
|
+
if start_timestamp is not None:
|
|
576
|
+
stmt = stmt.where(table.c.created_at >= start_timestamp)
|
|
577
|
+
if end_timestamp is not None:
|
|
578
|
+
stmt = stmt.where(table.c.created_at <= end_timestamp)
|
|
579
|
+
if session_name is not None:
|
|
580
|
+
# SingleStore JSON extraction syntax
|
|
581
|
+
stmt = stmt.where(
|
|
582
|
+
func.coalesce(func.JSON_EXTRACT_STRING(table.c.session_data, "session_name"), "").like(
|
|
583
|
+
f"%{session_name}%"
|
|
584
|
+
)
|
|
585
|
+
)
|
|
586
|
+
if session_type is not None:
|
|
587
|
+
session_type_value = session_type.value if isinstance(session_type, SessionType) else session_type
|
|
588
|
+
stmt = stmt.where(table.c.session_type == session_type_value)
|
|
589
|
+
|
|
590
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
591
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
592
|
+
|
|
593
|
+
# Sorting
|
|
594
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
595
|
+
|
|
596
|
+
# Paginating
|
|
597
|
+
if limit is not None:
|
|
598
|
+
stmt = stmt.limit(limit)
|
|
599
|
+
if page is not None:
|
|
600
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
601
|
+
|
|
602
|
+
records = sess.execute(stmt).fetchall()
|
|
603
|
+
if records is None:
|
|
604
|
+
return [] if deserialize else ([], 0)
|
|
605
|
+
|
|
606
|
+
session = [dict(record._mapping) for record in records]
|
|
607
|
+
if not deserialize:
|
|
608
|
+
return session, total_count
|
|
609
|
+
|
|
610
|
+
if session_type == SessionType.AGENT:
|
|
611
|
+
return [AgentSession.from_dict(record) for record in session] # type: ignore
|
|
612
|
+
elif session_type == SessionType.TEAM:
|
|
613
|
+
return [TeamSession.from_dict(record) for record in session] # type: ignore
|
|
614
|
+
elif session_type == SessionType.WORKFLOW:
|
|
615
|
+
return [WorkflowSession.from_dict(record) for record in session] # type: ignore
|
|
616
|
+
else:
|
|
617
|
+
raise ValueError(f"Invalid session type: {session_type}")
|
|
618
|
+
|
|
619
|
+
except Exception as e:
|
|
620
|
+
log_error(f"Exception reading from session table: {e}")
|
|
621
|
+
raise e
|
|
622
|
+
|
|
623
|
+
def rename_session(
|
|
624
|
+
self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
|
|
625
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
626
|
+
"""
|
|
627
|
+
Rename a session in the database.
|
|
628
|
+
|
|
629
|
+
Args:
|
|
630
|
+
session_id (str): The ID of the session to rename.
|
|
631
|
+
session_type (SessionType): The type of session to rename.
|
|
632
|
+
session_name (str): The new name for the session.
|
|
633
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
634
|
+
|
|
635
|
+
Returns:
|
|
636
|
+
Optional[Union[Session, Dict[str, Any]]]:
|
|
637
|
+
- When deserialize=True: Session object
|
|
638
|
+
- When deserialize=False: Session dictionary
|
|
639
|
+
|
|
640
|
+
Raises:
|
|
641
|
+
Exception: If an error occurs during renaming.
|
|
642
|
+
"""
|
|
643
|
+
try:
|
|
644
|
+
table = self._get_table(table_type="sessions")
|
|
645
|
+
if table is None:
|
|
646
|
+
return None
|
|
647
|
+
|
|
648
|
+
with self.Session() as sess, sess.begin():
|
|
649
|
+
stmt = (
|
|
650
|
+
update(table)
|
|
651
|
+
.where(table.c.session_id == session_id)
|
|
652
|
+
.where(table.c.session_type == session_type.value)
|
|
653
|
+
.values(session_data=func.JSON_SET_STRING(table.c.session_data, "session_name", session_name))
|
|
654
|
+
)
|
|
655
|
+
result = sess.execute(stmt)
|
|
656
|
+
if result.rowcount == 0:
|
|
657
|
+
return None
|
|
658
|
+
|
|
659
|
+
# Fetch the updated record
|
|
660
|
+
select_stmt = select(table).where(table.c.session_id == session_id)
|
|
661
|
+
row = sess.execute(select_stmt).fetchone()
|
|
662
|
+
if not row:
|
|
663
|
+
return None
|
|
664
|
+
|
|
665
|
+
session = dict(row._mapping)
|
|
666
|
+
|
|
667
|
+
log_debug(f"Renamed session with id '{session_id}' to '{session_name}'")
|
|
668
|
+
|
|
669
|
+
if not deserialize:
|
|
670
|
+
return session
|
|
671
|
+
|
|
672
|
+
if session_type == SessionType.AGENT:
|
|
673
|
+
return AgentSession.from_dict(session)
|
|
674
|
+
elif session_type == SessionType.TEAM:
|
|
675
|
+
return TeamSession.from_dict(session)
|
|
676
|
+
elif session_type == SessionType.WORKFLOW:
|
|
677
|
+
return WorkflowSession.from_dict(session)
|
|
678
|
+
else:
|
|
679
|
+
raise ValueError(f"Invalid session type: {session_type}")
|
|
680
|
+
|
|
681
|
+
except Exception as e:
|
|
682
|
+
log_error(f"Error renaming session: {e}")
|
|
683
|
+
raise e
|
|
684
|
+
|
|
685
|
+
def upsert_session(self, session: Session, deserialize: Optional[bool] = True) -> Optional[Session]:
|
|
686
|
+
"""
|
|
687
|
+
Insert or update a session in the database.
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
session (Session): The session data to upsert.
|
|
691
|
+
deserialize (Optional[bool]): Whether to deserialize the session. Defaults to True.
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
Optional[Union[Session, Dict[str, Any]]]:
|
|
695
|
+
- When deserialize=True: Session object
|
|
696
|
+
- When deserialize=False: Session dictionary
|
|
697
|
+
|
|
698
|
+
Raises:
|
|
699
|
+
Exception: If an error occurs during upsert.
|
|
700
|
+
"""
|
|
701
|
+
try:
|
|
702
|
+
table = self._get_table(table_type="sessions", create_table_if_not_found=True)
|
|
703
|
+
if table is None:
|
|
704
|
+
return None
|
|
705
|
+
|
|
706
|
+
session_dict = session.to_dict()
|
|
707
|
+
|
|
708
|
+
if isinstance(session, AgentSession):
|
|
709
|
+
with self.Session() as sess, sess.begin():
|
|
710
|
+
stmt = mysql.insert(table).values(
|
|
711
|
+
session_id=session_dict.get("session_id"),
|
|
712
|
+
session_type=SessionType.AGENT.value,
|
|
713
|
+
agent_id=session_dict.get("agent_id"),
|
|
714
|
+
user_id=session_dict.get("user_id"),
|
|
715
|
+
runs=session_dict.get("runs"),
|
|
716
|
+
agent_data=session_dict.get("agent_data"),
|
|
717
|
+
session_data=session_dict.get("session_data"),
|
|
718
|
+
summary=session_dict.get("summary"),
|
|
719
|
+
metadata=session_dict.get("metadata"),
|
|
720
|
+
created_at=session_dict.get("created_at"),
|
|
721
|
+
updated_at=session_dict.get("created_at"),
|
|
722
|
+
)
|
|
723
|
+
stmt = stmt.on_duplicate_key_update(
|
|
724
|
+
agent_id=stmt.inserted.agent_id,
|
|
725
|
+
user_id=stmt.inserted.user_id,
|
|
726
|
+
agent_data=stmt.inserted.agent_data,
|
|
727
|
+
session_data=stmt.inserted.session_data,
|
|
728
|
+
summary=stmt.inserted.summary,
|
|
729
|
+
metadata=stmt.inserted.metadata,
|
|
730
|
+
runs=stmt.inserted.runs,
|
|
731
|
+
updated_at=int(time.time()),
|
|
732
|
+
)
|
|
733
|
+
sess.execute(stmt)
|
|
734
|
+
|
|
735
|
+
# Fetch the result
|
|
736
|
+
select_stmt = select(table).where(
|
|
737
|
+
(table.c.session_id == session_dict.get("session_id"))
|
|
738
|
+
& (table.c.agent_id == session_dict.get("agent_id"))
|
|
739
|
+
)
|
|
740
|
+
row = sess.execute(select_stmt).fetchone()
|
|
741
|
+
if row is None:
|
|
742
|
+
return None
|
|
743
|
+
|
|
744
|
+
if not deserialize:
|
|
745
|
+
return row._mapping
|
|
746
|
+
|
|
747
|
+
return AgentSession.from_dict(row._mapping)
|
|
748
|
+
|
|
749
|
+
elif isinstance(session, TeamSession):
|
|
750
|
+
with self.Session() as sess, sess.begin():
|
|
751
|
+
stmt = mysql.insert(table).values(
|
|
752
|
+
session_id=session_dict.get("session_id"),
|
|
753
|
+
session_type=SessionType.TEAM.value,
|
|
754
|
+
team_id=session_dict.get("team_id"),
|
|
755
|
+
user_id=session_dict.get("user_id"),
|
|
756
|
+
runs=session_dict.get("runs"),
|
|
757
|
+
team_data=session_dict.get("team_data"),
|
|
758
|
+
session_data=session_dict.get("session_data"),
|
|
759
|
+
summary=session_dict.get("summary"),
|
|
760
|
+
metadata=session_dict.get("metadata"),
|
|
761
|
+
created_at=session_dict.get("created_at"),
|
|
762
|
+
updated_at=session_dict.get("created_at"),
|
|
763
|
+
)
|
|
764
|
+
stmt = stmt.on_duplicate_key_update(
|
|
765
|
+
team_id=stmt.inserted.team_id,
|
|
766
|
+
user_id=stmt.inserted.user_id,
|
|
767
|
+
team_data=stmt.inserted.team_data,
|
|
768
|
+
session_data=stmt.inserted.session_data,
|
|
769
|
+
summary=stmt.inserted.summary,
|
|
770
|
+
metadata=stmt.inserted.metadata,
|
|
771
|
+
runs=stmt.inserted.runs,
|
|
772
|
+
updated_at=int(time.time()),
|
|
773
|
+
)
|
|
774
|
+
sess.execute(stmt)
|
|
775
|
+
|
|
776
|
+
# Fetch the result
|
|
777
|
+
select_stmt = select(table).where(
|
|
778
|
+
(table.c.session_id == session_dict.get("session_id"))
|
|
779
|
+
& (table.c.team_id == session_dict.get("team_id"))
|
|
780
|
+
)
|
|
781
|
+
row = sess.execute(select_stmt).fetchone()
|
|
782
|
+
if row is None:
|
|
783
|
+
return None
|
|
784
|
+
|
|
785
|
+
if not deserialize:
|
|
786
|
+
return row._mapping
|
|
787
|
+
|
|
788
|
+
return TeamSession.from_dict(row._mapping)
|
|
789
|
+
|
|
790
|
+
else:
|
|
791
|
+
with self.Session() as sess, sess.begin():
|
|
792
|
+
stmt = mysql.insert(table).values(
|
|
793
|
+
session_id=session_dict.get("session_id"),
|
|
794
|
+
session_type=SessionType.WORKFLOW.value,
|
|
795
|
+
workflow_id=session_dict.get("workflow_id"),
|
|
796
|
+
user_id=session_dict.get("user_id"),
|
|
797
|
+
runs=session_dict.get("runs"),
|
|
798
|
+
workflow_data=session_dict.get("workflow_data"),
|
|
799
|
+
session_data=session_dict.get("session_data"),
|
|
800
|
+
summary=session_dict.get("summary"),
|
|
801
|
+
metadata=session_dict.get("metadata"),
|
|
802
|
+
created_at=session_dict.get("created_at"),
|
|
803
|
+
updated_at=session_dict.get("created_at"),
|
|
804
|
+
)
|
|
805
|
+
stmt = stmt.on_duplicate_key_update(
|
|
806
|
+
workflow_id=stmt.inserted.workflow_id,
|
|
807
|
+
user_id=stmt.inserted.user_id,
|
|
808
|
+
workflow_data=stmt.inserted.workflow_data,
|
|
809
|
+
session_data=stmt.inserted.session_data,
|
|
810
|
+
summary=stmt.inserted.summary,
|
|
811
|
+
metadata=stmt.inserted.metadata,
|
|
812
|
+
runs=stmt.inserted.runs,
|
|
813
|
+
updated_at=int(time.time()),
|
|
814
|
+
)
|
|
815
|
+
sess.execute(stmt)
|
|
816
|
+
|
|
817
|
+
# Fetch the result
|
|
818
|
+
select_stmt = select(table).where(
|
|
819
|
+
(table.c.session_id == session_dict.get("session_id"))
|
|
820
|
+
& (table.c.workflow_id == session_dict.get("workflow_id"))
|
|
821
|
+
)
|
|
822
|
+
row = sess.execute(select_stmt).fetchone()
|
|
823
|
+
if row is None:
|
|
824
|
+
return None
|
|
825
|
+
|
|
826
|
+
if not deserialize:
|
|
827
|
+
return row._mapping
|
|
828
|
+
|
|
829
|
+
return WorkflowSession.from_dict(row._mapping)
|
|
830
|
+
|
|
831
|
+
except Exception as e:
|
|
832
|
+
log_error(f"Error upserting into sessions table: {e}")
|
|
833
|
+
raise e
|
|
834
|
+
|
|
835
|
+
def upsert_sessions(
|
|
836
|
+
self, sessions: List[Session], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
|
|
837
|
+
) -> List[Union[Session, Dict[str, Any]]]:
|
|
838
|
+
"""
|
|
839
|
+
Bulk upsert multiple sessions for improved performance on large datasets.
|
|
840
|
+
|
|
841
|
+
Args:
|
|
842
|
+
sessions (List[Session]): List of sessions to upsert.
|
|
843
|
+
deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
|
|
844
|
+
|
|
845
|
+
Returns:
|
|
846
|
+
List[Union[Session, Dict[str, Any]]]: List of upserted sessions.
|
|
847
|
+
|
|
848
|
+
Raises:
|
|
849
|
+
Exception: If an error occurs during bulk upsert.
|
|
850
|
+
"""
|
|
851
|
+
if not sessions:
|
|
852
|
+
return []
|
|
853
|
+
|
|
854
|
+
try:
|
|
855
|
+
table = self._get_table(table_type="sessions", create_table_if_not_found=True)
|
|
856
|
+
if table is None:
|
|
857
|
+
return []
|
|
858
|
+
|
|
859
|
+
# Group sessions by type for batch processing
|
|
860
|
+
agent_sessions = []
|
|
861
|
+
team_sessions = []
|
|
862
|
+
workflow_sessions = []
|
|
863
|
+
|
|
864
|
+
for session in sessions:
|
|
865
|
+
if isinstance(session, AgentSession):
|
|
866
|
+
agent_sessions.append(session)
|
|
867
|
+
elif isinstance(session, TeamSession):
|
|
868
|
+
team_sessions.append(session)
|
|
869
|
+
elif isinstance(session, WorkflowSession):
|
|
870
|
+
workflow_sessions.append(session)
|
|
871
|
+
|
|
872
|
+
results: List[Union[Session, Dict[str, Any]]] = []
|
|
873
|
+
|
|
874
|
+
with self.Session() as sess, sess.begin():
|
|
875
|
+
# Bulk upsert agent sessions
|
|
876
|
+
if agent_sessions:
|
|
877
|
+
agent_data = []
|
|
878
|
+
for session in agent_sessions:
|
|
879
|
+
session_dict = session.to_dict()
|
|
880
|
+
# Use preserved updated_at if flag is set, otherwise use current time
|
|
881
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
882
|
+
agent_data.append(
|
|
883
|
+
{
|
|
884
|
+
"session_id": session_dict.get("session_id"),
|
|
885
|
+
"session_type": SessionType.AGENT.value,
|
|
886
|
+
"agent_id": session_dict.get("agent_id"),
|
|
887
|
+
"user_id": session_dict.get("user_id"),
|
|
888
|
+
"runs": session_dict.get("runs"),
|
|
889
|
+
"agent_data": session_dict.get("agent_data"),
|
|
890
|
+
"session_data": session_dict.get("session_data"),
|
|
891
|
+
"summary": session_dict.get("summary"),
|
|
892
|
+
"metadata": session_dict.get("metadata"),
|
|
893
|
+
"created_at": session_dict.get("created_at"),
|
|
894
|
+
"updated_at": updated_at,
|
|
895
|
+
}
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
if agent_data:
|
|
899
|
+
stmt = mysql.insert(table)
|
|
900
|
+
stmt = stmt.on_duplicate_key_update(
|
|
901
|
+
agent_id=stmt.inserted.agent_id,
|
|
902
|
+
user_id=stmt.inserted.user_id,
|
|
903
|
+
agent_data=stmt.inserted.agent_data,
|
|
904
|
+
session_data=stmt.inserted.session_data,
|
|
905
|
+
summary=stmt.inserted.summary,
|
|
906
|
+
metadata=stmt.inserted.metadata,
|
|
907
|
+
runs=stmt.inserted.runs,
|
|
908
|
+
updated_at=stmt.inserted.updated_at,
|
|
909
|
+
)
|
|
910
|
+
sess.execute(stmt, agent_data)
|
|
911
|
+
|
|
912
|
+
# Fetch the results for agent sessions
|
|
913
|
+
agent_ids = [session.session_id for session in agent_sessions]
|
|
914
|
+
select_stmt = select(table).where(table.c.session_id.in_(agent_ids))
|
|
915
|
+
result = sess.execute(select_stmt).fetchall()
|
|
916
|
+
|
|
917
|
+
for row in result:
|
|
918
|
+
if deserialize:
|
|
919
|
+
deserialized_session = AgentSession.from_dict(session_dict)
|
|
920
|
+
if deserialized_session is None:
|
|
921
|
+
continue
|
|
922
|
+
results.append(deserialized_session)
|
|
923
|
+
else:
|
|
924
|
+
results.append(dict(row._mapping))
|
|
925
|
+
|
|
926
|
+
# Bulk upsert team sessions
|
|
927
|
+
if team_sessions:
|
|
928
|
+
team_data = []
|
|
929
|
+
for session in team_sessions:
|
|
930
|
+
session_dict = session.to_dict()
|
|
931
|
+
# Use preserved updated_at if flag is set, otherwise use current time
|
|
932
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
933
|
+
team_data.append(
|
|
934
|
+
{
|
|
935
|
+
"session_id": session_dict.get("session_id"),
|
|
936
|
+
"session_type": SessionType.TEAM.value,
|
|
937
|
+
"team_id": session_dict.get("team_id"),
|
|
938
|
+
"user_id": session_dict.get("user_id"),
|
|
939
|
+
"runs": session_dict.get("runs"),
|
|
940
|
+
"team_data": session_dict.get("team_data"),
|
|
941
|
+
"session_data": session_dict.get("session_data"),
|
|
942
|
+
"summary": session_dict.get("summary"),
|
|
943
|
+
"metadata": session_dict.get("metadata"),
|
|
944
|
+
"created_at": session_dict.get("created_at"),
|
|
945
|
+
"updated_at": updated_at,
|
|
946
|
+
}
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
if team_data:
|
|
950
|
+
stmt = mysql.insert(table)
|
|
951
|
+
stmt = stmt.on_duplicate_key_update(
|
|
952
|
+
team_id=stmt.inserted.team_id,
|
|
953
|
+
user_id=stmt.inserted.user_id,
|
|
954
|
+
team_data=stmt.inserted.team_data,
|
|
955
|
+
session_data=stmt.inserted.session_data,
|
|
956
|
+
summary=stmt.inserted.summary,
|
|
957
|
+
metadata=stmt.inserted.metadata,
|
|
958
|
+
runs=stmt.inserted.runs,
|
|
959
|
+
updated_at=stmt.inserted.updated_at,
|
|
960
|
+
)
|
|
961
|
+
sess.execute(stmt, team_data)
|
|
962
|
+
|
|
963
|
+
# Fetch the results for team sessions
|
|
964
|
+
team_ids = [session.session_id for session in team_sessions]
|
|
965
|
+
select_stmt = select(table).where(table.c.session_id.in_(team_ids))
|
|
966
|
+
result = sess.execute(select_stmt).fetchall()
|
|
967
|
+
|
|
968
|
+
for row in result:
|
|
969
|
+
if deserialize:
|
|
970
|
+
deserialized_team_session = TeamSession.from_dict(session_dict)
|
|
971
|
+
if deserialized_team_session is None:
|
|
972
|
+
continue
|
|
973
|
+
results.append(deserialized_team_session)
|
|
974
|
+
else:
|
|
975
|
+
results.append(dict(row._mapping))
|
|
976
|
+
|
|
977
|
+
# Bulk upsert workflow sessions
|
|
978
|
+
if workflow_sessions:
|
|
979
|
+
workflow_data = []
|
|
980
|
+
for session in workflow_sessions:
|
|
981
|
+
session_dict = session.to_dict()
|
|
982
|
+
# Use preserved updated_at if flag is set, otherwise use current time
|
|
983
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
984
|
+
workflow_data.append(
|
|
985
|
+
{
|
|
986
|
+
"session_id": session_dict.get("session_id"),
|
|
987
|
+
"session_type": SessionType.WORKFLOW.value,
|
|
988
|
+
"workflow_id": session_dict.get("workflow_id"),
|
|
989
|
+
"user_id": session_dict.get("user_id"),
|
|
990
|
+
"runs": session_dict.get("runs"),
|
|
991
|
+
"workflow_data": session_dict.get("workflow_data"),
|
|
992
|
+
"session_data": session_dict.get("session_data"),
|
|
993
|
+
"summary": session_dict.get("summary"),
|
|
994
|
+
"metadata": session_dict.get("metadata"),
|
|
995
|
+
"created_at": session_dict.get("created_at"),
|
|
996
|
+
"updated_at": updated_at,
|
|
997
|
+
}
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
if workflow_data:
|
|
1001
|
+
stmt = mysql.insert(table)
|
|
1002
|
+
stmt = stmt.on_duplicate_key_update(
|
|
1003
|
+
workflow_id=stmt.inserted.workflow_id,
|
|
1004
|
+
user_id=stmt.inserted.user_id,
|
|
1005
|
+
workflow_data=stmt.inserted.workflow_data,
|
|
1006
|
+
session_data=stmt.inserted.session_data,
|
|
1007
|
+
summary=stmt.inserted.summary,
|
|
1008
|
+
metadata=stmt.inserted.metadata,
|
|
1009
|
+
runs=stmt.inserted.runs,
|
|
1010
|
+
updated_at=stmt.inserted.updated_at,
|
|
1011
|
+
)
|
|
1012
|
+
sess.execute(stmt, workflow_data)
|
|
1013
|
+
|
|
1014
|
+
# Fetch the results for workflow sessions
|
|
1015
|
+
workflow_ids = [session.session_id for session in workflow_sessions]
|
|
1016
|
+
select_stmt = select(table).where(table.c.session_id.in_(workflow_ids))
|
|
1017
|
+
result = sess.execute(select_stmt).fetchall()
|
|
1018
|
+
|
|
1019
|
+
for row in result:
|
|
1020
|
+
if deserialize:
|
|
1021
|
+
deserialized_workflow_session = WorkflowSession.from_dict(session_dict)
|
|
1022
|
+
if deserialized_workflow_session is None:
|
|
1023
|
+
continue
|
|
1024
|
+
results.append(deserialized_workflow_session)
|
|
1025
|
+
else:
|
|
1026
|
+
results.append(dict(row._mapping))
|
|
1027
|
+
|
|
1028
|
+
return results
|
|
1029
|
+
|
|
1030
|
+
except Exception as e:
|
|
1031
|
+
log_error(f"Exception during bulk session upsert: {e}")
|
|
1032
|
+
return []
|
|
1033
|
+
|
|
1034
|
+
# -- Memory methods --
|
|
1035
|
+
def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
|
|
1036
|
+
"""Delete a user memory from the database.
|
|
1037
|
+
|
|
1038
|
+
Args:
|
|
1039
|
+
memory_id (str): The ID of the memory to delete.
|
|
1040
|
+
user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
|
|
1041
|
+
|
|
1042
|
+
Returns:
|
|
1043
|
+
bool: True if deletion was successful, False otherwise.
|
|
1044
|
+
|
|
1045
|
+
Raises:
|
|
1046
|
+
Exception: If an error occurs during deletion.
|
|
1047
|
+
"""
|
|
1048
|
+
try:
|
|
1049
|
+
table = self._get_table(table_type="memories")
|
|
1050
|
+
if table is None:
|
|
1051
|
+
return
|
|
1052
|
+
|
|
1053
|
+
with self.Session() as sess, sess.begin():
|
|
1054
|
+
delete_stmt = table.delete().where(table.c.memory_id == memory_id)
|
|
1055
|
+
if user_id is not None:
|
|
1056
|
+
delete_stmt = delete_stmt.where(table.c.user_id == user_id)
|
|
1057
|
+
result = sess.execute(delete_stmt)
|
|
1058
|
+
|
|
1059
|
+
success = result.rowcount > 0
|
|
1060
|
+
if success:
|
|
1061
|
+
log_debug(f"Successfully deleted memory id: {memory_id}")
|
|
1062
|
+
else:
|
|
1063
|
+
log_debug(f"No memory found with id: {memory_id}")
|
|
1064
|
+
|
|
1065
|
+
except Exception as e:
|
|
1066
|
+
log_error(f"Error deleting memory: {e}")
|
|
1067
|
+
raise e
|
|
1068
|
+
|
|
1069
|
+
def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
|
|
1070
|
+
"""Delete user memories from the database.
|
|
1071
|
+
|
|
1072
|
+
Args:
|
|
1073
|
+
memory_ids (List[str]): The IDs of the memories to delete.
|
|
1074
|
+
user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
|
|
1075
|
+
|
|
1076
|
+
Raises:
|
|
1077
|
+
Exception: If an error occurs during deletion.
|
|
1078
|
+
"""
|
|
1079
|
+
try:
|
|
1080
|
+
table = self._get_table(table_type="memories")
|
|
1081
|
+
if table is None:
|
|
1082
|
+
return
|
|
1083
|
+
|
|
1084
|
+
with self.Session() as sess, sess.begin():
|
|
1085
|
+
delete_stmt = table.delete().where(table.c.memory_id.in_(memory_ids))
|
|
1086
|
+
if user_id is not None:
|
|
1087
|
+
delete_stmt = delete_stmt.where(table.c.user_id == user_id)
|
|
1088
|
+
result = sess.execute(delete_stmt)
|
|
1089
|
+
if result.rowcount == 0:
|
|
1090
|
+
log_debug(f"No memories found with ids: {memory_ids}")
|
|
1091
|
+
|
|
1092
|
+
except Exception as e:
|
|
1093
|
+
log_error(f"Error deleting memories: {e}")
|
|
1094
|
+
raise e
|
|
1095
|
+
|
|
1096
|
+
def get_all_memory_topics(self) -> List[str]:
|
|
1097
|
+
"""Get all memory topics from the database.
|
|
1098
|
+
|
|
1099
|
+
Returns:
|
|
1100
|
+
List[str]: List of memory topics.
|
|
1101
|
+
"""
|
|
1102
|
+
try:
|
|
1103
|
+
table = self._get_table(table_type="memories")
|
|
1104
|
+
if table is None:
|
|
1105
|
+
return []
|
|
1106
|
+
|
|
1107
|
+
with self.Session() as sess, sess.begin():
|
|
1108
|
+
stmt = select(table.c.topics)
|
|
1109
|
+
result = sess.execute(stmt).fetchall()
|
|
1110
|
+
|
|
1111
|
+
topics = []
|
|
1112
|
+
for record in result:
|
|
1113
|
+
if record is not None and record[0] is not None:
|
|
1114
|
+
topic_list = json.loads(record[0]) if isinstance(record[0], str) else record[0]
|
|
1115
|
+
if isinstance(topic_list, list):
|
|
1116
|
+
topics.extend(topic_list)
|
|
1117
|
+
|
|
1118
|
+
return list(set(topics))
|
|
1119
|
+
|
|
1120
|
+
except Exception as e:
|
|
1121
|
+
log_error(f"Exception reading from memory table: {e}")
|
|
1122
|
+
raise e
|
|
1123
|
+
|
|
1124
|
+
def get_user_memory(
|
|
1125
|
+
self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
|
|
1126
|
+
) -> Optional[UserMemory]:
|
|
1127
|
+
"""Get a memory from the database.
|
|
1128
|
+
|
|
1129
|
+
Args:
|
|
1130
|
+
memory_id (str): The ID of the memory to get.
|
|
1131
|
+
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
1132
|
+
user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
|
|
1133
|
+
|
|
1134
|
+
Returns:
|
|
1135
|
+
Union[UserMemory, Dict[str, Any], None]:
|
|
1136
|
+
- When deserialize=True: UserMemory object
|
|
1137
|
+
- When deserialize=False: UserMemory dictionary
|
|
1138
|
+
|
|
1139
|
+
Raises:
|
|
1140
|
+
Exception: If an error occurs during retrieval.
|
|
1141
|
+
"""
|
|
1142
|
+
try:
|
|
1143
|
+
table = self._get_table(table_type="memories")
|
|
1144
|
+
if table is None:
|
|
1145
|
+
return None
|
|
1146
|
+
|
|
1147
|
+
with self.Session() as sess, sess.begin():
|
|
1148
|
+
stmt = select(table).where(table.c.memory_id == memory_id)
|
|
1149
|
+
if user_id is not None:
|
|
1150
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
1151
|
+
|
|
1152
|
+
result = sess.execute(stmt).fetchone()
|
|
1153
|
+
if not result:
|
|
1154
|
+
return None
|
|
1155
|
+
|
|
1156
|
+
memory_raw = result._mapping
|
|
1157
|
+
if not deserialize:
|
|
1158
|
+
return memory_raw
|
|
1159
|
+
return UserMemory.from_dict(memory_raw)
|
|
1160
|
+
|
|
1161
|
+
except Exception as e:
|
|
1162
|
+
log_error(f"Exception reading from memory table: {e}")
|
|
1163
|
+
raise e
|
|
1164
|
+
|
|
1165
|
+
def get_user_memories(
|
|
1166
|
+
self,
|
|
1167
|
+
user_id: Optional[str] = None,
|
|
1168
|
+
agent_id: Optional[str] = None,
|
|
1169
|
+
team_id: Optional[str] = None,
|
|
1170
|
+
topics: Optional[List[str]] = None,
|
|
1171
|
+
search_content: Optional[str] = None,
|
|
1172
|
+
limit: Optional[int] = None,
|
|
1173
|
+
page: Optional[int] = None,
|
|
1174
|
+
sort_by: Optional[str] = None,
|
|
1175
|
+
sort_order: Optional[str] = None,
|
|
1176
|
+
deserialize: Optional[bool] = True,
|
|
1177
|
+
) -> Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
|
|
1178
|
+
"""Get all memories from the database as UserMemory objects.
|
|
1179
|
+
|
|
1180
|
+
Args:
|
|
1181
|
+
user_id (Optional[str]): The ID of the user to filter by.
|
|
1182
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
1183
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
1184
|
+
topics (Optional[List[str]]): The topics to filter by.
|
|
1185
|
+
search_content (Optional[str]): The content to search for.
|
|
1186
|
+
limit (Optional[int]): The maximum number of memories to return.
|
|
1187
|
+
page (Optional[int]): The page number.
|
|
1188
|
+
sort_by (Optional[str]): The column to sort by.
|
|
1189
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1190
|
+
deserialize (Optional[bool]): Whether to serialize the memories. Defaults to True.
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
Returns:
|
|
1194
|
+
Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
|
|
1195
|
+
- When deserialize=True: List of UserMemory objects
|
|
1196
|
+
- When deserialize=False: Tuple of (memory dictionaries, total count)
|
|
1197
|
+
|
|
1198
|
+
Raises:
|
|
1199
|
+
Exception: If an error occurs during retrieval.
|
|
1200
|
+
"""
|
|
1201
|
+
try:
|
|
1202
|
+
table = self._get_table(table_type="memories")
|
|
1203
|
+
if table is None:
|
|
1204
|
+
return [] if deserialize else ([], 0)
|
|
1205
|
+
|
|
1206
|
+
with self.Session() as sess, sess.begin():
|
|
1207
|
+
stmt = select(table)
|
|
1208
|
+
# Filtering
|
|
1209
|
+
if user_id is not None:
|
|
1210
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
1211
|
+
if agent_id is not None:
|
|
1212
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
1213
|
+
if team_id is not None:
|
|
1214
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
1215
|
+
if topics is not None:
|
|
1216
|
+
topic_conditions = [func.JSON_ARRAY_CONTAINS_STRING(table.c.topics, topic) for topic in topics]
|
|
1217
|
+
if topic_conditions:
|
|
1218
|
+
stmt = stmt.where(and_(*topic_conditions))
|
|
1219
|
+
if search_content is not None:
|
|
1220
|
+
stmt = stmt.where(table.c.memory.like(f"%{search_content}%"))
|
|
1221
|
+
|
|
1222
|
+
# Get total count after applying filtering
|
|
1223
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1224
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
1225
|
+
|
|
1226
|
+
# Sorting
|
|
1227
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
1228
|
+
|
|
1229
|
+
# Paginating
|
|
1230
|
+
if limit is not None:
|
|
1231
|
+
stmt = stmt.limit(limit)
|
|
1232
|
+
if page is not None:
|
|
1233
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
1234
|
+
|
|
1235
|
+
result = sess.execute(stmt).fetchall()
|
|
1236
|
+
if not result:
|
|
1237
|
+
return [] if deserialize else ([], 0)
|
|
1238
|
+
|
|
1239
|
+
memories_raw = [record._mapping for record in result]
|
|
1240
|
+
if not deserialize:
|
|
1241
|
+
return memories_raw, total_count
|
|
1242
|
+
|
|
1243
|
+
return [UserMemory.from_dict(record) for record in memories_raw]
|
|
1244
|
+
|
|
1245
|
+
except Exception as e:
|
|
1246
|
+
log_error(f"Exception reading from memory table: {e}")
|
|
1247
|
+
raise e
|
|
1248
|
+
|
|
1249
|
+
def get_user_memory_stats(
|
|
1250
|
+
self, limit: Optional[int] = None, page: Optional[int] = None
|
|
1251
|
+
) -> Tuple[List[Dict[str, Any]], int]:
|
|
1252
|
+
"""Get user memories stats.
|
|
1253
|
+
|
|
1254
|
+
Args:
|
|
1255
|
+
limit (Optional[int]): The maximum number of user stats to return.
|
|
1256
|
+
page (Optional[int]): The page number.
|
|
1257
|
+
|
|
1258
|
+
Returns:
|
|
1259
|
+
Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
|
|
1260
|
+
|
|
1261
|
+
Example:
|
|
1262
|
+
(
|
|
1263
|
+
[
|
|
1264
|
+
{
|
|
1265
|
+
"user_id": "123",
|
|
1266
|
+
"total_memories": 10,
|
|
1267
|
+
"last_memory_updated_at": 1714560000,
|
|
1268
|
+
},
|
|
1269
|
+
],
|
|
1270
|
+
total_count: 1,
|
|
1271
|
+
)
|
|
1272
|
+
"""
|
|
1273
|
+
try:
|
|
1274
|
+
table = self._get_table(table_type="memories")
|
|
1275
|
+
if table is None:
|
|
1276
|
+
return [], 0
|
|
1277
|
+
|
|
1278
|
+
with self.Session() as sess, sess.begin():
|
|
1279
|
+
stmt = (
|
|
1280
|
+
select(
|
|
1281
|
+
table.c.user_id,
|
|
1282
|
+
func.count(table.c.memory_id).label("total_memories"),
|
|
1283
|
+
func.max(table.c.updated_at).label("last_memory_updated_at"),
|
|
1284
|
+
)
|
|
1285
|
+
.where(table.c.user_id.is_not(None))
|
|
1286
|
+
.group_by(table.c.user_id)
|
|
1287
|
+
.order_by(func.max(table.c.updated_at).desc())
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1291
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
1292
|
+
|
|
1293
|
+
# Pagination
|
|
1294
|
+
if limit is not None:
|
|
1295
|
+
stmt = stmt.limit(limit)
|
|
1296
|
+
if page is not None:
|
|
1297
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
1298
|
+
|
|
1299
|
+
result = sess.execute(stmt).fetchall()
|
|
1300
|
+
if not result:
|
|
1301
|
+
return [], 0
|
|
1302
|
+
|
|
1303
|
+
return [
|
|
1304
|
+
{
|
|
1305
|
+
"user_id": record.user_id, # type: ignore
|
|
1306
|
+
"total_memories": record.total_memories,
|
|
1307
|
+
"last_memory_updated_at": record.last_memory_updated_at,
|
|
1308
|
+
}
|
|
1309
|
+
for record in result
|
|
1310
|
+
], total_count
|
|
1311
|
+
|
|
1312
|
+
except Exception as e:
|
|
1313
|
+
log_error(f"Exception getting user memory stats: {e}")
|
|
1314
|
+
raise e
|
|
1315
|
+
|
|
1316
|
+
def upsert_user_memory(
|
|
1317
|
+
self, memory: UserMemory, deserialize: Optional[bool] = True
|
|
1318
|
+
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
1319
|
+
"""Upsert a user memory in the database.
|
|
1320
|
+
|
|
1321
|
+
Args:
|
|
1322
|
+
memory (UserMemory): The user memory to upsert.
|
|
1323
|
+
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
1324
|
+
|
|
1325
|
+
Returns:
|
|
1326
|
+
Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
1327
|
+
- When deserialize=True: UserMemory object
|
|
1328
|
+
- When deserialize=False: UserMemory dictionary
|
|
1329
|
+
|
|
1330
|
+
Raises:
|
|
1331
|
+
Exception: If an error occurs during upsert.
|
|
1332
|
+
"""
|
|
1333
|
+
try:
|
|
1334
|
+
table = self._get_table(table_type="memories", create_table_if_not_found=True)
|
|
1335
|
+
if table is None:
|
|
1336
|
+
return None
|
|
1337
|
+
|
|
1338
|
+
with self.Session() as sess, sess.begin():
|
|
1339
|
+
if memory.memory_id is None:
|
|
1340
|
+
memory.memory_id = str(uuid4())
|
|
1341
|
+
|
|
1342
|
+
stmt = mysql.insert(table).values(
|
|
1343
|
+
memory_id=memory.memory_id,
|
|
1344
|
+
memory=memory.memory,
|
|
1345
|
+
input=memory.input,
|
|
1346
|
+
user_id=memory.user_id,
|
|
1347
|
+
agent_id=memory.agent_id,
|
|
1348
|
+
team_id=memory.team_id,
|
|
1349
|
+
topics=memory.topics,
|
|
1350
|
+
updated_at=int(time.time()),
|
|
1351
|
+
)
|
|
1352
|
+
stmt = stmt.on_duplicate_key_update(
|
|
1353
|
+
memory=stmt.inserted.memory,
|
|
1354
|
+
topics=stmt.inserted.topics,
|
|
1355
|
+
input=stmt.inserted.input,
|
|
1356
|
+
user_id=stmt.inserted.user_id,
|
|
1357
|
+
agent_id=stmt.inserted.agent_id,
|
|
1358
|
+
team_id=stmt.inserted.team_id,
|
|
1359
|
+
updated_at=int(time.time()),
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
sess.execute(stmt)
|
|
1363
|
+
|
|
1364
|
+
# Fetch the result
|
|
1365
|
+
select_stmt = select(table).where(table.c.memory_id == memory.memory_id)
|
|
1366
|
+
row = sess.execute(select_stmt).fetchone()
|
|
1367
|
+
if row is None:
|
|
1368
|
+
return None
|
|
1369
|
+
|
|
1370
|
+
memory_raw = row._mapping
|
|
1371
|
+
if not memory_raw or not deserialize:
|
|
1372
|
+
return memory_raw
|
|
1373
|
+
|
|
1374
|
+
return UserMemory.from_dict(memory_raw)
|
|
1375
|
+
|
|
1376
|
+
except Exception as e:
|
|
1377
|
+
log_error(f"Error upserting user memory: {e}")
|
|
1378
|
+
raise e
|
|
1379
|
+
|
|
1380
|
+
def upsert_memories(
|
|
1381
|
+
self, memories: List[UserMemory], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
|
|
1382
|
+
) -> List[Union[UserMemory, Dict[str, Any]]]:
|
|
1383
|
+
"""
|
|
1384
|
+
Bulk upsert multiple user memories for improved performance on large datasets.
|
|
1385
|
+
|
|
1386
|
+
Args:
|
|
1387
|
+
memories (List[UserMemory]): List of memories to upsert.
|
|
1388
|
+
deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
|
|
1389
|
+
|
|
1390
|
+
Returns:
|
|
1391
|
+
List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories.
|
|
1392
|
+
|
|
1393
|
+
Raises:
|
|
1394
|
+
Exception: If an error occurs during bulk upsert.
|
|
1395
|
+
"""
|
|
1396
|
+
if not memories:
|
|
1397
|
+
return []
|
|
1398
|
+
|
|
1399
|
+
try:
|
|
1400
|
+
table = self._get_table(table_type="memories", create_table_if_not_found=True)
|
|
1401
|
+
if table is None:
|
|
1402
|
+
return []
|
|
1403
|
+
|
|
1404
|
+
# Prepare data for bulk insert
|
|
1405
|
+
memory_data = []
|
|
1406
|
+
current_time = int(time.time())
|
|
1407
|
+
for memory in memories:
|
|
1408
|
+
if memory.memory_id is None:
|
|
1409
|
+
memory.memory_id = str(uuid4())
|
|
1410
|
+
# Use preserved updated_at if flag is set, otherwise use current time
|
|
1411
|
+
updated_at = memory.updated_at if preserve_updated_at else current_time
|
|
1412
|
+
memory_data.append(
|
|
1413
|
+
{
|
|
1414
|
+
"memory_id": memory.memory_id,
|
|
1415
|
+
"memory": memory.memory,
|
|
1416
|
+
"input": memory.input,
|
|
1417
|
+
"user_id": memory.user_id,
|
|
1418
|
+
"agent_id": memory.agent_id,
|
|
1419
|
+
"team_id": memory.team_id,
|
|
1420
|
+
"topics": memory.topics,
|
|
1421
|
+
"updated_at": updated_at,
|
|
1422
|
+
}
|
|
1423
|
+
)
|
|
1424
|
+
|
|
1425
|
+
results: List[Union[UserMemory, Dict[str, Any]]] = []
|
|
1426
|
+
|
|
1427
|
+
with self.Session() as sess, sess.begin():
|
|
1428
|
+
if memory_data:
|
|
1429
|
+
stmt = mysql.insert(table)
|
|
1430
|
+
stmt = stmt.on_duplicate_key_update(
|
|
1431
|
+
memory=stmt.inserted.memory,
|
|
1432
|
+
topics=stmt.inserted.topics,
|
|
1433
|
+
input=stmt.inserted.input,
|
|
1434
|
+
user_id=stmt.inserted.user_id,
|
|
1435
|
+
agent_id=stmt.inserted.agent_id,
|
|
1436
|
+
team_id=stmt.inserted.team_id,
|
|
1437
|
+
updated_at=stmt.inserted.updated_at,
|
|
1438
|
+
)
|
|
1439
|
+
sess.execute(stmt, memory_data)
|
|
1440
|
+
|
|
1441
|
+
# Fetch the results
|
|
1442
|
+
memory_ids = [memory.memory_id for memory in memories if memory.memory_id]
|
|
1443
|
+
select_stmt = select(table).where(table.c.memory_id.in_(memory_ids))
|
|
1444
|
+
result = sess.execute(select_stmt).fetchall()
|
|
1445
|
+
|
|
1446
|
+
for row in result:
|
|
1447
|
+
memory_raw = dict(row._mapping)
|
|
1448
|
+
if deserialize:
|
|
1449
|
+
results.append(UserMemory.from_dict(memory_raw))
|
|
1450
|
+
else:
|
|
1451
|
+
results.append(memory_raw)
|
|
1452
|
+
|
|
1453
|
+
return results
|
|
1454
|
+
|
|
1455
|
+
except Exception as e:
|
|
1456
|
+
log_error(f"Exception during bulk memory upsert: {e}")
|
|
1457
|
+
return []
|
|
1458
|
+
|
|
1459
|
+
def clear_memories(self) -> None:
|
|
1460
|
+
"""Delete all memories from the database.
|
|
1461
|
+
|
|
1462
|
+
Raises:
|
|
1463
|
+
Exception: If an error occurs during deletion.
|
|
1464
|
+
"""
|
|
1465
|
+
try:
|
|
1466
|
+
table = self._get_table(table_type="memories")
|
|
1467
|
+
if table is None:
|
|
1468
|
+
return
|
|
1469
|
+
|
|
1470
|
+
with self.Session() as sess, sess.begin():
|
|
1471
|
+
sess.execute(table.delete())
|
|
1472
|
+
|
|
1473
|
+
except Exception as e:
|
|
1474
|
+
log_error(f"Exception deleting all memories: {e}")
|
|
1475
|
+
raise e
|
|
1476
|
+
|
|
1477
|
+
# -- Metrics methods --
|
|
1478
|
+
def _get_all_sessions_for_metrics_calculation(
|
|
1479
|
+
self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None
|
|
1480
|
+
) -> List[Dict[str, Any]]:
|
|
1481
|
+
"""
|
|
1482
|
+
Get all sessions of all types (agent, team, workflow) as raw dictionaries.
|
|
1483
|
+
|
|
1484
|
+
Args:
|
|
1485
|
+
start_timestamp (Optional[int]): The start timestamp to filter by. Defaults to None.
|
|
1486
|
+
end_timestamp (Optional[int]): The end timestamp to filter by. Defaults to None.
|
|
1487
|
+
|
|
1488
|
+
Returns:
|
|
1489
|
+
List[Dict[str, Any]]: List of session dictionaries with session_type field.
|
|
1490
|
+
|
|
1491
|
+
Raises:
|
|
1492
|
+
Exception: If an error occurs during retrieval.
|
|
1493
|
+
"""
|
|
1494
|
+
try:
|
|
1495
|
+
table = self._get_table(table_type="sessions")
|
|
1496
|
+
if table is None:
|
|
1497
|
+
return []
|
|
1498
|
+
|
|
1499
|
+
stmt = select(
|
|
1500
|
+
table.c.user_id,
|
|
1501
|
+
table.c.session_data,
|
|
1502
|
+
table.c.runs,
|
|
1503
|
+
table.c.created_at,
|
|
1504
|
+
table.c.session_type,
|
|
1505
|
+
)
|
|
1506
|
+
|
|
1507
|
+
if start_timestamp is not None:
|
|
1508
|
+
stmt = stmt.where(table.c.created_at >= start_timestamp)
|
|
1509
|
+
if end_timestamp is not None:
|
|
1510
|
+
stmt = stmt.where(table.c.created_at <= end_timestamp)
|
|
1511
|
+
|
|
1512
|
+
with self.Session() as sess:
|
|
1513
|
+
result = sess.execute(stmt).fetchall()
|
|
1514
|
+
return [record._mapping for record in result]
|
|
1515
|
+
|
|
1516
|
+
except Exception as e:
|
|
1517
|
+
log_error(f"Exception reading from sessions table: {e}")
|
|
1518
|
+
return []
|
|
1519
|
+
|
|
1520
|
+
def _get_metrics_calculation_starting_date(self, table: Table) -> Optional[date]:
|
|
1521
|
+
"""Get the first date for which metrics calculation is needed:
|
|
1522
|
+
|
|
1523
|
+
1. If there are metrics records, return the date of the first day without a complete metrics record.
|
|
1524
|
+
2. If there are no metrics records, return the date of the first recorded session.
|
|
1525
|
+
3. If there are no metrics records and no sessions records, return None.
|
|
1526
|
+
|
|
1527
|
+
Args:
|
|
1528
|
+
table (Table): The table to get the starting date for.
|
|
1529
|
+
|
|
1530
|
+
Returns:
|
|
1531
|
+
Optional[date]: The starting date for which metrics calculation is needed.
|
|
1532
|
+
"""
|
|
1533
|
+
with self.Session() as sess:
|
|
1534
|
+
stmt = select(table).order_by(table.c.date.desc()).limit(1)
|
|
1535
|
+
result = sess.execute(stmt).fetchone()
|
|
1536
|
+
|
|
1537
|
+
# 1. Return the date of the first day without a complete metrics record.
|
|
1538
|
+
if result is not None:
|
|
1539
|
+
if result.completed:
|
|
1540
|
+
return result._mapping["date"] + timedelta(days=1)
|
|
1541
|
+
else:
|
|
1542
|
+
return result._mapping["date"]
|
|
1543
|
+
|
|
1544
|
+
# 2. No metrics records. Return the date of the first recorded session.
|
|
1545
|
+
sessions_result, _ = self.get_sessions(sort_by="created_at", sort_order="asc", limit=1, deserialize=False)
|
|
1546
|
+
if not isinstance(sessions_result, list):
|
|
1547
|
+
raise ValueError("Error obtaining session list to calculate metrics")
|
|
1548
|
+
|
|
1549
|
+
first_session_date = sessions_result[0]["created_at"] if sessions_result and len(sessions_result) > 0 else None # type: ignore
|
|
1550
|
+
|
|
1551
|
+
# 3. No metrics records and no sessions records. Return None.
|
|
1552
|
+
if first_session_date is None:
|
|
1553
|
+
return None
|
|
1554
|
+
|
|
1555
|
+
return datetime.fromtimestamp(first_session_date, tz=timezone.utc).date()
|
|
1556
|
+
|
|
1557
|
+
def calculate_metrics(self) -> Optional[list[dict]]:
|
|
1558
|
+
"""Calculate metrics for all dates without complete metrics.
|
|
1559
|
+
|
|
1560
|
+
Returns:
|
|
1561
|
+
Optional[list[dict]]: The calculated metrics.
|
|
1562
|
+
|
|
1563
|
+
Raises:
|
|
1564
|
+
Exception: If an error occurs during metrics calculation.
|
|
1565
|
+
"""
|
|
1566
|
+
try:
|
|
1567
|
+
table = self._get_table(table_type="metrics", create_table_if_not_found=True)
|
|
1568
|
+
if table is None:
|
|
1569
|
+
return None
|
|
1570
|
+
|
|
1571
|
+
starting_date = self._get_metrics_calculation_starting_date(table)
|
|
1572
|
+
if starting_date is None:
|
|
1573
|
+
log_info("No session data found. Won't calculate metrics.")
|
|
1574
|
+
return None
|
|
1575
|
+
|
|
1576
|
+
dates_to_process = get_dates_to_calculate_metrics_for(starting_date)
|
|
1577
|
+
if not dates_to_process:
|
|
1578
|
+
log_info("Metrics already calculated for all relevant dates.")
|
|
1579
|
+
return None
|
|
1580
|
+
|
|
1581
|
+
start_timestamp = int(datetime.combine(dates_to_process[0], datetime.min.time()).timestamp())
|
|
1582
|
+
end_timestamp = int(
|
|
1583
|
+
datetime.combine(dates_to_process[-1] + timedelta(days=1), datetime.min.time()).timestamp()
|
|
1584
|
+
)
|
|
1585
|
+
|
|
1586
|
+
sessions = self._get_all_sessions_for_metrics_calculation(
|
|
1587
|
+
start_timestamp=start_timestamp, end_timestamp=end_timestamp
|
|
1588
|
+
)
|
|
1589
|
+
all_sessions_data = fetch_all_sessions_data(
|
|
1590
|
+
sessions=sessions, dates_to_process=dates_to_process, start_timestamp=start_timestamp
|
|
1591
|
+
)
|
|
1592
|
+
if not all_sessions_data:
|
|
1593
|
+
log_info("No new session data found. Won't calculate metrics.")
|
|
1594
|
+
return None
|
|
1595
|
+
|
|
1596
|
+
metrics_records = []
|
|
1597
|
+
for date_to_process in dates_to_process:
|
|
1598
|
+
date_key = date_to_process.isoformat()
|
|
1599
|
+
sessions_for_date = all_sessions_data.get(date_key, {})
|
|
1600
|
+
|
|
1601
|
+
# Skip dates with no sessions
|
|
1602
|
+
if not any(len(sessions) > 0 for sessions in sessions_for_date.values()):
|
|
1603
|
+
continue
|
|
1604
|
+
|
|
1605
|
+
metrics_record = calculate_date_metrics(date_to_process, sessions_for_date)
|
|
1606
|
+
metrics_records.append(metrics_record)
|
|
1607
|
+
|
|
1608
|
+
if metrics_records:
|
|
1609
|
+
with self.Session() as sess, sess.begin():
|
|
1610
|
+
bulk_upsert_metrics(session=sess, table=table, metrics_records=metrics_records)
|
|
1611
|
+
|
|
1612
|
+
log_debug("Updated metrics calculations")
|
|
1613
|
+
|
|
1614
|
+
return metrics_records
|
|
1615
|
+
|
|
1616
|
+
except Exception as e:
|
|
1617
|
+
log_error(f"Error calculating metrics: {e}")
|
|
1618
|
+
raise e
|
|
1619
|
+
|
|
1620
|
+
def get_metrics(
|
|
1621
|
+
self,
|
|
1622
|
+
starting_date: Optional[date] = None,
|
|
1623
|
+
ending_date: Optional[date] = None,
|
|
1624
|
+
) -> Tuple[List[dict], Optional[int]]:
|
|
1625
|
+
"""Get all metrics matching the given date range.
|
|
1626
|
+
|
|
1627
|
+
Args:
|
|
1628
|
+
starting_date (Optional[date]): The starting date to filter metrics by.
|
|
1629
|
+
ending_date (Optional[date]): The ending date to filter metrics by.
|
|
1630
|
+
|
|
1631
|
+
Returns:
|
|
1632
|
+
Tuple[List[dict], int]: A tuple containing the metrics and the timestamp of the latest update.
|
|
1633
|
+
|
|
1634
|
+
Raises:
|
|
1635
|
+
Exception: If an error occurs during retrieval.
|
|
1636
|
+
"""
|
|
1637
|
+
try:
|
|
1638
|
+
table = self._get_table(table_type="metrics", create_table_if_not_found=True)
|
|
1639
|
+
if table is None:
|
|
1640
|
+
return [], 0
|
|
1641
|
+
|
|
1642
|
+
with self.Session() as sess, sess.begin():
|
|
1643
|
+
stmt = select(table)
|
|
1644
|
+
if starting_date:
|
|
1645
|
+
stmt = stmt.where(table.c.date >= starting_date)
|
|
1646
|
+
if ending_date:
|
|
1647
|
+
stmt = stmt.where(table.c.date <= ending_date)
|
|
1648
|
+
result = sess.execute(stmt).fetchall()
|
|
1649
|
+
if not result:
|
|
1650
|
+
return [], None
|
|
1651
|
+
|
|
1652
|
+
# Get the latest updated_at
|
|
1653
|
+
latest_stmt = select(func.max(table.c.updated_at))
|
|
1654
|
+
latest_updated_at = sess.execute(latest_stmt).scalar()
|
|
1655
|
+
|
|
1656
|
+
return [row._mapping for row in result], latest_updated_at
|
|
1657
|
+
|
|
1658
|
+
except Exception as e:
|
|
1659
|
+
log_error(f"Error getting metrics: {e}")
|
|
1660
|
+
raise e
|
|
1661
|
+
|
|
1662
|
+
# -- Knowledge methods --
|
|
1663
|
+
|
|
1664
|
+
def delete_knowledge_content(self, id: str):
|
|
1665
|
+
"""Delete a knowledge row from the database.
|
|
1666
|
+
|
|
1667
|
+
Args:
|
|
1668
|
+
id (str): The ID of the knowledge row to delete.
|
|
1669
|
+
"""
|
|
1670
|
+
try:
|
|
1671
|
+
table = self._get_table(table_type="knowledge")
|
|
1672
|
+
if table is None:
|
|
1673
|
+
return
|
|
1674
|
+
|
|
1675
|
+
with self.Session() as sess, sess.begin():
|
|
1676
|
+
stmt = table.delete().where(table.c.id == id)
|
|
1677
|
+
sess.execute(stmt)
|
|
1678
|
+
|
|
1679
|
+
log_debug(f"Deleted knowledge content with id '{id}'")
|
|
1680
|
+
except Exception as e:
|
|
1681
|
+
log_error(f"Error deleting knowledge content: {e}")
|
|
1682
|
+
raise e
|
|
1683
|
+
|
|
1684
|
+
def get_knowledge_content(self, id: str) -> Optional[KnowledgeRow]:
|
|
1685
|
+
"""Get a knowledge row from the database.
|
|
1686
|
+
|
|
1687
|
+
Args:
|
|
1688
|
+
id (str): The ID of the knowledge row to get.
|
|
1689
|
+
|
|
1690
|
+
Returns:
|
|
1691
|
+
Optional[KnowledgeRow]: The knowledge row, or None if it doesn't exist.
|
|
1692
|
+
"""
|
|
1693
|
+
try:
|
|
1694
|
+
table = self._get_table(table_type="knowledge")
|
|
1695
|
+
if table is None:
|
|
1696
|
+
return None
|
|
1697
|
+
|
|
1698
|
+
with self.Session() as sess, sess.begin():
|
|
1699
|
+
stmt = select(table).where(table.c.id == id)
|
|
1700
|
+
result = sess.execute(stmt).fetchone()
|
|
1701
|
+
if result is None:
|
|
1702
|
+
return None
|
|
1703
|
+
return KnowledgeRow.model_validate(result._mapping)
|
|
1704
|
+
except Exception as e:
|
|
1705
|
+
log_error(f"Error getting knowledge content: {e}")
|
|
1706
|
+
raise e
|
|
1707
|
+
|
|
1708
|
+
def get_knowledge_contents(
|
|
1709
|
+
self,
|
|
1710
|
+
limit: Optional[int] = None,
|
|
1711
|
+
page: Optional[int] = None,
|
|
1712
|
+
sort_by: Optional[str] = None,
|
|
1713
|
+
sort_order: Optional[str] = None,
|
|
1714
|
+
) -> Tuple[List[KnowledgeRow], int]:
|
|
1715
|
+
"""Get all knowledge contents from the database.
|
|
1716
|
+
|
|
1717
|
+
Args:
|
|
1718
|
+
limit (Optional[int]): The maximum number of knowledge contents to return.
|
|
1719
|
+
page (Optional[int]): The page number.
|
|
1720
|
+
sort_by (Optional[str]): The column to sort by.
|
|
1721
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1722
|
+
|
|
1723
|
+
Returns:
|
|
1724
|
+
Tuple[List[KnowledgeRow], int]: The knowledge contents and total count.
|
|
1725
|
+
|
|
1726
|
+
Raises:
|
|
1727
|
+
Exception: If an error occurs during retrieval.
|
|
1728
|
+
"""
|
|
1729
|
+
table = self._get_table(table_type="knowledge")
|
|
1730
|
+
if table is None:
|
|
1731
|
+
return [], 0
|
|
1732
|
+
|
|
1733
|
+
try:
|
|
1734
|
+
with self.Session() as sess, sess.begin():
|
|
1735
|
+
stmt = select(table)
|
|
1736
|
+
|
|
1737
|
+
# Apply sorting
|
|
1738
|
+
if sort_by is not None:
|
|
1739
|
+
stmt = stmt.order_by(getattr(table.c, sort_by) * (1 if sort_order == "asc" else -1))
|
|
1740
|
+
|
|
1741
|
+
# Get total count before applying limit and pagination
|
|
1742
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1743
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
1744
|
+
|
|
1745
|
+
# Apply pagination after count
|
|
1746
|
+
if limit is not None:
|
|
1747
|
+
stmt = stmt.limit(limit)
|
|
1748
|
+
if page is not None:
|
|
1749
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
1750
|
+
|
|
1751
|
+
result = sess.execute(stmt).fetchall()
|
|
1752
|
+
if result is None:
|
|
1753
|
+
return [], 0
|
|
1754
|
+
|
|
1755
|
+
return [KnowledgeRow.model_validate(record._mapping) for record in result], total_count
|
|
1756
|
+
|
|
1757
|
+
except Exception as e:
|
|
1758
|
+
log_error(f"Error getting knowledge contents: {e}")
|
|
1759
|
+
raise e
|
|
1760
|
+
|
|
1761
|
+
def upsert_knowledge_content(self, knowledge_row: KnowledgeRow):
|
|
1762
|
+
"""Upsert knowledge content in the database.
|
|
1763
|
+
|
|
1764
|
+
Args:
|
|
1765
|
+
knowledge_row (KnowledgeRow): The knowledge row to upsert.
|
|
1766
|
+
|
|
1767
|
+
Returns:
|
|
1768
|
+
Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
|
|
1769
|
+
"""
|
|
1770
|
+
try:
|
|
1771
|
+
table = self._get_table(table_type="knowledge", create_table_if_not_found=True)
|
|
1772
|
+
if table is None:
|
|
1773
|
+
return None
|
|
1774
|
+
|
|
1775
|
+
with self.Session() as sess, sess.begin():
|
|
1776
|
+
# Only include fields that are not None in the update
|
|
1777
|
+
update_fields = {
|
|
1778
|
+
k: v
|
|
1779
|
+
for k, v in {
|
|
1780
|
+
"name": knowledge_row.name,
|
|
1781
|
+
"description": knowledge_row.description,
|
|
1782
|
+
"metadata": knowledge_row.metadata,
|
|
1783
|
+
"type": knowledge_row.type,
|
|
1784
|
+
"size": knowledge_row.size,
|
|
1785
|
+
"linked_to": knowledge_row.linked_to,
|
|
1786
|
+
"access_count": knowledge_row.access_count,
|
|
1787
|
+
"status": knowledge_row.status,
|
|
1788
|
+
"status_message": knowledge_row.status_message,
|
|
1789
|
+
"created_at": knowledge_row.created_at,
|
|
1790
|
+
"updated_at": knowledge_row.updated_at,
|
|
1791
|
+
"external_id": knowledge_row.external_id,
|
|
1792
|
+
}.items()
|
|
1793
|
+
if v is not None
|
|
1794
|
+
}
|
|
1795
|
+
|
|
1796
|
+
stmt = mysql.insert(table).values(knowledge_row.model_dump())
|
|
1797
|
+
stmt = stmt.on_duplicate_key_update(**update_fields)
|
|
1798
|
+
sess.execute(stmt)
|
|
1799
|
+
|
|
1800
|
+
return knowledge_row
|
|
1801
|
+
|
|
1802
|
+
except Exception as e:
|
|
1803
|
+
log_error(f"Error upserting knowledge row: {e}")
|
|
1804
|
+
raise e
|
|
1805
|
+
|
|
1806
|
+
# -- Eval methods --
|
|
1807
|
+
|
|
1808
|
+
def create_eval_run(self, eval_run: EvalRunRecord) -> Optional[EvalRunRecord]:
|
|
1809
|
+
"""Create an EvalRunRecord in the database.
|
|
1810
|
+
|
|
1811
|
+
Args:
|
|
1812
|
+
eval_run (EvalRunRecord): The eval run to create.
|
|
1813
|
+
|
|
1814
|
+
Returns:
|
|
1815
|
+
Optional[EvalRunRecord]: The created eval run, or None if the operation fails.
|
|
1816
|
+
|
|
1817
|
+
Raises:
|
|
1818
|
+
Exception: If an error occurs during creation.
|
|
1819
|
+
"""
|
|
1820
|
+
try:
|
|
1821
|
+
table = self._get_table(table_type="evals", create_table_if_not_found=True)
|
|
1822
|
+
if table is None:
|
|
1823
|
+
return None
|
|
1824
|
+
|
|
1825
|
+
with self.Session() as sess, sess.begin():
|
|
1826
|
+
current_time = int(time.time())
|
|
1827
|
+
stmt = mysql.insert(table).values(
|
|
1828
|
+
{"created_at": current_time, "updated_at": current_time, **eval_run.model_dump()}
|
|
1829
|
+
)
|
|
1830
|
+
sess.execute(stmt)
|
|
1831
|
+
|
|
1832
|
+
log_debug(f"Created eval run with id '{eval_run.run_id}'")
|
|
1833
|
+
|
|
1834
|
+
return eval_run
|
|
1835
|
+
|
|
1836
|
+
except Exception as e:
|
|
1837
|
+
log_error(f"Error creating eval run: {e}")
|
|
1838
|
+
raise e
|
|
1839
|
+
|
|
1840
|
+
def delete_eval_run(self, eval_run_id: str) -> None:
|
|
1841
|
+
"""Delete an eval run from the database.
|
|
1842
|
+
|
|
1843
|
+
Args:
|
|
1844
|
+
eval_run_id (str): The ID of the eval run to delete.
|
|
1845
|
+
"""
|
|
1846
|
+
try:
|
|
1847
|
+
table = self._get_table(table_type="evals")
|
|
1848
|
+
if table is None:
|
|
1849
|
+
return
|
|
1850
|
+
|
|
1851
|
+
with self.Session() as sess, sess.begin():
|
|
1852
|
+
stmt = table.delete().where(table.c.run_id == eval_run_id)
|
|
1853
|
+
result = sess.execute(stmt)
|
|
1854
|
+
if result.rowcount == 0:
|
|
1855
|
+
log_warning(f"No eval run found with ID: {eval_run_id}")
|
|
1856
|
+
else:
|
|
1857
|
+
log_debug(f"Deleted eval run with ID: {eval_run_id}")
|
|
1858
|
+
|
|
1859
|
+
except Exception as e:
|
|
1860
|
+
log_error(f"Error deleting eval run {eval_run_id}: {e}")
|
|
1861
|
+
raise e
|
|
1862
|
+
|
|
1863
|
+
def delete_eval_runs(self, eval_run_ids: List[str]) -> None:
|
|
1864
|
+
"""Delete multiple eval runs from the database.
|
|
1865
|
+
|
|
1866
|
+
Args:
|
|
1867
|
+
eval_run_ids (List[str]): List of eval run IDs to delete.
|
|
1868
|
+
"""
|
|
1869
|
+
try:
|
|
1870
|
+
table = self._get_table(table_type="evals")
|
|
1871
|
+
if table is None:
|
|
1872
|
+
return
|
|
1873
|
+
|
|
1874
|
+
with self.Session() as sess, sess.begin():
|
|
1875
|
+
stmt = table.delete().where(table.c.run_id.in_(eval_run_ids))
|
|
1876
|
+
result = sess.execute(stmt)
|
|
1877
|
+
if result.rowcount == 0:
|
|
1878
|
+
log_debug(f"No eval runs found with IDs: {eval_run_ids}")
|
|
1879
|
+
else:
|
|
1880
|
+
log_debug(f"Deleted {result.rowcount} eval runs")
|
|
1881
|
+
|
|
1882
|
+
except Exception as e:
|
|
1883
|
+
log_error(f"Error deleting eval runs {eval_run_ids}: {e}")
|
|
1884
|
+
raise e
|
|
1885
|
+
|
|
1886
|
+
def get_eval_run(
|
|
1887
|
+
self, eval_run_id: str, deserialize: Optional[bool] = True
|
|
1888
|
+
) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1889
|
+
"""Get an eval run from the database.
|
|
1890
|
+
|
|
1891
|
+
Args:
|
|
1892
|
+
eval_run_id (str): The ID of the eval run to get.
|
|
1893
|
+
deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
|
|
1894
|
+
|
|
1895
|
+
Returns:
|
|
1896
|
+
Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1897
|
+
- When deserialize=True: EvalRunRecord object
|
|
1898
|
+
- When deserialize=False: EvalRun dictionary
|
|
1899
|
+
|
|
1900
|
+
Raises:
|
|
1901
|
+
Exception: If an error occurs during retrieval.
|
|
1902
|
+
"""
|
|
1903
|
+
try:
|
|
1904
|
+
table = self._get_table(table_type="evals")
|
|
1905
|
+
if table is None:
|
|
1906
|
+
return None
|
|
1907
|
+
|
|
1908
|
+
with self.Session() as sess, sess.begin():
|
|
1909
|
+
stmt = select(table).where(table.c.run_id == eval_run_id)
|
|
1910
|
+
result = sess.execute(stmt).fetchone()
|
|
1911
|
+
if result is None:
|
|
1912
|
+
return None
|
|
1913
|
+
|
|
1914
|
+
eval_run_raw = result._mapping
|
|
1915
|
+
if not deserialize:
|
|
1916
|
+
return eval_run_raw
|
|
1917
|
+
|
|
1918
|
+
return EvalRunRecord.model_validate(eval_run_raw)
|
|
1919
|
+
|
|
1920
|
+
except Exception as e:
|
|
1921
|
+
log_error(f"Exception getting eval run {eval_run_id}: {e}")
|
|
1922
|
+
raise e
|
|
1923
|
+
|
|
1924
|
+
def get_eval_runs(
|
|
1925
|
+
self,
|
|
1926
|
+
limit: Optional[int] = None,
|
|
1927
|
+
page: Optional[int] = None,
|
|
1928
|
+
sort_by: Optional[str] = None,
|
|
1929
|
+
sort_order: Optional[str] = None,
|
|
1930
|
+
agent_id: Optional[str] = None,
|
|
1931
|
+
team_id: Optional[str] = None,
|
|
1932
|
+
workflow_id: Optional[str] = None,
|
|
1933
|
+
model_id: Optional[str] = None,
|
|
1934
|
+
filter_type: Optional[EvalFilterType] = None,
|
|
1935
|
+
eval_type: Optional[List[EvalType]] = None,
|
|
1936
|
+
deserialize: Optional[bool] = True,
|
|
1937
|
+
) -> Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
|
|
1938
|
+
"""Get all eval runs from the database.
|
|
1939
|
+
|
|
1940
|
+
Args:
|
|
1941
|
+
limit (Optional[int]): The maximum number of eval runs to return.
|
|
1942
|
+
page (Optional[int]): The page number.
|
|
1943
|
+
sort_by (Optional[str]): The column to sort by.
|
|
1944
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1945
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
1946
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
1947
|
+
workflow_id (Optional[str]): The ID of the workflow to filter by.
|
|
1948
|
+
model_id (Optional[str]): The ID of the model to filter by.
|
|
1949
|
+
eval_type (Optional[List[EvalType]]): The type(s) of eval to filter by.
|
|
1950
|
+
filter_type (Optional[EvalFilterType]): Filter by component type (agent, team, workflow).
|
|
1951
|
+
deserialize (Optional[bool]): Whether to serialize the eval runs. Defaults to True.
|
|
1952
|
+
create_table_if_not_found (Optional[bool]): Whether to create the table if it doesn't exist.
|
|
1953
|
+
|
|
1954
|
+
Returns:
|
|
1955
|
+
Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
|
|
1956
|
+
- When deserialize=True: List of EvalRunRecord objects
|
|
1957
|
+
- When deserialize=False: List of dictionaries
|
|
1958
|
+
|
|
1959
|
+
Raises:
|
|
1960
|
+
Exception: If an error occurs during retrieval.
|
|
1961
|
+
"""
|
|
1962
|
+
try:
|
|
1963
|
+
table = self._get_table(table_type="evals")
|
|
1964
|
+
if table is None:
|
|
1965
|
+
return [] if deserialize else ([], 0)
|
|
1966
|
+
|
|
1967
|
+
with self.Session() as sess, sess.begin():
|
|
1968
|
+
stmt = select(table)
|
|
1969
|
+
|
|
1970
|
+
# Filtering
|
|
1971
|
+
if agent_id is not None:
|
|
1972
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
1973
|
+
if team_id is not None:
|
|
1974
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
1975
|
+
if workflow_id is not None:
|
|
1976
|
+
stmt = stmt.where(table.c.workflow_id == workflow_id)
|
|
1977
|
+
if model_id is not None:
|
|
1978
|
+
stmt = stmt.where(table.c.model_id == model_id)
|
|
1979
|
+
if eval_type is not None and len(eval_type) > 0:
|
|
1980
|
+
stmt = stmt.where(table.c.eval_type.in_(eval_type))
|
|
1981
|
+
if filter_type is not None:
|
|
1982
|
+
if filter_type == EvalFilterType.AGENT:
|
|
1983
|
+
stmt = stmt.where(table.c.agent_id.is_not(None))
|
|
1984
|
+
elif filter_type == EvalFilterType.TEAM:
|
|
1985
|
+
stmt = stmt.where(table.c.team_id.is_not(None))
|
|
1986
|
+
elif filter_type == EvalFilterType.WORKFLOW:
|
|
1987
|
+
stmt = stmt.where(table.c.workflow_id.is_not(None))
|
|
1988
|
+
|
|
1989
|
+
# Get total count after applying filtering
|
|
1990
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1991
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
1992
|
+
|
|
1993
|
+
# Sorting
|
|
1994
|
+
if sort_by is None:
|
|
1995
|
+
stmt = stmt.order_by(table.c.created_at.desc())
|
|
1996
|
+
else:
|
|
1997
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
1998
|
+
|
|
1999
|
+
# Paginating
|
|
2000
|
+
if limit is not None:
|
|
2001
|
+
stmt = stmt.limit(limit)
|
|
2002
|
+
if page is not None:
|
|
2003
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
2004
|
+
|
|
2005
|
+
result = sess.execute(stmt).fetchall()
|
|
2006
|
+
if not result:
|
|
2007
|
+
return [] if deserialize else ([], 0)
|
|
2008
|
+
|
|
2009
|
+
eval_runs_raw = [row._mapping for row in result]
|
|
2010
|
+
if not deserialize:
|
|
2011
|
+
return eval_runs_raw, total_count
|
|
2012
|
+
|
|
2013
|
+
return [EvalRunRecord.model_validate(row) for row in eval_runs_raw]
|
|
2014
|
+
|
|
2015
|
+
except Exception as e:
|
|
2016
|
+
log_error(f"Exception getting eval runs: {e}")
|
|
2017
|
+
raise e
|
|
2018
|
+
|
|
2019
|
+
def rename_eval_run(
|
|
2020
|
+
self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
|
|
2021
|
+
) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
2022
|
+
"""Upsert the name of an eval run in the database, returning raw dictionary.
|
|
2023
|
+
|
|
2024
|
+
Args:
|
|
2025
|
+
eval_run_id (str): The ID of the eval run to update.
|
|
2026
|
+
name (str): The new name of the eval run.
|
|
2027
|
+
|
|
2028
|
+
Returns:
|
|
2029
|
+
Optional[Dict[str, Any]]: The updated eval run, or None if the operation fails.
|
|
2030
|
+
|
|
2031
|
+
Raises:
|
|
2032
|
+
Exception: If an error occurs during update.
|
|
2033
|
+
"""
|
|
2034
|
+
try:
|
|
2035
|
+
table = self._get_table(table_type="evals")
|
|
2036
|
+
if table is None:
|
|
2037
|
+
return None
|
|
2038
|
+
|
|
2039
|
+
with self.Session() as sess, sess.begin():
|
|
2040
|
+
stmt = (
|
|
2041
|
+
table.update().where(table.c.run_id == eval_run_id).values(name=name, updated_at=int(time.time()))
|
|
2042
|
+
)
|
|
2043
|
+
sess.execute(stmt)
|
|
2044
|
+
|
|
2045
|
+
eval_run_raw = self.get_eval_run(eval_run_id=eval_run_id, deserialize=deserialize)
|
|
2046
|
+
|
|
2047
|
+
log_debug(f"Renamed eval run with id '{eval_run_id}' to '{name}'")
|
|
2048
|
+
|
|
2049
|
+
if not eval_run_raw or not deserialize:
|
|
2050
|
+
return eval_run_raw
|
|
2051
|
+
|
|
2052
|
+
return EvalRunRecord.model_validate(eval_run_raw)
|
|
2053
|
+
|
|
2054
|
+
except Exception as e:
|
|
2055
|
+
log_error(f"Error renaming eval run {eval_run_id}: {e}")
|
|
2056
|
+
raise e
|
|
2057
|
+
|
|
2058
|
+
# -- Culture methods --
|
|
2059
|
+
|
|
2060
|
+
def clear_cultural_knowledge(self) -> None:
|
|
2061
|
+
"""Delete all cultural knowledge from the database.
|
|
2062
|
+
|
|
2063
|
+
Raises:
|
|
2064
|
+
Exception: If an error occurs during deletion.
|
|
2065
|
+
"""
|
|
2066
|
+
try:
|
|
2067
|
+
table = self._get_table(table_type="culture")
|
|
2068
|
+
if table is None:
|
|
2069
|
+
return
|
|
2070
|
+
|
|
2071
|
+
with self.Session() as sess, sess.begin():
|
|
2072
|
+
sess.execute(table.delete())
|
|
2073
|
+
|
|
2074
|
+
except Exception as e:
|
|
2075
|
+
log_warning(f"Exception deleting all cultural knowledge: {e}")
|
|
2076
|
+
raise e
|
|
2077
|
+
|
|
2078
|
+
def delete_cultural_knowledge(self, id: str) -> None:
|
|
2079
|
+
"""Delete a cultural knowledge entry from the database.
|
|
2080
|
+
|
|
2081
|
+
Args:
|
|
2082
|
+
id (str): The ID of the cultural knowledge to delete.
|
|
2083
|
+
|
|
2084
|
+
Raises:
|
|
2085
|
+
Exception: If an error occurs during deletion.
|
|
2086
|
+
"""
|
|
2087
|
+
try:
|
|
2088
|
+
table = self._get_table(table_type="culture")
|
|
2089
|
+
if table is None:
|
|
2090
|
+
return
|
|
2091
|
+
|
|
2092
|
+
with self.Session() as sess, sess.begin():
|
|
2093
|
+
delete_stmt = table.delete().where(table.c.id == id)
|
|
2094
|
+
result = sess.execute(delete_stmt)
|
|
2095
|
+
|
|
2096
|
+
success = result.rowcount > 0
|
|
2097
|
+
if success:
|
|
2098
|
+
log_debug(f"Successfully deleted cultural knowledge id: {id}")
|
|
2099
|
+
else:
|
|
2100
|
+
log_debug(f"No cultural knowledge found with id: {id}")
|
|
2101
|
+
|
|
2102
|
+
except Exception as e:
|
|
2103
|
+
log_error(f"Error deleting cultural knowledge: {e}")
|
|
2104
|
+
raise e
|
|
2105
|
+
|
|
2106
|
+
def get_cultural_knowledge(
|
|
2107
|
+
self, id: str, deserialize: Optional[bool] = True
|
|
2108
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2109
|
+
"""Get a cultural knowledge entry from the database.
|
|
2110
|
+
|
|
2111
|
+
Args:
|
|
2112
|
+
id (str): The ID of the cultural knowledge to get.
|
|
2113
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2114
|
+
|
|
2115
|
+
Returns:
|
|
2116
|
+
Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The cultural knowledge entry, or None if it doesn't exist.
|
|
2117
|
+
|
|
2118
|
+
Raises:
|
|
2119
|
+
Exception: If an error occurs during retrieval.
|
|
2120
|
+
"""
|
|
2121
|
+
try:
|
|
2122
|
+
table = self._get_table(table_type="culture")
|
|
2123
|
+
if table is None:
|
|
2124
|
+
return None
|
|
2125
|
+
|
|
2126
|
+
with self.Session() as sess, sess.begin():
|
|
2127
|
+
stmt = select(table).where(table.c.id == id)
|
|
2128
|
+
result = sess.execute(stmt).fetchone()
|
|
2129
|
+
if result is None:
|
|
2130
|
+
return None
|
|
2131
|
+
|
|
2132
|
+
db_row = dict(result._mapping)
|
|
2133
|
+
if not db_row or not deserialize:
|
|
2134
|
+
return db_row
|
|
2135
|
+
|
|
2136
|
+
return deserialize_cultural_knowledge_from_db(db_row)
|
|
2137
|
+
|
|
2138
|
+
except Exception as e:
|
|
2139
|
+
log_error(f"Exception reading from cultural knowledge table: {e}")
|
|
2140
|
+
raise e
|
|
2141
|
+
|
|
2142
|
+
def get_all_cultural_knowledge(
|
|
2143
|
+
self,
|
|
2144
|
+
name: Optional[str] = None,
|
|
2145
|
+
agent_id: Optional[str] = None,
|
|
2146
|
+
team_id: Optional[str] = None,
|
|
2147
|
+
limit: Optional[int] = None,
|
|
2148
|
+
page: Optional[int] = None,
|
|
2149
|
+
sort_by: Optional[str] = None,
|
|
2150
|
+
sort_order: Optional[str] = None,
|
|
2151
|
+
deserialize: Optional[bool] = True,
|
|
2152
|
+
) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2153
|
+
"""Get all cultural knowledge from the database as CulturalKnowledge objects.
|
|
2154
|
+
|
|
2155
|
+
Args:
|
|
2156
|
+
name (Optional[str]): The name of the cultural knowledge to filter by.
|
|
2157
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
2158
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
2159
|
+
limit (Optional[int]): The maximum number of cultural knowledge entries to return.
|
|
2160
|
+
page (Optional[int]): The page number.
|
|
2161
|
+
sort_by (Optional[str]): The column to sort by.
|
|
2162
|
+
sort_order (Optional[str]): The order to sort by.
|
|
2163
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2164
|
+
|
|
2165
|
+
Returns:
|
|
2166
|
+
Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2167
|
+
- When deserialize=True: List of CulturalKnowledge objects
|
|
2168
|
+
- When deserialize=False: List of CulturalKnowledge dictionaries and total count
|
|
2169
|
+
|
|
2170
|
+
Raises:
|
|
2171
|
+
Exception: If an error occurs during retrieval.
|
|
2172
|
+
"""
|
|
2173
|
+
try:
|
|
2174
|
+
table = self._get_table(table_type="culture")
|
|
2175
|
+
if table is None:
|
|
2176
|
+
return [] if deserialize else ([], 0)
|
|
2177
|
+
|
|
2178
|
+
with self.Session() as sess, sess.begin():
|
|
2179
|
+
stmt = select(table)
|
|
2180
|
+
|
|
2181
|
+
# Filtering
|
|
2182
|
+
if name is not None:
|
|
2183
|
+
stmt = stmt.where(table.c.name == name)
|
|
2184
|
+
if agent_id is not None:
|
|
2185
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
2186
|
+
if team_id is not None:
|
|
2187
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
2188
|
+
|
|
2189
|
+
# Get total count after applying filtering
|
|
2190
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
2191
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
2192
|
+
|
|
2193
|
+
# Sorting
|
|
2194
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
2195
|
+
# Paginating
|
|
2196
|
+
if limit is not None:
|
|
2197
|
+
stmt = stmt.limit(limit)
|
|
2198
|
+
if page is not None:
|
|
2199
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
2200
|
+
|
|
2201
|
+
result = sess.execute(stmt).fetchall()
|
|
2202
|
+
if not result:
|
|
2203
|
+
return [] if deserialize else ([], 0)
|
|
2204
|
+
|
|
2205
|
+
db_rows = [dict(record._mapping) for record in result]
|
|
2206
|
+
|
|
2207
|
+
if not deserialize:
|
|
2208
|
+
return db_rows, total_count
|
|
2209
|
+
|
|
2210
|
+
return [deserialize_cultural_knowledge_from_db(row) for row in db_rows]
|
|
2211
|
+
|
|
2212
|
+
except Exception as e:
|
|
2213
|
+
log_error(f"Error reading from cultural knowledge table: {e}")
|
|
2214
|
+
raise e
|
|
2215
|
+
|
|
2216
|
+
def upsert_cultural_knowledge(
|
|
2217
|
+
self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
|
|
2218
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2219
|
+
"""Upsert a cultural knowledge entry into the database.
|
|
2220
|
+
|
|
2221
|
+
Args:
|
|
2222
|
+
cultural_knowledge (CulturalKnowledge): The cultural knowledge to upsert.
|
|
2223
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2224
|
+
|
|
2225
|
+
Returns:
|
|
2226
|
+
Optional[CulturalKnowledge]: The upserted cultural knowledge entry.
|
|
2227
|
+
|
|
2228
|
+
Raises:
|
|
2229
|
+
Exception: If an error occurs during upsert.
|
|
2230
|
+
"""
|
|
2231
|
+
try:
|
|
2232
|
+
table = self._get_table(table_type="culture", create_table_if_not_found=True)
|
|
2233
|
+
if table is None:
|
|
2234
|
+
return None
|
|
2235
|
+
|
|
2236
|
+
if cultural_knowledge.id is None:
|
|
2237
|
+
cultural_knowledge.id = str(uuid4())
|
|
2238
|
+
|
|
2239
|
+
# Serialize content, categories, and notes into a JSON dict for DB storage
|
|
2240
|
+
content_dict = serialize_cultural_knowledge_for_db(cultural_knowledge)
|
|
2241
|
+
|
|
2242
|
+
with self.Session() as sess, sess.begin():
|
|
2243
|
+
stmt = mysql.insert(table).values(
|
|
2244
|
+
id=cultural_knowledge.id,
|
|
2245
|
+
name=cultural_knowledge.name,
|
|
2246
|
+
summary=cultural_knowledge.summary,
|
|
2247
|
+
content=content_dict if content_dict else None,
|
|
2248
|
+
metadata=cultural_knowledge.metadata,
|
|
2249
|
+
input=cultural_knowledge.input,
|
|
2250
|
+
created_at=cultural_knowledge.created_at,
|
|
2251
|
+
updated_at=int(time.time()),
|
|
2252
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2253
|
+
team_id=cultural_knowledge.team_id,
|
|
2254
|
+
)
|
|
2255
|
+
stmt = stmt.on_duplicate_key_update(
|
|
2256
|
+
name=cultural_knowledge.name,
|
|
2257
|
+
summary=cultural_knowledge.summary,
|
|
2258
|
+
content=content_dict if content_dict else None,
|
|
2259
|
+
metadata=cultural_knowledge.metadata,
|
|
2260
|
+
input=cultural_knowledge.input,
|
|
2261
|
+
updated_at=int(time.time()),
|
|
2262
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2263
|
+
team_id=cultural_knowledge.team_id,
|
|
2264
|
+
)
|
|
2265
|
+
sess.execute(stmt)
|
|
2266
|
+
|
|
2267
|
+
# Fetch the inserted/updated row
|
|
2268
|
+
return self.get_cultural_knowledge(id=cultural_knowledge.id, deserialize=deserialize)
|
|
2269
|
+
|
|
2270
|
+
except Exception as e:
|
|
2271
|
+
log_error(f"Error upserting cultural knowledge: {e}")
|
|
2272
|
+
raise e
|