pydantic-evals 0.8.1__tar.gz → 1.0.0b1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-evals might be problematic. Click here for more details.
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/PKG-INFO +3 -4
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/_utils.py +2 -2
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/dataset.py +143 -91
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/__init__.py +3 -1
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/_run_evaluator.py +47 -9
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/context.py +1 -1
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/evaluator.py +15 -4
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/llm_as_a_judge.py +3 -3
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/spec.py +3 -3
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/span_tree.py +5 -14
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/reporting/__init__.py +214 -21
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pyproject.toml +1 -2
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/.gitignore +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/LICENSE +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/README.md +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/__init__.py +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/common.py +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/generation.py +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/__init__.py +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/_context_in_memory_span_exporter.py +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/_context_subtree.py +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/_errors.py +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/py.typed +0 -0
- {pydantic_evals-0.8.1 → pydantic_evals-1.0.0b1}/pydantic_evals/reporting/render_numbers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-evals
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0b1
|
|
4
4
|
Summary: Framework for evaluating stochastic code execution, especially code making use of LLMs
|
|
5
5
|
Project-URL: Homepage, https://ai.pydantic.dev/evals
|
|
6
6
|
Project-URL: Source, https://github.com/pydantic/pydantic-ai
|
|
@@ -21,18 +21,17 @@ Classifier: Operating System :: Unix
|
|
|
21
21
|
Classifier: Programming Language :: Python
|
|
22
22
|
Classifier: Programming Language :: Python :: 3
|
|
23
23
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
24
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
25
24
|
Classifier: Programming Language :: Python :: 3.10
|
|
26
25
|
Classifier: Programming Language :: Python :: 3.11
|
|
27
26
|
Classifier: Programming Language :: Python :: 3.12
|
|
28
27
|
Classifier: Programming Language :: Python :: 3.13
|
|
29
28
|
Classifier: Topic :: Internet
|
|
30
29
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
31
|
-
Requires-Python: >=3.
|
|
30
|
+
Requires-Python: >=3.10
|
|
32
31
|
Requires-Dist: anyio>=0
|
|
33
32
|
Requires-Dist: eval-type-backport>=0; python_version < '3.11'
|
|
34
33
|
Requires-Dist: logfire-api>=3.14.1
|
|
35
|
-
Requires-Dist: pydantic-ai-slim==0.
|
|
34
|
+
Requires-Dist: pydantic-ai-slim==1.0.0b1
|
|
36
35
|
Requires-Dist: pydantic>=2.10
|
|
37
36
|
Requires-Dist: pyyaml>=6.0.2
|
|
38
37
|
Requires-Dist: rich>=13.9.4
|
|
@@ -2,9 +2,9 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import inspect
|
|
5
|
-
from collections.abc import Awaitable, Sequence
|
|
5
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
6
6
|
from functools import partial
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, TypeVar
|
|
8
8
|
|
|
9
9
|
import anyio
|
|
10
10
|
from typing_extensions import ParamSpec, TypeIs
|
|
@@ -13,14 +13,15 @@ import functools
|
|
|
13
13
|
import inspect
|
|
14
14
|
import sys
|
|
15
15
|
import time
|
|
16
|
+
import traceback
|
|
16
17
|
import warnings
|
|
17
|
-
from collections.abc import Awaitable, Mapping, Sequence
|
|
18
|
+
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
|
18
19
|
from contextlib import AsyncExitStack, nullcontext
|
|
19
20
|
from contextvars import ContextVar
|
|
20
21
|
from dataclasses import dataclass, field
|
|
21
22
|
from inspect import iscoroutinefunction
|
|
22
23
|
from pathlib import Path
|
|
23
|
-
from typing import
|
|
24
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
|
|
24
25
|
|
|
25
26
|
import anyio
|
|
26
27
|
import logfire_api
|
|
@@ -40,10 +41,14 @@ from .evaluators import EvaluationResult, Evaluator
|
|
|
40
41
|
from .evaluators._run_evaluator import run_evaluator
|
|
41
42
|
from .evaluators.common import DEFAULT_EVALUATORS
|
|
42
43
|
from .evaluators.context import EvaluatorContext
|
|
44
|
+
from .evaluators.evaluator import EvaluatorFailure
|
|
43
45
|
from .evaluators.spec import EvaluatorSpec
|
|
44
46
|
from .otel import SpanTree
|
|
45
47
|
from .otel._context_subtree import context_subtree
|
|
46
|
-
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate
|
|
48
|
+
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure
|
|
49
|
+
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from pydantic_ai.retries import RetryConfig
|
|
47
52
|
|
|
48
53
|
if sys.version_info < (3, 11):
|
|
49
54
|
from exceptiongroup import ExceptionGroup # pragma: lax no cover
|
|
@@ -74,6 +79,7 @@ _YAML_SCHEMA_LINE_PREFIX = '# yaml-language-server: $schema='
|
|
|
74
79
|
|
|
75
80
|
|
|
76
81
|
_REPORT_CASES_ADAPTER = TypeAdapter(list[ReportCase])
|
|
82
|
+
_REPORT_CASE_FAILURES_ADAPTER = TypeAdapter(list[ReportCaseFailure])
|
|
77
83
|
_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter(ReportCaseAggregate)
|
|
78
84
|
|
|
79
85
|
|
|
@@ -161,11 +167,6 @@ class Case(Generic[InputsT, OutputT, MetadataT]):
|
|
|
161
167
|
self.evaluators = list(evaluators)
|
|
162
168
|
|
|
163
169
|
|
|
164
|
-
# TODO: Consider making one or more of the following changes to this type:
|
|
165
|
-
# * Add `task: Callable[[InputsT], Awaitable[OutputT]` as a field
|
|
166
|
-
# * Add `inputs_type`, `output_type`, etc. as kwargs on `__init__`
|
|
167
|
-
# * Rename to `Evaluation`
|
|
168
|
-
# TODO: Allow `task` to be sync _or_ async
|
|
169
170
|
class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', arbitrary_types_allowed=True):
|
|
170
171
|
"""A dataset of test [cases][pydantic_evals.Case].
|
|
171
172
|
|
|
@@ -253,6 +254,8 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
253
254
|
name: str | None = None,
|
|
254
255
|
max_concurrency: int | None = None,
|
|
255
256
|
progress: bool = True,
|
|
257
|
+
retry_task: RetryConfig | None = None,
|
|
258
|
+
retry_evaluators: RetryConfig | None = None,
|
|
256
259
|
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
|
|
257
260
|
"""Evaluates the test cases in the dataset using the given task.
|
|
258
261
|
|
|
@@ -267,6 +270,8 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
267
270
|
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
|
|
268
271
|
If None, all cases will be evaluated concurrently.
|
|
269
272
|
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
|
|
273
|
+
retry_task: Optional retry configuration for the task execution.
|
|
274
|
+
retry_evaluators: Optional retry configuration for evaluator execution.
|
|
270
275
|
|
|
271
276
|
Returns:
|
|
272
277
|
A report containing the results of the evaluation.
|
|
@@ -277,12 +282,17 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
277
282
|
|
|
278
283
|
limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()
|
|
279
284
|
|
|
280
|
-
with
|
|
285
|
+
with (
|
|
286
|
+
_logfire.span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span,
|
|
287
|
+
progress_bar or nullcontext(),
|
|
288
|
+
):
|
|
281
289
|
task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None
|
|
282
290
|
|
|
283
291
|
async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
|
|
284
292
|
async with limiter:
|
|
285
|
-
result = await _run_task_and_evaluators(
|
|
293
|
+
result = await _run_task_and_evaluators(
|
|
294
|
+
task, case, report_case_name, self.evaluators, retry_task, retry_evaluators
|
|
295
|
+
)
|
|
286
296
|
if progress_bar and task_id is not None: # pragma: no branch
|
|
287
297
|
progress_bar.update(task_id, advance=1)
|
|
288
298
|
return result
|
|
@@ -293,21 +303,28 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
293
303
|
else:
|
|
294
304
|
trace_id = f'{context.trace_id:032x}'
|
|
295
305
|
span_id = f'{context.span_id:016x}'
|
|
306
|
+
cases_and_failures = await task_group_gather(
|
|
307
|
+
[
|
|
308
|
+
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
|
|
309
|
+
for i, case in enumerate(self.cases, 1)
|
|
310
|
+
]
|
|
311
|
+
)
|
|
312
|
+
cases: list[ReportCase] = []
|
|
313
|
+
failures: list[ReportCaseFailure] = []
|
|
314
|
+
for item in cases_and_failures:
|
|
315
|
+
if isinstance(item, ReportCase):
|
|
316
|
+
cases.append(item)
|
|
317
|
+
else:
|
|
318
|
+
failures.append(item)
|
|
296
319
|
report = EvaluationReport(
|
|
297
320
|
name=name,
|
|
298
|
-
cases=
|
|
299
|
-
|
|
300
|
-
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
|
|
301
|
-
for i, case in enumerate(self.cases, 1)
|
|
302
|
-
]
|
|
303
|
-
),
|
|
321
|
+
cases=cases,
|
|
322
|
+
failures=failures,
|
|
304
323
|
span_id=span_id,
|
|
305
324
|
trace_id=trace_id,
|
|
306
325
|
)
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
|
|
310
|
-
eval_span.set_attribute('averages', _REPORT_CASE_AGGREGATE_ADAPTER.dump_python(report.averages()))
|
|
326
|
+
if (averages := report.averages()) is not None and averages.assertions is not None:
|
|
327
|
+
eval_span.set_attribute('assertion_pass_rate', averages.assertions)
|
|
311
328
|
return report
|
|
312
329
|
|
|
313
330
|
def evaluate_sync(
|
|
@@ -634,7 +651,7 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
634
651
|
def _make_typed_dict(cls_name_prefix: str, fields: dict[str, Any]) -> Any:
|
|
635
652
|
td = TypedDict(f'{cls_name_prefix}_{name}', fields) # pyright: ignore[reportArgumentType]
|
|
636
653
|
config = ConfigDict(extra='forbid', arbitrary_types_allowed=True)
|
|
637
|
-
# TODO: Replace with pydantic.with_config
|
|
654
|
+
# TODO: Replace with pydantic.with_config once pydantic 2.11 is the min supported version
|
|
638
655
|
td.__pydantic_config__ = config # pyright: ignore[reportAttributeAccessIssue]
|
|
639
656
|
return td
|
|
640
657
|
|
|
@@ -735,7 +752,7 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
735
752
|
See <https://github.com/json-schema-org/json-schema-spec/issues/828> for context, that seems to be the nearest
|
|
736
753
|
there is to a spec for this.
|
|
737
754
|
"""
|
|
738
|
-
context = cast(
|
|
755
|
+
context = cast(dict[str, Any] | None, info.context)
|
|
739
756
|
if isinstance(context, dict) and (schema := context.get('$schema')):
|
|
740
757
|
return {'$schema': schema} | nxt(self)
|
|
741
758
|
else:
|
|
@@ -815,13 +832,16 @@ class _TaskRun:
|
|
|
815
832
|
|
|
816
833
|
|
|
817
834
|
async def _run_task(
|
|
818
|
-
task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
|
|
835
|
+
task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
|
|
836
|
+
case: Case[InputsT, OutputT, MetadataT],
|
|
837
|
+
retry: RetryConfig | None = None,
|
|
819
838
|
) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
|
|
820
839
|
"""Run a task on a case and return the context for evaluators.
|
|
821
840
|
|
|
822
841
|
Args:
|
|
823
842
|
task: The task to run.
|
|
824
843
|
case: The case to run the task on.
|
|
844
|
+
retry: The retry config to use.
|
|
825
845
|
|
|
826
846
|
Returns:
|
|
827
847
|
An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
|
|
@@ -829,24 +849,36 @@ async def _run_task(
|
|
|
829
849
|
Raises:
|
|
830
850
|
Exception: Any exception raised by the task.
|
|
831
851
|
"""
|
|
832
|
-
task_run = _TaskRun()
|
|
833
|
-
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
|
|
834
|
-
raise RuntimeError('A task run has already been entered. Task runs should not be nested')
|
|
835
852
|
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
853
|
+
async def _run_once():
|
|
854
|
+
task_run_ = _TaskRun()
|
|
855
|
+
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
|
|
856
|
+
raise RuntimeError('A task run has already been entered. Task runs should not be nested')
|
|
857
|
+
|
|
858
|
+
token = _CURRENT_TASK_RUN.set(task_run_)
|
|
859
|
+
try:
|
|
860
|
+
with (
|
|
861
|
+
_logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span,
|
|
862
|
+
context_subtree() as span_tree_,
|
|
863
|
+
):
|
|
842
864
|
t0 = time.perf_counter()
|
|
843
865
|
if iscoroutinefunction(task):
|
|
844
|
-
|
|
866
|
+
task_output_ = cast(OutputT, await task(case.inputs))
|
|
845
867
|
else:
|
|
846
|
-
|
|
868
|
+
task_output_ = cast(OutputT, await to_thread.run_sync(task, case.inputs))
|
|
847
869
|
fallback_duration = time.perf_counter() - t0
|
|
848
|
-
|
|
849
|
-
|
|
870
|
+
duration_ = _get_span_duration(task_span, fallback_duration)
|
|
871
|
+
return task_run_, task_output_, duration_, span_tree_
|
|
872
|
+
finally:
|
|
873
|
+
_CURRENT_TASK_RUN.reset(token)
|
|
874
|
+
|
|
875
|
+
if retry:
|
|
876
|
+
# import from pydantic_ai.retries to trigger more descriptive import error if tenacity is missing
|
|
877
|
+
from pydantic_ai.retries import retry as tenacity_retry
|
|
878
|
+
|
|
879
|
+
_run_once = tenacity_retry(**retry)(_run_once)
|
|
880
|
+
|
|
881
|
+
task_run, task_output, duration, span_tree = await _run_once()
|
|
850
882
|
|
|
851
883
|
if isinstance(span_tree, SpanTree): # pragma: no branch
|
|
852
884
|
# Idea for making this more configurable: replace the following logic with a call to a user-provided function
|
|
@@ -857,9 +889,8 @@ async def _run_task(
|
|
|
857
889
|
if node.attributes.get('gen_ai.operation.name') == 'chat':
|
|
858
890
|
task_run.increment_metric('requests', 1)
|
|
859
891
|
for k, v in node.attributes.items():
|
|
860
|
-
if not isinstance(v,
|
|
892
|
+
if not isinstance(v, int | float):
|
|
861
893
|
continue
|
|
862
|
-
# TODO: Revisit this choice to strip the prefix..
|
|
863
894
|
if k.startswith('gen_ai.usage.details.'):
|
|
864
895
|
task_run.increment_metric(k.removeprefix('gen_ai.usage.details.'), v)
|
|
865
896
|
elif k.startswith('gen_ai.usage.'):
|
|
@@ -871,7 +902,7 @@ async def _run_task(
|
|
|
871
902
|
metadata=case.metadata,
|
|
872
903
|
expected_output=case.expected_output,
|
|
873
904
|
output=task_output,
|
|
874
|
-
duration=
|
|
905
|
+
duration=duration,
|
|
875
906
|
_span_tree=span_tree,
|
|
876
907
|
attributes=task_run.attributes,
|
|
877
908
|
metrics=task_run.metrics,
|
|
@@ -883,7 +914,9 @@ async def _run_task_and_evaluators(
|
|
|
883
914
|
case: Case[InputsT, OutputT, MetadataT],
|
|
884
915
|
report_case_name: str,
|
|
885
916
|
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
|
|
886
|
-
|
|
917
|
+
retry_task: RetryConfig | None,
|
|
918
|
+
retry_evaluators: RetryConfig | None,
|
|
919
|
+
) -> ReportCase[InputsT, OutputT, MetadataT] | ReportCaseFailure[InputsT, OutputT, MetadataT]:
|
|
887
920
|
"""Run a task on a case and evaluate the results.
|
|
888
921
|
|
|
889
922
|
Args:
|
|
@@ -891,64 +924,83 @@ async def _run_task_and_evaluators(
|
|
|
891
924
|
case: The case to run the task on.
|
|
892
925
|
report_case_name: The name to use for this case in the report.
|
|
893
926
|
dataset_evaluators: Evaluators from the dataset to apply to this case.
|
|
927
|
+
retry_task: The retry config to use for running the task.
|
|
928
|
+
retry_evaluators: The retry config to use for running the evaluators.
|
|
894
929
|
|
|
895
930
|
Returns:
|
|
896
931
|
A ReportCase containing the evaluation results.
|
|
897
932
|
"""
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
evaluators = case.evaluators + dataset_evaluators
|
|
915
|
-
evaluator_outputs: list[EvaluationResult] = []
|
|
916
|
-
if evaluators:
|
|
917
|
-
evaluator_outputs_by_task = await task_group_gather(
|
|
918
|
-
[lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
|
|
919
|
-
)
|
|
920
|
-
evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs]
|
|
921
|
-
|
|
922
|
-
assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
|
|
923
|
-
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
|
|
924
|
-
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
|
|
925
|
-
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
|
|
933
|
+
trace_id: str | None = None
|
|
934
|
+
span_id: str | None = None
|
|
935
|
+
try:
|
|
936
|
+
with _logfire.span(
|
|
937
|
+
'case: {case_name}',
|
|
938
|
+
task_name=get_unwrapped_function_name(task),
|
|
939
|
+
case_name=report_case_name,
|
|
940
|
+
inputs=case.inputs,
|
|
941
|
+
metadata=case.metadata,
|
|
942
|
+
expected_output=case.expected_output,
|
|
943
|
+
) as case_span:
|
|
944
|
+
context = case_span.context
|
|
945
|
+
if context is not None: # pragma: no branch
|
|
946
|
+
trace_id = f'{context.trace_id:032x}'
|
|
947
|
+
span_id = f'{context.span_id:016x}'
|
|
926
948
|
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
949
|
+
t0 = time.time()
|
|
950
|
+
scoring_context = await _run_task(task, case, retry_task)
|
|
951
|
+
|
|
952
|
+
case_span.set_attribute('output', scoring_context.output)
|
|
953
|
+
case_span.set_attribute('task_duration', scoring_context.duration)
|
|
954
|
+
case_span.set_attribute('metrics', scoring_context.metrics)
|
|
955
|
+
case_span.set_attribute('attributes', scoring_context.attributes)
|
|
956
|
+
|
|
957
|
+
evaluators = case.evaluators + dataset_evaluators
|
|
958
|
+
evaluator_outputs: list[EvaluationResult] = []
|
|
959
|
+
evaluator_failures: list[EvaluatorFailure] = []
|
|
960
|
+
if evaluators:
|
|
961
|
+
evaluator_outputs_by_task = await task_group_gather(
|
|
962
|
+
[lambda ev=ev: run_evaluator(ev, scoring_context, retry_evaluators) for ev in evaluators]
|
|
963
|
+
)
|
|
964
|
+
for outputs in evaluator_outputs_by_task:
|
|
965
|
+
if isinstance(outputs, EvaluatorFailure):
|
|
966
|
+
evaluator_failures.append(outputs)
|
|
967
|
+
else:
|
|
968
|
+
evaluator_outputs.extend(outputs)
|
|
969
|
+
|
|
970
|
+
assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
|
|
971
|
+
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
|
|
972
|
+
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
|
|
973
|
+
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
|
|
934
974
|
fallback_duration = time.time() - t0
|
|
935
975
|
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
976
|
+
return ReportCase[InputsT, OutputT, MetadataT](
|
|
977
|
+
name=report_case_name,
|
|
978
|
+
inputs=case.inputs,
|
|
979
|
+
metadata=case.metadata,
|
|
980
|
+
expected_output=case.expected_output,
|
|
981
|
+
output=scoring_context.output,
|
|
982
|
+
metrics=scoring_context.metrics,
|
|
983
|
+
attributes=scoring_context.attributes,
|
|
984
|
+
scores=scores,
|
|
985
|
+
labels=labels,
|
|
986
|
+
assertions=assertions,
|
|
987
|
+
task_duration=scoring_context.duration,
|
|
988
|
+
total_duration=_get_span_duration(case_span, fallback_duration),
|
|
989
|
+
trace_id=trace_id,
|
|
990
|
+
span_id=span_id,
|
|
991
|
+
evaluator_failures=evaluator_failures,
|
|
992
|
+
)
|
|
993
|
+
except Exception as exc:
|
|
994
|
+
return ReportCaseFailure[InputsT, OutputT, MetadataT](
|
|
995
|
+
name=report_case_name,
|
|
996
|
+
inputs=case.inputs,
|
|
997
|
+
metadata=case.metadata,
|
|
998
|
+
expected_output=case.expected_output,
|
|
999
|
+
error_message=f'{type(exc).__name__}: {exc}',
|
|
1000
|
+
error_stacktrace=traceback.format_exc(),
|
|
1001
|
+
trace_id=trace_id,
|
|
1002
|
+
span_id=span_id,
|
|
1003
|
+
)
|
|
952
1004
|
|
|
953
1005
|
|
|
954
1006
|
_evaluation_results_adapter = TypeAdapter(Mapping[str, EvaluationResult])
|
|
@@ -10,7 +10,7 @@ from .common import (
|
|
|
10
10
|
Python,
|
|
11
11
|
)
|
|
12
12
|
from .context import EvaluatorContext
|
|
13
|
-
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorOutput, EvaluatorSpec
|
|
13
|
+
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorFailure, EvaluatorOutput, EvaluatorSpec
|
|
14
14
|
|
|
15
15
|
__all__ = (
|
|
16
16
|
# common
|
|
@@ -27,6 +27,8 @@ __all__ = (
|
|
|
27
27
|
'EvaluatorContext',
|
|
28
28
|
# evaluator
|
|
29
29
|
'Evaluator',
|
|
30
|
+
'EvaluationReason',
|
|
31
|
+
'EvaluatorFailure',
|
|
30
32
|
'EvaluatorOutput',
|
|
31
33
|
'EvaluatorSpec',
|
|
32
34
|
'EvaluationReason',
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import traceback
|
|
3
4
|
from collections.abc import Mapping
|
|
4
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
5
7
|
|
|
8
|
+
import logfire_api
|
|
6
9
|
from pydantic import (
|
|
7
10
|
TypeAdapter,
|
|
8
11
|
ValidationError,
|
|
@@ -10,7 +13,20 @@ from pydantic import (
|
|
|
10
13
|
from typing_extensions import TypeVar
|
|
11
14
|
|
|
12
15
|
from .context import EvaluatorContext
|
|
13
|
-
from .evaluator import
|
|
16
|
+
from .evaluator import (
|
|
17
|
+
EvaluationReason,
|
|
18
|
+
EvaluationResult,
|
|
19
|
+
EvaluationScalar,
|
|
20
|
+
Evaluator,
|
|
21
|
+
EvaluatorFailure,
|
|
22
|
+
EvaluatorOutput,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from pydantic_ai.retries import RetryConfig
|
|
27
|
+
|
|
28
|
+
_logfire = logfire_api.Logfire(otel_scope='pydantic-evals')
|
|
29
|
+
logfire_api.add_non_user_code_prefix(Path(__file__).parent.absolute())
|
|
14
30
|
|
|
15
31
|
InputsT = TypeVar('InputsT', default=Any, contravariant=True)
|
|
16
32
|
OutputT = TypeVar('OutputT', default=Any, contravariant=True)
|
|
@@ -18,8 +34,10 @@ MetadataT = TypeVar('MetadataT', default=Any, contravariant=True)
|
|
|
18
34
|
|
|
19
35
|
|
|
20
36
|
async def run_evaluator(
|
|
21
|
-
evaluator: Evaluator[InputsT, OutputT, MetadataT],
|
|
22
|
-
|
|
37
|
+
evaluator: Evaluator[InputsT, OutputT, MetadataT],
|
|
38
|
+
ctx: EvaluatorContext[InputsT, OutputT, MetadataT],
|
|
39
|
+
retry: RetryConfig | None = None,
|
|
40
|
+
) -> list[EvaluationResult] | EvaluatorFailure:
|
|
23
41
|
"""Run an evaluator and return the results.
|
|
24
42
|
|
|
25
43
|
This function runs an evaluator on the given context and processes the results into
|
|
@@ -28,19 +46,39 @@ async def run_evaluator(
|
|
|
28
46
|
Args:
|
|
29
47
|
evaluator: The evaluator to run.
|
|
30
48
|
ctx: The context containing the inputs, outputs, and metadata for evaluation.
|
|
49
|
+
retry: The retry configuration to use for running the evaluator.
|
|
31
50
|
|
|
32
51
|
Returns:
|
|
33
|
-
A list of evaluation results.
|
|
52
|
+
A list of evaluation results, or an evaluator failure if an exception is raised during its execution.
|
|
34
53
|
|
|
35
54
|
Raises:
|
|
36
55
|
ValueError: If the evaluator returns a value of an invalid type.
|
|
37
56
|
"""
|
|
38
|
-
|
|
57
|
+
evaluate = evaluator.evaluate_async
|
|
58
|
+
if retry is not None:
|
|
59
|
+
# import from pydantic_ai.retries to trigger more descriptive import error if tenacity is missing
|
|
60
|
+
from pydantic_ai.retries import retry as tenacity_retry
|
|
61
|
+
|
|
62
|
+
evaluate = tenacity_retry(**retry)(evaluate)
|
|
39
63
|
|
|
40
64
|
try:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
65
|
+
with _logfire.span(
|
|
66
|
+
'evaluator: {evaluator_name}',
|
|
67
|
+
evaluator_name=evaluator.get_default_evaluation_name(),
|
|
68
|
+
):
|
|
69
|
+
raw_results = await evaluate(ctx)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
results = _EVALUATOR_OUTPUT_ADAPTER.validate_python(raw_results)
|
|
73
|
+
except ValidationError as e:
|
|
74
|
+
raise ValueError(f'{evaluator!r}.evaluate returned a value of an invalid type: {raw_results!r}.') from e
|
|
75
|
+
except Exception as e:
|
|
76
|
+
return EvaluatorFailure(
|
|
77
|
+
name=evaluator.get_default_evaluation_name(),
|
|
78
|
+
error_message=f'{type(e).__name__}: {e}',
|
|
79
|
+
error_stacktrace=traceback.format_exc(),
|
|
80
|
+
source=evaluator.as_spec(),
|
|
81
|
+
)
|
|
44
82
|
|
|
45
83
|
results = _convert_to_mapping(results, scalar_name=evaluator.get_default_evaluation_name())
|
|
46
84
|
|
|
@@ -27,7 +27,7 @@ MetadataT = TypeVar('MetadataT', default=Any, covariant=True)
|
|
|
27
27
|
"""Type variable for the metadata associated with the task being evaluated."""
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
@dataclass
|
|
30
|
+
@dataclass(kw_only=True)
|
|
31
31
|
class EvaluatorContext(Generic[InputsT, OutputT, MetadataT]):
|
|
32
32
|
"""Context for evaluating a task execution.
|
|
33
33
|
|
|
@@ -4,7 +4,7 @@ import inspect
|
|
|
4
4
|
from abc import ABCMeta, abstractmethod
|
|
5
5
|
from collections.abc import Awaitable, Mapping
|
|
6
6
|
from dataclasses import MISSING, dataclass, fields
|
|
7
|
-
from typing import Any, Generic,
|
|
7
|
+
from typing import Any, Generic, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import (
|
|
10
10
|
ConfigDict,
|
|
@@ -25,11 +25,12 @@ __all__ = (
|
|
|
25
25
|
'EvaluationResult',
|
|
26
26
|
'EvaluationScalar',
|
|
27
27
|
'Evaluator',
|
|
28
|
+
'EvaluatorFailure',
|
|
28
29
|
'EvaluatorOutput',
|
|
29
30
|
'EvaluatorSpec',
|
|
30
31
|
)
|
|
31
32
|
|
|
32
|
-
EvaluationScalar =
|
|
33
|
+
EvaluationScalar = bool | int | float | str
|
|
33
34
|
"""The most primitive output allowed as an output from an Evaluator.
|
|
34
35
|
|
|
35
36
|
`int` and `float` are treated as scores, `str` as labels, and `bool` as assertions.
|
|
@@ -51,11 +52,11 @@ class EvaluationReason:
|
|
|
51
52
|
reason: str | None = None
|
|
52
53
|
|
|
53
54
|
|
|
54
|
-
EvaluatorOutput =
|
|
55
|
+
EvaluatorOutput = EvaluationScalar | EvaluationReason | Mapping[str, EvaluationScalar | EvaluationReason]
|
|
55
56
|
"""Type for the output of an evaluator, which can be a scalar, an EvaluationReason, or a mapping of names to either."""
|
|
56
57
|
|
|
57
58
|
|
|
58
|
-
# TODO(DavidM): Add bound=EvaluationScalar to the following typevar
|
|
59
|
+
# TODO(DavidM): Add bound=EvaluationScalar to the following typevar once pydantic 2.11 is the min supported version
|
|
59
60
|
EvaluationScalarT = TypeVar('EvaluationScalarT', default=EvaluationScalar, covariant=True)
|
|
60
61
|
"""Type variable for the scalar result type of an evaluation."""
|
|
61
62
|
|
|
@@ -100,6 +101,16 @@ class EvaluationResult(Generic[EvaluationScalarT]):
|
|
|
100
101
|
return None
|
|
101
102
|
|
|
102
103
|
|
|
104
|
+
@dataclass
|
|
105
|
+
class EvaluatorFailure:
|
|
106
|
+
"""Represents a failure raised during the execution of an evaluator."""
|
|
107
|
+
|
|
108
|
+
name: str
|
|
109
|
+
error_message: str
|
|
110
|
+
error_stacktrace: str
|
|
111
|
+
source: EvaluatorSpec
|
|
112
|
+
|
|
113
|
+
|
|
103
114
|
# Evaluators are contravariant in all of its parameters.
|
|
104
115
|
InputsT = TypeVar('InputsT', default=Any, contravariant=True)
|
|
105
116
|
"""Type variable for the inputs type of the task being evaluated."""
|
|
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
|
|
8
8
|
from pydantic_core import to_json
|
|
9
9
|
|
|
10
10
|
from pydantic_ai import Agent, models
|
|
11
|
-
from pydantic_ai.messages import
|
|
11
|
+
from pydantic_ai.messages import MultiModalContent, UserContent
|
|
12
12
|
from pydantic_ai.settings import ModelSettings
|
|
13
13
|
|
|
14
14
|
__all__ = (
|
|
@@ -238,11 +238,11 @@ def _build_prompt(
|
|
|
238
238
|
sections.append('<Input>\n')
|
|
239
239
|
if isinstance(inputs, Sequence):
|
|
240
240
|
for item in inputs: # type: ignore
|
|
241
|
-
if isinstance(item,
|
|
241
|
+
if isinstance(item, str | MultiModalContent):
|
|
242
242
|
sections.append(item)
|
|
243
243
|
else:
|
|
244
244
|
sections.append(_stringify(item))
|
|
245
|
-
elif isinstance(inputs,
|
|
245
|
+
elif isinstance(inputs, MultiModalContent):
|
|
246
246
|
sections.append(inputs)
|
|
247
247
|
else:
|
|
248
248
|
sections.append(_stringify(inputs))
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Any,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
6
6
|
|
|
7
7
|
from pydantic import (
|
|
8
8
|
BaseModel,
|
|
@@ -17,7 +17,7 @@ from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapH
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
# This import seems to fail on Pydantic 2.10.1 in CI
|
|
19
19
|
from pydantic import ModelWrapValidatorHandler
|
|
20
|
-
# TODO:
|
|
20
|
+
# TODO: Remove this once pydantic 2.11 is the min supported version
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class EvaluatorSpec(BaseModel):
|
|
@@ -112,7 +112,7 @@ class EvaluatorSpec(BaseModel):
|
|
|
112
112
|
return handler(self)
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
class _SerializedEvaluatorSpec(RootModel[
|
|
115
|
+
class _SerializedEvaluatorSpec(RootModel[str | dict[str, Any]]):
|
|
116
116
|
"""Internal class for handling the serialized form of an EvaluatorSpec.
|
|
117
117
|
|
|
118
118
|
This is an auxiliary class used to serialize/deserialize instances of EvaluatorSpec
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from collections.abc import Iterator, Sequence
|
|
4
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timedelta, timezone
|
|
7
7
|
from functools import cache
|
|
8
8
|
from textwrap import indent
|
|
9
|
-
from typing import TYPE_CHECKING, Any
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
10
|
|
|
11
11
|
from pydantic import TypeAdapter
|
|
12
12
|
from typing_extensions import TypedDict
|
|
@@ -16,16 +16,7 @@ if TYPE_CHECKING: # pragma: no cover
|
|
|
16
16
|
from opentelemetry.sdk.trace import ReadableSpan
|
|
17
17
|
|
|
18
18
|
# Should match opentelemetry.util.types.AttributeValue
|
|
19
|
-
AttributeValue =
|
|
20
|
-
str,
|
|
21
|
-
bool,
|
|
22
|
-
int,
|
|
23
|
-
float,
|
|
24
|
-
Sequence[str],
|
|
25
|
-
Sequence[bool],
|
|
26
|
-
Sequence[int],
|
|
27
|
-
Sequence[float],
|
|
28
|
-
]
|
|
19
|
+
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
|
|
29
20
|
|
|
30
21
|
|
|
31
22
|
__all__ = 'SpanNode', 'SpanTree', 'SpanQuery'
|
|
@@ -87,7 +78,7 @@ class SpanQuery(TypedDict, total=False):
|
|
|
87
78
|
no_ancestor_has: SpanQuery
|
|
88
79
|
|
|
89
80
|
|
|
90
|
-
@dataclass(repr=False)
|
|
81
|
+
@dataclass(repr=False, kw_only=True)
|
|
91
82
|
class SpanNode:
|
|
92
83
|
"""A node in the span tree; provides references to parents/children for easy traversal and queries."""
|
|
93
84
|
|
|
@@ -435,7 +426,7 @@ class SpanNode:
|
|
|
435
426
|
SpanPredicate = Callable[[SpanNode], bool]
|
|
436
427
|
|
|
437
428
|
|
|
438
|
-
@dataclass(repr=False)
|
|
429
|
+
@dataclass(repr=False, kw_only=True)
|
|
439
430
|
class SpanTree:
|
|
440
431
|
"""A container that builds a hierarchy of SpanNode objects from a list of finished spans.
|
|
441
432
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from collections.abc import Mapping
|
|
5
|
-
from dataclasses import dataclass
|
|
4
|
+
from collections.abc import Callable, Mapping
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
6
|
from io import StringIO
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Generic, Literal, Protocol, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import BaseModel, TypeAdapter
|
|
10
10
|
from rich.console import Console
|
|
@@ -27,12 +27,16 @@ __all__ = (
|
|
|
27
27
|
'EvaluationReportAdapter',
|
|
28
28
|
'ReportCase',
|
|
29
29
|
'ReportCaseAdapter',
|
|
30
|
+
'ReportCaseFailure',
|
|
31
|
+
'ReportCaseFailureAdapter',
|
|
30
32
|
'EvaluationRenderer',
|
|
31
33
|
'RenderValueConfig',
|
|
32
34
|
'RenderNumberConfig',
|
|
33
35
|
'ReportCaseAggregate',
|
|
34
36
|
)
|
|
35
37
|
|
|
38
|
+
from ..evaluators.evaluator import EvaluatorFailure
|
|
39
|
+
|
|
36
40
|
MISSING_VALUE_STR = '[i]<missing>[/i]'
|
|
37
41
|
EMPTY_CELL_STR = '-'
|
|
38
42
|
EMPTY_AGGREGATE_CELL_STR = ''
|
|
@@ -42,7 +46,7 @@ OutputT = TypeVar('OutputT', default=Any)
|
|
|
42
46
|
MetadataT = TypeVar('MetadataT', default=Any)
|
|
43
47
|
|
|
44
48
|
|
|
45
|
-
@dataclass
|
|
49
|
+
@dataclass(kw_only=True)
|
|
46
50
|
class ReportCase(Generic[InputsT, OutputT, MetadataT]):
|
|
47
51
|
"""A single case in an evaluation report."""
|
|
48
52
|
|
|
@@ -67,12 +71,40 @@ class ReportCase(Generic[InputsT, OutputT, MetadataT]):
|
|
|
67
71
|
task_duration: float
|
|
68
72
|
total_duration: float # includes evaluator execution time
|
|
69
73
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
span_id: str | None
|
|
74
|
+
trace_id: str | None = None
|
|
75
|
+
"""The trace ID of the case span."""
|
|
76
|
+
span_id: str | None = None
|
|
77
|
+
"""The span ID of the case span."""
|
|
78
|
+
|
|
79
|
+
evaluator_failures: list[EvaluatorFailure] = field(default_factory=list)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass(kw_only=True)
|
|
83
|
+
class ReportCaseFailure(Generic[InputsT, OutputT, MetadataT]):
|
|
84
|
+
"""A single case in an evaluation report that failed due to an error during task execution."""
|
|
85
|
+
|
|
86
|
+
name: str
|
|
87
|
+
"""The name of the [case][pydantic_evals.Case]."""
|
|
88
|
+
inputs: InputsT
|
|
89
|
+
"""The inputs to the task, from [`Case.inputs`][pydantic_evals.Case.inputs]."""
|
|
90
|
+
metadata: MetadataT | None
|
|
91
|
+
"""Any metadata associated with the case, from [`Case.metadata`][pydantic_evals.Case.metadata]."""
|
|
92
|
+
expected_output: OutputT | None
|
|
93
|
+
"""The expected output of the task, from [`Case.expected_output`][pydantic_evals.Case.expected_output]."""
|
|
94
|
+
|
|
95
|
+
error_message: str
|
|
96
|
+
"""The message of the exception that caused the failure."""
|
|
97
|
+
error_stacktrace: str
|
|
98
|
+
"""The stacktrace of the exception that caused the failure."""
|
|
99
|
+
|
|
100
|
+
trace_id: str | None = None
|
|
101
|
+
"""The trace ID of the case span."""
|
|
102
|
+
span_id: str | None = None
|
|
103
|
+
"""The span ID of the case span."""
|
|
73
104
|
|
|
74
105
|
|
|
75
106
|
ReportCaseAdapter = TypeAdapter(ReportCase[Any, Any, Any])
|
|
107
|
+
ReportCaseFailureAdapter = TypeAdapter(ReportCaseFailure[Any, Any, Any])
|
|
76
108
|
|
|
77
109
|
|
|
78
110
|
class ReportCaseAggregate(BaseModel):
|
|
@@ -152,7 +184,7 @@ class ReportCaseAggregate(BaseModel):
|
|
|
152
184
|
)
|
|
153
185
|
|
|
154
186
|
|
|
155
|
-
@dataclass
|
|
187
|
+
@dataclass(kw_only=True)
|
|
156
188
|
class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
157
189
|
"""A report of the results of evaluating a model on a set of cases."""
|
|
158
190
|
|
|
@@ -161,15 +193,18 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
161
193
|
|
|
162
194
|
cases: list[ReportCase[InputsT, OutputT, MetadataT]]
|
|
163
195
|
"""The cases in the report."""
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
"""The span ID of the evaluation."""
|
|
196
|
+
failures: list[ReportCaseFailure[InputsT, OutputT, MetadataT]] = field(default_factory=list)
|
|
197
|
+
"""The failures in the report. These are cases where task execution raised an exception."""
|
|
167
198
|
|
|
168
199
|
trace_id: str | None = None
|
|
169
200
|
"""The trace ID of the evaluation."""
|
|
201
|
+
span_id: str | None = None
|
|
202
|
+
"""The span ID of the evaluation."""
|
|
170
203
|
|
|
171
|
-
def averages(self) -> ReportCaseAggregate:
|
|
172
|
-
|
|
204
|
+
def averages(self) -> ReportCaseAggregate | None:
|
|
205
|
+
if self.cases:
|
|
206
|
+
return ReportCaseAggregate.average(self.cases)
|
|
207
|
+
return None
|
|
173
208
|
|
|
174
209
|
def print(
|
|
175
210
|
self,
|
|
@@ -184,6 +219,9 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
184
219
|
include_total_duration: bool = False,
|
|
185
220
|
include_removed_cases: bool = False,
|
|
186
221
|
include_averages: bool = True,
|
|
222
|
+
include_errors: bool = True,
|
|
223
|
+
include_error_stacktrace: bool = False,
|
|
224
|
+
include_evaluator_failures: bool = True,
|
|
187
225
|
input_config: RenderValueConfig | None = None,
|
|
188
226
|
metadata_config: RenderValueConfig | None = None,
|
|
189
227
|
output_config: RenderValueConfig | None = None,
|
|
@@ -207,6 +245,7 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
207
245
|
include_total_duration=include_total_duration,
|
|
208
246
|
include_removed_cases=include_removed_cases,
|
|
209
247
|
include_averages=include_averages,
|
|
248
|
+
include_evaluator_failures=include_evaluator_failures,
|
|
210
249
|
input_config=input_config,
|
|
211
250
|
metadata_config=metadata_config,
|
|
212
251
|
output_config=output_config,
|
|
@@ -216,7 +255,19 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
216
255
|
duration_config=duration_config,
|
|
217
256
|
include_reasons=include_reasons,
|
|
218
257
|
)
|
|
219
|
-
Console(width=width)
|
|
258
|
+
console = Console(width=width)
|
|
259
|
+
console.print(table)
|
|
260
|
+
if include_errors and self.failures:
|
|
261
|
+
failures_table = self.failures_table(
|
|
262
|
+
include_input=include_input,
|
|
263
|
+
include_metadata=include_metadata,
|
|
264
|
+
include_expected_output=include_expected_output,
|
|
265
|
+
include_error_message=True,
|
|
266
|
+
include_error_stacktrace=include_error_stacktrace,
|
|
267
|
+
input_config=input_config,
|
|
268
|
+
metadata_config=metadata_config,
|
|
269
|
+
)
|
|
270
|
+
console.print(failures_table, style='red')
|
|
220
271
|
|
|
221
272
|
def console_table(
|
|
222
273
|
self,
|
|
@@ -230,6 +281,7 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
230
281
|
include_total_duration: bool = False,
|
|
231
282
|
include_removed_cases: bool = False,
|
|
232
283
|
include_averages: bool = True,
|
|
284
|
+
include_evaluator_failures: bool = True,
|
|
233
285
|
input_config: RenderValueConfig | None = None,
|
|
234
286
|
metadata_config: RenderValueConfig | None = None,
|
|
235
287
|
output_config: RenderValueConfig | None = None,
|
|
@@ -252,6 +304,9 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
252
304
|
include_total_duration=include_total_duration,
|
|
253
305
|
include_removed_cases=include_removed_cases,
|
|
254
306
|
include_averages=include_averages,
|
|
307
|
+
include_error_message=False,
|
|
308
|
+
include_error_stacktrace=False,
|
|
309
|
+
include_evaluator_failures=include_evaluator_failures,
|
|
255
310
|
input_config={**_DEFAULT_VALUE_CONFIG, **(input_config or {})},
|
|
256
311
|
metadata_config={**_DEFAULT_VALUE_CONFIG, **(metadata_config or {})},
|
|
257
312
|
output_config=output_config or _DEFAULT_VALUE_CONFIG,
|
|
@@ -266,6 +321,41 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
266
321
|
else: # pragma: no cover
|
|
267
322
|
return renderer.build_diff_table(self, baseline)
|
|
268
323
|
|
|
324
|
+
def failures_table(
|
|
325
|
+
self,
|
|
326
|
+
*,
|
|
327
|
+
include_input: bool = False,
|
|
328
|
+
include_metadata: bool = False,
|
|
329
|
+
include_expected_output: bool = False,
|
|
330
|
+
include_error_message: bool = True,
|
|
331
|
+
include_error_stacktrace: bool = True,
|
|
332
|
+
input_config: RenderValueConfig | None = None,
|
|
333
|
+
metadata_config: RenderValueConfig | None = None,
|
|
334
|
+
) -> Table:
|
|
335
|
+
"""Return a table containing the failures in this report."""
|
|
336
|
+
renderer = EvaluationRenderer(
|
|
337
|
+
include_input=include_input,
|
|
338
|
+
include_metadata=include_metadata,
|
|
339
|
+
include_expected_output=include_expected_output,
|
|
340
|
+
include_output=False,
|
|
341
|
+
include_durations=False,
|
|
342
|
+
include_total_duration=False,
|
|
343
|
+
include_removed_cases=False,
|
|
344
|
+
include_averages=False,
|
|
345
|
+
input_config={**_DEFAULT_VALUE_CONFIG, **(input_config or {})},
|
|
346
|
+
metadata_config={**_DEFAULT_VALUE_CONFIG, **(metadata_config or {})},
|
|
347
|
+
output_config=_DEFAULT_VALUE_CONFIG,
|
|
348
|
+
score_configs={},
|
|
349
|
+
label_configs={},
|
|
350
|
+
metric_configs={},
|
|
351
|
+
duration_config=_DEFAULT_DURATION_CONFIG,
|
|
352
|
+
include_reasons=False,
|
|
353
|
+
include_error_message=include_error_message,
|
|
354
|
+
include_error_stacktrace=include_error_stacktrace,
|
|
355
|
+
include_evaluator_failures=False, # Not applicable for failures table
|
|
356
|
+
)
|
|
357
|
+
return renderer.build_failures_table(self)
|
|
358
|
+
|
|
269
359
|
def __str__(self) -> str: # pragma: lax no cover
|
|
270
360
|
"""Return a string representation of the report."""
|
|
271
361
|
table = self.console_table()
|
|
@@ -286,7 +376,7 @@ class RenderValueConfig(TypedDict, total=False):
|
|
|
286
376
|
diff_style: str
|
|
287
377
|
|
|
288
378
|
|
|
289
|
-
@dataclass
|
|
379
|
+
@dataclass(kw_only=True)
|
|
290
380
|
class _ValueRenderer:
|
|
291
381
|
value_formatter: str | Callable[[Any], str] = '{}'
|
|
292
382
|
diff_checker: Callable[[Any, Any], bool] | None = lambda x, y: x != y
|
|
@@ -401,7 +491,7 @@ class RenderNumberConfig(TypedDict, total=False):
|
|
|
401
491
|
"""
|
|
402
492
|
|
|
403
493
|
|
|
404
|
-
@dataclass
|
|
494
|
+
@dataclass(kw_only=True)
|
|
405
495
|
class _NumberRenderer:
|
|
406
496
|
"""See documentation of `RenderNumberConfig` for more details about the parameters here."""
|
|
407
497
|
|
|
@@ -503,7 +593,7 @@ class _NumberRenderer:
|
|
|
503
593
|
return None
|
|
504
594
|
|
|
505
595
|
diff = new - old
|
|
506
|
-
if abs(diff) < self.diff_atol + self.diff_rtol * abs(old):
|
|
596
|
+
if abs(diff) < self.diff_atol + self.diff_rtol * abs(old):
|
|
507
597
|
return None
|
|
508
598
|
return self.diff_increase_style if diff > 0 else self.diff_decrease_style
|
|
509
599
|
|
|
@@ -532,7 +622,7 @@ _DEFAULT_DURATION_CONFIG = RenderNumberConfig(
|
|
|
532
622
|
T = TypeVar('T')
|
|
533
623
|
|
|
534
624
|
|
|
535
|
-
@dataclass
|
|
625
|
+
@dataclass(kw_only=True)
|
|
536
626
|
class ReportCaseRenderer:
|
|
537
627
|
include_input: bool
|
|
538
628
|
include_metadata: bool
|
|
@@ -545,6 +635,9 @@ class ReportCaseRenderer:
|
|
|
545
635
|
include_reasons: bool
|
|
546
636
|
include_durations: bool
|
|
547
637
|
include_total_duration: bool
|
|
638
|
+
include_error_message: bool
|
|
639
|
+
include_error_stacktrace: bool
|
|
640
|
+
include_evaluator_failures: bool
|
|
548
641
|
|
|
549
642
|
input_renderer: _ValueRenderer
|
|
550
643
|
metadata_renderer: _ValueRenderer
|
|
@@ -574,10 +667,28 @@ class ReportCaseRenderer:
|
|
|
574
667
|
table.add_column('Metrics', overflow='fold')
|
|
575
668
|
if self.include_assertions:
|
|
576
669
|
table.add_column('Assertions', overflow='fold')
|
|
670
|
+
if self.include_evaluator_failures:
|
|
671
|
+
table.add_column('Evaluator Failures', overflow='fold')
|
|
577
672
|
if self.include_durations:
|
|
578
673
|
table.add_column('Durations' if self.include_total_duration else 'Duration', justify='right')
|
|
579
674
|
return table
|
|
580
675
|
|
|
676
|
+
def build_failures_table(self, title: str) -> Table:
|
|
677
|
+
"""Build and return a Rich Table for the failures output."""
|
|
678
|
+
table = Table(title=title, show_lines=True)
|
|
679
|
+
table.add_column('Case ID', style='bold')
|
|
680
|
+
if self.include_input:
|
|
681
|
+
table.add_column('Inputs', overflow='fold')
|
|
682
|
+
if self.include_metadata:
|
|
683
|
+
table.add_column('Metadata', overflow='fold')
|
|
684
|
+
if self.include_expected_output:
|
|
685
|
+
table.add_column('Expected Output', overflow='fold')
|
|
686
|
+
if self.include_error_message:
|
|
687
|
+
table.add_column('Error Message', overflow='fold')
|
|
688
|
+
if self.include_error_stacktrace:
|
|
689
|
+
table.add_column('Error Stacktrace', overflow='fold')
|
|
690
|
+
return table
|
|
691
|
+
|
|
581
692
|
def build_row(self, case: ReportCase) -> list[str]:
|
|
582
693
|
"""Build a table row for a single case."""
|
|
583
694
|
row = [case.name]
|
|
@@ -606,6 +717,9 @@ class ReportCaseRenderer:
|
|
|
606
717
|
if self.include_assertions:
|
|
607
718
|
row.append(self._render_assertions(list(case.assertions.values())))
|
|
608
719
|
|
|
720
|
+
if self.include_evaluator_failures:
|
|
721
|
+
row.append(self._render_evaluator_failures(case.evaluator_failures))
|
|
722
|
+
|
|
609
723
|
if self.include_durations:
|
|
610
724
|
row.append(self._render_durations(case))
|
|
611
725
|
|
|
@@ -639,6 +753,9 @@ class ReportCaseRenderer:
|
|
|
639
753
|
if self.include_assertions:
|
|
640
754
|
row.append(self._render_aggregate_assertions(aggregate.assertions))
|
|
641
755
|
|
|
756
|
+
if self.include_evaluator_failures:
|
|
757
|
+
row.append(EMPTY_AGGREGATE_CELL_STR)
|
|
758
|
+
|
|
642
759
|
if self.include_durations:
|
|
643
760
|
row.append(self._render_durations(aggregate))
|
|
644
761
|
|
|
@@ -700,6 +817,12 @@ class ReportCaseRenderer:
|
|
|
700
817
|
)
|
|
701
818
|
row.append(assertions_diff)
|
|
702
819
|
|
|
820
|
+
if self.include_evaluator_failures: # pragma: no branch
|
|
821
|
+
evaluator_failures_diff = self._render_evaluator_failures_diff(
|
|
822
|
+
baseline.evaluator_failures, new_case.evaluator_failures
|
|
823
|
+
)
|
|
824
|
+
row.append(evaluator_failures_diff)
|
|
825
|
+
|
|
703
826
|
if self.include_durations: # pragma: no branch
|
|
704
827
|
durations_diff = self._render_durations_diff(baseline, new_case)
|
|
705
828
|
row.append(durations_diff)
|
|
@@ -743,12 +866,36 @@ class ReportCaseRenderer:
|
|
|
743
866
|
assertions_diff = self._render_aggregate_assertions_diff(baseline.assertions, new.assertions)
|
|
744
867
|
row.append(assertions_diff)
|
|
745
868
|
|
|
869
|
+
if self.include_evaluator_failures: # pragma: no branch
|
|
870
|
+
row.append(EMPTY_AGGREGATE_CELL_STR)
|
|
871
|
+
|
|
746
872
|
if self.include_durations: # pragma: no branch
|
|
747
873
|
durations_diff = self._render_durations_diff(baseline, new)
|
|
748
874
|
row.append(durations_diff)
|
|
749
875
|
|
|
750
876
|
return row
|
|
751
877
|
|
|
878
|
+
def build_failure_row(self, case: ReportCaseFailure) -> list[str]:
|
|
879
|
+
"""Build a table row for a single case failure."""
|
|
880
|
+
row = [case.name]
|
|
881
|
+
|
|
882
|
+
if self.include_input:
|
|
883
|
+
row.append(self.input_renderer.render_value(None, case.inputs) or EMPTY_CELL_STR)
|
|
884
|
+
|
|
885
|
+
if self.include_metadata:
|
|
886
|
+
row.append(self.metadata_renderer.render_value(None, case.metadata) or EMPTY_CELL_STR)
|
|
887
|
+
|
|
888
|
+
if self.include_expected_output:
|
|
889
|
+
row.append(self.output_renderer.render_value(None, case.expected_output) or EMPTY_CELL_STR)
|
|
890
|
+
|
|
891
|
+
if self.include_error_message:
|
|
892
|
+
row.append(case.error_message or EMPTY_CELL_STR)
|
|
893
|
+
|
|
894
|
+
if self.include_error_stacktrace:
|
|
895
|
+
row.append(case.error_stacktrace or EMPTY_CELL_STR)
|
|
896
|
+
|
|
897
|
+
return row
|
|
898
|
+
|
|
752
899
|
def _render_durations(self, case: ReportCase | ReportCaseAggregate) -> str:
|
|
753
900
|
"""Build the diff string for a duration value."""
|
|
754
901
|
case_durations: dict[str, float] = {'task': case.task_duration}
|
|
@@ -862,8 +1009,33 @@ class ReportCaseRenderer:
|
|
|
862
1009
|
rendered_new = default_render_percentage(new) + ' [green]✔[/]' if new is not None else EMPTY_CELL_STR
|
|
863
1010
|
return rendered_new if rendered_baseline == rendered_new else f'{rendered_baseline} → {rendered_new}'
|
|
864
1011
|
|
|
1012
|
+
def _render_evaluator_failures(
|
|
1013
|
+
self,
|
|
1014
|
+
failures: list[EvaluatorFailure],
|
|
1015
|
+
) -> str:
|
|
1016
|
+
if not failures:
|
|
1017
|
+
return EMPTY_CELL_STR # pragma: no cover
|
|
1018
|
+
lines: list[str] = []
|
|
1019
|
+
for failure in failures:
|
|
1020
|
+
line = f'[red]{failure.name}[/]'
|
|
1021
|
+
if failure.error_message:
|
|
1022
|
+
line += f': {failure.error_message}'
|
|
1023
|
+
lines.append(line)
|
|
1024
|
+
return '\n'.join(lines)
|
|
1025
|
+
|
|
1026
|
+
def _render_evaluator_failures_diff(
|
|
1027
|
+
self,
|
|
1028
|
+
baseline_failures: list[EvaluatorFailure],
|
|
1029
|
+
new_failures: list[EvaluatorFailure],
|
|
1030
|
+
) -> str:
|
|
1031
|
+
baseline_str = self._render_evaluator_failures(baseline_failures)
|
|
1032
|
+
new_str = self._render_evaluator_failures(new_failures)
|
|
1033
|
+
if baseline_str == new_str:
|
|
1034
|
+
return baseline_str # pragma: no cover
|
|
1035
|
+
return f'{baseline_str}\n→\n{new_str}'
|
|
1036
|
+
|
|
865
1037
|
|
|
866
|
-
@dataclass
|
|
1038
|
+
@dataclass(kw_only=True)
|
|
867
1039
|
class EvaluationRenderer:
|
|
868
1040
|
"""A class for rendering an EvalReport or the diff between two EvalReports."""
|
|
869
1041
|
|
|
@@ -887,10 +1059,13 @@ class EvaluationRenderer:
|
|
|
887
1059
|
metric_configs: dict[str, RenderNumberConfig]
|
|
888
1060
|
duration_config: RenderNumberConfig
|
|
889
1061
|
|
|
890
|
-
# TODO: Make this class kw-only so we can reorder the kwargs
|
|
891
1062
|
# Data to include
|
|
892
1063
|
include_reasons: bool # only applies to reports, not to diffs
|
|
893
1064
|
|
|
1065
|
+
include_error_message: bool
|
|
1066
|
+
include_error_stacktrace: bool
|
|
1067
|
+
include_evaluator_failures: bool
|
|
1068
|
+
|
|
894
1069
|
def include_scores(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
|
|
895
1070
|
return any(case.scores for case in self._all_cases(report, baseline))
|
|
896
1071
|
|
|
@@ -903,6 +1078,11 @@ class EvaluationRenderer:
|
|
|
903
1078
|
def include_assertions(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
|
|
904
1079
|
return any(case.assertions for case in self._all_cases(report, baseline))
|
|
905
1080
|
|
|
1081
|
+
def include_evaluator_failures_column(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
|
|
1082
|
+
return self.include_evaluator_failures and any(
|
|
1083
|
+
case.evaluator_failures for case in self._all_cases(report, baseline)
|
|
1084
|
+
)
|
|
1085
|
+
|
|
906
1086
|
def _all_cases(self, report: EvaluationReport, baseline: EvaluationReport | None) -> list[ReportCase]:
|
|
907
1087
|
if not baseline:
|
|
908
1088
|
return report.cases
|
|
@@ -940,6 +1120,9 @@ class EvaluationRenderer:
|
|
|
940
1120
|
include_reasons=self.include_reasons,
|
|
941
1121
|
include_durations=self.include_durations,
|
|
942
1122
|
include_total_duration=self.include_total_duration,
|
|
1123
|
+
include_error_message=self.include_error_message,
|
|
1124
|
+
include_error_stacktrace=self.include_error_stacktrace,
|
|
1125
|
+
include_evaluator_failures=self.include_evaluator_failures_column(report, baseline),
|
|
943
1126
|
input_renderer=input_renderer,
|
|
944
1127
|
metadata_renderer=metadata_renderer,
|
|
945
1128
|
output_renderer=output_renderer,
|
|
@@ -957,7 +1140,9 @@ class EvaluationRenderer:
|
|
|
957
1140
|
|
|
958
1141
|
if self.include_averages: # pragma: no branch
|
|
959
1142
|
average = report.averages()
|
|
960
|
-
|
|
1143
|
+
if average: # pragma: no branch
|
|
1144
|
+
table.add_row(*case_renderer.build_aggregate_row(average))
|
|
1145
|
+
|
|
961
1146
|
return table
|
|
962
1147
|
|
|
963
1148
|
def build_diff_table(self, report: EvaluationReport, baseline: EvaluationReport) -> Table:
|
|
@@ -1004,6 +1189,14 @@ class EvaluationRenderer:
|
|
|
1004
1189
|
|
|
1005
1190
|
return table
|
|
1006
1191
|
|
|
1192
|
+
def build_failures_table(self, report: EvaluationReport) -> Table:
|
|
1193
|
+
case_renderer = self._get_case_renderer(report)
|
|
1194
|
+
table = case_renderer.build_failures_table('Case Failures')
|
|
1195
|
+
for case in report.failures:
|
|
1196
|
+
table.add_row(*case_renderer.build_failure_row(case))
|
|
1197
|
+
|
|
1198
|
+
return table
|
|
1199
|
+
|
|
1007
1200
|
def _infer_score_renderers(
|
|
1008
1201
|
self, report: EvaluationReport, baseline: EvaluationReport | None
|
|
1009
1202
|
) -> dict[str, _NumberRenderer]:
|
|
@@ -28,7 +28,6 @@ classifiers = [
|
|
|
28
28
|
"Programming Language :: Python",
|
|
29
29
|
"Programming Language :: Python :: 3",
|
|
30
30
|
"Programming Language :: Python :: 3 :: Only",
|
|
31
|
-
"Programming Language :: Python :: 3.9",
|
|
32
31
|
"Programming Language :: Python :: 3.10",
|
|
33
32
|
"Programming Language :: Python :: 3.11",
|
|
34
33
|
"Programming Language :: Python :: 3.12",
|
|
@@ -44,7 +43,7 @@ classifiers = [
|
|
|
44
43
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
45
44
|
"Topic :: Internet",
|
|
46
45
|
]
|
|
47
|
-
requires-python = ">=3.
|
|
46
|
+
requires-python = ">=3.10"
|
|
48
47
|
|
|
49
48
|
[tool.hatch.metadata.hooks.uv-dynamic-versioning]
|
|
50
49
|
dependencies = [
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|