benchmax 0.1.1.dev0__tar.gz → 0.1.1.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/PKG-INFO +9 -10
  2. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/adapters/verl/benchmax_data_process.py +7 -6
  3. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/base_env.py +10 -5
  4. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/crm/crm_env.py +6 -5
  5. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/excel/README.md +3 -0
  6. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/excel/excel_env.py +7 -6
  7. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/local_mcp_env.py +8 -6
  8. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/wikipedia/wiki_env.py +5 -5
  9. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/pyproject.toml +8 -20
  10. benchmax-0.1.1.dev0/benchmax/adapters/verifiers/examples/verifiers_crm_example.py +0 -54
  11. benchmax-0.1.1.dev0/benchmax/adapters/verifiers/examples/verifiers_excel_example.py +0 -54
  12. benchmax-0.1.1.dev0/benchmax/adapters/verifiers/examples/verifiers_math_example.py +0 -54
  13. benchmax-0.1.1.dev0/benchmax/adapters/verl/examples/config/benchmax_multiturn_grpo.yaml +0 -21
  14. benchmax-0.1.1.dev0/benchmax/adapters/verl/examples/config/tool_config/benchmax_math_tool_config.yaml +0 -7
  15. benchmax-0.1.1.dev0/benchmax/adapters/verl/examples/run_qwen2.5-3b_benchmax_math.sh +0 -69
  16. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/LICENSE +0 -0
  17. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/README.md +0 -0
  18. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/adapters/__init__.py +0 -0
  19. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/adapters/verifiers/verifiers_adapters.py +0 -0
  20. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/__init__.py +0 -0
  21. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/bounded_dict.py +0 -0
  22. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/crm/README.md +0 -0
  23. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/crm/salesforce_mcp.py +0 -0
  24. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/crm/salesforce_requirements.txt +0 -0
  25. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/excel/data_utils.py +0 -0
  26. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/excel/excel_code_runner_mcp.py +0 -0
  27. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/excel/excel_utils.py +0 -0
  28. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/math/README.md +0 -0
  29. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/math/math_env.py +0 -0
  30. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/types.py +0 -0
  31. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/wikipedia/README.md +0 -0
  32. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/envs/wikipedia/utils.py +0 -0
  33. {benchmax-0.1.1.dev0 → benchmax-0.1.1.dev1}/benchmax/prompts/tools.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: benchmax
3
- Version: 0.1.1.dev0
3
+ Version: 0.1.1.dev1
4
4
  Summary: Framework-Agnostic RL Environments for LLM Fine-Tuning
5
5
  Author: cgft.io
6
6
  Requires-Python: >=3.11,<3.13
@@ -10,17 +10,16 @@ Classifier: Programming Language :: Python :: 3.12
10
10
  Provides-Extra: crm
11
11
  Provides-Extra: excel
12
12
  Provides-Extra: excel-linux
13
- Provides-Extra: full
14
13
  Provides-Extra: verifiers
15
14
  Provides-Extra: verl
16
- Provides-Extra: wikipedia
17
- Requires-Dist: fastmcp (>=2.10.0,<2.11.0) ; extra == "excel-linux" or extra == "excel" or extra == "crm" or extra == "wikipedia" or extra == "full"
18
- Requires-Dist: openpyxl (==3.1.5) ; extra == "excel-linux" or extra == "excel" or extra == "full"
19
- Requires-Dist: python-dateutil (>=2.9.0,<2.10.0) ; extra == "crm" or extra == "full"
20
- Requires-Dist: simple-salesforce (>=1.12.3) ; extra == "crm" or extra == "full"
21
- Requires-Dist: verifiers[train] (>=0.1.1,<0.2.0) ; extra == "verifiers" or extra == "full"
22
- Requires-Dist: verl-cgft (==0.4.1.dev0) ; extra == "verl" or extra == "full"
23
- Requires-Dist: xlwings (==0.33.15) ; extra == "excel" or extra == "full"
15
+ Requires-Dist: fastmcp (>=2.10.0,<2.11.0)
16
+ Requires-Dist: openpyxl (==3.1.5) ; extra == "excel-linux" or extra == "excel"
17
+ Requires-Dist: python-dateutil (>=2.9.0,<2.10.0) ; extra == "crm"
18
+ Requires-Dist: sglang (==0.4.9) ; extra == "verl"
19
+ Requires-Dist: simple-salesforce (>=1.12.3) ; extra == "crm"
20
+ Requires-Dist: verifiers[train] (>=0.1.1,<0.2.0) ; extra == "verifiers"
21
+ Requires-Dist: verl-cgft-fork (==0.4.1.dev1) ; extra == "verl"
22
+ Requires-Dist: xlwings (==0.33.15) ; extra == "excel"
24
23
  Description-Content-Type: text/markdown
