arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/__init__.py +17 -0
  2. arbor/cli.py +83 -43
  3. arbor/client/arbor_client.py +259 -0
  4. arbor/server/api/models/schemas.py +3 -1
  5. arbor/server/api/routes/grpo.py +2 -6
  6. arbor/server/api/routes/inference.py +7 -3
  7. arbor/server/core/config.py +293 -7
  8. arbor/server/core/config_manager.py +100 -0
  9. arbor/server/main.py +26 -1
  10. arbor/server/services/comms/comms.py +13 -9
  11. arbor/server/services/file_manager.py +7 -4
  12. arbor/server/services/grpo_manager.py +98 -62
  13. arbor/server/services/health_manager.py +171 -0
  14. arbor/server/services/inference/vllm_client.py +6 -4
  15. arbor/server/services/inference_manager.py +40 -38
  16. arbor/server/services/job_manager.py +2 -2
  17. arbor/server/services/scripts/grpo_training.py +62 -281
  18. arbor/server/services/scripts/mmgrpo_training.py +510 -0
  19. arbor/server/services/scripts/sft_training.py +8 -5
  20. arbor/server/services/scripts/utils/callbacks.py +33 -0
  21. arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  22. arbor/server/services/scripts/utils/dataset.py +176 -0
  23. arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  24. arbor/server/services/scripts/utils/mock_server.py +124 -0
  25. arbor/server/services/training_manager.py +4 -4
  26. arbor/server/utils/logging.py +298 -0
  27. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
  28. arbor_ai-0.2.2.dist-info/RECORD +51 -0
  29. arbor_ai-0.2.1.dist-info/RECORD +0 -42
  30. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,34 @@
1
1
  import asyncio
2
2
  import os
3
- import random
4
3
  import signal
5
- import socket
6
4
  import subprocess
7
5
  import sys
8
6
  import threading
9
7
  import time
10
8
  from datetime import datetime
11
- from enum import Enum
12
9
  from typing import Any, Dict, Optional
13
10
 
14
11
  import psutil
15
12
  import requests
16
13
 
17
- from arbor.server.core.config import Settings
14
+ from arbor.server.core.config import Config
18
15
  from arbor.server.services.inference.vllm_client import VLLMClient
16
+ from arbor.server.utils.logging import get_logger
17
+
18
+ logger = get_logger(__name__)
19
19
 
20
20
 
21
21
  class InferenceManager:
22
- def __init__(self, settings: Settings):
23
- self.settings = settings
22
+ def __init__(self, config: Config):
23
+ self.config = config
24
24
  self.process = None
25
25
  self.launch_kwargs = {}
26
26
  self.last_activity = None
27
27
  self._shutting_down = False
28
- self.launched_model = None
28
+ self.launched_model: Optional[str] = None
29
29
  self.inference_count = 0
30
30
  self._session = None
31
- self.port = None
31
+ self.port: Optional[int] = None
32
32
  self.group_port = None
33
33
  self.vllm_client = None
34
34
  self._is_updating = 0 # Counter for weight updates in progress
@@ -37,21 +37,24 @@ class InferenceManager:
37
37
  signal.signal(signal.SIGTERM, self._signal_handler)
38
38
 
39
39
  def _signal_handler(self, signum, frame):
40
- if self._shutting_down:
41
- print("\nForced exit during cleanup...")
40
+ """Handle shutdown signals gracefully."""
41
+ logger.info(f"Received signal {signum}. Initiating graceful shutdown...")
42
+ try:
43
+ self.kill_server()
44
+ except Exception as e:
45
+ logger.error(f"Error during signal handler cleanup: {e}")
46
+ logger.info("Forced exit during cleanup...")
42
47
  os._exit(1)
48
+ logger.info("Received signal to terminate. Cleaning up...")
49
+ os._exit(0)
43
50
 
44
- print("\nReceived signal to terminate. Cleaning up...")
45
- self._shutting_down = True
46
- self.kill()
47
- sys.exit(0)
48
-
49
- def is_server_running(self):
50
- return self.process is not None
51
+ def is_server_running(self) -> bool:
52
+ """Check if vLLM server is running."""
53
+ return self.process is not None and self.process.poll() is None
51
54
 
52
55
  def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
53
56
  if self.is_server_running():
54
- print("Server is already launched.")
57
+ logger.info("Server is already launched.")
55
58
  return
56
59
 
57
60
  launch_kwargs = launch_kwargs or self.launch_kwargs
@@ -61,18 +64,17 @@ class InferenceManager:
61
64
  if model.startswith(prefix):
62
65
  model = model[len(prefix) :]
63
66
 
64
- print(f"Grabbing a free port to launch a vLLM server for model {model}")
67
+ logger.info(f"Grabbing a free port to launch a vLLM server for model {model}")
65
68
  self.port = get_free_port()
66
- timeout = launch_kwargs.get("timeout", 1800)
67
69
  my_env = os.environ.copy()
68
- my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
69
- n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
70
- command = f"python -m arbor.server.services.inference.vllm_serve --model {model} --port {self.port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --enable_prefix_caching True"
70
+ my_env["CUDA_VISIBLE_DEVICES"] = self.config.arbor_config.inference.gpu_ids
71
+ n_gpus = self.config.arbor_config.inference.gpu_ids.count(",") + 1
72
+ command = f"{sys.executable} -m arbor.server.services.inference.vllm_serve --model {model} --port {self.port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --enable_prefix_caching True"
71
73
 
