arize 8.0.0a22__py3-none-any.whl → 8.0.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. arize/__init__.py +17 -9
  2. arize/_exporter/client.py +55 -36
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +268 -55
  65. arize/config.py +365 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +299 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +31 -12
  83. arize/embeddings/tabular_generators.py +32 -20
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +1 -0
  94. arize/experiments/client.py +389 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/models/__init__.py +1 -0
  107. arize/models/batch_validation/__init__.py +1 -0
  108. arize/models/batch_validation/errors.py +543 -65
  109. arize/models/batch_validation/validator.py +339 -300
  110. arize/models/bounded_executor.py +20 -7
  111. arize/models/casting.py +75 -29
  112. arize/models/client.py +326 -107
  113. arize/models/proto.py +95 -40
  114. arize/models/stream_validation.py +42 -14
  115. arize/models/surrogate_explainer/__init__.py +1 -0
  116. arize/models/surrogate_explainer/mimic.py +24 -13
  117. arize/pre_releases.py +43 -0
  118. arize/projects/__init__.py +1 -0
  119. arize/projects/client.py +129 -0
  120. arize/regions.py +40 -0
  121. arize/spans/__init__.py +1 -0
  122. arize/spans/client.py +130 -106
  123. arize/spans/columns.py +13 -0
  124. arize/spans/conversion.py +54 -38
  125. arize/spans/validation/__init__.py +1 -0
  126. arize/spans/validation/annotations/__init__.py +1 -0
  127. arize/spans/validation/annotations/annotations_validation.py +6 -4
  128. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  129. arize/spans/validation/annotations/value_validation.py +35 -11
  130. arize/spans/validation/common/__init__.py +1 -0
  131. arize/spans/validation/common/argument_validation.py +33 -8
  132. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  133. arize/spans/validation/common/errors.py +211 -11
  134. arize/spans/validation/common/value_validation.py +80 -13
  135. arize/spans/validation/evals/__init__.py +1 -0
  136. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  137. arize/spans/validation/evals/evals_validation.py +34 -4
  138. arize/spans/validation/evals/value_validation.py +26 -3
  139. arize/spans/validation/metadata/__init__.py +1 -1
  140. arize/spans/validation/metadata/argument_validation.py +14 -5
  141. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  142. arize/spans/validation/metadata/value_validation.py +24 -10
  143. arize/spans/validation/spans/__init__.py +1 -0
  144. arize/spans/validation/spans/dataframe_form_validation.py +34 -13
  145. arize/spans/validation/spans/spans_validation.py +35 -4
  146. arize/spans/validation/spans/value_validation.py +76 -7
  147. arize/types.py +293 -157
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +19 -2
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/version.py +3 -1
  158. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
  159. arize-8.0.0a23.dist-info/RECORD +174 -0
  160. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
  161. arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
  162. arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
  163. arize/_generated/protocol/flight/export_pb2.py +0 -61
  164. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  165. arize-8.0.0a22.dist-info/RECORD +0 -146
  166. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
arize/types.py CHANGED
@@ -1,19 +1,17 @@
1
+ """Common type definitions and data models used across the Arize SDK."""
2
+
1
3
  import json
2
4
  import logging
3
5
  import math
6
+ from collections.abc import Iterable, Iterator, Sequence
4
7
  from dataclasses import asdict, dataclass, replace
5
8
  from datetime import datetime
6
9
  from decimal import Decimal
7
10
  from enum import Enum, unique
8
11
  from itertools import chain
9
12
  from typing import (
10
- Dict,
11
- Iterable,
12
- List,
13
13
  NamedTuple,
14
- Sequence,
15
- Set,
16
- Tuple,
14
+ Self,
17
15
  TypeVar,
18
16
  )
19
17
 
@@ -48,6 +46,8 @@ logger = logging.getLogger(__name__)
48
46
 
49
47
  @unique
50
48
  class ModelTypes(Enum):
49
+ """Enum representing supported model types in Arize."""
50
+
51
51
  NUMERIC = 1
52
52
  SCORE_CATEGORICAL = 2
53
53
  RANKING = 3
@@ -58,7 +58,8 @@ class ModelTypes(Enum):
58
58
  MULTI_CLASS = 8
59
59
 
60
60
  @classmethod
61
- def list_types(cls):
61
+ def list_types(cls) -> list[str]:
62
+ """Return a list of all type names in this enum."""
62
63
  return [t.name for t in cls]
63
64
 
64
65
 
