arize 8.0.0a21__py3-none-any.whl → 8.0.0a23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +17 -9
- arize/_exporter/client.py +55 -36
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +208 -77
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +269 -55
- arize/config.py +365 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +299 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +31 -12
- arize/embeddings/tabular_generators.py +32 -20
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +1 -0
- arize/experiments/client.py +390 -286
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/models/__init__.py +1 -0
- arize/models/batch_validation/__init__.py +1 -0
- arize/models/batch_validation/errors.py +543 -65
- arize/models/batch_validation/validator.py +339 -300
- arize/models/bounded_executor.py +20 -7
- arize/models/casting.py +75 -29
- arize/models/client.py +326 -107
- arize/models/proto.py +95 -40
- arize/models/stream_validation.py +42 -14
- arize/models/surrogate_explainer/__init__.py +1 -0
- arize/models/surrogate_explainer/mimic.py +24 -13
- arize/pre_releases.py +43 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +129 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +130 -106
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +54 -38
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +80 -13
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +34 -13
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +76 -7
- arize/types.py +293 -157
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +19 -2
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/version.py +3 -1
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
- arize-8.0.0a23.dist-info/RECORD +174 -0
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
- arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize-8.0.0a21.dist-info/RECORD +0 -146
- arize-8.0.0a21.dist-info/licenses/LICENSE.md +0 -12
arize/models/client.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
|
+
"""Client implementation for managing ML models in the Arize platform."""
|
|
2
|
+
|
|
1
3
|
# type: ignore[pb2]
|
|
2
4
|
from __future__ import annotations
|
|
3
5
|
|
|
4
6
|
import copy
|
|
5
7
|
import logging
|
|
6
8
|
import time
|
|
7
|
-
from typing import TYPE_CHECKING
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
8
10
|
|
|
11
|
+
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
9
12
|
from arize._lazy import require
|
|
10
13
|
from arize.constants.ml import (
|
|
11
14
|
LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME,
|
|
@@ -43,6 +46,7 @@ from arize.types import (
|
|
|
43
46
|
BaseSchema,
|
|
44
47
|
CorpusSchema,
|
|
45
48
|
Embedding,
|
|
49
|
+
EmbeddingColumnNames,
|
|
46
50
|
Environments,
|
|
47
51
|
LLMRunMetadata,
|
|
48
52
|
Metrics,
|
|
@@ -64,12 +68,7 @@ if TYPE_CHECKING:
|
|
|
64
68
|
import requests
|
|
65
69
|
from requests_futures.sessions import FuturesSession
|
|
66
70
|
|
|
67
|
-
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
68
71
|
from arize.config import SDKConfiguration
|
|
69
|
-
from arize.types import (
|
|
70
|
-
EmbeddingColumnNames,
|
|
71
|
-
Schema,
|
|
72
|
-
)
|
|
73
72
|
|
|
74
73
|
|
|
75
74
|
logger = logging.getLogger(__name__)
|
|
@@ -96,7 +95,14 @@ _MIMIC_EXTRA = "mimic-explainer"
|
|
|
96
95
|
|
|
97
96
|
|
|
98
97
|
class MLModelsClient:
|
|
99
|
-
|
|
98
|
+
"""Client for logging ML model predictions and actuals to Arize."""
|
|
99
|
+
|
|
100
|
+
def __init__(self, *, sdk_config: SDKConfiguration) -> None:
|
|
101
|
+
"""Initialize the ML models client with SDK configuration.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
sdk_config: SDK configuration containing API endpoints and credentials.
|
|
105
|
+
"""
|
|
100
106
|
self._sdk_config = sdk_config
|
|
101
107
|
|
|
102
108
|
# internal cache for the futures session
|
|
@@ -114,21 +120,86 @@ class MLModelsClient:
|
|
|
114
120
|
prediction_timestamp: int | None = None,
|
|
115
121
|
prediction_label: PredictionLabelTypes | None = None,
|
|
116
122
|
actual_label: ActualLabelTypes | None = None,
|
|
117
|
-
features:
|
|
123
|
+
features: dict[str, str | bool | float | int | list[str] | TypedValue]
|
|
118
124
|
| None = None,
|
|
119
|
-
embedding_features:
|
|
120
|
-
shap_values:
|
|
121
|
-
tags:
|
|
125
|
+
embedding_features: dict[str, Embedding] | None = None,
|
|
126
|
+
shap_values: dict[str, float] | None = None,
|
|
127
|
+
tags: dict[str, str | bool | float | int | TypedValue] | None = None,
|
|
122
128
|
batch_id: str | None = None,
|
|
123
129
|
prompt: str | Embedding | None = None,
|
|
124
130
|
response: str | Embedding | None = None,
|
|
125
131
|
prompt_template: str | None = None,
|
|
126
132
|
prompt_template_version: str | None = None,
|
|
127
133
|
llm_model_name: str | None = None,
|
|
128
|
-
llm_params:
|
|
134
|
+
llm_params: dict[str, str | bool | float | int] | None = None,
|
|
129
135
|
llm_run_metadata: LLMRunMetadata | None = None,
|
|
130
136
|
timeout: float | None = None,
|
|
131
137
|
) -> cf.Future:
|
|
138
|
+
"""Log a single model prediction or actual to Arize asynchronously.
|
|
139
|
+
|
|
140
|
+
This method sends a single prediction, actual, or both to Arize for ML monitoring.
|
|
141
|
+
The request is made asynchronously and returns a Future that can be used to check
|
|
142
|
+
the status or retrieve the response.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
space_id: The space ID where the model resides.
|
|
146
|
+
model_name: A unique name to identify your model in the Arize platform.
|
|
147
|
+
model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
|
|
148
|
+
RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
|
|
149
|
+
spans module instead.
|
|
150
|
+
environment: The environment this data belongs to (PRODUCTION, TRAINING, or
|
|
151
|
+
VALIDATION).
|
|
152
|
+
model_version: Optional version identifier for the model.
|
|
153
|
+
prediction_id: Unique identifier for this prediction. If not provided, one
|
|
154
|
+
will be auto-generated for PRODUCTION environment.
|
|
155
|
+
prediction_timestamp: Unix timestamp (seconds) for when the prediction was made.
|
|
156
|
+
If not provided, the current time is used. Must be within 1 year in the
|
|
157
|
+
future and 2 years in the past from the current time.
|
|
158
|
+
prediction_label: The prediction output from your model. Type depends on
|
|
159
|
+
model_type (e.g., string for categorical, float for numeric).
|
|
160
|
+
actual_label: The ground truth label. Type depends on model_type.
|
|
161
|
+
features: Dictionary of feature name to feature value. Values can be str, bool,
|
|
162
|
+
float, int, list[str], or TypedValue.
|
|
163
|
+
embedding_features: Dictionary of embedding feature name to Embedding object.
|
|
164
|
+
Maximum 50 embeddings per record. Object detection models support only 1.
|
|
165
|
+
shap_values: Dictionary of feature name to SHAP value (float) for feature
|
|
166
|
+
importance/explainability.
|
|
167
|
+
tags: Dictionary of metadata tags. Tag names cannot end with "_shap" or be
|
|
168
|
+
reserved names. Values must be under 1000 characters (warning at 100).
|
|
169
|
+
batch_id: Required for VALIDATION environment; identifies the validation batch.
|
|
170
|
+
prompt: For generative models, the prompt text or embedding sent to the model.
|
|
171
|
+
response: For generative models, the response text or embedding from the model.
|
|
172
|
+
prompt_template: Template used to generate the prompt.
|
|
173
|
+
prompt_template_version: Version identifier for the prompt template.
|
|
174
|
+
llm_model_name: Name of the LLM model used (e.g., "gpt-4").
|
|
175
|
+
llm_params: Dictionary of LLM configuration parameters (e.g., temperature,
|
|
176
|
+
max_tokens).
|
|
177
|
+
llm_run_metadata: Metadata about the LLM run including token counts and latency.
|
|
178
|
+
timeout: Maximum time (in seconds) to wait for the request to complete.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
A concurrent.futures.Future object representing the async request. Call
|
|
182
|
+
.result() to block and retrieve the Response object, or check .done() for
|
|
183
|
+
completion status.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
ValueError: If model_type is GENERATIVE_LLM, or if validation environment is
|
|
187
|
+
missing batch_id, or if training/validation environment is missing
|
|
188
|
+
prediction or actual, or if timestamp is out of range, or if no data
|
|
189
|
+
is provided (must have prediction_label, actual_label, tags, or shap_values),
|
|
190
|
+
or if tag names end with "_shap" or exceed length limits.
|
|
191
|
+
MissingSpaceIDError: If space_id is not provided or empty.
|
|
192
|
+
MissingModelNameError: If model_name is not provided or empty.
|
|
193
|
+
InvalidValueType: If features, tags, or other parameters have incorrect types.
|
|
194
|
+
InvalidNumberOfEmbeddings: If more than 50 embedding features are provided.
|
|
195
|
+
KeyError: If tag names include reserved names.
|
|
196
|
+
|
|
197
|
+
Notes:
|
|
198
|
+
- Timestamps must be within 1 year future and 2 years past from current time
|
|
199
|
+
- Tag values are truncated at 1000 characters, with warnings at 100 characters
|
|
200
|
+
- For GENERATIVE_LLM models, use the spans module or OTEL tracing instead
|
|
201
|
+
- The Future returned can be monitored for request status asynchronously
|
|
202
|
+
"""
|
|
132
203
|
require(_STREAM_EXTRA, _STREAM_DEPS)
|
|
133
204
|
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
134
205
|
from arize.models.proto import (
|
|
@@ -179,16 +250,15 @@ class MLModelsClient:
|
|
|
179
250
|
_validate_mapping_key(feat_name, "features")
|
|
180
251
|
if is_list_of(feat_value, str):
|
|
181
252
|
continue
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
)
|
|
253
|
+
val = convert_element(feat_value)
|
|
254
|
+
if val is not None and not isinstance(
|
|
255
|
+
val, (str, bool, float, int)
|
|
256
|
+
):
|
|
257
|
+
raise InvalidValueType(
|
|
258
|
+
f"feature '{feat_name}'",
|
|
259
|
+
feat_value,
|
|
260
|
+
"one of: bool, int, float, str",
|
|
261
|
+
)
|
|
192
262
|
|
|
193
263
|
# Validate embedding_features type
|
|
194
264
|
if embedding_features:
|
|
@@ -247,7 +317,7 @@ class MLModelsClient:
|
|
|
247
317
|
f"{MAX_TAG_LENGTH}. The tag {tag_name} with value {tag_value} has "
|
|
248
318
|
f"{len(str(val))} characters."
|
|
249
319
|
)
|
|
250
|
-
|
|
320
|
+
if len(str(val)) > MAX_TAG_LENGTH_TRUNCATION:
|
|
251
321
|
logger.warning(
|
|
252
322
|
get_truncation_warning_message(
|
|
253
323
|
"tags", MAX_TAG_LENGTH_TRUNCATION
|
|
@@ -304,7 +374,7 @@ class MLModelsClient:
|
|
|
304
374
|
if embedding_features or prompt or response:
|
|
305
375
|
# NOTE: Deep copy is necessary to avoid side effects on the original input dictionary
|
|
306
376
|
combined_embedding_features = (
|
|
307
|
-
|
|
377
|
+
dict(embedding_features.items())
|
|
308
378
|
if embedding_features
|
|
309
379
|
else {}
|
|
310
380
|
)
|
|
@@ -453,7 +523,6 @@ class MLModelsClient:
|
|
|
453
523
|
indexes=None,
|
|
454
524
|
)
|
|
455
525
|
|
|
456
|
-
# TODO(Kiko): Handle sync argument
|
|
457
526
|
def log_batch(
|
|
458
527
|
self,
|
|
459
528
|
*,
|
|
@@ -466,12 +535,64 @@ class MLModelsClient:
|
|
|
466
535
|
model_version: str = "",
|
|
467
536
|
batch_id: str = "",
|
|
468
537
|
validate: bool = True,
|
|
469
|
-
metrics_validation:
|
|
538
|
+
metrics_validation: list[Metrics] | None = None,
|
|
470
539
|
surrogate_explainability: bool = False,
|
|
471
540
|
timeout: float | None = None,
|
|
472
541
|
tmp_dir: str = "",
|
|
473
|
-
sync: bool = False,
|
|
474
542
|
) -> requests.Response:
|
|
543
|
+
"""Log a batch of model predictions and actuals to Arize from a pandas DataFrame.
|
|
544
|
+
|
|
545
|
+
This method uploads multiple records to Arize in a single batch operation using
|
|
546
|
+
Apache Arrow format for efficient transfer. The dataframe structure is defined
|
|
547
|
+
by the provided schema which maps dataframe columns to Arize data fields.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
space_id: The space ID where the model resides.
|
|
551
|
+
model_name: A unique name to identify your model in the Arize platform.
|
|
552
|
+
model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
|
|
553
|
+
RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
|
|
554
|
+
spans module instead.
|
|
555
|
+
dataframe: Pandas DataFrame containing the data to upload. Columns should
|
|
556
|
+
correspond to the schema field mappings.
|
|
557
|
+
schema: Schema object (Schema or CorpusSchema) that defines the mapping between
|
|
558
|
+
dataframe columns and Arize data fields (e.g., prediction_label_column_name,
|
|
559
|
+
feature_column_names, etc.).
|
|
560
|
+
environment: The environment this data belongs to (PRODUCTION, TRAINING,
|
|
561
|
+
VALIDATION, or CORPUS).
|
|
562
|
+
model_version: Optional version identifier for the model.
|
|
563
|
+
batch_id: Required for VALIDATION environment; identifies the validation batch.
|
|
564
|
+
validate: When True, performs comprehensive validation before sending data.
|
|
565
|
+
Includes checks for required fields, data types, and value constraints.
|
|
566
|
+
metrics_validation: Optional list of metric families to validate against.
|
|
567
|
+
surrogate_explainability: When True, automatically generates SHAP values using
|
|
568
|
+
MIMIC surrogate explainer. Requires the 'mimic-explainer' extra. Has no
|
|
569
|
+
effect if shap_values_column_names is already specified in schema.
|
|
570
|
+
timeout: Maximum time (in seconds) to wait for the request to complete.
|
|
571
|
+
tmp_dir: Optional temporary directory to store serialized Arrow data before
|
|
572
|
+
upload.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
A requests.Response object from the upload request. Check .status_code for
|
|
576
|
+
success (200) or error conditions.
|
|
577
|
+
|
|
578
|
+
Raises:
|
|
579
|
+
MissingSpaceIDError: If space_id is not provided or empty.
|
|
580
|
+
MissingModelNameError: If model_name is not provided or empty.
|
|
581
|
+
ValueError: If model_type is GENERATIVE_LLM, or if environment is CORPUS with
|
|
582
|
+
non-CorpusSchema, or if training/validation records are incomplete.
|
|
583
|
+
ValidationFailure: If validate=True and validation checks fail. Contains list
|
|
584
|
+
of validation error messages.
|
|
585
|
+
pa.ArrowInvalid: If the dataframe cannot be converted to Arrow format, typically
|
|
586
|
+
due to mixed types in columns not specified in the schema.
|
|
587
|
+
|
|
588
|
+
Notes:
|
|
589
|
+
- Categorical dtype columns are automatically converted to string
|
|
590
|
+
- Extraneous columns not in the schema are removed before upload
|
|
591
|
+
- Surrogate explainability requires 'mimic-explainer' extra
|
|
592
|
+
- For GENERATIVE_LLM models, use the spans module or OTEL tracing instead
|
|
593
|
+
- If logging actuals without predictions, ensure predictions were logged first
|
|
594
|
+
- Data is sent via Apache Arrow for efficient large batch transfers
|
|
595
|
+
"""
|
|
475
596
|
require(_BATCH_EXTRA, _BATCH_DEPS)
|
|
476
597
|
import pandas.api.types as ptypes
|
|
477
598
|
import pyarrow as pa
|
|
@@ -506,8 +627,8 @@ class MLModelsClient:
|
|
|
506
627
|
# Thus we can only offer this functionality with pandas>=1.0.0.
|
|
507
628
|
try:
|
|
508
629
|
dataframe, schema = cast_typed_columns(dataframe, schema)
|
|
509
|
-
except Exception
|
|
510
|
-
logger.
|
|
630
|
+
except Exception:
|
|
631
|
+
logger.exception("Error casting typed columns")
|
|
511
632
|
raise
|
|
512
633
|
|
|
513
634
|
logger.debug("Performing required validation.")
|
|
@@ -546,7 +667,7 @@ class MLModelsClient:
|
|
|
546
667
|
|
|
547
668
|
# always validate pd.Category is not present, if yes, convert to string
|
|
548
669
|
has_cat_col = any(
|
|
549
|
-
|
|
670
|
+
ptypes.is_categorical_dtype(x) for x in dataframe.dtypes
|
|
550
671
|
)
|
|
551
672
|
if has_cat_col:
|
|
552
673
|
cat_cols = [
|
|
@@ -554,7 +675,13 @@ class MLModelsClient:
|
|
|
554
675
|
for col_name, col_cat in dataframe.dtypes.items()
|
|
555
676
|
if col_cat.name == "category"
|
|
556
677
|
]
|
|
557
|
-
cat_str_map = dict(
|
|
678
|
+
cat_str_map = dict(
|
|
679
|
+
zip(
|
|
680
|
+
cat_cols,
|
|
681
|
+
["str"] * len(cat_cols),
|
|
682
|
+
strict=True,
|
|
683
|
+
)
|
|
684
|
+
)
|
|
558
685
|
dataframe = dataframe.astype(cat_str_map)
|
|
559
686
|
|
|
560
687
|
if surrogate_explainability:
|
|
@@ -588,12 +715,12 @@ class MLModelsClient:
|
|
|
588
715
|
# error conditions that we're currently not aware of.
|
|
589
716
|
pa_table = pa.Table.from_pandas(dataframe, preserve_index=False)
|
|
590
717
|
except pa.ArrowInvalid as e:
|
|
591
|
-
logger.
|
|
718
|
+
logger.exception(INVALID_ARROW_CONVERSION_MSG)
|
|
592
719
|
raise pa.ArrowInvalid(
|
|
593
|
-
f"Error converting to Arrow format: {
|
|
720
|
+
f"Error converting to Arrow format: {e!s}"
|
|
594
721
|
) from e
|
|
595
|
-
except Exception
|
|
596
|
-
logger.
|
|
722
|
+
except Exception:
|
|
723
|
+
logger.exception("Unexpected error creating Arrow table")
|
|
597
724
|
raise
|
|
598
725
|
|
|
599
726
|
if validate:
|
|
@@ -678,18 +805,53 @@ class MLModelsClient:
|
|
|
678
805
|
model_version: str = "",
|
|
679
806
|
batch_id: str = "",
|
|
680
807
|
where: str = "",
|
|
681
|
-
columns:
|
|
808
|
+
columns: list | None = None,
|
|
682
809
|
similarity_search_params: SimilaritySearchParams | None = None,
|
|
683
810
|
stream_chunk_size: int | None = None,
|
|
684
811
|
) -> pd.DataFrame:
|
|
812
|
+
"""Export model data from Arize to a pandas DataFrame.
|
|
813
|
+
|
|
814
|
+
Retrieves prediction and optional actual data for a model within a specified time
|
|
815
|
+
range and returns it as a pandas DataFrame for analysis.
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
space_id: The space ID where the model resides.
|
|
819
|
+
model_name: The name of the model to export data from.
|
|
820
|
+
environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
|
|
821
|
+
start_time: Start of the time range (inclusive) as a datetime object.
|
|
822
|
+
end_time: End of the time range (inclusive) as a datetime object.
|
|
823
|
+
include_actuals: When True, includes actual labels in the export. When False,
|
|
824
|
+
only predictions are returned.
|
|
825
|
+
model_version: Optional model version to filter by. Empty string returns all
|
|
826
|
+
versions.
|
|
827
|
+
batch_id: Optional batch ID to filter by (for VALIDATION environment).
|
|
828
|
+
where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5").
|
|
829
|
+
columns: Optional list of column names to include. If None, all columns are
|
|
830
|
+
returned.
|
|
831
|
+
similarity_search_params: Optional parameters for embedding similarity search
|
|
832
|
+
filtering.
|
|
833
|
+
stream_chunk_size: Optional chunk size for streaming large result sets.
|
|
834
|
+
|
|
835
|
+
Returns:
|
|
836
|
+
A pandas DataFrame containing the exported data with columns for predictions,
|
|
837
|
+
actuals (if requested), features, tags, timestamps, and other model metadata.
|
|
838
|
+
|
|
839
|
+
Raises:
|
|
840
|
+
RuntimeError: If the Flight client request fails or returns no response.
|
|
841
|
+
|
|
842
|
+
Notes:
|
|
843
|
+
- Uses Apache Arrow Flight for efficient data transfer
|
|
844
|
+
- Large exports may benefit from specifying stream_chunk_size
|
|
845
|
+
- The where clause supports SQL-like filtering syntax
|
|
846
|
+
"""
|
|
685
847
|
require(_BATCH_EXTRA, _BATCH_DEPS)
|
|
686
848
|
from arize._exporter.client import ArizeExportClient
|
|
687
849
|
from arize._flight.client import ArizeFlightClient
|
|
688
850
|
|
|
689
851
|
with ArizeFlightClient(
|
|
690
852
|
api_key=self._sdk_config.api_key,
|
|
691
|
-
host=self._sdk_config.
|
|
692
|
-
port=self._sdk_config.
|
|
853
|
+
host=self._sdk_config.flight_host,
|
|
854
|
+
port=self._sdk_config.flight_port,
|
|
693
855
|
scheme=self._sdk_config.flight_scheme,
|
|
694
856
|
request_verify=self._sdk_config.request_verify,
|
|
695
857
|
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
@@ -724,18 +886,53 @@ class MLModelsClient:
|
|
|
724
886
|
model_version: str = "",
|
|
725
887
|
batch_id: str = "",
|
|
726
888
|
where: str = "",
|
|
727
|
-
columns:
|
|
889
|
+
columns: list | None = None,
|
|
728
890
|
similarity_search_params: SimilaritySearchParams | None = None,
|
|
729
891
|
stream_chunk_size: int | None = None,
|
|
730
892
|
) -> pd.DataFrame:
|
|
893
|
+
"""Export model data from Arize to a Parquet file and return as DataFrame.
|
|
894
|
+
|
|
895
|
+
Retrieves prediction and optional actual data for a model within a specified time
|
|
896
|
+
range, saves it as a Parquet file, and returns it as a pandas DataFrame.
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
space_id: The space ID where the model resides.
|
|
900
|
+
model_name: The name of the model to export data from.
|
|
901
|
+
environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
|
|
902
|
+
start_time: Start of the time range (inclusive) as a datetime object.
|
|
903
|
+
end_time: End of the time range (inclusive) as a datetime object.
|
|
904
|
+
include_actuals: When True, includes actual labels in the export. When False,
|
|
905
|
+
only predictions are returned.
|
|
906
|
+
model_version: Optional model version to filter by. Empty string returns all
|
|
907
|
+
versions.
|
|
908
|
+
batch_id: Optional batch ID to filter by (for VALIDATION environment).
|
|
909
|
+
where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5").
|
|
910
|
+
columns: Optional list of column names to include. If None, all columns are
|
|
911
|
+
returned.
|
|
912
|
+
similarity_search_params: Optional parameters for embedding similarity search
|
|
913
|
+
filtering.
|
|
914
|
+
stream_chunk_size: Optional chunk size for streaming large result sets.
|
|
915
|
+
|
|
916
|
+
Returns:
|
|
917
|
+
A pandas DataFrame containing the exported data. The data is also saved to a
|
|
918
|
+
Parquet file by the underlying export client.
|
|
919
|
+
|
|
920
|
+
Raises:
|
|
921
|
+
RuntimeError: If the Flight client request fails or returns no response.
|
|
922
|
+
|
|
923
|
+
Notes:
|
|
924
|
+
- Uses Apache Arrow Flight for efficient data transfer
|
|
925
|
+
- The Parquet file location is managed by the ArizeExportClient
|
|
926
|
+
- Large exports may benefit from specifying stream_chunk_size
|
|
927
|
+
"""
|
|
731
928
|
require(_BATCH_EXTRA, _BATCH_DEPS)
|
|
732
929
|
from arize._exporter.client import ArizeExportClient
|
|
733
930
|
from arize._flight.client import ArizeFlightClient
|
|
734
931
|
|
|
735
932
|
with ArizeFlightClient(
|
|
736
933
|
api_key=self._sdk_config.api_key,
|
|
737
|
-
host=self._sdk_config.
|
|
738
|
-
port=self._sdk_config.
|
|
934
|
+
host=self._sdk_config.flight_host,
|
|
935
|
+
port=self._sdk_config.flight_port,
|
|
739
936
|
scheme=self._sdk_config.flight_scheme,
|
|
740
937
|
request_verify=self._sdk_config.request_verify,
|
|
741
938
|
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
@@ -759,6 +956,7 @@ class MLModelsClient:
|
|
|
759
956
|
)
|
|
760
957
|
|
|
761
958
|
def _ensure_session(self) -> FuturesSession:
|
|
959
|
+
"""Lazily initialize and return the FuturesSession for async streaming requests."""
|
|
762
960
|
from requests_futures.sessions import FuturesSession
|
|
763
961
|
|
|
764
962
|
session = object.__getattribute__(self, "_session")
|
|
@@ -778,10 +976,11 @@ class MLModelsClient:
|
|
|
778
976
|
def _post(
|
|
779
977
|
self,
|
|
780
978
|
record: pb2.Record,
|
|
781
|
-
headers:
|
|
979
|
+
headers: dict[str, str],
|
|
782
980
|
timeout: float | None,
|
|
783
|
-
indexes:
|
|
784
|
-
):
|
|
981
|
+
indexes: tuple,
|
|
982
|
+
) -> object:
|
|
983
|
+
"""Post a record to Arize via async HTTP request with protobuf JSON serialization."""
|
|
785
984
|
from google.protobuf.json_format import MessageToDict
|
|
786
985
|
|
|
787
986
|
session = self._ensure_session()
|
|
@@ -801,9 +1000,10 @@ class MLModelsClient:
|
|
|
801
1000
|
return resp
|
|
802
1001
|
|
|
803
1002
|
|
|
804
|
-
def _validate_mapping_key(key_name: str, name: str):
|
|
1003
|
+
def _validate_mapping_key(key_name: str, name: str) -> None:
|
|
1004
|
+
"""Validate that a mapping key (feature/tag name) is a string and doesn't end with '_shap'."""
|
|
805
1005
|
if not isinstance(key_name, str):
|
|
806
|
-
raise
|
|
1006
|
+
raise TypeError(
|
|
807
1007
|
f"{name} dictionary key {key_name} must be named with string, type used: {type(key_name)}"
|
|
808
1008
|
)
|
|
809
1009
|
if key_name.endswith("_shap"):
|
|
@@ -813,7 +1013,8 @@ def _validate_mapping_key(key_name: str, name: str):
|
|
|
813
1013
|
return
|
|
814
1014
|
|
|
815
1015
|
|
|
816
|
-
def _is_timestamp_in_range(now: int, ts: int):
|
|
1016
|
+
def _is_timestamp_in_range(now: int, ts: int) -> bool:
|
|
1017
|
+
"""Check if a timestamp is within the acceptable range (1 year future, 2 years past)."""
|
|
817
1018
|
max_time = now + (MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
|
|
818
1019
|
min_time = now - (MAX_PAST_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
|
|
819
1020
|
return min_time <= ts <= max_time
|
|
@@ -826,7 +1027,8 @@ def _get_pb_schema(
|
|
|
826
1027
|
model_type: ModelTypes,
|
|
827
1028
|
environment: Environments,
|
|
828
1029
|
batch_id: str,
|
|
829
|
-
):
|
|
1030
|
+
) -> object:
|
|
1031
|
+
"""Construct a protocol buffer Schema from the user's Schema for batch logging."""
|
|
830
1032
|
s = pb2.Schema()
|
|
831
1033
|
s.constants.model_id = model_id
|
|
832
1034
|
|
|
@@ -874,48 +1076,52 @@ def _get_pb_schema(
|
|
|
874
1076
|
|
|
875
1077
|
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
876
1078
|
if schema.object_detection_prediction_column_names is not None:
|
|
877
|
-
|
|
878
|
-
|
|
1079
|
+
obj_det_pred = schema.object_detection_prediction_column_names
|
|
1080
|
+
pred_labels = (
|
|
1081
|
+
s.arrow_schema.prediction_object_detection_label_column_names
|
|
879
1082
|
)
|
|
880
|
-
|
|
881
|
-
|
|
1083
|
+
pred_labels.bboxes_coordinates_column_name = (
|
|
1084
|
+
obj_det_pred.bounding_boxes_coordinates_column_name
|
|
882
1085
|
)
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
1086
|
+
pred_labels.bboxes_categories_column_name = (
|
|
1087
|
+
obj_det_pred.categories_column_name
|
|
1088
|
+
)
|
|
1089
|
+
if obj_det_pred.scores_column_name is not None:
|
|
1090
|
+
pred_labels.bboxes_scores_column_name = (
|
|
1091
|
+
obj_det_pred.scores_column_name
|
|
889
1092
|
)
|
|
890
1093
|
|
|
891
1094
|
if schema.semantic_segmentation_prediction_column_names is not None:
|
|
892
|
-
|
|
893
|
-
|
|
1095
|
+
seg_pred_cols = schema.semantic_segmentation_prediction_column_names
|
|
1096
|
+
pred_seg_labels = s.arrow_schema.prediction_semantic_segmentation_label_column_names
|
|
1097
|
+
pred_seg_labels.polygons_coordinates_column_name = (
|
|
1098
|
+
seg_pred_cols.polygon_coordinates_column_name
|
|
894
1099
|
)
|
|
895
|
-
|
|
896
|
-
|
|
1100
|
+
pred_seg_labels.polygons_categories_column_name = (
|
|
1101
|
+
seg_pred_cols.categories_column_name
|
|
897
1102
|
)
|
|
898
1103
|
|
|
899
1104
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
900
|
-
|
|
901
|
-
schema.instance_segmentation_prediction_column_names
|
|
1105
|
+
inst_seg_pred_cols = (
|
|
1106
|
+
schema.instance_segmentation_prediction_column_names
|
|
902
1107
|
)
|
|
903
|
-
s.arrow_schema.prediction_instance_segmentation_label_column_names
|
|
904
|
-
|
|
1108
|
+
pred_inst_seg_labels = s.arrow_schema.prediction_instance_segmentation_label_column_names
|
|
1109
|
+
pred_inst_seg_labels.polygons_coordinates_column_name = (
|
|
1110
|
+
inst_seg_pred_cols.polygon_coordinates_column_name
|
|
905
1111
|
)
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
1112
|
+
pred_inst_seg_labels.polygons_categories_column_name = (
|
|
1113
|
+
inst_seg_pred_cols.categories_column_name
|
|
1114
|
+
)
|
|
1115
|
+
if inst_seg_pred_cols.scores_column_name is not None:
|
|
1116
|
+
pred_inst_seg_labels.polygons_scores_column_name = (
|
|
1117
|
+
inst_seg_pred_cols.scores_column_name
|
|
912
1118
|
)
|
|
913
1119
|
if (
|
|
914
|
-
|
|
1120
|
+
inst_seg_pred_cols.bounding_boxes_coordinates_column_name
|
|
915
1121
|
is not None
|
|
916
1122
|
):
|
|
917
|
-
|
|
918
|
-
|
|
1123
|
+
pred_inst_seg_labels.bboxes_coordinates_column_name = (
|
|
1124
|
+
inst_seg_pred_cols.bounding_boxes_coordinates_column_name
|
|
919
1125
|
)
|
|
920
1126
|
|
|
921
1127
|
if schema.prediction_score_column_name is not None:
|
|
@@ -1038,50 +1244,61 @@ def _get_pb_schema(
|
|
|
1038
1244
|
|
|
1039
1245
|
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
1040
1246
|
if schema.object_detection_actual_column_names is not None:
|
|
1041
|
-
|
|
1042
|
-
|
|
1247
|
+
obj_det_actual = schema.object_detection_actual_column_names
|
|
1248
|
+
actual_labels = (
|
|
1249
|
+
s.arrow_schema.actual_object_detection_label_column_names
|
|
1043
1250
|
)
|
|
1044
|
-
|
|
1045
|
-
|
|
1251
|
+
actual_labels.bboxes_coordinates_column_name = (
|
|
1252
|
+
obj_det_actual.bounding_boxes_coordinates_column_name
|
|
1046
1253
|
)
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1254
|
+
actual_labels.bboxes_categories_column_name = (
|
|
1255
|
+
obj_det_actual.categories_column_name
|
|
1256
|
+
)
|
|
1257
|
+
if obj_det_actual.scores_column_name is not None:
|
|
1258
|
+
actual_labels.bboxes_scores_column_name = (
|
|
1259
|
+
obj_det_actual.scores_column_name
|
|
1053
1260
|
)
|
|
1054
1261
|
|
|
1055
1262
|
if schema.semantic_segmentation_actual_column_names is not None:
|
|
1056
|
-
|
|
1057
|
-
|
|
1263
|
+
sem_seg_actual = schema.semantic_segmentation_actual_column_names
|
|
1264
|
+
sem_seg_labels = (
|
|
1265
|
+
s.arrow_schema.actual_semantic_segmentation_label_column_names
|
|
1266
|
+
)
|
|
1267
|
+
sem_seg_labels.polygons_coordinates_column_name = (
|
|
1268
|
+
sem_seg_actual.polygon_coordinates_column_name
|
|
1058
1269
|
)
|
|
1059
|
-
|
|
1060
|
-
|
|
1270
|
+
sem_seg_labels.polygons_categories_column_name = (
|
|
1271
|
+
sem_seg_actual.categories_column_name
|
|
1061
1272
|
)
|
|
1062
1273
|
|
|
1063
1274
|
if schema.instance_segmentation_actual_column_names is not None:
|
|
1064
|
-
|
|
1065
|
-
|
|
1275
|
+
inst_seg_actual = schema.instance_segmentation_actual_column_names
|
|
1276
|
+
inst_seg_labels = (
|
|
1277
|
+
s.arrow_schema.actual_instance_segmentation_label_column_names
|
|
1066
1278
|
)
|
|
1067
|
-
|
|
1068
|
-
|
|
1279
|
+
inst_seg_labels.polygons_coordinates_column_name = (
|
|
1280
|
+
inst_seg_actual.polygon_coordinates_column_name
|
|
1281
|
+
)
|
|
1282
|
+
inst_seg_labels.polygons_categories_column_name = (
|
|
1283
|
+
inst_seg_actual.categories_column_name
|
|
1069
1284
|
)
|
|
1070
1285
|
if (
|
|
1071
|
-
|
|
1286
|
+
inst_seg_actual.bounding_boxes_coordinates_column_name
|
|
1072
1287
|
is not None
|
|
1073
1288
|
):
|
|
1074
|
-
|
|
1075
|
-
|
|
1289
|
+
inst_seg_labels.bboxes_coordinates_column_name = (
|
|
1290
|
+
inst_seg_actual.bounding_boxes_coordinates_column_name
|
|
1076
1291
|
)
|
|
1077
1292
|
|
|
1078
1293
|
if model_type == ModelTypes.GENERATIVE_LLM:
|
|
1079
1294
|
if schema.prompt_template_column_names is not None:
|
|
1080
|
-
|
|
1081
|
-
|
|
1295
|
+
prompt_template_names = schema.prompt_template_column_names
|
|
1296
|
+
arrow_prompt_names = s.arrow_schema.prompt_template_column_names
|
|
1297
|
+
arrow_prompt_names.template_column_name = (
|
|
1298
|
+
prompt_template_names.template_column_name
|
|
1082
1299
|
)
|
|
1083
|
-
|
|
1084
|
-
|
|
1300
|
+
arrow_prompt_names.template_version_column_name = (
|
|
1301
|
+
prompt_template_names.template_version_column_name
|
|
1085
1302
|
)
|
|
1086
1303
|
if schema.llm_config_column_names is not None:
|
|
1087
1304
|
s.arrow_schema.llm_config_column_names.model_column_name = (
|
|
@@ -1114,6 +1331,7 @@ def _get_pb_schema_corpus(
|
|
|
1114
1331
|
schema: CorpusSchema,
|
|
1115
1332
|
model_id: str,
|
|
1116
1333
|
) -> pb2.Schema:
|
|
1334
|
+
"""Construct a protocol buffer Schema from CorpusSchema for document corpus logging."""
|
|
1117
1335
|
s = pb2.Schema()
|
|
1118
1336
|
s.constants.model_id = model_id
|
|
1119
1337
|
s.constants.environment = pb2.Schema.Environment.CORPUS
|
|
@@ -1127,11 +1345,12 @@ def _get_pb_schema_corpus(
|
|
|
1127
1345
|
schema.document_version_column_name
|
|
1128
1346
|
)
|
|
1129
1347
|
if schema.document_text_embedding_column_names is not None:
|
|
1130
|
-
|
|
1131
|
-
s.arrow_schema.document_column_names.text_column_name
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1348
|
+
doc_text_emb_cols = schema.document_text_embedding_column_names
|
|
1349
|
+
doc_text_col = s.arrow_schema.document_column_names.text_column_name
|
|
1350
|
+
doc_text_col.vector_column_name = doc_text_emb_cols.vector_column_name
|
|
1351
|
+
doc_text_col.data_column_name = doc_text_emb_cols.data_column_name
|
|
1352
|
+
if doc_text_emb_cols.link_to_data_column_name is not None:
|
|
1353
|
+
doc_text_col.link_to_data_column_name = (
|
|
1354
|
+
doc_text_emb_cols.link_to_data_column_name
|
|
1355
|
+
)
|
|
1137
1356
|
return s
|