benchmax 0.1.1.dev4__tar.gz → 0.1.1.dev6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/PKG-INFO +8 -3
  2. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/README.md +5 -3
  3. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/base_env.py +19 -36
  4. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/excel_env.py +4 -13
  5. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/excel_utils.py +9 -8
  6. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/local_mcp_env.py +28 -0
  7. benchmax-0.1.1.dev6/benchmax/envs/skypilot/proxy_server.py +167 -0
  8. benchmax-0.1.1.dev6/benchmax/envs/skypilot/remote_skypilot_mcp_server.py +694 -0
  9. benchmax-0.1.1.dev6/benchmax/envs/skypilot/workdir/mcp_config.yaml +5 -0
  10. benchmax-0.1.1.dev6/benchmax/envs/skypilot/workdir/reward_func.py +16 -0
  11. benchmax-0.1.1.dev6/benchmax/envs/skypilot/workdir/setup.sh +3 -0
  12. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/types.py +1 -3
  13. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/wikipedia/wiki_env.py +12 -0
  14. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/pyproject.toml +6 -1
  15. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/LICENSE +0 -0
  16. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/adapters/__init__.py +0 -0
  17. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/adapters/verifiers/verifiers_adapters.py +0 -0
  18. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/adapters/verl/benchmax_data_process.py +0 -0
  19. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/__init__.py +0 -0
  20. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/bounded_dict.py +0 -0
  21. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/crm/README.md +0 -0
  22. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/crm/crm_env.py +0 -0
  23. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/crm/salesforce_mcp.py +0 -0
  24. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/crm/salesforce_requirements.txt +0 -0
  25. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/README.md +0 -0
  26. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/data_utils.py +0 -0
  27. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/excel/excel_code_runner_mcp.py +0 -0
  28. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/math/README.md +0 -0
  29. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/math/math_env.py +0 -0
  30. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/wikipedia/README.md +0 -0
  31. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/envs/wikipedia/utils.py +0 -0
  32. {benchmax-0.1.1.dev4 → benchmax-0.1.1.dev6}/benchmax/prompts/tools.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: benchmax
3
- Version: 0.1.1.dev4
3
+ Version: 0.1.1.dev6
4
4
  Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
5
5
  Author: cgft.io
6
6
  Requires-Python: >=3.11,<3.13
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.12
10
10
  Provides-Extra: crm
11
11
  Provides-Extra: excel
12
12
  Provides-Extra: excel-linux
13
+ Provides-Extra: skypilot
13
14
  Provides-Extra: verifiers
14
15
  Provides-Extra: verl
15
16
  Requires-Dist: fastmcp (>=2.10.0,<2.11.0)
@@ -17,6 +18,7 @@ Requires-Dist: openpyxl (==3.1.5) ; extra == "excel-linux" or extra == "excel"
17
18
  Requires-Dist: python-dateutil (>=2.9.0,<2.10.0) ; extra == "crm"
18
19
  Requires-Dist: sglang[all] (==0.4.9) ; extra == "verl"
19
20
  Requires-Dist: simple-salesforce (>=1.12.3) ; extra == "crm"
21
+ Requires-Dist: skypilot (==0.8.1) ; extra == "skypilot"
20
22
  Requires-Dist: verifiers[train] (>=0.1.1,<0.2.0) ; extra == "verifiers"
21
23
  Requires-Dist: verl-cgft-fork (==0.5.0.dev2) ; extra == "verl"
22
24
  Requires-Dist: xlwings (==0.33.15) ; extra == "excel"
@@ -68,7 +70,7 @@ Get started with ready to use recipes, from Wikipedia search to spreadsheet mani
68
70
 
69
71
  **Trainer Integrations**
70
72
 
71
- Use your own trainer or training framework - no lock-in. `benchmax` is already Integrated into verl and verifiers, with more integrations (SkyRL, etc.) coming soon!
73
+ Use your own trainer or training framework - no lock-in. `benchmax` is already integrated into verl and verifiers, with more integrations (SkyRL, etc.) coming soon!
72
74
 
73
75
  **MCP Support**
74
76
  Tap into the growing MCP ecosystem and integrate them as tools within your environments.
@@ -89,6 +91,8 @@ Tap into the growing MCP ecosystem and integrate them as tools within your envir
89
91
 
90
92
  `pip install benchmax[verl]`
91
93
 