25
24
 
26
25
  <picture>
@@ -136,12 +136,13 @@ if __name__ == "__main__":
136
136
  "init_rollout_args"
137
137
  ]
138
138
  }
139
- if example.get("init_rollout_args", None):
140
- extra_info["tools_kwargs"] = {
141
- tool_name: {
142
- "create_kwargs": {**example.get("init_rollout_args", {})},
143
- } for tool_name in tool_names
144
- }
139
+ create_args = example.get("init_rollout_args", {}) or {"dummy": "dummy"}
140
+ extra_info["tools_kwargs"] = {
141
+ tool_name: {
142
+ "create_kwargs": {**create_args},
143
+ } for tool_name in tool_names
144
+ }
145
+
145
146
  example.pop("init_rollout_args")
146
147
  # This extra_info is used to pass addition info during reward computation
147
148
  example["extra_info"] = extra_info
@@ -24,11 +24,16 @@ class BaseEnv(ABC):
24
24
  - "ground_truth": Any
25
25
  - "init_rollout_args": Optional[Dict[str, Any]]
26
26
  """
27
- return {
28
- "prompt": example.get("prompt", ""),
29
- "ground_truth": example.get("ground_truth", None),
30
- "init_rollout_args": example.get("init_rollout_args", {})
31
- }
27
+ prompt = example.pop("prompt", "")
28
+ ground_truth = example.pop("ground_truth", "")
29
+ init_rollout_args = example.pop("init_rollout_args", "")
30
+ return StandardizedExample(
31
+ prompt=prompt,
32
+ ground_truth=ground_truth,
33
+ init_rollout_args=init_rollout_args,
34
+ **example,
35
+ )
36
+
32
37
  @classmethod
33
38
  def load_dataset(
34
39
  cls, dataset_name: str, **kwargs
@@ -245,8 +245,9 @@ class CRMEnv(LocalMCPEnv):
245
245
  if metadata and "required" in metadata:
246
246
  required_metadata = metadata["required"]
247
247
  prompt = f"{persona}\n{task}\n{required_metadata}\n{query}"
248
- return {
249
- "prompt": prompt,
250
- "ground_truth": answer,
251
- "init_rollout_args": None
252
- }
248
+
249
+ return StandardizedExample(
250
+ prompt=prompt,
251
+ ground_truth=answer,
252
+ init_rollout_args={}
253
+ )
@@ -8,6 +8,9 @@ This is based off the [SpreadsheetBench Benchmark](https://spreadsheetbench.gith
8
8
 
9
9
  **Important**: Before using this environment, ensure you have the appropriate spreadsheet application installed:
10
10
  - **Linux**: LibreOffice must be installed
11
+ ```bash
12
+ sudo apt install libreoffice
13
+ ```
11
14
  - **Windows/macOS**: Microsoft Excel must be installed
12
15
 
13
16
  ## Installation
@@ -121,14 +121,15 @@ class ExcelEnv(LocalMCPEnv):
121
121
  Output Path: {target_output_path}
122
122
  """
123
123
 
124
- return {
125
- "prompt": prompt.strip(),
126
- "ground_truth": "",
127
- "init_rollout_args": {
124
+ return StandardizedExample(
125
+ prompt=prompt.strip(),
126
+ ground_truth="",
127
+ init_rollout_args={
128
128
  "spreadsheet_path": str(source_input_path),
129
129
  "answer_spreadsheet_path": str(Path(spreadsheet_path) / target_answer_path),
130
- }
131
- }
130
+ },
131
+ **example
132
+ )
132
133
 
