arthur-common 1.0.1__py3-none-any.whl → 2.1.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/aggregations/aggregator.py +10 -1
- arthur_common/aggregations/functions/categorical_count.py +51 -11
- arthur_common/aggregations/functions/confusion_matrix.py +117 -27
- arthur_common/aggregations/functions/inference_count.py +46 -9
- arthur_common/aggregations/functions/inference_count_by_class.py +101 -24
- arthur_common/aggregations/functions/inference_null_count.py +50 -10
- arthur_common/aggregations/functions/mean_absolute_error.py +55 -15
- arthur_common/aggregations/functions/mean_squared_error.py +55 -15
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +78 -24
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +19 -1
- arthur_common/aggregations/functions/numeric_stats.py +46 -9
- arthur_common/aggregations/functions/numeric_sum.py +52 -12
- arthur_common/models/connectors.py +6 -1
- arthur_common/models/metrics.py +5 -9
- arthur_common/models/schema_definitions.py +2 -0
- arthur_common/tools/aggregation_analyzer.py +31 -1
- arthur_common/tools/duckdb_data_loader.py +1 -1
- {arthur_common-1.0.1.dist-info → arthur_common-2.1.47.dist-info}/METADATA +1 -4
- {arthur_common-1.0.1.dist-info → arthur_common-2.1.47.dist-info}/RECORD +20 -21
- arthur_common/__version__.py +0 -1
- {arthur_common-1.0.1.dist-info → arthur_common-2.1.47.dist-info}/WHEEL +0 -0
|
@@ -73,6 +73,15 @@ class NumericAggregationFunction(AggregationFunction, ABC):
|
|
|
73
73
|
From there, iterate over the group turning each data point to a *Point. At the end, this single instance of the group metrics
|
|
74
74
|
and the list of points (values) are merged to one *TimeSeries
|
|
75
75
|
"""
|
|
76
|
+
if not dim_columns:
|
|
77
|
+
return [
|
|
78
|
+
NumericAggregationFunction._dimensionless_query_results_to_numeric_metrics(
|
|
79
|
+
data,
|
|
80
|
+
value_col,
|
|
81
|
+
timestamp_col,
|
|
82
|
+
),
|
|
83
|
+
]
|
|
84
|
+
|
|
76
85
|
calculated_metrics: list[NumericTimeSeries] = []
|
|
77
86
|
# make sure dropna is False or rows with "null" as a dimension value will be dropped
|
|
78
87
|
groups = data.groupby(dim_columns, dropna=False)
|
|
@@ -99,7 +108,7 @@ class NumericAggregationFunction(AggregationFunction, ABC):
|
|
|
99
108
|
return calculated_metrics
|
|
100
109
|
|
|
101
110
|
@staticmethod
|
|
102
|
-
def
|
|
111
|
+
def _dimensionless_query_results_to_numeric_metrics(
|
|
103
112
|
data: pd.DataFrame,
|
|
104
113
|
value_col: str,
|
|
105
114
|
timestamp_col: str,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Annotated
|
|
1
|
+
from typing import Annotated, Optional
|
|
2
2
|
from uuid import UUID
|
|
3
3
|
|
|
4
4
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
@@ -7,6 +7,7 @@ from arthur_common.models.schema_definitions import (
|
|
|
7
7
|
DType,
|
|
8
8
|
MetricColumnParameterAnnotation,
|
|
9
9
|
MetricDatasetParameterAnnotation,
|
|
10
|
+
MetricMultipleColumnParameterAnnotation,
|
|
10
11
|
ScalarType,
|
|
11
12
|
ScopeSchemaTag,
|
|
12
13
|
)
|
|
@@ -64,25 +65,64 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
|
|
|
64
65
|
description="A column containing categorical values to count.",
|
|
65
66
|
),
|
|
66
67
|
],
|
|
68
|
+
segmentation_cols: Annotated[
|
|
69
|
+
Optional[list[str]],
|
|
70
|
+
MetricMultipleColumnParameterAnnotation(
|
|
71
|
+
source_dataset_parameter_key="dataset",
|
|
72
|
+
allowed_column_types=[
|
|
73
|
+
ScalarType(dtype=DType.INT),
|
|
74
|
+
ScalarType(dtype=DType.BOOL),
|
|
75
|
+
ScalarType(dtype=DType.STRING),
|
|
76
|
+
ScalarType(dtype=DType.UUID),
|
|
77
|
+
],
|
|
78
|
+
tag_hints=[],
|
|
79
|
+
friendly_name="Segmentation Columns",
|
|
80
|
+
description="All columns to include as dimensions for segmentation.",
|
|
81
|
+
optional=True,
|
|
82
|
+
),
|
|
83
|
+
] = None,
|
|
67
84
|
) -> list[NumericMetric]:
|
|
85
|
+
"""Executed SQL with no segmentation columns:
|
|
86
|
+
select time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts, \
|
|
87
|
+
count(*) as count, \
|
|
88
|
+
{categorical_col_escaped} as category, \
|
|
89
|
+
{categorical_col_name_escaped} as column_name \
|
|
90
|
+
from {dataset.dataset_table_name} \
|
|
91
|
+
where ts is not null \
|
|
92
|
+
group by ts, category
|
|
93
|
+
"""
|
|
94
|
+
segmentation_cols = [] if not segmentation_cols else segmentation_cols
|
|
68
95
|
timestamp_col_escaped = escape_identifier(timestamp_col)
|
|
69
96
|
categorical_col_escaped = escape_identifier(categorical_col)
|
|
70
97
|
categorical_col_name_escaped = escape_str_literal(categorical_col)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
98
|
+
|
|
99
|
+
# build query components with segmentation columns
|
|
100
|
+
escaped_segmentation_cols = [
|
|
101
|
+
escape_identifier(col) for col in segmentation_cols
|
|
102
|
+
]
|
|
103
|
+
all_select_clause_cols = [
|
|
104
|
+
f"time_bucket(INTERVAL '5 minutes', {timestamp_col_escaped}) as ts",
|
|
105
|
+
f"count(*) as count",
|
|
106
|
+
f"{categorical_col_escaped} as category",
|
|
107
|
+
f"{categorical_col_name_escaped} as column_name",
|
|
108
|
+
] + escaped_segmentation_cols
|
|
109
|
+
all_group_by_cols = ["ts", "category"] + escaped_segmentation_cols
|
|
110
|
+
extra_dims = ["column_name", "category"]
|
|
111
|
+
|
|
112
|
+
# build query
|
|
113
|
+
count_query = f"""
|
|
114
|
+
select {", ".join(all_select_clause_cols)}
|
|
115
|
+
from {dataset.dataset_table_name}
|
|
116
|
+
where ts is not null
|
|
117
|
+
group by {", ".join(all_group_by_cols)}
|
|
118
|
+
"""
|
|
119
|
+
|
|
80
120
|
results = ddb_conn.sql(count_query).df()
|
|
81
121
|
|
|
82
122
|
series = self.group_query_results_to_numeric_metrics(
|
|
83
123
|
results,
|
|
84
124
|
"count",
|
|
85
|
-
|
|
125
|
+
segmentation_cols + extra_dims,
|
|
86
126
|
timestamp_col="ts",
|
|
87
127
|
)
|
|
88
128
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Annotated
|
|
1
|
+
from typing import Annotated, Optional
|
|
2
2
|
from uuid import UUID
|
|
3
3
|
|
|
4
4
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
@@ -9,6 +9,7 @@ from arthur_common.models.schema_definitions import (
|
|
|
9
9
|
MetricColumnParameterAnnotation,
|
|
10
10
|
MetricDatasetParameterAnnotation,
|
|
11
11
|
MetricLiteralParameterAnnotation,
|
|
12
|
+
MetricMultipleColumnParameterAnnotation,
|
|
12
13
|
ScalarType,
|
|
13
14
|
ScopeSchemaTag,
|
|
14
15
|
)
|
|
@@ -26,6 +27,7 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
|
|
|
26
27
|
prediction_normalization_case: str,
|
|
27
28
|
gt_normalization_case: str,
|
|
28
29
|
dataset: DatasetReference,
|
|
30
|
+
segmentation_cols: list[str],
|
|
29
31
|
) -> list[NumericMetric]:
|
|
30
32
|
"""
|
|
31
33
|
Generate a SQL query to compute confusion matrix metrics over time.
|
|
@@ -37,59 +39,94 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
|
|
|
37
39
|
prediction_normalization_case: SQL CASE statement for normalizing predictions to 0 / 1 / null using 'value' as the target column name
|
|
38
40
|
gt_normalization_case: SQL CASE statement for normalizing ground truth values to 0 / 1 / null using 'value' as the target column name
|
|
39
41
|
dataset: DatasetReference containing dataset metadata
|
|
42
|
+
segmentation_cols: list of columns to segment by
|
|
40
43
|
|
|
41
44
|
Returns:
|
|
42
45
|
str: SQL query that computes confusion matrix metrics
|
|
46
|
+
Without segmentation, this is the query:
|
|
47
|
+
WITH normalized_data AS (
|
|
48
|
+
SELECT
|
|
49
|
+
{escaped_timestamp_col} AS timestamp,
|
|
50
|
+
{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
|
|
51
|
+
{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
|
|
52
|
+
FROM {dataset.dataset_table_name}
|
|
53
|
+
WHERE {escaped_timestamp_col} IS NOT NULL
|
|
54
|
+
)
|
|
55
|
+
SELECT
|
|
56
|
+
time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
|
|
57
|
+
SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
|
|
58
|
+
SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
|
|
59
|
+
SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
|
|
60
|
+
SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count
|
|
61
|
+
FROM normalized_data
|
|
62
|
+
GROUP BY ts
|
|
63
|
+
ORDER BY ts
|
|
43
64
|
"""
|
|
65
|
+
segmentation_cols = [] if not segmentation_cols else segmentation_cols
|
|
44
66
|
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
45
67
|
escaped_prediction_col = escape_identifier(prediction_col)
|
|
46
68
|
escaped_gt_values_col = escape_identifier(gt_values_col)
|
|
69
|
+
# build query components with segmentation columns
|
|
70
|
+
escaped_segmentation_cols = [
|
|
71
|
+
escape_identifier(col) for col in segmentation_cols
|
|
72
|
+
]
|
|
73
|
+
first_subquery_select_cols = [
|
|
74
|
+
f"{escaped_timestamp_col} AS timestamp",
|
|
75
|
+
f"{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction",
|
|
76
|
+
f"{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value",
|
|
77
|
+
] + escaped_segmentation_cols
|
|
78
|
+
second_subquery_select_cols = [
|
|
79
|
+
"time_bucket(INTERVAL '5 minutes', timestamp) AS ts",
|
|
80
|
+
"SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count",
|
|
81
|
+
"SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count",
|
|
82
|
+
"SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count",
|
|
83
|
+
"SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count",
|
|
84
|
+
] + escaped_segmentation_cols
|
|
85
|
+
second_subquery_group_by_cols = ["ts"] + escaped_segmentation_cols
|
|
86
|
+
|
|
87
|
+
# build query
|
|
47
88
|
confusion_matrix_query = f"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
{
|
|
51
|
-
{
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
SUM(CASE WHEN prediction = actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
|
|
59
|
-
SUM(CASE WHEN prediction != actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
|
|
60
|
-
SUM(CASE WHEN prediction != actual_value AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
|
|
61
|
-
SUM(CASE WHEN prediction = actual_value AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count
|
|
62
|
-
FROM normalized_data
|
|
63
|
-
GROUP BY ts
|
|
64
|
-
ORDER BY ts
|
|
65
|
-
"""
|
|
89
|
+
WITH normalized_data AS (
|
|
90
|
+
SELECT {", ".join(first_subquery_select_cols)}
|
|
91
|
+
FROM {dataset.dataset_table_name}
|
|
92
|
+
WHERE {escaped_timestamp_col} IS NOT NULL
|
|
93
|
+
)
|
|
94
|
+
SELECT {", ".join(second_subquery_select_cols)}
|
|
95
|
+
FROM normalized_data
|
|
96
|
+
GROUP BY {", ".join(second_subquery_group_by_cols)}
|
|
97
|
+
ORDER BY ts
|
|
98
|
+
"""
|
|
66
99
|
|
|
67
100
|
results = ddb_conn.sql(confusion_matrix_query).df()
|
|
68
101
|
|
|
69
|
-
tp = self.
|
|
102
|
+
tp = self.group_query_results_to_numeric_metrics(
|
|
70
103
|
results,
|
|
71
104
|
"true_positive_count",
|
|
105
|
+
dim_columns=segmentation_cols,
|
|
72
106
|
timestamp_col="ts",
|
|
73
107
|
)
|
|
74
|
-
fp = self.
|
|
108
|
+
fp = self.group_query_results_to_numeric_metrics(
|
|
75
109
|
results,
|
|
76
110
|
"false_positive_count",
|
|
111
|
+
dim_columns=segmentation_cols,
|
|
77
112
|
timestamp_col="ts",
|
|
78
113
|
)
|
|
79
|
-
fn = self.
|
|
114
|
+
fn = self.group_query_results_to_numeric_metrics(
|
|
80
115
|
results,
|
|
81
116
|
"false_negative_count",
|
|
117
|
+
dim_columns=segmentation_cols,
|
|
82
118
|
timestamp_col="ts",
|
|
83
119
|
)
|
|
84
|
-
tn = self.
|
|
120
|
+
tn = self.group_query_results_to_numeric_metrics(
|
|
85
121
|
results,
|
|
86
122
|
"true_negative_count",
|
|
123
|
+
dim_columns=segmentation_cols,
|
|
87
124
|
timestamp_col="ts",
|
|
88
125
|
)
|
|
89
|
-
tp_metric = self.series_to_metric("confusion_matrix_true_positive_count",
|
|
90
|
-
fp_metric = self.series_to_metric("confusion_matrix_false_positive_count",
|
|
91
|
-
fn_metric = self.series_to_metric("confusion_matrix_false_negative_count",
|
|
92
|
-
tn_metric = self.series_to_metric("confusion_matrix_true_negative_count",
|
|
126
|
+
tp_metric = self.series_to_metric("confusion_matrix_true_positive_count", tp)
|
|
127
|
+
fp_metric = self.series_to_metric("confusion_matrix_false_positive_count", fp)
|
|
128
|
+
fn_metric = self.series_to_metric("confusion_matrix_false_negative_count", fn)
|
|
129
|
+
tn_metric = self.series_to_metric("confusion_matrix_true_negative_count", tn)
|
|
93
130
|
return [tp_metric, fp_metric, fn_metric, tn_metric]
|
|
94
131
|
|
|
95
132
|
|
|
@@ -157,7 +194,24 @@ class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
|
|
|
157
194
|
description="A column containing boolean or integer ground truth values.",
|
|
158
195
|
),
|
|
159
196
|
],
|
|
197
|
+
segmentation_cols: Annotated[
|
|
198
|
+
Optional[list[str]],
|
|
199
|
+
MetricMultipleColumnParameterAnnotation(
|
|
200
|
+
source_dataset_parameter_key="dataset",
|
|
201
|
+
allowed_column_types=[
|
|
202
|
+
ScalarType(dtype=DType.INT),
|
|
203
|
+
ScalarType(dtype=DType.BOOL),
|
|
204
|
+
ScalarType(dtype=DType.STRING),
|
|
205
|
+
ScalarType(dtype=DType.UUID),
|
|
206
|
+
],
|
|
207
|
+
tag_hints=[],
|
|
208
|
+
friendly_name="Segmentation Columns",
|
|
209
|
+
description="All columns to include as dimensions for segmentation.",
|
|
210
|
+
optional=True,
|
|
211
|
+
),
|
|
212
|
+
] = None,
|
|
160
213
|
) -> list[NumericMetric]:
|
|
214
|
+
segmentation_cols = [] if not segmentation_cols else segmentation_cols
|
|
161
215
|
escaped_prediction_col = escape_identifier(prediction_col)
|
|
162
216
|
# Get the type of prediction column
|
|
163
217
|
type_query = f"SELECT typeof({escaped_prediction_col}) as col_type FROM {dataset.dataset_table_name} LIMIT 1"
|
|
@@ -194,6 +248,7 @@ class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
|
|
|
194
248
|
normalization_case,
|
|
195
249
|
normalization_case,
|
|
196
250
|
dataset,
|
|
251
|
+
segmentation_cols,
|
|
197
252
|
)
|
|
198
253
|
|
|
199
254
|
|
|
@@ -275,7 +330,24 @@ class BinaryClassifierStringLabelConfusionMatrixAggregationFunction(
|
|
|
275
330
|
description="The label indicating a negative classification to normalize to 0.",
|
|
276
331
|
),
|
|
277
332
|
],
|
|
333
|
+
segmentation_cols: Annotated[
|
|
334
|
+
Optional[list[str]],
|
|
335
|
+
MetricMultipleColumnParameterAnnotation(
|
|
336
|
+
source_dataset_parameter_key="dataset",
|
|
337
|
+
allowed_column_types=[
|
|
338
|
+
ScalarType(dtype=DType.INT),
|
|
339
|
+
ScalarType(dtype=DType.BOOL),
|
|
340
|
+
ScalarType(dtype=DType.STRING),
|
|
341
|
+
ScalarType(dtype=DType.UUID),
|
|
342
|
+
],
|
|
343
|
+
tag_hints=[],
|
|
344
|
+
friendly_name="Segmentation Columns",
|
|
345
|
+
description="All columns to include as dimensions for segmentation.",
|
|
346
|
+
optional=True,
|
|
347
|
+
),
|
|
348
|
+
] = None,
|
|
278
349
|
) -> list[NumericMetric]:
|
|
350
|
+
segmentation_cols = [] if not segmentation_cols else segmentation_cols
|
|
279
351
|
normalization_case = f"""
|
|
280
352
|
CASE
|
|
281
353
|
WHEN value = '{true_label}' THEN 1
|
|
@@ -291,6 +363,7 @@ class BinaryClassifierStringLabelConfusionMatrixAggregationFunction(
|
|
|
291
363
|
normalization_case,
|
|
292
364
|
normalization_case,
|
|
293
365
|
dataset,
|
|
366
|
+
segmentation_cols,
|
|
294
367
|
)
|
|
295
368
|
|
|
296
369
|
|
|
@@ -365,6 +438,22 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
|
|
|
365
438
|
description="The threshold to classify predictions to 0 or 1.",
|
|
366
439
|
),
|
|
367
440
|
],
|
|
441
|
+
segmentation_cols: Annotated[
|
|
442
|
+
Optional[list[str]],
|
|
443
|
+
MetricMultipleColumnParameterAnnotation(
|
|
444
|
+
source_dataset_parameter_key="dataset",
|
|
445
|
+
allowed_column_types=[
|
|
446
|
+
ScalarType(dtype=DType.INT),
|
|
447
|
+
ScalarType(dtype=DType.BOOL),
|
|
448
|
+
ScalarType(dtype=DType.STRING),
|
|
449
|
+
ScalarType(dtype=DType.UUID),
|
|
450
|
+
],
|
|
451
|
+
tag_hints=[],
|
|
452
|
+
friendly_name="Segmentation Columns",
|
|
453
|
+
description="All columns to include as dimensions for segmentation.",
|
|
454
|
+
optional=True,
|
|
455
|
+
),
|
|
456
|
+
] = None,
|
|
368
457
|
) -> list[NumericMetric]:
|
|
369
458
|
escaped_gt_values_col = escape_identifier(gt_values_col)
|
|
370
459
|
prediction_normalization_case = f"""
|
|
@@ -409,4 +498,5 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
|
|
|
409
498
|
prediction_normalization_case,
|
|
410
499
|
gt_normalization_case,
|
|
411
500
|
dataset,
|
|
501
|
+
segmentation_cols,
|
|
412
502
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Annotated
|
|
1
|
+
from typing import Annotated, Optional
|
|
2
2
|
from uuid import UUID
|
|
3
3
|
|
|
4
4
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
@@ -7,6 +7,7 @@ from arthur_common.models.schema_definitions import (
|
|
|
7
7
|
DType,
|
|
8
8
|
MetricColumnParameterAnnotation,
|
|
9
9
|
MetricDatasetParameterAnnotation,
|
|
10
|
+
MetricMultipleColumnParameterAnnotation,
|
|
10
11
|
ScalarType,
|
|
11
12
|
ScopeSchemaTag,
|
|
12
13
|
)
|
|
@@ -51,19 +52,55 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
|
|
|
51
52
|
description="A column containing timestamp values to bucket by.",
|
|
52
53
|
),
|
|
53
54
|
],
|
|
55
|
+
segmentation_cols: Annotated[
|
|
56
|
+
Optional[list[str]],
|
|
57
|
+
MetricMultipleColumnParameterAnnotation(
|
|
58
|
+
source_dataset_parameter_key="dataset",
|
|
59
|
+
allowed_column_types=[
|
|
60
|
+
ScalarType(dtype=DType.INT),
|
|
61
|
+
ScalarType(dtype=DType.BOOL),
|
|
62
|
+
ScalarType(dtype=DType.STRING),
|
|
63
|
+
ScalarType(dtype=DType.UUID),
|
|
64
|
+
],
|
|
65
|
+
tag_hints=[],
|
|
66
|
+
friendly_name="Segmentation Columns",
|
|
67
|
+
description="All columns to include as dimensions for segmentation.",
|
|
68
|
+
optional=True,
|
|
69
|
+
),
|
|
70
|
+
] = None,
|
|
54
71
|
) -> list[NumericMetric]:
|
|
55
|
-
|
|
56
|
-
count_query = f" \
|
|
72
|
+
"""Executed SQL with no segmentation columns:
|
|
57
73
|
select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
"
|
|
74
|
+
count(*) as count \
|
|
75
|
+
from {dataset.dataset_table_name} \
|
|
76
|
+
group by ts \
|
|
77
|
+
"""
|
|
78
|
+
segmentation_cols = [] if not segmentation_cols else segmentation_cols
|
|
79
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
80
|
+
|
|
81
|
+
# build query components with segmentation columns
|
|
82
|
+
escaped_segmentation_cols = [
|
|
83
|
+
escape_identifier(col) for col in segmentation_cols
|
|
84
|
+
]
|
|
85
|
+
all_select_clause_cols = [
|
|
86
|
+
f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
|
|
87
|
+
f"count(*) as count",
|
|
88
|
+
] + escaped_segmentation_cols
|
|
89
|
+
all_group_by_cols = ["ts"] + escaped_segmentation_cols
|
|
90
|
+
|
|
91
|
+
# build query
|
|
92
|
+
count_query = f"""
|
|
93
|
+
select {", ".join(all_select_clause_cols)}
|
|
94
|
+
from {dataset.dataset_table_name}
|
|
95
|
+
group by {", ".join(all_group_by_cols)}
|
|
96
|
+
"""
|
|
97
|
+
|
|
62
98
|
results = ddb_conn.sql(count_query).df()
|
|
63
|
-
series = self.
|
|
99
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
64
100
|
results,
|
|
65
101
|
"count",
|
|
102
|
+
segmentation_cols,
|
|
66
103
|
"ts",
|
|
67
104
|
)
|
|
68
|
-
metric = self.series_to_metric(self.METRIC_NAME,
|
|
105
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
69
106
|
return [metric]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Annotated
|
|
1
|
+
from typing import Annotated, Optional
|
|
2
2
|
from uuid import UUID
|
|
3
3
|
|
|
4
4
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
@@ -9,6 +9,7 @@ from arthur_common.models.schema_definitions import (
|
|
|
9
9
|
MetricColumnParameterAnnotation,
|
|
10
10
|
MetricDatasetParameterAnnotation,
|
|
11
11
|
MetricLiteralParameterAnnotation,
|
|
12
|
+
MetricMultipleColumnParameterAnnotation,
|
|
12
13
|
ScalarType,
|
|
13
14
|
ScopeSchemaTag,
|
|
14
15
|
)
|
|
@@ -70,29 +71,66 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
|
|
|
70
71
|
description="A column containing boolean, integer, or string labelled prediction values.",
|
|
71
72
|
),
|
|
72
73
|
],
|
|
74
|
+
segmentation_cols: Annotated[
|
|
75
|
+
Optional[list[str]],
|
|
76
|
+
MetricMultipleColumnParameterAnnotation(
|
|
77
|
+
source_dataset_parameter_key="dataset",
|
|
78
|
+
allowed_column_types=[
|
|
79
|
+
ScalarType(dtype=DType.INT),
|
|
80
|
+
ScalarType(dtype=DType.BOOL),
|
|
81
|
+
ScalarType(dtype=DType.STRING),
|
|
82
|
+
ScalarType(dtype=DType.UUID),
|
|
83
|
+
],
|
|
84
|
+
tag_hints=[],
|
|
85
|
+
friendly_name="Segmentation Columns",
|
|
86
|
+
description="All columns to include as dimensions for segmentation.",
|
|
87
|
+
optional=True,
|
|
88
|
+
),
|
|
89
|
+
] = None,
|
|
73
90
|
) -> list[NumericMetric]:
|
|
91
|
+
"""Executed SQL with no segmentation columns:
|
|
92
|
+
SELECT
|
|
93
|
+
time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
|
|
94
|
+
{escaped_pred_col} as prediction,
|
|
95
|
+
COUNT(*) as count
|
|
96
|
+
FROM {dataset.dataset_table_name}
|
|
97
|
+
GROUP BY
|
|
98
|
+
ts,
|
|
99
|
+
-- group by raw column name instead of alias in select
|
|
100
|
+
-- in case table has a column called 'prediction'
|
|
101
|
+
{escaped_pred_col}
|
|
102
|
+
ORDER BY ts
|
|
103
|
+
"""
|
|
104
|
+
segmentation_cols = [] if not segmentation_cols else segmentation_cols
|
|
74
105
|
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
75
106
|
escaped_pred_col = escape_identifier(prediction_col)
|
|
107
|
+
|
|
108
|
+
# build query components with segmentation columns
|
|
109
|
+
escaped_segmentation_cols = [
|
|
110
|
+
escape_identifier(col) for col in segmentation_cols
|
|
111
|
+
]
|
|
112
|
+
all_select_clause_cols = [
|
|
113
|
+
f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
|
|
114
|
+
f"{escaped_pred_col} as prediction",
|
|
115
|
+
f"COUNT(*) as count",
|
|
116
|
+
] + escaped_segmentation_cols
|
|
117
|
+
all_group_by_cols = ["ts", f"{escaped_pred_col}"] + escaped_segmentation_cols
|
|
118
|
+
extra_dims = ["prediction"]
|
|
119
|
+
|
|
120
|
+
# build query
|
|
76
121
|
query = f"""
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
GROUP BY
|
|
83
|
-
ts,
|
|
84
|
-
-- group by raw column name instead of alias in select
|
|
85
|
-
-- in case table has a column called 'prediction'
|
|
86
|
-
{escaped_pred_col}
|
|
87
|
-
ORDER BY ts
|
|
88
|
-
"""
|
|
122
|
+
SELECT {", ".join(all_select_clause_cols)}
|
|
123
|
+
FROM {dataset.dataset_table_name}
|
|
124
|
+
GROUP BY {", ".join(all_group_by_cols)}
|
|
125
|
+
ORDER BY ts
|
|
126
|
+
"""
|
|
89
127
|
|
|
90
128
|
result = ddb_conn.sql(query).df()
|
|
91
129
|
|
|
92
130
|
series = self.group_query_results_to_numeric_metrics(
|
|
93
131
|
result,
|
|
94
132
|
"count",
|
|
95
|
-
|
|
133
|
+
segmentation_cols + extra_dims,
|
|
96
134
|
"ts",
|
|
97
135
|
)
|
|
98
136
|
metric = self.series_to_metric(self._metric_name(), series)
|
|
@@ -177,20 +215,59 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
|
|
|
177
215
|
description="The label denoting a negative classification.",
|
|
178
216
|
),
|
|
179
217
|
],
|
|
218
|
+
segmentation_cols: Annotated[
|
|
219
|
+
Optional[list[str]],
|
|
220
|
+
MetricMultipleColumnParameterAnnotation(
|
|
221
|
+
source_dataset_parameter_key="dataset",
|
|
222
|
+
allowed_column_types=[
|
|
223
|
+
ScalarType(dtype=DType.INT),
|
|
224
|
+
ScalarType(dtype=DType.BOOL),
|
|
225
|
+
ScalarType(dtype=DType.STRING),
|
|
226
|
+
ScalarType(dtype=DType.UUID),
|
|
227
|
+
],
|
|
228
|
+
tag_hints=[],
|
|
229
|
+
friendly_name="Segmentation Columns",
|
|
230
|
+
description="All columns to include as dimensions for segmentation.",
|
|
231
|
+
optional=True,
|
|
232
|
+
),
|
|
233
|
+
] = None,
|
|
180
234
|
) -> list[NumericMetric]:
|
|
235
|
+
"""Executed SQL with no segmentation columns:
|
|
236
|
+
SELECT
|
|
237
|
+
time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
|
|
238
|
+
CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
|
|
239
|
+
COUNT(*) as count
|
|
240
|
+
FROM {dataset.dataset_table_name}
|
|
241
|
+
GROUP BY
|
|
242
|
+
ts,
|
|
243
|
+
-- group by raw column name instead of alias in select
|
|
244
|
+
-- in case table has a column called 'prediction'
|
|
245
|
+
{escaped_prediction_col}
|
|
246
|
+
ORDER BY ts
|
|
247
|
+
"""
|
|
248
|
+
segmentation_cols = [] if not segmentation_cols else segmentation_cols
|
|
181
249
|
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
182
250
|
escaped_prediction_col = escape_identifier(prediction_col)
|
|
251
|
+
|
|
252
|
+
# build query components with segmentation columns
|
|
253
|
+
escaped_segmentation_cols = [
|
|
254
|
+
escape_identifier(col) for col in segmentation_cols
|
|
255
|
+
]
|
|
256
|
+
all_select_clause_cols = [
|
|
257
|
+
f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
|
|
258
|
+
f"CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction",
|
|
259
|
+
f"COUNT(*) as count",
|
|
260
|
+
] + escaped_segmentation_cols
|
|
261
|
+
all_group_by_cols = [
|
|
262
|
+
"ts",
|
|
263
|
+
f"{escaped_prediction_col}",
|
|
264
|
+
] + escaped_segmentation_cols
|
|
265
|
+
extra_dims = ["prediction"]
|
|
266
|
+
|
|
183
267
|
query = f"""
|
|
184
|
-
SELECT
|
|
185
|
-
time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts,
|
|
186
|
-
CASE WHEN {escaped_prediction_col} >= {threshold} THEN '{true_label}' ELSE '{false_label}' END as prediction,
|
|
187
|
-
COUNT(*) as count
|
|
268
|
+
SELECT {", ".join(all_select_clause_cols)}
|
|
188
269
|
FROM {dataset.dataset_table_name}
|
|
189
|
-
GROUP BY
|
|
190
|
-
ts,
|
|
191
|
-
-- group by raw column name instead of alias in select
|
|
192
|
-
-- in case table has a column called 'prediction'
|
|
193
|
-
{escaped_prediction_col}
|
|
270
|
+
GROUP BY {", ".join(all_group_by_cols)}
|
|
194
271
|
ORDER BY ts
|
|
195
272
|
"""
|
|
196
273
|
|
|
@@ -199,7 +276,7 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
|
|
|
199
276
|
series = self.group_query_results_to_numeric_metrics(
|
|
200
277
|
result,
|
|
201
278
|
"count",
|
|
202
|
-
|
|
279
|
+
segmentation_cols + extra_dims,
|
|
203
280
|
"ts",
|
|
204
281
|
)
|
|
205
282
|
metric = self.series_to_metric(self._metric_name(), series)
|