arthur-common 1.0.1__py3-none-any.whl → 2.1.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

@@ -73,6 +73,15 @@ class NumericAggregationFunction(AggregationFunction, ABC):
73
73
  From there, iterate over the group turning each data point to a *Point. At the end, this single instance of the group metrics
74
74
  and the list of points (values) are merged to one *TimeSeries
75
75
  """
76
+ if not dim_columns:
77
+ return [
78
+ NumericAggregationFunction._dimensionless_query_results_to_numeric_metrics(
79
+ data,
80
+ value_col,
81
+ timestamp_col,
82
+ ),
83
+ ]
84
+
76
85
  calculated_metrics: list[NumericTimeSeries] = []
77
86
  # make sure dropna is False or rows with "null" as a dimension value will be dropped
78
87
  groups = data.groupby(dim_columns, dropna=False)
@@ -99,7 +108,7 @@ class NumericAggregationFunction(AggregationFunction, ABC):
99
108
  return calculated_metrics
100
109
 
101
110
  @staticmethod
102
- def dimensionless_query_results_to_numeric_metrics(
111
+ def _dimensionless_query_results_to_numeric_metrics(
103
112
  data: pd.DataFrame,
104
113
  value_col: str,
105
114
  timestamp_col: str,
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -7,6 +7,7 @@ from arthur_common.models.schema_definitions import (
7
7
  DType,
8
8
  MetricColumnParameterAnnotation,
9
9
  MetricDatasetParameterAnnotation,
10
+ MetricMultipleColumnParameterAnnotation,
10
11
  ScalarType,
11
12
  ScopeSchemaTag,
12
13
  )
@@ -64,25 +65,64 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
64
65
  description="A column containing categorical values to count.",
65
66
  ),
66
67
  ],
68
+ segmentation_cols: Annotated[
69
+ Optional[list[str]],
70
+ MetricMultipleColumnParameterAnnotation(
71
+ source_dataset_parameter_key="dataset",
72
+ allowed_column_types=[
73
+ ScalarType(dtype=DType.INT),
74
+ ScalarType(dtype=DType.BOOL),
75
+ ScalarType(dtype=DType.STRING),
76
+ ScalarType(dtype=DType.UUID),
77
+ ],
78
+ tag_hints=[],
79
+ friendly_name="Segmentation Columns",
80
+ description="All columns to include as dimensions for segmentation.",
81
+ optional=True,
82
+ ),
83
+ ] = None,
67
84
  ) -> list[NumericMetric]:
85
+ """Executed SQL with no segmentation columns:
86
+ select time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts, \
87
+ count(*) as count, \
88
+ {categorical_col_escaped} as category, \
89
+ {categorical_col_name_escaped} as column_name \
90
+ from {dataset.dataset_table_name} \
91
+ where ts is not null \
92
+ group by ts, category
93
+ """
94
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
68
95
  timestamp_col_escaped = escape_identifier(timestamp_col)
69
96
  categorical_col_escaped = escape_identifier(categorical_col)
70
97
  categorical_col_name_escaped = escape_str_literal(categorical_col)
71
- count_query = f" \
72
- select time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts, \
73
- count(*) as count, \
74
- {categorical_col_escaped} as category, \
75
- {categorical_col_name_escaped} as column_name \
76
- from {dataset.dataset_table_name} \
77
- where ts is not null \
78
- group by ts, category \
79
- "
98
+
99
+ # build query components with segmentation columns
100
+ escaped_segmentation_cols = [
101
+ escape_identifier(col) for col in segmentation_cols
102
+ ]
103
+ all_select_clause_cols = [
104
+ f"time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts",
105
+ f"count(*) as count",
106
+ f"{categorical_col_escaped} as category",
107
+ f"{categorical_col_name_escaped} as column_name",
108
+ ] + escaped_segmentation_cols
109
+ all_group_by_cols = ["ts", "category"] + escaped_segmentation_cols
110
+ extra_dims = ["column_name", "category"]
111
+
112
+ # build query
113
+ count_query = f"""
114
+ select {", ".join(all_select_clause_cols)}
115
+ from {dataset.dataset_table_name}
116
+ where ts is not null
117
+ group by {", ".join(all_group_by_cols)}
118
+ """
119
+
80
120
  results = ddb_conn.sql(count_query).df()
81
121
 
82
122
  series = self.group_query_results_to_numeric_metrics(
83
123
  results,
84
124
  "count",
85
- ["column_name", "category"],
125
+ segmentation_cols + extra_dims,
86
126
  timestamp_col="ts",
87
127
  )
88
128
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -9,6 +9,7 @@ from arthur_common.models.schema_definitions import (
9
9
  MetricColumnParameterAnnotation,
10
10
  MetricDatasetParameterAnnotation,
11
11
  MetricLiteralParameterAnnotation,
12
+ MetricMultipleColumnParameterAnnotation,
12
13
  ScalarType,
13
14
  ScopeSchemaTag,
14
15
  )
@@ -26,6 +27,7 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
26
27
  prediction_normalization_case: str,
27
28
  gt_normalization_case: str,
28
29
  dataset: DatasetReference,
30
+ segmentation_cols: list[str],
29
31
  ) -> list[NumericMetric]:
30
32
  """
