agno 2.1.2__py3-none-any.whl → 2.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +5540 -2273
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +689 -6
- agno/db/dynamo/dynamo.py +933 -37
- agno/db/dynamo/schemas.py +174 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +831 -9
- agno/db/firestore/schemas.py +51 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +660 -12
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +287 -14
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +590 -14
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +43 -13
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2760 -0
- agno/db/mongo/mongo.py +879 -11
- agno/db/mongo/schemas.py +42 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2912 -0
- agno/db/mysql/mysql.py +946 -68
- agno/db/mysql/schemas.py +72 -10
- agno/db/mysql/utils.py +198 -7
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2579 -0
- agno/db/postgres/postgres.py +942 -57
- agno/db/postgres/schemas.py +81 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +671 -7
- agno/db/redis/schemas.py +50 -0
- agno/db/redis/utils.py +65 -7
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +17 -2
- agno/db/singlestore/schemas.py +63 -0
- agno/db/singlestore/singlestore.py +949 -83
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2911 -0
- agno/db/sqlite/schemas.py +62 -0
- agno/db/sqlite/sqlite.py +965 -46
- agno/db/sqlite/utils.py +169 -8
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +334 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1908 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +2 -0
- agno/eval/__init__.py +10 -0
- agno/eval/accuracy.py +75 -55
- agno/eval/agent_as_judge.py +861 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +16 -7
- agno/eval/reliability.py +28 -16
- agno/eval/utils.py +35 -17
- agno/exceptions.py +27 -2
- agno/filters.py +354 -0
- agno/guardrails/prompt_injection.py +1 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +1 -1
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/semantic.py +9 -4
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +8 -0
- agno/knowledge/embedder/openai.py +8 -8
- agno/knowledge/embedder/sentence_transformer.py +6 -2
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/knowledge.py +1618 -318
- agno/knowledge/reader/base.py +6 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +16 -20
- agno/knowledge/reader/json_reader.py +5 -4
- agno/knowledge/reader/markdown_reader.py +8 -8
- agno/knowledge/reader/pdf_reader.py +17 -19
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +32 -3
- agno/knowledge/reader/s3_reader.py +3 -3
- agno/knowledge/reader/tavily_reader.py +193 -0
- agno/knowledge/reader/text_reader.py +22 -10
- agno/knowledge/reader/web_search_reader.py +1 -48
- agno/knowledge/reader/website_reader.py +10 -10
- agno/knowledge/reader/wikipedia_reader.py +33 -1
- agno/knowledge/types.py +1 -0
- agno/knowledge/utils.py +72 -7
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +544 -83
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +515 -40
- agno/models/aws/bedrock.py +102 -21
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +41 -19
- agno/models/azure/openai_chat.py +39 -8
- agno/models/base.py +1249 -525
- agno/models/cerebras/cerebras.py +91 -21
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +40 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +877 -80
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +51 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +44 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +28 -5
- agno/models/meta/llama.py +47 -14
- agno/models/meta/llama_openai.py +22 -17
- agno/models/mistral/mistral.py +8 -4
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/chat.py +24 -8
- agno/models/openai/chat.py +104 -29
- agno/models/openai/responses.py +101 -81
- agno/models/openrouter/openrouter.py +60 -3
- agno/models/perplexity/perplexity.py +17 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +24 -4
- agno/models/response.py +73 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +190 -0
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +549 -152
- agno/os/auth.py +190 -3
- agno/os/config.py +23 -0
- agno/os/interfaces/a2a/router.py +8 -11
- agno/os/interfaces/a2a/utils.py +1 -1
- agno/os/interfaces/agui/router.py +18 -3
- agno/os/interfaces/agui/utils.py +152 -39
- agno/os/interfaces/slack/router.py +55 -37
- agno/os/interfaces/slack/slack.py +9 -1
- agno/os/interfaces/whatsapp/router.py +0 -1
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/mcp.py +110 -52
- agno/os/middleware/__init__.py +2 -0
- agno/os/middleware/jwt.py +676 -112
- agno/os/router.py +40 -1478
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +599 -0
- agno/os/routers/agents/schema.py +261 -0
- agno/os/routers/evals/evals.py +96 -39
- agno/os/routers/evals/schemas.py +65 -33
- agno/os/routers/evals/utils.py +80 -10
- agno/os/routers/health.py +10 -4
- agno/os/routers/knowledge/knowledge.py +196 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +279 -52
- agno/os/routers/memory/schemas.py +46 -17
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +462 -34
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +512 -0
- agno/os/routers/teams/schema.py +257 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +499 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +624 -0
- agno/os/routers/workflows/schema.py +75 -0
- agno/os/schema.py +256 -693
- agno/os/scopes.py +469 -0
- agno/os/utils.py +514 -36
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/openai.py +5 -0
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +155 -32
- agno/run/base.py +55 -3
- agno/run/requirement.py +181 -0
- agno/run/team.py +125 -38
- agno/run/workflow.py +72 -18
- agno/session/agent.py +102 -89
- agno/session/summary.py +56 -15
- agno/session/team.py +164 -90
- agno/session/workflow.py +405 -40
- agno/table.py +10 -0
- agno/team/team.py +3974 -1903
- agno/tools/dalle.py +2 -4
- agno/tools/eleven_labs.py +23 -25
- agno/tools/exa.py +21 -16
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +16 -10
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +193 -38
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +271 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +3 -3
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/notion.py +204 -0
- agno/tools/parallel.py +314 -0
- agno/tools/postgres.py +76 -36
- agno/tools/redshift.py +406 -0
- agno/tools/scrapegraph.py +1 -1
- agno/tools/shopify.py +1519 -0
- agno/tools/slack.py +18 -3
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +146 -0
- agno/tools/toolkit.py +25 -0
- agno/tools/workflow.py +8 -1
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +157 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +111 -0
- agno/utils/agent.py +938 -0
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +151 -3
- agno/utils/gemini.py +15 -5
- agno/utils/hooks.py +118 -4
- agno/utils/http.py +113 -2
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +92 -2
- agno/utils/media.py +187 -1
- agno/utils/merge_dict.py +3 -3
- agno/utils/message.py +60 -0
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +49 -14
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/print_response/agent.py +109 -16
- agno/utils/print_response/team.py +223 -30
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/streamlit.py +1 -1
- agno/utils/team.py +98 -9
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +39 -7
- agno/vectordb/cassandra/cassandra.py +21 -5
- agno/vectordb/chroma/chromadb.py +43 -12
- agno/vectordb/clickhouse/clickhousedb.py +21 -5
- agno/vectordb/couchbase/couchbase.py +29 -5
- agno/vectordb/lancedb/lance_db.py +92 -181
- agno/vectordb/langchaindb/langchaindb.py +24 -4
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/llamaindexdb.py +25 -5
- agno/vectordb/milvus/milvus.py +50 -37
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +36 -30
- agno/vectordb/pgvector/pgvector.py +201 -77
- agno/vectordb/pineconedb/pineconedb.py +41 -23
- agno/vectordb/qdrant/qdrant.py +67 -54
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +682 -0
- agno/vectordb/singlestore/singlestore.py +50 -29
- agno/vectordb/surrealdb/surrealdb.py +31 -41
- agno/vectordb/upstashdb/upstashdb.py +34 -6
- agno/vectordb/weaviate/weaviate.py +53 -14
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +120 -18
- agno/workflow/loop.py +77 -10
- agno/workflow/parallel.py +231 -143
- agno/workflow/router.py +118 -17
- agno/workflow/step.py +609 -170
- agno/workflow/steps.py +73 -6
- agno/workflow/types.py +96 -21
- agno/workflow/workflow.py +2039 -262
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/METADATA +201 -66
- agno-2.3.13.dist-info/RECORD +613 -0
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -679
- agno/tools/memori.py +0 -339
- agno-2.1.2.dist-info/RECORD +0 -543
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/WHEEL +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1908 @@
|
|
|
1
|
+
from datetime import date, datetime, timedelta, timezone
|
|
2
|
+
from textwrap import dedent
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from agno.tracing.schemas import Span, Trace
|
|
7
|
+
|
|
8
|
+
from agno.db.base import BaseDb, SessionType
|
|
9
|
+
from agno.db.postgres.utils import (
|
|
10
|
+
get_dates_to_calculate_metrics_for,
|
|
11
|
+
)
|
|
12
|
+
from agno.db.schemas import UserMemory
|
|
13
|
+
from agno.db.schemas.culture import CulturalKnowledge
|
|
14
|
+
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
15
|
+
from agno.db.schemas.knowledge import KnowledgeRow
|
|
16
|
+
from agno.db.surrealdb import utils
|
|
17
|
+
from agno.db.surrealdb.metrics import (
|
|
18
|
+
bulk_upsert_metrics,
|
|
19
|
+
calculate_date_metrics,
|
|
20
|
+
fetch_all_sessions_data,
|
|
21
|
+
get_all_sessions_for_metrics_calculation,
|
|
22
|
+
get_metrics_calculation_starting_date,
|
|
23
|
+
)
|
|
24
|
+
from agno.db.surrealdb.models import (
|
|
25
|
+
TableType,
|
|
26
|
+
deserialize_cultural_knowledge,
|
|
27
|
+
deserialize_eval_run_record,
|
|
28
|
+
deserialize_knowledge_row,
|
|
29
|
+
deserialize_session,
|
|
30
|
+
deserialize_sessions,
|
|
31
|
+
deserialize_user_memories,
|
|
32
|
+
deserialize_user_memory,
|
|
33
|
+
desurrealize_eval_run_record,
|
|
34
|
+
desurrealize_session,
|
|
35
|
+
desurrealize_user_memory,
|
|
36
|
+
get_schema,
|
|
37
|
+
get_session_type,
|
|
38
|
+
serialize_cultural_knowledge,
|
|
39
|
+
serialize_eval_run_record,
|
|
40
|
+
serialize_knowledge_row,
|
|
41
|
+
serialize_session,
|
|
42
|
+
serialize_user_memory,
|
|
43
|
+
)
|
|
44
|
+
from agno.db.surrealdb.queries import COUNT_QUERY, WhereClause, order_limit_start
|
|
45
|
+
from agno.db.surrealdb.utils import build_client
|
|
46
|
+
from agno.session import Session
|
|
47
|
+
from agno.utils.log import log_debug, log_error, log_info
|
|
48
|
+
from agno.utils.string import generate_id
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
from surrealdb import BlockingHttpSurrealConnection, BlockingWsSurrealConnection, RecordID
|
|
52
|
+
except ImportError:
|
|
53
|
+
raise ImportError("The `surrealdb` package is not installed. Please install it via `pip install surrealdb`.")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SurrealDb(BaseDb):
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
client: Optional[Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]],
|
|
60
|
+
db_url: str,
|
|
61
|
+
db_creds: dict[str, str],
|
|
62
|
+
db_ns: str,
|
|
63
|
+
db_db: str,
|
|
64
|
+
session_table: Optional[str] = None,
|
|
65
|
+
memory_table: Optional[str] = None,
|
|
66
|
+
metrics_table: Optional[str] = None,
|
|
67
|
+
eval_table: Optional[str] = None,
|
|
68
|
+
knowledge_table: Optional[str] = None,
|
|
69
|
+
culture_table: Optional[str] = None,
|
|
70
|
+
traces_table: Optional[str] = None,
|
|
71
|
+
spans_table: Optional[str] = None,
|
|
72
|
+
id: Optional[str] = None,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Interface for interacting with a SurrealDB database.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
client: A blocking connection, either HTTP or WS
|
|
79
|
+
db_url: The URL of the SurrealDB database.
|
|
80
|
+
db_creds: The credentials for the SurrealDB database.
|
|
81
|
+
db_ns: The namespace for the SurrealDB database.
|
|
82
|
+
db_db: The database name for the SurrealDB database.
|
|
83
|
+
session_table: The name of the session table.
|
|
84
|
+
memory_table: The name of the memory table.
|
|
85
|
+
metrics_table: The name of the metrics table.
|
|
86
|
+
eval_table: The name of the eval table.
|
|
87
|
+
knowledge_table: The name of the knowledge table.
|
|
88
|
+
culture_table: The name of the culture table.
|
|
89
|
+
traces_table: The name of the traces table.
|
|
90
|
+
spans_table: The name of the spans table.
|
|
91
|
+
id: The ID of the database.
|
|
92
|
+
"""
|
|
93
|
+
if id is None:
|
|
94
|
+
base_seed = db_url
|
|
95
|
+
seed = f"{base_seed}#{db_db}"
|
|
96
|
+
id = generate_id(seed)
|
|
97
|
+
|
|
98
|
+
super().__init__(
|
|
99
|
+
id=id,
|
|
100
|
+
session_table=session_table,
|
|
101
|
+
memory_table=memory_table,
|
|
102
|
+
metrics_table=metrics_table,
|
|
103
|
+
eval_table=eval_table,
|
|
104
|
+
knowledge_table=knowledge_table,
|
|
105
|
+
culture_table=culture_table,
|
|
106
|
+
traces_table=traces_table,
|
|
107
|
+
spans_table=spans_table,
|
|
108
|
+
)
|
|
109
|
+
self._client = client
|
|
110
|
+
self._db_url = db_url
|
|
111
|
+
self._db_creds = db_creds
|
|
112
|
+
self._db_ns = db_ns
|
|
113
|
+
self._db_db = db_db
|
|
114
|
+
self._users_table_name: str = "agno_users"
|
|
115
|
+
self._agents_table_name: str = "agno_agents"
|
|
116
|
+
self._teams_table_name: str = "agno_teams"
|
|
117
|
+
self._workflows_table_name: str = "agno_workflows"
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def client(self) -> Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]:
|
|
121
|
+
if self._client is None:
|
|
122
|
+
self._client = build_client(self._db_url, self._db_creds, self._db_ns, self._db_db)
|
|
123
|
+
return self._client
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def table_names(self) -> dict[TableType, str]:
|
|
127
|
+
return {
|
|
128
|
+
"agents": self._agents_table_name,
|
|
129
|
+
"culture": self.culture_table_name,
|
|
130
|
+
"evals": self.eval_table_name,
|
|
131
|
+
"knowledge": self.knowledge_table_name,
|
|
132
|
+
"memories": self.memory_table_name,
|
|
133
|
+
"sessions": self.session_table_name,
|
|
134
|
+
"spans": self.span_table_name,
|
|
135
|
+
"teams": self._teams_table_name,
|
|
136
|
+
"traces": self.trace_table_name,
|
|
137
|
+
"users": self._users_table_name,
|
|
138
|
+
"workflows": self._workflows_table_name,
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
def table_exists(self, table_name: str) -> bool:
|
|
142
|
+
"""Check if a table with the given name exists in the SurrealDB database.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
table_name: Name of the table to check
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
bool: True if the table exists in the database, False otherwise
|
|
149
|
+
"""
|
|
150
|
+
response = self._query_one("INFO FOR DB", {}, dict)
|
|
151
|
+
if response is None:
|
|
152
|
+
raise Exception("Failed to retrieve database information")
|
|
153
|
+
return table_name in response.get("tables", [])
|
|
154
|
+
|
|
155
|
+
def _table_exists(self, table_name: str) -> bool:
|
|
156
|
+
"""Deprecated: Use table_exists() instead."""
|
|
157
|
+
return self.table_exists(table_name)
|
|
158
|
+
|
|
159
|
+
def _create_table(self, table_type: TableType, table_name: str):
|
|
160
|
+
query = get_schema(table_type, table_name)
|
|
161
|
+
self.client.query(query)
|
|
162
|
+
|
|
163
|
+
def _get_table(self, table_type: TableType, create_table_if_not_found: bool = True):
|
|
164
|
+
if table_type == "sessions":
|
|
165
|
+
table_name = self.session_table_name
|
|
166
|
+
elif table_type == "memories":
|
|
167
|
+
table_name = self.memory_table_name
|
|
168
|
+
elif table_type == "knowledge":
|
|
169
|
+
table_name = self.knowledge_table_name
|
|
170
|
+
elif table_type == "culture":
|
|
171
|
+
table_name = self.culture_table_name
|
|
172
|
+
elif table_type == "users":
|
|
173
|
+
table_name = self._users_table_name
|
|
174
|
+
elif table_type == "agents":
|
|
175
|
+
table_name = self._agents_table_name
|
|
176
|
+
elif table_type == "teams":
|
|
177
|
+
table_name = self._teams_table_name
|
|
178
|
+
elif table_type == "workflows":
|
|
179
|
+
table_name = self._workflows_table_name
|
|
180
|
+
elif table_type == "evals":
|
|
181
|
+
table_name = self.eval_table_name
|
|
182
|
+
elif table_type == "metrics":
|
|
183
|
+
table_name = self.metrics_table_name
|
|
184
|
+
elif table_type == "traces":
|
|
185
|
+
table_name = self.trace_table_name
|
|
186
|
+
elif table_type == "spans":
|
|
187
|
+
# Ensure traces table exists before spans (for foreign key-like relationship)
|
|
188
|
+
if create_table_if_not_found:
|
|
189
|
+
self._get_table("traces", create_table_if_not_found=True)
|
|
190
|
+
table_name = self.span_table_name
|
|
191
|
+
else:
|
|
192
|
+
raise NotImplementedError(f"Unknown table type: {table_type}")
|
|
193
|
+
|
|
194
|
+
if create_table_if_not_found and not self._table_exists(table_name):
|
|
195
|
+
self._create_table(table_type, table_name)
|
|
196
|
+
|
|
197
|
+
return table_name
|
|
198
|
+
|
|
199
|
+
def get_latest_schema_version(self):
|
|
200
|
+
"""Get the latest version of the database schema."""
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
def upsert_schema_version(self, version: str) -> None:
|
|
204
|
+
"""Upsert the schema version into the database."""
|
|
205
|
+
pass
|
|
206
|
+
|
|
207
|
+
def _query(
|
|
208
|
+
self,
|
|
209
|
+
query: str,
|
|
210
|
+
vars: dict[str, Any],
|
|
211
|
+
record_type: type[utils.RecordType],
|
|
212
|
+
) -> Sequence[utils.RecordType]:
|
|
213
|
+
return utils.query(self.client, query, vars, record_type)
|
|
214
|
+
|
|
215
|
+
def _query_one(
|
|
216
|
+
self,
|
|
217
|
+
query: str,
|
|
218
|
+
vars: dict[str, Any],
|
|
219
|
+
record_type: type[utils.RecordType],
|
|
220
|
+
) -> Optional[utils.RecordType]:
|
|
221
|
+
return utils.query_one(self.client, query, vars, record_type)
|
|
222
|
+
|
|
223
|
+
def _count(self, table: str, where_clause: str, where_vars: dict[str, Any], group_by: Optional[str] = None) -> int:
|
|
224
|
+
total_count_query = COUNT_QUERY.format(
|
|
225
|
+
table=table,
|
|
226
|
+
where_clause=where_clause,
|
|
227
|
+
group_clause="GROUP ALL" if group_by is None else f"GROUP BY {group_by}",
|
|
228
|
+
group_fields="" if group_by is None else f", {group_by}",
|
|
229
|
+
)
|
|
230
|
+
count_result = self._query_one(total_count_query, where_vars, dict)
|
|
231
|
+
total_count = count_result.get("count") if count_result else 0
|
|
232
|
+
assert isinstance(total_count, int), f"Expected int, got {type(total_count)}"
|
|
233
|
+
total_count = int(total_count)
|
|
234
|
+
return total_count
|
|
235
|
+
|
|
236
|
+
# --- Sessions ---
|
|
237
|
+
def clear_sessions(self) -> None:
|
|
238
|
+
"""Delete all session rows from the database.
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
Exception: If an error occurs during deletion.
|
|
242
|
+
"""
|
|
243
|
+
table = self._get_table("sessions")
|
|
244
|
+
_ = self.client.delete(table)
|
|
245
|
+
|
|
246
|
+
def delete_session(self, session_id: str) -> bool:
|
|
247
|
+
table = self._get_table(table_type="sessions")
|
|
248
|
+
if table is None:
|
|
249
|
+
return False
|
|
250
|
+
res = self.client.delete(RecordID(table, session_id))
|
|
251
|
+
return bool(res)
|
|
252
|
+
|
|
253
|
+
def delete_sessions(self, session_ids: list[str]) -> None:
|
|
254
|
+
table = self._get_table(table_type="sessions")
|
|
255
|
+
if table is None:
|
|
256
|
+
return
|
|
257
|
+
|
|
258
|
+
records = [RecordID(table, id) for id in session_ids]
|
|
259
|
+
self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
|
|
260
|
+
|
|
261
|
+
def get_session(
|
|
262
|
+
self,
|
|
263
|
+
session_id: str,
|
|
264
|
+
session_type: SessionType,
|
|
265
|
+
user_id: Optional[str] = None,
|
|
266
|
+
deserialize: Optional[bool] = True,
|
|
267
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
268
|
+
r"""
|
|
269
|
+
Read a session from the database.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
session_id (str): ID of the session to read.
|
|
273
|
+
session_type (SessionType): Type of session to get.
|
|
274
|
+
user_id (Optional[str]): User ID to filter by. Defaults to None.
|
|
275
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Optional[Union[Session, Dict[str, Any]]]:
|
|
279
|
+
- When deserialize=True: Session object
|
|
280
|
+
- When deserialize=False: Session dictionary
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
Exception: If an error occurs during retrieval.
|
|
284
|
+
"""
|
|
285
|
+
sessions_table = self._get_table("sessions")
|
|
286
|
+
record = RecordID(sessions_table, session_id)
|
|
287
|
+
where = WhereClause()
|
|
288
|
+
if user_id is not None:
|
|
289
|
+
where = where.and_("user_id", user_id)
|
|
290
|
+
where_clause, where_vars = where.build()
|
|
291
|
+
query = dedent(f"""
|
|
292
|
+
SELECT *
|
|
293
|
+
FROM ONLY $record
|
|
294
|
+
{where_clause}
|
|
295
|
+
""")
|
|
296
|
+
vars = {"record": record, **where_vars}
|
|
297
|
+
raw = self._query_one(query, vars, dict)
|
|
298
|
+
if raw is None or not deserialize:
|
|
299
|
+
return raw
|
|
300
|
+
|
|
301
|
+
return deserialize_session(session_type, raw)
|
|
302
|
+
|
|
303
|
+
def get_sessions(
|
|
304
|
+
self,
|
|
305
|
+
session_type: Optional[SessionType] = None,
|
|
306
|
+
user_id: Optional[str] = None,
|
|
307
|
+
component_id: Optional[str] = None,
|
|
308
|
+
session_name: Optional[str] = None,
|
|
309
|
+
start_timestamp: Optional[int] = None,
|
|
310
|
+
end_timestamp: Optional[int] = None,
|
|
311
|
+
limit: Optional[int] = None,
|
|
312
|
+
page: Optional[int] = None,
|
|
313
|
+
sort_by: Optional[str] = None,
|
|
314
|
+
sort_order: Optional[str] = None,
|
|
315
|
+
deserialize: Optional[bool] = True,
|
|
316
|
+
) -> Union[List[Session], Tuple[List[Dict[str, Any]], int]]:
|
|
317
|
+
r"""
|
|
318
|
+
Get all sessions in the given table. Can filter by user_id and entity_id.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
session_type (SessionType): The type of session to get.
|
|
322
|
+
user_id (Optional[str]): The ID of the user to filter by.
|
|
323
|
+
component_id (Optional[str]): The ID of the agent / team / workflow to filter by.
|
|
324
|
+
session_name (Optional[str]): The name of the session to filter by.
|
|
325
|
+
start_timestamp (Optional[int]): The start timestamp to filter by.
|
|
326
|
+
end_timestamp (Optional[int]): The end timestamp to filter by.
|
|
327
|
+
limit (Optional[int]): The maximum number of sessions to return. Defaults to None.
|
|
328
|
+
page (Optional[int]): The page number to return. Defaults to None.
|
|
329
|
+
sort_by (Optional[str]): The field to sort by. Defaults to None.
|
|
330
|
+
sort_order (Optional[str]): The sort order. Defaults to None.
|
|
331
|
+
deserialize (Optional[bool]): Whether to serialize the sessions. Defaults to True.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Union[List[Session], Tuple[List[Dict], int]]:
|
|
335
|
+
- When deserialize=True: List of Session objects
|
|
336
|
+
- When deserialize=False: Tuple of (session dictionaries, total count)
|
|
337
|
+
|
|
338
|
+
Raises:
|
|
339
|
+
Exception: If an error occurs during retrieval.
|
|
340
|
+
"""
|
|
341
|
+
table = self._get_table("sessions")
|
|
342
|
+
# users_table = self._get_table("users", False) # Not used, commenting out for now.
|
|
343
|
+
agents_table = self._get_table("agents", False)
|
|
344
|
+
teams_table = self._get_table("teams", False)
|
|
345
|
+
workflows_table = self._get_table("workflows", False)
|
|
346
|
+
|
|
347
|
+
# -- Filters
|
|
348
|
+
where = WhereClause()
|
|
349
|
+
|
|
350
|
+
# user_id
|
|
351
|
+
if user_id is not None:
|
|
352
|
+
where = where.and_("user_id", user_id)
|
|
353
|
+
|
|
354
|
+
# component_id
|
|
355
|
+
if component_id is not None:
|
|
356
|
+
if session_type == SessionType.AGENT:
|
|
357
|
+
where = where.and_("agent", RecordID(agents_table, component_id))
|
|
358
|
+
elif session_type == SessionType.TEAM:
|
|
359
|
+
where = where.and_("team", RecordID(teams_table, component_id))
|
|
360
|
+
elif session_type == SessionType.WORKFLOW:
|
|
361
|
+
where = where.and_("workflow", RecordID(workflows_table, component_id))
|
|
362
|
+
|
|
363
|
+
# session_name
|
|
364
|
+
if session_name is not None:
|
|
365
|
+
where = where.and_("session_name", session_name, "~")
|
|
366
|
+
|
|
367
|
+
# start_timestamp
|
|
368
|
+
if start_timestamp is not None:
|
|
369
|
+
where = where.and_("start_timestamp", start_timestamp, ">=")
|
|
370
|
+
|
|
371
|
+
# end_timestamp
|
|
372
|
+
if end_timestamp is not None:
|
|
373
|
+
where = where.and_("end_timestamp", end_timestamp, "<=")
|
|
374
|
+
|
|
375
|
+
where_clause, where_vars = where.build()
|
|
376
|
+
|
|
377
|
+
# Total count
|
|
378
|
+
total_count = self._count(table, where_clause, where_vars)
|
|
379
|
+
|
|
380
|
+
# Query
|
|
381
|
+
order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
|
|
382
|
+
query = dedent(f"""
|
|
383
|
+
SELECT *
|
|
384
|
+
FROM {table}
|
|
385
|
+
{where_clause}
|
|
386
|
+
{order_limit_start_clause}
|
|
387
|
+
""")
|
|
388
|
+
sessions_raw = self._query(query, where_vars, dict)
|
|
389
|
+
converted_sessions_raw = [desurrealize_session(session, session_type) for session in sessions_raw]
|
|
390
|
+
|
|
391
|
+
if not deserialize:
|
|
392
|
+
return list(converted_sessions_raw), total_count
|
|
393
|
+
|
|
394
|
+
if session_type is None:
|
|
395
|
+
raise ValueError("session_type is required when deserialize=True")
|
|
396
|
+
|
|
397
|
+
return deserialize_sessions(session_type, list(sessions_raw))
|
|
398
|
+
|
|
399
|
+
def rename_session(
|
|
400
|
+
self, session_id: str, session_type: SessionType, session_name: str, deserialize: Optional[bool] = True
|
|
401
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
402
|
+
"""
|
|
403
|
+
Rename a session in the database.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
session_id (str): The ID of the session to rename.
|
|
407
|
+
session_type (SessionType): The type of session to rename.
|
|
408
|
+
session_name (str): The new name for the session.
|
|
409
|
+
deserialize (Optional[bool]): Whether to serialize the session. Defaults to True.
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
Optional[Union[Session, Dict[str, Any]]]:
|
|
413
|
+
- When deserialize=True: Session object
|
|
414
|
+
- When deserialize=False: Session dictionary
|
|
415
|
+
|
|
416
|
+
Raises:
|
|
417
|
+
Exception: If an error occurs during renaming.
|
|
418
|
+
"""
|
|
419
|
+
table = self._get_table("sessions")
|
|
420
|
+
vars = {"record": RecordID(table, session_id), "name": session_name}
|
|
421
|
+
|
|
422
|
+
# Query
|
|
423
|
+
query = dedent("""
|
|
424
|
+
UPDATE ONLY $record
|
|
425
|
+
SET session_name = $name
|
|
426
|
+
""")
|
|
427
|
+
session_raw = self._query_one(query, vars, dict)
|
|
428
|
+
|
|
429
|
+
if session_raw is None or not deserialize:
|
|
430
|
+
return session_raw
|
|
431
|
+
return deserialize_session(session_type, session_raw)
|
|
432
|
+
|
|
433
|
+
def upsert_session(
|
|
434
|
+
self, session: Session, deserialize: Optional[bool] = True
|
|
435
|
+
) -> Optional[Union[Session, Dict[str, Any]]]:
|
|
436
|
+
"""
|
|
437
|
+
Insert or update a session in the database.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
session (Session): The session data to upsert.
|
|
441
|
+
deserialize (Optional[bool]): Whether to deserialize the session. Defaults to True.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Optional[Union[Session, Dict[str, Any]]]:
|
|
445
|
+
- When deserialize=True: Session object
|
|
446
|
+
- When deserialize=False: Session dictionary
|
|
447
|
+
|
|
448
|
+
Raises:
|
|
449
|
+
Exception: If an error occurs during upsert.
|
|
450
|
+
"""
|
|
451
|
+
session_type = get_session_type(session)
|
|
452
|
+
table = self._get_table("sessions")
|
|
453
|
+
session_raw = self._query_one(
|
|
454
|
+
"UPSERT ONLY $record CONTENT $content",
|
|
455
|
+
{
|
|
456
|
+
"record": RecordID(table, session.session_id),
|
|
457
|
+
"content": serialize_session(session, self.table_names),
|
|
458
|
+
},
|
|
459
|
+
dict,
|
|
460
|
+
)
|
|
461
|
+
if session_raw is None or not deserialize:
|
|
462
|
+
return session_raw
|
|
463
|
+
|
|
464
|
+
return deserialize_session(session_type, session_raw)
|
|
465
|
+
|
|
466
|
+
def upsert_sessions(
|
|
467
|
+
self, sessions: List[Session], deserialize: Optional[bool] = True
|
|
468
|
+
) -> List[Union[Session, Dict[str, Any]]]:
|
|
469
|
+
"""
|
|
470
|
+
Bulk insert or update multiple sessions.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
sessions (List[Session]): The list of session data to upsert.
|
|
474
|
+
deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
List[Union[Session, Dict[str, Any]]]: List of upserted sessions
|
|
478
|
+
|
|
479
|
+
Raises:
|
|
480
|
+
Exception: If an error occurs during bulk upsert.
|
|
481
|
+
"""
|
|
482
|
+
if not sessions:
|
|
483
|
+
return []
|
|
484
|
+
session_type = get_session_type(sessions[0])
|
|
485
|
+
table = self._get_table("sessions")
|
|
486
|
+
sessions_raw: List[Dict[str, Any]] = []
|
|
487
|
+
for session in sessions:
|
|
488
|
+
# UPSERT does only work for one record at a time
|
|
489
|
+
session_raw = self._query_one(
|
|
490
|
+
"UPSERT ONLY $record CONTENT $content",
|
|
491
|
+
{
|
|
492
|
+
"record": RecordID(table, session.session_id),
|
|
493
|
+
"content": serialize_session(session, self.table_names),
|
|
494
|
+
},
|
|
495
|
+
dict,
|
|
496
|
+
)
|
|
497
|
+
if session_raw:
|
|
498
|
+
sessions_raw.append(session_raw)
|
|
499
|
+
if not deserialize:
|
|
500
|
+
return list(sessions_raw)
|
|
501
|
+
|
|
502
|
+
# wrapping with list because of:
|
|
503
|
+
# Type "List[Session]" is not assignable to return type "List[Session | Dict[str, Any]]"
|
|
504
|
+
# Consider switching from "list" to "Sequence" which is covariant
|
|
505
|
+
return list(deserialize_sessions(session_type, sessions_raw))
|
|
506
|
+
|
|
507
|
+
# --- Memory ---
|
|
508
|
+
def clear_memories(self) -> None:
|
|
509
|
+
"""Delete all memories from the database.
|
|
510
|
+
|
|
511
|
+
Raises:
|
|
512
|
+
Exception: If an error occurs during deletion.
|
|
513
|
+
"""
|
|
514
|
+
table = self._get_table("memories")
|
|
515
|
+
_ = self.client.delete(table)
|
|
516
|
+
|
|
517
|
+
# -- Cultural Knowledge methods --
|
|
518
|
+
def clear_cultural_knowledge(self) -> None:
|
|
519
|
+
"""Delete all cultural knowledge from the database.
|
|
520
|
+
|
|
521
|
+
Raises:
|
|
522
|
+
Exception: If an error occurs during deletion.
|
|
523
|
+
"""
|
|
524
|
+
table = self._get_table("culture")
|
|
525
|
+
_ = self.client.delete(table)
|
|
526
|
+
|
|
527
|
+
def delete_cultural_knowledge(self, id: str) -> None:
|
|
528
|
+
"""Delete cultural knowledge by ID.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
id (str): The ID of the cultural knowledge to delete.
|
|
532
|
+
|
|
533
|
+
Raises:
|
|
534
|
+
Exception: If an error occurs during deletion.
|
|
535
|
+
"""
|
|
536
|
+
table = self._get_table("culture")
|
|
537
|
+
rec_id = RecordID(table, id)
|
|
538
|
+
self.client.delete(rec_id)
|
|
539
|
+
|
|
540
|
+
def get_cultural_knowledge(
|
|
541
|
+
self, id: str, deserialize: Optional[bool] = True
|
|
542
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
543
|
+
"""Get cultural knowledge by ID.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
id (str): The ID of the cultural knowledge to retrieve.
|
|
547
|
+
deserialize (Optional[bool]): Whether to deserialize to CulturalKnowledge object. Defaults to True.
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The cultural knowledge if found, None otherwise.
|
|
551
|
+
|
|
552
|
+
Raises:
|
|
553
|
+
Exception: If an error occurs during retrieval.
|
|
554
|
+
"""
|
|
555
|
+
table = self._get_table("culture")
|
|
556
|
+
rec_id = RecordID(table, id)
|
|
557
|
+
result = self.client.select(rec_id)
|
|
558
|
+
|
|
559
|
+
if result is None:
|
|
560
|
+
return None
|
|
561
|
+
|
|
562
|
+
if not deserialize:
|
|
563
|
+
return result # type: ignore
|
|
564
|
+
|
|
565
|
+
return deserialize_cultural_knowledge(result) # type: ignore
|
|
566
|
+
|
|
567
|
+
def get_all_cultural_knowledge(
|
|
568
|
+
self,
|
|
569
|
+
agent_id: Optional[str] = None,
|
|
570
|
+
team_id: Optional[str] = None,
|
|
571
|
+
name: Optional[str] = None,
|
|
572
|
+
limit: Optional[int] = None,
|
|
573
|
+
page: Optional[int] = None,
|
|
574
|
+
sort_by: Optional[str] = None,
|
|
575
|
+
sort_order: Optional[str] = None,
|
|
576
|
+
deserialize: Optional[bool] = True,
|
|
577
|
+
) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
578
|
+
"""Get all cultural knowledge with filtering and pagination.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
agent_id (Optional[str]): Filter by agent ID.
|
|
582
|
+
team_id (Optional[str]): Filter by team ID.
|
|
583
|
+
name (Optional[str]): Filter by name (case-insensitive partial match).
|
|
584
|
+
limit (Optional[int]): Maximum number of results to return.
|
|
585
|
+
page (Optional[int]): Page number for pagination.
|
|
586
|
+
sort_by (Optional[str]): Field to sort by.
|
|
587
|
+
sort_order (Optional[str]): Sort order ('asc' or 'desc').
|
|
588
|
+
deserialize (Optional[bool]): Whether to deserialize to CulturalKnowledge objects. Defaults to True.
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
592
|
+
- When deserialize=True: List of CulturalKnowledge objects
|
|
593
|
+
- When deserialize=False: Tuple with list of dictionaries and total count
|
|
594
|
+
|
|
595
|
+
Raises:
|
|
596
|
+
Exception: If an error occurs during retrieval.
|
|
597
|
+
"""
|
|
598
|
+
table = self._get_table("culture")
|
|
599
|
+
|
|
600
|
+
# Build where clauses
|
|
601
|
+
where_clauses: List[WhereClause] = []
|
|
602
|
+
if agent_id is not None:
|
|
603
|
+
agent_rec_id = RecordID(self._get_table("agents"), agent_id)
|
|
604
|
+
where_clauses.append(("agent", "=", agent_rec_id)) # type: ignore
|
|
605
|
+
if team_id is not None:
|
|
606
|
+
team_rec_id = RecordID(self._get_table("teams"), team_id)
|
|
607
|
+
where_clauses.append(("team", "=", team_rec_id)) # type: ignore
|
|
608
|
+
if name is not None:
|
|
609
|
+
where_clauses.append(("string::lowercase(name)", "CONTAINS", name.lower())) # type: ignore
|
|
610
|
+
|
|
611
|
+
# Build query for total count
|
|
612
|
+
count_query = COUNT_QUERY.format(
|
|
613
|
+
table=table,
|
|
614
|
+
where=""
|
|
615
|
+
if not where_clauses
|
|
616
|
+
else f"WHERE {' AND '.join(f'{w[0]} {w[1]} ${chr(97 + i)}' for i, w in enumerate(where_clauses))}", # type: ignore
|
|
617
|
+
)
|
|
618
|
+
params = {chr(97 + i): w[2] for i, w in enumerate(where_clauses)} # type: ignore
|
|
619
|
+
total_count = self._query_one(count_query, params, int) or 0
|
|
620
|
+
|
|
621
|
+
# Build main query
|
|
622
|
+
order_limit = order_limit_start(sort_by, sort_order, limit, page)
|
|
623
|
+
query = f"SELECT * FROM {table}"
|
|
624
|
+
if where_clauses:
|
|
625
|
+
query += f" WHERE {' AND '.join(f'{w[0]} {w[1]} ${chr(97 + i)}' for i, w in enumerate(where_clauses))}" # type: ignore
|
|
626
|
+
query += order_limit
|
|
627
|
+
|
|
628
|
+
results = self._query(query, params, list) or []
|
|
629
|
+
|
|
630
|
+
if not deserialize:
|
|
631
|
+
return results, total_count # type: ignore
|
|
632
|
+
|
|
633
|
+
return [deserialize_cultural_knowledge(r) for r in results] # type: ignore
|
|
634
|
+
|
|
635
|
+
def upsert_cultural_knowledge(
|
|
636
|
+
self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
|
|
637
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
638
|
+
"""Upsert cultural knowledge in SurrealDB.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
cultural_knowledge (CulturalKnowledge): The cultural knowledge to upsert.
|
|
642
|
+
deserialize (Optional[bool]): Whether to deserialize the result. Defaults to True.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The upserted cultural knowledge.
|
|
646
|
+
|
|
647
|
+
Raises:
|
|
648
|
+
Exception: If an error occurs during upsert.
|
|
649
|
+
"""
|
|
650
|
+
table = self._get_table("culture", create_table_if_not_found=True)
|
|
651
|
+
serialized = serialize_cultural_knowledge(cultural_knowledge, table)
|
|
652
|
+
|
|
653
|
+
result = self.client.upsert(serialized["id"], serialized)
|
|
654
|
+
|
|
655
|
+
if result is None:
|
|
656
|
+
return None
|
|
657
|
+
|
|
658
|
+
if not deserialize:
|
|
659
|
+
return result # type: ignore
|
|
660
|
+
|
|
661
|
+
return deserialize_cultural_knowledge(result) # type: ignore
|
|
662
|
+
|
|
663
|
+
def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
|
|
664
|
+
"""Delete a user memory from the database.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
memory_id (str): The ID of the memory to delete.
|
|
668
|
+
user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
|
|
669
|
+
|
|
670
|
+
Returns:
|
|
671
|
+
bool: True if deletion was successful, False otherwise.
|
|
672
|
+
|
|
673
|
+
Raises:
|
|
674
|
+
Exception: If an error occurs during deletion.
|
|
675
|
+
"""
|
|
676
|
+
table = self._get_table("memories")
|
|
677
|
+
mem_rec_id = RecordID(table, memory_id)
|
|
678
|
+
if user_id is None:
|
|
679
|
+
self.client.delete(mem_rec_id)
|
|
680
|
+
else:
|
|
681
|
+
user_rec_id = RecordID(self._get_table("users"), user_id)
|
|
682
|
+
self.client.query(
|
|
683
|
+
f"DELETE FROM {table} WHERE user = $user AND id = $memory",
|
|
684
|
+
{"user": user_rec_id, "memory": mem_rec_id},
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
|
|
688
|
+
"""Delete user memories from the database.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
memory_ids (List[str]): The IDs of the memories to delete.
|
|
692
|
+
user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
|
|
693
|
+
|
|
694
|
+
Raises:
|
|
695
|
+
Exception: If an error occurs during deletion.
|
|
696
|
+
"""
|
|
697
|
+
table = self._get_table("memories")
|
|
698
|
+
records = [RecordID(table, memory_id) for memory_id in memory_ids]
|
|
699
|
+
if user_id is None:
|
|
700
|
+
_ = self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
|
|
701
|
+
else:
|
|
702
|
+
user_rec_id = RecordID(self._get_table("users"), user_id)
|
|
703
|
+
_ = self.client.query(
|
|
704
|
+
f"DELETE FROM {table} WHERE id IN $records AND user = $user", {"records": records, "user": user_rec_id}
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
|
|
708
|
+
"""Get all memory topics from the database.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
List[str]: List of memory topics.
|
|
715
|
+
"""
|
|
716
|
+
table = self._get_table("memories")
|
|
717
|
+
vars: dict[str, Any] = {}
|
|
718
|
+
|
|
719
|
+
# Query
|
|
720
|
+
if user_id is None:
|
|
721
|
+
query = dedent(f"""
|
|
722
|
+
RETURN (
|
|
723
|
+
SELECT
|
|
724
|
+
array::flatten(topics) as topics
|
|
725
|
+
FROM ONLY {table}
|
|
726
|
+
GROUP ALL
|
|
727
|
+
).topics.distinct();
|
|
728
|
+
""")
|
|
729
|
+
else:
|
|
730
|
+
query = dedent(f"""
|
|
731
|
+
RETURN (
|
|
732
|
+
SELECT
|
|
733
|
+
array::flatten(topics) as topics
|
|
734
|
+
FROM ONLY {table}
|
|
735
|
+
WHERE user = $user
|
|
736
|
+
GROUP ALL
|
|
737
|
+
).topics.distinct();
|
|
738
|
+
""")
|
|
739
|
+
vars["user"] = RecordID(self._get_table("users"), user_id)
|
|
740
|
+
|
|
741
|
+
result = self._query(query, vars, str)
|
|
742
|
+
return list(result)
|
|
743
|
+
|
|
744
|
+
def get_user_memory(
|
|
745
|
+
self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
|
|
746
|
+
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
747
|
+
"""Get a memory from the database.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
memory_id (str): The ID of the memory to get.
|
|
751
|
+
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
752
|
+
user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
|
|
753
|
+
|
|
754
|
+
Returns:
|
|
755
|
+
Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
756
|
+
- When deserialize=True: UserMemory object
|
|
757
|
+
- When deserialize=False: UserMemory dictionary
|
|
758
|
+
|
|
759
|
+
Raises:
|
|
760
|
+
Exception: If an error occurs during retrieval.
|
|
761
|
+
"""
|
|
762
|
+
table_name = self._get_table("memories")
|
|
763
|
+
record = RecordID(table_name, memory_id)
|
|
764
|
+
vars = {"record": record}
|
|
765
|
+
|
|
766
|
+
if user_id is None:
|
|
767
|
+
query = "SELECT * FROM ONLY $record"
|
|
768
|
+
else:
|
|
769
|
+
query = "SELECT * FROM ONLY $record WHERE user = $user"
|
|
770
|
+
vars["user"] = RecordID(self._get_table("users"), user_id)
|
|
771
|
+
|
|
772
|
+
result = self._query_one(query, vars, dict)
|
|
773
|
+
if result is None or not deserialize:
|
|
774
|
+
return result
|
|
775
|
+
return deserialize_user_memory(result)
|
|
776
|
+
|
|
777
|
+
def get_user_memories(
|
|
778
|
+
self,
|
|
779
|
+
user_id: Optional[str] = None,
|
|
780
|
+
agent_id: Optional[str] = None,
|
|
781
|
+
team_id: Optional[str] = None,
|
|
782
|
+
topics: Optional[List[str]] = None,
|
|
783
|
+
search_content: Optional[str] = None,
|
|
784
|
+
limit: Optional[int] = None,
|
|
785
|
+
page: Optional[int] = None,
|
|
786
|
+
sort_by: Optional[str] = None,
|
|
787
|
+
sort_order: Optional[str] = None,
|
|
788
|
+
deserialize: Optional[bool] = True,
|
|
789
|
+
) -> Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
|
|
790
|
+
"""Get all memories from the database as UserMemory objects.
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
user_id (Optional[str]): The ID of the user to filter by.
|
|
794
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
795
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
796
|
+
topics (Optional[List[str]]): The topics to filter by.
|
|
797
|
+
search_content (Optional[str]): The content to search for.
|
|
798
|
+
limit (Optional[int]): The maximum number of memories to return.
|
|
799
|
+
page (Optional[int]): The page number.
|
|
800
|
+
sort_by (Optional[str]): The column to sort by.
|
|
801
|
+
sort_order (Optional[str]): The order to sort by.
|
|
802
|
+
deserialize (Optional[bool]): Whether to serialize the memories. Defaults to True.
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
Returns:
|
|
806
|
+
Union[List[UserMemory], Tuple[List[Dict[str, Any]], int]]:
|
|
807
|
+
- When deserialize=True: List of UserMemory objects
|
|
808
|
+
- When deserialize=False: Tuple of (memory dictionaries, total count)
|
|
809
|
+
|
|
810
|
+
Raises:
|
|
811
|
+
Exception: If an error occurs during retrieval.
|
|
812
|
+
"""
|
|
813
|
+
table = self._get_table("memories")
|
|
814
|
+
where = WhereClause()
|
|
815
|
+
if user_id is not None:
|
|
816
|
+
rec_id = RecordID(self._get_table("users"), user_id)
|
|
817
|
+
where.and_("user", rec_id)
|
|
818
|
+
if agent_id is not None:
|
|
819
|
+
rec_id = RecordID(self._get_table("agents"), agent_id)
|
|
820
|
+
where.and_("agent", rec_id)
|
|
821
|
+
if team_id is not None:
|
|
822
|
+
rec_id = RecordID(self._get_table("teams"), team_id)
|
|
823
|
+
where.and_("team", rec_id)
|
|
824
|
+
if topics is not None:
|
|
825
|
+
where.and_("topics", topics, "CONTAINSANY")
|
|
826
|
+
if search_content is not None:
|
|
827
|
+
where.and_("memory", search_content, "~")
|
|
828
|
+
where_clause, where_vars = where.build()
|
|
829
|
+
|
|
830
|
+
# Total count
|
|
831
|
+
total_count = self._count(table, where_clause, where_vars)
|
|
832
|
+
|
|
833
|
+
# Query
|
|
834
|
+
order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
|
|
835
|
+
query = dedent(f"""
|
|
836
|
+
SELECT *
|
|
837
|
+
FROM {table}
|
|
838
|
+
{where_clause}
|
|
839
|
+
{order_limit_start_clause}
|
|
840
|
+
""")
|
|
841
|
+
result = self._query(query, where_vars, dict)
|
|
842
|
+
if deserialize:
|
|
843
|
+
return deserialize_user_memories(result)
|
|
844
|
+
return [desurrealize_user_memory(x) for x in result], total_count
|
|
845
|
+
|
|
846
|
+
def get_user_memory_stats(
|
|
847
|
+
self,
|
|
848
|
+
limit: Optional[int] = None,
|
|
849
|
+
page: Optional[int] = None,
|
|
850
|
+
user_id: Optional[str] = None,
|
|
851
|
+
) -> Tuple[List[Dict[str, Any]], int]:
|
|
852
|
+
"""Get user memories stats.
|
|
853
|
+
|
|
854
|
+
Args:
|
|
855
|
+
limit (Optional[int]): The maximum number of user stats to return.
|
|
856
|
+
page (Optional[int]): The page number.
|
|
857
|
+
user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
|
|
858
|
+
|
|
859
|
+
Returns:
|
|
860
|
+
Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
|
|
861
|
+
|
|
862
|
+
Example:
|
|
863
|
+
(
|
|
864
|
+
[
|
|
865
|
+
{
|
|
866
|
+
"user_id": "123",
|
|
867
|
+
"total_memories": 10,
|
|
868
|
+
"last_memory_updated_at": 1714560000,
|
|
869
|
+
},
|
|
870
|
+
],
|
|
871
|
+
total_count: 1,
|
|
872
|
+
)
|
|
873
|
+
"""
|
|
874
|
+
memories_table_name = self._get_table("memories")
|
|
875
|
+
where = WhereClause()
|
|
876
|
+
|
|
877
|
+
if user_id is None:
|
|
878
|
+
where.and_("!!user", True, "=") # this checks that user is not falsy
|
|
879
|
+
else:
|
|
880
|
+
where.and_("user", RecordID(self._get_table("users"), user_id), "=")
|
|
881
|
+
|
|
882
|
+
where_clause, where_vars = where.build()
|
|
883
|
+
# Group
|
|
884
|
+
group_clause = "GROUP BY user"
|
|
885
|
+
# Order
|
|
886
|
+
order_limit_start_clause = order_limit_start("last_memory_updated_at", "DESC", limit, page)
|
|
887
|
+
# Total count
|
|
888
|
+
total_count = (
|
|
889
|
+
self._query_one(f"(SELECT user FROM {memories_table_name} GROUP BY user).map(|$x| $x.user).len()", {}, int)
|
|
890
|
+
or 0
|
|
891
|
+
)
|
|
892
|
+
# Query
|
|
893
|
+
query = dedent(f"""
|
|
894
|
+
SELECT
|
|
895
|
+
user,
|
|
896
|
+
count(id) AS total_memories,
|
|
897
|
+
time::max(updated_at) AS last_memory_updated_at
|
|
898
|
+
FROM {memories_table_name}
|
|
899
|
+
{where_clause}
|
|
900
|
+
{group_clause}
|
|
901
|
+
{order_limit_start_clause}
|
|
902
|
+
""")
|
|
903
|
+
result = self._query(query, where_vars, dict)
|
|
904
|
+
|
|
905
|
+
# deserialize dates and RecordIDs
|
|
906
|
+
for row in result:
|
|
907
|
+
row["user_id"] = row["user"].id
|
|
908
|
+
del row["user"]
|
|
909
|
+
row["last_memory_updated_at"] = row["last_memory_updated_at"].timestamp()
|
|
910
|
+
row["last_memory_updated_at"] = int(row["last_memory_updated_at"])
|
|
911
|
+
|
|
912
|
+
return list(result), total_count
|
|
913
|
+
|
|
914
|
+
def upsert_user_memory(
|
|
915
|
+
self, memory: UserMemory, deserialize: Optional[bool] = True
|
|
916
|
+
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
917
|
+
"""Upsert a user memory in the database.
|
|
918
|
+
|
|
919
|
+
Args:
|
|
920
|
+
memory (UserMemory): The user memory to upsert.
|
|
921
|
+
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
925
|
+
- When deserialize=True: UserMemory object
|
|
926
|
+
- When deserialize=False: UserMemory dictionary
|
|
927
|
+
|
|
928
|
+
Raises:
|
|
929
|
+
Exception: If an error occurs during upsert.
|
|
930
|
+
"""
|
|
931
|
+
table = self._get_table("memories")
|
|
932
|
+
user_table = self._get_table("users")
|
|
933
|
+
if memory.memory_id:
|
|
934
|
+
record = RecordID(table, memory.memory_id)
|
|
935
|
+
query = "UPSERT ONLY $record CONTENT $content"
|
|
936
|
+
result = self._query_one(
|
|
937
|
+
query, {"record": record, "content": serialize_user_memory(memory, table, user_table)}, dict
|
|
938
|
+
)
|
|
939
|
+
else:
|
|
940
|
+
query = f"CREATE ONLY {table} CONTENT $content"
|
|
941
|
+
result = self._query_one(query, {"content": serialize_user_memory(memory, table, user_table)}, dict)
|
|
942
|
+
if result is None:
|
|
943
|
+
return None
|
|
944
|
+
elif not deserialize:
|
|
945
|
+
return desurrealize_user_memory(result)
|
|
946
|
+
return deserialize_user_memory(result)
|
|
947
|
+
|
|
948
|
+
def upsert_memories(
|
|
949
|
+
self, memories: List[UserMemory], deserialize: Optional[bool] = True
|
|
950
|
+
) -> List[Union[UserMemory, Dict[str, Any]]]:
|
|
951
|
+
"""
|
|
952
|
+
Bulk insert or update multiple memories in the database for improved performance.
|
|
953
|
+
|
|
954
|
+
Args:
|
|
955
|
+
memories (List[UserMemory]): The list of memories to upsert.
|
|
956
|
+
deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
|
|
957
|
+
|
|
958
|
+
Returns:
|
|
959
|
+
List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories
|
|
960
|
+
|
|
961
|
+
Raises:
|
|
962
|
+
Exception: If an error occurs during bulk upsert.
|
|
963
|
+
"""
|
|
964
|
+
if not memories:
|
|
965
|
+
return []
|
|
966
|
+
table = self._get_table("memories")
|
|
967
|
+
user_table_name = self._get_table("users")
|
|
968
|
+
raw: list[dict] = []
|
|
969
|
+
for memory in memories:
|
|
970
|
+
if memory.memory_id:
|
|
971
|
+
# UPSERT does only work for one record at a time
|
|
972
|
+
session_raw = self._query_one(
|
|
973
|
+
"UPSERT ONLY $record CONTENT $content",
|
|
974
|
+
{
|
|
975
|
+
"record": RecordID(table, memory.memory_id),
|
|
976
|
+
"content": serialize_user_memory(memory, table, user_table_name),
|
|
977
|
+
},
|
|
978
|
+
dict,
|
|
979
|
+
)
|
|
980
|
+
else:
|
|
981
|
+
session_raw = self._query_one(
|
|
982
|
+
f"CREATE ONLY {table} CONTENT $content",
|
|
983
|
+
{"content": serialize_user_memory(memory, table, user_table_name)},
|
|
984
|
+
dict,
|
|
985
|
+
)
|
|
986
|
+
if session_raw is not None:
|
|
987
|
+
raw.append(session_raw)
|
|
988
|
+
if raw is None or not deserialize:
|
|
989
|
+
return [desurrealize_user_memory(x) for x in raw]
|
|
990
|
+
# wrapping with list because of:
|
|
991
|
+
# Type "List[Session]" is not assignable to return type "List[Session | Dict[str, Any]]"
|
|
992
|
+
# Consider switching from "list" to "Sequence" which is covariant
|
|
993
|
+
return list(deserialize_user_memories(raw))
|
|
994
|
+
|
|
995
|
+
# --- Metrics ---
|
|
996
|
+
def get_metrics(
|
|
997
|
+
self,
|
|
998
|
+
starting_date: Optional[date] = None,
|
|
999
|
+
ending_date: Optional[date] = None,
|
|
1000
|
+
) -> Tuple[List[Dict[str, Any]], Optional[int]]:
|
|
1001
|
+
"""Get all metrics matching the given date range.
|
|
1002
|
+
|
|
1003
|
+
Args:
|
|
1004
|
+
starting_date (Optional[date]): The starting date to filter metrics by.
|
|
1005
|
+
ending_date (Optional[date]): The ending date to filter metrics by.
|
|
1006
|
+
|
|
1007
|
+
Returns:
|
|
1008
|
+
Tuple[List[dict], Optional[int]]: A tuple containing the metrics and the timestamp of the latest update.
|
|
1009
|
+
|
|
1010
|
+
Raises:
|
|
1011
|
+
Exception: If an error occurs during retrieval.
|
|
1012
|
+
"""
|
|
1013
|
+
table = self._get_table("metrics")
|
|
1014
|
+
|
|
1015
|
+
where = WhereClause()
|
|
1016
|
+
|
|
1017
|
+
# starting_date - need to convert date to datetime for comparison
|
|
1018
|
+
if starting_date is not None:
|
|
1019
|
+
starting_datetime = datetime.combine(starting_date, datetime.min.time()).replace(tzinfo=timezone.utc)
|
|
1020
|
+
where = where.and_("date", starting_datetime, ">=")
|
|
1021
|
+
|
|
1022
|
+
# ending_date - need to convert date to datetime for comparison
|
|
1023
|
+
if ending_date is not None:
|
|
1024
|
+
ending_datetime = datetime.combine(ending_date, datetime.min.time()).replace(tzinfo=timezone.utc)
|
|
1025
|
+
where = where.and_("date", ending_datetime, "<=")
|
|
1026
|
+
|
|
1027
|
+
where_clause, where_vars = where.build()
|
|
1028
|
+
|
|
1029
|
+
# Query
|
|
1030
|
+
query = dedent(f"""
|
|
1031
|
+
SELECT *
|
|
1032
|
+
FROM {table}
|
|
1033
|
+
{where_clause}
|
|
1034
|
+
ORDER BY date ASC
|
|
1035
|
+
""")
|
|
1036
|
+
|
|
1037
|
+
results = self._query(query, where_vars, dict)
|
|
1038
|
+
|
|
1039
|
+
# Get the latest updated_at from all results
|
|
1040
|
+
latest_update = None
|
|
1041
|
+
if results:
|
|
1042
|
+
# Find the maximum updated_at timestamp
|
|
1043
|
+
latest_update = max(int(r["updated_at"].timestamp()) for r in results)
|
|
1044
|
+
|
|
1045
|
+
# Transform results to match expected format
|
|
1046
|
+
transformed_results = []
|
|
1047
|
+
for r in results:
|
|
1048
|
+
transformed = dict(r)
|
|
1049
|
+
|
|
1050
|
+
# Convert RecordID to string
|
|
1051
|
+
if hasattr(transformed.get("id"), "id"):
|
|
1052
|
+
transformed["id"] = transformed["id"].id
|
|
1053
|
+
elif isinstance(transformed.get("id"), RecordID):
|
|
1054
|
+
transformed["id"] = str(transformed["id"].id)
|
|
1055
|
+
|
|
1056
|
+
# Convert datetime objects to Unix timestamps
|
|
1057
|
+
if isinstance(transformed.get("created_at"), datetime):
|
|
1058
|
+
transformed["created_at"] = int(transformed["created_at"].timestamp())
|
|
1059
|
+
if isinstance(transformed.get("updated_at"), datetime):
|
|
1060
|
+
transformed["updated_at"] = int(transformed["updated_at"].timestamp())
|
|
1061
|
+
if isinstance(transformed.get("date"), datetime):
|
|
1062
|
+
transformed["date"] = int(transformed["date"].timestamp())
|
|
1063
|
+
|
|
1064
|
+
transformed_results.append(transformed)
|
|
1065
|
+
|
|
1066
|
+
return transformed_results, latest_update
|
|
1067
|
+
|
|
1068
|
+
return [], latest_update
|
|
1069
|
+
|
|
1070
|
+
def calculate_metrics(self) -> Optional[List[Dict[str, Any]]]: # More specific return type
|
|
1071
|
+
"""Calculate metrics for all dates without complete metrics.
|
|
1072
|
+
|
|
1073
|
+
Returns:
|
|
1074
|
+
Optional[List[Dict[str, Any]]]: The calculated metrics.
|
|
1075
|
+
|
|
1076
|
+
Raises:
|
|
1077
|
+
Exception: If an error occurs during metrics calculation.
|
|
1078
|
+
"""
|
|
1079
|
+
try:
|
|
1080
|
+
table = self._get_table("metrics") # Removed create_table_if_not_found parameter
|
|
1081
|
+
|
|
1082
|
+
starting_date = get_metrics_calculation_starting_date(self.client, table, self.get_sessions)
|
|
1083
|
+
|
|
1084
|
+
if starting_date is None:
|
|
1085
|
+
log_info("No session data found. Won't calculate metrics.")
|
|
1086
|
+
return None
|
|
1087
|
+
|
|
1088
|
+
dates_to_process = get_dates_to_calculate_metrics_for(starting_date)
|
|
1089
|
+
if not dates_to_process:
|
|
1090
|
+
log_info("Metrics already calculated for all relevant dates.")
|
|
1091
|
+
return None
|
|
1092
|
+
|
|
1093
|
+
start_timestamp = datetime.combine(dates_to_process[0], datetime.min.time()).replace(tzinfo=timezone.utc)
|
|
1094
|
+
end_timestamp = datetime.combine(dates_to_process[-1] + timedelta(days=1), datetime.min.time()).replace(
|
|
1095
|
+
tzinfo=timezone.utc
|
|
1096
|
+
)
|
|
1097
|
+
|
|
1098
|
+
sessions = get_all_sessions_for_metrics_calculation(
|
|
1099
|
+
self.client, self._get_table("sessions"), start_timestamp, end_timestamp
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1102
|
+
all_sessions_data = fetch_all_sessions_data(
|
|
1103
|
+
sessions=sessions, # Added parameter name for clarity
|
|
1104
|
+
dates_to_process=dates_to_process,
|
|
1105
|
+
start_timestamp=int(start_timestamp.timestamp()), # This expects int
|
|
1106
|
+
)
|
|
1107
|
+
if not all_sessions_data:
|
|
1108
|
+
log_info("No new session data found. Won't calculate metrics.")
|
|
1109
|
+
return None
|
|
1110
|
+
|
|
1111
|
+
metrics_records = []
|
|
1112
|
+
|
|
1113
|
+
for date_to_process in dates_to_process:
|
|
1114
|
+
date_key = date_to_process.isoformat()
|
|
1115
|
+
sessions_for_date = all_sessions_data.get(date_key, {})
|
|
1116
|
+
|
|
1117
|
+
# Skip dates with no sessions
|
|
1118
|
+
if not any(len(sessions) > 0 for sessions in sessions_for_date.values()):
|
|
1119
|
+
continue
|
|
1120
|
+
|
|
1121
|
+
metrics_record = calculate_date_metrics(date_to_process, sessions_for_date)
|
|
1122
|
+
metrics_records.append(metrics_record)
|
|
1123
|
+
|
|
1124
|
+
results = [] # Initialize before the if block
|
|
1125
|
+
if metrics_records:
|
|
1126
|
+
results = bulk_upsert_metrics(self.client, table, metrics_records)
|
|
1127
|
+
|
|
1128
|
+
log_debug("Updated metrics calculations")
|
|
1129
|
+
return results
|
|
1130
|
+
|
|
1131
|
+
except Exception as e:
|
|
1132
|
+
log_error(f"Exception refreshing metrics: {e}")
|
|
1133
|
+
raise e
|
|
1134
|
+
|
|
1135
|
+
# --- Knowledge ---
|
|
1136
|
+
def clear_knowledge(self) -> None:
|
|
1137
|
+
"""Delete all knowledge rows from the database.
|
|
1138
|
+
|
|
1139
|
+
Raises:
|
|
1140
|
+
Exception: If an error occurs during deletion.
|
|
1141
|
+
"""
|
|
1142
|
+
table = self._get_table("knowledge")
|
|
1143
|
+
_ = self.client.delete(table)
|
|
1144
|
+
|
|
1145
|
+
def delete_knowledge_content(self, id: str):
|
|
1146
|
+
"""Delete a knowledge row from the database.
|
|
1147
|
+
|
|
1148
|
+
Args:
|
|
1149
|
+
id (str): The ID of the knowledge row to delete.
|
|
1150
|
+
"""
|
|
1151
|
+
table = self._get_table("knowledge")
|
|
1152
|
+
self.client.delete(RecordID(table, id))
|
|
1153
|
+
|
|
1154
|
+
def get_knowledge_content(self, id: str) -> Optional[KnowledgeRow]:
|
|
1155
|
+
"""Get a knowledge row from the database.
|
|
1156
|
+
|
|
1157
|
+
Args:
|
|
1158
|
+
id (str): The ID of the knowledge row to get.
|
|
1159
|
+
|
|
1160
|
+
Returns:
|
|
1161
|
+
Optional[KnowledgeRow]: The knowledge row, or None if it doesn't exist.
|
|
1162
|
+
"""
|
|
1163
|
+
table = self._get_table("knowledge")
|
|
1164
|
+
record_id = RecordID(table, id)
|
|
1165
|
+
raw = self._query_one("SELECT * FROM ONLY $record_id", {"record_id": record_id}, dict)
|
|
1166
|
+
return deserialize_knowledge_row(raw) if raw else None
|
|
1167
|
+
|
|
1168
|
+
def get_knowledge_contents(
|
|
1169
|
+
self,
|
|
1170
|
+
limit: Optional[int] = None,
|
|
1171
|
+
page: Optional[int] = None,
|
|
1172
|
+
sort_by: Optional[str] = None,
|
|
1173
|
+
sort_order: Optional[str] = None,
|
|
1174
|
+
) -> Tuple[List[KnowledgeRow], int]:
|
|
1175
|
+
"""Get all knowledge contents from the database.
|
|
1176
|
+
|
|
1177
|
+
Args:
|
|
1178
|
+
limit (Optional[int]): The maximum number of knowledge contents to return.
|
|
1179
|
+
page (Optional[int]): The page number.
|
|
1180
|
+
sort_by (Optional[str]): The column to sort by.
|
|
1181
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1182
|
+
|
|
1183
|
+
Returns:
|
|
1184
|
+
Tuple[List[KnowledgeRow], int]: The knowledge contents and total count.
|
|
1185
|
+
|
|
1186
|
+
Raises:
|
|
1187
|
+
Exception: If an error occurs during retrieval.
|
|
1188
|
+
"""
|
|
1189
|
+
table = self._get_table("knowledge")
|
|
1190
|
+
where = WhereClause()
|
|
1191
|
+
where_clause, where_vars = where.build()
|
|
1192
|
+
|
|
1193
|
+
# Total count
|
|
1194
|
+
total_count = self._count(table, where_clause, where_vars)
|
|
1195
|
+
|
|
1196
|
+
# Query
|
|
1197
|
+
order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
|
|
1198
|
+
query = dedent(f"""
|
|
1199
|
+
SELECT *
|
|
1200
|
+
FROM {table}
|
|
1201
|
+
{where_clause}
|
|
1202
|
+
{order_limit_start_clause}
|
|
1203
|
+
""")
|
|
1204
|
+
result = self._query(query, where_vars, dict)
|
|
1205
|
+
return [deserialize_knowledge_row(row) for row in result], total_count
|
|
1206
|
+
|
|
1207
|
+
def upsert_knowledge_content(self, knowledge_row: KnowledgeRow) -> Optional[KnowledgeRow]:
|
|
1208
|
+
"""Upsert knowledge content in the database.
|
|
1209
|
+
|
|
1210
|
+
Args:
|
|
1211
|
+
knowledge_row (KnowledgeRow): The knowledge row to upsert.
|
|
1212
|
+
|
|
1213
|
+
Returns:
|
|
1214
|
+
Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
|
|
1215
|
+
"""
|
|
1216
|
+
knowledge_table_name = self._get_table("knowledge")
|
|
1217
|
+
record = RecordID(knowledge_table_name, knowledge_row.id)
|
|
1218
|
+
query = "UPSERT ONLY $record CONTENT $content"
|
|
1219
|
+
result = self._query_one(
|
|
1220
|
+
query, {"record": record, "content": serialize_knowledge_row(knowledge_row, knowledge_table_name)}, dict
|
|
1221
|
+
)
|
|
1222
|
+
return deserialize_knowledge_row(result) if result else None
|
|
1223
|
+
|
|
1224
|
+
# --- Evals ---
|
|
1225
|
+
def clear_evals(self) -> None:
|
|
1226
|
+
"""Delete all eval rows from the database.
|
|
1227
|
+
|
|
1228
|
+
Raises:
|
|
1229
|
+
Exception: If an error occurs during deletion.
|
|
1230
|
+
"""
|
|
1231
|
+
table = self._get_table("evals")
|
|
1232
|
+
_ = self.client.delete(table)
|
|
1233
|
+
|
|
1234
|
+
def create_eval_run(self, eval_run: EvalRunRecord) -> Optional[EvalRunRecord]:
|
|
1235
|
+
"""Create an EvalRunRecord in the database.
|
|
1236
|
+
|
|
1237
|
+
Args:
|
|
1238
|
+
eval_run (EvalRunRecord): The eval run to create.
|
|
1239
|
+
|
|
1240
|
+
Returns:
|
|
1241
|
+
Optional[EvalRunRecord]: The created eval run, or None if the operation fails.
|
|
1242
|
+
|
|
1243
|
+
Raises:
|
|
1244
|
+
Exception: If an error occurs during creation.
|
|
1245
|
+
"""
|
|
1246
|
+
table = self._get_table("evals")
|
|
1247
|
+
rec_id = RecordID(table, eval_run.run_id)
|
|
1248
|
+
query = "CREATE ONLY $record CONTENT $content"
|
|
1249
|
+
result = self._query_one(
|
|
1250
|
+
query, {"record": rec_id, "content": serialize_eval_run_record(eval_run, self.table_names)}, dict
|
|
1251
|
+
)
|
|
1252
|
+
return deserialize_eval_run_record(result) if result else None
|
|
1253
|
+
|
|
1254
|
+
def delete_eval_runs(self, eval_run_ids: List[str]) -> None:
|
|
1255
|
+
"""Delete multiple eval runs from the database.
|
|
1256
|
+
|
|
1257
|
+
Args:
|
|
1258
|
+
eval_run_ids (List[str]): List of eval run IDs to delete.
|
|
1259
|
+
"""
|
|
1260
|
+
table = self._get_table("evals")
|
|
1261
|
+
records = [RecordID(table, id) for id in eval_run_ids]
|
|
1262
|
+
_ = self.client.query(f"DELETE FROM {table} WHERE id IN $records", {"records": records})
|
|
1263
|
+
|
|
1264
|
+
def get_eval_run(
|
|
1265
|
+
self, eval_run_id: str, deserialize: Optional[bool] = True
|
|
1266
|
+
) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1267
|
+
"""Get an eval run from the database.
|
|
1268
|
+
|
|
1269
|
+
Args:
|
|
1270
|
+
eval_run_id (str): The ID of the eval run to get.
|
|
1271
|
+
deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
|
|
1272
|
+
|
|
1273
|
+
Returns:
|
|
1274
|
+
Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1275
|
+
- When deserialize=True: EvalRunRecord object
|
|
1276
|
+
- When deserialize=False: EvalRun dictionary
|
|
1277
|
+
|
|
1278
|
+
Raises:
|
|
1279
|
+
Exception: If an error occurs during retrieval.
|
|
1280
|
+
"""
|
|
1281
|
+
table = self._get_table("evals")
|
|
1282
|
+
record = RecordID(table, eval_run_id)
|
|
1283
|
+
result = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
|
|
1284
|
+
if not result or not deserialize:
|
|
1285
|
+
return desurrealize_eval_run_record(result) if result is not None else None
|
|
1286
|
+
return deserialize_eval_run_record(result)
|
|
1287
|
+
|
|
1288
|
+
def get_eval_runs(
|
|
1289
|
+
self,
|
|
1290
|
+
limit: Optional[int] = None,
|
|
1291
|
+
page: Optional[int] = None,
|
|
1292
|
+
sort_by: Optional[str] = None,
|
|
1293
|
+
sort_order: Optional[str] = None,
|
|
1294
|
+
agent_id: Optional[str] = None,
|
|
1295
|
+
team_id: Optional[str] = None,
|
|
1296
|
+
workflow_id: Optional[str] = None,
|
|
1297
|
+
model_id: Optional[str] = None,
|
|
1298
|
+
filter_type: Optional[EvalFilterType] = None,
|
|
1299
|
+
eval_type: Optional[List[EvalType]] = None,
|
|
1300
|
+
deserialize: Optional[bool] = True,
|
|
1301
|
+
) -> Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
|
|
1302
|
+
"""Get all eval runs from the database.
|
|
1303
|
+
|
|
1304
|
+
Args:
|
|
1305
|
+
limit (Optional[int]): The maximum number of eval runs to return.
|
|
1306
|
+
page (Optional[int]): The page number to return.
|
|
1307
|
+
sort_by (Optional[str]): The field to sort by.
|
|
1308
|
+
sort_order (Optional[str]): The order to sort by.
|
|
1309
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
1310
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
1311
|
+
workflow_id (Optional[str]): The ID of the workflow to filter by.
|
|
1312
|
+
model_id (Optional[str]): The ID of the model to filter by.
|
|
1313
|
+
eval_type (Optional[List[EvalType]]): The type of eval to filter by.
|
|
1314
|
+
filter_type (Optional[EvalFilterType]): The type of filter to apply.
|
|
1315
|
+
deserialize (Optional[bool]): Whether to serialize the eval runs. Defaults to True.
|
|
1316
|
+
|
|
1317
|
+
Returns:
|
|
1318
|
+
Union[List[EvalRunRecord], Tuple[List[Dict[str, Any]], int]]:
|
|
1319
|
+
- When deserialize=True: List of EvalRunRecord objects
|
|
1320
|
+
- When deserialize=False: List of eval run dictionaries and the total count
|
|
1321
|
+
|
|
1322
|
+
Raises:
|
|
1323
|
+
Exception: If there is an error getting the eval runs.
|
|
1324
|
+
"""
|
|
1325
|
+
table = self._get_table("evals")
|
|
1326
|
+
|
|
1327
|
+
where = WhereClause()
|
|
1328
|
+
if filter_type is not None:
|
|
1329
|
+
if filter_type == EvalFilterType.AGENT:
|
|
1330
|
+
where.and_("agent", RecordID(self._get_table("agents"), agent_id))
|
|
1331
|
+
elif filter_type == EvalFilterType.TEAM:
|
|
1332
|
+
where.and_("team", RecordID(self._get_table("teams"), team_id))
|
|
1333
|
+
elif filter_type == EvalFilterType.WORKFLOW:
|
|
1334
|
+
where.and_("workflow", RecordID(self._get_table("workflows"), workflow_id))
|
|
1335
|
+
if model_id is not None:
|
|
1336
|
+
where.and_("model_id", model_id)
|
|
1337
|
+
if eval_type is not None:
|
|
1338
|
+
where.and_("eval_type", eval_type)
|
|
1339
|
+
where_clause, where_vars = where.build()
|
|
1340
|
+
|
|
1341
|
+
# Order
|
|
1342
|
+
order_limit_start_clause = order_limit_start(sort_by, sort_order, limit, page)
|
|
1343
|
+
|
|
1344
|
+
# Total count
|
|
1345
|
+
total_count = self._count(table, where_clause, where_vars)
|
|
1346
|
+
|
|
1347
|
+
# Query
|
|
1348
|
+
query = dedent(f"""
|
|
1349
|
+
SELECT *
|
|
1350
|
+
FROM {table}
|
|
1351
|
+
{where_clause}
|
|
1352
|
+
{order_limit_start_clause}
|
|
1353
|
+
""")
|
|
1354
|
+
result = self._query(query, where_vars, dict)
|
|
1355
|
+
|
|
1356
|
+
if not deserialize:
|
|
1357
|
+
return list(result), total_count
|
|
1358
|
+
return [deserialize_eval_run_record(x) for x in result]
|
|
1359
|
+
|
|
1360
|
+
def rename_eval_run(
|
|
1361
|
+
self, eval_run_id: str, name: str, deserialize: Optional[bool] = True
|
|
1362
|
+
) -> Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1363
|
+
"""Update the name of an eval run in the database.
|
|
1364
|
+
|
|
1365
|
+
Args:
|
|
1366
|
+
eval_run_id (str): The ID of the eval run to update.
|
|
1367
|
+
name (str): The new name of the eval run.
|
|
1368
|
+
deserialize (Optional[bool]): Whether to serialize the eval run. Defaults to True.
|
|
1369
|
+
|
|
1370
|
+
Returns:
|
|
1371
|
+
Optional[Union[EvalRunRecord, Dict[str, Any]]]:
|
|
1372
|
+
- When deserialize=True: EvalRunRecord object
|
|
1373
|
+
- When deserialize=False: EvalRun dictionary
|
|
1374
|
+
|
|
1375
|
+
Raises:
|
|
1376
|
+
Exception: If there is an error updating the eval run.
|
|
1377
|
+
"""
|
|
1378
|
+
table = self._get_table("evals")
|
|
1379
|
+
vars = {"record": RecordID(table, eval_run_id), "name": name}
|
|
1380
|
+
|
|
1381
|
+
# Query
|
|
1382
|
+
query = dedent("""
|
|
1383
|
+
UPDATE ONLY $record
|
|
1384
|
+
SET name = $name
|
|
1385
|
+
""")
|
|
1386
|
+
raw = self._query_one(query, vars, dict)
|
|
1387
|
+
|
|
1388
|
+
if not raw or not deserialize:
|
|
1389
|
+
return raw
|
|
1390
|
+
return deserialize_eval_run_record(raw)
|
|
1391
|
+
|
|
1392
|
+
# --- Traces ---
|
|
1393
|
+
def upsert_trace(self, trace: "Trace") -> None:
|
|
1394
|
+
"""Create or update a single trace record in the database.
|
|
1395
|
+
|
|
1396
|
+
Args:
|
|
1397
|
+
trace: The Trace object to store (one per trace_id).
|
|
1398
|
+
"""
|
|
1399
|
+
try:
|
|
1400
|
+
table = self._get_table("traces", create_table_if_not_found=True)
|
|
1401
|
+
record = RecordID(table, trace.trace_id)
|
|
1402
|
+
|
|
1403
|
+
# Check if trace exists
|
|
1404
|
+
existing = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
|
|
1405
|
+
|
|
1406
|
+
if existing:
|
|
1407
|
+
# workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
|
|
1408
|
+
def get_component_level(workflow_id: Any, team_id: Any, agent_id: Any, name: str) -> int:
|
|
1409
|
+
is_root_name = ".run" in name or ".arun" in name
|
|
1410
|
+
if not is_root_name:
|
|
1411
|
+
return 0
|
|
1412
|
+
elif workflow_id:
|
|
1413
|
+
return 3
|
|
1414
|
+
elif team_id:
|
|
1415
|
+
return 2
|
|
1416
|
+
elif agent_id:
|
|
1417
|
+
return 1
|
|
1418
|
+
else:
|
|
1419
|
+
return 0
|
|
1420
|
+
|
|
1421
|
+
existing_level = get_component_level(
|
|
1422
|
+
existing.get("workflow_id"),
|
|
1423
|
+
existing.get("team_id"),
|
|
1424
|
+
existing.get("agent_id"),
|
|
1425
|
+
existing.get("name", ""),
|
|
1426
|
+
)
|
|
1427
|
+
new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
|
|
1428
|
+
should_update_name = new_level > existing_level
|
|
1429
|
+
|
|
1430
|
+
# Parse existing start_time to calculate correct duration
|
|
1431
|
+
existing_start_time = existing.get("start_time")
|
|
1432
|
+
if isinstance(existing_start_time, datetime):
|
|
1433
|
+
recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
|
|
1434
|
+
else:
|
|
1435
|
+
recalculated_duration_ms = trace.duration_ms
|
|
1436
|
+
|
|
1437
|
+
# Build update query
|
|
1438
|
+
update_fields = [
|
|
1439
|
+
"end_time = $end_time",
|
|
1440
|
+
"duration_ms = $duration_ms",
|
|
1441
|
+
"status = $status",
|
|
1442
|
+
]
|
|
1443
|
+
update_vars: Dict[str, Any] = {
|
|
1444
|
+
"record": record,
|
|
1445
|
+
"end_time": trace.end_time,
|
|
1446
|
+
"duration_ms": recalculated_duration_ms,
|
|
1447
|
+
"status": trace.status,
|
|
1448
|
+
}
|
|
1449
|
+
|
|
1450
|
+
if should_update_name:
|
|
1451
|
+
update_fields.append("name = $name")
|
|
1452
|
+
update_vars["name"] = trace.name
|
|
1453
|
+
|
|
1454
|
+
# Update context fields only if new value is not None
|
|
1455
|
+
if trace.run_id is not None:
|
|
1456
|
+
update_fields.append("run_id = $run_id")
|
|
1457
|
+
update_vars["run_id"] = trace.run_id
|
|
1458
|
+
if trace.session_id is not None:
|
|
1459
|
+
update_fields.append("session_id = $session_id")
|
|
1460
|
+
update_vars["session_id"] = trace.session_id
|
|
1461
|
+
if trace.user_id is not None:
|
|
1462
|
+
update_fields.append("user_id = $user_id")
|
|
1463
|
+
update_vars["user_id"] = trace.user_id
|
|
1464
|
+
if trace.agent_id is not None:
|
|
1465
|
+
update_fields.append("agent_id = $agent_id")
|
|
1466
|
+
update_vars["agent_id"] = trace.agent_id
|
|
1467
|
+
if trace.team_id is not None:
|
|
1468
|
+
update_fields.append("team_id = $team_id")
|
|
1469
|
+
update_vars["team_id"] = trace.team_id
|
|
1470
|
+
if trace.workflow_id is not None:
|
|
1471
|
+
update_fields.append("workflow_id = $workflow_id")
|
|
1472
|
+
update_vars["workflow_id"] = trace.workflow_id
|
|
1473
|
+
|
|
1474
|
+
update_query = f"UPDATE ONLY $record SET {', '.join(update_fields)}"
|
|
1475
|
+
self._query_one(update_query, update_vars, dict)
|
|
1476
|
+
else:
|
|
1477
|
+
# Create new trace
|
|
1478
|
+
trace_dict = trace.to_dict()
|
|
1479
|
+
trace_dict.pop("total_spans", None)
|
|
1480
|
+
trace_dict.pop("error_count", None)
|
|
1481
|
+
|
|
1482
|
+
# Convert datetime fields
|
|
1483
|
+
if isinstance(trace_dict.get("start_time"), str):
|
|
1484
|
+
trace_dict["start_time"] = datetime.fromisoformat(trace_dict["start_time"].replace("Z", "+00:00"))
|
|
1485
|
+
if isinstance(trace_dict.get("end_time"), str):
|
|
1486
|
+
trace_dict["end_time"] = datetime.fromisoformat(trace_dict["end_time"].replace("Z", "+00:00"))
|
|
1487
|
+
if isinstance(trace_dict.get("created_at"), str):
|
|
1488
|
+
trace_dict["created_at"] = datetime.fromisoformat(trace_dict["created_at"].replace("Z", "+00:00"))
|
|
1489
|
+
|
|
1490
|
+
self._query_one(
|
|
1491
|
+
"CREATE ONLY $record CONTENT $content",
|
|
1492
|
+
{"record": record, "content": trace_dict},
|
|
1493
|
+
dict,
|
|
1494
|
+
)
|
|
1495
|
+
|
|
1496
|
+
except Exception as e:
|
|
1497
|
+
log_error(f"Error creating trace: {e}")
|
|
1498
|
+
|
|
1499
|
+
def get_trace(
|
|
1500
|
+
self,
|
|
1501
|
+
trace_id: Optional[str] = None,
|
|
1502
|
+
run_id: Optional[str] = None,
|
|
1503
|
+
):
|
|
1504
|
+
"""Get a single trace by trace_id or other filters.
|
|
1505
|
+
|
|
1506
|
+
Args:
|
|
1507
|
+
trace_id: The unique trace identifier.
|
|
1508
|
+
run_id: Filter by run ID (returns first match).
|
|
1509
|
+
|
|
1510
|
+
Returns:
|
|
1511
|
+
Optional[Trace]: The trace if found, None otherwise.
|
|
1512
|
+
|
|
1513
|
+
Note:
|
|
1514
|
+
If multiple filters are provided, trace_id takes precedence.
|
|
1515
|
+
For other filters, the most recent trace is returned.
|
|
1516
|
+
"""
|
|
1517
|
+
try:
|
|
1518
|
+
table = self._get_table("traces", create_table_if_not_found=False)
|
|
1519
|
+
spans_table = self._get_table("spans", create_table_if_not_found=False)
|
|
1520
|
+
|
|
1521
|
+
if trace_id:
|
|
1522
|
+
record = RecordID(table, trace_id)
|
|
1523
|
+
trace_data = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
|
|
1524
|
+
elif run_id:
|
|
1525
|
+
query = dedent(f"""
|
|
1526
|
+
SELECT * FROM {table}
|
|
1527
|
+
WHERE run_id = $run_id
|
|
1528
|
+
ORDER BY start_time DESC
|
|
1529
|
+
LIMIT 1
|
|
1530
|
+
""")
|
|
1531
|
+
trace_data = self._query_one(query, {"run_id": run_id}, dict)
|
|
1532
|
+
else:
|
|
1533
|
+
log_debug("get_trace called without any filter parameters")
|
|
1534
|
+
return None
|
|
1535
|
+
|
|
1536
|
+
if not trace_data:
|
|
1537
|
+
return None
|
|
1538
|
+
|
|
1539
|
+
# Calculate total_spans and error_count
|
|
1540
|
+
id_obj = trace_data.get("id")
|
|
1541
|
+
trace_id_val = trace_data.get("trace_id") or (id_obj.id if id_obj is not None else None)
|
|
1542
|
+
if trace_id_val:
|
|
1543
|
+
count_query = f"SELECT count() as total FROM {spans_table} WHERE trace_id = $trace_id GROUP ALL"
|
|
1544
|
+
count_result = self._query_one(count_query, {"trace_id": trace_id_val}, dict)
|
|
1545
|
+
trace_data["total_spans"] = count_result.get("total", 0) if count_result else 0
|
|
1546
|
+
|
|
1547
|
+
error_query = f"SELECT count() as total FROM {spans_table} WHERE trace_id = $trace_id AND status_code = 'ERROR' GROUP ALL"
|
|
1548
|
+
error_result = self._query_one(error_query, {"trace_id": trace_id_val}, dict)
|
|
1549
|
+
trace_data["error_count"] = error_result.get("total", 0) if error_result else 0
|
|
1550
|
+
|
|
1551
|
+
# Deserialize
|
|
1552
|
+
return self._deserialize_trace(trace_data)
|
|
1553
|
+
|
|
1554
|
+
except Exception as e:
|
|
1555
|
+
log_error(f"Error getting trace: {e}")
|
|
1556
|
+
return None
|
|
1557
|
+
|
|
1558
|
+
def get_traces(
|
|
1559
|
+
self,
|
|
1560
|
+
run_id: Optional[str] = None,
|
|
1561
|
+
session_id: Optional[str] = None,
|
|
1562
|
+
user_id: Optional[str] = None,
|
|
1563
|
+
agent_id: Optional[str] = None,
|
|
1564
|
+
team_id: Optional[str] = None,
|
|
1565
|
+
workflow_id: Optional[str] = None,
|
|
1566
|
+
status: Optional[str] = None,
|
|
1567
|
+
start_time: Optional[datetime] = None,
|
|
1568
|
+
end_time: Optional[datetime] = None,
|
|
1569
|
+
limit: Optional[int] = 20,
|
|
1570
|
+
page: Optional[int] = 1,
|
|
1571
|
+
) -> tuple[List, int]:
|
|
1572
|
+
"""Get traces matching the provided filters with pagination.
|
|
1573
|
+
|
|
1574
|
+
Args:
|
|
1575
|
+
run_id: Filter by run ID.
|
|
1576
|
+
session_id: Filter by session ID.
|
|
1577
|
+
user_id: Filter by user ID.
|
|
1578
|
+
agent_id: Filter by agent ID.
|
|
1579
|
+
team_id: Filter by team ID.
|
|
1580
|
+
workflow_id: Filter by workflow ID.
|
|
1581
|
+
status: Filter by status (OK, ERROR, UNSET).
|
|
1582
|
+
start_time: Filter traces starting after this datetime.
|
|
1583
|
+
end_time: Filter traces ending before this datetime.
|
|
1584
|
+
limit: Maximum number of traces to return per page.
|
|
1585
|
+
page: Page number (1-indexed).
|
|
1586
|
+
|
|
1587
|
+
Returns:
|
|
1588
|
+
tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
|
|
1589
|
+
"""
|
|
1590
|
+
try:
|
|
1591
|
+
table = self._get_table("traces", create_table_if_not_found=False)
|
|
1592
|
+
spans_table = self._get_table("spans", create_table_if_not_found=False)
|
|
1593
|
+
|
|
1594
|
+
# Build where clause
|
|
1595
|
+
where = WhereClause()
|
|
1596
|
+
if run_id:
|
|
1597
|
+
where.and_("run_id", run_id)
|
|
1598
|
+
if session_id:
|
|
1599
|
+
where.and_("session_id", session_id)
|
|
1600
|
+
if user_id:
|
|
1601
|
+
where.and_("user_id", user_id)
|
|
1602
|
+
if agent_id:
|
|
1603
|
+
where.and_("agent_id", agent_id)
|
|
1604
|
+
if team_id:
|
|
1605
|
+
where.and_("team_id", team_id)
|
|
1606
|
+
if workflow_id:
|
|
1607
|
+
where.and_("workflow_id", workflow_id)
|
|
1608
|
+
if status:
|
|
1609
|
+
where.and_("status", status)
|
|
1610
|
+
if start_time:
|
|
1611
|
+
where.and_("start_time", start_time, ">=")
|
|
1612
|
+
if end_time:
|
|
1613
|
+
where.and_("end_time", end_time, "<=")
|
|
1614
|
+
|
|
1615
|
+
where_clause, where_vars = where.build()
|
|
1616
|
+
|
|
1617
|
+
# Total count
|
|
1618
|
+
total_count = self._count(table, where_clause, where_vars)
|
|
1619
|
+
|
|
1620
|
+
# Query with pagination
|
|
1621
|
+
order_limit_start_clause = order_limit_start("start_time", "DESC", limit, page)
|
|
1622
|
+
query = dedent(f"""
|
|
1623
|
+
SELECT * FROM {table}
|
|
1624
|
+
{where_clause}
|
|
1625
|
+
{order_limit_start_clause}
|
|
1626
|
+
""")
|
|
1627
|
+
traces_raw = self._query(query, where_vars, dict)
|
|
1628
|
+
|
|
1629
|
+
# Add total_spans and error_count to each trace
|
|
1630
|
+
result_traces = []
|
|
1631
|
+
for trace_data in traces_raw:
|
|
1632
|
+
id_obj = trace_data.get("id")
|
|
1633
|
+
trace_id_val = trace_data.get("trace_id") or (id_obj.id if id_obj is not None else None)
|
|
1634
|
+
if trace_id_val:
|
|
1635
|
+
count_query = f"SELECT count() as total FROM {spans_table} WHERE trace_id = $trace_id GROUP ALL"
|
|
1636
|
+
count_result = self._query_one(count_query, {"trace_id": trace_id_val}, dict)
|
|
1637
|
+
trace_data["total_spans"] = count_result.get("total", 0) if count_result else 0
|
|
1638
|
+
|
|
1639
|
+
error_query = f"SELECT count() as total FROM {spans_table} WHERE trace_id = $trace_id AND status_code = 'ERROR' GROUP ALL"
|
|
1640
|
+
error_result = self._query_one(error_query, {"trace_id": trace_id_val}, dict)
|
|
1641
|
+
trace_data["error_count"] = error_result.get("total", 0) if error_result else 0
|
|
1642
|
+
|
|
1643
|
+
result_traces.append(self._deserialize_trace(trace_data))
|
|
1644
|
+
|
|
1645
|
+
return result_traces, total_count
|
|
1646
|
+
|
|
1647
|
+
except Exception as e:
|
|
1648
|
+
log_error(f"Error getting traces: {e}")
|
|
1649
|
+
return [], 0
|
|
1650
|
+
|
|
1651
|
+
def get_trace_stats(
|
|
1652
|
+
self,
|
|
1653
|
+
user_id: Optional[str] = None,
|
|
1654
|
+
agent_id: Optional[str] = None,
|
|
1655
|
+
team_id: Optional[str] = None,
|
|
1656
|
+
workflow_id: Optional[str] = None,
|
|
1657
|
+
start_time: Optional[datetime] = None,
|
|
1658
|
+
end_time: Optional[datetime] = None,
|
|
1659
|
+
limit: Optional[int] = 20,
|
|
1660
|
+
page: Optional[int] = 1,
|
|
1661
|
+
) -> tuple[List[Dict[str, Any]], int]:
|
|
1662
|
+
"""Get trace statistics grouped by session.
|
|
1663
|
+
|
|
1664
|
+
Args:
|
|
1665
|
+
user_id: Filter by user ID.
|
|
1666
|
+
agent_id: Filter by agent ID.
|
|
1667
|
+
team_id: Filter by team ID.
|
|
1668
|
+
workflow_id: Filter by workflow ID.
|
|
1669
|
+
start_time: Filter sessions with traces created after this datetime.
|
|
1670
|
+
end_time: Filter sessions with traces created before this datetime.
|
|
1671
|
+
limit: Maximum number of sessions to return per page.
|
|
1672
|
+
page: Page number (1-indexed).
|
|
1673
|
+
|
|
1674
|
+
Returns:
|
|
1675
|
+
tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
|
|
1676
|
+
Each dict contains: session_id, user_id, agent_id, team_id, workflow_id, total_traces,
|
|
1677
|
+
first_trace_at, last_trace_at.
|
|
1678
|
+
"""
|
|
1679
|
+
try:
|
|
1680
|
+
table = self._get_table("traces", create_table_if_not_found=False)
|
|
1681
|
+
|
|
1682
|
+
# Build where clause
|
|
1683
|
+
where = WhereClause()
|
|
1684
|
+
where.and_("!!session_id", True, "=") # Ensure session_id is not null
|
|
1685
|
+
if user_id:
|
|
1686
|
+
where.and_("user_id", user_id)
|
|
1687
|
+
if agent_id:
|
|
1688
|
+
where.and_("agent_id", agent_id)
|
|
1689
|
+
if team_id:
|
|
1690
|
+
where.and_("team_id", team_id)
|
|
1691
|
+
if workflow_id:
|
|
1692
|
+
where.and_("workflow_id", workflow_id)
|
|
1693
|
+
if start_time:
|
|
1694
|
+
where.and_("created_at", start_time, ">=")
|
|
1695
|
+
if end_time:
|
|
1696
|
+
where.and_("created_at", end_time, "<=")
|
|
1697
|
+
|
|
1698
|
+
where_clause, where_vars = where.build()
|
|
1699
|
+
|
|
1700
|
+
# Get total count of unique sessions
|
|
1701
|
+
count_query = dedent(f"""
|
|
1702
|
+
SELECT count() as total FROM (
|
|
1703
|
+
SELECT session_id FROM {table}
|
|
1704
|
+
{where_clause}
|
|
1705
|
+
GROUP BY session_id
|
|
1706
|
+
) GROUP ALL
|
|
1707
|
+
""")
|
|
1708
|
+
count_result = self._query_one(count_query, where_vars, dict)
|
|
1709
|
+
total_count = count_result.get("total", 0) if count_result else 0
|
|
1710
|
+
|
|
1711
|
+
# Query with aggregation
|
|
1712
|
+
order_limit_start_clause = order_limit_start("last_trace_at", "DESC", limit, page)
|
|
1713
|
+
query = dedent(f"""
|
|
1714
|
+
SELECT
|
|
1715
|
+
session_id,
|
|
1716
|
+
user_id,
|
|
1717
|
+
agent_id,
|
|
1718
|
+
team_id,
|
|
1719
|
+
workflow_id,
|
|
1720
|
+
count() AS total_traces,
|
|
1721
|
+
time::min(created_at) AS first_trace_at,
|
|
1722
|
+
time::max(created_at) AS last_trace_at
|
|
1723
|
+
FROM {table}
|
|
1724
|
+
{where_clause}
|
|
1725
|
+
GROUP BY session_id, user_id, agent_id, team_id, workflow_id
|
|
1726
|
+
{order_limit_start_clause}
|
|
1727
|
+
""")
|
|
1728
|
+
results = self._query(query, where_vars, dict)
|
|
1729
|
+
|
|
1730
|
+
# Convert datetime objects
|
|
1731
|
+
stats_list = []
|
|
1732
|
+
for row in results:
|
|
1733
|
+
stat = dict(row)
|
|
1734
|
+
if isinstance(stat.get("first_trace_at"), datetime):
|
|
1735
|
+
pass # Keep as datetime
|
|
1736
|
+
if isinstance(stat.get("last_trace_at"), datetime):
|
|
1737
|
+
pass # Keep as datetime
|
|
1738
|
+
stats_list.append(stat)
|
|
1739
|
+
|
|
1740
|
+
return stats_list, total_count
|
|
1741
|
+
|
|
1742
|
+
except Exception as e:
|
|
1743
|
+
log_error(f"Error getting trace stats: {e}")
|
|
1744
|
+
return [], 0
|
|
1745
|
+
|
|
1746
|
+
def _deserialize_trace(self, trace_data: dict) -> "Trace":
|
|
1747
|
+
"""Helper to deserialize a trace record from SurrealDB."""
|
|
1748
|
+
from agno.tracing.schemas import Trace
|
|
1749
|
+
|
|
1750
|
+
# Handle RecordID for id field
|
|
1751
|
+
if isinstance(trace_data.get("id"), RecordID):
|
|
1752
|
+
if "trace_id" not in trace_data or not trace_data["trace_id"]:
|
|
1753
|
+
trace_data["trace_id"] = trace_data["id"].id
|
|
1754
|
+
del trace_data["id"]
|
|
1755
|
+
|
|
1756
|
+
# Convert datetime to ISO string for Trace.from_dict
|
|
1757
|
+
for field in ["start_time", "end_time", "created_at"]:
|
|
1758
|
+
if isinstance(trace_data.get(field), datetime):
|
|
1759
|
+
trace_data[field] = trace_data[field].isoformat()
|
|
1760
|
+
|
|
1761
|
+
return Trace.from_dict(trace_data)
|
|
1762
|
+
|
|
1763
|
+
# --- Spans ---
|
|
1764
|
+
def create_span(self, span: "Span") -> None:
|
|
1765
|
+
"""Create a single span in the database.
|
|
1766
|
+
|
|
1767
|
+
Args:
|
|
1768
|
+
span: The Span object to store.
|
|
1769
|
+
"""
|
|
1770
|
+
try:
|
|
1771
|
+
table = self._get_table("spans", create_table_if_not_found=True)
|
|
1772
|
+
record = RecordID(table, span.span_id)
|
|
1773
|
+
|
|
1774
|
+
span_dict = span.to_dict()
|
|
1775
|
+
|
|
1776
|
+
# Convert datetime fields
|
|
1777
|
+
if isinstance(span_dict.get("start_time"), str):
|
|
1778
|
+
span_dict["start_time"] = datetime.fromisoformat(span_dict["start_time"].replace("Z", "+00:00"))
|
|
1779
|
+
if isinstance(span_dict.get("end_time"), str):
|
|
1780
|
+
span_dict["end_time"] = datetime.fromisoformat(span_dict["end_time"].replace("Z", "+00:00"))
|
|
1781
|
+
if isinstance(span_dict.get("created_at"), str):
|
|
1782
|
+
span_dict["created_at"] = datetime.fromisoformat(span_dict["created_at"].replace("Z", "+00:00"))
|
|
1783
|
+
|
|
1784
|
+
self._query_one(
|
|
1785
|
+
"CREATE ONLY $record CONTENT $content",
|
|
1786
|
+
{"record": record, "content": span_dict},
|
|
1787
|
+
dict,
|
|
1788
|
+
)
|
|
1789
|
+
|
|
1790
|
+
except Exception as e:
|
|
1791
|
+
log_error(f"Error creating span: {e}")
|
|
1792
|
+
|
|
1793
|
+
def create_spans(self, spans: List) -> None:
|
|
1794
|
+
"""Create multiple spans in the database as a batch.
|
|
1795
|
+
|
|
1796
|
+
Args:
|
|
1797
|
+
spans: List of Span objects to store.
|
|
1798
|
+
"""
|
|
1799
|
+
if not spans:
|
|
1800
|
+
return
|
|
1801
|
+
|
|
1802
|
+
try:
|
|
1803
|
+
table = self._get_table("spans", create_table_if_not_found=True)
|
|
1804
|
+
|
|
1805
|
+
for span in spans:
|
|
1806
|
+
record = RecordID(table, span.span_id)
|
|
1807
|
+
span_dict = span.to_dict()
|
|
1808
|
+
|
|
1809
|
+
# Convert datetime fields
|
|
1810
|
+
if isinstance(span_dict.get("start_time"), str):
|
|
1811
|
+
span_dict["start_time"] = datetime.fromisoformat(span_dict["start_time"].replace("Z", "+00:00"))
|
|
1812
|
+
if isinstance(span_dict.get("end_time"), str):
|
|
1813
|
+
span_dict["end_time"] = datetime.fromisoformat(span_dict["end_time"].replace("Z", "+00:00"))
|
|
1814
|
+
if isinstance(span_dict.get("created_at"), str):
|
|
1815
|
+
span_dict["created_at"] = datetime.fromisoformat(span_dict["created_at"].replace("Z", "+00:00"))
|
|
1816
|
+
|
|
1817
|
+
self._query_one(
|
|
1818
|
+
"CREATE ONLY $record CONTENT $content",
|
|
1819
|
+
{"record": record, "content": span_dict},
|
|
1820
|
+
dict,
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
except Exception as e:
|
|
1824
|
+
log_error(f"Error creating spans batch: {e}")
|
|
1825
|
+
|
|
1826
|
+
def get_span(self, span_id: str):
|
|
1827
|
+
"""Get a single span by its span_id.
|
|
1828
|
+
|
|
1829
|
+
Args:
|
|
1830
|
+
span_id: The unique span identifier.
|
|
1831
|
+
|
|
1832
|
+
Returns:
|
|
1833
|
+
Optional[Span]: The span if found, None otherwise.
|
|
1834
|
+
"""
|
|
1835
|
+
try:
|
|
1836
|
+
table = self._get_table("spans", create_table_if_not_found=False)
|
|
1837
|
+
record = RecordID(table, span_id)
|
|
1838
|
+
|
|
1839
|
+
span_data = self._query_one("SELECT * FROM ONLY $record", {"record": record}, dict)
|
|
1840
|
+
if not span_data:
|
|
1841
|
+
return None
|
|
1842
|
+
|
|
1843
|
+
return self._deserialize_span(span_data)
|
|
1844
|
+
|
|
1845
|
+
except Exception as e:
|
|
1846
|
+
log_error(f"Error getting span: {e}")
|
|
1847
|
+
return None
|
|
1848
|
+
|
|
1849
|
+
def get_spans(
|
|
1850
|
+
self,
|
|
1851
|
+
trace_id: Optional[str] = None,
|
|
1852
|
+
parent_span_id: Optional[str] = None,
|
|
1853
|
+
limit: Optional[int] = 1000,
|
|
1854
|
+
) -> List:
|
|
1855
|
+
"""Get spans matching the provided filters.
|
|
1856
|
+
|
|
1857
|
+
Args:
|
|
1858
|
+
trace_id: Filter by trace ID.
|
|
1859
|
+
parent_span_id: Filter by parent span ID.
|
|
1860
|
+
limit: Maximum number of spans to return.
|
|
1861
|
+
|
|
1862
|
+
Returns:
|
|
1863
|
+
List[Span]: List of matching spans.
|
|
1864
|
+
"""
|
|
1865
|
+
try:
|
|
1866
|
+
table = self._get_table("spans", create_table_if_not_found=False)
|
|
1867
|
+
|
|
1868
|
+
# Build where clause
|
|
1869
|
+
where = WhereClause()
|
|
1870
|
+
if trace_id:
|
|
1871
|
+
where.and_("trace_id", trace_id)
|
|
1872
|
+
if parent_span_id:
|
|
1873
|
+
where.and_("parent_span_id", parent_span_id)
|
|
1874
|
+
|
|
1875
|
+
where_clause, where_vars = where.build()
|
|
1876
|
+
|
|
1877
|
+
# Query
|
|
1878
|
+
limit_clause = f"LIMIT {limit}" if limit else ""
|
|
1879
|
+
query = dedent(f"""
|
|
1880
|
+
SELECT * FROM {table}
|
|
1881
|
+
{where_clause}
|
|
1882
|
+
ORDER BY start_time ASC
|
|
1883
|
+
{limit_clause}
|
|
1884
|
+
""")
|
|
1885
|
+
spans_raw = self._query(query, where_vars, dict)
|
|
1886
|
+
|
|
1887
|
+
return [self._deserialize_span(s) for s in spans_raw]
|
|
1888
|
+
|
|
1889
|
+
except Exception as e:
|
|
1890
|
+
log_error(f"Error getting spans: {e}")
|
|
1891
|
+
return []
|
|
1892
|
+
|
|
1893
|
+
def _deserialize_span(self, span_data: dict) -> "Span":
|
|
1894
|
+
"""Helper to deserialize a span record from SurrealDB."""
|
|
1895
|
+
from agno.tracing.schemas import Span
|
|
1896
|
+
|
|
1897
|
+
# Handle RecordID for id field
|
|
1898
|
+
if isinstance(span_data.get("id"), RecordID):
|
|
1899
|
+
if "span_id" not in span_data or not span_data["span_id"]:
|
|
1900
|
+
span_data["span_id"] = span_data["id"].id
|
|
1901
|
+
del span_data["id"]
|
|
1902
|
+
|
|
1903
|
+
# Convert datetime to ISO string for Span.from_dict
|
|
1904
|
+
for field in ["start_time", "end_time", "created_at"]:
|
|
1905
|
+
if isinstance(span_data.get(field), datetime):
|
|
1906
|
+
span_data[field] = span_data[field].isoformat()
|
|
1907
|
+
|
|
1908
|
+
return Span.from_dict(span_data)
|