arbor-ai 0.1.14__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import json
3
2
  import os
4
3
  import random
5
4
  import signal
@@ -9,13 +8,14 @@ import sys
9
8
  import threading
10
9
  import time
11
10
  from datetime import datetime
11
+ from enum import Enum
12
12
  from typing import Any, Dict, Optional
13
13
 
14
- import aiohttp
14
+ import psutil
15
15
  import requests
16
- import zmq
17
16
 
18
17
  from arbor.server.core.config import Settings
18
+ from arbor.server.services.inference.vllm_client import VLLMClient
19
19
 
20
20
 
21
21
  class InferenceManager:
@@ -24,12 +24,14 @@ class InferenceManager:
24
24
  self.process = None
25
25
  self.launch_kwargs = {}
26
26
  self.last_activity = None
27
- self.restarting = False
28
27
  self._shutting_down = False
29
- self.current_model = None
28
+ self.launched_model = None
30
29
  self.inference_count = 0
31
30
  self._session = None
32
- self.worker_urls = []
31
+ self.port = None
32
+ self.group_port = None
33
+ self.vllm_client = None
34
+ self._is_updating = 0 # Counter for weight updates in progress
33
35
  # Set up signal handler for graceful shutdown
34
36
  signal.signal(signal.SIGINT, self._signal_handler)
35
37
  signal.signal(signal.SIGTERM, self._signal_handler)
@@ -47,15 +49,7 @@ class InferenceManager:
47
49
  def is_server_running(self):
48
50
  return self.process is not None
49
51
 
50
- def is_server_restarting(self):
51
- return self.restarting
52
-
53
- def launch(
54
- self,
55
- model: str,
56
- launch_kwargs: Optional[Dict[str, Any]] = None,
57
- max_retries: int = 3,
58
- ):
52
+ def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
59
53
  if self.is_server_running():
60
54
  print("Server is already launched.")
61
55
  return
@@ -67,126 +61,81 @@ class InferenceManager:
67
61
  if model.startswith(prefix):
68
62
  model = model[len(prefix) :]
69
63
 
