arize 8.0.0b2__py3-none-any.whl → 8.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arize/__init__.py +8 -1
  2. arize/_exporter/client.py +18 -17
  3. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  4. arize/_exporter/validation.py +1 -1
  5. arize/_flight/client.py +33 -13
  6. arize/_lazy.py +37 -2
  7. arize/client.py +61 -35
  8. arize/config.py +168 -14
  9. arize/constants/config.py +1 -0
  10. arize/datasets/client.py +32 -19
  11. arize/embeddings/auto_generator.py +14 -7
  12. arize/embeddings/base_generators.py +15 -9
  13. arize/embeddings/cv_generators.py +2 -2
  14. arize/embeddings/nlp_generators.py +8 -8
  15. arize/embeddings/tabular_generators.py +5 -5
  16. arize/exceptions/config.py +22 -0
  17. arize/exceptions/parameters.py +1 -1
  18. arize/exceptions/values.py +8 -5
  19. arize/experiments/__init__.py +4 -0
  20. arize/experiments/client.py +17 -11
  21. arize/experiments/evaluators/base.py +6 -3
  22. arize/experiments/evaluators/executors.py +6 -4
  23. arize/experiments/evaluators/rate_limiters.py +3 -1
  24. arize/experiments/evaluators/types.py +7 -5
  25. arize/experiments/evaluators/utils.py +7 -5
  26. arize/experiments/functions.py +111 -48
  27. arize/experiments/tracing.py +4 -1
  28. arize/experiments/types.py +31 -26
  29. arize/logging.py +53 -32
  30. arize/ml/batch_validation/validator.py +82 -70
  31. arize/ml/bounded_executor.py +25 -6
  32. arize/ml/casting.py +45 -27
  33. arize/ml/client.py +35 -28
  34. arize/ml/proto.py +16 -17
  35. arize/ml/stream_validation.py +63 -25
  36. arize/ml/surrogate_explainer/mimic.py +15 -7
  37. arize/ml/types.py +26 -12
  38. arize/pre_releases.py +7 -6
  39. arize/py.typed +0 -0
  40. arize/regions.py +10 -10
  41. arize/spans/client.py +113 -21
  42. arize/spans/conversion.py +7 -5
  43. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  44. arize/spans/validation/annotations/value_validation.py +11 -14
  45. arize/spans/validation/common/dataframe_form_validation.py +1 -1
  46. arize/spans/validation/common/value_validation.py +10 -13
  47. arize/spans/validation/evals/value_validation.py +1 -1
  48. arize/spans/validation/metadata/argument_validation.py +1 -1
  49. arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
  50. arize/spans/validation/metadata/value_validation.py +23 -1
  51. arize/utils/arrow.py +37 -1
  52. arize/utils/online_tasks/dataframe_preprocessor.py +8 -4
  53. arize/utils/proto.py +0 -1
  54. arize/utils/types.py +6 -6
  55. arize/version.py +1 -1
  56. {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/METADATA +10 -2
  57. {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/RECORD +60 -58
  58. {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
  59. {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
  60. {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
arize/ml/client.py CHANGED
@@ -1,4 +1,3 @@
1
- # type: ignore[pb2]
2
1
  """Client implementation for managing ML models in the Arize platform."""
3
2
 
4
3
  from __future__ import annotations
@@ -6,7 +5,7 @@ from __future__ import annotations
6
5
  import copy
7
6
  import logging
8
7
  import time
9
- from typing import TYPE_CHECKING
8
+ from typing import TYPE_CHECKING, Any, cast
10
9
 
11
10
  from arize._generated.protocol.rec import public_pb2 as pb2
12
11
  from arize._lazy import require
@@ -377,8 +376,12 @@ class MLModelsClient:
377
376
 
378
377
  if embedding_features or prompt or response:
379
378
  # NOTE: Deep copy is necessary to avoid side effects on the original input dictionary
380
- combined_embedding_features = (
381
- embedding_features.copy() if embedding_features else {}
379
+ combined_embedding_features: dict[str, str | Embedding] = (
380
+ cast(
381
+ "dict[str, str | Embedding]", embedding_features.copy()
382
+ )
383
+ if embedding_features
384
+ else {}
382
385
  )
383
386
  # Map prompt as embedding features for generative models
384
387
  if prompt is not None:
@@ -395,7 +398,7 @@ class MLModelsClient:
395
398
  p.MergeFrom(embedding_feats)
396
399
 
397
400
  if tags or llm_run_metadata:
398
- joined_tags = copy.deepcopy(tags)
401
+ joined_tags = copy.deepcopy(tags) if tags is not None else {}
399
402
  if llm_run_metadata:
400
403
  if llm_run_metadata.total_token_count is not None:
401
404
  joined_tags[
@@ -522,7 +525,7 @@ class MLModelsClient:
522
525
  record=rec,
523
526
  headers=headers,
524
527
  timeout=timeout,
525
- indexes=None,
528
+ indexes=None, # type: ignore[arg-type]
526
529
  )
527
530
 
528
531
  def log(
@@ -668,8 +671,10 @@ class MLModelsClient:
668
671
  dataframe = remove_extraneous_columns(df=dataframe, schema=schema)
669
672
 
670
673
  # always validate pd.Category is not present, if yes, convert to string
674
+ # Type ignore: pandas.api.types.is_categorical_dtype exists but stubs may be incomplete
671
675
  has_cat_col = any(
672
- ptypes.is_categorical_dtype(x) for x in dataframe.dtypes
676
+ ptypes.is_categorical_dtype(x) # type: ignore[attr-defined]
677
+ for x in dataframe.dtypes
673
678
  )
674
679
  if has_cat_col:
675
680
  cat_cols = [
@@ -691,14 +696,15 @@ class MLModelsClient:
691
696
  from arize.ml.surrogate_explainer.mimic import Mimic
692
697
 
693
698
  logger.debug("Running surrogate_explainability.")
694
- if schema.shap_values_column_names:
699
+ # Type ignore: schema typed as BaseSchema but runtime is Schema with these attrs
700
+ if schema.shap_values_column_names: # type: ignore[attr-defined]
695
701
  logger.info(
696
702
  "surrogate_explainability=True has no effect "
697
703
  "because shap_values_column_names is already specified in schema."
698
704
  )
699
- elif schema.feature_column_names is None or (
700
- hasattr(schema.feature_column_names, "__len__")
701
- and len(schema.feature_column_names) == 0
705
+ elif schema.feature_column_names is None or ( # type: ignore[attr-defined]
706
+ hasattr(schema.feature_column_names, "__len__") # type: ignore[attr-defined]
707
+ and len(schema.feature_column_names) == 0 # type: ignore[attr-defined]
702
708
  ):
703
709
  logger.info(
704
710
  "surrogate_explainability=True has no effect "
@@ -706,7 +712,9 @@ class MLModelsClient:
706
712
  )
707
713
  else:
708
714
  dataframe, schema = Mimic.augment(
709
- df=dataframe, schema=schema, model_type=model_type
715
+ df=dataframe,
716
+ schema=schema, # type: ignore[arg-type]
717
+ model_type=model_type,
710
718
  )
711
719
 
712
720
  # Convert to Arrow table
@@ -733,8 +741,8 @@ class MLModelsClient:
733
741
  pyarrow_schema=pa_table.schema,
734
742
  )
735
743
  if errors:
736
- for e in errors:
737
- logger.error(e)
744
+ for error in errors:
745
+ logger.error(error)
738
746
  raise ValidationFailure(errors)
739
747
  if validate:
740
748
  logger.debug("Performing values validation.")
@@ -745,8 +753,8 @@ class MLModelsClient:
745
753
  model_type=model_type,
746
754
  )
747
755
  if errors:
748
- for e in errors:
749
- logger.error(e)
756
+ for error in errors:
757
+ logger.error(error)
750
758
  raise ValidationFailure(errors)
751
759
 
752
760
  if isinstance(schema, Schema) and not schema.has_prediction_columns():
@@ -759,12 +767,12 @@ class MLModelsClient:
759
767
 
760
768
  if environment == Environments.CORPUS:
761
769
  proto_schema = _get_pb_schema_corpus(
762
- schema=schema,
770
+ schema=schema, # type: ignore[arg-type]
763
771
  model_id=model_name,
764
772
  )
765
773
  else:
766
774
  proto_schema = _get_pb_schema(
767
- schema=schema,
775
+ schema=schema, # type: ignore[arg-type]
768
776
  model_id=model_name,
769
777
  model_version=model_version,
770
778
  model_type=model_type,
@@ -880,6 +888,7 @@ class MLModelsClient:
880
888
  def export_to_parquet(
881
889
  self,
882
890
  *,
891
+ path: str,
883
892
  space_id: str,
884
893
  model_name: str,
885
894
  environment: Environments,
@@ -892,13 +901,14 @@ class MLModelsClient:
892
901
  columns: list | None = None,
893
902
  similarity_search_params: SimilaritySearchParams | None = None,
894
903
  stream_chunk_size: int | None = None,
895
- ) -> pd.DataFrame:
896
- """Export model data from Arize to a Parquet file and return as DataFrame.
904
+ ) -> None:
905
+ """Export model data from Arize to a Parquet file.
897
906
 
898
907
  Retrieves prediction and optional actual data for a model within a specified time
899
- range, saves it as a Parquet file, and returns it as a :class:`pandas.DataFrame`.
908
+ range and writes it directly to a Parquet file at the specified path.
900
909
 
901
910
  Args:
911
+ path: The file path where the Parquet file will be written.
902
912
  space_id: The space ID where the model resides.
903
913
  model_name: The name of the model to export data from.
904
914
  environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
@@ -916,16 +926,12 @@ class MLModelsClient:
916
926
  filtering.
917
927
  stream_chunk_size: Optional chunk size for streaming large result sets.
918
928
 
919
- Returns:
920
- :class:`pandas.DataFrame`: A pandas DataFrame containing the exported data.
921
- The data is also saved to a Parquet file by the underlying export client.
922
-
923
929
  Raises:
924
930
  RuntimeError: If the Flight client request fails or returns no response.
925
931
 
926
932
  Notes:
927
933
  - Uses Apache Arrow Flight for efficient data transfer
928
- - The Parquet file location is managed by the ArizeExportClient
934
+ - Data is written directly to the specified path as a Parquet file
929
935
  - Large exports may benefit from specifying stream_chunk_size
930
936
  """
931
937
  require(_BATCH_EXTRA, _BATCH_DEPS)
@@ -943,7 +949,8 @@ class MLModelsClient:
943
949
  exporter = ArizeExportClient(
944
950
  flight_client=flight_client,
945
951
  )
946
- return exporter.export_to_parquet(
952
+ exporter.export_to_parquet(
953
+ path=path,
947
954
  space_id=space_id,
948
955
  model_id=model_name,
949
956
  environment=environment,
@@ -982,7 +989,7 @@ class MLModelsClient:
982
989
  headers: dict[str, str],
983
990
  timeout: float | None,
984
991
  indexes: tuple,
985
- ) -> object:
992
+ ) -> cf.Future[Any]:
986
993
  """Post a record to Arize via async HTTP request with protobuf JSON serialization."""
987
994
  from google.protobuf.json_format import MessageToDict
988
995
 
arize/ml/proto.py CHANGED
@@ -1,4 +1,3 @@
1
- # type: ignore[pb2]
2
1
  """Protocol buffer utilities for ML model data serialization."""
3
2
 
4
3
  from __future__ import annotations
@@ -26,7 +25,7 @@ from arize.ml.types import (
26
25
  from arize.utils.types import is_list_of
27
26
 
28
27
 
29
- def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
28
+ def get_pb_dictionary(d: object | None) -> dict[str, object]:
30
29
  """Convert a dictionary to protobuf format with string keys and pb2.Value values.
31
30
 
32
31
  Args:
@@ -37,6 +36,8 @@ def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
37
36
  """
38
37
  if d is None:
39
38
  return {}
39
+ if not isinstance(d, dict):
40
+ return {}
40
41
  # Takes a dictionary and
41
42
  # - casts the keys as strings
42
43
  # - turns the values of the dictionary to our proto values pb2.Value()
@@ -48,7 +49,7 @@ def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
48
49
  return converted_dict
49
50
 
50
51
 
51
- def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
52
+ def get_pb_value(name: object, value: pb2.Value) -> pb2.Value:
52
53
  """Convert a Python value to a protobuf Value object.
53
54
 
54
55
  Args:
@@ -114,7 +115,8 @@ def get_pb_label(
114
115
  Raises:
115
116
  ValueError: If model_type is not supported.
116
117
  """
117
- value = convert_element(value)
118
+ # convert_element preserves value type but returns object for type safety
119
+ value = convert_element(value) # type: ignore[assignment]
118
120
  if model_type in NUMERIC_MODEL_TYPES:
119
121
  return _get_numeric_pb_label(prediction_or_actual, value)
120
122
  if (
@@ -129,7 +131,7 @@ def get_pb_label(
129
131
  if model_type == ModelTypes.MULTI_CLASS:
130
132
  return _get_multi_class_pb_label(value)
131
133
  raise ValueError(
132
- f"model_type must be one of: {[mt.prediction_or_actual for mt in ModelTypes]} "
134
+ f"model_type must be one of: {[mt.name for mt in ModelTypes]} "
133
135
  f"Got "
134
136
  f"{model_type} instead."
135
137
  )
@@ -197,12 +199,12 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
197
199
 
198
200
  def _get_numeric_pb_label(
199
201
  prediction_or_actual: str,
200
- value: int | float,
202
+ value: object,
201
203
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
202
204
  if not isinstance(value, (int, float)):
203
205
  raise TypeError(
204
206
  f"Received {prediction_or_actual}_label = {value}, of type {type(value)}. "
205
- + f"{[mt.prediction_or_actual for mt in NUMERIC_MODEL_TYPES]} models accept labels of "
207
+ + f"{[mt.name for mt in NUMERIC_MODEL_TYPES]} models accept labels of "
206
208
  f"type int or float"
207
209
  )
208
210
  if prediction_or_actual == "prediction":
@@ -214,7 +216,7 @@ def _get_numeric_pb_label(
214
216
 
215
217
  def _get_score_categorical_pb_label(
216
218
  prediction_or_actual: str,
217
- value: bool | str | tuple[str, float],
219
+ value: object,
218
220
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
219
221
  sc = pb2.ScoreCategorical()
220
222
  if isinstance(value, bool):
@@ -229,7 +231,7 @@ def _get_score_categorical_pb_label(
229
231
  raise TypeError(
230
232
  f"Received {prediction_or_actual}_label = {value}, of type "
231
233
  f"{type(value)}[{type(value[0])}, None]. "
232
- f"{[mt.prediction_or_actual for mt in CATEGORICAL_MODEL_TYPES]} models accept "
234
+ f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept "
233
235
  "values of type str, bool, or Tuple[str, float]"
234
236
  )
235
237
  if not isinstance(value[0], (bool, str)) or not isinstance(
@@ -238,7 +240,7 @@ def _get_score_categorical_pb_label(
238
240
  raise TypeError(
239
241
  f"Received {prediction_or_actual}_label = {value}, of type "
240
242
  f"{type(value)}[{type(value[0])}, {type(value[1])}]. "
241
- f"{[mt.prediction_or_actual for mt in CATEGORICAL_MODEL_TYPES]} models accept "
243
+ f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept "
242
244
  "values of type str, bool, or Tuple[str or bool, float]"
243
245
  )
244
246
  if isinstance(value[0], bool):
@@ -249,7 +251,7 @@ def _get_score_categorical_pb_label(
249
251
  else:
250
252
  raise TypeError(
251
253
  f"Received {prediction_or_actual}_label = {value}, of type {type(value)}. "
252
- + f"{[mt.prediction_or_actual for mt in CATEGORICAL_MODEL_TYPES]} models accept values "
254
+ + f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept values "
253
255
  f"of type str, bool, int, float or Tuple[str, float]"
254
256
  )
255
257
  if prediction_or_actual == "prediction":
@@ -261,10 +263,7 @@ def _get_score_categorical_pb_label(
261
263
 
262
264
  def _get_cv_pb_label(
263
265
  prediction_or_actual: str,
264
- value: ObjectDetectionLabel
265
- | SemanticSegmentationLabel
266
- | InstanceSegmentationPredictionLabel
267
- | InstanceSegmentationActualLabel,
266
+ value: object,
268
267
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
269
268
  if isinstance(value, ObjectDetectionLabel):
270
269
  return _get_object_detection_pb_label(prediction_or_actual, value)
@@ -429,7 +428,7 @@ def _get_instance_segmentation_actual_pb_label(
429
428
 
430
429
 
431
430
  def _get_ranking_pb_label(
432
- value: RankingPredictionLabel | RankingActualLabel,
431
+ value: object,
433
432
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
434
433
  if not isinstance(value, (RankingPredictionLabel, RankingActualLabel)):
435
434
  raise InvalidValueType(
@@ -460,7 +459,7 @@ def _get_ranking_pb_label(
460
459
 
461
460
 
462
461
  def _get_multi_class_pb_label(
463
- value: MultiClassPredictionLabel | MultiClassActualLabel,
462
+ value: object,
464
463
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
465
464
  if not isinstance(
466
465
  value, (MultiClassPredictionLabel, MultiClassActualLabel)
@@ -1,4 +1,3 @@
1
- # type: ignore[pb2]
2
1
  """Stream validation logic for ML model predictions."""
3
2
 
4
3
  from arize.constants.ml import MAX_PREDICTION_ID_LEN, MIN_PREDICTION_ID_LEN
@@ -28,20 +27,8 @@ from arize.ml.types import (
28
27
  def validate_label(
29
28
  prediction_or_actual: str,
30
29
  model_type: ModelTypes,
31
- label: str
32
- | bool
33
- | int
34
- | float
35
- | tuple[str | bool, float]
36
- | ObjectDetectionLabel
37
- | RankingPredictionLabel
38
- | RankingActualLabel
39
- | SemanticSegmentationLabel
40
- | InstanceSegmentationPredictionLabel
41
- | InstanceSegmentationActualLabel
42
- | MultiClassPredictionLabel
43
- | MultiClassActualLabel,
44
- embedding_features: dict[str, Embedding],
30
+ label: object,
31
+ embedding_features: dict[str, Embedding] | None,
45
32
  ) -> None:
46
33
  """Validate a label value against the specified model type.
47
34
 
@@ -75,8 +62,17 @@ def validate_label(
75
62
 
76
63
  def _validate_numeric_label(
77
64
  model_type: ModelTypes,
78
- label: str | bool | int | float | tuple[str | bool, float],
65
+ label: object,
79
66
  ) -> None:
67
+ """Validate that a label is numeric (int or float) for numeric model types.
68
+
69
+ Args:
70
+ model_type: The model type being validated.
71
+ label: The label value to validate.
72
+
73
+ Raises:
74
+ InvalidValueType: If the label is not an int or float.
75
+ """
80
76
  if not isinstance(label, (float, int)):
81
77
  raise InvalidValueType(
82
78
  f"label {label}",
@@ -87,8 +83,18 @@ def _validate_numeric_label(
87
83
 
88
84
  def _validate_categorical_label(
89
85
  model_type: ModelTypes,
90
- label: str | bool | int | float | tuple[str | bool, float],
86
+ label: object,
91
87
  ) -> None:
88
+ """Validate that a label is categorical (scalar or tuple with confidence) for categorical model types.
89
+
90
+ Args:
91
+ model_type: The model type being validated.
92
+ label: The label value to validate.
93
+
94
+ Raises:
95
+ InvalidValueType: If the label is not a valid categorical type (bool, int, float, str,
96
+ or tuple of [str/bool, float]).
97
+ """
92
98
  is_valid = isinstance(label, (str, bool, int, float)) or (
93
99
  isinstance(label, tuple)
94
100
  and isinstance(label[0], (str, bool))
@@ -104,12 +110,20 @@ def _validate_categorical_label(
104
110
 
105
111
  def _validate_cv_label(
106
112
  prediction_or_actual: str,
107
- label: ObjectDetectionLabel
108
- | SemanticSegmentationLabel
109
- | InstanceSegmentationPredictionLabel
110
- | InstanceSegmentationActualLabel,
111
- embedding_features: dict[str, Embedding],
113
+ label: object,
114
+ embedding_features: dict[str, Embedding] | None,
112
115
  ) -> None:
116
+ """Validate a computer vision label for object detection or segmentation tasks.
117
+
118
+ Args:
119
+ prediction_or_actual: Either 'prediction' or 'actual' to indicate label context.
120
+ label: The CV label to validate.
121
+ embedding_features: Dictionary of embedding features that must contain exactly one entry.
122
+
123
+ Raises:
124
+ InvalidValueType: If the label is not a valid CV label type.
125
+ ValueError: If embedding_features is None or doesn't contain exactly one feature.
126
+ """
113
127
  if (
114
128
  not isinstance(label, ObjectDetectionLabel)
115
129
  and not isinstance(label, SemanticSegmentationLabel)
@@ -137,8 +151,16 @@ def _validate_cv_label(
137
151
 
138
152
 
139
153
  def _validate_ranking_label(
140
- label: RankingPredictionLabel | RankingActualLabel,
154
+ label: object,
141
155
  ) -> None:
156
+ """Validate a ranking label for ranking model types.
157
+
158
+ Args:
159
+ label: The ranking label to validate.
160
+
161
+ Raises:
162
+ InvalidValueType: If the label is not a RankingPredictionLabel or RankingActualLabel.
163
+ """
142
164
  if not isinstance(label, (RankingPredictionLabel, RankingActualLabel)):
143
165
  raise InvalidValueType(
144
166
  f"label {label}",
@@ -149,8 +171,16 @@ def _validate_ranking_label(
149
171
 
150
172
 
151
173
  def _validate_generative_llm_label(
152
- label: str | bool | int | float,
174
+ label: object,
153
175
  ) -> None:
176
+ """Validate a label for generative LLM model types.
177
+
178
+ Args:
179
+ label: The label value to validate.
180
+
181
+ Raises:
182
+ InvalidValueType: If the label is not a bool, int, float, or str.
183
+ """
154
184
  is_valid = isinstance(label, (str, bool, int, float))
155
185
  if not is_valid:
156
186
  raise InvalidValueType(
@@ -161,8 +191,16 @@ def _validate_generative_llm_label(
161
191
 
162
192
 
163
193
  def _validate_multi_class_label(
164
- label: MultiClassPredictionLabel | MultiClassActualLabel,
194
+ label: object,
165
195
  ) -> None:
196
+ """Validate a multi-class label for multi-class model types.
197
+
198
+ Args:
199
+ label: The multi-class label to validate.
200
+
201
+ Raises:
202
+ InvalidValueType: If the label is not a MultiClassPredictionLabel or MultiClassActualLabel.
203
+ """
166
204
  if not isinstance(
167
205
  label, (MultiClassPredictionLabel, MultiClassActualLabel)
168
206
  ):
@@ -19,6 +19,7 @@ from arize.ml.types import (
19
19
  CATEGORICAL_MODEL_TYPES,
20
20
  NUMERIC_MODEL_TYPES,
21
21
  ModelTypes,
22
+ _normalize_column_names,
22
23
  )
23
24
 
24
25
  if TYPE_CHECKING:
@@ -60,7 +61,7 @@ class Mimic:
60
61
  df: pd.DataFrame, schema: Schema, model_type: ModelTypes
61
62
  ) -> tuple[pd.DataFrame, Schema]:
62
63
  """Augment the :class:`pandas.DataFrame` and schema with SHAP values for explainability."""
63
- features = schema.feature_column_names
64
+ features = _normalize_column_names(schema.feature_column_names)
64
65
  X = df[features]
65
66
 
66
67
  if X.shape[1] == 0:
@@ -85,25 +86,32 @@ class Mimic:
85
86
  )
86
87
 
87
88
  # model func requires 1 positional argument
88
- def model_func(_: object) -> object: # type: ignore
89
+ def model_func(_: object) -> object:
89
90
  return np.column_stack((1 - y, y))
90
91
 
91
92
  elif model_type in NUMERIC_MODEL_TYPES:
92
- y_col_name = schema.prediction_label_column_name
93
+ y_col_name_nullable: str | None = (
94
+ schema.prediction_label_column_name
95
+ )
93
96
  if schema.prediction_score_column_name is not None:
94
- y_col_name = schema.prediction_score_column_name
95
- y = df[y_col_name].to_numpy()
97
+ y_col_name_nullable = schema.prediction_score_column_name
98
+ if y_col_name_nullable is None:
99
+ raise ValueError(
100
+ f"For {model_type} models, either prediction_label_column_name "
101
+ "or prediction_score_column_name must be specified"
102
+ )
103
+ y = df[y_col_name_nullable].to_numpy()
96
104
 
97
105
  _finite_count = np.isfinite(y).sum()
98
106
  if len(y) - _finite_count:
99
107
  raise ValueError(
100
108
  f"To calculate surrogate explainability for {model_type}, "
101
109
  f"predictions must not contain NaN or infinite values, but "
102
- f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name}."
110
+ f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name_nullable}."
103
111
  )
104
112
 
105
113
  # model func requires 1 positional argument
106
- def model_func(_: object) -> object: # type: ignore
114
+ def model_func(_: object) -> object:
107
115
  return y
108
116
 
109
117
  else:
arize/ml/types.py CHANGED
@@ -2,16 +2,19 @@
2
2
 
3
3
  import logging
4
4
  import math
5
+ import sys
5
6
  from collections.abc import Iterator
6
7
  from dataclasses import asdict, dataclass, replace
7
8
  from datetime import datetime
8
9
  from decimal import Decimal
9
10
  from enum import Enum, unique
10
11
  from itertools import chain
11
- from typing import (
12
- NamedTuple,
13
- Self,
14
- )
12
+ from typing import NamedTuple
13
+
14
+ if sys.version_info >= (3, 11):
15
+ from typing import Self
16
+ else:
17
+ from typing_extensions import Self
15
18
 
16
19
  import numpy as np
17
20
 
@@ -29,6 +32,17 @@ from arize.utils.types import is_dict_of, is_iterable_of, is_list_of
29
32
  logger = logging.getLogger(__name__)
30
33
 
31
34
 
35
+ def _normalize_column_names(
36
+ col_names: "list[str] | TypedColumns | None",
37
+ ) -> list[str]:
38
+ """Convert TypedColumns or list to a flat list of column names."""
39
+ if col_names is None:
40
+ return []
41
+ if isinstance(col_names, list):
42
+ return col_names
43
+ return col_names.get_all_column_names()
44
+
45
+
32
46
  @unique
33
47
  class ModelTypes(Enum):
34
48
  """Enum representing supported model types in Arize."""
@@ -190,7 +204,7 @@ class Embedding(NamedTuple):
190
204
  )
191
205
  # Fail if not all elements in list are floats
192
206
  allowed_types = (int, float, np.int16, np.int32, np.float16, np.float32)
193
- if not all(isinstance(val, allowed_types) for val in self.vector): # type: ignore
207
+ if not all(isinstance(val, allowed_types) for val in self.vector):
194
208
  raise TypeError(
195
209
  f"Embedding vector must be a vector of integers and/or floats. Got "
196
210
  f"{emb_name}.vector = {self.vector}"
@@ -269,7 +283,7 @@ class Embedding(NamedTuple):
269
283
 
270
284
  @staticmethod
271
285
  def _is_valid_iterable(
272
- data: str | list[str] | list[float] | np.ndarray,
286
+ data: object,
273
287
  ) -> bool:
274
288
  """Validates that the input data field is of the correct iterable type.
275
289
 
@@ -1196,7 +1210,7 @@ class Schema(BaseSchema):
1196
1210
  actual_score_column_name: str | None = None
1197
1211
  shap_values_column_names: dict[str, str] | None = None
1198
1212
  embedding_feature_column_names: dict[str, EmbeddingColumnNames] | None = (
1199
- None # type:ignore
1213
+ None
1200
1214
  )
1201
1215
  prediction_group_id_column_name: str | None = None
1202
1216
  rank_column_name: str | None = None
@@ -1214,7 +1228,7 @@ class Schema(BaseSchema):
1214
1228
  prompt_template_column_names: PromptTemplateColumnNames | None = None
1215
1229
  llm_config_column_names: LLMConfigColumnNames | None = None
1216
1230
  llm_run_metadata_column_names: LLMRunMetadataColumnNames | None = None
1217
- retrieved_document_ids_column_name: list[str] | None = None
1231
+ retrieved_document_ids_column_name: str | None = None
1218
1232
  multi_class_threshold_scores_column_name: str | None = None
1219
1233
  semantic_segmentation_prediction_column_names: (
1220
1234
  SemanticSegmentationColumnNames | None
@@ -1231,7 +1245,7 @@ class Schema(BaseSchema):
1231
1245
 
1232
1246
  def get_used_columns_counts(self) -> dict[str, int]:
1233
1247
  """Return a dict mapping column names to their usage count."""
1234
- columns_used_counts = {}
1248
+ columns_used_counts: dict[str, int] = {}
1235
1249
 
1236
1250
  for field in self.__dataclass_fields__:
1237
1251
  if field.endswith("column_name"):
@@ -1240,7 +1254,7 @@ class Schema(BaseSchema):
1240
1254
  add_to_column_count_dictionary(columns_used_counts, col)
1241
1255
 
1242
1256
  if self.feature_column_names is not None:
1243
- for col in self.feature_column_names:
1257
+ for col in _normalize_column_names(self.feature_column_names):
1244
1258
  add_to_column_count_dictionary(columns_used_counts, col)
1245
1259
 
1246
1260
  if self.embedding_feature_column_names is not None:
@@ -1259,7 +1273,7 @@ class Schema(BaseSchema):
1259
1273
  )
1260
1274
 
1261
1275
  if self.tag_column_names is not None:
1262
- for col in self.tag_column_names:
1276
+ for col in _normalize_column_names(self.tag_column_names):
1263
1277
  add_to_column_count_dictionary(columns_used_counts, col)
1264
1278
 
1265
1279
  if self.shap_values_column_names is not None:
@@ -1404,7 +1418,7 @@ class CorpusSchema(BaseSchema):
1404
1418
 
1405
1419
  def get_used_columns_counts(self) -> dict[str, int]:
1406
1420
  """Return a dict mapping column names to their usage count."""
1407
- columns_used_counts = {}
1421
+ columns_used_counts: dict[str, int] = {}
1408
1422
 
1409
1423
  if self.document_id_column_name is not None:
1410
1424
  add_to_column_count_dictionary(
arize/pre_releases.py CHANGED
@@ -3,15 +3,15 @@
3
3
  import functools
4
4
  import logging
5
5
  from collections.abc import Callable
6
- from enum import StrEnum
7
- from typing import TypeVar
6
+ from enum import Enum
7
+ from typing import TypeVar, cast
8
8
 
9
9
  from arize.version import __version__
10
10
 
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
13
 
14
- class ReleaseStage(StrEnum):
14
+ class ReleaseStage(Enum):
15
15
  """Enum representing the release stage of API features."""
16
16
 
17
17
  ALPHA = "alpha"
@@ -26,12 +26,12 @@ _F = TypeVar("_F", bound=Callable)
26
26
  def _format_prerelease_message(*, key: str, stage: ReleaseStage) -> str:
27
27
  article = "an" if stage is ReleaseStage.ALPHA else "a"
28
28
  return (
29
- f"[{stage.upper()}] {key} is {article} {stage} API "
29
+ f"[{stage.value.upper()}] {key} is {article} {stage.value} API "
30
30
  f"in Arize SDK v{__version__} and may change without notice."
31
31
  )
32
32
 
33
33
 
34
- def prerelease_endpoint(*, stage: ReleaseStage, key: str) -> Callable[[_F], _F]:
34
+ def prerelease_endpoint(*, key: str, stage: ReleaseStage) -> Callable[[_F], _F]:
35
35
  """Decorate a method to emit a prerelease warning via logging once per process."""
36
36
 
37
37
  def deco(fn: _F) -> _F:
@@ -42,6 +42,7 @@ def prerelease_endpoint(*, stage: ReleaseStage, key: str) -> Callable[[_F], _F]:
42
42
  logger.warning(_format_prerelease_message(key=key, stage=stage))
43
43
  return fn(*args, **kwargs)
44
44
 
45
- return wrapper # type: ignore[return-value]
45
+ # Cast: functools.wraps preserves function signature at runtime but mypy can't verify this
46
+ return cast("_F", wrapper)
46
47
 
47
48
  return deco
arize/py.typed ADDED
File without changes