arthur-common 2.1.68__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

@@ -18,7 +18,8 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
- from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
21
+
22
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier, escape_str_literal
22
23
 
23
24
 
24
25
  class CategoricalCountAggregationFunction(NumericAggregationFunction):
@@ -93,30 +94,25 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
93
94
  ] = None,
94
95
  ) -> list[NumericMetric]:
95
96
  """Executed SQL with no segmentation columns:
96
- select time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts, \
97
+ select time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
97
98
  count(*) as count, \
98
- {categorical_col_escaped} as category, \
99
- {categorical_col_name_escaped} as column_name \
99
+ {categorical_col} as category, \
100
+ {categorical_col_name_unescaped} as column_name \
100
101
  from {dataset.dataset_table_name} \
101
102
  where ts is not null \
102
103
  group by ts, category
103
104
  """
104
105
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
105
- timestamp_col_escaped = escape_identifier(timestamp_col)
106
- categorical_col_escaped = escape_identifier(categorical_col)
107
- categorical_col_name_escaped = escape_str_literal(categorical_col)
106
+ categorical_col_name_unescaped = escape_str_literal(unescape_identifier(categorical_col))
108
107
 
109
108
  # build query components with segmentation columns
110
- escaped_segmentation_cols = [
111
- escape_identifier(col) for col in segmentation_cols
112
- ]
113
109
  all_select_clause_cols = [
114
- f"time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts",
110
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
115
111
  f"count(*) as count",
116
- f"{categorical_col_escaped} as category",
117
- f"{categorical_col_name_escaped} as column_name",
118
- ] + escaped_segmentation_cols
119
- all_group_by_cols = ["ts", "category"] + escaped_segmentation_cols
112
+ f"{categorical_col} as category",
113
+ f"{categorical_col_name_unescaped} as column_name",
114
+ ] + segmentation_cols
115
+ all_group_by_cols = ["ts", "category"] + segmentation_cols
120
116
  extra_dims = ["column_name", "category"]
121
117
 
122
118
  # build query
@@ -129,10 +125,11 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
129
125
 
130
126
  results = ddb_conn.sql(count_query).df()
131
127
 
128
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
132
129
  series = self.group_query_results_to_numeric_metrics(
133
130
  results,
134
131
  "count",
135
- segmentation_cols + extra_dims,
132
+ unescaped_segmentation_cols + extra_dims,
136
133
  timestamp_col="ts",
137
134
  )
138
135
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -20,7 +20,8 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
- from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
23
+
24
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier, escape_str_literal
24
25
 
25
26
 
26
27
  class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
@@ -78,11 +79,11 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
78
79
  Without segmentation, this is the query:
79
80
  WITH normalized_data AS (
80
81
  SELECT
81
- {escaped_timestamp_col} AS timestamp,
82
- {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
83
- {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
82
+ {timestamp_col} AS timestamp,
83
+ {prediction_normalization_case.replace('value', prediction_col)} AS prediction,
84
+ {gt_normalization_case.replace('value', gt_values_col)} AS actual_value
84
85
  FROM {dataset.dataset_table_name}
85
- WHERE {escaped_timestamp_col} IS NOT NULL
86
+ WHERE {timestamp_col} IS NOT NULL
86
87
  )
87
88
  SELECT
88
89
  time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
@@ -90,34 +91,29 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
90
91
  SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
91
92
  SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
92
93
  SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count,
93
- {escaped_prediction_col_name} as prediction_column_name
94
+ {unescaped_prediction_col_name} as prediction_column_name
94
95
  FROM normalized_data
95
96
  GROUP BY ts
96
97
  ORDER BY ts
97
98
  """
98
99
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
99
- escaped_timestamp_col = escape_identifier(timestamp_col)
100
- escaped_prediction_col = escape_identifier(prediction_col)
101
- escaped_prediction_col_name = escape_str_literal(prediction_col)
102
- escaped_gt_values_col = escape_identifier(gt_values_col)
100
+ unescaped_prediction_col_name = escape_str_literal(unescape_identifier(prediction_col))
101
+
103
102
  # build query components with segmentation columns
104
- escaped_segmentation_cols = [
105
- escape_identifier(col) for col in segmentation_cols
106
- ]
107
103
  first_subquery_select_cols = [
108
- f"{escaped_timestamp_col} AS timestamp",
109
- f"{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction",
110
- f"{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value",
111
- ] + escaped_segmentation_cols
104
+ f"{timestamp_col} AS timestamp",
105
+ f"{prediction_normalization_case.replace('value', prediction_col)} AS prediction",
106
+ f"{gt_normalization_case.replace('value', gt_values_col)} AS actual_value",
107
+ ] + segmentation_cols
112
108
  second_subquery_select_cols = [
113
109
  "time_bucket(INTERVAL '5 minutes', timestamp) AS ts",
114
110
  "SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count",
115
111
  "SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count",
116
112
  "SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count",
117
113
  "SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count",
118
- f"{escaped_prediction_col_name} as prediction_column_name",
119
- ] + escaped_segmentation_cols
120
- second_subquery_group_by_cols = ["ts"] + escaped_segmentation_cols
114
+ f"{unescaped_prediction_col_name} as prediction_column_name",
115
+ ] + segmentation_cols
116
+ second_subquery_group_by_cols = ["ts"] + segmentation_cols
121
117
  extra_dims = ["prediction_column_name"]
122
118
 
123
119
  # build query
@@ -125,7 +121,7 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
125
121
  WITH normalized_data AS (
126
122
  SELECT {", ".join(first_subquery_select_cols)}
127
123
  FROM {dataset.dataset_table_name}
128
- WHERE {escaped_timestamp_col} IS NOT NULL
124
+ WHERE {timestamp_col} IS NOT NULL
129
125
  )
130
126
  SELECT {", ".join(second_subquery_select_cols)}
131
127
  FROM normalized_data
@@ -135,28 +131,29 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
135
131
 
136
132
  results = ddb_conn.sql(confusion_matrix_query).df()
137
133
 
134
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
138
135
  tp = self.group_query_results_to_numeric_metrics(
139
136
  results,
140
137
  "true_positive_count",
141
- dim_columns=segmentation_cols + extra_dims,
138
+ dim_columns=unescaped_segmentation_cols + extra_dims,
142
139
  timestamp_col="ts",
143
140
  )
144
141
  fp = self.group_query_results_to_numeric_metrics(
145
142
  results,
146
143
  "false_positive_count",
147
- dim_columns=segmentation_cols + extra_dims,
144
+ dim_columns=unescaped_segmentation_cols + extra_dims,
148
145
  timestamp_col="ts",
149
146
  )
150
147
  fn = self.group_query_results_to_numeric_metrics(
151
148
  results,
152
149
  "false_negative_count",
153
- dim_columns=segmentation_cols + extra_dims,
150
+ dim_columns=unescaped_segmentation_cols + extra_dims,
154
151
  timestamp_col="ts",
155
152
  )
156
153
  tn = self.group_query_results_to_numeric_metrics(
157
154
  results,
158
155
  "true_negative_count",
159
- dim_columns=segmentation_cols + extra_dims,
156
+ dim_columns=unescaped_segmentation_cols + extra_dims,
160
157
  timestamp_col="ts",
161
158
  )
162
159
  tp_metric = self.series_to_metric(self.TRUE_POSITIVE_METRIC_NAME, tp)
@@ -243,9 +240,8 @@ class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
243
240
  ] = None,
244
241
  ) -> list[NumericMetric]:
245
242
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
246
- escaped_prediction_col = escape_identifier(prediction_col)
247
243
  # Get the type of prediction column
248
- type_query = f"SELECT typeof({escaped_prediction_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
244
+ type_query = f"SELECT typeof({prediction_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
249
245
  res = ddb_conn.sql(type_query).fetchone()
250
246
  # As long as this column exists, we should be able to get the type. This is here to make mypy happy.
251
247
  if not res:
@@ -476,7 +472,6 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
476
472
  ),
477
473
  ] = None,
478
474
  ) -> list[NumericMetric]:
479
- escaped_gt_values_col = escape_identifier(gt_values_col)
480
475
  prediction_normalization_case = f"""
481
476
  CASE
482
477
  WHEN value >= {threshold} THEN 1
@@ -485,7 +480,7 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
485
480
  END
486
481
  """
487
482
 
488
- type_query = f"SELECT typeof({escaped_gt_values_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
483
+ type_query = f"SELECT typeof({gt_values_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
489
484
  res = ddb_conn.sql(type_query).fetchone()
490
485
  # As long as this column exists, we should be able to get the type. This is here to make mypy happy.
491
486
  if not res:
@@ -18,7 +18,7 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
- from arthur_common.tools.duckdb_data_loader import escape_identifier
21
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier
22
22
 
23
23
 
24
24
  class InferenceCountAggregationFunction(NumericAggregationFunction):
@@ -80,23 +80,19 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
80
80
  ] = None,
81
81
  ) -> list[NumericMetric]:
82
82
  """Executed SQL with no segmentation columns:
83
- select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
83
+ select time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
84
84
  count(*) as count \
85
85
  from {dataset.dataset_table_name} \
86
86
  group by ts \
87
87
  """
88
88
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
89
- escaped_timestamp_col = escape_identifier(timestamp_col)
90
89
 
91
90
  # build query components with segmentation columns
92
- escaped_segmentation_cols = [
93
- escape_identifier(col) for col in segmentation_cols
94
- ]
95
91
  all_select_clause_cols = [
96
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
92
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
97
93
  f"count(*) as count",
98
- ] + escaped_segmentation_cols
99
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
94
+ ] + segmentation_cols
95
+ all_group_by_cols = ["ts"] + segmentation_cols
100
96
 
101
97
  # build query
102
98
  count_query = f"""
@@ -106,10 +102,11 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
106
102
  """
107
103
 
108
104
  results = ddb_conn.sql(count_query).df()
105
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
109
106
  series = self.group_query_results_to_numeric_metrics(
110
107
  results,
111
108
  "count",
112
- segmentation_cols,
109
+ unescaped_segmentation_cols,
113
110
  "ts",
114
111
  )
115
112
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -20,7 +20,7 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
- from arthur_common.tools.duckdb_data_loader import escape_identifier
23
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier
24
24
 
25
25
 
26
26
  class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction):
@@ -100,31 +100,26 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
100
100
  ) -> list[NumericMetric]:
101
101
  """Executed SQL with no segmentation columns:
102
102
  SELECT
103
- time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
104
- {escaped_pred_col} as prediction,
103
+ time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts,
104
+ {prediction_col} as prediction,
105
105
  COUNT(*) as count
106
106
  FROM {dataset.dataset_table_name}
107
107
  GROUP BY
108
108
  ts,
109
109
  -- group by raw column name instead of alias in select
110
110
  -- in case table has a column called 'prediction'
111
- {escaped_pred_col}
111
+ {prediction_col}
112
112
  ORDER BY ts
113
113
  """
114
114
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
115
- escaped_timestamp_col = escape_identifier(timestamp_col)
116
- escaped_pred_col = escape_identifier(prediction_col)
117
115
 
118
116
  # build query components with segmentation columns
119
- escaped_segmentation_cols = [
120
- escape_identifier(col) for col in segmentation_cols
121
- ]
122
117
  all_select_clause_cols = [
123
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
124
- f"{escaped_pred_col} as prediction",
118
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
119
+ f"{prediction_col} as prediction",
125
120
  f"COUNT(*) as count",
126
- ] + escaped_segmentation_cols
127
- all_group_by_cols = ["ts", f"{escaped_pred_col}"] + escaped_segmentation_cols
121
+ ] + segmentation_cols
122
+ all_group_by_cols = ["ts", f"{prediction_col}"] + segmentation_cols
128
123
  extra_dims = ["prediction"]
129
124
 
130
125
  # build query
@@ -137,10 +132,11 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
137
132
 
138
133
  result = ddb_conn.sql(query).df()
139
134
 
135
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
140
136
  series = self.group_query_results_to_numeric_metrics(
141
137
  result,
142
138
  "count",
143
- segmentation_cols + extra_dims,
139
+ unescaped_segmentation_cols + extra_dims,
144
140
  "ts",
145
141
  )
146
142
  metric = self.series_to_metric(self._metric_name(), series)
@@ -248,34 +244,29 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
248
244
  ) -> list[NumericMetric]:
249
245
  """Executed SQL with no segmentation columns:
250
246
  SELECT
251
- time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
252
- CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
247
+ time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts,
248
+ CASE WHEN {prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
253
249
  COUNT(*) as count
254
250
  FROM {dataset.dataset_table_name}
255
251
  GROUP BY
256
252
  ts,
257
253
  -- group by raw column name instead of alias in select
258
254
  -- in case table has a column called 'prediction'
259
- {escaped_prediction_col}
255
+ {prediction_col}
260
256
  ORDER BY ts
261
257
  """
262
258
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
263
- escaped_timestamp_col = escape_identifier(timestamp_col)
264
- escaped_prediction_col = escape_identifier(prediction_col)
265
259
 
266
260
  # build query components with segmentation columns
267
- escaped_segmentation_cols = [
268
- escape_identifier(col) for col in segmentation_cols
269
- ]
270
261
  all_select_clause_cols = [
271
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
272
- f"CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction",
262
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
263
+ f"CASE WHEN {prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction",
273
264
  f"COUNT(*) as count",
274
- ] + escaped_segmentation_cols
265
+ ] + segmentation_cols
275
266
  all_group_by_cols = [
276
267
  "ts",
277
- f"{escaped_prediction_col}",
278
- ] + escaped_segmentation_cols
268
+ f"{prediction_col}",
269
+ ] + segmentation_cols
279
270
  extra_dims = ["prediction"]
280
271
 
281
272
  query = f"""
@@ -287,10 +278,11 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
287
278
 
288
279
  result = ddb_conn.sql(query).df()
289
280
 
281
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
290
282
  series = self.group_query_results_to_numeric_metrics(
291
283
  result,
292
284
  "count",
293
- segmentation_cols + extra_dims,
285
+ unescaped_segmentation_cols + extra_dims,
294
286
  "ts",
295
287
  )
296
288
  metric = self.series_to_metric(self._metric_name(), series)
@@ -19,7 +19,7 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import escape_identifier
22
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier
23
23
 
24
24
 
25
25
  class InferenceNullCountAggregationFunction(NumericAggregationFunction):
@@ -90,44 +90,40 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
90
90
  ] = None,
91
91
  ) -> list[NumericMetric]:
92
92
  """Executed SQL with no segmentation columns:
93
- select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
93
+ select time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
94
94
  count(*) as count \
95
- from {dataset.dataset_table_name} where {escaped_nullable_col} is null \
95
+ from {dataset.dataset_table_name} where {nullable_col} is null \
96
96
  group by ts \
97
97
  """
98
98
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
99
- escaped_timestamp_col = escape_identifier(timestamp_col)
100
- escaped_nullable_col = escape_identifier(nullable_col)
101
99
 
102
100
  # build query components with segmentation columns
103
- escaped_segmentation_cols = [
104
- escape_identifier(col) for col in segmentation_cols
105
- ]
106
101
  all_select_clause_cols = [
107
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
102
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
108
103
  f"count(*) as count",
109
- ] + escaped_segmentation_cols
110
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
104
+ ] + segmentation_cols
105
+ all_group_by_cols = ["ts"] + segmentation_cols
111
106
 
