arize 8.0.0b1__py3-none-any.whl → 8.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. arize/__init__.py +9 -2
  2. arize/_client_factory.py +50 -0
  3. arize/_exporter/client.py +18 -17
  4. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  5. arize/_exporter/validation.py +1 -1
  6. arize/_flight/client.py +37 -17
  7. arize/_generated/api_client/api/datasets_api.py +6 -6
  8. arize/_generated/api_client/api/experiments_api.py +6 -6
  9. arize/_generated/api_client/api/projects_api.py +3 -3
  10. arize/_lazy.py +61 -10
  11. arize/client.py +66 -50
  12. arize/config.py +175 -48
  13. arize/constants/config.py +1 -0
  14. arize/constants/ml.py +9 -16
  15. arize/constants/spans.py +5 -10
  16. arize/datasets/client.py +45 -28
  17. arize/datasets/errors.py +1 -1
  18. arize/datasets/validation.py +2 -2
  19. arize/embeddings/auto_generator.py +16 -9
  20. arize/embeddings/base_generators.py +15 -9
  21. arize/embeddings/cv_generators.py +2 -2
  22. arize/embeddings/errors.py +2 -2
  23. arize/embeddings/nlp_generators.py +8 -8
  24. arize/embeddings/tabular_generators.py +6 -6
  25. arize/exceptions/base.py +0 -52
  26. arize/exceptions/config.py +22 -0
  27. arize/exceptions/parameters.py +1 -330
  28. arize/exceptions/values.py +8 -5
  29. arize/experiments/__init__.py +4 -0
  30. arize/experiments/client.py +31 -18
  31. arize/experiments/evaluators/base.py +12 -9
  32. arize/experiments/evaluators/executors.py +16 -7
  33. arize/experiments/evaluators/rate_limiters.py +3 -1
  34. arize/experiments/evaluators/types.py +9 -7
  35. arize/experiments/evaluators/utils.py +7 -5
  36. arize/experiments/functions.py +128 -58
  37. arize/experiments/tracing.py +4 -1
  38. arize/experiments/types.py +34 -31
  39. arize/logging.py +54 -33
  40. arize/ml/batch_validation/errors.py +10 -1004
  41. arize/ml/batch_validation/validator.py +351 -291
  42. arize/ml/bounded_executor.py +25 -6
  43. arize/ml/casting.py +51 -33
  44. arize/ml/client.py +43 -35
  45. arize/ml/proto.py +21 -22
  46. arize/ml/stream_validation.py +64 -27
  47. arize/ml/surrogate_explainer/mimic.py +18 -10
  48. arize/ml/types.py +27 -67
  49. arize/pre_releases.py +10 -6
  50. arize/projects/client.py +9 -4
  51. arize/py.typed +0 -0
  52. arize/regions.py +11 -11
  53. arize/spans/client.py +125 -31
  54. arize/spans/columns.py +32 -36
  55. arize/spans/conversion.py +12 -11
  56. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  57. arize/spans/validation/annotations/value_validation.py +11 -14
  58. arize/spans/validation/common/argument_validation.py +3 -3
  59. arize/spans/validation/common/dataframe_form_validation.py +7 -7
  60. arize/spans/validation/common/value_validation.py +11 -14
  61. arize/spans/validation/evals/dataframe_form_validation.py +4 -4
  62. arize/spans/validation/evals/evals_validation.py +6 -6
  63. arize/spans/validation/evals/value_validation.py +1 -1
  64. arize/spans/validation/metadata/argument_validation.py +1 -1
  65. arize/spans/validation/metadata/dataframe_form_validation.py +2 -2
  66. arize/spans/validation/metadata/value_validation.py +23 -1
  67. arize/spans/validation/spans/dataframe_form_validation.py +2 -2
  68. arize/spans/validation/spans/spans_validation.py +6 -6
  69. arize/utils/arrow.py +38 -2
  70. arize/utils/cache.py +2 -2
  71. arize/utils/dataframe.py +4 -4
  72. arize/utils/online_tasks/dataframe_preprocessor.py +15 -11
  73. arize/utils/openinference_conversion.py +10 -10
  74. arize/utils/proto.py +0 -1
  75. arize/utils/types.py +6 -6
  76. arize/version.py +1 -1
  77. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/METADATA +32 -7
  78. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/RECORD +81 -78
  79. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
  80. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
  81. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
@@ -77,7 +77,7 @@ class Executor(Protocol):
77
77
 
78
78
  def run(
79
79
  self, inputs: Sequence[Any]
80
- ) -> tuple[list[object], list[ExecutionDetails]]:
80
+ ) -> tuple[list[Unset | object], list[ExecutionDetails]]:
81
81
  """Execute the generation function on all inputs and return outputs with execution details."""
