arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -0,0 +1,47 @@
1
+ """Bounded thread pool executor with queue size limits."""
2
+
3
+ from collections.abc import Callable
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from threading import BoundedSemaphore
6
+
7
+
8
+ class BoundedExecutor:
9
+ """BoundedExecutor behaves as a ThreadPoolExecutor which will block on calls to submit().
10
+
11
+ Blocks once the limit given as "bound" work items are queued for execution.
12
+
13
+ :param bound: Integer - the maximum number of items in the work queue
14
+ :param max_workers: Integer - the size of the thread pool
15
+ """
16
+
17
+ def __init__(self, bound: int, max_workers: int) -> None:
18
+ """Initialize the bounded executor.
19
+
20
+ Args:
21
+ bound: Maximum number of items in the work queue.
22
+ max_workers: Size of the thread pool.
23
+ """
24
+ self.executor = ThreadPoolExecutor(max_workers=max_workers)
25
+ self.semaphore = BoundedSemaphore(bound + max_workers)
26
+
27
+ """See concurrent.futures.Executor#submit"""
28
+
29
+ def submit(
30
+ self, fn: Callable[..., object], *args: object, **kwargs: object
31
+ ) -> object:
32
+ """Submit a callable to be executed with bounded concurrency."""
33
+ self.semaphore.acquire()
34
+ try:
35
+ future = self.executor.submit(fn, *args, **kwargs)
36
+ except Exception:
37
+ self.semaphore.release()
38
+ raise
39
+ else:
40
+ future.add_done_callback(lambda _: self.semaphore.release())
41
+ return future
42
+
43
+ """See concurrent.futures.Executor#shutdown"""
44
+
45
+ def shutdown(self, wait: bool = True) -> None:
46
+ """Shutdown the executor, optionally waiting for pending tasks to complete."""
47
+ self.executor.shutdown(wait)
@@ -1,27 +1,45 @@
1
+ """Type casting utilities for ML model data conversion."""
2
+
1
3
  # type: ignore[pb2]
2
4
  from __future__ import annotations
3
5
 
4
6
  import math
5
- from typing import TYPE_CHECKING, List, Tuple, Union
7
+ from typing import TYPE_CHECKING
6
8
 
7
9
  import numpy as np
8
10
 
9
11
  from arize.logging import log_a_list
10
- from arize.types import ArizeTypes, Schema, TypedColumns, TypedValue, is_list_of
12
+ from arize.ml.types import (
13
+ ArizeTypes,
14
+ Schema,
15
+ TypedColumns,
16
+ TypedValue,
17
+ )
18
+ from arize.utils.types import is_list_of
11
19
 
12
20
  if TYPE_CHECKING:
13
21
  import pandas as pd
14
22
 
15
23
 
16
24
  class CastingError(Exception):
25
+ """Raised when type casting fails for a value."""
26
+
17
27
  def __str__(self) -> str:
28
+ """Return a human-readable error message."""
18
29
  return self.error_message()
19
30
 
20
31
  def __init__(self, error_msg: str, typed_value: TypedValue) -> None:
32
+ """Initialize the exception with type casting context.
33
+
34
+ Args:
35
+ error_msg: Description of the casting failure.
36
+ typed_value: The TypedValue that failed to cast.
37
+ """
21
38
  self.error_msg = error_msg
22
39
  self.typed_value = typed_value
23
40
 
24
41
  def error_message(self) -> str:
