agno 2.3.10__py3-none-any.whl → 2.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. agno/compression/manager.py +87 -16
  2. agno/db/base.py +5 -5
  3. agno/db/dynamo/dynamo.py +2 -2
  4. agno/db/firestore/firestore.py +2 -2
  5. agno/db/gcs_json/gcs_json_db.py +2 -2
  6. agno/db/in_memory/in_memory_db.py +2 -2
  7. agno/db/json/json_db.py +2 -2
  8. agno/db/mongo/async_mongo.py +170 -68
  9. agno/db/mongo/mongo.py +170 -76
  10. agno/db/mysql/async_mysql.py +93 -69
  11. agno/db/mysql/mysql.py +93 -68
  12. agno/db/postgres/async_postgres.py +104 -78
  13. agno/db/postgres/postgres.py +97 -69
  14. agno/db/redis/redis.py +2 -2
  15. agno/db/singlestore/singlestore.py +91 -66
  16. agno/db/sqlite/async_sqlite.py +101 -78
  17. agno/db/sqlite/sqlite.py +97 -69
  18. agno/db/surrealdb/surrealdb.py +2 -2
  19. agno/exceptions.py +1 -0
  20. agno/knowledge/chunking/fixed.py +4 -1
  21. agno/knowledge/knowledge.py +105 -24
  22. agno/knowledge/reader/csv_reader.py +2 -2
  23. agno/knowledge/reader/text_reader.py +15 -3
  24. agno/knowledge/reader/wikipedia_reader.py +33 -1
  25. agno/knowledge/utils.py +52 -7
  26. agno/memory/strategies/base.py +3 -4
  27. agno/models/anthropic/claude.py +44 -0
  28. agno/models/aws/bedrock.py +60 -0
  29. agno/models/base.py +124 -30
  30. agno/models/google/gemini.py +141 -23
  31. agno/models/litellm/chat.py +25 -0
  32. agno/models/openai/chat.py +21 -0
  33. agno/models/openai/responses.py +44 -0
  34. agno/os/routers/knowledge/knowledge.py +20 -9
  35. agno/run/agent.py +17 -0
  36. agno/run/requirement.py +89 -6
  37. agno/tracing/exporter.py +2 -2
  38. agno/utils/print_response/agent.py +4 -4
  39. agno/utils/print_response/team.py +12 -12
  40. agno/utils/tokens.py +643 -27
  41. agno/vectordb/base.py +15 -2
  42. agno/vectordb/chroma/chromadb.py +6 -2
  43. agno/vectordb/lancedb/lance_db.py +3 -37
  44. agno/vectordb/milvus/milvus.py +6 -32
  45. agno/vectordb/mongodb/mongodb.py +0 -27
  46. agno/vectordb/pgvector/pgvector.py +21 -11
  47. agno/vectordb/pineconedb/pineconedb.py +0 -17
  48. agno/vectordb/qdrant/qdrant.py +6 -29
  49. agno/vectordb/redis/redisdb.py +0 -26
  50. agno/vectordb/singlestore/singlestore.py +16 -8
  51. agno/vectordb/surrealdb/surrealdb.py +0 -36
  52. agno/vectordb/weaviate/weaviate.py +6 -2
  53. {agno-2.3.10.dist-info → agno-2.3.12.dist-info}/METADATA +4 -1
  54. {agno-2.3.10.dist-info → agno-2.3.12.dist-info}/RECORD +57 -57
  55. {agno-2.3.10.dist-info → agno-2.3.12.dist-info}/WHEEL +0 -0
  56. {agno-2.3.10.dist-info → agno-2.3.12.dist-info}/licenses/LICENSE +0 -0
  57. {agno-2.3.10.dist-info → agno-2.3.12.dist-info}/top_level.txt +0 -0
agno/db/mongo/mongo.py CHANGED
@@ -2028,8 +2028,45 @@ class MongoDb(BaseDb):
2028
2028
  log_info(f"Migrated {len(memories)} memories to collection: {self.memory_table_name}")
2029
2029
 
