arbor-ai 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -199,10 +199,16 @@ class GRPOConfigRequest(BaseModel):
199
199
  bf16: Optional[bool] = None
200
200
  scale_rewards: Optional[bool] = None
201
201
  max_grad_norm: Optional[float] = None
202
+ report_to: Optional[str] = None
203
+ log_completions: Optional[bool] = None
204
+ logging_steps: Optional[int] = None
205
+ mask_truncated_completions: Optional[bool] = None
206
+ # Arbor specific
207
+ max_context_length: Optional[int] = None
202
208
  lora: Optional[bool] = None
203
- update_interval: Optional[int] = None
204
209
  # To name the run
205
210
  suffix: Optional[str] = None
211
+ generation_batch_size: Optional[int] = None
206
212
 
207
213
 
208
214
  class GRPOConfigResponse(BaseModel):
@@ -216,8 +222,23 @@ class GRPOTerminateRequest(BaseModel):
216
222
  class GRPOTerminateResponse(BaseModel):
217
223
  status: str
218
224
  current_model: str
225
+ checkpoints: Optional[dict[str, str]] = None
226
+ last_checkpoint: Optional[str] = None
219
227
 
220
228
 
221
229
  class GRPOStepResponse(BaseModel):
222
230
  status: str
223
231
  current_model: str
232
+ checkpoints: dict[str, str]
233
+ last_checkpoint: Optional[str] = None
234
+
235
+
236
+ class GRPOCheckpointRequest(BaseModel):
237
+ checkpoint_name: str
238
+
239
+
240
+ class GRPOCheckpointResponse(BaseModel):
241
+ status: str
242
+ current_model: str
243
+ checkpoints: dict[str, str]
244
+ last_checkpoint: str
@@ -4,6 +4,8 @@ import subprocess
4
4
  from fastapi import APIRouter, BackgroundTasks, Request
5
5
 
6
6
  from arbor.server.api.models.schemas import (
7
+ GRPOCheckpointRequest,
8
+ GRPOCheckpointResponse,
7
9
  GRPOConfigRequest,
8
10
  GRPOConfigResponse,
9
11
  GRPORequest,
@@ -31,17 +33,24 @@ def run_grpo_step(
31
33
  inference_manager = request.app.state.inference_manager
32
34
  grpo_manager = request.app.state.grpo_manager
33
35
 
34
- current_model = grpo_manager.grpo_step(grpo_request, inference_manager)
36
+ step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
35
37
 
36
- return GRPOStepResponse(status="success", current_model=current_model)
38
+ return GRPOStepResponse(status="success", **step_data)
37
39
 
38
40
 
39
41
  @router.post("/update_model", response_model=GRPOStepResponse)
40
42
  def update_model(request: Request):
41
43
  grpo_manager = request.app.state.grpo_manager
42
44
  inference_manager = request.app.state.inference_manager
43
- current_model = grpo_manager.update_model(request, inference_manager)
44
- return GRPOStepResponse(status="success", current_model=current_model)
45
+ update_model_data = grpo_manager.update_model(request, inference_manager)
46
+ return GRPOStepResponse(status="success", **update_model_data)
47
+
48
+
49
+ @router.post("/checkpoint", response_model=GRPOCheckpointResponse)
50
+ def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
51
+ grpo_manager = request.app.state.grpo_manager
52
+ checkpoint_data = grpo_manager.checkpoint(grpo_checkpoint_request)
53
+ return GRPOCheckpointResponse(status="success", **checkpoint_data)
45
54
 
46
55
 
47
56
  @router.post("/terminate", response_model=GRPOTerminateResponse)
@@ -50,5 +59,5 @@ def terminate_grpo(request: Request):
50
59
  grpo_manager = request.app.state.grpo_manager
51
60
  inference_manager = request.app.state.inference_manager
52
61
 
53
- final_model = grpo_manager.terminate(inference_manager)
54
- return GRPOTerminateResponse(status="success", current_model=final_model)
62
+ terminate_data = grpo_manager.terminate(inference_manager)
63
+ return GRPOTerminateResponse(status="success", **terminate_data)
@@ -13,7 +13,11 @@ from datetime import datetime
13
13
  from pathlib import Path
14
14
  from typing import Optional
15
15
 
16
- from arbor.server.api.models.schemas import GRPOConfigRequest, GRPORequest
16
+ from arbor.server.api.models.schemas import (
17
+ GRPOCheckpointRequest,
18
+ GRPOConfigRequest,
19
+ GRPORequest,
20
+ )
17
21
  from arbor.server.core.config import Settings
18
22
  from arbor.server.services.comms.comms import ArborServerCommsHandler
19
23
  from arbor.server.services.inference_manager import InferenceManager
@@ -28,7 +32,10 @@ class GRPOManager:
28
32
  self.server_comms_handler = None
29
33
  self.status_thread = None
30
34
  self.model_saved_and_reload_requested = False
35
+ self.saving_checkpoint = False
31
36
 
37
+ self.checkpoints = {}
38
+ self.last_checkpoint = None
32
39
  self.data_count = 0
33
40
  self.last_inference_update = 0
34
41
  # Set up signal handler
@@ -86,12 +93,17 @@ class GRPOManager:
86
93
  "bf16",
87
94
  "scale_rewards",
88
95
  "max_grad_norm",
96
+ "report_to",
97
+ "log_completions",
98
+ "logging_steps",
99
+ "generation_batch_size",
100
+ "mask_truncated_completions",
89
101
  ]
90
102
  trl_train_kwargs = {
91
103
  key: train_kwargs[key] for key in trl_keys if key in train_kwargs
92
104
  }
93
105
 
94
- arbor_keys = ["update_interval", "lora"]
106
+ arbor_keys = ["max_context_length", "lora"]
95
107
  arbor_train_kwargs = {
96
108
  key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
97
109
  }
@@ -119,6 +131,8 @@ class GRPOManager:
119
131
  # Start the training process with ZMQ ports
120
132
  my_env = os.environ.copy()
121
133
  my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
134
+ # WandB can block the training process for login, so we silence it
135
+ my_env["WANDB_SILENT"] = "true"
122
136
 
123
137
  num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
124
138
 
@@ -209,6 +223,12 @@ class GRPOManager:
209
223
 
210
224
  # Launch the inference server
211
225
  print("Launching inference server...")
226
+ # launch_kwargs = {
227
+ # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
228
+ # }
229
+ inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
230
+ "max_context_length", None
231
+ )
212
232
  inference_manager.launch(self.current_model)
