arbor-ai 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,226 @@
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import logging
5
+ import multiprocessing as mp
6
+ import os
7
+ import random
8
+ import signal
9
+ import sys
10
+ import time
11
+ from typing import List
12
+
13
+ import requests
14
+ import zmq
15
+ from setproctitle import setproctitle
16
+ from sglang.srt.entrypoints.http_server import launch_server
17
+ from sglang.srt.server_args import ServerArgs
18
+ from sglang.srt.utils import is_port_available
19
+ from sglang_router.launch_router import RouterArgs, launch_router
20
+
21
+
22
+ def setup_logger():
23
+ logger = logging.getLogger("router")
24
+ logger.setLevel(logging.INFO)
25
+
26
+ formatter = logging.Formatter(
27
+ "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ handler = logging.StreamHandler()
32
+ handler.setFormatter(formatter)
33
+ logger.addHandler(handler)
34
+
35
+ return logger
36
+
37
+
38
+ logger = setup_logger()
39
+
40
+
41
+ # Create new process group
42
+ def run_server(server_args, dp_rank):
43
+ """
44
+ Note:
45
+
46
+ 1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
47
+ This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
48
+
49
+ Terminal (PGID=100)
50
+ └── Main Python Process (PGID=100)
51
+ └── Server Process 1 (PGID=100)
52
+ └── Scheduler 1
53
+ └── Detokenizer 1
54
+ └── Server Process 2 (PGID=100)
55
+ └── Scheduler 2
56
+ └── Detokenizer 2
57
+
58
+ 2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
59
+
60
+ Terminal (PGID=100)
61
+ └── Main Python Process (PGID=200)
62
+ └── Server Process 1 (PGID=300)
63
+ └── Scheduler 1
64
+ └── Detokenizer 1
65
+ └── Server Process 2 (PGID=400)
66
+ └── Scheduler 2
67
+ └── Detokenizer 2
68
+ """
69
+ # create new process group
70
+ os.setpgrp()
71
+
72
+ setproctitle("sglang::server")
73
+ # Set SGLANG_DP_RANK environment variable
74
+ os.environ["SGLANG_DP_RANK"] = str(dp_rank)
75
+
76
+ launch_server(server_args)
77
+
78
+
79
+ def launch_server_process(
80
+ server_args: ServerArgs, worker_port: int, dp_id: int
81
+ ) -> mp.Process:
82
+ """Launch a single server process with the given args and port."""
83
+ server_args = copy.deepcopy(server_args)
84
+ server_args.port = worker_port
85
+ server_args.base_gpu_id = dp_id * server_args.tp_size
86
+ server_args.dp_size = 1
87
+
88
+ proc = mp.Process(target=run_server, args=(server_args, dp_id))
89
+ proc.start()
90
+ return proc
91
+
92
+
93
+ def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
94
+ """Wait for server to be healthy by checking /health endpoint."""
95
+ start_time = time.time()
96
+ url = f"http://{host}:{port}/health"
97
+
98
+ while time.time() - start_time < timeout:
99
+ try:
100
+ response = requests.get(url, timeout=5)
101
+ if response.status_code == 200:
102
+ return True
103
+ except requests.exceptions.RequestException:
104
+ pass
105
+ time.sleep(1)
106
+ return False
107
+
108
+
109
+ def find_available_ports(base_port: int, count: int) -> List[int]:
110
+ """Find consecutive available ports starting from base_port."""
111
+ available_ports = []
112
+ current_port = base_port
113
+
114
+ while len(available_ports) < count:
115
+ if is_port_available(current_port):
116
+ available_ports.append(current_port)
117
+ current_port += random.randint(100, 1000)
118
+
119
+ return available_ports
120
+
121
+
122
+ def cleanup_processes(processes: List[mp.Process]):
123
+ for process in processes:
124
+ logger.info(f"Terminating process group {process.pid}")
125
+ try:
126
+ os.killpg(process.pid, signal.SIGTERM)
127
+ except ProcessLookupError:
128
+ # Process group may already be terminated
129
+ pass
130
+
131
+ # Wait for processes to terminate
132
+ for process in processes:
133
+ process.join(timeout=5)
134
+ if process.is_alive():
135
+ logger.warning(
136
+ f"Process {process.pid} did not terminate gracefully, forcing kill"
137
+ )
138
+ try:
139
+ os.killpg(process.pid, signal.SIGKILL)
140
+ except ProcessLookupError:
141
+ pass
142
+
143
+ logger.info("All process groups terminated")
144
+
145
+
146
+ def main():
147
+ # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
148
+ mp.set_start_method("spawn")
149
+
150
+ parser = argparse.ArgumentParser(
151
+ description="Launch SGLang router and server processes"
152
+ )
153
+
154
+ ServerArgs.add_cli_args(parser)
155
+ RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
156
+ parser.add_argument(
157
+ "--router-dp-worker-base-port",
158
+ type=int,
159
+ default=31000,
160
+ help="Base port number for data parallel workers",
161
+ )
162
+ parser.add_argument(
163
+ "--worker-urls-port",
164
+ type=int,
165
+ help="Port number for worker URLs publisher",
166
+ )
167
+
168
+ args = parser.parse_args()
169
+ server_args = ServerArgs.from_cli_args(args)
170
+ router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
171
+
172
+ # Find available ports for workers
173
+ worker_ports = find_available_ports(
174
+ args.router_dp_worker_base_port, server_args.dp_size
175
+ )
176
+
177
+ # Start server processes
178
+ server_processes = []
179
+
180
+ for i, worker_port in enumerate(worker_ports):
181
+ logger.info(f"Launching DP server process {i} on port {worker_port}")
182
+ proc = launch_server_process(server_args, worker_port, i)
183
+ server_processes.append(proc)
184
+
185
+ signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
186
+ signal.signal(
187
+ signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
188
+ )
189
+ signal.signal(
190
+ signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
191
+ )
192
+
193
+ # Update router args with worker URLs
194
+ worker_urls = [f"http://{server_args.host}:{port}" for port in worker_ports]
195
+ router_args.worker_urls = worker_urls
196
+
197
+ # Publish worker URLs via ZMQ if port is specified
198
+ if args.worker_urls_port:
199
+ try:
200
+ context = zmq.Context()
201
+ socket = context.socket(zmq.PUB)
202
+ socket.bind(f"tcp://*:{args.worker_urls_port}")
203
+ # Give subscribers time to connect
204
+ time.sleep(0.1)
205
+ socket.send_json({"type": "worker_urls", "urls": worker_urls})
206
+ logger.info(
207
+ f"Published worker URLs via ZMQ on port {args.worker_urls_port}"
208
+ )
209
+ socket.close()
210
+ context.term()
211
+ except Exception as e:
212
+ logger.error(f"Failed to publish worker URLs via ZMQ: {e}")
213
+ cleanup_processes(server_processes)
214
+ sys.exit(1)
215
+
216
+ # Start the router
217
+ try:
218
+ launch_router(router_args)
219
+ except Exception as e:
220
+ logger.error(f"Failed to start router: {e}")
221
+ cleanup_processes(server_processes)
222
+ sys.exit(1)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()
@@ -13,6 +13,7 @@ from typing import Any, Dict, Optional
13
13
 
14
14
  import aiohttp
15
15
  import requests
16
+ import zmq
16
17
 
17
18
  from arbor.server.core.config import Settings
18
19
 
@@ -28,6 +29,7 @@ class InferenceManager:
28
29
  self.current_model = None
29
30
  self.inference_count = 0
30
31
  self._session = None
32
+ self.worker_urls = []
31
33
  # Set up signal handler for graceful shutdown
32
34
  signal.signal(signal.SIGINT, self._signal_handler)
33
35
  signal.signal(signal.SIGTERM, self._signal_handler)
@@ -74,15 +76,17 @@ class InferenceManager:
74
76
  print(
75
77
  f"Grabbing a free port to launch an SGLang server for model {model}"
76
78
  )
77
- port = get_free_port()
79
+ router_port = get_free_port()
80
+ dp_worker_base_port = get_free_port()
81
+ worker_urls_port = get_free_port() # Get a port for worker URLs
82
+
78
83
  timeout = launch_kwargs.get("timeout", 1800)
79
84
  my_env = os.environ.copy()
80
85
  my_env["CUDA_VISIBLE_DEVICES"] = (
81
86
  self.settings.arbor_config.inference.gpu_ids
82
87
  )
83
88
  n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
84
- # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
85
- command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
89
+ command = f"python -m arbor.server.services.inference.sgl_router_launch_server --model-path {model} --dp-size {n_gpus} --port {router_port} --host 0.0.0.0 --disable-radix-cache --router-dp-worker-base-port {dp_worker_base_port} --worker-urls-port {worker_urls_port}"
86
90
  print(f"Running command: {command}")
87
91
  if launch_kwargs.get("max_context_length"):
88
92
  command += (
@@ -124,8 +128,16 @@ class InferenceManager:
124
128
  )
125
129
  thread.start()
126
130
 
131
+ # Get worker URLs before waiting for server
132
+ try:
133
+ worker_urls = get_worker_urls(worker_urls_port)
134
+ print(f"Received worker URLs: {worker_urls}")
135
+ self.worker_urls = worker_urls
136
+ except TimeoutError as e:
137
+ raise Exception(f"Failed to get worker URLs: {e}")
138
+
127
139
  # Wait until the server is ready (or times out)
128
- base_url = f"http://localhost:{port}"
140
+ base_url = f"http://localhost:{router_port}"
129
141
  try:
130
142
  wait_for_server(base_url, timeout=timeout)
131
143
  except TimeoutError:
@@ -142,9 +154,9 @@ class InferenceManager:
142
154
  return "".join(logs_buffer)
143
155
 
144
156
  # Let the user know server is up
145
- print(f"Server ready on random port {port}!")
157
+ print(f"Server ready on random port {router_port}!")
146
158
 
147
- self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
159
+ self.launch_kwargs["api_base"] = f"http://localhost:{router_port}/v1"
148
160
  self.launch_kwargs["api_key"] = "local"
149
161
  self.get_logs = get_logs
150
162
  self.process = process
@@ -286,9 +298,10 @@ class InferenceManager:
286
298
  self.inference_count = 0
287
299
 
288
300
  tik = time.time()
289
- self.kill()
290
- print("Just killed server")
291
- time.sleep(5)
301
+ # self.kill()
302
+ # print("Just killed server")
303
+ # time.sleep(5)
304
+
292
305
  # Check that output directory exists and was created successfully
293
306
  print(f"Checking that output directory {output_dir} exists")
294
307
  if not os.path.exists(output_dir):
@@ -296,8 +309,27 @@ class InferenceManager:
296
309
  f"Failed to save model - output directory {output_dir} does not exist"
297
310
  )
