arize 8.0.0b1__py3-none-any.whl → 8.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. arize/__init__.py +9 -2
  2. arize/_client_factory.py +50 -0
  3. arize/_exporter/client.py +18 -17
  4. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  5. arize/_exporter/validation.py +1 -1
  6. arize/_flight/client.py +37 -17
  7. arize/_generated/api_client/api/datasets_api.py +6 -6
  8. arize/_generated/api_client/api/experiments_api.py +6 -6
  9. arize/_generated/api_client/api/projects_api.py +3 -3
  10. arize/_lazy.py +61 -10
  11. arize/client.py +66 -50
  12. arize/config.py +175 -48
  13. arize/constants/config.py +1 -0
  14. arize/constants/ml.py +9 -16
  15. arize/constants/spans.py +5 -10
  16. arize/datasets/client.py +45 -28
  17. arize/datasets/errors.py +1 -1
  18. arize/datasets/validation.py +2 -2
  19. arize/embeddings/auto_generator.py +16 -9
  20. arize/embeddings/base_generators.py +15 -9
  21. arize/embeddings/cv_generators.py +2 -2
  22. arize/embeddings/errors.py +2 -2
  23. arize/embeddings/nlp_generators.py +8 -8
  24. arize/embeddings/tabular_generators.py +6 -6
  25. arize/exceptions/base.py +0 -52
  26. arize/exceptions/config.py +22 -0
  27. arize/exceptions/parameters.py +1 -330
  28. arize/exceptions/values.py +8 -5
  29. arize/experiments/__init__.py +4 -0
  30. arize/experiments/client.py +31 -18
  31. arize/experiments/evaluators/base.py +12 -9
  32. arize/experiments/evaluators/executors.py +16 -7
  33. arize/experiments/evaluators/rate_limiters.py +3 -1
  34. arize/experiments/evaluators/types.py +9 -7
  35. arize/experiments/evaluators/utils.py +7 -5
  36. arize/experiments/functions.py +128 -58
  37. arize/experiments/tracing.py +4 -1
  38. arize/experiments/types.py +34 -31
  39. arize/logging.py +54 -33
  40. arize/ml/batch_validation/errors.py +10 -1004
  41. arize/ml/batch_validation/validator.py +351 -291
  42. arize/ml/bounded_executor.py +25 -6
  43. arize/ml/casting.py +51 -33
  44. arize/ml/client.py +43 -35
  45. arize/ml/proto.py +21 -22
  46. arize/ml/stream_validation.py +64 -27
  47. arize/ml/surrogate_explainer/mimic.py +18 -10
  48. arize/ml/types.py +27 -67
  49. arize/pre_releases.py +10 -6
  50. arize/projects/client.py +9 -4
  51. arize/py.typed +0 -0
  52. arize/regions.py +11 -11
  53. arize/spans/client.py +125 -31
  54. arize/spans/columns.py +32 -36
  55. arize/spans/conversion.py +12 -11
  56. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  57. arize/spans/validation/annotations/value_validation.py +11 -14
  58. arize/spans/validation/common/argument_validation.py +3 -3
  59. arize/spans/validation/common/dataframe_form_validation.py +7 -7
  60. arize/spans/validation/common/value_validation.py +11 -14
  61. arize/spans/validation/evals/dataframe_form_validation.py +4 -4
  62. arize/spans/validation/evals/evals_validation.py +6 -6
  63. arize/spans/validation/evals/value_validation.py +1 -1
  64. arize/spans/validation/metadata/argument_validation.py +1 -1
  65. arize/spans/validation/metadata/dataframe_form_validation.py +2 -2
  66. arize/spans/validation/metadata/value_validation.py +23 -1
  67. arize/spans/validation/spans/dataframe_form_validation.py +2 -2
  68. arize/spans/validation/spans/spans_validation.py +6 -6
  69. arize/utils/arrow.py +38 -2
  70. arize/utils/cache.py +2 -2
  71. arize/utils/dataframe.py +4 -4
  72. arize/utils/online_tasks/dataframe_preprocessor.py +15 -11
  73. arize/utils/openinference_conversion.py +10 -10
  74. arize/utils/proto.py +0 -1
  75. arize/utils/types.py +6 -6
  76. arize/version.py +1 -1
  77. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/METADATA +32 -7
  78. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/RECORD +81 -78
  79. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
  80. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
  81. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