42
+ """Return the error message for this exception."""
25
43
  return (
26
44
  f"Failed to cast value {self.typed_value.value} of type {type(self.typed_value.value)} "
27
45
  f"to type {self.typed_value.type}. "
@@ -30,7 +48,10 @@ class CastingError(Exception):
30
48
 
31
49
 
32
50
  class ColumnCastingError(Exception):
51
+ """Raised when type casting fails for a column."""
52
+
33
53
  def __str__(self) -> str:
54
+ """Return a human-readable error message."""
34
55
  return self.error_message()
35
56
 
36
57
  def __init__(
@@ -39,11 +60,19 @@ class ColumnCastingError(Exception):
39
60
  attempted_columns: str,
40
61
  attempted_type: TypedColumns,
41
62
  ) -> None:
63
+ """Initialize the exception with column casting context.
64
+
65
+ Args:
66
+ error_msg: Description of the casting failure.
67
+ attempted_columns: Columns that failed to cast.
68
+ attempted_type: The TypedColumns type that was attempted.
69
+ """
42
70
  self.error_msg = error_msg
43
71
  self.attempted_casting_columns = attempted_columns
44
72
  self.attempted_casting_type = attempted_type
45
73
 
46
74
  def error_message(self) -> str:
75
+ """Return the error message for this exception."""
47
76
  return (
48
77
  f"Failed to cast to type {self.attempted_casting_type} "
49
78
  f"for columns: {log_a_list(self.attempted_casting_columns, 'and')}. "
@@ -52,60 +81,70 @@ class ColumnCastingError(Exception):
52
81
 
53
82
 
54
83
  class InvalidTypedColumnsError(Exception):
84
+ """Raised when typed columns are invalid or incorrectly specified."""
85
+
55
86
  def __str__(self) -> str:
87
+ """Return a human-readable error message."""
56
88
  return self.error_message()
57
89
 
58
90
  def __init__(self, field_name: str, reason: str) -> None:
91
+ """Initialize the exception with typed columns validation context.
92
+
93
+ Args:
94
+ field_name: Name of the schema field with invalid typed columns.
95
+ reason: Description of why the typed columns are invalid.
96
+ """
59
97
  self.field_name = field_name
60
98
  self.reason = reason
61
99
 
62
100
  def error_message(self) -> str:
101
+ """Return the error message for this exception."""
63
102
  return f"The {self.field_name} TypedColumns object {self.reason}."
64
103
 
65
104
 
66
105
  class InvalidSchemaFieldTypeError(Exception):
106
+ """Raised when schema field has invalid or unexpected type."""
107
+
67
108
  def __str__(self) -> str:
109
+ """Return a human-readable error message."""
68
110
  return self.error_message()
69
111
 
70
112
  def __init__(self, msg: str) -> None:
113
+ """Initialize the exception with schema field type error message.
114
+
115
+ Args:
116
+ msg: Error message describing the schema field type issue.
117
+ """
71
118
  self.msg = msg
72
119
 
73
120
  def error_message(self) -> str:
121
+ """Return the error message for this exception."""
74
122
  return self.msg
75
123
 
76
124
 
77
125
  def cast_typed_columns(
78
126
  dataframe: pd.DataFrame,
79
127
  schema: Schema,
80
- ) -> Tuple[pd.DataFrame, Schema]:
81
- """
82
- Cast feature and tag columns in the dataframe to the types specified in each TypedColumns config.
83
- This optional feature provides a simple way for users to prevent
84
- type drift within a column across many SDK uploads.
85
-
86
- Arguments:
87
- ---------
88
- dataframe: pd.DataFrame
89
- A deepcopy of the user's dataframe.
90
- schema: Schema
91
- The schema, which may include feature and tag column names
128
+ ) -> tuple[pd.DataFrame, Schema]:
129
+ """Cast feature and tag columns in the dataframe to the types specified in each TypedColumns config.
130
+
131
+ This optional feature provides a simple way for users to prevent type drift within
132
+ a column across many SDK uploads.
133
+
134
+ Args:
135
+ dataframe (pd.DataFrame): A deepcopy of the user's dataframe.
136
+ schema (Schema): The schema, which may include feature and tag column names
92
137
  in a TypedColumns object or a List[string].
93
138
 
94
139
  Returns:
95
- -------
96
- dataframe: pd.DataFrame
97
- The dataframe, with columns cast to the specified types.
98
- schema: Schema
99
- A new Schema object, with feature and tag column names converted to the List[string] format
100
- expected in downstream validation.
140
+ tuple[pd.DataFrame, Schema]: A tuple containing:
141
+ - dataframe: The dataframe, with columns cast to the specified types.
142
+ - schema: A new Schema object, with feature and tag column names converted
143
+ to the List[string] format expected in downstream validation.
101
144
 
102
145
  Raises:
103
- ------
104
- ColumnCastingError
105
- If casting fails.
106
- InvalidTypedColumnsError
107
- If the TypedColumns object is invalid.
108
-
146
+ ColumnCastingError: If casting fails.
147
+ InvalidTypedColumnsError: If the TypedColumns object is invalid.
109
148
  """
110
149
  typed_column_fields = schema.typed_column_fields()
111
150
  feature_field = "feature_column_names"
@@ -120,7 +159,7 @@ def cast_typed_columns(
120
159
  )
121
160
 
122
161
  # Make sure no other schema fields have this type.
123
- if any({f for f in typed_column_fields if f not in allowed_fields}):
162
+ if any(f for f in typed_column_fields if f not in allowed_fields):
124
163
  raise InvalidSchemaFieldTypeError(
125
164
  "Only the feature_column_names and tag_column_names Schema fields can be of type "
126
165
  "TypedColumns. Fields with type TypedColumns:"
@@ -130,10 +169,7 @@ def cast_typed_columns(
130
169
  for field_name in typed_column_fields:
131
170
  f = getattr(schema, field_name)
132
171
  if f:
133
- try:
134
- _validate_typed_columns(field_name, f)
135
- except InvalidTypedColumnsError:
136
- raise
172
+ _validate_typed_columns(field_name, f)
137
173
  dataframe = _cast_columns(dataframe, f)
138
174
 
139
175
  # Now that the dataframe values have been cast to the specified types:
@@ -144,6 +180,14 @@ def cast_typed_columns(
144
180
 
145
181
 
146
182
  def cast_dictionary(d: dict) -> dict:
183
+ """Cast TypedValue entries in a dictionary to their appropriate Python types.
184
+
185
+ Args:
186
+ d: Dictionary that may contain TypedValue objects as values.
187
+
188
+ Returns:
189
+ Dictionary with TypedValue objects cast to their native Python types.
190
+ """
147
191
  cast_dict = {}
148
192
  for k, v in d.items():
149
193
  if isinstance(v, TypedValue):
@@ -154,47 +198,38 @@ def cast_dictionary(d: dict) -> dict:
154
198
 
155
199
  def _cast_value(
156
200
  typed_value: TypedValue,
157
- ) -> Union[str, int, float, List[str], None]:
158
- """
159
- Casts a TypedValue to its provided type, preserving all null values as None or float('nan').
201
+ ) -> str | int | float | list[str] | None:
202
+ """Casts a TypedValue to its provided type, preserving all null values as None or float('nan').
160
203
 
161
- Arguments:
162
- ---------
163
- typed_value: TypedValue
164
- The TypedValue to cast.
204
+ Args:
205
+ typed_value (TypedValue): The TypedValue to cast.
165
206
 
166
207
  Returns:
167
- -------
168
- Union[str, int, float, List[str], None]
169
- The cast value.
208
+ str | int | float | list[str] | None: The cast value.
170
209
 
171
210
  Raises:
172
- ------
173
- CastingError
174
- If the value cannot be cast to the provided type.
175
-
211
+ CastingError: If the value cannot be cast to the provided type.
176
212
  """
177
213
  if typed_value.value is None:
178
214
  return None
179
215
 
180
216
  if typed_value.type == ArizeTypes.FLOAT:
181
217
  return _cast_to_float(typed_value)
182
- elif typed_value.type == ArizeTypes.INT:
218
+ if typed_value.type == ArizeTypes.INT:
183
219
  return _cast_to_int(typed_value)
184
- elif typed_value.type == ArizeTypes.STR:
220
+ if typed_value.type == ArizeTypes.STR:
185
221
  return _cast_to_str(typed_value)
186
- else:
187
- raise CastingError("Unknown casting type", typed_value)
222
+ raise CastingError("Unknown casting type", typed_value)
188
223
 
189
224
 
190
- def _cast_to_float(typed_value: TypedValue) -> Union[float, None]:
225
+ def _cast_to_float(typed_value: TypedValue) -> float | None:
191
226
  try:
192
227
  return float(typed_value.value)
193
228
  except Exception as e:
194
229
  raise CastingError(str(e), typed_value) from e
195
230
 
196
231
 
197
- def _cast_to_int(typed_value: TypedValue) -> Union[int, None]:
232
+ def _cast_to_int(typed_value: TypedValue) -> int | None:
198
233
  # a NaN float can't be cast to an int. Proactively return None instead.
199
234
  if isinstance(typed_value.value, float) and math.isnan(typed_value.value):
200
235
  return None
@@ -214,7 +249,7 @@ def _cast_to_int(typed_value: TypedValue) -> Union[int, None]:
214
249
  raise CastingError(str(e), typed_value) from e
215
250
 
216
251
 
217
- def _cast_to_str(typed_value: TypedValue) -> Union[str, None]:
252
+ def _cast_to_str(typed_value: TypedValue) -> str | None:
218
253
  # a NaN float can't be cast to a string. Proactively return None instead.
219
254
  if isinstance(typed_value.value, float) and math.isnan(typed_value.value):
220
255
  return None
@@ -227,21 +262,15 @@ def _cast_to_str(typed_value: TypedValue) -> Union[str, None]:
227
262
  def _validate_typed_columns(
228
263
  field_name: str, typed_columns: TypedColumns
229
264
  ) -> None:
230
- """
231
- Validate a TypedColumns object.
265
+ """Validate a TypedColumns object.
232
266
 
233
- Arguments:
234
- ---------
235
- field_name: str
236
- The name of the Schema field that the TypedColumns object is associated with.
237
- typed_columns: TypedColumns
238
- The TypedColumns object to validate.
267
+ Args:
268
+ field_name (str): The name of the Schema field that the TypedColumns object
269
+ is associated with.
270
+ typed_columns (TypedColumns): The TypedColumns object to validate.
239
271
 
240
272
  Raises:
241
- ------
242
- InvalidTypedColumnsError
243
- If the TypedColumns object is invalid.
244
-
273
+ InvalidTypedColumnsError: If the TypedColumns object is invalid.
245
274
  """
246
275
  if typed_columns.is_empty():
247
276
  raise InvalidTypedColumnsError(field_name=field_name, reason="is empty")
@@ -256,28 +285,20 @@ def _validate_typed_columns(
256
285
  def _cast_columns(
257
286
  dataframe: pd.DataFrame, columns: TypedColumns
258
287
  ) -> pd.DataFrame:
259
- """
260
- Cast columns corresponding to a single TypedColumns object and a single Arize Schema field.
288
+ """Cast columns corresponding to a single TypedColumns object and a single Arize Schema field.
289
+
261
290
  (feature_column_names or tag_column_names)
262
291
 
263
- Arguments:
264
- ---------
265
- dataframe: pd.DataFrame
266
- A deepcopy of the user's dataframe.
267
- columns: TypedColumns
268
- The TypedColumns object, which specifies the columns to cast
269
- (and/or to not cast) and their target types.
292
+ Args:
293
+ dataframe (pd.DataFrame): A deepcopy of the user's dataframe.
294
+ columns (TypedColumns): The TypedColumns object, which specifies the columns
295
+ to cast (and/or to not cast) and their target types.
270
296
 
271
297
  Returns:
272
- -------
273
- dataframe: pd.DataFrame
274
- The dataframe with columns cast to the specified types.
298
+ pd.DataFrame: The dataframe with columns cast to the specified types.
275
299
 
276
300
  Raises:
277
- ------
278
- ColumnCastingError
279
- If casting fails.
280
-
301
+ ColumnCastingError: If casting fails.
281
302
  """
282
303
  if columns.to_str:
283
304
  try:
@@ -324,52 +345,41 @@ def _cast_columns(
324
345
 
325
346
 
326
347
  def _cast_df(
327
- df: pd.DataFrame, cols: List[str], target_type_str: str
348
+ df: pd.DataFrame, cols: list[str], target_type_str: str
328
349
  ) -> pd.DataFrame:
329
- """
330
- Arguments:
331
- ---------
332
- df: pd.DataFrame
333
- A deepcopy of the user's dataframe.
334
- cols: List[str]
335
- The list of column names to cast.
336
- target_type_str: str
337
- The target type to cast to.
350
+ """Cast columns in a dataframe to the specified type.
351
+
352
+ Args:
353
+ df (pd.DataFrame): A deepcopy of the user's dataframe.
354
+ cols (list[str]): The list of column names to cast.
355
+ target_type_str (str): The target type to cast to.
338
356
 
339
357
  Returns:
340
- -------
341
- df: pd.DataFrame
342
- The dataframe with columns cast to the specified types.
358
+ pd.DataFrame: The dataframe with columns cast to the specified types.
343
359
 
344
360
  Raises:
345
- ------
346
- Exception
347
- If casting fails. Common exceptions raised by astype() are TypeError and ValueError.
348
-
361
+ Exception: If casting fails. Common exceptions raised by astype() are
362
+ TypeError and ValueError.
349
363
  """
350
364
  nan_mapping = {"nan": np.nan, "NaN": np.nan}
351
365
  df = df.replace(nan_mapping)
352
366
 
353
367
  # None or NaN-based values (including np.nan) are automatically converted to pandas pd.NA type
354
- return df.astype({col: target_type_str for col in cols})
368
+ return df.astype(dict.fromkeys(cols, target_type_str))
355
369
 
356
370
 
357
371
  def _convert_schema_field_types(
358
372
  schema: Schema,
359
373
  ) -> Schema:
360
- """
361
- Arguments:
362
- ---------
363
- schema: Schema
364
- The schema, which may include feature and tag column names
374
+ """Convert schema field types from TypedColumns to List[string] format.
375
+
376
+ Args:
377
+ schema (Schema): The schema, which may include feature and tag column names
365
378
  in a TypedColumns object or a List[string].
366
379
 
367
380
  Returns:
368
- -------
369
- schema: Schema
370
- A Schema, with feature and tag column names
371
- converted to the List[string] format expected in downstream validation.
372
-
381
+ Schema: A Schema, with feature and tag column names converted to the
382
+ List[string] format expected in downstream validation.
373
383
  """
374
384
  feature_column_names_list = (
375
385
  schema.feature_column_names