arize 8.0.0b1__py3-none-any.whl → 8.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +9 -2
- arize/_client_factory.py +50 -0
- arize/_exporter/client.py +18 -17
- arize/_exporter/parsers/tracing_data_parser.py +9 -4
- arize/_exporter/validation.py +1 -1
- arize/_flight/client.py +37 -17
- arize/_generated/api_client/api/datasets_api.py +6 -6
- arize/_generated/api_client/api/experiments_api.py +6 -6
- arize/_generated/api_client/api/projects_api.py +3 -3
- arize/_lazy.py +61 -10
- arize/client.py +66 -50
- arize/config.py +175 -48
- arize/constants/config.py +1 -0
- arize/constants/ml.py +9 -16
- arize/constants/spans.py +5 -10
- arize/datasets/client.py +45 -28
- arize/datasets/errors.py +1 -1
- arize/datasets/validation.py +2 -2
- arize/embeddings/auto_generator.py +16 -9
- arize/embeddings/base_generators.py +15 -9
- arize/embeddings/cv_generators.py +2 -2
- arize/embeddings/errors.py +2 -2
- arize/embeddings/nlp_generators.py +8 -8
- arize/embeddings/tabular_generators.py +6 -6
- arize/exceptions/base.py +0 -52
- arize/exceptions/config.py +22 -0
- arize/exceptions/parameters.py +1 -330
- arize/exceptions/values.py +8 -5
- arize/experiments/__init__.py +4 -0
- arize/experiments/client.py +31 -18
- arize/experiments/evaluators/base.py +12 -9
- arize/experiments/evaluators/executors.py +16 -7
- arize/experiments/evaluators/rate_limiters.py +3 -1
- arize/experiments/evaluators/types.py +9 -7
- arize/experiments/evaluators/utils.py +7 -5
- arize/experiments/functions.py +128 -58
- arize/experiments/tracing.py +4 -1
- arize/experiments/types.py +34 -31
- arize/logging.py +54 -33
- arize/ml/batch_validation/errors.py +10 -1004
- arize/ml/batch_validation/validator.py +351 -291
- arize/ml/bounded_executor.py +25 -6
- arize/ml/casting.py +51 -33
- arize/ml/client.py +43 -35
- arize/ml/proto.py +21 -22
- arize/ml/stream_validation.py +64 -27
- arize/ml/surrogate_explainer/mimic.py +18 -10
- arize/ml/types.py +27 -67
- arize/pre_releases.py +10 -6
- arize/projects/client.py +9 -4
- arize/py.typed +0 -0
- arize/regions.py +11 -11
- arize/spans/client.py +125 -31
- arize/spans/columns.py +32 -36
- arize/spans/conversion.py +12 -11
- arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
- arize/spans/validation/annotations/value_validation.py +11 -14
- arize/spans/validation/common/argument_validation.py +3 -3
- arize/spans/validation/common/dataframe_form_validation.py +7 -7
- arize/spans/validation/common/value_validation.py +11 -14
- arize/spans/validation/evals/dataframe_form_validation.py +4 -4
- arize/spans/validation/evals/evals_validation.py +6 -6
- arize/spans/validation/evals/value_validation.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +1 -1
- arize/spans/validation/metadata/dataframe_form_validation.py +2 -2
- arize/spans/validation/metadata/value_validation.py +23 -1
- arize/spans/validation/spans/dataframe_form_validation.py +2 -2
- arize/spans/validation/spans/spans_validation.py +6 -6
- arize/utils/arrow.py +38 -2
- arize/utils/cache.py +2 -2
- arize/utils/dataframe.py +4 -4
- arize/utils/online_tasks/dataframe_preprocessor.py +15 -11
- arize/utils/openinference_conversion.py +10 -10
- arize/utils/proto.py +0 -1
- arize/utils/types.py +6 -6
- arize/version.py +1 -1
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/METADATA +32 -7
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/RECORD +81 -78
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
arize/ml/stream_validation.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Stream validation logic for ML model predictions."""
|
|
2
2
|
|
|
3
|
-
# type: ignore[pb2]
|
|
4
|
-
|
|
5
3
|
from arize.constants.ml import MAX_PREDICTION_ID_LEN, MIN_PREDICTION_ID_LEN
|
|
6
4
|
from arize.exceptions.parameters import (
|
|
7
5
|
InvalidValueType,
|
|
@@ -29,20 +27,8 @@ from arize.ml.types import (
|
|
|
29
27
|
def validate_label(
|
|
30
28
|
prediction_or_actual: str,
|
|
31
29
|
model_type: ModelTypes,
|
|
32
|
-
label:
|
|
33
|
-
|
|
|
34
|
-
| int
|
|
35
|
-
| float
|
|
36
|
-
| tuple[str | bool, float]
|
|
37
|
-
| ObjectDetectionLabel
|
|
38
|
-
| RankingPredictionLabel
|
|
39
|
-
| RankingActualLabel
|
|
40
|
-
| SemanticSegmentationLabel
|
|
41
|
-
| InstanceSegmentationPredictionLabel
|
|
42
|
-
| InstanceSegmentationActualLabel
|
|
43
|
-
| MultiClassPredictionLabel
|
|
44
|
-
| MultiClassActualLabel,
|
|
45
|
-
embedding_features: dict[str, Embedding],
|
|
30
|
+
label: object,
|
|
31
|
+
embedding_features: dict[str, Embedding] | None,
|
|
46
32
|
) -> None:
|
|
47
33
|
"""Validate a label value against the specified model type.
|
|
48
34
|
|
|
@@ -76,8 +62,17 @@ def validate_label(
|
|
|
76
62
|
|
|
77
63
|
def _validate_numeric_label(
|
|
78
64
|
model_type: ModelTypes,
|
|
79
|
-
label:
|
|
65
|
+
label: object,
|
|
80
66
|
) -> None:
|
|
67
|
+
"""Validate that a label is numeric (int or float) for numeric model types.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
model_type: The model type being validated.
|
|
71
|
+
label: The label value to validate.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
InvalidValueType: If the label is not an int or float.
|
|
75
|
+
"""
|
|
81
76
|
if not isinstance(label, (float, int)):
|
|
82
77
|
raise InvalidValueType(
|
|
83
78
|
f"label {label}",
|
|
@@ -88,8 +83,18 @@ def _validate_numeric_label(
|
|
|
88
83
|
|
|
89
84
|
def _validate_categorical_label(
|
|
90
85
|
model_type: ModelTypes,
|
|
91
|
-
label:
|
|
86
|
+
label: object,
|
|
92
87
|
) -> None:
|
|
88
|
+
"""Validate that a label is categorical (scalar or tuple with confidence) for categorical model types.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model_type: The model type being validated.
|
|
92
|
+
label: The label value to validate.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
InvalidValueType: If the label is not a valid categorical type (bool, int, float, str,
|
|
96
|
+
or tuple of [str/bool, float]).
|
|
97
|
+
"""
|
|
93
98
|
is_valid = isinstance(label, (str, bool, int, float)) or (
|
|
94
99
|
isinstance(label, tuple)
|
|
95
100
|
and isinstance(label[0], (str, bool))
|
|
@@ -105,12 +110,20 @@ def _validate_categorical_label(
|
|
|
105
110
|
|
|
106
111
|
def _validate_cv_label(
|
|
107
112
|
prediction_or_actual: str,
|
|
108
|
-
label:
|
|
109
|
-
|
|
|
110
|
-
| InstanceSegmentationPredictionLabel
|
|
111
|
-
| InstanceSegmentationActualLabel,
|
|
112
|
-
embedding_features: dict[str, Embedding],
|
|
113
|
+
label: object,
|
|
114
|
+
embedding_features: dict[str, Embedding] | None,
|
|
113
115
|
) -> None:
|
|
116
|
+
"""Validate a computer vision label for object detection or segmentation tasks.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
prediction_or_actual: Either 'prediction' or 'actual' to indicate label context.
|
|
120
|
+
label: The CV label to validate.
|
|
121
|
+
embedding_features: Dictionary of embedding features that must contain exactly one entry.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
InvalidValueType: If the label is not a valid CV label type.
|
|
125
|
+
ValueError: If embedding_features is None or doesn't contain exactly one feature.
|
|
126
|
+
"""
|
|
114
127
|
if (
|
|
115
128
|
not isinstance(label, ObjectDetectionLabel)
|
|
116
129
|
and not isinstance(label, SemanticSegmentationLabel)
|
|
@@ -138,8 +151,16 @@ def _validate_cv_label(
|
|
|
138
151
|
|
|
139
152
|
|
|
140
153
|
def _validate_ranking_label(
|
|
141
|
-
label:
|
|
154
|
+
label: object,
|
|
142
155
|
) -> None:
|
|
156
|
+
"""Validate a ranking label for ranking model types.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
label: The ranking label to validate.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
InvalidValueType: If the label is not a RankingPredictionLabel or RankingActualLabel.
|
|
163
|
+
"""
|
|
143
164
|
if not isinstance(label, (RankingPredictionLabel, RankingActualLabel)):
|
|
144
165
|
raise InvalidValueType(
|
|
145
166
|
f"label {label}",
|
|
@@ -150,8 +171,16 @@ def _validate_ranking_label(
|
|
|
150
171
|
|
|
151
172
|
|
|
152
173
|
def _validate_generative_llm_label(
|
|
153
|
-
label:
|
|
174
|
+
label: object,
|
|
154
175
|
) -> None:
|
|
176
|
+
"""Validate a label for generative LLM model types.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
label: The label value to validate.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
InvalidValueType: If the label is not a bool, int, float, or str.
|
|
183
|
+
"""
|
|
155
184
|
is_valid = isinstance(label, (str, bool, int, float))
|
|
156
185
|
if not is_valid:
|
|
157
186
|
raise InvalidValueType(
|
|
@@ -162,8 +191,16 @@ def _validate_generative_llm_label(
|
|
|
162
191
|
|
|
163
192
|
|
|
164
193
|
def _validate_multi_class_label(
|
|
165
|
-
label:
|
|
194
|
+
label: object,
|
|
166
195
|
) -> None:
|
|
196
|
+
"""Validate a multi-class label for multi-class model types.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
label: The multi-class label to validate.
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
InvalidValueType: If the label is not a MultiClassPredictionLabel or MultiClassActualLabel.
|
|
203
|
+
"""
|
|
167
204
|
if not isinstance(
|
|
168
205
|
label, (MultiClassPredictionLabel, MultiClassActualLabel)
|
|
169
206
|
):
|
|
@@ -185,7 +222,7 @@ def validate_and_convert_prediction_id(
|
|
|
185
222
|
"""Validate and convert a prediction ID to string format, or generate one if absent.
|
|
186
223
|
|
|
187
224
|
Args:
|
|
188
|
-
prediction_id: The prediction ID to validate/convert, or None
|
|
225
|
+
prediction_id: The prediction ID to validate/convert, or :obj:`None`.
|
|
189
226
|
environment: The environment context (training, validation, production).
|
|
190
227
|
prediction_label: Optional prediction label for delayed record detection.
|
|
191
228
|
actual_label: Optional actual label for delayed record detection.
|
|
@@ -19,6 +19,7 @@ from arize.ml.types import (
|
|
|
19
19
|
CATEGORICAL_MODEL_TYPES,
|
|
20
20
|
NUMERIC_MODEL_TYPES,
|
|
21
21
|
ModelTypes,
|
|
22
|
+
_normalize_column_names,
|
|
22
23
|
)
|
|
23
24
|
|
|
24
25
|
if TYPE_CHECKING:
|
|
@@ -36,7 +37,7 @@ class Mimic:
|
|
|
36
37
|
"""Initialize the Mimic explainer with training data and model.
|
|
37
38
|
|
|
38
39
|
Args:
|
|
39
|
-
X: Training data DataFrame for the surrogate model.
|
|
40
|
+
X: Training data :class:`pandas.DataFrame` for the surrogate model.
|
|
40
41
|
model_func: Model function to explain.
|
|
41
42
|
"""
|
|
42
43
|
self.explainer = MimicExplainer(
|
|
@@ -48,7 +49,7 @@ class Mimic:
|
|
|
48
49
|
)
|
|
49
50
|
|
|
50
51
|
def explain(self, X: pd.DataFrame) -> pd.DataFrame:
|
|
51
|
-
"""Explain feature importance for the given input DataFrame
|
|
52
|
+
"""Explain feature importance for the given input :class:`pandas.DataFrame`."""
|
|
52
53
|
return pd.DataFrame(
|
|
53
54
|
self.explainer.explain_local(X).local_importance_values,
|
|
54
55
|
columns=X.columns,
|
|
@@ -59,8 +60,8 @@ class Mimic:
|
|
|
59
60
|
def augment(
|
|
60
61
|
df: pd.DataFrame, schema: Schema, model_type: ModelTypes
|
|
61
62
|
) -> tuple[pd.DataFrame, Schema]:
|
|
62
|
-
"""Augment the DataFrame and schema with SHAP values for explainability."""
|
|
63
|
-
features = schema.feature_column_names
|
|
63
|
+
"""Augment the :class:`pandas.DataFrame` and schema with SHAP values for explainability."""
|
|
64
|
+
features = _normalize_column_names(schema.feature_column_names)
|
|
64
65
|
X = df[features]
|
|
65
66
|
|
|
66
67
|
if X.shape[1] == 0:
|
|
@@ -85,25 +86,32 @@ class Mimic:
|
|
|
85
86
|
)
|
|
86
87
|
|
|
87
88
|
# model func requires 1 positional argument
|
|
88
|
-
def model_func(_: object) -> object:
|
|
89
|
+
def model_func(_: object) -> object:
|
|
89
90
|
return np.column_stack((1 - y, y))
|
|
90
91
|
|
|
91
92
|
elif model_type in NUMERIC_MODEL_TYPES:
|
|
92
|
-
|
|
93
|
+
y_col_name_nullable: str | None = (
|
|
94
|
+
schema.prediction_label_column_name
|
|
95
|
+
)
|
|
93
96
|
if schema.prediction_score_column_name is not None:
|
|
94
|
-
|
|
95
|
-
|
|
97
|
+
y_col_name_nullable = schema.prediction_score_column_name
|
|
98
|
+
if y_col_name_nullable is None:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"For {model_type} models, either prediction_label_column_name "
|
|
101
|
+
"or prediction_score_column_name must be specified"
|
|
102
|
+
)
|
|
103
|
+
y = df[y_col_name_nullable].to_numpy()
|
|
96
104
|
|
|
97
105
|
_finite_count = np.isfinite(y).sum()
|
|
98
106
|
if len(y) - _finite_count:
|
|
99
107
|
raise ValueError(
|
|
100
108
|
f"To calculate surrogate explainability for {model_type}, "
|
|
101
109
|
f"predictions must not contain NaN or infinite values, but "
|
|
102
|
-
f"{len(y) - _finite_count} NaN or infinite value(s) are found in {
|
|
110
|
+
f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name_nullable}."
|
|
103
111
|
)
|
|
104
112
|
|
|
105
113
|
# model func requires 1 positional argument
|
|
106
|
-
def model_func(_: object) -> object:
|
|
114
|
+
def model_func(_: object) -> object:
|
|
107
115
|
return y
|
|
108
116
|
|
|
109
117
|
else:
|
arize/ml/types.py
CHANGED
|
@@ -2,47 +2,47 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import math
|
|
5
|
+
import sys
|
|
5
6
|
from collections.abc import Iterator
|
|
6
7
|
from dataclasses import asdict, dataclass, replace
|
|
7
8
|
from datetime import datetime
|
|
8
9
|
from decimal import Decimal
|
|
9
10
|
from enum import Enum, unique
|
|
10
11
|
from itertools import chain
|
|
11
|
-
from typing import
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
12
|
+
from typing import NamedTuple
|
|
13
|
+
|
|
14
|
+
if sys.version_info >= (3, 11):
|
|
15
|
+
from typing import Self
|
|
16
|
+
else:
|
|
17
|
+
from typing_extensions import Self
|
|
15
18
|
|
|
16
19
|
import numpy as np
|
|
17
20
|
|
|
18
21
|
from arize.constants.ml import (
|
|
19
|
-
# MAX_MULTI_CLASS_NAME_LENGTH,
|
|
20
|
-
# MAX_NUMBER_OF_MULTI_CLASS_CLASSES,
|
|
21
22
|
MAX_MULTI_CLASS_NAME_LENGTH,
|
|
22
23
|
MAX_NUMBER_OF_MULTI_CLASS_CLASSES,
|
|
23
24
|
MAX_NUMBER_OF_SIMILARITY_REFERENCES,
|
|
24
25
|
MAX_RAW_DATA_CHARACTERS,
|
|
25
26
|
MAX_RAW_DATA_CHARACTERS_TRUNCATION,
|
|
26
|
-
# MAX_RAW_DATA_CHARACTERS,
|
|
27
|
-
# MAX_RAW_DATA_CHARACTERS_TRUNCATION,
|
|
28
27
|
)
|
|
29
28
|
from arize.exceptions.parameters import InvalidValueType
|
|
30
|
-
|
|
31
|
-
#
|
|
32
|
-
# from arize.utils.constants import (
|
|
33
|
-
# MAX_MULTI_CLASS_NAME_LENGTH,
|
|
34
|
-
# MAX_NUMBER_OF_MULTI_CLASS_CLASSES,
|
|
35
|
-
# MAX_NUMBER_OF_SIMILARITY_REFERENCES,
|
|
36
|
-
# MAX_RAW_DATA_CHARACTERS,
|
|
37
|
-
# MAX_RAW_DATA_CHARACTERS_TRUNCATION,
|
|
38
|
-
# )
|
|
39
|
-
# from arize.utils.errors import InvalidValueType
|
|
40
29
|
from arize.logging import get_truncation_warning_message
|
|
41
30
|
from arize.utils.types import is_dict_of, is_iterable_of, is_list_of
|
|
42
31
|
|
|
43
32
|
logger = logging.getLogger(__name__)
|
|
44
33
|
|
|
45
34
|
|
|
35
|
+
def _normalize_column_names(
|
|
36
|
+
col_names: "list[str] | TypedColumns | None",
|
|
37
|
+
) -> list[str]:
|
|
38
|
+
"""Convert TypedColumns or list to a flat list of column names."""
|
|
39
|
+
if col_names is None:
|
|
40
|
+
return []
|
|
41
|
+
if isinstance(col_names, list):
|
|
42
|
+
return col_names
|
|
43
|
+
return col_names.get_all_column_names()
|
|
44
|
+
|
|
45
|
+
|
|
46
46
|
@unique
|
|
47
47
|
class ModelTypes(Enum):
|
|
48
48
|
"""Enum representing supported model types in Arize."""
|
|
@@ -204,7 +204,7 @@ class Embedding(NamedTuple):
|
|
|
204
204
|
)
|
|
205
205
|
# Fail if not all elements in list are floats
|
|
206
206
|
allowed_types = (int, float, np.int16, np.int32, np.float16, np.float32)
|
|
207
|
-
if not all(isinstance(val, allowed_types) for val in self.vector):
|
|
207
|
+
if not all(isinstance(val, allowed_types) for val in self.vector):
|
|
208
208
|
raise TypeError(
|
|
209
209
|
f"Embedding vector must be a vector of integers and/or floats. Got "
|
|
210
210
|
f"{emb_name}.vector = {self.vector}"
|
|
@@ -283,7 +283,7 @@ class Embedding(NamedTuple):
|
|
|
283
283
|
|
|
284
284
|
@staticmethod
|
|
285
285
|
def _is_valid_iterable(
|
|
286
|
-
data:
|
|
286
|
+
data: object,
|
|
287
287
|
) -> bool:
|
|
288
288
|
"""Validates that the input data field is of the correct iterable type.
|
|
289
289
|
|
|
@@ -299,30 +299,6 @@ class Embedding(NamedTuple):
|
|
|
299
299
|
return any(isinstance(data, t) for t in (list, np.ndarray))
|
|
300
300
|
|
|
301
301
|
|
|
302
|
-
# @dataclass
|
|
303
|
-
# class _PromptOrResponseText:
|
|
304
|
-
# data: str
|
|
305
|
-
#
|
|
306
|
-
# def validate(self, name: str) -> None:
|
|
307
|
-
# # Validate that data is a string
|
|
308
|
-
# if not isinstance(self.data, str):
|
|
309
|
-
# raise TypeError(f"'{name}' must be a str")
|
|
310
|
-
#
|
|
311
|
-
# character_count = len(self.data)
|
|
312
|
-
# if character_count > MAX_RAW_DATA_CHARACTERS:
|
|
313
|
-
# raise ValueError(
|
|
314
|
-
# f"'{name}' field must not contain more than {MAX_RAW_DATA_CHARACTERS} characters. "
|
|
315
|
-
# f"Found {character_count}."
|
|
316
|
-
# )
|
|
317
|
-
# elif character_count > MAX_RAW_DATA_CHARACTERS_TRUNCATION:
|
|
318
|
-
# logger.warning(
|
|
319
|
-
# get_truncation_warning_message(
|
|
320
|
-
# f"'{name}'", MAX_RAW_DATA_CHARACTERS_TRUNCATION
|
|
321
|
-
# )
|
|
322
|
-
# )
|
|
323
|
-
# return None
|
|
324
|
-
|
|
325
|
-
|
|
326
302
|
class LLMRunMetadata(NamedTuple):
|
|
327
303
|
"""Metadata for LLM execution including token counts and latency."""
|
|
328
304
|
|
|
@@ -1021,22 +997,6 @@ class LLMRunMetadataColumnNames:
|
|
|
1021
997
|
)
|
|
1022
998
|
|
|
1023
999
|
|
|
1024
|
-
# @dataclass
|
|
1025
|
-
# class DocumentColumnNames:
|
|
1026
|
-
# id_column_name: Optional[str] = None
|
|
1027
|
-
# version_column_name: Optional[str] = None
|
|
1028
|
-
# text_embedding_column_names: Optional[EmbeddingColumnNames] = None
|
|
1029
|
-
#
|
|
1030
|
-
# def __iter__(self):
|
|
1031
|
-
# return iter(
|
|
1032
|
-
# (
|
|
1033
|
-
# self.id_column_name,
|
|
1034
|
-
# self.version_column_name,
|
|
1035
|
-
# self.text_embedding_column_names,
|
|
1036
|
-
# )
|
|
1037
|
-
# )
|
|
1038
|
-
#
|
|
1039
|
-
#
|
|
1040
1000
|
@dataclass
|
|
1041
1001
|
class SimilarityReference:
|
|
1042
1002
|
"""Reference to a prediction for similarity search operations."""
|
|
@@ -1250,7 +1210,7 @@ class Schema(BaseSchema):
|
|
|
1250
1210
|
actual_score_column_name: str | None = None
|
|
1251
1211
|
shap_values_column_names: dict[str, str] | None = None
|
|
1252
1212
|
embedding_feature_column_names: dict[str, EmbeddingColumnNames] | None = (
|
|
1253
|
-
None
|
|
1213
|
+
None
|
|
1254
1214
|
)
|
|
1255
1215
|
prediction_group_id_column_name: str | None = None
|
|
1256
1216
|
rank_column_name: str | None = None
|
|
@@ -1268,7 +1228,7 @@ class Schema(BaseSchema):
|
|
|
1268
1228
|
prompt_template_column_names: PromptTemplateColumnNames | None = None
|
|
1269
1229
|
llm_config_column_names: LLMConfigColumnNames | None = None
|
|
1270
1230
|
llm_run_metadata_column_names: LLMRunMetadataColumnNames | None = None
|
|
1271
|
-
retrieved_document_ids_column_name:
|
|
1231
|
+
retrieved_document_ids_column_name: str | None = None
|
|
1272
1232
|
multi_class_threshold_scores_column_name: str | None = None
|
|
1273
1233
|
semantic_segmentation_prediction_column_names: (
|
|
1274
1234
|
SemanticSegmentationColumnNames | None
|
|
@@ -1285,7 +1245,7 @@ class Schema(BaseSchema):
|
|
|
1285
1245
|
|
|
1286
1246
|
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1287
1247
|
"""Return a dict mapping column names to their usage count."""
|
|
1288
|
-
columns_used_counts = {}
|
|
1248
|
+
columns_used_counts: dict[str, int] = {}
|
|
1289
1249
|
|
|
1290
1250
|
for field in self.__dataclass_fields__:
|
|
1291
1251
|
if field.endswith("column_name"):
|
|
@@ -1294,7 +1254,7 @@ class Schema(BaseSchema):
|
|
|
1294
1254
|
add_to_column_count_dictionary(columns_used_counts, col)
|
|
1295
1255
|
|
|
1296
1256
|
if self.feature_column_names is not None:
|
|
1297
|
-
for col in self.feature_column_names:
|
|
1257
|
+
for col in _normalize_column_names(self.feature_column_names):
|
|
1298
1258
|
add_to_column_count_dictionary(columns_used_counts, col)
|
|
1299
1259
|
|
|
1300
1260
|
if self.embedding_feature_column_names is not None:
|
|
@@ -1313,7 +1273,7 @@ class Schema(BaseSchema):
|
|
|
1313
1273
|
)
|
|
1314
1274
|
|
|
1315
1275
|
if self.tag_column_names is not None:
|
|
1316
|
-
for col in self.tag_column_names:
|
|
1276
|
+
for col in _normalize_column_names(self.tag_column_names):
|
|
1317
1277
|
add_to_column_count_dictionary(columns_used_counts, col)
|
|
1318
1278
|
|
|
1319
1279
|
if self.shap_values_column_names is not None:
|
|
@@ -1458,7 +1418,7 @@ class CorpusSchema(BaseSchema):
|
|
|
1458
1418
|
|
|
1459
1419
|
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1460
1420
|
"""Return a dict mapping column names to their usage count."""
|
|
1461
|
-
columns_used_counts = {}
|
|
1421
|
+
columns_used_counts: dict[str, int] = {}
|
|
1462
1422
|
|
|
1463
1423
|
if self.document_id_column_name is not None:
|
|
1464
1424
|
add_to_column_count_dictionary(
|
|
@@ -1531,7 +1491,7 @@ def add_to_column_count_dictionary(
|
|
|
1531
1491
|
|
|
1532
1492
|
Args:
|
|
1533
1493
|
column_dictionary: Dictionary mapping column names to counts.
|
|
1534
|
-
col: The column name to increment, or None to skip.
|
|
1494
|
+
col: The column name to increment, or :obj:`None` to skip.
|
|
1535
1495
|
"""
|
|
1536
1496
|
if col:
|
|
1537
1497
|
if col in column_dictionary:
|
arize/pre_releases.py
CHANGED
|
@@ -3,14 +3,15 @@
|
|
|
3
3
|
import functools
|
|
4
4
|
import logging
|
|
5
5
|
from collections.abc import Callable
|
|
6
|
-
from enum import
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import TypeVar, cast
|
|
7
8
|
|
|
8
9
|
from arize.version import __version__
|
|
9
10
|
|
|
10
11
|
logger = logging.getLogger(__name__)
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
class ReleaseStage(
|
|
14
|
+
class ReleaseStage(Enum):
|
|
14
15
|
"""Enum representing the release stage of API features."""
|
|
15
16
|
|
|
16
17
|
ALPHA = "alpha"
|
|
@@ -19,19 +20,21 @@ class ReleaseStage(StrEnum):
|
|
|
19
20
|
|
|
20
21
|
_WARNED: set[str] = set()
|
|
21
22
|
|
|
23
|
+
_F = TypeVar("_F", bound=Callable)
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
def _format_prerelease_message(*, key: str, stage: ReleaseStage) -> str:
|
|
24
27
|
article = "an" if stage is ReleaseStage.ALPHA else "a"
|
|
25
28
|
return (
|
|
26
|
-
f"[{stage.upper()}] {key} is {article} {stage} API "
|
|
29
|
+
f"[{stage.value.upper()}] {key} is {article} {stage.value} API "
|
|
27
30
|
f"in Arize SDK v{__version__} and may change without notice."
|
|
28
31
|
)
|
|
29
32
|
|
|
30
33
|
|
|
31
|
-
def prerelease_endpoint(*,
|
|
34
|
+
def prerelease_endpoint(*, key: str, stage: ReleaseStage) -> Callable[[_F], _F]:
|
|
32
35
|
"""Decorate a method to emit a prerelease warning via logging once per process."""
|
|
33
36
|
|
|
34
|
-
def deco(fn:
|
|
37
|
+
def deco(fn: _F) -> _F:
|
|
35
38
|
@functools.wraps(fn)
|
|
36
39
|
def wrapper(*args: object, **kwargs: object) -> object:
|
|
37
40
|
if key not in _WARNED:
|
|
@@ -39,6 +42,7 @@ def prerelease_endpoint(*, stage: ReleaseStage, key: str) -> object:
|
|
|
39
42
|
logger.warning(_format_prerelease_message(key=key, stage=stage))
|
|
40
43
|
return fn(*args, **kwargs)
|
|
41
44
|
|
|
42
|
-
|
|
45
|
+
# Cast: functools.wraps preserves function signature at runtime but mypy can't verify this
|
|
46
|
+
return cast("_F", wrapper)
|
|
43
47
|
|
|
44
48
|
return deco
|
arize/projects/client.py
CHANGED
|
@@ -9,6 +9,7 @@ from arize.pre_releases import ReleaseStage, prerelease_endpoint
|
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
from arize._generated.api_client import models
|
|
12
|
+
from arize._generated.api_client.api_client import ApiClient
|
|
12
13
|
from arize.config import SDKConfiguration
|
|
13
14
|
|
|
14
15
|
logger = logging.getLogger(__name__)
|
|
@@ -26,18 +27,21 @@ class ProjectsClient:
|
|
|
26
27
|
:class:`arize.config.SDKConfiguration`.
|
|
27
28
|
"""
|
|
28
29
|
|
|
29
|
-
def __init__(
|
|
30
|
+
def __init__(
|
|
31
|
+
self, *, sdk_config: SDKConfiguration, generated_client: ApiClient
|
|
32
|
+
) -> None:
|
|
30
33
|
"""
|
|
31
34
|
Args:
|
|
32
35
|
sdk_config: Resolved SDK configuration.
|
|
36
|
+
generated_client: Shared generated API client instance.
|
|
33
37
|
""" # noqa: D205, D212
|
|
34
38
|
self._sdk_config = sdk_config
|
|
35
39
|
|
|
36
40
|
# Import at runtime so it's still lazy and extras-gated by the parent
|
|
37
41
|
from arize._generated import api_client as gen
|
|
38
42
|
|
|
39
|
-
# Use the
|
|
40
|
-
self._api = gen.ProjectsApi(
|
|
43
|
+
# Use the provided client directly
|
|
44
|
+
self._api = gen.ProjectsApi(generated_client)
|
|
41
45
|
|
|
42
46
|
@prerelease_endpoint(key="projects.list", stage=ReleaseStage.BETA)
|
|
43
47
|
def list(
|
|
@@ -125,7 +129,8 @@ class ProjectsClient:
|
|
|
125
129
|
Args:
|
|
126
130
|
project_id: Project ID.
|
|
127
131
|
|
|
128
|
-
Returns:
|
|
132
|
+
Returns:
|
|
133
|
+
This method returns None on success (common empty 204 response).
|
|
129
134
|
|
|
130
135
|
Raises:
|
|
131
136
|
arize._generated.api_client.exceptions.ApiException: If the API request fails
|
arize/py.typed
ADDED
|
File without changes
|
arize/regions.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
"""Region definitions and configuration for Arize deployment zones."""
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from enum import
|
|
4
|
+
from enum import Enum
|
|
5
5
|
|
|
6
6
|
from arize.constants.config import DEFAULT_FLIGHT_PORT
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class Region(
|
|
9
|
+
class Region(Enum):
|
|
10
10
|
"""Enum representing available Arize deployment regions."""
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
12
|
+
CA_CENTRAL_1A = "ca-central-1a"
|
|
13
|
+
EU_WEST_1A = "eu-west-1a"
|
|
14
|
+
US_CENTRAL_1A = "us-central-1a"
|
|
15
|
+
US_EAST_1B = "us-east-1b"
|
|
16
|
+
UNSET = ""
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
@dataclass(frozen=True)
|
|
@@ -28,13 +28,13 @@ class RegionEndpoints:
|
|
|
28
28
|
|
|
29
29
|
def _get_region_endpoints(region: Region) -> RegionEndpoints:
|
|
30
30
|
return RegionEndpoints(
|
|
31
|
-
api_host=f"api.{region}.arize.com",
|
|
32
|
-
otlp_host=f"otlp.{region}.arize.com",
|
|
33
|
-
flight_host=f"flight.{region}.arize.com",
|
|
31
|
+
api_host=f"api.{region.value}.arize.com",
|
|
32
|
+
otlp_host=f"otlp.{region.value}.arize.com",
|
|
33
|
+
flight_host=f"flight.{region.value}.arize.com",
|
|
34
34
|
flight_port=DEFAULT_FLIGHT_PORT,
|
|
35
35
|
)
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
REGION_ENDPOINTS: dict[Region, RegionEndpoints] = {
|
|
39
|
-
r: _get_region_endpoints(r) for r in Region if r != Region.
|
|
39
|
+
r: _get_region_endpoints(r) for r in list(Region) if r != Region.UNSET
|
|
40
40
|
}
|