benchmax 0.1.2.dev0__tar.gz → 0.1.2.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/PKG-INFO +13 -16
  2. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/README.md +8 -13
  3. benchmax-0.1.2.dev2/pyproject.toml +62 -0
  4. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/adapters/benchmax_wrapper.py +25 -10
  5. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/adapters/skyrl/benchmax_data_process.py +57 -11
  6. benchmax-0.1.2.dev2/src/benchmax/adapters/skyrl/skyrl_adapter.py +311 -0
  7. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/base_env.py +5 -0
  8. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/crm/crm_env.py +4 -2
  9. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/crm/workdir/reward_fn.py +31 -28
  10. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/crm/workdir/salesforce_mcp.py +387 -180
  11. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/data_utils.py +4 -3
  12. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/excel_env.py +1 -0
  13. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py +9 -15
  14. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/workdir/excel_utils.py +66 -29
  15. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/workdir/reward_fn.py +10 -5
  16. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/math/math_env.py +1 -1
  17. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/math/workdir/reward_fn.py +6 -6
  18. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/example_workdir/reward_fn.py +12 -15
  19. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/parallel_mcp_env.py +119 -88
  20. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/__init__.py +1 -1
  21. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/local_provisioner.py +80 -29
  22. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/manual_provisioner.py +8 -9
  23. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py +9 -9
  24. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/utils.py +18 -17
  25. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/proxy_server.py +29 -24
  26. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/server_pool.py +142 -66
  27. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/types.py +3 -1
  28. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/wikipedia/utils.py +9 -9
  29. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/wikipedia/wiki_env.py +21 -11
  30. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/PKG-INFO +13 -16
  31. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/requires.txt +1 -1
  32. benchmax-0.1.2.dev0/pyproject.toml +0 -58
  33. benchmax-0.1.2.dev0/src/benchmax/adapters/skyrl/skyrl_adapter.py +0 -224
  34. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/LICENSE +0 -0
  35. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/setup.cfg +0 -0
  36. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/adapters/__init__.py +0 -0
  37. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/__init__.py +0 -0
  38. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/excel/workdir/__init__.py +0 -0
  39. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/__init__.py +0 -0
  40. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py +0 -0
  41. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/provisioners/base_provisioner.py +0 -0
  42. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/envs/mcp/utils.py +0 -0
  43. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/prompts/__init__.py +0 -0
  44. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax/prompts/tools.py +0 -0
  45. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/SOURCES.txt +0 -0
  46. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/dependency_links.txt +0 -0
  47. {benchmax-0.1.2.dev0 → benchmax-0.1.2.dev2}/src/benchmax.egg-info/top_level.txt +0 -0
@@ -1,14 +1,16 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: benchmax
3
- Version: 0.1.2.dev0
3
+ Version: 0.1.2.dev2
4
4
  Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
5
5
  Author: cgft.io
6
- Requires-Python: <3.14,>=3.12
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: Operating System :: OS Independent
8
+ Requires-Python: ==3.12.*
7
9
  Description-Content-Type: text/markdown
8
10
  License-File: LICENSE
9
11
  Requires-Dist: aiohttp>=3.13.1
10
12
  Requires-Dist: asyncio>=4.0.0
11
- Requires-Dist: datasets>=4.3.0
13
+ Requires-Dist: datasets>=4.0.0
12
14
  Requires-Dist: fastmcp~=2.12.0
13
15
  Requires-Dist: pyjwt>=2.10.1
14
16
  Requires-Dist: skypilot~=0.8.1
@@ -36,24 +38,19 @@ Dynamic: license-file
36
38
 
37
39
  ## 📌 News
38
40
 
