arthur-common 2.1.58__py3-none-any.whl → 2.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arthur_common/aggregations/aggregator.py +73 -9
- arthur_common/aggregations/functions/agentic_aggregations.py +260 -85
- arthur_common/aggregations/functions/categorical_count.py +15 -15
- arthur_common/aggregations/functions/confusion_matrix.py +24 -26
- arthur_common/aggregations/functions/inference_count.py +5 -9
- arthur_common/aggregations/functions/inference_count_by_class.py +16 -27
- arthur_common/aggregations/functions/inference_null_count.py +10 -13
- arthur_common/aggregations/functions/mean_absolute_error.py +12 -18
- arthur_common/aggregations/functions/mean_squared_error.py +12 -18
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +13 -20
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +1 -1
- arthur_common/aggregations/functions/numeric_stats.py +13 -15
- arthur_common/aggregations/functions/numeric_sum.py +12 -15
- arthur_common/aggregations/functions/shield_aggregations.py +457 -215
- arthur_common/models/common_schemas.py +214 -0
- arthur_common/models/connectors.py +10 -2
- arthur_common/models/constants.py +24 -0
- arthur_common/models/datasets.py +0 -9
- arthur_common/models/enums.py +177 -0
- arthur_common/models/metric_schemas.py +63 -0
- arthur_common/models/metrics.py +2 -9
- arthur_common/models/request_schemas.py +870 -0
- arthur_common/models/response_schemas.py +785 -0
- arthur_common/models/schema_definitions.py +6 -1
- arthur_common/models/task_job_specs.py +3 -12
- arthur_common/tools/duckdb_data_loader.py +34 -2
- arthur_common/tools/duckdb_utils.py +3 -6
- arthur_common/tools/schema_inferer.py +3 -6
- {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/METADATA +12 -4
- arthur_common-2.4.13.dist-info/RECORD +49 -0
- arthur_common/models/shield.py +0 -642
- arthur_common-2.1.58.dist-info/RECORD +0 -44
- {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/WHEEL +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
1
3
|
from abc import ABC, abstractmethod
|
|
2
4
|
from base64 import b64encode
|
|
3
5
|
from typing import Any, Type, Union
|
|
@@ -10,6 +12,8 @@ from arthur_common.models.metrics import *
|
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class AggregationFunction(ABC):
|
|
15
|
+
FEATURE_FLAG_NAME: str | None = None
|
|
16
|
+
|
|
13
17
|
@staticmethod
|
|
14
18
|
@abstractmethod
|
|
15
19
|
def id() -> UUID:
|
|
@@ -35,6 +39,31 @@ class AggregationFunction(ABC):
|
|
|
35
39
|
"""Returns the list of aggregations reported by the aggregate function."""
|
|
36
40
|
raise NotImplementedError
|
|
37
41
|
|
|
42
|
+
@staticmethod
|
|
43
|
+
def get_innermost_segmentation_columns(segmentation_cols: list[str]) -> list[str]:
|
|
44
|
+
"""
|
|
45
|
+
Extracts the innermost column name for nested segmentation columns or
|
|
46
|
+
returns the top-level column name for non-nested segmentation columns.
|
|
47
|
+
"""
|
|
48
|
+
for i, col in enumerate(segmentation_cols):
|
|
49
|
+
# extract the innermost column for escaped column names (e.g. '"nested.col"."name"')
|
|
50
|
+
# otherwise return the name since it's a top-level column
|
|
51
|
+
if col.startswith('"') and col.endswith('"'):
|
|
52
|
+
identifier = col[1:-1]
|
|
53
|
+
identifier_split_in_struct_fields = re.split(r'"\."', identifier)
|
|
54
|
+
|
|
55
|
+
# For nested columns, take just the innermost field name
|
|
56
|
+
# Otherwise for top-level columns, take the whole name
|
|
57
|
+
if len(identifier_split_in_struct_fields) > 1:
|
|
58
|
+
innermost_field = identifier_split_in_struct_fields[-1]
|
|
59
|
+
segmentation_cols[i] = innermost_field.replace('""', '"')
|
|
60
|
+
else:
|
|
61
|
+
segmentation_cols[i] = identifier.replace('""', '"')
|
|
62
|
+
else:
|
|
63
|
+
segmentation_cols[i] = col
|
|
64
|
+
|
|
65
|
+
return segmentation_cols
|
|
66
|
+
|
|
38
67
|
@abstractmethod
|
|
39
68
|
def aggregate(
|
|
40
69
|
self,
|
|
@@ -50,6 +79,13 @@ class AggregationFunction(ABC):
|
|
|
50
79
|
value = "null"
|
|
51
80
|
return Dimension(name=name, value=str(value))
|
|
52
81
|
|
|
82
|
+
def is_feature_flag_enabled(self, feature_flag_name: str) -> bool:
|
|
83
|
+
if feature_flag_name is None:
|
|
84
|
+
value = os.getenv(self.FEATURE_FLAG_NAME, "false")
|
|
85
|
+
else:
|
|
86
|
+
value = os.getenv(feature_flag_name, "false")
|
|
87
|
+
return value.lower() in ("true", "1", "yes")
|
|
88
|
+
|
|
53
89
|
|
|
54
90
|
class NumericAggregationFunction(AggregationFunction, ABC):
|
|
55
91
|
def aggregation_type(self) -> Type[NumericMetric]:
|
|
@@ -89,6 +125,11 @@ class NumericAggregationFunction(AggregationFunction, ABC):
|
|
|
89
125
|
),
|
|
90
126
|
]
|
|
91
127
|
|
|
128
|
+
# get innermost column name for nested segmentation columns
|
|
129
|
+
dim_columns = AggregationFunction.get_innermost_segmentation_columns(
|
|
130
|
+
dim_columns,
|
|
131
|
+
)
|
|
132
|
+
|
|
92
133
|
calculated_metrics: list[NumericTimeSeries] = []
|
|
93
134
|
# make sure dropna is False or rows with "null" as a dimension value will be dropped
|
|
94
135
|
groups = data.groupby(dim_columns, dropna=False)
|
|
@@ -168,11 +209,33 @@ class SketchAggregationFunction(AggregationFunction, ABC):
|
|
|
168
209
|
"""
|
|
169
210
|
|
|
170
211
|
calculated_metrics: list[SketchTimeSeries] = []
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
212
|
+
|
|
213
|
+
# get innermost column name for nested segmentation columns
|
|
214
|
+
dim_columns = AggregationFunction.get_innermost_segmentation_columns(
|
|
215
|
+
dim_columns,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if dim_columns:
|
|
219
|
+
# make sure dropna is False or rows with "null" as a dimension value will be dropped
|
|
220
|
+
# call _group_to_series for each grouped DF
|
|
221
|
+
groups = data.groupby(dim_columns, dropna=False)
|
|
222
|
+
for _, group in groups:
|
|
223
|
+
calculated_metrics.append(
|
|
224
|
+
SketchAggregationFunction._group_to_series(
|
|
225
|
+
group,
|
|
226
|
+
timestamp_col,
|
|
227
|
+
dim_columns,
|
|
228
|
+
value_col,
|
|
229
|
+
),
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
174
232
|
calculated_metrics.append(
|
|
175
|
-
SketchAggregationFunction._group_to_series(
|
|
233
|
+
SketchAggregationFunction._group_to_series(
|
|
234
|
+
data,
|
|
235
|
+
timestamp_col,
|
|
236
|
+
dim_columns,
|
|
237
|
+
value_col,
|
|
238
|
+
),
|
|
176
239
|
)
|
|
177
240
|
|
|
178
241
|
return calculated_metrics
|
|
@@ -193,11 +256,12 @@ class SketchAggregationFunction(AggregationFunction, ABC):
|
|
|
193
256
|
return s
|
|
194
257
|
|
|
195
258
|
dimensions: list[Dimension] = []
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
259
|
+
if dim_columns:
|
|
260
|
+
# Get the first row of the group to determine the group level dimensions
|
|
261
|
+
dims_row = group.iloc[0]
|
|
262
|
+
for dim in dim_columns:
|
|
263
|
+
d = AggregationFunction.string_to_dimension(name=dim, value=dims_row[dim])
|
|
264
|
+
dimensions.append(d)
|
|
201
265
|
|
|
202
266
|
values: list[SketchPoint] = []
|
|
203
267
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
-
from
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Annotated, Any
|
|
4
5
|
from uuid import UUID
|
|
5
6
|
|
|
6
7
|
import pandas as pd
|
|
@@ -10,7 +11,7 @@ from arthur_common.aggregations.aggregator import (
|
|
|
10
11
|
NumericAggregationFunction,
|
|
11
12
|
SketchAggregationFunction,
|
|
12
13
|
)
|
|
13
|
-
from arthur_common.models.
|
|
14
|
+
from arthur_common.models.enums import ModelProblemType
|
|
14
15
|
from arthur_common.models.metrics import (
|
|
15
16
|
BaseReportedAggregation,
|
|
16
17
|
DatasetReference,
|
|
@@ -27,7 +28,50 @@ TOOL_SCORE_NO_TOOL_VALUE = 2
|
|
|
27
28
|
logger = logging.getLogger(__name__)
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def
|
|
31
|
+
def root_span_in_time_buckets(
|
|
32
|
+
ddb_conn: DuckDBPyConnection, dataset: DatasetReference
|
|
33
|
+
) -> pd.DataFrame:
|
|
34
|
+
return ddb_conn.sql(
|
|
35
|
+
f"""
|
|
36
|
+
SELECT
|
|
37
|
+
time_bucket(INTERVAL '5 minutes', start_time) as ts,
|
|
38
|
+
root_spans
|
|
39
|
+
FROM {dataset.dataset_table_name}
|
|
40
|
+
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
41
|
+
ORDER BY ts DESC;
|
|
42
|
+
""",
|
|
43
|
+
).df()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def span_parser(span_to_parse: str | dict[str, Any]) -> dict[str, Any]:
|
|
47
|
+
if isinstance(span_to_parse, str):
|
|
48
|
+
return json.loads(span_to_parse) # type: ignore[no-any-return]
|
|
49
|
+
|
|
50
|
+
return span_to_parse
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def extract_agent_name_from_span(span: dict[str, Any]) -> str | None:
|
|
54
|
+
try:
|
|
55
|
+
raw_data = span.get("raw_data", {})
|
|
56
|
+
if isinstance(raw_data, str):
|
|
57
|
+
raw_data = json.loads(raw_data)
|
|
58
|
+
|
|
59
|
+
# Try to get agent name from the span's name field
|
|
60
|
+
agent_name = raw_data.get("name", "unknown")
|
|
61
|
+
if agent_name != "unknown":
|
|
62
|
+
return str(agent_name)
|
|
63
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
64
|
+
logger.error(
|
|
65
|
+
f"Error parsing attributes from span (span_id: {span.get('span_id')}) in trace {span.get('trace_id')}",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# TODO: create TypedDict for span
|
|
72
|
+
def extract_spans_with_metrics_and_agents(
|
|
73
|
+
root_spans: list[str | dict[str, Any]],
|
|
74
|
+
) -> list[tuple[dict[str, Any], str]]:
|
|
31
75
|
"""Recursively extract all spans with metrics and their associated agent names from the span tree.
|
|
32
76
|
|
|
33
77
|
Returns:
|
|
@@ -35,46 +79,42 @@ def extract_spans_with_metrics_and_agents(root_spans):
|
|
|
35
79
|
"""
|
|
36
80
|
spans_with_metrics_and_agents = []
|
|
37
81
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
82
|
+
# TODO: Improve function so it won't modify variable outside of its scope
|
|
83
|
+
def traverse_spans(
|
|
84
|
+
spans: list[str | dict[str, Any]],
|
|
85
|
+
current_agent: str = "unknown",
|
|
86
|
+
) -> None:
|
|
87
|
+
for span_to_parse in spans:
|
|
88
|
+
parsed_span = span_parser(span_to_parse)
|
|
41
89
|
|
|
42
90
|
# Update current agent name if this span is an AGENT
|
|
43
|
-
if
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
raw_data = json.loads(raw_data)
|
|
48
|
-
|
|
49
|
-
# Try to get agent name from the span's name field
|
|
50
|
-
agent_name = raw_data.get("name", "unknown")
|
|
51
|
-
if agent_name != "unknown":
|
|
52
|
-
current_agent_name = agent_name
|
|
53
|
-
except (json.JSONDecodeError, KeyError, TypeError):
|
|
54
|
-
logger.error(
|
|
55
|
-
f"Error parsing attributes from span (span_id: {span.get('span_id')}) in trace {span.get('trace_id')}",
|
|
56
|
-
)
|
|
91
|
+
if parsed_span.get("span_kind") == "AGENT":
|
|
92
|
+
agent_name = extract_agent_name_from_span(parsed_span)
|
|
93
|
+
if agent_name:
|
|
94
|
+
current_agent = agent_name
|
|
57
95
|
|
|
58
96
|
# Check if this span has metrics
|
|
59
|
-
if
|
|
60
|
-
spans_with_metrics_and_agents.append(
|
|
97
|
+
if parsed_span.get("metric_results", []):
|
|
98
|
+
spans_with_metrics_and_agents.append(
|
|
99
|
+
(parsed_span, current_agent),
|
|
100
|
+
)
|
|
61
101
|
|
|
62
102
|
# Recursively traverse children with the current agent name
|
|
63
|
-
if
|
|
64
|
-
traverse_spans(
|
|
103
|
+
if children_span := parsed_span.get("children", []):
|
|
104
|
+
traverse_spans(children_span, current_agent)
|
|
65
105
|
|
|
66
106
|
traverse_spans(root_spans)
|
|
67
107
|
return spans_with_metrics_and_agents
|
|
68
108
|
|
|
69
109
|
|
|
70
|
-
def determine_relevance_pass_fail(score):
|
|
110
|
+
def determine_relevance_pass_fail(score: float | None) -> str | None:
|
|
71
111
|
"""Determine pass/fail for relevance scores using global threshold"""
|
|
72
112
|
if score is None:
|
|
73
113
|
return None
|
|
74
114
|
return "pass" if score >= RELEVANCE_SCORE_THRESHOLD else "fail"
|
|
75
115
|
|
|
76
116
|
|
|
77
|
-
def determine_tool_pass_fail(score):
|
|
117
|
+
def determine_tool_pass_fail(score: int | None) -> str | None:
|
|
78
118
|
"""Determine pass/fail for tool scores using global threshold"""
|
|
79
119
|
if score is None:
|
|
80
120
|
return None
|
|
@@ -141,16 +181,7 @@ class AgenticMetricsOverTimeAggregation(SketchAggregationFunction):
|
|
|
141
181
|
],
|
|
142
182
|
) -> list[SketchMetric]:
|
|
143
183
|
# Query traces by timestamp
|
|
144
|
-
results = ddb_conn
|
|
145
|
-
f"""
|
|
146
|
-
SELECT
|
|
147
|
-
time_bucket(INTERVAL '5 minutes', start_time) as ts,
|
|
148
|
-
root_spans
|
|
149
|
-
FROM {dataset.dataset_table_name}
|
|
150
|
-
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
151
|
-
ORDER BY ts DESC;
|
|
152
|
-
""",
|
|
153
|
-
).df()
|
|
184
|
+
results = root_span_in_time_buckets(ddb_conn, dataset)
|
|
154
185
|
|
|
155
186
|
# Process traces and extract spans with metrics
|
|
156
187
|
tool_selection_data = []
|
|
@@ -177,7 +208,7 @@ class AgenticMetricsOverTimeAggregation(SketchAggregationFunction):
|
|
|
177
208
|
|
|
178
209
|
for metric_result in metric_results:
|
|
179
210
|
metric_type = metric_result.get("metric_type")
|
|
180
|
-
details = json.loads(metric_result.get("details",
|
|
211
|
+
details = json.loads(metric_result.get("details", "{}"))
|
|
181
212
|
|
|
182
213
|
if metric_type == "ToolSelection":
|
|
183
214
|
tool_selection = details.get("tool_selection", {})
|
|
@@ -397,17 +428,7 @@ class AgenticRelevancePassFailCountAggregation(NumericAggregationFunction):
|
|
|
397
428
|
),
|
|
398
429
|
],
|
|
399
430
|
) -> list[NumericMetric]:
|
|
400
|
-
|
|
401
|
-
results = ddb_conn.sql(
|
|
402
|
-
f"""
|
|
403
|
-
SELECT
|
|
404
|
-
time_bucket(INTERVAL '5 minutes', start_time) as ts,
|
|
405
|
-
root_spans
|
|
406
|
-
FROM {dataset.dataset_table_name}
|
|
407
|
-
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
408
|
-
ORDER BY ts DESC;
|
|
409
|
-
""",
|
|
410
|
-
).df()
|
|
431
|
+
results = root_span_in_time_buckets(ddb_conn, dataset)
|
|
411
432
|
|
|
412
433
|
# Process traces and extract spans with metrics
|
|
413
434
|
processed_data = []
|
|
@@ -430,7 +451,7 @@ class AgenticRelevancePassFailCountAggregation(NumericAggregationFunction):
|
|
|
430
451
|
|
|
431
452
|
for metric_result in metric_results:
|
|
432
453
|
metric_type = metric_result.get("metric_type")
|
|
433
|
-
details = json.loads(metric_result.get("details",
|
|
454
|
+
details = json.loads(metric_result.get("details", "{}"))
|
|
434
455
|
|
|
435
456
|
if metric_type in ["QueryRelevance", "ResponseRelevance"]:
|
|
436
457
|
relevance_data = details.get(
|
|
@@ -522,17 +543,7 @@ class AgenticToolPassFailCountAggregation(NumericAggregationFunction):
|
|
|
522
543
|
),
|
|
523
544
|
],
|
|
524
545
|
) -> list[NumericMetric]:
|
|
525
|
-
|
|
526
|
-
results = ddb_conn.sql(
|
|
527
|
-
f"""
|
|
528
|
-
SELECT
|
|
529
|
-
time_bucket(INTERVAL '5 minutes', start_time) as ts,
|
|
530
|
-
root_spans
|
|
531
|
-
FROM {dataset.dataset_table_name}
|
|
532
|
-
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
533
|
-
ORDER BY ts DESC;
|
|
534
|
-
""",
|
|
535
|
-
).df()
|
|
546
|
+
results = root_span_in_time_buckets(ddb_conn, dataset)
|
|
536
547
|
|
|
537
548
|
# Process traces and extract spans with metrics
|
|
538
549
|
processed_data = []
|
|
@@ -555,7 +566,7 @@ class AgenticToolPassFailCountAggregation(NumericAggregationFunction):
|
|
|
555
566
|
|
|
556
567
|
for metric_result in metric_results:
|
|
557
568
|
if metric_result.get("metric_type") == "ToolSelection":
|
|
558
|
-
details = json.loads(metric_result.get("details",
|
|
569
|
+
details = json.loads(metric_result.get("details", "{}"))
|
|
559
570
|
tool_selection = details.get("tool_selection", {})
|
|
560
571
|
|
|
561
572
|
tool_selection_score = tool_selection.get("tool_selection")
|
|
@@ -701,16 +712,7 @@ class AgenticLLMCallCountAggregation(NumericAggregationFunction):
|
|
|
701
712
|
),
|
|
702
713
|
],
|
|
703
714
|
) -> list[NumericMetric]:
|
|
704
|
-
results = ddb_conn
|
|
705
|
-
f"""
|
|
706
|
-
SELECT
|
|
707
|
-
time_bucket(INTERVAL '5 minutes', start_time) as ts,
|
|
708
|
-
root_spans
|
|
709
|
-
FROM {dataset.dataset_table_name}
|
|
710
|
-
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
711
|
-
ORDER BY ts DESC;
|
|
712
|
-
""",
|
|
713
|
-
).df()
|
|
715
|
+
results = root_span_in_time_buckets(ddb_conn, dataset)
|
|
714
716
|
|
|
715
717
|
# Process traces and count LLM spans
|
|
716
718
|
llm_call_counts = {}
|
|
@@ -723,10 +725,10 @@ class AgenticLLMCallCountAggregation(NumericAggregationFunction):
|
|
|
723
725
|
root_spans = json.loads(root_spans)
|
|
724
726
|
|
|
725
727
|
# Count LLM spans in the tree
|
|
726
|
-
def count_llm_spans(spans):
|
|
728
|
+
def count_llm_spans(spans: list[str | dict[str, Any]]) -> int:
|
|
727
729
|
count = 0
|
|
728
|
-
for
|
|
729
|
-
span =
|
|
730
|
+
for span_to_parse in spans:
|
|
731
|
+
span = span_parser(span_to_parse)
|
|
730
732
|
|
|
731
733
|
# Check if this span is an LLM span
|
|
732
734
|
if span.get("span_kind") == "LLM":
|
|
@@ -798,16 +800,7 @@ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction)
|
|
|
798
800
|
],
|
|
799
801
|
) -> list[NumericMetric]:
|
|
800
802
|
# Query traces by timestamp
|
|
801
|
-
results = ddb_conn
|
|
802
|
-
f"""
|
|
803
|
-
SELECT
|
|
804
|
-
time_bucket(INTERVAL '5 minutes', start_time) as ts,
|
|
805
|
-
root_spans
|
|
806
|
-
FROM {dataset.dataset_table_name}
|
|
807
|
-
WHERE root_spans IS NOT NULL AND length(root_spans) > 0
|
|
808
|
-
ORDER BY ts DESC;
|
|
809
|
-
""",
|
|
810
|
-
).df()
|
|
803
|
+
results = root_span_in_time_buckets(ddb_conn, dataset)
|
|
811
804
|
|
|
812
805
|
# Process traces and extract spans with metrics
|
|
813
806
|
processed_data = []
|
|
@@ -830,7 +823,7 @@ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction)
|
|
|
830
823
|
|
|
831
824
|
for metric_result in metric_results:
|
|
832
825
|
if metric_result.get("metric_type") == "ToolSelection":
|
|
833
|
-
details = json.loads(metric_result.get("details",
|
|
826
|
+
details = json.loads(metric_result.get("details", "{}"))
|
|
834
827
|
tool_selection = details.get("tool_selection", {})
|
|
835
828
|
|
|
836
829
|
tool_selection_score = tool_selection.get("tool_selection")
|
|
@@ -884,3 +877,185 @@ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction)
|
|
|
884
877
|
)
|
|
885
878
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
886
879
|
return [metric]
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
class AgenticTraceLatencyAggregation(SketchAggregationFunction):
|
|
883
|
+
METRIC_NAME = "trace_latency"
|
|
884
|
+
|
|
885
|
+
@staticmethod
|
|
886
|
+
def id() -> UUID:
|
|
887
|
+
return UUID("00000000-0000-0000-0000-000000000039")
|
|
888
|
+
|
|
889
|
+
@staticmethod
|
|
890
|
+
def display_name() -> str:
|
|
891
|
+
return "Trace Latency"
|
|
892
|
+
|
|
893
|
+
@staticmethod
|
|
894
|
+
def description() -> str:
|
|
895
|
+
return "Aggregation that reports the latency of the agentic trace in ms."
|
|
896
|
+
|
|
897
|
+
@staticmethod
|
|
898
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
899
|
+
return [
|
|
900
|
+
BaseReportedAggregation(
|
|
901
|
+
metric_name=AgenticTraceLatencyAggregation.METRIC_NAME,
|
|
902
|
+
description=AgenticTraceLatencyAggregation.description(),
|
|
903
|
+
),
|
|
904
|
+
]
|
|
905
|
+
|
|
906
|
+
def aggregate(
|
|
907
|
+
self,
|
|
908
|
+
ddb_conn: DuckDBPyConnection,
|
|
909
|
+
dataset: Annotated[
|
|
910
|
+
DatasetReference,
|
|
911
|
+
MetricDatasetParameterAnnotation(
|
|
912
|
+
friendly_name="Dataset",
|
|
913
|
+
description="The agentic trace dataset containing traces with nested spans.",
|
|
914
|
+
model_problem_type=ModelProblemType.AGENTIC_TRACE,
|
|
915
|
+
),
|
|
916
|
+
],
|
|
917
|
+
) -> list[SketchMetric]:
|
|
918
|
+
# Query traces by timestamp and calculate latency directly in SQL
|
|
919
|
+
results = ddb_conn.sql(
|
|
920
|
+
f"""
|
|
921
|
+
SELECT
|
|
922
|
+
time_bucket(INTERVAL '5 minutes', start_time) as ts,
|
|
923
|
+
CAST(EXTRACT(EPOCH FROM (end_time - start_time)) * 1000 AS INTEGER) as latency_ms
|
|
924
|
+
FROM {dataset.dataset_table_name}
|
|
925
|
+
WHERE start_time IS NOT NULL
|
|
926
|
+
AND end_time IS NOT NULL
|
|
927
|
+
AND end_time > start_time
|
|
928
|
+
ORDER BY ts DESC;
|
|
929
|
+
""",
|
|
930
|
+
).df()
|
|
931
|
+
|
|
932
|
+
if results.empty:
|
|
933
|
+
return []
|
|
934
|
+
|
|
935
|
+
df = results
|
|
936
|
+
# Create a single time series without grouping dimensions
|
|
937
|
+
# Since we have no dimensions to group by, we create one time series for all data
|
|
938
|
+
series = [self._group_to_series(df, "ts", [], "latency_ms")]
|
|
939
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
940
|
+
return [metric]
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
class AgenticSpanLatencyAggregation(SketchAggregationFunction):
|
|
944
|
+
METRIC_NAME = "span_latency"
|
|
945
|
+
|
|
946
|
+
@staticmethod
|
|
947
|
+
def id() -> UUID:
|
|
948
|
+
return UUID("00000000-0000-0000-0000-000000000040")
|
|
949
|
+
|
|
950
|
+
@staticmethod
|
|
951
|
+
def display_name() -> str:
|
|
952
|
+
return "Span Latency"
|
|
953
|
+
|
|
954
|
+
@staticmethod
|
|
955
|
+
def description() -> str:
|
|
956
|
+
return "Aggregation that reports the latency of the agentic span in ms."
|
|
957
|
+
|
|
958
|
+
@staticmethod
|
|
959
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
960
|
+
return [
|
|
961
|
+
BaseReportedAggregation(
|
|
962
|
+
metric_name=AgenticSpanLatencyAggregation.METRIC_NAME,
|
|
963
|
+
description=AgenticSpanLatencyAggregation.description(),
|
|
964
|
+
),
|
|
965
|
+
]
|
|
966
|
+
|
|
967
|
+
def aggregate(
|
|
968
|
+
self,
|
|
969
|
+
ddb_conn: DuckDBPyConnection,
|
|
970
|
+
dataset: Annotated[
|
|
971
|
+
DatasetReference,
|
|
972
|
+
MetricDatasetParameterAnnotation(
|
|
973
|
+
friendly_name="Dataset",
|
|
974
|
+
description="The agentic trace dataset containing traces with nested spans.",
|
|
975
|
+
model_problem_type=ModelProblemType.AGENTIC_TRACE,
|
|
976
|
+
),
|
|
977
|
+
],
|
|
978
|
+
) -> list[SketchMetric]:
|
|
979
|
+
results = root_span_in_time_buckets(ddb_conn, dataset)
|
|
980
|
+
|
|
981
|
+
latency_data = []
|
|
982
|
+
for _, row in results.iterrows():
|
|
983
|
+
ts = row["ts"]
|
|
984
|
+
root_spans = row["root_spans"]
|
|
985
|
+
|
|
986
|
+
# Parse root_spans if it's a string
|
|
987
|
+
if isinstance(root_spans, str):
|
|
988
|
+
root_spans = json.loads(root_spans)
|
|
989
|
+
|
|
990
|
+
# Extract all spans with their timing data
|
|
991
|
+
spans_with_timing = self._extract_spans_with_timing(root_spans)
|
|
992
|
+
|
|
993
|
+
for span_data in spans_with_timing:
|
|
994
|
+
span, current_agent, latency_ms = span_data
|
|
995
|
+
span_kind = span.get("span_kind", "unknown")
|
|
996
|
+
|
|
997
|
+
if latency_ms is not None and latency_ms > 0:
|
|
998
|
+
latency_data.append(
|
|
999
|
+
{
|
|
1000
|
+
"ts": ts,
|
|
1001
|
+
"latency_ms": latency_ms,
|
|
1002
|
+
"span_kind": span_kind,
|
|
1003
|
+
"agent_name": current_agent,
|
|
1004
|
+
}
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
if not latency_data:
|
|
1008
|
+
return []
|
|
1009
|
+
|
|
1010
|
+
# Convert to DataFrame and create sketch metrics
|
|
1011
|
+
df = pd.DataFrame(latency_data)
|
|
1012
|
+
series = self.group_query_results_to_sketch_metrics(
|
|
1013
|
+
df,
|
|
1014
|
+
"latency_ms",
|
|
1015
|
+
["span_kind", "agent_name"],
|
|
1016
|
+
"ts",
|
|
1017
|
+
)
|
|
1018
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
1019
|
+
return [metric]
|
|
1020
|
+
|
|
1021
|
+
def _extract_spans_with_timing(
|
|
1022
|
+
self, spans: list[str | dict[str, Any]], current_agent: str = "unknown"
|
|
1023
|
+
) -> list[tuple[dict[str, Any], str, int | None]]:
|
|
1024
|
+
"""Recursively extract spans with calculated latency in milliseconds"""
|
|
1025
|
+
spans_with_timing = []
|
|
1026
|
+
|
|
1027
|
+
for span_to_parse in spans:
|
|
1028
|
+
span = span_parser(span_to_parse)
|
|
1029
|
+
|
|
1030
|
+
# Update current agent name if this span is an AGENT
|
|
1031
|
+
if span.get("span_kind") == "AGENT":
|
|
1032
|
+
agent_name = extract_agent_name_from_span(span)
|
|
1033
|
+
if agent_name:
|
|
1034
|
+
current_agent = agent_name
|
|
1035
|
+
|
|
1036
|
+
# Calculate latency if both start_time and end_time exist
|
|
1037
|
+
start_time = span.get("start_time")
|
|
1038
|
+
end_time = span.get("end_time")
|
|
1039
|
+
latency_ms = None
|
|
1040
|
+
|
|
1041
|
+
if start_time and end_time:
|
|
1042
|
+
try:
|
|
1043
|
+
# Parse ISO format timestamps and calculate latency in milliseconds
|
|
1044
|
+
# Assume same timezone for start and end time, specific TZ not important for latency calculation
|
|
1045
|
+
start_dt = datetime.fromisoformat(start_time)
|
|
1046
|
+
end_dt = datetime.fromisoformat(end_time)
|
|
1047
|
+
latency_ms = int((end_dt - start_dt).total_seconds() * 1000)
|
|
1048
|
+
except (ValueError, TypeError) as e:
|
|
1049
|
+
logger.warning(
|
|
1050
|
+
f"Error calculating latency for span {span.get('span_id')}: {e}"
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
spans_with_timing.append((span, current_agent, latency_ms))
|
|
1054
|
+
|
|
1055
|
+
# Recursively process children
|
|
1056
|
+
if children := span.get("children", []):
|
|
1057
|
+
spans_with_timing.extend(
|
|
1058
|
+
self._extract_spans_with_timing(children, current_agent)
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
return spans_with_timing
|
|
@@ -18,7 +18,10 @@ from arthur_common.models.schema_definitions import (
|
|
|
18
18
|
ScalarType,
|
|
19
19
|
ScopeSchemaTag,
|
|
20
20
|
)
|
|
21
|
-
from arthur_common.tools.duckdb_data_loader import
|
|
21
|
+
from arthur_common.tools.duckdb_data_loader import (
|
|
22
|
+
escape_str_literal,
|
|
23
|
+
unescape_identifier,
|
|
24
|
+
)
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class CategoricalCountAggregationFunction(NumericAggregationFunction):
|
|
@@ -93,30 +96,27 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
|
|
|
93
96
|
] = None,
|
|
94
97
|
) -> list[NumericMetric]:
|
|
95
98
|
"""Executed SQL with no segmentation columns:
|
|
96
|
-
select time_bucket(INTERVAL '5 minutes', {
|
|
99
|
+
select time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
|
|
97
100
|
count(*) as count, \
|
|
98
|
-
{
|
|
99
|
-
{
|
|
101
|
+
{categorical_col} as category, \
|
|
102
|
+
{categorical_col_name_unescaped} as column_name \
|
|
100
103
|
from {dataset.dataset_table_name} \
|
|
101
104
|
where ts is not null \
|
|
102
105
|
group by ts, category
|
|
103
106
|
"""
|
|
104
107
|
segmentation_cols = [] if not segmentation_cols else segmentation_cols
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
+
categorical_col_name_unescaped = escape_str_literal(
|
|
109
|
+
unescape_identifier(categorical_col),
|
|
110
|
+
)
|
|
108
111
|
|
|
109
112
|
# build query components with segmentation columns
|
|
110
|
-
escaped_segmentation_cols = [
|
|
111
|
-
escape_identifier(col) for col in segmentation_cols
|
|
112
|
-
]
|
|
113
113
|
all_select_clause_cols = [
|
|
114
|
-
f"time_bucket(INTERVAL '5 minutes', {
|
|
114
|
+
f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
|
|
115
115
|
f"count(*) as count",
|
|
116
|
-
f"{
|
|
117
|
-
f"{
|
|
118
|
-
] +
|
|
119
|
-
all_group_by_cols = ["ts", "category"] +
|
|
116
|
+
f"{categorical_col} as category",
|
|
117
|
+
f"{categorical_col_name_unescaped} as column_name",
|
|
118
|
+
] + segmentation_cols
|
|
119
|
+
all_group_by_cols = ["ts", "category"] + segmentation_cols
|
|
120
120
|
extra_dims = ["column_name", "category"]
|
|
121
121
|
|
|
122
122
|
# build query
|