strands-env 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strands_env/core/models.py +4 -0
- strands_env/core/types.py +2 -0
- strands_env/environments/__init__.py +20 -0
- strands_env/environments/calculator/__init__.py +19 -0
- strands_env/environments/calculator/env.py +30 -0
- strands_env/environments/calculator/system_prompt.md +1 -0
- strands_env/environments/code_sandbox/__init__.py +19 -0
- strands_env/environments/code_sandbox/env.py +114 -0
- strands_env/environments/code_sandbox/system_prompt.md +9 -0
- strands_env/eval/__init__.py +25 -0
- strands_env/eval/aime.py +64 -0
- strands_env/eval/evaluator.py +221 -0
- strands_env/eval/metrics.py +70 -0
- strands_env/rewards/__init__.py +21 -0
- strands_env/rewards/math_reward.py +134 -0
- strands_env/tools/__init__.py +21 -0
- strands_env/tools/code_interpreter.py +192 -0
- strands_env/utils/__init__.py +29 -0
- strands_env/utils/aws.py +98 -0
- strands_env/utils/sglang.py +47 -0
- strands_env-0.1.1.dist-info/METADATA +203 -0
- strands_env-0.1.1.dist-info/RECORD +27 -0
- strands_env-0.1.0.dist-info/METADATA +0 -98
- strands_env-0.1.0.dist-info/RECORD +0 -9
- {strands_env-0.1.0.dist-info → strands_env-0.1.1.dist-info}/WHEEL +0 -0
- {strands_env-0.1.0.dist-info → strands_env-0.1.1.dist-info}/licenses/LICENSE +0 -0
strands_env/core/models.py
CHANGED
|
@@ -49,6 +49,7 @@ from strands.models import Model
|
|
|
49
49
|
from strands.models.bedrock import BedrockModel
|
|
50
50
|
from strands.models.openai import OpenAIModel
|
|
51
51
|
from strands_sglang import SGLangClient, SGLangModel
|
|
52
|
+
from strands_sglang.tool_parser import HermesToolCallParser, ToolCallParser
|
|
52
53
|
from transformers import PreTrainedTokenizerBase
|
|
53
54
|
|
|
54
55
|
#: Factory that produces a fresh `Model` per step (for concurrent step isolation).
|
|
@@ -66,6 +67,7 @@ def sglang_model_factory(
|
|
|
66
67
|
model_id: str,
|
|
67
68
|
tokenizer: PreTrainedTokenizerBase,
|
|
68
69
|
client: SGLangClient,
|
|
70
|
+
tool_call_parser: ToolCallParser = HermesToolCallParser(),
|
|
69
71
|
sampling_params: dict[str, Any] = DEFAULT_SAMPLING_PARAMS,
|
|
70
72
|
enable_thinking: bool | None = None,
|
|
71
73
|
) -> ModelFactory:
|
|
@@ -81,6 +83,7 @@ def sglang_model_factory(
|
|
|
81
83
|
return lambda: SGLangModel(
|
|
82
84
|
tokenizer=tokenizer,
|
|
83
85
|
client=client,
|
|
86
|
+
tool_call_parser=tool_call_parser,
|
|
84
87
|
params=sampling_params,
|
|
85
88
|
model_id=model_id,
|
|
86
89
|
return_logprobs=True,
|
|
@@ -124,6 +127,7 @@ def bedrock_model_factory(
|
|
|
124
127
|
model_id=model_id,
|
|
125
128
|
boto_session=boto_session,
|
|
126
129
|
boto_client_config=boto_client_config,
|
|
130
|
+
streaming=False,
|
|
127
131
|
**sampling_params,
|
|
128
132
|
)
|
|
129
133
|
|
strands_env/core/types.py
CHANGED
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
19
|
import logging
|
|
20
|
+
import uuid
|
|
20
21
|
from abc import ABC, abstractmethod
|
|
21
22
|
from enum import Enum
|
|
22
23
|
from typing import Any
|
|
@@ -41,6 +42,7 @@ class TaskContext(BaseModel):
|
|
|
41
42
|
|
|
42
43
|
model_config = ConfigDict(extra="allow")
|
|
43
44
|
|
|
45
|
+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
44
46
|
ground_truth: Any = None
|
|
45
47
|
conversation_history: Messages = Field(default_factory=list)
|
|
46
48
|
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Environments for strands-env."""
|
|
16
|
+
|
|
17
|
+
from .calculator import CalculatorEnv
|
|
18
|
+
from .code_sandbox import CodeMode, CodeSandboxEnv
|
|
19
|
+
|
|
20
|
+
__all__ = ["CalculatorEnv", "CodeMode", "CodeSandboxEnv"]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simple math environment with a calculator tool."""
|
|
16
|
+
|
|
17
|
+
from .env import CalculatorEnv
|
|
18
|
+
|
|
19
|
+
__all__ = ["CalculatorEnv"]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Simple math environment using a calculator tool."""
|
|
16
|
+
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
from strands_tools import calculator
|
|
20
|
+
|
|
21
|
+
from strands_env.core.environment import Environment
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CalculatorEnv(Environment):
|
|
25
|
+
"""Simple math environment using a calculator tool."""
|
|
26
|
+
|
|
27
|
+
default_system_prompt_path = Path(__file__).parent / "system_prompt.md"
|
|
28
|
+
|
|
29
|
+
def get_tools(self):
|
|
30
|
+
return [calculator]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
You are a math problem solver. Solve the given problem step by step using the calculator tool when needed. Put your final answer in \boxed{}.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Code sandbox environment using AWS Bedrock AgentCore Code Interpreter."""
|
|
16
|
+
|
|
17
|
+
from .env import CodeMode, CodeSandboxEnv
|
|
18
|
+
|
|
19
|
+
__all__ = ["CodeMode", "CodeSandboxEnv"]
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Code sandbox environment using AWS Bedrock AgentCore Code Interpreter."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import TYPE_CHECKING
|
|
22
|
+
|
|
23
|
+
from strands_env.core.environment import Environment
|
|
24
|
+
from strands_env.tools import CodeInterpreterToolkit
|
|
25
|
+
from strands_env.utils.aws import get_boto3_session
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
import boto3
|
|
29
|
+
|
|
30
|
+
from strands_env.core.types import ModelFactory, RewardFunction
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class CodeMode(str, Enum):
|
|
34
|
+
"""Tool modes for CodeSandboxEnv."""
|
|
35
|
+
|
|
36
|
+
CODE = "code"
|
|
37
|
+
"""Only `execute_code` tool (Python execution)."""
|
|
38
|
+
|
|
39
|
+
TERMINAL = "terminal"
|
|
40
|
+
"""Only `execute_command` tool (shell commands)."""
|
|
41
|
+
|
|
42
|
+
CODE_AND_TERMINAL = "code_and_terminal"
|
|
43
|
+
"""Both `execute_code` and `execute_command` tools."""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CodeSandboxEnv(Environment):
|
|
47
|
+
"""Code sandbox environment using AWS Bedrock AgentCore Code Interpreter.
|
|
48
|
+
|
|
49
|
+
Provides `execute_code` (Python) and/or `execute_command` (shell) tools
|
|
50
|
+
depending on the configured `CodeMode`.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
from strands_env.environments.code_sandbox import CodeSandboxEnv, CodeMode
|
|
54
|
+
from strands_env.utils import get_boto3_session
|
|
55
|
+
|
|
56
|
+
session = get_boto3_session(region="us-east-1")
|
|
57
|
+
env = CodeSandboxEnv(
|
|
58
|
+
boto3_session=session,
|
|
59
|
+
model_factory=model_factory,
|
|
60
|
+
mode=CodeMode.CODE, # Only Python execution
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
result = await env.step(action)
|
|
64
|
+
await env.cleanup() # Clean up code interpreter session
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
default_system_prompt_path = Path(__file__).parent / "system_prompt.md"
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
*,
|
|
72
|
+
model_factory: ModelFactory,
|
|
73
|
+
system_prompt: str | None = None,
|
|
74
|
+
reward_fn: RewardFunction | None = None,
|
|
75
|
+
max_tool_iterations: int = 10,
|
|
76
|
+
verbose: bool = False,
|
|
77
|
+
boto3_session: boto3.Session | None = None,
|
|
78
|
+
mode: CodeMode = CodeMode.CODE,
|
|
79
|
+
):
|
|
80
|
+
"""Initialize the code sandbox environment.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
boto3_session: boto3 session for AWS credentials.
|
|
84
|
+
model_factory: Factory function that creates a fresh Model per step.
|
|
85
|
+
system_prompt: Optional system prompt override.
|
|
86
|
+
reward_fn: Optional reward function to compute rewards.
|
|
87
|
+
max_tool_iterations: Maximum tool iterations per step.
|
|
88
|
+
verbose: Whether to print verbose output.
|
|
89
|
+
mode: Tool mode - CODE, TERMINAL, or CODE_AND_TERMINAL.
|
|
90
|
+
"""
|
|
91
|
+
super().__init__(
|
|
92
|
+
model_factory=model_factory,
|
|
93
|
+
reward_fn=reward_fn,
|
|
94
|
+
system_prompt=system_prompt,
|
|
95
|
+
max_tool_iterations=max_tool_iterations,
|
|
96
|
+
verbose=verbose,
|
|
97
|
+
)
|
|
98
|
+
self.mode = mode
|
|
99
|
+
self._toolkit = CodeInterpreterToolkit(
|
|
100
|
+
boto3_session=boto3_session or get_boto3_session(), session_name="strands-env-code-sandbox"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def get_tools(self):
|
|
104
|
+
"""Return tools based on configured mode."""
|
|
105
|
+
tool_map = {
|
|
106
|
+
CodeMode.CODE: [self._toolkit.execute_code],
|
|
107
|
+
CodeMode.TERMINAL: [self._toolkit.execute_command],
|
|
108
|
+
CodeMode.CODE_AND_TERMINAL: [self._toolkit.execute_code, self._toolkit.execute_command],
|
|
109
|
+
}
|
|
110
|
+
return tool_map[self.mode]
|
|
111
|
+
|
|
112
|
+
async def cleanup(self) -> None:
|
|
113
|
+
"""Clean up code interpreter session."""
|
|
114
|
+
self._toolkit.cleanup()
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
You are a helpful coding assistant with access to a sandboxed execution environment.
|
|
2
|
+
|
|
3
|
+
Use the available tools to write and execute code to solve problems.
|
|
4
|
+
|
|
5
|
+
When solving problems:
|
|
6
|
+
1. Break down complex tasks into smaller steps
|
|
7
|
+
2. Write and execute code to verify your solutions
|
|
8
|
+
3. Use print statements to show intermediate results
|
|
9
|
+
4. Handle errors gracefully and retry if needed
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .aime import AIMEEvaluator
|
|
16
|
+
from .evaluator import EvalSample, Evaluator
|
|
17
|
+
from .metrics import MetricFn, pass_at_k_metric
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"AIMEEvaluator",
|
|
21
|
+
"EvalSample",
|
|
22
|
+
"Evaluator",
|
|
23
|
+
"MetricFn",
|
|
24
|
+
"pass_at_k_metric",
|
|
25
|
+
]
|
strands_env/eval/aime.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""AIME (American Invitational Mathematics Examination) evaluator."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
from collections.abc import Iterable
|
|
21
|
+
from typing import Literal
|
|
22
|
+
|
|
23
|
+
from datasets import load_dataset
|
|
24
|
+
|
|
25
|
+
from strands_env.core import Action, TaskContext
|
|
26
|
+
|
|
27
|
+
from .evaluator import Evaluator
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
_AIME_DATASETS = {
|
|
32
|
+
"2024": "HuggingFaceH4/aime_2024",
|
|
33
|
+
"2025": "MathArena/aime_2025",
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AIMEEvaluator(Evaluator):
|
|
38
|
+
"""Evaluator for AIME math competition problems."""
|
|
39
|
+
|
|
40
|
+
benchmark_name = "AIME"
|
|
41
|
+
|
|
42
|
+
def load_dataset(self, version: Literal["2024", "2025"] = "2024") -> Iterable[Action]:
|
|
43
|
+
"""Load AIME dataset from HuggingFace."""
|
|
44
|
+
self.benchmark_name = f"{self.benchmark_name}_{version}"
|
|
45
|
+
dataset = load_dataset(_AIME_DATASETS[version], split="train")
|
|
46
|
+
|
|
47
|
+
actions = []
|
|
48
|
+
for i, row in enumerate(dataset):
|
|
49
|
+
problem, answer = row.get("problem"), row.get("answer")
|
|
50
|
+
if problem is None or answer is None:
|
|
51
|
+
logger.warning(f"Row {i}: missing problem/answer, skipped")
|
|
52
|
+
continue
|
|
53
|
+
actions.append(
|
|
54
|
+
Action(
|
|
55
|
+
message=str(problem),
|
|
56
|
+
task_context=TaskContext(
|
|
57
|
+
id=f"{self.benchmark_name}_{row.get('id', i)}",
|
|
58
|
+
ground_truth=str(answer),
|
|
59
|
+
),
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
logger.info(f"[{self.benchmark_name}] Loaded {len(actions)}/{len(dataset)} prompts")
|
|
64
|
+
return actions
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Evaluator for running agentic benchmarks with `strands-env` environments."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import asyncio
|
|
20
|
+
import json
|
|
21
|
+
import logging
|
|
22
|
+
from collections import defaultdict
|
|
23
|
+
from collections.abc import Awaitable, Callable, Iterable
|
|
24
|
+
from functools import partial
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
|
|
27
|
+
from pydantic import BaseModel
|
|
28
|
+
from tqdm import tqdm
|
|
29
|
+
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
30
|
+
|
|
31
|
+
from strands_env.core import Action, Environment, StepResult
|
|
32
|
+
|
|
33
|
+
from .metrics import MetricFn, pass_at_k_metric
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
#: Type alias for environment factory function (async).
|
|
38
|
+
AsyncEnvFactory = Callable[[Action], Awaitable[Environment]]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class EvalSample(BaseModel):
|
|
42
|
+
"""Evaluation sample result."""
|
|
43
|
+
|
|
44
|
+
action: Action
|
|
45
|
+
"""The action (task) that was evaluated."""
|
|
46
|
+
|
|
47
|
+
step_result: StepResult
|
|
48
|
+
"""The result of the step (observation, reward, termination reason)."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Evaluator:
|
|
52
|
+
"""Evaluator for running concurrent environment evaluations."""
|
|
53
|
+
|
|
54
|
+
benchmark_name: str = ""
|
|
55
|
+
"""Benchmark identifier. Override in subclasses."""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
env_factory: AsyncEnvFactory,
|
|
60
|
+
*,
|
|
61
|
+
max_concurrency: int = 10,
|
|
62
|
+
n_samples_per_prompt: int = 1,
|
|
63
|
+
output_path: Path | str = Path.cwd() / "results.jsonl",
|
|
64
|
+
save_interval: int = 10,
|
|
65
|
+
keep_tokens: bool = False,
|
|
66
|
+
metric_fns: list[MetricFn] = [],
|
|
67
|
+
):
|
|
68
|
+
"""Initialize the evaluator.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
env_factory: Async factory function that creates a fresh Environment per sample.
|
|
72
|
+
max_concurrency: Maximum concurrent evaluate_sample() calls.
|
|
73
|
+
n_samples_per_prompt: Number of samples per prompt (for pass@k, set to max(k_values)).
|
|
74
|
+
output_path: Path to JSONL file for saving results. Enables resume.
|
|
75
|
+
save_interval: Flush results to disk every N completed samples.
|
|
76
|
+
keep_tokens: Keep token-level observation in results (only valid for `SGLangModel` backends).
|
|
77
|
+
metric_fns: Additional metric functions. `pass@k` is always included.
|
|
78
|
+
"""
|
|
79
|
+
self.env_factory: AsyncEnvFactory = env_factory
|
|
80
|
+
self.max_concurrency = max_concurrency
|
|
81
|
+
self.n_samples_per_prompt = n_samples_per_prompt
|
|
82
|
+
self.output_path = Path(output_path)
|
|
83
|
+
self.save_interval = save_interval
|
|
84
|
+
self.keep_tokens = keep_tokens
|
|
85
|
+
|
|
86
|
+
# Always include pass@k, then any additional metrics
|
|
87
|
+
self.metric_fns: list[MetricFn] = [
|
|
88
|
+
partial(pass_at_k_metric, k_values=list(range(1, n_samples_per_prompt + 1)), reward_threshold=1.0)
|
|
89
|
+
]
|
|
90
|
+
self.metric_fns += metric_fns
|
|
91
|
+
|
|
92
|
+
# Runtime state
|
|
93
|
+
self.results: dict[str, list[EvalSample]] = defaultdict(list)
|
|
94
|
+
self.completed_ids: set[str] = set()
|
|
95
|
+
|
|
96
|
+
def load_dataset(self) -> Iterable[Action]:
|
|
97
|
+
"""Load dataset. Override in subclasses."""
|
|
98
|
+
raise NotImplementedError("Subclasses must implement load_dataset()")
|
|
99
|
+
|
|
100
|
+
def load_results(self) -> None:
|
|
101
|
+
"""Load completed samples from checkpoint file."""
|
|
102
|
+
if not self.output_path.exists():
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
self.results = defaultdict(list)
|
|
106
|
+
self.completed_ids = set()
|
|
107
|
+
|
|
108
|
+
with open(self.output_path, encoding="utf-8") as f:
|
|
109
|
+
for line in f:
|
|
110
|
+
data = json.loads(line)
|
|
111
|
+
prompt_id = data.pop("prompt_id")
|
|
112
|
+
sample = EvalSample.model_validate(data)
|
|
113
|
+
self.results[prompt_id].append(sample)
|
|
114
|
+
self.completed_ids.add(sample.action.task_context.id)
|
|
115
|
+
|
|
116
|
+
total = sum(len(s) for s in self.results.values())
|
|
117
|
+
logger.info(f"Resumed {total} samples from {self.output_path}")
|
|
118
|
+
|
|
119
|
+
def save_results(self) -> None:
|
|
120
|
+
"""Save all samples to checkpoint file."""
|
|
121
|
+
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
122
|
+
with open(self.output_path, "w", encoding="utf-8") as f:
|
|
123
|
+
for prompt_id, samples in self.results.items():
|
|
124
|
+
for sample in samples:
|
|
125
|
+
data = sample.model_dump()
|
|
126
|
+
data["prompt_id"] = prompt_id
|
|
127
|
+
f.write(json.dumps(data, ensure_ascii=False) + "\n")
|
|
128
|
+
|
|
129
|
+
async def evaluate_sample(self, action: Action) -> EvalSample:
|
|
130
|
+
"""Evaluate a single sample."""
|
|
131
|
+
env = await self.env_factory(action)
|
|
132
|
+
await env.reset()
|
|
133
|
+
step_result = await env.step(action)
|
|
134
|
+
if not self.keep_tokens:
|
|
135
|
+
step_result.observation.tokens = None
|
|
136
|
+
await env.cleanup()
|
|
137
|
+
# Runtime logging for debugging
|
|
138
|
+
reward_str = f"{step_result.reward.reward:.2f}" if step_result.reward else "N/A"
|
|
139
|
+
reward_info = step_result.reward.info if step_result.reward else {}
|
|
140
|
+
logger.info(
|
|
141
|
+
f"[{action.task_context.id}]: "
|
|
142
|
+
f"reward={reward_str} | "
|
|
143
|
+
f"label={action.task_context.ground_truth} | "
|
|
144
|
+
f"reward_info={reward_info} | "
|
|
145
|
+
f"metrics={step_result.observation.metrics}"
|
|
146
|
+
)
|
|
147
|
+
return EvalSample(action=action, step_result=step_result)
|
|
148
|
+
|
|
149
|
+
async def run(self, actions: Iterable[Action]) -> dict[str, list[EvalSample]]:
|
|
150
|
+
"""Run evaluation on actions with n_samples_per_prompt each.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
actions: Actions to evaluate.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Dict mapping prompt_id to list of EvalSample results.
|
|
157
|
+
"""
|
|
158
|
+
self.load_results()
|
|
159
|
+
|
|
160
|
+
# Expand actions to (prompt_id, sample_id, action) tuples
|
|
161
|
+
to_process: list[tuple[str, str, Action]] = []
|
|
162
|
+
for action in actions:
|
|
163
|
+
prompt_id = action.task_context.id
|
|
164
|
+
for i in range(self.n_samples_per_prompt):
|
|
165
|
+
sample_id = f"{prompt_id}_{i}"
|
|
166
|
+
if sample_id not in self.completed_ids:
|
|
167
|
+
expanded = action.model_copy(deep=True)
|
|
168
|
+
expanded.task_context.id = sample_id
|
|
169
|
+
to_process.append((prompt_id, sample_id, expanded))
|
|
170
|
+
|
|
171
|
+
semaphore = asyncio.Semaphore(self.max_concurrency)
|
|
172
|
+
save_counter = 0
|
|
173
|
+
total = len(to_process)
|
|
174
|
+
|
|
175
|
+
async def process(prompt_id: str, sample_id: str, action: Action, pbar: tqdm) -> None:
|
|
176
|
+
nonlocal save_counter
|
|
177
|
+
async with semaphore:
|
|
178
|
+
sample = await self.evaluate_sample(action)
|
|
179
|
+
self.results[prompt_id].append(sample)
|
|
180
|
+
self.completed_ids.add(sample_id)
|
|
181
|
+
pbar.update(1)
|
|
182
|
+
save_counter += 1
|
|
183
|
+
if save_counter >= self.save_interval:
|
|
184
|
+
self.save_results()
|
|
185
|
+
save_counter = 0
|
|
186
|
+
|
|
187
|
+
with logging_redirect_tqdm():
|
|
188
|
+
with tqdm(total=total, desc=f"Evaluating {self.benchmark_name}", unit="sample", dynamic_ncols=True) as pbar:
|
|
189
|
+
await asyncio.gather(*[process(pid, sid, a, pbar) for pid, sid, a in to_process])
|
|
190
|
+
self.save_results()
|
|
191
|
+
return dict(self.results)
|
|
192
|
+
|
|
193
|
+
def compute_metrics(self, results: dict[str, list[EvalSample]], log: bool = True) -> dict[str, float]:
|
|
194
|
+
"""Compute all metrics on results.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
results: Dict mapping prompt_id to sample results.
|
|
198
|
+
log: Whether to log the metrics summary.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Dict mapping metric names to values.
|
|
202
|
+
"""
|
|
203
|
+
metrics = {}
|
|
204
|
+
for fn in self.metric_fns:
|
|
205
|
+
metrics.update(fn(results))
|
|
206
|
+
|
|
207
|
+
if log and metrics:
|
|
208
|
+
n_prompts = len(results)
|
|
209
|
+
n_samples = sum(len(s) for s in results.values())
|
|
210
|
+
name = self.benchmark_name or "Evaluation"
|
|
211
|
+
|
|
212
|
+
# Build formatted output
|
|
213
|
+
lines = [f"{'─' * 40}", f" {name} Results", f"{'─' * 40}"]
|
|
214
|
+
lines.append(f" Prompts: {n_prompts} Samples (n={self.n_samples_per_prompt}): {n_samples}")
|
|
215
|
+
lines.append("")
|
|
216
|
+
for metric, value in sorted(metrics.items()):
|
|
217
|
+
lines.append(f" {metric:<12} {value:>6.1%}")
|
|
218
|
+
lines.append(f"{'─' * 40}")
|
|
219
|
+
logger.info("\n" + "\n".join(lines))
|
|
220
|
+
|
|
221
|
+
return metrics
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Evaluation metrics for benchmark results."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from typing import TYPE_CHECKING
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from .evaluator import EvalSample
|
|
25
|
+
|
|
26
|
+
#: Type alias for metric function: takes results {prompt_id: [EvalSample, ...]}, returns {metric_name: value}.
|
|
27
|
+
MetricFn = Callable[[dict[str, list["EvalSample"]]], dict[str, float]]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def pass_at_k_metric(
|
|
31
|
+
results: dict[str, list["EvalSample"]],
|
|
32
|
+
k_values: list[int],
|
|
33
|
+
reward_threshold: float = 1.0,
|
|
34
|
+
) -> dict[str, float]:
|
|
35
|
+
"""Compute pass@k metrics using unbiased estimator.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
results: Dict mapping prompt_id to list of samples.
|
|
39
|
+
k_values: List of k values for pass@k.
|
|
40
|
+
reward_threshold: Reward threshold for "pass" (default: 1.0).
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Dict mapping "pass@k" to average score.
|
|
44
|
+
"""
|
|
45
|
+
if not results:
|
|
46
|
+
return {f"pass@{k}": 0.0 for k in k_values}
|
|
47
|
+
|
|
48
|
+
def is_correct(s: EvalSample) -> bool:
|
|
49
|
+
r = s.step_result.reward
|
|
50
|
+
return r is not None and r.reward >= reward_threshold
|
|
51
|
+
|
|
52
|
+
def pass_at_k_single(n: int, c: int, k: int) -> float:
|
|
53
|
+
"""Unbiased estimator: 1 - C(n-c, k) / C(n, k)."""
|
|
54
|
+
if n - c < k:
|
|
55
|
+
return 1.0
|
|
56
|
+
if c == 0:
|
|
57
|
+
return 0.0
|
|
58
|
+
log_ratio = sum(math.log(n - c - i) - math.log(n - i) for i in range(k))
|
|
59
|
+
return 1.0 - math.exp(log_ratio)
|
|
60
|
+
|
|
61
|
+
metrics = {}
|
|
62
|
+
for k in k_values:
|
|
63
|
+
scores = []
|
|
64
|
+
for samples in results.values():
|
|
65
|
+
n, c = len(samples), sum(1 for s in samples if is_correct(s))
|
|
66
|
+
if k <= n:
|
|
67
|
+
scores.append(pass_at_k_single(n, c, k))
|
|
68
|
+
metrics[f"pass@{k}"] = sum(scores) / len(scores) if scores else 0.0
|
|
69
|
+
|
|
70
|
+
return metrics
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright 2025 Horizon RL Contributors
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Reward functions for strands-env."""
|
|
16
|
+
|
|
17
|
+
from .math_reward import MathRewardFunction
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"MathRewardFunction",
|
|
21
|
+
]
|