31
33
  Generate a SQL query to compute confusion matrix metrics over time.
@@ -37,59 +39,94 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
37
39
  prediction_normalization_case: SQL CASE statement for normalizing predictions to 0 / 1 / null using 'value' as the target column name
38
40
  gt_normalization_case: SQL CASE statement for normalizing ground truth values to 0 / 1 / null using 'value' as the target column name
39
41
  dataset: DatasetReference containing dataset metadata
42
+ segmentation_cols: list of columns to segment by
40
43
 
41
44
  Returns:
42
45
  str: SQL query that computes confusion matrix metrics
46
+ Without segmentation, this is the query:
47
+ WITH normalized_data AS (
48
+ SELECT
49
+ {escaped_timestamp_col} AS timestamp,
50
+ {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
51
+ {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
52
+ FROM {dataset.dataset_table_name}
53
+ WHERE {escaped_timestamp_col} IS NOT NULL
54
+ )
55
+ SELECT
56
+ time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
57
+ SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
58
+ SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
59
+ SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
60
+ SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count
61
+ FROM normalized_data
62
+ GROUP BY ts
63
+ ORDER BY ts
43
64
  """
65
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
44
66
  escaped_timestamp_col = escape_identifier(timestamp_col)
45
67
  escaped_prediction_col = escape_identifier(prediction_col)
46
68
  escaped_gt_values_col = escape_identifier(gt_values_col)
69
+ # build query components with segmentation columns
70
+ escaped_segmentation_cols = [
71
+ escape_identifier(col) for col in segmentation_cols
72
+ ]
73
+ first_subquery_select_cols = [
74
+ f"{escaped_timestamp_col} AS timestamp",
75
+ f"{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction",
76
+ f"{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value",
77
+ ] + escaped_segmentation_cols
78
+ second_subquery_select_cols = [
79
+ "time_bucket(INTERVAL '5 minutes', timestamp) AS ts",
80
+ "SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count",
81
+ "SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count",
82
+ "SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count",
83
+ "SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count",
84
+ ] + escaped_segmentation_cols
85
+ second_subquery_group_by_cols = ["ts"] + escaped_segmentation_cols
86
+
87
+ # build query
47
88
  confusion_matrix_query = f"""
48
- WITH normalized_data AS (
49
- SELECT
50
- {escaped_timestamp_col} AS timestamp,
51
- {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
52
- {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
53
- FROM {dataset.dataset_table_name}
54
- WHERE {escaped_timestamp_col} IS NOT NULL
55
- )
56
- SELECT
57
- time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
58
- SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
59
- SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
60
- SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
61
- SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count
62
- FROM normalized_data
63
- GROUP BY ts
64
- ORDER BY ts
65
- """
89
+ WITH normalized_data AS (
90
+ SELECT {", ".join(first_subquery_select_cols)}
91
+ FROM {dataset.dataset_table_name}
92
+ WHERE {escaped_timestamp_col} IS NOT NULL
93
+ )
94
+ SELECT {", ".join(second_subquery_select_cols)}
95
+ FROM normalized_data
96
+ GROUP BY {", ".join(second_subquery_group_by_cols)}
97
+ ORDER BY ts
98
+ """
66
99
 
67
100
  results = ddb_conn.sql(confusion_matrix_query).df()
68
101
 
69
- tp = self.dimensionless_query_results_to_numeric_metrics(
102
+ tp = self.group_query_results_to_numeric_metrics(
70
103
  results,
71
104
  "true_positive_count",
105
+ dim_columns=segmentation_cols,
72
106
  timestamp_col="ts",
73
107
  )
74
- fp = self.dimensionless_query_results_to_numeric_metrics(
108
+ fp = self.group_query_results_to_numeric_metrics(
75
109
  results,
76
110
  "false_positive_count",
111
+ dim_columns=segmentation_cols,
77
112
  timestamp_col="ts",
78
113
  )
79
- fn = self.dimensionless_query_results_to_numeric_metrics(
114
+ fn = self.group_query_results_to_numeric_metrics(
80
115
  results,
81
116
  "false_negative_count",
117
+ dim_columns=segmentation_cols,
82
118
  timestamp_col="ts",
83
119
  )
84
- tn = self.dimensionless_query_results_to_numeric_metrics(
120
+ tn = self.group_query_results_to_numeric_metrics(
85
121
  results,
86
122
  "true_negative_count",
123
+ dim_columns=segmentation_cols,
87
124
  timestamp_col="ts",
88
125
  )
89
- tp_metric = self.series_to_metric("confusion_matrix_true_positive_count", [tp])
90
- fp_metric = self.series_to_metric("confusion_matrix_false_positive_count", [fp])
91
- fn_metric = self.series_to_metric("confusion_matrix_false_negative_count", [fn])
92
- tn_metric = self.series_to_metric("confusion_matrix_true_negative_count", [tn])
126
+ tp_metric = self.series_to_metric("confusion_matrix_true_positive_count", tp)
127
+ fp_metric = self.series_to_metric("confusion_matrix_false_positive_count", fp)
128
+ fn_metric = self.series_to_metric("confusion_matrix_false_negative_count", fn)
129
+ tn_metric = self.series_to_metric("confusion_matrix_true_negative_count", tn)
93
130
  return [tp_metric, fp_metric, fn_metric, tn_metric]
94
131
 
95
132
 
@@ -157,7 +194,24 @@ class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
157
194
  description="A column containing boolean or integer ground truth values.",
158
195
  ),
159
196
  ],
197
+ segmentation_cols: Annotated[
198
+ Optional[list[str]],
199
+ MetricMultipleColumnParameterAnnotation(
200
+ source_dataset_parameter_key="dataset",
201
+ allowed_column_types=[
202
+ ScalarType(dtype=DType.INT),
203
+ ScalarType(dtype=DType.BOOL),
204
+ ScalarType(dtype=DType.STRING),
205
+ ScalarType(dtype=DType.UUID),
206
+ ],
207
+ tag_hints=[],
208
+ friendly_name="Segmentation Columns",
209
+ description="All columns to include as dimensions for segmentation.",
210
+ optional=True,
211
+ ),
212
+ ] = None,
160
213
  ) -> list[NumericMetric]:
214
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
161
215
  escaped_prediction_col = escape_identifier(prediction_col)
162
216
  # Get the type of prediction column
163
217
  type_query = f"SELECT typeof({escaped_prediction_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
@@ -194,6 +248,7 @@ class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
194
248
  normalization_case,
195
249
  normalization_case,
196
250
  dataset,
251
+ segmentation_cols,
197
252
  )
198
253
 
199
254
 
@@ -275,7 +330,24 @@ class BinaryClassifierStringLabelConfusionMatrixAggregationFunction(
275
330
  description="The label indicating a negative classification to normalize to 0.",
276
331
  ),
277
332
  ],
333
+ segmentation_cols: Annotated[
334
+ Optional[list[str]],
335
+ MetricMultipleColumnParameterAnnotation(
336
+ source_dataset_parameter_key="dataset",
337
+ allowed_column_types=[
338
+ ScalarType(dtype=DType.INT),
339
+ ScalarType(dtype=DType.BOOL),
340
+ ScalarType(dtype=DType.STRING),
341
+ ScalarType(dtype=DType.UUID),
342
+ ],
343
+ tag_hints=[],
344
+ friendly_name="Segmentation Columns",
345
+ description="All columns to include as dimensions for segmentation.",
346
+ optional=True,
347
+ ),
348
+ ] = None,
278
349
  ) -> list[NumericMetric]:
350
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
279
351
  normalization_case = f"""
280
352
  CASE
281
353
  WHEN value = '{true_label}' THEN 1
@@ -291,6 +363,7 @@ class BinaryClassifierStringLabelConfusionMatrixAggregationFunction(
291
363
  normalization_case,
292
364
  normalization_case,
293
365
  dataset,
366
+ segmentation_cols,
294
367
  )
295
368
 
296
369
 
@@ -365,6 +438,22 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
365
438
  description="The threshold to classify predictions to 0 or 1.",
366
439
  ),
367
440
  ],
441
+ segmentation_cols: Annotated[
442
+ Optional[list[str]],
443
+ MetricMultipleColumnParameterAnnotation(
444
+ source_dataset_parameter_key="dataset",
445
+ allowed_column_types=[
446
+ ScalarType(dtype=DType.INT),
447
+ ScalarType(dtype=DType.BOOL),
448
+ ScalarType(dtype=DType.STRING),
449
+ ScalarType(dtype=DType.UUID),
450
+ ],
451
+ tag_hints=[],
452
+ friendly_name="Segmentation Columns",
453
+ description="All columns to include as dimensions for segmentation.",
454
+ optional=True,
455
+ ),
456
+ ] = None,
368
457
  ) -> list[NumericMetric]:
369
458
  escaped_gt_values_col = escape_identifier(gt_values_col)
370
459
  prediction_normalization_case = f"""
@@ -409,4 +498,5 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
409
498
  prediction_normalization_case,
410
499
  gt_normalization_case,
411
500
  dataset,
501
+ segmentation_cols,
412
502
  )
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -7,6 +7,7 @@ from arthur_common.models.schema_definitions import (
7
7
  DType,
8
8
  MetricColumnParameterAnnotation,
9
9
  MetricDatasetParameterAnnotation,
10
+ MetricMultipleColumnParameterAnnotation,
10
11
  ScalarType,
11
12
  ScopeSchemaTag,
12
13
  )
@@ -51,19 +52,55 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
51
52
  description="A column containing timestamp values to bucket by.",
52
53
  ),
