arthur-common 2.1.58__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arthur_common/aggregations/aggregator.py +73 -9
  2. arthur_common/aggregations/functions/agentic_aggregations.py +260 -85
  3. arthur_common/aggregations/functions/categorical_count.py +15 -15
  4. arthur_common/aggregations/functions/confusion_matrix.py +24 -26
  5. arthur_common/aggregations/functions/inference_count.py +5 -9
  6. arthur_common/aggregations/functions/inference_count_by_class.py +16 -27
  7. arthur_common/aggregations/functions/inference_null_count.py +10 -13
  8. arthur_common/aggregations/functions/mean_absolute_error.py +12 -18
  9. arthur_common/aggregations/functions/mean_squared_error.py +12 -18
  10. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +13 -20
  11. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +1 -1
  12. arthur_common/aggregations/functions/numeric_stats.py +13 -15
  13. arthur_common/aggregations/functions/numeric_sum.py +12 -15
  14. arthur_common/aggregations/functions/shield_aggregations.py +457 -215
  15. arthur_common/models/common_schemas.py +214 -0
  16. arthur_common/models/connectors.py +10 -2
  17. arthur_common/models/constants.py +24 -0
  18. arthur_common/models/datasets.py +0 -9
  19. arthur_common/models/enums.py +177 -0
  20. arthur_common/models/metric_schemas.py +63 -0
  21. arthur_common/models/metrics.py +2 -9
  22. arthur_common/models/request_schemas.py +870 -0
  23. arthur_common/models/response_schemas.py +785 -0
  24. arthur_common/models/schema_definitions.py +6 -1
  25. arthur_common/models/task_job_specs.py +3 -12
  26. arthur_common/tools/duckdb_data_loader.py +34 -2
  27. arthur_common/tools/duckdb_utils.py +3 -6
  28. arthur_common/tools/schema_inferer.py +3 -6
  29. {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/METADATA +12 -4
  30. arthur_common-2.4.13.dist-info/RECORD +49 -0
  31. arthur_common/models/shield.py +0 -642
  32. arthur_common-2.1.58.dist-info/RECORD +0 -44
  33. {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/WHEEL +0 -0
@@ -4,7 +4,7 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.datasets import ModelProblemType
7
+ from arthur_common.models.enums import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
9
  BaseReportedAggregation,
10
10
  DatasetReference,
@@ -20,7 +20,10 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
- from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
23
+ from arthur_common.tools.duckdb_data_loader import (
24
+ escape_str_literal,
25
+ unescape_identifier,
26
+ )
24
27
 
25
28
 
26
29
  class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
@@ -78,11 +81,11 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
78
81
  Without segmentation, this is the query:
79
82
  WITH normalized_data AS (
80
83
  SELECT
81
- {escaped_timestamp_col} AS timestamp,
82
- {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
83
- {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
84
+ {timestamp_col} AS timestamp,
85
+ {prediction_normalization_case.replace('value', prediction_col)} AS prediction,
86
+ {gt_normalization_case.replace('value', gt_values_col)} AS actual_value
84
87
  FROM {dataset.dataset_table_name}
85
- WHERE {escaped_timestamp_col} IS NOT NULL
88
+ WHERE {timestamp_col} IS NOT NULL
86
89
  )
87
90
  SELECT
88
91
  time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
@@ -90,34 +93,31 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
90
93
  SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
91
94
  SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
92
95
  SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count,
93
- {escaped_prediction_col_name} as prediction_column_name
96
+ {unescaped_prediction_col_name} as prediction_column_name
94
97
  FROM normalized_data
95
98
  GROUP BY ts
96
99
  ORDER BY ts
97
100
  """
98
101
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
99
- escaped_timestamp_col = escape_identifier(timestamp_col)
100
- escaped_prediction_col = escape_identifier(prediction_col)
101
- escaped_prediction_col_name = escape_str_literal(prediction_col)
102
- escaped_gt_values_col = escape_identifier(gt_values_col)
102
+ unescaped_prediction_col_name = escape_str_literal(
103
+ unescape_identifier(prediction_col),
104
+ )
105
+
103
106
  # build query components with segmentation columns
104
- escaped_segmentation_cols = [
105
- escape_identifier(col) for col in segmentation_cols
106
- ]
107
107
  first_subquery_select_cols = [
108
- f"{escaped_timestamp_col} AS timestamp",
109
- f"{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction",
110
- f"{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value",
111
- ] + escaped_segmentation_cols
108
+ f"{timestamp_col} AS timestamp",
109
+ f"{prediction_normalization_case.replace('value', prediction_col)} AS prediction",
110
+ f"{gt_normalization_case.replace('value', gt_values_col)} AS actual_value",
111
+ ] + segmentation_cols
112
112
  second_subquery_select_cols = [
113
113
  "time_bucket(INTERVAL '5 minutes', timestamp) AS ts",
114
114
  "SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count",
115
115
  "SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count",
116
116
  "SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count",
117
117
  "SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count",
118
- f"{escaped_prediction_col_name} as prediction_column_name",
119
- ] + escaped_segmentation_cols
120
- second_subquery_group_by_cols = ["ts"] + escaped_segmentation_cols
118
+ f"{unescaped_prediction_col_name} as prediction_column_name",
119
+ ] + segmentation_cols
120
+ second_subquery_group_by_cols = ["ts"] + segmentation_cols
121
121
  extra_dims = ["prediction_column_name"]
122
122
 
123
123
  # build query
@@ -125,7 +125,7 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
125
125
  WITH normalized_data AS (
126
126
  SELECT {", ".join(first_subquery_select_cols)}
127
127
  FROM {dataset.dataset_table_name}
128
- WHERE {escaped_timestamp_col} IS NOT NULL
128
+ WHERE {timestamp_col} IS NOT NULL
129
129
  )
130
130
  SELECT {", ".join(second_subquery_select_cols)}
131
131
  FROM normalized_data
@@ -243,9 +243,8 @@ class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
243
243
  ] = None,
244
244
  ) -> list[NumericMetric]:
245
245
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
246
- escaped_prediction_col = escape_identifier(prediction_col)
247
246
  # Get the type of prediction column
248
- type_query = f"SELECT typeof({escaped_prediction_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
247
+ type_query = f"SELECT typeof({prediction_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
249
248
  res = ddb_conn.sql(type_query).fetchone()
250
249
  # As long as this column exists, we should be able to get the type. This is here to make mypy happy.
251
250
  if not res:
@@ -476,7 +475,6 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
476
475
  ),
477
476
  ] = None,
478
477
  ) -> list[NumericMetric]:
479
- escaped_gt_values_col = escape_identifier(gt_values_col)
480
478
  prediction_normalization_case = f"""
481
479
  CASE
482
480
  WHEN value >= {threshold} THEN 1
@@ -485,7 +483,7 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
485
483
  END
486
484
  """
487
485
 
488
- type_query = f"SELECT typeof({escaped_gt_values_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
486
+ type_query = f"SELECT typeof({gt_values_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
489
487
  res = ddb_conn.sql(type_query).fetchone()
490
488
  # As long as this column exists, we should be able to get the type. This is here to make mypy happy.
491
489
  if not res:
@@ -18,7 +18,6 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
- from arthur_common.tools.duckdb_data_loader import escape_identifier
22
21
 
23
22
 
24
23
  class InferenceCountAggregationFunction(NumericAggregationFunction):
@@ -80,23 +79,19 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
80
79
  ] = None,
81
80
  ) -> list[NumericMetric]:
82
81
  """Executed SQL with no segmentation columns:
83
- select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
82
+ select time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
84
83
  count(*) as count \
85
84
  from {dataset.dataset_table_name} \
86
85
  group by ts \
87
86
  """
88
87
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
89
- escaped_timestamp_col = escape_identifier(timestamp_col)
90
88
 
91
89
  # build query components with segmentation columns
92
- escaped_segmentation_cols = [
93
- escape_identifier(col) for col in segmentation_cols
94
- ]
95
90
  all_select_clause_cols = [
96
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
91
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
97
92
  f"count(*) as count",
98
- ] + escaped_segmentation_cols
99
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
93
+ ] + segmentation_cols
94
+ all_group_by_cols = ["ts"] + segmentation_cols
100
95
 
101
96
  # build query
102
97
  count_query = f"""
@@ -106,6 +101,7 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
106
101
  """
107
102
 
108
103
  results = ddb_conn.sql(count_query).df()
104
+
109
105
  series = self.group_query_results_to_numeric_metrics(
110
106
  results,
111
107
  "count",
@@ -4,7 +4,7 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.datasets import ModelProblemType
7
+ from arthur_common.models.enums import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
9
  BaseReportedAggregation,
10
10
  DatasetReference,
@@ -20,7 +20,6 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
- from arthur_common.tools.duckdb_data_loader import escape_identifier
24
23
 
25
24
 
26
25
  class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction):
@@ -100,31 +99,26 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
100
99
  ) -> list[NumericMetric]:
101
100
  """Executed SQL with no segmentation columns:
102
101
  SELECT
103
- time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
104
- {escaped_pred_col} as prediction,
102
+ time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts,
103
+ {prediction_col} as prediction,
105
104
  COUNT(*) as count
106
105
  FROM {dataset.dataset_table_name}
107
106
  GROUP BY
108
107
  ts,
109
108
  -- group by raw column name instead of alias in select
110
109
  -- in case table has a column called 'prediction'
111
- {escaped_pred_col}
110
+ {prediction_col}
112
111
  ORDER BY ts
113
112
  """
114
113
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
115
- escaped_timestamp_col = escape_identifier(timestamp_col)
116
- escaped_pred_col = escape_identifier(prediction_col)
117
114
 
118
115
  # build query components with segmentation columns
119
- escaped_segmentation_cols = [
120
- escape_identifier(col) for col in segmentation_cols
121
- ]
122
116
  all_select_clause_cols = [
123
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
124
- f"{escaped_pred_col} as prediction",
117
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
118
+ f"{prediction_col} as prediction",
125
119
  f"COUNT(*) as count",
126
- ] + escaped_segmentation_cols
127
- all_group_by_cols = ["ts", f"{escaped_pred_col}"] + escaped_segmentation_cols
120
+ ] + segmentation_cols
121
+ all_group_by_cols = ["ts", f"{prediction_col}"] + segmentation_cols
128
122
  extra_dims = ["prediction"]
129
123
 
130
124
  # build query
@@ -248,34 +242,29 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
248
242
  ) -> list[NumericMetric]:
249
243
  """Executed SQL with no segmentation columns:
250
244
  SELECT
251
- time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
252
- CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
245
+ time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts,
246
+ CASE WHEN {prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
253
247
  COUNT(*) as count
254
248
  FROM {dataset.dataset_table_name}
255
249
  GROUP BY
256
250
  ts,
257
251
  -- group by raw column name instead of alias in select
258
252
  -- in case table has a column called 'prediction'
259
- {escaped_prediction_col}
253
+ {prediction_col}
260
254
  ORDER BY ts
261
255
  """
262
256
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
263
- escaped_timestamp_col = escape_identifier(timestamp_col)
264
- escaped_prediction_col = escape_identifier(prediction_col)
265
257
 
266
258
  # build query components with segmentation columns
267
- escaped_segmentation_cols = [
268
- escape_identifier(col) for col in segmentation_cols
269
- ]
270
259
  all_select_clause_cols = [
271
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
272
- f"CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction",
260
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
261
+ f"CASE WHEN {prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction",
273
262
  f"COUNT(*) as count",
274
- ] + escaped_segmentation_cols
263
+ ] + segmentation_cols
275
264
  all_group_by_cols = [
276
265
  "ts",
277
- f"{escaped_prediction_col}",
278
- ] + escaped_segmentation_cols
266
+ f"{prediction_col}",
267
+ ] + segmentation_cols
279
268
  extra_dims = ["prediction"]
280
269
 
281
270
  query = f"""
@@ -19,7 +19,7 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import escape_identifier
22
+ from arthur_common.tools.duckdb_data_loader import unescape_identifier
23
23
 
24
24
 
25
25
  class InferenceNullCountAggregationFunction(NumericAggregationFunction):
@@ -90,30 +90,25 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
90
90
  ] = None,
91
91
  ) -> list[NumericMetric]:
92
92
  """Executed SQL with no segmentation columns:
93
- select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
93
+ select time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
94
94
  count(*) as count \
95
- from {dataset.dataset_table_name} where {escaped_nullable_col} is null \
95
+ from {dataset.dataset_table_name} where {nullable_col} is null \
96
96
  group by ts \
97
97
  """
98
98
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
99
- escaped_timestamp_col = escape_identifier(timestamp_col)
100
- escaped_nullable_col = escape_identifier(nullable_col)
101
99
 
102
100
  # build query components with segmentation columns
103
- escaped_segmentation_cols = [
104
- escape_identifier(col) for col in segmentation_cols
105
- ]
106
101
  all_select_clause_cols = [
107
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
102
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
108
103
  f"count(*) as count",
109
- ] + escaped_segmentation_cols
110
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
104
+ ] + segmentation_cols
105
+ all_group_by_cols = ["ts"] + segmentation_cols
111
106
 
112
107
  # build query
113
108
  count_query = f"""
114
109
  select {", ".join(all_select_clause_cols)}
115
110
  from {dataset.dataset_table_name}
116
- where {escaped_nullable_col} is null
111
+ where {nullable_col} is null
117
112
  group by {", ".join(all_group_by_cols)}
118
113
  """
119
114
 
@@ -127,7 +122,9 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
127
122
  )
128
123
  # preserve dimension that identifies the name of the nullable column used for the aggregation
129
124
  for point in series:
130
- point.dimensions.append(Dimension(name="column_name", value=nullable_col))
125
+ point.dimensions.append(
126
+ Dimension(name="column_name", value=unescape_identifier(nullable_col)),
127
+ )
131
128
 
132
129
  metric = self.series_to_metric(self.METRIC_NAME, series)
133
130
  return [metric]
@@ -4,7 +4,7 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.datasets import ModelProblemType
7
+ from arthur_common.models.enums import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
9
  BaseReportedAggregation,
10
10
  DatasetReference,
@@ -19,7 +19,6 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import escape_identifier
23
22
 
24
23
 
25
24
  class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
@@ -111,40 +110,35 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
111
110
  ] = None,
