arthur-common 2.1.51__tar.gz → 2.1.53__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- {arthur_common-2.1.51 → arthur_common-2.1.53}/PKG-INFO +1 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/pyproject.toml +1 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/aggregator.py +6 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/categorical_count.py +14 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/confusion_matrix.py +35 -5
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/inference_count.py +14 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/inference_count_by_class.py +23 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/inference_null_count.py +15 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/mean_absolute_error.py +25 -3
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/mean_squared_error.py +25 -3
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/multiclass_confusion_matrix.py +43 -5
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +14 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/numeric_stats.py +14 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/numeric_sum.py +15 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/shield_aggregations.py +126 -16
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/metrics.py +84 -12
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/shield.py +0 -18
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/aggregation_analyzer.py +2 -1
- {arthur_common-2.1.51 → arthur_common-2.1.53}/README.md +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/__init__.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/__init__.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/README.md +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/__init__.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/py.typed +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/py.typed +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/config/__init__.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/config/config.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/config/settings.yaml +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/__init__.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/connectors.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/datasets.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/py.typed +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/schema_definitions.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/task_job_specs.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/py.typed +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/__init__.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/aggregation_loader.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/duckdb_data_loader.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/duckdb_utils.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/functions.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/py.typed +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/schema_inferer.py +0 -0
- {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/time_utils.py +0 -0
|
@@ -29,6 +29,12 @@ class AggregationFunction(ABC):
|
|
|
29
29
|
def aggregation_type(self) -> Type[SketchMetric] | Type[NumericMetric]:
|
|
30
30
|
raise NotImplementedError
|
|
31
31
|
|
|
32
|
+
@staticmethod
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
35
|
+
"""Returns the list of aggregations reported by the aggregate function."""
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
32
38
|
@abstractmethod
|
|
33
39
|
def aggregate(
|
|
34
40
|
self,
|
|
@@ -4,7 +4,11 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
NumericMetric,
|
|
10
|
+
BaseReportedAggregation,
|
|
11
|
+
)
|
|
8
12
|
from arthur_common.models.schema_definitions import (
|
|
9
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
14
|
DType,
|
|
@@ -32,6 +36,15 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
|
|
|
32
36
|
def description() -> str:
|
|
33
37
|
return "Metric that counts the number of discrete values of each category in a string column. Creates a separate dimension for each category and the values are the count of occurrences of that category in the time window."
|
|
34
38
|
|
|
39
|
+
@staticmethod
|
|
40
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
41
|
+
return [
|
|
42
|
+
BaseReportedAggregation(
|
|
43
|
+
metric_name=CategoricalCountAggregationFunction.METRIC_NAME,
|
|
44
|
+
description=CategoricalCountAggregationFunction.description(),
|
|
45
|
+
)
|
|
46
|
+
]
|
|
47
|
+
|
|
35
48
|
def aggregate(
|
|
36
49
|
self,
|
|
37
50
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -20,6 +24,32 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str
|
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
|
|
27
|
+
TRUE_POSITIVE_METRIC_NAME = "confusion_matrix_true_positive_count"
|
|
28
|
+
FALSE_POSITIVE_METRIC_NAME = "confusion_matrix_false_positive_count"
|
|
29
|
+
FALSE_NEGATIVE_METRIC_NAME = "confusion_matrix_false_negative_count"
|
|
30
|
+
TRUE_NEGATIVE_METRIC_NAME = "confusion_matrix_true_negative_count"
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
34
|
+
return [
|
|
35
|
+
BaseReportedAggregation(
|
|
36
|
+
metric_name=ConfusionMatrixAggregationFunction.TRUE_POSITIVE_METRIC_NAME,
|
|
37
|
+
description="Confusion matrix true positives count.",
|
|
38
|
+
),
|
|
39
|
+
BaseReportedAggregation(
|
|
40
|
+
metric_name=ConfusionMatrixAggregationFunction.FALSE_POSITIVE_METRIC_NAME,
|
|
41
|
+
description="Confusion matrix false positives count.",
|
|
42
|
+
),
|
|
43
|
+
BaseReportedAggregation(
|
|
44
|
+
metric_name=ConfusionMatrixAggregationFunction.FALSE_NEGATIVE_METRIC_NAME,
|
|
45
|
+
description="Confusion matrix false negatives count.",
|
|
46
|
+
),
|
|
47
|
+
BaseReportedAggregation(
|
|
48
|
+
metric_name=ConfusionMatrixAggregationFunction.TRUE_NEGATIVE_METRIC_NAME,
|
|
49
|
+
description="Confusion matrix true negatives count.",
|
|
50
|
+
),
|
|
51
|
+
]
|
|
52
|
+
|
|
23
53
|
def generate_confusion_matrix_metrics(
|
|
24
54
|
self,
|
|
25
55
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -129,10 +159,10 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
|
|
|
129
159
|
dim_columns=segmentation_cols + extra_dims,
|
|
130
160
|
timestamp_col="ts",
|
|
131
161
|
)
|
|
132
|
-
tp_metric = self.series_to_metric(
|
|
133
|
-
fp_metric = self.series_to_metric(
|
|
134
|
-
fn_metric = self.series_to_metric(
|
|
135
|
-
tn_metric = self.series_to_metric(
|
|
162
|
+
tp_metric = self.series_to_metric(self.TRUE_POSITIVE_METRIC_NAME, tp)
|
|
163
|
+
fp_metric = self.series_to_metric(self.FALSE_POSITIVE_METRIC_NAME, fp)
|
|
164
|
+
fn_metric = self.series_to_metric(self.FALSE_NEGATIVE_METRIC_NAME, fn)
|
|
165
|
+
tn_metric = self.series_to_metric(self.TRUE_NEGATIVE_METRIC_NAME, tn)
|
|
136
166
|
return [tp_metric, fp_metric, fn_metric, tn_metric]
|
|
137
167
|
|
|
138
168
|
|
|
@@ -4,7 +4,11 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
NumericMetric,
|
|
10
|
+
BaseReportedAggregation,
|
|
11
|
+
)
|
|
8
12
|
from arthur_common.models.schema_definitions import (
|
|
9
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
14
|
DType,
|
|
@@ -32,6 +36,15 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
|
|
|
32
36
|
def description() -> str:
|
|
33
37
|
return "Metric that counts the number of inferences per time window."
|
|
34
38
|
|
|
39
|
+
@staticmethod
|
|
40
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
41
|
+
return [
|
|
42
|
+
BaseReportedAggregation(
|
|
43
|
+
metric_name=InferenceCountAggregationFunction.METRIC_NAME,
|
|
44
|
+
description=InferenceCountAggregationFunction.description(),
|
|
45
|
+
)
|
|
46
|
+
]
|
|
47
|
+
|
|
35
48
|
def aggregate(
|
|
36
49
|
self,
|
|
37
50
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -36,6 +40,15 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
|
|
|
36
40
|
def _metric_name() -> str:
|
|
37
41
|
return "binary_classifier_count_by_class"
|
|
38
42
|
|
|
43
|
+
@staticmethod
|
|
44
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
45
|
+
return [
|
|
46
|
+
BaseReportedAggregation(
|
|
47
|
+
metric_name=BinaryClassifierCountByClassAggregationFunction._metric_name(),
|
|
48
|
+
description=BinaryClassifierCountByClassAggregationFunction.description(),
|
|
49
|
+
)
|
|
50
|
+
]
|
|
51
|
+
|
|
39
52
|
def aggregate(
|
|
40
53
|
self,
|
|
41
54
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -153,6 +166,15 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
|
|
|
153
166
|
def _metric_name() -> str:
|
|
154
167
|
return "binary_classifier_count_by_class"
|
|
155
168
|
|
|
169
|
+
@staticmethod
|
|
170
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
171
|
+
return [
|
|
172
|
+
BaseReportedAggregation(
|
|
173
|
+
metric_name=BinaryClassifierCountThresholdClassAggregationFunction._metric_name(),
|
|
174
|
+
description=BinaryClassifierCountThresholdClassAggregationFunction.description(),
|
|
175
|
+
)
|
|
176
|
+
]
|
|
177
|
+
|
|
156
178
|
def aggregate(
|
|
157
179
|
self,
|
|
158
180
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -4,7 +4,12 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
Dimension,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
8
13
|
from arthur_common.models.schema_definitions import (
|
|
9
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
15
|
DType,
|
|
@@ -32,6 +37,15 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
|
|
|
32
37
|
def description() -> str:
|
|
33
38
|
return "Metric that counts the number of null values in the column per time window."
|
|
34
39
|
|
|
40
|
+
@staticmethod
|
|
41
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
42
|
+
return [
|
|
43
|
+
BaseReportedAggregation(
|
|
44
|
+
metric_name=InferenceNullCountAggregationFunction.METRIC_NAME,
|
|
45
|
+
description=InferenceNullCountAggregationFunction.description(),
|
|
46
|
+
)
|
|
47
|
+
]
|
|
48
|
+
|
|
35
49
|
def aggregate(
|
|
36
50
|
self,
|
|
37
51
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -19,6 +23,9 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
|
|
26
|
+
ABSOLUTE_ERROR_COUNT_METRIC_NAME = "absolute_error_count"
|
|
27
|
+
ABSOLUTE_ERROR_SUM_METRIC_NAME = "absolute_error_sum"
|
|
28
|
+
|
|
22
29
|
@staticmethod
|
|
23
30
|
def id() -> UUID:
|
|
24
31
|
return UUID("00000000-0000-0000-0000-00000000000e")
|
|
@@ -31,6 +38,19 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
|
|
|
31
38
|
def description() -> str:
|
|
32
39
|
return "Metric that sums the absolute error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
|
|
33
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_COUNT_METRIC_NAME,
|
|
46
|
+
description="Sum of the absolute error of a prediction and ground truth column, omitting rows where either column is null.",
|
|
47
|
+
),
|
|
48
|
+
BaseReportedAggregation(
|
|
49
|
+
metric_name=MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_SUM_METRIC_NAME,
|
|
50
|
+
description=f"Count of non-null rows used in the calculation of the {MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_SUM_METRIC_NAME} metric.",
|
|
51
|
+
),
|
|
52
|
+
]
|
|
53
|
+
|
|
34
54
|
def aggregate(
|
|
35
55
|
self,
|
|
36
56
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -138,9 +158,11 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
|
|
|
138
158
|
"ts",
|
|
139
159
|
)
|
|
140
160
|
|
|
141
|
-
count_metric = self.series_to_metric(
|
|
161
|
+
count_metric = self.series_to_metric(
|
|
162
|
+
self.ABSOLUTE_ERROR_COUNT_METRIC_NAME, count_series
|
|
163
|
+
)
|
|
142
164
|
absolute_error_metric = self.series_to_metric(
|
|
143
|
-
|
|
165
|
+
self.ABSOLUTE_ERROR_SUM_METRIC_NAME,
|
|
144
166
|
absolute_error_series,
|
|
145
167
|
)
|
|
146
168
|
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -19,6 +23,9 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier
|
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
|
|
26
|
+
SQUARED_ERROR_COUNT_METRIC_NAME = "squared_error_count"
|
|
27
|
+
SQUARED_ERROR_SUM_METRIC_NAME = "squared_error_sum"
|
|
28
|
+
|
|
22
29
|
@staticmethod
|
|
23
30
|
def id() -> UUID:
|
|
24
31
|
return UUID("00000000-0000-0000-0000-000000000010")
|
|
@@ -31,6 +38,19 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
|
|
|
31
38
|
def description() -> str:
|
|
32
39
|
return "Metric that sums the squared error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
|
|
33
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=MeanSquaredErrorAggregationFunction.SQUARED_ERROR_SUM_METRIC_NAME,
|
|
46
|
+
description="Sum of the squared error of a prediction and ground truth column, omitting rows where either column is null.",
|
|
47
|
+
),
|
|
48
|
+
BaseReportedAggregation(
|
|
49
|
+
metric_name=MeanSquaredErrorAggregationFunction.SQUARED_ERROR_COUNT_METRIC_NAME,
|
|
50
|
+
description=f"Count of non-null rows used in the calculation of the {MeanSquaredErrorAggregationFunction.SQUARED_ERROR_SUM_METRIC_NAME} metric.",
|
|
51
|
+
),
|
|
52
|
+
]
|
|
53
|
+
|
|
34
54
|
def aggregate(
|
|
35
55
|
self,
|
|
36
56
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -138,9 +158,11 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
|
|
|
138
158
|
"ts",
|
|
139
159
|
)
|
|
140
160
|
|
|
141
|
-
count_metric = self.series_to_metric(
|
|
161
|
+
count_metric = self.series_to_metric(
|
|
162
|
+
self.SQUARED_ERROR_COUNT_METRIC_NAME, count_series
|
|
163
|
+
)
|
|
142
164
|
absolute_error_metric = self.series_to_metric(
|
|
143
|
-
|
|
165
|
+
self.SQUARED_ERROR_SUM_METRIC_NAME,
|
|
144
166
|
squared_error_series,
|
|
145
167
|
)
|
|
146
168
|
|
|
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
-
from arthur_common.models.metrics import
|
|
8
|
+
from arthur_common.models.metrics import (
|
|
9
|
+
DatasetReference,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
9
13
|
from arthur_common.models.schema_definitions import (
|
|
10
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
11
15
|
DType,
|
|
@@ -22,6 +26,19 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str
|
|
|
22
26
|
class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
|
|
23
27
|
NumericAggregationFunction,
|
|
24
28
|
):
|
|
29
|
+
MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME = (
|
|
30
|
+
"multiclass_confusion_matrix_single_class_true_positive_count"
|
|
31
|
+
)
|
|
32
|
+
MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME = (
|
|
33
|
+
"multiclass_confusion_matrix_single_class_false_positive_count"
|
|
34
|
+
)
|
|
35
|
+
MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME = (
|
|
36
|
+
"multiclass_confusion_matrix_single_class_false_negative_count"
|
|
37
|
+
)
|
|
38
|
+
MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME = (
|
|
39
|
+
"multiclass_confusion_matrix_single_class_true_negative_count"
|
|
40
|
+
)
|
|
41
|
+
|
|
25
42
|
@staticmethod
|
|
26
43
|
def id() -> UUID:
|
|
27
44
|
return UUID("dc728927-6928-4a3b-b174-8c1ec8b58d62")
|
|
@@ -38,6 +55,27 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
|
|
|
38
55
|
"False Negatives, True Negatives) for that class compared to all others."
|
|
39
56
|
)
|
|
40
57
|
|
|
58
|
+
@staticmethod
|
|
59
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
60
|
+
return [
|
|
61
|
+
BaseReportedAggregation(
|
|
62
|
+
metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME,
|
|
63
|
+
description="Confusion matrix true positives count.",
|
|
64
|
+
),
|
|
65
|
+
BaseReportedAggregation(
|
|
66
|
+
metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME,
|
|
67
|
+
description="Confusion matrix false positives count.",
|
|
68
|
+
),
|
|
69
|
+
BaseReportedAggregation(
|
|
70
|
+
metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME,
|
|
71
|
+
description="Confusion matrix false negatives count.",
|
|
72
|
+
),
|
|
73
|
+
BaseReportedAggregation(
|
|
74
|
+
metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME,
|
|
75
|
+
description="Confusion matrix true negatives count.",
|
|
76
|
+
),
|
|
77
|
+
]
|
|
78
|
+
|
|
41
79
|
def aggregate(
|
|
42
80
|
self,
|
|
43
81
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -238,19 +276,19 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
|
|
|
238
276
|
timestamp_col="ts",
|
|
239
277
|
)
|
|
240
278
|
tp_metric = self.series_to_metric(
|
|
241
|
-
|
|
279
|
+
self.MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME,
|
|
242
280
|
tp,
|
|
243
281
|
)
|
|
244
282
|
fp_metric = self.series_to_metric(
|
|
245
|
-
|
|
283
|
+
self.MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME,
|
|
246
284
|
fp,
|
|
247
285
|
)
|
|
248
286
|
fn_metric = self.series_to_metric(
|
|
249
|
-
|
|
287
|
+
self.MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME,
|
|
250
288
|
fn,
|
|
251
289
|
)
|
|
252
290
|
tn_metric = self.series_to_metric(
|
|
253
|
-
|
|
291
|
+
self.MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME,
|
|
254
292
|
tn,
|
|
255
293
|
)
|
|
256
294
|
return [tp_metric, fp_metric, fn_metric, tn_metric]
|
|
@@ -7,7 +7,11 @@ from arthur_common.aggregations.functions.inference_count_by_class import (
|
|
|
7
7
|
BinaryClassifierCountByClassAggregationFunction,
|
|
8
8
|
)
|
|
9
9
|
from arthur_common.models.datasets import ModelProblemType
|
|
10
|
-
from arthur_common.models.metrics import
|
|
10
|
+
from arthur_common.models.metrics import (
|
|
11
|
+
DatasetReference,
|
|
12
|
+
NumericMetric,
|
|
13
|
+
BaseReportedAggregation,
|
|
14
|
+
)
|
|
11
15
|
from arthur_common.models.schema_definitions import (
|
|
12
16
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
13
17
|
DType,
|
|
@@ -47,6 +51,15 @@ class MulticlassClassifierCountByClassAggregationFunction(
|
|
|
47
51
|
def _metric_name() -> str:
|
|
48
52
|
return "multiclass_classifier_count_by_class"
|
|
49
53
|
|
|
54
|
+
@staticmethod
|
|
55
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
56
|
+
return [
|
|
57
|
+
BaseReportedAggregation(
|
|
58
|
+
metric_name=MulticlassClassifierCountByClassAggregationFunction._metric_name(),
|
|
59
|
+
description=MulticlassClassifierCountByClassAggregationFunction.description(),
|
|
60
|
+
)
|
|
61
|
+
]
|
|
62
|
+
|
|
50
63
|
def aggregate(
|
|
51
64
|
self,
|
|
52
65
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -4,7 +4,11 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import SketchAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
SketchMetric,
|
|
10
|
+
BaseReportedAggregation,
|
|
11
|
+
)
|
|
8
12
|
from arthur_common.models.schema_definitions import (
|
|
9
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
14
|
DType,
|
|
@@ -34,6 +38,15 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
|
|
|
34
38
|
"Metric that calculates a distribution (data sketch) on a numeric column."
|
|
35
39
|
)
|
|
36
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=NumericSketchAggregationFunction.METRIC_NAME,
|
|
46
|
+
description=NumericSketchAggregationFunction.description(),
|
|
47
|
+
)
|
|
48
|
+
]
|
|
49
|
+
|
|
37
50
|
def aggregate(
|
|
38
51
|
self,
|
|
39
52
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -4,7 +4,12 @@ from uuid import UUID
|
|
|
4
4
|
from duckdb import DuckDBPyConnection
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
|
-
from arthur_common.models.metrics import
|
|
7
|
+
from arthur_common.models.metrics import (
|
|
8
|
+
DatasetReference,
|
|
9
|
+
Dimension,
|
|
10
|
+
NumericMetric,
|
|
11
|
+
BaseReportedAggregation,
|
|
12
|
+
)
|
|
8
13
|
from arthur_common.models.schema_definitions import (
|
|
9
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
10
15
|
DType,
|
|
@@ -32,6 +37,15 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
|
|
|
32
37
|
def description() -> str:
|
|
33
38
|
return "Metric that reports the sum of the numeric column per time window."
|
|
34
39
|
|
|
40
|
+
@staticmethod
|
|
41
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
42
|
+
return [
|
|
43
|
+
BaseReportedAggregation(
|
|
44
|
+
metric_name=NumericSumAggregationFunction.METRIC_NAME,
|
|
45
|
+
description=NumericSumAggregationFunction.description(),
|
|
46
|
+
)
|
|
47
|
+
]
|
|
48
|
+
|
|
35
49
|
def aggregate(
|
|
36
50
|
self,
|
|
37
51
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -10,7 +10,12 @@ from arthur_common.aggregations.aggregator import (
|
|
|
10
10
|
SketchAggregationFunction,
|
|
11
11
|
)
|
|
12
12
|
from arthur_common.models.datasets import ModelProblemType
|
|
13
|
-
from arthur_common.models.metrics import
|
|
13
|
+
from arthur_common.models.metrics import (
|
|
14
|
+
DatasetReference,
|
|
15
|
+
NumericMetric,
|
|
16
|
+
SketchMetric,
|
|
17
|
+
BaseReportedAggregation,
|
|
18
|
+
)
|
|
14
19
|
from arthur_common.models.schema_definitions import (
|
|
15
20
|
SHIELD_RESPONSE_SCHEMA,
|
|
16
21
|
MetricColumnParameterAnnotation,
|
|
@@ -33,6 +38,15 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
|
|
|
33
38
|
def description() -> str:
|
|
34
39
|
return "Metric that counts the number of Shield inferences grouped by the prompt, response, and overall check results."
|
|
35
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=ShieldInferencePassFailCountAggregation.METRIC_NAME,
|
|
46
|
+
description=ShieldInferencePassFailCountAggregation.description(),
|
|
47
|
+
)
|
|
48
|
+
]
|
|
49
|
+
|
|
36
50
|
def aggregate(
|
|
37
51
|
self,
|
|
38
52
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -93,6 +107,15 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
|
|
|
93
107
|
def description() -> str:
|
|
94
108
|
return "Metric that counts the number of Shield rule evaluations grouped by whether it was on the prompt or response, the rule type, the rule evaluation result, the rule name, and the rule id."
|
|
95
109
|
|
|
110
|
+
@staticmethod
|
|
111
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
112
|
+
return [
|
|
113
|
+
BaseReportedAggregation(
|
|
114
|
+
metric_name=ShieldInferenceRuleCountAggregation.METRIC_NAME,
|
|
115
|
+
description=ShieldInferenceRuleCountAggregation.description(),
|
|
116
|
+
)
|
|
117
|
+
]
|
|
118
|
+
|
|
96
119
|
def aggregate(
|
|
97
120
|
self,
|
|
98
121
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -176,6 +199,15 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
|
|
|
176
199
|
def description() -> str:
|
|
177
200
|
return "Metric that counts the number of Shield hallucination evaluations that failed."
|
|
178
201
|
|
|
202
|
+
@staticmethod
|
|
203
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
204
|
+
return [
|
|
205
|
+
BaseReportedAggregation(
|
|
206
|
+
metric_name=ShieldInferenceHallucinationCountAggregation.METRIC_NAME,
|
|
207
|
+
description=ShieldInferenceHallucinationCountAggregation.description(),
|
|
208
|
+
)
|
|
209
|
+
]
|
|
210
|
+
|
|
179
211
|
def aggregate(
|
|
180
212
|
self,
|
|
181
213
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -231,6 +263,15 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
|
|
|
231
263
|
def description() -> str:
|
|
232
264
|
return "Metric that reports a distribution (data sketch) on toxicity scores returned by the Shield toxicity rule."
|
|
233
265
|
|
|
266
|
+
@staticmethod
|
|
267
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
268
|
+
return [
|
|
269
|
+
BaseReportedAggregation(
|
|
270
|
+
metric_name=ShieldInferenceRuleToxicityScoreAggregation.METRIC_NAME,
|
|
271
|
+
description=ShieldInferenceRuleToxicityScoreAggregation.description(),
|
|
272
|
+
)
|
|
273
|
+
]
|
|
274
|
+
|
|
234
275
|
def aggregate(
|
|
235
276
|
self,
|
|
236
277
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -307,6 +348,15 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
|
|
|
307
348
|
def description() -> str:
|
|
308
349
|
return "Metric that reports a distribution (data sketch) on PII scores returned by the Shield PII rule."
|
|
309
350
|
|
|
351
|
+
@staticmethod
|
|
352
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
353
|
+
return [
|
|
354
|
+
BaseReportedAggregation(
|
|
355
|
+
metric_name=ShieldInferenceRulePIIDataScoreAggregation.METRIC_NAME,
|
|
356
|
+
description=ShieldInferenceRulePIIDataScoreAggregation.description(),
|
|
357
|
+
)
|
|
358
|
+
]
|
|
359
|
+
|
|
310
360
|
def aggregate(
|
|
311
361
|
self,
|
|
312
362
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -389,6 +439,15 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
|
|
|
389
439
|
def description() -> str:
|
|
390
440
|
return "Metric that reports a distribution (data sketch) on over the number of claims identified by the Shield hallucination rule."
|
|
391
441
|
|
|
442
|
+
@staticmethod
|
|
443
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
444
|
+
return [
|
|
445
|
+
BaseReportedAggregation(
|
|
446
|
+
metric_name=ShieldInferenceRuleClaimCountAggregation.METRIC_NAME,
|
|
447
|
+
description=ShieldInferenceRuleClaimCountAggregation.description(),
|
|
448
|
+
)
|
|
449
|
+
]
|
|
450
|
+
|
|
392
451
|
def aggregate(
|
|
393
452
|
self,
|
|
394
453
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -453,6 +512,15 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
|
|
|
453
512
|
def description() -> str:
|
|
454
513
|
return "Metric that reports a distribution (data sketch) on the number of valid claims determined by the Shield hallucination rule."
|
|
455
514
|
|
|
515
|
+
@staticmethod
|
|
516
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
517
|
+
return [
|
|
518
|
+
BaseReportedAggregation(
|
|
519
|
+
metric_name=ShieldInferenceRuleClaimPassCountAggregation.METRIC_NAME,
|
|
520
|
+
description=ShieldInferenceRuleClaimPassCountAggregation.description(),
|
|
521
|
+
)
|
|
522
|
+
]
|
|
523
|
+
|
|
456
524
|
def aggregate(
|
|
457
525
|
self,
|
|
458
526
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -517,6 +585,15 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
|
|
|
517
585
|
def description() -> str:
|
|
518
586
|
return "Metric that reports a distribution (data sketch) on the number of invalid claims determined by the Shield hallucination rule."
|
|
519
587
|
|
|
588
|
+
@staticmethod
|
|
589
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
590
|
+
return [
|
|
591
|
+
BaseReportedAggregation(
|
|
592
|
+
metric_name=ShieldInferenceRuleClaimFailCountAggregation.METRIC_NAME,
|
|
593
|
+
description=ShieldInferenceRuleClaimFailCountAggregation.description(),
|
|
594
|
+
)
|
|
595
|
+
]
|
|
596
|
+
|
|
520
597
|
def aggregate(
|
|
521
598
|
self,
|
|
522
599
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -581,6 +658,15 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
581
658
|
def description() -> str:
|
|
582
659
|
return "Metric that reports a distribution (data sketch) on the latency of Shield rule evaluations. Dimensions are the rule result, rule type, and whether the rule was applicable to a prompt or response."
|
|
583
660
|
|
|
661
|
+
@staticmethod
|
|
662
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
663
|
+
return [
|
|
664
|
+
BaseReportedAggregation(
|
|
665
|
+
metric_name=ShieldInferenceRuleLatencyAggregation.METRIC_NAME,
|
|
666
|
+
description=ShieldInferenceRuleLatencyAggregation.description(),
|
|
667
|
+
)
|
|
668
|
+
]
|
|
669
|
+
|
|
584
670
|
def aggregate(
|
|
585
671
|
self,
|
|
586
672
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -643,6 +729,18 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
643
729
|
|
|
644
730
|
class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
645
731
|
METRIC_NAME = "token_count"
|
|
732
|
+
SUPPORTED_MODELS = [
|
|
733
|
+
"gpt-4o",
|
|
734
|
+
"gpt-4o-mini",
|
|
735
|
+
"gpt-3.5-turbo",
|
|
736
|
+
"o1-mini",
|
|
737
|
+
"deepseek-chat",
|
|
738
|
+
"claude-3-5-sonnet-20241022",
|
|
739
|
+
"gemini/gemini-1.5-pro",
|
|
740
|
+
"meta.llama3-1-8b-instruct-v1:0",
|
|
741
|
+
"meta.llama3-1-70b-instruct-v1:0",
|
|
742
|
+
"meta.llama3-2-11b-instruct-v1:0",
|
|
743
|
+
]
|
|
646
744
|
|
|
647
745
|
@staticmethod
|
|
648
746
|
def id() -> UUID:
|
|
@@ -656,6 +754,27 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
656
754
|
def description() -> str:
|
|
657
755
|
return "Metric that reports the number of tokens in the Shield response and prompt schemas, and their estimated cost."
|
|
658
756
|
|
|
757
|
+
@staticmethod
|
|
758
|
+
def _series_name_from_model_name(model_name: str) -> str:
|
|
759
|
+
"""Calculates name of reported series based on the model name considered."""
|
|
760
|
+
return f"token_cost.{model_name}"
|
|
761
|
+
|
|
762
|
+
@staticmethod
|
|
763
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
764
|
+
base_token_count_agg = BaseReportedAggregation(
|
|
765
|
+
metric_name=ShieldInferenceTokenCountAggregation.METRIC_NAME,
|
|
766
|
+
description=f"Metric that reports the number of tokens in the Shield response and prompt schemas.",
|
|
767
|
+
)
|
|
768
|
+
return [base_token_count_agg] + [
|
|
769
|
+
BaseReportedAggregation(
|
|
770
|
+
metric_name=ShieldInferenceTokenCountAggregation._series_name_from_model_name(
|
|
771
|
+
model_name
|
|
772
|
+
),
|
|
773
|
+
description=f"Metric that reports the estimated cost for the {model_name} model of the tokens in the Shield response and prompt schemas.",
|
|
774
|
+
)
|
|
775
|
+
for model_name in ShieldInferenceTokenCountAggregation.SUPPORTED_MODELS
|
|
776
|
+
]
|
|
777
|
+
|
|
659
778
|
def aggregate(
|
|
660
779
|
self,
|
|
661
780
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -708,25 +827,12 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
708
827
|
resp = [metric]
|
|
709
828
|
|
|
710
829
|
# Compute Cost for each model
|
|
711
|
-
models = [
|
|
712
|
-
"gpt-4o",
|
|
713
|
-
"gpt-4o-mini",
|
|
714
|
-
"gpt-3.5-turbo",
|
|
715
|
-
"o1-mini",
|
|
716
|
-
"deepseek-chat",
|
|
717
|
-
"claude-3-5-sonnet-20241022",
|
|
718
|
-
"gemini/gemini-1.5-pro",
|
|
719
|
-
"meta.llama3-1-8b-instruct-v1:0",
|
|
720
|
-
"meta.llama3-1-70b-instruct-v1:0",
|
|
721
|
-
"meta.llama3-2-11b-instruct-v1:0",
|
|
722
|
-
]
|
|
723
|
-
|
|
724
830
|
# Precompute input/output classification to avoid recalculating in loop
|
|
725
831
|
location_type = results["location"].apply(
|
|
726
832
|
lambda x: "input" if x == "prompt" else "output",
|
|
727
833
|
)
|
|
728
834
|
|
|
729
|
-
for model in
|
|
835
|
+
for model in self.SUPPORTED_MODELS:
|
|
730
836
|
# Efficient list comprehension instead of apply
|
|
731
837
|
cost_values = [
|
|
732
838
|
calculate_cost_by_tokens(int(tokens), model, loc_type)
|
|
@@ -747,5 +853,9 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
747
853
|
["location"],
|
|
748
854
|
"ts",
|
|
749
855
|
)
|
|
750
|
-
resp.append(
|
|
856
|
+
resp.append(
|
|
857
|
+
self.series_to_metric(
|
|
858
|
+
self._series_name_from_model_name(model), model_series
|
|
859
|
+
)
|
|
860
|
+
)
|
|
751
861
|
return resp
|
|
@@ -4,7 +4,7 @@ from enum import Enum
|
|
|
4
4
|
from typing import Literal, Optional
|
|
5
5
|
from uuid import UUID
|
|
6
6
|
|
|
7
|
-
from pydantic import BaseModel, Field, model_validator
|
|
7
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
10
|
from arthur_common.models.datasets import ModelProblemType
|
|
@@ -112,12 +112,9 @@ class AggregationMetricType(Enum):
|
|
|
112
112
|
NUMERIC = "numeric"
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
class
|
|
115
|
+
class BaseAggregationParameterSchema(BaseModel):
|
|
116
|
+
# fields for aggregation parameters shared across all parameter types and between default and custom metrics
|
|
116
117
|
parameter_key: str = Field(description="Name of the parameter.")
|
|
117
|
-
optional: bool = Field(
|
|
118
|
-
False,
|
|
119
|
-
description="Boolean denoting if the parameter is optional.",
|
|
120
|
-
)
|
|
121
118
|
friendly_name: str = Field(
|
|
122
119
|
description="User facing name of the parameter.",
|
|
123
120
|
)
|
|
@@ -126,7 +123,16 @@ class MetricsParameterSchema(BaseModel):
|
|
|
126
123
|
)
|
|
127
124
|
|
|
128
125
|
|
|
129
|
-
class
|
|
126
|
+
class MetricsParameterSchema(BaseAggregationParameterSchema):
|
|
127
|
+
# specific to default metrics/Python metrics—not available to custom aggregations
|
|
128
|
+
optional: bool = Field(
|
|
129
|
+
False,
|
|
130
|
+
description="Boolean denoting if the parameter is optional.",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class BaseDatasetParameterSchema(BaseAggregationParameterSchema):
|
|
135
|
+
# fields specific to dataset parameters shared across default and custom metrics
|
|
130
136
|
parameter_type: Literal["dataset"] = "dataset"
|
|
131
137
|
model_problem_type: Optional[ModelProblemType] = Field(
|
|
132
138
|
default=None,
|
|
@@ -134,12 +140,24 @@ class MetricsDatasetParameterSchema(MetricsParameterSchema):
|
|
|
134
140
|
)
|
|
135
141
|
|
|
136
142
|
|
|
137
|
-
class
|
|
143
|
+
class MetricsDatasetParameterSchema(MetricsParameterSchema, BaseDatasetParameterSchema):
|
|
144
|
+
# dataset parameter schema including fields specific to default metrics
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class BaseLiteralParameterSchema(BaseAggregationParameterSchema):
|
|
149
|
+
# fields specific to literal parameters shared across default and custom metrics
|
|
138
150
|
parameter_type: Literal["literal"] = "literal"
|
|
139
151
|
parameter_dtype: DType = Field(description="Data type of the parameter.")
|
|
140
152
|
|
|
141
153
|
|
|
142
|
-
class
|
|
154
|
+
class MetricsLiteralParameterSchema(MetricsParameterSchema, BaseLiteralParameterSchema):
|
|
155
|
+
# literal parameter schema including fields specific to default metrics
|
|
156
|
+
pass
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class BaseColumnBaseParameterSchema(BaseAggregationParameterSchema):
|
|
160
|
+
# fields specific to all single or multiple column parameters shared across default and custom metrics
|
|
143
161
|
tag_hints: list[ScopeSchemaTag] = Field(
|
|
144
162
|
[],
|
|
145
163
|
description="List of tags that are applicable to this parameter. Datasets with columns that have matching tags can be inferred this way.",
|
|
@@ -165,12 +183,20 @@ class MetricsColumnBaseParameterSchema(MetricsParameterSchema):
|
|
|
165
183
|
return self
|
|
166
184
|
|
|
167
185
|
|
|
168
|
-
class
|
|
186
|
+
class BaseColumnParameterSchema(BaseColumnBaseParameterSchema):
|
|
187
|
+
# single column parameter schema common across default and custom metrics
|
|
169
188
|
parameter_type: Literal["column"] = "column"
|
|
170
189
|
|
|
171
190
|
|
|
172
|
-
|
|
173
|
-
|
|
191
|
+
class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSchema):
|
|
192
|
+
# single column parameter schema specific to default metrics
|
|
193
|
+
parameter_type: Literal["column"] = "column"
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class MetricsColumnListParameterSchema(
|
|
197
|
+
MetricsParameterSchema, BaseColumnParameterSchema
|
|
198
|
+
):
|
|
199
|
+
# list column parameter schema specific to default metrics
|
|
174
200
|
parameter_type: Literal["column_list"] = "column_list"
|
|
175
201
|
|
|
176
202
|
|
|
@@ -186,6 +212,11 @@ MetricsColumnSchemaUnion = (
|
|
|
186
212
|
)
|
|
187
213
|
|
|
188
214
|
|
|
215
|
+
CustomAggregationParametersSchemaUnion = (
|
|
216
|
+
BaseDatasetParameterSchema | BaseLiteralParameterSchema | BaseColumnParameterSchema
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
189
220
|
@dataclass
|
|
190
221
|
class DatasetReference:
|
|
191
222
|
dataset_name: str
|
|
@@ -193,6 +224,14 @@ class DatasetReference:
|
|
|
193
224
|
dataset_id: UUID
|
|
194
225
|
|
|
195
226
|
|
|
227
|
+
class BaseReportedAggregation(BaseModel):
|
|
228
|
+
# in future will be used by default metrics
|
|
229
|
+
metric_name: str = Field(description="Name of the reported aggregation metric.")
|
|
230
|
+
description: str = Field(
|
|
231
|
+
description="Description of the reported aggregation metric and what it aggregates.",
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
196
235
|
class AggregationSpecSchema(BaseModel):
|
|
197
236
|
name: str = Field(description="Name of the aggregation function.")
|
|
198
237
|
id: UUID = Field(description="Unique identifier of the aggregation function.")
|
|
@@ -209,6 +248,17 @@ class AggregationSpecSchema(BaseModel):
|
|
|
209
248
|
aggregate_args: list[MetricsParameterSchemaUnion] = Field(
|
|
210
249
|
description="List of parameters to the aggregation's aggregate function.",
|
|
211
250
|
)
|
|
251
|
+
reported_aggregations: list[BaseReportedAggregation] = Field(
|
|
252
|
+
description="List of aggregations reported by the metric."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
@model_validator(mode="after")
|
|
256
|
+
def at_least_one_reported_agg(self) -> Self:
|
|
257
|
+
if len(self.reported_aggregations) < 1:
|
|
258
|
+
raise ValueError(
|
|
259
|
+
"Aggregation spec must specify at least one reported aggregation."
|
|
260
|
+
)
|
|
261
|
+
return self
|
|
212
262
|
|
|
213
263
|
@model_validator(mode="after")
|
|
214
264
|
def column_dataset_references_exist(self) -> Self:
|
|
@@ -229,3 +279,25 @@ class AggregationSpecSchema(BaseModel):
|
|
|
229
279
|
f"Column parameter '{param.parameter_key}' references dataset parameter '{param.source_dataset_parameter_key}' which does not exist.",
|
|
230
280
|
)
|
|
231
281
|
return self
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class ReportedCustomAggregation(BaseReportedAggregation):
|
|
285
|
+
value_column: str = Field(
|
|
286
|
+
description="Name of the column returned from the SQL query holding the metric value."
|
|
287
|
+
)
|
|
288
|
+
timestamp_column: str = Field(
|
|
289
|
+
description="Name of the column returned from the SQL query holding the timestamp buckets."
|
|
290
|
+
)
|
|
291
|
+
metric_kind: AggregationMetricType = Field(
|
|
292
|
+
description="Return type of the reported aggregation metric value.",
|
|
293
|
+
)
|
|
294
|
+
dimension_columns: list[str] = Field(
|
|
295
|
+
description="Name of any dimension columns returned from the SQL query. Max length is 1."
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
@field_validator("dimension_columns")
|
|
299
|
+
@classmethod
|
|
300
|
+
def validate_dimension_columns_length(cls, v: list[str]) -> str:
|
|
301
|
+
if len(v) > 1:
|
|
302
|
+
raise ValueError("Only one dimension column can be specified.")
|
|
303
|
+
return v
|
|
@@ -10,9 +10,7 @@ DEFAULT_PII_RULE_CONFIDENCE_SCORE_THRESHOLD = 0
|
|
|
10
10
|
|
|
11
11
|
class RuleType(str, Enum):
|
|
12
12
|
KEYWORD = "KeywordRule"
|
|
13
|
-
MODEL_HALLUCINATION = "ModelHallucinationRule"
|
|
14
13
|
MODEL_HALLUCINATION_V2 = "ModelHallucinationRuleV2"
|
|
15
|
-
MODEL_HALLUCINATION_V3 = "ModelHallucinationRuleV3"
|
|
16
14
|
MODEL_SENSITIVE_DATA = "ModelSensitiveDataRule"
|
|
17
15
|
PII_DATA = "PIIDataRule"
|
|
18
16
|
PROMPT_INJECTION = "PromptInjectionRule"
|
|
@@ -456,14 +454,6 @@ class NewRuleRequest(BaseModel):
|
|
|
456
454
|
detail="PromptInjectionRule can only be enabled for prompt. Please set the 'apply_to_response' field "
|
|
457
455
|
"to false.",
|
|
458
456
|
)
|
|
459
|
-
if (self.type == RuleType.MODEL_HALLUCINATION) and (
|
|
460
|
-
self.apply_to_prompt is True
|
|
461
|
-
):
|
|
462
|
-
raise HTTPException(
|
|
463
|
-
status_code=400,
|
|
464
|
-
detail="ModelHallucinationRule can only be enabled for response. Please set the 'apply_to_prompt' "
|
|
465
|
-
"field to false.",
|
|
466
|
-
)
|
|
467
457
|
if (self.type == RuleType.MODEL_HALLUCINATION_V2) and (
|
|
468
458
|
self.apply_to_prompt is True
|
|
469
459
|
):
|
|
@@ -472,14 +462,6 @@ class NewRuleRequest(BaseModel):
|
|
|
472
462
|
detail="ModelHallucinationRuleV2 can only be enabled for response. Please set the 'apply_to_prompt' "
|
|
473
463
|
"field to false.",
|
|
474
464
|
)
|
|
475
|
-
if (self.type == RuleType.MODEL_HALLUCINATION_V3) and (
|
|
476
|
-
self.apply_to_prompt is True
|
|
477
|
-
):
|
|
478
|
-
raise HTTPException(
|
|
479
|
-
status_code=400,
|
|
480
|
-
detail="ModelHallucinationRuleV3 can only be enabled for response. Please set the "
|
|
481
|
-
"'apply_to_prompt' field to false.",
|
|
482
|
-
)
|
|
483
465
|
if (self.apply_to_prompt is False) and (self.apply_to_response is False):
|
|
484
466
|
raise HTTPException(
|
|
485
467
|
status_code=400,
|
{arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/aggregation_analyzer.py
RENAMED
|
@@ -207,7 +207,7 @@ class FunctionAnalyzer:
|
|
|
207
207
|
)
|
|
208
208
|
# Check if X implements the required methods
|
|
209
209
|
required_methods = ["aggregate", "id", "description", "display_name"]
|
|
210
|
-
static_methods = ["description", "id", "display_name"]
|
|
210
|
+
static_methods = ["description", "id", "display_name", "reported_aggregations"]
|
|
211
211
|
for method in required_methods:
|
|
212
212
|
if not hasattr(agg_func, method) or not callable(getattr(agg_func, method)):
|
|
213
213
|
raise AttributeError(
|
|
@@ -253,6 +253,7 @@ class FunctionAnalyzer:
|
|
|
253
253
|
metric_type=metric_type,
|
|
254
254
|
init_args=aggregation_init_args,
|
|
255
255
|
aggregate_args=aggregate_args,
|
|
256
|
+
reported_aggregations=agg_func.reported_aggregations(),
|
|
256
257
|
)
|
|
257
258
|
|
|
258
259
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/README.md
RENAMED
|
File without changes
|
{arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/__init__.py
RENAMED
|
File without changes
|
{arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/py.typed
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/schema_definitions.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|