@@ -70,7 +71,10 @@ CATEGORICAL_MODEL_TYPES = [
70
71
 
71
72
 
72
73
  class DocEnum(Enum):
73
- def __new__(cls, value, doc=None):
74
+ """Enum subclass supporting inline documentation for enum members."""
75
+
76
+ def __new__(cls, value: object, doc: str | None = None) -> Self:
77
+ """Create a new enum instance with optional documentation."""
74
78
  self = object.__new__(
75
79
  cls
76
80
  ) # calling super().__new__(value) here would fail
@@ -80,13 +84,13 @@ class DocEnum(Enum):
80
84
  return self
81
85
 
82
86
  def __repr__(self) -> str:
87
+ """Return a string representation including documentation."""
83
88
  return f"{self.name} metrics include: {self.__doc__}"
84
89
 
85
90
 
86
91
  @unique
87
92
  class Metrics(DocEnum):
88
- """
89
- Metric groupings, used for validation of schema columns in log() call.
93
+ """Metric groupings, used for validation of schema columns in log() call.
90
94
 
91
95
  See docstring descriptions of the Enum with __doc__ or __repr__(), e.g.:
92
96
  Metrics.RANKING.__doc__
@@ -105,6 +109,8 @@ class Metrics(DocEnum):
105
109
 
106
110
  @unique
107
111
  class Environments(Enum):
112
+ """Enum representing deployment environments for models."""
113
+
108
114
  TRAINING = 1
109
115
  VALIDATION = 2
110
116
  PRODUCTION = 3
@@ -114,11 +120,18 @@ class Environments(Enum):
114
120
 
115
121
  @dataclass
116
122
  class EmbeddingColumnNames:
123
+ """Column names for embedding feature data."""
124
+
117
125
  vector_column_name: str = ""
118
126
  data_column_name: str | None = None
119
127
  link_to_data_column_name: str | None = None
120
128
 
121
- def __post_init__(self):
129
+ def __post_init__(self) -> None:
130
+ """Validate that vector column name is specified.
131
+
132
+ Raises:
133
+ ValueError: If vector_column_name is empty.
134
+ """
122
135
  if not self.vector_column_name:
123
136
  raise ValueError(
124
137
  "embedding_features require a vector to be specified. You can "
@@ -126,7 +139,8 @@ class EmbeddingColumnNames:
126
139
  "(from arize.pandas.embeddings) if you do not have them"
127
140
  )
128
141
 
129
- def __iter__(self):
142
+ def __iter__(self) -> Iterator[str | None]:
143
+ """Iterate over the embedding column names."""
130
144
  return iter(
131
145
  (
132
146
  self.vector_column_name,
@@ -137,14 +151,16 @@ class EmbeddingColumnNames:
137
151
 
138
152
 
139
153
  class Embedding(NamedTuple):
140
- vector: List[float]
141
- data: str | List[str] | None = None
154
+ """Container for embedding vector data with optional raw data and links."""
155
+
156
+ vector: list[float]
157
+ data: str | list[str] | None = None
142
158
  link_to_data: str | None = None
143
159
 
144
160
  def validate(self, emb_name: str | int | float) -> None:
145
- """
146
- Validates that the embedding object passed is of the correct format.
147
- That is, validations must be passed for vector, data & link_to_data.
161
+ """Validates that the embedding object passed is of the correct format.
162
+
163
+ Ensures validations are passed for vector, data, and link_to_data fields.
148
164
 
149
165
  Arguments:
150
166
  ---------
@@ -167,19 +183,16 @@ class Embedding(NamedTuple):
167
183
  if self.link_to_data is not None:
168
184
  self._validate_embedding_link_to_data(emb_name, self.link_to_data)
169
185
 
170
- return None
186
+ return
171
187
 
172
188
  def _validate_embedding_vector(
173
189
  self,
174
190
  emb_name: str | int | float,
175
191
  ) -> None:
176
- """
177
- Validates that the embedding vector passed is of the correct format.
178
- That is:
179
- 1. Type must be list or convertible to list (like numpy arrays,
180
- pandas Series)
181
- 2. List must not be empty
182
- 3. Elements in list must be floats
192
+ """Validates that the embedding vector passed is of the correct format.
193
+
194
+ Requirements: 1) Type must be list or convertible to list (like numpy arrays,
195
+ pandas Series), 2) List must not be empty, 3) Elements in list must be floats.
183
196
 
184
197
  Arguments:
185
198
  ---------
@@ -209,11 +222,11 @@ class Embedding(NamedTuple):
209
222
 
210
223
  @staticmethod
211
224
  def _validate_embedding_data(
212
- emb_name: str | int | float, data: str | List[str]
225
+ emb_name: str | int | float, data: str | list[str]
213
226
  ) -> None:
214
- """
215
- Validates that the embedding raw data field is of the correct format. That is:
216
- 1. Must be string or list of strings (NLP case)
227
+ """Validates that the embedding raw data field is of the correct format.
228
+
229
+ Requirement: Must be string or list of strings (NLP case).
217
230
 
218
231
  Arguments:
219
232
  ---------
@@ -247,7 +260,7 @@ class Embedding(NamedTuple):
247
260
  f"Embedding data field must not contain more than {MAX_RAW_DATA_CHARACTERS} characters. "
248
261
  f"Found {character_count}."
249
262
  )
250
- elif character_count > MAX_RAW_DATA_CHARACTERS_TRUNCATION:
263
+ if character_count > MAX_RAW_DATA_CHARACTERS_TRUNCATION:
251
264
  logger.warning(
252
265
  get_truncation_warning_message(
253
266
  "Embedding raw data fields",
@@ -259,9 +272,9 @@ class Embedding(NamedTuple):
259
272
  def _validate_embedding_link_to_data(
260
273
  emb_name: str | int | float, link_to_data: str
261
274
  ) -> None:
262
- """
263
- Validates that the embedding link to data field is of the correct format. That is:
264
- 1. Must be string
275
+ """Validates that the embedding link to data field is of the correct format.
276
+
277
+ Requirement: Must be string.
265
278
 
266
279
  Arguments:
267
280
  ---------
@@ -282,13 +295,11 @@ class Embedding(NamedTuple):
282
295
 
283
296
  @staticmethod
284
297
  def _is_valid_iterable(
285
- data: str | List[str] | List[float] | np.ndarray,
298
+ data: str | list[str] | list[float] | np.ndarray,
286
299
  ) -> bool:
287
- """
288
- Validates that the input data field is of the correct iterable type. That is:
289
- 1. List or
290
- 2. numpy array or
291
- 3. pandas Series
300
+ """Validates that the input data field is of the correct iterable type.
301
+
302
+ Accepted types: 1) List, 2) numpy array, or 3) pandas Series.
292
303
 
293
304
  Arguments:
294
305
  ---------
@@ -327,12 +338,15 @@ class Embedding(NamedTuple):
327
338
 
328
339
 
329
340
  class LLMRunMetadata(NamedTuple):
341
+ """Metadata for LLM execution including token counts and latency."""
342
+
330
343
  total_token_count: int | None = None
331
344
  prompt_token_count: int | None = None
332
345
  response_token_count: int | None = None
333
346
  response_latency_ms: int | float | None = None
334
347
 
335
348
  def validate(self) -> None:
349
+ """Validate the field values and constraints."""
336
350
  allowed_types = (int, float, np.int16, np.int32, np.float16, np.float32)
337
351
  if not isinstance(self.total_token_count, allowed_types):