70
- retries = 0
71
- while retries < max_retries:
72
- try:
73
- print(
74
- f"Attempt {retries + 1} of {max_retries} to launch server for model {model}"
75
- )
76
- print(
77
- f"Grabbing a free port to launch an SGLang server for model {model}"
78
- )
79
- router_port = get_free_port()
80
- dp_worker_base_port = get_free_port()
81
- worker_urls_port = get_free_port() # Get a port for worker URLs
82
-
83
- timeout = launch_kwargs.get("timeout", 1800)
84
- my_env = os.environ.copy()
85
- my_env["CUDA_VISIBLE_DEVICES"] = (
86
- self.settings.arbor_config.inference.gpu_ids
87
- )
88
- n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
89
- command = f"python -m arbor.server.services.inference.sgl_router_launch_server --model-path {model} --dp-size {n_gpus} --port {router_port} --host 0.0.0.0 --disable-radix-cache --router-dp-worker-base-port {dp_worker_base_port} --worker-urls-port {worker_urls_port}"
90
- print(f"Running command: {command}")
91
- if launch_kwargs.get("max_context_length"):
92
- command += (
93
- f" --context-length {launch_kwargs['max_context_length']}"
94
- )
95
-
96
- # We will manually stream & capture logs.
97
- process = subprocess.Popen(
98
- command.replace("\\\n", " ").replace("\\", " ").split(),
99
- text=True,
100
- stdout=subprocess.PIPE, # We'll read from pipe
101
- stderr=subprocess.STDOUT, # Merge stderr into stdout
102
- env=my_env,
103
- )
104
-
105
- # A threading.Event to control printing after the server is ready.
106
- # This will store *all* lines (both before and after readiness).
107
- print(f"SGLang server process started with PID {process.pid}.")
108
- stop_printing_event = threading.Event()
109
- logs_buffer = []
110
-
111
- def _tail_process(proc, buffer, stop_event):
112
- while True:
113
- line = proc.stdout.readline()
114
- if not line and proc.poll() is not None:
115
- # Process ended and no new line
116
- break
117
- if line:
118
- buffer.append(line)
119
- # Print only if stop_event is not set
120
- if not stop_event.is_set():
121
- print(f"[SGLang LOG] {line}", end="")
122
-
123
- # Start a background thread to read from the process continuously
124
- thread = threading.Thread(
125
- target=_tail_process,
126
- args=(process, logs_buffer, stop_printing_event),
127
- daemon=True,
128
- )
129
- thread.start()
130
-
131
- # Get worker URLs before waiting for server
132
- try:
133
- worker_urls = get_worker_urls(worker_urls_port)
134
- print(f"Received worker URLs: {worker_urls}")
135
- self.worker_urls = worker_urls
136
- except TimeoutError as e:
137
- raise Exception(f"Failed to get worker URLs: {e}")
138
-
139
- # Wait until the server is ready (or times out)
140
- base_url = f"http://localhost:{router_port}"
141
- try:
142
- wait_for_server(base_url, timeout=timeout)
143
- except TimeoutError:
144
- # If the server doesn't come up, we might want to kill it:
145
- process.kill()
146
- raise
147
-
148
- # Once server is ready, we tell the thread to stop printing further lines.
149
- stop_printing_event.set()
150
-
151
- # A convenience getter so the caller can see all logs so far (and future).
152
- def get_logs() -> str:
153
- # Join them all into a single string, or you might return a list
154
- return "".join(logs_buffer)
155
-
156
- # Let the user know server is up
157
- print(f"Server ready on random port {router_port}!")
158
-
159
- self.launch_kwargs["api_base"] = f"http://localhost:{router_port}/v1"
160
- self.launch_kwargs["api_key"] = "local"
161
- self.get_logs = get_logs
162
- self.process = process
163
- self.thread = thread
164
- self.current_model = model
165
-
166
- # If we get here, the launch was successful
167
- return
168
-
169
- except Exception as e:
170
- retries += 1
171
- print(
172
- f"Failed to launch server (attempt {retries} of {max_retries}): {str(e)}"
173
- )
174
- # Clean up any failed processes
175
- if "process" in locals():
176
- try:
177
- process.kill()
178
- except:
179
- pass
180
- if retries == max_retries:
181
- raise Exception(
182
- f"Failed to launch server after {max_retries} attempts"
183
- ) from e
184
- # Wait a bit before retrying
185
- time.sleep(min(2**retries, 30)) # Exponential backoff, max 30 seconds
64
+ print(f"Grabbing a free port to launch a vLLM server for model {model}")
65
+ self.port = get_free_port()
66
+ timeout = launch_kwargs.get("timeout", 1800)
67
+ my_env = os.environ.copy()
68
+ my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
69
+ n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
70
+ command = f"python -m arbor.server.services.inference.vllm_serve --model {model} --port {self.port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --enable_prefix_caching True"
71
+
72
+ if launch_kwargs.get("max_context_length"):
73
+ command += f" --max_model_len {launch_kwargs['max_context_length']}"
74
+
75
+ print(f"Running command: {command}")
76
+
77
+ # We will manually stream & capture logs.
78
+ process = subprocess.Popen(
79
+ command.replace("\\\n", " ").replace("\\", " ").split(),
80
+ text=True,
81
+ stdout=subprocess.PIPE, # We'll read from pipe
82
+ stderr=subprocess.STDOUT, # Merge stderr into stdout
83
+ env=my_env,
84
+ )
85
+
86
+ # A threading.Event to control printing after the server is ready.
87
+ # This will store *all* lines (both before and after readiness).
88
+ print(f"vLLM server process started with PID {process.pid}.")
89
+ stop_printing_event = threading.Event()
90
+ logs_buffer = []
91
+
92
+ def _tail_process(proc, buffer, stop_event):
93
+ while True:
94
+ line = proc.stdout.readline()
95
+ if not line and proc.poll() is not None:
96
+ # Process ended and no new line
97
+ break
98
+ if line:
99
+ buffer.append(line)
100
+ # Print only if stop_event is not set
101
+ if not stop_event.is_set():
102
+ print(f"[vLLM LOG] {line}", end="")
103
+
104
+ # Start a background thread to read from the process continuously
105
+ thread = threading.Thread(
106
+ target=_tail_process,
107
+ args=(process, logs_buffer, stop_printing_event),
108
+ daemon=True,
109
+ )
110
+ thread.start()
111
+
112
+ # A convenience getter so the caller can see all logs so far (and future).
113
+ def get_logs() -> str:
114
+ # Join them all into a single string, or you might return a list
115
+ return "".join(logs_buffer)
116
+
117
+ # Let the user know server is up
118
+ print(f"Server ready on random port {self.port}!")
119
+
120
+ # self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
121
+ # self.launch_kwargs["api_key"] = "local"
122
+ self.get_logs = get_logs
123
+ self.process = process
124
+ self.thread = thread
125
+ self.launched_model = model
126
+
127
+ # Get another free port for weight sync group communication
128
+ self.group_port = get_free_port()
129
+ self.vllm_client = VLLMClient(
130
+ port=self.port,
131
+ group_port=self.group_port,
132
+ connection_timeout=300, # 5 minutes
133
+ )
134
+
135
+ # Once server is ready, we tell the thread to stop printing further lines.
136
+ stop_printing_event.set()
186
137
 