39
- - **[26 Oct 2025]** 🎉 Added support for easy multi-node parallelization across all major cloud providers using [SkyPilot](https://github.com/skypilot-org/skypilot)
40
- - **[26 Oct 2025]** 🎉 Integration with [SkyRL](https://github.com/NovaSky-AI/SkyRL) for distributed RL training across clusters
41
+ - **[29 Oct 2025]** 🎉 Added support for easy multi-node parallelization across all major cloud providers using [SkyPilot](https://github.com/skypilot-org/skypilot)
42
+ - **[29 Oct 2025]** 🎉 Integration with [SkyRL](https://github.com/NovaSky-AI/SkyRL) for distributed RL training across clusters
41
43
  - **[Upcoming]** 🛠️ Integration with Tinker API.
42
44
 
43
45
  ## 📘 Quickstart
44
46
 
45
47
  **Example: Multi-node parallelization of Excel Env with SkyRL and SkyPilot**
46
48
 
47
- RL environments can be computationally expensive to run (e.g. codegen). To handle this workload efficiently, we distribute rollouts across multiple nodes using **SkyPilot**, horizontally scaling `benchmax` across cloud providers like GCP, AWS, Azure, etc.
49
+ RL environments can be computationally expensive to run (e.g. running tests). To handle these workloads efficiently, we distribute rollouts across multiple nodes using **SkyPilot**, horizontally scaling `benchmax` across cloud providers like GCP, AWS, Azure, etc.
48
50
 
49
- **SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined at `benchmax.envs.excel.excel_env.ExcelEnvSkypilot`
51
+ **SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined in [`benchmax.envs.excel.excel_env.ExcelEnvSkypilot`](/src/benchmax/envs/excel/excel_env.py)
50
52
 
51
-
52
- 1. **Installation**
53
-
54
- `pip install benchmax[excel,skyrl]`
55
-
56
- 2. **Prepare the dataset**
53
+ 1. **Prepare the dataset**
57
54
 
58
55
  ```bash
59
56
  uv run src/benchmax/adapters/skyrl/benchmax_data_process.py \
@@ -64,13 +61,13 @@ RL environments can be computationally expensive to run (e.g. codegen). To handl
64
61
 
65
62
  Note: We are using `ExcelEnvLocal` instead of `ExcelEnvSkypilot` because the MCP is only used for listing tools to prepare the system prompt.
66
63
 
67
- 3. **Run training and parallelize Excel environment**
64
+ 2. **Run training and parallelize Excel environment**
68
65
 
69
66
  ```bash
70
- sh examples/skyrl/run_benchmax_excel.sh
67
+ bash examples/skyrl/run_benchmax_excel.sh
71
68
  ```
72
69
 
73
- This excel env example will spin up 5 nodes with 4 servers per node (total 20 MCP server in parallel). For more details, check out [multi-node parallelization](/src/benchmax/envs/mcp/README.md) and [SkyRL integration](/examples/skyrl/README.md).
70
+ This excel env example will spin up 5 nodes with 20 servers per node (total 100 MCP server in parallel). For more details, check out [multi-node parallelization](/src/benchmax/envs/mcp/README.md) and [SkyRL integration](/examples/skyrl/README.md).
74
71
 
75
72
  ## ℹ️ Overview
76
73
 
@@ -20,24 +20,19 @@
20
20
 
21
21
  ## 📌 News
22
22
 
23
- - **[26 Oct 2025]** 🎉 Added support for easy multi-node parallelization across all major cloud providers using [SkyPilot](https://github.com/skypilot-org/skypilot)
24
- - **[26 Oct 2025]** 🎉 Integration with [SkyRL](https://github.com/NovaSky-AI/SkyRL) for distributed RL training across clusters
23
+ - **[29 Oct 2025]** 🎉 Added support for easy multi-node parallelization across all major cloud providers using [SkyPilot](https://github.com/skypilot-org/skypilot)
24
+ - **[29 Oct 2025]** 🎉 Integration with [SkyRL](https://github.com/NovaSky-AI/SkyRL) for distributed RL training across clusters
25
25
  - **[Upcoming]** 🛠️ Integration with Tinker API.
26
26
 
27
27
  ## 📘 Quickstart
28
28
 
29
29
  **Example: Multi-node parallelization of Excel Env with SkyRL and SkyPilot**
30
30
 
31
- RL environments can be computationally expensive to run (e.g. codegen). To handle this workload efficiently, we distribute rollouts across multiple nodes using **SkyPilot**, horizontally scaling `benchmax` across cloud providers like GCP, AWS, Azure, etc.
31
+ RL environments can be computationally expensive to run (e.g. running tests). To handle these workloads efficiently, we distribute rollouts across multiple nodes using **SkyPilot**, horizontally scaling `benchmax` across cloud providers like GCP, AWS, Azure, etc.
32
32
 
33
- **SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined at `benchmax.envs.excel.excel_env.ExcelEnvSkypilot`
33
+ **SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined in [`benchmax.envs.excel.excel_env.ExcelEnvSkypilot`](/src/benchmax/envs/excel/excel_env.py)
34
34
 
35
-
36
- 1. **Installation**
37
-
38
- `pip install benchmax[excel,skyrl]`
39
-
40
- 2. **Prepare the dataset**
35
+ 1. **Prepare the dataset**
41
36
 
42
37
  ```bash
43
38
  uv run src/benchmax/adapters/skyrl/benchmax_data_process.py \
@@ -48,13 +43,13 @@ RL environments can be computationally expensive to run (e.g. codegen). To handl
48
43
 
49
44
  Note: We are using `ExcelEnvLocal` instead of `ExcelEnvSkypilot` because the MCP is only used for listing tools to prepare the system prompt.
50
45
 
51
- 3. **Run training and parallelize Excel environment**
46
+ 2. **Run training and parallelize Excel environment**
52
47
 
53
48
  ```bash
54
- sh examples/skyrl/run_benchmax_excel.sh
49
+ bash examples/skyrl/run_benchmax_excel.sh
55
50
  ```
56
51
 
57
- This excel env example will spin up 5 nodes with 4 servers per node (total 20 MCP server in parallel). For more details, check out [multi-node parallelization](/src/benchmax/envs/mcp/README.md) and [SkyRL integration](/examples/skyrl/README.md).
52
+ This excel env example will spin up 5 nodes with 20 servers per node (total 100 MCP server in parallel). For more details, check out [multi-node parallelization](/src/benchmax/envs/mcp/README.md) and [SkyRL integration](/examples/skyrl/README.md).
58
53
 
59
54
  ## ℹ️ Overview
60
55
 
@@ -0,0 +1,62 @@
1
+ [project]
2
+ name = "benchmax"
3
+ version = "0.1.2.dev2"
4
+ description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
5
+ readme = "README.md"
6
+ authors = [{ name = "cgft.io" }]
7
+ requires-python = "==3.12.*"
8
+ dependencies = [
9
+ "aiohttp>=3.13.1",
10
+ "asyncio>=4.0.0",
11
+ "datasets>=4.0.0",
12
+ "fastmcp~=2.12.0",
13
+ "pyjwt>=2.10.1",
14
+ "skypilot~=0.8.1",
15
+ ]
16
+ classifiers = [
17
+ "Programming Language :: Python :: 3",
18
+ "Operating System :: OS Independent",
19
+ ]
20
+
21
+ [build-system]
22
+ requires = ["setuptools>=61.0", "wheel"]
23
+ build-backend = "setuptools.build_meta"
24
+
25
+ [tool.setuptools.packages.find]
26
+ where = ["src"]
27
+
28
+ [dependency-groups]
29
+ dev = [
30
+ "pytest>=8.4.2",
31
+ "pytest-asyncio>=1.2.0",
32
+ "python-dotenv>=1.2.1",
33
+ "ruff>=0.14.2",
34
+ ]
35
+ skypilot = [
36
+ "skypilot[aws,gcp,azure]~=0.8.1", # Change this to your cloud provider
37
+ "pip>=25.3", # Added as needed for skypilot launch
38
+ "msrestazure>=0.6.4.post1",
39
+ ]
40
+ skyrl = [
41
+ "grpcio>=1.60.0",
42
+ "hydra-core>=1.3.2",
43
+ "omegaconf>=2.3.0",
44
+ "ray>=2.48.0",
45
+ "skyrl-gym>=0.1.1",
46
+ "skyrl-train[vllm]>=0.2.0",
47
+ ]
48
+ excel = ["openpyxl>=3.1.5"]
49
+ excel-mac-windows = ["openpyxl>=3.1.5", "xlwings>=0.33.16"]
50
+ crm = ["python-dateutil>=2.9.0.post0", "simple-salesforce>=1.12.9"]
51
+
52
+ [tool.uv]
53
+ conflicts = [[{ group = "skypilot" }, { group = "skyrl" }]]
54
+
55
+ [tool.uv.pip]
56
+ extra = ["dev", "skypilot", "skyrl", "excel", "excel-mac-windows", "crm"]
57
+
58
+ [tool.uv.extra-build-dependencies]
59
+ flash-attn = [{ requirement = "torch", match-runtime = true }]
60
+
61
+ [tool.uv.extra-build-variables]
62
+ flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" }
@@ -6,6 +6,9 @@ from typing import Dict, List, Any, Optional, Type, Union
6
6
 
7
7
  from benchmax.envs.base_env import BaseEnv
8
8
 
9
+ # 5 minutes timeout in seconds
10
+ RAY_GET_TIMEOUT = 300
11
+
9
12
 
10
13
  class BenchmaxEnv:
11
14
  """
@@ -34,6 +37,10 @@ class BenchmaxEnv:
34
37
  async def init_rollout(self, rollout_id: str, **rollout_args: Any) -> None:
35
38
  return await self._env.init_rollout(rollout_id=rollout_id, **rollout_args)
36
39
 
40
+ @ray.method
41
+ async def release_rollout(self, rollout_id: str) -> None:
42
+ return await self._env.release_rollout(rollout_id)
43
+
37
44
  @ray.method
38
45
  async def copy_to_workspace(
39
46
  self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
@@ -95,7 +102,7 @@ class BenchmaxEnvWrapper:
95
102
  obj_ref: ray.ObjectRef[str] = self._actor.get_system_prompt.remote(
96
103
  add_tool_defs=add_tool_defs # type: ignore
97
104
  )
98
- return ray.get(obj_ref)
105
+ return ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
99
106
 
100
107
  async def get_system_prompt(self, add_tool_defs: bool = True) -> str:
101
108
  """Async method to get system prompt."""
@@ -113,7 +120,7 @@ class BenchmaxEnvWrapper:
113
120
  def list_tools_sync(self) -> List[Any]:
114
121
  """Sync method to list available tools."""
115
122
  obj_ref: ray.ObjectRef[List[Any]] = self._actor.list_tools.remote() # type: ignore
116
- return ray.get(obj_ref)
123
+ return ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
117
124
 
118
125
  # === Shutdown ===
119
126
  async def shutdown(self) -> None:
@@ -124,9 +131,9 @@ class BenchmaxEnvWrapper:
124
131
  def shutdown_sync(self) -> None:
125
132
  """Sync method to shutdown the environment."""
126
133
  obj_ref: ray.ObjectRef[Any] = self._actor.shutdown.remote()
127
- ray.get(obj_ref)
134
+ ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
128
135
 
129
- # === Init Rollout ===
136
+ # === Rollout Lifecycle ===
130
137
  async def init_rollout(self, rollout_id: str, **rollout_args: Any) -> None:
131
138
  """Async method to initialize a rollout."""
132
139
  obj_ref: ray.ObjectRef[Any] = self._actor.init_rollout.remote(
@@ -139,7 +146,15 @@ class BenchmaxEnvWrapper:
139
146
  obj_ref: ray.ObjectRef[Any] = self._actor.init_rollout.remote(
140
147
  rollout_id, **rollout_args
141
148
  )
142
- ray.get(obj_ref)
149
+ ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
150
+
151
+ async def release_rollout(self, rollout_id: str) -> None:
152
+ obj_ref: ray.ObjectRef[Any] = self._actor.release_rollout.remote(rollout_id)
153
+ await obj_ref
154
+
155
+ def release_rollout_sync(self, rollout_id: str) -> None:
156
+ obj_ref: ray.ObjectRef[Any] = self._actor.release_rollout.remote(rollout_id)
157
+ ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
143
158
 
144
159
  # === Run Tool ===
145
160
  async def run_tool(self, rollout_id: str, tool_name: str, **tool_args: Any) -> Any:
@@ -154,7 +169,7 @@ class BenchmaxEnvWrapper:
154
169
  obj_ref: ray.ObjectRef[Any] = self._actor.run_tool.remote(
155
170
  rollout_id, tool_name, **tool_args
156
171
  )
157
- return ray.get(obj_ref)
172
+ return ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
158
173
 
159
174
  # === Copy to Workspace ===
160
175
  async def copy_to_workspace(
@@ -183,7 +198,7 @@ class BenchmaxEnvWrapper:
183
198
  Path(src_path),
184
199
  dst_filename=dst_filename, # type: ignore
185
200
  )
186
- ray.get(obj_ref)
201
+ ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
187
202
 
188
203
  # === Copy Content to Workspace ===
189
204
  async def copy_content_to_workspace(
@@ -212,7 +227,7 @@ class BenchmaxEnvWrapper:
212
227
  src_content,
213
228
  dst_filename=dst_filename, # type: ignore
214
229
  )
215
- ray.get(obj_ref)
230
+ ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
216
231
 
217
232
  # === Copy from Workspace ===
218
233
  async def copy_from_workspace(
@@ -237,7 +252,7 @@ class BenchmaxEnvWrapper:
237
252
  obj_ref: ray.ObjectRef[Any] = self._actor.copy_from_workspace.remote(
238
253
  rollout_id, src_filename, Path(dst_path)
239
254
  )
240
- ray.get(obj_ref)
255
+ ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
241
256
 
242
257
  # === Compute Reward ===
243
258
  async def compute_reward(
@@ -264,4 +279,4 @@ class BenchmaxEnvWrapper:
264
279
  obj_ref: ray.ObjectRef[Dict[str, float]] = self._actor.compute_reward.remote(
265
280
  rollout_id, completion, ground_truth, **kwargs
266
281
  ) # type: ignore
267
- return ray.get(obj_ref)
282
+ return ray.get(obj_ref, timeout=RAY_GET_TIMEOUT)
@@ -18,6 +18,7 @@ from benchmax.envs.base_env import BaseEnv
18
18
  # Set logging level to WARNING and above
19
19
  logging.basicConfig(level=logging.WARNING)
20
20
 
21
+
21
22
  def load_class(dotted_path: str) -> Type[BaseEnv]:
22
23
  """
23
24
  Load and return the class specified by `dotted_path`.
@@ -40,18 +41,58 @@ def load_class(dotted_path: str) -> Type[BaseEnv]:
40
41
 
41
42
  return cls
42
43
 
44
+
45
+ def get_canonical_class_name(cls: Type[BaseEnv]) -> str:
46
+ """
47
+ Get the canonical class name, removing local/skypilot prefix/suffix if the parent class
48
+ has the same name without that prefix/suffix.
49
+ """
50
+ class_name = cls.__name__
51
+
52
+ # Check for prefixes/suffixes to strip
53
+ prefixes = ["local", "skypilot"]
54
+ suffixes = ["local", "skypilot"]
55
+
56
+ # Try to find a matching parent class without the prefix/suffix
57
+ for base_cls in cls.__bases__:
58
+ base_name = base_cls.__name__
59
+
60
+ # Check if current class has prefix that base doesn't
61
+ for prefix in prefixes:
62
+ if class_name.lower().startswith(
63
+ prefix
64
+ ) and not base_name.lower().startswith(prefix):
65
+ # Check if removing prefix gives us the base name
66
+ stripped = class_name[len(prefix) :]
67
+ if stripped == base_name:
68
+ return base_name
69
+
70
+ # Check if current class has suffix that base doesn't
71
+ for suffix in suffixes:
72
+ if class_name.lower().endswith(suffix) and not base_name.lower().endswith(
73
+ suffix
74
+ ):
75
+ # Check if removing suffix gives us the base name
76
+ stripped = class_name[: -len(suffix)]
77
+ if stripped == base_name:
78
+ return base_name
79
+
80
+ # No matching parent found, return original name
81
+ return class_name
82
+
83
+
43
84
  async def get_system_prompt(cls: Type[BaseEnv]) -> str:
44
85
  """Setup env and get system prompt in async context."""
45
86
  # Initialize env with num_local_servers=1 if supported
46
87
  init_signature = inspect.signature(cls.__init__)
47
- if 'num_local_servers' in init_signature.parameters:
48
- env = cls(num_local_servers=1) # type: ignore
88
+ if "num_local_servers" in init_signature.parameters:
89
+ env = cls(num_local_servers=1) # type: ignore
49
90
  else:
50
91
  env = cls()
51
-
92
+
52
93
  # Get system prompt (async function)
53
94
  prompt = await env.get_system_prompt(add_tool_defs=True)
54
-
95
+
55
96
  await env.shutdown()
56
97
  return prompt
57
98
 
@@ -96,11 +137,16 @@ if __name__ == "__main__":
96
137
  print("Getting system prompt...", flush=True)
97
138
  system_prompt = asyncio.run(get_system_prompt(benchmax_cls))
98
139
 
140
+ # Get canonical class name (strips local/skypilot if parent matches)
141
+ canonical_name = get_canonical_class_name(benchmax_cls)
142
+
99
143
  def process_example(example):
100
144
  """Single mapping function that does all processing."""
101
145
  # First apply dataset-specific preprocessing
102
- standardized = benchmax_cls.dataset_preprocess(example, dataset_path=dataset_path)
103
-
146
+ standardized = benchmax_cls.dataset_preprocess(
147
+ example, dataset_path=dataset_path
148
+ )
149
+
104
150
  # Then format as multiturn prompt
105
151
  prompt = [
106
152
  {
@@ -112,13 +158,13 @@ if __name__ == "__main__":
112
158
  result = {
113
159
  **standardized,
114
160
  "prompt": prompt,
115
- "env_class": benchmax_cls.__name__,
116
- "data_source": benchmax_cls.__name__,
161
+ "env_class": canonical_name,
162
+ "data_source": canonical_name,
117
163
  }
118
-
164
+
119
165
  # Remove keys with None values
120
166
  result = {k: v for k, v in result.items() if v is not None}
121
-
167
+
122
168
  return result
123
169
 
124
170
  print("Processing examples...", flush=True)
@@ -145,4 +191,4 @@ if __name__ == "__main__":
145
191
  print(f"Saving to {args.local_dir}...", flush=True)
146
192
  local_dir = Path(args.local_dir)
147
193
  train_dataset.to_parquet(local_dir / "train.parquet")
148
- test_dataset.to_parquet(local_dir / "test.parquet")
194
+ test_dataset.to_parquet(local_dir / "test.parquet")