fba-bench-core 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. fba_bench_core/__init__.py +11 -0
  2. fba_bench_core/agents/__init__.py +15 -0
  3. fba_bench_core/agents/base.py +83 -0
  4. fba_bench_core/agents/registry.py +16 -0
  5. fba_bench_core/benchmarking/__init__.py +6 -0
  6. fba_bench_core/benchmarking/core/__init__.py +1 -0
  7. fba_bench_core/benchmarking/engine/__init__.py +12 -0
  8. fba_bench_core/benchmarking/engine/core.py +135 -0
  9. fba_bench_core/benchmarking/engine/models.py +62 -0
  10. fba_bench_core/benchmarking/metrics/__init__.py +30 -0
  11. fba_bench_core/benchmarking/metrics/accuracy_score.py +27 -0
  12. fba_bench_core/benchmarking/metrics/aggregate.py +39 -0
  13. fba_bench_core/benchmarking/metrics/completeness.py +38 -0
  14. fba_bench_core/benchmarking/metrics/cost_efficiency.py +32 -0
  15. fba_bench_core/benchmarking/metrics/custom_scriptable.py +17 -0
  16. fba_bench_core/benchmarking/metrics/keyword_coverage.py +41 -0
  17. fba_bench_core/benchmarking/metrics/policy_compliance.py +18 -0
  18. fba_bench_core/benchmarking/metrics/registry.py +57 -0
  19. fba_bench_core/benchmarking/metrics/robustness.py +27 -0
  20. fba_bench_core/benchmarking/metrics/technical_performance.py +16 -0
  21. fba_bench_core/benchmarking/registry.py +48 -0
  22. fba_bench_core/benchmarking/scenarios/__init__.py +1 -0
  23. fba_bench_core/benchmarking/scenarios/base.py +36 -0
  24. fba_bench_core/benchmarking/scenarios/complex_marketplace.py +181 -0
  25. fba_bench_core/benchmarking/scenarios/multiturn_tool_use.py +176 -0
  26. fba_bench_core/benchmarking/scenarios/registry.py +18 -0
  27. fba_bench_core/benchmarking/scenarios/research_summarization.py +141 -0
  28. fba_bench_core/benchmarking/validators/__init__.py +24 -0
  29. fba_bench_core/benchmarking/validators/determinism_check.py +95 -0
  30. fba_bench_core/benchmarking/validators/fairness_balance.py +75 -0
  31. fba_bench_core/benchmarking/validators/outlier_detection.py +53 -0
  32. fba_bench_core/benchmarking/validators/registry.py +57 -0
  33. fba_bench_core/benchmarking/validators/reproducibility_metadata.py +74 -0
  34. fba_bench_core/benchmarking/validators/schema_adherence.py +59 -0
  35. fba_bench_core/benchmarking/validators/structural_consistency.py +74 -0
  36. fba_bench_core/config.py +154 -0
  37. fba_bench_core/domain/__init__.py +75 -0
  38. fba_bench_core/domain/events/__init__.py +230 -0
  39. fba_bench_core/domain/events/analytics.py +69 -0
  40. fba_bench_core/domain/events/base.py +59 -0
  41. fba_bench_core/domain/events/inventory.py +119 -0
  42. fba_bench_core/domain/events/marketing.py +102 -0
  43. fba_bench_core/domain/events/pricing.py +179 -0
  44. fba_bench_core/domain/models.py +296 -0
  45. fba_bench_core/exceptions/__init__.py +9 -0
  46. fba_bench_core/exceptions/base.py +46 -0
  47. fba_bench_core/services/__init__.py +12 -0
  48. fba_bench_core/services/base.py +52 -0
  49. fba_bench_core-1.0.0.dist-info/METADATA +152 -0
  50. fba_bench_core-1.0.0.dist-info/RECORD +52 -0
  51. fba_bench_core-1.0.0.dist-info/WHEEL +4 -0
  52. fba_bench_core-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,11 @@
