arbor-ai 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/server/services/inference/__init__.py +0 -0
- arbor/server/services/inference/sgl_router_launch_server.py +226 -0
- arbor/server/services/inference_manager.py +68 -11
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.14.dist-info}/METADATA +2 -2
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.14.dist-info}/RECORD +9 -7
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.14.dist-info}/WHEEL +1 -1
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.14.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.14.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.14.dist-info}/top_level.txt +0 -0
File without changes
|
@@ -0,0 +1,226 @@
|
|
1
|
+
import argparse
|
2
|
+
import copy
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import multiprocessing as mp
|
6
|
+
import os
|
7
|
+
import random
|
8
|
+
import signal
|
9
|
+
import sys
|
10
|
+
import time
|
11
|
+
from typing import List
|
12
|
+
|
13
|
+
import requests
|
14
|
+
import zmq
|
15
|
+
from setproctitle import setproctitle
|
16
|
+
from sglang.srt.entrypoints.http_server import launch_server
|
17
|
+
from sglang.srt.server_args import ServerArgs
|
18
|
+
from sglang.srt.utils import is_port_available
|
19
|
+
from sglang_router.launch_router import RouterArgs, launch_router
|
20
|
+
|
21
|
+
|
22
|
+
def setup_logger():
|
23
|
+
logger = logging.getLogger("router")
|
24
|
+
logger.setLevel(logging.INFO)
|
25
|
+
|
26
|
+
formatter = logging.Formatter(
|
27
|
+
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
|
28
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
29
|
+
)
|
30
|
+
|
31
|
+
handler = logging.StreamHandler()
|
32
|
+
handler.setFormatter(formatter)
|
33
|
+
logger.addHandler(handler)
|
34
|
+
|
35
|
+
return logger
|
36
|
+
|
37
|
+
|
38
|
+
logger = setup_logger()
|
39
|
+
|
40
|
+
|
41
|
+
# Create new process group
|
42
|
+
def run_server(server_args, dp_rank):
|
43
|
+
"""
|
44
|
+
Note:
|
45
|
+
|
46
|
+
1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
|
47
|
+
This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
|
48
|
+
|
49
|
+
Terminal (PGID=100)
|
50
|
+
└── Main Python Process (PGID=100)
|
51
|
+
└── Server Process 1 (PGID=100)
|
52
|
+
└── Scheduler 1
|
53
|
+
└── Detokenizer 1
|
54
|
+
└── Server Process 2 (PGID=100)
|
55
|
+
└── Scheduler 2
|
56
|
+
└── Detokenizer 2
|
57
|
+
|
58
|
+
2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
|
59
|
+
|
60
|
+
Terminal (PGID=100)
|
61
|
+
└── Main Python Process (PGID=200)
|
62
|
+
└── Server Process 1 (PGID=300)
|
63
|
+
└── Scheduler 1
|
64
|
+
└── Detokenizer 1
|
65
|
+
└── Server Process 2 (PGID=400)
|
66
|
+
└── Scheduler 2
|
67
|
+
└── Detokenizer 2
|
68
|
+
"""
|
69
|
+
# create new process group
|
70
|
+
os.setpgrp()
|
71
|
+
|
72
|
+
setproctitle("sglang::server")
|
73
|
+
# Set SGLANG_DP_RANK environment variable
|
74
|
+
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
|
75
|
+
|
76
|
+
launch_server(server_args)
|
77
|
+
|
78
|
+
|
79
|
+
def launch_server_process(
|
80
|
+
server_args: ServerArgs, worker_port: int, dp_id: int
|
81
|
+
) -> mp.Process:
|
82
|
+
"""Launch a single server process with the given args and port."""
|
83
|
+
server_args = copy.deepcopy(server_args)
|
84
|
+
server_args.port = worker_port
|
85
|
+
server_args.base_gpu_id = dp_id * server_args.tp_size
|
86
|
+
server_args.dp_size = 1
|
87
|
+
|
88
|
+
proc = mp.Process(target=run_server, args=(server_args, dp_id))
|
89
|
+
proc.start()
|
90
|
+
return proc
|
91
|
+
|
92
|
+
|
93
|
+
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
|
94
|
+
"""Wait for server to be healthy by checking /health endpoint."""
|
95
|
+
start_time = time.time()
|
96
|
+
url = f"http://{host}:{port}/health"
|
97
|
+
|
98
|
+
while time.time() - start_time < timeout:
|
99
|
+
try:
|
100
|
+
response = requests.get(url, timeout=5)
|
101
|
+
if response.status_code == 200:
|
102
|
+
return True
|
103
|
+
except requests.exceptions.RequestException:
|
104
|
+
pass
|
105
|
+
time.sleep(1)
|
106
|
+
return False
|
107
|
+
|
108
|
+
|
109
|
+
def find_available_ports(base_port: int, count: int) -> List[int]:
|
110
|
+
"""Find consecutive available ports starting from base_port."""
|
111
|
+
available_ports = []
|
112
|
+
current_port = base_port
|
113
|
+
|
114
|
+
while len(available_ports) < count:
|
115
|
+
if is_port_available(current_port):
|
116
|
+
available_ports.append(current_port)
|
117
|
+
current_port += random.randint(100, 1000)
|
118
|
+
|
119
|
+
return available_ports
|
120
|
+
|
121
|
+
|
122
|
+
def cleanup_processes(processes: List[mp.Process]):
|
123
|
+
for process in processes:
|
124
|
+
logger.info(f"Terminating process group {process.pid}")
|
125
|
+
try:
|
126
|
+
os.killpg(process.pid, signal.SIGTERM)
|
127
|
+
except ProcessLookupError:
|
128
|
+
# Process group may already be terminated
|
129
|
+
pass
|
130
|
+
|
131
|
+
# Wait for processes to terminate
|
132
|
+
for process in processes:
|
133
|
+
process.join(timeout=5)
|
134
|
+
if process.is_alive():
|
135
|
+
logger.warning(
|
136
|
+
f"Process {process.pid} did not terminate gracefully, forcing kill"
|
137
|
+
)
|
138
|
+
try:
|
139
|
+
os.killpg(process.pid, signal.SIGKILL)
|
140
|
+
except ProcessLookupError:
|
141
|
+
pass
|
142
|
+
|
143
|
+
logger.info("All process groups terminated")
|
144
|
+
|
145
|
+
|
146
|
+
def main():
|
147
|
+
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
|
148
|
+
mp.set_start_method("spawn")
|
149
|
+
|
150
|
+
parser = argparse.ArgumentParser(
|
151
|
+
description="Launch SGLang router and server processes"
|
152
|
+
)
|
153
|
+
|
154
|
+
ServerArgs.add_cli_args(parser)
|
155
|
+
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
|
156
|
+
parser.add_argument(
|
157
|
+
"--router-dp-worker-base-port",
|
158
|
+
type=int,
|
159
|
+
default=31000,
|
160
|
+
help="Base port number for data parallel workers",
|
161
|
+
)
|
162
|
+
parser.add_argument(
|
163
|
+
"--worker-urls-port",
|
164
|
+
type=int,
|
165
|
+
help="Port number for worker URLs publisher",
|
166
|
+
)
|
167
|
+
|
168
|
+
args = parser.parse_args()
|
169
|
+
server_args = ServerArgs.from_cli_args(args)
|
170
|
+
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
171
|
+
|
172
|
+
# Find available ports for workers
|
173
|
+
worker_ports = find_available_ports(
|
174
|
+
args.router_dp_worker_base_port, server_args.dp_size
|
175
|
+
)
|
176
|
+
|
177
|
+
# Start server processes
|
178
|
+
server_processes = []
|
179
|
+
|
180
|
+
for i, worker_port in enumerate(worker_ports):
|
181
|
+
logger.info(f"Launching DP server process {i} on port {worker_port}")
|
182
|
+
proc = launch_server_process(server_args, worker_port, i)
|
183
|
+
server_processes.append(proc)
|
184
|
+
|
185
|
+
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
|
186
|
+
signal.signal(
|
187
|
+
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
|
188
|
+
)
|
189
|
+
signal.signal(
|
190
|
+
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
|
191
|
+
)
|
192
|
+
|
193
|
+
# Update router args with worker URLs
|
194
|
+
worker_urls = [f"http://{server_args.host}:{port}" for port in worker_ports]
|
195
|
+
router_args.worker_urls = worker_urls
|
196
|
+
|
197
|
+
# Publish worker URLs via ZMQ if port is specified
|
198
|
+
if args.worker_urls_port:
|
199
|
+
try:
|
200
|
+
context = zmq.Context()
|
201
|
+
socket = context.socket(zmq.PUB)
|
202
|
+
socket.bind(f"tcp://*:{args.worker_urls_port}")
|
203
|
+
# Give subscribers time to connect
|
204
|
+
time.sleep(0.1)
|
205
|
+
socket.send_json({"type": "worker_urls", "urls": worker_urls})
|
206
|
+
logger.info(
|
207
|
+
f"Published worker URLs via ZMQ on port {args.worker_urls_port}"
|
208
|
+
)
|
209
|
+
socket.close()
|
210
|
+
context.term()
|
211
|
+
except Exception as e:
|
212
|
+
logger.error(f"Failed to publish worker URLs via ZMQ: {e}")
|
213
|
+
cleanup_processes(server_processes)
|
214
|
+
sys.exit(1)
|
215
|
+
|
216
|
+
# Start the router
|
217
|
+
try:
|
218
|
+
launch_router(router_args)
|
219
|
+
except Exception as e:
|
220
|
+
logger.error(f"Failed to start router: {e}")
|
221
|
+
cleanup_processes(server_processes)
|
222
|
+
sys.exit(1)
|
223
|
+
|
224
|
+
|
225
|
+
if __name__ == "__main__":
|
226
|
+
main()
|
@@ -13,6 +13,7 @@ from typing import Any, Dict, Optional
|
|
13
13
|
|
14
14
|
import aiohttp
|
15
15
|
import requests
|
16
|
+
import zmq
|
16
17
|
|
17
18
|
from arbor.server.core.config import Settings
|
18
19
|
|
@@ -28,6 +29,7 @@ class InferenceManager:
|
|
28
29
|
self.current_model = None
|
29
30
|
self.inference_count = 0
|
30
31
|
self._session = None
|
32
|
+
self.worker_urls = []
|
31
33
|
# Set up signal handler for graceful shutdown
|
32
34
|
signal.signal(signal.SIGINT, self._signal_handler)
|
33
35
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
@@ -74,15 +76,17 @@ class InferenceManager:
|
|
74
76
|
print(
|
75
77
|
f"Grabbing a free port to launch an SGLang server for model {model}"
|
76
78
|
)
|
77
|
-
|
79
|
+
router_port = get_free_port()
|
80
|
+
dp_worker_base_port = get_free_port()
|
81
|
+
worker_urls_port = get_free_port() # Get a port for worker URLs
|
82
|
+
|
78
83
|
timeout = launch_kwargs.get("timeout", 1800)
|
79
84
|
my_env = os.environ.copy()
|
80
85
|
my_env["CUDA_VISIBLE_DEVICES"] = (
|
81
86
|
self.settings.arbor_config.inference.gpu_ids
|
82
87
|
)
|
83
88
|
n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
|
84
|
-
|
85
|
-
command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
|
89
|
+
command = f"python -m arbor.server.services.inference.sgl_router_launch_server --model-path {model} --dp-size {n_gpus} --port {router_port} --host 0.0.0.0 --disable-radix-cache --router-dp-worker-base-port {dp_worker_base_port} --worker-urls-port {worker_urls_port}"
|
86
90
|
print(f"Running command: {command}")
|
87
91
|
if launch_kwargs.get("max_context_length"):
|
88
92
|
command += (
|
@@ -124,8 +128,16 @@ class InferenceManager:
|
|
124
128
|
)
|
125
129
|
thread.start()
|
126
130
|
|
131
|
+
# Get worker URLs before waiting for server
|
132
|
+
try:
|
133
|
+
worker_urls = get_worker_urls(worker_urls_port)
|
134
|
+
print(f"Received worker URLs: {worker_urls}")
|
135
|
+
self.worker_urls = worker_urls
|
136
|
+
except TimeoutError as e:
|
137
|
+
raise Exception(f"Failed to get worker URLs: {e}")
|
138
|
+
|
127
139
|
# Wait until the server is ready (or times out)
|
128
|
-
base_url = f"http://localhost:{
|
140
|
+
base_url = f"http://localhost:{router_port}"
|
129
141
|
try:
|
130
142
|
wait_for_server(base_url, timeout=timeout)
|
131
143
|
except TimeoutError:
|
@@ -142,9 +154,9 @@ class InferenceManager:
|
|
142
154
|
return "".join(logs_buffer)
|
143
155
|
|
144
156
|
# Let the user know server is up
|
145
|
-
print(f"Server ready on random port {
|
157
|
+
print(f"Server ready on random port {router_port}!")
|
146
158
|
|
147
|
-
self.launch_kwargs["api_base"] = f"http://localhost:{
|
159
|
+
self.launch_kwargs["api_base"] = f"http://localhost:{router_port}/v1"
|
148
160
|
self.launch_kwargs["api_key"] = "local"
|
149
161
|
self.get_logs = get_logs
|
150
162
|
self.process = process
|
@@ -286,9 +298,10 @@ class InferenceManager:
|
|
286
298
|
self.inference_count = 0
|
287
299
|
|
288
300
|
tik = time.time()
|
289
|
-
self.kill()
|
290
|
-
print("Just killed server")
|
291
|
-
time.sleep(5)
|
301
|
+
# self.kill()
|
302
|
+
# print("Just killed server")
|
303
|
+
# time.sleep(5)
|
304
|
+
|
292
305
|
# Check that output directory exists and was created successfully
|
293
306
|
print(f"Checking that output directory {output_dir} exists")
|
294
307
|
if not os.path.exists(output_dir):
|
@@ -296,8 +309,27 @@ class InferenceManager:
|
|
296
309
|
f"Failed to save model - output directory {output_dir} does not exist"
|
297
310
|
)
|
298
311
|
|
299
|
-
print("
|
300
|
-
|
312
|
+
print("Directly updating weights from disk")
|
313
|
+
for worker_url in self.worker_urls:
|
314
|
+
print(f"Updating weights from disk for worker {worker_url}")
|
315
|
+
try:
|
316
|
+
response = requests.post(
|
317
|
+
f"{worker_url}/update_weights_from_disk",
|
318
|
+
json={"model_path": output_dir},
|
319
|
+
)
|
320
|
+
response_json = response.json()
|
321
|
+
print(f"Response from update_weights_from_disk: {response_json}")
|
322
|
+
# TODO: Check that the response is successful
|
323
|
+
except Exception as e:
|
324
|
+
print(f"Error during update_weights_from_disk: {e}")
|
325
|
+
print(f"Full error during update_weights_from_disk: {str(e)}")
|
326
|
+
if hasattr(e, "response") and e.response is not None:
|
327
|
+
print(f"Response status code: {e.response.status_code}")
|
328
|
+
print(f"Response text: {e.response.text}")
|
329
|
+
self.current_model = output_dir
|
330
|
+
|
331
|
+
# print("Launching new server")
|
332
|
+
# self.launch(output_dir, self.launch_kwargs)
|
301
333
|
tok = time.time()
|
302
334
|
self.restarting = False
|
303
335
|
print(f"Time taken to update model: {tok - tik} seconds")
|
@@ -345,3 +377,28 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
|
345
377
|
except requests.exceptions.RequestException:
|
346
378
|
# Server not up yet, wait and retry
|
347
379
|
time.sleep(1)
|
380
|
+
|
381
|
+
|
382
|
+
def get_worker_urls(zmq_port: int, timeout: float = 30.0) -> list:
|
383
|
+
print(f"Attempting to get worker URLs on port {zmq_port} with timeout {timeout}s")
|
384
|
+
context = zmq.Context()
|
385
|
+
socket = context.socket(zmq.SUB)
|
386
|
+
socket.connect(f"tcp://localhost:{zmq_port}")
|
387
|
+
socket.setsockopt_string(zmq.SUBSCRIBE, "") # Subscribe to all messages
|
388
|
+
|
389
|
+
# Set a timeout for receiving
|
390
|
+
socket.setsockopt(zmq.RCVTIMEO, int(timeout * 1000))
|
391
|
+
|
392
|
+
try:
|
393
|
+
print("Waiting for worker URLs message...")
|
394
|
+
message = socket.recv_json()
|
395
|
+
print(f"Received message: {message}")
|
396
|
+
if message.get("type") == "worker_urls":
|
397
|
+
return message["urls"]
|
398
|
+
else:
|
399
|
+
raise ValueError(f"Unexpected message type: {message.get('type')}")
|
400
|
+
except zmq.error.Again:
|
401
|
+
raise TimeoutError(f"Timeout waiting for worker URLs on port {zmq_port}")
|
402
|
+
finally:
|
403
|
+
socket.close()
|
404
|
+
context.term()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.14
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -15,7 +15,7 @@ Requires-Dist: python-multipart
|
|
15
15
|
Requires-Dist: pydantic-settings
|
16
16
|
Requires-Dist: torch
|
17
17
|
Requires-Dist: transformers
|
18
|
-
Requires-Dist: trl
|
18
|
+
Requires-Dist: trl
|
19
19
|
Requires-Dist: peft
|
20
20
|
Requires-Dist: ray>=2.9
|
21
21
|
Requires-Dist: setuptools<77.0.0,>=76.0.0
|
@@ -18,17 +18,19 @@ arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3h
|
|
18
18
|
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
20
|
arbor/server/services/grpo_manager.py,sha256=-_0xjENvIrOAtHACkFPMYox9YAeckHbpX2FkrmKrWuU,15448
|
21
|
-
arbor/server/services/inference_manager.py,sha256=
|
21
|
+
arbor/server/services/inference_manager.py,sha256=Ju39_7EWySzAAk7ftz-AzSNBEo0tlayloPVS0XRAp8E,15304
|
22
22
|
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
23
|
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
24
|
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
|
26
|
+
arbor/server/services/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
+
arbor/server/services/inference/sgl_router_launch_server.py,sha256=eqTW6nDqqoRMISHfv5ScBCrolqLBp9zyxPXqHUlP6uo,6988
|
26
28
|
arbor/server/services/scripts/grpo_training.py,sha256=eMT5cIMolAzhukANH1WRmPdxIkvLbsbrggdGFCMGMHc,26474
|
27
29
|
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
30
|
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
-
arbor_ai-0.1.
|
30
|
-
arbor_ai-0.1.
|
31
|
-
arbor_ai-0.1.
|
32
|
-
arbor_ai-0.1.
|
33
|
-
arbor_ai-0.1.
|
34
|
-
arbor_ai-0.1.
|
31
|
+
arbor_ai-0.1.14.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
32
|
+
arbor_ai-0.1.14.dist-info/METADATA,sha256=vw8RnMPdGi36ji4rpjAldkOuCbxxjV4MFVi6yW-0kas,2434
|
33
|
+
arbor_ai-0.1.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
34
|
+
arbor_ai-0.1.14.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
35
|
+
arbor_ai-0.1.14.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
36
|
+
arbor_ai-0.1.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|