arize 8.0.0a22__py3-none-any.whl → 8.0.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. arize/__init__.py +17 -9
  2. arize/_exporter/client.py +55 -36
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +268 -55
  65. arize/config.py +365 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +299 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +31 -12
  83. arize/embeddings/tabular_generators.py +32 -20
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +1 -0
  94. arize/experiments/client.py +389 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/models/__init__.py +1 -0
  107. arize/models/batch_validation/__init__.py +1 -0
  108. arize/models/batch_validation/errors.py +543 -65
  109. arize/models/batch_validation/validator.py +339 -300
  110. arize/models/bounded_executor.py +20 -7
  111. arize/models/casting.py +75 -29
  112. arize/models/client.py +326 -107
  113. arize/models/proto.py +95 -40
  114. arize/models/stream_validation.py +42 -14
  115. arize/models/surrogate_explainer/__init__.py +1 -0
  116. arize/models/surrogate_explainer/mimic.py +24 -13
  117. arize/pre_releases.py +43 -0
  118. arize/projects/__init__.py +1 -0
  119. arize/projects/client.py +129 -0
  120. arize/regions.py +40 -0
  121. arize/spans/__init__.py +1 -0
  122. arize/spans/client.py +130 -106
  123. arize/spans/columns.py +13 -0
  124. arize/spans/conversion.py +54 -38
  125. arize/spans/validation/__init__.py +1 -0
  126. arize/spans/validation/annotations/__init__.py +1 -0
  127. arize/spans/validation/annotations/annotations_validation.py +6 -4
  128. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  129. arize/spans/validation/annotations/value_validation.py +35 -11
  130. arize/spans/validation/common/__init__.py +1 -0
  131. arize/spans/validation/common/argument_validation.py +33 -8
  132. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  133. arize/spans/validation/common/errors.py +211 -11
  134. arize/spans/validation/common/value_validation.py +80 -13
  135. arize/spans/validation/evals/__init__.py +1 -0
  136. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  137. arize/spans/validation/evals/evals_validation.py +34 -4
  138. arize/spans/validation/evals/value_validation.py +26 -3
  139. arize/spans/validation/metadata/__init__.py +1 -1
  140. arize/spans/validation/metadata/argument_validation.py +14 -5
  141. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  142. arize/spans/validation/metadata/value_validation.py +24 -10
  143. arize/spans/validation/spans/__init__.py +1 -0
  144. arize/spans/validation/spans/dataframe_form_validation.py +34 -13
  145. arize/spans/validation/spans/spans_validation.py +35 -4
  146. arize/spans/validation/spans/value_validation.py +76 -7
  147. arize/types.py +293 -157
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +19 -2
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/version.py +3 -1
  158. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
  159. arize-8.0.0a23.dist-info/RECORD +174 -0
  160. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
  161. arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
  162. arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
  163. arize/_generated/protocol/flight/export_pb2.py +0 -61
  164. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  165. arize-8.0.0a22.dist-info/RECORD +0 -146
  166. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
arize/models/client.py CHANGED
@@ -1,11 +1,14 @@
1
+ """Client implementation for managing ML models in the Arize platform."""
2
+
1
3
  # type: ignore[pb2]
2
4
  from __future__ import annotations
3
5
 
4
6
  import copy
5
7
  import logging
6
8
  import time
7
- from typing import TYPE_CHECKING, Dict, List, Tuple
9
+ from typing import TYPE_CHECKING
8
10
 
11
+ from arize._generated.protocol.rec import public_pb2 as pb2
9
12
  from arize._lazy import require
