benchmax 0.1.2.dev11__tar.gz → 0.1.2.dev13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/PKG-INFO +9 -5
  2. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/pyproject.toml +18 -9
  3. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/skyrl/benchmax_data_process.py +9 -4
  4. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/loader.py +18 -11
  5. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/base_env.py +10 -13
  6. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/crm/crm_env.py +13 -3
  7. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/data_utils.py +2 -1
  8. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/excel_env.py +13 -6
  9. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/math/math_env.py +13 -3
  10. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/parallel_mcp_env.py +6 -1
  11. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/__init__.py +7 -2
  12. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py +18 -7
  13. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/proxy_server.py +11 -5
  14. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/server_pool.py +6 -1
  15. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/utils.py +6 -1
  16. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/PKG-INFO +9 -5
  17. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/requires.txt +7 -1
  18. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/LICENSE +0 -0
  19. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/README.md +0 -0
  20. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/setup.cfg +0 -0
  21. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/__init__.py +0 -0
  22. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/benchmax_wrapper.py +0 -0
  23. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/adapters/skyrl/skyrl_adapter.py +0 -0
  24. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/__init__.py +0 -0
  25. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/bundler.py +0 -0
  26. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/errors.py +0 -0
  27. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/payload.py +0 -0
  28. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/bundle/validator.py +0 -0
  29. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/__init__.py +0 -0
  30. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/crm/workdir/reward_fn.py +0 -0
  31. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/crm/workdir/salesforce_mcp.py +0 -0
  32. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/__init__.py +0 -0
  33. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py +0 -0
  34. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/excel_utils.py +0 -0
  35. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/excel/workdir/reward_fn.py +0 -0
  36. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/math/workdir/reward_fn.py +0 -0
  37. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/__init__.py +0 -0
  38. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py +0 -0
  39. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/example_workdir/reward_fn.py +0 -0
  40. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/base_provisioner.py +0 -0
  41. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/local_provisioner.py +0 -0
  42. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/manual_provisioner.py +0 -0
  43. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/mcp/provisioners/utils.py +0 -0
  44. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/types.py +0 -0
  45. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/wikipedia/utils.py +0 -0
  46. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/envs/wikipedia/wiki_env.py +0 -0
  47. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/prompts/__init__.py +0 -0
  48. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax/prompts/tools.py +0 -0
  49. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/SOURCES.txt +0 -0
  50. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/dependency_links.txt +0 -0
  51. {benchmax-0.1.2.dev11 → benchmax-0.1.2.dev13}/src/benchmax.egg-info/top_level.txt +0 -0
@@ -1,20 +1,24 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: benchmax
3
- Version: 0.1.2.dev11
3
+ Version: 0.1.2.dev13
4
4
  Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
5
5
  Author: cgft.io
6
6
  Classifier: Programming Language :: Python :: 3
7
7
  Classifier: Operating System :: OS Independent
8
- Requires-Python: ==3.12.*
8
+ Requires-Python: >=3.12
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE
11
11
  Requires-Dist: aiohttp>=3.13.1
12
12
  Requires-Dist: asyncio>=4.0.0
13
13
  Requires-Dist: cloudpickle>=3.0.0
14
14
  Requires-Dist: datasets>=4.0.0
15
- Requires-Dist: fastmcp~=2.12.0
16
- Requires-Dist: pyjwt>=2.10.1
17
- Requires-Dist: skypilot~=0.8.1
15
+ Provides-Extra: mcp
16
+ Requires-Dist: fastmcp~=2.12.0; extra == "mcp"
17
+ Requires-Dist: pyjwt>=2.10.1; extra == "mcp"
18
+ Provides-Extra: skypilot
19
+ Requires-Dist: skypilot[aws,gcp]~=0.8.1; extra == "skypilot"
20
+ Requires-Dist: pip>=25.3; extra == "skypilot"
21
+ Requires-Dist: msrestazure>=0.6.4.post1; extra == "skypilot"
18
22
  Dynamic: license-file
