benchmax 0.1.2.dev11__tar.gz → 0.1.2.dev13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/PKG-INFO +9 -5
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/pyproject.toml +18 -9
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/skyrl/benchmax_data_process.py +9 -4
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/loader.py +18 -11
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/base_env.py +10 -13
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/crm/crm_env.py +13 -3
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/data_utils.py +2 -1
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/excel_env.py +13 -6
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/math/math_env.py +13 -3
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/parallel_mcp_env.py +6 -1
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/__init__.py +7 -2
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py +18 -7
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/proxy_server.py +11 -5
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/server_pool.py +6 -1
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/utils.py +6 -1
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/PKG-INFO +9 -5
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/requires.txt +7 -1
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/LICENSE +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/README.md +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/setup.cfg +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/__init__.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/benchmax_wrapper.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/skyrl/skyrl_adapter.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/__init__.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/bundler.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/errors.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/payload.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/validator.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/__init__.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/crm/workdir/reward_fn.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/crm/workdir/salesforce_mcp.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/__init__.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/excel_utils.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/reward_fn.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/math/workdir/reward_fn.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/__init__.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/example_workdir/reward_fn.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/base_provisioner.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/local_provisioner.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/manual_provisioner.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/utils.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/types.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/wikipedia/utils.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/wikipedia/wiki_env.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/prompts/__init__.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/prompts/tools.py +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/SOURCES.txt +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/dependency_links.txt +0 -0
- {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/top_level.txt +0 -0
|
@@ -1,20 +1,24 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: benchmax
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev13
|
|
4
4
|
Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
|
|
5
5
|
Author: cgft.io
|
|
6
6
|
Classifier: Programming Language :: Python :: 3
|
|
7
7
|
Classifier: Operating System :: OS Independent
|
|
8
|
-
Requires-Python:
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
9
|
Description-Content-Type: text/markdown
|
|
10
10
|
License-File: LICENSE
|
|
11
11
|
Requires-Dist: aiohttp>=3.13.1
|
|
12
12
|
Requires-Dist: asyncio>=4.0.0
|
|
13
13
|
Requires-Dist: cloudpickle>=3.0.0
|
|
14
14
|
Requires-Dist: datasets>=4.0.0
|
|
15
|
-
|
|
16
|
-
Requires-Dist:
|
|
17
|
-
Requires-Dist:
|
|
15
|
+
Provides-Extra: mcp
|
|
16
|
+
Requires-Dist: fastmcp~=2.12.0; extra == "mcp"
|
|
17
|
+
Requires-Dist: pyjwt>=2.10.1; extra == "mcp"
|
|
18
|
+
Provides-Extra: skypilot
|
|
19
|
+
Requires-Dist: skypilot[aws,gcp]~=0.8.1; extra == "skypilot"
|
|
20
|
+
Requires-Dist: pip>=25.3; extra == "skypilot"
|
|
21
|
+
Requires-Dist: msrestazure>=0.6.4.post1; extra == "skypilot"
|
|
18
22
|
Dynamic: license-file
|
|
19
23
|
|
|
20
24
|
<picture>
|
|
@@ -1,18 +1,15 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "benchmax"
|
|
3
|
-
version = "0.1.2.
|
|
3
|
+
version = "0.1.2.dev13"
|
|
4
4
|
description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [{ name = "cgft.io" }]
|
|
7
|
-
requires-python = "
|
|
7
|
+
requires-python = ">=3.12"
|
|
8
8
|
dependencies = [
|
|
9
9
|
"aiohttp>=3.13.1",
|
|
10
10
|
"asyncio>=4.0.0",
|
|
11
11
|
"cloudpickle>=3.0.0",
|
|
12
12
|
"datasets>=4.0.0",
|
|
13
|
-
"fastmcp~=2.12.0",
|
|
14
|
-
"pyjwt>=2.10.1",
|
|
15
|
-
"skypilot~=0.8.1",
|
|
16
13
|
]
|
|
17
14
|
classifiers = [
|
|
18
15
|
"Programming Language :: Python :: 3",
|
|
@@ -26,6 +23,17 @@ build-backend = "setuptools.build_meta"
|
|
|
26
23
|
[tool.setuptools.packages.find]
|
|
27
24
|
where = ["src"]
|
|
28
25
|
|
|
26
|
+
[project.optional-dependencies]
|
|
27
|
+
mcp = [
|
|
28
|
+
"fastmcp~=2.12.0",
|
|
29
|
+
"pyjwt>=2.10.1",
|
|
30
|
+
]
|
|
31
|
+
skypilot = [
|
|
32
|
+
"skypilot[aws,gcp]~=0.8.1",
|
|
33
|
+
"pip>=25.3",
|
|
34
|
+
"msrestazure>=0.6.4.post1",
|
|
35
|
+
]
|
|
36
|
+
|
|
29
37
|
[dependency-groups]
|
|
30
38
|
dev = [
|
|
31
39
|
"pytest>=8.4.2",
|
|
@@ -33,8 +41,12 @@ dev = [
|
|
|
33
41
|
"python-dotenv>=1.2.1",
|
|
34
42
|
"ruff>=0.14.2",
|
|
35
43
|
]
|
|
44
|
+
mcp = [
|
|
45
|
+
"fastmcp~=2.12.0",
|
|
46
|
+
"pyjwt>=2.10.1",
|
|
47
|
+
]
|
|
36
48
|
skypilot = [
|
|
37
|
-
"skypilot[aws,gcp
|
|
49
|
+
"skypilot[aws,gcp]~=0.8.1", # Add azure only in an Azure-specific env/group
|
|
38
50
|
"pip>=25.3", # Added as needed for skypilot launch
|
|
39
51
|
"msrestazure>=0.6.4.post1",
|
|
40
52
|
]
|
|
@@ -53,9 +65,6 @@ crm = ["python-dateutil>=2.9.0.post0", "simple-salesforce>=1.12.9"]
|
|
|
53
65
|
[tool.uv]
|
|
54
66
|
conflicts = [[{ group = "skypilot" }, { group = "skyrl" }]]
|
|
55
67
|
|
|
56
|
-
[tool.uv.pip]
|
|
57
|
-
extra = ["dev", "skypilot", "skyrl", "excel", "excel-mac-windows", "crm"]
|
|
58
|
-
|
|
59
68
|
# [tool.uv.extra-build-dependencies]
|
|
60
69
|
# flash-attn = [{ requirement = "torch", match-runtime = true }]
|
|
61
70
|
|
{benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/skyrl/benchmax_data_process.py
RENAMED
|
@@ -7,14 +7,15 @@ import logging
|
|
|
7
7
|
from importlib import import_module
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from types import ModuleType
|
|
10
|
-
from typing import Type
|
|
11
|
-
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
12
|
-
import datasets
|
|
10
|
+
from typing import TYPE_CHECKING, Type
|
|
13
11
|
import asyncio
|
|
14
12
|
import inspect
|
|
15
13
|
|
|
16
14
|
from benchmax.envs.base_env import BaseEnv
|
|
17
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
18
|
+
|
|
18
19
|
# Set logging level to WARNING and above
|
|
19
20
|
logging.basicConfig(level=logging.WARNING)
|
|
20
21
|
|
|
@@ -124,6 +125,8 @@ if __name__ == "__main__":
|
|
|
124
125
|
benchmax_cls: Type[BaseEnv] = load_class(args.env_path)
|
|
125
126
|
raw_dataset, dataset_path = benchmax_cls.load_dataset(args.dataset_name)
|
|
126
127
|
|
|
128
|
+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
129
|
+
|
|
127
130
|
if isinstance(raw_dataset, (IterableDataset, IterableDatasetDict)):
|
|
128
131
|
raise TypeError(
|
|
129
132
|
f"Iterable datasets are currently not supported. Got {type(raw_dataset).__name__}. "
|
|
@@ -178,7 +181,9 @@ if __name__ == "__main__":
|
|
|
178
181
|
test_dataset = processed_dataset["test"]
|
|
179
182
|
else:
|
|
180
183
|
if isinstance(processed_dataset, DatasetDict):
|
|
181
|
-
|
|
184
|
+
from datasets import concatenate_datasets
|
|
185
|
+
|
|
186
|
+
processed_dataset = concatenate_datasets(
|
|
182
187
|
[ds for ds in processed_dataset.values()]
|
|
183
188
|
).shuffle(seed=42)
|
|
184
189
|
|
|
@@ -1,26 +1,25 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
import subprocess
|
|
3
2
|
import sys
|
|
4
3
|
from pathlib import Path
|
|
5
|
-
from typing import Any, Dict, Optional, Type, Union
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
|
|
6
5
|
|
|
7
|
-
import cloudpickle
|
|
8
|
-
|
|
9
|
-
from benchmax.envs.base_env import BaseEnv
|
|
10
6
|
from benchmax.bundle.errors import (
|
|
11
7
|
DependencyError,
|
|
12
8
|
IncompatiblePythonError,
|
|
13
9
|
IncompatibleBenchmaxError,
|
|
14
10
|
BundlingError,
|
|
15
11
|
)
|
|
16
|
-
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from benchmax.envs.base_env import BaseEnv
|
|
15
|
+
from benchmax.bundle.payload import BundleMetadata, BundledEnv
|
|
17
16
|
|
|
18
17
|
logger = logging.getLogger(__name__)
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
def load_env(
|
|
22
21
|
pickled_class: bytes,
|
|
23
|
-
metadata: BundleMetadata | None = None,
|
|
22
|
+
metadata: "BundleMetadata | None" = None,
|
|
24
23
|
*,
|
|
25
24
|
pip_dependencies: Optional[list[str]] = None,
|
|
26
25
|
python_version: Optional[str] = None,
|
|
@@ -30,7 +29,7 @@ def load_env(
|
|
|
30
29
|
check_benchmax_version: bool = False,
|
|
31
30
|
install_pip_deps: bool = False,
|
|
32
31
|
instantiate: Optional[bool] = None,
|
|
33
|
-
) -> BaseEnv | Type[BaseEnv]:
|
|
32
|
+
) -> "BaseEnv | Type[BaseEnv]":
|
|
34
33
|
"""Load a packaged environment class (and optionally instantiate it).
|
|
35
34
|
|
|
36
35
|
Args:
|
|
@@ -52,6 +51,8 @@ def load_env(
|
|
|
52
51
|
The unpickled BaseEnv subclass (class object), or an instance if
|
|
53
52
|
instantiation is requested.
|
|
54
53
|
"""
|
|
54
|
+
from benchmax.envs.base_env import BaseEnv
|
|
55
|
+
|
|
55
56
|
resolved_pip_deps = (
|
|
56
57
|
pip_dependencies
|
|
57
58
|
if pip_dependencies is not None
|
|
@@ -101,6 +102,8 @@ def load_env(
|
|
|
101
102
|
_install_dependencies(resolved_pip_deps)
|
|
102
103
|
|
|
103
104
|
try:
|
|
105
|
+
import cloudpickle
|
|
106
|
+
|
|
104
107
|
env_class = cloudpickle.loads(pickled_class)
|
|
105
108
|
except Exception as e:
|
|
106
109
|
raise BundlingError(
|
|
@@ -140,7 +143,9 @@ def load_env_from_files(
|
|
|
140
143
|
check_benchmax_version: bool = False,
|
|
141
144
|
install_pip_deps: bool = False,
|
|
142
145
|
instantiate: Optional[bool] = None,
|
|
143
|
-
) -> BaseEnv | Type[BaseEnv]:
|
|
146
|
+
) -> "BaseEnv | Type[BaseEnv]":
|
|
147
|
+
from benchmax.bundle.payload import BundleMetadata
|
|
148
|
+
|
|
144
149
|
pickle_path = Path(pickle_path)
|
|
145
150
|
metadata_path = Path(metadata_path)
|
|
146
151
|
pickled_class = pickle_path.read_bytes()
|
|
@@ -160,7 +165,7 @@ def load_env_from_files(
|
|
|
160
165
|
|
|
161
166
|
|
|
162
167
|
def load_env_from_bundle(
|
|
163
|
-
bundle: BundledEnv,
|
|
168
|
+
bundle: "BundledEnv",
|
|
164
169
|
*,
|
|
165
170
|
pip_dependencies: Optional[list[str]] = None,
|
|
166
171
|
python_version: Optional[str] = None,
|
|
@@ -170,7 +175,7 @@ def load_env_from_bundle(
|
|
|
170
175
|
check_benchmax_version: bool = False,
|
|
171
176
|
install_pip_deps: bool = False,
|
|
172
177
|
instantiate: Optional[bool] = None,
|
|
173
|
-
) -> BaseEnv | Type[BaseEnv]:
|
|
178
|
+
) -> "BaseEnv | Type[BaseEnv]":
|
|
174
179
|
return load_env(
|
|
175
180
|
bundle.pickled_class,
|
|
176
181
|
bundle.metadata,
|
|
@@ -187,6 +192,8 @@ def load_env_from_bundle(
|
|
|
187
192
|
|
|
188
193
|
def _install_dependencies(deps: list[str]) -> None:
|
|
189
194
|
"""Install pip dependencies in the current environment."""
|
|
195
|
+
import subprocess
|
|
196
|
+
|
|
190
197
|
logger.info(f"[bundling] Installing {len(deps)} dependencies: {deps}")
|
|
191
198
|
cmd = [sys.executable, "-m", "pip", "install", "--quiet", *deps]
|
|
192
199
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
@@ -1,27 +1,19 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, List, Any, Optional, Tuple
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Any, Optional, Tuple
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from datasets import (
|
|
5
|
-
DatasetDict,
|
|
6
|
-
Dataset,
|
|
7
|
-
IterableDatasetDict,
|
|
8
|
-
IterableDataset,
|
|
9
|
-
load_dataset,
|
|
10
|
-
)
|
|
11
4
|
|
|
12
5
|
from benchmax.envs.types import ToolDefinition, StandardizedExample
|
|
13
6
|
from benchmax.prompts.tools import render_tools_prompt
|
|
14
7
|
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
10
|
+
|
|
15
11
|
|
|
16
12
|
class BaseEnv(ABC):
|
|
17
13
|
"""Base benchmax environment for tool execution and reward computation"""
|
|
18
14
|
|
|
19
15
|
system_prompt: str = ""
|
|
20
16
|
|
|
21
|
-
@abstractmethod
|
|
22
|
-
async def shutdown(self):
|
|
23
|
-
pass
|
|
24
|
-
|
|
25
17
|
# Override this method if your example does not match the default structure
|
|
26
18
|
@classmethod
|
|
27
19
|
def dataset_preprocess(cls, example: Any, **kwargs) -> StandardizedExample:
|
|
@@ -45,7 +37,7 @@ class BaseEnv(ABC):
|
|
|
45
37
|
def load_dataset(
|
|
46
38
|
cls, dataset_name: str, **kwargs
|
|
47
39
|
) -> Tuple[
|
|
48
|
-
DatasetDict | Dataset | IterableDatasetDict | IterableDataset, str | None
|
|
40
|
+
"DatasetDict | Dataset | IterableDatasetDict | IterableDataset", str | None
|
|
49
41
|
]:
|
|
50
42
|
"""
|
|
51
43
|
Download and prepare a dataset for use with this environment.
|
|
@@ -63,6 +55,8 @@ class BaseEnv(ABC):
|
|
|
63
55
|
Dataset: A dataset object (e.g., HuggingFace Dataset or similar) ready for processing.
|
|
64
56
|
str: Optional string pointing to where the dataset is stored locally
|
|
65
57
|
"""
|
|
58
|
+
from datasets import load_dataset
|
|
59
|
+
|
|
66
60
|
return load_dataset(dataset_name, **kwargs), None
|
|
67
61
|
|
|
68
62
|
# Methods all environment subclasses must implement
|
|
@@ -98,6 +92,9 @@ class BaseEnv(ABC):
|
|
|
98
92
|
|
|
99
93
|
# Optional rollout lifecycle management methods
|
|
100
94
|
|
|
95
|
+
async def shutdown(self):
|
|
96
|
+
pass
|
|
97
|
+
|
|
101
98
|
async def init_rollout(self, rollout_id: str, **rollout_args) -> None:
|
|
102
99
|
"""Initialize resources for a new rollout"""
|
|
103
100
|
return None
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Any, Dict, List, Optional
|
|
3
|
-
import sky
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
4
3
|
|
|
5
4
|
from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
|
|
6
5
|
from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
|
|
@@ -8,6 +7,9 @@ from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
|
|
|
8
7
|
from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
|
|
9
8
|
from benchmax.envs.types import StandardizedExample
|
|
10
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import sky
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
SYSTEM_PROMPT = """\
|
|
13
15
|
You are an expert in Salesforce and you have access to a Salesforce instance.
|
|
@@ -85,11 +87,19 @@ class CRMEnvSkypilot(CRMEnv):
|
|
|
85
87
|
|
|
86
88
|
def __init__(
|
|
87
89
|
self,
|
|
88
|
-
cloud: sky.clouds.Cloud =
|
|
90
|
+
cloud: "sky.clouds.Cloud | None" = None,
|
|
89
91
|
num_nodes: int = 2,
|
|
90
92
|
servers_per_node: int = 5,
|
|
91
93
|
**kwargs,
|
|
92
94
|
):
|
|
95
|
+
if cloud is None:
|
|
96
|
+
try:
|
|
97
|
+
import sky
|
|
98
|
+
except ModuleNotFoundError as e:
|
|
99
|
+
raise ModuleNotFoundError(
|
|
100
|
+
"skypilot is required for CRMEnvSkypilot. Install with: pip install 'benchmax[skypilot]'"
|
|
101
|
+
) from e
|
|
102
|
+
cloud = sky.Azure()
|
|
93
103
|
workdir_path = Path(__file__).parent / "workdir"
|
|
94
104
|
provisioner = SkypilotProvisioner(
|
|
95
105
|
workdir_path=workdir_path,
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import requests
|
|
3
2
|
import tarfile
|
|
4
3
|
|
|
5
4
|
|
|
@@ -7,6 +6,8 @@ def download_and_extract(url, output_path):
|
|
|
7
6
|
"""
|
|
8
7
|
Downloads a tar.gz file from the given URL and extracts it into output_path.
|
|
9
8
|
"""
|
|
9
|
+
import requests
|
|
10
|
+
|
|
10
11
|
# Ensure the output directory exists
|
|
11
12
|
os.makedirs(output_path, exist_ok=True)
|
|
12
13
|
|
|
@@ -1,10 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Dict, Tuple, Optional
|
|
5
|
-
|
|
6
|
-
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
7
|
-
import sky
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Tuple, Optional
|
|
8
5
|
|
|
9
6
|
from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
|
|
10
7
|
from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
|
|
@@ -16,6 +13,10 @@ from .data_utils import download_and_extract
|
|
|
16
13
|
# Using library shared with mcp workdir
|
|
17
14
|
from .workdir.excel_utils import excel_to_str_repr
|
|
18
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import sky
|
|
18
|
+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
19
|
+
|
|
19
20
|
SYSTEM_PROMPT = """You are a spreadsheet expert who can manipulate spreadsheets through Python code.
|
|
20
21
|
|
|
21
22
|
You need to solve the given spreadsheet manipulation question, which contains six types of information:
|
|
@@ -64,8 +65,10 @@ class ExcelEnv(ParallelMcpEnv):
|
|
|
64
65
|
data_output_path: str = DEFAULT_DATA_OUTPUT_PATH,
|
|
65
66
|
**kwargs,
|
|
66
67
|
) -> Tuple[
|
|
67
|
-
DatasetDict | Dataset | IterableDatasetDict | IterableDataset, str | None
|
|
68
|
+
"DatasetDict | Dataset | IterableDatasetDict | IterableDataset", str | None
|
|
68
69
|
]:
|
|
70
|
+
from datasets import Dataset
|
|
71
|
+
|
|
69
72
|
# Currently only support spreadsheetbench dataset but can be extended to other datasets in the future
|
|
70
73
|
if dataset_name == "spreadsheetbench":
|
|
71
74
|
folder_path = Path(data_output_path) / SPREADSHEET_BENCH_TRAIN_DATA
|
|
@@ -206,11 +209,15 @@ class ExcelEnvSkypilot(ExcelEnv):
|
|
|
206
209
|
|
|
207
210
|
def __init__(
|
|
208
211
|
self,
|
|
209
|
-
cloud: sky.clouds.Cloud =
|
|
212
|
+
cloud: "sky.clouds.Cloud | None" = None,
|
|
210
213
|
num_nodes: int = 2,
|
|
211
214
|
servers_per_node: int = 5,
|
|
212
215
|
**kwargs,
|
|
213
216
|
):
|
|
217
|
+
if cloud is None:
|
|
218
|
+
import sky
|
|
219
|
+
|
|
220
|
+
cloud = sky.Azure()
|
|
214
221
|
workdir_path = Path(__file__).parent / "workdir"
|
|
215
222
|
provisioner = SkypilotProvisioner(
|
|
216
223
|
workdir_path=workdir_path,
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Any
|
|
3
|
-
import sky
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
4
3
|
|
|
5
4
|
from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
|
|
6
5
|
from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
|
|
@@ -8,6 +7,9 @@ from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
|
|
|
8
7
|
from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
|
|
9
8
|
from benchmax.envs.types import StandardizedExample
|
|
10
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import sky
|
|
12
|
+
|
|
11
13
|
SYSTEM_PROMPT = """Please use the tools provided to do any computation.
|
|
12
14
|
Write your complete answer on the final line only, within the xml tags <answer></answer>.\n
|
|
13
15
|
"""
|
|
@@ -46,11 +48,19 @@ class MathEnvSkypilot(MathEnv):
|
|
|
46
48
|
|
|
47
49
|
def __init__(
|
|
48
50
|
self,
|
|
49
|
-
cloud: sky.clouds.Cloud =
|
|
51
|
+
cloud: "sky.clouds.Cloud | None" = None,
|
|
50
52
|
num_nodes: int = 2,
|
|
51
53
|
servers_per_node: int = 5,
|
|
52
54
|
**kwargs,
|
|
53
55
|
):
|
|
56
|
+
if cloud is None:
|
|
57
|
+
try:
|
|
58
|
+
import sky
|
|
59
|
+
except ModuleNotFoundError as e:
|
|
60
|
+
raise ModuleNotFoundError(
|
|
61
|
+
"skypilot is required for MathEnvSkypilot. Install with: pip install 'benchmax[skypilot]'"
|
|
62
|
+
) from e
|
|
63
|
+
cloud = sky.Azure()
|
|
54
64
|
workdir_path = Path(__file__).parent / "workdir"
|
|
55
65
|
provisioner = SkypilotProvisioner(
|
|
56
66
|
workdir_path=workdir_path,
|
|
@@ -11,7 +11,12 @@ from typing import Any, Callable, Dict, List, Optional
|
|
|
11
11
|
import warnings
|
|
12
12
|
import aiohttp
|
|
13
13
|
from mcp.types import TextContent
|
|
14
|
-
|
|
14
|
+
try:
|
|
15
|
+
from fastmcp.exceptions import ToolError
|
|
16
|
+
except ModuleNotFoundError as e:
|
|
17
|
+
raise ModuleNotFoundError(
|
|
18
|
+
"fastmcp is required for MCP environments. Install with: pip install 'benchmax[mcp]'"
|
|
19
|
+
) from e
|
|
15
20
|
|
|
16
21
|
from benchmax.envs.base_env import BaseEnv
|
|
17
22
|
from benchmax.envs.types import ToolDefinition
|
{benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/__init__.py
RENAMED
|
@@ -5,11 +5,16 @@ Server provisioning strategies for ParallelMcpEnv.
|
|
|
5
5
|
from .base_provisioner import BaseProvisioner
|
|
6
6
|
from .manual_provisioner import ManualProvisioner
|
|
7
7
|
from .local_provisioner import LocalProvisioner
|
|
8
|
-
|
|
8
|
+
try:
|
|
9
|
+
from .skypilot_provisioner import SkypilotProvisioner
|
|
10
|
+
except ModuleNotFoundError:
|
|
11
|
+
SkypilotProvisioner = None
|
|
9
12
|
|
|
10
13
|
__all__ = [
|
|
11
14
|
"BaseProvisioner",
|
|
12
15
|
"ManualProvisioner",
|
|
13
16
|
"LocalProvisioner",
|
|
14
|
-
"SkypilotProvisioner",
|
|
15
17
|
]
|
|
18
|
+
|
|
19
|
+
if SkypilotProvisioner is not None:
|
|
20
|
+
__all__.append("SkypilotProvisioner")
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
"""
|
|
2
4
|
SkyPilot provisioner for launching cloud-based server clusters.
|
|
3
5
|
"""
|
|
@@ -5,14 +7,16 @@ SkyPilot provisioner for launching cloud-based server clusters.
|
|
|
5
7
|
import logging
|
|
6
8
|
import uuid
|
|
7
9
|
from pathlib import Path
|
|
8
|
-
from typing import List, Optional
|
|
9
|
-
import sky
|
|
10
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
10
11
|
|
|
11
12
|
from .base_provisioner import BaseProvisioner
|
|
12
13
|
from .utils import get_run_command, setup_sync_dir, cleanup_dir, get_setup_command
|
|
13
14
|
|
|
14
15
|
logger = logging.getLogger(__name__)
|
|
15
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import sky
|
|
19
|
+
|
|
16
20
|
|
|
17
21
|
class SkypilotProvisioner(BaseProvisioner):
|
|
18
22
|
"""
|
|
@@ -39,7 +43,7 @@ class SkypilotProvisioner(BaseProvisioner):
|
|
|
39
43
|
def __init__(
|
|
40
44
|
self,
|
|
41
45
|
workdir_path: Path | str,
|
|
42
|
-
cloud: sky.clouds.Cloud,
|
|
46
|
+
cloud: "sky.clouds.Cloud",
|
|
43
47
|
num_nodes: int = 1,
|
|
44
48
|
servers_per_node: int = 5,
|
|
45
49
|
cpus: Optional[str | int] = "2+",
|
|
@@ -62,8 +66,15 @@ class SkypilotProvisioner(BaseProvisioner):
|
|
|
62
66
|
raise ValueError("num_nodes must be at least 1")
|
|
63
67
|
if servers_per_node < 1 or servers_per_node > 100:
|
|
64
68
|
raise ValueError("servers_per_node must be between 1 and 100")
|
|
69
|
+
try:
|
|
70
|
+
import sky
|
|
71
|
+
except ModuleNotFoundError as e:
|
|
72
|
+
raise ModuleNotFoundError(
|
|
73
|
+
"skypilot is required for SkypilotProvisioner. Install with: pip install 'benchmax[skypilot]'"
|
|
74
|
+
) from e
|
|
65
75
|
|
|
66
76
|
self._workdir_path = Path(workdir_path).absolute()
|
|
77
|
+
self._sky = sky
|
|
67
78
|
self._cloud = cloud
|
|
68
79
|
self._num_nodes = num_nodes
|
|
69
80
|
self._servers_per_node = servers_per_node
|
|
@@ -123,7 +134,7 @@ class SkypilotProvisioner(BaseProvisioner):
|
|
|
123
134
|
env = None if api_secret is None else {"API_SECRET": api_secret}
|
|
124
135
|
|
|
125
136
|
# Configure SkyPilot task
|
|
126
|
-
sky_task =
|
|
137
|
+
sky_task = self._sky.Task(
|
|
127
138
|
name="mcp-server",
|
|
128
139
|
run=get_run_command(ports=all_ports),
|
|
129
140
|
setup=get_setup_command(),
|
|
@@ -133,7 +144,7 @@ class SkypilotProvisioner(BaseProvisioner):
|
|
|
133
144
|
)
|
|
134
145
|
|
|
135
146
|
sky_task.set_resources(
|
|
136
|
-
|
|
147
|
+
self._sky.Resources(
|
|
137
148
|
cloud=self._cloud,
|
|
138
149
|
cpus=self._cpus,
|
|
139
150
|
memory=self._memory,
|
|
@@ -148,7 +159,7 @@ class SkypilotProvisioner(BaseProvisioner):
|
|
|
148
159
|
)
|
|
149
160
|
cluster_handle = None
|
|
150
161
|
try:
|
|
151
|
-
_, handle =
|
|
162
|
+
_, handle = self._sky.launch(
|
|
152
163
|
task=sky_task,
|
|
153
164
|
cluster_name=self._cluster_name,
|
|
154
165
|
detach_run=True,
|
|
@@ -192,7 +203,7 @@ class SkypilotProvisioner(BaseProvisioner):
|
|
|
192
203
|
|
|
193
204
|
logger.info(f"Tearing down SkyPilot cluster '{self._cluster_name}'...")
|
|
194
205
|
try:
|
|
195
|
-
|
|
206
|
+
self._sky.down(cluster_name=self._cluster_name)
|
|
196
207
|
logger.info(f"Cluster '{self._cluster_name}' torn down successfully")
|
|
197
208
|
except Exception as e:
|
|
198
209
|
logger.error(f"Error tearing down cluster '{self._cluster_name}': {e}")
|
|
@@ -2,7 +2,6 @@ import os
|
|
|
2
2
|
import sys
|
|
3
3
|
import shutil
|
|
4
4
|
import uuid
|
|
5
|
-
import yaml
|
|
6
5
|
import asyncio
|
|
7
6
|
import argparse
|
|
8
7
|
import psutil
|
|
@@ -12,10 +11,15 @@ from typing import Any, Awaitable, Callable, Dict, Union, Optional, Tuple, List
|
|
|
12
11
|
from pathlib import Path
|
|
13
12
|
from functools import wraps
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
from fastmcp.server.
|
|
17
|
-
from fastmcp.server.auth import
|
|
18
|
-
from fastmcp import
|
|
14
|
+
try:
|
|
15
|
+
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
|
16
|
+
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
|
17
|
+
from fastmcp.server.auth import AccessToken
|
|
18
|
+
from fastmcp import FastMCP, Client
|
|
19
|
+
except ModuleNotFoundError as e:
|
|
20
|
+
raise ModuleNotFoundError(
|
|
21
|
+
"fastmcp is required for MCP environments. Install with: pip install 'benchmax[mcp]'"
|
|
22
|
+
) from e
|
|
19
23
|
from starlette.requests import Request
|
|
20
24
|
from starlette.responses import PlainTextResponse, FileResponse, JSONResponse, Response
|
|
21
25
|
from starlette.datastructures import UploadFile
|
|
@@ -36,6 +40,8 @@ def setup_workspace(base_dir: Path) -> Path:
|
|
|
36
40
|
|
|
37
41
|
def load_config(config_path: Path, workspace: Path) -> Dict[str, Any]:
|
|
38
42
|
"""Load YAML config and inject workspace paths."""
|
|
43
|
+
import yaml
|
|
44
|
+
|
|
39
45
|
with open(config_path, "r") as f:
|
|
40
46
|
content = f.read().replace(
|
|
41
47
|
"${{ sync_workdir }}", str(Path(__file__).resolve().parent)
|
|
@@ -8,7 +8,12 @@ from dataclasses import dataclass
|
|
|
8
8
|
import random
|
|
9
9
|
from typing import Coroutine, Dict, List, Optional, Set
|
|
10
10
|
import aiohttp
|
|
11
|
-
|
|
11
|
+
try:
|
|
12
|
+
from fastmcp import Client as FastMCPClient
|
|
13
|
+
except ModuleNotFoundError as e:
|
|
14
|
+
raise ModuleNotFoundError(
|
|
15
|
+
"fastmcp is required for MCP environments. Install with: pip install 'benchmax[mcp]'"
|
|
16
|
+
) from e
|
|
12
17
|
|
|
13
18
|
from benchmax.envs.mcp.utils import generate_jwt_token, get_auth_headers
|
|
14
19
|
|
|
@@ -9,7 +9,12 @@ from contextlib import AsyncExitStack
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import Optional, Dict, List, Any
|
|
11
11
|
import aiohttp
|
|
12
|
-
|
|
12
|
+
try:
|
|
13
|
+
from fastmcp import Client
|
|
14
|
+
except ModuleNotFoundError as e:
|
|
15
|
+
raise ModuleNotFoundError(
|
|
16
|
+
"fastmcp is required for MCP environments. Install with: pip install 'benchmax[mcp]'"
|
|
17
|
+
) from e
|
|
13
18
|
from mcp import Tool
|
|
14
19
|
|
|
15
20
|
from benchmax.envs.types import ToolDefinition
|
|
@@ -1,20 +1,24 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: benchmax
|
|
3
|
-
Version: 0.1.2.
|
|
3
|
+
Version: 0.1.2.dev13
|
|
4
4
|
Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
|
|
5
5
|
Author: cgft.io
|
|
6
6
|
Classifier: Programming Language :: Python :: 3
|
|
7
7
|
Classifier: Operating System :: OS Independent
|
|
8
|
-
Requires-Python:
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
9
|
Description-Content-Type: text/markdown
|
|
10
10
|
License-File: LICENSE
|
|
11
11
|
Requires-Dist: aiohttp>=3.13.1
|
|
12
12
|
Requires-Dist: asyncio>=4.0.0
|
|
13
13
|
Requires-Dist: cloudpickle>=3.0.0
|
|
14
14
|
Requires-Dist: datasets>=4.0.0
|
|
15
|
-
|
|
16
|
-
Requires-Dist:
|
|
17
|
-
Requires-Dist:
|
|
15
|
+
Provides-Extra: mcp
|
|
16
|
+
Requires-Dist: fastmcp~=2.12.0; extra == "mcp"
|
|
17
|
+
Requires-Dist: pyjwt>=2.10.1; extra == "mcp"
|
|
18
|
+
Provides-Extra: skypilot
|
|
19
|
+
Requires-Dist: skypilot[aws,gcp]~=0.8.1; extra == "skypilot"
|
|
20
|
+
Requires-Dist: pip>=25.3; extra == "skypilot"
|
|
21
|
+
Requires-Dist: msrestazure>=0.6.4.post1; extra == "skypilot"
|
|
18
22
|
Dynamic: license-file
|
|
19
23
|
|
|
20
24
|
<picture>
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/crm/workdir/salesforce_mcp.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/excel_utils.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/example_workdir/reward_fn.py
RENAMED
|
File without changes
|
{benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/base_provisioner.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|