agno 2.3.7__py3-none-any.whl → 2.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +391 -335
- agno/db/mongo/async_mongo.py +0 -24
- agno/db/mongo/mongo.py +0 -16
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2888 -0
- agno/db/mysql/mysql.py +17 -27
- agno/db/mysql/utils.py +139 -6
- agno/db/postgres/async_postgres.py +10 -26
- agno/db/postgres/postgres.py +7 -25
- agno/db/redis/redis.py +0 -4
- agno/db/schemas/evals.py +1 -0
- agno/db/singlestore/singlestore.py +5 -12
- agno/db/sqlite/async_sqlite.py +2 -26
- agno/db/sqlite/sqlite.py +0 -20
- agno/eval/__init__.py +10 -0
- agno/eval/agent_as_judge.py +860 -0
- agno/eval/base.py +29 -0
- agno/eval/utils.py +2 -1
- agno/exceptions.py +7 -0
- agno/knowledge/embedder/openai.py +8 -8
- agno/knowledge/knowledge.py +1142 -176
- agno/media.py +22 -6
- agno/models/aws/claude.py +8 -7
- agno/models/base.py +160 -11
- agno/models/deepseek/deepseek.py +67 -0
- agno/models/google/gemini.py +65 -11
- agno/models/google/utils.py +22 -0
- agno/models/message.py +2 -0
- agno/models/openai/chat.py +4 -0
- agno/models/openai/responses.py +3 -2
- agno/os/app.py +64 -74
- agno/os/interfaces/a2a/router.py +3 -4
- agno/os/interfaces/a2a/utils.py +1 -1
- agno/os/interfaces/agui/router.py +2 -0
- agno/os/middleware/jwt.py +8 -6
- agno/os/router.py +3 -1607
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +581 -0
- agno/os/routers/agents/schema.py +261 -0
- agno/os/routers/evals/evals.py +26 -6
- agno/os/routers/evals/schemas.py +34 -2
- agno/os/routers/evals/utils.py +101 -20
- agno/os/routers/knowledge/knowledge.py +1 -1
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +496 -0
- agno/os/routers/teams/schema.py +257 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +545 -0
- agno/os/routers/workflows/schema.py +75 -0
- agno/os/schema.py +1 -559
- agno/os/utils.py +139 -2
- agno/team/team.py +159 -100
- agno/tools/file_generation.py +12 -6
- agno/tools/firecrawl.py +15 -7
- agno/tools/workflow.py +8 -1
- agno/utils/hooks.py +64 -5
- agno/utils/http.py +2 -2
- agno/utils/media.py +11 -1
- agno/utils/print_response/agent.py +8 -0
- agno/utils/print_response/team.py +8 -0
- agno/vectordb/pgvector/pgvector.py +88 -51
- agno/workflow/parallel.py +11 -5
- agno/workflow/step.py +17 -5
- agno/workflow/types.py +38 -2
- agno/workflow/workflow.py +12 -4
- {agno-2.3.7.dist-info → agno-2.3.9.dist-info}/METADATA +8 -3
- {agno-2.3.7.dist-info → agno-2.3.9.dist-info}/RECORD +70 -58
- agno/tools/memori.py +0 -339
- {agno-2.3.7.dist-info → agno-2.3.9.dist-info}/WHEEL +0 -0
- {agno-2.3.7.dist-info → agno-2.3.9.dist-info}/licenses/LICENSE +0 -0
- {agno-2.3.7.dist-info → agno-2.3.9.dist-info}/top_level.txt +0 -0
agno/db/mysql/mysql.py
CHANGED
|
@@ -3,8 +3,6 @@ from datetime import date, datetime, timedelta, timezone
|
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
|
-
from sqlalchemy import ForeignKey, Index, UniqueConstraint
|
|
7
|
-
|
|
8
6
|
if TYPE_CHECKING:
|
|
9
7
|
from agno.tracing.schemas import Span, Trace
|
|
10
8
|
|
|
@@ -32,7 +30,7 @@ from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
|
32
30
|
from agno.utils.string import generate_id
|
|
33
31
|
|
|
34
32
|
try:
|
|
35
|
-
from sqlalchemy import TEXT, and_, cast, func, update
|
|
33
|
+
from sqlalchemy import TEXT, ForeignKey, Index, UniqueConstraint, and_, cast, func, update
|
|
36
34
|
from sqlalchemy.dialects import mysql
|
|
37
35
|
from sqlalchemy.engine import Engine, create_engine
|
|
38
36
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
|
@@ -45,6 +43,7 @@ except ImportError:
|
|
|
45
43
|
class MySQLDb(BaseDb):
|
|
46
44
|
def __init__(
|
|
47
45
|
self,
|
|
46
|
+
id: Optional[str] = None,
|
|
48
47
|
db_engine: Optional[Engine] = None,
|
|
49
48
|
db_schema: Optional[str] = None,
|
|
50
49
|
db_url: Optional[str] = None,
|
|
@@ -57,7 +56,7 @@ class MySQLDb(BaseDb):
|
|
|
57
56
|
traces_table: Optional[str] = None,
|
|
58
57
|
spans_table: Optional[str] = None,
|
|
59
58
|
versions_table: Optional[str] = None,
|
|
60
|
-
|
|
59
|
+
create_schema: bool = True,
|
|
61
60
|
):
|
|
62
61
|
"""
|
|
63
62
|
Interface for interacting with a MySQL database.
|
|
@@ -68,6 +67,7 @@ class MySQLDb(BaseDb):
|
|
|
68
67
|
3. Raise an error if neither is provided
|
|
69
68
|
|
|
70
69
|
Args:
|
|
70
|
+
id (Optional[str]): ID of the database.
|
|
71
71
|
db_url (Optional[str]): The database URL to connect to.
|
|
72
72
|
db_engine (Optional[Engine]): The SQLAlchemy database engine to use.
|
|
73
73
|
db_schema (Optional[str]): The database schema to use.
|
|
@@ -80,7 +80,8 @@ class MySQLDb(BaseDb):
|
|
|
80
80
|
traces_table (Optional[str]): Name of the table to store run traces.
|
|
81
81
|
spans_table (Optional[str]): Name of the table to store span events.
|
|
82
82
|
versions_table (Optional[str]): Name of the table to store schema versions.
|
|
83
|
-
|
|
83
|
+
create_schema (bool): Whether to automatically create the database schema if it doesn't exist.
|
|
84
|
+
Set to False if schema is managed externally (e.g., via migrations). Defaults to True.
|
|
84
85
|
|
|
85
86
|
Raises:
|
|
86
87
|
ValueError: If neither db_url nor db_engine is provided.
|
|
@@ -115,6 +116,7 @@ class MySQLDb(BaseDb):
|
|
|
115
116
|
self.db_engine: Engine = _engine
|
|
116
117
|
self.db_schema: str = db_schema if db_schema is not None else "ai"
|
|
117
118
|
self.metadata: MetaData = MetaData(schema=self.db_schema)
|
|
119
|
+
self.create_schema: bool = create_schema
|
|
118
120
|
|
|
119
121
|
# Initialize database session
|
|
120
122
|
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine))
|
|
@@ -190,8 +192,9 @@ class MySQLDb(BaseDb):
|
|
|
190
192
|
idx_name = f"idx_{table_name}_{idx_col}"
|
|
191
193
|
table.append_constraint(Index(idx_name, idx_col))
|
|
192
194
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
+
if self.create_schema:
|
|
196
|
+
with self.Session() as sess, sess.begin():
|
|
197
|
+
create_schema(session=sess, db_schema=self.db_schema)
|
|
195
198
|
|
|
196
199
|
# Create table
|
|
197
200
|
table_created = False
|
|
@@ -252,6 +255,9 @@ class MySQLDb(BaseDb):
|
|
|
252
255
|
(self.metrics_table_name, "metrics"),
|
|
253
256
|
(self.eval_table_name, "evals"),
|
|
254
257
|
(self.knowledge_table_name, "knowledge"),
|
|
258
|
+
(self.culture_table_name, "culture"),
|
|
259
|
+
(self.trace_table_name, "traces"),
|
|
260
|
+
(self.span_table_name, "spans"),
|
|
255
261
|
(self.versions_table_name, "versions"),
|
|
256
262
|
]
|
|
257
263
|
|
|
@@ -1109,9 +1115,12 @@ class MySQLDb(BaseDb):
|
|
|
1109
1115
|
except Exception as e:
|
|
1110
1116
|
log_error(f"Error deleting user memories: {e}")
|
|
1111
1117
|
|
|
1112
|
-
def get_all_memory_topics(self) -> List[str]:
|
|
1118
|
+
def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
|
|
1113
1119
|
"""Get all memory topics from the database.
|
|
1114
1120
|
|
|
1121
|
+
Args:
|
|
1122
|
+
user_id (Optional[str]): Optional user ID to filter topics.
|
|
1123
|
+
|
|
1115
1124
|
Returns:
|
|
1116
1125
|
List[str]: List of memory topics.
|
|
1117
1126
|
"""
|
|
@@ -2514,14 +2523,6 @@ class MySQLDb(BaseDb):
|
|
|
2514
2523
|
if trace.workflow_id is not None:
|
|
2515
2524
|
update_values["workflow_id"] = trace.workflow_id
|
|
2516
2525
|
|
|
2517
|
-
log_debug(
|
|
2518
|
-
f" Updating trace with context: run_id={update_values.get('run_id', 'unchanged')}, "
|
|
2519
|
-
f"session_id={update_values.get('session_id', 'unchanged')}, "
|
|
2520
|
-
f"user_id={update_values.get('user_id', 'unchanged')}, "
|
|
2521
|
-
f"agent_id={update_values.get('agent_id', 'unchanged')}, "
|
|
2522
|
-
f"team_id={update_values.get('team_id', 'unchanged')}, "
|
|
2523
|
-
)
|
|
2524
|
-
|
|
2525
2526
|
stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
|
|
2526
2527
|
sess.execute(stmt)
|
|
2527
2528
|
else:
|
|
@@ -2642,7 +2643,6 @@ class MySQLDb(BaseDb):
|
|
|
2642
2643
|
if run_id:
|
|
2643
2644
|
base_stmt = base_stmt.where(table.c.run_id == run_id)
|
|
2644
2645
|
if session_id:
|
|
2645
|
-
log_debug(f"Filtering by session_id={session_id}")
|
|
2646
2646
|
base_stmt = base_stmt.where(table.c.session_id == session_id)
|
|
2647
2647
|
if user_id:
|
|
2648
2648
|
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
@@ -2664,14 +2664,12 @@ class MySQLDb(BaseDb):
|
|
|
2664
2664
|
# Get total count
|
|
2665
2665
|
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2666
2666
|
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2667
|
-
log_debug(f"Total matching traces: {total_count}")
|
|
2668
2667
|
|
|
2669
2668
|
# Apply pagination
|
|
2670
2669
|
offset = (page - 1) * limit if page and limit else 0
|
|
2671
2670
|
paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
|
|
2672
2671
|
|
|
2673
2672
|
results = sess.execute(paginated_stmt).fetchall()
|
|
2674
|
-
log_debug(f"Returning page {page} with {len(results)} traces")
|
|
2675
2673
|
|
|
2676
2674
|
traces = [Trace.from_dict(dict(row._mapping)) for row in results]
|
|
2677
2675
|
return traces, total_count
|
|
@@ -2709,12 +2707,6 @@ class MySQLDb(BaseDb):
|
|
|
2709
2707
|
workflow_id, first_trace_at, last_trace_at.
|
|
2710
2708
|
"""
|
|
2711
2709
|
try:
|
|
2712
|
-
log_debug(
|
|
2713
|
-
f"get_trace_stats called with filters: user_id={user_id}, agent_id={agent_id}, "
|
|
2714
|
-
f"workflow_id={workflow_id}, team_id={team_id}, "
|
|
2715
|
-
f"start_time={start_time}, end_time={end_time}, page={page}, limit={limit}"
|
|
2716
|
-
)
|
|
2717
|
-
|
|
2718
2710
|
table = self._get_table(table_type="traces")
|
|
2719
2711
|
if table is None:
|
|
2720
2712
|
log_debug("Traces table not found")
|
|
@@ -2758,14 +2750,12 @@ class MySQLDb(BaseDb):
|
|
|
2758
2750
|
# Get total count of sessions
|
|
2759
2751
|
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2760
2752
|
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2761
|
-
log_debug(f"Total matching sessions: {total_count}")
|
|
2762
2753
|
|
|
2763
2754
|
# Apply pagination and ordering
|
|
2764
2755
|
offset = (page - 1) * limit if page and limit else 0
|
|
2765
2756
|
paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
|
|
2766
2757
|
|
|
2767
2758
|
results = sess.execute(paginated_stmt).fetchall()
|
|
2768
|
-
log_debug(f"Returning page {page} with {len(results)} session stats")
|
|
2769
2759
|
|
|
2770
2760
|
# Convert to list of dicts with datetime objects
|
|
2771
2761
|
stats_list = []
|
agno/db/mysql/utils.py
CHANGED
|
@@ -5,15 +5,14 @@ from datetime import date, datetime, timedelta, timezone
|
|
|
5
5
|
from typing import Any, Dict, List, Optional
|
|
6
6
|
from uuid import uuid4
|
|
7
7
|
|
|
8
|
-
from sqlalchemy import Engine
|
|
9
|
-
|
|
10
8
|
from agno.db.mysql.schemas import get_table_schema_definition
|
|
11
9
|
from agno.db.schemas.culture import CulturalKnowledge
|
|
12
10
|
from agno.utils.log import log_debug, log_error, log_warning
|
|
13
11
|
|
|
14
12
|
try:
|
|
15
|
-
from sqlalchemy import Table
|
|
13
|
+
from sqlalchemy import Engine, Table
|
|
16
14
|
from sqlalchemy.dialects import mysql
|
|
15
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
|
17
16
|
from sqlalchemy.inspection import inspect
|
|
18
17
|
from sqlalchemy.orm import Session
|
|
19
18
|
from sqlalchemy.sql.expression import text
|
|
@@ -91,8 +90,10 @@ def is_valid_table(db_engine: Engine, table_name: str, table_type: str, db_schem
|
|
|
91
90
|
Check if the existing table has the expected column names.
|
|
92
91
|
|
|
93
92
|
Args:
|
|
93
|
+
db_engine: Database engine
|
|
94
94
|
table_name (str): Name of the table to validate
|
|
95
|
-
|
|
95
|
+
table_type (str): Type of table (for schema lookup)
|
|
96
|
+
db_schema (str): Database schema name
|
|
96
97
|
|
|
97
98
|
Returns:
|
|
98
99
|
bool: True if table has all expected columns, False otherwise
|
|
@@ -123,6 +124,7 @@ def bulk_upsert_metrics(session: Session, table: Table, metrics_records: list[di
|
|
|
123
124
|
"""Bulk upsert metrics into the database.
|
|
124
125
|
|
|
125
126
|
Args:
|
|
127
|
+
session (Session): The SQLAlchemy session
|
|
126
128
|
table (Table): The table to upsert into.
|
|
127
129
|
metrics_records (list[dict]): The metrics records to upsert.
|
|
128
130
|
|
|
@@ -156,7 +158,10 @@ def bulk_upsert_metrics(session: Session, table: Table, metrics_records: list[di
|
|
|
156
158
|
|
|
157
159
|
for record in metrics_records:
|
|
158
160
|
select_stmt = select(table).where(
|
|
159
|
-
and_(
|
|
161
|
+
and_(
|
|
162
|
+
table.c.date == record["date"],
|
|
163
|
+
table.c.aggregation_period == record["aggregation_period"],
|
|
164
|
+
)
|
|
160
165
|
)
|
|
161
166
|
result = session.execute(select_stmt).fetchone()
|
|
162
167
|
if result:
|
|
@@ -165,6 +170,55 @@ def bulk_upsert_metrics(session: Session, table: Table, metrics_records: list[di
|
|
|
165
170
|
return results # type: ignore
|
|
166
171
|
|
|
167
172
|
|
|
173
|
+
async def abulk_upsert_metrics(session: AsyncSession, table: Table, metrics_records: list[dict]) -> list[dict]:
|
|
174
|
+
"""Async bulk upsert metrics into the database.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
session (AsyncSession): The async SQLAlchemy session
|
|
178
|
+
table (Table): The table to upsert into.
|
|
179
|
+
metrics_records (list[dict]): The metrics records to upsert.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
list[dict]: The upserted metrics records.
|
|
183
|
+
"""
|
|
184
|
+
if not metrics_records:
|
|
185
|
+
return []
|
|
186
|
+
|
|
187
|
+
results = []
|
|
188
|
+
|
|
189
|
+
# MySQL doesn't support returning in the same way as PostgreSQL
|
|
190
|
+
# We'll need to insert/update and then fetch the records
|
|
191
|
+
for record in metrics_records:
|
|
192
|
+
stmt = mysql.insert(table).values(record)
|
|
193
|
+
|
|
194
|
+
# Columns to update in case of conflict
|
|
195
|
+
update_dict = {
|
|
196
|
+
col.name: record.get(col.name)
|
|
197
|
+
for col in table.columns
|
|
198
|
+
if col.name not in ["id", "date", "created_at", "aggregation_period"] and col.name in record
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
stmt = stmt.on_duplicate_key_update(**update_dict)
|
|
202
|
+
await session.execute(stmt)
|
|
203
|
+
|
|
204
|
+
# Fetch the updated records
|
|
205
|
+
from sqlalchemy import and_, select
|
|
206
|
+
|
|
207
|
+
for record in metrics_records:
|
|
208
|
+
select_stmt = select(table).where(
|
|
209
|
+
and_(
|
|
210
|
+
table.c.date == record["date"],
|
|
211
|
+
table.c.aggregation_period == record["aggregation_period"],
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
result = await session.execute(select_stmt)
|
|
215
|
+
fetched_row = result.fetchone()
|
|
216
|
+
if fetched_row:
|
|
217
|
+
results.append(dict(fetched_row._mapping))
|
|
218
|
+
|
|
219
|
+
return results
|
|
220
|
+
|
|
221
|
+
|
|
168
222
|
def calculate_date_metrics(date_to_process: date, sessions_data: dict) -> dict:
|
|
169
223
|
"""Calculate metrics for the given single date.
|
|
170
224
|
|
|
@@ -299,7 +353,9 @@ def get_dates_to_calculate_metrics_for(starting_date: date) -> list[date]:
|
|
|
299
353
|
|
|
300
354
|
|
|
301
355
|
# -- Cultural Knowledge util methods --
|
|
302
|
-
def serialize_cultural_knowledge_for_db(
|
|
356
|
+
def serialize_cultural_knowledge_for_db(
|
|
357
|
+
cultural_knowledge: CulturalKnowledge,
|
|
358
|
+
) -> Dict[str, Any]:
|
|
303
359
|
"""Serialize a CulturalKnowledge object for database storage.
|
|
304
360
|
|
|
305
361
|
Converts the model's separate content, categories, and notes fields
|
|
@@ -353,3 +409,80 @@ def deserialize_cultural_knowledge_from_db(db_row: Dict[str, Any]) -> CulturalKn
|
|
|
353
409
|
"team_id": db_row.get("team_id"),
|
|
354
410
|
}
|
|
355
411
|
)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
# -- Async DB util methods --
|
|
415
|
+
async def acreate_schema(session: AsyncSession, db_schema: str) -> None:
|
|
416
|
+
"""Async version: Create the database schema if it doesn't exist.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
session: The async SQLAlchemy session to use
|
|
420
|
+
db_schema (str): The definition of the database schema to create
|
|
421
|
+
"""
|
|
422
|
+
try:
|
|
423
|
+
log_debug(f"Creating database if not exists: {db_schema}")
|
|
424
|
+
# MySQL uses CREATE DATABASE instead of CREATE SCHEMA
|
|
425
|
+
await session.execute(text(f"CREATE DATABASE IF NOT EXISTS `{db_schema}`;"))
|
|
426
|
+
except Exception as e:
|
|
427
|
+
log_warning(f"Could not create database {db_schema}: {e}")
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
async def ais_table_available(session: AsyncSession, table_name: str, db_schema: str) -> bool:
|
|
431
|
+
"""Async version: Check if a table with the given name exists in the given schema.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
bool: True if the table exists, False otherwise.
|
|
435
|
+
"""
|
|
436
|
+
try:
|
|
437
|
+
exists_query = text(
|
|
438
|
+
"SELECT 1 FROM information_schema.tables WHERE table_schema = :schema AND table_name = :table"
|
|
439
|
+
)
|
|
440
|
+
result = await session.execute(exists_query, {"schema": db_schema, "table": table_name})
|
|
441
|
+
exists = result.scalar() is not None
|
|
442
|
+
if not exists:
|
|
443
|
+
log_debug(f"Table {db_schema}.{table_name} {'exists' if exists else 'does not exist'}")
|
|
444
|
+
|
|
445
|
+
return exists
|
|
446
|
+
|
|
447
|
+
except Exception as e:
|
|
448
|
+
log_error(f"Error checking if table exists: {e}")
|
|
449
|
+
return False
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
async def ais_valid_table(db_engine: AsyncEngine, table_name: str, table_type: str, db_schema: str) -> bool:
|
|
453
|
+
"""Async version: Check if the existing table has the expected column names.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
db_engine: Async database engine
|
|
457
|
+
table_name (str): Name of the table to validate
|
|
458
|
+
table_type (str): Type of table (for schema lookup)
|
|
459
|
+
db_schema (str): Database schema name
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
bool: True if table has all expected columns, False otherwise
|
|
463
|
+
"""
|
|
464
|
+
try:
|
|
465
|
+
expected_table_schema = get_table_schema_definition(table_type)
|
|
466
|
+
expected_columns = {col_name for col_name in expected_table_schema.keys() if not col_name.startswith("_")}
|
|
467
|
+
|
|
468
|
+
# Get existing columns from the async engine
|
|
469
|
+
async with db_engine.connect() as conn:
|
|
470
|
+
existing_columns = await conn.run_sync(_get_table_columns, table_name, db_schema)
|
|
471
|
+
|
|
472
|
+
# Check if all expected columns exist
|
|
473
|
+
missing_columns = expected_columns - existing_columns
|
|
474
|
+
if missing_columns:
|
|
475
|
+
log_warning(f"Missing columns {missing_columns} in table {db_schema}.{table_name}")
|
|
476
|
+
return False
|
|
477
|
+
|
|
478
|
+
return True
|
|
479
|
+
except Exception as e:
|
|
480
|
+
log_error(f"Error validating table schema for {db_schema}.{table_name}: {e}")
|
|
481
|
+
return False
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _get_table_columns(connection, table_name: str, db_schema: str) -> set[str]:
|
|
485
|
+
"""Helper function to get table columns using sync inspector."""
|
|
486
|
+
inspector = inspect(connection)
|
|
487
|
+
columns_info = inspector.get_columns(table_name, schema=db_schema)
|
|
488
|
+
return {col["name"] for col in columns_info}
|
|
@@ -56,6 +56,7 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
56
56
|
traces_table: Optional[str] = None,
|
|
57
57
|
spans_table: Optional[str] = None,
|
|
58
58
|
versions_table: Optional[str] = None,
|
|
59
|
+
create_schema: bool = True,
|
|
59
60
|
db_id: Optional[str] = None, # Deprecated, use id instead.
|
|
60
61
|
):
|
|
61
62
|
"""
|
|
@@ -80,6 +81,8 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
80
81
|
traces_table (Optional[str]): Name of the table to store run traces.
|
|
81
82
|
spans_table (Optional[str]): Name of the table to store span events.
|
|
82
83
|
versions_table (Optional[str]): Name of the table to store schema versions.
|
|
84
|
+
create_schema (bool): Whether to automatically create the database schema if it doesn't exist.
|
|
85
|
+
Set to False if schema is managed externally (e.g., via migrations). Defaults to True.
|
|
83
86
|
db_id: Deprecated, use id instead.
|
|
84
87
|
|
|
85
88
|
Raises:
|
|
@@ -116,6 +119,7 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
116
119
|
self.db_engine: AsyncEngine = _engine
|
|
117
120
|
self.db_schema: str = db_schema if db_schema is not None else "ai"
|
|
118
121
|
self.metadata: MetaData = MetaData(schema=self.db_schema)
|
|
122
|
+
self.create_schema: bool = create_schema
|
|
119
123
|
|
|
120
124
|
# Initialize database session factory
|
|
121
125
|
self.async_session_factory = async_sessionmaker(
|
|
@@ -200,8 +204,9 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
200
204
|
idx_name = f"idx_{table_name}_{idx_col}"
|
|
201
205
|
table.append_constraint(Index(idx_name, idx_col))
|
|
202
206
|
|
|
203
|
-
|
|
204
|
-
|
|
207
|
+
if self.create_schema:
|
|
208
|
+
async with self.async_session_factory() as sess, sess.begin():
|
|
209
|
+
await acreate_schema(session=sess, db_schema=self.db_schema)
|
|
205
210
|
|
|
206
211
|
# Create table
|
|
207
212
|
table_created = False
|
|
@@ -1237,7 +1242,7 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
1237
1242
|
Exception: If an error occurs during upsert.
|
|
1238
1243
|
"""
|
|
1239
1244
|
try:
|
|
1240
|
-
table = await self._get_table(table_type="culture")
|
|
1245
|
+
table = await self._get_table(table_type="culture", create_table_if_not_found=True)
|
|
1241
1246
|
|
|
1242
1247
|
# Generate ID if not present
|
|
1243
1248
|
if cultural_knowledge.id is None:
|
|
@@ -1381,7 +1386,7 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
1381
1386
|
Exception: If an error occurs during upsert.
|
|
1382
1387
|
"""
|
|
1383
1388
|
try:
|
|
1384
|
-
table = await self._get_table(table_type="memories")
|
|
1389
|
+
table = await self._get_table(table_type="memories", create_table_if_not_found=True)
|
|
1385
1390
|
|
|
1386
1391
|
current_time = int(time.time())
|
|
1387
1392
|
|
|
@@ -1725,7 +1730,7 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
1725
1730
|
Optional[KnowledgeRow]: The upserted knowledge row, or None if the operation fails.
|
|
1726
1731
|
"""
|
|
1727
1732
|
try:
|
|
1728
|
-
table = await self._get_table(table_type="knowledge")
|
|
1733
|
+
table = await self._get_table(table_type="knowledge", create_table_if_not_found=True)
|
|
1729
1734
|
async with self.async_session_factory() as sess, sess.begin():
|
|
1730
1735
|
# Get the actual table columns to avoid "unconsumed column names" error
|
|
1731
1736
|
table_columns = set(table.columns.keys())
|
|
@@ -2186,14 +2191,6 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
2186
2191
|
if trace.workflow_id is not None:
|
|
2187
2192
|
update_values["workflow_id"] = trace.workflow_id
|
|
2188
2193
|
|
|
2189
|
-
log_debug(
|
|
2190
|
-
f" Updating trace with context: run_id={update_values.get('run_id', 'unchanged')}, "
|
|
2191
|
-
f"session_id={update_values.get('session_id', 'unchanged')}, "
|
|
2192
|
-
f"user_id={update_values.get('user_id', 'unchanged')}, "
|
|
2193
|
-
f"agent_id={update_values.get('agent_id', 'unchanged')}, "
|
|
2194
|
-
f"team_id={update_values.get('team_id', 'unchanged')}, "
|
|
2195
|
-
)
|
|
2196
|
-
|
|
2197
2194
|
stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
|
|
2198
2195
|
await sess.execute(stmt)
|
|
2199
2196
|
else:
|
|
@@ -2293,10 +2290,6 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
2293
2290
|
try:
|
|
2294
2291
|
from agno.tracing.schemas import Trace
|
|
2295
2292
|
|
|
2296
|
-
log_debug(
|
|
2297
|
-
f"get_traces called with filters: run_id={run_id}, session_id={session_id}, user_id={user_id}, agent_id={agent_id}, page={page}, limit={limit}"
|
|
2298
|
-
)
|
|
2299
|
-
|
|
2300
2293
|
table = await self._get_table(table_type="traces")
|
|
2301
2294
|
|
|
2302
2295
|
# Get spans table for JOIN
|
|
@@ -2310,7 +2303,6 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
2310
2303
|
if run_id:
|
|
2311
2304
|
base_stmt = base_stmt.where(table.c.run_id == run_id)
|
|
2312
2305
|
if session_id:
|
|
2313
|
-
log_debug(f"Filtering by session_id={session_id}")
|
|
2314
2306
|
base_stmt = base_stmt.where(table.c.session_id == session_id)
|
|
2315
2307
|
if user_id:
|
|
2316
2308
|
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
@@ -2378,12 +2370,6 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
2378
2370
|
workflow_id, first_trace_at, last_trace_at.
|
|
2379
2371
|
"""
|
|
2380
2372
|
try:
|
|
2381
|
-
log_debug(
|
|
2382
|
-
f"get_trace_stats called with filters: user_id={user_id}, agent_id={agent_id}, "
|
|
2383
|
-
f"workflow_id={workflow_id}, team_id={team_id}, "
|
|
2384
|
-
f"start_time={start_time}, end_time={end_time}, page={page}, limit={limit}"
|
|
2385
|
-
)
|
|
2386
|
-
|
|
2387
2373
|
table = await self._get_table(table_type="traces")
|
|
2388
2374
|
|
|
2389
2375
|
async with self.async_session_factory() as sess:
|
|
@@ -2424,7 +2410,6 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
2424
2410
|
# Get total count of sessions
|
|
2425
2411
|
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2426
2412
|
total_count = await sess.scalar(count_stmt) or 0
|
|
2427
|
-
log_debug(f"Total matching sessions: {total_count}")
|
|
2428
2413
|
|
|
2429
2414
|
# Apply pagination and ordering
|
|
2430
2415
|
offset = (page - 1) * limit if page and limit else 0
|
|
@@ -2432,7 +2417,6 @@ class AsyncPostgresDb(AsyncBaseDb):
|
|
|
2432
2417
|
|
|
2433
2418
|
result = await sess.execute(paginated_stmt)
|
|
2434
2419
|
results = result.fetchall()
|
|
2435
|
-
log_debug(f"Returning page {page} with {len(results)} session stats")
|
|
2436
2420
|
|
|
2437
2421
|
# Convert to list of dicts with datetime objects
|
|
2438
2422
|
stats_list = []
|
agno/db/postgres/postgres.py
CHANGED
|
@@ -57,6 +57,7 @@ class PostgresDb(BaseDb):
|
|
|
57
57
|
spans_table: Optional[str] = None,
|
|
58
58
|
versions_table: Optional[str] = None,
|
|
59
59
|
id: Optional[str] = None,
|
|
60
|
+
create_schema: bool = True,
|
|
60
61
|
):
|
|
61
62
|
"""
|
|
62
63
|
Interface for interacting with a PostgreSQL database.
|
|
@@ -80,6 +81,8 @@ class PostgresDb(BaseDb):
|
|
|
80
81
|
spans_table (Optional[str]): Name of the table to store span events.
|
|
81
82
|
versions_table (Optional[str]): Name of the table to store schema versions.
|
|
82
83
|
id (Optional[str]): ID of the database.
|
|
84
|
+
create_schema (bool): Whether to automatically create the database schema if it doesn't exist.
|
|
85
|
+
Set to False if schema is managed externally (e.g., via migrations). Defaults to True.
|
|
83
86
|
|
|
84
87
|
Raises:
|
|
85
88
|
ValueError: If neither db_url nor db_engine is provided.
|
|
@@ -115,6 +118,7 @@ class PostgresDb(BaseDb):
|
|
|
115
118
|
|
|
116
119
|
self.db_schema: str = db_schema if db_schema is not None else "ai"
|
|
117
120
|
self.metadata: MetaData = MetaData(schema=self.db_schema)
|
|
121
|
+
self.create_schema: bool = create_schema
|
|
118
122
|
|
|
119
123
|
# Initialize database session
|
|
120
124
|
self.Session: scoped_session = scoped_session(sessionmaker(bind=self.db_engine, expire_on_commit=False))
|
|
@@ -204,8 +208,9 @@ class PostgresDb(BaseDb):
|
|
|
204
208
|
idx_name = f"idx_{table_name}_{idx_col}"
|
|
205
209
|
table.append_constraint(Index(idx_name, idx_col))
|
|
206
210
|
|
|
207
|
-
|
|
208
|
-
|
|
211
|
+
if self.create_schema:
|
|
212
|
+
with self.Session() as sess, sess.begin():
|
|
213
|
+
create_schema(session=sess, db_schema=self.db_schema)
|
|
209
214
|
|
|
210
215
|
# Create table
|
|
211
216
|
table_created = False
|
|
@@ -2466,14 +2471,6 @@ class PostgresDb(BaseDb):
|
|
|
2466
2471
|
if trace.workflow_id is not None:
|
|
2467
2472
|
update_values["workflow_id"] = trace.workflow_id
|
|
2468
2473
|
|
|
2469
|
-
log_debug(
|
|
2470
|
-
f" Updating trace with context: run_id={update_values.get('run_id', 'unchanged')}, "
|
|
2471
|
-
f"session_id={update_values.get('session_id', 'unchanged')}, "
|
|
2472
|
-
f"user_id={update_values.get('user_id', 'unchanged')}, "
|
|
2473
|
-
f"agent_id={update_values.get('agent_id', 'unchanged')}, "
|
|
2474
|
-
f"team_id={update_values.get('team_id', 'unchanged')}, "
|
|
2475
|
-
)
|
|
2476
|
-
|
|
2477
2474
|
stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
|
|
2478
2475
|
sess.execute(stmt)
|
|
2479
2476
|
else:
|
|
@@ -2574,10 +2571,6 @@ class PostgresDb(BaseDb):
|
|
|
2574
2571
|
try:
|
|
2575
2572
|
from agno.tracing.schemas import Trace
|
|
2576
2573
|
|
|
2577
|
-
log_debug(
|
|
2578
|
-
f"get_traces called with filters: run_id={run_id}, session_id={session_id}, user_id={user_id}, agent_id={agent_id}, page={page}, limit={limit}"
|
|
2579
|
-
)
|
|
2580
|
-
|
|
2581
2574
|
table = self._get_table(table_type="traces")
|
|
2582
2575
|
if table is None:
|
|
2583
2576
|
log_debug("Traces table not found")
|
|
@@ -2594,7 +2587,6 @@ class PostgresDb(BaseDb):
|
|
|
2594
2587
|
if run_id:
|
|
2595
2588
|
base_stmt = base_stmt.where(table.c.run_id == run_id)
|
|
2596
2589
|
if session_id:
|
|
2597
|
-
log_debug(f"Filtering by session_id={session_id}")
|
|
2598
2590
|
base_stmt = base_stmt.where(table.c.session_id == session_id)
|
|
2599
2591
|
if user_id:
|
|
2600
2592
|
base_stmt = base_stmt.where(table.c.user_id == user_id)
|
|
@@ -2616,14 +2608,12 @@ class PostgresDb(BaseDb):
|
|
|
2616
2608
|
# Get total count
|
|
2617
2609
|
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2618
2610
|
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2619
|
-
log_debug(f"Total matching traces: {total_count}")
|
|
2620
2611
|
|
|
2621
2612
|
# Apply pagination
|
|
2622
2613
|
offset = (page - 1) * limit if page and limit else 0
|
|
2623
2614
|
paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
|
|
2624
2615
|
|
|
2625
2616
|
results = sess.execute(paginated_stmt).fetchall()
|
|
2626
|
-
log_debug(f"Returning page {page} with {len(results)} traces")
|
|
2627
2617
|
|
|
2628
2618
|
traces = [Trace.from_dict(dict(row._mapping)) for row in results]
|
|
2629
2619
|
return traces, total_count
|
|
@@ -2661,12 +2651,6 @@ class PostgresDb(BaseDb):
|
|
|
2661
2651
|
first_trace_at, last_trace_at.
|
|
2662
2652
|
"""
|
|
2663
2653
|
try:
|
|
2664
|
-
log_debug(
|
|
2665
|
-
f"get_trace_stats called with filters: user_id={user_id}, agent_id={agent_id}, "
|
|
2666
|
-
f"workflow_id={workflow_id}, team_id={team_id}, "
|
|
2667
|
-
f"start_time={start_time}, end_time={end_time}, page={page}, limit={limit}"
|
|
2668
|
-
)
|
|
2669
|
-
|
|
2670
2654
|
table = self._get_table(table_type="traces")
|
|
2671
2655
|
if table is None:
|
|
2672
2656
|
log_debug("Traces table not found")
|
|
@@ -2710,14 +2694,12 @@ class PostgresDb(BaseDb):
|
|
|
2710
2694
|
# Get total count of sessions
|
|
2711
2695
|
count_stmt = select(func.count()).select_from(base_stmt.alias())
|
|
2712
2696
|
total_count = sess.execute(count_stmt).scalar() or 0
|
|
2713
|
-
log_debug(f"Total matching sessions: {total_count}")
|
|
2714
2697
|
|
|
2715
2698
|
# Apply pagination and ordering
|
|
2716
2699
|
offset = (page - 1) * limit if page and limit else 0
|
|
2717
2700
|
paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
|
|
2718
2701
|
|
|
2719
2702
|
results = sess.execute(paginated_stmt).fetchall()
|
|
2720
|
-
log_debug(f"Returning page {page} with {len(results)} session stats")
|
|
2721
2703
|
|
|
2722
2704
|
# Convert to list of dicts with datetime objects
|
|
2723
2705
|
stats_list = []
|
agno/db/redis/redis.py
CHANGED
|
@@ -1919,14 +1919,12 @@ class RedisDb(BaseDb):
|
|
|
1919
1919
|
filtered_traces.append(trace)
|
|
1920
1920
|
|
|
1921
1921
|
total_count = len(filtered_traces)
|
|
1922
|
-
log_debug(f"Total matching traces: {total_count}")
|
|
1923
1922
|
|
|
1924
1923
|
# Sort by start_time descending
|
|
1925
1924
|
filtered_traces.sort(key=lambda x: x.get("start_time", ""), reverse=True)
|
|
1926
1925
|
|
|
1927
1926
|
# Apply pagination
|
|
1928
1927
|
paginated_traces = apply_pagination(records=filtered_traces, limit=limit, page=page)
|
|
1929
|
-
log_debug(f"Returning page {page} with {len(paginated_traces)} traces")
|
|
1930
1928
|
|
|
1931
1929
|
traces = []
|
|
1932
1930
|
for row in paginated_traces:
|
|
@@ -2025,11 +2023,9 @@ class RedisDb(BaseDb):
|
|
|
2025
2023
|
stats_list.sort(key=lambda x: x.get("last_trace_at", ""), reverse=True)
|
|
2026
2024
|
|
|
2027
2025
|
total_count = len(stats_list)
|
|
2028
|
-
log_debug(f"Total matching sessions: {total_count}")
|
|
2029
2026
|
|
|
2030
2027
|
# Apply pagination
|
|
2031
2028
|
paginated_stats = apply_pagination(records=stats_list, limit=limit, page=page)
|
|
2032
|
-
log_debug(f"Returning page {page} with {len(paginated_stats)} session stats")
|
|
2033
2029
|
|
|
2034
2030
|
# Convert ISO strings to datetime objects
|
|
2035
2031
|
for stat in paginated_stats:
|