extract-bench 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extract_bench/__init__.py +42 -0
- extract_bench/evaluation/__init__.py +42 -0
- extract_bench/evaluation/evaluation_config.py +33 -0
- extract_bench/evaluation/metric_id_collector.py +41 -0
- extract_bench/evaluation/metric_registry.py +120 -0
- extract_bench/evaluation/metrics/__init__.py +43 -0
- extract_bench/evaluation/metrics/array_metrics.py +116 -0
- extract_bench/evaluation/metrics/base_metric.py +40 -0
- extract_bench/evaluation/metrics/boolean_metrics.py +56 -0
- extract_bench/evaluation/metrics/llm_metrics.py +231 -0
- extract_bench/evaluation/metrics/metric_descriptors.py +36 -0
- extract_bench/evaluation/metrics/metric_prompts/__init__.py +16 -0
- extract_bench/evaluation/metrics/metric_prompts/array_llm.txt +35 -0
- extract_bench/evaluation/metrics/metric_prompts/llm_judge.txt +15 -0
- extract_bench/evaluation/metrics/metric_prompts/string_semantic.txt +17 -0
- extract_bench/evaluation/metrics/metric_utils.py +123 -0
- extract_bench/evaluation/metrics/number_metrics.py +148 -0
- extract_bench/evaluation/metrics/policy_metric.py +44 -0
- extract_bench/evaluation/metrics/string_metrics.py +195 -0
- extract_bench/evaluation/presets.py +109 -0
- extract_bench/evaluation/reporting/README.md +191 -0
- extract_bench/evaluation/reporting/__init__.py +47 -0
- extract_bench/evaluation/reporting/content_stats.py +160 -0
- extract_bench/evaluation/reporting/formatters.py +195 -0
- extract_bench/evaluation/reporting/models.py +181 -0
- extract_bench/evaluation/reporting/outcome_stats.py +290 -0
- extract_bench/evaluation/reporting/report_builder.py +169 -0
- extract_bench/evaluation/reporting/schema_stats.py +104 -0
- extract_bench/evaluation/schema_config_helpers.py +107 -0
- extract_bench/evaluation/schema_value_instantiator.py +213 -0
- extract_bench/evaluation/structured_evaluator.py +226 -0
- extract_bench/infra/__init__.py +60 -0
- extract_bench/infra/asyncio_utils.py +53 -0
- extract_bench/infra/construct_ast.py +110 -0
- extract_bench/infra/nodes.py +384 -0
- extract_bench/infra/ref_expander.py +43 -0
- extract_bench/infra/schema_instance_visitor.py +125 -0
- extract_bench/infra/visitors.py +452 -0
- extract_bench-0.1.0.dist-info/METADATA +342 -0
- extract_bench-0.1.0.dist-info/RECORD +42 -0
- extract_bench-0.1.0.dist-info/WHEEL +4 -0
- extract_bench-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Structured Extraction Evaluation Suite.
|
|
2
|
+
|
|
3
|
+
A standalone package for evaluating structured extraction quality by comparing
|
|
4
|
+
predicted JSON against gold JSON with per-field metrics.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .evaluation import (
|
|
8
|
+
AsyncEvaluationConfig,
|
|
9
|
+
BaseMetric,
|
|
10
|
+
EvaluationConfig,
|
|
11
|
+
EvaluationReport,
|
|
12
|
+
MetricConfig,
|
|
13
|
+
MetricRegistry,
|
|
14
|
+
MetricResult,
|
|
15
|
+
ReportBuilder,
|
|
16
|
+
ReportConfig,
|
|
17
|
+
StructuredEvaluator,
|
|
18
|
+
StructuredEvaluatorConfig,
|
|
19
|
+
global_metric_registry,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
__version__ = "0.1.0"
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
# Main evaluator
|
|
26
|
+
"StructuredEvaluator",
|
|
27
|
+
"StructuredEvaluatorConfig",
|
|
28
|
+
"AsyncEvaluationConfig",
|
|
29
|
+
# Config
|
|
30
|
+
"EvaluationConfig",
|
|
31
|
+
"MetricConfig",
|
|
32
|
+
# Registry
|
|
33
|
+
"MetricRegistry",
|
|
34
|
+
"global_metric_registry",
|
|
35
|
+
# Metrics
|
|
36
|
+
"BaseMetric",
|
|
37
|
+
"MetricResult",
|
|
38
|
+
# Reporting
|
|
39
|
+
"ReportBuilder",
|
|
40
|
+
"ReportConfig",
|
|
41
|
+
"EvaluationReport",
|
|
42
|
+
]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Structured extraction evaluation module."""
|
|
2
|
+
|
|
3
|
+
from .evaluation_config import EvaluationConfig, MetricConfig
|
|
4
|
+
from .metric_registry import MetricRegistry, global_metric_registry
|
|
5
|
+
from .metrics.base_metric import BaseMetric, MetricResult
|
|
6
|
+
from .reporting import EvaluationReport, ReportBuilder, ReportConfig
|
|
7
|
+
from .schema_config_helpers import (
|
|
8
|
+
add_evaluation_configs_to_export,
|
|
9
|
+
get_default_evaluation_config,
|
|
10
|
+
get_evaluation_config,
|
|
11
|
+
should_evaluate,
|
|
12
|
+
)
|
|
13
|
+
from .structured_evaluator import (
|
|
14
|
+
AsyncEvaluationConfig,
|
|
15
|
+
StructuredEvaluator,
|
|
16
|
+
StructuredEvaluatorConfig,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
# Config
|
|
21
|
+
"EvaluationConfig",
|
|
22
|
+
"MetricConfig",
|
|
23
|
+
# Registry
|
|
24
|
+
"MetricRegistry",
|
|
25
|
+
"global_metric_registry",
|
|
26
|
+
# Metrics
|
|
27
|
+
"BaseMetric",
|
|
28
|
+
"MetricResult",
|
|
29
|
+
# Helpers
|
|
30
|
+
"get_evaluation_config",
|
|
31
|
+
"get_default_evaluation_config",
|
|
32
|
+
"should_evaluate",
|
|
33
|
+
"add_evaluation_configs_to_export",
|
|
34
|
+
# Evaluator
|
|
35
|
+
"AsyncEvaluationConfig",
|
|
36
|
+
"StructuredEvaluator",
|
|
37
|
+
"StructuredEvaluatorConfig",
|
|
38
|
+
# Reporting
|
|
39
|
+
"ReportBuilder",
|
|
40
|
+
"ReportConfig",
|
|
41
|
+
"EvaluationReport",
|
|
42
|
+
]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Evaluation config models used to configure metrics per schema node."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MetricConfig(BaseModel):
|
|
9
|
+
"""Configuration for a single evaluation metric."""
|
|
10
|
+
|
|
11
|
+
metric_id: str
|
|
12
|
+
weight: Optional[float] = None
|
|
13
|
+
params: Optional[Dict[str, Any]] = Field(
|
|
14
|
+
default=None, description="Override default metric parameters"
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EvaluationConfig(BaseModel):
|
|
21
|
+
"""Configuration attached to schema nodes to guide evaluation."""
|
|
22
|
+
|
|
23
|
+
metrics: List[MetricConfig] = Field(default_factory=list)
|
|
24
|
+
aggregation_weight: Optional[float] = None
|
|
25
|
+
|
|
26
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def from_preset(cls, preset: str) -> "EvaluationConfig":
|
|
30
|
+
"""Create evaluation config from a preset string."""
|
|
31
|
+
from .presets import get_preset_config
|
|
32
|
+
|
|
33
|
+
return get_preset_config(preset)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Collect metric IDs and parameter examples from evaluation configs in a schema subtree."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
from ..infra.visitors import AnalyzerVisitor
|
|
6
|
+
from ..infra.nodes import Schema
|
|
7
|
+
from .schema_config_helpers import get_evaluation_config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def collect_metric_ids_with_params(
|
|
11
|
+
*, schema: Schema
|
|
12
|
+
) -> dict[str, list[dict[str, Any]]]:
|
|
13
|
+
"""Return metric_id -> list of distinct params dicts observed in the subtree."""
|
|
14
|
+
visitor = _MetricIdCollector()
|
|
15
|
+
visitor.visit(schema)
|
|
16
|
+
return visitor.metric_params
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _MetricIdCollector(AnalyzerVisitor):
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.metric_params: Dict[str, list[dict[str, Any]]] = {}
|
|
23
|
+
self._seen_param_sigs: Dict[str, set[str]] = {}
|
|
24
|
+
|
|
25
|
+
def _collect(self, node: Schema) -> None:
|
|
26
|
+
config = get_evaluation_config(node)
|
|
27
|
+
for metric in config.metrics:
|
|
28
|
+
metric_id = metric.metric_id
|
|
29
|
+
params = metric.params or {}
|
|
30
|
+
sig = repr(sorted(params.items()))
|
|
31
|
+
if sig in self._seen_param_sigs.setdefault(metric_id, set()):
|
|
32
|
+
continue
|
|
33
|
+
self._seen_param_sigs[metric_id].add(sig)
|
|
34
|
+
self.metric_params.setdefault(metric_id, []).append(dict(params))
|
|
35
|
+
|
|
36
|
+
def pre_visit_children_hook(self, node: Schema) -> None:
|
|
37
|
+
self._collect(node)
|
|
38
|
+
|
|
39
|
+
def generic_leaf_visit(self, node: Schema) -> Schema:
|
|
40
|
+
self._collect(node)
|
|
41
|
+
return node
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Metric registry for structured extraction evaluation."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Dict, Mapping
|
|
5
|
+
|
|
6
|
+
from .metrics.array_metrics import ArrayLlmJudgeMetric
|
|
7
|
+
from .metrics.base_metric import BaseMetric
|
|
8
|
+
from .metrics.boolean_metrics import BooleanExactMatchMetric
|
|
9
|
+
from .metrics.llm_metrics import LlmJudgeMetric
|
|
10
|
+
from .metrics.metric_descriptors import MetricPromptDescriptor
|
|
11
|
+
from .metrics.number_metrics import (
|
|
12
|
+
ExactNumberMatchMetric,
|
|
13
|
+
IntegerExactMatchMetric,
|
|
14
|
+
NumberToleranceMetric,
|
|
15
|
+
)
|
|
16
|
+
from .metrics.string_metrics import (
|
|
17
|
+
CaseInsensitiveStringMatchMetric,
|
|
18
|
+
ExactStringMatchMetric,
|
|
19
|
+
NormalizedLevenshteinSimilarityMetric,
|
|
20
|
+
StringSemanticMetric,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
MetricFactory = Callable[[], BaseMetric]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MetricRegistry:
|
|
27
|
+
"""Registry storing factories for metric instances."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, *, aliases: Mapping[str, str] | None = None) -> None:
|
|
30
|
+
self._factories: Dict[str, MetricFactory] = {}
|
|
31
|
+
self._aliases: Dict[str, str] = dict(aliases or {})
|
|
32
|
+
self._descriptors: Dict[str, MetricPromptDescriptor] = {}
|
|
33
|
+
|
|
34
|
+
def register_metric_factory(
|
|
35
|
+
self, metric_id: str, factory: MetricFactory, *, override: bool = False
|
|
36
|
+
) -> None:
|
|
37
|
+
if metric_id in self._factories and not override:
|
|
38
|
+
raise ValueError(f"Metric '{metric_id}' is already registered")
|
|
39
|
+
self._factories[metric_id] = factory
|
|
40
|
+
|
|
41
|
+
def register_metric(
|
|
42
|
+
self,
|
|
43
|
+
factory: MetricFactory,
|
|
44
|
+
*,
|
|
45
|
+
override: bool = False,
|
|
46
|
+
aliases: tuple[str, ...] = (),
|
|
47
|
+
) -> str:
|
|
48
|
+
sample = factory()
|
|
49
|
+
metric_id = sample.metric_id
|
|
50
|
+
self.register_metric_factory(metric_id, factory, override=override)
|
|
51
|
+
descriptor = getattr(sample, "prompt_descriptor", None)
|
|
52
|
+
if descriptor is not None:
|
|
53
|
+
self.register_metric_descriptor(metric_id, descriptor, override=override)
|
|
54
|
+
for alias in aliases:
|
|
55
|
+
self._aliases[alias] = metric_id
|
|
56
|
+
return metric_id
|
|
57
|
+
|
|
58
|
+
def has_metric(self, metric_id: str) -> bool:
|
|
59
|
+
resolved = self._aliases.get(metric_id, metric_id)
|
|
60
|
+
return resolved in self._factories
|
|
61
|
+
|
|
62
|
+
def create_metric(self, metric_id: str) -> BaseMetric:
|
|
63
|
+
resolved = self._aliases.get(metric_id, metric_id)
|
|
64
|
+
if resolved not in self._factories:
|
|
65
|
+
raise KeyError(f"Metric '{metric_id}' is not registered")
|
|
66
|
+
metric = self._factories[resolved]()
|
|
67
|
+
return metric
|
|
68
|
+
|
|
69
|
+
def register_metric_descriptor(
|
|
70
|
+
self,
|
|
71
|
+
metric_id: str,
|
|
72
|
+
descriptor: MetricPromptDescriptor,
|
|
73
|
+
*,
|
|
74
|
+
override: bool = False,
|
|
75
|
+
) -> None:
|
|
76
|
+
if metric_id in self._descriptors and not override:
|
|
77
|
+
raise ValueError(f"Metric descriptor '{metric_id}' is already registered")
|
|
78
|
+
if descriptor.metric_id != metric_id:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Metric descriptor id mismatch: expected '{metric_id}', got '{descriptor.metric_id}'"
|
|
81
|
+
)
|
|
82
|
+
self._descriptors[metric_id] = descriptor
|
|
83
|
+
|
|
84
|
+
def get_metric_descriptor(self, metric_id: str) -> MetricPromptDescriptor | None:
|
|
85
|
+
resolved = self._aliases.get(metric_id, metric_id)
|
|
86
|
+
return self._descriptors.get(resolved)
|
|
87
|
+
|
|
88
|
+
def available_metrics(self) -> tuple[str, ...]:
|
|
89
|
+
return tuple(self._factories.keys())
|
|
90
|
+
|
|
91
|
+
def unregister_metric(self, metric_id: str) -> None:
|
|
92
|
+
resolved = self._aliases.get(metric_id, metric_id)
|
|
93
|
+
self._factories.pop(resolved, None)
|
|
94
|
+
self._descriptors.pop(resolved, None)
|
|
95
|
+
self._aliases = {
|
|
96
|
+
alias: target
|
|
97
|
+
for alias, target in self._aliases.items()
|
|
98
|
+
if alias != metric_id and target != resolved
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _register_default_metrics(registry: MetricRegistry) -> None:
|
|
103
|
+
default_metrics = [
|
|
104
|
+
ExactStringMatchMetric,
|
|
105
|
+
CaseInsensitiveStringMatchMetric,
|
|
106
|
+
NormalizedLevenshteinSimilarityMetric,
|
|
107
|
+
StringSemanticMetric,
|
|
108
|
+
ExactNumberMatchMetric,
|
|
109
|
+
NumberToleranceMetric,
|
|
110
|
+
IntegerExactMatchMetric,
|
|
111
|
+
BooleanExactMatchMetric,
|
|
112
|
+
LlmJudgeMetric,
|
|
113
|
+
ArrayLlmJudgeMetric,
|
|
114
|
+
]
|
|
115
|
+
for metric in default_metrics:
|
|
116
|
+
registry.register_metric(metric)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
global_metric_registry = MetricRegistry()
|
|
120
|
+
_register_default_metrics(global_metric_registry)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Evaluation metrics for structured extraction."""
|
|
2
|
+
|
|
3
|
+
from .array_metrics import ArrayLlmJudgeMetric
|
|
4
|
+
from .base_metric import BaseMetric, MetricContext, MetricResult
|
|
5
|
+
from .boolean_metrics import BooleanExactMatchMetric
|
|
6
|
+
from .llm_metrics import LlmJudgeMetric
|
|
7
|
+
from .metric_descriptors import MetricPromptDescriptor
|
|
8
|
+
from .number_metrics import (
|
|
9
|
+
ExactNumberMatchMetric,
|
|
10
|
+
IntegerExactMatchMetric,
|
|
11
|
+
NumberToleranceMetric,
|
|
12
|
+
)
|
|
13
|
+
from .policy_metric import PolicyAwareMetric
|
|
14
|
+
from .string_metrics import (
|
|
15
|
+
CaseInsensitiveStringMatchMetric,
|
|
16
|
+
ExactStringMatchMetric,
|
|
17
|
+
NormalizedLevenshteinSimilarityMetric,
|
|
18
|
+
StringSemanticMetric,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
# Base
|
|
23
|
+
"BaseMetric",
|
|
24
|
+
"MetricContext",
|
|
25
|
+
"MetricResult",
|
|
26
|
+
"MetricPromptDescriptor",
|
|
27
|
+
"PolicyAwareMetric",
|
|
28
|
+
# String
|
|
29
|
+
"ExactStringMatchMetric",
|
|
30
|
+
"CaseInsensitiveStringMatchMetric",
|
|
31
|
+
"NormalizedLevenshteinSimilarityMetric",
|
|
32
|
+
"StringSemanticMetric",
|
|
33
|
+
# Number
|
|
34
|
+
"ExactNumberMatchMetric",
|
|
35
|
+
"NumberToleranceMetric",
|
|
36
|
+
"IntegerExactMatchMetric",
|
|
37
|
+
# Boolean
|
|
38
|
+
"BooleanExactMatchMetric",
|
|
39
|
+
# Array
|
|
40
|
+
"ArrayLlmJudgeMetric",
|
|
41
|
+
# LLM
|
|
42
|
+
"LlmJudgeMetric",
|
|
43
|
+
]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Array evaluation metrics."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, ClassVar, Dict
|
|
6
|
+
|
|
7
|
+
from ...infra.nodes import BaseSchema
|
|
8
|
+
from ..evaluation_config import MetricConfig
|
|
9
|
+
from .base_metric import MetricContext
|
|
10
|
+
from .llm_metrics import LlmJudgeMetric
|
|
11
|
+
from .metric_descriptors import MetricPromptDescriptor
|
|
12
|
+
from .metric_prompts import ARRAY_LLM_PROMPT
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True, slots=True)
|
|
16
|
+
class ArrayLlmJudgeMetric(LlmJudgeMetric):
|
|
17
|
+
metric_id: ClassVar[str] = "array_llm"
|
|
18
|
+
default_prompt_template: ClassVar[str] = ARRAY_LLM_PROMPT
|
|
19
|
+
prompt_descriptor = MetricPromptDescriptor(
|
|
20
|
+
metric_id="array_llm",
|
|
21
|
+
summary="LLM-based array evaluation with item matching and confusion summary.",
|
|
22
|
+
pass_rule="pass iff score >= pass_threshold (default threshold is metric-specific)",
|
|
23
|
+
score_rule="matched/(matched+missed_gold) after postprocessing",
|
|
24
|
+
params={
|
|
25
|
+
"pass_threshold": {"type": "number", "meaning": "minimum score to pass"},
|
|
26
|
+
"additional_instructions": {
|
|
27
|
+
"type": "string",
|
|
28
|
+
"meaning": "extra domain rules",
|
|
29
|
+
},
|
|
30
|
+
},
|
|
31
|
+
)
|
|
32
|
+
default_structured_output_schema: ClassVar[Dict[str, Any]] = {
|
|
33
|
+
"type": "object",
|
|
34
|
+
"properties": {
|
|
35
|
+
"reasoning": {"type": "string"},
|
|
36
|
+
"matched_items": {"type": "array", "items": {"type": "string"}},
|
|
37
|
+
"missed_gold_items": {"type": "array", "items": {"type": "string"}},
|
|
38
|
+
"spurious_pred_items": {"type": "array", "items": {"type": "string"}},
|
|
39
|
+
"matches_summary": {
|
|
40
|
+
"type": "object",
|
|
41
|
+
"properties": {
|
|
42
|
+
"matched": {"type": "integer", "minimum": 0},
|
|
43
|
+
"missed_gold": {"type": "integer", "minimum": 0},
|
|
44
|
+
"spurious_pred": {"type": "integer", "minimum": 0},
|
|
45
|
+
},
|
|
46
|
+
"required": ["matched", "missed_gold", "spurious_pred"],
|
|
47
|
+
},
|
|
48
|
+
},
|
|
49
|
+
"required": ["reasoning", "matches_summary"],
|
|
50
|
+
"additionalProperties": True,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def _get_additional_prompt_context(self, node: MetricContext) -> str | None:
|
|
54
|
+
"""Append metric definitions for child fields so the LLM understands per-field evaluation rules."""
|
|
55
|
+
from ..metric_id_collector import collect_metric_ids_with_params
|
|
56
|
+
from ..metric_registry import global_metric_registry
|
|
57
|
+
|
|
58
|
+
if not isinstance(node, BaseSchema):
|
|
59
|
+
return None
|
|
60
|
+
metric_params = collect_metric_ids_with_params(schema=node)
|
|
61
|
+
if not metric_params:
|
|
62
|
+
return None
|
|
63
|
+
definitions: dict[str, Any] = {}
|
|
64
|
+
for metric_id, params_examples in sorted(metric_params.items()):
|
|
65
|
+
descriptor = global_metric_registry.get_metric_descriptor(metric_id)
|
|
66
|
+
if descriptor is None:
|
|
67
|
+
definitions[metric_id] = {
|
|
68
|
+
"metric_id": metric_id,
|
|
69
|
+
"summary": "No descriptor registered for this metric_id.",
|
|
70
|
+
"params_examples": params_examples,
|
|
71
|
+
}
|
|
72
|
+
else:
|
|
73
|
+
definitions[metric_id] = descriptor.describe(
|
|
74
|
+
params_examples=params_examples
|
|
75
|
+
)
|
|
76
|
+
return (
|
|
77
|
+
f"Metric definitions (JSON):\n"
|
|
78
|
+
f"{json.dumps(definitions, ensure_ascii=True, indent=2)}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def postprocess_parsed_result(
|
|
82
|
+
self,
|
|
83
|
+
node: MetricContext,
|
|
84
|
+
config: MetricConfig | None,
|
|
85
|
+
parsed: Dict[str, Any],
|
|
86
|
+
) -> Dict[str, Any]:
|
|
87
|
+
parsed = dict(parsed)
|
|
88
|
+
summary = parsed["matches_summary"]
|
|
89
|
+
matched, missed, spurious = (
|
|
90
|
+
summary["matched"],
|
|
91
|
+
summary["missed_gold"],
|
|
92
|
+
summary["spurious_pred"],
|
|
93
|
+
)
|
|
94
|
+
denominator = max(1, matched + missed)
|
|
95
|
+
parsed["score"] = matched / denominator
|
|
96
|
+
|
|
97
|
+
total = matched + missed + spurious
|
|
98
|
+
gold_total = matched + missed
|
|
99
|
+
pred_total = matched + spurious
|
|
100
|
+
|
|
101
|
+
accuracy = matched / total if total > 0 else 1.0
|
|
102
|
+
precision = matched / pred_total if pred_total > 0 else 1.0
|
|
103
|
+
recall = matched / gold_total if gold_total > 0 else 1.0
|
|
104
|
+
if precision + recall == 0.0:
|
|
105
|
+
f1 = 0.0
|
|
106
|
+
else:
|
|
107
|
+
f1 = 2 * precision * recall / (precision + recall)
|
|
108
|
+
|
|
109
|
+
parsed["aggregate_metrics"] = {
|
|
110
|
+
"accuracy": accuracy,
|
|
111
|
+
"precision": precision,
|
|
112
|
+
"recall": recall,
|
|
113
|
+
"f1": f1,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return parsed
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Base interfaces for evaluation metrics."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, Protocol
|
|
5
|
+
|
|
6
|
+
from ..evaluation_config import MetricConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True, slots=True)
|
|
10
|
+
class MetricResult:
|
|
11
|
+
"""Structured result returned by metrics."""
|
|
12
|
+
|
|
13
|
+
metric_id: str
|
|
14
|
+
score: float
|
|
15
|
+
passed: bool | None = None
|
|
16
|
+
details: dict[str, Any] | None = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MetricContext(Protocol):
|
|
20
|
+
"""Subset of schema node interface used by metrics."""
|
|
21
|
+
|
|
22
|
+
def get_gold_value(self) -> Any: ...
|
|
23
|
+
|
|
24
|
+
def get_extracted_value(self) -> Any: ...
|
|
25
|
+
|
|
26
|
+
def get_metadata_summary(self) -> Dict[str, Any]: ...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BaseMetric(Protocol):
|
|
30
|
+
"""Common interface implemented by all metrics.
|
|
31
|
+
|
|
32
|
+
All metrics are async to enable parallel evaluation.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
metric_id: str
|
|
36
|
+
recurse_into_children: bool = True
|
|
37
|
+
|
|
38
|
+
async def evaluate(
|
|
39
|
+
self, node: MetricContext, config: MetricConfig | None = None
|
|
40
|
+
) -> MetricResult: ...
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Boolean evaluation metrics."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from ..evaluation_config import MetricConfig
|
|
7
|
+
from .base_metric import BaseMetric, MetricContext, MetricResult
|
|
8
|
+
from .metric_descriptors import MetricPromptDescriptor
|
|
9
|
+
from .policy_metric import PolicyAwareMetric
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _to_bool(value: Any) -> bool | None:
|
|
13
|
+
if isinstance(value, bool):
|
|
14
|
+
return value
|
|
15
|
+
return None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True, slots=True)
|
|
19
|
+
class BooleanExactMatchMetric(PolicyAwareMetric, BaseMetric):
|
|
20
|
+
metric_id: str = "boolean_exact"
|
|
21
|
+
prompt_descriptor = MetricPromptDescriptor(
|
|
22
|
+
metric_id="boolean_exact",
|
|
23
|
+
summary="Exact boolean equality.",
|
|
24
|
+
pass_rule="pass iff gold is True/False and predicted equals gold",
|
|
25
|
+
score_rule="1.0 if pass else 0.0",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
async def _evaluate_values(
|
|
29
|
+
self,
|
|
30
|
+
*,
|
|
31
|
+
node: MetricContext,
|
|
32
|
+
gold: Any,
|
|
33
|
+
extracted: Any,
|
|
34
|
+
config: MetricConfig | None,
|
|
35
|
+
) -> MetricResult:
|
|
36
|
+
gold_bool = _to_bool(gold)
|
|
37
|
+
extracted_bool = _to_bool(extracted)
|
|
38
|
+
|
|
39
|
+
if gold_bool is None or extracted_bool is None:
|
|
40
|
+
passed = gold_bool is extracted_bool
|
|
41
|
+
score = 1.0 if passed else 0.0
|
|
42
|
+
return MetricResult(
|
|
43
|
+
metric_id=self.metric_id,
|
|
44
|
+
score=score,
|
|
45
|
+
passed=passed,
|
|
46
|
+
details={"gold": gold, "extracted": extracted},
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
passed = gold_bool == extracted_bool
|
|
50
|
+
score = 1.0 if passed else 0.0
|
|
51
|
+
return MetricResult(
|
|
52
|
+
metric_id=self.metric_id,
|
|
53
|
+
score=score,
|
|
54
|
+
passed=passed,
|
|
55
|
+
details={"gold": gold_bool, "extracted": extracted_bool},
|
|
56
|
+
)
|