extract-bench 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. extract_bench/__init__.py +42 -0
  2. extract_bench/evaluation/__init__.py +42 -0
  3. extract_bench/evaluation/evaluation_config.py +33 -0
  4. extract_bench/evaluation/metric_id_collector.py +41 -0
  5. extract_bench/evaluation/metric_registry.py +120 -0
  6. extract_bench/evaluation/metrics/__init__.py +43 -0
  7. extract_bench/evaluation/metrics/array_metrics.py +116 -0
  8. extract_bench/evaluation/metrics/base_metric.py +40 -0
  9. extract_bench/evaluation/metrics/boolean_metrics.py +56 -0
  10. extract_bench/evaluation/metrics/llm_metrics.py +231 -0
  11. extract_bench/evaluation/metrics/metric_descriptors.py +36 -0
  12. extract_bench/evaluation/metrics/metric_prompts/__init__.py +16 -0
  13. extract_bench/evaluation/metrics/metric_prompts/array_llm.txt +35 -0
  14. extract_bench/evaluation/metrics/metric_prompts/llm_judge.txt +15 -0
  15. extract_bench/evaluation/metrics/metric_prompts/string_semantic.txt +17 -0
  16. extract_bench/evaluation/metrics/metric_utils.py +123 -0
  17. extract_bench/evaluation/metrics/number_metrics.py +148 -0
  18. extract_bench/evaluation/metrics/policy_metric.py +44 -0
  19. extract_bench/evaluation/metrics/string_metrics.py +195 -0
  20. extract_bench/evaluation/presets.py +109 -0
  21. extract_bench/evaluation/reporting/README.md +191 -0
  22. extract_bench/evaluation/reporting/__init__.py +47 -0
  23. extract_bench/evaluation/reporting/content_stats.py +160 -0
  24. extract_bench/evaluation/reporting/formatters.py +195 -0
  25. extract_bench/evaluation/reporting/models.py +181 -0
  26. extract_bench/evaluation/reporting/outcome_stats.py +290 -0
  27. extract_bench/evaluation/reporting/report_builder.py +169 -0
  28. extract_bench/evaluation/reporting/schema_stats.py +104 -0
  29. extract_bench/evaluation/schema_config_helpers.py +107 -0
  30. extract_bench/evaluation/schema_value_instantiator.py +213 -0
  31. extract_bench/evaluation/structured_evaluator.py +226 -0
  32. extract_bench/infra/__init__.py +60 -0
  33. extract_bench/infra/asyncio_utils.py +53 -0
  34. extract_bench/infra/construct_ast.py +110 -0
  35. extract_bench/infra/nodes.py +384 -0
  36. extract_bench/infra/ref_expander.py +43 -0
  37. extract_bench/infra/schema_instance_visitor.py +125 -0
  38. extract_bench/infra/visitors.py +452 -0
  39. extract_bench-0.1.0.dist-info/METADATA +342 -0
  40. extract_bench-0.1.0.dist-info/RECORD +42 -0
  41. extract_bench-0.1.0.dist-info/WHEEL +4 -0
  42. extract_bench-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,42 @@