298
311
 
299
- print("Launching new server")
300
- self.launch(output_dir, self.launch_kwargs)
312
+ print("Directly updating weights from disk")
313
+ for worker_url in self.worker_urls:
314
+ print(f"Updating weights from disk for worker {worker_url}")
315
+ try:
316
+ response = requests.post(
317
+ f"{worker_url}/update_weights_from_disk",
318
+ json={"model_path": output_dir},
319
+ )
320
+ response_json = response.json()
321
+ print(f"Response from update_weights_from_disk: {response_json}")
322
+ # TODO: Check that the response is successful
323
+ except Exception as e:
324
+ print(f"Error during update_weights_from_disk: {e}")
325
+ print(f"Full error during update_weights_from_disk: {str(e)}")
326
+ if hasattr(e, "response") and e.response is not None:
327
+ print(f"Response status code: {e.response.status_code}")
328
+ print(f"Response text: {e.response.text}")
329
+ self.current_model = output_dir
330
+
331
+ # print("Launching new server")
332
+ # self.launch(output_dir, self.launch_kwargs)
301
333
  tok = time.time()
302
334
  self.restarting = False
303
335
  print(f"Time taken to update model: {tok - tik} seconds")
@@ -345,3 +377,28 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
345
377
  except requests.exceptions.RequestException:
346
378
  # Server not up yet, wait and retry
347
379
  time.sleep(1)
380
+
381
+
382
+ def get_worker_urls(zmq_port: int, timeout: float = 30.0) -> list:
383
+ print(f"Attempting to get worker URLs on port {zmq_port} with timeout {timeout}s")
384
+ context = zmq.Context()
385
+ socket = context.socket(zmq.SUB)
386
+ socket.connect(f"tcp://localhost:{zmq_port}")
387
+ socket.setsockopt_string(zmq.SUBSCRIBE, "") # Subscribe to all messages
388
+
389
+ # Set a timeout for receiving
390
+ socket.setsockopt(zmq.RCVTIMEO, int(timeout * 1000))
391
+
392
+ try:
393
+ print("Waiting for worker URLs message...")
394
+ message = socket.recv_json()
395
+ print(f"Received message: {message}")
396
+ if message.get("type") == "worker_urls":
397
+ return message["urls"]
398
+ else:
399
+ raise ValueError(f"Unexpected message type: {message.get('type')}")
400
+ except zmq.error.Again:
401
+ raise TimeoutError(f"Timeout waiting for worker URLs on port {zmq_port}")
402
+ finally:
403
+ socket.close()
404
+ context.term()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.13
3
+ Version: 0.1.14
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -15,7 +15,7 @@ Requires-Dist: python-multipart
15
15
  Requires-Dist: pydantic-settings
