arthur-common 2.1.53__py3-none-any.whl → 2.1.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/aggregations/functions/agentic_aggregations.py +875 -0
- arthur_common/aggregations/functions/categorical_count.py +2 -2
- arthur_common/aggregations/functions/confusion_matrix.py +1 -1
- arthur_common/aggregations/functions/inference_count.py +2 -2
- arthur_common/aggregations/functions/inference_count_by_class.py +3 -3
- arthur_common/aggregations/functions/inference_null_count.py +2 -2
- arthur_common/aggregations/functions/mean_absolute_error.py +3 -2
- arthur_common/aggregations/functions/mean_squared_error.py +3 -2
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +1 -1
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +2 -2
- arthur_common/aggregations/functions/numeric_stats.py +2 -2
- arthur_common/aggregations/functions/numeric_sum.py +2 -2
- arthur_common/aggregations/functions/shield_aggregations.py +14 -13
- arthur_common/models/datasets.py +1 -0
- arthur_common/models/metrics.py +7 -6
- arthur_common/models/schema_definitions.py +58 -0
- arthur_common/models/shield.py +158 -0
- arthur_common/models/task_job_specs.py +26 -2
- {arthur_common-2.1.53.dist-info → arthur_common-2.1.54.dist-info}/METADATA +1 -1
- {arthur_common-2.1.53.dist-info → arthur_common-2.1.54.dist-info}/RECORD +21 -20
- {arthur_common-2.1.53.dist-info → arthur_common-2.1.54.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,875 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
from uuid import UUID
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from duckdb import DuckDBPyConnection
|
|
8
|
+
|
|
9
|
+
from arthur_common.aggregations.aggregator import (
|
|
10
|
+
NumericAggregationFunction,
|
|
11
|
+
SketchAggregationFunction,
|
|
12
|
+
)
|
|
13
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
14
|
+
from arthur_common.models.metrics import (
|
|
15
|
+
BaseReportedAggregation,
|
|
16
|
+
DatasetReference,
|
|
17
|
+
NumericMetric,
|
|
18
|
+
SketchMetric,
|
|
19
|
+
)
|
|
20
|
+
from arthur_common.models.schema_definitions import MetricDatasetParameterAnnotation
|
|
21
|
+
|
|
22
|
+
# Global threshold for pass/fail determination
|
|
23
|
+
RELEVANCE_SCORE_THRESHOLD = 0.5
|
|
24
|
+
TOOL_SCORE_PASS_VALUE = 1
|
|
25
|
+
TOOL_SCORE_NO_TOOL_VALUE = 2
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def extract_spans_with_metrics_and_agents(root_spans):
|
|
31
|
+
"""Recursively extract all spans with metrics and their associated agent names from the span tree.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
List of tuples: (span, agent_name)
|
|
35
|
+
"""
|
|
36
|
+
spans_with_metrics_and_agents = []
|
|
37
|
+
|
|
38
|
+
def traverse_spans(spans, current_agent_name="unknown"):
|
|
39
|
+
for span in spans:
|
|
40
|
+
# Update current agent name if this span is an AGENT
|
|
41
|
+
if span.get("span_kind") == "AGENT":
|
|
42
|
+
try:
|
|
43
|
+
raw_data = span.get("raw_data", {})
|
|
44
|
+
if isinstance(raw_data, str):
|
|
45
|
+
raw_data = json.loads(raw_data)
|
|
46
|
+
|
|
47
|
+
# Try to get agent name from the span's name field
|
|
48
|
+
agent_name = raw_data.get("name", "unknown")
|
|
49
|
+
if agent_name != "unknown":
|
|
50
|
+
current_agent_name = agent_name
|
|
51
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
52
|
+
logger.error(
|
|
53
|
+
f"Error parsing attributes from span (span_id: {span.get('span_id')}) in trace {span.get('trace_id')}",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Check if this span has metrics
|
|
57
|
+
if span.get("metric_results") and len(span.get("metric_results", [])) > 0:
|
|
58
|
+
spans_with_metrics_and_agents.append((span, current_agent_name))
|
|
59
|
+
|
|
60
|
+
# Recursively traverse children with the current agent name
|
|
61
|
+
if span.get("children", []):
|
|
62
|
+
traverse_spans(span["children"], current_agent_name)
|
|
63
|
+
|
|
64
|
+
traverse_spans(root_spans)
|
|
65
|
+
return spans_with_metrics_and_agents
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def determine_relevance_pass_fail(score):
|
|
69
|
+
"""Determine pass/fail for relevance scores using global threshold"""
|
|
70
|
+
if score is None:
|
|
71
|
+
return None
|
|
72
|
+
return "pass" if score >= RELEVANCE_SCORE_THRESHOLD else "fail"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def determine_tool_pass_fail(score):
|
|
76
|
+
"""Determine pass/fail for tool scores using global threshold"""
|
|
77
|
+
if score is None:
|
|
78
|
+
return None
|
|
79
|
+
if score == TOOL_SCORE_PASS_VALUE:
|
|
80
|
+
return "pass"
|
|
81
|
+
elif score == TOOL_SCORE_NO_TOOL_VALUE:
|
|
82
|
+
return "no_tool"
|
|
83
|
+
else:
|
|
84
|
+
return "fail"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class AgenticMetricsOverTimeAggregation(SketchAggregationFunction):
|
|
88
|
+
"""Combined aggregation for tool selection, tool usage, query relevance, and response relevance over time"""
|
|
89
|
+
|
|
90
|
+
METRIC_NAME = "agentic_metrics_over_time"
|
|
91
|
+
TOOL_SELECTION_METRIC_NAME = "tool_selection_over_time"
|
|
92
|
+
TOOL_USAGE_METRIC_NAME = "tool_usage_over_time"
|
|
93
|
+
QUERY_RELEVANCE_SCORES_METRIC_NAME = "query_relevance_scores_over_time"
|
|
94
|
+
RESPONSE_RELEVANCE_SCORES_METRIC_NAME = "response_relevance_scores_over_time"
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def id() -> UUID:
|
|
98
|
+
return UUID("00000000-0000-0000-0000-000000000030")
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def display_name() -> str:
|
|
102
|
+
return "Agentic Metrics Over Time"
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def description() -> str:
|
|
106
|
+
return "Metric that reports distributions (data sketches) on tool selection, tool usage, query relevance, and response relevance scores over time."
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
110
|
+
return [
|
|
111
|
+
BaseReportedAggregation(
|
|
112
|
+
metric_name=AgenticMetricsOverTimeAggregation.TOOL_SELECTION_METRIC_NAME,
|
|
113
|
+
description="Distribution of tool selection over time.",
|
|
114
|
+
),
|
|
115
|
+
BaseReportedAggregation(
|
|
116
|
+
metric_name=AgenticMetricsOverTimeAggregation.TOOL_USAGE_METRIC_NAME,
|
|
117
|
+
description="Distribution of tool usage over time.",
|
|
118
|
+
),
|
|
119
|
+
BaseReportedAggregation(
|
|
120
|
+
metric_name=AgenticMetricsOverTimeAggregation.QUERY_RELEVANCE_SCORES_METRIC_NAME,
|
|
121
|
+
description="Distribution of query relevance over time.",
|
|
122
|
+
),
|
|
123
|
+
BaseReportedAggregation(
|
|
124
|
+
metric_name=AgenticMetricsOverTimeAggregation.RESPONSE_RELEVANCE_SCORES_METRIC_NAME,
|
|
125
|
+
description="Distribution of response relevance over time.",
|
|
126
|
+
),
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
def aggregate(
|
|
130
|
+
self,
|
|
131
|
+
ddb_conn: DuckDBPyConnection,
|
|
132
|
+
dataset: Annotated[
|
|
133
|
+
DatasetReference,
|
|
134
|
+
MetricDatasetParameterAnnotation(
|
|
135
|
+
friendly_name="Dataset",
|
|
136
|
+
description="The agentic trace dataset containing traces with nested spans.",
|
|
137
|
+
model_problem_type=ModelProblemType.AGENTIC_TRACE,
|
|
138
|
+
),
|
|
139
|
+
],
|
|
140
|
+
) -> list[SketchMetric]:
|
|
141
|
+
# Query traces by timestamp
|
|
142
|
+
results = ddb_conn.sql(
|
|
143
|
+
f"""
|
|
144
|
+
SELECT
|
|
145
|
+
time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
|
|
146
|
+
root_spans
|
|
147
|
+
FROM {dataset.dataset_table_name}
|
|
148
|
+
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
149
|
+
ORDER BY ts DESC;
|
|
150
|
+
""",
|
|
151
|
+
).df()
|
|
152
|
+
|
|
153
|
+
# Process traces and extract spans with metrics
|
|
154
|
+
tool_selection_data = []
|
|
155
|
+
tool_usage_data = []
|
|
156
|
+
query_relevance_data = []
|
|
157
|
+
response_relevance_data = []
|
|
158
|
+
|
|
159
|
+
for _, row in results.iterrows():
|
|
160
|
+
ts = row["ts"]
|
|
161
|
+
root_spans = row["root_spans"]
|
|
162
|
+
|
|
163
|
+
# Parse root_spans if it's a string
|
|
164
|
+
if isinstance(root_spans, str):
|
|
165
|
+
root_spans = json.loads(root_spans)
|
|
166
|
+
|
|
167
|
+
# Extract all spans with metrics and their agent names from the tree
|
|
168
|
+
spans_with_metrics_and_agents = extract_spans_with_metrics_and_agents(
|
|
169
|
+
root_spans,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Process each span with metrics
|
|
173
|
+
for span, agent_name in spans_with_metrics_and_agents:
|
|
174
|
+
metric_results = span.get("metric_results", [])
|
|
175
|
+
|
|
176
|
+
for metric_result in metric_results:
|
|
177
|
+
metric_type = metric_result.get("metric_type")
|
|
178
|
+
details = metric_result.get("details", {})
|
|
179
|
+
|
|
180
|
+
if metric_type == "ToolSelection":
|
|
181
|
+
tool_selection = details.get("tool_selection", {})
|
|
182
|
+
|
|
183
|
+
# Extract tool selection data
|
|
184
|
+
tool_selection_score = tool_selection.get("tool_selection")
|
|
185
|
+
tool_selection_reason = tool_selection.get(
|
|
186
|
+
"tool_selection_reason",
|
|
187
|
+
"Unknown",
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if tool_selection_score is not None:
|
|
191
|
+
tool_selection_data.append(
|
|
192
|
+
{
|
|
193
|
+
"ts": ts,
|
|
194
|
+
"tool_selection_score": tool_selection_score,
|
|
195
|
+
"tool_selection_reason": tool_selection_reason,
|
|
196
|
+
},
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Extract tool usage data
|
|
200
|
+
tool_usage_score = tool_selection.get("tool_usage")
|
|
201
|
+
tool_usage_reason = tool_selection.get(
|
|
202
|
+
"tool_usage_reason",
|
|
203
|
+
"Unknown",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if tool_usage_score is not None:
|
|
207
|
+
tool_usage_data.append(
|
|
208
|
+
{
|
|
209
|
+
"ts": ts,
|
|
210
|
+
"tool_usage_score": tool_usage_score,
|
|
211
|
+
"tool_usage_reason": tool_usage_reason,
|
|
212
|
+
},
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
elif metric_type == "QueryRelevance":
|
|
216
|
+
query_relevance = details.get("query_relevance", {})
|
|
217
|
+
reason = query_relevance.get("reason", "Unknown")
|
|
218
|
+
|
|
219
|
+
# Add individual scores if they exist
|
|
220
|
+
llm_score = query_relevance.get("llm_relevance_score")
|
|
221
|
+
reranker_score = query_relevance.get("reranker_relevance_score")
|
|
222
|
+
bert_score = query_relevance.get("bert_f_score")
|
|
223
|
+
|
|
224
|
+
if llm_score is not None:
|
|
225
|
+
query_relevance_data.append(
|
|
226
|
+
{
|
|
227
|
+
"ts": ts,
|
|
228
|
+
"score_type": "llm_relevance_score",
|
|
229
|
+
"score_value": llm_score,
|
|
230
|
+
"reason": reason,
|
|
231
|
+
},
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if reranker_score is not None:
|
|
235
|
+
query_relevance_data.append(
|
|
236
|
+
{
|
|
237
|
+
"ts": ts,
|
|
238
|
+
"score_type": "reranker_relevance_score",
|
|
239
|
+
"score_value": reranker_score,
|
|
240
|
+
"reason": reason,
|
|
241
|
+
},
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if bert_score is not None:
|
|
245
|
+
query_relevance_data.append(
|
|
246
|
+
{
|
|
247
|
+
"ts": ts,
|
|
248
|
+
"score_type": "bert_f_score",
|
|
249
|
+
"score_value": bert_score,
|
|
250
|
+
"reason": reason,
|
|
251
|
+
},
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
elif metric_type == "ResponseRelevance":
|
|
255
|
+
response_relevance = details.get("response_relevance", {})
|
|
256
|
+
reason = response_relevance.get("reason", "Unknown")
|
|
257
|
+
|
|
258
|
+
# Add individual scores if they exist
|
|
259
|
+
llm_score = response_relevance.get("llm_relevance_score")
|
|
260
|
+
reranker_score = response_relevance.get(
|
|
261
|
+
"reranker_relevance_score",
|
|
262
|
+
)
|
|
263
|
+
bert_score = response_relevance.get("bert_f_score")
|
|
264
|
+
|
|
265
|
+
if llm_score is not None:
|
|
266
|
+
response_relevance_data.append(
|
|
267
|
+
{
|
|
268
|
+
"ts": ts,
|
|
269
|
+
"score_type": "llm_relevance_score",
|
|
270
|
+
"score_value": llm_score,
|
|
271
|
+
"reason": reason,
|
|
272
|
+
},
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if reranker_score is not None:
|
|
276
|
+
response_relevance_data.append(
|
|
277
|
+
{
|
|
278
|
+
"ts": ts,
|
|
279
|
+
"score_type": "reranker_relevance_score",
|
|
280
|
+
"score_value": reranker_score,
|
|
281
|
+
"reason": reason,
|
|
282
|
+
},
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
if bert_score is not None:
|
|
286
|
+
response_relevance_data.append(
|
|
287
|
+
{
|
|
288
|
+
"ts": ts,
|
|
289
|
+
"score_type": "bert_f_score",
|
|
290
|
+
"score_value": bert_score,
|
|
291
|
+
"reason": reason,
|
|
292
|
+
},
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
metrics = []
|
|
296
|
+
|
|
297
|
+
# Create tool selection metric
|
|
298
|
+
if tool_selection_data:
|
|
299
|
+
df = pd.DataFrame(tool_selection_data)
|
|
300
|
+
series = self.group_query_results_to_sketch_metrics(
|
|
301
|
+
df,
|
|
302
|
+
"tool_selection_score",
|
|
303
|
+
["tool_selection_reason"],
|
|
304
|
+
"ts",
|
|
305
|
+
)
|
|
306
|
+
metrics.append(
|
|
307
|
+
self.series_to_metric(self.TOOL_SELECTION_METRIC_NAME, series),
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Create tool usage metric
|
|
311
|
+
if tool_usage_data:
|
|
312
|
+
df = pd.DataFrame(tool_usage_data)
|
|
313
|
+
series = self.group_query_results_to_sketch_metrics(
|
|
314
|
+
df,
|
|
315
|
+
"tool_usage_score",
|
|
316
|
+
["tool_usage_reason"],
|
|
317
|
+
"ts",
|
|
318
|
+
)
|
|
319
|
+
metrics.append(self.series_to_metric(self.TOOL_USAGE_METRIC_NAME, series))
|
|
320
|
+
|
|
321
|
+
# Create comprehensive query relevance metric (includes all score data)
|
|
322
|
+
if query_relevance_data:
|
|
323
|
+
df = pd.DataFrame(query_relevance_data)
|
|
324
|
+
series = self.group_query_results_to_sketch_metrics(
|
|
325
|
+
df,
|
|
326
|
+
"score_value",
|
|
327
|
+
["score_type", "reason"],
|
|
328
|
+
"ts",
|
|
329
|
+
)
|
|
330
|
+
metrics.append(
|
|
331
|
+
self.series_to_metric(self.QUERY_RELEVANCE_SCORES_METRIC_NAME, series),
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Create comprehensive response relevance metric (includes all score data)
|
|
335
|
+
if response_relevance_data:
|
|
336
|
+
df = pd.DataFrame(response_relevance_data)
|
|
337
|
+
series = self.group_query_results_to_sketch_metrics(
|
|
338
|
+
df,
|
|
339
|
+
"score_value",
|
|
340
|
+
["score_type", "reason"],
|
|
341
|
+
"ts",
|
|
342
|
+
)
|
|
343
|
+
metrics.append(
|
|
344
|
+
self.series_to_metric(
|
|
345
|
+
self.RESPONSE_RELEVANCE_SCORES_METRIC_NAME,
|
|
346
|
+
series,
|
|
347
|
+
),
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
return metrics
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class AgenticRelevancePassFailCountAggregation(NumericAggregationFunction):
|
|
354
|
+
"""Combined aggregation for query and response relevance pass/fail counts by agent"""
|
|
355
|
+
|
|
356
|
+
METRIC_NAME = "relevance_pass_fail_count"
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def id() -> UUID:
|
|
360
|
+
return UUID("00000000-0000-0000-0000-000000000034")
|
|
361
|
+
|
|
362
|
+
@staticmethod
|
|
363
|
+
def display_name() -> str:
|
|
364
|
+
return "Relevance Pass/Fail Count by Agent"
|
|
365
|
+
|
|
366
|
+
@staticmethod
|
|
367
|
+
def description() -> str:
|
|
368
|
+
return "Metric that counts the number of query and response relevance passes and failures, segmented by agent name and metric type."
|
|
369
|
+
|
|
370
|
+
@staticmethod
|
|
371
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
372
|
+
return [
|
|
373
|
+
BaseReportedAggregation(
|
|
374
|
+
metric_name=AgenticRelevancePassFailCountAggregation.METRIC_NAME,
|
|
375
|
+
description=AgenticRelevancePassFailCountAggregation.description(),
|
|
376
|
+
),
|
|
377
|
+
]
|
|
378
|
+
|
|
379
|
+
def aggregate(
|
|
380
|
+
self,
|
|
381
|
+
ddb_conn: DuckDBPyConnection,
|
|
382
|
+
dataset: Annotated[
|
|
383
|
+
DatasetReference,
|
|
384
|
+
MetricDatasetParameterAnnotation(
|
|
385
|
+
friendly_name="Dataset",
|
|
386
|
+
description="The agentic trace dataset containing traces with nested spans.",
|
|
387
|
+
model_problem_type=ModelProblemType.AGENTIC_TRACE,
|
|
388
|
+
),
|
|
389
|
+
],
|
|
390
|
+
) -> list[NumericMetric]:
|
|
391
|
+
# Query traces by timestamp
|
|
392
|
+
results = ddb_conn.sql(
|
|
393
|
+
f"""
|
|
394
|
+
SELECT
|
|
395
|
+
time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
|
|
396
|
+
root_spans
|
|
397
|
+
FROM {dataset.dataset_table_name}
|
|
398
|
+
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
399
|
+
ORDER BY ts DESC;
|
|
400
|
+
""",
|
|
401
|
+
).df()
|
|
402
|
+
|
|
403
|
+
# Process traces and extract spans with metrics
|
|
404
|
+
processed_data = []
|
|
405
|
+
for _, row in results.iterrows():
|
|
406
|
+
ts = row["ts"]
|
|
407
|
+
root_spans = row["root_spans"]
|
|
408
|
+
|
|
409
|
+
# Parse root_spans if it's a string
|
|
410
|
+
if isinstance(root_spans, str):
|
|
411
|
+
root_spans = json.loads(root_spans)
|
|
412
|
+
|
|
413
|
+
# Extract all spans with metrics and their agent names from the tree
|
|
414
|
+
spans_with_metrics_and_agents = extract_spans_with_metrics_and_agents(
|
|
415
|
+
root_spans,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Process each span with metrics
|
|
419
|
+
for span, agent_name in spans_with_metrics_and_agents:
|
|
420
|
+
metric_results = span.get("metric_results", [])
|
|
421
|
+
|
|
422
|
+
for metric_result in metric_results:
|
|
423
|
+
metric_type = metric_result.get("metric_type")
|
|
424
|
+
details = metric_result.get("details", {})
|
|
425
|
+
|
|
426
|
+
if metric_type in ["QueryRelevance", "ResponseRelevance"]:
|
|
427
|
+
relevance_data = details.get(
|
|
428
|
+
(
|
|
429
|
+
"query_relevance"
|
|
430
|
+
if metric_type == "QueryRelevance"
|
|
431
|
+
else "response_relevance"
|
|
432
|
+
),
|
|
433
|
+
{},
|
|
434
|
+
)
|
|
435
|
+
# Check individual scores
|
|
436
|
+
for score_type in [
|
|
437
|
+
"llm_relevance_score",
|
|
438
|
+
"reranker_relevance_score",
|
|
439
|
+
"bert_f_score",
|
|
440
|
+
]:
|
|
441
|
+
score = relevance_data.get(score_type)
|
|
442
|
+
if score is not None:
|
|
443
|
+
result = determine_relevance_pass_fail(score)
|
|
444
|
+
processed_data.append(
|
|
445
|
+
{
|
|
446
|
+
"ts": ts,
|
|
447
|
+
"agent_name": agent_name,
|
|
448
|
+
"metric_type": metric_type,
|
|
449
|
+
"score_type": score_type,
|
|
450
|
+
"result": result,
|
|
451
|
+
"count": 1,
|
|
452
|
+
},
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
if not processed_data:
|
|
456
|
+
return []
|
|
457
|
+
|
|
458
|
+
# Convert to DataFrame and aggregate
|
|
459
|
+
df = pd.DataFrame(processed_data)
|
|
460
|
+
aggregated = (
|
|
461
|
+
df.groupby(["ts", "agent_name", "metric_type", "score_type", "result"])[
|
|
462
|
+
"count"
|
|
463
|
+
]
|
|
464
|
+
.sum()
|
|
465
|
+
.reset_index()
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
469
|
+
aggregated,
|
|
470
|
+
"count",
|
|
471
|
+
["agent_name", "metric_type", "score_type", "result"],
|
|
472
|
+
"ts",
|
|
473
|
+
)
|
|
474
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
475
|
+
return [metric]
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class AgenticToolPassFailCountAggregation(NumericAggregationFunction):
|
|
479
|
+
"""Combined aggregation for tool selection and usage pass/fail counts by agent"""
|
|
480
|
+
|
|
481
|
+
METRIC_NAME = "tool_pass_fail_count"
|
|
482
|
+
|
|
483
|
+
@staticmethod
|
|
484
|
+
def id() -> UUID:
|
|
485
|
+
return UUID("00000000-0000-0000-0000-000000000035")
|
|
486
|
+
|
|
487
|
+
@staticmethod
|
|
488
|
+
def display_name() -> str:
|
|
489
|
+
return "Tool Pass/Fail Count by Agent"
|
|
490
|
+
|
|
491
|
+
@staticmethod
|
|
492
|
+
def description() -> str:
|
|
493
|
+
return "Metric that counts the number of tool selection and usage passes, failures, and no-tool cases, segmented by agent name."
|
|
494
|
+
|
|
495
|
+
@staticmethod
|
|
496
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
497
|
+
return [
|
|
498
|
+
BaseReportedAggregation(
|
|
499
|
+
metric_name=AgenticToolPassFailCountAggregation.METRIC_NAME,
|
|
500
|
+
description=AgenticToolPassFailCountAggregation.description(),
|
|
501
|
+
),
|
|
502
|
+
]
|
|
503
|
+
|
|
504
|
+
def aggregate(
|
|
505
|
+
self,
|
|
506
|
+
ddb_conn: DuckDBPyConnection,
|
|
507
|
+
dataset: Annotated[
|
|
508
|
+
DatasetReference,
|
|
509
|
+
MetricDatasetParameterAnnotation(
|
|
510
|
+
friendly_name="Dataset",
|
|
511
|
+
description="The agentic trace dataset containing traces with nested spans.",
|
|
512
|
+
model_problem_type=ModelProblemType.AGENTIC_TRACE,
|
|
513
|
+
),
|
|
514
|
+
],
|
|
515
|
+
) -> list[NumericMetric]:
|
|
516
|
+
# Query traces by timestamp
|
|
517
|
+
results = ddb_conn.sql(
|
|
518
|
+
f"""
|
|
519
|
+
SELECT
|
|
520
|
+
time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
|
|
521
|
+
root_spans
|
|
522
|
+
FROM {dataset.dataset_table_name}
|
|
523
|
+
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
524
|
+
ORDER BY ts DESC;
|
|
525
|
+
""",
|
|
526
|
+
).df()
|
|
527
|
+
|
|
528
|
+
# Process traces and extract spans with metrics
|
|
529
|
+
processed_data = []
|
|
530
|
+
for _, row in results.iterrows():
|
|
531
|
+
ts = row["ts"]
|
|
532
|
+
root_spans = row["root_spans"]
|
|
533
|
+
|
|
534
|
+
# Parse root_spans if it's a string
|
|
535
|
+
if isinstance(root_spans, str):
|
|
536
|
+
root_spans = json.loads(root_spans)
|
|
537
|
+
|
|
538
|
+
# Extract all spans with metrics and their agent names from the tree
|
|
539
|
+
spans_with_metrics_and_agents = extract_spans_with_metrics_and_agents(
|
|
540
|
+
root_spans,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Process each span with metrics
|
|
544
|
+
for span, agent_name in spans_with_metrics_and_agents:
|
|
545
|
+
metric_results = span.get("metric_results", [])
|
|
546
|
+
|
|
547
|
+
for metric_result in metric_results:
|
|
548
|
+
if metric_result.get("metric_type") == "ToolSelection":
|
|
549
|
+
details = metric_result.get("details", {})
|
|
550
|
+
tool_selection = details.get("tool_selection", {})
|
|
551
|
+
|
|
552
|
+
tool_selection_score = tool_selection.get("tool_selection")
|
|
553
|
+
tool_usage_score = tool_selection.get("tool_usage")
|
|
554
|
+
|
|
555
|
+
# Process tool selection
|
|
556
|
+
if tool_selection_score is not None:
|
|
557
|
+
result = determine_tool_pass_fail(tool_selection_score)
|
|
558
|
+
processed_data.append(
|
|
559
|
+
{
|
|
560
|
+
"ts": ts,
|
|
561
|
+
"agent_name": agent_name,
|
|
562
|
+
"tool_metric": "tool_selection",
|
|
563
|
+
"result": result,
|
|
564
|
+
"count": 1,
|
|
565
|
+
},
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Process tool usage
|
|
569
|
+
if tool_usage_score is not None:
|
|
570
|
+
result = determine_tool_pass_fail(tool_usage_score)
|
|
571
|
+
processed_data.append(
|
|
572
|
+
{
|
|
573
|
+
"ts": ts,
|
|
574
|
+
"agent_name": agent_name,
|
|
575
|
+
"tool_metric": "tool_usage",
|
|
576
|
+
"result": result,
|
|
577
|
+
"count": 1,
|
|
578
|
+
},
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
if not processed_data:
|
|
582
|
+
return []
|
|
583
|
+
|
|
584
|
+
# Convert to DataFrame and aggregate
|
|
585
|
+
df = pd.DataFrame(processed_data)
|
|
586
|
+
aggregated = (
|
|
587
|
+
df.groupby(["ts", "agent_name", "tool_metric", "result"])["count"]
|
|
588
|
+
.sum()
|
|
589
|
+
.reset_index()
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
593
|
+
aggregated,
|
|
594
|
+
"count",
|
|
595
|
+
["agent_name", "tool_metric", "result"],
|
|
596
|
+
"ts",
|
|
597
|
+
)
|
|
598
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
599
|
+
return [metric]
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
class AgenticEventCountAggregation(NumericAggregationFunction):
|
|
603
|
+
METRIC_NAME = "event_count"
|
|
604
|
+
|
|
605
|
+
@staticmethod
|
|
606
|
+
def id() -> UUID:
|
|
607
|
+
return UUID("00000000-0000-0000-0000-000000000036")
|
|
608
|
+
|
|
609
|
+
@staticmethod
|
|
610
|
+
def display_name() -> str:
|
|
611
|
+
return "Number of Events"
|
|
612
|
+
|
|
613
|
+
@staticmethod
|
|
614
|
+
def description() -> str:
|
|
615
|
+
return "Metric that counts the number of events over time."
|
|
616
|
+
|
|
617
|
+
@staticmethod
|
|
618
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
619
|
+
return [
|
|
620
|
+
BaseReportedAggregation(
|
|
621
|
+
metric_name=AgenticEventCountAggregation.METRIC_NAME,
|
|
622
|
+
description=AgenticEventCountAggregation.description(),
|
|
623
|
+
),
|
|
624
|
+
]
|
|
625
|
+
|
|
626
|
+
def aggregate(
|
|
627
|
+
self,
|
|
628
|
+
ddb_conn: DuckDBPyConnection,
|
|
629
|
+
dataset: Annotated[
|
|
630
|
+
DatasetReference,
|
|
631
|
+
MetricDatasetParameterAnnotation(
|
|
632
|
+
friendly_name="Dataset",
|
|
633
|
+
description="The agentic trace dataset containing traces.",
|
|
634
|
+
model_problem_type=ModelProblemType.AGENTIC_TRACE,
|
|
635
|
+
),
|
|
636
|
+
],
|
|
637
|
+
) -> list[NumericMetric]:
|
|
638
|
+
results = ddb_conn.sql(
|
|
639
|
+
f"""
|
|
640
|
+
SELECT
|
|
641
|
+
time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
|
|
642
|
+
COUNT(*) as count
|
|
643
|
+
FROM {dataset.dataset_table_name}
|
|
644
|
+
GROUP BY ts
|
|
645
|
+
ORDER BY ts DESC;
|
|
646
|
+
""",
|
|
647
|
+
).df()
|
|
648
|
+
|
|
649
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
650
|
+
results,
|
|
651
|
+
"count",
|
|
652
|
+
[],
|
|
653
|
+
"ts",
|
|
654
|
+
)
|
|
655
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
656
|
+
return [metric]
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
class AgenticLLMCallCountAggregation(NumericAggregationFunction):
|
|
660
|
+
METRIC_NAME = "llm_call_count"
|
|
661
|
+
|
|
662
|
+
@staticmethod
|
|
663
|
+
def id() -> UUID:
|
|
664
|
+
return UUID("00000000-0000-0000-0000-000000000038")
|
|
665
|
+
|
|
666
|
+
@staticmethod
|
|
667
|
+
def display_name() -> str:
|
|
668
|
+
return "Number of LLM Calls"
|
|
669
|
+
|
|
670
|
+
@staticmethod
|
|
671
|
+
def description() -> str:
|
|
672
|
+
return "Metric that counts the number of LLM spans (individual LLM calls) over time."
|
|
673
|
+
|
|
674
|
+
@staticmethod
|
|
675
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
676
|
+
return [
|
|
677
|
+
BaseReportedAggregation(
|
|
678
|
+
metric_name=AgenticLLMCallCountAggregation.METRIC_NAME,
|
|
679
|
+
description=AgenticLLMCallCountAggregation.description(),
|
|
680
|
+
),
|
|
681
|
+
]
|
|
682
|
+
|
|
683
|
+
def aggregate(
|
|
684
|
+
self,
|
|
685
|
+
ddb_conn: DuckDBPyConnection,
|
|
686
|
+
dataset: Annotated[
|
|
687
|
+
DatasetReference,
|
|
688
|
+
MetricDatasetParameterAnnotation(
|
|
689
|
+
friendly_name="Dataset",
|
|
690
|
+
description="The agentic trace dataset containing traces with nested spans.",
|
|
691
|
+
model_problem_type=ModelProblemType.AGENTIC_TRACE,
|
|
692
|
+
),
|
|
693
|
+
],
|
|
694
|
+
) -> list[NumericMetric]:
|
|
695
|
+
results = ddb_conn.sql(
|
|
696
|
+
f"""
|
|
697
|
+
SELECT
|
|
698
|
+
time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
|
|
699
|
+
root_spans
|
|
700
|
+
FROM {dataset.dataset_table_name}
|
|
701
|
+
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
702
|
+
ORDER BY ts DESC;
|
|
703
|
+
""",
|
|
704
|
+
).df()
|
|
705
|
+
|
|
706
|
+
# Process traces and count LLM spans
|
|
707
|
+
llm_call_counts = {}
|
|
708
|
+
for _, row in results.iterrows():
|
|
709
|
+
ts = row["ts"]
|
|
710
|
+
root_spans = row["root_spans"]
|
|
711
|
+
|
|
712
|
+
# Parse root_spans if it's a string
|
|
713
|
+
if isinstance(root_spans, str):
|
|
714
|
+
root_spans = json.loads(root_spans)
|
|
715
|
+
|
|
716
|
+
# Count LLM spans in the tree
|
|
717
|
+
def count_llm_spans(spans):
|
|
718
|
+
count = 0
|
|
719
|
+
for span in spans:
|
|
720
|
+
# Check if this span is an LLM span
|
|
721
|
+
if span.get("span_kind") == "LLM":
|
|
722
|
+
count += 1
|
|
723
|
+
|
|
724
|
+
# Recursively count children
|
|
725
|
+
if span.get("children"):
|
|
726
|
+
count += count_llm_spans(span["children"])
|
|
727
|
+
return count
|
|
728
|
+
|
|
729
|
+
llm_count = count_llm_spans(root_spans)
|
|
730
|
+
|
|
731
|
+
if llm_count > 0:
|
|
732
|
+
if ts not in llm_call_counts:
|
|
733
|
+
llm_call_counts[ts] = 0
|
|
734
|
+
llm_call_counts[ts] += llm_count
|
|
735
|
+
|
|
736
|
+
if not llm_call_counts:
|
|
737
|
+
return []
|
|
738
|
+
|
|
739
|
+
# Convert to DataFrame format
|
|
740
|
+
data = [{"ts": ts, "count": count} for ts, count in llm_call_counts.items()]
|
|
741
|
+
df = pd.DataFrame(data)
|
|
742
|
+
|
|
743
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
744
|
+
df,
|
|
745
|
+
"count",
|
|
746
|
+
[],
|
|
747
|
+
"ts",
|
|
748
|
+
)
|
|
749
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
750
|
+
return [metric]
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction):
|
|
754
|
+
METRIC_NAME = "tool_selection_and_usage_by_agent"
|
|
755
|
+
|
|
756
|
+
@staticmethod
|
|
757
|
+
def id() -> UUID:
|
|
758
|
+
return UUID("00000000-0000-0000-0000-000000000037")
|
|
759
|
+
|
|
760
|
+
@staticmethod
|
|
761
|
+
def display_name() -> str:
|
|
762
|
+
return "Tool Selection and Usage by Agent"
|
|
763
|
+
|
|
764
|
+
@staticmethod
|
|
765
|
+
def description() -> str:
|
|
766
|
+
return "Metric that counts tool selection and usage correctness, segmented by agent name."
|
|
767
|
+
|
|
768
|
+
@staticmethod
|
|
769
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
770
|
+
return [
|
|
771
|
+
BaseReportedAggregation(
|
|
772
|
+
metric_name=AgenticToolSelectionAndUsageByAgentAggregation.METRIC_NAME,
|
|
773
|
+
description=AgenticToolSelectionAndUsageByAgentAggregation.description(),
|
|
774
|
+
),
|
|
775
|
+
]
|
|
776
|
+
|
|
777
|
+
def aggregate(
|
|
778
|
+
self,
|
|
779
|
+
ddb_conn: DuckDBPyConnection,
|
|
780
|
+
dataset: Annotated[
|
|
781
|
+
DatasetReference,
|
|
782
|
+
MetricDatasetParameterAnnotation(
|
|
783
|
+
friendly_name="Dataset",
|
|
784
|
+
description="The agentic trace dataset containing traces with nested spans.",
|
|
785
|
+
model_problem_type=ModelProblemType.AGENTIC_TRACE,
|
|
786
|
+
),
|
|
787
|
+
],
|
|
788
|
+
) -> list[NumericMetric]:
|
|
789
|
+
# Query traces by timestamp
|
|
790
|
+
results = ddb_conn.sql(
|
|
791
|
+
f"""
|
|
792
|
+
SELECT
|
|
793
|
+
time_bucket(INTERVAL '5 minutes', to_timestamp(start_time / 1000000)) as ts,
|
|
794
|
+
root_spans
|
|
795
|
+
FROM {dataset.dataset_table_name}
|
|
796
|
+
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
797
|
+
ORDER BY ts DESC;
|
|
798
|
+
""",
|
|
799
|
+
).df()
|
|
800
|
+
|
|
801
|
+
# Process traces and extract spans with metrics
|
|
802
|
+
processed_data = []
|
|
803
|
+
for _, row in results.iterrows():
|
|
804
|
+
ts = row["ts"]
|
|
805
|
+
root_spans = row["root_spans"]
|
|
806
|
+
|
|
807
|
+
# Parse root_spans if it's a string
|
|
808
|
+
if isinstance(root_spans, str):
|
|
809
|
+
root_spans = json.loads(root_spans)
|
|
810
|
+
|
|
811
|
+
# Extract all spans with metrics and their agent names from the tree
|
|
812
|
+
spans_with_metrics_and_agents = extract_spans_with_metrics_and_agents(
|
|
813
|
+
root_spans,
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
# Process each span with metrics
|
|
817
|
+
for span, agent_name in spans_with_metrics_and_agents:
|
|
818
|
+
metric_results = span.get("metric_results", [])
|
|
819
|
+
|
|
820
|
+
for metric_result in metric_results:
|
|
821
|
+
if metric_result.get("metric_type") == "ToolSelection":
|
|
822
|
+
details = metric_result.get("details", {})
|
|
823
|
+
tool_selection = details.get("tool_selection", {})
|
|
824
|
+
|
|
825
|
+
tool_selection_score = tool_selection.get("tool_selection")
|
|
826
|
+
tool_usage_score = tool_selection.get("tool_usage")
|
|
827
|
+
|
|
828
|
+
if tool_selection_score is not None:
|
|
829
|
+
# Categorize selection
|
|
830
|
+
if tool_selection_score == 1:
|
|
831
|
+
selection_category = "correct_selection"
|
|
832
|
+
elif tool_selection_score == 0:
|
|
833
|
+
selection_category = "incorrect_selection"
|
|
834
|
+
else:
|
|
835
|
+
selection_category = "no_selection"
|
|
836
|
+
|
|
837
|
+
# Categorize usage
|
|
838
|
+
if tool_usage_score == 1:
|
|
839
|
+
usage_category = "correct_usage"
|
|
840
|
+
elif tool_usage_score == 0:
|
|
841
|
+
usage_category = "incorrect_usage"
|
|
842
|
+
else:
|
|
843
|
+
usage_category = "no_usage"
|
|
844
|
+
|
|
845
|
+
processed_data.append(
|
|
846
|
+
{
|
|
847
|
+
"ts": ts,
|
|
848
|
+
"agent_name": agent_name,
|
|
849
|
+
"selection_category": selection_category,
|
|
850
|
+
"usage_category": usage_category,
|
|
851
|
+
"count": 1,
|
|
852
|
+
},
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
if not processed_data:
|
|
856
|
+
return []
|
|
857
|
+
|
|
858
|
+
# Convert to DataFrame and aggregate
|
|
859
|
+
df = pd.DataFrame(processed_data)
|
|
860
|
+
aggregated = (
|
|
861
|
+
df.groupby(["ts", "agent_name", "selection_category", "usage_category"])[
|
|
862
|
+
"count"
|
|
863
|
+
]
|
|
864
|
+
.sum()
|
|
865
|
+
.reset_index()
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
869
|
+
aggregated,
|
|
870
|
+
"count",
|
|
871
|
+
["agent_name", "selection_category", "usage_category"],
|
|
872
|
+
"ts",
|
|
873
|
+
)
|
|
874
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
875
|
+
return [metric]
|