arthur-common 2.4.1__py3-none-any.whl → 2.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

@@ -1,3 +1,4 @@
1
+ import re
1
2
  from abc import ABC, abstractmethod
2
3
  from base64 import b64encode
3
4
  from typing import Any, Type, Union
@@ -35,6 +36,31 @@ class AggregationFunction(ABC):
35
36
  """Returns the list of aggregations reported by the aggregate function."""
36
37
  raise NotImplementedError
37
38
 
39
+ @staticmethod
40
+ def get_innermost_segmentation_columns(segmentation_cols: list[str]) -> list[str]:
41
+ """
42
+ Extracts the innermost column name for nested segmentation columns or
43
+ returns the top-level column name for non-nested segmentation columns.
44
+ """
45
+ for i, col in enumerate(segmentation_cols):
46
+ # extract the innermost column for escaped column names (e.g. '"nested.col"."name"')
47
+ # otherwise return the name since it's a top-level column
48
+ if col.startswith('"') and col.endswith('"'):
49
+ identifier = col[1:-1]
50
+ identifier_split_in_struct_fields = re.split(r'"\."', identifier)
51
+
52
+ # For nested columns, take just the innermost field name
53
+ # Otherwise for top-level columns, take the whole name
54
+ if len(identifier_split_in_struct_fields) > 1:
55
+ innermost_field = identifier_split_in_struct_fields[-1]
56
+ segmentation_cols[i] = innermost_field.replace('""', '"')
57
+ else:
58
+ segmentation_cols[i] = identifier.replace('""', '"')
59
+ else:
60
+ segmentation_cols[i] = col
61
+
62
+ return segmentation_cols
63
+
38
64
  @abstractmethod
39
65
  def aggregate(
40
66
  self,
@@ -89,6 +115,11 @@ class NumericAggregationFunction(AggregationFunction, ABC):
89
115
  ),
90
116
  ]
91
117
 
118
+ # get innermost column name for nested segmentation columns
119
+ dim_columns = AggregationFunction.get_innermost_segmentation_columns(
120
+ dim_columns,
121
+ )
122
+
92
123
  calculated_metrics: list[NumericTimeSeries] = []
93
124
  # make sure dropna is False or rows with "null" as a dimension value will be dropped
94
125
  groups = data.groupby(dim_columns, dropna=False)
