arbor-ai 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/server/api/models/schemas.py +22 -1
- arbor/server/api/routes/grpo.py +15 -6
- arbor/server/services/grpo_manager.py +65 -17
- arbor/server/services/inference_manager.py +120 -72
- arbor/server/services/scripts/grpo_training.py +159 -43
- {arbor_ai-0.1.12.dist-info → arbor_ai-0.1.13.dist-info}/METADATA +3 -2
- {arbor_ai-0.1.12.dist-info → arbor_ai-0.1.13.dist-info}/RECORD +11 -11
- {arbor_ai-0.1.12.dist-info → arbor_ai-0.1.13.dist-info}/WHEEL +1 -1
- {arbor_ai-0.1.12.dist-info → arbor_ai-0.1.13.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.12.dist-info → arbor_ai-0.1.13.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.12.dist-info → arbor_ai-0.1.13.dist-info}/top_level.txt +0 -0
@@ -199,10 +199,16 @@ class GRPOConfigRequest(BaseModel):
|
|
199
199
|
bf16: Optional[bool] = None
|
200
200
|
scale_rewards: Optional[bool] = None
|
201
201
|
max_grad_norm: Optional[float] = None
|
202
|
+
report_to: Optional[str] = None
|
203
|
+
log_completions: Optional[bool] = None
|
204
|
+
logging_steps: Optional[int] = None
|
205
|
+
mask_truncated_completions: Optional[bool] = None
|
206
|
+
# Arbor specific
|
207
|
+
max_context_length: Optional[int] = None
|
202
208
|
lora: Optional[bool] = None
|
203
|
-
update_interval: Optional[int] = None
|
204
209
|
# To name the run
|
205
210
|
suffix: Optional[str] = None
|
211
|
+
generation_batch_size: Optional[int] = None
|
206
212
|
|
207
213
|
|
208
214
|
class GRPOConfigResponse(BaseModel):
|
@@ -216,8 +222,23 @@ class GRPOTerminateRequest(BaseModel):
|
|
216
222
|
class GRPOTerminateResponse(BaseModel):
|
217
223
|
status: str
|
218
224
|
current_model: str
|
225
|
+
checkpoints: Optional[dict[str, str]] = None
|
226
|
+
last_checkpoint: Optional[str] = None
|
219
227
|
|
220
228
|
|
221
229
|
class GRPOStepResponse(BaseModel):
|
222
230
|
status: str
|
223
231
|
current_model: str
|
232
|
+
checkpoints: dict[str, str]
|
233
|
+
last_checkpoint: Optional[str] = None
|
234
|
+
|
235
|
+
|
236
|
+
class GRPOCheckpointRequest(BaseModel):
|
237
|
+
checkpoint_name: str
|
238
|
+
|
239
|
+
|
240
|
+
class GRPOCheckpointResponse(BaseModel):
|
241
|
+
status: str
|
242
|
+
current_model: str
|
243
|
+
checkpoints: dict[str, str]
|
244
|
+
last_checkpoint: str
|
arbor/server/api/routes/grpo.py
CHANGED
@@ -4,6 +4,8 @@ import subprocess
|
|
4
4
|
from fastapi import APIRouter, BackgroundTasks, Request
|
5
5
|
|
6
6
|
from arbor.server.api.models.schemas import (
|
7
|
+
GRPOCheckpointRequest,
|
8
|
+
GRPOCheckpointResponse,
|
7
9
|
GRPOConfigRequest,
|
8
10
|
GRPOConfigResponse,
|
9
11
|
GRPORequest,
|
@@ -31,17 +33,24 @@ def run_grpo_step(
|
|
31
33
|
inference_manager = request.app.state.inference_manager
|
32
34
|
grpo_manager = request.app.state.grpo_manager
|
33
35
|
|
34
|
-
|
36
|
+
step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
|
35
37
|
|
36
|
-
return GRPOStepResponse(status="success",
|
38
|
+
return GRPOStepResponse(status="success", **step_data)
|
37
39
|
|
38
40
|
|
39
41
|
@router.post("/update_model", response_model=GRPOStepResponse)
|
40
42
|
def update_model(request: Request):
|
41
43
|
grpo_manager = request.app.state.grpo_manager
|
42
44
|
inference_manager = request.app.state.inference_manager
|
43
|
-
|
44
|
-
return GRPOStepResponse(status="success",
|
45
|
+
update_model_data = grpo_manager.update_model(request, inference_manager)
|
46
|
+
return GRPOStepResponse(status="success", **update_model_data)
|
47
|
+
|
48
|
+
|
49
|
+
@router.post("/checkpoint", response_model=GRPOCheckpointResponse)
|
50
|
+
def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
|
51
|
+
grpo_manager = request.app.state.grpo_manager
|
52
|
+
checkpoint_data = grpo_manager.checkpoint(grpo_checkpoint_request)
|
53
|
+
return GRPOCheckpointResponse(status="success", **checkpoint_data)
|
45
54
|
|
46
55
|
|
47
56
|
@router.post("/terminate", response_model=GRPOTerminateResponse)
|
@@ -50,5 +59,5 @@ def terminate_grpo(request: Request):
|
|
50
59
|
grpo_manager = request.app.state.grpo_manager
|
51
60
|
inference_manager = request.app.state.inference_manager
|
52
61
|
|
53
|
-
|
54
|
-
return GRPOTerminateResponse(status="success",
|
62
|
+
terminate_data = grpo_manager.terminate(inference_manager)
|
63
|
+
return GRPOTerminateResponse(status="success", **terminate_data)
|
@@ -13,7 +13,11 @@ from datetime import datetime
|
|
13
13
|
from pathlib import Path
|
14
14
|
from typing import Optional
|
15
15
|
|
16
|
-
from arbor.server.api.models.schemas import
|
16
|
+
from arbor.server.api.models.schemas import (
|
17
|
+
GRPOCheckpointRequest,
|
18
|
+
GRPOConfigRequest,
|
19
|
+
GRPORequest,
|
20
|
+
)
|
17
21
|
from arbor.server.core.config import Settings
|
18
22
|
from arbor.server.services.comms.comms import ArborServerCommsHandler
|
19
23
|
from arbor.server.services.inference_manager import InferenceManager
|
@@ -28,7 +32,10 @@ class GRPOManager:
|
|
28
32
|
self.server_comms_handler = None
|
29
33
|
self.status_thread = None
|
30
34
|
self.model_saved_and_reload_requested = False
|
35
|
+
self.saving_checkpoint = False
|
31
36
|
|
37
|
+
self.checkpoints = {}
|
38
|
+
self.last_checkpoint = None
|
32
39
|
self.data_count = 0
|
33
40
|
self.last_inference_update = 0
|
34
41
|
# Set up signal handler
|
@@ -86,12 +93,17 @@ class GRPOManager:
|
|
86
93
|
"bf16",
|
87
94
|
"scale_rewards",
|
88
95
|
"max_grad_norm",
|
96
|
+
"report_to",
|
97
|
+
"log_completions",
|
98
|
+
"logging_steps",
|
99
|
+
"generation_batch_size",
|
100
|
+
"mask_truncated_completions",
|
89
101
|
]
|
90
102
|
trl_train_kwargs = {
|
91
103
|
key: train_kwargs[key] for key in trl_keys if key in train_kwargs
|
92
104
|
}
|
93
105
|
|
94
|
-
arbor_keys = ["
|
106
|
+
arbor_keys = ["max_context_length", "lora"]
|
95
107
|
arbor_train_kwargs = {
|
96
108
|
key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
|
97
109
|
}
|
@@ -119,6 +131,8 @@ class GRPOManager:
|
|
119
131
|
# Start the training process with ZMQ ports
|
120
132
|
my_env = os.environ.copy()
|
121
133
|
my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
|
134
|
+
# WandB can block the training process for login, so we silence it
|
135
|
+
my_env["WANDB_SILENT"] = "true"
|
122
136
|
|
123
137
|
num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
|
124
138
|
|
@@ -209,6 +223,12 @@ class GRPOManager:
|
|
209
223
|
|
210
224
|
# Launch the inference server
|
211
225
|
print("Launching inference server...")
|
226
|
+
# launch_kwargs = {
|
227
|
+
# k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
|
228
|
+
# }
|
229
|
+
inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
|
230
|
+
"max_context_length", None
|
231
|
+
)
|
212
232
|
inference_manager.launch(self.current_model)
|
213
233
|
|
214
234
|
def _handle_status_updates(self, inference_manager: InferenceManager):
|
@@ -228,6 +248,12 @@ class GRPOManager:
|
|
228
248
|
self.model_saved_and_reload_requested = False
|
229
249
|
self.current_model = status["output_dir"]
|
230
250
|
print("Model update complete")
|
251
|
+
elif status["status"] == "checkpoint_saved":
|
252
|
+
print("Received checkpoint saved status")
|
253
|
+
self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
|
254
|
+
self.last_checkpoint = status["checkpoint_name"]
|
255
|
+
self.saving_checkpoint = False
|
256
|
+
print("Checkpoint saved")
|
231
257
|
elif status["status"] == "error":
|
232
258
|
print(f"Training error: {status.get('error', 'Unknown error')}")
|
233
259
|
elif status["status"] == "terminated":
|
@@ -249,6 +275,10 @@ class GRPOManager:
|
|
249
275
|
)
|
250
276
|
time.sleep(5)
|
251
277
|
|
278
|
+
while self.saving_checkpoint:
|
279
|
+
print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
|
280
|
+
time.sleep(5)
|
281
|
+
|
252
282
|
try:
|
253
283
|
# Send the batch to the training process
|
254
284
|
self.server_comms_handler.send_data(request.batch)
|
@@ -256,12 +286,11 @@ class GRPOManager:
|
|
256
286
|
except Exception as e:
|
257
287
|
print(f"Failed to send batch to training process: {e}")
|
258
288
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
return self.current_model
|
289
|
+
return {
|
290
|
+
"current_model": self.current_model,
|
291
|
+
"checkpoints": self.checkpoints,
|
292
|
+
"last_checkpoint": self.last_checkpoint,
|
293
|
+
}
|
265
294
|
|
266
295
|
def update_model(self, request, inference_manager: InferenceManager):
|
267
296
|
if inference_manager._session:
|
@@ -286,18 +315,41 @@ class GRPOManager:
|
|
286
315
|
"Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
|
287
316
|
)
|
288
317
|
time.sleep(5)
|
289
|
-
return
|
318
|
+
return {
|
319
|
+
"current_model": self.current_model,
|
320
|
+
"checkpoints": self.checkpoints,
|
321
|
+
"last_checkpoint": self.last_checkpoint,
|
322
|
+
}
|
323
|
+
|
324
|
+
def checkpoint(self, request: GRPOCheckpointRequest):
|
325
|
+
self.saving_checkpoint = True
|
326
|
+
self.server_comms_handler.send_command(
|
327
|
+
{"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
|
328
|
+
)
|
329
|
+
while self.saving_checkpoint:
|
330
|
+
print("Waiting for checkpoint to be saved...")
|
331
|
+
time.sleep(5)
|
332
|
+
return {
|
333
|
+
"current_model": self.current_model,
|
334
|
+
"checkpoints": self.checkpoints,
|
335
|
+
"last_checkpoint": self.last_checkpoint,
|
336
|
+
}
|
290
337
|
|
291
338
|
def terminate(self, inference_manager: InferenceManager):
|
292
339
|
"""Clean up resources and save the final model."""
|
340
|
+
termination_data = {
|
341
|
+
"current_model": self.current_model,
|
342
|
+
"checkpoints": self.checkpoints,
|
343
|
+
"last_checkpoint": self.last_checkpoint,
|
344
|
+
}
|
293
345
|
try:
|
294
346
|
# Stop the inference server
|
295
347
|
if inference_manager.process is not None:
|
296
348
|
inference_manager.kill()
|
297
349
|
|
298
350
|
# Send termination command through REQ socket
|
299
|
-
|
300
|
-
self.training_process.terminate()
|
351
|
+
self.server_comms_handler.send_broadcast({"message": "terminate"})
|
352
|
+
# self.training_process.terminate()
|
301
353
|
print("Waiting for training process to finish")
|
302
354
|
|
303
355
|
# Wait for training process to finish
|
@@ -336,17 +388,13 @@ class GRPOManager:
|
|
336
388
|
)
|
337
389
|
output_dir = self.train_kwargs["output_dir"]
|
338
390
|
self.train_kwargs = None
|
339
|
-
return output_dir
|
340
391
|
else:
|
341
392
|
print("Training terminated, no output directory specified")
|
342
393
|
self.train_kwargs = None
|
343
|
-
|
394
|
+
|
395
|
+
return termination_data
|
344
396
|
|
345
397
|
def _should_update_model(self):
|
346
|
-
# return (
|
347
|
-
# self.data_count - self.last_inference_update
|
348
|
-
# >= self.train_kwargs["update_interval"]
|
349
|
-
# )
|
350
398
|
return self.model_saved_and_reload_requested
|
351
399
|
|
352
400
|
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
import os
|
4
|
+
import random
|
4
5
|
import signal
|
5
6
|
import socket
|
6
7
|
import subprocess
|
@@ -47,7 +48,12 @@ class InferenceManager:
|
|
47
48
|
def is_server_restarting(self):
|
48
49
|
return self.restarting
|
49
50
|
|
50
|
-
def launch(
|
51
|
+
def launch(
|
52
|
+
self,
|
53
|
+
model: str,
|
54
|
+
launch_kwargs: Optional[Dict[str, Any]] = None,
|
55
|
+
max_retries: int = 3,
|
56
|
+
):
|
51
57
|
if self.is_server_running():
|
52
58
|
print("Server is already launched.")
|
53
59
|
return
|
@@ -59,77 +65,112 @@ class InferenceManager:
|
|
59
65
|
if model.startswith(prefix):
|
60
66
|
model = model[len(prefix) :]
|
61
67
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
logs_buffer = []
|
86
|
-
|
87
|
-
def _tail_process(proc, buffer, stop_event):
|
88
|
-
while True:
|
89
|
-
line = proc.stdout.readline()
|
90
|
-
if not line and proc.poll() is not None:
|
91
|
-
# Process ended and no new line
|
92
|
-
break
|
93
|
-
if line:
|
94
|
-
buffer.append(line)
|
95
|
-
# Print only if stop_event is not set
|
96
|
-
if not stop_event.is_set():
|
97
|
-
print(f"[SGLang LOG] {line}", end="")
|
98
|
-
|
99
|
-
# Start a background thread to read from the process continuously
|
100
|
-
thread = threading.Thread(
|
101
|
-
target=_tail_process,
|
102
|
-
args=(process, logs_buffer, stop_printing_event),
|
103
|
-
daemon=True,
|
104
|
-
)
|
105
|
-
thread.start()
|
106
|
-
|
107
|
-
# Wait until the server is ready (or times out)
|
108
|
-
base_url = f"http://localhost:{port}"
|
109
|
-
try:
|
110
|
-
wait_for_server(base_url, timeout=timeout)
|
111
|
-
except TimeoutError:
|
112
|
-
# If the server doesn't come up, we might want to kill it:
|
113
|
-
process.kill()
|
114
|
-
raise
|
115
|
-
|
116
|
-
# Once server is ready, we tell the thread to stop printing further lines.
|
117
|
-
stop_printing_event.set()
|
118
|
-
|
119
|
-
# A convenience getter so the caller can see all logs so far (and future).
|
120
|
-
def get_logs() -> str:
|
121
|
-
# Join them all into a single string, or you might return a list
|
122
|
-
return "".join(logs_buffer)
|
123
|
-
|
124
|
-
# Let the user know server is up
|
125
|
-
print(f"Server ready on random port {port}!")
|
68
|
+
retries = 0
|
69
|
+
while retries < max_retries:
|
70
|
+
try:
|
71
|
+
print(
|
72
|
+
f"Attempt {retries + 1} of {max_retries} to launch server for model {model}"
|
73
|
+
)
|
74
|
+
print(
|
75
|
+
f"Grabbing a free port to launch an SGLang server for model {model}"
|
76
|
+
)
|
77
|
+
port = get_free_port()
|
78
|
+
timeout = launch_kwargs.get("timeout", 1800)
|
79
|
+
my_env = os.environ.copy()
|
80
|
+
my_env["CUDA_VISIBLE_DEVICES"] = (
|
81
|
+
self.settings.arbor_config.inference.gpu_ids
|
82
|
+
)
|
83
|
+
n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
|
84
|
+
# command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
|
85
|
+
command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
|
86
|
+
print(f"Running command: {command}")
|
87
|
+
if launch_kwargs.get("max_context_length"):
|
88
|
+
command += (
|
89
|
+
f" --context-length {launch_kwargs['max_context_length']}"
|
90
|
+
)
|
126
91
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
92
|
+
# We will manually stream & capture logs.
|
93
|
+
process = subprocess.Popen(
|
94
|
+
command.replace("\\\n", " ").replace("\\", " ").split(),
|
95
|
+
text=True,
|
96
|
+
stdout=subprocess.PIPE, # We'll read from pipe
|
97
|
+
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
98
|
+
env=my_env,
|
99
|
+
)
|
100
|
+
|
101
|
+
# A threading.Event to control printing after the server is ready.
|
102
|
+
# This will store *all* lines (both before and after readiness).
|
103
|
+
print(f"SGLang server process started with PID {process.pid}.")
|
104
|
+
stop_printing_event = threading.Event()
|
105
|
+
logs_buffer = []
|
106
|
+
|
107
|
+
def _tail_process(proc, buffer, stop_event):
|
108
|
+
while True:
|
109
|
+
line = proc.stdout.readline()
|
110
|
+
if not line and proc.poll() is not None:
|
111
|
+
# Process ended and no new line
|
112
|
+
break
|
113
|
+
if line:
|
114
|
+
buffer.append(line)
|
115
|
+
# Print only if stop_event is not set
|
116
|
+
if not stop_event.is_set():
|
117
|
+
print(f"[SGLang LOG] {line}", end="")
|
118
|
+
|
119
|
+
# Start a background thread to read from the process continuously
|
120
|
+
thread = threading.Thread(
|
121
|
+
target=_tail_process,
|
122
|
+
args=(process, logs_buffer, stop_printing_event),
|
123
|
+
daemon=True,
|
124
|
+
)
|
125
|
+
thread.start()
|
126
|
+
|
127
|
+
# Wait until the server is ready (or times out)
|
128
|
+
base_url = f"http://localhost:{port}"
|
129
|
+
try:
|
130
|
+
wait_for_server(base_url, timeout=timeout)
|
131
|
+
except TimeoutError:
|
132
|
+
# If the server doesn't come up, we might want to kill it:
|
133
|
+
process.kill()
|
134
|
+
raise
|
135
|
+
|
136
|
+
# Once server is ready, we tell the thread to stop printing further lines.
|
137
|
+
stop_printing_event.set()
|
138
|
+
|
139
|
+
# A convenience getter so the caller can see all logs so far (and future).
|
140
|
+
def get_logs() -> str:
|
141
|
+
# Join them all into a single string, or you might return a list
|
142
|
+
return "".join(logs_buffer)
|
143
|
+
|
144
|
+
# Let the user know server is up
|
145
|
+
print(f"Server ready on random port {port}!")
|
146
|
+
|
147
|
+
self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
|
148
|
+
self.launch_kwargs["api_key"] = "local"
|
149
|
+
self.get_logs = get_logs
|
150
|
+
self.process = process
|
151
|
+
self.thread = thread
|
152
|
+
self.current_model = model
|
153
|
+
|
154
|
+
# If we get here, the launch was successful
|
155
|
+
return
|
156
|
+
|
157
|
+
except Exception as e:
|
158
|
+
retries += 1
|
159
|
+
print(
|
160
|
+
f"Failed to launch server (attempt {retries} of {max_retries}): {str(e)}"
|
161
|
+
)
|
162
|
+
# Clean up any failed processes
|
163
|
+
if "process" in locals():
|
164
|
+
try:
|
165
|
+
process.kill()
|
166
|
+
except:
|
167
|
+
pass
|
168
|
+
if retries == max_retries:
|
169
|
+
raise Exception(
|
170
|
+
f"Failed to launch server after {max_retries} attempts"
|
171
|
+
) from e
|
172
|
+
# Wait a bit before retrying
|
173
|
+
time.sleep(min(2**retries, 30)) # Exponential backoff, max 30 seconds
|
133
174
|
|
134
175
|
def kill(self):
|
135
176
|
from sglang.utils import terminate_process
|
@@ -184,7 +225,7 @@ class InferenceManager:
|
|
184
225
|
print(f"Running inference for model {model}")
|
185
226
|
# Monkeypatch:
|
186
227
|
if model != self.current_model:
|
187
|
-
print(f"
|
228
|
+
print(f"Model changed from {model} to {self.current_model}")
|
188
229
|
model = self.current_model
|
189
230
|
request_json["model"] = model
|
190
231
|
|
@@ -214,6 +255,12 @@ class InferenceManager:
|
|
214
255
|
await self._session.close()
|
215
256
|
self._session = None
|
216
257
|
return None
|
258
|
+
except json.decoder.JSONDecodeError:
|
259
|
+
print(f"JSON Decode Error during inference: {content}")
|
260
|
+
return {
|
261
|
+
"error": "JSON Decode Error",
|
262
|
+
"content": content if content else "Content is null",
|
263
|
+
}
|
217
264
|
except Exception as e:
|
218
265
|
print(f"Error during inference: {e}")
|
219
266
|
raise
|
@@ -241,6 +288,7 @@ class InferenceManager:
|
|
241
288
|
tik = time.time()
|
242
289
|
self.kill()
|
243
290
|
print("Just killed server")
|
291
|
+
time.sleep(5)
|
244
292
|
# Check that output directory exists and was created successfully
|
245
293
|
print(f"Checking that output directory {output_dir} exists")
|
246
294
|
if not os.path.exists(output_dir):
|
@@ -14,7 +14,7 @@ from typing import Any, List, Optional, Union
|
|
14
14
|
import torch
|
15
15
|
import zmq
|
16
16
|
from accelerate import Accelerator
|
17
|
-
from accelerate.utils import gather
|
17
|
+
from accelerate.utils import broadcast_object_list, gather, gather_object
|
18
18
|
from datasets import Dataset, IterableDataset, load_dataset
|
19
19
|
from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig # type: ignore
|
20
20
|
from torch.utils.data import Dataset
|
@@ -23,6 +23,7 @@ from transformers import (
|
|
23
23
|
PreTrainedTokenizerBase,
|
24
24
|
Trainer,
|
25
25
|
TrainerCallback,
|
26
|
+
is_wandb_available,
|
26
27
|
)
|
27
28
|
from trl import GRPOConfig, GRPOTrainer
|
28
29
|
from trl.data_utils import maybe_apply_chat_template
|
@@ -32,6 +33,9 @@ from arbor.server.services.comms.comms import (
|
|
32
33
|
ArborServerCommsHandler,
|
33
34
|
)
|
34
35
|
|
36
|
+
if is_wandb_available():
|
37
|
+
import wandb
|
38
|
+
|
35
39
|
last_step_time = None
|
36
40
|
last_queue_pop_time = None
|
37
41
|
|
@@ -65,8 +69,9 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
65
69
|
] = (None, None),
|
66
70
|
peft_config: Optional["PeftConfig"] = None,
|
67
71
|
comms_handler: Optional[ArborScriptCommsHandler] = None,
|
68
|
-
update_interval: Optional[int] = 5,
|
69
72
|
lora: Optional[bool] = False,
|
73
|
+
# We do nothing with max_context_length right now
|
74
|
+
max_context_length: Optional[int] = None,
|
70
75
|
**kwargs,
|
71
76
|
):
|
72
77
|
|
@@ -85,12 +90,12 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
85
90
|
self.peft_config = peft_config
|
86
91
|
self.scale_rewards = scale_rewards
|
87
92
|
self.comms_handler = comms_handler
|
88
|
-
self.update_interval = update_interval
|
89
93
|
|
90
94
|
def _generate_and_score_completions(
|
91
95
|
self, batch: List[dict[str, Any]]
|
92
96
|
) -> dict[str, Union[torch.Tensor, Any]]:
|
93
97
|
device = self.accelerator.device
|
98
|
+
mode = "train" if self.model.training else "eval"
|
94
99
|
|
95
100
|
# Process prompts and completions
|
96
101
|
prompt_completion_texts = []
|
@@ -106,12 +111,12 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
106
111
|
)
|
107
112
|
|
108
113
|
# Tokenize prompts
|
109
|
-
|
114
|
+
prompts_text = [
|
110
115
|
prompt_completion_text["prompt"]
|
111
116
|
for prompt_completion_text in prompt_completion_texts
|
112
117
|
]
|
113
118
|
prompt_inputs = self.processing_class(
|
114
|
-
|
119
|
+
prompts_text,
|
115
120
|
return_tensors="pt",
|
116
121
|
padding=True,
|
117
122
|
padding_side="left",
|
@@ -124,12 +129,12 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
124
129
|
)
|
125
130
|
|
126
131
|
# Tokenize completions
|
127
|
-
|
132
|
+
completions_text = [
|
128
133
|
prompt_completion_text["completion"]
|
129
134
|
for prompt_completion_text in prompt_completion_texts
|
130
135
|
]
|
131
136
|
completion_ids = self.processing_class(
|
132
|
-
|
137
|
+
completions_text,
|
133
138
|
return_tensors="pt",
|
134
139
|
padding=True,
|
135
140
|
add_special_tokens=False,
|
@@ -156,6 +161,30 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
156
161
|
# self._move_model_to_vllm()
|
157
162
|
# self._last_loaded_step = self.state.global_step
|
158
163
|
|
164
|
+
prompt_ids = broadcast_object_list(prompt_ids)
|
165
|
+
prompt_mask = broadcast_object_list(prompt_mask)
|
166
|
+
completion_ids = broadcast_object_list(completion_ids)
|
167
|
+
completion_mask = broadcast_object_list(completion_mask)
|
168
|
+
|
169
|
+
process_slice = slice(
|
170
|
+
self.accelerator.process_index * len(batch),
|
171
|
+
(self.accelerator.process_index + 1) * len(batch),
|
172
|
+
)
|
173
|
+
|
174
|
+
prompt_ids = prompt_ids[process_slice]
|
175
|
+
prompt_mask = prompt_mask[process_slice]
|
176
|
+
completion_ids = completion_ids[process_slice]
|
177
|
+
completion_mask = completion_mask[process_slice]
|
178
|
+
|
179
|
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
180
|
+
|
181
|
+
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
182
|
+
if self.mask_truncated_completions:
|
183
|
+
truncated_completions = ~is_eos.any(dim=1)
|
184
|
+
completion_mask = (
|
185
|
+
completion_mask * (~truncated_completions).unsqueeze(1).int()
|
186
|
+
)
|
187
|
+
|
159
188
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
160
189
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
161
190
|
|
@@ -164,34 +193,29 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
164
193
|
)
|
165
194
|
|
166
195
|
logits_to_keep = completion_ids.size(1)
|
196
|
+
batch_size = (
|
197
|
+
self.args.per_device_train_batch_size
|
198
|
+
if mode == "train"
|
199
|
+
else self.args.per_device_eval_batch_size
|
200
|
+
)
|
167
201
|
|
168
202
|
with torch.no_grad():
|
169
203
|
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
170
204
|
# computation here, and use per_token_logps.detach() instead.
|
171
|
-
if
|
205
|
+
if (
|
206
|
+
self.num_iterations > 1
|
207
|
+
or self.args.steps_per_generation
|
208
|
+
> self.args.gradient_accumulation_steps
|
209
|
+
):
|
172
210
|
old_per_token_logps = self._get_per_token_logps(
|
173
|
-
self.model,
|
174
|
-
)
|
175
|
-
else:
|
176
|
-
old_per_token_logps = None
|
177
|
-
|
178
|
-
if self.beta == 0.0:
|
179
|
-
ref_per_token_logps = None
|
180
|
-
elif self.ref_model is not None:
|
181
|
-
ref_per_token_logps = self._get_per_token_logps(
|
182
|
-
self.ref_model,
|
211
|
+
self.model,
|
183
212
|
prompt_completion_ids,
|
184
213
|
attention_mask,
|
185
214
|
logits_to_keep,
|
215
|
+
batch_size,
|
186
216
|
)
|
187
217
|
else:
|
188
|
-
|
189
|
-
ref_per_token_logps = self._get_per_token_logps(
|
190
|
-
self.model,
|
191
|
-
prompt_completion_ids,
|
192
|
-
attention_mask,
|
193
|
-
logits_to_keep,
|
194
|
-
)
|
218
|
+
old_per_token_logps = None
|
195
219
|
|
196
220
|
rewards = torch.tensor(
|
197
221
|
[example["reward"] for example in batch], dtype=torch.float32
|
@@ -219,7 +243,56 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
219
243
|
)
|
220
244
|
advantages = advantages[process_slice]
|
221
245
|
|
222
|
-
|
246
|
+
# Log the metrics
|
247
|
+
if mode == "train":
|
248
|
+
self.state.num_input_tokens_seen += (
|
249
|
+
self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
|
250
|
+
)
|
251
|
+
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
|
252
|
+
|
253
|
+
# log completion lengths, mean, min, max
|
254
|
+
agg_completion_mask = self.accelerator.gather_for_metrics(
|
255
|
+
completion_mask.sum(1)
|
256
|
+
)
|
257
|
+
self._metrics[mode]["completions/mean_length"].append(
|
258
|
+
agg_completion_mask.float().mean().item()
|
259
|
+
)
|
260
|
+
self._metrics[mode]["completions/min_length"].append(
|
261
|
+
agg_completion_mask.float().min().item()
|
262
|
+
)
|
263
|
+
self._metrics[mode]["completions/max_length"].append(
|
264
|
+
agg_completion_mask.float().max().item()
|
265
|
+
)
|
266
|
+
|
267
|
+
# identify sequences that terminated with EOS and log their lengths
|
268
|
+
agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
|
269
|
+
term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
|
270
|
+
clipped_completions_ratio = 1 - len(term_completion_mask) / len(
|
271
|
+
agg_completion_mask
|
272
|
+
)
|
273
|
+
self._metrics[mode]["completions/clipped_ratio"].append(
|
274
|
+
clipped_completions_ratio
|
275
|
+
)
|
276
|
+
if len(term_completion_mask) == 0:
|
277
|
+
# edge case where no completed sequences are found
|
278
|
+
term_completion_mask = torch.zeros(1, device=device)
|
279
|
+
self._metrics[mode]["completions/mean_terminated_length"].append(
|
280
|
+
term_completion_mask.float().mean().item()
|
281
|
+
)
|
282
|
+
self._metrics[mode]["completions/min_terminated_length"].append(
|
283
|
+
term_completion_mask.float().min().item()
|
284
|
+
)
|
285
|
+
self._metrics[mode]["completions/max_terminated_length"].append(
|
286
|
+
term_completion_mask.float().max().item()
|
287
|
+
)
|
288
|
+
|
289
|
+
# Calculate mean reward
|
290
|
+
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
|
291
|
+
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
|
292
|
+
|
293
|
+
# Log prompt and completion texts
|
294
|
+
self._textual_logs["prompt"].extend(gather_object(prompts_text))
|
295
|
+
self._textual_logs["completion"].extend(gather_object(completions_text))
|
223
296
|
|
224
297
|
return {
|
225
298
|
"prompt_ids": prompt_ids,
|
@@ -227,7 +300,6 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
227
300
|
"completion_ids": completion_ids,
|
228
301
|
"completion_mask": completion_mask,
|
229
302
|
"old_per_token_logps": old_per_token_logps,
|
230
|
-
"ref_per_token_logps": ref_per_token_logps,
|
231
303
|
"advantages": advantages,
|
232
304
|
}
|
233
305
|
|
@@ -326,15 +398,10 @@ class CommandMonitor:
|
|
326
398
|
print(
|
327
399
|
f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
|
328
400
|
)
|
329
|
-
# Wait until data queue is empty before saving
|
330
|
-
|
331
401
|
while (
|
332
402
|
time_since_last_step() <= 10
|
333
403
|
or get_time_since_last_queue_pop() <= 10
|
334
404
|
):
|
335
|
-
# print(
|
336
|
-
# f"Waiting for data queue to empty...{self.comms_handler.get_data_queue_size()}"
|
337
|
-
# )
|
338
405
|
print(f"Waiting for steps to finish")
|
339
406
|
print(
|
340
407
|
f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
|
@@ -342,15 +409,12 @@ class CommandMonitor:
|
|
342
409
|
print(
|
343
410
|
f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
|
344
411
|
)
|
345
|
-
time.sleep(5)
|
412
|
+
time.sleep(5)
|
346
413
|
print("[Training Script] Saving model...")
|
347
|
-
|
348
414
|
if self.trainer.peft_config:
|
349
|
-
|
350
415
|
self.trainer.save_model(
|
351
416
|
output_dir=self.trainer.args.output_dir + "/adapter/"
|
352
417
|
)
|
353
|
-
|
354
418
|
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
355
419
|
self.trainer.args.output_dir + "/adapter/",
|
356
420
|
config=self.trainer.peft_config,
|
@@ -373,6 +437,56 @@ class CommandMonitor:
|
|
373
437
|
"output_dir": self.trainer.args.output_dir,
|
374
438
|
}
|
375
439
|
)
|
440
|
+
elif command.get("command") == "save_checkpoint":
|
441
|
+
print(
|
442
|
+
f"[Training Script] Instructed to save checkpoint {command.get('checkpoint_name')}"
|
443
|
+
)
|
444
|
+
while (
|
445
|
+
time_since_last_step() <= 10
|
446
|
+
or get_time_since_last_queue_pop() <= 10
|
447
|
+
):
|
448
|
+
print(f"Waiting for steps to finish")
|
449
|
+
print(
|
450
|
+
f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
|
451
|
+
)
|
452
|
+
print(
|
453
|
+
f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
|
454
|
+
)
|
455
|
+
time.sleep(5)
|
456
|
+
if self.trainer.peft_config:
|
457
|
+
self.trainer.save_model(
|
458
|
+
output_dir=self.trainer.args.output_dir
|
459
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/adapter/"
|
460
|
+
)
|
461
|
+
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
462
|
+
self.trainer.args.output_dir
|
463
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/adapter/",
|
464
|
+
config=self.trainer.peft_config,
|
465
|
+
)
|
466
|
+
merged_model = _model_to_merge.merge_and_unload()
|
467
|
+
merged_model.save_pretrained(
|
468
|
+
self.trainer.args.output_dir
|
469
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/",
|
470
|
+
safe_serialization=True,
|
471
|
+
)
|
472
|
+
self.trainer.processing_class.save_pretrained(
|
473
|
+
self.trainer.args.output_dir
|
474
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
475
|
+
)
|
476
|
+
else:
|
477
|
+
self.trainer.save_model(
|
478
|
+
output_dir=self.trainer.args.output_dir
|
479
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
480
|
+
)
|
481
|
+
self.comms_handler.send_status(
|
482
|
+
{
|
483
|
+
"status": "checkpoint_saved",
|
484
|
+
"checkpoint_name": command.get("checkpoint_name"),
|
485
|
+
"output_dir": self.trainer.args.output_dir
|
486
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/",
|
487
|
+
}
|
488
|
+
)
|
489
|
+
|
376
490
|
except Exception as e:
|
377
491
|
print(e)
|
378
492
|
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
@@ -385,13 +499,15 @@ class CommandMonitor:
|
|
385
499
|
for broadcast in self.comms_handler.receive_broadcast():
|
386
500
|
print(f"!!!Received broadcast: {broadcast}")
|
387
501
|
if broadcast.get("message") == "terminate":
|
388
|
-
self.trainer.control.should_training_stop = True
|
389
|
-
self.comms_handler.send_status(
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
)
|
502
|
+
# self.trainer.control.should_training_stop = True
|
503
|
+
# self.comms_handler.send_status(
|
504
|
+
# {
|
505
|
+
# "status": "Received termination command",
|
506
|
+
# "process_id": self.trainer.accelerator.process_index,
|
507
|
+
# }
|
508
|
+
# )
|
509
|
+
if self.trainer.accelerator.is_main_process:
|
510
|
+
self.trainer.accelerator.end_training()
|
395
511
|
except Exception as e:
|
396
512
|
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
397
513
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.13
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -15,7 +15,7 @@ Requires-Dist: python-multipart
|
|
15
15
|
Requires-Dist: pydantic-settings
|
16
16
|
Requires-Dist: torch
|
17
17
|
Requires-Dist: transformers
|
18
|
-
Requires-Dist: trl
|
18
|
+
Requires-Dist: trl==0.17.0
|
19
19
|
Requires-Dist: peft
|
20
20
|
Requires-Dist: ray>=2.9
|
21
21
|
Requires-Dist: setuptools<77.0.0,>=76.0.0
|
@@ -23,6 +23,7 @@ Requires-Dist: pyzmq>=26.4.0
|
|
23
23
|
Requires-Dist: pyyaml>=6.0.2
|
24
24
|
Requires-Dist: sglang[all]>=0.4.5.post3
|
25
25
|
Requires-Dist: sglang-router
|
26
|
+
Requires-Dist: wandb
|
26
27
|
Dynamic: license-file
|
27
28
|
|
28
29
|
<p align="center">
|
@@ -5,10 +5,10 @@ arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
|
|
5
5
|
arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
6
6
|
arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
|
7
7
|
arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
8
|
-
arbor/server/api/models/schemas.py,sha256=
|
8
|
+
arbor/server/api/models/schemas.py,sha256=KCHav1nPFbQEynrcO-MObhRmoOrdFvfGuVogApynOCA,6210
|
9
9
|
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
|
11
|
-
arbor/server/api/routes/grpo.py,sha256=
|
11
|
+
arbor/server/api/routes/grpo.py,sha256=AbQ_BHgk-Om5U0qSt_FeJfyBJ0vItpfrnCNtJgD6p5k,2245
|
12
12
|
arbor/server/api/routes/inference.py,sha256=Zy4ciN6vdRgu0-sFFnEeTZB-4XnLjEDH-atU7roIKSs,1668
|
13
13
|
arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
|
14
14
|
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
@@ -17,18 +17,18 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
17
17
|
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
18
|
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
|
-
arbor/server/services/grpo_manager.py,sha256
|
21
|
-
arbor/server/services/inference_manager.py,sha256=
|
20
|
+
arbor/server/services/grpo_manager.py,sha256=-_0xjENvIrOAtHACkFPMYox9YAeckHbpX2FkrmKrWuU,15448
|
21
|
+
arbor/server/services/inference_manager.py,sha256=NcsUI-pgf3cRhU6P3xlPx0dxhvgYrfGZkEEGORcHcis,12833
|
22
22
|
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
23
|
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
24
|
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
|
26
|
-
arbor/server/services/scripts/grpo_training.py,sha256=
|
26
|
+
arbor/server/services/scripts/grpo_training.py,sha256=eMT5cIMolAzhukANH1WRmPdxIkvLbsbrggdGFCMGMHc,26474
|
27
27
|
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
-
arbor_ai-0.1.
|
30
|
-
arbor_ai-0.1.
|
31
|
-
arbor_ai-0.1.
|
32
|
-
arbor_ai-0.1.
|
33
|
-
arbor_ai-0.1.
|
34
|
-
arbor_ai-0.1.
|
29
|
+
arbor_ai-0.1.13.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
30
|
+
arbor_ai-0.1.13.dist-info/METADATA,sha256=c0yScMpCiWYSFqVLjgk5TrRBuAVJK3aTBl0z0IPZ_8Y,2442
|
31
|
+
arbor_ai-0.1.13.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
|
32
|
+
arbor_ai-0.1.13.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
33
|
+
arbor_ai-0.1.13.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
34
|
+
arbor_ai-0.1.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|