agno 2.1.2__py3-none-any.whl → 2.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +5540 -2273
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +689 -6
- agno/db/dynamo/dynamo.py +933 -37
- agno/db/dynamo/schemas.py +174 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +831 -9
- agno/db/firestore/schemas.py +51 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +660 -12
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +287 -14
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +590 -14
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +43 -13
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2760 -0
- agno/db/mongo/mongo.py +879 -11
- agno/db/mongo/schemas.py +42 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2912 -0
- agno/db/mysql/mysql.py +946 -68
- agno/db/mysql/schemas.py +72 -10
- agno/db/mysql/utils.py +198 -7
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2579 -0
- agno/db/postgres/postgres.py +942 -57
- agno/db/postgres/schemas.py +81 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +671 -7
- agno/db/redis/schemas.py +50 -0
- agno/db/redis/utils.py +65 -7
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +17 -2
- agno/db/singlestore/schemas.py +63 -0
- agno/db/singlestore/singlestore.py +949 -83
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2911 -0
- agno/db/sqlite/schemas.py +62 -0
- agno/db/sqlite/sqlite.py +965 -46
- agno/db/sqlite/utils.py +169 -8
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +334 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1908 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +2 -0
- agno/eval/__init__.py +10 -0
- agno/eval/accuracy.py +75 -55
- agno/eval/agent_as_judge.py +861 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +16 -7
- agno/eval/reliability.py +28 -16
- agno/eval/utils.py +35 -17
- agno/exceptions.py +27 -2
- agno/filters.py +354 -0
- agno/guardrails/prompt_injection.py +1 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +1 -1
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/semantic.py +9 -4
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +8 -0
- agno/knowledge/embedder/openai.py +8 -8
- agno/knowledge/embedder/sentence_transformer.py +6 -2
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/knowledge.py +1618 -318
- agno/knowledge/reader/base.py +6 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +16 -20
- agno/knowledge/reader/json_reader.py +5 -4
- agno/knowledge/reader/markdown_reader.py +8 -8
- agno/knowledge/reader/pdf_reader.py +17 -19
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +32 -3
- agno/knowledge/reader/s3_reader.py +3 -3
- agno/knowledge/reader/tavily_reader.py +193 -0
- agno/knowledge/reader/text_reader.py +22 -10
- agno/knowledge/reader/web_search_reader.py +1 -48
- agno/knowledge/reader/website_reader.py +10 -10
- agno/knowledge/reader/wikipedia_reader.py +33 -1
- agno/knowledge/types.py +1 -0
- agno/knowledge/utils.py +72 -7
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +544 -83
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +515 -40
- agno/models/aws/bedrock.py +102 -21
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +41 -19
- agno/models/azure/openai_chat.py +39 -8
- agno/models/base.py +1249 -525
- agno/models/cerebras/cerebras.py +91 -21
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +40 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +877 -80
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +51 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +44 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +28 -5
- agno/models/meta/llama.py +47 -14
- agno/models/meta/llama_openai.py +22 -17
- agno/models/mistral/mistral.py +8 -4
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/chat.py +24 -8
- agno/models/openai/chat.py +104 -29
- agno/models/openai/responses.py +101 -81
- agno/models/openrouter/openrouter.py +60 -3
- agno/models/perplexity/perplexity.py +17 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +24 -4
- agno/models/response.py +73 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +190 -0
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +549 -152
- agno/os/auth.py +190 -3
- agno/os/config.py +23 -0
- agno/os/interfaces/a2a/router.py +8 -11
- agno/os/interfaces/a2a/utils.py +1 -1
- agno/os/interfaces/agui/router.py +18 -3
- agno/os/interfaces/agui/utils.py +152 -39
- agno/os/interfaces/slack/router.py +55 -37
- agno/os/interfaces/slack/slack.py +9 -1
- agno/os/interfaces/whatsapp/router.py +0 -1
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/mcp.py +110 -52
- agno/os/middleware/__init__.py +2 -0
- agno/os/middleware/jwt.py +676 -112
- agno/os/router.py +40 -1478
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +599 -0
- agno/os/routers/agents/schema.py +261 -0
- agno/os/routers/evals/evals.py +96 -39
- agno/os/routers/evals/schemas.py +65 -33
- agno/os/routers/evals/utils.py +80 -10
- agno/os/routers/health.py +10 -4
- agno/os/routers/knowledge/knowledge.py +196 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +279 -52
- agno/os/routers/memory/schemas.py +46 -17
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +462 -34
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +512 -0
- agno/os/routers/teams/schema.py +257 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +499 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +624 -0
- agno/os/routers/workflows/schema.py +75 -0
- agno/os/schema.py +256 -693
- agno/os/scopes.py +469 -0
- agno/os/utils.py +514 -36
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/openai.py +5 -0
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +155 -32
- agno/run/base.py +55 -3
- agno/run/requirement.py +181 -0
- agno/run/team.py +125 -38
- agno/run/workflow.py +72 -18
- agno/session/agent.py +102 -89
- agno/session/summary.py +56 -15
- agno/session/team.py +164 -90
- agno/session/workflow.py +405 -40
- agno/table.py +10 -0
- agno/team/team.py +3974 -1903
- agno/tools/dalle.py +2 -4
- agno/tools/eleven_labs.py +23 -25
- agno/tools/exa.py +21 -16
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +16 -10
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +193 -38
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +271 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +3 -3
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/notion.py +204 -0
- agno/tools/parallel.py +314 -0
- agno/tools/postgres.py +76 -36
- agno/tools/redshift.py +406 -0
- agno/tools/scrapegraph.py +1 -1
- agno/tools/shopify.py +1519 -0
- agno/tools/slack.py +18 -3
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +146 -0
- agno/tools/toolkit.py +25 -0
- agno/tools/workflow.py +8 -1
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +157 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +111 -0
- agno/utils/agent.py +938 -0
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +151 -3
- agno/utils/gemini.py +15 -5
- agno/utils/hooks.py +118 -4
- agno/utils/http.py +113 -2
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +92 -2
- agno/utils/media.py +187 -1
- agno/utils/merge_dict.py +3 -3
- agno/utils/message.py +60 -0
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +49 -14
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/print_response/agent.py +109 -16
- agno/utils/print_response/team.py +223 -30
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/streamlit.py +1 -1
- agno/utils/team.py +98 -9
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +39 -7
- agno/vectordb/cassandra/cassandra.py +21 -5
- agno/vectordb/chroma/chromadb.py +43 -12
- agno/vectordb/clickhouse/clickhousedb.py +21 -5
- agno/vectordb/couchbase/couchbase.py +29 -5
- agno/vectordb/lancedb/lance_db.py +92 -181
- agno/vectordb/langchaindb/langchaindb.py +24 -4
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/llamaindexdb.py +25 -5
- agno/vectordb/milvus/milvus.py +50 -37
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +36 -30
- agno/vectordb/pgvector/pgvector.py +201 -77
- agno/vectordb/pineconedb/pineconedb.py +41 -23
- agno/vectordb/qdrant/qdrant.py +67 -54
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +682 -0
- agno/vectordb/singlestore/singlestore.py +50 -29
- agno/vectordb/surrealdb/surrealdb.py +31 -41
- agno/vectordb/upstashdb/upstashdb.py +34 -6
- agno/vectordb/weaviate/weaviate.py +53 -14
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +120 -18
- agno/workflow/loop.py +77 -10
- agno/workflow/parallel.py +231 -143
- agno/workflow/router.py +118 -17
- agno/workflow/step.py +609 -170
- agno/workflow/steps.py +73 -6
- agno/workflow/types.py +96 -21
- agno/workflow/workflow.py +2039 -262
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/METADATA +201 -66
- agno-2.3.13.dist-info/RECORD +613 -0
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -679
- agno/tools/memori.py +0 -339
- agno-2.1.2.dist-info/RECORD +0 -543
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/WHEEL +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.2.dist-info → agno-2.3.13.dist-info}/top_level.txt +0 -0
agno/db/postgres/postgres.py
CHANGED
|
@@ -1,20 +1,27 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from datetime import date, datetime, timedelta, timezone
|
|
3
|
-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from agno.tracing.schemas import Span, Trace
|
|
8
|
+
|
|
6
9
|
from agno.db.base import BaseDb, SessionType
|
|
10
|
+
from agno.db.migrations.manager import MigrationManager
|
|
7
11
|
from agno.db.postgres.schemas import get_table_schema_definition
|
|
8
12
|
from agno.db.postgres.utils import (
|
|
9
13
|
apply_sorting,
|
|
10
14
|
bulk_upsert_metrics,
|
|
11
15
|
calculate_date_metrics,
|
|
12
16
|
create_schema,
|
|
17
|
+
deserialize_cultural_knowledge,
|
|
13
18
|
fetch_all_sessions_data,
|
|
14
19
|
get_dates_to_calculate_metrics_for,
|
|
15
20
|
is_table_available,
|
|
16
21
|
is_valid_table,
|
|
22
|
+
serialize_cultural_knowledge,
|
|
17
23
|
)
|
|
24
|
+
from agno.db.schemas.culture import CulturalKnowledge
|
|
18
25
|
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
19
26
|
from agno.db.schemas.knowledge import KnowledgeRow
|
|
20
27
|
from agno.db.schemas.memory import UserMemory
|
|
@@ -23,12 +30,14 @@ from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
|
23
30
|
from agno.utils.string import generate_id
|
|
24
31
|
|
|
25
32
|
try:
|
|
26
|
-
from sqlalchemy import Index, String, UniqueConstraint, func, update
|
|
33
|
+
from sqlalchemy import ForeignKey, Index, String, UniqueConstraint, and_, case, func, or_, select, update
|
|
27
34
|
from sqlalchemy.dialects import postgresql
|
|
35
|
+
from sqlalchemy.dialects.postgresql import TIMESTAMP
|
|
28
36
|
from sqlalchemy.engine import Engine, create_engine
|
|
37
|
+
from sqlalchemy.exc import ProgrammingError
|
|
29
38
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
|
30
39
|
from sqlalchemy.schema import Column, MetaData, Table
|
|
31
|
-
from sqlalchemy.sql.expression import
|
|
40
|
+
from sqlalchemy.sql.expression import text
|
|
32
41
|
except ImportError:
|
|
33
42
|
raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`")
|
|
34
43
|
|
|
@@ -40,11 +49,16 @@ class PostgresDb(BaseDb):
|
|
|
40
49
|
db_engine: Optional[Engine] = None,
|
|
41
50
|
db_schema: Optional[str] = None,
|
|
42
51
|
session_table: Optional[str] = None,
|
|
52
|
+
culture_table: Optional[str] = None,
|
|
43
53
|
memory_table: Optional[str] = None,
|
|
44
54
|
metrics_table: Optional[str] = None,
|
|
45
55
|
eval_table: Optional[str] = None,
|
|
46
56
|
knowledge_table: Optional[str] = None,
|
|
57
|
+
traces_table: Optional[str] = None,
|
|
58
|
+
spans_table: Optional[str] = None,
|
|
59
|
+
versions_table: Optional[str] = None,
|
|
47
60
|
id: Optional[str] = None,
|
|
61
|
+
create_schema: bool = True,
|
|
48
62
|
):
|
|
49
63
|
"""
|
|
50
64
|
Interface for interacting with a PostgreSQL database.
|
|
@@ -63,7 +77,13 @@ class PostgresDb(BaseDb):
|
|
|
63
77
|
metrics_table (Optional[str]): Name of the table to store metrics.
|
|
64
78
|
eval_table (Optional[str]): Name of the table to store evaluation runs data.
|
|
65
79
|
knowledge_table (Optional[str]): Name of the table to store knowledge content.
|
|
80
|
+
culture_table (Optional[str]): Name of the table to store cultural knowledge.
|
|
81
|
+
traces_table (Optional[str]): Name of the table to store run traces.
|
|
82
|
+
spans_table (Optional[str]): Name of the table to store span events.
|
|
83
|
+
versions_table (Optional[str]): Name of the table to store schema versions.
|
|
66
84
|
id (Optional[str]): ID of the database.
|
|
85
|
+
create_schema (bool): Whether to automatically create the database schema if it doesn't exist.
|
|
86
|
+
Set to False if schema is managed externally (e.g., via migrations). Defaults to True.
|
|
67
87
|
|
|
68
88
|
Raises:
|
|
69
89
|
ValueError: If neither db_url nor db_engine is provided.
|
|
@@ -91,23 +111,53 @@ class PostgresDb(BaseDb):
|
|
|
91
111
|
metrics_table=metrics_table,
|
|
92
112
|
eval_table=eval_table,
|
|
93
113
|
knowledge_table=knowledge_table,
|
|
114
|
+
culture_table=culture_table,
|
|
115
|
+
traces_table=traces_table,
|
|
116
|
+
spans_table=spans_table,
|
|
117
|
+
versions_table=versions_table,
|
|
94
118
|
)
|
|
95
119
|
|
|
96
120
|
self.db_schema: str = db_schema if db_schema is not None else "ai"
|
|
97
|
-
self.metadata: MetaData = MetaData()
|
|
121
|
+
self.metadata: MetaData = MetaData(schema=self.db_schema)
|
|
122
|
+
self.create_schema: bool = create_schema
|
|
98
123
|
|
|
99
124
|
# Initialize database session
|
|
100
|
-
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
|
|
125
|
+
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine, expire_on_commit=False))
|
|
101
126
|
|
|
102
127
|
# -- DB methods --
|
|
103
|
-
def
|
|
128
|
+
def table_exists(self, table_name: str) -> bool:
|
|
129
|
+
"""Check if a table with the given name exists in the Postgres database.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
table_name: Name of the table to check
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
bool: True if the table exists in the database, False otherwise
|
|
136
|
+
"""
|
|
137
|
+
with self.Session() as sess:
|
|
138
|
+
return is_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
|
|
139
|
+
|
|
140
|
+
def _create_all_tables(self):
|
|
141
|
+
"""Create all tables for the database."""
|
|
142
|
+
tables_to_create = [
|
|
143
|
+
(self.session_table_name, "sessions"),
|
|
144
|
+
(self.memory_table_name, "memories"),
|
|
145
|
+
(self.metrics_table_name, "metrics"),
|
|
146
|
+
(self.eval_table_name, "evals"),
|
|
147
|
+
(self.knowledge_table_name, "knowledge"),
|
|
148
|
+
(self.versions_table_name, "versions"),
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
for table_name, table_type in tables_to_create:
|
|
152
|
+
self._get_or_create_table(table_name=table_name, table_type=table_type, create_table_if_not_found=True)
|
|
153
|
+
|
|
154
|
+
def _create_table(self, table_name: str, table_type: str) -> Table:
|
|
104
155
|
"""
|
|
105
156
|
Create a table with the appropriate schema based on the table type.
|
|
106
157
|
|
|
107
158
|
Args:
|
|
108
159
|
table_name (str): Name of the table to create
|
|
109
160
|
table_type (str): Type of table (used to get schema definition)
|
|
110
|
-
db_schema (str): Database schema name
|
|
111
161
|
|
|
112
162
|
Returns:
|
|
113
163
|
Table: SQLAlchemy Table object
|
|
@@ -133,11 +183,20 @@ class PostgresDb(BaseDb):
|
|
|
133
183
|
if col_config.get("unique", False):
|
|
134
184
|
column_kwargs["unique"] = True
|
|
135
185
|
unique_constraints.append(col_name)
|
|
186
|
+
|
|
187
|
+
# Handle foreign key constraint
|
|
188
|
+
if "foreign_key" in col_config:
|
|
189
|
+
fk_ref = col_config["foreign_key"]
|
|
190
|
+
# For spans table, dynamically replace the traces table reference
|
|
191
|
+
# with the actual trace table name configured for this db instance
|
|
192
|
+
if table_type == "spans" and "trace_id" in fk_ref:
|
|
193
|
+
fk_ref = f"{self.db_schema}.{self.trace_table_name}.trace_id"
|
|
194
|
+
column_args.append(ForeignKey(fk_ref))
|
|
195
|
+
|
|
136
196
|
columns.append(Column(*column_args, **column_kwargs)) # type: ignore
|
|
137
197
|
|
|
138
198
|
# Create the table object
|
|
139
|
-
|
|
140
|
-
table = Table(table_name, table_metadata, *columns, schema=db_schema)
|
|
199
|
+
table = Table(table_name, self.metadata, *columns, schema=self.db_schema)
|
|
141
200
|
|
|
142
201
|
# Add multi-column unique constraints with table-specific names
|
|
143
202
|
for constraint in schema_unique_constraints:
|
|
@@ -150,11 +209,18 @@ class PostgresDb(BaseDb):
|
|
|
150
209
|
idx_name = f"idx_{table_name}_{idx_col}"
|
|
151
210
|
table.append_constraint(Index(idx_name, idx_col))
|
|
152
211
|
|
|
153
|
-
|
|
154
|
-
|
|
212
|
+
if self.create_schema:
|
|
213
|
+
with self.Session() as sess, sess.begin():
|
|
214
|
+
create_schema(session=sess, db_schema=self.db_schema)
|
|
155
215
|
|
|
156
216
|
# Create table
|
|
157
|
-
|
|
217
|
+
table_created = False
|
|
218
|
+
if not self.table_exists(table_name):
|
|
219
|
+
table.create(self.db_engine, checkfirst=True)
|
|
220
|
+
log_debug(f"Successfully created table '{table_name}'")
|
|
221
|
+
table_created = True
|
|
222
|
+
else:
|
|
223
|
+
log_debug(f"Table {self.db_schema}.{table_name} already exists, skipping creation")
|
|
158
224
|
|
|
159
225
|
# Create indexes
|
|
160
226
|
for idx in table.indexes:
|
|
@@ -165,24 +231,29 @@ class PostgresDb(BaseDb):
|
|
|
165
231
|
"SELECT 1 FROM pg_indexes WHERE schemaname = :schema AND indexname = :index_name"
|
|
166
232
|
)
|
|
167
233
|
exists = (
|
|
168
|
-
sess.execute(exists_query, {"schema": db_schema, "index_name": idx.name}).scalar()
|
|
234
|
+
sess.execute(exists_query, {"schema": self.db_schema, "index_name": idx.name}).scalar()
|
|
169
235
|
is not None
|
|
170
236
|
)
|
|
171
237
|
if exists:
|
|
172
|
-
log_debug(
|
|
238
|
+
log_debug(
|
|
239
|
+
f"Index {idx.name} already exists in {self.db_schema}.{table_name}, skipping creation"
|
|
240
|
+
)
|
|
173
241
|
continue
|
|
174
242
|
|
|
175
243
|
idx.create(self.db_engine)
|
|
176
|
-
log_debug(f"Created index: {idx.name} for table {db_schema}.{table_name}")
|
|
244
|
+
log_debug(f"Created index: {idx.name} for table {self.db_schema}.{table_name}")
|
|
177
245
|
|
|
178
246
|
except Exception as e:
|
|
179
247
|
log_error(f"Error creating index {idx.name}: {e}")
|
|
180
248
|
|
|
181
|
-
|
|
249
|
+
# Store the schema version for the created table
|
|
250
|
+
if table_name != self.versions_table_name and table_created:
|
|
251
|
+
latest_schema_version = MigrationManager(self).latest_schema_version
|
|
252
|
+
self.upsert_schema_version(table_name=table_name, version=latest_schema_version.public)
|
|
182
253
|
return table
|
|
183
254
|
|
|
184
255
|
except Exception as e:
|
|
185
|
-
log_error(f"Could not create table {db_schema}.{table_name}: {e}")
|
|
256
|
+
log_error(f"Could not create table {self.db_schema}.{table_name}: {e}")
|
|
186
257
|
raise
|
|
187
258
|
|
|
188
259
|
def _get_table(self, table_type: str, create_table_if_not_found: Optional[bool] = False) -> Optional[Table]:
|
|
@@ -190,7 +261,6 @@ class PostgresDb(BaseDb):
|
|
|
190
261
|
self.session_table = self._get_or_create_table(
|
|
191
262
|
table_name=self.session_table_name,
|
|
192
263
|
table_type="sessions",
|
|
193
|
-
db_schema=self.db_schema,
|
|
194
264
|
create_table_if_not_found=create_table_if_not_found,
|
|
195
265
|
)
|
|
196
266
|
return self.session_table
|
|
@@ -199,7 +269,6 @@ class PostgresDb(BaseDb):
|
|
|
199
269
|
self.memory_table = self._get_or_create_table(
|
|
200
270
|
table_name=self.memory_table_name,
|
|
201
271
|
table_type="memories",
|
|
202
|
-
db_schema=self.db_schema,
|
|
203
272
|
create_table_if_not_found=create_table_if_not_found,
|
|
204
273
|
)
|
|
205
274
|
return self.memory_table
|
|
@@ -208,7 +277,6 @@ class PostgresDb(BaseDb):
|
|
|
208
277
|
self.metrics_table = self._get_or_create_table(
|
|
209
278
|
table_name=self.metrics_table_name,
|
|
210
279
|
table_type="metrics",
|
|
211
|
-
db_schema=self.db_schema,
|
|
212
280
|
create_table_if_not_found=create_table_if_not_found,
|
|
213
281
|
)
|
|
214
282
|
return self.metrics_table
|
|
@@ -217,7 +285,6 @@ class PostgresDb(BaseDb):
|
|
|
217
285
|
self.eval_table = self._get_or_create_table(
|
|
218
286
|
table_name=self.eval_table_name,
|
|
219
287
|
table_type="evals",
|
|
220
|
-
db_schema=self.db_schema,
|
|
221
288
|
create_table_if_not_found=create_table_if_not_found,
|
|
222
289
|
)
|
|
223
290
|
return self.eval_table
|
|
@@ -226,15 +293,50 @@ class PostgresDb(BaseDb):
|
|
|
226
293
|
self.knowledge_table = self._get_or_create_table(
|
|
227
294
|
table_name=self.knowledge_table_name,
|
|
228
295
|
table_type="knowledge",
|
|
229
|
-
db_schema=self.db_schema,
|
|
230
296
|
create_table_if_not_found=create_table_if_not_found,
|
|
231
297
|
)
|
|
232
298
|
return self.knowledge_table
|
|
233
299
|
|
|
300
|
+
if table_type == "culture":
|
|
301
|
+
self.culture_table = self._get_or_create_table(
|
|
302
|
+
table_name=self.culture_table_name,
|
|
303
|
+
table_type="culture",
|
|
304
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
305
|
+
)
|
|
306
|
+
return self.culture_table
|
|
307
|
+
|
|
308
|
+
if table_type == "versions":
|
|
309
|
+
self.versions_table = self._get_or_create_table(
|
|
310
|
+
table_name=self.versions_table_name,
|
|
311
|
+
table_type="versions",
|
|
312
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
313
|
+
)
|
|
314
|
+
return self.versions_table
|
|
315
|
+
|
|
316
|
+
if table_type == "traces":
|
|
317
|
+
self.traces_table = self._get_or_create_table(
|
|
318
|
+
table_name=self.trace_table_name,
|
|
319
|
+
table_type="traces",
|
|
320
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
321
|
+
)
|
|
322
|
+
return self.traces_table
|
|
323
|
+
|
|
324
|
+
if table_type == "spans":
|
|
325
|
+
# Ensure traces table exists first (spans has FK to traces)
|
|
326
|
+
if create_table_if_not_found:
|
|
327
|
+
self._get_table(table_type="traces", create_table_if_not_found=True)
|
|
328
|
+
|
|
329
|
+
self.spans_table = self._get_or_create_table(
|
|
330
|
+
table_name=self.span_table_name,
|
|
331
|
+
table_type="spans",
|
|
332
|
+
create_table_if_not_found=create_table_if_not_found,
|
|
333
|
+
)
|
|
334
|
+
return self.spans_table
|
|
335
|
+
|
|
234
336
|
raise ValueError(f"Unknown table type: {table_type}")
|
|
235
337
|
|
|
236
338
|
def _get_or_create_table(
|
|
237
|
-
self, table_name: str, table_type: str,
|
|
339
|
+
self, table_name: str, table_type: str, create_table_if_not_found: Optional[bool] = False
|
|
238
340
|
) -> Optional[Table]:
|
|
239
341
|
"""
|
|
240
342
|
Check if the table exists and is valid, else create it.
|
|
@@ -242,39 +344,72 @@ class PostgresDb(BaseDb):
|
|
|
242
344
|
Args:
|
|
243
345
|
table_name (str): Name of the table to get or create
|
|
244
346
|
table_type (str): Type of table (used to get schema definition)
|
|
245
|
-
db_schema (str): Database schema name
|
|
246
347
|
|
|
247
348
|
Returns:
|
|
248
349
|
Optional[Table]: SQLAlchemy Table object representing the schema.
|
|
249
350
|
"""
|
|
250
351
|
|
|
251
352
|
with self.Session() as sess, sess.begin():
|
|
252
|
-
table_is_available = is_table_available(session=sess, table_name=table_name, db_schema=db_schema)
|
|
353
|
+
table_is_available = is_table_available(session=sess, table_name=table_name, db_schema=self.db_schema)
|
|
253
354
|
|
|
254
355
|
if not table_is_available:
|
|
255
356
|
if not create_table_if_not_found:
|
|
256
357
|
return None
|
|
257
|
-
|
|
258
|
-
return self._create_table(table_name=table_name, table_type=table_type, db_schema=db_schema)
|
|
358
|
+
return self._create_table(table_name=table_name, table_type=table_type)
|
|
259
359
|
|
|
260
360
|
if not is_valid_table(
|
|
261
361
|
db_engine=self.db_engine,
|
|
262
362
|
table_name=table_name,
|
|
263
363
|
table_type=table_type,
|
|
264
|
-
db_schema=db_schema,
|
|
364
|
+
db_schema=self.db_schema,
|
|
265
365
|
):
|
|
266
|
-
raise ValueError(f"Table {db_schema}.{table_name} has an invalid schema")
|
|
366
|
+
raise ValueError(f"Table {self.db_schema}.{table_name} has an invalid schema")
|
|
267
367
|
|
|
268
368
|
try:
|
|
269
|
-
table = Table(table_name, self.metadata, schema=db_schema, autoload_with=self.db_engine)
|
|
369
|
+
table = Table(table_name, self.metadata, schema=self.db_schema, autoload_with=self.db_engine)
|
|
270
370
|
return table
|
|
271
371
|
|
|
272
372
|
except Exception as e:
|
|
273
|
-
log_error(f"Error loading existing table {db_schema}.{table_name}: {e}")
|
|
373
|
+
log_error(f"Error loading existing table {self.db_schema}.{table_name}: {e}")
|
|
274
374
|
raise
|
|
275
375
|
|
|
276
|
-
|
|
376
|
+
def get_latest_schema_version(self, table_name: str):
|
|
377
|
+
"""Get the latest version of the database schema."""
|
|
378
|
+
table = self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
379
|
+
if table is None:
|
|
380
|
+
return "2.0.0"
|
|
381
|
+
with self.Session() as sess:
|
|
382
|
+
stmt = select(table)
|
|
383
|
+
# Latest version for the given table
|
|
384
|
+
stmt = stmt.where(table.c.table_name == table_name)
|
|
385
|
+
stmt = stmt.order_by(table.c.version.desc()).limit(1)
|
|
386
|
+
result = sess.execute(stmt).fetchone()
|
|
387
|
+
if result is None:
|
|
388
|
+
return "2.0.0"
|
|
389
|
+
version_dict = dict(result._mapping)
|
|
390
|
+
return version_dict.get("version") or "2.0.0"
|
|
391
|
+
|
|
392
|
+
def upsert_schema_version(self, table_name: str, version: str) -> None:
|
|
393
|
+
"""Upsert the schema version into the database."""
|
|
394
|
+
table = self._get_table(table_type="versions", create_table_if_not_found=True)
|
|
395
|
+
if table is None:
|
|
396
|
+
return
|
|
397
|
+
current_datetime = datetime.now().isoformat()
|
|
398
|
+
with self.Session() as sess, sess.begin():
|
|
399
|
+
stmt = postgresql.insert(table).values(
|
|
400
|
+
table_name=table_name,
|
|
401
|
+
version=version,
|
|
402
|
+
created_at=current_datetime, # Store as ISO format string
|
|
403
|
+
updated_at=current_datetime,
|
|
404
|
+
)
|
|
405
|
+
# Update version if table_name already exists
|
|
406
|
+
stmt = stmt.on_conflict_do_update(
|
|
407
|
+
index_elements=["table_name"],
|
|
408
|
+
set_=dict(version=version, updated_at=current_datetime),
|
|
409
|
+
)
|
|
410
|
+
sess.execute(stmt)
|
|
277
411
|
|
|
412
|
+
# -- Session methods --
|
|
278
413
|
def delete_session(self, session_id: str) -> bool:
|
|
279
414
|
"""
|
|
280
415
|
Delete a session from the database.
|
|
@@ -368,9 +503,6 @@ class PostgresDb(BaseDb):
|
|
|
368
503
|
|
|
369
504
|
if user_id is not None:
|
|
370
505
|
stmt = stmt.where(table.c.user_id == user_id)
|
|
371
|
-
if session_type is not None:
|
|
372
|
-
session_type_value = session_type.value if isinstance(session_type, SessionType) else session_type
|
|
373
|
-
stmt = stmt.where(table.c.session_type == session_type_value)
|
|
374
506
|
result = sess.execute(stmt).fetchone()
|
|
375
507
|
if result is None:
|
|
376
508
|
return None
|
|
@@ -704,7 +836,7 @@ class PostgresDb(BaseDb):
|
|
|
704
836
|
raise e
|
|
705
837
|
|
|
706
838
|
def upsert_sessions(
|
|
707
|
-
self, sessions: List[Session], deserialize: Optional[bool] = True
|
|
839
|
+
self, sessions: List[Session], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
|
|
708
840
|
) -> List[Union[Session, Dict[str, Any]]]:
|
|
709
841
|
"""
|
|
710
842
|
Bulk insert or update multiple sessions.
|
|
@@ -712,6 +844,7 @@ class PostgresDb(BaseDb):
|
|
|
712
844
|
Args:
|
|
713
845
|
sessions (List[Session]): The list of session data to upsert.
|
|
714
846
|
deserialize (Optional[bool]): Whether to deserialize the sessions. Defaults to True.
|
|
847
|
+
preserve_updated_at (bool): If True, preserve the updated_at from the session object.
|
|
715
848
|
|
|
716
849
|
Returns:
|
|
717
850
|
List[Union[Session, Dict[str, Any]]]: List of upserted sessions
|
|
@@ -739,6 +872,8 @@ class PostgresDb(BaseDb):
|
|
|
739
872
|
session_records = []
|
|
740
873
|
for agent_session in agent_sessions:
|
|
741
874
|
session_dict = agent_session.to_dict()
|
|
875
|
+
# Use preserved updated_at if flag is set (even if None), otherwise use current time
|
|
876
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
742
877
|
session_records.append(
|
|
743
878
|
{
|
|
744
879
|
"session_id": session_dict.get("session_id"),
|
|
@@ -751,7 +886,7 @@ class PostgresDb(BaseDb):
|
|
|
751
886
|
"metadata": session_dict.get("metadata"),
|
|
752
887
|
"runs": session_dict.get("runs"),
|
|
753
888
|
"created_at": session_dict.get("created_at"),
|
|
754
|
-
"updated_at":
|
|
889
|
+
"updated_at": updated_at,
|
|
755
890
|
}
|
|
756
891
|
)
|
|
757
892
|
|
|
@@ -782,6 +917,8 @@ class PostgresDb(BaseDb):
|
|
|
782
917
|
session_records = []
|
|
783
918
|
for team_session in team_sessions:
|
|
784
919
|
session_dict = team_session.to_dict()
|
|
920
|
+
# Use preserved updated_at if flag is set (even if None), otherwise use current time
|
|
921
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
785
922
|
session_records.append(
|
|
786
923
|
{
|
|
787
924
|
"session_id": session_dict.get("session_id"),
|
|
@@ -794,7 +931,7 @@ class PostgresDb(BaseDb):
|
|
|
794
931
|
"metadata": session_dict.get("metadata"),
|
|
795
932
|
"runs": session_dict.get("runs"),
|
|
796
933
|
"created_at": session_dict.get("created_at"),
|
|
797
|
-
"updated_at":
|
|
934
|
+
"updated_at": updated_at,
|
|
798
935
|
}
|
|
799
936
|
)
|
|
800
937
|
|
|
@@ -825,6 +962,8 @@ class PostgresDb(BaseDb):
|
|
|
825
962
|
session_records = []
|
|
826
963
|
for workflow_session in workflow_sessions:
|
|
827
964
|
session_dict = workflow_session.to_dict()
|
|
965
|
+
# Use preserved updated_at if flag is set (even if None), otherwise use current time
|
|
966
|
+
updated_at = session_dict.get("updated_at") if preserve_updated_at else int(time.time())
|
|
828
967
|
session_records.append(
|
|
829
968
|
{
|
|
830
969
|
"session_id": session_dict.get("session_id"),
|
|
@@ -837,7 +976,7 @@ class PostgresDb(BaseDb):
|
|
|
837
976
|
"metadata": session_dict.get("metadata"),
|
|
838
977
|
"runs": session_dict.get("runs"),
|
|
839
978
|
"created_at": session_dict.get("created_at"),
|
|
840
|
-
"updated_at":
|
|
979
|
+
"updated_at": updated_at,
|
|
841
980
|
}
|
|
842
981
|
)
|
|
843
982
|
|
|
@@ -950,9 +1089,14 @@ class PostgresDb(BaseDb):
|
|
|
950
1089
|
return []
|
|
951
1090
|
|
|
952
1091
|
with self.Session() as sess, sess.begin():
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
1092
|
+
try:
|
|
1093
|
+
stmt = select(func.jsonb_array_elements_text(table.c.topics))
|
|
1094
|
+
result = sess.execute(stmt).fetchall()
|
|
1095
|
+
except ProgrammingError:
|
|
1096
|
+
# Retrying with json_array_elements_text. This works in older versions,
|
|
1097
|
+
# where the topics column was of type JSON instead of JSONB
|
|
1098
|
+
stmt = select(func.json_array_elements_text(table.c.topics))
|
|
1099
|
+
result = sess.execute(stmt).fetchall()
|
|
956
1100
|
|
|
957
1101
|
return list(set([record[0] for record in result]))
|
|
958
1102
|
|
|
@@ -1105,13 +1249,14 @@ class PostgresDb(BaseDb):
|
|
|
1105
1249
|
raise e
|
|
1106
1250
|
|
|
1107
1251
|
def get_user_memory_stats(
|
|
1108
|
-
self, limit: Optional[int] = None, page: Optional[int] = None
|
|
1252
|
+
self, limit: Optional[int] = None, page: Optional[int] = None, user_id: Optional[str] = None
|
|
1109
1253
|
) -> Tuple[List[Dict[str, Any]], int]:
|
|
1110
1254
|
"""Get user memories stats.
|
|
1111
1255
|
|
|
1112
1256
|
Args:
|
|
1113
1257
|
limit (Optional[int]): The maximum number of user stats to return.
|
|
1114
1258
|
page (Optional[int]): The page number.
|
|
1259
|
+
user_id (Optional[str]): User ID for filtering.
|
|
1115
1260
|
|
|
1116
1261
|
Returns:
|
|
1117
1262
|
Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
|
|
@@ -1134,16 +1279,17 @@ class PostgresDb(BaseDb):
|
|
|
1134
1279
|
return [], 0
|
|
1135
1280
|
|
|
1136
1281
|
with self.Session() as sess, sess.begin():
|
|
1137
|
-
stmt = (
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
func.max(table.c.updated_at).label("last_memory_updated_at"),
|
|
1142
|
-
)
|
|
1143
|
-
.where(table.c.user_id.is_not(None))
|
|
1144
|
-
.group_by(table.c.user_id)
|
|
1145
|
-
.order_by(func.max(table.c.updated_at).desc())
|
|
1282
|
+
stmt = select(
|
|
1283
|
+
table.c.user_id,
|
|
1284
|
+
func.count(table.c.memory_id).label("total_memories"),
|
|
1285
|
+
func.max(table.c.updated_at).label("last_memory_updated_at"),
|
|
1146
1286
|
)
|
|
1287
|
+
if user_id is not None:
|
|
1288
|
+
stmt = stmt.where(table.c.user_id == user_id)
|
|
1289
|
+
else:
|
|
1290
|
+
stmt = stmt.where(table.c.user_id.is_not(None))
|
|
1291
|
+
stmt = stmt.group_by(table.c.user_id)
|
|
1292
|
+
stmt = stmt.order_by(func.max(table.c.updated_at).desc())
|
|
1147
1293
|
|
|
1148
1294
|
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
1149
1295
|
total_count = sess.execute(count_stmt).scalar()
|
|
@@ -1197,6 +1343,8 @@ class PostgresDb(BaseDb):
|
|
|
1197
1343
|
if memory.memory_id is None:
|
|
1198
1344
|
memory.memory_id = str(uuid4())
|
|
1199
1345
|
|
|
1346
|
+
current_time = int(time.time())
|
|
1347
|
+
|
|
1200
1348
|
stmt = postgresql.insert(table).values(
|
|
1201
1349
|
memory_id=memory.memory_id,
|
|
1202
1350
|
memory=memory.memory,
|
|
@@ -1205,7 +1353,9 @@ class PostgresDb(BaseDb):
|
|
|
1205
1353
|
agent_id=memory.agent_id,
|
|
1206
1354
|
team_id=memory.team_id,
|
|
1207
1355
|
topics=memory.topics,
|
|
1208
|
-
|
|
1356
|
+
feedback=memory.feedback,
|
|
1357
|
+
created_at=memory.created_at,
|
|
1358
|
+
updated_at=memory.created_at,
|
|
1209
1359
|
)
|
|
1210
1360
|
stmt = stmt.on_conflict_do_update( # type: ignore
|
|
1211
1361
|
index_elements=["memory_id"],
|
|
@@ -1215,7 +1365,10 @@ class PostgresDb(BaseDb):
|
|
|
1215
1365
|
input=memory.input,
|
|
1216
1366
|
agent_id=memory.agent_id,
|
|
1217
1367
|
team_id=memory.team_id,
|
|
1218
|
-
|
|
1368
|
+
feedback=memory.feedback,
|
|
1369
|
+
updated_at=current_time,
|
|
1370
|
+
# Preserve created_at on update - don't overwrite existing value
|
|
1371
|
+
created_at=table.c.created_at,
|
|
1219
1372
|
),
|
|
1220
1373
|
).returning(table)
|
|
1221
1374
|
|
|
@@ -1234,7 +1387,7 @@ class PostgresDb(BaseDb):
|
|
|
1234
1387
|
raise e
|
|
1235
1388
|
|
|
1236
1389
|
def upsert_memories(
|
|
1237
|
-
self, memories: List[UserMemory], deserialize: Optional[bool] = True
|
|
1390
|
+
self, memories: List[UserMemory], deserialize: Optional[bool] = True, preserve_updated_at: bool = False
|
|
1238
1391
|
) -> List[Union[UserMemory, Dict[str, Any]]]:
|
|
1239
1392
|
"""
|
|
1240
1393
|
Bulk insert or update multiple memories in the database for improved performance.
|
|
@@ -1242,6 +1395,8 @@ class PostgresDb(BaseDb):
|
|
|
1242
1395
|
Args:
|
|
1243
1396
|
memories (List[UserMemory]): The list of memories to upsert.
|
|
1244
1397
|
deserialize (Optional[bool]): Whether to deserialize the memories. Defaults to True.
|
|
1398
|
+
preserve_updated_at (bool): If True, preserve the updated_at from the memory object.
|
|
1399
|
+
If False (default), set updated_at to current time.
|
|
1245
1400
|
|
|
1246
1401
|
Returns:
|
|
1247
1402
|
List[Union[UserMemory, Dict[str, Any]]]: List of upserted memories
|
|
@@ -1265,6 +1420,9 @@ class PostgresDb(BaseDb):
|
|
|
1265
1420
|
if memory.memory_id is None:
|
|
1266
1421
|
memory.memory_id = str(uuid4())
|
|
1267
1422
|
|
|
1423
|
+
# Use preserved updated_at if flag is set (even if None), otherwise use current time
|
|
1424
|
+
updated_at = memory.updated_at if preserve_updated_at else current_time
|
|
1425
|
+
|
|
1268
1426
|
memory_records.append(
|
|
1269
1427
|
{
|
|
1270
1428
|
"memory_id": memory.memory_id,
|
|
@@ -1274,7 +1432,9 @@ class PostgresDb(BaseDb):
|
|
|
1274
1432
|
"agent_id": memory.agent_id,
|
|
1275
1433
|
"team_id": memory.team_id,
|
|
1276
1434
|
"topics": memory.topics,
|
|
1277
|
-
"
|
|
1435
|
+
"feedback": memory.feedback,
|
|
1436
|
+
"created_at": memory.created_at,
|
|
1437
|
+
"updated_at": updated_at,
|
|
1278
1438
|
}
|
|
1279
1439
|
)
|
|
1280
1440
|
|
|
@@ -1285,7 +1445,7 @@ class PostgresDb(BaseDb):
|
|
|
1285
1445
|
update_columns = {
|
|
1286
1446
|
col.name: insert_stmt.excluded[col.name]
|
|
1287
1447
|
for col in table.columns
|
|
1288
|
-
if col.name not in ["memory_id"] # Don't update primary key
|
|
1448
|
+
if col.name not in ["memory_id", "created_at"] # Don't update primary key or created_at
|
|
1289
1449
|
}
|
|
1290
1450
|
stmt = insert_stmt.on_conflict_do_update(index_elements=["memory_id"], set_=update_columns).returning(
|
|
1291
1451
|
table
|
|
@@ -1926,6 +2086,233 @@ class PostgresDb(BaseDb):
|
|
|
1926
2086
|
log_error(f"Error upserting eval run name {eval_run_id}: {e}")
|
|
1927
2087
|
raise e
|
|
1928
2088
|
|
|
2089
|
+
# -- Culture methods --
|
|
2090
|
+
|
|
2091
|
+
def clear_cultural_knowledge(self) -> None:
|
|
2092
|
+
"""Delete all cultural knowledge from the database.
|
|
2093
|
+
|
|
2094
|
+
Raises:
|
|
2095
|
+
Exception: If an error occurs during deletion.
|
|
2096
|
+
"""
|
|
2097
|
+
try:
|
|
2098
|
+
table = self._get_table(table_type="culture")
|
|
2099
|
+
if table is None:
|
|
2100
|
+
return
|
|
2101
|
+
|
|
2102
|
+
with self.Session() as sess, sess.begin():
|
|
2103
|
+
sess.execute(table.delete())
|
|
2104
|
+
|
|
2105
|
+
except Exception as e:
|
|
2106
|
+
log_warning(f"Exception deleting all cultural knowledge: {e}")
|
|
2107
|
+
raise e
|
|
2108
|
+
|
|
2109
|
+
def delete_cultural_knowledge(self, id: str) -> None:
|
|
2110
|
+
"""Delete a cultural knowledge entry from the database.
|
|
2111
|
+
|
|
2112
|
+
Args:
|
|
2113
|
+
id (str): The ID of the cultural knowledge to delete.
|
|
2114
|
+
|
|
2115
|
+
Raises:
|
|
2116
|
+
Exception: If an error occurs during deletion.
|
|
2117
|
+
"""
|
|
2118
|
+
try:
|
|
2119
|
+
table = self._get_table(table_type="culture")
|
|
2120
|
+
if table is None:
|
|
2121
|
+
return
|
|
2122
|
+
|
|
2123
|
+
with self.Session() as sess, sess.begin():
|
|
2124
|
+
delete_stmt = table.delete().where(table.c.id == id)
|
|
2125
|
+
result = sess.execute(delete_stmt)
|
|
2126
|
+
|
|
2127
|
+
success = result.rowcount > 0
|
|
2128
|
+
if success:
|
|
2129
|
+
log_debug(f"Successfully deleted cultural knowledge id: {id}")
|
|
2130
|
+
else:
|
|
2131
|
+
log_debug(f"No cultural knowledge found with id: {id}")
|
|
2132
|
+
|
|
2133
|
+
except Exception as e:
|
|
2134
|
+
log_error(f"Error deleting cultural knowledge: {e}")
|
|
2135
|
+
raise e
|
|
2136
|
+
|
|
2137
|
+
def get_cultural_knowledge(
|
|
2138
|
+
self, id: str, deserialize: Optional[bool] = True
|
|
2139
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2140
|
+
"""Get a cultural knowledge entry from the database.
|
|
2141
|
+
|
|
2142
|
+
Args:
|
|
2143
|
+
id (str): The ID of the cultural knowledge to get.
|
|
2144
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2145
|
+
|
|
2146
|
+
Returns:
|
|
2147
|
+
Optional[Union[CulturalKnowledge, Dict[str, Any]]]: The cultural knowledge entry, or None if it doesn't exist.
|
|
2148
|
+
|
|
2149
|
+
Raises:
|
|
2150
|
+
Exception: If an error occurs during retrieval.
|
|
2151
|
+
"""
|
|
2152
|
+
try:
|
|
2153
|
+
table = self._get_table(table_type="culture")
|
|
2154
|
+
if table is None:
|
|
2155
|
+
return None
|
|
2156
|
+
|
|
2157
|
+
with self.Session() as sess, sess.begin():
|
|
2158
|
+
stmt = select(table).where(table.c.id == id)
|
|
2159
|
+
result = sess.execute(stmt).fetchone()
|
|
2160
|
+
if result is None:
|
|
2161
|
+
return None
|
|
2162
|
+
|
|
2163
|
+
db_row = dict(result._mapping)
|
|
2164
|
+
if not db_row or not deserialize:
|
|
2165
|
+
return db_row
|
|
2166
|
+
|
|
2167
|
+
return deserialize_cultural_knowledge(db_row)
|
|
2168
|
+
|
|
2169
|
+
except Exception as e:
|
|
2170
|
+
log_error(f"Exception reading from cultural knowledge table: {e}")
|
|
2171
|
+
raise e
|
|
2172
|
+
|
|
2173
|
+
def get_all_cultural_knowledge(
|
|
2174
|
+
self,
|
|
2175
|
+
name: Optional[str] = None,
|
|
2176
|
+
agent_id: Optional[str] = None,
|
|
2177
|
+
team_id: Optional[str] = None,
|
|
2178
|
+
limit: Optional[int] = None,
|
|
2179
|
+
page: Optional[int] = None,
|
|
2180
|
+
sort_by: Optional[str] = None,
|
|
2181
|
+
sort_order: Optional[str] = None,
|
|
2182
|
+
deserialize: Optional[bool] = True,
|
|
2183
|
+
) -> Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2184
|
+
"""Get all cultural knowledge from the database as CulturalKnowledge objects.
|
|
2185
|
+
|
|
2186
|
+
Args:
|
|
2187
|
+
name (Optional[str]): The name of the cultural knowledge to filter by.
|
|
2188
|
+
agent_id (Optional[str]): The ID of the agent to filter by.
|
|
2189
|
+
team_id (Optional[str]): The ID of the team to filter by.
|
|
2190
|
+
limit (Optional[int]): The maximum number of cultural knowledge entries to return.
|
|
2191
|
+
page (Optional[int]): The page number.
|
|
2192
|
+
sort_by (Optional[str]): The column to sort by.
|
|
2193
|
+
sort_order (Optional[str]): The order to sort by.
|
|
2194
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2195
|
+
|
|
2196
|
+
Returns:
|
|
2197
|
+
Union[List[CulturalKnowledge], Tuple[List[Dict[str, Any]], int]]:
|
|
2198
|
+
- When deserialize=True: List of CulturalKnowledge objects
|
|
2199
|
+
- When deserialize=False: List of CulturalKnowledge dictionaries and total count
|
|
2200
|
+
|
|
2201
|
+
Raises:
|
|
2202
|
+
Exception: If an error occurs during retrieval.
|
|
2203
|
+
"""
|
|
2204
|
+
try:
|
|
2205
|
+
table = self._get_table(table_type="culture")
|
|
2206
|
+
if table is None:
|
|
2207
|
+
return [] if deserialize else ([], 0)
|
|
2208
|
+
|
|
2209
|
+
with self.Session() as sess, sess.begin():
|
|
2210
|
+
stmt = select(table)
|
|
2211
|
+
|
|
2212
|
+
# Filtering
|
|
2213
|
+
if name is not None:
|
|
2214
|
+
stmt = stmt.where(table.c.name == name)
|
|
2215
|
+
if agent_id is not None:
|
|
2216
|
+
stmt = stmt.where(table.c.agent_id == agent_id)
|
|
2217
|
+
if team_id is not None:
|
|
2218
|
+
stmt = stmt.where(table.c.team_id == team_id)
|
|
2219
|
+
|
|
2220
|
+
# Get total count after applying filtering
|
|
2221
|
+
count_stmt = select(func.count()).select_from(stmt.alias())
|
|
2222
|
+
total_count = sess.execute(count_stmt).scalar()
|
|
2223
|
+
|
|
2224
|
+
# Sorting
|
|
2225
|
+
stmt = apply_sorting(stmt, table, sort_by, sort_order)
|
|
2226
|
+
# Paginating
|
|
2227
|
+
if limit is not None:
|
|
2228
|
+
stmt = stmt.limit(limit)
|
|
2229
|
+
if page is not None:
|
|
2230
|
+
stmt = stmt.offset((page - 1) * limit)
|
|
2231
|
+
|
|
2232
|
+
result = sess.execute(stmt).fetchall()
|
|
2233
|
+
if not result:
|
|
2234
|
+
return [] if deserialize else ([], 0)
|
|
2235
|
+
|
|
2236
|
+
db_rows = [dict(record._mapping) for record in result]
|
|
2237
|
+
|
|
2238
|
+
if not deserialize:
|
|
2239
|
+
return db_rows, total_count
|
|
2240
|
+
|
|
2241
|
+
return [deserialize_cultural_knowledge(row) for row in db_rows]
|
|
2242
|
+
|
|
2243
|
+
except Exception as e:
|
|
2244
|
+
log_error(f"Error reading from cultural knowledge table: {e}")
|
|
2245
|
+
raise e
|
|
2246
|
+
|
|
2247
|
+
def upsert_cultural_knowledge(
|
|
2248
|
+
self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
|
|
2249
|
+
) -> Optional[Union[CulturalKnowledge, Dict[str, Any]]]:
|
|
2250
|
+
"""Upsert a cultural knowledge entry into the database.
|
|
2251
|
+
|
|
2252
|
+
Args:
|
|
2253
|
+
cultural_knowledge (CulturalKnowledge): The cultural knowledge to upsert.
|
|
2254
|
+
deserialize (Optional[bool]): Whether to deserialize the cultural knowledge. Defaults to True.
|
|
2255
|
+
|
|
2256
|
+
Returns:
|
|
2257
|
+
Optional[CulturalKnowledge]: The upserted cultural knowledge entry.
|
|
2258
|
+
|
|
2259
|
+
Raises:
|
|
2260
|
+
Exception: If an error occurs during upsert.
|
|
2261
|
+
"""
|
|
2262
|
+
try:
|
|
2263
|
+
table = self._get_table(table_type="culture", create_table_if_not_found=True)
|
|
2264
|
+
if table is None:
|
|
2265
|
+
return None
|
|
2266
|
+
|
|
2267
|
+
if cultural_knowledge.id is None:
|
|
2268
|
+
cultural_knowledge.id = str(uuid4())
|
|
2269
|
+
|
|
2270
|
+
# Serialize content, categories, and notes into a JSON dict for DB storage
|
|
2271
|
+
content_dict = serialize_cultural_knowledge(cultural_knowledge)
|
|
2272
|
+
|
|
2273
|
+
with self.Session() as sess, sess.begin():
|
|
2274
|
+
stmt = postgresql.insert(table).values(
|
|
2275
|
+
id=cultural_knowledge.id,
|
|
2276
|
+
name=cultural_knowledge.name,
|
|
2277
|
+
summary=cultural_knowledge.summary,
|
|
2278
|
+
content=content_dict if content_dict else None,
|
|
2279
|
+
metadata=cultural_knowledge.metadata,
|
|
2280
|
+
input=cultural_knowledge.input,
|
|
2281
|
+
created_at=cultural_knowledge.created_at,
|
|
2282
|
+
updated_at=int(time.time()),
|
|
2283
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2284
|
+
team_id=cultural_knowledge.team_id,
|
|
2285
|
+
)
|
|
2286
|
+
stmt = stmt.on_conflict_do_update( # type: ignore
|
|
2287
|
+
index_elements=["id"],
|
|
2288
|
+
set_=dict(
|
|
2289
|
+
name=cultural_knowledge.name,
|
|
2290
|
+
summary=cultural_knowledge.summary,
|
|
2291
|
+
content=content_dict if content_dict else None,
|
|
2292
|
+
metadata=cultural_knowledge.metadata,
|
|
2293
|
+
input=cultural_knowledge.input,
|
|
2294
|
+
updated_at=int(time.time()),
|
|
2295
|
+
agent_id=cultural_knowledge.agent_id,
|
|
2296
|
+
team_id=cultural_knowledge.team_id,
|
|
2297
|
+
),
|
|
2298
|
+
).returning(table)
|
|
2299
|
+
|
|
2300
|
+
result = sess.execute(stmt)
|
|
2301
|
+
row = result.fetchone()
|
|
2302
|
+
|
|
2303
|
+
if row is None:
|
|
2304
|
+
return None
|
|
2305
|
+
|
|
2306
|
+
db_row = dict(row._mapping)
|
|
2307
|
+
if not db_row or not deserialize:
|
|
2308
|
+
return db_row
|
|
2309
|
+
|
|
2310
|
+
return deserialize_cultural_knowledge(db_row)
|
|
2311
|
+
|
|
2312
|
+
except Exception as e:
|
|
2313
|
+
log_error(f"Error upserting cultural knowledge: {e}")
|
|
2314
|
+
raise e
|
|
2315
|
+
|
|
1929
2316
|
# -- Migrations --
|
|
1930
2317
|
|
|
1931
2318
|
def migrate_table_from_v1_to_v2(self, v1_db_schema: str, v1_table_name: str, v1_table_type: str):
|
|
@@ -1983,3 +2370,501 @@ class PostgresDb(BaseDb):
|
|
|
1983
2370
|
for memory in memories:
|
|
1984
2371
|
self.upsert_user_memory(memory)
|
|
1985
2372
|
log_info(f"Migrated {len(memories)} memories to table: {self.memory_table}")
|
|
2373
|
+
|
|
2374
|
+
# --- Traces ---
|
|
2375
|
+
def _get_traces_base_query(self, table: Table, spans_table: Optional[Table] = None):
|
|
2376
|
+
"""Build base query for traces with aggregated span counts.
|
|
2377
|
+
|
|
2378
|
+
Args:
|
|
2379
|
+
table: The traces table.
|
|
2380
|
+
spans_table: The spans table (optional).
|
|
2381
|
+
|
|
2382
|
+
Returns:
|
|
2383
|
+
SQLAlchemy select statement with total_spans and error_count calculated dynamically.
|
|
2384
|
+
"""
|
|
2385
|
+
from sqlalchemy import case, literal
|
|
2386
|
+
|
|
2387
|
+
if spans_table is not None:
|
|
2388
|
+
# JOIN with spans table to calculate total_spans and error_count
|
|
2389
|
+
return (
|
|
2390
|
+
select(
|
|
2391
|
+
table,
|
|
2392
|
+
func.coalesce(func.count(spans_table.c.span_id), 0).label("total_spans"),
|
|
2393
|
+
func.coalesce(func.sum(case((spans_table.c.status_code == "ERROR", 1), else_=0)), 0).label(
|
|
2394
|
+
"error_count"
|
|
2395
|
+
),
|
|
2396
|
+
)
|
|
2397
|
+
.select_from(table.outerjoin(spans_table, table.c.trace_id == spans_table.c.trace_id))
|
|
2398
|
+
.group_by(table.c.trace_id)
|
|
2399
|
+
)
|
|
2400
|
+
else:
|
|
2401
|
+
# Fallback if spans table doesn't exist
|
|
2402
|
+
return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
|
|
2403
|
+
|
|
2404
|
+
def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
|
|
2405
|
+
"""Build a SQL CASE expression that returns the component level for a trace.
|
|
2406
|
+
|
|
2407
|
+
Component levels (higher = more important):
|
|
2408
|
+
- 3: Workflow root (.run or .arun with workflow_id)
|
|
2409
|
+
- 2: Team root (.run or .arun with team_id)
|
|
2410
|
+
- 1: Agent root (.run or .arun with agent_id)
|
|
2411
|
+
- 0: Child span (not a root)
|
|
2412
|
+
|
|
2413
|
+
Args:
|
|
2414
|
+
workflow_id_col: SQL column/expression for workflow_id
|
|
2415
|
+
team_id_col: SQL column/expression for team_id
|
|
2416
|
+
agent_id_col: SQL column/expression for agent_id
|
|
2417
|
+
name_col: SQL column/expression for name
|
|
2418
|
+
|
|
2419
|
+
Returns:
|
|
2420
|
+
SQLAlchemy CASE expression returning the component level as an integer.
|
|
2421
|
+
"""
|
|
2422
|
+
is_root_name = or_(name_col.contains(".run"), name_col.contains(".arun"))
|
|
2423
|
+
|
|
2424
|
+
return case(
|
|
2425
|
+
# Workflow root (level 3)
|
|
2426
|
+
(and_(workflow_id_col.isnot(None), is_root_name), 3),
|
|
2427
|
+
# Team root (level 2)
|
|
2428
|
+
(and_(team_id_col.isnot(None), is_root_name), 2),
|
|
2429
|
+
# Agent root (level 1)
|
|
2430
|
+
(and_(agent_id_col.isnot(None), is_root_name), 1),
|
|
2431
|
+
# Child span or unknown (level 0)
|
|
2432
|
+
else_=0,
|
|
2433
|
+
)
|
|
2434
|
+
|
|
2435
|
+
def upsert_trace(self, trace: "Trace") -> None:
|
|
2436
|
+
"""Create or update a single trace record in the database.
|
|
2437
|
+
|
|
2438
|
+
Uses INSERT ... ON CONFLICT DO UPDATE (upsert) to handle concurrent inserts
|
|
2439
|
+
atomically and avoid race conditions.
|
|
2440
|
+
|
|
2441
|
+
Args:
|
|
2442
|
+
trace: The Trace object to store (one per trace_id).
|
|
2443
|
+
"""
|
|
2444
|
+
try:
|
|
2445
|
+
table = self._get_table(table_type="traces", create_table_if_not_found=True)
|
|
2446
|
+
if table is None:
|
|
2447
|
+
return
|
|
2448
|
+
|
|
2449
|
+
trace_dict = trace.to_dict()
|
|
2450
|
+
trace_dict.pop("total_spans", None)
|
|
2451
|
+
trace_dict.pop("error_count", None)
|
|
2452
|
+
|
|
2453
|
+
with self.Session() as sess, sess.begin():
|
|
2454
|
+
# Use upsert to handle concurrent inserts atomically
|
|
2455
|
+
# On conflict, update fields while preserving existing non-null context values
|
|
2456
|
+
# and keeping the earliest start_time
|
|
2457
|
+
insert_stmt = postgresql.insert(table).values(trace_dict)
|
|
2458
|
+
|
|
2459
|
+
# Build component level expressions for comparing trace priority
|
|
2460
|
+
new_level = self._get_trace_component_level_expr(
|
|
2461
|
+
insert_stmt.excluded.workflow_id,
|
|
2462
|
+
insert_stmt.excluded.team_id,
|
|
2463
|
+
insert_stmt.excluded.agent_id,
|
|
2464
|
+
insert_stmt.excluded.name,
|
|
2465
|
+
)
|
|
2466
|
+
existing_level = self._get_trace_component_level_expr(
|
|
2467
|
+
table.c.workflow_id,
|
|
2468
|
+
table.c.team_id,
|
|
2469
|
+
table.c.agent_id,
|
|
2470
|
+
table.c.name,
|
|
2471
|
+
)
|
|
2472
|
+
|
|
2473
|
+
# Build the ON CONFLICT DO UPDATE clause
|
|
2474
|
+
# Use LEAST for start_time, GREATEST for end_time to capture full trace duration
|
|
2475
|
+
# Use COALESCE to preserve existing non-null context values
|
|
2476
|
+
upsert_stmt = insert_stmt.on_conflict_do_update(
|
|
2477
|
+
index_elements=["trace_id"],
|
|
2478
|
+
set_={
|
|
2479
|
+
"end_time": func.greatest(table.c.end_time, insert_stmt.excluded.end_time),
|
|
2480
|
+
"start_time": func.least(table.c.start_time, insert_stmt.excluded.start_time),
|
|
2481
|
+
"duration_ms": func.extract(
|
|
2482
|
+
"epoch",
|
|
2483
|
+
func.cast(
|
|
2484
|
+
func.greatest(table.c.end_time, insert_stmt.excluded.end_time),
|
|
2485
|
+
TIMESTAMP(timezone=True),
|
|
2486
|
+
)
|
|
2487
|
+
- func.cast(
|
|
2488
|
+
func.least(table.c.start_time, insert_stmt.excluded.start_time),
|
|
2489
|
+
TIMESTAMP(timezone=True),
|
|
2490
|
+
),
|
|
2491
|
+
)
|
|
2492
|
+
* 1000,
|
|
2493
|
+
"status": insert_stmt.excluded.status,
|
|
2494
|
+
# Update name only if new trace is from a higher-level component
|
|
2495
|
+
# Priority: workflow (3) > team (2) > agent (1) > child spans (0)
|
|
2496
|
+
"name": case(
|
|
2497
|
+
(new_level > existing_level, insert_stmt.excluded.name),
|
|
2498
|
+
else_=table.c.name,
|
|
2499
|
+
),
|
|
2500
|
+
# Preserve existing non-null context values using COALESCE
|
|
2501
|
+
"run_id": func.coalesce(insert_stmt.excluded.run_id, table.c.run_id),
|
|
2502
|
+
"session_id": func.coalesce(insert_stmt.excluded.session_id, table.c.session_id),
|
|
2503
|
+
"user_id": func.coalesce(insert_stmt.excluded.user_id, table.c.user_id),
|
|
2504
|
+
"agent_id": func.coalesce(insert_stmt.excluded.agent_id, table.c.agent_id),
|
|
2505
|
+
"team_id": func.coalesce(insert_stmt.excluded.team_id, table.c.team_id),
|
|
2506
|
+
"workflow_id": func.coalesce(insert_stmt.excluded.workflow_id, table.c.workflow_id),
|
|
2507
|
+
},
|
|
2508
|
+
)
|
|
2509
|
+
sess.execute(upsert_stmt)
|
|
2510
|
+
|
|
2511
|
+
except Exception as e:
|
|
2512
|
+
log_error(f"Error creating trace: {e}")
|
|
2513
|
+
# Don't raise - tracing should not break the main application flow
|
|
2514
|
+
|
|
2515
|
+
def get_trace(
|
|
2516
|
+
self,
|
|
2517
|
+
trace_id: Optional[str] = None,
|
|
2518
|
+
run_id: Optional[str] = None,
|
|
2519
|
+
):
|
|
2520
|
+
"""Get a single trace by trace_id or other filters.
|
|
2521
|
+
|
|
2522
|
+
Args:
|
|
2523
|
+
trace_id: The unique trace identifier.
|
|
2524
|
+
run_id: Filter by run ID (returns first match).
|
|
2525
|
+
|
|
2526
|
+
Returns:
|
|
2527
|
+
Optional[Trace]: The trace if found, None otherwise.
|
|
2528
|
+
|
|
2529
|
+
Note:
|
|
2530
|
+
If multiple filters are provided, trace_id takes precedence.
|
|
2531
|
+
For other filters, the most recent trace is returned.
|
|
2532
|
+
"""
|
|
2533
|
+
try:
|
|
2534
|
+
from agno.tracing.schemas import Trace
|
|
2535
|
+
|
|
2536
|
+
table = self._get_table(table_type="traces")
|
|
2537
|
+
if table is None:
|
|
2538
|
+
return None
|
|
2539
|
+
|
|
2540
|
+
# Get spans table for JOIN
|
|
2541
|
+
spans_table = self._get_table(table_type="spans")
|
|
2542
|
+
|
|
2543
|
+
with self.Session() as sess:
|
|
2544
|
+
# Build query with aggregated span counts
|
|
2545
|
+
stmt = self._get_traces_base_query(table, spans_table)
|
|
2546
|
+
|
|
2547
|
+
if trace_id:
|
|
2548
|
+
stmt = stmt.where(table.c.trace_id == trace_id)
|
|
2549
|
+
elif run_id:
|
|
2550
|
+
stmt = stmt.where(table.c.run_id == run_id)
|
|
2551
|
+
else:
|
|
2552
|
+
log_debug("get_trace called without any filter parameters")
|
|
2553
|
+
return None
|
|
2554
|
+
|
|
2555
|
+
# Order by most recent and get first result
|
|
2556
|
+
stmt = stmt.order_by(table.c.start_time.desc()).limit(1)
|
|
2557
|
+
result = sess.execute(stmt).fetchone()
|
|
2558
|
+
|
|
2559
|
+
if result:
|
|
2560
|
+
return Trace.from_dict(dict(result._mapping))
|
|
2561
|
+
return None
|
|
2562
|
+
|
|
2563
|
+
except Exception as e:
|
|
2564
|
+
log_error(f"Error getting trace: {e}")
|
|
2565
|
+
return None
|
|
2566
|
+
|
|
2567
|
+
def get_traces(
|
|
2568
|
+
self,
|
|
2569
|
+
run_id: Optional[str] = None,
|
|
2570
|
+
session_id: Optional[str] = None,
|
|
2571
|
+
user_id: Optional[str] = None,
|
|
2572
|
+
agent_id: Optional[str] = None,
|
|
2573
|
+
team_id: Optional[str] = None,
|
|
2574
|
+
workflow_id: Optional[str] = None,
|
|
2575
|
+
status: Optional[str] = None,
|
|
2576
|
+
start_time: Optional[datetime] = None,
|
|
2577
|
+
end_time: Optional[datetime] = None,
|
|
2578
|
+
limit: Optional[int] = 20,
|
|
2579
|
+
page: Optional[int] = 1,
|
|
2580
|
+
) -> tuple[List, int]:
|
|
2581
|
+
"""Get traces matching the provided filters with pagination.
|
|
2582
|
+
|
|
2583
|
+
Args:
|
|
2584
|
+
run_id: Filter by run ID.
|
|
2585
|
+
session_id: Filter by session ID.
|
|
2586
|
+
user_id: Filter by user ID.
|
|
2587
|
+
agent_id: Filter by agent ID.
|
|
2588
|
+
team_id: Filter by team ID.
|
|
2589
|
+
workflow_id: Filter by workflow ID.
|
|
2590
|
+
status: Filter by status (OK, ERROR, UNSET).
|
|
2591
|
+
start_time: Filter traces starting after this datetime.
|
|
2592
|
+
end_time: Filter traces ending before this datetime.
|
|
2593
|
+
limit: Maximum number of traces to return per page.
|
|
2594
|
+
page: Page number (1-indexed).
|
|
2595
|
+
|
|
2596
|
+
Returns:
|
|
2597
|
+
tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
|
|
2598
|
+
"""
|
|
2599
|
+
try:
|
|
2600
|
+
from agno.tracing.schemas import Trace
|
|
2601
|
+
|
|
2602
|
+
table = self._get_table(table_type="traces")
|
|
2603
|
+
if table is None:
|
|
2604
|
+
log_debug("Traces table not found")
|
|
2605
|
+
return [], 0
|
|
2606
|
+
|
|
2607
|
+
# Get spans table for JOIN
|
|
2608
|
+
spans_table = self._get_table(table_type="spans")
|
|
2609
|
+
|
|
2610
|
+
with self.Session() as sess:
|
|
2611
|
+
# Build base query with aggregated span counts
|
|
2612
|
+
base_stmt = self._get_traces_base_query(table, spans_table)
|
|
2613
|
+
|
|
2614
|
+
# Apply filters
|
|
2615
|
+
if run_id:
|
|
2616
|
+
base_stmt = base_stmt.where(table.c.run_id == run_id)
|
|
2617
|
+
if session_id:
|
|
2618
|
+
base_stmt = base_stmt.where(table.c.session_id == session_id)
|
|
2619
|
+
if user_id:
|
|
2620
|
+
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
2621
|
+
if agent_id:
|
|
2622
|
+
base_stmt = base_stmt.where(table.c.agent_id == agent_id)
|
|
2623
|
+
if team_id:
|
|
2624
|
+
base_stmt = base_stmt.where(table.c.team_id == team_id)
|
|
2625
|
+
if workflow_id:
|
|
2626
|
+
base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
|
|
2627
|
+
if status:
|
|
2628
|
+
base_stmt = base_stmt.where(table.c.status == status)
|
|
2629
|
+
if start_time:
|
|
2630
|
+
# Convert datetime to ISO string for comparison
|
|
2631
|
+
base_stmt = base_stmt.where(table.c.start_time >= start_time.isoformat())
|
|
2632
|
+
if end_time:
|
|
2633
|
+
# Convert datetime to ISO string for comparison
|
|
2634
|
+
base_stmt = base_stmt.where(table.c.end_time <= end_time.isoformat())
|
|
2635
|
+
|
|
2636
|
+
# Get total count
|
|
2637
|
+
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2638
|
+
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2639
|
+
|
|
2640
|
+
# Apply pagination
|
|
2641
|
+
offset = (page - 1) * limit if page and limit else 0
|
|
2642
|
+
paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
|
|
2643
|
+
|
|
2644
|
+
results = sess.execute(paginated_stmt).fetchall()
|
|
2645
|
+
|
|
2646
|
+
traces = [Trace.from_dict(dict(row._mapping)) for row in results]
|
|
2647
|
+
return traces, total_count
|
|
2648
|
+
|
|
2649
|
+
except Exception as e:
|
|
2650
|
+
log_error(f"Error getting traces: {e}")
|
|
2651
|
+
return [], 0
|
|
2652
|
+
|
|
2653
|
+
def get_trace_stats(
|
|
2654
|
+
self,
|
|
2655
|
+
user_id: Optional[str] = None,
|
|
2656
|
+
agent_id: Optional[str] = None,
|
|
2657
|
+
team_id: Optional[str] = None,
|
|
2658
|
+
workflow_id: Optional[str] = None,
|
|
2659
|
+
start_time: Optional[datetime] = None,
|
|
2660
|
+
end_time: Optional[datetime] = None,
|
|
2661
|
+
limit: Optional[int] = 20,
|
|
2662
|
+
page: Optional[int] = 1,
|
|
2663
|
+
) -> tuple[List[Dict[str, Any]], int]:
|
|
2664
|
+
"""Get trace statistics grouped by session.
|
|
2665
|
+
|
|
2666
|
+
Args:
|
|
2667
|
+
user_id: Filter by user ID.
|
|
2668
|
+
agent_id: Filter by agent ID.
|
|
2669
|
+
team_id: Filter by team ID.
|
|
2670
|
+
workflow_id: Filter by workflow ID.
|
|
2671
|
+
start_time: Filter sessions with traces created after this datetime.
|
|
2672
|
+
end_time: Filter sessions with traces created before this datetime.
|
|
2673
|
+
limit: Maximum number of sessions to return per page.
|
|
2674
|
+
page: Page number (1-indexed).
|
|
2675
|
+
|
|
2676
|
+
Returns:
|
|
2677
|
+
tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
|
|
2678
|
+
Each dict contains: session_id, user_id, agent_id, team_id, total_traces,
|
|
2679
|
+
first_trace_at, last_trace_at.
|
|
2680
|
+
"""
|
|
2681
|
+
try:
|
|
2682
|
+
table = self._get_table(table_type="traces")
|
|
2683
|
+
if table is None:
|
|
2684
|
+
log_debug("Traces table not found")
|
|
2685
|
+
return [], 0
|
|
2686
|
+
|
|
2687
|
+
with self.Session() as sess:
|
|
2688
|
+
# Build base query grouped by session_id
|
|
2689
|
+
base_stmt = (
|
|
2690
|
+
select(
|
|
2691
|
+
table.c.session_id,
|
|
2692
|
+
table.c.user_id,
|
|
2693
|
+
table.c.agent_id,
|
|
2694
|
+
table.c.team_id,
|
|
2695
|
+
table.c.workflow_id,
|
|
2696
|
+
func.count(table.c.trace_id).label("total_traces"),
|
|
2697
|
+
func.min(table.c.created_at).label("first_trace_at"),
|
|
2698
|
+
func.max(table.c.created_at).label("last_trace_at"),
|
|
2699
|
+
)
|
|
2700
|
+
.where(table.c.session_id.isnot(None)) # Only sessions with session_id
|
|
2701
|
+
.group_by(
|
|
2702
|
+
table.c.session_id, table.c.user_id, table.c.agent_id, table.c.team_id, table.c.workflow_id
|
|
2703
|
+
)
|
|
2704
|
+
)
|
|
2705
|
+
|
|
2706
|
+
# Apply filters
|
|
2707
|
+
if user_id:
|
|
2708
|
+
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
2709
|
+
if workflow_id:
|
|
2710
|
+
base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
|
|
2711
|
+
if team_id:
|
|
2712
|
+
base_stmt = base_stmt.where(table.c.team_id == team_id)
|
|
2713
|
+
if agent_id:
|
|
2714
|
+
base_stmt = base_stmt.where(table.c.agent_id == agent_id)
|
|
2715
|
+
if start_time:
|
|
2716
|
+
# Convert datetime to ISO string for comparison
|
|
2717
|
+
base_stmt = base_stmt.where(table.c.created_at >= start_time.isoformat())
|
|
2718
|
+
if end_time:
|
|
2719
|
+
# Convert datetime to ISO string for comparison
|
|
2720
|
+
base_stmt = base_stmt.where(table.c.created_at <= end_time.isoformat())
|
|
2721
|
+
|
|
2722
|
+
# Get total count of sessions
|
|
2723
|
+
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2724
|
+
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2725
|
+
|
|
2726
|
+
# Apply pagination and ordering
|
|
2727
|
+
offset = (page - 1) * limit if page and limit else 0
|
|
2728
|
+
paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
|
|
2729
|
+
|
|
2730
|
+
results = sess.execute(paginated_stmt).fetchall()
|
|
2731
|
+
|
|
2732
|
+
# Convert to list of dicts with datetime objects
|
|
2733
|
+
stats_list = []
|
|
2734
|
+
for row in results:
|
|
2735
|
+
# Convert ISO strings to datetime objects
|
|
2736
|
+
first_trace_at_str = row.first_trace_at
|
|
2737
|
+
last_trace_at_str = row.last_trace_at
|
|
2738
|
+
|
|
2739
|
+
# Parse ISO format strings to datetime objects
|
|
2740
|
+
first_trace_at = datetime.fromisoformat(first_trace_at_str.replace("Z", "+00:00"))
|
|
2741
|
+
last_trace_at = datetime.fromisoformat(last_trace_at_str.replace("Z", "+00:00"))
|
|
2742
|
+
|
|
2743
|
+
stats_list.append(
|
|
2744
|
+
{
|
|
2745
|
+
"session_id": row.session_id,
|
|
2746
|
+
"user_id": row.user_id,
|
|
2747
|
+
"agent_id": row.agent_id,
|
|
2748
|
+
"team_id": row.team_id,
|
|
2749
|
+
"workflow_id": row.workflow_id,
|
|
2750
|
+
"total_traces": row.total_traces,
|
|
2751
|
+
"first_trace_at": first_trace_at,
|
|
2752
|
+
"last_trace_at": last_trace_at,
|
|
2753
|
+
}
|
|
2754
|
+
)
|
|
2755
|
+
|
|
2756
|
+
return stats_list, total_count
|
|
2757
|
+
|
|
2758
|
+
except Exception as e:
|
|
2759
|
+
log_error(f"Error getting trace stats: {e}")
|
|
2760
|
+
return [], 0
|
|
2761
|
+
|
|
2762
|
+
# --- Spans ---
|
|
2763
|
+
def create_span(self, span: "Span") -> None:
|
|
2764
|
+
"""Create a single span in the database.
|
|
2765
|
+
|
|
2766
|
+
Args:
|
|
2767
|
+
span: The Span object to store.
|
|
2768
|
+
"""
|
|
2769
|
+
try:
|
|
2770
|
+
table = self._get_table(table_type="spans", create_table_if_not_found=True)
|
|
2771
|
+
if table is None:
|
|
2772
|
+
return
|
|
2773
|
+
|
|
2774
|
+
with self.Session() as sess, sess.begin():
|
|
2775
|
+
stmt = postgresql.insert(table).values(span.to_dict())
|
|
2776
|
+
sess.execute(stmt)
|
|
2777
|
+
|
|
2778
|
+
except Exception as e:
|
|
2779
|
+
log_error(f"Error creating span: {e}")
|
|
2780
|
+
|
|
2781
|
+
def create_spans(self, spans: List) -> None:
|
|
2782
|
+
"""Create multiple spans in the database as a batch.
|
|
2783
|
+
|
|
2784
|
+
Args:
|
|
2785
|
+
spans: List of Span objects to store.
|
|
2786
|
+
"""
|
|
2787
|
+
if not spans:
|
|
2788
|
+
return
|
|
2789
|
+
|
|
2790
|
+
try:
|
|
2791
|
+
table = self._get_table(table_type="spans", create_table_if_not_found=True)
|
|
2792
|
+
if table is None:
|
|
2793
|
+
return
|
|
2794
|
+
|
|
2795
|
+
with self.Session() as sess, sess.begin():
|
|
2796
|
+
for span in spans:
|
|
2797
|
+
stmt = postgresql.insert(table).values(span.to_dict())
|
|
2798
|
+
sess.execute(stmt)
|
|
2799
|
+
|
|
2800
|
+
except Exception as e:
|
|
2801
|
+
log_error(f"Error creating spans batch: {e}")
|
|
2802
|
+
|
|
2803
|
+
def get_span(self, span_id: str):
|
|
2804
|
+
"""Get a single span by its span_id.
|
|
2805
|
+
|
|
2806
|
+
Args:
|
|
2807
|
+
span_id: The unique span identifier.
|
|
2808
|
+
|
|
2809
|
+
Returns:
|
|
2810
|
+
Optional[Span]: The span if found, None otherwise.
|
|
2811
|
+
"""
|
|
2812
|
+
try:
|
|
2813
|
+
from agno.tracing.schemas import Span
|
|
2814
|
+
|
|
2815
|
+
table = self._get_table(table_type="spans")
|
|
2816
|
+
if table is None:
|
|
2817
|
+
return None
|
|
2818
|
+
|
|
2819
|
+
with self.Session() as sess:
|
|
2820
|
+
stmt = select(table).where(table.c.span_id == span_id)
|
|
2821
|
+
result = sess.execute(stmt).fetchone()
|
|
2822
|
+
if result:
|
|
2823
|
+
return Span.from_dict(dict(result._mapping))
|
|
2824
|
+
return None
|
|
2825
|
+
|
|
2826
|
+
except Exception as e:
|
|
2827
|
+
log_error(f"Error getting span: {e}")
|
|
2828
|
+
return None
|
|
2829
|
+
|
|
2830
|
+
def get_spans(
|
|
2831
|
+
self,
|
|
2832
|
+
trace_id: Optional[str] = None,
|
|
2833
|
+
parent_span_id: Optional[str] = None,
|
|
2834
|
+
limit: Optional[int] = 1000,
|
|
2835
|
+
) -> List:
|
|
2836
|
+
"""Get spans matching the provided filters.
|
|
2837
|
+
|
|
2838
|
+
Args:
|
|
2839
|
+
trace_id: Filter by trace ID.
|
|
2840
|
+
parent_span_id: Filter by parent span ID.
|
|
2841
|
+
limit: Maximum number of spans to return.
|
|
2842
|
+
|
|
2843
|
+
Returns:
|
|
2844
|
+
List[Span]: List of matching spans.
|
|
2845
|
+
"""
|
|
2846
|
+
try:
|
|
2847
|
+
from agno.tracing.schemas import Span
|
|
2848
|
+
|
|
2849
|
+
table = self._get_table(table_type="spans")
|
|
2850
|
+
if table is None:
|
|
2851
|
+
return []
|
|
2852
|
+
|
|
2853
|
+
with self.Session() as sess:
|
|
2854
|
+
stmt = select(table)
|
|
2855
|
+
|
|
2856
|
+
# Apply filters
|
|
2857
|
+
if trace_id:
|
|
2858
|
+
stmt = stmt.where(table.c.trace_id == trace_id)
|
|
2859
|
+
if parent_span_id:
|
|
2860
|
+
stmt = stmt.where(table.c.parent_span_id == parent_span_id)
|
|
2861
|
+
|
|
2862
|
+
if limit:
|
|
2863
|
+
stmt = stmt.limit(limit)
|
|
2864
|
+
|
|
2865
|
+
results = sess.execute(stmt).fetchall()
|
|
2866
|
+
return [Span.from_dict(dict(row._mapping)) for row in results]
|
|
2867
|
+
|
|
2868
|
+
except Exception as e:
|
|
2869
|
+
log_error(f"Error getting spans: {e}")
|
|
2870
|
+
return []
|