53
54
  ],
55
+ segmentation_cols: Annotated[
56
+ Optional[list[str]],
57
+ MetricMultipleColumnParameterAnnotation(
58
+ source_dataset_parameter_key="dataset",
59
+ allowed_column_types=[
60
+ ScalarType(dtype=DType.INT),
61
+ ScalarType(dtype=DType.BOOL),
62
+ ScalarType(dtype=DType.STRING),
63
+ ScalarType(dtype=DType.UUID),
64
+ ],
65
+ tag_hints=[],
66
+ friendly_name="Segmentation Columns",
67
+ description="All columns to include as dimensions for segmentation.",
68
+ optional=True,
69
+ ),
70
+ ] = None,
54
71
  ) -> list[NumericMetric]:
55
- escaped_timestamp_col = escape_identifier(timestamp_col)
56
- count_query = f" \
72
+ """Executed SQL with no segmentation columns:
57
73
  select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
58
- count(*) as count \
59
- from {dataset.dataset_table_name} \
60
- group by ts \
61
- "
74
+ count(*) as count \
75
+ from {dataset.dataset_table_name} \
76
+ group by ts \
77
+ """
78
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
79
+ escaped_timestamp_col = escape_identifier(timestamp_col)
80
+
81
+ # build query components with segmentation columns
82
+ escaped_segmentation_cols = [
83
+ escape_identifier(col) for col in segmentation_cols
84
+ ]
85
+ all_select_clause_cols = [
86
+ f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
87
+ f"count(*) as count",
88
+ ] + escaped_segmentation_cols
89
+ all_group_by_cols = ["ts"] + escaped_segmentation_cols
90
+
91
+ # build query
92
+ count_query = f"""
93
+ select {", ".join(all_select_clause_cols)}
94
+ from {dataset.dataset_table_name}
95
+ group by {", ".join(all_group_by_cols)}
96
+ """
97
+
62
98
  results = ddb_conn.sql(count_query).df()
