arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,14 +1,14 @@
1
+ """Protocol buffer utilities for ML model data serialization."""
2
+
1
3
  # type: ignore[pb2]
2
4
  from __future__ import annotations
3
5
 
4
- from typing import Tuple
5
-
6
6
  from google.protobuf.timestamp_pb2 import Timestamp
7
7
  from google.protobuf.wrappers_pb2 import DoubleValue, StringValue
8
8
 
9
9
  from arize._generated.protocol.rec import public_pb2 as pb2
10
10
  from arize.exceptions.parameters import InvalidValueType
11
- from arize.types import (
11
+ from arize.ml.types import (
12
12
  CATEGORICAL_MODEL_TYPES,
13
13
  NUMERIC_MODEL_TYPES,
14
14
  Embedding,
@@ -22,11 +22,19 @@ from arize.types import (
22
22
  RankingPredictionLabel,
23
23
  SemanticSegmentationLabel,
24
24
  convert_element,
25
- is_list_of,
26
25
  )
26
+ from arize.utils.types import is_list_of
27
+
28
+
29
+ def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
30
+ """Convert a dictionary to protobuf format with string keys and pb2.Value values.
27
31
 
32
+ Args:
33
+ d: Dictionary to convert, or None.
28
34
 
29
- def get_pb_dictionary(d):
35
+ Returns:
36
+ Dictionary with string keys and protobuf Value objects, or empty dict if input is None.
37
+ """
30
38
  if d is None:
31
39
  return {}
32
40
  # Takes a dictionary and
@@ -41,6 +49,18 @@ def get_pb_dictionary(d):
41
49
 
42
50
 
43
51
  def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
52
+ """Convert a Python value to a protobuf Value object.
53
+
54
+ Args:
55
+ name: The name/key associated with this value.
56
+ value: The value to convert to protobuf format.
57
+
58
+ Returns:
59
+ A pb2.Value protobuf object, or None if value cannot be converted.
60
+
61
+ Raises:
62
+ TypeError: If value type is not supported.
63
+ """
44
64
  if isinstance(value, pb2.Value):
45
65
  return value
46
66
  if value is not None and is_list_of(value, str):
@@ -50,19 +70,18 @@ def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
50
70
  val = convert_element(value)
51
71
  if val is None:
52
72
  return None
53
- elif isinstance(val, (str, bool)):
73
+ if isinstance(val, (str, bool)):
54
74
  return pb2.Value(string=str(val))
55
- elif isinstance(val, int):
75
+ if isinstance(val, int):
56
76
  return pb2.Value(int=val)
57
- elif isinstance(val, float):
77
+ if isinstance(val, float):
58
78
  return pb2.Value(double=val)
59
- elif isinstance(val, Embedding):
79
+ if isinstance(val, Embedding):
60
80
  return pb2.Value(embedding=get_pb_embedding(val))
61
- else:
62
- raise TypeError(
63
- f"dimension '{name}' = {value} is type {type(value)}, but must be "
64
- "one of: bool, str, float, int, embedding"
65
- )
81
+ raise TypeError(
82
+ f"dimension '{name}' = {value} is type {type(value)}, but must be "
83
+ "one of: bool, str, float, int, embedding"
84
+ )
66
85
 
67
86
 
68
87
  def get_pb_label(
@@ -71,7 +90,7 @@ def get_pb_label(
71
90
  | bool
72
91
  | int
73
92
  | float
74
- | Tuple[str, float]
93
+ | tuple[str, float]
75
94
  | ObjectDetectionLabel
76
95
  | SemanticSegmentationLabel
77
96
  | InstanceSegmentationPredictionLabel
@@ -82,19 +101,32 @@ def get_pb_label(
82
101
  | MultiClassActualLabel,
83
102
  model_type: ModelTypes,
84
103
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
104
+ """Convert a label value to the appropriate protobuf label type.
105
+
106
+ Args:
107
+ prediction_or_actual: Whether this is a "prediction" or "actual" label.
108
+ value: The label value to convert.
109
+ model_type: The type of model (numeric, categorical, etc.).
110
+
111
+ Returns:
112
+ A protobuf PredictionLabel or ActualLabel object.
113
+
114
+ Raises:
115
+ ValueError: If model_type is not supported.
116
+ """
85
117
  value = convert_element(value)
86
118
  if model_type in NUMERIC_MODEL_TYPES:
87
119
  return _get_numeric_pb_label(prediction_or_actual, value)
88
- elif (
120
+ if (
89
121
  model_type in CATEGORICAL_MODEL_TYPES
90
122
  or model_type == ModelTypes.GENERATIVE_LLM
91
123
  ):
92
124
  return _get_score_categorical_pb_label(prediction_or_actual, value)
93
- elif model_type == ModelTypes.OBJECT_DETECTION:
125
+ if model_type == ModelTypes.OBJECT_DETECTION:
94
126
  return _get_cv_pb_label(prediction_or_actual, value)
95
- elif model_type == ModelTypes.RANKING:
127
+ if model_type == ModelTypes.RANKING:
96
128
  return _get_ranking_pb_label(value)
97
- elif model_type == ModelTypes.MULTI_CLASS:
129
+ if model_type == ModelTypes.MULTI_CLASS:
98
130
  return _get_multi_class_pb_label(value)
99
131
  raise ValueError(
100
132
  f"model_type must be one of: {[mt.prediction_or_actual for mt in ModelTypes]} "
@@ -103,7 +135,18 @@ def get_pb_label(
103
135
  )
104
136
 
105
137
 
106
- def get_pb_timestamp(time_overwrite):
138
+ def get_pb_timestamp(time_overwrite: int | None) -> object | None:
139
+ """Convert a Unix timestamp to a protobuf Timestamp object.
140
+
141
+ Args:
142
+ time_overwrite: Unix epoch time in seconds, or None.
143
+
144
+ Returns:
145
+ A protobuf Timestamp object, or None if input is None.
146
+
147
+ Raises:
148
+ TypeError: If time_overwrite is not an integer.
149
+ """
107
150
  if time_overwrite is None:
108
151
  return None
109
152
  time = convert_element(time_overwrite)
@@ -118,6 +161,14 @@ def get_pb_timestamp(time_overwrite):
118
161
 
119
162
 
120
163
  def get_pb_embedding(val: Embedding) -> pb2.Embedding:
164
+ """Convert an Embedding object to a protobuf Embedding.
165
+
166
+ Args:
167
+ val: The Embedding object containing vector, data, and link_to_data.
168
+
169
+ Returns:
170
+ A protobuf Embedding object with the vector and optional raw data.
171
+ """
121
172
  if Embedding._is_valid_iterable(val.data):
122
173
  return pb2.Embedding(
123
174
  vector=val.vector,
@@ -126,7 +177,7 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
126
177
  ),
127
178
  link_to_data=StringValue(value=val.link_to_data),
128
179
  )
129
- elif isinstance(val.data, str):
180
+ if isinstance(val.data, str):
130
181
  return pb2.Embedding(
131
182
  vector=val.vector,
132
183
  raw_data=pb2.Embedding.RawData(
@@ -135,7 +186,7 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
135
186
  ),
136
187
  link_to_data=StringValue(value=val.link_to_data),
137
188
  )
138
- elif val.data is None:
189
+ if val.data is None:
139
190
  return pb2.Embedding(
140
191
  vector=val.vector,
141
192
  link_to_data=StringValue(value=val.link_to_data),
@@ -156,13 +207,14 @@ def _get_numeric_pb_label(
156
207
  )
157
208
  if prediction_or_actual == "prediction":
158
209
  return pb2.PredictionLabel(numeric=value)
159
- elif prediction_or_actual == "actual":
210
+ if prediction_or_actual == "actual":
160
211
  return pb2.ActualLabel(numeric=value)
212
+ return None
161
213
 
162
214
 
163
215
  def _get_score_categorical_pb_label(
164
216
  prediction_or_actual: str,
165
- value: bool | str | Tuple[str, float],
217
+ value: bool | str | tuple[str, float],
166
218
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
167
219
  sc = pb2.ScoreCategorical()
168
220
  if isinstance(value, bool):
@@ -202,8 +254,9 @@ def _get_score_categorical_pb_label(
202
254
  )
203
255
  if prediction_or_actual == "prediction":
204
256
  return pb2.PredictionLabel(score_categorical=sc)
205
- elif prediction_or_actual == "actual":
257
+ if prediction_or_actual == "actual":
206
258
  return pb2.ActualLabel(score_categorical=sc)
259
+ return None
207
260
 
208
261
 
209
262
  def _get_cv_pb_label(
@@ -215,20 +268,19 @@ def _get_cv_pb_label(
215
268
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
216
269
  if isinstance(value, ObjectDetectionLabel):
217
270
  return _get_object_detection_pb_label(prediction_or_actual, value)
218
- elif isinstance(value, SemanticSegmentationLabel):
271
+ if isinstance(value, SemanticSegmentationLabel):
219
272
  return _get_semantic_segmentation_pb_label(prediction_or_actual, value)
220
- elif isinstance(value, InstanceSegmentationPredictionLabel):
273
+ if isinstance(value, InstanceSegmentationPredictionLabel):
221
274
  return _get_instance_segmentation_prediction_pb_label(value)
222
- elif isinstance(value, InstanceSegmentationActualLabel):
275
+ if isinstance(value, InstanceSegmentationActualLabel):
223
276
  return _get_instance_segmentation_actual_pb_label(value)
224
- else:
225
- raise InvalidValueType(
226
- "cv label",
227
- value,
228
- "ObjectDetectionLabel, SemanticSegmentationLabel, or "
229
- "InstanceSegmentationPredictionLabel for model type "
230
- f"{ModelTypes.OBJECT_DETECTION}",
231
- )
277
+ raise InvalidValueType(
278
+ "cv label",
279
+ value,
280
+ "ObjectDetectionLabel, SemanticSegmentationLabel, or "
281
+ "InstanceSegmentationPredictionLabel for model type "
282
+ f"{ModelTypes.OBJECT_DETECTION}",
283
+ )
232
284
 
233
285
 
234
286
  def _get_object_detection_pb_label(
@@ -265,8 +317,9 @@ def _get_object_detection_pb_label(
265
317
  od.bounding_boxes.extend(bounding_boxes)
266
318
  if prediction_or_actual == "prediction":
267
319
  return pb2.PredictionLabel(object_detection=od)
268
- elif prediction_or_actual == "actual":
320
+ if prediction_or_actual == "actual":
269
321
  return pb2.ActualLabel(object_detection=od)
322
+ return None
270
323
 
271
324
 
272
325
  def _get_semantic_segmentation_pb_label(
@@ -292,10 +345,11 @@ def _get_semantic_segmentation_pb_label(
292
345
  cv_label = pb2.CVPredictionLabel()
293
346
  cv_label.semantic_segmentation_label.polygons.extend(polygons)
294
347
  return pb2.PredictionLabel(cv_label=cv_label)
295
- elif prediction_or_actual == "actual":
348
+ if prediction_or_actual == "actual":
296
349
  cv_label = pb2.CVActualLabel()
297
350
  cv_label.semantic_segmentation_label.polygons.extend(polygons)
298
351
  return pb2.ActualLabel(cv_label=cv_label)
352
+ return None
299
353
 
300
354
 
301
355
  def _get_instance_segmentation_prediction_pb_label(
@@ -394,7 +448,7 @@ def _get_ranking_pb_label(
394
448
  if value.label is not None:
395
449
  rp.label = value.label
396
450
  return pb2.PredictionLabel(ranking=rp)
397
- elif isinstance(value, RankingActualLabel):
451
+ if isinstance(value, RankingActualLabel):
398
452
  ra = pb2.RankingActual()
399
453
  # relevance_labels and relevance_score are optional
400
454
  if value.relevance_labels is not None:
@@ -402,6 +456,7 @@ def _get_ranking_pb_label(
402
456
  if value.relevance_score is not None:
403
457
  ra.relevance_score.value = value.relevance_score
404
458
  return pb2.ActualLabel(ranking=ra)
459
+ return None
405
460
 
406
461
 
407
462
  def _get_multi_class_pb_label(
@@ -447,9 +502,8 @@ def _get_multi_class_pb_label(
447
502
  prediction_scores=prediction_scores_double_values,
448
503
  )
449
504
  mc_pred = pb2.MultiClassPrediction(single_label=single_label)
450
- p_label = pb2.PredictionLabel(multi_class=mc_pred)
451
- return p_label
452
- elif isinstance(value, MultiClassActualLabel):
505
+ return pb2.PredictionLabel(multi_class=mc_pred)
506
+ if isinstance(value, MultiClassActualLabel):
453
507
  # Validations checked actual score map is not None
454
508
  actual_labels = [] # list of class names with actual score of 1
455
509
  for class_name, score in value.actual_scores.items():
@@ -459,3 +513,4 @@ def _get_multi_class_pb_label(
459
513
  actual_labels=actual_labels,
460
514
  )
461
515
  return pb2.ActualLabel(multi_class=mc_act)
516
+ return None
@@ -1,11 +1,12 @@
1
+ """Stream validation logic for ML model predictions."""
2
+
1
3
  # type: ignore[pb2]
2
- from typing import Dict, Tuple
3
4
 
4
5
  from arize.constants.ml import MAX_PREDICTION_ID_LEN, MIN_PREDICTION_ID_LEN
5
6
  from arize.exceptions.parameters import (
6
7
  InvalidValueType,
7
8
  )
8
- from arize.types import (
9
+ from arize.ml.types import (
9
10
  CATEGORICAL_MODEL_TYPES,
10
11
  NUMERIC_MODEL_TYPES,
11
12
  ActualLabelTypes,
@@ -32,7 +33,7 @@ def validate_label(
32
33
  | bool
33
34
  | int
34
35
  | float
35
- | Tuple[str | bool, float]
36
+ | tuple[str | bool, float]
36
37
  | ObjectDetectionLabel
37
38
  | RankingPredictionLabel
38
39
  | RankingActualLabel
@@ -41,8 +42,20 @@ def validate_label(
41
42
  | InstanceSegmentationActualLabel
42
43
  | MultiClassPredictionLabel
43
44
  | MultiClassActualLabel,
44
- embedding_features: Dict[str, Embedding],
45
- ):
45
+ embedding_features: dict[str, Embedding],
46
+ ) -> None:
47
+ """Validate a label value against the specified model type.
48
+
49
+ Args:
50
+ prediction_or_actual: Whether this is a "prediction" or "actual" label.
51
+ model_type: The type of model (numeric, categorical, etc.).
52
+ label: The label value to validate.
53
+ embedding_features: Dictionary of embedding features for validation.
54
+
55
+ Raises:
56
+ ValueError: If label is invalid for the given model type.
57
+ TypeError: If label type is incorrect.
58
+ """
46
59
  if model_type in NUMERIC_MODEL_TYPES:
47
60
  _validate_numeric_label(model_type, label)
48
61
  elif model_type in CATEGORICAL_MODEL_TYPES:
@@ -63,8 +76,8 @@ def validate_label(
63
76
 
64
77
  def _validate_numeric_label(
65
78
  model_type: ModelTypes,
66
- label: str | bool | int | float | Tuple[str | bool, float],
67
- ):
79
+ label: str | bool | int | float | tuple[str | bool, float],
80
+ ) -> None:
68
81
  if not isinstance(label, (float, int)):
69
82
  raise InvalidValueType(
70
83
  f"label {label}",
@@ -75,8 +88,8 @@ def _validate_numeric_label(
75
88
 
76
89
  def _validate_categorical_label(
77
90
  model_type: ModelTypes,
78
- label: str | bool | int | float | Tuple[str | bool, float],
79
- ):
91
+ label: str | bool | int | float | tuple[str | bool, float],
92
+ ) -> None:
80
93
  is_valid = isinstance(label, (str, bool, int, float)) or (
81
94
  isinstance(label, tuple)
82
95
  and isinstance(label[0], (str, bool))
@@ -96,8 +109,8 @@ def _validate_cv_label(
96
109
  | SemanticSegmentationLabel
97
110
  | InstanceSegmentationPredictionLabel
98
111
  | InstanceSegmentationActualLabel,
99
- embedding_features: Dict[str, Embedding],
100
- ):
112
+ embedding_features: dict[str, Embedding],
113
+ ) -> None:
101
114
  if (
102
115
  not isinstance(label, ObjectDetectionLabel)
103
116
  and not isinstance(label, SemanticSegmentationLabel)
@@ -126,7 +139,7 @@ def _validate_cv_label(
126
139
 
127
140
  def _validate_ranking_label(
128
141
  label: RankingPredictionLabel | RankingActualLabel,
129
- ):
142
+ ) -> None:
130
143
  if not isinstance(label, (RankingPredictionLabel, RankingActualLabel)):
131
144
  raise InvalidValueType(
132
145
  f"label {label}",
@@ -138,7 +151,7 @@ def _validate_ranking_label(
138
151
 
139
152
  def _validate_generative_llm_label(
140
153
  label: str | bool | int | float,
141
- ):
154
+ ) -> None:
142
155
  is_valid = isinstance(label, (str, bool, int, float))
143
156
  if not is_valid:
144
157
  raise InvalidValueType(
@@ -150,7 +163,7 @@ def _validate_generative_llm_label(
150
163
 
151
164
  def _validate_multi_class_label(
152
165
  label: MultiClassPredictionLabel | MultiClassActualLabel,
153
- ):
166
+ ) -> None:
154
167
  if not isinstance(
155
168
  label, (MultiClassPredictionLabel, MultiClassActualLabel)
156
169
  ):
@@ -167,8 +180,23 @@ def validate_and_convert_prediction_id(
167
180
  environment: Environments,
168
181
  prediction_label: PredictionLabelTypes | None = None,
169
182
  actual_label: ActualLabelTypes | None = None,
170
- shap_values: Dict[str, float] | None = None,
183
+ shap_values: dict[str, float] | None = None,
171
184
  ) -> str:
185
+ """Validate and convert a prediction ID to string format, or generate one if absent.
186
+
187
+ Args:
188
+ prediction_id: The prediction ID to validate/convert, or None.
189
+ environment: The environment context (training, validation, production).
190
+ prediction_label: Optional prediction label for delayed record detection.
191
+ actual_label: Optional actual label for delayed record detection.
192
+ shap_values: Optional SHAP values for delayed record detection.
193
+
194
+ Returns:
195
+ A validated prediction ID string.
196
+
197
+ Raises:
198
+ ValueError: If prediction ID is invalid.
199
+ """
172
200
  # If the user does not provide prediction id
173
201
  if prediction_id:
174
202
  # If prediction id is given by user, convert it to string and validate length
@@ -0,0 +1 @@
1
+ """Surrogate explainer implementations for model interpretability."""
@@ -1,9 +1,11 @@
1
+ """Mimic explainer implementation for surrogate model explanations."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import random
4
6
  import string
5
7
  from dataclasses import replace
6
- from typing import TYPE_CHECKING, Callable, Tuple
8
+ from typing import TYPE_CHECKING
7
9
 
8
10
  import numpy as np
9
11
  import pandas as pd
@@ -13,20 +15,30 @@ from interpret_community.mimic.mimic_explainer import (
13
15
  )
14
16
  from sklearn.preprocessing import LabelEncoder
15
17
 
16
- from arize.types import (
18
+ from arize.ml.types import (
17
19
  CATEGORICAL_MODEL_TYPES,
18
20
  NUMERIC_MODEL_TYPES,
19
21
  ModelTypes,
20
22
  )
21
23
 
22
24
  if TYPE_CHECKING:
23
- from arize.types import Schema
25
+ from collections.abc import Callable
26
+
27
+ from arize.ml.types import Schema
24
28
 
25
29
 
26
30
  class Mimic:
31
+ """Mimic explainer wrapper for generating surrogate model explanations."""
32
+
27
33
  _testing = False
28
34
 
29
- def __init__(self, X: pd.DataFrame, model_func: Callable):
35
+ def __init__(self, X: pd.DataFrame, model_func: Callable) -> None:
36
+ """Initialize the Mimic explainer with training data and model.
37
+
38
+ Args:
39
+ X: Training data DataFrame for the surrogate model.
40
+ model_func: Model function to explain.
41
+ """
30
42
  self.explainer = MimicExplainer(
31
43
  model_func,
32
44
  X,
@@ -36,6 +48,7 @@ class Mimic:
36
48
  )
37
49
 
38
50
  def explain(self, X: pd.DataFrame) -> pd.DataFrame:
51
+ """Explain feature importance for the given input DataFrame."""
39
52
  return pd.DataFrame(
40
53
  self.explainer.explain_local(X).local_importance_values,
41
54
  columns=X.columns,
@@ -45,7 +58,8 @@ class Mimic:
45
58
  @staticmethod
46
59
  def augment(
47
60
  df: pd.DataFrame, schema: Schema, model_type: ModelTypes
48
- ) -> Tuple[pd.DataFrame, Schema]:
61
+ ) -> tuple[pd.DataFrame, Schema]:
62
+ """Augment the DataFrame and schema with SHAP values for explainability."""
49
63
  features = schema.feature_column_names
50
64
  X = df[features]
51
65
 
@@ -71,7 +85,7 @@ class Mimic:
71
85
  )
72
86
 
73
87
  # model func requires 1 positional argument
74
- def model_func(_): # type: ignore
88
+ def model_func(_: object) -> object: # type: ignore
75
89
  return np.column_stack((1 - y, y))
76
90
 
77
91
  elif model_type in NUMERIC_MODEL_TYPES:
@@ -89,7 +103,7 @@ class Mimic:
89
103
  )
90
104
 
91
105
  # model func requires 1 positional argument
92
- def model_func(_): # type: ignore
106
+ def model_func(_: object) -> object: # type: ignore
93
107
  return y
94
108
 
95
109
  else:
@@ -100,8 +114,9 @@ class Mimic:
100
114
 
101
115
  # Column name mapping between features and feature importance values.
102
116
  # This is used to augment the schema.
117
+ # Generate unique column names to avoid collisions (not security-sensitive)
103
118
  col_map = {
104
- ft: f"{''.join(random.choices(string.ascii_letters, k=8))}"
119
+ ft: f"{''.join(random.choices(string.ascii_letters, k=8))}" # noqa: S311
105
120
  for ft in features
106
121
  }
107
122
  aug_schema = replace(schema, shap_values_column_names=col_map)
@@ -127,7 +142,7 @@ class Mimic:
127
142
  X[col] = X[col].astype(object).where(~X[col].isna(), np.nan)
128
143
 
129
144
  # Apply integer encoding to non-numeric columns.
130
- # Currently training and explaining detasets are the same, but
145
+ # Currently training and explaining datasets are the same, but
131
146
  # this can be changed in the future. The student model can be
132
147
  # fitted on a much larger dataset since it takes a lot less time.
133
148
  X = pd.concat(
@@ -156,7 +171,7 @@ class Mimic:
156
171
 
157
172
  # Fill null with zero so they're not counted as missing records by server
158
173
  if not Mimic._testing:
159
- aug_df.fillna({c: 0 for c in col_map.values()}, inplace=True)
174
+ aug_df.fillna(dict.fromkeys(col_map.values(), 0), inplace=True)
160
175
 
161
176
  return (
162
177
  aug_df,