arbor-ai 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import json
3
2
  import os
4
3
  import random
5
4
  import signal
@@ -9,12 +8,14 @@ import sys
9
8
  import threading
10
9
  import time
11
10
  from datetime import datetime
11
+ from enum import Enum
12
12
  from typing import Any, Dict, Optional
13
13
 
14
- import aiohttp
14
+ import psutil
15
15
  import requests
16
16
 
17
17
  from arbor.server.core.config import Settings
18
+ from arbor.server.services.inference.vllm_client import VLLMClient
18
19
 
19
20
 
20
21
  class InferenceManager:
@@ -23,11 +24,14 @@ class InferenceManager:
23
24
  self.process = None
24
25
  self.launch_kwargs = {}
25
26
  self.last_activity = None
26
- self.restarting = False
27
27
  self._shutting_down = False
28
- self.current_model = None
28
+ self.launched_model = None
29
29
  self.inference_count = 0
30
30
  self._session = None
31
+ self.port = None
32
+ self.group_port = None
33
+ self.vllm_client = None
34
+ self._is_updating = 0 # Counter for weight updates in progress
31
35
  # Set up signal handler for graceful shutdown
32
36
  signal.signal(signal.SIGINT, self._signal_handler)
33
37
  signal.signal(signal.SIGTERM, self._signal_handler)
@@ -45,15 +49,7 @@ class InferenceManager:
45
49
  def is_server_running(self):
46
50
  return self.process is not None
47
51
 
48
- def is_server_restarting(self):
49
- return self.restarting
50
-
51
- def launch(
52
- self,
53
- model: str,
54
- launch_kwargs: Optional[Dict[str, Any]] = None,
55
- max_retries: int = 3,
56
- ):
52
+ def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
57
53
  if self.is_server_running():
58
54
  print("Server is already launched.")
59
55
  return
@@ -65,116 +61,81 @@ class InferenceManager:
65
61
  if model.startswith(prefix):
66
62
  model = model[len(prefix) :]
67
63
 
