arize 8.0.0b2__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arize/__init__.py +8 -1
  2. arize/_exporter/client.py +18 -17
  3. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  4. arize/_exporter/validation.py +1 -1
  5. arize/_flight/client.py +33 -13
  6. arize/_lazy.py +37 -2
  7. arize/client.py +61 -35
  8. arize/config.py +168 -14
  9. arize/constants/config.py +1 -0
  10. arize/datasets/client.py +32 -19
  11. arize/embeddings/auto_generator.py +14 -7
  12. arize/embeddings/base_generators.py +15 -9
  13. arize/embeddings/cv_generators.py +2 -2
  14. arize/embeddings/nlp_generators.py +8 -8
  15. arize/embeddings/tabular_generators.py +5 -5
  16. arize/exceptions/config.py +22 -0
  17. arize/exceptions/parameters.py +1 -1
  18. arize/exceptions/values.py +8 -5
  19. arize/experiments/__init__.py +4 -0
  20. arize/experiments/client.py +17 -11
  21. arize/experiments/evaluators/base.py +6 -3
  22. arize/experiments/evaluators/executors.py +6 -4
  23. arize/experiments/evaluators/rate_limiters.py +3 -1
  24. arize/experiments/evaluators/types.py +7 -5
  25. arize/experiments/evaluators/utils.py +7 -5
  26. arize/experiments/functions.py +111 -48
  27. arize/experiments/tracing.py +4 -1
  28. arize/experiments/types.py +31 -26
  29. arize/logging.py +53 -32
  30. arize/ml/batch_validation/validator.py +82 -70
  31. arize/ml/bounded_executor.py +25 -6
  32. arize/ml/casting.py +45 -27
  33. arize/ml/client.py +35 -28
  34. arize/ml/proto.py +16 -17
  35. arize/ml/stream_validation.py +63 -25
  36. arize/ml/surrogate_explainer/mimic.py +15 -7
  37. arize/ml/types.py +26 -12
  38. arize/pre_releases.py +7 -6
  39. arize/py.typed +0 -0
  40. arize/regions.py +10 -10
  41. arize/spans/client.py +113 -21
  42. arize/spans/conversion.py +7 -5
  43. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  44. arize/spans/validation/annotations/value_validation.py +11 -14
  45. arize/spans/validation/common/dataframe_form_validation.py +1 -1
  46. arize/spans/validation/common/value_validation.py +10 -13
  47. arize/spans/validation/evals/value_validation.py +1 -1
  48. arize/spans/validation/metadata/argument_validation.py +1 -1
  49. arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
  50. arize/spans/validation/metadata/value_validation.py +23 -1
  51. arize/utils/arrow.py +37 -1
  52. arize/utils/online_tasks/dataframe_preprocessor.py +8 -4
  53. arize/utils/proto.py +0 -1
  54. arize/utils/types.py +6 -6
  55. arize/version.py +1 -1
  56. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/METADATA +18 -3
  57. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/RECORD +60 -58
  58. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/WHEEL +0 -0
  59. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/LICENSE +0 -0
  60. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/NOTICE +0 -0
@@ -64,10 +64,10 @@ class EmbeddingGeneratorForTabularFeatures(NLPEmbeddingGenerator):
64
64
  super().__init__(
65
65
  use_case=UseCases.STRUCTURED.TABULAR_EMBEDDINGS,
66
66
  model_name=model_name,
67
- **kwargs,
67
+ **kwargs, # type: ignore[arg-type]
68
68
  )
69
69
 
