arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,11 +1,14 @@
1
+ """Client implementation for managing ML models in the Arize platform."""
2
+
1
3
  # type: ignore[pb2]
2
4
  from __future__ import annotations
3
5
 
4
6
  import copy
5
7
  import logging
6
8
  import time
7
- from typing import TYPE_CHECKING, Dict, List, Tuple
9
+ from typing import TYPE_CHECKING
8
10
 
11
+ from arize._generated.protocol.rec import public_pb2 as pb2
9
12
  from arize._lazy import require
10
13
  from arize.constants.ml import (
11
14
  LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME,
@@ -30,19 +33,20 @@ from arize.exceptions.parameters import (
30
33
  )
31
34
  from arize.exceptions.spaces import MissingSpaceIDError
32
35
  from arize.logging import get_truncation_warning_message
33
- from arize.models.bounded_executor import BoundedExecutor
34
- from arize.models.casting import cast_dictionary, cast_typed_columns
35
- from arize.models.stream_validation import (
36
+ from arize.ml.bounded_executor import BoundedExecutor
37
+ from arize.ml.casting import cast_dictionary, cast_typed_columns
38
+ from arize.ml.stream_validation import (
36
39
  validate_and_convert_prediction_id,
37
40
  validate_label,
38
41
  )
39
- from arize.types import (
42
+ from arize.ml.types import (
40
43
  CATEGORICAL_MODEL_TYPES,
41
44
  NUMERIC_MODEL_TYPES,
42
45
  ActualLabelTypes,
43
46
  BaseSchema,
44
47
  CorpusSchema,
45
48
  Embedding,
49
+ EmbeddingColumnNames,
46
50
  Environments,
47
51
  LLMRunMetadata,
48
52
  Metrics,
@@ -53,8 +57,8 @@ from arize.types import (
53
57
  SimilaritySearchParams,
54
58
  TypedValue,
55
59
  convert_element,
56
- is_list_of,
57
60
  )
61
+ from arize.utils.types import is_list_of
58
62
 
59
63
  if TYPE_CHECKING:
60
64
  import concurrent.futures as cf
@@ -64,12 +68,7 @@ if TYPE_CHECKING:
64
68
  import requests
65
69
  from requests_futures.sessions import FuturesSession
66
70
 
67
- from arize._generated.protocol.rec import public_pb2 as pb2
68
71
  from arize.config import SDKConfiguration
69
- from arize.types import (
70
- EmbeddingColumnNames,
71
- Schema,
72
- )
73
72
 
74
73
 
75
74
  logger = logging.getLogger(__name__)
@@ -96,7 +95,18 @@ _MIMIC_EXTRA = "mimic-explainer"
96
95
 
97
96
 
98
97
  class MLModelsClient:
99
- def __init__(self, *, sdk_config: SDKConfiguration):
98
+ """Client for logging ML model predictions and actuals to Arize.
99
+
100
+ This class is primarily intended for internal use within the SDK. Users are
101
+ highly encouraged to access resource-specific functionality via
102
+ :class:`arize.ArizeClient`.
103
+ """
104
+
105
+ def __init__(self, *, sdk_config: SDKConfiguration) -> None:
106
+ """
107
+ Args:
108
+ sdk_config: Resolved SDK configuration.
109
+ """ # noqa: D205, D212
100
110
  self._sdk_config = sdk_config
101
111
 
102
112
  # internal cache for the futures session
@@ -114,24 +124,89 @@ class MLModelsClient:
114
124
  prediction_timestamp: int | None = None,
115
125
  prediction_label: PredictionLabelTypes | None = None,
116
126
  actual_label: ActualLabelTypes | None = None,
117
- features: Dict[str, str | bool | float | int | List[str] | TypedValue]
127
+ features: dict[str, str | bool | float | int | list[str] | TypedValue]
118
128
  | None = None,
119
- embedding_features: Dict[str, Embedding] | None = None,
120
- shap_values: Dict[str, float] | None = None,
121
- tags: Dict[str, str | bool | float | int | TypedValue] | None = None,
129
+ embedding_features: dict[str, Embedding] | None = None,
130
+ shap_values: dict[str, float] | None = None,
131
+ tags: dict[str, str | bool | float | int | TypedValue] | None = None,
122
132
  batch_id: str | None = None,
123
133
  prompt: str | Embedding | None = None,
124
134
  response: str | Embedding | None = None,
125
135
  prompt_template: str | None = None,
126
136
  prompt_template_version: str | None = None,
127
137
  llm_model_name: str | None = None,
128
- llm_params: Dict[str, str | bool | float | int] | None = None,
138
+ llm_params: dict[str, str | bool | float | int] | None = None,
129
139
  llm_run_metadata: LLMRunMetadata | None = None,
130
140
  timeout: float | None = None,
131
141
  ) -> cf.Future:
142
+ """Log a single model prediction or actual to Arize asynchronously.
143
+
144
+ This method sends a single prediction, actual, or both to Arize for ML monitoring.
145
+ The request is made asynchronously and returns a Future that can be used to check
146
+ the status or retrieve the response.
147
+
148
+ Args:
149
+ space_id: The space ID where the model resides.
150
+ model_name: A unique name to identify your model in the Arize platform.
151
+ model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
152
+ RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
153
+ spans module instead.
154
+ environment: The environment this data belongs to (PRODUCTION, TRAINING, or
155
+ VALIDATION).
156
+ model_version: Optional version identifier for the model.
157
+ prediction_id: Unique identifier for this prediction. If not provided, one
158
+ will be auto-generated for PRODUCTION environment.
159
+ prediction_timestamp: Unix timestamp (seconds) for when the prediction was made.
160
+ If not provided, the current time is used. Must be within 1 year in the
161
+ future and 2 years in the past from the current time.
162
+ prediction_label: The prediction output from your model. Type depends on
163
+ model_type (e.g., string for categorical, float for numeric).
164
+ actual_label: The ground truth label. Type depends on model_type.
165
+ features: Dictionary of feature name to feature value. Values can be str, bool,
166
+ float, int, list[str], or TypedValue.
167
+ embedding_features: Dictionary of embedding feature name to Embedding object.
168
+ Maximum 50 embeddings per record. Object detection models support only 1.
169
+ shap_values: Dictionary of feature name to SHAP value (float) for feature
170
+ importance/explainability.
171
+ tags: Dictionary of metadata tags. Tag names cannot end with "_shap" or be
172
+ reserved names. Values must be under 1000 characters (warning at 100).
173
+ batch_id: Required for VALIDATION environment; identifies the validation batch.
174
+ prompt: For generative models, the prompt text or embedding sent to the model.
175
+ response: For generative models, the response text or embedding from the model.
176
+ prompt_template: Template used to generate the prompt.
177
+ prompt_template_version: Version identifier for the prompt template.
178
+ llm_model_name: Name of the LLM model used (e.g., "gpt-4").
179
+ llm_params: Dictionary of LLM configuration parameters (e.g., temperature,
180
+ max_tokens).
181
+ llm_run_metadata: Metadata about the LLM run including token counts and latency.
182
+ timeout: Maximum time (in seconds) to wait for the request to complete.
183
+
184
+ Returns:
185
+ A concurrent.futures.Future object representing the async request. Call
186
+ .result() to block and retrieve the Response object, or check .done() for
187
+ completion status.
188
+
189
+ Raises:
190
+ ValueError: If model_type is GENERATIVE_LLM, or if validation environment is
191
+ missing batch_id, or if training/validation environment is missing
192
+ prediction or actual, or if timestamp is out of range, or if no data
193
+ is provided (must have prediction_label, actual_label, tags, or shap_values),
194
+ or if tag names end with "_shap" or exceed length limits.
195
+ MissingSpaceIDError: If space_id is not provided or empty.
196
+ MissingModelNameError: If model_name is not provided or empty.
197
+ InvalidValueType: If features, tags, or other parameters have incorrect types.
198
+ InvalidNumberOfEmbeddings: If more than 50 embedding features are provided.
199
+ KeyError: If tag names include reserved names.
200
+
201
+ Notes:
202
+ - Timestamps must be within 1 year future and 2 years past from current time
203
+ - Tag values are truncated at 1000 characters, with warnings at 100 characters
204
+ - For GENERATIVE_LLM models, use the spans module or OTEL tracing instead
205
+ - The Future returned can be monitored for request status asynchronously
206
+ """
132
207
  require(_STREAM_EXTRA, _STREAM_DEPS)
133
208
  from arize._generated.protocol.rec import public_pb2 as pb2
134
- from arize.models.proto import (
209
+ from arize.ml.proto import (
135
210
  get_pb_dictionary,
136
211
  get_pb_label,
137
212
  get_pb_timestamp,
@@ -179,16 +254,15 @@ class MLModelsClient:
179
254
  _validate_mapping_key(feat_name, "features")
180
255
  if is_list_of(feat_value, str):
181
256
  continue
182
- else:
183
- val = convert_element(feat_value)
184
- if val is not None and not isinstance(
185
- val, (str, bool, float, int)
186
- ):
187
- raise InvalidValueType(
188
- f"feature '{feat_name}'",
189
- feat_value,
190
- "one of: bool, int, float, str",
191
- )
257
+ val = convert_element(feat_value)
258
+ if val is not None and not isinstance(
259
+ val, (str, bool, float, int)
260
+ ):
261
+ raise InvalidValueType(
262
+ f"feature '{feat_name}'",
263
+ feat_value,
264
+ "one of: bool, int, float, str",
265
+ )
192
266
 
193
267
  # Validate embedding_features type
194
268
  if embedding_features:
@@ -247,7 +321,7 @@ class MLModelsClient:
247
321
  f"{MAX_TAG_LENGTH}. The tag {tag_name} with value {tag_value} has "
248
322
  f"{len(str(val))} characters."
249
323
  )
250
- elif len(str(val)) > MAX_TAG_LENGTH_TRUNCATION:
324
+ if len(str(val)) > MAX_TAG_LENGTH_TRUNCATION:
251
325
  logger.warning(
252
326
  get_truncation_warning_message(
253
327
  "tags", MAX_TAG_LENGTH_TRUNCATION
@@ -304,9 +378,7 @@ class MLModelsClient:
304
378
  if embedding_features or prompt or response:
305
379
  # NOTE: Deep copy is necessary to avoid side effects on the original input dictionary
306
380
  combined_embedding_features = (
307
- {k: v for k, v in embedding_features.items()}
308
- if embedding_features
309
- else {}
381
+ embedding_features.copy() if embedding_features else {}
310
382
  )
311
383
  # Map prompt as embedding features for generative models
312
384
  if prompt is not None:
@@ -453,8 +525,7 @@ class MLModelsClient:
453
525
  indexes=None,
454
526
  )
455
527
 
456
- # TODO(Kiko): Handle sync argument
457
- def log_batch(
528
+ def log(
458
529
  self,
459
530
  *,
460
531
  space_id: str,
@@ -466,17 +537,69 @@ class MLModelsClient:
466
537
  model_version: str = "",
467
538
  batch_id: str = "",
468
539
  validate: bool = True,
469
- metrics_validation: List[Metrics] | None = None,
540
+ metrics_validation: list[Metrics] | None = None,
470
541
  surrogate_explainability: bool = False,
471
542
  timeout: float | None = None,
472
543
  tmp_dir: str = "",
473
- sync: bool = False,
474
544
  ) -> requests.Response:
545
+ """Log a batch of model predictions and actuals to Arize from a pandas DataFrame.
546
+
547
+ This method uploads multiple records to Arize in a single batch operation using
548
+ Apache Arrow format for efficient transfer. The dataframe structure is defined
549
+ by the provided schema which maps dataframe columns to Arize data fields.
550
+
551
+ Args:
552
+ space_id: The space ID where the model resides.
553
+ model_name: A unique name to identify your model in the Arize platform.
554
+ model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
555
+ RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
556
+ spans module instead.
557
+ dataframe: Pandas DataFrame containing the data to upload. Columns should
558
+ correspond to the schema field mappings.
559
+ schema: Schema object (Schema or CorpusSchema) that defines the mapping between
560
+ dataframe columns and Arize data fields (e.g., prediction_label_column_name,
561
+ feature_column_names, etc.).
562
+ environment: The environment this data belongs to (PRODUCTION, TRAINING,
563
+ VALIDATION, or CORPUS).
564
+ model_version: Optional version identifier for the model.
565
+ batch_id: Required for VALIDATION environment; identifies the validation batch.
566
+ validate: When True, performs comprehensive validation before sending data.
567
+ Includes checks for required fields, data types, and value constraints.
568
+ metrics_validation: Optional list of metric families to validate against.
569
+ surrogate_explainability: When True, automatically generates SHAP values using
570
+ MIMIC surrogate explainer. Requires the 'mimic-explainer' extra. Has no
571
+ effect if shap_values_column_names is already specified in schema.
572
+ timeout: Maximum time (in seconds) to wait for the request to complete.
573
+ tmp_dir: Optional temporary directory to store serialized Arrow data before
574
+ upload.
575
+
576
+ Returns:
577
+ A requests.Response object from the upload request. Check .status_code for
578
+ success (200) or error conditions.
579
+
580
+ Raises:
581
+ MissingSpaceIDError: If space_id is not provided or empty.
582
+ MissingModelNameError: If model_name is not provided or empty.
583
+ ValueError: If model_type is GENERATIVE_LLM, or if environment is CORPUS with
584
+ non-CorpusSchema, or if training/validation records are incomplete.
585
+ ValidationFailure: If validate=True and validation checks fail. Contains list
586
+ of validation error messages.
587
+ pa.ArrowInvalid: If the dataframe cannot be converted to Arrow format, typically
588
+ due to mixed types in columns not specified in the schema.
589
+
590
+ Notes:
591
+ - Categorical dtype columns are automatically converted to string
592
+ - Extraneous columns not in the schema are removed before upload
593
+ - Surrogate explainability requires 'mimic-explainer' extra
594
+ - For GENERATIVE_LLM models, use the spans module or OTEL tracing instead
595
+ - If logging actuals without predictions, ensure predictions were logged first
596
+ - Data is sent via Apache Arrow for efficient large batch transfers
597
+ """
475
598
  require(_BATCH_EXTRA, _BATCH_DEPS)
476
599
  import pandas.api.types as ptypes
477
600
  import pyarrow as pa
478
601
 
479
- from arize.models.batch_validation.validator import Validator
602
+ from arize.ml.batch_validation.validator import Validator
480
603
  from arize.utils.arrow import post_arrow_table
481
604
  from arize.utils.dataframe import remove_extraneous_columns
482
605
 
@@ -506,8 +629,8 @@ class MLModelsClient:
506
629
  # Thus we can only offer this functionality with pandas>=1.0.0.
507
630
  try:
508
631
  dataframe, schema = cast_typed_columns(dataframe, schema)
509
- except Exception as e:
510
- logger.error(e)
632
+ except Exception:
633
+ logger.exception("Error casting typed columns")
511
634
  raise
512
635
 
513
636
  logger.debug("Performing required validation.")
@@ -546,7 +669,7 @@ class MLModelsClient:
546
669
 
547
670
  # always validate pd.Category is not present, if yes, convert to string
548
671
  has_cat_col = any(
549
- [ptypes.is_categorical_dtype(x) for x in dataframe.dtypes]
672
+ ptypes.is_categorical_dtype(x) for x in dataframe.dtypes
550
673
  )
551
674
  if has_cat_col:
552
675
  cat_cols = [
@@ -554,12 +677,18 @@ class MLModelsClient:
554
677
  for col_name, col_cat in dataframe.dtypes.items()
555
678
  if col_cat.name == "category"
556
679
  ]
557
- cat_str_map = dict(zip(cat_cols, ["str"] * len(cat_cols)))
680
+ cat_str_map = dict(
681
+ zip(
682
+ cat_cols,
683
+ ["str"] * len(cat_cols),
684
+ strict=True,
685
+ )
686
+ )
558
687
  dataframe = dataframe.astype(cat_str_map)
559
688
 
560
689
  if surrogate_explainability:
561
690
  require(_MIMIC_EXTRA, _MIMIC_DEPS)
562
- from arize.models.surrogate_explainer.mimic import Mimic
691
+ from arize.ml.surrogate_explainer.mimic import Mimic
563
692
 
564
693
  logger.debug("Running surrogate_explainability.")
565
694
  if schema.shap_values_column_names:
@@ -588,12 +717,12 @@ class MLModelsClient:
588
717
  # error conditions that we're currently not aware of.
589
718
  pa_table = pa.Table.from_pandas(dataframe, preserve_index=False)
590
719
  except pa.ArrowInvalid as e:
591
- logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
720
+ logger.exception(INVALID_ARROW_CONVERSION_MSG)
592
721
  raise pa.ArrowInvalid(
593
- f"Error converting to Arrow format: {str(e)}"
722
+ f"Error converting to Arrow format: {e!s}"
594
723
  ) from e
595
- except Exception as e:
596
- logger.error(f"Unexpected error creating Arrow table: {str(e)}")
724
+ except Exception:
725
+ logger.exception("Unexpected error creating Arrow table")
597
726
  raise
598
727
 
599
728
  if validate:
@@ -678,18 +807,53 @@ class MLModelsClient:
678
807
  model_version: str = "",
679
808
  batch_id: str = "",
680
809
  where: str = "",
681
- columns: List | None = None,
810
+ columns: list | None = None,
682
811
  similarity_search_params: SimilaritySearchParams | None = None,
683
812
  stream_chunk_size: int | None = None,
684
813
  ) -> pd.DataFrame:
814
+ """Export model data from Arize to a pandas DataFrame.
815
+
816
+ Retrieves prediction and optional actual data for a model within a specified time
817
+ range and returns it as a pandas DataFrame for analysis.
818
+
819
+ Args:
820
+ space_id: The space ID where the model resides.
821
+ model_name: The name of the model to export data from.
822
+ environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
823
+ start_time: Start of the time range (inclusive) as a datetime object.
824
+ end_time: End of the time range (inclusive) as a datetime object.
825
+ include_actuals: When True, includes actual labels in the export. When False,
826
+ only predictions are returned.
827
+ model_version: Optional model version to filter by. Empty string returns all
828
+ versions.
829
+ batch_id: Optional batch ID to filter by (for VALIDATION environment).
830
+ where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5").
831
+ columns: Optional list of column names to include. If None, all columns are
832
+ returned.
833
+ similarity_search_params: Optional parameters for embedding similarity search
834
+ filtering.
835
+ stream_chunk_size: Optional chunk size for streaming large result sets.
836
+
837
+ Returns:
838
+ A pandas DataFrame containing the exported data with columns for predictions,
839
+ actuals (if requested), features, tags, timestamps, and other model metadata.
840
+
841
+ Raises:
842
+ RuntimeError: If the Flight client request fails or returns no response.
843
+
844
+ Notes:
845
+ - Uses Apache Arrow Flight for efficient data transfer
846
+ - Large exports may benefit from specifying stream_chunk_size
847
+ - The where clause supports SQL-like filtering syntax
848
+ """
685
849
  require(_BATCH_EXTRA, _BATCH_DEPS)
686
850
  from arize._exporter.client import ArizeExportClient
687
851
  from arize._flight.client import ArizeFlightClient
688
852
 
689
853
  with ArizeFlightClient(
690
854
  api_key=self._sdk_config.api_key,
691
- host=self._sdk_config.flight_server_host,
692
- port=self._sdk_config.flight_server_port,
855
+ host=self._sdk_config.flight_host,
856
+ port=self._sdk_config.flight_port,
693
857
  scheme=self._sdk_config.flight_scheme,
694
858
  request_verify=self._sdk_config.request_verify,
695
859
  max_chunksize=self._sdk_config.pyarrow_max_chunksize,
@@ -724,18 +888,53 @@ class MLModelsClient:
724
888
  model_version: str = "",
725
889
  batch_id: str = "",
726
890
  where: str = "",
727
- columns: List | None = None,
891
+ columns: list | None = None,
728
892
  similarity_search_params: SimilaritySearchParams | None = None,
729
893
  stream_chunk_size: int | None = None,
730
894
  ) -> pd.DataFrame:
895
+ """Export model data from Arize to a Parquet file and return as DataFrame.
896
+
897
+ Retrieves prediction and optional actual data for a model within a specified time
898
+ range, saves it as a Parquet file, and returns it as a pandas DataFrame.
899
+
900
+ Args:
901
+ space_id: The space ID where the model resides.
902
+ model_name: The name of the model to export data from.
903
+ environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
904
+ start_time: Start of the time range (inclusive) as a datetime object.
905
+ end_time: End of the time range (inclusive) as a datetime object.
906
+ include_actuals: When True, includes actual labels in the export. When False,
907
+ only predictions are returned.
908
+ model_version: Optional model version to filter by. Empty string returns all
909
+ versions.
910
+ batch_id: Optional batch ID to filter by (for VALIDATION environment).
911
+ where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5").
912
+ columns: Optional list of column names to include. If None, all columns are
913
+ returned.
914
+ similarity_search_params: Optional parameters for embedding similarity search
915
+ filtering.
916
+ stream_chunk_size: Optional chunk size for streaming large result sets.
917
+
918
+ Returns:
919
+ A pandas DataFrame containing the exported data. The data is also saved to a
920
+ Parquet file by the underlying export client.
921
+
922
+ Raises:
923
+ RuntimeError: If the Flight client request fails or returns no response.
924
+
925
+ Notes:
926
+ - Uses Apache Arrow Flight for efficient data transfer
927
+ - The Parquet file location is managed by the ArizeExportClient
928
+ - Large exports may benefit from specifying stream_chunk_size
929
+ """
731
930
  require(_BATCH_EXTRA, _BATCH_DEPS)
732
931
  from arize._exporter.client import ArizeExportClient
733
932
  from arize._flight.client import ArizeFlightClient
734
933
 
735
934
  with ArizeFlightClient(
736
935
  api_key=self._sdk_config.api_key,
737
- host=self._sdk_config.flight_server_host,
738
- port=self._sdk_config.flight_server_port,
936
+ host=self._sdk_config.flight_host,
937
+ port=self._sdk_config.flight_port,
739
938
  scheme=self._sdk_config.flight_scheme,
740
939
  request_verify=self._sdk_config.request_verify,
741
940
  max_chunksize=self._sdk_config.pyarrow_max_chunksize,
@@ -759,6 +958,7 @@ class MLModelsClient:
759
958
  )
760
959
 
761
960
  def _ensure_session(self) -> FuturesSession:
961
+ """Lazily initialize and return the FuturesSession for async streaming requests."""
762
962
  from requests_futures.sessions import FuturesSession
763
963
 
764
964
  session = object.__getattribute__(self, "_session")
@@ -778,10 +978,11 @@ class MLModelsClient:
778
978
  def _post(
779
979
  self,
780
980
  record: pb2.Record,
781
- headers: Dict[str, str],
981
+ headers: dict[str, str],
782
982
  timeout: float | None,
783
- indexes: Tuple,
784
- ):
983
+ indexes: tuple,
984
+ ) -> object:
985
+ """Post a record to Arize via async HTTP request with protobuf JSON serialization."""
785
986
  from google.protobuf.json_format import MessageToDict
786
987
 
787
988
  session = self._ensure_session()
@@ -801,9 +1002,10 @@ class MLModelsClient:
801
1002
  return resp
802
1003
 
803
1004
 
804
- def _validate_mapping_key(key_name: str, name: str):
1005
+ def _validate_mapping_key(key_name: str, name: str) -> None:
1006
+ """Validate that a mapping key (feature/tag name) is a string and doesn't end with '_shap'."""
805
1007
  if not isinstance(key_name, str):
806
- raise ValueError(
1008
+ raise TypeError(
807
1009
  f"{name} dictionary key {key_name} must be named with string, type used: {type(key_name)}"
808
1010
  )
809
1011
  if key_name.endswith("_shap"):
@@ -813,7 +1015,8 @@ def _validate_mapping_key(key_name: str, name: str):
813
1015
  return
814
1016
 
815
1017
 
816
- def _is_timestamp_in_range(now: int, ts: int):
1018
+ def _is_timestamp_in_range(now: int, ts: int) -> bool:
1019
+ """Check if a timestamp is within the acceptable range (1 year future, 2 years past)."""
817
1020
  max_time = now + (MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
818
1021
  min_time = now - (MAX_PAST_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
819
1022
  return min_time <= ts <= max_time
@@ -826,7 +1029,8 @@ def _get_pb_schema(
826
1029
  model_type: ModelTypes,
827
1030
  environment: Environments,
828
1031
  batch_id: str,
829
- ):
1032
+ ) -> object:
1033
+ """Construct a protocol buffer Schema from the user's Schema for batch logging."""
830
1034
  s = pb2.Schema()
831
1035
  s.constants.model_id = model_id
832
1036
 
@@ -874,48 +1078,52 @@ def _get_pb_schema(
874
1078
 
875
1079
  if model_type == ModelTypes.OBJECT_DETECTION:
876
1080
  if schema.object_detection_prediction_column_names is not None:
877
- s.arrow_schema.prediction_object_detection_label_column_names.bboxes_coordinates_column_name = (
878
- schema.object_detection_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
1081
+ obj_det_pred = schema.object_detection_prediction_column_names
1082
+ pred_labels = (
1083
+ s.arrow_schema.prediction_object_detection_label_column_names
879
1084
  )
880
- s.arrow_schema.prediction_object_detection_label_column_names.bboxes_categories_column_name = (
881
- schema.object_detection_prediction_column_names.categories_column_name # noqa: E501
1085
+ pred_labels.bboxes_coordinates_column_name = (
1086
+ obj_det_pred.bounding_boxes_coordinates_column_name
882
1087
  )
883
- if (
884
- schema.object_detection_prediction_column_names.scores_column_name
885
- is not None
886
- ):
887
- s.arrow_schema.prediction_object_detection_label_column_names.bboxes_scores_column_name = (
888
- schema.object_detection_prediction_column_names.scores_column_name # noqa: E501
1088
+ pred_labels.bboxes_categories_column_name = (
1089
+ obj_det_pred.categories_column_name
1090
+ )
1091
+ if obj_det_pred.scores_column_name is not None:
1092
+ pred_labels.bboxes_scores_column_name = (
1093
+ obj_det_pred.scores_column_name
889
1094
  )
890
1095
 
891
1096
  if schema.semantic_segmentation_prediction_column_names is not None:
892
- s.arrow_schema.prediction_semantic_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
893
- schema.semantic_segmentation_prediction_column_names.polygon_coordinates_column_name
1097
+ seg_pred_cols = schema.semantic_segmentation_prediction_column_names
1098
+ pred_seg_labels = s.arrow_schema.prediction_semantic_segmentation_label_column_names
1099
+ pred_seg_labels.polygons_coordinates_column_name = (
1100
+ seg_pred_cols.polygon_coordinates_column_name
894
1101
  )
895
- s.arrow_schema.prediction_semantic_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
896
- schema.semantic_segmentation_prediction_column_names.categories_column_name
1102
+ pred_seg_labels.polygons_categories_column_name = (
1103
+ seg_pred_cols.categories_column_name
897
1104
  )
898
1105
 
899
1106
  if schema.instance_segmentation_prediction_column_names is not None:
900
- s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
901
- schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name
1107
+ inst_seg_pred_cols = (
1108
+ schema.instance_segmentation_prediction_column_names
902
1109
  )
903
- s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
904
- schema.instance_segmentation_prediction_column_names.categories_column_name
1110
+ pred_inst_seg_labels = s.arrow_schema.prediction_instance_segmentation_label_column_names
1111
+ pred_inst_seg_labels.polygons_coordinates_column_name = (
1112
+ inst_seg_pred_cols.polygon_coordinates_column_name
905
1113
  )
906
- if (
907
- schema.instance_segmentation_prediction_column_names.scores_column_name
908
- is not None
909
- ):
910
- s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_scores_column_name = ( # noqa: E501
911
- schema.instance_segmentation_prediction_column_names.scores_column_name
1114
+ pred_inst_seg_labels.polygons_categories_column_name = (
1115
+ inst_seg_pred_cols.categories_column_name
1116
+ )
1117
+ if inst_seg_pred_cols.scores_column_name is not None:
1118
+ pred_inst_seg_labels.polygons_scores_column_name = (
1119
+ inst_seg_pred_cols.scores_column_name
912
1120
  )
913
1121
  if (
914
- schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name
1122
+ inst_seg_pred_cols.bounding_boxes_coordinates_column_name
915
1123
  is not None
916
1124
  ):
917
- s.arrow_schema.prediction_instance_segmentation_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
918
- schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name
1125
+ pred_inst_seg_labels.bboxes_coordinates_column_name = (
1126
+ inst_seg_pred_cols.bounding_boxes_coordinates_column_name
919
1127
  )
920
1128
 
921
1129
  if schema.prediction_score_column_name is not None:
@@ -1038,50 +1246,61 @@ def _get_pb_schema(
1038
1246
 
1039
1247
  if model_type == ModelTypes.OBJECT_DETECTION:
1040
1248
  if schema.object_detection_actual_column_names is not None:
1041
- s.arrow_schema.actual_object_detection_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
1042
- schema.object_detection_actual_column_names.bounding_boxes_coordinates_column_name
1249
+ obj_det_actual = schema.object_detection_actual_column_names
1250
+ actual_labels = (
1251
+ s.arrow_schema.actual_object_detection_label_column_names
1043
1252
  )
1044
- s.arrow_schema.actual_object_detection_label_column_names.bboxes_categories_column_name = ( # noqa: E501
1045
- schema.object_detection_actual_column_names.categories_column_name
1253
+ actual_labels.bboxes_coordinates_column_name = (
1254
+ obj_det_actual.bounding_boxes_coordinates_column_name
1046
1255
  )
1047
- if (
1048
- schema.object_detection_actual_column_names.scores_column_name
1049
- is not None
1050
- ):
1051
- s.arrow_schema.actual_object_detection_label_column_names.bboxes_scores_column_name = ( # noqa: E501
1052
- schema.object_detection_actual_column_names.scores_column_name
1256
+ actual_labels.bboxes_categories_column_name = (
1257
+ obj_det_actual.categories_column_name
1258
+ )
1259
+ if obj_det_actual.scores_column_name is not None:
1260
+ actual_labels.bboxes_scores_column_name = (
1261
+ obj_det_actual.scores_column_name
1053
1262
  )
1054
1263
 
1055
1264
  if schema.semantic_segmentation_actual_column_names is not None:
1056
- s.arrow_schema.actual_semantic_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
1057
- schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
1265
+ sem_seg_actual = schema.semantic_segmentation_actual_column_names
1266
+ sem_seg_labels = (
1267
+ s.arrow_schema.actual_semantic_segmentation_label_column_names
1268
+ )
1269
+ sem_seg_labels.polygons_coordinates_column_name = (
1270
+ sem_seg_actual.polygon_coordinates_column_name
1058
1271
  )
1059
- s.arrow_schema.actual_semantic_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
1060
- schema.semantic_segmentation_actual_column_names.categories_column_name
1272
+ sem_seg_labels.polygons_categories_column_name = (
1273
+ sem_seg_actual.categories_column_name
1061
1274
  )
1062
1275
 
1063
1276
  if schema.instance_segmentation_actual_column_names is not None:
1064
- s.arrow_schema.actual_instance_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
1065
- schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
1277
+ inst_seg_actual = schema.instance_segmentation_actual_column_names
1278
+ inst_seg_labels = (
1279
+ s.arrow_schema.actual_instance_segmentation_label_column_names
1066
1280
  )
1067
- s.arrow_schema.actual_instance_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
1068
- schema.instance_segmentation_actual_column_names.categories_column_name
1281
+ inst_seg_labels.polygons_coordinates_column_name = (
1282
+ inst_seg_actual.polygon_coordinates_column_name
1283
+ )
1284
+ inst_seg_labels.polygons_categories_column_name = (
1285
+ inst_seg_actual.categories_column_name
1069
1286
  )
1070
1287
  if (
1071
- schema.instance_segmentation_actual_column_names.bounding_boxes_coordinates_column_name
1288
+ inst_seg_actual.bounding_boxes_coordinates_column_name
1072
1289
  is not None
1073
1290
  ):
1074
- s.arrow_schema.actual_instance_segmentation_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
1075
- schema.instance_segmentation_actual_column_names.bounding_boxes_coordinates_column_name
1291
+ inst_seg_labels.bboxes_coordinates_column_name = (
1292
+ inst_seg_actual.bounding_boxes_coordinates_column_name
1076
1293
  )
1077
1294
 
1078
1295
  if model_type == ModelTypes.GENERATIVE_LLM:
1079
1296
  if schema.prompt_template_column_names is not None:
1080
- s.arrow_schema.prompt_template_column_names.template_column_name = (
1081
- schema.prompt_template_column_names.template_column_name
1297
+ prompt_template_names = schema.prompt_template_column_names
1298
+ arrow_prompt_names = s.arrow_schema.prompt_template_column_names
1299
+ arrow_prompt_names.template_column_name = (
1300
+ prompt_template_names.template_column_name
1082
1301
  )
1083
- s.arrow_schema.prompt_template_column_names.template_version_column_name = ( # noqa: E501
1084
- schema.prompt_template_column_names.template_version_column_name
1302
+ arrow_prompt_names.template_version_column_name = (
1303
+ prompt_template_names.template_version_column_name
1085
1304
  )
1086
1305
  if schema.llm_config_column_names is not None:
1087
1306
  s.arrow_schema.llm_config_column_names.model_column_name = (
@@ -1114,6 +1333,7 @@ def _get_pb_schema_corpus(
1114
1333
  schema: CorpusSchema,
1115
1334
  model_id: str,
1116
1335
  ) -> pb2.Schema:
1336
+ """Construct a protocol buffer Schema from CorpusSchema for document corpus logging."""
1117
1337
  s = pb2.Schema()
1118
1338
  s.constants.model_id = model_id
1119
1339
  s.constants.environment = pb2.Schema.Environment.CORPUS
@@ -1127,11 +1347,12 @@ def _get_pb_schema_corpus(
1127
1347
  schema.document_version_column_name
1128
1348
  )
1129
1349
  if schema.document_text_embedding_column_names is not None:
1130
- s.arrow_schema.document_column_names.text_column_name.vector_column_name = schema.document_text_embedding_column_names.vector_column_name # noqa: E501
1131
- s.arrow_schema.document_column_names.text_column_name.data_column_name = schema.document_text_embedding_column_names.data_column_name # noqa: E501
1132
- if (
1133
- schema.document_text_embedding_column_names.link_to_data_column_name
1134
- is not None
1135
- ):
1136
- s.arrow_schema.document_column_names.text_column_name.link_to_data_column_name = schema.document_text_embedding_column_names.link_to_data_column_name # noqa: E501
1350
+ doc_text_emb_cols = schema.document_text_embedding_column_names
1351
+ doc_text_col = s.arrow_schema.document_column_names.text_column_name
1352
+ doc_text_col.vector_column_name = doc_text_emb_cols.vector_column_name
1353
+ doc_text_col.data_column_name = doc_text_emb_cols.data_column_name
1354
+ if doc_text_emb_cols.link_to_data_column_name is not None:
1355
+ doc_text_col.link_to_data_column_name = (
1356
+ doc_text_emb_cols.link_to_data_column_name
1357
+ )
1137
1358
  return s