arthur-common 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/__init__.py +0 -0
- arthur_common/__version__.py +1 -0
- arthur_common/aggregations/__init__.py +2 -0
- arthur_common/aggregations/aggregator.py +214 -0
- arthur_common/aggregations/functions/README.md +26 -0
- arthur_common/aggregations/functions/__init__.py +25 -0
- arthur_common/aggregations/functions/categorical_count.py +89 -0
- arthur_common/aggregations/functions/confusion_matrix.py +412 -0
- arthur_common/aggregations/functions/inference_count.py +69 -0
- arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
- arthur_common/aggregations/functions/inference_null_count.py +82 -0
- arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
- arthur_common/aggregations/functions/mean_squared_error.py +110 -0
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
- arthur_common/aggregations/functions/numeric_stats.py +90 -0
- arthur_common/aggregations/functions/numeric_sum.py +87 -0
- arthur_common/aggregations/functions/py.typed +0 -0
- arthur_common/aggregations/functions/shield_aggregations.py +752 -0
- arthur_common/aggregations/py.typed +0 -0
- arthur_common/models/__init__.py +0 -0
- arthur_common/models/connectors.py +41 -0
- arthur_common/models/datasets.py +22 -0
- arthur_common/models/metrics.py +227 -0
- arthur_common/models/py.typed +0 -0
- arthur_common/models/schema_definitions.py +420 -0
- arthur_common/models/shield.py +504 -0
- arthur_common/models/task_job_specs.py +78 -0
- arthur_common/py.typed +0 -0
- arthur_common/tools/__init__.py +0 -0
- arthur_common/tools/aggregation_analyzer.py +243 -0
- arthur_common/tools/aggregation_loader.py +59 -0
- arthur_common/tools/duckdb_data_loader.py +329 -0
- arthur_common/tools/functions.py +46 -0
- arthur_common/tools/py.typed +0 -0
- arthur_common/tools/schema_inferer.py +104 -0
- arthur_common/tools/time_utils.py +33 -0
- arthur_common-1.0.1.dist-info/METADATA +74 -0
- arthur_common-1.0.1.dist-info/RECORD +40 -0
- arthur_common-1.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
5
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
6
|
+
from arthur_common.models.metrics import DatasetReference, NumericMetric
|
|
7
|
+
from arthur_common.models.schema_definitions import (
|
|
8
|
+
DType,
|
|
9
|
+
MetricColumnParameterAnnotation,
|
|
10
|
+
MetricDatasetParameterAnnotation,
|
|
11
|
+
MetricLiteralParameterAnnotation,
|
|
12
|
+
ScalarType,
|
|
13
|
+
ScopeSchemaTag,
|
|
14
|
+
)
|
|
15
|
+
from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
16
|
+
from duckdb import DuckDBPyConnection
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
|
|
20
|
+
def generate_confusion_matrix_metrics(
|
|
21
|
+
self,
|
|
22
|
+
ddb_conn: DuckDBPyConnection,
|
|
23
|
+
timestamp_col: str,
|
|
24
|
+
prediction_col: str,
|
|
25
|
+
gt_values_col: str,
|
|
26
|
+
prediction_normalization_case: str,
|
|
27
|
+
gt_normalization_case: str,
|
|
28
|
+
dataset: DatasetReference,
|
|
29
|
+
) -> list[NumericMetric]:
|
|
30
|
+
"""
|
|
31
|
+
Generate a SQL query to compute confusion matrix metrics over time.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
timestamp_col: Column name containing timestamps
|
|
35
|
+
prediction_col: Column name containing predictions
|
|
36
|
+
gt_values_col: Column name containing ground truth values
|
|
37
|
+
prediction_normalization_case: SQL CASE statement for normalizing predictions to 0 / 1 / null using 'value' as the target column name
|
|
38
|
+
gt_normalization_case: SQL CASE statement for normalizing ground truth values to 0 / 1 / null using 'value' as the target column name
|
|
39
|
+
dataset: DatasetReference containing dataset metadata
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
str: SQL query that computes confusion matrix metrics
|
|
43
|
+
"""
|
|
44
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
45
|
+
escaped_prediction_col = escape_identifier(prediction_col)
|
|
46
|
+
escaped_gt_values_col = escape_identifier(gt_values_col)
|
|
47
|
+
confusion_matrix_query = f"""
|
|
48
|
+
WITH normalized_data AS (
|
|
49
|
+
SELECT
|
|
50
|
+
{escaped_timestamp_col} AS timestamp,
|
|
51
|
+
{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
|
|
52
|
+
{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
|
|
53
|
+
FROM {dataset.dataset_table_name}
|
|
54
|
+
WHERE {escaped_timestamp_col} IS NOT NULL
|
|
55
|
+
)
|
|
56
|
+
SELECT
|
|
57
|
+
time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
|
|
58
|
+
SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
|
|
59
|
+
SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
|
|
60
|
+
SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
|
|
61
|
+
SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count
|
|
62
|
+
FROM normalized_data
|
|
63
|
+
GROUP BY ts
|
|
64
|
+
ORDER BY ts
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
results = ddb_conn.sql(confusion_matrix_query).df()
|
|
68
|
+
|
|
69
|
+
tp = self.dimensionless_query_results_to_numeric_metrics(
|
|
70
|
+
results,
|
|
71
|
+
"true_positive_count",
|
|
72
|
+
timestamp_col="ts",
|
|
73
|
+
)
|
|
74
|
+
fp = self.dimensionless_query_results_to_numeric_metrics(
|
|
75
|
+
results,
|
|
76
|
+
"false_positive_count",
|
|
77
|
+
timestamp_col="ts",
|
|
78
|
+
)
|
|
79
|
+
fn = self.dimensionless_query_results_to_numeric_metrics(
|
|
80
|
+
results,
|
|
81
|
+
"false_negative_count",
|
|
82
|
+
timestamp_col="ts",
|
|
83
|
+
)
|
|
84
|
+
tn = self.dimensionless_query_results_to_numeric_metrics(
|
|
85
|
+
results,
|
|
86
|
+
"true_negative_count",
|
|
87
|
+
timestamp_col="ts",
|
|
88
|
+
)
|
|
89
|
+
tp_metric = self.series_to_metric("confusion_matrix_true_positive_count", [tp])
|
|
90
|
+
fp_metric = self.series_to_metric("confusion_matrix_false_positive_count", [fp])
|
|
91
|
+
fn_metric = self.series_to_metric("confusion_matrix_false_negative_count", [fn])
|
|
92
|
+
tn_metric = self.series_to_metric("confusion_matrix_true_negative_count", [tn])
|
|
93
|
+
return [tp_metric, fp_metric, fn_metric, tn_metric]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
|
|
97
|
+
ConfusionMatrixAggregationFunction,
|
|
98
|
+
):
|
|
99
|
+
@staticmethod
|
|
100
|
+
def id() -> UUID:
|
|
101
|
+
return UUID("00000000-0000-0000-0000-00000000001c")
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def display_name() -> str:
|
|
105
|
+
return "Binary Classification Confusion Matrix - Int/Bool Prediction"
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def description() -> str:
|
|
109
|
+
return "Aggregation that takes in boolean or integer prediction and ground truth values and calculates the confusion matrix (True Positives, False Positives, False Negatives, True Negatives) for a binary set of predictions and values."
|
|
110
|
+
|
|
111
|
+
def aggregate(
|
|
112
|
+
self,
|
|
113
|
+
ddb_conn: DuckDBPyConnection,
|
|
114
|
+
dataset: Annotated[
|
|
115
|
+
DatasetReference,
|
|
116
|
+
MetricDatasetParameterAnnotation(
|
|
117
|
+
friendly_name="Dataset",
|
|
118
|
+
description="The dataset containing the prediction and ground truth values.",
|
|
119
|
+
model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
|
|
120
|
+
),
|
|
121
|
+
],
|
|
122
|
+
timestamp_col: Annotated[
|
|
123
|
+
str,
|
|
124
|
+
MetricColumnParameterAnnotation(
|
|
125
|
+
source_dataset_parameter_key="dataset",
|
|
126
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
127
|
+
allowed_column_types=[
|
|
128
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
129
|
+
],
|
|
130
|
+
friendly_name="Timestamp Column",
|
|
131
|
+
description="A column containing timestamp values to bucket by.",
|
|
132
|
+
),
|
|
133
|
+
],
|
|
134
|
+
prediction_col: Annotated[
|
|
135
|
+
str,
|
|
136
|
+
MetricColumnParameterAnnotation(
|
|
137
|
+
source_dataset_parameter_key="dataset",
|
|
138
|
+
allowed_column_types=[
|
|
139
|
+
ScalarType(dtype=DType.BOOL),
|
|
140
|
+
ScalarType(dtype=DType.INT),
|
|
141
|
+
],
|
|
142
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
143
|
+
friendly_name="Prediction Column",
|
|
144
|
+
description="A column containing boolean or integer prediction values.",
|
|
145
|
+
),
|
|
146
|
+
],
|
|
147
|
+
gt_values_col: Annotated[
|
|
148
|
+
str,
|
|
149
|
+
MetricColumnParameterAnnotation(
|
|
150
|
+
source_dataset_parameter_key="dataset",
|
|
151
|
+
allowed_column_types=[
|
|
152
|
+
ScalarType(dtype=DType.BOOL),
|
|
153
|
+
ScalarType(dtype=DType.INT),
|
|
154
|
+
],
|
|
155
|
+
tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
|
|
156
|
+
friendly_name="Ground Truth Column",
|
|
157
|
+
description="A column containing boolean or integer ground truth values.",
|
|
158
|
+
),
|
|
159
|
+
],
|
|
160
|
+
) -> list[NumericMetric]:
|
|
161
|
+
escaped_prediction_col = escape_identifier(prediction_col)
|
|
162
|
+
# Get the type of prediction column
|
|
163
|
+
type_query = f"SELECT typeof({escaped_prediction_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
|
|
164
|
+
res = ddb_conn.sql(type_query).fetchone()
|
|
165
|
+
# As long as this column exists, we should be able to get the type. This is here to make mypy happy.
|
|
166
|
+
if not res:
|
|
167
|
+
raise ValueError(f"No results found for type query: {type_query}")
|
|
168
|
+
col_type = res[0].lower()
|
|
169
|
+
|
|
170
|
+
match col_type:
|
|
171
|
+
case "boolean":
|
|
172
|
+
normalization_case = """
|
|
173
|
+
CASE
|
|
174
|
+
WHEN value THEN 1
|
|
175
|
+
ELSE 0
|
|
176
|
+
END
|
|
177
|
+
"""
|
|
178
|
+
case "integer" | "bigint":
|
|
179
|
+
normalization_case = """
|
|
180
|
+
CASE
|
|
181
|
+
WHEN value = 1 THEN 1
|
|
182
|
+
WHEN value = 0 THEN 0
|
|
183
|
+
ELSE NULL
|
|
184
|
+
END
|
|
185
|
+
"""
|
|
186
|
+
case _:
|
|
187
|
+
raise ValueError(f"Unsupported column type: {col_type}")
|
|
188
|
+
|
|
189
|
+
return self.generate_confusion_matrix_metrics(
|
|
190
|
+
ddb_conn,
|
|
191
|
+
timestamp_col,
|
|
192
|
+
prediction_col,
|
|
193
|
+
gt_values_col,
|
|
194
|
+
normalization_case,
|
|
195
|
+
normalization_case,
|
|
196
|
+
dataset,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class BinaryClassifierStringLabelConfusionMatrixAggregationFunction(
|
|
201
|
+
ConfusionMatrixAggregationFunction,
|
|
202
|
+
):
|
|
203
|
+
@staticmethod
|
|
204
|
+
def id() -> UUID:
|
|
205
|
+
return UUID("00000000-0000-0000-0000-00000000001d")
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def display_name() -> str:
|
|
209
|
+
return "Binary Classification Confusion Matrix - String Class Label Prediction"
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def description() -> str:
|
|
213
|
+
return "Aggregation that takes in string labelled prediction and ground truth values and calculates the confusion matrix (True Positives, False Positives, False Negatives, True Negatives) for a binary set of predictions and values."
|
|
214
|
+
|
|
215
|
+
def aggregate(
|
|
216
|
+
self,
|
|
217
|
+
ddb_conn: DuckDBPyConnection,
|
|
218
|
+
dataset: Annotated[
|
|
219
|
+
DatasetReference,
|
|
220
|
+
MetricDatasetParameterAnnotation(
|
|
221
|
+
friendly_name="Dataset",
|
|
222
|
+
description="The dataset containing the prediction and ground truth values.",
|
|
223
|
+
model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
|
|
224
|
+
),
|
|
225
|
+
],
|
|
226
|
+
timestamp_col: Annotated[
|
|
227
|
+
str,
|
|
228
|
+
MetricColumnParameterAnnotation(
|
|
229
|
+
source_dataset_parameter_key="dataset",
|
|
230
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
231
|
+
allowed_column_types=[
|
|
232
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
233
|
+
],
|
|
234
|
+
friendly_name="Timestamp Column",
|
|
235
|
+
description="A column containing timestamp values to bucket by.",
|
|
236
|
+
),
|
|
237
|
+
],
|
|
238
|
+
prediction_col: Annotated[
|
|
239
|
+
str,
|
|
240
|
+
MetricColumnParameterAnnotation(
|
|
241
|
+
source_dataset_parameter_key="dataset",
|
|
242
|
+
allowed_column_types=[
|
|
243
|
+
ScalarType(dtype=DType.STRING),
|
|
244
|
+
],
|
|
245
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
246
|
+
friendly_name="Prediction Column",
|
|
247
|
+
description="A column containing string labelled prediction values.",
|
|
248
|
+
),
|
|
249
|
+
],
|
|
250
|
+
gt_values_col: Annotated[
|
|
251
|
+
str,
|
|
252
|
+
MetricColumnParameterAnnotation(
|
|
253
|
+
source_dataset_parameter_key="dataset",
|
|
254
|
+
allowed_column_types=[
|
|
255
|
+
ScalarType(dtype=DType.STRING),
|
|
256
|
+
],
|
|
257
|
+
tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
|
|
258
|
+
friendly_name="Ground Truth Column",
|
|
259
|
+
description="A column containing string labelled ground truth values.",
|
|
260
|
+
),
|
|
261
|
+
],
|
|
262
|
+
true_label: Annotated[
|
|
263
|
+
str,
|
|
264
|
+
MetricLiteralParameterAnnotation(
|
|
265
|
+
parameter_dtype=DType.STRING,
|
|
266
|
+
friendly_name="True Label",
|
|
267
|
+
description="The label indicating a positive classification to normalize to 1.",
|
|
268
|
+
),
|
|
269
|
+
],
|
|
270
|
+
false_label: Annotated[
|
|
271
|
+
str,
|
|
272
|
+
MetricLiteralParameterAnnotation(
|
|
273
|
+
parameter_dtype=DType.STRING,
|
|
274
|
+
friendly_name="False Label",
|
|
275
|
+
description="The label indicating a negative classification to normalize to 0.",
|
|
276
|
+
),
|
|
277
|
+
],
|
|
278
|
+
) -> list[NumericMetric]:
|
|
279
|
+
normalization_case = f"""
|
|
280
|
+
CASE
|
|
281
|
+
WHEN value = '{true_label}' THEN 1
|
|
282
|
+
WHEN value = '{false_label}' THEN 0
|
|
283
|
+
ELSE NULL
|
|
284
|
+
END
|
|
285
|
+
"""
|
|
286
|
+
return self.generate_confusion_matrix_metrics(
|
|
287
|
+
ddb_conn,
|
|
288
|
+
timestamp_col,
|
|
289
|
+
prediction_col,
|
|
290
|
+
gt_values_col,
|
|
291
|
+
normalization_case,
|
|
292
|
+
normalization_case,
|
|
293
|
+
dataset,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
|
|
298
|
+
ConfusionMatrixAggregationFunction,
|
|
299
|
+
):
|
|
300
|
+
@staticmethod
|
|
301
|
+
def id() -> UUID:
|
|
302
|
+
return UUID("00000000-0000-0000-0000-00000000001e")
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def display_name() -> str:
|
|
306
|
+
return "Binary Classification Confusion Matrix - Probability Threshold"
|
|
307
|
+
|
|
308
|
+
@staticmethod
|
|
309
|
+
def description() -> str:
|
|
310
|
+
return "Aggregation that takes in a float prediction column, a ground truth values column, and a probability threshold and calculates the confusion matrix (True Positives, False Positives, False Negatives, True Negatives) for a binary set of predictions and values where the predictions are calculated using the probability threshold."
|
|
311
|
+
|
|
312
|
+
def aggregate(
|
|
313
|
+
self,
|
|
314
|
+
ddb_conn: DuckDBPyConnection,
|
|
315
|
+
dataset: Annotated[
|
|
316
|
+
DatasetReference,
|
|
317
|
+
MetricDatasetParameterAnnotation(
|
|
318
|
+
friendly_name="Dataset",
|
|
319
|
+
description="The dataset containing the prediction and ground truth values.",
|
|
320
|
+
model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
|
|
321
|
+
),
|
|
322
|
+
],
|
|
323
|
+
timestamp_col: Annotated[
|
|
324
|
+
str,
|
|
325
|
+
MetricColumnParameterAnnotation(
|
|
326
|
+
source_dataset_parameter_key="dataset",
|
|
327
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
328
|
+
allowed_column_types=[
|
|
329
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
330
|
+
],
|
|
331
|
+
friendly_name="Timestamp Column",
|
|
332
|
+
description="A column containing timestamp values to bucket by.",
|
|
333
|
+
),
|
|
334
|
+
],
|
|
335
|
+
prediction_col: Annotated[
|
|
336
|
+
str,
|
|
337
|
+
MetricColumnParameterAnnotation(
|
|
338
|
+
source_dataset_parameter_key="dataset",
|
|
339
|
+
allowed_column_types=[
|
|
340
|
+
ScalarType(dtype=DType.FLOAT),
|
|
341
|
+
],
|
|
342
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
343
|
+
friendly_name="Prediction Column",
|
|
344
|
+
description="A column containing float prediction values.",
|
|
345
|
+
),
|
|
346
|
+
],
|
|
347
|
+
gt_values_col: Annotated[
|
|
348
|
+
str,
|
|
349
|
+
MetricColumnParameterAnnotation(
|
|
350
|
+
source_dataset_parameter_key="dataset",
|
|
351
|
+
allowed_column_types=[
|
|
352
|
+
ScalarType(dtype=DType.BOOL),
|
|
353
|
+
ScalarType(dtype=DType.INT),
|
|
354
|
+
],
|
|
355
|
+
tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
|
|
356
|
+
friendly_name="Ground Truth Column",
|
|
357
|
+
description="A column containing boolean or integer ground truth values.",
|
|
358
|
+
),
|
|
359
|
+
],
|
|
360
|
+
threshold: Annotated[
|
|
361
|
+
float,
|
|
362
|
+
MetricLiteralParameterAnnotation(
|
|
363
|
+
parameter_dtype=DType.FLOAT,
|
|
364
|
+
friendly_name="Threshold",
|
|
365
|
+
description="The threshold to classify predictions to 0 or 1.",
|
|
366
|
+
),
|
|
367
|
+
],
|
|
368
|
+
) -> list[NumericMetric]:
|
|
369
|
+
escaped_gt_values_col = escape_identifier(gt_values_col)
|
|
370
|
+
prediction_normalization_case = f"""
|
|
371
|
+
CASE
|
|
372
|
+
WHEN value >= {threshold} THEN 1
|
|
373
|
+
WHEN value < {threshold} THEN 0
|
|
374
|
+
ELSE NULL
|
|
375
|
+
END
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
type_query = f"SELECT typeof({escaped_gt_values_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
|
|
379
|
+
res = ddb_conn.sql(type_query).fetchone()
|
|
380
|
+
# As long as this column exists, we should be able to get the type. This is here to make mypy happy.
|
|
381
|
+
if not res:
|
|
382
|
+
raise ValueError(f"No results found for type query: {type_query}")
|
|
383
|
+
col_type = res[0].lower()
|
|
384
|
+
|
|
385
|
+
match col_type:
|
|
386
|
+
case "boolean":
|
|
387
|
+
gt_normalization_case = """
|
|
388
|
+
CASE
|
|
389
|
+
WHEN value THEN 1
|
|
390
|
+
ELSE 0
|
|
391
|
+
END
|
|
392
|
+
"""
|
|
393
|
+
case "integer" | "bigint":
|
|
394
|
+
gt_normalization_case = """
|
|
395
|
+
CASE
|
|
396
|
+
WHEN value = 1 THEN 1
|
|
397
|
+
WHEN value = 0 THEN 0
|
|
398
|
+
ELSE NULL
|
|
399
|
+
END
|
|
400
|
+
"""
|
|
401
|
+
case _:
|
|
402
|
+
raise ValueError(f"Unsupported column type: {col_type}")
|
|
403
|
+
|
|
404
|
+
return self.generate_confusion_matrix_metrics(
|
|
405
|
+
ddb_conn,
|
|
406
|
+
timestamp_col,
|
|
407
|
+
prediction_col,
|
|
408
|
+
gt_values_col,
|
|
409
|
+
prediction_normalization_case,
|
|
410
|
+
gt_normalization_case,
|
|
411
|
+
dataset,
|
|
412
|
+
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
5
|
+
from arthur_common.models.metrics import DatasetReference, NumericMetric
|
|
6
|
+
from arthur_common.models.schema_definitions import (
|
|
7
|
+
DType,
|
|
8
|
+
MetricColumnParameterAnnotation,
|
|
9
|
+
MetricDatasetParameterAnnotation,
|
|
10
|
+
ScalarType,
|
|
11
|
+
ScopeSchemaTag,
|
|
12
|
+
)
|
|
13
|
+
from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
14
|
+
from duckdb import DuckDBPyConnection
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InferenceCountAggregationFunction(NumericAggregationFunction):
|
|
18
|
+
METRIC_NAME = "inference_count"
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def id() -> UUID:
|
|
22
|
+
return UUID("00000000-0000-0000-0000-00000000000a")
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def display_name() -> str:
|
|
26
|
+
return "Inference Count"
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def description() -> str:
|
|
30
|
+
return "Metric that counts the number of inferences per time window."
|
|
31
|
+
|
|
32
|
+
def aggregate(
|
|
33
|
+
self,
|
|
34
|
+
ddb_conn: DuckDBPyConnection,
|
|
35
|
+
dataset: Annotated[
|
|
36
|
+
DatasetReference,
|
|
37
|
+
MetricDatasetParameterAnnotation(
|
|
38
|
+
friendly_name="Dataset",
|
|
39
|
+
description="The dataset containing the inference data.",
|
|
40
|
+
),
|
|
41
|
+
],
|
|
42
|
+
timestamp_col: Annotated[
|
|
43
|
+
str,
|
|
44
|
+
MetricColumnParameterAnnotation(
|
|
45
|
+
source_dataset_parameter_key="dataset",
|
|
46
|
+
allowed_column_types=[
|
|
47
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
48
|
+
],
|
|
49
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
50
|
+
friendly_name="Timestamp Column",
|
|
51
|
+
description="A column containing timestamp values to bucket by.",
|
|
52
|
+
),
|
|
53
|
+
],
|
|
54
|
+
) -> list[NumericMetric]:
|
|
55
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
56
|
+
count_query = f" \
|
|
57
|
+
select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
|
|
58
|
+
count(*) as count \
|
|
59
|
+
from {dataset.dataset_table_name} \
|
|
60
|
+
group by ts \
|
|
61
|
+
"
|
|
62
|
+
results = ddb_conn.sql(count_query).df()
|
|
63
|
+
series = self.dimensionless_query_results_to_numeric_metrics(
|
|
64
|
+
results,
|
|
65
|
+
"count",
|
|
66
|
+
"ts",
|
|
67
|
+
)
|
|
68
|
+
metric = self.series_to_metric(self.METRIC_NAME, [series])
|
|
69
|
+
return [metric]
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
5
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
6
|
+
from arthur_common.models.metrics import DatasetReference, NumericMetric
|
|
7
|
+
from arthur_common.models.schema_definitions import (
|
|
8
|
+
DType,
|
|
9
|
+
MetricColumnParameterAnnotation,
|
|
10
|
+
MetricDatasetParameterAnnotation,
|
|
11
|
+
MetricLiteralParameterAnnotation,
|
|
12
|
+
ScalarType,
|
|
13
|
+
ScopeSchemaTag,
|
|
14
|
+
)
|
|
15
|
+
from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
16
|
+
from duckdb import DuckDBPyConnection
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction):
|
|
20
|
+
@staticmethod
|
|
21
|
+
def id() -> UUID:
|
|
22
|
+
return UUID("00000000-0000-0000-0000-00000000001f")
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def display_name() -> str:
|
|
26
|
+
return "Binary Classification Count by Class - Class Label"
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def description() -> str:
|
|
30
|
+
return "Aggregation that counts the number of predictions by class for a binary classifier. Takes boolean, integer, or string prediction values and groups them by time bucket to show prediction distribution over time."
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def _metric_name() -> str:
|
|
34
|
+
return "binary_classifier_count_by_class"
|
|
35
|
+
|
|
36
|
+
def aggregate(
|
|
37
|
+
self,
|
|
38
|
+
ddb_conn: DuckDBPyConnection,
|
|
39
|
+
dataset: Annotated[
|
|
40
|
+
DatasetReference,
|
|
41
|
+
MetricDatasetParameterAnnotation(
|
|
42
|
+
friendly_name="Dataset",
|
|
43
|
+
description="The dataset containing binary classifier prediction values.",
|
|
44
|
+
model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
|
|
45
|
+
),
|
|
46
|
+
],
|
|
47
|
+
timestamp_col: Annotated[
|
|
48
|
+
str,
|
|
49
|
+
MetricColumnParameterAnnotation(
|
|
50
|
+
source_dataset_parameter_key="dataset",
|
|
51
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
52
|
+
allowed_column_types=[
|
|
53
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
54
|
+
],
|
|
55
|
+
friendly_name="Timestamp Column",
|
|
56
|
+
description="A column containing timestamp values to bucket by.",
|
|
57
|
+
),
|
|
58
|
+
],
|
|
59
|
+
prediction_col: Annotated[
|
|
60
|
+
str,
|
|
61
|
+
MetricColumnParameterAnnotation(
|
|
62
|
+
source_dataset_parameter_key="dataset",
|
|
63
|
+
allowed_column_types=[
|
|
64
|
+
ScalarType(dtype=DType.BOOL),
|
|
65
|
+
ScalarType(dtype=DType.INT),
|
|
66
|
+
ScalarType(dtype=DType.STRING),
|
|
67
|
+
],
|
|
68
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
69
|
+
friendly_name="Prediction Column",
|
|
70
|
+
description="A column containing boolean, integer, or string labelled prediction values.",
|
|
71
|
+
),
|
|
72
|
+
],
|
|
73
|
+
) -> list[NumericMetric]:
|
|
74
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
75
|
+
escaped_pred_col = escape_identifier(prediction_col)
|
|
76
|
+
query = f"""
|
|
77
|
+
SELECT
|
|
78
|
+
time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
|
|
79
|
+
{escaped_pred_col} as prediction,
|
|
80
|
+
COUNT(*) as count
|
|
81
|
+
FROM {dataset.dataset_table_name}
|
|
82
|
+
GROUP BY
|
|
83
|
+
ts,
|
|
84
|
+
-- group by raw column name instead of alias in select
|
|
85
|
+
-- in case table has a column called 'prediction'
|
|
86
|
+
{escaped_pred_col}
|
|
87
|
+
ORDER BY ts
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
result = ddb_conn.sql(query).df()
|
|
91
|
+
|
|
92
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
93
|
+
result,
|
|
94
|
+
"count",
|
|
95
|
+
["prediction"],
|
|
96
|
+
"ts",
|
|
97
|
+
)
|
|
98
|
+
metric = self.series_to_metric(self._metric_name(), series)
|
|
99
|
+
return [metric]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class BinaryClassifierCountThresholdClassAggregationFunction(
|
|
103
|
+
NumericAggregationFunction,
|
|
104
|
+
):
|
|
105
|
+
@staticmethod
|
|
106
|
+
def id() -> UUID:
|
|
107
|
+
return UUID("00000000-0000-0000-0000-000000000020")
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def display_name() -> str:
|
|
111
|
+
return "Binary Classification Count by Class - Probability Threshold"
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def description() -> str:
|
|
115
|
+
return "Aggregation that counts the number of predictions by class for a binary classifier using a probability threshold. Takes float prediction values and a threshold value to classify predictions, then groups them by time bucket to show prediction distribution over time."
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def _metric_name() -> str:
|
|
119
|
+
return "binary_classifier_count_by_class"
|
|
120
|
+
|
|
121
|
+
def aggregate(
|
|
122
|
+
self,
|
|
123
|
+
ddb_conn: DuckDBPyConnection,
|
|
124
|
+
dataset: Annotated[
|
|
125
|
+
DatasetReference,
|
|
126
|
+
MetricDatasetParameterAnnotation(
|
|
127
|
+
friendly_name="Dataset",
|
|
128
|
+
description="The dataset containing binary classifier prediction values.",
|
|
129
|
+
model_problem_type=ModelProblemType.BINARY_CLASSIFICATION,
|
|
130
|
+
),
|
|
131
|
+
],
|
|
132
|
+
timestamp_col: Annotated[
|
|
133
|
+
str,
|
|
134
|
+
MetricColumnParameterAnnotation(
|
|
135
|
+
source_dataset_parameter_key="dataset",
|
|
136
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
137
|
+
allowed_column_types=[
|
|
138
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
139
|
+
],
|
|
140
|
+
friendly_name="Timestamp Column",
|
|
141
|
+
description="A column containing timestamp values to bucket by.",
|
|
142
|
+
),
|
|
143
|
+
],
|
|
144
|
+
prediction_col: Annotated[
|
|
145
|
+
str,
|
|
146
|
+
MetricColumnParameterAnnotation(
|
|
147
|
+
source_dataset_parameter_key="dataset",
|
|
148
|
+
allowed_column_types=[
|
|
149
|
+
ScalarType(dtype=DType.FLOAT),
|
|
150
|
+
],
|
|
151
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
152
|
+
friendly_name="Prediction Column",
|
|
153
|
+
description="A column containing float prediction values.",
|
|
154
|
+
),
|
|
155
|
+
],
|
|
156
|
+
threshold: Annotated[
|
|
157
|
+
float,
|
|
158
|
+
MetricLiteralParameterAnnotation(
|
|
159
|
+
parameter_dtype=DType.FLOAT,
|
|
160
|
+
friendly_name="Threshold",
|
|
161
|
+
description="The threshold to classify predictions to 0 or 1. 0 will result in the 'False Label' being assigned and 1 to the 'True Label' being assigned.",
|
|
162
|
+
),
|
|
163
|
+
],
|
|
164
|
+
true_label: Annotated[
|
|
165
|
+
str,
|
|
166
|
+
MetricLiteralParameterAnnotation(
|
|
167
|
+
parameter_dtype=DType.STRING,
|
|
168
|
+
friendly_name="True Label",
|
|
169
|
+
description="The label denoting a positive classification.",
|
|
170
|
+
),
|
|
171
|
+
],
|
|
172
|
+
false_label: Annotated[
|
|
173
|
+
str,
|
|
174
|
+
MetricLiteralParameterAnnotation(
|
|
175
|
+
parameter_dtype=DType.STRING,
|
|
176
|
+
friendly_name="False Label",
|
|
177
|
+
description="The label denoting a negative classification.",
|
|
178
|
+
),
|
|
179
|
+
],
|
|
180
|
+
) -> list[NumericMetric]:
|
|
181
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
182
|
+
escaped_prediction_col = escape_identifier(prediction_col)
|
|
183
|
+
query = f"""
|
|
184
|
+
SELECT
|
|
185
|
+
time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
|
|
186
|
+
CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
|
|
187
|
+
COUNT(*) as count
|
|
188
|
+
FROM {dataset.dataset_table_name}
|
|
189
|
+
GROUP BY
|
|
190
|
+
ts,
|
|
191
|
+
-- group by raw column name instead of alias in select
|
|
192
|
+
-- in case table has a column called 'prediction'
|
|
193
|
+
{escaped_prediction_col}
|
|
194
|
+
ORDER BY ts
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
result = ddb_conn.sql(query).df()
|
|
198
|
+
|
|
199
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
200
|
+
result,
|
|
201
|
+
"count",
|
|
202
|
+
["prediction"],
|
|
203
|
+
"ts",
|
|
204
|
+
)
|
|
205
|
+
metric = self.series_to_metric(self._metric_name(), series)
|
|
206
|
+
return [metric]
|