agno 2.3.3__py3-none-any.whl → 2.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. agno/agent/agent.py +177 -41
  2. agno/culture/manager.py +2 -2
  3. agno/db/base.py +330 -8
  4. agno/db/dynamo/dynamo.py +722 -2
  5. agno/db/dynamo/schemas.py +127 -0
  6. agno/db/firestore/firestore.py +573 -1
  7. agno/db/firestore/schemas.py +40 -0
  8. agno/db/gcs_json/gcs_json_db.py +446 -1
  9. agno/db/in_memory/in_memory_db.py +143 -1
  10. agno/db/json/json_db.py +438 -1
  11. agno/db/mongo/async_mongo.py +522 -0
  12. agno/db/mongo/mongo.py +523 -1
  13. agno/db/mongo/schemas.py +29 -0
  14. agno/db/mysql/mysql.py +536 -3
  15. agno/db/mysql/schemas.py +38 -0
  16. agno/db/postgres/async_postgres.py +546 -14
  17. agno/db/postgres/postgres.py +535 -2
  18. agno/db/postgres/schemas.py +38 -0
  19. agno/db/redis/redis.py +468 -1
  20. agno/db/redis/schemas.py +32 -0
  21. agno/db/singlestore/schemas.py +38 -0
  22. agno/db/singlestore/singlestore.py +523 -1
  23. agno/db/sqlite/async_sqlite.py +548 -9
  24. agno/db/sqlite/schemas.py +38 -0
  25. agno/db/sqlite/sqlite.py +537 -5
  26. agno/db/sqlite/utils.py +6 -8
  27. agno/db/surrealdb/models.py +25 -0
  28. agno/db/surrealdb/surrealdb.py +548 -1
  29. agno/eval/accuracy.py +10 -4
  30. agno/eval/performance.py +10 -4
  31. agno/eval/reliability.py +22 -13
  32. agno/exceptions.py +11 -0
  33. agno/hooks/__init__.py +3 -0
  34. agno/hooks/decorator.py +164 -0
  35. agno/knowledge/chunking/semantic.py +2 -2
  36. agno/models/aimlapi/aimlapi.py +17 -0
  37. agno/models/anthropic/claude.py +19 -12
  38. agno/models/aws/bedrock.py +3 -4
  39. agno/models/aws/claude.py +5 -1
  40. agno/models/azure/ai_foundry.py +2 -2
  41. agno/models/azure/openai_chat.py +8 -0
  42. agno/models/cerebras/cerebras.py +61 -4
  43. agno/models/cerebras/cerebras_openai.py +17 -0
  44. agno/models/cohere/chat.py +5 -1
  45. agno/models/cometapi/cometapi.py +18 -1
  46. agno/models/dashscope/dashscope.py +2 -3
  47. agno/models/deepinfra/deepinfra.py +18 -1
  48. agno/models/deepseek/deepseek.py +2 -3
  49. agno/models/fireworks/fireworks.py +18 -1
  50. agno/models/google/gemini.py +8 -2
  51. agno/models/groq/groq.py +5 -2
  52. agno/models/internlm/internlm.py +18 -1
  53. agno/models/langdb/langdb.py +13 -1
  54. agno/models/litellm/chat.py +2 -2
  55. agno/models/litellm/litellm_openai.py +18 -1
  56. agno/models/meta/llama_openai.py +19 -2
  57. agno/models/nebius/nebius.py +2 -3
  58. agno/models/nvidia/nvidia.py +20 -3
  59. agno/models/openai/chat.py +17 -2
  60. agno/models/openai/responses.py +17 -2
  61. agno/models/openrouter/openrouter.py +21 -2
  62. agno/models/perplexity/perplexity.py +17 -1
  63. agno/models/portkey/portkey.py +7 -6
  64. agno/models/requesty/requesty.py +19 -2
  65. agno/models/response.py +2 -1
  66. agno/models/sambanova/sambanova.py +20 -3
  67. agno/models/siliconflow/siliconflow.py +19 -2
  68. agno/models/together/together.py +20 -3
  69. agno/models/vercel/v0.py +20 -3
  70. agno/models/vllm/vllm.py +19 -14
  71. agno/models/xai/xai.py +19 -2
  72. agno/os/app.py +104 -0
  73. agno/os/config.py +13 -0
  74. agno/os/interfaces/whatsapp/router.py +0 -1
  75. agno/os/mcp.py +1 -0
  76. agno/os/router.py +31 -0
  77. agno/os/routers/traces/__init__.py +3 -0
  78. agno/os/routers/traces/schemas.py +414 -0
  79. agno/os/routers/traces/traces.py +499 -0
  80. agno/os/schema.py +22 -1
  81. agno/os/utils.py +57 -0
  82. agno/run/agent.py +1 -0
  83. agno/run/base.py +17 -0
  84. agno/run/team.py +4 -0
  85. agno/session/team.py +1 -0
  86. agno/table.py +10 -0
  87. agno/team/team.py +215 -65
  88. agno/tools/function.py +10 -8
  89. agno/tools/nano_banana.py +1 -1
  90. agno/tracing/__init__.py +12 -0
  91. agno/tracing/exporter.py +157 -0
  92. agno/tracing/schemas.py +276 -0
  93. agno/tracing/setup.py +111 -0
  94. agno/utils/agent.py +4 -4
  95. agno/utils/hooks.py +56 -1
  96. agno/vectordb/qdrant/qdrant.py +22 -22
  97. agno/workflow/condition.py +8 -0
  98. agno/workflow/loop.py +8 -0
  99. agno/workflow/parallel.py +8 -0
  100. agno/workflow/router.py +8 -0
  101. agno/workflow/step.py +20 -0
  102. agno/workflow/steps.py +8 -0
  103. agno/workflow/workflow.py +83 -17
  104. {agno-2.3.3.dist-info → agno-2.3.5.dist-info}/METADATA +2 -2
  105. {agno-2.3.3.dist-info → agno-2.3.5.dist-info}/RECORD +108 -98
  106. {agno-2.3.3.dist-info → agno-2.3.5.dist-info}/WHEEL +0 -0
  107. {agno-2.3.3.dist-info → agno-2.3.5.dist-info}/licenses/LICENSE +0 -0
  108. {agno-2.3.3.dist-info → agno-2.3.5.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
1
  import time
2
2
  import warnings
3
3
  from datetime import date, datetime, timedelta, timezone
4
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
5
5
  from uuid import uuid4
6
6
 
7
+ if TYPE_CHECKING:
8
+ from agno.tracing.schemas import Span, Trace
9
+
7
10
  from agno.db.base import AsyncBaseDb, SessionType
8
11
  from agno.db.migrations.manager import MigrationManager
9
12
  from agno.db.postgres.schemas import get_table_schema_definition
@@ -27,10 +30,10 @@ from agno.session import AgentSession, Session, TeamSession, WorkflowSession
27
30
  from agno.utils.log import log_debug, log_error, log_info, log_warning
28
31
 
29
32
  try:
30
- from sqlalchemy import Index, String, UniqueConstraint, func, update
33
+ from sqlalchemy import Index, String, Table, UniqueConstraint, func, update
31
34
  from sqlalchemy.dialects import postgresql
32
35
  from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
33
- from sqlalchemy.schema import Column, MetaData, Table
36
+ from sqlalchemy.schema import Column, MetaData
34
37
  from sqlalchemy.sql.expression import select, text
35
38
  except ImportError:
36
39
  raise ImportError("`sqlalchemy` not installed. Please install it using `pip install sqlalchemy`")
@@ -49,6 +52,8 @@ class AsyncPostgresDb(AsyncBaseDb):
49
52
  eval_table: Optional[str] = None,
50
53
  knowledge_table: Optional[str] = None,
51
54
  culture_table: Optional[str] = None,
55
+ traces_table: Optional[str] = None,
56
+ spans_table: Optional[str] = None,
52
57
  versions_table: Optional[str] = None,
53
58
  db_id: Optional[str] = None, # Deprecated, use id instead.
54
59
  ):
