arthur-common 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

Files changed (40) hide show
  1. arthur_common/__init__.py +0 -0
  2. arthur_common/__version__.py +1 -0
  3. arthur_common/aggregations/__init__.py +2 -0
  4. arthur_common/aggregations/aggregator.py +214 -0
  5. arthur_common/aggregations/functions/README.md +26 -0
  6. arthur_common/aggregations/functions/__init__.py +25 -0
  7. arthur_common/aggregations/functions/categorical_count.py +89 -0
  8. arthur_common/aggregations/functions/confusion_matrix.py +412 -0
  9. arthur_common/aggregations/functions/inference_count.py +69 -0
  10. arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
  11. arthur_common/aggregations/functions/inference_null_count.py +82 -0
  12. arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
  13. arthur_common/aggregations/functions/mean_squared_error.py +110 -0
  14. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
  15. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
  16. arthur_common/aggregations/functions/numeric_stats.py +90 -0
  17. arthur_common/aggregations/functions/numeric_sum.py +87 -0
  18. arthur_common/aggregations/functions/py.typed +0 -0
  19. arthur_common/aggregations/functions/shield_aggregations.py +752 -0
  20. arthur_common/aggregations/py.typed +0 -0
  21. arthur_common/models/__init__.py +0 -0
  22. arthur_common/models/connectors.py +41 -0
  23. arthur_common/models/datasets.py +22 -0
  24. arthur_common/models/metrics.py +227 -0
  25. arthur_common/models/py.typed +0 -0
  26. arthur_common/models/schema_definitions.py +420 -0
  27. arthur_common/models/shield.py +504 -0
  28. arthur_common/models/task_job_specs.py +78 -0
  29. arthur_common/py.typed +0 -0
  30. arthur_common/tools/__init__.py +0 -0
  31. arthur_common/tools/aggregation_analyzer.py +243 -0
  32. arthur_common/tools/aggregation_loader.py +59 -0
  33. arthur_common/tools/duckdb_data_loader.py +329 -0
  34. arthur_common/tools/functions.py +46 -0
  35. arthur_common/tools/py.typed +0 -0
  36. arthur_common/tools/schema_inferer.py +104 -0
  37. arthur_common/tools/time_utils.py +33 -0
  38. arthur_common-1.0.1.dist-info/METADATA +74 -0
  39. arthur_common-1.0.1.dist-info/RECORD +40 -0
  40. arthur_common-1.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,82 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