10
13
  from arize.constants.ml import (
11
14
  LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME,
@@ -43,6 +46,7 @@ from arize.types import (
43
46
  BaseSchema,
44
47
  CorpusSchema,
45
48
  Embedding,
49
+ EmbeddingColumnNames,
46
50
  Environments,
47
51
  LLMRunMetadata,
48
52
  Metrics,
@@ -64,12 +68,7 @@ if TYPE_CHECKING:
64
68
  import requests
65
69
  from requests_futures.sessions import FuturesSession
66
70
 
67
- from arize._generated.protocol.rec import public_pb2 as pb2
68
71
  from arize.config import SDKConfiguration
69
- from arize.types import (
70
- EmbeddingColumnNames,
71
- Schema,
72
- )
73
72
 
74
73
 
75
74
  logger = logging.getLogger(__name__)
@@ -96,7 +95,14 @@ _MIMIC_EXTRA = "mimic-explainer"
96
95
 
97
96
 
98
97
  class MLModelsClient:
99
- def __init__(self, *, sdk_config: SDKConfiguration):
98
+ """Client for logging ML model predictions and actuals to Arize."""
99
+
100
+ def __init__(self, *, sdk_config: SDKConfiguration) -> None:
101
+ """Initialize the ML models client with SDK configuration.
102
+
103
+ Args:
104
+ sdk_config: SDK configuration containing API endpoints and credentials.
105
+ """
100
106
  self._sdk_config = sdk_config
101
107
 
102
108
  # internal cache for the futures session
@@ -114,21 +120,86 @@ class MLModelsClient:
114
120
  prediction_timestamp: int | None = None,
115
121
  prediction_label: PredictionLabelTypes | None = None,
116
122
  actual_label: ActualLabelTypes | None = None,
117
- features: Dict[str, str | bool | float | int | List[str] | TypedValue]
123
+ features: dict[str, str | bool | float | int | list[str] | TypedValue]
118
124
  | None = None,
119
- embedding_features: Dict[str, Embedding] | None = None,
120
- shap_values: Dict[str, float] | None = None,
121
- tags: Dict[str, str | bool | float | int | TypedValue] | None = None,
125
+ embedding_features: dict[str, Embedding] | None = None,
126
+ shap_values: dict[str, float] | None = None,
127
+ tags: dict[str, str | bool | float | int | TypedValue] | None = None,
122
128
  batch_id: str | None = None,
123
129
  prompt: str | Embedding | None = None,
124
130
  response: str | Embedding | None = None,
125
131
  prompt_template: str | None = None,
126
132
  prompt_template_version: str | None = None,
127
133
  llm_model_name: str | None = None,
128
- llm_params: Dict[str, str | bool | float | int] | None = None,
134
+ llm_params: dict[str, str | bool | float | int] | None = None,
129
135
  llm_run_metadata: LLMRunMetadata | None = None,
130
136
  timeout: float | None = None,
131
137
  ) -> cf.Future:
138
+ """Log a single model prediction or actual to Arize asynchronously.
139
+
140
+ This method sends a single prediction, actual, or both to Arize for ML monitoring.
141
+ The request is made asynchronously and returns a Future that can be used to check
142
+ the status or retrieve the response.
143
+
144
+ Args:
145
+ space_id: The space ID where the model resides.
146
+ model_name: A unique name to identify your model in the Arize platform.
147
+ model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
148
+ RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
149
+ spans module instead.
150
+ environment: The environment this data belongs to (PRODUCTION, TRAINING, or
151
+ VALIDATION).
152
+ model_version: Optional version identifier for the model.
153
+ prediction_id: Unique identifier for this prediction. If not provided, one
154
+ will be auto-generated for PRODUCTION environment.
155
+ prediction_timestamp: Unix timestamp (seconds) for when the prediction was made.
156
+ If not provided, the current time is used. Must be within 1 year in the
157
+ future and 2 years in the past from the current time.
158
+ prediction_label: The prediction output from your model. Type depends on
159
+ model_type (e.g., string for categorical, float for numeric).
160
+ actual_label: The ground truth label. Type depends on model_type.
161
+ features: Dictionary of feature name to feature value. Values can be str, bool,
162
+ float, int, list[str], or TypedValue.
163
+ embedding_features: Dictionary of embedding feature name to Embedding object.
164
+ Maximum 50 embeddings per record. Object detection models support only 1.
165
+ shap_values: Dictionary of feature name to SHAP value (float) for feature
166
+ importance/explainability.
167
+ tags: Dictionary of metadata tags. Tag names cannot end with "_shap" or be
168
+ reserved names. Values must be under 1000 characters (warning at 100).
169
+ batch_id: Required for VALIDATION environment; identifies the validation batch.
170
+ prompt: For generative models, the prompt text or embedding sent to the model.
171
+ response: For generative models, the response text or embedding from the model.
172
+ prompt_template: Template used to generate the prompt.
173
+ prompt_template_version: Version identifier for the prompt template.
174
+ llm_model_name: Name of the LLM model used (e.g., "gpt-4").
175
+ llm_params: Dictionary of LLM configuration parameters (e.g., temperature,
176
+ max_tokens).
177
+ llm_run_metadata: Metadata about the LLM run including token counts and latency.
178
+ timeout: Maximum time (in seconds) to wait for the request to complete.
179
+
180
+ Returns:
181
+ A concurrent.futures.Future object representing the async request. Call
182
+ .result() to block and retrieve the Response object, or check .done() for
183
+ completion status.
184
+
185
+ Raises:
186
+ ValueError: If model_type is GENERATIVE_LLM, or if validation environment is
187
+ missing batch_id, or if training/validation environment is missing
188
+ prediction or actual, or if timestamp is out of range, or if no data
189
+ is provided (must have prediction_label, actual_label, tags, or shap_values),
190
+ or if tag names end with "_shap" or exceed length limits.
191
+ MissingSpaceIDError: If space_id is not provided or empty.
192
+ MissingModelNameError: If model_name is not provided or empty.
193
+ InvalidValueType: If features, tags, or other parameters have incorrect types.
194
+ InvalidNumberOfEmbeddings: If more than 50 embedding features are provided.
195
+ KeyError: If tag names include reserved names.
196
+
197
+ Notes:
198
+ - Timestamps must be within 1 year future and 2 years past from current time
199
+ - Tag values are truncated at 1000 characters, with warnings at 100 characters
200
+ - For GENERATIVE_LLM models, use the spans module or OTEL tracing instead
201
+ - The Future returned can be monitored for request status asynchronously
202
+ """
132
203
  require(_STREAM_EXTRA, _STREAM_DEPS)
133
204
  from arize._generated.protocol.rec import public_pb2 as pb2
134
205
  from arize.models.proto import (
@@ -179,16 +250,15 @@ class MLModelsClient:
179
250
  _validate_mapping_key(feat_name, "features")
180
251
  if is_list_of(feat_value, str):
181
252
  continue
182
- else:
183
- val = convert_element(feat_value)
184
- if val is not None and not isinstance(
185
- val, (str, bool, float, int)
186
- ):
187
- raise InvalidValueType(
188
- f"feature '{feat_name}'",
189
- feat_value,
190
- "one of: bool, int, float, str",
191
- )
253
+ val = convert_element(feat_value)
254
+ if val is not None and not isinstance(
255
+ val, (str, bool, float, int)
256
+ ):
257
+ raise InvalidValueType(
258
+ f"feature '{feat_name}'",
259
+ feat_value,
260
+ "one of: bool, int, float, str",
261
+ )
192
262
 
193
263
  # Validate embedding_features type
194
264
  if embedding_features:
@@ -247,7 +317,7 @@ class MLModelsClient:
247
317
  f"{MAX_TAG_LENGTH}. The tag {tag_name} with value {tag_value} has "
248
318
  f"{len(str(val))} characters."
249
319
  )
250
- elif len(str(val)) > MAX_TAG_LENGTH_TRUNCATION:
320
+ if len(str(val)) > MAX_TAG_LENGTH_TRUNCATION:
251
321
  logger.warning(
252
322
  get_truncation_warning_message(
253
323
  "tags", MAX_TAG_LENGTH_TRUNCATION
@@ -304,7 +374,7 @@ class MLModelsClient:
304
374
  if embedding_features or prompt or response:
305
375
  # NOTE: Deep copy is necessary to avoid side effects on the original input dictionary
306
376
  combined_embedding_features = (
307
- {k: v for k, v in embedding_features.items()}
377
+ dict(embedding_features.items())
308
378
  if embedding_features
309
379
  else {}
310
380
  )
@@ -453,7 +523,6 @@ class MLModelsClient:
453
523
  indexes=None,
454
524
  )
455
525
 
456
- # TODO(Kiko): Handle sync argument
457
526
  def log_batch(
458
527
  self,
459
528
  *,
@@ -466,12 +535,64 @@ class MLModelsClient:
466
535
  model_version: str = "",
467
536
  batch_id: str = "",
468
537
  validate: bool = True,
469
- metrics_validation: List[Metrics] | None = None,
538
+ metrics_validation: list[Metrics] | None = None,
470
539
  surrogate_explainability: bool = False,
471
540
  timeout: float | None = None,
472
541
  tmp_dir: str = "",
473
- sync: bool = False,
474
542
  ) -> requests.Response:
543
+ """Log a batch of model predictions and actuals to Arize from a pandas DataFrame.
544
+
545
+ This method uploads multiple records to Arize in a single batch operation using
546
+ Apache Arrow format for efficient transfer. The dataframe structure is defined
547
+ by the provided schema which maps dataframe columns to Arize data fields.
548
+
549
+ Args:
550
+ space_id: The space ID where the model resides.
551
+ model_name: A unique name to identify your model in the Arize platform.
552
+ model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
553
+ RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
554
+ spans module instead.
555
+ dataframe: Pandas DataFrame containing the data to upload. Columns should
556
+ correspond to the schema field mappings.
557
+ schema: Schema object (Schema or CorpusSchema) that defines the mapping between
558
+ dataframe columns and Arize data fields (e.g., prediction_label_column_name,
559
+ feature_column_names, etc.).
560
+ environment: The environment this data belongs to (PRODUCTION, TRAINING,
561
+ VALIDATION, or CORPUS).
562
+ model_version: Optional version identifier for the model.
563
+ batch_id: Required for VALIDATION environment; identifies the validation batch.
564
+ validate: When True, performs comprehensive validation before sending data.
565
+ Includes checks for required fields, data types, and value constraints.
566
+ metrics_validation: Optional list of metric families to validate against.
567
+ surrogate_explainability: When True, automatically generates SHAP values using
568
+ MIMIC surrogate explainer. Requires the 'mimic-explainer' extra. Has no
569
+ effect if shap_values_column_names is already specified in schema.
570
+ timeout: Maximum time (in seconds) to wait for the request to complete.
571
+ tmp_dir: Optional temporary directory to store serialized Arrow data before
572
+ upload.
573
+
574
+ Returns:
575
+ A requests.Response object from the upload request. Check .status_code for
576
+ success (200) or error conditions.
577
+
578
+ Raises:
579
+ MissingSpaceIDError: If space_id is not provided or empty.
580
+ MissingModelNameError: If model_name is not provided or empty.
581
+ ValueError: If model_type is GENERATIVE_LLM, or if environment is CORPUS with
582
+ non-CorpusSchema, or if training/validation records are incomplete.
583
+ ValidationFailure: If validate=True and validation checks fail. Contains list
584
+ of validation error messages.
585
+ pa.ArrowInvalid: If the dataframe cannot be converted to Arrow format, typically
586
+ due to mixed types in columns not specified in the schema.
587
+
588
+ Notes:
589
+ - Categorical dtype columns are automatically converted to string
590
+ - Extraneous columns not in the schema are removed before upload
591
+ - Surrogate explainability requires 'mimic-explainer' extra
592
+ - For GENERATIVE_LLM models, use the spans module or OTEL tracing instead
593
+ - If logging actuals without predictions, ensure predictions were logged first
594
+ - Data is sent via Apache Arrow for efficient large batch transfers
595
+ """
475
596
  require(_BATCH_EXTRA, _BATCH_DEPS)
476
597
  import pandas.api.types as ptypes
477
598
  import pyarrow as pa
@@ -506,8 +627,8 @@ class MLModelsClient:
506
627
  # Thus we can only offer this functionality with pandas>=1.0.0.
507
628
  try:
508
629
  dataframe, schema = cast_typed_columns(dataframe, schema)
509
- except Exception as e:
510
- logger.error(e)
630
+ except Exception:
631
+ logger.exception("Error casting typed columns")
511
632
  raise
512
633
 
513
634
  logger.debug("Performing required validation.")
@@ -546,7 +667,7 @@ class MLModelsClient:
546
667
 
547
668
  # always validate pd.Category is not present, if yes, convert to string
548
669
  has_cat_col = any(
549
- [ptypes.is_categorical_dtype(x) for x in dataframe.dtypes]
670
+ ptypes.is_categorical_dtype(x) for x in dataframe.dtypes
550
671
  )
551
672
  if has_cat_col:
552
673
  cat_cols = [
@@ -554,7 +675,13 @@ class MLModelsClient:
554
675
  for col_name, col_cat in dataframe.dtypes.items()
555
676
  if col_cat.name == "category"
556
677
  ]
557
- cat_str_map = dict(zip(cat_cols, ["str"] * len(cat_cols)))
678
+ cat_str_map = dict(
679
+ zip(
680
+ cat_cols,
681
+ ["str"] * len(cat_cols),
682
+ strict=True,
683
+ )
684
+ )
558
685
  dataframe = dataframe.astype(cat_str_map)
559
686
 
560
687
  if surrogate_explainability:
@@ -588,12 +715,12 @@ class MLModelsClient:
588
715
  # error conditions that we're currently not aware of.
589
716
  pa_table = pa.Table.from_pandas(dataframe, preserve_index=False)
590
717
  except pa.ArrowInvalid as e:
591
- logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
718
+ logger.exception(INVALID_ARROW_CONVERSION_MSG)
592
719
  raise pa.ArrowInvalid(
593
- f"Error converting to Arrow format: {str(e)}"
720
+ f"Error converting to Arrow format: {e!s}"
594
721
  ) from e
595
- except Exception as e:
596
- logger.error(f"Unexpected error creating Arrow table: {str(e)}")
722
+ except Exception:
723
+ logger.exception("Unexpected error creating Arrow table")
597
724
  raise
598
725
 
599
726
  if validate:
@@ -678,18 +805,53 @@ class MLModelsClient:
678
805
  model_version: str = "",
679
806
  batch_id: str = "",
680
807
  where: str = "",
681
- columns: List | None = None,
808
+ columns: list | None = None,
682
809
  similarity_search_params: SimilaritySearchParams | None = None,
683
810
  stream_chunk_size: int | None = None,
684
811
  ) -> pd.DataFrame:
812
+ """Export model data from Arize to a pandas DataFrame.
813
+
814
+ Retrieves prediction and optional actual data for a model within a specified time
815
+ range and returns it as a pandas DataFrame for analysis.
816
+
817
+ Args:
818
+ space_id: The space ID where the model resides.
819
+ model_name: The name of the model to export data from.
820
+ environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
821
+ start_time: Start of the time range (inclusive) as a datetime object.
822
+ end_time: End of the time range (inclusive) as a datetime object.
823
+ include_actuals: When True, includes actual labels in the export. When False,
824
+ only predictions are returned.
825
+ model_version: Optional model version to filter by. Empty string returns all
826
+ versions.
827
+ batch_id: Optional batch ID to filter by (for VALIDATION environment).
828
+ where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5").
829
+ columns: Optional list of column names to include. If None, all columns are
830
+ returned.
831
+ similarity_search_params: Optional parameters for embedding similarity search
832
+ filtering.
833
+ stream_chunk_size: Optional chunk size for streaming large result sets.
834
+
835
+ Returns:
836
+ A pandas DataFrame containing the exported data with columns for predictions,
837
+ actuals (if requested), features, tags, timestamps, and other model metadata.
838
+
839
+ Raises:
840
+ RuntimeError: If the Flight client request fails or returns no response.
841
+
842
+ Notes:
843
+ - Uses Apache Arrow Flight for efficient data transfer
844
+ - Large exports may benefit from specifying stream_chunk_size
845
+ - The where clause supports SQL-like filtering syntax
846
+ """
685
847
  require(_BATCH_EXTRA, _BATCH_DEPS)
686
848
  from arize._exporter.client import ArizeExportClient
687
849
  from arize._flight.client import ArizeFlightClient
688
850
 
689
851
  with ArizeFlightClient(
690
852
  api_key=self._sdk_config.api_key,
691
- host=self._sdk_config.flight_server_host,
692
- port=self._sdk_config.flight_server_port,
853
+ host=self._sdk_config.flight_host,
854
+ port=self._sdk_config.flight_port,
693
855
  scheme=self._sdk_config.flight_scheme,
694
856
  request_verify=self._sdk_config.request_verify,
695
857
  max_chunksize=self._sdk_config.pyarrow_max_chunksize,
@@ -724,18 +886,53 @@ class MLModelsClient:
724
886
  model_version: str = "",
725
887
  batch_id: str = "",
726
888
  where: str = "",
727
- columns: List | None = None,
889
+ columns: list | None = None,
728
890
  similarity_search_params: SimilaritySearchParams | None = None,
729
891
  stream_chunk_size: int | None = None,
730
892
  ) -> pd.DataFrame:
893
+ """Export model data from Arize to a Parquet file and return as DataFrame.
894
+
895
+ Retrieves prediction and optional actual data for a model within a specified time
896
+ range, saves it as a Parquet file, and returns it as a pandas DataFrame.
897
+
898
+ Args:
899
+ space_id: The space ID where the model resides.
900
+ model_name: The name of the model to export data from.
901
+ environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
902
+ start_time: Start of the time range (inclusive) as a datetime object.
903
+ end_time: End of the time range (inclusive) as a datetime object.
904
+ include_actuals: When True, includes actual labels in the export. When False,
905
+ only predictions are returned.
906
+ model_version: Optional model version to filter by. Empty string returns all
907
+ versions.
908
+ batch_id: Optional batch ID to filter by (for VALIDATION environment).
909
+ where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5").
910
+ columns: Optional list of column names to include. If None, all columns are
911
+ returned.
912
+ similarity_search_params: Optional parameters for embedding similarity search
913
+ filtering.
914
+ stream_chunk_size: Optional chunk size for streaming large result sets.
915
+
916
+ Returns:
917
+ A pandas DataFrame containing the exported data. The data is also saved to a
918
+ Parquet file by the underlying export client.
919
+
920
+ Raises:
921
+ RuntimeError: If the Flight client request fails or returns no response.
922
+
923
+ Notes:
924
+ - Uses Apache Arrow Flight for efficient data transfer
925
+ - The Parquet file location is managed by the ArizeExportClient
926
+ - Large exports may benefit from specifying stream_chunk_size
927
+ """
731
928
  require(_BATCH_EXTRA, _BATCH_DEPS)
732
929
  from arize._exporter.client import ArizeExportClient
733
930
  from arize._flight.client import ArizeFlightClient
734
931
 
735
932
  with ArizeFlightClient(
736
933
  api_key=self._sdk_config.api_key,
737
- host=self._sdk_config.flight_server_host,
738
- port=self._sdk_config.flight_server_port,
934
+ host=self._sdk_config.flight_host,
935
+ port=self._sdk_config.flight_port,
739
936
  scheme=self._sdk_config.flight_scheme,
740
937
  request_verify=self._sdk_config.request_verify,
741
938
  max_chunksize=self._sdk_config.pyarrow_max_chunksize,
@@ -759,6 +956,7 @@ class MLModelsClient:
759
956
  )
760
957
 
761
958
  def _ensure_session(self) -> FuturesSession:
959
+ """Lazily initialize and return the FuturesSession for async streaming requests."""
762
960
  from requests_futures.sessions import FuturesSession
763
961
 
764
962
  session = object.__getattribute__(self, "_session")
@@ -778,10 +976,11 @@ class MLModelsClient:
778
976
  def _post(
779
977
  self,
780
978
  record: pb2.Record,
781
- headers: Dict[str, str],
979
+ headers: dict[str, str],
782
980
  timeout: float | None,
783
- indexes: Tuple,
784
- ):
981
+ indexes: tuple,
982
+ ) -> object:
983
+ """Post a record to Arize via async HTTP request with protobuf JSON serialization."""
785
984
  from google.protobuf.json_format import MessageToDict
