arbor-ai 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,7 +33,7 @@ async def run_inference(
33
33
  raw_json["model"] = inference_manager.current_model
34
34
 
35
35
  # forward the request to the inference server
36
- completion = inference_manager.run_inference(raw_json)
36
+ completion = await inference_manager.run_inference(raw_json)
37
37
 
38
38
  return completion
39
39
 
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import json
2
3
  import os
3
4
  import random
@@ -263,8 +264,21 @@ class GRPOManager:
263
264
  return self.current_model
264
265
 
265
266
  def update_model(self, request, inference_manager: InferenceManager):
266
- # THIS IS HACKY AND NEEDS TO BE FIXED BEFORE RELEASE
267
+ if inference_manager._session:
268
+ # Create a new event loop if one doesn't exist
269
+ try:
270
+ loop = asyncio.get_event_loop()
271
+ except RuntimeError:
272
+ loop = asyncio.new_event_loop()
273
+ asyncio.set_event_loop(loop)
274
+
275
+ # Run the session closure in the event loop
276
+ loop.run_until_complete(inference_manager._session.close())
277
+ inference_manager._session = None
278
+
279
+ inference_manager.inference_count = 0
267
280
  inference_manager.restarting = True
281
+
268
282
  self.model_saved_and_reload_requested = True
269
283
  self.server_comms_handler.send_command({"command": "save_model"})
270
284
  while self.model_saved_and_reload_requested:
@@ -1,3 +1,5 @@
1
+ import asyncio
2
+ import json
1
3
  import os
2
4
  import signal
3
5
  import socket
@@ -8,6 +10,7 @@ import time
8
10
  from datetime import datetime
9
11
  from typing import Any, Dict, Optional
10
12
 
13
+ import aiohttp
11
14
  import requests
12
15
 
13
16
  from arbor.server.core.config import Settings
@@ -23,6 +26,7 @@ class InferenceManager:
23
26
  self._shutting_down = False
24
27
  self.current_model = None
25
28
  self.inference_count = 0
29
+ self._session = None
26
30
  # Set up signal handler for graceful shutdown
27
31
  signal.signal(signal.SIGINT, self._signal_handler)
28
32
  signal.signal(signal.SIGTERM, self._signal_handler)
@@ -62,7 +66,7 @@ class InferenceManager:
62
66
  my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
63
67
  n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
64
68
  # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
65
- command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --router-policy round_robin --port {port} --host 0.0.0.0"
69
+ command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --port {port} --host 0.0.0.0 --disable-radix-cache"
66
70
  print(f"Running command: {command}")
67
71
 
68
72
  # We will manually stream & capture logs.
@@ -171,7 +175,7 @@ class InferenceManager:
171
175
 
172
176
  print("Server killed.")
173
177
 
174
- def run_inference(self, request_json: dict):
178
+ async def run_inference(self, request_json: dict):
175
179
  model = request_json["model"]
176
180
  prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
177
181
  for prefix in prefixes:
@@ -193,16 +197,22 @@ class InferenceManager:
193
197
  if self.restarting:
194
198
  while self.restarting:
195
199
  print("Inference is paused while server is restarting...")
196
- time.sleep(5)
200
+ await asyncio.sleep(5)
197
201
  request_json["model"] = self.current_model
198
202
 
199
203
  url = f"{self.launch_kwargs['api_base']}/chat/completions"
200
204
  try:
201
205
  self.inference_count += 1
202
- response = requests.post(url, json=request_json)
203
- return response.json()
204
- except requests.exceptions.ConnectionError:
205
- print("Server disconnected...ignoring")
206
+ session = await self._ensure_session()
207
+ async with session.post(url, json=request_json) as response:
208
+ content = await response.content.read()
209
+ return json.loads(content)
210
+ except aiohttp.ClientError as e:
211
+ print(f"Connection error: {type(e).__name__}: {str(e)}")
212
+ # Try to close and recreate the session on error
213
+ if self._session:
214
+ await self._session.close()
215
+ self._session = None
206
216
  return None
207
217
  except Exception as e:
208
218
  print(f"Error during inference: {e}")
@@ -214,11 +224,19 @@ class InferenceManager:
214
224
  print("Restarting server with new model...")
215
225
  self.restarting = True
216
226
 
217
- while self.inference_count > 0:
218
- print(
219
- f"Waiting for inference requests to finish... {self.inference_count} remaining"
220
- )
221
- time.sleep(5)
227
+ # Close existing session and reset inference count
228
+ if self._session:
229
+ # Create a new event loop if one doesn't exist
230
+ try:
231
+ loop = asyncio.get_event_loop()
232
+ except RuntimeError:
233
+ loop = asyncio.new_event_loop()
234
+ asyncio.set_event_loop(loop)
235
+
236
+ # Run the session closure in the event loop
237
+ loop.run_until_complete(self._session.close())
238
+ self._session = None
239
+ self.inference_count = 0
222
240
 
223
241
  tik = time.time()
224
242
  self.kill()
@@ -236,6 +254,14 @@ class InferenceManager:
236
254
  self.restarting = False
237
255
  print(f"Time taken to update model: {tok - tik} seconds")
238
256
 
257
+ async def _ensure_session(self):
258
+ if self._session is None or self._session.closed:
259
+ timeout = aiohttp.ClientTimeout(
260
+ total=None
261
+ ) # No timeout...If it hangs, this might be the issue.
262
+ self._session = aiohttp.ClientSession(timeout=timeout)
263
+ return self._session
264
+
239
265
 
240
266
  def get_free_port() -> int:
241
267
  """
@@ -351,10 +351,6 @@ class CommandMonitor:
351
351
  output_dir=self.trainer.args.output_dir + "/adapter/"
352
352
  )
353
353
 
354
- # base_model = AutoModelForCausalLM.from_pretrained(
355
- # self.base_model_name
356
- # ).to(self.trainer.accelerator.device)
357
-
358
354
  _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
359
355
  self.trainer.args.output_dir + "/adapter/",
360
356
  config=self.trainer.peft_config,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.10
3
+ Version: 0.1.11
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -9,7 +9,7 @@ arbor/server/api/models/schemas.py,sha256=s_G8sSb05FjkKEqpKpLlqaEd8NysJddHibRHhc
9
9
  arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
11
11
  arbor/server/api/routes/grpo.py,sha256=VuEvSOwwrHegn9qM-1nbHFmmUnnC_BMwnIHsfIdiJyI,1877
12
- arbor/server/api/routes/inference.py,sha256=xlP-FMpOJAiiPZkE470l9mCR0ujLki8RrcO9hmTQD-k,1662
12
+ arbor/server/api/routes/inference.py,sha256=Zy4ciN6vdRgu0-sFFnEeTZB-4XnLjEDH-atU7roIKSs,1668
13
13
  arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
14
14
  arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
15
15
  arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
@@ -17,18 +17,18 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
17
17
  arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
20
- arbor/server/services/grpo_manager.py,sha256=50g90lV8qpol7fQp2SBTXUCrF5eOP8YdxDnMLM0XY0E,13311
21
- arbor/server/services/inference_manager.py,sha256=gHI-Biy3TtGkyWxIDKY-uqZZm_fiQJLktkPY8ezRvo8,9660
20
+ arbor/server/services/grpo_manager.py,sha256=TAU2BMHgbCgiAvKNVd2Y8N20SR4qEms3lChA4Z0ZzyY,13777
21
+ arbor/server/services/inference_manager.py,sha256=YVHXqwBm9vEmgKzKdMKQdLdw6qkUTl5BjHTnW-3yfo0,10699
22
22
  arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
23
23
  arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
24
24
  arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
26
- arbor/server/services/scripts/grpo_training.py,sha256=V36pCMZDJj2DdzquxScOddi9zP8EVPGWN3HGiftFfrY,21082
26
+ arbor/server/services/scripts/grpo_training.py,sha256=Q9jwnbRdXAv_jVgrChLX6IiB3BLZU1F3BP6mBV0DVik,20889
27
27
  arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- arbor_ai-0.1.10.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
- arbor_ai-0.1.10.dist-info/METADATA,sha256=qnUBfdKczxenG5kPTcZgQVMnWimEUPExz7nONxBYpDQ,2413
31
- arbor_ai-0.1.10.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
- arbor_ai-0.1.10.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
- arbor_ai-0.1.10.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
- arbor_ai-0.1.10.dist-info/RECORD,,
29
+ arbor_ai-0.1.11.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
+ arbor_ai-0.1.11.dist-info/METADATA,sha256=04deKUBx8A_5j4_OU39_09873sHhs-jKZwMOeRSU3GA,2413
31
+ arbor_ai-0.1.11.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
+ arbor_ai-0.1.11.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
+ arbor_ai-0.1.11.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
+ arbor_ai-0.1.11.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.4.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5