16
16
  Requires-Dist: torch
17
17
  Requires-Dist: transformers
18
- Requires-Dist: trl==0.17.0
18
+ Requires-Dist: trl
19
19
  Requires-Dist: peft
20
20
  Requires-Dist: ray>=2.9
21
21
  Requires-Dist: setuptools<77.0.0,>=76.0.0
@@ -18,17 +18,19 @@ arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3h
18
18
  arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
20
20
  arbor/server/services/grpo_manager.py,sha256=-_0xjENvIrOAtHACkFPMYox9YAeckHbpX2FkrmKrWuU,15448
21
- arbor/server/services/inference_manager.py,sha256=NcsUI-pgf3cRhU6P3xlPx0dxhvgYrfGZkEEGORcHcis,12833
21
+ arbor/server/services/inference_manager.py,sha256=Ju39_7EWySzAAk7ftz-AzSNBEo0tlayloPVS0XRAp8E,15304
22
22
  arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
23
23
  arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
24
24
  arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
26
+ arbor/server/services/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ arbor/server/services/inference/sgl_router_launch_server.py,sha256=eqTW6nDqqoRMISHfv5ScBCrolqLBp9zyxPXqHUlP6uo,6988
26
28
  arbor/server/services/scripts/grpo_training.py,sha256=eMT5cIMolAzhukANH1WRmPdxIkvLbsbrggdGFCMGMHc,26474
27
29
  arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
30
  arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- arbor_ai-0.1.13.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
- arbor_ai-0.1.13.dist-info/METADATA,sha256=c0yScMpCiWYSFqVLjgk5TrRBuAVJK3aTBl0z0IPZ_8Y,2442
31
- arbor_ai-0.1.13.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
32
- arbor_ai-0.1.13.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
- arbor_ai-0.1.13.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
- arbor_ai-0.1.13.dist-info/RECORD,,
31
+ arbor_ai-0.1.14.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
32
+ arbor_ai-0.1.14.dist-info/METADATA,sha256=vw8RnMPdGi36ji4rpjAldkOuCbxxjV4MFVi6yW-0kas,2434
33
+ arbor_ai-0.1.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ arbor_ai-0.1.14.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
35
+ arbor_ai-0.1.14.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
36
+ arbor_ai-0.1.14.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.6.0)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5