2030
2030
  # --- Traces ---
2031
- def create_trace(self, trace: "Trace") -> None:
2032
- """Create a single trace record in the database.
2031
+ def _get_component_level(
2032
+ self, workflow_id: Optional[str], team_id: Optional[str], agent_id: Optional[str], name: str
2033
+ ) -> int:
2034
+ """Get the component level for a trace based on its context.
2035
+
2036
+ Component levels (higher = more important):
2037
+ - 3: Workflow root (.run or .arun with workflow_id)
2038
+ - 2: Team root (.run or .arun with team_id)
2039
+ - 1: Agent root (.run or .arun with agent_id)
2040
+ - 0: Child span (not a root)
2041
+
2042
+ Args:
2043
+ workflow_id: The workflow ID of the trace.
2044
+ team_id: The team ID of the trace.
2045
+ agent_id: The agent ID of the trace.
2046
+ name: The name of the trace.
2047
+
2048
+ Returns:
2049
+ int: The component level (0-3).
2050
+ """
2051
+ # Check if name indicates a root span
2052
+ is_root_name = ".run" in name or ".arun" in name
2053
+
2054
+ if not is_root_name:
2055
+ return 0 # Child span (not a root)
2056
+ elif workflow_id:
2057
+ return 3 # Workflow root
2058
+ elif team_id:
2059
+ return 2 # Team root
2060
+ elif agent_id:
2061
+ return 1 # Agent root
2062
+ else:
2063
+ return 0 # Unknown
2064
+
2065
+ def upsert_trace(self, trace: "Trace") -> None:
2066
+ """Create or update a single trace record in the database.
2067
+
2068
+ Uses MongoDB's update_one with upsert=True and aggregation pipeline
2069
+ to handle concurrent inserts atomically and avoid race conditions.
2033
2070
 
2034
2071
  Args:
2035
2072
  trace: The Trace object to store (one per trace_id).
@@ -2039,83 +2076,140 @@ class MongoDb(BaseDb):
2039
2076
  if collection is None:
2040
2077
  return
2041
2078
 
2042
- # Check if trace already exists
2043
- existing = collection.find_one({"trace_id": trace.trace_id})
2044
-
2045
- if existing:
2046
- # workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
2047
- def get_component_level(
2048
- workflow_id: Optional[str], team_id: Optional[str], agent_id: Optional[str], name: str
2049
- ) -> int:
2050
- # Check if name indicates a root span
2051
- is_root_name = ".run" in name or ".arun" in name
2052
-
2053
- if not is_root_name:
2054
- return 0 # Child span (not a root)
2055
- elif workflow_id:
2056
- return 3 # Workflow root
2057
- elif team_id:
2058
- return 2 # Team root
2059
- elif agent_id:
2060
- return 1 # Agent root
2061
- else:
2062
- return 0 # Unknown
2063
-
2064
- existing_level = get_component_level(
2065
- existing.get("workflow_id"),
2066
- existing.get("team_id"),
2067
- existing.get("agent_id"),
2068
- existing.get("name", ""),
2069
- )
2070
- new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
2071
-
2072
- # Only update name if new trace is from a higher or equal level
2073
- should_update_name = new_level > existing_level
2074
-
2075
- # Parse existing start_time to calculate correct duration
2076
- existing_start_time_str = existing.get("start_time")
2077
- if isinstance(existing_start_time_str, str):
2078
- existing_start_time = datetime.fromisoformat(existing_start_time_str.replace("Z", "+00:00"))
2079
- else:
2080
- existing_start_time = trace.start_time
2081
-
2082
- recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
2079
+ trace_dict = trace.to_dict()
2080
+ trace_dict.pop("total_spans", None)
2081
+ trace_dict.pop("error_count", None)
2083
2082
 
2084
- update_values: Dict[str, Any] = {
2085
- "end_time": trace.end_time.isoformat(),
2086
- "duration_ms": recalculated_duration_ms,
2087
- "status": trace.status,
2088
- "name": trace.name if should_update_name else existing.get("name"),
2089
- }
2083
+ # Calculate the component level for the new trace
2084
+ new_level = self._get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
2090
2085
 
