arthur-common 2.1.58__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arthur_common/aggregations/aggregator.py +73 -9
  2. arthur_common/aggregations/functions/agentic_aggregations.py +260 -85
  3. arthur_common/aggregations/functions/categorical_count.py +15 -15
  4. arthur_common/aggregations/functions/confusion_matrix.py +24 -26
  5. arthur_common/aggregations/functions/inference_count.py +5 -9
  6. arthur_common/aggregations/functions/inference_count_by_class.py +16 -27
  7. arthur_common/aggregations/functions/inference_null_count.py +10 -13
  8. arthur_common/aggregations/functions/mean_absolute_error.py +12 -18
  9. arthur_common/aggregations/functions/mean_squared_error.py +12 -18
  10. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +13 -20
  11. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +1 -1
  12. arthur_common/aggregations/functions/numeric_stats.py +13 -15
  13. arthur_common/aggregations/functions/numeric_sum.py +12 -15
  14. arthur_common/aggregations/functions/shield_aggregations.py +457 -215
  15. arthur_common/models/common_schemas.py +214 -0
  16. arthur_common/models/connectors.py +10 -2
  17. arthur_common/models/constants.py +24 -0
  18. arthur_common/models/datasets.py +0 -9
  19. arthur_common/models/enums.py +177 -0
  20. arthur_common/models/metric_schemas.py +63 -0
  21. arthur_common/models/metrics.py +2 -9
  22. arthur_common/models/request_schemas.py +870 -0
  23. arthur_common/models/response_schemas.py +785 -0
  24. arthur_common/models/schema_definitions.py +6 -1
  25. arthur_common/models/task_job_specs.py +3 -12
  26. arthur_common/tools/duckdb_data_loader.py +34 -2
  27. arthur_common/tools/duckdb_utils.py +3 -6
  28. arthur_common/tools/schema_inferer.py +3 -6
  29. {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/METADATA +12 -4
  30. arthur_common-2.4.13.dist-info/RECORD +49 -0
  31. arthur_common/models/shield.py +0 -642
  32. arthur_common-2.1.58.dist-info/RECORD +0 -44
  33. {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/WHEEL +0 -0
@@ -1,3 +1,5 @@
1
+ import os
2
+ import re
1
3
  from abc import ABC, abstractmethod
2
4
  from base64 import b64encode
3
5
  from typing import Any, Type, Union
@@ -10,6 +12,8 @@ from arthur_common.models.metrics import *
10
12
 
11
13
 
12
14
  class AggregationFunction(ABC):
15
+ FEATURE_FLAG_NAME: str | None = None
16
+
13
17
  @staticmethod
14
18
  @abstractmethod
15
19
  def id() -> UUID:
@@ -35,6 +39,31 @@ class AggregationFunction(ABC):
35
39
  """Returns the list of aggregations reported by the aggregate function."""
36
40
  raise NotImplementedError
37
41
 
42
+ @staticmethod
43
+ def get_innermost_segmentation_columns(segmentation_cols: list[str]) -> list[str]:
44
+ """
45
+ Extracts the innermost column name for nested segmentation columns or
46
+ returns the top-level column name for non-nested segmentation columns.
47
+ """
48
+ for i, col in enumerate(segmentation_cols):
49
+ # extract the innermost column for escaped column names (e.g. '"nested.col"."name"')
50
+ # otherwise return the name since it's a top-level column
51
+ if col.startswith('"') and col.endswith('"'):
52
+ identifier = col[1:-1]
53
+ identifier_split_in_struct_fields = re.split(r'"\."', identifier)
54
+
55
+ # For nested columns, take just the innermost field name
56
+ # Otherwise for top-level columns, take the whole name
57
+ if len(identifier_split_in_struct_fields) > 1:
58
+ innermost_field = identifier_split_in_struct_fields[-1]
59
+ segmentation_cols[i] = innermost_field.replace('""', '"')
60
+ else:
61
+ segmentation_cols[i] = identifier.replace('""', '"')
62
+ else:
63
+ segmentation_cols[i] = col
64
+
65
+ return segmentation_cols
66
+
38
67
  @abstractmethod
39
68
  def aggregate(
40
69
  self,
@@ -50,6 +79,13 @@ class AggregationFunction(ABC):
50
79
  value = "null"
51
80
  return Dimension(name=name, value=str(value))
52
81
 
82
+ def is_feature_flag_enabled(self, feature_flag_name: str) -> bool:
83
+ if feature_flag_name is None:
84
+ value = os.getenv(self.FEATURE_FLAG_NAME, "false")
85
+ else:
86
+ value = os.getenv(feature_flag_name, "false")
87
+ return value.lower() in ("true", "1", "yes")
88
+
53
89
 
54
90
  class NumericAggregationFunction(AggregationFunction, ABC):
55
91
  def aggregation_type(self) -> Type[NumericMetric]:
@@ -89,6 +125,11 @@ class NumericAggregationFunction(AggregationFunction, ABC):
89
125
  ),
90
126
  ]
91
127
 
128
+ # get innermost column name for nested segmentation columns
129
+ dim_columns = AggregationFunction.get_innermost_segmentation_columns(
130
+ dim_columns,
131
+ )
132
+
92
133
  calculated_metrics: list[NumericTimeSeries] = []
93
134
  # make sure dropna is False or rows with "null" as a dimension value will be dropped
94
135
  groups = data.groupby(dim_columns, dropna=False)
@@ -168,11 +209,33 @@ class SketchAggregationFunction(AggregationFunction, ABC):
168
209
  """
169
210
 
170
211
  calculated_metrics: list[SketchTimeSeries] = []
171
- # make sure dropna is False or rows with "null" as a dimension value will be dropped
172
- groups = data.groupby(dim_columns, dropna=False)
173
- for _, group in groups:
212
+
213
+ # get innermost column name for nested segmentation columns
214
+ dim_columns = AggregationFunction.get_innermost_segmentation_columns(
215
+ dim_columns,
216
+ )
217
+
218
+ if dim_columns:
219
+ # make sure dropna is False or rows with "null" as a dimension value will be dropped
220
+ # call _group_to_series for each grouped DF
221
+ groups = data.groupby(dim_columns, dropna=False)
222
+ for _, group in groups:
223
+ calculated_metrics.append(
224
+ SketchAggregationFunction._group_to_series(
225
+ group,
226
+ timestamp_col,
227
+ dim_columns,
228
+ value_col,
229
+ ),
230
+ )
231
+ else:
174
232
  calculated_metrics.append(
175
- SketchAggregationFunction._group_to_series(group, timestamp_col, dim_columns, value_col),
233
+ SketchAggregationFunction._group_to_series(
234
+ data,
235
+ timestamp_col,
236
+ dim_columns,
237
+ value_col,
238
+ ),
176
239
  )
177
240
 
178
241
  return calculated_metrics
@@ -193,11 +256,12 @@ class SketchAggregationFunction(AggregationFunction, ABC):
193
256
  return s
194
257
 
195
258
  dimensions: list[Dimension] = []
196
- # Get the first row of the group to determine the group level dimensions
197
- dims_row = group.iloc[0]
198
- for dim in dim_columns:
199
- d = AggregationFunction.string_to_dimension(name=dim, value=dims_row[dim])
200
- dimensions.append(d)
259
+ if dim_columns:
260
+ # Get the first row of the group to determine the group level dimensions
261
+ dims_row = group.iloc[0]
262
+ for dim in dim_columns:
263
+ d = AggregationFunction.string_to_dimension(name=dim, value=dims_row[dim])
264
+ dimensions.append(d)
201
265
 
202
266
  values: list[SketchPoint] = []
203
267
 
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
- from typing import Annotated
3
+ from datetime import datetime
4
+ from typing import Annotated, Any
4
5
  from uuid import UUID
5
6
 
6
7
  import pandas as pd
@@ -10,7 +11,7 @@ from arthur_common.aggregations.aggregator import (
10
11
  NumericAggregationFunction,
11
12
  SketchAggregationFunction,
12
13
  )
13
- from arthur_common.models.datasets import ModelProblemType
14
+ from arthur_common.models.enums import ModelProblemType
14
15
  from arthur_common.models.metrics import (
15
16
  BaseReportedAggregation,
16
17
  DatasetReference,
@@ -27,7 +28,50 @@ TOOL_SCORE_NO_TOOL_VALUE = 2
27
28
  logger = logging.getLogger(__name__)
28
29
 
29
30
 
30
- def extract_spans_with_metrics_and_agents(root_spans):
31
+ def root_span_in_time_buckets(
32
+ ddb_conn: DuckDBPyConnection, dataset: DatasetReference
33
+ ) -> pd.DataFrame:
34
+ return ddb_conn.sql(
35
+ f"""
36
+ SELECT
37
+ time_bucket(INTERVAL '5 minutes', start_time) as ts,
38
+ root_spans
39
+ FROM {dataset.dataset_table_name}
40
+ WHERE root_spans IS NOT NULL AND length(root_spans) > 0
41
+ ORDER BY ts DESC;
42
+ """,
43
+ ).df()
44
+
45
+
46
+ def span_parser(span_to_parse: str | dict[str, Any]) -> dict[str, Any]:
47
+ if isinstance(span_to_parse, str):
48
+ return json.loads(span_to_parse) # type: ignore[no-any-return]
49
+
50
+ return span_to_parse
51
+
52
+
53
+ def extract_agent_name_from_span(span: dict[str, Any]) -> str | None:
54
+ try:
55
+ raw_data = span.get("raw_data", {})
56
+ if isinstance(raw_data, str):
57
+ raw_data = json.loads(raw_data)
58
+
59
+ # Try to get agent name from the span's name field
60
+ agent_name = raw_data.get("name", "unknown")
61
+ if agent_name != "unknown":
62
+ return str(agent_name)
63
+ except (json.JSONDecodeError, KeyError, TypeError):
64
+ logger.error(
65
+ f"Error parsing attributes from span (span_id: {span.get('span_id')}) in trace {span.get('trace_id')}",
66
+ )
67
+
68
+ return None
69
+
70
+
71
+ # TODO: create TypedDict for span
72
+ def extract_spans_with_metrics_and_agents(
73
+ root_spans: list[str | dict[str, Any]],
74
+ ) -> list[tuple[dict[str, Any], str]]:
31
75
  """Recursively extract all spans with metrics and their associated agent names from the span tree.
32
76
 
33
77
  Returns:
@@ -35,46 +79,42 @@ def extract_spans_with_metrics_and_agents(root_spans):
35
79
  """
36
80
  spans_with_metrics_and_agents = []
37
81
 
38
- def traverse_spans(spans, current_agent_name="unknown"):
39
- for span_str in spans:
40
- span = json.loads(span_str) if type(span_str) == str else span_str
82
+ # TODO: Improve function so it won't modify variable outside of its scope
83
+ def traverse_spans(
84
+ spans: list[str | dict[str, Any]],
85
+ current_agent: str = "unknown",
86
+ ) -> None:
87
+ for span_to_parse in spans:
88
+ parsed_span = span_parser(span_to_parse)
41
89
 
42
90
  # Update current agent name if this span is an AGENT
43
- if span.get("span_kind") == "AGENT":
44
- try:
45
- raw_data = span.get("raw_data", {})
46
- if isinstance(raw_data, str):
47
- raw_data = json.loads(raw_data)
48
-
49
- # Try to get agent name from the span's name field
50
- agent_name = raw_data.get("name", "unknown")
51
- if agent_name != "unknown":
52
- current_agent_name = agent_name
53
- except (json.JSONDecodeError, KeyError, TypeError):
54
- logger.error(
55
- f"Error parsing attributes from span (span_id: {span.get('span_id')}) in trace {span.get('trace_id')}",
56
- )
91
+ if parsed_span.get("span_kind") == "AGENT":
92
+ agent_name = extract_agent_name_from_span(parsed_span)
93
+ if agent_name:
94
+ current_agent = agent_name
57
95
 
58
96
  # Check if this span has metrics
59
- if span.get("metric_results") and len(span.get("metric_results", [])) > 0:
60
- spans_with_metrics_and_agents.append((span, current_agent_name))
97
+ if parsed_span.get("metric_results", []):
98
+ spans_with_metrics_and_agents.append(
99
+ (parsed_span, current_agent),
100
+ )
61
101
 
62
102
  # Recursively traverse children with the current agent name
63
- if span.get("children", []):
64
- traverse_spans(span["children"], current_agent_name)
103
+ if children_span := parsed_span.get("children", []):
104
+ traverse_spans(children_span, current_agent)
65
105
 
66
106
  traverse_spans(root_spans)
67
107
  return spans_with_metrics_and_agents
68
108
 
69
109
 
70
- def determine_relevance_pass_fail(score):
110
+ def determine_relevance_pass_fail(score: float | None) -> str | None:
71
111
  """Determine pass/fail for relevance scores using global threshold"""
72
112
  if score is None:
73
113
  return None
74
114
  return "pass" if score >= RELEVANCE_SCORE_THRESHOLD else "fail"
75
115
 
76
116
 
77
- def determine_tool_pass_fail(score):
117
+ def determine_tool_pass_fail(score: int | None) -> str | None:
78
118
  """Determine pass/fail for tool scores using global threshold"""
79
119
  if score is None:
80
120
  return None
@@ -141,16 +181,7 @@ class AgenticMetricsOverTimeAggregation(SketchAggregationFunction):
141
181
  ],
142
182
  ) -> list[SketchMetric]:
143
183
  # Query traces by timestamp
144
- results = ddb_conn.sql(
145
- f"""
146
- SELECT
147
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
148
- root_spans
149
- FROM {dataset.dataset_table_name}
150
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
151
- ORDER BY ts DESC;
152
- """,
153
- ).df()
184
+ results = root_span_in_time_buckets(ddb_conn, dataset)
154
185
 
155
186
  # Process traces and extract spans with metrics
156
187
  tool_selection_data = []
@@ -177,7 +208,7 @@ class AgenticMetricsOverTimeAggregation(SketchAggregationFunction):
177
208
 
178
209
  for metric_result in metric_results:
179
210
  metric_type = metric_result.get("metric_type")
180
- details = json.loads(metric_result.get("details", '{}'))
211
+ details = json.loads(metric_result.get("details", "{}"))
181
212
 
182
213
  if metric_type == "ToolSelection":
183
214
  tool_selection = details.get("tool_selection", {})
@@ -397,17 +428,7 @@ class AgenticRelevancePassFailCountAggregation(NumericAggregationFunction):
397
428
  ),
398
429
  ],
399
430
  ) -> list[NumericMetric]:
400
- # Query traces by timestamp
401
- results = ddb_conn.sql(
402
- f"""
403
- SELECT
404
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
405
- root_spans
406
- FROM {dataset.dataset_table_name}
407
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
408
- ORDER BY ts DESC;
409
- """,
410
- ).df()
431
+ results = root_span_in_time_buckets(ddb_conn, dataset)
411
432
 
412
433
  # Process traces and extract spans with metrics
413
434
  processed_data = []
@@ -430,7 +451,7 @@ class AgenticRelevancePassFailCountAggregation(NumericAggregationFunction):
430
451
 
431
452
  for metric_result in metric_results:
432
453
  metric_type = metric_result.get("metric_type")
433
- details = json.loads(metric_result.get("details", '{}'))
454
+ details = json.loads(metric_result.get("details", "{}"))
434
455
 
435
456
  if metric_type in ["QueryRelevance", "ResponseRelevance"]:
436
457
  relevance_data = details.get(
@@ -522,17 +543,7 @@ class AgenticToolPassFailCountAggregation(NumericAggregationFunction):
522
543
  ),
523
544
  ],
524
545
  ) -> list[NumericMetric]:
525
- # Query traces by timestamp
526
- results = ddb_conn.sql(
527
- f"""
528
- SELECT
529
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
530
- root_spans
531
- FROM {dataset.dataset_table_name}
532
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
533
- ORDER BY ts DESC;
534
- """,
535
- ).df()
546
+ results = root_span_in_time_buckets(ddb_conn, dataset)
536
547
 
537
548
  # Process traces and extract spans with metrics
538
549
  processed_data = []
@@ -555,7 +566,7 @@ class AgenticToolPassFailCountAggregation(NumericAggregationFunction):
555
566
 
556
567
  for metric_result in metric_results:
557
568
  if metric_result.get("metric_type") == "ToolSelection":
558
- details = json.loads(metric_result.get("details", '{}'))
569
+ details = json.loads(metric_result.get("details", "{}"))
559
570
  tool_selection = details.get("tool_selection", {})
560
571
 
561
572
  tool_selection_score = tool_selection.get("tool_selection")
@@ -701,16 +712,7 @@ class AgenticLLMCallCountAggregation(NumericAggregationFunction):
701
712
  ),
702
713
  ],
703
714
  ) -> list[NumericMetric]:
704
- results = ddb_conn.sql(
705
- f"""
706
- SELECT
707
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
708
- root_spans
709
- FROM {dataset.dataset_table_name}
710
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
711
- ORDER BY ts DESC;
712
- """,
713
- ).df()
715
+ results = root_span_in_time_buckets(ddb_conn, dataset)
714
716
 
715
717
  # Process traces and count LLM spans
716
718
  llm_call_counts = {}
@@ -723,10 +725,10 @@ class AgenticLLMCallCountAggregation(NumericAggregationFunction):
723
725
  root_spans = json.loads(root_spans)
724
726
 
725
727
  # Count LLM spans in the tree
726
- def count_llm_spans(spans):
728
+ def count_llm_spans(spans: list[str | dict[str, Any]]) -> int:
727
729
  count = 0
728
- for span_str in spans:
729
- span = json.loads(span_str) if type(span_str) == str else span_str
730
+ for span_to_parse in spans:
731
+ span = span_parser(span_to_parse)
730
732
 
731
733
  # Check if this span is an LLM span
732
734
  if span.get("span_kind") == "LLM":
@@ -798,16 +800,7 @@ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction)
798
800
  ],
799
801
  ) -> list[NumericMetric]:
800
802
  # Query traces by timestamp
801
- results = ddb_conn.sql(
802
- f"""
803
- SELECT
804
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
805
- root_spans
806
- FROM {dataset.dataset_table_name}
807
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
808
- ORDER BY ts DESC;
809
- """,
810
- ).df()
803
+ results = root_span_in_time_buckets(ddb_conn, dataset)
811
804
 
812
805
  # Process traces and extract spans with metrics
813
806
  processed_data = []
@@ -830,7 +823,7 @@ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction)
830
823
 
831
824
  for metric_result in metric_results:
832
825
  if metric_result.get("metric_type") == "ToolSelection":
833
- details = json.loads(metric_result.get("details", '{}'))
826
+ details = json.loads(metric_result.get("details", "{}"))
834
827
  tool_selection = details.get("tool_selection", {})
835
828
 
836
829
  tool_selection_score = tool_selection.get("tool_selection")
@@ -884,3 +877,185 @@ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction)
884
877
  )
885
878
  metric = self.series_to_metric(self.METRIC_NAME, series)
886
879
  return [metric]
880
+
881
+
882
+ class AgenticTraceLatencyAggregation(SketchAggregationFunction):
883
+ METRIC_NAME = "trace_latency"
884
+
885
+ @staticmethod
886
+ def id() -> UUID:
887
+ return UUID("00000000-0000-0000-0000-000000000039")
888
+
889
+ @staticmethod
890
+ def display_name() -> str:
891
+ return "Trace Latency"
892
+
893
+ @staticmethod
894
+ def description() -> str:
895
+ return "Aggregation that reports the latency of the agentic trace in ms."
896
+
897
+ @staticmethod
898
+ def reported_aggregations() -> list[BaseReportedAggregation]:
899
+ return [
900
+ BaseReportedAggregation(
901
+ metric_name=AgenticTraceLatencyAggregation.METRIC_NAME,
902
+ description=AgenticTraceLatencyAggregation.description(),
903
+ ),
904
+ ]
905
+
906
+ def aggregate(
907
+ self,
908
+ ddb_conn: DuckDBPyConnection,
909
+ dataset: Annotated[
910
+ DatasetReference,
911
+ MetricDatasetParameterAnnotation(
912
+ friendly_name="Dataset",
913
+ description="The agentic trace dataset containing traces with nested spans.",
914
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
915
+ ),
916
+ ],
917
+ ) -> list[SketchMetric]:
918
+ # Query traces by timestamp and calculate latency directly in SQL
919
+ results = ddb_conn.sql(
920
+ f"""
921
+ SELECT
922
+ time_bucket(INTERVAL '5 minutes', start_time) as ts,
923
+ CAST(EXTRACT(EPOCH FROM (end_time - start_time)) * 1000 AS INTEGER) as latency_ms
924
+ FROM {dataset.dataset_table_name}
925
+ WHERE start_time IS NOT NULL
926
+ AND end_time IS NOT NULL
927
+ AND end_time > start_time
928
+ ORDER BY ts DESC;
929
+ """,
930
+ ).df()
931
+
932
+ if results.empty:
933
+ return []
934
+
935
+ df = results
936
+ # Create a single time series without grouping dimensions
937
+ # Since we have no dimensions to group by, we create one time series for all data
938
+ series = [self._group_to_series(df, "ts", [], "latency_ms")]
939
+ metric = self.series_to_metric(self.METRIC_NAME, series)
940
+ return [metric]
941
+
942
+
943
+ class AgenticSpanLatencyAggregation(SketchAggregationFunction):
944
+ METRIC_NAME = "span_latency"
945
+
946
+ @staticmethod
947
+ def id() -> UUID:
948
+ return UUID("00000000-0000-0000-0000-000000000040")
949
+
950
+ @staticmethod
951
+ def display_name() -> str:
952
+ return "Span Latency"
953
+
954
+ @staticmethod
955
+ def description() -> str:
956
+ return "Aggregation that reports the latency of the agentic span in ms."
957
+
958
+ @staticmethod
959
+ def reported_aggregations() -> list[BaseReportedAggregation]:
960
+ return [
961
+ BaseReportedAggregation(
962
+ metric_name=AgenticSpanLatencyAggregation.METRIC_NAME,
963
+ description=AgenticSpanLatencyAggregation.description(),
964
+ ),
965
+ ]
966
+
967
+ def aggregate(
968
+ self,
969
+ ddb_conn: DuckDBPyConnection,
970
+ dataset: Annotated[
971
+ DatasetReference,
972
+ MetricDatasetParameterAnnotation(
973
+ friendly_name="Dataset",
974
+ description="The agentic trace dataset containing traces with nested spans.",
975
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
976
+ ),
977
+ ],
978
+ ) -> list[SketchMetric]:
979
+ results = root_span_in_time_buckets(ddb_conn, dataset)
980
+
981
+ latency_data = []
982
+ for _, row in results.iterrows():
983
+ ts = row["ts"]
984
+ root_spans = row["root_spans"]
985
+
986
+ # Parse root_spans if it's a string
987
+ if isinstance(root_spans, str):
988
+ root_spans = json.loads(root_spans)
989
+
990
+ # Extract all spans with their timing data
991
+ spans_with_timing = self._extract_spans_with_timing(root_spans)
992
+
993
+ for span_data in spans_with_timing:
994
+ span, current_agent, latency_ms = span_data
995
+ span_kind = span.get("span_kind", "unknown")
996
+
997
+ if latency_ms is not None and latency_ms > 0:
998
+ latency_data.append(
999
+ {
1000
+ "ts": ts,
1001
+ "latency_ms": latency_ms,
1002
+ "span_kind": span_kind,
1003
+ "agent_name": current_agent,
1004
+ }
1005
+ )
1006
+
1007
+ if not latency_data:
1008
+ return []
1009
+
1010
+ # Convert to DataFrame and create sketch metrics
1011
+ df = pd.DataFrame(latency_data)
1012
+ series = self.group_query_results_to_sketch_metrics(
1013
+ df,
1014
+ "latency_ms",
1015
+ ["span_kind", "agent_name"],
1016
+ "ts",
1017
+ )
1018
+ metric = self.series_to_metric(self.METRIC_NAME, series)
1019
+ return [metric]
1020
+
1021
+ def _extract_spans_with_timing(
1022
+ self, spans: list[str | dict[str, Any]], current_agent: str = "unknown"
1023
+ ) -> list[tuple[dict[str, Any], str, int | None]]:
1024
+ """Recursively extract spans with calculated latency in milliseconds"""
1025
+ spans_with_timing = []
1026
+
1027
+ for span_to_parse in spans:
1028
+ span = span_parser(span_to_parse)
1029
+
1030
+ # Update current agent name if this span is an AGENT
1031
+ if span.get("span_kind") == "AGENT":
1032
+ agent_name = extract_agent_name_from_span(span)
1033
+ if agent_name:
1034
+ current_agent = agent_name
1035
+
1036
+ # Calculate latency if both start_time and end_time exist
1037
+ start_time = span.get("start_time")
1038
+ end_time = span.get("end_time")
1039
+ latency_ms = None
1040
+
1041
+ if start_time and end_time:
1042
+ try:
1043
+ # Parse ISO format timestamps and calculate latency in milliseconds
1044
+ # Assume same timezone for start and end time, specific TZ not important for latency calculation
1045
+ start_dt = datetime.fromisoformat(start_time)
1046
+ end_dt = datetime.fromisoformat(end_time)
1047
+ latency_ms = int((end_dt - start_dt).total_seconds() * 1000)
1048
+ except (ValueError, TypeError) as e:
1049
+ logger.warning(
1050
+ f"Error calculating latency for span {span.get('span_id')}: {e}"
1051
+ )
1052
+
1053
+ spans_with_timing.append((span, current_agent, latency_ms))
1054
+
1055
+ # Recursively process children
1056
+ if children := span.get("children", []):
1057
+ spans_with_timing.extend(
1058
+ self._extract_spans_with_timing(children, current_agent)
1059
+ )
1060
+
1061
+ return spans_with_timing
@@ -18,7 +18,10 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
- from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
21
+ from arthur_common.tools.duckdb_data_loader import (
22
+ escape_str_literal,
23
+ unescape_identifier,
24
+ )
22
25
 
23
26
 
24
27
  class CategoricalCountAggregationFunction(NumericAggregationFunction):
@@ -93,30 +96,27 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
93
96
  ] = None,
94
97
  ) -> list[NumericMetric]:
95
98
  """Executed SQL with no segmentation columns:
96
- select time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts, \
99
+ select time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
97
100
  count(*) as count, \
98
- {categorical_col_escaped} as category, \
99
- {categorical_col_name_escaped} as column_name \
101
+ {categorical_col} as category, \
102
+ {categorical_col_name_unescaped} as column_name \
100
103
  from {dataset.dataset_table_name} \
101
104
  where ts is not null \
102
105
  group by ts, category
103
106
  """
104
107
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
105
- timestamp_col_escaped = escape_identifier(timestamp_col)
106
- categorical_col_escaped = escape_identifier(categorical_col)
107
- categorical_col_name_escaped = escape_str_literal(categorical_col)
108
+ categorical_col_name_unescaped = escape_str_literal(
109
+ unescape_identifier(categorical_col),
110
+ )
108
111
 
109
112
  # build query components with segmentation columns
110
- escaped_segmentation_cols = [
111
- escape_identifier(col) for col in segmentation_cols
112
- ]
113
113
  all_select_clause_cols = [
114
- f"time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts",
114
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
115
115
  f"count(*) as count",
116
- f"{categorical_col_escaped} as category",
117
- f"{categorical_col_name_escaped} as column_name",
118
- ] + escaped_segmentation_cols
119
- all_group_by_cols = ["ts", "category"] + escaped_segmentation_cols
116
+ f"{categorical_col} as category",
117
+ f"{categorical_col_name_unescaped} as column_name",
118
+ ] + segmentation_cols
119
+ all_group_by_cols = ["ts", "category"] + segmentation_cols
120
120
  extra_dims = ["column_name", "category"]
121
121
 
122
122
  # build query