benchmax 0.1.2.dev0__tar.gz → 0.1.2.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/PKG-INFO +13 -16
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/README.md +8 -13
- benchmax-0.1.2.dev2/pyproject.toml +62 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/adapters/benchmax_wrapper.py +25 -10
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/adapters/skyrl/benchmax_data_process.py +57 -11
- benchmax-0.1.2.dev2/src/benchmax/adapters/skyrl/skyrl_adapter.py +311 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/base_env.py +5 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/crm/crm_env.py +4 -2
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/crm/workdir/reward_fn.py +31 -28
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/crm/workdir/salesforce_mcp.py +387 -180
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/data_utils.py +4 -3
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/excel_env.py +1 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py +9 -15
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/workdir/excel_utils.py +66 -29
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/workdir/reward_fn.py +10 -5
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/math/math_env.py +1 -1
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/math/workdir/reward_fn.py +6 -6
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/example_workdir/reward_fn.py +12 -15
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/parallel_mcp_env.py +119 -88
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/__init__.py +1 -1
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/local_provisioner.py +80 -29
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/manual_provisioner.py +8 -9
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py +9 -9
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/utils.py +18 -17
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/proxy_server.py +29 -24
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/server_pool.py +142 -66
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/types.py +3 -1
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/wikipedia/utils.py +9 -9
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/wikipedia/wiki_env.py +21 -11
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/PKG-INFO +13 -16
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/requires.txt +1 -1
- benchmax-0.1.2.dev0/pyproject.toml +0 -58
- benchmax-0.1.2.dev0/src/benchmax/adapters/skyrl/skyrl_adapter.py +0 -224
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/LICENSE +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/setup.cfg +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/adapters/__init__.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/__init__.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/workdir/__init__.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/__init__.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/base_provisioner.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/utils.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/prompts/__init__.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/prompts/tools.py +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/SOURCES.txt +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/dependency_links.txt +0 -0
- {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/top_level.txt +0 -0
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: benchmax
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev2
|
|
4
4
|
Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
|
|
5
5
|
Author: cgft.io
|
|
6
|
-
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Operating System :: OS Independent
|
|
8
|
+
Requires-Python: ==3.12.*
|
|
7
9
|
Description-Content-Type: text/markdown
|
|
8
10
|
License-File: LICENSE
|
|
9
11
|
Requires-Dist: aiohttp>=3.13.1
|
|
10
12
|
Requires-Dist: asyncio>=4.0.0
|
|
11
|
-
Requires-Dist: datasets>=4.
|
|
13
|
+
Requires-Dist: datasets>=4.0.0
|
|
12
14
|
Requires-Dist: fastmcp~=2.12.0
|
|
13
15
|
Requires-Dist: pyjwt>=2.10.1
|
|
14
16
|
Requires-Dist: skypilot~=0.8.1
|
|
@@ -36,24 +38,19 @@ Dynamic: license-file
|
|
|
36
38
|
|
|
37
39
|
## 📌 News
|
|
38
40
|
|
|
39
|
-
- **[
|
|
40
|
-
- **[
|
|
41
|
+
- **[29 Oct 2025]** 🎉 Added support for easy multi-node parallelization across all major cloud providers using [SkyPilot](https://github.com/skypilot-org/skypilot)
|
|
42
|
+
- **[29 Oct 2025]** 🎉 Integration with [SkyRL](https://github.com/NovaSky-AI/SkyRL) for distributed RL training across clusters
|
|
41
43
|
- **[Upcoming]** 🛠️ Integration with Tinker API.
|
|
42
44
|
|
|
43
45
|
## 📘 Quickstart
|
|
44
46
|
|
|
45
47
|
**Example: Multi-node parallelization of Excel Env with SkyRL and SkyPilot**
|
|
46
48
|
|
|
47
|
-
RL environments can be computationally expensive to run (e.g.
|
|
49
|
+
RL environments can be computationally expensive to run (e.g. running tests). To handle these workloads efficiently, we distribute rollouts across multiple nodes using **SkyPilot**, horizontally scaling `benchmax` across cloud providers like GCP, AWS, Azure, etc.
|
|
48
50
|
|
|
49
|
-
**SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined
|
|
51
|
+
**SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined in [`benchmax.envs.excel.excel_env.ExcelEnvSkypilot`](/src/benchmax/envs/excel/excel_env.py)
|
|
50
52
|
|
|
51
|
-
|
|
52
|
-
1. **Installation**
|
|
53
|
-
|
|
54
|
-
`pip install benchmax[excel,skyrl]`
|
|
55
|
-
|
|
56
|
-
2. **Prepare the dataset**
|
|
53
|
+
1. **Prepare the dataset**
|
|
57
54
|
|
|
58
55
|
```bash
|
|
59
56
|
uv run src/benchmax/adapters/skyrl/benchmax_data_process.py \
|
|
@@ -64,13 +61,13 @@ RL environments can be computationally expensive to run (e.g. codegen). To handl
|
|
|
64
61
|
|
|
65
62
|
Note: We are using `ExcelEnvLocal` instead of `ExcelEnvSkypilot` because the MCP is only used for listing tools to prepare the system prompt.
|
|
66
63
|
|
|
67
|
-
|
|
64
|
+
2. **Run training and parallelize Excel environment**
|
|
68
65
|
|
|
69
66
|
```bash
|
|
70
|
-
|
|
67
|
+
bash examples/skyrl/run_benchmax_excel.sh
|
|
71
68
|
```
|
|
72
69
|
|
|
73
|
-
This excel env example will spin up 5 nodes with
|
|
70
|
+
This excel env example will spin up 5 nodes with 20 servers per node (total 100 MCP server in parallel). For more details, check out [multi-node parallelization](/src/benchmax/envs/mcp/README.md) and [SkyRL integration](/examples/skyrl/README.md).
|
|
74
71
|
|
|
75
72
|
## ℹ️ Overview
|
|
76
73
|
|
|
@@ -20,24 +20,19 @@
|
|
|
20
20
|
|
|
21
21
|
## 📌 News
|
|
22
22
|
|
|
23
|
-
- **[
|
|
24
|
-
- **[
|
|
23
|
+
- **[29 Oct 2025]** 🎉 Added support for easy multi-node parallelization across all major cloud providers using [SkyPilot](https://github.com/skypilot-org/skypilot)
|
|
24
|
+
- **[29 Oct 2025]** 🎉 Integration with [SkyRL](https://github.com/NovaSky-AI/SkyRL) for distributed RL training across clusters
|
|
25
25
|
- **[Upcoming]** 🛠️ Integration with Tinker API.
|
|
26
26
|
|
|
27
27
|
## 📘 Quickstart
|
|
28
28
|
|
|
29
29
|
**Example: Multi-node parallelization of Excel Env with SkyRL and SkyPilot**
|
|
30
30
|
|
|
31
|
-
RL environments can be computationally expensive to run (e.g.
|
|
31
|
+
RL environments can be computationally expensive to run (e.g. running tests). To handle these workloads efficiently, we distribute rollouts across multiple nodes using **SkyPilot**, horizontally scaling `benchmax` across cloud providers like GCP, AWS, Azure, etc.
|
|
32
32
|
|
|
33
|
-
**SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined
|
|
33
|
+
**SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined in [`benchmax.envs.excel.excel_env.ExcelEnvSkypilot`](/src/benchmax/envs/excel/excel_env.py)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
1. **Installation**
|
|
37
|
-
|
|
38
|
-
`pip install benchmax[excel,skyrl]`
|
|
39
|
-
|
|
40
|
-
2. **Prepare the dataset**
|
|
35
|
+
1. **Prepare the dataset**
|
|
41
36
|
|
|
42
37
|
```bash
|
|
43
38
|
uv run src/benchmax/adapters/skyrl/benchmax_data_process.py \
|
|
@@ -48,13 +43,13 @@ RL environments can be computationally expensive to run (e.g. codegen). To handl
|
|
|
48
43
|
|
|
49
44
|
Note: We are using `ExcelEnvLocal` instead of `ExcelEnvSkypilot` because the MCP is only used for listing tools to prepare the system prompt.
|
|
50
45
|
|
|
51
|
-
|
|
46
|
+
2. **Run training and parallelize Excel environment**
|
|
52
47
|
|
|
53
48
|
```bash
|
|
54
|
-
|
|
49
|
+
bash examples/skyrl/run_benchmax_excel.sh
|
|
55
50
|
```
|
|
56
51
|
|
|
57
|
-
This excel env example will spin up 5 nodes with
|
|
52
|
+
This excel env example will spin up 5 nodes with 20 servers per node (total 100 MCP server in parallel). For more details, check out [multi-node parallelization](/src/benchmax/envs/mcp/README.md) and [SkyRL integration](/examples/skyrl/README.md).
|
|
58
53
|
|
|
59
54
|
## ℹ️ Overview
|
|
60
55
|
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "benchmax"
|
|
3
|
+
version = "0.1.2.dev2"
|
|
4
|
+
description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [{ name = "cgft.io" }]
|
|
7
|
+
requires-python = "==3.12.*"
|
|
8
|
+
dependencies = [
|
|
9
|
+
"aiohttp>=3.13.1",
|
|
10
|
+
"asyncio>=4.0.0",
|
|
11
|
+
"datasets>=4.0.0",
|
|
12
|
+
"fastmcp~=2.12.0",
|
|
13
|
+
"pyjwt>=2.10.1",
|
|
14
|
+
"skypilot~=0.8.1",
|
|
15
|
+
]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"Operating System :: OS Independent",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
[build-system]
|
|
22
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
23
|
+
build-backend = "setuptools.build_meta"
|
|
24
|
+
|
|
25
|
+
[tool.setuptools.packages.find]
|
|
26
|
+
where = ["src"]
|
|
27
|
+
|
|
28
|
+
[dependency-groups]
|
|
29
|
+
dev = [
|
|
30
|
+
"pytest>=8.4.2",
|
|
31
|
+
"pytest-asyncio>=1.2.0",
|
|
32
|
+
"python-dotenv>=1.2.1",
|
|
33
|
+
"ruff>=0.14.2",
|
|
34
|
+
]
|
|
35
|
+
skypilot = [
|
|
36
|
+
"skypilot[aws,gcp,azure]~=0.8.1", # Change this to your cloud provider
|
|
37
|
+
"pip>=25.3", # Added as needed for skypilot launch
|
|
38
|
+
"msrestazure>=0.6.4.post1",
|
|
39
|
+
]
|
|
40
|
+
skyrl = [
|
|
41
|
+
"grpcio>=1.60.0",
|
|
42
|
+
"hydra-core>=1.3.2",
|
|
43
|
+
"omegaconf>=2.3.0",
|
|
44
|
+
"ray>=2.48.0",
|
|
45
|
+
"skyrl-gym>=0.1.1",
|
|
46
|
+
"skyrl-train[vllm]>=0.2.0",
|
|
47
|
+
]
|
|
48
|
+
excel = ["openpyxl>=3.1.5"]
|
|
49
|
+
excel-mac-windows = ["openpyxl>=3.1.5", "xlwings>=0.33.16"]
|
|
50
|
+
crm = ["python-dateutil>=2.9.0.post0", "simple-salesforce>=1.12.9"]
|
|
51
|
+
|
|
52
|
+
[tool.uv]
|
|
53
|
+
conflicts = [[{ group = "skypilot" }, { group = "skyrl" }]]
|
|
54
|
+
|
|
55
|
+
[tool.uv.pip]
|
|
56
|
+
extra = ["dev", "skypilot", "skyrl", "excel", "excel-mac-windows", "crm"]
|
|
57
|
+
|
|
58
|
+
[tool.uv.extra-build-dependencies]
|
|
59
|
+
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
|
60
|
+
|
|
61
|
+
[tool.uv.extra-build-variables]
|
|
62
|
+
flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" }
|
|
@@ -6,6 +6,9 @@ from typing import Dict, List, Any, Optional, Type, Union
|
|
|
6
6
|
|
|
7
7
|
from benchmax.envs.base_env import BaseEnv
|
|
8
8
|
|
|
9
|
+
# 5 minutes timeout in seconds
|
|
10
|
+
RAY_GET_TIMEOUT = 300
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
class BenchmaxEnv:
|
|
11
14
|
"""
|
|
@@ -34,6 +37,10 @@ class BenchmaxEnv:
|
|
|
34
37
|
async def init_rollout(self, rollout_id: str, **rollout_args: Any) -> None:
|
|
35
38
|
return await self._env.init_rollout(rollout_id=rollout_id, **rollout_args)
|
|
36
39
|
|
|
40
|
+
@ray.method
|
|
41
|
+
async def release_rollout(self, rollout_id: str) -> None:
|
|
42
|
+
return await self._env.release_rollout(rollout_id)
|
|
43
|
+
|
|
37
44
|
@ray.method
|
|
38
45
|
async def copy_to_workspace(
|
|
39
46
|
self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
|
|
@@ -95,7 +102,7 @@ class BenchmaxEnvWrapper:
|
|
|
95
102
|
obj_ref: ray.ObjectRef[str] = self._actor.get_system_prompt.remote(
|
|
96
103
|
add_tool_defs=add_tool_defs # type: ignore
|
|
97
104
|
)
|
|
98
|
-
return ray.get(obj_ref)
|
|
105
|
+
return ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
99
106
|
|
|
100
107
|
async def get_system_prompt(self, add_tool_defs: bool = True) -> str:
|
|
101
108
|
"""Async method to get system prompt."""
|
|
@@ -113,7 +120,7 @@ class BenchmaxEnvWrapper:
|
|
|
113
120
|
def list_tools_sync(self) -> List[Any]:
|
|
114
121
|
"""Sync method to list available tools."""
|
|
115
122
|
obj_ref: ray.ObjectRef[List[Any]] = self._actor.list_tools.remote() # type: ignore
|
|
116
|
-
return ray.get(obj_ref)
|
|
123
|
+
return ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
117
124
|
|
|
118
125
|
# === Shutdown ===
|
|
119
126
|
async def shutdown(self) -> None:
|
|
@@ -124,9 +131,9 @@ class BenchmaxEnvWrapper:
|
|
|
124
131
|
def shutdown_sync(self) -> None:
|
|
125
132
|
"""Sync method to shutdown the environment."""
|
|
126
133
|
obj_ref: ray.ObjectRef[Any] = self._actor.shutdown.remote()
|
|
127
|
-
ray.get(obj_ref)
|
|
134
|
+
ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
128
135
|
|
|
129
|
-
# ===
|
|
136
|
+
# === Rollout Lifecycle ===
|
|
130
137
|
async def init_rollout(self, rollout_id: str, **rollout_args: Any) -> None:
|
|
131
138
|
"""Async method to initialize a rollout."""
|
|
132
139
|
obj_ref: ray.ObjectRef[Any] = self._actor.init_rollout.remote(
|
|
@@ -139,7 +146,15 @@ class BenchmaxEnvWrapper:
|
|
|
139
146
|
obj_ref: ray.ObjectRef[Any] = self._actor.init_rollout.remote(
|
|
140
147
|
rollout_id, **rollout_args
|
|
141
148
|
)
|
|
142
|
-
ray.get(obj_ref)
|
|
149
|
+
ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
150
|
+
|
|
151
|
+
async def release_rollout(self, rollout_id: str) -> None:
|
|
152
|
+
obj_ref: ray.ObjectRef[Any] = self._actor.release_rollout.remote(rollout_id)
|
|
153
|
+
await obj_ref
|
|
154
|
+
|
|
155
|
+
def release_rollout_sync(self, rollout_id: str) -> None:
|
|
156
|
+
obj_ref: ray.ObjectRef[Any] = self._actor.release_rollout.remote(rollout_id)
|
|
157
|
+
ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
143
158
|
|
|
144
159
|
# === Run Tool ===
|
|
145
160
|
async def run_tool(self, rollout_id: str, tool_name: str, **tool_args: Any) -> Any:
|
|
@@ -154,7 +169,7 @@ class BenchmaxEnvWrapper:
|
|
|
154
169
|
obj_ref: ray.ObjectRef[Any] = self._actor.run_tool.remote(
|
|
155
170
|
rollout_id, tool_name, **tool_args
|
|
156
171
|
)
|
|
157
|
-
return ray.get(obj_ref)
|
|
172
|
+
return ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
158
173
|
|
|
159
174
|
# === Copy to Workspace ===
|
|
160
175
|
async def copy_to_workspace(
|
|
@@ -183,7 +198,7 @@ class BenchmaxEnvWrapper:
|
|
|
183
198
|
Path(src_path),
|
|
184
199
|
dst_filename=dst_filename, # type: ignore
|
|
185
200
|
)
|
|
186
|
-
ray.get(obj_ref)
|
|
201
|
+
ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
187
202
|
|
|
188
203
|
# === Copy Content to Workspace ===
|
|
189
204
|
async def copy_content_to_workspace(
|
|
@@ -212,7 +227,7 @@ class BenchmaxEnvWrapper:
|
|
|
212
227
|
src_content,
|
|
213
228
|
dst_filename=dst_filename, # type: ignore
|
|
214
229
|
)
|
|
215
|
-
ray.get(obj_ref)
|
|
230
|
+
ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
216
231
|
|
|
217
232
|
# === Copy from Workspace ===
|
|
218
233
|
async def copy_from_workspace(
|
|
@@ -237,7 +252,7 @@ class BenchmaxEnvWrapper:
|
|
|
237
252
|
obj_ref: ray.ObjectRef[Any] = self._actor.copy_from_workspace.remote(
|
|
238
253
|
rollout_id, src_filename, Path(dst_path)
|
|
239
254
|
)
|
|
240
|
-
ray.get(obj_ref)
|
|
255
|
+
ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
|
241
256
|
|
|
242
257
|
# === Compute Reward ===
|
|
243
258
|
async def compute_reward(
|
|
@@ -264,4 +279,4 @@ class BenchmaxEnvWrapper:
|
|
|
264
279
|
obj_ref: ray.ObjectRef[Dict[str, float]] = self._actor.compute_reward.remote(
|
|
265
280
|
rollout_id, completion, ground_truth, **kwargs
|
|
266
281
|
) # type: ignore
|
|
267
|
-
return ray.get(obj_ref)
|
|
282
|
+
return ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
|
{benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/adapters/skyrl/benchmax_data_process.py
RENAMED
|
@@ -18,6 +18,7 @@ from benchmax.envs.base_env import BaseEnv
|
|
|
18
18
|
# Set logging level to WARNING and above
|
|
19
19
|
logging.basicConfig(level=logging.WARNING)
|
|
20
20
|
|
|
21
|
+
|
|
21
22
|
def load_class(dotted_path: str) -> Type[BaseEnv]:
|
|
22
23
|
"""
|
|
23
24
|
Load and return the class specified by `dotted_path`.
|
|
@@ -40,18 +41,58 @@ def load_class(dotted_path: str) -> Type[BaseEnv]:
|
|
|
40
41
|
|
|
41
42
|
return cls
|
|
42
43
|
|
|
44
|
+
|
|
45
|
+
def get_canonical_class_name(cls: Type[BaseEnv]) -> str:
|
|
46
|
+
"""
|
|
47
|
+
Get the canonical class name, removing local/skypilot prefix/suffix if the parent class
|
|
48
|
+
has the same name without that prefix/suffix.
|
|
49
|
+
"""
|
|
50
|
+
class_name = cls.__name__
|
|
51
|
+
|
|
52
|
+
# Check for prefixes/suffixes to strip
|
|
53
|
+
prefixes = ["local", "skypilot"]
|
|
54
|
+
suffixes = ["local", "skypilot"]
|
|
55
|
+
|
|
56
|
+
# Try to find a matching parent class without the prefix/suffix
|
|
57
|
+
for base_cls in cls.__bases__:
|
|
58
|
+
base_name = base_cls.__name__
|
|
59
|
+
|
|
60
|
+
# Check if current class has prefix that base doesn't
|
|
61
|
+
for prefix in prefixes:
|
|
62
|
+
if class_name.lower().startswith(
|
|
63
|
+
prefix
|
|
64
|
+
) and not base_name.lower().startswith(prefix):
|
|
65
|
+
# Check if removing prefix gives us the base name
|
|
66
|
+
stripped = class_name[len(prefix) :]
|
|
67
|
+
if stripped == base_name:
|
|
68
|
+
return base_name
|
|
69
|
+
|
|
70
|
+
# Check if current class has suffix that base doesn't
|
|
71
|
+
for suffix in suffixes:
|
|
72
|
+
if class_name.lower().endswith(suffix) and not base_name.lower().endswith(
|
|
73
|
+
suffix
|
|
74
|
+
):
|
|
75
|
+
# Check if removing suffix gives us the base name
|
|
76
|
+
stripped = class_name[: -len(suffix)]
|
|
77
|
+
if stripped == base_name:
|
|
78
|
+
return base_name
|
|
79
|
+
|
|
80
|
+
# No matching parent found, return original name
|
|
81
|
+
return class_name
|
|
82
|
+
|
|
83
|
+
|
|
43
84
|
async def get_system_prompt(cls: Type[BaseEnv]) -> str:
|
|
44
85
|
"""Setup env and get system prompt in async context."""
|
|
45
86
|
# Initialize env with num_local_servers=1 if supported
|
|
46
87
|
init_signature = inspect.signature(cls.__init__)
|
|
47
|
-
if
|
|
48
|
-
env = cls(num_local_servers=1)
|
|
88
|
+
if "num_local_servers" in init_signature.parameters:
|
|
89
|
+
env = cls(num_local_servers=1) # type: ignore
|
|
49
90
|
else:
|
|
50
91
|
env = cls()
|
|
51
|
-
|
|
92
|
+
|
|
52
93
|
# Get system prompt (async function)
|
|
53
94
|
prompt = await env.get_system_prompt(add_tool_defs=True)
|
|
54
|
-
|
|
95
|
+
|
|
55
96
|
await env.shutdown()
|
|
56
97
|
return prompt
|
|
57
98
|
|
|
@@ -96,11 +137,16 @@ if __name__ == "__main__":
|
|
|
96
137
|
print("Getting system prompt...", flush=True)
|
|
97
138
|
system_prompt = asyncio.run(get_system_prompt(benchmax_cls))
|
|
98
139
|
|
|
140
|
+
# Get canonical class name (strips local/skypilot if parent matches)
|
|
141
|
+
canonical_name = get_canonical_class_name(benchmax_cls)
|
|
142
|
+
|
|
99
143
|
def process_example(example):
|
|
100
144
|
"""Single mapping function that does all processing."""
|
|
101
145
|
# First apply dataset-specific preprocessing
|
|
102
|
-
standardized = benchmax_cls.dataset_preprocess(
|
|
103
|
-
|
|
146
|
+
standardized = benchmax_cls.dataset_preprocess(
|
|
147
|
+
example, dataset_path=dataset_path
|
|
148
|
+
)
|
|
149
|
+
|
|
104
150
|
# Then format as multiturn prompt
|
|
105
151
|
prompt = [
|
|
106
152
|
{
|
|
@@ -112,13 +158,13 @@ if __name__ == "__main__":
|
|
|
112
158
|
result = {
|
|
113
159
|
**standardized,
|
|
114
160
|
"prompt": prompt,
|
|
115
|
-
"env_class":
|
|
116
|
-
"data_source":
|
|
161
|
+
"env_class": canonical_name,
|
|
162
|
+
"data_source": canonical_name,
|
|
117
163
|
}
|
|
118
|
-
|
|
164
|
+
|
|
119
165
|
# Remove keys with None values
|
|
120
166
|
result = {k: v for k, v in result.items() if v is not None}
|
|
121
|
-
|
|
167
|
+
|
|
122
168
|
return result
|
|
123
169
|
|
|
124
170
|
print("Processing examples...", flush=True)
|
|
@@ -145,4 +191,4 @@ if __name__ == "__main__":
|
|
|
145
191
|
print(f"Saving to {args.local_dir}...", flush=True)
|
|
146
192
|
local_dir = Path(args.local_dir)
|
|
147
193
|
train_dataset.to_parquet(local_dir / "train.parquet")
|
|
148
|
-
test_dataset.to_parquet(local_dir / "test.parquet")
|
|
194
|
+
test_dataset.to_parquet(local_dir / "test.parquet")
|