112
107
  # build query
113
108
  count_query = f"""
114
109
  select {", ".join(all_select_clause_cols)}
115
110
  from {dataset.dataset_table_name}
116
- where {escaped_nullable_col} is null
111
+ where {nullable_col} is null
117
112
  group by {", ".join(all_group_by_cols)}
118
113
  """
119
114
 
120
115
  results = ddb_conn.sql(count_query).df()
121
116
 
117
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
122
118
  series = self.group_query_results_to_numeric_metrics(
123
119
  results,
124
120
  "count",
125
- segmentation_cols,
121
+ unescaped_segmentation_cols,
126
122
  "ts",
127
123
  )
128
124
  # preserve dimension that identifies the name of the nullable column used for the aggregation
129
125
  for point in series:
130
- point.dimensions.append(Dimension(name="column_name", value=nullable_col))
126
+ point.dimensions.append(Dimension(name="column_name", value=unescape_identifier(nullable_col)))
131
127
 
132
128
  metric = self.series_to_metric(self.METRIC_NAME, series)
133
129
  return [metric]
@@ -19,7 +19,7 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import escape_identifier
22
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier
23
23
 
24
24
 
25
25
  class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
@@ -111,50 +111,45 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
111
111
  ] = None,
112
112
  ) -> list[NumericMetric]:
113
113
  """Executed SQL with no segmentation columns:
114
- SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
115
- SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae, \
114
+ SELECT time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
115
+ SUM(ABS({prediction_col} - {ground_truth_col})) as ae, \
116
116
  COUNT(*) as count \
117
117
  FROM {dataset.dataset_table_name} \
118
- WHERE {escaped_prediction_col} IS NOT NULL \
119
- AND {escaped_ground_truth_col} IS NOT NULL \
118
+ WHERE {prediction_col} IS NOT NULL \
119
+ AND {ground_truth_col} IS NOT NULL \
120
120
  GROUP BY ts order by ts desc \
121
121
  """
122
122
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
123
- escaped_timestamp_col = escape_identifier(timestamp_col)
124
- escaped_prediction_col = escape_identifier(prediction_col)
125
- escaped_ground_truth_col = escape_identifier(ground_truth_col)
126
123
 
127
124
  # build query components with segmentation columns
128
- escaped_segmentation_cols = [
129
- escape_identifier(col) for col in segmentation_cols
130
- ]
131
125
  all_select_clause_cols = [
132
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
133
- f"SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae",
126
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
127
+ f"SUM(ABS({prediction_col} - {ground_truth_col})) as ae",
134
128
  f"COUNT(*) as count",
135
- ] + escaped_segmentation_cols
136
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
129
+ ] + segmentation_cols
130
+ all_group_by_cols = ["ts"] + segmentation_cols
137
131
 
138
132
  # build query
139
133
  mae_query = f"""
140
134
  SELECT {", ".join(all_select_clause_cols)}
141
135
  FROM {dataset.dataset_table_name}
142
- WHERE {escaped_prediction_col} IS NOT NULL
143
- AND {escaped_ground_truth_col} IS NOT NULL
136
+ WHERE {prediction_col} IS NOT NULL
137
+ AND {ground_truth_col} IS NOT NULL
144
138
  GROUP BY {", ".join(all_group_by_cols)} order by ts desc
145
139
  """
146
140
 
147
141
  results = ddb_conn.sql(mae_query).df()
142
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
148
143
  count_series = self.group_query_results_to_numeric_metrics(
149
144
  results,
150
145
  "count",
151
- segmentation_cols,
146
+ unescaped_segmentation_cols,
152
147
  "ts",
153
148
  )
154
149
  absolute_error_series = self.group_query_results_to_numeric_metrics(
155
150
  results,
156
151
  "ae",
157
- segmentation_cols,
152
+ unescaped_segmentation_cols,
158
153
  "ts",
159
154
  )
160
155
 
@@ -19,7 +19,7 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import escape_identifier
22
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier
23
23
 
24
24
 
25
25
  class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
@@ -111,50 +111,45 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
111
111
  ] = None,
112
112
  ) -> list[NumericMetric]:
113
113
  """Executed SQL with no segmentation columns:
114
- SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
115
- SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error, \
114
+ SELECT time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
115
+ SUM(POW({prediction_col} - {ground_truth_col}, 2)) as squared_error, \
116
116
  COUNT(*) as count \
117
117
  FROM {dataset.dataset_table_name} \
118
- WHERE {escaped_prediction_col} IS NOT NULL \
119
- AND {escaped_ground_truth_col} IS NOT NULL \
118
+ WHERE {prediction_col} IS NOT NULL \
119
+ AND {ground_truth_col} IS NOT NULL \
120
120
  GROUP BY ts order by ts desc \
121
121
  """
122
122
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
123
- escaped_timestamp_col = escape_identifier(timestamp_col)
124
- escaped_prediction_col = escape_identifier(prediction_col)
125
- escaped_ground_truth_col = escape_identifier(ground_truth_col)
126
123
 
127
124
  # build query components with segmentation columns
128
- escaped_segmentation_cols = [
129
- escape_identifier(col) for col in segmentation_cols
130
- ]
131
125
  all_select_clause_cols = [
132
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
133
- f"SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error",
126
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
127
+ f"SUM(POW({prediction_col} - {ground_truth_col}, 2)) as squared_error",
134
128
  f"COUNT(*) as count",
135
- ] + escaped_segmentation_cols
136
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
129
+ ] + segmentation_cols
130
+ all_group_by_cols = ["ts"] + segmentation_cols
137
131
 
138
132
  # build query