@@ -168,12 +199,21 @@ class SketchAggregationFunction(AggregationFunction, ABC):
168
199
  """
169
200
 
170
201
  calculated_metrics: list[SketchTimeSeries] = []
202
+
203
+ # get innermost column name for nested segmentation columns
204
+ dim_columns = AggregationFunction.get_innermost_segmentation_columns(
205
+ dim_columns,
206
+ )
207
+
171
208
  # make sure dropna is False or rows with "null" as a dimension value will be dropped
172
209
  groups = data.groupby(dim_columns, dropna=False)
173
210
  for _, group in groups:
174
211
  calculated_metrics.append(
175
212
  SketchAggregationFunction._group_to_series(
176
- group, timestamp_col, dim_columns, value_col
213
+ group,
214
+ timestamp_col,
215
+ dim_columns,
216
+ value_col,
177
217
  ),
178
218
  )
179
219
 
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import logging
3
+ from datetime import datetime
3
4
  from typing import Annotated, Any
4
5
  from uuid import UUID
5
6
 
@@ -27,6 +28,46 @@ TOOL_SCORE_NO_TOOL_VALUE = 2
27
28
  logger = logging.getLogger(__name__)
28
29
 
29
30
 
31
+ def root_span_in_time_buckets(
32
+ ddb_conn: DuckDBPyConnection, dataset: DatasetReference
33
+ ) -> pd.DataFrame:
34
+ return ddb_conn.sql(
35
+ f"""
36
+ SELECT
37
+ time_bucket(INTERVAL '5 minutes', start_time) as ts,
38
+ root_spans
39
+ FROM {dataset.dataset_table_name}
40
+ WHERE root_spans IS NOT NULL AND length(root_spans) > 0
41
+ ORDER BY ts DESC;
42
+ """,
43
+ ).df()
44
+
45
+
46
+ def span_parser(span_to_parse: str | dict[str, Any]) -> dict[str, Any]:
47
+ if isinstance(span_to_parse, str):
48
+ return json.loads(span_to_parse) # type: ignore[no-any-return]
49
+
50
+ return span_to_parse
51
+
52
+
53
+ def extract_agent_name_from_span(span: dict[str, Any]) -> str | None:
54
+ try:
55
+ raw_data = span.get("raw_data", {})
56
+ if isinstance(raw_data, str):
57
+ raw_data = json.loads(raw_data)
58
+
59
+ # Try to get agent name from the span's name field
60
+ agent_name = raw_data.get("name", "unknown")
61
+ if agent_name != "unknown":
62
+ return str(agent_name)
63
+ except (json.JSONDecodeError, KeyError, TypeError):
64
+ logger.error(
65
+ f"Error parsing attributes from span (span_id: {span.get('span_id')}) in trace {span.get('trace_id')}",
66
+ )
67
+
68
+ return None
69
+
70
+
30
71
  # TODO: create TypedDict for span
31
72
  def extract_spans_with_metrics_and_agents(
32
73
  root_spans: list[str | dict[str, Any]],
@@ -41,39 +82,26 @@ def extract_spans_with_metrics_and_agents(
41
82
  # TODO: Improve function so it won't modify variable outside of its scope
42
83
  def traverse_spans(
43
84
  spans: list[str | dict[str, Any]],
44
- current_agent_name: str = "unknown",
85
+ current_agent: str = "unknown",
45
86
  ) -> None:
46
87
  for span_to_parse in spans:
47
- if isinstance(span_to_parse, str):
48
- parsed_span = json.loads(span_to_parse)
49
- else:
50
- parsed_span = span_to_parse
88
+ parsed_span = span_parser(span_to_parse)
51
89
 
52
90
  # Update current agent name if this span is an AGENT
53
91
  if parsed_span.get("span_kind") == "AGENT":
54
- try:
55
- raw_data = parsed_span.get("raw_data", {})
56
- if isinstance(raw_data, str):
57
- raw_data = json.loads(raw_data)
58
-
59
- # Try to get agent name from the span's name field
60
- agent_name = raw_data.get("name", "unknown")
61
- if agent_name != "unknown":
62
- current_agent_name = agent_name
63
- except (json.JSONDecodeError, KeyError, TypeError):
64
- logger.error(
65
- f"Error parsing attributes from span (span_id: {parsed_span.get('span_id')}) in trace {parsed_span.get('trace_id')}",
66
- )
92
+ agent_name = extract_agent_name_from_span(parsed_span)
93
+ if agent_name:
94
+ current_agent = agent_name
67
95
 
68
96
  # Check if this span has metrics
69
97
  if parsed_span.get("metric_results", []):
70
98
  spans_with_metrics_and_agents.append(
71
- (parsed_span, current_agent_name),
99
+ (parsed_span, current_agent),
72
100
  )
73
101
 
74
102
  # Recursively traverse children with the current agent name
75
103
  if children_span := parsed_span.get("children", []):
76
- traverse_spans(children_span, current_agent_name)
104
+ traverse_spans(children_span, current_agent)
77
105
 
78
106
  traverse_spans(root_spans)
79
107
  return spans_with_metrics_and_agents
@@ -153,16 +181,7 @@ class AgenticMetricsOverTimeAggregation(SketchAggregationFunction):
153
181
  ],
154
182
  ) -> list[SketchMetric]:
155
183
  # Query traces by timestamp
156
- results = ddb_conn.sql(
157
- f"""
158
- SELECT
159
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
160
- root_spans
161
- FROM {dataset.dataset_table_name}
162
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
163
- ORDER BY ts DESC;
164
- """,
165
- ).df()
184
+ results = root_span_in_time_buckets(ddb_conn, dataset)
166
185
 
167
186
  # Process traces and extract spans with metrics
168
187
  tool_selection_data = []
@@ -409,17 +428,7 @@ class AgenticRelevancePassFailCountAggregation(NumericAggregationFunction):
409
428
  ),
410
429
  ],
411
430
  ) -> list[NumericMetric]:
412
- # Query traces by timestamp
413
- results = ddb_conn.sql(
414
- f"""
415
- SELECT
416
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
417
- root_spans
418
- FROM {dataset.dataset_table_name}
419
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
420
- ORDER BY ts DESC;
421
- """,
422
- ).df()
431
+ results = root_span_in_time_buckets(ddb_conn, dataset)
423
432
 
424
433
  # Process traces and extract spans with metrics
425
434
  processed_data = []
@@ -534,17 +543,7 @@ class AgenticToolPassFailCountAggregation(NumericAggregationFunction):
534
543
  ),
535
544
  ],
536
545
  ) -> list[NumericMetric]:
537
- # Query traces by timestamp
538
- results = ddb_conn.sql(
539
- f"""
540
- SELECT
541
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
542
- root_spans
543
- FROM {dataset.dataset_table_name}
544
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
545
- ORDER BY ts DESC;
546
- """,
547
- ).df()
546
+ results = root_span_in_time_buckets(ddb_conn, dataset)
548
547
 
549
548
  # Process traces and extract spans with metrics
550
549
  processed_data = []
@@ -713,16 +712,7 @@ class AgenticLLMCallCountAggregation(NumericAggregationFunction):
713
712
  ),
714
713
  ],
715
714
  ) -> list[NumericMetric]:
716
- results = ddb_conn.sql(
717
- f"""
718
- SELECT
719
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
720
- root_spans
721
- FROM {dataset.dataset_table_name}
722
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
723
- ORDER BY ts DESC;
724
- """,
725
- ).df()
715
+ results = root_span_in_time_buckets(ddb_conn, dataset)
726
716
 
727
717
  # Process traces and count LLM spans
728
718
  llm_call_counts = {}
@@ -738,10 +728,7 @@ class AgenticLLMCallCountAggregation(NumericAggregationFunction):
738
728
  def count_llm_spans(spans: list[str | dict[str, Any]]) -> int:
739
729
  count = 0
740
730
  for span_to_parse in spans:
741
- if isinstance(span_to_parse, str):
742
- span = json.loads(span_to_parse)
743
- else:
744
- span = span_to_parse
731
+ span = span_parser(span_to_parse)
745
732
 
746
733
  # Check if this span is an LLM span
747
734
  if span.get("span_kind") == "LLM":
@@ -813,16 +800,7 @@ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction)
813
800
  ],
814
801
  ) -> list[NumericMetric]:
815
802
  # Query traces by timestamp
816
- results = ddb_conn.sql(
817
- f"""
818
- SELECT
819
- time_bucket(INTERVAL '5 minutes', start_time) as ts,
820
- root_spans
821
- FROM {dataset.dataset_table_name}
822
- WHERE root_spans IS NOT NULL AND length(root_spans) > 0
823
- ORDER BY ts DESC;
824
- """,
825
- ).df()
803
+ results = root_span_in_time_buckets(ddb_conn, dataset)
826
804
 
827
805
  # Process traces and extract spans with metrics
828
806
  processed_data = []
@@ -899,3 +877,185 @@ class AgenticToolSelectionAndUsageByAgentAggregation(NumericAggregationFunction)
899
877
  )
900
878
  metric = self.series_to_metric(self.METRIC_NAME, series)
901
879
  return [metric]
880
+
881
+
882
+ class AgenticTraceLatencyAggregation(SketchAggregationFunction):
883
+ METRIC_NAME = "trace_latency"
884
+
885
+ @staticmethod
886
+ def id() -> UUID:
887
+ return UUID("00000000-0000-0000-0000-000000000039")
888
+
889
+ @staticmethod
890
+ def display_name() -> str:
891
+ return "Trace Latency"
892
+
893
+ @staticmethod
894
+ def description() -> str:
895
+ return "Aggregation that reports the latency of the agentic trace in ms."
896
+
897
+ @staticmethod
898
+ def reported_aggregations() -> list[BaseReportedAggregation]:
899
+ return [
900
+ BaseReportedAggregation(
901
+ metric_name=AgenticTraceLatencyAggregation.METRIC_NAME,
902
+ description=AgenticTraceLatencyAggregation.description(),
903
+ ),
904
+ ]
905
+
906
+ def aggregate(
907
+ self,
908
+ ddb_conn: DuckDBPyConnection,
909
+ dataset: Annotated[
910
+ DatasetReference,
911
+ MetricDatasetParameterAnnotation(
912
+ friendly_name="Dataset",
913
+ description="The agentic trace dataset containing traces with nested spans.",
914
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
915
+ ),
916
+ ],
917
+ ) -> list[SketchMetric]:
918
+ # Query traces by timestamp and calculate latency directly in SQL
919
+ results = ddb_conn.sql(
920
+ f"""
921
+ SELECT
922
+ time_bucket(INTERVAL '5 minutes', start_time) as ts,
923
+ CAST(EXTRACT(EPOCH FROM (end_time - start_time)) * 1000 AS INTEGER) as latency_ms
924
+ FROM {dataset.dataset_table_name}
925
+ WHERE start_time IS NOT NULL
926
+ AND end_time IS NOT NULL
927
+ AND end_time > start_time
928
+ ORDER BY ts DESC;
929
+ """,
930
+ ).df()
931
+
932
+ if results.empty:
933
+ return []
934
+
935
+ df = results
936
+ # Create a single time series without grouping dimensions
937
+ # Since we have no dimensions to group by, we create one time series for all data
938
+ series = [self._group_to_series(df, "ts", [], "latency_ms")]
939
+ metric = self.series_to_metric(self.METRIC_NAME, series)
940
+ return [metric]
941
+
942
+
943
+ class AgenticSpanLatencyAggregation(SketchAggregationFunction):
944
+ METRIC_NAME = "span_latency"
945
+
946
+ @staticmethod
947
+ def id() -> UUID:
948
+ return UUID("00000000-0000-0000-0000-000000000040")
949
+
950
+ @staticmethod
951
+ def display_name() -> str:
952
+ return "Span Latency"
953
+
954
+ @staticmethod
955
+ def description() -> str:
956
+ return "Aggregation that reports the latency of the agentic span in ms."
957
+
958
+ @staticmethod
959
+ def reported_aggregations() -> list[BaseReportedAggregation]:
960
+ return [
961
+ BaseReportedAggregation(
962
+ metric_name=AgenticSpanLatencyAggregation.METRIC_NAME,
963
+ description=AgenticSpanLatencyAggregation.description(),
964
+ ),
965
+ ]
966
+
967
+ def aggregate(
968
+ self,
969
+ ddb_conn: DuckDBPyConnection,
970
+ dataset: Annotated[
971
+ DatasetReference,
972
+ MetricDatasetParameterAnnotation(
973
+ friendly_name="Dataset",
974
+ description="The agentic trace dataset containing traces with nested spans.",
975
+ model_problem_type=ModelProblemType.AGENTIC_TRACE,
976
+ ),
977
+ ],
978
+ ) -> list[SketchMetric]:
979
+ results = root_span_in_time_buckets(ddb_conn, dataset)
980
+
981
+ latency_data = []
982
+ for _, row in results.iterrows():
983
+ ts = row["ts"]
984
+ root_spans = row["root_spans"]
985
+
986
+ # Parse root_spans if it's a string
987
+ if isinstance(root_spans, str):
988
+ root_spans = json.loads(root_spans)
989
+
990
+ # Extract all spans with their timing data
991
+ spans_with_timing = self._extract_spans_with_timing(root_spans)
992
+
993
+ for span_data in spans_with_timing:
994
+ span, current_agent, latency_ms = span_data
995
+ span_kind = span.get("span_kind", "unknown")
996
+
997
+ if latency_ms is not None and latency_ms > 0:
998
+ latency_data.append(
999
+ {
1000
+ "ts": ts,
1001
+ "latency_ms": latency_ms,
1002
+ "span_kind": span_kind,
1003
+ "agent_name": current_agent,
1004
+ }
1005
+ )
1006
+
1007
+ if not latency_data:
1008
+ return []
1009
+
1010
+ # Convert to DataFrame and create sketch metrics
1011
+ df = pd.DataFrame(latency_data)
1012
+ series = self.group_query_results_to_sketch_metrics(
1013
+ df,
1014
+ "latency_ms",
1015
+ ["span_kind", "agent_name"],
1016
+ "ts",
1017
+ )
1018
+ metric = self.series_to_metric(self.METRIC_NAME, series)
1019
+ return [metric]
1020
+
1021
+ def _extract_spans_with_timing(
1022
+ self, spans: list[str | dict[str, Any]], current_agent: str = "unknown"
1023
+ ) -> list[tuple[dict[str, Any], str, int | None]]:
1024
+ """Recursively extract spans with calculated latency in milliseconds"""
1025
+ spans_with_timing = []
1026
+
1027
+ for span_to_parse in spans:
1028
+ span = span_parser(span_to_parse)
1029
+
1030
+ # Update current agent name if this span is an AGENT
1031
+ if span.get("span_kind") == "AGENT":
1032
+ agent_name = extract_agent_name_from_span(span)
1033
+ if agent_name:
1034
+ current_agent = agent_name
1035
+
1036
+ # Calculate latency if both start_time and end_time exist
1037
+ start_time = span.get("start_time")
1038
+ end_time = span.get("end_time")
1039
+ latency_ms = None
1040
+
1041
+ if start_time and end_time:
1042
+ try:
1043
+ # Parse ISO format timestamps and calculate latency in milliseconds
1044
+ # Assume same timezone for start and end time, specific TZ not important for latency calculation
1045
+ start_dt = datetime.fromisoformat(start_time)
1046
+ end_dt = datetime.fromisoformat(end_time)
1047
+ latency_ms = int((end_dt - start_dt).total_seconds() * 1000)
1048
+ except (ValueError, TypeError) as e:
1049
+ logger.warning(
1050
+ f"Error calculating latency for span {span.get('span_id')}: {e}"
1051
+ )
1052
+
1053
+ spans_with_timing.append((span, current_agent, latency_ms))
1054
+
1055
+ # Recursively process children
1056
+ if children := span.get("children", []):
1057
+ spans_with_timing.extend(
1058
+ self._extract_spans_with_timing(children, current_agent)
1059
+ )
1060
+
1061
+ return spans_with_timing
@@ -18,8 +18,10 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
-
22
- from arthur_common.tools.duckdb_data_loader import unescape_identifier, escape_str_literal
21
+ from arthur_common.tools.duckdb_data_loader import (
22
+ escape_str_literal,
23
+ unescape_identifier,
24
+ )
23
25
 
24
26
 
25
27
  class CategoricalCountAggregationFunction(NumericAggregationFunction):
@@ -103,7 +105,9 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
103
105
  group by ts, category
104
106
  """
