arbor-ai 0.1.14__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +12 -0
- arbor/server/api/models/schemas.py +0 -1
- arbor/server/api/routes/grpo.py +4 -9
- arbor/server/api/routes/inference.py +24 -14
- arbor/server/services/grpo_manager.py +176 -103
- arbor/server/services/inference/vllm_client.py +444 -0
- arbor/server/services/inference/vllm_serve.py +2336 -0
- arbor/server/services/inference_manager.py +145 -272
- arbor/server/services/scripts/dpo_training.py +0 -0
- arbor/server/services/scripts/grpo_training.py +165 -57
- arbor/server/services/scripts/sft_training.py +109 -0
- arbor/server/services/scripts/utils/__init__.py +0 -0
- arbor/server/services/scripts/utils/arg_parser.py +31 -0
- arbor/server/services/scripts/utils/dataset.py +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/METADATA +10 -6
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/RECORD +20 -14
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/WHEEL +1 -1
- arbor/server/services/inference/sgl_router_launch_server.py +0 -226
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
|
|
1
1
|
import asyncio
|
2
|
-
import json
|
3
2
|
import os
|
4
3
|
import random
|
5
4
|
import signal
|
@@ -9,13 +8,14 @@ import sys
|
|
9
8
|
import threading
|
10
9
|
import time
|
11
10
|
from datetime import datetime
|
11
|
+
from enum import Enum
|
12
12
|
from typing import Any, Dict, Optional
|
13
13
|
|
14
|
-
import
|
14
|
+
import psutil
|
15
15
|
import requests
|
16
|
-
import zmq
|
17
16
|
|
18
17
|
from arbor.server.core.config import Settings
|
18
|
+
from arbor.server.services.inference.vllm_client import VLLMClient
|
19
19
|
|
20
20
|
|
21
21
|
class InferenceManager:
|
@@ -24,12 +24,14 @@ class InferenceManager:
|
|
24
24
|
self.process = None
|
25
25
|
self.launch_kwargs = {}
|
26
26
|
self.last_activity = None
|
27
|
-
self.restarting = False
|
28
27
|
self._shutting_down = False
|
29
|
-
self.
|
28
|
+
self.launched_model = None
|
30
29
|
self.inference_count = 0
|
31
30
|
self._session = None
|
32
|
-
self.
|
31
|
+
self.port = None
|
32
|
+
self.group_port = None
|
33
|
+
self.vllm_client = None
|
34
|
+
self._is_updating = 0 # Counter for weight updates in progress
|
33
35
|
# Set up signal handler for graceful shutdown
|
34
36
|
signal.signal(signal.SIGINT, self._signal_handler)
|
35
37
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
@@ -47,15 +49,7 @@ class InferenceManager:
|
|
47
49
|
def is_server_running(self):
|
48
50
|
return self.process is not None
|
49
51
|
|
50
|
-
def
|
51
|
-
return self.restarting
|
52
|
-
|
53
|
-
def launch(
|
54
|
-
self,
|
55
|
-
model: str,
|
56
|
-
launch_kwargs: Optional[Dict[str, Any]] = None,
|
57
|
-
max_retries: int = 3,
|
58
|
-
):
|
52
|
+
def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
|
59
53
|
if self.is_server_running():
|
60
54
|
print("Server is already launched.")
|
61
55
|
return
|
@@ -67,126 +61,81 @@ class InferenceManager:
|
|
67
61
|
if model.startswith(prefix):
|
68
62
|
model = model[len(prefix) :]
|
69
63
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
except TimeoutError:
|
144
|
-
# If the server doesn't come up, we might want to kill it:
|
145
|
-
process.kill()
|
146
|
-
raise
|
147
|
-
|
148
|
-
# Once server is ready, we tell the thread to stop printing further lines.
|
149
|
-
stop_printing_event.set()
|
150
|
-
|
151
|
-
# A convenience getter so the caller can see all logs so far (and future).
|
152
|
-
def get_logs() -> str:
|
153
|
-
# Join them all into a single string, or you might return a list
|
154
|
-
return "".join(logs_buffer)
|
155
|
-
|
156
|
-
# Let the user know server is up
|
157
|
-
print(f"Server ready on random port {router_port}!")
|
158
|
-
|
159
|
-
self.launch_kwargs["api_base"] = f"http://localhost:{router_port}/v1"
|
160
|
-
self.launch_kwargs["api_key"] = "local"
|
161
|
-
self.get_logs = get_logs
|
162
|
-
self.process = process
|
163
|
-
self.thread = thread
|
164
|
-
self.current_model = model
|
165
|
-
|
166
|
-
# If we get here, the launch was successful
|
167
|
-
return
|
168
|
-
|
169
|
-
except Exception as e:
|
170
|
-
retries += 1
|
171
|
-
print(
|
172
|
-
f"Failed to launch server (attempt {retries} of {max_retries}): {str(e)}"
|
173
|
-
)
|
174
|
-
# Clean up any failed processes
|
175
|
-
if "process" in locals():
|
176
|
-
try:
|
177
|
-
process.kill()
|
178
|
-
except:
|
179
|
-
pass
|
180
|
-
if retries == max_retries:
|
181
|
-
raise Exception(
|
182
|
-
f"Failed to launch server after {max_retries} attempts"
|
183
|
-
) from e
|
184
|
-
# Wait a bit before retrying
|
185
|
-
time.sleep(min(2**retries, 30)) # Exponential backoff, max 30 seconds
|
64
|
+
print(f"Grabbing a free port to launch a vLLM server for model {model}")
|
65
|
+
self.port = get_free_port()
|
66
|
+
timeout = launch_kwargs.get("timeout", 1800)
|
67
|
+
my_env = os.environ.copy()
|
68
|
+
my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
|
69
|
+
n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
|
70
|
+
command = f"python -m arbor.server.services.inference.vllm_serve --model {model} --port {self.port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --enable_prefix_caching True"
|
71
|
+
|
72
|
+
if launch_kwargs.get("max_context_length"):
|
73
|
+
command += f" --max_model_len {launch_kwargs['max_context_length']}"
|
74
|
+
|
75
|
+
print(f"Running command: {command}")
|
76
|
+
|
77
|
+
# We will manually stream & capture logs.
|
78
|
+
process = subprocess.Popen(
|
79
|
+
command.replace("\\\n", " ").replace("\\", " ").split(),
|
80
|
+
text=True,
|
81
|
+
stdout=subprocess.PIPE, # We'll read from pipe
|
82
|
+
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
83
|
+
env=my_env,
|
84
|
+
)
|
85
|
+
|
86
|
+
# A threading.Event to control printing after the server is ready.
|
87
|
+
# This will store *all* lines (both before and after readiness).
|
88
|
+
print(f"vLLM server process started with PID {process.pid}.")
|
89
|
+
stop_printing_event = threading.Event()
|
90
|
+
logs_buffer = []
|
91
|
+
|
92
|
+
def _tail_process(proc, buffer, stop_event):
|
93
|
+
while True:
|
94
|
+
line = proc.stdout.readline()
|
95
|
+
if not line and proc.poll() is not None:
|
96
|
+
# Process ended and no new line
|
97
|
+
break
|
98
|
+
if line:
|
99
|
+
buffer.append(line)
|
100
|
+
# Print only if stop_event is not set
|
101
|
+
if not stop_event.is_set():
|
102
|
+
print(f"[vLLM LOG] {line}", end="")
|
103
|
+
|
104
|
+
# Start a background thread to read from the process continuously
|
105
|
+
thread = threading.Thread(
|
106
|
+
target=_tail_process,
|
107
|
+
args=(process, logs_buffer, stop_printing_event),
|
108
|
+
daemon=True,
|
109
|
+
)
|
110
|
+
thread.start()
|
111
|
+
|
112
|
+
# A convenience getter so the caller can see all logs so far (and future).
|
113
|
+
def get_logs() -> str:
|
114
|
+
# Join them all into a single string, or you might return a list
|
115
|
+
return "".join(logs_buffer)
|
116
|
+
|
117
|
+
# Let the user know server is up
|
118
|
+
print(f"Server ready on random port {self.port}!")
|
119
|
+
|
120
|
+
# self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
|
121
|
+
# self.launch_kwargs["api_key"] = "local"
|
122
|
+
self.get_logs = get_logs
|
123
|
+
self.process = process
|
124
|
+
self.thread = thread
|
125
|
+
self.launched_model = model
|
126
|
+
|
127
|
+
# Get another free port for weight sync group communication
|
128
|
+
self.group_port = get_free_port()
|
129
|
+
self.vllm_client = VLLMClient(
|
130
|
+
port=self.port,
|
131
|
+
group_port=self.group_port,
|
132
|
+
connection_timeout=300, # 5 minutes
|
133
|
+
)
|
134
|
+
|
135
|
+
# Once server is ready, we tell the thread to stop printing further lines.
|
136
|
+
stop_printing_event.set()
|
186
137
|
|
187
138
|
def kill(self):
|
188
|
-
from sglang.utils import terminate_process
|
189
|
-
|
190
139
|
if self.process is None:
|
191
140
|
print("No running server to kill.")
|
192
141
|
return
|
@@ -201,24 +150,7 @@ class InferenceManager:
|
|
201
150
|
self.last_activity = None
|
202
151
|
|
203
152
|
try:
|
204
|
-
|
205
|
-
if self._shutting_down:
|
206
|
-
process.kill() # Go straight to SIGKILL if we're shutting down
|
207
|
-
else:
|
208
|
-
terminate_process(process)
|
209
|
-
try:
|
210
|
-
process.wait(timeout=10)
|
211
|
-
except subprocess.TimeoutExpired:
|
212
|
-
print(
|
213
|
-
"Process did not terminate after 10 seconds, forcing with SIGKILL..."
|
214
|
-
)
|
215
|
-
process.kill()
|
216
|
-
|
217
|
-
process.wait(timeout=5)
|
218
|
-
|
219
|
-
if thread and thread.is_alive():
|
220
|
-
thread.join(timeout=5)
|
221
|
-
|
153
|
+
kill_vllm_server(process.pid)
|
222
154
|
except Exception as e:
|
223
155
|
print(f"Error during cleanup: {e}")
|
224
156
|
try:
|
@@ -229,127 +161,65 @@ class InferenceManager:
|
|
229
161
|
print("Server killed.")
|
230
162
|
|
231
163
|
async def run_inference(self, request_json: dict):
|
164
|
+
# Check if weights are being updated
|
165
|
+
while self.is_updating:
|
166
|
+
# weights are being updated...waiting
|
167
|
+
# print("Weights are being updated, waiting...")
|
168
|
+
await asyncio.sleep(1) # Small sleep to prevent busy waiting
|
169
|
+
|
232
170
|
model = request_json["model"]
|
233
171
|
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
234
172
|
for prefix in prefixes:
|
235
173
|
if model.startswith(prefix):
|
236
174
|
model = model[len(prefix) :]
|
237
175
|
print(f"Running inference for model {model}")
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
176
|
+
|
177
|
+
# Monkeypatch for GRPO runs:
|
178
|
+
# vllm complains if we don't give it the exact model name that was launched
|
179
|
+
# TODO: This should really throw an error unless in a GRPO run.
|
180
|
+
if model != self.launched_model:
|
181
|
+
# print(f"Model changed from {model} to {self.current_model}")
|
182
|
+
model = self.launched_model
|
242
183
|
request_json["model"] = model
|
243
184
|
|
244
185
|
# Update last_activity timestamp
|
245
186
|
self.last_activity = datetime.now()
|
246
187
|
|
247
|
-
if self.process is None
|
188
|
+
if self.process is None:
|
248
189
|
raise RuntimeError("Server is not running. Please launch it first.")
|
249
190
|
|
250
|
-
|
251
|
-
while self.restarting:
|
252
|
-
print("Inference is paused while server is restarting...")
|
253
|
-
await asyncio.sleep(5)
|
254
|
-
request_json["model"] = self.current_model
|
191
|
+
return await self.vllm_client.chat(json_body=request_json)
|
255
192
|
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
session = await self._ensure_session()
|
260
|
-
async with session.post(url, json=request_json) as response:
|
261
|
-
content = await response.content.read()
|
262
|
-
return json.loads(content)
|
263
|
-
except aiohttp.ClientError as e:
|
264
|
-
print(f"Connection error: {type(e).__name__}: {str(e)}")
|
265
|
-
# Try to close and recreate the session on error
|
266
|
-
if self._session:
|
267
|
-
await self._session.close()
|
268
|
-
self._session = None
|
269
|
-
return None
|
270
|
-
except json.decoder.JSONDecodeError:
|
271
|
-
print(f"JSON Decode Error during inference: {content}")
|
272
|
-
return {
|
273
|
-
"error": "JSON Decode Error",
|
274
|
-
"content": content if content else "Content is null",
|
275
|
-
}
|
276
|
-
except Exception as e:
|
277
|
-
print(f"Error during inference: {e}")
|
278
|
-
raise
|
279
|
-
finally:
|
280
|
-
self.inference_count -= 1
|
281
|
-
|
282
|
-
def update_model(self, output_dir):
|
283
|
-
print("Restarting server with new model...")
|
284
|
-
self.restarting = True
|
285
|
-
|
286
|
-
# Close existing session and reset inference count
|
287
|
-
if self._session:
|
288
|
-
# Create a new event loop if one doesn't exist
|
289
|
-
try:
|
290
|
-
loop = asyncio.get_event_loop()
|
291
|
-
except RuntimeError:
|
292
|
-
loop = asyncio.new_event_loop()
|
293
|
-
asyncio.set_event_loop(loop)
|
294
|
-
|
295
|
-
# Run the session closure in the event loop
|
296
|
-
loop.run_until_complete(self._session.close())
|
297
|
-
self._session = None
|
298
|
-
self.inference_count = 0
|
193
|
+
def start_weight_update(self):
|
194
|
+
"""Block inference during weight updates"""
|
195
|
+
self._is_updating += 1
|
299
196
|
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
# time.sleep(5)
|
197
|
+
def complete_weight_update(self):
|
198
|
+
"""Allow inference after weight update is complete"""
|
199
|
+
self._is_updating = max(0, self._is_updating - 1) # Prevent going negative
|
304
200
|
|
305
|
-
|
306
|
-
|
307
|
-
if
|
308
|
-
|
309
|
-
f"Failed to save model - output directory {output_dir} does not exist"
|
310
|
-
)
|
311
|
-
|
312
|
-
print("Directly updating weights from disk")
|
313
|
-
for worker_url in self.worker_urls:
|
314
|
-
print(f"Updating weights from disk for worker {worker_url}")
|
315
|
-
try:
|
316
|
-
response = requests.post(
|
317
|
-
f"{worker_url}/update_weights_from_disk",
|
318
|
-
json={"model_path": output_dir},
|
319
|
-
)
|
320
|
-
response_json = response.json()
|
321
|
-
print(f"Response from update_weights_from_disk: {response_json}")
|
322
|
-
# TODO: Check that the response is successful
|
323
|
-
except Exception as e:
|
324
|
-
print(f"Error during update_weights_from_disk: {e}")
|
325
|
-
print(f"Full error during update_weights_from_disk: {str(e)}")
|
326
|
-
if hasattr(e, "response") and e.response is not None:
|
327
|
-
print(f"Response status code: {e.response.status_code}")
|
328
|
-
print(f"Response text: {e.response.text}")
|
329
|
-
self.current_model = output_dir
|
330
|
-
|
331
|
-
# print("Launching new server")
|
332
|
-
# self.launch(output_dir, self.launch_kwargs)
|
333
|
-
tok = time.time()
|
334
|
-
self.restarting = False
|
335
|
-
print(f"Time taken to update model: {tok - tik} seconds")
|
336
|
-
|
337
|
-
async def _ensure_session(self):
|
338
|
-
if self._session is None or self._session.closed:
|
339
|
-
timeout = aiohttp.ClientTimeout(
|
340
|
-
total=None
|
341
|
-
) # No timeout...If it hangs, this might be the issue.
|
342
|
-
self._session = aiohttp.ClientSession(timeout=timeout)
|
343
|
-
return self._session
|
201
|
+
@property
|
202
|
+
def is_updating(self):
|
203
|
+
"""Check if any weight updates are in progress"""
|
204
|
+
return self._is_updating > 0
|
344
205
|
|
345
206
|
|
346
207
|
def get_free_port() -> int:
|
347
208
|
"""
|
348
|
-
Return a free TCP port on localhost.
|
209
|
+
Return a randomly selected free TCP port on localhost from a selection of 3-4 ports.
|
349
210
|
"""
|
350
|
-
|
351
|
-
|
352
|
-
|
211
|
+
import random
|
212
|
+
import socket
|
213
|
+
|
214
|
+
ports = []
|
215
|
+
for _ in range(random.randint(5, 10)):
|
216
|
+
try:
|
217
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
218
|
+
s.bind(("localhost", 0))
|
219
|
+
ports.append(s.getsockname()[1])
|
220
|
+
except Exception as e:
|
221
|
+
print(f"Error binding to port: {e}")
|
222
|
+
return random.choice(ports)
|
353
223
|
|
354
224
|
|
355
225
|
def wait_for_server(base_url: str, timeout: int = None) -> None:
|
@@ -379,26 +249,29 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
|
379
249
|
time.sleep(1)
|
380
250
|
|
381
251
|
|
382
|
-
def
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
socket.connect(f"tcp://localhost:{zmq_port}")
|
387
|
-
socket.setsockopt_string(zmq.SUBSCRIBE, "") # Subscribe to all messages
|
252
|
+
def kill_vllm_server(main_process_pid):
|
253
|
+
try:
|
254
|
+
# Get the parent process
|
255
|
+
parent = psutil.Process(main_process_pid)
|
388
256
|
|
389
|
-
|
390
|
-
|
257
|
+
# Get all child processes recursively
|
258
|
+
children = parent.children(recursive=True)
|
391
259
|
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
260
|
+
# Send SIGTERM to all child processes first
|
261
|
+
for child in children:
|
262
|
+
child.send_signal(signal.SIGTERM)
|
263
|
+
|
264
|
+
# Send SIGTERM to parent process
|
265
|
+
parent.send_signal(signal.SIGTERM)
|
266
|
+
|
267
|
+
# Wait for processes to terminate gracefully
|
268
|
+
gone, alive = psutil.wait_procs(children + [parent], timeout=10)
|
269
|
+
|
270
|
+
# If any processes are still alive, force kill them
|
271
|
+
for p in alive:
|
272
|
+
p.kill() # SIGKILL
|
273
|
+
|
274
|
+
except psutil.NoSuchProcess:
|
275
|
+
print(f"Process {main_process_pid} not found")
|
276
|
+
except Exception as e:
|
277
|
+
print(f"Error killing processes: {e}")
|
File without changes
|