arize 8.0.0b1__py3-none-any.whl → 8.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. arize/__init__.py +9 -2
  2. arize/_client_factory.py +50 -0
  3. arize/_exporter/client.py +18 -17
  4. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  5. arize/_exporter/validation.py +1 -1
  6. arize/_flight/client.py +37 -17
  7. arize/_generated/api_client/api/datasets_api.py +6 -6
  8. arize/_generated/api_client/api/experiments_api.py +6 -6
  9. arize/_generated/api_client/api/projects_api.py +3 -3
  10. arize/_lazy.py +61 -10
  11. arize/client.py +66 -50
  12. arize/config.py +175 -48
  13. arize/constants/config.py +1 -0
  14. arize/constants/ml.py +9 -16
  15. arize/constants/spans.py +5 -10
  16. arize/datasets/client.py +45 -28
  17. arize/datasets/errors.py +1 -1
  18. arize/datasets/validation.py +2 -2
  19. arize/embeddings/auto_generator.py +16 -9
  20. arize/embeddings/base_generators.py +15 -9
  21. arize/embeddings/cv_generators.py +2 -2
  22. arize/embeddings/errors.py +2 -2
  23. arize/embeddings/nlp_generators.py +8 -8
  24. arize/embeddings/tabular_generators.py +6 -6
  25. arize/exceptions/base.py +0 -52
  26. arize/exceptions/config.py +22 -0
  27. arize/exceptions/parameters.py +1 -330
  28. arize/exceptions/values.py +8 -5
  29. arize/experiments/__init__.py +4 -0
  30. arize/experiments/client.py +31 -18
  31. arize/experiments/evaluators/base.py +12 -9
  32. arize/experiments/evaluators/executors.py +16 -7
  33. arize/experiments/evaluators/rate_limiters.py +3 -1
  34. arize/experiments/evaluators/types.py +9 -7
  35. arize/experiments/evaluators/utils.py +7 -5
  36. arize/experiments/functions.py +128 -58
  37. arize/experiments/tracing.py +4 -1
  38. arize/experiments/types.py +34 -31
  39. arize/logging.py +54 -33
  40. arize/ml/batch_validation/errors.py +10 -1004
  41. arize/ml/batch_validation/validator.py +351 -291
  42. arize/ml/bounded_executor.py +25 -6
  43. arize/ml/casting.py +51 -33
  44. arize/ml/client.py +43 -35
  45. arize/ml/proto.py +21 -22
  46. arize/ml/stream_validation.py +64 -27
  47. arize/ml/surrogate_explainer/mimic.py +18 -10
  48. arize/ml/types.py +27 -67
  49. arize/pre_releases.py +10 -6
  50. arize/projects/client.py +9 -4
  51. arize/py.typed +0 -0
  52. arize/regions.py +11 -11
  53. arize/spans/client.py +125 -31
  54. arize/spans/columns.py +32 -36
  55. arize/spans/conversion.py +12 -11
  56. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  57. arize/spans/validation/annotations/value_validation.py +11 -14
  58. arize/spans/validation/common/argument_validation.py +3 -3
  59. arize/spans/validation/common/dataframe_form_validation.py +7 -7
  60. arize/spans/validation/common/value_validation.py +11 -14
  61. arize/spans/validation/evals/dataframe_form_validation.py +4 -4
  62. arize/spans/validation/evals/evals_validation.py +6 -6
  63. arize/spans/validation/evals/value_validation.py +1 -1
  64. arize/spans/validation/metadata/argument_validation.py +1 -1
  65. arize/spans/validation/metadata/dataframe_form_validation.py +2 -2
  66. arize/spans/validation/metadata/value_validation.py +23 -1
  67. arize/spans/validation/spans/dataframe_form_validation.py +2 -2
  68. arize/spans/validation/spans/spans_validation.py +6 -6
  69. arize/utils/arrow.py +38 -2
  70. arize/utils/cache.py +2 -2
  71. arize/utils/dataframe.py +4 -4
  72. arize/utils/online_tasks/dataframe_preprocessor.py +15 -11
  73. arize/utils/openinference_conversion.py +10 -10
  74. arize/utils/proto.py +0 -1
  75. arize/utils/types.py +6 -6
  76. arize/version.py +1 -1
  77. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/METADATA +32 -7
  78. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/RECORD +81 -78
  79. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
  80. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
  81. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