105
107
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
106
- categorical_col_name_unescaped = escape_str_literal(unescape_identifier(categorical_col))
108
+ categorical_col_name_unescaped = escape_str_literal(
109
+ unescape_identifier(categorical_col),
110
+ )
107
111
 
108
112
  # build query components with segmentation columns
109
113
  all_select_clause_cols = [
@@ -125,11 +129,10 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
125
129
 
126
130
  results = ddb_conn.sql(count_query).df()
127
131
 
128
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
129
132
  series = self.group_query_results_to_numeric_metrics(
130
133
  results,
131
134
  "count",
132
- unescaped_segmentation_cols + extra_dims,
135
+ segmentation_cols + extra_dims,
133
136
  timestamp_col="ts",
134
137
  )
135
138
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -20,8 +20,10 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
-
24
- from arthur_common.tools.duckdb_data_loader import unescape_identifier, escape_str_literal
23
+ from arthur_common.tools.duckdb_data_loader import (
24
+ escape_str_literal,
25
+ unescape_identifier,
26
+ )
25
27
 
26
28
 
27
29
  class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
@@ -97,7 +99,9 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
97
99
  ORDER BY ts
98
100
  """
99
101
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
100
- unescaped_prediction_col_name = escape_str_literal(unescape_identifier(prediction_col))
102
+ unescaped_prediction_col_name = escape_str_literal(
103
+ unescape_identifier(prediction_col),
104
+ )
101
105
 
102
106
  # build query components with segmentation columns
103
107
  first_subquery_select_cols = [
@@ -131,29 +135,28 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
131
135
 
132
136
  results = ddb_conn.sql(confusion_matrix_query).df()
133
137
 
134
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
135
138
  tp = self.group_query_results_to_numeric_metrics(
136
139
  results,
137
140
  "true_positive_count",
138
- dim_columns=unescaped_segmentation_cols + extra_dims,
141
+ dim_columns=segmentation_cols + extra_dims,
139
142
  timestamp_col="ts",
140
143
  )
141
144
  fp = self.group_query_results_to_numeric_metrics(
142
145
  results,
143
146
  "false_positive_count",
144
- dim_columns=unescaped_segmentation_cols + extra_dims,
147
+ dim_columns=segmentation_cols + extra_dims,
145
148
  timestamp_col="ts",
146
149
  )
147
150
  fn = self.group_query_results_to_numeric_metrics(
148
151
  results,
149
152
  "false_negative_count",
150
- dim_columns=unescaped_segmentation_cols + extra_dims,
153
+ dim_columns=segmentation_cols + extra_dims,
151
154
  timestamp_col="ts",
152
155
  )
153
156
  tn = self.group_query_results_to_numeric_metrics(
154
157
  results,
155
158
  "true_negative_count",
156
- dim_columns=unescaped_segmentation_cols + extra_dims,
159
+ dim_columns=segmentation_cols + extra_dims,
157
160
  timestamp_col="ts",
158
161
  )
159
162
  tp_metric = self.series_to_metric(self.TRUE_POSITIVE_METRIC_NAME, tp)
@@ -18,7 +18,6 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
- from arthur_common.tools.duckdb_data_loader import unescape_identifier
22
21
 
23
22
 
24
23
  class InferenceCountAggregationFunction(NumericAggregationFunction):
@@ -102,11 +101,11 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
102
101
  """
103
102
 
104
103
  results = ddb_conn.sql(count_query).df()
105
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
104
+
106
105
  series = self.group_query_results_to_numeric_metrics(
107
106
  results,
108
107
  "count",
109
- unescaped_segmentation_cols,
108
+ segmentation_cols,
110
109
  "ts",
111
110
  )
112
111
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -20,7 +20,6 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
- from arthur_common.tools.duckdb_data_loader import unescape_identifier
24
23
 
25
24
 
26
25
  class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction):