1
+ """Structured Extraction Evaluation Suite.
2
+
3
+ A standalone package for evaluating structured extraction quality by comparing
4
+ predicted JSON against gold JSON with per-field metrics.
5
+ """
6
+
7
+ from .evaluation import (
8
+ AsyncEvaluationConfig,
9
+ BaseMetric,
10
+ EvaluationConfig,
11
+ EvaluationReport,
12
+ MetricConfig,
13
+ MetricRegistry,
14
+ MetricResult,
15
+ ReportBuilder,
16
+ ReportConfig,
17
+ StructuredEvaluator,
18
+ StructuredEvaluatorConfig,
19
+ global_metric_registry,
20
+ )
21
+
22
+ __version__ = "0.1.0"
23
+
24
+ __all__ = [
25
+ # Main evaluator
26
+ "StructuredEvaluator",
27
+ "StructuredEvaluatorConfig",
28
+ "AsyncEvaluationConfig",
29
+ # Config
30
+ "EvaluationConfig",
31
+ "MetricConfig",
32
+ # Registry
33
+ "MetricRegistry",
34
+ "global_metric_registry",
35
+ # Metrics
36
+ "BaseMetric",
37
+ "MetricResult",
38
+ # Reporting
39
+ "ReportBuilder",
40
+ "ReportConfig",
41
+ "EvaluationReport",
42
+ ]
@@ -0,0 +1,42 @@
1
+ """Structured extraction evaluation module."""
2
+
3
+ from .evaluation_config import EvaluationConfig, MetricConfig
4
+ from .metric_registry import MetricRegistry, global_metric_registry
5
+ from .metrics.base_metric import BaseMetric, MetricResult
6
+ from .reporting import EvaluationReport, ReportBuilder, ReportConfig
7
+ from .schema_config_helpers import (
8
+ add_evaluation_configs_to_export,
9
+ get_default_evaluation_config,
10
+ get_evaluation_config,
11
+ should_evaluate,
12
+ )
13
+ from .structured_evaluator import (
14
+ AsyncEvaluationConfig,
15
+ StructuredEvaluator,
16
+ StructuredEvaluatorConfig,
17
+ )
18
+
19
+ __all__ = [
20
+ # Config
21
+ "EvaluationConfig",
22
+ "MetricConfig",
23
+ # Registry
24
+ "MetricRegistry",
25
+ "global_metric_registry",
26
+ # Metrics
27
+ "BaseMetric",
28
+ "MetricResult",
29
+ # Helpers
30
+ "get_evaluation_config",
31
+ "get_default_evaluation_config",
32
+ "should_evaluate",
33
+ "add_evaluation_configs_to_export",
34
+ # Evaluator
35
+ "AsyncEvaluationConfig",
36
+ "StructuredEvaluator",
37
+ "StructuredEvaluatorConfig",
38
+ # Reporting
39
+ "ReportBuilder",
40
+ "ReportConfig",
41
+ "EvaluationReport",
42
+ ]
@@ -0,0 +1,33 @@
1
+ """Evaluation config models used to configure metrics per schema node."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+
7
+
8
+ class MetricConfig(BaseModel):
9
+ """Configuration for a single evaluation metric."""
10
+
11
+ metric_id: str
12
+ weight: Optional[float] = None
13
+ params: Optional[Dict[str, Any]] = Field(
14
+ default=None, description="Override default metric parameters"
15
+ )
16
+
17
+ model_config = ConfigDict(frozen=True, extra="forbid")
18
+
19
+
20
+ class EvaluationConfig(BaseModel):
21
+ """Configuration attached to schema nodes to guide evaluation."""
22
+
23
+ metrics: List[MetricConfig] = Field(default_factory=list)
24
+ aggregation_weight: Optional[float] = None
25
+
26
+ model_config = ConfigDict(frozen=True, extra="forbid")
27
+
28
+ @classmethod
29
+ def from_preset(cls, preset: str) -> "EvaluationConfig":
30
+ """Create evaluation config from a preset string."""
31
+ from .presets import get_preset_config
32
+
33
+ return get_preset_config(preset)
@@ -0,0 +1,41 @@
1
+ """Collect metric IDs and parameter examples from evaluation configs in a schema subtree."""
2
+
3
+ from typing import Any, Dict
4
+
5
+ from ..infra.visitors import AnalyzerVisitor
6
+ from ..infra.nodes import Schema
7
+ from .schema_config_helpers import get_evaluation_config
8
+
9
+
10
+ def collect_metric_ids_with_params(
11
+ *, schema: Schema
12
+ ) -> dict[str, list[dict[str, Any]]]:
13
+ """Return metric_id -> list of distinct params dicts observed in the subtree."""
14
+ visitor = _MetricIdCollector()
15
+ visitor.visit(schema)
16
+ return visitor.metric_params
17
+
18
+
19
+ class _MetricIdCollector(AnalyzerVisitor):
20
+ def __init__(self) -> None:
21
+ super().__init__()
22
+ self.metric_params: Dict[str, list[dict[str, Any]]] = {}
23
+ self._seen_param_sigs: Dict[str, set[str]] = {}
24
+
25
+ def _collect(self, node: Schema) -> None:
26
+ config = get_evaluation_config(node)
27
+ for metric in config.metrics:
28
+ metric_id = metric.metric_id
29
+ params = metric.params or {}
30
+ sig = repr(sorted(params.items()))
31
+ if sig in self._seen_param_sigs.setdefault(metric_id, set()):
32
+ continue
33
+ self._seen_param_sigs[metric_id].add(sig)
34
+ self.metric_params.setdefault(metric_id, []).append(dict(params))
35
+
36
+ def pre_visit_children_hook(self, node: Schema) -> None:
37
+ self._collect(node)
38
+
39
+ def generic_leaf_visit(self, node: Schema) -> Schema:
40
+ self._collect(node)
41
+ return node
@@ -0,0 +1,120 @@
1
+ """Metric registry for structured extraction evaluation."""
2
+
3
+ from collections.abc import Callable
4
+ from typing import Dict, Mapping
5
+
6
+ from .metrics.array_metrics import ArrayLlmJudgeMetric
7
+ from .metrics.base_metric import BaseMetric
8
+ from .metrics.boolean_metrics import BooleanExactMatchMetric
9
+ from .metrics.llm_metrics import LlmJudgeMetric
10
+ from .metrics.metric_descriptors import MetricPromptDescriptor
11
+ from .metrics.number_metrics import (
12
+ ExactNumberMatchMetric,
13
+ IntegerExactMatchMetric,
14
+ NumberToleranceMetric,
15
+ )
16
+ from .metrics.string_metrics import (
17
+ CaseInsensitiveStringMatchMetric,
18
+ ExactStringMatchMetric,
19
+ NormalizedLevenshteinSimilarityMetric,
20
+ StringSemanticMetric,
21
+ )
22
+
23
+ MetricFactory = Callable[[], BaseMetric]
24
+
25
+
26
+ class MetricRegistry:
27
+ """Registry storing factories for metric instances."""
28
+
29
+ def __init__(self, *, aliases: Mapping[str, str] | None = None) -> None:
30
+ self._factories: Dict[str, MetricFactory] = {}
31
+ self._aliases: Dict[str, str] = dict(aliases or {})
32
+ self._descriptors: Dict[str, MetricPromptDescriptor] = {}
33
+
34
+ def register_metric_factory(
35
+ self, metric_id: str, factory: MetricFactory, *, override: bool = False
36
+ ) -> None:
37
+ if metric_id in self._factories and not override:
38
+ raise ValueError(f"Metric '{metric_id}' is already registered")
39
+ self._factories[metric_id] = factory
40
+
41
+ def register_metric(
42
+ self,
43
+ factory: MetricFactory,
44
+ *,
45
+ override: bool = False,
46
+ aliases: tuple[str, ...] = (),
47
+ ) -> str:
48
+ sample = factory()
49
+ metric_id = sample.metric_id
50
+ self.register_metric_factory(metric_id, factory, override=override)
51
+ descriptor = getattr(sample, "prompt_descriptor", None)
52
+ if descriptor is not None:
53
+ self.register_metric_descriptor(metric_id, descriptor, override=override)
54
+ for alias in aliases:
55
+ self._aliases[alias] = metric_id
56
+ return metric_id
57
+
58
+ def has_metric(self, metric_id: str) -> bool:
59
+ resolved = self._aliases.get(metric_id, metric_id)
60
+ return resolved in self._factories
61
+
62
+ def create_metric(self, metric_id: str) -> BaseMetric:
63
+ resolved = self._aliases.get(metric_id, metric_id)
64
+ if resolved not in self._factories:
65
+ raise KeyError(f"Metric '{metric_id}' is not registered")
66
+ metric = self._factories[resolved]()
67
+ return metric
68
+
69
+ def register_metric_descriptor(
70
+ self,
71
+ metric_id: str,
72
+ descriptor: MetricPromptDescriptor,
73
+ *,
74
+ override: bool = False,
75
+ ) -> None:
76
+ if metric_id in self._descriptors and not override:
77
+ raise ValueError(f"Metric descriptor '{metric_id}' is already registered")
78
+ if descriptor.metric_id != metric_id:
79
+ raise ValueError(
80
+ f"Metric descriptor id mismatch: expected '{metric_id}', got '{descriptor.metric_id}'"
81
+ )
82
+ self._descriptors[metric_id] = descriptor
83
+
84
+ def get_metric_descriptor(self, metric_id: str) -> MetricPromptDescriptor | None:
85
+ resolved = self._aliases.get(metric_id, metric_id)
86
+ return self._descriptors.get(resolved)
87
+
88
+ def available_metrics(self) -> tuple[str, ...]:
89
+ return tuple(self._factories.keys())
90
+
91
+ def unregister_metric(self, metric_id: str) -> None:
92
+ resolved = self._aliases.get(metric_id, metric_id)
93
+ self._factories.pop(resolved, None)
94
+ self._descriptors.pop(resolved, None)
95
+ self._aliases = {
96
+ alias: target
97
+ for alias, target in self._aliases.items()
98
+ if alias != metric_id and target != resolved
99
+ }
100
+
101
+
102
+ def _register_default_metrics(registry: MetricRegistry) -> None:
103
+ default_metrics = [
104
+ ExactStringMatchMetric,
105
+ CaseInsensitiveStringMatchMetric,
106
+ NormalizedLevenshteinSimilarityMetric,
107
+ StringSemanticMetric,
108
+ ExactNumberMatchMetric,
109
+ NumberToleranceMetric,
110
+ IntegerExactMatchMetric,
111
+ BooleanExactMatchMetric,
112
+ LlmJudgeMetric,
113
+ ArrayLlmJudgeMetric,
114
+ ]
115
+ for metric in default_metrics:
116
+ registry.register_metric(metric)
117
+
118
+
119
+ global_metric_registry = MetricRegistry()
120
+ _register_default_metrics(global_metric_registry)
@@ -0,0 +1,43 @@
1
+ """Evaluation metrics for structured extraction."""
2
+
3
+ from .array_metrics import ArrayLlmJudgeMetric
4
+ from .base_metric import BaseMetric, MetricContext, MetricResult
5
+ from .boolean_metrics import BooleanExactMatchMetric
6
+ from .llm_metrics import LlmJudgeMetric
7
+ from .metric_descriptors import MetricPromptDescriptor
8
+ from .number_metrics import (
9
+ ExactNumberMatchMetric,
10
+ IntegerExactMatchMetric,
11
+ NumberToleranceMetric,
12
+ )
13
+ from .policy_metric import PolicyAwareMetric
14
+ from .string_metrics import (
15
+ CaseInsensitiveStringMatchMetric,
16
+ ExactStringMatchMetric,
17
+ NormalizedLevenshteinSimilarityMetric,
18
+ StringSemanticMetric,
19
+ )
20
+
21
+ __all__ = [
22
+ # Base
23
+ "BaseMetric",
24
+ "MetricContext",
25
+ "MetricResult",
26
+ "MetricPromptDescriptor",
27
+ "PolicyAwareMetric",
28
+ # String
29
+ "ExactStringMatchMetric",
30
+ "CaseInsensitiveStringMatchMetric",
31
+ "NormalizedLevenshteinSimilarityMetric",
32
+ "StringSemanticMetric",
33
+ # Number
34
+ "ExactNumberMatchMetric",
35
+ "NumberToleranceMetric",
36
+ "IntegerExactMatchMetric",
37
+ # Boolean
38
+ "BooleanExactMatchMetric",
39
+ # Array
40
+ "ArrayLlmJudgeMetric",
41
+ # LLM
42
+ "LlmJudgeMetric",
43
+ ]
@@ -0,0 +1,116 @@
1
+ """Array evaluation metrics."""
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from typing import Any, ClassVar, Dict
6
+
7
+ from ...infra.nodes import BaseSchema
8
+ from ..evaluation_config import MetricConfig
9
+ from .base_metric import MetricContext
10
+ from .llm_metrics import LlmJudgeMetric
11
+ from .metric_descriptors import MetricPromptDescriptor
12
+ from .metric_prompts import ARRAY_LLM_PROMPT
13
+
14
+
15
+ @dataclass(frozen=True, slots=True)
16
+ class ArrayLlmJudgeMetric(LlmJudgeMetric):
17
+ metric_id: ClassVar[str] = "array_llm"
18
+ default_prompt_template: ClassVar[str] = ARRAY_LLM_PROMPT
19
+ prompt_descriptor = MetricPromptDescriptor(
20
+ metric_id="array_llm",
21
+ summary="LLM-based array evaluation with item matching and confusion summary.",
22
+ pass_rule="pass iff score >= pass_threshold (default threshold is metric-specific)",
23
+ score_rule="matched/(matched+missed_gold) after postprocessing",
24
+ params={
25
+ "pass_threshold": {"type": "number", "meaning": "minimum score to pass"},
26
+ "additional_instructions": {
27
+ "type": "string",
28
+ "meaning": "extra domain rules",
29
+ },
30
+ },
31
+ )
32
+ default_structured_output_schema: ClassVar[Dict[str, Any]] = {
33
+ "type": "object",
34
+ "properties": {
35
+ "reasoning": {"type": "string"},
36
+ "matched_items": {"type": "array", "items": {"type": "string"}},
37
+ "missed_gold_items": {"type": "array", "items": {"type": "string"}},
38
+ "spurious_pred_items": {"type": "array", "items": {"type": "string"}},
39
+ "matches_summary": {
40
+ "type": "object",
41
+ "properties": {
42
+ "matched": {"type": "integer", "minimum": 0},
43
+ "missed_gold": {"type": "integer", "minimum": 0},
44
+ "spurious_pred": {"type": "integer", "minimum": 0},
45
+ },
46
+ "required": ["matched", "missed_gold", "spurious_pred"],
47
+ },
48
+ },
49
+ "required": ["reasoning", "matches_summary"],
50
+ "additionalProperties": True,
51
+ }
52
+
53
+ def _get_additional_prompt_context(self, node: MetricContext) -> str | None:
54
+ """Append metric definitions for child fields so the LLM understands per-field evaluation rules."""
55
+ from ..metric_id_collector import collect_metric_ids_with_params
56
+ from ..metric_registry import global_metric_registry
57
+
58
+ if not isinstance(node, BaseSchema):
59
+ return None
60
+ metric_params = collect_metric_ids_with_params(schema=node)
61
+ if not metric_params:
62
+ return None
63
+ definitions: dict[str, Any] = {}
64
+ for metric_id, params_examples in sorted(metric_params.items()):
65
+ descriptor = global_metric_registry.get_metric_descriptor(metric_id)
66
+ if descriptor is None:
67
+ definitions[metric_id] = {
68
+ "metric_id": metric_id,
69
+ "summary": "No descriptor registered for this metric_id.",
70
+ "params_examples": params_examples,
71
+ }
72
+ else:
73
+ definitions[metric_id] = descriptor.describe(
74
+ params_examples=params_examples
75
+ )
76
+ return (
77
+ f"Metric definitions (JSON):\n"
78
+ f"{json.dumps(definitions, ensure_ascii=True, indent=2)}"
79
+ )
80
+
81
+ def postprocess_parsed_result(
82
+ self,
83
+ node: MetricContext,
84
+ config: MetricConfig | None,
85
+ parsed: Dict[str, Any],
86
+ ) -> Dict[str, Any]:
87
+ parsed = dict(parsed)
88
+ summary = parsed["matches_summary"]
89
+ matched, missed, spurious = (
90
+ summary["matched"],
91
+ summary["missed_gold"],
92
+ summary["spurious_pred"],
93
+ )
94
+ denominator = max(1, matched + missed)
95
+ parsed["score"] = matched / denominator
96
+
97
+ total = matched + missed + spurious
98
+ gold_total = matched + missed
99
+ pred_total = matched + spurious
100
+
101
+ accuracy = matched / total if total > 0 else 1.0
102
+ precision = matched / pred_total if pred_total > 0 else 1.0
103
+ recall = matched / gold_total if gold_total > 0 else 1.0
104
+ if precision + recall == 0.0:
105
+ f1 = 0.0
106
+ else:
107
+ f1 = 2 * precision * recall / (precision + recall)
108
+
109
+ parsed["aggregate_metrics"] = {
110
+ "accuracy": accuracy,
111
+ "precision": precision,
112
+ "recall": recall,
113
+ "f1": f1,
114
+ }
115
+
116
+ return parsed
@@ -0,0 +1,40 @@
1
+ """Base interfaces for evaluation metrics."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Protocol
5
+
6
+ from ..evaluation_config import MetricConfig
7
+
8
+
9
+ @dataclass(frozen=True, slots=True)
10
+ class MetricResult:
11
+ """Structured result returned by metrics."""
12
+
13
+ metric_id: str
14
+ score: float
15
+ passed: bool | None = None
16
+ details: dict[str, Any] | None = None
17
+
18
+
19
+ class MetricContext(Protocol):
20
+ """Subset of schema node interface used by metrics."""
21
+
22
+ def get_gold_value(self) -> Any: ...
23
+
24
+ def get_extracted_value(self) -> Any: ...
25
+
26
+ def get_metadata_summary(self) -> Dict[str, Any]: ...
27
+
28
+
29
+ class BaseMetric(Protocol):
30
+ """Common interface implemented by all metrics.
31
+
32
+ All metrics are async to enable parallel evaluation.
33
+ """
34
+
35
+ metric_id: str
36
+ recurse_into_children: bool = True
37
+
38
+ async def evaluate(
39
+ self, node: MetricContext, config: MetricConfig | None = None
40
+ ) -> MetricResult: ...
@@ -0,0 +1,56 @@
1
+ """Boolean evaluation metrics."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ from ..evaluation_config import MetricConfig
7
+ from .base_metric import BaseMetric, MetricContext, MetricResult
8
+ from .metric_descriptors import MetricPromptDescriptor
9
+ from .policy_metric import PolicyAwareMetric
10
+
11
+
12
+ def _to_bool(value: Any) -> bool | None:
13
+ if isinstance(value, bool):
14
+ return value
15
+ return None
16
+
17
+
18
+ @dataclass(frozen=True, slots=True)
19
+ class BooleanExactMatchMetric(PolicyAwareMetric, BaseMetric):
20
+ metric_id: str = "boolean_exact"
21
+ prompt_descriptor = MetricPromptDescriptor(
22
+ metric_id="boolean_exact",
23
+ summary="Exact boolean equality.",
24
+ pass_rule="pass iff gold is True/False and predicted equals gold",
25
+ score_rule="1.0 if pass else 0.0",
26
+ )
27
+
28
+ async def _evaluate_values(
29
+ self,
30
+ *,
31
+ node: MetricContext,
32
+ gold: Any,
33
+ extracted: Any,
34
+ config: MetricConfig | None,
35
+ ) -> MetricResult:
36
+ gold_bool = _to_bool(gold)
37
+ extracted_bool = _to_bool(extracted)
38
+
39
+ if gold_bool is None or extracted_bool is None:
40
+ passed = gold_bool is extracted_bool
41
+ score = 1.0 if passed else 0.0
42
+ return MetricResult(
43
+ metric_id=self.metric_id,
44
+ score=score,
45
+ passed=passed,
46
+ details={"gold": gold, "extracted": extracted},
47
+ )
48
+
49
+ passed = gold_bool == extracted_bool
50
+ score = 1.0 if passed else 0.0
51
+ return MetricResult(
52
+ metric_id=self.metric_id,
53
+ score=score,
54
+ passed=passed,
55
+ details={"gold": gold_bool, "extracted": extracted_bool},
56
+ )