arthur-common 2.1.53__py3-none-any.whl → 2.1.54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

@@ -5,9 +5,9 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.metrics import (
8
+ BaseReportedAggregation,
8
9
  DatasetReference,
9
10
  NumericMetric,
10
- BaseReportedAggregation,
11
11
  )
12
12
  from arthur_common.models.schema_definitions import (
13
13
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -42,7 +42,7 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
42
42
  BaseReportedAggregation(
43
43
  metric_name=CategoricalCountAggregationFunction.METRIC_NAME,
44
44
  description=CategoricalCountAggregationFunction.description(),
45
- )
45
+ ),
46
46
  ]
47
47
 
48
48
  def aggregate(
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
+ BaseReportedAggregation,
9
10
  DatasetReference,
10
11
  NumericMetric,
11
- BaseReportedAggregation,
12
12
  )
13
13
  from arthur_common.models.schema_definitions import (
14
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -5,9 +5,9 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.metrics import (
8
+ BaseReportedAggregation,
8
9
  DatasetReference,
9
10
  NumericMetric,
10
- BaseReportedAggregation,
11
11
  )
12
12
  from arthur_common.models.schema_definitions import (
13
13
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -42,7 +42,7 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
42
42
  BaseReportedAggregation(
43
43
  metric_name=InferenceCountAggregationFunction.METRIC_NAME,
44
44
  description=InferenceCountAggregationFunction.description(),
45
- )
45
+ ),
46
46
  ]
47
47
 
48
48
  def aggregate(
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
+ BaseReportedAggregation,
9
10
  DatasetReference,
10
11
  NumericMetric,
11
- BaseReportedAggregation,
12
12
  )
13
13
  from arthur_common.models.schema_definitions import (
14
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -46,7 +46,7 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
46
46
  BaseReportedAggregation(
47
47
  metric_name=BinaryClassifierCountByClassAggregationFunction._metric_name(),
48
48
  description=BinaryClassifierCountByClassAggregationFunction.description(),
49
- )
49
+ ),
50
50
  ]
51
51
 
52
52
  def aggregate(
@@ -172,7 +172,7 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
172
172
  BaseReportedAggregation(
173
173
  metric_name=BinaryClassifierCountThresholdClassAggregationFunction._metric_name(),
174
174
  description=BinaryClassifierCountThresholdClassAggregationFunction.description(),
175
- )
175
+ ),
176
176
  ]
177
177
 
178
178
  def aggregate(
@@ -5,10 +5,10 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.metrics import (
8
+ BaseReportedAggregation,
8
9
  DatasetReference,
9
10
  Dimension,
10
11
  NumericMetric,
11
- BaseReportedAggregation,
12
12
  )
13
13
  from arthur_common.models.schema_definitions import (
14
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -43,7 +43,7 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
43
43
  BaseReportedAggregation(
44
44
  metric_name=InferenceNullCountAggregationFunction.METRIC_NAME,
45
45
  description=InferenceNullCountAggregationFunction.description(),
46
- )
46
+ ),
47
47
  ]
48
48
 
49
49
  def aggregate(
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
+ BaseReportedAggregation,
9
10
  DatasetReference,
10
11
  NumericMetric,
11
- BaseReportedAggregation,
12
12
  )
13
13
  from arthur_common.models.schema_definitions import (
14
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -159,7 +159,8 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
159
159
  )
160
160
 
161
161
  count_metric = self.series_to_metric(
162
- self.ABSOLUTE_ERROR_COUNT_METRIC_NAME, count_series
162
+ self.ABSOLUTE_ERROR_COUNT_METRIC_NAME,
163
+ count_series,
163
164
  )
164
165
  absolute_error_metric = self.series_to_metric(
165
166
  self.ABSOLUTE_ERROR_SUM_METRIC_NAME,
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
+ BaseReportedAggregation,
9
10
  DatasetReference,
10
11
  NumericMetric,
11
- BaseReportedAggregation,
12
12
  )
13
13
  from arthur_common.models.schema_definitions import (
14
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -159,7 +159,8 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
159
159
  )
160
160
 
