arthur-common 2.1.53__py3-none-any.whl → 2.1.54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

@@ -0,0 +1,875 @@
1
+ import json
2
+ import logging
3
+ from typing import Annotated
4
+ from uuid import UUID
5
+
6
+ import pandas as pd
7
+ from duckdb import DuckDBPyConnection
8
+
9
+ from arthur_common.aggregations.aggregator import (
10
+ NumericAggregationFunction,
11
+ SketchAggregationFunction,
12
+ )
13
+ from arthur_common.models.datasets import ModelProblemType
14
+ from arthur_common.models.metrics import (
15
+ BaseReportedAggregation,
16
+ DatasetReference,
17
+ NumericMetric,
18
+ SketchMetric,
19
+ )
20
+ from arthur_common.models.schema_definitions import MetricDatasetParameterAnnotation
21
+
22
+ # Global threshold for pass/fail determination
23
+ RELEVANCE_SCORE_THRESHOLD = 0.5
24
+ TOOL_SCORE_PASS_VALUE = 1
25
+ TOOL_SCORE_NO_TOOL_VALUE = 2
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def extract_spans_with_metrics_and_agents(root_spans):
31
+ """Recursively extract all spans with metrics and their associated agent names from the span tree.
32
+
33
+ Returns:
34
+ List of tuples: (span, agent_name)
35
+ """
36
+ spans_with_metrics_and_agents = []
37
+
38
+ def traverse_spans(spans, current_agent_name="unknown"):
39
+ for span in spans:
40
+ # Update current agent name if this span is an AGENT
41
+ if span.get("span_kind") == "AGENT":
42
+ try:
43
+ raw_data = span.get("raw_data", {})
44
+ if isinstance(raw_data, str):
45
+ raw_data = json.loads(raw_data)
46
+
47
+ # Try to get agent name from the span's name field
48
+ agent_name = raw_data.get("name", "unknown")
49
+ if agent_name != "unknown":
50
+ current_agent_name = agent_name
51
+ except (json.JSONDecodeError, KeyError, TypeError):
52
+ logger.error(
53
+ f"Error parsing attributes from span (span_id: {span.get('span_id')}) in trace {span.get('trace_id')}",
54
+ )
55
+
56
+ # Check if this span has metrics
57
+ if span.get("metric_results") and len(span.get("metric_results", [])) > 0:
58
+ spans_with_metrics_and_agents.append((span, current_agent_name))
59
+
60
+ # Recursively traverse children with the current agent name
61
+ if span.get("children", []):
62
+ traverse_spans(span["children"], current_agent_name)
63
+
64
+ traverse_spans(root_spans)
65
+ return spans_with_metrics_and_agents
66
+
67
+
68
+ def determine_relevance_pass_fail(score):
69
+ """Determine pass/fail for relevance scores using global threshold"""
70
+ if score is None:
71
+ return None
72
+ return "pass" if score >= RELEVANCE_SCORE_THRESHOLD else "fail"
73
+
74
+
75
+ def determine_tool_pass_fail(score):
76
+ """Determine pass/fail for tool scores using global threshold"""
77
+ if score is None:
78
+ return None
79
+ if score == TOOL_SCORE_PASS_VALUE:
80
+ return "pass"
81
+ elif score == TOOL_SCORE_NO_TOOL_VALUE:
82
+ return "no_tool"
83
+ else:
84
+ return "fail"
85
+
86
+
87
+ class AgenticMetricsOverTimeAggregation(SketchAggregationFunction):
88
+ """Combined aggregation for tool selection, tool usage, query relevance, and response relevance over time"""
89
+
90
+ METRIC_NAME = "agentic_metrics_over_time"
91
+ TOOL_SELECTION_METRIC_NAME = "tool_selection_over_time"
92
+ TOOL_USAGE_METRIC_NAME = "tool_usage_over_time"
93
+ QUERY_RELEVANCE_SCORES_METRIC_NAME = "query_relevance_scores_over_time"
94
+ RESPONSE_RELEVANCE_SCORES_METRIC_NAME = "response_relevance_scores_over_time"
95
+
96
+ @staticmethod
97
+ def id() -> UUID:
98
+ return UUID("00000000-0000-0000-0000-000000000030")
99
+
100
+ @staticmethod
101
+ def display_name() -> str:
102
+ return "Agentic Metrics Over Time"
103
+
104
+ @staticmethod
105
+ def description() -> str:
106
+ return "Metric that reports distributions (data sketches) on tool selection, tool usage, query relevance, and response relevance scores over time."
107
+
108
+ @staticmethod
109
+ def reported_aggregations() -> list[BaseReportedAggregation]:
110
+ return [
111
+ BaseReportedAggregation(
112
+ metric_name=AgenticMetricsOverTimeAggregation.TOOL_SELECTION_METRIC_NAME,
113
+ description="Distribution of tool selection over time.",
114
+ ),
115
+ BaseReportedAggregation(
116
+ metric_name=AgenticMetricsOverTimeAggregation.TOOL_USAGE_METRIC_NAME,
117
+ description="Distribution of tool usage over time.",
118
+ ),
119
+ BaseReportedAggregation(
120
+ metric_name=AgenticMetricsOverTimeAggregation.QUERY_RELEVANCE_SCORES_METRIC_NAME,
121
+ description="Distribution of query relevance over time.",
122
+ ),
123
+ BaseReportedAggregation(
124
+ metric_name=AgenticMetricsOverTimeAggregation.RESPONSE_RELEVANCE_SCORES_METRIC_NAME,
125
+ description="Distribution of response relevance over time.",
126
+ ),
127
+ ]
128
+
129
+ def aggregate(
130
+ self,
131
+ ddb_conn: DuckDBPyConnection,
132
+ dataset: Annotated[
133
+ DatasetReference,
134
+ MetricDatasetParameterAnnotation(
135
+ friendly_name="Dataset",
136
+ description="The agentic trace dataset containing traces with nested spans.",
137
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
138
+ ),
139
+ ],
140
+ ) -> list[SketchMetric]:
141
+ # Query traces by timestamp
142
+ results = ddb_conn.sql(
143
+ f"""
144
+ SELECT
145
+ time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
146
+ root_spans
147
+ FROM {dataset.dataset_table_name}
148
+ WHERE root_spans IS NOT NULL AND length(root_spans) > 0
149
+ ORDER BY ts DESC;
150
+ """,
151
+ ).df()
152
+
153
+ # Process traces and extract spans with metrics
154
+ tool_selection_data = []
155
+ tool_usage_data = []
156
+ query_relevance_data = []
157
+ response_relevance_data = []
158
+
159
+ for _, row in results.iterrows():
160
+ ts = row["ts"]
161
+ root_spans = row["root_spans"]
162
+
163
+ # Parse root_spans if it's a string
164
+ if isinstance(root_spans, str):
165
+ root_spans = json.loads(root_spans)
166
+
167
+ # Extract all spans with metrics and their agent names from the tree
168
+ spans_with_metrics_and_agents = extract_spans_with_metrics_and_agents(
169
+ root_spans,
170
+ )
171
+
172
+ # Process each span with metrics
173
+ for span, agent_name in spans_with_metrics_and_agents:
174
+ metric_results = span.get("metric_results", [])
175
+
176
+ for metric_result in metric_results:
177
+ metric_type = metric_result.get("metric_type")
178
+ details = metric_result.get("details", {})
179
+
180
+ if metric_type == "ToolSelection":
181
+ tool_selection = details.get("tool_selection", {})
182
+
183
+ # Extract tool selection data
184
+ tool_selection_score = tool_selection.get("tool_selection")
185
+ tool_selection_reason = tool_selection.get(
186
+ "tool_selection_reason",
187
+ "Unknown",
188
+ )
189
+
190
+ if tool_selection_score is not None:
191
+ tool_selection_data.append(
192
+ {
193
+ "ts": ts,
194
+ "tool_selection_score": tool_selection_score,
195
+ "tool_selection_reason": tool_selection_reason,
196
+ },
197
+ )
198
+
199
+ # Extract tool usage data
200
+ tool_usage_score = tool_selection.get("tool_usage")
201
+ tool_usage_reason = tool_selection.get(
202
+ "tool_usage_reason",
203
+ "Unknown",
204
+ )
205
+
206
+ if tool_usage_score is not None:
207
+ tool_usage_data.append(
208
+ {
209
+ "ts": ts,
210
+ "tool_usage_score": tool_usage_score,
211
+ "tool_usage_reason": tool_usage_reason,
212
+ },
213
+ )
214
+
215
+ elif metric_type == "QueryRelevance":
216
+ query_relevance = details.get("query_relevance", {})
217
+ reason = query_relevance.get("reason", "Unknown")
218
+
219
+ # Add individual scores if they exist
220
+ llm_score = query_relevance.get("llm_relevance_score")
221
+ reranker_score = query_relevance.get("reranker_relevance_score")
222
+ bert_score = query_relevance.get("bert_f_score")
223
+
224
+ if llm_score is not None:
225
+ query_relevance_data.append(
226
+ {
227
+ "ts": ts,
228
+ "score_type": "llm_relevance_score",
229
+ "score_value": llm_score,
230
+ "reason": reason,
231
+ },
232
+ )
233
+
234
+ if reranker_score is not None:
235
+ query_relevance_data.append(
236
+ {
237
+ "ts": ts,
238
+ "score_type": "reranker_relevance_score",
239
+ "score_value": reranker_score,
240
+ "reason": reason,
241
+ },
242
+ )
243
+
244
+ if bert_score is not None:
245
+ query_relevance_data.append(
246
+ {
247
+ "ts": ts,
248
+ "score_type": "bert_f_score",
249
+ "score_value": bert_score,
250
+ "reason": reason,
251
+ },
252
+ )
253
+
254
+ elif metric_type == "ResponseRelevance":
255
+ response_relevance = details.get("response_relevance", {})
256
+ reason = response_relevance.get("reason", "Unknown")
257
+
258
+ # Add individual scores if they exist
259
+ llm_score = response_relevance.get("llm_relevance_score")
260
+ reranker_score = response_relevance.get(
261
+ "reranker_relevance_score",
262
+ )
263
+ bert_score = response_relevance.get("bert_f_score")
264
+
265
+ if llm_score is not None:
266
+ response_relevance_data.append(
267
+ {
268
+ "ts": ts,
269
+ "score_type": "llm_relevance_score",
270
+ "score_value": llm_score,
271
+ "reason": reason,
272
+ },
273
+ )
274
+
275
+ if reranker_score is not None:
276
+ response_relevance_data.append(
277
+ {
278
+ "ts": ts,
279
+ "score_type": "reranker_relevance_score",
280
+ "score_value": reranker_score,
281
+ "reason": reason,
282
+ },
283
+ )
284
+
285
+ if bert_score is not None:
286
+ response_relevance_data.append(
287
+ {
288
+ "ts": ts,
289
+ "score_type": "bert_f_score",
290
+ "score_value": bert_score,
291
+ "reason": reason,
292
+ },
293
+ )
294
+
295
+ metrics = []
296
+
297
+ # Create tool selection metric
298
+ if tool_selection_data:
299
+ df = pd.DataFrame(tool_selection_data)
300
+ series = self.group_query_results_to_sketch_metrics(
301
+ df,
302
+ "tool_selection_score",
303
+ ["tool_selection_reason"],
304
+ "ts",
305
+ )
306
+ metrics.append(
307
+ self.series_to_metric(self.TOOL_SELECTION_METRIC_NAME, series),
308
+ )
309
+
310
+ # Create tool usage metric
311
+ if tool_usage_data:
312
+ df = pd.DataFrame(tool_usage_data)
313
+ series = self.group_query_results_to_sketch_metrics(
314
+ df,
315
+ "tool_usage_score",
316
+ ["tool_usage_reason"],
317
+ "ts",
318
+ )
319
+ metrics.append(self.series_to_metric(self.TOOL_USAGE_METRIC_NAME, series))
320
+
321
+ # Create comprehensive query relevance metric (includes all score data)
322
+ if query_relevance_data:
323
+ df = pd.DataFrame(query_relevance_data)
324
+ series = self.group_query_results_to_sketch_metrics(
325
+ df,
326
+ "score_value",
327
+ ["score_type", "reason"],
328
+ "ts",
329
+ )
330
+ metrics.append(
331
+ self.series_to_metric(self.QUERY_RELEVANCE_SCORES_METRIC_NAME, series),
332
+ )
333
+
334
+ # Create comprehensive response relevance metric (includes all score data)
335
+ if response_relevance_data:
336
+ df = pd.DataFrame(response_relevance_data)
337
+ series = self.group_query_results_to_sketch_metrics(
338
+ df,
339
+ "score_value",
340
+ ["score_type", "reason"],
341
+ "ts",
342
+ )
343
+ metrics.append(
344
+ self.series_to_metric(
345
+ self.RESPONSE_RELEVANCE_SCORES_METRIC_NAME,
346
+ series,
347
+ ),
348
+ )
349
+
350
+ return metrics
351
+
352
+
353
+ class AgenticRelevancePassFailCountAggregation(NumericAggregationFunction):
354
+ """Combined aggregation for query and response relevance pass/fail counts by agent"""
355
+
356
+ METRIC_NAME = "relevance_pass_fail_count"
357
+
358
+ @staticmethod
359
+ def id() -> UUID:
360
+ return UUID("00000000-0000-0000-0000-000000000034")
361
+
362
+ @staticmethod
363
+ def display_name() -> str:
364
+ return "Relevance Pass/Fail Count by Agent"
365
+
366
+ @staticmethod
367
+ def description() -> str:
368
+ return "Metric that counts the number of query and response relevance passes and failures, segmented by agent name and metric type."
369
+
370
+ @staticmethod
371
+ def reported_aggregations() -> list[BaseReportedAggregation]:
372
+ return [
373
+ BaseReportedAggregation(
374
+ metric_name=AgenticRelevancePassFailCountAggregation.METRIC_NAME,
375
+ description=AgenticRelevancePassFailCountAggregation.description(),
376
+ ),
377
+ ]
378
+
379
+ def aggregate(
380
+ self,
381
+ ddb_conn: DuckDBPyConnection,
382
+ dataset: Annotated[
383
+ DatasetReference,
384
+ MetricDatasetParameterAnnotation(
385
+ friendly_name="Dataset",
386
+ description="The agentic trace dataset containing traces with nested spans.",
387
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
388
+ ),
389
+ ],
390
+ ) -> list[NumericMetric]:
391
+ # Query traces by timestamp
392
+ results = ddb_conn.sql(
393
+ f"""
394
+ SELECT
395
+ time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
396
+ root_spans
397
+ FROM {dataset.dataset_table_name}
398
+ WHERE root_spans IS NOT NULL AND length(root_spans) > 0
399
+ ORDER BY ts DESC;
400
+ """,
401
+ ).df()
402
+
403
+ # Process traces and extract spans with metrics
404
+ processed_data = []
405
+ for _, row in results.iterrows():
406
+ ts = row["ts"]
407
+ root_spans = row["root_spans"]
408
+
409
+ # Parse root_spans if it's a string
410
+ if isinstance(root_spans, str):
411
+ root_spans = json.loads(root_spans)
412
+
413
+ # Extract all spans with metrics and their agent names from the tree
414
+ spans_with_metrics_and_agents = extract_spans_with_metrics_and_agents(
415
+ root_spans,
416
+ )
417
+
418
+ # Process each span with metrics
419
+ for span, agent_name in spans_with_metrics_and_agents:
420
+ metric_results = span.get("metric_results", [])
421
+
422
+ for metric_result in metric_results:
423
+ metric_type = metric_result.get("metric_type")
424
+ details = metric_result.get("details", {})
425
+
426
+ if metric_type in ["QueryRelevance", "ResponseRelevance"]:
427
+ relevance_data = details.get(
428
+ (
429
+ "query_relevance"
430
+ if metric_type == "QueryRelevance"
431
+ else "response_relevance"
432
+ ),
433
+ {},
434
+ )
435
+ # Check individual scores
436
+ for score_type in [
437
+ "llm_relevance_score",
438
+ "reranker_relevance_score",
439
+ "bert_f_score",
440
+ ]:
441
+ score = relevance_data.get(score_type)
442
+ if score is not None:
443
+ result = determine_relevance_pass_fail(score)
444
+ processed_data.append(
445
+ {
446
+ "ts": ts,
447
+ "agent_name": agent_name,
448
+ "metric_type": metric_type,
449
+ "score_type": score_type,
450
+ "result": result,
451
+ "count": 1,
452
+ },
453
+ )
454
+
455
+ if not processed_data:
456
+ return []
457
+
458
+ # Convert to DataFrame and aggregate
459
+ df = pd.DataFrame(processed_data)
460
+ aggregated = (
461
+ df.groupby(["ts", "agent_name", "metric_type", "score_type", "result"])[
462
+ "count"
463
+ ]
464
+ .sum()
465
+ .reset_index()
466
+ )
467
+
468
+ series = self.group_query_results_to_numeric_metrics(
469
+ aggregated,
470
+ "count",
471
+ ["agent_name", "metric_type", "score_type", "result"],
472
+ "ts",
473
+ )
474
+ metric = self.series_to_metric(self.METRIC_NAME, series)
475
+ return [metric]
476
+
477
+
478
+ class AgenticToolPassFailCountAggregation(NumericAggregationFunction):
479
+ """Combined aggregation for tool selection and usage pass/fail counts by agent"""
480
+
481
+ METRIC_NAME = "tool_pass_fail_count"
482
+
483
+ @staticmethod
484
+ def id() -> UUID:
485
+ return UUID("00000000-0000-0000-0000-000000000035")
486
+
487
+ @staticmethod
488
+ def display_name() -> str:
489
+ return "Tool Pass/Fail Count by Agent"
490
+
491
+ @staticmethod
492
+ def description() -> str:
493
+ return "Metric that counts the number of tool selection and usage passes, failures, and no-tool cases, segmented by agent name."
494
+
495
+ @staticmethod
496
+ def reported_aggregations() -> list[BaseReportedAggregation]:
497
+ return [
498
+ BaseReportedAggregation(
499
+ metric_name=AgenticToolPassFailCountAggregation.METRIC_NAME,
500
+ description=AgenticToolPassFailCountAggregation.description(),
501
+ ),
502
+ ]
503
+
504
+ def aggregate(
505
+ self,
506
+ ddb_conn: DuckDBPyConnection,
507
+ dataset: Annotated[
508
+ DatasetReference,
509
+ MetricDatasetParameterAnnotation(
510
+ friendly_name="Dataset",
511
+ description="The agentic trace dataset containing traces with nested spans.",
512
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
513
+ ),
514
+ ],
515
+ ) -> list[NumericMetric]:
516
+ # Query traces by timestamp
517
+ results = ddb_conn.sql(
518
+ f"""
519
+ SELECT
520
+ time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
521
+ root_spans
522
+ FROM {dataset.dataset_table_name}
523
+ WHERE root_spans IS NOT NULL AND length(root_spans) > 0
524
+ ORDER BY ts DESC;
525
+ """,
526
+ ).df()
527
+
528
+ # Process traces and extract spans with metrics
529
+ processed_data = []
530
+ for _, row in results.iterrows():
531
+ ts = row["ts"]
532
+ root_spans = row["root_spans"]
533
+
534
+ # Parse root_spans if it's a string
535
+ if isinstance(root_spans, str):
536
+ root_spans = json.loads(root_spans)
537
+
538
+ # Extract all spans with metrics and their agent names from the tree
539
+ spans_with_metrics_and_agents = extract_spans_with_metrics_and_agents(
540
+ root_spans,
541
+ )
542
+
543
+ # Process each span with metrics
544
+ for span, agent_name in spans_with_metrics_and_agents:
545
+ metric_results = span.get("metric_results", [])
546
+
547
+ for metric_result in metric_results:
548
+ if metric_result.get("metric_type") == "ToolSelection":
549
+ details = metric_result.get("details", {})
550
+ tool_selection = details.get("tool_selection", {})
551
+
552
+ tool_selection_score = tool_selection.get("tool_selection")
553
+ tool_usage_score = tool_selection.get("tool_usage")
554
+
555
+ # Process tool selection
556
+ if tool_selection_score is not None:
557
+ result = determine_tool_pass_fail(tool_selection_score)
558
+ processed_data.append(
559
+ {
560
+ "ts": ts,
561
+ "agent_name": agent_name,
562
+ "tool_metric": "tool_selection",
563
+ "result": result,
564
+ "count": 1,
565
+ },
566
+ )
567
+
568
+ # Process tool usage
569
+ if tool_usage_score is not None:
570
+ result = determine_tool_pass_fail(tool_usage_score)
571
+ processed_data.append(
572
+ {
573
+ "ts": ts,
574
+ "agent_name": agent_name,
575
+ "tool_metric": "tool_usage",
576
+ "result": result,
577
+ "count": 1,
578
+ },
579
+ )
580
+
581
+ if not processed_data:
582
+ return []
583
+
584
+ # Convert to DataFrame and aggregate
585
+ df = pd.DataFrame(processed_data)
586
+ aggregated = (
587
+ df.groupby(["ts", "agent_name", "tool_metric", "result"])["count"]
588
+ .sum()
589
+ .reset_index()
590
+ )
591
+
592
+ series = self.group_query_results_to_numeric_metrics(
593
+ aggregated,
594
+ "count",
595
+ ["agent_name", "tool_metric", "result"],
596
+ "ts",
597
+ )
598
+ metric = self.series_to_metric(self.METRIC_NAME, series)
599
+ return [metric]
600
+
601
+
602
+ class AgenticEventCountAggregation(NumericAggregationFunction):
603
+ METRIC_NAME = "event_count"
604
+
605
+ @staticmethod
606
+ def id() -> UUID:
607
+ return UUID("00000000-0000-0000-0000-000000000036")
608
+
609
+ @staticmethod
610
+ def display_name() -> str:
611
+ return "Number of Events"
612
+
613
+ @staticmethod
614
+ def description() -> str:
615
+ return "Metric that counts the number of events over time."
616
+
617
+ @staticmethod
618
+ def reported_aggregations() -> list[BaseReportedAggregation]:
619
+ return [
620
+ BaseReportedAggregation(
621
+ metric_name=AgenticEventCountAggregation.METRIC_NAME,
622
+ description=AgenticEventCountAggregation.description(),
623
+ ),
624
+ ]
625
+
626
+ def aggregate(
627
+ self,
628
+ ddb_conn: DuckDBPyConnection,
629
+ dataset: Annotated[
630
+ DatasetReference,
631
+ MetricDatasetParameterAnnotation(
632
+ friendly_name="Dataset",
633
+ description="The agentic trace dataset containing traces.",
634
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
635
+ ),
636
+ ],
637
+ ) -> list[NumericMetric]:
638
+ results = ddb_conn.sql(
639
+ f"""
640
+ SELECT
641
+ time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
642
+ COUNT(*) as count
643
+ FROM {dataset.dataset_table_name}
644
+ GROUP BY ts
645
+ ORDER BY ts DESC;
646
+ """,
647
+ ).df()
648
+
649
+ series = self.group_query_results_to_numeric_metrics(
650
+ results,
651
+ "count",
652
+ [],
653
+ "ts",
654
+ )
655
+ metric = self.series_to_metric(self.METRIC_NAME, series)
656
+ return [metric]
657
+
658
+
659
+ class AgenticLLMCallCountAggregation(NumericAggregationFunction):
660
+ METRIC_NAME = "llm_call_count"
661
+
662
+ @staticmethod
663
+ def id() -> UUID:
664
+ return UUID("00000000-0000-0000-0000-000000000038")
665
+
666
+ @staticmethod
667
+ def display_name() -> str:
668
+ return "Number of LLM Calls"
669
+
670
+ @staticmethod
671
+ def description() -> str:
672
+ return "Metric that counts the number of LLM spans (individual LLM calls) over time."
673
+
674
+ @staticmethod
675
+ def reported_aggregations() -> list[BaseReportedAggregation]:
676
+ return [
677
+ BaseReportedAggregation(
678
+ metric_name=AgenticLLMCallCountAggregation.METRIC_NAME,
679
+ description=AgenticLLMCallCountAggregation.description(),
680
+ ),
681
+ ]
682
+
683
+ def aggregate(
684
+ self,
685
+ ddb_conn: DuckDBPyConnection,
686
+ dataset: Annotated[
687
+ DatasetReference,
688
+ MetricDatasetParameterAnnotation(
689
+ friendly_name="Dataset",
690
+ description="The agentic trace dataset containing traces with nested spans.",
691
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
692
+ ),
693
+ ],
694
+ ) -> list[NumericMetric]:
695
+ results = ddb_conn.sql(
696
+ f"""
697
+ SELECT
698
+ time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
699
+ root_spans
700
+ FROM {dataset.dataset_table_name}
701
+ WHERE root_spans IS NOT NULL AND length(root_spans) > 0
702
+ ORDER BY ts DESC;
703
+ """,
704
+ ).df()
705
+
706
+ # Process traces and count LLM spans
707
+ llm_call_counts = {}
708
+ for _, row in results.iterrows():
709
+ ts = row["ts"]
710
+ root_spans = row["root_spans"]
711
+
712
+ # Parse root_spans if it's a string
713
+ if isinstance(root_spans, str):
714
+ root_spans = json.loads(root_spans)
715
+
716
+ # Count LLM spans in the tree
717
+ def count_llm_spans(spans):
718
+ count = 0
719
+ for span in spans:
720
+ # Check if this span is an LLM span
721
+ if span.get("span_kind") == "LLM":
722
+ count += 1
723
+
724
+ # Recursively count children
725
+ if span.get("children"):
726
+ count += count_llm_spans(span["children"])
727
+ return count
728
+
729
+ llm_count = count_llm_spans(root_spans)
730
+
731
+ if llm_count > 0:
732
+ if ts not in llm_call_counts:
733
+ llm_call_counts[ts] = 0
734
+ llm_call_counts[ts] += llm_count
735
+
736
+ if not llm_call_counts:
737
+ return []
738
+
739
+ # Convert to DataFrame format
740
+ data = [{"ts": ts, "count": count} for ts, count in llm_call_counts.items()]
741
+ df = pd.DataFrame(data)
742
+
743
+ series = self.group_query_results_to_numeric_metrics(
744
+ df,
745
+ "count",
746
+ [],
747
+ "ts",
748
+ )
749
+ metric = self.series_to_metric(self.METRIC_NAME, series)
750
+ return [metric]
751
+
752
+
753
+ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction):
754
+ METRIC_NAME = "tool_selection_and_usage_by_agent"
755
+
756
+ @staticmethod
757
+ def id() -> UUID:
758
+ return UUID("00000000-0000-0000-0000-000000000037")
759
+
760
+ @staticmethod
761
+ def display_name() -> str:
762
+ return "Tool Selection and Usage by Agent"
763
+
764
+ @staticmethod
765
+ def description() -> str:
766
+ return "Metric that counts tool selection and usage correctness, segmented by agent name."
767
+
768
+ @staticmethod
769
+ def reported_aggregations() -> list[BaseReportedAggregation]:
770
+ return [
771
+ BaseReportedAggregation(
772
+ metric_name=AgenticToolSelectionAndUsageByAgentAggregation.METRIC_NAME,
773
+ description=AgenticToolSelectionAndUsageByAgentAggregation.description(),
774
+ ),
775
+ ]
776
+
777
+ def aggregate(
778
+ self,
779
+ ddb_conn: DuckDBPyConnection,
780
+ dataset: Annotated[
781
+ DatasetReference,
782
+ MetricDatasetParameterAnnotation(
783
+ friendly_name="Dataset",
784
+ description="The agentic trace dataset containing traces with nested spans.",
785
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
786
+ ),
787
+ ],
788
+ ) -> list[NumericMetric]:
789
+ # Query traces by timestamp
790
+ results = ddb_conn.sql(
791
+ f"""
792
+ SELECT
793
+ time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
794
+ root_spans
795
+ FROM {dataset.dataset_table_name}
796
+ WHERE root_spans IS NOT NULL AND length(root_spans) > 0
797
+ ORDER BY ts DESC;
798
+ """,
799
+ ).df()
800
+
801
+ # Process traces and extract spans with metrics
802
+ processed_data = []
803
+ for _, row in results.iterrows():
804
+ ts = row["ts"]
805
+ root_spans = row["root_spans"]
806
+
807
+ # Parse root_spans if it's a string
808
+ if isinstance(root_spans, str):
809
+ root_spans = json.loads(root_spans)
810
+
811
+ # Extract all spans with metrics and their agent names from the tree
812
+ spans_with_metrics_and_agents = extract_spans_with_metrics_and_agents(
813
+ root_spans,
814
+ )
815
+
816
+ # Process each span with metrics
817
+ for span, agent_name in spans_with_metrics_and_agents:
818
+ metric_results = span.get("metric_results", [])
819
+
820
+ for metric_result in metric_results:
821
+ if metric_result.get("metric_type") == "ToolSelection":
822
+ details = metric_result.get("details", {})
823
+ tool_selection = details.get("tool_selection", {})
824
+
825
+ tool_selection_score = tool_selection.get("tool_selection")
826
+ tool_usage_score = tool_selection.get("tool_usage")
827
+
828
+ if tool_selection_score is not None:
829
+ # Categorize selection
830
+ if tool_selection_score == 1:
831
+ selection_category = "correct_selection"
832
+ elif tool_selection_score == 0:
833
+ selection_category = "incorrect_selection"
834
+ else:
835
+ selection_category = "no_selection"
836
+
837
+ # Categorize usage
838
+ if tool_usage_score == 1:
839
+ usage_category = "correct_usage"
840
+ elif tool_usage_score == 0:
841
+ usage_category = "incorrect_usage"
842
+ else:
843
+ usage_category = "no_usage"
844
+
845
+ processed_data.append(
846
+ {
847
+ "ts": ts,
848
+ "agent_name": agent_name,
849
+ "selection_category": selection_category,
850
+ "usage_category": usage_category,
851
+ "count": 1,
852
+ },
853
+ )
854
+
855
+ if not processed_data:
856
+ return []
857
+
858
+ # Convert to DataFrame and aggregate
859
+ df = pd.DataFrame(processed_data)
860
+ aggregated = (
861
+ df.groupby(["ts", "agent_name", "selection_category", "usage_category"])[
862
+ "count"
863
+ ]
864
+ .sum()
865
+ .reset_index()
866
+ )
867
+
868
+ series = self.group_query_results_to_numeric_metrics(
869
+ aggregated,
870
+ "count",
871
+ ["agent_name", "selection_category", "usage_category"],
872
+ "ts",
873
+ )
874
+ metric = self.series_to_metric(self.METRIC_NAME, series)
875
+ return [metric]