arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,20 +1,16 @@
1
- import json
1
+ """Common type definitions and data models used across the ML Client."""
2
+
2
3
  import logging
3
4
  import math
5
+ from collections.abc import Iterator
4
6
  from dataclasses import asdict, dataclass, replace
5
7
  from datetime import datetime
6
8
  from decimal import Decimal
7
9
  from enum import Enum, unique
8
10
  from itertools import chain
9
11
  from typing import (
10
- Dict,
11
- Iterable,
12
- List,
13
12
  NamedTuple,
14
- Sequence,
15
- Set,
16
- Tuple,
17
- TypeVar,
13
+ Self,
18
14
  )
19
15
 
20
16
  import numpy as np
@@ -42,12 +38,15 @@ from arize.exceptions.parameters import InvalidValueType
42
38
  # )
43
39
  # from arize.utils.errors import InvalidValueType
44
40
  from arize.logging import get_truncation_warning_message
41
+ from arize.utils.types import is_dict_of, is_iterable_of, is_list_of
45
42
 
46
43
  logger = logging.getLogger(__name__)
47
44
 
48
45
 
49
46
  @unique
50
47
  class ModelTypes(Enum):
48
+ """Enum representing supported model types in Arize."""
49
+
51
50
  NUMERIC = 1
52
51
  SCORE_CATEGORICAL = 2
53
52
  RANKING = 3
@@ -58,7 +57,8 @@ class ModelTypes(Enum):
58
57
  MULTI_CLASS = 8
59
58
 
60
59
  @classmethod
61
- def list_types(cls):
60
+ def list_types(cls) -> list[str]:
61
+ """Return a list of all type names in this enum."""
62
62
  return [t.name for t in cls]
63
63
 
64
64
 
