fba-bench-core 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fba_bench_core/__init__.py +11 -0
- fba_bench_core/agents/__init__.py +15 -0
- fba_bench_core/agents/base.py +83 -0
- fba_bench_core/agents/registry.py +16 -0
- fba_bench_core/benchmarking/__init__.py +6 -0
- fba_bench_core/benchmarking/core/__init__.py +1 -0
- fba_bench_core/benchmarking/engine/__init__.py +12 -0
- fba_bench_core/benchmarking/engine/core.py +135 -0
- fba_bench_core/benchmarking/engine/models.py +62 -0
- fba_bench_core/benchmarking/metrics/__init__.py +30 -0
- fba_bench_core/benchmarking/metrics/accuracy_score.py +27 -0
- fba_bench_core/benchmarking/metrics/aggregate.py +39 -0
- fba_bench_core/benchmarking/metrics/completeness.py +38 -0
- fba_bench_core/benchmarking/metrics/cost_efficiency.py +32 -0
- fba_bench_core/benchmarking/metrics/custom_scriptable.py +17 -0
- fba_bench_core/benchmarking/metrics/keyword_coverage.py +41 -0
- fba_bench_core/benchmarking/metrics/policy_compliance.py +18 -0
- fba_bench_core/benchmarking/metrics/registry.py +57 -0
- fba_bench_core/benchmarking/metrics/robustness.py +27 -0
- fba_bench_core/benchmarking/metrics/technical_performance.py +16 -0
- fba_bench_core/benchmarking/registry.py +48 -0
- fba_bench_core/benchmarking/scenarios/__init__.py +1 -0
- fba_bench_core/benchmarking/scenarios/base.py +36 -0
- fba_bench_core/benchmarking/scenarios/complex_marketplace.py +181 -0
- fba_bench_core/benchmarking/scenarios/multiturn_tool_use.py +176 -0
- fba_bench_core/benchmarking/scenarios/registry.py +18 -0
- fba_bench_core/benchmarking/scenarios/research_summarization.py +141 -0
- fba_bench_core/benchmarking/validators/__init__.py +24 -0
- fba_bench_core/benchmarking/validators/determinism_check.py +95 -0
- fba_bench_core/benchmarking/validators/fairness_balance.py +75 -0
- fba_bench_core/benchmarking/validators/outlier_detection.py +53 -0
- fba_bench_core/benchmarking/validators/registry.py +57 -0
- fba_bench_core/benchmarking/validators/reproducibility_metadata.py +74 -0
- fba_bench_core/benchmarking/validators/schema_adherence.py +59 -0
- fba_bench_core/benchmarking/validators/structural_consistency.py +74 -0
- fba_bench_core/config.py +154 -0
- fba_bench_core/domain/__init__.py +75 -0
- fba_bench_core/domain/events/__init__.py +230 -0
- fba_bench_core/domain/events/analytics.py +69 -0
- fba_bench_core/domain/events/base.py +59 -0
- fba_bench_core/domain/events/inventory.py +119 -0
- fba_bench_core/domain/events/marketing.py +102 -0
- fba_bench_core/domain/events/pricing.py +179 -0
- fba_bench_core/domain/models.py +296 -0
- fba_bench_core/exceptions/__init__.py +9 -0
- fba_bench_core/exceptions/base.py +46 -0
- fba_bench_core/services/__init__.py +12 -0
- fba_bench_core/services/base.py +52 -0
- fba_bench_core-1.0.0.dist-info/METADATA +152 -0
- fba_bench_core-1.0.0.dist-info/RECORD +52 -0
- fba_bench_core-1.0.0.dist-info/WHEEL +4 -0
- fba_bench_core-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,11 @@
|
|
1
|
+
"""Phase 2 scaffold for the fba_bench_core package.
|
2
|
+
|
3
|
+
This module is a minimal placeholder created during Phase 2 of the
|
4
|
+
core rebuild. It intentionally contains no runtime logic; only the
|
5
|
+
package-level metadata required to allow safe imports and satisfy
|
6
|
+
linters during the rescue process.
|
7
|
+
"""
|
8
|
+
|
9
|
+
__all__ = []
|
10
|
+
|
11
|
+
__version__ = "1.0.0"
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Package exports for fba_bench_core.agents.
|
2
|
+
|
3
|
+
Exports the BaseAgent abstract class and exposes the registry module to allow
|
4
|
+
external code to discover and register agent implementations. Also export the
|
5
|
+
typed base configuration model to make it easy for downstream users to extend.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from __future__ import annotations
|
9
|
+
|
10
|
+
from fba_bench_core.config import BaseAgentConfig
|
11
|
+
|
12
|
+
from . import registry
|
13
|
+
from .base import BaseAgent
|
14
|
+
|
15
|
+
__all__ = ["BaseAgent", "registry", "BaseAgentConfig"]
|
@@ -0,0 +1,83 @@
|
|
1
|
+
"""Typed BaseAgent for fba_bench_core.
|
2
|
+
|
3
|
+
Phase D change:
|
4
|
+
- Replace legacy **kwargs configuration with a typed Pydantic configuration
|
5
|
+
object. Downstream implementations should subclass the provided config model
|
6
|
+
for specialized parameters.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from __future__ import annotations
|
10
|
+
|
11
|
+
from abc import ABC, abstractmethod
|
12
|
+
|
13
|
+
from fba_bench_core.config import BaseAgentConfig
|
14
|
+
from fba_bench_core.domain.events import BaseEvent, Command
|
15
|
+
|
16
|
+
|
17
|
+
class BaseAgent(ABC):
|
18
|
+
"""Abstract base class for agents that receive a validated configuration.
|
19
|
+
|
20
|
+
Rationale:
|
21
|
+
Using a typed configuration object prevents downstream implementations
|
22
|
+
from hiding untyped parameters behind `Any` and enables validation at
|
23
|
+
construction time. Implementations that require additional fields may
|
24
|
+
subclass `BaseAgentConfig` (see examples in the config module).
|
25
|
+
|
26
|
+
Initialization:
|
27
|
+
The agent receives a single `config: BaseAgentConfig` argument. The
|
28
|
+
`agent_id` is expected to be present on that config and becomes the
|
29
|
+
agent's immutable identifier.
|
30
|
+
|
31
|
+
Immutability:
|
32
|
+
The provided config model is frozen (Pydantic frozen model). The agent
|
33
|
+
stores the config object directly and exposes it via a read-only
|
34
|
+
property to avoid accidental mutation.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, config: BaseAgentConfig) -> None:
|
38
|
+
"""Initialize the base agent with a validated, typed configuration.
|
39
|
+
|
40
|
+
Parameters:
|
41
|
+
config: An instance of BaseAgentConfig or a subclass thereof. The
|
42
|
+
model is validated by Pydantic prior to construction.
|
43
|
+
|
44
|
+
Notes:
|
45
|
+
- Do not accept `**kwargs` here: typed configs are required.
|
46
|
+
- The agent keeps a reference to the provided config (which is
|
47
|
+
immutable/frozen). Use `agent.config.model_copy()` to obtain a
|
48
|
+
mutable copy if necessary.
|
49
|
+
"""
|
50
|
+
self._config = config
|
51
|
+
self._agent_id = config.agent_id
|
52
|
+
|
53
|
+
@property
|
54
|
+
def agent_id(self) -> str:
|
55
|
+
"""Return the agent's unique identifier (from config.agent_id)."""
|
56
|
+
return self._agent_id
|
57
|
+
|
58
|
+
@property
|
59
|
+
def config(self) -> BaseAgentConfig:
|
60
|
+
"""Return the typed configuration object for this agent.
|
61
|
+
|
62
|
+
The returned object is immutable (Pydantic frozen model). Downstream
|
63
|
+
code that needs to modify configuration should create a new instance
|
64
|
+
(e.g., via `model_copy(update={...})`).
|
65
|
+
"""
|
66
|
+
return self._config
|
67
|
+
|
68
|
+
def get_config(self) -> dict:
|
69
|
+
"""Return a serializable shallow mapping of the configuration.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
A dict produced by Pydantic's model_dump() representing the config.
|
73
|
+
"""
|
74
|
+
return self._config.model_dump()
|
75
|
+
|
76
|
+
@abstractmethod
|
77
|
+
async def decide(self, events: list[BaseEvent]) -> list[Command]:
|
78
|
+
"""Decide on a list of Commands given observed domain events.
|
79
|
+
|
80
|
+
Implementations must be async coroutines and must not mutate the
|
81
|
+
provided `events` list.
|
82
|
+
"""
|
83
|
+
raise NotImplementedError
|
@@ -0,0 +1,16 @@
|
|
1
|
+
"""Phase 5 placeholder for the agent registry.
|
2
|
+
|
3
|
+
This module will hold mappings of agent names to agent classes and will be
|
4
|
+
populated during Phase 5. It intentionally contains no runtime logic now.
|
5
|
+
"""
|
6
|
+
|
7
|
+
AGENT_REGISTRY: dict[str, type] = {} # Populated in a later phase (Phase 5)
|
8
|
+
|
9
|
+
|
10
|
+
def create_runner(key: str, config: dict):
|
11
|
+
"""Stub function for creating runners."""
|
12
|
+
|
13
|
+
class DummyRunner:
|
14
|
+
agent_id = config.get("agent_id", "dummy")
|
15
|
+
|
16
|
+
return DummyRunner()
|
@@ -0,0 +1 @@
|
|
1
|
+
# Benchmarking core module
|
@@ -0,0 +1,12 @@
|
|
1
|
+
"""Public API for the FBA-Bench Benchmarking Engine."""
|
2
|
+
|
3
|
+
from .core import Engine
|
4
|
+
from .models import EngineConfig, EngineReport, RunnerSpec, ScenarioSpec
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"Engine",
|
8
|
+
"EngineConfig",
|
9
|
+
"EngineReport",
|
10
|
+
"RunnerSpec",
|
11
|
+
"ScenarioSpec",
|
12
|
+
]
|
@@ -0,0 +1,135 @@
|
|
1
|
+
import asyncio
|
2
|
+
|
3
|
+
from agent_runners.registry import create_runner
|
4
|
+
from fba_bench_core.benchmarking.metrics.registry import MetricRegistry
|
5
|
+
from fba_bench_core.benchmarking.scenarios.registry import scenario_registry
|
6
|
+
from fba_bench_core.benchmarking.validators.registry import ValidatorRegistry
|
7
|
+
|
8
|
+
from .models import EngineConfig, EngineReport, RunReport, ScenarioReport
|
9
|
+
|
10
|
+
|
11
|
+
class Engine:
|
12
|
+
def __init__(self, config: EngineConfig):
|
13
|
+
self.config = config
|
14
|
+
|
15
|
+
async def run(self) -> EngineReport:
|
16
|
+
scenario_reports = []
|
17
|
+
for scenario_spec in self.config.scenarios:
|
18
|
+
scenario_fn = scenario_registry.get(scenario_spec.key)
|
19
|
+
if scenario_fn is None:
|
20
|
+
raise ValueError(
|
21
|
+
f"Scenario '{scenario_spec.key}' not found in registry"
|
22
|
+
)
|
23
|
+
|
24
|
+
# Assume first runner for minimal implementation
|
25
|
+
if not self.config.runners:
|
26
|
+
raise ValueError("No runners configured")
|
27
|
+
runner_spec = self.config.runners[0]
|
28
|
+
runner = create_runner(runner_spec.key, runner_spec.config)
|
29
|
+
|
30
|
+
repetitions = scenario_spec.repetitions or 1
|
31
|
+
seeds = scenario_spec.seeds or [42] * repetitions
|
32
|
+
scenario_runs = []
|
33
|
+
timeout = scenario_spec.timeout_seconds
|
34
|
+
retries = self.config.retries
|
35
|
+
|
36
|
+
for i in range(repetitions):
|
37
|
+
seed = seeds[i] if i < len(seeds) else 42
|
38
|
+
payload = {"seed": seed}
|
39
|
+
status = None
|
40
|
+
output = None
|
41
|
+
for attempt in range(retries + 1):
|
42
|
+
try:
|
43
|
+
if timeout is not None:
|
44
|
+
output = await asyncio.wait_for(
|
45
|
+
scenario_fn(runner, payload), timeout=timeout
|
46
|
+
)
|
47
|
+
else:
|
48
|
+
output = await scenario_fn(runner, payload)
|
49
|
+
status = "success"
|
50
|
+
break
|
51
|
+
except TimeoutError:
|
52
|
+
status = "timeout"
|
53
|
+
output = None
|
54
|
+
break # Do not retry timeouts
|
55
|
+
except Exception:
|
56
|
+
if attempt < retries:
|
57
|
+
continue
|
58
|
+
status = "error"
|
59
|
+
output = None
|
60
|
+
break
|
61
|
+
|
62
|
+
# Apply metrics
|
63
|
+
metrics = {}
|
64
|
+
if self.config.metrics:
|
65
|
+
metric_reg = MetricRegistry()
|
66
|
+
for metric_name in self.config.metrics:
|
67
|
+
try:
|
68
|
+
metric = metric_reg.create_metric(metric_name)
|
69
|
+
if metric:
|
70
|
+
metrics[metric_name] = (
|
71
|
+
metric.calculate(output or {})
|
72
|
+
if output is not None
|
73
|
+
else 0.0
|
74
|
+
)
|
75
|
+
except Exception:
|
76
|
+
metrics[metric_name] = 0.0
|
77
|
+
|
78
|
+
run_report = RunReport(
|
79
|
+
status=status,
|
80
|
+
output=output,
|
81
|
+
seed=seed,
|
82
|
+
metrics=metrics,
|
83
|
+
)
|
84
|
+
scenario_runs.append(run_report)
|
85
|
+
|
86
|
+
# Compute aggregates for metrics
|
87
|
+
aggregates = {}
|
88
|
+
if scenario_runs and self.config.metrics:
|
89
|
+
all_metrics = set()
|
90
|
+
for r in scenario_runs:
|
91
|
+
all_metrics.update(r.metrics.keys())
|
92
|
+
mean_metrics = {}
|
93
|
+
for m in all_metrics:
|
94
|
+
values = [r.metrics.get(m, 0.0) for r in scenario_runs]
|
95
|
+
mean_metrics[m] = sum(values) / len(values)
|
96
|
+
aggregates["metrics"] = {"mean": mean_metrics}
|
97
|
+
|
98
|
+
# Apply validators
|
99
|
+
if self.config.validators and scenario_runs:
|
100
|
+
validations = []
|
101
|
+
val_reg = ValidatorRegistry()
|
102
|
+
run_data = [r.model_dump() for r in scenario_runs]
|
103
|
+
for val_name in self.config.validators:
|
104
|
+
try:
|
105
|
+
validator = val_reg.create_validator(val_name)
|
106
|
+
if validator:
|
107
|
+
result = validator.validate(run_data)
|
108
|
+
result_dict = (
|
109
|
+
result.to_dict()
|
110
|
+
if hasattr(result, "to_dict")
|
111
|
+
else result
|
112
|
+
)
|
113
|
+
validations.append(
|
114
|
+
{"name": val_name, "result": result_dict}
|
115
|
+
)
|
116
|
+
except Exception:
|
117
|
+
validations.append(
|
118
|
+
{
|
119
|
+
"name": val_name,
|
120
|
+
"result": {
|
121
|
+
"is_valid": False,
|
122
|
+
"error": "Validation failed",
|
123
|
+
},
|
124
|
+
}
|
125
|
+
)
|
126
|
+
aggregates["validations"] = validations
|
127
|
+
|
128
|
+
scenario_report = ScenarioReport(
|
129
|
+
key=scenario_spec.key,
|
130
|
+
runs=scenario_runs,
|
131
|
+
aggregates=aggregates,
|
132
|
+
)
|
133
|
+
scenario_reports.append(scenario_report)
|
134
|
+
|
135
|
+
return EngineReport(scenario_reports=scenario_reports)
|
@@ -0,0 +1,62 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
from pydantic import BaseModel, field_validator
|
4
|
+
|
5
|
+
|
6
|
+
class RunnerSpec(BaseModel):
|
7
|
+
key: str
|
8
|
+
config: dict[str, Any]
|
9
|
+
|
10
|
+
|
11
|
+
class ScenarioSpec(BaseModel):
|
12
|
+
key: str
|
13
|
+
timeout_seconds: float | None = None
|
14
|
+
repetitions: int | None = None
|
15
|
+
seeds: list[int] | None = None
|
16
|
+
|
17
|
+
|
18
|
+
class EngineConfig(BaseModel):
|
19
|
+
scenarios: list[ScenarioSpec]
|
20
|
+
runners: list[RunnerSpec]
|
21
|
+
metrics: list[str] | None = None
|
22
|
+
validators: list[str] | None = None
|
23
|
+
parallelism: int = 1
|
24
|
+
retries: int = 0
|
25
|
+
|
26
|
+
@field_validator("parallelism")
|
27
|
+
@classmethod
|
28
|
+
def validate_parallelism(cls, v):
|
29
|
+
if v <= 0:
|
30
|
+
raise ValueError("parallelism must be greater than 0")
|
31
|
+
return v
|
32
|
+
|
33
|
+
@field_validator("scenarios")
|
34
|
+
@classmethod
|
35
|
+
def validate_scenarios(cls, v):
|
36
|
+
if not v:
|
37
|
+
raise ValueError("scenarios cannot be empty")
|
38
|
+
return v
|
39
|
+
|
40
|
+
@field_validator("runners")
|
41
|
+
@classmethod
|
42
|
+
def validate_runners(cls, v):
|
43
|
+
if not v:
|
44
|
+
raise ValueError("runners cannot be empty")
|
45
|
+
return v
|
46
|
+
|
47
|
+
|
48
|
+
class RunReport(BaseModel):
|
49
|
+
status: str
|
50
|
+
output: Any = None
|
51
|
+
seed: int | None = None
|
52
|
+
metrics: dict[str, Any] = {}
|
53
|
+
|
54
|
+
|
55
|
+
class ScenarioReport(BaseModel):
|
56
|
+
key: str
|
57
|
+
runs: list[RunReport] = []
|
58
|
+
aggregates: dict[str, Any] = {}
|
59
|
+
|
60
|
+
|
61
|
+
class EngineReport(BaseModel):
|
62
|
+
scenario_reports: list[ScenarioReport]
|
@@ -0,0 +1,30 @@
|
|
1
|
+
"""Metrics module for benchmarking."""
|
2
|
+
|
3
|
+
# Import metric modules to register them
|
4
|
+
from . import (
|
5
|
+
accuracy_score,
|
6
|
+
aggregate,
|
7
|
+
completeness,
|
8
|
+
cost_efficiency,
|
9
|
+
custom_scriptable,
|
10
|
+
keyword_coverage,
|
11
|
+
policy_compliance,
|
12
|
+
robustness,
|
13
|
+
technical_performance,
|
14
|
+
)
|
15
|
+
from .registry import get_metric, list_metrics, register_metric
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
"accuracy_score",
|
19
|
+
"aggregate",
|
20
|
+
"completeness",
|
21
|
+
"cost_efficiency",
|
22
|
+
"custom_scriptable",
|
23
|
+
"keyword_coverage",
|
24
|
+
"policy_compliance",
|
25
|
+
"robustness",
|
26
|
+
"technical_performance",
|
27
|
+
"get_metric",
|
28
|
+
"list_metrics",
|
29
|
+
"register_metric",
|
30
|
+
]
|
@@ -0,0 +1,27 @@
|
|
1
|
+
"""Accuracy score metric."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_metric
|
6
|
+
|
7
|
+
|
8
|
+
@register_metric("accuracy_score")
|
9
|
+
def accuracy_score(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
|
10
|
+
"""Calculate accuracy score."""
|
11
|
+
output = run.get("output", "")
|
12
|
+
expected = config.get("expected_output", "")
|
13
|
+
mode = config.get("mode", "exact")
|
14
|
+
case_insensitive = config.get("case_insensitive", False)
|
15
|
+
|
16
|
+
if case_insensitive:
|
17
|
+
output = output.lower()
|
18
|
+
expected = expected.lower()
|
19
|
+
|
20
|
+
if mode == "exact":
|
21
|
+
score = 1.0 if output == expected else 0.0
|
22
|
+
elif mode == "contains":
|
23
|
+
score = 1.0 if expected in output else 0.0
|
24
|
+
else:
|
25
|
+
score = 0.0 # default
|
26
|
+
|
27
|
+
return {"mode": mode, "accuracy": score}
|
@@ -0,0 +1,39 @@
|
|
1
|
+
"""Aggregation utilities for metrics."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
|
6
|
+
def aggregate_all(data: list[dict[str, Any]]) -> dict[str, Any]:
|
7
|
+
"""Aggregate all metrics from a list of data points."""
|
8
|
+
if not data:
|
9
|
+
return {}
|
10
|
+
# Simple implementation: return the first item
|
11
|
+
return data[0] if data else {}
|
12
|
+
|
13
|
+
|
14
|
+
def aggregate_metric_values(data: list[dict[str, Any]], field: str) -> dict[str, Any]:
|
15
|
+
"""Aggregate metric values for a specific field."""
|
16
|
+
values = [item.get(field) for item in data if field in item]
|
17
|
+
if not values:
|
18
|
+
return {}
|
19
|
+
# Filter out None values for type safety
|
20
|
+
clean_values = [v for v in values if v is not None]
|
21
|
+
if not clean_values:
|
22
|
+
return {}
|
23
|
+
numeric_values = [v for v in clean_values if isinstance(v, (int, float))]
|
24
|
+
boolean_values = [v for v in clean_values if isinstance(v, bool)]
|
25
|
+
by_field = {}
|
26
|
+
if numeric_values:
|
27
|
+
by_field["numeric"] = {
|
28
|
+
"mean": sum(numeric_values) / len(numeric_values),
|
29
|
+
"min": min(numeric_values),
|
30
|
+
"max": max(numeric_values),
|
31
|
+
"count": len(numeric_values),
|
32
|
+
}
|
33
|
+
if boolean_values:
|
34
|
+
by_field["boolean"] = {
|
35
|
+
"success_rate": sum(boolean_values) / len(boolean_values),
|
36
|
+
"true_count": sum(1 for b in boolean_values if b),
|
37
|
+
"false_count": len(boolean_values) - sum(1 for b in boolean_values if b),
|
38
|
+
}
|
39
|
+
return by_field
|
@@ -0,0 +1,38 @@
|
|
1
|
+
"""Completeness metric."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_metric
|
6
|
+
|
7
|
+
|
8
|
+
@register_metric("completeness")
|
9
|
+
def completeness(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
|
10
|
+
"""Calculate completeness."""
|
11
|
+
output = run.get("output", {})
|
12
|
+
required_fields = config.get("required_fields", [])
|
13
|
+
allow_nested = config.get("allow_nested", False)
|
14
|
+
if not required_fields:
|
15
|
+
return {"required": 0, "present": 0, "completeness": 1.0}
|
16
|
+
|
17
|
+
present = 0
|
18
|
+
for field in required_fields:
|
19
|
+
if allow_nested and "." in field:
|
20
|
+
keys = field.split(".")
|
21
|
+
current = output
|
22
|
+
found = True
|
23
|
+
for key in keys:
|
24
|
+
if isinstance(current, dict) and key in current:
|
25
|
+
current = current[key]
|
26
|
+
else:
|
27
|
+
found = False
|
28
|
+
break
|
29
|
+
if found:
|
30
|
+
present += 1
|
31
|
+
elif field in output:
|
32
|
+
present += 1
|
33
|
+
|
34
|
+
return {
|
35
|
+
"required": len(required_fields),
|
36
|
+
"present": present,
|
37
|
+
"completeness": present / len(required_fields),
|
38
|
+
}
|
@@ -0,0 +1,32 @@
|
|
1
|
+
"""Cost efficiency metric."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_metric
|
6
|
+
|
7
|
+
|
8
|
+
@register_metric("cost_efficiency")
|
9
|
+
def cost_efficiency(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
|
10
|
+
"""Calculate cost efficiency."""
|
11
|
+
output = run.get("output", {})
|
12
|
+
cost = output.get("cost", 0)
|
13
|
+
token_usage = output.get("token_usage", {})
|
14
|
+
total_tokens = token_usage.get("total_tokens", 0)
|
15
|
+
token_to_cost_rate = config.get("token_to_cost_rate", 0)
|
16
|
+
score_value = config.get("score_value", 1.0)
|
17
|
+
|
18
|
+
if cost > 0:
|
19
|
+
efficiency = score_value / cost
|
20
|
+
supported = True
|
21
|
+
reason = None
|
22
|
+
elif total_tokens > 0 and token_to_cost_rate > 0:
|
23
|
+
cost = total_tokens * token_to_cost_rate
|
24
|
+
efficiency = score_value / cost
|
25
|
+
supported = True
|
26
|
+
reason = None
|
27
|
+
else:
|
28
|
+
efficiency = 0.0
|
29
|
+
supported = False
|
30
|
+
reason = "missing_usage"
|
31
|
+
|
32
|
+
return {"supported": supported, "reason": reason, "efficiency": efficiency}
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""Custom scriptable metric."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_metric
|
6
|
+
|
7
|
+
|
8
|
+
@register_metric("custom_scriptable")
|
9
|
+
def custom_scriptable(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
|
10
|
+
"""Calculate custom scriptable metric."""
|
11
|
+
expression = config.get("expression", "0")
|
12
|
+
try:
|
13
|
+
# Simple eval with run and config in scope, restricted globals
|
14
|
+
result = eval(expression, {"__builtins__": {}}, {"run": run, "config": config})
|
15
|
+
return {"result": result, "expression": expression}
|
16
|
+
except Exception as e:
|
17
|
+
return {"result": False, "error": str(e), "expression": expression}
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"""Keyword coverage metric."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_metric
|
6
|
+
|
7
|
+
|
8
|
+
@register_metric("keyword_coverage")
|
9
|
+
def keyword_coverage(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
|
10
|
+
"""Calculate keyword coverage."""
|
11
|
+
field_path = config.get("field_path", "")
|
12
|
+
keywords = config.get("keywords", [])
|
13
|
+
unique_match = config.get("unique_match", True)
|
14
|
+
|
15
|
+
text = ""
|
16
|
+
data = run.get("output", {})
|
17
|
+
if field_path:
|
18
|
+
keys = field_path.split(".")
|
19
|
+
for key in keys:
|
20
|
+
if isinstance(data, dict):
|
21
|
+
data = data.get(key, "")
|
22
|
+
else:
|
23
|
+
break
|
24
|
+
text = str(data) if data else ""
|
25
|
+
else:
|
26
|
+
text = str(data)
|
27
|
+
|
28
|
+
if not keywords:
|
29
|
+
return {"keyword_total": 0, "keyword_hits": 0, "coverage": 0.0}
|
30
|
+
|
31
|
+
if unique_match:
|
32
|
+
found = set(kw for kw in keywords if kw in text)
|
33
|
+
hits = len(found)
|
34
|
+
else:
|
35
|
+
hits = sum(text.count(kw) for kw in keywords)
|
36
|
+
|
37
|
+
return {
|
38
|
+
"keyword_total": len(keywords),
|
39
|
+
"keyword_hits": hits,
|
40
|
+
"coverage": hits / len(keywords),
|
41
|
+
}
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""Policy compliance metric."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_metric
|
6
|
+
|
7
|
+
|
8
|
+
@register_metric("policy_compliance")
|
9
|
+
def policy_compliance(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
|
10
|
+
"""Calculate policy compliance."""
|
11
|
+
output = run.get("output", {})
|
12
|
+
violations = output.get("policy_violations", 0)
|
13
|
+
if isinstance(violations, (list, tuple)):
|
14
|
+
violations = len(violations)
|
15
|
+
elif isinstance(violations, dict):
|
16
|
+
violations = violations.get("count", 0)
|
17
|
+
compliant = violations == 0
|
18
|
+
return {"policy_violations": violations, "compliant": compliant}
|
@@ -0,0 +1,57 @@
|
|
1
|
+
"""Registry for metrics."""
|
2
|
+
|
3
|
+
from collections.abc import Callable
|
4
|
+
|
5
|
+
|
6
|
+
class MetricRegistry:
|
7
|
+
_metrics: dict[str, Callable] = {}
|
8
|
+
|
9
|
+
@classmethod
|
10
|
+
def register(cls, name: str, metric_class: Callable) -> None:
|
11
|
+
"""Register a metric class."""
|
12
|
+
cls._metrics[name] = metric_class
|
13
|
+
|
14
|
+
@classmethod
|
15
|
+
def create_metric(cls, name: str, config=None) -> Callable | None:
|
16
|
+
"""Create a metric instance."""
|
17
|
+
fn = cls._metrics.get(name)
|
18
|
+
if fn:
|
19
|
+
return fn
|
20
|
+
return None
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def get_metric(cls, name: str) -> Callable | None:
|
24
|
+
"""Get a metric class by name."""
|
25
|
+
return cls._metrics.get(name)
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def list_metrics(cls) -> list[str]:
|
29
|
+
"""List all registered metric names."""
|
30
|
+
return list(cls._metrics.keys())
|
31
|
+
|
32
|
+
|
33
|
+
# Global instance for function-based API
|
34
|
+
registry = MetricRegistry()
|
35
|
+
|
36
|
+
|
37
|
+
def get_metric(name: str) -> Callable:
|
38
|
+
"""Get a metric by name, raising KeyError if not found."""
|
39
|
+
metric = registry.get_metric(name)
|
40
|
+
if metric is None:
|
41
|
+
raise KeyError(f"Metric '{name}' not found")
|
42
|
+
return metric
|
43
|
+
|
44
|
+
|
45
|
+
def list_metrics() -> list[str]:
|
46
|
+
"""List all registered metric names."""
|
47
|
+
return registry.list_metrics()
|
48
|
+
|
49
|
+
|
50
|
+
def register_metric(name: str):
|
51
|
+
"""Decorator to register a metric function with the given name."""
|
52
|
+
|
53
|
+
def decorator(func: Callable) -> Callable:
|
54
|
+
registry.register(name, func)
|
55
|
+
return func
|
56
|
+
|
57
|
+
return decorator
|