72
74
  if launch_kwargs.get("max_context_length"):
73
75
  command += f" --max_model_len {launch_kwargs['max_context_length']}"
74
76
 
75
- print(f"Running command: {command}")
77
+ logger.info(f"Running command: {command}")
76
78
 
77
79
  # We will manually stream & capture logs.
78
80
  process = subprocess.Popen(
@@ -85,7 +87,7 @@ class InferenceManager:
85
87
 
86
88
  # A threading.Event to control printing after the server is ready.
87
89
  # This will store *all* lines (both before and after readiness).
88
- print(f"vLLM server process started with PID {process.pid}.")
90
+ logger.info(f"vLLM server process started with PID {process.pid}.")
89
91
  stop_printing_event = threading.Event()
90
92
  logs_buffer = []
91
93
 
@@ -97,9 +99,11 @@ class InferenceManager:
97
99
  break
98
100
  if line:
99
101
  buffer.append(line)
100
- # Print only if stop_event is not set
102
+ # Log only if stop_event is not set
101
103
  if not stop_event.is_set():
102
- print(f"[vLLM LOG] {line}", end="")
104
+ logger.info(f"[vLLM LOG] {line.strip()}")
105
+ else:
106
+ logger.debug(f"[vLLM LOG] {line.strip()}")
103
107
 
104
108
  # Start a background thread to read from the process continuously
105
109
  thread = threading.Thread(
@@ -115,7 +119,7 @@ class InferenceManager:
115
119
  return "".join(logs_buffer)
116
120
 
117
121
  # Let the user know server is up
118
- print(f"Server ready on random port {self.port}!")
122
+ logger.info(f"Server ready on random port {self.port}!")
119
123
 
120
124
  # self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
121
125
  # self.launch_kwargs["api_key"] = "local"
@@ -137,7 +141,7 @@ class InferenceManager:
137
141
 
138
142
  def kill(self):
139
143
  if self.process is None:
140
- print("No running server to kill.")
144
+ logger.info("No running server to kill.")
141
145
  return
142
146
 
143
147
  process = self.process
@@ -152,19 +156,18 @@ class InferenceManager:
152
156
  try:
153
157
  kill_vllm_server(process.pid)
154
158
  except Exception as e:
155
- print(f"Error during cleanup: {e}")
159
+ logger.error(f"Error during cleanup: {e}")
156
160
  try:
157
161
  process.kill() # Final attempt to kill
158
162
  except:
159
163
  pass
160
164
 
161
- print("Server killed.")
165
+ logger.info("Server killed.")
162
166
 
163
167
  async def run_inference(self, request_json: dict):
164
168
  # Check if weights are being updated
165
- while self.is_updating:
169
+ while self._is_updating:
166
170
  # weights are being updated...waiting
167
- # print("Weights are being updated, waiting...")
168
171
  await asyncio.sleep(1) # Small sleep to prevent busy waiting
169
172
 
170
173
  model = request_json["model"]
@@ -172,13 +175,12 @@ class InferenceManager:
172
175
  for prefix in prefixes:
173
176
  if model.startswith(prefix):
174
177
  model = model[len(prefix) :]
175
- print(f"Running inference for model {model}")
178
+ logger.info(f"Running inference for model {model}")
176
179
 
177
180
  # Monkeypatch for GRPO runs:
178
181
  # vllm complains if we don't give it the exact model name that was launched
179
182
  # TODO: This should really throw an error unless in a GRPO run.
180
183
  if model != self.launched_model:
181
- # print(f"Model changed from {model} to {self.current_model}")
182
184
  model = self.launched_model
183
185
  request_json["model"] = model
184
186
 
@@ -218,7 +220,7 @@ def get_free_port() -> int:
218
220
  s.bind(("localhost", 0))
219
221
  ports.append(s.getsockname()[1])
220
222
  except Exception as e:
221
- print(f"Error binding to port: {e}")
223
+ logger.error(f"Error binding to port: {e}")
222
224
  return random.choice(ports)
223
225
 
224
226
 
@@ -272,6 +274,6 @@ def kill_vllm_server(main_process_pid):
272
274
  p.kill() # SIGKILL
273
275
 
274
276
  except psutil.NoSuchProcess:
275
- print(f"Process {main_process_pid} not found")
277
+ logger.warning(f"Process {main_process_pid} not found")
276
278
  except Exception as e:
277
- print(f"Error killing processes: {e}")
279
+ logger.error(f"Error killing processes: {e}")
@@ -3,7 +3,7 @@ from datetime import datetime
3
3
  from typing import Literal
4
4
 
5
5
  from arbor.server.api.models.schemas import JobStatus
6
- from arbor.server.core.config import Settings
6
+ from arbor.server.core.config import Config
7
7
 
8
8
 
9
9
  class JobEvent:
@@ -58,7 +58,7 @@ class Job:
58
58
 
59
59
 
60
60
  class JobManager:
61
- def __init__(self, settings: Settings):
61
+ def __init__(self, config: Config):
62
62
  self.jobs = {}
63
63
 
64
64
  def get_job(self, job_id: str):