213
233
 
214
234
  def _handle_status_updates(self, inference_manager: InferenceManager):
@@ -228,6 +248,12 @@ class GRPOManager:
228
248
  self.model_saved_and_reload_requested = False
229
249
  self.current_model = status["output_dir"]
230
250
  print("Model update complete")
251
+ elif status["status"] == "checkpoint_saved":
252
+ print("Received checkpoint saved status")
253
+ self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
254
+ self.last_checkpoint = status["checkpoint_name"]
255
+ self.saving_checkpoint = False
256
+ print("Checkpoint saved")
231
257
  elif status["status"] == "error":
232
258
  print(f"Training error: {status.get('error', 'Unknown error')}")
233
259
  elif status["status"] == "terminated":
@@ -249,6 +275,10 @@ class GRPOManager:
249
275
  )
250
276
  time.sleep(5)
251
277
 
278
+ while self.saving_checkpoint:
279
+ print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
280
+ time.sleep(5)
281
+
252
282
  try:
253
283
  # Send the batch to the training process
254
284
  self.server_comms_handler.send_data(request.batch)
@@ -256,12 +286,11 @@ class GRPOManager:
256
286
  except Exception as e:
257
287
  print(f"Failed to send batch to training process: {e}")
258
288
 
259
- # We tell the script to save the model. The script will let us know when it's done via the status update handler
260
- # Then we'll actually run the update_model function in the inference manager and finally update the last_inference_update variable
261
- # if self._should_update_model():
262
- # self.server_comms_handler.send_command({"command": "save_model"})
263
-
264
- return self.current_model
289
+ return {
290
+ "current_model": self.current_model,
291
+ "checkpoints": self.checkpoints,
292
+ "last_checkpoint": self.last_checkpoint,
293
+ }
265
294
 
266
295
  def update_model(self, request, inference_manager: InferenceManager):
267
296
  if inference_manager._session:
@@ -286,18 +315,41 @@ class GRPOManager:
286
315
  "Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
287
316
  )
288
317
  time.sleep(5)