133
134
  def init_rollout(self, rollout_id: str, **rollout_args):
134
135
  if "spreadsheet_path" not in rollout_args:
@@ -4,6 +4,8 @@ from pathlib import Path
4
4
  import asyncio
5
5
  import json
6
6
  from threading import Thread
7
+ import uuid
8
+
7
9
  from fastmcp import Client as FastMCPClient
8
10
  from fastmcp.exceptions import ToolError
9
11
  from mcp import Tool
@@ -58,8 +60,7 @@ class LocalMCPEnv(BaseEnv):
58
60
  self._output_parsers: Dict[str, Callable[[str], Any]] = {}
59
61
  self._workspace_dir = workspace_dir or Path("workspaces")
60
62
  self._workspace_dir.mkdir(parents=True, exist_ok=True)
61
-
62
- self._counter = 0 # Counter for workspace naming
63
+
63
64
  self._pre_warmed_pool: List[ClientWorkspacePair] = [] # Available pre-initialized pairs
64
65
  self._active_clients: BoundedDict[str, ClientWorkspacePair] = BoundedDict(10000) # rollout_id -> pair mapping
65
66
  self._tool_definitions: Optional[List[ToolDefinition]] = None
@@ -219,12 +220,14 @@ class LocalMCPEnv(BaseEnv):
219
220
  else:
220
221
  self._active_clients.pop(rollout_id)
221
222
 
222
- def get_rollout_workspace(self, rollout_id: str) -> Path:
223
+ def get_rollout_workspace(self, rollout_id: str, strict_check: bool = False) -> Path:
223
224
  """Get dedicated workspace path for a rollout"""
224
225
  if rollout_id in self._active_clients:
225
226
  return self._active_clients[rollout_id].workspace
226
- else:
227
+ if strict_check:
227
228
  raise ValueError(f"No active client found for rollout {rollout_id}")
229
+ else:
230
+ return Path()
228
231
 
229
232
  # ---- Private Helper Methods ----
230
233
 
@@ -267,8 +270,7 @@ class LocalMCPEnv(BaseEnv):
267
270
 
268
271
  async def _create_client_workspace(self) -> ClientWorkspacePair:
269
272
  """Create a new FastMCP client with a unique workspace"""
270
- workspace = Path(self._workspace_dir) / f"{self._counter}"
271
- self._counter += 1
273
+ workspace = self._workspace_dir / uuid.uuid4().hex
272
274
  workspace.mkdir(parents=True, exist_ok=True)
273
275
  config = self._prepare_config(workspace)
274
276
 
@@ -206,11 +206,11 @@ class WikipediaEnv(BaseEnv):
206
206
  return tool_function(**tool_args)
207
207
 
208
208
  def dataset_preprocess(self, example: Any) -> StandardizedExample:
209
- return {
210
- "prompt": example.get("Question", ""),
211
- "ground_truth": example.get("Answer", None),
212
- "init_rollout_args": {}
213
- }
209
+ return StandardizedExample(
210
+ prompt=example.get("Question", ""),
211
+ ground_truth=example.get("Answer", None),
212
+ init_rollout_args={}
213
+ )
214
214
 
215
215
  def init_rollout(self, rollout_id: str, **rollout_args) -> None:
216
216
  return super().init_rollout(rollout_id, **rollout_args)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "benchmax"
3
- version = "0.1.1.dev0"
3
+ version = "0.1.1.dev1"
4
4
  description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
5
5
  authors = ["cgft.io"]
6
6
  readme = "README.md"