2091
- # Update context fields ONLY if new value is not None (preserve non-null values)
2092
- if trace.run_id is not None:
2093
- update_values["run_id"] = trace.run_id
2094
- if trace.session_id is not None:
2095
- update_values["session_id"] = trace.session_id
2096
- if trace.user_id is not None:
2097
- update_values["user_id"] = trace.user_id
2098
- if trace.agent_id is not None:
2099
- update_values["agent_id"] = trace.agent_id
2100
- if trace.team_id is not None:
2101
- update_values["team_id"] = trace.team_id
2102
- if trace.workflow_id is not None:
2103
- update_values["workflow_id"] = trace.workflow_id
2104
-
2105
- log_debug(
2106
- f" Updating trace with context: run_id={update_values.get('run_id', 'unchanged')}, "
2107
- f"session_id={update_values.get('session_id', 'unchanged')}, "
2108
- f"user_id={update_values.get('user_id', 'unchanged')}, "
2109
- f"agent_id={update_values.get('agent_id', 'unchanged')}, "
2110
- f"team_id={update_values.get('team_id', 'unchanged')}, "
2111
- )
2086
+ # Use MongoDB aggregation pipeline update for atomic upsert
2087
+ # This allows conditional logic within a single atomic operation
2088
+ pipeline: List[Dict[str, Any]] = [
2089
+ {
2090
+ "$set": {
2091
+ # Always update these fields
2092
+ "status": trace.status,
2093
+ "created_at": {"$ifNull": ["$created_at", trace_dict.get("created_at")]},
2094
+ # Use $min for start_time (keep earliest)
2095
+ "start_time": {
2096
+ "$cond": {
2097
+ "if": {"$eq": [{"$type": "$start_time"}, "missing"]},
2098
+ "then": trace_dict.get("start_time"),
2099
+ "else": {"$min": ["$start_time", trace_dict.get("start_time")]},
2100
+ }
2101
+ },
2102
+ # Use $max for end_time (keep latest)
2103
+ "end_time": {
2104
+ "$cond": {
2105
+ "if": {"$eq": [{"$type": "$end_time"}, "missing"]},
2106
+ "then": trace_dict.get("end_time"),
2107
+ "else": {"$max": ["$end_time", trace_dict.get("end_time")]},
2108
+ }
2109
+ },
2110
+ # Preserve existing non-null context values using $ifNull
2111
+ "run_id": {"$ifNull": [trace.run_id, "$run_id"]},
2112
+ "session_id": {"$ifNull": [trace.session_id, "$session_id"]},
2113
+ "user_id": {"$ifNull": [trace.user_id, "$user_id"]},
2114
+ "agent_id": {"$ifNull": [trace.agent_id, "$agent_id"]},
2115
+ "team_id": {"$ifNull": [trace.team_id, "$team_id"]},
2116
+ "workflow_id": {"$ifNull": [trace.workflow_id, "$workflow_id"]},
2117
+ }
2118
+ },
2119
+ {
2120
+ "$set": {
2121
+ # Calculate duration_ms from the (potentially updated) start_time and end_time
2122
+ # MongoDB stores dates as strings in ISO format, so we need to parse them
2123
+ "duration_ms": {
2124
+ "$cond": {
2125
+ "if": {
2126
+ "$and": [
2127
+ {"$ne": [{"$type": "$start_time"}, "missing"]},
2128
+ {"$ne": [{"$type": "$end_time"}, "missing"]},
2129
+ ]
2130
+ },
2131
+ "then": {
2132
+ "$subtract": [
2133
+ {"$toLong": {"$toDate": "$end_time"}},
2134
+ {"$toLong": {"$toDate": "$start_time"}},
2135
+ ]
2136
+ },
2137
+ "else": trace_dict.get("duration_ms", 0),
2138
+ }
2139
+ },
2140
+ # Update name based on component level priority
2141
+ # Only update if new trace is from a higher-level component
2142
+ "name": {
2143
+ "$cond": {
2144
+ "if": {"$eq": [{"$type": "$name"}, "missing"]},
2145
+ "then": trace.name,
2146
+ "else": {
2147
+ "$cond": {
2148
+ "if": {
2149
+ "$gt": [
2150
+ new_level,
2151
+ {
2152
+ "$switch": {
2153
+ "branches": [
2154
+ # Check if existing name is a root span
2155
+ {
2156
+ "case": {
2157
+ "$not": {
2158
+ "$or": [
2159
+ {
2160
+ "$regexMatch": {
2161
+ "input": {"$ifNull": ["$name", ""]},
2162
+ "regex": "\\.run",
2163
+ }
2164
+ },
2165
+ {
2166
+ "$regexMatch": {
2167
+ "input": {"$ifNull": ["$name", ""]},
2168
+ "regex": "\\.arun",
2169
+ }
2170
+ },
2171
+ ]
2172
+ }
2173
+ },
2174
+ "then": 0,
2175
+ },
2176
+ # Workflow root (level 3)
2177
+ {
2178
+ "case": {"$ne": ["$workflow_id", None]},
2179
+ "then": 3,
2180
+ },
2181
+ # Team root (level 2)
2182
+ {
2183
+ "case": {"$ne": ["$team_id", None]},
2184
+ "then": 2,
2185
+ },
2186
+ # Agent root (level 1)
2187
+ {
2188
+ "case": {"$ne": ["$agent_id", None]},
2189
+ "then": 1,
2190
+ },
2191
+ ],
2192
+ "default": 0,
2193
+ }
2194
+ },
2195
+ ]
2196
+ },
2197
+ "then": trace.name,
2198
+ "else": "$name",
2199
+ }
2200
+ },
2201
+ }
2202
+ },
2203
+ }
2204
+ },
2205
+ ]
2112
2206
 
