arthur-common 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

Files changed (40) hide show
  1. arthur_common/__init__.py +0 -0
  2. arthur_common/__version__.py +1 -0
  3. arthur_common/aggregations/__init__.py +2 -0
  4. arthur_common/aggregations/aggregator.py +214 -0
  5. arthur_common/aggregations/functions/README.md +26 -0
  6. arthur_common/aggregations/functions/__init__.py +25 -0
  7. arthur_common/aggregations/functions/categorical_count.py +89 -0
  8. arthur_common/aggregations/functions/confusion_matrix.py +412 -0
  9. arthur_common/aggregations/functions/inference_count.py +69 -0
  10. arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
  11. arthur_common/aggregations/functions/inference_null_count.py +82 -0
  12. arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
  13. arthur_common/aggregations/functions/mean_squared_error.py +110 -0
  14. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
  15. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
  16. arthur_common/aggregations/functions/numeric_stats.py +90 -0
  17. arthur_common/aggregations/functions/numeric_sum.py +87 -0
  18. arthur_common/aggregations/functions/py.typed +0 -0
  19. arthur_common/aggregations/functions/shield_aggregations.py +752 -0
  20. arthur_common/aggregations/py.typed +0 -0
  21. arthur_common/models/__init__.py +0 -0
  22. arthur_common/models/connectors.py +41 -0
  23. arthur_common/models/datasets.py +22 -0
  24. arthur_common/models/metrics.py +227 -0
  25. arthur_common/models/py.typed +0 -0
  26. arthur_common/models/schema_definitions.py +420 -0
  27. arthur_common/models/shield.py +504 -0
  28. arthur_common/models/task_job_specs.py +78 -0
  29. arthur_common/py.typed +0 -0
  30. arthur_common/tools/__init__.py +0 -0
  31. arthur_common/tools/aggregation_analyzer.py +243 -0
  32. arthur_common/tools/aggregation_loader.py +59 -0
  33. arthur_common/tools/duckdb_data_loader.py +329 -0
  34. arthur_common/tools/functions.py +46 -0
  35. arthur_common/tools/py.typed +0 -0
  36. arthur_common/tools/schema_inferer.py +104 -0
  37. arthur_common/tools/time_utils.py +33 -0
  38. arthur_common-1.0.1.dist-info/METADATA +74 -0
  39. arthur_common-1.0.1.dist-info/RECORD +40 -0
  40. arthur_common-1.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,420 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Optional, Self, Union