@@ -132,11 +131,10 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
132
131
 
133
132
  result = ddb_conn.sql(query).df()
134
133
 
135
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
136
134
  series = self.group_query_results_to_numeric_metrics(
137
135
  result,
138
136
  "count",
139
- unescaped_segmentation_cols + extra_dims,
137
+ segmentation_cols + extra_dims,
140
138
  "ts",
141
139
  )
142
140
  metric = self.series_to_metric(self._metric_name(), series)
@@ -278,11 +276,10 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
278
276
 
279
277
  result = ddb_conn.sql(query).df()
280
278
 
281
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
282
279
  series = self.group_query_results_to_numeric_metrics(
283
280
  result,
284
281
  "count",
285
- unescaped_segmentation_cols + extra_dims,
282
+ segmentation_cols + extra_dims,
286
283
  "ts",
287
284
  )
288
285
  metric = self.series_to_metric(self._metric_name(), series)
@@ -114,16 +114,17 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
114
114
 
115
115
  results = ddb_conn.sql(count_query).df()
116
116
 
117
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
118
117
  series = self.group_query_results_to_numeric_metrics(
119
118
  results,
120
119
  "count",
121
- unescaped_segmentation_cols,
120
+ segmentation_cols,
122
121
  "ts",
123
122
  )
124
123
  # preserve dimension that identifies the name of the nullable column used for the aggregation