161
161
  count_metric = self.series_to_metric(
162
- self.SQUARED_ERROR_COUNT_METRIC_NAME, count_series
162
+ self.SQUARED_ERROR_COUNT_METRIC_NAME,
163
+ count_series,
163
164
  )
164
165
  absolute_error_metric = self.series_to_metric(
165
166
  self.SQUARED_ERROR_SUM_METRIC_NAME,
@@ -6,9 +6,9 @@ from duckdb import DuckDBPyConnection
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.datasets import ModelProblemType
8
8
  from arthur_common.models.metrics import (
9
+ BaseReportedAggregation,
9
10
  DatasetReference,
10
11
  NumericMetric,
11
- BaseReportedAggregation,
12
12
  )
13
13
  from arthur_common.models.schema_definitions import (
14
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -8,9 +8,9 @@ from arthur_common.aggregations.functions.inference_count_by_class import (
8
8
  )
9
9
  from arthur_common.models.datasets import ModelProblemType
10
10
  from arthur_common.models.metrics import (
11
+ BaseReportedAggregation,
11
12
  DatasetReference,
12
13
  NumericMetric,
13
- BaseReportedAggregation,
14
14
  )
15
15
  from arthur_common.models.schema_definitions import (
16
16
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -57,7 +57,7 @@ class MulticlassClassifierCountByClassAggregationFunction(
57
57
  BaseReportedAggregation(
58
58
  metric_name=MulticlassClassifierCountByClassAggregationFunction._metric_name(),
59
59
  description=MulticlassClassifierCountByClassAggregationFunction.description(),
60
- )
60
+ ),
61
61
  ]
62
62
 
63
63
  def aggregate(
@@ -5,9 +5,9 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import SketchAggregationFunction
7
7
  from arthur_common.models.metrics import (
8
+ BaseReportedAggregation,
8
9
  DatasetReference,
9
10
  SketchMetric,
10
- BaseReportedAggregation,
11
11
  )
12
12
  from arthur_common.models.schema_definitions import (
13
13
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -44,7 +44,7 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
44
44
  BaseReportedAggregation(
45
45
  metric_name=NumericSketchAggregationFunction.METRIC_NAME,
46
46
  description=NumericSketchAggregationFunction.description(),
47
- )
47
+ ),
48
48
  ]
49
49
 
50
50
  def aggregate(
@@ -5,10 +5,10 @@ from duckdb import DuckDBPyConnection
5
5
 
6
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
7
7
  from arthur_common.models.metrics import (
8
+ BaseReportedAggregation,
8
9
  DatasetReference,
9
10
  Dimension,
10
11
  NumericMetric,
11
- BaseReportedAggregation,
12
12
  )
13
13
  from arthur_common.models.schema_definitions import (
14
14
  SEGMENTATION_ALLOWED_COLUMN_TYPES,
@@ -43,7 +43,7 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
43
43
  BaseReportedAggregation(
44
44
  metric_name=NumericSumAggregationFunction.METRIC_NAME,
45
45
  description=NumericSumAggregationFunction.description(),
46
- )
46
+ ),
47
47
  ]
48
48
 
49
49
  def aggregate(
@@ -11,10 +11,10 @@ from arthur_common.aggregations.aggregator import (
11
11
  )
12
12
  from arthur_common.models.datasets import ModelProblemType
13
13
  from arthur_common.models.metrics import (
14
+ BaseReportedAggregation,
14
15
  DatasetReference,
15
16
  NumericMetric,
16
17
  SketchMetric,
17
- BaseReportedAggregation,
18
18
  )
19
19
  from arthur_common.models.schema_definitions import (
20
20
  SHIELD_RESPONSE_SCHEMA,
@@ -44,7 +44,7 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
44
44
  BaseReportedAggregation(
45
45
  metric_name=ShieldInferencePassFailCountAggregation.METRIC_NAME,
46
46
  description=ShieldInferencePassFailCountAggregation.description(),
47
- )
47
+ ),
48
48
  ]
49
49
 
50
50
  def aggregate(
@@ -113,7 +113,7 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
113
113
  BaseReportedAggregation(
114
114
  metric_name=ShieldInferenceRuleCountAggregation.METRIC_NAME,
115
115
  description=ShieldInferenceRuleCountAggregation.description(),
116
- )
116
+ ),
117
117
  ]
118
118
 
119
119
  def aggregate(
@@ -205,7 +205,7 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
205
205
  BaseReportedAggregation(
206
206
  metric_name=ShieldInferenceHallucinationCountAggregation.METRIC_NAME,
207
207
  description=ShieldInferenceHallucinationCountAggregation.description(),
208
- )
208
+ ),
209
209
  ]
