arthur-common 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

Files changed (40) hide show
  1. arthur_common/__init__.py +0 -0
  2. arthur_common/__version__.py +1 -0
  3. arthur_common/aggregations/__init__.py +2 -0
  4. arthur_common/aggregations/aggregator.py +214 -0
  5. arthur_common/aggregations/functions/README.md +26 -0
  6. arthur_common/aggregations/functions/__init__.py +25 -0
  7. arthur_common/aggregations/functions/categorical_count.py +89 -0
  8. arthur_common/aggregations/functions/confusion_matrix.py +412 -0
  9. arthur_common/aggregations/functions/inference_count.py +69 -0
  10. arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
  11. arthur_common/aggregations/functions/inference_null_count.py +82 -0
  12. arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
  13. arthur_common/aggregations/functions/mean_squared_error.py +110 -0
  14. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
  15. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
  16. arthur_common/aggregations/functions/numeric_stats.py +90 -0
  17. arthur_common/aggregations/functions/numeric_sum.py +87 -0
  18. arthur_common/aggregations/functions/py.typed +0 -0
  19. arthur_common/aggregations/functions/shield_aggregations.py +752 -0
  20. arthur_common/aggregations/py.typed +0 -0
  21. arthur_common/models/__init__.py +0 -0
  22. arthur_common/models/connectors.py +41 -0
  23. arthur_common/models/datasets.py +22 -0
  24. arthur_common/models/metrics.py +227 -0
  25. arthur_common/models/py.typed +0 -0
  26. arthur_common/models/schema_definitions.py +420 -0
  27. arthur_common/models/shield.py +504 -0
  28. arthur_common/models/task_job_specs.py +78 -0
  29. arthur_common/py.typed +0 -0
  30. arthur_common/tools/__init__.py +0 -0
  31. arthur_common/tools/aggregation_analyzer.py +243 -0
  32. arthur_common/tools/aggregation_loader.py +59 -0
  33. arthur_common/tools/duckdb_data_loader.py +329 -0
  34. arthur_common/tools/functions.py +46 -0
  35. arthur_common/tools/py.typed +0 -0
  36. arthur_common/tools/schema_inferer.py +104 -0
  37. arthur_common/tools/time_utils.py +33 -0
  38. arthur_common-1.0.1.dist-info/METADATA +74 -0
  39. arthur_common-1.0.1.dist-info/RECORD +40 -0
  40. arthur_common-1.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,412 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