68
- retries = 0
69
- while retries < max_retries:
70
- try:
71
- print(
72
- f"Attempt {retries + 1} of {max_retries} to launch server for model {model}"
73
- )
74
- print(
75
- f"Grabbing a free port to launch an SGLang server for model {model}"
76
- )
77
- port = get_free_port()
78
- timeout = launch_kwargs.get("timeout", 1800)
79
- my_env = os.environ.copy()
80
- my_env["CUDA_VISIBLE_DEVICES"] = (
81
- self.settings.arbor_config.inference.gpu_ids
82
- )
83
- n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
84
- # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
85
- command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
86
- print(f"Running command: {command}")
87
- if launch_kwargs.get("max_context_length"):
88
- command += (
89
- f" --context-length {launch_kwargs['max_context_length']}"
90
- )
91
-
92
- # We will manually stream & capture logs.
93
- process = subprocess.Popen(
94
- command.replace("\\\n", " ").replace("\\", " ").split(),
95
- text=True,
96
- stdout=subprocess.PIPE, # We'll read from pipe
97
- stderr=subprocess.STDOUT, # Merge stderr into stdout
98
- env=my_env,
99
- )
100
-
101
- # A threading.Event to control printing after the server is ready.
102
- # This will store *all* lines (both before and after readiness).
103
- print(f"SGLang server process started with PID {process.pid}.")
104
- stop_printing_event = threading.Event()
105
- logs_buffer = []
106
-
107
- def _tail_process(proc, buffer, stop_event):
108
- while True:
109
- line = proc.stdout.readline()
110
- if not line and proc.poll() is not None:
111
- # Process ended and no new line
112
- break
113
- if line:
114
- buffer.append(line)
115
- # Print only if stop_event is not set
116
- if not stop_event.is_set():
117
- print(f"[SGLang LOG] {line}", end="")
118
-
119
- # Start a background thread to read from the process continuously
120
- thread = threading.Thread(
121
- target=_tail_process,
122
- args=(process, logs_buffer, stop_printing_event),
123
- daemon=True,
124
- )
125
- thread.start()
126
-
127
- # Wait until the server is ready (or times out)
128
- base_url = f"http://localhost:{port}"
129
- try:
130
- wait_for_server(base_url, timeout=timeout)
131
- except TimeoutError:
132
- # If the server doesn't come up, we might want to kill it:
133
- process.kill()
134
- raise
135
-
136
- # Once server is ready, we tell the thread to stop printing further lines.
137
- stop_printing_event.set()
138
-
139
- # A convenience getter so the caller can see all logs so far (and future).
140
- def get_logs() -> str:
141
- # Join them all into a single string, or you might return a list
142
- return "".join(logs_buffer)
143
-
144
- # Let the user know server is up
145
- print(f"Server ready on random port {port}!")
146
-
147
- self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
148
- self.launch_kwargs["api_key"] = "local"
149
- self.get_logs = get_logs
150
- self.process = process
151
- self.thread = thread
152
- self.current_model = model
153
-
154
- # If we get here, the launch was successful
155
- return
156
-
157
- except Exception as e:
158
- retries += 1
159
- print(
160
- f"Failed to launch server (attempt {retries} of {max_retries}): {str(e)}"
161
- )
162
- # Clean up any failed processes
163
- if "process" in locals():
164
- try:
165
- process.kill()
166
- except:
167
- pass
168
- if retries == max_retries:
169
- raise Exception(
170
- f"Failed to launch server after {max_retries} attempts"
171
- ) from e
172
- # Wait a bit before retrying
173
- time.sleep(min(2**retries, 30)) # Exponential backoff, max 30 seconds
64
+ print(f"Grabbing a free port to launch a vLLM server for model {model}")
65
+ self.port = get_free_port()
66
+ timeout = launch_kwargs.get("timeout", 1800)
67
+ my_env = os.environ.copy()
68
+ my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
69
+ n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
70
+ command = f"python -m arbor.server.services.inference.vllm_serve --model {model} --port {self.port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --enable_prefix_caching True"
71
+
72
+ if launch_kwargs.get("max_context_length"):
73
+ command += f" --max_model_len {launch_kwargs['max_context_length']}"
74
+
75
+ print(f"Running command: {command}")
76
+
77
+ # We will manually stream & capture logs.
78
+ process = subprocess.Popen(
79
+ command.replace("\\\n", " ").replace("\\", " ").split(),
80
+ text=True,
81
+ stdout=subprocess.PIPE, # We'll read from pipe
82
+ stderr=subprocess.STDOUT, # Merge stderr into stdout
83
+ env=my_env,
84
+ )
85
+
86
+ # A threading.Event to control printing after the server is ready.
87
+ # This will store *all* lines (both before and after readiness).
88
+ print(f"vLLM server process started with PID {process.pid}.")
89
+ stop_printing_event = threading.Event()
90
+ logs_buffer = []
91
+
92
+ def _tail_process(proc, buffer, stop_event):
93
+ while True:
94
+ line = proc.stdout.readline()
95
+ if not line and proc.poll() is not None:
96
+ # Process ended and no new line
97
+ break
98
+ if line:
99
+ buffer.append(line)
100
+ # Print only if stop_event is not set
101
+ if not stop_event.is_set():
102
+ print(f"[vLLM LOG] {line}", end="")
103
+
104
+ # Start a background thread to read from the process continuously
105
+ thread = threading.Thread(
106
+ target=_tail_process,
107
+ args=(process, logs_buffer, stop_printing_event),
108
+ daemon=True,
109
+ )
110
+ thread.start()
111
+
112
+ # A convenience getter so the caller can see all logs so far (and future).
113
+ def get_logs() -> str:
114
+ # Join them all into a single string, or you might return a list
115
+ return "".join(logs_buffer)
116
+
117
+ # Let the user know server is up
118
+ print(f"Server ready on random port {self.port}!")
119
+
120
+ # self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
121
+ # self.launch_kwargs["api_key"] = "local"
122
+ self.get_logs = get_logs
123
+ self.process = process
124
+ self.thread = thread
125
+ self.launched_model = model
126
+
127
+ # Get another free port for weight sync group communication
128
+ self.group_port = get_free_port()
129
+ self.vllm_client = VLLMClient(
130
+ port=self.port,
131
+ group_port=self.group_port,
132
+ connection_timeout=300, # 5 minutes
133
+ )
134
+
135
+ # Once server is ready, we tell the thread to stop printing further lines.
136
+ stop_printing_event.set()
174
137
 