82
82
  ...
83
83
 
@@ -94,7 +94,7 @@ class AsyncExecutor(Executor):
94
94
 
95
95
  concurrency (int, optional): The number of concurrent consumers. Defaults to 3.
96
96
 
97
- tqdm_bar_format (Optional[str], optional): The format string for the progress bar.
97
+ tqdm_bar_format (str | :obj:`None`, optional): The format string for the progress bar.
98
98
  Defaults to None.
99
99
 
100
100
  max_retries (int, optional): The maximum number of times to retry on exceptions.
@@ -119,6 +119,7 @@ class AsyncExecutor(Executor):
119
119
  exit_on_error: bool = True,
120
120
  fallback_return_value: Unset | object = _unset,
121
121
  termination_signal: signal.Signals = signal.SIGINT,
122
+ timeout: int = 120,
122
123
  ) -> None:
123
124
  """Initialize the async executor with configuration parameters.
124
125
 
@@ -130,6 +131,7 @@ class AsyncExecutor(Executor):
130
131
  exit_on_error: Whether to exit on first error.
131
132
  fallback_return_value: Value to return when execution fails.
132
133
  termination_signal: Signal to handle for graceful termination.
134
+ timeout: Timeout for each task in seconds.
133
135
  """
134
136
  self.generate = generation_fn
135
137
  self.fallback_return_value = fallback_return_value
@@ -139,6 +141,7 @@ class AsyncExecutor(Executor):
139
141
  self.exit_on_error = exit_on_error
140
142
  self.base_priority = 0
141
143
  self.termination_signal = termination_signal
144
+ self.timeout = timeout
142
145
 
143
146
  async def producer(
144
147
  self,
@@ -195,7 +198,7 @@ class AsyncExecutor(Executor):
195
198
  )
196
199
  done, _pending = await asyncio.wait(
197
200
  [generate_task, termination_event_watcher],
198
- timeout=120,
201
+ timeout=self.timeout,
199
202
  return_when=asyncio.FIRST_COMPLETED,
200
203
  )
201
204
 
@@ -252,7 +255,7 @@ class AsyncExecutor(Executor):
252
255
 
253
256
  async def execute(
254
257
  self, inputs: Sequence[Any]
255
- ) -> tuple[list[object], list[ExecutionDetails]]:
258
+ ) -> tuple[list[Unset | object], list[ExecutionDetails]]:
256
259
  """Execute all inputs asynchronously using producer-consumer pattern."""
257
260
  termination_event = asyncio.Event()
258
261
 
@@ -329,7 +332,7 @@ class AsyncExecutor(Executor):
329
332
 
330
333
  def run(
331
334
  self, inputs: Sequence[Any]
332
- ) -> tuple[list[object], list[ExecutionDetails]]:
335
+ ) -> tuple[list[Unset | object], list[ExecutionDetails]]:
333
336
  """Execute all inputs asynchronously and return outputs with execution details."""
334
337
  return asyncio.run(self.execute(inputs))
335
338
 
@@ -341,7 +344,7 @@ class SyncExecutor(Executor):
341
344
  generation_fn (Callable[[object], Any]): The generation function that takes an input and
342
345
  returns an output.
343
346
 
344
- tqdm_bar_format (Optional[str], optional): The format string for the progress bar. Defaults
347
+ tqdm_bar_format (str | :obj:`None`, optional): The format string for the progress bar. Defaults
345
348
  to None.
346
349
 
347
350
  max_retries (int, optional): The maximum number of times to retry on exceptions. Defaults to
@@ -403,7 +406,9 @@ class SyncExecutor(Executor):
403
406
  else:
404
407
  yield
405
408
 
406
- def run(self, inputs: Sequence[Any]) -> tuple[list[object], list[object]]:
409
+ def run(
410
+ self, inputs: Sequence[Any]
411
+ ) -> tuple[list[Unset | object], list[ExecutionDetails]]:
407
412
  """Execute all inputs synchronously and return outputs with execution details."""
408
413
  with self._executor_signal_handling(self.termination_signal):
409
414
  outputs = [self.fallback_return_value] * len(inputs)
@@ -460,6 +465,7 @@ def get_executor_on_sync_context(
460
465
  max_retries: int = 10,
461
466
  exit_on_error: bool = True,
462
467
  fallback_return_value: Unset | object = _unset,
468
+ timeout: int = 120,
463
469
  ) -> Executor:
464
470
  """Get an appropriate executor based on the current threading context.
465
471
 
@@ -475,6 +481,7 @@ def get_executor_on_sync_context(
475
481
  max_retries: Maximum number of retry attempts. Defaults to 10.
476
482
  exit_on_error: Whether to exit on first error. Defaults to True.
477
483
  fallback_return_value: Value to return on failure. Defaults to unset.
484
+ timeout: Timeout for each task in seconds. Defaults to 120.
478
485
 
479
486
  Returns:
480
487
  An Executor instance configured for the current context.
@@ -513,6 +520,7 @@ def get_executor_on_sync_context(
513
520
  max_retries=max_retries,
514
521
  exit_on_error=exit_on_error,
515
522
  fallback_return_value=fallback_return_value,
523
+ timeout=timeout,
516
524
  )
517
525
  logger.warning(
518
526
  "🐌!! If running inside a notebook, patching the event loop with "
@@ -533,6 +541,7 @@ def get_executor_on_sync_context(
533
541
  max_retries=max_retries,
534
542
  exit_on_error=exit_on_error,
535
543
  fallback_return_value=fallback_return_value,
544
+ timeout=timeout,
536
545
  )
537
546
 
538
547
 
@@ -276,7 +276,9 @@ class RateLimiter:
276
276
  """Apply rate limiting to an asynchronous function."""
277
277
 
278
278
  @wraps(fn)
279
- async def wrapper(*args: object, **kwargs: object) -> GenericType:
279
+ async def wrapper(
280
+ *args: ParameterSpec.args, **kwargs: ParameterSpec.kwargs
281
+ ) -> GenericType:
280
282
  self._initialize_async_primitives()
281
283
  if self._rate_limit_handling_lock is None or not isinstance(
282
284
  self._rate_limit_handling_lock, asyncio.Lock
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from dataclasses import dataclass, field
6
6
  from enum import Enum
7
- from typing import TYPE_CHECKING
7
+ from typing import TYPE_CHECKING, cast
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from collections.abc import Mapping
@@ -60,10 +60,12 @@ class EvaluationResult:
60
60
  if not obj:
61
61
  return None
62
62
  return cls(
63
- score=obj.get("score"),
64
- label=obj.get("label"),
65
- explanation=obj.get("explanation"),
66
- metadata=obj.get("metadata") or {},
63
+ score=cast("float | None", obj.get("score")),
64
+ label=cast("str | None", obj.get("label")),
65
+ explanation=cast("str | None", obj.get("explanation")),
66
+ metadata=cast(
67
+ "Mapping[str, JSONSerializable]", obj.get("metadata") or {}
68
+ ),
67
69
  )
68
70
 
69
71
  def __post_init__(self) -> None:
@@ -94,14 +96,14 @@ EvaluatorOutput = (
94
96
 
95
97
  @dataclass
96
98
  class EvaluationResultFieldNames:
97
- """Column names for mapping evaluation results in a DataFrame.
99
+ """Column names for mapping evaluation results in a :class:`pandas.DataFrame`.
98
100
 
99
101
  Args:
100
102
  score: Optional name of column containing evaluation scores
101
103
  label: Optional name of column containing evaluation labels
102
104
  explanation: Optional name of column containing evaluation explanations
103
105
  metadata: Optional mapping of metadata keys to column names. If a column name
104
- is None or empty string, the metadata key will be used as the column name.
106
+ is :obj:`None` or empty string, the metadata key will be used as the column name.
105
107
 
106
108
  Examples:
107
109
  >>> # Basic usage with score and label columns
@@ -2,8 +2,8 @@
2
2
 
3
3
  import functools
4
4
  import inspect
5
- from collections.abc import Callable
6
- from typing import TYPE_CHECKING
5
+ from collections.abc import Awaitable, Callable
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from tqdm.auto import tqdm
9
9
 
@@ -154,10 +154,10 @@ def _wrap_coroutine_evaluation_function(
154
154
  name: str,
155
155
  sig: inspect.Signature,
156
156
  convert_to_score: Callable[[object], EvaluationResult],
157
- ) -> Callable[[Callable[..., object]], "Evaluator"]:
157
+ ) -> Callable[[Callable[..., Awaitable[object]]], "Evaluator"]:
158
158
  from ..evaluators.base import Evaluator
159
159
 
160
- def wrapper(func: Callable[..., object]) -> "Evaluator":
160
+ def wrapper(func: Callable[..., Awaitable[object]]) -> "Evaluator":
161
161
  class AsyncEvaluator(Evaluator):
162
162
  def __init__(self) -> None:
163
163
  self._name = name
@@ -224,9 +224,11 @@ def _default_eval_scorer(result: object) -> EvaluationResult:
224
224
  raise ValueError(f"Unsupported evaluation result type: {type(result)}")
225
225
 
226
226
 
227
- def printif(condition: bool, *args: object, **kwargs: object) -> None:
227
+ def printif(condition: bool, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
228
228
  """Print to tqdm output if the condition is true.