+ from arthur_common.models.datasets import ModelProblemType
6
+ from arthur_common.models.metrics import DatasetReference, NumericMetric
7
+ from arthur_common.models.schema_definitions import (
8
+ DType,
9
+ MetricColumnParameterAnnotation,
10
+ MetricDatasetParameterAnnotation,
11
+ MetricLiteralParameterAnnotation,
12
+ ScalarType,
13
+ ScopeSchemaTag,
14
+ )
15
+ from arthur_common.tools.duckdb_data_loader import escape_identifier
16
+ from duckdb import DuckDBPyConnection
17
+
18
+
19
+ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
20
+ def generate_confusion_matrix_metrics(
21
+ self,
22
+ ddb_conn: DuckDBPyConnection,
23
+ timestamp_col: str,
24
+ prediction_col: str,
25
+ gt_values_col: str,
26
+ prediction_normalization_case: str,
27
+ gt_normalization_case: str,
28
+ dataset: DatasetReference,
29
+ ) -> list[NumericMetric]:
30
+ """
31
+ Generate a SQL query to compute confusion matrix metrics over time.
32
+
33
+ Args:
34
+ timestamp_col: Column name containing timestamps
35
+ prediction_col: Column name containing predictions
36
+ gt_values_col: Column name containing ground truth values
37
+ prediction_normalization_case: SQL CASE statement for normalizing predictions to 0 / 1 / null using 'value' as the target column name
38
+ gt_normalization_case: SQL CASE statement for normalizing ground truth values to 0 / 1 / null using 'value' as the target column name
39
+ dataset: DatasetReference containing dataset metadata
40
+
41
+ Returns:
42
+ str: SQL query that computes confusion matrix metrics
43
+ """
44
+ escaped_timestamp_col = escape_identifier(timestamp_col)
45
+ escaped_prediction_col = escape_identifier(prediction_col)
46
+ escaped_gt_values_col = escape_identifier(gt_values_col)
47
+ confusion_matrix_query = f"""
48
+ WITH normalized_data AS (
49
+ SELECT
50
+ {escaped_timestamp_col} AS timestamp,
51
+ {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
52
+ {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
53
+ FROM {dataset.dataset_table_name}
54
+ WHERE {escaped_timestamp_col} IS NOT NULL
55
+ )
56
+ SELECT
57
+ time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
58
+ SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
59
+ SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
60
+ SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
61
+ SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count
62
+ FROM normalized_data
63
+ GROUP BY ts
64
+ ORDER BY ts
65
+ """
66
+
67
+ results = ddb_conn.sql(confusion_matrix_query).df()
68
+
69
+ tp = self.dimensionless_query_results_to_numeric_metrics(
70
+ results,
71
+ "true_positive_count",
72
+ timestamp_col="ts",
73
+ )
74
+ fp = self.dimensionless_query_results_to_numeric_metrics(
75
+ results,
76
+ "false_positive_count",
77
+ timestamp_col="ts",
78
+ )
79
+ fn = self.dimensionless_query_results_to_numeric_metrics(
80
+ results,
81
+ "false_negative_count",
82
+ timestamp_col="ts",
83
+ )
84
+ tn = self.dimensionless_query_results_to_numeric_metrics(
85
+ results,
86
+ "true_negative_count",
87
+ timestamp_col="ts",
88
+ )
89
+ tp_metric = self.series_to_metric("confusion_matrix_true_positive_count", [tp])
90
+ fp_metric = self.series_to_metric("confusion_matrix_false_positive_count", [fp])
91
+ fn_metric = self.series_to_metric("confusion_matrix_false_negative_count", [fn])
92
+ tn_metric = self.series_to_metric("confusion_matrix_true_negative_count", [tn])
93
+ return [tp_metric, fp_metric, fn_metric, tn_metric]
94
+
95
+
96
+ class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
97
+ ConfusionMatrixAggregationFunction,
98
+ ):
99
+ @staticmethod
100
+ def id() -> UUID:
101
+ return UUID("00000000-0000-0000-0000-00000000001c")
102
+
103
+ @staticmethod
104
+ def display_name() -> str:
105
+ return "Binary Classification Confusion Matrix - Int/Bool Prediction"
106
+
107
+ @staticmethod
108
+ def description() -> str:
109
+ return "Aggregation that takes in boolean or integer prediction and ground truth values and calculates the confusion matrix (True Positives, False Positives, False Negatives, True Negatives) for a binary set of predictions and values."
110
+
111
+ def aggregate(
112
+ self,
113
+ ddb_conn: DuckDBPyConnection,
114
+ dataset: Annotated[
115
+ DatasetReference,
116
+ MetricDatasetParameterAnnotation(
117
+ friendly_name="Dataset",
118
+ description="The dataset containing the prediction and ground truth values.",
119
+ model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
120
+ ),
121
+ ],
122
+ timestamp_col: Annotated[
123
+ str,
124
+ MetricColumnParameterAnnotation(
125
+ source_dataset_parameter_key="dataset",
126
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
127
+ allowed_column_types=[
128
+ ScalarType(dtype=DType.TIMESTAMP),
129
+ ],
130
+ friendly_name="Timestamp Column",
131
+ description="A column containing timestamp values to bucket by.",
132
+ ),
133
+ ],
134
+ prediction_col: Annotated[
135
+ str,
136
+ MetricColumnParameterAnnotation(
137
+ source_dataset_parameter_key="dataset",
138
+ allowed_column_types=[
139
+ ScalarType(dtype=DType.BOOL),
140
+ ScalarType(dtype=DType.INT),
141
+ ],
142
+ tag_hints=[ScopeSchemaTag.PREDICTION],
143
+ friendly_name="Prediction Column",
144
+ description="A column containing boolean or integer prediction values.",
145
+ ),
146
+ ],
147
+ gt_values_col: Annotated[
148
+ str,
149
+ MetricColumnParameterAnnotation(
150
+ source_dataset_parameter_key="dataset",
151
+ allowed_column_types=[
152
+ ScalarType(dtype=DType.BOOL),
153
+ ScalarType(dtype=DType.INT),
154
+ ],
155
+ tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
156
+ friendly_name="Ground Truth Column",
157
+ description="A column containing boolean or integer ground truth values.",
158
+ ),
159
+ ],
160
+ ) -> list[NumericMetric]:
161
+ escaped_prediction_col = escape_identifier(prediction_col)
162
+ # Get the type of prediction column
163
+ type_query = f"SELECT typeof({escaped_prediction_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
164
+ res = ddb_conn.sql(type_query).fetchone()
165
+ # As long as this column exists, we should be able to get the type. This is here to make mypy happy.
166
+ if not res:
167
+ raise ValueError(f"No results found for type query: {type_query}")
168
+ col_type = res[0].lower()
169
+
170
+ match col_type:
171
+ case "boolean":
172
+ normalization_case = """
173
+ CASE
174
+ WHEN value THEN 1
175
+ ELSE 0
176
+ END
177
+ """
178
+ case "integer" | "bigint":
179
+ normalization_case = """
180
+ CASE
181
+ WHEN value = 1 THEN 1
182
+ WHEN value = 0 THEN 0
183
+ ELSE NULL
184
+ END
185
+ """
186
+ case _:
187
+ raise ValueError(f"Unsupported column type: {col_type}")
188
+
189
+ return self.generate_confusion_matrix_metrics(
190
+ ddb_conn,
191
+ timestamp_col,
192
+ prediction_col,
193
+ gt_values_col,
194
+ normalization_case,
195
+ normalization_case,
196
+ dataset,
197
+ )
198
+
199
+
200
+ class BinaryClassifierStringLabelConfusionMatrixAggregationFunction(
201
+ ConfusionMatrixAggregationFunction,
202
+ ):
203
+ @staticmethod
204
+ def id() -> UUID:
205
+ return UUID("00000000-0000-0000-0000-00000000001d")
206
+
207
+ @staticmethod
208
+ def display_name() -> str:
209
+ return "Binary Classification Confusion Matrix - String Class Label Prediction"
210
+
211
+ @staticmethod
212
+ def description() -> str:
213
+ return "Aggregation that takes in string labelled prediction and ground truth values and calculates the confusion matrix (True Positives, False Positives, False Negatives, True Negatives) for a binary set of predictions and values."
214
+
215
+ def aggregate(
216
+ self,
217
+ ddb_conn: DuckDBPyConnection,
218
+ dataset: Annotated[
219
+ DatasetReference,
220
+ MetricDatasetParameterAnnotation(
221
+ friendly_name="Dataset",
222
+ description="The dataset containing the prediction and ground truth values.",
223
+ model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
224
+ ),
225
+ ],
226
+ timestamp_col: Annotated[
227
+ str,
228
+ MetricColumnParameterAnnotation(
229
+ source_dataset_parameter_key="dataset",
230
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
231
+ allowed_column_types=[
232
+ ScalarType(dtype=DType.TIMESTAMP),
233
+ ],
234
+ friendly_name="Timestamp Column",
235
+ description="A column containing timestamp values to bucket by.",
236
+ ),
237
+ ],
238
+ prediction_col: Annotated[
239
+ str,
240
+ MetricColumnParameterAnnotation(
241
+ source_dataset_parameter_key="dataset",
242
+ allowed_column_types=[
243
+ ScalarType(dtype=DType.STRING),
244
+ ],
245
+ tag_hints=[ScopeSchemaTag.PREDICTION],
246
+ friendly_name="Prediction Column",
247
+ description="A column containing string labelled prediction values.",
248
+ ),
249
+ ],
250
+ gt_values_col: Annotated[
251
+ str,
252
+ MetricColumnParameterAnnotation(
253
+ source_dataset_parameter_key="dataset",
254
+ allowed_column_types=[
255
+ ScalarType(dtype=DType.STRING),
256
+ ],
257
+ tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
258
+ friendly_name="Ground Truth Column",
259
+ description="A column containing string labelled ground truth values.",
260
+ ),
261
+ ],
262
+ true_label: Annotated[
263
+ str,
264
+ MetricLiteralParameterAnnotation(
265
+ parameter_dtype=DType.STRING,
266
+ friendly_name="True Label",
267
+ description="The label indicating a positive classification to normalize to 1.",
268
+ ),
269
+ ],
270
+ false_label: Annotated[
271
+ str,
272
+ MetricLiteralParameterAnnotation(
273
+ parameter_dtype=DType.STRING,
274
+ friendly_name="False Label",
275
+ description="The label indicating a negative classification to normalize to 0.",
276
+ ),
277
+ ],
278
+ ) -> list[NumericMetric]:
279
+ normalization_case = f"""
280
+ CASE
281
+ WHEN value = '{true_label}' THEN 1
282
+ WHEN value = '{false_label}' THEN 0
283
+ ELSE NULL
284
+ END
285
+ """
286
+ return self.generate_confusion_matrix_metrics(
287
+ ddb_conn,
288
+ timestamp_col,
289
+ prediction_col,
290
+ gt_values_col,
291
+ normalization_case,
292
+ normalization_case,
293
+ dataset,
294
+ )
295
+
296
+
297
+ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
298
+ ConfusionMatrixAggregationFunction,
299
+ ):
300
+ @staticmethod
301
+ def id() -> UUID:
302
+ return UUID("00000000-0000-0000-0000-00000000001e")
303
+
304
+ @staticmethod
305
+ def display_name() -> str:
306
+ return "Binary Classification Confusion Matrix - Probability Threshold"
307
+
308
+ @staticmethod
309
+ def description() -> str:
310
+ return "Aggregation that takes in a float prediction column, a ground truth values column, and a probability threshold and calculates the confusion matrix (True Positives, False Positives, False Negatives, True Negatives) for a binary set of predictions and values where the predictions are calculated using the probability threshold."
311
+
312
+ def aggregate(
313
+ self,
314
+ ddb_conn: DuckDBPyConnection,
315
+ dataset: Annotated[
316
+ DatasetReference,
317
+ MetricDatasetParameterAnnotation(
318
+ friendly_name="Dataset",
319
+ description="The dataset containing the prediction and ground truth values.",
320
+ model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
321
+ ),
322
+ ],
323
+ timestamp_col: Annotated[
324
+ str,
325
+ MetricColumnParameterAnnotation(
326
+ source_dataset_parameter_key="dataset",
327
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
328
+ allowed_column_types=[
329
+ ScalarType(dtype=DType.TIMESTAMP),
330
+ ],
331
+ friendly_name="Timestamp Column",
332
+ description="A column containing timestamp values to bucket by.",
333
+ ),
334
+ ],
335
+ prediction_col: Annotated[
336
+ str,
337
+ MetricColumnParameterAnnotation(
338
+ source_dataset_parameter_key="dataset",
339
+ allowed_column_types=[
340
+ ScalarType(dtype=DType.FLOAT),
341
+ ],
342
+ tag_hints=[ScopeSchemaTag.PREDICTION],
343
+ friendly_name="Prediction Column",
344
+ description="A column containing float prediction values.",
345
+ ),
346
+ ],
347
+ gt_values_col: Annotated[
348
+ str,
349
+ MetricColumnParameterAnnotation(
350
+ source_dataset_parameter_key="dataset",
351
+ allowed_column_types=[
352
+ ScalarType(dtype=DType.BOOL),
353
+ ScalarType(dtype=DType.INT),
354
+ ],
355
+ tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
356
+ friendly_name="Ground Truth Column",
357
+ description="A column containing boolean or integer ground truth values.",
358
+ ),
359
+ ],
360
+ threshold: Annotated[
361
+ float,
362
+ MetricLiteralParameterAnnotation(
363
+ parameter_dtype=DType.FLOAT,
364
+ friendly_name="Threshold",
365
+ description="The threshold to classify predictions to 0 or 1.",
366
+ ),
367
+ ],
368
+ ) -> list[NumericMetric]:
369
+ escaped_gt_values_col = escape_identifier(gt_values_col)
370
+ prediction_normalization_case = f"""
371
+ CASE
372
+ WHEN value >= {threshold} THEN 1
373
+ WHEN value < {threshold} THEN 0
374
+ ELSE NULL
375
+ END
376
+ """
377
+
378
+ type_query = f"SELECT typeof({escaped_gt_values_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
379
+ res = ddb_conn.sql(type_query).fetchone()
380
+ # As long as this column exists, we should be able to get the type. This is here to make mypy happy.
381
+ if not res:
382
+ raise ValueError(f"No results found for type query: {type_query}")
383
+ col_type = res[0].lower()
384
+
385
+ match col_type:
386
+ case "boolean":
387
+ gt_normalization_case = """
388
+ CASE
389
+ WHEN value THEN 1
390
+ ELSE 0
391
+ END
392
+ """
393
+ case "integer" | "bigint":
394
+ gt_normalization_case = """
395
+ CASE
396
+ WHEN value = 1 THEN 1
397
+ WHEN value = 0 THEN 0
398
+ ELSE NULL
399
+ END
400
+ """
401
+ case _:
402
+ raise ValueError(f"Unsupported column type: {col_type}")
403
+
404
+ return self.generate_confusion_matrix_metrics(
405
+ ddb_conn,
406
+ timestamp_col,
407
+ prediction_col,
408
+ gt_values_col,
409
+ prediction_normalization_case,
410
+ gt_normalization_case,
411
+ dataset,
412
+ )
@@ -0,0 +1,69 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
+ from arthur_common.models.metrics import DatasetReference, NumericMetric
6
+ from arthur_common.models.schema_definitions import (
7
+ DType,
8
+ MetricColumnParameterAnnotation,
9
+ MetricDatasetParameterAnnotation,
10
+ ScalarType,
11
+ ScopeSchemaTag,
12
+ )
13
+ from arthur_common.tools.duckdb_data_loader import escape_identifier
14
+ from duckdb import DuckDBPyConnection
15
+
16
+
17
+ class InferenceCountAggregationFunction(NumericAggregationFunction):
18
+ METRIC_NAME = "inference_count"
19
+
20
+ @staticmethod
21
+ def id() -> UUID:
22
+ return UUID("00000000-0000-0000-0000-00000000000a")
23
+
24
+ @staticmethod
25
+ def display_name() -> str:
26
+ return "Inference Count"
27
+
28
+ @staticmethod
29
+ def description() -> str:
30
+ return "Metric that counts the number of inferences per time window."
31
+
32
+ def aggregate(
33
+ self,
34
+ ddb_conn: DuckDBPyConnection,
35
+ dataset: Annotated[
36
+ DatasetReference,
37
+ MetricDatasetParameterAnnotation(
38
+ friendly_name="Dataset",
39
+ description="The dataset containing the inference data.",
40
+ ),
41
+ ],
42
+ timestamp_col: Annotated[
43
+ str,
44
+ MetricColumnParameterAnnotation(
45
+ source_dataset_parameter_key="dataset",
46
+ allowed_column_types=[
47
+ ScalarType(dtype=DType.TIMESTAMP),
48
+ ],
49
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
50
+ friendly_name="Timestamp Column",
51
+ description="A column containing timestamp values to bucket by.",
52
+ ),
53
+ ],
54
+ ) -> list[NumericMetric]:
55
+ escaped_timestamp_col = escape_identifier(timestamp_col)
56
+ count_query = f" \
57
+ select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
58
+ count(*) as count \
59
+ from {dataset.dataset_table_name} \
60
+ group by ts \
61
+ "
62
+ results = ddb_conn.sql(count_query).df()
63
+ series = self.dimensionless_query_results_to_numeric_metrics(
64
+ results,
65
+ "count",
66
+ "ts",
67
+ )
68
+ metric = self.series_to_metric(self.METRIC_NAME, [series])
69
+ return [metric]
@@ -0,0 +1,206 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
+ from arthur_common.models.datasets import ModelProblemType
6
+ from arthur_common.models.metrics import DatasetReference, NumericMetric
7
+ from arthur_common.models.schema_definitions import (
8
+ DType,
9
+ MetricColumnParameterAnnotation,
10
+ MetricDatasetParameterAnnotation,
11
+ MetricLiteralParameterAnnotation,
12
+ ScalarType,
13
+ ScopeSchemaTag,
14
+ )
15
+ from arthur_common.tools.duckdb_data_loader import escape_identifier
16
+ from duckdb import DuckDBPyConnection
17
+
18
+
19
+ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction):
20
+ @staticmethod
21
+ def id() -> UUID:
22
+ return UUID("00000000-0000-0000-0000-00000000001f")
23
+
24
+ @staticmethod
25
+ def display_name() -> str:
26
+ return "Binary Classification Count by Class - Class Label"
27
+
28
+ @staticmethod
29
+ def description() -> str:
30
+ return "Aggregation that counts the number of predictions by class for a binary classifier. Takes boolean, integer, or string prediction values and groups them by time bucket to show prediction distribution over time."
31
+
32
+ @staticmethod
33
+ def _metric_name() -> str:
34
+ return "binary_classifier_count_by_class"
35
+
36
+ def aggregate(
37
+ self,
38
+ ddb_conn: DuckDBPyConnection,
39
+ dataset: Annotated[
40
+ DatasetReference,
41
+ MetricDatasetParameterAnnotation(
42
+ friendly_name="Dataset",
43
+ description="The dataset containing binary classifier prediction values.",
44
+ model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
45
+ ),
46
+ ],
47
+ timestamp_col: Annotated[
48
+ str,
49
+ MetricColumnParameterAnnotation(
50
+ source_dataset_parameter_key="dataset",
51
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
52
+ allowed_column_types=[
53
+ ScalarType(dtype=DType.TIMESTAMP),
54
+ ],
55
+ friendly_name="Timestamp Column",
56
+ description="A column containing timestamp values to bucket by.",
57
+ ),
58
+ ],
59
+ prediction_col: Annotated[
60
+ str,
61
+ MetricColumnParameterAnnotation(
62
+ source_dataset_parameter_key="dataset",
63
+ allowed_column_types=[
64
+ ScalarType(dtype=DType.BOOL),
65
+ ScalarType(dtype=DType.INT),
66
+ ScalarType(dtype=DType.STRING),
67
+ ],
68
+ tag_hints=[ScopeSchemaTag.PREDICTION],
69
+ friendly_name="Prediction Column",
70
+ description="A column containing boolean, integer, or string labelled prediction values.",
71
+ ),
72
+ ],
73
+ ) -> list[NumericMetric]:
74
+ escaped_timestamp_col = escape_identifier(timestamp_col)
75
+ escaped_pred_col = escape_identifier(prediction_col)
76
+ query = f"""
77
+ SELECT
78
+ time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
79
+ {escaped_pred_col} as prediction,
80
+ COUNT(*) as count
81
+ FROM {dataset.dataset_table_name}
82
+ GROUP BY
83
+ ts,
84
+ -- group by raw column name instead of alias in select
85
+ -- in case table has a column called 'prediction'
86
+ {escaped_pred_col}
87
+ ORDER BY ts
88
+ """
89
+
90
+ result = ddb_conn.sql(query).df()
91
+
92
+ series = self.group_query_results_to_numeric_metrics(
93
+ result,
94
+ "count",
95
+ ["prediction"],
96
+ "ts",
97
+ )
98
+ metric = self.series_to_metric(self._metric_name(), series)
99
+ return [metric]
100
+
101
+
102
+ class BinaryClassifierCountThresholdClassAggregationFunction(
103
+ NumericAggregationFunction,
104
+ ):
105
+ @staticmethod
106
+ def id() -> UUID:
107
+ return UUID("00000000-0000-0000-0000-000000000020")
108
+
109
+ @staticmethod
110
+ def display_name() -> str:
111
+ return "Binary Classification Count by Class - Probability Threshold"
112
+
113
+ @staticmethod
114
+ def description() -> str:
115
+ return "Aggregation that counts the number of predictions by class for a binary classifier using a probability threshold. Takes float prediction values and a threshold value to classify predictions, then groups them by time bucket to show prediction distribution over time."
116
+
117
+ @staticmethod
118
+ def _metric_name() -> str:
119
+ return "binary_classifier_count_by_class"
120
+
121
+ def aggregate(
122
+ self,
123
+ ddb_conn: DuckDBPyConnection,
124
+ dataset: Annotated[
125
+ DatasetReference,
126
+ MetricDatasetParameterAnnotation(
127
+ friendly_name="Dataset",
128
+ description="The dataset containing binary classifier prediction values.",
129
+ model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
130
+ ),
131
+ ],
132
+ timestamp_col: Annotated[
133
+ str,
134
+ MetricColumnParameterAnnotation(
135
+ source_dataset_parameter_key="dataset",
136
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
137
+ allowed_column_types=[
138
+ ScalarType(dtype=DType.TIMESTAMP),
139
+ ],
140
+ friendly_name="Timestamp Column",
141
+ description="A column containing timestamp values to bucket by.",
142
+ ),
143
+ ],
144
+ prediction_col: Annotated[
145
+ str,
146
+ MetricColumnParameterAnnotation(
147
+ source_dataset_parameter_key="dataset",
148
+ allowed_column_types=[
149
+ ScalarType(dtype=DType.FLOAT),
150
+ ],
151
+ tag_hints=[ScopeSchemaTag.PREDICTION],
152
+ friendly_name="Prediction Column",
153
+ description="A column containing float prediction values.",
154
+ ),
155
+ ],
156
+ threshold: Annotated[
157
+ float,
158
+ MetricLiteralParameterAnnotation(
159
+ parameter_dtype=DType.FLOAT,
160
+ friendly_name="Threshold",
161
+ description="The threshold to classify predictions to 0 or 1. 0 will result in the 'False Label' being assigned and 1 to the 'True Label' being assigned.",
162
+ ),
163
+ ],
164
+ true_label: Annotated[
165
+ str,
166
+ MetricLiteralParameterAnnotation(
167
+ parameter_dtype=DType.STRING,
168
+ friendly_name="True Label",
169
+ description="The label denoting a positive classification.",
170
+ ),
171
+ ],
172
+ false_label: Annotated[
173
+ str,
174
+ MetricLiteralParameterAnnotation(
175
+ parameter_dtype=DType.STRING,
176
+ friendly_name="False Label",
177
+ description="The label denoting a negative classification.",
178
+ ),
179
+ ],
180
+ ) -> list[NumericMetric]:
181
+ escaped_timestamp_col = escape_identifier(timestamp_col)
182
+ escaped_prediction_col = escape_identifier(prediction_col)
183
+ query = f"""
184
+ SELECT
185
+ time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
186
+ CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
187
+ COUNT(*) as count
188
+ FROM {dataset.dataset_table_name}
189
+ GROUP BY
190
+ ts,
191
+ -- group by raw column name instead of alias in select
192
+ -- in case table has a column called 'prediction'
193
+ {escaped_prediction_col}
194
+ ORDER BY ts
195
+ """
196
+
197
+ result = ddb_conn.sql(query).df()
198
+
199
+ series = self.group_query_results_to_numeric_metrics(
200
+ result,
201
+ "count",
202
+ ["prediction"],
203
+ "ts",
204
+ )
205
+ metric = self.series_to_metric(self._metric_name(), series)
206
+ return [metric]