arize 8.0.0b2__py3-none-any.whl → 8.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +8 -1
- arize/_exporter/client.py +18 -17
- arize/_exporter/parsers/tracing_data_parser.py +9 -4
- arize/_exporter/validation.py +1 -1
- arize/_flight/client.py +33 -13
- arize/_lazy.py +37 -2
- arize/client.py +61 -35
- arize/config.py +168 -14
- arize/constants/config.py +1 -0
- arize/datasets/client.py +32 -19
- arize/embeddings/auto_generator.py +14 -7
- arize/embeddings/base_generators.py +15 -9
- arize/embeddings/cv_generators.py +2 -2
- arize/embeddings/nlp_generators.py +8 -8
- arize/embeddings/tabular_generators.py +5 -5
- arize/exceptions/config.py +22 -0
- arize/exceptions/parameters.py +1 -1
- arize/exceptions/values.py +8 -5
- arize/experiments/__init__.py +4 -0
- arize/experiments/client.py +17 -11
- arize/experiments/evaluators/base.py +6 -3
- arize/experiments/evaluators/executors.py +6 -4
- arize/experiments/evaluators/rate_limiters.py +3 -1
- arize/experiments/evaluators/types.py +7 -5
- arize/experiments/evaluators/utils.py +7 -5
- arize/experiments/functions.py +111 -48
- arize/experiments/tracing.py +4 -1
- arize/experiments/types.py +31 -26
- arize/logging.py +53 -32
- arize/ml/batch_validation/validator.py +82 -70
- arize/ml/bounded_executor.py +25 -6
- arize/ml/casting.py +45 -27
- arize/ml/client.py +35 -28
- arize/ml/proto.py +16 -17
- arize/ml/stream_validation.py +63 -25
- arize/ml/surrogate_explainer/mimic.py +15 -7
- arize/ml/types.py +26 -12
- arize/pre_releases.py +7 -6
- arize/py.typed +0 -0
- arize/regions.py +10 -10
- arize/spans/client.py +113 -21
- arize/spans/conversion.py +7 -5
- arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
- arize/spans/validation/annotations/value_validation.py +11 -14
- arize/spans/validation/common/dataframe_form_validation.py +1 -1
- arize/spans/validation/common/value_validation.py +10 -13
- arize/spans/validation/evals/value_validation.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +1 -1
- arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
- arize/spans/validation/metadata/value_validation.py +23 -1
- arize/utils/arrow.py +37 -1
- arize/utils/online_tasks/dataframe_preprocessor.py +8 -4
- arize/utils/proto.py +0 -1
- arize/utils/types.py +6 -6
- arize/version.py +1 -1
- {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/METADATA +18 -3
- {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/RECORD +60 -58
- {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/WHEEL +0 -0
- {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/LICENSE +0 -0
- {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/NOTICE +0 -0
arize/config.py
CHANGED
|
@@ -25,6 +25,7 @@ from arize.constants.config import (
|
|
|
25
25
|
ENV_API_KEY,
|
|
26
26
|
ENV_API_SCHEME,
|
|
27
27
|
ENV_ARIZE_DIRECTORY,
|
|
28
|
+
ENV_BASE_DOMAIN,
|
|
28
29
|
ENV_ENABLE_CACHING,
|
|
29
30
|
ENV_FLIGHT_HOST,
|
|
30
31
|
ENV_FLIGHT_PORT,
|
|
@@ -42,6 +43,7 @@ from arize.constants.config import (
|
|
|
42
43
|
)
|
|
43
44
|
from arize.constants.pyarrow import MAX_CHUNKSIZE
|
|
44
45
|
from arize.exceptions.auth import MissingAPIKeyError
|
|
46
|
+
from arize.exceptions.config import MultipleEndpointOverridesError
|
|
45
47
|
from arize.regions import REGION_ENDPOINTS, Region
|
|
46
48
|
from arize.version import __version__
|
|
47
49
|
|
|
@@ -53,18 +55,44 @@ ALLOWED_HTTP_SCHEMES = {"http", "https"}
|
|
|
53
55
|
|
|
54
56
|
|
|
55
57
|
def _is_sensitive_field(name: str) -> bool:
|
|
58
|
+
"""Check if a field name contains sensitive information markers.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
name: The field name to check.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
bool: True if the field name contains 'key', 'token', or 'secret' (case-insensitive).
|
|
65
|
+
"""
|
|
56
66
|
n = name.lower()
|
|
57
67
|
return bool(any(k in n for k in SENSITIVE_FIELD_MARKERS))
|
|
58
68
|
|
|
59
69
|
|
|
60
70
|
def _mask_secret(secret: str, N: int = 4) -> str:
|
|
61
|
-
"""
|
|
71
|
+
"""Mask a secret string by showing only the first N characters.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
secret: The secret string to mask.
|
|
75
|
+
N: Number of characters to show before masking. Defaults to 4.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
str: The masked string (first N chars + '***'), or empty string if input is empty.
|
|
79
|
+
"""
|
|
62
80
|
if len(secret) == 0:
|
|
63
81
|
return ""
|
|
64
82
|
return f"{secret[:N]}***"
|
|
65
83
|
|
|
66
84
|
|
|
67
85
|
def _endpoint(scheme: str, base: str, path: str = "") -> str:
|
|
86
|
+
"""Construct a full endpoint URL from scheme, base, and optional path.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
scheme: The URL scheme (e.g., "http", "https").
|
|
90
|
+
base: The base URL or hostname.
|
|
91
|
+
path: Optional path to append to the base URL. Defaults to empty string.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
str: The fully constructed endpoint URL.
|
|
95
|
+
"""
|
|
68
96
|
endpoint = scheme + "://" + base.rstrip("/")
|
|
69
97
|
if path:
|
|
70
98
|
endpoint += "/" + path.lstrip("/")
|
|
@@ -72,6 +100,18 @@ def _endpoint(scheme: str, base: str, path: str = "") -> str:
|
|
|
72
100
|
|
|
73
101
|
|
|
74
102
|
def _env_http_scheme(name: str, default: str) -> str:
|
|
103
|
+
"""Get an HTTP scheme from environment variable with validation.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
name: The environment variable name.
|
|
107
|
+
default: The default value if the environment variable is not set.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
str: The validated HTTP scheme ('http' or 'https').
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If the scheme is not 'http' or 'https'.
|
|
114
|
+
"""
|
|
75
115
|
v = _env_str(name, default).lower()
|
|
76
116
|
if v not in ALLOWED_HTTP_SCHEMES:
|
|
77
117
|
raise ValueError(
|
|
@@ -86,6 +126,20 @@ def _env_str(
|
|
|
86
126
|
min_len: int | None = None,
|
|
87
127
|
max_len: int | None = None,
|
|
88
128
|
) -> str:
|
|
129
|
+
"""Get a string value from environment variable with length validation.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
name: The environment variable name.
|
|
133
|
+
default: The default value if the environment variable is not set.
|
|
134
|
+
min_len: Optional minimum length constraint for the string.
|
|
135
|
+
max_len: Optional maximum length constraint for the string.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
str: The validated string value (stripped of whitespace).
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
ValueError: If the string length violates min_len or max_len constraints.
|
|
142
|
+
"""
|
|
89
143
|
val = os.getenv(name, default).strip()
|
|
90
144
|
|
|
91
145
|
if min_len is not None and len(val) < min_len:
|
|
@@ -107,6 +161,20 @@ def _env_int(
|
|
|
107
161
|
min_val: int | None = None,
|
|
108
162
|
max_val: int | None = None,
|
|
109
163
|
) -> int:
|
|
164
|
+
"""Get an integer value from environment variable with range validation.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
name: The environment variable name.
|
|
168
|
+
default: The default value if the environment variable is not set.
|
|
169
|
+
min_val: Optional minimum value constraint for the integer.
|
|
170
|
+
max_val: Optional maximum value constraint for the integer.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
int: The validated integer value.
|
|
174
|
+
|
|
175
|
+
Raises:
|
|
176
|
+
ValueError: If the value cannot be parsed as an integer or violates min_val/max_val constraints.
|
|
177
|
+
"""
|
|
110
178
|
raw = os.getenv(name, default)
|
|
111
179
|
try:
|
|
112
180
|
val = int(raw)
|
|
@@ -132,6 +200,20 @@ def _env_float(
|
|
|
132
200
|
min_val: float | None = None,
|
|
133
201
|
max_val: float | None = None,
|
|
134
202
|
) -> float:
|
|
203
|
+
"""Get a float value from environment variable with range validation.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
name: The environment variable name.
|
|
207
|
+
default: The default value if the environment variable is not set.
|
|
208
|
+
min_val: Optional minimum value constraint for the float.
|
|
209
|
+
max_val: Optional maximum value constraint for the float.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
float: The validated float value.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ValueError: If the value cannot be parsed as a float or violates min_val/max_val constraints.
|
|
216
|
+
"""
|
|
135
217
|
raw = os.getenv(name, default)
|
|
136
218
|
try:
|
|
137
219
|
val = float(raw)
|
|
@@ -152,10 +234,28 @@ def _env_float(
|
|
|
152
234
|
|
|
153
235
|
|
|
154
236
|
def _env_bool(name: str, default: bool) -> bool:
|
|
237
|
+
"""Get a boolean value from environment variable.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
name: The environment variable name.
|
|
241
|
+
default: The default boolean value if the environment variable is not set.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
bool: The parsed boolean value.
|
|
245
|
+
"""
|
|
155
246
|
return _parse_bool(os.getenv(name, str(default)))
|
|
156
247
|
|
|
157
248
|
|
|
158
249
|
def _parse_bool(val: bool | str | None) -> bool:
|
|
250
|
+
"""Parse a boolean value from various input types.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
val: The value to parse. Can be a bool, string, or None.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
bool: True if the value is already True or matches one of the truthy strings
|
|
257
|
+
('1', 'true', 'yes', 'on', case-insensitive). False otherwise.
|
|
258
|
+
"""
|
|
159
259
|
if isinstance(val, bool):
|
|
160
260
|
return val
|
|
161
261
|
return (val or "").strip().lower() in {"1", "true", "yes", "on"}
|
|
@@ -227,15 +327,27 @@ class SDKConfiguration:
|
|
|
227
327
|
individual host/port settings.
|
|
228
328
|
Environment variable: ARIZE_REGION.
|
|
229
329
|
Default: :class:`Region.UNSET`.
|
|
230
|
-
single_host: Single host to use for all endpoints.
|
|
330
|
+
single_host: Single host to use for all endpoints. When specified, overrides
|
|
331
|
+
individual host settings.
|
|
231
332
|
Environment variable: ARIZE_SINGLE_HOST.
|
|
232
333
|
Default: "" (not set).
|
|
233
|
-
single_port: Single port to use for all endpoints.
|
|
334
|
+
single_port: Single port to use for all endpoints. When specified, overrides
|
|
335
|
+
individual port settings (0-65535).
|
|
234
336
|
Environment variable: ARIZE_SINGLE_PORT.
|
|
235
337
|
Default: 0 (not set).
|
|
338
|
+
base_domain: Base domain for generating all endpoint hosts. Intended for Private Connect
|
|
339
|
+
setups. When specified, generates hosts as api.<base_domain>, otlp.<base_domain>,
|
|
340
|
+
flight.<base_domain>. When specified, overrides individual host settings.
|
|
341
|
+
Environment variable: ARIZE_BASE_DOMAIN.
|
|
342
|
+
Default: "" (not set).
|
|
343
|
+
|
|
344
|
+
Note:
|
|
345
|
+
The endpoint override options (region, single_host/single_port, base_domain) are
|
|
346
|
+
mutually exclusive. Specifying more than one will raise MultipleEndpointOverridesError.
|
|
236
347
|
|
|
237
348
|
Raises:
|
|
238
349
|
MissingAPIKeyError: If api_key is not provided via argument or environment variable.
|
|
350
|
+
MultipleEndpointOverridesError: If multiple endpoint override options are provided.
|
|
239
351
|
"""
|
|
240
352
|
|
|
241
353
|
api_key: str = field(
|
|
@@ -326,27 +438,73 @@ class SDKConfiguration:
|
|
|
326
438
|
ENV_SINGLE_PORT, 0, min_val=0, max_val=65535
|
|
327
439
|
)
|
|
328
440
|
)
|
|
441
|
+
base_domain: str = field(
|
|
442
|
+
default_factory=lambda: _env_str(ENV_BASE_DOMAIN, "")
|
|
443
|
+
)
|
|
329
444
|
|
|
330
445
|
def __post_init__(self) -> None:
|
|
331
446
|
"""Validate and configure SDK endpoints after initialization.
|
|
332
447
|
|
|
448
|
+
Endpoint override options are mutually exclusive. Only one of the following
|
|
449
|
+
can be specified:
|
|
450
|
+
1. region - Overrides all via REGION_ENDPOINTS mapping
|
|
451
|
+
2. single_host/single_port - Overrides individual hosts/ports
|
|
452
|
+
3. base_domain - Generates hosts from base domain
|
|
453
|
+
|
|
454
|
+
If none are specified, per-endpoint host/port settings are used.
|
|
455
|
+
|
|
333
456
|
Raises:
|
|
334
|
-
MissingAPIKeyError: If api_key is not provided
|
|
457
|
+
MissingAPIKeyError: If api_key is not provided.
|
|
458
|
+
MultipleEndpointOverridesError: If multiple endpoint override options are provided.
|
|
335
459
|
"""
|
|
336
|
-
# Validate
|
|
460
|
+
# Validate configuration
|
|
337
461
|
if not self.api_key:
|
|
338
462
|
raise MissingAPIKeyError()
|
|
339
463
|
|
|
464
|
+
# Check which override options are set
|
|
465
|
+
has_base_domain = bool(self.base_domain)
|
|
340
466
|
has_single_host = bool(self.single_host)
|
|
341
467
|
has_single_port = self.single_port != 0
|
|
342
468
|
has_region = self.region is not Region.UNSET
|
|
343
|
-
|
|
469
|
+
|
|
470
|
+
# Ensure only one override method is used (mutually exclusive)
|
|
471
|
+
override_count = sum(
|
|
472
|
+
[has_base_domain, has_single_host or has_single_port, has_region]
|
|
473
|
+
)
|
|
474
|
+
if override_count > 1:
|
|
475
|
+
# Determine which overrides were provided
|
|
476
|
+
provided_overrides = []
|
|
477
|
+
if has_region:
|
|
478
|
+
provided_overrides.append(f"region={self.region.value}")
|
|
479
|
+
if has_single_host or has_single_port:
|
|
480
|
+
if has_single_host:
|
|
481
|
+
provided_overrides.append(
|
|
482
|
+
f"single_host={self.single_host!r}"
|
|
483
|
+
)
|
|
484
|
+
if has_single_port:
|
|
485
|
+
provided_overrides.append(f"single_port={self.single_port}")
|
|
486
|
+
if has_base_domain:
|
|
487
|
+
provided_overrides.append(f"base_domain={self.base_domain!r}")
|
|
488
|
+
|
|
489
|
+
error_message = (
|
|
490
|
+
f"Multiple endpoint override options provided: {', '.join(provided_overrides)}. "
|
|
491
|
+
"Only one of the following can be specified: 'region', "
|
|
492
|
+
"'single_host'/'single_port', or 'base_domain'."
|
|
493
|
+
)
|
|
494
|
+
logger.error(error_message)
|
|
495
|
+
raise MultipleEndpointOverridesError(error_message)
|
|
496
|
+
|
|
497
|
+
if has_base_domain:
|
|
344
498
|
logger.info(
|
|
345
|
-
"
|
|
346
|
-
|
|
499
|
+
"Base domain %r provided; generating hosts from base domain.",
|
|
500
|
+
self.base_domain,
|
|
501
|
+
)
|
|
502
|
+
object.__setattr__(self, "api_host", f"api.{self.base_domain}")
|
|
503
|
+
object.__setattr__(self, "otlp_host", f"otlp.{self.base_domain}")
|
|
504
|
+
object.__setattr__(
|
|
505
|
+
self, "flight_host", f"flight.{self.base_domain}"
|
|
347
506
|
)
|
|
348
507
|
|
|
349
|
-
# Single host override: if single_host is set, it overrides hosts
|
|
350
508
|
if has_single_host:
|
|
351
509
|
logger.info(
|
|
352
510
|
"Single host %r provided; overriding hosts configuration with single host.",
|
|
@@ -356,7 +514,6 @@ class SDKConfiguration:
|
|
|
356
514
|
object.__setattr__(self, "otlp_host", self.single_host)
|
|
357
515
|
object.__setattr__(self, "flight_host", self.single_host)
|
|
358
516
|
|
|
359
|
-
# Single port override: if single_port is set, it overrides ports
|
|
360
517
|
if has_single_port:
|
|
361
518
|
logger.info(
|
|
362
519
|
"Single port %s provided; overriding ports configuration with single port.",
|
|
@@ -364,15 +521,12 @@ class SDKConfiguration:
|
|
|
364
521
|
)
|
|
365
522
|
object.__setattr__(self, "flight_port", self.single_port)
|
|
366
523
|
|
|
367
|
-
# Region override: if region is set, it *always* wins over host/port fields
|
|
368
524
|
if has_region:
|
|
369
|
-
endpoints = REGION_ENDPOINTS[self.region]
|
|
370
|
-
|
|
371
|
-
# Override config (region trumps everything)
|
|
372
525
|
logger.info(
|
|
373
526
|
"Region %s provided; overriding hosts & ports configuration with region defaults.",
|
|
374
527
|
self.region.value,
|
|
375
528
|
)
|
|
529
|
+
endpoints = REGION_ENDPOINTS[self.region]
|
|
376
530
|
object.__setattr__(self, "api_host", endpoints.api_host)
|
|
377
531
|
object.__setattr__(self, "otlp_host", endpoints.otlp_host)
|
|
378
532
|
object.__setattr__(self, "flight_host", endpoints.flight_host)
|
arize/constants/config.py
CHANGED
|
@@ -14,6 +14,7 @@ ENV_FLIGHT_PORT = "ARIZE_FLIGHT_PORT"
|
|
|
14
14
|
ENV_FLIGHT_SCHEME = "ARIZE_FLIGHT_SCHEME"
|
|
15
15
|
ENV_SINGLE_HOST = "ARIZE_SINGLE_HOST"
|
|
16
16
|
ENV_SINGLE_PORT = "ARIZE_SINGLE_PORT"
|
|
17
|
+
ENV_BASE_DOMAIN = "ARIZE_BASE_DOMAIN"
|
|
17
18
|
ENV_PYARROW_MAX_CHUNKSIZE = "ARIZE_MAX_CHUNKSIZE"
|
|
18
19
|
ENV_REQUEST_VERIFY = "ARIZE_REQUEST_VERIFY"
|
|
19
20
|
ENV_MAX_HTTP_PAYLOAD_SIZE_MB = "ARIZE_MAX_HTTP_PAYLOAD_SIZE_MB"
|
arize/datasets/client.py
CHANGED
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import logging
|
|
6
6
|
import time
|
|
7
7
|
import uuid
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
9
9
|
|
|
10
10
|
import pandas as pd
|
|
11
11
|
import pyarrow as pa
|
|
@@ -24,6 +24,10 @@ from arize.utils.openinference_conversion import (
|
|
|
24
24
|
from arize.utils.size import get_payload_size_mb
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
|
+
# builtins is needed to use builtins.list in type annotations because
|
|
28
|
+
# the class has a list() method that shadows the built-in list type
|
|
29
|
+
import builtins
|
|
30
|
+
|
|
27
31
|
from arize._generated.api_client.api_client import ApiClient
|
|
28
32
|
from arize.config import SDKConfiguration
|
|
29
33
|
|
|
@@ -97,7 +101,7 @@ class DatasetsClient:
|
|
|
97
101
|
*,
|
|
98
102
|
name: str,
|
|
99
103
|
space_id: str,
|
|
100
|
-
examples: list[dict[str, object]] | pd.DataFrame,
|
|
104
|
+
examples: builtins.list[dict[str, object]] | pd.DataFrame,
|
|
101
105
|
force_http: bool = False,
|
|
102
106
|
) -> models.Dataset:
|
|
103
107
|
"""Create a dataset with JSON examples.
|
|
@@ -150,7 +154,7 @@ class DatasetsClient:
|
|
|
150
154
|
from arize._generated import api_client as gen
|
|
151
155
|
|
|
152
156
|
data = (
|
|
153
|
-
examples.to_dict(orient="records")
|
|
157
|
+
examples.to_dict(orient="records")
|
|
154
158
|
if isinstance(examples, pd.DataFrame)
|
|
155
159
|
else examples
|
|
156
160
|
)
|
|
@@ -158,7 +162,8 @@ class DatasetsClient:
|
|
|
158
162
|
body = gen.DatasetsCreateRequest(
|
|
159
163
|
name=name,
|
|
160
164
|
space_id=space_id,
|
|
161
|
-
|
|
165
|
+
# Cast: pandas to_dict returns dict[Hashable, Any] but API requires dict[str, Any]
|
|
166
|
+
examples=cast("list[dict[str, Any]]", data),
|
|
162
167
|
)
|
|
163
168
|
return self._api.datasets_create(datasets_create_request=body)
|
|
164
169
|
|
|
@@ -169,15 +174,12 @@ class DatasetsClient:
|
|
|
169
174
|
"Trying to convert to DataFrame for more efficient upload via "
|
|
170
175
|
"gRPC + Flight."
|
|
171
176
|
)
|
|
172
|
-
|
|
173
|
-
examples
|
|
174
|
-
if isinstance(examples, pd.DataFrame)
|
|
175
|
-
else pd.DataFrame(examples)
|
|
176
|
-
)
|
|
177
|
+
if not isinstance(examples, pd.DataFrame):
|
|
178
|
+
examples = pd.DataFrame(examples)
|
|
177
179
|
return self._create_dataset_via_flight(
|
|
178
180
|
name=name,
|
|
179
181
|
space_id=space_id,
|
|
180
|
-
examples=
|
|
182
|
+
examples=examples,
|
|
181
183
|
)
|
|
182
184
|
|
|
183
185
|
@prerelease_endpoint(key="datasets.get", stage=ReleaseStage.BETA)
|
|
@@ -280,7 +282,11 @@ class DatasetsClient:
|
|
|
280
282
|
)
|
|
281
283
|
if dataset_df is not None:
|
|
282
284
|
return models.DatasetsExamplesList200Response(
|
|
283
|
-
|
|
285
|
+
# Cast: Pydantic validates and converts dicts to DatasetExample at runtime
|
|
286
|
+
examples=cast(
|
|
287
|
+
"list[models.DatasetExample]",
|
|
288
|
+
dataset_df.to_dict(orient="records"),
|
|
289
|
+
),
|
|
284
290
|
pagination=models.PaginationMetadata(
|
|
285
291
|
has_more=False, # Note that all=True
|
|
286
292
|
),
|
|
@@ -321,7 +327,11 @@ class DatasetsClient:
|
|
|
321
327
|
)
|
|
322
328
|
|
|
323
329
|
return models.DatasetsExamplesList200Response(
|
|
324
|
-
|
|
330
|
+
# Cast: Pydantic validates and converts dicts to DatasetExample at runtime
|
|
331
|
+
examples=cast(
|
|
332
|
+
"list[models.DatasetExample]",
|
|
333
|
+
dataset_df.to_dict(orient="records"),
|
|
334
|
+
),
|
|
325
335
|
pagination=models.PaginationMetadata(
|
|
326
336
|
has_more=False, # Note that all=True
|
|
327
337
|
),
|
|
@@ -336,7 +346,7 @@ class DatasetsClient:
|
|
|
336
346
|
*,
|
|
337
347
|
dataset_id: str,
|
|
338
348
|
dataset_version_id: str = "",
|
|
339
|
-
examples: list[dict[str, object]] | pd.DataFrame,
|
|
349
|
+
examples: builtins.list[dict[str, object]] | pd.DataFrame,
|
|
340
350
|
) -> models.Dataset:
|
|
341
351
|
"""Append new examples to an existing dataset.
|
|
342
352
|
|
|
@@ -377,11 +387,14 @@ class DatasetsClient:
|
|
|
377
387
|
)
|
|
378
388
|
|
|
379
389
|
data = (
|
|
380
|
-
examples.to_dict(orient="records")
|
|
390
|
+
examples.to_dict(orient="records")
|
|
381
391
|
if isinstance(examples, pd.DataFrame)
|
|
382
392
|
else examples
|
|
383
393
|
)
|
|
384
|
-
|
|
394
|
+
# Cast: pandas to_dict returns dict[Hashable, Any] but API requires dict[str, Any]
|
|
395
|
+
body = gen.DatasetsExamplesInsertRequest(
|
|
396
|
+
examples=cast("list[dict[str, Any]]", data)
|
|
397
|
+
)
|
|
385
398
|
|
|
386
399
|
return self._api.datasets_examples_insert(
|
|
387
400
|
dataset_id=dataset_id,
|
|
@@ -394,7 +407,7 @@ class DatasetsClient:
|
|
|
394
407
|
name: str,
|
|
395
408
|
space_id: str,
|
|
396
409
|
examples: pd.DataFrame,
|
|
397
|
-
) ->
|
|
410
|
+
) -> models.Dataset:
|
|
398
411
|
"""Internal method to create a dataset using Flight protocol for large example sets."""
|
|
399
412
|
data = examples.copy()
|
|
400
413
|
# Convert datetime columns to int64 (ms since epoch)
|
|
@@ -454,19 +467,19 @@ def _set_default_columns_for_dataset(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
454
467
|
"""Set default values for created_at and updated_at columns if missing or null."""
|
|
455
468
|
current_time = int(time.time() * 1000)
|
|
456
469
|
if "created_at" in df.columns:
|
|
457
|
-
if df["created_at"].isnull().
|
|
470
|
+
if df["created_at"].isnull().any():
|
|
458
471
|
df["created_at"].fillna(current_time, inplace=True)
|
|
459
472
|
else:
|
|
460
473
|
df["created_at"] = current_time
|
|
461
474
|
|
|
462
475
|
if "updated_at" in df.columns:
|
|
463
|
-
if df["updated_at"].isnull().
|
|
476
|
+
if df["updated_at"].isnull().any():
|
|
464
477
|
df["updated_at"].fillna(current_time, inplace=True)
|
|
465
478
|
else:
|
|
466
479
|
df["updated_at"] = current_time
|
|
467
480
|
|
|
468
481
|
if "id" in df.columns:
|
|
469
|
-
if df["id"].isnull().
|
|
482
|
+
if df["id"].isnull().any():
|
|
470
483
|
df["id"] = df["id"].apply(
|
|
471
484
|
lambda x: str(uuid.uuid4()) if pd.isnull(x) else x
|
|
472
485
|
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Automatic embedding generation factory for various ML use cases."""
|
|
2
2
|
|
|
3
|
+
from typing import TypeAlias
|
|
4
|
+
|
|
3
5
|
import pandas as pd
|
|
4
6
|
|
|
5
7
|
from arize.embeddings import constants
|
|
@@ -24,9 +26,14 @@ from arize.embeddings.nlp_generators import (
|
|
|
24
26
|
from arize.embeddings.tabular_generators import (
|
|
25
27
|
EmbeddingGeneratorForTabularFeatures,
|
|
26
28
|
)
|
|
27
|
-
from arize.embeddings.usecases import
|
|
29
|
+
from arize.embeddings.usecases import (
|
|
30
|
+
CVUseCases,
|
|
31
|
+
NLPUseCases,
|
|
32
|
+
TabularUseCases,
|
|
33
|
+
UseCases,
|
|
34
|
+
)
|
|
28
35
|
|
|
29
|
-
UseCaseLike = str |
|
|
36
|
+
UseCaseLike: TypeAlias = str | NLPUseCases | CVUseCases | TabularUseCases
|
|
30
37
|
|
|
31
38
|
|
|
32
39
|
class EmbeddingGenerator:
|
|
@@ -49,15 +56,15 @@ class EmbeddingGenerator:
|
|
|
49
56
|
) -> BaseEmbeddingGenerator:
|
|
50
57
|
"""Create an embedding generator for the specified use case."""
|
|
51
58
|
if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION:
|
|
52
|
-
return EmbeddingGeneratorForNLPSequenceClassification(**kwargs)
|
|
59
|
+
return EmbeddingGeneratorForNLPSequenceClassification(**kwargs) # type: ignore[arg-type]
|
|
53
60
|
if use_case == UseCases.NLP.SUMMARIZATION:
|
|
54
|
-
return EmbeddingGeneratorForNLPSummarization(**kwargs)
|
|
61
|
+
return EmbeddingGeneratorForNLPSummarization(**kwargs) # type: ignore[arg-type]
|
|
55
62
|
if use_case == UseCases.CV.IMAGE_CLASSIFICATION:
|
|
56
|
-
return EmbeddingGeneratorForCVImageClassification(**kwargs)
|
|
63
|
+
return EmbeddingGeneratorForCVImageClassification(**kwargs) # type: ignore[arg-type]
|
|
57
64
|
if use_case == UseCases.CV.OBJECT_DETECTION:
|
|
58
|
-
return EmbeddingGeneratorForCVObjectDetection(**kwargs)
|
|
65
|
+
return EmbeddingGeneratorForCVObjectDetection(**kwargs) # type: ignore[arg-type]
|
|
59
66
|
if use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
|
|
60
|
-
return EmbeddingGeneratorForTabularFeatures(**kwargs)
|
|
67
|
+
return EmbeddingGeneratorForTabularFeatures(**kwargs) # type: ignore[arg-type]
|
|
61
68
|
raise ValueError(f"Invalid use case {use_case}")
|
|
62
69
|
|
|
63
70
|
@classmethod
|
|
@@ -14,11 +14,15 @@ try:
|
|
|
14
14
|
import torch
|
|
15
15
|
from datasets import Dataset
|
|
16
16
|
from PIL import Image
|
|
17
|
-
from transformers import (
|
|
17
|
+
from transformers import (
|
|
18
18
|
AutoImageProcessor,
|
|
19
19
|
AutoModel,
|
|
20
20
|
AutoTokenizer,
|
|
21
|
+
BaseImageProcessor,
|
|
21
22
|
BatchEncoding,
|
|
23
|
+
BatchFeature,
|
|
24
|
+
PreTrainedModel,
|
|
25
|
+
PreTrainedTokenizerBase,
|
|
22
26
|
)
|
|
23
27
|
from transformers.utils import logging as transformer_logging
|
|
24
28
|
except ImportError as e:
|
|
@@ -67,7 +71,9 @@ class BaseEmbeddingGenerator(ABC):
|
|
|
67
71
|
raise
|
|
68
72
|
|
|
69
73
|
@abstractmethod
|
|
70
|
-
def generate_embeddings(
|
|
74
|
+
def generate_embeddings(
|
|
75
|
+
self, **kwargs: object
|
|
76
|
+
) -> pd.Series | tuple[pd.Series, pd.Series]:
|
|
71
77
|
"""Generate embeddings for the input data."""
|
|
72
78
|
...
|
|
73
79
|
|
|
@@ -95,7 +101,7 @@ class BaseEmbeddingGenerator(ABC):
|
|
|
95
101
|
return self.__model_name
|
|
96
102
|
|
|
97
103
|
@property
|
|
98
|
-
def model(self) ->
|
|
104
|
+
def model(self) -> PreTrainedModel:
|
|
99
105
|
"""Return the underlying model instance."""
|
|
100
106
|
return self.__model
|
|
101
107
|
|
|
@@ -183,7 +189,7 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
183
189
|
tokenizer_max_length: Maximum sequence length for the tokenizer.
|
|
184
190
|
**kwargs: Additional arguments for model initialization.
|
|
185
191
|
"""
|
|
186
|
-
super().__init__(use_case=use_case, model_name=model_name, **kwargs)
|
|
192
|
+
super().__init__(use_case=use_case, model_name=model_name, **kwargs) # type: ignore[arg-type]
|
|
187
193
|
self.__tokenizer_max_length = tokenizer_max_length
|
|
188
194
|
# We don't check for the tokenizer's existence since it is coupled with the corresponding model
|
|
189
195
|
# We check the model's existence in `BaseEmbeddingGenerator`
|
|
@@ -193,7 +199,7 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
193
199
|
)
|
|
194
200
|
|
|
195
201
|
@property
|
|
196
|
-
def tokenizer(self) ->
|
|
202
|
+
def tokenizer(self) -> PreTrainedTokenizerBase:
|
|
197
203
|
"""Return the tokenizer instance for text processing."""
|
|
198
204
|
return self.__tokenizer
|
|
199
205
|
|
|
@@ -240,7 +246,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
240
246
|
model_name: Name of the pre-trained vision model.
|
|
241
247
|
**kwargs: Additional arguments for model initialization.
|
|
242
248
|
"""
|
|
243
|
-
super().__init__(use_case=use_case, model_name=model_name, **kwargs)
|
|
249
|
+
super().__init__(use_case=use_case, model_name=model_name, **kwargs) # type: ignore[arg-type]
|
|
244
250
|
logger.info("Downloading image processor")
|
|
245
251
|
# We don't check for the image processor's existence since it is coupled with the corresponding model
|
|
246
252
|
# We check the model's existence in `BaseEmbeddingGenerator`
|
|
@@ -249,7 +255,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
249
255
|
)
|
|
250
256
|
|
|
251
257
|
@property
|
|
252
|
-
def image_processor(self) ->
|
|
258
|
+
def image_processor(self) -> BaseImageProcessor:
|
|
253
259
|
"""Return the image processor instance for image preprocessing."""
|
|
254
260
|
return self.__image_processor
|
|
255
261
|
|
|
@@ -262,7 +268,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
262
268
|
|
|
263
269
|
def preprocess_image(
|
|
264
270
|
self, batch: dict[str, list[str]], local_image_feat_name: str
|
|
265
|
-
) ->
|
|
271
|
+
) -> BatchFeature:
|
|
266
272
|
"""Preprocess a batch of images for model input."""
|
|
267
273
|
return self.image_processor(
|
|
268
274
|
[
|
|
@@ -272,7 +278,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
272
278
|
return_tensors="pt",
|
|
273
279
|
).to(self.device)
|
|
274
280
|
|
|
275
|
-
def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series:
|
|
281
|
+
def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series: # type: ignore[override]
|
|
276
282
|
"""Obtain embedding vectors from your image data using pre-trained image models.
|
|
277
283
|
|
|
278
284
|
:param local_image_path_col: a pandas Series containing the local path to the images to
|
|
@@ -25,7 +25,7 @@ class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
|
|
|
25
25
|
super().__init__(
|
|
26
26
|
use_case=UseCases.CV.IMAGE_CLASSIFICATION,
|
|
27
27
|
model_name=model_name,
|
|
28
|
-
**kwargs,
|
|
28
|
+
**kwargs, # type: ignore[arg-type]
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
|
|
@@ -46,5 +46,5 @@ class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator):
|
|
|
46
46
|
super().__init__(
|
|
47
47
|
use_case=UseCases.CV.OBJECT_DETECTION,
|
|
48
48
|
model_name=model_name,
|
|
49
|
-
**kwargs,
|
|
49
|
+
**kwargs, # type: ignore[arg-type]
|
|
50
50
|
)
|
|
@@ -39,10 +39,10 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
|
|
|
39
39
|
super().__init__(
|
|
40
40
|
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
|
|
41
41
|
model_name=model_name,
|
|
42
|
-
**kwargs,
|
|
42
|
+
**kwargs, # type: ignore[arg-type]
|
|
43
43
|
)
|
|
44
44
|
|
|
45
|
-
def generate_embeddings(
|
|
45
|
+
def generate_embeddings( # type: ignore[override]
|
|
46
46
|
self,
|
|
47
47
|
text_col: pd.Series,
|
|
48
48
|
class_label_col: pd.Series | None = None,
|
|
@@ -65,10 +65,10 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
|
|
|
65
65
|
if class_label_col is not None:
|
|
66
66
|
if not isinstance(class_label_col, pd.Series):
|
|
67
67
|
raise TypeError("class_label_col must be a pandas Series")
|
|
68
|
-
|
|
68
|
+
temp_df = pd.concat(
|
|
69
69
|
{"text": text_col, "class_label": class_label_col}, axis=1
|
|
70
70
|
)
|
|
71
|
-
prepared_text_col =
|
|
71
|
+
prepared_text_col = temp_df.apply(
|
|
72
72
|
lambda row: f" The classification label is {row['class_label']}. {row['text']}",
|
|
73
73
|
axis=1,
|
|
74
74
|
)
|
|
@@ -83,8 +83,8 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
|
|
|
83
83
|
batched=True,
|
|
84
84
|
batch_size=self.batch_size,
|
|
85
85
|
)
|
|
86
|
-
|
|
87
|
-
return
|
|
86
|
+
result_df: pd.DataFrame = ds.to_pandas()
|
|
87
|
+
return result_df["embedding_vector"]
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
|
|
@@ -104,10 +104,10 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
|
|
|
104
104
|
super().__init__(
|
|
105
105
|
use_case=UseCases.NLP.SUMMARIZATION,
|
|
106
106
|
model_name=model_name,
|
|
107
|
-
**kwargs,
|
|
107
|
+
**kwargs, # type: ignore[arg-type]
|
|
108
108
|
)
|
|
109
109
|
|
|
110
|
-
def generate_embeddings(
|
|
110
|
+
def generate_embeddings( # type: ignore[override]
|
|
111
111
|
self,
|
|
112
112
|
text_col: pd.Series,
|
|
113
113
|
) -> pd.Series:
|