125
124
  for point in series:
126
- point.dimensions.append(Dimension(name="column_name", value=unescape_identifier(nullable_col)))
125
+ point.dimensions.append(
126
+ Dimension(name="column_name", value=unescape_identifier(nullable_col)),
127
+ )
127
128
 
128
129
  metric = self.series_to_metric(self.METRIC_NAME, series)
129
130
  return [metric]
@@ -19,7 +19,6 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import unescape_identifier
23
22
 
24
23
 
25
24
  class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
@@ -139,17 +138,17 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
139
138
  """
140
139
 
141
140
  results = ddb_conn.sql(mae_query).df()
142
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
141
+
143
142
  count_series = self.group_query_results_to_numeric_metrics(
144
143
  results,
145
144
  "count",
146
- unescaped_segmentation_cols,
145
+ segmentation_cols,
147
146
  "ts",
148
147
  )
149
148
  absolute_error_series = self.group_query_results_to_numeric_metrics(
150
149
  results,
151
150
  "ae",
152
- unescaped_segmentation_cols,
151
+ segmentation_cols,
153
152
  "ts",
154
153
  )
155
154
 
@@ -19,7 +19,6 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import unescape_identifier
23
22
 
24
23
 
25
24
  class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
@@ -139,17 +138,17 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
139
138
  """
140
139
 
141
140
  results = ddb_conn.sql(mse_query).df()
142
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
141
+
143
142
  count_series = self.group_query_results_to_numeric_metrics(
144
143
  results,
145
144
  "count",
146
- unescaped_segmentation_cols,
145
+ segmentation_cols,
147
146
  "ts",
148
147
  )
149
148
  squared_error_series = self.group_query_results_to_numeric_metrics(
150
149
  results,
151
150
  "squared_error",
152
- unescaped_segmentation_cols,
151
+ segmentation_cols,
153
152
  "ts",
154
153
  )
155
154
 
@@ -20,8 +20,7 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
-
24
- from arthur_common.tools.duckdb_data_loader import escape_str_literal, unescape_identifier
23
+ from arthur_common.tools.duckdb_data_loader import escape_str_literal
25
24
 
26
25
 
27
26
  class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