112
111
  ) -> list[NumericMetric]:
113
112
  """Executed SQL with no segmentation columns:
114
- SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
115
- SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae, \
113
+ SELECT time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
114
+ SUM(ABS({prediction_col} - {ground_truth_col})) as ae, \
116
115
  COUNT(*) as count \
117
116
  FROM {dataset.dataset_table_name} \
118
- WHERE {escaped_prediction_col} IS NOT NULL \
119
- AND {escaped_ground_truth_col} IS NOT NULL \
117
+ WHERE {prediction_col} IS NOT NULL \
118
+ AND {ground_truth_col} IS NOT NULL \
120
119
  GROUP BY ts order by ts desc \
121
120
  """
122
121
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
123
- escaped_timestamp_col = escape_identifier(timestamp_col)
124
- escaped_prediction_col = escape_identifier(prediction_col)
125
- escaped_ground_truth_col = escape_identifier(ground_truth_col)
126
122
 
127
123
  # build query components with segmentation columns
128
- escaped_segmentation_cols = [
129
- escape_identifier(col) for col in segmentation_cols
130
- ]
131
124
  all_select_clause_cols = [
132
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
133
- f"SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae",
125
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
126
+ f"SUM(ABS({prediction_col} - {ground_truth_col})) as ae",
134
127
  f"COUNT(*) as count",
135
- ] + escaped_segmentation_cols
136
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
128
+ ] + segmentation_cols
129
+ all_group_by_cols = ["ts"] + segmentation_cols
137
130
 