@@ -24,12 +24,26 @@ class BoundedExecutor:
24
24
  self.executor = ThreadPoolExecutor(max_workers=max_workers)
25
25
  self.semaphore = BoundedSemaphore(bound + max_workers)
26
26
 
27
- """See concurrent.futures.Executor#submit"""
28
-
29
27
  def submit(
30
28
  self, fn: Callable[..., object], *args: object, **kwargs: object
31
29
  ) -> object:
32
- """Submit a callable to be executed with bounded concurrency."""
30
+ """Submit a callable to be executed with bounded concurrency.
31
+
32
+ This method blocks if the work queue is full (at the bound limit) until
33
+ space becomes available. Compatible with concurrent.futures.Executor.submit().
34
+
35
+ Args:
36
+ fn: The callable to execute.
37
+ *args: Positional arguments to pass to the callable.
38
+ **kwargs: Keyword arguments to pass to the callable.
39
+
40
+ Returns:
41
+ concurrent.futures.Future: A Future representing the pending execution.
42
+
43
+ Raises:
44
+ Exception: Any exception raised during submission is re-raised after
45
+ releasing the semaphore.
46
+ """
33
47
  self.semaphore.acquire()
34
48
  try:
35
49
  future = self.executor.submit(fn, *args, **kwargs)
@@ -40,8 +54,13 @@ class BoundedExecutor:
40
54
  future.add_done_callback(lambda _: self.semaphore.release())
41
55
  return future
42
56
 
43
- """See concurrent.futures.Executor#shutdown"""
44
-
45
57
  def shutdown(self, wait: bool = True) -> None:
46
- """Shutdown the executor, optionally waiting for pending tasks to complete."""
58
+ """Shutdown the executor, optionally waiting for pending tasks to complete.
59
+
60
+ Compatible with concurrent.futures.Executor.shutdown().
61
+
62
+ Args:
63
+ wait: If True, blocks until all pending tasks complete. If False,
64
+ returns immediately without waiting. Defaults to True.
65
+ """
47
66
  self.executor.shutdown(wait)
arize/ml/casting.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Type casting utilities for ML model data conversion."""
2
2
 
3
- # type: ignore[pb2]
4
3
  from __future__ import annotations
5
4
 
6
5
  import math
@@ -14,8 +13,8 @@ from arize.ml.types import (
14
13
  Schema,
15
14
  TypedColumns,
16
15
  TypedValue,
16
+ _normalize_column_names,
17
17
  )
18
- from arize.utils.types import is_list_of
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  import pandas as pd
@@ -25,7 +24,11 @@ class CastingError(Exception):
25
24
  """Raised when type casting fails for a value."""
26
25
 
27
26
  def __str__(self) -> str:
28
- """Return a human-readable error message."""
27
+ """Return a human-readable error message.
28
+
29
+ Returns:
30
+ str: The formatted error message describing the casting failure.
31
+ """
29
32
  return self.error_message()
30
33
 
31
34
  def __init__(self, error_msg: str, typed_value: TypedValue) -> None:
@@ -39,7 +42,11 @@ class CastingError(Exception):
39
42
  self.typed_value = typed_value
40
43
 
41
44
  def error_message(self) -> str:
42
- """Return the error message for this exception."""
45
+ """Return the error message for this exception.
46
+
47
+ Returns:
48
+ str: Detailed error message including the value, its type, target type, and failure reason.
49
+ """
43
50
  return (
44
51
  f"Failed to cast value {self.typed_value.value} of type {type(self.typed_value.value)} "
45
52
  f"to type {self.typed_value.type}. "
@@ -51,14 +58,18 @@ class ColumnCastingError(Exception):
51
58
  """Raised when type casting fails for a column."""
52
59
 
53
60
  def __str__(self) -> str:
54
- """Return a human-readable error message."""
61
+ """Return a human-readable error message.
62
+
63
+ Returns:
64
+ str: The formatted error message describing the column casting failure.
65
+ """
55
66
  return self.error_message()
56
67
 
57
68
  def __init__(
58
69
  self,
59
70
  error_msg: str,
60
- attempted_columns: str,
61
- attempted_type: TypedColumns,
71
+ attempted_columns: list[str],
72
+ attempted_type: str,
62
73
  ) -> None:
63
74
  """Initialize the exception with column casting context.