289
- return self.current_model
318
+ return {
319
+ "current_model": self.current_model,
320
+ "checkpoints": self.checkpoints,
321
+ "last_checkpoint": self.last_checkpoint,
322
+ }
323
+
324
+ def checkpoint(self, request: GRPOCheckpointRequest):
325
+ self.saving_checkpoint = True
326
+ self.server_comms_handler.send_command(
327
+ {"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
328
+ )
329
+ while self.saving_checkpoint:
330
+ print("Waiting for checkpoint to be saved...")
331
+ time.sleep(5)
332
+ return {
333
+ "current_model": self.current_model,
334
+ "checkpoints": self.checkpoints,
335
+ "last_checkpoint": self.last_checkpoint,
336
+ }
290
337
 
291
338
  def terminate(self, inference_manager: InferenceManager):
292
339
  """Clean up resources and save the final model."""
340
+ termination_data = {
341
+ "current_model": self.current_model,
342
+ "checkpoints": self.checkpoints,
343
+ "last_checkpoint": self.last_checkpoint,
344
+ }
293
345
  try:
294
346
  # Stop the inference server
295
347
  if inference_manager.process is not None:
296
348
  inference_manager.kill()
297
349
 
298
350
  # Send termination command through REQ socket
299
- # self.server_comms_handler.send_broadcast({"message": "terminate"})
300
- self.training_process.terminate()
351
+ self.server_comms_handler.send_broadcast({"message": "terminate"})
352
+ # self.training_process.terminate()
301
353
  print("Waiting for training process to finish")
302
354
 
303
355
  # Wait for training process to finish
@@ -336,17 +388,13 @@ class GRPOManager:
336
388
  )
337
389
  output_dir = self.train_kwargs["output_dir"]
338
390
  self.train_kwargs = None
339
- return output_dir
340
391
  else:
341
392
  print("Training terminated, no output directory specified")
342
393
  self.train_kwargs = None
343
- return None
394
+
395
+ return termination_data
344
396
 
345
397
  def _should_update_model(self):
346
- # return (
347
- # self.data_count - self.last_inference_update
348
- # >= self.train_kwargs["update_interval"]
349
- # )
350
398
  return self.model_saved_and_reload_requested
351
399
 
352
400
 
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import json
3
3
  import os
4
+ import random
4
5
  import signal
5
6
  import socket
6
7
  import subprocess
@@ -47,7 +48,12 @@ class InferenceManager:
47
48
  def is_server_restarting(self):
48
49
  return self.restarting
49
50
 
50
- def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
51
+ def launch(
52
+ self,
53
+ model: str,
54
+ launch_kwargs: Optional[Dict[str, Any]] = None,
55
+ max_retries: int = 3,
56
+ ):
51
57
  if self.is_server_running():
52
58
  print("Server is already launched.")
53
59
  return
@@ -59,77 +65,112 @@ class InferenceManager:
59
65
  if model.startswith(prefix):
60
66
  model = model[len(prefix) :]
61
67
 
62
- print(f"Grabbing a free port to launch an SGLang server for model {model}")
63
- port = get_free_port()
64
- timeout = launch_kwargs.get("timeout", 1800)
65
- my_env = os.environ.copy()
66
- my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
67
- n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
68
- # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
69
- command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
70
- print(f"Running command: {command}")
71
-
72
- # We will manually stream & capture logs.
73
- process = subprocess.Popen(
74
- command.replace("\\\n", " ").replace("\\", " ").split(),
75
- text=True,
76
- stdout=subprocess.PIPE, # We'll read from pipe
77
- stderr=subprocess.STDOUT, # Merge stderr into stdout
78
- env=my_env,
79
- )
80
-
81
- # A threading.Event to control printing after the server is ready.
82
- # This will store *all* lines (both before and after readiness).
83
- print(f"SGLang server process started with PID {process.pid}.")
84
- stop_printing_event = threading.Event()
85
- logs_buffer = []
86
-
87
- def _tail_process(proc, buffer, stop_event):
88
- while True:
89
- line = proc.stdout.readline()
90
- if not line and proc.poll() is not None:
91
- # Process ended and no new line
92
- break
93
- if line:
94
- buffer.append(line)
95
- # Print only if stop_event is not set
96
- if not stop_event.is_set():
97
- print(f"[SGLang LOG] {line}", end="")
98
-
99
- # Start a background thread to read from the process continuously
100
- thread = threading.Thread(
101
- target=_tail_process,
102
- args=(process, logs_buffer, stop_printing_event),
103
- daemon=True,
104
- )
105
- thread.start()
106
-
107
- # Wait until the server is ready (or times out)
108
- base_url = f"http://localhost:{port}"
109
- try:
110
- wait_for_server(base_url, timeout=timeout)
111
- except TimeoutError:
112
- # If the server doesn't come up, we might want to kill it:
113
- process.kill()
114
- raise
115
-
116
- # Once server is ready, we tell the thread to stop printing further lines.
117
- stop_printing_event.set()
118
-
119
- # A convenience getter so the caller can see all logs so far (and future).
120
- def get_logs() -> str:
121
- # Join them all into a single string, or you might return a list
122
- return "".join(logs_buffer)
123
-
124
- # Let the user know server is up
125
- print(f"Server ready on random port {port}!")
68
+ retries = 0
69
+ while retries < max_retries:
70
+ try:
71
+ print(
72
+ f"Attempt {retries + 1} of {max_retries} to launch server for model {model}"
73
+ )
74
+ print(
75
+ f"Grabbing a free port to launch an SGLang server for model {model}"
76
+ )
77
+ port = get_free_port()
78
+ timeout = launch_kwargs.get("timeout", 1800)
79
+ my_env = os.environ.copy()
80
+ my_env["CUDA_VISIBLE_DEVICES"] = (
81
+ self.settings.arbor_config.inference.gpu_ids
82
+ )
83
+ n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
84
+ # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
85
+ command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
86
+ print(f"Running command: {command}")
87
+ if launch_kwargs.get("max_context_length"):
88
+ command += (
89
+ f" --context-length {launch_kwargs['max_context_length']}"
90
+ )
126
91
 