1
+ """Phase 2 scaffold for the fba_bench_core package.
2
+
3
+ This module is a minimal placeholder created during Phase 2 of the
4
+ core rebuild. It intentionally contains no runtime logic; only the
5
+ package-level metadata required to allow safe imports and satisfy
6
+ linters during the rescue process.
7
+ """
8
+
9
+ __all__ = []
10
+
11
+ __version__ = "1.0.0"
@@ -0,0 +1,15 @@
1
+ """Package exports for fba_bench_core.agents.
2
+
3
+ Exports the BaseAgent abstract class and exposes the registry module to allow
4
+ external code to discover and register agent implementations. Also export the
5
+ typed base configuration model to make it easy for downstream users to extend.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from fba_bench_core.config import BaseAgentConfig
11
+
12
+ from . import registry
13
+ from .base import BaseAgent
14
+
15
+ __all__ = ["BaseAgent", "registry", "BaseAgentConfig"]
@@ -0,0 +1,83 @@
1
+ """Typed BaseAgent for fba_bench_core.
2
+
3
+ Phase D change:
4
+ - Replace legacy **kwargs configuration with a typed Pydantic configuration
5
+ object. Downstream implementations should subclass the provided config model
6
+ for specialized parameters.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from abc import ABC, abstractmethod
12
+
13
+ from fba_bench_core.config import BaseAgentConfig
14
+ from fba_bench_core.domain.events import BaseEvent, Command
15
+
16
+
17
+ class BaseAgent(ABC):
18
+ """Abstract base class for agents that receive a validated configuration.
19
+
20
+ Rationale:
21
+ Using a typed configuration object prevents downstream implementations
22
+ from hiding untyped parameters behind `Any` and enables validation at
23
+ construction time. Implementations that require additional fields may
24
+ subclass `BaseAgentConfig` (see examples in the config module).
25
+
26
+ Initialization:
27
+ The agent receives a single `config: BaseAgentConfig` argument. The
28
+ `agent_id` is expected to be present on that config and becomes the
29
+ agent's immutable identifier.
30
+
31
+ Immutability:
32
+ The provided config model is frozen (Pydantic frozen model). The agent
33
+ stores the config object directly and exposes it via a read-only
34
+ property to avoid accidental mutation.
35
+ """
36
+
37
+ def __init__(self, config: BaseAgentConfig) -> None:
38
+ """Initialize the base agent with a validated, typed configuration.
39
+
40
+ Parameters:
41
+ config: An instance of BaseAgentConfig or a subclass thereof. The
42
+ model is validated by Pydantic prior to construction.
43
+
44
+ Notes:
45
+ - Do not accept `**kwargs` here: typed configs are required.
46
+ - The agent keeps a reference to the provided config (which is
47
+ immutable/frozen). Use `agent.config.model_copy()` to obtain a
48
+ mutable copy if necessary.
49
+ """
50
+ self._config = config
51
+ self._agent_id = config.agent_id
52
+
53
+ @property
54
+ def agent_id(self) -> str:
55
+ """Return the agent's unique identifier (from config.agent_id)."""
56
+ return self._agent_id
57
+
58
+ @property
59
+ def config(self) -> BaseAgentConfig:
60
+ """Return the typed configuration object for this agent.
61
+
62
+ The returned object is immutable (Pydantic frozen model). Downstream
63
+ code that needs to modify configuration should create a new instance
64
+ (e.g., via `model_copy(update={...})`).
65
+ """
66
+ return self._config
67
+
68
+ def get_config(self) -> dict:
69
+ """Return a serializable shallow mapping of the configuration.
70
+
71
+ Returns:
72
+ A dict produced by Pydantic's model_dump() representing the config.
73
+ """
74
+ return self._config.model_dump()
75
+
76
+ @abstractmethod
77
+ async def decide(self, events: list[BaseEvent]) -> list[Command]:
78
+ """Decide on a list of Commands given observed domain events.
79
+
80
+ Implementations must be async coroutines and must not mutate the
81
+ provided `events` list.
82
+ """
83
+ raise NotImplementedError
@@ -0,0 +1,16 @@
1
+ """Phase 5 placeholder for the agent registry.
2
+
3
+ This module will hold mappings of agent names to agent classes and will be
4
+ populated during Phase 5. It intentionally contains no runtime logic now.
5
+ """
6
+
7
+ AGENT_REGISTRY: dict[str, type] = {} # Populated in a later phase (Phase 5)
8
+
9
+
10
+ def create_runner(key: str, config: dict):
11
+ """Stub function for creating runners."""
12
+
13
+ class DummyRunner:
14
+ agent_id = config.get("agent_id", "dummy")
15
+
16
+ return DummyRunner()
@@ -0,0 +1,6 @@
1
+ """Benchmarking module for FBA Bench Core.
2
+
3
+ This module provides benchmarking functionality including engine, scenarios, metrics, and validators.
4
+ """
5
+
6
+ __all__ = []
@@ -0,0 +1 @@
1
+ # Benchmarking core module
@@ -0,0 +1,12 @@
1
+ """Public API for the FBA-Bench Benchmarking Engine."""
2
+
3
+ from .core import Engine
4
+ from .models import EngineConfig, EngineReport, RunnerSpec, ScenarioSpec
5
+
6
+ __all__ = [
7
+ "Engine",
8
+ "EngineConfig",
9
+ "EngineReport",
10
+ "RunnerSpec",
11
+ "ScenarioSpec",
12
+ ]
@@ -0,0 +1,135 @@
1
+ import asyncio
2
+
3
+ from agent_runners.registry import create_runner
4
+ from fba_bench_core.benchmarking.metrics.registry import MetricRegistry
5
+ from fba_bench_core.benchmarking.scenarios.registry import scenario_registry
6
+ from fba_bench_core.benchmarking.validators.registry import ValidatorRegistry
7
+
8
+ from .models import EngineConfig, EngineReport, RunReport, ScenarioReport
9
+
10
+
11
+ class Engine:
12
+ def __init__(self, config: EngineConfig):
13
+ self.config = config
14
+
15
+ async def run(self) -> EngineReport:
16
+ scenario_reports = []
17
+ for scenario_spec in self.config.scenarios:
18
+ scenario_fn = scenario_registry.get(scenario_spec.key)
19
+ if scenario_fn is None:
20
+ raise ValueError(
21
+ f"Scenario '{scenario_spec.key}' not found in registry"
22
+ )
23
+
24
+ # Assume first runner for minimal implementation
25
+ if not self.config.runners:
26
+ raise ValueError("No runners configured")
27
+ runner_spec = self.config.runners[0]
28
+ runner = create_runner(runner_spec.key, runner_spec.config)
29
+
30
+ repetitions = scenario_spec.repetitions or 1
31
+ seeds = scenario_spec.seeds or [42] * repetitions
32
+ scenario_runs = []
33
+ timeout = scenario_spec.timeout_seconds
34
+ retries = self.config.retries
35
+
36
+ for i in range(repetitions):
37
+ seed = seeds[i] if i < len(seeds) else 42
38
+ payload = {"seed": seed}
39
+ status = None
40
+ output = None
41
+ for attempt in range(retries + 1):
42
+ try:
43
+ if timeout is not None:
44
+ output = await asyncio.wait_for(
45
+ scenario_fn(runner, payload), timeout=timeout
46
+ )
47
+ else:
48
+ output = await scenario_fn(runner, payload)
49
+ status = "success"
50
+ break
51
+ except TimeoutError:
52
+ status = "timeout"
53
+ output = None
54
+ break # Do not retry timeouts
55
+ except Exception:
56
+ if attempt < retries:
57
+ continue
58
+ status = "error"
59
+ output = None
60
+ break
61
+
62
+ # Apply metrics
63
+ metrics = {}
64
+ if self.config.metrics:
65
+ metric_reg = MetricRegistry()
66
+ for metric_name in self.config.metrics:
67
+ try:
68
+ metric = metric_reg.create_metric(metric_name)
69
+ if metric:
70
+ metrics[metric_name] = (
71
+ metric.calculate(output or {})
72
+ if output is not None
73
+ else 0.0
74
+ )
75
+ except Exception:
76
+ metrics[metric_name] = 0.0
77
+
78
+ run_report = RunReport(
79
+ status=status,
80
+ output=output,
81
+ seed=seed,
82
+ metrics=metrics,
83
+ )
84
+ scenario_runs.append(run_report)
85
+
86
+ # Compute aggregates for metrics
87
+ aggregates = {}
88
+ if scenario_runs and self.config.metrics:
89
+ all_metrics = set()
90
+ for r in scenario_runs:
91
+ all_metrics.update(r.metrics.keys())
92
+ mean_metrics = {}
93
+ for m in all_metrics:
94
+ values = [r.metrics.get(m, 0.0) for r in scenario_runs]
95
+ mean_metrics[m] = sum(values) / len(values)
96
+ aggregates["metrics"] = {"mean": mean_metrics}
97
+
98
+ # Apply validators
99
+ if self.config.validators and scenario_runs:
100
+ validations = []
101
+ val_reg = ValidatorRegistry()
102
+ run_data = [r.model_dump() for r in scenario_runs]
103
+ for val_name in self.config.validators:
104
+ try:
105
+ validator = val_reg.create_validator(val_name)
106
+ if validator:
107
+ result = validator.validate(run_data)
108
+ result_dict = (
109
+ result.to_dict()
110
+ if hasattr(result, "to_dict")
111
+ else result
112
+ )
113
+ validations.append(
114
+ {"name": val_name, "result": result_dict}
115
+ )
116
+ except Exception:
117
+ validations.append(
118
+ {
119
+ "name": val_name,
120
+ "result": {
121
+ "is_valid": False,
122
+ "error": "Validation failed",
123
+ },
124
+ }
125
+ )
126
+ aggregates["validations"] = validations
127
+
128
+ scenario_report = ScenarioReport(
129
+ key=scenario_spec.key,
130
+ runs=scenario_runs,
131
+ aggregates=aggregates,
132
+ )
133
+ scenario_reports.append(scenario_report)
134
+
135
+ return EngineReport(scenario_reports=scenario_reports)
@@ -0,0 +1,62 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel, field_validator
4
+
5
+
6
+ class RunnerSpec(BaseModel):
7
+ key: str
8
+ config: dict[str, Any]
9
+
10
+
11
+ class ScenarioSpec(BaseModel):
12
+ key: str
13
+ timeout_seconds: float | None = None
14
+ repetitions: int | None = None
15
+ seeds: list[int] | None = None
16
+
17
+
18
+ class EngineConfig(BaseModel):
19
+ scenarios: list[ScenarioSpec]
20
+ runners: list[RunnerSpec]
21
+ metrics: list[str] | None = None
22
+ validators: list[str] | None = None
23
+ parallelism: int = 1
24
+ retries: int = 0
25
+
26
+ @field_validator("parallelism")
27
+ @classmethod
28
+ def validate_parallelism(cls, v):
29
+ if v <= 0:
30
+ raise ValueError("parallelism must be greater than 0")
31
+ return v
32
+
33
+ @field_validator("scenarios")
34
+ @classmethod
35
+ def validate_scenarios(cls, v):
36
+ if not v:
37
+ raise ValueError("scenarios cannot be empty")
38
+ return v
39
+
40
+ @field_validator("runners")
41
+ @classmethod
42
+ def validate_runners(cls, v):
43
+ if not v:
44
+ raise ValueError("runners cannot be empty")
45
+ return v
46
+
47
+
48
+ class RunReport(BaseModel):
49
+ status: str
50
+ output: Any = None
51
+ seed: int | None = None
52
+ metrics: dict[str, Any] = {}
53
+
54
+
55
+ class ScenarioReport(BaseModel):
56
+ key: str
57
+ runs: list[RunReport] = []
58
+ aggregates: dict[str, Any] = {}
59
+
60
+
61
+ class EngineReport(BaseModel):
62
+ scenario_reports: list[ScenarioReport]
@@ -0,0 +1,30 @@
1
+ """Metrics module for benchmarking."""
2
+
3
+ # Import metric modules to register them
4
+ from . import (
5
+ accuracy_score,
6
+ aggregate,
7
+ completeness,
8
+ cost_efficiency,
9
+ custom_scriptable,
10
+ keyword_coverage,
11
+ policy_compliance,
12
+ robustness,
13
+ technical_performance,
14
+ )
15
+ from .registry import get_metric, list_metrics, register_metric
16
+
17
+ __all__ = [
18
+ "accuracy_score",
19
+ "aggregate",
20
+ "completeness",
21
+ "cost_efficiency",
22
+ "custom_scriptable",
23
+ "keyword_coverage",
24
+ "policy_compliance",
25
+ "robustness",
26
+ "technical_performance",
27
+ "get_metric",
28
+ "list_metrics",
29
+ "register_metric",
30
+ ]
@@ -0,0 +1,27 @@
1
+ """Accuracy score metric."""
2
+
3
+ from typing import Any
4
+
5
+ from .registry import register_metric
6
+
7
+
8
+ @register_metric("accuracy_score")
9
+ def accuracy_score(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
10
+ """Calculate accuracy score."""
11
+ output = run.get("output", "")
12
+ expected = config.get("expected_output", "")
13
+ mode = config.get("mode", "exact")
14
+ case_insensitive = config.get("case_insensitive", False)
15
+
16
+ if case_insensitive:
17
+ output = output.lower()
18
+ expected = expected.lower()
19
+
20
+ if mode == "exact":
21
+ score = 1.0 if output == expected else 0.0
22
+ elif mode == "contains":
23
+ score = 1.0 if expected in output else 0.0
24
+ else:
25
+ score = 0.0 # default
26
+
27
+ return {"mode": mode, "accuracy": score}
@@ -0,0 +1,39 @@
1
+ """Aggregation utilities for metrics."""
2
+
3
+ from typing import Any
4
+
5
+
6
+ def aggregate_all(data: list[dict[str, Any]]) -> dict[str, Any]:
7
+ """Aggregate all metrics from a list of data points."""
8
+ if not data:
9
+ return {}
10
+ # Simple implementation: return the first item
11
+ return data[0] if data else {}
12
+
13
+
14
+ def aggregate_metric_values(data: list[dict[str, Any]], field: str) -> dict[str, Any]:
15
+ """Aggregate metric values for a specific field."""
16
+ values = [item.get(field) for item in data if field in item]
17
+ if not values:
18
+ return {}
19
+ # Filter out None values for type safety
20
+ clean_values = [v for v in values if v is not None]
21
+ if not clean_values:
22
+ return {}
23
+ numeric_values = [v for v in clean_values if isinstance(v, (int, float))]
24
+ boolean_values = [v for v in clean_values if isinstance(v, bool)]
25
+ by_field = {}
26
+ if numeric_values:
27
+ by_field["numeric"] = {
28
+ "mean": sum(numeric_values) / len(numeric_values),
29
+ "min": min(numeric_values),
30
+ "max": max(numeric_values),
31
+ "count": len(numeric_values),
32
+ }
33
+ if boolean_values:
34
+ by_field["boolean"] = {
35
+ "success_rate": sum(boolean_values) / len(boolean_values),
36
+ "true_count": sum(1 for b in boolean_values if b),
37
+ "false_count": len(boolean_values) - sum(1 for b in boolean_values if b),
38
+ }
39
+ return by_field
@@ -0,0 +1,38 @@
1
+ """Completeness metric."""
2
+
3
+ from typing import Any
4
+
5
+ from .registry import register_metric
6
+
7
+
8
+ @register_metric("completeness")
9
+ def completeness(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
10
+ """Calculate completeness."""
11
+ output = run.get("output", {})
12
+ required_fields = config.get("required_fields", [])
13
+ allow_nested = config.get("allow_nested", False)
14
+ if not required_fields:
15
+ return {"required": 0, "present": 0, "completeness": 1.0}
16
+
17
+ present = 0
18
+ for field in required_fields:
19
+ if allow_nested and "." in field:
20
+ keys = field.split(".")
21
+ current = output
22
+ found = True
23
+ for key in keys:
24
+ if isinstance(current, dict) and key in current:
25
+ current = current[key]
26
+ else:
27
+ found = False
28
+ break
29
+ if found:
30
+ present += 1
31
+ elif field in output:
32
+ present += 1
33
+
34
+ return {
35
+ "required": len(required_fields),
36
+ "present": present,
37
+ "completeness": present / len(required_fields),
38
+ }
@@ -0,0 +1,32 @@
1
+ """Cost efficiency metric."""
2
+
3
+ from typing import Any
4
+
5
+ from .registry import register_metric
6
+
7
+
8
+ @register_metric("cost_efficiency")
9
+ def cost_efficiency(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
10
+ """Calculate cost efficiency."""
11
+ output = run.get("output", {})
12
+ cost = output.get("cost", 0)
13
+ token_usage = output.get("token_usage", {})
14
+ total_tokens = token_usage.get("total_tokens", 0)
15
+ token_to_cost_rate = config.get("token_to_cost_rate", 0)
16
+ score_value = config.get("score_value", 1.0)
17
+
18
+ if cost > 0:
19
+ efficiency = score_value / cost
20
+ supported = True
21
+ reason = None
22
+ elif total_tokens > 0 and token_to_cost_rate > 0:
23
+ cost = total_tokens * token_to_cost_rate
24
+ efficiency = score_value / cost
25
+ supported = True
26
+ reason = None
27
+ else:
28
+ efficiency = 0.0
29
+ supported = False
30
+ reason = "missing_usage"
31
+
32
+ return {"supported": supported, "reason": reason, "efficiency": efficiency}
@@ -0,0 +1,17 @@
1
+ """Custom scriptable metric."""
2
+
3
+ from typing import Any
4
+
5
+ from .registry import register_metric
6
+
7
+
8
+ @register_metric("custom_scriptable")
9
+ def custom_scriptable(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
10
+ """Calculate custom scriptable metric."""
11
+ expression = config.get("expression", "0")
12
+ try:
13
+ # Simple eval with run and config in scope, restricted globals
14
+ result = eval(expression, {"__builtins__": {}}, {"run": run, "config": config})
15
+ return {"result": result, "expression": expression}
16
+ except Exception as e:
17
+ return {"result": False, "error": str(e), "expression": expression}
@@ -0,0 +1,41 @@
1
+ """Keyword coverage metric."""
2
+
3
+ from typing import Any
4
+
5
+ from .registry import register_metric
6
+
7
+
8
+ @register_metric("keyword_coverage")
9
+ def keyword_coverage(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
10
+ """Calculate keyword coverage."""
11
+ field_path = config.get("field_path", "")
12
+ keywords = config.get("keywords", [])
13
+ unique_match = config.get("unique_match", True)
14
+
15
+ text = ""
16
+ data = run.get("output", {})
17
+ if field_path:
18
+ keys = field_path.split(".")
19
+ for key in keys:
20
+ if isinstance(data, dict):
21
+ data = data.get(key, "")
22
+ else:
23
+ break
24
+ text = str(data) if data else ""
25
+ else:
26
+ text = str(data)
27
+
28
+ if not keywords:
29
+ return {"keyword_total": 0, "keyword_hits": 0, "coverage": 0.0}
30
+
31
+ if unique_match:
32
+ found = set(kw for kw in keywords if kw in text)
33
+ hits = len(found)
34
+ else:
35
+ hits = sum(text.count(kw) for kw in keywords)
36
+
37
+ return {
38
+ "keyword_total": len(keywords),
39
+ "keyword_hits": hits,
40
+ "coverage": hits / len(keywords),
41
+ }
@@ -0,0 +1,18 @@
1
+ """Policy compliance metric."""
2
+
3
+ from typing import Any
4
+
5
+ from .registry import register_metric
6
+
7
+
8
+ @register_metric("policy_compliance")
9
+ def policy_compliance(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
10
+ """Calculate policy compliance."""
11
+ output = run.get("output", {})
12
+ violations = output.get("policy_violations", 0)
13
+ if isinstance(violations, (list, tuple)):
14
+ violations = len(violations)
15
+ elif isinstance(violations, dict):
16
+ violations = violations.get("count", 0)
17
+ compliant = violations == 0
18
+ return {"policy_violations": violations, "compliant": compliant}
@@ -0,0 +1,57 @@
1
+ """Registry for metrics."""
2
+
3
+ from collections.abc import Callable
4
+
5
+
6
+ class MetricRegistry:
7
+ _metrics: dict[str, Callable] = {}
8
+
9
+ @classmethod
10
+ def register(cls, name: str, metric_class: Callable) -> None:
11
+ """Register a metric class."""
12
+ cls._metrics[name] = metric_class
13
+
14
+ @classmethod
15
+ def create_metric(cls, name: str, config=None) -> Callable | None:
16
+ """Create a metric instance."""
17
+ fn = cls._metrics.get(name)
18
+ if fn:
19
+ return fn
20
+ return None
21
+
22
+ @classmethod
23
+ def get_metric(cls, name: str) -> Callable | None:
24
+ """Get a metric class by name."""
25
+ return cls._metrics.get(name)
26
+
27
+ @classmethod
28
+ def list_metrics(cls) -> list[str]:
29
+ """List all registered metric names."""
30
+ return list(cls._metrics.keys())
31
+
32
+
33
+ # Global instance for function-based API
34
+ registry = MetricRegistry()
35
+
36
+
37
+ def get_metric(name: str) -> Callable:
38
+ """Get a metric by name, raising KeyError if not found."""
39
+ metric = registry.get_metric(name)
40
+ if metric is None:
41
+ raise KeyError(f"Metric '{name}' not found")
42
+ return metric
43
+
44
+
45
+ def list_metrics() -> list[str]:
46
+ """List all registered metric names."""
47
+ return registry.list_metrics()
48
+
49
+
50
+ def register_metric(name: str):
51
+ """Decorator to register a metric function with the given name."""
52
+
53
+ def decorator(func: Callable) -> Callable:
54
+ registry.register(name, func)
55
+ return func
56
+
57
+ return decorator