786
985
 
787
986
  session = self._ensure_session()
@@ -801,9 +1000,10 @@ class MLModelsClient:
801
1000
  return resp
802
1001
 
803
1002
 
804
- def _validate_mapping_key(key_name: str, name: str):
1003
+ def _validate_mapping_key(key_name: str, name: str) -> None:
1004
+ """Validate that a mapping key (feature/tag name) is a string and doesn't end with '_shap'."""
805
1005
  if not isinstance(key_name, str):
806
- raise ValueError(
1006
+ raise TypeError(
807
1007
  f"{name} dictionary key {key_name} must be named with string, type used: {type(key_name)}"
808
1008
  )
809
1009
  if key_name.endswith("_shap"):
@@ -813,7 +1013,8 @@ def _validate_mapping_key(key_name: str, name: str):
813
1013
  return
814
1014
 
815
1015
 
816
- def _is_timestamp_in_range(now: int, ts: int):
1016
+ def _is_timestamp_in_range(now: int, ts: int) -> bool:
1017
+ """Check if a timestamp is within the acceptable range (1 year future, 2 years past)."""
817
1018
  max_time = now + (MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
818
1019
  min_time = now - (MAX_PAST_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
819
1020
  return min_time <= ts <= max_time
@@ -826,7 +1027,8 @@ def _get_pb_schema(
826
1027
  model_type: ModelTypes,
827
1028
  environment: Environments,
828
1029
  batch_id: str,
829
- ):
1030
+ ) -> object:
1031
+ """Construct a protocol buffer Schema from the user's Schema for batch logging."""
830
1032
  s = pb2.Schema()
831
1033
  s.constants.model_id = model_id
832
1034
 
@@ -874,48 +1076,52 @@ def _get_pb_schema(
874
1076
 
875
1077
  if model_type == ModelTypes.OBJECT_DETECTION:
876
1078
  if schema.object_detection_prediction_column_names is not None:
877
- s.arrow_schema.prediction_object_detection_label_column_names.bboxes_coordinates_column_name = (
878
- schema.object_detection_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
1079
+ obj_det_pred = schema.object_detection_prediction_column_names
1080
+ pred_labels = (
1081
+ s.arrow_schema.prediction_object_detection_label_column_names
879
1082
  )
880
- s.arrow_schema.prediction_object_detection_label_column_names.bboxes_categories_column_name = (
881
- schema.object_detection_prediction_column_names.categories_column_name # noqa: E501
1083
+ pred_labels.bboxes_coordinates_column_name = (
1084
+ obj_det_pred.bounding_boxes_coordinates_column_name
882
1085
  )
883
- if (
884
- schema.object_detection_prediction_column_names.scores_column_name
885
- is not None
886
- ):
887
- s.arrow_schema.prediction_object_detection_label_column_names.bboxes_scores_column_name = (
888
- schema.object_detection_prediction_column_names.scores_column_name # noqa: E501
1086
+ pred_labels.bboxes_categories_column_name = (
1087
+ obj_det_pred.categories_column_name
1088
+ )
1089
+ if obj_det_pred.scores_column_name is not None:
1090
+ pred_labels.bboxes_scores_column_name = (
1091
+ obj_det_pred.scores_column_name
889
1092
  )
890
1093
 
891
1094
  if schema.semantic_segmentation_prediction_column_names is not None:
892
- s.arrow_schema.prediction_semantic_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
893
- schema.semantic_segmentation_prediction_column_names.polygon_coordinates_column_name
1095
+ seg_pred_cols = schema.semantic_segmentation_prediction_column_names
1096
+ pred_seg_labels = s.arrow_schema.prediction_semantic_segmentation_label_column_names
1097
+ pred_seg_labels.polygons_coordinates_column_name = (
1098
+ seg_pred_cols.polygon_coordinates_column_name
894
1099
  )
895
- s.arrow_schema.prediction_semantic_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
896
- schema.semantic_segmentation_prediction_column_names.categories_column_name
1100
+ pred_seg_labels.polygons_categories_column_name = (
1101
+ seg_pred_cols.categories_column_name
897
1102
  )
898
1103
 
899
1104
  if schema.instance_segmentation_prediction_column_names is not None:
900
- s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
901
- schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name
1105
+ inst_seg_pred_cols = (
1106
+ schema.instance_segmentation_prediction_column_names
902
1107
  )
903
- s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
904
- schema.instance_segmentation_prediction_column_names.categories_column_name
1108
+ pred_inst_seg_labels = s.arrow_schema.prediction_instance_segmentation_label_column_names
1109
+ pred_inst_seg_labels.polygons_coordinates_column_name = (
1110
+ inst_seg_pred_cols.polygon_coordinates_column_name
905
1111
  )
906
- if (
907
- schema.instance_segmentation_prediction_column_names.scores_column_name
908
- is not None
909
- ):
910
- s.arrow_schema.prediction_instance_segmentation_label_column_names.polygons_scores_column_name = ( # noqa: E501
911
- schema.instance_segmentation_prediction_column_names.scores_column_name
1112
+ pred_inst_seg_labels.polygons_categories_column_name = (
1113
+ inst_seg_pred_cols.categories_column_name
1114
+ )
1115
+ if inst_seg_pred_cols.scores_column_name is not None:
1116
+ pred_inst_seg_labels.polygons_scores_column_name = (
1117
+ inst_seg_pred_cols.scores_column_name
912
1118
  )
913
1119
  if (
914
- schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name
1120
+ inst_seg_pred_cols.bounding_boxes_coordinates_column_name
915
1121
  is not None
916
1122
  ):
917
- s.arrow_schema.prediction_instance_segmentation_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
918
- schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name
1123
+ pred_inst_seg_labels.bboxes_coordinates_column_name = (
1124
+ inst_seg_pred_cols.bounding_boxes_coordinates_column_name
919
1125
  )
920
1126
 
921
1127
  if schema.prediction_score_column_name is not None:
@@ -1038,50 +1244,61 @@ def _get_pb_schema(
1038
1244
 
1039
1245
  if model_type == ModelTypes.OBJECT_DETECTION:
1040
1246
  if schema.object_detection_actual_column_names is not None:
1041
- s.arrow_schema.actual_object_detection_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
1042
- schema.object_detection_actual_column_names.bounding_boxes_coordinates_column_name
1247
+ obj_det_actual = schema.object_detection_actual_column_names
1248
+ actual_labels = (
1249
+ s.arrow_schema.actual_object_detection_label_column_names
1043
1250
  )
1044
- s.arrow_schema.actual_object_detection_label_column_names.bboxes_categories_column_name = ( # noqa: E501
1045
- schema.object_detection_actual_column_names.categories_column_name
1251
+ actual_labels.bboxes_coordinates_column_name = (
1252
+ obj_det_actual.bounding_boxes_coordinates_column_name
1046
1253
  )
1047
- if (
1048
- schema.object_detection_actual_column_names.scores_column_name
1049
- is not None
1050
- ):
1051
- s.arrow_schema.actual_object_detection_label_column_names.bboxes_scores_column_name = ( # noqa: E501
1052
- schema.object_detection_actual_column_names.scores_column_name
1254
+ actual_labels.bboxes_categories_column_name = (
1255
+ obj_det_actual.categories_column_name
1256
+ )
1257
+ if obj_det_actual.scores_column_name is not None:
1258
+ actual_labels.bboxes_scores_column_name = (
1259
+ obj_det_actual.scores_column_name
1053
1260
  )
1054
1261
 
1055
1262
  if schema.semantic_segmentation_actual_column_names is not None:
1056
- s.arrow_schema.actual_semantic_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
1057
- schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
1263
+ sem_seg_actual = schema.semantic_segmentation_actual_column_names
1264
+ sem_seg_labels = (
1265
+ s.arrow_schema.actual_semantic_segmentation_label_column_names
1266
+ )
1267
+ sem_seg_labels.polygons_coordinates_column_name = (
1268
+ sem_seg_actual.polygon_coordinates_column_name
1058
1269
  )
1059
- s.arrow_schema.actual_semantic_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
1060
- schema.semantic_segmentation_actual_column_names.categories_column_name
1270
+ sem_seg_labels.polygons_categories_column_name = (
1271
+ sem_seg_actual.categories_column_name
1061
1272
  )
1062
1273
 
1063
1274
  if schema.instance_segmentation_actual_column_names is not None:
1064
- s.arrow_schema.actual_instance_segmentation_label_column_names.polygons_coordinates_column_name = ( # noqa: E501
1065
- schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
1275
+ inst_seg_actual = schema.instance_segmentation_actual_column_names
1276
+ inst_seg_labels = (
1277
+ s.arrow_schema.actual_instance_segmentation_label_column_names
1066
1278
  )
1067
- s.arrow_schema.actual_instance_segmentation_label_column_names.polygons_categories_column_name = ( # noqa: E501
1068
- schema.instance_segmentation_actual_column_names.categories_column_name
1279
+ inst_seg_labels.polygons_coordinates_column_name = (
1280
+ inst_seg_actual.polygon_coordinates_column_name
1281
+ )
1282
+ inst_seg_labels.polygons_categories_column_name = (
1283
+ inst_seg_actual.categories_column_name
1069
1284
  )
1070
1285
  if (
1071
- schema.instance_segmentation_actual_column_names.bounding_boxes_coordinates_column_name
1286
+ inst_seg_actual.bounding_boxes_coordinates_column_name
1072
1287
  is not None
1073
1288
  ):
1074
- s.arrow_schema.actual_instance_segmentation_label_column_names.bboxes_coordinates_column_name = ( # noqa: E501
1075
- schema.instance_segmentation_actual_column_names.bounding_boxes_coordinates_column_name
1289
+ inst_seg_labels.bboxes_coordinates_column_name = (
1290
+ inst_seg_actual.bounding_boxes_coordinates_column_name
1076
1291
  )
1077
1292
 
1078
1293
  if model_type == ModelTypes.GENERATIVE_LLM:
1079
1294
  if schema.prompt_template_column_names is not None:
1080
- s.arrow_schema.prompt_template_column_names.template_column_name = (
1081
- schema.prompt_template_column_names.template_column_name
1295
+ prompt_template_names = schema.prompt_template_column_names
1296
+ arrow_prompt_names = s.arrow_schema.prompt_template_column_names
1297
+ arrow_prompt_names.template_column_name = (
1298
+ prompt_template_names.template_column_name
1082
1299
  )
1083
- s.arrow_schema.prompt_template_column_names.template_version_column_name = ( # noqa: E501
1084
- schema.prompt_template_column_names.template_version_column_name
1300
+ arrow_prompt_names.template_version_column_name = (
1301
+ prompt_template_names.template_version_column_name
1085
1302
  )
1086
1303
  if schema.llm_config_column_names is not None:
1087
1304
  s.arrow_schema.llm_config_column_names.model_column_name = (
@@ -1114,6 +1331,7 @@ def _get_pb_schema_corpus(
1114
1331
  schema: CorpusSchema,
1115
1332
  model_id: str,
1116
1333
  ) -> pb2.Schema:
1334
+ """Construct a protocol buffer Schema from CorpusSchema for document corpus logging."""
1117
1335
  s = pb2.Schema()
1118
1336
  s.constants.model_id = model_id
1119
1337
  s.constants.environment = pb2.Schema.Environment.CORPUS
@@ -1127,11 +1345,12 @@ def _get_pb_schema_corpus(
1127
1345
  schema.document_version_column_name
1128
1346
  )
1129
1347
  if schema.document_text_embedding_column_names is not None:
1130
- s.arrow_schema.document_column_names.text_column_name.vector_column_name = schema.document_text_embedding_column_names.vector_column_name # noqa: E501
1131
- s.arrow_schema.document_column_names.text_column_name.data_column_name = schema.document_text_embedding_column_names.data_column_name # noqa: E501
1132
- if (
1133
- schema.document_text_embedding_column_names.link_to_data_column_name
1134
- is not None
1135
- ):
1136
- s.arrow_schema.document_column_names.text_column_name.link_to_data_column_name = schema.document_text_embedding_column_names.link_to_data_column_name # noqa: E501
1348
+ doc_text_emb_cols = schema.document_text_embedding_column_names
1349
+ doc_text_col = s.arrow_schema.document_column_names.text_column_name
1350
+ doc_text_col.vector_column_name = doc_text_emb_cols.vector_column_name
1351
+ doc_text_col.data_column_name = doc_text_emb_cols.data_column_name
1352
+ if doc_text_emb_cols.link_to_data_column_name is not None:
1353
+ doc_text_col.link_to_data_column_name = (
1354
+ doc_text_emb_cols.link_to_data_column_name
1355
+ )
1137
1356
  return s