pydantic-evals 0.8.0__tar.gz → 1.0.0b1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-evals might be problematic. Click here for more details.
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/PKG-INFO +3 -4
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/_utils.py +2 -2
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/dataset.py +147 -106
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/__init__.py +3 -1
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/_run_evaluator.py +47 -9
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/context.py +1 -1
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/evaluator.py +15 -4
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/llm_as_a_judge.py +3 -3
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/spec.py +3 -3
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/span_tree.py +5 -14
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/reporting/__init__.py +214 -21
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pyproject.toml +1 -2
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/.gitignore +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/LICENSE +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/README.md +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/__init__.py +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/evaluators/common.py +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/generation.py +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/__init__.py +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/_context_in_memory_span_exporter.py +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/_context_subtree.py +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/otel/_errors.py +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/py.typed +0 -0
- {pydantic_evals-0.8.0 → pydantic_evals-1.0.0b1}/pydantic_evals/reporting/render_numbers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-evals
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0b1
|
|
4
4
|
Summary: Framework for evaluating stochastic code execution, especially code making use of LLMs
|
|
5
5
|
Project-URL: Homepage, https://ai.pydantic.dev/evals
|
|
6
6
|
Project-URL: Source, https://github.com/pydantic/pydantic-ai
|
|
@@ -21,18 +21,17 @@ Classifier: Operating System :: Unix
|
|
|
21
21
|
Classifier: Programming Language :: Python
|
|
22
22
|
Classifier: Programming Language :: Python :: 3
|
|
23
23
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
24
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
25
24
|
Classifier: Programming Language :: Python :: 3.10
|
|
26
25
|
Classifier: Programming Language :: Python :: 3.11
|
|
27
26
|
Classifier: Programming Language :: Python :: 3.12
|
|
28
27
|
Classifier: Programming Language :: Python :: 3.13
|
|
29
28
|
Classifier: Topic :: Internet
|
|
30
29
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
31
|
-
Requires-Python: >=3.
|
|
30
|
+
Requires-Python: >=3.10
|
|
32
31
|
Requires-Dist: anyio>=0
|
|
33
32
|
Requires-Dist: eval-type-backport>=0; python_version < '3.11'
|
|
34
33
|
Requires-Dist: logfire-api>=3.14.1
|
|
35
|
-
Requires-Dist: pydantic-ai-slim==0.
|
|
34
|
+
Requires-Dist: pydantic-ai-slim==1.0.0b1
|
|
36
35
|
Requires-Dist: pydantic>=2.10
|
|
37
36
|
Requires-Dist: pyyaml>=6.0.2
|
|
38
37
|
Requires-Dist: rich>=13.9.4
|
|
@@ -2,9 +2,9 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import inspect
|
|
5
|
-
from collections.abc import Awaitable, Sequence
|
|
5
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
6
6
|
from functools import partial
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, TypeVar
|
|
8
8
|
|
|
9
9
|
import anyio
|
|
10
10
|
from typing_extensions import ParamSpec, TypeIs
|
|
@@ -13,14 +13,15 @@ import functools
|
|
|
13
13
|
import inspect
|
|
14
14
|
import sys
|
|
15
15
|
import time
|
|
16
|
+
import traceback
|
|
16
17
|
import warnings
|
|
17
|
-
from collections.abc import Awaitable, Mapping, Sequence
|
|
18
|
+
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
|
18
19
|
from contextlib import AsyncExitStack, nullcontext
|
|
19
20
|
from contextvars import ContextVar
|
|
20
21
|
from dataclasses import dataclass, field
|
|
21
22
|
from inspect import iscoroutinefunction
|
|
22
23
|
from pathlib import Path
|
|
23
|
-
from typing import
|
|
24
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
|
|
24
25
|
|
|
25
26
|
import anyio
|
|
26
27
|
import logfire_api
|
|
@@ -40,26 +41,20 @@ from .evaluators import EvaluationResult, Evaluator
|
|
|
40
41
|
from .evaluators._run_evaluator import run_evaluator
|
|
41
42
|
from .evaluators.common import DEFAULT_EVALUATORS
|
|
42
43
|
from .evaluators.context import EvaluatorContext
|
|
44
|
+
from .evaluators.evaluator import EvaluatorFailure
|
|
43
45
|
from .evaluators.spec import EvaluatorSpec
|
|
44
46
|
from .otel import SpanTree
|
|
45
47
|
from .otel._context_subtree import context_subtree
|
|
46
|
-
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate
|
|
48
|
+
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure
|
|
49
|
+
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from pydantic_ai.retries import RetryConfig
|
|
47
52
|
|
|
48
53
|
if sys.version_info < (3, 11):
|
|
49
54
|
from exceptiongroup import ExceptionGroup # pragma: lax no cover
|
|
50
55
|
else:
|
|
51
56
|
ExceptionGroup = ExceptionGroup # pragma: lax no cover
|
|
52
57
|
|
|
53
|
-
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
54
|
-
try:
|
|
55
|
-
import logfire._internal.stack_info
|
|
56
|
-
except ImportError:
|
|
57
|
-
pass
|
|
58
|
-
else:
|
|
59
|
-
from pathlib import Path
|
|
60
|
-
|
|
61
|
-
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) # pyright: ignore[reportPrivateImportUsage]
|
|
62
|
-
|
|
63
58
|
__all__ = (
|
|
64
59
|
'Case',
|
|
65
60
|
'Dataset',
|
|
@@ -84,6 +79,7 @@ _YAML_SCHEMA_LINE_PREFIX = '# yaml-language-server: $schema='
|
|
|
84
79
|
|
|
85
80
|
|
|
86
81
|
_REPORT_CASES_ADAPTER = TypeAdapter(list[ReportCase])
|
|
82
|
+
_REPORT_CASE_FAILURES_ADAPTER = TypeAdapter(list[ReportCaseFailure])
|
|
87
83
|
_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter(ReportCaseAggregate)
|
|
88
84
|
|
|
89
85
|
|
|
@@ -171,11 +167,6 @@ class Case(Generic[InputsT, OutputT, MetadataT]):
|
|
|
171
167
|
self.evaluators = list(evaluators)
|
|
172
168
|
|
|
173
169
|
|
|
174
|
-
# TODO: Consider making one or more of the following changes to this type:
|
|
175
|
-
# * Add `task: Callable[[InputsT], Awaitable[OutputT]` as a field
|
|
176
|
-
# * Add `inputs_type`, `output_type`, etc. as kwargs on `__init__`
|
|
177
|
-
# * Rename to `Evaluation`
|
|
178
|
-
# TODO: Allow `task` to be sync _or_ async
|
|
179
170
|
class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', arbitrary_types_allowed=True):
|
|
180
171
|
"""A dataset of test [cases][pydantic_evals.Case].
|
|
181
172
|
|
|
@@ -263,6 +254,8 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
263
254
|
name: str | None = None,
|
|
264
255
|
max_concurrency: int | None = None,
|
|
265
256
|
progress: bool = True,
|
|
257
|
+
retry_task: RetryConfig | None = None,
|
|
258
|
+
retry_evaluators: RetryConfig | None = None,
|
|
266
259
|
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
|
|
267
260
|
"""Evaluates the test cases in the dataset using the given task.
|
|
268
261
|
|
|
@@ -277,6 +270,8 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
277
270
|
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
|
|
278
271
|
If None, all cases will be evaluated concurrently.
|
|
279
272
|
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
|
|
273
|
+
retry_task: Optional retry configuration for the task execution.
|
|
274
|
+
retry_evaluators: Optional retry configuration for evaluator execution.
|
|
280
275
|
|
|
281
276
|
Returns:
|
|
282
277
|
A report containing the results of the evaluation.
|
|
@@ -287,12 +282,17 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
287
282
|
|
|
288
283
|
limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()
|
|
289
284
|
|
|
290
|
-
with
|
|
285
|
+
with (
|
|
286
|
+
_logfire.span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span,
|
|
287
|
+
progress_bar or nullcontext(),
|
|
288
|
+
):
|
|
291
289
|
task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None
|
|
292
290
|
|
|
293
291
|
async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
|
|
294
292
|
async with limiter:
|
|
295
|
-
result = await _run_task_and_evaluators(
|
|
293
|
+
result = await _run_task_and_evaluators(
|
|
294
|
+
task, case, report_case_name, self.evaluators, retry_task, retry_evaluators
|
|
295
|
+
)
|
|
296
296
|
if progress_bar and task_id is not None: # pragma: no branch
|
|
297
297
|
progress_bar.update(task_id, advance=1)
|
|
298
298
|
return result
|
|
@@ -303,21 +303,28 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
303
303
|
else:
|
|
304
304
|
trace_id = f'{context.trace_id:032x}'
|
|
305
305
|
span_id = f'{context.span_id:016x}'
|
|
306
|
+
cases_and_failures = await task_group_gather(
|
|
307
|
+
[
|
|
308
|
+
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
|
|
309
|
+
for i, case in enumerate(self.cases, 1)
|
|
310
|
+
]
|
|
311
|
+
)
|
|
312
|
+
cases: list[ReportCase] = []
|
|
313
|
+
failures: list[ReportCaseFailure] = []
|
|
314
|
+
for item in cases_and_failures:
|
|
315
|
+
if isinstance(item, ReportCase):
|
|
316
|
+
cases.append(item)
|
|
317
|
+
else:
|
|
318
|
+
failures.append(item)
|
|
306
319
|
report = EvaluationReport(
|
|
307
320
|
name=name,
|
|
308
|
-
cases=
|
|
309
|
-
|
|
310
|
-
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
|
|
311
|
-
for i, case in enumerate(self.cases, 1)
|
|
312
|
-
]
|
|
313
|
-
),
|
|
321
|
+
cases=cases,
|
|
322
|
+
failures=failures,
|
|
314
323
|
span_id=span_id,
|
|
315
324
|
trace_id=trace_id,
|
|
316
325
|
)
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
|
|
320
|
-
eval_span.set_attribute('averages', _REPORT_CASE_AGGREGATE_ADAPTER.dump_python(report.averages()))
|
|
326
|
+
if (averages := report.averages()) is not None and averages.assertions is not None:
|
|
327
|
+
eval_span.set_attribute('assertion_pass_rate', averages.assertions)
|
|
321
328
|
return report
|
|
322
329
|
|
|
323
330
|
def evaluate_sync(
|
|
@@ -644,7 +651,7 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
644
651
|
def _make_typed_dict(cls_name_prefix: str, fields: dict[str, Any]) -> Any:
|
|
645
652
|
td = TypedDict(f'{cls_name_prefix}_{name}', fields) # pyright: ignore[reportArgumentType]
|
|
646
653
|
config = ConfigDict(extra='forbid', arbitrary_types_allowed=True)
|
|
647
|
-
# TODO: Replace with pydantic.with_config
|
|
654
|
+
# TODO: Replace with pydantic.with_config once pydantic 2.11 is the min supported version
|
|
648
655
|
td.__pydantic_config__ = config # pyright: ignore[reportAttributeAccessIssue]
|
|
649
656
|
return td
|
|
650
657
|
|
|
@@ -745,7 +752,7 @@ class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', a
|
|
|
745
752
|
See <https://github.com/json-schema-org/json-schema-spec/issues/828> for context, that seems to be the nearest
|
|
746
753
|
there is to a spec for this.
|
|
747
754
|
"""
|
|
748
|
-
context = cast(
|
|
755
|
+
context = cast(dict[str, Any] | None, info.context)
|
|
749
756
|
if isinstance(context, dict) and (schema := context.get('$schema')):
|
|
750
757
|
return {'$schema': schema} | nxt(self)
|
|
751
758
|
else:
|
|
@@ -825,13 +832,16 @@ class _TaskRun:
|
|
|
825
832
|
|
|
826
833
|
|
|
827
834
|
async def _run_task(
|
|
828
|
-
task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
|
|
835
|
+
task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
|
|
836
|
+
case: Case[InputsT, OutputT, MetadataT],
|
|
837
|
+
retry: RetryConfig | None = None,
|
|
829
838
|
) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
|
|
830
839
|
"""Run a task on a case and return the context for evaluators.
|
|
831
840
|
|
|
832
841
|
Args:
|
|
833
842
|
task: The task to run.
|
|
834
843
|
case: The case to run the task on.
|
|
844
|
+
retry: The retry config to use.
|
|
835
845
|
|
|
836
846
|
Returns:
|
|
837
847
|
An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
|
|
@@ -839,38 +849,48 @@ async def _run_task(
|
|
|
839
849
|
Raises:
|
|
840
850
|
Exception: Any exception raised by the task.
|
|
841
851
|
"""
|
|
842
|
-
task_run = _TaskRun()
|
|
843
|
-
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
|
|
844
|
-
raise RuntimeError('A task run has already been entered. Task runs should not be nested')
|
|
845
852
|
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
853
|
+
async def _run_once():
|
|
854
|
+
task_run_ = _TaskRun()
|
|
855
|
+
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
|
|
856
|
+
raise RuntimeError('A task run has already been entered. Task runs should not be nested')
|
|
857
|
+
|
|
858
|
+
token = _CURRENT_TASK_RUN.set(task_run_)
|
|
859
|
+
try:
|
|
860
|
+
with (
|
|
861
|
+
_logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span,
|
|
862
|
+
context_subtree() as span_tree_,
|
|
863
|
+
):
|
|
852
864
|
t0 = time.perf_counter()
|
|
853
865
|
if iscoroutinefunction(task):
|
|
854
|
-
|
|
866
|
+
task_output_ = cast(OutputT, await task(case.inputs))
|
|
855
867
|
else:
|
|
856
|
-
|
|
868
|
+
task_output_ = cast(OutputT, await to_thread.run_sync(task, case.inputs))
|
|
857
869
|
fallback_duration = time.perf_counter() - t0
|
|
858
|
-
|
|
859
|
-
|
|
870
|
+
duration_ = _get_span_duration(task_span, fallback_duration)
|
|
871
|
+
return task_run_, task_output_, duration_, span_tree_
|
|
872
|
+
finally:
|
|
873
|
+
_CURRENT_TASK_RUN.reset(token)
|
|
874
|
+
|
|
875
|
+
if retry:
|
|
876
|
+
# import from pydantic_ai.retries to trigger more descriptive import error if tenacity is missing
|
|
877
|
+
from pydantic_ai.retries import retry as tenacity_retry
|
|
878
|
+
|
|
879
|
+
_run_once = tenacity_retry(**retry)(_run_once)
|
|
880
|
+
|
|
881
|
+
task_run, task_output, duration, span_tree = await _run_once()
|
|
860
882
|
|
|
861
883
|
if isinstance(span_tree, SpanTree): # pragma: no branch
|
|
862
|
-
#
|
|
863
|
-
#
|
|
864
|
-
#
|
|
865
|
-
#
|
|
866
|
-
# users. Maybe via an argument of type Callable[[SpanTree], dict[str, int | float]] or similar?
|
|
884
|
+
# Idea for making this more configurable: replace the following logic with a call to a user-provided function
|
|
885
|
+
# of type Callable[[_TaskRun, SpanTree], None] or similar, (maybe no _TaskRun and just use the public APIs).
|
|
886
|
+
# That way users can customize this logic. We'd default to a function that does the current thing but also
|
|
887
|
+
# allow `None` to disable it entirely.
|
|
867
888
|
for node in span_tree:
|
|
868
889
|
if node.attributes.get('gen_ai.operation.name') == 'chat':
|
|
869
890
|
task_run.increment_metric('requests', 1)
|
|
870
891
|
for k, v in node.attributes.items():
|
|
871
|
-
if not isinstance(v,
|
|
892
|
+
if not isinstance(v, int | float):
|
|
872
893
|
continue
|
|
873
|
-
# TODO: Revisit this choice to strip the prefix..
|
|
874
894
|
if k.startswith('gen_ai.usage.details.'):
|
|
875
895
|
task_run.increment_metric(k.removeprefix('gen_ai.usage.details.'), v)
|
|
876
896
|
elif k.startswith('gen_ai.usage.'):
|
|
@@ -882,7 +902,7 @@ async def _run_task(
|
|
|
882
902
|
metadata=case.metadata,
|
|
883
903
|
expected_output=case.expected_output,
|
|
884
904
|
output=task_output,
|
|
885
|
-
duration=
|
|
905
|
+
duration=duration,
|
|
886
906
|
_span_tree=span_tree,
|
|
887
907
|
attributes=task_run.attributes,
|
|
888
908
|
metrics=task_run.metrics,
|
|
@@ -894,7 +914,9 @@ async def _run_task_and_evaluators(
|
|
|
894
914
|
case: Case[InputsT, OutputT, MetadataT],
|
|
895
915
|
report_case_name: str,
|
|
896
916
|
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
|
|
897
|
-
|
|
917
|
+
retry_task: RetryConfig | None,
|
|
918
|
+
retry_evaluators: RetryConfig | None,
|
|
919
|
+
) -> ReportCase[InputsT, OutputT, MetadataT] | ReportCaseFailure[InputsT, OutputT, MetadataT]:
|
|
898
920
|
"""Run a task on a case and evaluate the results.
|
|
899
921
|
|
|
900
922
|
Args:
|
|
@@ -902,64 +924,83 @@ async def _run_task_and_evaluators(
|
|
|
902
924
|
case: The case to run the task on.
|
|
903
925
|
report_case_name: The name to use for this case in the report.
|
|
904
926
|
dataset_evaluators: Evaluators from the dataset to apply to this case.
|
|
927
|
+
retry_task: The retry config to use for running the task.
|
|
928
|
+
retry_evaluators: The retry config to use for running the evaluators.
|
|
905
929
|
|
|
906
930
|
Returns:
|
|
907
931
|
A ReportCase containing the evaluation results.
|
|
908
932
|
"""
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
evaluators = case.evaluators + dataset_evaluators
|
|
926
|
-
evaluator_outputs: list[EvaluationResult] = []
|
|
927
|
-
if evaluators:
|
|
928
|
-
evaluator_outputs_by_task = await task_group_gather(
|
|
929
|
-
[lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
|
|
930
|
-
)
|
|
931
|
-
evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs]
|
|
932
|
-
|
|
933
|
-
assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
|
|
934
|
-
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
|
|
935
|
-
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
|
|
936
|
-
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
|
|
933
|
+
trace_id: str | None = None
|
|
934
|
+
span_id: str | None = None
|
|
935
|
+
try:
|
|
936
|
+
with _logfire.span(
|
|
937
|
+
'case: {case_name}',
|
|
938
|
+
task_name=get_unwrapped_function_name(task),
|
|
939
|
+
case_name=report_case_name,
|
|
940
|
+
inputs=case.inputs,
|
|
941
|
+
metadata=case.metadata,
|
|
942
|
+
expected_output=case.expected_output,
|
|
943
|
+
) as case_span:
|
|
944
|
+
context = case_span.context
|
|
945
|
+
if context is not None: # pragma: no branch
|
|
946
|
+
trace_id = f'{context.trace_id:032x}'
|
|
947
|
+
span_id = f'{context.span_id:016x}'
|
|
937
948
|
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
949
|
+
t0 = time.time()
|
|
950
|
+
scoring_context = await _run_task(task, case, retry_task)
|
|
951
|
+
|
|
952
|
+
case_span.set_attribute('output', scoring_context.output)
|
|
953
|
+
case_span.set_attribute('task_duration', scoring_context.duration)
|
|
954
|
+
case_span.set_attribute('metrics', scoring_context.metrics)
|
|
955
|
+
case_span.set_attribute('attributes', scoring_context.attributes)
|
|
956
|
+
|
|
957
|
+
evaluators = case.evaluators + dataset_evaluators
|
|
958
|
+
evaluator_outputs: list[EvaluationResult] = []
|
|
959
|
+
evaluator_failures: list[EvaluatorFailure] = []
|
|
960
|
+
if evaluators:
|
|
961
|
+
evaluator_outputs_by_task = await task_group_gather(
|
|
962
|
+
[lambda ev=ev: run_evaluator(ev, scoring_context, retry_evaluators) for ev in evaluators]
|
|
963
|
+
)
|
|
964
|
+
for outputs in evaluator_outputs_by_task:
|
|
965
|
+
if isinstance(outputs, EvaluatorFailure):
|
|
966
|
+
evaluator_failures.append(outputs)
|
|
967
|
+
else:
|
|
968
|
+
evaluator_outputs.extend(outputs)
|
|
969
|
+
|
|
970
|
+
assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
|
|
971
|
+
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
|
|
972
|
+
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
|
|
973
|
+
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
|
|
945
974
|
fallback_duration = time.time() - t0
|
|
946
975
|
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
976
|
+
return ReportCase[InputsT, OutputT, MetadataT](
|
|
977
|
+
name=report_case_name,
|
|
978
|
+
inputs=case.inputs,
|
|
979
|
+
metadata=case.metadata,
|
|
980
|
+
expected_output=case.expected_output,
|
|
981
|
+
output=scoring_context.output,
|
|
982
|
+
metrics=scoring_context.metrics,
|
|
983
|
+
attributes=scoring_context.attributes,
|
|
984
|
+
scores=scores,
|
|
985
|
+
labels=labels,
|
|
986
|
+
assertions=assertions,
|
|
987
|
+
task_duration=scoring_context.duration,
|
|
988
|
+
total_duration=_get_span_duration(case_span, fallback_duration),
|
|
989
|
+
trace_id=trace_id,
|
|
990
|
+
span_id=span_id,
|
|
991
|
+
evaluator_failures=evaluator_failures,
|
|
992
|
+
)
|
|
993
|
+
except Exception as exc:
|
|
994
|
+
return ReportCaseFailure[InputsT, OutputT, MetadataT](
|
|
995
|
+
name=report_case_name,
|
|
996
|
+
inputs=case.inputs,
|
|
997
|
+
metadata=case.metadata,
|
|
998
|
+
expected_output=case.expected_output,
|
|
999
|
+
error_message=f'{type(exc).__name__}: {exc}',
|
|
1000
|
+
error_stacktrace=traceback.format_exc(),
|
|
1001
|
+
trace_id=trace_id,
|
|
1002
|
+
span_id=span_id,
|
|
1003
|
+
)
|
|
963
1004
|
|
|
964
1005
|
|
|
965
1006
|
_evaluation_results_adapter = TypeAdapter(Mapping[str, EvaluationResult])
|
|
@@ -10,7 +10,7 @@ from .common import (
|
|
|
10
10
|
Python,
|
|
11
11
|
)
|
|
12
12
|
from .context import EvaluatorContext
|
|
13
|
-
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorOutput, EvaluatorSpec
|
|
13
|
+
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorFailure, EvaluatorOutput, EvaluatorSpec
|
|
14
14
|
|
|
15
15
|
__all__ = (
|
|
16
16
|
# common
|
|
@@ -27,6 +27,8 @@ __all__ = (
|
|
|
27
27
|
'EvaluatorContext',
|
|
28
28
|
# evaluator
|
|
29
29
|
'Evaluator',
|
|
30
|
+
'EvaluationReason',
|
|
31
|
+
'EvaluatorFailure',
|
|
30
32
|
'EvaluatorOutput',
|
|
31
33
|
'EvaluatorSpec',
|
|
32
34
|
'EvaluationReason',
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import traceback
|
|
3
4
|
from collections.abc import Mapping
|
|
4
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
5
7
|
|
|
8
|
+
import logfire_api
|
|
6
9
|
from pydantic import (
|
|
7
10
|
TypeAdapter,
|
|
8
11
|
ValidationError,
|
|
@@ -10,7 +13,20 @@ from pydantic import (
|
|
|
10
13
|
from typing_extensions import TypeVar
|
|
11
14
|
|
|
12
15
|
from .context import EvaluatorContext
|
|
13
|
-
from .evaluator import
|
|
16
|
+
from .evaluator import (
|
|
17
|
+
EvaluationReason,
|
|
18
|
+
EvaluationResult,
|
|
19
|
+
EvaluationScalar,
|
|
20
|
+
Evaluator,
|
|
21
|
+
EvaluatorFailure,
|
|
22
|
+
EvaluatorOutput,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from pydantic_ai.retries import RetryConfig
|
|
27
|
+
|
|
28
|
+
_logfire = logfire_api.Logfire(otel_scope='pydantic-evals')
|
|
29
|
+
logfire_api.add_non_user_code_prefix(Path(__file__).parent.absolute())
|
|
14
30
|
|
|
15
31
|
InputsT = TypeVar('InputsT', default=Any, contravariant=True)
|
|
16
32
|
OutputT = TypeVar('OutputT', default=Any, contravariant=True)
|
|
@@ -18,8 +34,10 @@ MetadataT = TypeVar('MetadataT', default=Any, contravariant=True)
|
|
|
18
34
|
|
|
19
35
|
|
|
20
36
|
async def run_evaluator(
|
|
21
|
-
evaluator: Evaluator[InputsT, OutputT, MetadataT],
|
|
22
|
-
|
|
37
|
+
evaluator: Evaluator[InputsT, OutputT, MetadataT],
|
|
38
|
+
ctx: EvaluatorContext[InputsT, OutputT, MetadataT],
|
|
39
|
+
retry: RetryConfig | None = None,
|
|
40
|
+
) -> list[EvaluationResult] | EvaluatorFailure:
|
|
23
41
|
"""Run an evaluator and return the results.
|
|
24
42
|
|
|
25
43
|
This function runs an evaluator on the given context and processes the results into
|
|
@@ -28,19 +46,39 @@ async def run_evaluator(
|
|
|
28
46
|
Args:
|
|
29
47
|
evaluator: The evaluator to run.
|
|
30
48
|
ctx: The context containing the inputs, outputs, and metadata for evaluation.
|
|
49
|
+
retry: The retry configuration to use for running the evaluator.
|
|
31
50
|
|
|
32
51
|
Returns:
|
|
33
|
-
A list of evaluation results.
|
|
52
|
+
A list of evaluation results, or an evaluator failure if an exception is raised during its execution.
|
|
34
53
|
|
|
35
54
|
Raises:
|
|
36
55
|
ValueError: If the evaluator returns a value of an invalid type.
|
|
37
56
|
"""
|
|
38
|
-
|
|
57
|
+
evaluate = evaluator.evaluate_async
|
|
58
|
+
if retry is not None:
|
|
59
|
+
# import from pydantic_ai.retries to trigger more descriptive import error if tenacity is missing
|
|
60
|
+
from pydantic_ai.retries import retry as tenacity_retry
|
|
61
|
+
|
|
62
|
+
evaluate = tenacity_retry(**retry)(evaluate)
|
|
39
63
|
|
|
40
64
|
try:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
65
|
+
with _logfire.span(
|
|
66
|
+
'evaluator: {evaluator_name}',
|
|
67
|
+
evaluator_name=evaluator.get_default_evaluation_name(),
|
|
68
|
+
):
|
|
69
|
+
raw_results = await evaluate(ctx)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
results = _EVALUATOR_OUTPUT_ADAPTER.validate_python(raw_results)
|
|
73
|
+
except ValidationError as e:
|
|
74
|
+
raise ValueError(f'{evaluator!r}.evaluate returned a value of an invalid type: {raw_results!r}.') from e
|
|
75
|
+
except Exception as e:
|
|
76
|
+
return EvaluatorFailure(
|
|
77
|
+
name=evaluator.get_default_evaluation_name(),
|
|
78
|
+
error_message=f'{type(e).__name__}: {e}',
|
|
79
|
+
error_stacktrace=traceback.format_exc(),
|
|
80
|
+
source=evaluator.as_spec(),
|
|
81
|
+
)
|
|
44
82
|
|
|
45
83
|
results = _convert_to_mapping(results, scalar_name=evaluator.get_default_evaluation_name())
|
|
46
84
|
|
|
@@ -27,7 +27,7 @@ MetadataT = TypeVar('MetadataT', default=Any, covariant=True)
|
|
|
27
27
|
"""Type variable for the metadata associated with the task being evaluated."""
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
@dataclass
|
|
30
|
+
@dataclass(kw_only=True)
|
|
31
31
|
class EvaluatorContext(Generic[InputsT, OutputT, MetadataT]):
|
|
32
32
|
"""Context for evaluating a task execution.
|
|
33
33
|
|
|
@@ -4,7 +4,7 @@ import inspect
|
|
|
4
4
|
from abc import ABCMeta, abstractmethod
|
|
5
5
|
from collections.abc import Awaitable, Mapping
|
|
6
6
|
from dataclasses import MISSING, dataclass, fields
|
|
7
|
-
from typing import Any, Generic,
|
|
7
|
+
from typing import Any, Generic, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import (
|
|
10
10
|
ConfigDict,
|
|
@@ -25,11 +25,12 @@ __all__ = (
|
|
|
25
25
|
'EvaluationResult',
|
|
26
26
|
'EvaluationScalar',
|
|
27
27
|
'Evaluator',
|
|
28
|
+
'EvaluatorFailure',
|
|
28
29
|
'EvaluatorOutput',
|
|
29
30
|
'EvaluatorSpec',
|
|
30
31
|
)
|
|
31
32
|
|
|
32
|
-
EvaluationScalar =
|
|
33
|
+
EvaluationScalar = bool | int | float | str
|
|
33
34
|
"""The most primitive output allowed as an output from an Evaluator.
|
|
34
35
|
|
|
35
36
|
`int` and `float` are treated as scores, `str` as labels, and `bool` as assertions.
|
|
@@ -51,11 +52,11 @@ class EvaluationReason:
|
|
|
51
52
|
reason: str | None = None
|
|
52
53
|
|
|
53
54
|
|
|
54
|
-
EvaluatorOutput =
|
|
55
|
+
EvaluatorOutput = EvaluationScalar | EvaluationReason | Mapping[str, EvaluationScalar | EvaluationReason]
|
|
55
56
|
"""Type for the output of an evaluator, which can be a scalar, an EvaluationReason, or a mapping of names to either."""
|
|
56
57
|
|
|
57
58
|
|
|
58
|
-
# TODO(DavidM): Add bound=EvaluationScalar to the following typevar
|
|
59
|
+
# TODO(DavidM): Add bound=EvaluationScalar to the following typevar once pydantic 2.11 is the min supported version
|
|
59
60
|
EvaluationScalarT = TypeVar('EvaluationScalarT', default=EvaluationScalar, covariant=True)
|
|
60
61
|
"""Type variable for the scalar result type of an evaluation."""
|
|
61
62
|
|
|
@@ -100,6 +101,16 @@ class EvaluationResult(Generic[EvaluationScalarT]):
|
|
|
100
101
|
return None
|
|
101
102
|
|
|
102
103
|
|
|
104
|
+
@dataclass
|
|
105
|
+
class EvaluatorFailure:
|
|
106
|
+
"""Represents a failure raised during the execution of an evaluator."""
|
|
107
|
+
|
|
108
|
+
name: str
|
|
109
|
+
error_message: str
|
|
110
|
+
error_stacktrace: str
|
|
111
|
+
source: EvaluatorSpec
|
|
112
|
+
|
|
113
|
+
|
|
103
114
|
# Evaluators are contravariant in all of its parameters.
|
|
104
115
|
InputsT = TypeVar('InputsT', default=Any, contravariant=True)
|
|
105
116
|
"""Type variable for the inputs type of the task being evaluated."""
|
|
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
|
|
8
8
|
from pydantic_core import to_json
|
|
9
9
|
|
|
10
10
|
from pydantic_ai import Agent, models
|
|
11
|
-
from pydantic_ai.messages import
|
|
11
|
+
from pydantic_ai.messages import MultiModalContent, UserContent
|
|
12
12
|
from pydantic_ai.settings import ModelSettings
|
|
13
13
|
|
|
14
14
|
__all__ = (
|
|
@@ -238,11 +238,11 @@ def _build_prompt(
|
|
|
238
238
|
sections.append('<Input>\n')
|
|
239
239
|
if isinstance(inputs, Sequence):
|
|
240
240
|
for item in inputs: # type: ignore
|
|
241
|
-
if isinstance(item,
|
|
241
|
+
if isinstance(item, str | MultiModalContent):
|
|
242
242
|
sections.append(item)
|
|
243
243
|
else:
|
|
244
244
|
sections.append(_stringify(item))
|
|
245
|
-
elif isinstance(inputs,
|
|
245
|
+
elif isinstance(inputs, MultiModalContent):
|
|
246
246
|
sections.append(inputs)
|
|
247
247
|
else:
|
|
248
248
|
sections.append(_stringify(inputs))
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Any,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
6
6
|
|
|
7
7
|
from pydantic import (
|
|
8
8
|
BaseModel,
|
|
@@ -17,7 +17,7 @@ from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapH
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
# This import seems to fail on Pydantic 2.10.1 in CI
|
|
19
19
|
from pydantic import ModelWrapValidatorHandler
|
|
20
|
-
# TODO:
|
|
20
|
+
# TODO: Remove this once pydantic 2.11 is the min supported version
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class EvaluatorSpec(BaseModel):
|
|
@@ -112,7 +112,7 @@ class EvaluatorSpec(BaseModel):
|
|
|
112
112
|
return handler(self)
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
class _SerializedEvaluatorSpec(RootModel[
|
|
115
|
+
class _SerializedEvaluatorSpec(RootModel[str | dict[str, Any]]):
|
|
116
116
|
"""Internal class for handling the serialized form of an EvaluatorSpec.
|
|
117
117
|
|
|
118
118
|
This is an auxiliary class used to serialize/deserialize instances of EvaluatorSpec
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
-
from collections.abc import Iterator, Sequence
|
|
4
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timedelta, timezone
|
|
7
7
|
from functools import cache
|
|
8
8
|
from textwrap import indent
|
|
9
|
-
from typing import TYPE_CHECKING, Any
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
10
|
|
|
11
11
|
from pydantic import TypeAdapter
|
|
12
12
|
from typing_extensions import TypedDict
|
|
@@ -16,16 +16,7 @@ if TYPE_CHECKING: # pragma: no cover
|
|
|
16
16
|
from opentelemetry.sdk.trace import ReadableSpan
|
|
17
17
|
|
|
18
18
|
# Should match opentelemetry.util.types.AttributeValue
|
|
19
|
-
AttributeValue =
|
|
20
|
-
str,
|
|
21
|
-
bool,
|
|
22
|
-
int,
|
|
23
|
-
float,
|
|
24
|
-
Sequence[str],
|
|
25
|
-
Sequence[bool],
|
|
26
|
-
Sequence[int],
|
|
27
|
-
Sequence[float],
|
|
28
|
-
]
|
|
19
|
+
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
|
|
29
20
|
|
|
30
21
|
|
|
31
22
|
__all__ = 'SpanNode', 'SpanTree', 'SpanQuery'
|
|
@@ -87,7 +78,7 @@ class SpanQuery(TypedDict, total=False):
|
|
|
87
78
|
no_ancestor_has: SpanQuery
|
|
88
79
|
|
|
89
80
|
|
|
90
|
-
@dataclass(repr=False)
|
|
81
|
+
@dataclass(repr=False, kw_only=True)
|
|
91
82
|
class SpanNode:
|
|
92
83
|
"""A node in the span tree; provides references to parents/children for easy traversal and queries."""
|
|
93
84
|
|
|
@@ -435,7 +426,7 @@ class SpanNode:
|
|
|
435
426
|
SpanPredicate = Callable[[SpanNode], bool]
|
|
436
427
|
|
|
437
428
|
|
|
438
|
-
@dataclass(repr=False)
|
|
429
|
+
@dataclass(repr=False, kw_only=True)
|
|
439
430
|
class SpanTree:
|
|
440
431
|
"""A container that builds a hierarchy of SpanNode objects from a list of finished spans.
|
|
441
432
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from collections.abc import Mapping
|
|
5
|
-
from dataclasses import dataclass
|
|
4
|
+
from collections.abc import Callable, Mapping
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
6
|
from io import StringIO
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Generic, Literal, Protocol, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import BaseModel, TypeAdapter
|
|
10
10
|
from rich.console import Console
|
|
@@ -27,12 +27,16 @@ __all__ = (
|
|
|
27
27
|
'EvaluationReportAdapter',
|
|
28
28
|
'ReportCase',
|
|
29
29
|
'ReportCaseAdapter',
|
|
30
|
+
'ReportCaseFailure',
|
|
31
|
+
'ReportCaseFailureAdapter',
|
|
30
32
|
'EvaluationRenderer',
|
|
31
33
|
'RenderValueConfig',
|
|
32
34
|
'RenderNumberConfig',
|
|
33
35
|
'ReportCaseAggregate',
|
|
34
36
|
)
|
|
35
37
|
|
|
38
|
+
from ..evaluators.evaluator import EvaluatorFailure
|
|
39
|
+
|
|
36
40
|
MISSING_VALUE_STR = '[i]<missing>[/i]'
|
|
37
41
|
EMPTY_CELL_STR = '-'
|
|
38
42
|
EMPTY_AGGREGATE_CELL_STR = ''
|
|
@@ -42,7 +46,7 @@ OutputT = TypeVar('OutputT', default=Any)
|
|
|
42
46
|
MetadataT = TypeVar('MetadataT', default=Any)
|
|
43
47
|
|
|
44
48
|
|
|
45
|
-
@dataclass
|
|
49
|
+
@dataclass(kw_only=True)
|
|
46
50
|
class ReportCase(Generic[InputsT, OutputT, MetadataT]):
|
|
47
51
|
"""A single case in an evaluation report."""
|
|
48
52
|
|
|
@@ -67,12 +71,40 @@ class ReportCase(Generic[InputsT, OutputT, MetadataT]):
|
|
|
67
71
|
task_duration: float
|
|
68
72
|
total_duration: float # includes evaluator execution time
|
|
69
73
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
span_id: str | None
|
|
74
|
+
trace_id: str | None = None
|
|
75
|
+
"""The trace ID of the case span."""
|
|
76
|
+
span_id: str | None = None
|
|
77
|
+
"""The span ID of the case span."""
|
|
78
|
+
|
|
79
|
+
evaluator_failures: list[EvaluatorFailure] = field(default_factory=list)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass(kw_only=True)
|
|
83
|
+
class ReportCaseFailure(Generic[InputsT, OutputT, MetadataT]):
|
|
84
|
+
"""A single case in an evaluation report that failed due to an error during task execution."""
|
|
85
|
+
|
|
86
|
+
name: str
|
|
87
|
+
"""The name of the [case][pydantic_evals.Case]."""
|
|
88
|
+
inputs: InputsT
|
|
89
|
+
"""The inputs to the task, from [`Case.inputs`][pydantic_evals.Case.inputs]."""
|
|
90
|
+
metadata: MetadataT | None
|
|
91
|
+
"""Any metadata associated with the case, from [`Case.metadata`][pydantic_evals.Case.metadata]."""
|
|
92
|
+
expected_output: OutputT | None
|
|
93
|
+
"""The expected output of the task, from [`Case.expected_output`][pydantic_evals.Case.expected_output]."""
|
|
94
|
+
|
|
95
|
+
error_message: str
|
|
96
|
+
"""The message of the exception that caused the failure."""
|
|
97
|
+
error_stacktrace: str
|
|
98
|
+
"""The stacktrace of the exception that caused the failure."""
|
|
99
|
+
|
|
100
|
+
trace_id: str | None = None
|
|
101
|
+
"""The trace ID of the case span."""
|
|
102
|
+
span_id: str | None = None
|
|
103
|
+
"""The span ID of the case span."""
|
|
73
104
|
|
|
74
105
|
|
|
75
106
|
ReportCaseAdapter = TypeAdapter(ReportCase[Any, Any, Any])
|
|
107
|
+
ReportCaseFailureAdapter = TypeAdapter(ReportCaseFailure[Any, Any, Any])
|
|
76
108
|
|
|
77
109
|
|
|
78
110
|
class ReportCaseAggregate(BaseModel):
|
|
@@ -152,7 +184,7 @@ class ReportCaseAggregate(BaseModel):
|
|
|
152
184
|
)
|
|
153
185
|
|
|
154
186
|
|
|
155
|
-
@dataclass
|
|
187
|
+
@dataclass(kw_only=True)
|
|
156
188
|
class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
157
189
|
"""A report of the results of evaluating a model on a set of cases."""
|
|
158
190
|
|
|
@@ -161,15 +193,18 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
161
193
|
|
|
162
194
|
cases: list[ReportCase[InputsT, OutputT, MetadataT]]
|
|
163
195
|
"""The cases in the report."""
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
"""The span ID of the evaluation."""
|
|
196
|
+
failures: list[ReportCaseFailure[InputsT, OutputT, MetadataT]] = field(default_factory=list)
|
|
197
|
+
"""The failures in the report. These are cases where task execution raised an exception."""
|
|
167
198
|
|
|
168
199
|
trace_id: str | None = None
|
|
169
200
|
"""The trace ID of the evaluation."""
|
|
201
|
+
span_id: str | None = None
|
|
202
|
+
"""The span ID of the evaluation."""
|
|
170
203
|
|
|
171
|
-
def averages(self) -> ReportCaseAggregate:
|
|
172
|
-
|
|
204
|
+
def averages(self) -> ReportCaseAggregate | None:
|
|
205
|
+
if self.cases:
|
|
206
|
+
return ReportCaseAggregate.average(self.cases)
|
|
207
|
+
return None
|
|
173
208
|
|
|
174
209
|
def print(
|
|
175
210
|
self,
|
|
@@ -184,6 +219,9 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
184
219
|
include_total_duration: bool = False,
|
|
185
220
|
include_removed_cases: bool = False,
|
|
186
221
|
include_averages: bool = True,
|
|
222
|
+
include_errors: bool = True,
|
|
223
|
+
include_error_stacktrace: bool = False,
|
|
224
|
+
include_evaluator_failures: bool = True,
|
|
187
225
|
input_config: RenderValueConfig | None = None,
|
|
188
226
|
metadata_config: RenderValueConfig | None = None,
|
|
189
227
|
output_config: RenderValueConfig | None = None,
|
|
@@ -207,6 +245,7 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
207
245
|
include_total_duration=include_total_duration,
|
|
208
246
|
include_removed_cases=include_removed_cases,
|
|
209
247
|
include_averages=include_averages,
|
|
248
|
+
include_evaluator_failures=include_evaluator_failures,
|
|
210
249
|
input_config=input_config,
|
|
211
250
|
metadata_config=metadata_config,
|
|
212
251
|
output_config=output_config,
|
|
@@ -216,7 +255,19 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
216
255
|
duration_config=duration_config,
|
|
217
256
|
include_reasons=include_reasons,
|
|
218
257
|
)
|
|
219
|
-
Console(width=width)
|
|
258
|
+
console = Console(width=width)
|
|
259
|
+
console.print(table)
|
|
260
|
+
if include_errors and self.failures:
|
|
261
|
+
failures_table = self.failures_table(
|
|
262
|
+
include_input=include_input,
|
|
263
|
+
include_metadata=include_metadata,
|
|
264
|
+
include_expected_output=include_expected_output,
|
|
265
|
+
include_error_message=True,
|
|
266
|
+
include_error_stacktrace=include_error_stacktrace,
|
|
267
|
+
input_config=input_config,
|
|
268
|
+
metadata_config=metadata_config,
|
|
269
|
+
)
|
|
270
|
+
console.print(failures_table, style='red')
|
|
220
271
|
|
|
221
272
|
def console_table(
|
|
222
273
|
self,
|
|
@@ -230,6 +281,7 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
230
281
|
include_total_duration: bool = False,
|
|
231
282
|
include_removed_cases: bool = False,
|
|
232
283
|
include_averages: bool = True,
|
|
284
|
+
include_evaluator_failures: bool = True,
|
|
233
285
|
input_config: RenderValueConfig | None = None,
|
|
234
286
|
metadata_config: RenderValueConfig | None = None,
|
|
235
287
|
output_config: RenderValueConfig | None = None,
|
|
@@ -252,6 +304,9 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
252
304
|
include_total_duration=include_total_duration,
|
|
253
305
|
include_removed_cases=include_removed_cases,
|
|
254
306
|
include_averages=include_averages,
|
|
307
|
+
include_error_message=False,
|
|
308
|
+
include_error_stacktrace=False,
|
|
309
|
+
include_evaluator_failures=include_evaluator_failures,
|
|
255
310
|
input_config={**_DEFAULT_VALUE_CONFIG, **(input_config or {})},
|
|
256
311
|
metadata_config={**_DEFAULT_VALUE_CONFIG, **(metadata_config or {})},
|
|
257
312
|
output_config=output_config or _DEFAULT_VALUE_CONFIG,
|
|
@@ -266,6 +321,41 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
|
|
|
266
321
|
else: # pragma: no cover
|
|
267
322
|
return renderer.build_diff_table(self, baseline)
|
|
268
323
|
|
|
324
|
+
def failures_table(
|
|
325
|
+
self,
|
|
326
|
+
*,
|
|
327
|
+
include_input: bool = False,
|
|
328
|
+
include_metadata: bool = False,
|
|
329
|
+
include_expected_output: bool = False,
|
|
330
|
+
include_error_message: bool = True,
|
|
331
|
+
include_error_stacktrace: bool = True,
|
|
332
|
+
input_config: RenderValueConfig | None = None,
|
|
333
|
+
metadata_config: RenderValueConfig | None = None,
|
|
334
|
+
) -> Table:
|
|
335
|
+
"""Return a table containing the failures in this report."""
|
|
336
|
+
renderer = EvaluationRenderer(
|
|
337
|
+
include_input=include_input,
|
|
338
|
+
include_metadata=include_metadata,
|
|
339
|
+
include_expected_output=include_expected_output,
|
|
340
|
+
include_output=False,
|
|
341
|
+
include_durations=False,
|
|
342
|
+
include_total_duration=False,
|
|
343
|
+
include_removed_cases=False,
|
|
344
|
+
include_averages=False,
|
|
345
|
+
input_config={**_DEFAULT_VALUE_CONFIG, **(input_config or {})},
|
|
346
|
+
metadata_config={**_DEFAULT_VALUE_CONFIG, **(metadata_config or {})},
|
|
347
|
+
output_config=_DEFAULT_VALUE_CONFIG,
|
|
348
|
+
score_configs={},
|
|
349
|
+
label_configs={},
|
|
350
|
+
metric_configs={},
|
|
351
|
+
duration_config=_DEFAULT_DURATION_CONFIG,
|
|
352
|
+
include_reasons=False,
|
|
353
|
+
include_error_message=include_error_message,
|
|
354
|
+
include_error_stacktrace=include_error_stacktrace,
|
|
355
|
+
include_evaluator_failures=False, # Not applicable for failures table
|
|
356
|
+
)
|
|
357
|
+
return renderer.build_failures_table(self)
|
|
358
|
+
|
|
269
359
|
def __str__(self) -> str: # pragma: lax no cover
|
|
270
360
|
"""Return a string representation of the report."""
|
|
271
361
|
table = self.console_table()
|
|
@@ -286,7 +376,7 @@ class RenderValueConfig(TypedDict, total=False):
|
|
|
286
376
|
diff_style: str
|
|
287
377
|
|
|
288
378
|
|
|
289
|
-
@dataclass
|
|
379
|
+
@dataclass(kw_only=True)
|
|
290
380
|
class _ValueRenderer:
|
|
291
381
|
value_formatter: str | Callable[[Any], str] = '{}'
|
|
292
382
|
diff_checker: Callable[[Any, Any], bool] | None = lambda x, y: x != y
|
|
@@ -401,7 +491,7 @@ class RenderNumberConfig(TypedDict, total=False):
|
|
|
401
491
|
"""
|
|
402
492
|
|
|
403
493
|
|
|
404
|
-
@dataclass
|
|
494
|
+
@dataclass(kw_only=True)
|
|
405
495
|
class _NumberRenderer:
|
|
406
496
|
"""See documentation of `RenderNumberConfig` for more details about the parameters here."""
|
|
407
497
|
|
|
@@ -503,7 +593,7 @@ class _NumberRenderer:
|
|
|
503
593
|
return None
|
|
504
594
|
|
|
505
595
|
diff = new - old
|
|
506
|
-
if abs(diff) < self.diff_atol + self.diff_rtol * abs(old):
|
|
596
|
+
if abs(diff) < self.diff_atol + self.diff_rtol * abs(old):
|
|
507
597
|
return None
|
|
508
598
|
return self.diff_increase_style if diff > 0 else self.diff_decrease_style
|
|
509
599
|
|
|
@@ -532,7 +622,7 @@ _DEFAULT_DURATION_CONFIG = RenderNumberConfig(
|
|
|
532
622
|
T = TypeVar('T')
|
|
533
623
|
|
|
534
624
|
|
|
535
|
-
@dataclass
|
|
625
|
+
@dataclass(kw_only=True)
|
|
536
626
|
class ReportCaseRenderer:
|
|
537
627
|
include_input: bool
|
|
538
628
|
include_metadata: bool
|
|
@@ -545,6 +635,9 @@ class ReportCaseRenderer:
|
|
|
545
635
|
include_reasons: bool
|
|
546
636
|
include_durations: bool
|
|
547
637
|
include_total_duration: bool
|
|
638
|
+
include_error_message: bool
|
|
639
|
+
include_error_stacktrace: bool
|
|
640
|
+
include_evaluator_failures: bool
|
|
548
641
|
|
|
549
642
|
input_renderer: _ValueRenderer
|
|
550
643
|
metadata_renderer: _ValueRenderer
|
|
@@ -574,10 +667,28 @@ class ReportCaseRenderer:
|
|
|
574
667
|
table.add_column('Metrics', overflow='fold')
|
|
575
668
|
if self.include_assertions:
|
|
576
669
|
table.add_column('Assertions', overflow='fold')
|
|
670
|
+
if self.include_evaluator_failures:
|
|
671
|
+
table.add_column('Evaluator Failures', overflow='fold')
|
|
577
672
|
if self.include_durations:
|
|
578
673
|
table.add_column('Durations' if self.include_total_duration else 'Duration', justify='right')
|
|
579
674
|
return table
|
|
580
675
|
|
|
676
|
+
def build_failures_table(self, title: str) -> Table:
|
|
677
|
+
"""Build and return a Rich Table for the failures output."""
|
|
678
|
+
table = Table(title=title, show_lines=True)
|
|
679
|
+
table.add_column('Case ID', style='bold')
|
|
680
|
+
if self.include_input:
|
|
681
|
+
table.add_column('Inputs', overflow='fold')
|
|
682
|
+
if self.include_metadata:
|
|
683
|
+
table.add_column('Metadata', overflow='fold')
|
|
684
|
+
if self.include_expected_output:
|
|
685
|
+
table.add_column('Expected Output', overflow='fold')
|
|
686
|
+
if self.include_error_message:
|
|
687
|
+
table.add_column('Error Message', overflow='fold')
|
|
688
|
+
if self.include_error_stacktrace:
|
|
689
|
+
table.add_column('Error Stacktrace', overflow='fold')
|
|
690
|
+
return table
|
|
691
|
+
|
|
581
692
|
def build_row(self, case: ReportCase) -> list[str]:
|
|
582
693
|
"""Build a table row for a single case."""
|
|
583
694
|
row = [case.name]
|
|
@@ -606,6 +717,9 @@ class ReportCaseRenderer:
|
|
|
606
717
|
if self.include_assertions:
|
|
607
718
|
row.append(self._render_assertions(list(case.assertions.values())))
|
|
608
719
|
|
|
720
|
+
if self.include_evaluator_failures:
|
|
721
|
+
row.append(self._render_evaluator_failures(case.evaluator_failures))
|
|
722
|
+
|
|
609
723
|
if self.include_durations:
|
|
610
724
|
row.append(self._render_durations(case))
|
|
611
725
|
|
|
@@ -639,6 +753,9 @@ class ReportCaseRenderer:
|
|
|
639
753
|
if self.include_assertions:
|
|
640
754
|
row.append(self._render_aggregate_assertions(aggregate.assertions))
|
|
641
755
|
|
|
756
|
+
if self.include_evaluator_failures:
|
|
757
|
+
row.append(EMPTY_AGGREGATE_CELL_STR)
|
|
758
|
+
|
|
642
759
|
if self.include_durations:
|
|
643
760
|
row.append(self._render_durations(aggregate))
|
|
644
761
|
|
|
@@ -700,6 +817,12 @@ class ReportCaseRenderer:
|
|
|
700
817
|
)
|
|
701
818
|
row.append(assertions_diff)
|
|
702
819
|
|
|
820
|
+
if self.include_evaluator_failures: # pragma: no branch
|
|
821
|
+
evaluator_failures_diff = self._render_evaluator_failures_diff(
|
|
822
|
+
baseline.evaluator_failures, new_case.evaluator_failures
|
|
823
|
+
)
|
|
824
|
+
row.append(evaluator_failures_diff)
|
|
825
|
+
|
|
703
826
|
if self.include_durations: # pragma: no branch
|
|
704
827
|
durations_diff = self._render_durations_diff(baseline, new_case)
|
|
705
828
|
row.append(durations_diff)
|
|
@@ -743,12 +866,36 @@ class ReportCaseRenderer:
|
|
|
743
866
|
assertions_diff = self._render_aggregate_assertions_diff(baseline.assertions, new.assertions)
|
|
744
867
|
row.append(assertions_diff)
|
|
745
868
|
|
|
869
|
+
if self.include_evaluator_failures: # pragma: no branch
|
|
870
|
+
row.append(EMPTY_AGGREGATE_CELL_STR)
|
|
871
|
+
|
|
746
872
|
if self.include_durations: # pragma: no branch
|
|
747
873
|
durations_diff = self._render_durations_diff(baseline, new)
|
|
748
874
|
row.append(durations_diff)
|
|
749
875
|
|
|
750
876
|
return row
|
|
751
877
|
|
|
878
|
+
def build_failure_row(self, case: ReportCaseFailure) -> list[str]:
|
|
879
|
+
"""Build a table row for a single case failure."""
|
|
880
|
+
row = [case.name]
|
|
881
|
+
|
|
882
|
+
if self.include_input:
|
|
883
|
+
row.append(self.input_renderer.render_value(None, case.inputs) or EMPTY_CELL_STR)
|
|
884
|
+
|
|
885
|
+
if self.include_metadata:
|
|
886
|
+
row.append(self.metadata_renderer.render_value(None, case.metadata) or EMPTY_CELL_STR)
|
|
887
|
+
|
|
888
|
+
if self.include_expected_output:
|
|
889
|
+
row.append(self.output_renderer.render_value(None, case.expected_output) or EMPTY_CELL_STR)
|
|
890
|
+
|
|
891
|
+
if self.include_error_message:
|
|
892
|
+
row.append(case.error_message or EMPTY_CELL_STR)
|
|
893
|
+
|
|
894
|
+
if self.include_error_stacktrace:
|
|
895
|
+
row.append(case.error_stacktrace or EMPTY_CELL_STR)
|
|
896
|
+
|
|
897
|
+
return row
|
|
898
|
+
|
|
752
899
|
def _render_durations(self, case: ReportCase | ReportCaseAggregate) -> str:
|
|
753
900
|
"""Build the diff string for a duration value."""
|
|
754
901
|
case_durations: dict[str, float] = {'task': case.task_duration}
|
|
@@ -862,8 +1009,33 @@ class ReportCaseRenderer:
|
|
|
862
1009
|
rendered_new = default_render_percentage(new) + ' [green]✔[/]' if new is not None else EMPTY_CELL_STR
|
|
863
1010
|
return rendered_new if rendered_baseline == rendered_new else f'{rendered_baseline} → {rendered_new}'
|
|
864
1011
|
|
|
1012
|
+
def _render_evaluator_failures(
|
|
1013
|
+
self,
|
|
1014
|
+
failures: list[EvaluatorFailure],
|
|
1015
|
+
) -> str:
|
|
1016
|
+
if not failures:
|
|
1017
|
+
return EMPTY_CELL_STR # pragma: no cover
|
|
1018
|
+
lines: list[str] = []
|
|
1019
|
+
for failure in failures:
|
|
1020
|
+
line = f'[red]{failure.name}[/]'
|
|
1021
|
+
if failure.error_message:
|
|
1022
|
+
line += f': {failure.error_message}'
|
|
1023
|
+
lines.append(line)
|
|
1024
|
+
return '\n'.join(lines)
|
|
1025
|
+
|
|
1026
|
+
def _render_evaluator_failures_diff(
|
|
1027
|
+
self,
|
|
1028
|
+
baseline_failures: list[EvaluatorFailure],
|
|
1029
|
+
new_failures: list[EvaluatorFailure],
|
|
1030
|
+
) -> str:
|
|
1031
|
+
baseline_str = self._render_evaluator_failures(baseline_failures)
|
|
1032
|
+
new_str = self._render_evaluator_failures(new_failures)
|
|
1033
|
+
if baseline_str == new_str:
|
|
1034
|
+
return baseline_str # pragma: no cover
|
|
1035
|
+
return f'{baseline_str}\n→\n{new_str}'
|
|
1036
|
+
|
|
865
1037
|
|
|
866
|
-
@dataclass
|
|
1038
|
+
@dataclass(kw_only=True)
|
|
867
1039
|
class EvaluationRenderer:
|
|
868
1040
|
"""A class for rendering an EvalReport or the diff between two EvalReports."""
|
|
869
1041
|
|
|
@@ -887,10 +1059,13 @@ class EvaluationRenderer:
|
|
|
887
1059
|
metric_configs: dict[str, RenderNumberConfig]
|
|
888
1060
|
duration_config: RenderNumberConfig
|
|
889
1061
|
|
|
890
|
-
# TODO: Make this class kw-only so we can reorder the kwargs
|
|
891
1062
|
# Data to include
|
|
892
1063
|
include_reasons: bool # only applies to reports, not to diffs
|
|
893
1064
|
|
|
1065
|
+
include_error_message: bool
|
|
1066
|
+
include_error_stacktrace: bool
|
|
1067
|
+
include_evaluator_failures: bool
|
|
1068
|
+
|
|
894
1069
|
def include_scores(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
|
|
895
1070
|
return any(case.scores for case in self._all_cases(report, baseline))
|
|
896
1071
|
|
|
@@ -903,6 +1078,11 @@ class EvaluationRenderer:
|
|
|
903
1078
|
def include_assertions(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
|
|
904
1079
|
return any(case.assertions for case in self._all_cases(report, baseline))
|
|
905
1080
|
|
|
1081
|
+
def include_evaluator_failures_column(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
|
|
1082
|
+
return self.include_evaluator_failures and any(
|
|
1083
|
+
case.evaluator_failures for case in self._all_cases(report, baseline)
|
|
1084
|
+
)
|
|
1085
|
+
|
|
906
1086
|
def _all_cases(self, report: EvaluationReport, baseline: EvaluationReport | None) -> list[ReportCase]:
|
|
907
1087
|
if not baseline:
|
|
908
1088
|
return report.cases
|
|
@@ -940,6 +1120,9 @@ class EvaluationRenderer:
|
|
|
940
1120
|
include_reasons=self.include_reasons,
|
|
941
1121
|
include_durations=self.include_durations,
|
|
942
1122
|
include_total_duration=self.include_total_duration,
|
|
1123
|
+
include_error_message=self.include_error_message,
|
|
1124
|
+
include_error_stacktrace=self.include_error_stacktrace,
|
|
1125
|
+
include_evaluator_failures=self.include_evaluator_failures_column(report, baseline),
|
|
943
1126
|
input_renderer=input_renderer,
|
|
944
1127
|
metadata_renderer=metadata_renderer,
|
|
945
1128
|
output_renderer=output_renderer,
|
|
@@ -957,7 +1140,9 @@ class EvaluationRenderer:
|
|
|
957
1140
|
|
|
958
1141
|
if self.include_averages: # pragma: no branch
|
|
959
1142
|
average = report.averages()
|
|
960
|
-
|
|
1143
|
+
if average: # pragma: no branch
|
|
1144
|
+
table.add_row(*case_renderer.build_aggregate_row(average))
|
|
1145
|
+
|
|
961
1146
|
return table
|
|
962
1147
|
|
|
963
1148
|
def build_diff_table(self, report: EvaluationReport, baseline: EvaluationReport) -> Table:
|
|
@@ -1004,6 +1189,14 @@ class EvaluationRenderer:
|
|
|
1004
1189
|
|
|
1005
1190
|
return table
|
|
1006
1191
|
|
|
1192
|
+
def build_failures_table(self, report: EvaluationReport) -> Table:
|
|
1193
|
+
case_renderer = self._get_case_renderer(report)
|
|
1194
|
+
table = case_renderer.build_failures_table('Case Failures')
|
|
1195
|
+
for case in report.failures:
|
|
1196
|
+
table.add_row(*case_renderer.build_failure_row(case))
|
|
1197
|
+
|
|
1198
|
+
return table
|
|
1199
|
+
|
|
1007
1200
|
def _infer_score_renderers(
|
|
1008
1201
|
self, report: EvaluationReport, baseline: EvaluationReport | None
|
|
1009
1202
|
) -> dict[str, _NumberRenderer]:
|
|
@@ -28,7 +28,6 @@ classifiers = [
|
|
|
28
28
|
"Programming Language :: Python",
|
|
29
29
|
"Programming Language :: Python :: 3",
|
|
30
30
|
"Programming Language :: Python :: 3 :: Only",
|
|
31
|
-
"Programming Language :: Python :: 3.9",
|
|
32
31
|
"Programming Language :: Python :: 3.10",
|
|
33
32
|
"Programming Language :: Python :: 3.11",
|
|
34
33
|
"Programming Language :: Python :: 3.12",
|
|
@@ -44,7 +43,7 @@ classifiers = [
|
|
|
44
43
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
45
44
|
"Topic :: Internet",
|
|
46
45
|
]
|
|
47
|
-
requires-python = ">=3.
|
|
46
|
+
requires-python = ">=3.10"
|
|
48
47
|
|
|
49
48
|
[tool.hatch.metadata.hooks.uv-dynamic-versioning]
|
|
50
49
|
dependencies = [
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|