127
- self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
128
- self.launch_kwargs["api_key"] = "local"
129
- self.get_logs = get_logs
130
- self.process = process
131
- self.thread = thread
132
- self.current_model = model
92
+ # We will manually stream & capture logs.
93
+ process = subprocess.Popen(
94
+ command.replace("\\\n", " ").replace("\\", " ").split(),
95
+ text=True,
96
+ stdout=subprocess.PIPE, # We'll read from pipe
97
+ stderr=subprocess.STDOUT, # Merge stderr into stdout
98
+ env=my_env,
99
+ )
100
+
101
+ # A threading.Event to control printing after the server is ready.
102
+ # This will store *all* lines (both before and after readiness).
103
+ print(f"SGLang server process started with PID {process.pid}.")
104
+ stop_printing_event = threading.Event()
105
+ logs_buffer = []
106
+
107
+ def _tail_process(proc, buffer, stop_event):
108
+ while True:
109
+ line = proc.stdout.readline()
110
+ if not line and proc.poll() is not None:
111
+ # Process ended and no new line
112
+ break
113
+ if line:
114
+ buffer.append(line)
115
+ # Print only if stop_event is not set
116
+ if not stop_event.is_set():
117
+ print(f"[SGLang LOG] {line}", end="")
118
+
119
+ # Start a background thread to read from the process continuously
120
+ thread = threading.Thread(
121
+ target=_tail_process,
122
+ args=(process, logs_buffer, stop_printing_event),
123
+ daemon=True,
124
+ )
125
+ thread.start()
126
+
127
+ # Wait until the server is ready (or times out)
128
+ base_url = f"http://localhost:{port}"
129
+ try:
130
+ wait_for_server(base_url, timeout=timeout)
131
+ except TimeoutError:
132
+ # If the server doesn't come up, we might want to kill it:
133
+ process.kill()
134
+ raise
135
+
136
+ # Once server is ready, we tell the thread to stop printing further lines.
137
+ stop_printing_event.set()
138
+
139
+ # A convenience getter so the caller can see all logs so far (and future).
140
+ def get_logs() -> str:
141
+ # Join them all into a single string, or you might return a list
142
+ return "".join(logs_buffer)
143
+
144
+ # Let the user know server is up
145
+ print(f"Server ready on random port {port}!")
146
+
147
+ self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
148
+ self.launch_kwargs["api_key"] = "local"
149
+ self.get_logs = get_logs
150
+ self.process = process
151
+ self.thread = thread
152
+ self.current_model = model
153
+
154
+ # If we get here, the launch was successful
155
+ return
156
+
157
+ except Exception as e:
158
+ retries += 1
159
+ print(
160
+ f"Failed to launch server (attempt {retries} of {max_retries}): {str(e)}"
161
+ )
162
+ # Clean up any failed processes
163
+ if "process" in locals():
164
+ try:
165
+ process.kill()
166
+ except:
167
+ pass
168
+ if retries == max_retries:
169
+ raise Exception(
170
+ f"Failed to launch server after {max_retries} attempts"
171
+ ) from e
172
+ # Wait a bit before retrying
173
+ time.sleep(min(2**retries, 30)) # Exponential backoff, max 30 seconds
133
174
 
134
175
  def kill(self):
135
176
  from sglang.utils import terminate_process
@@ -184,7 +225,7 @@ class InferenceManager:
184
225
  print(f"Running inference for model {model}")
185
226
  # Monkeypatch:
186
227
  if model != self.current_model:
187
- print(f"MONKEYPATCH: Model changed from {model} to {self.current_model}")
228
+ print(f"Model changed from {model} to {self.current_model}")
188
229
  model = self.current_model
189
230
  request_json["model"] = model