139
133
  mse_query = f"""
140
134
  SELECT {", ".join(all_select_clause_cols)}
141
135
  FROM {dataset.dataset_table_name}
142
- WHERE {escaped_prediction_col} IS NOT NULL
143
- AND {escaped_ground_truth_col} IS NOT NULL
136
+ WHERE {prediction_col} IS NOT NULL
137
+ AND {ground_truth_col} IS NOT NULL
144
138
  GROUP BY {", ".join(all_group_by_cols)} order by ts desc
145
139
  """
146
140
 
147
141
  results = ddb_conn.sql(mse_query).df()
142
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
148
143
  count_series = self.group_query_results_to_numeric_metrics(
149
144
  results,
150
145
  "count",
151
- segmentation_cols,
146
+ unescaped_segmentation_cols,
152
147
  "ts",
153
148
  )
154
149
  squared_error_series = self.group_query_results_to_numeric_metrics(
155
150
  results,
156
151
  "squared_error",
157
- segmentation_cols,
152
+ unescaped_segmentation_cols,
158
153
  "ts",
159
154
  )
160
155
 
@@ -20,7 +20,8 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
- from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
23
+
24
+ from arthur_common.tools.duckdb_data_loader import escape_str_literal, unescape_identifier
24
25
 
25
26
 
26
27
  class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
@@ -194,11 +195,11 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
194
195
  Returns the following SQL with no segmentation:
195
196
  WITH normalized_data AS (
196
197
  SELECT
197
- {escaped_timestamp_col} AS timestamp,
198
- {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
199
- {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
198
+ {timestamp_col} AS timestamp,
199
+ {prediction_normalization_case.replace('value', prediction_col)} AS prediction,
200
+ {gt_normalization_case.replace('value', gt_values_col)} AS actual_value
200
201
  FROM {dataset.dataset_table_name}
201
- WHERE {escaped_timestamp_col} IS NOT NULL
202
+ WHERE {timestamp_col} IS NOT NULL
202
203
  )
203
204
  SELECT
204
205
  time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
@@ -212,19 +213,12 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
212
213
  ORDER BY ts
213
214
 
214
215
  """
215
- escaped_timestamp_col = escape_identifier(timestamp_col)
216
- escaped_prediction_col = escape_identifier(prediction_col)
217
- escaped_gt_values_col = escape_identifier(gt_values_col)
218
-
219
216
  # build query components with segmentation columns
220
- escaped_segmentation_cols = [
221
- escape_identifier(col) for col in segmentation_cols
222
- ]
223
217
  first_subquery_select_cols = [
224
- f"{escaped_timestamp_col} AS timestamp",
225
- f"{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction",
226
- f"{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value",
227
- ] + escaped_segmentation_cols
218
+ f"{timestamp_col} AS timestamp",
219
+ f"{prediction_normalization_case.replace('value', prediction_col)} AS prediction",
220
+ f"{gt_normalization_case.replace('value', gt_values_col)} AS actual_value",
221
+ ] + segmentation_cols
228
222
  second_subquery_select_cols = [
229
223
  "time_bucket(INTERVAL '5 minutes', timestamp) AS ts",
230
224
  "SUM(CASE WHEN prediction = 1 AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count",
@@ -232,8 +226,8 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
232
226
  "SUM(CASE WHEN prediction = 0 AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count",
233
227
  "SUM(CASE WHEN prediction = 0 AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count",
234
228
  f"any_value({escaped_positive_class_label}) as class_label",
235
- ] + escaped_segmentation_cols
236
- second_subquery_group_by_cols = ["ts"] + escaped_segmentation_cols
229
+ ] + segmentation_cols
230
+ second_subquery_group_by_cols = ["ts"] + segmentation_cols
237
231
  extra_dims = ["class_label"]
238
232
 
239
233
  # build query
@@ -241,7 +235,7 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
241
235
  WITH normalized_data AS (
242
236
  SELECT {", ".join(first_subquery_select_cols)}
243
237
  FROM {dataset.dataset_table_name}
244
- WHERE {escaped_timestamp_col} IS NOT NULL
238
+ WHERE {timestamp_col} IS NOT NULL
245
239
  )
246
240
  SELECT {", ".join(second_subquery_select_cols)}
247
241
  FROM normalized_data
@@ -250,29 +244,30 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
250
244
  """
251
245
 
252
246
  results = ddb_conn.sql(confusion_matrix_query).df()
247
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
253
248
 
254
249
  tp = self.group_query_results_to_numeric_metrics(
255
250
  results,
256
251
  "true_positive_count",
257
- dim_columns=segmentation_cols + extra_dims,
252
+ dim_columns=unescaped_segmentation_cols + extra_dims,
258
253
  timestamp_col="ts",
259
254
  )
260
255
  fp = self.group_query_results_to_numeric_metrics(
261
256
  results,
262
257
  "false_positive_count",
263
- dim_columns=segmentation_cols + extra_dims,
258
+ dim_columns=unescaped_segmentation_cols + extra_dims,
264
259
  timestamp_col="ts",
265
260
  )
266
261
  fn = self.group_query_results_to_numeric_metrics(
267
262
  results,
268
263
  "false_negative_count",
269
- dim_columns=segmentation_cols + extra_dims,
264
+ dim_columns=unescaped_segmentation_cols + extra_dims,
270
265
  timestamp_col="ts",
271
266
  )
272
267
  tn = self.group_query_results_to_numeric_metrics(
273
268
  results,
274
269
  "true_negative_count",
275
- dim_columns=segmentation_cols + extra_dims,
270
+ dim_columns=unescaped_segmentation_cols + extra_dims,
276
271
  timestamp_col="ts",
277
272
  )
278
273
  tp_metric = self.series_to_metric(
@@ -18,7 +18,8 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
- from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
21
+
22
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier, escape_str_literal
22
23
 
23
24
 
24
25
  class NumericSketchAggregationFunction(SketchAggregationFunction):
@@ -95,41 +96,37 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
95
96
  ] = None,
96
97
  ) -> list[SketchMetric]:
97
98
  """Executed SQL with no segmentation columns:
98
- select {escaped_timestamp_col_id} as ts, \
99
- {escaped_numeric_col_id}, \
99
+ select {timestamp_col} as ts, \
100
+ {numeric_col}, \
100
101
  {numeric_col_name_str} as column_name \
101
102
  from {dataset.dataset_table_name} \
102
- where {escaped_numeric_col_id} is not null \
103
+ where {numeric_col} is not null \
103
104
  """
104
105
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
105
- escaped_timestamp_col_id = escape_identifier(timestamp_col)
106
- escaped_numeric_col_id = escape_identifier(numeric_col)
107
- numeric_col_name_str = escape_str_literal(numeric_col)
106
+ numeric_col_name_str = escape_str_literal(unescape_identifier(numeric_col))
108
107
 
109
108
  # build query components with segmentation columns
110
- escaped_segmentation_cols = [
111
- escape_identifier(col) for col in segmentation_cols
112
- ]
113
109
  all_select_clause_cols = [
114
- f"{escaped_timestamp_col_id} as ts",
115
- f"{escaped_numeric_col_id}",
110
+ f"{timestamp_col} as ts",
111
+ f"{numeric_col}",
116
112
  f"{numeric_col_name_str} as column_name",
117
- ] + escaped_segmentation_cols
113
+ ] + segmentation_cols
118
114
  extra_dims = ["column_name"]
