strands-env 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,6 +49,7 @@ from strands.models import Model
49
49
  from strands.models.bedrock import BedrockModel
50
50
  from strands.models.openai import OpenAIModel
51
51
  from strands_sglang import SGLangClient, SGLangModel
52
+ from strands_sglang.tool_parser import HermesToolCallParser, ToolCallParser
52
53
  from transformers import PreTrainedTokenizerBase
53
54
 
54
55
  #: Factory that produces a fresh `Model` per step (for concurrent step isolation).
@@ -66,6 +67,7 @@ def sglang_model_factory(
66
67
  model_id: str,
67
68
  tokenizer: PreTrainedTokenizerBase,
68
69
  client: SGLangClient,
70
+ tool_call_parser: ToolCallParser = HermesToolCallParser(),
69
71
  sampling_params: dict[str, Any] = DEFAULT_SAMPLING_PARAMS,
70
72
  enable_thinking: bool | None = None,
71
73
  ) -> ModelFactory:
@@ -81,6 +83,7 @@ def sglang_model_factory(
81
83
  return lambda: SGLangModel(
82
84
  tokenizer=tokenizer,
83
85
  client=client,
86
+ tool_call_parser=tool_call_parser,
84
87
  params=sampling_params,
85
88
  model_id=model_id,
86
89
  return_logprobs=True,
@@ -124,6 +127,7 @@ def bedrock_model_factory(
124
127
  model_id=model_id,
125
128
  boto_session=boto_session,
126
129
  boto_client_config=boto_client_config,
130
+ streaming=False,
127
131
  **sampling_params,
128
132
  )
129
133
 
strands_env/core/types.py CHANGED
@@ -17,6 +17,7 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import logging
20
+ import uuid
20
21
  from abc import ABC, abstractmethod
21
22
  from enum import Enum
22
23
  from typing import Any
@@ -41,6 +42,7 @@ class TaskContext(BaseModel):
41
42
 
42
43
  model_config = ConfigDict(extra="allow")
43
44
 
45
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
44
46
  ground_truth: Any = None
45
47
  conversation_history: Messages = Field(default_factory=list)
46
48
 
@@ -0,0 +1,20 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Environments for strands-env."""
16
+
17
+ from .calculator import CalculatorEnv
18
+ from .code_sandbox import CodeMode, CodeSandboxEnv
19
+
20
+ __all__ = ["CalculatorEnv", "CodeMode", "CodeSandboxEnv"]
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simple math environment with a calculator tool."""
16
+
17
+ from .env import CalculatorEnv
18
+
19
+ __all__ = ["CalculatorEnv"]
@@ -0,0 +1,30 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simple math environment using a calculator tool."""
16
+
17
+ from pathlib import Path
18
+
19
+ from strands_tools import calculator
20
+
21
+ from strands_env.core.environment import Environment
22
+
23
+
24
+ class CalculatorEnv(Environment):
25
+ """Simple math environment using a calculator tool."""
26
+
27
+ default_system_prompt_path = Path(__file__).parent / "system_prompt.md"
28
+
29
+ def get_tools(self):
30
+ return [calculator]
@@ -0,0 +1 @@
1
+ You are a math problem solver. Solve the given problem step by step using the calculator tool when needed. Put your final answer in \boxed{}.
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Code sandbox environment using AWS Bedrock AgentCore Code Interpreter."""
16
+
17
+ from .env import CodeMode, CodeSandboxEnv
18
+
19
+ __all__ = ["CodeMode", "CodeSandboxEnv"]
@@ -0,0 +1,114 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Code sandbox environment using AWS Bedrock AgentCore Code Interpreter."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from enum import Enum
20
+ from pathlib import Path
21
+ from typing import TYPE_CHECKING
22
+
23
+ from strands_env.core.environment import Environment
24
+ from strands_env.tools import CodeInterpreterToolkit
25
+ from strands_env.utils.aws import get_boto3_session
26
+
27
+ if TYPE_CHECKING:
28
+ import boto3
29
+
30
+ from strands_env.core.types import ModelFactory, RewardFunction
31
+
32
+
33
+ class CodeMode(str, Enum):
34
+ """Tool modes for CodeSandboxEnv."""
35
+
36
+ CODE = "code"
37
+ """Only `execute_code` tool (Python execution)."""
38
+
39
+ TERMINAL = "terminal"
40
+ """Only `execute_command` tool (shell commands)."""
41
+
42
+ CODE_AND_TERMINAL = "code_and_terminal"
43
+ """Both `execute_code` and `execute_command` tools."""
44
+
45
+
46
+ class CodeSandboxEnv(Environment):
47
+ """Code sandbox environment using AWS Bedrock AgentCore Code Interpreter.
48
+
49
+ Provides `execute_code` (Python) and/or `execute_command` (shell) tools
50
+ depending on the configured `CodeMode`.
51
+
52
+ Example:
53
+ from strands_env.environments.code_sandbox import CodeSandboxEnv, CodeMode
54
+ from strands_env.utils import get_boto3_session
55
+
56
+ session = get_boto3_session(region="us-east-1")
57
+ env = CodeSandboxEnv(
58
+ boto3_session=session,
59
+ model_factory=model_factory,
60
+ mode=CodeMode.CODE, # Only Python execution
61
+ )
62
+
63
+ result = await env.step(action)
64
+ await env.cleanup() # Clean up code interpreter session
65
+ """
66
+
67
+ default_system_prompt_path = Path(__file__).parent / "system_prompt.md"
68
+
69
+ def __init__(
70
+ self,
71
+ *,
72
+ model_factory: ModelFactory,
73
+ system_prompt: str | None = None,
74
+ reward_fn: RewardFunction | None = None,
75
+ max_tool_iterations: int = 10,
76
+ verbose: bool = False,
77
+ boto3_session: boto3.Session | None = None,
78
+ mode: CodeMode = CodeMode.CODE,
79
+ ):
80
+ """Initialize the code sandbox environment.
81
+
82
+ Args:
83
+ boto3_session: boto3 session for AWS credentials.
84
+ model_factory: Factory function that creates a fresh Model per step.
85
+ system_prompt: Optional system prompt override.
86
+ reward_fn: Optional reward function to compute rewards.
87
+ max_tool_iterations: Maximum tool iterations per step.
88
+ verbose: Whether to print verbose output.
89
+ mode: Tool mode - CODE, TERMINAL, or CODE_AND_TERMINAL.
90
+ """
91
+ super().__init__(
92
+ model_factory=model_factory,
93
+ reward_fn=reward_fn,
94
+ system_prompt=system_prompt,
95
+ max_tool_iterations=max_tool_iterations,
96
+ verbose=verbose,
97
+ )
98
+ self.mode = mode
99
+ self._toolkit = CodeInterpreterToolkit(
100
+ boto3_session=boto3_session or get_boto3_session(), session_name="strands-env-code-sandbox"
101
+ )
102
+
103
+ def get_tools(self):
104
+ """Return tools based on configured mode."""
105
+ tool_map = {
106
+ CodeMode.CODE: [self._toolkit.execute_code],
107
+ CodeMode.TERMINAL: [self._toolkit.execute_command],
108
+ CodeMode.CODE_AND_TERMINAL: [self._toolkit.execute_code, self._toolkit.execute_command],
109
+ }
110
+ return tool_map[self.mode]
111
+
112
+ async def cleanup(self) -> None:
113
+ """Clean up code interpreter session."""
114
+ self._toolkit.cleanup()
@@ -0,0 +1,9 @@
1
+ You are a helpful coding assistant with access to a sandboxed execution environment.
2
+
3
+ Use the available tools to write and execute code to solve problems.
4
+
5
+ When solving problems:
6
+ 1. Break down complex tasks into smaller steps
7
+ 2. Write and execute code to verify your solutions
8
+ 3. Use print statements to show intermediate results
9
+ 4. Handle errors gracefully and retry if needed
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .aime import AIMEEvaluator
16
+ from .evaluator import EvalSample, Evaluator
17
+ from .metrics import MetricFn, pass_at_k_metric
18
+
19
+ __all__ = [
20
+ "AIMEEvaluator",
21
+ "EvalSample",
22
+ "Evaluator",
23
+ "MetricFn",
24
+ "pass_at_k_metric",
25
+ ]
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """AIME (American Invitational Mathematics Examination) evaluator."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from collections.abc import Iterable
21
+ from typing import Literal
22
+
23
+ from datasets import load_dataset
24
+
25
+ from strands_env.core import Action, TaskContext
26
+
27
+ from .evaluator import Evaluator
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ _AIME_DATASETS = {
32
+ "2024": "HuggingFaceH4/aime_2024",
33
+ "2025": "MathArena/aime_2025",
34
+ }
35
+
36
+
37
+ class AIMEEvaluator(Evaluator):
38
+ """Evaluator for AIME math competition problems."""
39
+
40
+ benchmark_name = "AIME"
41
+
42
+ def load_dataset(self, version: Literal["2024", "2025"] = "2024") -> Iterable[Action]:
43
+ """Load AIME dataset from HuggingFace."""
44
+ self.benchmark_name = f"{self.benchmark_name}_{version}"
45
+ dataset = load_dataset(_AIME_DATASETS[version], split="train")
46
+
47
+ actions = []
48
+ for i, row in enumerate(dataset):
49
+ problem, answer = row.get("problem"), row.get("answer")
50
+ if problem is None or answer is None:
51
+ logger.warning(f"Row {i}: missing problem/answer, skipped")
52
+ continue
53
+ actions.append(
54
+ Action(
55
+ message=str(problem),
56
+ task_context=TaskContext(
57
+ id=f"{self.benchmark_name}_{row.get('id', i)}",
58
+ ground_truth=str(answer),
59
+ ),
60
+ )
61
+ )
62
+
63
+ logger.info(f"[{self.benchmark_name}] Loaded {len(actions)}/{len(dataset)} prompts")
64
+ return actions
@@ -0,0 +1,221 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for running agentic benchmarks with `strands-env` environments."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import asyncio
20
+ import json
21
+ import logging
22
+ from collections import defaultdict
23
+ from collections.abc import Awaitable, Callable, Iterable
24
+ from functools import partial
25
+ from pathlib import Path
26
+
27
+ from pydantic import BaseModel
28
+ from tqdm import tqdm
29
+ from tqdm.contrib.logging import logging_redirect_tqdm
30
+
31
+ from strands_env.core import Action, Environment, StepResult
32
+
33
+ from .metrics import MetricFn, pass_at_k_metric
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ #: Type alias for environment factory function (async).
38
+ AsyncEnvFactory = Callable[[Action], Awaitable[Environment]]
39
+
40
+
41
+ class EvalSample(BaseModel):
42
+ """Evaluation sample result."""
43
+
44
+ action: Action
45
+ """The action (task) that was evaluated."""
46
+
47
+ step_result: StepResult
48
+ """The result of the step (observation, reward, termination reason)."""
49
+
50
+
51
+ class Evaluator:
52
+ """Evaluator for running concurrent environment evaluations."""
53
+
54
+ benchmark_name: str = ""
55
+ """Benchmark identifier. Override in subclasses."""
56
+
57
+ def __init__(
58
+ self,
59
+ env_factory: AsyncEnvFactory,
60
+ *,
61
+ max_concurrency: int = 10,
62
+ n_samples_per_prompt: int = 1,
63
+ output_path: Path | str = Path.cwd() / "results.jsonl",
64
+ save_interval: int = 10,
65
+ keep_tokens: bool = False,
66
+ metric_fns: list[MetricFn] = [],
67
+ ):
68
+ """Initialize the evaluator.
69
+
70
+ Args:
71
+ env_factory: Async factory function that creates a fresh Environment per sample.
72
+ max_concurrency: Maximum concurrent evaluate_sample() calls.
73
+ n_samples_per_prompt: Number of samples per prompt (for pass@k, set to max(k_values)).
74
+ output_path: Path to JSONL file for saving results. Enables resume.
75
+ save_interval: Flush results to disk every N completed samples.
76
+ keep_tokens: Keep token-level observation in results (only valid for `SGLangModel` backends).
77
+ metric_fns: Additional metric functions. `pass@k` is always included.
78
+ """
79
+ self.env_factory: AsyncEnvFactory = env_factory
80
+ self.max_concurrency = max_concurrency
81
+ self.n_samples_per_prompt = n_samples_per_prompt
82
+ self.output_path = Path(output_path)
83
+ self.save_interval = save_interval
84
+ self.keep_tokens = keep_tokens
85
+
86
+ # Always include pass@k, then any additional metrics
87
+ self.metric_fns: list[MetricFn] = [
88
+ partial(pass_at_k_metric, k_values=list(range(1, n_samples_per_prompt + 1)), reward_threshold=1.0)
89
+ ]
90
+ self.metric_fns += metric_fns
91
+
92
+ # Runtime state
93
+ self.results: dict[str, list[EvalSample]] = defaultdict(list)
94
+ self.completed_ids: set[str] = set()
95
+
96
+ def load_dataset(self) -> Iterable[Action]:
97
+ """Load dataset. Override in subclasses."""
98
+ raise NotImplementedError("Subclasses must implement load_dataset()")
99
+
100
+ def load_results(self) -> None:
101
+ """Load completed samples from checkpoint file."""
102
+ if not self.output_path.exists():
103
+ return
104
+
105
+ self.results = defaultdict(list)
106
+ self.completed_ids = set()
107
+
108
+ with open(self.output_path, encoding="utf-8") as f:
109
+ for line in f:
110
+ data = json.loads(line)
111
+ prompt_id = data.pop("prompt_id")
112
+ sample = EvalSample.model_validate(data)
113
+ self.results[prompt_id].append(sample)
114
+ self.completed_ids.add(sample.action.task_context.id)
115
+
116
+ total = sum(len(s) for s in self.results.values())
117
+ logger.info(f"Resumed {total} samples from {self.output_path}")
118
+
119
+ def save_results(self) -> None:
120
+ """Save all samples to checkpoint file."""
121
+ self.output_path.parent.mkdir(parents=True, exist_ok=True)
122
+ with open(self.output_path, "w", encoding="utf-8") as f:
123
+ for prompt_id, samples in self.results.items():
124
+ for sample in samples:
125
+ data = sample.model_dump()
126
+ data["prompt_id"] = prompt_id
127
+ f.write(json.dumps(data, ensure_ascii=False) + "\n")
128
+
129
+ async def evaluate_sample(self, action: Action) -> EvalSample:
130
+ """Evaluate a single sample."""
131
+ env = await self.env_factory(action)
132
+ await env.reset()
133
+ step_result = await env.step(action)
134
+ if not self.keep_tokens:
135
+ step_result.observation.tokens = None
136
+ await env.cleanup()
137
+ # Runtime logging for debugging
138
+ reward_str = f"{step_result.reward.reward:.2f}" if step_result.reward else "N/A"
139
+ reward_info = step_result.reward.info if step_result.reward else {}
140
+ logger.info(
141
+ f"[{action.task_context.id}]: "
142
+ f"reward={reward_str} | "
143
+ f"label={action.task_context.ground_truth} | "
144
+ f"reward_info={reward_info} | "
145
+ f"metrics={step_result.observation.metrics}"
146
+ )
147
+ return EvalSample(action=action, step_result=step_result)
148
+
149
+ async def run(self, actions: Iterable[Action]) -> dict[str, list[EvalSample]]:
150
+ """Run evaluation on actions with n_samples_per_prompt each.
151
+
152
+ Args:
153
+ actions: Actions to evaluate.
154
+
155
+ Returns:
156
+ Dict mapping prompt_id to list of EvalSample results.
157
+ """
158
+ self.load_results()
159
+
160
+ # Expand actions to (prompt_id, sample_id, action) tuples
161
+ to_process: list[tuple[str, str, Action]] = []
162
+ for action in actions:
163
+ prompt_id = action.task_context.id
164
+ for i in range(self.n_samples_per_prompt):
165
+ sample_id = f"{prompt_id}_{i}"
166
+ if sample_id not in self.completed_ids:
167
+ expanded = action.model_copy(deep=True)
168
+ expanded.task_context.id = sample_id
169
+ to_process.append((prompt_id, sample_id, expanded))
170
+
171
+ semaphore = asyncio.Semaphore(self.max_concurrency)
172
+ save_counter = 0
173
+ total = len(to_process)
174
+
175
+ async def process(prompt_id: str, sample_id: str, action: Action, pbar: tqdm) -> None:
176
+ nonlocal save_counter
177
+ async with semaphore:
178
+ sample = await self.evaluate_sample(action)
179
+ self.results[prompt_id].append(sample)
180
+ self.completed_ids.add(sample_id)
181
+ pbar.update(1)
182
+ save_counter += 1
183
+ if save_counter >= self.save_interval:
184
+ self.save_results()
185
+ save_counter = 0
186
+
187
+ with logging_redirect_tqdm():
188
+ with tqdm(total=total, desc=f"Evaluating {self.benchmark_name}", unit="sample", dynamic_ncols=True) as pbar:
189
+ await asyncio.gather(*[process(pid, sid, a, pbar) for pid, sid, a in to_process])
190
+ self.save_results()
191
+ return dict(self.results)
192
+
193
+ def compute_metrics(self, results: dict[str, list[EvalSample]], log: bool = True) -> dict[str, float]:
194
+ """Compute all metrics on results.
195
+
196
+ Args:
197
+ results: Dict mapping prompt_id to sample results.
198
+ log: Whether to log the metrics summary.
199
+
200
+ Returns:
201
+ Dict mapping metric names to values.
202
+ """
203
+ metrics = {}
204
+ for fn in self.metric_fns:
205
+ metrics.update(fn(results))
206
+
207
+ if log and metrics:
208
+ n_prompts = len(results)
209
+ n_samples = sum(len(s) for s in results.values())
210
+ name = self.benchmark_name or "Evaluation"
211
+
212
+ # Build formatted output
213
+ lines = [f"{'─' * 40}", f" {name} Results", f"{'─' * 40}"]
214
+ lines.append(f" Prompts: {n_prompts} Samples (n={self.n_samples_per_prompt}): {n_samples}")
215
+ lines.append("")
216
+ for metric, value in sorted(metrics.items()):
217
+ lines.append(f" {metric:<12} {value:>6.1%}")
218
+ lines.append(f"{'─' * 40}")
219
+ logger.info("\n" + "\n".join(lines))
220
+
221
+ return metrics
@@ -0,0 +1,70 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluation metrics for benchmark results."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from collections.abc import Callable
21
+ from typing import TYPE_CHECKING
22
+
23
+ if TYPE_CHECKING:
24
+ from .evaluator import EvalSample
25
+
26
+ #: Type alias for metric function: takes results {prompt_id: [EvalSample, ...]}, returns {metric_name: value}.
27
+ MetricFn = Callable[[dict[str, list["EvalSample"]]], dict[str, float]]
28
+
29
+
30
+ def pass_at_k_metric(
31
+ results: dict[str, list["EvalSample"]],
32
+ k_values: list[int],
33
+ reward_threshold: float = 1.0,
34
+ ) -> dict[str, float]:
35
+ """Compute pass@k metrics using unbiased estimator.
36
+
37
+ Args:
38
+ results: Dict mapping prompt_id to list of samples.
39
+ k_values: List of k values for pass@k.
40
+ reward_threshold: Reward threshold for "pass" (default: 1.0).
41
+
42
+ Returns:
43
+ Dict mapping "pass@k" to average score.
44
+ """
45
+ if not results:
46
+ return {f"pass@{k}": 0.0 for k in k_values}
47
+
48
+ def is_correct(s: EvalSample) -> bool:
49
+ r = s.step_result.reward
50
+ return r is not None and r.reward >= reward_threshold
51
+
52
+ def pass_at_k_single(n: int, c: int, k: int) -> float:
53
+ """Unbiased estimator: 1 - C(n-c, k) / C(n, k)."""
54
+ if n - c < k:
55
+ return 1.0
56
+ if c == 0:
57
+ return 0.0
58
+ log_ratio = sum(math.log(n - c - i) - math.log(n - i) for i in range(k))
59
+ return 1.0 - math.exp(log_ratio)
60
+
61
+ metrics = {}
62
+ for k in k_values:
63
+ scores = []
64
+ for samples in results.values():
65
+ n, c = len(samples), sum(1 for s in samples if is_correct(s))
66
+ if k <= n:
67
+ scores.append(pass_at_k_single(n, c, k))
68
+ metrics[f"pass@{k}"] = sum(scores) / len(scores) if scores else 0.0
69
+
70
+ return metrics
@@ -0,0 +1,21 @@
1
+ # Copyright 2025 Horizon RL Contributors
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Reward functions for strands-env."""
16
+
17
+ from .math_reward import MathRewardFunction
18
+
19
+ __all__ = [
20
+ "MathRewardFunction",
21
+ ]