arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +28 -19
- arize/_exporter/client.py +56 -37
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +207 -76
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +181 -58
- arize/config.py +324 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +304 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +43 -18
- arize/embeddings/tabular_generators.py +46 -31
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +13 -0
- arize/experiments/client.py +394 -285
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/ml/__init__.py +1 -0
- arize/ml/batch_validation/__init__.py +1 -0
- arize/{models → ml}/batch_validation/errors.py +545 -67
- arize/{models → ml}/batch_validation/validator.py +344 -303
- arize/ml/bounded_executor.py +47 -0
- arize/{models → ml}/casting.py +118 -108
- arize/{models → ml}/client.py +339 -118
- arize/{models → ml}/proto.py +97 -42
- arize/{models → ml}/stream_validation.py +43 -15
- arize/ml/surrogate_explainer/__init__.py +1 -0
- arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
- arize/{types.py → ml/types.py} +355 -354
- arize/pre_releases.py +44 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +134 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +204 -175
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +60 -37
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +81 -14
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +35 -14
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +78 -8
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +20 -3
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/utils/types.py +105 -0
- arize/version.py +3 -1
- {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
- arize-8.0.0b0.dist-info/RECORD +175 -0
- {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
- arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize/models/__init__.py +0 -0
- arize/models/batch_validation/__init__.py +0 -0
- arize/models/bounded_executor.py +0 -34
- arize/models/surrogate_explainer/__init__.py +0 -0
- arize-8.0.0a22.dist-info/RECORD +0 -146
- arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
arize/{models → ml}/proto.py
RENAMED
|
@@ -1,14 +1,14 @@
|
|
|
1
|
+
"""Protocol buffer utilities for ML model data serialization."""
|
|
2
|
+
|
|
1
3
|
# type: ignore[pb2]
|
|
2
4
|
from __future__ import annotations
|
|
3
5
|
|
|
4
|
-
from typing import Tuple
|
|
5
|
-
|
|
6
6
|
from google.protobuf.timestamp_pb2 import Timestamp
|
|
7
7
|
from google.protobuf.wrappers_pb2 import DoubleValue, StringValue
|
|
8
8
|
|
|
9
9
|
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
10
10
|
from arize.exceptions.parameters import InvalidValueType
|
|
11
|
-
from arize.types import (
|
|
11
|
+
from arize.ml.types import (
|
|
12
12
|
CATEGORICAL_MODEL_TYPES,
|
|
13
13
|
NUMERIC_MODEL_TYPES,
|
|
14
14
|
Embedding,
|
|
@@ -22,11 +22,19 @@ from arize.types import (
|
|
|
22
22
|
RankingPredictionLabel,
|
|
23
23
|
SemanticSegmentationLabel,
|
|
24
24
|
convert_element,
|
|
25
|
-
is_list_of,
|
|
26
25
|
)
|
|
26
|
+
from arize.utils.types import is_list_of
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
|
|
30
|
+
"""Convert a dictionary to protobuf format with string keys and pb2.Value values.
|
|
27
31
|
|
|
32
|
+
Args:
|
|
33
|
+
d: Dictionary to convert, or None.
|
|
28
34
|
|
|
29
|
-
|
|
35
|
+
Returns:
|
|
36
|
+
Dictionary with string keys and protobuf Value objects, or empty dict if input is None.
|
|
37
|
+
"""
|
|
30
38
|
if d is None:
|
|
31
39
|
return {}
|
|
32
40
|
# Takes a dictionary and
|
|
@@ -41,6 +49,18 @@ def get_pb_dictionary(d):
|
|
|
41
49
|
|
|
42
50
|
|
|
43
51
|
def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
|
|
52
|
+
"""Convert a Python value to a protobuf Value object.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
name: The name/key associated with this value.
|
|
56
|
+
value: The value to convert to protobuf format.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
A pb2.Value protobuf object, or None if value cannot be converted.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
TypeError: If value type is not supported.
|
|
63
|
+
"""
|
|
44
64
|
if isinstance(value, pb2.Value):
|
|
45
65
|
return value
|
|
46
66
|
if value is not None and is_list_of(value, str):
|
|
@@ -50,19 +70,18 @@ def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
|
|
|
50
70
|
val = convert_element(value)
|
|
51
71
|
if val is None:
|
|
52
72
|
return None
|
|
53
|
-
|
|
73
|
+
if isinstance(val, (str, bool)):
|
|
54
74
|
return pb2.Value(string=str(val))
|
|
55
|
-
|
|
75
|
+
if isinstance(val, int):
|
|
56
76
|
return pb2.Value(int=val)
|
|
57
|
-
|
|
77
|
+
if isinstance(val, float):
|
|
58
78
|
return pb2.Value(double=val)
|
|
59
|
-
|
|
79
|
+
if isinstance(val, Embedding):
|
|
60
80
|
return pb2.Value(embedding=get_pb_embedding(val))
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
)
|
|
81
|
+
raise TypeError(
|
|
82
|
+
f"dimension '{name}' = {value} is type {type(value)}, but must be "
|
|
83
|
+
"one of: bool, str, float, int, embedding"
|
|
84
|
+
)
|
|
66
85
|
|
|
67
86
|
|
|
68
87
|
def get_pb_label(
|
|
@@ -71,7 +90,7 @@ def get_pb_label(
|
|
|
71
90
|
| bool
|
|
72
91
|
| int
|
|
73
92
|
| float
|
|
74
|
-
|
|
|
93
|
+
| tuple[str, float]
|
|
75
94
|
| ObjectDetectionLabel
|
|
76
95
|
| SemanticSegmentationLabel
|
|
77
96
|
| InstanceSegmentationPredictionLabel
|
|
@@ -82,19 +101,32 @@ def get_pb_label(
|
|
|
82
101
|
| MultiClassActualLabel,
|
|
83
102
|
model_type: ModelTypes,
|
|
84
103
|
) -> pb2.PredictionLabel | pb2.ActualLabel:
|
|
104
|
+
"""Convert a label value to the appropriate protobuf label type.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
prediction_or_actual: Whether this is a "prediction" or "actual" label.
|
|
108
|
+
value: The label value to convert.
|
|
109
|
+
model_type: The type of model (numeric, categorical, etc.).
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A protobuf PredictionLabel or ActualLabel object.
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
ValueError: If model_type is not supported.
|
|
116
|
+
"""
|
|
85
117
|
value = convert_element(value)
|
|
86
118
|
if model_type in NUMERIC_MODEL_TYPES:
|
|
87
119
|
return _get_numeric_pb_label(prediction_or_actual, value)
|
|
88
|
-
|
|
120
|
+
if (
|
|
89
121
|
model_type in CATEGORICAL_MODEL_TYPES
|
|
90
122
|
or model_type == ModelTypes.GENERATIVE_LLM
|
|
91
123
|
):
|
|
92
124
|
return _get_score_categorical_pb_label(prediction_or_actual, value)
|
|
93
|
-
|
|
125
|
+
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
94
126
|
return _get_cv_pb_label(prediction_or_actual, value)
|
|
95
|
-
|
|
127
|
+
if model_type == ModelTypes.RANKING:
|
|
96
128
|
return _get_ranking_pb_label(value)
|
|
97
|
-
|
|
129
|
+
if model_type == ModelTypes.MULTI_CLASS:
|
|
98
130
|
return _get_multi_class_pb_label(value)
|
|
99
131
|
raise ValueError(
|
|
100
132
|
f"model_type must be one of: {[mt.prediction_or_actual for mt in ModelTypes]} "
|
|
@@ -103,7 +135,18 @@ def get_pb_label(
|
|
|
103
135
|
)
|
|
104
136
|
|
|
105
137
|
|
|
106
|
-
def get_pb_timestamp(time_overwrite):
|
|
138
|
+
def get_pb_timestamp(time_overwrite: int | None) -> object | None:
|
|
139
|
+
"""Convert a Unix timestamp to a protobuf Timestamp object.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
time_overwrite: Unix epoch time in seconds, or None.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
A protobuf Timestamp object, or None if input is None.
|
|
146
|
+
|
|
147
|
+
Raises:
|
|
148
|
+
TypeError: If time_overwrite is not an integer.
|
|
149
|
+
"""
|
|
107
150
|
if time_overwrite is None:
|
|
108
151
|
return None
|
|
109
152
|
time = convert_element(time_overwrite)
|
|
@@ -118,6 +161,14 @@ def get_pb_timestamp(time_overwrite):
|
|
|
118
161
|
|
|
119
162
|
|
|
120
163
|
def get_pb_embedding(val: Embedding) -> pb2.Embedding:
|
|
164
|
+
"""Convert an Embedding object to a protobuf Embedding.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
val: The Embedding object containing vector, data, and link_to_data.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
A protobuf Embedding object with the vector and optional raw data.
|
|
171
|
+
"""
|
|
121
172
|
if Embedding._is_valid_iterable(val.data):
|
|
122
173
|
return pb2.Embedding(
|
|
123
174
|
vector=val.vector,
|
|
@@ -126,7 +177,7 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
|
|
|
126
177
|
),
|
|
127
178
|
link_to_data=StringValue(value=val.link_to_data),
|
|
128
179
|
)
|
|
129
|
-
|
|
180
|
+
if isinstance(val.data, str):
|
|
130
181
|
return pb2.Embedding(
|
|
131
182
|
vector=val.vector,
|
|
132
183
|
raw_data=pb2.Embedding.RawData(
|
|
@@ -135,7 +186,7 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
|
|
|
135
186
|
),
|
|
136
187
|
link_to_data=StringValue(value=val.link_to_data),
|
|
137
188
|
)
|
|
138
|
-
|
|
189
|
+
if val.data is None:
|
|
139
190
|
return pb2.Embedding(
|
|
140
191
|
vector=val.vector,
|
|
141
192
|
link_to_data=StringValue(value=val.link_to_data),
|
|
@@ -156,13 +207,14 @@ def _get_numeric_pb_label(
|
|
|
156
207
|
)
|
|
157
208
|
if prediction_or_actual == "prediction":
|
|
158
209
|
return pb2.PredictionLabel(numeric=value)
|
|
159
|
-
|
|
210
|
+
if prediction_or_actual == "actual":
|
|
160
211
|
return pb2.ActualLabel(numeric=value)
|
|
212
|
+
return None
|
|
161
213
|
|
|
162
214
|
|
|
163
215
|
def _get_score_categorical_pb_label(
|
|
164
216
|
prediction_or_actual: str,
|
|
165
|
-
value: bool | str |
|
|
217
|
+
value: bool | str | tuple[str, float],
|
|
166
218
|
) -> pb2.PredictionLabel | pb2.ActualLabel:
|
|
167
219
|
sc = pb2.ScoreCategorical()
|
|
168
220
|
if isinstance(value, bool):
|
|
@@ -202,8 +254,9 @@ def _get_score_categorical_pb_label(
|
|
|
202
254
|
)
|
|
203
255
|
if prediction_or_actual == "prediction":
|
|
204
256
|
return pb2.PredictionLabel(score_categorical=sc)
|
|
205
|
-
|
|
257
|
+
if prediction_or_actual == "actual":
|
|
206
258
|
return pb2.ActualLabel(score_categorical=sc)
|
|
259
|
+
return None
|
|
207
260
|
|
|
208
261
|
|
|
209
262
|
def _get_cv_pb_label(
|
|
@@ -215,20 +268,19 @@ def _get_cv_pb_label(
|
|
|
215
268
|
) -> pb2.PredictionLabel | pb2.ActualLabel:
|
|
216
269
|
if isinstance(value, ObjectDetectionLabel):
|
|
217
270
|
return _get_object_detection_pb_label(prediction_or_actual, value)
|
|
218
|
-
|
|
271
|
+
if isinstance(value, SemanticSegmentationLabel):
|
|
219
272
|
return _get_semantic_segmentation_pb_label(prediction_or_actual, value)
|
|
220
|
-
|
|
273
|
+
if isinstance(value, InstanceSegmentationPredictionLabel):
|
|
221
274
|
return _get_instance_segmentation_prediction_pb_label(value)
|
|
222
|
-
|
|
275
|
+
if isinstance(value, InstanceSegmentationActualLabel):
|
|
223
276
|
return _get_instance_segmentation_actual_pb_label(value)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
)
|
|
277
|
+
raise InvalidValueType(
|
|
278
|
+
"cv label",
|
|
279
|
+
value,
|
|
280
|
+
"ObjectDetectionLabel, SemanticSegmentationLabel, or "
|
|
281
|
+
"InstanceSegmentationPredictionLabel for model type "
|
|
282
|
+
f"{ModelTypes.OBJECT_DETECTION}",
|
|
283
|
+
)
|
|
232
284
|
|
|
233
285
|
|
|
234
286
|
def _get_object_detection_pb_label(
|
|
@@ -265,8 +317,9 @@ def _get_object_detection_pb_label(
|
|
|
265
317
|
od.bounding_boxes.extend(bounding_boxes)
|
|
266
318
|
if prediction_or_actual == "prediction":
|
|
267
319
|
return pb2.PredictionLabel(object_detection=od)
|
|
268
|
-
|
|
320
|
+
if prediction_or_actual == "actual":
|
|
269
321
|
return pb2.ActualLabel(object_detection=od)
|
|
322
|
+
return None
|
|
270
323
|
|
|
271
324
|
|
|
272
325
|
def _get_semantic_segmentation_pb_label(
|
|
@@ -292,10 +345,11 @@ def _get_semantic_segmentation_pb_label(
|
|
|
292
345
|
cv_label = pb2.CVPredictionLabel()
|
|
293
346
|
cv_label.semantic_segmentation_label.polygons.extend(polygons)
|
|
294
347
|
return pb2.PredictionLabel(cv_label=cv_label)
|
|
295
|
-
|
|
348
|
+
if prediction_or_actual == "actual":
|
|
296
349
|
cv_label = pb2.CVActualLabel()
|
|
297
350
|
cv_label.semantic_segmentation_label.polygons.extend(polygons)
|
|
298
351
|
return pb2.ActualLabel(cv_label=cv_label)
|
|
352
|
+
return None
|
|
299
353
|
|
|
300
354
|
|
|
301
355
|
def _get_instance_segmentation_prediction_pb_label(
|
|
@@ -394,7 +448,7 @@ def _get_ranking_pb_label(
|
|
|
394
448
|
if value.label is not None:
|
|
395
449
|
rp.label = value.label
|
|
396
450
|
return pb2.PredictionLabel(ranking=rp)
|
|
397
|
-
|
|
451
|
+
if isinstance(value, RankingActualLabel):
|
|
398
452
|
ra = pb2.RankingActual()
|
|
399
453
|
# relevance_labels and relevance_score are optional
|
|
400
454
|
if value.relevance_labels is not None:
|
|
@@ -402,6 +456,7 @@ def _get_ranking_pb_label(
|
|
|
402
456
|
if value.relevance_score is not None:
|
|
403
457
|
ra.relevance_score.value = value.relevance_score
|
|
404
458
|
return pb2.ActualLabel(ranking=ra)
|
|
459
|
+
return None
|
|
405
460
|
|
|
406
461
|
|
|
407
462
|
def _get_multi_class_pb_label(
|
|
@@ -447,9 +502,8 @@ def _get_multi_class_pb_label(
|
|
|
447
502
|
prediction_scores=prediction_scores_double_values,
|
|
448
503
|
)
|
|
449
504
|
mc_pred = pb2.MultiClassPrediction(single_label=single_label)
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
elif isinstance(value, MultiClassActualLabel):
|
|
505
|
+
return pb2.PredictionLabel(multi_class=mc_pred)
|
|
506
|
+
if isinstance(value, MultiClassActualLabel):
|
|
453
507
|
# Validations checked actual score map is not None
|
|
454
508
|
actual_labels = [] # list of class names with actual score of 1
|
|
455
509
|
for class_name, score in value.actual_scores.items():
|
|
@@ -459,3 +513,4 @@ def _get_multi_class_pb_label(
|
|
|
459
513
|
actual_labels=actual_labels,
|
|
460
514
|
)
|
|
461
515
|
return pb2.ActualLabel(multi_class=mc_act)
|
|
516
|
+
return None
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
+
"""Stream validation logic for ML model predictions."""
|
|
2
|
+
|
|
1
3
|
# type: ignore[pb2]
|
|
2
|
-
from typing import Dict, Tuple
|
|
3
4
|
|
|
4
5
|
from arize.constants.ml import MAX_PREDICTION_ID_LEN, MIN_PREDICTION_ID_LEN
|
|
5
6
|
from arize.exceptions.parameters import (
|
|
6
7
|
InvalidValueType,
|
|
7
8
|
)
|
|
8
|
-
from arize.types import (
|
|
9
|
+
from arize.ml.types import (
|
|
9
10
|
CATEGORICAL_MODEL_TYPES,
|
|
10
11
|
NUMERIC_MODEL_TYPES,
|
|
11
12
|
ActualLabelTypes,
|
|
@@ -32,7 +33,7 @@ def validate_label(
|
|
|
32
33
|
| bool
|
|
33
34
|
| int
|
|
34
35
|
| float
|
|
35
|
-
|
|
|
36
|
+
| tuple[str | bool, float]
|
|
36
37
|
| ObjectDetectionLabel
|
|
37
38
|
| RankingPredictionLabel
|
|
38
39
|
| RankingActualLabel
|
|
@@ -41,8 +42,20 @@ def validate_label(
|
|
|
41
42
|
| InstanceSegmentationActualLabel
|
|
42
43
|
| MultiClassPredictionLabel
|
|
43
44
|
| MultiClassActualLabel,
|
|
44
|
-
embedding_features:
|
|
45
|
-
):
|
|
45
|
+
embedding_features: dict[str, Embedding],
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Validate a label value against the specified model type.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
prediction_or_actual: Whether this is a "prediction" or "actual" label.
|
|
51
|
+
model_type: The type of model (numeric, categorical, etc.).
|
|
52
|
+
label: The label value to validate.
|
|
53
|
+
embedding_features: Dictionary of embedding features for validation.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If label is invalid for the given model type.
|
|
57
|
+
TypeError: If label type is incorrect.
|
|
58
|
+
"""
|
|
46
59
|
if model_type in NUMERIC_MODEL_TYPES:
|
|
47
60
|
_validate_numeric_label(model_type, label)
|
|
48
61
|
elif model_type in CATEGORICAL_MODEL_TYPES:
|
|
@@ -63,8 +76,8 @@ def validate_label(
|
|
|
63
76
|
|
|
64
77
|
def _validate_numeric_label(
|
|
65
78
|
model_type: ModelTypes,
|
|
66
|
-
label: str | bool | int | float |
|
|
67
|
-
):
|
|
79
|
+
label: str | bool | int | float | tuple[str | bool, float],
|
|
80
|
+
) -> None:
|
|
68
81
|
if not isinstance(label, (float, int)):
|
|
69
82
|
raise InvalidValueType(
|
|
70
83
|
f"label {label}",
|
|
@@ -75,8 +88,8 @@ def _validate_numeric_label(
|
|
|
75
88
|
|
|
76
89
|
def _validate_categorical_label(
|
|
77
90
|
model_type: ModelTypes,
|
|
78
|
-
label: str | bool | int | float |
|
|
79
|
-
):
|
|
91
|
+
label: str | bool | int | float | tuple[str | bool, float],
|
|
92
|
+
) -> None:
|
|
80
93
|
is_valid = isinstance(label, (str, bool, int, float)) or (
|
|
81
94
|
isinstance(label, tuple)
|
|
82
95
|
and isinstance(label[0], (str, bool))
|
|
@@ -96,8 +109,8 @@ def _validate_cv_label(
|
|
|
96
109
|
| SemanticSegmentationLabel
|
|
97
110
|
| InstanceSegmentationPredictionLabel
|
|
98
111
|
| InstanceSegmentationActualLabel,
|
|
99
|
-
embedding_features:
|
|
100
|
-
):
|
|
112
|
+
embedding_features: dict[str, Embedding],
|
|
113
|
+
) -> None:
|
|
101
114
|
if (
|
|
102
115
|
not isinstance(label, ObjectDetectionLabel)
|
|
103
116
|
and not isinstance(label, SemanticSegmentationLabel)
|
|
@@ -126,7 +139,7 @@ def _validate_cv_label(
|
|
|
126
139
|
|
|
127
140
|
def _validate_ranking_label(
|
|
128
141
|
label: RankingPredictionLabel | RankingActualLabel,
|
|
129
|
-
):
|
|
142
|
+
) -> None:
|
|
130
143
|
if not isinstance(label, (RankingPredictionLabel, RankingActualLabel)):
|
|
131
144
|
raise InvalidValueType(
|
|
132
145
|
f"label {label}",
|
|
@@ -138,7 +151,7 @@ def _validate_ranking_label(
|
|
|
138
151
|
|
|
139
152
|
def _validate_generative_llm_label(
|
|
140
153
|
label: str | bool | int | float,
|
|
141
|
-
):
|
|
154
|
+
) -> None:
|
|
142
155
|
is_valid = isinstance(label, (str, bool, int, float))
|
|
143
156
|
if not is_valid:
|
|
144
157
|
raise InvalidValueType(
|
|
@@ -150,7 +163,7 @@ def _validate_generative_llm_label(
|
|
|
150
163
|
|
|
151
164
|
def _validate_multi_class_label(
|
|
152
165
|
label: MultiClassPredictionLabel | MultiClassActualLabel,
|
|
153
|
-
):
|
|
166
|
+
) -> None:
|
|
154
167
|
if not isinstance(
|
|
155
168
|
label, (MultiClassPredictionLabel, MultiClassActualLabel)
|
|
156
169
|
):
|
|
@@ -167,8 +180,23 @@ def validate_and_convert_prediction_id(
|
|
|
167
180
|
environment: Environments,
|
|
168
181
|
prediction_label: PredictionLabelTypes | None = None,
|
|
169
182
|
actual_label: ActualLabelTypes | None = None,
|
|
170
|
-
shap_values:
|
|
183
|
+
shap_values: dict[str, float] | None = None,
|
|
171
184
|
) -> str:
|
|
185
|
+
"""Validate and convert a prediction ID to string format, or generate one if absent.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
prediction_id: The prediction ID to validate/convert, or None.
|
|
189
|
+
environment: The environment context (training, validation, production).
|
|
190
|
+
prediction_label: Optional prediction label for delayed record detection.
|
|
191
|
+
actual_label: Optional actual label for delayed record detection.
|
|
192
|
+
shap_values: Optional SHAP values for delayed record detection.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
A validated prediction ID string.
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
ValueError: If prediction ID is invalid.
|
|
199
|
+
"""
|
|
172
200
|
# If the user does not provide prediction id
|
|
173
201
|
if prediction_id:
|
|
174
202
|
# If prediction id is given by user, convert it to string and validate length
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Surrogate explainer implementations for model interpretability."""
|
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
"""Mimic explainer implementation for surrogate model explanations."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import random
|
|
4
6
|
import string
|
|
5
7
|
from dataclasses import replace
|
|
6
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
7
9
|
|
|
8
10
|
import numpy as np
|
|
9
11
|
import pandas as pd
|
|
@@ -13,20 +15,30 @@ from interpret_community.mimic.mimic_explainer import (
|
|
|
13
15
|
)
|
|
14
16
|
from sklearn.preprocessing import LabelEncoder
|
|
15
17
|
|
|
16
|
-
from arize.types import (
|
|
18
|
+
from arize.ml.types import (
|
|
17
19
|
CATEGORICAL_MODEL_TYPES,
|
|
18
20
|
NUMERIC_MODEL_TYPES,
|
|
19
21
|
ModelTypes,
|
|
20
22
|
)
|
|
21
23
|
|
|
22
24
|
if TYPE_CHECKING:
|
|
23
|
-
from
|
|
25
|
+
from collections.abc import Callable
|
|
26
|
+
|
|
27
|
+
from arize.ml.types import Schema
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
class Mimic:
|
|
31
|
+
"""Mimic explainer wrapper for generating surrogate model explanations."""
|
|
32
|
+
|
|
27
33
|
_testing = False
|
|
28
34
|
|
|
29
|
-
def __init__(self, X: pd.DataFrame, model_func: Callable):
|
|
35
|
+
def __init__(self, X: pd.DataFrame, model_func: Callable) -> None:
|
|
36
|
+
"""Initialize the Mimic explainer with training data and model.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
X: Training data DataFrame for the surrogate model.
|
|
40
|
+
model_func: Model function to explain.
|
|
41
|
+
"""
|
|
30
42
|
self.explainer = MimicExplainer(
|
|
31
43
|
model_func,
|
|
32
44
|
X,
|
|
@@ -36,6 +48,7 @@ class Mimic:
|
|
|
36
48
|
)
|
|
37
49
|
|
|
38
50
|
def explain(self, X: pd.DataFrame) -> pd.DataFrame:
|
|
51
|
+
"""Explain feature importance for the given input DataFrame."""
|
|
39
52
|
return pd.DataFrame(
|
|
40
53
|
self.explainer.explain_local(X).local_importance_values,
|
|
41
54
|
columns=X.columns,
|
|
@@ -45,7 +58,8 @@ class Mimic:
|
|
|
45
58
|
@staticmethod
|
|
46
59
|
def augment(
|
|
47
60
|
df: pd.DataFrame, schema: Schema, model_type: ModelTypes
|
|
48
|
-
) ->
|
|
61
|
+
) -> tuple[pd.DataFrame, Schema]:
|
|
62
|
+
"""Augment the DataFrame and schema with SHAP values for explainability."""
|
|
49
63
|
features = schema.feature_column_names
|
|
50
64
|
X = df[features]
|
|
51
65
|
|
|
@@ -71,7 +85,7 @@ class Mimic:
|
|
|
71
85
|
)
|
|
72
86
|
|
|
73
87
|
# model func requires 1 positional argument
|
|
74
|
-
def model_func(_): # type: ignore
|
|
88
|
+
def model_func(_: object) -> object: # type: ignore
|
|
75
89
|
return np.column_stack((1 - y, y))
|
|
76
90
|
|
|
77
91
|
elif model_type in NUMERIC_MODEL_TYPES:
|
|
@@ -89,7 +103,7 @@ class Mimic:
|
|
|
89
103
|
)
|
|
90
104
|
|
|
91
105
|
# model func requires 1 positional argument
|
|
92
|
-
def model_func(_): # type: ignore
|
|
106
|
+
def model_func(_: object) -> object: # type: ignore
|
|
93
107
|
return y
|
|
94
108
|
|
|
95
109
|
else:
|
|
@@ -100,8 +114,9 @@ class Mimic:
|
|
|
100
114
|
|
|
101
115
|
# Column name mapping between features and feature importance values.
|
|
102
116
|
# This is used to augment the schema.
|
|
117
|
+
# Generate unique column names to avoid collisions (not security-sensitive)
|
|
103
118
|
col_map = {
|
|
104
|
-
ft: f"{''.join(random.choices(string.ascii_letters, k=8))}"
|
|
119
|
+
ft: f"{''.join(random.choices(string.ascii_letters, k=8))}" # noqa: S311
|
|
105
120
|
for ft in features
|
|
106
121
|
}
|
|
107
122
|
aug_schema = replace(schema, shap_values_column_names=col_map)
|
|
@@ -127,7 +142,7 @@ class Mimic:
|
|
|
127
142
|
X[col] = X[col].astype(object).where(~X[col].isna(), np.nan)
|
|
128
143
|
|
|
129
144
|
# Apply integer encoding to non-numeric columns.
|
|
130
|
-
# Currently training and explaining
|
|
145
|
+
# Currently training and explaining datasets are the same, but
|
|
131
146
|
# this can be changed in the future. The student model can be
|
|
132
147
|
# fitted on a much larger dataset since it takes a lot less time.
|
|
133
148
|
X = pd.concat(
|
|
@@ -156,7 +171,7 @@ class Mimic:
|
|
|
156
171
|
|
|
157
172
|
# Fill null with zero so they're not counted as missing records by server
|
|
158
173
|
if not Mimic._testing:
|
|
159
|
-
aug_df.fillna(
|
|
174
|
+
aug_df.fillna(dict.fromkeys(col_map.values(), 0), inplace=True)
|
|
160
175
|
|
|
161
176
|
return (
|
|
162
177
|
aug_df,
|