arthur-common 2.1.52__py3-none-any.whl → 2.1.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/aggregations/aggregator.py +6 -0
- arthur_common/aggregations/functions/agentic_aggregations.py +875 -0
- arthur_common/aggregations/functions/categorical_count.py +14 -1
- arthur_common/aggregations/functions/confusion_matrix.py +35 -5
- arthur_common/aggregations/functions/inference_count.py +14 -1
- arthur_common/aggregations/functions/inference_count_by_class.py +23 -1
- arthur_common/aggregations/functions/inference_null_count.py +15 -1
- arthur_common/aggregations/functions/mean_absolute_error.py +26 -3
- arthur_common/aggregations/functions/mean_squared_error.py +26 -3
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +43 -5
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +14 -1
- arthur_common/aggregations/functions/numeric_stats.py +14 -1
- arthur_common/aggregations/functions/numeric_sum.py +15 -1
- arthur_common/aggregations/functions/shield_aggregations.py +127 -16
- arthur_common/models/datasets.py +1 -0
- arthur_common/models/metrics.py +35 -18
- arthur_common/models/schema_definitions.py +58 -0
- arthur_common/models/shield.py +158 -0
- arthur_common/models/task_job_specs.py +26 -2
- arthur_common/tools/aggregation_analyzer.py +2 -1
- {arthur_common-2.1.52.dist-info → arthur_common-2.1.54.dist-info}/METADATA +1 -1
- {arthur_common-2.1.52.dist-info → arthur_common-2.1.54.dist-info}/RECORD +23 -22
- {arthur_common-2.1.52.dist-info → arthur_common-2.1.54.dist-info}/WHEEL +0 -0
|
@@ -10,7 +10,12 @@ from arthur_common.aggregations.aggregator import (
|
|
|
10
10
|
SketchAggregationFunction,
|
|
11
11
|
)
|
|
12
12
|
from arthur_common.models.datasets import ModelProblemType
|
|
13
|
-
from arthur_common.models.metrics import
|
|
13
|
+
from arthur_common.models.metrics import (
|
|
14
|
+
BaseReportedAggregation,
|
|
15
|
+
DatasetReference,
|
|
16
|
+
NumericMetric,
|
|
17
|
+
SketchMetric,
|
|
18
|
+
)
|
|
14
19
|
from arthur_common.models.schema_definitions import (
|
|
15
20
|
SHIELD_RESPONSE_SCHEMA,
|
|
16
21
|
MetricColumnParameterAnnotation,
|
|
@@ -33,6 +38,15 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
|
|
|
33
38
|
def description() -> str:
|
|
34
39
|
return "Metric that counts the number of Shield inferences grouped by the prompt, response, and overall check results."
|
|
35
40
|
|
|
41
|
+
@staticmethod
|
|
42
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
43
|
+
return [
|
|
44
|
+
BaseReportedAggregation(
|
|
45
|
+
metric_name=ShieldInferencePassFailCountAggregation.METRIC_NAME,
|
|
46
|
+
description=ShieldInferencePassFailCountAggregation.description(),
|
|
47
|
+
),
|
|
48
|
+
]
|
|
49
|
+
|
|
36
50
|
def aggregate(
|
|
37
51
|
self,
|
|
38
52
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -93,6 +107,15 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
|
|
|
93
107
|
def description() -> str:
|
|
94
108
|
return "Metric that counts the number of Shield rule evaluations grouped by whether it was on the prompt or response, the rule type, the rule evaluation result, the rule name, and the rule id."
|
|
95
109
|
|
|
110
|
+
@staticmethod
|
|
111
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
112
|
+
return [
|
|
113
|
+
BaseReportedAggregation(
|
|
114
|
+
metric_name=ShieldInferenceRuleCountAggregation.METRIC_NAME,
|
|
115
|
+
description=ShieldInferenceRuleCountAggregation.description(),
|
|
116
|
+
),
|
|
117
|
+
]
|
|
118
|
+
|
|
96
119
|
def aggregate(
|
|
97
120
|
self,
|
|
98
121
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -176,6 +199,15 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
|
|
|
176
199
|
def description() -> str:
|
|
177
200
|
return "Metric that counts the number of Shield hallucination evaluations that failed."
|
|
178
201
|
|
|
202
|
+
@staticmethod
|
|
203
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
204
|
+
return [
|
|
205
|
+
BaseReportedAggregation(
|
|
206
|
+
metric_name=ShieldInferenceHallucinationCountAggregation.METRIC_NAME,
|
|
207
|
+
description=ShieldInferenceHallucinationCountAggregation.description(),
|
|
208
|
+
),
|
|
209
|
+
]
|
|
210
|
+
|
|
179
211
|
def aggregate(
|
|
180
212
|
self,
|
|
181
213
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -231,6 +263,15 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
|
|
|
231
263
|
def description() -> str:
|
|
232
264
|
return "Metric that reports a distribution (data sketch) on toxicity scores returned by the Shield toxicity rule."
|
|
233
265
|
|
|
266
|
+
@staticmethod
|
|
267
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
268
|
+
return [
|
|
269
|
+
BaseReportedAggregation(
|
|
270
|
+
metric_name=ShieldInferenceRuleToxicityScoreAggregation.METRIC_NAME,
|
|
271
|
+
description=ShieldInferenceRuleToxicityScoreAggregation.description(),
|
|
272
|
+
),
|
|
273
|
+
]
|
|
274
|
+
|
|
234
275
|
def aggregate(
|
|
235
276
|
self,
|
|
236
277
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -307,6 +348,15 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
|
|
|
307
348
|
def description() -> str:
|
|
308
349
|
return "Metric that reports a distribution (data sketch) on PII scores returned by the Shield PII rule."
|
|
309
350
|
|
|
351
|
+
@staticmethod
|
|
352
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
353
|
+
return [
|
|
354
|
+
BaseReportedAggregation(
|
|
355
|
+
metric_name=ShieldInferenceRulePIIDataScoreAggregation.METRIC_NAME,
|
|
356
|
+
description=ShieldInferenceRulePIIDataScoreAggregation.description(),
|
|
357
|
+
),
|
|
358
|
+
]
|
|
359
|
+
|
|
310
360
|
def aggregate(
|
|
311
361
|
self,
|
|
312
362
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -389,6 +439,15 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
|
|
|
389
439
|
def description() -> str:
|
|
390
440
|
return "Metric that reports a distribution (data sketch) on over the number of claims identified by the Shield hallucination rule."
|
|
391
441
|
|
|
442
|
+
@staticmethod
|
|
443
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
444
|
+
return [
|
|
445
|
+
BaseReportedAggregation(
|
|
446
|
+
metric_name=ShieldInferenceRuleClaimCountAggregation.METRIC_NAME,
|
|
447
|
+
description=ShieldInferenceRuleClaimCountAggregation.description(),
|
|
448
|
+
),
|
|
449
|
+
]
|
|
450
|
+
|
|
392
451
|
def aggregate(
|
|
393
452
|
self,
|
|
394
453
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -453,6 +512,15 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
|
|
|
453
512
|
def description() -> str:
|
|
454
513
|
return "Metric that reports a distribution (data sketch) on the number of valid claims determined by the Shield hallucination rule."
|
|
455
514
|
|
|
515
|
+
@staticmethod
|
|
516
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
517
|
+
return [
|
|
518
|
+
BaseReportedAggregation(
|
|
519
|
+
metric_name=ShieldInferenceRuleClaimPassCountAggregation.METRIC_NAME,
|
|
520
|
+
description=ShieldInferenceRuleClaimPassCountAggregation.description(),
|
|
521
|
+
),
|
|
522
|
+
]
|
|
523
|
+
|
|
456
524
|
def aggregate(
|
|
457
525
|
self,
|
|
458
526
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -517,6 +585,15 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
|
|
|
517
585
|
def description() -> str:
|
|
518
586
|
return "Metric that reports a distribution (data sketch) on the number of invalid claims determined by the Shield hallucination rule."
|
|
519
587
|
|
|
588
|
+
@staticmethod
|
|
589
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
590
|
+
return [
|
|
591
|
+
BaseReportedAggregation(
|
|
592
|
+
metric_name=ShieldInferenceRuleClaimFailCountAggregation.METRIC_NAME,
|
|
593
|
+
description=ShieldInferenceRuleClaimFailCountAggregation.description(),
|
|
594
|
+
),
|
|
595
|
+
]
|
|
596
|
+
|
|
520
597
|
def aggregate(
|
|
521
598
|
self,
|
|
522
599
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -581,6 +658,15 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
581
658
|
def description() -> str:
|
|
582
659
|
return "Metric that reports a distribution (data sketch) on the latency of Shield rule evaluations. Dimensions are the rule result, rule type, and whether the rule was applicable to a prompt or response."
|
|
583
660
|
|
|
661
|
+
@staticmethod
|
|
662
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
663
|
+
return [
|
|
664
|
+
BaseReportedAggregation(
|
|
665
|
+
metric_name=ShieldInferenceRuleLatencyAggregation.METRIC_NAME,
|
|
666
|
+
description=ShieldInferenceRuleLatencyAggregation.description(),
|
|
667
|
+
),
|
|
668
|
+
]
|
|
669
|
+
|
|
584
670
|
def aggregate(
|
|
585
671
|
self,
|
|
586
672
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -643,6 +729,18 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
643
729
|
|
|
644
730
|
class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
645
731
|
METRIC_NAME = "token_count"
|
|
732
|
+
SUPPORTED_MODELS = [
|
|
733
|
+
"gpt-4o",
|
|
734
|
+
"gpt-4o-mini",
|
|
735
|
+
"gpt-3.5-turbo",
|
|
736
|
+
"o1-mini",
|
|
737
|
+
"deepseek-chat",
|
|
738
|
+
"claude-3-5-sonnet-20241022",
|
|
739
|
+
"gemini/gemini-1.5-pro",
|
|
740
|
+
"meta.llama3-1-8b-instruct-v1:0",
|
|
741
|
+
"meta.llama3-1-70b-instruct-v1:0",
|
|
742
|
+
"meta.llama3-2-11b-instruct-v1:0",
|
|
743
|
+
]
|
|
646
744
|
|
|
647
745
|
@staticmethod
|
|
648
746
|
def id() -> UUID:
|
|
@@ -656,6 +754,27 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
656
754
|
def description() -> str:
|
|
657
755
|
return "Metric that reports the number of tokens in the Shield response and prompt schemas, and their estimated cost."
|
|
658
756
|
|
|
757
|
+
@staticmethod
|
|
758
|
+
def _series_name_from_model_name(model_name: str) -> str:
|
|
759
|
+
"""Calculates name of reported series based on the model name considered."""
|
|
760
|
+
return f"token_cost.{model_name}"
|
|
761
|
+
|
|
762
|
+
@staticmethod
|
|
763
|
+
def reported_aggregations() -> list[BaseReportedAggregation]:
|
|
764
|
+
base_token_count_agg = BaseReportedAggregation(
|
|
765
|
+
metric_name=ShieldInferenceTokenCountAggregation.METRIC_NAME,
|
|
766
|
+
description=f"Metric that reports the number of tokens in the Shield response and prompt schemas.",
|
|
767
|
+
)
|
|
768
|
+
return [base_token_count_agg] + [
|
|
769
|
+
BaseReportedAggregation(
|
|
770
|
+
metric_name=ShieldInferenceTokenCountAggregation._series_name_from_model_name(
|
|
771
|
+
model_name,
|
|
772
|
+
),
|
|
773
|
+
description=f"Metric that reports the estimated cost for the {model_name} model of the tokens in the Shield response and prompt schemas.",
|
|
774
|
+
)
|
|
775
|
+
for model_name in ShieldInferenceTokenCountAggregation.SUPPORTED_MODELS
|
|
776
|
+
]
|
|
777
|
+
|
|
659
778
|
def aggregate(
|
|
660
779
|
self,
|
|
661
780
|
ddb_conn: DuckDBPyConnection,
|
|
@@ -708,25 +827,12 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
708
827
|
resp = [metric]
|
|
709
828
|
|
|
710
829
|
# Compute Cost for each model
|
|
711
|
-
models = [
|
|
712
|
-
"gpt-4o",
|
|
713
|
-
"gpt-4o-mini",
|
|
714
|
-
"gpt-3.5-turbo",
|
|
715
|
-
"o1-mini",
|
|
716
|
-
"deepseek-chat",
|
|
717
|
-
"claude-3-5-sonnet-20241022",
|
|
718
|
-
"gemini/gemini-1.5-pro",
|
|
719
|
-
"meta.llama3-1-8b-instruct-v1:0",
|
|
720
|
-
"meta.llama3-1-70b-instruct-v1:0",
|
|
721
|
-
"meta.llama3-2-11b-instruct-v1:0",
|
|
722
|
-
]
|
|
723
|
-
|
|
724
830
|
# Precompute input/output classification to avoid recalculating in loop
|
|
725
831
|
location_type = results["location"].apply(
|
|
726
832
|
lambda x: "input" if x == "prompt" else "output",
|
|
727
833
|
)
|
|
728
834
|
|
|
729
|
-
for model in
|
|
835
|
+
for model in self.SUPPORTED_MODELS:
|
|
730
836
|
# Efficient list comprehension instead of apply
|
|
731
837
|
cost_values = [
|
|
732
838
|
calculate_cost_by_tokens(int(tokens), model, loc_type)
|
|
@@ -747,5 +853,10 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
747
853
|
["location"],
|
|
748
854
|
"ts",
|
|
749
855
|
)
|
|
750
|
-
resp.append(
|
|
856
|
+
resp.append(
|
|
857
|
+
self.series_to_metric(
|
|
858
|
+
self._series_name_from_model_name(model),
|
|
859
|
+
model_series,
|
|
860
|
+
),
|
|
861
|
+
)
|
|
751
862
|
return resp
|
arthur_common/models/datasets.py
CHANGED
arthur_common/models/metrics.py
CHANGED
|
@@ -193,7 +193,10 @@ class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSc
|
|
|
193
193
|
parameter_type: Literal["column"] = "column"
|
|
194
194
|
|
|
195
195
|
|
|
196
|
-
class MetricsColumnListParameterSchema(
|
|
196
|
+
class MetricsColumnListParameterSchema(
|
|
197
|
+
MetricsParameterSchema,
|
|
198
|
+
BaseColumnParameterSchema,
|
|
199
|
+
):
|
|
197
200
|
# list column parameter schema specific to default metrics
|
|
198
201
|
parameter_type: Literal["column_list"] = "column_list"
|
|
199
202
|
|
|
@@ -211,9 +214,7 @@ MetricsColumnSchemaUnion = (
|
|
|
211
214
|
|
|
212
215
|
|
|
213
216
|
CustomAggregationParametersSchemaUnion = (
|
|
214
|
-
BaseDatasetParameterSchema
|
|
215
|
-
| BaseLiteralParameterSchema
|
|
216
|
-
| BaseColumnParameterSchema
|
|
217
|
+
BaseDatasetParameterSchema | BaseLiteralParameterSchema | BaseColumnParameterSchema
|
|
217
218
|
)
|
|
218
219
|
|
|
219
220
|
|
|
@@ -224,6 +225,14 @@ class DatasetReference:
|
|
|
224
225
|
dataset_id: UUID
|
|
225
226
|
|
|
226
227
|
|
|
228
|
+
class BaseReportedAggregation(BaseModel):
|
|
229
|
+
# in future will be used by default metrics
|
|
230
|
+
metric_name: str = Field(description="Name of the reported aggregation metric.")
|
|
231
|
+
description: str = Field(
|
|
232
|
+
description="Description of the reported aggregation metric and what it aggregates.",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
227
236
|
class AggregationSpecSchema(BaseModel):
|
|
228
237
|
name: str = Field(description="Name of the aggregation function.")
|
|
229
238
|
id: UUID = Field(description="Unique identifier of the aggregation function.")
|
|
@@ -240,6 +249,17 @@ class AggregationSpecSchema(BaseModel):
|
|
|
240
249
|
aggregate_args: list[MetricsParameterSchemaUnion] = Field(
|
|
241
250
|
description="List of parameters to the aggregation's aggregate function.",
|
|
242
251
|
)
|
|
252
|
+
reported_aggregations: list[BaseReportedAggregation] = Field(
|
|
253
|
+
description="List of aggregations reported by the metric.",
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
@model_validator(mode="after")
|
|
257
|
+
def at_least_one_reported_agg(self) -> Self:
|
|
258
|
+
if len(self.reported_aggregations) < 1:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
"Aggregation spec must specify at least one reported aggregation.",
|
|
261
|
+
)
|
|
262
|
+
return self
|
|
243
263
|
|
|
244
264
|
@model_validator(mode="after")
|
|
245
265
|
def column_dataset_references_exist(self) -> Self:
|
|
@@ -262,26 +282,23 @@ class AggregationSpecSchema(BaseModel):
|
|
|
262
282
|
return self
|
|
263
283
|
|
|
264
284
|
|
|
265
|
-
class BaseReportedAggregation(BaseModel):
|
|
266
|
-
# in future will be used by default metrics
|
|
267
|
-
metric_name: str = Field(description="Name of the reported aggregation metric.")
|
|
268
|
-
description: str = Field(
|
|
269
|
-
description="Description of the reported aggregation metric and what it aggregates.",
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
|
|
273
285
|
class ReportedCustomAggregation(BaseReportedAggregation):
|
|
274
|
-
value_column: str = Field(
|
|
275
|
-
|
|
286
|
+
value_column: str = Field(
|
|
287
|
+
description="Name of the column returned from the SQL query holding the metric value.",
|
|
288
|
+
)
|
|
289
|
+
timestamp_column: str = Field(
|
|
290
|
+
description="Name of the column returned from the SQL query holding the timestamp buckets.",
|
|
291
|
+
)
|
|
276
292
|
metric_kind: AggregationMetricType = Field(
|
|
277
293
|
description="Return type of the reported aggregation metric value.",
|
|
278
294
|
)
|
|
279
|
-
dimension_columns: list[str] = Field(
|
|
295
|
+
dimension_columns: list[str] = Field(
|
|
296
|
+
description="Name of any dimension columns returned from the SQL query. Max length is 1.",
|
|
297
|
+
)
|
|
280
298
|
|
|
281
|
-
@field_validator(
|
|
299
|
+
@field_validator("dimension_columns")
|
|
282
300
|
@classmethod
|
|
283
301
|
def validate_dimension_columns_length(cls, v: list[str]) -> str:
|
|
284
302
|
if len(v) > 1:
|
|
285
|
-
raise ValueError(
|
|
303
|
+
raise ValueError("Only one dimension column can be specified.")
|
|
286
304
|
return v
|
|
287
|
-
|
|
@@ -367,6 +367,38 @@ def create_shield_inference_feedback_schema() -> DatasetListType:
|
|
|
367
367
|
)
|
|
368
368
|
|
|
369
369
|
|
|
370
|
+
def AGENTIC_TRACE_SCHEMA() -> DatasetSchema:
|
|
371
|
+
return DatasetSchema(
|
|
372
|
+
alias_mask={},
|
|
373
|
+
columns=[
|
|
374
|
+
DatasetColumn(
|
|
375
|
+
id=uuid4(),
|
|
376
|
+
source_name="trace_id",
|
|
377
|
+
definition=create_dataset_scalar_type(DType.STRING),
|
|
378
|
+
),
|
|
379
|
+
DatasetColumn(
|
|
380
|
+
id=uuid4(),
|
|
381
|
+
source_name="start_time",
|
|
382
|
+
definition=create_dataset_scalar_type(DType.TIMESTAMP),
|
|
383
|
+
),
|
|
384
|
+
DatasetColumn(
|
|
385
|
+
id=uuid4(),
|
|
386
|
+
source_name="end_time",
|
|
387
|
+
definition=create_dataset_scalar_type(DType.TIMESTAMP),
|
|
388
|
+
),
|
|
389
|
+
DatasetColumn(
|
|
390
|
+
id=uuid4(),
|
|
391
|
+
source_name="root_spans",
|
|
392
|
+
definition=create_dataset_list_type(
|
|
393
|
+
create_dataset_scalar_type(
|
|
394
|
+
DType.JSON,
|
|
395
|
+
), # JSON blob to preserve hierarchy
|
|
396
|
+
),
|
|
397
|
+
),
|
|
398
|
+
],
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
|
|
370
402
|
def SHIELD_SCHEMA() -> DatasetSchema:
|
|
371
403
|
return DatasetSchema(
|
|
372
404
|
alias_mask={},
|
|
@@ -423,6 +455,32 @@ def SHIELD_SCHEMA() -> DatasetSchema:
|
|
|
423
455
|
SHIELD_RESPONSE_SCHEMA = create_shield_response_schema().to_base_type()
|
|
424
456
|
SHIELD_PROMPT_SCHEMA = create_shield_prompt_schema().to_base_type()
|
|
425
457
|
|
|
458
|
+
|
|
459
|
+
# Agentic trace schema base type for API responses
|
|
460
|
+
def create_agentic_trace_response_schema() -> DatasetObjectType:
|
|
461
|
+
return create_dataset_object_type(
|
|
462
|
+
{
|
|
463
|
+
"count": create_dataset_scalar_type(DType.INT),
|
|
464
|
+
"traces": create_dataset_list_type(
|
|
465
|
+
create_dataset_object_type(
|
|
466
|
+
{
|
|
467
|
+
"trace_id": create_dataset_scalar_type(DType.STRING),
|
|
468
|
+
"start_time": create_dataset_scalar_type(DType.TIMESTAMP),
|
|
469
|
+
"end_time": create_dataset_scalar_type(DType.TIMESTAMP),
|
|
470
|
+
"root_spans": create_dataset_list_type(
|
|
471
|
+
create_dataset_scalar_type(
|
|
472
|
+
DType.JSON,
|
|
473
|
+
), # JSON blob for infinite depth
|
|
474
|
+
),
|
|
475
|
+
},
|
|
476
|
+
),
|
|
477
|
+
),
|
|
478
|
+
},
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
AGENTIC_TRACE_RESPONSE_SCHEMA = create_agentic_trace_response_schema().to_base_type()
|
|
483
|
+
|
|
426
484
|
SEGMENTATION_ALLOWED_DTYPES = [DType.INT, DType.BOOL, DType.STRING, DType.UUID]
|
|
427
485
|
SEGMENTATION_ALLOWED_COLUMN_TYPES = [
|
|
428
486
|
ScalarType(dtype=d_type) for d_type in SEGMENTATION_ALLOWED_DTYPES
|
arthur_common/models/shield.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from datetime import datetime
|
|
1
2
|
from enum import Enum
|
|
2
3
|
from typing import Any, Dict, List, Optional, Self, Type, Union
|
|
3
4
|
|
|
@@ -26,6 +27,15 @@ class RuleScope(str, Enum):
|
|
|
26
27
|
TASK = "task"
|
|
27
28
|
|
|
28
29
|
|
|
30
|
+
class MetricType(str, Enum):
|
|
31
|
+
QUERY_RELEVANCE = "QueryRelevance"
|
|
32
|
+
RESPONSE_RELEVANCE = "ResponseRelevance"
|
|
33
|
+
TOOL_SELECTION = "ToolSelection"
|
|
34
|
+
|
|
35
|
+
def __str__(self):
|
|
36
|
+
return self.value
|
|
37
|
+
|
|
38
|
+
|
|
29
39
|
class BaseEnum(str, Enum):
|
|
30
40
|
@classmethod
|
|
31
41
|
def values(cls) -> list[Any]:
|
|
@@ -240,6 +250,27 @@ class RuleResponse(BaseModel):
|
|
|
240
250
|
)
|
|
241
251
|
|
|
242
252
|
|
|
253
|
+
class MetricResponse(BaseModel):
|
|
254
|
+
id: str = Field(description="ID of the Metric")
|
|
255
|
+
name: str = Field(description="Name of the Metric")
|
|
256
|
+
type: MetricType = Field(description="Type of the Metric")
|
|
257
|
+
metric_metadata: str = Field(description="Metadata of the Metric")
|
|
258
|
+
config: Optional[str] = Field(
|
|
259
|
+
description="JSON-serialized configuration for the Metric",
|
|
260
|
+
default=None,
|
|
261
|
+
)
|
|
262
|
+
created_at: datetime = Field(
|
|
263
|
+
description="Time the Metric was created in unix milliseconds",
|
|
264
|
+
)
|
|
265
|
+
updated_at: datetime = Field(
|
|
266
|
+
description="Time the Metric was updated in unix milliseconds",
|
|
267
|
+
)
|
|
268
|
+
enabled: Optional[bool] = Field(
|
|
269
|
+
description="Whether the Metric is enabled",
|
|
270
|
+
default=None,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
243
274
|
class TaskResponse(BaseModel):
|
|
244
275
|
id: str = Field(description=" ID of the task")
|
|
245
276
|
name: str = Field(description="Name of the task")
|
|
@@ -249,7 +280,12 @@ class TaskResponse(BaseModel):
|
|
|
249
280
|
updated_at: int = Field(
|
|
250
281
|
description="Time the task was created in unix milliseconds",
|
|
251
282
|
)
|
|
283
|
+
is_agentic: bool = Field(description="Whether the task is agentic or not")
|
|
252
284
|
rules: List[RuleResponse] = Field(description="List of all the rules for the task.")
|
|
285
|
+
metrics: Optional[List[MetricResponse]] = Field(
|
|
286
|
+
description="List of all the metrics for the task.",
|
|
287
|
+
default=None,
|
|
288
|
+
)
|
|
253
289
|
|
|
254
290
|
|
|
255
291
|
class UpdateRuleRequest(BaseModel):
|
|
@@ -484,3 +520,125 @@ class NewRuleRequest(BaseModel):
|
|
|
484
520
|
detail="Examples must be provided to onboard a ModelSensitiveDataRule",
|
|
485
521
|
)
|
|
486
522
|
return self
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
class RelevanceMetricConfig(BaseModel):
|
|
526
|
+
"""Configuration for relevance metrics including QueryRelevance and ResponseRelevance"""
|
|
527
|
+
|
|
528
|
+
relevance_threshold: Optional[float] = Field(
|
|
529
|
+
default=None,
|
|
530
|
+
description="Threshold for determining relevance when not using LLM judge",
|
|
531
|
+
)
|
|
532
|
+
use_llm_judge: bool = Field(
|
|
533
|
+
default=True,
|
|
534
|
+
description="Whether to use LLM as a judge for relevance scoring",
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class NewMetricRequest(BaseModel):
|
|
539
|
+
type: MetricType = Field(
|
|
540
|
+
description="Type of the metric. It can only be one of QueryRelevance, ResponseRelevance, ToolSelection",
|
|
541
|
+
examples=["UserQueryRelevance"],
|
|
542
|
+
)
|
|
543
|
+
name: str = Field(
|
|
544
|
+
description="Name of metric",
|
|
545
|
+
examples=["My User Query Relevance"],
|
|
546
|
+
)
|
|
547
|
+
metric_metadata: str = Field(description="Additional metadata for the metric")
|
|
548
|
+
config: Optional[RelevanceMetricConfig] = Field(
|
|
549
|
+
description="Configuration for the metric. Currently only applies to UserQueryRelevance and ResponseRelevance metric types.",
|
|
550
|
+
default=None,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
model_config = ConfigDict(
|
|
554
|
+
json_schema_extra={
|
|
555
|
+
"example1": {
|
|
556
|
+
"type": "QueryRelevance",
|
|
557
|
+
"name": "My User Query Relevance",
|
|
558
|
+
"metric_metadata": "This is a test metric metadata",
|
|
559
|
+
},
|
|
560
|
+
"example2": {
|
|
561
|
+
"type": "QueryRelevance",
|
|
562
|
+
"name": "My User Query Relevance with Config",
|
|
563
|
+
"metric_metadata": "This is a test metric metadata",
|
|
564
|
+
"config": {"relevance_threshold": 0.8, "use_llm_judge": False},
|
|
565
|
+
},
|
|
566
|
+
"example3": {
|
|
567
|
+
"type": "ResponseRelevance",
|
|
568
|
+
"name": "My Response Relevance",
|
|
569
|
+
"metric_metadata": "This is a test metric metadata",
|
|
570
|
+
"config": {"use_llm_judge": True},
|
|
571
|
+
},
|
|
572
|
+
},
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
@field_validator("type")
|
|
576
|
+
def validate_metric_type(cls, value):
|
|
577
|
+
if value not in MetricType:
|
|
578
|
+
raise ValueError(
|
|
579
|
+
f"Invalid metric type: {value}. Valid types are: {', '.join([t.value for t in MetricType])}",
|
|
580
|
+
)
|
|
581
|
+
return value
|
|
582
|
+
|
|
583
|
+
@model_validator(mode="before")
|
|
584
|
+
def set_config_type(cls, values):
|
|
585
|
+
if not isinstance(values, dict):
|
|
586
|
+
return values
|
|
587
|
+
|
|
588
|
+
metric_type = values.get("type")
|
|
589
|
+
config_values = values.get("config")
|
|
590
|
+
|
|
591
|
+
# Map metric types to their corresponding config classes
|
|
592
|
+
metric_type_to_config = {
|
|
593
|
+
MetricType.QUERY_RELEVANCE: RelevanceMetricConfig,
|
|
594
|
+
MetricType.RESPONSE_RELEVANCE: RelevanceMetricConfig,
|
|
595
|
+
# Add new metric types and their configs here as needed
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
config_class = metric_type_to_config.get(metric_type)
|
|
599
|
+
|
|
600
|
+
if config_class is not None:
|
|
601
|
+
if config_values is None:
|
|
602
|
+
# Default config when none is provided
|
|
603
|
+
config_values = {"use_llm_judge": True}
|
|
604
|
+
elif isinstance(config_values, dict):
|
|
605
|
+
# Handle mutually exclusive parameters
|
|
606
|
+
if (
|
|
607
|
+
"relevance_threshold" in config_values
|
|
608
|
+
and "use_llm_judge" in config_values
|
|
609
|
+
and config_values["use_llm_judge"]
|
|
610
|
+
):
|
|
611
|
+
raise HTTPException(
|
|
612
|
+
status_code=400,
|
|
613
|
+
detail="relevance_threshold and use_llm_judge=true are mutually exclusive. Set use_llm_judge=false when using relevance_threshold.",
|
|
614
|
+
headers={"full_stacktrace": "false"},
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
# If relevance_threshold is set but use_llm_judge isn't, set use_llm_judge to false
|
|
618
|
+
if (
|
|
619
|
+
"relevance_threshold" in config_values
|
|
620
|
+
and "use_llm_judge" not in config_values
|
|
621
|
+
):
|
|
622
|
+
config_values["use_llm_judge"] = False
|
|
623
|
+
|
|
624
|
+
# If neither is set, default to use_llm_judge=True
|
|
625
|
+
if (
|
|
626
|
+
"relevance_threshold" not in config_values
|
|
627
|
+
and "use_llm_judge" not in config_values
|
|
628
|
+
):
|
|
629
|
+
config_values["use_llm_judge"] = True
|
|
630
|
+
|
|
631
|
+
if isinstance(config_values, BaseModel):
|
|
632
|
+
config_values = config_values.model_dump()
|
|
633
|
+
|
|
634
|
+
values["config"] = config_class(**config_values)
|
|
635
|
+
elif config_values is not None:
|
|
636
|
+
# Provide a nice error message listing supported metric types
|
|
637
|
+
supported_types = [t.value for t in metric_type_to_config.keys()]
|
|
638
|
+
raise HTTPException(
|
|
639
|
+
status_code=400,
|
|
640
|
+
detail=f"Config is only supported for {', '.join(supported_types)} metric types",
|
|
641
|
+
headers={"full_stacktrace": "false"},
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
return values
|
|
@@ -1,13 +1,23 @@
|
|
|
1
|
-
from
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Literal, Optional, Self
|
|
2
3
|
from uuid import UUID
|
|
3
4
|
|
|
4
5
|
from pydantic import BaseModel, Field
|
|
5
6
|
|
|
6
|
-
from arthur_common.models.shield import
|
|
7
|
+
from arthur_common.models.shield import (
|
|
8
|
+
NewMetricRequest,
|
|
9
|
+
NewRuleRequest,
|
|
10
|
+
model_validator,
|
|
11
|
+
)
|
|
7
12
|
|
|
8
13
|
onboarding_id_desc = "An identifier to assign to the created model to make it easy to retrieve. Used by the UI during the GenAI model creation flow."
|
|
9
14
|
|
|
10
15
|
|
|
16
|
+
class TaskType(str, Enum):
|
|
17
|
+
TRADITIONAL = "traditional"
|
|
18
|
+
AGENTIC = "agentic"
|
|
19
|
+
|
|
20
|
+
|
|
11
21
|
class CreateModelTaskJobSpec(BaseModel):
|
|
12
22
|
job_type: Literal["create_model_task"] = "create_model_task"
|
|
13
23
|
connector_id: UUID = Field(
|
|
@@ -21,6 +31,20 @@ class CreateModelTaskJobSpec(BaseModel):
|
|
|
21
31
|
initial_rules: list[NewRuleRequest] = Field(
|
|
22
32
|
description="The initial rules to apply to the created model.",
|
|
23
33
|
)
|
|
34
|
+
task_type: TaskType = Field(
|
|
35
|
+
default=TaskType.TRADITIONAL,
|
|
36
|
+
description="The type of task to create.",
|
|
37
|
+
)
|
|
38
|
+
initial_metrics: list[NewMetricRequest] = Field(
|
|
39
|
+
description="The initial metrics to apply to agentic tasks.",
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@model_validator(mode="after")
|
|
43
|
+
def initial_metric_required(self) -> Self:
|
|
44
|
+
if self.task_type == TaskType.TRADITIONAL:
|
|
45
|
+
if not len(self.initial_metrics) == 0:
|
|
46
|
+
raise ValueError("No initial_metrics when task_type is TRADITIONAL")
|
|
47
|
+
return self
|
|
24
48
|
|
|
25
49
|
|
|
26
50
|
class CreateModelLinkTaskJobSpec(BaseModel):
|
|
@@ -207,7 +207,7 @@ class FunctionAnalyzer:
|
|
|
207
207
|
)
|
|
208
208
|
# Check if X implements the required methods
|
|
209
209
|
required_methods = ["aggregate", "id", "description", "display_name"]
|
|
210
|
-
static_methods = ["description", "id", "display_name"]
|
|
210
|
+
static_methods = ["description", "id", "display_name", "reported_aggregations"]
|
|
211
211
|
for method in required_methods:
|
|
212
212
|
if not hasattr(agg_func, method) or not callable(getattr(agg_func, method)):
|
|
213
213
|
raise AttributeError(
|
|
@@ -253,6 +253,7 @@ class FunctionAnalyzer:
|
|
|
253
253
|
metric_type=metric_type,
|
|
254
254
|
init_args=aggregation_init_args,
|
|
255
255
|
aggregate_args=aggregate_args,
|
|
256
|
+
reported_aggregations=agg_func.reported_aggregations(),
|
|
256
257
|
)
|
|
257
258
|
|
|
258
259
|
|