119
115
 
120
116
  # build query
121
117
  data_query = f"""
122
118
  select {", ".join(all_select_clause_cols)}
123
119
  from {dataset.dataset_table_name}
124
- where {escaped_numeric_col_id} is not null
120
+ where {numeric_col} is not null
125
121
  """
126
122
 
127
123
  results = ddb_conn.sql(data_query).df()
124
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
128
125
 
129
126
  series = self.group_query_results_to_sketch_metrics(
130
127
  results,
131
- numeric_col,
132
- segmentation_cols + extra_dims,
128
+ unescape_identifier(numeric_col),
129
+ unescaped_segmentation_cols + extra_dims,
133
130
  "ts",
134
131
  )
135
132
 
@@ -19,7 +19,7 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import escape_identifier
22
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier
23
23
 
24
24
 
25
25
  class NumericSumAggregationFunction(NumericAggregationFunction):
@@ -94,45 +94,41 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
94
94
  ] = None,
95
95
  ) -> list[NumericMetric]:
96
96
  """Executed SQL with no segmentation columns:
97
- select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
98
- sum({escaped_numeric_col}) as sum \
97
+ select time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
98
+ sum({numeric_col}) as sum \
99
99
  from {dataset.dataset_table_name} \
100
- where {escaped_numeric_col} is not null \
100
+ where {numeric_col} is not null \
101
101
  group by ts \
102
102
  """
103
103
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
104
- escaped_timestamp_col = escape_identifier(timestamp_col)
105
- escaped_numeric_col = escape_identifier(numeric_col)
106
104
 
107
105
  # build query components with segmentation columns
108
- escaped_segmentation_cols = [
109
- escape_identifier(col) for col in segmentation_cols
110
- ]
111
106
  all_select_clause_cols = [
112
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
113
- f"sum({escaped_numeric_col}) as sum",
114
- ] + escaped_segmentation_cols
115
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
107
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
108
+ f"sum({numeric_col}) as sum",
109
+ ] + segmentation_cols
110
+ all_group_by_cols = ["ts"] + segmentation_cols
116
111
 
117
112
  # build query
118
113
  query = f"""
119
114
  select {", ".join(all_select_clause_cols)}
120
115
  from {dataset.dataset_table_name}
121
- where {escaped_numeric_col} is not null
116
+ where {numeric_col} is not null
122
117
  group by {", ".join(all_group_by_cols)}
123
118
  """
124
119
 
125
120
  results = ddb_conn.sql(query).df()
121
+ unescaped_segmentation_cols = [unescape_identifier(seg_col) for seg_col in segmentation_cols]
126
122
 
127
123
  series = self.group_query_results_to_numeric_metrics(
128
124
  results,
129
125
  "sum",
130
- segmentation_cols,
126
+ unescaped_segmentation_cols,
131
127
  "ts",
132
128
  )
133
129
  # preserve dimension that identifies the name of the numeric column used for the aggregation
134
130
  for point in series:
135
- point.dimensions.append(Dimension(name="column_name", value=numeric_col))
131
+ point.dimensions.append(Dimension(name="column_name", value=unescape_identifier(numeric_col)))
136
132
 
137
133
  metric = self.series_to_metric(self.METRIC_NAME, series)
138
134
  return [metric]
@@ -119,9 +119,9 @@ class TokenUsageScope(BaseEnum):
119
119
 
120
120
 
121
121
  class ToolClassEnum(IntEnum):
122
- WRONG_TOOL_SELECTED = 0
123
- CORRECT_TOOL_SELECTED = 1
124
- NO_TOOL_SELECTED = 2
122
+ INCORRECT = 0
123
+ CORRECT = 1
124
+ NA = 2
125
125
 
126
126
  def __str__(self) -> str:
127
127
  return str(self.value)
@@ -147,3 +147,11 @@ class UserPermissionResource(BaseEnum):
147
147
  RESPONSES = "responses"
148
148
  RULES = "rules"
149
149
  TASKS = "tasks"
150
+
151
+
152
+ class ComparisonOperatorEnum(BaseEnum):
153
+ EQUAL = "eq"
154
+ GREATER_THAN = "gt"
155
+ GREATER_THAN_OR_EQUAL = "gte"
156
+ LESS_THAN = "lt"
157
+ LESS_THAN_OR_EQUAL = "lte"
@@ -1,9 +1,16 @@
1
1
  from datetime import datetime
2
- from typing import Any, Dict, List, Optional, Self, Type, Union
2
+ from typing import Any, Dict, List, Optional, Self, Type
3
3
 
4
4
  from fastapi import HTTPException
5
5
  from openinference.semconv.trace import OpenInferenceSpanKindValues
6
- from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
6
+ from pydantic import (
7
+ BaseModel,
8
+ ConfigDict,
9
+ Field,
10
+ ValidationInfo,
11
+ field_validator,
12
+ model_validator,
13
+ )
7
14
 
8
15
  from arthur_common.models.common_schemas import (
9
16
  ExamplesConfig,
@@ -25,6 +32,7 @@ from arthur_common.models.enums import (
25
32
  PIIEntityTypes,
26
33
  RuleScope,
27
34
  RuleType,
35
+ ToolClassEnum,
28
36
  )
29
37
  from arthur_common.models.metric_schemas import RelevanceMetricConfig
30
38
 
@@ -50,12 +58,12 @@ class NewRuleRequest(BaseModel):
50
58
  examples=[False],
51
59
  )
52
60
  config: (
53
- KeywordsConfig
54
- | RegexConfig
55
- | ExamplesConfig
56
- | ToxicityConfig
57
- | PIIConfig
58
- | None
61
+ KeywordsConfig
62
+ | RegexConfig
63
+ | ExamplesConfig
64
+ | ToxicityConfig
65
+ | PIIConfig
66
+ | None
59
67
  ) = Field(description="Config of the rule", default=None)
60
68
 
61
69
  model_config = ConfigDict(
@@ -554,3 +562,250 @@ class SpanQueryRequest(BaseModel):
554
562
  f"Valid values: {', '.join(sorted(valid_span_kinds))}",
555
563
  )
556
564
  return value
