arthur-common 2.1.53__py3-none-any.whl → 2.1.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/aggregations/functions/agentic_aggregations.py +875 -0
- arthur_common/aggregations/functions/categorical_count.py +2 -2
- arthur_common/aggregations/functions/confusion_matrix.py +1 -1
- arthur_common/aggregations/functions/inference_count.py +2 -2
- arthur_common/aggregations/functions/inference_count_by_class.py +3 -3
- arthur_common/aggregations/functions/inference_null_count.py +2 -2
- arthur_common/aggregations/functions/mean_absolute_error.py +3 -2
- arthur_common/aggregations/functions/mean_squared_error.py +3 -2
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +1 -1
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +2 -2
- arthur_common/aggregations/functions/numeric_stats.py +2 -2
- arthur_common/aggregations/functions/numeric_sum.py +2 -2
- arthur_common/aggregations/functions/shield_aggregations.py +14 -13
- arthur_common/models/datasets.py +1 -0
- arthur_common/models/metrics.py +7 -6
- arthur_common/models/schema_definitions.py +58 -0
- arthur_common/models/shield.py +153 -0
- arthur_common/models/task_job_specs.py +26 -2
- {arthur_common-2.1.53.dist-info → arthur_common-2.1.55.dist-info}/METADATA +1 -1
- {arthur_common-2.1.53.dist-info → arthur_common-2.1.55.dist-info}/RECORD +21 -20
- {arthur_common-2.1.53.dist-info → arthur_common-2.1.55.dist-info}/WHEEL +0 -0
|
@@ -5,9 +5,9 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.metrics import (
|
|
8
|
+
BaseReportedAggregation,
|
|
8
9
|
DatasetReference,
|
|
9
10
|
NumericMetric,
|
|
10
|
-
BaseReportedAggregation,
|
|
11
11
|
)
|
|
12
12
|
from arthur_common.models.schema_definitions import (
|
|
13
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -42,7 +42,7 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
|
|
|
42
42
|
BaseReportedAggregation(
|
|
43
43
|
metric_name=CategoricalCountAggregationFunction.METRIC_NAME,
|
|
44
44
|
description=CategoricalCountAggregationFunction.description(),
|
|
45
|
-
)
|
|
45
|
+
),
|
|
46
46
|
]
|
|
47
47
|
|
|
48
48
|
def aggregate(
|
|
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
8
|
from arthur_common.models.metrics import (
|
|
9
|
+
BaseReportedAggregation,
|
|
9
10
|
DatasetReference,
|
|
10
11
|
NumericMetric,
|
|
11
|
-
BaseReportedAggregation,
|
|
12
12
|
)
|
|
13
13
|
from arthur_common.models.schema_definitions import (
|
|
14
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -5,9 +5,9 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.metrics import (
|
|
8
|
+
BaseReportedAggregation,
|
|
8
9
|
DatasetReference,
|
|
9
10
|
NumericMetric,
|
|
10
|
-
BaseReportedAggregation,
|
|
11
11
|
)
|
|
12
12
|
from arthur_common.models.schema_definitions import (
|
|
13
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -42,7 +42,7 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
|
|
|
42
42
|
BaseReportedAggregation(
|
|
43
43
|
metric_name=InferenceCountAggregationFunction.METRIC_NAME,
|
|
44
44
|
description=InferenceCountAggregationFunction.description(),
|
|
45
|
-
)
|
|
45
|
+
),
|
|
46
46
|
]
|
|
47
47
|
|
|
48
48
|
def aggregate(
|
|
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
8
|
from arthur_common.models.metrics import (
|
|
9
|
+
BaseReportedAggregation,
|
|
9
10
|
DatasetReference,
|
|
10
11
|
NumericMetric,
|
|
11
|
-
BaseReportedAggregation,
|
|
12
12
|
)
|
|
13
13
|
from arthur_common.models.schema_definitions import (
|
|
14
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -46,7 +46,7 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
|
|
|
46
46
|
BaseReportedAggregation(
|
|
47
47
|
metric_name=BinaryClassifierCountByClassAggregationFunction._metric_name(),
|
|
48
48
|
description=BinaryClassifierCountByClassAggregationFunction.description(),
|
|
49
|
-
)
|
|
49
|
+
),
|
|
50
50
|
]
|
|
51
51
|
|
|
52
52
|
def aggregate(
|
|
@@ -172,7 +172,7 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
|
|
|
172
172
|
BaseReportedAggregation(
|
|
173
173
|
metric_name=BinaryClassifierCountThresholdClassAggregationFunction._metric_name(),
|
|
174
174
|
description=BinaryClassifierCountThresholdClassAggregationFunction.description(),
|
|
175
|
-
)
|
|
175
|
+
),
|
|
176
176
|
]
|
|
177
177
|
|
|
178
178
|
def aggregate(
|
|
@@ -5,10 +5,10 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.metrics import (
|
|
8
|
+
BaseReportedAggregation,
|
|
8
9
|
DatasetReference,
|
|
9
10
|
Dimension,
|
|
10
11
|
NumericMetric,
|
|
11
|
-
BaseReportedAggregation,
|
|
12
12
|
)
|
|
13
13
|
from arthur_common.models.schema_definitions import (
|
|
14
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -43,7 +43,7 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
|
|
|
43
43
|
BaseReportedAggregation(
|
|
44
44
|
metric_name=InferenceNullCountAggregationFunction.METRIC_NAME,
|
|
45
45
|
description=InferenceNullCountAggregationFunction.description(),
|
|
46
|
-
)
|
|
46
|
+
),
|
|
47
47
|
]
|
|
48
48
|
|
|
49
49
|
def aggregate(
|
|
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
8
|
from arthur_common.models.metrics import (
|
|
9
|
+
BaseReportedAggregation,
|
|
9
10
|
DatasetReference,
|
|
10
11
|
NumericMetric,
|
|
11
|
-
BaseReportedAggregation,
|
|
12
12
|
)
|
|
13
13
|
from arthur_common.models.schema_definitions import (
|
|
14
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -159,7 +159,8 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
|
|
|
159
159
|
)
|
|
160
160
|
|
|
161
161
|
count_metric = self.series_to_metric(
|
|
162
|
-
self.ABSOLUTE_ERROR_COUNT_METRIC_NAME,
|
|
162
|
+
self.ABSOLUTE_ERROR_COUNT_METRIC_NAME,
|
|
163
|
+
count_series,
|
|
163
164
|
)
|
|
164
165
|
absolute_error_metric = self.series_to_metric(
|
|
165
166
|
self.ABSOLUTE_ERROR_SUM_METRIC_NAME,
|
|
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
8
|
from arthur_common.models.metrics import (
|
|
9
|
+
BaseReportedAggregation,
|
|
9
10
|
DatasetReference,
|
|
10
11
|
NumericMetric,
|
|
11
|
-
BaseReportedAggregation,
|
|
12
12
|
)
|
|
13
13
|
from arthur_common.models.schema_definitions import (
|
|
14
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -159,7 +159,8 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
|
|
|
159
159
|
)
|
|
160
160
|
|
|
161
161
|
count_metric = self.series_to_metric(
|
|
162
|
-
self.SQUARED_ERROR_COUNT_METRIC_NAME,
|
|
162
|
+
self.SQUARED_ERROR_COUNT_METRIC_NAME,
|
|
163
|
+
count_series,
|
|
163
164
|
)
|
|
164
165
|
absolute_error_metric = self.series_to_metric(
|
|
165
166
|
self.SQUARED_ERROR_SUM_METRIC_NAME,
|
|
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.datasets import ModelProblemType
|
|
8
8
|
from arthur_common.models.metrics import (
|
|
9
|
+
BaseReportedAggregation,
|
|
9
10
|
DatasetReference,
|
|
10
11
|
NumericMetric,
|
|
11
|
-
BaseReportedAggregation,
|
|
12
12
|
)
|
|
13
13
|
from arthur_common.models.schema_definitions import (
|
|
14
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -8,9 +8,9 @@ from arthur_common.aggregations.functions.inference_count_by_class import (
|
|
|
8
8
|
)
|
|
9
9
|
from arthur_common.models.datasets import ModelProblemType
|
|
10
10
|
from arthur_common.models.metrics import (
|
|
11
|
+
BaseReportedAggregation,
|
|
11
12
|
DatasetReference,
|
|
12
13
|
NumericMetric,
|
|
13
|
-
BaseReportedAggregation,
|
|
14
14
|
)
|
|
15
15
|
from arthur_common.models.schema_definitions import (
|
|
16
16
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -57,7 +57,7 @@ class MulticlassClassifierCountByClassAggregationFunction(
|
|
|
57
57
|
BaseReportedAggregation(
|
|
58
58
|
metric_name=MulticlassClassifierCountByClassAggregationFunction._metric_name(),
|
|
59
59
|
description=MulticlassClassifierCountByClassAggregationFunction.description(),
|
|
60
|
-
)
|
|
60
|
+
),
|
|
61
61
|
]
|
|
62
62
|
|
|
63
63
|
def aggregate(
|
|
@@ -5,9 +5,9 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import SketchAggregationFunction
|
|
7
7
|
from arthur_common.models.metrics import (
|
|
8
|
+
BaseReportedAggregation,
|
|
8
9
|
DatasetReference,
|
|
9
10
|
SketchMetric,
|
|
10
|
-
BaseReportedAggregation,
|
|
11
11
|
)
|
|
12
12
|
from arthur_common.models.schema_definitions import (
|
|
13
13
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -44,7 +44,7 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
|
|
|
44
44
|
BaseReportedAggregation(
|
|
45
45
|
metric_name=NumericSketchAggregationFunction.METRIC_NAME,
|
|
46
46
|
description=NumericSketchAggregationFunction.description(),
|
|
47
|
-
)
|
|
47
|
+
),
|
|
48
48
|
]
|
|
49
49
|
|
|
50
50
|
def aggregate(
|
|
@@ -5,10 +5,10 @@ from duckdb import DuckDBPyConnection
|
|
|
5
5
|
|
|
6
6
|
from arthur_common.aggregations.aggregator import NumericAggregationFunction
|
|
7
7
|
from arthur_common.models.metrics import (
|
|
8
|
+
BaseReportedAggregation,
|
|
8
9
|
DatasetReference,
|
|
9
10
|
Dimension,
|
|
10
11
|
NumericMetric,
|
|
11
|
-
BaseReportedAggregation,
|
|
12
12
|
)
|
|
13
13
|
from arthur_common.models.schema_definitions import (
|
|
14
14
|
SEGMENTATION_ALLOWED_COLUMN_TYPES,
|
|
@@ -43,7 +43,7 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
|
|
|
43
43
|
BaseReportedAggregation(
|
|
44
44
|
metric_name=NumericSumAggregationFunction.METRIC_NAME,
|
|
45
45
|
description=NumericSumAggregationFunction.description(),
|
|
46
|
-
)
|
|
46
|
+
),
|
|
47
47
|
]
|
|
48
48
|
|
|
49
49
|
def aggregate(
|
|
@@ -11,10 +11,10 @@ from arthur_common.aggregations.aggregator import (
|
|
|
11
11
|
)
|
|
12
12
|
from arthur_common.models.datasets import ModelProblemType
|
|
13
13
|
from arthur_common.models.metrics import (
|
|
14
|
+
BaseReportedAggregation,
|
|
14
15
|
DatasetReference,
|
|
15
16
|
NumericMetric,
|
|
16
17
|
SketchMetric,
|
|
17
|
-
BaseReportedAggregation,
|
|
18
18
|
)
|
|
19
19
|
from arthur_common.models.schema_definitions import (
|
|
20
20
|
SHIELD_RESPONSE_SCHEMA,
|
|
@@ -44,7 +44,7 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
|
|
|
44
44
|
BaseReportedAggregation(
|
|
45
45
|
metric_name=ShieldInferencePassFailCountAggregation.METRIC_NAME,
|
|
46
46
|
description=ShieldInferencePassFailCountAggregation.description(),
|
|
47
|
-
)
|
|
47
|
+
),
|
|
48
48
|
]
|
|
49
49
|
|
|
50
50
|
def aggregate(
|
|
@@ -113,7 +113,7 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
|
|
|
113
113
|
BaseReportedAggregation(
|
|
114
114
|
metric_name=ShieldInferenceRuleCountAggregation.METRIC_NAME,
|
|
115
115
|
description=ShieldInferenceRuleCountAggregation.description(),
|
|
116
|
-
)
|
|
116
|
+
),
|
|
117
117
|
]
|
|
118
118
|
|
|
119
119
|
def aggregate(
|
|
@@ -205,7 +205,7 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
|
|
|
205
205
|
BaseReportedAggregation(
|
|
206
206
|
metric_name=ShieldInferenceHallucinationCountAggregation.METRIC_NAME,
|
|
207
207
|
description=ShieldInferenceHallucinationCountAggregation.description(),
|
|
208
|
-
)
|
|
208
|
+
),
|
|
209
209
|
]
|
|
210
210
|
|
|
211
211
|
def aggregate(
|
|
@@ -269,7 +269,7 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
|
|
|
269
269
|
BaseReportedAggregation(
|
|
270
270
|
metric_name=ShieldInferenceRuleToxicityScoreAggregation.METRIC_NAME,
|
|
271
271
|
description=ShieldInferenceRuleToxicityScoreAggregation.description(),
|
|
272
|
-
)
|
|
272
|
+
),
|
|
273
273
|
]
|
|
274
274
|
|
|
275
275
|
def aggregate(
|
|
@@ -354,7 +354,7 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
|
|
|
354
354
|
BaseReportedAggregation(
|
|
355
355
|
metric_name=ShieldInferenceRulePIIDataScoreAggregation.METRIC_NAME,
|
|
356
356
|
description=ShieldInferenceRulePIIDataScoreAggregation.description(),
|
|
357
|
-
)
|
|
357
|
+
),
|
|
358
358
|
]
|
|
359
359
|
|
|
360
360
|
def aggregate(
|
|
@@ -445,7 +445,7 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
|
|
|
445
445
|
BaseReportedAggregation(
|
|
446
446
|
metric_name=ShieldInferenceRuleClaimCountAggregation.METRIC_NAME,
|
|
447
447
|
description=ShieldInferenceRuleClaimCountAggregation.description(),
|
|
448
|
-
)
|
|
448
|
+
),
|
|
449
449
|
]
|
|
450
450
|
|
|
451
451
|
def aggregate(
|
|
@@ -518,7 +518,7 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
|
|
|
518
518
|
BaseReportedAggregation(
|
|
519
519
|
metric_name=ShieldInferenceRuleClaimPassCountAggregation.METRIC_NAME,
|
|
520
520
|
description=ShieldInferenceRuleClaimPassCountAggregation.description(),
|
|
521
|
-
)
|
|
521
|
+
),
|
|
522
522
|
]
|
|
523
523
|
|
|
524
524
|
def aggregate(
|
|
@@ -591,7 +591,7 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
|
|
|
591
591
|
BaseReportedAggregation(
|
|
592
592
|
metric_name=ShieldInferenceRuleClaimFailCountAggregation.METRIC_NAME,
|
|
593
593
|
description=ShieldInferenceRuleClaimFailCountAggregation.description(),
|
|
594
|
-
)
|
|
594
|
+
),
|
|
595
595
|
]
|
|
596
596
|
|
|
597
597
|
def aggregate(
|
|
@@ -664,7 +664,7 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
664
664
|
BaseReportedAggregation(
|
|
665
665
|
metric_name=ShieldInferenceRuleLatencyAggregation.METRIC_NAME,
|
|
666
666
|
description=ShieldInferenceRuleLatencyAggregation.description(),
|
|
667
|
-
)
|
|
667
|
+
),
|
|
668
668
|
]
|
|
669
669
|
|
|
670
670
|
def aggregate(
|
|
@@ -768,7 +768,7 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
768
768
|
return [base_token_count_agg] + [
|
|
769
769
|
BaseReportedAggregation(
|
|
770
770
|
metric_name=ShieldInferenceTokenCountAggregation._series_name_from_model_name(
|
|
771
|
-
model_name
|
|
771
|
+
model_name,
|
|
772
772
|
),
|
|
773
773
|
description=f"Metric that reports the estimated cost for the {model_name} model of the tokens in the Shield response and prompt schemas.",
|
|
774
774
|
)
|
|
@@ -855,7 +855,8 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
855
855
|
)
|
|
856
856
|
resp.append(
|
|
857
857
|
self.series_to_metric(
|
|
858
|
-
self._series_name_from_model_name(model),
|
|
859
|
-
|
|
858
|
+
self._series_name_from_model_name(model),
|
|
859
|
+
model_series,
|
|
860
|
+
),
|
|
860
861
|
)
|
|
861
862
|
return resp
|
arthur_common/models/datasets.py
CHANGED
arthur_common/models/metrics.py
CHANGED
|
@@ -194,7 +194,8 @@ class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSc
|
|
|
194
194
|
|
|
195
195
|
|
|
196
196
|
class MetricsColumnListParameterSchema(
|
|
197
|
-
MetricsParameterSchema,
|
|
197
|
+
MetricsParameterSchema,
|
|
198
|
+
BaseColumnParameterSchema,
|
|
198
199
|
):
|
|
199
200
|
# list column parameter schema specific to default metrics
|
|
200
201
|
parameter_type: Literal["column_list"] = "column_list"
|
|
@@ -249,14 +250,14 @@ class AggregationSpecSchema(BaseModel):
|
|
|
249
250
|
description="List of parameters to the aggregation's aggregate function.",
|
|
250
251
|
)
|
|
251
252
|
reported_aggregations: list[BaseReportedAggregation] = Field(
|
|
252
|
-
description="List of aggregations reported by the metric."
|
|
253
|
+
description="List of aggregations reported by the metric.",
|
|
253
254
|
)
|
|
254
255
|
|
|
255
256
|
@model_validator(mode="after")
|
|
256
257
|
def at_least_one_reported_agg(self) -> Self:
|
|
257
258
|
if len(self.reported_aggregations) < 1:
|
|
258
259
|
raise ValueError(
|
|
259
|
-
"Aggregation spec must specify at least one reported aggregation."
|
|
260
|
+
"Aggregation spec must specify at least one reported aggregation.",
|
|
260
261
|
)
|
|
261
262
|
return self
|
|
262
263
|
|
|
@@ -283,16 +284,16 @@ class AggregationSpecSchema(BaseModel):
|
|
|
283
284
|
|
|
284
285
|
class ReportedCustomAggregation(BaseReportedAggregation):
|
|
285
286
|
value_column: str = Field(
|
|
286
|
-
description="Name of the column returned from the SQL query holding the metric value."
|
|
287
|
+
description="Name of the column returned from the SQL query holding the metric value.",
|
|
287
288
|
)
|
|
288
289
|
timestamp_column: str = Field(
|
|
289
|
-
description="Name of the column returned from the SQL query holding the timestamp buckets."
|
|
290
|
+
description="Name of the column returned from the SQL query holding the timestamp buckets.",
|
|
290
291
|
)
|
|
291
292
|
metric_kind: AggregationMetricType = Field(
|
|
292
293
|
description="Return type of the reported aggregation metric value.",
|
|
293
294
|
)
|
|
294
295
|
dimension_columns: list[str] = Field(
|
|
295
|
-
description="Name of any dimension columns returned from the SQL query. Max length is 1."
|
|
296
|
+
description="Name of any dimension columns returned from the SQL query. Max length is 1.",
|
|
296
297
|
)
|
|
297
298
|
|
|
298
299
|
@field_validator("dimension_columns")
|
|
@@ -367,6 +367,38 @@ def create_shield_inference_feedback_schema() -> DatasetListType:
|
|
|
367
367
|
)
|
|
368
368
|
|
|
369
369
|
|
|
370
|
+
def AGENTIC_TRACE_SCHEMA() -> DatasetSchema:
|
|
371
|
+
return DatasetSchema(
|
|
372
|
+
alias_mask={},
|
|
373
|
+
columns=[
|
|
374
|
+
DatasetColumn(
|
|
375
|
+
id=uuid4(),
|
|
376
|
+
source_name="trace_id",
|
|
377
|
+
definition=create_dataset_scalar_type(DType.STRING),
|
|
378
|
+
),
|
|
379
|
+
DatasetColumn(
|
|
380
|
+
id=uuid4(),
|
|
381
|
+
source_name="start_time",
|
|
382
|
+
definition=create_dataset_scalar_type(DType.TIMESTAMP),
|
|
383
|
+
),
|
|
384
|
+
DatasetColumn(
|
|
385
|
+
id=uuid4(),
|
|
386
|
+
source_name="end_time",
|
|
387
|
+
definition=create_dataset_scalar_type(DType.TIMESTAMP),
|
|
388
|
+
),
|
|
389
|
+
DatasetColumn(
|
|
390
|
+
id=uuid4(),
|
|
391
|
+
source_name="root_spans",
|
|
392
|
+
definition=create_dataset_list_type(
|
|
393
|
+
create_dataset_scalar_type(
|
|
394
|
+
DType.JSON,
|
|
395
|
+
), # JSON blob to preserve hierarchy
|
|
396
|
+
),
|
|
397
|
+
),
|
|
398
|
+
],
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
|
|
370
402
|
def SHIELD_SCHEMA() -> DatasetSchema:
|
|
371
403
|
return DatasetSchema(
|
|
372
404
|
alias_mask={},
|
|
@@ -423,6 +455,32 @@ def SHIELD_SCHEMA() -> DatasetSchema:
|
|
|
423
455
|
SHIELD_RESPONSE_SCHEMA = create_shield_response_schema().to_base_type()
|
|
424
456
|
SHIELD_PROMPT_SCHEMA = create_shield_prompt_schema().to_base_type()
|
|
425
457
|
|
|
458
|
+
|
|
459
|
+
# Agentic trace schema base type for API responses
|
|
460
|
+
def create_agentic_trace_response_schema() -> DatasetObjectType:
|
|
461
|
+
return create_dataset_object_type(
|
|
462
|
+
{
|
|
463
|
+
"count": create_dataset_scalar_type(DType.INT),
|
|
464
|
+
"traces": create_dataset_list_type(
|
|
465
|
+
create_dataset_object_type(
|
|
466
|
+
{
|
|
467
|
+
"trace_id": create_dataset_scalar_type(DType.STRING),
|
|
468
|
+
"start_time": create_dataset_scalar_type(DType.TIMESTAMP),
|
|
469
|
+
"end_time": create_dataset_scalar_type(DType.TIMESTAMP),
|
|
470
|
+
"root_spans": create_dataset_list_type(
|
|
471
|
+
create_dataset_scalar_type(
|
|
472
|
+
DType.JSON,
|
|
473
|
+
), # JSON blob for infinite depth
|
|
474
|
+
),
|
|
475
|
+
},
|
|
476
|
+
),
|
|
477
|
+
),
|
|
478
|
+
},
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
AGENTIC_TRACE_RESPONSE_SCHEMA = create_agentic_trace_response_schema().to_base_type()
|
|
483
|
+
|
|
426
484
|
SEGMENTATION_ALLOWED_DTYPES = [DType.INT, DType.BOOL, DType.STRING, DType.UUID]
|
|
427
485
|
SEGMENTATION_ALLOWED_COLUMN_TYPES = [
|
|
428
486
|
ScalarType(dtype=d_type) for d_type in SEGMENTATION_ALLOWED_DTYPES
|
arthur_common/models/shield.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from datetime import datetime
|
|
1
2
|
from enum import Enum
|
|
2
3
|
from typing import Any, Dict, List, Optional, Self, Type, Union
|
|
3
4
|
|
|
@@ -26,6 +27,15 @@ class RuleScope(str, Enum):
|
|
|
26
27
|
TASK = "task"
|
|
27
28
|
|
|
28
29
|
|
|
30
|
+
class MetricType(str, Enum):
|
|
31
|
+
QUERY_RELEVANCE = "QueryRelevance"
|
|
32
|
+
RESPONSE_RELEVANCE = "ResponseRelevance"
|
|
33
|
+
TOOL_SELECTION = "ToolSelection"
|
|
34
|
+
|
|
35
|
+
def __str__(self):
|
|
36
|
+
return self.value
|
|
37
|
+
|
|
38
|
+
|
|
29
39
|
class BaseEnum(str, Enum):
|
|
30
40
|
@classmethod
|
|
31
41
|
def values(cls) -> list[Any]:
|
|
@@ -240,6 +250,27 @@ class RuleResponse(BaseModel):
|
|
|
240
250
|
)
|
|
241
251
|
|
|
242
252
|
|
|
253
|
+
class MetricResponse(BaseModel):
|
|
254
|
+
id: str = Field(description="ID of the Metric")
|
|
255
|
+
name: str = Field(description="Name of the Metric")
|
|
256
|
+
type: MetricType = Field(description="Type of the Metric")
|
|
257
|
+
metric_metadata: str = Field(description="Metadata of the Metric")
|
|
258
|
+
config: Optional[str] = Field(
|
|
259
|
+
description="JSON-serialized configuration for the Metric",
|
|
260
|
+
default=None,
|
|
261
|
+
)
|
|
262
|
+
created_at: datetime = Field(
|
|
263
|
+
description="Time the Metric was created in unix milliseconds",
|
|
264
|
+
)
|
|
265
|
+
updated_at: datetime = Field(
|
|
266
|
+
description="Time the Metric was updated in unix milliseconds",
|
|
267
|
+
)
|
|
268
|
+
enabled: Optional[bool] = Field(
|
|
269
|
+
description="Whether the Metric is enabled",
|
|
270
|
+
default=None,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
243
274
|
class TaskResponse(BaseModel):
|
|
244
275
|
id: str = Field(description=" ID of the task")
|
|
245
276
|
name: str = Field(description="Name of the task")
|
|
@@ -249,7 +280,12 @@ class TaskResponse(BaseModel):
|
|
|
249
280
|
updated_at: int = Field(
|
|
250
281
|
description="Time the task was created in unix milliseconds",
|
|
251
282
|
)
|
|
283
|
+
is_agentic: bool = Field(description="Whether the task is agentic or not")
|
|
252
284
|
rules: List[RuleResponse] = Field(description="List of all the rules for the task.")
|
|
285
|
+
metrics: Optional[List[MetricResponse]] = Field(
|
|
286
|
+
description="List of all the metrics for the task.",
|
|
287
|
+
default=None,
|
|
288
|
+
)
|
|
253
289
|
|
|
254
290
|
|
|
255
291
|
class UpdateRuleRequest(BaseModel):
|
|
@@ -484,3 +520,120 @@ class NewRuleRequest(BaseModel):
|
|
|
484
520
|
detail="Examples must be provided to onboard a ModelSensitiveDataRule",
|
|
485
521
|
)
|
|
486
522
|
return self
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
class RelevanceMetricConfig(BaseModel):
|
|
526
|
+
"""Configuration for relevance metrics including QueryRelevance and ResponseRelevance"""
|
|
527
|
+
|
|
528
|
+
relevance_threshold: Optional[float] = Field(
|
|
529
|
+
default=None,
|
|
530
|
+
description="Threshold for determining relevance when not using LLM judge",
|
|
531
|
+
)
|
|
532
|
+
use_llm_judge: bool = Field(
|
|
533
|
+
default=True,
|
|
534
|
+
description="Whether to use LLM as a judge for relevance scoring",
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class NewMetricRequest(BaseModel):
|
|
539
|
+
type: MetricType = Field(
|
|
540
|
+
description="Type of the metric. It can only be one of QueryRelevance, ResponseRelevance, ToolSelection",
|
|
541
|
+
examples=["UserQueryRelevance"],
|
|
542
|
+
)
|
|
543
|
+
name: str = Field(
|
|
544
|
+
description="Name of metric",
|
|
545
|
+
examples=["My User Query Relevance"],
|
|
546
|
+
)
|
|
547
|
+
metric_metadata: str = Field(description="Additional metadata for the metric")
|
|
548
|
+
config: Optional[RelevanceMetricConfig] = Field(
|
|
549
|
+
description="Configuration for the metric. Currently only applies to UserQueryRelevance and ResponseRelevance metric types.",
|
|
550
|
+
default=None,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
model_config = ConfigDict(
|
|
554
|
+
json_schema_extra={
|
|
555
|
+
"example1": {
|
|
556
|
+
"type": "QueryRelevance",
|
|
557
|
+
"name": "My User Query Relevance",
|
|
558
|
+
"metric_metadata": "This is a test metric metadata",
|
|
559
|
+
},
|
|
560
|
+
"example2": {
|
|
561
|
+
"type": "QueryRelevance",
|
|
562
|
+
"name": "My User Query Relevance with Config",
|
|
563
|
+
"metric_metadata": "This is a test metric metadata",
|
|
564
|
+
"config": {"relevance_threshold": 0.8, "use_llm_judge": False},
|
|
565
|
+
},
|
|
566
|
+
"example3": {
|
|
567
|
+
"type": "ResponseRelevance",
|
|
568
|
+
"name": "My Response Relevance",
|
|
569
|
+
"metric_metadata": "This is a test metric metadata",
|
|
570
|
+
"config": {"use_llm_judge": True},
|
|
571
|
+
},
|
|
572
|
+
},
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
@field_validator("type")
|
|
576
|
+
def validate_metric_type(cls, value):
|
|
577
|
+
if value not in MetricType:
|
|
578
|
+
raise ValueError(
|
|
579
|
+
f"Invalid metric type: {value}. Valid types are: {', '.join([t.value for t in MetricType])}",
|
|
580
|
+
)
|
|
581
|
+
return value
|
|
582
|
+
|
|
583
|
+
@model_validator(mode="before")
|
|
584
|
+
def set_config_type(cls, values):
|
|
585
|
+
if not isinstance(values, dict):
|
|
586
|
+
return values
|
|
587
|
+
|
|
588
|
+
metric_type = values.get("type")
|
|
589
|
+
config_values = values.get("config")
|
|
590
|
+
|
|
591
|
+
# Map metric types to their corresponding config classes
|
|
592
|
+
metric_type_to_config = {
|
|
593
|
+
MetricType.QUERY_RELEVANCE: RelevanceMetricConfig,
|
|
594
|
+
MetricType.RESPONSE_RELEVANCE: RelevanceMetricConfig,
|
|
595
|
+
# Add new metric types and their configs here as needed
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
config_class = metric_type_to_config.get(metric_type)
|
|
599
|
+
|
|
600
|
+
if config_class is not None:
|
|
601
|
+
if config_values is None:
|
|
602
|
+
# Default config when none is provided
|
|
603
|
+
config_values = {"use_llm_judge": True}
|
|
604
|
+
elif isinstance(config_values, dict):
|
|
605
|
+
relevance_threshold = config_values.get("relevance_threshold")
|
|
606
|
+
use_llm_judge = config_values.get("use_llm_judge")
|
|
607
|
+
|
|
608
|
+
# Handle mutually exclusive parameters
|
|
609
|
+
if relevance_threshold is not None and use_llm_judge:
|
|
610
|
+
raise HTTPException(
|
|
611
|
+
status_code=400,
|
|
612
|
+
detail="relevance_threshold and use_llm_judge=true are mutually exclusive. Set use_llm_judge=false when using relevance_threshold.",
|
|
613
|
+
headers={"full_stacktrace": "false"},
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
# If relevance_threshold is set but use_llm_judge isn't, set use_llm_judge to false
|
|
617
|
+
if relevance_threshold is not None and use_llm_judge is None:
|
|
618
|
+
config_values["use_llm_judge"] = False
|
|
619
|
+
|
|
620
|
+
# If neither is set, default to use_llm_judge=True
|
|
621
|
+
if relevance_threshold is None and (
|
|
622
|
+
use_llm_judge is None or use_llm_judge == False
|
|
623
|
+
):
|
|
624
|
+
config_values["use_llm_judge"] = True
|
|
625
|
+
|
|
626
|
+
if isinstance(config_values, BaseModel):
|
|
627
|
+
config_values = config_values.model_dump()
|
|
628
|
+
|
|
629
|
+
values["config"] = config_class(**config_values)
|
|
630
|
+
elif config_values is not None:
|
|
631
|
+
# Provide a nice error message listing supported metric types
|
|
632
|
+
supported_types = [t.value for t in metric_type_to_config.keys()]
|
|
633
|
+
raise HTTPException(
|
|
634
|
+
status_code=400,
|
|
635
|
+
detail=f"Config is only supported for {', '.join(supported_types)} metric types",
|
|
636
|
+
headers={"full_stacktrace": "false"},
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
return values
|