190
231
 
@@ -214,6 +255,12 @@ class InferenceManager:
214
255
  await self._session.close()
215
256
  self._session = None
216
257
  return None
258
+ except json.decoder.JSONDecodeError:
259
+ print(f"JSON Decode Error during inference: {content}")
260
+ return {
261
+ "error": "JSON Decode Error",
262
+ "content": content if content else "Content is null",
263
+ }
217
264
  except Exception as e:
218
265
  print(f"Error during inference: {e}")
219
266
  raise
@@ -241,6 +288,7 @@ class InferenceManager:
241
288
  tik = time.time()
242
289
  self.kill()
243
290
  print("Just killed server")
291
+ time.sleep(5)
244
292
  # Check that output directory exists and was created successfully
245
293
  print(f"Checking that output directory {output_dir} exists")
246
294
  if not os.path.exists(output_dir):
@@ -14,7 +14,7 @@ from typing import Any, List, Optional, Union
14
14
  import torch
15
15
  import zmq
16
16
  from accelerate import Accelerator
17
- from accelerate.utils import gather
17
+ from accelerate.utils import broadcast_object_list, gather, gather_object
18
18
  from datasets import Dataset, IterableDataset, load_dataset
19
19
  from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig # type: ignore
20
20
  from torch.utils.data import Dataset
@@ -23,6 +23,7 @@ from transformers import (
23
23
  PreTrainedTokenizerBase,
24
24
  Trainer,
25
25
  TrainerCallback,
26
+ is_wandb_available,
26
27
  )
27
28
  from trl import GRPOConfig, GRPOTrainer
28
29
  from trl.data_utils import maybe_apply_chat_template
@@ -32,6 +33,9 @@ from arbor.server.services.comms.comms import (
32
33
  ArborServerCommsHandler,
33
34
  )
34
35
 
36
+ if is_wandb_available():
37
+ import wandb
38
+
35
39
  last_step_time = None
36
40
  last_queue_pop_time = None
37
41
 
@@ -65,8 +69,9 @@ class ArborGRPOTrainer(GRPOTrainer):
65
69
  ] = (None, None),
66
70
  peft_config: Optional["PeftConfig"] = None,
67
71
  comms_handler: Optional[ArborScriptCommsHandler] = None,
68
- update_interval: Optional[int] = 5,
69
72
  lora: Optional[bool] = False,
73
+ # We do nothing with max_context_length right now
74
+ max_context_length: Optional[int] = None,
70
75
  **kwargs,
71
76
  ):
72
77
 
@@ -85,12 +90,12 @@ class ArborGRPOTrainer(GRPOTrainer):
85
90
  self.peft_config = peft_config
86
91
  self.scale_rewards = scale_rewards
87
92
  self.comms_handler = comms_handler
88
- self.update_interval = update_interval
89
93
 
90
94
  def _generate_and_score_completions(
91
95
  self, batch: List[dict[str, Any]]
92
96
  ) -> dict[str, Union[torch.Tensor, Any]]:
93
97
  device = self.accelerator.device
98
+ mode = "train" if self.model.training else "eval"
94
99
 
95
100
  # Process prompts and completions
96
101
  prompt_completion_texts = []
@@ -106,12 +111,12 @@ class ArborGRPOTrainer(GRPOTrainer):
106
111
  )
107
112
 
108
113
  # Tokenize prompts