229
229
 
230
+ Note: *args/**kwargs use Any for proper pass-through to tqdm.write().
231
+
230
232
  Args:
231
233
  condition: Whether to print the message.
232
234
  *args: Positional arguments to pass to tqdm.write.
@@ -7,7 +7,14 @@ import json
7
7
  import logging
8
8
  import traceback
9
9
  from binascii import hexlify
10
- from collections.abc import Awaitable, Callable, Mapping, Sequence
10
+ from collections.abc import (
11
+ Awaitable,
12
+ Callable,
13
+ Coroutine,
14
+ Iterable,
15
+ Mapping,
16
+ Sequence,
17
+ )
11
18
  from contextlib import ExitStack
12
19
  from copy import deepcopy
13
20
  from datetime import date, datetime, time, timedelta, timezone
@@ -34,7 +41,7 @@ from openinference.semconv.trace import (
34
41
  )
35
42
  from opentelemetry.context import Context
36
43
  from opentelemetry.sdk.resources import Resource
37
- from opentelemetry.trace import Status, StatusCode, Tracer
44
+ from opentelemetry.trace import NoOpTracer, Status, StatusCode, Tracer
38
45
 
39
46
  if TYPE_CHECKING:
40
47
  from opentelemetry.sdk.trace import Span
@@ -48,6 +55,7 @@ from arize.experiments.evaluators.types import (
48
55
  EvaluationResult,
49
56
  EvaluationResultFieldNames,
50
57
  EvaluatorName,
58
+ JSONSerializable,
51
59
  )
52
60
  from arize.experiments.evaluators.utils import create_evaluator
53
61
  from arize.experiments.tracing import capture_spans, flatten
@@ -64,6 +72,9 @@ RateLimitErrors: TypeAlias = type[BaseException] | Sequence[type[BaseException]]
64
72
 
65
73
  logger = logging.getLogger(__name__)
66
74
 
75
+ # Module-level singleton for no-op tracing
76
+ _NOOP_TRACER = NoOpTracer()
77
+
67
78
 
68
79
  def run_experiment(
69
80
  experiment_name: str,
@@ -76,23 +87,25 @@ def run_experiment(
76
87
  evaluators: Evaluators | None = None,
77
88
  concurrency: int = 3,
78
89
  exit_on_error: bool = False,
90
+ timeout: int = 120,
79
91
  ) -> pd.DataFrame:
80
92
  """Run an experiment on a dataset.
81
93
 
82
94
  Args:
83
95
  experiment_name (str): The name for the experiment.
84
96
  experiment_id (str): The ID for the experiment.
85
- dataset (pd.DataFrame): The dataset to run the experiment on.
97
+ dataset (:class:`pandas.DataFrame`): The dataset to run the experiment on.
86
98
  task (ExperimentTask): The task to be executed on the dataset.
87
99
  tracer (Tracer): Tracer for tracing the experiment.
88
100
  resource (Resource): The resource for tracing the experiment.
89
- rate_limit_errors (Optional[RateLimitErrors]): Optional rate limit errors.
90
- evaluators (Optional[Evaluators]): Optional evaluators to assess the task.
101
+ rate_limit_errors (RateLimitErrors | :obj:`None`): Optional rate limit errors.
102
+ evaluators (Evaluators | :obj:`None`): Optional evaluators to assess the task.
91
103
  concurrency (int): The number of concurrent tasks to run. Default is 3.
92
104
  exit_on_error (bool): Whether to exit on error. Default is False.
105
+ timeout (int): The timeout for each task execution in seconds. Default is 120.
93
106
 
94
107
  Returns:
95
- pd.DataFrame: The results of the experiment.
108
+ :class:`pandas.DataFrame`: The results of the experiment.
96
109
  """
97
110
  task_signature = inspect.signature(task)
98
111
  _validate_task_signature(task_signature)
@@ -114,11 +127,12 @@ def run_experiment(
114
127
  error: BaseException | None = None
115
128
  status = Status(StatusCode.OK)
116
129
  with ExitStack() as stack:
117
- span: Span = stack.enter_context(
130
+ # Type ignore: OpenTelemetry interface vs implementation type mismatch
131
+ span: Span = stack.enter_context( # type: ignore[assignment]
118
132
  cm=tracer.start_as_current_span(
119
133
  name=root_span_name, context=Context()
120
134
  )
121
- ) # type:ignore
135
+ )
122
136
  stack.enter_context(capture_spans(resource))
123
137
  span.set_attribute(METADATA, json.dumps(md, ensure_ascii=False))
124
138
  try:
@@ -144,9 +158,12 @@ def run_experiment(
144
158
  raise TypeError(sync_error_message)
145
159
  output = _output
146
160
 
147
- output = jsonify(output)
161
+ # Type ignore: jsonify returns object but runtime result is JSONSerializable
162
+ output = jsonify(output) # type: ignore[assignment]
148
163
  if example.input:
149
- span.set_attribute(INPUT_VALUE, example.input) # type: ignore
164
+ # OpenTelemetry type hints are restrictive, but Arize's tracing layer
165
+ # accepts JSON-serializable structures which are auto-serialized
166
+ span.set_attribute(INPUT_VALUE, example.input) # type: ignore[arg-type]
150
167
  else:
151
168
  span.set_attribute(
152
169
  INPUT_VALUE,
@@ -185,9 +202,9 @@ def run_experiment(
185
202
  else datetime.now(tz=timezone.utc)
186
203
  ),
187
204
  dataset_example_id=example.id,
188
- output=output, # type:ignore
205
+ output=output,
189
206
  error=repr(error) if error else None,
190
- trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore
207
+ trace_id=_str_trace_id(span.get_span_context().trace_id),
191
208
  )
192
209
 
193
210
  async def async_run_experiment(example: Example) -> ExperimentRun:
@@ -195,11 +212,12 @@ def run_experiment(
195
212
  error: BaseException | None = None
196
213
  status = Status(StatusCode.OK)
197
214
  with ExitStack() as stack:
198
- span: Span = stack.enter_context(
215
+ # Type ignore: OpenTelemetry interface vs implementation type mismatch
216
+ span: Span = stack.enter_context( # type: ignore[assignment]
199
217
  cm=tracer.start_as_current_span(
200
218
  name=root_span_name, context=Context()
201
219
  )
202
- ) # type:ignore
220
+ )
203
221
  stack.enter_context(capture_spans(resource))
204
222
  span.set_attribute(METADATA, json.dumps(md, ensure_ascii=False))
205
223
  try:
@@ -218,9 +236,12 @@ def run_experiment(
218
236
  )
219
237
  error = exc
220
238
  _print_experiment_error(exc, example_id=example.id, kind="task")
221
- output = jsonify(output)
239
+ # Type ignore: jsonify returns object but runtime result is JSONSerializable
240
+ output = jsonify(output) # type: ignore[assignment]
222
241
  if example.input:
223
- span.set_attribute(INPUT_VALUE, example.input) # type: ignore
242
+ # OpenTelemetry type hints are restrictive, but Arize's tracing layer
243
+ # accepts JSON-serializable structures which are auto-serialized
244
+ span.set_attribute(INPUT_VALUE, example.input) # type: ignore[arg-type]
224
245
  else:
225
246
  span.set_attribute(
226
247
  INPUT_VALUE,
@@ -259,14 +280,14 @@ def run_experiment(
259
280
  else datetime.now(tz=timezone.utc)
260
281
  ),
261
282
  dataset_example_id=example.id,
262
- output=output, # type: ignore
283
+ output=output,
263
284
  error=repr(error) if error else None,
264
- trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore
285
+ trace_id=_str_trace_id(span.get_span_context().trace_id),
265
286
  )
266
287
 
267
288
  _errors: tuple[type[BaseException], ...]
268
289
  if not isinstance(rate_limit_errors, Sequence):
269
- _errors = (rate_limit_errors,) # type: ignore
290
+ _errors = (rate_limit_errors,) if rate_limit_errors is not None else ()
270
291
  else:
271
292
  _errors = tuple(filter(None, rate_limit_errors))
272
293
  rate_limiters = [RateLimiter(rate_limit_error=rle) for rle in _errors]
@@ -282,30 +303,43 @@ def run_experiment(
282
303
  )
283
304
 
284
305
  executor = get_executor_on_sync_context(
285
- sync_fn=rate_limited_sync_run_experiment,
286
- async_fn=rate_limited_async_run_experiment,
306
+ sync_fn=cast(
307
+ "Callable[[object], Any]", rate_limited_sync_run_experiment
308
+ ),
309
+ async_fn=cast(
310
+ "Callable[[object], Coroutine[Any, Any, Any]]",
311
+ rate_limited_async_run_experiment,
312
+ ),
287
313
  max_retries=0,
288
314
  exit_on_error=exit_on_error,
289
315
  fallback_return_value=None,
290
316
  tqdm_bar_format=get_tqdm_progress_bar_formatter("running tasks"),
291
317
  concurrency=concurrency,
318
+ timeout=timeout,
292
319
  )
293
320
 
294
321
  runs, _ = executor.run(examples)
295
- task_summary = _TaskSummary.from_task_runs(len(dataset), runs)
322
+ task_summary = _TaskSummary.from_task_runs(
323
+ len(dataset), cast("list[ExperimentRun | None]", runs)
324
+ )
296
325
 
297
326
  if exit_on_error and (None in runs):
298
327
  # When exit_on_error is True, the result of a failed task execution is None
299
328
  # If any task execution failed, raise an error to exit early
300
329
  raise RuntimeError("An error occurred during execution of tasks.")
301
330
 
331
+ # Filter out None values before accessing attributes
332
+ runs_filtered = [
333
+ r for r in cast("list[ExperimentRun | None]", runs) if r is not None
334
+ ]
335
+
302
336
  out_df = pd.DataFrame()
303
- out_df["id"] = [run.id for run in runs]
304
- out_df["example_id"] = [run.dataset_example_id for run in runs]
305
- out_df["result"] = [run.output for run in runs]
306
- out_df["result.trace.id"] = [run.trace_id for run in runs]
337
+ out_df["id"] = [run.id for run in runs_filtered]
338
+ out_df["example_id"] = [run.dataset_example_id for run in runs_filtered]
339
+ out_df["result"] = [run.output for run in runs_filtered] # type: ignore[assignment]
340
+ out_df["result.trace.id"] = [run.trace_id for run in runs_filtered]
307
341
  out_df["result.trace.timestamp"] = [
308
- int(run.start_time.timestamp() * 1e3) for run in runs
342
+ int(run.start_time.timestamp() * 1e3) for run in runs_filtered
309
343
  ]
310
344
  out_df.set_index("id", inplace=True, drop=False)
311
345
  logger.info(f"✅ Task runs completed.\n{task_summary}")
@@ -314,13 +348,14 @@ def run_experiment(
314
348
  eval_results = evaluate_experiment(
315
349
  experiment_name=experiment_name,
316
350
  examples=examples,
317
- experiment_results=runs,
351
+ experiment_results=cast("Sequence[ExperimentRun]", runs),
318
352
  evaluators=evaluators,
319
353
  rate_limit_errors=rate_limit_errors,
320
354
  concurrency=concurrency,
321
355
  tracer=tracer,
322
356
  resource=resource,
323
357
  exit_on_error=exit_on_error,
358
+ timeout=timeout,
324
359
  )
325
360
 
326
361
  if exit_on_error and (None in eval_results):
@@ -329,7 +364,7 @@ def run_experiment(
329
364
  )
330
365
 
331
366
  # group evaluation results by name
332
- eval_results_by_name = {}
367
+ eval_results_by_name: dict[str, list[ExperimentEvaluationRun]] = {}
333
368
  for r in eval_results:
334
369
  if r is None:
335
370
  continue
@@ -351,7 +386,8 @@ def run_experiment(
351
386
  }
352
387
 
353
388
  for attr, getter in eval_data.items():
354
- out_df[f"eval.{eval_name}.{attr}"] = out_df.index.map(
389
+ # Type ignore: pandas DataFrame column assignment type is overly restrictive
390
+ out_df[f"eval.{eval_name}.{attr}"] = out_df.index.map( # type: ignore[assignment]
355
391
  {r.experiment_run_id: getter(r) for r in eval_res}
356
392
  )
357
393
  out_df = _add_metadata_to_output_df(out_df, eval_res, eval_name)
@@ -368,9 +404,10 @@ def evaluate_experiment(
368
404
  evaluators: Evaluators | None = None,
369
405
  rate_limit_errors: RateLimitErrors | None = None,
370
406
  concurrency: int = 3,
371
- tracer: Tracer | None = None,
407
+ tracer: Tracer = _NOOP_TRACER,
372
408
  resource: Resource | None = None,
373
409
  exit_on_error: bool = False,
410
+ timeout: int = 120,
374
411
  ) -> list[ExperimentEvaluationRun]:
375
412
  """Evaluate the results of an experiment using the provided evaluators.
376
413
 
@@ -379,11 +416,12 @@ def evaluate_experiment(
379
416
  examples (Sequence[Example]): The examples to evaluate.
380
417
  experiment_results (Sequence[ExperimentRun]): The results of the experiment.
381
418
  evaluators (Evaluators): The evaluators to use for assessment.
382
- rate_limit_errors (Optional[RateLimitErrors]): Optional rate limit errors.
419
+ rate_limit_errors (RateLimitErrors | :obj:`None`): Optional rate limit errors.
383
420
  concurrency (int): The number of concurrent tasks to run. Default is 3.
384
- tracer (Optional[Tracer]): Optional tracer for tracing the evaluation.
385
- resource (Optional[Resource]): Optional resource for the evaluation.
421
+ tracer (Tracer): Tracer for tracing the evaluation. Defaults to NoOpTracer().
422
+ resource (Resource | :obj:`None`): Optional resource for the evaluation.
386
423
  exit_on_error (bool): Whether to exit on error. Default is False.
424
+ timeout (int): The timeout for each evaluation in seconds. Default is 120.
387
425
 
388
426
  Returns:
389
427
  List[ExperimentEvaluationRun]: The evaluation results.
@@ -419,12 +457,16 @@ def evaluate_experiment(
419
457
  status = Status(StatusCode.OK)
420
458
  root_span_name = f"Evaluation: {evaluator.name}"
421
459
  with ExitStack() as stack:
422
- span: Span = stack.enter_context(
423
- tracer.start_as_current_span( # type:ignore
424
- name=root_span_name, context=Context()
425
- )
460
+ span: Span = cast(
461
+ "Span",
462
+ stack.enter_context(
463
+ tracer.start_as_current_span(
464
+ name=root_span_name, context=Context()
465
+ )
466
+ ),
426
467
  )
427
- stack.enter_context(capture_spans(resource)) # type:ignore
468
+ if resource is not None:
469
+ stack.enter_context(capture_spans(resource))
428
470
  span.set_attribute(METADATA, json.dumps(md, ensure_ascii=False))
429
471
  try:
430
472
  result = evaluator.evaluate(
@@ -450,7 +492,15 @@ def evaluate_experiment(
450
492
  )
451
493
  if result:
452
494
  span.set_attributes(
453
- dict(flatten(jsonify(result), recurse_on_sequence=True))
495
+ dict(
496
+ flatten(
497
+ cast(
498
+ "Mapping[str, Any] | Iterable[Any]",
499
+ jsonify(result),
500
+ ),
501
+ recurse_on_sequence=True,
502
+ )
503
+ )
454
504
  )
455
505
  span.set_attribute(OPENINFERENCE_SPAN_KIND, root_span_kind)
456
506
  span.set_status(status)
@@ -467,7 +517,7 @@ def evaluate_experiment(
467
517
  annotator_kind=evaluator.kind,
468
518
  error=repr(error) if error else None,
469
519
  result=result,
470
- trace_id=_str_trace_id(span.get_span_context().trace_id), # type:ignore
520
+ trace_id=_str_trace_id(span.get_span_context().trace_id),
471
521
  )
472
522
 
473
523
  async def async_eval_run(
@@ -479,12 +529,16 @@ def evaluate_experiment(
479
529
  status = Status(StatusCode.OK)
480
530
  root_span_name = f"Evaluation: {evaluator.name}"
481
531
  with ExitStack() as stack:
482
- span: Span = stack.enter_context(
483
- tracer.start_as_current_span( # type:ignore
484
- name=root_span_name, context=Context()
485
- )
532
+ span: Span = cast(
533
+ "Span",
534
+ stack.enter_context(
535
+ tracer.start_as_current_span(
536
+ name=root_span_name, context=Context()
537
+ )
538
+ ),
486
539
  )
487
- stack.enter_context(capture_spans(resource)) # type:ignore
540
+ if resource is not None:
541
+ stack.enter_context(capture_spans(resource))
488
542
  span.set_attribute(METADATA, json.dumps(md, ensure_ascii=False))
489
543
  try:
490
544
  result = await evaluator.async_evaluate(
@@ -510,7 +564,15 @@ def evaluate_experiment(
510
564
  )
511
565
  if result:
512
566
  span.set_attributes(
513
- dict(flatten(jsonify(result), recurse_on_sequence=True))
567
+ dict(
568
+ flatten(
569
+ cast(
570
+ "Mapping[str, Any] | Iterable[Any]",
571
+ jsonify(result),
572
+ ),
573
+ recurse_on_sequence=True,
574
+ )
575
+ )
514
576
  )
515
577
  span.set_attribute(OPENINFERENCE_SPAN_KIND, root_span_kind)
516
578
  span.set_status(status)
@@ -526,7 +588,7 @@ def evaluate_experiment(
526
588
  annotator_kind=evaluator.kind,
527
589
  error=repr(error) if error else None,
528
590
  result=result,
529
- trace_id=_str_trace_id(span.get_span_context().trace_id), # type:ignore
591
+ trace_id=_str_trace_id(span.get_span_context().trace_id),
530
592
  )
531
593
 
532
594
  _errors: tuple[type[BaseException], ...]
@@ -547,8 +609,11 @@ def evaluate_experiment(
547
609
  )
548
610
 
549
611
  executor = get_executor_on_sync_context(
550
- rate_limited_sync_evaluate_run,
551
- rate_limited_async_evaluate_run,
612
+ cast("Callable[[object], Any]", rate_limited_sync_evaluate_run),
613
+ cast(
614
+ "Callable[[object], Coroutine[Any, Any, Any]]",
615
+ rate_limited_async_evaluate_run,
616
+ ),
552
617
  max_retries=0,
553
618
  exit_on_error=exit_on_error,
554
619
  fallback_return_value=None,
@@ -556,16 +621,18 @@ def evaluate_experiment(
556
621
  "running experiment evaluations"
557
622
  ),
558
623
  concurrency=concurrency,
624
+ timeout=timeout,
559
625
  )
560
626
  eval_runs, _ = executor.run(evaluation_input)
561
- return eval_runs
627
+ # Cast: run returns list[Unset | object], but sync/async_eval_run guarantee ExperimentEvaluationRun
628
+ return cast("list[ExperimentEvaluationRun]", eval_runs)
562
629
 
563
630
 
564
631
  def _add_metadata_to_output_df(
565
632
  output_df: pd.DataFrame,
566
633
  eval_runs: list[ExperimentEvaluationRun],
567
634
  evaluator_name: str,
568
- ) -> object:
635
+ ) -> pd.DataFrame:
569
636
  for eval_run in eval_runs:
570
637
  if eval_run.result is None:
571
638
  continue
@@ -596,7 +663,9 @@ def _dataframe_to_examples(dataset: pd.DataFrame) -> list[Example]:
596
663
  examples = []
597
664
 
598
665
  for _, row in dataset.iterrows():
599
- example = Example(dataset_row=row.to_dict())
666
+ example = Example(
667
+ dataset_row=cast("Mapping[str, JSONSerializable]", row.to_dict())
668
+ )
600
669
  examples.append(example)
601
670
  return examples
602
671
 
@@ -763,7 +832,8 @@ def get_result_attr(r: object, attr: str, default: object = None) -> object:
763
832
  Returns:
764
833
  The attribute value if found, otherwise the default value.
765
834
  """
766
- return getattr(r.result, attr, default) if r.result else default
835
+ # Type ignore: r typed as object but expected to have result attribute at runtime
836
+ return getattr(r.result, attr, default) if r.result else default # type: ignore[attr-defined]
767
837
 
768
838
 
769
839
  def transform_to_experiment_format(
@@ -771,16 +841,16 @@ def transform_to_experiment_format(
771
841
  task_fields: ExperimentTaskFieldNames,
772
842
  evaluator_fields: dict[str, EvaluationResultFieldNames] | None = None,
773
843
  ) -> pd.DataFrame:
774
- """Transform a DataFrame to match the format returned by run_experiment().
844
+ """Transform a :class:`pandas.DataFrame` to match the format returned by run_experiment().
775
845
 
776
846
  Args:
777
- experiment_runs: Input list of dictionaries or DataFrame containing experiment results
847
+ experiment_runs: Input list of dictionaries or :class:`pandas.DataFrame` containing experiment results
778
848
  task_fields: Field name mapping for task results
779
849
  evaluator_fields: Dictionary mapping evaluator names (str)
780
850
  to their field name mappings (EvaluationResultFieldNames)
781
851
 
782
852
  Returns:
783
- DataFrame in the format matching run_experiment() output
853
+ :class:`pandas.DataFrame` in the format matching run_experiment() output
784
854
  """
785
855
  data = (
786
856
  experiment_runs
@@ -822,7 +892,7 @@ def _add_evaluator_columns(
822
892
  evaluator_name: str,
823
893
  column_names: EvaluationResultFieldNames,
824
894
  ) -> None:
825
- """Helper function to add evaluator columns to output DataFrame."""
895
+ """Helper function to add evaluator columns to output :class:`pandas.DataFrame`."""
826
896
  # Add score if specified
827
897
  if column_names.score and column_names.score in input_df.columns:
828
898
  output_df[f"eval.{evaluator_name}.score"] = input_df[column_names.score]
@@ -57,7 +57,10 @@ _ACTIVE_MODIFIER: ContextVar[SpanModifier | None] = ContextVar(
57
57
 
58
58
 
59
59
  def override_span(
60
- init: Callable[..., None], span: ReadableSpan, args: object, kwargs: object
60
+ init: Callable[..., None],
61
+ span: ReadableSpan,
62
+ args: tuple[object, ...],
63
+ kwargs: dict[str, object],
61
64
  ) -> None:
62
65
  """Override span initialization to apply active span modifiers.
63
66