@@ -244,30 +243,29 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
244
243
  """
245
244
 
246
245
  results = ddb_conn.sql(confusion_matrix_query).df()
247
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
248
246
 
249
247
  tp = self.group_query_results_to_numeric_metrics(
250
248
  results,
251
249
  "true_positive_count",
252
- dim_columns=unescaped_segmentation_cols + extra_dims,
250
+ dim_columns=segmentation_cols + extra_dims,
253
251
  timestamp_col="ts",
254
252
  )
255
253
  fp = self.group_query_results_to_numeric_metrics(
256
254
  results,
257
255
  "false_positive_count",
258
- dim_columns=unescaped_segmentation_cols + extra_dims,
256
+ dim_columns=segmentation_cols + extra_dims,
259
257
  timestamp_col="ts",
260
258
  )
261
259
  fn = self.group_query_results_to_numeric_metrics(
262
260
  results,
263
261
  "false_negative_count",
264
- dim_columns=unescaped_segmentation_cols + extra_dims,
262
+ dim_columns=segmentation_cols + extra_dims,
265
263
  timestamp_col="ts",
266
264
  )
267
265
  tn = self.group_query_results_to_numeric_metrics(
268
266
  results,
269
267
  "true_negative_count",
270
- dim_columns=unescaped_segmentation_cols + extra_dims,
268
+ dim_columns=segmentation_cols + extra_dims,
271
269
  timestamp_col="ts",
272
270
  )
273
271
  tp_metric = self.series_to_metric(
@@ -18,8 +18,10 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
-
22
- from arthur_common.tools.duckdb_data_loader import unescape_identifier, escape_str_literal
21
+ from arthur_common.tools.duckdb_data_loader import (
22
+ escape_str_literal,
23
+ unescape_identifier,
24
+ )
23
25
 
24
26
 
25
27
  class NumericSketchAggregationFunction(SketchAggregationFunction):
@@ -121,12 +123,11 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
121
123
  """
122
124
 
123
125
  results = ddb_conn.sql(data_query).df()
124
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
125
126
 
126
127
  series = self.group_query_results_to_sketch_metrics(
127
128
  results,
128
129
  unescape_identifier(numeric_col),
129
- unescaped_segmentation_cols + extra_dims,
130
+ segmentation_cols + extra_dims,
130
131
  "ts",
131
132
  )
132
133
 
@@ -118,17 +118,18 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
118
118
  """
119
119
 
120
120
  results = ddb_conn.sql(query).df()
121
- unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
122
121
 
123
122
  series = self.group_query_results_to_numeric_metrics(
124
123
  results,
125
124
  "sum",
126
- unescaped_segmentation_cols,
125
+ segmentation_cols,
127
126
  "ts",
128
127
  )
129
128
  # preserve dimension that identifies the name of the numeric column used for the aggregation
130
129
  for point in series:
131
- point.dimensions.append(Dimension(name="column_name", value=unescape_identifier(numeric_col)))
130
+ point.dimensions.append(
131
+ Dimension(name="column_name", value=unescape_identifier(numeric_col)),
132
+ )
132
133
 
133
134
  metric = self.series_to_metric(self.METRIC_NAME, series)
134
135
  return [metric]
@@ -1,4 +1,4 @@
1
- from pydantic import BaseModel, ConfigDict, Field, computed_field
1
+ from pydantic import BaseModel, Field
2
2
 
3
3
 
4
4
  class ConnectorPaginationOptions(BaseModel):
@@ -624,6 +624,8 @@ class SpanWithMetricsResponse(BaseModel):
624
624
  start_time: datetime
625
625
  end_time: datetime
626
626
  task_id: Optional[str] = None
627
+ session_id: Optional[str] = None
628
+ status_code: str = Field(description="Status code for the span (Unset, Error, Ok)")
627
629
  created_at: datetime
628
630
  updated_at: datetime
629
631
  raw_data: dict[str, Any]
@@ -650,6 +652,8 @@ class NestedSpanWithMetricsResponse(BaseModel):
650
652
  start_time: datetime
651
653
  end_time: datetime
652
654
  task_id: Optional[str] = None
655
+ session_id: Optional[str] = None
656
+ status_code: str = Field(description="Status code for the span (Unset, Error, Ok)")
653
657
  created_at: datetime
654
658
  updated_at: datetime
655
659
  raw_data: dict[str, Any]
@@ -2,7 +2,6 @@ import duckdb
2
2
 
3
3
  from arthur_common.config.config import Config
4
4
  from arthur_common.models.schema_definitions import SEGMENTATION_ALLOWED_DTYPES, DType
5
- from arthur_common.tools.duckdb_data_loader import escape_identifier
6
5
 
7
6
 
8
7
  def is_column_possible_segmentation(
@@ -40,12 +40,11 @@ class SchemaInferer:
40
40
  self.conn.sql(
41
41
  f"CREATE OR REPLACE TEMP TABLE {escaped_col} AS SELECT UNNEST({escaped_col}) as {escaped_col} FROM {table}",
42
42
  )
43
- return self._infer_schema(escaped_col, is_nested_col=True)
43
+ return self._infer_schema(escaped_col)
44
44
 
45
45
  def _infer_schema(
46
46
  self,
47
47
  table: str = "root",
48
- is_nested_col: bool = False,
49
48
  ) -> DatasetObjectType:
50
49
  """is_nested_col indicates whether the function is being called on an unnested/flattened table that represents