2113
- collection.update_one({"trace_id": trace.trace_id}, {"$set": update_values})
2114
- else:
2115
- trace_dict = trace.to_dict()
2116
- trace_dict.pop("total_spans", None)
2117
- trace_dict.pop("error_count", None)
2118
- collection.insert_one(trace_dict)
2207
+ # Perform atomic upsert using aggregation pipeline
2208
+ collection.update_one(
2209
+ {"trace_id": trace.trace_id},
2210
+ pipeline,
2211
+ upsert=True,
2212
+ )
2119
2213
 
2120
2214
  except Exception as e:
2121
2215
  log_error(f"Error creating trace: {e}")
@@ -2434,86 +2434,110 @@ class AsyncMySQLDb(AsyncBaseDb):
2434
2434
  # Fallback if spans table doesn't exist
2435
2435
  return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
2436
2436
 
2437
- async def create_trace(self, trace: "Trace") -> None:
2438
- """Create a single trace record in the database.
2437
+ def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
2438
+ """Build a SQL CASE expression that returns the component level for a trace.
2439
+
2440
+ Component levels (higher = more important):
2441
+ - 3: Workflow root (.run or .arun with workflow_id)
2442
+ - 2: Team root (.run or .arun with team_id)
2443
+ - 1: Agent root (.run or .arun with agent_id)
2444
+ - 0: Child span (not a root)
2445
+
2446
+ Args:
2447
+ workflow_id_col: SQL column/expression for workflow_id
2448
+ team_id_col: SQL column/expression for team_id
2449
+ agent_id_col: SQL column/expression for agent_id
2450
+ name_col: SQL column/expression for name
2451
+
2452
+ Returns:
2453
+ SQLAlchemy CASE expression returning the component level as an integer.
2454
+ """
2455
+ from sqlalchemy import and_, case, or_
2456
+
2457
+ is_root_name = or_(name_col.like("%.run%"), name_col.like("%.arun%"))
2458
+
2459
+ return case(
2460
+ # Workflow root (level 3)
2461
+ (and_(workflow_id_col.isnot(None), is_root_name), 3),
2462
+ # Team root (level 2)
2463
+ (and_(team_id_col.isnot(None), is_root_name), 2),
2464
+ # Agent root (level 1)
2465
+ (and_(agent_id_col.isnot(None), is_root_name), 1),
2466
+ # Child span or unknown (level 0)
2467
+ else_=0,
2468
+ )
2469
+
2470
+ async def upsert_trace(self, trace: "Trace") -> None:
2471
+ """Create or update a single trace record in the database.
2472
+
2473
+ Uses INSERT ... ON DUPLICATE KEY UPDATE (upsert) to handle concurrent inserts
2474
+ atomically and avoid race conditions.
2439
2475
 
