arthur-common 2.1.52__py3-none-any.whl → 2.1.54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

@@ -10,7 +10,12 @@ from arthur_common.aggregations.aggregator import (
10
10
  SketchAggregationFunction,
11
11
  )
12
12
  from arthur_common.models.datasets import ModelProblemType
13
- from arthur_common.models.metrics import DatasetReference, NumericMetric, SketchMetric
13
+ from arthur_common.models.metrics import (
14
+ BaseReportedAggregation,
15
+ DatasetReference,
16
+ NumericMetric,
17
+ SketchMetric,
18
+ )
14
19
  from arthur_common.models.schema_definitions import (
15
20
  SHIELD_RESPONSE_SCHEMA,
16
21
  MetricColumnParameterAnnotation,
@@ -33,6 +38,15 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
33
38
  def description() -> str:
34
39
  return "Metric that counts the number of Shield inferences grouped by the prompt, response, and overall check results."
35
40
 
41
+ @staticmethod
42
+ def reported_aggregations() -> list[BaseReportedAggregation]:
43
+ return [
44
+ BaseReportedAggregation(
45
+ metric_name=ShieldInferencePassFailCountAggregation.METRIC_NAME,
46
+ description=ShieldInferencePassFailCountAggregation.description(),
47
+ ),
48
+ ]
49
+
36
50
  def aggregate(
37
51
  self,
38
52
  ddb_conn: DuckDBPyConnection,
@@ -93,6 +107,15 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
93
107
  def description() -> str:
94
108
  return "Metric that counts the number of Shield rule evaluations grouped by whether it was on the prompt or response, the rule type, the rule evaluation result, the rule name, and the rule id."
95
109
 
110
+ @staticmethod
111
+ def reported_aggregations() -> list[BaseReportedAggregation]:
112
+ return [
113
+ BaseReportedAggregation(
114
+ metric_name=ShieldInferenceRuleCountAggregation.METRIC_NAME,
115
+ description=ShieldInferenceRuleCountAggregation.description(),
116
+ ),
117
+ ]
118
+
96
119
  def aggregate(
97
120
  self,
98
121
  ddb_conn: DuckDBPyConnection,
@@ -176,6 +199,15 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
176
199
  def description() -> str:
177
200
  return "Metric that counts the number of Shield hallucination evaluations that failed."
178
201
 
202
+ @staticmethod
203
+ def reported_aggregations() -> list[BaseReportedAggregation]:
204
+ return [
205
+ BaseReportedAggregation(
206
+ metric_name=ShieldInferenceHallucinationCountAggregation.METRIC_NAME,
207
+ description=ShieldInferenceHallucinationCountAggregation.description(),
208
+ ),
209
+ ]
210
+
179
211
  def aggregate(
180
212
  self,
181
213
  ddb_conn: DuckDBPyConnection,
@@ -231,6 +263,15 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
231
263
  def description() -> str:
232
264
  return "Metric that reports a distribution (data sketch) on toxicity scores returned by the Shield toxicity rule."
233
265
 
266
+ @staticmethod
267
+ def reported_aggregations() -> list[BaseReportedAggregation]:
268
+ return [
269
+ BaseReportedAggregation(
270
+ metric_name=ShieldInferenceRuleToxicityScoreAggregation.METRIC_NAME,
271
+ description=ShieldInferenceRuleToxicityScoreAggregation.description(),
272
+ ),
273
+ ]
274
+
234
275
  def aggregate(
235
276
  self,
236
277
  ddb_conn: DuckDBPyConnection,
@@ -307,6 +348,15 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
307
348
  def description() -> str:
308
349
  return "Metric that reports a distribution (data sketch) on PII scores returned by the Shield PII rule."
309
350
 
351
+ @staticmethod
352
+ def reported_aggregations() -> list[BaseReportedAggregation]:
353
+ return [
354
+ BaseReportedAggregation(
355
+ metric_name=ShieldInferenceRulePIIDataScoreAggregation.METRIC_NAME,
356
+ description=ShieldInferenceRulePIIDataScoreAggregation.description(),
357
+ ),
358
+ ]
359
+
310
360
  def aggregate(
311
361
  self,
312
362
  ddb_conn: DuckDBPyConnection,
@@ -389,6 +439,15 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
389
439
  def description() -> str:
390
440
  return "Metric that reports a distribution (data sketch) on over the number of claims identified by the Shield hallucination rule."
391
441
 
442
+ @staticmethod
443
+ def reported_aggregations() -> list[BaseReportedAggregation]:
444
+ return [
445
+ BaseReportedAggregation(
446
+ metric_name=ShieldInferenceRuleClaimCountAggregation.METRIC_NAME,
447
+ description=ShieldInferenceRuleClaimCountAggregation.description(),
448
+ ),
449
+ ]
450
+
392
451
  def aggregate(
393
452
  self,
394
453
  ddb_conn: DuckDBPyConnection,
@@ -453,6 +512,15 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
453
512
  def description() -> str:
454
513
  return "Metric that reports a distribution (data sketch) on the number of valid claims determined by the Shield hallucination rule."
455
514
 
515
+ @staticmethod
516
+ def reported_aggregations() -> list[BaseReportedAggregation]:
517
+ return [
518
+ BaseReportedAggregation(
519
+ metric_name=ShieldInferenceRuleClaimPassCountAggregation.METRIC_NAME,
520
+ description=ShieldInferenceRuleClaimPassCountAggregation.description(),
521
+ ),
522
+ ]
523
+
456
524
  def aggregate(
457
525
  self,
458
526
  ddb_conn: DuckDBPyConnection,
@@ -517,6 +585,15 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
517
585
  def description() -> str:
518
586
  return "Metric that reports a distribution (data sketch) on the number of invalid claims determined by the Shield hallucination rule."
519
587
 
588
+ @staticmethod
589
+ def reported_aggregations() -> list[BaseReportedAggregation]:
590
+ return [
591
+ BaseReportedAggregation(
592
+ metric_name=ShieldInferenceRuleClaimFailCountAggregation.METRIC_NAME,
593
+ description=ShieldInferenceRuleClaimFailCountAggregation.description(),
594
+ ),
595
+ ]
596
+
520
597
  def aggregate(
521
598
  self,
522
599
  ddb_conn: DuckDBPyConnection,
@@ -581,6 +658,15 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
581
658
  def description() -> str:
582
659
  return "Metric that reports a distribution (data sketch) on the latency of Shield rule evaluations. Dimensions are the rule result, rule type, and whether the rule was applicable to a prompt or response."
583
660
 
661
+ @staticmethod
662
+ def reported_aggregations() -> list[BaseReportedAggregation]:
663
+ return [
664
+ BaseReportedAggregation(
665
+ metric_name=ShieldInferenceRuleLatencyAggregation.METRIC_NAME,
666
+ description=ShieldInferenceRuleLatencyAggregation.description(),
667
+ ),
668
+ ]
669
+
584
670
  def aggregate(
585
671
  self,
586
672
  ddb_conn: DuckDBPyConnection,
@@ -643,6 +729,18 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
643
729
 
644
730
  class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
645
731
  METRIC_NAME = "token_count"
732
+ SUPPORTED_MODELS = [
733
+ "gpt-4o",
734
+ "gpt-4o-mini",
735
+ "gpt-3.5-turbo",
736
+ "o1-mini",
737
+ "deepseek-chat",
738
+ "claude-3-5-sonnet-20241022",
739
+ "gemini/gemini-1.5-pro",
740
+ "meta.llama3-1-8b-instruct-v1:0",
741
+ "meta.llama3-1-70b-instruct-v1:0",
742
+ "meta.llama3-2-11b-instruct-v1:0",
743
+ ]
646
744
 
647
745
  @staticmethod
648
746
  def id() -> UUID:
@@ -656,6 +754,27 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
656
754
  def description() -> str:
657
755
  return "Metric that reports the number of tokens in the Shield response and prompt schemas, and their estimated cost."
658
756
 
757
+ @staticmethod
758
+ def _series_name_from_model_name(model_name: str) -> str:
759
+ """Calculates name of reported series based on the model name considered."""
760
+ return f"token_cost.{model_name}"
761
+
762
+ @staticmethod
763
+ def reported_aggregations() -> list[BaseReportedAggregation]:
764
+ base_token_count_agg = BaseReportedAggregation(
765
+ metric_name=ShieldInferenceTokenCountAggregation.METRIC_NAME,
766
+ description=f"Metric that reports the number of tokens in the Shield response and prompt schemas.",
767
+ )
768
+ return [base_token_count_agg] + [
769
+ BaseReportedAggregation(
770
+ metric_name=ShieldInferenceTokenCountAggregation._series_name_from_model_name(
771
+ model_name,
772
+ ),
773
+ description=f"Metric that reports the estimated cost for the {model_name} model of the tokens in the Shield response and prompt schemas.",
774
+ )
775
+ for model_name in ShieldInferenceTokenCountAggregation.SUPPORTED_MODELS
776
+ ]
777
+
659
778
  def aggregate(
660
779
  self,
661
780
  ddb_conn: DuckDBPyConnection,
@@ -708,25 +827,12 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
708
827
  resp = [metric]
709
828
 
710
829
  # Compute Cost for each model
711
- models = [
712
- "gpt-4o",
713
- "gpt-4o-mini",
714
- "gpt-3.5-turbo",
715
- "o1-mini",
716
- "deepseek-chat",
717
- "claude-3-5-sonnet-20241022",
718
- "gemini/gemini-1.5-pro",
719
- "meta.llama3-1-8b-instruct-v1:0",
720
- "meta.llama3-1-70b-instruct-v1:0",
721
- "meta.llama3-2-11b-instruct-v1:0",
722
- ]
723
-
724
830
  # Precompute input/output classification to avoid recalculating in loop
725
831
  location_type = results["location"].apply(
726
832
  lambda x: "input" if x == "prompt" else "output",
727
833
  )
728
834
 
729
- for model in models:
835
+ for model in self.SUPPORTED_MODELS:
730
836
  # Efficient list comprehension instead of apply
731
837
  cost_values = [
732
838
  calculate_cost_by_tokens(int(tokens), model, loc_type)
@@ -747,5 +853,10 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
747
853
  ["location"],
748
854
  "ts",
749
855
  )
750
- resp.append(self.series_to_metric(f"token_cost.{model}", model_series))
856
+ resp.append(
857
+ self.series_to_metric(
858
+ self._series_name_from_model_name(model),
859
+ model_series,
860
+ ),
861
+ )
751
862
  return resp
@@ -7,6 +7,7 @@ class ModelProblemType(str, Enum):
7
7
  ARTHUR_SHIELD = "arthur_shield"
8
8
  CUSTOM = "custom"
9
9
  MULTICLASS_CLASSIFICATION = "multiclass_classification"
10
+ AGENTIC_TRACE = "agentic_trace"
10
11
 
11
12
 
12
13
  class DatasetFileType(str, Enum):
@@ -193,7 +193,10 @@ class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSc
193
193
  parameter_type: Literal["column"] = "column"
194
194
 
195
195
 
196
- class MetricsColumnListParameterSchema(MetricsParameterSchema, BaseColumnParameterSchema):
196
+ class MetricsColumnListParameterSchema(
197
+ MetricsParameterSchema,
198
+ BaseColumnParameterSchema,
199
+ ):
197
200
  # list column parameter schema specific to default metrics
198
201
  parameter_type: Literal["column_list"] = "column_list"
199
202
 
@@ -211,9 +214,7 @@ MetricsColumnSchemaUnion = (
211
214
 
212
215
 
213
216
  CustomAggregationParametersSchemaUnion = (
214
- BaseDatasetParameterSchema
215
- | BaseLiteralParameterSchema
216
- | BaseColumnParameterSchema
217
+ BaseDatasetParameterSchema | BaseLiteralParameterSchema | BaseColumnParameterSchema
217
218
  )
218
219
 
219
220
 
@@ -224,6 +225,14 @@ class DatasetReference:
224
225
  dataset_id: UUID
225
226
 
226
227
 
228
+ class BaseReportedAggregation(BaseModel):
229
+ # in future will be used by default metrics
230
+ metric_name: str = Field(description="Name of the reported aggregation metric.")
231
+ description: str = Field(
232
+ description="Description of the reported aggregation metric and what it aggregates.",
233
+ )
234
+
235
+
227
236
  class AggregationSpecSchema(BaseModel):
228
237
  name: str = Field(description="Name of the aggregation function.")
229
238
  id: UUID = Field(description="Unique identifier of the aggregation function.")
@@ -240,6 +249,17 @@ class AggregationSpecSchema(BaseModel):
240
249
  aggregate_args: list[MetricsParameterSchemaUnion] = Field(
241
250
  description="List of parameters to the aggregation's aggregate function.",
242
251
  )
252
+ reported_aggregations: list[BaseReportedAggregation] = Field(
253
+ description="List of aggregations reported by the metric.",
254
+ )
255
+
256
+ @model_validator(mode="after")
257
+ def at_least_one_reported_agg(self) -> Self:
258
+ if len(self.reported_aggregations) < 1:
259
+ raise ValueError(
260
+ "Aggregation spec must specify at least one reported aggregation.",
261
+ )
262
+ return self
243
263
 
244
264
  @model_validator(mode="after")
245
265
  def column_dataset_references_exist(self) -> Self:
@@ -262,26 +282,23 @@ class AggregationSpecSchema(BaseModel):
262
282
  return self
263
283
 
264
284
 
265
- class BaseReportedAggregation(BaseModel):
266
- # in future will be used by default metrics
267
- metric_name: str = Field(description="Name of the reported aggregation metric.")
268
- description: str = Field(
269
- description="Description of the reported aggregation metric and what it aggregates.",
270
- )
271
-
272
-
273
285
  class ReportedCustomAggregation(BaseReportedAggregation):
274
- value_column: str = Field(description="Name of the column returned from the SQL query holding the metric value.")
275
- timestamp_column: str = Field(description="Name of the column returned from the SQL query holding the timestamp buckets.")
286
+ value_column: str = Field(
287
+ description="Name of the column returned from the SQL query holding the metric value.",
288
+ )
289
+ timestamp_column: str = Field(
290
+ description="Name of the column returned from the SQL query holding the timestamp buckets.",
291
+ )
276
292
  metric_kind: AggregationMetricType = Field(
277
293
  description="Return type of the reported aggregation metric value.",
278
294
  )
279
- dimension_columns: list[str] = Field(description="Name of any dimension columns returned from the SQL query. Max length is 1.")
295
+ dimension_columns: list[str] = Field(
296
+ description="Name of any dimension columns returned from the SQL query. Max length is 1.",
297
+ )
280
298
 
281
- @field_validator('dimension_columns')
299
+ @field_validator("dimension_columns")
282
300
  @classmethod
283
301
  def validate_dimension_columns_length(cls, v: list[str]) -> str:
284
302
  if len(v) > 1:
285
- raise ValueError('Only one dimension column can be specified.')
303
+ raise ValueError("Only one dimension column can be specified.")
286
304
  return v
287
-
@@ -367,6 +367,38 @@ def create_shield_inference_feedback_schema() -> DatasetListType:
367
367
  )
368
368
 
369
369
 
370
+ def AGENTIC_TRACE_SCHEMA() -> DatasetSchema:
371
+ return DatasetSchema(
372
+ alias_mask={},
373
+ columns=[
374
+ DatasetColumn(
375
+ id=uuid4(),
376
+ source_name="trace_id",
377
+ definition=create_dataset_scalar_type(DType.STRING),
378
+ ),
379
+ DatasetColumn(
380
+ id=uuid4(),
381
+ source_name="start_time",
382
+ definition=create_dataset_scalar_type(DType.TIMESTAMP),
383
+ ),
384
+ DatasetColumn(
385
+ id=uuid4(),
386
+ source_name="end_time",
387
+ definition=create_dataset_scalar_type(DType.TIMESTAMP),
388
+ ),
389
+ DatasetColumn(
390
+ id=uuid4(),
391
+ source_name="root_spans",
392
+ definition=create_dataset_list_type(
393
+ create_dataset_scalar_type(
394
+ DType.JSON,
395
+ ), # JSON blob to preserve hierarchy
396
+ ),
397
+ ),
398
+ ],
399
+ )
400
+
401
+
370
402
  def SHIELD_SCHEMA() -> DatasetSchema:
371
403
  return DatasetSchema(
372
404
  alias_mask={},
@@ -423,6 +455,32 @@ def SHIELD_SCHEMA() -> DatasetSchema:
423
455
  SHIELD_RESPONSE_SCHEMA = create_shield_response_schema().to_base_type()
424
456
  SHIELD_PROMPT_SCHEMA = create_shield_prompt_schema().to_base_type()
425
457
 
458
+
459
+ # Agentic trace schema base type for API responses
460
+ def create_agentic_trace_response_schema() -> DatasetObjectType:
461
+ return create_dataset_object_type(
462
+ {
463
+ "count": create_dataset_scalar_type(DType.INT),
464
+ "traces": create_dataset_list_type(
465
+ create_dataset_object_type(
466
+ {
467
+ "trace_id": create_dataset_scalar_type(DType.STRING),
468
+ "start_time": create_dataset_scalar_type(DType.TIMESTAMP),
469
+ "end_time": create_dataset_scalar_type(DType.TIMESTAMP),
470
+ "root_spans": create_dataset_list_type(
471
+ create_dataset_scalar_type(
472
+ DType.JSON,
473
+ ), # JSON blob for infinite depth
474
+ ),
475
+ },
476
+ ),
477
+ ),
478
+ },
479
+ )
480
+
481
+
482
+ AGENTIC_TRACE_RESPONSE_SCHEMA = create_agentic_trace_response_schema().to_base_type()
483
+
426
484
  SEGMENTATION_ALLOWED_DTYPES = [DType.INT, DType.BOOL, DType.STRING, DType.UUID]
427
485
  SEGMENTATION_ALLOWED_COLUMN_TYPES = [
428
486
  ScalarType(dtype=d_type) for d_type in SEGMENTATION_ALLOWED_DTYPES
@@ -1,3 +1,4 @@
1
+ from datetime import datetime
1
2
  from enum import Enum
2
3
  from typing import Any, Dict, List, Optional, Self, Type, Union
3
4
 
@@ -26,6 +27,15 @@ class RuleScope(str, Enum):
26
27
  TASK = "task"
27
28
 
28
29
 
30
+ class MetricType(str, Enum):
31
+ QUERY_RELEVANCE = "QueryRelevance"
32
+ RESPONSE_RELEVANCE = "ResponseRelevance"
33
+ TOOL_SELECTION = "ToolSelection"
34
+
35
+ def __str__(self):
36
+ return self.value
37
+
38
+
29
39
  class BaseEnum(str, Enum):
30
40
  @classmethod
31
41
  def values(cls) -> list[Any]:
@@ -240,6 +250,27 @@ class RuleResponse(BaseModel):
240
250
  )
241
251
 
242
252
 
253
+ class MetricResponse(BaseModel):
254
+ id: str = Field(description="ID of the Metric")
255
+ name: str = Field(description="Name of the Metric")
256
+ type: MetricType = Field(description="Type of the Metric")
257
+ metric_metadata: str = Field(description="Metadata of the Metric")
258
+ config: Optional[str] = Field(
259
+ description="JSON-serialized configuration for the Metric",
260
+ default=None,
261
+ )
262
+ created_at: datetime = Field(
263
+ description="Time the Metric was created in unix milliseconds",
264
+ )
265
+ updated_at: datetime = Field(
266
+ description="Time the Metric was updated in unix milliseconds",
267
+ )
268
+ enabled: Optional[bool] = Field(
269
+ description="Whether the Metric is enabled",
270
+ default=None,
271
+ )
272
+
273
+
243
274
  class TaskResponse(BaseModel):
244
275
  id: str = Field(description=" ID of the task")
245
276
  name: str = Field(description="Name of the task")
@@ -249,7 +280,12 @@ class TaskResponse(BaseModel):
249
280
  updated_at: int = Field(
250
281
  description="Time the task was created in unix milliseconds",
251
282
  )
283
+ is_agentic: bool = Field(description="Whether the task is agentic or not")
252
284
  rules: List[RuleResponse] = Field(description="List of all the rules for the task.")
285
+ metrics: Optional[List[MetricResponse]] = Field(
286
+ description="List of all the metrics for the task.",
287
+ default=None,
288
+ )
253
289
 
254
290
 
255
291
  class UpdateRuleRequest(BaseModel):
@@ -484,3 +520,125 @@ class NewRuleRequest(BaseModel):
484
520
  detail="Examples must be provided to onboard a ModelSensitiveDataRule",
485
521
  )
486
522
  return self
523
+
524
+
525
+ class RelevanceMetricConfig(BaseModel):
526
+ """Configuration for relevance metrics including QueryRelevance and ResponseRelevance"""
527
+
528
+ relevance_threshold: Optional[float] = Field(
529
+ default=None,
530
+ description="Threshold for determining relevance when not using LLM judge",
531
+ )
532
+ use_llm_judge: bool = Field(
533
+ default=True,
534
+ description="Whether to use LLM as a judge for relevance scoring",
535
+ )
536
+
537
+
538
+ class NewMetricRequest(BaseModel):
539
+ type: MetricType = Field(
540
+ description="Type of the metric. It can only be one of QueryRelevance, ResponseRelevance, ToolSelection",
541
+ examples=["UserQueryRelevance"],
542
+ )
543
+ name: str = Field(
544
+ description="Name of metric",
545
+ examples=["My User Query Relevance"],
546
+ )
547
+ metric_metadata: str = Field(description="Additional metadata for the metric")
548
+ config: Optional[RelevanceMetricConfig] = Field(
549
+ description="Configuration for the metric. Currently only applies to UserQueryRelevance and ResponseRelevance metric types.",
550
+ default=None,
551
+ )
552
+
553
+ model_config = ConfigDict(
554
+ json_schema_extra={
555
+ "example1": {
556
+ "type": "QueryRelevance",
557
+ "name": "My User Query Relevance",
558
+ "metric_metadata": "This is a test metric metadata",
559
+ },
560
+ "example2": {
561
+ "type": "QueryRelevance",
562
+ "name": "My User Query Relevance with Config",
563
+ "metric_metadata": "This is a test metric metadata",
564
+ "config": {"relevance_threshold": 0.8, "use_llm_judge": False},
565
+ },
566
+ "example3": {
567
+ "type": "ResponseRelevance",
568
+ "name": "My Response Relevance",
569
+ "metric_metadata": "This is a test metric metadata",
570
+ "config": {"use_llm_judge": True},
571
+ },
572
+ },
573
+ )
574
+
575
+ @field_validator("type")
576
+ def validate_metric_type(cls, value):
577
+ if value not in MetricType:
578
+ raise ValueError(
579
+ f"Invalid metric type: {value}. Valid types are: {', '.join([t.value for t in MetricType])}",
580
+ )
581
+ return value
582
+
583
+ @model_validator(mode="before")
584
+ def set_config_type(cls, values):
585
+ if not isinstance(values, dict):
586
+ return values
587
+
588
+ metric_type = values.get("type")
589
+ config_values = values.get("config")
590
+
591
+ # Map metric types to their corresponding config classes
592
+ metric_type_to_config = {
593
+ MetricType.QUERY_RELEVANCE: RelevanceMetricConfig,
594
+ MetricType.RESPONSE_RELEVANCE: RelevanceMetricConfig,
595
+ # Add new metric types and their configs here as needed
596
+ }
597
+
598
+ config_class = metric_type_to_config.get(metric_type)
599
+
600
+ if config_class is not None:
601
+ if config_values is None:
602
+ # Default config when none is provided
603
+ config_values = {"use_llm_judge": True}
604
+ elif isinstance(config_values, dict):
605
+ # Handle mutually exclusive parameters
606
+ if (
607
+ "relevance_threshold" in config_values
608
+ and "use_llm_judge" in config_values
609
+ and config_values["use_llm_judge"]
610
+ ):
611
+ raise HTTPException(
612
+ status_code=400,
613
+ detail="relevance_threshold and use_llm_judge=true are mutually exclusive. Set use_llm_judge=false when using relevance_threshold.",
614
+ headers={"full_stacktrace": "false"},
615
+ )
616
+
617
+ # If relevance_threshold is set but use_llm_judge isn't, set use_llm_judge to false
618
+ if (
619
+ "relevance_threshold" in config_values
620
+ and "use_llm_judge" not in config_values
621
+ ):
622
+ config_values["use_llm_judge"] = False
623
+
624
+ # If neither is set, default to use_llm_judge=True
625
+ if (
626
+ "relevance_threshold" not in config_values
627
+ and "use_llm_judge" not in config_values
628
+ ):
629
+ config_values["use_llm_judge"] = True
630
+
631
+ if isinstance(config_values, BaseModel):
632
+ config_values = config_values.model_dump()
633
+
634
+ values["config"] = config_class(**config_values)
635
+ elif config_values is not None:
636
+ # Provide a nice error message listing supported metric types
637
+ supported_types = [t.value for t in metric_type_to_config.keys()]
638
+ raise HTTPException(
639
+ status_code=400,
640
+ detail=f"Config is only supported for {', '.join(supported_types)} metric types",
641
+ headers={"full_stacktrace": "false"},
642
+ )
643
+
644
+ return values
@@ -1,13 +1,23 @@
1
- from typing import Literal, Optional
1
+ from enum import Enum
2
+ from typing import Literal, Optional, Self
2
3
  from uuid import UUID
3
4
 
4
5
  from pydantic import BaseModel, Field
5
6
 
6
- from arthur_common.models.shield import NewRuleRequest
7
+ from arthur_common.models.shield import (
8
+ NewMetricRequest,
9
+ NewRuleRequest,
10
+ model_validator,
11
+ )
7
12
 
8
13
  onboarding_id_desc = "An identifier to assign to the created model to make it easy to retrieve. Used by the UI during the GenAI model creation flow."
9
14
 
10
15
 
16
+ class TaskType(str, Enum):
17
+ TRADITIONAL = "traditional"
18
+ AGENTIC = "agentic"
19
+
20
+
11
21
  class CreateModelTaskJobSpec(BaseModel):
12
22
  job_type: Literal["create_model_task"] = "create_model_task"
13
23
  connector_id: UUID = Field(
@@ -21,6 +31,20 @@ class CreateModelTaskJobSpec(BaseModel):
21
31
  initial_rules: list[NewRuleRequest] = Field(
22
32
  description="The initial rules to apply to the created model.",
23
33
  )
34
+ task_type: TaskType = Field(
35
+ default=TaskType.TRADITIONAL,
36
+ description="The type of task to create.",
37
+ )
38
+ initial_metrics: list[NewMetricRequest] = Field(
39
+ description="The initial metrics to apply to agentic tasks.",
40
+ )
41
+
42
+ @model_validator(mode="after")
43
+ def initial_metric_required(self) -> Self:
44
+ if self.task_type == TaskType.TRADITIONAL:
45
+ if not len(self.initial_metrics) == 0:
46
+ raise ValueError("No initial_metrics when task_type is TRADITIONAL")
47
+ return self
24
48
 
25
49
 
26
50
  class CreateModelLinkTaskJobSpec(BaseModel):
@@ -207,7 +207,7 @@ class FunctionAnalyzer:
207
207
  )
208
208
  # Check if X implements the required methods
209
209
  required_methods = ["aggregate", "id", "description", "display_name"]
210
- static_methods = ["description", "id", "display_name"]
210
+ static_methods = ["description", "id", "display_name", "reported_aggregations"]
211
211
  for method in required_methods:
212
212
  if not hasattr(agg_func, method) or not callable(getattr(agg_func, method)):
213
213
  raise AttributeError(
@@ -253,6 +253,7 @@ class FunctionAnalyzer:
253
253
  metric_type=metric_type,
254
254
  init_args=aggregation_init_args,
255
255
  aggregate_args=aggregate_args,
256
+ reported_aggregations=agg_func.reported_aggregations(),
256
257
  )
257
258
 
258
259
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: arthur-common
3
- Version: 2.1.52
3
+ Version: 2.1.54
4
4
  Summary: Utility code common to Arthur platform components.
5
5
  License: MIT
6
6
  Author: Arthur