210
210
 
211
211
  def aggregate(
@@ -269,7 +269,7 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
269
269
  BaseReportedAggregation(
270
270
  metric_name=ShieldInferenceRuleToxicityScoreAggregation.METRIC_NAME,
271
271
  description=ShieldInferenceRuleToxicityScoreAggregation.description(),
272
- )
272
+ ),
273
273
  ]
274
274
 
275
275
  def aggregate(
@@ -354,7 +354,7 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
354
354
  BaseReportedAggregation(
355
355
  metric_name=ShieldInferenceRulePIIDataScoreAggregation.METRIC_NAME,
356
356
  description=ShieldInferenceRulePIIDataScoreAggregation.description(),
357
- )
357
+ ),
358
358
  ]
359
359
 
360
360
  def aggregate(
@@ -445,7 +445,7 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
445
445
  BaseReportedAggregation(
446
446
  metric_name=ShieldInferenceRuleClaimCountAggregation.METRIC_NAME,
447
447
  description=ShieldInferenceRuleClaimCountAggregation.description(),
448
- )
448
+ ),
449
449
  ]
450
450
 
451
451
  def aggregate(
@@ -518,7 +518,7 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
518
518
  BaseReportedAggregation(
519
519
  metric_name=ShieldInferenceRuleClaimPassCountAggregation.METRIC_NAME,
520
520
  description=ShieldInferenceRuleClaimPassCountAggregation.description(),
521
- )
521
+ ),
522
522
  ]
523
523
 
524
524
  def aggregate(
@@ -591,7 +591,7 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
591
591
  BaseReportedAggregation(
592
592
  metric_name=ShieldInferenceRuleClaimFailCountAggregation.METRIC_NAME,
593
593
  description=ShieldInferenceRuleClaimFailCountAggregation.description(),
594
- )
594
+ ),
595
595
  ]
596
596
 
597
597
  def aggregate(
@@ -664,7 +664,7 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
664
664
  BaseReportedAggregation(
665
665
  metric_name=ShieldInferenceRuleLatencyAggregation.METRIC_NAME,
666
666
  description=ShieldInferenceRuleLatencyAggregation.description(),
667
- )
667
+ ),
668
668
  ]
669
669
 
670
670
  def aggregate(
@@ -768,7 +768,7 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
768
768
  return [base_token_count_agg] + [
769
769
  BaseReportedAggregation(
770
770
  metric_name=ShieldInferenceTokenCountAggregation._series_name_from_model_name(
771
- model_name
771
+ model_name,
772
772
  ),
773
773
  description=f"Metric that reports the estimated cost for the {model_name} model of the tokens in the Shield response and prompt schemas.",
774
774
  )
@@ -855,7 +855,8 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
855
855
  )
856
856
  resp.append(
857
857
  self.series_to_metric(
858
- self._series_name_from_model_name(model), model_series
859
- )
858
+ self._series_name_from_model_name(model),
859
+ model_series,
860
+ ),
860
861
  )
861
862
  return resp
@@ -7,6 +7,7 @@ class ModelProblemType(str, Enum):
7
7
  ARTHUR_SHIELD = "arthur_shield"
8
8
  CUSTOM = "custom"
9
9
  MULTICLASS_CLASSIFICATION = "multiclass_classification"
10
+ AGENTIC_TRACE = "agentic_trace"
10
11
 
11
12
 
12
13
  class DatasetFileType(str, Enum):
@@ -194,7 +194,8 @@ class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSc
194
194
 
195
195
 
196
196
  class MetricsColumnListParameterSchema(
197
- MetricsParameterSchema, BaseColumnParameterSchema
197
+ MetricsParameterSchema,
198
+ BaseColumnParameterSchema,
198
199
  ):
199
200
  # list column parameter schema specific to default metrics
