arbor-ai 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/server/api/routes/inference.py +1 -1
- arbor/server/services/grpo_manager.py +30 -1
- arbor/server/services/inference_manager.py +38 -12
- arbor/server/services/scripts/grpo_training.py +0 -4
- {arbor_ai-0.1.9.dist-info → arbor_ai-0.1.11.dist-info}/METADATA +2 -1
- {arbor_ai-0.1.9.dist-info → arbor_ai-0.1.11.dist-info}/RECORD +10 -10
- {arbor_ai-0.1.9.dist-info → arbor_ai-0.1.11.dist-info}/WHEEL +1 -1
- {arbor_ai-0.1.9.dist-info → arbor_ai-0.1.11.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.9.dist-info → arbor_ai-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.9.dist-info → arbor_ai-0.1.11.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,7 @@ async def run_inference(
|
|
33
33
|
raw_json["model"] = inference_manager.current_model
|
34
34
|
|
35
35
|
# forward the request to the inference server
|
36
|
-
completion = inference_manager.run_inference(raw_json)
|
36
|
+
completion = await inference_manager.run_inference(raw_json)
|
37
37
|
|
38
38
|
return completion
|
39
39
|
|
@@ -1,7 +1,9 @@
|
|
1
|
+
import asyncio
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import random
|
4
5
|
import signal
|
6
|
+
import socket
|
5
7
|
import string
|
6
8
|
import subprocess
|
7
9
|
import sys
|
@@ -120,12 +122,17 @@ class GRPOManager:
|
|
120
122
|
|
121
123
|
num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
|
122
124
|
|
125
|
+
# This is the port for the accelerate main process
|
126
|
+
main_process_port = get_free_port()
|
127
|
+
|
123
128
|
params = [
|
124
129
|
"python",
|
125
130
|
"-m",
|
126
131
|
"accelerate.commands.launch",
|
127
132
|
"--num_processes",
|
128
133
|
str(num_processes),
|
134
|
+
"--main_process_port",
|
135
|
+
str(main_process_port),
|
129
136
|
]
|
130
137
|
if self.settings.arbor_config.training.accelerate_config:
|
131
138
|
params.extend(
|
@@ -257,8 +264,21 @@ class GRPOManager:
|
|
257
264
|
return self.current_model
|
258
265
|
|
259
266
|
def update_model(self, request, inference_manager: InferenceManager):
|
260
|
-
|
267
|
+
if inference_manager._session:
|
268
|
+
# Create a new event loop if one doesn't exist
|
269
|
+
try:
|
270
|
+
loop = asyncio.get_event_loop()
|
271
|
+
except RuntimeError:
|
272
|
+
loop = asyncio.new_event_loop()
|
273
|
+
asyncio.set_event_loop(loop)
|
274
|
+
|
275
|
+
# Run the session closure in the event loop
|
276
|
+
loop.run_until_complete(inference_manager._session.close())
|
277
|
+
inference_manager._session = None
|
278
|
+
|
279
|
+
inference_manager.inference_count = 0
|
261
280
|
inference_manager.restarting = True
|
281
|
+
|
262
282
|
self.model_saved_and_reload_requested = True
|
263
283
|
self.server_comms_handler.send_command({"command": "save_model"})
|
264
284
|
while self.model_saved_and_reload_requested:
|
@@ -328,3 +348,12 @@ class GRPOManager:
|
|
328
348
|
# >= self.train_kwargs["update_interval"]
|
329
349
|
# )
|
330
350
|
return self.model_saved_and_reload_requested
|
351
|
+
|
352
|
+
|
353
|
+
def get_free_port() -> int:
|
354
|
+
"""
|
355
|
+
Return a free TCP port on localhost.
|
356
|
+
"""
|
357
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
358
|
+
s.bind(("localhost", 0))
|
359
|
+
return s.getsockname()[1]
|
@@ -1,3 +1,5 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
1
3
|
import os
|
2
4
|
import signal
|
3
5
|
import socket
|
@@ -8,6 +10,7 @@ import time
|
|
8
10
|
from datetime import datetime
|
9
11
|
from typing import Any, Dict, Optional
|
10
12
|
|
13
|
+
import aiohttp
|
11
14
|
import requests
|
12
15
|
|
13
16
|
from arbor.server.core.config import Settings
|
@@ -23,6 +26,7 @@ class InferenceManager:
|
|
23
26
|
self._shutting_down = False
|
24
27
|
self.current_model = None
|
25
28
|
self.inference_count = 0
|
29
|
+
self._session = None
|
26
30
|
# Set up signal handler for graceful shutdown
|
27
31
|
signal.signal(signal.SIGINT, self._signal_handler)
|
28
32
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
@@ -62,7 +66,7 @@ class InferenceManager:
|
|
62
66
|
my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
|
63
67
|
n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
|
64
68
|
# command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
|
65
|
-
command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --
|
69
|
+
command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
|
66
70
|
print(f"Running command: {command}")
|
67
71
|
|
68
72
|
# We will manually stream & capture logs.
|
@@ -171,7 +175,7 @@ class InferenceManager:
|
|
171
175
|
|
172
176
|
print("Server killed.")
|
173
177
|
|
174
|
-
def run_inference(self, request_json: dict):
|
178
|
+
async def run_inference(self, request_json: dict):
|
175
179
|
model = request_json["model"]
|
176
180
|
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
177
181
|
for prefix in prefixes:
|
@@ -193,16 +197,22 @@ class InferenceManager:
|
|
193
197
|
if self.restarting:
|
194
198
|
while self.restarting:
|
195
199
|
print("Inference is paused while server is restarting...")
|
196
|
-
|
200
|
+
await asyncio.sleep(5)
|
197
201
|
request_json["model"] = self.current_model
|
198
202
|
|
199
203
|
url = f"{self.launch_kwargs['api_base']}/chat/completions"
|
200
204
|
try:
|
201
205
|
self.inference_count += 1
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
+
session = await self._ensure_session()
|
207
|
+
async with session.post(url, json=request_json) as response:
|
208
|
+
content = await response.content.read()
|
209
|
+
return json.loads(content)
|
210
|
+
except aiohttp.ClientError as e:
|
211
|
+
print(f"Connection error: {type(e).__name__}: {str(e)}")
|
212
|
+
# Try to close and recreate the session on error
|
213
|
+
if self._session:
|
214
|
+
await self._session.close()
|
215
|
+
self._session = None
|
206
216
|
return None
|
207
217
|
except Exception as e:
|
208
218
|
print(f"Error during inference: {e}")
|
@@ -214,11 +224,19 @@ class InferenceManager:
|
|
214
224
|
print("Restarting server with new model...")
|
215
225
|
self.restarting = True
|
216
226
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
227
|
+
# Close existing session and reset inference count
|
228
|
+
if self._session:
|
229
|
+
# Create a new event loop if one doesn't exist
|
230
|
+
try:
|
231
|
+
loop = asyncio.get_event_loop()
|
232
|
+
except RuntimeError:
|
233
|
+
loop = asyncio.new_event_loop()
|
234
|
+
asyncio.set_event_loop(loop)
|
235
|
+
|
236
|
+
# Run the session closure in the event loop
|
237
|
+
loop.run_until_complete(self._session.close())
|
238
|
+
self._session = None
|
239
|
+
self.inference_count = 0
|
222
240
|
|
223
241
|
tik = time.time()
|
224
242
|
self.kill()
|
@@ -236,6 +254,14 @@ class InferenceManager:
|
|
236
254
|
self.restarting = False
|
237
255
|
print(f"Time taken to update model: {tok - tik} seconds")
|
238
256
|
|
257
|
+
async def _ensure_session(self):
|
258
|
+
if self._session is None or self._session.closed:
|
259
|
+
timeout = aiohttp.ClientTimeout(
|
260
|
+
total=None
|
261
|
+
) # No timeout...If it hangs, this might be the issue.
|
262
|
+
self._session = aiohttp.ClientSession(timeout=timeout)
|
263
|
+
return self._session
|
264
|
+
|
239
265
|
|
240
266
|
def get_free_port() -> int:
|
241
267
|
"""
|
@@ -351,10 +351,6 @@ class CommandMonitor:
|
|
351
351
|
output_dir=self.trainer.args.output_dir + "/adapter/"
|
352
352
|
)
|
353
353
|
|
354
|
-
# base_model = AutoModelForCausalLM.from_pretrained(
|
355
|
-
# self.base_model_name
|
356
|
-
# ).to(self.trainer.accelerator.device)
|
357
|
-
|
358
354
|
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
359
355
|
self.trainer.args.output_dir + "/adapter/",
|
360
356
|
config=self.trainer.peft_config,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.11
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -57,6 +57,7 @@ inference:
|
|
57
57
|
training:
|
58
58
|
gpu_ids: '1, 2'
|
59
59
|
```
|
60
|
+
Which will use the `GPU:0` for inference with `GPU:1` and `GPU:2` reserved for training. We generally recommend splitting the GPUs roughly evenly between inference and training.
|
60
61
|
|
61
62
|
### 2️⃣ Start the Server
|
62
63
|
|
@@ -9,7 +9,7 @@ arbor/server/api/models/schemas.py,sha256=s_G8sSb05FjkKEqpKpLlqaEd8NysJddHibRHhc
|
|
9
9
|
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
|
11
11
|
arbor/server/api/routes/grpo.py,sha256=VuEvSOwwrHegn9qM-1nbHFmmUnnC_BMwnIHsfIdiJyI,1877
|
12
|
-
arbor/server/api/routes/inference.py,sha256=
|
12
|
+
arbor/server/api/routes/inference.py,sha256=Zy4ciN6vdRgu0-sFFnEeTZB-4XnLjEDH-atU7roIKSs,1668
|
13
13
|
arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
|
14
14
|
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
15
|
arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
|
@@ -17,18 +17,18 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
17
17
|
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
18
|
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
|
-
arbor/server/services/grpo_manager.py,sha256=
|
21
|
-
arbor/server/services/inference_manager.py,sha256=
|
20
|
+
arbor/server/services/grpo_manager.py,sha256=TAU2BMHgbCgiAvKNVd2Y8N20SR4qEms3lChA4Z0ZzyY,13777
|
21
|
+
arbor/server/services/inference_manager.py,sha256=YVHXqwBm9vEmgKzKdMKQdLdw6qkUTl5BjHTnW-3yfo0,10699
|
22
22
|
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
23
|
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
24
|
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
|
26
|
-
arbor/server/services/scripts/grpo_training.py,sha256=
|
26
|
+
arbor/server/services/scripts/grpo_training.py,sha256=Q9jwnbRdXAv_jVgrChLX6IiB3BLZU1F3BP6mBV0DVik,20889
|
27
27
|
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
-
arbor_ai-0.1.
|
30
|
-
arbor_ai-0.1.
|
31
|
-
arbor_ai-0.1.
|
32
|
-
arbor_ai-0.1.
|
33
|
-
arbor_ai-0.1.
|
34
|
-
arbor_ai-0.1.
|
29
|
+
arbor_ai-0.1.11.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
30
|
+
arbor_ai-0.1.11.dist-info/METADATA,sha256=04deKUBx8A_5j4_OU39_09873sHhs-jKZwMOeRSU3GA,2413
|
31
|
+
arbor_ai-0.1.11.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
|
32
|
+
arbor_ai-0.1.11.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
33
|
+
arbor_ai-0.1.11.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
34
|
+
arbor_ai-0.1.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|