70
- def generate_embeddings(
70
+ def generate_embeddings( # type: ignore[override]
71
71
  self,
72
72
  df: pd.DataFrame,
73
73
  selected_columns: list[str],
@@ -145,11 +145,11 @@ class EmbeddingGeneratorForTabularFeatures(NLPEmbeddingGenerator):
145
145
  batch_size=self.batch_size,
146
146
  )
147
147
 
148
- df: pd.DataFrame = ds.to_pandas()
148
+ result_df: pd.DataFrame = ds.to_pandas()
149
149
  if return_prompt_col:
150
- return df["embedding_vector"], prompts
150
+ return result_df["embedding_vector"], prompts
151
151
 
152
- return df["embedding_vector"]
152
+ return result_df["embedding_vector"]
153
153
 
154
154
  @staticmethod
155
155
  def __prompt_fn(row: pd.DataFrame, columns: list[str]) -> str:
@@ -0,0 +1,22 @@
1
+ """Configuration validation exceptions."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ class MultipleEndpointOverridesError(Exception):
7
+ """Raised when multiple endpoint override options are provided.
8
+
9
+ Only one of the following can be specified: region, single_host/single_port, or base_domain.
10
+ """
11
+
12
+ def __init__(self, message: str) -> None:
13
+ """Initialize the exception with an optional custom message.
14
+
15
+ Args:
16
+ message: Custom error message, or empty string.
17
+ """
18
+ self.message = message
19
+
20
+ def __str__(self) -> str:
21
+ """Return the error message."""
22
+ return self.message
@@ -61,7 +61,7 @@ class InvalidValueType(Exception):
61
61
  def __init__(
62
62
  self,
63
63
  value_name: str,
64
- value: bool | int | float | str,
64
+ value: object,
65
65
  correct_type: str,
66
66
  ) -> None:
67
67
  """Initialize the exception with value type validation context.
@@ -533,14 +533,15 @@ class InvalidMultiClassClassNameLength(ValidationError):
533
533
  err_msg = ""
534
534
  for col, class_names in self.invalid_col_class_name.items():
535
535
  # limit to 10
536
- class_names = (
536
+ class_names_list = (
537
537
  list(class_names)[:10]
538
538
  if len(class_names) > 10
539
539
  else list(class_names)
540
540
  )
541
541
  err_msg += (
542
- f"Found some invalid class names: {log_a_list(class_names, 'and')} in the {col} column. Class"
543
- f" names must have at least one character and less than {MAX_MULTI_CLASS_NAME_LENGTH}.\n"
542
+ f"Found some invalid class names: {log_a_list(class_names_list, 'and')} "
543
+ f"in the {col} column. Class names must have at least one character and "
544
+ f"less than {MAX_MULTI_CLASS_NAME_LENGTH}.\n"
544
545
  )
545
546
  return err_msg
546
547
 
@@ -565,9 +566,11 @@ class InvalidMultiClassPredScoreValue(ValidationError):
565
566
  err_msg = ""
566
567
  for col, scores in self.invalid_col_class_scores.items():
567
568
  # limit to 10
568
- scores = list(scores)[:10] if len(scores) > 10 else list(scores)
569
+ scores_list = (
570
+ list(scores)[:10] if len(scores) > 10 else list(scores)
571
+ )
569
572
  err_msg += (
570
- f"Found some invalid scores: {log_a_list(scores, 'and')} in the {col} column that was "
573
+ f"Found some invalid scores: {log_a_list(scores_list, 'and')} in the {col} column that was "
571
574
  "invalid. All scores (values in dictionary) must be between 0 and 1, inclusive. \n"
572
575
  )
573
576
  return err_msg
@@ -1,5 +1,8 @@
1
1
  """Experiment tracking and evaluation functionality for the Arize SDK."""
2
2
 
3
+ from arize.experiments.evaluators.base import (
4
+ Evaluator,
5
+ )
3
6
  from arize.experiments.evaluators.types import (
4
7
  EvaluationResult,
5
8
  EvaluationResultFieldNames,
@@ -9,5 +12,6 @@ from arize.experiments.types import ExperimentTaskFieldNames
9
12
  __all__ = [
10
13
  "EvaluationResult",
11
14
  "EvaluationResultFieldNames",
15
+ "Evaluator",
12
16
  "ExperimentTaskFieldNames",
13
17
  ]
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import TYPE_CHECKING
6
+ from typing import TYPE_CHECKING, cast
7
7
 
8
8
  import opentelemetry.sdk.trace as trace_sdk
9
9
  import pandas as pd
@@ -36,6 +36,10 @@ from arize.utils.openinference_conversion import (
36
36
  from arize.utils.size import get_payload_size_mb
37
37
 
38
38
  if TYPE_CHECKING:
39
+ # builtins is needed to use builtins.list in type annotations because
40
+ # the class has a list() method that shadows the built-in list type
41
+ import builtins
42
+
39
43
  from opentelemetry.trace import Tracer
40
44
 
41
45
  from arize._generated.api_client.api_client import ApiClient
@@ -116,7 +120,7 @@ class ExperimentsClient:
116
120
  *,
117
121
  name: str,
118
122
  dataset_id: str,
119
- experiment_runs: list[dict[str, object]] | pd.DataFrame,
123
+ experiment_runs: builtins.list[dict[str, object]] | pd.DataFrame,
120
124
  task_fields: ExperimentTaskFieldNames,
121
125
  evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
122
126
  force_http: bool = False,
@@ -181,7 +185,7 @@ class ExperimentsClient:
181
185
  body = gen.ExperimentsCreateRequest(
182
186
  name=name,
183
187
  dataset_id=dataset_id,
184
- experiment_runs=data, # type: ignore
188
+ experiment_runs=cast("list[gen.ExperimentRunCreate]", data),
185
189
  )
186
190
  return self._api.experiments_create(experiments_create_request=body)
187
191
 
@@ -303,7 +307,10 @@ class ExperimentsClient:
303
307
  )
304
308
  if experiment_df is not None:
305
309
  return models.ExperimentsRunsList200Response(
306
- experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
310
+ experiment_runs=cast(
311
+ "list[models.ExperimentRun]",
312
+ experiment_df.to_dict(orient="records"),
313
+ ),
307
314
  pagination=models.PaginationMetadata(
308
315
  has_more=False, # Note that all=True
309
316
  ),
@@ -343,7 +350,10 @@ class ExperimentsClient:
343
350
  )
344
351
 
345
352
  return models.ExperimentsRunsList200Response(
346
- experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
353
+ experiment_runs=cast(
354
+ "list[models.ExperimentRun]",
355
+ experiment_df.to_dict(orient="records"),
356
+ ),
347
357
  pagination=models.PaginationMetadata(
348
358
  has_more=False, # Note that all=True
349
359
  ),
@@ -553,9 +563,7 @@ class ExperimentsClient:
553
563
  logger.error(msg)
554
564
  raise RuntimeError(msg)
555
565
 
556
- experiment = self.get(
557
- experiment_id=str(post_resp.experiment_id) # type: ignore
558
- )
566
+ experiment = self.get(experiment_id=str(post_resp.experiment_id))
559
567
  return experiment, output_df
560
568
 
561
569
  def _create_experiment_via_flight(
@@ -636,9 +644,7 @@ class ExperimentsClient:
636
644
  logger.error(msg)
637
645
  raise RuntimeError(msg)
638
646
 
639
- return self.get(
640
- experiment_id=str(post_resp.experiment_id) # type: ignore
641
- )
647
+ return self.get(experiment_id=str(post_resp.experiment_id))
642
648
 
643
649
 
644
650
  def _get_tracer_resource(
@@ -7,7 +7,7 @@ import inspect
7
7
  from abc import ABC
8
8
  from collections.abc import Awaitable, Callable, Mapping, Sequence
9
9
  from types import MappingProxyType
10
- from typing import TYPE_CHECKING
10
+ from typing import TYPE_CHECKING, Any, cast
11
11
 
12
12
  from arize.experiments.evaluators.types import (
13
13
  AnnotatorKind,
@@ -162,7 +162,9 @@ class Evaluator(ABC):
162
162
  f"`evaluate()` method should be callable, got {type(evaluate)}"
163
163
  )
164
164
  # need to remove the first param, i.e. `self`
165
- _validate_sig(functools.partial(evaluate, None), "evaluate")
165
+ _validate_sig(
166
+ functools.partial(evaluate, cast("Any", None)), "evaluate"
167
+ )
166
168
  return
167
169
  if async_evaluate := super_cls.__dict__.get(
168
170
  Evaluator.async_evaluate.__name__
@@ -175,7 +177,8 @@ class Evaluator(ABC):
175
177
  )
176
178
  # need to remove the first param, i.e. `self`
177
179
  _validate_sig(
178
- functools.partial(async_evaluate, None), "async_evaluate"
180
+ functools.partial(async_evaluate, cast("Any", None)),
181
+ "async_evaluate",
179
182
  )
180
183
  return
181
184
  raise ValueError(
@@ -77,7 +77,7 @@ class Executor(Protocol):
77
77
 
78
78
  def run(
79
79
  self, inputs: Sequence[Any]
80
- ) -> tuple[list[object], list[ExecutionDetails]]:
80
+ ) -> tuple[list[Unset | object], list[ExecutionDetails]]:
81
81
  """Execute the generation function on all inputs and return outputs with execution details."""
82
82
  ...
83
83
 
@@ -255,7 +255,7 @@ class AsyncExecutor(Executor):
255
255
 
256
256
  async def execute(
257
257
  self, inputs: Sequence[Any]
258
- ) -> tuple[list[object], list[ExecutionDetails]]:
258
+ ) -> tuple[list[Unset | object], list[ExecutionDetails]]:
259
259
  """Execute all inputs asynchronously using producer-consumer pattern."""
260
260
  termination_event = asyncio.Event()
261
261
 
@@ -332,7 +332,7 @@ class AsyncExecutor(Executor):
332
332
 
333
333
  def run(
334
334
  self, inputs: Sequence[Any]
335
- ) -> tuple[list[object], list[ExecutionDetails]]:
335
+ ) -> tuple[list[Unset | object], list[ExecutionDetails]]:
336
336
  """Execute all inputs asynchronously and return outputs with execution details."""
337
337
  return asyncio.run(self.execute(inputs))
338
338
 
@@ -406,7 +406,9 @@ class SyncExecutor(Executor):
406
406
  else:
407
407
  yield
408
408
 
409
- def run(self, inputs: Sequence[Any]) -> tuple[list[object], list[object]]:
409
+ def run(
410
+ self, inputs: Sequence[Any]
411
+ ) -> tuple[list[Unset | object], list[ExecutionDetails]]:
410
412
  """Execute all inputs synchronously and return outputs with execution details."""
411
413
  with self._executor_signal_handling(self.termination_signal):
412
414
  outputs = [self.fallback_return_value] * len(inputs)
@@ -276,7 +276,9 @@ class RateLimiter:
276
276
  """Apply rate limiting to an asynchronous function."""
277
277
 
278
278
  @wraps(fn)
279
- async def wrapper(*args: object, **kwargs: object) -> GenericType:
279
+ async def wrapper(
280
+ *args: ParameterSpec.args, **kwargs: ParameterSpec.kwargs
281
+ ) -> GenericType:
280
282
  self._initialize_async_primitives()
281
283
  if self._rate_limit_handling_lock is None or not isinstance(
282
284
  self._rate_limit_handling_lock, asyncio.Lock
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from dataclasses import dataclass, field
6
6
  from enum import Enum
7
- from typing import TYPE_CHECKING
7
+ from typing import TYPE_CHECKING, cast
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from collections.abc import Mapping
@@ -60,10 +60,12 @@ class EvaluationResult:
60
60
  if not obj:
61
61
  return None
62
62
  return cls(
63
- score=obj.get("score"),
64
- label=obj.get("label"),
65
- explanation=obj.get("explanation"),
66
- metadata=obj.get("metadata") or {},
63
+ score=cast("float | None", obj.get("score")),
64
+ label=cast("str | None", obj.get("label")),
65
+ explanation=cast("str | None", obj.get("explanation")),
66
+ metadata=cast(
67
+ "Mapping[str, JSONSerializable]", obj.get("metadata") or {}
68
+ ),
67
69
  )
68
70
 
69
71
  def __post_init__(self) -> None:
@@ -2,8 +2,8 @@
2
2
 
3
3
  import functools
4
4
  import inspect
5
- from collections.abc import Callable
6
- from typing import TYPE_CHECKING
5
+ from collections.abc import Awaitable, Callable
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from tqdm.auto import tqdm
9
9
 
@@ -154,10 +154,10 @@ def _wrap_coroutine_evaluation_function(
154
154
  name: str,
155
155
  sig: inspect.Signature,
156
156
  convert_to_score: Callable[[object], EvaluationResult],
157
- ) -> Callable[[Callable[..., object]], "Evaluator"]:
157
+ ) -> Callable[[Callable[..., Awaitable[object]]], "Evaluator"]:
158
158
  from ..evaluators.base import Evaluator
159
159
 
160
- def wrapper(func: Callable[..., object]) -> "Evaluator":
160
+ def wrapper(func: Callable[..., Awaitable[object]]) -> "Evaluator":
161
161
  class AsyncEvaluator(Evaluator):
162
162
  def __init__(self) -> None:
163
163
  self._name = name
@@ -224,9 +224,11 @@ def _default_eval_scorer(result: object) -> EvaluationResult:
224
224
  raise ValueError(f"Unsupported evaluation result type: {type(result)}")
225
225
 
226
226
 
227
- def printif(condition: bool, *args: object, **kwargs: object) -> None:
227
+ def printif(condition: bool, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
228
228
  """Print to tqdm output if the condition is true.
229
229
 
230
+ Note: *args/**kwargs use Any for proper pass-through to tqdm.write().
231
+
230
232
  Args:
231
233
  condition: Whether to print the message.
232
234
  *args: Positional arguments to pass to tqdm.write.