arthur-common 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/__init__.py +0 -0
- arthur_common/__version__.py +1 -0
- arthur_common/aggregations/__init__.py +2 -0
- arthur_common/aggregations/aggregator.py +214 -0
- arthur_common/aggregations/functions/README.md +26 -0
- arthur_common/aggregations/functions/__init__.py +25 -0
- arthur_common/aggregations/functions/categorical_count.py +89 -0
- arthur_common/aggregations/functions/confusion_matrix.py +412 -0
- arthur_common/aggregations/functions/inference_count.py +69 -0
- arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
- arthur_common/aggregations/functions/inference_null_count.py +82 -0
- arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
- arthur_common/aggregations/functions/mean_squared_error.py +110 -0
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
- arthur_common/aggregations/functions/numeric_stats.py +90 -0
- arthur_common/aggregations/functions/numeric_sum.py +87 -0
- arthur_common/aggregations/functions/py.typed +0 -0
- arthur_common/aggregations/functions/shield_aggregations.py +752 -0
- arthur_common/aggregations/py.typed +0 -0
- arthur_common/models/__init__.py +0 -0
- arthur_common/models/connectors.py +41 -0
- arthur_common/models/datasets.py +22 -0
- arthur_common/models/metrics.py +227 -0
- arthur_common/models/py.typed +0 -0
- arthur_common/models/schema_definitions.py +420 -0
- arthur_common/models/shield.py +504 -0
- arthur_common/models/task_job_specs.py +78 -0
- arthur_common/py.typed +0 -0
- arthur_common/tools/__init__.py +0 -0
- arthur_common/tools/aggregation_analyzer.py +243 -0
- arthur_common/tools/aggregation_loader.py +59 -0
- arthur_common/tools/duckdb_data_loader.py +329 -0
- arthur_common/tools/functions.py +46 -0
- arthur_common/tools/py.typed +0 -0
- arthur_common/tools/schema_inferer.py +104 -0
- arthur_common/tools/time_utils.py +33 -0
- arthur_common-1.0.1.dist-info/METADATA +74 -0
- arthur_common-1.0.1.dist-info/RECORD +40 -0
- arthur_common-1.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
5
|
+
from arthur_common.models.metrics import DatasetReference, Dimension, NumericMetric
|
|
6
|
+
from arthur_common.models.schema_definitions import (
|
|
7
|
+
DType,
|
|
8
|
+
MetricColumnParameterAnnotation,
|
|
9
|
+
MetricDatasetParameterAnnotation,
|
|
10
|
+
ScalarType,
|
|
11
|
+
ScopeSchemaTag,
|
|
12
|
+
)
|
|
13
|
+
from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
14
|
+
from duckdb import DuckDBPyConnection
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InferenceNullCountAggregationFunction(NumericAggregationFunction):
|
|
18
|
+
METRIC_NAME = "null_count"
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def id() -> UUID:
|
|
22
|
+
return UUID("00000000-0000-0000-0000-00000000000b")
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def display_name() -> str:
|
|
26
|
+
return "Null Value Count"
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def description() -> str:
|
|
30
|
+
return "Metric that counts the number of null values in the column per time window."
|
|
31
|
+
|
|
32
|
+
def aggregate(
|
|
33
|
+
self,
|
|
34
|
+
ddb_conn: DuckDBPyConnection,
|
|
35
|
+
dataset: Annotated[
|
|
36
|
+
DatasetReference,
|
|
37
|
+
MetricDatasetParameterAnnotation(
|
|
38
|
+
friendly_name="Dataset",
|
|
39
|
+
description="The dataset containing the inference data.",
|
|
40
|
+
),
|
|
41
|
+
],
|
|
42
|
+
timestamp_col: Annotated[
|
|
43
|
+
str,
|
|
44
|
+
MetricColumnParameterAnnotation(
|
|
45
|
+
source_dataset_parameter_key="dataset",
|
|
46
|
+
allowed_column_types=[
|
|
47
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
48
|
+
],
|
|
49
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
50
|
+
friendly_name="Timestamp Column",
|
|
51
|
+
description="A column containing timestamp values to bucket by.",
|
|
52
|
+
),
|
|
53
|
+
],
|
|
54
|
+
nullable_col: Annotated[
|
|
55
|
+
str,
|
|
56
|
+
MetricColumnParameterAnnotation(
|
|
57
|
+
source_dataset_parameter_key="dataset",
|
|
58
|
+
allow_any_column_type=True,
|
|
59
|
+
friendly_name="Nullable Column",
|
|
60
|
+
description="A column containing nullable values to count.",
|
|
61
|
+
),
|
|
62
|
+
],
|
|
63
|
+
) -> list[NumericMetric]:
|
|
64
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
65
|
+
escaped_nullable_col = escape_identifier(nullable_col)
|
|
66
|
+
count_query = f" \
|
|
67
|
+
select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
|
|
68
|
+
count(*) as count \
|
|
69
|
+
from {dataset.dataset_table_name} where {escaped_nullable_col} is null \
|
|
70
|
+
group by ts \
|
|
71
|
+
"
|
|
72
|
+
results = ddb_conn.sql(count_query).df()
|
|
73
|
+
|
|
74
|
+
series = self.dimensionless_query_results_to_numeric_metrics(
|
|
75
|
+
results,
|
|
76
|
+
"count",
|
|
77
|
+
"ts",
|
|
78
|
+
)
|
|
79
|
+
series.dimensions = [Dimension(name="column_name", value=nullable_col)]
|
|
80
|
+
|
|
81
|
+
metric = self.series_to_metric(self.METRIC_NAME, [series])
|
|
82
|
+
return [metric]
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
5
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
6
|
+
from arthur_common.models.metrics import DatasetReference, NumericMetric
|
|
7
|
+
from arthur_common.models.schema_definitions import (
|
|
8
|
+
DType,
|
|
9
|
+
MetricColumnParameterAnnotation,
|
|
10
|
+
MetricDatasetParameterAnnotation,
|
|
11
|
+
ScalarType,
|
|
12
|
+
ScopeSchemaTag,
|
|
13
|
+
)
|
|
14
|
+
from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
15
|
+
from duckdb import DuckDBPyConnection
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
|
|
19
|
+
@staticmethod
|
|
20
|
+
def id() -> UUID:
|
|
21
|
+
return UUID("00000000-0000-0000-0000-00000000000e")
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def display_name() -> str:
|
|
25
|
+
return "Mean Absolute Error"
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def description() -> str:
|
|
29
|
+
return "Metric that sums the absolute error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
|
|
30
|
+
|
|
31
|
+
def aggregate(
|
|
32
|
+
self,
|
|
33
|
+
ddb_conn: DuckDBPyConnection,
|
|
34
|
+
dataset: Annotated[
|
|
35
|
+
DatasetReference,
|
|
36
|
+
MetricDatasetParameterAnnotation(
|
|
37
|
+
friendly_name="Dataset",
|
|
38
|
+
description="The dataset containing the inference data.",
|
|
39
|
+
model_problem_type=ModelProblemType.REGRESSION,
|
|
40
|
+
),
|
|
41
|
+
],
|
|
42
|
+
timestamp_col: Annotated[
|
|
43
|
+
str,
|
|
44
|
+
MetricColumnParameterAnnotation(
|
|
45
|
+
source_dataset_parameter_key="dataset",
|
|
46
|
+
allowed_column_types=[
|
|
47
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
48
|
+
],
|
|
49
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
50
|
+
friendly_name="Timestamp Column",
|
|
51
|
+
description="A column containing timestamp values to bucket by.",
|
|
52
|
+
),
|
|
53
|
+
],
|
|
54
|
+
prediction_col: Annotated[
|
|
55
|
+
str,
|
|
56
|
+
MetricColumnParameterAnnotation(
|
|
57
|
+
source_dataset_parameter_key="dataset",
|
|
58
|
+
allowed_column_types=[
|
|
59
|
+
ScalarType(dtype=DType.FLOAT),
|
|
60
|
+
],
|
|
61
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
62
|
+
friendly_name="Prediction Column",
|
|
63
|
+
description="A column containing float typed prediction values.",
|
|
64
|
+
),
|
|
65
|
+
],
|
|
66
|
+
ground_truth_col: Annotated[
|
|
67
|
+
str,
|
|
68
|
+
MetricColumnParameterAnnotation(
|
|
69
|
+
source_dataset_parameter_key="dataset",
|
|
70
|
+
allowed_column_types=[
|
|
71
|
+
ScalarType(dtype=DType.FLOAT),
|
|
72
|
+
],
|
|
73
|
+
tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
|
|
74
|
+
friendly_name="Ground Truth Column",
|
|
75
|
+
description="A column containing float typed ground truth values.",
|
|
76
|
+
),
|
|
77
|
+
],
|
|
78
|
+
) -> list[NumericMetric]:
|
|
79
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
80
|
+
escaped_prediction_col = escape_identifier(prediction_col)
|
|
81
|
+
escaped_ground_truth_col = escape_identifier(ground_truth_col)
|
|
82
|
+
count_query = f" \
|
|
83
|
+
SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
|
|
84
|
+
SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae, \
|
|
85
|
+
COUNT(*) as count \
|
|
86
|
+
FROM {dataset.dataset_table_name} \
|
|
87
|
+
WHERE {escaped_prediction_col} IS NOT NULL \
|
|
88
|
+
AND {escaped_ground_truth_col} IS NOT NULL \
|
|
89
|
+
GROUP BY ts order by ts desc \
|
|
90
|
+
"
|
|
91
|
+
|
|
92
|
+
results = ddb_conn.sql(count_query).df()
|
|
93
|
+
count_series = self.dimensionless_query_results_to_numeric_metrics(
|
|
94
|
+
results,
|
|
95
|
+
"count",
|
|
96
|
+
"ts",
|
|
97
|
+
)
|
|
98
|
+
absolute_error_series = self.dimensionless_query_results_to_numeric_metrics(
|
|
99
|
+
results,
|
|
100
|
+
"ae",
|
|
101
|
+
"ts",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
count_metric = self.series_to_metric("absolute_error_count", [count_series])
|
|
105
|
+
absolute_error_metric = self.series_to_metric(
|
|
106
|
+
"absolute_error_sum",
|
|
107
|
+
[absolute_error_series],
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return [count_metric, absolute_error_metric]
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
5
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
6
|
+
from arthur_common.models.metrics import DatasetReference, NumericMetric
|
|
7
|
+
from arthur_common.models.schema_definitions import (
|
|
8
|
+
DType,
|
|
9
|
+
MetricColumnParameterAnnotation,
|
|
10
|
+
MetricDatasetParameterAnnotation,
|
|
11
|
+
ScalarType,
|
|
12
|
+
ScopeSchemaTag,
|
|
13
|
+
)
|
|
14
|
+
from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
15
|
+
from duckdb import DuckDBPyConnection
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
|
|
19
|
+
@staticmethod
|
|
20
|
+
def id() -> UUID:
|
|
21
|
+
return UUID("00000000-0000-0000-0000-000000000010")
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def display_name() -> str:
|
|
25
|
+
return "Mean Squared Error"
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def description() -> str:
|
|
29
|
+
return "Metric that sums the squared error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
|
|
30
|
+
|
|
31
|
+
def aggregate(
|
|
32
|
+
self,
|
|
33
|
+
ddb_conn: DuckDBPyConnection,
|
|
34
|
+
dataset: Annotated[
|
|
35
|
+
DatasetReference,
|
|
36
|
+
MetricDatasetParameterAnnotation(
|
|
37
|
+
friendly_name="Dataset",
|
|
38
|
+
description="The dataset containing the inference data.",
|
|
39
|
+
model_problem_type=ModelProblemType.REGRESSION,
|
|
40
|
+
),
|
|
41
|
+
],
|
|
42
|
+
timestamp_col: Annotated[
|
|
43
|
+
str,
|
|
44
|
+
MetricColumnParameterAnnotation(
|
|
45
|
+
source_dataset_parameter_key="dataset",
|
|
46
|
+
allowed_column_types=[
|
|
47
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
48
|
+
],
|
|
49
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
50
|
+
friendly_name="Timestamp Column",
|
|
51
|
+
description="A column containing timestamp values to bucket by.",
|
|
52
|
+
),
|
|
53
|
+
],
|
|
54
|
+
prediction_col: Annotated[
|
|
55
|
+
str,
|
|
56
|
+
MetricColumnParameterAnnotation(
|
|
57
|
+
source_dataset_parameter_key="dataset",
|
|
58
|
+
allowed_column_types=[
|
|
59
|
+
ScalarType(dtype=DType.FLOAT),
|
|
60
|
+
],
|
|
61
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
62
|
+
friendly_name="Prediction Column",
|
|
63
|
+
description="A column containing float typed prediction values.",
|
|
64
|
+
),
|
|
65
|
+
],
|
|
66
|
+
ground_truth_col: Annotated[
|
|
67
|
+
str,
|
|
68
|
+
MetricColumnParameterAnnotation(
|
|
69
|
+
source_dataset_parameter_key="dataset",
|
|
70
|
+
allowed_column_types=[
|
|
71
|
+
ScalarType(dtype=DType.FLOAT),
|
|
72
|
+
],
|
|
73
|
+
tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
|
|
74
|
+
friendly_name="Ground Truth Column",
|
|
75
|
+
description="A column containing float typed ground truth values.",
|
|
76
|
+
),
|
|
77
|
+
],
|
|
78
|
+
) -> list[NumericMetric]:
|
|
79
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
80
|
+
escaped_prediction_col = escape_identifier(prediction_col)
|
|
81
|
+
escaped_ground_truth_col = escape_identifier(ground_truth_col)
|
|
82
|
+
count_query = f" \
|
|
83
|
+
SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
|
|
84
|
+
SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error, \
|
|
85
|
+
COUNT(*) as count \
|
|
86
|
+
FROM {dataset.dataset_table_name} \
|
|
87
|
+
WHERE {escaped_prediction_col} IS NOT NULL \
|
|
88
|
+
AND {escaped_ground_truth_col} IS NOT NULL \
|
|
89
|
+
GROUP BY ts order by ts desc \
|
|
90
|
+
"
|
|
91
|
+
|
|
92
|
+
results = ddb_conn.sql(count_query).df()
|
|
93
|
+
count_series = self.dimensionless_query_results_to_numeric_metrics(
|
|
94
|
+
results,
|
|
95
|
+
"count",
|
|
96
|
+
"ts",
|
|
97
|
+
)
|
|
98
|
+
squared_error_series = self.dimensionless_query_results_to_numeric_metrics(
|
|
99
|
+
results,
|
|
100
|
+
"squared_error",
|
|
101
|
+
"ts",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
count_metric = self.series_to_metric("squared_error_count", [count_series])
|
|
105
|
+
absolute_error_metric = self.series_to_metric(
|
|
106
|
+
"squared_error_sum",
|
|
107
|
+
[squared_error_series],
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return [count_metric, absolute_error_metric]
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
5
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
6
|
+
from arthur_common.models.metrics import DatasetReference, NumericMetric
|
|
7
|
+
from arthur_common.models.schema_definitions import (
|
|
8
|
+
DType,
|
|
9
|
+
MetricColumnParameterAnnotation,
|
|
10
|
+
MetricDatasetParameterAnnotation,
|
|
11
|
+
MetricLiteralParameterAnnotation,
|
|
12
|
+
ScalarType,
|
|
13
|
+
ScopeSchemaTag,
|
|
14
|
+
)
|
|
15
|
+
from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
|
|
16
|
+
from duckdb import DuckDBPyConnection
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
|
|
20
|
+
NumericAggregationFunction,
|
|
21
|
+
):
|
|
22
|
+
@staticmethod
|
|
23
|
+
def id() -> UUID:
|
|
24
|
+
return UUID("dc728927-6928-4a3b-b174-8c1ec8b58d62")
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def display_name() -> str:
|
|
28
|
+
return "Multiclass Classification Confusion Matrix Single Class - String Class Label Prediction"
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def description() -> str:
|
|
32
|
+
return (
|
|
33
|
+
"Aggregation that takes in the string label for the positive class, "
|
|
34
|
+
"and calculates the confusion matrix (True Positives, False Positives, "
|
|
35
|
+
"False Negatives, True Negatives) for that class compared to all others."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def aggregate(
|
|
39
|
+
self,
|
|
40
|
+
ddb_conn: DuckDBPyConnection,
|
|
41
|
+
dataset: Annotated[
|
|
42
|
+
DatasetReference,
|
|
43
|
+
MetricDatasetParameterAnnotation(
|
|
44
|
+
friendly_name="Dataset",
|
|
45
|
+
description="The dataset containing the prediction and ground truth values.",
|
|
46
|
+
model_problem_type=ModelProblemType.MULTICLASS_CLASSIFICATION,
|
|
47
|
+
),
|
|
48
|
+
],
|
|
49
|
+
timestamp_col: Annotated[
|
|
50
|
+
str,
|
|
51
|
+
MetricColumnParameterAnnotation(
|
|
52
|
+
source_dataset_parameter_key="dataset",
|
|
53
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
54
|
+
allowed_column_types=[
|
|
55
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
56
|
+
],
|
|
57
|
+
friendly_name="Timestamp Column",
|
|
58
|
+
description="A column containing timestamp values to bucket by.",
|
|
59
|
+
),
|
|
60
|
+
],
|
|
61
|
+
prediction_col: Annotated[
|
|
62
|
+
str,
|
|
63
|
+
MetricColumnParameterAnnotation(
|
|
64
|
+
source_dataset_parameter_key="dataset",
|
|
65
|
+
allowed_column_types=[
|
|
66
|
+
ScalarType(dtype=DType.STRING),
|
|
67
|
+
],
|
|
68
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
69
|
+
friendly_name="Prediction Column",
|
|
70
|
+
description="A column containing the predicted string class label.",
|
|
71
|
+
),
|
|
72
|
+
],
|
|
73
|
+
gt_values_col: Annotated[
|
|
74
|
+
str,
|
|
75
|
+
MetricColumnParameterAnnotation(
|
|
76
|
+
source_dataset_parameter_key="dataset",
|
|
77
|
+
allowed_column_types=[
|
|
78
|
+
ScalarType(dtype=DType.STRING),
|
|
79
|
+
],
|
|
80
|
+
tag_hints=[ScopeSchemaTag.GROUND_TRUTH],
|
|
81
|
+
friendly_name="Ground Truth Column",
|
|
82
|
+
description="A column containing the ground truth string class label.",
|
|
83
|
+
),
|
|
84
|
+
],
|
|
85
|
+
positive_class_label: Annotated[
|
|
86
|
+
str,
|
|
87
|
+
MetricLiteralParameterAnnotation(
|
|
88
|
+
parameter_dtype=DType.STRING,
|
|
89
|
+
friendly_name="Positive Class Label",
|
|
90
|
+
description="The label indicating a positive class.",
|
|
91
|
+
),
|
|
92
|
+
],
|
|
93
|
+
) -> list[NumericMetric]:
|
|
94
|
+
escaped_positive_class_label = escape_str_literal(positive_class_label)
|
|
95
|
+
normalization_case = f"""
|
|
96
|
+
CASE
|
|
97
|
+
WHEN value = {escaped_positive_class_label} THEN 1
|
|
98
|
+
ELSE 0
|
|
99
|
+
END
|
|
100
|
+
"""
|
|
101
|
+
return self.generate_confusion_matrix_metrics(
|
|
102
|
+
ddb_conn,
|
|
103
|
+
timestamp_col,
|
|
104
|
+
prediction_col,
|
|
105
|
+
gt_values_col,
|
|
106
|
+
normalization_case,
|
|
107
|
+
normalization_case,
|
|
108
|
+
dataset,
|
|
109
|
+
escaped_positive_class_label,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def generate_confusion_matrix_metrics(
|
|
113
|
+
self,
|
|
114
|
+
ddb_conn: DuckDBPyConnection,
|
|
115
|
+
timestamp_col: str,
|
|
116
|
+
prediction_col: str,
|
|
117
|
+
gt_values_col: str,
|
|
118
|
+
prediction_normalization_case: str,
|
|
119
|
+
gt_normalization_case: str,
|
|
120
|
+
dataset: DatasetReference,
|
|
121
|
+
escaped_positive_class_label: str,
|
|
122
|
+
) -> list[NumericMetric]:
|
|
123
|
+
"""
|
|
124
|
+
Generate a SQL query to compute confusion matrix metrics over time.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
ddb_conn: duck DB connection
|
|
128
|
+
timestamp_col: Column name containing timestamps
|
|
129
|
+
prediction_col: Column name containing predictions
|
|
130
|
+
gt_values_col: Column name containing ground truth values
|
|
131
|
+
prediction_normalization_case: SQL CASE statement for normalizing predictions to 0 / 1 / null using 'value' as the target column name
|
|
132
|
+
gt_normalization_case: SQL CASE statement for normalizing ground truth values to 0 / 1 / null using 'value' as the target column name
|
|
133
|
+
dataset: DatasetReference containing dataset metadata
|
|
134
|
+
escaped_positive_class_label: escaped label for the class to include in the dimensions
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
str: SQL query that computes confusion matrix metrics
|
|
138
|
+
"""
|
|
139
|
+
escaped_timestamp_col = escape_identifier(timestamp_col)
|
|
140
|
+
escaped_prediction_col = escape_identifier(prediction_col)
|
|
141
|
+
escaped_gt_values_col = escape_identifier(gt_values_col)
|
|
142
|
+
confusion_matrix_query = f"""
|
|
143
|
+
WITH normalized_data AS (
|
|
144
|
+
SELECT
|
|
145
|
+
{escaped_timestamp_col} AS timestamp,
|
|
146
|
+
{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
|
|
147
|
+
{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
|
|
148
|
+
FROM {dataset.dataset_table_name}
|
|
149
|
+
WHERE {escaped_timestamp_col} IS NOT NULL
|
|
150
|
+
)
|
|
151
|
+
SELECT
|
|
152
|
+
time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
|
|
153
|
+
SUM(CASE WHEN prediction = 1 AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
|
|
154
|
+
SUM(CASE WHEN prediction = 1 AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
|
|
155
|
+
SUM(CASE WHEN prediction = 0 AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
|
|
156
|
+
SUM(CASE WHEN prediction = 0 AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count,
|
|
157
|
+
any_value({escaped_positive_class_label}) as class_label
|
|
158
|
+
FROM normalized_data
|
|
159
|
+
GROUP BY ts
|
|
160
|
+
ORDER BY ts
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
results = ddb_conn.sql(confusion_matrix_query).df()
|
|
164
|
+
|
|
165
|
+
tp = self.group_query_results_to_numeric_metrics(
|
|
166
|
+
results,
|
|
167
|
+
"true_positive_count",
|
|
168
|
+
dim_columns=["class_label"],
|
|
169
|
+
timestamp_col="ts",
|
|
170
|
+
)
|
|
171
|
+
fp = self.group_query_results_to_numeric_metrics(
|
|
172
|
+
results,
|
|
173
|
+
"false_positive_count",
|
|
174
|
+
dim_columns=["class_label"],
|
|
175
|
+
timestamp_col="ts",
|
|
176
|
+
)
|
|
177
|
+
fn = self.group_query_results_to_numeric_metrics(
|
|
178
|
+
results,
|
|
179
|
+
"false_negative_count",
|
|
180
|
+
dim_columns=["class_label"],
|
|
181
|
+
timestamp_col="ts",
|
|
182
|
+
)
|
|
183
|
+
tn = self.group_query_results_to_numeric_metrics(
|
|
184
|
+
results,
|
|
185
|
+
"true_negative_count",
|
|
186
|
+
dim_columns=["class_label"],
|
|
187
|
+
timestamp_col="ts",
|
|
188
|
+
)
|
|
189
|
+
tp_metric = self.series_to_metric(
|
|
190
|
+
"multiclass_confusion_matrix_single_class_true_positive_count",
|
|
191
|
+
tp,
|
|
192
|
+
)
|
|
193
|
+
fp_metric = self.series_to_metric(
|
|
194
|
+
"multiclass_confusion_matrix_single_class_false_positive_count",
|
|
195
|
+
fp,
|
|
196
|
+
)
|
|
197
|
+
fn_metric = self.series_to_metric(
|
|
198
|
+
"multiclass_confusion_matrix_single_class_false_negative_count",
|
|
199
|
+
fn,
|
|
200
|
+
)
|
|
201
|
+
tn_metric = self.series_to_metric(
|
|
202
|
+
"multiclass_confusion_matrix_single_class_true_negative_count",
|
|
203
|
+
tn,
|
|
204
|
+
)
|
|
205
|
+
return [tp_metric, fp_metric, fn_metric, tn_metric]
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.functions.inference_count_by_class import (
|
|
5
|
+
BinaryClassifierCountByClassAggregationFunction,
|
|
6
|
+
)
|
|
7
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
+
from arthur_common.models.metrics import DatasetReference, NumericMetric
|
|
9
|
+
from arthur_common.models.schema_definitions import (
|
|
10
|
+
DType,
|
|
11
|
+
MetricColumnParameterAnnotation,
|
|
12
|
+
MetricDatasetParameterAnnotation,
|
|
13
|
+
ScalarType,
|
|
14
|
+
ScopeSchemaTag,
|
|
15
|
+
)
|
|
16
|
+
from duckdb import DuckDBPyConnection
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MulticlassClassifierCountByClassAggregationFunction(
|
|
20
|
+
BinaryClassifierCountByClassAggregationFunction,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
This class simply exposes the same calculation as the BinaryClassifierCountByClassAggregationFunction
|
|
24
|
+
but using the MULTICLASS_CLASSIFICATION tags
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def id() -> UUID:
|
|
29
|
+
return UUID("64a338fb-6c99-4c40-ba39-81ab8baa8687")
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def display_name() -> str:
|
|
33
|
+
return "Multiclass Classification Count by Class - Class Label"
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def description() -> str:
|
|
37
|
+
return (
|
|
38
|
+
"Aggregation that counts the number of predictions by class for a multiclass classifier. "
|
|
39
|
+
"Takes boolean, integer, or string prediction values and groups them by time bucket "
|
|
40
|
+
"to show prediction distribution over time."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def _metric_name() -> str:
|
|
45
|
+
return "multiclass_classifier_count_by_class"
|
|
46
|
+
|
|
47
|
+
def aggregate(
|
|
48
|
+
self,
|
|
49
|
+
ddb_conn: DuckDBPyConnection,
|
|
50
|
+
dataset: Annotated[
|
|
51
|
+
DatasetReference,
|
|
52
|
+
MetricDatasetParameterAnnotation(
|
|
53
|
+
friendly_name="Dataset",
|
|
54
|
+
description="The dataset containing multiclass classifier prediction values.",
|
|
55
|
+
model_problem_type=ModelProblemType.MULTICLASS_CLASSIFICATION,
|
|
56
|
+
),
|
|
57
|
+
],
|
|
58
|
+
timestamp_col: Annotated[
|
|
59
|
+
str,
|
|
60
|
+
MetricColumnParameterAnnotation(
|
|
61
|
+
source_dataset_parameter_key="dataset",
|
|
62
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
63
|
+
allowed_column_types=[
|
|
64
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
65
|
+
],
|
|
66
|
+
friendly_name="Timestamp Column",
|
|
67
|
+
description="A column containing timestamp values to bucket by.",
|
|
68
|
+
),
|
|
69
|
+
],
|
|
70
|
+
prediction_col: Annotated[
|
|
71
|
+
str,
|
|
72
|
+
MetricColumnParameterAnnotation(
|
|
73
|
+
source_dataset_parameter_key="dataset",
|
|
74
|
+
allowed_column_types=[
|
|
75
|
+
ScalarType(dtype=DType.BOOL),
|
|
76
|
+
ScalarType(dtype=DType.INT),
|
|
77
|
+
ScalarType(dtype=DType.STRING),
|
|
78
|
+
],
|
|
79
|
+
tag_hints=[ScopeSchemaTag.PREDICTION],
|
|
80
|
+
friendly_name="Prediction Column",
|
|
81
|
+
description="A column containing boolean, integer, or string labelled prediction values.",
|
|
82
|
+
),
|
|
83
|
+
],
|
|
84
|
+
) -> list[NumericMetric]:
|
|
85
|
+
return super().aggregate(
|
|
86
|
+
ddb_conn=ddb_conn,
|
|
87
|
+
dataset=dataset,
|
|
88
|
+
timestamp_col=timestamp_col,
|
|
89
|
+
prediction_col=prediction_col,
|
|
90
|
+
)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from typing import Annotated
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from arthur_common.aggregations.aggregator import SketchAggregationFunction
|
|
5
|
+
from arthur_common.models.metrics import DatasetReference, SketchMetric
|
|
6
|
+
from arthur_common.models.schema_definitions import (
|
|
7
|
+
DType,
|
|
8
|
+
MetricColumnParameterAnnotation,
|
|
9
|
+
MetricDatasetParameterAnnotation,
|
|
10
|
+
ScalarType,
|
|
11
|
+
ScopeSchemaTag,
|
|
12
|
+
)
|
|
13
|
+
from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
|
|
14
|
+
from duckdb import DuckDBPyConnection
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class NumericSketchAggregationFunction(SketchAggregationFunction):
|
|
18
|
+
METRIC_NAME = "numeric_sketch"
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def id() -> UUID:
|
|
22
|
+
return UUID("00000000-0000-0000-0000-00000000000d")
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def display_name() -> str:
|
|
26
|
+
return "Numeric Distribution"
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def description() -> str:
|
|
30
|
+
return (
|
|
31
|
+
"Metric that calculates a distribution (data sketch) on a numeric column."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def aggregate(
|
|
35
|
+
self,
|
|
36
|
+
ddb_conn: DuckDBPyConnection,
|
|
37
|
+
dataset: Annotated[
|
|
38
|
+
DatasetReference,
|
|
39
|
+
MetricDatasetParameterAnnotation(
|
|
40
|
+
friendly_name="Dataset",
|
|
41
|
+
description="The dataset containing the numeric data.",
|
|
42
|
+
),
|
|
43
|
+
],
|
|
44
|
+
timestamp_col: Annotated[
|
|
45
|
+
str,
|
|
46
|
+
MetricColumnParameterAnnotation(
|
|
47
|
+
source_dataset_parameter_key="dataset",
|
|
48
|
+
allowed_column_types=[
|
|
49
|
+
ScalarType(dtype=DType.TIMESTAMP),
|
|
50
|
+
],
|
|
51
|
+
tag_hints=[ScopeSchemaTag.PRIMARY_TIMESTAMP],
|
|
52
|
+
friendly_name="Timestamp Column",
|
|
53
|
+
description="A column containing timestamp values to bucket by.",
|
|
54
|
+
),
|
|
55
|
+
],
|
|
56
|
+
numeric_col: Annotated[
|
|
57
|
+
str,
|
|
58
|
+
MetricColumnParameterAnnotation(
|
|
59
|
+
source_dataset_parameter_key="dataset",
|
|
60
|
+
allowed_column_types=[
|
|
61
|
+
ScalarType(dtype=DType.INT),
|
|
62
|
+
ScalarType(dtype=DType.FLOAT),
|
|
63
|
+
],
|
|
64
|
+
tag_hints=[ScopeSchemaTag.CONTINUOUS],
|
|
65
|
+
friendly_name="Numeric Column",
|
|
66
|
+
description="A column containing numeric values to calculate a data sketch on.",
|
|
67
|
+
),
|
|
68
|
+
],
|
|
69
|
+
) -> list[SketchMetric]:
|
|
70
|
+
escaped_timestamp_col_id = escape_identifier(timestamp_col)
|
|
71
|
+
escaped_numeric_col_id = escape_identifier(numeric_col)
|
|
72
|
+
numeric_col_name_str = escape_str_literal(numeric_col)
|
|
73
|
+
data_query = f" \
|
|
74
|
+
select {escaped_timestamp_col_id} as ts, \
|
|
75
|
+
{escaped_numeric_col_id}, \
|
|
76
|
+
{numeric_col_name_str} as column_name \
|
|
77
|
+
from {dataset.dataset_table_name} \
|
|
78
|
+
where {escaped_numeric_col_id} is not null \
|
|
79
|
+
"
|
|
80
|
+
results = ddb_conn.sql(data_query).df()
|
|
81
|
+
|
|
82
|
+
series = self.group_query_results_to_sketch_metrics(
|
|
83
|
+
results,
|
|
84
|
+
numeric_col,
|
|
85
|
+
["column_name"],
|
|
86
|
+
"ts",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
90
|
+
return [metric]
|