19
23
 
20
24
  <picture>
@@ -1,18 +1,15 @@
1
1
  [project]
2
2
  name = "benchmax"
3
- version = "0.1.2.dev11"
3
+ version = "0.1.2.dev13"
4
4
  description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "cgft.io" }]
7
- requires-python = "==3.12.*"
7
+ requires-python = ">=3.12"
8
8
  dependencies = [
9
9
  "aiohttp>=3.13.1",
10
10
  "asyncio>=4.0.0",
11
11
  "cloudpickle>=3.0.0",
12
12
  "datasets>=4.0.0",
13
- "fastmcp~=2.12.0",
14
- "pyjwt>=2.10.1",
15
- "skypilot~=0.8.1",
16
13
  ]
17
14
  classifiers = [
18
15
  "Programming Language :: Python :: 3",
@@ -26,6 +23,17 @@ build-backend = "setuptools.build_meta"
26
23
  [tool.setuptools.packages.find]
27
24
  where = ["src"]
28
25
 
26
+ [project.optional-dependencies]
27
+ mcp = [
28
+ "fastmcp~=2.12.0",
29
+ "pyjwt>=2.10.1",
30
+ ]
31
+ skypilot = [
32
+ "skypilot[aws,gcp]~=0.8.1",
33
+ "pip>=25.3",
34
+ "msrestazure>=0.6.4.post1",
35
+ ]
36
+
29
37
  [dependency-groups]
30
38
  dev = [
31
39
  "pytest>=8.4.2",
@@ -33,8 +41,12 @@ dev = [
33
41
  "python-dotenv>=1.2.1",
34
42
  "ruff>=0.14.2",
35
43
  ]
44
+ mcp = [
45
+ "fastmcp~=2.12.0",
46
+ "pyjwt>=2.10.1",
47
+ ]
36
48
  skypilot = [
37
- "skypilot[aws,gcp,azure]~=0.8.1", # Change this to your cloud provider
49
+ "skypilot[aws,gcp]~=0.8.1", # Add azure only in an Azure-specific env/group
38
50
  "pip>=25.3", # Added as needed for skypilot launch
39
51
  "msrestazure>=0.6.4.post1",
40
52
  ]
@@ -53,9 +65,6 @@ crm = ["python-dateutil>=2.9.0.post0", "simple-salesforce>=1.12.9"]
53
65
  [tool.uv]
54
66
  conflicts = [[{ group = "skypilot" }, { group = "skyrl" }]]
55
67
 
56
- [tool.uv.pip]
57
- extra = ["dev", "skypilot", "skyrl", "excel", "excel-mac-windows", "crm"]
58
-
59
68
  # [tool.uv.extra-build-dependencies]
60
69
  # flash-attn = [{ requirement = "torch", match-runtime = true }]
61
70
 
@@ -7,14 +7,15 @@ import logging
7
7
  from importlib import import_module
8
8
  from pathlib import Path
9
9
  from types import ModuleType
10
- from typing import Type
11
- from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
12
- import datasets
10
+ from typing import TYPE_CHECKING, Type
13
11
  import asyncio
14
12
  import inspect
15
13
 
16
14
  from benchmax.envs.base_env import BaseEnv
17
15
 
16
+ if TYPE_CHECKING:
17
+ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
18
+
18
19
  # Set logging level to WARNING and above
19
20
  logging.basicConfig(level=logging.WARNING)
20
21
 
@@ -124,6 +125,8 @@ if __name__ == "__main__":
124
125
  benchmax_cls: Type[BaseEnv] = load_class(args.env_path)
125
126
  raw_dataset, dataset_path = benchmax_cls.load_dataset(args.dataset_name)
126
127
 
128
+ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
129
+
127
130
  if isinstance(raw_dataset, (IterableDataset, IterableDatasetDict)):
128
131
  raise TypeError(
129
132
  f"Iterable datasets are currently not supported. Got {type(raw_dataset).__name__}. "
@@ -178,7 +181,9 @@ if __name__ == "__main__":
178
181
  test_dataset = processed_dataset["test"]
179
182
  else:
180
183
  if isinstance(processed_dataset, DatasetDict):
181
- processed_dataset = datasets.concatenate_datasets(
184
+ from datasets import concatenate_datasets
185
+
186
+ processed_dataset = concatenate_datasets(
182
187
  [ds for ds in processed_dataset.values()]
183
188
  ).shuffle(seed=42)
184
189
 
@@ -1,26 +1,25 @@
1
1
  import logging
2
- import subprocess
3
2
  import sys
4
3
  from pathlib import Path
5
- from typing import Any, Dict, Optional, Type, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
6
5
 
7
- import cloudpickle
8
-
9
- from benchmax.envs.base_env import BaseEnv
10
6
  from benchmax.bundle.errors import (
11
7
  DependencyError,
12
8
  IncompatiblePythonError,
13
9
  IncompatibleBenchmaxError,
14
10
  BundlingError,
15
11
  )
16
- from benchmax.bundle.payload import BundleMetadata, BundledEnv
12
+
13
+ if TYPE_CHECKING:
14
+ from benchmax.envs.base_env import BaseEnv
15
+ from benchmax.bundle.payload import BundleMetadata, BundledEnv
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
20
19
 
21
20
  def load_env(
22
21
  pickled_class: bytes,
23
- metadata: BundleMetadata | None = None,
22
+ metadata: "BundleMetadata | None" = None,
24
23
  *,
25
24
  pip_dependencies: Optional[list[str]] = None,
26
25
  python_version: Optional[str] = None,
@@ -30,7 +29,7 @@ def load_env(
30
29
  check_benchmax_version: bool = False,
31
30
  install_pip_deps: bool = False,
32
31
  instantiate: Optional[bool] = None,
33
- ) -> BaseEnv | Type[BaseEnv]:
32
+ ) -> "BaseEnv | Type[BaseEnv]":
34
33
  """Load a packaged environment class (and optionally instantiate it).
35
34
 
36
35
  Args:
@@ -52,6 +51,8 @@ def load_env(
52
51
  The unpickled BaseEnv subclass (class object), or an instance if
53
52
  instantiation is requested.
54
53
  """
54
+ from benchmax.envs.base_env import BaseEnv
55
+
55
56
  resolved_pip_deps = (
56
57
  pip_dependencies
57
58
  if pip_dependencies is not None
@@ -101,6 +102,8 @@ def load_env(
101
102
  _install_dependencies(resolved_pip_deps)
102
103
 
103
104
  try:
105
+ import cloudpickle
106
+
104
107
  env_class = cloudpickle.loads(pickled_class)
105
108
  except Exception as e:
106
109
  raise BundlingError(
@@ -140,7 +143,9 @@ def load_env_from_files(
140
143
  check_benchmax_version: bool = False,
141
144
  install_pip_deps: bool = False,
142
145
  instantiate: Optional[bool] = None,
143
- ) -> BaseEnv | Type[BaseEnv]:
146
+ ) -> "BaseEnv | Type[BaseEnv]":
147
+ from benchmax.bundle.payload import BundleMetadata
148
+
144
149
  pickle_path = Path(pickle_path)
145
150
  metadata_path = Path(metadata_path)
146
151
  pickled_class = pickle_path.read_bytes()
@@ -160,7 +165,7 @@ def load_env_from_files(
160
165
 
161
166
 
162
167
  def load_env_from_bundle(
163
- bundle: BundledEnv,
168
+ bundle: "BundledEnv",
164
169
  *,
165
170
  pip_dependencies: Optional[list[str]] = None,
166
171
  python_version: Optional[str] = None,
@@ -170,7 +175,7 @@ def load_env_from_bundle(
170
175
  check_benchmax_version: bool = False,
171
176
  install_pip_deps: bool = False,
172
177
  instantiate: Optional[bool] = None,
173
- ) -> BaseEnv | Type[BaseEnv]:
178
+ ) -> "BaseEnv | Type[BaseEnv]":
174
179
  return load_env(
175
180
  bundle.pickled_class,
176
181
  bundle.metadata,
@@ -187,6 +192,8 @@ def load_env_from_bundle(
187
192
 
188
193
  def _install_dependencies(deps: list[str]) -> None:
189
194
  """Install pip dependencies in the current environment."""
195
+ import subprocess
196
+
190
197
  logger.info(f"[bundling] Installing {len(deps)} dependencies: {deps}")
191
198
  cmd = [sys.executable, "-m", "pip", "install", "--quiet", *deps]
192
199
  result = subprocess.run(cmd, capture_output=True, text=True)
@@ -1,27 +1,19 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, List, Any, Optional, Tuple
2
+ from typing import TYPE_CHECKING, Dict, List, Any, Optional, Tuple
3
3
  from pathlib import Path
4
- from datasets import (
5
- DatasetDict,
6
- Dataset,
7
- IterableDatasetDict,
8
- IterableDataset,
9
- load_dataset,
10
- )
11
4
 
12
5
  from benchmax.envs.types import ToolDefinition, StandardizedExample
13
6
  from benchmax.prompts.tools import render_tools_prompt
14
7
 
8
+ if TYPE_CHECKING:
9
+ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
10
+
15
11
 
16
12
  class BaseEnv(ABC):
17
13
  """Base benchmax environment for tool execution and reward computation"""
18
14
 
19
15
  system_prompt: str = ""
20
16
 
21
- @abstractmethod
22
- async def shutdown(self):
23
- pass
24
-
25
17
  # Override this method if your example does not match the default structure
26
18
  @classmethod
27
19
  def dataset_preprocess(cls, example: Any, **kwargs) -> StandardizedExample:
@@ -45,7 +37,7 @@ class BaseEnv(ABC):
45
37
  def load_dataset(
46
38
  cls, dataset_name: str, **kwargs
47
39
  ) -> Tuple[
48
- DatasetDict | Dataset | IterableDatasetDict | IterableDataset, str | None
40
+ "DatasetDict | Dataset | IterableDatasetDict | IterableDataset", str | None
49
41
  ]:
50
42
  """
51
43
  Download and prepare a dataset for use with this environment.
@@ -63,6 +55,8 @@ class BaseEnv(ABC):
63
55
  Dataset: A dataset object (e.g., HuggingFace Dataset or similar) ready for processing.
64
56
  str: Optional string pointing to where the dataset is stored locally
65
57
  """
58
+ from datasets import load_dataset
59
+
66
60
  return load_dataset(dataset_name, **kwargs), None
67
61
 
68
62
  # Methods all environment subclasses must implement
@@ -98,6 +92,9 @@ class BaseEnv(ABC):
98
92
 
99
93
  # Optional rollout lifecycle management methods
100
94
 
95
+ async def shutdown(self):
96
+ pass
97
+
101
98
  async def init_rollout(self, rollout_id: str, **rollout_args) -> None:
102
99
  """Initialize resources for a new rollout"""
103
100
  return None
@@ -1,6 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, Dict, List, Optional
3
- import sky
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
4
3
 
5
4
  from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
6
5
  from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
@@ -8,6 +7,9 @@ from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
8
7
  from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
9
8
  from benchmax.envs.types import StandardizedExample
10
9
 
10
+ if TYPE_CHECKING:
11
+ import sky
12
+
11
13
 
12
14
  SYSTEM_PROMPT = """\
13
15
  You are an expert in Salesforce and you have access to a Salesforce instance.
@@ -85,11 +87,19 @@ class CRMEnvSkypilot(CRMEnv):
85
87
 
86
88
  def __init__(
87
89
  self,
88
- cloud: sky.clouds.Cloud = sky.Azure(),
90
+ cloud: "sky.clouds.Cloud | None" = None,
89
91
  num_nodes: int = 2,
90
92
  servers_per_node: int = 5,
91
93
  **kwargs,
92
94
  ):
95
+ if cloud is None:
96
+ try:
97
+ import sky
98
+ except ModuleNotFoundError as e:
99
+ raise ModuleNotFoundError(
100
+ "skypilot is required for CRMEnvSkypilot. Install with: pip install 'benchmax[skypilot]'"
101
+ ) from e
102
+ cloud = sky.Azure()
93
103
  workdir_path = Path(__file__).parent / "workdir"
94
104
  provisioner = SkypilotProvisioner(
95
105
  workdir_path=workdir_path,
@@ -1,5 +1,4 @@
1
1
  import os
2
- import requests
3
2
  import tarfile
4
3
 
5
4
 
@@ -7,6 +6,8 @@ def download_and_extract(url, output_path):
7
6
  """
8
7
  Downloads a tar.gz file from the given URL and extracts it into output_path.
9
8
  """
9
+ import requests
10
+
10
11
  # Ensure the output directory exists
11
12
  os.makedirs(output_path, exist_ok=True)
12
13
 
@@ -1,10 +1,7 @@
1
1
  import json
2
2
  import os
3
3
  from pathlib import Path
4
- from typing import Any, Dict, Tuple, Optional
5
-
6
- from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
7
- import sky
4
+ from typing import TYPE_CHECKING, Any, Dict, Tuple, Optional
8
5
 
9
6
  from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
10
7
  from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
@@ -16,6 +13,10 @@ from .data_utils import download_and_extract
16
13
  # Using library shared with mcp workdir
17
14
  from .workdir.excel_utils import excel_to_str_repr
18
15
 
16
+ if TYPE_CHECKING:
17
+ import sky
18
+ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
19
+
19
20
  SYSTEM_PROMPT = """You are a spreadsheet expert who can manipulate spreadsheets through Python code.
20
21
 
21
22
  You need to solve the given spreadsheet manipulation question, which contains six types of information:
@@ -64,8 +65,10 @@ class ExcelEnv(ParallelMcpEnv):
64
65
  data_output_path: str = DEFAULT_DATA_OUTPUT_PATH,
65
66
  **kwargs,
66
67
  ) -> Tuple[
67
- DatasetDict | Dataset | IterableDatasetDict | IterableDataset, str | None
68
+ "DatasetDict | Dataset | IterableDatasetDict | IterableDataset", str | None
68
69
  ]:
70
+ from datasets import Dataset
71
+
69
72
  # Currently only support spreadsheetbench dataset but can be extended to other datasets in the future
70
73
  if dataset_name == "spreadsheetbench":
71
74
  folder_path = Path(data_output_path) / SPREADSHEET_BENCH_TRAIN_DATA
@@ -206,11 +209,15 @@ class ExcelEnvSkypilot(ExcelEnv):
206
209
 
207
210
  def __init__(
208
211
  self,
209
- cloud: sky.clouds.Cloud = sky.Azure(),
212
+ cloud: "sky.clouds.Cloud | None" = None,
210
213
  num_nodes: int = 2,
211
214
  servers_per_node: int = 5,
212
215
  **kwargs,
213
216
  ):
217
+ if cloud is None:
218
+ import sky
219
+
220
+ cloud = sky.Azure()
214
221
  workdir_path = Path(__file__).parent / "workdir"
215
222
  provisioner = SkypilotProvisioner(
216
223
  workdir_path=workdir_path,
@@ -1,6 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any
3
- import sky
2
+ from typing import TYPE_CHECKING, Any
4
3
 
5
4
  from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
6
5
  from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
@@ -8,6 +7,9 @@ from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
8
7
  from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
9
8
  from benchmax.envs.types import StandardizedExample
10
9
 
10
+ if TYPE_CHECKING:
11
+ import sky
12
+
11
13
  SYSTEM_PROMPT = """Please use the tools provided to do any computation.
12
14
  Write your complete answer on the final line only, within the xml tags <answer></answer>.\n
13
15
  """
@@ -46,11 +48,19 @@ class MathEnvSkypilot(MathEnv):
46
48
 
47
49
  def __init__(
48
50
  self,
49
- cloud: sky.clouds.Cloud = sky.Azure(),
51
+ cloud: "sky.clouds.Cloud | None" = None,
50
52
  num_nodes: int = 2,
51
53
  servers_per_node: int = 5,
52
54
  **kwargs,
53
55
  ):
56
+ if cloud is None:
57
+ try:
58
+ import sky
59
+ except ModuleNotFoundError as e:
60
+ raise ModuleNotFoundError(
61
+ "skypilot is required for MathEnvSkypilot. Install with: pip install 'benchmax[skypilot]'"
62
+ ) from e
63
+ cloud = sky.Azure()
54
64
  workdir_path = Path(__file__).parent / "workdir"
55
65
  provisioner = SkypilotProvisioner(
56
66
  workdir_path=workdir_path,
@@ -11,7 +11,12 @@ from typing import Any, Callable, Dict, List, Optional
11
11
  import warnings
12
12
  import aiohttp
13
13
  from mcp.types import TextContent
14
- from fastmcp.exceptions import ToolError
14
+ try:
15
+ from fastmcp.exceptions import ToolError
16
+ except ModuleNotFoundError as e:
17
+ raise ModuleNotFoundError(
18
+ "fastmcp is required for MCP environments. Install with: pip install 'benchmax[mcp]'"
19
+ ) from e
15
20
 
16
21
  from benchmax.envs.base_env import BaseEnv
17
22
  from benchmax.envs.types import ToolDefinition
@@ -5,11 +5,16 @@ Server provisioning strategies for ParallelMcpEnv.
5
5
  from .base_provisioner import BaseProvisioner
6
6
  from .manual_provisioner import ManualProvisioner
7
7
  from .local_provisioner import LocalProvisioner
8
- from .skypilot_provisioner import SkypilotProvisioner
8
+ try:
9
+ from .skypilot_provisioner import SkypilotProvisioner
10
+ except ModuleNotFoundError:
11
+ SkypilotProvisioner = None
9
12
 
10
13
  __all__ = [
11
14
  "BaseProvisioner",
12
15
  "ManualProvisioner",
13
16
  "LocalProvisioner",
14
- "SkypilotProvisioner",
15
17
  ]
18
+
19
+ if SkypilotProvisioner is not None:
20
+ __all__.append("SkypilotProvisioner")
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  SkyPilot provisioner for launching cloud-based server clusters.
3
5
  """
@@ -5,14 +7,16 @@ SkyPilot provisioner for launching cloud-based server clusters.
5
7
  import logging
6
8
  import uuid
7
9
  from pathlib import Path
8
- from typing import List, Optional
9
- import sky
10
+ from typing import TYPE_CHECKING, List, Optional
10
11
 
11
12
  from .base_provisioner import BaseProvisioner
12
13
  from .utils import get_run_command, setup_sync_dir, cleanup_dir, get_setup_command
13
14
 
14
15
  logger = logging.getLogger(__name__)
15
16
 
17
+ if TYPE_CHECKING:
18
+ import sky
19
+
16
20
 
17
21
  class SkypilotProvisioner(BaseProvisioner):
18
22
  """
@@ -39,7 +43,7 @@ class SkypilotProvisioner(BaseProvisioner):
39
43
  def __init__(
40
44
  self,
41
45
  workdir_path: Path | str,
42
- cloud: sky.clouds.Cloud,
46
+ cloud: "sky.clouds.Cloud",
43
47
  num_nodes: int = 1,
44
48
  servers_per_node: int = 5,
45
49
  cpus: Optional[str | int] = "2+",
@@ -62,8 +66,15 @@ class SkypilotProvisioner(BaseProvisioner):
62
66
  raise ValueError("num_nodes must be at least 1")
63
67
  if servers_per_node < 1 or servers_per_node > 100:
64
68
  raise ValueError("servers_per_node must be between 1 and 100")
69
+ try:
70
+ import sky
71
+ except ModuleNotFoundError as e:
72
+ raise ModuleNotFoundError(
73
+ "skypilot is required for SkypilotProvisioner. Install with: pip install 'benchmax[skypilot]'"
74
+ ) from e
65
75
 
66
76
  self._workdir_path = Path(workdir_path).absolute()
77
+ self._sky = sky
67
78
  self._cloud = cloud
68
79
  self._num_nodes = num_nodes
69
80
  self._servers_per_node = servers_per_node
@@ -123,7 +134,7 @@ class SkypilotProvisioner(BaseProvisioner):
123
134
  env = None if api_secret is None else {"API_SECRET": api_secret}
124
135
 
125
136
  # Configure SkyPilot task
126
- sky_task = sky.Task(
137
+ sky_task = self._sky.Task(
127
138
  name="mcp-server",
128
139
  run=get_run_command(ports=all_ports),
129
140
  setup=get_setup_command(),
@@ -133,7 +144,7 @@ class SkypilotProvisioner(BaseProvisioner):
133
144
  )
134
145
 
135
146
  sky_task.set_resources(
136
- sky.Resources(
147
+ self._sky.Resources(
137
148
  cloud=self._cloud,
138
149
  cpus=self._cpus,
139
150
  memory=self._memory,
@@ -148,7 +159,7 @@ class SkypilotProvisioner(BaseProvisioner):
148
159
  )
149
160
  cluster_handle = None
150
161
  try:
151
- _, handle = sky.launch(
162
+ _, handle = self._sky.launch(
152
163
  task=sky_task,
153
164
  cluster_name=self._cluster_name,
154
165
  detach_run=True,
@@ -192,7 +203,7 @@ class SkypilotProvisioner(BaseProvisioner):
192
203
 
193
204
  logger.info(f"Tearing down SkyPilot cluster '{self._cluster_name}'...")
194
205
  try:
195
- sky.down(cluster_name=self._cluster_name)
206
+ self._sky.down(cluster_name=self._cluster_name)
196
207
  logger.info(f"Cluster '{self._cluster_name}' torn down successfully")
197
208
  except Exception as e:
198
209
  logger.error(f"Error tearing down cluster '{self._cluster_name}': {e}")
@@ -2,7 +2,6 @@ import os
2
2
  import sys
3
3
  import shutil
4
4
  import uuid
5
- import yaml
6
5
  import asyncio
7
6
  import argparse
8
7
  import psutil
@@ -12,10 +11,15 @@ from typing import Any, Awaitable, Callable, Dict, Union, Optional, Tuple, List
12
11
  from pathlib import Path
13
12
  from functools import wraps
14
13
 
15
- from fastmcp.server.middleware import Middleware, MiddlewareContext
16
- from fastmcp.server.auth.providers.jwt import JWTVerifier
17
- from fastmcp.server.auth import AccessToken
18
- from fastmcp import FastMCP, Client
14
+ try:
15
+ from fastmcp.server.middleware import Middleware, MiddlewareContext
16
+ from fastmcp.server.auth.providers.jwt import JWTVerifier
17
+ from fastmcp.server.auth import AccessToken
18
+ from fastmcp import FastMCP, Client
19
+ except ModuleNotFoundError as e:
20
+ raise ModuleNotFoundError(
21
+ "fastmcp is required for MCP environments. Install with: pip install 'benchmax[mcp]'"
22
+ ) from e
19
23
  from starlette.requests import Request
20
24
  from starlette.responses import PlainTextResponse, FileResponse, JSONResponse, Response
21
25
  from starlette.datastructures import UploadFile
@@ -36,6 +40,8 @@ def setup_workspace(base_dir: Path) -> Path:
36
40
 
37
41
  def load_config(config_path: Path, workspace: Path) -> Dict[str, Any]:
38
42
  """Load YAML config and inject workspace paths."""
43
+ import yaml
44
+
39
45
  with open(config_path, "r") as f:
40
46
  content = f.read().replace(
41
47
  "${{ sync_workdir }}", str(Path(__file__).resolve().parent)
@@ -8,7 +8,12 @@ from dataclasses import dataclass
8
8
  import random
9
9
  from typing import Coroutine, Dict, List, Optional, Set
10
10
  import aiohttp
11
- from fastmcp import Client as FastMCPClient
11
+ try:
12
+ from fastmcp import Client as FastMCPClient
13
+ except ModuleNotFoundError as e:
14
+ raise ModuleNotFoundError(
15
+ "fastmcp is required for MCP environments. Install with: pip install 'benchmax[mcp]'"
16
+ ) from e
12
17
 
13
18
  from benchmax.envs.mcp.utils import generate_jwt_token, get_auth_headers
14
19
 
@@ -9,7 +9,12 @@ from contextlib import AsyncExitStack
9
9
  from pathlib import Path
10
10
  from typing import Optional, Dict, List, Any
11
11
  import aiohttp
12
- from fastmcp import Client
12
+ try:
13
+ from fastmcp import Client
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError(
16
+ "fastmcp is required for MCP environments. Install with: pip install 'benchmax[mcp]'"
17
+ ) from e
13
18
  from mcp import Tool
14
19
 
15
20
  from benchmax.envs.types import ToolDefinition
@@ -1,20 +1,24 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: benchmax
3
- Version: 0.1.2.dev11
3
+ Version: 0.1.2.dev13
4
4
  Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
5
5
  Author: cgft.io
6
6
  Classifier: Programming Language :: Python :: 3
7
7
  Classifier: Operating System :: OS Independent
8
- Requires-Python: ==3.12.*
8
+ Requires-Python: >=3.12
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE
11
11
  Requires-Dist: aiohttp>=3.13.1
12
12
  Requires-Dist: asyncio>=4.0.0
13
13
  Requires-Dist: cloudpickle>=3.0.0
14
14
  Requires-Dist: datasets>=4.0.0
15
- Requires-Dist: fastmcp~=2.12.0
16
- Requires-Dist: pyjwt>=2.10.1
17
- Requires-Dist: skypilot~=0.8.1
15
+ Provides-Extra: mcp
16
+ Requires-Dist: fastmcp~=2.12.0; extra == "mcp"
17
+ Requires-Dist: pyjwt>=2.10.1; extra == "mcp"
18
+ Provides-Extra: skypilot
19
+ Requires-Dist: skypilot[aws,gcp]~=0.8.1; extra == "skypilot"
20
+ Requires-Dist: pip>=25.3; extra == "skypilot"
21
+ Requires-Dist: msrestazure>=0.6.4.post1; extra == "skypilot"
18
22
  Dynamic: license-file
19
23
 
20
24
  <picture>
@@ -2,6 +2,12 @@ aiohttp>=3.13.1
2
2
  asyncio>=4.0.0
3
3
  cloudpickle>=3.0.0
4
4
  datasets>=4.0.0
5
+
6
+ [mcp]
5
7
  fastmcp~=2.12.0
6
8
  pyjwt>=2.10.1
7
- skypilot~=0.8.1
9
+
10
+ [skypilot]
11
+ skypilot[aws,gcp]~=0.8.1
12
+ pip>=25.3
13
+ msrestazure>=0.6.4.post1
File without changes
File without changes
File without changes