175
138
  def kill(self):
176
- from sglang.utils import terminate_process
177
-
178
139
  if self.process is None:
179
140
  print("No running server to kill.")
180
141
  return
@@ -189,24 +150,7 @@ class InferenceManager:
189
150
  self.last_activity = None
190
151
 
191
152
  try:
192
- # Handle nested signal case
193
- if self._shutting_down:
194
- process.kill() # Go straight to SIGKILL if we're shutting down
195
- else:
196
- terminate_process(process)
197
- try:
198
- process.wait(timeout=10)
199
- except subprocess.TimeoutExpired:
200
- print(
201
- "Process did not terminate after 10 seconds, forcing with SIGKILL..."
202
- )
203
- process.kill()
204
-
205
- process.wait(timeout=5)
206
-
207
- if thread and thread.is_alive():
208
- thread.join(timeout=5)
209
-
153
+ kill_vllm_server(process.pid)
210
154
  except Exception as e:
211
155
  print(f"Error during cleanup: {e}")
212
156
  try:
@@ -217,107 +161,65 @@ class InferenceManager:
217
161
  print("Server killed.")
218
162
 
219
163
  async def run_inference(self, request_json: dict):
164
+ # Check if weights are being updated
165
+ while self.is_updating:
166
+ # weights are being updated...waiting
167
+ # print("Weights are being updated, waiting...")
168
+ await asyncio.sleep(1) # Small sleep to prevent busy waiting
169
+
220
170
  model = request_json["model"]
221
171
  prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
222
172
  for prefix in prefixes:
223
173
  if model.startswith(prefix):
224
174
  model = model[len(prefix) :]
225
175
  print(f"Running inference for model {model}")
226
- # Monkeypatch:
227
- if model != self.current_model:
228
- print(f"Model changed from {model} to {self.current_model}")
229
- model = self.current_model
176
+
177
+ # Monkeypatch for GRPO runs:
178
+ # vllm complains if we don't give it the exact model name that was launched
179
+ # TODO: This should really throw an error unless in a GRPO run.
180
+ if model != self.launched_model:
181
+ # print(f"Model changed from {model} to {self.current_model}")
182
+ model = self.launched_model
230
183
  request_json["model"] = model
231
184
 
232
185
  # Update last_activity timestamp
233
186
  self.last_activity = datetime.now()
234
187
 
235
- if self.process is None or self.launch_kwargs.get("api_base") is None:
188
+ if self.process is None:
236
189
  raise RuntimeError("Server is not running. Please launch it first.")
237
190
 
238
- if self.restarting:
239
- while self.restarting:
240
- print("Inference is paused while server is restarting...")
241
- await asyncio.sleep(5)
242
- request_json["model"] = self.current_model
191
+ return await self.vllm_client.chat(json_body=request_json)
243
192
 