94
+ \* Note that benchmax installs our verl fork (temporary until [PR gets merged](https://github.com/volcengine/verl/pull/2792))
95
+
92
96
  1. **Prepare the dataset**
93
97
 
94
98
  ```bash
@@ -387,7 +391,7 @@ Open an issue and tag us & we will look into building you one!
387
391
  - Facilitate easy deployment and scalability in cloud environments.
388
392
  - **MCP as a first class citizen**:
389
393
 
390
- There has been an explosion of MCP servers/tools built out for usecases ranging from browser use to excel to game creation.`benchmax` allow folks to leverage and composes these existing MCP servers to build environment integrated with real world systems e.g. excel
394
+ There has been an explosion of MCP servers/tools built out for usecases ranging from browser use to excel to game creation.`benchmax` allows folks to leverage and compose these existing MCP servers to build environments integrated with real world systems e.g. excel
391
395
 
392
396
 
393
397
  ## 🤝 Contributing
@@ -399,3 +403,4 @@ We welcome new environment recipes, bug reports, and trainer integrations!
399
403
  ## 📜 License
400
404
 
401
405
  Apache 2.0 © 2025 CGFT Inc.
406
+
@@ -44,7 +44,7 @@ Get started with ready to use recipes, from Wikipedia search to spreadsheet mani
44
44
 
45
45
  **Trainer Integrations**
46
46
 
47
- Use your own trainer or training framework - no lock-in. `benchmax` is already Integrated into verl and verifiers, with more integrations (SkyRL, etc.) coming soon!
47
+ Use your own trainer or training framework - no lock-in. `benchmax` is already integrated into verl and verifiers, with more integrations (SkyRL, etc.) coming soon!
48
48
 
49
49
  **MCP Support**
50
50
  Tap into the growing MCP ecosystem and integrate them as tools within your environments.
@@ -65,6 +65,8 @@ Tap into the growing MCP ecosystem and integrate them as tools within your envir
65
65
 
66
66
  `pip install benchmax[verl]`
67
67
 
68
+ \* Note that benchmax installs our verl fork (temporary until [PR gets merged](https://github.com/volcengine/verl/pull/2792))
69
+
68
70
  1. **Prepare the dataset**
69
71
 
70
72
  ```bash
@@ -363,7 +365,7 @@ Open an issue and tag us & we will look into building you one!
363
365
  - Facilitate easy deployment and scalability in cloud environments.
364
366
  - **MCP as a first class citizen**:
365
367
 
366
- There has been an explosion of MCP servers/tools built out for usecases ranging from browser use to excel to game creation.`benchmax` allow folks to leverage and composes these existing MCP servers to build environment integrated with real world systems e.g. excel
368
+ There has been an explosion of MCP servers/tools built out for usecases ranging from browser use to excel to game creation.`benchmax` allows folks to leverage and compose these existing MCP servers to build environments integrated with real world systems e.g. excel
367
369
 
368
370
 
369
371
  ## 🤝 Contributing
@@ -374,4 +376,4 @@ We welcome new environment recipes, bug reports, and trainer integrations!
374
376
 
375
377
  ## 📜 License
376
378
 
377
- Apache 2.0 © 2025 CGFT Inc.
379
+ Apache 2.0 © 2025 CGFT Inc.
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, List, Any, Tuple
2
+ from typing import Dict, List, Any, Optional, Tuple
3
3
  from pathlib import Path
4
4
 
5
5
  from datasets import (
@@ -59,34 +59,38 @@ class BaseEnv(ABC):
59
59
  return load_dataset(dataset_name, **kwargs), None
60
60
 
61
61
  @abstractmethod
62
- def list_tools(self) -> List[ToolDefinition]:
62
+ async def list_tools(self) -> List[ToolDefinition]:
63
63
  """Return list of available tools"""
64
64
  pass
65
65
 
66
66
  @abstractmethod
67
- def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Any:
67
+ async def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Any:
68
68
  """Execute named tool in rollout context with given arguments"""
69
69
  pass
70
70
 
71
71
  @abstractmethod
72
- def init_rollout(self, rollout_id: str, **rollout_args) -> None:
72
+ async def init_rollout(self, rollout_id: str, **rollout_args) -> None:
73
73
  """Initialize resources for a new rollout"""
74
74
  pass
75
-
75
+
76
76
  @abstractmethod
77
- def cleanup_rollout(self, rollout_id: str) -> None:
78
- """Clean up resources for a rollout"""
77
+ async def copy_to_workspace(
78
+ self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
79
+ ) -> None:
80
+ """Copy a file to the workspace for a specific rollout. If dst_filename is None, use the original filename."""
79
81
  pass
80
82
 
81
83
  @abstractmethod
82
- def get_rollout_workspace(self, rollout_id: str) -> Path:
83
- """Get the workspace path for a specific rollout"""
84
+ async def copy_from_workspace(
85
+ self, rollout_id: str, src_filename: str, dst_path: Path
86
+ ) -> None:
87
+ """Copy a file from the workspace for a specific rollout"""
84
88
  pass
85
-
86
- def compute_reward(
89
+
90
+ @abstractmethod
91
+ async def compute_reward(
87
92
  self,
88
93
  rollout_id: str,
89
- prompt: str,
90
94
  completion: str,
91
95
  ground_truth: Any,
92
96
  **kwargs: Any
@@ -95,32 +99,11 @@ class BaseEnv(ABC):
95
99
 
96
100
  Returns dict mapping reward function names to their computed scores.
97
101
  """
98
- workspace = self.get_rollout_workspace(rollout_id)
99
- if workspace is None:
100
- raise ValueError(f"No workspace found for rollout {rollout_id}")
101
-
102
- results: Dict[str, float] = {}
103
- for func in self.reward_funcs:
104
- try:
105
- # Get function name, falling back to string representation if not available
106
- func_name = getattr(func, "__name__", str(func))
107
- results[func_name] = func(
108
- prompt=prompt,
109
- completion=completion,
110
- ground_truth=ground_truth,
111
- workspace=workspace,
112
- **kwargs
113
- )
114
- except Exception as e:
115
- # Use same function name resolution
116
- func_name = getattr(func, "__name__", str(func))
117
- results[func_name] = float('nan')
118
- print(f"[WARN] reward {func_name} failed: {e}")
119
- return results
102
+ pass
120
103
 
121
- def get_system_prompt(self, add_tool_defs: bool = False) -> str:
104
+ async def get_system_prompt(self, add_tool_defs: bool = False) -> str:
122
105
  """Get system prompt. To add tool definitions, set add_tool_defs to True."""
123
106
  if add_tool_defs:
124
- return render_tools_prompt(self.list_tools(), self.system_prompt or "")
107
+ return render_tools_prompt(await self.list_tools(), self.system_prompt or "")
125
108
  else:
126
109
  return self.system_prompt
@@ -138,16 +138,7 @@ class ExcelEnv(LocalMCPEnv):
138
138
  answer_spreadsheet_path = rollout_args["answer_spreadsheet_path"]
139
139
 
140
140
  super().init_rollout(rollout_id, **rollout_args)
141
- workspace = self.get_rollout_workspace(rollout_id)
142
-
143
- def _copy_to_workspace(src_path: Path):
144
- """
145
- Copy the spreadsheet file to the workspace if it doesn't already exist.
146
- """
147
- dest_path = workspace / src_path.name
148
- if not dest_path.exists():
149
- dest_path.write_bytes(src_path.read_bytes())
150
-
151
- # Copy the spreadsheet to the workspace
152
- _copy_to_workspace(Path(spreadsheet_path))
153
- _copy_to_workspace(Path(answer_spreadsheet_path))
141
+
142
+ self.copy_to_workspace(rollout_id, Path(spreadsheet_path))
143
+ self.copy_to_workspace(rollout_id, Path(answer_spreadsheet_path))
144
+
@@ -16,7 +16,6 @@ WHITE_LIKE_COLORS = [
16
16
  ]
17
17
 
18
18
  def evaluate_excel(excel_path: str):
19
- import xlwings
20
19
  """
21
20
  Evaluate Python code that manipulates an Excel file using xlwings.
22
21
  """
@@ -24,11 +23,13 @@ def evaluate_excel(excel_path: str):
24
23
  if platform.system() == "Linux":
25
24
  evaluate_excel_libre(excel_path)
26
25
  return
27
- excel_app = xlwings.App(visible=False)
28
- excel_book = excel_app.books.open(excel_path)
29
- excel_book.save()
30
- excel_book.close()
31
- excel_app.quit()
26
+ else:
27
+ import xlwings
28
+ excel_app = xlwings.App(visible=False)
29
+ excel_book = excel_app.books.open(excel_path)
30
+ excel_book.save()
31
+ excel_book.close()
32
+ excel_app.quit()
32
33
 
33
34
  def evaluate_excel_libre(excel_path: str) -> None:
34
35
  """
@@ -120,7 +121,7 @@ def excel_to_str_repr(excel_path: str, evaluate_formulas = False) -> str:
120
121
  result.append(f"{coords}: null [{', '.join(style)}]")
121
122
  is_row_empty = False
122
123
  elif display_value:
123
- style_str = f" [{", ".join(style)}]" if style else ""
124
+ style_str = f" [{', '.join(style)}]" if style else ""
124
125
  result.append(f"{coords}: {display_value}{style_str}")
125
126
  is_row_empty = False
126
127
  if not is_row_empty:
@@ -212,4 +213,4 @@ def compare_excel_cells(ground_truth_path: str, output_path: str, answer_positio
212
213
  return False, f"Fill color mismatch at {cell_name}"
213
214
  if not compare_font_color(cell_gt.font, cell_out.font):
214
215
  return False, f"Font color mismatch at {cell_name}"
215
- return True, "All comparisons passed."
216
+ return True, "All comparisons passed."
@@ -228,6 +228,34 @@ class LocalMCPEnv(BaseEnv):
228
228
  raise ValueError(f"No active client found for rollout {rollout_id}")
229
229
  else:
230
230
  return Path()
231
+
232
+ def copy_to_workspace(
233
+ self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
234
+ ) -> None:
235
+ """Copy a file to the workspace for a specific rollout. If dst_filename is None, use the original filename."""
236
+ if rollout_id not in self._active_clients:
237
+ raise ValueError(f"No active client found for rollout {rollout_id}")
238
+
239
+ if not src_path.exists():
240
+ raise FileNotFoundError(f"Source file {src_path} does not exist")
241
+
242
+ pair = self._active_clients[rollout_id]
243
+ dst_path = pair.workspace / (dst_filename or src_path.name)
244
+ dst_path.write_bytes(src_path.read_bytes())
245
+
246
+ def copy_from_workspace(
247
+ self, rollout_id: str, src_filename: str, dst_path: Path
248
+ ) -> None:
249
+ """Copy a file from the workspace for a specific rollout"""
250
+ if rollout_id not in self._active_clients:
251
+ raise ValueError(f"No active client found for rollout {rollout_id}")
252
+
253
+ pair = self._active_clients[rollout_id]
254
+ src_path = pair.workspace / src_filename
255
+ if not src_path.exists():
256
+ raise FileNotFoundError(f"File {src_filename} not found in workspace {pair.workspace}")
257
+
258
+ dst_path.write_bytes(src_path.read_bytes())
231
259
 
232
260
  # ---- Private Helper Methods ----
233
261
 
@@ -0,0 +1,167 @@
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import uuid
5
+ import yaml
6
+ import asyncio
7
+ from pathlib import Path
8
+ from functools import wraps
9
+ from fastmcp import FastMCP, Client
10
+ from starlette.requests import Request
11
+ from starlette.responses import PlainTextResponse, FileResponse, JSONResponse
12
+ from starlette.datastructures import UploadFile
13
+
14
+ from reward_func import reward_functions # your reward functions
15
+
16
+
17
+ # ---------------- Utility Functions ---------------- #
18
+ def setup_workspace(base_dir: Path) -> Path:
19
+ """Create a unique workspace directory."""
20
+ ws = (base_dir / uuid.uuid4().hex).resolve()
21
+ ws.mkdir(parents=True, exist_ok=True)
22
+ return ws
23
+
24
+
25
+ def load_config(config_path: Path, workspace: Path) -> dict:
26
+ """Load YAML config and inject workspace paths."""
27
+ with open(config_path, "r") as f:
28
+ content = f.read().replace("${{ sync_workdir }}", str(Path(__file__).resolve().parent))
29
+ config = yaml.safe_load(content)
30
+ if "mcpServers" in config:
31
+ for server in config["mcpServers"].values():
32
+ server["cwd"] = str(workspace)
33
+ return config
34
+
35
+
36
+ # ---------------- Auth Decorator ---------------- #
37
+ def require_auth(func):
38
+ """Require API_TOKEN header."""
39
+ @wraps(func)
40
+ async def wrapper(*args, **kwargs):
41
+ request = args[1] if len(args) == 2 else args[0]
42
+ token = request.headers.get("Authorization")
43
+ if token != os.getenv("API_TOKEN", "default-secret-token"):
44
+ return PlainTextResponse("Unauthorized", status_code=401)
45
+ return await func(*args, **kwargs)
46
+ return wrapper
47
+
48
+
49
+ # ---------------- Proxy Server ---------------- #
50
+ class ProxyServer:
51
+ def __init__(self, base_dir="workspace", host="0.0.0.0", port=8080):
52
+ self.base_dir = Path(base_dir)
53
+ self.base_dir.mkdir(parents=True, exist_ok=True)
54
+ self.host = host
55
+ self.port = port
56
+ self.workspace: Path | None = None
57
+ self.client: Client | None = None
58
+ self.proxy: FastMCP | None = None
59
+ self.config_path = Path(__file__).parent / "mcp_config.yaml"
60
+
61
+ async def _setup(self):
62
+ """Initialize workspace, MCP client, and proxy server."""
63
+ self.workspace = setup_workspace(self.base_dir)
64
+ config = load_config(self.config_path, self.workspace)
65
+
66
+ self.client = Client(config)
67
+ await self.client._connect()
68
+
69
+ self.proxy = FastMCP.as_proxy(self.client, name="proxy")
70
+
71
+ # Register endpoints
72
+ self.proxy.custom_route("/health", methods=["GET"])(self._health)
73
+ self.proxy.custom_route("/upload", methods=["POST"])(self._upload)
74
+ self.proxy.custom_route("/download", methods=["GET"])(self._download)
75
+ self.proxy.custom_route("/compute_reward", methods=["POST"])(self._compute_reward)
76
+ self.proxy.custom_route("/reset", methods=["POST"])(self._reset)
77
+
78
+ # ---------------- Endpoints ---------------- #
79
+ async def _health(self, request: Request):
80
+ return PlainTextResponse("OK")
81
+
82
+ @require_auth
83
+ async def _upload(self, request: Request):
84
+ if not self.workspace:
85
+ return PlainTextResponse("No workspace available", 500)
86
+ form = await request.form()
87
+ uploaded = []
88
+ for file in form.values():
89
+ if isinstance(file, UploadFile) and file.filename:
90
+ dest = self.workspace / file.filename
91
+ with open(dest, "wb") as f:
92
+ f.write(await file.read())
93
+ uploaded.append(file.filename)
94
+ if not uploaded:
95
+ return PlainTextResponse("No files uploaded", 400)
96
+ return PlainTextResponse(f"Uploaded: {', '.join(uploaded)}")
97
+
98
+ @require_auth
99
+ async def _download(self, request: Request):
100
+ if not self.workspace:
101
+ return PlainTextResponse("No workspace", 500)
102
+ file_path = request.query_params.get("file_path")
103
+ if not file_path:
104
+ return PlainTextResponse("file_path required", 400)
105
+ full_path = self.workspace / file_path
106
+ if not full_path.exists() or not full_path.is_file():
107
+ return PlainTextResponse("File not found", 404)
108
+ return FileResponse(str(full_path), filename=full_path.name)
109
+
110
+ @require_auth
111
+ async def _compute_reward(self, request: Request):
112
+ try:
113
+ data = await request.json()
114
+ except Exception:
115
+ return PlainTextResponse("Invalid JSON", 400)
116
+
117
+ completion = data.get("completion")
118
+ ground_truth = data.get("ground_truth")
119
+ if completion is None or ground_truth is None:
120
+ return PlainTextResponse("completion and ground_truth required", 400)
121
+
122
+ results = {}
123
+ for func in reward_functions or []:
124
+ name = getattr(func, "__name__", str(func))
125
+ try:
126
+ results[name] = func(completion=completion, ground_truth=ground_truth, workspace=self.workspace, mcp_client=self.client, **{
127
+ k: v for k, v in data.items() if k not in ("completion", "ground_truth")
128
+ })
129
+ except Exception as e:
130
+ results[name] = float("nan")
131
+ print(f"[WARN] reward {name} failed: {e}")
132
+ return JSONResponse(results)
133
+
134
+ @require_auth
135
+ async def _reset(self, request: Request):
136
+ """Reset server: clean workspace and restart process."""
137
+ async def do_reset():
138
+ await asyncio.sleep(0.1)
139
+ print("[INFO] Resetting server...")
140
+ sys.stdout.flush()
141
+ os.execv(sys.executable, [sys.executable] + sys.argv)
142
+
143
+ # Clean up workspace
144
+ self.cleanup_workspace()
145
+
146
+ asyncio.create_task(do_reset())
147
+ return PlainTextResponse("Server reset scheduled")
148
+
149
+ # ---------------- Public API ---------------- #
150
+ def cleanup_workspace(self):
151
+ if self.workspace and self.workspace.exists():
152
+ shutil.rmtree(self.workspace)
153
+
154
+ async def start(self):
155
+ await self._setup()
156
+ if self.proxy:
157
+ await self.proxy.run_async(transport="http", host=self.host, port=self.port)
158
+
159
+
160
+ # ---------------- Main ---------------- #
161
+ if __name__ == "__main__":
162
+ server = ProxyServer("../workspace")
163
+ try:
164
+ asyncio.run(server.start())
165
+ except KeyboardInterrupt:
166
+ print("\nShutting down gracefully...")
167
+ server.cleanup_workspace()
@@ -0,0 +1,694 @@
1
+ import asyncio
2
+ import datetime
3
+ import aiohttp
4
+ import uuid
5
+ import tempfile
6
+ import shutil
7
+ from typing import Callable, List, Any, Optional, Dict
8
+ from pathlib import Path
9
+ from fastmcp import Client as FastMCPClient
10
+ from mcp.types import TextContent
11
+ from fastmcp.exceptions import ToolError
12
+ from mcp import Tool
13
+ import sky
14
+ import logging
15
+ import warnings
16
+
17
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
18
+
19
+ logging.basicConfig(level=logging.DEBUG, format='%(levelname)s:%(name)s:%(message)s')
20
+ logging.getLogger('httpx').setLevel(logging.CRITICAL)
21
+ logging.getLogger('aiohttp').setLevel(logging.CRITICAL)
22
+ logging.getLogger('mcp.client.streamable_http').setLevel(logging.CRITICAL)
23
+ logging.getLogger('urllib3').setLevel(logging.CRITICAL)
24
+ logging.getLogger('requests').setLevel(logging.CRITICAL)
25
+
26
+ from benchmax.envs.base_env import BaseEnv
27
+ from benchmax.envs.types import ToolDefinition
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ class RemoteSkypilotMcpEnv(BaseEnv):
32
+ """Remote MCP Environment for managing tool execution and rollouts with a remote MCP server.
33
+ Currently only supports running on Skypilot containers.
34
+ """
35
+ def __init__(
36
+ self,
37
+ workdir_path: str,
38
+ num_nodes: int = 1,
39
+ allowed_tools: Optional[List[str]] = None,
40
+ output_parsers: Optional[Dict[str, Callable[[str], Any]]] = None,
41
+ cluster_name: str = "benchmax-env-cluster",
42
+ health_check_timeout: int = 300, # 5 minutes
43
+ health_check_interval: int = 5, # 5 seconds
44
+ launch_workers_on_init: bool = True,
45
+ cloud: Optional[Any] = sky.Azure(), # sky.Cloud instance
46
+ cpus: str = "2+",
47
+ ) -> None:
48
+ """Initialize the environment with configuration and pool settings."""
49
+ super().__init__()
50
+ self._workdir_path = Path(workdir_path)
51
+ self._num_nodes = num_nodes
52
+ self._allowed_tools = allowed_tools or []
53
+ self._output_parsers: Dict[str, Callable[[str], Any]] = output_parsers or {}
54
+ self._cluster_name = cluster_name
55
+ self._health_check_timeout = health_check_timeout
56
+ self._health_check_interval = health_check_interval
57
+ self._cloud = cloud
58
+ self._cpus = cpus
59
+ self._ports = ["8080"] # MCP server port
60
+
61
+ # Generate API token for worker authentication
62
+ self._api_token = uuid.uuid4().hex
63
+
64
+ # Generate unique cluster name with suffix
65
+ unique_suffix = uuid.uuid4().hex[:8]
66
+ self._full_cluster_name = f"{self._cluster_name}-{unique_suffix}"
67
+
68
+ # Sync directory management
69
+ self._sync_dir: Optional[Path] = None
70
+
71
+ # Worker management
72
+ self._client_pool: Dict[str, FastMCPClient] = {}
73
+ self._available_workers: asyncio.Queue[str] = asyncio.Queue()
74
+ self._rollout_to_worker: Dict[str, str] = {}
75
+ self._worker_init_tasks: List[asyncio.Task] = []
76
+
77
+ # HTTP session
78
+ self._http_session = aiohttp.ClientSession()
79
+
80
+ # Cached tool definitions
81
+ self._tool_definitions: Optional[List[ToolDefinition]] = None
82
+
83
+ # Launch workers and start initialization
84
+ self.launch_workers_started = False
85
+ if launch_workers_on_init:
86
+ self.launch_workers()
87
+
88
+ def _setup_sync_directory(self) -> Path:
89
+ """Create temporary sync directory and copy required files."""
90
+ # Create temporary directory
91
+ self._sync_dir = Path(tempfile.mkdtemp(prefix="benchmax_skypilot_"))
92
+ logger.info(f"Created sync directory: {self._sync_dir}")
93
+
94
+ try:
95
+ # Get the directory where this file is located
96
+ current_file_dir = Path(__file__).parent # inside remote_skypilot_mcp_server.py
97
+ proxy_server_path = current_file_dir / "proxy_server.py"
98
+
99
+ # Copy proxy_server.py
100
+ if not proxy_server_path.exists():
101
+ raise FileNotFoundError(f"proxy_server.py not found at {proxy_server_path}")
102
+
103
+ shutil.copy2(proxy_server_path, self._sync_dir / "proxy_server.py")
104
+ logger.debug(f"Copied proxy_server.py to sync directory")
105
+
106
+ # Copy all contents from workdir_path
107
+ if not self._workdir_path.exists():
108
+ raise FileNotFoundError(f"Workdir path does not exist: {self._workdir_path}")
109
+
110
+ if not self._workdir_path.is_dir():
111
+ raise ValueError(f"Workdir path is not a directory: {self._workdir_path}")
112
+
113
+ # Copy all contents
114
+ for item in self._workdir_path.iterdir():
115
+ if item.is_file():
116
+ shutil.copy2(item, self._sync_dir / item.name)
117
+ logger.debug(f"Copied file {item.name} to sync directory")
118
+ elif item.is_dir():
119
+ shutil.copytree(item, self._sync_dir / item.name)
120
+ logger.debug(f"Copied directory {item.name} to sync directory")
121
+
122
+ # Validate required files exist
123
+ reward_func_path = self._sync_dir / "reward_func.py"
124
+ setup_sh_path = self._sync_dir / "setup.sh"
125
+ mcp_config_path = self._sync_dir / "mcp_config.yaml"
126
+
127
+ if not reward_func_path.exists():
128
+ raise FileNotFoundError(f"reward_func.py not found in workdir: {self._workdir_path}")
129
+
130
+ if not setup_sh_path.exists():
131
+ raise FileNotFoundError(f"setup.sh not found in workdir: {self._workdir_path}")
132
+
133
+ if not mcp_config_path.exists():
134
+ raise FileNotFoundError(f"mcp_config.yaml not found in workdir: {self._workdir_path}")
135
+
136
+ logger.info(f"Validated required files in sync directory")
137
+ return self._sync_dir
138
+
139
+ except Exception as e:
140
+ # Clean up sync directory if setup fails
141
+ if self._sync_dir and self._sync_dir.exists():
142
+ shutil.rmtree(self._sync_dir, ignore_errors=True)
143
+ self._sync_dir = None
144
+ raise e
145
+
146
+ def _cleanup_sync_directory(self) -> None:
147
+ """Clean up the temporary sync directory."""
148
+ if self._sync_dir and self._sync_dir.exists():
149
+ try:
150
+ shutil.rmtree(self._sync_dir)
151
+ logger.info(f"Cleaned up sync directory: {self._sync_dir}")
152
+ except Exception as e:
153
+ logger.warning(f"Failed to clean up sync directory {self._sync_dir}: {e}")
154
+ finally:
155
+ self._sync_dir = None
156
+
157
+ async def _init_worker(self, worker_ip: str) -> None:
158
+ """Initialize a single worker: health check + FastMCP client + add to pool."""
159
+ try:
160
+ # Health check
161
+ await self._wait_for_worker_health(worker_ip)
162
+
163
+ # Initialize FastMCP client
164
+ mcp_url = f"http://{worker_ip}:8080/mcp/"
165
+ client = FastMCPClient(mcp_url)
166
+ await client._connect()
167
+ self._client_pool[worker_ip] = client
168
+
169
+ # Add to available pool
170
+ await self._available_workers.put(worker_ip)
171
+ logger.debug(f"Worker {worker_ip} initialized and added to pool")
172
+
173
+ except Exception as e:
174
+ logger.error(f"Failed to initialize worker {worker_ip}: {e}")
175
+ # Don't re-raise - let other workers continue
176
+
177
+ async def _wait_for_worker_health(self, worker_ip: str) -> None:
178
+ """Wait for worker to pass health check."""
179
+ health_url = f"http://{worker_ip}:8080/health"
180
+ start_time = asyncio.get_event_loop().time()
181
+
182
+ while True:
183
+ elapsed = asyncio.get_event_loop().time() - start_time
184
+ if elapsed > self._health_check_timeout:
185
+ raise TimeoutError(f"Health check timeout for worker {worker_ip}")
186
+
187
+ try:
188
+ timeout = aiohttp.ClientTimeout(total=5)
189
+ async with self._http_session.get(health_url, timeout=timeout) as response:
190
+ if response.status == 200:
191
+ logger.debug(f"Worker {worker_ip} is healthy")
192
+ return
193
+ else:
194
+ logger.debug(f"Worker {worker_ip} health check returned {response.status}")
195
+ except (aiohttp.ClientError, asyncio.TimeoutError) as e:
196
+ logger.debug(f"Health check failed for {worker_ip}: {e}")
197
+
198
+ await asyncio.sleep(self._health_check_interval)
199
+
200
+ async def _get_available_worker(self) -> str:
201
+ """Get an available worker, blocking until one is ready."""
202
+ return await self._available_workers.get()
203
+
204
+ async def _release_worker(self, worker_ip: str) -> None:
205
+ """Return a worker to the available pool."""
206
+ await self._available_workers.put(worker_ip)
207
+
208
+ async def _call_worker_reset(self, worker_ip: str) -> None:
209
+ """Call the reset endpoint on a specific worker."""
210
+ reset_url = f"http://{worker_ip}:8080/reset"
211
+ headers = {"Authorization": self._api_token}
212
+
213
+ try:
214
+ async with self._http_session.post(reset_url, headers=headers) as response:
215
+ if response.status == 200:
216
+ logger.info(f"Reset successful for worker {worker_ip}")
217
+ else:
218
+ error_text = await response.text()
219
+ raise RuntimeError(f"Reset failed for worker {worker_ip}: {response.status} - {error_text}")
220
+ except aiohttp.ClientError as e:
221
+ raise RuntimeError(f"Reset request failed for worker {worker_ip}: {e}")
222
+
223
+ async def add_worker_back_once_available(self, worker_ip: str) -> None:
224
+ """Add a worker back to the available pool once it passes health check."""
225
+ await asyncio.sleep(1) # brief delay before starting health checks
226
+
227
+ try:
228
+ await self._init_worker(worker_ip)
229
+ logger.info(f"Worker {worker_ip} added back to available pool after reset")
230
+ except Exception as e:
231
+ logger.error(f"Failed to add worker {worker_ip} back to pool: {e}")
232
+ # Don't re-raise - worker remains out of pool
233
+
234
+ # Function is expected to be called at the end of compute_reward
235
+ async def _cleanup_rollout(self, rollout_id: str) -> None:
236
+ """Clean up rollout resources and return worker to pool."""
237
+ if rollout_id not in self._rollout_to_worker:
238
+ raise ValueError(f"Rollout {rollout_id} is not initialized")
239
+
240
+ worker_ip = self._rollout_to_worker[rollout_id]
241
+ logger.debug(f"Cleaning up rollout {rollout_id} on worker {worker_ip}")
242
+
243
+ del self._rollout_to_worker[rollout_id]
244
+ try:
245
+ # Call reset endpoint
246
+ await self._call_worker_reset(worker_ip)
247
+ except Exception as e:
248
+ logger.error(f"Failed to reset worker {worker_ip} for rollout {rollout_id}: {e}")
249
+
250
+ # Disconnect the client
251
+ client = self._client_pool.get(worker_ip)
252
+ if client:
253
+ await client._disconnect()
254
+ self._client_pool.pop(worker_ip, None)
255
+
256
+ # Start background task to add worker back once healthy
257
+ asyncio.create_task(self.add_worker_back_once_available(worker_ip))
258
+
259
+ def _convert_and_filter_tools(self, tools: List[Tool]) -> List[ToolDefinition]:
260
+ """Convert Tool objects to ToolDefinition objects and filter based on allowed list."""
261
+ tool_definitions = [
262
+ ToolDefinition(
263
+ name=tool.name,
264
+ description=tool.description or "",
265
+ input_schema=tool.inputSchema
266
+ )
267
+ for tool in tools
268
+ ]
269
+
270
+ if not self._allowed_tools:
271
+ return tool_definitions
272
+
273
+ return [tool for tool in tool_definitions if tool.name in self._allowed_tools]
274
+
275
+ # ---- Public API Methods ----
276
+ def launch_workers(self) -> None:
277
+ """Launch SkyPilot workers synchronously with programmatically created task."""
278
+ if self.launch_workers_started:
279
+ raise RuntimeError("Workers have already been launched.")
280
+
281
+ self.launch_workers_started = True
282
+
283
+ try:
284
+ # Setup sync directory and copy files
285
+ sync_dir = self._setup_sync_directory()
286
+
287
+ # Create the task programmatically
288
+ task = sky.Task(
289
+ name='fastmcp',
290
+ setup='pip install fastmcp~=2.10.0\npip install pyyaml\nsh setup.sh',
291
+ run='python proxy_server.py',
292
+ workdir=str(sync_dir),
293
+ num_nodes=self._num_nodes
294
+ )
295
+
296
+ # Set the resources
297
+ task.set_resources(
298
+ sky.Resources(
299
+ cloud=self._cloud,
300
+ cpus=self._cpus,
301
+ ports=self._ports
302
+ )
303
+ )
304
+
305
+ # Update environment variables with API token
306
+ task.update_envs({"API_TOKEN": self._api_token})
307
+
308
+ # Launch the cluster
309
+ _, handle = sky.launch(
310
+ task=task,
311
+ cluster_name=self._full_cluster_name,
312
+ detach_run=True,
313
+ detach_setup=True,
314
+ retry_until_up=True
315
+ )
316
+
317
+ if handle is None:
318
+ raise RuntimeError("Failed to launch SkyPilot task.")
319
+
320
+ worker_ips = [
321
+ external_ip for _, external_ip in handle.stable_internal_external_ips
322
+ ]
323
+ logger.info(f"Launched workers with IPs: {worker_ips}")
324
+
325
+ # Start background initialization for each worker
326
+ for worker_ip in worker_ips:
327
+ task = asyncio.create_task(self._init_worker(worker_ip))
328
+ self._worker_init_tasks.append(task)
329
+
330
+ except Exception as e:
331
+ # Clean up sync directory if launch fails
332
+ self._cleanup_sync_directory()
333
+ raise e
334
+
335
+ async def shutdown(self) -> None:
336
+ """Clean up resources - stop all tasks and close clients."""
337
+ try:
338
+ # Cancel worker initialization tasks
339
+ for task in self._worker_init_tasks:
340
+ if not task.done():
341
+ task.cancel()
342
+
343
+ if self._worker_init_tasks:
344
+ results = await asyncio.gather(*self._worker_init_tasks, return_exceptions=True)
345
+ for i, result in enumerate(results):
346
+ if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError):
347
+ logger.error(f"Error in worker init task {i}: {result}")
348
+
349
+ # Close FastMCP clients
350
+ if self._client_pool:
351
+ close_tasks = [client.close() for client in self._client_pool.values()]
352
+ results = await asyncio.gather(*close_tasks, return_exceptions=True)
353
+
354
+ for i, result in enumerate(results):
355
+ if isinstance(result, Exception):
356
+ worker_ip = list(self._client_pool.keys())[i]
357
+ logger.error(f"Error closing FastMCP client for {worker_ip}: {result}")
358
+
359
+ # Close HTTP session
360
+ await self._http_session.close()
361
+
362
+ # Tear down SkyPilot cluster
363
+ try:
364
+ sky.down(cluster_name=self._full_cluster_name)
365
+ except Exception as e:
366
+ logger.error(f"Error tearing down SkyPilot cluster: {e}")
367
+
368
+ finally:
369
+ # Always clean up sync directory
370
+ self._cleanup_sync_directory()
371
+
372
+ async def list_tools(self) -> List[ToolDefinition]:
373
+ """List available tools, using cached definitions if available."""
374
+ if self._tool_definitions is not None:
375
+ return self._tool_definitions
376
+
377
+ # Get any available worker to fetch tools
378
+ worker_ip = await self._get_available_worker()
379
+ try:
380
+ client = self._client_pool[worker_ip]
381
+ tools = await client.list_tools()
382
+ self._tool_definitions = self._convert_and_filter_tools(tools)
383
+ return self._tool_definitions
384
+ finally:
385
+ await self._release_worker(worker_ip)
386
+
387
+ async def init_rollout(self, rollout_id: str, **rollout_args) -> None:
388
+ """Initialize resources for a new rollout - assigns a worker to the rollout."""
389
+ if rollout_id in self._rollout_to_worker:
390
+ raise ValueError(f"Rollout {rollout_id} is already initialized")
391
+
392
+ # Get an available worker (blocks until one is ready)
393
+ worker_ip = await self._get_available_worker()
394
+
395
+ # Assign worker to rollout
396
+ self._rollout_to_worker[rollout_id] = worker_ip
397
+ logger.info(f"Rollout {rollout_id} assigned to worker {worker_ip}")
398
+
399
+ async def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Optional[str]:
400
+ """Execute a tool in the context of a specific rollout."""
401
+ if rollout_id not in self._rollout_to_worker:
402
+ raise ValueError(f"Rollout {rollout_id} is not initialized. Call init_rollout() first.")
403
+
404
+ worker_ip = self._rollout_to_worker[rollout_id]
405
+ client = self._client_pool[worker_ip]
406
+
407
+ try:
408
+ content_list = (await client.call_tool(tool_name, tool_args, timeout=datetime.timedelta(seconds=30))).content
409
+ text_content = []
410
+ # Process content based on type
411
+ for content in content_list:
412
+ # Text content
413
+ if isinstance(content, TextContent):
414
+ text_content.append(content.text)
415
+ # Only process text content for now
416
+
417
+ combined_text = "\n".join(text_content)
418
+ # Apply output parser if available
419
+ if tool_name in self._output_parsers and isinstance(combined_text, str):
420
+ return self._output_parsers[tool_name](combined_text)
421
+
422
+ return combined_text
423
+
424
+ except ToolError as e:
425
+ logger.error(f"[ERROR] Tool call returned error: {str(e)}")
426
+ return None
427
+ except Exception as e:
428
+ logger.error(f"[ERROR] Tool call failed: {str(e)}")
429
+ return None
430
+
431
+ async def copy_to_workspace(
432
+ self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
433
+ ) -> None:
434
+ """Copy a file to the workspace for a specific rollout."""
435
+ if rollout_id not in self._rollout_to_worker:
436
+ raise ValueError(f"Rollout {rollout_id} is not initialized")
437
+
438
+ worker_ip = self._rollout_to_worker[rollout_id]
439
+ upload_url = f"http://{worker_ip}:8080/upload"
440
+ headers = {"Authorization": self._api_token}
441
+
442
+ # Prepare file for upload
443
+ filename = dst_filename or src_path.name
444
+
445
+ try:
446
+ with open(src_path, 'rb') as f:
447
+ data = aiohttp.FormData()
448
+ data.add_field('file', f, filename=filename)
449
+
450
+ async with self._http_session.post(upload_url, headers=headers, data=data) as response:
451
+ if response.status == 200:
452
+ logger.info(f"File {src_path} uploaded as {filename} for rollout {rollout_id}")
453
+ else:
454
+ error_text = await response.text()
455
+ raise RuntimeError(f"Upload failed: {response.status} - {error_text}")
456
+ except Exception as e:
457
+ logger.error(f"Failed to copy {src_path} to workspace for rollout {rollout_id}: {e}")
458
+ raise
459
+
460
+ async def copy_content_to_workspace(
461
+ self, rollout_id: str, src_content: str | bytes, dst_filename: str, encoding: str = "utf-8"
462
+ ) -> None:
463
+ """Copy content (string or bytes) to the workspace for a specific rollout.
464
+
465
+ Args:
466
+ rollout_id: The rollout identifier.
467
+ src_content: The content to upload (str or bytes).
468
+ dst_filename: The filename to assign in the workspace.
469
+ encoding: Encoding to use if src_content is str. Defaults to UTF-8.
470
+ """
471
+ if rollout_id not in self._rollout_to_worker:
472
+ raise ValueError(f"Rollout {rollout_id} is not initialized")
473
+
474
+ worker_ip = self._rollout_to_worker[rollout_id]
475
+ upload_url = f"http://{worker_ip}:8080/upload"
476
+ headers = {"Authorization": self._api_token}
477
+
478
+ try:
479
+ if isinstance(src_content, str):
480
+ file_bytes = src_content.encode(encoding)
481
+ content_type = "text/plain"
482
+ else:
483
+ file_bytes = src_content
484
+ content_type = "application/octet-stream"
485
+
486
+ data = aiohttp.FormData()
487
+ data.add_field(
488
+ "file",
489
+ file_bytes,
490
+ filename=dst_filename,
491
+ content_type=content_type,
492
+ )
493
+
494
+ async with self._http_session.post(upload_url, headers=headers, data=data) as response:
495
+ if response.status == 200:
496
+ logger.info(f"Content uploaded as {dst_filename} for rollout {rollout_id}")
497
+ else:
498
+ error_text = await response.text()
499
+ raise RuntimeError(f"Upload failed: {response.status} - {error_text}")
500
+ except Exception as e:
501
+ logger.error(f"Failed to upload content to workspace for rollout {rollout_id}: {e}")
502
+ raise
503
+
504
+
505
+ async def copy_from_workspace(
506
+ self, rollout_id: str, src_filename: str, dst_path: Path
507
+ ) -> None:
508
+ """Copy a file from the workspace for a specific rollout."""
509
+ if rollout_id not in self._rollout_to_worker:
510
+ raise ValueError(f"Rollout {rollout_id} is not initialized")
511
+
512
+ worker_ip = self._rollout_to_worker[rollout_id]
513
+ download_url = f"http://{worker_ip}:8080/download"
514
+ headers = {"Authorization": self._api_token}
515
+ params = {"file_path": src_filename}
516
+
517
+ try:
518
+ async with self._http_session.get(download_url, headers=headers, params=params) as response:
519
+ if response.status == 200:
520
+ # Ensure destination directory exists
521
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
522
+
523
+ # Write file content
524
+ with open(dst_path, 'wb') as f:
525
+ async for chunk in response.content.iter_chunked(8192):
526
+ f.write(chunk)
527
+
528
+ logger.info(f"File {src_filename} downloaded from rollout {rollout_id} to {dst_path}")
529
+ else:
530
+ error_text = await response.text()
531
+ raise RuntimeError(f"Download failed: {response.status} - {error_text}")
532
+ except Exception as e:
533
+ logger.error(f"Failed to copy {src_filename} from workspace for rollout {rollout_id}: {e}")
534
+ raise
535
+
536
+ async def compute_reward(
537
+ self,
538
+ rollout_id: str,
539
+ completion: str,
540
+ ground_truth: Any,
541
+ **kwargs: Any
542
+ ) -> Dict[str, float]:
543
+ """Compute rewards using registered functions
544
+
545
+ Returns dict mapping reward function names to their computed scores.
546
+ """
547
+ if rollout_id not in self._rollout_to_worker:
548
+ raise ValueError(f"Rollout {rollout_id} is not initialized")
549
+
550
+ worker_ip = self._rollout_to_worker[rollout_id]
551
+ compute_reward_url = f"http://{worker_ip}:8080/compute_reward"
552
+ headers = {
553
+ "Authorization": self._api_token,
554
+ "Content-Type": "application/json"
555
+ }
556
+
557
+ # Prepare request payload
558
+ payload = {
559
+ "completion": completion,
560
+ "ground_truth": ground_truth,
561
+ **kwargs
562
+ }
563
+
564
+ try:
565
+ async with self._http_session.post(
566
+ compute_reward_url,
567
+ headers=headers,
568
+ json=payload
569
+ ) as response:
570
+ if response.status == 200:
571
+ result = await response.json()
572
+ logger.debug(f"Reward computed successfully for rollout {rollout_id}")
573
+ return result
574
+ else:
575
+ error_text = await response.text()
576
+ raise RuntimeError(f"Reward computation failed: {response.status} - {error_text}")
577
+ except aiohttp.ClientError as e:
578
+ logger.error(f"Failed to compute reward for rollout {rollout_id}: {e}")
579
+ raise RuntimeError(f"Reward computation request failed: {e}")
580
+ except Exception as e:
581
+ logger.error(f"Unexpected error computing reward for rollout {rollout_id}: {e}")
582
+ raise
583
+ finally:
584
+ await self._cleanup_rollout(rollout_id)
585
+
586
+
587
+ async def run_single_rollout(env: RemoteSkypilotMcpEnv, rollout_id: str, expression: str, expected: str, tmp_root: Path):
588
+ """Run a complete rollout: init -> upload -> download+verify -> tool -> reward -> cleanup"""
589
+ print(f"Starting rollout: {rollout_id}")
590
+
591
+ # Create rollout-specific tmp dir
592
+ rollout_tmp = tmp_root / rollout_id
593
+ rollout_tmp.mkdir(parents=True, exist_ok=True)
594
+
595
+ # Stage 1: Initialize rollout
596
+ await env.init_rollout(rollout_id)
597
+ print(f"Initialized rollout: {rollout_id}")
598
+
599
+ # Stage 1.5a: Upload various content types
600
+ test_contents = {
601
+ "utf8_text.txt": f"# UTF-8 text for {rollout_id}\nExpression: {expression}\n",
602
+ "latin1_text.txt": "Café Münster".encode("latin-1"),
603
+ "json_data.json": '{"rollout": "%s", "value": %s}' % (rollout_id, expression),
604
+ "binary_data.bin": b"\x00\x01\x02\x03\xFF",
605
+ "unicode_text.txt": "你好, мир, hello 🌍",
606
+ }
607
+
608
+ for filename, content in test_contents.items():
609
+ await env.copy_content_to_workspace(rollout_id, content, filename)
610
+ print(f"Uploaded {filename} for {rollout_id}")
611
+
612
+ # Stage 1.5b: Test file-based copy
613
+ tmp_path = rollout_tmp / "local_file.txt"
614
+ tmp_path.write_text(f"Temporary file for {rollout_id}, expression={expression}\n", encoding="utf-8")
615
+ await env.copy_to_workspace(rollout_id, tmp_path, dst_filename=f"copied_{rollout_id}.txt")
616
+ print(f"Copied file {tmp_path} to workspace for {rollout_id}")
617
+
618
+ # Stage 1.6: Download and verify content
619
+ for filename, original_content in test_contents.items():
620
+ download_path = rollout_tmp / f"dl_{filename}"
621
+ await env.copy_from_workspace(rollout_id, filename, download_path)
622
+
623
+ downloaded_bytes = download_path.read_bytes()
624
+ if isinstance(original_content, str):
625
+ original_bytes = original_content.encode("utf-8")
626
+ else:
627
+ original_bytes = original_content
628
+
629
+ if downloaded_bytes == original_bytes:
630
+ print(f"Verified {filename} ✅")
631
+ else:
632
+ print(f"Mismatch in {filename}! ❌")
633
+
634
+ # Verify copied file
635
+ print(f"Verifying copied file for {rollout_id}")
636
+ copied_dl = rollout_tmp / f"dl_copied_{rollout_id}.txt"
637
+ await env.copy_from_workspace(rollout_id, f"copied_{rollout_id}.txt", copied_dl)
638
+ if copied_dl.read_text(encoding="utf-8") == tmp_path.read_text(encoding="utf-8"):
639
+ print(f"Verified copied file ✅")
640
+ else:
641
+ print(f"Mismatch in copied file ❌")
642
+
643
+ # Stage 2: Run tool
644
+ tool_result = await env.run_tool(rollout_id, "calculate", expression=expression)
645
+ print(f"Tool result for {rollout_id}: {tool_result}")
646
+
647
+ # Stage 3: Compute reward
648
+ reward = await env.compute_reward(rollout_id, completion=str(tool_result), ground_truth=expected)
649
+ print(f"Computed reward for {rollout_id}: {reward}")
650
+
651
+ return rollout_id, reward
652
+
653
+
654
+ async def main():
655
+ env = RemoteSkypilotMcpEnv(
656
+ workdir_path="benchmax/envs/skypilot/workdir",
657
+ num_nodes=2,
658
+ cluster_name="test-cluster",
659
+ cloud=sky.Azure(),
660
+ cpus="2+",
661
+ )
662
+
663
+ tmp_root = Path("./tmp")
664
+ tmp_root.mkdir(exist_ok=True)
665
+
666
+ try:
667
+ tools = await env.list_tools()
668
+ print(f"Available tools: {[tool.name for tool in tools]}")
669
+
670
+ rollout_tasks = []
671
+ for i in range(3): # fewer for debugging; adjust as needed
672
+ rollout_id = f"test-rollout-{i:03d}"
673
+ expression = f"{i + 1} + {i + 1}"
674
+ expected = str((i + 1) + (i + 1))
675
+ task = run_single_rollout(env, rollout_id, expression, expected, tmp_root)
676
+ rollout_tasks.append(task)
677
+
678
+ print("Starting concurrent rollouts...")
679
+ results = await asyncio.gather(*rollout_tasks, return_exceptions=True)
680
+
681
+ print("Rollout results:")
682
+ for result in results:
683
+ print(result)
684
+
685
+ finally:
686
+ await env.shutdown()
687
+ # Cleanup tmp dir at the very end
688
+ shutil.rmtree(tmp_root, ignore_errors=True)
689
+ print("Cleaned up temporary files.")
690
+
691
+
692
+
693
+ if __name__ == "__main__":
694
+ asyncio.run(main())
@@ -0,0 +1,5 @@
1
+ mcpServers:
2
+ server-name:
3
+ command: uvx
4
+ args:
5
+ - mcp-server-calculator
@@ -0,0 +1,16 @@
1
+ from typing import Any
2
+ from fastmcp import Client
3
+
4
+ def reward_function(
5
+ completion: str,
6
+ ground_truth: Any,
7
+ workspace: str,
8
+ mcp_client: Client,
9
+ **kwargs: Any
10
+ ) -> float:
11
+ """Compute the reward for a given model completion."""
12
+ print(f"Workspace for reward function: {workspace}")
13
+ return 1.0 if completion.strip() == ground_truth.strip() else 0.0
14
+
15
+
16
+ reward_functions = [reward_function]
@@ -0,0 +1,3 @@
1
+ #!/bin/bash
2
+ # Install uv for our calculator mcp
3
+ curl -LsSf https://astral.sh/uv/0.8.14/install.sh | sh
@@ -7,7 +7,6 @@ class StandardizedExample(TypedDict):
7
7
  ground_truth: Any
8
8
  init_rollout_args: Optional[Dict[str, Any]]
9
9
 
10
-
11
10
  @dataclass
12
11
  class ToolDefinition:
13
12
  """Definition of a tool's interface"""
@@ -19,10 +18,9 @@ class RewardFunction(Protocol):
19
18
  """Function that evaluates model interactions"""
20
19
  def __call__(
21
20
  self,
22
- prompt: str, # Input prompt given to the model
23
21
  completion: str, # Model's generated completion/response
24
22
  ground_truth: Any, # Expected/correct output to compare against
25
- workspace: Path, # Path to rollout's workspace with tool outputs
23
+ workspace: str, # Current workspace of the rollout
26
24
  **kwargs: Any # Additional context for reward computation
27
25
  ) -> float: # Reward score (typically in range [0, 1])
28
26
  ...
@@ -220,6 +220,18 @@ class WikipediaEnv(BaseEnv):
220
220
 
221
221
  def get_rollout_workspace(self, rollout_id: str) -> Path:
222
222
  return super().get_rollout_workspace(rollout_id)
223
+
224
+ def copy_to_workspace(
225
+ self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
226
+ ) -> None:
227
+ """Copy a file to the workspace for a specific rollout."""
228
+ pass
229
+
230
+ def copy_from_workspace(
231
+ self, rollout_id: str, src_filename: str, dst_path: Path
232
+ ) -> None:
233
+ """Copy a file from the workspace for a specific rollout."""
234
+ pass
223
235
 
224
236
  if __name__ == "__main__":
225
237
  # Example usage
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "benchmax"
3
- version = "0.1.1.dev4"
3
+ version = "0.1.1.dev6"
4
4
  description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
5
5
  authors = ["cgft.io"]
6
6
  readme = "README.md"
@@ -12,6 +12,8 @@ packages = [
12
12
  python = ">=3.11,<3.13"
13
13
  fastmcp = "~2.10.0"
14
14
 
15
+ skypilot = { version = "0.8.1", optional = true }
16
+
15
17
  verl-cgft-fork = { version = "0.5.0.dev2", optional = true }
16
18
  sglang = { version = "0.4.9", optional = true, extras = ["all"] }
17
19
  verifiers = { version = "^0.1.1", optional = true, extras = ["train"] }
@@ -30,6 +32,9 @@ pytest = "^8.4.1"
30
32
  verifiers = ["verifiers"]
31
33
  verl = ["verl-cgft-fork", "sglang"]
32
34
 
35
+ # Hosting-specific
36
+ skypilot = ["skypilot"]
37
+
33
38
  # Environment-specific
34
39
  excel-linux = ["openpyxl"]
35
40
  excel = ["openpyxl", "xlwings"]
File without changes