565
+
566
+
567
+ class TraceQueryRequest(BaseModel):
568
+ """Request schema for querying traces with comprehensive filtering."""
569
+
570
+ # Required
571
+ task_ids: list[str] = Field(
572
+ ...,
573
+ description="Task IDs to filter on. At least one is required.",
574
+ min_length=1,
575
+ )
576
+
577
+ # Common optional filters
578
+ trace_ids: Optional[list[str]] = Field(
579
+ None,
580
+ description="Trace IDs to filter on. Optional.",
581
+ )
582
+ start_time: Optional[datetime] = Field(
583
+ None,
584
+ description="Inclusive start date in ISO8601 string format. Use local time (not UTC).",
585
+ )
586
+ end_time: Optional[datetime] = Field(
587
+ None,
588
+ description="Exclusive end date in ISO8601 string format. Use local time (not UTC).",
589
+ )
590
+
591
+ # New trace-level filters
592
+ tool_name: Optional[str] = Field(
593
+ None,
594
+ description="Return only results with this tool name.",
595
+ )
596
+ span_types: Optional[list[str]] = Field(
597
+ None,
598
+ description="Span types to filter on. Optional.",
599
+ )
600
+
601
+ # Query relevance filters
602
+ query_relevance_eq: Optional[float] = Field(
603
+ None,
604
+ ge=0,
605
+ le=1,
606
+ description="Equal to this value.",
607
+ )
608
+ query_relevance_gt: Optional[float] = Field(
609
+ None,
610
+ ge=0,
611
+ le=1,
612
+ description="Greater than this value.",
613
+ )
614
+ query_relevance_gte: Optional[float] = Field(
615
+ None,
616
+ ge=0,
617
+ le=1,
618
+ description="Greater than or equal to this value.",
619
+ )
620
+ query_relevance_lt: Optional[float] = Field(
621
+ None,
622
+ ge=0,
623
+ le=1,
624
+ description="Less than this value.",
625
+ )
626
+ query_relevance_lte: Optional[float] = Field(
627
+ None,
628
+ ge=0,
629
+ le=1,
630
+ description="Less than or equal to this value.",
631
+ )
632
+
633
+ # Response relevance filters
634
+ response_relevance_eq: Optional[float] = Field(
635
+ None,
636
+ ge=0,
637
+ le=1,
638
+ description="Equal to this value.",
639
+ )
640
+ response_relevance_gt: Optional[float] = Field(
641
+ None,
642
+ ge=0,
643
+ le=1,
644
+ description="Greater than this value.",
645
+ )
646
+ response_relevance_gte: Optional[float] = Field(
647
+ None,
648
+ ge=0,
649
+ le=1,
650
+ description="Greater than or equal to this value.",
651
+ )
652
+ response_relevance_lt: Optional[float] = Field(
653
+ None,
654
+ ge=0,
655
+ le=1,
656
+ description="Less than this value.",
657
+ )
658
+ response_relevance_lte: Optional[float] = Field(
659
+ None,
660
+ ge=0,
661
+ le=1,
662
+ description="Less than or equal to this value.",
663
+ )
664
+
665
+ # Tool classification filters
666
+ tool_selection: Optional[ToolClassEnum] = Field(
667
+ None,
668
+ description="Tool selection evaluation result.",
669
+ )
670
+ tool_usage: Optional[ToolClassEnum] = Field(
671
+ None,
672
+ description="Tool usage evaluation result.",
673
+ )
674
+
675
+ # Trace duration filters
676
+ trace_duration_eq: Optional[float] = Field(
677
+ None,
678
+ ge=0,
679
+ description="Duration exactly equal to this value (seconds).",
680
+ )
681
+ trace_duration_gt: Optional[float] = Field(
682
+ None,
683
+ ge=0,
684
+ description="Duration greater than this value (seconds).",
685
+ )
686
+ trace_duration_gte: Optional[float] = Field(
687
+ None,
688
+ ge=0,
689
+ description="Duration greater than or equal to this value (seconds).",
690
+ )
691
+ trace_duration_lt: Optional[float] = Field(
692
+ None,
693
+ ge=0,
694
+ description="Duration less than this value (seconds).",
695
+ )
696
+ trace_duration_lte: Optional[float] = Field(
697
+ None,
698
+ ge=0,
699
+ description="Duration less than or equal to this value (seconds).",
700
+ )
701
+
702
+ @field_validator(
703
+ "query_relevance_eq",
704
+ "query_relevance_gt",
705
+ "query_relevance_gte",
706
+ "query_relevance_lt",
707
+ "query_relevance_lte",
708
+ "response_relevance_eq",
709
+ "response_relevance_gt",
710
+ "response_relevance_gte",
711
+ "response_relevance_lt",
712
+ "response_relevance_lte",
713
+ mode="before",
714
+ )
715
+ @classmethod
716
+ def validate_relevance_scores(
717
+ cls,
718
+ value: Optional[float],
719
+ info: ValidationInfo,
720
+ ) -> Optional[float]:
721
+ """Validate that relevance scores are between 0 and 1 (inclusive)."""
722
+ if value is not None:
723
+ if not (0.0 <= value <= 1.0):
724
+ raise ValueError(
725
+ f"{info.field_name} value must be between 0 and 1 (inclusive)",
726
+ )
727
+ return value
728
+
729
+ @field_validator(
730
+ "trace_duration_eq",
731
+ "trace_duration_gt",
732
+ "trace_duration_gte",
733
+ "trace_duration_lt",
734
+ "trace_duration_lte",
735
+ mode="before",
736
+ )
737
+ @classmethod
738
+ def validate_trace_duration(
739
+ cls,
740
+ value: Optional[float],
741
+ info: ValidationInfo,
742
+ ) -> Optional[float]:
743
+ """Validate that trace duration values are non-negative."""
744
+ if value is not None:
745
+ if value < 0:
746
+ raise ValueError(
747
+ f"{info.field_name} value must be non-negative (greater than or equal to 0)",
748
+ )
749
+ return value
750
+
751
+ @field_validator("tool_selection", "tool_usage", mode="before")
752
+ @classmethod
753
+ def validate_tool_classification(cls, value: Any) -> Optional[ToolClassEnum]:
754
+ """Validate tool classification enum values."""
755
+ if value is not None:
756
+ # Handle both integer and enum inputs
757
+ if isinstance(value, int):
758
+ if value not in [0, 1, 2]:
759
+ raise ValueError(
760
+ "Tool classification must be 0 (INCORRECT), "
761
+ "1 (CORRECT), or 2 (NA)",
762
+ )
763
+ return ToolClassEnum(value)
764
+ elif isinstance(value, ToolClassEnum):
765
+ return value
766
+ else:
767
+ raise ValueError(
768
+ "Tool classification must be an integer (0, 1, 2) or ToolClassEnum instance",
769
+ )
770
+ return value
771
+
772
+ @field_validator("span_types")
773
+ @classmethod
774
+ def validate_span_types(cls, value: Optional[list[str]]) -> Optional[list[str]]:
775
+ """Validate that all span_types are valid OpenInference span kinds."""
776
+ if not value:
777
+ return value
778
+
779
+ # Get all valid span kind values
780
+ valid_span_kinds = [kind.value for kind in OpenInferenceSpanKindValues]
781
+ invalid_types = [st for st in value if st not in valid_span_kinds]
782
+
783
+ if invalid_types:
784
+ raise ValueError(
785
+ f"Invalid span_types received: {invalid_types}. "
786
+ f"Valid values: {', '.join(sorted(valid_span_kinds))}",
787
+ )
788
+ return value
789
+
790
+ @model_validator(mode="after")
791
+ def validate_filter_combinations(self) -> Self:
792
+ """Validate that filter combinations are logically valid."""
793
+ # Check mutually exclusive filters for each metric type
794
+ for prefix in ["query_relevance", "response_relevance", "trace_duration"]:
795
+ eq_field = f"{prefix}_eq"
796
+ comparison_fields = [f"{prefix}_{op}" for op in ["gt", "gte", "lt", "lte"]]
797
+
798
+ if getattr(self, eq_field) and any(
799
+ getattr(self, field) for field in comparison_fields
800
+ ):
801
+ raise ValueError(
802
+ f"{eq_field} cannot be combined with other {prefix} comparison operators",
803
+ )
804
+
805
+ # Check for incompatible operator combinations
806
+ if getattr(self, f"{prefix}_gt") and getattr(self, f"{prefix}_gte"):
807
+ raise ValueError(f"Cannot combine {prefix}_gt with {prefix}_gte")
808
+ if getattr(self, f"{prefix}_lt") and getattr(self, f"{prefix}_lte"):
809
+ raise ValueError(f"Cannot combine {prefix}_lt with {prefix}_lte")
810
+
811
+ return self
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import re
2
3
  from typing import Any
