arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/__init__.py +17 -0
- arbor/cli.py +83 -43
- arbor/client/arbor_client.py +259 -0
- arbor/server/api/models/schemas.py +3 -1
- arbor/server/api/routes/grpo.py +2 -6
- arbor/server/api/routes/inference.py +7 -3
- arbor/server/core/config.py +293 -7
- arbor/server/core/config_manager.py +100 -0
- arbor/server/main.py +26 -1
- arbor/server/services/comms/comms.py +13 -9
- arbor/server/services/file_manager.py +7 -4
- arbor/server/services/grpo_manager.py +98 -62
- arbor/server/services/health_manager.py +171 -0
- arbor/server/services/inference/vllm_client.py +6 -4
- arbor/server/services/inference_manager.py +40 -38
- arbor/server/services/job_manager.py +2 -2
- arbor/server/services/scripts/grpo_training.py +62 -281
- arbor/server/services/scripts/mmgrpo_training.py +510 -0
- arbor/server/services/scripts/sft_training.py +8 -5
- arbor/server/services/scripts/utils/callbacks.py +33 -0
- arbor/server/services/scripts/utils/comms_monitors.py +169 -0
- arbor/server/services/scripts/utils/dataset.py +176 -0
- arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
- arbor/server/services/scripts/utils/mock_server.py +124 -0
- arbor/server/services/training_manager.py +4 -4
- arbor/server/utils/logging.py +298 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
- arbor_ai-0.2.2.dist-info/RECORD +51 -0
- arbor_ai-0.2.1.dist-info/RECORD +0 -42
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,34 @@
|
|
1
1
|
import asyncio
|
2
2
|
import os
|
3
|
-
import random
|
4
3
|
import signal
|
5
|
-
import socket
|
6
4
|
import subprocess
|
7
5
|
import sys
|
8
6
|
import threading
|
9
7
|
import time
|
10
8
|
from datetime import datetime
|
11
|
-
from enum import Enum
|
12
9
|
from typing import Any, Dict, Optional
|
13
10
|
|
14
11
|
import psutil
|
15
12
|
import requests
|
16
13
|
|
17
|
-
from arbor.server.core.config import
|
14
|
+
from arbor.server.core.config import Config
|
18
15
|
from arbor.server.services.inference.vllm_client import VLLMClient
|
16
|
+
from arbor.server.utils.logging import get_logger
|
17
|
+
|
18
|
+
logger = get_logger(__name__)
|
19
19
|
|
20
20
|
|
21
21
|
class InferenceManager:
|
22
|
-
def __init__(self,
|
23
|
-
self.
|
22
|
+
def __init__(self, config: Config):
|
23
|
+
self.config = config
|
24
24
|
self.process = None
|
25
25
|
self.launch_kwargs = {}
|
26
26
|
self.last_activity = None
|
27
27
|
self._shutting_down = False
|
28
|
-
self.launched_model = None
|
28
|
+
self.launched_model: Optional[str] = None
|
29
29
|
self.inference_count = 0
|
30
30
|
self._session = None
|
31
|
-
self.port = None
|
31
|
+
self.port: Optional[int] = None
|
32
32
|
self.group_port = None
|
33
33
|
self.vllm_client = None
|
34
34
|
self._is_updating = 0 # Counter for weight updates in progress
|
@@ -37,21 +37,24 @@ class InferenceManager:
|
|
37
37
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
38
38
|
|
39
39
|
def _signal_handler(self, signum, frame):
|
40
|
-
|
41
|
-
|
40
|
+
"""Handle shutdown signals gracefully."""
|
41
|
+
logger.info(f"Received signal {signum}. Initiating graceful shutdown...")
|
42
|
+
try:
|
43
|
+
self.kill_server()
|
44
|
+
except Exception as e:
|
45
|
+
logger.error(f"Error during signal handler cleanup: {e}")
|
46
|
+
logger.info("Forced exit during cleanup...")
|
42
47
|
os._exit(1)
|
48
|
+
logger.info("Received signal to terminate. Cleaning up...")
|
49
|
+
os._exit(0)
|
43
50
|
|
44
|
-
|
45
|
-
|
46
|
-
self.
|
47
|
-
sys.exit(0)
|
48
|
-
|
49
|
-
def is_server_running(self):
|
50
|
-
return self.process is not None
|
51
|
+
def is_server_running(self) -> bool:
|
52
|
+
"""Check if vLLM server is running."""
|
53
|
+
return self.process is not None and self.process.poll() is None
|
51
54
|
|
52
55
|
def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
|
53
56
|
if self.is_server_running():
|
54
|
-
|
57
|
+
logger.info("Server is already launched.")
|
55
58
|
return
|
56
59
|
|
57
60
|
launch_kwargs = launch_kwargs or self.launch_kwargs
|
@@ -61,18 +64,17 @@ class InferenceManager:
|
|
61
64
|
if model.startswith(prefix):
|
62
65
|
model = model[len(prefix) :]
|
63
66
|
|
64
|
-
|
67
|
+
logger.info(f"Grabbing a free port to launch a vLLM server for model {model}")
|
65
68
|
self.port = get_free_port()
|
66
|
-
timeout = launch_kwargs.get("timeout", 1800)
|
67
69
|
my_env = os.environ.copy()
|
68
|
-
my_env["CUDA_VISIBLE_DEVICES"] = self.
|
69
|
-
n_gpus = self.
|
70
|
-
command = f"
|
70
|
+
my_env["CUDA_VISIBLE_DEVICES"] = self.config.arbor_config.inference.gpu_ids
|
71
|
+
n_gpus = self.config.arbor_config.inference.gpu_ids.count(",") + 1
|
72
|
+
command = f"{sys.executable} -m arbor.server.services.inference.vllm_serve --model {model} --port {self.port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --enable_prefix_caching True"
|
71
73
|
|
72
74
|
if launch_kwargs.get("max_context_length"):
|
73
75
|
command += f" --max_model_len {launch_kwargs['max_context_length']}"
|
74
76
|
|
75
|
-
|
77
|
+
logger.info(f"Running command: {command}")
|
76
78
|
|
77
79
|
# We will manually stream & capture logs.
|
78
80
|
process = subprocess.Popen(
|
@@ -85,7 +87,7 @@ class InferenceManager:
|
|
85
87
|
|
86
88
|
# A threading.Event to control printing after the server is ready.
|
87
89
|
# This will store *all* lines (both before and after readiness).
|
88
|
-
|
90
|
+
logger.info(f"vLLM server process started with PID {process.pid}.")
|
89
91
|
stop_printing_event = threading.Event()
|
90
92
|
logs_buffer = []
|
91
93
|
|
@@ -97,9 +99,11 @@ class InferenceManager:
|
|
97
99
|
break
|
98
100
|
if line:
|
99
101
|
buffer.append(line)
|
100
|
-
#
|
102
|
+
# Log only if stop_event is not set
|
101
103
|
if not stop_event.is_set():
|
102
|
-
|
104
|
+
logger.info(f"[vLLM LOG] {line.strip()}")
|
105
|
+
else:
|
106
|
+
logger.debug(f"[vLLM LOG] {line.strip()}")
|
103
107
|
|
104
108
|
# Start a background thread to read from the process continuously
|
105
109
|
thread = threading.Thread(
|
@@ -115,7 +119,7 @@ class InferenceManager:
|
|
115
119
|
return "".join(logs_buffer)
|
116
120
|
|
117
121
|
# Let the user know server is up
|
118
|
-
|
122
|
+
logger.info(f"Server ready on random port {self.port}!")
|
119
123
|
|
120
124
|
# self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
|
121
125
|
# self.launch_kwargs["api_key"] = "local"
|
@@ -137,7 +141,7 @@ class InferenceManager:
|
|
137
141
|
|
138
142
|
def kill(self):
|
139
143
|
if self.process is None:
|
140
|
-
|
144
|
+
logger.info("No running server to kill.")
|
141
145
|
return
|
142
146
|
|
143
147
|
process = self.process
|
@@ -152,19 +156,18 @@ class InferenceManager:
|
|
152
156
|
try:
|
153
157
|
kill_vllm_server(process.pid)
|
154
158
|
except Exception as e:
|
155
|
-
|
159
|
+
logger.error(f"Error during cleanup: {e}")
|
156
160
|
try:
|
157
161
|
process.kill() # Final attempt to kill
|
158
162
|
except:
|
159
163
|
pass
|
160
164
|
|
161
|
-
|
165
|
+
logger.info("Server killed.")
|
162
166
|
|
163
167
|
async def run_inference(self, request_json: dict):
|
164
168
|
# Check if weights are being updated
|
165
|
-
while self.
|
169
|
+
while self._is_updating:
|
166
170
|
# weights are being updated...waiting
|
167
|
-
# print("Weights are being updated, waiting...")
|
168
171
|
await asyncio.sleep(1) # Small sleep to prevent busy waiting
|
169
172
|
|
170
173
|
model = request_json["model"]
|
@@ -172,13 +175,12 @@ class InferenceManager:
|
|
172
175
|
for prefix in prefixes:
|
173
176
|
if model.startswith(prefix):
|
174
177
|
model = model[len(prefix) :]
|
175
|
-
|
178
|
+
logger.info(f"Running inference for model {model}")
|
176
179
|
|
177
180
|
# Monkeypatch for GRPO runs:
|
178
181
|
# vllm complains if we don't give it the exact model name that was launched
|
179
182
|
# TODO: This should really throw an error unless in a GRPO run.
|
180
183
|
if model != self.launched_model:
|
181
|
-
# print(f"Model changed from {model} to {self.current_model}")
|
182
184
|
model = self.launched_model
|
183
185
|
request_json["model"] = model
|
184
186
|
|
@@ -218,7 +220,7 @@ def get_free_port() -> int:
|
|
218
220
|
s.bind(("localhost", 0))
|
219
221
|
ports.append(s.getsockname()[1])
|
220
222
|
except Exception as e:
|
221
|
-
|
223
|
+
logger.error(f"Error binding to port: {e}")
|
222
224
|
return random.choice(ports)
|
223
225
|
|
224
226
|
|
@@ -272,6 +274,6 @@ def kill_vllm_server(main_process_pid):
|
|
272
274
|
p.kill() # SIGKILL
|
273
275
|
|
274
276
|
except psutil.NoSuchProcess:
|
275
|
-
|
277
|
+
logger.warning(f"Process {main_process_pid} not found")
|
276
278
|
except Exception as e:
|
277
|
-
|
279
|
+
logger.error(f"Error killing processes: {e}")
|
@@ -3,7 +3,7 @@ from datetime import datetime
|
|
3
3
|
from typing import Literal
|
4
4
|
|
5
5
|
from arbor.server.api.models.schemas import JobStatus
|
6
|
-
from arbor.server.core.config import
|
6
|
+
from arbor.server.core.config import Config
|
7
7
|
|
8
8
|
|
9
9
|
class JobEvent:
|
@@ -58,7 +58,7 @@ class Job:
|
|
58
58
|
|
59
59
|
|
60
60
|
class JobManager:
|
61
|
-
def __init__(self,
|
61
|
+
def __init__(self, config: Config):
|
62
62
|
self.jobs = {}
|
63
63
|
|
64
64
|
def get_job(self, job_id: str):
|