63
- series = self.dimensionless_query_results_to_numeric_metrics(
99
+ series = self.group_query_results_to_numeric_metrics(
64
100
  results,
65
101
  "count",
102
+ segmentation_cols,
66
103
  "ts",
67
104
  )
68
- metric = self.series_to_metric(self.METRIC_NAME, [series])
105
+ metric = self.series_to_metric(self.METRIC_NAME, series)
69
106
  return [metric]
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -9,6 +9,7 @@ from arthur_common.models.schema_definitions import (
9
9
  MetricColumnParameterAnnotation,
10
10
  MetricDatasetParameterAnnotation,
11
11
  MetricLiteralParameterAnnotation,
12
+ MetricMultipleColumnParameterAnnotation,
12
13
  ScalarType,
13
14
  ScopeSchemaTag,
14
15
  )
@@ -70,29 +71,66 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
70
71
  description="A column containing boolean, integer, or string labelled prediction values.",
71
72
  ),
72
73
  ],
74
+ segmentation_cols: Annotated[
75
+ Optional[list[str]],
76
+ MetricMultipleColumnParameterAnnotation(
77
+ source_dataset_parameter_key="dataset",
78
+ allowed_column_types=[
79
+ ScalarType(dtype=DType.INT),
80
+ ScalarType(dtype=DType.BOOL),
81
+ ScalarType(dtype=DType.STRING),
82
+ ScalarType(dtype=DType.UUID),
83
+ ],
84
+ tag_hints=[],
85
+ friendly_name="Segmentation Columns",
86
+ description="All columns to include as dimensions for segmentation.",
87
+ optional=True,
88
+ ),
89
+ ] = None,
73
90
  ) -> list[NumericMetric]:
91
+ """Executed SQL with no segmentation columns:
92
+ SELECT
93
+ time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
94
+ {escaped_pred_col} as prediction,
95
+ COUNT(*) as count
96
+ FROM {dataset.dataset_table_name}
97
+ GROUP BY
98
+ ts,
99
+ -- group by raw column name instead of alias in select
100
+ -- in case table has a column called 'prediction'
101
+ {escaped_pred_col}
102
+ ORDER BY ts
103
+ """
104
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
74
105
  escaped_timestamp_col = escape_identifier(timestamp_col)
75
106
  escaped_pred_col = escape_identifier(prediction_col)
107
+
108
+ # build query components with segmentation columns
109
+ escaped_segmentation_cols = [
110
+ escape_identifier(col) for col in segmentation_cols
111
+ ]
112
+ all_select_clause_cols = [
113
+ f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
114
+ f"{escaped_pred_col} as prediction",
115
+ f"COUNT(*) as count",
116
+ ] + escaped_segmentation_cols
117
+ all_group_by_cols = ["ts", f"{escaped_pred_col}"] + escaped_segmentation_cols
118
+ extra_dims = ["prediction"]
119
+
120
+ # build query
76
121
  query = f"""
77
- SELECT
78
- time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
79
- {escaped_pred_col} as prediction,
80
- COUNT(*) as count
81
- FROM {dataset.dataset_table_name}
82
- GROUP BY
83
- ts,
84
- -- group by raw column name instead of alias in select
85
- -- in case table has a column called 'prediction'
86
- {escaped_pred_col}
87
- ORDER BY ts
88
- """
122
+ SELECT {", ".join(all_select_clause_cols)}
123
+ FROM {dataset.dataset_table_name}
124
+ GROUP BY {", ".join(all_group_by_cols)}
125
+ ORDER BY ts
126
+ """
89
127
 
90
128
  result = ddb_conn.sql(query).df()
91
129
 
92
130
  series = self.group_query_results_to_numeric_metrics(
93
131
  result,
94
132
  "count",
95
- ["prediction"],
133
+ segmentation_cols + extra_dims,
96
134
  "ts",
97
135
  )
98
136
  metric = self.series_to_metric(self._metric_name(), series)
@@ -177,20 +215,59 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
177
215
  description="The label denoting a negative classification.",
178
216
  ),
179
217
  ],
218
+ segmentation_cols: Annotated[
219
+ Optional[list[str]],
220
+ MetricMultipleColumnParameterAnnotation(
221
+ source_dataset_parameter_key="dataset",
222
+ allowed_column_types=[
223
+ ScalarType(dtype=DType.INT),
224
+ ScalarType(dtype=DType.BOOL),
225
+ ScalarType(dtype=DType.STRING),
226
+ ScalarType(dtype=DType.UUID),
227
+ ],
228
+ tag_hints=[],
229
+ friendly_name="Segmentation Columns",
230
+ description="All columns to include as dimensions for segmentation.",
231
+ optional=True,
232
+ ),
233
+ ] = None,
180
234
  ) -> list[NumericMetric]:
235
+ """Executed SQL with no segmentation columns:
236
+ SELECT
237
+ time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
238
+ CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
239
+ COUNT(*) as count
240
+ FROM {dataset.dataset_table_name}
241
+ GROUP BY
242
+ ts,
243
+ -- group by raw column name instead of alias in select
244
+ -- in case table has a column called 'prediction'
245
+ {escaped_prediction_col}
246
+ ORDER BY ts
247
+ """
248
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
181
249
  escaped_timestamp_col = escape_identifier(timestamp_col)
182
250
  escaped_prediction_col = escape_identifier(prediction_col)
251
+
252
+ # build query components with segmentation columns
253
+ escaped_segmentation_cols = [
254
+ escape_identifier(col) for col in segmentation_cols
255
+ ]
256
+ all_select_clause_cols = [
257
+ f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
258
+ f"CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction",
259
+ f"COUNT(*) as count",
260
+ ] + escaped_segmentation_cols
261
+ all_group_by_cols = [
262
+ "ts",
263
+ f"{escaped_prediction_col}",
264
+ ] + escaped_segmentation_cols
265
+ extra_dims = ["prediction"]
266
+
183
267
  query = f"""
184
- SELECT
185
- time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
186
- CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
187
- COUNT(*) as count
268
+ SELECT {", ".join(all_select_clause_cols)}
188
269
  FROM {dataset.dataset_table_name}
189
- GROUP BY
190
- ts,
191
- -- group by raw column name instead of alias in select
192
- -- in case table has a column called 'prediction'
193
- {escaped_prediction_col}
270
+ GROUP BY {", ".join(all_group_by_cols)}
194
271
  ORDER BY ts
195
272
  """
196
273
 
@@ -199,7 +276,7 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
199
276
  series = self.group_query_results_to_numeric_metrics(
200
277
  result,
201
278
  "count",
202
- ["prediction"],
279
+ segmentation_cols + extra_dims,
203
280
  "ts",
204
281
  )
205
282
  metric = self.series_to_metric(self._metric_name(), series)