arize 8.0.0a22__py3-none-any.whl → 8.0.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. arize/__init__.py +17 -9
  2. arize/_exporter/client.py +55 -36
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +268 -55
  65. arize/config.py +365 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +299 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +31 -12
  83. arize/embeddings/tabular_generators.py +32 -20
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +1 -0
  94. arize/experiments/client.py +389 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/models/__init__.py +1 -0
  107. arize/models/batch_validation/__init__.py +1 -0
  108. arize/models/batch_validation/errors.py +543 -65
  109. arize/models/batch_validation/validator.py +339 -300
  110. arize/models/bounded_executor.py +20 -7
  111. arize/models/casting.py +75 -29
  112. arize/models/client.py +326 -107
  113. arize/models/proto.py +95 -40
  114. arize/models/stream_validation.py +42 -14
  115. arize/models/surrogate_explainer/__init__.py +1 -0
  116. arize/models/surrogate_explainer/mimic.py +24 -13
  117. arize/pre_releases.py +43 -0
  118. arize/projects/__init__.py +1 -0
  119. arize/projects/client.py +129 -0
  120. arize/regions.py +40 -0
  121. arize/spans/__init__.py +1 -0
  122. arize/spans/client.py +130 -106
  123. arize/spans/columns.py +13 -0
  124. arize/spans/conversion.py +54 -38
  125. arize/spans/validation/__init__.py +1 -0
  126. arize/spans/validation/annotations/__init__.py +1 -0
  127. arize/spans/validation/annotations/annotations_validation.py +6 -4
  128. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  129. arize/spans/validation/annotations/value_validation.py +35 -11
  130. arize/spans/validation/common/__init__.py +1 -0
  131. arize/spans/validation/common/argument_validation.py +33 -8
  132. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  133. arize/spans/validation/common/errors.py +211 -11
  134. arize/spans/validation/common/value_validation.py +80 -13
  135. arize/spans/validation/evals/__init__.py +1 -0
  136. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  137. arize/spans/validation/evals/evals_validation.py +34 -4
  138. arize/spans/validation/evals/value_validation.py +26 -3
  139. arize/spans/validation/metadata/__init__.py +1 -1
  140. arize/spans/validation/metadata/argument_validation.py +14 -5
  141. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  142. arize/spans/validation/metadata/value_validation.py +24 -10
  143. arize/spans/validation/spans/__init__.py +1 -0
  144. arize/spans/validation/spans/dataframe_form_validation.py +34 -13
  145. arize/spans/validation/spans/spans_validation.py +35 -4
  146. arize/spans/validation/spans/value_validation.py +76 -7
  147. arize/types.py +293 -157
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +19 -2
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/version.py +3 -1
  158. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
  159. arize-8.0.0a23.dist-info/RECORD +174 -0
  160. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
  161. arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
  162. arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
  163. arize/_generated/protocol/flight/export_pb2.py +0 -61
  164. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  165. arize-8.0.0a22.dist-info/RECORD +0 -146
  166. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
arize/models/proto.py CHANGED
@@ -1,8 +1,8 @@
1
+ """Protocol buffer utilities for ML model data serialization."""
2
+
1
3
  # type: ignore[pb2]
2
4
  from __future__ import annotations
3
5
 
4
- from typing import Tuple
5
-
6
6
  from google.protobuf.timestamp_pb2 import Timestamp
7
7
  from google.protobuf.wrappers_pb2 import DoubleValue, StringValue
8
8
 
@@ -26,7 +26,15 @@ from arize.types import (
26
26
  )
27
27
 
28
28
 
29
- def get_pb_dictionary(d):
29
+ def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
30
+ """Convert a dictionary to protobuf format with string keys and pb2.Value values.
31
+
32
+ Args:
33
+ d: Dictionary to convert, or None.
34
+
35
+ Returns:
36
+ Dictionary with string keys and protobuf Value objects, or empty dict if input is None.
37
+ """
30
38
  if d is None:
31
39
  return {}
32
40
  # Takes a dictionary and