@@ -1,7 +1,5 @@
1
1
  """Stream validation logic for ML model predictions."""
2
2
 
3
- # type: ignore[pb2]
4
-
5
3
  from arize.constants.ml import MAX_PREDICTION_ID_LEN, MIN_PREDICTION_ID_LEN
6
4
  from arize.exceptions.parameters import (
7
5
  InvalidValueType,
@@ -29,20 +27,8 @@ from arize.ml.types import (
29
27
  def validate_label(
30
28
  prediction_or_actual: str,
31
29
  model_type: ModelTypes,
32
- label: str
33
- | bool
34
- | int
35
- | float
36
- | tuple[str | bool, float]
37
- | ObjectDetectionLabel
38
- | RankingPredictionLabel
39
- | RankingActualLabel
40
- | SemanticSegmentationLabel
41
- | InstanceSegmentationPredictionLabel
42
- | InstanceSegmentationActualLabel
43
- | MultiClassPredictionLabel
44
- | MultiClassActualLabel,
45
- embedding_features: dict[str, Embedding],
30
+ label: object,
31
+ embedding_features: dict[str, Embedding] | None,
46
32
  ) -> None:
47
33
  """Validate a label value against the specified model type.
48
34
 
@@ -76,8 +62,17 @@ def validate_label(
76
62
 
77
63
  def _validate_numeric_label(
78
64
  model_type: ModelTypes,
79
- label: str | bool | int | float | tuple[str | bool, float],
65
+ label: object,
80
66
  ) -> None:
67
+ """Validate that a label is numeric (int or float) for numeric model types.
68
+
69
+ Args:
70
+ model_type: The model type being validated.
71
+ label: The label value to validate.
72
+
73
+ Raises:
74
+ InvalidValueType: If the label is not an int or float.
75
+ """
81
76
  if not isinstance(label, (float, int)):
82
77
  raise InvalidValueType(
83
78
  f"label {label}",
@@ -88,8 +83,18 @@ def _validate_numeric_label(
88
83
 
89
84
  def _validate_categorical_label(
90
85
  model_type: ModelTypes,
91
- label: str | bool | int | float | tuple[str | bool, float],
86
+ label: object,
92
87
  ) -> None:
88
+ """Validate that a label is categorical (scalar or tuple with confidence) for categorical model types.
89
+
90
+ Args:
91
+ model_type: The model type being validated.
92
+ label: The label value to validate.
93
+
94
+ Raises:
95
+ InvalidValueType: If the label is not a valid categorical type (bool, int, float, str,
96
+ or tuple of [str/bool, float]).
97
+ """
93
98
  is_valid = isinstance(label, (str, bool, int, float)) or (
94
99
  isinstance(label, tuple)
95
100
  and isinstance(label[0], (str, bool))
@@ -105,12 +110,20 @@ def _validate_categorical_label(
105
110
 
106
111
  def _validate_cv_label(
107
112
  prediction_or_actual: str,
108
- label: ObjectDetectionLabel
109
- | SemanticSegmentationLabel
110
- | InstanceSegmentationPredictionLabel
111
- | InstanceSegmentationActualLabel,
112
- embedding_features: dict[str, Embedding],
113
+ label: object,
114
+ embedding_features: dict[str, Embedding] | None,
113
115
  ) -> None:
116
+ """Validate a computer vision label for object detection or segmentation tasks.
117
+
118
+ Args:
119
+ prediction_or_actual: Either 'prediction' or 'actual' to indicate label context.
120
+ label: The CV label to validate.
121
+ embedding_features: Dictionary of embedding features that must contain exactly one entry.
122
+
123
+ Raises:
124
+ InvalidValueType: If the label is not a valid CV label type.
125
+ ValueError: If embedding_features is None or doesn't contain exactly one feature.
126
+ """
114
127
  if (
115
128
  not isinstance(label, ObjectDetectionLabel)
116
129
  and not isinstance(label, SemanticSegmentationLabel)
@@ -138,8 +151,16 @@ def _validate_cv_label(
138
151
 
139
152
 
140
153
  def _validate_ranking_label(
141
- label: RankingPredictionLabel | RankingActualLabel,
154
+ label: object,
142
155
  ) -> None:
156
+ """Validate a ranking label for ranking model types.
157
+
158
+ Args:
159
+ label: The ranking label to validate.
160
+
161
+ Raises:
162
+ InvalidValueType: If the label is not a RankingPredictionLabel or RankingActualLabel.
163
+ """
143
164
  if not isinstance(label, (RankingPredictionLabel, RankingActualLabel)):
144
165
  raise InvalidValueType(
145
166
  f"label {label}",
@@ -150,8 +171,16 @@ def _validate_ranking_label(
150
171
 
151
172
 
152
173
  def _validate_generative_llm_label(
153
- label: str | bool | int | float,
174
+ label: object,
154
175
  ) -> None:
176
+ """Validate a label for generative LLM model types.
177
+
178
+ Args:
179
+ label: The label value to validate.
180
+
181
+ Raises:
182
+ InvalidValueType: If the label is not a bool, int, float, or str.
183
+ """
155
184
  is_valid = isinstance(label, (str, bool, int, float))
156
185
  if not is_valid:
157
186
  raise InvalidValueType(
@@ -162,8 +191,16 @@ def _validate_generative_llm_label(
162
191
 
163
192
 
164
193
  def _validate_multi_class_label(
165
- label: MultiClassPredictionLabel | MultiClassActualLabel,
194
+ label: object,
166
195
  ) -> None:
196
+ """Validate a multi-class label for multi-class model types.
197
+
198
+ Args:
199
+ label: The multi-class label to validate.
200
+
201
+ Raises:
202
+ InvalidValueType: If the label is not a MultiClassPredictionLabel or MultiClassActualLabel.
203
+ """
167
204
  if not isinstance(
168
205
  label, (MultiClassPredictionLabel, MultiClassActualLabel)
169
206
  ):
@@ -185,7 +222,7 @@ def validate_and_convert_prediction_id(
185
222
  """Validate and convert a prediction ID to string format, or generate one if absent.
186
223
 
187
224
  Args:
188
- prediction_id: The prediction ID to validate/convert, or None.
225
+ prediction_id: The prediction ID to validate/convert, or :obj:`None`.
189
226
  environment: The environment context (training, validation, production).
190
227
  prediction_label: Optional prediction label for delayed record detection.
191
228
  actual_label: Optional actual label for delayed record detection.
@@ -19,6 +19,7 @@ from arize.ml.types import (
19
19
  CATEGORICAL_MODEL_TYPES,
20
20
  NUMERIC_MODEL_TYPES,
21
21
  ModelTypes,
22
+ _normalize_column_names,
22
23
  )
23
24
 
24
25
  if TYPE_CHECKING:
@@ -36,7 +37,7 @@ class Mimic:
36
37
  """Initialize the Mimic explainer with training data and model.
37
38
 
38
39
  Args:
39
- X: Training data DataFrame for the surrogate model.
40
+ X: Training data :class:`pandas.DataFrame` for the surrogate model.
40
41
  model_func: Model function to explain.
41
42
  """
42
43
  self.explainer = MimicExplainer(
@@ -48,7 +49,7 @@ class Mimic:
48
49
  )
49
50
 
50
51
  def explain(self, X: pd.DataFrame) -> pd.DataFrame:
51
- """Explain feature importance for the given input DataFrame."""
52
+ """Explain feature importance for the given input :class:`pandas.DataFrame`."""
52
53
  return pd.DataFrame(
53
54
  self.explainer.explain_local(X).local_importance_values,
54
55
  columns=X.columns,
@@ -59,8 +60,8 @@ class Mimic:
59
60
  def augment(
60
61
  df: pd.DataFrame, schema: Schema, model_type: ModelTypes
61
62
  ) -> tuple[pd.DataFrame, Schema]:
62
- """Augment the DataFrame and schema with SHAP values for explainability."""
63
- features = schema.feature_column_names
63
+ """Augment the :class:`pandas.DataFrame` and schema with SHAP values for explainability."""
64
+ features = _normalize_column_names(schema.feature_column_names)
64
65
  X = df[features]
65
66
 
66
67
  if X.shape[1] == 0:
@@ -85,25 +86,32 @@ class Mimic:
85
86
  )
86
87
 
87
88
  # model func requires 1 positional argument
88
- def model_func(_: object) -> object: # type: ignore
89
+ def model_func(_: object) -> object:
89
90
  return np.column_stack((1 - y, y))
90
91
 
91
92
  elif model_type in NUMERIC_MODEL_TYPES:
92
- y_col_name = schema.prediction_label_column_name
93
+ y_col_name_nullable: str | None = (
94
+ schema.prediction_label_column_name
95
+ )
93
96
  if schema.prediction_score_column_name is not None:
94
- y_col_name = schema.prediction_score_column_name
95
- y = df[y_col_name].to_numpy()
97
+ y_col_name_nullable = schema.prediction_score_column_name
98
+ if y_col_name_nullable is None:
99
+ raise ValueError(
100
+ f"For {model_type} models, either prediction_label_column_name "
101
+ "or prediction_score_column_name must be specified"
102
+ )
103
+ y = df[y_col_name_nullable].to_numpy()
96
104
 
97
105
  _finite_count = np.isfinite(y).sum()
98
106
  if len(y) - _finite_count:
99
107
  raise ValueError(
100
108
  f"To calculate surrogate explainability for {model_type}, "
101
109
  f"predictions must not contain NaN or infinite values, but "
102
- f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name}."
110
+ f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name_nullable}."
103
111
  )