3
4
 
4
5
  import duckdb
@@ -314,6 +315,9 @@ def escape_identifier(identifier: str) -> str:
314
315
  """
315
316
  Escape an identifier (e.g., column name) for use in a SQL query.
316
317
  This method handles special characters and ensures proper quoting.
318
+
319
+ For struct fields, the identifiers must be escaped as following:
320
+ "struct_column_name"."struct_field"
317
321
  """
318
322
  # Replace any double quotes with two double quotes
319
323
  escaped = identifier.replace('"', '""')
@@ -321,6 +325,32 @@ def escape_identifier(identifier: str) -> str:
321
325
  return f'"{escaped}"'
322
326
 
323
327
 
328
+ def unescape_identifier(identifier: str) -> str:
329
+ """
330
+ Unescape an identifier (e.g., column name).
331
+
332
+ This removes the double quotes and properly handles struct fields, which may be escaped as follows:
333
+ "struct_column_name"."struct_field"
334
+
335
+ Here's a hard case for help understanding this function: "struct "" column name with quotes"."struct.field.name.with.dots"
336
+ """
337
+ unescaped_identifiers = []
338
+ # strip top-level quotes
339
+ identifier = identifier[1:-1]
340
+ # split identifier into struct fields based on delimiter pattern "."
341
+ # at this point there are no external double quotes left; any remaining are escaped double quotes belonging to
342
+ # the column name
343
+ identifier_split_in_struct_fields = re.split(r'"\."', identifier)
344
+
345
+ for identifier in identifier_split_in_struct_fields:
346
+ # replace any escaped double quotes in the column
347
+ unescaped_identifier = identifier.replace('""', '"')
348
+ unescaped_identifiers.append(unescaped_identifier)
349
+
350
+ # join back any struct fields via dot syntax without the escape identifiers
351
+ return ".".join(unescaped_identifiers)
352
+
353
+
324
354
  def escape_str_literal(literal: str) -> str:
325
355
  """
326
356
  Escape a duckDB string literal for use in a SQL query.
@@ -16,17 +16,15 @@ def is_column_possible_segmentation(
16
16
  2. Has an allowed DType.
17
17
 
18
18
  PreReq: Table with column should already be loaded in DuckDB
19
+ column_name already has DuckDB escape identifier for the query syntax
19
20
  """
20
21
  segmentation_col_unique_val_limit = Config.segmentation_col_unique_values_limit()
21
22
  if column_dtype not in SEGMENTATION_ALLOWED_DTYPES:
22
23
  return False
23
24
 
24
- # check column for unique value count
25
- escaped_column = escape_identifier(column_name)
26
-
27
- # count distinct values in this column
25
+ # check column for unique value count - count distinct values in this column
28
26
  distinct_count_query = f"""
29
- SELECT COUNT(DISTINCT {escaped_column}) as distinct_count
27
+ SELECT COUNT(DISTINCT {column_name}) as distinct_count
30
28
  FROM {table}
31
29
  """
32
30
  result = conn.sql(distinct_count_query).fetchone()
@@ -110,7 +110,7 @@ class SchemaInferer:
110
110
  if not is_nested_col and is_column_possible_segmentation(
111
111
  self.conn,
112
112
  table,
113
- col_name,
113
+ escape_identifier(col_name),
114
114
  scalar_schema.dtype,
115
115
  ):
116
116
  scalar_schema.tag_hints.append(ScopeSchemaTag.POSSIBLE_SEGMENTATION)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: arthur-common
3
- Version: 2.1.68
3
+ Version: 2.3.0
4
4
  Summary: Utility code common to Arthur platform components.
5
5
  License: MIT
6
6
  Author: Arthur
@@ -4,17 +4,17 @@ arthur_common/aggregations/aggregator.py,sha256=AhyNqBDEbKtS3ZrnSIT9iZ1SK_TAuiUN
4
4
  arthur_common/aggregations/functions/README.md,sha256=MkZoTAJ94My96R5Z8GAxud7S6vyR0vgVi9gqdt9a4XY,5460
5
5
  arthur_common/aggregations/functions/__init__.py,sha256=HqC3UNRURX7ZQHgamTrQvfA8u_FiZGZ4I4eQW7Ooe5o,1299
6
6
  arthur_common/aggregations/functions/agentic_aggregations.py,sha256=09th4RPRf-ogtVWpbcqqmITN2UFtfqXhQ7Rr6IBqQHo,33995
7
- arthur_common/aggregations/functions/categorical_count.py,sha256=wc1ovL8JoiSeoSTk9h1fgrLj1QuQeYYZmEqgffGc2cw,5328
8
- arthur_common/aggregations/functions/confusion_matrix.py,sha256=aPL8DaXpflt0z1u1KIeFw9geZLJ6qTuTosCNFV54y8M,22105
9
- arthur_common/aggregations/functions/inference_count.py,sha256=SrRfxQVnX-wRTZ1zbqUKupPdACvfKeUpZDidZs45ZUY,4079
10
- arthur_common/aggregations/functions/inference_count_by_class.py,sha256=H64-pZIU1bJ2BPNJl64_H97BASAjGact10AjW_gkvaY,11551
11
- arthur_common/aggregations/functions/inference_null_count.py,sha256=w9sfu1QDlVBJwMW5EEkgda65nyMAABzd-FBKtj8amw4,4825
12
- arthur_common/aggregations/functions/mean_absolute_error.py,sha256=mOqE7XO2h7JtTLEKG5gTXu-pQJJIMYKWbUyqWA2dcxk,6831
13
- arthur_common/aggregations/functions/mean_squared_error.py,sha256=9WFBIhmAg1FZ7tdQYFWsS3yp3kyCYMJVAk-uLSb41Ck,6852
14
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py,sha256=rXXvXCIb30j_ofsMfp2yjLEdf8LmfKTqOLM3NQowzaU,12612
7
+ arthur_common/aggregations/functions/categorical_count.py,sha256=_TD0s0JAtqC5RmT6ZNWLEBZm-dU4akm-Aor7EDVazzA,5176
8
+ arthur_common/aggregations/functions/confusion_matrix.py,sha256=n33kyyZuxo8k6jUYnBUsc1fLotTmcw0H8rsX_x_oeJ0,21733
9
+ arthur_common/aggregations/functions/inference_count.py,sha256=D49SpwFywipMqeC93gc3_ZGwBoGL89yKuA9_55dBWBw,3984
10
+ arthur_common/aggregations/functions/inference_count_by_class.py,sha256=mYL6xMTb-_VO6mKGWHOtFAvWzTt-C_4vKf8KgioJGDg,11191
11
+ arthur_common/aggregations/functions/inference_null_count.py,sha256=UlE5EZa3k2nKIv6Yzrnjq1MsZEzrau7Olumny8hsHtg,4672
12
+ arthur_common/aggregations/functions/mean_absolute_error.py,sha256=YzrNHox_4HEGWn33E12d6eiQ8A9Rwct7AW3hOWrTW7I,6544
13
+ arthur_common/aggregations/functions/mean_squared_error.py,sha256=b_is7FKRSninYs1ilAXeLPJFfmyCaiKvCC9Ev_OERio,6565
14
+ arthur_common/aggregations/functions/multiclass_confusion_matrix.py,sha256=e1KEyxIZocWMkDbnW0zfJHd5PUi_kyzwNUVFOD0l5Nk,12359
15
15
  arthur_common/aggregations/functions/multiclass_inference_count_by_class.py,sha256=yiMpdz4VuX1ELprXYupFu4B9aDLIhgfEi3ma8jZsT_M,4261
