arthur-common 2.1.51__tar.gz → 2.1.53__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

Files changed (43) hide show
  1. {arthur_common-2.1.51 → arthur_common-2.1.53}/PKG-INFO +1 -1
  2. {arthur_common-2.1.51 → arthur_common-2.1.53}/pyproject.toml +1 -1
  3. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/aggregator.py +6 -0
  4. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/categorical_count.py +14 -1
  5. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/confusion_matrix.py +35 -5
  6. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/inference_count.py +14 -1
  7. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/inference_count_by_class.py +23 -1
  8. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/inference_null_count.py +15 -1
  9. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/mean_absolute_error.py +25 -3
  10. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/mean_squared_error.py +25 -3
  11. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/multiclass_confusion_matrix.py +43 -5
  12. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +14 -1
  13. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/numeric_stats.py +14 -1
  14. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/numeric_sum.py +15 -1
  15. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/shield_aggregations.py +126 -16
  16. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/metrics.py +84 -12
  17. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/shield.py +0 -18
  18. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/aggregation_analyzer.py +2 -1
  19. {arthur_common-2.1.51 → arthur_common-2.1.53}/README.md +0 -0
  20. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/__init__.py +0 -0
  21. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/__init__.py +0 -0
  22. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/README.md +0 -0
  23. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/__init__.py +0 -0
  24. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/functions/py.typed +0 -0
  25. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/aggregations/py.typed +0 -0
  26. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/config/__init__.py +0 -0
  27. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/config/config.py +0 -0
  28. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/config/settings.yaml +0 -0
  29. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/__init__.py +0 -0
  30. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/connectors.py +0 -0
  31. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/datasets.py +0 -0
  32. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/py.typed +0 -0
  33. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/schema_definitions.py +0 -0
  34. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/models/task_job_specs.py +0 -0
  35. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/py.typed +0 -0
  36. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/__init__.py +0 -0
  37. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/aggregation_loader.py +0 -0
  38. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/duckdb_data_loader.py +0 -0
  39. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/duckdb_utils.py +0 -0
  40. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/functions.py +0 -0
  41. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/py.typed +0 -0
  42. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/schema_inferer.py +0 -0
  43. {arthur_common-2.1.51 → arthur_common-2.1.53}/src/arthur_common/tools/time_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: arthur-common
3
- Version: 2.1.51
3
+ Version: 2.1.53
4
4
  Summary: Utility code common to Arthur platform components.
5
5
  License: MIT
6
6
  Author: Arthur
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "arthur-common"
3
- version = "2.1.51"
3
+ version = "2.1.53"
4
4
  description = "Utility code common to Arthur platform components."
5
5
  authors = ["Arthur <engineering@arthur.ai>"]
6
6
  license = "MIT"
@@ -29,6 +29,12 @@ class AggregationFunction(ABC):
29
29
  def aggregation_type(self) -> Type[SketchMetric] | Type[NumericMetric]:
30
30
  raise NotImplementedError
31
31
 
32
+ @staticmethod
33
+ @abstractmethod
34
+ def reported_aggregations() -> list[BaseReportedAggregation]:
35
+ """Returns the list of aggregations reported by the aggregate function."""
36
+ raise NotImplementedError
37
+
32
38
  @abstractmethod