104
112
 
105
113
  # model func requires 1 positional argument
106
- def model_func(_: object) -> object: # type: ignore
114
+ def model_func(_: object) -> object:
107
115
  return y
108
116
 
109
117
  else:
arize/ml/types.py CHANGED
@@ -2,47 +2,47 @@
2
2
 
3
3
  import logging
4
4
  import math
5
+ import sys
5
6
  from collections.abc import Iterator
6
7
  from dataclasses import asdict, dataclass, replace
7
8
  from datetime import datetime
8
9
  from decimal import Decimal
9
10
  from enum import Enum, unique
10
11
  from itertools import chain
11
- from typing import (
12
- NamedTuple,
13
- Self,
14
- )
12
+ from typing import NamedTuple
13
+
14
+ if sys.version_info >= (3, 11):
15
+ from typing import Self
16
+ else:
17
+ from typing_extensions import Self
15
18
 
16
19
  import numpy as np
17
20
 
18
21
  from arize.constants.ml import (
19
- # MAX_MULTI_CLASS_NAME_LENGTH,
20
- # MAX_NUMBER_OF_MULTI_CLASS_CLASSES,
21
22
  MAX_MULTI_CLASS_NAME_LENGTH,
22
23
  MAX_NUMBER_OF_MULTI_CLASS_CLASSES,
23
24
  MAX_NUMBER_OF_SIMILARITY_REFERENCES,
24
25
  MAX_RAW_DATA_CHARACTERS,
25
26
  MAX_RAW_DATA_CHARACTERS_TRUNCATION,
26
- # MAX_RAW_DATA_CHARACTERS,
27
- # MAX_RAW_DATA_CHARACTERS_TRUNCATION,
28
27
  )