109
- prompt_texts = [
114
+ prompts_text = [
110
115
  prompt_completion_text["prompt"]
111
116
  for prompt_completion_text in prompt_completion_texts
112
117
  ]
113
118
  prompt_inputs = self.processing_class(
114
- prompt_texts,
119
+ prompts_text,
115
120
  return_tensors="pt",
116
121
  padding=True,
117
122
  padding_side="left",
@@ -124,12 +129,12 @@ class ArborGRPOTrainer(GRPOTrainer):
124
129
  )
125
130
 
126
131
  # Tokenize completions
127
- completion_texts = [
132
+ completions_text = [
128
133
  prompt_completion_text["completion"]
129
134
  for prompt_completion_text in prompt_completion_texts
130
135
  ]
131
136
  completion_ids = self.processing_class(
132
- completion_texts,
137
+ completions_text,
133
138
  return_tensors="pt",
134
139
  padding=True,
135
140
  add_special_tokens=False,
@@ -156,6 +161,30 @@ class ArborGRPOTrainer(GRPOTrainer):
156
161
  # self._move_model_to_vllm()
157
162
  # self._last_loaded_step = self.state.global_step
158
163
 
164
+ prompt_ids = broadcast_object_list(prompt_ids)
165
+ prompt_mask = broadcast_object_list(prompt_mask)
166
+ completion_ids = broadcast_object_list(completion_ids)
167
+ completion_mask = broadcast_object_list(completion_mask)
168
+
169
+ process_slice = slice(
170
+ self.accelerator.process_index * len(batch),
171
+ (self.accelerator.process_index + 1) * len(batch),
172
+ )
173
+
174
+ prompt_ids = prompt_ids[process_slice]
175
+ prompt_mask = prompt_mask[process_slice]
176
+ completion_ids = completion_ids[process_slice]
177
+ completion_mask = completion_mask[process_slice]
178
+
179
+ is_eos = completion_ids == self.processing_class.eos_token_id
180
+
181
+ # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
182
+ if self.mask_truncated_completions:
183
+ truncated_completions = ~is_eos.any(dim=1)
184
+ completion_mask = (
185
+ completion_mask * (~truncated_completions).unsqueeze(1).int()
186
+ )
187
+
159
188
  prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
160
189
  attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
161
190
 
@@ -164,34 +193,29 @@ class ArborGRPOTrainer(GRPOTrainer):
164
193
  )
165
194
 
166
195
  logits_to_keep = completion_ids.size(1)
196
+ batch_size = (
197
+ self.args.per_device_train_batch_size
198
+ if mode == "train"
199
+ else self.args.per_device_eval_batch_size
200
+ )
167
201
 
168
202
  with torch.no_grad():
169
203
  # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
170
204
  # computation here, and use per_token_logps.detach() instead.
171
- if self.num_iterations > 1:
205
+ if (
206
+ self.num_iterations > 1
207
+ or self.args.steps_per_generation
208
+ > self.args.gradient_accumulation_steps
209
+ ):
172
210
  old_per_token_logps = self._get_per_token_logps(
173
- self.model, prompt_completion_ids, attention_mask, logits_to_keep
174
- )
175
- else:
176
- old_per_token_logps = None
177
-
178
- if self.beta == 0.0:
179
- ref_per_token_logps = None
180
- elif self.ref_model is not None:
181
- ref_per_token_logps = self._get_per_token_logps(
182
- self.ref_model,
211
+ self.model,
183
212
  prompt_completion_ids,
184
213
  attention_mask,
185
214
  logits_to_keep,
215
+ batch_size,
186
216
  )
187
217
  else:
188
- with self.accelerator.unwrap_model(self.model).disable_adapter():
189
- ref_per_token_logps = self._get_per_token_logps(
190
- self.model,
191
- prompt_completion_ids,
192
- attention_mask,
193
- logits_to_keep,
194
- )
218
+ old_per_token_logps = None
195
219
 
196
220
  rewards = torch.tensor(
197
221
  [example["reward"] for example in batch], dtype=torch.float32
@@ -219,7 +243,56 @@ class ArborGRPOTrainer(GRPOTrainer):
219
243
  )
220
244
  advantages = advantages[process_slice]
221
245
 
222
- ## Logged Metrics Removed Here
246
+ # Log the metrics
247
+ if mode == "train":
248
+ self.state.num_input_tokens_seen += (
249
+ self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
250
+ )
251
+ self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
252
+
253
+ # log completion lengths, mean, min, max
254
+ agg_completion_mask = self.accelerator.gather_for_metrics(
255
+ completion_mask.sum(1)
256
+ )
257
+ self._metrics[mode]["completions/mean_length"].append(
258
+ agg_completion_mask.float().mean().item()
259
+ )
260
+ self._metrics[mode]["completions/min_length"].append(
261
+ agg_completion_mask.float().min().item()
262
+ )
263
+ self._metrics[mode]["completions/max_length"].append(
264
+ agg_completion_mask.float().max().item()
265
+ )
266
+
267
+ # identify sequences that terminated with EOS and log their lengths
268
+ agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
269
+ term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
270
+ clipped_completions_ratio = 1 - len(term_completion_mask) / len(
271
+ agg_completion_mask
272
+ )
273
+ self._metrics[mode]["completions/clipped_ratio"].append(
274
+ clipped_completions_ratio
275
+ )
276
+ if len(term_completion_mask) == 0:
277
+ # edge case where no completed sequences are found
278
+ term_completion_mask = torch.zeros(1, device=device)
279
+ self._metrics[mode]["completions/mean_terminated_length"].append(
280
+ term_completion_mask.float().mean().item()
281
+ )
282
+ self._metrics[mode]["completions/min_terminated_length"].append(
283
+ term_completion_mask.float().min().item()
284
+ )
285
+ self._metrics[mode]["completions/max_terminated_length"].append(
286
+ term_completion_mask.float().max().item()
287
+ )
288
+
289
+ # Calculate mean reward
290
+ self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
291
+ self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
292
+
293
+ # Log prompt and completion texts
294
+ self._textual_logs["prompt"].extend(gather_object(prompts_text))
295
+ self._textual_logs["completion"].extend(gather_object(completions_text))
223
296
 
224
297
  return {
225
298
  "prompt_ids": prompt_ids,
@@ -227,7 +300,6 @@ class ArborGRPOTrainer(GRPOTrainer):
227
300
  "completion_ids": completion_ids,
228
301
  "completion_mask": completion_mask,
229
302
  "old_per_token_logps": old_per_token_logps,
230
- "ref_per_token_logps": ref_per_token_logps,
231
303
  "advantages": advantages,
232
304
  }
233
305
 
@@ -326,15 +398,10 @@ class CommandMonitor:
326
398
  print(
327
399
  f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
328
400
  )
329
- # Wait until data queue is empty before saving
330
-
331
401
  while (
332
402
  time_since_last_step() <= 10
333
403
  or get_time_since_last_queue_pop() <= 10
334
404
  ):
335
- # print(
336
- # f"Waiting for data queue to empty...{self.comms_handler.get_data_queue_size()}"
337
- # )
338
405
  print(f"Waiting for steps to finish")
339
406
  print(
340
407
  f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
@@ -342,15 +409,12 @@ class CommandMonitor:
342
409
  print(
343
410
  f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
344
411
  )
345
- time.sleep(5) # Small delay to prevent busy waiting)
412
+ time.sleep(5)
346
413
  print("[Training Script] Saving model...")
347
-
348
414
  if self.trainer.peft_config:
349
-
350
415
  self.trainer.save_model(
351
416
  output_dir=self.trainer.args.output_dir + "/adapter/"
352
417
  )
353
-
354
418
  _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
355
419
  self.trainer.args.output_dir + "/adapter/",
356
420
  config=self.trainer.peft_config,
@@ -373,6 +437,56 @@ class CommandMonitor:
373
437
  "output_dir": self.trainer.args.output_dir,
374
438
  }
375
439
  )
440
+ elif command.get("command") == "save_checkpoint":
441
+ print(
442
+ f"[Training Script] Instructed to save checkpoint {command.get('checkpoint_name')}"
443
+ )
444
+ while (
445
+ time_since_last_step() <= 10
446
+ or get_time_since_last_queue_pop() <= 10
447
+ ):
448
+ print(f"Waiting for steps to finish")
449
+ print(
450
+ f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
451
+ )
452
+ print(
453
+ f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
454
+ )
455
+ time.sleep(5)
456
+ if self.trainer.peft_config:
457
+ self.trainer.save_model(
458
+ output_dir=self.trainer.args.output_dir
459
+ + f"/checkpoints/{command.get('checkpoint_name')}/adapter/"
460
+ )
461
+ _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
462
+ self.trainer.args.output_dir
463
+ + f"/checkpoints/{command.get('checkpoint_name')}/adapter/",
464
+ config=self.trainer.peft_config,
465
+ )
466
+ merged_model = _model_to_merge.merge_and_unload()
467
+ merged_model.save_pretrained(
468
+ self.trainer.args.output_dir
469
+ + f"/checkpoints/{command.get('checkpoint_name')}/",
470
+ safe_serialization=True,
471
+ )
472
+ self.trainer.processing_class.save_pretrained(
473
+ self.trainer.args.output_dir
474
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
475
+ )
476
+ else:
477
+ self.trainer.save_model(
478
+ output_dir=self.trainer.args.output_dir
479
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
480
+ )
481
+ self.comms_handler.send_status(
482
+ {
483
+ "status": "checkpoint_saved",
484
+ "checkpoint_name": command.get("checkpoint_name"),
485
+ "output_dir": self.trainer.args.output_dir
486
+ + f"/checkpoints/{command.get('checkpoint_name')}/",
487
+ }
488
+ )
489
+
376
490
  except Exception as e:
377
491
  print(e)
378
492
  self.comms_handler.send_status({"status": "error", "error": str(e)})
@@ -385,13 +499,15 @@ class CommandMonitor:
385
499
  for broadcast in self.comms_handler.receive_broadcast():
386
500
  print(f"!!!Received broadcast: {broadcast}")
387
501
  if broadcast.get("message") == "terminate":
388
- self.trainer.control.should_training_stop = True
389
- self.comms_handler.send_status(
390
- {
391
- "status": "Received termination command",
392
- "process_id": self.trainer.accelerator.process_index,
393
- }
394
- )
502
+ # self.trainer.control.should_training_stop = True
503
+ # self.comms_handler.send_status(
504
+ # {
505
+ # "status": "Received termination command",
506
+ # "process_id": self.trainer.accelerator.process_index,
507
+ # }
508
+ # )
509
+ if self.trainer.accelerator.is_main_process:
510
+ self.trainer.accelerator.end_training()
395
511
  except Exception as e:
396
512
  self.comms_handler.send_status({"status": "error", "error": str(e)})
397
513
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.12
3
+ Version: 0.1.13
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -15,7 +15,7 @@ Requires-Dist: python-multipart
15
15
  Requires-Dist: pydantic-settings
16
16
  Requires-Dist: torch
