arthur-common 2.1.52__py3-none-any.whl → 2.1.53__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/aggregations/aggregator.py +6 -0
- arthur_common/aggregations/functions/categorical_count.py +14 -1
- arthur_common/aggregations/functions/confusion_matrix.py +35 -5
- arthur_common/aggregations/functions/inference_count.py +14 -1
- arthur_common/aggregations/functions/inference_count_by_class.py +23 -1
- arthur_common/aggregations/functions/inference_null_count.py +15 -1
- arthur_common/aggregations/functions/mean_absolute_error.py +25 -3
- arthur_common/aggregations/functions/mean_squared_error.py +25 -3
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +43 -5
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +14 -1
- arthur_common/aggregations/functions/numeric_stats.py +14 -1
- arthur_common/aggregations/functions/numeric_sum.py +15 -1
- arthur_common/aggregations/functions/shield_aggregations.py +126 -16
- arthur_common/models/metrics.py +34 -18
- arthur_common/tools/aggregation_analyzer.py +2 -1
- {arthur_common-2.1.52.dist-info → arthur_common-2.1.53.dist-info}/METADATA +1 -1
- {arthur_common-2.1.52.dist-info → arthur_common-2.1.53.dist-info}/RECORD +18 -18
- {arthur_common-2.1.52.dist-info → arthur_common-2.1.53.dist-info}/WHEEL +0 -0
|
@@ -29,6 +29,12 @@ class AggregationFunction(ABC):
|
|
|
29
29
|
def aggregation_type(self) -> Type[SketchMetric] | Type[NumericMetric]:
|
|
30
30
|
raise NotImplementedError
|
|
31
31
|
|
|
32
|
+
@staticmethod
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
35
|
+
"""Returns the list of aggregations reported by the aggregate function."""
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
32
38
|
@abstractmethod
|
|
33
39
|
def aggregate(
|
|
34
40
|
self,
|
|
@@ -4,7 +4,11 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
NumericMetric,
|
|
10
|
+
BaseReportedAggregation,
|
|
11
|
+
)
|
|
8
12
|
from arthur_common.models.schema_definitions import (
|
|
9
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
14
|
DType,
|
|
@@ -32,6 +36,15 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
|
|
|
32
36
|
def description() -> str:
|
|
33
37
|
return "Metric that counts the number of discrete values of each category in a string column. Creates a separate dimension for each category and the values are the count of occurrences of that category in the time window."
|
|
34
38
|
|
|
39
|
+
@staticmethod
|
|
40
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
41
|
+
return [
|
|
42
|
+
BaseReportedAggregation(
|
|
43
|
+
metric_name=CategoricalCountAggregationFunction.METRIC_NAME,
|
|
44
|
+
description=CategoricalCountAggregationFunction.description(),
|
|
45
|
+
)
|
|
46
|
+
]
|
|
47
|
+
|
|
35
48
|
def aggregate(
|
|
36
49
|
self,
|
|
37
50
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -20,6 +24,32 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str
|
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
|
|
27
|
+
TRUE_POSITIVE_METRIC_NAME = "confusion_matrix_true_positive_count"
|
|
28
|
+
FALSE_POSITIVE_METRIC_NAME = "confusion_matrix_false_positive_count"
|
|
29
|
+
FALSE_NEGATIVE_METRIC_NAME = "confusion_matrix_false_negative_count"
|
|
30
|
+
TRUE_NEGATIVE_METRIC_NAME = "confusion_matrix_true_negative_count"
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
34
|
+
return [
|
|
35
|
+
BaseReportedAggregation(
|
|
36
|
+
metric_name=ConfusionMatrixAggregationFunction.TRUE_POSITIVE_METRIC_NAME,
|
|
37
|
+
description="Confusion matrix true positives count.",
|
|
38
|
+
),
|
|
39
|
+
BaseReportedAggregation(
|
|
40
|
+
metric_name=ConfusionMatrixAggregationFunction.FALSE_POSITIVE_METRIC_NAME,
|
|
41
|
+
description="Confusion matrix false positives count.",
|
|
42
|
+
),
|
|
43
|
+
BaseReportedAggregation(
|
|
44
|
+
metric_name=ConfusionMatrixAggregationFunction.FALSE_NEGATIVE_METRIC_NAME,
|
|
45
|
+
description="Confusion matrix false negatives count.",
|
|
46
|
+
),
|
|
47
|
+
BaseReportedAggregation(
|
|
48
|
+
metric_name=ConfusionMatrixAggregationFunction.TRUE_NEGATIVE_METRIC_NAME,
|
|
49
|
+
description="Confusion matrix true negatives count.",
|
|
50
|
+
),
|
|
51
|
+
]
|
|
52
|
+
|
|
23
53
|
def generate_confusion_matrix_metrics(
|
|
24
54
|
self,
|
|
25
55
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -129,10 +159,10 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
|
|
|
129
159
|
dim_columns=segmentation_cols + extra_dims,
|
|
130
160
|
timestamp_col="ts",
|
|
131
161
|
)
|
|
132
|
-
tp_metric = self.series_to_metric(
|
|
133
|
-
fp_metric = self.series_to_metric(
|
|
134
|
-
fn_metric = self.series_to_metric(
|
|
135
|
-
tn_metric = self.series_to_metric(
|
|
162
|
+
tp_metric = self.series_to_metric(self.TRUE_POSITIVE_METRIC_NAME, tp)
|
|
163
|
+
fp_metric = self.series_to_metric(self.FALSE_POSITIVE_METRIC_NAME, fp)
|
|
164
|
+
fn_metric = self.series_to_metric(self.FALSE_NEGATIVE_METRIC_NAME, fn)
|
|
165
|
+
tn_metric = self.series_to_metric(self.TRUE_NEGATIVE_METRIC_NAME, tn)
|
|
136
166
|
return [tp_metric, fp_metric, fn_metric, tn_metric]
|
|
137
167
|
|
|
138
168
|
|
|
@@ -4,7 +4,11 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
NumericMetric,
|
|
10
|
+
BaseReportedAggregation,
|
|
11
|
+
)
|
|
8
12
|
from arthur_common.models.schema_definitions import (
|
|
9
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
14
|
DType,
|
|
@@ -32,6 +36,15 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
|
|
|
32
36
|
def description() -> str:
|
|
33
37
|
return "Metric that counts the number of inferences per time window."
|
|
34
38
|
|
|
39
|
+
@staticmethod
|
|
40
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
41
|
+
return [
|
|
42
|
+
BaseReportedAggregation(
|
|
43
|
+
metric_name=InferenceCountAggregationFunction.METRIC_NAME,
|
|
44
|
+
description=InferenceCountAggregationFunction.description(),
|
|
45
|
+
)
|
|
46
|
+
]
|
|
47
|
+
|
|
35
48
|
def aggregate(
|
|
36
49
|
self,
|
|
37
50
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -36,6 +40,15 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
|
|
|
36
40
|
def _metric_name() -> str:
|
|
37
41
|
return "binary_classifier_count_by_class"
|
|
38
42
|
|
|
43
|
+
@staticmethod
|
|
44
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
45
|
+
return [
|
|
46
|
+
BaseReportedAggregation(
|
|
47
|
+
metric_name=BinaryClassifierCountByClassAggregationFunction._metric_name(),
|
|
48
|
+
description=BinaryClassifierCountByClassAggregationFunction.description(),
|
|
49
|
+
)
|
|
50
|
+
]
|
|
51
|
+
|
|
39
52
|
def aggregate(
|
|
40
53
|
self,
|
|
41
54
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -153,6 +166,15 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
|
|
|
153
166
|
def _metric_name() -> str:
|
|
154
167
|
return "binary_classifier_count_by_class"
|
|
155
168
|
|
|
169
|
+
@staticmethod
|
|
170
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
171
|
+
return [
|
|
172
|
+
BaseReportedAggregation(
|
|
173
|
+
metric_name=BinaryClassifierCountThresholdClassAggregationFunction._metric_name(),
|
|
174
|
+
description=BinaryClassifierCountThresholdClassAggregationFunction.description(),
|
|
175
|
+
)
|
|
176
|
+
]
|
|
177
|
+
|
|
156
178
|
def aggregate(
|
|
157
179
|
self,
|
|
158
180
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -4,7 +4,12 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
Dimension,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
8
13
|
from arthur_common.models.schema_definitions import (
|
|
9
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
15
|
DType,
|
|
@@ -32,6 +37,15 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
|
|
|
32
37
|
def description() -> str:
|
|
33
38
|
return "Metric that counts the number of null values in the column per time window."
|
|
34
39
|
|
|
40
|
+
@staticmethod
|
|
41
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
42
|
+
return [
|
|
43
|
+
BaseReportedAggregation(
|
|
44
|
+
metric_name=InferenceNullCountAggregationFunction.METRIC_NAME,
|
|
45
|
+
description=InferenceNullCountAggregationFunction.description(),
|
|
46
|
+
)
|
|
47
|
+
]
|
|
48
|
+
|
|
35
49
|
def aggregate(
|
|
36
50
|
self,
|
|
37
51
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -19,6 +23,9 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
|
|
26
|
+
ABSOLUTE_ERROR_COUNT_METRIC_NAME = "absolute_error_count"
|
|
27
|
+
ABSOLUTE_ERROR_SUM_METRIC_NAME = "absolute_error_sum"
|
|
28
|
+
|
|
22
29
|
@staticmethod
|
|
23
30
|
def id() -> UUID:
|
|
24
31
|
return UUID("00000000-0000-0000-0000-00000000000e")
|
|
@@ -31,6 +38,19 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
|
|
|
31
38
|
def description() -> str:
|
|
32
39
|
return "Metric that sums the absolute error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
|
|
33
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_COUNT_METRIC_NAME,
|
|
46
|
+
description="Sum of the absolute error of a prediction and ground truth column, omitting rows where either column is null.",
|
|
47
|
+
),
|
|
48
|
+
BaseReportedAggregation(
|
|
49
|
+
metric_name=MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_SUM_METRIC_NAME,
|
|
50
|
+
description=f"Count of non-null rows used in the calculation of the {MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_SUM_METRIC_NAME} metric.",
|
|
51
|
+
),
|
|
52
|
+
]
|
|
53
|
+
|
|
34
54
|
def aggregate(
|
|
35
55
|
self,
|
|
36
56
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -138,9 +158,11 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
|
|
|
138
158
|
"ts",
|
|
139
159
|
)
|
|
140
160
|
|
|
141
|
-
count_metric = self.series_to_metric(
|
|
161
|
+
count_metric = self.series_to_metric(
|
|
162
|
+
self.ABSOLUTE_ERROR_COUNT_METRIC_NAME, count_series
|
|
163
|
+
)
|
|
142
164
|
absolute_error_metric = self.series_to_metric(
|
|
143
|
-
|
|
165
|
+
self.ABSOLUTE_ERROR_SUM_METRIC_NAME,
|
|
144
166
|
absolute_error_series,
|
|
145
167
|
)
|
|
146
168
|
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -19,6 +23,9 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
|
|
26
|
+
SQUARED_ERROR_COUNT_METRIC_NAME = "squared_error_count"
|
|
27
|
+
SQUARED_ERROR_SUM_METRIC_NAME = "squared_error_sum"
|
|
28
|
+
|
|
22
29
|
@staticmethod
|
|
23
30
|
def id() -> UUID:
|
|
24
31
|
return UUID("00000000-0000-0000-0000-000000000010")
|
|
@@ -31,6 +38,19 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
|
|
|
31
38
|
def description() -> str:
|
|
32
39
|
return "Metric that sums the squared error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
|
|
33
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=MeanSquaredErrorAggregationFunction.SQUARED_ERROR_SUM_METRIC_NAME,
|
|
46
|
+
description="Sum of the squared error of a prediction and ground truth column, omitting rows where either column is null.",
|
|
47
|
+
),
|
|
48
|
+
BaseReportedAggregation(
|
|
49
|
+
metric_name=MeanSquaredErrorAggregationFunction.SQUARED_ERROR_COUNT_METRIC_NAME,
|
|
50
|
+
description=f"Count of non-null rows used in the calculation of the {MeanSquaredErrorAggregationFunction.SQUARED_ERROR_SUM_METRIC_NAME} metric.",
|
|
51
|
+
),
|
|
52
|
+
]
|
|
53
|
+
|
|
34
54
|
def aggregate(
|
|
35
55
|
self,
|
|
36
56
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -138,9 +158,11 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
|
|
|
138
158
|
"ts",
|
|
139
159
|
)
|
|
140
160
|
|
|
141
|
-
count_metric = self.series_to_metric(
|
|
161
|
+
count_metric = self.series_to_metric(
|
|
162
|
+
self.SQUARED_ERROR_COUNT_METRIC_NAME, count_series
|
|
163
|
+
)
|
|
142
164
|
absolute_error_metric = self.series_to_metric(
|
|
143
|
-
|
|
165
|
+
self.SQUARED_ERROR_SUM_METRIC_NAME,
|
|
144
166
|
squared_error_series,
|
|
145
167
|
)
|
|
146
168
|
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -22,6 +26,19 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str
|
|
|
22
26
|
class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
|
|
23
27
|
NumericAggregationFunction,
|
|
24
28
|
):
|
|
29
|
+
MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME = (
|
|
30
|
+
"multiclass_confusion_matrix_single_class_true_positive_count"
|
|
31
|
+
)
|
|
32
|
+
MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME = (
|
|
33
|
+
"multiclass_confusion_matrix_single_class_false_positive_count"
|
|
34
|
+
)
|
|
35
|
+
MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME = (
|
|
36
|
+
"multiclass_confusion_matrix_single_class_false_negative_count"
|
|
37
|
+
)
|
|
38
|
+
MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME = (
|
|
39
|
+
"multiclass_confusion_matrix_single_class_true_negative_count"
|
|
40
|
+
)
|
|
41
|
+
|
|
25
42
|
@staticmethod
|
|
26
43
|
def id() -> UUID:
|
|
27
44
|
return UUID("dc728927-6928-4a3b-b174-8c1ec8b58d62")
|
|
@@ -38,6 +55,27 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
|
|
|
38
55
|
"False Negatives, True Negatives) for that class compared to all others."
|
|
39
56
|
)
|
|
40
57
|
|
|
58
|
+
@staticmethod
|
|
59
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
60
|
+
return [
|
|
61
|
+
BaseReportedAggregation(
|
|
62
|
+
metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME,
|
|
63
|
+
description="Confusion matrix true positives count.",
|
|
64
|
+
),
|
|
65
|
+
BaseReportedAggregation(
|
|
66
|
+
metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME,
|
|
67
|
+
description="Confusion matrix false positives count.",
|
|
68
|
+
),
|
|
69
|
+
BaseReportedAggregation(
|
|
70
|
+
metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME,
|
|
71
|
+
description="Confusion matrix false negatives count.",
|
|
72
|
+
),
|
|
73
|
+
BaseReportedAggregation(
|
|
74
|
+
metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME,
|
|
75
|
+
description="Confusion matrix true negatives count.",
|
|
76
|
+
),
|
|
77
|
+
]
|
|
78
|
+
|
|
41
79
|
def aggregate(
|
|
42
80
|
self,
|
|
43
81
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -238,19 +276,19 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
|
|
|
238
276
|
timestamp_col="ts",
|
|
239
277
|
)
|
|
240
278
|
tp_metric = self.series_to_metric(
|
|
241
|
-
|
|
279
|
+
self.MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME,
|
|
242
280
|
tp,
|
|
243
281
|
)
|
|
244
282
|
fp_metric = self.series_to_metric(
|
|
245
|
-
|
|
283
|
+
self.MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME,
|
|
246
284
|
fp,
|
|
247
285
|
)
|
|
248
286
|
fn_metric = self.series_to_metric(
|
|
249
|
-
|
|
287
|
+
self.MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME,
|
|
250
288
|
fn,
|
|
251
289
|
)
|
|
252
290
|
tn_metric = self.series_to_metric(
|
|
253
|
-
|
|
291
|
+
self.MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME,
|
|
254
292
|
tn,
|
|
255
293
|
)
|
|
256
294
|
return [tp_metric, fp_metric, fn_metric, tn_metric]
|
|
@@ -7,7 +7,11 @@ from arthur_common.aggregations.functions.inference_count_by_class import (
|
|
|
7
7
|
BinaryClassifierCountByClassAggregationFunction,
|
|
8
8
|
)
|
|
9
9
|
from arthur_common.models.datasets import ModelProblemType
|
|
10
|
-
from arthur_common.models.metrics import
|
|
10
|
+
from arthur_common.models.metrics import (
|
|
11
|
+
DatasetReference,
|
|
12
|
+
NumericMetric,
|
|
13
|
+
BaseReportedAggregation,
|
|
14
|
+
)
|
|
11
15
|
from arthur_common.models.schema_definitions import (
|
|
12
16
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
13
17
|
DType,
|
|
@@ -47,6 +51,15 @@ class MulticlassClassifierCountByClassAggregationFunction(
|
|
|
47
51
|
def _metric_name() -> str:
|
|
48
52
|
return "multiclass_classifier_count_by_class"
|
|
49
53
|
|
|
54
|
+
@staticmethod
|
|
55
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
56
|
+
return [
|
|
57
|
+
BaseReportedAggregation(
|
|
58
|
+
metric_name=MulticlassClassifierCountByClassAggregationFunction._metric_name(),
|
|
59
|
+
description=MulticlassClassifierCountByClassAggregationFunction.description(),
|
|
60
|
+
)
|
|
61
|
+
]
|
|
62
|
+
|
|
50
63
|
def aggregate(
|
|
51
64
|
self,
|
|
52
65
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -4,7 +4,11 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import SketchAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
SketchMetric,
|
|
10
|
+
BaseReportedAggregation,
|
|
11
|
+
)
|
|
8
12
|
from arthur_common.models.schema_definitions import (
|
|
9
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
14
|
DType,
|
|
@@ -34,6 +38,15 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
|
|
|
34
38
|
"Metric that calculates a distribution (data sketch) on a numeric column."
|
|
35
39
|
)
|
|
36
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=NumericSketchAggregationFunction.METRIC_NAME,
|
|
46
|
+
description=NumericSketchAggregationFunction.description(),
|
|
47
|
+
)
|
|
48
|
+
]
|
|
49
|
+
|
|
37
50
|
def aggregate(
|
|
38
51
|
self,
|
|
39
52
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -4,7 +4,12 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
Dimension,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
8
13
|
from arthur_common.models.schema_definitions import (
|
|
9
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
15
|
DType,
|
|
@@ -32,6 +37,15 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
|
|
|
32
37
|
def description() -> str:
|
|
33
38
|
return "Metric that reports the sum of the numeric column per time window."
|
|
34
39
|
|
|
40
|
+
@staticmethod
|
|
41
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
42
|
+
return [
|
|
43
|
+
BaseReportedAggregation(
|
|
44
|
+
metric_name=NumericSumAggregationFunction.METRIC_NAME,
|
|
45
|
+
description=NumericSumAggregationFunction.description(),
|
|
46
|
+
)
|
|
47
|
+
]
|
|
48
|
+
|
|
35
49
|
def aggregate(
|
|
36
50
|
self,
|
|
37
51
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -10,7 +10,12 @@ from arthur_common.aggregations.aggregator import (
|
|
|
10
10
|
SketchAggregationFunction,
|
|
11
11
|
)
|
|
12
12
|
from arthur_common.models.datasets import ModelProblemType
|
|
13
|
-
from arthur_common.models.metrics import
|
|
13
|
+
from arthur_common.models.metrics import (
|
|
14
|
+
DatasetReference,
|
|
15
|
+
NumericMetric,
|
|
16
|
+
SketchMetric,
|
|
17
|
+
BaseReportedAggregation,
|
|
18
|
+
)
|
|
14
19
|
from arthur_common.models.schema_definitions import (
|
|
15
20
|
SHIELD_RESPONSE_SCHEMA,
|
|
16
21
|
MetricColumnParameterAnnotation,
|
|
@@ -33,6 +38,15 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
|
|
|
33
38
|
def description() -> str:
|
|
34
39
|
return "Metric that counts the number of Shield inferences grouped by the prompt, response, and overall check results."
|
|
35
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=ShieldInferencePassFailCountAggregation.METRIC_NAME,
|
|
46
|
+
description=ShieldInferencePassFailCountAggregation.description(),
|
|
47
|
+
)
|
|
48
|
+
]
|
|
49
|
+
|
|
36
50
|
def aggregate(
|
|
37
51
|
self,
|
|
38
52
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -93,6 +107,15 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
|
|
|
93
107
|
def description() -> str:
|
|
94
108
|
return "Metric that counts the number of Shield rule evaluations grouped by whether it was on the prompt or response, the rule type, the rule evaluation result, the rule name, and the rule id."
|
|
95
109
|
|
|
110
|
+
@staticmethod
|
|
111
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
112
|
+
return [
|
|
113
|
+
BaseReportedAggregation(
|
|
114
|
+
metric_name=ShieldInferenceRuleCountAggregation.METRIC_NAME,
|
|
115
|
+
description=ShieldInferenceRuleCountAggregation.description(),
|
|
116
|
+
)
|
|
117
|
+
]
|
|
118
|
+
|
|
96
119
|
def aggregate(
|
|
97
120
|
self,
|
|
98
121
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -176,6 +199,15 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
|
|
|
176
199
|
def description() -> str:
|
|
177
200
|
return "Metric that counts the number of Shield hallucination evaluations that failed."
|
|
178
201
|
|
|
202
|
+
@staticmethod
|
|
203
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
204
|
+
return [
|
|
205
|
+
BaseReportedAggregation(
|
|
206
|
+
metric_name=ShieldInferenceHallucinationCountAggregation.METRIC_NAME,
|
|
207
|
+
description=ShieldInferenceHallucinationCountAggregation.description(),
|
|
208
|
+
)
|
|
209
|
+
]
|
|
210
|
+
|
|
179
211
|
def aggregate(
|
|
180
212
|
self,
|
|
181
213
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -231,6 +263,15 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
|
|
|
231
263
|
def description() -> str:
|
|
232
264
|
return "Metric that reports a distribution (data sketch) on toxicity scores returned by the Shield toxicity rule."
|
|
233
265
|
|
|
266
|
+
@staticmethod
|
|
267
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
268
|
+
return [
|
|
269
|
+
BaseReportedAggregation(
|
|
270
|
+
metric_name=ShieldInferenceRuleToxicityScoreAggregation.METRIC_NAME,
|
|
271
|
+
description=ShieldInferenceRuleToxicityScoreAggregation.description(),
|
|
272
|
+
)
|
|
273
|
+
]
|
|
274
|
+
|
|
234
275
|
def aggregate(
|
|
235
276
|
self,
|
|
236
277
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -307,6 +348,15 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
|
|
|
307
348
|
def description() -> str:
|
|
308
349
|
return "Metric that reports a distribution (data sketch) on PII scores returned by the Shield PII rule."
|
|
309
350
|
|
|
351
|
+
@staticmethod
|
|
352
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
353
|
+
return [
|
|
354
|
+
BaseReportedAggregation(
|
|
355
|
+
metric_name=ShieldInferenceRulePIIDataScoreAggregation.METRIC_NAME,
|
|
356
|
+
description=ShieldInferenceRulePIIDataScoreAggregation.description(),
|
|
357
|
+
)
|
|
358
|
+
]
|
|
359
|
+
|
|
310
360
|
def aggregate(
|
|
311
361
|
self,
|
|
312
362
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -389,6 +439,15 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
|
|
|
389
439
|
def description() -> str:
|
|
390
440
|
return "Metric that reports a distribution (data sketch) on over the number of claims identified by the Shield hallucination rule."
|
|
391
441
|
|
|
442
|
+
@staticmethod
|
|
443
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
444
|
+
return [
|
|
445
|
+
BaseReportedAggregation(
|
|
446
|
+
metric_name=ShieldInferenceRuleClaimCountAggregation.METRIC_NAME,
|
|
447
|
+
description=ShieldInferenceRuleClaimCountAggregation.description(),
|
|
448
|
+
)
|
|
449
|
+
]
|
|
450
|
+
|
|
392
451
|
def aggregate(
|
|
393
452
|
self,
|
|
394
453
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -453,6 +512,15 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
|
|
|
453
512
|
def description() -> str:
|
|
454
513
|
return "Metric that reports a distribution (data sketch) on the number of valid claims determined by the Shield hallucination rule."
|
|
455
514
|
|
|
515
|
+
@staticmethod
|
|
516
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
517
|
+
return [
|
|
518
|
+
BaseReportedAggregation(
|
|
519
|
+
metric_name=ShieldInferenceRuleClaimPassCountAggregation.METRIC_NAME,
|
|
520
|
+
description=ShieldInferenceRuleClaimPassCountAggregation.description(),
|
|
521
|
+
)
|
|
522
|
+
]
|
|
523
|
+
|
|
456
524
|
def aggregate(
|
|
457
525
|
self,
|
|
458
526
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -517,6 +585,15 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
|
|
|
517
585
|
def description() -> str:
|
|
518
586
|
return "Metric that reports a distribution (data sketch) on the number of invalid claims determined by the Shield hallucination rule."
|
|
519
587
|
|
|
588
|
+
@staticmethod
|
|
589
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
590
|
+
return [
|
|
591
|
+
BaseReportedAggregation(
|
|
592
|
+
metric_name=ShieldInferenceRuleClaimFailCountAggregation.METRIC_NAME,
|
|
593
|
+
description=ShieldInferenceRuleClaimFailCountAggregation.description(),
|
|
594
|
+
)
|
|
595
|
+
]
|
|
596
|
+
|
|
520
597
|
def aggregate(
|
|
521
598
|
self,
|
|
522
599
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -581,6 +658,15 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
581
658
|
def description() -> str:
|
|
582
659
|
return "Metric that reports a distribution (data sketch) on the latency of Shield rule evaluations. Dimensions are the rule result, rule type, and whether the rule was applicable to a prompt or response."
|
|
583
660
|
|
|
661
|
+
@staticmethod
|
|
662
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
663
|
+
return [
|
|
664
|
+
BaseReportedAggregation(
|
|
665
|
+
metric_name=ShieldInferenceRuleLatencyAggregation.METRIC_NAME,
|
|
666
|
+
description=ShieldInferenceRuleLatencyAggregation.description(),
|
|
667
|
+
)
|
|
668
|
+
]
|
|
669
|
+
|
|
584
670
|
def aggregate(
|
|
585
671
|
self,
|
|
586
672
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -643,6 +729,18 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
643
729
|
|
|
644
730
|
class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
645
731
|
METRIC_NAME = "token_count"
|
|
732
|
+
SUPPORTED_MODELS = [
|
|
733
|
+
"gpt-4o",
|
|
734
|
+
"gpt-4o-mini",
|
|
735
|
+
"gpt-3.5-turbo",
|
|
736
|
+
"o1-mini",
|
|
737
|
+
"deepseek-chat",
|
|
738
|
+
"claude-3-5-sonnet-20241022",
|
|
739
|
+
"gemini/gemini-1.5-pro",
|
|
740
|
+
"meta.llama3-1-8b-instruct-v1:0",
|
|
741
|
+
"meta.llama3-1-70b-instruct-v1:0",
|
|
742
|
+
"meta.llama3-2-11b-instruct-v1:0",
|
|
743
|
+
]
|
|
646
744
|
|
|
647
745
|
@staticmethod
|
|
648
746
|
def id() -> UUID:
|
|
@@ -656,6 +754,27 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
656
754
|
def description() -> str:
|
|
657
755
|
return "Metric that reports the number of tokens in the Shield response and prompt schemas, and their estimated cost."
|
|
658
756
|
|
|
757
|
+
@staticmethod
|
|
758
|
+
def _series_name_from_model_name(model_name: str) -> str:
|
|
759
|
+
"""Calculates name of reported series based on the model name considered."""
|
|
760
|
+
return f"token_cost.{model_name}"
|
|
761
|
+
|
|
762
|
+
@staticmethod
|
|
763
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
764
|
+
base_token_count_agg = BaseReportedAggregation(
|
|
765
|
+
metric_name=ShieldInferenceTokenCountAggregation.METRIC_NAME,
|
|
766
|
+
description=f"Metric that reports the number of tokens in the Shield response and prompt schemas.",
|
|
767
|
+
)
|
|
768
|
+
return [base_token_count_agg] + [
|
|
769
|
+
BaseReportedAggregation(
|
|
770
|
+
metric_name=ShieldInferenceTokenCountAggregation._series_name_from_model_name(
|
|
771
|
+
model_name
|
|
772
|
+
),
|
|
773
|
+
description=f"Metric that reports the estimated cost for the {model_name} model of the tokens in the Shield response and prompt schemas.",
|
|
774
|
+
)
|
|
775
|
+
for model_name in ShieldInferenceTokenCountAggregation.SUPPORTED_MODELS
|
|
776
|
+
]
|
|
777
|
+
|
|
659
778
|
def aggregate(
|
|
660
779
|
self,
|
|
661
780
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -708,25 +827,12 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
708
827
|
resp = [metric]
|
|
709
828
|
|
|
710
829
|
# Compute Cost for each model
|
|
711
|
-
models = [
|
|
712
|
-
"gpt-4o",
|
|
713
|
-
"gpt-4o-mini",
|
|
714
|
-
"gpt-3.5-turbo",
|
|
715
|
-
"o1-mini",
|
|
716
|
-
"deepseek-chat",
|
|
717
|
-
"claude-3-5-sonnet-20241022",
|
|
718
|
-
"gemini/gemini-1.5-pro",
|
|
719
|
-
"meta.llama3-1-8b-instruct-v1:0",
|
|
720
|
-
"meta.llama3-1-70b-instruct-v1:0",
|
|
721
|
-
"meta.llama3-2-11b-instruct-v1:0",
|
|
722
|
-
]
|
|
723
|
-
|
|
724
830
|
# Precompute input/output classification to avoid recalculating in loop
|
|
725
831
|
location_type = results["location"].apply(
|
|
726
832
|
lambda x: "input" if x == "prompt" else "output",
|
|
727
833
|
)
|
|
728
834
|
|
|
729
|
-
for model in
|
|
835
|
+
for model in self.SUPPORTED_MODELS:
|
|
730
836
|
# Efficient list comprehension instead of apply
|
|
731
837
|
cost_values = [
|
|
732
838
|
calculate_cost_by_tokens(int(tokens), model, loc_type)
|
|
@@ -747,5 +853,9 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
747
853
|
["location"],
|
|
748
854
|
"ts",
|
|
749
855
|
)
|
|
750
|
-
resp.append(
|
|
856
|
+
resp.append(
|
|
857
|
+
self.series_to_metric(
|
|
858
|
+
self._series_name_from_model_name(model), model_series
|
|
859
|
+
)
|
|
860
|
+
)
|
|
751
861
|
return resp
|
arthur_common/models/metrics.py
CHANGED
|
@@ -193,7 +193,9 @@ class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSc
|
|
|
193
193
|
parameter_type: Literal["column"] = "column"
|
|
194
194
|
|
|
195
195
|
|
|
196
|
-
class MetricsColumnListParameterSchema(
|
|
196
|
+
class MetricsColumnListParameterSchema(
|
|
197
|
+
MetricsParameterSchema, BaseColumnParameterSchema
|
|
198
|
+
):
|
|
197
199
|
# list column parameter schema specific to default metrics
|
|
198
200
|
parameter_type: Literal["column_list"] = "column_list"
|
|
199
201
|
|
|
@@ -211,9 +213,7 @@ MetricsColumnSchemaUnion = (
|
|
|
211
213
|
|
|
212
214
|
|
|
213
215
|
CustomAggregationParametersSchemaUnion = (
|
|
214
|
-
BaseDatasetParameterSchema
|
|
215
|
-
| BaseLiteralParameterSchema
|
|
216
|
-
| BaseColumnParameterSchema
|
|
216
|
+
BaseDatasetParameterSchema | BaseLiteralParameterSchema | BaseColumnParameterSchema
|
|
217
217
|
)
|
|
218
218
|
|
|
219
219
|
|
|
@@ -224,6 +224,14 @@ class DatasetReference:
|
|
|
224
224
|
dataset_id: UUID
|
|
225
225
|
|
|
226
226
|
|
|
227
|
+
class BaseReportedAggregation(BaseModel):
|
|
228
|
+
# in future will be used by default metrics
|
|
229
|
+
metric_name: str = Field(description="Name of the reported aggregation metric.")
|
|
230
|
+
description: str = Field(
|
|
231
|
+
description="Description of the reported aggregation metric and what it aggregates.",
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
227
235
|
class AggregationSpecSchema(BaseModel):
|
|
228
236
|
name: str = Field(description="Name of the aggregation function.")
|
|
229
237
|
id: UUID = Field(description="Unique identifier of the aggregation function.")
|
|
@@ -240,6 +248,17 @@ class AggregationSpecSchema(BaseModel):
|
|
|
240
248
|
aggregate_args: list[MetricsParameterSchemaUnion] = Field(
|
|
241
249
|
description="List of parameters to the aggregation's aggregate function.",
|
|
242
250
|
)
|
|
251
|
+
reported_aggregations: list[BaseReportedAggregation] = Field(
|
|
252
|
+
description="List of aggregations reported by the metric."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
@model_validator(mode="after")
|
|
256
|
+
def at_least_one_reported_agg(self) -> Self:
|
|
257
|
+
if len(self.reported_aggregations) < 1:
|
|
258
|
+
raise ValueError(
|
|
259
|
+
"Aggregation spec must specify at least one reported aggregation."
|
|
260
|
+
)
|
|
261
|
+
return self
|
|
243
262
|
|
|
244
263
|
@model_validator(mode="after")
|
|
245
264
|
def column_dataset_references_exist(self) -> Self:
|
|
@@ -262,26 +281,23 @@ class AggregationSpecSchema(BaseModel):
|
|
|
262
281
|
return self
|
|
263
282
|
|
|
264
283
|
|
|
265
|
-
class BaseReportedAggregation(BaseModel):
|
|
266
|
-
# in future will be used by default metrics
|
|
267
|
-
metric_name: str = Field(description="Name of the reported aggregation metric.")
|
|
268
|
-
description: str = Field(
|
|
269
|
-
description="Description of the reported aggregation metric and what it aggregates.",
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
|
|
273
284
|
class ReportedCustomAggregation(BaseReportedAggregation):
|
|
274
|
-
value_column: str = Field(
|
|
275
|
-
|
|
285
|
+
value_column: str = Field(
|
|
286
|
+
description="Name of the column returned from the SQL query holding the metric value."
|
|
287
|
+
)
|
|
288
|
+
timestamp_column: str = Field(
|
|
289
|
+
description="Name of the column returned from the SQL query holding the timestamp buckets."
|
|
290
|
+
)
|
|
276
291
|
metric_kind: AggregationMetricType = Field(
|
|
277
292
|
description="Return type of the reported aggregation metric value.",
|
|
278
293
|
)
|
|
279
|
-
dimension_columns: list[str] = Field(
|
|
294
|
+
dimension_columns: list[str] = Field(
|
|
295
|
+
description="Name of any dimension columns returned from the SQL query. Max length is 1."
|
|
296
|
+
)
|
|
280
297
|
|
|
281
|
-
@field_validator(
|
|
298
|
+
@field_validator("dimension_columns")
|
|
282
299
|
@classmethod
|
|
283
300
|
def validate_dimension_columns_length(cls, v: list[str]) -> str:
|
|
284
301
|
if len(v) > 1:
|
|
285
|
-
raise ValueError(
|
|
302
|
+
raise ValueError("Only one dimension column can be specified.")
|
|
286
303
|
return v
|
|
287
|
-
|
|
@@ -207,7 +207,7 @@ class FunctionAnalyzer:
|
|
|
207
207
|
)
|
|
208
208
|
# Check if X implements the required methods
|
|
209
209
|
required_methods = ["aggregate", "id", "description", "display_name"]
|
|
210
|
-
static_methods = ["description", "id", "display_name"]
|
|
210
|
+
static_methods = ["description", "id", "display_name", "reported_aggregations"]
|
|
211
211
|
for method in required_methods:
|
|
212
212
|
if not hasattr(agg_func, method) or not callable(getattr(agg_func, method)):
|
|
213
213
|
raise AttributeError(
|
|
@@ -253,6 +253,7 @@ class FunctionAnalyzer:
|
|
|
253
253
|
metric_type=metric_type,
|
|
254
254
|
init_args=aggregation_init_args,
|
|
255
255
|
aggregate_args=aggregate_args,
|
|
256
|
+
reported_aggregations=agg_func.reported_aggregations(),
|
|
256
257
|
)
|
|
257
258
|
|
|
258
259
|
|
|
@@ -1,21 +1,21 @@
|
|
|
1
1
|
arthur_common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
arthur_common/aggregations/__init__.py,sha256=vISWyciQAtksa71OKeHNP-QyFGd1NzBKq_LBsG0QSG8,67
|
|
3
|
-
arthur_common/aggregations/aggregator.py,sha256=
|
|
3
|
+
arthur_common/aggregations/aggregator.py,sha256=kS9Qru0AhZzZz4Ym20NT7aNrbcQaqg2zgBVYFogFbbg,7936
|
|
4
4
|
arthur_common/aggregations/functions/README.md,sha256=MkZoTAJ94My96R5Z8GAxud7S6vyR0vgVi9gqdt9a4XY,5460
|
|
5
5
|
arthur_common/aggregations/functions/__init__.py,sha256=HqC3UNRURX7ZQHgamTrQvfA8u_FiZGZ4I4eQW7Ooe5o,1299
|
|
6
|
-
arthur_common/aggregations/functions/categorical_count.py,sha256=
|
|
7
|
-
arthur_common/aggregations/functions/confusion_matrix.py,sha256=
|
|
8
|
-
arthur_common/aggregations/functions/inference_count.py,sha256=
|
|
9
|
-
arthur_common/aggregations/functions/inference_count_by_class.py,sha256=
|
|
10
|
-
arthur_common/aggregations/functions/inference_null_count.py,sha256=
|
|
11
|
-
arthur_common/aggregations/functions/mean_absolute_error.py,sha256
|
|
12
|
-
arthur_common/aggregations/functions/mean_squared_error.py,sha256=
|
|
13
|
-
arthur_common/aggregations/functions/multiclass_confusion_matrix.py,sha256
|
|
14
|
-
arthur_common/aggregations/functions/multiclass_inference_count_by_class.py,sha256=
|
|
15
|
-
arthur_common/aggregations/functions/numeric_stats.py,sha256=
|
|
16
|
-
arthur_common/aggregations/functions/numeric_sum.py,sha256=
|
|
6
|
+
arthur_common/aggregations/functions/categorical_count.py,sha256=na22lBhxASMMR0R9Z-3qBvToYN875tJm8u2ULVdrdYQ,5327
|
|
7
|
+
arthur_common/aggregations/functions/confusion_matrix.py,sha256=MbtS_Nge7dgjNutdtzd0hx756qzLQlHS2MQxuwSuwxc,22108
|
|
8
|
+
arthur_common/aggregations/functions/inference_count.py,sha256=lO-IgcmnsfRR1qmHbWjENJUSnQT-dXwZd9rVFOtKYrs,4078
|
|
9
|
+
arthur_common/aggregations/functions/inference_count_by_class.py,sha256=sOgrMyeZh71U9uGvq8w-bYlXNPRI6jtR2jP-oV81hHo,11552
|
|
10
|
+
arthur_common/aggregations/functions/inference_null_count.py,sha256=6dfkumX8NJjTB633Pt-shY5x99TXaqSyLcYVHk_DxHc,4824
|
|
11
|
+
arthur_common/aggregations/functions/mean_absolute_error.py,sha256=-Nihcl_QcwZPn-LrHX6KgG9O-QSfoa6SY3LHt2xDCbg,6821
|
|
12
|
+
arthur_common/aggregations/functions/mean_squared_error.py,sha256=kpADLvsJkg7C07nj5X1drk8ChRXvur_PjkzMB2uLazg,6842
|
|
13
|
+
arthur_common/aggregations/functions/multiclass_confusion_matrix.py,sha256=zfKK5maUy3TXmVEkqXwtXs6NM3fjp0W0yc-zS0uXZT4,12615
|
|
14
|
+
arthur_common/aggregations/functions/multiclass_inference_count_by_class.py,sha256=ZJU_GDMsq4XvqbhCAiH2J-DKrGVjXlz-E2nxrd7pM6c,4263
|
|
15
|
+
arthur_common/aggregations/functions/numeric_stats.py,sha256=4auKDwtTNxqw86gA0q3AqOf0-IM9uYWZ_tMquuug_sE,4920
|
|
16
|
+
arthur_common/aggregations/functions/numeric_sum.py,sha256=LcV2MWL-EOl0JPCozIGIoHkvohu7d2S1PHuAik-cAo4,5027
|
|
17
17
|
arthur_common/aggregations/functions/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
arthur_common/aggregations/functions/shield_aggregations.py,sha256=
|
|
18
|
+
arthur_common/aggregations/functions/shield_aggregations.py,sha256=KQzi97ILgn6UQhpQPyerrQ3CXMxs1vpuSUAAvIWf_zg,35630
|
|
19
19
|
arthur_common/aggregations/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
20
|
arthur_common/config/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
21
|
arthur_common/config/config.py,sha256=fcpjOYjPKu4Duk63CuTHrOWKQKAlAhVUR60kF_2_Xog,1247
|
|
@@ -23,14 +23,14 @@ arthur_common/config/settings.yaml,sha256=0CrygUwJzC5mGcO5Xnvv2ttp-P7LIsx682jllY
|
|
|
23
23
|
arthur_common/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
arthur_common/models/connectors.py,sha256=5f5DUgOQ16P3lBPZ0zpUv9kTAqw45Agrl526F-iFJes,1862
|
|
25
25
|
arthur_common/models/datasets.py,sha256=giG_8mv_3ilBf7cIvRV0_TDCDdb4qxRbYZvl7hRb6l8,491
|
|
26
|
-
arthur_common/models/metrics.py,sha256=
|
|
26
|
+
arthur_common/models/metrics.py,sha256=8_7ec0oFIjFGJpgRWS0Y28aaGCSd3j7dqa_QTYCNGus,11343
|
|
27
27
|
arthur_common/models/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
28
28
|
arthur_common/models/schema_definitions.py,sha256=0zXZKHKr49d7ATml2Tzw1AIFfM0i0HjIblM-qOwNxk8,14878
|
|
29
29
|
arthur_common/models/shield.py,sha256=62SKLzlsUsuP3u7EnibtI1CrRYg3TummP4Wbwg5ZPUs,18310
|
|
30
30
|
arthur_common/models/task_job_specs.py,sha256=uZo8eiTBHWf2EZGEQrDfJGVyYg_8wd9MHWLxn-5oNUk,2797
|
|
31
31
|
arthur_common/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
32
|
arthur_common/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
|
-
arthur_common/tools/aggregation_analyzer.py,sha256=
|
|
33
|
+
arthur_common/tools/aggregation_analyzer.py,sha256=UfMtvFWXV2Dqly8S6nneGgomuvEGN-1tBz81tfkMcAE,11206
|
|
34
34
|
arthur_common/tools/aggregation_loader.py,sha256=3CF46bNi-GdJBNOXkjYfCQ1Aung8lf65L532sdWmR_s,2351
|
|
35
35
|
arthur_common/tools/duckdb_data_loader.py,sha256=nscmarfP5FeL8p-9e3uZhpGEV0xFqDJmR3t77HdR26U,11081
|
|
36
36
|
arthur_common/tools/duckdb_utils.py,sha256=1i-kRXu95gh4Sf9Osl2LFUpdb0yZifOjLDtIgSfSmfs,1197
|
|
@@ -38,6 +38,6 @@ arthur_common/tools/functions.py,sha256=FWL4eWO5-vLp86WudT-MGUKvf2B8f02IdoXQFKd6
|
|
|
38
38
|
arthur_common/tools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
39
|
arthur_common/tools/schema_inferer.py,sha256=Ur4CXGAkd6ZMSU0nMNrkOEElsBopHXq0lctTV8X92W8,5188
|
|
40
40
|
arthur_common/tools/time_utils.py,sha256=4gfiu9NXfvPZltiVNLSIQGylX6h2W0viNi9Kv4bKyfw,1410
|
|
41
|
-
arthur_common-2.1.
|
|
42
|
-
arthur_common-2.1.
|
|
43
|
-
arthur_common-2.1.
|
|
41
|
+
arthur_common-2.1.53.dist-info/METADATA,sha256=ezzkiB4FHTSRLK3rzvj1mqiRxYqxedgouR0q55zsVLk,1609
|
|
42
|
+
arthur_common-2.1.53.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
43
|
+
arthur_common-2.1.53.dist-info/RECORD,,
|
|
File without changes
|