@@ -71,6 +76,8 @@ class AsyncPostgresDb(AsyncBaseDb):
71
76
  eval_table (Optional[str]): Name of the table to store evaluation runs data.
72
77
  knowledge_table (Optional[str]): Name of the table to store knowledge content.
73
78
  culture_table (Optional[str]): Name of the table to store cultural knowledge.
79
+ traces_table (Optional[str]): Name of the table to store run traces.
80
+ spans_table (Optional[str]): Name of the table to store span events.
74
81
  versions_table (Optional[str]): Name of the table to store schema versions.
75
82
  db_id: Deprecated, use id instead.
76
83
 
@@ -93,6 +100,8 @@ class AsyncPostgresDb(AsyncBaseDb):
93
100
  eval_table=eval_table,
94
101
  knowledge_table=knowledge_table,
95
102
  culture_table=culture_table,
103
+ traces_table=traces_table,
104
+ spans_table=spans_table,
96
105
  versions_table=versions_table,
97
106
  )
98
107
 
@@ -138,7 +147,9 @@ class AsyncPostgresDb(AsyncBaseDb):
138
147
  ]
139
148
 
140
149
  for table_name, table_type in tables_to_create:
141
- await self._get_or_create_table(table_name=table_name, table_type=table_type)
150
+ await self._get_or_create_table(
151
+ table_name=table_name, table_type=table_type, create_table_if_not_found=True
152
+ )
142
153
 
