arthur-common 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/__init__.py +0 -0
- arthur_common/__version__.py +1 -0
- arthur_common/aggregations/__init__.py +2 -0
- arthur_common/aggregations/aggregator.py +214 -0
- arthur_common/aggregations/functions/README.md +26 -0
- arthur_common/aggregations/functions/__init__.py +25 -0
- arthur_common/aggregations/functions/categorical_count.py +89 -0
- arthur_common/aggregations/functions/confusion_matrix.py +412 -0
- arthur_common/aggregations/functions/inference_count.py +69 -0
- arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
- arthur_common/aggregations/functions/inference_null_count.py +82 -0
- arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
- arthur_common/aggregations/functions/mean_squared_error.py +110 -0
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
- arthur_common/aggregations/functions/numeric_stats.py +90 -0
- arthur_common/aggregations/functions/numeric_sum.py +87 -0
- arthur_common/aggregations/functions/py.typed +0 -0
- arthur_common/aggregations/functions/shield_aggregations.py +752 -0
- arthur_common/aggregations/py.typed +0 -0
- arthur_common/models/__init__.py +0 -0
- arthur_common/models/connectors.py +41 -0
- arthur_common/models/datasets.py +22 -0
- arthur_common/models/metrics.py +227 -0
- arthur_common/models/py.typed +0 -0
- arthur_common/models/schema_definitions.py +420 -0
- arthur_common/models/shield.py +504 -0
- arthur_common/models/task_job_specs.py +78 -0
- arthur_common/py.typed +0 -0
- arthur_common/tools/__init__.py +0 -0
- arthur_common/tools/aggregation_analyzer.py +243 -0
- arthur_common/tools/aggregation_loader.py +59 -0
- arthur_common/tools/duckdb_data_loader.py +329 -0
- arthur_common/tools/functions.py +46 -0
- arthur_common/tools/py.typed +0 -0
- arthur_common/tools/schema_inferer.py +104 -0
- arthur_common/tools/time_utils.py +33 -0
- arthur_common-1.0.1.dist-info/METADATA +74 -0
- arthur_common-1.0.1.dist-info/RECORD +40 -0
- arthur_common-1.0.1.dist-info/WHEEL +4 -0
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ConnectorPaginationOptions(BaseModel):
|
|
5
|
+
page: int = Field(default=1, ge=1)
|
|
6
|
+
page_size: int = Field(default=25, gt=0, le=500)
|
|
7
|
+
# comment to run pipeline
|
|
8
|
+
|
|
9
|
+
@property
|
|
10
|
+
def page_params(self) -> tuple[int, int]:
|
|
11
|
+
if self.page is not None:
|
|
12
|
+
return self.page, self.page_size
|
|
13
|
+
else:
|
|
14
|
+
raise ValueError(
|
|
15
|
+
"Pagination options must be set to return a page and page size",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# connector constants
|
|
20
|
+
S3_CONNECTOR_ENDPOINT_FIELD = "endpoint"
|
|
21
|
+
AWS_CONNECTOR_REGION_FIELD = "region"
|
|
22
|
+
AWS_CONNECTOR_ACCESS_KEY_ID_FIELD = "access_key_id"
|
|
23
|
+
AWS_CONNECTOR_SECRET_ACCESS_KEY_FIELD = "secret_access_key"
|
|
24
|
+
AWS_CONNECTOR_ROLE_ARN_FIELD = "role_arn"
|
|
25
|
+
AWS_CONNECTOR_EXTERNAL_ID_FIELD = "external_id"
|
|
26
|
+
AWS_CONNECTOR_ROLE_DURATION_SECONDS_FIELD = "role_duration_seconds"
|
|
27
|
+
BUCKET_BASED_CONNECTOR_BUCKET_FIELD = "bucket"
|
|
28
|
+
GOOGLE_CONNECTOR_CREDENTIALS_FIELD = "credentials"
|
|
29
|
+
GOOGLE_CONNECTOR_PROJECT_ID_FIELD = "project_id"
|
|
30
|
+
GOOGLE_CONNECTOR_LOCATION_FIELD = "location"
|
|
31
|
+
SHIELD_CONNECTOR_API_KEY_FIELD = "api_key"
|
|
32
|
+
SHIELD_CONNECTOR_ENDPOINT_FIELD = "endpoint"
|
|
33
|
+
|
|
34
|
+
# dataset (connector type dependent) constants
|
|
35
|
+
SHIELD_DATASET_TASK_ID_FIELD = "task_id"
|
|
36
|
+
BUCKET_BASED_DATASET_FILE_PREFIX_FIELD = "file_prefix"
|
|
37
|
+
BUCKET_BASED_DATASET_FILE_SUFFIX_FIELD = "file_suffix"
|
|
38
|
+
BUCKET_BASED_DATASET_FILE_TYPE_FIELD = "data_file_type"
|
|
39
|
+
BUCKET_BASED_DATASET_TIMESTAMP_TIME_ZONE_FIELD = "timestamp_time_zone"
|
|
40
|
+
BIG_QUERY_DATASET_TABLE_NAME_FIELD = "table_name"
|
|
41
|
+
BIG_QUERY_DATASET_DATASET_ID_FIELD = "dataset_id"
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ModelProblemType(str, Enum):
|
|
5
|
+
REGRESSION = "regression"
|
|
6
|
+
BINARY_CLASSIFICATION = "binary_classification"
|
|
7
|
+
ARTHUR_SHIELD = "arthur_shield"
|
|
8
|
+
CUSTOM = "custom"
|
|
9
|
+
MULTICLASS_CLASSIFICATION = "multiclass_classification"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DatasetFileType(str, Enum):
|
|
13
|
+
JSON = "json"
|
|
14
|
+
CSV = "csv"
|
|
15
|
+
PARQUET = "parquet"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DatasetJoinKind(str, Enum):
|
|
19
|
+
INNER = "inner"
|
|
20
|
+
LEFT_OUTER = "left_outer"
|
|
21
|
+
OUTER = "outer"
|
|
22
|
+
RIGHT_OUTER = "right_outer"
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Literal, Optional
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
|
|
7
|
+
from arthur_common.models.datasets import ModelProblemType
|
|
8
|
+
from arthur_common.models.schema_definitions import (
|
|
9
|
+
DType,
|
|
10
|
+
SchemaTypeUnion,
|
|
11
|
+
ScopeSchemaTag,
|
|
12
|
+
)
|
|
13
|
+
from pydantic import BaseModel, Field, model_validator
|
|
14
|
+
from typing_extensions import Self
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Temporary limited list, expand this as we grow and make it more in line with custom transformations later on
|
|
18
|
+
class AggregationType(str, Enum):
|
|
19
|
+
MIN = "min"
|
|
20
|
+
MAX = "max"
|
|
21
|
+
AVERAGE = "average"
|
|
22
|
+
COUNT = "count"
|
|
23
|
+
# Highly specific for Shield MVP work, to be abtracted more along the lines of the above later
|
|
24
|
+
SHIELD_INFERENCE_PASS_FAIL_COUNT = "shield_inference_pass_fail_count"
|
|
25
|
+
SHIELD_PROMPT_RESPONSE_PASS_FAIL_COUNT = "shield_prompt_response_pass_fail_count"
|
|
26
|
+
SHIELD_INFERENCE_RULE_COUNT = "shield_inference_rule_count"
|
|
27
|
+
SHIELD_INFERENCE_RULE_PASS_FAIL_COUNT = "shield_inference_rule_pass_fail_count"
|
|
28
|
+
SHIELD_INFERENCE_RULE_TOXICITY_SCORE = "shield_inference_rule_toxicity_score"
|
|
29
|
+
SHIELD_INFERENCE_RULE_PII_SCORE = "shield_inference_rule_pii_score"
|
|
30
|
+
SHIELD_INFERENCE_HALLUCINATION_COUNT = "shield_inference_hallucination_count"
|
|
31
|
+
SHIELD_INFERENCE_RULE_CLAIM_COUNT = "shield_inference_rule_claim_count"
|
|
32
|
+
SHIELD_INFERENCE_RULE_CLAIM_PASS_COUNT = "shield_inference_rule_claim_pass_count"
|
|
33
|
+
SHIELD_INFERENCE_RULE_CLAIM_FAIL_COUNT = "shield_inference_rule_claim_fail_count"
|
|
34
|
+
SHIELD_INFERENCE_RULE_LATENCY = "shield_inference_rule_latency"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Dimension(BaseModel):
|
|
38
|
+
name: str = Field(description="Name of the dimension.")
|
|
39
|
+
value: str = Field(description="Value of the dimension.")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NumericPoint(BaseModel):
|
|
43
|
+
timestamp: datetime = Field(
|
|
44
|
+
description="Timestamp with timezone. Should be the timestamp of the start of the interval covered by 'value'.",
|
|
45
|
+
)
|
|
46
|
+
value: float = Field(description="Floating point value for the metric.")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class NumericTimeSeries(BaseModel):
|
|
50
|
+
dimensions: list[Dimension] = Field(
|
|
51
|
+
description="List of dimensions for the series. If multiple dimensions are uploaded with the same key, "
|
|
52
|
+
"the one that is kept is undefined.",
|
|
53
|
+
)
|
|
54
|
+
values: list[NumericPoint] = Field(
|
|
55
|
+
description="List of numeric time series points.",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SketchPoint(BaseModel):
|
|
60
|
+
timestamp: datetime = Field(
|
|
61
|
+
description="Timestamp with timezone. Should be the timestamp of the start of the interval covered by 'value'.",
|
|
62
|
+
)
|
|
63
|
+
value: str = Field(description="Base64-encoded string representation of a sketch.")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SketchTimeSeries(BaseModel):
|
|
67
|
+
dimensions: list[Dimension] = Field(
|
|
68
|
+
description="List of dimensions for the series. If multiple dimensions are uploaded with the same key, "
|
|
69
|
+
"the one that is kept is undefined.",
|
|
70
|
+
)
|
|
71
|
+
values: list[SketchPoint] = Field(
|
|
72
|
+
description="List of sketch-based time series points.",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class BaseMetric(BaseModel):
|
|
77
|
+
name: str = Field(description="Name of the metric.")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class NumericMetric(BaseMetric):
|
|
81
|
+
numeric_series: list[NumericTimeSeries] = Field(
|
|
82
|
+
description="List of numeric time series to upload for the metric.",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class SketchMetric(BaseMetric):
|
|
87
|
+
sketch_series: list[SketchTimeSeries] = Field(
|
|
88
|
+
description="List of sketch-based time series to upload for the metric.",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class SystemMetricEventKind(Enum):
|
|
93
|
+
MODEL_JOB_FAILURE = "model_job_failure"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class SystemMetric(BaseModel):
|
|
97
|
+
event_kind: SystemMetricEventKind = Field(
|
|
98
|
+
description="Kind of the system metric event.",
|
|
99
|
+
)
|
|
100
|
+
timestamp: datetime = Field(
|
|
101
|
+
description="Timezone-aware timestamp of the system metric event.",
|
|
102
|
+
)
|
|
103
|
+
dimensions: list[Dimension] = Field(
|
|
104
|
+
description="List of dimensions for the systems metric. If multiple dimensions are uploaded with the same key, "
|
|
105
|
+
"the one that is kept is undefined.",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class AggregationMetricType(Enum):
|
|
110
|
+
SKETCH = "sketch"
|
|
111
|
+
NUMERIC = "numeric"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class MetricsParameterSchema(BaseModel):
|
|
115
|
+
parameter_key: str = Field(description="Name of the parameter.")
|
|
116
|
+
optional: bool = Field(
|
|
117
|
+
False,
|
|
118
|
+
description="Boolean denoting if the parameter is optional.",
|
|
119
|
+
)
|
|
120
|
+
friendly_name: str = Field(
|
|
121
|
+
description="User facing name of the parameter.",
|
|
122
|
+
)
|
|
123
|
+
description: str = Field(
|
|
124
|
+
description="Description of the parameter.",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class MetricsDatasetParameterSchema(MetricsParameterSchema):
|
|
129
|
+
parameter_type: Literal["dataset"] = "dataset"
|
|
130
|
+
model_problem_type: Optional[ModelProblemType] = Field(
|
|
131
|
+
default=None,
|
|
132
|
+
description="Model problem type of the parameter. If not set, any model problem type is allowed.",
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class MetricsLiteralParameterSchema(MetricsParameterSchema):
|
|
137
|
+
parameter_type: Literal["literal"] = "literal"
|
|
138
|
+
parameter_dtype: DType = Field(description="Data type of the parameter.")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class MetricsColumnParameterSchema(MetricsParameterSchema):
|
|
142
|
+
parameter_type: Literal["column"] = "column"
|
|
143
|
+
tag_hints: list[ScopeSchemaTag] = Field(
|
|
144
|
+
[],
|
|
145
|
+
description="List of tags that are applicable to this parameter. Datasets with columns that have matching tags can be inferred this way.",
|
|
146
|
+
)
|
|
147
|
+
source_dataset_parameter_key: str = Field(
|
|
148
|
+
description="Name of the parameter that provides the dataset to be used for this column.",
|
|
149
|
+
)
|
|
150
|
+
allowed_column_types: Optional[list[SchemaTypeUnion]] = Field(
|
|
151
|
+
default=None,
|
|
152
|
+
description="List of column types applicable to this parameter",
|
|
153
|
+
)
|
|
154
|
+
allow_any_column_type: bool = Field(
|
|
155
|
+
False,
|
|
156
|
+
description="Indicates if this metric parameter can accept any column type.",
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@model_validator(mode="after")
|
|
160
|
+
def column_type_combination_validator(self) -> Self:
|
|
161
|
+
if self.allowed_column_types and self.allow_any_column_type:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"Parameter cannot allow any column while also explicitly listing applicable ones.",
|
|
164
|
+
)
|
|
165
|
+
return self
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# Not used /implemented yet. Might turn into group by column list
|
|
169
|
+
class MetricsColumnListParameterSchema(MetricsParameterSchema):
|
|
170
|
+
parameter_type: Literal["column_list"] = "column_list"
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
MetricsParameterSchemaUnion = (
|
|
174
|
+
MetricsDatasetParameterSchema
|
|
175
|
+
| MetricsLiteralParameterSchema
|
|
176
|
+
| MetricsColumnParameterSchema
|
|
177
|
+
| MetricsColumnListParameterSchema
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@dataclass
|
|
182
|
+
class DatasetReference:
|
|
183
|
+
dataset_name: str
|
|
184
|
+
dataset_table_name: str
|
|
185
|
+
dataset_id: UUID
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class AggregationSpecSchema(BaseModel):
|
|
189
|
+
name: str = Field(description="Name of the aggregation function.")
|
|
190
|
+
id: UUID = Field(description="Unique identifier of the aggregation function.")
|
|
191
|
+
description: str = Field(
|
|
192
|
+
description="Description of the aggregation function and what it aggregates.",
|
|
193
|
+
)
|
|
194
|
+
# version: int = Field("Version number of the aggregation function.")
|
|
195
|
+
metric_type: AggregationMetricType = Field(
|
|
196
|
+
description="Return type of the aggregations aggregate function.",
|
|
197
|
+
) # Sketch, Numeric
|
|
198
|
+
init_args: list[MetricsParameterSchemaUnion] = Field(
|
|
199
|
+
description="List of parameters to the aggregation's init function.",
|
|
200
|
+
)
|
|
201
|
+
aggregate_args: list[MetricsParameterSchemaUnion] = Field(
|
|
202
|
+
description="List of parameters to the aggregation's aggregate function.",
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def parameter_is_column_reference(self, parameter_name: str) -> bool:
|
|
206
|
+
return any(
|
|
207
|
+
param.parameter_key == parameter_name
|
|
208
|
+
and isinstance(param, MetricsColumnParameterSchema)
|
|
209
|
+
for param in self.aggregate_args
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
@model_validator(mode="after")
|
|
213
|
+
def column_dataset_references_exist(self) -> Self:
|
|
214
|
+
dataset_parameter_keys = [
|
|
215
|
+
p.parameter_key
|
|
216
|
+
for p in self.aggregate_args
|
|
217
|
+
if isinstance(p, MetricsDatasetParameterSchema)
|
|
218
|
+
]
|
|
219
|
+
for param in self.aggregate_args:
|
|
220
|
+
if (
|
|
221
|
+
isinstance(param, MetricsColumnParameterSchema)
|
|
222
|
+
and param.source_dataset_parameter_key not in dataset_parameter_keys
|
|
223
|
+
):
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Column parameter '{param.parameter_key}' references dataset parameter '{param.source_dataset_parameter_key}' which does not exist.",
|
|
226
|
+
)
|
|
227
|
+
return self
|
|
File without changes
|