+ from arthur_common.models.metrics import DatasetReference, Dimension, NumericMetric
6
+ from arthur_common.models.schema_definitions import (
7
+ DType,
8
+ MetricColumnParameterAnnotation,
9
+ MetricDatasetParameterAnnotation,
10
+ ScalarType,
11
+ ScopeSchemaTag,
12
+ )
13
+ from arthur_common.tools.duckdb_data_loader import escape_identifier
14
+ from duckdb import DuckDBPyConnection
15
+
16
+
17
+ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
18
+ METRIC_NAME = "null_count"
19
+
20
+ @staticmethod
21
+ def id() -> UUID:
22
+ return UUID("00000000-0000-0000-0000-00000000000b")
23
+
24
+ @staticmethod
25
+ def display_name() -> str:
26
+ return "Null Value Count"
27
+
28
+ @staticmethod
29
+ def description() -> str:
30
+ return "Metric that counts the number of null values in the column per time window."
31
+
32
+ def aggregate(
33
+ self,
34
+ ddb_conn: DuckDBPyConnection,
35
+ dataset: Annotated[
36
+ DatasetReference,
37
+ MetricDatasetParameterAnnotation(
38
+ friendly_name="Dataset",
39
+ description="The dataset containing the inference data.",
40
+ ),
41
+ ],
42
+ timestamp_col: Annotated[
43
+ str,
44
+ MetricColumnParameterAnnotation(
45
+ source_dataset_parameter_key="dataset",
46
+ allowed_column_types=[
47
+ ScalarType(dtype=DType.TIMESTAMP),
48
+ ],
49
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
50
+ friendly_name="Timestamp Column",
51
+ description="A column containing timestamp values to bucket by.",
52
+ ),
53
+ ],
54
+ nullable_col: Annotated[
55
+ str,
56
+ MetricColumnParameterAnnotation(
57
+ source_dataset_parameter_key="dataset",
58
+ allow_any_column_type=True,
59
+ friendly_name="Nullable Column",
60
+ description="A column containing nullable values to count.",
61
+ ),
62
+ ],
63
+ ) -> list[NumericMetric]:
64
+ escaped_timestamp_col = escape_identifier(timestamp_col)
65
+ escaped_nullable_col = escape_identifier(nullable_col)
66
+ count_query = f" \
67
+ select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
68
+ count(*) as count \
69
+ from {dataset.dataset_table_name} where {escaped_nullable_col} is null \
70
+ group by ts \
71
+ "
72
+ results = ddb_conn.sql(count_query).df()
73
+
74
+ series = self.dimensionless_query_results_to_numeric_metrics(
75
+ results,
76
+ "count",
77
+ "ts",
78
+ )
79
+ series.dimensions = [Dimension(name="column_name", value=nullable_col)]
80
+
81
+ metric = self.series_to_metric(self.METRIC_NAME, [series])
82
+ return [metric]
@@ -0,0 +1,110 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
+ from arthur_common.models.datasets import ModelProblemType
6
+ from arthur_common.models.metrics import DatasetReference, NumericMetric
7
+ from arthur_common.models.schema_definitions import (
8
+ DType,
9
+ MetricColumnParameterAnnotation,
10
+ MetricDatasetParameterAnnotation,
11
+ ScalarType,
12
+ ScopeSchemaTag,
13
+ )
14
+ from arthur_common.tools.duckdb_data_loader import escape_identifier
15
+ from duckdb import DuckDBPyConnection
16
+
17
+
18
+ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
19
+ @staticmethod
20
+ def id() -> UUID:
21
+ return UUID("00000000-0000-0000-0000-00000000000e")
22
+
23
+ @staticmethod
24
+ def display_name() -> str:
25
+ return "Mean Absolute Error"
26
+
27
+ @staticmethod
28
+ def description() -> str:
29
+ return "Metric that sums the absolute error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
30
+
31
+ def aggregate(
32
+ self,
33
+ ddb_conn: DuckDBPyConnection,
34
+ dataset: Annotated[
35
+ DatasetReference,
36
+ MetricDatasetParameterAnnotation(
37
+ friendly_name="Dataset",
38
+ description="The dataset containing the inference data.",
39
+ model_problem_type=ModelProblemType.REGRESSION,
40
+ ),
41
+ ],
42
+ timestamp_col: Annotated[
43
+ str,
44
+ MetricColumnParameterAnnotation(
45
+ source_dataset_parameter_key="dataset",
46
+ allowed_column_types=[
47
+ ScalarType(dtype=DType.TIMESTAMP),
48
+ ],
49
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
50
+ friendly_name="Timestamp Column",
51
+ description="A column containing timestamp values to bucket by.",
52
+ ),
53
+ ],
54
+ prediction_col: Annotated[
55
+ str,
56
+ MetricColumnParameterAnnotation(
57
+ source_dataset_parameter_key="dataset",
58
+ allowed_column_types=[
59
+ ScalarType(dtype=DType.FLOAT),
60
+ ],
61
+ tag_hints=[ScopeSchemaTag.PREDICTION],
62
+ friendly_name="Prediction Column",
63
+ description="A column containing float typed prediction values.",
64
+ ),
65
+ ],
66
+ ground_truth_col: Annotated[
67
+ str,
68
+ MetricColumnParameterAnnotation(
69
+ source_dataset_parameter_key="dataset",
70
+ allowed_column_types=[
71
+ ScalarType(dtype=DType.FLOAT),
72
+ ],
73
+ tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
74
+ friendly_name="Ground Truth Column",
75
+ description="A column containing float typed ground truth values.",
76
+ ),
77
+ ],
78
+ ) -> list[NumericMetric]:
79
+ escaped_timestamp_col = escape_identifier(timestamp_col)
80
+ escaped_prediction_col = escape_identifier(prediction_col)
81
+ escaped_ground_truth_col = escape_identifier(ground_truth_col)
82
+ count_query = f" \
83
+ SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
84
+ SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae, \
85
+ COUNT(*) as count \
86
+ FROM {dataset.dataset_table_name} \
87
+ WHERE {escaped_prediction_col} IS NOT NULL \
88
+ AND {escaped_ground_truth_col} IS NOT NULL \
89
+ GROUP BY ts order by ts desc \
90
+ "
91
+
92
+ results = ddb_conn.sql(count_query).df()
93
+ count_series = self.dimensionless_query_results_to_numeric_metrics(
94
+ results,
95
+ "count",
96
+ "ts",
97
+ )
98
+ absolute_error_series = self.dimensionless_query_results_to_numeric_metrics(
99
+ results,
100
+ "ae",
101
+ "ts",
102
+ )
103
+
104
+ count_metric = self.series_to_metric("absolute_error_count", [count_series])
105
+ absolute_error_metric = self.series_to_metric(
106
+ "absolute_error_sum",
107
+ [absolute_error_series],
108
+ )
109
+
110
+ return [count_metric, absolute_error_metric]
@@ -0,0 +1,110 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
+ from arthur_common.models.datasets import ModelProblemType
6
+ from arthur_common.models.metrics import DatasetReference, NumericMetric
7
+ from arthur_common.models.schema_definitions import (
8
+ DType,
9
+ MetricColumnParameterAnnotation,
10
+ MetricDatasetParameterAnnotation,
11
+ ScalarType,
12
+ ScopeSchemaTag,
13
+ )
14
+ from arthur_common.tools.duckdb_data_loader import escape_identifier
15
+ from duckdb import DuckDBPyConnection
16
+
17
+
18
+ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
19
+ @staticmethod
20
+ def id() -> UUID:
21
+ return UUID("00000000-0000-0000-0000-000000000010")
22
+
23
+ @staticmethod
24
+ def display_name() -> str:
25
+ return "Mean Squared Error"
26
+
27
+ @staticmethod
28
+ def description() -> str:
29
+ return "Metric that sums the squared error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
30
+
31
+ def aggregate(
32
+ self,
33
+ ddb_conn: DuckDBPyConnection,
34
+ dataset: Annotated[
35
+ DatasetReference,
36
+ MetricDatasetParameterAnnotation(
37
+ friendly_name="Dataset",
38
+ description="The dataset containing the inference data.",
39
+ model_problem_type=ModelProblemType.REGRESSION,
40
+ ),
41
+ ],
42
+ timestamp_col: Annotated[
43
+ str,
44
+ MetricColumnParameterAnnotation(
45
+ source_dataset_parameter_key="dataset",
46
+ allowed_column_types=[
47
+ ScalarType(dtype=DType.TIMESTAMP),
48
+ ],
49
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
50
+ friendly_name="Timestamp Column",
51
+ description="A column containing timestamp values to bucket by.",
52
+ ),
53
+ ],
54
+ prediction_col: Annotated[
55
+ str,
56
+ MetricColumnParameterAnnotation(
57
+ source_dataset_parameter_key="dataset",
58
+ allowed_column_types=[
59
+ ScalarType(dtype=DType.FLOAT),
60
+ ],
61
+ tag_hints=[ScopeSchemaTag.PREDICTION],
62
+ friendly_name="Prediction Column",
63
+ description="A column containing float typed prediction values.",
64
+ ),
65
+ ],
66
+ ground_truth_col: Annotated[
67
+ str,
68
+ MetricColumnParameterAnnotation(
69
+ source_dataset_parameter_key="dataset",
70
+ allowed_column_types=[
71
+ ScalarType(dtype=DType.FLOAT),
72
+ ],
73
+ tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
74
+ friendly_name="Ground Truth Column",
75
+ description="A column containing float typed ground truth values.",
76
+ ),
77
+ ],
78
+ ) -> list[NumericMetric]:
79
+ escaped_timestamp_col = escape_identifier(timestamp_col)
80
+ escaped_prediction_col = escape_identifier(prediction_col)
81
+ escaped_ground_truth_col = escape_identifier(ground_truth_col)
82
+ count_query = f" \
83
+ SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
84
+ SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error, \
85
+ COUNT(*) as count \
86
+ FROM {dataset.dataset_table_name} \
87
+ WHERE {escaped_prediction_col} IS NOT NULL \
88
+ AND {escaped_ground_truth_col} IS NOT NULL \
89
+ GROUP BY ts order by ts desc \
90
+ "
91
+
92
+ results = ddb_conn.sql(count_query).df()
93
+ count_series = self.dimensionless_query_results_to_numeric_metrics(
94
+ results,
95
+ "count",
96
+ "ts",
97
+ )
98
+ squared_error_series = self.dimensionless_query_results_to_numeric_metrics(
99
+ results,
100
+ "squared_error",
101
+ "ts",
102
+ )
103
+
104
+ count_metric = self.series_to_metric("squared_error_count", [count_series])
105
+ absolute_error_metric = self.series_to_metric(
106
+ "squared_error_sum",
107
+ [squared_error_series],
108
+ )
109
+
110
+ return [count_metric, absolute_error_metric]
@@ -0,0 +1,205 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
+ from arthur_common.models.datasets import ModelProblemType
6
+ from arthur_common.models.metrics import DatasetReference, NumericMetric
7
+ from arthur_common.models.schema_definitions import (
8
+ DType,
9
+ MetricColumnParameterAnnotation,
10
+ MetricDatasetParameterAnnotation,
11
+ MetricLiteralParameterAnnotation,
12
+ ScalarType,
13
+ ScopeSchemaTag,
14
+ )
15
+ from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
16
+ from duckdb import DuckDBPyConnection
17
+
18
+
19
+ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
20
+ NumericAggregationFunction,
21
+ ):
22
+ @staticmethod
23
+ def id() -> UUID:
24
+ return UUID("dc728927-6928-4a3b-b174-8c1ec8b58d62")
25
+
26
+ @staticmethod
27
+ def display_name() -> str:
28
+ return "Multiclass Classification Confusion Matrix Single Class - String Class Label Prediction"
29
+
30
+ @staticmethod
31
+ def description() -> str:
32
+ return (
33
+ "Aggregation that takes in the string label for the positive class, "
34
+ "and calculates the confusion matrix (True Positives, False Positives, "
35
+ "False Negatives, True Negatives) for that class compared to all others."
36
+ )
37
+
38
+ def aggregate(
39
+ self,
40
+ ddb_conn: DuckDBPyConnection,
41
+ dataset: Annotated[
42
+ DatasetReference,
43
+ MetricDatasetParameterAnnotation(
44
+ friendly_name="Dataset",
45
+ description="The dataset containing the prediction and ground truth values.",
46
+ model_problem_type=ModelProblemType.MULTICLASS_CLASSIFICATION,
47
+ ),
48
+ ],
49
+ timestamp_col: Annotated[
50
+ str,
51
+ MetricColumnParameterAnnotation(
52
+ source_dataset_parameter_key="dataset",
53
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
54
+ allowed_column_types=[
55
+ ScalarType(dtype=DType.TIMESTAMP),
56
+ ],
57
+ friendly_name="Timestamp Column",
58
+ description="A column containing timestamp values to bucket by.",
59
+ ),
60
+ ],
61
+ prediction_col: Annotated[
62
+ str,
63
+ MetricColumnParameterAnnotation(
64
+ source_dataset_parameter_key="dataset",
65
+ allowed_column_types=[
66
+ ScalarType(dtype=DType.STRING),
67
+ ],
68
+ tag_hints=[ScopeSchemaTag.PREDICTION],
69
+ friendly_name="Prediction Column",
70
+ description="A column containing the predicted string class label.",
71
+ ),
72
+ ],
73
+ gt_values_col: Annotated[
74
+ str,
75
+ MetricColumnParameterAnnotation(
76
+ source_dataset_parameter_key="dataset",
77
+ allowed_column_types=[
78
+ ScalarType(dtype=DType.STRING),
79
+ ],
80
+ tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
81
+ friendly_name="Ground Truth Column",
82
+ description="A column containing the ground truth string class label.",
83
+ ),
84
+ ],
85
+ positive_class_label: Annotated[
86
+ str,
87
+ MetricLiteralParameterAnnotation(
88
+ parameter_dtype=DType.STRING,
89
+ friendly_name="Positive Class Label",
90
+ description="The label indicating a positive class.",
91
+ ),
92
+ ],
93
+ ) -> list[NumericMetric]:
94
+ escaped_positive_class_label = escape_str_literal(positive_class_label)
95
+ normalization_case = f"""
96
+ CASE
97
+ WHEN value = {escaped_positive_class_label} THEN 1
98
+ ELSE 0
99
+ END
100
+ """
101
+ return self.generate_confusion_matrix_metrics(
102
+ ddb_conn,
103
+ timestamp_col,
104
+ prediction_col,
105
+ gt_values_col,
106
+ normalization_case,
107
+ normalization_case,
108
+ dataset,
109
+ escaped_positive_class_label,
110
+ )
111
+
112
+ def generate_confusion_matrix_metrics(
113
+ self,
114
+ ddb_conn: DuckDBPyConnection,
115
+ timestamp_col: str,
116
+ prediction_col: str,
117
+ gt_values_col: str,
118
+ prediction_normalization_case: str,
119
+ gt_normalization_case: str,
120
+ dataset: DatasetReference,
121
+ escaped_positive_class_label: str,
122
+ ) -> list[NumericMetric]:
123
+ """
124
+ Generate a SQL query to compute confusion matrix metrics over time.
125
+
126
+ Args:
127
+ ddb_conn: duck DB connection
128
+ timestamp_col: Column name containing timestamps
129
+ prediction_col: Column name containing predictions
130
+ gt_values_col: Column name containing ground truth values
131
+ prediction_normalization_case: SQL CASE statement for normalizing predictions to 0 / 1 / null using 'value' as the target column name
132
+ gt_normalization_case: SQL CASE statement for normalizing ground truth values to 0 / 1 / null using 'value' as the target column name
133
+ dataset: DatasetReference containing dataset metadata
134
+ escaped_positive_class_label: escaped label for the class to include in the dimensions
135
+
136
+ Returns:
137
+ str: SQL query that computes confusion matrix metrics
138
+ """
139
+ escaped_timestamp_col = escape_identifier(timestamp_col)
140
+ escaped_prediction_col = escape_identifier(prediction_col)
141
+ escaped_gt_values_col = escape_identifier(gt_values_col)
142
+ confusion_matrix_query = f"""
143
+ WITH normalized_data AS (
144
+ SELECT
145
+ {escaped_timestamp_col} AS timestamp,
146
+ {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
147
+ {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
148
+ FROM {dataset.dataset_table_name}
149
+ WHERE {escaped_timestamp_col} IS NOT NULL
150
+ )
151
+ SELECT
152
+ time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
153
+ SUM(CASE WHEN prediction = 1 AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
154
+ SUM(CASE WHEN prediction = 1 AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
155
+ SUM(CASE WHEN prediction = 0 AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
156
+ SUM(CASE WHEN prediction = 0 AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count,
157
+ any_value({escaped_positive_class_label}) as class_label
158
+ FROM normalized_data
159
+ GROUP BY ts
160
+ ORDER BY ts
161
+ """
162
+
163
+ results = ddb_conn.sql(confusion_matrix_query).df()
164
+
165
+ tp = self.group_query_results_to_numeric_metrics(
166
+ results,
167
+ "true_positive_count",
168
+ dim_columns=["class_label"],
169
+ timestamp_col="ts",
170
+ )
171
+ fp = self.group_query_results_to_numeric_metrics(
172
+ results,
173
+ "false_positive_count",
174
+ dim_columns=["class_label"],
175
+ timestamp_col="ts",
176
+ )
177
+ fn = self.group_query_results_to_numeric_metrics(
178
+ results,
179
+ "false_negative_count",
180
+ dim_columns=["class_label"],
181
+ timestamp_col="ts",
182
+ )
183
+ tn = self.group_query_results_to_numeric_metrics(
184
+ results,
185
+ "true_negative_count",
186
+ dim_columns=["class_label"],
187
+ timestamp_col="ts",
188
+ )
189
+ tp_metric = self.series_to_metric(
190
+ "multiclass_confusion_matrix_single_class_true_positive_count",
191
+ tp,
192
+ )
193
+ fp_metric = self.series_to_metric(
194
+ "multiclass_confusion_matrix_single_class_false_positive_count",
195
+ fp,
196
+ )
197
+ fn_metric = self.series_to_metric(
198
+ "multiclass_confusion_matrix_single_class_false_negative_count",
199
+ fn,
200
+ )
201
+ tn_metric = self.series_to_metric(
202
+ "multiclass_confusion_matrix_single_class_true_negative_count",
203
+ tn,
204
+ )
205
+ return [tp_metric, fp_metric, fn_metric, tn_metric]
@@ -0,0 +1,90 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.functions.inference_count_by_class import (
5
+ BinaryClassifierCountByClassAggregationFunction,
6
+ )
7
+ from arthur_common.models.datasets import ModelProblemType
8
+ from arthur_common.models.metrics import DatasetReference, NumericMetric
9
+ from arthur_common.models.schema_definitions import (
10
+ DType,
11
+ MetricColumnParameterAnnotation,
12
+ MetricDatasetParameterAnnotation,
13
+ ScalarType,
14
+ ScopeSchemaTag,
15
+ )
16
+ from duckdb import DuckDBPyConnection
17
+
18
+
19
+ class MulticlassClassifierCountByClassAggregationFunction(
20
+ BinaryClassifierCountByClassAggregationFunction,
21
+ ):
22
+ """
23
+ This class simply exposes the same calculation as the BinaryClassifierCountByClassAggregationFunction
24
+ but using the MULTICLASS_CLASSIFICATION tags
25
+ """
26
+
27
+ @staticmethod
28
+ def id() -> UUID:
29
+ return UUID("64a338fb-6c99-4c40-ba39-81ab8baa8687")
30
+
31
+ @staticmethod
32
+ def display_name() -> str:
33
+ return "Multiclass Classification Count by Class - Class Label"
34
+
35
+ @staticmethod
36
+ def description() -> str:
37
+ return (
38
+ "Aggregation that counts the number of predictions by class for a multiclass classifier. "
39
+ "Takes boolean, integer, or string prediction values and groups them by time bucket "
40
+ "to show prediction distribution over time."
41
+ )
42
+
43
+ @staticmethod
44
+ def _metric_name() -> str:
45
+ return "multiclass_classifier_count_by_class"
46
+
47
+ def aggregate(
48
+ self,
49
+ ddb_conn: DuckDBPyConnection,
50
+ dataset: Annotated[
51
+ DatasetReference,
52
+ MetricDatasetParameterAnnotation(
53
+ friendly_name="Dataset",
54
+ description="The dataset containing multiclass classifier prediction values.",
55
+ model_problem_type=ModelProblemType.MULTICLASS_CLASSIFICATION,
56
+ ),
57
+ ],
58
+ timestamp_col: Annotated[
59
+ str,
60
+ MetricColumnParameterAnnotation(
61
+ source_dataset_parameter_key="dataset",
62
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
63
+ allowed_column_types=[
64
+ ScalarType(dtype=DType.TIMESTAMP),
65
+ ],
66
+ friendly_name="Timestamp Column",
67
+ description="A column containing timestamp values to bucket by.",
68
+ ),
69
+ ],
70
+ prediction_col: Annotated[
71
+ str,
72
+ MetricColumnParameterAnnotation(
73
+ source_dataset_parameter_key="dataset",
74
+ allowed_column_types=[
75
+ ScalarType(dtype=DType.BOOL),
76
+ ScalarType(dtype=DType.INT),
77
+ ScalarType(dtype=DType.STRING),
78
+ ],
79
+ tag_hints=[ScopeSchemaTag.PREDICTION],
80
+ friendly_name="Prediction Column",
81
+ description="A column containing boolean, integer, or string labelled prediction values.",
82
+ ),
83
+ ],
84
+ ) -> list[NumericMetric]:
85
+ return super().aggregate(
86
+ ddb_conn=ddb_conn,
87
+ dataset=dataset,
88
+ timestamp_col=timestamp_col,
89
+ prediction_col=prediction_col,
90
+ )
@@ -0,0 +1,90 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ from arthur_common.aggregations.aggregator import SketchAggregationFunction
5
+ from arthur_common.models.metrics import DatasetReference, SketchMetric
6
+ from arthur_common.models.schema_definitions import (
7
+ DType,
8
+ MetricColumnParameterAnnotation,
9
+ MetricDatasetParameterAnnotation,
10
+ ScalarType,
11
+ ScopeSchemaTag,
12
+ )
13
+ from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
14
+ from duckdb import DuckDBPyConnection
15
+
16
+
17
+ class NumericSketchAggregationFunction(SketchAggregationFunction):
18
+ METRIC_NAME = "numeric_sketch"
19
+
20
+ @staticmethod
21
+ def id() -> UUID:
22
+ return UUID("00000000-0000-0000-0000-00000000000d")
23
+
24
+ @staticmethod
25
+ def display_name() -> str:
26
+ return "Numeric Distribution"
27
+
28
+ @staticmethod
29
+ def description() -> str:
30
+ return (
31
+ "Metric that calculates a distribution (data sketch) on a numeric column."
32
+ )
33
+
34
+ def aggregate(
35
+ self,
36
+ ddb_conn: DuckDBPyConnection,
37
+ dataset: Annotated[
38
+ DatasetReference,
39
+ MetricDatasetParameterAnnotation(
40
+ friendly_name="Dataset",
41
+ description="The dataset containing the numeric data.",
42
+ ),
43
+ ],
44
+ timestamp_col: Annotated[
45
+ str,
46
+ MetricColumnParameterAnnotation(
47
+ source_dataset_parameter_key="dataset",
48
+ allowed_column_types=[
49
+ ScalarType(dtype=DType.TIMESTAMP),
50
+ ],
51
+ tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
52
+ friendly_name="Timestamp Column",
53
+ description="A column containing timestamp values to bucket by.",
54
+ ),
55
+ ],
56
+ numeric_col: Annotated[
57
+ str,
58
+ MetricColumnParameterAnnotation(
59
+ source_dataset_parameter_key="dataset",
60
+ allowed_column_types=[
61
+ ScalarType(dtype=DType.INT),
62
+ ScalarType(dtype=DType.FLOAT),
63
+ ],
64
+ tag_hints=[ScopeSchemaTag.CONTINUOUS],
65
+ friendly_name="Numeric Column",
66
+ description="A column containing numeric values to calculate a data sketch on.",
67
+ ),
68
+ ],
69
+ ) -> list[SketchMetric]:
70
+ escaped_timestamp_col_id = escape_identifier(timestamp_col)
71
+ escaped_numeric_col_id = escape_identifier(numeric_col)
72
+ numeric_col_name_str = escape_str_literal(numeric_col)
73
+ data_query = f" \
74
+ select {escaped_timestamp_col_id} as ts, \
75
+ {escaped_numeric_col_id}, \
76
+ {numeric_col_name_str} as column_name \
77
+ from {dataset.dataset_table_name} \
78
+ where {escaped_numeric_col_id} is not null \
79
+ "
80
+ results = ddb_conn.sql(data_query).df()
81
+
82
+ series = self.group_query_results_to_sketch_metrics(
83
+ results,
84
+ numeric_col,
85
+ ["column_name"],
86
+ "ts",
87
+ )
88
+
89
+ metric = self.series_to_metric(self.METRIC_NAME, series)
90
+ return [metric]