338
352
  raise InvalidValueType(
@@ -361,9 +375,9 @@ class LLMRunMetadata(NamedTuple):
361
375
 
362
376
 
363
377
  class ObjectDetectionColumnNames(NamedTuple):
364
- """
365
- Used to log object detection prediction and actual values that are assigned to the prediction or
366
- actual schema parameter.
378
+ """Used to log object detection prediction and actual values.
379
+
380
+ These values are assigned to the prediction or actual schema parameter.
367
381
 
368
382
  Arguments:
369
383
  ---------
@@ -385,9 +399,9 @@ class ObjectDetectionColumnNames(NamedTuple):
385
399
 
386
400
 
387
401
  class SemanticSegmentationColumnNames(NamedTuple):
388
- """
389
- Used to log semantic segmentation prediction and actual values that are assigned to the prediction or
390
- actual schema parameter.
402
+ """Used to log semantic segmentation prediction and actual values.
403
+
404
+ These values are assigned to the prediction or actual schema parameter.
391
405
 
392
406
  Arguments:
393
407
  ---------
@@ -405,8 +419,7 @@ class SemanticSegmentationColumnNames(NamedTuple):
405
419
 
406
420
 
407
421
  class InstanceSegmentationPredictionColumnNames(NamedTuple):
408
- """
409
- Used to log instance segmentation prediction values that are assigned to the prediction schema parameter.
422
+ """Used to log instance segmentation prediction values for the prediction schema parameter.
410
423
 
411
424
  Arguments:
412
425
  ---------
@@ -433,8 +446,7 @@ class InstanceSegmentationPredictionColumnNames(NamedTuple):
433
446
 
434
447
 
435
448
  class InstanceSegmentationActualColumnNames(NamedTuple):
436
- """
437
- Used to log instance segmentation actual values that are assigned to the actual schema parameter.
449
+ """Used to log instance segmentation actual values that are assigned to the actual schema parameter.
438
450
 
439
451
  Arguments:
440
452
  ---------
@@ -455,12 +467,15 @@ class InstanceSegmentationActualColumnNames(NamedTuple):
455
467
 
456
468
 
457
469
  class ObjectDetectionLabel(NamedTuple):
458
- bounding_boxes_coordinates: List[List[float]]
459
- categories: List[str]
470
+ """Label data for object detection tasks with bounding boxes and categories."""
471
+
472
+ bounding_boxes_coordinates: list[list[float]]
473
+ categories: list[str]
460
474
  # Actual Object Detection Labels won't have scores
461
- scores: List[float] | None = None
475
+ scores: list[float] | None = None
462
476
 
463
- def validate(self, prediction_or_actual: str):
477
+ def validate(self, prediction_or_actual: str) -> None:
478
+ """Validate the object detection label fields and constraints."""
464
479
  # Validate bounding boxes
465
480
  self._validate_bounding_boxes_coordinates()
466
481
  # Validate categories
@@ -470,7 +485,7 @@ class ObjectDetectionLabel(NamedTuple):
470
485
  # Validate we have the same number of bounding boxes, categories and scores
471
486
  self._validate_count_match()
472
487
 
473
- def _validate_bounding_boxes_coordinates(self):
488
+ def _validate_bounding_boxes_coordinates(self) -> None:
474
489
  if not is_list_of(self.bounding_boxes_coordinates, list):
475
490
  raise TypeError(
476
491
  "Object Detection Label bounding boxes must be a list of lists of floats"
@@ -478,14 +493,14 @@ class ObjectDetectionLabel(NamedTuple):
478
493
  for coordinates in self.bounding_boxes_coordinates:
479
494
  _validate_bounding_box_coordinates(coordinates)
480
495
 
481
- def _validate_categories(self):
496
+ def _validate_categories(self) -> None:
482
497
  # Allows for categories as empty strings
483
498
  if not is_list_of(self.categories, str):
484
499
  raise TypeError(
485
500
  "Object Detection Label categories must be a list of strings"
486
501
  )
487
502
 
488
- def _validate_scores(self, prediction_or_actual: str):
503
+ def _validate_scores(self, prediction_or_actual: str) -> None:
489
504
  if self.scores is None:
490
505
  if prediction_or_actual == "prediction":
491
506
  raise ValueError(
@@ -507,7 +522,7 @@ class ObjectDetectionLabel(NamedTuple):
507
522
  f"{self.scores}"
508
523
  )
509
524
 
510
- def _validate_count_match(self):
525
+ def _validate_count_match(self) -> None:
511
526
  n_bounding_boxes = len(self.bounding_boxes_coordinates)
512
527
  if n_bounding_boxes == 0:
513
528
  raise ValueError(
@@ -534,10 +549,13 @@ class ObjectDetectionLabel(NamedTuple):
534
549
 
535
550
 
536
551
  class SemanticSegmentationLabel(NamedTuple):
537
- polygon_coordinates: List[List[float]]
538
- categories: List[str]
552
+ """Label data for semantic segmentation with polygon coordinates and categories."""
539
553
 
540
- def validate(self):
554
+ polygon_coordinates: list[list[float]]
555
+ categories: list[str]
556
+
557
+ def validate(self) -> None:
558
+ """Validate the field values and constraints."""
541
559
  # Validate polygon coordinates
542
560
  self._validate_polygon_coordinates()
543
561
  # Validate categories
@@ -545,17 +563,17 @@ class SemanticSegmentationLabel(NamedTuple):
545
563
  # Validate we have the same number of polygon coordinates and categories
546
564
  self._validate_count_match()
547
565
 
548
- def _validate_polygon_coordinates(self):
566
+ def _validate_polygon_coordinates(self) -> None:
549
567
  _validate_polygon_coordinates(self.polygon_coordinates)
550
568
 
551
- def _validate_categories(self):
569
+ def _validate_categories(self) -> None:
552
570
  # Allows for categories as empty strings
553
571
  if not is_list_of(self.categories, str):
554
572
  raise TypeError(
555
573
  "Semantic Segmentation Label categories must be a list of strings"
556
574
  )
557
575
 
558
- def _validate_count_match(self):
576
+ def _validate_count_match(self) -> None:
559
577
  n_polygon_coordinates = len(self.polygon_coordinates)
560
578
  if n_polygon_coordinates == 0:
561
579
  raise ValueError(
@@ -573,12 +591,15 @@ class SemanticSegmentationLabel(NamedTuple):
573
591
 
574
592
 
575
593
  class InstanceSegmentationPredictionLabel(NamedTuple):
576
- polygon_coordinates: List[List[float]]
577
- categories: List[str]
578
- scores: List[float] | None = None
579
- bounding_boxes_coordinates: List[List[float]] | None = None
594
+ """Prediction label for instance segmentation with polygons and category information."""
595
+
596
+ polygon_coordinates: list[list[float]]
597
+ categories: list[str]
598
+ scores: list[float] | None = None
599
+ bounding_boxes_coordinates: list[list[float]] | None = None
580
600
 
581
- def validate(self):
601
+ def validate(self) -> None:
602
+ """Validate the field values and constraints."""
582
603
  # Validate polygon coordinates
583
604
  self._validate_polygon_coordinates()
584
605
  # Validate categories
@@ -590,17 +611,17 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
590
611
  # Validate we have the same number of polygon coordinates and categories
591
612
  self._validate_count_match()
592
613
 
593
- def _validate_polygon_coordinates(self):
614
+ def _validate_polygon_coordinates(self) -> None:
594
615
  _validate_polygon_coordinates(self.polygon_coordinates)
595
616
 
596
- def _validate_categories(self):
617
+ def _validate_categories(self) -> None:
597
618
  # Allows for categories as empty strings
598
619
  if not is_list_of(self.categories, str):
599
620
  raise TypeError(
600
621
  "Instance Segmentation Prediction Label categories must be a list of strings"
601
622
  )
602
623
 
603
- def _validate_scores(self):
624
+ def _validate_scores(self) -> None:
604
625
  if self.scores is not None:
605
626
  if not is_list_of(self.scores, float):
606
627
  raise TypeError(
@@ -613,7 +634,7 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
613
634
  f"{self.scores}"
614
635
  )
615
636
 
616
- def _validate_bounding_boxes(self):
637
+ def _validate_bounding_boxes(self) -> None:
617
638
  if self.bounding_boxes_coordinates is not None:
618
639
  if not is_list_of(self.bounding_boxes_coordinates, list):
619
640
  raise TypeError(
@@ -622,7 +643,7 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
622
643
  for coordinates in self.bounding_boxes_coordinates:
623
644
  _validate_bounding_box_coordinates(coordinates)
624
645
 
625
- def _validate_count_match(self):
646
+ def _validate_count_match(self) -> None:
626
647
  n_polygon_coordinates = len(self.polygon_coordinates)
627
648
  if n_polygon_coordinates == 0:
628
649
  raise ValueError(
@@ -657,11 +678,14 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
657
678
 
658
679
 
659
680
  class InstanceSegmentationActualLabel(NamedTuple):
660
- polygon_coordinates: List[List[float]]
661
- categories: List[str]
662
- bounding_boxes_coordinates: List[List[float]] | None = None
681
+ """Actual label for instance segmentation with polygon coordinates and categories."""
682
+
683
+ polygon_coordinates: list[list[float]]
684
+ categories: list[str]
685
+ bounding_boxes_coordinates: list[list[float]] | None = None
663
686
 
664
- def validate(self):
687
+ def validate(self) -> None:
688
+ """Validate the field values and constraints."""
665
689
  # Validate polygon coordinates
666
690
  self._validate_polygon_coordinates()
667
691
  # Validate categories
@@ -671,17 +695,17 @@ class InstanceSegmentationActualLabel(NamedTuple):
671
695
  # Validate we have the same number of polygon coordinates and categories
672
696
  self._validate_count_match()
673
697
 
674
- def _validate_polygon_coordinates(self):
698
+ def _validate_polygon_coordinates(self) -> None:
675
699
  _validate_polygon_coordinates(self.polygon_coordinates)
676
700
 
677
- def _validate_categories(self):
701
+ def _validate_categories(self) -> None:
678
702
  # Allows for categories as empty strings
679
703
  if not is_list_of(self.categories, str):
680
704
  raise TypeError(
681
705
  "Instance Segmentation Actual Label categories must be a list of strings"
682
706
  )
683
707
 
684
- def _validate_bounding_boxes(self):
708
+ def _validate_bounding_boxes(self) -> None:
685
709
  if self.bounding_boxes_coordinates is not None:
686
710
  if not is_list_of(self.bounding_boxes_coordinates, list):
687
711
  raise TypeError(
@@ -690,7 +714,7 @@ class InstanceSegmentationActualLabel(NamedTuple):
690
714
  for coordinates in self.bounding_boxes_coordinates:
691
715
  _validate_bounding_box_coordinates(coordinates)
692
716
 
693
- def _validate_count_match(self):
717
+ def _validate_count_match(self) -> None:
694
718
  n_polygon_coordinates = len(self.polygon_coordinates)
695
719
  if n_polygon_coordinates == 0:
696
720
  raise ValueError(
@@ -717,8 +741,7 @@ class InstanceSegmentationActualLabel(NamedTuple):
717
741
 
718
742
 
719
743
  class MultiClassPredictionLabel(NamedTuple):
720
- """
721
- Used to log multi class prediction label
744
+ """Used to log multi class prediction label.
722
745
 
723
746
  Arguments:
724
747
  ---------
@@ -729,15 +752,16 @@ class MultiClassPredictionLabel(NamedTuple):
729
752
 
730
753
  """
731
754
 
732
- prediction_scores: Dict[str, float | int]
733
- threshold_scores: Dict[str, float | int] | None = None
755
+ prediction_scores: dict[str, float | int]
756
+ threshold_scores: dict[str, float | int] | None = None
734
757
 
735
- def validate(self):
758
+ def validate(self) -> None:
759
+ """Validate the field values and constraints."""
736
760
  # Validate scores
737
761
  self._validate_prediction_scores()
738
762
  self._validate_threshold_scores()
739
763
 
740
- def _validate_prediction_scores(self):
764
+ def _validate_prediction_scores(self) -> None:
741
765
  # prediction dictionary validations
742
766
  if not is_dict_of(
743
767
  self.prediction_scores,
@@ -778,7 +802,7 @@ class MultiClassPredictionLabel(NamedTuple):
778
802
  "invalid. All scores (values in dictionary) must be between 0 and 1, inclusive."
779
803
  )
780
804
 
781
- def _validate_threshold_scores(self):
805
+ def _validate_threshold_scores(self) -> None:
782
806
  if self.threshold_scores is None or len(self.threshold_scores) == 0:
783
807
  return
784
808
  if not is_dict_of(
@@ -822,8 +846,7 @@ class MultiClassPredictionLabel(NamedTuple):
822
846
 
823
847
 
824
848
  class MultiClassActualLabel(NamedTuple):
825
- """
826
- Used to log multi class actual label
849
+ """Used to log multi class actual label.
827
850
 
828
851
  Arguments:
829
852
  ---------
@@ -833,13 +856,14 @@ class MultiClassActualLabel(NamedTuple):
833
856
 
834
857
  """
835
858
 
836
- actual_scores: Dict[str, float | int]
859
+ actual_scores: dict[str, float | int]
837
860
 
838
- def validate(self):
861
+ def validate(self) -> None:
862
+ """Validate the field values and constraints."""
839
863
  # Validate scores
840
864
  self._validate_actual_scores()
841
865
 
842
- def _validate_actual_scores(self):
866
+ def _validate_actual_scores(self) -> None:
843
867
  if not is_dict_of(
844
868
  self.actual_scores,
845
869
  key_allowed_types=str,
@@ -879,12 +903,15 @@ class MultiClassActualLabel(NamedTuple):
879
903
 
880
904
 
881
905
  class RankingPredictionLabel(NamedTuple):
906
+ """Prediction label for ranking tasks with group and rank information."""
907
+
882
908
  group_id: str
883
909
  rank: int
884
910
  score: float | None = None
885
911
  label: str | None = None
886
912
 
887
- def validate(self):
913
+ def validate(self) -> None:
914
+ """Validate the field values and constraints."""
888
915
  # Validate existence of required fields: prediction_group_id and rank
889
916
  if self.group_id is None or self.rank is None:
890
917
  raise ValueError(
@@ -901,7 +928,7 @@ class RankingPredictionLabel(NamedTuple):
901
928
  if self.score is not None:
902
929
  self._validate_score()
903
930
 
904
- def _validate_group_id(self):
931
+ def _validate_group_id(self) -> None:
905
932
  if not isinstance(self.group_id, str):
906
933
  raise TypeError("Prediction Group ID must be a string")
907
934
  if not (1 <= len(self.group_id) <= 36):
@@ -909,7 +936,7 @@ class RankingPredictionLabel(NamedTuple):
909
936
  f"Prediction Group ID must have length between 1 and 36. Found {len(self.group_id)}"
910
937
  )
911
938
 
912
- def _validate_rank(self):
939
+ def _validate_rank(self) -> None:
913
940
  if not isinstance(self.rank, int):
914
941
  raise TypeError("Prediction Rank must be an int")
915
942
  if not (1 <= self.rank <= 100):
@@ -917,22 +944,25 @@ class RankingPredictionLabel(NamedTuple):
917
944
  f"Prediction Rank must be between 1 and 100, inclusive. Found {self.rank}"
918
945
  )
919
946
 
920
- def _validate_label(self):
947
+ def _validate_label(self) -> None:
921
948
  if not isinstance(self.label, str):
922
949
  raise TypeError("Prediction Label must be a str")
923
950
  if self.label == "":
924
951
  raise ValueError("Prediction Label must not be an empty string.")
925
952
 
926
- def _validate_score(self):
953
+ def _validate_score(self) -> None:
927
954
  if not isinstance(self.score, (float, int)):
928
955
  raise TypeError("Prediction Score must be a float or an int")
929
956
 
930
957
 
931
958
  class RankingActualLabel(NamedTuple):
932
- relevance_labels: List[str] | None = None
959
+ """Actual label for ranking tasks with relevance information."""
960
+
961
+ relevance_labels: list[str] | None = None
933
962
  relevance_score: float | None = None
934
963
 
935
- def validate(self):
964
+ def validate(self) -> None:
965
+ """Validate the field values and constraints."""
936
966
  # Validate relevance_labels type
937
967
  if self.relevance_labels is not None:
938
968
  self._validate_relevance_labels(self.relevance_labels)
@@ -941,7 +971,7 @@ class RankingActualLabel(NamedTuple):
941
971
  self._validate_relevance_score(self.relevance_score)
942
972
 
943
973
  @staticmethod
944
- def _validate_relevance_labels(relevance_labels: List[str]):
974
+ def _validate_relevance_labels(relevance_labels: list[str]) -> None:
945
975
  if not is_list_of(relevance_labels, str):
946
976
  raise TypeError("Actual Relevance Labels must be a list of strings")
947
977
  if any(label == "" for label in relevance_labels):
@@ -950,17 +980,20 @@ class RankingActualLabel(NamedTuple):
950
980
  )
951
981
 
952
982
  @staticmethod
953
- def _validate_relevance_score(relevance_score: float):
983
+ def _validate_relevance_score(relevance_score: float) -> None:
954
984
  if not isinstance(relevance_score, (float, int)):
955
985
  raise TypeError("Actual Relevance score must be a float or an int")
956
986
 
957
987
 
958
988
  @dataclass
959
989
  class PromptTemplateColumnNames:
990
+ """Column names for prompt template configuration in LLM schemas."""
991
+
960
992
  template_column_name: str | None = None
961
993
  template_version_column_name: str | None = None
962
994
 
963
- def __iter__(self):
995
+ def __iter__(self) -> Iterator[str | None]:
996
+ """Iterate over the prompt template column names."""
964
997
  return iter(
965
998
  (self.template_column_name, self.template_version_column_name)
966
999
  )
@@ -968,21 +1001,27 @@ class PromptTemplateColumnNames:
968
1001
 
969
1002
  @dataclass
970
1003
  class LLMConfigColumnNames:
1004
+ """Column names for LLM configuration parameters in schemas."""
1005
+
971
1006
  model_column_name: str | None = None
972
1007
  params_column_name: str | None = None
973
1008
 
974
- def __iter__(self):
1009
+ def __iter__(self) -> Iterator[str | None]:
1010
+ """Iterate over the LLM config column names."""
975
1011
  return iter((self.model_column_name, self.params_column_name))
976
1012
 
977
1013
 
978
1014
  @dataclass
979
1015
  class LLMRunMetadataColumnNames:
1016
+ """Column names for LLM run metadata fields in schemas."""
1017
+
980
1018
  total_token_count_column_name: str | None = None
981
1019
  prompt_token_count_column_name: str | None = None
982
1020
  response_token_count_column_name: str | None = None
983
1021
  response_latency_ms_column_name: str | None = None
984
1022
 
985
- def __iter__(self):
1023
+ def __iter__(self) -> Iterator[str | None]:
1024
+ """Iterate over the LLM run metadata column names."""
986
1025
  return iter(
987
1026
  (
988
1027
  self.total_token_count_column_name,
@@ -1011,11 +1050,19 @@ class LLMRunMetadataColumnNames:
1011
1050
  #
1012
1051
  @dataclass
1013
1052
  class SimilarityReference:
1053
+ """Reference to a prediction for similarity search operations."""
1054
+
1014
1055
  prediction_id: str
1015
1056
  reference_column_name: str
1016
1057
  prediction_timestamp: datetime | None = None
1017
1058
 
1018
- def __post_init__(self):
1059
+ def __post_init__(self) -> None:
1060
+ """Validate similarity reference fields after initialization.
1061
+
1062
+ Raises:
1063
+ ValueError: If prediction_id or reference_column_name is empty.
1064
+ TypeError: If prediction_timestamp is not a datetime object.
1065
+ """
1019
1066
  if self.prediction_id == "":
1020
1067
  raise ValueError("prediction id cannot be empty")
1021
1068
  if self.reference_column_name == "":
@@ -1028,11 +1075,20 @@ class SimilarityReference:
1028
1075
 
1029
1076
  @dataclass
1030
1077
  class SimilaritySearchParams:
1031
- references: List[SimilarityReference]
1078
+ """Parameters for configuring similarity search operations."""
1079
+
1080
+ references: list[SimilarityReference]
1032
1081
  search_column_name: str
1033
1082
  threshold: float = 0
1034
1083
 
1035
- def __post_init__(self):
1084
+ def __post_init__(self) -> None:
1085
+ """Validate similarity search parameters after initialization.
1086
+
1087
+ Raises:
1088
+ ValueError: If references list is invalid, search_column_name is
1089
+ empty, or threshold is out of range.
1090
+ TypeError: If any reference is not a SimilarityReference instance.
1091
+ """
1036
1092
  if (
1037
1093
  not self.references
1038
1094
  or len(self.references) <= 0
@@ -1054,23 +1110,28 @@ class SimilaritySearchParams:
1054
1110
 
1055
1111
  @dataclass(frozen=True)
1056
1112
  class BaseSchema:
1057
- def replace(self, **changes):
1113
+ """Base class for all schema definitions with immutable fields."""
1114
+
1115
+ def replace(self, **changes: object) -> Self:
1116
+ """Return a new instance with specified fields replaced."""
1058
1117
  return replace(self, **changes)
1059
1118
 
1060
- def asdict(self) -> Dict[str, str]:
1119
+ def asdict(self) -> dict[str, str]:
1120
+ """Convert the schema to a dictionary."""
1061
1121
  return asdict(self)
1062
1122
 
1063
- def get_used_columns(self) -> Set[str]:
1123
+ def get_used_columns(self) -> set[str]:
1124
+ """Return the set of column names used in this schema."""
1064
1125
  return set(self.get_used_columns_counts().keys())
1065
1126
 
1066
- def get_used_columns_counts(self) -> Dict[str, int]:
1127
+ def get_used_columns_counts(self) -> dict[str, int]:
1128
+ """Return a dict mapping column names to their usage count."""
1067
1129
  raise NotImplementedError()
1068
1130
 
1069
1131
 
1070
1132
  @dataclass(frozen=True)
1071
1133
  class TypedColumns:
1072
- """
1073
- Optional class used for explicit type enforcement of feature and tag columns in the dataframe.
1134
+ """Optional class used for explicit type enforcement of feature and tag columns in the dataframe.
1074
1135
 
1075
1136
  Usage:
1076
1137
  ------
@@ -1101,30 +1162,31 @@ class TypedColumns:
1101
1162
 
1102
1163
  """
1103
1164
 
1104
- inferred: List[str] | None = None
1105
- to_str: List[str] | None = None
1106
- to_int: List[str] | None = None
1107
- to_float: List[str] | None = None
1165
+ inferred: list[str] | None = None
1166
+ to_str: list[str] | None = None
1167
+ to_int: list[str] | None = None
1168
+ to_float: list[str] | None = None
1108
1169
 
1109
- def get_all_column_names(self) -> List[str]:
1170
+ def get_all_column_names(self) -> list[str]:
1171
+ """Return all column names across all conversion lists."""
1110
1172
  return list(chain.from_iterable(filter(None, self.__dict__.values())))
1111
1173
 
1112
- def has_duplicate_columns(self) -> Tuple[bool, Set[str]]:
1174
+ def has_duplicate_columns(self) -> tuple[bool, set[str]]:
1175
+ """Check for duplicate columns and return (has_duplicates, duplicate_set)."""
1113
1176
  # True if there are duplicates within a field's list or across fields.
1114
1177
  # Return a set of the duplicate column names.
1115
1178
  cols = self.get_all_column_names()
1116
- duplicates = set([x for x in cols if cols.count(x) > 1])
1179
+ duplicates = {x for x in cols if cols.count(x) > 1}
1117
1180
  return len(duplicates) > 0, duplicates
1118
1181
 
1119
1182
  def is_empty(self) -> bool:
1183
+ """Return True if no columns are configured for conversion."""
1120
1184
  return not self.get_all_column_names()
1121
1185
 
1122
1186
 
1123
1187
  @dataclass(frozen=True)
1124
1188
  class Schema(BaseSchema):
1125
- """
1126
- Used to organize and map column names containing model data within your Pandas dataframe to
1127
- Arize.
1189
+ """Used to organize and map column names containing model data within your Pandas dataframe to Arize.
1128
1190
 
1129
1191
  Arguments:
1130
1192
  ---------
@@ -1215,15 +1277,15 @@ class Schema(BaseSchema):
1215
1277
  """
1216
1278
 
1217
1279
  prediction_id_column_name: str | None = None
1218
- feature_column_names: List[str] | TypedColumns | None = None
1219
- tag_column_names: List[str] | TypedColumns | None = None
1280
+ feature_column_names: list[str] | TypedColumns | None = None
1281
+ tag_column_names: list[str] | TypedColumns | None = None
1220
1282
  timestamp_column_name: str | None = None
1221
1283
  prediction_label_column_name: str | None = None
1222
1284
  prediction_score_column_name: str | None = None
1223
1285
  actual_label_column_name: str | None = None
1224
1286
  actual_score_column_name: str | None = None
1225
- shap_values_column_names: Dict[str, str] | None = None
1226
- embedding_feature_column_names: Dict[str, EmbeddingColumnNames] | None = (
1287
+ shap_values_column_names: dict[str, str] | None = None
1288
+ embedding_feature_column_names: dict[str, EmbeddingColumnNames] | None = (
1227
1289
  None # type:ignore
1228
1290
  )
1229
1291
  prediction_group_id_column_name: str | None = None
@@ -1242,7 +1304,7 @@ class Schema(BaseSchema):
1242
1304
  prompt_template_column_names: PromptTemplateColumnNames | None = None
1243
1305
  llm_config_column_names: LLMConfigColumnNames | None = None
1244
1306
  llm_run_metadata_column_names: LLMRunMetadataColumnNames | None = None
1245
- retrieved_document_ids_column_name: List[str] | None = None
1307
+ retrieved_document_ids_column_name: list[str] | None = None
1246
1308
  multi_class_threshold_scores_column_name: str | None = None
1247
1309
  semantic_segmentation_prediction_column_names: (
1248
1310
  SemanticSegmentationColumnNames | None
@@ -1257,7 +1319,8 @@ class Schema(BaseSchema):
1257
1319
  InstanceSegmentationActualColumnNames | None
1258
1320
  ) = None
1259
1321
 
1260
- def get_used_columns_counts(self) -> Dict[str, int]:
1322
+ def get_used_columns_counts(self) -> dict[str, int]:
1323
+ """Return a dict mapping column names to their usage count."""
1261
1324
  columns_used_counts = {}
1262
1325
 
1263
1326
  for field in self.__dataclass_fields__:
@@ -1364,6 +1427,7 @@ class Schema(BaseSchema):
1364
1427
  return columns_used_counts
1365
1428
 
1366
1429
  def has_prediction_columns(self) -> bool:
1430
+ """Return True if prediction columns are configured."""
1367
1431
  prediction_cols = (
1368
1432
  self.prediction_label_column_name,
1369
1433
  self.prediction_score_column_name,
@@ -1377,6 +1441,7 @@ class Schema(BaseSchema):
1377
1441
  return any(col is not None for col in prediction_cols)
1378
1442
 
1379
1443
  def has_actual_columns(self) -> bool:
1444
+ """Return True if actual label columns are configured."""
1380
1445
  actual_cols = (
1381
1446
  self.actual_label_column_name,
1382
1447
  self.actual_score_column_name,
@@ -1389,13 +1454,16 @@ class Schema(BaseSchema):
1389
1454
  return any(col is not None for col in actual_cols)
1390
1455
 
1391
1456
  def has_feature_importance_columns(self) -> bool:
1457
+ """Return True if feature importance columns are configured."""
1392
1458
  feature_importance_cols = (self.shap_values_column_names,)
1393
1459
  return any(col is not None for col in feature_importance_cols)
1394
1460
 
1395
1461
  def has_typed_columns(self) -> bool:
1462
+ """Return True if typed columns are configured."""
1396
1463
  return any(self.typed_column_fields())
1397
1464
 
1398
- def typed_column_fields(self) -> Set[str]:
1465
+ def typed_column_fields(self) -> set[str]:
1466
+ """Return the set of field names with typed columns."""
1399
1467
  return {
1400
1468
  field
1401
1469
  for field in self.__dataclass_fields__
@@ -1403,9 +1471,9 @@ class Schema(BaseSchema):
1403
1471
  }
1404
1472
 
1405
1473
  def is_delayed(self) -> bool:
1406
- """
1407
- This function checks if the given schema, according to the columns provided
1408
- by the user, has inherently latent information
1474
+ """Check if the schema has inherently latent information.
1475
+
1476
+ Determines this based on the columns provided by the user.
1409
1477
 
1410
1478
  Returns:
1411
1479
  bool: True if the schema is "delayed", i.e., does not possess prediction
@@ -1418,11 +1486,14 @@ class Schema(BaseSchema):
1418
1486
 
1419
1487
  @dataclass(frozen=True)
1420
1488
  class CorpusSchema(BaseSchema):
1489
+ """Schema for corpus data with document identification and content columns."""
1490
+
1421
1491
  document_id_column_name: str | None = None
1422
1492
  document_version_column_name: str | None = None
1423
1493
  document_text_embedding_column_names: EmbeddingColumnNames | None = None
1424
1494
 
1425
- def get_used_columns_counts(self) -> Dict[str, int]:
1495
+ def get_used_columns_counts(self) -> dict[str, int]:
1496
+ """Return a dict mapping column names to their usage count."""
1426
1497
  columns_used_counts = {}
1427
1498
 
1428
1499
  if self.document_id_column_name is not None:
@@ -1459,6 +1530,8 @@ class CorpusSchema(BaseSchema):
1459
1530
 
1460
1531
  @unique
1461
1532
  class ArizeTypes(Enum):
1533
+ """Enum representing supported data types in Arize platform."""
1534
+
1462
1535
  STR = 0
1463
1536
  FLOAT = 1
1464
1537
  INT = 2
@@ -1466,11 +1539,21 @@ class ArizeTypes(Enum):
1466
1539
 
1467
1540
  @dataclass(frozen=True)
1468
1541
  class TypedValue:
1542
+ """Container for a value with its associated Arize type."""
1543
+
1469
1544
  type: ArizeTypes
1470
1545
  value: str | bool | float | int
1471
1546
 
1472
1547
 
1473
1548
  def is_json_str(s: str) -> bool:
1549
+ """Check if a string is valid JSON.
1550
+
1551
+ Args:
1552
+ s: The string to validate.
1553
+
1554
+ Returns:
1555
+ True if the string is valid JSON, False otherwise.
1556
+ """
1474
1557
  try:
1475
1558
  json.loads(s)
1476
1559
  except ValueError:
@@ -1484,25 +1567,51 @@ T = TypeVar("T", bound=type)
1484
1567
 
1485
1568
 
1486
1569
  def is_array_of(arr: Sequence[object], tp: T) -> bool:
1570
+ """Check if a value is a numpy array with all elements of a specific type.
1571
+
1572
+ Args:
1573
+ arr: The sequence to check.
1574
+ tp: The expected type for all elements.
1575
+
1576
+ Returns:
1577
+ True if arr is a numpy array and all elements are of type tp.
1578
+ """
1487
1579
  return isinstance(arr, np.ndarray) and all(isinstance(x, tp) for x in arr)
1488
1580
 
1489
1581
 
1490
1582
  def is_list_of(lst: Sequence[object], tp: T) -> bool:
1583
+ """Check if a value is a list with all elements of a specific type.
1584
+
1585
+ Args:
1586
+ lst: The sequence to check.
1587
+ tp: The expected type for all elements.
1588
+
1589
+ Returns:
1590
+ True if lst is a list and all elements are of type tp.
1591
+ """
1491
1592
  return isinstance(lst, list) and all(isinstance(x, tp) for x in lst)
1492
1593
 
1493
1594
 
1494
1595
  def is_iterable_of(lst: Sequence[object], tp: T) -> bool:
1596
+ """Check if a value is an iterable with all elements of a specific type.
1597
+
1598
+ Args:
1599
+ lst: The sequence to check.
1600
+ tp: The expected type for all elements.
1601
+
1602
+ Returns:
1603
+ True if lst is an iterable and all elements are of type tp.
1604
+ """
1495
1605
  return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst)
1496
1606
 
1497
1607
 
1498
1608
  def is_dict_of(
1499
- d: Dict[object, object],
1609
+ d: dict[object, object],
1500
1610
  key_allowed_types: T,
1501
1611
  value_allowed_types: T = (),
1502
1612
  value_list_allowed_types: T = (),
1503
1613
  ) -> bool:
1504
- """
1505
- Method to check types are valid for dictionary.
1614
+ """Method to check types are valid for dictionary.
1506
1615
 
1507
1616
  Arguments:
1508
1617
  ---------
@@ -1535,7 +1644,7 @@ def is_dict_of(
1535
1644
  )
1536
1645
 
1537
1646
 
1538
- def _count_characters_raw_data(data: str | List[str]) -> int:
1647
+ def _count_characters_raw_data(data: str | list[str]) -> int:
1539
1648
  character_count = 0
1540
1649
  if isinstance(data, str):
1541
1650
  character_count = len(data)
@@ -1551,8 +1660,14 @@ def _count_characters_raw_data(data: str | List[str]) -> int:
1551
1660
 
1552
1661
 
1553
1662
  def add_to_column_count_dictionary(
1554
- column_dictionary: Dict[str, int], col: str | None
1555
- ):
1663
+ column_dictionary: dict[str, int], col: str | None
1664
+ ) -> None:
1665
+ """Increment the count for a column name in a dictionary.
1666
+
1667
+ Args:
1668
+ column_dictionary: Dictionary mapping column names to counts.
1669
+ col: The column name to increment, or None to skip.
1670
+ """
1556
1671
  if col:
1557
1672
  if col in column_dictionary:
1558
1673
  column_dictionary[col] += 1
@@ -1560,7 +1675,9 @@ def add_to_column_count_dictionary(
1560
1675
  column_dictionary[col] = 1
1561
1676
 
1562
1677
 
1563
- def _validate_bounding_box_coordinates(bounding_box_coordinates: List[float]):
1678
+ def _validate_bounding_box_coordinates(
1679
+ bounding_box_coordinates: list[float],
1680
+ ) -> None:
1564
1681
  if not is_list_of(bounding_box_coordinates, float):
1565
1682
  raise TypeError(
1566
1683
  "Each bounding box's coordinates must be a lists of floats"
@@ -1586,10 +1703,12 @@ def _validate_bounding_box_coordinates(bounding_box_coordinates: List[float]):
1586
1703
  f"top-left. Found {bounding_box_coordinates}"
1587
1704
  )
1588
1705
 
1589
- return None
1706
+ return
1590
1707
 
1591
1708
 
1592
- def _validate_polygon_coordinates(polygon_coordinates: List[List[float]]):
1709
+ def _validate_polygon_coordinates(
1710
+ polygon_coordinates: list[list[float]],
1711
+ ) -> None:
1593
1712
  if not is_list_of(polygon_coordinates, list):
1594
1713
  raise TypeError("Polygon coordinates must be a list of lists of floats")
1595
1714
  for coordinates in polygon_coordinates:
@@ -1651,27 +1770,41 @@ def _validate_polygon_coordinates(polygon_coordinates: List[List[float]]):
1651
1770
  f"{coordinates}"
1652
1771
  )
1653
1772
 
1654
- return None
1773
+ return
1655
1774
 
1656
1775
 
1657
- def segments_intersect(p1, p2, p3, p4):
1658
- """
1659
- Check if two line segments intersect.
1776
+ def segments_intersect(
1777
+ p1: tuple[float, float],
1778
+ p2: tuple[float, float],
1779
+ p3: tuple[float, float],
1780
+ p4: tuple[float, float],
1781
+ ) -> bool:
1782
+ """Check if two line segments intersect.
1660
1783
 
1661
1784
  Args:
1662
- p1, p2: First line segment endpoints (x,y)
1663
- p3, p4: Second line segment endpoints (x,y)
1785
+ p1: First endpoint of the first line segment (x,y)
1786
+ p2: Second endpoint of the first line segment (x,y)
1787
+ p3: First endpoint of the second line segment (x,y)
1788
+ p4: Second endpoint of the second line segment (x,y)
1664
1789
 
1665
1790
  Returns:
1666
1791
  True if the line segments intersect, False otherwise
1667
1792
  """
1668
1793
 
1669
1794
  # Function to calculate direction
1670
- def orientation(p, q, r):
1795
+ def orientation(
1796
+ p: tuple[float, float],
1797
+ q: tuple[float, float],
1798
+ r: tuple[float, float],
1799
+ ) -> float:
1671
1800
  return (q[1] - p[1]) * (r[0] - q[0]) - (q[0] - p[0]) * (r[1] - q[1])
1672
1801
 
1673
1802
  # Function to check if point q is on segment pr
1674
- def on_segment(p, q, r):
1803
+ def on_segment(
1804
+ p: tuple[float, float],
1805
+ q: tuple[float, float],
1806
+ r: tuple[float, float],
1807
+ ) -> bool:
1675
1808
  return (
1676
1809
  q[0] <= max(p[0], r[0])
1677
1810
  and q[0] >= min(p[0], r[0])
@@ -1703,17 +1836,20 @@ def segments_intersect(p1, p2, p3, p4):
1703
1836
 
1704
1837
  @unique
1705
1838
  class StatusCodes(Enum):
1839
+ """Enum representing status codes for operations and responses."""
1840
+
1706
1841
  UNSET = 0
1707
1842
  OK = 1
1708
1843
  ERROR = 2
1709
1844
 
1710
1845
  @classmethod
1711
- def list_codes(cls):
1846
+ def list_codes(cls) -> list[str]:
1847
+ """Return a list of all status code names."""
1712
1848
  return [t.name for t in cls]
1713
1849
 
1714
1850
 
1715
- def convert_element(value):
1716
- """Converts scalar or array to python native"""
1851
+ def convert_element(value: object) -> object:
1852
+ """Converts scalar or array to python native."""
1717
1853
  val = getattr(value, "tolist", lambda: value)()
1718
1854
  # Check if it's a list since elements from pd indices are converted to a
1719
1855
  # scalar whereas pd series/dataframe elements are converted to list of 1
@@ -1734,7 +1870,7 @@ PredictionLabelTypes = (
1734
1870
  | bool
1735
1871
  | int
1736
1872
  | float
1737
- | Tuple[str, float]
1873
+ | tuple[str, float]
1738
1874
  | ObjectDetectionLabel
1739
1875
  | RankingPredictionLabel
1740
1876
  | MultiClassPredictionLabel
@@ -1745,7 +1881,7 @@ ActualLabelTypes = (
1745
1881
  | bool
1746
1882
  | int
1747
1883
  | float
1748
- | Tuple[str, float]
1884
+ | tuple[str, float]
1749
1885
  | ObjectDetectionLabel
1750
1886
  | RankingActualLabel
1751
1887
  | MultiClassActualLabel