33
39
  def aggregate(
34
40
  self,
@@ -4,7 +4,11 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.metrics import DatasetReference, NumericMetric
7
+ from arthur_common.models.metrics import (
8
+ DatasetReference,
9
+ NumericMetric,
10
+ BaseReportedAggregation,
11
+ )
8
12
  from arthur_common.models.schema_definitions import (
9
13
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
10
14
  DType,
@@ -32,6 +36,15 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
32
36
  def description() -> str:
33
37
  return "Metric that counts the number of discrete values of each category in a string column. Creates a separate dimension for each category and the values are the count of occurrences of that category in the time window."
34
38
 
39
+ @staticmethod
40
+ def reported_aggregations() -> list[BaseReportedAggregation]:
41
+ return [
42
+ BaseReportedAggregation(
43
+ metric_name=CategoricalCountAggregationFunction.METRIC_NAME,
44
+ description=CategoricalCountAggregationFunction.description(),
45
+ )
46
+ ]
47
+
35
48
  def aggregate(
36
49
  self,
37
50
  ddb_conn: DuckDBPyConnection,
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
- from arthur_common.models.metrics import DatasetReference, NumericMetric
8
+ from arthur_common.models.metrics import (
9
+ DatasetReference,
10
+ NumericMetric,
11
+ BaseReportedAggregation,
12
+ )
9
13
  from arthur_common.models.schema_definitions import (
10
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
11
15
  DType,
@@ -20,6 +24,32 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str
20
24
 
21
25
 
22
26
  class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
27
+ TRUE_POSITIVE_METRIC_NAME = "confusion_matrix_true_positive_count"
28
+ FALSE_POSITIVE_METRIC_NAME = "confusion_matrix_false_positive_count"
29
+ FALSE_NEGATIVE_METRIC_NAME = "confusion_matrix_false_negative_count"
30
+ TRUE_NEGATIVE_METRIC_NAME = "confusion_matrix_true_negative_count"
31
+
32
+ @staticmethod
33
+ def reported_aggregations() -> list[BaseReportedAggregation]:
34
+ return [
35
+ BaseReportedAggregation(
36
+ metric_name=ConfusionMatrixAggregationFunction.TRUE_POSITIVE_METRIC_NAME,
37
+ description="Confusion matrix true positives count.",
38
+ ),
39
+ BaseReportedAggregation(
40
+ metric_name=ConfusionMatrixAggregationFunction.FALSE_POSITIVE_METRIC_NAME,
41
+ description="Confusion matrix false positives count.",
42
+ ),
43
+ BaseReportedAggregation(
44
+ metric_name=ConfusionMatrixAggregationFunction.FALSE_NEGATIVE_METRIC_NAME,
45
+ description="Confusion matrix false negatives count.",
46
+ ),
47
+ BaseReportedAggregation(
48
+ metric_name=ConfusionMatrixAggregationFunction.TRUE_NEGATIVE_METRIC_NAME,
49
+ description="Confusion matrix true negatives count.",
50
+ ),
51
+ ]
52
+
23
53
  def generate_confusion_matrix_metrics(
24
54
  self,
25
55
  ddb_conn: DuckDBPyConnection,
@@ -129,10 +159,10 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
129
159
  dim_columns=segmentation_cols + extra_dims,
130
160
  timestamp_col="ts",
131
161
  )
132
- tp_metric = self.series_to_metric("confusion_matrix_true_positive_count", tp)
133
- fp_metric = self.series_to_metric("confusion_matrix_false_positive_count", fp)
134
- fn_metric = self.series_to_metric("confusion_matrix_false_negative_count", fn)
135
- tn_metric = self.series_to_metric("confusion_matrix_true_negative_count", tn)
162
+ tp_metric = self.series_to_metric(self.TRUE_POSITIVE_METRIC_NAME, tp)
163
+ fp_metric = self.series_to_metric(self.FALSE_POSITIVE_METRIC_NAME, fp)
164
+ fn_metric = self.series_to_metric(self.FALSE_NEGATIVE_METRIC_NAME, fn)
165
+ tn_metric = self.series_to_metric(self.TRUE_NEGATIVE_METRIC_NAME, tn)
136
166
  return [tp_metric, fp_metric, fn_metric, tn_metric]
137
167
 
138
168
 
@@ -4,7 +4,11 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.metrics import DatasetReference, NumericMetric
7
+ from arthur_common.models.metrics import (
8
+ DatasetReference,
9
+ NumericMetric,
10
+ BaseReportedAggregation,
11
+ )
8
12
  from arthur_common.models.schema_definitions import (
9
13
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
10
14
  DType,
@@ -32,6 +36,15 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
32
36
  def description() -> str:
33
37
  return "Metric that counts the number of inferences per time window."
34
38
 
39
+ @staticmethod
40
+ def reported_aggregations() -> list[BaseReportedAggregation]:
41
+ return [
42
+ BaseReportedAggregation(
43
+ metric_name=InferenceCountAggregationFunction.METRIC_NAME,
44
+ description=InferenceCountAggregationFunction.description(),
45
+ )
46
+ ]
47
+
35
48
  def aggregate(
36
49
  self,
37
50
  ddb_conn: DuckDBPyConnection,
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
- from arthur_common.models.metrics import DatasetReference, NumericMetric
8
+ from arthur_common.models.metrics import (
9
+ DatasetReference,
10
+ NumericMetric,
11
+ BaseReportedAggregation,
12
+ )
9
13
  from arthur_common.models.schema_definitions import (
10
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
11
15
  DType,
@@ -36,6 +40,15 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
36
40
  def _metric_name() -> str:
37
41
  return "binary_classifier_count_by_class"
38
42
 
43
+ @staticmethod
44
+ def reported_aggregations() -> list[BaseReportedAggregation]:
45
+ return [
46
+ BaseReportedAggregation(
47
+ metric_name=BinaryClassifierCountByClassAggregationFunction._metric_name(),
48
+ description=BinaryClassifierCountByClassAggregationFunction.description(),
49
+ )
50
+ ]
51
+
39
52
  def aggregate(
40
53
  self,
41
54
  ddb_conn: DuckDBPyConnection,
@@ -153,6 +166,15 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
153
166
  def _metric_name() -> str:
154
167
  return "binary_classifier_count_by_class"
155
168
 
169
+ @staticmethod
170
+ def reported_aggregations() -> list[BaseReportedAggregation]:
171
+ return [
172
+ BaseReportedAggregation(
173
+ metric_name=BinaryClassifierCountThresholdClassAggregationFunction._metric_name(),
174
+ description=BinaryClassifierCountThresholdClassAggregationFunction.description(),
175
+ )
176
+ ]
177
+
156
178
  def aggregate(
157
179
  self,
158
180
  ddb_conn: DuckDBPyConnection,
@@ -4,7 +4,12 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.metrics import DatasetReference, Dimension, NumericMetric
7
+ from arthur_common.models.metrics import (
8
+ DatasetReference,
9
+ Dimension,
10
+ NumericMetric,
11
+ BaseReportedAggregation,
12
+ )
8
13
  from arthur_common.models.schema_definitions import (
9
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
10
15
  DType,
@@ -32,6 +37,15 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
32
37
  def description() -> str:
33
38
  return "Metric that counts the number of null values in the column per time window."
34
39
 
40
+ @staticmethod
41
+ def reported_aggregations() -> list[BaseReportedAggregation]:
42
+ return [
43
+ BaseReportedAggregation(
44
+ metric_name=InferenceNullCountAggregationFunction.METRIC_NAME,
45
+ description=InferenceNullCountAggregationFunction.description(),
46
+ )
47
+ ]
48
+
35
49
  def aggregate(
36
50
  self,
37
51
  ddb_conn: DuckDBPyConnection,
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
- from arthur_common.models.metrics import DatasetReference, NumericMetric
8
+ from arthur_common.models.metrics import (
9
+ DatasetReference,
10
+ NumericMetric,
11
+ BaseReportedAggregation,
12
+ )
9
13
  from arthur_common.models.schema_definitions import (
10
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
11
15
  DType,
@@ -19,6 +23,9 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier
19
23
 
20
24
 
21
25
  class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
26
+ ABSOLUTE_ERROR_COUNT_METRIC_NAME = "absolute_error_count"
27
+ ABSOLUTE_ERROR_SUM_METRIC_NAME = "absolute_error_sum"
28
+
22
29
  @staticmethod
23
30
  def id() -> UUID:
24
31
  return UUID("00000000-0000-0000-0000-00000000000e")
@@ -31,6 +38,19 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
31
38
  def description() -> str:
32
39
  return "Metric that sums the absolute error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
33
40
 
41
+ @staticmethod
42
+ def reported_aggregations() -> list[BaseReportedAggregation]:
43
+ return [
44
+ BaseReportedAggregation(
45
+ metric_name=MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_COUNT_METRIC_NAME,
46
+ description="Sum of the absolute error of a prediction and ground truth column, omitting rows where either column is null.",
47
+ ),
48
+ BaseReportedAggregation(
49
+ metric_name=MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_SUM_METRIC_NAME,
50
+ description=f"Count of non-null rows used in the calculation of the {MeanAbsoluteErrorAggregationFunction.ABSOLUTE_ERROR_SUM_METRIC_NAME} metric.",
51
+ ),
52
+ ]
53
+
34
54
  def aggregate(
35
55
  self,
36
56
  ddb_conn: DuckDBPyConnection,
@@ -138,9 +158,11 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
138
158
  "ts",
139
159
  )
140
160
 
141
- count_metric = self.series_to_metric("absolute_error_count", count_series)
161
+ count_metric = self.series_to_metric(
162
+ self.ABSOLUTE_ERROR_COUNT_METRIC_NAME, count_series
163
+ )
142
164
  absolute_error_metric = self.series_to_metric(
143
- "absolute_error_sum",
165
+ self.ABSOLUTE_ERROR_SUM_METRIC_NAME,
144
166
  absolute_error_series,
145
167
  )
146
168
 
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
- from arthur_common.models.metrics import DatasetReference, NumericMetric
8
+ from arthur_common.models.metrics import (
9
+ DatasetReference,
10
+ NumericMetric,
11
+ BaseReportedAggregation,
12
+ )
9
13
  from arthur_common.models.schema_definitions import (
10
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
11
15
  DType,
@@ -19,6 +23,9 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier
19
23
 
20
24
 
21
25
  class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
26
+ SQUARED_ERROR_COUNT_METRIC_NAME = "squared_error_count"
27
+ SQUARED_ERROR_SUM_METRIC_NAME = "squared_error_sum"
28
+
22
29
  @staticmethod
23
30
  def id() -> UUID:
24
31
  return UUID("00000000-0000-0000-0000-000000000010")
@@ -31,6 +38,19 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
31
38
  def description() -> str:
32
39
  return "Metric that sums the squared error of a prediction and ground truth column. It omits any rows where either the prediction or ground truth are null. It reports the count of non-null rows used in the calculation in a second metric."
33
40
 
41
+ @staticmethod
42
+ def reported_aggregations() -> list[BaseReportedAggregation]:
43
+ return [
44
+ BaseReportedAggregation(
45
+ metric_name=MeanSquaredErrorAggregationFunction.SQUARED_ERROR_SUM_METRIC_NAME,
46
+ description="Sum of the squared error of a prediction and ground truth column, omitting rows where either column is null.",
47
+ ),
48
+ BaseReportedAggregation(
49
+ metric_name=MeanSquaredErrorAggregationFunction.SQUARED_ERROR_COUNT_METRIC_NAME,
50
+ description=f"Count of non-null rows used in the calculation of the {MeanSquaredErrorAggregationFunction.SQUARED_ERROR_SUM_METRIC_NAME} metric.",
51
+ ),
52
+ ]
53
+
34
54
  def aggregate(
35
55
  self,
36
56
  ddb_conn: DuckDBPyConnection,
@@ -138,9 +158,11 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
138
158
  "ts",
139
159
  )
140
160
 
141
- count_metric = self.series_to_metric("squared_error_count", count_series)
161
+ count_metric = self.series_to_metric(
162
+ self.SQUARED_ERROR_COUNT_METRIC_NAME, count_series
163
+ )
142
164
  absolute_error_metric = self.series_to_metric(
143
- "squared_error_sum",
165
+ self.SQUARED_ERROR_SUM_METRIC_NAME,
144
166
  squared_error_series,
145
167
  )
146
168
 
@@ -5,7 +5,11 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
- from arthur_common.models.metrics import DatasetReference, NumericMetric
8
+ from arthur_common.models.metrics import (
9
+ DatasetReference,
10
+ NumericMetric,
11
+ BaseReportedAggregation,
12
+ )
9
13
  from arthur_common.models.schema_definitions import (
10
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
11
15
  DType,
@@ -22,6 +26,19 @@ from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str
22
26
  class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
23
27
  NumericAggregationFunction,
24
28
  ):
29
+ MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME = (
30
+ "multiclass_confusion_matrix_single_class_true_positive_count"
31
+ )
32
+ MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME = (
33
+ "multiclass_confusion_matrix_single_class_false_positive_count"
34
+ )
35
+ MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME = (
36
+ "multiclass_confusion_matrix_single_class_false_negative_count"
37
+ )
38
+ MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME = (
39
+ "multiclass_confusion_matrix_single_class_true_negative_count"
40
+ )
41
+
25
42
  @staticmethod
26
43
  def id() -> UUID:
27
44
  return UUID("dc728927-6928-4a3b-b174-8c1ec8b58d62")
@@ -38,6 +55,27 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
38
55
  "False Negatives, True Negatives) for that class compared to all others."
39
56
  )
40
57
 
58
+ @staticmethod
59
+ def reported_aggregations() -> list[BaseReportedAggregation]:
60
+ return [
61
+ BaseReportedAggregation(
62
+ metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME,
63
+ description="Confusion matrix true positives count.",
64
+ ),
65
+ BaseReportedAggregation(
66
+ metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME,
67
+ description="Confusion matrix false positives count.",
68
+ ),
69
+ BaseReportedAggregation(
70
+ metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME,
71
+ description="Confusion matrix false negatives count.",
72
+ ),
73
+ BaseReportedAggregation(
74
+ metric_name=MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction.MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME,
75
+ description="Confusion matrix true negatives count.",
76
+ ),
77
+ ]
78
+
41
79
  def aggregate(
42
80
  self,
43
81
  ddb_conn: DuckDBPyConnection,
@@ -238,19 +276,19 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
238
276
  timestamp_col="ts",
239
277
  )
240
278
  tp_metric = self.series_to_metric(
241
- "multiclass_confusion_matrix_single_class_true_positive_count",
279
+ self.MULTICLASS_CM_SINGLE_CLASS_TP_COUNT_METRIC_NAME,
242
280
  tp,
243
281
  )
244
282
  fp_metric = self.series_to_metric(
245
- "multiclass_confusion_matrix_single_class_false_positive_count",
283
+ self.MULTICLASS_CM_SINGLE_CLASS_FP_COUNT_METRIC_NAME,
246
284
  fp,
247
285
  )
248
286
  fn_metric = self.series_to_metric(
249
- "multiclass_confusion_matrix_single_class_false_negative_count",
287
+ self.MULTICLASS_CM_SINGLE_CLASS_FN_COUNT_METRIC_NAME,
250
288
  fn,
251
289
  )
252
290
  tn_metric = self.series_to_metric(
253
- "multiclass_confusion_matrix_single_class_true_negative_count",
291
+ self.MULTICLASS_CM_SINGLE_CLASS_TN_COUNT_METRIC_NAME,
254
292
  tn,
255
293
  )
256
294
  return [tp_metric, fp_metric, fn_metric, tn_metric]
@@ -7,7 +7,11 @@ from arthur_common.aggregations.functions.inference_count_by_class import (
7
7
  BinaryClassifierCountByClassAggregationFunction,
8
8
  )
9
9
  from arthur_common.models.datasets import ModelProblemType
10
- from arthur_common.models.metrics import DatasetReference, NumericMetric
10
+ from arthur_common.models.metrics import (
11
+ DatasetReference,
12
+ NumericMetric,
13
+ BaseReportedAggregation,
14
+ )
11
15
  from arthur_common.models.schema_definitions import (
12
16
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
13
17
  DType,
@@ -47,6 +51,15 @@ class MulticlassClassifierCountByClassAggregationFunction(
47
51
  def _metric_name() -> str:
48
52
  return "multiclass_classifier_count_by_class"
49
53
 
54
+ @staticmethod
55
+ def reported_aggregations() -> list[BaseReportedAggregation]:
56
+ return [
57
+ BaseReportedAggregation(
58
+ metric_name=MulticlassClassifierCountByClassAggregationFunction._metric_name(),
59
+ description=MulticlassClassifierCountByClassAggregationFunction.description(),
60
+ )
61
+ ]
62
+
50
63
  def aggregate(
51
64
  self,
52
65
  ddb_conn: DuckDBPyConnection,
@@ -4,7 +4,11 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import SketchAggregationFunction
7
- from arthur_common.models.metrics import DatasetReference, SketchMetric
7
+ from arthur_common.models.metrics import (
8
+ DatasetReference,
9
+ SketchMetric,
10
+ BaseReportedAggregation,
11
+ )
8
12
  from arthur_common.models.schema_definitions import (
9
13
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
10
14
  DType,
@@ -34,6 +38,15 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
34
38
  "Metric that calculates a distribution (data sketch) on a numeric column."
35
39
  )
36
40
 
41
+ @staticmethod
42
+ def reported_aggregations() -> list[BaseReportedAggregation]:
43
+ return [
44
+ BaseReportedAggregation(
45
+ metric_name=NumericSketchAggregationFunction.METRIC_NAME,
46
+ description=NumericSketchAggregationFunction.description(),
47
+ )
48
+ ]
49
+
37
50
  def aggregate(
38
51
  self,
39
52
  ddb_conn: DuckDBPyConnection,
@@ -4,7 +4,12 @@ from uuid import UUID
4
4
  from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
- from arthur_common.models.metrics import DatasetReference, Dimension, NumericMetric
7
+ from arthur_common.models.metrics import (
8
+ DatasetReference,
9
+ Dimension,
10
+ NumericMetric,
11
+ BaseReportedAggregation,
12
+ )
8
13
  from arthur_common.models.schema_definitions import (
9
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
10
15
  DType,
@@ -32,6 +37,15 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
32
37
  def description() -> str:
33
38
  return "Metric that reports the sum of the numeric column per time window."
34
39
 
40
+ @staticmethod
41
+ def reported_aggregations() -> list[BaseReportedAggregation]:
42
+ return [
43
+ BaseReportedAggregation(
44
+ metric_name=NumericSumAggregationFunction.METRIC_NAME,
45
+ description=NumericSumAggregationFunction.description(),
46
+ )
47
+ ]
48
+
35
49
  def aggregate(
36
50
  self,
37
51
  ddb_conn: DuckDBPyConnection,
@@ -10,7 +10,12 @@ from arthur_common.aggregations.aggregator import (
10
10
  SketchAggregationFunction,
11
11
  )
12
12
  from arthur_common.models.datasets import ModelProblemType
13
- from arthur_common.models.metrics import DatasetReference, NumericMetric, SketchMetric
13
+ from arthur_common.models.metrics import (
14
+ DatasetReference,
15
+ NumericMetric,
16
+ SketchMetric,
17
+ BaseReportedAggregation,
18
+ )
14
19
  from arthur_common.models.schema_definitions import (
15
20
  SHIELD_RESPONSE_SCHEMA,
16
21
  MetricColumnParameterAnnotation,
@@ -33,6 +38,15 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
33
38
  def description() -> str:
34
39
  return "Metric that counts the number of Shield inferences grouped by the prompt, response, and overall check results."
35
40
 
41
+ @staticmethod
42
+ def reported_aggregations() -> list[BaseReportedAggregation]:
43
+ return [
44
+ BaseReportedAggregation(
45
+ metric_name=ShieldInferencePassFailCountAggregation.METRIC_NAME,
46
+ description=ShieldInferencePassFailCountAggregation.description(),
47
+ )
48
+ ]
49
+
36
50
  def aggregate(
37
51
  self,
38
52
  ddb_conn: DuckDBPyConnection,
@@ -93,6 +107,15 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
93
107
  def description() -> str:
94
108
  return "Metric that counts the number of Shield rule evaluations grouped by whether it was on the prompt or response, the rule type, the rule evaluation result, the rule name, and the rule id."
95
109
 
110
+ @staticmethod
111
+ def reported_aggregations() -> list[BaseReportedAggregation]:
112
+ return [
113
+ BaseReportedAggregation(
114
+ metric_name=ShieldInferenceRuleCountAggregation.METRIC_NAME,
115
+ description=ShieldInferenceRuleCountAggregation.description(),
116
+ )
117
+ ]
118
+
96
119
  def aggregate(
97
120
  self,
98
121
  ddb_conn: DuckDBPyConnection,
@@ -176,6 +199,15 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
176
199
  def description() -> str:
177
200
  return "Metric that counts the number of Shield hallucination evaluations that failed."
178
201
 
202
+ @staticmethod
203
+ def reported_aggregations() -> list[BaseReportedAggregation]:
204
+ return [
205
+ BaseReportedAggregation(
206
+ metric_name=ShieldInferenceHallucinationCountAggregation.METRIC_NAME,
207
+ description=ShieldInferenceHallucinationCountAggregation.description(),
208
+ )
209
+ ]
210
+
179
211
  def aggregate(
180
212
  self,
181
213
  ddb_conn: DuckDBPyConnection,
@@ -231,6 +263,15 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
231
263
  def description() -> str:
232
264
  return "Metric that reports a distribution (data sketch) on toxicity scores returned by the Shield toxicity rule."
233
265
 
266
+ @staticmethod
267
+ def reported_aggregations() -> list[BaseReportedAggregation]:
268
+ return [
269
+ BaseReportedAggregation(
270
+ metric_name=ShieldInferenceRuleToxicityScoreAggregation.METRIC_NAME,
271
+ description=ShieldInferenceRuleToxicityScoreAggregation.description(),
272
+ )
273
+ ]
274
+
234
275
  def aggregate(
235
276
  self,
236
277
  ddb_conn: DuckDBPyConnection,
@@ -307,6 +348,15 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
307
348
  def description() -> str:
308
349
  return "Metric that reports a distribution (data sketch) on PII scores returned by the Shield PII rule."
309
350
 
351
+ @staticmethod
352
+ def reported_aggregations() -> list[BaseReportedAggregation]:
353
+ return [
354
+ BaseReportedAggregation(
355
+ metric_name=ShieldInferenceRulePIIDataScoreAggregation.METRIC_NAME,
356
+ description=ShieldInferenceRulePIIDataScoreAggregation.description(),
357
+ )
358
+ ]
359
+
310
360
  def aggregate(
311
361
  self,
312
362
  ddb_conn: DuckDBPyConnection,
@@ -389,6 +439,15 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
389
439
  def description() -> str:
390
440
  return "Metric that reports a distribution (data sketch) on over the number of claims identified by the Shield hallucination rule."
391
441
 
442
+ @staticmethod
443
+ def reported_aggregations() -> list[BaseReportedAggregation]:
444
+ return [
445
+ BaseReportedAggregation(
446
+ metric_name=ShieldInferenceRuleClaimCountAggregation.METRIC_NAME,
447
+ description=ShieldInferenceRuleClaimCountAggregation.description(),
448
+ )
449
+ ]
450
+
392
451
  def aggregate(
393
452
  self,
394
453
  ddb_conn: DuckDBPyConnection,
@@ -453,6 +512,15 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
453
512
  def description() -> str:
454
513
  return "Metric that reports a distribution (data sketch) on the number of valid claims determined by the Shield hallucination rule."
455
514
 
515
+ @staticmethod
516
+ def reported_aggregations() -> list[BaseReportedAggregation]:
517
+ return [
518
+ BaseReportedAggregation(
519
+ metric_name=ShieldInferenceRuleClaimPassCountAggregation.METRIC_NAME,
520
+ description=ShieldInferenceRuleClaimPassCountAggregation.description(),
521
+ )
522
+ ]
523
+
456
524
  def aggregate(
457
525
  self,
458
526
  ddb_conn: DuckDBPyConnection,
@@ -517,6 +585,15 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
517
585
  def description() -> str:
518
586
  return "Metric that reports a distribution (data sketch) on the number of invalid claims determined by the Shield hallucination rule."
519
587
 
588
+ @staticmethod
589
+ def reported_aggregations() -> list[BaseReportedAggregation]:
590
+ return [
591
+ BaseReportedAggregation(
592
+ metric_name=ShieldInferenceRuleClaimFailCountAggregation.METRIC_NAME,
593
+ description=ShieldInferenceRuleClaimFailCountAggregation.description(),
594
+ )
595
+ ]
596
+
520
597
  def aggregate(
521
598
  self,
522
599
  ddb_conn: DuckDBPyConnection,
@@ -581,6 +658,15 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
581
658
  def description() -> str:
582
659
  return "Metric that reports a distribution (data sketch) on the latency of Shield rule evaluations. Dimensions are the rule result, rule type, and whether the rule was applicable to a prompt or response."
583
660
 
661
+ @staticmethod
662
+ def reported_aggregations() -> list[BaseReportedAggregation]:
663
+ return [
664
+ BaseReportedAggregation(
665
+ metric_name=ShieldInferenceRuleLatencyAggregation.METRIC_NAME,
666
+ description=ShieldInferenceRuleLatencyAggregation.description(),
667
+ )
668
+ ]
669
+
584
670
  def aggregate(
585
671
  self,
586
672
  ddb_conn: DuckDBPyConnection,
@@ -643,6 +729,18 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
643
729
 
644
730
  class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
645
731
  METRIC_NAME = "token_count"
732
+ SUPPORTED_MODELS = [
733
+ "gpt-4o",
734
+ "gpt-4o-mini",
735
+ "gpt-3.5-turbo",
736
+ "o1-mini",
737
+ "deepseek-chat",
738
+ "claude-3-5-sonnet-20241022",
739
+ "gemini/gemini-1.5-pro",
740
+ "meta.llama3-1-8b-instruct-v1:0",
741
+ "meta.llama3-1-70b-instruct-v1:0",
742
+ "meta.llama3-2-11b-instruct-v1:0",
743
+ ]
646
744
 
647
745
  @staticmethod
648
746
  def id() -> UUID:
@@ -656,6 +754,27 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
656
754
  def description() -> str:
657
755
  return "Metric that reports the number of tokens in the Shield response and prompt schemas, and their estimated cost."
658
756
 
757
+ @staticmethod
758
+ def _series_name_from_model_name(model_name: str) -> str:
759
+ """Calculates name of reported series based on the model name considered."""
760
+ return f"token_cost.{model_name}"
761
+
762
+ @staticmethod
763
+ def reported_aggregations() -> list[BaseReportedAggregation]:
764
+ base_token_count_agg = BaseReportedAggregation(
765
+ metric_name=ShieldInferenceTokenCountAggregation.METRIC_NAME,
766
+ description=f"Metric that reports the number of tokens in the Shield response and prompt schemas.",
767
+ )
768
+ return [base_token_count_agg] + [
769
+ BaseReportedAggregation(
770
+ metric_name=ShieldInferenceTokenCountAggregation._series_name_from_model_name(
771
+ model_name
772
+ ),
773
+ description=f"Metric that reports the estimated cost for the {model_name} model of the tokens in the Shield response and prompt schemas.",
774
+ )
775
+ for model_name in ShieldInferenceTokenCountAggregation.SUPPORTED_MODELS
776
+ ]
777
+
659
778
  def aggregate(
660
779
  self,
661
780
  ddb_conn: DuckDBPyConnection,
@@ -708,25 +827,12 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
708
827
  resp = [metric]
709
828
 
710
829
  # Compute Cost for each model
711
- models = [
712
- "gpt-4o",
713
- "gpt-4o-mini",
714
- "gpt-3.5-turbo",
715
- "o1-mini",
716
- "deepseek-chat",
717
- "claude-3-5-sonnet-20241022",
718
- "gemini/gemini-1.5-pro",
719
- "meta.llama3-1-8b-instruct-v1:0",
720
- "meta.llama3-1-70b-instruct-v1:0",
721
- "meta.llama3-2-11b-instruct-v1:0",
722
- ]
723
-
724
830
  # Precompute input/output classification to avoid recalculating in loop
725
831
  location_type = results["location"].apply(
726
832
  lambda x: "input" if x == "prompt" else "output",
727
833
  )
728
834
 
729
- for model in models:
835
+ for model in self.SUPPORTED_MODELS:
730
836
  # Efficient list comprehension instead of apply
731
837
  cost_values = [
732
838
  calculate_cost_by_tokens(int(tokens), model, loc_type)
@@ -747,5 +853,9 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
747
853
  ["location"],
748
854
  "ts",
749
855
  )
750
- resp.append(self.series_to_metric(f"token_cost.{model}", model_series))
856
+ resp.append(
857
+ self.series_to_metric(
858
+ self._series_name_from_model_name(model), model_series
859
+ )
860
+ )
751
861
  return resp
@@ -4,7 +4,7 @@ from enum import Enum
4
4
  from typing import Literal, Optional
5
5
  from uuid import UUID
6
6
 
7
- from pydantic import BaseModel, Field, model_validator
7
+ from pydantic import BaseModel, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
9
9
 
10
10
  from arthur_common.models.datasets import ModelProblemType
@@ -112,12 +112,9 @@ class AggregationMetricType(Enum):
112
112
  NUMERIC = "numeric"
113
113
 
114
114
 
115
- class MetricsParameterSchema(BaseModel):
115
+ class BaseAggregationParameterSchema(BaseModel):
116
+ # fields for aggregation parameters shared across all parameter types and between default and custom metrics
116
117
  parameter_key: str = Field(description="Name of the parameter.")
117
- optional: bool = Field(
118
- False,
119
- description="Boolean denoting if the parameter is optional.",
120
- )
121
118
  friendly_name: str = Field(
122
119
  description="User facing name of the parameter.",
123
120
  )
@@ -126,7 +123,16 @@ class MetricsParameterSchema(BaseModel):
126
123
  )
127
124
 
128
125
 
129
- class MetricsDatasetParameterSchema(MetricsParameterSchema):
126
+ class MetricsParameterSchema(BaseAggregationParameterSchema):
127
+ # specific to default metrics/Python metrics—not available to custom aggregations
128
+ optional: bool = Field(
129
+ False,
130
+ description="Boolean denoting if the parameter is optional.",
131
+ )
132
+
133
+
134
+ class BaseDatasetParameterSchema(BaseAggregationParameterSchema):
135
+ # fields specific to dataset parameters shared across default and custom metrics
130
136
  parameter_type: Literal["dataset"] = "dataset"
131
137
  model_problem_type: Optional[ModelProblemType] = Field(
132
138
  default=None,
@@ -134,12 +140,24 @@ class MetricsDatasetParameterSchema(MetricsParameterSchema):
134
140
  )
135
141
 
136
142
 
137
- class MetricsLiteralParameterSchema(MetricsParameterSchema):
143
+ class MetricsDatasetParameterSchema(MetricsParameterSchema, BaseDatasetParameterSchema):
144
+ # dataset parameter schema including fields specific to default metrics
145
+ pass
146
+
147
+
148
+ class BaseLiteralParameterSchema(BaseAggregationParameterSchema):
149
+ # fields specific to literal parameters shared across default and custom metrics
138
150
  parameter_type: Literal["literal"] = "literal"
139
151
  parameter_dtype: DType = Field(description="Data type of the parameter.")
140
152
 
141
153
 
142
- class MetricsColumnBaseParameterSchema(MetricsParameterSchema):
154
+ class MetricsLiteralParameterSchema(MetricsParameterSchema, BaseLiteralParameterSchema):
155
+ # literal parameter schema including fields specific to default metrics
156
+ pass
157
+
158
+
159
+ class BaseColumnBaseParameterSchema(BaseAggregationParameterSchema):
160
+ # fields specific to all single or multiple column parameters shared across default and custom metrics
143
161
  tag_hints: list[ScopeSchemaTag] = Field(
144
162
  [],
145
163
  description="List of tags that are applicable to this parameter. Datasets with columns that have matching tags can be inferred this way.",
@@ -165,12 +183,20 @@ class MetricsColumnBaseParameterSchema(MetricsParameterSchema):
165
183
  return self
166
184
 
167
185
 
168
- class MetricsColumnParameterSchema(MetricsColumnBaseParameterSchema):
186
+ class BaseColumnParameterSchema(BaseColumnBaseParameterSchema):
187
+ # single column parameter schema common across default and custom metrics
169
188
  parameter_type: Literal["column"] = "column"
170
189
 
171
190
 
172
- # Not used /implemented yet. Might turn into group by column list
173
- class MetricsColumnListParameterSchema(MetricsColumnBaseParameterSchema):
191
+ class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSchema):
192
+ # single column parameter schema specific to default metrics
193
+ parameter_type: Literal["column"] = "column"
194
+
195
+
196
+ class MetricsColumnListParameterSchema(
197
+ MetricsParameterSchema, BaseColumnParameterSchema
198
+ ):
199
+ # list column parameter schema specific to default metrics
174
200
  parameter_type: Literal["column_list"] = "column_list"
175
201
 
176
202
 
@@ -186,6 +212,11 @@ MetricsColumnSchemaUnion = (
186
212
  )
187
213
 
188
214
 
215
+ CustomAggregationParametersSchemaUnion = (
216
+ BaseDatasetParameterSchema | BaseLiteralParameterSchema | BaseColumnParameterSchema
217
+ )
218
+
219
+
189
220
  @dataclass
190
221
  class DatasetReference:
191
222
  dataset_name: str
@@ -193,6 +224,14 @@ class DatasetReference:
193
224
  dataset_id: UUID
194
225
 
195
226
 
227
+ class BaseReportedAggregation(BaseModel):
228
+ # in future will be used by default metrics
229
+ metric_name: str = Field(description="Name of the reported aggregation metric.")
230
+ description: str = Field(
231
+ description="Description of the reported aggregation metric and what it aggregates.",
232
+ )
233
+
234
+
196
235
  class AggregationSpecSchema(BaseModel):
197
236
  name: str = Field(description="Name of the aggregation function.")
198
237
  id: UUID = Field(description="Unique identifier of the aggregation function.")
@@ -209,6 +248,17 @@ class AggregationSpecSchema(BaseModel):
209
248
  aggregate_args: list[MetricsParameterSchemaUnion] = Field(
210
249
  description="List of parameters to the aggregation's aggregate function.",
211
250
  )
251
+ reported_aggregations: list[BaseReportedAggregation] = Field(
252
+ description="List of aggregations reported by the metric."
253
+ )
254
+
255
+ @model_validator(mode="after")
256
+ def at_least_one_reported_agg(self) -> Self:
257
+ if len(self.reported_aggregations) < 1:
258
+ raise ValueError(
259
+ "Aggregation spec must specify at least one reported aggregation."
260
+ )
261
+ return self
212
262
 
213
263
  @model_validator(mode="after")
214
264
  def column_dataset_references_exist(self) -> Self:
@@ -229,3 +279,25 @@ class AggregationSpecSchema(BaseModel):
229
279
  f"Column parameter '{param.parameter_key}' references dataset parameter '{param.source_dataset_parameter_key}' which does not exist.",
230
280
  )
231
281
  return self
282
+
283
+
284
+ class ReportedCustomAggregation(BaseReportedAggregation):
285
+ value_column: str = Field(
286
+ description="Name of the column returned from the SQL query holding the metric value."
287
+ )
288
+ timestamp_column: str = Field(
289
+ description="Name of the column returned from the SQL query holding the timestamp buckets."
290
+ )
291
+ metric_kind: AggregationMetricType = Field(
292
+ description="Return type of the reported aggregation metric value.",
293
+ )
294
+ dimension_columns: list[str] = Field(
295
+ description="Name of any dimension columns returned from the SQL query. Max length is 1."
296
+ )
297
+
298
+ @field_validator("dimension_columns")
299
+ @classmethod
300
+ def validate_dimension_columns_length(cls, v: list[str]) -> str:
301
+ if len(v) > 1:
302
+ raise ValueError("Only one dimension column can be specified.")
303
+ return v
@@ -10,9 +10,7 @@ DEFAULT_PII_RULE_CONFIDENCE_SCORE_THRESHOLD = 0
10
10
 
11
11
  class RuleType(str, Enum):
12
12
  KEYWORD = "KeywordRule"
13
- MODEL_HALLUCINATION = "ModelHallucinationRule"
14
13
  MODEL_HALLUCINATION_V2 = "ModelHallucinationRuleV2"
15
- MODEL_HALLUCINATION_V3 = "ModelHallucinationRuleV3"
16
14
  MODEL_SENSITIVE_DATA = "ModelSensitiveDataRule"
17
15
  PII_DATA = "PIIDataRule"
18
16
  PROMPT_INJECTION = "PromptInjectionRule"
@@ -456,14 +454,6 @@ class NewRuleRequest(BaseModel):
456
454
  detail="PromptInjectionRule can only be enabled for prompt. Please set the 'apply_to_response' field "
457
455
  "to false.",
458
456
  )
459
- if (self.type == RuleType.MODEL_HALLUCINATION) and (
460
- self.apply_to_prompt is True
461
- ):
462
- raise HTTPException(
463
- status_code=400,
464
- detail="ModelHallucinationRule can only be enabled for response. Please set the 'apply_to_prompt' "
465
- "field to false.",
466
- )
467
457
  if (self.type == RuleType.MODEL_HALLUCINATION_V2) and (
468
458
  self.apply_to_prompt is True
469
459
  ):
@@ -472,14 +462,6 @@ class NewRuleRequest(BaseModel):
472
462
  detail="ModelHallucinationRuleV2 can only be enabled for response. Please set the 'apply_to_prompt' "
473
463
  "field to false.",
474
464
  )
475
- if (self.type == RuleType.MODEL_HALLUCINATION_V3) and (
476
- self.apply_to_prompt is True
477
- ):
478
- raise HTTPException(
479
- status_code=400,
480
- detail="ModelHallucinationRuleV3 can only be enabled for response. Please set the "
481
- "'apply_to_prompt' field to false.",
482
- )
483
465
  if (self.apply_to_prompt is False) and (self.apply_to_response is False):
484
466
  raise HTTPException(
485
467
  status_code=400,
@@ -207,7 +207,7 @@ class FunctionAnalyzer:
207
207
  )
208
208
  # Check if X implements the required methods
209
209
  required_methods = ["aggregate", "id", "description", "display_name"]
210
- static_methods = ["description", "id", "display_name"]
210
+ static_methods = ["description", "id", "display_name", "reported_aggregations"]
211
211
  for method in required_methods:
212
212
  if not hasattr(agg_func, method) or not callable(getattr(agg_func, method)):
213
213
  raise AttributeError(
@@ -253,6 +253,7 @@ class FunctionAnalyzer:
253
253
  metric_type=metric_type,
254
254
  init_args=aggregation_init_args,
255
255
  aggregate_args=aggregate_args,
256
+ reported_aggregations=agg_func.reported_aggregations(),
256
257
  )
257
258
 
258
259
 
File without changes