arthur-common 2.1.58__py3-none-any.whl → 2.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arthur_common/aggregations/aggregator.py +73 -9
- arthur_common/aggregations/functions/agentic_aggregations.py +260 -85
- arthur_common/aggregations/functions/categorical_count.py +15 -15
- arthur_common/aggregations/functions/confusion_matrix.py +24 -26
- arthur_common/aggregations/functions/inference_count.py +5 -9
- arthur_common/aggregations/functions/inference_count_by_class.py +16 -27
- arthur_common/aggregations/functions/inference_null_count.py +10 -13
- arthur_common/aggregations/functions/mean_absolute_error.py +12 -18
- arthur_common/aggregations/functions/mean_squared_error.py +12 -18
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +13 -20
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +1 -1
- arthur_common/aggregations/functions/numeric_stats.py +13 -15
- arthur_common/aggregations/functions/numeric_sum.py +12 -15
- arthur_common/aggregations/functions/shield_aggregations.py +457 -215
- arthur_common/models/common_schemas.py +214 -0
- arthur_common/models/connectors.py +10 -2
- arthur_common/models/constants.py +24 -0
- arthur_common/models/datasets.py +0 -9
- arthur_common/models/enums.py +177 -0
- arthur_common/models/metric_schemas.py +63 -0
- arthur_common/models/metrics.py +2 -9
- arthur_common/models/request_schemas.py +870 -0
- arthur_common/models/response_schemas.py +785 -0
- arthur_common/models/schema_definitions.py +6 -1
- arthur_common/models/task_job_specs.py +3 -12
- arthur_common/tools/duckdb_data_loader.py +34 -2
- arthur_common/tools/duckdb_utils.py +3 -6
- arthur_common/tools/schema_inferer.py +3 -6
- {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/METADATA +12 -4
- arthur_common-2.4.13.dist-info/RECORD +49 -0
- arthur_common/models/shield.py +0 -642
- arthur_common-2.1.58.dist-info/RECORD +0 -44
- {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,870 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Any, Dict, List, Optional, Self, Type
|
|
3
|
+
|
|
4
|
+
from fastapi import HTTPException
|
|
5
|
+
from openinference.semconv.trace import OpenInferenceSpanKindValues
|
|
6
|
+
from pydantic import (
|
|
7
|
+
BaseModel,
|
|
8
|
+
ConfigDict,
|
|
9
|
+
Field,
|
|
10
|
+
ValidationInfo,
|
|
11
|
+
field_validator,
|
|
12
|
+
model_validator,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from arthur_common.models.common_schemas import (
|
|
16
|
+
ExamplesConfig,
|
|
17
|
+
KeywordsConfig,
|
|
18
|
+
PIIConfig,
|
|
19
|
+
RegexConfig,
|
|
20
|
+
ToxicityConfig,
|
|
21
|
+
)
|
|
22
|
+
from arthur_common.models.constants import (
|
|
23
|
+
ERROR_PASSWORD_POLICY_NOT_MET,
|
|
24
|
+
GENAI_ENGINE_KEYCLOAK_PASSWORD_LENGTH,
|
|
25
|
+
HALLUCINATION_RULE_NAME,
|
|
26
|
+
NEGATIVE_BLOOD_EXAMPLE,
|
|
27
|
+
)
|
|
28
|
+
from arthur_common.models.enums import (
|
|
29
|
+
APIKeysRolesEnum,
|
|
30
|
+
InferenceFeedbackTarget,
|
|
31
|
+
MetricType,
|
|
32
|
+
PIIEntityTypes,
|
|
33
|
+
RuleScope,
|
|
34
|
+
RuleType,
|
|
35
|
+
StatusCodeEnum,
|
|
36
|
+
ToolClassEnum,
|
|
37
|
+
)
|
|
38
|
+
from arthur_common.models.metric_schemas import RelevanceMetricConfig
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class UpdateRuleRequest(BaseModel):
|
|
42
|
+
enabled: bool = Field(description="Boolean value to enable or disable the rule. ")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Using the latest version from arthur-common
|
|
46
|
+
class NewRuleRequest(BaseModel):
|
|
47
|
+
name: str = Field(description="Name of the rule", examples=["SSN Regex Rule"])
|
|
48
|
+
type: str = Field(
|
|
49
|
+
description="Type of the rule. It can only be one of KeywordRule, RegexRule, "
|
|
50
|
+
"ModelSensitiveDataRule, ModelHallucinationRule, ModelHallucinationRuleV2, PromptInjectionRule, PIIDataRule",
|
|
51
|
+
examples=["RegexRule"],
|
|
52
|
+
)
|
|
53
|
+
apply_to_prompt: bool = Field(
|
|
54
|
+
description="Boolean value to enable or disable the rule for llm prompt",
|
|
55
|
+
examples=[True],
|
|
56
|
+
)
|
|
57
|
+
apply_to_response: bool = Field(
|
|
58
|
+
description="Boolean value to enable or disable the rule for llm response",
|
|
59
|
+
examples=[False],
|
|
60
|
+
)
|
|
61
|
+
config: (
|
|
62
|
+
KeywordsConfig
|
|
63
|
+
| RegexConfig
|
|
64
|
+
| ExamplesConfig
|
|
65
|
+
| ToxicityConfig
|
|
66
|
+
| PIIConfig
|
|
67
|
+
| None
|
|
68
|
+
) = Field(description="Config of the rule", default=None)
|
|
69
|
+
|
|
70
|
+
model_config = ConfigDict(
|
|
71
|
+
json_schema_extra={
|
|
72
|
+
"example1": {
|
|
73
|
+
"summary": "Sensitive Data Example",
|
|
74
|
+
"description": "Sensitive Data Example with its required configuration",
|
|
75
|
+
"value": {
|
|
76
|
+
"name": "Sensitive Data Rule",
|
|
77
|
+
"type": "ModelSensitiveDataRule",
|
|
78
|
+
"apply_to_prompt": True,
|
|
79
|
+
"apply_to_response": False,
|
|
80
|
+
"config": {
|
|
81
|
+
"examples": [
|
|
82
|
+
{
|
|
83
|
+
"example": NEGATIVE_BLOOD_EXAMPLE,
|
|
84
|
+
"result": True,
|
|
85
|
+
},
|
|
86
|
+
{
|
|
87
|
+
"example": "Most of the people have A positive blood group",
|
|
88
|
+
"result": False,
|
|
89
|
+
},
|
|
90
|
+
],
|
|
91
|
+
"hint": "specific individual's blood types",
|
|
92
|
+
},
|
|
93
|
+
},
|
|
94
|
+
},
|
|
95
|
+
"example2": {
|
|
96
|
+
"summary": "Regex Example",
|
|
97
|
+
"description": "Regex Example with its required configuration. Be sure to properly encode requests "
|
|
98
|
+
"using JSON libraries. For example, the regex provided encodes to a different string "
|
|
99
|
+
"when encoded to account for escape characters.",
|
|
100
|
+
"value": {
|
|
101
|
+
"name": "SSN Regex Rule",
|
|
102
|
+
"type": "RegexRule",
|
|
103
|
+
"apply_to_prompt": True,
|
|
104
|
+
"apply_to_response": True,
|
|
105
|
+
"config": {
|
|
106
|
+
"regex_patterns": [
|
|
107
|
+
"\\d{3}-\\d{2}-\\d{4}",
|
|
108
|
+
"\\d{5}-\\d{6}-\\d{7}",
|
|
109
|
+
],
|
|
110
|
+
},
|
|
111
|
+
},
|
|
112
|
+
},
|
|
113
|
+
"example3": {
|
|
114
|
+
"summary": "Keywords Rule Example",
|
|
115
|
+
"description": "Keywords Rule Example with its required configuration",
|
|
116
|
+
"value": {
|
|
117
|
+
"name": "Blocked Keywords Rule",
|
|
118
|
+
"type": "KeywordRule",
|
|
119
|
+
"apply_to_prompt": True,
|
|
120
|
+
"apply_to_response": True,
|
|
121
|
+
"config": {"keywords": ["Blocked_Keyword_1", "Blocked_Keyword_2"]},
|
|
122
|
+
},
|
|
123
|
+
},
|
|
124
|
+
"example4": {
|
|
125
|
+
"summary": "Prompt Injection Rule Example",
|
|
126
|
+
"description": "Prompt Injection Rule Example, no configuration required",
|
|
127
|
+
"value": {
|
|
128
|
+
"name": "Prompt Injection Rule",
|
|
129
|
+
"type": "PromptInjectionRule",
|
|
130
|
+
"apply_to_prompt": True,
|
|
131
|
+
"apply_to_response": False,
|
|
132
|
+
},
|
|
133
|
+
},
|
|
134
|
+
"example5": {
|
|
135
|
+
"summary": "Hallucination Rule V1 Example (Deprecated)",
|
|
136
|
+
"description": "Hallucination Rule Example, no configuration required (This rule is deprecated. Use "
|
|
137
|
+
"ModelHallucinationRuleV2 instead.)",
|
|
138
|
+
"value": {
|
|
139
|
+
"name": HALLUCINATION_RULE_NAME,
|
|
140
|
+
"type": "ModelHallucinationRule",
|
|
141
|
+
"apply_to_prompt": False,
|
|
142
|
+
"apply_to_response": True,
|
|
143
|
+
},
|
|
144
|
+
},
|
|
145
|
+
"example6": {
|
|
146
|
+
"summary": "Hallucination Rule V2 Example",
|
|
147
|
+
"description": "Hallucination Rule Example, no configuration required",
|
|
148
|
+
"value": {
|
|
149
|
+
"name": HALLUCINATION_RULE_NAME,
|
|
150
|
+
"type": "ModelHallucinationRuleV2",
|
|
151
|
+
"apply_to_prompt": False,
|
|
152
|
+
"apply_to_response": True,
|
|
153
|
+
},
|
|
154
|
+
},
|
|
155
|
+
"example7": {
|
|
156
|
+
"summary": "Hallucination Rule V3 Example (Beta)",
|
|
157
|
+
"description": "Hallucination Rule Example, no configuration required. This rule is in beta and must "
|
|
158
|
+
"be enabled by the system administrator.",
|
|
159
|
+
"value": {
|
|
160
|
+
"name": HALLUCINATION_RULE_NAME,
|
|
161
|
+
"type": "ModelHallucinationRuleV3",
|
|
162
|
+
"apply_to_prompt": False,
|
|
163
|
+
"apply_to_response": True,
|
|
164
|
+
},
|
|
165
|
+
},
|
|
166
|
+
"example8": {
|
|
167
|
+
"summary": "PII Rule Example",
|
|
168
|
+
"description": f'PII Rule Example, no configuration required. "disabled_pii_entities", '
|
|
169
|
+
f'"confidence_threshold", and "allow_list" accepted. Valid value for '
|
|
170
|
+
f'"confidence_threshold" is 0.0-1.0. Valid values for "disabled_pii_entities" '
|
|
171
|
+
f"are {PIIEntityTypes.to_string()}",
|
|
172
|
+
"value": {
|
|
173
|
+
"name": "PII Rule",
|
|
174
|
+
"type": "PIIDataRule",
|
|
175
|
+
"apply_to_prompt": True,
|
|
176
|
+
"apply_to_response": True,
|
|
177
|
+
"config": {
|
|
178
|
+
"disabled_pii_entities": [
|
|
179
|
+
"EMAIL_ADDRESS",
|
|
180
|
+
"PHONE_NUMBER",
|
|
181
|
+
],
|
|
182
|
+
"confidence_threshold": "0.5",
|
|
183
|
+
"allow_list": ["arthur.ai", "Arthur"],
|
|
184
|
+
},
|
|
185
|
+
},
|
|
186
|
+
},
|
|
187
|
+
"example9": {
|
|
188
|
+
"summary": "Toxicity Rule Example",
|
|
189
|
+
"description": "Toxicity Rule Example, no configuration required. Threshold accepted",
|
|
190
|
+
"value": {
|
|
191
|
+
"name": "Toxicity Rule",
|
|
192
|
+
"type": "ToxicityRule",
|
|
193
|
+
"apply_to_prompt": True,
|
|
194
|
+
"apply_to_response": True,
|
|
195
|
+
"config": {"threshold": 0.5},
|
|
196
|
+
},
|
|
197
|
+
},
|
|
198
|
+
},
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
@model_validator(mode="before")
|
|
202
|
+
def set_config_type(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
203
|
+
config_type_to_class: Dict[str, Type[BaseModel]] = {
|
|
204
|
+
RuleType.REGEX: RegexConfig,
|
|
205
|
+
RuleType.KEYWORD: KeywordsConfig,
|
|
206
|
+
RuleType.TOXICITY: ToxicityConfig,
|
|
207
|
+
RuleType.PII_DATA: PIIConfig,
|
|
208
|
+
RuleType.MODEL_SENSITIVE_DATA: ExamplesConfig,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
config_type = values["type"]
|
|
212
|
+
config_class = config_type_to_class.get(config_type)
|
|
213
|
+
|
|
214
|
+
if config_class is not None:
|
|
215
|
+
config_values = values.get("config")
|
|
216
|
+
if config_values is None:
|
|
217
|
+
if config_type in [RuleType.REGEX, RuleType.KEYWORD]:
|
|
218
|
+
raise HTTPException(
|
|
219
|
+
status_code=400,
|
|
220
|
+
detail="This rule must be created with a config parameter",
|
|
221
|
+
)
|
|
222
|
+
config_values = {}
|
|
223
|
+
if isinstance(config_values, BaseModel):
|
|
224
|
+
config_values = config_values.model_dump()
|
|
225
|
+
values["config"] = config_class(**config_values)
|
|
226
|
+
return values
|
|
227
|
+
|
|
228
|
+
@model_validator(mode="after")
|
|
229
|
+
def check_prompt_or_response(self) -> Self:
|
|
230
|
+
if (self.type == RuleType.MODEL_SENSITIVE_DATA) and (
|
|
231
|
+
self.apply_to_response is True
|
|
232
|
+
):
|
|
233
|
+
raise HTTPException(
|
|
234
|
+
status_code=400,
|
|
235
|
+
detail="ModelSensitiveDataRule can only be enabled for prompt. Please set the 'apply_to_response' "
|
|
236
|
+
"field to false.",
|
|
237
|
+
)
|
|
238
|
+
if (self.type == RuleType.PROMPT_INJECTION) and (
|
|
239
|
+
self.apply_to_response is True
|
|
240
|
+
):
|
|
241
|
+
raise HTTPException(
|
|
242
|
+
status_code=400,
|
|
243
|
+
detail="PromptInjectionRule can only be enabled for prompt. Please set the 'apply_to_response' field "
|
|
244
|
+
"to false.",
|
|
245
|
+
)
|
|
246
|
+
if (self.type == RuleType.MODEL_HALLUCINATION_V2) and (
|
|
247
|
+
self.apply_to_prompt is True
|
|
248
|
+
):
|
|
249
|
+
raise HTTPException(
|
|
250
|
+
status_code=400,
|
|
251
|
+
detail="ModelHallucinationRuleV2 can only be enabled for response. Please set the 'apply_to_prompt' "
|
|
252
|
+
"field to false.",
|
|
253
|
+
)
|
|
254
|
+
if (self.apply_to_prompt is False) and (self.apply_to_response is False):
|
|
255
|
+
raise HTTPException(
|
|
256
|
+
status_code=400,
|
|
257
|
+
detail="Rule must be either applied to the prompt or to the response.",
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
return self
|
|
261
|
+
|
|
262
|
+
@model_validator(mode="after")
|
|
263
|
+
def check_examples_non_null(self) -> Self:
|
|
264
|
+
if self.type == RuleType.MODEL_SENSITIVE_DATA:
|
|
265
|
+
config = self.config
|
|
266
|
+
if (
|
|
267
|
+
config is not None
|
|
268
|
+
and isinstance(config, ExamplesConfig)
|
|
269
|
+
and (config.examples is None or len(config.examples) == 0)
|
|
270
|
+
):
|
|
271
|
+
raise HTTPException(
|
|
272
|
+
status_code=400,
|
|
273
|
+
detail="Examples must be provided to onboard a ModelSensitiveDataRule",
|
|
274
|
+
)
|
|
275
|
+
return self
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class SearchTasksRequest(BaseModel):
|
|
279
|
+
task_ids: Optional[list[str]] = Field(
|
|
280
|
+
description="List of tasks to query for.",
|
|
281
|
+
default=None,
|
|
282
|
+
)
|
|
283
|
+
task_name: Optional[str] = Field(
|
|
284
|
+
description="Task name substring search string.",
|
|
285
|
+
default=None,
|
|
286
|
+
)
|
|
287
|
+
is_agentic: Optional[bool] = Field(
|
|
288
|
+
description="Filter tasks by agentic status. If not provided, returns both agentic and non-agentic tasks.",
|
|
289
|
+
default=None,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class SearchRulesRequest(BaseModel):
|
|
294
|
+
rule_ids: Optional[list[str]] = Field(
|
|
295
|
+
description="List of rule IDs to search for.",
|
|
296
|
+
default=None,
|
|
297
|
+
)
|
|
298
|
+
rule_scopes: Optional[list[RuleScope]] = Field(
|
|
299
|
+
description="List of rule scopes to search for.",
|
|
300
|
+
default=None,
|
|
301
|
+
)
|
|
302
|
+
prompt_enabled: Optional[bool] = Field(
|
|
303
|
+
description="Include or exclude prompt-enabled rules.",
|
|
304
|
+
default=None,
|
|
305
|
+
)
|
|
306
|
+
response_enabled: Optional[bool] = Field(
|
|
307
|
+
description="Include or exclude response-enabled rules.",
|
|
308
|
+
default=None,
|
|
309
|
+
)
|
|
310
|
+
rule_types: Optional[list[RuleType]] = Field(
|
|
311
|
+
description="List of rule types to search for.",
|
|
312
|
+
default=None,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class NewTaskRequest(BaseModel):
|
|
317
|
+
name: str = Field(description="Name of the task.", min_length=1)
|
|
318
|
+
is_agentic: bool = Field(
|
|
319
|
+
description="Whether the task is agentic or not.",
|
|
320
|
+
default=False,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class NewApiKeyRequest(BaseModel):
|
|
325
|
+
description: Optional[str] = Field(
|
|
326
|
+
description="Description of the API key. Optional.",
|
|
327
|
+
default=None,
|
|
328
|
+
)
|
|
329
|
+
roles: Optional[list[APIKeysRolesEnum]] = Field(
|
|
330
|
+
description=f"Role that will be assigned to API key. Allowed values: {[role for role in APIKeysRolesEnum]}",
|
|
331
|
+
default=[APIKeysRolesEnum.VALIDATION_USER],
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class PromptValidationRequest(BaseModel):
|
|
336
|
+
prompt: str = Field(description="Prompt to be validated by GenAI Engine")
|
|
337
|
+
# context: Optional[str] = Field(
|
|
338
|
+
# description="Optional data provided as context for the prompt validation. "
|
|
339
|
+
# "Currently not used"
|
|
340
|
+
# )
|
|
341
|
+
conversation_id: Optional[str] = Field(
|
|
342
|
+
description="The unique conversation ID this prompt belongs to. All prompts and responses from this \
|
|
343
|
+
conversation can later be reconstructed with this ID.",
|
|
344
|
+
default=None,
|
|
345
|
+
)
|
|
346
|
+
user_id: Optional[str] = Field(
|
|
347
|
+
description="The user ID this prompt belongs to",
|
|
348
|
+
default=None,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class ResponseValidationRequest(BaseModel):
|
|
353
|
+
response: str = Field(description="LLM Response to be validated by GenAI Engine")
|
|
354
|
+
context: Optional[str] = Field(
|
|
355
|
+
description="Optional data provided as context for the validation.",
|
|
356
|
+
default=None,
|
|
357
|
+
)
|
|
358
|
+
model_name: Optional[str] = Field(
|
|
359
|
+
description="The model name and version being used for this response (e.g., 'gpt-4', 'gpt-3.5-turbo', 'claude-3-opus', 'gemini-pro').",
|
|
360
|
+
default=None,
|
|
361
|
+
)
|
|
362
|
+
# tokens: Optional[List[str]] = Field(description="optional, not used currently")
|
|
363
|
+
# token_likelihoods: Optional[List[str]] = Field(
|
|
364
|
+
# description="optional, not used currently"
|
|
365
|
+
# )
|
|
366
|
+
|
|
367
|
+
@model_validator(mode="after")
|
|
368
|
+
def check_prompt_or_response(cls, values: Any) -> Any:
|
|
369
|
+
if isinstance(values, PromptValidationRequest) and values.prompt is None:
|
|
370
|
+
raise ValueError("prompt is required when validating a prompt")
|
|
371
|
+
if isinstance(values, ResponseValidationRequest) and values.response is None:
|
|
372
|
+
raise ValueError("response is required when validating a response")
|
|
373
|
+
return values
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class ChatRequest(BaseModel):
|
|
377
|
+
user_prompt: str = Field(description="Prompt user wants to send to chat.")
|
|
378
|
+
conversation_id: str = Field(description="Conversation ID")
|
|
379
|
+
file_ids: List[str] = Field(
|
|
380
|
+
description="list of file IDs to retrieve from during chat.",
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class FeedbackRequest(BaseModel):
|
|
385
|
+
target: InferenceFeedbackTarget
|
|
386
|
+
score: int
|
|
387
|
+
reason: str | None
|
|
388
|
+
user_id: str | None = None
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class CreateUserRequest(BaseModel):
|
|
392
|
+
email: str
|
|
393
|
+
password: str
|
|
394
|
+
temporary: bool = True
|
|
395
|
+
roles: list[str]
|
|
396
|
+
firstName: str
|
|
397
|
+
lastName: str
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
class PasswordResetRequest(BaseModel):
|
|
401
|
+
password: str
|
|
402
|
+
|
|
403
|
+
@field_validator("password")
|
|
404
|
+
@classmethod
|
|
405
|
+
def password_meets_security(cls, value: str) -> str:
|
|
406
|
+
special_characters = '!@#$%^&*()-+?_=,<>/"'
|
|
407
|
+
if not len(value) >= GENAI_ENGINE_KEYCLOAK_PASSWORD_LENGTH:
|
|
408
|
+
raise ValueError(ERROR_PASSWORD_POLICY_NOT_MET)
|
|
409
|
+
if (
|
|
410
|
+
not any(c.isupper() for c in value)
|
|
411
|
+
or not any(c.islower() for c in value)
|
|
412
|
+
or not any(c.isdigit() for c in value)
|
|
413
|
+
or not any(c in special_characters for c in value)
|
|
414
|
+
):
|
|
415
|
+
raise ValueError(ERROR_PASSWORD_POLICY_NOT_MET)
|
|
416
|
+
return value
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
class ChatDefaultTaskRequest(BaseModel):
|
|
420
|
+
task_id: str
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
# Using the latest version from arthur-common
|
|
424
|
+
class NewMetricRequest(BaseModel):
|
|
425
|
+
type: MetricType = Field(
|
|
426
|
+
description="Type of the metric. It can only be one of QueryRelevance, ResponseRelevance, ToolSelection",
|
|
427
|
+
examples=["UserQueryRelevance"],
|
|
428
|
+
)
|
|
429
|
+
name: str = Field(
|
|
430
|
+
description="Name of metric",
|
|
431
|
+
examples=["My User Query Relevance"],
|
|
432
|
+
)
|
|
433
|
+
metric_metadata: str = Field(description="Additional metadata for the metric")
|
|
434
|
+
config: Optional[RelevanceMetricConfig] = Field(
|
|
435
|
+
description="Configuration for the metric. Currently only applies to UserQueryRelevance and ResponseRelevance metric types.",
|
|
436
|
+
default=None,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
model_config = ConfigDict(
|
|
440
|
+
json_schema_extra={
|
|
441
|
+
"example1": {
|
|
442
|
+
"type": "QueryRelevance",
|
|
443
|
+
"name": "My User Query Relevance",
|
|
444
|
+
"metric_metadata": "This is a test metric metadata",
|
|
445
|
+
},
|
|
446
|
+
"example2": {
|
|
447
|
+
"type": "QueryRelevance",
|
|
448
|
+
"name": "My User Query Relevance with Config",
|
|
449
|
+
"metric_metadata": "This is a test metric metadata",
|
|
450
|
+
"config": {"relevance_threshold": 0.8, "use_llm_judge": False},
|
|
451
|
+
},
|
|
452
|
+
"example3": {
|
|
453
|
+
"type": "ResponseRelevance",
|
|
454
|
+
"name": "My Response Relevance",
|
|
455
|
+
"metric_metadata": "This is a test metric metadata",
|
|
456
|
+
"config": {"use_llm_judge": True},
|
|
457
|
+
},
|
|
458
|
+
},
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
@model_validator(mode="before")
|
|
462
|
+
def set_config_type(cls, values: dict[str, Any] | None) -> dict[str, Any] | None:
|
|
463
|
+
if not isinstance(values, dict):
|
|
464
|
+
return values
|
|
465
|
+
|
|
466
|
+
try:
|
|
467
|
+
metric_type = MetricType(values.get("type", "empty_value"))
|
|
468
|
+
except ValueError:
|
|
469
|
+
raise HTTPException(
|
|
470
|
+
status_code=400,
|
|
471
|
+
detail=f"Invalid metric type: {values.get('type', 'empty_value')}. Must be one of {[t.value for t in MetricType]}",
|
|
472
|
+
headers={"full_stacktrace": "false"},
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
config_values = values.get("config")
|
|
476
|
+
|
|
477
|
+
# Map metric types to their corresponding config classes
|
|
478
|
+
metric_type_to_config = {
|
|
479
|
+
MetricType.QUERY_RELEVANCE: RelevanceMetricConfig,
|
|
480
|
+
MetricType.RESPONSE_RELEVANCE: RelevanceMetricConfig,
|
|
481
|
+
# Add new metric types and their configs here as needed
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
config_class = metric_type_to_config.get(metric_type)
|
|
485
|
+
|
|
486
|
+
if config_class is not None:
|
|
487
|
+
if config_values is None:
|
|
488
|
+
# Default config when none is provided
|
|
489
|
+
config_values = {"use_llm_judge": True}
|
|
490
|
+
elif isinstance(config_values, dict):
|
|
491
|
+
relevance_threshold = config_values.get("relevance_threshold")
|
|
492
|
+
use_llm_judge = config_values.get("use_llm_judge")
|
|
493
|
+
|
|
494
|
+
# Handle mutually exclusive parameters
|
|
495
|
+
if relevance_threshold is not None and use_llm_judge:
|
|
496
|
+
raise HTTPException(
|
|
497
|
+
status_code=400,
|
|
498
|
+
detail="relevance_threshold and use_llm_judge=true are mutually exclusive. Set use_llm_judge=false when using relevance_threshold.",
|
|
499
|
+
headers={"full_stacktrace": "false"},
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# If relevance_threshold is set but use_llm_judge isn't, set use_llm_judge to false
|
|
503
|
+
if relevance_threshold is not None and use_llm_judge is None:
|
|
504
|
+
config_values["use_llm_judge"] = False
|
|
505
|
+
|
|
506
|
+
# If neither is set, default to use_llm_judge=True
|
|
507
|
+
if relevance_threshold is None and (
|
|
508
|
+
use_llm_judge is None or use_llm_judge == False
|
|
509
|
+
):
|
|
510
|
+
config_values["use_llm_judge"] = True
|
|
511
|
+
|
|
512
|
+
if isinstance(config_values, BaseModel):
|
|
513
|
+
config_values = config_values.model_dump()
|
|
514
|
+
|
|
515
|
+
values["config"] = config_class(**config_values)
|
|
516
|
+
elif config_values is not None:
|
|
517
|
+
# Provide a nice error message listing supported metric types
|
|
518
|
+
supported_types = [t.value for t in metric_type_to_config.keys()]
|
|
519
|
+
raise HTTPException(
|
|
520
|
+
status_code=400,
|
|
521
|
+
detail=f"Config is only supported for {', '.join(supported_types)} metric types",
|
|
522
|
+
headers={"full_stacktrace": "false"},
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
return values
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
class UpdateMetricRequest(BaseModel):
|
|
529
|
+
enabled: bool = Field(description="Boolean value to enable or disable the metric. ")
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
class SpanQueryRequest(BaseModel):
|
|
533
|
+
"""Request schema for querying spans with validation."""
|
|
534
|
+
|
|
535
|
+
task_ids: list[str] = Field(
|
|
536
|
+
...,
|
|
537
|
+
description="Task IDs to filter on. At least one is required.",
|
|
538
|
+
min_length=1,
|
|
539
|
+
)
|
|
540
|
+
span_types: Optional[list[str]] = Field(
|
|
541
|
+
None,
|
|
542
|
+
description=f"Span types to filter on. Optional. Valid values: {', '.join(sorted([kind.value for kind in OpenInferenceSpanKindValues]))}",
|
|
543
|
+
)
|
|
544
|
+
start_time: Optional[datetime] = Field(
|
|
545
|
+
None,
|
|
546
|
+
description="Inclusive start date in ISO8601 string format.",
|
|
547
|
+
)
|
|
548
|
+
end_time: Optional[datetime] = Field(
|
|
549
|
+
None,
|
|
550
|
+
description="Exclusive end date in ISO8601 string format.",
|
|
551
|
+
)
|
|
552
|
+
session_ids: Optional[list[str]] = Field(
|
|
553
|
+
None,
|
|
554
|
+
description="Session IDs to filter on. Optional.",
|
|
555
|
+
)
|
|
556
|
+
span_ids: Optional[list[str]] = Field(
|
|
557
|
+
None,
|
|
558
|
+
description="Span IDs to filter on. Optional.",
|
|
559
|
+
)
|
|
560
|
+
user_ids: Optional[list[str]] = Field(
|
|
561
|
+
None,
|
|
562
|
+
description="User IDs to filter on. Optional.",
|
|
563
|
+
)
|
|
564
|
+
span_name: Optional[str] = Field(
|
|
565
|
+
None,
|
|
566
|
+
description="Return only results with this span name.",
|
|
567
|
+
)
|
|
568
|
+
span_name_contains: Optional[str] = Field(
|
|
569
|
+
None,
|
|
570
|
+
description="Return only results where span name contains this substring.",
|
|
571
|
+
)
|
|
572
|
+
status_code: Optional[list[StatusCodeEnum]] = Field(
|
|
573
|
+
None,
|
|
574
|
+
description="Status codes to filter on. Optional. Valid values: Ok, Error, Unset",
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
@field_validator("span_types")
|
|
578
|
+
@classmethod
|
|
579
|
+
def validate_span_types(cls, value: list[str]) -> list[str]:
|
|
580
|
+
"""Validate that all span_types are valid OpenInference span kinds."""
|
|
581
|
+
if not value:
|
|
582
|
+
return value
|
|
583
|
+
|
|
584
|
+
# Get all valid span kind values
|
|
585
|
+
valid_span_kinds = [kind.value for kind in OpenInferenceSpanKindValues]
|
|
586
|
+
invalid_types = [st for st in value if st not in valid_span_kinds]
|
|
587
|
+
|
|
588
|
+
if invalid_types:
|
|
589
|
+
raise ValueError(
|
|
590
|
+
f"Invalid span_types received: {invalid_types}. "
|
|
591
|
+
f"Valid values: {', '.join(sorted(valid_span_kinds))}",
|
|
592
|
+
)
|
|
593
|
+
return value
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class TraceQueryRequest(BaseModel):
|
|
597
|
+
"""Request schema for querying traces with comprehensive filtering."""
|
|
598
|
+
|
|
599
|
+
# Required
|
|
600
|
+
task_ids: list[str] = Field(
|
|
601
|
+
...,
|
|
602
|
+
description="Task IDs to filter on. At least one is required.",
|
|
603
|
+
min_length=1,
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# Common optional filters
|
|
607
|
+
trace_ids: Optional[list[str]] = Field(
|
|
608
|
+
None,
|
|
609
|
+
description="Trace IDs to filter on. Optional.",
|
|
610
|
+
)
|
|
611
|
+
start_time: Optional[datetime] = Field(
|
|
612
|
+
None,
|
|
613
|
+
description="Inclusive start date in ISO8601 string format. Use local time (not UTC).",
|
|
614
|
+
)
|
|
615
|
+
end_time: Optional[datetime] = Field(
|
|
616
|
+
None,
|
|
617
|
+
description="Exclusive end date in ISO8601 string format. Use local time (not UTC).",
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# New trace-level filters
|
|
621
|
+
tool_name: Optional[str] = Field(
|
|
622
|
+
None,
|
|
623
|
+
description="Return only results with this tool name.",
|
|
624
|
+
)
|
|
625
|
+
span_types: Optional[list[str]] = Field(
|
|
626
|
+
None,
|
|
627
|
+
description="Span types to filter on. Optional.",
|
|
628
|
+
)
|
|
629
|
+
span_ids: Optional[list[str]] = Field(
|
|
630
|
+
None,
|
|
631
|
+
description="Span IDs to filter on. Optional.",
|
|
632
|
+
)
|
|
633
|
+
session_ids: Optional[list[str]] = Field(
|
|
634
|
+
None,
|
|
635
|
+
description="Session IDs to filter on. Optional.",
|
|
636
|
+
)
|
|
637
|
+
user_ids: Optional[list[str]] = Field(
|
|
638
|
+
None,
|
|
639
|
+
description="User IDs to filter on. Optional.",
|
|
640
|
+
)
|
|
641
|
+
span_name: Optional[str] = Field(
|
|
642
|
+
None,
|
|
643
|
+
description="Return only results with this span name.",
|
|
644
|
+
)
|
|
645
|
+
span_name_contains: Optional[str] = Field(
|
|
646
|
+
None,
|
|
647
|
+
description="Return only results where span name contains this substring.",
|
|
648
|
+
)
|
|
649
|
+
status_code: Optional[list[StatusCodeEnum]] = Field(
|
|
650
|
+
None,
|
|
651
|
+
description="Status codes to filter on. Optional. Valid values: Ok, Error, Unset",
|
|
652
|
+
)
|
|
653
|
+
annotation_score: Optional[int] = Field(
|
|
654
|
+
None,
|
|
655
|
+
ge=0,
|
|
656
|
+
le=1,
|
|
657
|
+
description="Filter by trace annotation score (0 or 1).",
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
# Query relevance filters
|
|
661
|
+
query_relevance_eq: Optional[float] = Field(
|
|
662
|
+
None,
|
|
663
|
+
ge=0,
|
|
664
|
+
le=1,
|
|
665
|
+
description="Equal to this value.",
|
|
666
|
+
)
|
|
667
|
+
query_relevance_gt: Optional[float] = Field(
|
|
668
|
+
None,
|
|
669
|
+
ge=0,
|
|
670
|
+
le=1,
|
|
671
|
+
description="Greater than this value.",
|
|
672
|
+
)
|
|
673
|
+
query_relevance_gte: Optional[float] = Field(
|
|
674
|
+
None,
|
|
675
|
+
ge=0,
|
|
676
|
+
le=1,
|
|
677
|
+
description="Greater than or equal to this value.",
|
|
678
|
+
)
|
|
679
|
+
query_relevance_lt: Optional[float] = Field(
|
|
680
|
+
None,
|
|
681
|
+
ge=0,
|
|
682
|
+
le=1,
|
|
683
|
+
description="Less than this value.",
|
|
684
|
+
)
|
|
685
|
+
query_relevance_lte: Optional[float] = Field(
|
|
686
|
+
None,
|
|
687
|
+
ge=0,
|
|
688
|
+
le=1,
|
|
689
|
+
description="Less than or equal to this value.",
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
# Response relevance filters
|
|
693
|
+
response_relevance_eq: Optional[float] = Field(
|
|
694
|
+
None,
|
|
695
|
+
ge=0,
|
|
696
|
+
le=1,
|
|
697
|
+
description="Equal to this value.",
|
|
698
|
+
)
|
|
699
|
+
response_relevance_gt: Optional[float] = Field(
|
|
700
|
+
None,
|
|
701
|
+
ge=0,
|
|
702
|
+
le=1,
|
|
703
|
+
description="Greater than this value.",
|
|
704
|
+
)
|
|
705
|
+
response_relevance_gte: Optional[float] = Field(
|
|
706
|
+
None,
|
|
707
|
+
ge=0,
|
|
708
|
+
le=1,
|
|
709
|
+
description="Greater than or equal to this value.",
|
|
710
|
+
)
|
|
711
|
+
response_relevance_lt: Optional[float] = Field(
|
|
712
|
+
None,
|
|
713
|
+
ge=0,
|
|
714
|
+
le=1,
|
|
715
|
+
description="Less than this value.",
|
|
716
|
+
)
|
|
717
|
+
response_relevance_lte: Optional[float] = Field(
|
|
718
|
+
None,
|
|
719
|
+
ge=0,
|
|
720
|
+
le=1,
|
|
721
|
+
description="Less than or equal to this value.",
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
# Tool classification filters
|
|
725
|
+
tool_selection: Optional[ToolClassEnum] = Field(
|
|
726
|
+
None,
|
|
727
|
+
description="Tool selection evaluation result.",
|
|
728
|
+
)
|
|
729
|
+
tool_usage: Optional[ToolClassEnum] = Field(
|
|
730
|
+
None,
|
|
731
|
+
description="Tool usage evaluation result.",
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# Trace duration filters
|
|
735
|
+
trace_duration_eq: Optional[float] = Field(
|
|
736
|
+
None,
|
|
737
|
+
ge=0,
|
|
738
|
+
description="Duration exactly equal to this value (seconds).",
|
|
739
|
+
)
|
|
740
|
+
trace_duration_gt: Optional[float] = Field(
|
|
741
|
+
None,
|
|
742
|
+
ge=0,
|
|
743
|
+
description="Duration greater than this value (seconds).",
|
|
744
|
+
)
|
|
745
|
+
trace_duration_gte: Optional[float] = Field(
|
|
746
|
+
None,
|
|
747
|
+
ge=0,
|
|
748
|
+
description="Duration greater than or equal to this value (seconds).",
|
|
749
|
+
)
|
|
750
|
+
trace_duration_lt: Optional[float] = Field(
|
|
751
|
+
None,
|
|
752
|
+
ge=0,
|
|
753
|
+
description="Duration less than this value (seconds).",
|
|
754
|
+
)
|
|
755
|
+
trace_duration_lte: Optional[float] = Field(
|
|
756
|
+
None,
|
|
757
|
+
ge=0,
|
|
758
|
+
description="Duration less than or equal to this value (seconds).",
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
@field_validator(
|
|
762
|
+
"query_relevance_eq",
|
|
763
|
+
"query_relevance_gt",
|
|
764
|
+
"query_relevance_gte",
|
|
765
|
+
"query_relevance_lt",
|
|
766
|
+
"query_relevance_lte",
|
|
767
|
+
"response_relevance_eq",
|
|
768
|
+
"response_relevance_gt",
|
|
769
|
+
"response_relevance_gte",
|
|
770
|
+
"response_relevance_lt",
|
|
771
|
+
"response_relevance_lte",
|
|
772
|
+
mode="before",
|
|
773
|
+
)
|
|
774
|
+
@classmethod
|
|
775
|
+
def validate_relevance_scores(
|
|
776
|
+
cls,
|
|
777
|
+
value: Optional[float],
|
|
778
|
+
info: ValidationInfo,
|
|
779
|
+
) -> Optional[float]:
|
|
780
|
+
"""Validate that relevance scores are between 0 and 1 (inclusive)."""
|
|
781
|
+
if value is not None:
|
|
782
|
+
if not (0.0 <= value <= 1.0):
|
|
783
|
+
raise ValueError(
|
|
784
|
+
f"{info.field_name} value must be between 0 and 1 (inclusive)",
|
|
785
|
+
)
|
|
786
|
+
return value
|
|
787
|
+
|
|
788
|
+
@field_validator(
|
|
789
|
+
"trace_duration_eq",
|
|
790
|
+
"trace_duration_gt",
|
|
791
|
+
"trace_duration_gte",
|
|
792
|
+
"trace_duration_lt",
|
|
793
|
+
"trace_duration_lte",
|
|
794
|
+
mode="before",
|
|
795
|
+
)
|
|
796
|
+
@classmethod
|
|
797
|
+
def validate_trace_duration(
|
|
798
|
+
cls,
|
|
799
|
+
value: Optional[float],
|
|
800
|
+
info: ValidationInfo,
|
|
801
|
+
) -> Optional[float]:
|
|
802
|
+
"""Validate that trace duration values are non-negative."""
|
|
803
|
+
if value is not None:
|
|
804
|
+
if value < 0:
|
|
805
|
+
raise ValueError(
|
|
806
|
+
f"{info.field_name} value must be non-negative (greater than or equal to 0)",
|
|
807
|
+
)
|
|
808
|
+
return value
|
|
809
|
+
|
|
810
|
+
@field_validator("tool_selection", "tool_usage", mode="before")
|
|
811
|
+
@classmethod
|
|
812
|
+
def validate_tool_classification(cls, value: Any) -> Optional[ToolClassEnum]:
|
|
813
|
+
"""Validate tool classification enum values."""
|
|
814
|
+
if value is not None:
|
|
815
|
+
# Handle both integer and enum inputs
|
|
816
|
+
if isinstance(value, int):
|
|
817
|
+
if value not in [0, 1, 2]:
|
|
818
|
+
raise ValueError(
|
|
819
|
+
"Tool classification must be 0 (INCORRECT), "
|
|
820
|
+
"1 (CORRECT), or 2 (NA)",
|
|
821
|
+
)
|
|
822
|
+
return ToolClassEnum(value)
|
|
823
|
+
elif isinstance(value, ToolClassEnum):
|
|
824
|
+
return value
|
|
825
|
+
else:
|
|
826
|
+
raise ValueError(
|
|
827
|
+
"Tool classification must be an integer (0, 1, 2) or ToolClassEnum instance",
|
|
828
|
+
)
|
|
829
|
+
return value
|
|
830
|
+
|
|
831
|
+
@field_validator("span_types")
|
|
832
|
+
@classmethod
|
|
833
|
+
def validate_span_types(cls, value: Optional[list[str]]) -> Optional[list[str]]:
|
|
834
|
+
"""Validate that all span_types are valid OpenInference span kinds."""
|
|
835
|
+
if not value:
|
|
836
|
+
return value
|
|
837
|
+
|
|
838
|
+
# Get all valid span kind values
|
|
839
|
+
valid_span_kinds = [kind.value for kind in OpenInferenceSpanKindValues]
|
|
840
|
+
invalid_types = [st for st in value if st not in valid_span_kinds]
|
|
841
|
+
|
|
842
|
+
if invalid_types:
|
|
843
|
+
raise ValueError(
|
|
844
|
+
f"Invalid span_types received: {invalid_types}. "
|
|
845
|
+
f"Valid values: {', '.join(sorted(valid_span_kinds))}",
|
|
846
|
+
)
|
|
847
|
+
return value
|
|
848
|
+
|
|
849
|
+
@model_validator(mode="after")
|
|
850
|
+
def validate_filter_combinations(self) -> Self:
|
|
851
|
+
"""Validate that filter combinations are logically valid."""
|
|
852
|
+
# Check mutually exclusive filters for each metric type
|
|
853
|
+
for prefix in ["query_relevance", "response_relevance", "trace_duration"]:
|
|
854
|
+
eq_field = f"{prefix}_eq"
|
|
855
|
+
comparison_fields = [f"{prefix}_{op}" for op in ["gt", "gte", "lt", "lte"]]
|
|
856
|
+
|
|
857
|
+
if getattr(self, eq_field) and any(
|
|
858
|
+
getattr(self, field) for field in comparison_fields
|
|
859
|
+
):
|
|
860
|
+
raise ValueError(
|
|
861
|
+
f"{eq_field} cannot be combined with other {prefix} comparison operators",
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# Check for incompatible operator combinations
|
|
865
|
+
if getattr(self, f"{prefix}_gt") and getattr(self, f"{prefix}_gte"):
|
|
866
|
+
raise ValueError(f"Cannot combine {prefix}_gt with {prefix}_gte")
|
|
867
|
+
if getattr(self, f"{prefix}_lt") and getattr(self, f"{prefix}_lte"):
|
|
868
|
+
raise ValueError(f"Cannot combine {prefix}_lt with {prefix}_lte")
|
|
869
|
+
|
|
870
|
+
return self
|