138
131
  # build query
139
132
  mae_query = f"""
140
133
  SELECT {", ".join(all_select_clause_cols)}
141
134
  FROM {dataset.dataset_table_name}
142
- WHERE {escaped_prediction_col} IS NOT NULL
143
- AND {escaped_ground_truth_col} IS NOT NULL
135
+ WHERE {prediction_col} IS NOT NULL
136
+ AND {ground_truth_col} IS NOT NULL
144
137
  GROUP BY {", ".join(all_group_by_cols)} order by ts desc
145
138
  """
146
139
 
147
140
  results = ddb_conn.sql(mae_query).df()
141
+
148
142
  count_series = self.group_query_results_to_numeric_metrics(
149
143
  results,
150
144
  "count",
@@ -4,7 +4,7 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.datasets import ModelProblemType
7
+ from arthur_common.models.enums import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
9
  BaseReportedAggregation,
10
10
  DatasetReference,
@@ -19,7 +19,6 @@ from arthur_common.models.schema_definitions import (
19
19
  ScalarType,
20
20
  ScopeSchemaTag,
21
21
  )
22
- from arthur_common.tools.duckdb_data_loader import escape_identifier
23
22
 
24
23
 
25
24
  class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
@@ -111,40 +110,35 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
111
110
  ] = None,
