agno 2.3.9__py3-none-any.whl → 2.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. agno/agent/agent.py +0 -12
  2. agno/db/base.py +5 -5
  3. agno/db/dynamo/dynamo.py +2 -2
  4. agno/db/firestore/firestore.py +2 -2
  5. agno/db/gcs_json/gcs_json_db.py +2 -2
  6. agno/db/in_memory/in_memory_db.py +2 -2
  7. agno/db/json/json_db.py +2 -2
  8. agno/db/mongo/async_mongo.py +171 -69
  9. agno/db/mongo/mongo.py +171 -77
  10. agno/db/mysql/async_mysql.py +93 -69
  11. agno/db/mysql/mysql.py +93 -68
  12. agno/db/postgres/async_postgres.py +104 -78
  13. agno/db/postgres/postgres.py +97 -69
  14. agno/db/redis/redis.py +2 -2
  15. agno/db/singlestore/singlestore.py +91 -66
  16. agno/db/sqlite/async_sqlite.py +102 -79
  17. agno/db/sqlite/sqlite.py +97 -69
  18. agno/db/surrealdb/surrealdb.py +2 -2
  19. agno/eval/accuracy.py +11 -8
  20. agno/eval/agent_as_judge.py +9 -8
  21. agno/knowledge/chunking/fixed.py +4 -1
  22. agno/knowledge/embedder/openai.py +1 -1
  23. agno/knowledge/knowledge.py +22 -4
  24. agno/knowledge/utils.py +52 -7
  25. agno/models/base.py +34 -1
  26. agno/models/google/gemini.py +69 -40
  27. agno/models/message.py +3 -0
  28. agno/models/openai/chat.py +21 -0
  29. agno/os/routers/evals/utils.py +15 -37
  30. agno/os/routers/knowledge/knowledge.py +21 -9
  31. agno/team/team.py +14 -8
  32. agno/tools/function.py +37 -23
  33. agno/tools/shopify.py +1519 -0
  34. agno/tools/spotify.py +2 -5
  35. agno/tracing/exporter.py +2 -2
  36. agno/vectordb/base.py +15 -2
  37. agno/vectordb/pgvector/pgvector.py +8 -8
  38. agno/workflow/parallel.py +2 -0
  39. {agno-2.3.9.dist-info → agno-2.3.11.dist-info}/METADATA +1 -1
  40. {agno-2.3.9.dist-info → agno-2.3.11.dist-info}/RECORD +43 -42
  41. {agno-2.3.9.dist-info → agno-2.3.11.dist-info}/WHEEL +0 -0
  42. {agno-2.3.9.dist-info → agno-2.3.11.dist-info}/licenses/LICENSE +0 -0
  43. {agno-2.3.9.dist-info → agno-2.3.11.dist-info}/top_level.txt +0 -0
@@ -30,8 +30,9 @@ from agno.session import AgentSession, Session, TeamSession, WorkflowSession
30
30
  from agno.utils.log import log_debug, log_error, log_info, log_warning
31
31
 
32
32
  try:
33
- from sqlalchemy import Index, String, Table, UniqueConstraint, func, update
33
+ from sqlalchemy import Index, String, Table, UniqueConstraint, and_, case, func, or_, update
34
34
  from sqlalchemy.dialects import postgresql
35
+ from sqlalchemy.dialects.postgresql import TIMESTAMP
35
36
  from sqlalchemy.exc import ProgrammingError
36
37
  from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
37
38
  from sqlalchemy.schema import Column, MetaData
@@ -1094,7 +1095,7 @@ class AsyncPostgresDb(AsyncBaseDb):
1094
1095
  await sess.execute(table.delete())
1095
1096
 
1096
1097
  except Exception as e:
1097
- log_warning(f"Exception deleting all cultural knowledge: {e}")
1098
+ log_error(f"Exception deleting all cultural knowledge: {e}")
1098
1099
 
1099
1100
  async def delete_cultural_knowledge(self, id: str) -> None:
1100
1101
  """Delete cultural knowledge by ID.
@@ -1113,8 +1114,7 @@ class AsyncPostgresDb(AsyncBaseDb):
1113
1114
  await sess.execute(stmt)
1114
1115
 
1115
1116
  except Exception as e:
1116
- log_warning(f"Exception deleting cultural knowledge: {e}")
1117
- raise e
1117
+ log_error(f"Exception deleting cultural knowledge: {e}")
1118
1118
 
1119
1119
  async def get_cultural_knowledge(
1120
1120
  self, id: str, deserialize: Optional[bool] = True
@@ -1150,8 +1150,8 @@ class AsyncPostgresDb(AsyncBaseDb):
1150
1150
  return deserialize_cultural_knowledge(db_row)
1151
1151
 
1152
1152
  except Exception as e:
1153
- log_warning(f"Exception reading cultural knowledge: {e}")
1154
- raise e
1153
+ log_error(f"Exception reading cultural knowledge: {e}")
1154
+ return None
1155
1155
 
1156
1156
  async def get_all_cultural_knowledge(
1157
1157
  self,
@@ -1185,7 +1185,7 @@ class AsyncPostgresDb(AsyncBaseDb):
1185
1185
  Exception: If an error occurs during retrieval.
1186
1186
  """
1187
1187
  try:
1188
- table = await self._get_table(table_type="culture")
1188
+ table = await self._get_table(table_type="culture", create_table_if_not_found=True)
1189
1189
 
1190
1190
  async with self.async_session_factory() as sess:
1191
1191
  # Build query with filters
@@ -1223,8 +1223,8 @@ class AsyncPostgresDb(AsyncBaseDb):
1223
1223
  return [deserialize_cultural_knowledge(row) for row in db_rows]
1224
1224
 
1225
1225
  except Exception as e:
1226
- log_warning(f"Exception reading all cultural knowledge: {e}")
1227
- raise e
1226
+ log_error(f"Exception reading all cultural knowledge: {e}")
1227
+ return [] if deserialize else ([], 0)
1228
1228
 