17
17
  Requires-Dist: transformers
18
- Requires-Dist: trl
18
+ Requires-Dist: trl==0.17.0
19
19
  Requires-Dist: peft
20
20
  Requires-Dist: ray>=2.9
21
21
  Requires-Dist: setuptools<77.0.0,>=76.0.0
@@ -23,6 +23,7 @@ Requires-Dist: pyzmq>=26.4.0
23
23
  Requires-Dist: pyyaml>=6.0.2
24
24
  Requires-Dist: sglang[all]>=0.4.5.post3
25
25
  Requires-Dist: sglang-router
26
+ Requires-Dist: wandb
26
27
  Dynamic: license-file
27
28
 
28
29
  <p align="center">
@@ -5,10 +5,10 @@ arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
5
5
  arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
6
6
  arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
7
7
  arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
8
- arbor/server/api/models/schemas.py,sha256=s_G8sSb05FjkKEqpKpLlqaEd8NysJddHibRHhcnrKIk,5594
8
+ arbor/server/api/models/schemas.py,sha256=KCHav1nPFbQEynrcO-MObhRmoOrdFvfGuVogApynOCA,6210
9
9
  arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
11
- arbor/server/api/routes/grpo.py,sha256=VuEvSOwwrHegn9qM-1nbHFmmUnnC_BMwnIHsfIdiJyI,1877
11
+ arbor/server/api/routes/grpo.py,sha256=AbQ_BHgk-Om5U0qSt_FeJfyBJ0vItpfrnCNtJgD6p5k,2245
12
12
  arbor/server/api/routes/inference.py,sha256=Zy4ciN6vdRgu0-sFFnEeTZB-4XnLjEDH-atU7roIKSs,1668
13
13
  arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
14
14
  arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
@@ -17,18 +17,18 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
17
17
  arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
20
- arbor/server/services/grpo_manager.py,sha256=TAU2BMHgbCgiAvKNVd2Y8N20SR4qEms3lChA4Z0ZzyY,13777
21
- arbor/server/services/inference_manager.py,sha256=q4RVUqh1snGfW-AADkCqW8hC5x3WAZNe0jwXKOY5joU,10685
20
+ arbor/server/services/grpo_manager.py,sha256=-_0xjENvIrOAtHACkFPMYox9YAeckHbpX2FkrmKrWuU,15448
21
+ arbor/server/services/inference_manager.py,sha256=NcsUI-pgf3cRhU6P3xlPx0dxhvgYrfGZkEEGORcHcis,12833
22
22
  arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
23
23
  arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
24
24
  arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
26
- arbor/server/services/scripts/grpo_training.py,sha256=Q9jwnbRdXAv_jVgrChLX6IiB3BLZU1F3BP6mBV0DVik,20889
26
+ arbor/server/services/scripts/grpo_training.py,sha256=eMT5cIMolAzhukANH1WRmPdxIkvLbsbrggdGFCMGMHc,26474
27
27
  arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- arbor_ai-0.1.12.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
- arbor_ai-0.1.12.dist-info/METADATA,sha256=upqnB_F9JDLytHm4AFrDnvPaOHdj8XiBCdrlam0rgRc,2413
31
- arbor_ai-0.1.12.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
- arbor_ai-0.1.12.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
- arbor_ai-0.1.12.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
- arbor_ai-0.1.12.dist-info/RECORD,,
29
+ arbor_ai-0.1.13.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
+ arbor_ai-0.1.13.dist-info/METADATA,sha256=c0yScMpCiWYSFqVLjgk5TrRBuAVJK3aTBl0z0IPZ_8Y,2442
31
+ arbor_ai-0.1.13.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
32
+ arbor_ai-0.1.13.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
+ arbor_ai-0.1.13.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
+ arbor_ai-0.1.13.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.4.0)
2
+ Generator: setuptools (80.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5