arize 8.0.0b2__py3-none-any.whl → 8.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +8 -1
- arize/_exporter/client.py +18 -17
- arize/_exporter/parsers/tracing_data_parser.py +9 -4
- arize/_exporter/validation.py +1 -1
- arize/_flight/client.py +33 -13
- arize/_lazy.py +37 -2
- arize/client.py +61 -35
- arize/config.py +168 -14
- arize/constants/config.py +1 -0
- arize/datasets/client.py +32 -19
- arize/embeddings/auto_generator.py +14 -7
- arize/embeddings/base_generators.py +15 -9
- arize/embeddings/cv_generators.py +2 -2
- arize/embeddings/nlp_generators.py +8 -8
- arize/embeddings/tabular_generators.py +5 -5
- arize/exceptions/config.py +22 -0
- arize/exceptions/parameters.py +1 -1
- arize/exceptions/values.py +8 -5
- arize/experiments/__init__.py +4 -0
- arize/experiments/client.py +17 -11
- arize/experiments/evaluators/base.py +6 -3
- arize/experiments/evaluators/executors.py +6 -4
- arize/experiments/evaluators/rate_limiters.py +3 -1
- arize/experiments/evaluators/types.py +7 -5
- arize/experiments/evaluators/utils.py +7 -5
- arize/experiments/functions.py +111 -48
- arize/experiments/tracing.py +4 -1
- arize/experiments/types.py +31 -26
- arize/logging.py +53 -32
- arize/ml/batch_validation/validator.py +82 -70
- arize/ml/bounded_executor.py +25 -6
- arize/ml/casting.py +45 -27
- arize/ml/client.py +35 -28
- arize/ml/proto.py +16 -17
- arize/ml/stream_validation.py +63 -25
- arize/ml/surrogate_explainer/mimic.py +15 -7
- arize/ml/types.py +26 -12
- arize/pre_releases.py +7 -6
- arize/py.typed +0 -0
- arize/regions.py +10 -10
- arize/spans/client.py +113 -21
- arize/spans/conversion.py +7 -5
- arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
- arize/spans/validation/annotations/value_validation.py +11 -14
- arize/spans/validation/common/dataframe_form_validation.py +1 -1
- arize/spans/validation/common/value_validation.py +10 -13
- arize/spans/validation/evals/value_validation.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +1 -1
- arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
- arize/spans/validation/metadata/value_validation.py +23 -1
- arize/utils/arrow.py +37 -1
- arize/utils/online_tasks/dataframe_preprocessor.py +8 -4
- arize/utils/proto.py +0 -1
- arize/utils/types.py +6 -6
- arize/version.py +1 -1
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/METADATA +10 -2
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/RECORD +60 -58
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
arize/ml/client.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# type: ignore[pb2]
|
|
2
1
|
"""Client implementation for managing ML models in the Arize platform."""
|
|
3
2
|
|
|
4
3
|
from __future__ import annotations
|
|
@@ -6,7 +5,7 @@ from __future__ import annotations
|
|
|
6
5
|
import copy
|
|
7
6
|
import logging
|
|
8
7
|
import time
|
|
9
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
10
9
|
|
|
11
10
|
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
12
11
|
from arize._lazy import require
|
|
@@ -377,8 +376,12 @@ class MLModelsClient:
|
|
|
377
376
|
|
|
378
377
|
if embedding_features or prompt or response:
|
|
379
378
|
# NOTE: Deep copy is necessary to avoid side effects on the original input dictionary
|
|
380
|
-
combined_embedding_features = (
|
|
381
|
-
|
|
379
|
+
combined_embedding_features: dict[str, str | Embedding] = (
|
|
380
|
+
cast(
|
|
381
|
+
"dict[str, str | Embedding]", embedding_features.copy()
|
|
382
|
+
)
|
|
383
|
+
if embedding_features
|
|
384
|
+
else {}
|
|
382
385
|
)
|
|
383
386
|
# Map prompt as embedding features for generative models
|
|
384
387
|
if prompt is not None:
|
|
@@ -395,7 +398,7 @@ class MLModelsClient:
|
|
|
395
398
|
p.MergeFrom(embedding_feats)
|
|
396
399
|
|
|
397
400
|
if tags or llm_run_metadata:
|
|
398
|
-
joined_tags = copy.deepcopy(tags)
|
|
401
|
+
joined_tags = copy.deepcopy(tags) if tags is not None else {}
|
|
399
402
|
if llm_run_metadata:
|
|
400
403
|
if llm_run_metadata.total_token_count is not None:
|
|
401
404
|
joined_tags[
|
|
@@ -522,7 +525,7 @@ class MLModelsClient:
|
|
|
522
525
|
record=rec,
|
|
523
526
|
headers=headers,
|
|
524
527
|
timeout=timeout,
|
|
525
|
-
indexes=None,
|
|
528
|
+
indexes=None, # type: ignore[arg-type]
|
|
526
529
|
)
|
|
527
530
|
|
|
528
531
|
def log(
|
|
@@ -668,8 +671,10 @@ class MLModelsClient:
|
|
|
668
671
|
dataframe = remove_extraneous_columns(df=dataframe, schema=schema)
|
|
669
672
|
|
|
670
673
|
# always validate pd.Category is not present, if yes, convert to string
|
|
674
|
+
# Type ignore: pandas.api.types.is_categorical_dtype exists but stubs may be incomplete
|
|
671
675
|
has_cat_col = any(
|
|
672
|
-
ptypes.is_categorical_dtype(x)
|
|
676
|
+
ptypes.is_categorical_dtype(x) # type: ignore[attr-defined]
|
|
677
|
+
for x in dataframe.dtypes
|
|
673
678
|
)
|
|
674
679
|
if has_cat_col:
|
|
675
680
|
cat_cols = [
|
|
@@ -691,14 +696,15 @@ class MLModelsClient:
|
|
|
691
696
|
from arize.ml.surrogate_explainer.mimic import Mimic
|
|
692
697
|
|
|
693
698
|
logger.debug("Running surrogate_explainability.")
|
|
694
|
-
|
|
699
|
+
# Type ignore: schema typed as BaseSchema but runtime is Schema with these attrs
|
|
700
|
+
if schema.shap_values_column_names: # type: ignore[attr-defined]
|
|
695
701
|
logger.info(
|
|
696
702
|
"surrogate_explainability=True has no effect "
|
|
697
703
|
"because shap_values_column_names is already specified in schema."
|
|
698
704
|
)
|
|
699
|
-
elif schema.feature_column_names is None or (
|
|
700
|
-
hasattr(schema.feature_column_names, "__len__")
|
|
701
|
-
and len(schema.feature_column_names) == 0
|
|
705
|
+
elif schema.feature_column_names is None or ( # type: ignore[attr-defined]
|
|
706
|
+
hasattr(schema.feature_column_names, "__len__") # type: ignore[attr-defined]
|
|
707
|
+
and len(schema.feature_column_names) == 0 # type: ignore[attr-defined]
|
|
702
708
|
):
|
|
703
709
|
logger.info(
|
|
704
710
|
"surrogate_explainability=True has no effect "
|
|
@@ -706,7 +712,9 @@ class MLModelsClient:
|
|
|
706
712
|
)
|
|
707
713
|
else:
|
|
708
714
|
dataframe, schema = Mimic.augment(
|
|
709
|
-
df=dataframe,
|
|
715
|
+
df=dataframe,
|
|
716
|
+
schema=schema, # type: ignore[arg-type]
|
|
717
|
+
model_type=model_type,
|
|
710
718
|
)
|
|
711
719
|
|
|
712
720
|
# Convert to Arrow table
|
|
@@ -733,8 +741,8 @@ class MLModelsClient:
|
|
|
733
741
|
pyarrow_schema=pa_table.schema,
|
|
734
742
|
)
|
|
735
743
|
if errors:
|
|
736
|
-
for
|
|
737
|
-
logger.error(
|
|
744
|
+
for error in errors:
|
|
745
|
+
logger.error(error)
|
|
738
746
|
raise ValidationFailure(errors)
|
|
739
747
|
if validate:
|
|
740
748
|
logger.debug("Performing values validation.")
|
|
@@ -745,8 +753,8 @@ class MLModelsClient:
|
|
|
745
753
|
model_type=model_type,
|
|
746
754
|
)
|
|
747
755
|
if errors:
|
|
748
|
-
for
|
|
749
|
-
logger.error(
|
|
756
|
+
for error in errors:
|
|
757
|
+
logger.error(error)
|
|
750
758
|
raise ValidationFailure(errors)
|
|
751
759
|
|
|
752
760
|
if isinstance(schema, Schema) and not schema.has_prediction_columns():
|
|
@@ -759,12 +767,12 @@ class MLModelsClient:
|
|
|
759
767
|
|
|
760
768
|
if environment == Environments.CORPUS:
|
|
761
769
|
proto_schema = _get_pb_schema_corpus(
|
|
762
|
-
schema=schema,
|
|
770
|
+
schema=schema, # type: ignore[arg-type]
|
|
763
771
|
model_id=model_name,
|
|
764
772
|
)
|
|
765
773
|
else:
|
|
766
774
|
proto_schema = _get_pb_schema(
|
|
767
|
-
schema=schema,
|
|
775
|
+
schema=schema, # type: ignore[arg-type]
|
|
768
776
|
model_id=model_name,
|
|
769
777
|
model_version=model_version,
|
|
770
778
|
model_type=model_type,
|
|
@@ -880,6 +888,7 @@ class MLModelsClient:
|
|
|
880
888
|
def export_to_parquet(
|
|
881
889
|
self,
|
|
882
890
|
*,
|
|
891
|
+
path: str,
|
|
883
892
|
space_id: str,
|
|
884
893
|
model_name: str,
|
|
885
894
|
environment: Environments,
|
|
@@ -892,13 +901,14 @@ class MLModelsClient:
|
|
|
892
901
|
columns: list | None = None,
|
|
893
902
|
similarity_search_params: SimilaritySearchParams | None = None,
|
|
894
903
|
stream_chunk_size: int | None = None,
|
|
895
|
-
) ->
|
|
896
|
-
"""Export model data from Arize to a Parquet file
|
|
904
|
+
) -> None:
|
|
905
|
+
"""Export model data from Arize to a Parquet file.
|
|
897
906
|
|
|
898
907
|
Retrieves prediction and optional actual data for a model within a specified time
|
|
899
|
-
range
|
|
908
|
+
range and writes it directly to a Parquet file at the specified path.
|
|
900
909
|
|
|
901
910
|
Args:
|
|
911
|
+
path: The file path where the Parquet file will be written.
|
|
902
912
|
space_id: The space ID where the model resides.
|
|
903
913
|
model_name: The name of the model to export data from.
|
|
904
914
|
environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
|
|
@@ -916,16 +926,12 @@ class MLModelsClient:
|
|
|
916
926
|
filtering.
|
|
917
927
|
stream_chunk_size: Optional chunk size for streaming large result sets.
|
|
918
928
|
|
|
919
|
-
Returns:
|
|
920
|
-
:class:`pandas.DataFrame`: A pandas DataFrame containing the exported data.
|
|
921
|
-
The data is also saved to a Parquet file by the underlying export client.
|
|
922
|
-
|
|
923
929
|
Raises:
|
|
924
930
|
RuntimeError: If the Flight client request fails or returns no response.
|
|
925
931
|
|
|
926
932
|
Notes:
|
|
927
933
|
- Uses Apache Arrow Flight for efficient data transfer
|
|
928
|
-
-
|
|
934
|
+
- Data is written directly to the specified path as a Parquet file
|
|
929
935
|
- Large exports may benefit from specifying stream_chunk_size
|
|
930
936
|
"""
|
|
931
937
|
require(_BATCH_EXTRA, _BATCH_DEPS)
|
|
@@ -943,7 +949,8 @@ class MLModelsClient:
|
|
|
943
949
|
exporter = ArizeExportClient(
|
|
944
950
|
flight_client=flight_client,
|
|
945
951
|
)
|
|
946
|
-
|
|
952
|
+
exporter.export_to_parquet(
|
|
953
|
+
path=path,
|
|
947
954
|
space_id=space_id,
|
|
948
955
|
model_id=model_name,
|
|
949
956
|
environment=environment,
|
|
@@ -982,7 +989,7 @@ class MLModelsClient:
|
|
|
982
989
|
headers: dict[str, str],
|
|
983
990
|
timeout: float | None,
|
|
984
991
|
indexes: tuple,
|
|
985
|
-
) ->
|
|
992
|
+
) -> cf.Future[Any]:
|
|
986
993
|
"""Post a record to Arize via async HTTP request with protobuf JSON serialization."""
|
|
987
994
|
from google.protobuf.json_format import MessageToDict
|
|
988
995
|
|
arize/ml/proto.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# type: ignore[pb2]
|
|
2
1
|
"""Protocol buffer utilities for ML model data serialization."""
|
|
3
2
|
|
|
4
3
|
from __future__ import annotations
|
|
@@ -26,7 +25,7 @@ from arize.ml.types import (
|
|
|
26
25
|
from arize.utils.types import is_list_of
|
|
27
26
|
|
|
28
27
|
|
|
29
|
-
def get_pb_dictionary(d:
|
|
28
|
+
def get_pb_dictionary(d: object | None) -> dict[str, object]:
|
|
30
29
|
"""Convert a dictionary to protobuf format with string keys and pb2.Value values.
|
|
31
30
|
|
|
32
31
|
Args:
|
|
@@ -37,6 +36,8 @@ def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
|
|
|
37
36
|
"""
|
|
38
37
|
if d is None:
|
|
39
38
|
return {}
|
|
39
|
+
if not isinstance(d, dict):
|
|
40
|
+
return {}
|
|
40
41
|
# Takes a dictionary and
|
|
41
42
|
# - casts the keys as strings
|
|
42
43
|
# - turns the values of the dictionary to our proto values pb2.Value()
|
|
@@ -48,7 +49,7 @@ def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
|
|
|
48
49
|
return converted_dict
|
|
49
50
|
|
|
50
51
|
|
|
51
|
-
def get_pb_value(name:
|
|
52
|
+
def get_pb_value(name: object, value: pb2.Value) -> pb2.Value:
|
|
52
53
|
"""Convert a Python value to a protobuf Value object.
|
|
53
54
|
|
|
54
55
|
Args:
|
|
@@ -114,7 +115,8 @@ def get_pb_label(
|
|
|
114
115
|
Raises:
|
|
115
116
|
ValueError: If model_type is not supported.
|
|
116
117
|
"""
|
|
117
|
-
value
|
|
118
|
+
# convert_element preserves value type but returns object for type safety
|
|
119
|
+
value = convert_element(value) # type: ignore[assignment]
|
|
118
120
|
if model_type in NUMERIC_MODEL_TYPES:
|
|
119
121
|
return _get_numeric_pb_label(prediction_or_actual, value)
|
|
120
122
|
if (
|
|
@@ -129,7 +131,7 @@ def get_pb_label(
|
|
|
129
131
|
if model_type == ModelTypes.MULTI_CLASS:
|
|
130
132
|
return _get_multi_class_pb_label(value)
|
|
131
133
|
raise ValueError(
|
|
132
|
-
f"model_type must be one of: {[mt.
|
|
134
|
+
f"model_type must be one of: {[mt.name for mt in ModelTypes]} "
|
|
133
135
|
f"Got "
|
|
134
136
|
f"{model_type} instead."
|
|
135
137
|
)
|
|
@@ -197,12 +199,12 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
|
|
|
197
199
|
|
|
198
200
|
def _get_numeric_pb_label(
|
|
199
201
|
prediction_or_actual: str,
|
|
200
|
-
value:
|
|
202
|
+
value: object,
|
|
201
203
|
) -> pb2.PredictionLabel | pb2.ActualLabel:
|
|
202
204
|
if not isinstance(value, (int, float)):
|
|
203
205
|
raise TypeError(
|
|
204
206
|
f"Received {prediction_or_actual}_label = {value}, of type {type(value)}. "
|
|
205
|
-
+ f"{[mt.
|
|
207
|
+
+ f"{[mt.name for mt in NUMERIC_MODEL_TYPES]} models accept labels of "
|
|
206
208
|
f"type int or float"
|
|
207
209
|
)
|
|
208
210
|
if prediction_or_actual == "prediction":
|
|
@@ -214,7 +216,7 @@ def _get_numeric_pb_label(
|
|
|
214
216
|
|
|
215
217
|
def _get_score_categorical_pb_label(
|
|
216
218
|
prediction_or_actual: str,
|
|
217
|
-
value:
|
|
219
|
+
value: object,
|
|
218
220
|
) -> pb2.PredictionLabel | pb2.ActualLabel:
|
|
219
221
|
sc = pb2.ScoreCategorical()
|
|
220
222
|
if isinstance(value, bool):
|
|
@@ -229,7 +231,7 @@ def _get_score_categorical_pb_label(
|
|
|
229
231
|
raise TypeError(
|
|
230
232
|
f"Received {prediction_or_actual}_label = {value}, of type "
|
|
231
233
|
f"{type(value)}[{type(value[0])}, None]. "
|
|
232
|
-
f"{[mt.
|
|
234
|
+
f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept "
|
|
233
235
|
"values of type str, bool, or Tuple[str, float]"
|
|
234
236
|
)
|
|
235
237
|
if not isinstance(value[0], (bool, str)) or not isinstance(
|
|
@@ -238,7 +240,7 @@ def _get_score_categorical_pb_label(
|
|
|
238
240
|
raise TypeError(
|
|
239
241
|
f"Received {prediction_or_actual}_label = {value}, of type "
|
|
240
242
|
f"{type(value)}[{type(value[0])}, {type(value[1])}]. "
|
|
241
|
-
f"{[mt.
|
|
243
|
+
f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept "
|
|
242
244
|
"values of type str, bool, or Tuple[str or bool, float]"
|
|
243
245
|
)
|
|
244
246
|
if isinstance(value[0], bool):
|
|
@@ -249,7 +251,7 @@ def _get_score_categorical_pb_label(
|
|
|
249
251
|
else:
|
|
250
252
|
raise TypeError(
|
|
251
253
|
f"Received {prediction_or_actual}_label = {value}, of type {type(value)}. "
|
|
252
|
-
+ f"{[mt.
|
|
254
|
+
+ f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept values "
|
|
253
255
|
f"of type str, bool, int, float or Tuple[str, float]"
|
|
254
256
|
)
|
|
255
257
|
if prediction_or_actual == "prediction":
|
|
@@ -261,10 +263,7 @@ def _get_score_categorical_pb_label(
|
|
|
261
263
|
|
|
262
264
|
def _get_cv_pb_label(
|
|
263
265
|
prediction_or_actual: str,
|
|
264
|
-
value:
|
|
265
|
-
| SemanticSegmentationLabel
|
|
266
|
-
| InstanceSegmentationPredictionLabel
|
|
267
|
-
| InstanceSegmentationActualLabel,
|
|
266
|
+
value: object,
|
|
268
267
|
) -> pb2.PredictionLabel | pb2.ActualLabel:
|
|
269
268
|
if isinstance(value, ObjectDetectionLabel):
|
|
270
269
|
return _get_object_detection_pb_label(prediction_or_actual, value)
|
|
@@ -429,7 +428,7 @@ def _get_instance_segmentation_actual_pb_label(
|
|
|
429
428
|
|
|
430
429
|
|
|
431
430
|
def _get_ranking_pb_label(
|
|
432
|
-
value:
|
|
431
|
+
value: object,
|
|
433
432
|
) -> pb2.PredictionLabel | pb2.ActualLabel:
|
|
434
433
|
if not isinstance(value, (RankingPredictionLabel, RankingActualLabel)):
|
|
435
434
|
raise InvalidValueType(
|
|
@@ -460,7 +459,7 @@ def _get_ranking_pb_label(
|
|
|
460
459
|
|
|
461
460
|
|
|
462
461
|
def _get_multi_class_pb_label(
|
|
463
|
-
value:
|
|
462
|
+
value: object,
|
|
464
463
|
) -> pb2.PredictionLabel | pb2.ActualLabel:
|
|
465
464
|
if not isinstance(
|
|
466
465
|
value, (MultiClassPredictionLabel, MultiClassActualLabel)
|
arize/ml/stream_validation.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# type: ignore[pb2]
|
|
2
1
|
"""Stream validation logic for ML model predictions."""
|
|
3
2
|
|
|
4
3
|
from arize.constants.ml import MAX_PREDICTION_ID_LEN, MIN_PREDICTION_ID_LEN
|
|
@@ -28,20 +27,8 @@ from arize.ml.types import (
|
|
|
28
27
|
def validate_label(
|
|
29
28
|
prediction_or_actual: str,
|
|
30
29
|
model_type: ModelTypes,
|
|
31
|
-
label:
|
|
32
|
-
|
|
|
33
|
-
| int
|
|
34
|
-
| float
|
|
35
|
-
| tuple[str | bool, float]
|
|
36
|
-
| ObjectDetectionLabel
|
|
37
|
-
| RankingPredictionLabel
|
|
38
|
-
| RankingActualLabel
|
|
39
|
-
| SemanticSegmentationLabel
|
|
40
|
-
| InstanceSegmentationPredictionLabel
|
|
41
|
-
| InstanceSegmentationActualLabel
|
|
42
|
-
| MultiClassPredictionLabel
|
|
43
|
-
| MultiClassActualLabel,
|
|
44
|
-
embedding_features: dict[str, Embedding],
|
|
30
|
+
label: object,
|
|
31
|
+
embedding_features: dict[str, Embedding] | None,
|
|
45
32
|
) -> None:
|
|
46
33
|
"""Validate a label value against the specified model type.
|
|
47
34
|
|
|
@@ -75,8 +62,17 @@ def validate_label(
|
|
|
75
62
|
|
|
76
63
|
def _validate_numeric_label(
|
|
77
64
|
model_type: ModelTypes,
|
|
78
|
-
label:
|
|
65
|
+
label: object,
|
|
79
66
|
) -> None:
|
|
67
|
+
"""Validate that a label is numeric (int or float) for numeric model types.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
model_type: The model type being validated.
|
|
71
|
+
label: The label value to validate.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
InvalidValueType: If the label is not an int or float.
|
|
75
|
+
"""
|
|
80
76
|
if not isinstance(label, (float, int)):
|
|
81
77
|
raise InvalidValueType(
|
|
82
78
|
f"label {label}",
|
|
@@ -87,8 +83,18 @@ def _validate_numeric_label(
|
|
|
87
83
|
|
|
88
84
|
def _validate_categorical_label(
|
|
89
85
|
model_type: ModelTypes,
|
|
90
|
-
label:
|
|
86
|
+
label: object,
|
|
91
87
|
) -> None:
|
|
88
|
+
"""Validate that a label is categorical (scalar or tuple with confidence) for categorical model types.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model_type: The model type being validated.
|
|
92
|
+
label: The label value to validate.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
InvalidValueType: If the label is not a valid categorical type (bool, int, float, str,
|
|
96
|
+
or tuple of [str/bool, float]).
|
|
97
|
+
"""
|
|
92
98
|
is_valid = isinstance(label, (str, bool, int, float)) or (
|
|
93
99
|
isinstance(label, tuple)
|
|
94
100
|
and isinstance(label[0], (str, bool))
|
|
@@ -104,12 +110,20 @@ def _validate_categorical_label(
|
|
|
104
110
|
|
|
105
111
|
def _validate_cv_label(
|
|
106
112
|
prediction_or_actual: str,
|
|
107
|
-
label:
|
|
108
|
-
|
|
|
109
|
-
| InstanceSegmentationPredictionLabel
|
|
110
|
-
| InstanceSegmentationActualLabel,
|
|
111
|
-
embedding_features: dict[str, Embedding],
|
|
113
|
+
label: object,
|
|
114
|
+
embedding_features: dict[str, Embedding] | None,
|
|
112
115
|
) -> None:
|
|
116
|
+
"""Validate a computer vision label for object detection or segmentation tasks.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
prediction_or_actual: Either 'prediction' or 'actual' to indicate label context.
|
|
120
|
+
label: The CV label to validate.
|
|
121
|
+
embedding_features: Dictionary of embedding features that must contain exactly one entry.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
InvalidValueType: If the label is not a valid CV label type.
|
|
125
|
+
ValueError: If embedding_features is None or doesn't contain exactly one feature.
|
|
126
|
+
"""
|
|
113
127
|
if (
|
|
114
128
|
not isinstance(label, ObjectDetectionLabel)
|
|
115
129
|
and not isinstance(label, SemanticSegmentationLabel)
|
|
@@ -137,8 +151,16 @@ def _validate_cv_label(
|
|
|
137
151
|
|
|
138
152
|
|
|
139
153
|
def _validate_ranking_label(
|
|
140
|
-
label:
|
|
154
|
+
label: object,
|
|
141
155
|
) -> None:
|
|
156
|
+
"""Validate a ranking label for ranking model types.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
label: The ranking label to validate.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
InvalidValueType: If the label is not a RankingPredictionLabel or RankingActualLabel.
|
|
163
|
+
"""
|
|
142
164
|
if not isinstance(label, (RankingPredictionLabel, RankingActualLabel)):
|
|
143
165
|
raise InvalidValueType(
|
|
144
166
|
f"label {label}",
|
|
@@ -149,8 +171,16 @@ def _validate_ranking_label(
|
|
|
149
171
|
|
|
150
172
|
|
|
151
173
|
def _validate_generative_llm_label(
|
|
152
|
-
label:
|
|
174
|
+
label: object,
|
|
153
175
|
) -> None:
|
|
176
|
+
"""Validate a label for generative LLM model types.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
label: The label value to validate.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
InvalidValueType: If the label is not a bool, int, float, or str.
|
|
183
|
+
"""
|
|
154
184
|
is_valid = isinstance(label, (str, bool, int, float))
|
|
155
185
|
if not is_valid:
|
|
156
186
|
raise InvalidValueType(
|
|
@@ -161,8 +191,16 @@ def _validate_generative_llm_label(
|
|
|
161
191
|
|
|
162
192
|
|
|
163
193
|
def _validate_multi_class_label(
|
|
164
|
-
label:
|
|
194
|
+
label: object,
|
|
165
195
|
) -> None:
|
|
196
|
+
"""Validate a multi-class label for multi-class model types.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
label: The multi-class label to validate.
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
InvalidValueType: If the label is not a MultiClassPredictionLabel or MultiClassActualLabel.
|
|
203
|
+
"""
|
|
166
204
|
if not isinstance(
|
|
167
205
|
label, (MultiClassPredictionLabel, MultiClassActualLabel)
|
|
168
206
|
):
|
|
@@ -19,6 +19,7 @@ from arize.ml.types import (
|
|
|
19
19
|
CATEGORICAL_MODEL_TYPES,
|
|
20
20
|
NUMERIC_MODEL_TYPES,
|
|
21
21
|
ModelTypes,
|
|
22
|
+
_normalize_column_names,
|
|
22
23
|
)
|
|
23
24
|
|
|
24
25
|
if TYPE_CHECKING:
|
|
@@ -60,7 +61,7 @@ class Mimic:
|
|
|
60
61
|
df: pd.DataFrame, schema: Schema, model_type: ModelTypes
|
|
61
62
|
) -> tuple[pd.DataFrame, Schema]:
|
|
62
63
|
"""Augment the :class:`pandas.DataFrame` and schema with SHAP values for explainability."""
|
|
63
|
-
features = schema.feature_column_names
|
|
64
|
+
features = _normalize_column_names(schema.feature_column_names)
|
|
64
65
|
X = df[features]
|
|
65
66
|
|
|
66
67
|
if X.shape[1] == 0:
|
|
@@ -85,25 +86,32 @@ class Mimic:
|
|
|
85
86
|
)
|
|
86
87
|
|
|
87
88
|
# model func requires 1 positional argument
|
|
88
|
-
def model_func(_: object) -> object:
|
|
89
|
+
def model_func(_: object) -> object:
|
|
89
90
|
return np.column_stack((1 - y, y))
|
|
90
91
|
|
|
91
92
|
elif model_type in NUMERIC_MODEL_TYPES:
|
|
92
|
-
|
|
93
|
+
y_col_name_nullable: str | None = (
|
|
94
|
+
schema.prediction_label_column_name
|
|
95
|
+
)
|
|
93
96
|
if schema.prediction_score_column_name is not None:
|
|
94
|
-
|
|
95
|
-
|
|
97
|
+
y_col_name_nullable = schema.prediction_score_column_name
|
|
98
|
+
if y_col_name_nullable is None:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"For {model_type} models, either prediction_label_column_name "
|
|
101
|
+
"or prediction_score_column_name must be specified"
|
|
102
|
+
)
|
|
103
|
+
y = df[y_col_name_nullable].to_numpy()
|
|
96
104
|
|
|
97
105
|
_finite_count = np.isfinite(y).sum()
|
|
98
106
|
if len(y) - _finite_count:
|
|
99
107
|
raise ValueError(
|
|
100
108
|
f"To calculate surrogate explainability for {model_type}, "
|
|
101
109
|
f"predictions must not contain NaN or infinite values, but "
|
|
102
|
-
f"{len(y) - _finite_count} NaN or infinite value(s) are found in {
|
|
110
|
+
f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name_nullable}."
|
|
103
111
|
)
|
|
104
112
|
|
|
105
113
|
# model func requires 1 positional argument
|
|
106
|
-
def model_func(_: object) -> object:
|
|
114
|
+
def model_func(_: object) -> object:
|
|
107
115
|
return y
|
|
108
116
|
|
|
109
117
|
else:
|
arize/ml/types.py
CHANGED
|
@@ -2,16 +2,19 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import math
|
|
5
|
+
import sys
|
|
5
6
|
from collections.abc import Iterator
|
|
6
7
|
from dataclasses import asdict, dataclass, replace
|
|
7
8
|
from datetime import datetime
|
|
8
9
|
from decimal import Decimal
|
|
9
10
|
from enum import Enum, unique
|
|
10
11
|
from itertools import chain
|
|
11
|
-
from typing import
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
12
|
+
from typing import NamedTuple
|
|
13
|
+
|
|
14
|
+
if sys.version_info >= (3, 11):
|
|
15
|
+
from typing import Self
|
|
16
|
+
else:
|
|
17
|
+
from typing_extensions import Self
|
|
15
18
|
|
|
16
19
|
import numpy as np
|
|
17
20
|
|
|
@@ -29,6 +32,17 @@ from arize.utils.types import is_dict_of, is_iterable_of, is_list_of
|
|
|
29
32
|
logger = logging.getLogger(__name__)
|
|
30
33
|
|
|
31
34
|
|
|
35
|
+
def _normalize_column_names(
|
|
36
|
+
col_names: "list[str] | TypedColumns | None",
|
|
37
|
+
) -> list[str]:
|
|
38
|
+
"""Convert TypedColumns or list to a flat list of column names."""
|
|
39
|
+
if col_names is None:
|
|
40
|
+
return []
|
|
41
|
+
if isinstance(col_names, list):
|
|
42
|
+
return col_names
|
|
43
|
+
return col_names.get_all_column_names()
|
|
44
|
+
|
|
45
|
+
|
|
32
46
|
@unique
|
|
33
47
|
class ModelTypes(Enum):
|
|
34
48
|
"""Enum representing supported model types in Arize."""
|
|
@@ -190,7 +204,7 @@ class Embedding(NamedTuple):
|
|
|
190
204
|
)
|
|
191
205
|
# Fail if not all elements in list are floats
|
|
192
206
|
allowed_types = (int, float, np.int16, np.int32, np.float16, np.float32)
|
|
193
|
-
if not all(isinstance(val, allowed_types) for val in self.vector):
|
|
207
|
+
if not all(isinstance(val, allowed_types) for val in self.vector):
|
|
194
208
|
raise TypeError(
|
|
195
209
|
f"Embedding vector must be a vector of integers and/or floats. Got "
|
|
196
210
|
f"{emb_name}.vector = {self.vector}"
|
|
@@ -269,7 +283,7 @@ class Embedding(NamedTuple):
|
|
|
269
283
|
|
|
270
284
|
@staticmethod
|
|
271
285
|
def _is_valid_iterable(
|
|
272
|
-
data:
|
|
286
|
+
data: object,
|
|
273
287
|
) -> bool:
|
|
274
288
|
"""Validates that the input data field is of the correct iterable type.
|
|
275
289
|
|
|
@@ -1196,7 +1210,7 @@ class Schema(BaseSchema):
|
|
|
1196
1210
|
actual_score_column_name: str | None = None
|
|
1197
1211
|
shap_values_column_names: dict[str, str] | None = None
|
|
1198
1212
|
embedding_feature_column_names: dict[str, EmbeddingColumnNames] | None = (
|
|
1199
|
-
None
|
|
1213
|
+
None
|
|
1200
1214
|
)
|
|
1201
1215
|
prediction_group_id_column_name: str | None = None
|
|
1202
1216
|
rank_column_name: str | None = None
|
|
@@ -1214,7 +1228,7 @@ class Schema(BaseSchema):
|
|
|
1214
1228
|
prompt_template_column_names: PromptTemplateColumnNames | None = None
|
|
1215
1229
|
llm_config_column_names: LLMConfigColumnNames | None = None
|
|
1216
1230
|
llm_run_metadata_column_names: LLMRunMetadataColumnNames | None = None
|
|
1217
|
-
retrieved_document_ids_column_name:
|
|
1231
|
+
retrieved_document_ids_column_name: str | None = None
|
|
1218
1232
|
multi_class_threshold_scores_column_name: str | None = None
|
|
1219
1233
|
semantic_segmentation_prediction_column_names: (
|
|
1220
1234
|
SemanticSegmentationColumnNames | None
|
|
@@ -1231,7 +1245,7 @@ class Schema(BaseSchema):
|
|
|
1231
1245
|
|
|
1232
1246
|
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1233
1247
|
"""Return a dict mapping column names to their usage count."""
|
|
1234
|
-
columns_used_counts = {}
|
|
1248
|
+
columns_used_counts: dict[str, int] = {}
|
|
1235
1249
|
|
|
1236
1250
|
for field in self.__dataclass_fields__:
|
|
1237
1251
|
if field.endswith("column_name"):
|
|
@@ -1240,7 +1254,7 @@ class Schema(BaseSchema):
|
|
|
1240
1254
|
add_to_column_count_dictionary(columns_used_counts, col)
|
|
1241
1255
|
|
|
1242
1256
|
if self.feature_column_names is not None:
|
|
1243
|
-
for col in self.feature_column_names:
|
|
1257
|
+
for col in _normalize_column_names(self.feature_column_names):
|
|
1244
1258
|
add_to_column_count_dictionary(columns_used_counts, col)
|
|
1245
1259
|
|
|
1246
1260
|
if self.embedding_feature_column_names is not None:
|
|
@@ -1259,7 +1273,7 @@ class Schema(BaseSchema):
|
|
|
1259
1273
|
)
|
|
1260
1274
|
|
|
1261
1275
|
if self.tag_column_names is not None:
|
|
1262
|
-
for col in self.tag_column_names:
|
|
1276
|
+
for col in _normalize_column_names(self.tag_column_names):
|
|
1263
1277
|
add_to_column_count_dictionary(columns_used_counts, col)
|
|
1264
1278
|
|
|
1265
1279
|
if self.shap_values_column_names is not None:
|
|
@@ -1404,7 +1418,7 @@ class CorpusSchema(BaseSchema):
|
|
|
1404
1418
|
|
|
1405
1419
|
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1406
1420
|
"""Return a dict mapping column names to their usage count."""
|
|
1407
|
-
columns_used_counts = {}
|
|
1421
|
+
columns_used_counts: dict[str, int] = {}
|
|
1408
1422
|
|
|
1409
1423
|
if self.document_id_column_name is not None:
|
|
1410
1424
|
add_to_column_count_dictionary(
|
arize/pre_releases.py
CHANGED
|
@@ -3,15 +3,15 @@
|
|
|
3
3
|
import functools
|
|
4
4
|
import logging
|
|
5
5
|
from collections.abc import Callable
|
|
6
|
-
from enum import
|
|
7
|
-
from typing import TypeVar
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import TypeVar, cast
|
|
8
8
|
|
|
9
9
|
from arize.version import __version__
|
|
10
10
|
|
|
11
11
|
logger = logging.getLogger(__name__)
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class ReleaseStage(
|
|
14
|
+
class ReleaseStage(Enum):
|
|
15
15
|
"""Enum representing the release stage of API features."""
|
|
16
16
|
|
|
17
17
|
ALPHA = "alpha"
|
|
@@ -26,12 +26,12 @@ _F = TypeVar("_F", bound=Callable)
|
|
|
26
26
|
def _format_prerelease_message(*, key: str, stage: ReleaseStage) -> str:
|
|
27
27
|
article = "an" if stage is ReleaseStage.ALPHA else "a"
|
|
28
28
|
return (
|
|
29
|
-
f"[{stage.upper()}] {key} is {article} {stage} API "
|
|
29
|
+
f"[{stage.value.upper()}] {key} is {article} {stage.value} API "
|
|
30
30
|
f"in Arize SDK v{__version__} and may change without notice."
|
|
31
31
|
)
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def prerelease_endpoint(*,
|
|
34
|
+
def prerelease_endpoint(*, key: str, stage: ReleaseStage) -> Callable[[_F], _F]:
|
|
35
35
|
"""Decorate a method to emit a prerelease warning via logging once per process."""
|
|
36
36
|
|
|
37
37
|
def deco(fn: _F) -> _F:
|
|
@@ -42,6 +42,7 @@ def prerelease_endpoint(*, stage: ReleaseStage, key: str) -> Callable[[_F], _F]:
|
|
|
42
42
|
logger.warning(_format_prerelease_message(key=key, stage=stage))
|
|
43
43
|
return fn(*args, **kwargs)
|
|
44
44
|
|
|
45
|
-
|
|
45
|
+
# Cast: functools.wraps preserves function signature at runtime but mypy can't verify this
|
|
46
|
+
return cast("_F", wrapper)
|
|
46
47
|
|
|
47
48
|
return deco
|
arize/py.typed
ADDED
|
File without changes
|