arbor-ai 0.1.12__tar.gz → 0.1.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arbor_ai-0.1.12/arbor_ai.egg-info → arbor_ai-0.1.14}/PKG-INFO +2 -1
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/models/schemas.py +22 -1
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/grpo.py +15 -6
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/grpo_manager.py +65 -17
- arbor_ai-0.1.14/arbor/server/services/inference/sgl_router_launch_server.py +226 -0
- arbor_ai-0.1.14/arbor/server/services/inference_manager.py +404 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/scripts/grpo_training.py +159 -43
- arbor_ai-0.1.14/arbor/server/utils/helpers.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14/arbor_ai.egg-info}/PKG-INFO +2 -1
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/SOURCES.txt +2 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/requires.txt +1 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/pyproject.toml +3 -2
- arbor_ai-0.1.12/arbor/server/services/inference_manager.py +0 -299
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/LICENSE +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/README.md +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/cli.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/client/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/client/api.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/files.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/inference.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/api/routes/jobs.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/core/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/core/config.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/core/logging.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/main.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/comms/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/comms/comms.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/dependencies.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/file_manager.py +0 -0
- {arbor_ai-0.1.12/arbor/server/utils → arbor_ai-0.1.14/arbor/server/services/inference}/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/job_manager.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor/server/services/training_manager.py +0 -0
- /arbor_ai-0.1.12/arbor/server/utils/helpers.py → /arbor_ai-0.1.14/arbor/server/utils/__init__.py +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/dependency_links.txt +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/entry_points.txt +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/arbor_ai.egg-info/top_level.txt +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/setup.cfg +0 -0
- {arbor_ai-0.1.12 → arbor_ai-0.1.14}/tests/test_cli.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.14
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -23,6 +23,7 @@ Requires-Dist: pyzmq>=26.4.0
|
|
23
23
|
Requires-Dist: pyyaml>=6.0.2
|
24
24
|
Requires-Dist: sglang[all]>=0.4.5.post3
|
25
25
|
Requires-Dist: sglang-router
|
26
|
+
Requires-Dist: wandb
|
26
27
|
Dynamic: license-file
|
27
28
|
|
28
29
|
<p align="center">
|
@@ -199,10 +199,16 @@ class GRPOConfigRequest(BaseModel):
|
|
199
199
|
bf16: Optional[bool] = None
|
200
200
|
scale_rewards: Optional[bool] = None
|
201
201
|
max_grad_norm: Optional[float] = None
|
202
|
+
report_to: Optional[str] = None
|
203
|
+
log_completions: Optional[bool] = None
|
204
|
+
logging_steps: Optional[int] = None
|
205
|
+
mask_truncated_completions: Optional[bool] = None
|
206
|
+
# Arbor specific
|
207
|
+
max_context_length: Optional[int] = None
|
202
208
|
lora: Optional[bool] = None
|
203
|
-
update_interval: Optional[int] = None
|
204
209
|
# To name the run
|
205
210
|
suffix: Optional[str] = None
|
211
|
+
generation_batch_size: Optional[int] = None
|
206
212
|
|
207
213
|
|
208
214
|
class GRPOConfigResponse(BaseModel):
|
@@ -216,8 +222,23 @@ class GRPOTerminateRequest(BaseModel):
|
|
216
222
|
class GRPOTerminateResponse(BaseModel):
|
217
223
|
status: str
|
218
224
|
current_model: str
|
225
|
+
checkpoints: Optional[dict[str, str]] = None
|
226
|
+
last_checkpoint: Optional[str] = None
|
219
227
|
|
220
228
|
|
221
229
|
class GRPOStepResponse(BaseModel):
|
222
230
|
status: str
|
223
231
|
current_model: str
|
232
|
+
checkpoints: dict[str, str]
|
233
|
+
last_checkpoint: Optional[str] = None
|
234
|
+
|
235
|
+
|
236
|
+
class GRPOCheckpointRequest(BaseModel):
|
237
|
+
checkpoint_name: str
|
238
|
+
|
239
|
+
|
240
|
+
class GRPOCheckpointResponse(BaseModel):
|
241
|
+
status: str
|
242
|
+
current_model: str
|
243
|
+
checkpoints: dict[str, str]
|
244
|
+
last_checkpoint: str
|
@@ -4,6 +4,8 @@ import subprocess
|
|
4
4
|
from fastapi import APIRouter, BackgroundTasks, Request
|
5
5
|
|
6
6
|
from arbor.server.api.models.schemas import (
|
7
|
+
GRPOCheckpointRequest,
|
8
|
+
GRPOCheckpointResponse,
|
7
9
|
GRPOConfigRequest,
|
8
10
|
GRPOConfigResponse,
|
9
11
|
GRPORequest,
|
@@ -31,17 +33,24 @@ def run_grpo_step(
|
|
31
33
|
inference_manager = request.app.state.inference_manager
|
32
34
|
grpo_manager = request.app.state.grpo_manager
|
33
35
|
|
34
|
-
|
36
|
+
step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
|
35
37
|
|
36
|
-
return GRPOStepResponse(status="success",
|
38
|
+
return GRPOStepResponse(status="success", **step_data)
|
37
39
|
|
38
40
|
|
39
41
|
@router.post("/update_model", response_model=GRPOStepResponse)
|
40
42
|
def update_model(request: Request):
|
41
43
|
grpo_manager = request.app.state.grpo_manager
|
42
44
|
inference_manager = request.app.state.inference_manager
|
43
|
-
|
44
|
-
return GRPOStepResponse(status="success",
|
45
|
+
update_model_data = grpo_manager.update_model(request, inference_manager)
|
46
|
+
return GRPOStepResponse(status="success", **update_model_data)
|
47
|
+
|
48
|
+
|
49
|
+
@router.post("/checkpoint", response_model=GRPOCheckpointResponse)
|
50
|
+
def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
|
51
|
+
grpo_manager = request.app.state.grpo_manager
|
52
|
+
checkpoint_data = grpo_manager.checkpoint(grpo_checkpoint_request)
|
53
|
+
return GRPOCheckpointResponse(status="success", **checkpoint_data)
|
45
54
|
|
46
55
|
|
47
56
|
@router.post("/terminate", response_model=GRPOTerminateResponse)
|
@@ -50,5 +59,5 @@ def terminate_grpo(request: Request):
|
|
50
59
|
grpo_manager = request.app.state.grpo_manager
|
51
60
|
inference_manager = request.app.state.inference_manager
|
52
61
|
|
53
|
-
|
54
|
-
return GRPOTerminateResponse(status="success",
|
62
|
+
terminate_data = grpo_manager.terminate(inference_manager)
|
63
|
+
return GRPOTerminateResponse(status="success", **terminate_data)
|
@@ -13,7 +13,11 @@ from datetime import datetime
|
|
13
13
|
from pathlib import Path
|
14
14
|
from typing import Optional
|
15
15
|
|
16
|
-
from arbor.server.api.models.schemas import
|
16
|
+
from arbor.server.api.models.schemas import (
|
17
|
+
GRPOCheckpointRequest,
|
18
|
+
GRPOConfigRequest,
|
19
|
+
GRPORequest,
|
20
|
+
)
|
17
21
|
from arbor.server.core.config import Settings
|
18
22
|
from arbor.server.services.comms.comms import ArborServerCommsHandler
|
19
23
|
from arbor.server.services.inference_manager import InferenceManager
|
@@ -28,7 +32,10 @@ class GRPOManager:
|
|
28
32
|
self.server_comms_handler = None
|
29
33
|
self.status_thread = None
|
30
34
|
self.model_saved_and_reload_requested = False
|
35
|
+
self.saving_checkpoint = False
|
31
36
|
|
37
|
+
self.checkpoints = {}
|
38
|
+
self.last_checkpoint = None
|
32
39
|
self.data_count = 0
|
33
40
|
self.last_inference_update = 0
|
34
41
|
# Set up signal handler
|
@@ -86,12 +93,17 @@ class GRPOManager:
|
|
86
93
|
"bf16",
|
87
94
|
"scale_rewards",
|
88
95
|
"max_grad_norm",
|
96
|
+
"report_to",
|
97
|
+
"log_completions",
|
98
|
+
"logging_steps",
|
99
|
+
"generation_batch_size",
|
100
|
+
"mask_truncated_completions",
|
89
101
|
]
|
90
102
|
trl_train_kwargs = {
|
91
103
|
key: train_kwargs[key] for key in trl_keys if key in train_kwargs
|
92
104
|
}
|
93
105
|
|
94
|
-
arbor_keys = ["
|
106
|
+
arbor_keys = ["max_context_length", "lora"]
|
95
107
|
arbor_train_kwargs = {
|
96
108
|
key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
|
97
109
|
}
|
@@ -119,6 +131,8 @@ class GRPOManager:
|
|
119
131
|
# Start the training process with ZMQ ports
|
120
132
|
my_env = os.environ.copy()
|
121
133
|
my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
|
134
|
+
# WandB can block the training process for login, so we silence it
|
135
|
+
my_env["WANDB_SILENT"] = "true"
|
122
136
|
|
123
137
|
num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
|
124
138
|
|
@@ -209,6 +223,12 @@ class GRPOManager:
|
|
209
223
|
|
210
224
|
# Launch the inference server
|
211
225
|
print("Launching inference server...")
|
226
|
+
# launch_kwargs = {
|
227
|
+
# k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
|
228
|
+
# }
|
229
|
+
inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
|
230
|
+
"max_context_length", None
|
231
|
+
)
|
212
232
|
inference_manager.launch(self.current_model)
|
213
233
|
|
214
234
|
def _handle_status_updates(self, inference_manager: InferenceManager):
|
@@ -228,6 +248,12 @@ class GRPOManager:
|
|
228
248
|
self.model_saved_and_reload_requested = False
|
229
249
|
self.current_model = status["output_dir"]
|
230
250
|
print("Model update complete")
|
251
|
+
elif status["status"] == "checkpoint_saved":
|
252
|
+
print("Received checkpoint saved status")
|
253
|
+
self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
|
254
|
+
self.last_checkpoint = status["checkpoint_name"]
|
255
|
+
self.saving_checkpoint = False
|
256
|
+
print("Checkpoint saved")
|
231
257
|
elif status["status"] == "error":
|
232
258
|
print(f"Training error: {status.get('error', 'Unknown error')}")
|
233
259
|
elif status["status"] == "terminated":
|
@@ -249,6 +275,10 @@ class GRPOManager:
|
|
249
275
|
)
|
250
276
|
time.sleep(5)
|
251
277
|
|
278
|
+
while self.saving_checkpoint:
|
279
|
+
print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
|
280
|
+
time.sleep(5)
|
281
|
+
|
252
282
|
try:
|
253
283
|
# Send the batch to the training process
|
254
284
|
self.server_comms_handler.send_data(request.batch)
|
@@ -256,12 +286,11 @@ class GRPOManager:
|
|
256
286
|
except Exception as e:
|
257
287
|
print(f"Failed to send batch to training process: {e}")
|
258
288
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
return self.current_model
|
289
|
+
return {
|
290
|
+
"current_model": self.current_model,
|
291
|
+
"checkpoints": self.checkpoints,
|
292
|
+
"last_checkpoint": self.last_checkpoint,
|
293
|
+
}
|
265
294
|
|
266
295
|
def update_model(self, request, inference_manager: InferenceManager):
|
267
296
|
if inference_manager._session:
|
@@ -286,18 +315,41 @@ class GRPOManager:
|
|
286
315
|
"Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
|
287
316
|
)
|
288
317
|
time.sleep(5)
|
289
|
-
return
|
318
|
+
return {
|
319
|
+
"current_model": self.current_model,
|
320
|
+
"checkpoints": self.checkpoints,
|
321
|
+
"last_checkpoint": self.last_checkpoint,
|
322
|
+
}
|
323
|
+
|
324
|
+
def checkpoint(self, request: GRPOCheckpointRequest):
|
325
|
+
self.saving_checkpoint = True
|
326
|
+
self.server_comms_handler.send_command(
|
327
|
+
{"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
|
328
|
+
)
|
329
|
+
while self.saving_checkpoint:
|
330
|
+
print("Waiting for checkpoint to be saved...")
|
331
|
+
time.sleep(5)
|
332
|
+
return {
|
333
|
+
"current_model": self.current_model,
|
334
|
+
"checkpoints": self.checkpoints,
|
335
|
+
"last_checkpoint": self.last_checkpoint,
|
336
|
+
}
|
290
337
|
|
291
338
|
def terminate(self, inference_manager: InferenceManager):
|
292
339
|
"""Clean up resources and save the final model."""
|
340
|
+
termination_data = {
|
341
|
+
"current_model": self.current_model,
|
342
|
+
"checkpoints": self.checkpoints,
|
343
|
+
"last_checkpoint": self.last_checkpoint,
|
344
|
+
}
|
293
345
|
try:
|
294
346
|
# Stop the inference server
|
295
347
|
if inference_manager.process is not None:
|
296
348
|
inference_manager.kill()
|
297
349
|
|
298
350
|
# Send termination command through REQ socket
|
299
|
-
|
300
|
-
self.training_process.terminate()
|
351
|
+
self.server_comms_handler.send_broadcast({"message": "terminate"})
|
352
|
+
# self.training_process.terminate()
|
301
353
|
print("Waiting for training process to finish")
|
302
354
|
|
303
355
|
# Wait for training process to finish
|
@@ -336,17 +388,13 @@ class GRPOManager:
|
|
336
388
|
)
|
337
389
|
output_dir = self.train_kwargs["output_dir"]
|
338
390
|
self.train_kwargs = None
|
339
|
-
return output_dir
|
340
391
|
else:
|
341
392
|
print("Training terminated, no output directory specified")
|
342
393
|
self.train_kwargs = None
|
343
|
-
|
394
|
+
|
395
|
+
return termination_data
|
344
396
|
|
345
397
|
def _should_update_model(self):
|
346
|
-
# return (
|
347
|
-
# self.data_count - self.last_inference_update
|
348
|
-
# >= self.train_kwargs["update_interval"]
|
349
|
-
# )
|
350
398
|
return self.model_saved_and_reload_requested
|
351
399
|
|
352
400
|
|
@@ -0,0 +1,226 @@
|
|
1
|
+
import argparse
|
2
|
+
import copy
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import multiprocessing as mp
|
6
|
+
import os
|
7
|
+
import random
|
8
|
+
import signal
|
9
|
+
import sys
|
10
|
+
import time
|
11
|
+
from typing import List
|
12
|
+
|
13
|
+
import requests
|
14
|
+
import zmq
|
15
|
+
from setproctitle import setproctitle
|
16
|
+
from sglang.srt.entrypoints.http_server import launch_server
|
17
|
+
from sglang.srt.server_args import ServerArgs
|
18
|
+
from sglang.srt.utils import is_port_available
|
19
|
+
from sglang_router.launch_router import RouterArgs, launch_router
|
20
|
+
|
21
|
+
|
22
|
+
def setup_logger():
|
23
|
+
logger = logging.getLogger("router")
|
24
|
+
logger.setLevel(logging.INFO)
|
25
|
+
|
26
|
+
formatter = logging.Formatter(
|
27
|
+
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
|
28
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
29
|
+
)
|
30
|
+
|
31
|
+
handler = logging.StreamHandler()
|
32
|
+
handler.setFormatter(formatter)
|
33
|
+
logger.addHandler(handler)
|
34
|
+
|
35
|
+
return logger
|
36
|
+
|
37
|
+
|
38
|
+
logger = setup_logger()
|
39
|
+
|
40
|
+
|
41
|
+
# Create new process group
|
42
|
+
def run_server(server_args, dp_rank):
|
43
|
+
"""
|
44
|
+
Note:
|
45
|
+
|
46
|
+
1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
|
47
|
+
This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
|
48
|
+
|
49
|
+
Terminal (PGID=100)
|
50
|
+
└── Main Python Process (PGID=100)
|
51
|
+
└── Server Process 1 (PGID=100)
|
52
|
+
└── Scheduler 1
|
53
|
+
└── Detokenizer 1
|
54
|
+
└── Server Process 2 (PGID=100)
|
55
|
+
└── Scheduler 2
|
56
|
+
└── Detokenizer 2
|
57
|
+
|
58
|
+
2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
|
59
|
+
|
60
|
+
Terminal (PGID=100)
|
61
|
+
└── Main Python Process (PGID=200)
|
62
|
+
└── Server Process 1 (PGID=300)
|
63
|
+
└── Scheduler 1
|
64
|
+
└── Detokenizer 1
|
65
|
+
└── Server Process 2 (PGID=400)
|
66
|
+
└── Scheduler 2
|
67
|
+
└── Detokenizer 2
|
68
|
+
"""
|
69
|
+
# create new process group
|
70
|
+
os.setpgrp()
|
71
|
+
|
72
|
+
setproctitle("sglang::server")
|
73
|
+
# Set SGLANG_DP_RANK environment variable
|
74
|
+
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
|
75
|
+
|
76
|
+
launch_server(server_args)
|
77
|
+
|
78
|
+
|
79
|
+
def launch_server_process(
|
80
|
+
server_args: ServerArgs, worker_port: int, dp_id: int
|
81
|
+
) -> mp.Process:
|
82
|
+
"""Launch a single server process with the given args and port."""
|
83
|
+
server_args = copy.deepcopy(server_args)
|
84
|
+
server_args.port = worker_port
|
85
|
+
server_args.base_gpu_id = dp_id * server_args.tp_size
|
86
|
+
server_args.dp_size = 1
|
87
|
+
|
88
|
+
proc = mp.Process(target=run_server, args=(server_args, dp_id))
|
89
|
+
proc.start()
|
90
|
+
return proc
|
91
|
+
|
92
|
+
|
93
|
+
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
|
94
|
+
"""Wait for server to be healthy by checking /health endpoint."""
|
95
|
+
start_time = time.time()
|
96
|
+
url = f"http://{host}:{port}/health"
|
97
|
+
|
98
|
+
while time.time() - start_time < timeout:
|
99
|
+
try:
|
100
|
+
response = requests.get(url, timeout=5)
|
101
|
+
if response.status_code == 200:
|
102
|
+
return True
|
103
|
+
except requests.exceptions.RequestException:
|
104
|
+
pass
|
105
|
+
time.sleep(1)
|
106
|
+
return False
|
107
|
+
|
108
|
+
|
109
|
+
def find_available_ports(base_port: int, count: int) -> List[int]:
|
110
|
+
"""Find consecutive available ports starting from base_port."""
|
111
|
+
available_ports = []
|
112
|
+
current_port = base_port
|
113
|
+
|
114
|
+
while len(available_ports) < count:
|
115
|
+
if is_port_available(current_port):
|
116
|
+
available_ports.append(current_port)
|
117
|
+
current_port += random.randint(100, 1000)
|
118
|
+
|
119
|
+
return available_ports
|
120
|
+
|
121
|
+
|
122
|
+
def cleanup_processes(processes: List[mp.Process]):
|
123
|
+
for process in processes:
|
124
|
+
logger.info(f"Terminating process group {process.pid}")
|
125
|
+
try:
|
126
|
+
os.killpg(process.pid, signal.SIGTERM)
|
127
|
+
except ProcessLookupError:
|
128
|
+
# Process group may already be terminated
|
129
|
+
pass
|
130
|
+
|
131
|
+
# Wait for processes to terminate
|
132
|
+
for process in processes:
|
133
|
+
process.join(timeout=5)
|
134
|
+
if process.is_alive():
|
135
|
+
logger.warning(
|
136
|
+
f"Process {process.pid} did not terminate gracefully, forcing kill"
|
137
|
+
)
|
138
|
+
try:
|
139
|
+
os.killpg(process.pid, signal.SIGKILL)
|
140
|
+
except ProcessLookupError:
|
141
|
+
pass
|
142
|
+
|
143
|
+
logger.info("All process groups terminated")
|
144
|
+
|
145
|
+
|
146
|
+
def main():
|
147
|
+
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
|
148
|
+
mp.set_start_method("spawn")
|
149
|
+
|
150
|
+
parser = argparse.ArgumentParser(
|
151
|
+
description="Launch SGLang router and server processes"
|
152
|
+
)
|
153
|
+
|
154
|
+
ServerArgs.add_cli_args(parser)
|
155
|
+
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
|
156
|
+
parser.add_argument(
|
157
|
+
"--router-dp-worker-base-port",
|
158
|
+
type=int,
|
159
|
+
default=31000,
|
160
|
+
help="Base port number for data parallel workers",
|
161
|
+
)
|
162
|
+
parser.add_argument(
|
163
|
+
"--worker-urls-port",
|
164
|
+
type=int,
|
165
|
+
help="Port number for worker URLs publisher",
|
166
|
+
)
|
167
|
+
|
168
|
+
args = parser.parse_args()
|
169
|
+
server_args = ServerArgs.from_cli_args(args)
|
170
|
+
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
171
|
+
|
172
|
+
# Find available ports for workers
|
173
|
+
worker_ports = find_available_ports(
|
174
|
+
args.router_dp_worker_base_port, server_args.dp_size
|
175
|
+
)
|
176
|
+
|
177
|
+
# Start server processes
|
178
|
+
server_processes = []
|
179
|
+
|
180
|
+
for i, worker_port in enumerate(worker_ports):
|
181
|
+
logger.info(f"Launching DP server process {i} on port {worker_port}")
|
182
|
+
proc = launch_server_process(server_args, worker_port, i)
|
183
|
+
server_processes.append(proc)
|
184
|
+
|
185
|
+
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
|
186
|
+
signal.signal(
|
187
|
+
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
|
188
|
+
)
|
189
|
+
signal.signal(
|
190
|
+
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
|
191
|
+
)
|
192
|
+
|
193
|
+
# Update router args with worker URLs
|
194
|
+
worker_urls = [f"http://{server_args.host}:{port}" for port in worker_ports]
|
195
|
+
router_args.worker_urls = worker_urls
|
196
|
+
|
197
|
+
# Publish worker URLs via ZMQ if port is specified
|
198
|
+
if args.worker_urls_port:
|
199
|
+
try:
|
200
|
+
context = zmq.Context()
|
201
|
+
socket = context.socket(zmq.PUB)
|
202
|
+
socket.bind(f"tcp://*:{args.worker_urls_port}")
|
203
|
+
# Give subscribers time to connect
|
204
|
+
time.sleep(0.1)
|
205
|
+
socket.send_json({"type": "worker_urls", "urls": worker_urls})
|
206
|
+
logger.info(
|
207
|
+
f"Published worker URLs via ZMQ on port {args.worker_urls_port}"
|
208
|
+
)
|
209
|
+
socket.close()
|
210
|
+
context.term()
|
211
|
+
except Exception as e:
|
212
|
+
logger.error(f"Failed to publish worker URLs via ZMQ: {e}")
|
213
|
+
cleanup_processes(server_processes)
|
214
|
+
sys.exit(1)
|
215
|
+
|
216
|
+
# Start the router
|
217
|
+
try:
|
218
|
+
launch_router(router_args)
|
219
|
+
except Exception as e:
|
220
|
+
logger.error(f"Failed to start router: {e}")
|
221
|
+
cleanup_processes(server_processes)
|
222
|
+
sys.exit(1)
|
223
|
+
|
224
|
+
|
225
|
+
if __name__ == "__main__":
|
226
|
+
main()
|