2440
2476
  Args:
2441
2477
  trace: The Trace object to store (one per trace_id).
2442
2478
  """
2479
+ from sqlalchemy import case
2480
+
2443
2481
  try:
2444
2482
  table = await self._get_table(table_type="traces", create_table_if_not_found=True)
2445
2483
  if table is None:
2446
2484
  return
2447
2485
 
2448
- async with self.async_session_factory() as sess, sess.begin():
2449
- # Check if trace exists
2450
- result = await sess.execute(select(table).where(table.c.trace_id == trace.trace_id))
2451
- existing = result.fetchone()
2452
-
2453
- if existing:
2454
- # workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
2455
-
2456
- def get_component_level(workflow_id, team_id, agent_id, name):
2457
- # Check if name indicates a root span
2458
- is_root_name = ".run" in name or ".arun" in name
2459
-
2460
- if not is_root_name:
2461
- return 0 # Child span (not a root)
2462
- elif workflow_id:
2463
- return 3 # Workflow root
2464
- elif team_id:
2465
- return 2 # Team root
2466
- elif agent_id:
2467
- return 1 # Agent root
2468
- else:
2469
- return 0 # Unknown
2470
-
2471
- existing_level = get_component_level(
2472
- existing.workflow_id, existing.team_id, existing.agent_id, existing.name
2473
- )
2474
- new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
2475
-
2476
- # Only update name if new trace is from a higher or equal level
2477
- should_update_name = new_level > existing_level
2478
-
2479
- # Parse existing start_time to calculate correct duration
2480
- existing_start_time_str = existing.start_time
2481
- if isinstance(existing_start_time_str, str):
2482
- existing_start_time = datetime.fromisoformat(existing_start_time_str.replace("Z", "+00:00"))
2483
- else:
2484
- existing_start_time = trace.start_time
2485
-
2486
- recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
2486
+ trace_dict = trace.to_dict()
2487
+ trace_dict.pop("total_spans", None)
2488
+ trace_dict.pop("error_count", None)
2487
2489
 
2488
- update_values = {
2489
- "end_time": trace.end_time.isoformat(),
2490
- "duration_ms": recalculated_duration_ms,
2491
- "status": trace.status,
2492
- "name": trace.name if should_update_name else existing.name,
2493
- }
2490
+ async with self.async_session_factory() as sess, sess.begin():
2491
+ # Use upsert to handle concurrent inserts atomically
2492
+ # On conflict, update fields while preserving existing non-null context values
2493
+ # and keeping the earliest start_time
2494
+ insert_stmt = mysql.insert(table).values(trace_dict)
2495
+
2496
+ # Build component level expressions for comparing trace priority
2497
+ new_level = self._get_trace_component_level_expr(
2498
+ insert_stmt.inserted.workflow_id,
2499
+ insert_stmt.inserted.team_id,
2500
+ insert_stmt.inserted.agent_id,
2501
+ insert_stmt.inserted.name,
2502
+ )
2503
+ existing_level = self._get_trace_component_level_expr(
2504
+ table.c.workflow_id,
2505
+ table.c.team_id,
2506
+ table.c.agent_id,
2507
+ table.c.name,
2508
+ )
2494
2509
 
2495
- # Update context fields ONLY if new value is not None (preserve non-null values)
2496
- if trace.run_id is not None:
2497
- update_values["run_id"] = trace.run_id
2498
- if trace.session_id is not None:
2499
- update_values["session_id"] = trace.session_id
2500
- if trace.user_id is not None:
2501
- update_values["user_id"] = trace.user_id
2502
- if trace.agent_id is not None:
2503
- update_values["agent_id"] = trace.agent_id
2504
- if trace.team_id is not None:
2505
- update_values["team_id"] = trace.team_id
2506
- if trace.workflow_id is not None:
2507
- update_values["workflow_id"] = trace.workflow_id
2508
-
2509
- stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
2510
- await sess.execute(stmt)
2511
- else:
2512
- trace_dict = trace.to_dict()
2513
- trace_dict.pop("total_spans", None)
2514
- trace_dict.pop("error_count", None)
2515
- stmt = mysql.insert(table).values(trace_dict)
2516
- await sess.execute(stmt)
2510
+ # Build the ON DUPLICATE KEY UPDATE clause
2511
+ # Use LEAST for start_time, GREATEST for end_time to capture full trace duration
2512
+ # MySQL stores timestamps as ISO strings, so string comparison works for ISO format
2513
+ # Duration is calculated using TIMESTAMPDIFF in microseconds then converted to ms
2514
+ upsert_stmt = insert_stmt.on_duplicate_key_update(
2515
+ end_time=func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
2516
+ start_time=func.least(table.c.start_time, insert_stmt.inserted.start_time),
2517
+ # Calculate duration in milliseconds using TIMESTAMPDIFF
2518
+ # TIMESTAMPDIFF(MICROSECOND, start, end) / 1000 gives milliseconds
2519
+ duration_ms=func.timestampdiff(
2520
+ text("MICROSECOND"),
2521
+ func.least(table.c.start_time, insert_stmt.inserted.start_time),
2522
+ func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
2523
+ )
2524
+ / 1000,
2525
+ status=insert_stmt.inserted.status,
2526
+ # Update name only if new trace is from a higher-level component
2527
+ # Priority: workflow (3) > team (2) > agent (1) > child spans (0)
2528
+ name=case(
2529
+ (new_level > existing_level, insert_stmt.inserted.name),
2530
+ else_=table.c.name,
2531
+ ),
2532
+ # Preserve existing non-null context values using COALESCE
2533
+ run_id=func.coalesce(insert_stmt.inserted.run_id, table.c.run_id),
2534
+ session_id=func.coalesce(insert_stmt.inserted.session_id, table.c.session_id),
2535
+ user_id=func.coalesce(insert_stmt.inserted.user_id, table.c.user_id),
2536
+ agent_id=func.coalesce(insert_stmt.inserted.agent_id, table.c.agent_id),
2537
+ team_id=func.coalesce(insert_stmt.inserted.team_id, table.c.team_id),
2538
+ workflow_id=func.coalesce(insert_stmt.inserted.workflow_id, table.c.workflow_id),
2539
+ )
2540
+ await sess.execute(upsert_stmt)
2517
2541
 
2518
2542
  except Exception as e:
2519
2543
  log_error(f"Error creating trace: {e}")
agno/db/mysql/mysql.py CHANGED
@@ -2452,85 +2452,110 @@ class MySQLDb(BaseDb):
2452
2452
  # Fallback if spans table doesn't exist
2453
2453
  return select(table, literal(0).label("total_spans"), literal(0).label("error_count"))
2454
2454
 
2455
- def create_trace(self, trace: "Trace") -> None:
2456
- """Create a single trace record in the database.
2455
+ def _get_trace_component_level_expr(self, workflow_id_col, team_id_col, agent_id_col, name_col):
2456
+ """Build a SQL CASE expression that returns the component level for a trace.
2457
+
2458
+ Component levels (higher = more important):
2459
+ - 3: Workflow root (.run or .arun with workflow_id)
2460
+ - 2: Team root (.run or .arun with team_id)
2461
+ - 1: Agent root (.run or .arun with agent_id)
2462
+ - 0: Child span (not a root)
2463
+
2464
+ Args:
2465
+ workflow_id_col: SQL column/expression for workflow_id
2466
+ team_id_col: SQL column/expression for team_id
2467
+ agent_id_col: SQL column/expression for agent_id
2468
+ name_col: SQL column/expression for name
2469
+
2470
+ Returns:
2471
+ SQLAlchemy CASE expression returning the component level as an integer.
2472
+ """
2473
+ from sqlalchemy import and_, case, or_
2474
+
2475
+ is_root_name = or_(name_col.like("%.run%"), name_col.like("%.arun%"))
2476
+
2477
+ return case(
2478
+ # Workflow root (level 3)
2479
+ (and_(workflow_id_col.isnot(None), is_root_name), 3),
2480
+ # Team root (level 2)
2481
+ (and_(team_id_col.isnot(None), is_root_name), 2),
2482
+ # Agent root (level 1)
2483
+ (and_(agent_id_col.isnot(None), is_root_name), 1),
2484
+ # Child span or unknown (level 0)
2485
+ else_=0,
2486
+ )
2487
+
2488
+ def upsert_trace(self, trace: "Trace") -> None:
2489
+ """Create or update a single trace record in the database.
2490
+
2491
+ Uses INSERT ... ON DUPLICATE KEY UPDATE (upsert) to handle concurrent inserts
2492
+ atomically and avoid race conditions.
2457
2493
 
2458
2494
  Args:
2459
2495
  trace: The Trace object to store (one per trace_id).
2460
2496
  """
