arbor-ai 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,226 +0,0 @@
1
- import argparse
2
- import copy
3
- import json
4
- import logging
5
- import multiprocessing as mp
6
- import os
7
- import random
8
- import signal
9
- import sys
10
- import time
11
- from typing import List
12
-
13
- import requests
14
- import zmq
15
- from setproctitle import setproctitle
16
- from sglang.srt.entrypoints.http_server import launch_server
17
- from sglang.srt.server_args import ServerArgs
18
- from sglang.srt.utils import is_port_available
19
- from sglang_router.launch_router import RouterArgs, launch_router
20
-
21
-
22
- def setup_logger():
23
- logger = logging.getLogger("router")
24
- logger.setLevel(logging.INFO)
25
-
26
- formatter = logging.Formatter(
27
- "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
28
- datefmt="%Y-%m-%d %H:%M:%S",
29
- )
30
-
31
- handler = logging.StreamHandler()
32
- handler.setFormatter(formatter)
33
- logger.addHandler(handler)
34
-
35
- return logger
36
-
37
-
38
- logger = setup_logger()
39
-
40
-
41
- # Create new process group
42
- def run_server(server_args, dp_rank):
43
- """
44
- Note:
45
-
46
- 1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
47
- This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
48
-
49
- Terminal (PGID=100)
50
- └── Main Python Process (PGID=100)
51
- └── Server Process 1 (PGID=100)
52
- └── Scheduler 1
53
- └── Detokenizer 1
54
- └── Server Process 2 (PGID=100)
55
- └── Scheduler 2
56
- └── Detokenizer 2
57
-
58
- 2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
59
-
60
- Terminal (PGID=100)
61
- └── Main Python Process (PGID=200)
62
- └── Server Process 1 (PGID=300)
63
- └── Scheduler 1
64
- └── Detokenizer 1
65
- └── Server Process 2 (PGID=400)
66
- └── Scheduler 2
67
- └── Detokenizer 2
68
- """
69
- # create new process group
70
- os.setpgrp()
71
-
72
- setproctitle("sglang::server")
73
- # Set SGLANG_DP_RANK environment variable
74
- os.environ["SGLANG_DP_RANK"] = str(dp_rank)
75
-
76
- launch_server(server_args)
77
-
78
-
79
- def launch_server_process(
80
- server_args: ServerArgs, worker_port: int, dp_id: int
81
- ) -> mp.Process:
82
- """Launch a single server process with the given args and port."""
83
- server_args = copy.deepcopy(server_args)
84
- server_args.port = worker_port
85
- server_args.base_gpu_id = dp_id * server_args.tp_size
86
- server_args.dp_size = 1
87
-
88
- proc = mp.Process(target=run_server, args=(server_args, dp_id))
89
- proc.start()
90
- return proc
91
-
92
-
93
- def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
94
- """Wait for server to be healthy by checking /health endpoint."""
95
- start_time = time.time()
96
- url = f"http://{host}:{port}/health"
97
-
98
- while time.time() - start_time < timeout:
99
- try:
100
- response = requests.get(url, timeout=5)
101
- if response.status_code == 200:
102
- return True
103
- except requests.exceptions.RequestException:
104
- pass
105
- time.sleep(1)
106
- return False
107
-
108
-
109
- def find_available_ports(base_port: int, count: int) -> List[int]:
110
- """Find consecutive available ports starting from base_port."""
111
- available_ports = []
112
- current_port = base_port
113
-
114
- while len(available_ports) < count:
115
- if is_port_available(current_port):
116
- available_ports.append(current_port)
117
- current_port += random.randint(100, 1000)
118
-
119
- return available_ports
120
-
121
-
122
- def cleanup_processes(processes: List[mp.Process]):
123
- for process in processes:
124
- logger.info(f"Terminating process group {process.pid}")
125
- try:
126
- os.killpg(process.pid, signal.SIGTERM)
127
- except ProcessLookupError:
128
- # Process group may already be terminated
129
- pass
130
-
131
- # Wait for processes to terminate
132
- for process in processes:
133
- process.join(timeout=5)
134
- if process.is_alive():
135
- logger.warning(
136
- f"Process {process.pid} did not terminate gracefully, forcing kill"
137
- )
138
- try:
139
- os.killpg(process.pid, signal.SIGKILL)
140
- except ProcessLookupError:
141
- pass
142
-
143
- logger.info("All process groups terminated")
144
-
145
-
146
- def main():
147
- # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
148
- mp.set_start_method("spawn")
149
-
150
- parser = argparse.ArgumentParser(
151
- description="Launch SGLang router and server processes"
152
- )
153
-
154
- ServerArgs.add_cli_args(parser)
155
- RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
156
- parser.add_argument(
157
- "--router-dp-worker-base-port",
158
- type=int,
159
- default=31000,
160
- help="Base port number for data parallel workers",
161
- )
162
- parser.add_argument(
163
- "--worker-urls-port",
164
- type=int,
165
- help="Port number for worker URLs publisher",
166
- )
167
-
168
- args = parser.parse_args()
169
- server_args = ServerArgs.from_cli_args(args)
170
- router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
171
-
172
- # Find available ports for workers
173
- worker_ports = find_available_ports(
174
- args.router_dp_worker_base_port, server_args.dp_size
175
- )
176
-
177
- # Start server processes
178
- server_processes = []
179
-
180
- for i, worker_port in enumerate(worker_ports):
181
- logger.info(f"Launching DP server process {i} on port {worker_port}")
182
- proc = launch_server_process(server_args, worker_port, i)
183
- server_processes.append(proc)
184
-
185
- signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
186
- signal.signal(
187
- signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
188
- )
189
- signal.signal(
190
- signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
191
- )
192
-
193
- # Update router args with worker URLs
194
- worker_urls = [f"http://{server_args.host}:{port}" for port in worker_ports]
195
- router_args.worker_urls = worker_urls
196
-
197
- # Publish worker URLs via ZMQ if port is specified
198
- if args.worker_urls_port:
199
- try:
200
- context = zmq.Context()
201
- socket = context.socket(zmq.PUB)
202
- socket.bind(f"tcp://*:{args.worker_urls_port}")
203
- # Give subscribers time to connect
204
- time.sleep(0.1)
205
- socket.send_json({"type": "worker_urls", "urls": worker_urls})
206
- logger.info(
207
- f"Published worker URLs via ZMQ on port {args.worker_urls_port}"
208
- )
209
- socket.close()
210
- context.term()
211
- except Exception as e:
212
- logger.error(f"Failed to publish worker URLs via ZMQ: {e}")
213
- cleanup_processes(server_processes)
214
- sys.exit(1)
215
-
216
- # Start the router
217
- try:
218
- launch_router(router_args)
219
- except Exception as e:
220
- logger.error(f"Failed to start router: {e}")
221
- cleanup_processes(server_processes)
222
- sys.exit(1)
223
-
224
-
225
- if __name__ == "__main__":
226
- main()