5
+ from uuid import UUID, uuid4
6
+
7
+ from arthur_common.models.datasets import ModelProblemType
8
+ from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator
9
+
10
+
11
+ class ScopeSchemaTag(str, Enum):
12
+ LLM_CONTEXT = "llm_context"
13
+ LLM_PROMPT = "llm_prompt"
14
+ LLM_RESPONSE = "llm_response"
15
+ PRIMARY_TIMESTAMP = "primary_timestamp"
16
+ CATEGORICAL = "categorical"
17
+ CONTINUOUS = "continuous"
18
+ PREDICTION = "prediction"
19
+ GROUND_TRUTH = "ground_truth"
20
+
21
+
22
+ class DType(str, Enum):
23
+ UNDEFINED = "undefined"
24
+ INT = "int"
25
+ FLOAT = "float"
26
+ BOOL = "bool"
27
+ STRING = "str"
28
+ UUID = "uuid"
29
+ TIMESTAMP = "timestamp"
30
+ DATE = "date"
31
+ JSON = "json"
32
+
33
+
34
+ class MetricParameterAnnotation(BaseModel):
35
+ optional: bool = Field(
36
+ False,
37
+ description="Boolean denoting if the parameter is optional.",
38
+ )
39
+ friendly_name: str = Field(
40
+ description="User facing name of the parameter.",
41
+ )
42
+ description: str = Field(
43
+ description="Description of the parameter.",
44
+ )
45
+
46
+
47
+ class MetricDatasetParameterAnnotation(MetricParameterAnnotation):
48
+ model_problem_type: Optional[ModelProblemType] = Field(
49
+ default=None,
50
+ description="Model problem type that is applicable to this parameter.",
51
+ )
52
+
53
+
54
+ class MetricLiteralParameterAnnotation(MetricParameterAnnotation):
55
+ parameter_dtype: DType = Field(description="Data type of the parameter.")
56
+
57
+
58
+ class MetricColumnParameterAnnotation(MetricParameterAnnotation):
59
+ tag_hints: list[ScopeSchemaTag] = Field(
60
+ [],
61
+ description="List of tags that are applicable to this parameter. Datasets with columns that have matching tags can be inferred this way.",
62
+ )
63
+ source_dataset_parameter_key: str = Field(
64
+ description="Name of the parameter that provides the dataset to be used for this column.",
65
+ )
66
+ allowed_column_types: Optional[list[SchemaTypeUnion]] = Field(
67
+ default=None,
68
+ description="List of column types applicable to this parameter",
69
+ )
70
+ allow_any_column_type: bool = Field(
71
+ False,
72
+ description="Indicates if this metric parameter can accept any column type.",
73
+ )
74
+
75
+ @model_validator(mode="after")
76
+ def column_type_combination_validator(self) -> Self:
77
+ if self.allowed_column_types and self.allow_any_column_type:
78
+ raise ValueError(
79
+ "Parameter cannot allow any column while also explicitly listing applicable ones.",
80
+ )
81
+ return self
82
+
83
+
84
+ class MetricMultipleColumnParameterAnnotation(MetricColumnParameterAnnotation):
85
+ pass
86
+
87
+
88
+ MetricsParameterAnnotationUnion = (
89
+ MetricDatasetParameterAnnotation
90
+ | MetricLiteralParameterAnnotation
91
+ | MetricColumnParameterAnnotation
92
+ | MetricMultipleColumnParameterAnnotation
93
+ )
94
+
95
+
96
+ class Type(BaseModel):
97
+ # There's bound to be something in common here eventually
98
+ pass
99
+
100
+
101
+ class ScalarType(Type):
102
+ dtype: DType
103
+
104
+ def __hash__(self) -> int:
105
+ return hash(self.dtype)
106
+
107
+
108
+ class ObjectType(Type):
109
+ object: dict[str, SchemaTypeUnion]
110
+
111
+ def __getitem__(self, key: str) -> SchemaTypeUnion:
112
+ return self.object[key]
113
+
114
+ def __hash__(self) -> int:
115
+ # Combine the hash of all dictionary values
116
+ combined_hash = 0
117
+ for name, col in self.object.items():
118
+ combined_hash ^= hash(name)
119
+ combined_hash ^= hash(col)
120
+ return combined_hash
121
+
122
+
123
+ class ListType(Type):
124
+ items: SchemaTypeUnion
125
+
126
+ def __hash__(self) -> int:
127
+ return hash(self.items)
128
+
129
+
130
+ class DatasetSchemaType(Type):
131
+ tag_hints: list[ScopeSchemaTag] = []
132
+ nullable: Optional[bool] = True
133
+ id: UUID = Field(default_factory=uuid4, description="Unique ID of the schema node.")
134
+
135
+
136
+ class DatasetScalarType(ScalarType, DatasetSchemaType):
137
+ def __hash__(self) -> int:
138
+ return hash(self.dtype)
139
+
140
+ def to_base_type(self) -> ScalarType:
141
+ return ScalarType(dtype=self.dtype)
142
+
143
+
144
+ class DatasetObjectType(DatasetSchemaType):
145
+ object: dict[str, DatasetSchemaTypeUnion]
146
+
147
+ def __getitem__(self, key: str) -> DatasetSchemaTypeUnion:
148
+ return self.object[key]
149
+
150
+ def __hash__(self) -> int:
151
+ # Combine the hash of all dictionary values
152
+ combined_hash = 0
153
+ for name, col in self.object.items():
154
+ combined_hash ^= hash(name)
155
+ combined_hash ^= hash(col)
156
+ return combined_hash
157
+
158
+ def to_base_type(self) -> ObjectType:
159
+ return ObjectType(object={k: v.to_base_type() for k, v in self.object.items()})
160
+
161
+
162
+ class DatasetListType(DatasetSchemaType):
163
+ items: DatasetSchemaTypeUnion
164
+
165
+ def __hash__(self) -> int:
166
+ return hash(self.items)
167
+
168
+ def to_base_type(self) -> ListType:
169
+ return ListType(items=self.items.to_base_type())
170
+
171
+
172
+ class DatasetColumn(BaseModel):
173
+ id: UUID = Field(default_factory=uuid4, description="Unique ID of the column.")
174
+ source_name: str
175
+ definition: DatasetSchemaTypeUnion
176
+
177
+ def __hash__(self) -> int:
178
+ combined_hash = 0
179
+ combined_hash ^= hash(self.source_name)
180
+ combined_hash ^= hash(self.definition)
181
+ return combined_hash
182
+
183
+
184
+ class PutDatasetSchema(BaseModel):
185
+ alias_mask: dict[UUID, str]
186
+ columns: list[DatasetColumn]
187
+
188
+ def regenerate_ids(self) -> PutDatasetSchema:
189
+ new_columns = []
190
+ new_alias_mask = {}
191
+ for column in self.columns:
192
+ new_id = uuid4()
193
+ new_columns.append(
194
+ DatasetColumn(
195
+ id=new_id,
196
+ source_name=column.source_name,
197
+ definition=self._regenerate_definition_ids(column.definition),
198
+ ),
199
+ )
200
+ if column.id in self.alias_mask:
201
+ new_alias_mask[new_id] = self.alias_mask[column.id]
202
+
203
+ self.columns = new_columns
204
+ self.alias_mask = new_alias_mask
205
+ return self
206
+
207
+ @staticmethod
208
+ def _regenerate_definition_ids(
209
+ definition: DatasetSchemaTypeUnion,
210
+ ) -> DatasetSchemaTypeUnion:
211
+ new_def = definition.model_copy(deep=True)
212
+ new_def.id = uuid4()
213
+
214
+ if isinstance(new_def, DatasetObjectType):
215
+ new_def.object = {
216
+ k: PutDatasetSchema._regenerate_definition_ids(v)
217
+ for k, v in new_def.object.items()
218
+ }
219
+ elif isinstance(new_def, DatasetListType):
220
+ new_def.items = PutDatasetSchema._regenerate_definition_ids(new_def.items)
221
+
222
+ return new_def
223
+
224
+
225
+ # This needs to be a separate model than PutDatasetSchema because of this generated field not being consumed correctly by the client generator
226
+ # Issue tracked externally here: https://github.com/OpenAPITools/openapi-generator/issues/4190
227
+ class DatasetSchema(PutDatasetSchema):
228
+ @computed_field # type: ignore[prop-decorator]
229
+ @property
230
+ def column_names(self) -> dict[UUID, str]:
231
+ col_names = {column.id: column.source_name for column in self.columns}
232
+ col_names.update(self.alias_mask)
233
+ return col_names
234
+
235
+ model_config = ConfigDict(
236
+ json_schema_mode_override="serialization",
237
+ )
238
+
239
+
240
+ SchemaTypeUnion = Union[ScalarType, ObjectType, ListType]
241
+ DatasetSchemaTypeUnion = Union[DatasetScalarType, DatasetObjectType, DatasetListType]
242
+
243
+ from uuid import uuid4
244
+
245
+
246
+ def create_dataset_scalar_type(dtype: DType) -> DatasetScalarType:
247
+ return DatasetScalarType(id=uuid4(), dtype=dtype)
248
+
249
+
250
+ def create_dataset_object_type(
251
+ object_dict: dict[str, DatasetSchemaTypeUnion],
252
+ ) -> DatasetObjectType:
253
+ return DatasetObjectType(id=uuid4(), object={k: v for k, v in object_dict.items()})
254
+
255
+
256
+ def create_dataset_list_type(items: DatasetSchemaTypeUnion) -> DatasetListType:
257
+ return DatasetListType(id=uuid4(), items=items)
258
+
259
+
260
+ def create_shield_rule_results_schema() -> DatasetListType:
261
+ return create_dataset_list_type(
262
+ create_dataset_object_type(
263
+ {
264
+ "id": create_dataset_scalar_type(DType.UUID),
265
+ "name": create_dataset_scalar_type(DType.STRING),
266
+ "rule_type": create_dataset_scalar_type(DType.STRING),
267
+ "scope": create_dataset_scalar_type(DType.STRING),
268
+ "result": create_dataset_scalar_type(DType.STRING),
269
+ "latency_ms": create_dataset_scalar_type(DType.INT),
270
+ "details": create_dataset_object_type(
271
+ {
272
+ "message": create_dataset_scalar_type(DType.STRING),
273
+ "claims": create_dataset_list_type(
274
+ create_dataset_object_type(
275
+ {
276
+ "claim": create_dataset_scalar_type(DType.STRING),
277
+ "valid": create_dataset_scalar_type(DType.BOOL),
278
+ "reason": create_dataset_scalar_type(DType.STRING),
279
+ },
280
+ ),
281
+ ),
282
+ "pii_entities": create_dataset_list_type(
283
+ create_dataset_object_type(
284
+ {
285
+ "entity": create_dataset_scalar_type(DType.STRING),
286
+ "span": create_dataset_scalar_type(DType.STRING),
287
+ "confidence": create_dataset_scalar_type(
288
+ DType.FLOAT,
289
+ ),
290
+ },
291
+ ),
292
+ ),
293
+ "toxicity_score": create_dataset_scalar_type(DType.FLOAT),
294
+ "regex_matches": create_dataset_list_type(
295
+ create_dataset_object_type(
296
+ {
297
+ "matching_text": create_dataset_scalar_type(
298
+ DType.STRING,
299
+ ),
300
+ "pattern": create_dataset_scalar_type(DType.STRING),
301
+ },
302
+ ),
303
+ ),
304
+ "keyword_matches": create_dataset_list_type(
305
+ create_dataset_object_type(
306
+ {
307
+ "keyword": create_dataset_scalar_type(DType.STRING),
308
+ },
309
+ ),
310
+ ),
311
+ },
312
+ ),
313
+ },
314
+ ),
315
+ )
316
+
317
+
318
+ def create_shield_prompt_schema() -> DatasetObjectType:
319
+ return create_dataset_object_type(
320
+ {
321
+ "id": create_dataset_scalar_type(DType.UUID),
322
+ "inference_id": create_dataset_scalar_type(DType.UUID),
323
+ "result": create_dataset_scalar_type(DType.STRING),
324
+ "created_at": create_dataset_scalar_type(DType.INT),
325
+ "updated_at": create_dataset_scalar_type(DType.INT),
326
+ "message": create_dataset_scalar_type(DType.STRING),
327
+ "prompt_rule_results": create_shield_rule_results_schema(),
328
+ "tokens": create_dataset_scalar_type(DType.INT),
329
+ },
330
+ )
331
+
332
+
333
+ def create_shield_response_schema() -> DatasetObjectType:
334
+ return create_dataset_object_type(
335
+ {
336
+ "id": create_dataset_scalar_type(DType.UUID),
337
+ "inference_id": create_dataset_scalar_type(DType.UUID),
338
+ "result": create_dataset_scalar_type(DType.STRING),
339
+ "created_at": create_dataset_scalar_type(DType.INT),
340
+ "updated_at": create_dataset_scalar_type(DType.INT),
341
+ "message": create_dataset_scalar_type(DType.STRING),
342
+ "context": create_dataset_scalar_type(DType.STRING),
343
+ "response_rule_results": create_shield_rule_results_schema(),
344
+ "tokens": create_dataset_scalar_type(DType.INT),
345
+ },
346
+ )
347
+
348
+
349
+ def create_shield_inference_feedback_schema() -> DatasetListType:
350
+ return create_dataset_list_type(
351
+ create_dataset_object_type(
352
+ {
353
+ "id": create_dataset_scalar_type(DType.UUID),
354
+ "inference_id": create_dataset_scalar_type(DType.UUID),
355
+ "target": create_dataset_scalar_type(DType.STRING),
356
+ "score": create_dataset_scalar_type(DType.FLOAT),
357
+ "reason": create_dataset_scalar_type(DType.STRING),
358
+ "user_id": create_dataset_scalar_type(DType.UUID),
359
+ "created_at": create_dataset_scalar_type(DType.TIMESTAMP),
360
+ "updated_at": create_dataset_scalar_type(DType.TIMESTAMP),
361
+ },
362
+ ),
363
+ )
364
+
365
+
366
+ def SHIELD_SCHEMA() -> DatasetSchema:
367
+ return DatasetSchema(
368
+ alias_mask={},
369
+ columns=[
370
+ DatasetColumn(
371
+ id=uuid4(),
372
+ source_name="id",
373
+ definition=create_dataset_scalar_type(DType.UUID),
374
+ ),
375
+ DatasetColumn(
376
+ id=uuid4(),
377
+ source_name="result",
378
+ definition=create_dataset_scalar_type(DType.STRING),
379
+ ),
380
+ DatasetColumn(
381
+ id=uuid4(),
382
+ source_name="created_at",
383
+ definition=create_dataset_scalar_type(DType.INT),
384
+ ),
385
+ DatasetColumn(
386
+ id=uuid4(),
387
+ source_name="updated_at",
388
+ definition=create_dataset_scalar_type(DType.INT),
389
+ ),
390
+ DatasetColumn(
391
+ id=uuid4(),
392
+ source_name="task_id",
393
+ definition=create_dataset_scalar_type(DType.UUID),
394
+ ),
395
+ DatasetColumn(
396
+ id=uuid4(),
397
+ source_name="conversation_id",
398
+ definition=create_dataset_scalar_type(DType.STRING),
399
+ ),
400
+ DatasetColumn(
401
+ id=uuid4(),
402
+ source_name="inference_prompt",
403
+ definition=create_shield_prompt_schema(),
404
+ ),
405
+ DatasetColumn(
406
+ id=uuid4(),
407
+ source_name="inference_response",
408
+ definition=create_shield_response_schema(),
409
+ ),
410
+ DatasetColumn(
411
+ id=uuid4(),
412
+ source_name="inference_feedback",
413
+ definition=create_shield_inference_feedback_schema(),
414
+ ),
415
+ ],
416
+ )
417
+
418
+
419
+ SHIELD_RESPONSE_SCHEMA = create_shield_response_schema().to_base_type()
420
+ SHIELD_PROMPT_SCHEMA = create_shield_prompt_schema().to_base_type()