arize 8.0.0b2__py3-none-any.whl → 8.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +8 -1
- arize/_exporter/client.py +18 -17
- arize/_exporter/parsers/tracing_data_parser.py +9 -4
- arize/_exporter/validation.py +1 -1
- arize/_flight/client.py +33 -13
- arize/_lazy.py +37 -2
- arize/client.py +61 -35
- arize/config.py +168 -14
- arize/constants/config.py +1 -0
- arize/datasets/client.py +32 -19
- arize/embeddings/auto_generator.py +14 -7
- arize/embeddings/base_generators.py +15 -9
- arize/embeddings/cv_generators.py +2 -2
- arize/embeddings/nlp_generators.py +8 -8
- arize/embeddings/tabular_generators.py +5 -5
- arize/exceptions/config.py +22 -0
- arize/exceptions/parameters.py +1 -1
- arize/exceptions/values.py +8 -5
- arize/experiments/__init__.py +4 -0
- arize/experiments/client.py +17 -11
- arize/experiments/evaluators/base.py +6 -3
- arize/experiments/evaluators/executors.py +6 -4
- arize/experiments/evaluators/rate_limiters.py +3 -1
- arize/experiments/evaluators/types.py +7 -5
- arize/experiments/evaluators/utils.py +7 -5
- arize/experiments/functions.py +111 -48
- arize/experiments/tracing.py +4 -1
- arize/experiments/types.py +31 -26
- arize/logging.py +53 -32
- arize/ml/batch_validation/validator.py +82 -70
- arize/ml/bounded_executor.py +25 -6
- arize/ml/casting.py +45 -27
- arize/ml/client.py +35 -28
- arize/ml/proto.py +16 -17
- arize/ml/stream_validation.py +63 -25
- arize/ml/surrogate_explainer/mimic.py +15 -7
- arize/ml/types.py +26 -12
- arize/pre_releases.py +7 -6
- arize/py.typed +0 -0
- arize/regions.py +10 -10
- arize/spans/client.py +113 -21
- arize/spans/conversion.py +7 -5
- arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
- arize/spans/validation/annotations/value_validation.py +11 -14
- arize/spans/validation/common/dataframe_form_validation.py +1 -1
- arize/spans/validation/common/value_validation.py +10 -13
- arize/spans/validation/evals/value_validation.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +1 -1
- arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
- arize/spans/validation/metadata/value_validation.py +23 -1
- arize/utils/arrow.py +37 -1
- arize/utils/online_tasks/dataframe_preprocessor.py +8 -4
- arize/utils/proto.py +0 -1
- arize/utils/types.py +6 -6
- arize/version.py +1 -1
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/METADATA +10 -2
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/RECORD +60 -58
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
- {arize-8.0.0b2.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
|
@@ -64,10 +64,10 @@ class EmbeddingGeneratorForTabularFeatures(NLPEmbeddingGenerator):
|
|
|
64
64
|
super().__init__(
|
|
65
65
|
use_case=UseCases.STRUCTURED.TABULAR_EMBEDDINGS,
|
|
66
66
|
model_name=model_name,
|
|
67
|
-
**kwargs,
|
|
67
|
+
**kwargs, # type: ignore[arg-type]
|
|
68
68
|
)
|
|
69
69
|
|
|
70
|
-
def generate_embeddings(
|
|
70
|
+
def generate_embeddings( # type: ignore[override]
|
|
71
71
|
self,
|
|
72
72
|
df: pd.DataFrame,
|
|
73
73
|
selected_columns: list[str],
|
|
@@ -145,11 +145,11 @@ class EmbeddingGeneratorForTabularFeatures(NLPEmbeddingGenerator):
|
|
|
145
145
|
batch_size=self.batch_size,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
-
|
|
148
|
+
result_df: pd.DataFrame = ds.to_pandas()
|
|
149
149
|
if return_prompt_col:
|
|
150
|
-
return
|
|
150
|
+
return result_df["embedding_vector"], prompts
|
|
151
151
|
|
|
152
|
-
return
|
|
152
|
+
return result_df["embedding_vector"]
|
|
153
153
|
|
|
154
154
|
@staticmethod
|
|
155
155
|
def __prompt_fn(row: pd.DataFrame, columns: list[str]) -> str:
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Configuration validation exceptions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MultipleEndpointOverridesError(Exception):
|
|
7
|
+
"""Raised when multiple endpoint override options are provided.
|
|
8
|
+
|
|
9
|
+
Only one of the following can be specified: region, single_host/single_port, or base_domain.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, message: str) -> None:
|
|
13
|
+
"""Initialize the exception with an optional custom message.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
message: Custom error message, or empty string.
|
|
17
|
+
"""
|
|
18
|
+
self.message = message
|
|
19
|
+
|
|
20
|
+
def __str__(self) -> str:
|
|
21
|
+
"""Return the error message."""
|
|
22
|
+
return self.message
|
arize/exceptions/parameters.py
CHANGED
arize/exceptions/values.py
CHANGED
|
@@ -533,14 +533,15 @@ class InvalidMultiClassClassNameLength(ValidationError):
|
|
|
533
533
|
err_msg = ""
|
|
534
534
|
for col, class_names in self.invalid_col_class_name.items():
|
|
535
535
|
# limit to 10
|
|
536
|
-
|
|
536
|
+
class_names_list = (
|
|
537
537
|
list(class_names)[:10]
|
|
538
538
|
if len(class_names) > 10
|
|
539
539
|
else list(class_names)
|
|
540
540
|
)
|
|
541
541
|
err_msg += (
|
|
542
|
-
f"Found some invalid class names: {log_a_list(
|
|
543
|
-
f" names must have at least one character and
|
|
542
|
+
f"Found some invalid class names: {log_a_list(class_names_list, 'and')} "
|
|
543
|
+
f"in the {col} column. Class names must have at least one character and "
|
|
544
|
+
f"less than {MAX_MULTI_CLASS_NAME_LENGTH}.\n"
|
|
544
545
|
)
|
|
545
546
|
return err_msg
|
|
546
547
|
|
|
@@ -565,9 +566,11 @@ class InvalidMultiClassPredScoreValue(ValidationError):
|
|
|
565
566
|
err_msg = ""
|
|
566
567
|
for col, scores in self.invalid_col_class_scores.items():
|
|
567
568
|
# limit to 10
|
|
568
|
-
|
|
569
|
+
scores_list = (
|
|
570
|
+
list(scores)[:10] if len(scores) > 10 else list(scores)
|
|
571
|
+
)
|
|
569
572
|
err_msg += (
|
|
570
|
-
f"Found some invalid scores: {log_a_list(
|
|
573
|
+
f"Found some invalid scores: {log_a_list(scores_list, 'and')} in the {col} column that was "
|
|
571
574
|
"invalid. All scores (values in dictionary) must be between 0 and 1, inclusive. \n"
|
|
572
575
|
)
|
|
573
576
|
return err_msg
|
arize/experiments/__init__.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
"""Experiment tracking and evaluation functionality for the Arize SDK."""
|
|
2
2
|
|
|
3
|
+
from arize.experiments.evaluators.base import (
|
|
4
|
+
Evaluator,
|
|
5
|
+
)
|
|
3
6
|
from arize.experiments.evaluators.types import (
|
|
4
7
|
EvaluationResult,
|
|
5
8
|
EvaluationResultFieldNames,
|
|
@@ -9,5 +12,6 @@ from arize.experiments.types import ExperimentTaskFieldNames
|
|
|
9
12
|
__all__ = [
|
|
10
13
|
"EvaluationResult",
|
|
11
14
|
"EvaluationResultFieldNames",
|
|
15
|
+
"Evaluator",
|
|
12
16
|
"ExperimentTaskFieldNames",
|
|
13
17
|
]
|
arize/experiments/client.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
-
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import TYPE_CHECKING, cast
|
|
7
7
|
|
|
8
8
|
import opentelemetry.sdk.trace as trace_sdk
|
|
9
9
|
import pandas as pd
|
|
@@ -36,6 +36,10 @@ from arize.utils.openinference_conversion import (
|
|
|
36
36
|
from arize.utils.size import get_payload_size_mb
|
|
37
37
|
|
|
38
38
|
if TYPE_CHECKING:
|
|
39
|
+
# builtins is needed to use builtins.list in type annotations because
|
|
40
|
+
# the class has a list() method that shadows the built-in list type
|
|
41
|
+
import builtins
|
|
42
|
+
|
|
39
43
|
from opentelemetry.trace import Tracer
|
|
40
44
|
|
|
41
45
|
from arize._generated.api_client.api_client import ApiClient
|
|
@@ -116,7 +120,7 @@ class ExperimentsClient:
|
|
|
116
120
|
*,
|
|
117
121
|
name: str,
|
|
118
122
|
dataset_id: str,
|
|
119
|
-
experiment_runs: list[dict[str, object]] | pd.DataFrame,
|
|
123
|
+
experiment_runs: builtins.list[dict[str, object]] | pd.DataFrame,
|
|
120
124
|
task_fields: ExperimentTaskFieldNames,
|
|
121
125
|
evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
|
|
122
126
|
force_http: bool = False,
|
|
@@ -181,7 +185,7 @@ class ExperimentsClient:
|
|
|
181
185
|
body = gen.ExperimentsCreateRequest(
|
|
182
186
|
name=name,
|
|
183
187
|
dataset_id=dataset_id,
|
|
184
|
-
experiment_runs=data,
|
|
188
|
+
experiment_runs=cast("list[gen.ExperimentRunCreate]", data),
|
|
185
189
|
)
|
|
186
190
|
return self._api.experiments_create(experiments_create_request=body)
|
|
187
191
|
|
|
@@ -303,7 +307,10 @@ class ExperimentsClient:
|
|
|
303
307
|
)
|
|
304
308
|
if experiment_df is not None:
|
|
305
309
|
return models.ExperimentsRunsList200Response(
|
|
306
|
-
|
|
310
|
+
experiment_runs=cast(
|
|
311
|
+
"list[models.ExperimentRun]",
|
|
312
|
+
experiment_df.to_dict(orient="records"),
|
|
313
|
+
),
|
|
307
314
|
pagination=models.PaginationMetadata(
|
|
308
315
|
has_more=False, # Note that all=True
|
|
309
316
|
),
|
|
@@ -343,7 +350,10 @@ class ExperimentsClient:
|
|
|
343
350
|
)
|
|
344
351
|
|
|
345
352
|
return models.ExperimentsRunsList200Response(
|
|
346
|
-
|
|
353
|
+
experiment_runs=cast(
|
|
354
|
+
"list[models.ExperimentRun]",
|
|
355
|
+
experiment_df.to_dict(orient="records"),
|
|
356
|
+
),
|
|
347
357
|
pagination=models.PaginationMetadata(
|
|
348
358
|
has_more=False, # Note that all=True
|
|
349
359
|
),
|
|
@@ -553,9 +563,7 @@ class ExperimentsClient:
|
|
|
553
563
|
logger.error(msg)
|
|
554
564
|
raise RuntimeError(msg)
|
|
555
565
|
|
|
556
|
-
experiment = self.get(
|
|
557
|
-
experiment_id=str(post_resp.experiment_id) # type: ignore
|
|
558
|
-
)
|
|
566
|
+
experiment = self.get(experiment_id=str(post_resp.experiment_id))
|
|
559
567
|
return experiment, output_df
|
|
560
568
|
|
|
561
569
|
def _create_experiment_via_flight(
|
|
@@ -636,9 +644,7 @@ class ExperimentsClient:
|
|
|
636
644
|
logger.error(msg)
|
|
637
645
|
raise RuntimeError(msg)
|
|
638
646
|
|
|
639
|
-
return self.get(
|
|
640
|
-
experiment_id=str(post_resp.experiment_id) # type: ignore
|
|
641
|
-
)
|
|
647
|
+
return self.get(experiment_id=str(post_resp.experiment_id))
|
|
642
648
|
|
|
643
649
|
|
|
644
650
|
def _get_tracer_resource(
|
|
@@ -7,7 +7,7 @@ import inspect
|
|
|
7
7
|
from abc import ABC
|
|
8
8
|
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
|
9
9
|
from types import MappingProxyType
|
|
10
|
-
from typing import TYPE_CHECKING
|
|
10
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
11
11
|
|
|
12
12
|
from arize.experiments.evaluators.types import (
|
|
13
13
|
AnnotatorKind,
|
|
@@ -162,7 +162,9 @@ class Evaluator(ABC):
|
|
|
162
162
|
f"`evaluate()` method should be callable, got {type(evaluate)}"
|
|
163
163
|
)
|
|
164
164
|
# need to remove the first param, i.e. `self`
|
|
165
|
-
_validate_sig(
|
|
165
|
+
_validate_sig(
|
|
166
|
+
functools.partial(evaluate, cast("Any", None)), "evaluate"
|
|
167
|
+
)
|
|
166
168
|
return
|
|
167
169
|
if async_evaluate := super_cls.__dict__.get(
|
|
168
170
|
Evaluator.async_evaluate.__name__
|
|
@@ -175,7 +177,8 @@ class Evaluator(ABC):
|
|
|
175
177
|
)
|
|
176
178
|
# need to remove the first param, i.e. `self`
|
|
177
179
|
_validate_sig(
|
|
178
|
-
functools.partial(async_evaluate, None),
|
|
180
|
+
functools.partial(async_evaluate, cast("Any", None)),
|
|
181
|
+
"async_evaluate",
|
|
179
182
|
)
|
|
180
183
|
return
|
|
181
184
|
raise ValueError(
|
|
@@ -77,7 +77,7 @@ class Executor(Protocol):
|
|
|
77
77
|
|
|
78
78
|
def run(
|
|
79
79
|
self, inputs: Sequence[Any]
|
|
80
|
-
) -> tuple[list[object], list[ExecutionDetails]]:
|
|
80
|
+
) -> tuple[list[Unset | object], list[ExecutionDetails]]:
|
|
81
81
|
"""Execute the generation function on all inputs and return outputs with execution details."""
|
|
82
82
|
...
|
|
83
83
|
|
|
@@ -255,7 +255,7 @@ class AsyncExecutor(Executor):
|
|
|
255
255
|
|
|
256
256
|
async def execute(
|
|
257
257
|
self, inputs: Sequence[Any]
|
|
258
|
-
) -> tuple[list[object], list[ExecutionDetails]]:
|
|
258
|
+
) -> tuple[list[Unset | object], list[ExecutionDetails]]:
|
|
259
259
|
"""Execute all inputs asynchronously using producer-consumer pattern."""
|
|
260
260
|
termination_event = asyncio.Event()
|
|
261
261
|
|
|
@@ -332,7 +332,7 @@ class AsyncExecutor(Executor):
|
|
|
332
332
|
|
|
333
333
|
def run(
|
|
334
334
|
self, inputs: Sequence[Any]
|
|
335
|
-
) -> tuple[list[object], list[ExecutionDetails]]:
|
|
335
|
+
) -> tuple[list[Unset | object], list[ExecutionDetails]]:
|
|
336
336
|
"""Execute all inputs asynchronously and return outputs with execution details."""
|
|
337
337
|
return asyncio.run(self.execute(inputs))
|
|
338
338
|
|
|
@@ -406,7 +406,9 @@ class SyncExecutor(Executor):
|
|
|
406
406
|
else:
|
|
407
407
|
yield
|
|
408
408
|
|
|
409
|
-
def run(
|
|
409
|
+
def run(
|
|
410
|
+
self, inputs: Sequence[Any]
|
|
411
|
+
) -> tuple[list[Unset | object], list[ExecutionDetails]]:
|
|
410
412
|
"""Execute all inputs synchronously and return outputs with execution details."""
|
|
411
413
|
with self._executor_signal_handling(self.termination_signal):
|
|
412
414
|
outputs = [self.fallback_return_value] * len(inputs)
|
|
@@ -276,7 +276,9 @@ class RateLimiter:
|
|
|
276
276
|
"""Apply rate limiting to an asynchronous function."""
|
|
277
277
|
|
|
278
278
|
@wraps(fn)
|
|
279
|
-
async def wrapper(
|
|
279
|
+
async def wrapper(
|
|
280
|
+
*args: ParameterSpec.args, **kwargs: ParameterSpec.kwargs
|
|
281
|
+
) -> GenericType:
|
|
280
282
|
self._initialize_async_primitives()
|
|
281
283
|
if self._rate_limit_handling_lock is None or not isinstance(
|
|
282
284
|
self._rate_limit_handling_lock, asyncio.Lock
|
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from typing import TYPE_CHECKING
|
|
7
|
+
from typing import TYPE_CHECKING, cast
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
from collections.abc import Mapping
|
|
@@ -60,10 +60,12 @@ class EvaluationResult:
|
|
|
60
60
|
if not obj:
|
|
61
61
|
return None
|
|
62
62
|
return cls(
|
|
63
|
-
score=obj.get("score"),
|
|
64
|
-
label=obj.get("label"),
|
|
65
|
-
explanation=obj.get("explanation"),
|
|
66
|
-
metadata=
|
|
63
|
+
score=cast("float | None", obj.get("score")),
|
|
64
|
+
label=cast("str | None", obj.get("label")),
|
|
65
|
+
explanation=cast("str | None", obj.get("explanation")),
|
|
66
|
+
metadata=cast(
|
|
67
|
+
"Mapping[str, JSONSerializable]", obj.get("metadata") or {}
|
|
68
|
+
),
|
|
67
69
|
)
|
|
68
70
|
|
|
69
71
|
def __post_init__(self) -> None:
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
import inspect
|
|
5
|
-
from collections.abc import Callable
|
|
6
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from tqdm.auto import tqdm
|
|
9
9
|
|
|
@@ -154,10 +154,10 @@ def _wrap_coroutine_evaluation_function(
|
|
|
154
154
|
name: str,
|
|
155
155
|
sig: inspect.Signature,
|
|
156
156
|
convert_to_score: Callable[[object], EvaluationResult],
|
|
157
|
-
) -> Callable[[Callable[..., object]], "Evaluator"]:
|
|
157
|
+
) -> Callable[[Callable[..., Awaitable[object]]], "Evaluator"]:
|
|
158
158
|
from ..evaluators.base import Evaluator
|
|
159
159
|
|
|
160
|
-
def wrapper(func: Callable[..., object]) -> "Evaluator":
|
|
160
|
+
def wrapper(func: Callable[..., Awaitable[object]]) -> "Evaluator":
|
|
161
161
|
class AsyncEvaluator(Evaluator):
|
|
162
162
|
def __init__(self) -> None:
|
|
163
163
|
self._name = name
|
|
@@ -224,9 +224,11 @@ def _default_eval_scorer(result: object) -> EvaluationResult:
|
|
|
224
224
|
raise ValueError(f"Unsupported evaluation result type: {type(result)}")
|
|
225
225
|
|
|
226
226
|
|
|
227
|
-
def printif(condition: bool, *args:
|
|
227
|
+
def printif(condition: bool, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
|
|
228
228
|
"""Print to tqdm output if the condition is true.
|
|
229
229
|
|
|
230
|
+
Note: *args/**kwargs use Any for proper pass-through to tqdm.write().
|
|
231
|
+
|
|
230
232
|
Args:
|
|
231
233
|
condition: Whether to print the message.
|
|
232
234
|
*args: Positional arguments to pass to tqdm.write.
|