29
28
  from arize.exceptions.parameters import InvalidValueType
30
-
31
- #
32
- # from arize.utils.constants import (
33
- # MAX_MULTI_CLASS_NAME_LENGTH,
34
- # MAX_NUMBER_OF_MULTI_CLASS_CLASSES,
35
- # MAX_NUMBER_OF_SIMILARITY_REFERENCES,
36
- # MAX_RAW_DATA_CHARACTERS,
37
- # MAX_RAW_DATA_CHARACTERS_TRUNCATION,
38
- # )
39
- # from arize.utils.errors import InvalidValueType
40
29
  from arize.logging import get_truncation_warning_message
41
30
  from arize.utils.types import is_dict_of, is_iterable_of, is_list_of
42
31
 
43
32
  logger = logging.getLogger(__name__)
44
33
 
45
34
 
35
+ def _normalize_column_names(
36
+ col_names: "list[str] | TypedColumns | None",
37
+ ) -> list[str]:
38
+ """Convert TypedColumns or list to a flat list of column names."""
39
+ if col_names is None:
40
+ return []
41
+ if isinstance(col_names, list):
42
+ return col_names
43
+ return col_names.get_all_column_names()
44
+
45
+
46
46
  @unique
47
47
  class ModelTypes(Enum):
48
48
  """Enum representing supported model types in Arize."""
