arthur-common 2.1.58__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arthur_common/aggregations/aggregator.py +73 -9
  2. arthur_common/aggregations/functions/agentic_aggregations.py +260 -85
  3. arthur_common/aggregations/functions/categorical_count.py +15 -15
  4. arthur_common/aggregations/functions/confusion_matrix.py +24 -26
  5. arthur_common/aggregations/functions/inference_count.py +5 -9
  6. arthur_common/aggregations/functions/inference_count_by_class.py +16 -27
  7. arthur_common/aggregations/functions/inference_null_count.py +10 -13
  8. arthur_common/aggregations/functions/mean_absolute_error.py +12 -18
  9. arthur_common/aggregations/functions/mean_squared_error.py +12 -18
  10. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +13 -20
  11. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +1 -1
  12. arthur_common/aggregations/functions/numeric_stats.py +13 -15
  13. arthur_common/aggregations/functions/numeric_sum.py +12 -15
  14. arthur_common/aggregations/functions/shield_aggregations.py +457 -215
  15. arthur_common/models/common_schemas.py +214 -0
  16. arthur_common/models/connectors.py +10 -2
  17. arthur_common/models/constants.py +24 -0
  18. arthur_common/models/datasets.py +0 -9
  19. arthur_common/models/enums.py +177 -0
  20. arthur_common/models/metric_schemas.py +63 -0
  21. arthur_common/models/metrics.py +2 -9
  22. arthur_common/models/request_schemas.py +870 -0
  23. arthur_common/models/response_schemas.py +785 -0
  24. arthur_common/models/schema_definitions.py +6 -1
  25. arthur_common/models/task_job_specs.py +3 -12
  26. arthur_common/tools/duckdb_data_loader.py +34 -2
  27. arthur_common/tools/duckdb_utils.py +3 -6
  28. arthur_common/tools/schema_inferer.py +3 -6
  29. {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/METADATA +12 -4
  30. arthur_common-2.4.13.dist-info/RECORD +49 -0
  31. arthur_common/models/shield.py +0 -642
  32. arthur_common-2.1.58.dist-info/RECORD +0 -44
  33. {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,214 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
7
+
8
+ from arthur_common.models.constants import (
9
+ DEFAULT_PII_RULE_CONFIDENCE_SCORE_THRESHOLD,
10
+ DEFAULT_TOXICITY_RULE_THRESHOLD,
11
+ NEGATIVE_BLOOD_EXAMPLE,
12
+ )
13
+ from arthur_common.models.enums import (
14
+ PaginationSortMethod,
15
+ PIIEntityTypes,
16
+ UserPermissionAction,
17
+ UserPermissionResource,
18
+ )
19
+
20
+
21
+ class AuthUserRole(BaseModel):
22
+ id: str | None = None
23
+ name: str
24
+ description: str
25
+ composite: bool
26
+
27
+
28
+ class ExampleConfig(BaseModel):
29
+ example: str = Field(description="Custom example for the sensitive data")
30
+ result: bool = Field(
31
+ description="Boolean value representing if the example passes or fails the the sensitive "
32
+ "data rule ",
33
+ )
34
+
35
+ model_config = ConfigDict(
36
+ json_schema_extra={
37
+ "example": {"example": NEGATIVE_BLOOD_EXAMPLE, "result": True},
38
+ },
39
+ )
40
+
41
+
42
+ class ExamplesConfig(BaseModel):
43
+ examples: List[ExampleConfig] = Field(
44
+ description="List of all the examples for Sensitive Data Rule",
45
+ )
46
+
47
+ model_config = ConfigDict(
48
+ json_schema_extra={
49
+ "example": {
50
+ "examples": [
51
+ {"example": NEGATIVE_BLOOD_EXAMPLE, "result": True},
52
+ {
53
+ "example": "Most of the people have A positive blood group",
54
+ "result": False,
55
+ },
56
+ ],
57
+ "hint": "specific individual's blood type",
58
+ },
59
+ },
60
+ )
61
+ hint: Optional[str] = Field(
62
+ description="Optional. Hint added to describe what Sensitive Data Rule should be checking for",
63
+ default=None,
64
+ )
65
+
66
+ def to_dict(self) -> Dict[str, Any]:
67
+ d = self.__dict__
68
+ d["examples"] = [ex.__dict__ for ex in self.examples]
69
+ d["hint"] = self.hint
70
+ return d
71
+
72
+
73
+ class KeywordsConfig(BaseModel):
74
+ keywords: List[str] = Field(description="List of Keywords")
75
+
76
+ model_config = ConfigDict(
77
+ json_schema_extra={
78
+ "example": {"keywords": ["Blocked_Keyword_1", "Blocked_Keyword_2"]},
79
+ },
80
+ )
81
+
82
+
83
+ class LLMTokenConsumption(BaseModel):
84
+ prompt_tokens: int
85
+ completion_tokens: int
86
+
87
+ def total_tokens(self) -> int:
88
+ return self.prompt_tokens + self.completion_tokens
89
+
90
+ def add(self, token_consumption: LLMTokenConsumption) -> "LLMTokenConsumption":
91
+ self.prompt_tokens += token_consumption.prompt_tokens
92
+ self.completion_tokens += token_consumption.completion_tokens
93
+ return self
94
+
95
+
96
+ class PaginationParameters(BaseModel):
97
+ sort: Optional[PaginationSortMethod] = PaginationSortMethod.DESCENDING
98
+ page_size: int = 10
99
+ page: int = 0
100
+
101
+ def calculate_total_pages(self, total_items_count: int) -> int:
102
+ return math.ceil(total_items_count / self.page_size)
103
+
104
+
105
+ class PIIConfig(BaseModel):
106
+ disabled_pii_entities: Optional[list[str]] = Field(
107
+ description=f"Optional. List of PII entities to disable. Valid values are: {PIIEntityTypes.to_string()}",
108
+ default=None,
109
+ )
110
+
111
+ confidence_threshold: Optional[float] = Field(
112
+ description=f"Optional. Float (0, 1) indicating the level of tolerable PII to consider the rule passed or failed. Min: 0 (less confident) Max: 1 (very confident). Default: {DEFAULT_PII_RULE_CONFIDENCE_SCORE_THRESHOLD}",
113
+ default=DEFAULT_PII_RULE_CONFIDENCE_SCORE_THRESHOLD,
114
+ json_schema_extra={"deprecated": True},
115
+ )
116
+
117
+ allow_list: Optional[list[str]] = Field(
118
+ description="Optional. List of strings to pass PII validation.",
119
+ default=None,
120
+ )
121
+
122
+ @field_validator("disabled_pii_entities")
123
+ def validate_pii_entities(cls, v: list[str] | None) -> list[str] | None:
124
+ if v:
125
+ entities_passed = set(v)
126
+ entities_supported = set(PIIEntityTypes.values())
127
+ invalid_entities = entities_passed - entities_supported
128
+ if invalid_entities:
129
+ raise ValueError(
130
+ f"The following values are not valid PII entities: {invalid_entities}",
131
+ )
132
+
133
+ # Fail the case where they are trying to disable all PII entity types
134
+ if (not invalid_entities) & (
135
+ len(entities_passed) == len(entities_supported)
136
+ ):
137
+ raise ValueError(
138
+ f"Cannot disable all supported PII entities on PIIDataRule",
139
+ )
140
+ return v
141
+ else:
142
+ return v
143
+
144
+ @field_validator("confidence_threshold")
145
+ def validate_confidence_threshold(cls, v: float | None) -> float | None:
146
+ if v:
147
+ if (v < 0) | (v > 1):
148
+ raise ValueError(f'"confidence_threshold" must be between 0 and 1')
149
+ return v
150
+ else:
151
+ return v
152
+
153
+ model_config = ConfigDict(
154
+ json_schema_extra={
155
+ "example": {
156
+ "disabled_pii_entities": ["PERSON", "URL"],
157
+ "confidence_threshold": "0.5",
158
+ "allow_list": ["arthur.ai", "Arthur"],
159
+ },
160
+ },
161
+ extra="forbid",
162
+ )
163
+
164
+
165
+ class RegexConfig(BaseModel):
166
+ regex_patterns: List[str] = Field(
167
+ description="List of Regex patterns to be used for validation. Be sure to encode requests in JSON and account for escape characters.",
168
+ )
169
+
170
+ model_config = ConfigDict(
171
+ json_schema_extra={
172
+ "example": {
173
+ "regex_patterns": ["\\d{3}-\\d{2}-\\d{4}", "\\d{5}-\\d{6}-\\d{7}"],
174
+ },
175
+ },
176
+ extra="forbid",
177
+ )
178
+
179
+
180
+ class ToxicityConfig(BaseModel):
181
+ threshold: float = Field(
182
+ default=DEFAULT_TOXICITY_RULE_THRESHOLD,
183
+ description=f"Optional. Float (0, 1) indicating the level of tolerable toxicity to consider the rule passed or failed. Min: 0 (no toxic language) Max: 1 (very toxic language). Default: {DEFAULT_TOXICITY_RULE_THRESHOLD}",
184
+ )
185
+
186
+ model_config = ConfigDict(
187
+ extra="forbid",
188
+ json_schema_extra={"example": {"threshold": DEFAULT_TOXICITY_RULE_THRESHOLD}},
189
+ )
190
+
191
+ @field_validator("threshold", mode="before")
192
+ @classmethod
193
+ def validate_toxicity_threshold(cls, v: float | None) -> float:
194
+ if v is None:
195
+ return float(DEFAULT_TOXICITY_RULE_THRESHOLD)
196
+ if (v < 0) | (v > 1):
197
+ raise ValueError(f'"threshold" must be between 0 and 1')
198
+ return v
199
+
200
+
201
+ class UserPermission(BaseModel):
202
+ action: UserPermissionAction
203
+ resource: UserPermissionResource
204
+
205
+ def __hash__(self) -> int:
206
+ return hash((self.action, self.resource))
207
+
208
+ def __eq__(self, other: object) -> bool:
209
+ return isinstance(other, UserPermission) and self.__hash__() == other.__hash__()
210
+
211
+
212
+ class VariableTemplateValue(BaseModel):
213
+ name: str = Field(..., description="Name of the variable")
214
+ value: str = Field(..., description="Value of the variable")
@@ -3,12 +3,12 @@ from pydantic import BaseModel, Field
3
3
 
4
4
  class ConnectorPaginationOptions(BaseModel):
5
5
  page: int = Field(default=1, ge=1)
6
- page_size: int = Field(default=25, gt=0, le=500)
6
+ page_size: int = Field(default=25, ge=1, le=500)
7
7
 
8
8
  @property
9
9
  def page_params(self) -> tuple[int, int]:
10
10
  if self.page is not None:
11
- return self.page, self.page_size
11
+ return (self.page, self.page_size)
12
12
  else:
13
13
  raise ValueError(
14
14
  "Pagination options must be set to return a page and page size",
@@ -38,6 +38,14 @@ ODBC_CONNECTOR_DRIVER_FIELD = "driver"
38
38
  ODBC_CONNECTOR_TABLE_NAME_FIELD = "table_name"
39
39
  ODBC_CONNECTOR_DIALECT_FIELD = "dialect"
40
40
 
41
+ # Snowflake connector constants
42
+ SNOWFLAKE_CONNECTOR_ACCOUNT_FIELD = "account"
43
+ SNOWFLAKE_CONNECTOR_SCHEMA_FIELD = "schema"
44
+ SNOWFLAKE_CONNECTOR_WAREHOUSE_FIELD = "warehouse"
45
+ SNOWFLAKE_CONNECTOR_ROLE_FIELD = "role"
46
+ SNOWFLAKE_CONNECTOR_AUTHENTICATOR_FIELD = "authenticator"
47
+ SNOWFLAKE_CONNECTOR_PRIVATE_KEY_FIELD = "private_key"
48
+ SNOWFLAKE_CONNECTOR_PRIVATE_KEY_PASSPHRASE_FIELD = "private_key_passphrase"
41
49
 
42
50
  # dataset (connector type dependent) constants
43
51
  SHIELD_DATASET_TASK_ID_FIELD = "task_id"
@@ -0,0 +1,24 @@
1
+ # RBAC
2
+ CHAT_USER: str = "CHAT-USER"
3
+ ORG_ADMIN: str = "ORG-ADMIN"
4
+ TASK_ADMIN: str = "TASK-ADMIN"
5
+ DEFAULT_RULE_ADMIN: str = "DEFAULT-RULE-ADMIN"
6
+ VALIDATION_USER: str = "VALIDATION-USER"
7
+ ORG_AUDITOR: str = "ORG-AUDITOR"
8
+ ADMIN_KEY: str = "ADMIN-KEY"
9
+
10
+ LEGACY_KEYCLOAK_ROLES: dict[str, str] = {
11
+ "genai_engine_admin_user": TASK_ADMIN,
12
+ }
13
+
14
+ # Make sure the policy and description match
15
+ GENAI_ENGINE_KEYCLOAK_PASSWORD_LENGTH = 12
16
+ GENAI_ENGINE_KEYCLOAK_PASSWORD_POLICY = f"length({GENAI_ENGINE_KEYCLOAK_PASSWORD_LENGTH}) and specialChars(1) and upperCase(1) and lowerCase(1)"
17
+ ERROR_PASSWORD_POLICY_NOT_MET = f"Password should be at least {GENAI_ENGINE_KEYCLOAK_PASSWORD_LENGTH} characters and contain at least one special character, lowercase character, and uppercase character."
18
+ ERROR_DEFAULT_METRICS_ENGINE = "This metric could not be evaluated"
19
+
20
+ # Miscellaneous
21
+ DEFAULT_TOXICITY_RULE_THRESHOLD = 0.5
22
+ DEFAULT_PII_RULE_CONFIDENCE_SCORE_THRESHOLD = 0
23
+ NEGATIVE_BLOOD_EXAMPLE = "John has O negative blood group"
24
+ HALLUCINATION_RULE_NAME = "Hallucination Rule"
@@ -1,15 +1,6 @@
1
1
  from enum import Enum
2
2
 
3
3
 
4
- class ModelProblemType(str, Enum):
5
- REGRESSION = "regression"
6
- BINARY_CLASSIFICATION = "binary_classification"
7
- ARTHUR_SHIELD = "arthur_shield"
8
- CUSTOM = "custom"
9
- MULTICLASS_CLASSIFICATION = "multiclass_classification"
10
- AGENTIC_TRACE = "agentic_trace"
11
-
12
-
13
4
  class DatasetFileType(str, Enum):
14
5
  JSON = "json"
15
6
  CSV = "csv"
@@ -0,0 +1,177 @@
1
+ from enum import IntEnum, StrEnum
2
+
3
+ from arthur_common.models.constants import (
4
+ DEFAULT_RULE_ADMIN,
5
+ ORG_ADMIN,
6
+ ORG_AUDITOR,
7
+ TASK_ADMIN,
8
+ VALIDATION_USER,
9
+ )
10
+
11
+
12
+ class BaseEnum(StrEnum):
13
+ @classmethod
14
+ def values(self) -> list[str]:
15
+ values: list[str] = [e for e in self]
16
+ return values
17
+
18
+ def __str__(self) -> str:
19
+ return str(self.value)
20
+
21
+
22
+ class APIKeysRolesEnum(BaseEnum):
23
+ DEFAULT_RULE_ADMIN = DEFAULT_RULE_ADMIN
24
+ TASK_ADMIN = TASK_ADMIN
25
+ VALIDATION_USER = VALIDATION_USER
26
+ ORG_AUDITOR = ORG_AUDITOR
27
+ ORG_ADMIN = ORG_ADMIN
28
+
29
+
30
+ class InferenceFeedbackTarget(BaseEnum):
31
+ CONTEXT = "context"
32
+ RESPONSE_RESULTS = "response_results"
33
+ PROMPT_RESULTS = "prompt_results"
34
+
35
+
36
+ class MetricType(BaseEnum):
37
+ QUERY_RELEVANCE = "QueryRelevance"
38
+ RESPONSE_RELEVANCE = "ResponseRelevance"
39
+ TOOL_SELECTION = "ToolSelection"
40
+
41
+
42
+ class ModelProblemType(BaseEnum):
43
+ REGRESSION = "regression"
44
+ BINARY_CLASSIFICATION = "binary_classification"
45
+ ARTHUR_SHIELD = "arthur_shield"
46
+ CUSTOM = "custom"
47
+ MULTICLASS_CLASSIFICATION = "multiclass_classification"
48
+ AGENTIC_TRACE = "agentic_trace"
49
+
50
+
51
+ class SnowflakeConnectorAuthenticatorMethods(BaseEnum):
52
+ SNOWFLAKE_KEY_PAIR = "snowflake_key_pair"
53
+ SNOWFLAKE_PASSWORD = "snowflake_password"
54
+
55
+
56
+ # Using version from arthur-engine, which has str and enum type inheritance.
57
+ # Note: These string values are not arbitrary and map to Presidio entity types: https://microsoft.github.io/presidio/supported_entities/
58
+ class PIIEntityTypes(BaseEnum):
59
+ CREDIT_CARD = "CREDIT_CARD"
60
+ CRYPTO = "CRYPTO"
61
+ DATE_TIME = "DATE_TIME"
62
+ EMAIL_ADDRESS = "EMAIL_ADDRESS"
63
+ IBAN_CODE = "IBAN_CODE"
64
+ IP_ADDRESS = "IP_ADDRESS"
65
+ NRP = "NRP"
66
+ LOCATION = "LOCATION"
67
+ PERSON = "PERSON"
68
+ PHONE_NUMBER = "PHONE_NUMBER"
69
+ MEDICAL_LICENSE = "MEDICAL_LICENSE"
70
+ URL = "URL"
71
+ US_BANK_NUMBER = "US_BANK_NUMBER"
72
+ US_DRIVER_LICENSE = "US_DRIVER_LICENSE"
73
+ US_ITIN = "US_ITIN"
74
+ US_PASSPORT = "US_PASSPORT"
75
+ US_SSN = "US_SSN"
76
+
77
+ @classmethod
78
+ def to_string(cls) -> str:
79
+ return ",".join(member.value for member in cls)
80
+
81
+
82
+ class PaginationSortMethod(BaseEnum):
83
+ ASCENDING = "asc"
84
+ DESCENDING = "desc"
85
+
86
+
87
+ class RuleResultEnum(BaseEnum):
88
+ PASS = "Pass"
89
+ FAIL = "Fail"
90
+ SKIPPED = "Skipped"
91
+ UNAVAILABLE = "Unavailable"
92
+ PARTIALLY_UNAVAILABLE = "Partially Unavailable"
93
+ MODEL_NOT_AVAILABLE = "Model Not Available"
94
+
95
+
96
+ class RuleScope(BaseEnum):
97
+ DEFAULT = "default"
98
+ TASK = "task"
99
+
100
+
101
+ class RuleType(BaseEnum):
102
+ KEYWORD = "KeywordRule"
103
+ MODEL_HALLUCINATION_V2 = "ModelHallucinationRuleV2"
104
+ MODEL_SENSITIVE_DATA = "ModelSensitiveDataRule"
105
+ PII_DATA = "PIIDataRule"
106
+ PROMPT_INJECTION = "PromptInjectionRule"
107
+ REGEX = "RegexRule"
108
+ TOXICITY = "ToxicityRule"
109
+
110
+
111
+ class TaskType(BaseEnum):
112
+ TRADITIONAL = "traditional"
113
+ AGENTIC = "agentic"
114
+
115
+
116
+ class TokenUsageScope(BaseEnum):
117
+ RULE_TYPE = "rule_type"
118
+ TASK = "task"
119
+
120
+
121
+ class ToolClassEnum(IntEnum):
122
+ INCORRECT = 0
123
+ CORRECT = 1
124
+ NA = 2
125
+
126
+ def __str__(self) -> str:
127
+ return str(self.value)
128
+
129
+
130
+ class ToxicityViolationType(BaseEnum):
131
+ BENIGN = "benign"
132
+ HARMFUL_REQUEST = "harmful_request"
133
+ TOXIC_CONTENT = "toxic_content"
134
+ PROFANITY = "profanity"
135
+ UNKNOWN = "unknown"
136
+
137
+
138
+ # If you added values here, did you update permission_mappings.py in arthur-engine?
139
+ class UserPermissionAction(BaseEnum):
140
+ CREATE = "create"
141
+ READ = "read"
142
+
143
+
144
+ # If you added values here, did you update permission_mappings.py in arthur-engine?
145
+ class UserPermissionResource(BaseEnum):
146
+ PROMPTS = "prompts"
147
+ RESPONSES = "responses"
148
+ RULES = "rules"
149
+ TASKS = "tasks"
150
+
151
+
152
+ class ComparisonOperatorEnum(BaseEnum):
153
+ EQUAL = "eq"
154
+ GREATER_THAN = "gt"
155
+ GREATER_THAN_OR_EQUAL = "gte"
156
+ LESS_THAN = "lt"
157
+ LESS_THAN_OR_EQUAL = "lte"
158
+
159
+
160
+ class StatusCodeEnum(BaseEnum):
161
+ OK = "Ok"
162
+ ERROR = "Error"
163
+ UNSET = "Unset"
164
+
165
+
166
+ class ContinuousEvalRunStatus(BaseEnum):
167
+ PENDING = "pending"
168
+ PASSED = "passed"
169
+ RUNNING = "running"
170
+ FAILED = "failed"
171
+ SKIPPED = "skipped"
172
+ ERROR = "error"
173
+
174
+
175
+ class AgenticAnnotationType(BaseEnum):
176
+ HUMAN = "human"
177
+ CONTINUOUS_EVAL = "continuous_eval"
@@ -0,0 +1,63 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class RelevanceMetricConfig(BaseModel):
7
+ """Configuration for relevance metrics including QueryRelevance and ResponseRelevance"""
8
+
9
+ relevance_threshold: Optional[float] = Field(
10
+ default=None,
11
+ description="Threshold for determining relevance when not using LLM judge",
12
+ )
13
+ use_llm_judge: bool = Field(
14
+ default=True,
15
+ description="Whether to use LLM as a judge for relevance scoring",
16
+ )
17
+
18
+
19
+ class RelevanceMetric(BaseModel):
20
+ bert_f_score: Optional[float] = None
21
+ reranker_relevance_score: Optional[float] = None
22
+ llm_relevance_score: Optional[float] = None
23
+ reason: Optional[str] = None
24
+ refinement: Optional[str] = None
25
+
26
+
27
+ class QueryRelevanceMetric(RelevanceMetric):
28
+ """Inherits from RelevanceMetric. This class is left empty so that the openapi response schema remains the same as before, but we have a single source of truth for the relevance metric details."""
29
+
30
+
31
+ class ResponseRelevanceMetric(RelevanceMetric):
32
+ """Inherits from RelevanceMetric. This class is left empty so that the openapi response schema remains the same as before, but we have a single source of truth for the relevance metric details."""
33
+
34
+
35
+ class MetricRequest(BaseModel):
36
+ system_prompt: Optional[str] = Field(
37
+ description="System prompt to be used by GenAI Engine for computing metrics.",
38
+ default=None,
39
+ )
40
+ user_query: Optional[str] = Field(
41
+ description="User query to be used by GenAI Engine for computing metrics.",
42
+ default=None,
43
+ )
44
+ context: List[Dict[str, Any]] = Field(
45
+ description="Conversation history and additional context to be used by GenAI Engine for computing metrics.",
46
+ default_factory=list,
47
+ examples=[
48
+ {"role": "user", "value": "What is the weather in Tokyo?"},
49
+ {"role": "assistant", "value": "WeatherTool", "args": {"city": "Tokyo"}},
50
+ {
51
+ "role": "tool",
52
+ "value": '[{"name": "WeatherTool", "result": {"temperature": "20°C", "humidity": "50%", "condition": "sunny"}}]',
53
+ },
54
+ {
55
+ "role": "assistant",
56
+ "value": "The weather in Tokyo is sunny and the temperature is 20°C.",
57
+ },
58
+ ],
59
+ )
60
+ response: Optional[str] = Field(
61
+ description="Response to be used by GenAI Engine for computing metrics.",
62
+ default=None,
63
+ )
@@ -7,7 +7,7 @@ from uuid import UUID
7
7
  from pydantic import BaseModel, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
9
9
 
10
- from arthur_common.models.datasets import ModelProblemType
10
+ from arthur_common.models.enums import ModelProblemType
11
11
  from arthur_common.models.schema_definitions import (
12
12
  DType,
13
13
  SchemaTypeUnion,
@@ -195,7 +195,7 @@ class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSc
195
195
 
196
196
  class MetricsColumnListParameterSchema(
197
197
  MetricsParameterSchema,
198
- BaseColumnParameterSchema,
198
+ BaseColumnBaseParameterSchema,
199
199
  ):
200
200
  # list column parameter schema specific to default metrics
201
201
  parameter_type: Literal["column_list"] = "column_list"
@@ -295,10 +295,3 @@ class ReportedCustomAggregation(BaseReportedAggregation):
295
295
  dimension_columns: list[str] = Field(
296
296
  description="Name of any dimension columns returned from the SQL query. Max length is 1.",
297
297
  )
298
-
299
- @field_validator("dimension_columns")
300
- @classmethod
301
- def validate_dimension_columns_length(cls, v: list[str]) -> str:
302
- if len(v) > 1:
303
- raise ValueError("Only one dimension column can be specified.")
304
- return v