arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,3 +1,5 @@
1
+ """Experiment utility functions for task execution and annotation."""
2
+
1
3
  import dataclasses
2
4
  import functools
3
5
  import inspect
@@ -5,6 +7,7 @@ import json
5
7
  import logging
6
8
  import traceback
7
9
  from binascii import hexlify
10
+ from collections.abc import Awaitable, Callable, Mapping, Sequence
8
11
  from contextlib import ExitStack
9
12
  from copy import deepcopy
10
13
  from datetime import date, datetime, time, timedelta, timezone
@@ -12,16 +15,10 @@ from enum import Enum
12
15
  from itertools import product
13
16
  from pathlib import Path
14
17
  from typing import (
18
+ TYPE_CHECKING,
15
19
  Any,
16
- Awaitable,
17
- Callable,
18
- Dict,
19
- List,
20
20
  Literal,
21
- Mapping,
22
- Sequence,
23
- Tuple,
24
- Type,
21
+ TypeAlias,
25
22
  Union,
26
23
  cast,
27
24
  get_args,
@@ -37,9 +34,10 @@ from openinference.semconv.trace import (
37
34
  )
38
35
  from opentelemetry.context import Context
39
36
  from opentelemetry.sdk.resources import Resource
40
- from opentelemetry.sdk.trace import Span
41
37
  from opentelemetry.trace import Status, StatusCode, Tracer
42
- from typing_extensions import TypeAlias
38
+
39
+ if TYPE_CHECKING:
40
+ from opentelemetry.sdk.trace import Span
43
41
 
44
42
  from arize.experiments.evaluators.base import Evaluator, Evaluators
45
43
  from arize.experiments.evaluators.executors import (
@@ -62,9 +60,7 @@ from arize.experiments.types import (
62
60
  _TaskSummary,
63
61
  )
64
62
 
65
- RateLimitErrors: TypeAlias = Union[
66
- Type[BaseException], Sequence[Type[BaseException]]
67
- ]
63
+ RateLimitErrors: TypeAlias = type[BaseException] | Sequence[type[BaseException]]
68
64
 
69
65
  logger = logging.getLogger(__name__)
70
66
 
@@ -81,8 +77,8 @@ def run_experiment(
81
77
  concurrency: int = 3,
82
78
  exit_on_error: bool = False,
83
79
  ) -> pd.DataFrame:
84
- """
85
- Run an experiment on a dataset.
80
+ """Run an experiment on a dataset.
81
+
86
82
  Args:
87
83
  experiment_name (str): The name for the experiment.
88
84
  experiment_id (str): The ID for the experiment.
@@ -94,6 +90,7 @@ def run_experiment(
94
90
  evaluators (Optional[Evaluators]): Optional evaluators to assess the task.
95
91
  concurrency (int): The number of concurrent tasks to run. Default is 3.
96
92
  exit_on_error (bool): Whether to exit on error. Default is False.
93
+
97
94
  Returns:
98
95
  pd.DataFrame: The results of the experiment.
99
96
  """
@@ -127,25 +124,25 @@ def run_experiment(
127
124
  try:
128
125
  bound_task_args = _bind_task_signature(task_signature, example)
129
126
  _output = task(*bound_task_args.args, **bound_task_args.kwargs)
130
- if isinstance(_output, Awaitable):
131
- sync_error_message = (
132
- "Task is async and cannot be run within an existing event loop. "
133
- "Consider the following options:\n\n"
134
- "1. Pass in a synchronous task callable.\n"
135
- "2. Use `nest_asyncio.apply()` to allow nesting event loops."
136
- )
137
- raise RuntimeError(sync_error_message)
138
- else:
139
- output = _output
140
127
  except BaseException as exc:
141
128
  if exit_on_error:
142
- raise exc
129
+ raise
143
130
  span.record_exception(exc)
144
131
  status = Status(
145
132
  StatusCode.ERROR, f"{type(exc).__name__}: {exc}"
146
133
  )
147
134
  error = exc
148
135
  _print_experiment_error(exc, example_id=example.id, kind="task")
136
+ else:
137
+ if isinstance(_output, Awaitable):
138
+ sync_error_message = (
139
+ "Task is async and cannot be run within an existing event loop. "
140
+ "Consider the following options:\n\n"
141
+ "1. Pass in a synchronous task callable.\n"
142
+ "2. Use `nest_asyncio.apply()` to allow nesting event loops."
143
+ )
144
+ raise TypeError(sync_error_message)
145
+ output = _output
149
146
 
150
147
  output = jsonify(output)
151
148
  if example.input:
@@ -171,25 +168,27 @@ def run_experiment(
171
168
  )
172
169
  span.set_status(status)
173
170
 
174
- assert isinstance(
171
+ if not isinstance(
175
172
  output, (dict, list, str, int, float, bool, type(None))
176
- ), "Output must be JSON serializable"
173
+ ):
174
+ raise TypeError(
175
+ f"Output must be JSON serializable, got {type(output).__name__}"
176
+ )
177
177
 
178
- exp_run = ExperimentRun(
178
+ return ExperimentRun(
179
179
  experiment_id=experiment_name,
180
180
  repetition_number=1,
181
- start_time=_decode_unix_nano(cast(int, span.start_time)),
181
+ start_time=_decode_unix_nano(cast("int", span.start_time)),
182
182
  end_time=(
183
- _decode_unix_nano(cast(int, span.end_time))
183
+ _decode_unix_nano(cast("int", span.end_time))
184
184
  if span.end_time
185
- else datetime.now()
185
+ else datetime.now(tz=timezone.utc)
186
186
  ),
187
187
  dataset_example_id=example.id,
188
188
  output=output, # type:ignore
189
189
  error=repr(error) if error else None,
190
190
  trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore
191
191
  )
192
- return exp_run
193
192
 
194
193
  async def async_run_experiment(example: Example) -> ExperimentRun:
195
194
  output = None
@@ -212,7 +211,7 @@ def run_experiment(
212
211
  output = _output
213
212
  except BaseException as exc:
214
213
  if exit_on_error:
215
- raise exc
214
+ raise
216
215
  span.record_exception(exc)
217
216
  status = Status(
218
217
  StatusCode.ERROR, f"{type(exc).__name__}: {exc}"
@@ -243,27 +242,29 @@ def run_experiment(
243
242
  )
244
243
  span.set_status(status)
245
244
 
246
- assert isinstance(
245
+ if not isinstance(
247
246
  output, (dict, list, str, int, float, bool, type(None))
248
- ), "Output must be JSON serializable"
247
+ ):
248
+ raise TypeError(
249
+ f"Output must be JSON serializable, got {type(output).__name__}"
250
+ )
249
251
 
250
- exp_run = ExperimentRun(
252
+ return ExperimentRun(
251
253
  experiment_id=experiment_name,
252
254
  repetition_number=1,
253
- start_time=_decode_unix_nano(cast(int, span.start_time)),
255
+ start_time=_decode_unix_nano(cast("int", span.start_time)),
254
256
  end_time=(
255
- _decode_unix_nano(cast(int, span.end_time))
257
+ _decode_unix_nano(cast("int", span.end_time))
256
258
  if span.end_time
257
- else datetime.now()
259
+ else datetime.now(tz=timezone.utc)
258
260
  ),
259
261
  dataset_example_id=example.id,
260
262
  output=output, # type: ignore
261
263
  error=repr(error) if error else None,
262
264
  trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore
263
265
  )
264
- return exp_run
265
266
 
266
- _errors: Tuple[Type[BaseException], ...]
267
+ _errors: tuple[type[BaseException], ...]
267
268
  if not isinstance(rate_limit_errors, Sequence):
268
269
  _errors = (rate_limit_errors,) # type: ignore
269
270
  else:
@@ -370,9 +371,9 @@ def evaluate_experiment(
370
371
  tracer: Tracer | None = None,
371
372
  resource: Resource | None = None,
372
373
  exit_on_error: bool = False,
373
- ):
374
- """
375
- Evaluate the results of an experiment using the provided evaluators.
374
+ ) -> list[ExperimentEvaluationRun]:
375
+ """Evaluate the results of an experiment using the provided evaluators.
376
+
376
377
  Args:
377
378
  experiment_name (str): The name of the experiment.
378
379
  examples (Sequence[Example]): The examples to evaluate.
@@ -383,6 +384,7 @@ def evaluate_experiment(
383
384
  tracer (Optional[Tracer]): Optional tracer for tracing the evaluation.
384
385
  resource (Optional[Resource]): Optional resource for the evaluation.
385
386
  exit_on_error (bool): Whether to exit on error. Default is False.
387
+
386
388
  Returns:
387
389
  List[ExperimentEvaluationRun]: The evaluation results.
388
390
  """
@@ -409,7 +411,7 @@ def evaluate_experiment(
409
411
  md = {"experiment_name": experiment_name}
410
412
 
411
413
  def sync_eval_run(
412
- obj: Tuple[Example, ExperimentRun, Evaluator],
414
+ obj: tuple[Example, ExperimentRun, Evaluator],
413
415
  ) -> ExperimentEvaluationRun:
414
416
  example, experiment_run, evaluator = obj
415
417
  result: EvaluationResult | None = None
@@ -435,7 +437,7 @@ def evaluate_experiment(
435
437
  )
436
438
  except BaseException as exc:
437
439
  if exit_on_error:
438
- raise exc
440
+ raise
439
441
  span.record_exception(exc)
440
442
  status = Status(
441
443
  StatusCode.ERROR, f"{type(exc).__name__}: {exc}"
@@ -453,13 +455,13 @@ def evaluate_experiment(
453
455
  span.set_attribute(OPENINFERENCE_SPAN_KIND, root_span_kind)
454
456
  span.set_status(status)
455
457
 
456
- eval_run = ExperimentEvaluationRun(
458
+ return ExperimentEvaluationRun(
457
459
  experiment_run_id=experiment_run.id,
458
- start_time=_decode_unix_nano(cast(int, span.start_time)),
460
+ start_time=_decode_unix_nano(cast("int", span.start_time)),
459
461
  end_time=(
460
- _decode_unix_nano(cast(int, span.end_time))
462
+ _decode_unix_nano(cast("int", span.end_time))
461
463
  if span.end_time
462
- else datetime.now()
464
+ else datetime.now(tz=timezone.utc)
463
465
  ),
464
466
  name=evaluator.name,
465
467
  annotator_kind=evaluator.kind,
@@ -467,10 +469,9 @@ def evaluate_experiment(
467
469
  result=result,
468
470
  trace_id=_str_trace_id(span.get_span_context().trace_id), # type:ignore
469
471
  )
470
- return eval_run
471
472
 
472
473
  async def async_eval_run(
473
- obj: Tuple[Example, ExperimentRun, Evaluator],
474
+ obj: tuple[Example, ExperimentRun, Evaluator],
474
475
  ) -> ExperimentEvaluationRun:
475
476
  example, experiment_run, evaluator = obj
476
477
  result: EvaluationResult | None = None
@@ -496,7 +497,7 @@ def evaluate_experiment(
496
497
  )
497
498
  except BaseException as exc:
498
499
  if exit_on_error:
499
- raise exc
500
+ raise
500
501
  span.record_exception(exc)
501
502
  status = Status(
502
503
  StatusCode.ERROR, f"{type(exc).__name__}: {exc}"
@@ -513,13 +514,13 @@ def evaluate_experiment(
513
514
  )
514
515
  span.set_attribute(OPENINFERENCE_SPAN_KIND, root_span_kind)
515
516
  span.set_status(status)
516
- eval_run = ExperimentEvaluationRun(
517
+ return ExperimentEvaluationRun(
517
518
  experiment_run_id=experiment_run.id,
518
- start_time=_decode_unix_nano(cast(int, span.start_time)),
519
+ start_time=_decode_unix_nano(cast("int", span.start_time)),
519
520
  end_time=(
520
- _decode_unix_nano(cast(int, span.end_time))
521
+ _decode_unix_nano(cast("int", span.end_time))
521
522
  if span.end_time
522
- else datetime.now()
523
+ else datetime.now(tz=timezone.utc)
523
524
  ),
524
525
  name=evaluator.name,
525
526
  annotator_kind=evaluator.kind,
@@ -527,9 +528,8 @@ def evaluate_experiment(
527
528
  result=result,
528
529
  trace_id=_str_trace_id(span.get_span_context().trace_id), # type:ignore
529
530
  )
530
- return eval_run
531
531
 
532
- _errors: Tuple[Type[BaseException], ...]
532
+ _errors: tuple[type[BaseException], ...]
533
533
  if not isinstance(rate_limit_errors, Sequence):
534
534
  _errors = (rate_limit_errors,) if rate_limit_errors is not None else ()
535
535
  else:
@@ -563,9 +563,9 @@ def evaluate_experiment(
563
563
 
564
564
  def _add_metadata_to_output_df(
565
565
  output_df: pd.DataFrame,
566
- eval_runs: List[ExperimentEvaluationRun],
566
+ eval_runs: list[ExperimentEvaluationRun],
567
567
  evaluator_name: str,
568
- ):
568
+ ) -> object:
569
569
  for eval_run in eval_runs:
570
570
  if eval_run.result is None:
571
571
  continue
@@ -589,7 +589,7 @@ def _add_metadata_to_output_df(
589
589
  return output_df
590
590
 
591
591
 
592
- def _dataframe_to_examples(dataset: pd.DataFrame) -> List[Example]:
592
+ def _dataframe_to_examples(dataset: pd.DataFrame) -> list[Example]:
593
593
  for column in dataset.columns:
594
594
  if pd.api.types.is_datetime64_any_dtype(dataset[column]):
595
595
  dataset[column] = dataset[column].astype(str)
@@ -637,8 +637,7 @@ def _bind_task_signature(
637
637
  parameter_name = next(iter(params))
638
638
  if parameter_name in parameter_mapping:
639
639
  return sig.bind(parameter_mapping[parameter_name])
640
- else:
641
- return sig.bind(parameter_mapping["dataset_row"])
640
+ return sig.bind(parameter_mapping["dataset_row"])
642
641
  return sig.bind_partial(
643
642
  **{
644
643
  name: parameter_mapping[name]
@@ -650,7 +649,7 @@ def _bind_task_signature(
650
649
  def _evaluators_by_name(
651
650
  obj: Evaluators | None,
652
651
  ) -> Mapping[EvaluatorName, Evaluator]:
653
- evaluators_by_name: Dict[EvaluatorName, Evaluator] = {}
652
+ evaluators_by_name: dict[EvaluatorName, Evaluator] = {}
654
653
  if obj is None:
655
654
  return evaluators_by_name
656
655
  if isinstance(obj, Mapping):
@@ -676,7 +675,10 @@ def _evaluators_by_name(
676
675
  raise ValueError(f"Two evaluators have the same name: {name}")
677
676
  evaluators_by_name[name] = evaluator
678
677
  else:
679
- assert not isinstance(obj, Mapping) and not isinstance(obj, Sequence)
678
+ if isinstance(obj, (Mapping, Sequence)):
679
+ raise TypeError(
680
+ "Expected a single evaluator, got a mapping or sequence"
681
+ )
680
682
  evaluator = (
681
683
  create_evaluator()(obj) if not isinstance(obj, Evaluator) else obj
682
684
  )
@@ -688,9 +690,7 @@ def _evaluators_by_name(
688
690
 
689
691
 
690
692
  def get_func_name(fn: Callable[..., Any]) -> str:
691
- """
692
- Makes a best-effort attempt to get the name of the function.
693
- """
693
+ """Makes a best-effort attempt to get the name of the function."""
694
694
  if isinstance(fn, functools.partial):
695
695
  return fn.func.__qualname__
696
696
  if hasattr(fn, "__qualname__") and not fn.__qualname__.endswith("<lambda>"):
@@ -705,12 +705,8 @@ def _print_experiment_error(
705
705
  example_id: str,
706
706
  kind: Literal["evaluator", "task"],
707
707
  ) -> None:
708
- """
709
- Prints an experiment error.
710
- """
711
- display_error = RuntimeError(
712
- f"{kind} failed for example id {repr(example_id)}"
713
- )
708
+ """Prints an experiment error."""
709
+ display_error = RuntimeError(f"{kind} failed for example id {example_id!r}")
714
710
  display_error.__cause__ = error
715
711
  formatted_exception = "".join(
716
712
  traceback.format_exception(
@@ -729,8 +725,7 @@ def _str_trace_id(id_: int) -> str:
729
725
 
730
726
 
731
727
  def get_tqdm_progress_bar_formatter(title: str) -> str:
732
- """
733
- Returns a progress bar formatter for use with tqdm.
728
+ """Returns a progress bar formatter for use with tqdm.
734
729
 
735
730
  Args:
736
731
  title (str): The title of the progress bar, displayed as a prefix.
@@ -757,23 +752,32 @@ EVALUATOR = OpenInferenceSpanKindValues.EVALUATOR.value
757
752
  JSON = OpenInferenceMimeTypeValues.JSON
758
753
 
759
754
 
760
- def get_result_attr(r, attr, default=None):
755
+ def get_result_attr(r: object, attr: str, default: object = None) -> object:
756
+ """Get an attribute from a result object, with fallback to default.
757
+
758
+ Args:
759
+ r: An object with a `result` attribute.
760
+ attr: The attribute name to retrieve from the result.
761
+ default: Value to return if result is None or attribute not found. Defaults to None.
762
+
763
+ Returns:
764
+ The attribute value if found, otherwise the default value.
765
+ """
761
766
  return getattr(r.result, attr, default) if r.result else default
762
767
 
763
768
 
764
769
  def transform_to_experiment_format(
765
- experiment_runs: List[Dict[str, Any]] | pd.DataFrame,
770
+ experiment_runs: list[dict[str, object]] | pd.DataFrame,
766
771
  task_fields: ExperimentTaskResultFieldNames,
767
- evaluator_fields: Dict[str, EvaluationResultFieldNames] | None = None,
772
+ evaluator_fields: dict[str, EvaluationResultFieldNames] | None = None,
768
773
  ) -> pd.DataFrame:
769
- """
770
- Transform a DataFrame to match the format returned by run_experiment().
774
+ """Transform a DataFrame to match the format returned by run_experiment().
771
775
 
772
776
  Args:
773
- df: Input DataFrame containing experiment results
774
- task_columns: Column mapping for task results
775
- evaluator_columns: Dictionary mapping evaluator names (str)
776
- to their column mappings (EvaluationResultColumnNames)
777
+ experiment_runs: Input list of dictionaries or DataFrame containing experiment results
778
+ task_fields: Field name mapping for task results
779
+ evaluator_fields: Dictionary mapping evaluator names (str)
780
+ to their field name mappings (EvaluationResultFieldNames)
777
781
 
778
782
  Returns:
779
783
  DataFrame in the format matching run_experiment() output
@@ -818,7 +822,7 @@ def _add_evaluator_columns(
818
822
  evaluator_name: str,
819
823
  column_names: EvaluationResultFieldNames,
820
824
  ) -> None:
821
- """Helper function to add evaluator columns to output DataFrame"""
825
+ """Helper function to add evaluator columns to output DataFrame."""
822
826
  # Add score if specified
823
827
  if column_names.score and column_names.score in input_df.columns:
824
828
  output_df[f"eval.{evaluator_name}.score"] = input_df[column_names.score]
@@ -862,10 +866,8 @@ def _add_evaluator_columns(
862
866
  output_df[output_col] = output_vals
863
867
 
864
868
 
865
- def jsonify(obj: Any) -> Any:
866
- """
867
- Coerce object to be json serializable.
868
- """
869
+ def jsonify(obj: object) -> object:
870
+ """Coerce object to be json serializable."""
869
871
  if isinstance(obj, Enum):
870
872
  return jsonify(obj.value)
871
873
  if isinstance(obj, (str, int, float, bool)) or obj is None:
@@ -901,20 +903,20 @@ def jsonify(obj: Any) -> Any:
901
903
  if hasattr(obj, "model_dump") and callable(obj.model_dump):
902
904
  # pydantic v2
903
905
  try:
904
- d = obj
905
- assert isinstance(d, dict)
906
- except BaseException:
906
+ d = obj.model_dump()
907
+ if isinstance(d, dict):
908
+ return jsonify(d)
909
+ except Exception: # noqa: S110
910
+ # If model_dump fails or returns non-dict, fall through to next handler
907
911
  pass
908
- else:
909
- return jsonify(d)
910
912
  if hasattr(obj, "dict") and callable(obj.dict):
911
913
  # pydantic v1
912
914
  try:
913
915
  d = obj.dict()
914
- assert isinstance(d, dict)
915
- except BaseException:
916
+ if isinstance(d, dict):
917
+ return jsonify(d)
918
+ except Exception: # noqa: S110
919
+ # If dict fails or returns non-dict, fall through to next handler
916
920
  pass
917
- else:
918
- return jsonify(d)
919
921
  cls = obj.__class__
920
922
  return f"<{cls.__module__}.{cls.__name__} object>"
@@ -1,44 +1,46 @@
1
+ """Experiment tracing functionality for capturing execution context."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import inspect
4
6
  import json
7
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
5
8
  from contextlib import contextmanager
6
9
  from contextvars import ContextVar
7
10
  from threading import Lock
8
11
  from typing import (
12
+ TYPE_CHECKING,
9
13
  Any,
10
- Callable,
11
- Iterable,
12
- Iterator,
13
- List,
14
- Mapping,
15
- Sequence,
16
14
  cast,
17
15
  )
18
16
 
19
17
  import numpy as np
20
18
  from openinference.semconv import trace
21
19
  from openinference.semconv.trace import DocumentAttributes, SpanAttributes
22
- from opentelemetry.sdk.resources import Resource
23
20
  from opentelemetry.sdk.trace import ReadableSpan
24
21
  from opentelemetry.trace import INVALID_TRACE_ID
25
22
  from typing_extensions import assert_never
26
23
  from wrapt import apply_patch, resolve_path, wrap_function_wrapper
27
24
 
25
+ if TYPE_CHECKING:
26
+ from opentelemetry.sdk.resources import Resource
27
+
28
28
 
29
29
  class SpanModifier:
30
- """
31
- A class that modifies spans with the specified resource attributes.
32
- """
30
+ """A class that modifies spans with the specified resource attributes."""
33
31
 
34
32
  __slots__ = ("_resource",)
35
33
 
36
34
  def __init__(self, resource: Resource) -> None:
35
+ """Initialize the span modifier with resource attributes.
36
+
37
+ Args:
38
+ resource: OpenTelemetry Resource containing attributes to merge.
39
+ """
37
40
  self._resource = resource
38
41
 
39
42
  def modify_resource(self, span: ReadableSpan) -> None:
40
- """
41
- Takes a span and merges in the resource attributes specified in the constructor.
43
+ """Takes a span and merges in the resource attributes specified in the constructor.
42
44
 
43
45
  Args:
44
46
  span: ReadableSpan: the span to modify
@@ -55,8 +57,16 @@ _ACTIVE_MODIFIER: ContextVar[SpanModifier | None] = ContextVar(
55
57
 
56
58
 
57
59
  def override_span(
58
- init: Callable[..., None], span: ReadableSpan, args: Any, kwargs: Any
60
+ init: Callable[..., None], span: ReadableSpan, args: object, kwargs: object
59
61
  ) -> None:
62
+ """Override span initialization to apply active span modifiers.
63
+
64
+ Args:
65
+ init: The original span initialization function.
66
+ span: The span being initialized.
67
+ args: Positional arguments for the init function.
68
+ kwargs: Keyword arguments for the init function.
69
+ """
60
70
  init(*args, **kwargs)
61
71
  if isinstance(span_modifier := _ACTIVE_MODIFIER.get(None), SpanModifier):
62
72
  span_modifier.modify_resource(span)
@@ -91,8 +101,7 @@ def _monkey_patch_span_init() -> Iterator[None]:
91
101
 
92
102
  @contextmanager
93
103
  def capture_spans(resource: Resource) -> Iterator[SpanModifier]:
94
- """
95
- A context manager that captures spans and modifies them with the specified resources.
104
+ """A context manager that captures spans and modifies them with the specified resources.
96
105
 
97
106
  Args:
98
107
  resource: Resource: The resource to merge into the spans created within the context.
@@ -139,10 +148,10 @@ JSON_STRING_ATTRIBUTES = (
139
148
  TOOL_PARAMETERS,
140
149
  )
141
150
 
142
- SEMANTIC_CONVENTIONS: List[str] = sorted(
151
+ SEMANTIC_CONVENTIONS: list[str] = sorted(
143
152
  # e.g. "input.value", "llm.token_count.total", etc.
144
153
  (
145
- cast(str, getattr(klass, attr))
154
+ cast("str", getattr(klass, attr))
146
155
  for name in dir(trace)
147
156
  if name.endswith("Attributes")
148
157
  and inspect.isclass(klass := getattr(trace, name))
@@ -162,14 +171,13 @@ def flatten(
162
171
  recurse_on_sequence: bool = False,
163
172
  json_string_attributes: Sequence[str] | None = None,
164
173
  ) -> Iterator[tuple[str, Any]]:
165
- """
166
- Flatten a nested dictionary or a sequence of dictionaries into a list of
167
- key value pairs. If `recurse_on_sequence` is True, then the function will
168
- also recursively flatten nested sequences of dictionaries. If
169
- `json_string_attributes` is provided, then the function will interpret the
170
- attributes in the list as JSON strings and convert them into dictionaries.
171
- The `prefix` argument is used to prefix the keys in the output list, but
172
- it's mostly used internally to facilitate recursion.
174
+ """Flatten a nested dictionary or a sequence of dictionaries into a list of key value pairs.
175
+
176
+ If `recurse_on_sequence` is True, then the function will also recursively flatten
177
+ nested sequences of dictionaries. If `json_string_attributes` is provided, then the
178
+ function will interpret the attributes in the list as JSON strings and convert them
179
+ into dictionaries. The `prefix` argument is used to prefix the keys in the output list,
180
+ but it's mostly used internally to facilitate recursion.
173
181
  """
174
182
  if isinstance(obj, Mapping):
175
183
  yield from _flatten_mapping(
@@ -192,11 +200,11 @@ def flatten(
192
200
 
193
201
 
194
202
  def has_mapping(sequence: Iterable[Any]) -> bool:
195
- """
196
- Check if a sequence contains a dictionary. We don't flatten sequences that
197
- only contain primitive types, such as strings, integers, etc. Conversely,
198
- we'll only un-flatten digit sub-keys if it can be interpreted the index of
199
- an array of dictionaries.
203
+ """Check if a sequence contains a dictionary.
204
+
205
+ We don't flatten sequences that only contain primitive types, such as strings,
206
+ integers, etc. Conversely, we'll only un-flatten digit sub-keys if it can be
207
+ interpreted the index of an array of dictionaries.
200
208
  """
201
209
  return any(isinstance(item, Mapping) for item in sequence)
202
210
 
@@ -209,13 +217,13 @@ def _flatten_mapping(
209
217
  json_string_attributes: Sequence[str] | None = None,
210
218
  separator: str = ".",
211
219
  ) -> Iterator[tuple[str, Any]]:
212
- """
213
- Flatten a nested dictionary into a list of key value pairs. If `recurse_on_sequence`
214
- is True, then the function will also recursively flatten nested sequences of dictionaries.
215
- If `json_string_attributes` is provided, then the function will interpret the attributes
216
- in the list as JSON strings and convert them into dictionaries. The `prefix` argument is
217
- used to prefix the keys in the output list, but it's mostly used internally to facilitate
218
- recursion.
220
+ """Flatten a nested dictionary into a list of key value pairs.
221
+
222
+ If `recurse_on_sequence` is True, then the function will also recursively flatten
223
+ nested sequences of dictionaries. If `json_string_attributes` is provided, then the
224
+ function will interpret the attributes in the list as JSON strings and convert them
225
+ into dictionaries. The `prefix` argument is used to prefix the keys in the output list,
226
+ but it's mostly used internally to facilitate recursion.
219
227
  """
220
228
  for key, value in mapping.items():
221
229
  prefixed_key = f"{prefix}{separator}{key}" if prefix else key
@@ -254,13 +262,13 @@ def _flatten_sequence(
254
262
  json_string_attributes: Sequence[str] | None = None,
255
263
  separator: str = ".",
256
264
  ) -> Iterator[tuple[str, Any]]:
257
- """
258
- Flatten a sequence of dictionaries into a list of key value pairs. If `recurse_on_sequence`
259
- is True, then the function will also recursively flatten nested sequences of dictionaries.
260
- If `json_string_attributes` is provided, then the function will interpret the attributes
261
- in the list as JSON strings and convert them into dictionaries. The `prefix` argument is
262
- used to prefix the keys in the output list, but it's mostly used internally to facilitate
263
- recursion.
265
+ """Flatten a sequence of dictionaries into a list of key value pairs.
266
+
267
+ If `recurse_on_sequence` is True, then the function will also recursively flatten
268
+ nested sequences of dictionaries. If `json_string_attributes` is provided, then the
269
+ function will interpret the attributes in the list as JSON strings and convert them
270
+ into dictionaries. The `prefix` argument is used to prefix the keys in the output list,
271
+ but it's mostly used internally to facilitate recursion.
264
272
  """
265
273
  if isinstance(sequence, str) or not has_mapping(sequence):
266
274
  yield prefix, sequence