2497
+ from sqlalchemy import case
2498
+
2461
2499
  try:
2462
2500
  table = self._get_table(table_type="traces", create_table_if_not_found=True)
2463
2501
  if table is None:
2464
2502
  return
2465
2503
 
2466
- with self.Session() as sess, sess.begin():
2467
- # Check if trace exists
2468
- existing = sess.execute(select(table).where(table.c.trace_id == trace.trace_id)).fetchone()
2469
-
2470
- if existing:
2471
- # workflow (level 3) > team (level 2) > agent (level 1) > child/unknown (level 0)
2472
-
2473
- def get_component_level(workflow_id, team_id, agent_id, name):
2474
- # Check if name indicates a root span
2475
- is_root_name = ".run" in name or ".arun" in name
2476
-
2477
- if not is_root_name:
2478
- return 0 # Child span (not a root)
2479
- elif workflow_id:
2480
- return 3 # Workflow root
2481
- elif team_id:
2482
- return 2 # Team root
2483
- elif agent_id:
2484
- return 1 # Agent root
2485
- else:
2486
- return 0 # Unknown
2487
-
2488
- existing_level = get_component_level(
2489
- existing.workflow_id, existing.team_id, existing.agent_id, existing.name
2490
- )
2491
- new_level = get_component_level(trace.workflow_id, trace.team_id, trace.agent_id, trace.name)
2492
-
2493
- # Only update name if new trace is from a higher or equal level
2494
- should_update_name = new_level > existing_level
2504
+ trace_dict = trace.to_dict()
2505
+ trace_dict.pop("total_spans", None)
2506
+ trace_dict.pop("error_count", None)
2495
2507
 