143
154
  async def _create_table(self, table_name: str, table_type: str) -> Table:
144
155
  """
@@ -239,7 +250,7 @@ class AsyncPostgresDb(AsyncBaseDb):
239
250
  log_error(f"Could not create table {self.db_schema}.{table_name}: {e}")
240
251
  raise
241
252
 
242
- async def _get_table(self, table_type: str) -> Table:
253
+ async def _get_table(self, table_type: str, create_table_if_not_found: Optional[bool] = False) -> Table:
243
254
  if table_type == "sessions":
244
255
  if not hasattr(self, "session_table"):
245
256
  self.session_table = await self._get_or_create_table(
@@ -287,9 +298,31 @@ class AsyncPostgresDb(AsyncBaseDb):
287
298
  )
288
299
  return self.versions_table
289
300
 
301
+ if table_type == "traces":
302
+ if not hasattr(self, "traces_table"):
303
+ self.traces_table = await self._get_or_create_table(
304
+ table_name=self.trace_table_name,
305
+ table_type="traces",
306
+ create_table_if_not_found=create_table_if_not_found,
307
+ )
308
+ return self.traces_table
309
+
310
+ if table_type == "spans":
311
+ if not hasattr(self, "spans_table"):
312
+ # Ensure traces table exists first (spans has FK to traces)
313
+ await self._get_table(table_type="traces", create_table_if_not_found=True)
314
+ self.spans_table = await self._get_or_create_table(
315
+ table_name=self.span_table_name,
316
+ table_type="spans",
317
+ create_table_if_not_found=create_table_if_not_found,
318
+ )
319
+ return self.spans_table
320
+
290
321
  raise ValueError(f"Unknown table type: {table_type}")
291
322
 
292
- async def _get_or_create_table(self, table_name: str, table_type: str) -> Table:
323
+ async def _get_or_create_table(
324
+ self, table_name: str, table_type: str, create_table_if_not_found: Optional[bool] = False
325
+ ) -> Table:
293
326
  """
294
327
  Check if the table exists and is valid, else create it.
295
328
 
@@ -306,7 +339,7 @@ class AsyncPostgresDb(AsyncBaseDb):
306
339
  session=sess, table_name=table_name, db_schema=self.db_schema
307
340
  )
308
341
 
309
- if not table_is_available:
342
+ if (not table_is_available) and create_table_if_not_found:
310
343
  return await self._create_table(table_name=table_name, table_type=table_type)
311
344
 