@@ -41,6 +49,18 @@ def get_pb_dictionary(d):
41
49
 
42
50
 
43
51
  def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
52
+ """Convert a Python value to a protobuf Value object.
53
+
54
+ Args:
55
+ name: The name/key associated with this value.
56
+ value: The value to convert to protobuf format.
57
+
58
+ Returns:
59
+ A pb2.Value protobuf object, or None if value cannot be converted.
60
+
61
+ Raises:
62
+ TypeError: If value type is not supported.
63
+ """
44
64
  if isinstance(value, pb2.Value):
45
65
  return value
46
66
  if value is not None and is_list_of(value, str):
@@ -50,19 +70,18 @@ def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
50
70
  val = convert_element(value)
51
71
  if val is None:
52
72
  return None
53
- elif isinstance(val, (str, bool)):
73
+ if isinstance(val, (str, bool)):
54
74
  return pb2.Value(string=str(val))
55
- elif isinstance(val, int):
75
+ if isinstance(val, int):
56
76
  return pb2.Value(int=val)
57
- elif isinstance(val, float):
77
+ if isinstance(val, float):
58
78
  return pb2.Value(double=val)
59
- elif isinstance(val, Embedding):
79
+ if isinstance(val, Embedding):
60
80
  return pb2.Value(embedding=get_pb_embedding(val))
61
- else:
62
- raise TypeError(
63
- f"dimension '{name}' = {value} is type {type(value)}, but must be "
64
- "one of: bool, str, float, int, embedding"
65
- )
81
+ raise TypeError(
82
+ f"dimension '{name}' = {value} is type {type(value)}, but must be "
83
+ "one of: bool, str, float, int, embedding"
84
+ )
66
85
 
67
86
 