@@ -10,12 +10,12 @@ packages = [
10
10
 
11
11
  [tool.poetry.dependencies]
12
12
  python = ">=3.11,<3.13"
13
+ fastmcp = "~2.10.0"
13
14
 
14
- verl-cgft = { version = "0.4.1.dev0", optional = true }
15
+ verl-cgft-fork = { version = "0.4.1.dev1", optional = true }
16
+ sglang = { version = "0.4.9", optional = true }
15
17
  verifiers = { version = "^0.1.1", optional = true, extras = ["train"] }
16
18
 
17
- fastmcp = { version = "~2.10.0", optional = true }
18
-
19
19
  openpyxl = { version = "3.1.5", optional = true }
20
20
  xlwings = { version = "0.33.15", optional = true }
21
21
 
@@ -28,24 +28,12 @@ pytest = "^8.4.1"
28
28
  [tool.poetry.extras]
29
29
  # Independent feature extras
30
30
  verifiers = ["verifiers"]
31
- verl = ["verl-cgft"]
31
+ verl = ["verl-cgft-fork", "sglang"]
32
32
 
33
33
  # Environment-specific
34
- excel-linux = ["openpyxl", "fastmcp"]
35
- excel = ["openpyxl", "xlwings", "fastmcp"]
36
- crm = ["simple-salesforce", "python-dateutil", "fastmcp"]
37
- wikipedia = ["fastmcp"]
38
-
39
- # Everything
40
- full = [
41
- "verl-cgft",
42
- "verifiers",
43
- "openpyxl",
44
- "xlwings",
45
- "simple-salesforce",
46
- "python-dateutil",
47
- "fastmcp"
48
- ]
34
+ excel-linux = ["openpyxl"]
35
+ excel = ["openpyxl", "xlwings"]
36
+ crm = ["simple-salesforce", "python-dateutil"]
49
37
 
50
38
  [build-system]
51
39
  requires = ["poetry-core"]
@@ -1,54 +0,0 @@
1
- import verifiers as vf
2
-
3
- from benchmax.adapters.verifiers.verifiers_adapters import get_verifiers_environment
4
- from benchmax.envs.crm.crm_env import CRMEnv
5
-
6
-
7
- """
8
- Multi-GPU training (single node, 3 training + 1 inference)
9
-
10
- CUDA_VISIBLE_DEVICES=0 poetry run vf-vllm --model willcb/Qwen3-4B
11
-
12
- CUDA_VISIBLE_DEVICES=1,2,3 accelerate launch benchmax/adapters/verifiers/examples/verifiers_crm_example.py
13
- """
14
-
15
- dataset, _ = CRMEnv.load_dataset("Salesforce/CRMArenaPro", name="CRMArenaPro", split="b2b")
16
- benchmax_env = CRMEnv(pool_size=3)
17
- dataset = dataset.map(
18
- lambda example: benchmax_env.dataset_preprocess(example),
19
- )
20
- splits = dataset.train_test_split(test_size=0.1, seed=42)
21
-
22
- train_ds = splits["train"]
23
-
24
- vf_env = get_verifiers_environment(
25
- benchmax_env,
26
- max_concurrent=3,
27
- max_turns=3,
28
- dataset=train_ds,
29
- )
30
-
31
- model_name = "willcb/Qwen3-4B"
32
- model, tokenizer = vf.get_model_and_tokenizer(model_name)
33
- run_name = "verifiers-excel" + model_name.split("/")[-1].lower()
34
-
35
- training_args=vf.grpo_defaults(run_name=run_name)
36
- training_args.per_device_train_batch_size=2
37
- training_args.num_generations=12
38
- training_args.gradient_accumulation_steps=2
39
- training_args.num_iterations=1
40
- training_args.num_train_epochs=5
41
- training_args.max_prompt_length=10000
42
- training_args.max_completion_length=4096
43
- training_args.max_steps=500
44
- training_args.save_steps=100
45
- training_args.report_to = "none"
46
- training_args.log_completions = False
47
-
48
- trainer = vf.GRPOTrainer(
49
- model=model,
50
- processing_class=tokenizer,
51
- env=vf_env,
52
- args=training_args,
53
- )
54
- trainer.train()
@@ -1,54 +0,0 @@
1
- import verifiers as vf
2
-
3
- from benchmax.adapters.verifiers.verifiers_adapters import get_verifiers_environment
4
- from benchmax.envs.excel.excel_env import ExcelEnv
5
-
6
-
7
- """
8
- Multi-GPU training (single node, 3 training + 1 inference)
9
-
10
- CUDA_VISIBLE_DEVICES=0 poetry run vf-vllm --model willcb/Qwen3-4B
11
-
12
- CUDA_VISIBLE_DEVICES=1,2,3 accelerate launch benchmax/adapters/verifiers/examples/verifiers_excel_example.py
13
- """
14
-
15
- dataset, dataset_path = ExcelEnv.load_dataset()
16
- mcp_benchmax_env = ExcelEnv(dataset_path=dataset_path, pool_size=3)
17
- dataset = dataset.map(
18
- lambda example: mcp_benchmax_env.dataset_preprocess(example),
19
- )
20
- splits = dataset.train_test_split(test_size=0.1, seed=42)
21
-
22
- train_ds = splits["train"]
23
-
24
- vf_env = get_verifiers_environment(
25
- mcp_benchmax_env,
26
- max_concurrent=3,
27
- max_turns=3,
28
- dataset=train_ds,
29
- )
30
-
31
- model_name = "willcb/Qwen3-4B"
32
- model, tokenizer = vf.get_model_and_tokenizer(model_name)
33
- run_name = "verifiers-excel" + model_name.split("/")[-1].lower()
34
-
35
- training_args=vf.grpo_defaults(run_name=run_name)
36
- training_args.per_device_train_batch_size=2
37
- training_args.num_generations=12
38
- training_args.gradient_accumulation_steps=2
39
- training_args.num_iterations=1
40
- training_args.num_train_epochs=5
41
- training_args.max_prompt_length=10000
42
- training_args.max_completion_length=4096
43
- training_args.max_steps=500
44
- training_args.save_steps=100
45
- training_args.report_to = "none"
46
- training_args.log_completions = False
47
-
48
- trainer = vf.GRPOTrainer(
49
- model=model,
50
- processing_class=tokenizer,
51
- env=vf_env,
52
- args=training_args,
53
- )
54
- trainer.train()
@@ -1,54 +0,0 @@
1
- import verifiers as vf
2
-
3
- from datasets import load_dataset
4
-
5
- from benchmax.adapters.verifiers.verifiers_adapters import get_verifiers_environment
6
- from benchmax.envs.math.math_env import MathEnv
7
-
8
-
9
- """
10
- Multi-GPU training (single node, 3 training + 1 inference)
11
-
12
- CUDA_VISIBLE_DEVICES=0 poetry run vf-vllm --model willcb/Qwen3-4B
13
-
14
- CUDA_VISIBLE_DEVICES=1,2,3 accelerate launch benchmax/adapters/verifiers/examples/verifiers_math_example.py
15
- """
16
-
17
- math_env = MathEnv()
18
- dataset, _ = MathEnv.load_dataset("dawidmt/arithmetic50", split="test")
19
- dataset = dataset.map(
20
- lambda example: math_env.dataset_preprocess(example),
21
- )
22
- splits = dataset.train_test_split(test_size=0.1, seed=42)
23
- train_ds = splits["train"]
24
-
25
- vf_env = get_verifiers_environment(
26
- math_env,
27
- max_concurrent=3,
28
- max_turns=3,
29
- dataset=train_ds,
30
- )
31
-
32
- model_name = "willcb/Qwen3-4B"
33
- model, tokenizer = vf.get_model_and_tokenizer(model_name)
34
- run_name = "math-grpo" + model_name.split("/")[-1].lower()
35
-
36
- training_args=vf.grpo_defaults(run_name=run_name)
37
- training_args.per_device_train_batch_size=6
38
- training_args.num_generations=12
39
- training_args.gradient_accumulation_steps=2
40
- training_args.num_iterations=1
41
- training_args.num_train_epochs=5
42
- training_args.max_prompt_length=1024
43
- training_args.max_completion_length=4096
44
- training_args.max_steps=500
45
- training_args.save_steps=100
46
- training_args.report_to = "none"
47
-
48
- trainer = vf.GRPOTrainer(
49
- model=model,
50
- processing_class=tokenizer,
51
- env=vf_env,
52
- args=training_args,
53
- )
54
- trainer.train()
@@ -1,21 +0,0 @@
1
- hydra:
2
- searchpath:
3
- - pkg://verl.trainer.config
4
-
5
- defaults:
6
- - ppo_trainer
7
- - _self_
8
-
9
- data:
10
- max_prompt_length: 1024
11
- max_response_length: 1024
12
- train_batch_size: 256
13
- return_raw_chat: True
14
-
15
- actor_rollout_ref:
16
- hybrid_engine: True
17
- rollout:
18
- name: sglang
19
- multi_turn:
20
- enable: True
21
- max_assistant_turns: 5
@@ -1,7 +0,0 @@
1
- # Example Benchmax Tool Config
2
- tools:
3
- # Class name points to benchmax class. This is expected to be a subclass of benchmax.envs.BaseEnv
4
- - class_name: benchmax.envs.math.math_env.MathEnv
5
- config:
6
- type: benchmax
7
- # Specify initialization args for Sandbox here e.g. api_keys
@@ -1,69 +0,0 @@
1
- # make sure your current working directory is the root of the project
2
- # Specifically note the last 3 lines
3
- # The first line points to tool config, which is necessary for initializing tools from the benchmax environment
4
- # The second and third lines point to the relevant benchmax environment to initialize the rewards from
5
-
6
- set -x
7
-
8
- ulimit -n 65535
9
-
10
- PROJECT_DIR="$(pwd)"
11
- CONFIG_PATH="$PROJECT_DIR/benchmax/adapters/verl/examples/config"
12
- CONFIG_NAME="benchmax_multiturn_grpo"
13
-
14
-
15
- TRAIN_DATA="~/data/math/train.parquet"
16
- VAL_DATA="~/data/math/test.parquet"
17
-
18
- TOOL_CONFIG="$CONFIG_PATH/tool_config/benchmax_math_tool_config.yaml"
19
- BENCHMAX_CLASS_NAME="benchmax.envs.math.math_env.MathEnv"
20
-
21
- PYTHONPATH="$PYTHONPATH:$(pwd)" python -m verl.trainer.main_ppo \
22
- --config-path="$CONFIG_PATH" \
23
- --config-name="$CONFIG_NAME" \
24
- algorithm.adv_estimator=grpo \
25
- data.train_batch_size=4 \
26
- data.val_batch_size=4 \
27
- data.max_prompt_length=4096 \
28
- data.max_response_length=3000 \
29
- data.filter_overlong_prompts=True \
30
- data.truncation='error' \
31
- data.return_raw_chat=True \
32
- actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
33
- actor_rollout_ref.actor.optim.lr=1e-6 \
34
- actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \
35
- actor_rollout_ref.model.use_remove_padding=True \
36
- actor_rollout_ref.actor.ppo_mini_batch_size=4 \
37
- actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
38
- actor_rollout_ref.actor.use_kl_loss=True \
39
- actor_rollout_ref.actor.kl_loss_coef=0.001 \
40
- actor_rollout_ref.actor.kl_loss_type=low_var_kl \
41
- actor_rollout_ref.actor.entropy_coeff=0 \
42
- actor_rollout_ref.model.enable_gradient_checkpointing=True \
43
- actor_rollout_ref.actor.fsdp_config.param_offload=False \
44
- actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
45
- actor_rollout_ref.rollout.max_model_len=15000 \
46
- actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
47
- actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
48
- actor_rollout_ref.rollout.name=sglang \
49
- actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
50
- actor_rollout_ref.rollout.n=5 \
51
- actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5 \
52
- actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
53
- actor_rollout_ref.ref.fsdp_config.param_offload=True \
54
- algorithm.use_kl_in_reward=False \
55
- trainer.critic_warmup=0 \
56
- trainer.val_before_train=False \
57
- trainer.logger=['console','wandb'] \
58
- trainer.project_name='wiki_search' \
59
- trainer.experiment_name='qwen2.5-3b-instruct_wiki_search' \
60
- trainer.n_gpus_per_node=4 \
61
- trainer.nnodes=1 \
62
- trainer.save_freq=100 \
63
- trainer.test_freq=50 \
64
- data.train_files="$TRAIN_DATA" \
65
- data.val_files="$VAL_DATA" \
66
- trainer.total_epochs=1 $@ \
67
- actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \
68
- reward_model.reward_manager=benchmax \
69
- +reward_model.reward_kwargs.benchmax_cls_name="$BENCHMAX_CLASS_NAME"
File without changes
File without changes