arbor-ai 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/server/api/routes/inference.py +1 -1
- arbor/server/services/grpo_manager.py +15 -1
- arbor/server/services/inference_manager.py +39 -13
- arbor/server/services/scripts/grpo_training.py +0 -4
- {arbor_ai-0.1.10.dist-info → arbor_ai-0.1.12.dist-info}/METADATA +1 -1
- {arbor_ai-0.1.10.dist-info → arbor_ai-0.1.12.dist-info}/RECORD +10 -10
- {arbor_ai-0.1.10.dist-info → arbor_ai-0.1.12.dist-info}/WHEEL +1 -1
- {arbor_ai-0.1.10.dist-info → arbor_ai-0.1.12.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.10.dist-info → arbor_ai-0.1.12.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.10.dist-info → arbor_ai-0.1.12.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,7 @@ async def run_inference(
|
|
33
33
|
raw_json["model"] = inference_manager.current_model
|
34
34
|
|
35
35
|
# forward the request to the inference server
|
36
|
-
completion = inference_manager.run_inference(raw_json)
|
36
|
+
completion = await inference_manager.run_inference(raw_json)
|
37
37
|
|
38
38
|
return completion
|
39
39
|
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import asyncio
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import random
|
@@ -263,8 +264,21 @@ class GRPOManager:
|
|
263
264
|
return self.current_model
|
264
265
|
|
265
266
|
def update_model(self, request, inference_manager: InferenceManager):
|
266
|
-
|
267
|
+
if inference_manager._session:
|
268
|
+
# Create a new event loop if one doesn't exist
|
269
|
+
try:
|
270
|
+
loop = asyncio.get_event_loop()
|
271
|
+
except RuntimeError:
|
272
|
+
loop = asyncio.new_event_loop()
|
273
|
+
asyncio.set_event_loop(loop)
|
274
|
+
|
275
|
+
# Run the session closure in the event loop
|
276
|
+
loop.run_until_complete(inference_manager._session.close())
|
277
|
+
inference_manager._session = None
|
278
|
+
|
279
|
+
inference_manager.inference_count = 0
|
267
280
|
inference_manager.restarting = True
|
281
|
+
|
268
282
|
self.model_saved_and_reload_requested = True
|
269
283
|
self.server_comms_handler.send_command({"command": "save_model"})
|
270
284
|
while self.model_saved_and_reload_requested:
|
@@ -1,3 +1,5 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
1
3
|
import os
|
2
4
|
import signal
|
3
5
|
import socket
|
@@ -8,6 +10,7 @@ import time
|
|
8
10
|
from datetime import datetime
|
9
11
|
from typing import Any, Dict, Optional
|
10
12
|
|
13
|
+
import aiohttp
|
11
14
|
import requests
|
12
15
|
|
13
16
|
from arbor.server.core.config import Settings
|
@@ -23,6 +26,7 @@ class InferenceManager:
|
|
23
26
|
self._shutting_down = False
|
24
27
|
self.current_model = None
|
25
28
|
self.inference_count = 0
|
29
|
+
self._session = None
|
26
30
|
# Set up signal handler for graceful shutdown
|
27
31
|
signal.signal(signal.SIGINT, self._signal_handler)
|
28
32
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
@@ -62,7 +66,7 @@ class InferenceManager:
|
|
62
66
|
my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
|
63
67
|
n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
|
64
68
|
# command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
|
65
|
-
command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --
|
69
|
+
command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
|
66
70
|
print(f"Running command: {command}")
|
67
71
|
|
68
72
|
# We will manually stream & capture logs.
|
@@ -148,7 +152,7 @@ class InferenceManager:
|
|
148
152
|
if self._shutting_down:
|
149
153
|
process.kill() # Go straight to SIGKILL if we're shutting down
|
150
154
|
else:
|
151
|
-
process
|
155
|
+
terminate_process(process)
|
152
156
|
try:
|
153
157
|
process.wait(timeout=10)
|
154
158
|
except subprocess.TimeoutExpired:
|
@@ -171,7 +175,7 @@ class InferenceManager:
|
|
171
175
|
|
172
176
|
print("Server killed.")
|
173
177
|
|
174
|
-
def run_inference(self, request_json: dict):
|
178
|
+
async def run_inference(self, request_json: dict):
|
175
179
|
model = request_json["model"]
|
176
180
|
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
177
181
|
for prefix in prefixes:
|
@@ -193,16 +197,22 @@ class InferenceManager:
|
|
193
197
|
if self.restarting:
|
194
198
|
while self.restarting:
|
195
199
|
print("Inference is paused while server is restarting...")
|
196
|
-
|
200
|
+
await asyncio.sleep(5)
|
197
201
|
request_json["model"] = self.current_model
|
198
202
|
|
199
203
|
url = f"{self.launch_kwargs['api_base']}/chat/completions"
|
200
204
|
try:
|
201
205
|
self.inference_count += 1
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
+
session = await self._ensure_session()
|
207
|
+
async with session.post(url, json=request_json) as response:
|
208
|
+
content = await response.content.read()
|
209
|
+
return json.loads(content)
|
210
|
+
except aiohttp.ClientError as e:
|
211
|
+
print(f"Connection error: {type(e).__name__}: {str(e)}")
|
212
|
+
# Try to close and recreate the session on error
|
213
|
+
if self._session:
|
214
|
+
await self._session.close()
|
215
|
+
self._session = None
|
206
216
|
return None
|
207
217
|
except Exception as e:
|
208
218
|
print(f"Error during inference: {e}")
|
@@ -214,11 +224,19 @@ class InferenceManager:
|
|
214
224
|
print("Restarting server with new model...")
|
215
225
|
self.restarting = True
|
216
226
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
227
|
+
# Close existing session and reset inference count
|
228
|
+
if self._session:
|
229
|
+
# Create a new event loop if one doesn't exist
|
230
|
+
try:
|
231
|
+
loop = asyncio.get_event_loop()
|
232
|
+
except RuntimeError:
|
233
|
+
loop = asyncio.new_event_loop()
|
234
|
+
asyncio.set_event_loop(loop)
|
235
|
+
|
236
|
+
# Run the session closure in the event loop
|
237
|
+
loop.run_until_complete(self._session.close())
|
238
|
+
self._session = None
|
239
|
+
self.inference_count = 0
|
222
240
|
|
223
241
|
tik = time.time()
|
224
242
|
self.kill()
|
@@ -236,6 +254,14 @@ class InferenceManager:
|
|
236
254
|
self.restarting = False
|
237
255
|
print(f"Time taken to update model: {tok - tik} seconds")
|
238
256
|
|
257
|
+
async def _ensure_session(self):
|
258
|
+
if self._session is None or self._session.closed:
|
259
|
+
timeout = aiohttp.ClientTimeout(
|
260
|
+
total=None
|
261
|
+
) # No timeout...If it hangs, this might be the issue.
|
262
|
+
self._session = aiohttp.ClientSession(timeout=timeout)
|
263
|
+
return self._session
|
264
|
+
|
239
265
|
|
240
266
|
def get_free_port() -> int:
|
241
267
|
"""
|
@@ -351,10 +351,6 @@ class CommandMonitor:
|
|
351
351
|
output_dir=self.trainer.args.output_dir + "/adapter/"
|
352
352
|
)
|
353
353
|
|
354
|
-
# base_model = AutoModelForCausalLM.from_pretrained(
|
355
|
-
# self.base_model_name
|
356
|
-
# ).to(self.trainer.accelerator.device)
|
357
|
-
|
358
354
|
_model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
|
359
355
|
self.trainer.args.output_dir + "/adapter/",
|
360
356
|
config=self.trainer.peft_config,
|
@@ -9,7 +9,7 @@ arbor/server/api/models/schemas.py,sha256=s_G8sSb05FjkKEqpKpLlqaEd8NysJddHibRHhc
|
|
9
9
|
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
|
11
11
|
arbor/server/api/routes/grpo.py,sha256=VuEvSOwwrHegn9qM-1nbHFmmUnnC_BMwnIHsfIdiJyI,1877
|
12
|
-
arbor/server/api/routes/inference.py,sha256=
|
12
|
+
arbor/server/api/routes/inference.py,sha256=Zy4ciN6vdRgu0-sFFnEeTZB-4XnLjEDH-atU7roIKSs,1668
|
13
13
|
arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
|
14
14
|
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
15
|
arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
|
@@ -17,18 +17,18 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
17
17
|
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
18
|
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
|
-
arbor/server/services/grpo_manager.py,sha256=
|
21
|
-
arbor/server/services/inference_manager.py,sha256=
|
20
|
+
arbor/server/services/grpo_manager.py,sha256=TAU2BMHgbCgiAvKNVd2Y8N20SR4qEms3lChA4Z0ZzyY,13777
|
21
|
+
arbor/server/services/inference_manager.py,sha256=q4RVUqh1snGfW-AADkCqW8hC5x3WAZNe0jwXKOY5joU,10685
|
22
22
|
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
23
|
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
24
|
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
|
26
|
-
arbor/server/services/scripts/grpo_training.py,sha256=
|
26
|
+
arbor/server/services/scripts/grpo_training.py,sha256=Q9jwnbRdXAv_jVgrChLX6IiB3BLZU1F3BP6mBV0DVik,20889
|
27
27
|
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
-
arbor_ai-0.1.
|
30
|
-
arbor_ai-0.1.
|
31
|
-
arbor_ai-0.1.
|
32
|
-
arbor_ai-0.1.
|
33
|
-
arbor_ai-0.1.
|
34
|
-
arbor_ai-0.1.
|
29
|
+
arbor_ai-0.1.12.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
30
|
+
arbor_ai-0.1.12.dist-info/METADATA,sha256=upqnB_F9JDLytHm4AFrDnvPaOHdj8XiBCdrlam0rgRc,2413
|
31
|
+
arbor_ai-0.1.12.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
|
32
|
+
arbor_ai-0.1.12.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
33
|
+
arbor_ai-0.1.12.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
34
|
+
arbor_ai-0.1.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|