68
87
  def get_pb_label(
@@ -71,7 +90,7 @@ def get_pb_label(
71
90
  | bool
72
91
  | int
73
92
  | float
74
- | Tuple[str, float]
93
+ | tuple[str, float]
75
94
  | ObjectDetectionLabel
76
95
  | SemanticSegmentationLabel
77
96
  | InstanceSegmentationPredictionLabel
@@ -82,19 +101,32 @@ def get_pb_label(
82
101
  | MultiClassActualLabel,
83
102
  model_type: ModelTypes,
84
103
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
104
+ """Convert a label value to the appropriate protobuf label type.
105
+
106
+ Args:
107
+ prediction_or_actual: Whether this is a "prediction" or "actual" label.
108
+ value: The label value to convert.
109
+ model_type: The type of model (numeric, categorical, etc.).
110
+
111
+ Returns:
112
+ A protobuf PredictionLabel or ActualLabel object.
113
+
114
+ Raises:
115
+ ValueError: If model_type is not supported.
116
+ """
85
117
  value = convert_element(value)
86
118
  if model_type in NUMERIC_MODEL_TYPES:
87
119
  return _get_numeric_pb_label(prediction_or_actual, value)
88
- elif (
120
+ if (
89
121
  model_type in CATEGORICAL_MODEL_TYPES
90
122
  or model_type == ModelTypes.GENERATIVE_LLM
91
123
  ):
92
124
  return _get_score_categorical_pb_label(prediction_or_actual, value)
93
- elif model_type == ModelTypes.OBJECT_DETECTION:
125
+ if model_type == ModelTypes.OBJECT_DETECTION:
94
126
  return _get_cv_pb_label(prediction_or_actual, value)
95
- elif model_type == ModelTypes.RANKING:
127
+ if model_type == ModelTypes.RANKING:
96
128
  return _get_ranking_pb_label(value)
97
- elif model_type == ModelTypes.MULTI_CLASS:
129
+ if model_type == ModelTypes.MULTI_CLASS:
98
130
  return _get_multi_class_pb_label(value)
99
131
  raise ValueError(
100
132
  f"model_type must be one of: {[mt.prediction_or_actual for mt in ModelTypes]} "
@@ -103,7 +135,18 @@ def get_pb_label(
103
135
  )
104
136
 
105
137
 
106
- def get_pb_timestamp(time_overwrite):
138
+ def get_pb_timestamp(time_overwrite: int | None) -> object | None:
139
+ """Convert a Unix timestamp to a protobuf Timestamp object.
140
+
141
+ Args:
142
+ time_overwrite: Unix epoch time in seconds, or None.
143
+
144
+ Returns:
145
+ A protobuf Timestamp object, or None if input is None.
146
+
147
+ Raises:
148
+ TypeError: If time_overwrite is not an integer.
149
+ """
107
150
  if time_overwrite is None:
108
151
  return None
109
152
  time = convert_element(time_overwrite)
@@ -118,6 +161,14 @@ def get_pb_timestamp(time_overwrite):
118
161
 
119
162
 
120
163
  def get_pb_embedding(val: Embedding) -> pb2.Embedding:
164
+ """Convert an Embedding object to a protobuf Embedding.
165
+
166
+ Args:
167
+ val: The Embedding object containing vector, data, and link_to_data.
168
+
169
+ Returns:
170
+ A protobuf Embedding object with the vector and optional raw data.
171
+ """
121
172
  if Embedding._is_valid_iterable(val.data):
122
173
  return pb2.Embedding(
123
174
  vector=val.vector,
@@ -126,7 +177,7 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
126
177
  ),
127
178
  link_to_data=StringValue(value=val.link_to_data),
128
179
  )
129
- elif isinstance(val.data, str):
180
+ if isinstance(val.data, str):
130
181
  return pb2.Embedding(
131
182
  vector=val.vector,
132
183
  raw_data=pb2.Embedding.RawData(
@@ -135,7 +186,7 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
135
186
  ),
136
187
  link_to_data=StringValue(value=val.link_to_data),
137
188
  )
138
- elif val.data is None:
189
+ if val.data is None:
139
190
  return pb2.Embedding(
140
191
  vector=val.vector,
141
192
  link_to_data=StringValue(value=val.link_to_data),
@@ -156,13 +207,14 @@ def _get_numeric_pb_label(
156
207
  )
157
208
  if prediction_or_actual == "prediction":
158
209
  return pb2.PredictionLabel(numeric=value)
159
- elif prediction_or_actual == "actual":
210
+ if prediction_or_actual == "actual":
160
211
  return pb2.ActualLabel(numeric=value)
212
+ return None
161
213
 
162
214
 
163
215
  def _get_score_categorical_pb_label(
164
216
  prediction_or_actual: str,
165
- value: bool | str | Tuple[str, float],
217
+ value: bool | str | tuple[str, float],
166
218
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
167
219
  sc = pb2.ScoreCategorical()
168
220
  if isinstance(value, bool):
@@ -202,8 +254,9 @@ def _get_score_categorical_pb_label(
202
254
  )
203
255
  if prediction_or_actual == "prediction":
204
256
  return pb2.PredictionLabel(score_categorical=sc)
205
- elif prediction_or_actual == "actual":
257
+ if prediction_or_actual == "actual":
206
258
  return pb2.ActualLabel(score_categorical=sc)
259
+ return None
207
260
 
208
261
 
209
262
  def _get_cv_pb_label(
@@ -215,20 +268,19 @@ def _get_cv_pb_label(
215
268
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
216
269
  if isinstance(value, ObjectDetectionLabel):
217
270
  return _get_object_detection_pb_label(prediction_or_actual, value)
218
- elif isinstance(value, SemanticSegmentationLabel):
271
+ if isinstance(value, SemanticSegmentationLabel):
219
272
  return _get_semantic_segmentation_pb_label(prediction_or_actual, value)
220
- elif isinstance(value, InstanceSegmentationPredictionLabel):
273
+ if isinstance(value, InstanceSegmentationPredictionLabel):
221
274
  return _get_instance_segmentation_prediction_pb_label(value)
222
- elif isinstance(value, InstanceSegmentationActualLabel):
275
+ if isinstance(value, InstanceSegmentationActualLabel):
223
276
  return _get_instance_segmentation_actual_pb_label(value)
224
- else:
225
- raise InvalidValueType(
226
- "cv label",
227
- value,
228
- "ObjectDetectionLabel, SemanticSegmentationLabel, or "
229
- "InstanceSegmentationPredictionLabel for model type "
230
- f"{ModelTypes.OBJECT_DETECTION}",
231
- )
277
+ raise InvalidValueType(
278
+ "cv label",
279
+ value,
280
+ "ObjectDetectionLabel, SemanticSegmentationLabel, or "
281
+ "InstanceSegmentationPredictionLabel for model type "
282
+ f"{ModelTypes.OBJECT_DETECTION}",
283
+ )
232
284
 
233
285
 
234
286
  def _get_object_detection_pb_label(
@@ -265,8 +317,9 @@ def _get_object_detection_pb_label(
265
317
  od.bounding_boxes.extend(bounding_boxes)
266
318
  if prediction_or_actual == "prediction":
267
319
  return pb2.PredictionLabel(object_detection=od)
268
- elif prediction_or_actual == "actual":
320
+ if prediction_or_actual == "actual":
269
321
  return pb2.ActualLabel(object_detection=od)
322
+ return None
270
323
 
271
324
 
272
325
  def _get_semantic_segmentation_pb_label(
@@ -292,10 +345,11 @@ def _get_semantic_segmentation_pb_label(
292
345
  cv_label = pb2.CVPredictionLabel()
293
346
  cv_label.semantic_segmentation_label.polygons.extend(polygons)
294
347
  return pb2.PredictionLabel(cv_label=cv_label)
295
- elif prediction_or_actual == "actual":
348
+ if prediction_or_actual == "actual":
296
349
  cv_label = pb2.CVActualLabel()
297
350
  cv_label.semantic_segmentation_label.polygons.extend(polygons)
298
351
  return pb2.ActualLabel(cv_label=cv_label)
352
+ return None
299
353
 
300
354
 
301
355
  def _get_instance_segmentation_prediction_pb_label(
@@ -394,7 +448,7 @@ def _get_ranking_pb_label(
394
448
  if value.label is not None:
395
449
  rp.label = value.label
396
450
  return pb2.PredictionLabel(ranking=rp)
397
- elif isinstance(value, RankingActualLabel):
451
+ if isinstance(value, RankingActualLabel):
398
452
  ra = pb2.RankingActual()
399
453
  # relevance_labels and relevance_score are optional
400
454
  if value.relevance_labels is not None:
@@ -402,6 +456,7 @@ def _get_ranking_pb_label(
402
456
  if value.relevance_score is not None:
403
457
  ra.relevance_score.value = value.relevance_score
404
458
  return pb2.ActualLabel(ranking=ra)
459
+ return None
405
460
 
406
461
 
407
462
  def _get_multi_class_pb_label(
@@ -447,9 +502,8 @@ def _get_multi_class_pb_label(
447
502
  prediction_scores=prediction_scores_double_values,
448
503
  )
449
504
  mc_pred = pb2.MultiClassPrediction(single_label=single_label)
450
- p_label = pb2.PredictionLabel(multi_class=mc_pred)
451
- return p_label
452
- elif isinstance(value, MultiClassActualLabel):
505
+ return pb2.PredictionLabel(multi_class=mc_pred)
506
+ if isinstance(value, MultiClassActualLabel):
453
507
  # Validations checked actual score map is not None
454
508
  actual_labels = [] # list of class names with actual score of 1
455
509
  for class_name, score in value.actual_scores.items():
@@ -459,3 +513,4 @@ def _get_multi_class_pb_label(
459
513
  actual_labels=actual_labels,
460
514
  )
461
515
  return pb2.ActualLabel(multi_class=mc_act)
516
+ return None
@@ -1,5 +1,6 @@
1
+ """Stream validation logic for ML model predictions."""
2
+
1
3
  # type: ignore[pb2]
2
- from typing import Dict, Tuple
3
4
 
4
5
  from arize.constants.ml import MAX_PREDICTION_ID_LEN, MIN_PREDICTION_ID_LEN
5
6
  from arize.exceptions.parameters import (
@@ -32,7 +33,7 @@ def validate_label(
32
33
  | bool
33
34
  | int
34
35
  | float
35
- | Tuple[str | bool, float]
36
+ | tuple[str | bool, float]
36
37
  | ObjectDetectionLabel
37
38
  | RankingPredictionLabel
38
39
  | RankingActualLabel
@@ -41,8 +42,20 @@ def validate_label(
41
42
  | InstanceSegmentationActualLabel
42
43
  | MultiClassPredictionLabel
43
44
  | MultiClassActualLabel,
44
- embedding_features: Dict[str, Embedding],
45
- ):
45
+ embedding_features: dict[str, Embedding],
46
+ ) -> None:
47
+ """Validate a label value against the specified model type.
48
+
49
+ Args:
50
+ prediction_or_actual: Whether this is a "prediction" or "actual" label.
51
+ model_type: The type of model (numeric, categorical, etc.).
52
+ label: The label value to validate.
53
+ embedding_features: Dictionary of embedding features for validation.
54
+
55
+ Raises:
56
+ ValueError: If label is invalid for the given model type.
57
+ TypeError: If label type is incorrect.
58
+ """
46
59
  if model_type in NUMERIC_MODEL_TYPES:
47
60
  _validate_numeric_label(model_type, label)
48
61
  elif model_type in CATEGORICAL_MODEL_TYPES:
@@ -63,8 +76,8 @@ def validate_label(
63
76
 
64
77
  def _validate_numeric_label(
65
78
  model_type: ModelTypes,
66
- label: str | bool | int | float | Tuple[str | bool, float],
67
- ):
79
+ label: str | bool | int | float | tuple[str | bool, float],
80
+ ) -> None:
68
81
  if not isinstance(label, (float, int)):
69
82
  raise InvalidValueType(
70
83
  f"label {label}",
@@ -75,8 +88,8 @@ def _validate_numeric_label(
75
88
 
76
89
  def _validate_categorical_label(
77
90
  model_type: ModelTypes,
78
- label: str | bool | int | float | Tuple[str | bool, float],
79
- ):
91
+ label: str | bool | int | float | tuple[str | bool, float],
92
+ ) -> None:
80
93
  is_valid = isinstance(label, (str, bool, int, float)) or (
81
94
  isinstance(label, tuple)
82
95
  and isinstance(label[0], (str, bool))
@@ -96,8 +109,8 @@ def _validate_cv_label(
96
109
  | SemanticSegmentationLabel
97
110
  | InstanceSegmentationPredictionLabel
98
111
  | InstanceSegmentationActualLabel,
99
- embedding_features: Dict[str, Embedding],
100
- ):
112
+ embedding_features: dict[str, Embedding],
113
+ ) -> None:
101
114
  if (
102
115
  not isinstance(label, ObjectDetectionLabel)
103
116
  and not isinstance(label, SemanticSegmentationLabel)
@@ -126,7 +139,7 @@ def _validate_cv_label(
126
139
 
127
140
  def _validate_ranking_label(
128
141
  label: RankingPredictionLabel | RankingActualLabel,
129
- ):
142
+ ) -> None:
130
143
  if not isinstance(label, (RankingPredictionLabel, RankingActualLabel)):
131
144
  raise InvalidValueType(
132
145
  f"label {label}",
@@ -138,7 +151,7 @@ def _validate_ranking_label(
138
151
 
139
152
  def _validate_generative_llm_label(
140
153
  label: str | bool | int | float,
141
- ):
154
+ ) -> None:
142
155
  is_valid = isinstance(label, (str, bool, int, float))
143
156
  if not is_valid:
144
157
  raise InvalidValueType(
@@ -150,7 +163,7 @@ def _validate_generative_llm_label(
150
163
 
151
164
  def _validate_multi_class_label(
152
165
  label: MultiClassPredictionLabel | MultiClassActualLabel,
153
- ):
166
+ ) -> None:
154
167
  if not isinstance(
155
168
  label, (MultiClassPredictionLabel, MultiClassActualLabel)
156
169
  ):
@@ -167,8 +180,23 @@ def validate_and_convert_prediction_id(
167
180
  environment: Environments,
168
181
  prediction_label: PredictionLabelTypes | None = None,
169
182
  actual_label: ActualLabelTypes | None = None,
170
- shap_values: Dict[str, float] | None = None,
183
+ shap_values: dict[str, float] | None = None,
171
184
  ) -> str:
185
+ """Validate and convert a prediction ID to string format, or generate one if absent.
186
+
187
+ Args:
188
+ prediction_id: The prediction ID to validate/convert, or None.
189
+ environment: The environment context (training, validation, production).
190
+ prediction_label: Optional prediction label for delayed record detection.
191
+ actual_label: Optional actual label for delayed record detection.
192
+ shap_values: Optional SHAP values for delayed record detection.
193
+
194
+ Returns:
195
+ A validated prediction ID string.
196
+
197
+ Raises:
198
+ ValueError: If prediction ID is invalid.
199
+ """
172
200
  # If the user does not provide prediction id
173
201
  if prediction_id:
174
202
  # If prediction id is given by user, convert it to string and validate length
@@ -0,0 +1 @@
1
+ """Surrogate explainer implementations for model interpretability."""
@@ -1,9 +1,11 @@
1
+ """Mimic explainer implementation for surrogate model explanations."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import random
4
6
  import string
5
7
  from dataclasses import replace
6
- from typing import TYPE_CHECKING, Callable, Tuple
8
+ from typing import TYPE_CHECKING
7
9
 
8
10
  import numpy as np
9
11
  import pandas as pd
@@ -13,20 +15,26 @@ from interpret_community.mimic.mimic_explainer import (
13
15
  )
14
16
  from sklearn.preprocessing import LabelEncoder
15
17
 
16
- from arize.types import (
17
- CATEGORICAL_MODEL_TYPES,
18
- NUMERIC_MODEL_TYPES,
19
- ModelTypes,
20
- )
18
+ from arize.types import CATEGORICAL_MODEL_TYPES, NUMERIC_MODEL_TYPES, ModelTypes
21
19
 
22
20
  if TYPE_CHECKING:
21
+ from collections.abc import Callable
22
+
23
23
  from arize.types import Schema
24
24
 
25
25
 
26
26
  class Mimic:
27
+ """Mimic explainer wrapper for generating surrogate model explanations."""
28
+
27
29
  _testing = False
28
30
 
29
- def __init__(self, X: pd.DataFrame, model_func: Callable):
31
+ def __init__(self, X: pd.DataFrame, model_func: Callable) -> None:
32
+ """Initialize the Mimic explainer with training data and model.
33
+
34
+ Args:
35
+ X: Training data DataFrame for the surrogate model.
36
+ model_func: Model function to explain.
37
+ """
30
38
  self.explainer = MimicExplainer(
31
39
  model_func,
32
40
  X,
@@ -36,6 +44,7 @@ class Mimic:
36
44
  )
37
45
 
38
46
  def explain(self, X: pd.DataFrame) -> pd.DataFrame:
47
+ """Explain feature importance for the given input DataFrame."""
39
48
  return pd.DataFrame(
40
49
  self.explainer.explain_local(X).local_importance_values,
41
50
  columns=X.columns,
@@ -45,7 +54,8 @@ class Mimic:
45
54
  @staticmethod
46
55
  def augment(
47
56
  df: pd.DataFrame, schema: Schema, model_type: ModelTypes
48
- ) -> Tuple[pd.DataFrame, Schema]:
57
+ ) -> tuple[pd.DataFrame, Schema]:
58
+ """Augment the DataFrame and schema with SHAP values for explainability."""
49
59
  features = schema.feature_column_names
50
60
  X = df[features]
51
61
 
@@ -71,7 +81,7 @@ class Mimic:
71
81
  )
72
82
 
73
83
  # model func requires 1 positional argument
74
- def model_func(_): # type: ignore
84
+ def model_func(_: object) -> object: # type: ignore
75
85
  return np.column_stack((1 - y, y))
76
86
 
77
87
  elif model_type in NUMERIC_MODEL_TYPES:
@@ -89,7 +99,7 @@ class Mimic:
89
99
  )
90
100
 
91
101
  # model func requires 1 positional argument
92
- def model_func(_): # type: ignore
102
+ def model_func(_: object) -> object: # type: ignore
93
103
  return y
94
104
 
95
105
  else:
@@ -100,8 +110,9 @@ class Mimic:
100
110
 
101
111
  # Column name mapping between features and feature importance values.
102
112
  # This is used to augment the schema.
113
+ # Generate unique column names to avoid collisions (not security-sensitive)
103
114
  col_map = {
104
- ft: f"{''.join(random.choices(string.ascii_letters, k=8))}"
115
+ ft: f"{''.join(random.choices(string.ascii_letters, k=8))}" # noqa: S311
105
116
  for ft in features
106
117
  }
107
118
  aug_schema = replace(schema, shap_values_column_names=col_map)
@@ -127,7 +138,7 @@ class Mimic:
127
138
  X[col] = X[col].astype(object).where(~X[col].isna(), np.nan)
128
139
 
129
140
  # Apply integer encoding to non-numeric columns.
130
- # Currently training and explaining detasets are the same, but
141
+ # Currently training and explaining datasets are the same, but
131
142
  # this can be changed in the future. The student model can be
132
143
  # fitted on a much larger dataset since it takes a lot less time.
133
144
  X = pd.concat(
@@ -156,7 +167,7 @@ class Mimic:
156
167
 
157
168
  # Fill null with zero so they're not counted as missing records by server
158
169
  if not Mimic._testing:
159
- aug_df.fillna({c: 0 for c in col_map.values()}, inplace=True)
170
+ aug_df.fillna(dict.fromkeys(col_map.values(), 0), inplace=True)
160
171
 
161
172
  return (
162
173
  aug_df,
arize/pre_releases.py ADDED
@@ -0,0 +1,43 @@
1
+ """Pre-release feature management and gating for the Arize SDK."""
2
+
3
+ import functools
4
+ import logging
5
+ from collections.abc import Callable
6
+ from enum import StrEnum
7
+
8
+ from arize.version import __version__
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ReleaseStage(StrEnum):
14
+ """Enum representing the release stage of API features."""
15
+
16
+ ALPHA = "alpha"
17
+ BETA = "beta"
18
+
19
+
20
+ _WARNED: set[str] = set()
21
+
22
+
23
+ def _format_prerelease_message(*, key: str, stage: ReleaseStage) -> str:
24
+ return (
25
+ f"[{stage.upper()}] {key} is an {stage} API "
26
+ f"in Arize SDK v{__version__} and may change without notice."
27
+ )
28
+
29
+
30
+ def prerelease_endpoint(*, stage: ReleaseStage, key: str) -> object:
31
+ """Decorate a method to emit a prerelease warning via logging once per process."""
32
+
33
+ def deco(fn: Callable[..., object]) -> object:
34
+ @functools.wraps(fn)
35
+ def wrapper(*args: object, **kwargs: object) -> object:
36
+ if key not in _WARNED:
37
+ _WARNED.add(key)
38
+ logger.warning(_format_prerelease_message(key=key, stage=stage))
39
+ return fn(*args, **kwargs)
40
+
41
+ return wrapper
42
+
43
+ return deco
@@ -0,0 +1 @@
1
+ """Project management and operations for the Arize platform."""
@@ -0,0 +1,129 @@
1
+ """Client implementation for managing projects in the Arize platform."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING
7
+
8
+ from arize.pre_releases import ReleaseStage, prerelease_endpoint
9
+
10
+ if TYPE_CHECKING:
11
+ from arize._generated.api_client import models
12
+ from arize.config import SDKConfiguration
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ProjectsClient:
18
+ """Client for managing Arize projects and project-level operations."""
19
+
20
+ def __init__(self, *, sdk_config: SDKConfiguration) -> None:
21
+ """Create a projects sub-client.
22
+
23
+ The projects client is a thin wrapper around the generated REST API client,
24
+ using the shared generated API client owned by `SDKConfiguration`.
25
+
26
+ Args:
27
+ sdk_config: Resolved SDK configuration.
28
+ """
29
+ self._sdk_config = sdk_config
30
+
31
+ # Import at runtime so it's still lazy and extras-gated by the parent
32
+ from arize._generated import api_client as gen
33
+
34
+ # Use the shared generated client from the config
35
+ self._api = gen.ProjectsApi(self._sdk_config.get_generated_client())
36
+
37
+ @prerelease_endpoint(key="projects.list", stage=ReleaseStage.BETA)
38
+ def list(
39
+ self,
40
+ *,
41
+ space_id: str | None = None,
42
+ limit: int = 100,
43
+ cursor: str | None = None,
44
+ ) -> models.ProjectsList200Response:
45
+ """List projects the user has access to.
46
+
47
+ This endpoint supports cursor-based pagination. When provided, `space_id`
48
+ filters results to a particular space.
49
+
50
+ Args:
51
+ space_id: Optional space ID to filter results.
52
+ limit: Maximum number of projects to return. The server may enforce
53
+ an upper bound.
54
+ cursor: Opaque pagination cursor from a previous response.
55
+
56
+ Returns:
57
+ A paginated project list response from the Arize REST API.
58
+
59
+ Raises:
60
+ arize._generated.api_client.exceptions.ApiException: If the API request fails.
61
+ """
62
+ return self._api.projects_list(
63
+ space_id=space_id,
64
+ limit=limit,
65
+ cursor=cursor,
66
+ )
67
+
68
+ @prerelease_endpoint(key="projects.create", stage=ReleaseStage.BETA)
69
+ def create(
70
+ self,
71
+ *,
72
+ name: str,
73
+ space_id: str,
74
+ ) -> models.Project:
75
+ """Create a new project.
76
+
77
+ Project names must be unique within the target space.
78
+
79
+ Args:
80
+ name: Project name (must be unique within `space_id`).
81
+ space_id: Space ID to create the project in.
82
+
83
+ Returns:
84
+ The created project object.
85
+
86
+ Raises:
87
+ arize._generated.api_client.exceptions.ApiException: If the API request fails
88
+ (for example, due to invalid input or a uniqueness conflict).
89
+ """
90
+ from arize._generated import api_client as gen
91
+
92
+ body = gen.ProjectsCreateRequest(
93
+ name=name,
94
+ space_id=space_id,
95
+ )
96
+ return self._api.projects_create(projects_create_request=body)
97
+
98
+ @prerelease_endpoint(key="projects.get", stage=ReleaseStage.BETA)
99
+ def get(self, *, project_id: str) -> models.Project:
100
+ """Get a project by ID.
101
+
102
+ Args:
103
+ project_id: Project ID.
104
+
105
+ Returns:
106
+ The project object.
107
+
108
+ Raises:
109
+ arize._generated.api_client.exceptions.ApiException: If the API request fails
110
+ (for example, project not found).
111
+ """
112
+ return self._api.projects_get(project_id=project_id)
113
+
114
+ @prerelease_endpoint(key="projects.delete", stage=ReleaseStage.BETA)
115
+ def delete(self, *, project_id: str) -> None:
116
+ """Delete a project by ID.
117
+
118
+ This operation is irreversible.
119
+
120
+ Args:
121
+ project_id: Project ID.
122
+
123
+ Returns: This method returns None on success (common empty 204 response)
124
+
125
+ Raises:
126
+ arize._generated.api_client.exceptions.ApiException: If the API request fails
127
+ (for example, project not found or insufficient permissions).
128
+ """
129
+ return self._api.projects_delete(project_id=project_id)