benchmax 0.1.1.dev4__tar.gz → 0.1.1.dev6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/PKG-INFO +8 -3
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/README.md +5 -3
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/base_env.py +19 -36
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/excel_env.py +4 -13
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/excel_utils.py +9 -8
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/local_mcp_env.py +28 -0
- benchmax-0.1.1.dev6/benchmax/envs/skypilot/proxy_server.py +167 -0
- benchmax-0.1.1.dev6/benchmax/envs/skypilot/remote_skypilot_mcp_server.py +694 -0
- benchmax-0.1.1.dev6/benchmax/envs/skypilot/workdir/mcp_config.yaml +5 -0
- benchmax-0.1.1.dev6/benchmax/envs/skypilot/workdir/reward_func.py +16 -0
- benchmax-0.1.1.dev6/benchmax/envs/skypilot/workdir/setup.sh +3 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/types.py +1 -3
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/wikipedia/wiki_env.py +12 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/pyproject.toml +6 -1
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/LICENSE +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/adapters/__init__.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/adapters/verifiers/verifiers_adapters.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/adapters/verl/benchmax_data_process.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/__init__.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/bounded_dict.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/crm/README.md +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/crm/crm_env.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/crm/salesforce_mcp.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/crm/salesforce_requirements.txt +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/README.md +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/data_utils.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/excel_code_runner_mcp.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/math/README.md +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/math/math_env.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/wikipedia/README.md +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/wikipedia/utils.py +0 -0
- {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/prompts/tools.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: benchmax
|
|
3
|
-
Version: 0.1.1.
|
|
3
|
+
Version: 0.1.1.dev6
|
|
4
4
|
Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
|
|
5
5
|
Author: cgft.io
|
|
6
6
|
Requires-Python: >=3.11,<3.13
|
|
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
10
10
|
Provides-Extra: crm
|
|
11
11
|
Provides-Extra: excel
|
|
12
12
|
Provides-Extra: excel-linux
|
|
13
|
+
Provides-Extra: skypilot
|
|
13
14
|
Provides-Extra: verifiers
|
|
14
15
|
Provides-Extra: verl
|
|
15
16
|
Requires-Dist: fastmcp (>=2.10.0,<2.11.0)
|
|
@@ -17,6 +18,7 @@ Requires-Dist: openpyxl (==3.1.5) ; extra == "excel-linux" or extra == "excel"
|
|
|
17
18
|
Requires-Dist: python-dateutil (>=2.9.0,<2.10.0) ; extra == "crm"
|
|
18
19
|
Requires-Dist: sglang[all] (==0.4.9) ; extra == "verl"
|
|
19
20
|
Requires-Dist: simple-salesforce (>=1.12.3) ; extra == "crm"
|
|
21
|
+
Requires-Dist: skypilot (==0.8.1) ; extra == "skypilot"
|
|
20
22
|
Requires-Dist: verifiers[train] (>=0.1.1,<0.2.0) ; extra == "verifiers"
|
|
21
23
|
Requires-Dist: verl-cgft-fork (==0.5.0.dev2) ; extra == "verl"
|
|
22
24
|
Requires-Dist: xlwings (==0.33.15) ; extra == "excel"
|
|
@@ -68,7 +70,7 @@ Get started with ready to use recipes, from Wikipedia search to spreadsheet mani
|
|
|
68
70
|
|
|
69
71
|
**Trainer Integrations**
|
|
70
72
|
|
|
71
|
-
Use your own trainer or training framework - no lock-in. `benchmax` is already
|
|
73
|
+
Use your own trainer or training framework - no lock-in. `benchmax` is already integrated into verl and verifiers, with more integrations (SkyRL, etc.) coming soon!
|
|
72
74
|
|
|
73
75
|
**MCP Support**
|
|
74
76
|
Tap into the growing MCP ecosystem and integrate them as tools within your environments.
|
|
@@ -89,6 +91,8 @@ Tap into the growing MCP ecosystem and integrate them as tools within your envir
|
|
|
89
91
|
|
|
90
92
|
`pip install benchmax[verl]`
|
|
91
93
|
|
|
94
|
+
\* Note that benchmax installs our verl fork (temporary until [PR gets merged](https://github.com/volcengine/verl/pull/2792))
|
|
95
|
+
|
|
92
96
|
1. **Prepare the dataset**
|
|
93
97
|
|
|
94
98
|
```bash
|
|
@@ -387,7 +391,7 @@ Open an issue and tag us & we will look into building you one!
|
|
|
387
391
|
- Facilitate easy deployment and scalability in cloud environments.
|
|
388
392
|
- **MCP as a first class citizen**:
|
|
389
393
|
|
|
390
|
-
There has been an explosion of MCP servers/tools built out for usecases ranging from browser use to excel to game creation.`benchmax`
|
|
394
|
+
There has been an explosion of MCP servers/tools built out for usecases ranging from browser use to excel to game creation.`benchmax` allows folks to leverage and compose these existing MCP servers to build environments integrated with real world systems e.g. excel
|
|
391
395
|
|
|
392
396
|
|
|
393
397
|
## 🤝 Contributing
|
|
@@ -399,3 +403,4 @@ We welcome new environment recipes, bug reports, and trainer integrations!
|
|
|
399
403
|
## 📜 License
|
|
400
404
|
|
|
401
405
|
Apache 2.0 © 2025 CGFT Inc.
|
|
406
|
+
|
|
@@ -44,7 +44,7 @@ Get started with ready to use recipes, from Wikipedia search to spreadsheet mani
|
|
|
44
44
|
|
|
45
45
|
**Trainer Integrations**
|
|
46
46
|
|
|
47
|
-
Use your own trainer or training framework - no lock-in. `benchmax` is already
|
|
47
|
+
Use your own trainer or training framework - no lock-in. `benchmax` is already integrated into verl and verifiers, with more integrations (SkyRL, etc.) coming soon!
|
|
48
48
|
|
|
49
49
|
**MCP Support**
|
|
50
50
|
Tap into the growing MCP ecosystem and integrate them as tools within your environments.
|
|
@@ -65,6 +65,8 @@ Tap into the growing MCP ecosystem and integrate them as tools within your envir
|
|
|
65
65
|
|
|
66
66
|
`pip install benchmax[verl]`
|
|
67
67
|
|
|
68
|
+
\* Note that benchmax installs our verl fork (temporary until [PR gets merged](https://github.com/volcengine/verl/pull/2792))
|
|
69
|
+
|
|
68
70
|
1. **Prepare the dataset**
|
|
69
71
|
|
|
70
72
|
```bash
|
|
@@ -363,7 +365,7 @@ Open an issue and tag us & we will look into building you one!
|
|
|
363
365
|
- Facilitate easy deployment and scalability in cloud environments.
|
|
364
366
|
- **MCP as a first class citizen**:
|
|
365
367
|
|
|
366
|
-
There has been an explosion of MCP servers/tools built out for usecases ranging from browser use to excel to game creation.`benchmax`
|
|
368
|
+
There has been an explosion of MCP servers/tools built out for usecases ranging from browser use to excel to game creation.`benchmax` allows folks to leverage and compose these existing MCP servers to build environments integrated with real world systems e.g. excel
|
|
367
369
|
|
|
368
370
|
|
|
369
371
|
## 🤝 Contributing
|
|
@@ -374,4 +376,4 @@ We welcome new environment recipes, bug reports, and trainer integrations!
|
|
|
374
376
|
|
|
375
377
|
## 📜 License
|
|
376
378
|
|
|
377
|
-
Apache 2.0 © 2025 CGFT Inc.
|
|
379
|
+
Apache 2.0 © 2025 CGFT Inc.
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, List, Any, Tuple
|
|
2
|
+
from typing import Dict, List, Any, Optional, Tuple
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
5
|
from datasets import (
|
|
@@ -59,34 +59,38 @@ class BaseEnv(ABC):
|
|
|
59
59
|
return load_dataset(dataset_name, **kwargs), None
|
|
60
60
|
|
|
61
61
|
@abstractmethod
|
|
62
|
-
def list_tools(self) -> List[ToolDefinition]:
|
|
62
|
+
async def list_tools(self) -> List[ToolDefinition]:
|
|
63
63
|
"""Return list of available tools"""
|
|
64
64
|
pass
|
|
65
65
|
|
|
66
66
|
@abstractmethod
|
|
67
|
-
def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Any:
|
|
67
|
+
async def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Any:
|
|
68
68
|
"""Execute named tool in rollout context with given arguments"""
|
|
69
69
|
pass
|
|
70
70
|
|
|
71
71
|
@abstractmethod
|
|
72
|
-
def init_rollout(self, rollout_id: str, **rollout_args) -> None:
|
|
72
|
+
async def init_rollout(self, rollout_id: str, **rollout_args) -> None:
|
|
73
73
|
"""Initialize resources for a new rollout"""
|
|
74
74
|
pass
|
|
75
|
-
|
|
75
|
+
|
|
76
76
|
@abstractmethod
|
|
77
|
-
def
|
|
78
|
-
|
|
77
|
+
async def copy_to_workspace(
|
|
78
|
+
self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Copy a file to the workspace for a specific rollout. If dst_filename is None, use the original filename."""
|
|
79
81
|
pass
|
|
80
82
|
|
|
81
83
|
@abstractmethod
|
|
82
|
-
def
|
|
83
|
-
|
|
84
|
+
async def copy_from_workspace(
|
|
85
|
+
self, rollout_id: str, src_filename: str, dst_path: Path
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Copy a file from the workspace for a specific rollout"""
|
|
84
88
|
pass
|
|
85
|
-
|
|
86
|
-
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
async def compute_reward(
|
|
87
92
|
self,
|
|
88
93
|
rollout_id: str,
|
|
89
|
-
prompt: str,
|
|
90
94
|
completion: str,
|
|
91
95
|
ground_truth: Any,
|
|
92
96
|
**kwargs: Any
|
|
@@ -95,32 +99,11 @@ class BaseEnv(ABC):
|
|
|
95
99
|
|
|
96
100
|
Returns dict mapping reward function names to their computed scores.
|
|
97
101
|
"""
|
|
98
|
-
|
|
99
|
-
if workspace is None:
|
|
100
|
-
raise ValueError(f"No workspace found for rollout {rollout_id}")
|
|
101
|
-
|
|
102
|
-
results: Dict[str, float] = {}
|
|
103
|
-
for func in self.reward_funcs:
|
|
104
|
-
try:
|
|
105
|
-
# Get function name, falling back to string representation if not available
|
|
106
|
-
func_name = getattr(func, "__name__", str(func))
|
|
107
|
-
results[func_name] = func(
|
|
108
|
-
prompt=prompt,
|
|
109
|
-
completion=completion,
|
|
110
|
-
ground_truth=ground_truth,
|
|
111
|
-
workspace=workspace,
|
|
112
|
-
**kwargs
|
|
113
|
-
)
|
|
114
|
-
except Exception as e:
|
|
115
|
-
# Use same function name resolution
|
|
116
|
-
func_name = getattr(func, "__name__", str(func))
|
|
117
|
-
results[func_name] = float('nan')
|
|
118
|
-
print(f"[WARN] reward {func_name} failed: {e}")
|
|
119
|
-
return results
|
|
102
|
+
pass
|
|
120
103
|
|
|
121
|
-
def get_system_prompt(self, add_tool_defs: bool = False) -> str:
|
|
104
|
+
async def get_system_prompt(self, add_tool_defs: bool = False) -> str:
|
|
122
105
|
"""Get system prompt. To add tool definitions, set add_tool_defs to True."""
|
|
123
106
|
if add_tool_defs:
|
|
124
|
-
return render_tools_prompt(self.list_tools(), self.system_prompt or "")
|
|
107
|
+
return render_tools_prompt(await self.list_tools(), self.system_prompt or "")
|
|
125
108
|
else:
|
|
126
109
|
return self.system_prompt
|
|
@@ -138,16 +138,7 @@ class ExcelEnv(LocalMCPEnv):
|
|
|
138
138
|
answer_spreadsheet_path = rollout_args["answer_spreadsheet_path"]
|
|
139
139
|
|
|
140
140
|
super().init_rollout(rollout_id, **rollout_args)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
Copy the spreadsheet file to the workspace if it doesn't already exist.
|
|
146
|
-
"""
|
|
147
|
-
dest_path = workspace / src_path.name
|
|
148
|
-
if not dest_path.exists():
|
|
149
|
-
dest_path.write_bytes(src_path.read_bytes())
|
|
150
|
-
|
|
151
|
-
# Copy the spreadsheet to the workspace
|
|
152
|
-
_copy_to_workspace(Path(spreadsheet_path))
|
|
153
|
-
_copy_to_workspace(Path(answer_spreadsheet_path))
|
|
141
|
+
|
|
142
|
+
self.copy_to_workspace(rollout_id, Path(spreadsheet_path))
|
|
143
|
+
self.copy_to_workspace(rollout_id, Path(answer_spreadsheet_path))
|
|
144
|
+
|
|
@@ -16,7 +16,6 @@ WHITE_LIKE_COLORS = [
|
|
|
16
16
|
]
|
|
17
17
|
|
|
18
18
|
def evaluate_excel(excel_path: str):
|
|
19
|
-
import xlwings
|
|
20
19
|
"""
|
|
21
20
|
Evaluate Python code that manipulates an Excel file using xlwings.
|
|
22
21
|
"""
|
|
@@ -24,11 +23,13 @@ def evaluate_excel(excel_path: str):
|
|
|
24
23
|
if platform.system() == "Linux":
|
|
25
24
|
evaluate_excel_libre(excel_path)
|
|
26
25
|
return
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
26
|
+
else:
|
|
27
|
+
import xlwings
|
|
28
|
+
excel_app = xlwings.App(visible=False)
|
|
29
|
+
excel_book = excel_app.books.open(excel_path)
|
|
30
|
+
excel_book.save()
|
|
31
|
+
excel_book.close()
|
|
32
|
+
excel_app.quit()
|
|
32
33
|
|
|
33
34
|
def evaluate_excel_libre(excel_path: str) -> None:
|
|
34
35
|
"""
|
|
@@ -120,7 +121,7 @@ def excel_to_str_repr(excel_path: str, evaluate_formulas = False) -> str:
|
|
|
120
121
|
result.append(f"{coords}: null [{', '.join(style)}]")
|
|
121
122
|
is_row_empty = False
|
|
122
123
|
elif display_value:
|
|
123
|
-
style_str = f" [{
|
|
124
|
+
style_str = f" [{', '.join(style)}]" if style else ""
|
|
124
125
|
result.append(f"{coords}: {display_value}{style_str}")
|
|
125
126
|
is_row_empty = False
|
|
126
127
|
if not is_row_empty:
|
|
@@ -212,4 +213,4 @@ def compare_excel_cells(ground_truth_path: str, output_path: str, answer_positio
|
|
|
212
213
|
return False, f"Fill color mismatch at {cell_name}"
|
|
213
214
|
if not compare_font_color(cell_gt.font, cell_out.font):
|
|
214
215
|
return False, f"Font color mismatch at {cell_name}"
|
|
215
|
-
return True, "All comparisons passed."
|
|
216
|
+
return True, "All comparisons passed."
|
|
@@ -228,6 +228,34 @@ class LocalMCPEnv(BaseEnv):
|
|
|
228
228
|
raise ValueError(f"No active client found for rollout {rollout_id}")
|
|
229
229
|
else:
|
|
230
230
|
return Path()
|
|
231
|
+
|
|
232
|
+
def copy_to_workspace(
|
|
233
|
+
self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
|
|
234
|
+
) -> None:
|
|
235
|
+
"""Copy a file to the workspace for a specific rollout. If dst_filename is None, use the original filename."""
|
|
236
|
+
if rollout_id not in self._active_clients:
|
|
237
|
+
raise ValueError(f"No active client found for rollout {rollout_id}")
|
|
238
|
+
|
|
239
|
+
if not src_path.exists():
|
|
240
|
+
raise FileNotFoundError(f"Source file {src_path} does not exist")
|
|
241
|
+
|
|
242
|
+
pair = self._active_clients[rollout_id]
|
|
243
|
+
dst_path = pair.workspace / (dst_filename or src_path.name)
|
|
244
|
+
dst_path.write_bytes(src_path.read_bytes())
|
|
245
|
+
|
|
246
|
+
def copy_from_workspace(
|
|
247
|
+
self, rollout_id: str, src_filename: str, dst_path: Path
|
|
248
|
+
) -> None:
|
|
249
|
+
"""Copy a file from the workspace for a specific rollout"""
|
|
250
|
+
if rollout_id not in self._active_clients:
|
|
251
|
+
raise ValueError(f"No active client found for rollout {rollout_id}")
|
|
252
|
+
|
|
253
|
+
pair = self._active_clients[rollout_id]
|
|
254
|
+
src_path = pair.workspace / src_filename
|
|
255
|
+
if not src_path.exists():
|
|
256
|
+
raise FileNotFoundError(f"File {src_filename} not found in workspace {pair.workspace}")
|
|
257
|
+
|
|
258
|
+
dst_path.write_bytes(src_path.read_bytes())
|
|
231
259
|
|
|
232
260
|
# ---- Private Helper Methods ----
|
|
233
261
|
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import shutil
|
|
4
|
+
import uuid
|
|
5
|
+
import yaml
|
|
6
|
+
import asyncio
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from functools import wraps
|
|
9
|
+
from fastmcp import FastMCP, Client
|
|
10
|
+
from starlette.requests import Request
|
|
11
|
+
from starlette.responses import PlainTextResponse, FileResponse, JSONResponse
|
|
12
|
+
from starlette.datastructures import UploadFile
|
|
13
|
+
|
|
14
|
+
from reward_func import reward_functions # your reward functions
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# ---------------- Utility Functions ---------------- #
|
|
18
|
+
def setup_workspace(base_dir: Path) -> Path:
|
|
19
|
+
"""Create a unique workspace directory."""
|
|
20
|
+
ws = (base_dir / uuid.uuid4().hex).resolve()
|
|
21
|
+
ws.mkdir(parents=True, exist_ok=True)
|
|
22
|
+
return ws
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def load_config(config_path: Path, workspace: Path) -> dict:
|
|
26
|
+
"""Load YAML config and inject workspace paths."""
|
|
27
|
+
with open(config_path, "r") as f:
|
|
28
|
+
content = f.read().replace("${{ sync_workdir }}", str(Path(__file__).resolve().parent))
|
|
29
|
+
config = yaml.safe_load(content)
|
|
30
|
+
if "mcpServers" in config:
|
|
31
|
+
for server in config["mcpServers"].values():
|
|
32
|
+
server["cwd"] = str(workspace)
|
|
33
|
+
return config
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ---------------- Auth Decorator ---------------- #
|
|
37
|
+
def require_auth(func):
|
|
38
|
+
"""Require API_TOKEN header."""
|
|
39
|
+
@wraps(func)
|
|
40
|
+
async def wrapper(*args, **kwargs):
|
|
41
|
+
request = args[1] if len(args) == 2 else args[0]
|
|
42
|
+
token = request.headers.get("Authorization")
|
|
43
|
+
if token != os.getenv("API_TOKEN", "default-secret-token"):
|
|
44
|
+
return PlainTextResponse("Unauthorized", status_code=401)
|
|
45
|
+
return await func(*args, **kwargs)
|
|
46
|
+
return wrapper
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ---------------- Proxy Server ---------------- #
|
|
50
|
+
class ProxyServer:
|
|
51
|
+
def __init__(self, base_dir="workspace", host="0.0.0.0", port=8080):
|
|
52
|
+
self.base_dir = Path(base_dir)
|
|
53
|
+
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
self.host = host
|
|
55
|
+
self.port = port
|
|
56
|
+
self.workspace: Path | None = None
|
|
57
|
+
self.client: Client | None = None
|
|
58
|
+
self.proxy: FastMCP | None = None
|
|
59
|
+
self.config_path = Path(__file__).parent / "mcp_config.yaml"
|
|
60
|
+
|
|
61
|
+
async def _setup(self):
|
|
62
|
+
"""Initialize workspace, MCP client, and proxy server."""
|
|
63
|
+
self.workspace = setup_workspace(self.base_dir)
|
|
64
|
+
config = load_config(self.config_path, self.workspace)
|
|
65
|
+
|
|
66
|
+
self.client = Client(config)
|
|
67
|
+
await self.client._connect()
|
|
68
|
+
|
|
69
|
+
self.proxy = FastMCP.as_proxy(self.client, name="proxy")
|
|
70
|
+
|
|
71
|
+
# Register endpoints
|
|
72
|
+
self.proxy.custom_route("/health", methods=["GET"])(self._health)
|
|
73
|
+
self.proxy.custom_route("/upload", methods=["POST"])(self._upload)
|
|
74
|
+
self.proxy.custom_route("/download", methods=["GET"])(self._download)
|
|
75
|
+
self.proxy.custom_route("/compute_reward", methods=["POST"])(self._compute_reward)
|
|
76
|
+
self.proxy.custom_route("/reset", methods=["POST"])(self._reset)
|
|
77
|
+
|
|
78
|
+
# ---------------- Endpoints ---------------- #
|
|
79
|
+
async def _health(self, request: Request):
|
|
80
|
+
return PlainTextResponse("OK")
|
|
81
|
+
|
|
82
|
+
@require_auth
|
|
83
|
+
async def _upload(self, request: Request):
|
|
84
|
+
if not self.workspace:
|
|
85
|
+
return PlainTextResponse("No workspace available", 500)
|
|
86
|
+
form = await request.form()
|
|
87
|
+
uploaded = []
|
|
88
|
+
for file in form.values():
|
|
89
|
+
if isinstance(file, UploadFile) and file.filename:
|
|
90
|
+
dest = self.workspace / file.filename
|
|
91
|
+
with open(dest, "wb") as f:
|
|
92
|
+
f.write(await file.read())
|
|
93
|
+
uploaded.append(file.filename)
|
|
94
|
+
if not uploaded:
|
|
95
|
+
return PlainTextResponse("No files uploaded", 400)
|
|
96
|
+
return PlainTextResponse(f"Uploaded: {', '.join(uploaded)}")
|
|
97
|
+
|
|
98
|
+
@require_auth
|
|
99
|
+
async def _download(self, request: Request):
|
|
100
|
+
if not self.workspace:
|
|
101
|
+
return PlainTextResponse("No workspace", 500)
|
|
102
|
+
file_path = request.query_params.get("file_path")
|
|
103
|
+
if not file_path:
|
|
104
|
+
return PlainTextResponse("file_path required", 400)
|
|
105
|
+
full_path = self.workspace / file_path
|
|
106
|
+
if not full_path.exists() or not full_path.is_file():
|
|
107
|
+
return PlainTextResponse("File not found", 404)
|
|
108
|
+
return FileResponse(str(full_path), filename=full_path.name)
|
|
109
|
+
|
|
110
|
+
@require_auth
|
|
111
|
+
async def _compute_reward(self, request: Request):
|
|
112
|
+
try:
|
|
113
|
+
data = await request.json()
|
|
114
|
+
except Exception:
|
|
115
|
+
return PlainTextResponse("Invalid JSON", 400)
|
|
116
|
+
|
|
117
|
+
completion = data.get("completion")
|
|
118
|
+
ground_truth = data.get("ground_truth")
|
|
119
|
+
if completion is None or ground_truth is None:
|
|
120
|
+
return PlainTextResponse("completion and ground_truth required", 400)
|
|
121
|
+
|
|
122
|
+
results = {}
|
|
123
|
+
for func in reward_functions or []:
|
|
124
|
+
name = getattr(func, "__name__", str(func))
|
|
125
|
+
try:
|
|
126
|
+
results[name] = func(completion=completion, ground_truth=ground_truth, workspace=self.workspace, mcp_client=self.client, **{
|
|
127
|
+
k: v for k, v in data.items() if k not in ("completion", "ground_truth")
|
|
128
|
+
})
|
|
129
|
+
except Exception as e:
|
|
130
|
+
results[name] = float("nan")
|
|
131
|
+
print(f"[WARN] reward {name} failed: {e}")
|
|
132
|
+
return JSONResponse(results)
|
|
133
|
+
|
|
134
|
+
@require_auth
|
|
135
|
+
async def _reset(self, request: Request):
|
|
136
|
+
"""Reset server: clean workspace and restart process."""
|
|
137
|
+
async def do_reset():
|
|
138
|
+
await asyncio.sleep(0.1)
|
|
139
|
+
print("[INFO] Resetting server...")
|
|
140
|
+
sys.stdout.flush()
|
|
141
|
+
os.execv(sys.executable, [sys.executable] + sys.argv)
|
|
142
|
+
|
|
143
|
+
# Clean up workspace
|
|
144
|
+
self.cleanup_workspace()
|
|
145
|
+
|
|
146
|
+
asyncio.create_task(do_reset())
|
|
147
|
+
return PlainTextResponse("Server reset scheduled")
|
|
148
|
+
|
|
149
|
+
# ---------------- Public API ---------------- #
|
|
150
|
+
def cleanup_workspace(self):
|
|
151
|
+
if self.workspace and self.workspace.exists():
|
|
152
|
+
shutil.rmtree(self.workspace)
|
|
153
|
+
|
|
154
|
+
async def start(self):
|
|
155
|
+
await self._setup()
|
|
156
|
+
if self.proxy:
|
|
157
|
+
await self.proxy.run_async(transport="http", host=self.host, port=self.port)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# ---------------- Main ---------------- #
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
server = ProxyServer("../workspace")
|
|
163
|
+
try:
|
|
164
|
+
asyncio.run(server.start())
|
|
165
|
+
except KeyboardInterrupt:
|
|
166
|
+
print("\nShutting down gracefully...")
|
|
167
|
+
server.cleanup_workspace()
|
|
@@ -0,0 +1,694 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import datetime
|
|
3
|
+
import aiohttp
|
|
4
|
+
import uuid
|
|
5
|
+
import tempfile
|
|
6
|
+
import shutil
|
|
7
|
+
from typing import Callable, List, Any, Optional, Dict
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from fastmcp import Client as FastMCPClient
|
|
10
|
+
from mcp.types import TextContent
|
|
11
|
+
from fastmcp.exceptions import ToolError
|
|
12
|
+
from mcp import Tool
|
|
13
|
+
import sky
|
|
14
|
+
import logging
|
|
15
|
+
import warnings
|
|
16
|
+
|
|
17
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
18
|
+
|
|
19
|
+
logging.basicConfig(level=logging.DEBUG, format='%(levelname)s:%(name)s:%(message)s')
|
|
20
|
+
logging.getLogger('httpx').setLevel(logging.CRITICAL)
|
|
21
|
+
logging.getLogger('aiohttp').setLevel(logging.CRITICAL)
|
|
22
|
+
logging.getLogger('mcp.client.streamable_http').setLevel(logging.CRITICAL)
|
|
23
|
+
logging.getLogger('urllib3').setLevel(logging.CRITICAL)
|
|
24
|
+
logging.getLogger('requests').setLevel(logging.CRITICAL)
|
|
25
|
+
|
|
26
|
+
from benchmax.envs.base_env import BaseEnv
|
|
27
|
+
from benchmax.envs.types import ToolDefinition
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
class RemoteSkypilotMcpEnv(BaseEnv):
|
|
32
|
+
"""Remote MCP Environment for managing tool execution and rollouts with a remote MCP server.
|
|
33
|
+
Currently only supports running on Skypilot containers.
|
|
34
|
+
"""
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
workdir_path: str,
|
|
38
|
+
num_nodes: int = 1,
|
|
39
|
+
allowed_tools: Optional[List[str]] = None,
|
|
40
|
+
output_parsers: Optional[Dict[str, Callable[[str], Any]]] = None,
|
|
41
|
+
cluster_name: str = "benchmax-env-cluster",
|
|
42
|
+
health_check_timeout: int = 300, # 5 minutes
|
|
43
|
+
health_check_interval: int = 5, # 5 seconds
|
|
44
|
+
launch_workers_on_init: bool = True,
|
|
45
|
+
cloud: Optional[Any] = sky.Azure(), # sky.Cloud instance
|
|
46
|
+
cpus: str = "2+",
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialize the environment with configuration and pool settings."""
|
|
49
|
+
super().__init__()
|
|
50
|
+
self._workdir_path = Path(workdir_path)
|
|
51
|
+
self._num_nodes = num_nodes
|
|
52
|
+
self._allowed_tools = allowed_tools or []
|
|
53
|
+
self._output_parsers: Dict[str, Callable[[str], Any]] = output_parsers or {}
|
|
54
|
+
self._cluster_name = cluster_name
|
|
55
|
+
self._health_check_timeout = health_check_timeout
|
|
56
|
+
self._health_check_interval = health_check_interval
|
|
57
|
+
self._cloud = cloud
|
|
58
|
+
self._cpus = cpus
|
|
59
|
+
self._ports = ["8080"] # MCP server port
|
|
60
|
+
|
|
61
|
+
# Generate API token for worker authentication
|
|
62
|
+
self._api_token = uuid.uuid4().hex
|
|
63
|
+
|
|
64
|
+
# Generate unique cluster name with suffix
|
|
65
|
+
unique_suffix = uuid.uuid4().hex[:8]
|
|
66
|
+
self._full_cluster_name = f"{self._cluster_name}-{unique_suffix}"
|
|
67
|
+
|
|
68
|
+
# Sync directory management
|
|
69
|
+
self._sync_dir: Optional[Path] = None
|
|
70
|
+
|
|
71
|
+
# Worker management
|
|
72
|
+
self._client_pool: Dict[str, FastMCPClient] = {}
|
|
73
|
+
self._available_workers: asyncio.Queue[str] = asyncio.Queue()
|
|
74
|
+
self._rollout_to_worker: Dict[str, str] = {}
|
|
75
|
+
self._worker_init_tasks: List[asyncio.Task] = []
|
|
76
|
+
|
|
77
|
+
# HTTP session
|
|
78
|
+
self._http_session = aiohttp.ClientSession()
|
|
79
|
+
|
|
80
|
+
# Cached tool definitions
|
|
81
|
+
self._tool_definitions: Optional[List[ToolDefinition]] = None
|
|
82
|
+
|
|
83
|
+
# Launch workers and start initialization
|
|
84
|
+
self.launch_workers_started = False
|
|
85
|
+
if launch_workers_on_init:
|
|
86
|
+
self.launch_workers()
|
|
87
|
+
|
|
88
|
+
def _setup_sync_directory(self) -> Path:
|
|
89
|
+
"""Create temporary sync directory and copy required files."""
|
|
90
|
+
# Create temporary directory
|
|
91
|
+
self._sync_dir = Path(tempfile.mkdtemp(prefix="benchmax_skypilot_"))
|
|
92
|
+
logger.info(f"Created sync directory: {self._sync_dir}")
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
# Get the directory where this file is located
|
|
96
|
+
current_file_dir = Path(__file__).parent # inside remote_skypilot_mcp_server.py
|
|
97
|
+
proxy_server_path = current_file_dir / "proxy_server.py"
|
|
98
|
+
|
|
99
|
+
# Copy proxy_server.py
|
|
100
|
+
if not proxy_server_path.exists():
|
|
101
|
+
raise FileNotFoundError(f"proxy_server.py not found at {proxy_server_path}")
|
|
102
|
+
|
|
103
|
+
shutil.copy2(proxy_server_path, self._sync_dir / "proxy_server.py")
|
|
104
|
+
logger.debug(f"Copied proxy_server.py to sync directory")
|
|
105
|
+
|
|
106
|
+
# Copy all contents from workdir_path
|
|
107
|
+
if not self._workdir_path.exists():
|
|
108
|
+
raise FileNotFoundError(f"Workdir path does not exist: {self._workdir_path}")
|
|
109
|
+
|
|
110
|
+
if not self._workdir_path.is_dir():
|
|
111
|
+
raise ValueError(f"Workdir path is not a directory: {self._workdir_path}")
|
|
112
|
+
|
|
113
|
+
# Copy all contents
|
|
114
|
+
for item in self._workdir_path.iterdir():
|
|
115
|
+
if item.is_file():
|
|
116
|
+
shutil.copy2(item, self._sync_dir / item.name)
|
|
117
|
+
logger.debug(f"Copied file {item.name} to sync directory")
|
|
118
|
+
elif item.is_dir():
|
|
119
|
+
shutil.copytree(item, self._sync_dir / item.name)
|
|
120
|
+
logger.debug(f"Copied directory {item.name} to sync directory")
|
|
121
|
+
|
|
122
|
+
# Validate required files exist
|
|
123
|
+
reward_func_path = self._sync_dir / "reward_func.py"
|
|
124
|
+
setup_sh_path = self._sync_dir / "setup.sh"
|
|
125
|
+
mcp_config_path = self._sync_dir / "mcp_config.yaml"
|
|
126
|
+
|
|
127
|
+
if not reward_func_path.exists():
|
|
128
|
+
raise FileNotFoundError(f"reward_func.py not found in workdir: {self._workdir_path}")
|
|
129
|
+
|
|
130
|
+
if not setup_sh_path.exists():
|
|
131
|
+
raise FileNotFoundError(f"setup.sh not found in workdir: {self._workdir_path}")
|
|
132
|
+
|
|
133
|
+
if not mcp_config_path.exists():
|
|
134
|
+
raise FileNotFoundError(f"mcp_config.yaml not found in workdir: {self._workdir_path}")
|
|
135
|
+
|
|
136
|
+
logger.info(f"Validated required files in sync directory")
|
|
137
|
+
return self._sync_dir
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
# Clean up sync directory if setup fails
|
|
141
|
+
if self._sync_dir and self._sync_dir.exists():
|
|
142
|
+
shutil.rmtree(self._sync_dir, ignore_errors=True)
|
|
143
|
+
self._sync_dir = None
|
|
144
|
+
raise e
|
|
145
|
+
|
|
146
|
+
def _cleanup_sync_directory(self) -> None:
|
|
147
|
+
"""Clean up the temporary sync directory."""
|
|
148
|
+
if self._sync_dir and self._sync_dir.exists():
|
|
149
|
+
try:
|
|
150
|
+
shutil.rmtree(self._sync_dir)
|
|
151
|
+
logger.info(f"Cleaned up sync directory: {self._sync_dir}")
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.warning(f"Failed to clean up sync directory {self._sync_dir}: {e}")
|
|
154
|
+
finally:
|
|
155
|
+
self._sync_dir = None
|
|
156
|
+
|
|
157
|
+
async def _init_worker(self, worker_ip: str) -> None:
|
|
158
|
+
"""Initialize a single worker: health check + FastMCP client + add to pool."""
|
|
159
|
+
try:
|
|
160
|
+
# Health check
|
|
161
|
+
await self._wait_for_worker_health(worker_ip)
|
|
162
|
+
|
|
163
|
+
# Initialize FastMCP client
|
|
164
|
+
mcp_url = f"http://{worker_ip}:8080/mcp/"
|
|
165
|
+
client = FastMCPClient(mcp_url)
|
|
166
|
+
await client._connect()
|
|
167
|
+
self._client_pool[worker_ip] = client
|
|
168
|
+
|
|
169
|
+
# Add to available pool
|
|
170
|
+
await self._available_workers.put(worker_ip)
|
|
171
|
+
logger.debug(f"Worker {worker_ip} initialized and added to pool")
|
|
172
|
+
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.error(f"Failed to initialize worker {worker_ip}: {e}")
|
|
175
|
+
# Don't re-raise - let other workers continue
|
|
176
|
+
|
|
177
|
+
async def _wait_for_worker_health(self, worker_ip: str) -> None:
|
|
178
|
+
"""Wait for worker to pass health check."""
|
|
179
|
+
health_url = f"http://{worker_ip}:8080/health"
|
|
180
|
+
start_time = asyncio.get_event_loop().time()
|
|
181
|
+
|
|
182
|
+
while True:
|
|
183
|
+
elapsed = asyncio.get_event_loop().time() - start_time
|
|
184
|
+
if elapsed > self._health_check_timeout:
|
|
185
|
+
raise TimeoutError(f"Health check timeout for worker {worker_ip}")
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
timeout = aiohttp.ClientTimeout(total=5)
|
|
189
|
+
async with self._http_session.get(health_url, timeout=timeout) as response:
|
|
190
|
+
if response.status == 200:
|
|
191
|
+
logger.debug(f"Worker {worker_ip} is healthy")
|
|
192
|
+
return
|
|
193
|
+
else:
|
|
194
|
+
logger.debug(f"Worker {worker_ip} health check returned {response.status}")
|
|
195
|
+
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
|
196
|
+
logger.debug(f"Health check failed for {worker_ip}: {e}")
|
|
197
|
+
|
|
198
|
+
await asyncio.sleep(self._health_check_interval)
|
|
199
|
+
|
|
200
|
+
async def _get_available_worker(self) -> str:
|
|
201
|
+
"""Get an available worker, blocking until one is ready."""
|
|
202
|
+
return await self._available_workers.get()
|
|
203
|
+
|
|
204
|
+
async def _release_worker(self, worker_ip: str) -> None:
|
|
205
|
+
"""Return a worker to the available pool."""
|
|
206
|
+
await self._available_workers.put(worker_ip)
|
|
207
|
+
|
|
208
|
+
async def _call_worker_reset(self, worker_ip: str) -> None:
|
|
209
|
+
"""Call the reset endpoint on a specific worker."""
|
|
210
|
+
reset_url = f"http://{worker_ip}:8080/reset"
|
|
211
|
+
headers = {"Authorization": self._api_token}
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
async with self._http_session.post(reset_url, headers=headers) as response:
|
|
215
|
+
if response.status == 200:
|
|
216
|
+
logger.info(f"Reset successful for worker {worker_ip}")
|
|
217
|
+
else:
|
|
218
|
+
error_text = await response.text()
|
|
219
|
+
raise RuntimeError(f"Reset failed for worker {worker_ip}: {response.status} - {error_text}")
|
|
220
|
+
except aiohttp.ClientError as e:
|
|
221
|
+
raise RuntimeError(f"Reset request failed for worker {worker_ip}: {e}")
|
|
222
|
+
|
|
223
|
+
async def add_worker_back_once_available(self, worker_ip: str) -> None:
|
|
224
|
+
"""Add a worker back to the available pool once it passes health check."""
|
|
225
|
+
await asyncio.sleep(1) # brief delay before starting health checks
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
await self._init_worker(worker_ip)
|
|
229
|
+
logger.info(f"Worker {worker_ip} added back to available pool after reset")
|
|
230
|
+
except Exception as e:
|
|
231
|
+
logger.error(f"Failed to add worker {worker_ip} back to pool: {e}")
|
|
232
|
+
# Don't re-raise - worker remains out of pool
|
|
233
|
+
|
|
234
|
+
# Function is expected to be called at the end of compute_reward
|
|
235
|
+
async def _cleanup_rollout(self, rollout_id: str) -> None:
|
|
236
|
+
"""Clean up rollout resources and return worker to pool."""
|
|
237
|
+
if rollout_id not in self._rollout_to_worker:
|
|
238
|
+
raise ValueError(f"Rollout {rollout_id} is not initialized")
|
|
239
|
+
|
|
240
|
+
worker_ip = self._rollout_to_worker[rollout_id]
|
|
241
|
+
logger.debug(f"Cleaning up rollout {rollout_id} on worker {worker_ip}")
|
|
242
|
+
|
|
243
|
+
del self._rollout_to_worker[rollout_id]
|
|
244
|
+
try:
|
|
245
|
+
# Call reset endpoint
|
|
246
|
+
await self._call_worker_reset(worker_ip)
|
|
247
|
+
except Exception as e:
|
|
248
|
+
logger.error(f"Failed to reset worker {worker_ip} for rollout {rollout_id}: {e}")
|
|
249
|
+
|
|
250
|
+
# Disconnect the client
|
|
251
|
+
client = self._client_pool.get(worker_ip)
|
|
252
|
+
if client:
|
|
253
|
+
await client._disconnect()
|
|
254
|
+
self._client_pool.pop(worker_ip, None)
|
|
255
|
+
|
|
256
|
+
# Start background task to add worker back once healthy
|
|
257
|
+
asyncio.create_task(self.add_worker_back_once_available(worker_ip))
|
|
258
|
+
|
|
259
|
+
def _convert_and_filter_tools(self, tools: List[Tool]) -> List[ToolDefinition]:
|
|
260
|
+
"""Convert Tool objects to ToolDefinition objects and filter based on allowed list."""
|
|
261
|
+
tool_definitions = [
|
|
262
|
+
ToolDefinition(
|
|
263
|
+
name=tool.name,
|
|
264
|
+
description=tool.description or "",
|
|
265
|
+
input_schema=tool.inputSchema
|
|
266
|
+
)
|
|
267
|
+
for tool in tools
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
if not self._allowed_tools:
|
|
271
|
+
return tool_definitions
|
|
272
|
+
|
|
273
|
+
return [tool for tool in tool_definitions if tool.name in self._allowed_tools]
|
|
274
|
+
|
|
275
|
+
# ---- Public API Methods ----
|
|
276
|
+
def launch_workers(self) -> None:
|
|
277
|
+
"""Launch SkyPilot workers synchronously with programmatically created task."""
|
|
278
|
+
if self.launch_workers_started:
|
|
279
|
+
raise RuntimeError("Workers have already been launched.")
|
|
280
|
+
|
|
281
|
+
self.launch_workers_started = True
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
# Setup sync directory and copy files
|
|
285
|
+
sync_dir = self._setup_sync_directory()
|
|
286
|
+
|
|
287
|
+
# Create the task programmatically
|
|
288
|
+
task = sky.Task(
|
|
289
|
+
name='fastmcp',
|
|
290
|
+
setup='pip install fastmcp~=2.10.0\npip install pyyaml\nsh setup.sh',
|
|
291
|
+
run='python proxy_server.py',
|
|
292
|
+
workdir=str(sync_dir),
|
|
293
|
+
num_nodes=self._num_nodes
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Set the resources
|
|
297
|
+
task.set_resources(
|
|
298
|
+
sky.Resources(
|
|
299
|
+
cloud=self._cloud,
|
|
300
|
+
cpus=self._cpus,
|
|
301
|
+
ports=self._ports
|
|
302
|
+
)
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Update environment variables with API token
|
|
306
|
+
task.update_envs({"API_TOKEN": self._api_token})
|
|
307
|
+
|
|
308
|
+
# Launch the cluster
|
|
309
|
+
_, handle = sky.launch(
|
|
310
|
+
task=task,
|
|
311
|
+
cluster_name=self._full_cluster_name,
|
|
312
|
+
detach_run=True,
|
|
313
|
+
detach_setup=True,
|
|
314
|
+
retry_until_up=True
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
if handle is None:
|
|
318
|
+
raise RuntimeError("Failed to launch SkyPilot task.")
|
|
319
|
+
|
|
320
|
+
worker_ips = [
|
|
321
|
+
external_ip for _, external_ip in handle.stable_internal_external_ips
|
|
322
|
+
]
|
|
323
|
+
logger.info(f"Launched workers with IPs: {worker_ips}")
|
|
324
|
+
|
|
325
|
+
# Start background initialization for each worker
|
|
326
|
+
for worker_ip in worker_ips:
|
|
327
|
+
task = asyncio.create_task(self._init_worker(worker_ip))
|
|
328
|
+
self._worker_init_tasks.append(task)
|
|
329
|
+
|
|
330
|
+
except Exception as e:
|
|
331
|
+
# Clean up sync directory if launch fails
|
|
332
|
+
self._cleanup_sync_directory()
|
|
333
|
+
raise e
|
|
334
|
+
|
|
335
|
+
async def shutdown(self) -> None:
|
|
336
|
+
"""Clean up resources - stop all tasks and close clients."""
|
|
337
|
+
try:
|
|
338
|
+
# Cancel worker initialization tasks
|
|
339
|
+
for task in self._worker_init_tasks:
|
|
340
|
+
if not task.done():
|
|
341
|
+
task.cancel()
|
|
342
|
+
|
|
343
|
+
if self._worker_init_tasks:
|
|
344
|
+
results = await asyncio.gather(*self._worker_init_tasks, return_exceptions=True)
|
|
345
|
+
for i, result in enumerate(results):
|
|
346
|
+
if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError):
|
|
347
|
+
logger.error(f"Error in worker init task {i}: {result}")
|
|
348
|
+
|
|
349
|
+
# Close FastMCP clients
|
|
350
|
+
if self._client_pool:
|
|
351
|
+
close_tasks = [client.close() for client in self._client_pool.values()]
|
|
352
|
+
results = await asyncio.gather(*close_tasks, return_exceptions=True)
|
|
353
|
+
|
|
354
|
+
for i, result in enumerate(results):
|
|
355
|
+
if isinstance(result, Exception):
|
|
356
|
+
worker_ip = list(self._client_pool.keys())[i]
|
|
357
|
+
logger.error(f"Error closing FastMCP client for {worker_ip}: {result}")
|
|
358
|
+
|
|
359
|
+
# Close HTTP session
|
|
360
|
+
await self._http_session.close()
|
|
361
|
+
|
|
362
|
+
# Tear down SkyPilot cluster
|
|
363
|
+
try:
|
|
364
|
+
sky.down(cluster_name=self._full_cluster_name)
|
|
365
|
+
except Exception as e:
|
|
366
|
+
logger.error(f"Error tearing down SkyPilot cluster: {e}")
|
|
367
|
+
|
|
368
|
+
finally:
|
|
369
|
+
# Always clean up sync directory
|
|
370
|
+
self._cleanup_sync_directory()
|
|
371
|
+
|
|
372
|
+
async def list_tools(self) -> List[ToolDefinition]:
|
|
373
|
+
"""List available tools, using cached definitions if available."""
|
|
374
|
+
if self._tool_definitions is not None:
|
|
375
|
+
return self._tool_definitions
|
|
376
|
+
|
|
377
|
+
# Get any available worker to fetch tools
|
|
378
|
+
worker_ip = await self._get_available_worker()
|
|
379
|
+
try:
|
|
380
|
+
client = self._client_pool[worker_ip]
|
|
381
|
+
tools = await client.list_tools()
|
|
382
|
+
self._tool_definitions = self._convert_and_filter_tools(tools)
|
|
383
|
+
return self._tool_definitions
|
|
384
|
+
finally:
|
|
385
|
+
await self._release_worker(worker_ip)
|
|
386
|
+
|
|
387
|
+
async def init_rollout(self, rollout_id: str, **rollout_args) -> None:
|
|
388
|
+
"""Initialize resources for a new rollout - assigns a worker to the rollout."""
|
|
389
|
+
if rollout_id in self._rollout_to_worker:
|
|
390
|
+
raise ValueError(f"Rollout {rollout_id} is already initialized")
|
|
391
|
+
|
|
392
|
+
# Get an available worker (blocks until one is ready)
|
|
393
|
+
worker_ip = await self._get_available_worker()
|
|
394
|
+
|
|
395
|
+
# Assign worker to rollout
|
|
396
|
+
self._rollout_to_worker[rollout_id] = worker_ip
|
|
397
|
+
logger.info(f"Rollout {rollout_id} assigned to worker {worker_ip}")
|
|
398
|
+
|
|
399
|
+
async def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Optional[str]:
|
|
400
|
+
"""Execute a tool in the context of a specific rollout."""
|
|
401
|
+
if rollout_id not in self._rollout_to_worker:
|
|
402
|
+
raise ValueError(f"Rollout {rollout_id} is not initialized. Call init_rollout() first.")
|
|
403
|
+
|
|
404
|
+
worker_ip = self._rollout_to_worker[rollout_id]
|
|
405
|
+
client = self._client_pool[worker_ip]
|
|
406
|
+
|
|
407
|
+
try:
|
|
408
|
+
content_list = (await client.call_tool(tool_name, tool_args, timeout=datetime.timedelta(seconds=30))).content
|
|
409
|
+
text_content = []
|
|
410
|
+
# Process content based on type
|
|
411
|
+
for content in content_list:
|
|
412
|
+
# Text content
|
|
413
|
+
if isinstance(content, TextContent):
|
|
414
|
+
text_content.append(content.text)
|
|
415
|
+
# Only process text content for now
|
|
416
|
+
|
|
417
|
+
combined_text = "\n".join(text_content)
|
|
418
|
+
# Apply output parser if available
|
|
419
|
+
if tool_name in self._output_parsers and isinstance(combined_text, str):
|
|
420
|
+
return self._output_parsers[tool_name](combined_text)
|
|
421
|
+
|
|
422
|
+
return combined_text
|
|
423
|
+
|
|
424
|
+
except ToolError as e:
|
|
425
|
+
logger.error(f"[ERROR] Tool call returned error: {str(e)}")
|
|
426
|
+
return None
|
|
427
|
+
except Exception as e:
|
|
428
|
+
logger.error(f"[ERROR] Tool call failed: {str(e)}")
|
|
429
|
+
return None
|
|
430
|
+
|
|
431
|
+
async def copy_to_workspace(
|
|
432
|
+
self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
|
|
433
|
+
) -> None:
|
|
434
|
+
"""Copy a file to the workspace for a specific rollout."""
|
|
435
|
+
if rollout_id not in self._rollout_to_worker:
|
|
436
|
+
raise ValueError(f"Rollout {rollout_id} is not initialized")
|
|
437
|
+
|
|
438
|
+
worker_ip = self._rollout_to_worker[rollout_id]
|
|
439
|
+
upload_url = f"http://{worker_ip}:8080/upload"
|
|
440
|
+
headers = {"Authorization": self._api_token}
|
|
441
|
+
|
|
442
|
+
# Prepare file for upload
|
|
443
|
+
filename = dst_filename or src_path.name
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
with open(src_path, 'rb') as f:
|
|
447
|
+
data = aiohttp.FormData()
|
|
448
|
+
data.add_field('file', f, filename=filename)
|
|
449
|
+
|
|
450
|
+
async with self._http_session.post(upload_url, headers=headers, data=data) as response:
|
|
451
|
+
if response.status == 200:
|
|
452
|
+
logger.info(f"File {src_path} uploaded as {filename} for rollout {rollout_id}")
|
|
453
|
+
else:
|
|
454
|
+
error_text = await response.text()
|
|
455
|
+
raise RuntimeError(f"Upload failed: {response.status} - {error_text}")
|
|
456
|
+
except Exception as e:
|
|
457
|
+
logger.error(f"Failed to copy {src_path} to workspace for rollout {rollout_id}: {e}")
|
|
458
|
+
raise
|
|
459
|
+
|
|
460
|
+
async def copy_content_to_workspace(
|
|
461
|
+
self, rollout_id: str, src_content: str | bytes, dst_filename: str, encoding: str = "utf-8"
|
|
462
|
+
) -> None:
|
|
463
|
+
"""Copy content (string or bytes) to the workspace for a specific rollout.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
rollout_id: The rollout identifier.
|
|
467
|
+
src_content: The content to upload (str or bytes).
|
|
468
|
+
dst_filename: The filename to assign in the workspace.
|
|
469
|
+
encoding: Encoding to use if src_content is str. Defaults to UTF-8.
|
|
470
|
+
"""
|
|
471
|
+
if rollout_id not in self._rollout_to_worker:
|
|
472
|
+
raise ValueError(f"Rollout {rollout_id} is not initialized")
|
|
473
|
+
|
|
474
|
+
worker_ip = self._rollout_to_worker[rollout_id]
|
|
475
|
+
upload_url = f"http://{worker_ip}:8080/upload"
|
|
476
|
+
headers = {"Authorization": self._api_token}
|
|
477
|
+
|
|
478
|
+
try:
|
|
479
|
+
if isinstance(src_content, str):
|
|
480
|
+
file_bytes = src_content.encode(encoding)
|
|
481
|
+
content_type = "text/plain"
|
|
482
|
+
else:
|
|
483
|
+
file_bytes = src_content
|
|
484
|
+
content_type = "application/octet-stream"
|
|
485
|
+
|
|
486
|
+
data = aiohttp.FormData()
|
|
487
|
+
data.add_field(
|
|
488
|
+
"file",
|
|
489
|
+
file_bytes,
|
|
490
|
+
filename=dst_filename,
|
|
491
|
+
content_type=content_type,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
async with self._http_session.post(upload_url, headers=headers, data=data) as response:
|
|
495
|
+
if response.status == 200:
|
|
496
|
+
logger.info(f"Content uploaded as {dst_filename} for rollout {rollout_id}")
|
|
497
|
+
else:
|
|
498
|
+
error_text = await response.text()
|
|
499
|
+
raise RuntimeError(f"Upload failed: {response.status} - {error_text}")
|
|
500
|
+
except Exception as e:
|
|
501
|
+
logger.error(f"Failed to upload content to workspace for rollout {rollout_id}: {e}")
|
|
502
|
+
raise
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
async def copy_from_workspace(
|
|
506
|
+
self, rollout_id: str, src_filename: str, dst_path: Path
|
|
507
|
+
) -> None:
|
|
508
|
+
"""Copy a file from the workspace for a specific rollout."""
|
|
509
|
+
if rollout_id not in self._rollout_to_worker:
|
|
510
|
+
raise ValueError(f"Rollout {rollout_id} is not initialized")
|
|
511
|
+
|
|
512
|
+
worker_ip = self._rollout_to_worker[rollout_id]
|
|
513
|
+
download_url = f"http://{worker_ip}:8080/download"
|
|
514
|
+
headers = {"Authorization": self._api_token}
|
|
515
|
+
params = {"file_path": src_filename}
|
|
516
|
+
|
|
517
|
+
try:
|
|
518
|
+
async with self._http_session.get(download_url, headers=headers, params=params) as response:
|
|
519
|
+
if response.status == 200:
|
|
520
|
+
# Ensure destination directory exists
|
|
521
|
+
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
|
522
|
+
|
|
523
|
+
# Write file content
|
|
524
|
+
with open(dst_path, 'wb') as f:
|
|
525
|
+
async for chunk in response.content.iter_chunked(8192):
|
|
526
|
+
f.write(chunk)
|
|
527
|
+
|
|
528
|
+
logger.info(f"File {src_filename} downloaded from rollout {rollout_id} to {dst_path}")
|
|
529
|
+
else:
|
|
530
|
+
error_text = await response.text()
|
|
531
|
+
raise RuntimeError(f"Download failed: {response.status} - {error_text}")
|
|
532
|
+
except Exception as e:
|
|
533
|
+
logger.error(f"Failed to copy {src_filename} from workspace for rollout {rollout_id}: {e}")
|
|
534
|
+
raise
|
|
535
|
+
|
|
536
|
+
async def compute_reward(
|
|
537
|
+
self,
|
|
538
|
+
rollout_id: str,
|
|
539
|
+
completion: str,
|
|
540
|
+
ground_truth: Any,
|
|
541
|
+
**kwargs: Any
|
|
542
|
+
) -> Dict[str, float]:
|
|
543
|
+
"""Compute rewards using registered functions
|
|
544
|
+
|
|
545
|
+
Returns dict mapping reward function names to their computed scores.
|
|
546
|
+
"""
|
|
547
|
+
if rollout_id not in self._rollout_to_worker:
|
|
548
|
+
raise ValueError(f"Rollout {rollout_id} is not initialized")
|
|
549
|
+
|
|
550
|
+
worker_ip = self._rollout_to_worker[rollout_id]
|
|
551
|
+
compute_reward_url = f"http://{worker_ip}:8080/compute_reward"
|
|
552
|
+
headers = {
|
|
553
|
+
"Authorization": self._api_token,
|
|
554
|
+
"Content-Type": "application/json"
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
# Prepare request payload
|
|
558
|
+
payload = {
|
|
559
|
+
"completion": completion,
|
|
560
|
+
"ground_truth": ground_truth,
|
|
561
|
+
**kwargs
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
try:
|
|
565
|
+
async with self._http_session.post(
|
|
566
|
+
compute_reward_url,
|
|
567
|
+
headers=headers,
|
|
568
|
+
json=payload
|
|
569
|
+
) as response:
|
|
570
|
+
if response.status == 200:
|
|
571
|
+
result = await response.json()
|
|
572
|
+
logger.debug(f"Reward computed successfully for rollout {rollout_id}")
|
|
573
|
+
return result
|
|
574
|
+
else:
|
|
575
|
+
error_text = await response.text()
|
|
576
|
+
raise RuntimeError(f"Reward computation failed: {response.status} - {error_text}")
|
|
577
|
+
except aiohttp.ClientError as e:
|
|
578
|
+
logger.error(f"Failed to compute reward for rollout {rollout_id}: {e}")
|
|
579
|
+
raise RuntimeError(f"Reward computation request failed: {e}")
|
|
580
|
+
except Exception as e:
|
|
581
|
+
logger.error(f"Unexpected error computing reward for rollout {rollout_id}: {e}")
|
|
582
|
+
raise
|
|
583
|
+
finally:
|
|
584
|
+
await self._cleanup_rollout(rollout_id)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
async def run_single_rollout(env: RemoteSkypilotMcpEnv, rollout_id: str, expression: str, expected: str, tmp_root: Path):
|
|
588
|
+
"""Run a complete rollout: init -> upload -> download+verify -> tool -> reward -> cleanup"""
|
|
589
|
+
print(f"Starting rollout: {rollout_id}")
|
|
590
|
+
|
|
591
|
+
# Create rollout-specific tmp dir
|
|
592
|
+
rollout_tmp = tmp_root / rollout_id
|
|
593
|
+
rollout_tmp.mkdir(parents=True, exist_ok=True)
|
|
594
|
+
|
|
595
|
+
# Stage 1: Initialize rollout
|
|
596
|
+
await env.init_rollout(rollout_id)
|
|
597
|
+
print(f"Initialized rollout: {rollout_id}")
|
|
598
|
+
|
|
599
|
+
# Stage 1.5a: Upload various content types
|
|
600
|
+
test_contents = {
|
|
601
|
+
"utf8_text.txt": f"# UTF-8 text for {rollout_id}\nExpression: {expression}\n",
|
|
602
|
+
"latin1_text.txt": "Café Münster".encode("latin-1"),
|
|
603
|
+
"json_data.json": '{"rollout": "%s", "value": %s}' % (rollout_id, expression),
|
|
604
|
+
"binary_data.bin": b"\x00\x01\x02\x03\xFF",
|
|
605
|
+
"unicode_text.txt": "你好, мир, hello 🌍",
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
for filename, content in test_contents.items():
|
|
609
|
+
await env.copy_content_to_workspace(rollout_id, content, filename)
|
|
610
|
+
print(f"Uploaded {filename} for {rollout_id}")
|
|
611
|
+
|
|
612
|
+
# Stage 1.5b: Test file-based copy
|
|
613
|
+
tmp_path = rollout_tmp / "local_file.txt"
|
|
614
|
+
tmp_path.write_text(f"Temporary file for {rollout_id}, expression={expression}\n", encoding="utf-8")
|
|
615
|
+
await env.copy_to_workspace(rollout_id, tmp_path, dst_filename=f"copied_{rollout_id}.txt")
|
|
616
|
+
print(f"Copied file {tmp_path} to workspace for {rollout_id}")
|
|
617
|
+
|
|
618
|
+
# Stage 1.6: Download and verify content
|
|
619
|
+
for filename, original_content in test_contents.items():
|
|
620
|
+
download_path = rollout_tmp / f"dl_{filename}"
|
|
621
|
+
await env.copy_from_workspace(rollout_id, filename, download_path)
|
|
622
|
+
|
|
623
|
+
downloaded_bytes = download_path.read_bytes()
|
|
624
|
+
if isinstance(original_content, str):
|
|
625
|
+
original_bytes = original_content.encode("utf-8")
|
|
626
|
+
else:
|
|
627
|
+
original_bytes = original_content
|
|
628
|
+
|
|
629
|
+
if downloaded_bytes == original_bytes:
|
|
630
|
+
print(f"Verified {filename} ✅")
|
|
631
|
+
else:
|
|
632
|
+
print(f"Mismatch in {filename}! ❌")
|
|
633
|
+
|
|
634
|
+
# Verify copied file
|
|
635
|
+
print(f"Verifying copied file for {rollout_id}")
|
|
636
|
+
copied_dl = rollout_tmp / f"dl_copied_{rollout_id}.txt"
|
|
637
|
+
await env.copy_from_workspace(rollout_id, f"copied_{rollout_id}.txt", copied_dl)
|
|
638
|
+
if copied_dl.read_text(encoding="utf-8") == tmp_path.read_text(encoding="utf-8"):
|
|
639
|
+
print(f"Verified copied file ✅")
|
|
640
|
+
else:
|
|
641
|
+
print(f"Mismatch in copied file ❌")
|
|
642
|
+
|
|
643
|
+
# Stage 2: Run tool
|
|
644
|
+
tool_result = await env.run_tool(rollout_id, "calculate", expression=expression)
|
|
645
|
+
print(f"Tool result for {rollout_id}: {tool_result}")
|
|
646
|
+
|
|
647
|
+
# Stage 3: Compute reward
|
|
648
|
+
reward = await env.compute_reward(rollout_id, completion=str(tool_result), ground_truth=expected)
|
|
649
|
+
print(f"Computed reward for {rollout_id}: {reward}")
|
|
650
|
+
|
|
651
|
+
return rollout_id, reward
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
async def main():
|
|
655
|
+
env = RemoteSkypilotMcpEnv(
|
|
656
|
+
workdir_path="benchmax/envs/skypilot/workdir",
|
|
657
|
+
num_nodes=2,
|
|
658
|
+
cluster_name="test-cluster",
|
|
659
|
+
cloud=sky.Azure(),
|
|
660
|
+
cpus="2+",
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
tmp_root = Path("./tmp")
|
|
664
|
+
tmp_root.mkdir(exist_ok=True)
|
|
665
|
+
|
|
666
|
+
try:
|
|
667
|
+
tools = await env.list_tools()
|
|
668
|
+
print(f"Available tools: {[tool.name for tool in tools]}")
|
|
669
|
+
|
|
670
|
+
rollout_tasks = []
|
|
671
|
+
for i in range(3): # fewer for debugging; adjust as needed
|
|
672
|
+
rollout_id = f"test-rollout-{i:03d}"
|
|
673
|
+
expression = f"{i + 1} + {i + 1}"
|
|
674
|
+
expected = str((i + 1) + (i + 1))
|
|
675
|
+
task = run_single_rollout(env, rollout_id, expression, expected, tmp_root)
|
|
676
|
+
rollout_tasks.append(task)
|
|
677
|
+
|
|
678
|
+
print("Starting concurrent rollouts...")
|
|
679
|
+
results = await asyncio.gather(*rollout_tasks, return_exceptions=True)
|
|
680
|
+
|
|
681
|
+
print("Rollout results:")
|
|
682
|
+
for result in results:
|
|
683
|
+
print(result)
|
|
684
|
+
|
|
685
|
+
finally:
|
|
686
|
+
await env.shutdown()
|
|
687
|
+
# Cleanup tmp dir at the very end
|
|
688
|
+
shutil.rmtree(tmp_root, ignore_errors=True)
|
|
689
|
+
print("Cleaned up temporary files.")
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
if __name__ == "__main__":
|
|
694
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from fastmcp import Client
|
|
3
|
+
|
|
4
|
+
def reward_function(
|
|
5
|
+
completion: str,
|
|
6
|
+
ground_truth: Any,
|
|
7
|
+
workspace: str,
|
|
8
|
+
mcp_client: Client,
|
|
9
|
+
**kwargs: Any
|
|
10
|
+
) -> float:
|
|
11
|
+
"""Compute the reward for a given model completion."""
|
|
12
|
+
print(f"Workspace for reward function: {workspace}")
|
|
13
|
+
return 1.0 if completion.strip() == ground_truth.strip() else 0.0
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
reward_functions = [reward_function]
|
|
@@ -7,7 +7,6 @@ class StandardizedExample(TypedDict):
|
|
|
7
7
|
ground_truth: Any
|
|
8
8
|
init_rollout_args: Optional[Dict[str, Any]]
|
|
9
9
|
|
|
10
|
-
|
|
11
10
|
@dataclass
|
|
12
11
|
class ToolDefinition:
|
|
13
12
|
"""Definition of a tool's interface"""
|
|
@@ -19,10 +18,9 @@ class RewardFunction(Protocol):
|
|
|
19
18
|
"""Function that evaluates model interactions"""
|
|
20
19
|
def __call__(
|
|
21
20
|
self,
|
|
22
|
-
prompt: str, # Input prompt given to the model
|
|
23
21
|
completion: str, # Model's generated completion/response
|
|
24
22
|
ground_truth: Any, # Expected/correct output to compare against
|
|
25
|
-
workspace:
|
|
23
|
+
workspace: str, # Current workspace of the rollout
|
|
26
24
|
**kwargs: Any # Additional context for reward computation
|
|
27
25
|
) -> float: # Reward score (typically in range [0, 1])
|
|
28
26
|
...
|
|
@@ -220,6 +220,18 @@ class WikipediaEnv(BaseEnv):
|
|
|
220
220
|
|
|
221
221
|
def get_rollout_workspace(self, rollout_id: str) -> Path:
|
|
222
222
|
return super().get_rollout_workspace(rollout_id)
|
|
223
|
+
|
|
224
|
+
def copy_to_workspace(
|
|
225
|
+
self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
|
|
226
|
+
) -> None:
|
|
227
|
+
"""Copy a file to the workspace for a specific rollout."""
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
def copy_from_workspace(
|
|
231
|
+
self, rollout_id: str, src_filename: str, dst_path: Path
|
|
232
|
+
) -> None:
|
|
233
|
+
"""Copy a file from the workspace for a specific rollout."""
|
|
234
|
+
pass
|
|
223
235
|
|
|
224
236
|
if __name__ == "__main__":
|
|
225
237
|
# Example usage
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "benchmax"
|
|
3
|
-
version = "0.1.1.
|
|
3
|
+
version = "0.1.1.dev6"
|
|
4
4
|
description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
|
|
5
5
|
authors = ["cgft.io"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -12,6 +12,8 @@ packages = [
|
|
|
12
12
|
python = ">=3.11,<3.13"
|
|
13
13
|
fastmcp = "~2.10.0"
|
|
14
14
|
|
|
15
|
+
skypilot = { version = "0.8.1", optional = true }
|
|
16
|
+
|
|
15
17
|
verl-cgft-fork = { version = "0.5.0.dev2", optional = true }
|
|
16
18
|
sglang = { version = "0.4.9", optional = true, extras = ["all"] }
|
|
17
19
|
verifiers = { version = "^0.1.1", optional = true, extras = ["train"] }
|
|
@@ -30,6 +32,9 @@ pytest = "^8.4.1"
|
|
|
30
32
|
verifiers = ["verifiers"]
|
|
31
33
|
verl = ["verl-cgft-fork", "sglang"]
|
|
32
34
|
|
|
35
|
+
# Hosting-specific
|
|
36
|
+
skypilot = ["skypilot"]
|
|
37
|
+
|
|
33
38
|
# Environment-specific
|
|
34
39
|
excel-linux = ["openpyxl"]
|
|
35
40
|
excel = ["openpyxl", "xlwings"]
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/adapters/verifiers/verifiers_adapters.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|