16
- arthur_common/aggregations/functions/numeric_stats.py,sha256=uHTyOAHW6xF6D-TeFLtY16iVR-Ju_6lmXSSY77mH0Qs,4921
17
- arthur_common/aggregations/functions/numeric_sum.py,sha256=kGE6Jjnjwf2E4TKE3NwPyrlEKgygfCxv1z_YGDCOcCQ,5028
16
+ arthur_common/aggregations/functions/numeric_stats.py,sha256=mMpVH1PvElGaz5mIQWy8sIkKPZ5kyeNOAM2iM2IlBvY,4760
17
+ arthur_common/aggregations/functions/numeric_sum.py,sha256=Vq-dQonKTdLt8pYFwT5tCXyyL_FvVQxb6b3nFNRSqus,4861
18
18
  arthur_common/aggregations/functions/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arthur_common/aggregations/functions/shield_aggregations.py,sha256=BzPkpbhZRy16iFOobuusGKHfov5DxnXS2v_WThpw2fk,35659
20
20
  arthur_common/aggregations/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -26,11 +26,11 @@ arthur_common/models/common_schemas.py,sha256=31Br7DbIgrwHwzgiyMXrgPYrANhqSqle7k
26
26
  arthur_common/models/connectors.py,sha256=RwjY74cs0KTKw7Opywehg46SZ4vwN3xm6ujHRsRIQ8Y,2292
27
27
  arthur_common/models/constants.py,sha256=munkU0LrLsDs9BtAfozzw30FCguIowmAUKg_9vqwX24,1049
28
28
  arthur_common/models/datasets.py,sha256=7p1tyJEPwXjBs2ZRoai8hTzNl6MK9jU1DluzASApE_4,254
29
- arthur_common/models/enums.py,sha256=f--GnBHo7_PEISrIS18lCxOhZUZ-BcaBvTlq0kX4tsU,3739
29
+ arthur_common/models/enums.py,sha256=J2beHEMjLfOGgc-vh1aDpE7KmBGKzLoOUGYLtuciJro,3870
30
30
  arthur_common/models/metric_schemas.py,sha256=Xf-1RTzg7iYtnBMLkUUUuMPzAujzzNvQx_pe-CksEdU,2484
31
31
  arthur_common/models/metrics.py,sha256=87LUU7-8duoKCzaffw9GHMyjsKMNoxKa5n5Hyg_ZK1s,11931
32
32
  arthur_common/models/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
- arthur_common/models/request_schemas.py,sha256=l6GvEtUcIJW5GGy9L3jR2djNqg4f1bBFwl2wBpVfL10,21467
33
+ arthur_common/models/request_schemas.py,sha256=ihrWK0SRVXsRmNaiLibbAEWi_RHl440JJvm09WRdNxQ,29329
34
34
  arthur_common/models/response_schemas.py,sha256=eZCgxnfOht8isUunAA4rosLFA-tgXRZIcj2CYa5XqOE,24362
35
35
  arthur_common/models/schema_definitions.py,sha256=dcUSLjBmvyloStcBFmT_rHdXbKdvA8Yxi_avYUbps3E,16876
36
36
  arthur_common/models/task_job_specs.py,sha256=p7jsSb97ylHYNkwoHXNOJvx2zcnh2kxLeh3m0pddo4M,3442
@@ -38,12 +38,12 @@ arthur_common/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
38
  arthur_common/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
39
  arthur_common/tools/aggregation_analyzer.py,sha256=UfMtvFWXV2Dqly8S6nneGgomuvEGN-1tBz81tfkMcAE,11206
40
40
  arthur_common/tools/aggregation_loader.py,sha256=3CF46bNi-GdJBNOXkjYfCQ1Aung8lf65L532sdWmR_s,2351
41
- arthur_common/tools/duckdb_data_loader.py,sha256=OwuvppwcBB9qQxyWr86mH7Gz2FBIuyDl0UpQ7TulhlU,11220
42
- arthur_common/tools/duckdb_utils.py,sha256=1i-kRXu95gh4Sf9Osl2LFUpdb0yZifOjLDtIgSfSmfs,1197
41
+ arthur_common/tools/duckdb_data_loader.py,sha256=A80wpATSc4VJLghoHwxpBEuUsxY93OZS0Qo4cFX7cRw,12462
42
+ arthur_common/tools/duckdb_utils.py,sha256=8l8bUmjqJyj84DXyEOzO_DsD8VsO25DWYK_IYF--Zek,1211
43
43
  arthur_common/tools/functions.py,sha256=FWL4eWO5-vLp86WudT-MGUKvf2B8f02IdoXQFKd6d8k,1093
44
44
  arthur_common/tools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
- arthur_common/tools/schema_inferer.py,sha256=Ur4CXGAkd6ZMSU0nMNrkOEElsBopHXq0lctTV8X92W8,5188
45
+ arthur_common/tools/schema_inferer.py,sha256=9teI67umlGn0izp6pZ5UBuWxJthaWEmw3wRj2KPIbf4,5207
46
46
  arthur_common/tools/time_utils.py,sha256=4gfiu9NXfvPZltiVNLSIQGylX6h2W0viNi9Kv4bKyfw,1410
47
- arthur_common-2.1.68.dist-info/METADATA,sha256=Dmyvy60ivlka8sQwvlFN2-XNAybMTRbHsNu9RUd5FkU,2147
48
- arthur_common-2.1.68.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
49
- arthur_common-2.1.68.dist-info/RECORD,,
47
+ arthur_common-2.3.0.dist-info/METADATA,sha256=AfXaXNFya5qwUZcaI_QBG7b1gTLCgLSoza5kzgyGb0E,2146
48
+ arthur_common-2.3.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
49
+ arthur_common-2.3.0.dist-info/RECORD,,