agno 2.1.2__py3-none-any.whl → 2.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +5540 -2273
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +689 -6
- agno/db/dynamo/dynamo.py +933 -37
- agno/db/dynamo/schemas.py +174 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +831 -9
- agno/db/firestore/schemas.py +51 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +660 -12
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +287 -14
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +590 -14
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +43 -13
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2760 -0
- agno/db/mongo/mongo.py +879 -11
- agno/db/mongo/schemas.py +42 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2912 -0
- agno/db/mysql/mysql.py +946 -68
- agno/db/mysql/schemas.py +72 -10
- agno/db/mysql/utils.py +198 -7
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2579 -0
- agno/db/postgres/postgres.py +942 -57
- agno/db/postgres/schemas.py +81 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +671 -7
- agno/db/redis/schemas.py +50 -0
- agno/db/redis/utils.py +65 -7
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +17 -2
- agno/db/singlestore/schemas.py +63 -0
- agno/db/singlestore/singlestore.py +949 -83
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2911 -0
- agno/db/sqlite/schemas.py +62 -0
- agno/db/sqlite/sqlite.py +965 -46
- agno/db/sqlite/utils.py +169 -8
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +334 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1908 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +2 -0
- agno/eval/__init__.py +10 -0
- agno/eval/accuracy.py +75 -55
- agno/eval/agent_as_judge.py +861 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +16 -7
- agno/eval/reliability.py +28 -16
- agno/eval/utils.py +35 -17
- agno/exceptions.py +27 -2
- agno/filters.py +354 -0
- agno/guardrails/prompt_injection.py +1 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +1 -1
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/semantic.py +9 -4
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +8 -0
- agno/knowledge/embedder/openai.py +8 -8
- agno/knowledge/embedder/sentence_transformer.py +6 -2
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/knowledge.py +1618 -318
- agno/knowledge/reader/base.py +6 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +16 -20
- agno/knowledge/reader/json_reader.py +5 -4
- agno/knowledge/reader/markdown_reader.py +8 -8
- agno/knowledge/reader/pdf_reader.py +17 -19
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +32 -3
- agno/knowledge/reader/s3_reader.py +3 -3
- agno/knowledge/reader/tavily_reader.py +193 -0
- agno/knowledge/reader/text_reader.py +22 -10
- agno/knowledge/reader/web_search_reader.py +1 -48
- agno/knowledge/reader/website_reader.py +10 -10
- agno/knowledge/reader/wikipedia_reader.py +33 -1
- agno/knowledge/types.py +1 -0
- agno/knowledge/utils.py +72 -7
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +544 -83
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +515 -40
- agno/models/aws/bedrock.py +102 -21
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +41 -19
- agno/models/azure/openai_chat.py +39 -8
- agno/models/base.py +1249 -525
- agno/models/cerebras/cerebras.py +91 -21
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +40 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +877 -80
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +51 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +44 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +28 -5
- agno/models/meta/llama.py +47 -14
- agno/models/meta/llama_openai.py +22 -17
- agno/models/mistral/mistral.py +8 -4
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/chat.py +24 -8
- agno/models/openai/chat.py +104 -29
- agno/models/openai/responses.py +101 -81
- agno/models/openrouter/openrouter.py +60 -3
- agno/models/perplexity/perplexity.py +17 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +24 -4
- agno/models/response.py +73 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +190 -0
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +549 -152
- agno/os/auth.py +190 -3
- agno/os/config.py +23 -0
- agno/os/interfaces/a2a/router.py +8 -11
- agno/os/interfaces/a2a/utils.py +1 -1
- agno/os/interfaces/agui/router.py +18 -3
- agno/os/interfaces/agui/utils.py +152 -39
- agno/os/interfaces/slack/router.py +55 -37
- agno/os/interfaces/slack/slack.py +9 -1
- agno/os/interfaces/whatsapp/router.py +0 -1
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/mcp.py +110 -52
- agno/os/middleware/__init__.py +2 -0
- agno/os/middleware/jwt.py +676 -112
- agno/os/router.py +40 -1478
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +599 -0
- agno/os/routers/agents/schema.py +261 -0
- agno/os/routers/evals/evals.py +96 -39
- agno/os/routers/evals/schemas.py +65 -33
- agno/os/routers/evals/utils.py +80 -10
- agno/os/routers/health.py +10 -4
- agno/os/routers/knowledge/knowledge.py +196 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +279 -52
- agno/os/routers/memory/schemas.py +46 -17
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +462 -34
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +512 -0
- agno/os/routers/teams/schema.py +257 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +499 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +624 -0
- agno/os/routers/workflows/schema.py +75 -0
- agno/os/schema.py +256 -693
- agno/os/scopes.py +469 -0
- agno/os/utils.py +514 -36
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/openai.py +5 -0
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +155 -32
- agno/run/base.py +55 -3
- agno/run/requirement.py +181 -0
- agno/run/team.py +125 -38
- agno/run/workflow.py +72 -18
- agno/session/agent.py +102 -89
- agno/session/summary.py +56 -15
- agno/session/team.py +164 -90
- agno/session/workflow.py +405 -40
- agno/table.py +10 -0
- agno/team/team.py +3974 -1903
- agno/tools/dalle.py +2 -4
- agno/tools/eleven_labs.py +23 -25
- agno/tools/exa.py +21 -16
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +16 -10
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +193 -38
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +271 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +3 -3
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/notion.py +204 -0
- agno/tools/parallel.py +314 -0
- agno/tools/postgres.py +76 -36
- agno/tools/redshift.py +406 -0
- agno/tools/scrapegraph.py +1 -1
- agno/tools/shopify.py +1519 -0
- agno/tools/slack.py +18 -3
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +146 -0
- agno/tools/toolkit.py +25 -0
- agno/tools/workflow.py +8 -1
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +157 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +111 -0
- agno/utils/agent.py +938 -0
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +151 -3
- agno/utils/gemini.py +15 -5
- agno/utils/hooks.py +118 -4
- agno/utils/http.py +113 -2
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +92 -2
- agno/utils/media.py +187 -1
- agno/utils/merge_dict.py +3 -3
- agno/utils/message.py +60 -0
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +49 -14
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/print_response/agent.py +109 -16
- agno/utils/print_response/team.py +223 -30
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/streamlit.py +1 -1
- agno/utils/team.py +98 -9
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +39 -7
- agno/vectordb/cassandra/cassandra.py +21 -5
- agno/vectordb/chroma/chromadb.py +43 -12
- agno/vectordb/clickhouse/clickhousedb.py +21 -5
- agno/vectordb/couchbase/couchbase.py +29 -5
- agno/vectordb/lancedb/lance_db.py +92 -181
- agno/vectordb/langchaindb/langchaindb.py +24 -4
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/llamaindexdb.py +25 -5
- agno/vectordb/milvus/milvus.py +50 -37
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +36 -30
- agno/vectordb/pgvector/pgvector.py +201 -77
- agno/vectordb/pineconedb/pineconedb.py +41 -23
- agno/vectordb/qdrant/qdrant.py +67 -54
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +682 -0
- agno/vectordb/singlestore/singlestore.py +50 -29
- agno/vectordb/surrealdb/surrealdb.py +31 -41
- agno/vectordb/upstashdb/upstashdb.py +34 -6
- agno/vectordb/weaviate/weaviate.py +53 -14
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +120 -18
- agno/workflow/loop.py +77 -10
- agno/workflow/parallel.py +231 -143
- agno/workflow/router.py +118 -17
- agno/workflow/step.py +609 -170
- agno/workflow/steps.py +73 -6
- agno/workflow/types.py +96 -21
- agno/workflow/workflow.py +2039 -262
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/METADATA +201 -66
- agno-2.3.13.dist-info/RECORD +613 -0
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -679
- agno/tools/memori.py +0 -339
- agno-2.1.2.dist-info/RECORD +0 -543
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/WHEEL +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,2911 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from datetime import date, datetime, timedelta, timezone
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from agno.tracing.schemas import Span, Trace
|
|
9
|
+
|
|
10
|
+
from agno.db.base import AsyncBaseDb, SessionType
|
|
11
|
+
from agno.db.migrations.manager import MigrationManager
|
|
12
|
+
from agno.db.schemas.culture import CulturalKnowledge
|
|
13
|
+
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
14
|
+
from agno.db.schemas.knowledge import KnowledgeRow
|
|
15
|
+
from agno.db.schemas.memory import UserMemory
|
|
16
|
+
from agno.db.sqlite.schemas import get_table_schema_definition
|
|
17
|
+
from agno.db.sqlite.utils import (
|
|
18
|
+
abulk_upsert_metrics,
|
|
19
|
+
ais_table_available,
|
|
20
|
+
ais_valid_table,
|
|
21
|
+
apply_sorting,
|
|
22
|
+
calculate_date_metrics,
|
|
23
|
+
deserialize_cultural_knowledge_from_db,
|
|
24
|
+
fetch_all_sessions_data,
|
|
25
|
+
get_dates_to_calculate_metrics_for,
|
|
26
|
+
serialize_cultural_knowledge_for_db,
|
|
27
|
+
)
|
|
28
|
+
from agno.db.utils import deserialize_session_json_fields, serialize_session_json_fields
|
|
29
|
+
from agno.session import AgentSession, Session, TeamSession, WorkflowSession
|
|
30
|
+
from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
31
|
+
from agno.utils.string import generate_id
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from sqlalchemy import Column, MetaData, String, Table, func, select, text
|
|
35
|
+
from sqlalchemy.dialects import sqlite
|
|
36
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
|
37
|
+
from sqlalchemy.schema import Index, UniqueConstraint
|
|
38
|
+
except ImportError:
|
|
39
|
+
raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AsyncSqliteDb(AsyncBaseDb):
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
db_file: Optional[str] = None,
|
|
46
|
+
db_engine: Optional[AsyncEngine] = None,
|
|
47
|
+
db_url: Optional[str] = None,
|
|
48
|
+
session_table: Optional[str] = None,
|
|
49
|
+
culture_table: Optional[str] = None,
|
|
50
|
+
memory_table: Optional[str] = None,
|
|
51
|
+
metrics_table: Optional[str] = None,
|
|
52
|
+
eval_table: Optional[str] = None,
|
|
53
|
+
knowledge_table: Optional[str] = None,
|
|
54
|
+
traces_table: Optional[str] = None,
|
|
55
|
+
spans_table: Optional[str] = None,
|
|
56
|
+
versions_table: Optional[str] = None,
|
|
57
|
+
id: Optional[str] = None,
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Async interface for interacting with a SQLite database.
|
|
61
|
+
|
|
62
|
+
The following order is used to determine the database connection:
|
|
63
|
+
1. Use the db_engine
|
|
64
|
+
2. Use the db_url
|
|
65
|
+
3. Use the db_file
|
|
66
|
+
4. Create a new database in the current directory
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
db_file (Optional[str]): The database file to connect to.
|
|
70
|
+
db_engine (Optional[AsyncEngine]): The SQLAlchemy async database engine to use.
|
|
71
|
+
db_url (Optional[str]): The database URL to connect to.
|
|
72
|
+
session_table (Optional[str]): Name of the table to store Agent, Team and Workflow sessions.
|
|
73
|
+
culture_table (Optional[str]): Name of the table to store cultural notions.
|
|
74
|
+
memory_table (Optional[str]): Name of the table to store user memories.
|
|
75
|
+
metrics_table (Optional[str]): Name of the table to store metrics.
|
|
76
|
+
eval_table (Optional[str]): Name of the table to store evaluation runs data.
|
|
77
|
+
knowledge_table (Optional[str]): Name of the table to store knowledge documents data.
|
|
78
|
+
traces_table (Optional[str]): Name of the table to store run traces.
|
|
79
|
+
spans_table (Optional[str]): Name of the table to store span events.
|
|
80
|
+
versions_table (Optional[str]): Name of the table to store schema versions.
|
|
81
|
+
id (Optional[str]): ID of the database.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If none of the tables are provided.
|
|
85
|
+
"""
|
|
86
|
+
if id is None:
|
|
87
|
+
seed = db_url or db_file or str(db_engine.url) if db_engine else "sqlite+aiosqlite:///agno.db"
|
|
88
|
+
id = generate_id(seed)
|
|
89
|
+
|
|
90
|
+
super().__init__(
|
|
91
|
+
id=id,
|
|
92
|
+
session_table=session_table,
|
|
93
|
+
culture_table=culture_table,
|
|
94
|
+
memory_table=memory_table,
|
|
95
|
+
metrics_table=metrics_table,
|
|
96
|
+
eval_table=eval_table,
|
|
97
|
+
knowledge_table=knowledge_table,
|
|
98
|
+
traces_table=traces_table,
|
|
99
|
+
spans_table=spans_table,
|
|
100
|
+
versions_table=versions_table,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
_engine: Optional[AsyncEngine] = db_engine
|
|
104
|
+
if _engine is None:
|
|
105
|
+
if db_url is not None:
|
|
106
|
+
_engine = create_async_engine(db_url)
|
|
107
|
+
elif db_file is not None:
|
|
108
|
+
db_path = Path(db_file).resolve()
|
|
109
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
110
|
+
db_file = str(db_path)
|
|
111
|
+
_engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
|
112
|
+
else:
|
|
113
|
+
# If none of db_engine, db_url, or db_file are provided, create a db in the current directory
|
|
114
|
+
default_db_path = Path("./agno.db").resolve()
|
|
115
|
+
_engine = create_async_engine(f"sqlite+aiosqlite:///{default_db_path}")
|
|
116
|
+
db_file = str(default_db_path)
|
|
117
|
+
log_debug(f"Created SQLite database: {default_db_path}")
|
|
118
|
+
|
|
119
|
+
self.db_engine: AsyncEngine = _engine
|
|
120
|
+
self.db_url: Optional[str] = db_url
|
|
121
|
+
self.db_file: Optional[str] = db_file
|
|
122
|
+
self.metadata: MetaData = MetaData()
|
|
123
|
+
|
|
124
|
+
# Initialize database session factory
|
|
125
|
+
self.async_session_factory = async_sessionmaker(bind=self.db_engine, expire_on_commit=False)
|
|
126
|
+
|
|
127
|
+
# -- DB methods --
|
|
128
|
+
async def table_exists(self, table_name: str) -> bool:
|
|
129
|
+
"""Check if a table with the given name exists in the SQLite database.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
table_name: Name of the table to check
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
bool: True if the table exists in the database, False otherwise
|
|
136
|
+
"""
|
|
137
|
+
async with self.async_session_factory() as sess:
|
|
138
|
+
return await ais_table_available(session=sess, table_name=table_name)
|
|
139
|
+
|
|
140
|
+
async def _create_all_tables(self):
|
|
141
|
+
"""Create all tables for the database."""
|
|
142
|
+
tables_to_create = [
|
|
143
|
+
(self.session_table_name, "sessions"),
|
|
144
|
+
(self.memory_table_name, "memories"),
|
|
145
|
+
(self.metrics_table_name, "metrics"),
|
|
146
|
+
(self.eval_table_name, "evals"),
|
|
147
|
+
(self.knowledge_table_name, "knowledge"),
|
|
148
|
+
(self.versions_table_name, "versions"),
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
for table_name, table_type in tables_to_create:
|
|
152
|
+
await self._get_or_create_table(
|
|
153
|
+
table_name=table_name, table_type=table_type, create_table_if_not_found=True
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
async def _create_table(self, table_name: str, table_type: str) -> Table:
|
|
157
|
+
"""
|
|
158
|
+
Create a table with the appropriate schema based on the table type.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
table_name (str): Name of the table to create
|
|
162
|
+
table_type (str): Type of table (used to get schema definition)
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Table: SQLAlchemy Table object
|
|
166
|
+
"""
|
|
167
|
+
try:
|
|
168
|
+
table_schema = get_table_schema_definition(table_type)
|
|
169
|
+
|
|
170
|
+
columns: List[Column] = []
|
|
171
|
+
indexes: List[str] = []
|
|
172
|
+
unique_constraints: List[str] = []
|
|
173
|
+
schema_unique_constraints = table_schema.pop("_unique_constraints", [])
|
|
174
|
+
|
|
175
|
+
# Get the columns, indexes, and unique constraints from the table schema
|
|
176
|
+
for col_name, col_config in table_schema.items():
|
|
177
|
+
column_args = [col_name, col_config["type"]()]
|
|
178
|
+
column_kwargs = {}
|
|
179
|
+
|
|
180
|
+
if col_config.get("primary_key", False):
|
|
181
|
+
column_kwargs["primary_key"] = True
|
|
182
|
+
if "nullable" in col_config:
|
|
183
|
+
column_kwargs["nullable"] = col_config["nullable"]
|
|
184
|
+
if col_config.get("index", False):
|
|
185
|
+
indexes.append(col_name)
|
|
186
|
+
if col_config.get("unique", False):
|
|
187
|
+
column_kwargs["unique"] = True
|
|
188
|
+
unique_constraints.append(col_name)
|
|
189
|
+
|
|
190
|
+
columns.append(Column(*column_args, **column_kwargs)) # type: ignore
|
|
191
|
+
|
|
192
|
+
# Create the table object
|
|
193
|
+
table = Table(table_name, self.metadata, *columns)
|
|
194
|
+
|
|
195
|
+
# Add multi-column unique constraints with table-specific names
|
|
196
|
+
for constraint in schema_unique_constraints:
|
|
197
|
+
constraint_name = f"{table_name}_{constraint['name']}"
|
|
198
|
+
constraint_columns = constraint["columns"]
|
|
199
|
+
table.append_constraint(UniqueConstraint(*constraint_columns, name=constraint_name))
|
|
200
|
+
|
|
201
|
+
# Add indexes to the table definition
|
|
202
|
+
for idx_col in indexes:
|
|
203
|
+
idx_name = f"idx_{table_name}_{idx_col}"
|
|
204
|
+
table.append_constraint(Index(idx_name, idx_col))
|
|
205
|
+
|
|
206
|
+
# Create table
|
|
207
|
+
table_created = False
|
|
208
|
+
if not await self.table_exists(table_name):
|
|
209
|
+
async with self.db_engine.begin() as conn:
|
|
210
|
+
await conn.run_sync(table.create, checkfirst=True)
|
|
211
|
+
log_debug(f"Successfully created table '{table_name}'")
|
|
212
|
+
table_created = True
|
|
213
|
+
else:
|
|
214
|
+
log_debug(f"Table {table_name} already exists, skipping creation")
|
|
215
|
+
|
|
216
|
+
# Create indexes
|
|
217
|
+
for idx in table.indexes:
|
|
218
|
+
try:
|
|
219
|
+
# Check if index already exists
|
|
220
|
+
async with self.async_session_factory() as sess:
|
|
221
|
+
exists_query = text("SELECT 1 FROM sqlite_master WHERE type = 'index' AND name = :index_name")
|
|
222
|
+
result = await sess.execute(exists_query, {"index_name": idx.name})
|
|
223
|
+
exists = result.scalar() is not None
|
|
224
|
+
if exists:
|
|
225
|
+
log_debug(f"Index {idx.name} already exists in table {table_name}, skipping creation")
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
async with self.db_engine.begin() as conn:
|
|
229
|
+
await conn.run_sync(idx.create)
|
|
230
|
+
log_debug(f"Created index: {idx.name} for table {table_name}")
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
log_warning(f"Error creating index {idx.name}: {e}")
|
|
234
|
+
|
|
235
|
+
# Store the schema version for the created table
|
|
236
|
+
if table_name != self.versions_table_name and table_created:
|
|
237
|
+
latest_schema_version = MigrationManager(self).latest_schema_version
|
|
238
|
+
await self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
|
|
239
|
+
|
|
240
|
+
return table
|
|
241
|
+
|
|
242
|
+
except Exception as e:
|
|
243
|
+
log_error(f"Could not create table '{table_name}': {e}")
|
|
244
|
+
raise e
|
|
245
|
+
|
|
246
|
+
async def _get_table(self, table_type: str, create_table_if_not_found: Optional[bool] = False) -> Optional[Table]:
|
|
247
|
+
if table_type == "sessions":
|
|
248
|
+
if not hasattr(self, "session_table"):
|
|
249
|
+
self.session_table = await self._get_or_create_table(
|
|
250
|
+
table_name=self.session_table_name,
|
|
251
|
+
table_type=table_type,
|
|
252
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
253
|
+
)
|
|
254
|
+
return self.session_table
|
|
255
|
+
|
|
256
|
+
elif table_type == "memories":
|
|
257
|
+
if not hasattr(self, "memory_table"):
|
|
258
|
+
self.memory_table = await self._get_or_create_table(
|
|
259
|
+
table_name=self.memory_table_name,
|
|
260
|
+
table_type="memories",
|
|
261
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
262
|
+
)
|
|
263
|
+
return self.memory_table
|
|
264
|
+
|
|
265
|
+
elif table_type == "metrics":
|
|
266
|
+
if not hasattr(self, "metrics_table"):
|
|
267
|
+
self.metrics_table = await self._get_or_create_table(
|
|
268
|
+
table_name=self.metrics_table_name,
|
|
269
|
+
table_type="metrics",
|
|
270
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
271
|
+
)
|
|
272
|
+
return self.metrics_table
|
|
273
|
+
|
|
274
|
+
elif table_type == "evals":
|
|
275
|
+
if not hasattr(self, "eval_table"):
|
|
276
|
+
self.eval_table = await self._get_or_create_table(
|
|
277
|
+
table_name=self.eval_table_name,
|
|
278
|
+
table_type="evals",
|
|
279
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
280
|
+
)
|
|
281
|
+
return self.eval_table
|
|
282
|
+
|
|
283
|
+
elif table_type == "knowledge":
|
|
284
|
+
if not hasattr(self, "knowledge_table"):
|
|
285
|
+
self.knowledge_table = await self._get_or_create_table(
|
|
286
|
+
table_name=self.knowledge_table_name,
|
|
287
|
+
table_type="knowledge",
|
|
288
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
289
|
+
)
|
|
290
|
+
return self.knowledge_table
|
|
291
|
+
|
|
292
|
+
elif table_type == "culture":
|
|
293
|
+
if not hasattr(self, "culture_table"):
|
|
294
|
+
self.culture_table = await self._get_or_create_table(
|
|
295
|
+
table_name=self.culture_table_name,
|
|
296
|
+
table_type="culture",
|
|
297
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
298
|
+
)
|
|
299
|
+
return self.culture_table
|
|
300
|
+
|
|
301
|
+
elif table_type == "versions":
|
|
302
|
+
if not hasattr(self, "versions_table"):
|
|
303
|
+
self.versions_table = await self._get_or_create_table(
|
|
304
|
+
table_name=self.versions_table_name,
|
|
305
|
+
table_type="versions",
|
|
306
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
307
|
+
)
|
|
308
|
+
return self.versions_table
|
|
309
|
+
|
|
310
|
+
elif table_type == "traces":
|
|
311
|
+
if not hasattr(self, "traces_table"):
|
|
312
|
+
self.traces_table = await self._get_or_create_table(
|
|
313
|
+
table_name=self.trace_table_name,
|
|
314
|
+
table_type="traces",
|
|
315
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
316
|
+
)
|
|
317
|
+
return self.traces_table
|
|
318
|
+
|
|
319
|
+
elif table_type == "spans":
|
|
320
|
+
if not hasattr(self, "spans_table"):
|
|
321
|
+
# Ensure traces table exists first (spans has FK to traces)
|
|
322
|
+
await self._get_table(table_type="traces", create_table_if_not_found=True)
|
|
323
|
+
self.spans_table = await self._get_or_create_table(
|
|
324
|
+
table_name=self.span_table_name,
|
|
325
|
+
table_type="spans",
|
|
326
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
327
|
+
)
|
|
328
|
+
return self.spans_table
|
|
329
|
+
|
|
330
|
+
else:
|
|
331
|
+
raise ValueError(f"Unknown table type: '{table_type}'")
|
|
332
|
+
|
|
333
|
+
async def _get_or_create_table(
|
|
334
|
+
self,
|
|
335
|
+
table_name: str,
|
|
336
|
+
table_type: str,
|
|
337
|
+
create_table_if_not_found: Optional[bool] = False,
|
|
338
|
+
) -> Table:
|
|
339
|
+
"""
|
|
340
|
+
Check if the table exists and is valid, else create it.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
table_name (str): Name of the table to get or create
|
|
344
|
+
table_type (str): Type of table (used to get schema definition)
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
Table: SQLAlchemy Table object
|
|
348
|
+
"""
|
|
349
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
350
|
+
table_is_available = await ais_table_available(session=sess, table_name=table_name)
|
|
351
|
+
|
|
352
|
+
if (not table_is_available) and create_table_if_not_found:
|
|
353
|
+
return await self._create_table(table_name=table_name, table_type=table_type)
|
|
354
|
+
|
|
355
|
+
# SQLite version of table validation (no schema)
|
|
356
|
+
if not await ais_valid_table(db_engine=self.db_engine, table_name=table_name, table_type=table_type):
|
|
357
|
+
raise ValueError(f"Table {table_name} has an invalid schema")
|
|
358
|
+
|
|
359
|
+
try:
|
|
360
|
+
async with self.db_engine.connect() as conn:
|
|
361
|
+
|
|
362
|
+
def load_table(connection):
|
|
363
|
+
return Table(table_name, self.metadata, autoload_with=connection)
|
|
364
|
+
|
|
365
|
+
table = await conn.run_sync(load_table)
|
|
366
|
+
return table
|
|
367
|
+
|
|
368
|
+
except Exception as e:
|
|
369
|
+
log_error(f"Error loading existing table {table_name}: {e}")
|
|
370
|
+
raise e
|
|
371
|
+
|
|
372
|
+
async def get_latest_schema_version(self, table_name: str) -> str:
|
|
373
|
+
"""Get the latest version of the database schema."""
|
|
374
|
+
table = await self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
375
|
+
if table is None:
|
|
376
|
+
return "2.0.0"
|
|
377
|
+
async with self.async_session_factory() as sess:
|
|
378
|
+
stmt = select(table)
|
|
379
|
+
# Latest version for the given table
|
|
380
|
+
stmt = stmt.where(table.c.table_name == table_name)
|
|
381
|
+
stmt = stmt.order_by(table.c.version.desc()).limit(1)
|
|
382
|
+
result = await sess.execute(stmt)
|
|
383
|
+
row = result.fetchone()
|
|
384
|
+
if row is None:
|
|
385
|
+
return "2.0.0"
|
|
386
|
+
version_dict = dict(row._mapping)
|
|
387
|
+
return version_dict.get("version") or "2.0.0"
|
|
388
|
+
|
|
389
|
+
async def upsert_schema_version(self, table_name: str, version: str) -> None:
|
|
390
|
+
"""Upsert the schema version into the database."""
|
|
391
|
+
table = await self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
392
|
+
if table is None:
|
|
393
|
+
return
|
|
394
|
+
current_datetime = datetime.now().isoformat()
|
|
395
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
396
|
+
stmt = sqlite.insert(table).values(
|
|
397
|
+
table_name=table_name,
|
|
398
|
+
version=version,
|
|
399
|
+
created_at=current_datetime, # Store as ISO format string
|
|
400
|
+
updated_at=current_datetime,
|
|
401
|
+
)
|
|
402
|
+
# Update version if table_name already exists
|
|
403
|
+
stmt = stmt.on_conflict_do_update(
|
|
404
|
+
index_elements=["table_name"],
|
|
405
|
+
set_=dict(version=version, updated_at=current_datetime),
|
|
406
|
+
)
|
|
407
|
+
await sess.execute(stmt)
|
|
408
|
+
|
|
409
|
+
# -- Session methods --
|
|
410
|
+
|
|
411
|
+
async def delete_session(self, session_id: str) -> bool:
|
|
412
|
+
"""
|
|
413
|
+
Delete a session from the database.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
session_id (str): ID of the session to delete
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
bool: True if the session was deleted, False otherwise.
|
|
420
|
+
|
|
421
|
+
Raises:
|
|
422
|
+
Exception: If an error occurs during deletion.
|
|
423
|
+
"""
|
|
424
|
+
try:
|
|
425
|
+
table = await self._get_table(table_type="sessions")
|
|
426
|
+
if table is None:
|
|
427
|
+
return False
|
|
428
|
+
|
|
429
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
430
|
+
delete_stmt = table.delete().where(table.c.session_id == session_id)
|
|
431
|
+
result = await sess.execute(delete_stmt)
|
|
432
|
+
if result.rowcount == 0: # type: ignore
|
|
433
|
+
log_debug(f"No session found to delete with session_id: {session_id}")
|
|
434
|
+
return False
|
|
435
|
+
else:
|
|
436
|
+
log_debug(f"Successfully deleted session with session_id: {session_id}")
|
|
437
|
+
return True
|
|
438
|
+
|
|
439
|
+
except Exception as e:
|
|
440
|
+
log_error(f"Error deleting session: {e}")
|
|
441
|
+
return False
|
|
442
|
+
|
|
443
|
+
async def delete_sessions(self, session_ids: List[str]) -> None:
|
|
444
|
+
"""Delete all given sessions from the database.
|
|
445
|
+
Can handle multiple session types in the same run.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
session_ids (List[str]): The IDs of the sessions to delete.
|
|
449
|
+
|
|
450
|
+
Raises:
|
|
451
|
+
Exception: If an error occurs during deletion.
|
|
452
|
+
"""
|
|
453
|
+
try:
|
|
454
|
+
table = await self._get_table(table_type="sessions")
|
|
455
|
+
if table is None:
|
|
456
|
+
return
|
|
457
|
+
|
|
458
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
459
|
+
delete_stmt = table.delete().where(table.c.session_id.in_(session_ids))
|
|
460
|
+
result = await sess.execute(delete_stmt)
|
|
461
|
+
|
|
462
|
+
log_debug(f"Successfully deleted {result.rowcount} sessions") # type: ignore
|
|
463
|
+
|
|
464
|
+
except Exception as e:
|
|
465
|
+
log_error(f"Error deleting sessions: {e}")
|
|
466
|
+
|
|
467
|
+
async def get_session(
|
|
468
|
+
self,
|
|
469
|
+
session_id: str,
|
|
470
|
+
session_type: SessionType,
|
|
471
|
+
user_id: Optional[str] = None,
|
|
472
|
+
deserialize: Optional[bool] = True,
|
|
473
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
474
|
+
"""
|
|
475
|
+
Read a session from the database.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
session_id (str): ID of the session to read.
|
|
479
|
+
session_type (SessionType): Type of session to get.
|
|
480
|
+
user_id (Optional[str]): User ID to filter by. Defaults to None.
|
|
481
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
Optional[Union[Session, Dict[str, Any]]]:
|
|
485
|
+
- When deserialize=True: Session object
|
|
486
|
+
- When deserialize=False: Session dictionary
|
|
487
|
+
|
|
488
|
+
Raises:
|
|
489
|
+
Exception: If an error occurs during retrieval.
|
|
490
|
+
"""
|
|
491
|
+
try:
|
|
492
|
+
table = await self._get_table(table_type="sessions")
|
|
493
|
+
if table is None:
|
|
494
|
+
return None
|
|
495
|
+
|
|
496
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
497
|
+
stmt = select(table).where(table.c.session_id == session_id)
|
|
498
|
+
|
|
499
|
+
# Filtering
|
|
500
|
+
if user_id is not None:
|
|
501
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
502
|
+
|
|
503
|
+
result = await sess.execute(stmt)
|
|
504
|
+
row = result.fetchone()
|
|
505
|
+
if row is None:
|
|
506
|
+
return None
|
|
507
|
+
|
|
508
|
+
session_raw = deserialize_session_json_fields(dict(row._mapping))
|
|
509
|
+
if not session_raw or not deserialize:
|
|
510
|
+
return session_raw
|
|
511
|
+
|
|
512
|
+
if session_type == SessionType.AGENT:
|
|
513
|
+
return AgentSession.from_dict(session_raw)
|
|
514
|
+
elif session_type == SessionType.TEAM:
|
|
515
|
+
return TeamSession.from_dict(session_raw)
|
|
516
|
+
elif session_type == SessionType.WORKFLOW:
|
|
517
|
+
return WorkflowSession.from_dict(session_raw)
|
|
518
|
+
else:
|
|
519
|
+
raise ValueError(f"Invalid session type: {session_type}")
|
|
520
|
+
|
|
521
|
+
except Exception as e:
|
|
522
|
+
log_debug(f"Exception reading from sessions table: {e}")
|
|
523
|
+
raise e
|
|
524
|
+
|
|
525
|
+
async def get_sessions(
|
|
526
|
+
self,
|
|
527
|
+
session_type: Optional[SessionType] = None,
|
|
528
|
+
user_id: Optional[str] = None,
|
|
529
|
+
component_id: Optional[str] = None,
|
|
530
|
+
session_name: Optional[str] = None,
|
|
531
|
+
start_timestamp: Optional[int] = None,
|
|
532
|
+
end_timestamp: Optional[int] = None,
|
|
533
|
+
limit: Optional[int] = None,
|
|
534
|
+
page: Optional[int] = None,
|
|
535
|
+
sort_by: Optional[str] = None,
|
|
536
|
+
sort_order: Optional[str] = None,
|
|
537
|
+
deserialize: Optional[bool] = True,
|
|
538
|
+
) -> Union[List[Session], Tuple[List[Dict[str, Any]], int]]:
|
|
539
|
+
"""
|
|
540
|
+
Get all sessions in the given table. Can filter by user_id and entity_id.
|
|
541
|
+
Args:
|
|
542
|
+
session_type (Optional[SessionType]): The type of session to get.
|
|
543
|
+
user_id (Optional[str]): The ID of the user to filter by.
|
|
544
|
+
component_id (Optional[str]): The ID of the agent / workflow to filter by.
|
|
545
|
+
session_name (Optional[str]): The name of the session to filter by.
|
|
546
|
+
start_timestamp (Optional[int]): The start timestamp to filter by.
|
|
547
|
+
end_timestamp (Optional[int]): The end timestamp to filter by.
|
|
548
|
+
limit (Optional[int]): The maximum number of sessions to return. Defaults to None.
|
|
549
|
+
page (Optional[int]): The page number to return. Defaults to None.
|
|
550
|
+
sort_by (Optional[str]): The field to sort by. Defaults to None.
|
|
551
|
+
sort_order (Optional[str]): The sort order. Defaults to None.
|
|
552
|
+
deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
List[Session]:
|
|
556
|
+
- When deserialize=True: List of Session objects matching the criteria.
|
|
557
|
+
- When deserialize=False: List of Session dictionaries matching the criteria.
|
|
558
|
+
|
|
559
|
+
Raises:
|
|
560
|
+
Exception: If an error occurs during retrieval.
|
|
561
|
+
"""
|
|
562
|
+
try:
|
|
563
|
+
table = await self._get_table(table_type="sessions")
|
|
564
|
+
if table is None:
|
|
565
|
+
return [] if deserialize else ([], 0)
|
|
566
|
+
|
|
567
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
568
|
+
stmt = select(table)
|
|
569
|
+
|
|
570
|
+
# Filtering
|
|
571
|
+
if user_id is not None:
|
|
572
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
573
|
+
if component_id is not None:
|
|
574
|
+
if session_type == SessionType.AGENT:
|
|
575
|
+
stmt = stmt.where(table.c.agent_id == component_id)
|
|
576
|
+
elif session_type == SessionType.TEAM:
|
|
577
|
+
stmt = stmt.where(table.c.team_id == component_id)
|
|
578
|
+
elif session_type == SessionType.WORKFLOW:
|
|
579
|
+
stmt = stmt.where(table.c.workflow_id == component_id)
|
|
580
|
+
if start_timestamp is not None:
|
|
581
|
+
stmt = stmt.where(table.c.created_at >= start_timestamp)
|
|
582
|
+
if end_timestamp is not None:
|
|
583
|
+
stmt = stmt.where(table.c.created_at <= end_timestamp)
|
|
584
|
+
if session_name is not None:
|
|
585
|
+
stmt = stmt.where(table.c.session_data.like(f"%{session_name}%"))
|
|
586
|
+
if session_type is not None:
|
|
587
|
+
stmt = stmt.where(table.c.session_type == session_type.value)
|
|
588
|
+
|
|
589
|
+
# Getting total count
|
|
590
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
591
|
+
count_result = await sess.execute(count_stmt)
|
|
592
|
+
total_count = count_result.scalar() or 0
|
|
593
|
+
|
|
594
|
+
# Sorting
|
|
595
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
596
|
+
|
|
597
|
+
# Paginating
|
|
598
|
+
if limit is not None:
|
|
599
|
+
stmt = stmt.limit(limit)
|
|
600
|
+
if page is not None:
|
|
601
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
602
|
+
|
|
603
|
+
result = await sess.execute(stmt)
|
|
604
|
+
records = result.fetchall()
|
|
605
|
+
if records is None:
|
|
606
|
+
return [] if deserialize else ([], 0)
|
|
607
|
+
|
|
608
|
+
sessions_raw = [deserialize_session_json_fields(dict(record._mapping)) for record in records]
|
|
609
|
+
if not deserialize:
|
|
610
|
+
return sessions_raw, total_count
|
|
611
|
+
if not sessions_raw:
|
|
612
|
+
return []
|
|
613
|
+
|
|
614
|
+
if session_type == SessionType.AGENT:
|
|
615
|
+
return [AgentSession.from_dict(record) for record in sessions_raw] # type: ignore
|
|
616
|
+
elif session_type == SessionType.TEAM:
|
|
617
|
+
return [TeamSession.from_dict(record) for record in sessions_raw] # type: ignore
|
|
618
|
+
elif session_type == SessionType.WORKFLOW:
|
|
619
|
+
return [WorkflowSession.from_dict(record) for record in sessions_raw] # type: ignore
|
|
620
|
+
else:
|
|
621
|
+
raise ValueError(f"Invalid session type: {session_type}")
|
|
622
|
+
|
|
623
|
+
except Exception as e:
|
|
624
|
+
log_debug(f"Exception reading from sessions table: {e}")
|
|
625
|
+
raise e
|
|
626
|
+
|
|
627
|
+
async def rename_session(
|
|
628
|
+
self,
|
|
629
|
+
session_id: str,
|
|
630
|
+
session_type: SessionType,
|
|
631
|
+
session_name: str,
|
|
632
|
+
deserialize: Optional[bool] = True,
|
|
633
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
634
|
+
"""
|
|
635
|
+
Rename a session in the database.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
session_id (str): The ID of the session to rename.
|
|
639
|
+
session_type (SessionType): The type of session to rename.
|
|
640
|
+
session_name (str): The new name for the session.
|
|
641
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
Optional[Union[Session, Dict[str, Any]]]:
|
|
645
|
+
- When deserialize=True: Session object
|
|
646
|
+
- When deserialize=False: Session dictionary
|
|
647
|
+
|
|
648
|
+
Raises:
|
|
649
|
+
Exception: If an error occurs during renaming.
|
|
650
|
+
"""
|
|
651
|
+
try:
|
|
652
|
+
# Get the current session as a deserialized object
|
|
653
|
+
session = await self.get_session(session_id, session_type, deserialize=True)
|
|
654
|
+
if session is None:
|
|
655
|
+
return None
|
|
656
|
+
|
|
657
|
+
session = cast(Session, session)
|
|
658
|
+
# Update the session name
|
|
659
|
+
if session.session_data is None:
|
|
660
|
+
session.session_data = {}
|
|
661
|
+
session.session_data["session_name"] = session_name
|
|
662
|
+
|
|
663
|
+
# Upsert the updated session back to the database
|
|
664
|
+
return await self.upsert_session(session, deserialize=deserialize)
|
|
665
|
+
|
|
666
|
+
except Exception as e:
|
|
667
|
+
log_error(f"Exception renaming session: {e}")
|
|
668
|
+
raise e
|
|
669
|
+
|
|
670
|
+
async def upsert_session(
|
|
671
|
+
self, session: Session, deserialize: Optional[bool] = True
|
|
672
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
673
|
+
"""
|
|
674
|
+
Insert or update a session in the database.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
session (Session): The session data to upsert.
|
|
678
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
679
|
+
|
|
680
|
+
Returns:
|
|
681
|
+
Optional[Session]:
|
|
682
|
+
- When deserialize=True: Session object
|
|
683
|
+
- When deserialize=False: Session dictionary
|
|
684
|
+
|
|
685
|
+
Raises:
|
|
686
|
+
Exception: If an error occurs during upserting.
|
|
687
|
+
"""
|
|
688
|
+
try:
|
|
689
|
+
table = await self._get_table(table_type="sessions", create_table_if_not_found=True)
|
|
690
|
+
if table is None:
|
|
691
|
+
return None
|
|
692
|
+
|
|
693
|
+
serialized_session = serialize_session_json_fields(session.to_dict())
|
|
694
|
+
|
|
695
|
+
if isinstance(session, AgentSession):
|
|
696
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
697
|
+
stmt = sqlite.insert(table).values(
|
|
698
|
+
session_id=serialized_session.get("session_id"),
|
|
699
|
+
session_type=SessionType.AGENT.value,
|
|
700
|
+
agent_id=serialized_session.get("agent_id"),
|
|
701
|
+
user_id=serialized_session.get("user_id"),
|
|
702
|
+
agent_data=serialized_session.get("agent_data"),
|
|
703
|
+
session_data=serialized_session.get("session_data"),
|
|
704
|
+
metadata=serialized_session.get("metadata"),
|
|
705
|
+
runs=serialized_session.get("runs"),
|
|
706
|
+
summary=serialized_session.get("summary"),
|
|
707
|
+
created_at=serialized_session.get("created_at"),
|
|
708
|
+
updated_at=serialized_session.get("created_at"),
|
|
709
|
+
)
|
|
710
|
+
stmt = stmt.on_conflict_do_update(
|
|
711
|
+
index_elements=["session_id"],
|
|
712
|
+
set_=dict(
|
|
713
|
+
agent_id=serialized_session.get("agent_id"),
|
|
714
|
+
user_id=serialized_session.get("user_id"),
|
|
715
|
+
runs=serialized_session.get("runs"),
|
|
716
|
+
summary=serialized_session.get("summary"),
|
|
717
|
+
agent_data=serialized_session.get("agent_data"),
|
|
718
|
+
session_data=serialized_session.get("session_data"),
|
|
719
|
+
metadata=serialized_session.get("metadata"),
|
|
720
|
+
updated_at=int(time.time()),
|
|
721
|
+
),
|
|
722
|
+
)
|
|
723
|
+
stmt = stmt.returning(*table.columns) # type: ignore
|
|
724
|
+
result = await sess.execute(stmt)
|
|
725
|
+
row = result.fetchone()
|
|
726
|
+
|
|
727
|
+
session_raw = deserialize_session_json_fields(dict(row._mapping)) if row else None
|
|
728
|
+
if session_raw is None or not deserialize:
|
|
729
|
+
return session_raw
|
|
730
|
+
return AgentSession.from_dict(session_raw)
|
|
731
|
+
|
|
732
|
+
elif isinstance(session, TeamSession):
|
|
733
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
734
|
+
stmt = sqlite.insert(table).values(
|
|
735
|
+
session_id=serialized_session.get("session_id"),
|
|
736
|
+
session_type=SessionType.TEAM.value,
|
|
737
|
+
team_id=serialized_session.get("team_id"),
|
|
738
|
+
user_id=serialized_session.get("user_id"),
|
|
739
|
+
runs=serialized_session.get("runs"),
|
|
740
|
+
summary=serialized_session.get("summary"),
|
|
741
|
+
created_at=serialized_session.get("created_at"),
|
|
742
|
+
updated_at=serialized_session.get("created_at"),
|
|
743
|
+
team_data=serialized_session.get("team_data"),
|
|
744
|
+
session_data=serialized_session.get("session_data"),
|
|
745
|
+
metadata=serialized_session.get("metadata"),
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
stmt = stmt.on_conflict_do_update(
|
|
749
|
+
index_elements=["session_id"],
|
|
750
|
+
set_=dict(
|
|
751
|
+
team_id=serialized_session.get("team_id"),
|
|
752
|
+
user_id=serialized_session.get("user_id"),
|
|
753
|
+
summary=serialized_session.get("summary"),
|
|
754
|
+
runs=serialized_session.get("runs"),
|
|
755
|
+
team_data=serialized_session.get("team_data"),
|
|
756
|
+
session_data=serialized_session.get("session_data"),
|
|
757
|
+
metadata=serialized_session.get("metadata"),
|
|
758
|
+
updated_at=int(time.time()),
|
|
759
|
+
),
|
|
760
|
+
)
|
|
761
|
+
stmt = stmt.returning(*table.columns) # type: ignore
|
|
762
|
+
result = await sess.execute(stmt)
|
|
763
|
+
row = result.fetchone()
|
|
764
|
+
|
|
765
|
+
session_raw = deserialize_session_json_fields(dict(row._mapping)) if row else None
|
|
766
|
+
if session_raw is None or not deserialize:
|
|
767
|
+
return session_raw
|
|
768
|
+
return TeamSession.from_dict(session_raw)
|
|
769
|
+
|
|
770
|
+
else:
|
|
771
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
772
|
+
stmt = sqlite.insert(table).values(
|
|
773
|
+
session_id=serialized_session.get("session_id"),
|
|
774
|
+
session_type=SessionType.WORKFLOW.value,
|
|
775
|
+
workflow_id=serialized_session.get("workflow_id"),
|
|
776
|
+
user_id=serialized_session.get("user_id"),
|
|
777
|
+
runs=serialized_session.get("runs"),
|
|
778
|
+
summary=serialized_session.get("summary"),
|
|
779
|
+
created_at=serialized_session.get("created_at") or int(time.time()),
|
|
780
|
+
updated_at=serialized_session.get("updated_at") or int(time.time()),
|
|
781
|
+
workflow_data=serialized_session.get("workflow_data"),
|
|
782
|
+
session_data=serialized_session.get("session_data"),
|
|
783
|
+
metadata=serialized_session.get("metadata"),
|
|
784
|
+
)
|
|
785
|
+
stmt = stmt.on_conflict_do_update(
|
|
786
|
+
index_elements=["session_id"],
|
|
787
|
+
set_=dict(
|
|
788
|
+
workflow_id=serialized_session.get("workflow_id"),
|
|
789
|
+
user_id=serialized_session.get("user_id"),
|
|
790
|
+
summary=serialized_session.get("summary"),
|
|
791
|
+
runs=serialized_session.get("runs"),
|
|
792
|
+
workflow_data=serialized_session.get("workflow_data"),
|
|
793
|
+
session_data=serialized_session.get("session_data"),
|
|
794
|
+
metadata=serialized_session.get("metadata"),
|
|
795
|
+
updated_at=int(time.time()),
|
|
796
|
+
),
|
|
797
|
+
)
|
|
798
|
+
stmt = stmt.returning(*table.columns) # type: ignore
|
|
799
|
+
result = await sess.execute(stmt)
|
|
800
|
+
row = result.fetchone()
|
|
801
|
+
|
|
802
|
+
session_raw = deserialize_session_json_fields(dict(row._mapping)) if row else None
|
|
803
|
+
if session_raw is None or not deserialize:
|
|
804
|
+
return session_raw
|
|
805
|
+
return WorkflowSession.from_dict(session_raw)
|
|
806
|
+
|
|
807
|
+
except Exception as e:
|
|
808
|
+
log_warning(f"Exception upserting into table: {e}")
|
|
809
|
+
raise e
|
|
810
|
+
|
|
811
|
+
async def upsert_sessions(
|
|
812
|
+
self,
|
|
813
|
+
sessions: List[Session],
|
|
814
|
+
deserialize: Optional[bool] = True,
|
|
815
|
+
preserve_updated_at: bool = False,
|
|
816
|
+
) -> List[Union[Session, Dict[str, Any]]]:
|
|
817
|
+
"""
|
|
818
|
+
Bulk upsert multiple sessions for improved performance on large datasets.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
sessions (List[Session]): List of sessions to upsert.
|
|
822
|
+
deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
|
|
823
|
+
preserve_updated_at (bool): If True, preserve the updated_at from the session object.
|
|
824
|
+
|
|
825
|
+
Returns:
|
|
826
|
+
List[Union[Session, Dict[str, Any]]]: List of upserted sessions.
|
|
827
|
+
|
|
828
|
+
Raises:
|
|
829
|
+
Exception: If an error occurs during bulk upsert.
|
|
830
|
+
"""
|
|
831
|
+
if not sessions:
|
|
832
|
+
return []
|
|
833
|
+
|
|
834
|
+
try:
|
|
835
|
+
table = await self._get_table(table_type="sessions", create_table_if_not_found=True)
|
|
836
|
+
if table is None:
|
|
837
|
+
log_info("Sessions table not available, falling back to individual upserts")
|
|
838
|
+
return [
|
|
839
|
+
result
|
|
840
|
+
for session in sessions
|
|
841
|
+
if session is not None
|
|
842
|
+
for result in [await self.upsert_session(session, deserialize=deserialize)]
|
|
843
|
+
if result is not None
|
|
844
|
+
]
|
|
845
|
+
|
|
846
|
+
# Group sessions by type for batch processing
|
|
847
|
+
agent_sessions = []
|
|
848
|
+
team_sessions = []
|
|
849
|
+
workflow_sessions = []
|
|
850
|
+
|
|
851
|
+
for session in sessions:
|
|
852
|
+
if isinstance(session, AgentSession):
|
|
853
|
+
agent_sessions.append(session)
|
|
854
|
+
elif isinstance(session, TeamSession):
|
|
855
|
+
team_sessions.append(session)
|
|
856
|
+
elif isinstance(session, WorkflowSession):
|
|
857
|
+
workflow_sessions.append(session)
|
|
858
|
+
|
|
859
|
+
results: List[Union[Session, Dict[str, Any]]] = []
|
|
860
|
+
|
|
861
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
862
|
+
# Bulk upsert agent sessions
|
|
863
|
+
if agent_sessions:
|
|
864
|
+
agent_data = []
|
|
865
|
+
for session in agent_sessions:
|
|
866
|
+
serialized_session = serialize_session_json_fields(session.to_dict())
|
|
867
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
868
|
+
updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
|
|
869
|
+
agent_data.append(
|
|
870
|
+
{
|
|
871
|
+
"session_id": serialized_session.get("session_id"),
|
|
872
|
+
"session_type": SessionType.AGENT.value,
|
|
873
|
+
"agent_id": serialized_session.get("agent_id"),
|
|
874
|
+
"user_id": serialized_session.get("user_id"),
|
|
875
|
+
"agent_data": serialized_session.get("agent_data"),
|
|
876
|
+
"session_data": serialized_session.get("session_data"),
|
|
877
|
+
"metadata": serialized_session.get("metadata"),
|
|
878
|
+
"runs": serialized_session.get("runs"),
|
|
879
|
+
"summary": serialized_session.get("summary"),
|
|
880
|
+
"created_at": serialized_session.get("created_at"),
|
|
881
|
+
"updated_at": updated_at,
|
|
882
|
+
}
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
if agent_data:
|
|
886
|
+
stmt = sqlite.insert(table)
|
|
887
|
+
stmt = stmt.on_conflict_do_update(
|
|
888
|
+
index_elements=["session_id"],
|
|
889
|
+
set_=dict(
|
|
890
|
+
agent_id=stmt.excluded.agent_id,
|
|
891
|
+
user_id=stmt.excluded.user_id,
|
|
892
|
+
agent_data=stmt.excluded.agent_data,
|
|
893
|
+
session_data=stmt.excluded.session_data,
|
|
894
|
+
metadata=stmt.excluded.metadata,
|
|
895
|
+
runs=stmt.excluded.runs,
|
|
896
|
+
summary=stmt.excluded.summary,
|
|
897
|
+
updated_at=stmt.excluded.updated_at,
|
|
898
|
+
),
|
|
899
|
+
)
|
|
900
|
+
await sess.execute(stmt, agent_data)
|
|
901
|
+
|
|
902
|
+
# Fetch the results for agent sessions
|
|
903
|
+
agent_ids = [session.session_id for session in agent_sessions]
|
|
904
|
+
select_stmt = select(table).where(table.c.session_id.in_(agent_ids))
|
|
905
|
+
result = (await sess.execute(select_stmt)).fetchall()
|
|
906
|
+
|
|
907
|
+
for row in result:
|
|
908
|
+
session_dict = deserialize_session_json_fields(dict(row._mapping))
|
|
909
|
+
if deserialize:
|
|
910
|
+
deserialized_agent_session = AgentSession.from_dict(session_dict)
|
|
911
|
+
if deserialized_agent_session is None:
|
|
912
|
+
continue
|
|
913
|
+
results.append(deserialized_agent_session)
|
|
914
|
+
else:
|
|
915
|
+
results.append(session_dict)
|
|
916
|
+
|
|
917
|
+
# Bulk upsert team sessions
|
|
918
|
+
if team_sessions:
|
|
919
|
+
team_data = []
|
|
920
|
+
for session in team_sessions:
|
|
921
|
+
serialized_session = serialize_session_json_fields(session.to_dict())
|
|
922
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
923
|
+
updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
|
|
924
|
+
team_data.append(
|
|
925
|
+
{
|
|
926
|
+
"session_id": serialized_session.get("session_id"),
|
|
927
|
+
"session_type": SessionType.TEAM.value,
|
|
928
|
+
"team_id": serialized_session.get("team_id"),
|
|
929
|
+
"user_id": serialized_session.get("user_id"),
|
|
930
|
+
"runs": serialized_session.get("runs"),
|
|
931
|
+
"summary": serialized_session.get("summary"),
|
|
932
|
+
"created_at": serialized_session.get("created_at"),
|
|
933
|
+
"updated_at": updated_at,
|
|
934
|
+
"team_data": serialized_session.get("team_data"),
|
|
935
|
+
"session_data": serialized_session.get("session_data"),
|
|
936
|
+
"metadata": serialized_session.get("metadata"),
|
|
937
|
+
}
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
if team_data:
|
|
941
|
+
stmt = sqlite.insert(table)
|
|
942
|
+
stmt = stmt.on_conflict_do_update(
|
|
943
|
+
index_elements=["session_id"],
|
|
944
|
+
set_=dict(
|
|
945
|
+
team_id=stmt.excluded.team_id,
|
|
946
|
+
user_id=stmt.excluded.user_id,
|
|
947
|
+
team_data=stmt.excluded.team_data,
|
|
948
|
+
session_data=stmt.excluded.session_data,
|
|
949
|
+
metadata=stmt.excluded.metadata,
|
|
950
|
+
runs=stmt.excluded.runs,
|
|
951
|
+
summary=stmt.excluded.summary,
|
|
952
|
+
updated_at=stmt.excluded.updated_at,
|
|
953
|
+
),
|
|
954
|
+
)
|
|
955
|
+
await sess.execute(stmt, team_data)
|
|
956
|
+
|
|
957
|
+
# Fetch the results for team sessions
|
|
958
|
+
team_ids = [session.session_id for session in team_sessions]
|
|
959
|
+
select_stmt = select(table).where(table.c.session_id.in_(team_ids))
|
|
960
|
+
result = (await sess.execute(select_stmt)).fetchall()
|
|
961
|
+
|
|
962
|
+
for row in result:
|
|
963
|
+
session_dict = deserialize_session_json_fields(dict(row._mapping))
|
|
964
|
+
if deserialize:
|
|
965
|
+
deserialized_team_session = TeamSession.from_dict(session_dict)
|
|
966
|
+
if deserialized_team_session is None:
|
|
967
|
+
continue
|
|
968
|
+
results.append(deserialized_team_session)
|
|
969
|
+
else:
|
|
970
|
+
results.append(session_dict)
|
|
971
|
+
|
|
972
|
+
# Bulk upsert workflow sessions
|
|
973
|
+
if workflow_sessions:
|
|
974
|
+
workflow_data = []
|
|
975
|
+
for session in workflow_sessions:
|
|
976
|
+
serialized_session = serialize_session_json_fields(session.to_dict())
|
|
977
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
978
|
+
updated_at = serialized_session.get("updated_at") if preserve_updated_at else int(time.time())
|
|
979
|
+
workflow_data.append(
|
|
980
|
+
{
|
|
981
|
+
"session_id": serialized_session.get("session_id"),
|
|
982
|
+
"session_type": SessionType.WORKFLOW.value,
|
|
983
|
+
"workflow_id": serialized_session.get("workflow_id"),
|
|
984
|
+
"user_id": serialized_session.get("user_id"),
|
|
985
|
+
"runs": serialized_session.get("runs"),
|
|
986
|
+
"summary": serialized_session.get("summary"),
|
|
987
|
+
"created_at": serialized_session.get("created_at"),
|
|
988
|
+
"updated_at": updated_at,
|
|
989
|
+
"workflow_data": serialized_session.get("workflow_data"),
|
|
990
|
+
"session_data": serialized_session.get("session_data"),
|
|
991
|
+
"metadata": serialized_session.get("metadata"),
|
|
992
|
+
}
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
if workflow_data:
|
|
996
|
+
stmt = sqlite.insert(table)
|
|
997
|
+
stmt = stmt.on_conflict_do_update(
|
|
998
|
+
index_elements=["session_id"],
|
|
999
|
+
set_=dict(
|
|
1000
|
+
workflow_id=stmt.excluded.workflow_id,
|
|
1001
|
+
user_id=stmt.excluded.user_id,
|
|
1002
|
+
workflow_data=stmt.excluded.workflow_data,
|
|
1003
|
+
session_data=stmt.excluded.session_data,
|
|
1004
|
+
metadata=stmt.excluded.metadata,
|
|
1005
|
+
runs=stmt.excluded.runs,
|
|
1006
|
+
summary=stmt.excluded.summary,
|
|
1007
|
+
updated_at=stmt.excluded.updated_at,
|
|
1008
|
+
),
|
|
1009
|
+
)
|
|
1010
|
+
await sess.execute(stmt, workflow_data)
|
|
1011
|
+
|
|
1012
|
+
# Fetch the results for workflow sessions
|
|
1013
|
+
workflow_ids = [session.session_id for session in workflow_sessions]
|
|
1014
|
+
select_stmt = select(table).where(table.c.session_id.in_(workflow_ids))
|
|
1015
|
+
result = (await sess.execute(select_stmt)).fetchall()
|
|
1016
|
+
|
|
1017
|
+
for row in result:
|
|
1018
|
+
session_dict = deserialize_session_json_fields(dict(row._mapping))
|
|
1019
|
+
if deserialize:
|
|
1020
|
+
deserialized_workflow_session = WorkflowSession.from_dict(session_dict)
|
|
1021
|
+
if deserialized_workflow_session is None:
|
|
1022
|
+
continue
|
|
1023
|
+
results.append(deserialized_workflow_session)
|
|
1024
|
+
else:
|
|
1025
|
+
results.append(session_dict)
|
|
1026
|
+
|
|
1027
|
+
return results
|
|
1028
|
+
|
|
1029
|
+
except Exception as e:
|
|
1030
|
+
log_error(f"Exception during bulk session upsert, falling back to individual upserts: {e}")
|
|
1031
|
+
# Fallback to individual upserts
|
|
1032
|
+
return [
|
|
1033
|
+
result
|
|
1034
|
+
for session in sessions
|
|
1035
|
+
if session is not None
|
|
1036
|
+
for result in [await self.upsert_session(session, deserialize=deserialize)]
|
|
1037
|
+
if result is not None
|
|
1038
|
+
]
|
|
1039
|
+
|
|
1040
|
+
# -- Memory methods --
|
|
1041
|
+
|
|
1042
|
+
async def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
|
|
1043
|
+
"""Delete a user memory from the database.
|
|
1044
|
+
|
|
1045
|
+
Args:
|
|
1046
|
+
memory_id (str): The ID of the memory to delete.
|
|
1047
|
+
user_id (Optional[str]): The user ID to filter by. Defaults to None.
|
|
1048
|
+
|
|
1049
|
+
Returns:
|
|
1050
|
+
bool: True if deletion was successful, False otherwise.
|
|
1051
|
+
|
|
1052
|
+
Raises:
|
|
1053
|
+
Exception: If an error occurs during deletion.
|
|
1054
|
+
"""
|
|
1055
|
+
try:
|
|
1056
|
+
table = await self._get_table(table_type="memories")
|
|
1057
|
+
if table is None:
|
|
1058
|
+
return
|
|
1059
|
+
|
|
1060
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1061
|
+
delete_stmt = table.delete().where(table.c.memory_id == memory_id)
|
|
1062
|
+
if user_id is not None:
|
|
1063
|
+
delete_stmt = delete_stmt.where(table.c.user_id == user_id)
|
|
1064
|
+
result = await sess.execute(delete_stmt)
|
|
1065
|
+
|
|
1066
|
+
success = result.rowcount > 0 # type: ignore
|
|
1067
|
+
if success:
|
|
1068
|
+
log_debug(f"Successfully deleted user memory id: {memory_id}")
|
|
1069
|
+
else:
|
|
1070
|
+
log_debug(f"No user memory found with id: {memory_id}")
|
|
1071
|
+
|
|
1072
|
+
except Exception as e:
|
|
1073
|
+
log_error(f"Error deleting user memory: {e}")
|
|
1074
|
+
raise e
|
|
1075
|
+
|
|
1076
|
+
async def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
|
|
1077
|
+
"""Delete user memories from the database.
|
|
1078
|
+
|
|
1079
|
+
Args:
|
|
1080
|
+
memory_ids (List[str]): The IDs of the memories to delete.
|
|
1081
|
+
user_id (Optional[str]): The user ID to filter by. Defaults to None.
|
|
1082
|
+
|
|
1083
|
+
Raises:
|
|
1084
|
+
Exception: If an error occurs during deletion.
|
|
1085
|
+
"""
|
|
1086
|
+
try:
|
|
1087
|
+
table = await self._get_table(table_type="memories")
|
|
1088
|
+
if table is None:
|
|
1089
|
+
return
|
|
1090
|
+
|
|
1091
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1092
|
+
delete_stmt = table.delete().where(table.c.memory_id.in_(memory_ids))
|
|
1093
|
+
if user_id is not None:
|
|
1094
|
+
delete_stmt = delete_stmt.where(table.c.user_id == user_id)
|
|
1095
|
+
result = await sess.execute(delete_stmt)
|
|
1096
|
+
if result.rowcount == 0: # type: ignore
|
|
1097
|
+
log_debug(f"No user memories found with ids: {memory_ids}")
|
|
1098
|
+
|
|
1099
|
+
except Exception as e:
|
|
1100
|
+
log_error(f"Error deleting user memories: {e}")
|
|
1101
|
+
raise e
|
|
1102
|
+
|
|
1103
|
+
async def get_all_memory_topics(self) -> List[str]:
|
|
1104
|
+
"""Get all memory topics from the database.
|
|
1105
|
+
|
|
1106
|
+
Returns:
|
|
1107
|
+
List[str]: List of memory topics.
|
|
1108
|
+
"""
|
|
1109
|
+
try:
|
|
1110
|
+
table = await self._get_table(table_type="memories")
|
|
1111
|
+
if table is None:
|
|
1112
|
+
return []
|
|
1113
|
+
|
|
1114
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1115
|
+
# Select topics from all results
|
|
1116
|
+
stmt = select(table.c.topics)
|
|
1117
|
+
result = (await sess.execute(stmt)).fetchall()
|
|
1118
|
+
|
|
1119
|
+
return list(set([record[0] for record in result]))
|
|
1120
|
+
|
|
1121
|
+
except Exception as e:
|
|
1122
|
+
log_debug(f"Exception reading from memory table: {e}")
|
|
1123
|
+
raise e
|
|
1124
|
+
|
|
1125
|
+
async def get_user_memory(
|
|
1126
|
+
self,
|
|
1127
|
+
memory_id: str,
|
|
1128
|
+
deserialize: Optional[bool] = True,
|
|
1129
|
+
user_id: Optional[str] = None,
|
|
1130
|
+
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
1131
|
+
"""Get a memory from the database.
|
|
1132
|
+
|
|
1133
|
+
Args:
|
|
1134
|
+
memory_id (str): The ID of the memory to get.
|
|
1135
|
+
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
1136
|
+
user_id (Optional[str]): The user ID to filter by. Defaults to None.
|
|
1137
|
+
|
|
1138
|
+
Returns:
|
|
1139
|
+
Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
1140
|
+
- When deserialize=True: UserMemory object
|
|
1141
|
+
- When deserialize=False: Memory dictionary
|
|
1142
|
+
|
|
1143
|
+
Raises:
|
|
1144
|
+
Exception: If an error occurs during retrieval.
|
|
1145
|
+
"""
|
|
1146
|
+
try:
|
|
1147
|
+
table = await self._get_table(table_type="memories")
|
|
1148
|
+
if table is None:
|
|
1149
|
+
return None
|
|
1150
|
+
|
|
1151
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1152
|
+
stmt = select(table).where(table.c.memory_id == memory_id)
|
|
1153
|
+
if user_id is not None:
|
|
1154
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
1155
|
+
result = (await sess.execute(stmt)).fetchone()
|
|
1156
|
+
if result is None:
|
|
1157
|
+
return None
|
|
1158
|
+
|
|
1159
|
+
memory_raw = dict(result._mapping)
|
|
1160
|
+
if not memory_raw or not deserialize:
|
|
1161
|
+
return memory_raw
|
|
1162
|
+
|
|
1163
|
+
return UserMemory.from_dict(memory_raw)
|
|
1164
|
+
|
|
1165
|
+
except Exception as e:
|
|
1166
|
+
log_debug(f"Exception reading from memorytable: {e}")
|
|
1167
|
+
raise e
|
|
1168
|
+
|
|
1169
|
+
async def get_user_memories(
|
|
1170
|
+
self,
|
|
1171
|
+
user_id: Optional[str] = None,
|
|
1172
|
+
agent_id: Optional[str] = None,
|
|
1173
|
+
team_id: Optional[str] = None,
|
|
1174
|
+
topics: Optional[List[str]] = None,
|
|
1175
|
+
search_content: Optional[str] = None,
|
|
1176
|
+
limit: Optional[int] = None,
|
|
1177
|
+
page: Optional[int] = None,
|
|
1178
|
+
sort_by: Optional[str] = None,
|
|
1179
|
+
sort_order: Optional[str] = None,
|
|
1180
|
+
deserialize: Optional[bool] = True,
|
|
1181
|
+
) -> Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
|
|
1182
|
+
"""Get all memories from the database as UserMemory objects.
|
|
1183
|
+
|
|
1184
|
+
Args:
|
|
1185
|
+
user_id (Optional[str]): The ID of the user to filter by.
|
|
1186
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
1187
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
1188
|
+
topics (Optional[List[str]]): The topics to filter by.
|
|
1189
|
+
search_content (Optional[str]): The content to search for.
|
|
1190
|
+
limit (Optional[int]): The maximum number of memories to return.
|
|
1191
|
+
page (Optional[int]): The page number.
|
|
1192
|
+
sort_by (Optional[str]): The column to sort by.
|
|
1193
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1194
|
+
deserialize (Optional[bool]): Whether to serialize the memories. Defaults to True.
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
Returns:
|
|
1198
|
+
Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
|
|
1199
|
+
- When deserialize=True: List of UserMemory objects
|
|
1200
|
+
- When deserialize=False: List of UserMemory dictionaries and total count
|
|
1201
|
+
|
|
1202
|
+
Raises:
|
|
1203
|
+
Exception: If an error occurs during retrieval.
|
|
1204
|
+
"""
|
|
1205
|
+
try:
|
|
1206
|
+
table = await self._get_table(table_type="memories")
|
|
1207
|
+
if table is None:
|
|
1208
|
+
return [] if deserialize else ([], 0)
|
|
1209
|
+
|
|
1210
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1211
|
+
stmt = select(table)
|
|
1212
|
+
|
|
1213
|
+
# Filtering
|
|
1214
|
+
if user_id is not None:
|
|
1215
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
1216
|
+
if agent_id is not None:
|
|
1217
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
1218
|
+
if team_id is not None:
|
|
1219
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
1220
|
+
if topics is not None:
|
|
1221
|
+
for topic in topics:
|
|
1222
|
+
stmt = stmt.where(func.cast(table.c.topics, String).like(f'%"{topic}"%'))
|
|
1223
|
+
if search_content is not None:
|
|
1224
|
+
stmt = stmt.where(table.c.memory.ilike(f"%{search_content}%"))
|
|
1225
|
+
|
|
1226
|
+
# Get total count after applying filtering
|
|
1227
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1228
|
+
total_count = (await sess.execute(count_stmt)).scalar() or 0
|
|
1229
|
+
|
|
1230
|
+
# Sorting
|
|
1231
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
1232
|
+
# Paginating
|
|
1233
|
+
if limit is not None:
|
|
1234
|
+
stmt = stmt.limit(limit)
|
|
1235
|
+
if page is not None:
|
|
1236
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
1237
|
+
|
|
1238
|
+
result = (await sess.execute(stmt)).fetchall()
|
|
1239
|
+
if not result:
|
|
1240
|
+
return [] if deserialize else ([], 0)
|
|
1241
|
+
|
|
1242
|
+
memories_raw = [dict(record._mapping) for record in result]
|
|
1243
|
+
|
|
1244
|
+
if not deserialize:
|
|
1245
|
+
return memories_raw, total_count
|
|
1246
|
+
|
|
1247
|
+
return [UserMemory.from_dict(record) for record in memories_raw]
|
|
1248
|
+
|
|
1249
|
+
except Exception as e:
|
|
1250
|
+
log_error(f"Error reading from memory table: {e}")
|
|
1251
|
+
raise e
|
|
1252
|
+
|
|
1253
|
+
async def get_user_memory_stats(
|
|
1254
|
+
self,
|
|
1255
|
+
limit: Optional[int] = None,
|
|
1256
|
+
page: Optional[int] = None,
|
|
1257
|
+
user_id: Optional[str] = None,
|
|
1258
|
+
) -> Tuple[List[Dict[str, Any]], int]:
|
|
1259
|
+
"""Get user memories stats.
|
|
1260
|
+
|
|
1261
|
+
Args:
|
|
1262
|
+
limit (Optional[int]): The maximum number of user stats to return.
|
|
1263
|
+
page (Optional[int]): The page number.
|
|
1264
|
+
user_id (Optional[str]): User ID for filtering.
|
|
1265
|
+
|
|
1266
|
+
Returns:
|
|
1267
|
+
Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
|
|
1268
|
+
|
|
1269
|
+
Example:
|
|
1270
|
+
(
|
|
1271
|
+
[
|
|
1272
|
+
{
|
|
1273
|
+
"user_id": "123",
|
|
1274
|
+
"total_memories": 10,
|
|
1275
|
+
"last_memory_updated_at": 1714560000,
|
|
1276
|
+
},
|
|
1277
|
+
],
|
|
1278
|
+
total_count: 1,
|
|
1279
|
+
)
|
|
1280
|
+
"""
|
|
1281
|
+
try:
|
|
1282
|
+
table = await self._get_table(table_type="memories")
|
|
1283
|
+
if table is None:
|
|
1284
|
+
return [], 0
|
|
1285
|
+
|
|
1286
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1287
|
+
stmt = select(
|
|
1288
|
+
table.c.user_id,
|
|
1289
|
+
func.count(table.c.memory_id).label("total_memories"),
|
|
1290
|
+
func.max(table.c.updated_at).label("last_memory_updated_at"),
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
if user_id is not None:
|
|
1294
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
1295
|
+
else:
|
|
1296
|
+
stmt = stmt.where(table.c.user_id.is_not(None))
|
|
1297
|
+
stmt = stmt.group_by(table.c.user_id)
|
|
1298
|
+
stmt = stmt.order_by(func.max(table.c.updated_at).desc())
|
|
1299
|
+
|
|
1300
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1301
|
+
total_count = (await sess.execute(count_stmt)).scalar() or 0
|
|
1302
|
+
|
|
1303
|
+
# Pagination
|
|
1304
|
+
if limit is not None:
|
|
1305
|
+
stmt = stmt.limit(limit)
|
|
1306
|
+
if page is not None:
|
|
1307
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
1308
|
+
|
|
1309
|
+
result = (await sess.execute(stmt)).fetchall()
|
|
1310
|
+
if not result:
|
|
1311
|
+
return [], 0
|
|
1312
|
+
|
|
1313
|
+
return [
|
|
1314
|
+
{
|
|
1315
|
+
"user_id": record.user_id, # type: ignore
|
|
1316
|
+
"total_memories": record.total_memories,
|
|
1317
|
+
"last_memory_updated_at": record.last_memory_updated_at,
|
|
1318
|
+
}
|
|
1319
|
+
for record in result
|
|
1320
|
+
], total_count
|
|
1321
|
+
|
|
1322
|
+
except Exception as e:
|
|
1323
|
+
log_error(f"Error getting user memory stats: {e}")
|
|
1324
|
+
raise e
|
|
1325
|
+
|
|
1326
|
+
async def upsert_user_memory(
|
|
1327
|
+
self, memory: UserMemory, deserialize: Optional[bool] = True
|
|
1328
|
+
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
1329
|
+
"""Upsert a user memory in the database.
|
|
1330
|
+
|
|
1331
|
+
Args:
|
|
1332
|
+
memory (UserMemory): The user memory to upsert.
|
|
1333
|
+
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
1334
|
+
|
|
1335
|
+
Returns:
|
|
1336
|
+
Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
1337
|
+
- When deserialize=True: UserMemory object
|
|
1338
|
+
- When deserialize=False: UserMemory dictionary
|
|
1339
|
+
|
|
1340
|
+
Raises:
|
|
1341
|
+
Exception: If an error occurs during upsert.
|
|
1342
|
+
"""
|
|
1343
|
+
try:
|
|
1344
|
+
table = await self._get_table(table_type="memories")
|
|
1345
|
+
if table is None:
|
|
1346
|
+
return None
|
|
1347
|
+
|
|
1348
|
+
if memory.memory_id is None:
|
|
1349
|
+
memory.memory_id = str(uuid4())
|
|
1350
|
+
|
|
1351
|
+
current_time = int(time.time())
|
|
1352
|
+
|
|
1353
|
+
async with self.async_session_factory() as sess:
|
|
1354
|
+
async with sess.begin():
|
|
1355
|
+
stmt = sqlite.insert(table).values(
|
|
1356
|
+
user_id=memory.user_id,
|
|
1357
|
+
agent_id=memory.agent_id,
|
|
1358
|
+
team_id=memory.team_id,
|
|
1359
|
+
memory_id=memory.memory_id,
|
|
1360
|
+
memory=memory.memory,
|
|
1361
|
+
topics=memory.topics,
|
|
1362
|
+
input=memory.input,
|
|
1363
|
+
feedback=memory.feedback,
|
|
1364
|
+
created_at=memory.created_at,
|
|
1365
|
+
updated_at=memory.created_at,
|
|
1366
|
+
)
|
|
1367
|
+
stmt = stmt.on_conflict_do_update( # type: ignore
|
|
1368
|
+
index_elements=["memory_id"],
|
|
1369
|
+
set_=dict(
|
|
1370
|
+
memory=memory.memory,
|
|
1371
|
+
topics=memory.topics,
|
|
1372
|
+
input=memory.input,
|
|
1373
|
+
agent_id=memory.agent_id,
|
|
1374
|
+
team_id=memory.team_id,
|
|
1375
|
+
feedback=memory.feedback,
|
|
1376
|
+
updated_at=current_time,
|
|
1377
|
+
# Preserve created_at on update - don't overwrite existing value
|
|
1378
|
+
created_at=table.c.created_at,
|
|
1379
|
+
),
|
|
1380
|
+
).returning(table)
|
|
1381
|
+
|
|
1382
|
+
result = await sess.execute(stmt)
|
|
1383
|
+
row = result.fetchone()
|
|
1384
|
+
|
|
1385
|
+
if row is None:
|
|
1386
|
+
return None
|
|
1387
|
+
|
|
1388
|
+
memory_raw = dict(row._mapping)
|
|
1389
|
+
if not memory_raw or not deserialize:
|
|
1390
|
+
return memory_raw
|
|
1391
|
+
|
|
1392
|
+
return UserMemory.from_dict(memory_raw)
|
|
1393
|
+
|
|
1394
|
+
except Exception as e:
|
|
1395
|
+
log_error(f"Error upserting user memory: {e}")
|
|
1396
|
+
raise e
|
|
1397
|
+
|
|
1398
|
+
async def upsert_memories(
|
|
1399
|
+
self,
|
|
1400
|
+
memories: List[UserMemory],
|
|
1401
|
+
deserialize: Optional[bool] = True,
|
|
1402
|
+
preserve_updated_at: bool = False,
|
|
1403
|
+
) -> List[Union[UserMemory, Dict[str, Any]]]:
|
|
1404
|
+
"""
|
|
1405
|
+
Bulk upsert multiple user memories for improved performance on large datasets.
|
|
1406
|
+
|
|
1407
|
+
Args:
|
|
1408
|
+
memories (List[UserMemory]): List of memories to upsert.
|
|
1409
|
+
deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
|
|
1410
|
+
|
|
1411
|
+
Returns:
|
|
1412
|
+
List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories.
|
|
1413
|
+
|
|
1414
|
+
Raises:
|
|
1415
|
+
Exception: If an error occurs during bulk upsert.
|
|
1416
|
+
"""
|
|
1417
|
+
if not memories:
|
|
1418
|
+
return []
|
|
1419
|
+
|
|
1420
|
+
try:
|
|
1421
|
+
table = await self._get_table(table_type="memories")
|
|
1422
|
+
if table is None:
|
|
1423
|
+
log_info("Memories table not available, falling back to individual upserts")
|
|
1424
|
+
return [
|
|
1425
|
+
result
|
|
1426
|
+
for memory in memories
|
|
1427
|
+
if memory is not None
|
|
1428
|
+
for result in [await self.upsert_user_memory(memory, deserialize=deserialize)]
|
|
1429
|
+
if result is not None
|
|
1430
|
+
]
|
|
1431
|
+
# Prepare bulk data
|
|
1432
|
+
bulk_data = []
|
|
1433
|
+
current_time = int(time.time())
|
|
1434
|
+
|
|
1435
|
+
for memory in memories:
|
|
1436
|
+
if memory.memory_id is None:
|
|
1437
|
+
memory.memory_id = str(uuid4())
|
|
1438
|
+
|
|
1439
|
+
# Use preserved updated_at if flag is set and value exists, otherwise use current time
|
|
1440
|
+
updated_at = memory.updated_at if preserve_updated_at else current_time
|
|
1441
|
+
|
|
1442
|
+
bulk_data.append(
|
|
1443
|
+
{
|
|
1444
|
+
"user_id": memory.user_id,
|
|
1445
|
+
"agent_id": memory.agent_id,
|
|
1446
|
+
"team_id": memory.team_id,
|
|
1447
|
+
"memory_id": memory.memory_id,
|
|
1448
|
+
"memory": memory.memory,
|
|
1449
|
+
"topics": memory.topics,
|
|
1450
|
+
"input": memory.input,
|
|
1451
|
+
"feedback": memory.feedback,
|
|
1452
|
+
"created_at": memory.created_at,
|
|
1453
|
+
"updated_at": updated_at,
|
|
1454
|
+
}
|
|
1455
|
+
)
|
|
1456
|
+
|
|
1457
|
+
results: List[Union[UserMemory, Dict[str, Any]]] = []
|
|
1458
|
+
|
|
1459
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1460
|
+
# Bulk upsert memories using SQLite ON CONFLICT DO UPDATE
|
|
1461
|
+
stmt = sqlite.insert(table)
|
|
1462
|
+
stmt = stmt.on_conflict_do_update(
|
|
1463
|
+
index_elements=["memory_id"],
|
|
1464
|
+
set_=dict(
|
|
1465
|
+
memory=stmt.excluded.memory,
|
|
1466
|
+
topics=stmt.excluded.topics,
|
|
1467
|
+
input=stmt.excluded.input,
|
|
1468
|
+
agent_id=stmt.excluded.agent_id,
|
|
1469
|
+
team_id=stmt.excluded.team_id,
|
|
1470
|
+
feedback=stmt.excluded.feedback,
|
|
1471
|
+
updated_at=stmt.excluded.updated_at,
|
|
1472
|
+
# Preserve created_at on update
|
|
1473
|
+
created_at=table.c.created_at,
|
|
1474
|
+
),
|
|
1475
|
+
)
|
|
1476
|
+
await sess.execute(stmt, bulk_data)
|
|
1477
|
+
|
|
1478
|
+
# Fetch results
|
|
1479
|
+
memory_ids = [memory.memory_id for memory in memories if memory.memory_id]
|
|
1480
|
+
select_stmt = select(table).where(table.c.memory_id.in_(memory_ids))
|
|
1481
|
+
result = (await sess.execute(select_stmt)).fetchall()
|
|
1482
|
+
|
|
1483
|
+
for row in result:
|
|
1484
|
+
memory_dict = dict(row._mapping)
|
|
1485
|
+
if deserialize:
|
|
1486
|
+
results.append(UserMemory.from_dict(memory_dict))
|
|
1487
|
+
else:
|
|
1488
|
+
results.append(memory_dict)
|
|
1489
|
+
|
|
1490
|
+
return results
|
|
1491
|
+
|
|
1492
|
+
except Exception as e:
|
|
1493
|
+
log_error(f"Exception during bulk memory upsert, falling back to individual upserts: {e}")
|
|
1494
|
+
|
|
1495
|
+
# Fallback to individual upserts
|
|
1496
|
+
return [
|
|
1497
|
+
result
|
|
1498
|
+
for memory in memories
|
|
1499
|
+
if memory is not None
|
|
1500
|
+
for result in [await self.upsert_user_memory(memory, deserialize=deserialize)]
|
|
1501
|
+
if result is not None
|
|
1502
|
+
]
|
|
1503
|
+
|
|
1504
|
+
async def clear_memories(self) -> None:
|
|
1505
|
+
"""Delete all memories from the database.
|
|
1506
|
+
|
|
1507
|
+
Raises:
|
|
1508
|
+
Exception: If an error occurs during deletion.
|
|
1509
|
+
"""
|
|
1510
|
+
try:
|
|
1511
|
+
table = await self._get_table(table_type="memories")
|
|
1512
|
+
if table is None:
|
|
1513
|
+
return
|
|
1514
|
+
|
|
1515
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1516
|
+
await sess.execute(table.delete())
|
|
1517
|
+
|
|
1518
|
+
except Exception as e:
|
|
1519
|
+
from agno.utils.log import log_warning
|
|
1520
|
+
|
|
1521
|
+
log_warning(f"Exception deleting all memories: {e}")
|
|
1522
|
+
raise e
|
|
1523
|
+
|
|
1524
|
+
# -- Metrics methods --
|
|
1525
|
+
|
|
1526
|
+
async def _get_all_sessions_for_metrics_calculation(
|
|
1527
|
+
self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None
|
|
1528
|
+
) -> List[Dict[str, Any]]:
|
|
1529
|
+
"""
|
|
1530
|
+
Get all sessions of all types (agent, team, workflow) as raw dictionaries.
|
|
1531
|
+
|
|
1532
|
+
Args:
|
|
1533
|
+
start_timestamp (Optional[int]): The start timestamp to filter by. Defaults to None.
|
|
1534
|
+
end_timestamp (Optional[int]): The end timestamp to filter by. Defaults to None.
|
|
1535
|
+
|
|
1536
|
+
Returns:
|
|
1537
|
+
List[Dict[str, Any]]: List of session dictionaries with session_type field.
|
|
1538
|
+
|
|
1539
|
+
Raises:
|
|
1540
|
+
Exception: If an error occurs during retrieval.
|
|
1541
|
+
"""
|
|
1542
|
+
try:
|
|
1543
|
+
table = await self._get_table(table_type="sessions", create_table_if_not_found=True)
|
|
1544
|
+
if table is None:
|
|
1545
|
+
return []
|
|
1546
|
+
|
|
1547
|
+
stmt = select(
|
|
1548
|
+
table.c.user_id,
|
|
1549
|
+
table.c.session_data,
|
|
1550
|
+
table.c.runs,
|
|
1551
|
+
table.c.created_at,
|
|
1552
|
+
table.c.session_type,
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
if start_timestamp is not None:
|
|
1556
|
+
stmt = stmt.where(table.c.created_at >= start_timestamp)
|
|
1557
|
+
if end_timestamp is not None:
|
|
1558
|
+
stmt = stmt.where(table.c.created_at <= end_timestamp)
|
|
1559
|
+
|
|
1560
|
+
async with self.async_session_factory() as sess:
|
|
1561
|
+
result = (await sess.execute(stmt)).fetchall()
|
|
1562
|
+
return [dict(record._mapping) for record in result]
|
|
1563
|
+
|
|
1564
|
+
except Exception as e:
|
|
1565
|
+
log_error(f"Error reading from sessions table: {e}")
|
|
1566
|
+
raise e
|
|
1567
|
+
|
|
1568
|
+
async def _get_metrics_calculation_starting_date(self, table: Table) -> Optional[date]:
|
|
1569
|
+
"""Get the first date for which metrics calculation is needed:
|
|
1570
|
+
|
|
1571
|
+
1. If there are metrics records, return the date of the first day without a complete metrics record.
|
|
1572
|
+
2. If there are no metrics records, return the date of the first recorded session.
|
|
1573
|
+
3. If there are no metrics records and no sessions records, return None.
|
|
1574
|
+
|
|
1575
|
+
Args:
|
|
1576
|
+
table (Table): The table to get the starting date for.
|
|
1577
|
+
|
|
1578
|
+
Returns:
|
|
1579
|
+
Optional[date]: The starting date for which metrics calculation is needed.
|
|
1580
|
+
"""
|
|
1581
|
+
async with self.async_session_factory() as sess:
|
|
1582
|
+
stmt = select(table).order_by(table.c.date.desc()).limit(1)
|
|
1583
|
+
result = (await sess.execute(stmt)).fetchone()
|
|
1584
|
+
|
|
1585
|
+
# 1. Return the date of the first day without a complete metrics record.
|
|
1586
|
+
if result is not None:
|
|
1587
|
+
if result.completed:
|
|
1588
|
+
return result._mapping["date"] + timedelta(days=1)
|
|
1589
|
+
else:
|
|
1590
|
+
return result._mapping["date"]
|
|
1591
|
+
|
|
1592
|
+
# 2. No metrics records. Return the date of the first recorded session.
|
|
1593
|
+
first_session, _ = await self.get_sessions(sort_by="created_at", sort_order="asc", limit=1, deserialize=False)
|
|
1594
|
+
first_session_date = first_session[0]["created_at"] if first_session else None # type: ignore
|
|
1595
|
+
|
|
1596
|
+
# 3. No metrics records and no sessions records. Return None.
|
|
1597
|
+
if not first_session_date:
|
|
1598
|
+
return None
|
|
1599
|
+
|
|
1600
|
+
return datetime.fromtimestamp(first_session_date, tz=timezone.utc).date()
|
|
1601
|
+
|
|
1602
|
+
async def calculate_metrics(self) -> Optional[list[dict]]:
|
|
1603
|
+
"""Calculate metrics for all dates without complete metrics.
|
|
1604
|
+
|
|
1605
|
+
Returns:
|
|
1606
|
+
Optional[list[dict]]: The calculated metrics.
|
|
1607
|
+
|
|
1608
|
+
Raises:
|
|
1609
|
+
Exception: If an error occurs during metrics calculation.
|
|
1610
|
+
"""
|
|
1611
|
+
try:
|
|
1612
|
+
table = await self._get_table(table_type="metrics")
|
|
1613
|
+
if table is None:
|
|
1614
|
+
return None
|
|
1615
|
+
|
|
1616
|
+
starting_date = await self._get_metrics_calculation_starting_date(table)
|
|
1617
|
+
if starting_date is None:
|
|
1618
|
+
log_info("No session data found. Won't calculate metrics.")
|
|
1619
|
+
return None
|
|
1620
|
+
|
|
1621
|
+
dates_to_process = get_dates_to_calculate_metrics_for(starting_date)
|
|
1622
|
+
if not dates_to_process:
|
|
1623
|
+
log_info("Metrics already calculated for all relevant dates.")
|
|
1624
|
+
return None
|
|
1625
|
+
|
|
1626
|
+
start_timestamp = int(
|
|
1627
|
+
datetime.combine(dates_to_process[0], datetime.min.time()).replace(tzinfo=timezone.utc).timestamp()
|
|
1628
|
+
)
|
|
1629
|
+
end_timestamp = int(
|
|
1630
|
+
datetime.combine(dates_to_process[-1] + timedelta(days=1), datetime.min.time())
|
|
1631
|
+
.replace(tzinfo=timezone.utc)
|
|
1632
|
+
.timestamp()
|
|
1633
|
+
)
|
|
1634
|
+
|
|
1635
|
+
sessions = await self._get_all_sessions_for_metrics_calculation(
|
|
1636
|
+
start_timestamp=start_timestamp, end_timestamp=end_timestamp
|
|
1637
|
+
)
|
|
1638
|
+
all_sessions_data = fetch_all_sessions_data(
|
|
1639
|
+
sessions=sessions,
|
|
1640
|
+
dates_to_process=dates_to_process,
|
|
1641
|
+
start_timestamp=start_timestamp,
|
|
1642
|
+
)
|
|
1643
|
+
if not all_sessions_data:
|
|
1644
|
+
log_info("No new session data found. Won't calculate metrics.")
|
|
1645
|
+
return None
|
|
1646
|
+
|
|
1647
|
+
results = []
|
|
1648
|
+
metrics_records = []
|
|
1649
|
+
|
|
1650
|
+
for date_to_process in dates_to_process:
|
|
1651
|
+
date_key = date_to_process.isoformat()
|
|
1652
|
+
sessions_for_date = all_sessions_data.get(date_key, {})
|
|
1653
|
+
|
|
1654
|
+
# Skip dates with no sessions
|
|
1655
|
+
if not any(len(sessions) > 0 for sessions in sessions_for_date.values()):
|
|
1656
|
+
continue
|
|
1657
|
+
|
|
1658
|
+
metrics_record = calculate_date_metrics(date_to_process, sessions_for_date)
|
|
1659
|
+
metrics_records.append(metrics_record)
|
|
1660
|
+
|
|
1661
|
+
if metrics_records:
|
|
1662
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1663
|
+
results = await abulk_upsert_metrics(session=sess, table=table, metrics_records=metrics_records)
|
|
1664
|
+
|
|
1665
|
+
log_debug("Updated metrics calculations")
|
|
1666
|
+
|
|
1667
|
+
return results
|
|
1668
|
+
|
|
1669
|
+
except Exception as e:
|
|
1670
|
+
log_error(f"Error refreshing metrics: {e}")
|
|
1671
|
+
raise e
|
|
1672
|
+
|
|
1673
|
+
async def get_metrics(
|
|
1674
|
+
self,
|
|
1675
|
+
starting_date: Optional[date] = None,
|
|
1676
|
+
ending_date: Optional[date] = None,
|
|
1677
|
+
) -> Tuple[List[dict], Optional[int]]:
|
|
1678
|
+
"""Get all metrics matching the given date range.
|
|
1679
|
+
|
|
1680
|
+
Args:
|
|
1681
|
+
starting_date (Optional[date]): The starting date to filter metrics by.
|
|
1682
|
+
ending_date (Optional[date]): The ending date to filter metrics by.
|
|
1683
|
+
|
|
1684
|
+
Returns:
|
|
1685
|
+
Tuple[List[dict], Optional[int]]: A tuple containing the metrics and the timestamp of the latest update.
|
|
1686
|
+
|
|
1687
|
+
Raises:
|
|
1688
|
+
Exception: If an error occurs during retrieval.
|
|
1689
|
+
"""
|
|
1690
|
+
try:
|
|
1691
|
+
table = await self._get_table(table_type="metrics")
|
|
1692
|
+
if table is None:
|
|
1693
|
+
return [], None
|
|
1694
|
+
|
|
1695
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1696
|
+
stmt = select(table)
|
|
1697
|
+
if starting_date:
|
|
1698
|
+
stmt = stmt.where(table.c.date >= starting_date)
|
|
1699
|
+
if ending_date:
|
|
1700
|
+
stmt = stmt.where(table.c.date <= ending_date)
|
|
1701
|
+
result = (await sess.execute(stmt)).fetchall()
|
|
1702
|
+
if not result:
|
|
1703
|
+
return [], None
|
|
1704
|
+
|
|
1705
|
+
# Get the latest updated_at
|
|
1706
|
+
latest_stmt = select(func.max(table.c.updated_at))
|
|
1707
|
+
latest_updated_at = (await sess.execute(latest_stmt)).scalar()
|
|
1708
|
+
|
|
1709
|
+
return [dict(row._mapping) for row in result], latest_updated_at
|
|
1710
|
+
|
|
1711
|
+
except Exception as e:
|
|
1712
|
+
log_error(f"Error getting metrics: {e}")
|
|
1713
|
+
raise e
|
|
1714
|
+
|
|
1715
|
+
# -- Knowledge methods --
|
|
1716
|
+
|
|
1717
|
+
async def delete_knowledge_content(self, id: str):
|
|
1718
|
+
"""Delete a knowledge row from the database.
|
|
1719
|
+
|
|
1720
|
+
Args:
|
|
1721
|
+
id (str): The ID of the knowledge row to delete.
|
|
1722
|
+
|
|
1723
|
+
Raises:
|
|
1724
|
+
Exception: If an error occurs during deletion.
|
|
1725
|
+
"""
|
|
1726
|
+
table = await self._get_table(table_type="knowledge")
|
|
1727
|
+
if table is None:
|
|
1728
|
+
return
|
|
1729
|
+
|
|
1730
|
+
try:
|
|
1731
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1732
|
+
stmt = table.delete().where(table.c.id == id)
|
|
1733
|
+
await sess.execute(stmt)
|
|
1734
|
+
|
|
1735
|
+
except Exception as e:
|
|
1736
|
+
log_error(f"Error deleting knowledge content: {e}")
|
|
1737
|
+
raise e
|
|
1738
|
+
|
|
1739
|
+
async def get_knowledge_content(self, id: str) -> Optional[KnowledgeRow]:
|
|
1740
|
+
"""Get a knowledge row from the database.
|
|
1741
|
+
|
|
1742
|
+
Args:
|
|
1743
|
+
id (str): The ID of the knowledge row to get.
|
|
1744
|
+
|
|
1745
|
+
Returns:
|
|
1746
|
+
Optional[KnowledgeRow]: The knowledge row, or None if it doesn't exist.
|
|
1747
|
+
|
|
1748
|
+
Raises:
|
|
1749
|
+
Exception: If an error occurs during retrieval.
|
|
1750
|
+
"""
|
|
1751
|
+
table = await self._get_table(table_type="knowledge")
|
|
1752
|
+
if table is None:
|
|
1753
|
+
return None
|
|
1754
|
+
|
|
1755
|
+
try:
|
|
1756
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1757
|
+
stmt = select(table).where(table.c.id == id)
|
|
1758
|
+
result = (await sess.execute(stmt)).fetchone()
|
|
1759
|
+
if result is None:
|
|
1760
|
+
return None
|
|
1761
|
+
|
|
1762
|
+
return KnowledgeRow.model_validate(result._mapping)
|
|
1763
|
+
|
|
1764
|
+
except Exception as e:
|
|
1765
|
+
log_error(f"Error getting knowledge content: {e}")
|
|
1766
|
+
raise e
|
|
1767
|
+
|
|
1768
|
+
async def get_knowledge_contents(
|
|
1769
|
+
self,
|
|
1770
|
+
limit: Optional[int] = None,
|
|
1771
|
+
page: Optional[int] = None,
|
|
1772
|
+
sort_by: Optional[str] = None,
|
|
1773
|
+
sort_order: Optional[str] = None,
|
|
1774
|
+
) -> Tuple[List[KnowledgeRow], int]:
|
|
1775
|
+
"""Get all knowledge contents from the database.
|
|
1776
|
+
|
|
1777
|
+
Args:
|
|
1778
|
+
limit (Optional[int]): The maximum number of knowledge contents to return.
|
|
1779
|
+
page (Optional[int]): The page number.
|
|
1780
|
+
sort_by (Optional[str]): The column to sort by.
|
|
1781
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1782
|
+
|
|
1783
|
+
Returns:
|
|
1784
|
+
Tuple[List[KnowledgeRow], int]: The knowledge contents and total count.
|
|
1785
|
+
|
|
1786
|
+
Raises:
|
|
1787
|
+
Exception: If an error occurs during retrieval.
|
|
1788
|
+
"""
|
|
1789
|
+
table = await self._get_table(table_type="knowledge")
|
|
1790
|
+
if table is None:
|
|
1791
|
+
return [], 0
|
|
1792
|
+
|
|
1793
|
+
try:
|
|
1794
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1795
|
+
stmt = select(table)
|
|
1796
|
+
|
|
1797
|
+
# Apply sorting
|
|
1798
|
+
if sort_by is not None:
|
|
1799
|
+
stmt = stmt.order_by(getattr(table.c, sort_by) * (1 if sort_order == "asc" else -1))
|
|
1800
|
+
|
|
1801
|
+
# Get total count before applying limit and pagination
|
|
1802
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1803
|
+
total_count = (await sess.execute(count_stmt)).scalar() or 0
|
|
1804
|
+
|
|
1805
|
+
# Apply pagination after count
|
|
1806
|
+
if limit is not None:
|
|
1807
|
+
stmt = stmt.limit(limit)
|
|
1808
|
+
if page is not None:
|
|
1809
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
1810
|
+
|
|
1811
|
+
result = (await sess.execute(stmt)).fetchall()
|
|
1812
|
+
return [KnowledgeRow.model_validate(record._mapping) for record in result], total_count
|
|
1813
|
+
|
|
1814
|
+
except Exception as e:
|
|
1815
|
+
log_error(f"Error getting knowledge contents: {e}")
|
|
1816
|
+
raise e
|
|
1817
|
+
|
|
1818
|
+
async def upsert_knowledge_content(self, knowledge_row: KnowledgeRow):
|
|
1819
|
+
"""Upsert knowledge content in the database.
|
|
1820
|
+
|
|
1821
|
+
Args:
|
|
1822
|
+
knowledge_row (KnowledgeRow): The knowledge row to upsert.
|
|
1823
|
+
|
|
1824
|
+
Returns:
|
|
1825
|
+
Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
|
|
1826
|
+
"""
|
|
1827
|
+
try:
|
|
1828
|
+
table = await self._get_table(table_type="knowledge", create_table_if_not_found=True)
|
|
1829
|
+
if table is None:
|
|
1830
|
+
return None
|
|
1831
|
+
|
|
1832
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1833
|
+
update_fields = {
|
|
1834
|
+
k: v
|
|
1835
|
+
for k, v in {
|
|
1836
|
+
"name": knowledge_row.name,
|
|
1837
|
+
"description": knowledge_row.description,
|
|
1838
|
+
"metadata": knowledge_row.metadata,
|
|
1839
|
+
"type": knowledge_row.type,
|
|
1840
|
+
"size": knowledge_row.size,
|
|
1841
|
+
"linked_to": knowledge_row.linked_to,
|
|
1842
|
+
"access_count": knowledge_row.access_count,
|
|
1843
|
+
"status": knowledge_row.status,
|
|
1844
|
+
"status_message": knowledge_row.status_message,
|
|
1845
|
+
"created_at": knowledge_row.created_at,
|
|
1846
|
+
"updated_at": knowledge_row.updated_at,
|
|
1847
|
+
"external_id": knowledge_row.external_id,
|
|
1848
|
+
}.items()
|
|
1849
|
+
# Filtering out None fields if updating
|
|
1850
|
+
if v is not None
|
|
1851
|
+
}
|
|
1852
|
+
|
|
1853
|
+
stmt = (
|
|
1854
|
+
sqlite.insert(table)
|
|
1855
|
+
.values(knowledge_row.model_dump())
|
|
1856
|
+
.on_conflict_do_update(index_elements=["id"], set_=update_fields)
|
|
1857
|
+
)
|
|
1858
|
+
await sess.execute(stmt)
|
|
1859
|
+
|
|
1860
|
+
return knowledge_row
|
|
1861
|
+
|
|
1862
|
+
except Exception as e:
|
|
1863
|
+
log_error(f"Error upserting knowledge content: {e}")
|
|
1864
|
+
raise e
|
|
1865
|
+
|
|
1866
|
+
# -- Eval methods --
|
|
1867
|
+
|
|
1868
|
+
async def create_eval_run(self, eval_run: EvalRunRecord) -> Optional[EvalRunRecord]:
|
|
1869
|
+
"""Create an EvalRunRecord in the database.
|
|
1870
|
+
|
|
1871
|
+
Args:
|
|
1872
|
+
eval_run (EvalRunRecord): The eval run to create.
|
|
1873
|
+
|
|
1874
|
+
Returns:
|
|
1875
|
+
Optional[EvalRunRecord]: The created eval run, or None if the operation fails.
|
|
1876
|
+
|
|
1877
|
+
Raises:
|
|
1878
|
+
Exception: If an error occurs during creation.
|
|
1879
|
+
"""
|
|
1880
|
+
try:
|
|
1881
|
+
table = await self._get_table(table_type="evals", create_table_if_not_found=True)
|
|
1882
|
+
if table is None:
|
|
1883
|
+
return None
|
|
1884
|
+
|
|
1885
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1886
|
+
current_time = int(time.time())
|
|
1887
|
+
stmt = sqlite.insert(table).values(
|
|
1888
|
+
{
|
|
1889
|
+
"created_at": current_time,
|
|
1890
|
+
"updated_at": current_time,
|
|
1891
|
+
**eval_run.model_dump(),
|
|
1892
|
+
}
|
|
1893
|
+
)
|
|
1894
|
+
await sess.execute(stmt)
|
|
1895
|
+
|
|
1896
|
+
log_debug(f"Created eval run with id '{eval_run.run_id}'")
|
|
1897
|
+
|
|
1898
|
+
return eval_run
|
|
1899
|
+
|
|
1900
|
+
except Exception as e:
|
|
1901
|
+
log_error(f"Error creating eval run: {e}")
|
|
1902
|
+
raise e
|
|
1903
|
+
|
|
1904
|
+
async def delete_eval_run(self, eval_run_id: str) -> None:
|
|
1905
|
+
"""Delete an eval run from the database.
|
|
1906
|
+
|
|
1907
|
+
Args:
|
|
1908
|
+
eval_run_id (str): The ID of the eval run to delete.
|
|
1909
|
+
"""
|
|
1910
|
+
try:
|
|
1911
|
+
table = await self._get_table(table_type="evals")
|
|
1912
|
+
if table is None:
|
|
1913
|
+
return
|
|
1914
|
+
|
|
1915
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1916
|
+
stmt = table.delete().where(table.c.run_id == eval_run_id)
|
|
1917
|
+
result = await sess.execute(stmt)
|
|
1918
|
+
if result.rowcount == 0: # type: ignore
|
|
1919
|
+
log_warning(f"No eval run found with ID: {eval_run_id}")
|
|
1920
|
+
else:
|
|
1921
|
+
log_debug(f"Deleted eval run with ID: {eval_run_id}")
|
|
1922
|
+
|
|
1923
|
+
except Exception as e:
|
|
1924
|
+
log_error(f"Error deleting eval run {eval_run_id}: {e}")
|
|
1925
|
+
raise e
|
|
1926
|
+
|
|
1927
|
+
async def delete_eval_runs(self, eval_run_ids: List[str]) -> None:
|
|
1928
|
+
"""Delete multiple eval runs from the database.
|
|
1929
|
+
|
|
1930
|
+
Args:
|
|
1931
|
+
eval_run_ids (List[str]): List of eval run IDs to delete.
|
|
1932
|
+
"""
|
|
1933
|
+
try:
|
|
1934
|
+
table = await self._get_table(table_type="evals")
|
|
1935
|
+
if table is None:
|
|
1936
|
+
return
|
|
1937
|
+
|
|
1938
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1939
|
+
stmt = table.delete().where(table.c.run_id.in_(eval_run_ids))
|
|
1940
|
+
result = await sess.execute(stmt)
|
|
1941
|
+
if result.rowcount == 0: # type: ignore
|
|
1942
|
+
log_debug(f"No eval runs found with IDs: {eval_run_ids}")
|
|
1943
|
+
else:
|
|
1944
|
+
log_debug(f"Deleted {result.rowcount} eval runs") # type: ignore
|
|
1945
|
+
|
|
1946
|
+
except Exception as e:
|
|
1947
|
+
log_error(f"Error deleting eval runs {eval_run_ids}: {e}")
|
|
1948
|
+
raise e
|
|
1949
|
+
|
|
1950
|
+
async def get_eval_run(
|
|
1951
|
+
self, eval_run_id: str, deserialize: Optional[bool] = True
|
|
1952
|
+
) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1953
|
+
"""Get an eval run from the database.
|
|
1954
|
+
|
|
1955
|
+
Args:
|
|
1956
|
+
eval_run_id (str): The ID of the eval run to get.
|
|
1957
|
+
deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
|
|
1958
|
+
|
|
1959
|
+
Returns:
|
|
1960
|
+
Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1961
|
+
- When deserialize=True: EvalRunRecord object
|
|
1962
|
+
- When deserialize=False: EvalRun dictionary
|
|
1963
|
+
|
|
1964
|
+
Raises:
|
|
1965
|
+
Exception: If an error occurs during retrieval.
|
|
1966
|
+
"""
|
|
1967
|
+
try:
|
|
1968
|
+
table = await self._get_table(table_type="evals")
|
|
1969
|
+
if table is None:
|
|
1970
|
+
return None
|
|
1971
|
+
|
|
1972
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
1973
|
+
stmt = select(table).where(table.c.run_id == eval_run_id)
|
|
1974
|
+
result = (await sess.execute(stmt)).fetchone()
|
|
1975
|
+
if result is None:
|
|
1976
|
+
return None
|
|
1977
|
+
|
|
1978
|
+
eval_run_raw = dict(result._mapping)
|
|
1979
|
+
if not eval_run_raw or not deserialize:
|
|
1980
|
+
return eval_run_raw
|
|
1981
|
+
|
|
1982
|
+
return EvalRunRecord.model_validate(eval_run_raw)
|
|
1983
|
+
|
|
1984
|
+
except Exception as e:
|
|
1985
|
+
log_error(f"Exception getting eval run {eval_run_id}: {e}")
|
|
1986
|
+
raise e
|
|
1987
|
+
|
|
1988
|
+
async def get_eval_runs(
|
|
1989
|
+
self,
|
|
1990
|
+
limit: Optional[int] = None,
|
|
1991
|
+
page: Optional[int] = None,
|
|
1992
|
+
sort_by: Optional[str] = None,
|
|
1993
|
+
sort_order: Optional[str] = None,
|
|
1994
|
+
agent_id: Optional[str] = None,
|
|
1995
|
+
team_id: Optional[str] = None,
|
|
1996
|
+
workflow_id: Optional[str] = None,
|
|
1997
|
+
model_id: Optional[str] = None,
|
|
1998
|
+
filter_type: Optional[EvalFilterType] = None,
|
|
1999
|
+
eval_type: Optional[List[EvalType]] = None,
|
|
2000
|
+
deserialize: Optional[bool] = True,
|
|
2001
|
+
) -> Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
|
|
2002
|
+
"""Get all eval runs from the database.
|
|
2003
|
+
|
|
2004
|
+
Args:
|
|
2005
|
+
limit (Optional[int]): The maximum number of eval runs to return.
|
|
2006
|
+
page (Optional[int]): The page number.
|
|
2007
|
+
sort_by (Optional[str]): The column to sort by.
|
|
2008
|
+
sort_order (Optional[str]): The order to sort by.
|
|
2009
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
2010
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
2011
|
+
workflow_id (Optional[str]): The ID of the workflow to filter by.
|
|
2012
|
+
model_id (Optional[str]): The ID of the model to filter by.
|
|
2013
|
+
eval_type (Optional[List[EvalType]]): The type(s) of eval to filter by.
|
|
2014
|
+
filter_type (Optional[EvalFilterType]): Filter by component type (agent, team, workflow).
|
|
2015
|
+
deserialize (Optional[bool]): Whether to serialize the eval runs. Defaults to True.
|
|
2016
|
+
create_table_if_not_found (Optional[bool]): Whether to create the table if it doesn't exist.
|
|
2017
|
+
|
|
2018
|
+
Returns:
|
|
2019
|
+
Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
|
|
2020
|
+
- When deserialize=True: List of EvalRunRecord objects
|
|
2021
|
+
- When deserialize=False: List of EvalRun dictionaries and total count
|
|
2022
|
+
|
|
2023
|
+
Raises:
|
|
2024
|
+
Exception: If an error occurs during retrieval.
|
|
2025
|
+
"""
|
|
2026
|
+
try:
|
|
2027
|
+
table = await self._get_table(table_type="evals")
|
|
2028
|
+
if table is None:
|
|
2029
|
+
return [] if deserialize else ([], 0)
|
|
2030
|
+
|
|
2031
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2032
|
+
stmt = select(table)
|
|
2033
|
+
|
|
2034
|
+
# Filtering
|
|
2035
|
+
if agent_id is not None:
|
|
2036
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
2037
|
+
if team_id is not None:
|
|
2038
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
2039
|
+
if workflow_id is not None:
|
|
2040
|
+
stmt = stmt.where(table.c.workflow_id == workflow_id)
|
|
2041
|
+
if model_id is not None:
|
|
2042
|
+
stmt = stmt.where(table.c.model_id == model_id)
|
|
2043
|
+
if eval_type is not None and len(eval_type) > 0:
|
|
2044
|
+
stmt = stmt.where(table.c.eval_type.in_(eval_type))
|
|
2045
|
+
if filter_type is not None:
|
|
2046
|
+
if filter_type == EvalFilterType.AGENT:
|
|
2047
|
+
stmt = stmt.where(table.c.agent_id.is_not(None))
|
|
2048
|
+
elif filter_type == EvalFilterType.TEAM:
|
|
2049
|
+
stmt = stmt.where(table.c.team_id.is_not(None))
|
|
2050
|
+
elif filter_type == EvalFilterType.WORKFLOW:
|
|
2051
|
+
stmt = stmt.where(table.c.workflow_id.is_not(None))
|
|
2052
|
+
|
|
2053
|
+
# Get total count after applying filtering
|
|
2054
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
2055
|
+
total_count = (await sess.execute(count_stmt)).scalar() or 0
|
|
2056
|
+
|
|
2057
|
+
# Sorting - apply default sort by created_at desc if no sort parameters provided
|
|
2058
|
+
if sort_by is None:
|
|
2059
|
+
stmt = stmt.order_by(table.c.created_at.desc())
|
|
2060
|
+
else:
|
|
2061
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
2062
|
+
# Paginating
|
|
2063
|
+
if limit is not None:
|
|
2064
|
+
stmt = stmt.limit(limit)
|
|
2065
|
+
if page is not None:
|
|
2066
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
2067
|
+
|
|
2068
|
+
result = (await sess.execute(stmt)).fetchall()
|
|
2069
|
+
if not result:
|
|
2070
|
+
return [] if deserialize else ([], 0)
|
|
2071
|
+
|
|
2072
|
+
eval_runs_raw = [dict(row._mapping) for row in result]
|
|
2073
|
+
if not deserialize:
|
|
2074
|
+
return eval_runs_raw, total_count
|
|
2075
|
+
|
|
2076
|
+
return [EvalRunRecord.model_validate(row) for row in eval_runs_raw]
|
|
2077
|
+
|
|
2078
|
+
except Exception as e:
|
|
2079
|
+
log_error(f"Exception getting eval runs: {e}")
|
|
2080
|
+
raise e
|
|
2081
|
+
|
|
2082
|
+
async def rename_eval_run(
|
|
2083
|
+
self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
|
|
2084
|
+
) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
2085
|
+
"""Upsert the name of an eval run in the database, returning raw dictionary.
|
|
2086
|
+
|
|
2087
|
+
Args:
|
|
2088
|
+
eval_run_id (str): The ID of the eval run to update.
|
|
2089
|
+
name (str): The new name of the eval run.
|
|
2090
|
+
deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
|
|
2091
|
+
|
|
2092
|
+
Returns:
|
|
2093
|
+
Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
2094
|
+
- When deserialize=True: EvalRunRecord object
|
|
2095
|
+
- When deserialize=False: EvalRun dictionary
|
|
2096
|
+
|
|
2097
|
+
Raises:
|
|
2098
|
+
Exception: If an error occurs during update.
|
|
2099
|
+
"""
|
|
2100
|
+
try:
|
|
2101
|
+
table = await self._get_table(table_type="evals")
|
|
2102
|
+
if table is None:
|
|
2103
|
+
return None
|
|
2104
|
+
|
|
2105
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2106
|
+
stmt = (
|
|
2107
|
+
table.update().where(table.c.run_id == eval_run_id).values(name=name, updated_at=int(time.time()))
|
|
2108
|
+
)
|
|
2109
|
+
await sess.execute(stmt)
|
|
2110
|
+
|
|
2111
|
+
eval_run_raw = await self.get_eval_run(eval_run_id=eval_run_id, deserialize=deserialize)
|
|
2112
|
+
|
|
2113
|
+
log_debug(f"Renamed eval run with id '{eval_run_id}' to '{name}'")
|
|
2114
|
+
|
|
2115
|
+
if not eval_run_raw or not deserialize:
|
|
2116
|
+
return eval_run_raw
|
|
2117
|
+
|
|
2118
|
+
return EvalRunRecord.model_validate(eval_run_raw)
|
|
2119
|
+
|
|
2120
|
+
except Exception as e:
|
|
2121
|
+
log_error(f"Error renaming eval run {eval_run_id}: {e}")
|
|
2122
|
+
raise e
|
|
2123
|
+
|
|
2124
|
+
# -- Migrations --
|
|
2125
|
+
|
|
2126
|
+
async def migrate_table_from_v1_to_v2(self, v1_db_schema: str, v1_table_name: str, v1_table_type: str):
|
|
2127
|
+
"""Migrate all content in the given table to the right v2 table"""
|
|
2128
|
+
|
|
2129
|
+
from agno.db.migrations.v1_to_v2 import (
|
|
2130
|
+
get_all_table_content,
|
|
2131
|
+
parse_agent_sessions,
|
|
2132
|
+
parse_memories,
|
|
2133
|
+
parse_team_sessions,
|
|
2134
|
+
parse_workflow_sessions,
|
|
2135
|
+
)
|
|
2136
|
+
|
|
2137
|
+
# Get all content from the old table
|
|
2138
|
+
old_content: list[dict[str, Any]] = get_all_table_content(
|
|
2139
|
+
db=self,
|
|
2140
|
+
db_schema=v1_db_schema,
|
|
2141
|
+
table_name=v1_table_name,
|
|
2142
|
+
)
|
|
2143
|
+
if not old_content:
|
|
2144
|
+
log_info(f"No content to migrate from table {v1_table_name}")
|
|
2145
|
+
return
|
|
2146
|
+
|
|
2147
|
+
# Parse the content into the new format
|
|
2148
|
+
memories: List[UserMemory] = []
|
|
2149
|
+
sessions: Sequence[Union[AgentSession, TeamSession, WorkflowSession]] = []
|
|
2150
|
+
if v1_table_type == "agent_sessions":
|
|
2151
|
+
sessions = parse_agent_sessions(old_content)
|
|
2152
|
+
elif v1_table_type == "team_sessions":
|
|
2153
|
+
sessions = parse_team_sessions(old_content)
|
|
2154
|
+
elif v1_table_type == "workflow_sessions":
|
|
2155
|
+
sessions = parse_workflow_sessions(old_content)
|
|
2156
|
+
elif v1_table_type == "memories":
|
|
2157
|
+
memories = parse_memories(old_content)
|
|
2158
|
+
else:
|
|
2159
|
+
raise ValueError(f"Invalid table type: {v1_table_type}")
|
|
2160
|
+
|
|
2161
|
+
# Insert the new content into the new table
|
|
2162
|
+
if v1_table_type == "agent_sessions":
|
|
2163
|
+
for session in sessions:
|
|
2164
|
+
await self.upsert_session(session)
|
|
2165
|
+
log_info(f"Migrated {len(sessions)} Agent sessions to table: {self.session_table_name}")
|
|
2166
|
+
|
|
2167
|
+
elif v1_table_type == "team_sessions":
|
|
2168
|
+
for session in sessions:
|
|
2169
|
+
await self.upsert_session(session)
|
|
2170
|
+
log_info(f"Migrated {len(sessions)} Team sessions to table: {self.session_table_name}")
|
|
2171
|
+
|
|
2172
|
+
elif v1_table_type == "workflow_sessions":
|
|
2173
|
+
for session in sessions:
|
|
2174
|
+
await self.upsert_session(session)
|
|
2175
|
+
log_info(f"Migrated {len(sessions)} Workflow sessions to table: {self.session_table_name}")
|
|
2176
|
+
|
|
2177
|
+
elif v1_table_type == "memories":
|
|
2178
|
+
for memory in memories:
|
|
2179
|
+
await self.upsert_user_memory(memory)
|
|
2180
|
+
log_info(f"Migrated {len(memories)} memories to table: {self.memory_table}")
|
|
2181
|
+
|
|
2182
|
+
# -- Culture methods --
|
|
2183
|
+
|
|
2184
|
+
async def clear_cultural_knowledge(self) -> None:
|
|
2185
|
+
"""Delete all cultural artifacts from the database.
|
|
2186
|
+
|
|
2187
|
+
Raises:
|
|
2188
|
+
Exception: If an error occurs during deletion.
|
|
2189
|
+
"""
|
|
2190
|
+
try:
|
|
2191
|
+
table = await self._get_table(table_type="culture")
|
|
2192
|
+
if table is None:
|
|
2193
|
+
return
|
|
2194
|
+
|
|
2195
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2196
|
+
await sess.execute(table.delete())
|
|
2197
|
+
|
|
2198
|
+
except Exception as e:
|
|
2199
|
+
log_error(f"Exception deleting all cultural artifacts: {e}")
|
|
2200
|
+
|
|
2201
|
+
async def delete_cultural_knowledge(self, id: str) -> None:
|
|
2202
|
+
"""Delete a cultural artifact from the database.
|
|
2203
|
+
|
|
2204
|
+
Args:
|
|
2205
|
+
id (str): The ID of the cultural artifact to delete.
|
|
2206
|
+
|
|
2207
|
+
Raises:
|
|
2208
|
+
Exception: If an error occurs during deletion.
|
|
2209
|
+
"""
|
|
2210
|
+
try:
|
|
2211
|
+
table = await self._get_table(table_type="culture")
|
|
2212
|
+
if table is None:
|
|
2213
|
+
return
|
|
2214
|
+
|
|
2215
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2216
|
+
delete_stmt = table.delete().where(table.c.id == id)
|
|
2217
|
+
result = await sess.execute(delete_stmt)
|
|
2218
|
+
|
|
2219
|
+
success = result.rowcount > 0 # type: ignore
|
|
2220
|
+
if success:
|
|
2221
|
+
log_debug(f"Successfully deleted cultural artifact id: {id}")
|
|
2222
|
+
else:
|
|
2223
|
+
log_debug(f"No cultural artifact found with id: {id}")
|
|
2224
|
+
|
|
2225
|
+
except Exception as e:
|
|
2226
|
+
log_error(f"Error deleting cultural artifact: {e}")
|
|
2227
|
+
|
|
2228
|
+
async def get_cultural_knowledge(
|
|
2229
|
+
self, id: str, deserialize: Optional[bool] = True
|
|
2230
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2231
|
+
"""Get a cultural artifact from the database.
|
|
2232
|
+
|
|
2233
|
+
Args:
|
|
2234
|
+
id (str): The ID of the cultural artifact to get.
|
|
2235
|
+
deserialize (Optional[bool]): Whether to serialize the cultural artifact. Defaults to True.
|
|
2236
|
+
|
|
2237
|
+
Returns:
|
|
2238
|
+
Optional[CulturalKnowledge]: The cultural artifact, or None if it doesn't exist.
|
|
2239
|
+
|
|
2240
|
+
Raises:
|
|
2241
|
+
Exception: If an error occurs during retrieval.
|
|
2242
|
+
"""
|
|
2243
|
+
try:
|
|
2244
|
+
table = await self._get_table(table_type="culture")
|
|
2245
|
+
if table is None:
|
|
2246
|
+
return None
|
|
2247
|
+
|
|
2248
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2249
|
+
stmt = select(table).where(table.c.id == id)
|
|
2250
|
+
result = (await sess.execute(stmt)).fetchone()
|
|
2251
|
+
if result is None:
|
|
2252
|
+
return None
|
|
2253
|
+
|
|
2254
|
+
db_row = dict(result._mapping)
|
|
2255
|
+
if not db_row or not deserialize:
|
|
2256
|
+
return db_row
|
|
2257
|
+
|
|
2258
|
+
return deserialize_cultural_knowledge_from_db(db_row)
|
|
2259
|
+
|
|
2260
|
+
except Exception as e:
|
|
2261
|
+
log_error(f"Exception reading from cultural artifacts table: {e}")
|
|
2262
|
+
return None
|
|
2263
|
+
|
|
2264
|
+
async def get_all_cultural_knowledge(
|
|
2265
|
+
self,
|
|
2266
|
+
name: Optional[str] = None,
|
|
2267
|
+
agent_id: Optional[str] = None,
|
|
2268
|
+
team_id: Optional[str] = None,
|
|
2269
|
+
limit: Optional[int] = None,
|
|
2270
|
+
page: Optional[int] = None,
|
|
2271
|
+
sort_by: Optional[str] = None,
|
|
2272
|
+
sort_order: Optional[str] = None,
|
|
2273
|
+
deserialize: Optional[bool] = True,
|
|
2274
|
+
) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2275
|
+
"""Get all cultural artifacts from the database as CulturalNotion objects.
|
|
2276
|
+
|
|
2277
|
+
Args:
|
|
2278
|
+
name (Optional[str]): The name of the cultural artifact to filter by.
|
|
2279
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
2280
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
2281
|
+
limit (Optional[int]): The maximum number of cultural artifacts to return.
|
|
2282
|
+
page (Optional[int]): The page number.
|
|
2283
|
+
sort_by (Optional[str]): The column to sort by.
|
|
2284
|
+
sort_order (Optional[str]): The order to sort by.
|
|
2285
|
+
deserialize (Optional[bool]): Whether to serialize the cultural artifacts. Defaults to True.
|
|
2286
|
+
|
|
2287
|
+
Returns:
|
|
2288
|
+
Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2289
|
+
- When deserialize=True: List of CulturalNotion objects
|
|
2290
|
+
- When deserialize=False: List of CulturalNotion dictionaries and total count
|
|
2291
|
+
|
|
2292
|
+
Raises:
|
|
2293
|
+
Exception: If an error occurs during retrieval.
|
|
2294
|
+
"""
|
|
2295
|
+
try:
|
|
2296
|
+
table = await self._get_table(table_type="culture")
|
|
2297
|
+
if table is None:
|
|
2298
|
+
return [] if deserialize else ([], 0)
|
|
2299
|
+
|
|
2300
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2301
|
+
stmt = select(table)
|
|
2302
|
+
|
|
2303
|
+
# Filtering
|
|
2304
|
+
if name is not None:
|
|
2305
|
+
stmt = stmt.where(table.c.name == name)
|
|
2306
|
+
if agent_id is not None:
|
|
2307
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
2308
|
+
if team_id is not None:
|
|
2309
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
2310
|
+
|
|
2311
|
+
# Get total count after applying filtering
|
|
2312
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
2313
|
+
total_count = (await sess.execute(count_stmt)).scalar() or 0
|
|
2314
|
+
|
|
2315
|
+
# Sorting
|
|
2316
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
2317
|
+
# Paginating
|
|
2318
|
+
if limit is not None:
|
|
2319
|
+
stmt = stmt.limit(limit)
|
|
2320
|
+
if page is not None:
|
|
2321
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
2322
|
+
|
|
2323
|
+
result = (await sess.execute(stmt)).fetchall()
|
|
2324
|
+
if not result:
|
|
2325
|
+
return [] if deserialize else ([], 0)
|
|
2326
|
+
|
|
2327
|
+
db_rows = [dict(record._mapping) for record in result]
|
|
2328
|
+
|
|
2329
|
+
if not deserialize:
|
|
2330
|
+
return db_rows, total_count
|
|
2331
|
+
|
|
2332
|
+
return [deserialize_cultural_knowledge_from_db(row) for row in db_rows]
|
|
2333
|
+
|
|
2334
|
+
except Exception as e:
|
|
2335
|
+
log_error(f"Error reading from cultural artifacts table: {e}")
|
|
2336
|
+
return [] if deserialize else ([], 0)
|
|
2337
|
+
|
|
2338
|
+
async def upsert_cultural_knowledge(
|
|
2339
|
+
self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
|
|
2340
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2341
|
+
"""Upsert a cultural artifact into the database.
|
|
2342
|
+
|
|
2343
|
+
Args:
|
|
2344
|
+
cultural_knowledge (CulturalKnowledge): The cultural artifact to upsert.
|
|
2345
|
+
deserialize (Optional[bool]): Whether to serialize the cultural artifact. Defaults to True.
|
|
2346
|
+
|
|
2347
|
+
Returns:
|
|
2348
|
+
Optional[Union[CulturalNotion, Dict[str, Any]]]:
|
|
2349
|
+
- When deserialize=True: CulturalNotion object
|
|
2350
|
+
- When deserialize=False: CulturalNotion dictionary
|
|
2351
|
+
|
|
2352
|
+
Raises:
|
|
2353
|
+
Exception: If an error occurs during upsert.
|
|
2354
|
+
"""
|
|
2355
|
+
try:
|
|
2356
|
+
table = await self._get_table(table_type="culture", create_table_if_not_found=True)
|
|
2357
|
+
if table is None:
|
|
2358
|
+
return None
|
|
2359
|
+
|
|
2360
|
+
if cultural_knowledge.id is None:
|
|
2361
|
+
cultural_knowledge.id = str(uuid4())
|
|
2362
|
+
|
|
2363
|
+
# Serialize content, categories, and notes into a JSON string for DB storage (SQLite requires strings)
|
|
2364
|
+
content_json_str = serialize_cultural_knowledge_for_db(cultural_knowledge)
|
|
2365
|
+
|
|
2366
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2367
|
+
stmt = sqlite.insert(table).values(
|
|
2368
|
+
id=cultural_knowledge.id,
|
|
2369
|
+
name=cultural_knowledge.name,
|
|
2370
|
+
summary=cultural_knowledge.summary,
|
|
2371
|
+
content=content_json_str,
|
|
2372
|
+
metadata=cultural_knowledge.metadata,
|
|
2373
|
+
input=cultural_knowledge.input,
|
|
2374
|
+
created_at=cultural_knowledge.created_at,
|
|
2375
|
+
updated_at=int(time.time()),
|
|
2376
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2377
|
+
team_id=cultural_knowledge.team_id,
|
|
2378
|
+
)
|
|
2379
|
+
stmt = stmt.on_conflict_do_update( # type: ignore
|
|
2380
|
+
index_elements=["id"],
|
|
2381
|
+
set_=dict(
|
|
2382
|
+
name=cultural_knowledge.name,
|
|
2383
|
+
summary=cultural_knowledge.summary,
|
|
2384
|
+
content=content_json_str,
|
|
2385
|
+
metadata=cultural_knowledge.metadata,
|
|
2386
|
+
input=cultural_knowledge.input,
|
|
2387
|
+
updated_at=int(time.time()),
|
|
2388
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2389
|
+
team_id=cultural_knowledge.team_id,
|
|
2390
|
+
),
|
|
2391
|
+
).returning(table)
|
|
2392
|
+
|
|
2393
|
+
result = await sess.execute(stmt)
|
|
2394
|
+
row = result.fetchone()
|
|
2395
|
+
|
|
2396
|
+
if row is None:
|
|
2397
|
+
return None
|
|
2398
|
+
|
|
2399
|
+
db_row: Dict[str, Any] = dict(row._mapping)
|
|
2400
|
+
if not db_row or not deserialize:
|
|
2401
|
+
return db_row
|
|
2402
|
+
|
|
2403
|
+
return deserialize_cultural_knowledge_from_db(db_row)
|
|
2404
|
+
|
|
2405
|
+
except Exception as e:
|
|
2406
|
+
log_error(f"Error upserting cultural knowledge: {e}")
|
|
2407
|
+
raise e
|
|
2408
|
+
|
|
2409
|
+
# --- Traces ---
|
|
2410
|
+
def _get_traces_base_query(self, table: Table, spans_table: Optional[Table] = None):
|
|
2411
|
+
"""Build base query for traces with aggregated span counts.
|
|
2412
|
+
|
|
2413
|
+
Args:
|
|
2414
|
+
table: The traces table.
|
|
2415
|
+
spans_table: The spans table (optional).
|
|
2416
|
+
|
|
2417
|
+
Returns:
|
|
2418
|
+
SQLAlchemy select statement with total_spans and error_count calculated dynamically.
|
|
2419
|
+
"""
|
|
2420
|
+
from sqlalchemy import case, literal
|
|
2421
|
+
|
|
2422
|
+
if spans_table is not None:
|
|
2423
|
+
# JOIN with spans table to calculate total_spans and error_count
|
|
2424
|
+
return (
|
|
2425
|
+
select(
|
|
2426
|
+
table,
|
|
2427
|
+
func.coalesce(func.count(spans_table.c.span_id), 0).label("total_spans"),
|
|
2428
|
+
func.coalesce(func.sum(case((spans_table.c.status_code == "ERROR", 1), else_=0)), 0).label(
|
|
2429
|
+
"error_count"
|
|
2430
|
+
),
|
|
2431
|
+
)
|
|
2432
|
+
.select_from(table.outerjoin(spans_table, table.c.trace_id == spans_table.c.trace_id))
|
|
2433
|
+
.group_by(table.c.trace_id)
|
|
2434
|
+
)
|
|
2435
|
+
else:
|
|
2436
|
+
# Fallback if spans table doesn't exist
|
|
2437
|
+
return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
|
|
2438
|
+
|
|
2439
|
+
def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
|
|
2440
|
+
"""Build a SQL CASE expression that returns the component level for a trace.
|
|
2441
|
+
|
|
2442
|
+
Component levels (higher = more important):
|
|
2443
|
+
- 3: Workflow root (.run or .arun with workflow_id)
|
|
2444
|
+
- 2: Team root (.run or .arun with team_id)
|
|
2445
|
+
- 1: Agent root (.run or .arun with agent_id)
|
|
2446
|
+
- 0: Child span (not a root)
|
|
2447
|
+
|
|
2448
|
+
Args:
|
|
2449
|
+
workflow_id_col: SQL column/expression for workflow_id
|
|
2450
|
+
team_id_col: SQL column/expression for team_id
|
|
2451
|
+
agent_id_col: SQL column/expression for agent_id
|
|
2452
|
+
name_col: SQL column/expression for name
|
|
2453
|
+
|
|
2454
|
+
Returns:
|
|
2455
|
+
SQLAlchemy CASE expression returning the component level as an integer.
|
|
2456
|
+
"""
|
|
2457
|
+
from sqlalchemy import and_, case, or_
|
|
2458
|
+
|
|
2459
|
+
is_root_name = or_(name_col.contains(".run"), name_col.contains(".arun"))
|
|
2460
|
+
|
|
2461
|
+
return case(
|
|
2462
|
+
# Workflow root (level 3)
|
|
2463
|
+
(and_(workflow_id_col.isnot(None), is_root_name), 3),
|
|
2464
|
+
# Team root (level 2)
|
|
2465
|
+
(and_(team_id_col.isnot(None), is_root_name), 2),
|
|
2466
|
+
# Agent root (level 1)
|
|
2467
|
+
(and_(agent_id_col.isnot(None), is_root_name), 1),
|
|
2468
|
+
# Child span or unknown (level 0)
|
|
2469
|
+
else_=0,
|
|
2470
|
+
)
|
|
2471
|
+
|
|
2472
|
+
async def upsert_trace(self, trace: "Trace") -> None:
|
|
2473
|
+
"""Create or update a single trace record in the database.
|
|
2474
|
+
|
|
2475
|
+
Uses INSERT ... ON CONFLICT DO UPDATE (upsert) to handle concurrent inserts
|
|
2476
|
+
atomically and avoid race conditions.
|
|
2477
|
+
|
|
2478
|
+
Args:
|
|
2479
|
+
trace: The Trace object to store (one per trace_id).
|
|
2480
|
+
"""
|
|
2481
|
+
from sqlalchemy import case
|
|
2482
|
+
|
|
2483
|
+
try:
|
|
2484
|
+
table = await self._get_table(table_type="traces", create_table_if_not_found=True)
|
|
2485
|
+
if table is None:
|
|
2486
|
+
return
|
|
2487
|
+
|
|
2488
|
+
trace_dict = trace.to_dict()
|
|
2489
|
+
trace_dict.pop("total_spans", None)
|
|
2490
|
+
trace_dict.pop("error_count", None)
|
|
2491
|
+
|
|
2492
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2493
|
+
# Use upsert to handle concurrent inserts atomically
|
|
2494
|
+
# On conflict, update fields while preserving existing non-null context values
|
|
2495
|
+
# and keeping the earliest start_time
|
|
2496
|
+
insert_stmt = sqlite.insert(table).values(trace_dict)
|
|
2497
|
+
|
|
2498
|
+
# Build component level expressions for comparing trace priority
|
|
2499
|
+
new_level = self._get_trace_component_level_expr(
|
|
2500
|
+
insert_stmt.excluded.workflow_id,
|
|
2501
|
+
insert_stmt.excluded.team_id,
|
|
2502
|
+
insert_stmt.excluded.agent_id,
|
|
2503
|
+
insert_stmt.excluded.name,
|
|
2504
|
+
)
|
|
2505
|
+
existing_level = self._get_trace_component_level_expr(
|
|
2506
|
+
table.c.workflow_id,
|
|
2507
|
+
table.c.team_id,
|
|
2508
|
+
table.c.agent_id,
|
|
2509
|
+
table.c.name,
|
|
2510
|
+
)
|
|
2511
|
+
|
|
2512
|
+
# Build the ON CONFLICT DO UPDATE clause
|
|
2513
|
+
# Use MIN for start_time, MAX for end_time to capture full trace duration
|
|
2514
|
+
# SQLite stores timestamps as ISO strings, so string comparison works for ISO format
|
|
2515
|
+
# Duration is calculated as: (MAX(end_time) - MIN(start_time)) in milliseconds
|
|
2516
|
+
# SQLite doesn't have epoch extraction, so we calculate duration using julianday
|
|
2517
|
+
upsert_stmt = insert_stmt.on_conflict_do_update(
|
|
2518
|
+
index_elements=["trace_id"],
|
|
2519
|
+
set_={
|
|
2520
|
+
"end_time": func.max(table.c.end_time, insert_stmt.excluded.end_time),
|
|
2521
|
+
"start_time": func.min(table.c.start_time, insert_stmt.excluded.start_time),
|
|
2522
|
+
# Calculate duration in milliseconds using julianday (SQLite-specific)
|
|
2523
|
+
# julianday returns days, so multiply by 86400000 to get milliseconds
|
|
2524
|
+
"duration_ms": (
|
|
2525
|
+
func.julianday(func.max(table.c.end_time, insert_stmt.excluded.end_time))
|
|
2526
|
+
- func.julianday(func.min(table.c.start_time, insert_stmt.excluded.start_time))
|
|
2527
|
+
)
|
|
2528
|
+
* 86400000,
|
|
2529
|
+
"status": insert_stmt.excluded.status,
|
|
2530
|
+
# Update name only if new trace is from a higher-level component
|
|
2531
|
+
# Priority: workflow (3) > team (2) > agent (1) > child spans (0)
|
|
2532
|
+
"name": case(
|
|
2533
|
+
(new_level > existing_level, insert_stmt.excluded.name),
|
|
2534
|
+
else_=table.c.name,
|
|
2535
|
+
),
|
|
2536
|
+
# Preserve existing non-null context values using COALESCE
|
|
2537
|
+
"run_id": func.coalesce(insert_stmt.excluded.run_id, table.c.run_id),
|
|
2538
|
+
"session_id": func.coalesce(insert_stmt.excluded.session_id, table.c.session_id),
|
|
2539
|
+
"user_id": func.coalesce(insert_stmt.excluded.user_id, table.c.user_id),
|
|
2540
|
+
"agent_id": func.coalesce(insert_stmt.excluded.agent_id, table.c.agent_id),
|
|
2541
|
+
"team_id": func.coalesce(insert_stmt.excluded.team_id, table.c.team_id),
|
|
2542
|
+
"workflow_id": func.coalesce(insert_stmt.excluded.workflow_id, table.c.workflow_id),
|
|
2543
|
+
},
|
|
2544
|
+
)
|
|
2545
|
+
await sess.execute(upsert_stmt)
|
|
2546
|
+
|
|
2547
|
+
except Exception as e:
|
|
2548
|
+
log_error(f"Error creating trace: {e}")
|
|
2549
|
+
# Don't raise - tracing should not break the main application flow
|
|
2550
|
+
|
|
2551
|
+
async def get_trace(
|
|
2552
|
+
self,
|
|
2553
|
+
trace_id: Optional[str] = None,
|
|
2554
|
+
run_id: Optional[str] = None,
|
|
2555
|
+
):
|
|
2556
|
+
"""Get a single trace by trace_id or other filters.
|
|
2557
|
+
|
|
2558
|
+
Args:
|
|
2559
|
+
trace_id: The unique trace identifier.
|
|
2560
|
+
run_id: Filter by run ID (returns first match).
|
|
2561
|
+
|
|
2562
|
+
Returns:
|
|
2563
|
+
Optional[Trace]: The trace if found, None otherwise.
|
|
2564
|
+
|
|
2565
|
+
Note:
|
|
2566
|
+
If multiple filters are provided, trace_id takes precedence.
|
|
2567
|
+
For other filters, the most recent trace is returned.
|
|
2568
|
+
"""
|
|
2569
|
+
try:
|
|
2570
|
+
from agno.tracing.schemas import Trace
|
|
2571
|
+
|
|
2572
|
+
table = await self._get_table(table_type="traces")
|
|
2573
|
+
if table is None:
|
|
2574
|
+
return None
|
|
2575
|
+
|
|
2576
|
+
# Get spans table for JOIN
|
|
2577
|
+
spans_table = await self._get_table(table_type="spans")
|
|
2578
|
+
|
|
2579
|
+
async with self.async_session_factory() as sess:
|
|
2580
|
+
# Build query with aggregated span counts
|
|
2581
|
+
stmt = self._get_traces_base_query(table, spans_table)
|
|
2582
|
+
|
|
2583
|
+
if trace_id:
|
|
2584
|
+
stmt = stmt.where(table.c.trace_id == trace_id)
|
|
2585
|
+
elif run_id:
|
|
2586
|
+
stmt = stmt.where(table.c.run_id == run_id)
|
|
2587
|
+
else:
|
|
2588
|
+
log_debug("get_trace called without any filter parameters")
|
|
2589
|
+
return None
|
|
2590
|
+
|
|
2591
|
+
# Order by most recent and get first result
|
|
2592
|
+
stmt = stmt.order_by(table.c.start_time.desc()).limit(1)
|
|
2593
|
+
result = await sess.execute(stmt)
|
|
2594
|
+
row = result.fetchone()
|
|
2595
|
+
|
|
2596
|
+
if row:
|
|
2597
|
+
return Trace.from_dict(dict(row._mapping))
|
|
2598
|
+
return None
|
|
2599
|
+
|
|
2600
|
+
except Exception as e:
|
|
2601
|
+
log_error(f"Error getting trace: {e}")
|
|
2602
|
+
return None
|
|
2603
|
+
|
|
2604
|
+
async def get_traces(
|
|
2605
|
+
self,
|
|
2606
|
+
run_id: Optional[str] = None,
|
|
2607
|
+
session_id: Optional[str] = None,
|
|
2608
|
+
user_id: Optional[str] = None,
|
|
2609
|
+
agent_id: Optional[str] = None,
|
|
2610
|
+
team_id: Optional[str] = None,
|
|
2611
|
+
workflow_id: Optional[str] = None,
|
|
2612
|
+
status: Optional[str] = None,
|
|
2613
|
+
start_time: Optional[datetime] = None,
|
|
2614
|
+
end_time: Optional[datetime] = None,
|
|
2615
|
+
limit: Optional[int] = 20,
|
|
2616
|
+
page: Optional[int] = 1,
|
|
2617
|
+
) -> tuple[List, int]:
|
|
2618
|
+
"""Get traces matching the provided filters with pagination.
|
|
2619
|
+
|
|
2620
|
+
Args:
|
|
2621
|
+
run_id: Filter by run ID.
|
|
2622
|
+
session_id: Filter by session ID.
|
|
2623
|
+
user_id: Filter by user ID.
|
|
2624
|
+
agent_id: Filter by agent ID.
|
|
2625
|
+
team_id: Filter by team ID.
|
|
2626
|
+
workflow_id: Filter by workflow ID.
|
|
2627
|
+
status: Filter by status (OK, ERROR, UNSET).
|
|
2628
|
+
start_time: Filter traces starting after this datetime.
|
|
2629
|
+
end_time: Filter traces ending before this datetime.
|
|
2630
|
+
limit: Maximum number of traces to return per page.
|
|
2631
|
+
page: Page number (1-indexed).
|
|
2632
|
+
|
|
2633
|
+
Returns:
|
|
2634
|
+
tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
|
|
2635
|
+
"""
|
|
2636
|
+
try:
|
|
2637
|
+
from agno.tracing.schemas import Trace
|
|
2638
|
+
|
|
2639
|
+
table = await self._get_table(table_type="traces")
|
|
2640
|
+
if table is None:
|
|
2641
|
+
log_debug("Traces table not found")
|
|
2642
|
+
return [], 0
|
|
2643
|
+
|
|
2644
|
+
# Get spans table for JOIN
|
|
2645
|
+
spans_table = await self._get_table(table_type="spans")
|
|
2646
|
+
|
|
2647
|
+
async with self.async_session_factory() as sess:
|
|
2648
|
+
# Build base query with aggregated span counts
|
|
2649
|
+
base_stmt = self._get_traces_base_query(table, spans_table)
|
|
2650
|
+
|
|
2651
|
+
# Apply filters
|
|
2652
|
+
if run_id:
|
|
2653
|
+
base_stmt = base_stmt.where(table.c.run_id == run_id)
|
|
2654
|
+
if session_id:
|
|
2655
|
+
base_stmt = base_stmt.where(table.c.session_id == session_id)
|
|
2656
|
+
if user_id:
|
|
2657
|
+
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
2658
|
+
if agent_id:
|
|
2659
|
+
base_stmt = base_stmt.where(table.c.agent_id == agent_id)
|
|
2660
|
+
if team_id:
|
|
2661
|
+
base_stmt = base_stmt.where(table.c.team_id == team_id)
|
|
2662
|
+
if workflow_id:
|
|
2663
|
+
base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
|
|
2664
|
+
if status:
|
|
2665
|
+
base_stmt = base_stmt.where(table.c.status == status)
|
|
2666
|
+
if start_time:
|
|
2667
|
+
# Convert datetime to ISO string for comparison
|
|
2668
|
+
base_stmt = base_stmt.where(table.c.start_time >= start_time.isoformat())
|
|
2669
|
+
if end_time:
|
|
2670
|
+
# Convert datetime to ISO string for comparison
|
|
2671
|
+
base_stmt = base_stmt.where(table.c.end_time <= end_time.isoformat())
|
|
2672
|
+
|
|
2673
|
+
# Get total count
|
|
2674
|
+
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2675
|
+
total_count = await sess.scalar(count_stmt) or 0
|
|
2676
|
+
|
|
2677
|
+
# Apply pagination
|
|
2678
|
+
offset = (page - 1) * limit if page and limit else 0
|
|
2679
|
+
paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
|
|
2680
|
+
|
|
2681
|
+
result = await sess.execute(paginated_stmt)
|
|
2682
|
+
results = result.fetchall()
|
|
2683
|
+
|
|
2684
|
+
traces = [Trace.from_dict(dict(row._mapping)) for row in results]
|
|
2685
|
+
return traces, total_count
|
|
2686
|
+
|
|
2687
|
+
except Exception as e:
|
|
2688
|
+
log_error(f"Error getting traces: {e}")
|
|
2689
|
+
return [], 0
|
|
2690
|
+
|
|
2691
|
+
async def get_trace_stats(
|
|
2692
|
+
self,
|
|
2693
|
+
user_id: Optional[str] = None,
|
|
2694
|
+
agent_id: Optional[str] = None,
|
|
2695
|
+
team_id: Optional[str] = None,
|
|
2696
|
+
workflow_id: Optional[str] = None,
|
|
2697
|
+
start_time: Optional[datetime] = None,
|
|
2698
|
+
end_time: Optional[datetime] = None,
|
|
2699
|
+
limit: Optional[int] = 20,
|
|
2700
|
+
page: Optional[int] = 1,
|
|
2701
|
+
) -> tuple[List[Dict[str, Any]], int]:
|
|
2702
|
+
"""Get trace statistics grouped by session.
|
|
2703
|
+
|
|
2704
|
+
Args:
|
|
2705
|
+
user_id: Filter by user ID.
|
|
2706
|
+
agent_id: Filter by agent ID.
|
|
2707
|
+
team_id: Filter by team ID.
|
|
2708
|
+
workflow_id: Filter by workflow ID.
|
|
2709
|
+
start_time: Filter sessions with traces created after this datetime.
|
|
2710
|
+
end_time: Filter sessions with traces created before this datetime.
|
|
2711
|
+
limit: Maximum number of sessions to return per page.
|
|
2712
|
+
page: Page number (1-indexed).
|
|
2713
|
+
|
|
2714
|
+
Returns:
|
|
2715
|
+
tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
|
|
2716
|
+
Each dict contains: session_id, user_id, agent_id, team_id, total_traces,
|
|
2717
|
+
workflow_id, first_trace_at, last_trace_at.
|
|
2718
|
+
"""
|
|
2719
|
+
try:
|
|
2720
|
+
table = await self._get_table(table_type="traces")
|
|
2721
|
+
if table is None:
|
|
2722
|
+
log_debug("Traces table not found")
|
|
2723
|
+
return [], 0
|
|
2724
|
+
|
|
2725
|
+
async with self.async_session_factory() as sess:
|
|
2726
|
+
# Build base query grouped by session_id
|
|
2727
|
+
base_stmt = (
|
|
2728
|
+
select(
|
|
2729
|
+
table.c.session_id,
|
|
2730
|
+
table.c.user_id,
|
|
2731
|
+
table.c.agent_id,
|
|
2732
|
+
table.c.team_id,
|
|
2733
|
+
table.c.workflow_id,
|
|
2734
|
+
func.count(table.c.trace_id).label("total_traces"),
|
|
2735
|
+
func.min(table.c.created_at).label("first_trace_at"),
|
|
2736
|
+
func.max(table.c.created_at).label("last_trace_at"),
|
|
2737
|
+
)
|
|
2738
|
+
.where(table.c.session_id.isnot(None)) # Only sessions with session_id
|
|
2739
|
+
.group_by(
|
|
2740
|
+
table.c.session_id, table.c.user_id, table.c.agent_id, table.c.team_id, table.c.workflow_id
|
|
2741
|
+
)
|
|
2742
|
+
)
|
|
2743
|
+
|
|
2744
|
+
# Apply filters
|
|
2745
|
+
if user_id:
|
|
2746
|
+
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
2747
|
+
if workflow_id:
|
|
2748
|
+
base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
|
|
2749
|
+
if team_id:
|
|
2750
|
+
base_stmt = base_stmt.where(table.c.team_id == team_id)
|
|
2751
|
+
if agent_id:
|
|
2752
|
+
base_stmt = base_stmt.where(table.c.agent_id == agent_id)
|
|
2753
|
+
if start_time:
|
|
2754
|
+
# Convert datetime to ISO string for comparison
|
|
2755
|
+
base_stmt = base_stmt.where(table.c.created_at >= start_time.isoformat())
|
|
2756
|
+
if end_time:
|
|
2757
|
+
# Convert datetime to ISO string for comparison
|
|
2758
|
+
base_stmt = base_stmt.where(table.c.created_at <= end_time.isoformat())
|
|
2759
|
+
|
|
2760
|
+
# Get total count of sessions
|
|
2761
|
+
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2762
|
+
total_count = await sess.scalar(count_stmt) or 0
|
|
2763
|
+
|
|
2764
|
+
# Apply pagination and ordering
|
|
2765
|
+
offset = (page - 1) * limit if page and limit else 0
|
|
2766
|
+
paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
|
|
2767
|
+
|
|
2768
|
+
result = await sess.execute(paginated_stmt)
|
|
2769
|
+
results = result.fetchall()
|
|
2770
|
+
|
|
2771
|
+
# Convert to list of dicts with datetime objects
|
|
2772
|
+
stats_list = []
|
|
2773
|
+
for row in results:
|
|
2774
|
+
# Convert ISO strings to datetime objects
|
|
2775
|
+
first_trace_at_str = row.first_trace_at
|
|
2776
|
+
last_trace_at_str = row.last_trace_at
|
|
2777
|
+
|
|
2778
|
+
# Parse ISO format strings to datetime objects
|
|
2779
|
+
first_trace_at = datetime.fromisoformat(first_trace_at_str.replace("Z", "+00:00"))
|
|
2780
|
+
last_trace_at = datetime.fromisoformat(last_trace_at_str.replace("Z", "+00:00"))
|
|
2781
|
+
|
|
2782
|
+
stats_list.append(
|
|
2783
|
+
{
|
|
2784
|
+
"session_id": row.session_id,
|
|
2785
|
+
"user_id": row.user_id,
|
|
2786
|
+
"agent_id": row.agent_id,
|
|
2787
|
+
"team_id": row.team_id,
|
|
2788
|
+
"workflow_id": row.workflow_id,
|
|
2789
|
+
"total_traces": row.total_traces,
|
|
2790
|
+
"first_trace_at": first_trace_at,
|
|
2791
|
+
"last_trace_at": last_trace_at,
|
|
2792
|
+
}
|
|
2793
|
+
)
|
|
2794
|
+
|
|
2795
|
+
return stats_list, total_count
|
|
2796
|
+
|
|
2797
|
+
except Exception as e:
|
|
2798
|
+
log_error(f"Error getting trace stats: {e}")
|
|
2799
|
+
return [], 0
|
|
2800
|
+
|
|
2801
|
+
# --- Spans ---
|
|
2802
|
+
async def create_span(self, span: "Span") -> None:
|
|
2803
|
+
"""Create a single span in the database.
|
|
2804
|
+
|
|
2805
|
+
Args:
|
|
2806
|
+
span: The Span object to store.
|
|
2807
|
+
"""
|
|
2808
|
+
try:
|
|
2809
|
+
table = await self._get_table(table_type="spans", create_table_if_not_found=True)
|
|
2810
|
+
if table is None:
|
|
2811
|
+
return
|
|
2812
|
+
|
|
2813
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2814
|
+
stmt = sqlite.insert(table).values(span.to_dict())
|
|
2815
|
+
await sess.execute(stmt)
|
|
2816
|
+
|
|
2817
|
+
except Exception as e:
|
|
2818
|
+
log_error(f"Error creating span: {e}")
|
|
2819
|
+
|
|
2820
|
+
async def create_spans(self, spans: List) -> None:
|
|
2821
|
+
"""Create multiple spans in the database as a batch.
|
|
2822
|
+
|
|
2823
|
+
Args:
|
|
2824
|
+
spans: List of Span objects to store.
|
|
2825
|
+
"""
|
|
2826
|
+
if not spans:
|
|
2827
|
+
return
|
|
2828
|
+
|
|
2829
|
+
try:
|
|
2830
|
+
table = await self._get_table(table_type="spans", create_table_if_not_found=True)
|
|
2831
|
+
if table is None:
|
|
2832
|
+
return
|
|
2833
|
+
|
|
2834
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
2835
|
+
for span in spans:
|
|
2836
|
+
stmt = sqlite.insert(table).values(span.to_dict())
|
|
2837
|
+
await sess.execute(stmt)
|
|
2838
|
+
|
|
2839
|
+
except Exception as e:
|
|
2840
|
+
log_error(f"Error creating spans batch: {e}")
|
|
2841
|
+
|
|
2842
|
+
async def get_span(self, span_id: str):
|
|
2843
|
+
"""Get a single span by its span_id.
|
|
2844
|
+
|
|
2845
|
+
Args:
|
|
2846
|
+
span_id: The unique span identifier.
|
|
2847
|
+
|
|
2848
|
+
Returns:
|
|
2849
|
+
Optional[Span]: The span if found, None otherwise.
|
|
2850
|
+
"""
|
|
2851
|
+
try:
|
|
2852
|
+
from agno.tracing.schemas import Span
|
|
2853
|
+
|
|
2854
|
+
table = await self._get_table(table_type="spans")
|
|
2855
|
+
if table is None:
|
|
2856
|
+
return None
|
|
2857
|
+
|
|
2858
|
+
async with self.async_session_factory() as sess:
|
|
2859
|
+
stmt = select(table).where(table.c.span_id == span_id)
|
|
2860
|
+
result = await sess.execute(stmt)
|
|
2861
|
+
row = result.fetchone()
|
|
2862
|
+
if row:
|
|
2863
|
+
return Span.from_dict(dict(row._mapping))
|
|
2864
|
+
return None
|
|
2865
|
+
|
|
2866
|
+
except Exception as e:
|
|
2867
|
+
log_error(f"Error getting span: {e}")
|
|
2868
|
+
return None
|
|
2869
|
+
|
|
2870
|
+
async def get_spans(
|
|
2871
|
+
self,
|
|
2872
|
+
trace_id: Optional[str] = None,
|
|
2873
|
+
parent_span_id: Optional[str] = None,
|
|
2874
|
+
limit: Optional[int] = 1000,
|
|
2875
|
+
) -> List:
|
|
2876
|
+
"""Get spans matching the provided filters.
|
|
2877
|
+
|
|
2878
|
+
Args:
|
|
2879
|
+
trace_id: Filter by trace ID.
|
|
2880
|
+
parent_span_id: Filter by parent span ID.
|
|
2881
|
+
limit: Maximum number of spans to return.
|
|
2882
|
+
|
|
2883
|
+
Returns:
|
|
2884
|
+
List[Span]: List of matching spans.
|
|
2885
|
+
"""
|
|
2886
|
+
try:
|
|
2887
|
+
from agno.tracing.schemas import Span
|
|
2888
|
+
|
|
2889
|
+
table = await self._get_table(table_type="spans")
|
|
2890
|
+
if table is None:
|
|
2891
|
+
return []
|
|
2892
|
+
|
|
2893
|
+
async with self.async_session_factory() as sess:
|
|
2894
|
+
stmt = select(table)
|
|
2895
|
+
|
|
2896
|
+
# Apply filters
|
|
2897
|
+
if trace_id:
|
|
2898
|
+
stmt = stmt.where(table.c.trace_id == trace_id)
|
|
2899
|
+
if parent_span_id:
|
|
2900
|
+
stmt = stmt.where(table.c.parent_span_id == parent_span_id)
|
|
2901
|
+
|
|
2902
|
+
if limit:
|
|
2903
|
+
stmt = stmt.limit(limit)
|
|
2904
|
+
|
|
2905
|
+
result = await sess.execute(stmt)
|
|
2906
|
+
results = result.fetchall()
|
|
2907
|
+
return [Span.from_dict(dict(row._mapping)) for row in results]
|
|
2908
|
+
|
|
2909
|
+
except Exception as e:
|
|
2910
|
+
log_error(f"Error getting spans: {e}")
|
|
2911
|
+
return []
|