51
50
  a struct column or list column in the root table."""
@@ -105,9 +104,7 @@ class SchemaInferer:
105
104
  raise NotImplementedError(f"Type {col_type} not mappable.")
106
105
 
107
106
  # tag column as a possible segmentation column if it meets criteria
108
- # we only support top-level column aggregations right now (ie you can't aggregate on a nested column)
109
- # so we don't want to tag nested columns as possible segmentation columns
110
- if not is_nested_col and is_column_possible_segmentation(
107
+ if is_column_possible_segmentation(
111
108
  self.conn,
112
109
  table,
113
110
  escape_identifier(col_name),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: arthur-common
3
- Version: 2.4.1
3
+ Version: 2.4.3
4
4
  Summary: Utility code common to Arthur platform components.
5
5
  License: MIT
6
6
  Author: Arthur
@@ -1,20 +1,20 @@
1
1
  arthur_common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  arthur_common/aggregations/__init__.py,sha256=vISWyciQAtksa71OKeHNP-QyFGd1NzBKq_LBsG0QSG8,67
3
- arthur_common/aggregations/aggregator.py,sha256=AhyNqBDEbKtS3ZrnSIT9iZ1SK_TAuiUNg9s9loDvek0,8007
3
+ arthur_common/aggregations/aggregator.py,sha256=1kMyP89biBSR6omD1R9fsAwfjbau0gozcirQOwYVYjg,9680
4
4
  arthur_common/aggregations/functions/README.md,sha256=MkZoTAJ94My96R5Z8GAxud7S6vyR0vgVi9gqdt9a4XY,5460
5
5
  arthur_common/aggregations/functions/__init__.py,sha256=HqC3UNRURX7ZQHgamTrQvfA8u_FiZGZ4I4eQW7Ooe5o,1299
6
- arthur_common/aggregations/functions/agentic_aggregations.py,sha256=09th4RPRf-ogtVWpbcqqmITN2UFtfqXhQ7Rr6IBqQHo,33995
7
- arthur_common/aggregations/functions/categorical_count.py,sha256=_TD0s0JAtqC5RmT6ZNWLEBZm-dU4akm-Aor7EDVazzA,5176
8
- arthur_common/aggregations/functions/confusion_matrix.py,sha256=n33kyyZuxo8k6jUYnBUsc1fLotTmcw0H8rsX_x_oeJ0,21733
9
- arthur_common/aggregations/functions/inference_count.py,sha256=D49SpwFywipMqeC93gc3_ZGwBoGL89yKuA9_55dBWBw,3984
10
- arthur_common/aggregations/functions/inference_count_by_class.py,sha256=mYL6xMTb-_VO6mKGWHOtFAvWzTt-C_4vKf8KgioJGDg,11191
11
- arthur_common/aggregations/functions/inference_null_count.py,sha256=UlE5EZa3k2nKIv6Yzrnjq1MsZEzrau7Olumny8hsHtg,4672
12
- arthur_common/aggregations/functions/mean_absolute_error.py,sha256=YzrNHox_4HEGWn33E12d6eiQ8A9Rwct7AW3hOWrTW7I,6544
13
- arthur_common/aggregations/functions/mean_squared_error.py,sha256=b_is7FKRSninYs1ilAXeLPJFfmyCaiKvCC9Ev_OERio,6565
14
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py,sha256=e1KEyxIZocWMkDbnW0zfJHd5PUi_kyzwNUVFOD0l5Nk,12359
6
+ arthur_common/aggregations/functions/agentic_aggregations.py,sha256=82OJ174uGcDqf7OLXY7dwnnv1g4kubkjazpc7Yj0xw0,39531
7
+ arthur_common/aggregations/functions/categorical_count.py,sha256=jxV2w2Itmoh02VuazWN5z94PmQ-bRZjZpSoODGeBulQ,5099
8
+ arthur_common/aggregations/functions/confusion_matrix.py,sha256=2fIqo50TcbUlGgPXxGtfFr6ehyZn69R8sphigGuMDgo,21626
9
+ arthur_common/aggregations/functions/inference_count.py,sha256=Pxe5WT_Zgnn_wSDcm48l-flh-M5Zr72SbR4tQyNBk-o,3802
10
+ arthur_common/aggregations/functions/inference_count_by_class.py,sha256=fmzrbRxiWgmutJYrBs7JY1iIRF7F6kozBzcsMypatlE,10896
11
+ arthur_common/aggregations/functions/inference_null_count.py,sha256=X8mfeKb46VxUQFrjukSlVpM9AZCNvStsBHU3LsUbcEM,4591
12
+ arthur_common/aggregations/functions/mean_absolute_error.py,sha256=P9H0rRvpObnWQiu4p7-yW6y6R7_-Ju23y2YlZQgxvHA,6352
13
+ arthur_common/aggregations/functions/mean_squared_error.py,sha256=hZrHzfCscNnGKp_SqOeHEebzjMych1EXtnI1K70EYZE,6373
14
+ arthur_common/aggregations/functions/multiclass_confusion_matrix.py,sha256=eA4y0xJikErkRww5OudUAMG9Y6cYztkO4w561nWVh5w,12195
15
15
  arthur_common/aggregations/functions/multiclass_inference_count_by_class.py,sha256=yiMpdz4VuX1ELprXYupFu4B9aDLIhgfEi3ma8jZsT_M,4261
16
- arthur_common/aggregations/functions/numeric_stats.py,sha256=mMpVH1PvElGaz5mIQWy8sIkKPZ5kyeNOAM2iM2IlBvY,4760
17
- arthur_common/aggregations/functions/numeric_sum.py,sha256=Vq-dQonKTdLt8pYFwT5tCXyyL_FvVQxb6b3nFNRSqus,4861
16
+ arthur_common/aggregations/functions/numeric_stats.py,sha256=28y0Zdhk3kLFiJYVWq_uev1C1yBZDn1aTUEdvLkqo3k,4660
17
+ arthur_common/aggregations/functions/numeric_sum.py,sha256=TAeVVd5NqF7X9_hnMzbNVOVxdExcra4EZDkubtWHyAs,4780
18
18
  arthur_common/aggregations/functions/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arthur_common/aggregations/functions/shield_aggregations.py,sha256=BzPkpbhZRy16iFOobuusGKHfov5DxnXS2v_WThpw2fk,35659
20
20
  arthur_common/aggregations/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -23,7 +23,7 @@ arthur_common/config/config.py,sha256=fcpjOYjPKu4Duk63CuTHrOWKQKAlAhVUR60kF_2_Xo
23
23
  arthur_common/config/settings.yaml,sha256=0CrygUwJzC5mGcO5Xnvv2ttp-P7LIsx682jllYA96NQ,161
24
24
  arthur_common/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  arthur_common/models/common_schemas.py,sha256=31Br7DbIgrwHwzgiyMXrgPYrANhqSqle7kmismcy4TY,6770
26
- arthur_common/models/connectors.py,sha256=RwjY74cs0KTKw7Opywehg46SZ4vwN3xm6ujHRsRIQ8Y,2292
26
+ arthur_common/models/connectors.py,sha256=gRdX4lNz0ObU64FqMmoffHVBwEgO3JfOf3wjn3tKv0Q,2264
27
27
  arthur_common/models/constants.py,sha256=munkU0LrLsDs9BtAfozzw30FCguIowmAUKg_9vqwX24,1049
28
28
  arthur_common/models/datasets.py,sha256=7p1tyJEPwXjBs2ZRoai8hTzNl6MK9jU1DluzASApE_4,254
29
29
  arthur_common/models/enums.py,sha256=J2beHEMjLfOGgc-vh1aDpE7KmBGKzLoOUGYLtuciJro,3870
@@ -31,7 +31,7 @@ arthur_common/models/metric_schemas.py,sha256=Xf-1RTzg7iYtnBMLkUUUuMPzAujzzNvQx_
31
31
  arthur_common/models/metrics.py,sha256=mCa0aN-nuNHYcqGfkyKFeriI0krz0-ScgmXWXHlKoEI,11109
32
32
  arthur_common/models/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
33
  arthur_common/models/request_schemas.py,sha256=LEybzPoOzaaTyua48mr9sAVzrBK6dIeYhR158kMp0o8,29749
34
- arthur_common/models/response_schemas.py,sha256=qc6DDfY4GxtXtUiBllsQglvPnZzhe2Vw8D-S76B_r_0,25393
34
+ arthur_common/models/response_schemas.py,sha256=kY3NZceFaBRkxGDF5-W1CYDjtwFdf3xyLxdzrwHlmJI,25643
35
35
  arthur_common/models/schema_definitions.py,sha256=dcUSLjBmvyloStcBFmT_rHdXbKdvA8Yxi_avYUbps3E,16876
36
36
  arthur_common/models/task_job_specs.py,sha256=p7jsSb97ylHYNkwoHXNOJvx2zcnh2kxLeh3m0pddo4M,3442
37
37
  arthur_common/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -39,11 +39,11 @@ arthur_common/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSu
39
39
  arthur_common/tools/aggregation_analyzer.py,sha256=UfMtvFWXV2Dqly8S6nneGgomuvEGN-1tBz81tfkMcAE,11206
40
40
  arthur_common/tools/aggregation_loader.py,sha256=3CF46bNi-GdJBNOXkjYfCQ1Aung8lf65L532sdWmR_s,2351
41
41
  arthur_common/tools/duckdb_data_loader.py,sha256=A80wpATSc4VJLghoHwxpBEuUsxY93OZS0Qo4cFX7cRw,12462
42
- arthur_common/tools/duckdb_utils.py,sha256=8l8bUmjqJyj84DXyEOzO_DsD8VsO25DWYK_IYF--Zek,1211
42
+ arthur_common/tools/duckdb_utils.py,sha256=PZ3AKoBUaU6papqNiNQ4Sm2ugg5bGyXfaC_1I-E2q3s,1142
43
43
  arthur_common/tools/functions.py,sha256=FWL4eWO5-vLp86WudT-MGUKvf2B8f02IdoXQFKd6d8k,1093
44
44
  arthur_common/tools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
- arthur_common/tools/schema_inferer.py,sha256=9teI67umlGn0izp6pZ5UBuWxJthaWEmw3wRj2KPIbf4,5207
45
+ arthur_common/tools/schema_inferer.py,sha256=8ehIqAxuGlgM08RtwPB43a7TfenZyEIf1R0p1RYrkng,4920
46
46
  arthur_common/tools/time_utils.py,sha256=4gfiu9NXfvPZltiVNLSIQGylX6h2W0viNi9Kv4bKyfw,1410
47
- arthur_common-2.4.1.dist-info/METADATA,sha256=LA7R2B8LGE78eJrwpQTHP0nFRPzLTnntZvHQUNtcVm4,2146
48
- arthur_common-2.4.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
49
- arthur_common-2.4.1.dist-info/RECORD,,
47
+ arthur_common-2.4.3.dist-info/METADATA,sha256=IwWpahpZ5U0mEeZ1YdeR-nBn0tOeYM938q8wIn8Vb-0,2146
48
+ arthur_common-2.4.3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
49
+ arthur_common-2.4.3.dist-info/RECORD,,