112
111
  ) -> list[NumericMetric]:
113
112
  """Executed SQL with no segmentation columns:
114
- SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
115
- SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error, \
113
+ SELECT time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts, \
114
+ SUM(POW({prediction_col} - {ground_truth_col}, 2)) as squared_error, \
116
115
  COUNT(*) as count \
117
116
  FROM {dataset.dataset_table_name} \
118
- WHERE {escaped_prediction_col} IS NOT NULL \
119
- AND {escaped_ground_truth_col} IS NOT NULL \
117
+ WHERE {prediction_col} IS NOT NULL \
118
+ AND {ground_truth_col} IS NOT NULL \
120
119
  GROUP BY ts order by ts desc \
121
120
  """
122
121
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
123
- escaped_timestamp_col = escape_identifier(timestamp_col)
124
- escaped_prediction_col = escape_identifier(prediction_col)
125
- escaped_ground_truth_col = escape_identifier(ground_truth_col)
126
122
 
127
123
  # build query components with segmentation columns
128
- escaped_segmentation_cols = [
129
- escape_identifier(col) for col in segmentation_cols
130
- ]
131
124
  all_select_clause_cols = [
132
- f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
133
- f"SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error",
125
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col}) as ts",
126
+ f"SUM(POW({prediction_col} - {ground_truth_col}, 2)) as squared_error",
134
127
  f"COUNT(*) as count",