1229
1229
  async def upsert_cultural_knowledge(
1230
1230
  self, cultural_knowledge: CulturalKnowledge, deserialize: Optional[bool] = True
@@ -2121,84 +2121,110 @@ class AsyncPostgresDb(AsyncBaseDb):
2121
2121
  # Fallback if spans table doesn't exist
2122
2122
  return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
2123
2123
 
2124
- async def create_trace(self, trace: "Trace") -> None:
2125
- """Create a single trace record in the database.
2124
+ def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
2125
+ """Build a SQL CASE expression that returns the component level for a trace.
2126
+
2127
+ Component levels (higher = more important):
2128
+ - 3: Workflow root (.run or .arun with workflow_id)
2129
+ - 2: Team root (.run or .arun with team_id)
2130
+ - 1: Agent root (.run or .arun with agent_id)
2131
+ - 0: Child span (not a root)
2126
2132
 
2127
2133
  Args:
2128
- trace: The Trace object to store (one per trace_id).
2134
+ workflow_id_col: SQL column/expression for workflow_id
2135
+ team_id_col: SQL column/expression for team_id
2136
+ agent_id_col: SQL column/expression for agent_id
2137
+ name_col: SQL column/expression for name
2138
+
2139
+ Returns:
2140
+ SQLAlchemy CASE expression returning the component level as an integer.
2129
2141
  """
2130
- try:
2131
- table = await self._get_table(table_type="traces", create_table_if_not_found=True)
2142
+ is_root_name = or_(name_col.contains(".run"), name_col.contains(".arun"))
2143
+
2144
+ return case(
2145
+ # Workflow root (level 3)
2146
+ (and_(workflow_id_col.isnot(None), is_root_name), 3),
2147
+ # Team root (level 2)
2148
+ (and_(team_id_col.isnot(None), is_root_name), 2),
2149
+ # Agent root (level 1)
2150
+ (and_(agent_id_col.isnot(None), is_root_name), 1),
2151
+ # Child span or unknown (level 0)
2152
+ else_=0,
2153
+ )
2132
2154
 
2133
- async with self.async_session_factory() as sess, sess.begin():
2134
- # Check if trace exists
2135
- result = await sess.execute(select(table).where(table.c.trace_id == trace.trace_id))
2136
- existing = result.fetchone()
2137
-
2138
- if existing:
2139
- # workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
2140
-
2141
- def get_component_level(workflow_id, team_id, agent_id, name):
2142
- # Check if name indicates a root span
2143
- is_root_name = ".run" in name or ".arun" in name
2144
-
2145
- if not is_root_name:
2146
- return 0 # Child span (not a root)
2147
- elif workflow_id:
2148
- return 3 # Workflow root
2149
- elif team_id:
2150
- return 2 # Team root
2151
- elif agent_id:
2152
- return 1 # Agent root
2153
- else:
2154
- return 0 # Unknown
2155
-
2156
- existing_level = get_component_level(
2157
- existing.workflow_id, existing.team_id, existing.agent_id, existing.name
2158
- )
2159
- new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
2155
+ async def upsert_trace(self, trace: "Trace") -> None:
2156
+ """Create or update a single trace record in the database.
2160
2157
 
2161
- # Only update name if new trace is from a higher or equal level
2162
- should_update_name = new_level > existing_level
2158
+ Uses INSERT ... ON CONFLICT DO UPDATE (upsert) to handle concurrent inserts
2159
+ atomically and avoid race conditions.
2163
2160
 
2164
- # Parse existing start_time to calculate correct duration
2165
- existing_start_time_str = existing.start_time
2166
- if isinstance(existing_start_time_str, str):
2167
- existing_start_time = datetime.fromisoformat(existing_start_time_str.replace("Z", "+00:00"))
2168
- else:
2169
- existing_start_time = trace.start_time
2161
+ Args:
2162
+ trace: The Trace object to store (one per trace_id).
2163
+ """
2164
+ try:
2165
+ table = await self._get_table(table_type="traces", create_table_if_not_found=True)
2170
2166
 
2171
- recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
2167
+ trace_dict = trace.to_dict()
2168
+ trace_dict.pop("total_spans", None)
2169
+ trace_dict.pop("error_count", None)
2172
2170
 
2173
- update_values = {
2174
- "end_time": trace.end_time.isoformat(),
2175
- "duration_ms": recalculated_duration_ms,
2176
- "status": trace.status,
2177
- "name": trace.name if should_update_name else existing.name,
2178
- }
2171
+ async with self.async_session_factory() as sess, sess.begin():
2172
+ # Use upsert to handle concurrent inserts atomically
2173
+ # On conflict, update fields while preserving existing non-null context values
2174
+ # and keeping the earliest start_time
2175
+ insert_stmt = postgresql.insert(table).values(trace_dict)
2176
+
2177
+ # Build component level expressions for comparing trace priority
2178
+ new_level = self._get_trace_component_level_expr(
2179
+ insert_stmt.excluded.workflow_id,
2180
+ insert_stmt.excluded.team_id,
2181
+ insert_stmt.excluded.agent_id,
2182
+ insert_stmt.excluded.name,
2183
+ )
2184
+ existing_level = self._get_trace_component_level_expr(
2185
+ table.c.workflow_id,
2186
+ table.c.team_id,
2187
+ table.c.agent_id,
2188
+ table.c.name,
2189
+ )
2179
2190
 
2180
- # Update context fields ONLY if new value is not None (preserve non-null values)
2181
- if trace.run_id is not None:
2182
- update_values["run_id"] = trace.run_id
2183
- if trace.session_id is not None:
2184
- update_values["session_id"] = trace.session_id
2185
- if trace.user_id is not None:
2186
- update_values["user_id"] = trace.user_id
2187
- if trace.agent_id is not None:
2188
- update_values["agent_id"] = trace.agent_id
2189
- if trace.team_id is not None:
2190
- update_values["team_id"] = trace.team_id
2191
- if trace.workflow_id is not None:
2192
- update_values["workflow_id"] = trace.workflow_id
2193
-
2194
- stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
2195
- await sess.execute(stmt)
2196
- else:
2197
- trace_dict = trace.to_dict()
2198
- trace_dict.pop("total_spans", None)
2199
- trace_dict.pop("error_count", None)
2200
- stmt = postgresql.insert(table).values(trace_dict)
2201
- await sess.execute(stmt)
2191
+ # Build the ON CONFLICT DO UPDATE clause
2192
+ # Use LEAST for start_time, GREATEST for end_time to capture full trace duration
2193
+ # Use COALESCE to preserve existing non-null context values
2194
+ upsert_stmt = insert_stmt.on_conflict_do_update(
2195
+ index_elements=["trace_id"],
2196
+ set_={
2197
+ "end_time": func.greatest(table.c.end_time, insert_stmt.excluded.end_time),
2198
+ "start_time": func.least(table.c.start_time, insert_stmt.excluded.start_time),
2199
+ "duration_ms": func.extract(
2200
+ "epoch",
2201
+ func.cast(
2202
+ func.greatest(table.c.end_time, insert_stmt.excluded.end_time),
2203
+ TIMESTAMP(timezone=True),
2204
+ )
2205
+ - func.cast(
2206
+ func.least(table.c.start_time, insert_stmt.excluded.start_time),
2207
+ TIMESTAMP(timezone=True),
2208
+ ),
2209
+ )
2210
+ * 1000,
2211
+ "status": insert_stmt.excluded.status,
2212
+ # Update name only if new trace is from a higher-level component
2213
+ # Priority: workflow (3) > team (2) > agent (1) > child spans (0)
2214
+ "name": case(
2215
+ (new_level > existing_level, insert_stmt.excluded.name),
2216
+ else_=table.c.name,
2217
+ ),
2218
+ # Preserve existing non-null context values using COALESCE
2219
+ "run_id": func.coalesce(insert_stmt.excluded.run_id, table.c.run_id),
2220
+ "session_id": func.coalesce(insert_stmt.excluded.session_id, table.c.session_id),
2221
+ "user_id": func.coalesce(insert_stmt.excluded.user_id, table.c.user_id),
2222
+ "agent_id": func.coalesce(insert_stmt.excluded.agent_id, table.c.agent_id),
2223
+ "team_id": func.coalesce(insert_stmt.excluded.team_id, table.c.team_id),
2224
+ "workflow_id": func.coalesce(insert_stmt.excluded.workflow_id, table.c.workflow_id),
2225
+ },
2226
+ )
2227
+ await sess.execute(upsert_stmt)
2202
2228
 
2203
2229
  except Exception as e:
2204
2230
  log_error(f"Error creating trace: {e}")
@@ -30,8 +30,9 @@ from agno.utils.log import log_debug, log_error, log_info, log_warning
30
30
  from agno.utils.string import generate_id
31
31
 
32
32
  try:
33
- from sqlalchemy import ForeignKey, Index, String, UniqueConstraint, func, select, update
33
+ from sqlalchemy import ForeignKey, Index, String, UniqueConstraint, and_, case, func, or_, select, update
34
34
  from sqlalchemy.dialects import postgresql
35
+ from sqlalchemy.dialects.postgresql import TIMESTAMP
35
36
  from sqlalchemy.engine import Engine, create_engine
36
37
  from sqlalchemy.exc import ProgrammingError
37
38
  from sqlalchemy.orm import scoped_session, sessionmaker
@@ -2400,8 +2401,42 @@ class PostgresDb(BaseDb):
2400
2401
  # Fallback if spans table doesn't exist
2401
2402
  return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
2402
2403
 
2403
- def create_trace(self, trace: "Trace") -> None:
2404
- """Create a single trace record in the database.
2404
+ def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
2405
+ """Build a SQL CASE expression that returns the component level for a trace.
2406
+
2407
+ Component levels (higher = more important):
2408
+ - 3: Workflow root (.run or .arun with workflow_id)
2409
+ - 2: Team root (.run or .arun with team_id)
2410
+ - 1: Agent root (.run or .arun with agent_id)
2411
+ - 0: Child span (not a root)
2412
+
2413
+ Args:
2414
+ workflow_id_col: SQL column/expression for workflow_id
2415
+ team_id_col: SQL column/expression for team_id
2416
+ agent_id_col: SQL column/expression for agent_id
2417
+ name_col: SQL column/expression for name
2418
+
2419
+ Returns:
2420
+ SQLAlchemy CASE expression returning the component level as an integer.
2421
+ """
2422
+ is_root_name = or_(name_col.contains(".run"), name_col.contains(".arun"))
2423
+
2424
+ return case(
2425
+ # Workflow root (level 3)
2426
+ (and_(workflow_id_col.isnot(None), is_root_name), 3),
2427
+ # Team root (level 2)
2428
+ (and_(team_id_col.isnot(None), is_root_name), 2),
2429
+ # Agent root (level 1)
2430
+ (and_(agent_id_col.isnot(None), is_root_name), 1),
2431
+ # Child span or unknown (level 0)
2432
+ else_=0,
2433
+ )
2434
+
2435
+ def upsert_trace(self, trace: "Trace") -> None:
2436
+ """Create or update a single trace record in the database.
2437
+
2438
+ Uses INSERT ... ON CONFLICT DO UPDATE (upsert) to handle concurrent inserts
2439
+ atomically and avoid race conditions.
2405
2440
 
2406
2441
  Args:
2407
2442
  trace: The Trace object to store (one per trace_id).
@@ -2411,74 +2446,67 @@ class PostgresDb(BaseDb):
2411
2446
  if table is None:
2412
2447
  return
2413
2448
 
2414
- with self.Session() as sess, sess.begin():
2415
- # Check if trace exists
2416
- existing = sess.execute(select(table).where(table.c.trace_id == trace.trace_id)).fetchone()
2417
-
2418
- if existing:
2419
- # workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
2420
-
2421
- def get_component_level(workflow_id, team_id, agent_id, name):
2422
- # Check if name indicates a root span
2423
- is_root_name = ".run" in name or ".arun" in name
2424
-
2425
- if not is_root_name:
2426
- return 0 # Child span (not a root)
2427
- elif workflow_id:
2428
- return 3 # Workflow root
2429
- elif team_id:
2430
- return 2 # Team root
2431
- elif agent_id:
2432
- return 1 # Agent root
2433
- else:
2434
- return 0 # Unknown
2449
+ trace_dict = trace.to_dict()
2450
+ trace_dict.pop("total_spans", None)
2451
+ trace_dict.pop("error_count", None)
2435
2452
 
2436
- existing_level = get_component_level(
2437
- existing.workflow_id, existing.team_id, existing.agent_id, existing.name
2438
- )
2439
- new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
2440
-
2441
- # Only update name if new trace is from a higher or equal level
2442
- should_update_name = new_level > existing_level
2443
-
2444
- # Parse existing start_time to calculate correct duration
2445
- existing_start_time_str = existing.start_time
2446
- if isinstance(existing_start_time_str, str):
2447
- existing_start_time = datetime.fromisoformat(existing_start_time_str.replace("Z", "+00:00"))
2448
- else:
2449
- existing_start_time = trace.start_time
2450
-
2451
- recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
2452
-
2453
- update_values = {
2454
- "end_time": trace.end_time.isoformat(),
2455
- "duration_ms": recalculated_duration_ms,
2456
- "status": trace.status,
2457
- "name": trace.name if should_update_name else existing.name,
2458
- }
2453
+ with self.Session() as sess, sess.begin():
2454
+ # Use upsert to handle concurrent inserts atomically
2455
+ # On conflict, update fields while preserving existing non-null context values
2456
+ # and keeping the earliest start_time
2457
+ insert_stmt = postgresql.insert(table).values(trace_dict)
2458
+
2459
+ # Build component level expressions for comparing trace priority
2460
+ new_level = self._get_trace_component_level_expr(
2461
+ insert_stmt.excluded.workflow_id,
2462
+ insert_stmt.excluded.team_id,
2463
+ insert_stmt.excluded.agent_id,
2464
+ insert_stmt.excluded.name,
2465
+ )
2466
+ existing_level = self._get_trace_component_level_expr(
2467
+ table.c.workflow_id,
2468
+ table.c.team_id,
2469
+ table.c.agent_id,
2470
+ table.c.name,
2471
+ )
2459
2472
 
2460
- # Update context fields ONLY if new value is not None (preserve non-null values)
2461
- if trace.run_id is not None:
2462
- update_values["run_id"] = trace.run_id
2463
- if trace.session_id is not None:
2464
- update_values["session_id"] = trace.session_id
2465
- if trace.user_id is not None:
2466
- update_values["user_id"] = trace.user_id
2467
- if trace.agent_id is not None:
2468
- update_values["agent_id"] = trace.agent_id
2469
- if trace.team_id is not None:
2470
- update_values["team_id"] = trace.team_id
2471
- if trace.workflow_id is not None:
2472
- update_values["workflow_id"] = trace.workflow_id
2473
-
2474
- stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
2475
- sess.execute(stmt)
2476
- else:
2477
- trace_dict = trace.to_dict()
2478
- trace_dict.pop("total_spans", None)
2479
- trace_dict.pop("error_count", None)
2480
- stmt = postgresql.insert(table).values(trace_dict)
2481
- sess.execute(stmt)
2473
+ # Build the ON CONFLICT DO UPDATE clause
2474
+ # Use LEAST for start_time, GREATEST for end_time to capture full trace duration
2475
+ # Use COALESCE to preserve existing non-null context values
2476
+ upsert_stmt = insert_stmt.on_conflict_do_update(
2477
+ index_elements=["trace_id"],
2478
+ set_={
2479
+ "end_time": func.greatest(table.c.end_time, insert_stmt.excluded.end_time),
2480
+ "start_time": func.least(table.c.start_time, insert_stmt.excluded.start_time),
2481
+ "duration_ms": func.extract(
2482
+ "epoch",
2483
+ func.cast(
2484
+ func.greatest(table.c.end_time, insert_stmt.excluded.end_time),
2485
+ TIMESTAMP(timezone=True),
2486
+ )
2487
+ - func.cast(
2488
+ func.least(table.c.start_time, insert_stmt.excluded.start_time),
2489
+ TIMESTAMP(timezone=True),
2490
+ ),
2491
+ )
2492
+ * 1000,
2493
+ "status": insert_stmt.excluded.status,
2494
+ # Update name only if new trace is from a higher-level component
2495
+ # Priority: workflow (3) > team (2) > agent (1) > child spans (0)
2496
+ "name": case(
2497
+ (new_level > existing_level, insert_stmt.excluded.name),
2498
+ else_=table.c.name,
2499
+ ),
2500
+ # Preserve existing non-null context values using COALESCE
2501
+ "run_id": func.coalesce(insert_stmt.excluded.run_id, table.c.run_id),
2502
+ "session_id": func.coalesce(insert_stmt.excluded.session_id, table.c.session_id),
2503
+ "user_id": func.coalesce(insert_stmt.excluded.user_id, table.c.user_id),
2504
+ "agent_id": func.coalesce(insert_stmt.excluded.agent_id, table.c.agent_id),
2505
+ "team_id": func.coalesce(insert_stmt.excluded.team_id, table.c.team_id),
2506
+ "workflow_id": func.coalesce(insert_stmt.excluded.workflow_id, table.c.workflow_id),
2507
+ },
2508
+ )
2509
+ sess.execute(upsert_stmt)
2482
2510
 
2483
2511
  except Exception as e:
2484
2512
  log_error(f"Error creating trace: {e}")
agno/db/redis/redis.py CHANGED
@@ -1693,8 +1693,8 @@ class RedisDb(BaseDb):
1693
1693
  raise e
1694
1694
 
1695
1695
  # --- Traces ---
1696
- def create_trace(self, trace: "Trace") -> None:
1697
- """Create a single trace record in the database.
1696
+ def upsert_trace(self, trace: "Trace") -> None:
1697
+ """Create or update a single trace record in the database.
1698
1698
 
1699
1699
  Args:
1700
1700
  trace: The Trace object to store (one per trace_id).
@@ -2395,84 +2395,109 @@ class SingleStoreDb(BaseDb):
2395
2395
  # Fallback if spans table doesn't exist
2396
2396
  return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
2397
2397
 
2398
- def create_trace(self, trace: "Trace") -> None:
2399
- """Create a single trace record in the database.
2398
+ def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
2399
+ """Build a SQL CASE expression that returns the component level for a trace.
2400
+
2401
+ Component levels (higher = more important):
2402
+ - 3: Workflow root (.run or .arun with workflow_id)
2403
+ - 2: Team root (.run or .arun with team_id)
2404
+ - 1: Agent root (.run or .arun with agent_id)
2405
+ - 0: Child span (not a root)
2406
+
2407
+ Args:
2408
+ workflow_id_col: SQL column/expression for workflow_id
2409
+ team_id_col: SQL column/expression for team_id
2410
+ agent_id_col: SQL column/expression for agent_id
2411
+ name_col: SQL column/expression for name
2412
+
2413
+ Returns:
2414
+ SQLAlchemy CASE expression returning the component level as an integer.
2415
+ """
2416
+ from sqlalchemy import case, or_
2417
+
2418
+ is_root_name = or_(name_col.like("%.run%"), name_col.like("%.arun%"))
2419
+
2420
+ return case(
2421
+ # Workflow root (level 3)
2422
+ (and_(workflow_id_col.isnot(None), is_root_name), 3),
2423
+ # Team root (level 2)
2424
+ (and_(team_id_col.isnot(None), is_root_name), 2),
2425
+ # Agent root (level 1)
2426
+ (and_(agent_id_col.isnot(None), is_root_name), 1),
2427
+ # Child span or unknown (level 0)
2428
+ else_=0,
2429
+ )
2430
+
2431
+ def upsert_trace(self, trace: "Trace") -> None:
2432
+ """Create or update a single trace record in the database.
2433
+
2434
+ Uses INSERT ... ON DUPLICATE KEY UPDATE (upsert) to handle concurrent inserts
2435
+ atomically and avoid race conditions.
2400
2436
 
2401
2437
  Args:
2402
2438
  trace: The Trace object to store (one per trace_id).
2403
2439
  """
2440
+ from sqlalchemy import case
2441
+
2404
2442
  try:
2405
2443
  table = self._get_table(table_type="traces", create_table_if_not_found=True)
2406
2444
  if table is None:
2407
2445
  return
2408
2446
 
2447
+ trace_dict = trace.to_dict()
2448
+ trace_dict.pop("total_spans", None)
2449
+ trace_dict.pop("error_count", None)
2450
+
2409
2451
  with self.Session() as sess, sess.begin():
2410
- existing = sess.execute(select(table).where(table.c.trace_id == trace.trace_id)).fetchone()
2411
-
2412
- if existing:
2413
- # workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
2414
-
2415
- def get_component_level(workflow_id, team_id, agent_id, name):
2416
- # Check if name indicates a root span
2417
- is_root_name = ".run" in name or ".arun" in name
2418
-
2419
- if not is_root_name:
2420
- return 0 # Child span (not a root)
2421
- elif workflow_id:
2422
- return 3 # Workflow root
2423
- elif team_id:
2424
- return 2 # Team root
2425
- elif agent_id:
2426
- return 1 # Agent root
2427
- else:
2428
- return 0 # Unknown
2452
+ # Use upsert to handle concurrent inserts atomically
2453
+ # On conflict, update fields while preserving existing non-null context values
2454
+ # and keeping the earliest start_time
2455
+ insert_stmt = mysql.insert(table).values(trace_dict)
2456
+
2457
+ # Build component level expressions for comparing trace priority
2458
+ new_level = self._get_trace_component_level_expr(
2459
+ insert_stmt.inserted.workflow_id,
2460
+ insert_stmt.inserted.team_id,
2461
+ insert_stmt.inserted.agent_id,
2462
+ insert_stmt.inserted.name,
2463
+ )
2464
+ existing_level = self._get_trace_component_level_expr(
2465
+ table.c.workflow_id,
2466
+ table.c.team_id,
2467
+ table.c.agent_id,
2468
+ table.c.name,
2469
+ )
2429
2470
 
2430
- existing_level = get_component_level(
2431
- existing.workflow_id, existing.team_id, existing.agent_id, existing.name
2471
+ # Build the ON DUPLICATE KEY UPDATE clause
2472
+ # Use LEAST for start_time, GREATEST for end_time to capture full trace duration
2473
+ # Duration is calculated using TIMESTAMPDIFF in microseconds then converted to ms
2474
+ upsert_stmt = insert_stmt.on_duplicate_key_update(
2475
+ end_time=func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
2476
+ start_time=func.least(table.c.start_time, insert_stmt.inserted.start_time),
2477
+ # Calculate duration in milliseconds using TIMESTAMPDIFF
2478
+ # TIMESTAMPDIFF(MICROSECOND, start, end) / 1000 gives milliseconds
2479
+ duration_ms=func.timestampdiff(
2480
+ text("MICROSECOND"),
2481
+ func.least(table.c.start_time, insert_stmt.inserted.start_time),
2482
+ func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
2432
2483
  )
2433
- new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
2434
-
2435
- # Only update name if new trace is from a higher or equal level
2436
- should_update_name = new_level > existing_level
2437
-
2438
- # Parse existing start_time to calculate correct duration
2439
- existing_start_time_str = existing.start_time
2440
- if isinstance(existing_start_time_str, str):
2441
- existing_start_time = datetime.fromisoformat(existing_start_time_str.replace("Z", "+00:00"))
2442
- else:
2443
- existing_start_time = trace.start_time
2444
-
2445
- recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
2446
-
2447
- update_values = {
2448
- "end_time": trace.end_time.isoformat(),
2449
- "duration_ms": recalculated_duration_ms,
2450
- "status": trace.status,
2451
- "name": trace.name if should_update_name else existing.name,
2452
- }
2453
-
2454
- # Update context fields ONLY if new value is not None (preserve non-null values)
2455
- if trace.run_id is not None:
2456
- update_values["run_id"] = trace.run_id
2457
- if trace.session_id is not None:
2458
- update_values["session_id"] = trace.session_id
2459
- if trace.user_id is not None:
2460
- update_values["user_id"] = trace.user_id
2461
- if trace.agent_id is not None:
2462
- update_values["agent_id"] = trace.agent_id
2463
- if trace.team_id is not None:
2464
- update_values["team_id"] = trace.team_id
2465
- if trace.workflow_id is not None:
2466
- update_values["workflow_id"] = trace.workflow_id
2467
-
2468
- update_stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
2469
- sess.execute(update_stmt)
2470
- else:
2471
- trace_dict = trace.to_dict()
2472
- trace_dict.pop("total_spans", None)
2473
- trace_dict.pop("error_count", None)
2474
- insert_stmt = mysql.insert(table).values(trace_dict)
2475
- sess.execute(insert_stmt)
2484
+ / 1000,
2485
+ status=insert_stmt.inserted.status,
2486
+ # Update name only if new trace is from a higher-level component
2487
+ # Priority: workflow (3) > team (2) > agent (1) > child spans (0)
2488
+ name=case(
2489
+ (new_level > existing_level, insert_stmt.inserted.name),
2490
+ else_=table.c.name,
2491
+ ),
2492
+ # Preserve existing non-null context values using COALESCE
2493
+ run_id=func.coalesce(insert_stmt.inserted.run_id, table.c.run_id),
2494
+ session_id=func.coalesce(insert_stmt.inserted.session_id, table.c.session_id),
2495
+ user_id=func.coalesce(insert_stmt.inserted.user_id, table.c.user_id),
2496
+ agent_id=func.coalesce(insert_stmt.inserted.agent_id, table.c.agent_id),
2497
+ team_id=func.coalesce(insert_stmt.inserted.team_id, table.c.team_id),
2498
+ workflow_id=func.coalesce(insert_stmt.inserted.workflow_id, table.c.workflow_id),
2499
+ )
2500
+ sess.execute(upsert_stmt)
2476
2501
 
2477
2502
  except Exception as e:
2478
2503
  log_error(f"Error creating trace: {e}")