187
138
  def kill(self):
188
- from sglang.utils import terminate_process
189
-
190
139
  if self.process is None:
191
140
  print("No running server to kill.")
192
141
  return
@@ -201,24 +150,7 @@ class InferenceManager:
201
150
  self.last_activity = None
202
151
 
203
152
  try:
204
- # Handle nested signal case
205
- if self._shutting_down:
206
- process.kill() # Go straight to SIGKILL if we're shutting down
207
- else:
208
- terminate_process(process)
209
- try:
210
- process.wait(timeout=10)
211
- except subprocess.TimeoutExpired:
212
- print(
213
- "Process did not terminate after 10 seconds, forcing with SIGKILL..."
214
- )
215
- process.kill()
216
-
217
- process.wait(timeout=5)
218
-
219
- if thread and thread.is_alive():
220
- thread.join(timeout=5)
221
-
153
+ kill_vllm_server(process.pid)
222
154
  except Exception as e:
223
155
  print(f"Error during cleanup: {e}")
224
156
  try:
@@ -229,127 +161,65 @@ class InferenceManager:
229
161
  print("Server killed.")
230
162
 
231
163
  async def run_inference(self, request_json: dict):
164
+ # Check if weights are being updated
165
+ while self.is_updating:
166
+ # weights are being updated...waiting
167
+ # print("Weights are being updated, waiting...")
168
+ await asyncio.sleep(1) # Small sleep to prevent busy waiting
169
+
232
170
  model = request_json["model"]
233
171
  prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
234
172
  for prefix in prefixes:
235
173
  if model.startswith(prefix):
236
174
  model = model[len(prefix) :]
237
175
  print(f"Running inference for model {model}")
238
- # Monkeypatch:
239
- if model != self.current_model:
240
- print(f"Model changed from {model} to {self.current_model}")
241
- model = self.current_model
176
+
177
+ # Monkeypatch for GRPO runs:
178
+ # vllm complains if we don't give it the exact model name that was launched
179
+ # TODO: This should really throw an error unless in a GRPO run.
180
+ if model != self.launched_model:
181
+ # print(f"Model changed from {model} to {self.current_model}")
182
+ model = self.launched_model
242
183
  request_json["model"] = model
243
184
 
244
185
  # Update last_activity timestamp
245
186
  self.last_activity = datetime.now()
246
187
 
247
- if self.process is None or self.launch_kwargs.get("api_base") is None:
188
+ if self.process is None:
248
189
  raise RuntimeError("Server is not running. Please launch it first.")
249
190
 
250
- if self.restarting:
251
- while self.restarting:
252
- print("Inference is paused while server is restarting...")
253
- await asyncio.sleep(5)
254
- request_json["model"] = self.current_model
191
+ return await self.vllm_client.chat(json_body=request_json)
255
192
 
256
- url = f"{self.launch_kwargs['api_base']}/chat/completions"
257
- try:
258
- self.inference_count += 1
259
- session = await self._ensure_session()
260
- async with session.post(url, json=request_json) as response:
261
- content = await response.content.read()
262
- return json.loads(content)
263
- except aiohttp.ClientError as e:
264
- print(f"Connection error: {type(e).__name__}: {str(e)}")
265
- # Try to close and recreate the session on error
266
- if self._session:
267
- await self._session.close()
268
- self._session = None
269
- return None
270
- except json.decoder.JSONDecodeError:
271
- print(f"JSON Decode Error during inference: {content}")
272
- return {
273
- "error": "JSON Decode Error",
274
- "content": content if content else "Content is null",
275
- }
276
- except Exception as e:
277
- print(f"Error during inference: {e}")
278
- raise
279
- finally:
280
- self.inference_count -= 1
281
-
282
- def update_model(self, output_dir):
283
- print("Restarting server with new model...")
284
- self.restarting = True
285
-
286
- # Close existing session and reset inference count
287
- if self._session:
288
- # Create a new event loop if one doesn't exist
289
- try:
290
- loop = asyncio.get_event_loop()
291
- except RuntimeError:
292
- loop = asyncio.new_event_loop()
293
- asyncio.set_event_loop(loop)
294
-
295
- # Run the session closure in the event loop
296
- loop.run_until_complete(self._session.close())
297
- self._session = None
298
- self.inference_count = 0
193
+ def start_weight_update(self):
194
+ """Block inference during weight updates"""
195
+ self._is_updating += 1
299
196
 
300
- tik = time.time()
301
- # self.kill()
302
- # print("Just killed server")
303
- # time.sleep(5)
197
+ def complete_weight_update(self):
198
+ """Allow inference after weight update is complete"""
199
+ self._is_updating = max(0, self._is_updating - 1) # Prevent going negative
304
200
 
