arize 8.0.0a23__py3-none-any.whl → 8.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. arize/__init__.py +11 -10
  2. arize/_exporter/client.py +1 -1
  3. arize/_generated/api_client/__init__.py +0 -2
  4. arize/_generated/api_client/models/__init__.py +0 -1
  5. arize/_generated/api_client/models/datasets_create_request.py +2 -10
  6. arize/_generated/api_client/models/datasets_examples_insert_request.py +2 -10
  7. arize/_generated/api_client/test/test_datasets_create_request.py +2 -6
  8. arize/_generated/api_client/test/test_datasets_examples_insert_request.py +2 -6
  9. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +2 -6
  10. arize/_generated/api_client/test/test_datasets_examples_update_request.py +2 -6
  11. arize/_generated/api_client/test/test_experiments_create_request.py +2 -6
  12. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +2 -6
  13. arize/_generated/api_client_README.md +0 -1
  14. arize/client.py +47 -163
  15. arize/config.py +59 -100
  16. arize/datasets/client.py +11 -6
  17. arize/embeddings/nlp_generators.py +12 -6
  18. arize/embeddings/tabular_generators.py +14 -11
  19. arize/experiments/__init__.py +12 -0
  20. arize/experiments/client.py +13 -9
  21. arize/experiments/functions.py +6 -6
  22. arize/experiments/types.py +3 -3
  23. arize/{models → ml}/batch_validation/errors.py +2 -2
  24. arize/{models → ml}/batch_validation/validator.py +5 -3
  25. arize/{models → ml}/casting.py +42 -78
  26. arize/{models → ml}/client.py +19 -17
  27. arize/{models → ml}/proto.py +2 -2
  28. arize/{models → ml}/stream_validation.py +1 -1
  29. arize/{models → ml}/surrogate_explainer/mimic.py +6 -2
  30. arize/{types.py → ml/types.py} +99 -234
  31. arize/pre_releases.py +2 -1
  32. arize/projects/client.py +11 -6
  33. arize/spans/client.py +91 -86
  34. arize/spans/conversion.py +11 -4
  35. arize/spans/validation/common/value_validation.py +1 -1
  36. arize/spans/validation/spans/dataframe_form_validation.py +1 -1
  37. arize/spans/validation/spans/value_validation.py +2 -1
  38. arize/utils/dataframe.py +1 -1
  39. arize/utils/online_tasks/dataframe_preprocessor.py +5 -6
  40. arize/utils/types.py +105 -0
  41. arize/version.py +1 -1
  42. {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/METADATA +56 -59
  43. {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/RECORD +50 -51
  44. arize/_generated/api_client/models/primitive_value.py +0 -172
  45. arize/_generated/api_client/test/test_primitive_value.py +0 -50
  46. /arize/{models → ml}/__init__.py +0 -0
  47. /arize/{models → ml}/batch_validation/__init__.py +0 -0
  48. /arize/{models → ml}/bounded_executor.py +0 -0
  49. /arize/{models → ml}/surrogate_explainer/__init__.py +0 -0
  50. {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/WHEEL +0 -0
  51. {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/licenses/LICENSE +0 -0
  52. {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/licenses/NOTICE +0 -0
arize/config.py CHANGED
@@ -167,112 +167,71 @@ def _parse_bool(val: bool | str | None) -> bool:
167
167
  class SDKConfiguration:
168
168
  """Configuration for the Arize SDK with endpoint and authentication settings.
169
169
 
170
- This class is used internally by ArizeClient to manage SDK configuration. Users
171
- typically interact with ArizeClient rather than instantiating this class directly.
170
+ This class is used internally by ArizeClient to manage SDK configuration. It is not
171
+ recommended to use this class directly; users should interact with ArizeClient
172
+ instead.
172
173
 
173
- Configuration Precedence
174
- ------------------------
175
174
  Each configuration parameter follows this resolution order:
176
175
  1. Explicit value passed to ArizeClient constructor (highest priority)
177
176
  2. Environment variable value
178
177
  3. Built-in default value (lowest priority)
179
178
 
180
- Parameters
181
- ----------
182
- api_key : str
183
- Arize API key for authentication. Required.
184
- Environment variable: ARIZE_API_KEY
185
- Default: None (must be provided via argument or environment variable)
186
-
187
- api_host : str
188
- API endpoint host.
189
- Environment variable: ARIZE_API_HOST
190
- Default: "api.arize.com"
191
-
192
- api_scheme : str
193
- API endpoint scheme (http/https).
194
- Environment variable: ARIZE_API_SCHEME
195
- Default: "https"
196
-
197
- otlp_host : str
198
- OTLP (OpenTelemetry Protocol) endpoint host.
199
- Environment variable: ARIZE_OTLP_HOST
200
- Default: "otlp.arize.com"
201
-
202
- otlp_scheme : str
203
- OTLP endpoint scheme (http/https).
204
- Environment variable: ARIZE_OTLP_SCHEME
205
- Default: "https"
206
-
207
- flight_host : str
208
- Apache Arrow Flight endpoint host.
209
- Environment variable: ARIZE_FLIGHT_HOST
210
- Default: "flight.arize.com"
211
-
212
- flight_port : int
213
- Apache Arrow Flight endpoint port (1-65535).
214
- Environment variable: ARIZE_FLIGHT_PORT
215
- Default: 443
216
-
217
- flight_scheme : str
218
- Apache Arrow Flight endpoint scheme.
219
- Environment variable: ARIZE_FLIGHT_SCHEME
220
- Default: "grpc+tls"
221
-
222
- pyarrow_max_chunksize : int
223
- Maximum chunk size for PyArrow operations (1 to MAX_CHUNKSIZE).
224
- Environment variable: ARIZE_MAX_CHUNKSIZE
225
- Default: 10_000
226
-
227
- request_verify : bool
228
- Whether to verify SSL certificates for HTTP requests.
229
- Environment variable: ARIZE_REQUEST_VERIFY
230
- Default: True
231
-
232
- stream_max_workers : int
233
- Maximum number of worker threads for streaming operations (minimum: 1).
234
- Environment variable: ARIZE_STREAM_MAX_WORKERS
235
- Default: 8
236
-
237
- stream_max_queue_bound : int
238
- Maximum queue size for streaming operations (minimum: 1).
239
- Environment variable: ARIZE_STREAM_MAX_QUEUE_BOUND
240
- Default: 5000
241
-
242
- max_http_payload_size_mb : float
243
- Maximum HTTP payload size in megabytes (minimum: 1).
244
- Environment variable: ARIZE_MAX_HTTP_PAYLOAD_SIZE_MB
245
- Default: 100
246
-
247
- arize_directory : str
248
- Directory for Arize SDK files (cache, logs, etc.).
249
- Environment variable: ARIZE_DIRECTORY
250
- Default: "~/.arize"
251
-
252
- enable_caching : bool
253
- Whether to enable local caching.
254
- Environment variable: ARIZE_ENABLE_CACHING
255
- Default: True
256
-
257
- region : Region
258
- Arize region (e.g., US_CENTRAL, EU_WEST). When specified, overrides
259
- individual host/port settings.
260
- Environment variable: ARIZE_REGION
261
- Default: Region.UNSPECIFIED
262
-
263
- single_host : str
264
- Single host to use for all endpoints. Overrides individual host settings.
265
- Environment variable: ARIZE_SINGLE_HOST
266
- Default: "" (not set)
267
-
268
- single_port : int
269
- Single port to use for all endpoints. Overrides individual port settings (0-65535).
270
- Environment variable: ARIZE_SINGLE_PORT
271
- Default: 0 (not set)
272
-
273
- See Also:
274
- --------
275
- ArizeClient : Main client class that uses this configuration
179
+ Args:
180
+ api_key: Arize API key for authentication. Required.
181
+ Environment variable: ARIZE_API_KEY.
182
+ Default: None (must be provided via argument or environment variable).
183
+ api_host: API endpoint host.
184
+ Environment variable: ARIZE_API_HOST.
185
+ Default: "api.arize.com".
186
+ api_scheme: API endpoint scheme (http/https).
187
+ Environment variable: ARIZE_API_SCHEME.
188
+ Default: "https".
189
+ otlp_host: OTLP (OpenTelemetry Protocol) endpoint host.
190
+ Environment variable: ARIZE_OTLP_HOST.
191
+ Default: "otlp.arize.com".
192
+ otlp_scheme: OTLP endpoint scheme (http/https).
193
+ Environment variable: ARIZE_OTLP_SCHEME.
194
+ Default: "https".
195
+ flight_host: Apache Arrow Flight endpoint host.
196
+ Environment variable: ARIZE_FLIGHT_HOST.
197
+ Default: "flight.arize.com".
198
+ flight_port: Apache Arrow Flight endpoint port (1-65535).
199
+ Environment variable: ARIZE_FLIGHT_PORT.
200
+ Default: 443.
201
+ flight_scheme: Apache Arrow Flight endpoint scheme.
202
+ Environment variable: ARIZE_FLIGHT_SCHEME.
203
+ Default: "grpc+tls".
204
+ pyarrow_max_chunksize: Maximum chunk size for PyArrow operations (1 to MAX_CHUNKSIZE).
205
+ Environment variable: ARIZE_MAX_CHUNKSIZE.
206
+ Default: 10_000.
207
+ request_verify: Whether to verify SSL certificates for HTTP requests.
208
+ Environment variable: ARIZE_REQUEST_VERIFY.
209
+ Default: True.
210
+ stream_max_workers: Maximum number of worker threads for streaming operations (minimum: 1).
211
+ Environment variable: ARIZE_STREAM_MAX_WORKERS.
212
+ Default: 8.
213
+ stream_max_queue_bound: Maximum queue size for streaming operations (minimum: 1).
214
+ Environment variable: ARIZE_STREAM_MAX_QUEUE_BOUND.
215
+ Default: 5000.
216
+ max_http_payload_size_mb: Maximum HTTP payload size in megabytes (minimum: 1).
217
+ Environment variable: ARIZE_MAX_HTTP_PAYLOAD_SIZE_MB.
218
+ Default: 100.
219
+ arize_directory: Directory for Arize SDK files (cache, logs, etc.).
220
+ Environment variable: ARIZE_DIRECTORY.
221
+ Default: "~/.arize".
222
+ enable_caching: Whether to enable local caching.
223
+ Environment variable: ARIZE_ENABLE_CACHING.
224
+ Default: True.
225
+ region: Arize region (e.g., US_CENTRAL, EU_WEST). When specified, overrides
226
+ individual host/port settings.
227
+ Environment variable: ARIZE_REGION.
228
+ Default: Region.UNSPECIFIED.
229
+ single_host: Single host to use for all endpoints. Overrides individual host settings.
230
+ Environment variable: ARIZE_SINGLE_HOST.
231
+ Default: "" (not set).
232
+ single_port: Single port to use for all endpoints. Overrides individual port settings (0-65535).
233
+ Environment variable: ARIZE_SINGLE_PORT.
234
+ Default: 0 (not set).
276
235
  """
277
236
 
278
237
  api_key: str = field(
arize/datasets/client.py CHANGED
@@ -30,17 +30,22 @@ logger = logging.getLogger(__name__)
30
30
 
31
31
 
32
32
  class DatasetsClient:
33
- """Client for managing datasets including creation, retrieval, and example management."""
33
+ """Client for managing datasets including creation, retrieval, and example management.
34
34
 
35
- def __init__(self, *, sdk_config: SDKConfiguration) -> None:
36
- """Create a datasets sub-client.
35
+ This class is primarily intended for internal use within the SDK. Users are
36
+ highly encouraged to access resource-specific functionality via
37
+ :class:`arize.ArizeClient`.
37
38
 
38
- The datasets client is a thin wrapper around the generated REST API client,
39
- using the shared generated API client owned by `SDKConfiguration`.
39
+ The datasets client is a thin wrapper around the generated REST API client,
40
+ using the shared generated API client owned by
41
+ :class:`arize.config.SDKConfiguration`.
42
+ """
40
43
 
44
+ def __init__(self, *, sdk_config: SDKConfiguration) -> None:
45
+ """
41
46
  Args:
42
47
  sdk_config: Resolved SDK configuration.
43
- """
48
+ """ # noqa: D205, D212
44
49
  self._sdk_config = sdk_config
45
50
 
46
51
  # Import at runtime so it's still lazy and extras-gated by the parent
@@ -49,10 +49,13 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
49
49
  ) -> pd.Series:
50
50
  """Obtain embedding vectors from your text data using pre-trained large language models.
51
51
 
52
- :param text_col: a pandas Series containing the different pieces of text.
53
- :param class_label_col: if this column is passed, the sentence "The classification label
54
- is <class_label>" will be appended to the text in the `text_col`.
55
- :return: a pandas Series containing the embedding vectors.
52
+ Args:
53
+ text_col: A pandas Series containing the different pieces of text.
54
+ class_label_col: If this column is passed, the sentence "The classification label
55
+ is <class_label>" will be appended to the text in the `text_col`.
56
+
57
+ Returns:
58
+ A pandas Series containing the embedding vectors.
56
59
  """
57
60
  if not isinstance(text_col, pd.Series):
58
61
  raise TypeError("text_col must be a pandas Series")
@@ -110,8 +113,11 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
110
113
  ) -> pd.Series:
111
114
  """Obtain embedding vectors from your text data using pre-trained large language models.
112
115
 
113
- :param text_col: a pandas Series containing the different pieces of text.
114
- :return: a pandas Series containing the embedding vectors.
116
+ Args:
117
+ text_col: A pandas Series containing the different pieces of text.
118
+
119
+ Returns:
120
+ A pandas Series containing the embedding vectors.
115
121
  """
116
122
  if not isinstance(text_col, pd.Series):
117
123
  raise TypeError("text_col must be a pandas Series")
@@ -11,7 +11,7 @@ from arize.embeddings.constants import (
11
11
  IMPORT_ERROR_MESSAGE,
12
12
  )
13
13
  from arize.embeddings.usecases import UseCases
14
- from arize.types import is_list_of
14
+ from arize.utils.types import is_list_of
15
15
 
16
16
  try:
17
17
  from datasets import Dataset
@@ -79,16 +79,19 @@ class EmbeddingGeneratorForTabularFeatures(NLPEmbeddingGenerator):
79
79
  Prompts are generated from your `selected_columns` and passed to a pre-trained
80
80
  large language model for embedding vector computation.
81
81
 
82
- :param df: pandas DataFrame containing the tabular data, not all columns will be
83
- considered, see `selected_columns`.
84
- :param selected_columns: columns to be considered to construct the prompt to be passed to
85
- the LLM.
86
- :param col_name_map: mapping between selected column names and a more verbose description of
87
- the name. This helps the LLM understand the features better.
88
- :param return_prompt_col: if set to True, an extra pandas Series will be returned
89
- containing the constructed prompts. Defaults to False.
90
- :return: a pandas Series containing the embedding vectors and, if `return_prompt_col` is
91
- set to True, a pandas Series containing the prompts created from tabular features.
82
+ Args:
83
+ df: Pandas DataFrame containing the tabular data. Not all columns will be
84
+ considered, see `selected_columns`.
85
+ selected_columns: Columns to be considered to construct the prompt to be passed to
86
+ the LLM.
87
+ col_name_map: Mapping between selected column names and a more verbose description of
88
+ the name. This helps the LLM understand the features better.
89
+ return_prompt_col: If set to True, an extra pandas Series will be returned
90
+ containing the constructed prompts. Defaults to False.
91
+
92
+ Returns:
93
+ A pandas Series containing the embedding vectors and, if `return_prompt_col` is
94
+ set to True, a pandas Series containing the prompts created from tabular features.
92
95
  """
93
96
  if col_name_map is None:
94
97
  col_name_map = {}
@@ -1 +1,13 @@
1
1
  """Experiment tracking and evaluation functionality for the Arize SDK."""
2
+
3
+ from arize.experiments.evaluators.types import (
4
+ EvaluationResult,
5
+ EvaluationResultFieldNames,
6
+ )
7
+ from arize.experiments.types import ExperimentTaskFieldNames
8
+
9
+ __all__ = [
10
+ "EvaluationResult",
11
+ "EvaluationResultFieldNames",
12
+ "ExperimentTaskFieldNames",
13
+ ]
@@ -43,24 +43,29 @@ if TYPE_CHECKING:
43
43
  from arize.experiments.evaluators.types import EvaluationResultFieldNames
44
44
  from arize.experiments.types import (
45
45
  ExperimentTask,
46
- ExperimentTaskResultFieldNames,
46
+ ExperimentTaskFieldNames,
47
47
  )
48
48
 
49
49
  logger = logging.getLogger(__name__)
50
50
 
51
51
 
52
52
  class ExperimentsClient:
53
- """Client for managing experiments including creation, execution, and result tracking."""
53
+ """Client for managing experiments including creation, execution, and result tracking.
54
54
 
55
- def __init__(self, *, sdk_config: SDKConfiguration) -> None:
56
- """Create an experiments sub-client.
55
+ This class is primarily intended for internal use within the SDK. Users are
56
+ highly encouraged to access resource-specific functionality via
57
+ :class:`arize.ArizeClient`.
57
58
 
58
- The experiments client is a thin wrapper around the generated REST API client,
59
- using the shared generated API client owned by `SDKConfiguration`.
59
+ The experiments client is a thin wrapper around the generated REST API client,
60
+ using the shared generated API client owned by
61
+ :class:`arize.config.SDKConfiguration`.
62
+ """
60
63
 
64
+ def __init__(self, *, sdk_config: SDKConfiguration) -> None:
65
+ """
61
66
  Args:
62
67
  sdk_config: Resolved SDK configuration.
63
- """
68
+ """ # noqa: D205, D212
64
69
  self._sdk_config = sdk_config
65
70
  from arize._generated import api_client as gen
66
71
 
@@ -109,7 +114,7 @@ class ExperimentsClient:
109
114
  name: str,
110
115
  dataset_id: str,
111
116
  experiment_runs: list[dict[str, object]] | pd.DataFrame,
112
- task_fields: ExperimentTaskResultFieldNames,
117
+ task_fields: ExperimentTaskFieldNames,
113
118
  evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
114
119
  force_http: bool = False,
115
120
  ) -> models.Experiment:
@@ -170,7 +175,6 @@ class ExperimentsClient:
170
175
  from arize._generated import api_client as gen
171
176
 
172
177
  data = experiment_df.to_dict(orient="records")
173
-
174
178
  body = gen.ExperimentsCreateRequest(
175
179
  name=name,
176
180
  dataset_id=dataset_id,
@@ -56,7 +56,7 @@ from arize.experiments.types import (
56
56
  ExperimentEvaluationRun,
57
57
  ExperimentRun,
58
58
  ExperimentTask,
59
- ExperimentTaskResultFieldNames,
59
+ ExperimentTaskFieldNames,
60
60
  _TaskSummary,
61
61
  )
62
62
 
@@ -768,7 +768,7 @@ def get_result_attr(r: object, attr: str, default: object = None) -> object:
768
768
 
769
769
  def transform_to_experiment_format(
770
770
  experiment_runs: list[dict[str, object]] | pd.DataFrame,
771
- task_fields: ExperimentTaskResultFieldNames,
771
+ task_fields: ExperimentTaskFieldNames,
772
772
  evaluator_fields: dict[str, EvaluationResultFieldNames] | None = None,
773
773
  ) -> pd.DataFrame:
774
774
  """Transform a DataFrame to match the format returned by run_experiment().
@@ -788,7 +788,7 @@ def transform_to_experiment_format(
788
788
  else pd.DataFrame(experiment_runs)
789
789
  )
790
790
  # Validate required columns
791
- required_cols = {task_fields.example_id, task_fields.result}
791
+ required_cols = {task_fields.example_id, task_fields.output}
792
792
  missing_cols = required_cols - set(data.columns)
793
793
  if missing_cols:
794
794
  raise ValueError(f"Missing required columns: {missing_cols}")
@@ -799,11 +799,11 @@ def transform_to_experiment_format(
799
799
  out_df["example_id"] = data[task_fields.example_id]
800
800
  if task_fields.example_id != "example_id":
801
801
  out_df.drop(task_fields.example_id, axis=1, inplace=True)
802
- out_df["result"] = data[task_fields.result].apply(
802
+ out_df["output"] = data[task_fields.output].apply(
803
803
  lambda x: json.dumps(x) if isinstance(x, dict) else x
804
804
  )
805
- if task_fields.result != "result":
806
- out_df.drop(task_fields.result, axis=1, inplace=True)
805
+ if task_fields.output != "output":
806
+ out_df.drop(task_fields.output, axis=1, inplace=True)
807
807
 
808
808
  # Process evaluator results
809
809
  if evaluator_fields:
@@ -397,17 +397,17 @@ def _top_string(s: pd.Series, length: int = 100) -> str | None:
397
397
 
398
398
 
399
399
  @dataclass
400
- class ExperimentTaskResultFieldNames:
400
+ class ExperimentTaskFieldNames:
401
401
  """Column names for mapping experiment task results in a DataFrame.
402
402
 
403
403
  Args:
404
404
  example_id: Name of column containing example IDs.
405
405
  The ID values must match the id of the dataset rows.
406
- result: Name of column containing task results
406
+ output: Name of column containing task results
407
407
  """
408
408
 
409
409
  example_id: str
410
- result: str
410
+ output: str
411
411
 
412
412
 
413
413
  TaskOutput = JSONSerializable
@@ -16,12 +16,12 @@ from arize.constants.ml import (
16
16
  MAX_TAG_LENGTH,
17
17
  )
18
18
  from arize.logging import log_a_list
19
- from arize.types import Environments, ModelTypes
19
+ from arize.ml.types import Environments, ModelTypes
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  from collections.abc import Iterable
23
23
 
24
- from arize.types import Metrics
24
+ from arize.ml.types import Metrics
25
25
 
26
26
 
27
27
  class ValidationError(Exception, ABC):
@@ -40,8 +40,8 @@ from arize.constants.ml import (
40
40
  MODEL_MAPPING_CONFIG,
41
41
  )
42
42
  from arize.logging import get_truncation_warning_message
43
- from arize.models.batch_validation import errors as err
44
- from arize.types import (
43
+ from arize.ml.batch_validation import errors as err
44
+ from arize.ml.types import (
45
45
  CATEGORICAL_MODEL_TYPES,
46
46
  NUMERIC_MODEL_TYPES,
47
47
  BaseSchema,
@@ -53,9 +53,11 @@ from arize.types import (
53
53
  ModelTypes,
54
54
  PromptTemplateColumnNames,
55
55
  Schema,
56
+ segments_intersect,
57
+ )
58
+ from arize.utils.types import (
56
59
  is_dict_of,
57
60
  is_iterable_of,
58
- segments_intersect,
59
61
  )
60
62
 
61
63
  logger = logging.getLogger(__name__)
@@ -9,7 +9,13 @@ from typing import TYPE_CHECKING
9
9
  import numpy as np
10
10
 
11
11
  from arize.logging import log_a_list
12
- from arize.types import ArizeTypes, Schema, TypedColumns, TypedValue, is_list_of
12
+ from arize.ml.types import (
13
+ ArizeTypes,
14
+ Schema,
15
+ TypedColumns,
16
+ TypedValue,
17
+ )
18
+ from arize.utils.types import is_list_of
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  import pandas as pd
@@ -125,29 +131,20 @@ def cast_typed_columns(
125
131
  This optional feature provides a simple way for users to prevent type drift within
126
132
  a column across many SDK uploads.
127
133
 
128
- Arguments:
129
- ---------
130
- dataframe: pd.DataFrame
131
- A deepcopy of the user's dataframe.
132
- schema: Schema
133
- The schema, which may include feature and tag column names
134
+ Args:
135
+ dataframe (pd.DataFrame): A deepcopy of the user's dataframe.
136
+ schema (Schema): The schema, which may include feature and tag column names
134
137
  in a TypedColumns object or a List[string].
135
138
 
136
139
  Returns:
137
- -------
138
- dataframe: pd.DataFrame
139
- The dataframe, with columns cast to the specified types.
140
- schema: Schema
141
- A new Schema object, with feature and tag column names converted to the List[string] format
142
- expected in downstream validation.
140
+ tuple[pd.DataFrame, Schema]: A tuple containing:
141
+ - dataframe: The dataframe, with columns cast to the specified types.
142
+ - schema: A new Schema object, with feature and tag column names converted
143
+ to the List[string] format expected in downstream validation.
143
144
 
144
145
  Raises:
145
- ------
146
- ColumnCastingError
147
- If casting fails.
148
- InvalidTypedColumnsError
149
- If the TypedColumns object is invalid.
150
-
146
+ ColumnCastingError: If casting fails.
147
+ InvalidTypedColumnsError: If the TypedColumns object is invalid.
151
148
  """
152
149
  typed_column_fields = schema.typed_column_fields()
153
150
  feature_field = "feature_column_names"
@@ -204,21 +201,14 @@ def _cast_value(
204
201
  ) -> str | int | float | list[str] | None:
205
202
  """Casts a TypedValue to its provided type, preserving all null values as None or float('nan').
206
203
 
207
- Arguments:
208
- ---------
209
- typed_value: TypedValue
210
- The TypedValue to cast.
204
+ Args:
205
+ typed_value (TypedValue): The TypedValue to cast.
211
206
 
212
207
  Returns:
213
- -------
214
- Union[str, int, float, List[str], None]
215
- The cast value.
208
+ str | int | float | list[str] | None: The cast value.
216
209
 
217
210
  Raises:
218
- ------
219
- CastingError
220
- If the value cannot be cast to the provided type.
221
-
211
+ CastingError: If the value cannot be cast to the provided type.
222
212
  """
223
213
  if typed_value.value is None:
224
214
  return None
@@ -274,18 +264,13 @@ def _validate_typed_columns(
274
264
  ) -> None:
275
265
  """Validate a TypedColumns object.
276
266
 
277
- Arguments:
278
- ---------
279
- field_name: str
280
- The name of the Schema field that the TypedColumns object is associated with.
281
- typed_columns: TypedColumns
282
- The TypedColumns object to validate.
267
+ Args:
268
+ field_name (str): The name of the Schema field that the TypedColumns object
269
+ is associated with.
270
+ typed_columns (TypedColumns): The TypedColumns object to validate.
283
271
 
284
272
  Raises:
285
- ------
286
- InvalidTypedColumnsError
287
- If the TypedColumns object is invalid.
288
-
273
+ InvalidTypedColumnsError: If the TypedColumns object is invalid.
289
274
  """
290
275
  if typed_columns.is_empty():
291
276
  raise InvalidTypedColumnsError(field_name=field_name, reason="is empty")
@@ -304,24 +289,16 @@ def _cast_columns(
304
289
 
305
290
  (feature_column_names or tag_column_names)
306
291
 
307
- Arguments:
308
- ---------
309
- dataframe: pd.DataFrame
310
- A deepcopy of the user's dataframe.
311
- columns: TypedColumns
312
- The TypedColumns object, which specifies the columns to cast
313
- (and/or to not cast) and their target types.
292
+ Args:
293
+ dataframe (pd.DataFrame): A deepcopy of the user's dataframe.
294
+ columns (TypedColumns): The TypedColumns object, which specifies the columns
295
+ to cast (and/or to not cast) and their target types.
314
296
 
315
297
  Returns:
316
- -------
317
- dataframe: pd.DataFrame
318
- The dataframe with columns cast to the specified types.
298
+ pd.DataFrame: The dataframe with columns cast to the specified types.
319
299
 
320
300
  Raises:
321
- ------
322
- ColumnCastingError
323
- If casting fails.
324
-
301
+ ColumnCastingError: If casting fails.
325
302
  """
326
303
  if columns.to_str:
327
304
  try:
@@ -372,25 +349,17 @@ def _cast_df(
372
349
  ) -> pd.DataFrame:
373
350
  """Cast columns in a dataframe to the specified type.
374
351
 
375
- Arguments:
376
- ---------
377
- df: pd.DataFrame
378
- A deepcopy of the user's dataframe.
379
- cols: List[str]
380
- The list of column names to cast.
381
- target_type_str: str
382
- The target type to cast to.
352
+ Args:
353
+ df (pd.DataFrame): A deepcopy of the user's dataframe.
354
+ cols (list[str]): The list of column names to cast.
355
+ target_type_str (str): The target type to cast to.
383
356
 
384
357
  Returns:
385
- -------
386
- df: pd.DataFrame
387
- The dataframe with columns cast to the specified types.
358
+ pd.DataFrame: The dataframe with columns cast to the specified types.
388
359
 
389
360
  Raises:
390
- ------
391
- Exception
392
- If casting fails. Common exceptions raised by astype() are TypeError and ValueError.
393
-
361
+ Exception: If casting fails. Common exceptions raised by astype() are
362
+ TypeError and ValueError.
394
363
  """
395
364
  nan_mapping = {"nan": np.nan, "NaN": np.nan}
396
365
  df = df.replace(nan_mapping)
@@ -404,18 +373,13 @@ def _convert_schema_field_types(
404
373
  ) -> Schema:
405
374
  """Convert schema field types from TypedColumns to List[string] format.
406
375
 
407
- Arguments:
408
- ---------
409
- schema: Schema
410
- The schema, which may include feature and tag column names
376
+ Args:
377
+ schema (Schema): The schema, which may include feature and tag column names
411
378
  in a TypedColumns object or a List[string].
412
379
 
413
380
  Returns:
414
- -------
415
- schema: Schema
416
- A Schema, with feature and tag column names
417
- converted to the List[string] format expected in downstream validation.
418
-
381
+ Schema: A Schema, with feature and tag column names converted to the
382
+ List[string] format expected in downstream validation.
419
383
  """
420
384
  feature_column_names_list = (
421
385
  schema.feature_column_names