arbor-ai 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,10 +49,15 @@ class ArborServerCommsHandler:
49
49
  yield status
50
50
 
51
51
  def close(self):
52
+ self.command_socket.setsockopt(zmq.LINGER, 0)
52
53
  self.command_socket.close()
54
+ self.status_socket.setsockopt(zmq.LINGER, 0)
53
55
  self.status_socket.close()
56
+ self.data_socket.setsockopt(zmq.LINGER, 0)
54
57
  self.data_socket.close()
58
+ self.broadcast_socket.setsockopt(zmq.LINGER, 0)
55
59
  self.broadcast_socket.close()
60
+ self.handshake_socket.setsockopt(zmq.LINGER, 0)
56
61
  self.handshake_socket.close()
57
62
  self.context.term()
58
63
 
@@ -2,6 +2,7 @@ import json
2
2
  import os
3
3
  import random
4
4
  import signal
5
+ import socket
5
6
  import string
6
7
  import subprocess
7
8
  import sys
@@ -120,12 +121,17 @@ class GRPOManager:
120
121
 
121
122
  num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
122
123
 
124
+ # This is the port for the accelerate main process
125
+ main_process_port = get_free_port()
126
+
123
127
  params = [
124
128
  "python",
125
129
  "-m",
126
130
  "accelerate.commands.launch",
127
131
  "--num_processes",
128
132
  str(num_processes),
133
+ "--main_process_port",
134
+ str(main_process_port),
129
135
  ]
130
136
  if self.settings.arbor_config.training.accelerate_config:
131
137
  params.extend(
@@ -276,7 +282,9 @@ class GRPOManager:
276
282
  inference_manager.kill()
277
283
 
278
284
  # Send termination command through REQ socket
279
- self.server_comms_handler.send_broadcast({"message": "terminate"})
285
+ # self.server_comms_handler.send_broadcast({"message": "terminate"})
286
+ self.training_process.terminate()
287
+ print("Waiting for training process to finish")
280
288
 
281
289
  # Wait for training process to finish
282
290
  if self.training_process:
@@ -289,6 +297,21 @@ class GRPOManager:
289
297
  if self.server_comms_handler:
290
298
  self.server_comms_handler.close()
291
299
 
300
+ # Force kill training process if still running
301
+ if self.training_process and self.training_process.poll() is None:
302
+ self.training_process.kill()
303
+ self.training_process.wait()
304
+
305
+ # Reinitialize incase we want to start a new training run
306
+ self.training_process = None
307
+ self.current_model = None
308
+ self.server_comms_handler = None
309
+ self.status_thread = None
310
+ self.model_saved_and_reload_requested = False
311
+
312
+ self.data_count = 0
313
+ self.last_inference_update = 0
314
+
292
315
  if self.train_kwargs and "output_dir" in self.train_kwargs:
293
316
  print(
294
317
  f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
@@ -297,9 +320,12 @@ class GRPOManager:
297
320
  print(
298
321
  f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
299
322
  )
300
- return self.train_kwargs["output_dir"]
323
+ output_dir = self.train_kwargs["output_dir"]
324
+ self.train_kwargs = None
325
+ return output_dir
301
326
  else:
302
327
  print("Training terminated, no output directory specified")
328
+ self.train_kwargs = None
303
329
  return None
304
330
 
305
331
  def _should_update_model(self):
@@ -308,3 +334,12 @@ class GRPOManager:
308
334
  # >= self.train_kwargs["update_interval"]
309
335
  # )
310
336
  return self.model_saved_and_reload_requested
337
+
338
+
339
+ def get_free_port() -> int:
340
+ """
341
+ Return a free TCP port on localhost.
342
+ """
343
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
344
+ s.bind(("localhost", 0))
345
+ return s.getsockname()[1]
@@ -61,7 +61,7 @@ class InferenceManager:
61
61
  my_env = os.environ.copy()
62
62
  my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
63
63
  n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
64
- # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching --guided-decoding-backend xgrammar"
64
+ # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching"
65
65
  command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --router-policy round_robin --port {port} --host 0.0.0.0"
66
66
  print(f"Running command: {command}")
67
67
 
@@ -137,8 +137,6 @@ class InferenceManager:
137
137
  process = self.process
138
138
  thread = self.thread
139
139
 
140
- terminate_process(process)
141
-
142
140
  # Clear references first
143
141
  self.process = None
144
142
  self.thread = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -57,6 +57,7 @@ inference:
57
57
  training:
58
58
  gpu_ids: '1, 2'
59
59
  ```
60
+ Which will use the `GPU:0` for inference with `GPU:1` and `GPU:2` reserved for training. We generally recommend splitting the GPUs roughly evenly between inference and training.
60
61
 
61
62
  ### 2️⃣ Start the Server
62
63
 
@@ -17,18 +17,18 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
17
17
  arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
20
- arbor/server/services/grpo_manager.py,sha256=T-f1TrNSi_kmxPOcpaDphS8Xf3UMUbricocc6fuaKIM,12077
21
- arbor/server/services/inference_manager.py,sha256=qR9xPiYs4Is24vgeF72w7Hbe8j_PGEbl-qewcvUV-dA,9731
20
+ arbor/server/services/grpo_manager.py,sha256=50g90lV8qpol7fQp2SBTXUCrF5eOP8YdxDnMLM0XY0E,13311
21
+ arbor/server/services/inference_manager.py,sha256=gHI-Biy3TtGkyWxIDKY-uqZZm_fiQJLktkPY8ezRvo8,9660
22
22
  arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
23
23
  arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
24
24
  arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
- arbor/server/services/comms/comms.py,sha256=Dg08D2Fm5TAEiGyr0Qcr0uocabQpFD_sBVhxIkj9D2M,7424
25
+ arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
26
26
  arbor/server/services/scripts/grpo_training.py,sha256=V36pCMZDJj2DdzquxScOddi9zP8EVPGWN3HGiftFfrY,21082
27
27
  arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- arbor_ai-0.1.8.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
- arbor_ai-0.1.8.dist-info/METADATA,sha256=kAZBj176hfqSrvrcWb0Wz8_vU33yiZJ-ck9buyDF6Jg,2234
31
- arbor_ai-0.1.8.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
- arbor_ai-0.1.8.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
- arbor_ai-0.1.8.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
- arbor_ai-0.1.8.dist-info/RECORD,,
29
+ arbor_ai-0.1.10.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
+ arbor_ai-0.1.10.dist-info/METADATA,sha256=qnUBfdKczxenG5kPTcZgQVMnWimEUPExz7nONxBYpDQ,2413
31
+ arbor_ai-0.1.10.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
+ arbor_ai-0.1.10.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
+ arbor_ai-0.1.10.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
+ arbor_ai-0.1.10.dist-info/RECORD,,