200
201
  parameter_type: Literal["column_list"] = "column_list"
@@ -249,14 +250,14 @@ class AggregationSpecSchema(BaseModel):
249
250
  description="List of parameters to the aggregation's aggregate function.",
250
251
  )
251
252
  reported_aggregations: list[BaseReportedAggregation] = Field(
252
- description="List of aggregations reported by the metric."
253
+ description="List of aggregations reported by the metric.",
253
254
  )
254
255
 
255
256
  @model_validator(mode="after")
256
257
  def at_least_one_reported_agg(self) -> Self:
257
258
  if len(self.reported_aggregations) < 1:
258
259
  raise ValueError(
259
- "Aggregation spec must specify at least one reported aggregation."
260
+ "Aggregation spec must specify at least one reported aggregation.",
260
261
  )
261
262
  return self
262
263
 
@@ -283,16 +284,16 @@ class AggregationSpecSchema(BaseModel):
283
284
 
284
285
  class ReportedCustomAggregation(BaseReportedAggregation):
285
286
  value_column: str = Field(
286
- description="Name of the column returned from the SQL query holding the metric value."
287
+ description="Name of the column returned from the SQL query holding the metric value.",
287
288
  )
288
289
  timestamp_column: str = Field(
289
- description="Name of the column returned from the SQL query holding the timestamp buckets."
290
+ description="Name of the column returned from the SQL query holding the timestamp buckets.",
290
291
  )
291
292
  metric_kind: AggregationMetricType = Field(
292
293
  description="Return type of the reported aggregation metric value.",
293
294
  )
294
295
  dimension_columns: list[str] = Field(
295
- description="Name of any dimension columns returned from the SQL query. Max length is 1."
296
+ description="Name of any dimension columns returned from the SQL query. Max length is 1.",
296
297
  )
297
298
 
298
299
  @field_validator("dimension_columns")
@@ -367,6 +367,38 @@ def create_shield_inference_feedback_schema() -> DatasetListType:
367
367
  )
368
368
 
369
369
 
370
+ def AGENTIC_TRACE_SCHEMA() -> DatasetSchema:
371
+ return DatasetSchema(
372
+ alias_mask={},
373
+ columns=[
374
+ DatasetColumn(
375
+ id=uuid4(),
376
+ source_name="trace_id",
377
+ definition=create_dataset_scalar_type(DType.STRING),
378
+ ),
379
+ DatasetColumn(
380
+ id=uuid4(),
381
+ source_name="start_time",
382
+ definition=create_dataset_scalar_type(DType.TIMESTAMP),
383
+ ),
384
+ DatasetColumn(
385
+ id=uuid4(),
386
+ source_name="end_time",
387
+ definition=create_dataset_scalar_type(DType.TIMESTAMP),
388
+ ),
389
+ DatasetColumn(
390
+ id=uuid4(),
391
+ source_name="root_spans",
392
+ definition=create_dataset_list_type(
393
+ create_dataset_scalar_type(
394
+ DType.JSON,
395
+ ), # JSON blob to preserve hierarchy
396
+ ),
397
+ ),
398
+ ],
399
+ )
400
+
401
+
370
402
  def SHIELD_SCHEMA() -> DatasetSchema:
371
403
  return DatasetSchema(
372
404
  alias_mask={},
@@ -423,6 +455,32 @@ def SHIELD_SCHEMA() -> DatasetSchema:
423
455
  SHIELD_RESPONSE_SCHEMA = create_shield_response_schema().to_base_type()
424
456
  SHIELD_PROMPT_SCHEMA = create_shield_prompt_schema().to_base_type()
425
457
 
458
+
459
+ # Agentic trace schema base type for API responses
460
+ def create_agentic_trace_response_schema() -> DatasetObjectType:
461
+ return create_dataset_object_type(
462
+ {
463
+ "count": create_dataset_scalar_type(DType.INT),
464
+ "traces": create_dataset_list_type(
465
+ create_dataset_object_type(
466
+ {
467
+ "trace_id": create_dataset_scalar_type(DType.STRING),
468
+ "start_time": create_dataset_scalar_type(DType.TIMESTAMP),
469
+ "end_time": create_dataset_scalar_type(DType.TIMESTAMP),
470
+ "root_spans": create_dataset_list_type(
471
+ create_dataset_scalar_type(
472
+ DType.JSON,
473
+ ), # JSON blob for infinite depth
474
+ ),
475
+ },
476
+ ),
477
+ ),
478
+ },
479
+ )
480
+
481
+
482
+ AGENTIC_TRACE_RESPONSE_SCHEMA = create_agentic_trace_response_schema().to_base_type()
483
+
426
484
  SEGMENTATION_ALLOWED_DTYPES = [DType.INT, DType.BOOL, DType.STRING, DType.UUID]
427
485
  SEGMENTATION_ALLOWED_COLUMN_TYPES = [
428
486
  ScalarType(dtype=d_type) for d_type in SEGMENTATION_ALLOWED_DTYPES
@@ -1,3 +1,4 @@
1
+ from datetime import datetime
1
2
  from enum import Enum
2
3
  from typing import Any, Dict, List, Optional, Self, Type, Union
3
4
 
@@ -26,6 +27,15 @@ class RuleScope(str, Enum):
26
27
  TASK = "task"
27
28
 
28
29
 
30
+ class MetricType(str, Enum):
31
+ QUERY_RELEVANCE = "QueryRelevance"
32
+ RESPONSE_RELEVANCE = "ResponseRelevance"
33
+ TOOL_SELECTION = "ToolSelection"
34
+
35
+ def __str__(self):
36
+ return self.value
37
+
38
+
29
39
  class BaseEnum(str, Enum):
30
40
  @classmethod
31
41
  def values(cls) -> list[Any]:
@@ -240,6 +250,27 @@ class RuleResponse(BaseModel):
240
250
  )
241
251
 
242
252
 
253
+ class MetricResponse(BaseModel):
254
+ id: str = Field(description="ID of the Metric")
255
+ name: str = Field(description="Name of the Metric")
256
+ type: MetricType = Field(description="Type of the Metric")
257
+ metric_metadata: str = Field(description="Metadata of the Metric")
258
+ config: Optional[str] = Field(
259
+ description="JSON-serialized configuration for the Metric",
260
+ default=None,
261
+ )
262
+ created_at: datetime = Field(
263
+ description="Time the Metric was created in unix milliseconds",
264
+ )
265
+ updated_at: datetime = Field(
266
+ description="Time the Metric was updated in unix milliseconds",
267
+ )
268
+ enabled: Optional[bool] = Field(
269
+ description="Whether the Metric is enabled",
270
+ default=None,
271
+ )
272
+
273
+
243
274
  class TaskResponse(BaseModel):
244
275
  id: str = Field(description=" ID of the task")
245
276
  name: str = Field(description="Name of the task")
@@ -249,7 +280,12 @@ class TaskResponse(BaseModel):
249
280
  updated_at: int = Field(
250
281
  description="Time the task was created in unix milliseconds",
251
282
  )
283
+ is_agentic: bool = Field(description="Whether the task is agentic or not")
252
284
  rules: List[RuleResponse] = Field(description="List of all the rules for the task.")
285
+ metrics: Optional[List[MetricResponse]] = Field(
286
+ description="List of all the metrics for the task.",
287
+ default=None,
288
+ )
253
289
 
254
290
 
255
291
  class UpdateRuleRequest(BaseModel):
@@ -484,3 +520,125 @@ class NewRuleRequest(BaseModel):
484
520
  detail="Examples must be provided to onboard a ModelSensitiveDataRule",
485
521
  )
486
522
  return self