64
75
 
@@ -72,7 +83,11 @@ class ColumnCastingError(Exception):
72
83
  self.attempted_casting_type = attempted_type
73
84
 
74
85
  def error_message(self) -> str:
75
- """Return the error message for this exception."""
86
+ """Return the error message for this exception.
87
+
88
+ Returns:
89
+ str: Detailed error message including the target type, affected columns, and failure reason.
90
+ """
76
91
  return (
77
92
  f"Failed to cast to type {self.attempted_casting_type} "
78
93
  f"for columns: {log_a_list(self.attempted_casting_columns, 'and')}. "
@@ -84,7 +99,11 @@ class InvalidTypedColumnsError(Exception):
84
99
  """Raised when typed columns are invalid or incorrectly specified."""
85
100
 
86
101
  def __str__(self) -> str:
87
- """Return a human-readable error message."""
102
+ """Return a human-readable error message.
103
+
104
+ Returns:
105
+ str: The formatted error message describing the invalid typed columns.
106
+ """
88
107
  return self.error_message()
89
108
 
90
109
  def __init__(self, field_name: str, reason: str) -> None:
@@ -98,7 +117,11 @@ class InvalidTypedColumnsError(Exception):
98
117
  self.reason = reason
99
118
 
100
119
  def error_message(self) -> str:
101
- """Return the error message for this exception."""
120
+ """Return the error message for this exception.
121
+
122
+ Returns:
123
+ str: Error message describing which field has invalid typed columns and why.
124
+ """
102
125
  return f"The {self.field_name} TypedColumns object {self.reason}."
103
126
 
104
127
 
@@ -106,7 +129,11 @@ class InvalidSchemaFieldTypeError(Exception):
106
129
  """Raised when schema field has invalid or unexpected type."""
107
130
 
108
131
  def __str__(self) -> str:
109
- """Return a human-readable error message."""
132
+ """Return a human-readable error message.
133
+
134
+ Returns:
135
+ str: The formatted error message describing the invalid schema field type.
136
+ """
110
137
  return self.error_message()
111
138
 
112
139
  def __init__(self, msg: str) -> None:
@@ -118,7 +145,11 @@ class InvalidSchemaFieldTypeError(Exception):
118
145
  self.msg = msg
119
146
 
120
147
  def error_message(self) -> str:
121
- """Return the error message for this exception."""
148
+ """Return the error message for this exception.
149
+
150
+ Returns:
151
+ str: The error message describing the schema field type issue.
152
+ """
122
153
  return self.msg
123
154
 
124
155
 
@@ -132,12 +163,12 @@ def cast_typed_columns(
132
163
  a column across many SDK uploads.
133
164
 
134
165
  Args:
135
- dataframe (pd.DataFrame): A deepcopy of the user's dataframe.
166
+ dataframe (:class:`pandas.DataFrame`): A deepcopy of the user's dataframe.
136
167
  schema (Schema): The schema, which may include feature and tag column names
137
168
  in a TypedColumns object or a List[string].
138
169
 
139
170
  Returns:
140
- tuple[pd.DataFrame, Schema]: A tuple containing:
171
+ tuple[:class:`pandas.DataFrame`, Schema]: A tuple containing:
141
172
  - dataframe: The dataframe, with columns cast to the specified types.
142
173
  - schema: A new Schema object, with feature and tag column names converted
143
174
  to the List[string] format expected in downstream validation.
@@ -290,12 +321,12 @@ def _cast_columns(
290
321
  (feature_column_names or tag_column_names)
291
322
 
292
323
  Args:
293
- dataframe (pd.DataFrame): A deepcopy of the user's dataframe.
324
+ dataframe (:class:`pandas.DataFrame`): A deepcopy of the user's dataframe.
294
325
  columns (TypedColumns): The TypedColumns object, which specifies the columns
295
326
  to cast (and/or to not cast) and their target types.
296
327
 
297
328
  Returns:
298
- pd.DataFrame: The dataframe with columns cast to the specified types.
329
+ :class:`pandas.DataFrame`: The dataframe with columns cast to the specified types.
299
330
 
300
331
  Raises:
301
332
  ColumnCastingError: If casting fails.
@@ -350,12 +381,12 @@ def _cast_df(
350
381
  """Cast columns in a dataframe to the specified type.
351
382
 
352
383
  Args:
353
- df (pd.DataFrame): A deepcopy of the user's dataframe.
384
+ df (:class:`pandas.DataFrame`): A deepcopy of the user's dataframe.
354
385
  cols (list[str]): The list of column names to cast.
355
386
  target_type_str (str): The target type to cast to.
356
387
 
357
388
  Returns:
358
- pd.DataFrame: The dataframe with columns cast to the specified types.
389
+ :class:`pandas.DataFrame`: The dataframe with columns cast to the specified types.
359
390
 
360
391
  Raises:
361
392
  Exception: If casting fails. Common exceptions raised by astype() are
@@ -381,23 +412,10 @@ def _convert_schema_field_types(
381
412
  Schema: A Schema, with feature and tag column names converted to the
382
413
  List[string] format expected in downstream validation.
383
414
  """
384
- feature_column_names_list = (
415
+ feature_column_names_list = _normalize_column_names(
385
416
  schema.feature_column_names
386
- if is_list_of(schema.feature_column_names, str)
387
- else (
388
- schema.feature_column_names.get_all_column_names()
389
- if schema.feature_column_names
390
- else []
391
- )
392
- )
393
-
394
- tag_column_names_list = (
395
- schema.tag_column_names
396
- if is_list_of(schema.tag_column_names, str)
397
- else schema.tag_column_names.get_all_column_names()
398
- if schema.tag_column_names
399
- else []
400
417
  )
418
+ tag_column_names_list = _normalize_column_names(schema.tag_column_names)
401
419
 
402
420
  schema_dict = {
403
421
  "feature_column_names": feature_column_names_list,
arize/ml/client.py CHANGED
@@ -1,12 +1,11 @@
1
1
  """Client implementation for managing ML models in the Arize platform."""
2
2
 
3
- # type: ignore[pb2]
4
3
  from __future__ import annotations
5
4
 
6
5
  import copy
7
6
  import logging
8
7
  import time
9
- from typing import TYPE_CHECKING
8
+ from typing import TYPE_CHECKING, Any, cast
10
9
 
11
10
  from arize._generated.protocol.rec import public_pb2 as pb2
12
11
  from arize._lazy import require
@@ -377,8 +376,12 @@ class MLModelsClient:
377
376
 
378
377
  if embedding_features or prompt or response:
379
378
  # NOTE: Deep copy is necessary to avoid side effects on the original input dictionary
380
- combined_embedding_features = (
381
- embedding_features.copy() if embedding_features else {}
379
+ combined_embedding_features: dict[str, str | Embedding] = (
380
+ cast(
381
+ "dict[str, str | Embedding]", embedding_features.copy()
382
+ )
383
+ if embedding_features
384
+ else {}
382
385
  )
383
386
  # Map prompt as embedding features for generative models
384
387
  if prompt is not None:
@@ -395,7 +398,7 @@ class MLModelsClient:
395
398
  p.MergeFrom(embedding_feats)
396
399
 
397
400
  if tags or llm_run_metadata:
398
- joined_tags = copy.deepcopy(tags)
401
+ joined_tags = copy.deepcopy(tags) if tags is not None else {}
399
402
  if llm_run_metadata:
400
403
  if llm_run_metadata.total_token_count is not None:
401
404
  joined_tags[
@@ -522,7 +525,7 @@ class MLModelsClient:
522
525
  record=rec,
523
526
  headers=headers,
524
527
  timeout=timeout,
525
- indexes=None,
528
+ indexes=None, # type: ignore[arg-type]
526
529
  )
527
530
 
528
531
  def log(
@@ -542,7 +545,7 @@ class MLModelsClient:
542
545
  timeout: float | None = None,
543
546
  tmp_dir: str = "",
544
547
  ) -> requests.Response:
545
- """Log a batch of model predictions and actuals to Arize from a pandas DataFrame.
548
+ """Log a batch of model predictions and actuals to Arize from a :class:`pandas.DataFrame`.
546
549
 
547
550
  This method uploads multiple records to Arize in a single batch operation using
548
551
  Apache Arrow format for efficient transfer. The dataframe structure is defined
@@ -554,8 +557,8 @@ class MLModelsClient:
554
557
  model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
555
558
  RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
556
559
  spans module instead.
557
- dataframe: Pandas DataFrame containing the data to upload. Columns should
558
- correspond to the schema field mappings.
560
+ dataframe (:class:`pandas.DataFrame`): Pandas DataFrame containing the data to
561
+ upload. Columns should correspond to the schema field mappings.
559
562
  schema: Schema object (Schema or CorpusSchema) that defines the mapping between
560
563
  dataframe columns and Arize data fields (e.g., prediction_label_column_name,
561
564
  feature_column_names, etc.).
@@ -668,8 +671,10 @@ class MLModelsClient:
668
671
  dataframe = remove_extraneous_columns(df=dataframe, schema=schema)
669
672
 
670
673
  # always validate pd.Category is not present, if yes, convert to string
674
+ # Type ignore: pandas.api.types.is_categorical_dtype exists but stubs may be incomplete
671
675
  has_cat_col = any(
672
- ptypes.is_categorical_dtype(x) for x in dataframe.dtypes
676
+ ptypes.is_categorical_dtype(x) # type: ignore[attr-defined]
677
+ for x in dataframe.dtypes
673
678
  )
674
679
  if has_cat_col:
675
680
  cat_cols = [
@@ -691,14 +696,15 @@ class MLModelsClient:
691
696
  from arize.ml.surrogate_explainer.mimic import Mimic
692
697
 
693
698
  logger.debug("Running surrogate_explainability.")
694
- if schema.shap_values_column_names:
699
+ # Type ignore: schema typed as BaseSchema but runtime is Schema with these attrs
700
+ if schema.shap_values_column_names: # type: ignore[attr-defined]
695
701
  logger.info(
696
702
  "surrogate_explainability=True has no effect "
697
703
  "because shap_values_column_names is already specified in schema."
698
704
  )
699
- elif schema.feature_column_names is None or (
700
- hasattr(schema.feature_column_names, "__len__")
701
- and len(schema.feature_column_names) == 0
705
+ elif schema.feature_column_names is None or ( # type: ignore[attr-defined]
706
+ hasattr(schema.feature_column_names, "__len__") # type: ignore[attr-defined]
707
+ and len(schema.feature_column_names) == 0 # type: ignore[attr-defined]
702
708
  ):
703
709
  logger.info(
704
710
  "surrogate_explainability=True has no effect "
@@ -706,7 +712,9 @@ class MLModelsClient:
706
712
  )
707
713
  else:
708
714
  dataframe, schema = Mimic.augment(
709
- df=dataframe, schema=schema, model_type=model_type
715
+ df=dataframe,
716
+ schema=schema, # type: ignore[arg-type]
717
+ model_type=model_type,
710
718
  )
711
719
 
712
720
  # Convert to Arrow table
@@ -733,8 +741,8 @@ class MLModelsClient:
733
741
  pyarrow_schema=pa_table.schema,
734
742
  )
735
743
  if errors:
736
- for e in errors:
737
- logger.error(e)
744
+ for error in errors:
745
+ logger.error(error)
738
746
  raise ValidationFailure(errors)
739
747
  if validate:
740
748
  logger.debug("Performing values validation.")
@@ -745,8 +753,8 @@ class MLModelsClient:
745
753
  model_type=model_type,
746
754
  )
747
755
  if errors:
748
- for e in errors:
749
- logger.error(e)
756
+ for error in errors:
757
+ logger.error(error)
750
758
  raise ValidationFailure(errors)
751
759
 
752
760
  if isinstance(schema, Schema) and not schema.has_prediction_columns():
@@ -759,12 +767,12 @@ class MLModelsClient:
759
767
 
760
768
  if environment == Environments.CORPUS:
761
769
  proto_schema = _get_pb_schema_corpus(
762
- schema=schema,
770
+ schema=schema, # type: ignore[arg-type]
763
771
  model_id=model_name,
764
772
  )
765
773
  else:
766
774
  proto_schema = _get_pb_schema(
767
- schema=schema,
775
+ schema=schema, # type: ignore[arg-type]
768
776
  model_id=model_name,
769
777
  model_version=model_version,
770
778
  model_type=model_type,
@@ -811,10 +819,10 @@ class MLModelsClient:
811
819
  similarity_search_params: SimilaritySearchParams | None = None,
812
820
  stream_chunk_size: int | None = None,
813
821
  ) -> pd.DataFrame:
814
- """Export model data from Arize to a pandas DataFrame.
822
+ """Export model data from Arize to a :class:`pandas.DataFrame`.
815
823
 
816
824
  Retrieves prediction and optional actual data for a model within a specified time
817
- range and returns it as a pandas DataFrame for analysis.
825
+ range and returns it as a :class:`pandas.DataFrame` for analysis.
818
826
 
819
827
  Args:
820
828
  space_id: The space ID where the model resides.
@@ -835,8 +843,9 @@ class MLModelsClient:
835
843
  stream_chunk_size: Optional chunk size for streaming large result sets.
836
844
 
837
845
  Returns:
838
- A pandas DataFrame containing the exported data with columns for predictions,
839
- actuals (if requested), features, tags, timestamps, and other model metadata.
846
+ :class:`pandas.DataFrame`: A pandas DataFrame containing the exported data
847
+ with columns for predictions, actuals (if requested), features, tags,
848
+ timestamps, and other model metadata.
840
849
 
841
850
  Raises:
842
851
  RuntimeError: If the Flight client request fails or returns no response.
@@ -879,6 +888,7 @@ class MLModelsClient:
879
888
  def export_to_parquet(
880
889
  self,
881
890
  *,
891
+ path: str,
882
892
  space_id: str,
883
893
  model_name: str,
884
894
  environment: Environments,
@@ -891,13 +901,14 @@ class MLModelsClient:
891
901
  columns: list | None = None,
892
902
  similarity_search_params: SimilaritySearchParams | None = None,
893
903
  stream_chunk_size: int | None = None,
894
- ) -> pd.DataFrame:
895
- """Export model data from Arize to a Parquet file and return as DataFrame.
904
+ ) -> None:
905
+ """Export model data from Arize to a Parquet file.
896
906
 
897
907
  Retrieves prediction and optional actual data for a model within a specified time
898
- range, saves it as a Parquet file, and returns it as a pandas DataFrame.
908
+ range and writes it directly to a Parquet file at the specified path.
899
909
 
900
910
  Args:
911
+ path: The file path where the Parquet file will be written.
901
912
  space_id: The space ID where the model resides.
902
913
  model_name: The name of the model to export data from.
903
914
  environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
@@ -915,16 +926,12 @@ class MLModelsClient:
915
926
  filtering.
916
927
  stream_chunk_size: Optional chunk size for streaming large result sets.
917
928
 
918
- Returns:
919
- A pandas DataFrame containing the exported data. The data is also saved to a
920
- Parquet file by the underlying export client.
921
-
922
929
  Raises:
923
930
  RuntimeError: If the Flight client request fails or returns no response.
924
931
 
925
932
  Notes:
926
933
  - Uses Apache Arrow Flight for efficient data transfer
927
- - The Parquet file location is managed by the ArizeExportClient
934
+ - Data is written directly to the specified path as a Parquet file
928
935
  - Large exports may benefit from specifying stream_chunk_size
929
936
  """
930
937
  require(_BATCH_EXTRA, _BATCH_DEPS)
@@ -942,7 +949,8 @@ class MLModelsClient:
942
949
  exporter = ArizeExportClient(
943
950
  flight_client=flight_client,
944
951
  )
945
- return exporter.export_to_parquet(
952
+ exporter.export_to_parquet(
953
+ path=path,
946
954
  space_id=space_id,
947
955
  model_id=model_name,
948
956
  environment=environment,
@@ -981,7 +989,7 @@ class MLModelsClient:
981
989
  headers: dict[str, str],
982
990
  timeout: float | None,
983
991
  indexes: tuple,
984
- ) -> object:
992
+ ) -> cf.Future[Any]:
985
993
  """Post a record to Arize via async HTTP request with protobuf JSON serialization."""
986
994
  from google.protobuf.json_format import MessageToDict
987
995
 
arize/ml/proto.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Protocol buffer utilities for ML model data serialization."""
2
2
 
3
- # type: ignore[pb2]
4
3
  from __future__ import annotations
5
4
 
6
5
  from google.protobuf.timestamp_pb2 import Timestamp
@@ -26,17 +25,19 @@ from arize.ml.types import (
26
25
  from arize.utils.types import is_list_of
27
26
 
28
27
 
29
- def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
28
+ def get_pb_dictionary(d: object | None) -> dict[str, object]:
30
29
  """Convert a dictionary to protobuf format with string keys and pb2.Value values.
31
30
 
32
31
  Args:
33
- d: Dictionary to convert, or None.
32
+ d: Dictionary to convert, or :obj:`None`.
34
33
 
35
34
  Returns:
36
- Dictionary with string keys and protobuf Value objects, or empty dict if input is None.
35
+ Dictionary with string keys and protobuf Value objects, or empty dict if input is :obj:`None`.
37
36
  """
38
37
  if d is None:
39
38
  return {}
39
+ if not isinstance(d, dict):
40
+ return {}
40
41
  # Takes a dictionary and
41
42
  # - casts the keys as strings
42
43
  # - turns the values of the dictionary to our proto values pb2.Value()
@@ -48,7 +49,7 @@ def get_pb_dictionary(d: dict[object, object] | None) -> dict[str, object]:
48
49
  return converted_dict
49
50
 
50
51
 
51
- def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
52
+ def get_pb_value(name: object, value: pb2.Value) -> pb2.Value:
52
53
  """Convert a Python value to a protobuf Value object.
53
54
 
54
55
  Args:
@@ -56,7 +57,7 @@ def get_pb_value(name: str | int | float, value: pb2.Value) -> pb2.Value:
56
57
  value: The value to convert to protobuf format.
57
58
 
58
59
  Returns:
59
- A pb2.Value protobuf object, or None if value cannot be converted.
60
+ A pb2.Value protobuf object, or :obj:`None` if value cannot be converted.
60
61
 
61
62
  Raises:
62
63
  TypeError: If value type is not supported.
@@ -114,7 +115,8 @@ def get_pb_label(
114
115
  Raises:
115
116
  ValueError: If model_type is not supported.
116
117
  """
117
- value = convert_element(value)
118
+ # convert_element preserves value type but returns object for type safety
119
+ value = convert_element(value) # type: ignore[assignment]
118
120
  if model_type in NUMERIC_MODEL_TYPES:
119
121
  return _get_numeric_pb_label(prediction_or_actual, value)
120
122
  if (
@@ -129,7 +131,7 @@ def get_pb_label(
129
131
  if model_type == ModelTypes.MULTI_CLASS:
130
132
  return _get_multi_class_pb_label(value)
131
133
  raise ValueError(
132
- f"model_type must be one of: {[mt.prediction_or_actual for mt in ModelTypes]} "
134
+ f"model_type must be one of: {[mt.name for mt in ModelTypes]} "
133
135
  f"Got "
134
136
  f"{model_type} instead."
135
137
  )
@@ -139,10 +141,10 @@ def get_pb_timestamp(time_overwrite: int | None) -> object | None:
139
141
  """Convert a Unix timestamp to a protobuf Timestamp object.
140
142
 
141
143
  Args:
142
- time_overwrite: Unix epoch time in seconds, or None.
144
+ time_overwrite: Unix epoch time in seconds, or :obj:`None`.
143
145
 
144
146
  Returns:
145
- A protobuf Timestamp object, or None if input is None.
147
+ A protobuf Timestamp object, or :obj:`None` if input is :obj:`None`.
146
148
 
147
149
  Raises:
148
150
  TypeError: If time_overwrite is not an integer.
@@ -197,12 +199,12 @@ def get_pb_embedding(val: Embedding) -> pb2.Embedding:
197
199
 
198
200
  def _get_numeric_pb_label(
199
201
  prediction_or_actual: str,
200
- value: int | float,
202
+ value: object,
201
203
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
202
204
  if not isinstance(value, (int, float)):
203
205
  raise TypeError(
204
206
  f"Received {prediction_or_actual}_label = {value}, of type {type(value)}. "
205
- + f"{[mt.prediction_or_actual for mt in NUMERIC_MODEL_TYPES]} models accept labels of "
207
+ + f"{[mt.name for mt in NUMERIC_MODEL_TYPES]} models accept labels of "
206
208
  f"type int or float"
207
209
  )
208
210
  if prediction_or_actual == "prediction":
@@ -214,7 +216,7 @@ def _get_numeric_pb_label(
214
216
 
215
217
  def _get_score_categorical_pb_label(
216
218
  prediction_or_actual: str,
217
- value: bool | str | tuple[str, float],
219
+ value: object,
218
220
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
219
221
  sc = pb2.ScoreCategorical()
220
222
  if isinstance(value, bool):
@@ -229,7 +231,7 @@ def _get_score_categorical_pb_label(
229
231
  raise TypeError(
230
232
  f"Received {prediction_or_actual}_label = {value}, of type "
231
233
  f"{type(value)}[{type(value[0])}, None]. "
232
- f"{[mt.prediction_or_actual for mt in CATEGORICAL_MODEL_TYPES]} models accept "
234
+ f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept "
233
235
  "values of type str, bool, or Tuple[str, float]"
234
236
  )
235
237
  if not isinstance(value[0], (bool, str)) or not isinstance(
@@ -238,7 +240,7 @@ def _get_score_categorical_pb_label(
238
240
  raise TypeError(
239
241
  f"Received {prediction_or_actual}_label = {value}, of type "
240
242
  f"{type(value)}[{type(value[0])}, {type(value[1])}]. "
241
- f"{[mt.prediction_or_actual for mt in CATEGORICAL_MODEL_TYPES]} models accept "
243
+ f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept "
242
244
  "values of type str, bool, or Tuple[str or bool, float]"
243
245
  )
244
246
  if isinstance(value[0], bool):
@@ -249,7 +251,7 @@ def _get_score_categorical_pb_label(
249
251
  else:
250
252
  raise TypeError(
251
253
  f"Received {prediction_or_actual}_label = {value}, of type {type(value)}. "
252
- + f"{[mt.prediction_or_actual for mt in CATEGORICAL_MODEL_TYPES]} models accept values "
254
+ + f"{[mt.name for mt in CATEGORICAL_MODEL_TYPES]} models accept values "
253
255
  f"of type str, bool, int, float or Tuple[str, float]"
254
256
  )
255
257
  if prediction_or_actual == "prediction":
@@ -261,10 +263,7 @@ def _get_score_categorical_pb_label(
261
263
 
262
264
  def _get_cv_pb_label(
263
265
  prediction_or_actual: str,
264
- value: ObjectDetectionLabel
265
- | SemanticSegmentationLabel
266
- | InstanceSegmentationPredictionLabel
267
- | InstanceSegmentationActualLabel,
266
+ value: object,
268
267
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
269
268
  if isinstance(value, ObjectDetectionLabel):
270
269
  return _get_object_detection_pb_label(prediction_or_actual, value)
@@ -429,7 +428,7 @@ def _get_instance_segmentation_actual_pb_label(
429
428
 
430
429
 
431
430
  def _get_ranking_pb_label(
432
- value: RankingPredictionLabel | RankingActualLabel,
431
+ value: object,
433
432
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
434
433
  if not isinstance(value, (RankingPredictionLabel, RankingActualLabel)):
435
434
  raise InvalidValueType(
@@ -460,7 +459,7 @@ def _get_ranking_pb_label(
460
459
 
461
460
 
462
461
  def _get_multi_class_pb_label(
463
- value: MultiClassPredictionLabel | MultiClassActualLabel,
462
+ value: object,
464
463
  ) -> pb2.PredictionLabel | pb2.ActualLabel:
465
464
  if not isinstance(
466
465
  value, (MultiClassPredictionLabel, MultiClassActualLabel)