244
- url = f"{self.launch_kwargs['api_base']}/chat/completions"
245
- try:
246
- self.inference_count += 1
247
- session = await self._ensure_session()
248
- async with session.post(url, json=request_json) as response:
249
- content = await response.content.read()
250
- return json.loads(content)
251
- except aiohttp.ClientError as e:
252
- print(f"Connection error: {type(e).__name__}: {str(e)}")
253
- # Try to close and recreate the session on error
254
- if self._session:
255
- await self._session.close()
256
- self._session = None
257
- return None
258
- except json.decoder.JSONDecodeError:
259
- print(f"JSON Decode Error during inference: {content}")
260
- return {
261
- "error": "JSON Decode Error",
262
- "content": content if content else "Content is null",
263
- }
264
- except Exception as e:
265
- print(f"Error during inference: {e}")
266
- raise
267
- finally:
268
- self.inference_count -= 1
269
-
270
- def update_model(self, output_dir):
271
- print("Restarting server with new model...")
272
- self.restarting = True
273
-
274
- # Close existing session and reset inference count
275
- if self._session:
276
- # Create a new event loop if one doesn't exist
277
- try:
278
- loop = asyncio.get_event_loop()
279
- except RuntimeError:
280
- loop = asyncio.new_event_loop()
281
- asyncio.set_event_loop(loop)
282
-
283
- # Run the session closure in the event loop
284
- loop.run_until_complete(self._session.close())
285
- self._session = None
286
- self.inference_count = 0
193
+ def start_weight_update(self):
194
+ """Block inference during weight updates"""
195
+ self._is_updating += 1
287
196
 
288
- tik = time.time()
289
- self.kill()
290
- print("Just killed server")
291
- time.sleep(5)
292
- # Check that output directory exists and was created successfully
293
- print(f"Checking that output directory {output_dir} exists")
294
- if not os.path.exists(output_dir):
295
- raise RuntimeError(
296
- f"Failed to save model - output directory {output_dir} does not exist"
297
- )
197
+ def complete_weight_update(self):
198
+ """Allow inference after weight update is complete"""
199
+ self._is_updating = max(0, self._is_updating - 1) # Prevent going negative
298
200
 
299
- print("Launching new server")
300
- self.launch(output_dir, self.launch_kwargs)
301
- tok = time.time()
302
- self.restarting = False
303
- print(f"Time taken to update model: {tok - tik} seconds")
304
-
305
- async def _ensure_session(self):
306
- if self._session is None or self._session.closed:
307
- timeout = aiohttp.ClientTimeout(
308
- total=None
309
- ) # No timeout...If it hangs, this might be the issue.
310
- self._session = aiohttp.ClientSession(timeout=timeout)
311
- return self._session
201
+ @property
202
+ def is_updating(self):
203
+ """Check if any weight updates are in progress"""
204
+ return self._is_updating > 0
312
205
 
313
206
 
314
207
  def get_free_port() -> int:
315
208
  """
316
- Return a free TCP port on localhost.
209
+ Return a randomly selected free TCP port on localhost from a selection of 3-4 ports.
317
210
  """
318
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
319
- s.bind(("localhost", 0))
320
- return s.getsockname()[1]
211
+ import random
212
+ import socket
213
+
214
+ ports = []
215
+ for _ in range(random.randint(5, 10)):
216
+ try:
217
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
218
+ s.bind(("localhost", 0))
219
+ ports.append(s.getsockname()[1])
220
+ except Exception as e:
221
+ print(f"Error binding to port: {e}")
222
+ return random.choice(ports)
321
223
 
322
224
 
323
225
  def wait_for_server(base_url: str, timeout: int = None) -> None:
@@ -345,3 +247,31 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
345
247
  except requests.exceptions.RequestException:
346
248
  # Server not up yet, wait and retry
347
249
  time.sleep(1)
250
+
251
+
252
+ def kill_vllm_server(main_process_pid):
253
+ try:
254
+ # Get the parent process
255
+ parent = psutil.Process(main_process_pid)
256
+
257
+ # Get all child processes recursively
258
+ children = parent.children(recursive=True)
259
+
260
+ # Send SIGTERM to all child processes first
261
+ for child in children:
262
+ child.send_signal(signal.SIGTERM)
263
+
264
+ # Send SIGTERM to parent process
265
+ parent.send_signal(signal.SIGTERM)
266
+
267
+ # Wait for processes to terminate gracefully
268
+ gone, alive = psutil.wait_procs(children + [parent], timeout=10)
269
+
270
+ # If any processes are still alive, force kill them
271
+ for p in alive:
272
+ p.kill() # SIGKILL
273
+
274
+ except psutil.NoSuchProcess:
275
+ print(f"Process {main_process_pid} not found")
276
+ except Exception as e:
277
+ print(f"Error killing processes: {e}")
File without changes