2496
- # Parse existing start_time to calculate correct duration
2497
- existing_start_time_str = existing.start_time
2498
- if isinstance(existing_start_time_str, str):
2499
- existing_start_time = datetime.fromisoformat(existing_start_time_str.replace("Z", "+00:00"))
2500
- else:
2501
- existing_start_time = trace.start_time
2502
-
2503
- recalculated_duration_ms = int((trace.end_time - existing_start_time).total_seconds() * 1000)
2504
-
2505
- update_values = {
2506
- "end_time": trace.end_time.isoformat(),
2507
- "duration_ms": recalculated_duration_ms,
2508
- "status": trace.status,
2509
- "name": trace.name if should_update_name else existing.name,
2510
- }
2508
+ with self.Session() as sess, sess.begin():
2509
+ # Use upsert to handle concurrent inserts atomically
2510
+ # On conflict, update fields while preserving existing non-null context values
2511
+ # and keeping the earliest start_time
2512
+ insert_stmt = mysql.insert(table).values(trace_dict)
2513
+
2514
+ # Build component level expressions for comparing trace priority
2515
+ new_level = self._get_trace_component_level_expr(
2516
+ insert_stmt.inserted.workflow_id,
2517
+ insert_stmt.inserted.team_id,
2518
+ insert_stmt.inserted.agent_id,
2519
+ insert_stmt.inserted.name,
2520
+ )
2521
+ existing_level = self._get_trace_component_level_expr(
2522
+ table.c.workflow_id,
2523
+ table.c.team_id,
2524
+ table.c.agent_id,
2525
+ table.c.name,
2526
+ )
2511
2527
 
