arthur-common 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/__init__.py +0 -0
- arthur_common/__version__.py +1 -0
- arthur_common/aggregations/__init__.py +2 -0
- arthur_common/aggregations/aggregator.py +214 -0
- arthur_common/aggregations/functions/README.md +26 -0
- arthur_common/aggregations/functions/__init__.py +25 -0
- arthur_common/aggregations/functions/categorical_count.py +89 -0
- arthur_common/aggregations/functions/confusion_matrix.py +412 -0
- arthur_common/aggregations/functions/inference_count.py +69 -0
- arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
- arthur_common/aggregations/functions/inference_null_count.py +82 -0
- arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
- arthur_common/aggregations/functions/mean_squared_error.py +110 -0
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
- arthur_common/aggregations/functions/numeric_stats.py +90 -0
- arthur_common/aggregations/functions/numeric_sum.py +87 -0
- arthur_common/aggregations/functions/py.typed +0 -0
- arthur_common/aggregations/functions/shield_aggregations.py +752 -0
- arthur_common/aggregations/py.typed +0 -0
- arthur_common/models/__init__.py +0 -0
- arthur_common/models/connectors.py +41 -0
- arthur_common/models/datasets.py +22 -0
- arthur_common/models/metrics.py +227 -0
- arthur_common/models/py.typed +0 -0
- arthur_common/models/schema_definitions.py +420 -0
- arthur_common/models/shield.py +504 -0
- arthur_common/models/task_job_specs.py +78 -0
- arthur_common/py.typed +0 -0
- arthur_common/tools/__init__.py +0 -0
- arthur_common/tools/aggregation_analyzer.py +243 -0
- arthur_common/tools/aggregation_loader.py +59 -0
- arthur_common/tools/duckdb_data_loader.py +329 -0
- arthur_common/tools/functions.py +46 -0
- arthur_common/tools/py.typed +0 -0
- arthur_common/tools/schema_inferer.py +104 -0
- arthur_common/tools/time_utils.py +33 -0
- arthur_common-1.0.1.dist-info/METADATA +74 -0
- arthur_common-1.0.1.dist-info/RECORD +40 -0
- arthur_common-1.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Optional, Self, Union
|
|
5
|
+
from uuid import UUID, uuid4
|
|
6
|
+
|
|
7
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ScopeSchemaTag(str, Enum):
|
|
12
|
+
LLM_CONTEXT = "llm_context"
|
|
13
|
+
LLM_PROMPT = "llm_prompt"
|
|
14
|
+
LLM_RESPONSE = "llm_response"
|
|
15
|
+
PRIMARY_TIMESTAMP = "primary_timestamp"
|
|
16
|
+
CATEGORICAL = "categorical"
|
|
17
|
+
CONTINUOUS = "continuous"
|
|
18
|
+
PREDICTION = "prediction"
|
|
19
|
+
GROUND_TRUTH = "ground_truth"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DType(str, Enum):
|
|
23
|
+
UNDEFINED = "undefined"
|
|
24
|
+
INT = "int"
|
|
25
|
+
FLOAT = "float"
|
|
26
|
+
BOOL = "bool"
|
|
27
|
+
STRING = "str"
|
|
28
|
+
UUID = "uuid"
|
|
29
|
+
TIMESTAMP = "timestamp"
|
|
30
|
+
DATE = "date"
|
|
31
|
+
JSON = "json"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MetricParameterAnnotation(BaseModel):
|
|
35
|
+
optional: bool = Field(
|
|
36
|
+
False,
|
|
37
|
+
description="Boolean denoting if the parameter is optional.",
|
|
38
|
+
)
|
|
39
|
+
friendly_name: str = Field(
|
|
40
|
+
description="User facing name of the parameter.",
|
|
41
|
+
)
|
|
42
|
+
description: str = Field(
|
|
43
|
+
description="Description of the parameter.",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class MetricDatasetParameterAnnotation(MetricParameterAnnotation):
|
|
48
|
+
model_problem_type: Optional[ModelProblemType] = Field(
|
|
49
|
+
default=None,
|
|
50
|
+
description="Model problem type that is applicable to this parameter.",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MetricLiteralParameterAnnotation(MetricParameterAnnotation):
|
|
55
|
+
parameter_dtype: DType = Field(description="Data type of the parameter.")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MetricColumnParameterAnnotation(MetricParameterAnnotation):
|
|
59
|
+
tag_hints: list[ScopeSchemaTag] = Field(
|
|
60
|
+
[],
|
|
61
|
+
description="List of tags that are applicable to this parameter. Datasets with columns that have matching tags can be inferred this way.",
|
|
62
|
+
)
|
|
63
|
+
source_dataset_parameter_key: str = Field(
|
|
64
|
+
description="Name of the parameter that provides the dataset to be used for this column.",
|
|
65
|
+
)
|
|
66
|
+
allowed_column_types: Optional[list[SchemaTypeUnion]] = Field(
|
|
67
|
+
default=None,
|
|
68
|
+
description="List of column types applicable to this parameter",
|
|
69
|
+
)
|
|
70
|
+
allow_any_column_type: bool = Field(
|
|
71
|
+
False,
|
|
72
|
+
description="Indicates if this metric parameter can accept any column type.",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@model_validator(mode="after")
|
|
76
|
+
def column_type_combination_validator(self) -> Self:
|
|
77
|
+
if self.allowed_column_types and self.allow_any_column_type:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Parameter cannot allow any column while also explicitly listing applicable ones.",
|
|
80
|
+
)
|
|
81
|
+
return self
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class MetricMultipleColumnParameterAnnotation(MetricColumnParameterAnnotation):
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
MetricsParameterAnnotationUnion = (
|
|
89
|
+
MetricDatasetParameterAnnotation
|
|
90
|
+
| MetricLiteralParameterAnnotation
|
|
91
|
+
| MetricColumnParameterAnnotation
|
|
92
|
+
| MetricMultipleColumnParameterAnnotation
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class Type(BaseModel):
|
|
97
|
+
# There's bound to be something in common here eventually
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ScalarType(Type):
|
|
102
|
+
dtype: DType
|
|
103
|
+
|
|
104
|
+
def __hash__(self) -> int:
|
|
105
|
+
return hash(self.dtype)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class ObjectType(Type):
|
|
109
|
+
object: dict[str, SchemaTypeUnion]
|
|
110
|
+
|
|
111
|
+
def __getitem__(self, key: str) -> SchemaTypeUnion:
|
|
112
|
+
return self.object[key]
|
|
113
|
+
|
|
114
|
+
def __hash__(self) -> int:
|
|
115
|
+
# Combine the hash of all dictionary values
|
|
116
|
+
combined_hash = 0
|
|
117
|
+
for name, col in self.object.items():
|
|
118
|
+
combined_hash ^= hash(name)
|
|
119
|
+
combined_hash ^= hash(col)
|
|
120
|
+
return combined_hash
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class ListType(Type):
|
|
124
|
+
items: SchemaTypeUnion
|
|
125
|
+
|
|
126
|
+
def __hash__(self) -> int:
|
|
127
|
+
return hash(self.items)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class DatasetSchemaType(Type):
|
|
131
|
+
tag_hints: list[ScopeSchemaTag] = []
|
|
132
|
+
nullable: Optional[bool] = True
|
|
133
|
+
id: UUID = Field(default_factory=uuid4, description="Unique ID of the schema node.")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class DatasetScalarType(ScalarType, DatasetSchemaType):
|
|
137
|
+
def __hash__(self) -> int:
|
|
138
|
+
return hash(self.dtype)
|
|
139
|
+
|
|
140
|
+
def to_base_type(self) -> ScalarType:
|
|
141
|
+
return ScalarType(dtype=self.dtype)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class DatasetObjectType(DatasetSchemaType):
|
|
145
|
+
object: dict[str, DatasetSchemaTypeUnion]
|
|
146
|
+
|
|
147
|
+
def __getitem__(self, key: str) -> DatasetSchemaTypeUnion:
|
|
148
|
+
return self.object[key]
|
|
149
|
+
|
|
150
|
+
def __hash__(self) -> int:
|
|
151
|
+
# Combine the hash of all dictionary values
|
|
152
|
+
combined_hash = 0
|
|
153
|
+
for name, col in self.object.items():
|
|
154
|
+
combined_hash ^= hash(name)
|
|
155
|
+
combined_hash ^= hash(col)
|
|
156
|
+
return combined_hash
|
|
157
|
+
|
|
158
|
+
def to_base_type(self) -> ObjectType:
|
|
159
|
+
return ObjectType(object={k: v.to_base_type() for k, v in self.object.items()})
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class DatasetListType(DatasetSchemaType):
|
|
163
|
+
items: DatasetSchemaTypeUnion
|
|
164
|
+
|
|
165
|
+
def __hash__(self) -> int:
|
|
166
|
+
return hash(self.items)
|
|
167
|
+
|
|
168
|
+
def to_base_type(self) -> ListType:
|
|
169
|
+
return ListType(items=self.items.to_base_type())
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class DatasetColumn(BaseModel):
|
|
173
|
+
id: UUID = Field(default_factory=uuid4, description="Unique ID of the column.")
|
|
174
|
+
source_name: str
|
|
175
|
+
definition: DatasetSchemaTypeUnion
|
|
176
|
+
|
|
177
|
+
def __hash__(self) -> int:
|
|
178
|
+
combined_hash = 0
|
|
179
|
+
combined_hash ^= hash(self.source_name)
|
|
180
|
+
combined_hash ^= hash(self.definition)
|
|
181
|
+
return combined_hash
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class PutDatasetSchema(BaseModel):
|
|
185
|
+
alias_mask: dict[UUID, str]
|
|
186
|
+
columns: list[DatasetColumn]
|
|
187
|
+
|
|
188
|
+
def regenerate_ids(self) -> PutDatasetSchema:
|
|
189
|
+
new_columns = []
|
|
190
|
+
new_alias_mask = {}
|
|
191
|
+
for column in self.columns:
|
|
192
|
+
new_id = uuid4()
|
|
193
|
+
new_columns.append(
|
|
194
|
+
DatasetColumn(
|
|
195
|
+
id=new_id,
|
|
196
|
+
source_name=column.source_name,
|
|
197
|
+
definition=self._regenerate_definition_ids(column.definition),
|
|
198
|
+
),
|
|
199
|
+
)
|
|
200
|
+
if column.id in self.alias_mask:
|
|
201
|
+
new_alias_mask[new_id] = self.alias_mask[column.id]
|
|
202
|
+
|
|
203
|
+
self.columns = new_columns
|
|
204
|
+
self.alias_mask = new_alias_mask
|
|
205
|
+
return self
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def _regenerate_definition_ids(
|
|
209
|
+
definition: DatasetSchemaTypeUnion,
|
|
210
|
+
) -> DatasetSchemaTypeUnion:
|
|
211
|
+
new_def = definition.model_copy(deep=True)
|
|
212
|
+
new_def.id = uuid4()
|
|
213
|
+
|
|
214
|
+
if isinstance(new_def, DatasetObjectType):
|
|
215
|
+
new_def.object = {
|
|
216
|
+
k: PutDatasetSchema._regenerate_definition_ids(v)
|
|
217
|
+
for k, v in new_def.object.items()
|
|
218
|
+
}
|
|
219
|
+
elif isinstance(new_def, DatasetListType):
|
|
220
|
+
new_def.items = PutDatasetSchema._regenerate_definition_ids(new_def.items)
|
|
221
|
+
|
|
222
|
+
return new_def
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# This needs to be a separate model than PutDatasetSchema because of this generated field not being consumed correctly by the client generator
|
|
226
|
+
# Issue tracked externally here: https://github.com/OpenAPITools/openapi-generator/issues/4190
|
|
227
|
+
class DatasetSchema(PutDatasetSchema):
|
|
228
|
+
@computed_field # type: ignore[prop-decorator]
|
|
229
|
+
@property
|
|
230
|
+
def column_names(self) -> dict[UUID, str]:
|
|
231
|
+
col_names = {column.id: column.source_name for column in self.columns}
|
|
232
|
+
col_names.update(self.alias_mask)
|
|
233
|
+
return col_names
|
|
234
|
+
|
|
235
|
+
model_config = ConfigDict(
|
|
236
|
+
json_schema_mode_override="serialization",
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
SchemaTypeUnion = Union[ScalarType, ObjectType, ListType]
|
|
241
|
+
DatasetSchemaTypeUnion = Union[DatasetScalarType, DatasetObjectType, DatasetListType]
|
|
242
|
+
|
|
243
|
+
from uuid import uuid4
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def create_dataset_scalar_type(dtype: DType) -> DatasetScalarType:
|
|
247
|
+
return DatasetScalarType(id=uuid4(), dtype=dtype)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def create_dataset_object_type(
|
|
251
|
+
object_dict: dict[str, DatasetSchemaTypeUnion],
|
|
252
|
+
) -> DatasetObjectType:
|
|
253
|
+
return DatasetObjectType(id=uuid4(), object={k: v for k, v in object_dict.items()})
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def create_dataset_list_type(items: DatasetSchemaTypeUnion) -> DatasetListType:
|
|
257
|
+
return DatasetListType(id=uuid4(), items=items)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def create_shield_rule_results_schema() -> DatasetListType:
|
|
261
|
+
return create_dataset_list_type(
|
|
262
|
+
create_dataset_object_type(
|
|
263
|
+
{
|
|
264
|
+
"id": create_dataset_scalar_type(DType.UUID),
|
|
265
|
+
"name": create_dataset_scalar_type(DType.STRING),
|
|
266
|
+
"rule_type": create_dataset_scalar_type(DType.STRING),
|
|
267
|
+
"scope": create_dataset_scalar_type(DType.STRING),
|
|
268
|
+
"result": create_dataset_scalar_type(DType.STRING),
|
|
269
|
+
"latency_ms": create_dataset_scalar_type(DType.INT),
|
|
270
|
+
"details": create_dataset_object_type(
|
|
271
|
+
{
|
|
272
|
+
"message": create_dataset_scalar_type(DType.STRING),
|
|
273
|
+
"claims": create_dataset_list_type(
|
|
274
|
+
create_dataset_object_type(
|
|
275
|
+
{
|
|
276
|
+
"claim": create_dataset_scalar_type(DType.STRING),
|
|
277
|
+
"valid": create_dataset_scalar_type(DType.BOOL),
|
|
278
|
+
"reason": create_dataset_scalar_type(DType.STRING),
|
|
279
|
+
},
|
|
280
|
+
),
|
|
281
|
+
),
|
|
282
|
+
"pii_entities": create_dataset_list_type(
|
|
283
|
+
create_dataset_object_type(
|
|
284
|
+
{
|
|
285
|
+
"entity": create_dataset_scalar_type(DType.STRING),
|
|
286
|
+
"span": create_dataset_scalar_type(DType.STRING),
|
|
287
|
+
"confidence": create_dataset_scalar_type(
|
|
288
|
+
DType.FLOAT,
|
|
289
|
+
),
|
|
290
|
+
},
|
|
291
|
+
),
|
|
292
|
+
),
|
|
293
|
+
"toxicity_score": create_dataset_scalar_type(DType.FLOAT),
|
|
294
|
+
"regex_matches": create_dataset_list_type(
|
|
295
|
+
create_dataset_object_type(
|
|
296
|
+
{
|
|
297
|
+
"matching_text": create_dataset_scalar_type(
|
|
298
|
+
DType.STRING,
|
|
299
|
+
),
|
|
300
|
+
"pattern": create_dataset_scalar_type(DType.STRING),
|
|
301
|
+
},
|
|
302
|
+
),
|
|
303
|
+
),
|
|
304
|
+
"keyword_matches": create_dataset_list_type(
|
|
305
|
+
create_dataset_object_type(
|
|
306
|
+
{
|
|
307
|
+
"keyword": create_dataset_scalar_type(DType.STRING),
|
|
308
|
+
},
|
|
309
|
+
),
|
|
310
|
+
),
|
|
311
|
+
},
|
|
312
|
+
),
|
|
313
|
+
},
|
|
314
|
+
),
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def create_shield_prompt_schema() -> DatasetObjectType:
|
|
319
|
+
return create_dataset_object_type(
|
|
320
|
+
{
|
|
321
|
+
"id": create_dataset_scalar_type(DType.UUID),
|
|
322
|
+
"inference_id": create_dataset_scalar_type(DType.UUID),
|
|
323
|
+
"result": create_dataset_scalar_type(DType.STRING),
|
|
324
|
+
"created_at": create_dataset_scalar_type(DType.INT),
|
|
325
|
+
"updated_at": create_dataset_scalar_type(DType.INT),
|
|
326
|
+
"message": create_dataset_scalar_type(DType.STRING),
|
|
327
|
+
"prompt_rule_results": create_shield_rule_results_schema(),
|
|
328
|
+
"tokens": create_dataset_scalar_type(DType.INT),
|
|
329
|
+
},
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def create_shield_response_schema() -> DatasetObjectType:
|
|
334
|
+
return create_dataset_object_type(
|
|
335
|
+
{
|
|
336
|
+
"id": create_dataset_scalar_type(DType.UUID),
|
|
337
|
+
"inference_id": create_dataset_scalar_type(DType.UUID),
|
|
338
|
+
"result": create_dataset_scalar_type(DType.STRING),
|
|
339
|
+
"created_at": create_dataset_scalar_type(DType.INT),
|
|
340
|
+
"updated_at": create_dataset_scalar_type(DType.INT),
|
|
341
|
+
"message": create_dataset_scalar_type(DType.STRING),
|
|
342
|
+
"context": create_dataset_scalar_type(DType.STRING),
|
|
343
|
+
"response_rule_results": create_shield_rule_results_schema(),
|
|
344
|
+
"tokens": create_dataset_scalar_type(DType.INT),
|
|
345
|
+
},
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def create_shield_inference_feedback_schema() -> DatasetListType:
|
|
350
|
+
return create_dataset_list_type(
|
|
351
|
+
create_dataset_object_type(
|
|
352
|
+
{
|
|
353
|
+
"id": create_dataset_scalar_type(DType.UUID),
|
|
354
|
+
"inference_id": create_dataset_scalar_type(DType.UUID),
|
|
355
|
+
"target": create_dataset_scalar_type(DType.STRING),
|
|
356
|
+
"score": create_dataset_scalar_type(DType.FLOAT),
|
|
357
|
+
"reason": create_dataset_scalar_type(DType.STRING),
|
|
358
|
+
"user_id": create_dataset_scalar_type(DType.UUID),
|
|
359
|
+
"created_at": create_dataset_scalar_type(DType.TIMESTAMP),
|
|
360
|
+
"updated_at": create_dataset_scalar_type(DType.TIMESTAMP),
|
|
361
|
+
},
|
|
362
|
+
),
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def SHIELD_SCHEMA() -> DatasetSchema:
|
|
367
|
+
return DatasetSchema(
|
|
368
|
+
alias_mask={},
|
|
369
|
+
columns=[
|
|
370
|
+
DatasetColumn(
|
|
371
|
+
id=uuid4(),
|
|
372
|
+
source_name="id",
|
|
373
|
+
definition=create_dataset_scalar_type(DType.UUID),
|
|
374
|
+
),
|
|
375
|
+
DatasetColumn(
|
|
376
|
+
id=uuid4(),
|
|
377
|
+
source_name="result",
|
|
378
|
+
definition=create_dataset_scalar_type(DType.STRING),
|
|
379
|
+
),
|
|
380
|
+
DatasetColumn(
|
|
381
|
+
id=uuid4(),
|
|
382
|
+
source_name="created_at",
|
|
383
|
+
definition=create_dataset_scalar_type(DType.INT),
|
|
384
|
+
),
|
|
385
|
+
DatasetColumn(
|
|
386
|
+
id=uuid4(),
|
|
387
|
+
source_name="updated_at",
|
|
388
|
+
definition=create_dataset_scalar_type(DType.INT),
|
|
389
|
+
),
|
|
390
|
+
DatasetColumn(
|
|
391
|
+
id=uuid4(),
|
|
392
|
+
source_name="task_id",
|
|
393
|
+
definition=create_dataset_scalar_type(DType.UUID),
|
|
394
|
+
),
|
|
395
|
+
DatasetColumn(
|
|
396
|
+
id=uuid4(),
|
|
397
|
+
source_name="conversation_id",
|
|
398
|
+
definition=create_dataset_scalar_type(DType.STRING),
|
|
399
|
+
),
|
|
400
|
+
DatasetColumn(
|
|
401
|
+
id=uuid4(),
|
|
402
|
+
source_name="inference_prompt",
|
|
403
|
+
definition=create_shield_prompt_schema(),
|
|
404
|
+
),
|
|
405
|
+
DatasetColumn(
|
|
406
|
+
id=uuid4(),
|
|
407
|
+
source_name="inference_response",
|
|
408
|
+
definition=create_shield_response_schema(),
|
|
409
|
+
),
|
|
410
|
+
DatasetColumn(
|
|
411
|
+
id=uuid4(),
|
|
412
|
+
source_name="inference_feedback",
|
|
413
|
+
definition=create_shield_inference_feedback_schema(),
|
|
414
|
+
),
|
|
415
|
+
],
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
SHIELD_RESPONSE_SCHEMA = create_shield_response_schema().to_base_type()
|
|
420
|
+
SHIELD_PROMPT_SCHEMA = create_shield_prompt_schema().to_base_type()
|