arbor-ai 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +12 -0
- arbor/server/api/routes/grpo.py +4 -1
- arbor/server/api/routes/inference.py +11 -16
- arbor/server/services/grpo_manager.py +179 -98
- arbor/server/services/inference/__init__.py +0 -0
- arbor/server/services/inference/vllm_client.py +445 -0
- arbor/server/services/inference/vllm_serve.py +2335 -0
- arbor/server/services/inference_manager.py +149 -219
- arbor/server/services/scripts/dpo_training.py +0 -0
- arbor/server/services/scripts/grpo_training.py +157 -53
- arbor/server/services/scripts/sft_training.py +109 -0
- arbor/server/services/scripts/utils/__init__.py +0 -0
- arbor/server/services/scripts/utils/arg_parser.py +31 -0
- arbor/server/services/scripts/utils/dataset.py +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/METADATA +4 -5
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/RECORD +20 -12
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/WHEEL +1 -1
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
|
|
1
1
|
import asyncio
|
2
|
-
import json
|
3
2
|
import os
|
4
3
|
import random
|
5
4
|
import signal
|
@@ -9,12 +8,14 @@ import sys
|
|
9
8
|
import threading
|
10
9
|
import time
|
11
10
|
from datetime import datetime
|
11
|
+
from enum import Enum
|
12
12
|
from typing import Any, Dict, Optional
|
13
13
|
|
14
|
-
import
|
14
|
+
import psutil
|
15
15
|
import requests
|
16
16
|
|
17
17
|
from arbor.server.core.config import Settings
|
18
|
+
from arbor.server.services.inference.vllm_client import VLLMClient
|
18
19
|
|
19
20
|
|
20
21
|
class InferenceManager:
|
@@ -23,11 +24,14 @@ class InferenceManager:
|
|
23
24
|
self.process = None
|
24
25
|
self.launch_kwargs = {}
|
25
26
|
self.last_activity = None
|
26
|
-
self.restarting = False
|
27
27
|
self._shutting_down = False
|
28
|
-
self.
|
28
|
+
self.launched_model = None
|
29
29
|
self.inference_count = 0
|
30
30
|
self._session = None
|
31
|
+
self.port = None
|
32
|
+
self.group_port = None
|
33
|
+
self.vllm_client = None
|
34
|
+
self._is_updating = 0 # Counter for weight updates in progress
|
31
35
|
# Set up signal handler for graceful shutdown
|
32
36
|
signal.signal(signal.SIGINT, self._signal_handler)
|
33
37
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
@@ -45,15 +49,7 @@ class InferenceManager:
|
|
45
49
|
def is_server_running(self):
|
46
50
|
return self.process is not None
|
47
51
|
|
48
|
-
def
|
49
|
-
return self.restarting
|
50
|
-
|
51
|
-
def launch(
|
52
|
-
self,
|
53
|
-
model: str,
|
54
|
-
launch_kwargs: Optional[Dict[str, Any]] = None,
|
55
|
-
max_retries: int = 3,
|
56
|
-
):
|
52
|
+
def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
|
57
53
|
if self.is_server_running():
|
58
54
|
print("Server is already launched.")
|
59
55
|
return
|
@@ -65,116 +61,81 @@ class InferenceManager:
|
|
65
61
|
if model.startswith(prefix):
|
66
62
|
model = model[len(prefix) :]
|
67
63
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
)
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
# Join them all into a single string, or you might return a list
|
142
|
-
return "".join(logs_buffer)
|
143
|
-
|
144
|
-
# Let the user know server is up
|
145
|
-
print(f"Server ready on random port {port}!")
|
146
|
-
|
147
|
-
self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
|
148
|
-
self.launch_kwargs["api_key"] = "local"
|
149
|
-
self.get_logs = get_logs
|
150
|
-
self.process = process
|
151
|
-
self.thread = thread
|
152
|
-
self.current_model = model
|
153
|
-
|
154
|
-
# If we get here, the launch was successful
|
155
|
-
return
|
156
|
-
|
157
|
-
except Exception as e:
|
158
|
-
retries += 1
|
159
|
-
print(
|
160
|
-
f"Failed to launch server (attempt {retries} of {max_retries}): {str(e)}"
|
161
|
-
)
|
162
|
-
# Clean up any failed processes
|
163
|
-
if "process" in locals():
|
164
|
-
try:
|
165
|
-
process.kill()
|
166
|
-
except:
|
167
|
-
pass
|
168
|
-
if retries == max_retries:
|
169
|
-
raise Exception(
|
170
|
-
f"Failed to launch server after {max_retries} attempts"
|
171
|
-
) from e
|
172
|
-
# Wait a bit before retrying
|
173
|
-
time.sleep(min(2**retries, 30)) # Exponential backoff, max 30 seconds
|
64
|
+
print(f"Grabbing a free port to launch a vLLM server for model {model}")
|
65
|
+
self.port = get_free_port()
|
66
|
+
timeout = launch_kwargs.get("timeout", 1800)
|
67
|
+
my_env = os.environ.copy()
|
68
|
+
my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
|
69
|
+
n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
|
70
|
+
command = f"python -m arbor.server.services.inference.vllm_serve --model {model} --port {self.port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --enable_prefix_caching True"
|
71
|
+
|
72
|
+
if launch_kwargs.get("max_context_length"):
|
73
|
+
command += f" --max_model_len {launch_kwargs['max_context_length']}"
|
74
|
+
|
75
|
+
print(f"Running command: {command}")
|
76
|
+
|
77
|
+
# We will manually stream & capture logs.
|
78
|
+
process = subprocess.Popen(
|
79
|
+
command.replace("\\\n", " ").replace("\\", " ").split(),
|
80
|
+
text=True,
|
81
|
+
stdout=subprocess.PIPE, # We'll read from pipe
|
82
|
+
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
83
|
+
env=my_env,
|
84
|
+
)
|
85
|
+
|
86
|
+
# A threading.Event to control printing after the server is ready.
|
87
|
+
# This will store *all* lines (both before and after readiness).
|
88
|
+
print(f"vLLM server process started with PID {process.pid}.")
|
89
|
+
stop_printing_event = threading.Event()
|
90
|
+
logs_buffer = []
|
91
|
+
|
92
|
+
def _tail_process(proc, buffer, stop_event):
|
93
|
+
while True:
|
94
|
+
line = proc.stdout.readline()
|
95
|
+
if not line and proc.poll() is not None:
|
96
|
+
# Process ended and no new line
|
97
|
+
break
|
98
|
+
if line:
|
99
|
+
buffer.append(line)
|
100
|
+
# Print only if stop_event is not set
|
101
|
+
if not stop_event.is_set():
|
102
|
+
print(f"[vLLM LOG] {line}", end="")
|
103
|
+
|
104
|
+
# Start a background thread to read from the process continuously
|
105
|
+
thread = threading.Thread(
|
106
|
+
target=_tail_process,
|
107
|
+
args=(process, logs_buffer, stop_printing_event),
|
108
|
+
daemon=True,
|
109
|
+
)
|
110
|
+
thread.start()
|
111
|
+
|
112
|
+
# A convenience getter so the caller can see all logs so far (and future).
|
113
|
+
def get_logs() -> str:
|
114
|
+
# Join them all into a single string, or you might return a list
|
115
|
+
return "".join(logs_buffer)
|
116
|
+
|
117
|
+
# Let the user know server is up
|
118
|
+
print(f"Server ready on random port {self.port}!")
|
119
|
+
|
120
|
+
# self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
|
121
|
+
# self.launch_kwargs["api_key"] = "local"
|
122
|
+
self.get_logs = get_logs
|
123
|
+
self.process = process
|
124
|
+
self.thread = thread
|
125
|
+
self.launched_model = model
|
126
|
+
|
127
|
+
# Get another free port for weight sync group communication
|
128
|
+
self.group_port = get_free_port()
|
129
|
+
self.vllm_client = VLLMClient(
|
130
|
+
port=self.port,
|
131
|
+
group_port=self.group_port,
|
132
|
+
connection_timeout=300, # 5 minutes
|
133
|
+
)
|
134
|
+
|
135
|
+
# Once server is ready, we tell the thread to stop printing further lines.
|
136
|
+
stop_printing_event.set()
|
174
137
|
|
175
138
|
def kill(self):
|
176
|
-
from sglang.utils import terminate_process
|
177
|
-
|
178
139
|
if self.process is None:
|
179
140
|
print("No running server to kill.")
|
180
141
|
return
|
@@ -189,24 +150,7 @@ class InferenceManager:
|
|
189
150
|
self.last_activity = None
|
190
151
|
|
191
152
|
try:
|
192
|
-
|
193
|
-
if self._shutting_down:
|
194
|
-
process.kill() # Go straight to SIGKILL if we're shutting down
|
195
|
-
else:
|
196
|
-
terminate_process(process)
|
197
|
-
try:
|
198
|
-
process.wait(timeout=10)
|
199
|
-
except subprocess.TimeoutExpired:
|
200
|
-
print(
|
201
|
-
"Process did not terminate after 10 seconds, forcing with SIGKILL..."
|
202
|
-
)
|
203
|
-
process.kill()
|
204
|
-
|
205
|
-
process.wait(timeout=5)
|
206
|
-
|
207
|
-
if thread and thread.is_alive():
|
208
|
-
thread.join(timeout=5)
|
209
|
-
|
153
|
+
kill_vllm_server(process.pid)
|
210
154
|
except Exception as e:
|
211
155
|
print(f"Error during cleanup: {e}")
|
212
156
|
try:
|
@@ -217,107 +161,65 @@ class InferenceManager:
|
|
217
161
|
print("Server killed.")
|
218
162
|
|
219
163
|
async def run_inference(self, request_json: dict):
|
164
|
+
# Check if weights are being updated
|
165
|
+
while self.is_updating:
|
166
|
+
# weights are being updated...waiting
|
167
|
+
# print("Weights are being updated, waiting...")
|
168
|
+
await asyncio.sleep(1) # Small sleep to prevent busy waiting
|
169
|
+
|
220
170
|
model = request_json["model"]
|
221
171
|
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
222
172
|
for prefix in prefixes:
|
223
173
|
if model.startswith(prefix):
|
224
174
|
model = model[len(prefix) :]
|
225
175
|
print(f"Running inference for model {model}")
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
176
|
+
|
177
|
+
# Monkeypatch for GRPO runs:
|
178
|
+
# vllm complains if we don't give it the exact model name that was launched
|
179
|
+
# TODO: This should really throw an error unless in a GRPO run.
|
180
|
+
if model != self.launched_model:
|
181
|
+
# print(f"Model changed from {model} to {self.current_model}")
|
182
|
+
model = self.launched_model
|
230
183
|
request_json["model"] = model
|
231
184
|
|
232
185
|
# Update last_activity timestamp
|
233
186
|
self.last_activity = datetime.now()
|
234
187
|
|
235
|
-
if self.process is None
|
188
|
+
if self.process is None:
|
236
189
|
raise RuntimeError("Server is not running. Please launch it first.")
|
237
190
|
|
238
|
-
|
239
|
-
while self.restarting:
|
240
|
-
print("Inference is paused while server is restarting...")
|
241
|
-
await asyncio.sleep(5)
|
242
|
-
request_json["model"] = self.current_model
|
191
|
+
return await self.vllm_client.chat(json_body=request_json)
|
243
192
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
session = await self._ensure_session()
|
248
|
-
async with session.post(url, json=request_json) as response:
|
249
|
-
content = await response.content.read()
|
250
|
-
return json.loads(content)
|
251
|
-
except aiohttp.ClientError as e:
|
252
|
-
print(f"Connection error: {type(e).__name__}: {str(e)}")
|
253
|
-
# Try to close and recreate the session on error
|
254
|
-
if self._session:
|
255
|
-
await self._session.close()
|
256
|
-
self._session = None
|
257
|
-
return None
|
258
|
-
except json.decoder.JSONDecodeError:
|
259
|
-
print(f"JSON Decode Error during inference: {content}")
|
260
|
-
return {
|
261
|
-
"error": "JSON Decode Error",
|
262
|
-
"content": content if content else "Content is null",
|
263
|
-
}
|
264
|
-
except Exception as e:
|
265
|
-
print(f"Error during inference: {e}")
|
266
|
-
raise
|
267
|
-
finally:
|
268
|
-
self.inference_count -= 1
|
269
|
-
|
270
|
-
def update_model(self, output_dir):
|
271
|
-
print("Restarting server with new model...")
|
272
|
-
self.restarting = True
|
273
|
-
|
274
|
-
# Close existing session and reset inference count
|
275
|
-
if self._session:
|
276
|
-
# Create a new event loop if one doesn't exist
|
277
|
-
try:
|
278
|
-
loop = asyncio.get_event_loop()
|
279
|
-
except RuntimeError:
|
280
|
-
loop = asyncio.new_event_loop()
|
281
|
-
asyncio.set_event_loop(loop)
|
282
|
-
|
283
|
-
# Run the session closure in the event loop
|
284
|
-
loop.run_until_complete(self._session.close())
|
285
|
-
self._session = None
|
286
|
-
self.inference_count = 0
|
193
|
+
def start_weight_update(self):
|
194
|
+
"""Block inference during weight updates"""
|
195
|
+
self._is_updating += 1
|
287
196
|
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
time.sleep(5)
|
292
|
-
# Check that output directory exists and was created successfully
|
293
|
-
print(f"Checking that output directory {output_dir} exists")
|
294
|
-
if not os.path.exists(output_dir):
|
295
|
-
raise RuntimeError(
|
296
|
-
f"Failed to save model - output directory {output_dir} does not exist"
|
297
|
-
)
|
197
|
+
def complete_weight_update(self):
|
198
|
+
"""Allow inference after weight update is complete"""
|
199
|
+
self._is_updating = max(0, self._is_updating - 1) # Prevent going negative
|
298
200
|
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
self.
|
303
|
-
print(f"Time taken to update model: {tok - tik} seconds")
|
304
|
-
|
305
|
-
async def _ensure_session(self):
|
306
|
-
if self._session is None or self._session.closed:
|
307
|
-
timeout = aiohttp.ClientTimeout(
|
308
|
-
total=None
|
309
|
-
) # No timeout...If it hangs, this might be the issue.
|
310
|
-
self._session = aiohttp.ClientSession(timeout=timeout)
|
311
|
-
return self._session
|
201
|
+
@property
|
202
|
+
def is_updating(self):
|
203
|
+
"""Check if any weight updates are in progress"""
|
204
|
+
return self._is_updating > 0
|
312
205
|
|
313
206
|
|
314
207
|
def get_free_port() -> int:
|
315
208
|
"""
|
316
|
-
Return a free TCP port on localhost.
|
209
|
+
Return a randomly selected free TCP port on localhost from a selection of 3-4 ports.
|
317
210
|
"""
|
318
|
-
|
319
|
-
|
320
|
-
|
211
|
+
import random
|
212
|
+
import socket
|
213
|
+
|
214
|
+
ports = []
|
215
|
+
for _ in range(random.randint(5, 10)):
|
216
|
+
try:
|
217
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
218
|
+
s.bind(("localhost", 0))
|
219
|
+
ports.append(s.getsockname()[1])
|
220
|
+
except Exception as e:
|
221
|
+
print(f"Error binding to port: {e}")
|
222
|
+
return random.choice(ports)
|
321
223
|
|
322
224
|
|
323
225
|
def wait_for_server(base_url: str, timeout: int = None) -> None:
|
@@ -345,3 +247,31 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
|
345
247
|
except requests.exceptions.RequestException:
|
346
248
|
# Server not up yet, wait and retry
|
347
249
|
time.sleep(1)
|
250
|
+
|
251
|
+
|
252
|
+
def kill_vllm_server(main_process_pid):
|
253
|
+
try:
|
254
|
+
# Get the parent process
|
255
|
+
parent = psutil.Process(main_process_pid)
|
256
|
+
|
257
|
+
# Get all child processes recursively
|
258
|
+
children = parent.children(recursive=True)
|
259
|
+
|
260
|
+
# Send SIGTERM to all child processes first
|
261
|
+
for child in children:
|
262
|
+
child.send_signal(signal.SIGTERM)
|
263
|
+
|
264
|
+
# Send SIGTERM to parent process
|
265
|
+
parent.send_signal(signal.SIGTERM)
|
266
|
+
|
267
|
+
# Wait for processes to terminate gracefully
|
268
|
+
gone, alive = psutil.wait_procs(children + [parent], timeout=10)
|
269
|
+
|
270
|
+
# If any processes are still alive, force kill them
|
271
|
+
for p in alive:
|
272
|
+
p.kill() # SIGKILL
|
273
|
+
|
274
|
+
except psutil.NoSuchProcess:
|
275
|
+
print(f"Process {main_process_pid} not found")
|
276
|
+
except Exception as e:
|
277
|
+
print(f"Error killing processes: {e}")
|
File without changes
|