agno 1.8.0__py3-none-any.whl → 2.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/__init__.py +8 -0
- agno/agent/__init__.py +19 -27
- agno/agent/agent.py +2781 -4126
- agno/api/agent.py +9 -65
- agno/api/api.py +5 -46
- agno/api/evals.py +6 -17
- agno/api/os.py +17 -0
- agno/api/routes.py +6 -41
- agno/api/schemas/__init__.py +9 -0
- agno/api/schemas/agent.py +5 -21
- agno/api/schemas/evals.py +7 -16
- agno/api/schemas/os.py +14 -0
- agno/api/schemas/team.py +5 -21
- agno/api/schemas/utils.py +21 -0
- agno/api/schemas/workflows.py +11 -7
- agno/api/settings.py +53 -0
- agno/api/team.py +9 -64
- agno/api/workflow.py +28 -0
- agno/cloud/aws/base.py +214 -0
- agno/cloud/aws/s3/__init__.py +2 -0
- agno/cloud/aws/s3/api_client.py +43 -0
- agno/cloud/aws/s3/bucket.py +195 -0
- agno/cloud/aws/s3/object.py +57 -0
- agno/db/__init__.py +24 -0
- agno/db/base.py +245 -0
- agno/db/dynamo/__init__.py +3 -0
- agno/db/dynamo/dynamo.py +1749 -0
- agno/db/dynamo/schemas.py +278 -0
- agno/db/dynamo/utils.py +684 -0
- agno/db/firestore/__init__.py +3 -0
- agno/db/firestore/firestore.py +1438 -0
- agno/db/firestore/schemas.py +130 -0
- agno/db/firestore/utils.py +278 -0
- agno/db/gcs_json/__init__.py +3 -0
- agno/db/gcs_json/gcs_json_db.py +1001 -0
- agno/db/gcs_json/utils.py +194 -0
- agno/db/in_memory/__init__.py +3 -0
- agno/db/in_memory/in_memory_db.py +888 -0
- agno/db/in_memory/utils.py +172 -0
- agno/db/json/__init__.py +3 -0
- agno/db/json/json_db.py +1051 -0
- agno/db/json/utils.py +196 -0
- agno/db/migrations/v1_to_v2.py +162 -0
- agno/db/mongo/__init__.py +3 -0
- agno/db/mongo/mongo.py +1417 -0
- agno/db/mongo/schemas.py +77 -0
- agno/db/mongo/utils.py +204 -0
- agno/db/mysql/__init__.py +3 -0
- agno/db/mysql/mysql.py +1719 -0
- agno/db/mysql/schemas.py +124 -0
- agno/db/mysql/utils.py +298 -0
- agno/db/postgres/__init__.py +3 -0
- agno/db/postgres/postgres.py +1720 -0
- agno/db/postgres/schemas.py +124 -0
- agno/db/postgres/utils.py +281 -0
- agno/db/redis/__init__.py +3 -0
- agno/db/redis/redis.py +1371 -0
- agno/db/redis/schemas.py +109 -0
- agno/db/redis/utils.py +288 -0
- agno/db/schemas/__init__.py +3 -0
- agno/db/schemas/evals.py +33 -0
- agno/db/schemas/knowledge.py +40 -0
- agno/db/schemas/memory.py +46 -0
- agno/db/singlestore/__init__.py +3 -0
- agno/db/singlestore/schemas.py +116 -0
- agno/db/singlestore/singlestore.py +1722 -0
- agno/db/singlestore/utils.py +327 -0
- agno/db/sqlite/__init__.py +3 -0
- agno/db/sqlite/schemas.py +119 -0
- agno/db/sqlite/sqlite.py +1680 -0
- agno/db/sqlite/utils.py +269 -0
- agno/db/utils.py +88 -0
- agno/eval/__init__.py +14 -0
- agno/eval/accuracy.py +142 -43
- agno/eval/performance.py +88 -23
- agno/eval/reliability.py +73 -20
- agno/eval/utils.py +23 -13
- agno/integrations/discord/__init__.py +3 -0
- agno/{app → integrations}/discord/client.py +10 -10
- agno/knowledge/__init__.py +2 -2
- agno/{document → knowledge}/chunking/agentic.py +2 -2
- agno/{document → knowledge}/chunking/document.py +2 -2
- agno/{document → knowledge}/chunking/fixed.py +3 -3
- agno/{document → knowledge}/chunking/markdown.py +2 -2
- agno/{document → knowledge}/chunking/recursive.py +2 -2
- agno/{document → knowledge}/chunking/row.py +2 -2
- agno/knowledge/chunking/semantic.py +59 -0
- agno/knowledge/chunking/strategy.py +121 -0
- agno/knowledge/content.py +74 -0
- agno/knowledge/document/__init__.py +5 -0
- agno/{document → knowledge/document}/base.py +12 -2
- agno/knowledge/embedder/__init__.py +5 -0
- agno/{embedder → knowledge/embedder}/aws_bedrock.py +127 -1
- agno/{embedder → knowledge/embedder}/azure_openai.py +65 -1
- agno/{embedder → knowledge/embedder}/base.py +6 -0
- agno/{embedder → knowledge/embedder}/cohere.py +72 -1
- agno/{embedder → knowledge/embedder}/fastembed.py +17 -1
- agno/{embedder → knowledge/embedder}/fireworks.py +1 -1
- agno/{embedder → knowledge/embedder}/google.py +74 -1
- agno/{embedder → knowledge/embedder}/huggingface.py +36 -2
- agno/{embedder → knowledge/embedder}/jina.py +48 -2
- agno/knowledge/embedder/langdb.py +22 -0
- agno/knowledge/embedder/mistral.py +139 -0
- agno/{embedder → knowledge/embedder}/nebius.py +1 -1
- agno/{embedder → knowledge/embedder}/ollama.py +54 -3
- agno/knowledge/embedder/openai.py +223 -0
- agno/{embedder → knowledge/embedder}/sentence_transformer.py +16 -1
- agno/{embedder → knowledge/embedder}/together.py +1 -1
- agno/{embedder → knowledge/embedder}/voyageai.py +49 -1
- agno/knowledge/knowledge.py +1515 -0
- agno/knowledge/reader/__init__.py +7 -0
- agno/{document → knowledge}/reader/arxiv_reader.py +32 -4
- agno/knowledge/reader/base.py +88 -0
- agno/{document → knowledge}/reader/csv_reader.py +68 -15
- agno/knowledge/reader/docx_reader.py +83 -0
- agno/{document → knowledge}/reader/firecrawl_reader.py +42 -21
- agno/knowledge/reader/gcs_reader.py +67 -0
- agno/{document → knowledge}/reader/json_reader.py +30 -9
- agno/{document → knowledge}/reader/markdown_reader.py +36 -9
- agno/{document → knowledge}/reader/pdf_reader.py +79 -21
- agno/knowledge/reader/reader_factory.py +275 -0
- agno/knowledge/reader/s3_reader.py +171 -0
- agno/{document → knowledge}/reader/text_reader.py +31 -10
- agno/knowledge/reader/url_reader.py +84 -0
- agno/knowledge/reader/web_search_reader.py +389 -0
- agno/{document → knowledge}/reader/website_reader.py +37 -10
- agno/knowledge/reader/wikipedia_reader.py +59 -0
- agno/knowledge/reader/youtube_reader.py +78 -0
- agno/knowledge/remote_content/remote_content.py +88 -0
- agno/{reranker → knowledge/reranker}/base.py +1 -1
- agno/{reranker → knowledge/reranker}/cohere.py +2 -2
- agno/{reranker → knowledge/reranker}/infinity.py +2 -2
- agno/{reranker → knowledge/reranker}/sentence_transformer.py +2 -2
- agno/knowledge/types.py +30 -0
- agno/knowledge/utils.py +169 -0
- agno/media.py +2 -2
- agno/memory/__init__.py +2 -10
- agno/memory/manager.py +1003 -148
- agno/models/aimlapi/__init__.py +2 -2
- agno/models/aimlapi/aimlapi.py +6 -6
- agno/models/anthropic/claude.py +129 -82
- agno/models/aws/bedrock.py +107 -175
- agno/models/aws/claude.py +64 -18
- agno/models/azure/ai_foundry.py +73 -23
- agno/models/base.py +347 -287
- agno/models/cerebras/cerebras.py +84 -27
- agno/models/cohere/chat.py +106 -98
- agno/models/dashscope/dashscope.py +14 -5
- agno/models/google/gemini.py +123 -53
- agno/models/groq/groq.py +97 -35
- agno/models/huggingface/huggingface.py +92 -27
- agno/models/ibm/watsonx.py +72 -13
- agno/models/litellm/chat.py +85 -13
- agno/models/message.py +38 -144
- agno/models/meta/llama.py +85 -49
- agno/models/metrics.py +120 -0
- agno/models/mistral/mistral.py +90 -21
- agno/models/ollama/__init__.py +0 -2
- agno/models/ollama/chat.py +84 -46
- agno/models/openai/chat.py +135 -27
- agno/models/openai/responses.py +233 -115
- agno/models/perplexity/perplexity.py +26 -2
- agno/models/portkey/portkey.py +0 -7
- agno/models/response.py +14 -8
- agno/models/utils.py +20 -0
- agno/models/vercel/__init__.py +2 -2
- agno/models/vercel/v0.py +1 -1
- agno/models/vllm/__init__.py +2 -2
- agno/models/vllm/vllm.py +3 -3
- agno/models/xai/xai.py +10 -10
- agno/os/__init__.py +3 -0
- agno/os/app.py +393 -0
- agno/os/auth.py +47 -0
- agno/os/config.py +103 -0
- agno/os/interfaces/agui/__init__.py +3 -0
- agno/os/interfaces/agui/agui.py +31 -0
- agno/{app/agui/async_router.py → os/interfaces/agui/router.py} +16 -16
- agno/{app → os/interfaces}/agui/utils.py +65 -28
- agno/os/interfaces/base.py +21 -0
- agno/os/interfaces/slack/__init__.py +3 -0
- agno/{app/slack/async_router.py → os/interfaces/slack/router.py} +3 -5
- agno/os/interfaces/slack/slack.py +33 -0
- agno/os/interfaces/whatsapp/__init__.py +3 -0
- agno/{app/whatsapp/async_router.py → os/interfaces/whatsapp/router.py} +4 -7
- agno/os/interfaces/whatsapp/whatsapp.py +30 -0
- agno/os/router.py +843 -0
- agno/os/routers/__init__.py +3 -0
- agno/os/routers/evals/__init__.py +3 -0
- agno/os/routers/evals/evals.py +204 -0
- agno/os/routers/evals/schemas.py +142 -0
- agno/os/routers/evals/utils.py +161 -0
- agno/os/routers/knowledge/__init__.py +3 -0
- agno/os/routers/knowledge/knowledge.py +413 -0
- agno/os/routers/knowledge/schemas.py +118 -0
- agno/os/routers/memory/__init__.py +3 -0
- agno/os/routers/memory/memory.py +179 -0
- agno/os/routers/memory/schemas.py +58 -0
- agno/os/routers/metrics/__init__.py +3 -0
- agno/os/routers/metrics/metrics.py +58 -0
- agno/os/routers/metrics/schemas.py +47 -0
- agno/os/routers/session/__init__.py +3 -0
- agno/os/routers/session/session.py +163 -0
- agno/os/schema.py +892 -0
- agno/{app/playground → os}/settings.py +8 -15
- agno/os/utils.py +270 -0
- agno/reasoning/azure_ai_foundry.py +4 -4
- agno/reasoning/deepseek.py +4 -4
- agno/reasoning/default.py +6 -11
- agno/reasoning/groq.py +4 -4
- agno/reasoning/helpers.py +4 -6
- agno/reasoning/ollama.py +4 -4
- agno/reasoning/openai.py +4 -4
- agno/run/{response.py → agent.py} +144 -72
- agno/run/base.py +44 -58
- agno/run/cancel.py +83 -0
- agno/run/team.py +133 -77
- agno/run/workflow.py +537 -12
- agno/session/__init__.py +10 -0
- agno/session/agent.py +244 -0
- agno/session/summary.py +225 -0
- agno/session/team.py +262 -0
- agno/{storage/session/v2 → session}/workflow.py +47 -24
- agno/team/__init__.py +15 -16
- agno/team/team.py +2967 -4243
- agno/tools/agentql.py +14 -5
- agno/tools/airflow.py +9 -4
- agno/tools/api.py +7 -3
- agno/tools/apify.py +2 -46
- agno/tools/arxiv.py +8 -3
- agno/tools/aws_lambda.py +7 -5
- agno/tools/aws_ses.py +7 -1
- agno/tools/baidusearch.py +4 -1
- agno/tools/bitbucket.py +4 -4
- agno/tools/brandfetch.py +14 -11
- agno/tools/bravesearch.py +4 -1
- agno/tools/brightdata.py +42 -22
- agno/tools/browserbase.py +13 -4
- agno/tools/calcom.py +12 -10
- agno/tools/calculator.py +10 -27
- agno/tools/cartesia.py +18 -13
- agno/tools/{clickup_tool.py → clickup.py} +12 -25
- agno/tools/confluence.py +71 -18
- agno/tools/crawl4ai.py +7 -1
- agno/tools/csv_toolkit.py +9 -8
- agno/tools/dalle.py +18 -11
- agno/tools/daytona.py +13 -16
- agno/tools/decorator.py +6 -3
- agno/tools/desi_vocal.py +16 -7
- agno/tools/discord.py +11 -8
- agno/tools/docker.py +30 -42
- agno/tools/duckdb.py +34 -53
- agno/tools/duckduckgo.py +8 -7
- agno/tools/e2b.py +62 -62
- agno/tools/eleven_labs.py +35 -28
- agno/tools/email.py +4 -1
- agno/tools/evm.py +7 -1
- agno/tools/exa.py +19 -14
- agno/tools/fal.py +29 -29
- agno/tools/file.py +9 -8
- agno/tools/financial_datasets.py +25 -44
- agno/tools/firecrawl.py +22 -22
- agno/tools/function.py +68 -17
- agno/tools/giphy.py +22 -10
- agno/tools/github.py +48 -126
- agno/tools/gmail.py +46 -62
- agno/tools/google_bigquery.py +7 -6
- agno/tools/google_maps.py +11 -26
- agno/tools/googlesearch.py +7 -2
- agno/tools/googlesheets.py +21 -17
- agno/tools/hackernews.py +9 -5
- agno/tools/jina.py +5 -4
- agno/tools/jira.py +18 -9
- agno/tools/knowledge.py +31 -32
- agno/tools/linear.py +18 -33
- agno/tools/linkup.py +5 -1
- agno/tools/local_file_system.py +8 -5
- agno/tools/lumalab.py +31 -19
- agno/tools/mem0.py +18 -12
- agno/tools/memori.py +14 -10
- agno/tools/mlx_transcribe.py +3 -2
- agno/tools/models/azure_openai.py +32 -14
- agno/tools/models/gemini.py +58 -31
- agno/tools/models/groq.py +29 -20
- agno/tools/models/nebius.py +27 -11
- agno/tools/models_labs.py +39 -15
- agno/tools/moviepy_video.py +7 -6
- agno/tools/neo4j.py +134 -0
- agno/tools/newspaper.py +7 -2
- agno/tools/newspaper4k.py +8 -3
- agno/tools/openai.py +57 -26
- agno/tools/openbb.py +12 -11
- agno/tools/opencv.py +62 -46
- agno/tools/openweather.py +14 -12
- agno/tools/pandas.py +11 -3
- agno/tools/postgres.py +4 -12
- agno/tools/pubmed.py +4 -1
- agno/tools/python.py +9 -22
- agno/tools/reasoning.py +35 -27
- agno/tools/reddit.py +11 -26
- agno/tools/replicate.py +54 -41
- agno/tools/resend.py +4 -1
- agno/tools/scrapegraph.py +15 -14
- agno/tools/searxng.py +10 -23
- agno/tools/serpapi.py +6 -3
- agno/tools/serper.py +13 -4
- agno/tools/shell.py +9 -2
- agno/tools/slack.py +12 -11
- agno/tools/sleep.py +3 -2
- agno/tools/spider.py +24 -4
- agno/tools/sql.py +7 -6
- agno/tools/tavily.py +6 -4
- agno/tools/telegram.py +12 -4
- agno/tools/todoist.py +11 -31
- agno/tools/toolkit.py +1 -1
- agno/tools/trafilatura.py +22 -6
- agno/tools/trello.py +9 -22
- agno/tools/twilio.py +10 -3
- agno/tools/user_control_flow.py +6 -1
- agno/tools/valyu.py +34 -5
- agno/tools/visualization.py +19 -28
- agno/tools/webbrowser.py +4 -3
- agno/tools/webex.py +11 -7
- agno/tools/website.py +15 -46
- agno/tools/webtools.py +12 -4
- agno/tools/whatsapp.py +5 -9
- agno/tools/wikipedia.py +20 -13
- agno/tools/x.py +14 -13
- agno/tools/yfinance.py +13 -40
- agno/tools/youtube.py +26 -20
- agno/tools/zendesk.py +7 -2
- agno/tools/zep.py +10 -7
- agno/tools/zoom.py +10 -9
- agno/utils/common.py +1 -19
- agno/utils/events.py +95 -118
- agno/utils/knowledge.py +29 -0
- agno/utils/location.py +2 -2
- agno/utils/log.py +2 -2
- agno/utils/mcp.py +11 -5
- agno/utils/media.py +39 -0
- agno/utils/message.py +12 -1
- agno/utils/models/claude.py +6 -4
- agno/utils/models/mistral.py +8 -7
- agno/utils/models/schema_utils.py +3 -3
- agno/utils/pprint.py +33 -32
- agno/utils/print_response/agent.py +779 -0
- agno/utils/print_response/team.py +1565 -0
- agno/utils/print_response/workflow.py +1451 -0
- agno/utils/prompts.py +14 -14
- agno/utils/reasoning.py +87 -0
- agno/utils/response.py +42 -42
- agno/utils/string.py +8 -22
- agno/utils/team.py +50 -0
- agno/utils/timer.py +2 -2
- agno/vectordb/base.py +33 -21
- agno/vectordb/cassandra/cassandra.py +287 -23
- agno/vectordb/chroma/chromadb.py +482 -59
- agno/vectordb/clickhouse/clickhousedb.py +270 -63
- agno/vectordb/couchbase/couchbase.py +309 -29
- agno/vectordb/lancedb/lance_db.py +360 -21
- agno/vectordb/langchaindb/__init__.py +5 -0
- agno/vectordb/langchaindb/langchaindb.py +145 -0
- agno/vectordb/lightrag/__init__.py +5 -0
- agno/vectordb/lightrag/lightrag.py +374 -0
- agno/vectordb/llamaindex/llamaindexdb.py +127 -0
- agno/vectordb/milvus/milvus.py +242 -32
- agno/vectordb/mongodb/mongodb.py +200 -24
- agno/vectordb/pgvector/pgvector.py +319 -37
- agno/vectordb/pineconedb/pineconedb.py +221 -27
- agno/vectordb/qdrant/qdrant.py +356 -14
- agno/vectordb/singlestore/singlestore.py +286 -29
- agno/vectordb/surrealdb/surrealdb.py +187 -7
- agno/vectordb/upstashdb/upstashdb.py +342 -26
- agno/vectordb/weaviate/weaviate.py +227 -165
- agno/workflow/__init__.py +17 -13
- agno/workflow/{v2/condition.py → condition.py} +135 -32
- agno/workflow/{v2/loop.py → loop.py} +115 -28
- agno/workflow/{v2/parallel.py → parallel.py} +138 -108
- agno/workflow/{v2/router.py → router.py} +133 -32
- agno/workflow/{v2/step.py → step.py} +200 -42
- agno/workflow/{v2/steps.py → steps.py} +147 -66
- agno/workflow/types.py +482 -0
- agno/workflow/workflow.py +2394 -696
- agno-2.0.0a1.dist-info/METADATA +355 -0
- agno-2.0.0a1.dist-info/RECORD +514 -0
- agno/agent/metrics.py +0 -107
- agno/api/app.py +0 -35
- agno/api/playground.py +0 -92
- agno/api/schemas/app.py +0 -12
- agno/api/schemas/playground.py +0 -22
- agno/api/schemas/user.py +0 -35
- agno/api/schemas/workspace.py +0 -46
- agno/api/user.py +0 -160
- agno/api/workflows.py +0 -33
- agno/api/workspace.py +0 -175
- agno/app/agui/__init__.py +0 -3
- agno/app/agui/app.py +0 -17
- agno/app/agui/sync_router.py +0 -120
- agno/app/base.py +0 -186
- agno/app/discord/__init__.py +0 -3
- agno/app/fastapi/__init__.py +0 -3
- agno/app/fastapi/app.py +0 -107
- agno/app/fastapi/async_router.py +0 -457
- agno/app/fastapi/sync_router.py +0 -448
- agno/app/playground/app.py +0 -228
- agno/app/playground/async_router.py +0 -1050
- agno/app/playground/deploy.py +0 -249
- agno/app/playground/operator.py +0 -183
- agno/app/playground/schemas.py +0 -220
- agno/app/playground/serve.py +0 -55
- agno/app/playground/sync_router.py +0 -1042
- agno/app/playground/utils.py +0 -46
- agno/app/settings.py +0 -15
- agno/app/slack/__init__.py +0 -3
- agno/app/slack/app.py +0 -19
- agno/app/slack/sync_router.py +0 -92
- agno/app/utils.py +0 -54
- agno/app/whatsapp/__init__.py +0 -3
- agno/app/whatsapp/app.py +0 -15
- agno/app/whatsapp/sync_router.py +0 -197
- agno/cli/auth_server.py +0 -249
- agno/cli/config.py +0 -274
- agno/cli/console.py +0 -88
- agno/cli/credentials.py +0 -23
- agno/cli/entrypoint.py +0 -571
- agno/cli/operator.py +0 -357
- agno/cli/settings.py +0 -96
- agno/cli/ws/ws_cli.py +0 -817
- agno/constants.py +0 -13
- agno/document/__init__.py +0 -5
- agno/document/chunking/semantic.py +0 -45
- agno/document/chunking/strategy.py +0 -31
- agno/document/reader/__init__.py +0 -5
- agno/document/reader/base.py +0 -47
- agno/document/reader/docx_reader.py +0 -60
- agno/document/reader/gcs/pdf_reader.py +0 -44
- agno/document/reader/s3/pdf_reader.py +0 -59
- agno/document/reader/s3/text_reader.py +0 -63
- agno/document/reader/url_reader.py +0 -59
- agno/document/reader/youtube_reader.py +0 -58
- agno/embedder/__init__.py +0 -5
- agno/embedder/langdb.py +0 -80
- agno/embedder/mistral.py +0 -82
- agno/embedder/openai.py +0 -78
- agno/file/__init__.py +0 -5
- agno/file/file.py +0 -16
- agno/file/local/csv.py +0 -32
- agno/file/local/txt.py +0 -19
- agno/infra/app.py +0 -240
- agno/infra/base.py +0 -144
- agno/infra/context.py +0 -20
- agno/infra/db_app.py +0 -52
- agno/infra/resource.py +0 -205
- agno/infra/resources.py +0 -55
- agno/knowledge/agent.py +0 -698
- agno/knowledge/arxiv.py +0 -33
- agno/knowledge/combined.py +0 -36
- agno/knowledge/csv.py +0 -144
- agno/knowledge/csv_url.py +0 -124
- agno/knowledge/document.py +0 -223
- agno/knowledge/docx.py +0 -137
- agno/knowledge/firecrawl.py +0 -34
- agno/knowledge/gcs/__init__.py +0 -0
- agno/knowledge/gcs/base.py +0 -39
- agno/knowledge/gcs/pdf.py +0 -125
- agno/knowledge/json.py +0 -137
- agno/knowledge/langchain.py +0 -71
- agno/knowledge/light_rag.py +0 -273
- agno/knowledge/llamaindex.py +0 -66
- agno/knowledge/markdown.py +0 -154
- agno/knowledge/pdf.py +0 -164
- agno/knowledge/pdf_bytes.py +0 -42
- agno/knowledge/pdf_url.py +0 -148
- agno/knowledge/s3/__init__.py +0 -0
- agno/knowledge/s3/base.py +0 -64
- agno/knowledge/s3/pdf.py +0 -33
- agno/knowledge/s3/text.py +0 -34
- agno/knowledge/text.py +0 -141
- agno/knowledge/url.py +0 -46
- agno/knowledge/website.py +0 -179
- agno/knowledge/wikipedia.py +0 -32
- agno/knowledge/youtube.py +0 -35
- agno/memory/agent.py +0 -423
- agno/memory/classifier.py +0 -104
- agno/memory/db/__init__.py +0 -5
- agno/memory/db/base.py +0 -42
- agno/memory/db/mongodb.py +0 -189
- agno/memory/db/postgres.py +0 -203
- agno/memory/db/sqlite.py +0 -193
- agno/memory/memory.py +0 -22
- agno/memory/row.py +0 -36
- agno/memory/summarizer.py +0 -201
- agno/memory/summary.py +0 -19
- agno/memory/team.py +0 -415
- agno/memory/v2/__init__.py +0 -2
- agno/memory/v2/db/__init__.py +0 -1
- agno/memory/v2/db/base.py +0 -42
- agno/memory/v2/db/firestore.py +0 -339
- agno/memory/v2/db/mongodb.py +0 -196
- agno/memory/v2/db/postgres.py +0 -214
- agno/memory/v2/db/redis.py +0 -187
- agno/memory/v2/db/schema.py +0 -54
- agno/memory/v2/db/sqlite.py +0 -209
- agno/memory/v2/manager.py +0 -437
- agno/memory/v2/memory.py +0 -1097
- agno/memory/v2/schema.py +0 -55
- agno/memory/v2/summarizer.py +0 -215
- agno/memory/workflow.py +0 -38
- agno/models/ollama/tools.py +0 -430
- agno/models/qwen/__init__.py +0 -5
- agno/playground/__init__.py +0 -10
- agno/playground/deploy.py +0 -3
- agno/playground/playground.py +0 -3
- agno/playground/serve.py +0 -3
- agno/playground/settings.py +0 -3
- agno/reranker/__init__.py +0 -0
- agno/run/v2/__init__.py +0 -0
- agno/run/v2/workflow.py +0 -567
- agno/storage/__init__.py +0 -0
- agno/storage/agent/__init__.py +0 -0
- agno/storage/agent/dynamodb.py +0 -1
- agno/storage/agent/json.py +0 -1
- agno/storage/agent/mongodb.py +0 -1
- agno/storage/agent/postgres.py +0 -1
- agno/storage/agent/singlestore.py +0 -1
- agno/storage/agent/sqlite.py +0 -1
- agno/storage/agent/yaml.py +0 -1
- agno/storage/base.py +0 -60
- agno/storage/dynamodb.py +0 -673
- agno/storage/firestore.py +0 -297
- agno/storage/gcs_json.py +0 -261
- agno/storage/in_memory.py +0 -234
- agno/storage/json.py +0 -237
- agno/storage/mongodb.py +0 -328
- agno/storage/mysql.py +0 -685
- agno/storage/postgres.py +0 -682
- agno/storage/redis.py +0 -336
- agno/storage/session/__init__.py +0 -16
- agno/storage/session/agent.py +0 -64
- agno/storage/session/team.py +0 -63
- agno/storage/session/v2/__init__.py +0 -5
- agno/storage/session/workflow.py +0 -61
- agno/storage/singlestore.py +0 -606
- agno/storage/sqlite.py +0 -646
- agno/storage/workflow/__init__.py +0 -0
- agno/storage/workflow/mongodb.py +0 -1
- agno/storage/workflow/postgres.py +0 -1
- agno/storage/workflow/sqlite.py +0 -1
- agno/storage/yaml.py +0 -241
- agno/tools/thinking.py +0 -73
- agno/utils/defaults.py +0 -57
- agno/utils/filesystem.py +0 -39
- agno/utils/git.py +0 -52
- agno/utils/json_io.py +0 -30
- agno/utils/load_env.py +0 -19
- agno/utils/py_io.py +0 -19
- agno/utils/pyproject.py +0 -18
- agno/utils/resource_filter.py +0 -31
- agno/workflow/v2/__init__.py +0 -21
- agno/workflow/v2/types.py +0 -357
- agno/workflow/v2/workflow.py +0 -3312
- agno/workspace/__init__.py +0 -0
- agno/workspace/config.py +0 -325
- agno/workspace/enums.py +0 -6
- agno/workspace/helpers.py +0 -52
- agno/workspace/operator.py +0 -757
- agno/workspace/settings.py +0 -158
- agno-1.8.0.dist-info/METADATA +0 -979
- agno-1.8.0.dist-info/RECORD +0 -565
- agno-1.8.0.dist-info/entry_points.txt +0 -3
- /agno/{app → db/migrations}/__init__.py +0 -0
- /agno/{app/playground/__init__.py → db/schemas/metrics.py} +0 -0
- /agno/{cli → integrations}/__init__.py +0 -0
- /agno/{cli/ws → knowledge/chunking}/__init__.py +0 -0
- /agno/{document/chunking → knowledge/remote_content}/__init__.py +0 -0
- /agno/{document/reader/gcs → knowledge/reranker}/__init__.py +0 -0
- /agno/{document/reader/s3 → os/interfaces}/__init__.py +0 -0
- /agno/{app → os/interfaces}/slack/security.py +0 -0
- /agno/{app → os/interfaces}/whatsapp/security.py +0 -0
- /agno/{file/local → utils/print_response}/__init__.py +0 -0
- /agno/{infra → vectordb/llamaindex}/__init__.py +0 -0
- {agno-1.8.0.dist-info → agno-2.0.0a1.dist-info}/WHEEL +0 -0
- {agno-1.8.0.dist-info → agno-2.0.0a1.dist-info}/licenses/LICENSE +0 -0
- {agno-1.8.0.dist-info → agno-2.0.0a1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1438 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from datetime import date, datetime, timedelta, timezone
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
from agno.db.base import BaseDb, SessionType
|
|
7
|
+
from agno.db.firestore.utils import (
|
|
8
|
+
apply_pagination,
|
|
9
|
+
apply_sorting,
|
|
10
|
+
bulk_upsert_metrics,
|
|
11
|
+
calculate_date_metrics,
|
|
12
|
+
create_collection_indexes,
|
|
13
|
+
fetch_all_sessions_data,
|
|
14
|
+
get_dates_to_calculate_metrics_for,
|
|
15
|
+
)
|
|
16
|
+
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
17
|
+
from agno.db.schemas.knowledge import KnowledgeRow
|
|
18
|
+
from agno.db.schemas.memory import UserMemory
|
|
19
|
+
from agno.db.utils import deserialize_session_json_fields, serialize_session_json_fields
|
|
20
|
+
from agno.session import AgentSession, Session, TeamSession, WorkflowSession
|
|
21
|
+
from agno.utils.log import log_debug, log_error, log_info
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from google.cloud.firestore import Client, FieldFilter # type: ignore[import-untyped]
|
|
25
|
+
except ImportError:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
"`google-cloud-firestore` not installed. Please install it using `pip install google-cloud-firestore`"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FirestoreDb(BaseDb):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
db_client: Optional[Client] = None,
|
|
35
|
+
project_id: Optional[str] = None,
|
|
36
|
+
session_collection: Optional[str] = None,
|
|
37
|
+
memory_collection: Optional[str] = None,
|
|
38
|
+
metrics_collection: Optional[str] = None,
|
|
39
|
+
eval_collection: Optional[str] = None,
|
|
40
|
+
knowledge_collection: Optional[str] = None,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Interface for interacting with a Firestore database.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
db_client (Optional[Client]): The Firestore client to use.
|
|
47
|
+
project_id (Optional[str]): The GCP project ID for Firestore.
|
|
48
|
+
session_collection (Optional[str]): Name of the collection to store sessions.
|
|
49
|
+
memory_collection (Optional[str]): Name of the collection to store memories.
|
|
50
|
+
metrics_collection (Optional[str]): Name of the collection to store metrics.
|
|
51
|
+
eval_collection (Optional[str]): Name of the collection to store evaluation runs.
|
|
52
|
+
knowledge_collection (Optional[str]): Name of the collection to store knowledge documents.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If neither project_id nor db_client is provided.
|
|
56
|
+
"""
|
|
57
|
+
super().__init__(
|
|
58
|
+
session_table=session_collection,
|
|
59
|
+
memory_table=memory_collection,
|
|
60
|
+
metrics_table=metrics_collection,
|
|
61
|
+
eval_table=eval_collection,
|
|
62
|
+
knowledge_table=knowledge_collection,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
_client: Optional[Client] = db_client
|
|
66
|
+
if _client is None and project_id is not None:
|
|
67
|
+
_client = Client(project=project_id)
|
|
68
|
+
if _client is None:
|
|
69
|
+
raise ValueError("One of project_id or db_client must be provided")
|
|
70
|
+
|
|
71
|
+
self.project_id: Optional[str] = project_id
|
|
72
|
+
self.db_client: Client = _client
|
|
73
|
+
|
|
74
|
+
# -- DB methods --
|
|
75
|
+
|
|
76
|
+
def _get_collection(self, table_type: str, create_collection_if_not_found: Optional[bool] = True):
|
|
77
|
+
"""Get or create a collection based on table type.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
table_type (str): The type of table to get or create.
|
|
81
|
+
create_collection_if_not_found (Optional[bool]): Whether to create the collection if it doesn't exist.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
CollectionReference: The collection reference.
|
|
85
|
+
"""
|
|
86
|
+
if table_type == "sessions":
|
|
87
|
+
if not hasattr(self, "session_collection"):
|
|
88
|
+
if self.session_table_name is None:
|
|
89
|
+
raise ValueError("Session collection was not provided on initialization")
|
|
90
|
+
self.session_collection = self._get_or_create_collection(
|
|
91
|
+
collection_name=self.session_table_name,
|
|
92
|
+
collection_type="sessions",
|
|
93
|
+
create_collection_if_not_found=create_collection_if_not_found,
|
|
94
|
+
)
|
|
95
|
+
return self.session_collection
|
|
96
|
+
|
|
97
|
+
if table_type == "memories":
|
|
98
|
+
if not hasattr(self, "memory_collection"):
|
|
99
|
+
if self.memory_table_name is None:
|
|
100
|
+
raise ValueError("Memory collection was not provided on initialization")
|
|
101
|
+
self.memory_collection = self._get_or_create_collection(
|
|
102
|
+
collection_name=self.memory_table_name,
|
|
103
|
+
collection_type="memories",
|
|
104
|
+
create_collection_if_not_found=create_collection_if_not_found,
|
|
105
|
+
)
|
|
106
|
+
return self.memory_collection
|
|
107
|
+
|
|
108
|
+
if table_type == "metrics":
|
|
109
|
+
if not hasattr(self, "metrics_collection"):
|
|
110
|
+
if self.metrics_table_name is None:
|
|
111
|
+
raise ValueError("Metrics collection was not provided on initialization")
|
|
112
|
+
self.metrics_collection = self._get_or_create_collection(
|
|
113
|
+
collection_name=self.metrics_table_name,
|
|
114
|
+
collection_type="metrics",
|
|
115
|
+
create_collection_if_not_found=create_collection_if_not_found,
|
|
116
|
+
)
|
|
117
|
+
return self.metrics_collection
|
|
118
|
+
|
|
119
|
+
if table_type == "evals":
|
|
120
|
+
if not hasattr(self, "eval_collection"):
|
|
121
|
+
if self.eval_table_name is None:
|
|
122
|
+
raise ValueError("Eval collection was not provided on initialization")
|
|
123
|
+
self.eval_collection = self._get_or_create_collection(
|
|
124
|
+
collection_name=self.eval_table_name,
|
|
125
|
+
collection_type="evals",
|
|
126
|
+
create_collection_if_not_found=create_collection_if_not_found,
|
|
127
|
+
)
|
|
128
|
+
return self.eval_collection
|
|
129
|
+
|
|
130
|
+
if table_type == "knowledge":
|
|
131
|
+
if not hasattr(self, "knowledge_collection"):
|
|
132
|
+
if self.knowledge_table_name is None:
|
|
133
|
+
raise ValueError("Knowledge collection was not provided on initialization")
|
|
134
|
+
self.knowledge_collection = self._get_or_create_collection(
|
|
135
|
+
collection_name=self.knowledge_table_name,
|
|
136
|
+
collection_type="knowledge",
|
|
137
|
+
create_collection_if_not_found=create_collection_if_not_found,
|
|
138
|
+
)
|
|
139
|
+
return self.knowledge_collection
|
|
140
|
+
|
|
141
|
+
raise ValueError(f"Unknown table type: {table_type}")
|
|
142
|
+
|
|
143
|
+
def _get_or_create_collection(
|
|
144
|
+
self, collection_name: str, collection_type: str, create_collection_if_not_found: Optional[bool] = True
|
|
145
|
+
):
|
|
146
|
+
"""Get or create a collection with proper indexes.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
collection_name (str): The name of the collection to get or create.
|
|
150
|
+
collection_type (str): The type of collection to get or create.
|
|
151
|
+
create_collection_if_not_found (Optional[bool]): Whether to create the collection if it doesn't exist.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Optional[CollectionReference]: The collection reference.
|
|
155
|
+
"""
|
|
156
|
+
try:
|
|
157
|
+
collection_ref = self.db_client.collection(collection_name)
|
|
158
|
+
|
|
159
|
+
if not hasattr(self, f"_{collection_name}_initialized"):
|
|
160
|
+
if not create_collection_if_not_found:
|
|
161
|
+
return None
|
|
162
|
+
create_collection_indexes(self.db_client, collection_name, collection_type)
|
|
163
|
+
setattr(self, f"_{collection_name}_initialized", True)
|
|
164
|
+
|
|
165
|
+
return collection_ref
|
|
166
|
+
|
|
167
|
+
except Exception as e:
|
|
168
|
+
log_error(f"Error getting collection {collection_name}: {e}")
|
|
169
|
+
raise
|
|
170
|
+
|
|
171
|
+
# -- Session methods --
|
|
172
|
+
|
|
173
|
+
def delete_session(self, session_id: str) -> bool:
|
|
174
|
+
"""Delete a session from the database.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
session_id (str): The ID of the session to delete.
|
|
178
|
+
session_type (SessionType): The type of session to delete. Defaults to SessionType.AGENT.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
bool: True if the session was deleted, False otherwise.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
Exception: If there is an error deleting the session.
|
|
185
|
+
"""
|
|
186
|
+
try:
|
|
187
|
+
collection_ref = self._get_collection(table_type="sessions")
|
|
188
|
+
docs = collection_ref.where(filter=FieldFilter("session_id", "==", session_id)).stream()
|
|
189
|
+
|
|
190
|
+
for doc in docs:
|
|
191
|
+
doc.reference.delete()
|
|
192
|
+
log_debug(f"Successfully deleted session with session_id: {session_id}")
|
|
193
|
+
return True
|
|
194
|
+
|
|
195
|
+
log_debug(f"No session found to delete with session_id: {session_id}")
|
|
196
|
+
return False
|
|
197
|
+
|
|
198
|
+
except Exception as e:
|
|
199
|
+
log_error(f"Error deleting session: {e}")
|
|
200
|
+
return False
|
|
201
|
+
|
|
202
|
+
def delete_sessions(self, session_ids: List[str]) -> None:
|
|
203
|
+
"""Delete multiple sessions from the database.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
session_ids (List[str]): The IDs of the sessions to delete.
|
|
207
|
+
"""
|
|
208
|
+
try:
|
|
209
|
+
collection_ref = self._get_collection(table_type="sessions")
|
|
210
|
+
batch = self.db_client.batch()
|
|
211
|
+
|
|
212
|
+
deleted_count = 0
|
|
213
|
+
for session_id in session_ids:
|
|
214
|
+
docs = collection_ref.where(filter=FieldFilter("session_id", "==", session_id)).stream()
|
|
215
|
+
for doc in docs:
|
|
216
|
+
batch.delete(doc.reference)
|
|
217
|
+
deleted_count += 1
|
|
218
|
+
|
|
219
|
+
batch.commit()
|
|
220
|
+
|
|
221
|
+
log_debug(f"Successfully deleted {deleted_count} sessions")
|
|
222
|
+
|
|
223
|
+
except Exception as e:
|
|
224
|
+
log_error(f"Error deleting sessions: {e}")
|
|
225
|
+
|
|
226
|
+
def get_session(
|
|
227
|
+
self,
|
|
228
|
+
session_id: str,
|
|
229
|
+
session_type: SessionType,
|
|
230
|
+
user_id: Optional[str] = None,
|
|
231
|
+
deserialize: Optional[bool] = True,
|
|
232
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
233
|
+
"""Read a session from the database.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
session_id (str): The ID of the session to get.
|
|
237
|
+
user_id (Optional[str]): The ID of the user to get the session for.
|
|
238
|
+
session_type (Optional[SessionType]): The type of session to get.
|
|
239
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Union[Session, Dict[str, Any], None]:
|
|
243
|
+
- When deserialize=True: Session object
|
|
244
|
+
- When deserialize=False: Session dictionary
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
Exception: If there is an error reading the session.
|
|
248
|
+
"""
|
|
249
|
+
try:
|
|
250
|
+
collection_ref = self._get_collection(table_type="sessions")
|
|
251
|
+
query = collection_ref.where(filter=FieldFilter("session_id", "==", session_id))
|
|
252
|
+
|
|
253
|
+
if user_id is not None:
|
|
254
|
+
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
|
255
|
+
if session_type is not None:
|
|
256
|
+
query = query.where(filter=FieldFilter("session_type", "==", session_type.value))
|
|
257
|
+
|
|
258
|
+
docs = query.stream()
|
|
259
|
+
result = None
|
|
260
|
+
for doc in docs:
|
|
261
|
+
result = doc.to_dict()
|
|
262
|
+
break
|
|
263
|
+
|
|
264
|
+
if result is None:
|
|
265
|
+
return None
|
|
266
|
+
|
|
267
|
+
session = deserialize_session_json_fields(result)
|
|
268
|
+
|
|
269
|
+
if not deserialize:
|
|
270
|
+
return session
|
|
271
|
+
|
|
272
|
+
if session_type == SessionType.AGENT:
|
|
273
|
+
return AgentSession.from_dict(session)
|
|
274
|
+
elif session_type == SessionType.TEAM:
|
|
275
|
+
return TeamSession.from_dict(session)
|
|
276
|
+
else:
|
|
277
|
+
return WorkflowSession.from_dict(session)
|
|
278
|
+
|
|
279
|
+
except Exception as e:
|
|
280
|
+
log_error(f"Exception reading session: {e}")
|
|
281
|
+
return None
|
|
282
|
+
|
|
283
|
+
def get_sessions(
|
|
284
|
+
self,
|
|
285
|
+
session_type: Optional[SessionType] = None,
|
|
286
|
+
user_id: Optional[str] = None,
|
|
287
|
+
component_id: Optional[str] = None,
|
|
288
|
+
session_name: Optional[str] = None,
|
|
289
|
+
start_timestamp: Optional[int] = None,
|
|
290
|
+
end_timestamp: Optional[int] = None,
|
|
291
|
+
limit: Optional[int] = None,
|
|
292
|
+
page: Optional[int] = None,
|
|
293
|
+
sort_by: Optional[str] = None,
|
|
294
|
+
sort_order: Optional[str] = None,
|
|
295
|
+
deserialize: Optional[bool] = True,
|
|
296
|
+
) -> Union[List[Session], Tuple[List[Dict[str, Any]], int]]:
|
|
297
|
+
"""Get all sessions.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
session_type (Optional[SessionType]): The type of session to get.
|
|
301
|
+
user_id (Optional[str]): The ID of the user to get the session for.
|
|
302
|
+
component_id (Optional[str]): The ID of the component to get the session for.
|
|
303
|
+
session_name (Optional[str]): The name of the session to filter by.
|
|
304
|
+
start_timestamp (Optional[int]): The start timestamp to filter sessions by.
|
|
305
|
+
end_timestamp (Optional[int]): The end timestamp to filter sessions by.
|
|
306
|
+
limit (Optional[int]): The limit of the sessions to get.
|
|
307
|
+
page (Optional[int]): The page number to get.
|
|
308
|
+
sort_by (Optional[str]): The field to sort the sessions by.
|
|
309
|
+
sort_order (Optional[str]): The order to sort the sessions by.
|
|
310
|
+
deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Union[List[AgentSession], List[TeamSession], List[WorkflowSession], Tuple[List[Dict[str, Any]], int]]:
|
|
314
|
+
- When deserialize=True: List of Session objects
|
|
315
|
+
- When deserialize=False: List of session dictionaries and the total count
|
|
316
|
+
|
|
317
|
+
Raises:
|
|
318
|
+
Exception: If there is an error reading the sessions.
|
|
319
|
+
"""
|
|
320
|
+
try:
|
|
321
|
+
collection_ref = self._get_collection(table_type="sessions")
|
|
322
|
+
if collection_ref is None:
|
|
323
|
+
return [] if deserialize else ([], 0)
|
|
324
|
+
|
|
325
|
+
query = collection_ref
|
|
326
|
+
|
|
327
|
+
if user_id is not None:
|
|
328
|
+
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
|
329
|
+
if session_type is not None:
|
|
330
|
+
query = query.where(filter=FieldFilter("session_type", "==", session_type.value))
|
|
331
|
+
if component_id is not None:
|
|
332
|
+
if session_type == SessionType.AGENT:
|
|
333
|
+
query = query.where(filter=FieldFilter("agent_id", "==", component_id))
|
|
334
|
+
elif session_type == SessionType.TEAM:
|
|
335
|
+
query = query.where(filter=FieldFilter("team_id", "==", component_id))
|
|
336
|
+
elif session_type == SessionType.WORKFLOW:
|
|
337
|
+
query = query.where(filter=FieldFilter("workflow_id", "==", component_id))
|
|
338
|
+
if start_timestamp is not None:
|
|
339
|
+
query = query.where(filter=FieldFilter("created_at", ">=", start_timestamp))
|
|
340
|
+
if end_timestamp is not None:
|
|
341
|
+
query = query.where(filter=FieldFilter("created_at", "<=", end_timestamp))
|
|
342
|
+
if session_name is not None:
|
|
343
|
+
query = query.where(filter=FieldFilter("session_data.session_name", "==", session_name))
|
|
344
|
+
|
|
345
|
+
# Apply sorting
|
|
346
|
+
query = apply_sorting(query, sort_by, sort_order)
|
|
347
|
+
|
|
348
|
+
# Get all documents for counting before pagination
|
|
349
|
+
all_docs = query.stream()
|
|
350
|
+
all_records = [doc.to_dict() for doc in all_docs]
|
|
351
|
+
|
|
352
|
+
if not all_records:
|
|
353
|
+
return [] if deserialize else ([], 0)
|
|
354
|
+
|
|
355
|
+
all_sessions_raw = [deserialize_session_json_fields(record) for record in all_records]
|
|
356
|
+
|
|
357
|
+
# Get total count before pagination
|
|
358
|
+
total_count = len(all_sessions_raw)
|
|
359
|
+
|
|
360
|
+
# Apply pagination to the results
|
|
361
|
+
if limit is not None and page is not None:
|
|
362
|
+
start_index = (page - 1) * limit
|
|
363
|
+
end_index = start_index + limit
|
|
364
|
+
sessions_raw = all_sessions_raw[start_index:end_index]
|
|
365
|
+
elif limit is not None:
|
|
366
|
+
sessions_raw = all_sessions_raw[:limit]
|
|
367
|
+
else:
|
|
368
|
+
sessions_raw = all_sessions_raw
|
|
369
|
+
|
|
370
|
+
if not deserialize:
|
|
371
|
+
return sessions_raw, total_count
|
|
372
|
+
|
|
373
|
+
sessions: List[Union[AgentSession, TeamSession, WorkflowSession]] = []
|
|
374
|
+
for session in sessions_raw:
|
|
375
|
+
if session["session_type"] == SessionType.AGENT.value:
|
|
376
|
+
agent_session = AgentSession.from_dict(session)
|
|
377
|
+
if agent_session is not None:
|
|
378
|
+
sessions.append(agent_session)
|
|
379
|
+
elif session["session_type"] == SessionType.TEAM.value:
|
|
380
|
+
team_session = TeamSession.from_dict(session)
|
|
381
|
+
if team_session is not None:
|
|
382
|
+
sessions.append(team_session)
|
|
383
|
+
elif session["session_type"] == SessionType.WORKFLOW.value:
|
|
384
|
+
workflow_session = WorkflowSession.from_dict(session)
|
|
385
|
+
if workflow_session is not None:
|
|
386
|
+
sessions.append(workflow_session)
|
|
387
|
+
|
|
388
|
+
if not sessions:
|
|
389
|
+
return [] if deserialize else ([], 0)
|
|
390
|
+
|
|
391
|
+
return sessions
|
|
392
|
+
|
|
393
|
+
except Exception as e:
|
|
394
|
+
log_error(f"Exception reading sessions: {e}")
|
|
395
|
+
return [] if deserialize else ([], 0)
|
|
396
|
+
|
|
397
|
+
def rename_session(
|
|
398
|
+
self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
|
|
399
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
400
|
+
"""Rename a session in the database.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
session_id (str): The ID of the session to rename.
|
|
404
|
+
session_type (SessionType): The type of session to rename.
|
|
405
|
+
session_name (str): The new name of the session.
|
|
406
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Optional[Union[Session, Dict[str, Any]]]:
|
|
410
|
+
- When deserialize=True: Session object
|
|
411
|
+
- When deserialize=False: Session dictionary
|
|
412
|
+
|
|
413
|
+
Raises:
|
|
414
|
+
Exception: If there is an error renaming the session.
|
|
415
|
+
"""
|
|
416
|
+
try:
|
|
417
|
+
collection_ref = self._get_collection(table_type="sessions")
|
|
418
|
+
|
|
419
|
+
docs = collection_ref.where(filter=FieldFilter("session_id", "==", session_id)).stream()
|
|
420
|
+
doc_ref = next((doc.reference for doc in docs), None)
|
|
421
|
+
|
|
422
|
+
if doc_ref is None:
|
|
423
|
+
return None
|
|
424
|
+
|
|
425
|
+
doc_ref.update({"session_data.session_name": session_name, "updated_at": int(time.time())})
|
|
426
|
+
|
|
427
|
+
updated_doc = doc_ref.get()
|
|
428
|
+
if not updated_doc.exists:
|
|
429
|
+
return None
|
|
430
|
+
|
|
431
|
+
result = updated_doc.to_dict()
|
|
432
|
+
if result is None:
|
|
433
|
+
return None
|
|
434
|
+
deserialized_session = deserialize_session_json_fields(result)
|
|
435
|
+
|
|
436
|
+
log_debug(f"Renamed session with id '{session_id}' to '{session_name}'")
|
|
437
|
+
|
|
438
|
+
if not deserialize:
|
|
439
|
+
return deserialized_session
|
|
440
|
+
|
|
441
|
+
if session_type == SessionType.AGENT:
|
|
442
|
+
return AgentSession.from_dict(deserialized_session)
|
|
443
|
+
elif session_type == SessionType.TEAM:
|
|
444
|
+
return TeamSession.from_dict(deserialized_session)
|
|
445
|
+
else:
|
|
446
|
+
return WorkflowSession.from_dict(deserialized_session)
|
|
447
|
+
|
|
448
|
+
except Exception as e:
|
|
449
|
+
log_error(f"Exception renaming session: {e}")
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
def upsert_session(
|
|
453
|
+
self, session: Session, deserialize: Optional[bool] = True
|
|
454
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
455
|
+
"""Insert or update a session in the database.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
session (Session): The session to upsert.
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
Optional[Session]: The upserted session.
|
|
462
|
+
|
|
463
|
+
Raises:
|
|
464
|
+
Exception: If there is an error upserting the session.
|
|
465
|
+
"""
|
|
466
|
+
try:
|
|
467
|
+
collection_ref = self._get_collection(table_type="sessions", create_collection_if_not_found=True)
|
|
468
|
+
serialized_session_dict = serialize_session_json_fields(session.to_dict())
|
|
469
|
+
|
|
470
|
+
if isinstance(session, AgentSession):
|
|
471
|
+
record = {
|
|
472
|
+
"session_id": serialized_session_dict.get("session_id"),
|
|
473
|
+
"session_type": SessionType.AGENT.value,
|
|
474
|
+
"agent_id": serialized_session_dict.get("agent_id"),
|
|
475
|
+
"user_id": serialized_session_dict.get("user_id"),
|
|
476
|
+
"runs": serialized_session_dict.get("runs"),
|
|
477
|
+
"agent_data": serialized_session_dict.get("agent_data"),
|
|
478
|
+
"session_data": serialized_session_dict.get("session_data"),
|
|
479
|
+
"summary": serialized_session_dict.get("summary"),
|
|
480
|
+
"metadata": serialized_session_dict.get("metadata"),
|
|
481
|
+
"created_at": serialized_session_dict.get("created_at"),
|
|
482
|
+
"updated_at": int(time.time()),
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
elif isinstance(session, TeamSession):
|
|
486
|
+
record = {
|
|
487
|
+
"session_id": serialized_session_dict.get("session_id"),
|
|
488
|
+
"session_type": SessionType.TEAM.value,
|
|
489
|
+
"team_id": serialized_session_dict.get("team_id"),
|
|
490
|
+
"user_id": serialized_session_dict.get("user_id"),
|
|
491
|
+
"runs": serialized_session_dict.get("runs"),
|
|
492
|
+
"team_data": serialized_session_dict.get("team_data"),
|
|
493
|
+
"session_data": serialized_session_dict.get("session_data"),
|
|
494
|
+
"summary": serialized_session_dict.get("summary"),
|
|
495
|
+
"metadata": serialized_session_dict.get("metadata"),
|
|
496
|
+
"created_at": serialized_session_dict.get("created_at"),
|
|
497
|
+
"updated_at": int(time.time()),
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
elif isinstance(session, WorkflowSession):
|
|
501
|
+
record = {
|
|
502
|
+
"session_id": serialized_session_dict.get("session_id"),
|
|
503
|
+
"session_type": SessionType.WORKFLOW.value,
|
|
504
|
+
"workflow_id": serialized_session_dict.get("workflow_id"),
|
|
505
|
+
"user_id": serialized_session_dict.get("user_id"),
|
|
506
|
+
"runs": serialized_session_dict.get("runs"),
|
|
507
|
+
"workflow_data": serialized_session_dict.get("workflow_data"),
|
|
508
|
+
"session_data": serialized_session_dict.get("session_data"),
|
|
509
|
+
"summary": serialized_session_dict.get("summary"),
|
|
510
|
+
"metadata": serialized_session_dict.get("metadata"),
|
|
511
|
+
"created_at": serialized_session_dict.get("created_at"),
|
|
512
|
+
"updated_at": int(time.time()),
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
# Find existing document or create new one
|
|
516
|
+
docs = collection_ref.where(filter=FieldFilter("session_id", "==", record["session_id"])).stream()
|
|
517
|
+
doc_ref = next((doc.reference for doc in docs), None)
|
|
518
|
+
|
|
519
|
+
if doc_ref is None:
|
|
520
|
+
# Create new document
|
|
521
|
+
doc_ref = collection_ref.document()
|
|
522
|
+
|
|
523
|
+
doc_ref.set(record, merge=True)
|
|
524
|
+
|
|
525
|
+
# Get the updated document
|
|
526
|
+
updated_doc = doc_ref.get()
|
|
527
|
+
if not updated_doc.exists:
|
|
528
|
+
return None
|
|
529
|
+
|
|
530
|
+
result = updated_doc.to_dict()
|
|
531
|
+
if result is None:
|
|
532
|
+
return None
|
|
533
|
+
deserialized_session = deserialize_session_json_fields(result)
|
|
534
|
+
|
|
535
|
+
log_debug(f"Upserted session with id '{session.session_id}'")
|
|
536
|
+
|
|
537
|
+
if not deserialize:
|
|
538
|
+
return deserialized_session
|
|
539
|
+
|
|
540
|
+
if isinstance(session, AgentSession):
|
|
541
|
+
return AgentSession.from_dict(deserialized_session)
|
|
542
|
+
elif isinstance(session, TeamSession):
|
|
543
|
+
return TeamSession.from_dict(deserialized_session)
|
|
544
|
+
else:
|
|
545
|
+
return WorkflowSession.from_dict(deserialized_session)
|
|
546
|
+
|
|
547
|
+
except Exception as e:
|
|
548
|
+
log_error(f"Exception upserting session: {e}")
|
|
549
|
+
return None
|
|
550
|
+
|
|
551
|
+
# -- Memory methods --
|
|
552
|
+
|
|
553
|
+
def delete_user_memory(self, memory_id: str):
|
|
554
|
+
"""Delete a user memory from the database.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
memory_id (str): The ID of the memory to delete.
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
bool: True if the memory was deleted, False otherwise.
|
|
561
|
+
|
|
562
|
+
Raises:
|
|
563
|
+
Exception: If there is an error deleting the memory.
|
|
564
|
+
"""
|
|
565
|
+
try:
|
|
566
|
+
collection_ref = self._get_collection(table_type="memories")
|
|
567
|
+
docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
|
|
568
|
+
|
|
569
|
+
deleted_count = 0
|
|
570
|
+
for doc in docs:
|
|
571
|
+
doc.reference.delete()
|
|
572
|
+
deleted_count += 1
|
|
573
|
+
|
|
574
|
+
success = deleted_count > 0
|
|
575
|
+
if success:
|
|
576
|
+
log_debug(f"Successfully deleted user memory id: {memory_id}")
|
|
577
|
+
else:
|
|
578
|
+
log_debug(f"No user memory found with id: {memory_id}")
|
|
579
|
+
|
|
580
|
+
except Exception as e:
|
|
581
|
+
log_error(f"Error deleting user memory: {e}")
|
|
582
|
+
|
|
583
|
+
def delete_user_memories(self, memory_ids: List[str]) -> None:
|
|
584
|
+
"""Delete user memories from the database.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
memory_ids (List[str]): The IDs of the memories to delete.
|
|
588
|
+
|
|
589
|
+
Raises:
|
|
590
|
+
Exception: If there is an error deleting the memories.
|
|
591
|
+
"""
|
|
592
|
+
try:
|
|
593
|
+
collection_ref = self._get_collection(table_type="memories")
|
|
594
|
+
batch = self.db_client.batch()
|
|
595
|
+
deleted_count = 0
|
|
596
|
+
|
|
597
|
+
for memory_id in memory_ids:
|
|
598
|
+
docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
|
|
599
|
+
for doc in docs:
|
|
600
|
+
batch.delete(doc.reference)
|
|
601
|
+
deleted_count += 1
|
|
602
|
+
|
|
603
|
+
batch.commit()
|
|
604
|
+
|
|
605
|
+
if deleted_count == 0:
|
|
606
|
+
log_info(f"No memories found with ids: {memory_ids}")
|
|
607
|
+
else:
|
|
608
|
+
log_info(f"Successfully deleted {deleted_count} memories")
|
|
609
|
+
|
|
610
|
+
except Exception as e:
|
|
611
|
+
log_error(f"Error deleting memories: {e}")
|
|
612
|
+
|
|
613
|
+
def get_all_memory_topics(self, create_collection_if_not_found: Optional[bool] = True) -> List[str]:
|
|
614
|
+
"""Get all memory topics from the database.
|
|
615
|
+
|
|
616
|
+
Returns:
|
|
617
|
+
List[str]: The topics.
|
|
618
|
+
|
|
619
|
+
Raises:
|
|
620
|
+
Exception: If there is an error getting the topics.
|
|
621
|
+
"""
|
|
622
|
+
try:
|
|
623
|
+
collection_ref = self._get_collection(table_type="memories")
|
|
624
|
+
if collection_ref is None:
|
|
625
|
+
return []
|
|
626
|
+
|
|
627
|
+
docs = collection_ref.stream()
|
|
628
|
+
|
|
629
|
+
all_topics = set()
|
|
630
|
+
for doc in docs:
|
|
631
|
+
data = doc.to_dict()
|
|
632
|
+
topics = data.get("topics", [])
|
|
633
|
+
if topics:
|
|
634
|
+
all_topics.update(topics)
|
|
635
|
+
|
|
636
|
+
return [topic for topic in all_topics if topic]
|
|
637
|
+
|
|
638
|
+
except Exception as e:
|
|
639
|
+
log_error(f"Exception reading from collection: {e}")
|
|
640
|
+
return []
|
|
641
|
+
|
|
642
|
+
def get_user_memory(self, memory_id: str, deserialize: Optional[bool] = True) -> Optional[UserMemory]:
|
|
643
|
+
"""Get a memory from the database.
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
memory_id (str): The ID of the memory to get.
|
|
647
|
+
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
648
|
+
|
|
649
|
+
Returns:
|
|
650
|
+
Optional[UserMemory]:
|
|
651
|
+
- When deserialize=True: UserMemory object
|
|
652
|
+
- When deserialize=False: Memory dictionary
|
|
653
|
+
|
|
654
|
+
Raises:
|
|
655
|
+
Exception: If there is an error getting the memory.
|
|
656
|
+
"""
|
|
657
|
+
try:
|
|
658
|
+
collection_ref = self._get_collection(table_type="memories")
|
|
659
|
+
docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
|
|
660
|
+
|
|
661
|
+
result = None
|
|
662
|
+
for doc in docs:
|
|
663
|
+
result = doc.to_dict()
|
|
664
|
+
break
|
|
665
|
+
|
|
666
|
+
if result is None or not deserialize:
|
|
667
|
+
return result
|
|
668
|
+
|
|
669
|
+
return UserMemory.from_dict(result)
|
|
670
|
+
|
|
671
|
+
except Exception as e:
|
|
672
|
+
log_error(f"Exception reading from collection: {e}")
|
|
673
|
+
return None
|
|
674
|
+
|
|
675
|
+
def get_user_memories(
|
|
676
|
+
self,
|
|
677
|
+
user_id: Optional[str] = None,
|
|
678
|
+
agent_id: Optional[str] = None,
|
|
679
|
+
team_id: Optional[str] = None,
|
|
680
|
+
topics: Optional[List[str]] = None,
|
|
681
|
+
search_content: Optional[str] = None,
|
|
682
|
+
limit: Optional[int] = None,
|
|
683
|
+
page: Optional[int] = None,
|
|
684
|
+
sort_by: Optional[str] = None,
|
|
685
|
+
sort_order: Optional[str] = None,
|
|
686
|
+
deserialize: Optional[bool] = True,
|
|
687
|
+
) -> Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
|
|
688
|
+
"""Get all memories from the database as UserMemory objects.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
user_id (Optional[str]): The ID of the user to get the memories for.
|
|
692
|
+
agent_id (Optional[str]): The ID of the agent to get the memories for.
|
|
693
|
+
team_id (Optional[str]): The ID of the team to get the memories for.
|
|
694
|
+
topics (Optional[List[str]]): The topics to filter the memories by.
|
|
695
|
+
search_content (Optional[str]): The content to filter the memories by.
|
|
696
|
+
limit (Optional[int]): The limit of the memories to get.
|
|
697
|
+
page (Optional[int]): The page number to get.
|
|
698
|
+
sort_by (Optional[str]): The field to sort the memories by.
|
|
699
|
+
sort_order (Optional[str]): The order to sort the memories by.
|
|
700
|
+
deserialize (Optional[bool]): Whether to serialize the memories. Defaults to True.
|
|
701
|
+
create_table_if_not_found: Whether to create the index if it doesn't exist.
|
|
702
|
+
|
|
703
|
+
Returns:
|
|
704
|
+
Tuple[List[Dict[str, Any]], int]: A tuple containing the memories and the total count.
|
|
705
|
+
|
|
706
|
+
Raises:
|
|
707
|
+
Exception: If there is an error getting the memories.
|
|
708
|
+
"""
|
|
709
|
+
try:
|
|
710
|
+
collection_ref = self._get_collection(table_type="memories")
|
|
711
|
+
if collection_ref is None:
|
|
712
|
+
return [] if deserialize else ([], 0)
|
|
713
|
+
|
|
714
|
+
query = collection_ref
|
|
715
|
+
|
|
716
|
+
if user_id is not None:
|
|
717
|
+
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
|
718
|
+
if agent_id is not None:
|
|
719
|
+
query = query.where(filter=FieldFilter("agent_id", "==", agent_id))
|
|
720
|
+
if team_id is not None:
|
|
721
|
+
query = query.where(filter=FieldFilter("team_id", "==", team_id))
|
|
722
|
+
if topics is not None and len(topics) > 0:
|
|
723
|
+
query = query.where(filter=FieldFilter("topics", "array_contains_any", topics))
|
|
724
|
+
if search_content is not None:
|
|
725
|
+
query = query.where(filter=FieldFilter("memory", "==", search_content))
|
|
726
|
+
|
|
727
|
+
# Apply sorting
|
|
728
|
+
query = apply_sorting(query, sort_by, sort_order)
|
|
729
|
+
|
|
730
|
+
# Get all documents
|
|
731
|
+
docs = query.stream()
|
|
732
|
+
all_records = [doc.to_dict() for doc in docs]
|
|
733
|
+
|
|
734
|
+
total_count = len(all_records)
|
|
735
|
+
|
|
736
|
+
# Apply pagination to the filtered results
|
|
737
|
+
if limit is not None and page is not None:
|
|
738
|
+
start_index = (page - 1) * limit
|
|
739
|
+
end_index = start_index + limit
|
|
740
|
+
records = all_records[start_index:end_index]
|
|
741
|
+
elif limit is not None:
|
|
742
|
+
records = all_records[:limit]
|
|
743
|
+
else:
|
|
744
|
+
records = all_records
|
|
745
|
+
if not deserialize:
|
|
746
|
+
return records, total_count
|
|
747
|
+
|
|
748
|
+
return [UserMemory.from_dict(record) for record in records]
|
|
749
|
+
|
|
750
|
+
except Exception as e:
|
|
751
|
+
log_error(f"Exception reading from collection: {e}")
|
|
752
|
+
return []
|
|
753
|
+
|
|
754
|
+
def get_user_memory_stats(
|
|
755
|
+
self,
|
|
756
|
+
limit: Optional[int] = None,
|
|
757
|
+
page: Optional[int] = None,
|
|
758
|
+
) -> Tuple[List[Dict[str, Any]], int]:
|
|
759
|
+
"""Get user memories stats.
|
|
760
|
+
|
|
761
|
+
Args:
|
|
762
|
+
limit (Optional[int]): The limit of the memories to get.
|
|
763
|
+
page (Optional[int]): The page number to get.
|
|
764
|
+
|
|
765
|
+
Returns:
|
|
766
|
+
Tuple[List[Dict[str, Any]], int]: A tuple containing the memories stats and the total count.
|
|
767
|
+
|
|
768
|
+
Raises:
|
|
769
|
+
Exception: If there is an error getting the memories stats.
|
|
770
|
+
"""
|
|
771
|
+
try:
|
|
772
|
+
collection_ref = self._get_collection(table_type="memories")
|
|
773
|
+
docs = collection_ref.where(filter=FieldFilter("user_id", "!=", None)).stream()
|
|
774
|
+
|
|
775
|
+
user_stats = {}
|
|
776
|
+
for doc in docs:
|
|
777
|
+
data = doc.to_dict()
|
|
778
|
+
user_id = data.get("user_id")
|
|
779
|
+
if user_id:
|
|
780
|
+
if user_id not in user_stats:
|
|
781
|
+
user_stats[user_id] = {
|
|
782
|
+
"user_id": user_id,
|
|
783
|
+
"total_memories": 0,
|
|
784
|
+
"last_memory_updated_at": 0,
|
|
785
|
+
}
|
|
786
|
+
user_stats[user_id]["total_memories"] += 1
|
|
787
|
+
updated_at = data.get("updated_at", 0)
|
|
788
|
+
if updated_at > user_stats[user_id]["last_memory_updated_at"]:
|
|
789
|
+
user_stats[user_id]["last_memory_updated_at"] = updated_at
|
|
790
|
+
|
|
791
|
+
# Convert to list and sort
|
|
792
|
+
formatted_results = list(user_stats.values())
|
|
793
|
+
formatted_results.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)
|
|
794
|
+
|
|
795
|
+
total_count = len(formatted_results)
|
|
796
|
+
|
|
797
|
+
# Apply pagination
|
|
798
|
+
if limit is not None:
|
|
799
|
+
start_idx = 0
|
|
800
|
+
if page is not None:
|
|
801
|
+
start_idx = (page - 1) * limit
|
|
802
|
+
formatted_results = formatted_results[start_idx : start_idx + limit]
|
|
803
|
+
|
|
804
|
+
return formatted_results, total_count
|
|
805
|
+
|
|
806
|
+
except Exception as e:
|
|
807
|
+
log_error(f"Exception getting user memory stats: {e}")
|
|
808
|
+
return [], 0
|
|
809
|
+
|
|
810
|
+
def upsert_user_memory(
|
|
811
|
+
self, memory: UserMemory, deserialize: Optional[bool] = True
|
|
812
|
+
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
813
|
+
"""Upsert a user memory in the database.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
memory (UserMemory): The memory to upsert.
|
|
817
|
+
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
818
|
+
|
|
819
|
+
Returns:
|
|
820
|
+
Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
821
|
+
- When deserialize=True: UserMemory object
|
|
822
|
+
- When deserialize=False: Memory dictionary
|
|
823
|
+
|
|
824
|
+
Raises:
|
|
825
|
+
Exception: If there is an error upserting the memory.
|
|
826
|
+
"""
|
|
827
|
+
try:
|
|
828
|
+
collection_ref = self._get_collection(table_type="memories", create_collection_if_not_found=True)
|
|
829
|
+
if collection_ref is None:
|
|
830
|
+
return None
|
|
831
|
+
|
|
832
|
+
if memory.memory_id is None:
|
|
833
|
+
memory.memory_id = str(uuid4())
|
|
834
|
+
|
|
835
|
+
update_doc = memory.to_dict()
|
|
836
|
+
update_doc["updated_at"] = int(time.time())
|
|
837
|
+
|
|
838
|
+
# Find existing document or create new one
|
|
839
|
+
docs = collection_ref.where("memory_id", "==", memory.memory_id).stream()
|
|
840
|
+
doc_ref = next((doc.reference for doc in docs), None)
|
|
841
|
+
|
|
842
|
+
if doc_ref is None:
|
|
843
|
+
doc_ref = collection_ref.document()
|
|
844
|
+
|
|
845
|
+
doc_ref.set(update_doc, merge=True)
|
|
846
|
+
|
|
847
|
+
log_debug(f"Upserted user memory with id '{memory.memory_id}'")
|
|
848
|
+
|
|
849
|
+
if not deserialize:
|
|
850
|
+
return update_doc
|
|
851
|
+
|
|
852
|
+
return UserMemory.from_dict(update_doc)
|
|
853
|
+
|
|
854
|
+
except Exception as e:
|
|
855
|
+
log_error(f"Exception upserting user memory: {e}")
|
|
856
|
+
return None
|
|
857
|
+
|
|
858
|
+
def clear_memories(self) -> None:
|
|
859
|
+
"""Delete all memories from the database.
|
|
860
|
+
|
|
861
|
+
Raises:
|
|
862
|
+
Exception: If an error occurs during deletion.
|
|
863
|
+
"""
|
|
864
|
+
try:
|
|
865
|
+
collection_ref = self._get_collection(table_type="memories")
|
|
866
|
+
|
|
867
|
+
# Get all documents in the collection
|
|
868
|
+
docs = collection_ref.stream()
|
|
869
|
+
|
|
870
|
+
# Delete all documents in batches
|
|
871
|
+
batch = self.db_client.batch()
|
|
872
|
+
batch_count = 0
|
|
873
|
+
|
|
874
|
+
for doc in docs:
|
|
875
|
+
batch.delete(doc.reference)
|
|
876
|
+
batch_count += 1
|
|
877
|
+
|
|
878
|
+
# Firestore batch has a limit of 500 operations
|
|
879
|
+
if batch_count >= 500:
|
|
880
|
+
batch.commit()
|
|
881
|
+
batch = self.db_client.batch()
|
|
882
|
+
batch_count = 0
|
|
883
|
+
|
|
884
|
+
# Commit remaining operations
|
|
885
|
+
if batch_count > 0:
|
|
886
|
+
batch.commit()
|
|
887
|
+
|
|
888
|
+
except Exception as e:
|
|
889
|
+
from agno.utils.log import log_warning
|
|
890
|
+
|
|
891
|
+
log_warning(f"Exception deleting all memories: {e}")
|
|
892
|
+
|
|
893
|
+
# -- Metrics methods --
|
|
894
|
+
|
|
895
|
+
def _get_all_sessions_for_metrics_calculation(
|
|
896
|
+
self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None
|
|
897
|
+
) -> List[Dict[str, Any]]:
|
|
898
|
+
"""Get all sessions of all types for metrics calculation."""
|
|
899
|
+
try:
|
|
900
|
+
collection_ref = self._get_collection(table_type="sessions")
|
|
901
|
+
|
|
902
|
+
query = collection_ref
|
|
903
|
+
if start_timestamp is not None:
|
|
904
|
+
query = query.where(filter=FieldFilter("created_at", ">=", start_timestamp))
|
|
905
|
+
if end_timestamp is not None:
|
|
906
|
+
query = query.where(filter=FieldFilter("created_at", "<=", end_timestamp))
|
|
907
|
+
|
|
908
|
+
docs = query.stream()
|
|
909
|
+
results = []
|
|
910
|
+
for doc in docs:
|
|
911
|
+
data = doc.to_dict()
|
|
912
|
+
# Only include required fields for metrics
|
|
913
|
+
result = {
|
|
914
|
+
"user_id": data.get("user_id"),
|
|
915
|
+
"session_data": data.get("session_data"),
|
|
916
|
+
"runs": data.get("runs"),
|
|
917
|
+
"created_at": data.get("created_at"),
|
|
918
|
+
"session_type": data.get("session_type"),
|
|
919
|
+
}
|
|
920
|
+
results.append(result)
|
|
921
|
+
|
|
922
|
+
return results
|
|
923
|
+
|
|
924
|
+
except Exception as e:
|
|
925
|
+
log_error(f"Exception reading from sessions collection: {e}")
|
|
926
|
+
return []
|
|
927
|
+
|
|
928
|
+
def _get_metrics_calculation_starting_date(self, collection_ref) -> Optional[date]:
|
|
929
|
+
"""Get the first date for which metrics calculation is needed."""
|
|
930
|
+
try:
|
|
931
|
+
query = collection_ref.order_by("date", direction="DESCENDING").limit(1)
|
|
932
|
+
docs = query.stream()
|
|
933
|
+
|
|
934
|
+
for doc in docs:
|
|
935
|
+
data = doc.to_dict()
|
|
936
|
+
result_date = datetime.strptime(data["date"], "%Y-%m-%d").date()
|
|
937
|
+
if data.get("completed"):
|
|
938
|
+
return result_date + timedelta(days=1)
|
|
939
|
+
else:
|
|
940
|
+
return result_date
|
|
941
|
+
|
|
942
|
+
# No metrics records. Return the date of the first recorded session.
|
|
943
|
+
first_session_result = self.get_sessions(sort_by="created_at", sort_order="asc", limit=1, deserialize=False)
|
|
944
|
+
first_session_date = None
|
|
945
|
+
|
|
946
|
+
if isinstance(first_session_result, list) and len(first_session_result) > 0:
|
|
947
|
+
first_session_date = first_session_result[0].created_at # type: ignore
|
|
948
|
+
elif isinstance(first_session_result, tuple) and len(first_session_result[0]) > 0:
|
|
949
|
+
first_session_date = first_session_result[0][0].get("created_at")
|
|
950
|
+
|
|
951
|
+
if first_session_date is None:
|
|
952
|
+
return None
|
|
953
|
+
|
|
954
|
+
return datetime.fromtimestamp(first_session_date, tz=timezone.utc).date()
|
|
955
|
+
|
|
956
|
+
except Exception as e:
|
|
957
|
+
log_error(f"Exception getting metrics calculation starting date: {e}")
|
|
958
|
+
return None
|
|
959
|
+
|
|
960
|
+
def calculate_metrics(self) -> Optional[list[dict]]:
|
|
961
|
+
"""Calculate metrics for all dates without complete metrics."""
|
|
962
|
+
try:
|
|
963
|
+
collection_ref = self._get_collection(table_type="metrics", create_collection_if_not_found=True)
|
|
964
|
+
|
|
965
|
+
starting_date = self._get_metrics_calculation_starting_date(collection_ref)
|
|
966
|
+
if starting_date is None:
|
|
967
|
+
log_info("No session data found. Won't calculate metrics.")
|
|
968
|
+
return None
|
|
969
|
+
|
|
970
|
+
dates_to_process = get_dates_to_calculate_metrics_for(starting_date)
|
|
971
|
+
if not dates_to_process:
|
|
972
|
+
log_info("Metrics already calculated for all relevant dates.")
|
|
973
|
+
return None
|
|
974
|
+
|
|
975
|
+
start_timestamp = int(datetime.combine(dates_to_process[0], datetime.min.time()).timestamp())
|
|
976
|
+
end_timestamp = int(
|
|
977
|
+
datetime.combine(dates_to_process[-1] + timedelta(days=1), datetime.min.time()).timestamp()
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
sessions = self._get_all_sessions_for_metrics_calculation(
|
|
981
|
+
start_timestamp=start_timestamp, end_timestamp=end_timestamp
|
|
982
|
+
)
|
|
983
|
+
all_sessions_data = fetch_all_sessions_data(
|
|
984
|
+
sessions=sessions, dates_to_process=dates_to_process, start_timestamp=start_timestamp
|
|
985
|
+
)
|
|
986
|
+
if not all_sessions_data:
|
|
987
|
+
log_info("No new session data found. Won't calculate metrics.")
|
|
988
|
+
return None
|
|
989
|
+
|
|
990
|
+
results = []
|
|
991
|
+
metrics_records = []
|
|
992
|
+
|
|
993
|
+
for date_to_process in dates_to_process:
|
|
994
|
+
date_key = date_to_process.isoformat()
|
|
995
|
+
sessions_for_date = all_sessions_data.get(date_key, {})
|
|
996
|
+
|
|
997
|
+
# Skip dates with no sessions
|
|
998
|
+
if not any(len(sessions) > 0 for sessions in sessions_for_date.values()):
|
|
999
|
+
continue
|
|
1000
|
+
|
|
1001
|
+
metrics_record = calculate_date_metrics(date_to_process, sessions_for_date)
|
|
1002
|
+
metrics_records.append(metrics_record)
|
|
1003
|
+
|
|
1004
|
+
if metrics_records:
|
|
1005
|
+
results = bulk_upsert_metrics(collection_ref, metrics_records)
|
|
1006
|
+
|
|
1007
|
+
log_debug("Updated metrics calculations")
|
|
1008
|
+
|
|
1009
|
+
return results
|
|
1010
|
+
|
|
1011
|
+
except Exception as e:
|
|
1012
|
+
log_error(f"Exception calculating metrics: {e}")
|
|
1013
|
+
raise e
|
|
1014
|
+
|
|
1015
|
+
def get_metrics(
|
|
1016
|
+
self,
|
|
1017
|
+
starting_date: Optional[date] = None,
|
|
1018
|
+
ending_date: Optional[date] = None,
|
|
1019
|
+
) -> Tuple[List[dict], Optional[int]]:
|
|
1020
|
+
"""Get all metrics matching the given date range."""
|
|
1021
|
+
try:
|
|
1022
|
+
collection_ref = self._get_collection(table_type="metrics")
|
|
1023
|
+
if collection_ref is None:
|
|
1024
|
+
return [], None
|
|
1025
|
+
|
|
1026
|
+
query = collection_ref
|
|
1027
|
+
if starting_date:
|
|
1028
|
+
query = query.where(filter=FieldFilter("date", ">=", starting_date.isoformat()))
|
|
1029
|
+
if ending_date:
|
|
1030
|
+
query = query.where(filter=FieldFilter("date", "<=", ending_date.isoformat()))
|
|
1031
|
+
|
|
1032
|
+
docs = query.stream()
|
|
1033
|
+
records = []
|
|
1034
|
+
latest_updated_at = 0
|
|
1035
|
+
|
|
1036
|
+
for doc in docs:
|
|
1037
|
+
data = doc.to_dict()
|
|
1038
|
+
records.append(data)
|
|
1039
|
+
updated_at = data.get("updated_at", 0)
|
|
1040
|
+
if updated_at > latest_updated_at:
|
|
1041
|
+
latest_updated_at = updated_at
|
|
1042
|
+
|
|
1043
|
+
if not records:
|
|
1044
|
+
return [], None
|
|
1045
|
+
|
|
1046
|
+
return records, latest_updated_at
|
|
1047
|
+
|
|
1048
|
+
except Exception as e:
|
|
1049
|
+
log_error(f"Exception getting metrics: {e}")
|
|
1050
|
+
return [], None
|
|
1051
|
+
|
|
1052
|
+
# -- Knowledge methods --
|
|
1053
|
+
|
|
1054
|
+
def delete_knowledge_content(self, id: str):
|
|
1055
|
+
"""Delete a knowledge row from the database.
|
|
1056
|
+
|
|
1057
|
+
Args:
|
|
1058
|
+
id (str): The ID of the knowledge row to delete.
|
|
1059
|
+
|
|
1060
|
+
Raises:
|
|
1061
|
+
Exception: If an error occurs during deletion.
|
|
1062
|
+
"""
|
|
1063
|
+
try:
|
|
1064
|
+
collection_ref = self._get_collection(table_type="knowledge")
|
|
1065
|
+
docs = collection_ref.where(filter=FieldFilter("id", "==", id)).stream()
|
|
1066
|
+
|
|
1067
|
+
for doc in docs:
|
|
1068
|
+
doc.reference.delete()
|
|
1069
|
+
|
|
1070
|
+
except Exception as e:
|
|
1071
|
+
log_error(f"Error deleting knowledge source: {e}")
|
|
1072
|
+
|
|
1073
|
+
def get_knowledge_content(self, id: str) -> Optional[KnowledgeRow]:
|
|
1074
|
+
"""Get a knowledge row from the database.
|
|
1075
|
+
|
|
1076
|
+
Args:
|
|
1077
|
+
id (str): The ID of the knowledge row to get.
|
|
1078
|
+
|
|
1079
|
+
Returns:
|
|
1080
|
+
Optional[KnowledgeRow]: The knowledge row, or None if it doesn't exist.
|
|
1081
|
+
|
|
1082
|
+
Raises:
|
|
1083
|
+
Exception: If an error occurs during retrieval.
|
|
1084
|
+
"""
|
|
1085
|
+
try:
|
|
1086
|
+
collection_ref = self._get_collection(table_type="knowledge")
|
|
1087
|
+
docs = collection_ref.where(filter=FieldFilter("id", "==", id)).stream()
|
|
1088
|
+
|
|
1089
|
+
for doc in docs:
|
|
1090
|
+
data = doc.to_dict()
|
|
1091
|
+
return KnowledgeRow.model_validate(data)
|
|
1092
|
+
|
|
1093
|
+
return None
|
|
1094
|
+
|
|
1095
|
+
except Exception as e:
|
|
1096
|
+
log_error(f"Error getting knowledge source: {e}")
|
|
1097
|
+
return None
|
|
1098
|
+
|
|
1099
|
+
def get_knowledge_contents(
|
|
1100
|
+
self,
|
|
1101
|
+
limit: Optional[int] = None,
|
|
1102
|
+
page: Optional[int] = None,
|
|
1103
|
+
sort_by: Optional[str] = None,
|
|
1104
|
+
sort_order: Optional[str] = None,
|
|
1105
|
+
) -> Tuple[List[KnowledgeRow], int]:
|
|
1106
|
+
"""Get all knowledge contents from the database.
|
|
1107
|
+
|
|
1108
|
+
Args:
|
|
1109
|
+
limit (Optional[int]): The maximum number of knowledge contents to return.
|
|
1110
|
+
page (Optional[int]): The page number.
|
|
1111
|
+
sort_by (Optional[str]): The column to sort by.
|
|
1112
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1113
|
+
create_table_if_not_found (Optional[bool]): Whether to create the table if it doesn't exist.
|
|
1114
|
+
|
|
1115
|
+
Returns:
|
|
1116
|
+
Tuple[List[KnowledgeRow], int]: The knowledge contents and total count.
|
|
1117
|
+
|
|
1118
|
+
Raises:
|
|
1119
|
+
Exception: If an error occurs during retrieval.
|
|
1120
|
+
"""
|
|
1121
|
+
try:
|
|
1122
|
+
collection_ref = self._get_collection(table_type="knowledge")
|
|
1123
|
+
if collection_ref is None:
|
|
1124
|
+
return [], 0
|
|
1125
|
+
|
|
1126
|
+
query = collection_ref
|
|
1127
|
+
|
|
1128
|
+
# Apply sorting
|
|
1129
|
+
query = apply_sorting(query, sort_by, sort_order)
|
|
1130
|
+
|
|
1131
|
+
# Apply pagination
|
|
1132
|
+
query = apply_pagination(query, limit, page)
|
|
1133
|
+
|
|
1134
|
+
docs = query.stream()
|
|
1135
|
+
records = []
|
|
1136
|
+
for doc in docs:
|
|
1137
|
+
records.append(doc.to_dict())
|
|
1138
|
+
|
|
1139
|
+
knowledge_rows = [KnowledgeRow.model_validate(record) for record in records]
|
|
1140
|
+
total_count = len(knowledge_rows) # Simplified count
|
|
1141
|
+
|
|
1142
|
+
return knowledge_rows, total_count
|
|
1143
|
+
|
|
1144
|
+
except Exception as e:
|
|
1145
|
+
log_error(f"Error getting knowledge sources: {e}")
|
|
1146
|
+
return [], 0
|
|
1147
|
+
|
|
1148
|
+
def upsert_knowledge_content(self, knowledge_row: KnowledgeRow):
|
|
1149
|
+
"""Upsert knowledge content in the database.
|
|
1150
|
+
|
|
1151
|
+
Args:
|
|
1152
|
+
knowledge_row (KnowledgeRow): The knowledge row to upsert.
|
|
1153
|
+
|
|
1154
|
+
Returns:
|
|
1155
|
+
Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
|
|
1156
|
+
"""
|
|
1157
|
+
try:
|
|
1158
|
+
collection_ref = self._get_collection(table_type="knowledge", create_collection_if_not_found=True)
|
|
1159
|
+
if collection_ref is None:
|
|
1160
|
+
return None
|
|
1161
|
+
|
|
1162
|
+
update_doc = knowledge_row.model_dump()
|
|
1163
|
+
|
|
1164
|
+
# Find existing document or create new one
|
|
1165
|
+
docs = collection_ref.where(filter=FieldFilter("id", "==", knowledge_row.id)).stream()
|
|
1166
|
+
doc_ref = next((doc.reference for doc in docs), None)
|
|
1167
|
+
|
|
1168
|
+
if doc_ref is None:
|
|
1169
|
+
doc_ref = collection_ref.document()
|
|
1170
|
+
|
|
1171
|
+
doc_ref.set(update_doc, merge=True)
|
|
1172
|
+
|
|
1173
|
+
log_debug(f"Upserted knowledge source with id '{knowledge_row.id}'")
|
|
1174
|
+
|
|
1175
|
+
return knowledge_row
|
|
1176
|
+
|
|
1177
|
+
except Exception as e:
|
|
1178
|
+
log_error(f"Error upserting knowledge document: {e}")
|
|
1179
|
+
return None
|
|
1180
|
+
|
|
1181
|
+
# -- Eval methods --
|
|
1182
|
+
|
|
1183
|
+
def create_eval_run(self, eval_run: EvalRunRecord) -> Optional[EvalRunRecord]:
|
|
1184
|
+
"""Create an EvalRunRecord in the database."""
|
|
1185
|
+
try:
|
|
1186
|
+
collection_ref = self._get_collection(table_type="evals", create_collection_if_not_found=True)
|
|
1187
|
+
|
|
1188
|
+
current_time = int(time.time())
|
|
1189
|
+
eval_dict = eval_run.model_dump()
|
|
1190
|
+
eval_dict["created_at"] = current_time
|
|
1191
|
+
eval_dict["updated_at"] = current_time
|
|
1192
|
+
|
|
1193
|
+
doc_ref = collection_ref.document()
|
|
1194
|
+
doc_ref.set(eval_dict)
|
|
1195
|
+
|
|
1196
|
+
log_debug(f"Created eval run with id '{eval_run.run_id}'")
|
|
1197
|
+
|
|
1198
|
+
return eval_run
|
|
1199
|
+
|
|
1200
|
+
except Exception as e:
|
|
1201
|
+
log_error(f"Error creating eval run: {e}")
|
|
1202
|
+
return None
|
|
1203
|
+
|
|
1204
|
+
def delete_eval_run(self, eval_run_id: str) -> None:
|
|
1205
|
+
"""Delete an eval run from the database."""
|
|
1206
|
+
try:
|
|
1207
|
+
collection_ref = self._get_collection(table_type="evals")
|
|
1208
|
+
docs = collection_ref.where(filter=FieldFilter("run_id", "==", eval_run_id)).stream()
|
|
1209
|
+
|
|
1210
|
+
deleted_count = 0
|
|
1211
|
+
for doc in docs:
|
|
1212
|
+
doc.reference.delete()
|
|
1213
|
+
deleted_count += 1
|
|
1214
|
+
|
|
1215
|
+
if deleted_count == 0:
|
|
1216
|
+
log_info(f"No eval run found with ID: {eval_run_id}")
|
|
1217
|
+
else:
|
|
1218
|
+
log_info(f"Deleted eval run with ID: {eval_run_id}")
|
|
1219
|
+
|
|
1220
|
+
except Exception as e:
|
|
1221
|
+
log_error(f"Error deleting eval run {eval_run_id}: {e}")
|
|
1222
|
+
raise
|
|
1223
|
+
|
|
1224
|
+
def delete_eval_runs(self, eval_run_ids: List[str]) -> None:
|
|
1225
|
+
"""Delete multiple eval runs from the database.
|
|
1226
|
+
|
|
1227
|
+
Args:
|
|
1228
|
+
eval_run_ids (List[str]): The IDs of the eval runs to delete.
|
|
1229
|
+
|
|
1230
|
+
Raises:
|
|
1231
|
+
Exception: If there is an error deleting the eval runs.
|
|
1232
|
+
"""
|
|
1233
|
+
try:
|
|
1234
|
+
collection_ref = self._get_collection(table_type="evals")
|
|
1235
|
+
batch = self.db_client.batch()
|
|
1236
|
+
deleted_count = 0
|
|
1237
|
+
|
|
1238
|
+
for eval_run_id in eval_run_ids:
|
|
1239
|
+
docs = collection_ref.where(filter=FieldFilter("run_id", "==", eval_run_id)).stream()
|
|
1240
|
+
for doc in docs:
|
|
1241
|
+
batch.delete(doc.reference)
|
|
1242
|
+
deleted_count += 1
|
|
1243
|
+
|
|
1244
|
+
batch.commit()
|
|
1245
|
+
|
|
1246
|
+
if deleted_count == 0:
|
|
1247
|
+
log_info(f"No eval runs found with IDs: {eval_run_ids}")
|
|
1248
|
+
else:
|
|
1249
|
+
log_info(f"Deleted {deleted_count} eval runs")
|
|
1250
|
+
|
|
1251
|
+
except Exception as e:
|
|
1252
|
+
log_error(f"Error deleting eval runs {eval_run_ids}: {e}")
|
|
1253
|
+
raise
|
|
1254
|
+
|
|
1255
|
+
def get_eval_run(
|
|
1256
|
+
self, eval_run_id: str, deserialize: Optional[bool] = True
|
|
1257
|
+
) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1258
|
+
"""Get an eval run from the database.
|
|
1259
|
+
|
|
1260
|
+
Args:
|
|
1261
|
+
eval_run_id (str): The ID of the eval run to get.
|
|
1262
|
+
deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
|
|
1263
|
+
|
|
1264
|
+
Returns:
|
|
1265
|
+
Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1266
|
+
- When deserialize=True: EvalRunRecord object
|
|
1267
|
+
- When deserialize=False: EvalRun dictionary
|
|
1268
|
+
|
|
1269
|
+
Raises:
|
|
1270
|
+
Exception: If there is an error getting the eval run.
|
|
1271
|
+
"""
|
|
1272
|
+
try:
|
|
1273
|
+
collection_ref = self._get_collection(table_type="evals")
|
|
1274
|
+
docs = collection_ref.where(filter=FieldFilter("run_id", "==", eval_run_id)).stream()
|
|
1275
|
+
|
|
1276
|
+
eval_run_raw = None
|
|
1277
|
+
for doc in docs:
|
|
1278
|
+
eval_run_raw = doc.to_dict()
|
|
1279
|
+
break
|
|
1280
|
+
|
|
1281
|
+
if not eval_run_raw:
|
|
1282
|
+
return None
|
|
1283
|
+
|
|
1284
|
+
if not deserialize:
|
|
1285
|
+
return eval_run_raw
|
|
1286
|
+
|
|
1287
|
+
return EvalRunRecord.model_validate(eval_run_raw)
|
|
1288
|
+
|
|
1289
|
+
except Exception as e:
|
|
1290
|
+
log_error(f"Exception getting eval run {eval_run_id}: {e}")
|
|
1291
|
+
return None
|
|
1292
|
+
|
|
1293
|
+
def get_eval_runs(
|
|
1294
|
+
self,
|
|
1295
|
+
limit: Optional[int] = None,
|
|
1296
|
+
page: Optional[int] = None,
|
|
1297
|
+
sort_by: Optional[str] = None,
|
|
1298
|
+
sort_order: Optional[str] = None,
|
|
1299
|
+
agent_id: Optional[str] = None,
|
|
1300
|
+
team_id: Optional[str] = None,
|
|
1301
|
+
workflow_id: Optional[str] = None,
|
|
1302
|
+
model_id: Optional[str] = None,
|
|
1303
|
+
filter_type: Optional[EvalFilterType] = None,
|
|
1304
|
+
eval_type: Optional[List[EvalType]] = None,
|
|
1305
|
+
deserialize: Optional[bool] = True,
|
|
1306
|
+
) -> Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
|
|
1307
|
+
"""Get all eval runs from the database.
|
|
1308
|
+
|
|
1309
|
+
Args:
|
|
1310
|
+
limit (Optional[int]): The maximum number of eval runs to return.
|
|
1311
|
+
page (Optional[int]): The page number to return.
|
|
1312
|
+
sort_by (Optional[str]): The field to sort by.
|
|
1313
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1314
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
1315
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
1316
|
+
workflow_id (Optional[str]): The ID of the workflow to filter by.
|
|
1317
|
+
model_id (Optional[str]): The ID of the model to filter by.
|
|
1318
|
+
eval_type (Optional[List[EvalType]]): The type of eval to filter by.
|
|
1319
|
+
filter_type (Optional[EvalFilterType]): The type of filter to apply.
|
|
1320
|
+
deserialize (Optional[bool]): Whether to serialize the eval runs. Defaults to True.
|
|
1321
|
+
create_table_if_not_found (Optional[bool]): Whether to create the table if it doesn't exist.
|
|
1322
|
+
|
|
1323
|
+
Returns:
|
|
1324
|
+
Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
|
|
1325
|
+
- When deserialize=True: List of EvalRunRecord objects
|
|
1326
|
+
- When deserialize=False: List of eval run dictionaries and the total count
|
|
1327
|
+
|
|
1328
|
+
Raises:
|
|
1329
|
+
Exception: If there is an error getting the eval runs.
|
|
1330
|
+
"""
|
|
1331
|
+
try:
|
|
1332
|
+
collection_ref = self._get_collection(table_type="evals")
|
|
1333
|
+
if collection_ref is None:
|
|
1334
|
+
return [] if deserialize else ([], 0)
|
|
1335
|
+
|
|
1336
|
+
query = collection_ref
|
|
1337
|
+
|
|
1338
|
+
if agent_id is not None:
|
|
1339
|
+
query = query.where(filter=FieldFilter("agent_id", "==", agent_id))
|
|
1340
|
+
if team_id is not None:
|
|
1341
|
+
query = query.where(filter=FieldFilter("team_id", "==", team_id))
|
|
1342
|
+
if workflow_id is not None:
|
|
1343
|
+
query = query.where(filter=FieldFilter("workflow_id", "==", workflow_id))
|
|
1344
|
+
if model_id is not None:
|
|
1345
|
+
query = query.where(filter=FieldFilter("model_id", "==", model_id))
|
|
1346
|
+
if eval_type is not None and len(eval_type) > 0:
|
|
1347
|
+
eval_values = [et.value for et in eval_type]
|
|
1348
|
+
query = query.where(filter=FieldFilter("eval_type", "in", eval_values))
|
|
1349
|
+
if filter_type is not None:
|
|
1350
|
+
if filter_type == EvalFilterType.AGENT:
|
|
1351
|
+
query = query.where(filter=FieldFilter("agent_id", "!=", None))
|
|
1352
|
+
elif filter_type == EvalFilterType.TEAM:
|
|
1353
|
+
query = query.where(filter=FieldFilter("team_id", "!=", None))
|
|
1354
|
+
elif filter_type == EvalFilterType.WORKFLOW:
|
|
1355
|
+
query = query.where(filter=FieldFilter("workflow_id", "!=", None))
|
|
1356
|
+
|
|
1357
|
+
# Apply default sorting by created_at desc if no sort parameters provided
|
|
1358
|
+
if sort_by is None:
|
|
1359
|
+
from google.cloud.firestore import Query
|
|
1360
|
+
|
|
1361
|
+
query = query.order_by("created_at", direction=Query.DESCENDING)
|
|
1362
|
+
else:
|
|
1363
|
+
query = apply_sorting(query, sort_by, sort_order)
|
|
1364
|
+
|
|
1365
|
+
# Get all documents for counting before pagination
|
|
1366
|
+
all_docs = query.stream()
|
|
1367
|
+
all_records = [doc.to_dict() for doc in all_docs]
|
|
1368
|
+
|
|
1369
|
+
if not all_records:
|
|
1370
|
+
return [] if deserialize else ([], 0)
|
|
1371
|
+
|
|
1372
|
+
# Get total count before pagination
|
|
1373
|
+
total_count = len(all_records)
|
|
1374
|
+
|
|
1375
|
+
# Apply pagination to the results
|
|
1376
|
+
if limit is not None and page is not None:
|
|
1377
|
+
start_index = (page - 1) * limit
|
|
1378
|
+
end_index = start_index + limit
|
|
1379
|
+
records = all_records[start_index:end_index]
|
|
1380
|
+
elif limit is not None:
|
|
1381
|
+
records = all_records[:limit]
|
|
1382
|
+
else:
|
|
1383
|
+
records = all_records
|
|
1384
|
+
|
|
1385
|
+
if not deserialize:
|
|
1386
|
+
return records, total_count
|
|
1387
|
+
|
|
1388
|
+
return [EvalRunRecord.model_validate(row) for row in records]
|
|
1389
|
+
|
|
1390
|
+
except Exception as e:
|
|
1391
|
+
log_error(f"Exception getting eval runs: {e}")
|
|
1392
|
+
return [] if deserialize else ([], 0)
|
|
1393
|
+
|
|
1394
|
+
def rename_eval_run(
|
|
1395
|
+
self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
|
|
1396
|
+
) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1397
|
+
"""Update the name of an eval run in the database.
|
|
1398
|
+
|
|
1399
|
+
Args:
|
|
1400
|
+
eval_run_id (str): The ID of the eval run to update.
|
|
1401
|
+
name (str): The new name of the eval run.
|
|
1402
|
+
deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
|
|
1403
|
+
|
|
1404
|
+
Returns:
|
|
1405
|
+
Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1406
|
+
- When deserialize=True: EvalRunRecord object
|
|
1407
|
+
- When deserialize=False: EvalRun dictionary
|
|
1408
|
+
|
|
1409
|
+
Raises:
|
|
1410
|
+
Exception: If there is an error updating the eval run.
|
|
1411
|
+
"""
|
|
1412
|
+
try:
|
|
1413
|
+
collection_ref = self._get_collection(table_type="evals")
|
|
1414
|
+
|
|
1415
|
+
docs = collection_ref.where(filter=FieldFilter("run_id", "==", eval_run_id)).stream()
|
|
1416
|
+
doc_ref = next((doc.reference for doc in docs), None)
|
|
1417
|
+
|
|
1418
|
+
if doc_ref is None:
|
|
1419
|
+
return None
|
|
1420
|
+
|
|
1421
|
+
doc_ref.update({"name": name, "updated_at": int(time.time())})
|
|
1422
|
+
|
|
1423
|
+
updated_doc = doc_ref.get()
|
|
1424
|
+
if not updated_doc.exists:
|
|
1425
|
+
return None
|
|
1426
|
+
|
|
1427
|
+
result = updated_doc.to_dict()
|
|
1428
|
+
|
|
1429
|
+
log_debug(f"Renamed eval run with id '{eval_run_id}' to '{name}'")
|
|
1430
|
+
|
|
1431
|
+
if not result or not deserialize:
|
|
1432
|
+
return result
|
|
1433
|
+
|
|
1434
|
+
return EvalRunRecord.model_validate(result)
|
|
1435
|
+
|
|
1436
|
+
except Exception as e:
|
|
1437
|
+
log_error(f"Error updating eval run name {eval_run_id}: {e}")
|
|
1438
|
+
raise
|