arize 8.0.0a23__py3-none-any.whl → 8.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +11 -10
- arize/_exporter/client.py +1 -1
- arize/_generated/api_client/__init__.py +0 -2
- arize/_generated/api_client/models/__init__.py +0 -1
- arize/_generated/api_client/models/datasets_create_request.py +2 -10
- arize/_generated/api_client/models/datasets_examples_insert_request.py +2 -10
- arize/_generated/api_client/test/test_datasets_create_request.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_insert_request.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +2 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +2 -6
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +2 -6
- arize/_generated/api_client_README.md +0 -1
- arize/client.py +47 -163
- arize/config.py +59 -100
- arize/datasets/client.py +11 -6
- arize/embeddings/nlp_generators.py +12 -6
- arize/embeddings/tabular_generators.py +14 -11
- arize/experiments/__init__.py +12 -0
- arize/experiments/client.py +13 -9
- arize/experiments/functions.py +6 -6
- arize/experiments/types.py +3 -3
- arize/{models → ml}/batch_validation/errors.py +2 -2
- arize/{models → ml}/batch_validation/validator.py +5 -3
- arize/{models → ml}/casting.py +42 -78
- arize/{models → ml}/client.py +19 -17
- arize/{models → ml}/proto.py +2 -2
- arize/{models → ml}/stream_validation.py +1 -1
- arize/{models → ml}/surrogate_explainer/mimic.py +6 -2
- arize/{types.py → ml/types.py} +99 -234
- arize/pre_releases.py +2 -1
- arize/projects/client.py +11 -6
- arize/spans/client.py +91 -86
- arize/spans/conversion.py +11 -4
- arize/spans/validation/common/value_validation.py +1 -1
- arize/spans/validation/spans/dataframe_form_validation.py +1 -1
- arize/spans/validation/spans/value_validation.py +2 -1
- arize/utils/dataframe.py +1 -1
- arize/utils/online_tasks/dataframe_preprocessor.py +5 -6
- arize/utils/types.py +105 -0
- arize/version.py +1 -1
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/METADATA +56 -59
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/RECORD +50 -51
- arize/_generated/api_client/models/primitive_value.py +0 -172
- arize/_generated/api_client/test/test_primitive_value.py +0 -50
- /arize/{models → ml}/__init__.py +0 -0
- /arize/{models → ml}/batch_validation/__init__.py +0 -0
- /arize/{models → ml}/bounded_executor.py +0 -0
- /arize/{models → ml}/surrogate_explainer/__init__.py +0 -0
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/WHEEL +0 -0
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/licenses/LICENSE +0 -0
- {arize-8.0.0a23.dist-info → arize-8.0.0b1.dist-info}/licenses/NOTICE +0 -0
arize/config.py
CHANGED
|
@@ -167,112 +167,71 @@ def _parse_bool(val: bool | str | None) -> bool:
|
|
|
167
167
|
class SDKConfiguration:
|
|
168
168
|
"""Configuration for the Arize SDK with endpoint and authentication settings.
|
|
169
169
|
|
|
170
|
-
This class is used internally by ArizeClient to manage SDK configuration.
|
|
171
|
-
|
|
170
|
+
This class is used internally by ArizeClient to manage SDK configuration. It is not
|
|
171
|
+
recommended to use this class directly; users should interact with ArizeClient
|
|
172
|
+
instead.
|
|
172
173
|
|
|
173
|
-
Configuration Precedence
|
|
174
|
-
------------------------
|
|
175
174
|
Each configuration parameter follows this resolution order:
|
|
176
175
|
1. Explicit value passed to ArizeClient constructor (highest priority)
|
|
177
176
|
2. Environment variable value
|
|
178
177
|
3. Built-in default value (lowest priority)
|
|
179
178
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
stream_max_queue_bound : int
|
|
238
|
-
Maximum queue size for streaming operations (minimum: 1).
|
|
239
|
-
Environment variable: ARIZE_STREAM_MAX_QUEUE_BOUND
|
|
240
|
-
Default: 5000
|
|
241
|
-
|
|
242
|
-
max_http_payload_size_mb : float
|
|
243
|
-
Maximum HTTP payload size in megabytes (minimum: 1).
|
|
244
|
-
Environment variable: ARIZE_MAX_HTTP_PAYLOAD_SIZE_MB
|
|
245
|
-
Default: 100
|
|
246
|
-
|
|
247
|
-
arize_directory : str
|
|
248
|
-
Directory for Arize SDK files (cache, logs, etc.).
|
|
249
|
-
Environment variable: ARIZE_DIRECTORY
|
|
250
|
-
Default: "~/.arize"
|
|
251
|
-
|
|
252
|
-
enable_caching : bool
|
|
253
|
-
Whether to enable local caching.
|
|
254
|
-
Environment variable: ARIZE_ENABLE_CACHING
|
|
255
|
-
Default: True
|
|
256
|
-
|
|
257
|
-
region : Region
|
|
258
|
-
Arize region (e.g., US_CENTRAL, EU_WEST). When specified, overrides
|
|
259
|
-
individual host/port settings.
|
|
260
|
-
Environment variable: ARIZE_REGION
|
|
261
|
-
Default: Region.UNSPECIFIED
|
|
262
|
-
|
|
263
|
-
single_host : str
|
|
264
|
-
Single host to use for all endpoints. Overrides individual host settings.
|
|
265
|
-
Environment variable: ARIZE_SINGLE_HOST
|
|
266
|
-
Default: "" (not set)
|
|
267
|
-
|
|
268
|
-
single_port : int
|
|
269
|
-
Single port to use for all endpoints. Overrides individual port settings (0-65535).
|
|
270
|
-
Environment variable: ARIZE_SINGLE_PORT
|
|
271
|
-
Default: 0 (not set)
|
|
272
|
-
|
|
273
|
-
See Also:
|
|
274
|
-
--------
|
|
275
|
-
ArizeClient : Main client class that uses this configuration
|
|
179
|
+
Args:
|
|
180
|
+
api_key: Arize API key for authentication. Required.
|
|
181
|
+
Environment variable: ARIZE_API_KEY.
|
|
182
|
+
Default: None (must be provided via argument or environment variable).
|
|
183
|
+
api_host: API endpoint host.
|
|
184
|
+
Environment variable: ARIZE_API_HOST.
|
|
185
|
+
Default: "api.arize.com".
|
|
186
|
+
api_scheme: API endpoint scheme (http/https).
|
|
187
|
+
Environment variable: ARIZE_API_SCHEME.
|
|
188
|
+
Default: "https".
|
|
189
|
+
otlp_host: OTLP (OpenTelemetry Protocol) endpoint host.
|
|
190
|
+
Environment variable: ARIZE_OTLP_HOST.
|
|
191
|
+
Default: "otlp.arize.com".
|
|
192
|
+
otlp_scheme: OTLP endpoint scheme (http/https).
|
|
193
|
+
Environment variable: ARIZE_OTLP_SCHEME.
|
|
194
|
+
Default: "https".
|
|
195
|
+
flight_host: Apache Arrow Flight endpoint host.
|
|
196
|
+
Environment variable: ARIZE_FLIGHT_HOST.
|
|
197
|
+
Default: "flight.arize.com".
|
|
198
|
+
flight_port: Apache Arrow Flight endpoint port (1-65535).
|
|
199
|
+
Environment variable: ARIZE_FLIGHT_PORT.
|
|
200
|
+
Default: 443.
|
|
201
|
+
flight_scheme: Apache Arrow Flight endpoint scheme.
|
|
202
|
+
Environment variable: ARIZE_FLIGHT_SCHEME.
|
|
203
|
+
Default: "grpc+tls".
|
|
204
|
+
pyarrow_max_chunksize: Maximum chunk size for PyArrow operations (1 to MAX_CHUNKSIZE).
|
|
205
|
+
Environment variable: ARIZE_MAX_CHUNKSIZE.
|
|
206
|
+
Default: 10_000.
|
|
207
|
+
request_verify: Whether to verify SSL certificates for HTTP requests.
|
|
208
|
+
Environment variable: ARIZE_REQUEST_VERIFY.
|
|
209
|
+
Default: True.
|
|
210
|
+
stream_max_workers: Maximum number of worker threads for streaming operations (minimum: 1).
|
|
211
|
+
Environment variable: ARIZE_STREAM_MAX_WORKERS.
|
|
212
|
+
Default: 8.
|
|
213
|
+
stream_max_queue_bound: Maximum queue size for streaming operations (minimum: 1).
|
|
214
|
+
Environment variable: ARIZE_STREAM_MAX_QUEUE_BOUND.
|
|
215
|
+
Default: 5000.
|
|
216
|
+
max_http_payload_size_mb: Maximum HTTP payload size in megabytes (minimum: 1).
|
|
217
|
+
Environment variable: ARIZE_MAX_HTTP_PAYLOAD_SIZE_MB.
|
|
218
|
+
Default: 100.
|
|
219
|
+
arize_directory: Directory for Arize SDK files (cache, logs, etc.).
|
|
220
|
+
Environment variable: ARIZE_DIRECTORY.
|
|
221
|
+
Default: "~/.arize".
|
|
222
|
+
enable_caching: Whether to enable local caching.
|
|
223
|
+
Environment variable: ARIZE_ENABLE_CACHING.
|
|
224
|
+
Default: True.
|
|
225
|
+
region: Arize region (e.g., US_CENTRAL, EU_WEST). When specified, overrides
|
|
226
|
+
individual host/port settings.
|
|
227
|
+
Environment variable: ARIZE_REGION.
|
|
228
|
+
Default: Region.UNSPECIFIED.
|
|
229
|
+
single_host: Single host to use for all endpoints. Overrides individual host settings.
|
|
230
|
+
Environment variable: ARIZE_SINGLE_HOST.
|
|
231
|
+
Default: "" (not set).
|
|
232
|
+
single_port: Single port to use for all endpoints. Overrides individual port settings (0-65535).
|
|
233
|
+
Environment variable: ARIZE_SINGLE_PORT.
|
|
234
|
+
Default: 0 (not set).
|
|
276
235
|
"""
|
|
277
236
|
|
|
278
237
|
api_key: str = field(
|
arize/datasets/client.py
CHANGED
|
@@ -30,17 +30,22 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class DatasetsClient:
|
|
33
|
-
"""Client for managing datasets including creation, retrieval, and example management.
|
|
33
|
+
"""Client for managing datasets including creation, retrieval, and example management.
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
35
|
+
This class is primarily intended for internal use within the SDK. Users are
|
|
36
|
+
highly encouraged to access resource-specific functionality via
|
|
37
|
+
:class:`arize.ArizeClient`.
|
|
37
38
|
|
|
38
|
-
|
|
39
|
-
|
|
39
|
+
The datasets client is a thin wrapper around the generated REST API client,
|
|
40
|
+
using the shared generated API client owned by
|
|
41
|
+
:class:`arize.config.SDKConfiguration`.
|
|
42
|
+
"""
|
|
40
43
|
|
|
44
|
+
def __init__(self, *, sdk_config: SDKConfiguration) -> None:
|
|
45
|
+
"""
|
|
41
46
|
Args:
|
|
42
47
|
sdk_config: Resolved SDK configuration.
|
|
43
|
-
"""
|
|
48
|
+
""" # noqa: D205, D212
|
|
44
49
|
self._sdk_config = sdk_config
|
|
45
50
|
|
|
46
51
|
# Import at runtime so it's still lazy and extras-gated by the parent
|
|
@@ -49,10 +49,13 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
|
|
|
49
49
|
) -> pd.Series:
|
|
50
50
|
"""Obtain embedding vectors from your text data using pre-trained large language models.
|
|
51
51
|
|
|
52
|
-
:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
52
|
+
Args:
|
|
53
|
+
text_col: A pandas Series containing the different pieces of text.
|
|
54
|
+
class_label_col: If this column is passed, the sentence "The classification label
|
|
55
|
+
is <class_label>" will be appended to the text in the `text_col`.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A pandas Series containing the embedding vectors.
|
|
56
59
|
"""
|
|
57
60
|
if not isinstance(text_col, pd.Series):
|
|
58
61
|
raise TypeError("text_col must be a pandas Series")
|
|
@@ -110,8 +113,11 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
|
|
|
110
113
|
) -> pd.Series:
|
|
111
114
|
"""Obtain embedding vectors from your text data using pre-trained large language models.
|
|
112
115
|
|
|
113
|
-
:
|
|
114
|
-
|
|
116
|
+
Args:
|
|
117
|
+
text_col: A pandas Series containing the different pieces of text.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
A pandas Series containing the embedding vectors.
|
|
115
121
|
"""
|
|
116
122
|
if not isinstance(text_col, pd.Series):
|
|
117
123
|
raise TypeError("text_col must be a pandas Series")
|
|
@@ -11,7 +11,7 @@ from arize.embeddings.constants import (
|
|
|
11
11
|
IMPORT_ERROR_MESSAGE,
|
|
12
12
|
)
|
|
13
13
|
from arize.embeddings.usecases import UseCases
|
|
14
|
-
from arize.types import is_list_of
|
|
14
|
+
from arize.utils.types import is_list_of
|
|
15
15
|
|
|
16
16
|
try:
|
|
17
17
|
from datasets import Dataset
|
|
@@ -79,16 +79,19 @@ class EmbeddingGeneratorForTabularFeatures(NLPEmbeddingGenerator):
|
|
|
79
79
|
Prompts are generated from your `selected_columns` and passed to a pre-trained
|
|
80
80
|
large language model for embedding vector computation.
|
|
81
81
|
|
|
82
|
-
:
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
82
|
+
Args:
|
|
83
|
+
df: Pandas DataFrame containing the tabular data. Not all columns will be
|
|
84
|
+
considered, see `selected_columns`.
|
|
85
|
+
selected_columns: Columns to be considered to construct the prompt to be passed to
|
|
86
|
+
the LLM.
|
|
87
|
+
col_name_map: Mapping between selected column names and a more verbose description of
|
|
88
|
+
the name. This helps the LLM understand the features better.
|
|
89
|
+
return_prompt_col: If set to True, an extra pandas Series will be returned
|
|
90
|
+
containing the constructed prompts. Defaults to False.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
A pandas Series containing the embedding vectors and, if `return_prompt_col` is
|
|
94
|
+
set to True, a pandas Series containing the prompts created from tabular features.
|
|
92
95
|
"""
|
|
93
96
|
if col_name_map is None:
|
|
94
97
|
col_name_map = {}
|
arize/experiments/__init__.py
CHANGED
|
@@ -1 +1,13 @@
|
|
|
1
1
|
"""Experiment tracking and evaluation functionality for the Arize SDK."""
|
|
2
|
+
|
|
3
|
+
from arize.experiments.evaluators.types import (
|
|
4
|
+
EvaluationResult,
|
|
5
|
+
EvaluationResultFieldNames,
|
|
6
|
+
)
|
|
7
|
+
from arize.experiments.types import ExperimentTaskFieldNames
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"EvaluationResult",
|
|
11
|
+
"EvaluationResultFieldNames",
|
|
12
|
+
"ExperimentTaskFieldNames",
|
|
13
|
+
]
|
arize/experiments/client.py
CHANGED
|
@@ -43,24 +43,29 @@ if TYPE_CHECKING:
|
|
|
43
43
|
from arize.experiments.evaluators.types import EvaluationResultFieldNames
|
|
44
44
|
from arize.experiments.types import (
|
|
45
45
|
ExperimentTask,
|
|
46
|
-
|
|
46
|
+
ExperimentTaskFieldNames,
|
|
47
47
|
)
|
|
48
48
|
|
|
49
49
|
logger = logging.getLogger(__name__)
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class ExperimentsClient:
|
|
53
|
-
"""Client for managing experiments including creation, execution, and result tracking.
|
|
53
|
+
"""Client for managing experiments including creation, execution, and result tracking.
|
|
54
54
|
|
|
55
|
-
|
|
56
|
-
|
|
55
|
+
This class is primarily intended for internal use within the SDK. Users are
|
|
56
|
+
highly encouraged to access resource-specific functionality via
|
|
57
|
+
:class:`arize.ArizeClient`.
|
|
57
58
|
|
|
58
|
-
|
|
59
|
-
|
|
59
|
+
The experiments client is a thin wrapper around the generated REST API client,
|
|
60
|
+
using the shared generated API client owned by
|
|
61
|
+
:class:`arize.config.SDKConfiguration`.
|
|
62
|
+
"""
|
|
60
63
|
|
|
64
|
+
def __init__(self, *, sdk_config: SDKConfiguration) -> None:
|
|
65
|
+
"""
|
|
61
66
|
Args:
|
|
62
67
|
sdk_config: Resolved SDK configuration.
|
|
63
|
-
"""
|
|
68
|
+
""" # noqa: D205, D212
|
|
64
69
|
self._sdk_config = sdk_config
|
|
65
70
|
from arize._generated import api_client as gen
|
|
66
71
|
|
|
@@ -109,7 +114,7 @@ class ExperimentsClient:
|
|
|
109
114
|
name: str,
|
|
110
115
|
dataset_id: str,
|
|
111
116
|
experiment_runs: list[dict[str, object]] | pd.DataFrame,
|
|
112
|
-
task_fields:
|
|
117
|
+
task_fields: ExperimentTaskFieldNames,
|
|
113
118
|
evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
|
|
114
119
|
force_http: bool = False,
|
|
115
120
|
) -> models.Experiment:
|
|
@@ -170,7 +175,6 @@ class ExperimentsClient:
|
|
|
170
175
|
from arize._generated import api_client as gen
|
|
171
176
|
|
|
172
177
|
data = experiment_df.to_dict(orient="records")
|
|
173
|
-
|
|
174
178
|
body = gen.ExperimentsCreateRequest(
|
|
175
179
|
name=name,
|
|
176
180
|
dataset_id=dataset_id,
|
arize/experiments/functions.py
CHANGED
|
@@ -56,7 +56,7 @@ from arize.experiments.types import (
|
|
|
56
56
|
ExperimentEvaluationRun,
|
|
57
57
|
ExperimentRun,
|
|
58
58
|
ExperimentTask,
|
|
59
|
-
|
|
59
|
+
ExperimentTaskFieldNames,
|
|
60
60
|
_TaskSummary,
|
|
61
61
|
)
|
|
62
62
|
|
|
@@ -768,7 +768,7 @@ def get_result_attr(r: object, attr: str, default: object = None) -> object:
|
|
|
768
768
|
|
|
769
769
|
def transform_to_experiment_format(
|
|
770
770
|
experiment_runs: list[dict[str, object]] | pd.DataFrame,
|
|
771
|
-
task_fields:
|
|
771
|
+
task_fields: ExperimentTaskFieldNames,
|
|
772
772
|
evaluator_fields: dict[str, EvaluationResultFieldNames] | None = None,
|
|
773
773
|
) -> pd.DataFrame:
|
|
774
774
|
"""Transform a DataFrame to match the format returned by run_experiment().
|
|
@@ -788,7 +788,7 @@ def transform_to_experiment_format(
|
|
|
788
788
|
else pd.DataFrame(experiment_runs)
|
|
789
789
|
)
|
|
790
790
|
# Validate required columns
|
|
791
|
-
required_cols = {task_fields.example_id, task_fields.
|
|
791
|
+
required_cols = {task_fields.example_id, task_fields.output}
|
|
792
792
|
missing_cols = required_cols - set(data.columns)
|
|
793
793
|
if missing_cols:
|
|
794
794
|
raise ValueError(f"Missing required columns: {missing_cols}")
|
|
@@ -799,11 +799,11 @@ def transform_to_experiment_format(
|
|
|
799
799
|
out_df["example_id"] = data[task_fields.example_id]
|
|
800
800
|
if task_fields.example_id != "example_id":
|
|
801
801
|
out_df.drop(task_fields.example_id, axis=1, inplace=True)
|
|
802
|
-
out_df["
|
|
802
|
+
out_df["output"] = data[task_fields.output].apply(
|
|
803
803
|
lambda x: json.dumps(x) if isinstance(x, dict) else x
|
|
804
804
|
)
|
|
805
|
-
if task_fields.
|
|
806
|
-
out_df.drop(task_fields.
|
|
805
|
+
if task_fields.output != "output":
|
|
806
|
+
out_df.drop(task_fields.output, axis=1, inplace=True)
|
|
807
807
|
|
|
808
808
|
# Process evaluator results
|
|
809
809
|
if evaluator_fields:
|
arize/experiments/types.py
CHANGED
|
@@ -397,17 +397,17 @@ def _top_string(s: pd.Series, length: int = 100) -> str | None:
|
|
|
397
397
|
|
|
398
398
|
|
|
399
399
|
@dataclass
|
|
400
|
-
class
|
|
400
|
+
class ExperimentTaskFieldNames:
|
|
401
401
|
"""Column names for mapping experiment task results in a DataFrame.
|
|
402
402
|
|
|
403
403
|
Args:
|
|
404
404
|
example_id: Name of column containing example IDs.
|
|
405
405
|
The ID values must match the id of the dataset rows.
|
|
406
|
-
|
|
406
|
+
output: Name of column containing task results
|
|
407
407
|
"""
|
|
408
408
|
|
|
409
409
|
example_id: str
|
|
410
|
-
|
|
410
|
+
output: str
|
|
411
411
|
|
|
412
412
|
|
|
413
413
|
TaskOutput = JSONSerializable
|
|
@@ -16,12 +16,12 @@ from arize.constants.ml import (
|
|
|
16
16
|
MAX_TAG_LENGTH,
|
|
17
17
|
)
|
|
18
18
|
from arize.logging import log_a_list
|
|
19
|
-
from arize.types import Environments, ModelTypes
|
|
19
|
+
from arize.ml.types import Environments, ModelTypes
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
22
|
from collections.abc import Iterable
|
|
23
23
|
|
|
24
|
-
from arize.types import Metrics
|
|
24
|
+
from arize.ml.types import Metrics
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class ValidationError(Exception, ABC):
|
|
@@ -40,8 +40,8 @@ from arize.constants.ml import (
|
|
|
40
40
|
MODEL_MAPPING_CONFIG,
|
|
41
41
|
)
|
|
42
42
|
from arize.logging import get_truncation_warning_message
|
|
43
|
-
from arize.
|
|
44
|
-
from arize.types import (
|
|
43
|
+
from arize.ml.batch_validation import errors as err
|
|
44
|
+
from arize.ml.types import (
|
|
45
45
|
CATEGORICAL_MODEL_TYPES,
|
|
46
46
|
NUMERIC_MODEL_TYPES,
|
|
47
47
|
BaseSchema,
|
|
@@ -53,9 +53,11 @@ from arize.types import (
|
|
|
53
53
|
ModelTypes,
|
|
54
54
|
PromptTemplateColumnNames,
|
|
55
55
|
Schema,
|
|
56
|
+
segments_intersect,
|
|
57
|
+
)
|
|
58
|
+
from arize.utils.types import (
|
|
56
59
|
is_dict_of,
|
|
57
60
|
is_iterable_of,
|
|
58
|
-
segments_intersect,
|
|
59
61
|
)
|
|
60
62
|
|
|
61
63
|
logger = logging.getLogger(__name__)
|
arize/{models → ml}/casting.py
RENAMED
|
@@ -9,7 +9,13 @@ from typing import TYPE_CHECKING
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
|
|
11
11
|
from arize.logging import log_a_list
|
|
12
|
-
from arize.types import
|
|
12
|
+
from arize.ml.types import (
|
|
13
|
+
ArizeTypes,
|
|
14
|
+
Schema,
|
|
15
|
+
TypedColumns,
|
|
16
|
+
TypedValue,
|
|
17
|
+
)
|
|
18
|
+
from arize.utils.types import is_list_of
|
|
13
19
|
|
|
14
20
|
if TYPE_CHECKING:
|
|
15
21
|
import pandas as pd
|
|
@@ -125,29 +131,20 @@ def cast_typed_columns(
|
|
|
125
131
|
This optional feature provides a simple way for users to prevent type drift within
|
|
126
132
|
a column across many SDK uploads.
|
|
127
133
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
A deepcopy of the user's dataframe.
|
|
132
|
-
schema: Schema
|
|
133
|
-
The schema, which may include feature and tag column names
|
|
134
|
+
Args:
|
|
135
|
+
dataframe (pd.DataFrame): A deepcopy of the user's dataframe.
|
|
136
|
+
schema (Schema): The schema, which may include feature and tag column names
|
|
134
137
|
in a TypedColumns object or a List[string].
|
|
135
138
|
|
|
136
139
|
Returns:
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
A new Schema object, with feature and tag column names converted to the List[string] format
|
|
142
|
-
expected in downstream validation.
|
|
140
|
+
tuple[pd.DataFrame, Schema]: A tuple containing:
|
|
141
|
+
- dataframe: The dataframe, with columns cast to the specified types.
|
|
142
|
+
- schema: A new Schema object, with feature and tag column names converted
|
|
143
|
+
to the List[string] format expected in downstream validation.
|
|
143
144
|
|
|
144
145
|
Raises:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
If casting fails.
|
|
148
|
-
InvalidTypedColumnsError
|
|
149
|
-
If the TypedColumns object is invalid.
|
|
150
|
-
|
|
146
|
+
ColumnCastingError: If casting fails.
|
|
147
|
+
InvalidTypedColumnsError: If the TypedColumns object is invalid.
|
|
151
148
|
"""
|
|
152
149
|
typed_column_fields = schema.typed_column_fields()
|
|
153
150
|
feature_field = "feature_column_names"
|
|
@@ -204,21 +201,14 @@ def _cast_value(
|
|
|
204
201
|
) -> str | int | float | list[str] | None:
|
|
205
202
|
"""Casts a TypedValue to its provided type, preserving all null values as None or float('nan').
|
|
206
203
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
typed_value: TypedValue
|
|
210
|
-
The TypedValue to cast.
|
|
204
|
+
Args:
|
|
205
|
+
typed_value (TypedValue): The TypedValue to cast.
|
|
211
206
|
|
|
212
207
|
Returns:
|
|
213
|
-
|
|
214
|
-
Union[str, int, float, List[str], None]
|
|
215
|
-
The cast value.
|
|
208
|
+
str | int | float | list[str] | None: The cast value.
|
|
216
209
|
|
|
217
210
|
Raises:
|
|
218
|
-
|
|
219
|
-
CastingError
|
|
220
|
-
If the value cannot be cast to the provided type.
|
|
221
|
-
|
|
211
|
+
CastingError: If the value cannot be cast to the provided type.
|
|
222
212
|
"""
|
|
223
213
|
if typed_value.value is None:
|
|
224
214
|
return None
|
|
@@ -274,18 +264,13 @@ def _validate_typed_columns(
|
|
|
274
264
|
) -> None:
|
|
275
265
|
"""Validate a TypedColumns object.
|
|
276
266
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
typed_columns: TypedColumns
|
|
282
|
-
The TypedColumns object to validate.
|
|
267
|
+
Args:
|
|
268
|
+
field_name (str): The name of the Schema field that the TypedColumns object
|
|
269
|
+
is associated with.
|
|
270
|
+
typed_columns (TypedColumns): The TypedColumns object to validate.
|
|
283
271
|
|
|
284
272
|
Raises:
|
|
285
|
-
|
|
286
|
-
InvalidTypedColumnsError
|
|
287
|
-
If the TypedColumns object is invalid.
|
|
288
|
-
|
|
273
|
+
InvalidTypedColumnsError: If the TypedColumns object is invalid.
|
|
289
274
|
"""
|
|
290
275
|
if typed_columns.is_empty():
|
|
291
276
|
raise InvalidTypedColumnsError(field_name=field_name, reason="is empty")
|
|
@@ -304,24 +289,16 @@ def _cast_columns(
|
|
|
304
289
|
|
|
305
290
|
(feature_column_names or tag_column_names)
|
|
306
291
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
columns: TypedColumns
|
|
312
|
-
The TypedColumns object, which specifies the columns to cast
|
|
313
|
-
(and/or to not cast) and their target types.
|
|
292
|
+
Args:
|
|
293
|
+
dataframe (pd.DataFrame): A deepcopy of the user's dataframe.
|
|
294
|
+
columns (TypedColumns): The TypedColumns object, which specifies the columns
|
|
295
|
+
to cast (and/or to not cast) and their target types.
|
|
314
296
|
|
|
315
297
|
Returns:
|
|
316
|
-
|
|
317
|
-
dataframe: pd.DataFrame
|
|
318
|
-
The dataframe with columns cast to the specified types.
|
|
298
|
+
pd.DataFrame: The dataframe with columns cast to the specified types.
|
|
319
299
|
|
|
320
300
|
Raises:
|
|
321
|
-
|
|
322
|
-
ColumnCastingError
|
|
323
|
-
If casting fails.
|
|
324
|
-
|
|
301
|
+
ColumnCastingError: If casting fails.
|
|
325
302
|
"""
|
|
326
303
|
if columns.to_str:
|
|
327
304
|
try:
|
|
@@ -372,25 +349,17 @@ def _cast_df(
|
|
|
372
349
|
) -> pd.DataFrame:
|
|
373
350
|
"""Cast columns in a dataframe to the specified type.
|
|
374
351
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
cols: List[str]
|
|
380
|
-
The list of column names to cast.
|
|
381
|
-
target_type_str: str
|
|
382
|
-
The target type to cast to.
|
|
352
|
+
Args:
|
|
353
|
+
df (pd.DataFrame): A deepcopy of the user's dataframe.
|
|
354
|
+
cols (list[str]): The list of column names to cast.
|
|
355
|
+
target_type_str (str): The target type to cast to.
|
|
383
356
|
|
|
384
357
|
Returns:
|
|
385
|
-
|
|
386
|
-
df: pd.DataFrame
|
|
387
|
-
The dataframe with columns cast to the specified types.
|
|
358
|
+
pd.DataFrame: The dataframe with columns cast to the specified types.
|
|
388
359
|
|
|
389
360
|
Raises:
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
If casting fails. Common exceptions raised by astype() are TypeError and ValueError.
|
|
393
|
-
|
|
361
|
+
Exception: If casting fails. Common exceptions raised by astype() are
|
|
362
|
+
TypeError and ValueError.
|
|
394
363
|
"""
|
|
395
364
|
nan_mapping = {"nan": np.nan, "NaN": np.nan}
|
|
396
365
|
df = df.replace(nan_mapping)
|
|
@@ -404,18 +373,13 @@ def _convert_schema_field_types(
|
|
|
404
373
|
) -> Schema:
|
|
405
374
|
"""Convert schema field types from TypedColumns to List[string] format.
|
|
406
375
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
schema: Schema
|
|
410
|
-
The schema, which may include feature and tag column names
|
|
376
|
+
Args:
|
|
377
|
+
schema (Schema): The schema, which may include feature and tag column names
|
|
411
378
|
in a TypedColumns object or a List[string].
|
|
412
379
|
|
|
413
380
|
Returns:
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
A Schema, with feature and tag column names
|
|
417
|
-
converted to the List[string] format expected in downstream validation.
|
|
418
|
-
|
|
381
|
+
Schema: A Schema, with feature and tag column names converted to the
|
|
382
|
+
List[string] format expected in downstream validation.
|
|
419
383
|
"""
|
|
420
384
|
feature_column_names_list = (
|
|
421
385
|
schema.feature_column_names
|