fba-bench-core 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fba_bench_core/__init__.py +11 -0
- fba_bench_core/agents/__init__.py +15 -0
- fba_bench_core/agents/base.py +83 -0
- fba_bench_core/agents/registry.py +16 -0
- fba_bench_core/benchmarking/__init__.py +6 -0
- fba_bench_core/benchmarking/core/__init__.py +1 -0
- fba_bench_core/benchmarking/engine/__init__.py +12 -0
- fba_bench_core/benchmarking/engine/core.py +135 -0
- fba_bench_core/benchmarking/engine/models.py +62 -0
- fba_bench_core/benchmarking/metrics/__init__.py +30 -0
- fba_bench_core/benchmarking/metrics/accuracy_score.py +27 -0
- fba_bench_core/benchmarking/metrics/aggregate.py +39 -0
- fba_bench_core/benchmarking/metrics/completeness.py +38 -0
- fba_bench_core/benchmarking/metrics/cost_efficiency.py +32 -0
- fba_bench_core/benchmarking/metrics/custom_scriptable.py +17 -0
- fba_bench_core/benchmarking/metrics/keyword_coverage.py +41 -0
- fba_bench_core/benchmarking/metrics/policy_compliance.py +18 -0
- fba_bench_core/benchmarking/metrics/registry.py +57 -0
- fba_bench_core/benchmarking/metrics/robustness.py +27 -0
- fba_bench_core/benchmarking/metrics/technical_performance.py +16 -0
- fba_bench_core/benchmarking/registry.py +48 -0
- fba_bench_core/benchmarking/scenarios/__init__.py +1 -0
- fba_bench_core/benchmarking/scenarios/base.py +36 -0
- fba_bench_core/benchmarking/scenarios/complex_marketplace.py +181 -0
- fba_bench_core/benchmarking/scenarios/multiturn_tool_use.py +176 -0
- fba_bench_core/benchmarking/scenarios/registry.py +18 -0
- fba_bench_core/benchmarking/scenarios/research_summarization.py +141 -0
- fba_bench_core/benchmarking/validators/__init__.py +24 -0
- fba_bench_core/benchmarking/validators/determinism_check.py +95 -0
- fba_bench_core/benchmarking/validators/fairness_balance.py +75 -0
- fba_bench_core/benchmarking/validators/outlier_detection.py +53 -0
- fba_bench_core/benchmarking/validators/registry.py +57 -0
- fba_bench_core/benchmarking/validators/reproducibility_metadata.py +74 -0
- fba_bench_core/benchmarking/validators/schema_adherence.py +59 -0
- fba_bench_core/benchmarking/validators/structural_consistency.py +74 -0
- fba_bench_core/config.py +154 -0
- fba_bench_core/domain/__init__.py +75 -0
- fba_bench_core/domain/events/__init__.py +230 -0
- fba_bench_core/domain/events/analytics.py +69 -0
- fba_bench_core/domain/events/base.py +59 -0
- fba_bench_core/domain/events/inventory.py +119 -0
- fba_bench_core/domain/events/marketing.py +102 -0
- fba_bench_core/domain/events/pricing.py +179 -0
- fba_bench_core/domain/models.py +296 -0
- fba_bench_core/exceptions/__init__.py +9 -0
- fba_bench_core/exceptions/base.py +46 -0
- fba_bench_core/services/__init__.py +12 -0
- fba_bench_core/services/base.py +52 -0
- fba_bench_core-1.0.0.dist-info/METADATA +152 -0
- fba_bench_core-1.0.0.dist-info/RECORD +52 -0
- fba_bench_core-1.0.0.dist-info/WHEEL +4 -0
- fba_bench_core-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,27 @@
|
|
1
|
+
"""Robustness metric."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_metric
|
6
|
+
|
7
|
+
|
8
|
+
@register_metric("robustness")
|
9
|
+
def robustness(run: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
|
10
|
+
"""Calculate robustness."""
|
11
|
+
output = run.get("output", "")
|
12
|
+
expected_signal = config.get("expected_signal", "")
|
13
|
+
mode = config.get("mode", "exact")
|
14
|
+
|
15
|
+
if mode == "exact_casefold":
|
16
|
+
score = 1.0 if output.lower() == expected_signal.lower() else 0.0
|
17
|
+
elif mode == "normalized_overlap":
|
18
|
+
# Simple overlap for now
|
19
|
+
score = (
|
20
|
+
1.0
|
21
|
+
if any(word in output.lower() for word in expected_signal.lower().split())
|
22
|
+
else 0.0
|
23
|
+
)
|
24
|
+
else:
|
25
|
+
score = 0.0
|
26
|
+
|
27
|
+
return {"mode": mode, "robustness_score": score}
|
@@ -0,0 +1,16 @@
|
|
1
|
+
"""Technical performance metric."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_metric
|
6
|
+
|
7
|
+
|
8
|
+
@register_metric("technical_performance")
|
9
|
+
def technical_performance(
|
10
|
+
run: dict[str, Any], config: dict[str, Any]
|
11
|
+
) -> dict[str, Any]:
|
12
|
+
"""Calculate technical performance metrics."""
|
13
|
+
duration_ms = run.get("duration_ms", 0)
|
14
|
+
latency_threshold_ms = config.get("latency_threshold_ms", 1000)
|
15
|
+
fast_enough = duration_ms <= latency_threshold_ms
|
16
|
+
return {"latency_ms": duration_ms, "fast_enough": fast_enough}
|
@@ -0,0 +1,48 @@
|
|
1
|
+
"""Global registry for benchmarking components."""
|
2
|
+
|
3
|
+
from collections.abc import Callable
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
|
7
|
+
class MetricRegistry:
|
8
|
+
"""Registry for metrics."""
|
9
|
+
|
10
|
+
_metrics: dict[str, Callable] = {}
|
11
|
+
|
12
|
+
@classmethod
|
13
|
+
def register_metric(cls, name: str, metric_fn: Callable) -> None:
|
14
|
+
"""Register a metric function."""
|
15
|
+
cls._metrics[name] = metric_fn
|
16
|
+
|
17
|
+
@classmethod
|
18
|
+
def get_metric(cls, name: str) -> Callable:
|
19
|
+
"""Get a metric function by name."""
|
20
|
+
if name not in cls._metrics:
|
21
|
+
raise KeyError(f"Metric '{name}' not found")
|
22
|
+
return cls._metrics[name]
|
23
|
+
|
24
|
+
@classmethod
|
25
|
+
def list_metrics(cls) -> list[str]:
|
26
|
+
"""List all registered metric names."""
|
27
|
+
return list(cls._metrics.keys())
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def create_metric(cls, name: str, config: dict[str, Any] | None = None) -> Any:
|
31
|
+
"""Create a metric instance (for compatibility)."""
|
32
|
+
return cls.get_metric(name)
|
33
|
+
|
34
|
+
|
35
|
+
# Convenience functions
|
36
|
+
def register_metric(name: str, metric_fn: Callable) -> None:
|
37
|
+
"""Register a metric function."""
|
38
|
+
MetricRegistry.register_metric(name, metric_fn)
|
39
|
+
|
40
|
+
|
41
|
+
def get_metric(name: str) -> Callable:
|
42
|
+
"""Get a metric function by name."""
|
43
|
+
return MetricRegistry.get_metric(name)
|
44
|
+
|
45
|
+
|
46
|
+
def list_metrics() -> list[str]:
|
47
|
+
"""List all registered metric names."""
|
48
|
+
return MetricRegistry.list_metrics()
|
@@ -0,0 +1 @@
|
|
1
|
+
# Benchmarking scenarios module
|
@@ -0,0 +1,36 @@
|
|
1
|
+
"""Base classes for all scenario types in the FBA-Bench benchmarking framework."""
|
2
|
+
|
3
|
+
import abc
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
|
7
|
+
class BaseScenario(abc.ABC):
|
8
|
+
"""
|
9
|
+
Abstract base class for all benchmark scenarios.
|
10
|
+
|
11
|
+
This class defines the minimal interface required for a scenario to be
|
12
|
+
executable by the modern benchmarking engine.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, params: dict[str, Any] | None = None):
|
16
|
+
"""
|
17
|
+
Initializes the scenario with its parameters.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
params: A dictionary of parameters that configure this scenario.
|
21
|
+
"""
|
22
|
+
self.params = params or {}
|
23
|
+
|
24
|
+
@abc.abstractmethod
|
25
|
+
async def run(self, runner: Any, payload: dict[str, Any]) -> dict[str, Any]:
|
26
|
+
"""
|
27
|
+
Asynchronously executes the scenario with a given agent runner.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
runner: The agent runner instance.
|
31
|
+
payload: A dictionary containing runtime parameters, including the seed.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
A dictionary containing the results of the scenario run.
|
35
|
+
"""
|
36
|
+
pass
|
@@ -0,0 +1,181 @@
|
|
1
|
+
"""Complex marketplace scenario."""
|
2
|
+
|
3
|
+
import random
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
|
8
|
+
from .base import BaseScenario
|
9
|
+
|
10
|
+
|
11
|
+
class ComplexMarketplaceConfig(BaseModel):
|
12
|
+
"""Configuration for complex marketplace scenario."""
|
13
|
+
|
14
|
+
num_products: int = Field(default=10, gt=0, description="Number of products")
|
15
|
+
num_orders: int = Field(default=20, gt=0, description="Number of orders")
|
16
|
+
max_quantity: int = Field(default=5, gt=0, description="Maximum quantity per order")
|
17
|
+
price_variance: float = Field(
|
18
|
+
default=0.1, ge=0.0, le=1.0, description="Price variance"
|
19
|
+
)
|
20
|
+
allow_backorder: bool = Field(default=False, description="Allow backorders")
|
21
|
+
|
22
|
+
|
23
|
+
class ComplexMarketplaceScenario(BaseScenario):
|
24
|
+
"""
|
25
|
+
This class now correctly implements the modern BaseScenario API.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, params: dict[str, Any] | None = None):
|
29
|
+
"""
|
30
|
+
The constructor now only accepts a `params` dictionary.
|
31
|
+
"""
|
32
|
+
super().__init__(params)
|
33
|
+
self.config = ComplexMarketplaceConfig(**(self.params or {}))
|
34
|
+
|
35
|
+
async def run(self, runner: Any, payload: dict[str, Any]) -> dict[str, Any]:
|
36
|
+
"""
|
37
|
+
All scenario logic is now contained within this single method.
|
38
|
+
"""
|
39
|
+
seed = payload.get("seed", 42)
|
40
|
+
rng = random.Random(seed)
|
41
|
+
|
42
|
+
# --- 1. Setup Phase (Logic from old `initialize` and `setup_for_agent`) ---
|
43
|
+
world_state = self.setup_simulation(rng)
|
44
|
+
|
45
|
+
original_orders = world_state["remaining_orders"][:]
|
46
|
+
total_requested = sum(order["quantity"] for order in original_orders)
|
47
|
+
|
48
|
+
# --- 2. Execution Loop (Logic from old `run` and `update_tick`) ---
|
49
|
+
for _ in range(self.config.num_orders):
|
50
|
+
self.simulate_market_changes(world_state, self.config.price_variance, rng)
|
51
|
+
|
52
|
+
if not world_state["remaining_orders"]:
|
53
|
+
continue
|
54
|
+
|
55
|
+
current_order = world_state["remaining_orders"].pop(0)
|
56
|
+
|
57
|
+
# Interact with the agent runner
|
58
|
+
agent_input = self.get_percepts_for_agent(world_state, current_order)
|
59
|
+
agent_actions = await runner.process(agent_input)
|
60
|
+
self.apply_agent_actions(world_state, agent_actions, current_order)
|
61
|
+
|
62
|
+
# --- 3. Evaluation and Results (Logic from old `evaluate_agent_performance`) ---
|
63
|
+
final_metrics = self.calculate_final_kpis(world_state, total_requested)
|
64
|
+
|
65
|
+
return {
|
66
|
+
"metrics": final_metrics,
|
67
|
+
"final_world_state": world_state,
|
68
|
+
}
|
69
|
+
|
70
|
+
def setup_simulation(self, rng: random.Random) -> dict[str, Any]:
|
71
|
+
config = self.config
|
72
|
+
num_products = config.num_products
|
73
|
+
catalog: list[dict[str, Any]] = []
|
74
|
+
for i in range(num_products):
|
75
|
+
base_price = 10.0 + i * 0.5
|
76
|
+
price = base_price * (
|
77
|
+
1 + rng.uniform(-config.price_variance, config.price_variance)
|
78
|
+
)
|
79
|
+
catalog.append({"id": i, "price": round(price, 2)})
|
80
|
+
|
81
|
+
orders: list[dict[str, int]] = []
|
82
|
+
for _ in range(config.num_orders):
|
83
|
+
product_id = rng.randint(0, num_products - 1)
|
84
|
+
quantity = rng.randint(1, config.max_quantity)
|
85
|
+
orders.append({"product_id": product_id, "quantity": quantity})
|
86
|
+
|
87
|
+
inventory = {i: rng.randint(10, 100) for i in range(num_products)}
|
88
|
+
|
89
|
+
world_state = {
|
90
|
+
"inventory": inventory,
|
91
|
+
"catalog": catalog,
|
92
|
+
"remaining_orders": orders,
|
93
|
+
"history": [],
|
94
|
+
"total_revenue": 0.0,
|
95
|
+
"total_fulfilled": 0,
|
96
|
+
}
|
97
|
+
return world_state
|
98
|
+
|
99
|
+
def simulate_market_changes(
|
100
|
+
self,
|
101
|
+
world_state: dict[str, Any],
|
102
|
+
volatility: float,
|
103
|
+
rng: random.Random,
|
104
|
+
) -> None:
|
105
|
+
for product in world_state["catalog"]:
|
106
|
+
product["price"] *= 1 + rng.uniform(-volatility, volatility)
|
107
|
+
product["price"] = round(product["price"], 2)
|
108
|
+
|
109
|
+
def get_percepts_for_agent(
|
110
|
+
self,
|
111
|
+
world_state: dict[str, Any],
|
112
|
+
current_order: dict[str, int],
|
113
|
+
) -> dict[str, Any]:
|
114
|
+
return {
|
115
|
+
"type": "process_order",
|
116
|
+
"order": current_order,
|
117
|
+
"available_inventory": {
|
118
|
+
pid: qty for pid, qty in world_state["inventory"].items() if qty > 0
|
119
|
+
},
|
120
|
+
"catalog": [p.copy() for p in world_state["catalog"]],
|
121
|
+
}
|
122
|
+
|
123
|
+
def apply_agent_actions(
|
124
|
+
self,
|
125
|
+
world_state: dict[str, Any],
|
126
|
+
actions: dict[str, Any],
|
127
|
+
current_order: dict[str, int],
|
128
|
+
) -> None:
|
129
|
+
product_id = current_order["product_id"]
|
130
|
+
requested = current_order["quantity"]
|
131
|
+
available = world_state["inventory"].get(product_id, 0)
|
132
|
+
|
133
|
+
fulfilled_quantity = min(
|
134
|
+
actions.get("fulfilled_quantity", requested), requested
|
135
|
+
)
|
136
|
+
|
137
|
+
use_backorder = (
|
138
|
+
actions.get("use_backorder", False) and self.config.allow_backorder
|
139
|
+
)
|
140
|
+
if use_backorder:
|
141
|
+
world_state["inventory"][product_id] -= fulfilled_quantity
|
142
|
+
else:
|
143
|
+
fulfilled_quantity = min(fulfilled_quantity, available)
|
144
|
+
world_state["inventory"][product_id] -= fulfilled_quantity
|
145
|
+
|
146
|
+
price = next(
|
147
|
+
p["price"] for p in world_state["catalog"] if p["id"] == product_id
|
148
|
+
)
|
149
|
+
revenue = fulfilled_quantity * price
|
150
|
+
world_state["total_revenue"] += revenue
|
151
|
+
world_state["total_fulfilled"] += fulfilled_quantity
|
152
|
+
|
153
|
+
world_state["history"].append(
|
154
|
+
{
|
155
|
+
"order": current_order,
|
156
|
+
"actions": actions,
|
157
|
+
"fulfilled": fulfilled_quantity,
|
158
|
+
"revenue": revenue,
|
159
|
+
"price": price,
|
160
|
+
}
|
161
|
+
)
|
162
|
+
|
163
|
+
def calculate_final_kpis(
|
164
|
+
self,
|
165
|
+
world_state: dict[str, Any],
|
166
|
+
total_requested: int,
|
167
|
+
) -> dict[str, Any]:
|
168
|
+
fulfilled_rate = (
|
169
|
+
world_state["total_fulfilled"] / total_requested if total_requested else 0.0
|
170
|
+
)
|
171
|
+
backorder_amount = sum(
|
172
|
+
abs(qty) for qty in world_state["inventory"].values() if qty < 0
|
173
|
+
)
|
174
|
+
|
175
|
+
return {
|
176
|
+
"total_revenue": round(world_state["total_revenue"], 2),
|
177
|
+
"fulfilled_rate": round(fulfilled_rate, 4),
|
178
|
+
"total_fulfilled": world_state["total_fulfilled"],
|
179
|
+
"total_requested": total_requested,
|
180
|
+
"backorder_amount": backorder_amount,
|
181
|
+
}
|
@@ -0,0 +1,176 @@
|
|
1
|
+
"""Multi-turn tool use scenario."""
|
2
|
+
|
3
|
+
import random
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
|
8
|
+
from .base import BaseScenario
|
9
|
+
|
10
|
+
|
11
|
+
class MultiTurnToolUseConfig(BaseModel):
|
12
|
+
"""Configuration for multi-turn tool use scenario."""
|
13
|
+
|
14
|
+
steps: int = Field(default=5, gt=0, description="Number of steps")
|
15
|
+
include_math: bool = Field(default=True, description="Include math operations")
|
16
|
+
include_extraction: bool = Field(
|
17
|
+
default=True, description="Include data extraction"
|
18
|
+
)
|
19
|
+
include_transform: bool = Field(
|
20
|
+
default=True, description="Include data transformation"
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class MultiturnToolUseScenario(BaseScenario):
|
25
|
+
"""
|
26
|
+
Scenario for evaluating agents on multi-turn tool usage across different capabilities.
|
27
|
+
|
28
|
+
Agents must demonstrate effective tool selection and usage over multiple sequential turns,
|
29
|
+
handling tasks like mathematical computations, data extraction, and transformations.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, params: dict[str, Any] | None = None):
|
33
|
+
"""
|
34
|
+
Initialize the multiturn tool use scenario.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
params: A dictionary of parameters that configure this scenario.
|
38
|
+
"""
|
39
|
+
super().__init__(params)
|
40
|
+
self.config = MultiTurnToolUseConfig(**self.params)
|
41
|
+
|
42
|
+
async def run(self, runner: Any, payload: dict[str, Any]) -> dict[str, Any]:
|
43
|
+
"""
|
44
|
+
Asynchronously executes the multiturn tool use scenario.
|
45
|
+
|
46
|
+
This method orchestrates a sequence of turns where the agent must use appropriate tools
|
47
|
+
for tasks involving math, extraction, or transformation based on enabled capabilities.
|
48
|
+
State is tracked across turns, and metrics are computed based on success rates.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
runner: The agent runner instance, expected to have an async `process` method
|
52
|
+
that takes input data and returns a response dict with 'success' bool
|
53
|
+
and optional 'result' for verification.
|
54
|
+
payload: Runtime parameters, including 'seed' for reproducible randomness.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Dictionary with scenario results, including metrics (success rates per capability),
|
58
|
+
final state (success counts, attempts), and interaction history.
|
59
|
+
"""
|
60
|
+
seed = payload.get("seed", 0)
|
61
|
+
rng = random.Random(seed)
|
62
|
+
|
63
|
+
# --- 1. Setup Phase ---
|
64
|
+
capabilities = self._determine_capabilities()
|
65
|
+
state = self._initialize_state(capabilities)
|
66
|
+
|
67
|
+
# --- 2. Execution Loop ---
|
68
|
+
for step in range(1, self.config.steps + 1):
|
69
|
+
capability = self._choose_capability(capabilities, rng)
|
70
|
+
agent_input = self._build_task_input(capability, rng)
|
71
|
+
response = await runner.process(agent_input)
|
72
|
+
self._record_interaction(state, step, capability, agent_input, response)
|
73
|
+
|
74
|
+
# --- 3. Evaluation ---
|
75
|
+
metrics = self._compute_metrics(state, self.config.steps, capabilities)
|
76
|
+
|
77
|
+
return {
|
78
|
+
"metrics": metrics,
|
79
|
+
"final_state": {
|
80
|
+
"successes": state["successes"],
|
81
|
+
"total_attempts": state["total_attempts"],
|
82
|
+
},
|
83
|
+
"interactions": state["interactions"],
|
84
|
+
}
|
85
|
+
|
86
|
+
def _determine_capabilities(self) -> list[str]:
|
87
|
+
capabilities: list[str] = []
|
88
|
+
if self.config.include_math:
|
89
|
+
capabilities.append("math")
|
90
|
+
if self.config.include_extraction:
|
91
|
+
capabilities.append("extraction")
|
92
|
+
if self.config.include_transform:
|
93
|
+
capabilities.append("transform")
|
94
|
+
if not capabilities:
|
95
|
+
capabilities.append("basic")
|
96
|
+
return capabilities
|
97
|
+
|
98
|
+
def _initialize_state(self, capabilities: list[str]) -> dict[str, Any]:
|
99
|
+
return {
|
100
|
+
"interactions": [],
|
101
|
+
"successes": {capability: 0 for capability in capabilities},
|
102
|
+
"total_attempts": {capability: 0 for capability in capabilities},
|
103
|
+
}
|
104
|
+
|
105
|
+
def _choose_capability(self, capabilities: list[str], rng: random.Random) -> str:
|
106
|
+
return rng.choice(capabilities)
|
107
|
+
|
108
|
+
def _build_task_input(self, capability: str, rng: random.Random) -> dict[str, Any]:
|
109
|
+
if capability == "math":
|
110
|
+
a, b = rng.randint(1, 100), rng.randint(1, 100)
|
111
|
+
return {
|
112
|
+
"task": "math",
|
113
|
+
"problem": f"Calculate the sum of {a} and {b}.",
|
114
|
+
"expected_result": a + b,
|
115
|
+
}
|
116
|
+
if capability == "extraction":
|
117
|
+
value = rng.randint(100, 999)
|
118
|
+
return {
|
119
|
+
"task": "extraction",
|
120
|
+
"text": f"The key value in this document is {value}. Extract it.",
|
121
|
+
"expected_result": value,
|
122
|
+
}
|
123
|
+
if capability == "transform":
|
124
|
+
data = [rng.randint(1, 10) for _ in range(5)]
|
125
|
+
return {
|
126
|
+
"task": "transform",
|
127
|
+
"data": data,
|
128
|
+
"operation": "sort ascending",
|
129
|
+
"expected_result": sorted(data),
|
130
|
+
}
|
131
|
+
return {
|
132
|
+
"task": "basic",
|
133
|
+
"query": "Perform a simple tool call to confirm functionality.",
|
134
|
+
}
|
135
|
+
|
136
|
+
def _record_interaction(
|
137
|
+
self,
|
138
|
+
state: dict[str, Any],
|
139
|
+
step: int,
|
140
|
+
capability: str,
|
141
|
+
agent_input: dict[str, Any],
|
142
|
+
response: dict[str, Any],
|
143
|
+
) -> None:
|
144
|
+
state["interactions"].append(
|
145
|
+
{
|
146
|
+
"step": step,
|
147
|
+
"task_type": capability,
|
148
|
+
"input": agent_input,
|
149
|
+
"response": response,
|
150
|
+
}
|
151
|
+
)
|
152
|
+
success = response.get("success", False)
|
153
|
+
if success:
|
154
|
+
state["successes"][capability] += 1
|
155
|
+
state["total_attempts"][capability] += 1
|
156
|
+
|
157
|
+
def _compute_metrics(
|
158
|
+
self,
|
159
|
+
state: dict[str, Any],
|
160
|
+
total_steps: int,
|
161
|
+
capabilities: list[str],
|
162
|
+
) -> dict[str, Any]:
|
163
|
+
overall_attempts = sum(state["total_attempts"].values())
|
164
|
+
total_successes = sum(state["successes"].values())
|
165
|
+
metrics: dict[str, Any] = {
|
166
|
+
"overall_success_rate": (
|
167
|
+
total_successes / overall_attempts if overall_attempts else 0.0
|
168
|
+
),
|
169
|
+
"steps_completed": total_steps,
|
170
|
+
}
|
171
|
+
for capability in capabilities:
|
172
|
+
attempts = state["total_attempts"][capability]
|
173
|
+
metrics[f"{capability}_success_rate"] = (
|
174
|
+
state["successes"][capability] / attempts if attempts else 0.0
|
175
|
+
)
|
176
|
+
return metrics
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
|
3
|
+
|
4
|
+
class ScenarioRegistry:
|
5
|
+
def __init__(self):
|
6
|
+
self._registry: dict[str, Callable] = {}
|
7
|
+
|
8
|
+
def register(self, key: str, fn: Callable):
|
9
|
+
self._registry[key] = fn
|
10
|
+
|
11
|
+
def get(self, key: str) -> Callable:
|
12
|
+
return self._registry[key]
|
13
|
+
|
14
|
+
def clear(self):
|
15
|
+
self._registry.clear()
|
16
|
+
|
17
|
+
|
18
|
+
scenario_registry = ScenarioRegistry()
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""Research summarization scenario."""
|
2
|
+
|
3
|
+
import random
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
|
8
|
+
from .base import BaseScenario
|
9
|
+
|
10
|
+
|
11
|
+
class ResearchSummarizationConfig(BaseModel):
|
12
|
+
"""Configuration for research summarization scenario."""
|
13
|
+
|
14
|
+
num_docs: int = Field(default=5, gt=0, description="Number of documents")
|
15
|
+
max_tokens: int = Field(
|
16
|
+
default=200, gt=0, description="Maximum tokens per document"
|
17
|
+
)
|
18
|
+
focus_keywords: list[str] = Field(
|
19
|
+
default=["research", "findings", "methodology"],
|
20
|
+
description="Keywords to focus on",
|
21
|
+
)
|
22
|
+
noise_probability: float = Field(
|
23
|
+
default=0.1, ge=0.0, le=0.5, description="Probability of noise"
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class ResearchSummarizationScenario(BaseScenario):
|
28
|
+
"""
|
29
|
+
This class now correctly implements the modern BaseScenario API.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, params: dict[str, Any] | None = None):
|
33
|
+
"""
|
34
|
+
The constructor now only accepts a `params` dictionary.
|
35
|
+
"""
|
36
|
+
super().__init__(params)
|
37
|
+
self.config = ResearchSummarizationConfig(**(self.params or {}))
|
38
|
+
|
39
|
+
async def run(self, runner: Any, payload: dict[str, Any]) -> dict[str, Any]:
|
40
|
+
"""
|
41
|
+
All scenario logic is now contained within this single method.
|
42
|
+
"""
|
43
|
+
seed = payload.get("seed", 0)
|
44
|
+
rng = random.Random(seed)
|
45
|
+
|
46
|
+
# --- 1. Setup Phase ---
|
47
|
+
documents = self.setup_documents(rng)
|
48
|
+
|
49
|
+
# --- 2. Execution Loop ---
|
50
|
+
summaries = []
|
51
|
+
for doc in documents:
|
52
|
+
prompt = self.create_prompt(doc)
|
53
|
+
response = await runner.process(prompt)
|
54
|
+
summary_text = (
|
55
|
+
response.get("content", str(response))
|
56
|
+
if isinstance(response, dict)
|
57
|
+
else str(response)
|
58
|
+
)
|
59
|
+
summaries.append(
|
60
|
+
{
|
61
|
+
"doc_id": doc["id"],
|
62
|
+
"summary": summary_text,
|
63
|
+
"original_content": doc["content"],
|
64
|
+
}
|
65
|
+
)
|
66
|
+
|
67
|
+
# --- 3. Evaluation and Results ---
|
68
|
+
metrics = self.evaluate_summaries(summaries)
|
69
|
+
|
70
|
+
return {
|
71
|
+
"metrics": metrics,
|
72
|
+
"summaries": [s["summary"] for s in summaries],
|
73
|
+
"documents": [d["content"] for d in documents],
|
74
|
+
}
|
75
|
+
|
76
|
+
def setup_documents(self, rng: random.Random) -> list[dict[str, Any]]:
|
77
|
+
documents = []
|
78
|
+
for i in range(self.config.num_docs):
|
79
|
+
base_content = f"Research paper {i + 1}: This study explores key findings in {', '.join(self.config.focus_keywords)}. "
|
80
|
+
base_content += f"The methodology involved analysis with approximately {self.config.max_tokens} tokens of data. "
|
81
|
+
|
82
|
+
if rng.random() < self.config.noise_probability:
|
83
|
+
base_content += (
|
84
|
+
"Irrelevant detail: weather conditions during the study. "
|
85
|
+
)
|
86
|
+
|
87
|
+
content = base_content * (
|
88
|
+
self.config.max_tokens // len(base_content.split()) + 1
|
89
|
+
)
|
90
|
+
content = " ".join(content.split()[: self.config.max_tokens // 4])
|
91
|
+
|
92
|
+
documents.append({"id": i, "content": content})
|
93
|
+
return documents
|
94
|
+
|
95
|
+
def create_prompt(self, doc: dict[str, Any]) -> str:
|
96
|
+
return (
|
97
|
+
f"Summarize the following research paper, focusing on the key {', '.join(self.config.focus_keywords)}.\n\n"
|
98
|
+
f"Paper content:\n{doc['content']}\n\n"
|
99
|
+
f"Provide a concise summary of 100-200 words highlighting the main findings and methodology."
|
100
|
+
)
|
101
|
+
|
102
|
+
def evaluate_summaries(self, summaries: list[dict[str, Any]]) -> dict[str, Any]:
|
103
|
+
if not summaries:
|
104
|
+
return {
|
105
|
+
"average_quality_score": 0.0,
|
106
|
+
"keyword_coverage": 0.0,
|
107
|
+
"conciseness_score": 0.0,
|
108
|
+
"total_documents": 0,
|
109
|
+
}
|
110
|
+
|
111
|
+
total_coverage = 0.0
|
112
|
+
total_conciseness = 0.0
|
113
|
+
num_keywords = len(self.config.focus_keywords)
|
114
|
+
if num_keywords == 0:
|
115
|
+
num_keywords = 1
|
116
|
+
|
117
|
+
for summary_info in summaries:
|
118
|
+
summary_lower = summary_info["summary"].lower()
|
119
|
+
keywords_lower = [kw.lower() for kw in self.config.focus_keywords]
|
120
|
+
|
121
|
+
coverage = (
|
122
|
+
sum(1 for kw in keywords_lower if kw in summary_lower) / num_keywords
|
123
|
+
)
|
124
|
+
total_coverage += coverage
|
125
|
+
|
126
|
+
original_words = len(summary_info["original_content"].split())
|
127
|
+
summary_words = len(summary_info["summary"].split())
|
128
|
+
conciseness = 1.0 if 0.1 <= summary_words / original_words <= 0.3 else 0.5
|
129
|
+
|
130
|
+
total_conciseness += conciseness
|
131
|
+
|
132
|
+
avg_coverage = total_coverage / len(summaries)
|
133
|
+
avg_conciseness = total_conciseness / len(summaries)
|
134
|
+
avg_quality = (avg_coverage + avg_conciseness) / 2
|
135
|
+
|
136
|
+
return {
|
137
|
+
"average_quality_score": round(avg_quality, 4),
|
138
|
+
"keyword_coverage": round(avg_coverage, 4),
|
139
|
+
"conciseness_score": round(avg_conciseness, 4),
|
140
|
+
"total_documents": len(summaries),
|
141
|
+
}
|
@@ -0,0 +1,24 @@
|
|
1
|
+
"""Validators module for benchmarking."""
|
2
|
+
|
3
|
+
# Import validator modules to register them
|
4
|
+
from . import (
|
5
|
+
determinism_check,
|
6
|
+
fairness_balance,
|
7
|
+
outlier_detection,
|
8
|
+
reproducibility_metadata,
|
9
|
+
schema_adherence,
|
10
|
+
structural_consistency,
|
11
|
+
)
|
12
|
+
from .registry import get_validator, list_validators, register_validator
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"determinism_check",
|
16
|
+
"fairness_balance",
|
17
|
+
"get_validator",
|
18
|
+
"list_validators",
|
19
|
+
"outlier_detection",
|
20
|
+
"reproducibility_metadata",
|
21
|
+
"register_validator",
|
22
|
+
"schema_adherence",
|
23
|
+
"structural_consistency",
|
24
|
+
]
|