135
- ] + escaped_segmentation_cols
136
- all_group_by_cols = ["ts"] + escaped_segmentation_cols
128
+ ] + segmentation_cols
129
+ all_group_by_cols = ["ts"] + segmentation_cols
137
130
 
138
131
  # build query
139
132
  mse_query = f"""
140
133
  SELECT {", ".join(all_select_clause_cols)}
141
134
  FROM {dataset.dataset_table_name}
142
- WHERE {escaped_prediction_col} IS NOT NULL
143
- AND {escaped_ground_truth_col} IS NOT NULL
135
+ WHERE {prediction_col} IS NOT NULL
136
+ AND {ground_truth_col} IS NOT NULL
144
137
  GROUP BY {", ".join(all_group_by_cols)} order by ts desc
145
138
  """
146
139
 
147
140
  results = ddb_conn.sql(mse_query).df()
141
+
148
142
  count_series = self.group_query_results_to_numeric_metrics(
149
143
  results,
150
144
  "count",
@@ -4,7 +4,7 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.datasets import ModelProblemType
7
+ from arthur_common.models.enums import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
9
  BaseReportedAggregation,
10
10
  DatasetReference,
@@ -20,7 +20,7 @@ from arthur_common.models.schema_definitions import (
20
20
  ScalarType,
21
21
  ScopeSchemaTag,
22
22
  )
23
- from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
23
+ from arthur_common.tools.duckdb_data_loader import escape_str_literal
24
24
 
25
25
 
26
26
  class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
@@ -194,11 +194,11 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
194
194
  Returns the following SQL with no segmentation:
195
195
  WITH normalized_data AS (
196
196
  SELECT
197
- {escaped_timestamp_col} AS timestamp,
198
- {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
199
- {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
197
+ {timestamp_col} AS timestamp,
198
+ {prediction_normalization_case.replace('value', prediction_col)} AS prediction,
199
+ {gt_normalization_case.replace('value', gt_values_col)} AS actual_value
200
200
  FROM {dataset.dataset_table_name}
201
- WHERE {escaped_timestamp_col} IS NOT NULL
201
+ WHERE {timestamp_col} IS NOT NULL
202
202
  )
203
203
  SELECT
204
204
  time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
@@ -212,19 +212,12 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
212
212
  ORDER BY ts
213
213
 
214
214
  """
215
- escaped_timestamp_col = escape_identifier(timestamp_col)
216
- escaped_prediction_col = escape_identifier(prediction_col)
217
- escaped_gt_values_col = escape_identifier(gt_values_col)
218
-
219
215
  # build query components with segmentation columns
220
- escaped_segmentation_cols = [
221
- escape_identifier(col) for col in segmentation_cols
222
- ]
223
216
  first_subquery_select_cols = [
224
- f"{escaped_timestamp_col} AS timestamp",
225
- f"{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction",
226
- f"{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value",
227
- ] + escaped_segmentation_cols
217
+ f"{timestamp_col} AS timestamp",
218
+ f"{prediction_normalization_case.replace('value', prediction_col)} AS prediction",
219
+ f"{gt_normalization_case.replace('value', gt_values_col)} AS actual_value",
220
+ ] + segmentation_cols
228
221
  second_subquery_select_cols = [
229
222
  "time_bucket(INTERVAL '5 minutes', timestamp) AS ts",
230
223
  "SUM(CASE WHEN prediction = 1 AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count",
@@ -232,8 +225,8 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
232
225
  "SUM(CASE WHEN prediction = 0 AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count",
233
226
  "SUM(CASE WHEN prediction = 0 AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count",
234
227
  f"any_value({escaped_positive_class_label}) as class_label",
235
- ] + escaped_segmentation_cols
236
- second_subquery_group_by_cols = ["ts"] + escaped_segmentation_cols
228
+ ] + segmentation_cols
229
+ second_subquery_group_by_cols = ["ts"] + segmentation_cols
237
230
  extra_dims = ["class_label"]
238
231
 
239
232
  # build query
@@ -241,7 +234,7 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
241
234
  WITH normalized_data AS (
242
235
  SELECT {", ".join(first_subquery_select_cols)}
243
236
  FROM {dataset.dataset_table_name}
244
- WHERE {escaped_timestamp_col} IS NOT NULL
237
+ WHERE {timestamp_col} IS NOT NULL
245
238
  )
246
239
  SELECT {", ".join(second_subquery_select_cols)}
247
240
  FROM normalized_data
@@ -6,7 +6,7 @@ from duckdb import DuckDBPyConnection
6
6
  from arthur_common.aggregations.functions.inference_count_by_class import (
7
7
  BinaryClassifierCountByClassAggregationFunction,
8
8
  )
9
- from arthur_common.models.datasets import ModelProblemType
9
+ from arthur_common.models.enums import ModelProblemType
10
10
  from arthur_common.models.metrics import (
11
11
  BaseReportedAggregation,
12
12
  DatasetReference,
@@ -18,7 +18,10 @@ from arthur_common.models.schema_definitions import (
18
18
  ScalarType,
19
19
  ScopeSchemaTag,
20
20
  )
21
- from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
21
+ from arthur_common.tools.duckdb_data_loader import (
22
+ escape_str_literal,
23
+ unescape_identifier,
24
+ )
22
25
 
23
26
 
24
27
  class NumericSketchAggregationFunction(SketchAggregationFunction):
@@ -95,40 +98,35 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
95
98
  ] = None,
96
99
  ) -> list[SketchMetric]:
97
100
  """Executed SQL with no segmentation columns:
98
- select {escaped_timestamp_col_id} as ts, \
99
- {escaped_numeric_col_id}, \
101
+ select {timestamp_col} as ts, \
102
+ {numeric_col}, \
100
103
  {numeric_col_name_str} as column_name \
101
104
  from {dataset.dataset_table_name} \
102
- where {escaped_numeric_col_id} is not null \
105
+ where {numeric_col} is not null \
103
106
  """
104
107
  segmentation_cols = [] if not segmentation_cols else segmentation_cols
105
- escaped_timestamp_col_id = escape_identifier(timestamp_col)
106
- escaped_numeric_col_id = escape_identifier(numeric_col)
107
- numeric_col_name_str = escape_str_literal(numeric_col)
108
+ numeric_col_name_str = escape_str_literal(unescape_identifier(numeric_col))
108
109
 
109
110
  # build query components with segmentation columns
110
- escaped_segmentation_cols = [
111
- escape_identifier(col) for col in segmentation_cols
112
- ]
113
111
  all_select_clause_cols = [
114
- f"{escaped_timestamp_col_id} as ts",
115
- f"{escaped_numeric_col_id}",
112
+ f"{timestamp_col} as ts",
113
+ f"{numeric_col}",
116
114
  f"{numeric_col_name_str} as column_name",
117
- ] + escaped_segmentation_cols
115
+ ] + segmentation_cols
118
116
  extra_dims = ["column_name"]
119
117
 
120
118
  # build query
121
119
  data_query = f"""
122
120
  select {", ".join(all_select_clause_cols)}
123
121
  from {dataset.dataset_table_name}
124
- where {escaped_numeric_col_id} is not null
122
+ where {numeric_col} is not null
125
123
  """
126
124
 
127
125
  results = ddb_conn.sql(data_query).df()
128
126
 
129
127
  series = self.group_query_results_to_sketch_metrics(
130
128
  results,
131
- numeric_col,
129
+ unescape_identifier(numeric_col),
132
130
  segmentation_cols + extra_dims,
133
131
  "ts",
134
132
  )