arbor-ai 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +12 -0
- arbor/server/api/routes/grpo.py +4 -1
- arbor/server/api/routes/inference.py +11 -16
- arbor/server/services/grpo_manager.py +179 -98
- arbor/server/services/inference/__init__.py +0 -0
- arbor/server/services/inference/vllm_client.py +445 -0
- arbor/server/services/inference/vllm_serve.py +2335 -0
- arbor/server/services/inference_manager.py +149 -219
- arbor/server/services/scripts/dpo_training.py +0 -0
- arbor/server/services/scripts/grpo_training.py +157 -53
- arbor/server/services/scripts/sft_training.py +109 -0
- arbor/server/services/scripts/utils/__init__.py +0 -0
- arbor/server/services/scripts/utils/arg_parser.py +31 -0
- arbor/server/services/scripts/utils/dataset.py +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/METADATA +4 -5
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/RECORD +20 -12
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/WHEEL +1 -1
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/top_level.txt +0 -0
arbor/cli.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1
|
+
import os
|
2
|
+
from datetime import datetime
|
3
|
+
|
1
4
|
import click
|
2
5
|
import uvicorn
|
3
6
|
|
@@ -10,6 +13,14 @@ from arbor.server.services.job_manager import JobManager
|
|
10
13
|
from arbor.server.services.training_manager import TrainingManager
|
11
14
|
|
12
15
|
|
16
|
+
def make_log_dir(storage_path: str):
|
17
|
+
# Create a timestamped log directory under the storage path
|
18
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
19
|
+
log_dir = os.path.join(storage_path, "logs", timestamp)
|
20
|
+
os.makedirs(log_dir, exist_ok=True)
|
21
|
+
return log_dir
|
22
|
+
|
23
|
+
|
13
24
|
@click.group()
|
14
25
|
def cli():
|
15
26
|
pass
|
@@ -26,6 +37,7 @@ def create_app(arbor_config_path: str):
|
|
26
37
|
"""
|
27
38
|
# Create new settings instance with overrides
|
28
39
|
settings = Settings.load_from_yaml(arbor_config_path)
|
40
|
+
app.state.log_dir = make_log_dir(settings.STORAGE_PATH)
|
29
41
|
|
30
42
|
# Initialize services with settings
|
31
43
|
file_manager = FileManager(settings=settings)
|
arbor/server/api/routes/grpo.py
CHANGED
@@ -49,7 +49,10 @@ def update_model(request: Request):
|
|
49
49
|
@router.post("/checkpoint", response_model=GRPOCheckpointResponse)
|
50
50
|
def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
|
51
51
|
grpo_manager = request.app.state.grpo_manager
|
52
|
-
|
52
|
+
inference_manager = request.app.state.inference_manager
|
53
|
+
checkpoint_data = grpo_manager.checkpoint(
|
54
|
+
grpo_checkpoint_request, inference_manager
|
55
|
+
)
|
53
56
|
return GRPOCheckpointResponse(status="success", **checkpoint_data)
|
54
57
|
|
55
58
|
|
@@ -1,4 +1,5 @@
|
|
1
|
-
import
|
1
|
+
import json
|
2
|
+
import uuid
|
2
3
|
|
3
4
|
from fastapi import APIRouter, Request
|
4
5
|
|
@@ -12,29 +13,23 @@ async def run_inference(
|
|
12
13
|
inference_manager = request.app.state.inference_manager
|
13
14
|
raw_json = await request.json()
|
14
15
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
# Generate a random hex ID
|
17
|
+
request_id = str(uuid.uuid4())
|
18
|
+
# Create requests directory if it doesn't exist
|
19
|
+
with open(f"{request.app.state.log_dir}/inference_requests.jsonl", "a") as f:
|
20
|
+
f.write(json.dumps({"id": request_id, "request": raw_json}) + "\n")
|
19
21
|
|
20
22
|
# if a server isnt running, launch one
|
21
|
-
if (
|
22
|
-
not inference_manager.is_server_running()
|
23
|
-
and not inference_manager.is_server_restarting()
|
24
|
-
):
|
23
|
+
if not inference_manager.is_server_running():
|
25
24
|
print("No model is running, launching model...")
|
26
25
|
inference_manager.launch(raw_json["model"])
|
27
26
|
|
28
|
-
if inference_manager.is_server_restarting():
|
29
|
-
print("Waiting for server to finish restarting...")
|
30
|
-
while inference_manager.is_server_restarting():
|
31
|
-
time.sleep(5)
|
32
|
-
# Update the model in the request
|
33
|
-
raw_json["model"] = inference_manager.current_model
|
34
|
-
|
35
27
|
# forward the request to the inference server
|
36
28
|
completion = await inference_manager.run_inference(raw_json)
|
37
29
|
|
30
|
+
with open(f"{request.app.state.log_dir}/inference_responses.jsonl", "a") as f:
|
31
|
+
f.write(json.dumps({"id": request_id, "response": completion}) + "\n")
|
32
|
+
|
38
33
|
return completion
|
39
34
|
|
40
35
|
|
@@ -13,6 +13,8 @@ from datetime import datetime
|
|
13
13
|
from pathlib import Path
|
14
14
|
from typing import Optional
|
15
15
|
|
16
|
+
import psutil
|
17
|
+
|
16
18
|
from arbor.server.api.models.schemas import (
|
17
19
|
GRPOCheckpointRequest,
|
18
20
|
GRPOConfigRequest,
|
@@ -31,8 +33,9 @@ class GRPOManager:
|
|
31
33
|
self.train_kwargs = None
|
32
34
|
self.server_comms_handler = None
|
33
35
|
self.status_thread = None
|
34
|
-
self.model_saved_and_reload_requested = False
|
35
36
|
self.saving_checkpoint = False
|
37
|
+
self.saving_model = False
|
38
|
+
self.terminating = False
|
36
39
|
|
37
40
|
self.checkpoints = {}
|
38
41
|
self.last_checkpoint = None
|
@@ -45,8 +48,10 @@ class GRPOManager:
|
|
45
48
|
def _signal_handler(self, signum, frame):
|
46
49
|
"""Handle keyboard interrupt (SIGINT) gracefully."""
|
47
50
|
print("\nReceived keyboard interrupt. Shutting down gracefully...")
|
48
|
-
|
49
|
-
|
51
|
+
# Sleep for a bit to let async operations go through
|
52
|
+
time.sleep(2)
|
53
|
+
if self.training_process is not None:
|
54
|
+
self.cleanup_termination(None)
|
50
55
|
|
51
56
|
def make_output_dir(
|
52
57
|
self, model_name: str, run_suffix: Optional[str] = None
|
@@ -122,6 +127,17 @@ class GRPOManager:
|
|
122
127
|
|
123
128
|
self.current_model = request.model
|
124
129
|
|
130
|
+
# The inference server has to be launched before the training process
|
131
|
+
# Launch the inference server
|
132
|
+
# launch_kwargs = {
|
133
|
+
# k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
|
134
|
+
# }
|
135
|
+
inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
|
136
|
+
"max_context_length", None
|
137
|
+
)
|
138
|
+
print("Launching inference server...")
|
139
|
+
inference_manager.launch(self.current_model)
|
140
|
+
|
125
141
|
# Initialize ZMQ socket manager - no need for connection acceptance thread anymore
|
126
142
|
self.server_comms_handler = ArborServerCommsHandler()
|
127
143
|
|
@@ -171,6 +187,10 @@ class GRPOManager:
|
|
171
187
|
str(self.server_comms_handler.broadcast_port),
|
172
188
|
"--handshake_port",
|
173
189
|
str(self.server_comms_handler.handshake_port),
|
190
|
+
"--vllm_port",
|
191
|
+
str(inference_manager.port),
|
192
|
+
"--vllm_group_port",
|
193
|
+
str(inference_manager.group_port),
|
174
194
|
# Training args
|
175
195
|
"--model",
|
176
196
|
self.current_model,
|
@@ -221,33 +241,38 @@ class GRPOManager:
|
|
221
241
|
self.status_thread.start()
|
222
242
|
self.server_comms_handler.wait_for_clients(num_processes)
|
223
243
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
244
|
+
async def _handle_weight_update_start(self, inference_manager):
|
245
|
+
"""Handle weight update start in the event loop"""
|
246
|
+
await inference_manager.start_weight_update()
|
247
|
+
|
248
|
+
async def _handle_weight_update_complete(self, inference_manager):
|
249
|
+
"""Handle weight update complete in the event loop"""
|
250
|
+
await inference_manager.complete_weight_update()
|
251
|
+
|
252
|
+
def _run_in_loop(self, coro):
|
253
|
+
"""Run a coroutine in the event loop from a thread"""
|
254
|
+
future = asyncio.run_coroutine_threadsafe(coro, self.event_loop)
|
255
|
+
return future.result()
|
233
256
|
|
234
257
|
def _handle_status_updates(self, inference_manager: InferenceManager):
|
235
258
|
"""Handle status updates from training process using ZMQ SUB socket"""
|
236
259
|
print("Starting status update handler...")
|
237
260
|
try:
|
238
|
-
|
239
261
|
for status in self.server_comms_handler.receive_status():
|
240
262
|
print(f"Received status update: {status}")
|
241
|
-
if status["status"] == "
|
263
|
+
if status["status"] == "weight_update_start":
|
264
|
+
# Block inference calls by incrementing counter
|
265
|
+
inference_manager.start_weight_update()
|
266
|
+
elif status["status"] == "weight_update_complete":
|
267
|
+
# Decrement counter to potentially allow inference calls again
|
268
|
+
inference_manager.complete_weight_update()
|
269
|
+
elif status["status"] == "model_saved":
|
242
270
|
print("Updating inference model...")
|
243
271
|
# There is a case where this status is sent multiple times
|
244
272
|
# We need to make sure we only update the model once
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
self.model_saved_and_reload_requested = False
|
249
|
-
self.current_model = status["output_dir"]
|
250
|
-
print("Model update complete")
|
273
|
+
self.current_model = status["output_dir"]
|
274
|
+
self.saving_model = False
|
275
|
+
print("Model update complete")
|
251
276
|
elif status["status"] == "checkpoint_saved":
|
252
277
|
print("Received checkpoint saved status")
|
253
278
|
self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
|
@@ -257,24 +282,19 @@ class GRPOManager:
|
|
257
282
|
elif status["status"] == "error":
|
258
283
|
print(f"Training error: {status.get('error', 'Unknown error')}")
|
259
284
|
elif status["status"] == "terminated":
|
285
|
+
self.terminating = False
|
260
286
|
print("Training process terminated")
|
261
|
-
break
|
262
287
|
except Exception as e:
|
263
288
|
print(f"Error in status update handler: {e}")
|
289
|
+
# Make sure to allow inference if there's an error
|
290
|
+
try:
|
291
|
+
inference_manager.complete_weight_update()
|
292
|
+
except:
|
293
|
+
pass
|
264
294
|
|
265
295
|
def grpo_step(
|
266
296
|
self, request: GRPORequest, inference_manager: InferenceManager
|
267
297
|
) -> str:
|
268
|
-
while inference_manager.is_server_restarting():
|
269
|
-
print("Inferece manager restarting, waiting for GRPO step")
|
270
|
-
time.sleep(5)
|
271
|
-
|
272
|
-
while self._should_update_model():
|
273
|
-
print(
|
274
|
-
f"Waiting for model update. Data count: {self.data_count}, Last inference update: {self.last_inference_update}"
|
275
|
-
)
|
276
|
-
time.sleep(5)
|
277
|
-
|
278
298
|
while self.saving_checkpoint:
|
279
299
|
print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
|
280
300
|
time.sleep(5)
|
@@ -283,8 +303,10 @@ class GRPOManager:
|
|
283
303
|
# Send the batch to the training process
|
284
304
|
self.server_comms_handler.send_data(request.batch)
|
285
305
|
self.data_count += 1
|
306
|
+
|
286
307
|
except Exception as e:
|
287
308
|
print(f"Failed to send batch to training process: {e}")
|
309
|
+
raise
|
288
310
|
|
289
311
|
return {
|
290
312
|
"current_model": self.current_model,
|
@@ -293,35 +315,22 @@ class GRPOManager:
|
|
293
315
|
}
|
294
316
|
|
295
317
|
def update_model(self, request, inference_manager: InferenceManager):
|
296
|
-
|
297
|
-
# Create a new event loop if one doesn't exist
|
298
|
-
try:
|
299
|
-
loop = asyncio.get_event_loop()
|
300
|
-
except RuntimeError:
|
301
|
-
loop = asyncio.new_event_loop()
|
302
|
-
asyncio.set_event_loop(loop)
|
303
|
-
|
304
|
-
# Run the session closure in the event loop
|
305
|
-
loop.run_until_complete(inference_manager._session.close())
|
306
|
-
inference_manager._session = None
|
307
|
-
|
308
|
-
inference_manager.inference_count = 0
|
309
|
-
inference_manager.restarting = True
|
310
|
-
|
311
|
-
self.model_saved_and_reload_requested = True
|
312
|
-
self.server_comms_handler.send_command({"command": "save_model"})
|
313
|
-
while self.model_saved_and_reload_requested:
|
314
|
-
print(
|
315
|
-
"Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
|
316
|
-
)
|
317
|
-
time.sleep(5)
|
318
|
+
# No longer used
|
318
319
|
return {
|
319
320
|
"current_model": self.current_model,
|
320
321
|
"checkpoints": self.checkpoints,
|
321
322
|
"last_checkpoint": self.last_checkpoint,
|
322
323
|
}
|
323
324
|
|
324
|
-
def checkpoint(
|
325
|
+
def checkpoint(
|
326
|
+
self, request: GRPOCheckpointRequest, inference_manager: InferenceManager
|
327
|
+
):
|
328
|
+
while (
|
329
|
+
inference_manager.is_updating
|
330
|
+
): # Use the property instead of direct access
|
331
|
+
print("Waiting for weight updates to finish before checkpointing...")
|
332
|
+
time.sleep(5)
|
333
|
+
|
325
334
|
self.saving_checkpoint = True
|
326
335
|
self.server_comms_handler.send_command(
|
327
336
|
{"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
|
@@ -337,71 +346,143 @@ class GRPOManager:
|
|
337
346
|
|
338
347
|
def terminate(self, inference_manager: InferenceManager):
|
339
348
|
"""Clean up resources and save the final model."""
|
349
|
+
time.sleep(5)
|
350
|
+
|
351
|
+
while (
|
352
|
+
inference_manager and inference_manager.is_updating
|
353
|
+
): # Use the property instead of direct access
|
354
|
+
print("Waiting for final weight updates to finish before saving...")
|
355
|
+
time.sleep(5)
|
356
|
+
|
357
|
+
print("sending save model command")
|
358
|
+
self.saving_model = True
|
359
|
+
self.server_comms_handler.send_command({"command": "save_model"})
|
360
|
+
while self.saving_model:
|
361
|
+
print("Waiting for final model to be saved...")
|
362
|
+
time.sleep(5)
|
363
|
+
|
340
364
|
termination_data = {
|
341
365
|
"current_model": self.current_model,
|
342
366
|
"checkpoints": self.checkpoints,
|
343
367
|
"last_checkpoint": self.last_checkpoint,
|
344
368
|
}
|
345
|
-
try:
|
346
|
-
# Stop the inference server
|
347
|
-
if inference_manager.process is not None:
|
348
|
-
inference_manager.kill()
|
349
369
|
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
370
|
+
print("sending termination command")
|
371
|
+
self.terminating = True
|
372
|
+
self.server_comms_handler.send_command({"command": "terminate"})
|
373
|
+
print("Waiting for training process to finish...")
|
374
|
+
|
375
|
+
# Wait for at most 15 seconds for termination
|
376
|
+
start_time = time.time()
|
377
|
+
while self.terminating:
|
378
|
+
if time.time() - start_time > 15:
|
379
|
+
print(
|
380
|
+
"Termination wait timed out after 15 seconds, proceeding with cleanup..."
|
381
|
+
)
|
382
|
+
break
|
383
|
+
print("Waiting for run to be terminated...")
|
384
|
+
time.sleep(3)
|
385
|
+
|
386
|
+
print("Doing cleanup")
|
387
|
+
self.cleanup_termination(inference_manager)
|
388
|
+
|
389
|
+
if self.train_kwargs and "output_dir" in self.train_kwargs:
|
390
|
+
print(
|
391
|
+
f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
|
392
|
+
)
|
393
|
+
if not os.path.exists(self.train_kwargs["output_dir"]):
|
394
|
+
print(
|
395
|
+
f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
|
396
|
+
)
|
397
|
+
output_dir = self.train_kwargs["output_dir"]
|
398
|
+
self.train_kwargs = None
|
399
|
+
else:
|
400
|
+
print("Training terminated, no output directory specified")
|
401
|
+
self.train_kwargs = None
|
402
|
+
|
403
|
+
return termination_data
|
354
404
|
|
355
|
-
|
405
|
+
def cleanup_termination(self, inference_manager):
|
406
|
+
try:
|
407
|
+
# Kill training process and all its children (accelerate launcher creates multiple processes)
|
356
408
|
if self.training_process:
|
357
|
-
|
409
|
+
print("Terminating training process and its children...")
|
410
|
+
try:
|
411
|
+
parent = psutil.Process(self.training_process.pid)
|
412
|
+
# Get all child processes including grandchildren
|
413
|
+
children = parent.children(recursive=True)
|
414
|
+
|
415
|
+
# Send SIGTERM to children first
|
416
|
+
for child in children:
|
417
|
+
try:
|
418
|
+
child.send_signal(signal.SIGTERM)
|
419
|
+
except psutil.NoSuchProcess:
|
420
|
+
pass
|
421
|
+
|
422
|
+
# Send SIGTERM to parent
|
423
|
+
parent.send_signal(signal.SIGTERM)
|
424
|
+
|
425
|
+
# Wait for processes to terminate gracefully
|
426
|
+
gone, alive = psutil.wait_procs(children + [parent], timeout=10)
|
427
|
+
|
428
|
+
# If any processes are still alive, force kill them
|
429
|
+
for p in alive:
|
430
|
+
try:
|
431
|
+
p.kill() # SIGKILL
|
432
|
+
except psutil.NoSuchProcess:
|
433
|
+
pass
|
434
|
+
|
435
|
+
except psutil.NoSuchProcess:
|
436
|
+
print(f"Process {self.training_process.pid} not found")
|
437
|
+
except Exception as e:
|
438
|
+
print(f"Error killing training process tree: {e}")
|
439
|
+
# Fallback to basic termination
|
440
|
+
self.training_process.terminate()
|
441
|
+
try:
|
442
|
+
self.training_process.wait(timeout=10)
|
443
|
+
except subprocess.TimeoutExpired:
|
444
|
+
self.training_process.kill()
|
445
|
+
self.training_process.wait(timeout=10)
|
358
446
|
|
359
|
-
except Exception as e:
|
360
|
-
print(f"Error during termination: {e}")
|
361
|
-
finally:
|
362
447
|
# Clean up ZMQ connections
|
363
448
|
if self.server_comms_handler:
|
449
|
+
print("Closing ZMQ connections...")
|
364
450
|
self.server_comms_handler.close()
|
365
451
|
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
self.training_process.wait()
|
452
|
+
if inference_manager and inference_manager.process is not None:
|
453
|
+
print("Killing inference manager...")
|
454
|
+
inference_manager.kill()
|
370
455
|
|
371
|
-
# Reinitialize
|
456
|
+
# Reinitialize in case we want to start a new training run
|
457
|
+
self.training_process = None
|
458
|
+
self.current_model = None
|
459
|
+
self.server_comms_handler = None
|
460
|
+
self.status_thread = None
|
461
|
+
self.data_count = 0
|
462
|
+
print("Cleanup completed successfully")
|
463
|
+
except Exception as e:
|
464
|
+
print(f"Error during cleanup: {e}")
|
465
|
+
# Still reset state even if cleanup fails
|
372
466
|
self.training_process = None
|
373
467
|
self.current_model = None
|
374
468
|
self.server_comms_handler = None
|
375
469
|
self.status_thread = None
|
376
|
-
self.model_saved_and_reload_requested = False
|
377
|
-
|
378
470
|
self.data_count = 0
|
379
|
-
self.last_inference_update = 0
|
380
|
-
|
381
|
-
if self.train_kwargs and "output_dir" in self.train_kwargs:
|
382
|
-
print(
|
383
|
-
f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
|
384
|
-
)
|
385
|
-
if not os.path.exists(self.train_kwargs["output_dir"]):
|
386
|
-
print(
|
387
|
-
f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
|
388
|
-
)
|
389
|
-
output_dir = self.train_kwargs["output_dir"]
|
390
|
-
self.train_kwargs = None
|
391
|
-
else:
|
392
|
-
print("Training terminated, no output directory specified")
|
393
|
-
self.train_kwargs = None
|
394
|
-
|
395
|
-
return termination_data
|
396
|
-
|
397
|
-
def _should_update_model(self):
|
398
|
-
return self.model_saved_and_reload_requested
|
399
471
|
|
400
472
|
|
401
473
|
def get_free_port() -> int:
|
402
474
|
"""
|
403
|
-
Return a free TCP port on localhost.
|
475
|
+
Return a randomly selected free TCP port on localhost from a selection of 3-4 ports.
|
404
476
|
"""
|
405
|
-
|
406
|
-
|
407
|
-
|
477
|
+
import random
|
478
|
+
import socket
|
479
|
+
|
480
|
+
ports = []
|
481
|
+
for _ in range(random.randint(5, 10)):
|
482
|
+
try:
|
483
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
484
|
+
s.bind(("localhost", 0))
|
485
|
+
ports.append(s.getsockname()[1])
|
486
|
+
except Exception as e:
|
487
|
+
print(f"Error binding to port: {e}")
|
488
|
+
return random.choice(ports)
|
File without changes
|