@@ -204,7 +204,7 @@ class Embedding(NamedTuple):
204
204
  )
205
205
  # Fail if not all elements in list are floats
206
206
  allowed_types = (int, float, np.int16, np.int32, np.float16, np.float32)
207
- if not all(isinstance(val, allowed_types) for val in self.vector): # type: ignore
207
+ if not all(isinstance(val, allowed_types) for val in self.vector):
208
208
  raise TypeError(
209
209
  f"Embedding vector must be a vector of integers and/or floats. Got "
210
210
  f"{emb_name}.vector = {self.vector}"
@@ -283,7 +283,7 @@ class Embedding(NamedTuple):
283
283
 
284
284
  @staticmethod
285
285
  def _is_valid_iterable(
286
- data: str | list[str] | list[float] | np.ndarray,
286
+ data: object,
287
287
  ) -> bool:
288
288
  """Validates that the input data field is of the correct iterable type.
289
289
 
@@ -299,30 +299,6 @@ class Embedding(NamedTuple):
299
299
  return any(isinstance(data, t) for t in (list, np.ndarray))
300
300
 
301
301
 
302
- # @dataclass
303
- # class _PromptOrResponseText:
304
- # data: str
305
- #
306
- # def validate(self, name: str) -> None:
307
- # # Validate that data is a string
308
- # if not isinstance(self.data, str):
309
- # raise TypeError(f"'{name}' must be a str")
310
- #
311
- # character_count = len(self.data)
312
- # if character_count > MAX_RAW_DATA_CHARACTERS:
313
- # raise ValueError(
314
- # f"'{name}' field must not contain more than {MAX_RAW_DATA_CHARACTERS} characters. "
315
- # f"Found {character_count}."
316
- # )
317
- # elif character_count > MAX_RAW_DATA_CHARACTERS_TRUNCATION:
318
- # logger.warning(
319
- # get_truncation_warning_message(
320
- # f"'{name}'", MAX_RAW_DATA_CHARACTERS_TRUNCATION
321
- # )
322
- # )
323
- # return None
324
-
325
-
326
302
  class LLMRunMetadata(NamedTuple):
327
303
  """Metadata for LLM execution including token counts and latency."""
328
304
 
@@ -1021,22 +997,6 @@ class LLMRunMetadataColumnNames:
1021
997
  )
1022
998
 
1023
999
 
1024
- # @dataclass
1025
- # class DocumentColumnNames:
1026
- # id_column_name: Optional[str] = None
1027
- # version_column_name: Optional[str] = None
1028
- # text_embedding_column_names: Optional[EmbeddingColumnNames] = None
1029
- #
1030
- # def __iter__(self):
1031
- # return iter(
1032
- # (
1033
- # self.id_column_name,
1034
- # self.version_column_name,
1035
- # self.text_embedding_column_names,
1036
- # )
1037
- # )
1038
- #
1039
- #
1040
1000
  @dataclass
1041
1001
  class SimilarityReference:
1042
1002
  """Reference to a prediction for similarity search operations."""
@@ -1250,7 +1210,7 @@ class Schema(BaseSchema):
1250
1210
  actual_score_column_name: str | None = None
1251
1211
  shap_values_column_names: dict[str, str] | None = None
1252
1212
  embedding_feature_column_names: dict[str, EmbeddingColumnNames] | None = (
1253
- None # type:ignore
1213
+ None
1254
1214
  )
1255
1215
  prediction_group_id_column_name: str | None = None
1256
1216
  rank_column_name: str | None = None
@@ -1268,7 +1228,7 @@ class Schema(BaseSchema):
1268
1228
  prompt_template_column_names: PromptTemplateColumnNames | None = None
1269
1229
  llm_config_column_names: LLMConfigColumnNames | None = None
1270
1230
  llm_run_metadata_column_names: LLMRunMetadataColumnNames | None = None
1271
- retrieved_document_ids_column_name: list[str] | None = None
1231
+ retrieved_document_ids_column_name: str | None = None
1272
1232
  multi_class_threshold_scores_column_name: str | None = None
1273
1233
  semantic_segmentation_prediction_column_names: (
1274
1234
  SemanticSegmentationColumnNames | None
@@ -1285,7 +1245,7 @@ class Schema(BaseSchema):
1285
1245
 
1286
1246
  def get_used_columns_counts(self) -> dict[str, int]:
1287
1247
  """Return a dict mapping column names to their usage count."""
1288
- columns_used_counts = {}
1248
+ columns_used_counts: dict[str, int] = {}
1289
1249
 
1290
1250
  for field in self.__dataclass_fields__:
1291
1251
  if field.endswith("column_name"):
@@ -1294,7 +1254,7 @@ class Schema(BaseSchema):
1294
1254
  add_to_column_count_dictionary(columns_used_counts, col)
1295
1255
 
1296
1256
  if self.feature_column_names is not None:
1297
- for col in self.feature_column_names:
1257
+ for col in _normalize_column_names(self.feature_column_names):
1298
1258
  add_to_column_count_dictionary(columns_used_counts, col)
1299
1259
 
1300
1260
  if self.embedding_feature_column_names is not None:
@@ -1313,7 +1273,7 @@ class Schema(BaseSchema):
1313
1273
  )
1314
1274
 
1315
1275
  if self.tag_column_names is not None:
1316
- for col in self.tag_column_names:
1276
+ for col in _normalize_column_names(self.tag_column_names):
1317
1277
  add_to_column_count_dictionary(columns_used_counts, col)
1318
1278
 
1319
1279
  if self.shap_values_column_names is not None:
@@ -1458,7 +1418,7 @@ class CorpusSchema(BaseSchema):
1458
1418
 
1459
1419
  def get_used_columns_counts(self) -> dict[str, int]:
1460
1420
  """Return a dict mapping column names to their usage count."""
1461
- columns_used_counts = {}
1421
+ columns_used_counts: dict[str, int] = {}
1462
1422
 
1463
1423
  if self.document_id_column_name is not None:
1464
1424
  add_to_column_count_dictionary(
@@ -1531,7 +1491,7 @@ def add_to_column_count_dictionary(
1531
1491
 
1532
1492
  Args:
1533
1493
  column_dictionary: Dictionary mapping column names to counts.
1534
- col: The column name to increment, or None to skip.
1494
+ col: The column name to increment, or :obj:`None` to skip.
1535
1495
  """
1536
1496
  if col:
1537
1497
  if col in column_dictionary:
arize/pre_releases.py CHANGED
@@ -3,14 +3,15 @@
3
3
  import functools
4
4
  import logging
5
5
  from collections.abc import Callable
6
- from enum import StrEnum
6
+ from enum import Enum
7
+ from typing import TypeVar, cast
7
8
 
8
9
  from arize.version import __version__
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
12
13
 
13
- class ReleaseStage(StrEnum):
14
+ class ReleaseStage(Enum):
14
15
  """Enum representing the release stage of API features."""
15
16
 
16
17
  ALPHA = "alpha"
@@ -19,19 +20,21 @@ class ReleaseStage(StrEnum):
19
20
 
20
21
  _WARNED: set[str] = set()
21
22
 
23
+ _F = TypeVar("_F", bound=Callable)
24
+
22
25
 
23
26
  def _format_prerelease_message(*, key: str, stage: ReleaseStage) -> str:
24
27
  article = "an" if stage is ReleaseStage.ALPHA else "a"
25
28
  return (
26
- f"[{stage.upper()}] {key} is {article} {stage} API "
29
+ f"[{stage.value.upper()}] {key} is {article} {stage.value} API "
27
30
  f"in Arize SDK v{__version__} and may change without notice."
28
31
  )
29
32
 
30
33
 
31
- def prerelease_endpoint(*, stage: ReleaseStage, key: str) -> object:
34
+ def prerelease_endpoint(*, key: str, stage: ReleaseStage) -> Callable[[_F], _F]:
32
35
  """Decorate a method to emit a prerelease warning via logging once per process."""
33
36
 
34
- def deco(fn: Callable[..., object]) -> object:
37
+ def deco(fn: _F) -> _F:
35
38
  @functools.wraps(fn)
36
39
  def wrapper(*args: object, **kwargs: object) -> object:
37
40
  if key not in _WARNED:
@@ -39,6 +42,7 @@ def prerelease_endpoint(*, stage: ReleaseStage, key: str) -> object:
39
42
  logger.warning(_format_prerelease_message(key=key, stage=stage))
40
43
  return fn(*args, **kwargs)
41
44
 
42
- return wrapper
45
+ # Cast: functools.wraps preserves function signature at runtime but mypy can't verify this
46
+ return cast("_F", wrapper)
43
47
 
44
48
  return deco
arize/projects/client.py CHANGED
@@ -9,6 +9,7 @@ from arize.pre_releases import ReleaseStage, prerelease_endpoint
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from arize._generated.api_client import models
12
+ from arize._generated.api_client.api_client import ApiClient
12
13
  from arize.config import SDKConfiguration
13
14
 
14
15
  logger = logging.getLogger(__name__)
@@ -26,18 +27,21 @@ class ProjectsClient:
26
27
  :class:`arize.config.SDKConfiguration`.
27
28
  """
28
29
 
29
- def __init__(self, *, sdk_config: SDKConfiguration) -> None:
30
+ def __init__(
31
+ self, *, sdk_config: SDKConfiguration, generated_client: ApiClient
32
+ ) -> None:
30
33
  """
31
34
  Args:
32
35
  sdk_config: Resolved SDK configuration.
36
+ generated_client: Shared generated API client instance.
33
37
  """ # noqa: D205, D212
34
38
  self._sdk_config = sdk_config
35
39
 
36
40
  # Import at runtime so it's still lazy and extras-gated by the parent
37
41
  from arize._generated import api_client as gen
38
42
 
39
- # Use the shared generated client from the config
40
- self._api = gen.ProjectsApi(self._sdk_config.get_generated_client())
43
+ # Use the provided client directly
44
+ self._api = gen.ProjectsApi(generated_client)
41
45
 
42
46
  @prerelease_endpoint(key="projects.list", stage=ReleaseStage.BETA)
43
47
  def list(
@@ -125,7 +129,8 @@ class ProjectsClient:
125
129
  Args:
126
130
  project_id: Project ID.
127
131
 
128
- Returns: This method returns None on success (common empty 204 response)
132
+ Returns:
133
+ This method returns None on success (common empty 204 response).
129
134
 
130
135
  Raises:
131
136
  arize._generated.api_client.exceptions.ApiException: If the API request fails
arize/py.typed ADDED
File without changes
arize/regions.py CHANGED
@@ -1,19 +1,19 @@
1
1
  """Region definitions and configuration for Arize deployment zones."""
2
2
 
3
3
  from dataclasses import dataclass
4
- from enum import StrEnum
4
+ from enum import Enum
5
5
 
6
6
  from arize.constants.config import DEFAULT_FLIGHT_PORT
7
7
 
8
8
 
9
- class Region(StrEnum):
9
+ class Region(Enum):
10
10
  """Enum representing available Arize deployment regions."""
11
11
 
12
- US_CENTRAL_1 = "us-central-1a"
13
- EU_WEST_1 = "eu-west-1a"
14
- CA_CENTRAL_1 = "ca-central-1a"
15
- US_EAST_1 = "us-east-1b"
16
- UNSPECIFIED = ""
12
+ CA_CENTRAL_1A = "ca-central-1a"
13
+ EU_WEST_1A = "eu-west-1a"
14
+ US_CENTRAL_1A = "us-central-1a"
15
+ US_EAST_1B = "us-east-1b"
16
+ UNSET = ""
17
17
 
18
18
 
19
19
  @dataclass(frozen=True)
@@ -28,13 +28,13 @@ class RegionEndpoints:
28
28
 
29
29
  def _get_region_endpoints(region: Region) -> RegionEndpoints:
30
30
  return RegionEndpoints(
31
- api_host=f"api.{region}.arize.com",
32
- otlp_host=f"otlp.{region}.arize.com",
33
- flight_host=f"flight.{region}.arize.com",
31
+ api_host=f"api.{region.value}.arize.com",
32
+ otlp_host=f"otlp.{region.value}.arize.com",
33
+ flight_host=f"flight.{region.value}.arize.com",
34
34
  flight_port=DEFAULT_FLIGHT_PORT,
35
35
  )
36
36
 
37
37
 
38
38
  REGION_ENDPOINTS: dict[Region, RegionEndpoints] = {
39
- r: _get_region_endpoints(r) for r in Region if r != Region.UNSPECIFIED
39
+ r: _get_region_endpoints(r) for r in list(Region) if r != Region.UNSET
40
40
  }