305
- # Check that output directory exists and was created successfully
306
- print(f"Checking that output directory {output_dir} exists")
307
- if not os.path.exists(output_dir):
308
- raise RuntimeError(
309
- f"Failed to save model - output directory {output_dir} does not exist"
310
- )
311
-
312
- print("Directly updating weights from disk")
313
- for worker_url in self.worker_urls:
314
- print(f"Updating weights from disk for worker {worker_url}")
315
- try:
316
- response = requests.post(
317
- f"{worker_url}/update_weights_from_disk",
318
- json={"model_path": output_dir},
319
- )
320
- response_json = response.json()
321
- print(f"Response from update_weights_from_disk: {response_json}")
322
- # TODO: Check that the response is successful
323
- except Exception as e:
324
- print(f"Error during update_weights_from_disk: {e}")
325
- print(f"Full error during update_weights_from_disk: {str(e)}")
326
- if hasattr(e, "response") and e.response is not None:
327
- print(f"Response status code: {e.response.status_code}")
328
- print(f"Response text: {e.response.text}")
329
- self.current_model = output_dir
330
-
331
- # print("Launching new server")
332
- # self.launch(output_dir, self.launch_kwargs)
333
- tok = time.time()
334
- self.restarting = False
335
- print(f"Time taken to update model: {tok - tik} seconds")
336
-
337
- async def _ensure_session(self):
338
- if self._session is None or self._session.closed:
339
- timeout = aiohttp.ClientTimeout(
340
- total=None
341
- ) # No timeout...If it hangs, this might be the issue.
342
- self._session = aiohttp.ClientSession(timeout=timeout)
343
- return self._session
201
+ @property
202
+ def is_updating(self):
203
+ """Check if any weight updates are in progress"""
204
+ return self._is_updating > 0
344
205
 
345
206
 
346
207
  def get_free_port() -> int:
347
208
  """
348
- Return a free TCP port on localhost.
209
+ Return a randomly selected free TCP port on localhost from a selection of 3-4 ports.
349
210
  """
350
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
351
- s.bind(("localhost", 0))
352
- return s.getsockname()[1]
211
+ import random
212
+ import socket
213
+
214
+ ports = []
215
+ for _ in range(random.randint(5, 10)):
216
+ try:
217
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
218
+ s.bind(("localhost", 0))
219
+ ports.append(s.getsockname()[1])
220
+ except Exception as e:
221
+ print(f"Error binding to port: {e}")
222
+ return random.choice(ports)
353
223
 
354
224
 
355
225
  def wait_for_server(base_url: str, timeout: int = None) -> None:
@@ -379,26 +249,29 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
379
249
  time.sleep(1)
380
250
 
381
251
 
382
- def get_worker_urls(zmq_port: int, timeout: float = 30.0) -> list:
383
- print(f"Attempting to get worker URLs on port {zmq_port} with timeout {timeout}s")
384
- context = zmq.Context()
385
- socket = context.socket(zmq.SUB)
386
- socket.connect(f"tcp://localhost:{zmq_port}")
387
- socket.setsockopt_string(zmq.SUBSCRIBE, "") # Subscribe to all messages
252
+ def kill_vllm_server(main_process_pid):
253
+ try:
254
+ # Get the parent process
255
+ parent = psutil.Process(main_process_pid)
388
256
 
389
- # Set a timeout for receiving
390
- socket.setsockopt(zmq.RCVTIMEO, int(timeout * 1000))
257
+ # Get all child processes recursively
258
+ children = parent.children(recursive=True)
391
259
 
392
- try:
393
- print("Waiting for worker URLs message...")
394
- message = socket.recv_json()
395
- print(f"Received message: {message}")
396
- if message.get("type") == "worker_urls":
397
- return message["urls"]
398
- else:
399
- raise ValueError(f"Unexpected message type: {message.get('type')}")
400
- except zmq.error.Again:
401
- raise TimeoutError(f"Timeout waiting for worker URLs on port {zmq_port}")
402
- finally:
403
- socket.close()
404
- context.term()
260
+ # Send SIGTERM to all child processes first
261
+ for child in children:
262
+ child.send_signal(signal.SIGTERM)
263
+
264
+ # Send SIGTERM to parent process
265
+ parent.send_signal(signal.SIGTERM)
266
+
267
+ # Wait for processes to terminate gracefully
268
+ gone, alive = psutil.wait_procs(children + [parent], timeout=10)
269
+
270
+ # If any processes are still alive, force kill them
271
+ for p in alive:
272
+ p.kill() # SIGKILL
273
+
274
+ except psutil.NoSuchProcess:
275
+ print(f"Process {main_process_pid} not found")
276
+ except Exception as e:
277
+ print(f"Error killing processes: {e}")
File without changes