2512
- # Update context fields ONLY if new value is not None (preserve non-null values)
2513
- if trace.run_id is not None:
2514
- update_values["run_id"] = trace.run_id
2515
- if trace.session_id is not None:
2516
- update_values["session_id"] = trace.session_id
2517
- if trace.user_id is not None:
2518
- update_values["user_id"] = trace.user_id
2519
- if trace.agent_id is not None:
2520
- update_values["agent_id"] = trace.agent_id
2521
- if trace.team_id is not None:
2522
- update_values["team_id"] = trace.team_id
2523
- if trace.workflow_id is not None:
2524
- update_values["workflow_id"] = trace.workflow_id
2525
-
2526
- stmt = update(table).where(table.c.trace_id == trace.trace_id).values(**update_values)
2527
- sess.execute(stmt)
2528
- else:
2529
- trace_dict = trace.to_dict()
2530
- trace_dict.pop("total_spans", None)
2531
- trace_dict.pop("error_count", None)
2532
- stmt = mysql.insert(table).values(trace_dict)
2533
- sess.execute(stmt)
2528
+ # Build the ON DUPLICATE KEY UPDATE clause
2529
+ # Use LEAST for start_time, GREATEST for end_time to capture full trace duration
2530
+ # MySQL stores timestamps as ISO strings, so string comparison works for ISO format
2531
+ # Duration is calculated using TIMESTAMPDIFF in microseconds then converted to ms
2532
+ upsert_stmt = insert_stmt.on_duplicate_key_update(
2533
+ end_time=func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
2534
+ start_time=func.least(table.c.start_time, insert_stmt.inserted.start_time),
2535
+ # Calculate duration in milliseconds using TIMESTAMPDIFF
2536
+ # TIMESTAMPDIFF(MICROSECOND, start, end) / 1000 gives milliseconds
2537
+ duration_ms=func.timestampdiff(
2538
+ text("MICROSECOND"),
2539
+ func.least(table.c.start_time, insert_stmt.inserted.start_time),
2540
+ func.greatest(table.c.end_time, insert_stmt.inserted.end_time),
2541
+ )
2542
+ / 1000,
2543
+ status=insert_stmt.inserted.status,
2544
+ # Update name only if new trace is from a higher-level component
2545
+ # Priority: workflow (3) > team (2) > agent (1) > child spans (0)
2546
+ name=case(
2547
+ (new_level > existing_level, insert_stmt.inserted.name),
2548
+ else_=table.c.name,
2549
+ ),
2550
+ # Preserve existing non-null context values using COALESCE
2551
+ run_id=func.coalesce(insert_stmt.inserted.run_id, table.c.run_id),
2552
+ session_id=func.coalesce(insert_stmt.inserted.session_id, table.c.session_id),
2553
+ user_id=func.coalesce(insert_stmt.inserted.user_id, table.c.user_id),
2554
+ agent_id=func.coalesce(insert_stmt.inserted.agent_id, table.c.agent_id),
2555
+ team_id=func.coalesce(insert_stmt.inserted.team_id, table.c.team_id),
2556
+ workflow_id=func.coalesce(insert_stmt.inserted.workflow_id, table.c.workflow_id),
2557
+ )
2558
+ sess.execute(upsert_stmt)
2534
2559
 
2535
2560
  except Exception as e:
2536
2561
  log_error(f"Error creating trace: {e}")