@@ -70,7 +70,10 @@ CATEGORICAL_MODEL_TYPES = [
70
70
 
71
71
 
72
72
  class DocEnum(Enum):
73
- def __new__(cls, value, doc=None):
73
+ """Enum subclass supporting inline documentation for enum members."""
74
+
75
+ def __new__(cls, value: object, doc: str | None = None) -> Self:
76
+ """Create a new enum instance with optional documentation."""
74
77
  self = object.__new__(
75
78
  cls
76
79
  ) # calling super().__new__(value) here would fail
@@ -80,13 +83,13 @@ class DocEnum(Enum):
80
83
  return self
81
84
 
82
85
  def __repr__(self) -> str:
86
+ """Return a string representation including documentation."""
83
87
  return f"{self.name} metrics include: {self.__doc__}"
84
88
 
85
89
 
86
90
  @unique
87
91
  class Metrics(DocEnum):
88
- """
89
- Metric groupings, used for validation of schema columns in log() call.
92
+ """Metric groupings, used for validation of schema columns in log() call.
90
93
 
91
94
  See docstring descriptions of the Enum with __doc__ or __repr__(), e.g.:
92
95
  Metrics.RANKING.__doc__
@@ -105,6 +108,8 @@ class Metrics(DocEnum):
105
108
 
106
109
  @unique
107
110
  class Environments(Enum):
111
+ """Enum representing deployment environments for models."""
112
+
108
113
  TRAINING = 1
109
114
  VALIDATION = 2
110
115
  PRODUCTION = 3
@@ -114,11 +119,18 @@ class Environments(Enum):
114
119
 
115
120
  @dataclass
116
121
  class EmbeddingColumnNames:
122
+ """Column names for embedding feature data."""
123
+
117
124
  vector_column_name: str = ""
118
125
  data_column_name: str | None = None
119
126
  link_to_data_column_name: str | None = None
120
127
 
121
- def __post_init__(self):
128
+ def __post_init__(self) -> None:
129
+ """Validate that vector column name is specified.
130
+
131
+ Raises:
132
+ ValueError: If vector_column_name is empty.
133
+ """
122
134
  if not self.vector_column_name:
123
135
  raise ValueError(
124
136
  "embedding_features require a vector to be specified. You can "
@@ -126,7 +138,8 @@ class EmbeddingColumnNames:
126
138
  "(from arize.pandas.embeddings) if you do not have them"
127
139
  )
128
140
 
129
- def __iter__(self):
141
+ def __iter__(self) -> Iterator[str | None]:
142
+ """Iterate over the embedding column names."""
130
143
  return iter(
131
144
  (
132
145
  self.vector_column_name,
@@ -137,24 +150,23 @@ class EmbeddingColumnNames:
137
150
 
138
151
 
139
152
  class Embedding(NamedTuple):
140
- vector: List[float]
141
- data: str | List[str] | None = None
153
+ """Container for embedding vector data with optional raw data and links."""
154
+
155
+ vector: list[float]
156
+ data: str | list[str] | None = None
142
157
  link_to_data: str | None = None
143
158
 
144
159
  def validate(self, emb_name: str | int | float) -> None:
145
- """
146
- Validates that the embedding object passed is of the correct format.
147
- That is, validations must be passed for vector, data & link_to_data.
160
+ """Validates that the embedding object passed is of the correct format.
148
161
 
149
- Arguments:
150
- ---------
151
- emb_name (str, int, float): Name of the embedding feature the
152
- vector belongs to
162
+ Ensures validations are passed for vector, data, and link_to_data fields.
153
163
 
154
- Raises:
155
- ------
156
- TypeError: If the embedding fields are of the wrong type
164
+ Args:
165
+ emb_name: Name of the embedding feature the
166
+ vector belongs to.
157
167
 
168
+ Raises:
169
+ TypeError: If the embedding fields are of the wrong type.
158
170
  """
159
171
  if self.vector is not None:
160
172
  self._validate_embedding_vector(emb_name)
@@ -167,29 +179,23 @@ class Embedding(NamedTuple):
167
179
  if self.link_to_data is not None:
168
180
  self._validate_embedding_link_to_data(emb_name, self.link_to_data)
169
181
 
170
- return None
182
+ return
171
183
 
172
184
  def _validate_embedding_vector(
173
185
  self,
174
186
  emb_name: str | int | float,
175
187
  ) -> None:
176
- """
177
- Validates that the embedding vector passed is of the correct format.
178
- That is:
179
- 1. Type must be list or convertible to list (like numpy arrays,
180
- pandas Series)
181
- 2. List must not be empty
182
- 3. Elements in list must be floats
183
-
184
- Arguments:
185
- ---------
186
- emb_name (str, int, float): Name of the embedding feature the vector
187
- belongs to
188
+ """Validates that the embedding vector passed is of the correct format.
188
189
 
189
- Raises:
190
- ------
191
- TypeError: If the embedding does not satisfy requirements above
190
+ Requirements: 1) Type must be list or convertible to list (like numpy arrays,
191
+ pandas Series), 2) List must not be empty, 3) Elements in list must be floats.
192
192
 
193
+ Args:
194
+ emb_name: Name of the embedding feature the vector
195
+ belongs to.
196
+
197
+ Raises:
198
+ TypeError: If the embedding does not satisfy requirements above.
193
199
  """
194
200
  if not Embedding._is_valid_iterable(self.vector):
195
201
  raise TypeError(
@@ -209,21 +215,19 @@ class Embedding(NamedTuple):
209
215
 
210
216
  @staticmethod
211
217
  def _validate_embedding_data(
212
- emb_name: str | int | float, data: str | List[str]
218
+ emb_name: str | int | float, data: str | list[str]
213
219
  ) -> None:
214
- """
215
- Validates that the embedding raw data field is of the correct format. That is:
216
- 1. Must be string or list of strings (NLP case)
220
+ """Validates that the embedding raw data field is of the correct format.
217
221
 
218
- Arguments:
219
- ---------
220
- emb_name (str, int, float): Name of the embedding feature the vector belongs to
221
- data (str, int, float): Raw data associated with the embedding feature. Typically raw text.
222
+ Requirement: Must be string or list of strings (NLP case).
222
223
 
223
- Raises:
224
- ------
225
- TypeError: If the embedding does not satisfy requirements above
224
+ Args:
225
+ emb_name: Name of the embedding feature the vector belongs to.
226
+ data: Raw data associated with the embedding feature.
227
+ Typically raw text.
226
228
 
229
+ Raises:
230
+ TypeError: If the embedding does not satisfy requirements above.
227
231
  """
228
232
  # Validate that data is a string or iterable of strings
229
233
  is_string = isinstance(data, str)
@@ -247,7 +251,7 @@ class Embedding(NamedTuple):
247
251
  f"Embedding data field must not contain more than {MAX_RAW_DATA_CHARACTERS} characters. "
248
252
  f"Found {character_count}."
249
253
  )
250
- elif character_count > MAX_RAW_DATA_CHARACTERS_TRUNCATION:
254
+ if character_count > MAX_RAW_DATA_CHARACTERS_TRUNCATION:
251
255
  logger.warning(
252
256
  get_truncation_warning_message(
253
257
  "Embedding raw data fields",
@@ -259,20 +263,17 @@ class Embedding(NamedTuple):
259
263
  def _validate_embedding_link_to_data(
260
264
  emb_name: str | int | float, link_to_data: str
261
265
  ) -> None:
262
- """
263
- Validates that the embedding link to data field is of the correct format. That is:
264
- 1. Must be string
266
+ """Validates that the embedding link to data field is of the correct format.
265
267
 
266
- Arguments:
267
- ---------
268
- emb_name (str, int, float): Name of the embedding feature the vector belongs to
269
- link_to_data (str): Link to source data of embedding feature, typically an image file on
270
- cloud storage
268
+ Requirement: Must be string.
271
269
 
272
- Raises:
273
- ------
274
- TypeError: If the embedding does not satisfy requirements above
270
+ Args:
271
+ emb_name: Name of the embedding feature the vector belongs to.
272
+ link_to_data: Link to source data of embedding feature, typically an
273
+ image file on cloud storage.
275
274
 
275
+ Raises:
276
+ TypeError: If the embedding does not satisfy requirements above.
276
277
  """
277
278
  if not isinstance(link_to_data, str):
278
279
  raise TypeError(
@@ -282,22 +283,18 @@ class Embedding(NamedTuple):
282
283
 
283
284
  @staticmethod
284
285
  def _is_valid_iterable(
285
- data: str | List[str] | List[float] | np.ndarray,
286
+ data: str | list[str] | list[float] | np.ndarray,
286
287
  ) -> bool:
287
- """
288
- Validates that the input data field is of the correct iterable type. That is:
289
- 1. List or
290
- 2. numpy array or
291
- 3. pandas Series
288
+ """Validates that the input data field is of the correct iterable type.
292
289
 
293
- Arguments:
294
- ---------
295
- data: input iterable
290
+ Accepted types: 1) List, 2) numpy array, or 3) pandas Series.
296
291
 
297
- Returns:
298
- -------
299
- True if the data type is one of the accepted iterable types, false otherwise
292
+ Args:
293
+ data: Input iterable.
300
294
 
295
+ Returns:
296
+ True if the data type is one of the accepted iterable types,
297
+ false otherwise.
301
298
  """
302
299
  return any(isinstance(data, t) for t in (list, np.ndarray))
303
300
 
@@ -327,12 +324,15 @@ class Embedding(NamedTuple):
327
324
 
328
325
 
329
326
  class LLMRunMetadata(NamedTuple):
327
+ """Metadata for LLM execution including token counts and latency."""
328
+
330
329
  total_token_count: int | None = None
331
330
  prompt_token_count: int | None = None
332
331
  response_token_count: int | None = None
333
332
  response_latency_ms: int | float | None = None
334
333
 
335
334
  def validate(self) -> None:
335
+ """Validate the field values and constraints."""
336
336
  allowed_types = (int, float, np.int16, np.int32, np.float16, np.float32)
337
337
  if not isinstance(self.total_token_count, allowed_types):
338
338
  raise InvalidValueType(
@@ -361,22 +361,20 @@ class LLMRunMetadata(NamedTuple):
361
361
 
362
362
 
363
363
  class ObjectDetectionColumnNames(NamedTuple):
364
- """
365
- Used to log object detection prediction and actual values that are assigned to the prediction or
366
- actual schema parameter.
364
+ """Used to log object detection prediction and actual values.
365
+
366
+ These values are assigned to the prediction or actual schema parameter.
367
367
 
368
- Arguments:
369
- ---------
370
- bounding_boxes_coordinates_column_name (str): Column name containing the coordinates of the
368
+ Args:
369
+ bounding_boxes_coordinates_column_name: Column name containing the coordinates of the
371
370
  rectangular outline that locates an object within an image or video. Pascal VOC format
372
371
  required. The contents of this column must be a List[List[float]].
373
- categories_column_name (str): Column name containing the predefined classes or labels used
372
+ categories_column_name: Column name containing the predefined classes or labels used
374
373
  by the model to classify the detected objects. The contents of this column must be List[str].
375
- scores_column_names (str, optional): Column name containing the confidence scores that the
374
+ scores_column_names: Column name containing the confidence scores that the
376
375
  model assigns to it's predictions, indicating how certain the model is that the predicted
377
376
  class is contained within the bounding box. This argument is only applicable for prediction
378
377
  values. The contents of this column must be List[float].
379
-
380
378
  """
381
379
 
382
380
  bounding_boxes_coordinates_column_name: str
@@ -385,19 +383,17 @@ class ObjectDetectionColumnNames(NamedTuple):
385
383
 
386
384
 
387
385
  class SemanticSegmentationColumnNames(NamedTuple):
388
- """
389
- Used to log semantic segmentation prediction and actual values that are assigned to the prediction or
390
- actual schema parameter.
386
+ """Used to log semantic segmentation prediction and actual values.
391
387
 
392
- Arguments:
393
- ---------
394
- polygon_coordinates_column_name (str): Column name containing the coordinates of the vertices
388
+ These values are assigned to the prediction or actual schema parameter.
389
+
390
+ Args:
391
+ polygon_coordinates_column_name: Column name containing the coordinates of the vertices
395
392
  of the polygon mask within an image or video. The first sublist contains the
396
393
  coordinates of the outline of the polygon. The subsequent sublists contain the coordinates
397
394
  of any cutouts within the polygon. The contents of this column must be a List[List[float]].
398
- categories_column_name (str): Column name containing the predefined classes or labels used
395
+ categories_column_name: Column name containing the predefined classes or labels used
399
396
  by the model to classify the detected objects. The contents of this column must be List[str].
400
-
401
397
  """
402
398
 
403
399
  polygon_coordinates_column_name: str
@@ -405,25 +401,22 @@ class SemanticSegmentationColumnNames(NamedTuple):
405
401
 
406
402
 
407
403
  class InstanceSegmentationPredictionColumnNames(NamedTuple):
408
- """
409
- Used to log instance segmentation prediction values that are assigned to the prediction schema parameter.
404
+ """Used to log instance segmentation prediction values for the prediction schema parameter.
410
405
 
411
- Arguments:
412
- ---------
413
- polygon_coordinates_column_name (str): Column name containing the coordinates of the vertices
406
+ Args:
407
+ polygon_coordinates_column_name: Column name containing the coordinates of the vertices
414
408
  of the polygon mask within an image or video. The first sublist contains the
415
409
  coordinates of the outline of the polygon. The subsequent sublists contain the coordinates
416
410
  of any cutouts within the polygon. The contents of this column must be a List[List[float]].
417
- categories_column_name (str): Column name containing the predefined classes or labels used
411
+ categories_column_name: Column name containing the predefined classes or labels used
418
412
  by the model to classify the detected objects. The contents of this column must be List[str].
419
- scores_column_name (str, optional): Column name containing the confidence scores that the
413
+ scores_column_name: Column name containing the confidence scores that the
420
414
  model assigns to it's predictions, indicating how certain the model is that the predicted
421
415
  class is contained within the bounding box. This argument is only applicable for prediction
422
416
  values. The contents of this column must be List[float].
423
- bounding_boxes_coordinates_column_name (str, optional): Column name containing the coordinates of the
417
+ bounding_boxes_coordinates_column_name: Column name containing the coordinates of the
424
418
  rectangular outline that locates an object within an image or video. Pascal VOC format
425
419
  required. The contents of this column must be a List[List[float]].
426
-
427
420
  """
428
421
 
429
422
  polygon_coordinates_column_name: str
@@ -433,20 +426,17 @@ class InstanceSegmentationPredictionColumnNames(NamedTuple):
433
426
 
434
427
 
435
428
  class InstanceSegmentationActualColumnNames(NamedTuple):
436
- """
437
- Used to log instance segmentation actual values that are assigned to the actual schema parameter.
429
+ """Used to log instance segmentation actual values that are assigned to the actual schema parameter.
438
430
 
439
- Arguments:
440
- ---------
441
- polygon_coordinates_column_name (str): Column name containing the coordinates of the
431
+ Args:
432
+ polygon_coordinates_column_name: Column name containing the coordinates of the
442
433
  polygon that locates an object within an image or video. The contents of this column
443
434
  must be a List[List[float]].
444
- categories_column_name (str): Column name containing the predefined classes or labels used
435
+ categories_column_name: Column name containing the predefined classes or labels used
445
436
  by the model to classify the detected objects. The contents of this column must be List[str].
446
- bounding_boxes_coordinates_column_name (str, optional): Column name containing the coordinates of the
437
+ bounding_boxes_coordinates_column_name: Column name containing the coordinates of the
447
438
  rectangular outline that locates an object within an image or video. Pascal VOC format
448
439
  required. The contents of this column must be a List[List[float]].
449
-
450
440
  """
451
441
 
452
442
  polygon_coordinates_column_name: str
@@ -455,12 +445,15 @@ class InstanceSegmentationActualColumnNames(NamedTuple):
455
445
 
456
446
 
457
447
  class ObjectDetectionLabel(NamedTuple):
458
- bounding_boxes_coordinates: List[List[float]]
459
- categories: List[str]
448
+ """Label data for object detection tasks with bounding boxes and categories."""
449
+
450
+ bounding_boxes_coordinates: list[list[float]]
451
+ categories: list[str]
460
452
  # Actual Object Detection Labels won't have scores
461
- scores: List[float] | None = None
453
+ scores: list[float] | None = None
462
454
 
463
- def validate(self, prediction_or_actual: str):
455
+ def validate(self, prediction_or_actual: str) -> None:
456
+ """Validate the object detection label fields and constraints."""
464
457
  # Validate bounding boxes
465
458
  self._validate_bounding_boxes_coordinates()
466
459
  # Validate categories
@@ -470,7 +463,7 @@ class ObjectDetectionLabel(NamedTuple):
470
463
  # Validate we have the same number of bounding boxes, categories and scores
471
464
  self._validate_count_match()
472
465
 
473
- def _validate_bounding_boxes_coordinates(self):
466
+ def _validate_bounding_boxes_coordinates(self) -> None:
474
467
  if not is_list_of(self.bounding_boxes_coordinates, list):
475
468
  raise TypeError(
476
469
  "Object Detection Label bounding boxes must be a list of lists of floats"
@@ -478,14 +471,14 @@ class ObjectDetectionLabel(NamedTuple):
478
471
  for coordinates in self.bounding_boxes_coordinates:
479
472
  _validate_bounding_box_coordinates(coordinates)
480
473
 
481
- def _validate_categories(self):
474
+ def _validate_categories(self) -> None:
482
475
  # Allows for categories as empty strings
483
476
  if not is_list_of(self.categories, str):
484
477
  raise TypeError(
485
478
  "Object Detection Label categories must be a list of strings"
486
479
  )
487
480
 
488
- def _validate_scores(self, prediction_or_actual: str):
481
+ def _validate_scores(self, prediction_or_actual: str) -> None:
489
482
  if self.scores is None:
490
483
  if prediction_or_actual == "prediction":
491
484
  raise ValueError(
@@ -507,7 +500,7 @@ class ObjectDetectionLabel(NamedTuple):
507
500
  f"{self.scores}"
508
501
  )
509
502
 
510
- def _validate_count_match(self):
503
+ def _validate_count_match(self) -> None:
511
504
  n_bounding_boxes = len(self.bounding_boxes_coordinates)
512
505
  if n_bounding_boxes == 0:
513
506
  raise ValueError(
@@ -534,10 +527,13 @@ class ObjectDetectionLabel(NamedTuple):
534
527
 
535
528
 
536
529
  class SemanticSegmentationLabel(NamedTuple):
537
- polygon_coordinates: List[List[float]]
538
- categories: List[str]
530
+ """Label data for semantic segmentation with polygon coordinates and categories."""
539
531
 
540
- def validate(self):
532
+ polygon_coordinates: list[list[float]]
533
+ categories: list[str]
534
+
535
+ def validate(self) -> None:
536
+ """Validate the field values and constraints."""
541
537
  # Validate polygon coordinates
542
538
  self._validate_polygon_coordinates()
543
539
  # Validate categories
@@ -545,17 +541,17 @@ class SemanticSegmentationLabel(NamedTuple):
545
541
  # Validate we have the same number of polygon coordinates and categories
546
542
  self._validate_count_match()
547
543
 
548
- def _validate_polygon_coordinates(self):
544
+ def _validate_polygon_coordinates(self) -> None:
549
545
  _validate_polygon_coordinates(self.polygon_coordinates)
550
546
 
551
- def _validate_categories(self):
547
+ def _validate_categories(self) -> None:
552
548
  # Allows for categories as empty strings
553
549
  if not is_list_of(self.categories, str):
554
550
  raise TypeError(
555
551
  "Semantic Segmentation Label categories must be a list of strings"
556
552
  )
557
553
 
558
- def _validate_count_match(self):
554
+ def _validate_count_match(self) -> None:
559
555
  n_polygon_coordinates = len(self.polygon_coordinates)
560
556
  if n_polygon_coordinates == 0:
561
557
  raise ValueError(
@@ -573,12 +569,15 @@ class SemanticSegmentationLabel(NamedTuple):
573
569
 
574
570
 
575
571
  class InstanceSegmentationPredictionLabel(NamedTuple):
576
- polygon_coordinates: List[List[float]]
577
- categories: List[str]
578
- scores: List[float] | None = None
579
- bounding_boxes_coordinates: List[List[float]] | None = None
572
+ """Prediction label for instance segmentation with polygons and category information."""
580
573
 
581
- def validate(self):
574
+ polygon_coordinates: list[list[float]]
575
+ categories: list[str]
576
+ scores: list[float] | None = None
577
+ bounding_boxes_coordinates: list[list[float]] | None = None
578
+
579
+ def validate(self) -> None:
580
+ """Validate the field values and constraints."""
582
581
  # Validate polygon coordinates
583
582
  self._validate_polygon_coordinates()
584
583
  # Validate categories
@@ -590,17 +589,17 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
590
589
  # Validate we have the same number of polygon coordinates and categories
591
590
  self._validate_count_match()
592
591
 
593
- def _validate_polygon_coordinates(self):
592
+ def _validate_polygon_coordinates(self) -> None:
594
593
  _validate_polygon_coordinates(self.polygon_coordinates)
595
594
 
596
- def _validate_categories(self):
595
+ def _validate_categories(self) -> None:
597
596
  # Allows for categories as empty strings
598
597
  if not is_list_of(self.categories, str):
599
598
  raise TypeError(
600
599
  "Instance Segmentation Prediction Label categories must be a list of strings"
601
600
  )
602
601
 
603
- def _validate_scores(self):
602
+ def _validate_scores(self) -> None:
604
603
  if self.scores is not None:
605
604
  if not is_list_of(self.scores, float):
606
605
  raise TypeError(
@@ -613,7 +612,7 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
613
612
  f"{self.scores}"
614
613
  )
615
614
 
616
- def _validate_bounding_boxes(self):
615
+ def _validate_bounding_boxes(self) -> None:
617
616
  if self.bounding_boxes_coordinates is not None:
618
617
  if not is_list_of(self.bounding_boxes_coordinates, list):
619
618
  raise TypeError(
@@ -622,7 +621,7 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
622
621
  for coordinates in self.bounding_boxes_coordinates:
623
622
  _validate_bounding_box_coordinates(coordinates)
624
623
 
625
- def _validate_count_match(self):
624
+ def _validate_count_match(self) -> None:
626
625
  n_polygon_coordinates = len(self.polygon_coordinates)
627
626
  if n_polygon_coordinates == 0:
628
627
  raise ValueError(
@@ -657,11 +656,14 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
657
656
 
658
657
 
659
658
  class InstanceSegmentationActualLabel(NamedTuple):
660
- polygon_coordinates: List[List[float]]
661
- categories: List[str]
662
- bounding_boxes_coordinates: List[List[float]] | None = None
659
+ """Actual label for instance segmentation with polygon coordinates and categories."""
663
660
 
664
- def validate(self):
661
+ polygon_coordinates: list[list[float]]
662
+ categories: list[str]
663
+ bounding_boxes_coordinates: list[list[float]] | None = None
664
+
665
+ def validate(self) -> None:
666
+ """Validate the field values and constraints."""
665
667
  # Validate polygon coordinates
666
668
  self._validate_polygon_coordinates()
667
669
  # Validate categories
@@ -671,17 +673,17 @@ class InstanceSegmentationActualLabel(NamedTuple):
671
673
  # Validate we have the same number of polygon coordinates and categories
672
674
  self._validate_count_match()
673
675
 
674
- def _validate_polygon_coordinates(self):
676
+ def _validate_polygon_coordinates(self) -> None:
675
677
  _validate_polygon_coordinates(self.polygon_coordinates)
676
678
 
677
- def _validate_categories(self):
679
+ def _validate_categories(self) -> None:
678
680
  # Allows for categories as empty strings
679
681
  if not is_list_of(self.categories, str):
680
682
  raise TypeError(
681
683
  "Instance Segmentation Actual Label categories must be a list of strings"
682
684
  )
683
685
 
684
- def _validate_bounding_boxes(self):
686
+ def _validate_bounding_boxes(self) -> None:
685
687
  if self.bounding_boxes_coordinates is not None:
686
688
  if not is_list_of(self.bounding_boxes_coordinates, list):
687
689
  raise TypeError(
@@ -690,7 +692,7 @@ class InstanceSegmentationActualLabel(NamedTuple):
690
692
  for coordinates in self.bounding_boxes_coordinates:
691
693
  _validate_bounding_box_coordinates(coordinates)
692
694
 
693
- def _validate_count_match(self):
695
+ def _validate_count_match(self) -> None:
694
696
  n_polygon_coordinates = len(self.polygon_coordinates)
695
697
  if n_polygon_coordinates == 0:
696
698
  raise ValueError(
@@ -717,27 +719,24 @@ class InstanceSegmentationActualLabel(NamedTuple):
717
719
 
718
720
 
719
721
  class MultiClassPredictionLabel(NamedTuple):
720
- """
721
- Used to log multi class prediction label
722
+ """Used to log multi class prediction label.
722
723
 
723
- Arguments:
724
- ---------
725
- MultiClassPredictionLabel
726
- prediction_scores (Dict[str, Union[float, int]]): the prediction scores of the classes.
727
- threshold_scores (Optional[Dict[str, Union[float, int]]]): the threshold scores of the classes.
724
+ Args:
725
+ prediction_scores: The prediction scores of the classes.
726
+ threshold_scores: The threshold scores of the classes.
728
727
  Only Multi Label will have threshold scores.
729
-
730
728
  """
731
729
 
732
- prediction_scores: Dict[str, float | int]
733
- threshold_scores: Dict[str, float | int] | None = None
730
+ prediction_scores: dict[str, float | int]
731
+ threshold_scores: dict[str, float | int] | None = None
734
732
 
735
- def validate(self):
733
+ def validate(self) -> None:
734
+ """Validate the field values and constraints."""
736
735
  # Validate scores
737
736
  self._validate_prediction_scores()
738
737
  self._validate_threshold_scores()
739
738
 
740
- def _validate_prediction_scores(self):
739
+ def _validate_prediction_scores(self) -> None:
741
740
  # prediction dictionary validations
742
741
  if not is_dict_of(
743
742
  self.prediction_scores,
@@ -778,7 +777,7 @@ class MultiClassPredictionLabel(NamedTuple):
778
777
  "invalid. All scores (values in dictionary) must be between 0 and 1, inclusive."
779
778
  )
780
779
 
781
- def _validate_threshold_scores(self):
780
+ def _validate_threshold_scores(self) -> None:
782
781
  if self.threshold_scores is None or len(self.threshold_scores) == 0:
783
782
  return
784
783
  if not is_dict_of(
@@ -822,24 +821,21 @@ class MultiClassPredictionLabel(NamedTuple):
822
821
 
823
822
 
824
823
  class MultiClassActualLabel(NamedTuple):
825
- """
826
- Used to log multi class actual label
827
-
828
- Arguments:
829
- ---------
830
- MultiClassActualLabel
831
- actual_scores (Dict[str, Union[float, int]]): the actual scores of the classes.
832
- Any class in actual_scores with a score of 1 will be sent to arize
824
+ """Used to log multi class actual label.
833
825
 
826
+ Args:
827
+ actual_scores: The actual scores of the classes.
828
+ Any class in actual_scores with a score of 1 will be sent to arize.
834
829
  """
835
830
 
836
- actual_scores: Dict[str, float | int]
831
+ actual_scores: dict[str, float | int]
837
832
 
838
- def validate(self):
833
+ def validate(self) -> None:
834
+ """Validate the field values and constraints."""
839
835
  # Validate scores
840
836
  self._validate_actual_scores()
841
837
 
842
- def _validate_actual_scores(self):
838
+ def _validate_actual_scores(self) -> None:
843
839
  if not is_dict_of(
844
840
  self.actual_scores,
845
841
  key_allowed_types=str,
@@ -879,12 +875,15 @@ class MultiClassActualLabel(NamedTuple):
879
875
 
880
876
 
881
877
  class RankingPredictionLabel(NamedTuple):
878
+ """Prediction label for ranking tasks with group and rank information."""
879
+
882
880
  group_id: str
883
881
  rank: int
884
882
  score: float | None = None
885
883
  label: str | None = None
886
884
 
887
- def validate(self):
885
+ def validate(self) -> None:
886
+ """Validate the field values and constraints."""
888
887
  # Validate existence of required fields: prediction_group_id and rank
889
888
  if self.group_id is None or self.rank is None:
890
889
  raise ValueError(
@@ -901,7 +900,7 @@ class RankingPredictionLabel(NamedTuple):
901
900
  if self.score is not None:
902
901
  self._validate_score()
903
902
 
904
- def _validate_group_id(self):
903
+ def _validate_group_id(self) -> None:
905
904
  if not isinstance(self.group_id, str):
906
905
  raise TypeError("Prediction Group ID must be a string")
907
906
  if not (1 <= len(self.group_id) <= 36):
@@ -909,7 +908,7 @@ class RankingPredictionLabel(NamedTuple):
909
908
  f"Prediction Group ID must have length between 1 and 36. Found {len(self.group_id)}"
910
909
  )
911
910
 
912
- def _validate_rank(self):
911
+ def _validate_rank(self) -> None:
913
912
  if not isinstance(self.rank, int):
914
913
  raise TypeError("Prediction Rank must be an int")
915
914
  if not (1 <= self.rank <= 100):
@@ -917,22 +916,25 @@ class RankingPredictionLabel(NamedTuple):
917
916
  f"Prediction Rank must be between 1 and 100, inclusive. Found {self.rank}"
918
917
  )
919
918
 
920
- def _validate_label(self):
919
+ def _validate_label(self) -> None:
921
920
  if not isinstance(self.label, str):
922
921
  raise TypeError("Prediction Label must be a str")
923
922
  if self.label == "":
924
923
  raise ValueError("Prediction Label must not be an empty string.")
925
924
 
926
- def _validate_score(self):
925
+ def _validate_score(self) -> None:
927
926
  if not isinstance(self.score, (float, int)):
928
927
  raise TypeError("Prediction Score must be a float or an int")
929
928
 
930
929
 
931
930
  class RankingActualLabel(NamedTuple):
932
- relevance_labels: List[str] | None = None
931
+ """Actual label for ranking tasks with relevance information."""
932
+
933
+ relevance_labels: list[str] | None = None
933
934
  relevance_score: float | None = None
934
935
 
935
- def validate(self):
936
+ def validate(self) -> None:
937
+ """Validate the field values and constraints."""
936
938
  # Validate relevance_labels type
937
939
  if self.relevance_labels is not None:
938
940
  self._validate_relevance_labels(self.relevance_labels)
@@ -941,7 +943,16 @@ class RankingActualLabel(NamedTuple):
941
943
  self._validate_relevance_score(self.relevance_score)
942
944
 
943
945
  @staticmethod
944
- def _validate_relevance_labels(relevance_labels: List[str]):
946
+ def _validate_relevance_labels(relevance_labels: list[str]) -> None:
947
+ """Validate relevance labels.
948
+
949
+ Args:
950
+ relevance_labels: List of relevance labels to validate.
951
+
952
+ Raises:
953
+ TypeError: If relevance_labels is not a list of strings.
954
+ ValueError: If any label is an empty string.
955
+ """
945
956
  if not is_list_of(relevance_labels, str):
946
957
  raise TypeError("Actual Relevance Labels must be a list of strings")
947
958
  if any(label == "" for label in relevance_labels):
@@ -950,17 +961,28 @@ class RankingActualLabel(NamedTuple):
950
961
  )
951
962
 
952
963
  @staticmethod
953
- def _validate_relevance_score(relevance_score: float):
964
+ def _validate_relevance_score(relevance_score: float) -> None:
965
+ """Validate relevance score.
966
+
967
+ Args:
968
+ relevance_score: Relevance score to validate.
969
+
970
+ Raises:
971
+ TypeError: If relevance_score is not a float or int.
972
+ """
954
973
  if not isinstance(relevance_score, (float, int)):
955
974
  raise TypeError("Actual Relevance score must be a float or an int")
956
975
 
957
976
 
958
977
  @dataclass
959
978
  class PromptTemplateColumnNames:
979
+ """Column names for prompt template configuration in LLM schemas."""
980
+
960
981
  template_column_name: str | None = None
961
982
  template_version_column_name: str | None = None
962
983
 
963
- def __iter__(self):
984
+ def __iter__(self) -> Iterator[str | None]:
985
+ """Iterate over the prompt template column names."""
964
986
  return iter(
965
987
  (self.template_column_name, self.template_version_column_name)
966
988
  )
@@ -968,21 +990,27 @@ class PromptTemplateColumnNames:
968
990
 
969
991
  @dataclass
970
992
  class LLMConfigColumnNames:
993
+ """Column names for LLM configuration parameters in schemas."""
994
+
971
995
  model_column_name: str | None = None
972
996
  params_column_name: str | None = None
973
997
 
974
- def __iter__(self):
998
+ def __iter__(self) -> Iterator[str | None]:
999
+ """Iterate over the LLM config column names."""
975
1000
  return iter((self.model_column_name, self.params_column_name))
976
1001
 
977
1002
 
978
1003
  @dataclass
979
1004
  class LLMRunMetadataColumnNames:
1005
+ """Column names for LLM run metadata fields in schemas."""
1006
+
980
1007
  total_token_count_column_name: str | None = None
981
1008
  prompt_token_count_column_name: str | None = None
982
1009
  response_token_count_column_name: str | None = None
983
1010
  response_latency_ms_column_name: str | None = None
984
1011
 
985
- def __iter__(self):
1012
+ def __iter__(self) -> Iterator[str | None]:
1013
+ """Iterate over the LLM run metadata column names."""
986
1014
  return iter(
987
1015
  (
988
1016
  self.total_token_count_column_name,
@@ -1011,11 +1039,19 @@ class LLMRunMetadataColumnNames:
1011
1039
  #
1012
1040
  @dataclass
1013
1041
  class SimilarityReference:
1042
+ """Reference to a prediction for similarity search operations."""
1043
+
1014
1044
  prediction_id: str
1015
1045
  reference_column_name: str
1016
1046
  prediction_timestamp: datetime | None = None
1017
1047
 
1018
- def __post_init__(self):
1048
+ def __post_init__(self) -> None:
1049
+ """Validate similarity reference fields after initialization.
1050
+
1051
+ Raises:
1052
+ ValueError: If prediction_id or reference_column_name is empty.
1053
+ TypeError: If prediction_timestamp is not a datetime object.
1054
+ """
1019
1055
  if self.prediction_id == "":
1020
1056
  raise ValueError("prediction id cannot be empty")
1021
1057
  if self.reference_column_name == "":
@@ -1028,11 +1064,20 @@ class SimilarityReference:
1028
1064
 
1029
1065
  @dataclass
1030
1066
  class SimilaritySearchParams:
1031
- references: List[SimilarityReference]
1067
+ """Parameters for configuring similarity search operations."""
1068
+
1069
+ references: list[SimilarityReference]
1032
1070
  search_column_name: str
1033
1071
  threshold: float = 0
1034
1072
 
1035
- def __post_init__(self):
1073
+ def __post_init__(self) -> None:
1074
+ """Validate similarity search parameters after initialization.
1075
+
1076
+ Raises:
1077
+ ValueError: If references list is invalid, search_column_name is
1078
+ empty, or threshold is out of range.
1079
+ TypeError: If any reference is not a SimilarityReference instance.
1080
+ """
1036
1081
  if (
1037
1082
  not self.references
1038
1083
  or len(self.references) <= 0
@@ -1054,176 +1099,157 @@ class SimilaritySearchParams:
1054
1099
 
1055
1100
  @dataclass(frozen=True)
1056
1101
  class BaseSchema:
1057
- def replace(self, **changes):
1102
+ """Base class for all schema definitions with immutable fields."""
1103
+
1104
+ def replace(self, **changes: object) -> Self:
1105
+ """Return a new instance with specified fields replaced."""
1058
1106
  return replace(self, **changes)
1059
1107
 
1060
- def asdict(self) -> Dict[str, str]:
1108
+ def asdict(self) -> dict[str, str]:
1109
+ """Convert the schema to a dictionary."""
1061
1110
  return asdict(self)
1062
1111
 
1063
- def get_used_columns(self) -> Set[str]:
1112
+ def get_used_columns(self) -> set[str]:
1113
+ """Return the set of column names used in this schema."""
1064
1114
  return set(self.get_used_columns_counts().keys())
1065
1115
 
1066
- def get_used_columns_counts(self) -> Dict[str, int]:
1116
+ def get_used_columns_counts(self) -> dict[str, int]:
1117
+ """Return a dict mapping column names to their usage count."""
1067
1118
  raise NotImplementedError()
1068
1119
 
1069
1120
 
1070
1121
  @dataclass(frozen=True)
1071
1122
  class TypedColumns:
1072
- """
1073
- Optional class used for explicit type enforcement of feature and tag columns in the dataframe.
1074
-
1075
- Usage:
1076
- ------
1077
- When initializing a Schema, use TypedColumns in place of a list of string column names.
1078
- e.g. feature_column_names=TypedColumns(
1079
- inferred=["feature_1", "feature_2"],
1080
- to_str=["feature_3"],
1081
- to_int=["feature_4"]
1082
- )
1123
+ """Optional class used for explicit type enforcement of feature and tag columns in the dataframe.
1124
+
1125
+ When initializing a Schema, use TypedColumns in place of a list of string column names::
1083
1126
 
1084
- Fields:
1085
- -------
1086
- inferred (Optional[List[str]]): List of columns that will not be altered at all.
1087
- The values in these columns will have their type inferred as Arize validates and ingests the data.
1088
- There's no difference between passing in all column names as inferred
1089
- vs. not using TypedColumns at all.
1090
- to_str (Optional[List[str]]): List of columns that should be cast to pandas StringDType.
1091
- to_int (Optional[List[str]]): List of columns that should be cast to pandas Int64DType.
1092
- to_float (Optional[List[str]]): List of columns that should be cast to pandas Float64DType.
1127
+ feature_column_names = TypedColumns(
1128
+ inferred=["feature_1", "feature_2"],
1129
+ to_str=["feature_3"],
1130
+ to_int=["feature_4"],
1131
+ )
1093
1132
 
1094
1133
  Notes:
1095
- -----
1096
1134
  - If a TypedColumns object is included in a Schema, pandas version 1.0.0 or higher is required.
1097
1135
  - Pandas StringDType is still considered an experimental field.
1098
1136
  - Columns not present in any field will not be captured in the Schema.
1099
1137
  - StringDType, Int64DType, and Float64DType are all nullable column types.
1100
- Null values will be ingested and represented in Arize as empty values.
1101
-
1138
+ Null values will be ingested and represented in Arize as empty values.
1102
1139
  """
1103
1140
 
1104
- inferred: List[str] | None = None
1105
- to_str: List[str] | None = None
1106
- to_int: List[str] | None = None
1107
- to_float: List[str] | None = None
1141
+ inferred: list[str] | None = None
1142
+ to_str: list[str] | None = None
1143
+ to_int: list[str] | None = None
1144
+ to_float: list[str] | None = None
1108
1145
 
1109
- def get_all_column_names(self) -> List[str]:
1146
+ def get_all_column_names(self) -> list[str]:
1147
+ """Return all column names across all conversion lists."""
1110
1148
  return list(chain.from_iterable(filter(None, self.__dict__.values())))
1111
1149
 
1112
- def has_duplicate_columns(self) -> Tuple[bool, Set[str]]:
1150
+ def has_duplicate_columns(self) -> tuple[bool, set[str]]:
1151
+ """Check for duplicate columns and return (has_duplicates, duplicate_set)."""
1113
1152
  # True if there are duplicates within a field's list or across fields.
1114
1153
  # Return a set of the duplicate column names.
1115
1154
  cols = self.get_all_column_names()
1116
- duplicates = set([x for x in cols if cols.count(x) > 1])
1155
+ duplicates = {x for x in cols if cols.count(x) > 1}
1117
1156
  return len(duplicates) > 0, duplicates
1118
1157
 
1119
1158
  def is_empty(self) -> bool:
1159
+ """Return True if no columns are configured for conversion."""
1120
1160
  return not self.get_all_column_names()
1121
1161
 
1122
1162
 
1123
1163
  @dataclass(frozen=True)
1124
1164
  class Schema(BaseSchema):
1125
- """
1126
- Used to organize and map column names containing model data within your Pandas dataframe to
1127
- Arize.
1165
+ """Used to organize and map column names containing model data within your Pandas dataframe to Arize.
1128
1166
 
1129
- Arguments:
1130
- ---------
1131
- prediction_id_column_name (str, optional): Column name for the predictions unique identifier.
1167
+ Args:
1168
+ prediction_id_column_name: Column name for the predictions unique identifier.
1132
1169
  Unique IDs are used to match a prediction to delayed actuals or feature importances in Arize.
1133
1170
  If prediction ids are not provided, it will default to an empty string "" and, when possible,
1134
1171
  Arize will create a random prediction id on the server side. Prediction id must be a string column
1135
1172
  with each row indicating a unique prediction event.
1136
- feature_column_names (Union[List[str], TypedColumns], optional): Column names for features.
1173
+ feature_column_names: Column names for features.
1137
1174
  The content of feature columns can be int, float, string. If TypedColumns is used,
1138
1175
  the columns will be cast to the provided types prior to logging.
1139
- tag_column_names (Union[List[str], TypedColumns], optional): Column names for tags. The content of tag
1176
+ tag_column_names: Column names for tags. The content of tag
1140
1177
  columns can be int, float, string. If TypedColumns is used,
1141
1178
  the columns will be cast to the provided types prior to logging.
1142
- timestamp_column_name (str, optional): Column name for timestamps. The content of this
1179
+ timestamp_column_name: Column name for timestamps. The content of this
1143
1180
  column must be int Unix Timestamps in seconds.
1144
- prediction_label_column_name (str, optional): Column name for categorical prediction values.
1181
+ prediction_label_column_name: Column name for categorical prediction values.
1145
1182
  The content of this column must be convertible to string.
1146
- prediction_score_column_name (str, optional): Column name for numeric prediction values. The
1183
+ prediction_score_column_name: Column name for numeric prediction values. The
1147
1184
  content of this column must be int/float or list of dictionaries mapping class names to
1148
1185
  int/float scores in the case of MULTI_CLASS model types.
1149
- actual_label_column_name (str, optional): Column name for categorical ground truth values.
1186
+ actual_label_column_name: Column name for categorical ground truth values.
1150
1187
  The content of this column must be convertible to string.
1151
- actual_score_column_name (str, optional): Column name for numeric ground truth values. The
1188
+ actual_score_column_name: Column name for numeric ground truth values. The
1152
1189
  content of this column must be int/float or list of dictionaries mapping class names to
1153
1190
  int/float scores in the case of MULTI_CLASS model types.
1154
- shap_values_column_names (Dict[str, str], optional): Dictionary mapping feature column name
1191
+ shap_values_column_names: Dictionary mapping feature column name
1155
1192
  and corresponding SHAP feature importance column name. e.g.
1156
1193
  {{"feat_A": "feat_A_shap", "feat_B": "feat_B_shap"}}
1157
- embedding_feature_column_names (Dict[str, EmbeddingColumnNames], optional): Dictionary
1194
+ embedding_feature_column_names: Dictionary
1158
1195
  mapping embedding display names to EmbeddingColumnNames objects.
1159
- prediction_group_id_column_name (str, optional): Column name for ranking groups or lists in
1196
+ prediction_group_id_column_name: Column name for ranking groups or lists in
1160
1197
  ranking models. The content of this column must be string and is limited to 128 characters.
1161
- rank_column_name (str, optional): Column name for rank of each element on the its group or
1198
+ rank_column_name: Column name for rank of each element on the its group or
1162
1199
  list. The content of this column must be integer between 1-100.
1163
- relevance_score_column_name (str, optional): Column name for ranking model type numeric
1200
+ relevance_score_column_name: Column name for ranking model type numeric
1164
1201
  ground truth values. The content of this column must be int/float.
1165
- relevance_labels_column_name (str, optional): Column name for ranking model type categorical
1202
+ relevance_labels_column_name: Column name for ranking model type categorical
1166
1203
  ground truth values. The content of this column must be a string.
1167
- object_detection_prediction_column_names (ObjectDetectionColumnNames, optional):
1204
+ object_detection_prediction_column_names:
1168
1205
  ObjectDetectionColumnNames object containing information defining the predicted bounding
1169
1206
  boxes' coordinates, categories, and scores.
1170
- object_detection_actual_column_names (ObjectDetectionColumnNames, optional):
1207
+ object_detection_actual_column_names:
1171
1208
  ObjectDetectionColumnNames object containing information defining the actual bounding
1172
1209
  boxes' coordinates, categories, and scores.
1173
- prompt_column_names (str or EmbeddingColumnNames, optional): column names for text that is passed
1210
+ prompt_column_names: column names for text that is passed
1174
1211
  to the GENERATIVE_LLM model. It accepts a string (if sending only a text column) or
1175
1212
  EmbeddingColumnNames object containing the embedding vector data (required) and raw text
1176
1213
  (optional) for the input text your model acts on.
1177
- response_column_names (str or EmbeddingColumnNames, optional): column names for text generated by
1214
+ response_column_names: column names for text generated by
1178
1215
  the GENERATIVE_LLM model. It accepts a string (if sending only a text column) or
1179
1216
  EmbeddingColumnNames object containing the embedding vector data (required) and raw text
1180
1217
  (optional) for the text your model generates.
1181
- prompt_template_column_names (PromptTemplateColumnNames, optional): PromptTemplateColumnNames object
1218
+ prompt_template_column_names: PromptTemplateColumnNames object
1182
1219
  containing the prompt template and the prompt template version.
1183
- llm_config_column_names (LLMConfigColumnNames, optional): LLMConfigColumnNames object containing
1220
+ llm_config_column_names: LLMConfigColumnNames object containing
1184
1221
  the LLM's model name and its hyper parameters used at inference.
1185
- llm_run_metadata_column_names (LLMRunMetadataColumnNames, optional): LLMRunMetadataColumnNames
1222
+ llm_run_metadata_column_names: LLMRunMetadataColumnNames
1186
1223
  object containing token counts and latency metrics
1187
- retrieved_document_ids_column_name (str, optional): Column name for retrieved document ids.
1224
+ retrieved_document_ids_column_name: Column name for retrieved document ids.
1188
1225
  The content of this column must be lists with entries convertible to strings.
1189
- multi_class_threshold_scores_column_name (str, optional):
1226
+ multi_class_threshold_scores_column_name:
1190
1227
  Column name for dictionary that maps class names to threshold values. The
1191
1228
  content of this column must be dictionary of str -> int/float.
1192
- semantic_segmentation_prediction_column_names (SemanticSegmentationColumnNames, optional):
1229
+ semantic_segmentation_prediction_column_names:
1193
1230
  SemanticSegmentationColumnNames object containing information defining the predicted
1194
1231
  polygon coordinates and categories.
1195
- semantic_segmentation_actual_column_names (SemanticSegmentationColumnNames, optional):
1232
+ semantic_segmentation_actual_column_names:
1196
1233
  SemanticSegmentationColumnNames object containing information defining the actual
1197
1234
  polygon coordinates and categories.
1198
- instance_segmentation_prediction_column_names (InstanceSegmentationPredictionColumnNames, optional):
1235
+ instance_segmentation_prediction_column_names:
1199
1236
  InstanceSegmentationPredictionColumnNames object containing information defining the predicted
1200
1237
  polygon coordinates, categories, scores, and bounding box coordinates.
1201
- instance_segmentation_actual_column_names (InstanceSegmentationActualColumnNames, optional):
1238
+ instance_segmentation_actual_column_names:
1202
1239
  InstanceSegmentationActualColumnNames object containing information defining the actual
1203
1240
  polygon coordinates, categories, scores, and bounding box coordinates.
1204
-
1205
- Methods:
1206
- -------
1207
- replace(**changes):
1208
- Replaces fields of the schema
1209
- asdict():
1210
- Returns the schema as a dictionary. Warning: the types are not maintained, fields are
1211
- converted to strings.
1212
- get_used_columns():
1213
- Returns a set with the unique collection of columns to be used from the dataframe.
1214
-
1215
1241
  """
1216
1242
 
1217
1243
  prediction_id_column_name: str | None = None
1218
- feature_column_names: List[str] | TypedColumns | None = None
1219
- tag_column_names: List[str] | TypedColumns | None = None
1244
+ feature_column_names: list[str] | TypedColumns | None = None
1245
+ tag_column_names: list[str] | TypedColumns | None = None
1220
1246
  timestamp_column_name: str | None = None
1221
1247
  prediction_label_column_name: str | None = None
1222
1248
  prediction_score_column_name: str | None = None
1223
1249
  actual_label_column_name: str | None = None
1224
1250
  actual_score_column_name: str | None = None
1225
- shap_values_column_names: Dict[str, str] | None = None
1226
- embedding_feature_column_names: Dict[str, EmbeddingColumnNames] | None = (
1251
+ shap_values_column_names: dict[str, str] | None = None
1252
+ embedding_feature_column_names: dict[str, EmbeddingColumnNames] | None = (
1227
1253
  None # type:ignore
1228
1254
  )
1229
1255
  prediction_group_id_column_name: str | None = None
@@ -1242,7 +1268,7 @@ class Schema(BaseSchema):
1242
1268
  prompt_template_column_names: PromptTemplateColumnNames | None = None
1243
1269
  llm_config_column_names: LLMConfigColumnNames | None = None
1244
1270
  llm_run_metadata_column_names: LLMRunMetadataColumnNames | None = None
1245
- retrieved_document_ids_column_name: List[str] | None = None
1271
+ retrieved_document_ids_column_name: list[str] | None = None
1246
1272
  multi_class_threshold_scores_column_name: str | None = None
1247
1273
  semantic_segmentation_prediction_column_names: (
1248
1274
  SemanticSegmentationColumnNames | None
@@ -1257,7 +1283,8 @@ class Schema(BaseSchema):
1257
1283
  InstanceSegmentationActualColumnNames | None
1258
1284
  ) = None
1259
1285
 
1260
- def get_used_columns_counts(self) -> Dict[str, int]:
1286
+ def get_used_columns_counts(self) -> dict[str, int]:
1287
+ """Return a dict mapping column names to their usage count."""
1261
1288
  columns_used_counts = {}
1262
1289
 
1263
1290
  for field in self.__dataclass_fields__:
@@ -1364,6 +1391,7 @@ class Schema(BaseSchema):
1364
1391
  return columns_used_counts
1365
1392
 
1366
1393
  def has_prediction_columns(self) -> bool:
1394
+ """Return True if prediction columns are configured."""
1367
1395
  prediction_cols = (
1368
1396
  self.prediction_label_column_name,
1369
1397
  self.prediction_score_column_name,
@@ -1377,6 +1405,7 @@ class Schema(BaseSchema):
1377
1405
  return any(col is not None for col in prediction_cols)
1378
1406
 
1379
1407
  def has_actual_columns(self) -> bool:
1408
+ """Return True if actual label columns are configured."""
1380
1409
  actual_cols = (
1381
1410
  self.actual_label_column_name,
1382
1411
  self.actual_score_column_name,
@@ -1389,13 +1418,16 @@ class Schema(BaseSchema):
1389
1418
  return any(col is not None for col in actual_cols)
1390
1419
 
1391
1420
  def has_feature_importance_columns(self) -> bool:
1421
+ """Return True if feature importance columns are configured."""
1392
1422
  feature_importance_cols = (self.shap_values_column_names,)
1393
1423
  return any(col is not None for col in feature_importance_cols)
1394
1424
 
1395
1425
  def has_typed_columns(self) -> bool:
1426
+ """Return True if typed columns are configured."""
1396
1427
  return any(self.typed_column_fields())
1397
1428
 
1398
- def typed_column_fields(self) -> Set[str]:
1429
+ def typed_column_fields(self) -> set[str]:
1430
+ """Return the set of field names with typed columns."""
1399
1431
  return {
1400
1432
  field
1401
1433
  for field in self.__dataclass_fields__
@@ -1403,9 +1435,9 @@ class Schema(BaseSchema):
1403
1435
  }
1404
1436
 
1405
1437
  def is_delayed(self) -> bool:
1406
- """
1407
- This function checks if the given schema, according to the columns provided
1408
- by the user, has inherently latent information
1438
+ """Check if the schema has inherently latent information.
1439
+
1440
+ Determines this based on the columns provided by the user.
1409
1441
 
1410
1442
  Returns:
1411
1443
  bool: True if the schema is "delayed", i.e., does not possess prediction
@@ -1418,11 +1450,14 @@ class Schema(BaseSchema):
1418
1450
 
1419
1451
  @dataclass(frozen=True)
1420
1452
  class CorpusSchema(BaseSchema):
1453
+ """Schema for corpus data with document identification and content columns."""
1454
+
1421
1455
  document_id_column_name: str | None = None
1422
1456
  document_version_column_name: str | None = None
1423
1457
  document_text_embedding_column_names: EmbeddingColumnNames | None = None
1424
1458
 
1425
- def get_used_columns_counts(self) -> Dict[str, int]:
1459
+ def get_used_columns_counts(self) -> dict[str, int]:
1460
+ """Return a dict mapping column names to their usage count."""
1426
1461
  columns_used_counts = {}
1427
1462
 
1428
1463
  if self.document_id_column_name is not None:
@@ -1459,6 +1494,8 @@ class CorpusSchema(BaseSchema):
1459
1494
 
1460
1495
  @unique
1461
1496
  class ArizeTypes(Enum):
1497
+ """Enum representing supported data types in Arize platform."""
1498
+
1462
1499
  STR = 0
1463
1500
  FLOAT = 1
1464
1501
  INT = 2
@@ -1466,76 +1503,13 @@ class ArizeTypes(Enum):
1466
1503
 
1467
1504
  @dataclass(frozen=True)
1468
1505
  class TypedValue:
1506
+ """Container for a value with its associated Arize type."""
1507
+
1469
1508
  type: ArizeTypes
1470
1509
  value: str | bool | float | int
1471
1510
 
1472
1511
 
1473
- def is_json_str(s: str) -> bool:
1474
- try:
1475
- json.loads(s)
1476
- except ValueError:
1477
- return False
1478
- except TypeError:
1479
- return False
1480
- return True
1481
-
1482
-
1483
- T = TypeVar("T", bound=type)
1484
-
1485
-
1486
- def is_array_of(arr: Sequence[object], tp: T) -> bool:
1487
- return isinstance(arr, np.ndarray) and all(isinstance(x, tp) for x in arr)
1488
-
1489
-
1490
- def is_list_of(lst: Sequence[object], tp: T) -> bool:
1491
- return isinstance(lst, list) and all(isinstance(x, tp) for x in lst)
1492
-
1493
-
1494
- def is_iterable_of(lst: Sequence[object], tp: T) -> bool:
1495
- return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst)
1496
-
1497
-
1498
- def is_dict_of(
1499
- d: Dict[object, object],
1500
- key_allowed_types: T,
1501
- value_allowed_types: T = (),
1502
- value_list_allowed_types: T = (),
1503
- ) -> bool:
1504
- """
1505
- Method to check types are valid for dictionary.
1506
-
1507
- Arguments:
1508
- ---------
1509
- d (Dict[object, object]): dictionary itself
1510
- key_allowed_types (T): all allowed types for keys of dictionary
1511
- value_allowed_types (T): all allowed types for values of dictionary
1512
- value_list_allowed_types (T): if value is a list, these are the allowed
1513
- types for value list
1514
-
1515
- Returns:
1516
- -------
1517
- True if the data types of dictionary match the types specified by the
1518
- arguments, false otherwise
1519
-
1520
- """
1521
- if value_list_allowed_types and not isinstance(
1522
- value_list_allowed_types, tuple
1523
- ):
1524
- value_list_allowed_types = (value_list_allowed_types,)
1525
-
1526
- return (
1527
- isinstance(d, dict)
1528
- and all(isinstance(k, key_allowed_types) for k in d)
1529
- and all(
1530
- isinstance(v, value_allowed_types)
1531
- or any(is_list_of(v, t) for t in value_list_allowed_types)
1532
- for v in d.values()
1533
- if value_allowed_types or value_list_allowed_types
1534
- )
1535
- )
1536
-
1537
-
1538
- def _count_characters_raw_data(data: str | List[str]) -> int:
1512
+ def _count_characters_raw_data(data: str | list[str]) -> int:
1539
1513
  character_count = 0
1540
1514
  if isinstance(data, str):
1541
1515
  character_count = len(data)
@@ -1551,8 +1525,14 @@ def _count_characters_raw_data(data: str | List[str]) -> int:
1551
1525
 
1552
1526
 
1553
1527
  def add_to_column_count_dictionary(
1554
- column_dictionary: Dict[str, int], col: str | None
1555
- ):
1528
+ column_dictionary: dict[str, int], col: str | None
1529
+ ) -> None:
1530
+ """Increment the count for a column name in a dictionary.
1531
+
1532
+ Args:
1533
+ column_dictionary: Dictionary mapping column names to counts.
1534
+ col: The column name to increment, or None to skip.
1535
+ """
1556
1536
  if col:
1557
1537
  if col in column_dictionary:
1558
1538
  column_dictionary[col] += 1
@@ -1560,7 +1540,9 @@ def add_to_column_count_dictionary(
1560
1540
  column_dictionary[col] = 1
1561
1541
 
1562
1542
 
1563
- def _validate_bounding_box_coordinates(bounding_box_coordinates: List[float]):
1543
+ def _validate_bounding_box_coordinates(
1544
+ bounding_box_coordinates: list[float],
1545
+ ) -> None:
1564
1546
  if not is_list_of(bounding_box_coordinates, float):
1565
1547
  raise TypeError(
1566
1548
  "Each bounding box's coordinates must be a lists of floats"
@@ -1586,10 +1568,12 @@ def _validate_bounding_box_coordinates(bounding_box_coordinates: List[float]):
1586
1568
  f"top-left. Found {bounding_box_coordinates}"
1587
1569
  )
1588
1570
 
1589
- return None
1571
+ return
1590
1572
 
1591
1573
 
1592
- def _validate_polygon_coordinates(polygon_coordinates: List[List[float]]):
1574
+ def _validate_polygon_coordinates(
1575
+ polygon_coordinates: list[list[float]],
1576
+ ) -> None:
1593
1577
  if not is_list_of(polygon_coordinates, list):
1594
1578
  raise TypeError("Polygon coordinates must be a list of lists of floats")
1595
1579
  for coordinates in polygon_coordinates:
@@ -1651,27 +1635,41 @@ def _validate_polygon_coordinates(polygon_coordinates: List[List[float]]):
1651
1635
  f"{coordinates}"
1652
1636
  )
1653
1637
 
1654
- return None
1638
+ return
1655
1639
 
1656
1640
 
1657
- def segments_intersect(p1, p2, p3, p4):
1658
- """
1659
- Check if two line segments intersect.
1641
+ def segments_intersect(
1642
+ p1: tuple[float, float],
1643
+ p2: tuple[float, float],
1644
+ p3: tuple[float, float],
1645
+ p4: tuple[float, float],
1646
+ ) -> bool:
1647
+ """Check if two line segments intersect.
1660
1648
 
1661
1649
  Args:
1662
- p1, p2: First line segment endpoints (x,y)
1663
- p3, p4: Second line segment endpoints (x,y)
1650
+ p1: First endpoint of the first line segment (x,y)
1651
+ p2: Second endpoint of the first line segment (x,y)
1652
+ p3: First endpoint of the second line segment (x,y)
1653
+ p4: Second endpoint of the second line segment (x,y)
1664
1654
 
1665
1655
  Returns:
1666
1656
  True if the line segments intersect, False otherwise
1667
1657
  """
1668
1658
 
1669
1659
  # Function to calculate direction
1670
- def orientation(p, q, r):
1660
+ def orientation(
1661
+ p: tuple[float, float],
1662
+ q: tuple[float, float],
1663
+ r: tuple[float, float],
1664
+ ) -> float:
1671
1665
  return (q[1] - p[1]) * (r[0] - q[0]) - (q[0] - p[0]) * (r[1] - q[1])
1672
1666
 
1673
1667
  # Function to check if point q is on segment pr
1674
- def on_segment(p, q, r):
1668
+ def on_segment(
1669
+ p: tuple[float, float],
1670
+ q: tuple[float, float],
1671
+ r: tuple[float, float],
1672
+ ) -> bool:
1675
1673
  return (
1676
1674
  q[0] <= max(p[0], r[0])
1677
1675
  and q[0] >= min(p[0], r[0])
@@ -1703,17 +1701,20 @@ def segments_intersect(p1, p2, p3, p4):
1703
1701
 
1704
1702
  @unique
1705
1703
  class StatusCodes(Enum):
1704
+ """Enum representing status codes for operations and responses."""
1705
+
1706
1706
  UNSET = 0
1707
1707
  OK = 1
1708
1708
  ERROR = 2
1709
1709
 
1710
1710
  @classmethod
1711
- def list_codes(cls):
1711
+ def list_codes(cls) -> list[str]:
1712
+ """Return a list of all status code names."""
1712
1713
  return [t.name for t in cls]
1713
1714
 
1714
1715
 
1715
- def convert_element(value):
1716
- """Converts scalar or array to python native"""
1716
+ def convert_element(value: object) -> object:
1717
+ """Converts scalar or array to python native."""
1717
1718
  val = getattr(value, "tolist", lambda: value)()
1718
1719
  # Check if it's a list since elements from pd indices are converted to a
1719
1720
  # scalar whereas pd series/dataframe elements are converted to list of 1
@@ -1734,7 +1735,7 @@ PredictionLabelTypes = (
1734
1735
  | bool
1735
1736
  | int
1736
1737
  | float
1737
- | Tuple[str, float]
1738
+ | tuple[str, float]
1738
1739
  | ObjectDetectionLabel
1739
1740
  | RankingPredictionLabel
1740
1741
  | MultiClassPredictionLabel
@@ -1745,7 +1746,7 @@ ActualLabelTypes = (
1745
1746
  | bool
1746
1747
  | int
1747
1748
  | float
1748
- | Tuple[str, float]
1749
+ | tuple[str, float]
1749
1750
  | ObjectDetectionLabel
1750
1751
  | RankingActualLabel
1751
1752
  | MultiClassActualLabel