312
345
  if not await ais_valid_table(
@@ -333,7 +366,7 @@ class AsyncPostgresDb(AsyncBaseDb):
333
366
 
334
367
  async def get_latest_schema_version(self, table_name: str) -> str:
335
368
  """Get the latest version of the database schema."""
336
- table = await self._get_table(table_type="versions")
369
+ table = await self._get_table(table_type="versions", create_table_if_not_found=True)
337
370
  if table is None:
338
371
  return "2.0.0"
339
372
 
@@ -352,7 +385,7 @@ class AsyncPostgresDb(AsyncBaseDb):
352
385
 
353
386
  async def upsert_schema_version(self, table_name: str, version: str) -> None:
354
387
  """Upsert the schema version into the database."""
355
- table = await self._get_table(table_type="versions")
388
+ table = await self._get_table(table_type="versions", create_table_if_not_found=True)
356
389
  if table is None:
357
390
  return
358
391
  current_datetime = datetime.now().isoformat()
@@ -665,7 +698,7 @@ class AsyncPostgresDb(AsyncBaseDb):
665
698
  Exception: If an error occurs during upsert.
666
699
  """
667
700
  try:
668
- table = await self._get_table(table_type="sessions")
701
+ table = await self._get_table(table_type="sessions", create_table_if_not_found=True)
669
702
  session_dict = session.to_dict()
670
703
 
671
704
  if isinstance(session, AgentSession):
@@ -796,7 +829,7 @@ class AsyncPostgresDb(AsyncBaseDb):
796
829
  return None
797
830
 
798
831
  # -- Memory methods --
799
- async def delete_user_memory(self, memory_id: str):
832
+ async def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
800
833
  """Delete a user memory from the database.
801
834
 
802
835
  Returns:
@@ -810,6 +843,8 @@ class AsyncPostgresDb(AsyncBaseDb):
810
843
 
811
844
  async with self.async_session_factory() as sess, sess.begin():
812
845
  delete_stmt = table.delete().where(table.c.memory_id == memory_id)
846
+ if user_id is not None:
847
+ delete_stmt = delete_stmt.where(table.c.user_id == user_id)
813
848
  result = await sess.execute(delete_stmt)
814
849
 
815
850
  success = result.rowcount > 0 # type: ignore
@@ -821,7 +856,7 @@ class AsyncPostgresDb(AsyncBaseDb):
821
856
  except Exception as e:
822
857
  log_error(f"Error deleting user memory: {e}")
823
858
 
824
- async def delete_user_memories(self, memory_ids: List[str]) -> None:
859
+ async def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
825
860
  """Delete user memories from the database.
826
861
 
827
862
  Args:
@@ -835,6 +870,10 @@ class AsyncPostgresDb(AsyncBaseDb):
835
870
 
836
871
  async with self.async_session_factory() as sess, sess.begin():
837
872
  delete_stmt = table.delete().where(table.c.memory_id.in_(memory_ids))
873
+
874
+ if user_id is not None:
875
+ delete_stmt = delete_stmt.where(table.c.user_id == user_id)
876
+
838
877
  result = await sess.execute(delete_stmt)
839
878
 
840
879
  if result.rowcount == 0: # type: ignore
@@ -845,9 +884,12 @@ class AsyncPostgresDb(AsyncBaseDb):
845
884
  except Exception as e:
846
885
  log_error(f"Error deleting user memories: {e}")
847
886
 
848
- async def get_all_memory_topics(self) -> List[str]:
887
+ async def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
849
888
  """Get all memory topics from the database.
850
889
 
890
+ Args:
891
+ user_id (Optional[str]): The ID of the user to filter by.
892
+
851
893
  Returns:
852
894
  List[str]: List of memory topics.
853
895
  """
@@ -856,6 +898,8 @@ class AsyncPostgresDb(AsyncBaseDb):
856
898
 
857
899
  async with self.async_session_factory() as sess, sess.begin():
858
900
  stmt = select(func.json_array_elements_text(table.c.topics))
901
+ if user_id is not None:
902
+ stmt = stmt.where(table.c.user_id == user_id)
859
903
  result = await sess.execute(stmt)
860
904
  records = result.fetchall()
861
905
 
@@ -866,13 +910,17 @@ class AsyncPostgresDb(AsyncBaseDb):
866
910
  return []
867
911
 
868
912
  async def get_user_memory(
869
- self, memory_id: str, deserialize: Optional[bool] = True
913
+ self,
914
+ memory_id: str,
915
+ deserialize: Optional[bool] = True,
916
+ user_id: Optional[str] = None,
870
917
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
871
918
  """Get a memory from the database.
872
919
 
873
920
  Args:
874
921
  memory_id (str): The ID of the memory to get.
875
922
  deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
923
+ user_id (Optional[str]): The ID of the user to filter by.
876
924
 
877
925
  Returns:
878
926
  Union[UserMemory, Dict[str, Any], None]:
@@ -887,6 +935,8 @@ class AsyncPostgresDb(AsyncBaseDb):
887
935
 
888
936
  async with self.async_session_factory() as sess, sess.begin():
889
937
  stmt = select(table).where(table.c.memory_id == memory_id)
938
+ if user_id is not None:
939
+ stmt = stmt.where(table.c.user_id == user_id)
890
940
 
891
941
  result = await sess.execute(stmt)
892
942
  row = result.fetchone()
@@ -2009,3 +2059,485 @@ class AsyncPostgresDb(AsyncBaseDb):
2009
2059
  for memory in memories:
2010
2060
  await self.upsert_user_memory(memory)
2011
2061
  log_info(f"Migrated {len(memories)} memories to table: {self.memory_table}")
2062
+
2063
+ # --- Traces ---
2064
+ def _get_traces_base_query(self, table: Table, spans_table: Optional[Table] = None):
2065
+ """Build base query for traces with aggregated span counts.
2066
+
2067
+ Args:
2068
+ table: The traces table.
2069
+ spans_table: The spans table (optional).
2070
+
2071
+ Returns:
2072
+ SQLAlchemy select statement with total_spans and error_count calculated dynamically.
2073
+ """
2074
+ from sqlalchemy import case, literal
2075
+
2076
+ if spans_table is not None:
2077
+ # JOIN with spans table to calculate total_spans and error_count
2078
+ return (
2079
+ select(
2080
+ table,
2081
+ func.coalesce(func.count(spans_table.c.span_id), 0).label("total_spans"),
2082
+ func.coalesce(func.sum(case((spans_table.c.status_code == "ERROR", 1), else_=0)), 0).label(
2083
+ "error_count"
2084
+ ),
2085
+ )
2086
+ .select_from(table.outerjoin(spans_table, table.c.trace_id == spans_table.c.trace_id))
2087
+ .group_by(table.c.trace_id)
2088
+ )
2089
+ else:
2090
+ # Fallback if spans table doesn't exist
2091
+ return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
2092
+
2093
+ async def create_trace(self, trace: "Trace") -> None:
2094
+ """Create a single trace record in the database.
2095
+
2096
+ Args:
2097
+ trace: The Trace object to store (one per trace_id).
2098
+ """
2099
+ try:
2100
+ table = await self._get_table(table_type="traces", create_table_if_not_found=True)
2101
+
2102
+ async with self.async_session_factory() as sess, sess.begin():
2103
+ # Check if trace exists
2104
+ result = await sess.execute(select(table).where(table.c.trace_id == trace.trace_id))
2105
+ existing = result.fetchone()
2106
+
2107
+ if existing:
2108
+ # workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
2109
+
2110
+ def get_component_level(workflow_id, team_id, agent_id, name):
2111
+ # Check if name indicates a root span
2112
+ is_root_name = ".run" in name or ".arun" in name
2113
+
2114
+ if not is_root_name:
2115
+ return 0 # Child span (not a root)
2116
+ elif workflow_id:
2117
+ return 3 # Workflow root
2118
+ elif team_id:
2119
+ return 2 # Team root
2120
+ elif agent_id:
2121
+ return 1 # Agent root
2122
+ else:
2123
+ return 0 # Unknown
2124
+
2125
+ existing_level = get_component_level(
2126
+ existing.workflow_id, existing.team_id, existing.agent_id, existing.name
2127
+ )
2128
+ new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
2129
+
2130
+ # Only update name if new trace is from a higher or equal level
2131
+ should_update_name = new_level > existing_level
2132
+
2133
+ # Parse existing start_time to calculate correct duration
2134
+ existing_start_time_str = existing.start_time
2135
+ if isinstance(existing_start_time_str, str):
2136
+ existing_start_time = datetime.fromisoformat(existing_start_time_str.replace("Z", "+00:00"))
2137
+ else:
2138
+ existing_start_time = trace.start_time
2139
+
2140
+ recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
2141
+
2142
+ update_values = {
2143
+ "end_time": trace.end_time.isoformat(),
2144
+ "duration_ms": recalculated_duration_ms,
2145
+ "status": trace.status,
2146
+ "name": trace.name if should_update_name else existing.name,
2147
+ }
2148
+
2149
+ # Update context fields ONLY if new value is not None (preserve non-null values)
2150
+ if trace.run_id is not None:
2151
+ update_values["run_id"] = trace.run_id
2152
+ if trace.session_id is not None:
2153
+ update_values["session_id"] = trace.session_id
2154
+ if trace.user_id is not None:
2155
+ update_values["user_id"] = trace.user_id
2156
+ if trace.agent_id is not None:
2157
+ update_values["agent_id"] = trace.agent_id
2158
+ if trace.team_id is not None:
2159
+ update_values["team_id"] = trace.team_id
2160
+ if trace.workflow_id is not None:
2161
+ update_values["workflow_id"] = trace.workflow_id
2162
+
2163
+ log_debug(
2164
+ f" Updating trace with context: run_id={update_values.get('run_id', 'unchanged')}, "
2165
+ f"session_id={update_values.get('session_id', 'unchanged')}, "
2166
+ f"user_id={update_values.get('user_id', 'unchanged')}, "
2167
+ f"agent_id={update_values.get('agent_id', 'unchanged')}, "
2168
+ f"team_id={update_values.get('team_id', 'unchanged')}, "
2169
+ )
2170
+
2171
+ stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
2172
+ await sess.execute(stmt)
2173
+ else:
2174
+ trace_dict = trace.to_dict()
2175
+ trace_dict.pop("total_spans", None)
2176
+ trace_dict.pop("error_count", None)
2177
+ stmt = postgresql.insert(table).values(trace_dict)
2178
+ await sess.execute(stmt)
2179
+
2180
+ except Exception as e:
2181
+ log_error(f"Error creating trace: {e}")
2182
+ # Don't raise - tracing should not break the main application flow
2183
+
2184
+ async def get_trace(
2185
+ self,
2186
+ trace_id: Optional[str] = None,
2187
+ run_id: Optional[str] = None,
2188
+ ):
2189
+ """Get a single trace by trace_id or other filters.
2190
+
2191
+ Args:
2192
+ trace_id: The unique trace identifier.
2193
+ run_id: Filter by run ID (returns first match).
2194
+
2195
+ Returns:
2196
+ Optional[Trace]: The trace if found, None otherwise.
2197
+
2198
+ Note:
2199
+ If multiple filters are provided, trace_id takes precedence.
2200
+ For other filters, the most recent trace is returned.
2201
+ """
2202
+ try:
2203
+ from agno.tracing.schemas import Trace
2204
+
2205
+ table = await self._get_table(table_type="traces")
2206
+
2207
+ # Get spans table for JOIN
2208
+ spans_table = await self._get_table(table_type="spans")
2209
+
2210
+ async with self.async_session_factory() as sess:
2211
+ # Build query with aggregated span counts
2212
+ stmt = self._get_traces_base_query(table, spans_table)
2213
+
2214
+ if trace_id:
2215
+ stmt = stmt.where(table.c.trace_id == trace_id)
2216
+ elif run_id:
2217
+ stmt = stmt.where(table.c.run_id == run_id)
2218
+ else:
2219
+ log_debug("get_trace called without any filter parameters")
2220
+ return None
2221
+
2222
+ # Order by most recent and get first result
2223
+ stmt = stmt.order_by(table.c.start_time.desc()).limit(1)
2224
+ result = await sess.execute(stmt)
2225
+ row = result.fetchone()
2226
+
2227
+ if row:
2228
+ return Trace.from_dict(dict(row._mapping))
2229
+ return None
2230
+
2231
+ except Exception as e:
2232
+ log_error(f"Error getting trace: {e}")
2233
+ return None
2234
+
2235
+ async def get_traces(
2236
+ self,
2237
+ run_id: Optional[str] = None,
2238
+ session_id: Optional[str] = None,
2239
+ user_id: Optional[str] = None,
2240
+ agent_id: Optional[str] = None,
2241
+ team_id: Optional[str] = None,
2242
+ workflow_id: Optional[str] = None,
2243
+ status: Optional[str] = None,
2244
+ start_time: Optional[datetime] = None,
2245
+ end_time: Optional[datetime] = None,
2246
+ limit: Optional[int] = 20,
2247
+ page: Optional[int] = 1,
2248
+ ) -> tuple[List, int]:
2249
+ """Get traces matching the provided filters with pagination.
2250
+
2251
+ Args:
2252
+ run_id: Filter by run ID.
2253
+ session_id: Filter by session ID.
2254
+ user_id: Filter by user ID.
2255
+ agent_id: Filter by agent ID.
2256
+ team_id: Filter by team ID.
2257
+ workflow_id: Filter by workflow ID.
2258
+ status: Filter by status (OK, ERROR, UNSET).
2259
+ start_time: Filter traces starting after this datetime.
2260
+ end_time: Filter traces ending before this datetime.
2261
+ limit: Maximum number of traces to return per page.
2262
+ page: Page number (1-indexed).
2263
+
2264
+ Returns:
2265
+ tuple[List[Trace], int]: Tuple of (list of matching traces, total count).
2266
+ """
2267
+ try:
2268
+ from agno.tracing.schemas import Trace
2269
+
2270
+ log_debug(
2271
+ f"get_traces called with filters: run_id={run_id}, session_id={session_id}, user_id={user_id}, agent_id={agent_id}, page={page}, limit={limit}"
2272
+ )
2273
+
2274
+ table = await self._get_table(table_type="traces")
2275
+
2276
+ # Get spans table for JOIN
2277
+ spans_table = await self._get_table(table_type="spans")
2278
+
2279
+ async with self.async_session_factory() as sess:
2280
+ # Build base query with aggregated span counts
2281
+ base_stmt = self._get_traces_base_query(table, spans_table)
2282
+
2283
+ # Apply filters
2284
+ if run_id:
2285
+ base_stmt = base_stmt.where(table.c.run_id == run_id)
2286
+ if session_id:
2287
+ log_debug(f"Filtering by session_id={session_id}")
2288
+ base_stmt = base_stmt.where(table.c.session_id == session_id)
2289
+ if user_id:
2290
+ base_stmt = base_stmt.where(table.c.user_id == user_id)
2291
+ if agent_id:
2292
+ base_stmt = base_stmt.where(table.c.agent_id == agent_id)
2293
+ if team_id:
2294
+ base_stmt = base_stmt.where(table.c.team_id == team_id)
2295
+ if workflow_id:
2296
+ base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
2297
+ if status:
2298
+ base_stmt = base_stmt.where(table.c.status == status)
2299
+ if start_time:
2300
+ # Convert datetime to ISO string for comparison
2301
+ base_stmt = base_stmt.where(table.c.start_time >= start_time.isoformat())
2302
+ if end_time:
2303
+ # Convert datetime to ISO string for comparison
2304
+ base_stmt = base_stmt.where(table.c.end_time <= end_time.isoformat())
2305
+
2306
+ # Get total count
2307
+ count_stmt = select(func.count()).select_from(base_stmt.alias())
2308
+ total_count = await sess.scalar(count_stmt) or 0
2309
+ log_debug(f"Total matching traces: {total_count}")
2310
+
2311
+ # Apply pagination
2312
+ offset = (page - 1) * limit if page and limit else 0
2313
+ paginated_stmt = base_stmt.order_by(table.c.start_time.desc()).limit(limit).offset(offset)
2314
+
2315
+ result = await sess.execute(paginated_stmt)
2316
+ results = result.fetchall()
2317
+ log_debug(f"Returning page {page} with {len(results)} traces")
2318
+
2319
+ traces = [Trace.from_dict(dict(row._mapping)) for row in results]
2320
+ return traces, total_count
2321
+
2322
+ except Exception as e:
2323
+ log_error(f"Error getting traces: {e}")
2324
+ return [], 0
2325
+
2326
+ async def get_trace_stats(
2327
+ self,
2328
+ user_id: Optional[str] = None,
2329
+ agent_id: Optional[str] = None,
2330
+ team_id: Optional[str] = None,
2331
+ workflow_id: Optional[str] = None,
2332
+ start_time: Optional[datetime] = None,
2333
+ end_time: Optional[datetime] = None,
2334
+ limit: Optional[int] = 20,
2335
+ page: Optional[int] = 1,
2336
+ ) -> tuple[List[Dict[str, Any]], int]:
2337
+ """Get trace statistics grouped by session.
2338
+
2339
+ Args:
2340
+ user_id: Filter by user ID.
2341
+ agent_id: Filter by agent ID.
2342
+ team_id: Filter by team ID.
2343
+ workflow_id: Filter by workflow ID.
2344
+ start_time: Filter sessions with traces created after this datetime.
2345
+ end_time: Filter sessions with traces created before this datetime.
2346
+ limit: Maximum number of sessions to return per page.
2347
+ page: Page number (1-indexed).
2348
+
2349
+ Returns:
2350
+ tuple[List[Dict], int]: Tuple of (list of session stats dicts, total count).
2351
+ Each dict contains: session_id, user_id, agent_id, team_id, total_traces,
2352
+ workflow_id, first_trace_at, last_trace_at.
2353
+ """
2354
+ try:
2355
+ log_debug(
2356
+ f"get_trace_stats called with filters: user_id={user_id}, agent_id={agent_id}, "
2357
+ f"workflow_id={workflow_id}, team_id={team_id}, "
2358
+ f"start_time={start_time}, end_time={end_time}, page={page}, limit={limit}"
2359
+ )
2360
+
2361
+ table = await self._get_table(table_type="traces")
2362
+
2363
+ async with self.async_session_factory() as sess:
2364
+ # Build base query grouped by session_id
2365
+ base_stmt = (
2366
+ select(
2367
+ table.c.session_id,
2368
+ table.c.user_id,
2369
+ table.c.agent_id,
2370
+ table.c.team_id,
2371
+ table.c.workflow_id,
2372
+ func.count(table.c.trace_id).label("total_traces"),
2373
+ func.min(table.c.created_at).label("first_trace_at"),
2374
+ func.max(table.c.created_at).label("last_trace_at"),
2375
+ )
2376
+ .where(table.c.session_id.isnot(None)) # Only sessions with session_id
2377
+ .group_by(
2378
+ table.c.session_id, table.c.user_id, table.c.agent_id, table.c.team_id, table.c.workflow_id
2379
+ )
2380
+ )
2381
+
2382
+ # Apply filters
2383
+ if user_id:
2384
+ base_stmt = base_stmt.where(table.c.user_id == user_id)
2385
+ if workflow_id:
2386
+ base_stmt = base_stmt.where(table.c.workflow_id == workflow_id)
2387
+ if team_id:
2388
+ base_stmt = base_stmt.where(table.c.team_id == team_id)
2389
+ if agent_id:
2390
+ base_stmt = base_stmt.where(table.c.agent_id == agent_id)
2391
+ if start_time:
2392
+ # Convert datetime to ISO string for comparison
2393
+ base_stmt = base_stmt.where(table.c.created_at >= start_time.isoformat())
2394
+ if end_time:
2395
+ # Convert datetime to ISO string for comparison
2396
+ base_stmt = base_stmt.where(table.c.created_at <= end_time.isoformat())
2397
+
2398
+ # Get total count of sessions
2399
+ count_stmt = select(func.count()).select_from(base_stmt.alias())
2400
+ total_count = await sess.scalar(count_stmt) or 0
2401
+ log_debug(f"Total matching sessions: {total_count}")
2402
+
2403
+ # Apply pagination and ordering
2404
+ offset = (page - 1) * limit if page and limit else 0
2405
+ paginated_stmt = base_stmt.order_by(func.max(table.c.created_at).desc()).limit(limit).offset(offset)
2406
+
2407
+ result = await sess.execute(paginated_stmt)
2408
+ results = result.fetchall()
2409
+ log_debug(f"Returning page {page} with {len(results)} session stats")
2410
+
2411
+ # Convert to list of dicts with datetime objects
2412
+ stats_list = []
2413
+ for row in results:
2414
+ # Convert ISO strings to datetime objects
2415
+ first_trace_at_str = row.first_trace_at
2416
+ last_trace_at_str = row.last_trace_at
2417
+
2418
+ # Parse ISO format strings to datetime objects
2419
+ first_trace_at = datetime.fromisoformat(first_trace_at_str.replace("Z", "+00:00"))
2420
+ last_trace_at = datetime.fromisoformat(last_trace_at_str.replace("Z", "+00:00"))
2421
+
2422
+ stats_list.append(
2423
+ {
2424
+ "session_id": row.session_id,
2425
+ "user_id": row.user_id,
2426
+ "agent_id": row.agent_id,
2427
+ "team_id": row.team_id,
2428
+ "workflow_id": row.workflow_id,
2429
+ "total_traces": row.total_traces,
2430
+ "first_trace_at": first_trace_at,
2431
+ "last_trace_at": last_trace_at,
2432
+ }
2433
+ )
2434
+
2435
+ return stats_list, total_count
2436
+
2437
+ except Exception as e:
2438
+ log_error(f"Error getting trace stats: {e}")
2439
+ return [], 0
2440
+
2441
+ # --- Spans ---
2442
+ async def create_span(self, span: "Span") -> None:
2443
+ """Create a single span in the database.
2444
+
2445
+ Args:
2446
+ span: The Span object to store.
2447
+ """
2448
+ try:
2449
+ table = await self._get_table(table_type="spans", create_table_if_not_found=True)
2450
+
2451
+ async with self.async_session_factory() as sess, sess.begin():
2452
+ stmt = postgresql.insert(table).values(span.to_dict())
2453
+ await sess.execute(stmt)
2454
+
2455
+ except Exception as e:
2456
+ log_error(f"Error creating span: {e}")
2457
+
2458
+ async def create_spans(self, spans: List) -> None:
2459
+ """Create multiple spans in the database as a batch.
2460
+
2461
+ Args:
2462
+ spans: List of Span objects to store.
2463
+ """
2464
+ if not spans:
2465
+ return
2466
+
2467
+ try:
2468
+ table = await self._get_table(table_type="spans", create_table_if_not_found=True)
2469
+
2470
+ async with self.async_session_factory() as sess, sess.begin():
2471
+ for span in spans:
2472
+ stmt = postgresql.insert(table).values(span.to_dict())
2473
+ await sess.execute(stmt)
2474
+
2475
+ except Exception as e:
2476
+ log_error(f"Error creating spans batch: {e}")
2477
+
2478
+ async def get_span(self, span_id: str):
2479
+ """Get a single span by its span_id.
2480
+
2481
+ Args:
2482
+ span_id: The unique span identifier.
2483
+
2484
+ Returns:
2485
+ Optional[Span]: The span if found, None otherwise.
2486
+ """
2487
+ try:
2488
+ from agno.tracing.schemas import Span
2489
+
2490
+ table = await self._get_table(table_type="spans")
2491
+
2492
+ async with self.async_session_factory() as sess:
2493
+ stmt = select(table).where(table.c.span_id == span_id)
2494
+ result = await sess.execute(stmt)
2495
+ row = result.fetchone()
2496
+ if row:
2497
+ return Span.from_dict(dict(row._mapping))
2498
+ return None
2499
+
2500
+ except Exception as e:
2501
+ log_error(f"Error getting span: {e}")
2502
+ return None
2503
+
2504
+ async def get_spans(
2505
+ self,
2506
+ trace_id: Optional[str] = None,
2507
+ parent_span_id: Optional[str] = None,
2508
+ limit: Optional[int] = 1000,
2509
+ ) -> List:
2510
+ """Get spans matching the provided filters.
2511
+
2512
+ Args:
2513
+ trace_id: Filter by trace ID.
2514
+ parent_span_id: Filter by parent span ID.
2515
+ limit: Maximum number of spans to return.
2516
+
2517
+ Returns:
2518
+ List[Span]: List of matching spans.
2519
+ """
2520
+ try:
2521
+ from agno.tracing.schemas import Span
2522
+
2523
+ table = await self._get_table(table_type="spans")
2524
+
2525
+ async with self.async_session_factory() as sess:
2526
+ stmt = select(table)
2527
+
2528
+ # Apply filters
2529
+ if trace_id:
2530
+ stmt = stmt.where(table.c.trace_id == trace_id)
2531
+ if parent_span_id:
2532
+ stmt = stmt.where(table.c.parent_span_id == parent_span_id)
2533
+
2534
+ if limit:
2535
+ stmt = stmt.limit(limit)
2536
+
2537
+ result = await sess.execute(stmt)
2538
+ results = result.fetchall()
2539
+ return [Span.from_dict(dict(row._mapping)) for row in results]
2540
+
2541
+ except Exception as e:
2542
+ log_error(f"Error getting spans: {e}")
2543
+ return []