523
+
524
+
525
+ class RelevanceMetricConfig(BaseModel):
526
+ """Configuration for relevance metrics including QueryRelevance and ResponseRelevance"""
527
+
528
+ relevance_threshold: Optional[float] = Field(
529
+ default=None,
530
+ description="Threshold for determining relevance when not using LLM judge",
531
+ )
532
+ use_llm_judge: bool = Field(
533
+ default=True,
534
+ description="Whether to use LLM as a judge for relevance scoring",
535
+ )
536
+
537
+
538
+ class NewMetricRequest(BaseModel):
539
+ type: MetricType = Field(
540
+ description="Type of the metric. It can only be one of QueryRelevance, ResponseRelevance, ToolSelection",
541
+ examples=["UserQueryRelevance"],
542
+ )
543
+ name: str = Field(
544
+ description="Name of metric",
545
+ examples=["My User Query Relevance"],
546
+ )
547
+ metric_metadata: str = Field(description="Additional metadata for the metric")
548
+ config: Optional[RelevanceMetricConfig] = Field(
549
+ description="Configuration for the metric. Currently only applies to UserQueryRelevance and ResponseRelevance metric types.",
550
+ default=None,
551
+ )
552
+
553
+ model_config = ConfigDict(
554
+ json_schema_extra={
555
+ "example1": {
556
+ "type": "QueryRelevance",
557
+ "name": "My User Query Relevance",
558
+ "metric_metadata": "This is a test metric metadata",
559
+ },
560
+ "example2": {
561
+ "type": "QueryRelevance",
562
+ "name": "My User Query Relevance with Config",
563
+ "metric_metadata": "This is a test metric metadata",
564
+ "config": {"relevance_threshold": 0.8, "use_llm_judge": False},
565
+ },
566
+ "example3": {
567
+ "type": "ResponseRelevance",
568
+ "name": "My Response Relevance",
569
+ "metric_metadata": "This is a test metric metadata",
570
+ "config": {"use_llm_judge": True},
571
+ },
572
+ },
573
+ )
574
+
575
+ @field_validator("type")
576
+ def validate_metric_type(cls, value):
577
+ if value not in MetricType:
578
+ raise ValueError(
579
+ f"Invalid metric type: {value}. Valid types are: {', '.join([t.value for t in MetricType])}",
580
+ )
581
+ return value
582
+
583
+ @model_validator(mode="before")
584
+ def set_config_type(cls, values):
585
+ if not isinstance(values, dict):
586
+ return values
587
+
588
+ metric_type = values.get("type")
589
+ config_values = values.get("config")
590
+
591
+ # Map metric types to their corresponding config classes
592
+ metric_type_to_config = {
593
+ MetricType.QUERY_RELEVANCE: RelevanceMetricConfig,
594
+ MetricType.RESPONSE_RELEVANCE: RelevanceMetricConfig,
595
+ # Add new metric types and their configs here as needed
596
+ }
597
+
598
+ config_class = metric_type_to_config.get(metric_type)
599
+
600
+ if config_class is not None:
601
+ if config_values is None:
602
+ # Default config when none is provided
603
+ config_values = {"use_llm_judge": True}
604
+ elif isinstance(config_values, dict):
605
+ # Handle mutually exclusive parameters
606
+ if (
607
+ "relevance_threshold" in config_values
608
+ and "use_llm_judge" in config_values
609
+ and config_values["use_llm_judge"]
610
+ ):
611
+ raise HTTPException(
612
+ status_code=400,
613
+ detail="relevance_threshold and use_llm_judge=true are mutually exclusive. Set use_llm_judge=false when using relevance_threshold.",
614
+ headers={"full_stacktrace": "false"},
615
+ )
616
+
617
+ # If relevance_threshold is set but use_llm_judge isn't, set use_llm_judge to false
618
+ if (
619
+ "relevance_threshold" in config_values
620
+ and "use_llm_judge" not in config_values
621
+ ):
622
+ config_values["use_llm_judge"] = False
623
+
624
+ # If neither is set, default to use_llm_judge=True
625
+ if (
626
+ "relevance_threshold" not in config_values
627
+ and "use_llm_judge" not in config_values
628
+ ):
629
+ config_values["use_llm_judge"] = True
630
+
631
+ if isinstance(config_values, BaseModel):
632
+ config_values = config_values.model_dump()
633
+
634
+ values["config"] = config_class(**config_values)
635
+ elif config_values is not None:
636
+ # Provide a nice error message listing supported metric types
637
+ supported_types = [t.value for t in metric_type_to_config.keys()]
638
+ raise HTTPException(
639
+ status_code=400,
640
+ detail=f"Config is only supported for {', '.join(supported_types)} metric types",
641
+ headers={"full_stacktrace": "false"},
642
+ )
643
+
644
+ return values