arbor-ai 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/client/__init__.py +0 -0
- arbor/client/api.py +1 -0
- arbor/server/__init__.py +1 -0
- arbor/server/api/__init__.py +1 -0
- arbor/server/api/models/schemas.py +223 -0
- arbor/server/api/routes/__init__.py +0 -0
- arbor/server/api/routes/files.py +52 -0
- arbor/server/api/routes/grpo.py +54 -0
- arbor/server/api/routes/inference.py +53 -0
- arbor/server/api/routes/jobs.py +117 -0
- arbor/server/core/__init__.py +1 -0
- arbor/server/core/config.py +47 -0
- arbor/server/core/logging.py +0 -0
- arbor/server/main.py +11 -0
- arbor/server/services/__init__.py +0 -0
- arbor/server/services/comms/__init__.py +0 -0
- arbor/server/services/comms/comms.py +226 -0
- arbor/server/services/dependencies.py +0 -0
- arbor/server/services/file_manager.py +289 -0
- arbor/server/services/grpo_manager.py +310 -0
- arbor/server/services/inference_manager.py +275 -0
- arbor/server/services/job_manager.py +81 -0
- arbor/server/services/scripts/grpo_training.py +576 -0
- arbor/server/services/training_manager.py +561 -0
- arbor/server/utils/__init__.py +0 -0
- arbor/server/utils/helpers.py +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/METADATA +1 -1
- arbor_ai-0.1.6.dist-info/RECORD +34 -0
- arbor_ai-0.1.5.dist-info/RECORD +0 -8
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,310 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import signal
|
5
|
+
import string
|
6
|
+
import subprocess
|
7
|
+
import sys
|
8
|
+
import threading
|
9
|
+
import time
|
10
|
+
from datetime import datetime
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import Optional
|
13
|
+
|
14
|
+
from arbor.server.api.models.schemas import GRPOConfigRequest, GRPORequest
|
15
|
+
from arbor.server.core.config import Settings
|
16
|
+
from arbor.server.services.comms.comms import ArborServerCommsHandler
|
17
|
+
from arbor.server.services.inference_manager import InferenceManager
|
18
|
+
|
19
|
+
|
20
|
+
class GRPOManager:
|
21
|
+
def __init__(self, settings: Settings):
|
22
|
+
self.settings = settings
|
23
|
+
self.training_process = None
|
24
|
+
self.current_model = None
|
25
|
+
self.train_kwargs = None
|
26
|
+
self.server_comms_handler = None
|
27
|
+
self.status_thread = None
|
28
|
+
self.model_saved_and_reload_requested = False
|
29
|
+
|
30
|
+
self.data_count = 0
|
31
|
+
self.last_inference_update = 0
|
32
|
+
# Set up signal handler
|
33
|
+
signal.signal(signal.SIGINT, self._signal_handler)
|
34
|
+
signal.signal(signal.SIGTERM, self._signal_handler)
|
35
|
+
|
36
|
+
def _signal_handler(self, signum, frame):
|
37
|
+
"""Handle keyboard interrupt (SIGINT) gracefully."""
|
38
|
+
print("\nReceived keyboard interrupt. Shutting down gracefully...")
|
39
|
+
self.terminate(None)
|
40
|
+
sys.exit(0)
|
41
|
+
|
42
|
+
def make_output_dir(
|
43
|
+
self, model_name: str, run_suffix: Optional[str] = None
|
44
|
+
) -> tuple[str, str]:
|
45
|
+
"""Create a unique output directory name for the training run."""
|
46
|
+
model_name = model_name.split("/")[-1].lower()
|
47
|
+
suffix = (
|
48
|
+
run_suffix
|
49
|
+
if run_suffix
|
50
|
+
else "".join(random.choices(string.ascii_letters + string.digits, k=6))
|
51
|
+
)
|
52
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
53
|
+
name = f"grpo:{model_name}:{suffix}:{timestamp}"
|
54
|
+
return name, str(Path(self.settings.STORAGE_PATH).resolve() / "models" / name)
|
55
|
+
|
56
|
+
def find_training_args(self, request: GRPOConfigRequest) -> dict:
|
57
|
+
"""Process the config request and return training arguments."""
|
58
|
+
name, output_dir = self.make_output_dir(request.model, request.suffix)
|
59
|
+
|
60
|
+
# Here are defaults for training. We can adjust them if we disagree w the huggingface defaults
|
61
|
+
default_train_kwargs = {
|
62
|
+
"output_dir": output_dir,
|
63
|
+
}
|
64
|
+
|
65
|
+
train_kwargs = request.model_dump(exclude_unset=True)
|
66
|
+
return {**default_train_kwargs, **(train_kwargs or {})}
|
67
|
+
|
68
|
+
def process_training_args(self, train_kwargs: dict) -> tuple[dict, dict]:
|
69
|
+
# NOTE: These also need to be in the GRPOConfigRequest
|
70
|
+
trl_keys = [
|
71
|
+
"output_dir",
|
72
|
+
"temperature",
|
73
|
+
"beta",
|
74
|
+
"num_iterations",
|
75
|
+
"num_generations",
|
76
|
+
"per_device_train_batch_size",
|
77
|
+
"learning_rate",
|
78
|
+
"gradient_accumulation_steps",
|
79
|
+
"gradient_checkpointing",
|
80
|
+
"lr_scheduler_type",
|
81
|
+
"max_prompt_length",
|
82
|
+
"max_completion_length",
|
83
|
+
"gradient_checkpointing_kwargs",
|
84
|
+
"bf16",
|
85
|
+
"scale_rewards",
|
86
|
+
"max_grad_norm",
|
87
|
+
]
|
88
|
+
trl_train_kwargs = {
|
89
|
+
key: train_kwargs[key] for key in trl_keys if key in train_kwargs
|
90
|
+
}
|
91
|
+
|
92
|
+
arbor_keys = ["update_interval", "lora"]
|
93
|
+
arbor_train_kwargs = {
|
94
|
+
key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
|
95
|
+
}
|
96
|
+
|
97
|
+
return trl_train_kwargs, arbor_train_kwargs
|
98
|
+
|
99
|
+
def initialize(
|
100
|
+
self, request: GRPOConfigRequest, inference_manager: InferenceManager
|
101
|
+
):
|
102
|
+
"""Initialize the training process with ZMQ-based communication."""
|
103
|
+
self.train_kwargs = self.find_training_args(request)
|
104
|
+
|
105
|
+
trl_train_kwargs, arbor_train_kwargs = self.process_training_args(
|
106
|
+
self.train_kwargs
|
107
|
+
)
|
108
|
+
|
109
|
+
self.current_model = request.model
|
110
|
+
|
111
|
+
# Initialize ZMQ socket manager - no need for connection acceptance thread anymore
|
112
|
+
self.server_comms_handler = ArborServerCommsHandler()
|
113
|
+
|
114
|
+
script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
|
115
|
+
script_path = os.path.join(script_dir, "grpo_training.py")
|
116
|
+
|
117
|
+
# Start the training process with ZMQ ports
|
118
|
+
my_env = os.environ.copy()
|
119
|
+
my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
|
120
|
+
|
121
|
+
num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
|
122
|
+
|
123
|
+
params = [
|
124
|
+
"python",
|
125
|
+
"-m",
|
126
|
+
"accelerate.commands.launch",
|
127
|
+
"--num_processes",
|
128
|
+
str(num_processes),
|
129
|
+
]
|
130
|
+
if self.settings.arbor_config.training.accelerate_config:
|
131
|
+
params.extend(
|
132
|
+
[
|
133
|
+
"--config_file",
|
134
|
+
self.settings.arbor_config.training.accelerate_config,
|
135
|
+
]
|
136
|
+
)
|
137
|
+
params.extend(
|
138
|
+
[
|
139
|
+
script_path,
|
140
|
+
# Comms args
|
141
|
+
"--host",
|
142
|
+
self.server_comms_handler.host,
|
143
|
+
"--command_port",
|
144
|
+
str(self.server_comms_handler.command_port),
|
145
|
+
"--status_port",
|
146
|
+
str(self.server_comms_handler.status_port),
|
147
|
+
"--data_port",
|
148
|
+
str(self.server_comms_handler.data_port),
|
149
|
+
"--broadcast_port",
|
150
|
+
str(self.server_comms_handler.broadcast_port),
|
151
|
+
"--handshake_port",
|
152
|
+
str(self.server_comms_handler.handshake_port),
|
153
|
+
# Training args
|
154
|
+
"--model",
|
155
|
+
self.current_model,
|
156
|
+
"--trl_train_kwargs",
|
157
|
+
json.dumps(trl_train_kwargs),
|
158
|
+
"--arbor_train_kwargs",
|
159
|
+
json.dumps(arbor_train_kwargs),
|
160
|
+
]
|
161
|
+
)
|
162
|
+
print(f"Running following command\n: {' '.join(params)}")
|
163
|
+
|
164
|
+
self.training_process = subprocess.Popen(
|
165
|
+
params,
|
166
|
+
text=True,
|
167
|
+
stdout=subprocess.PIPE,
|
168
|
+
stderr=subprocess.STDOUT,
|
169
|
+
env=my_env,
|
170
|
+
)
|
171
|
+
|
172
|
+
# A threading.Event to control printing after the server is ready.
|
173
|
+
stop_printing_event = threading.Event()
|
174
|
+
logs_buffer = []
|
175
|
+
|
176
|
+
def _tail_process(proc, buffer, stop_event):
|
177
|
+
while True:
|
178
|
+
line = proc.stdout.readline()
|
179
|
+
if not line and proc.poll() is not None:
|
180
|
+
# Process ended and no new line
|
181
|
+
break
|
182
|
+
if line:
|
183
|
+
buffer.append(line)
|
184
|
+
# Print only if stop_event is not set
|
185
|
+
if not stop_event.is_set():
|
186
|
+
print(f"[GRPO LOG] {line}", end="")
|
187
|
+
|
188
|
+
# Start a background thread to read from the process continuously
|
189
|
+
thread = threading.Thread(
|
190
|
+
target=_tail_process,
|
191
|
+
args=(self.training_process, logs_buffer, stop_printing_event),
|
192
|
+
daemon=True,
|
193
|
+
)
|
194
|
+
thread.start()
|
195
|
+
|
196
|
+
# Start status handling thread
|
197
|
+
self.status_thread = threading.Thread(
|
198
|
+
target=self._handle_status_updates, args=(inference_manager,), daemon=True
|
199
|
+
)
|
200
|
+
self.status_thread.start()
|
201
|
+
self.server_comms_handler.wait_for_clients(num_processes)
|
202
|
+
|
203
|
+
# Launch the inference server
|
204
|
+
print("Launching inference server...")
|
205
|
+
inference_manager.launch(self.current_model)
|
206
|
+
|
207
|
+
def _handle_status_updates(self, inference_manager: InferenceManager):
|
208
|
+
"""Handle status updates from training process using ZMQ SUB socket"""
|
209
|
+
print("Starting status update handler...")
|
210
|
+
try:
|
211
|
+
|
212
|
+
for status in self.server_comms_handler.receive_status():
|
213
|
+
print(f"Received status update: {status}")
|
214
|
+
if status["status"] == "model_saved":
|
215
|
+
print("Updating inference model...")
|
216
|
+
# There is a case where this status is sent multiple times
|
217
|
+
# We need to make sure we only update the model once
|
218
|
+
if self._should_update_model():
|
219
|
+
inference_manager.update_model(status["output_dir"])
|
220
|
+
# self.last_inference_update = self.data_count
|
221
|
+
self.model_saved_and_reload_requested = False
|
222
|
+
self.current_model = status["output_dir"]
|
223
|
+
print("Model update complete")
|
224
|
+
elif status["status"] == "error":
|
225
|
+
print(f"Training error: {status.get('error', 'Unknown error')}")
|
226
|
+
elif status["status"] == "terminated":
|
227
|
+
print("Training process terminated")
|
228
|
+
break
|
229
|
+
except Exception as e:
|
230
|
+
print(f"Error in status update handler: {e}")
|
231
|
+
|
232
|
+
def grpo_step(
|
233
|
+
self, request: GRPORequest, inference_manager: InferenceManager
|
234
|
+
) -> str:
|
235
|
+
while inference_manager.is_server_restarting():
|
236
|
+
print("Inferece manager restarting, waiting for GRPO step")
|
237
|
+
time.sleep(5)
|
238
|
+
|
239
|
+
while self._should_update_model():
|
240
|
+
print(
|
241
|
+
f"Waiting for model update. Data count: {self.data_count}, Last inference update: {self.last_inference_update}"
|
242
|
+
)
|
243
|
+
time.sleep(5)
|
244
|
+
|
245
|
+
try:
|
246
|
+
# Send the batch to the training process
|
247
|
+
self.server_comms_handler.send_data(request.batch)
|
248
|
+
self.data_count += 1
|
249
|
+
except Exception as e:
|
250
|
+
print(f"Failed to send batch to training process: {e}")
|
251
|
+
|
252
|
+
# We tell the script to save the model. The script will let us know when it's done via the status update handler
|
253
|
+
# Then we'll actually run the update_model function in the inference manager and finally update the last_inference_update variable
|
254
|
+
# if self._should_update_model():
|
255
|
+
# self.server_comms_handler.send_command({"command": "save_model"})
|
256
|
+
|
257
|
+
return self.current_model
|
258
|
+
|
259
|
+
def update_model(self, request, inference_manager: InferenceManager):
|
260
|
+
# THIS IS HACKY AND NEEDS TO BE FIXED BEFORE RELEASE
|
261
|
+
inference_manager.restarting = True
|
262
|
+
self.model_saved_and_reload_requested = True
|
263
|
+
self.server_comms_handler.send_command({"command": "save_model"})
|
264
|
+
while self.model_saved_and_reload_requested:
|
265
|
+
print(
|
266
|
+
"Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
|
267
|
+
)
|
268
|
+
time.sleep(5)
|
269
|
+
return self.current_model
|
270
|
+
|
271
|
+
def terminate(self, inference_manager: InferenceManager):
|
272
|
+
"""Clean up resources and save the final model."""
|
273
|
+
try:
|
274
|
+
# Stop the inference server
|
275
|
+
if inference_manager.process is not None:
|
276
|
+
inference_manager.kill()
|
277
|
+
|
278
|
+
# Send termination command through REQ socket
|
279
|
+
self.server_comms_handler.send_broadcast({"message": "terminate"})
|
280
|
+
|
281
|
+
# Wait for training process to finish
|
282
|
+
if self.training_process:
|
283
|
+
self.training_process.wait(timeout=30)
|
284
|
+
|
285
|
+
except Exception as e:
|
286
|
+
print(f"Error during termination: {e}")
|
287
|
+
finally:
|
288
|
+
# Clean up ZMQ connections
|
289
|
+
if self.server_comms_handler:
|
290
|
+
self.server_comms_handler.close()
|
291
|
+
|
292
|
+
if self.train_kwargs and "output_dir" in self.train_kwargs:
|
293
|
+
print(
|
294
|
+
f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
|
295
|
+
)
|
296
|
+
if not os.path.exists(self.train_kwargs["output_dir"]):
|
297
|
+
print(
|
298
|
+
f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
|
299
|
+
)
|
300
|
+
return self.train_kwargs["output_dir"]
|
301
|
+
else:
|
302
|
+
print("Training terminated, no output directory specified")
|
303
|
+
return None
|
304
|
+
|
305
|
+
def _should_update_model(self):
|
306
|
+
# return (
|
307
|
+
# self.data_count - self.last_inference_update
|
308
|
+
# >= self.train_kwargs["update_interval"]
|
309
|
+
# )
|
310
|
+
return self.model_saved_and_reload_requested
|
@@ -0,0 +1,275 @@
|
|
1
|
+
import os
|
2
|
+
import signal
|
3
|
+
import socket
|
4
|
+
import subprocess
|
5
|
+
import sys
|
6
|
+
import threading
|
7
|
+
import time
|
8
|
+
from datetime import datetime
|
9
|
+
from typing import Any, Dict, Optional
|
10
|
+
|
11
|
+
import requests
|
12
|
+
|
13
|
+
from arbor.server.core.config import Settings
|
14
|
+
|
15
|
+
|
16
|
+
class InferenceManager:
|
17
|
+
def __init__(self, settings: Settings):
|
18
|
+
self.settings = settings
|
19
|
+
self.process = None
|
20
|
+
self.launch_kwargs = {}
|
21
|
+
self.last_activity = None
|
22
|
+
self.restarting = False
|
23
|
+
self._shutting_down = False
|
24
|
+
self.current_model = None
|
25
|
+
self.inference_count = 0
|
26
|
+
# Set up signal handler for graceful shutdown
|
27
|
+
signal.signal(signal.SIGINT, self._signal_handler)
|
28
|
+
signal.signal(signal.SIGTERM, self._signal_handler)
|
29
|
+
|
30
|
+
def _signal_handler(self, signum, frame):
|
31
|
+
if self._shutting_down:
|
32
|
+
print("\nForced exit during cleanup...")
|
33
|
+
os._exit(1)
|
34
|
+
|
35
|
+
print("\nReceived signal to terminate. Cleaning up...")
|
36
|
+
self._shutting_down = True
|
37
|
+
self.kill()
|
38
|
+
sys.exit(0)
|
39
|
+
|
40
|
+
def is_server_running(self):
|
41
|
+
return self.process is not None
|
42
|
+
|
43
|
+
def is_server_restarting(self):
|
44
|
+
return self.restarting
|
45
|
+
|
46
|
+
def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
|
47
|
+
if self.is_server_running():
|
48
|
+
print("Server is already launched.")
|
49
|
+
return
|
50
|
+
|
51
|
+
launch_kwargs = launch_kwargs or self.launch_kwargs
|
52
|
+
|
53
|
+
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
54
|
+
for prefix in prefixes:
|
55
|
+
if model.startswith(prefix):
|
56
|
+
model = model[len(prefix) :]
|
57
|
+
|
58
|
+
print(f"Grabbing a free port to launch an SGLang server for model {model}")
|
59
|
+
port = get_free_port()
|
60
|
+
timeout = launch_kwargs.get("timeout", 1800)
|
61
|
+
my_env = os.environ.copy()
|
62
|
+
my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
|
63
|
+
n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
|
64
|
+
# command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching --guided-decoding-backend xgrammar"
|
65
|
+
command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --router-policy round_robin --port {port} --host 0.0.0.0"
|
66
|
+
print(f"Running command: {command}")
|
67
|
+
|
68
|
+
# We will manually stream & capture logs.
|
69
|
+
process = subprocess.Popen(
|
70
|
+
command.replace("\\\n", " ").replace("\\", " ").split(),
|
71
|
+
text=True,
|
72
|
+
stdout=subprocess.PIPE, # We'll read from pipe
|
73
|
+
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
74
|
+
env=my_env,
|
75
|
+
)
|
76
|
+
|
77
|
+
# A threading.Event to control printing after the server is ready.
|
78
|
+
# This will store *all* lines (both before and after readiness).
|
79
|
+
print(f"SGLang server process started with PID {process.pid}.")
|
80
|
+
stop_printing_event = threading.Event()
|
81
|
+
logs_buffer = []
|
82
|
+
|
83
|
+
def _tail_process(proc, buffer, stop_event):
|
84
|
+
while True:
|
85
|
+
line = proc.stdout.readline()
|
86
|
+
if not line and proc.poll() is not None:
|
87
|
+
# Process ended and no new line
|
88
|
+
break
|
89
|
+
if line:
|
90
|
+
buffer.append(line)
|
91
|
+
# Print only if stop_event is not set
|
92
|
+
if not stop_event.is_set():
|
93
|
+
print(f"[SGLang LOG] {line}", end="")
|
94
|
+
|
95
|
+
# Start a background thread to read from the process continuously
|
96
|
+
thread = threading.Thread(
|
97
|
+
target=_tail_process,
|
98
|
+
args=(process, logs_buffer, stop_printing_event),
|
99
|
+
daemon=True,
|
100
|
+
)
|
101
|
+
thread.start()
|
102
|
+
|
103
|
+
# Wait until the server is ready (or times out)
|
104
|
+
base_url = f"http://localhost:{port}"
|
105
|
+
try:
|
106
|
+
wait_for_server(base_url, timeout=timeout)
|
107
|
+
except TimeoutError:
|
108
|
+
# If the server doesn't come up, we might want to kill it:
|
109
|
+
process.kill()
|
110
|
+
raise
|
111
|
+
|
112
|
+
# Once server is ready, we tell the thread to stop printing further lines.
|
113
|
+
stop_printing_event.set()
|
114
|
+
|
115
|
+
# A convenience getter so the caller can see all logs so far (and future).
|
116
|
+
def get_logs() -> str:
|
117
|
+
# Join them all into a single string, or you might return a list
|
118
|
+
return "".join(logs_buffer)
|
119
|
+
|
120
|
+
# Let the user know server is up
|
121
|
+
print(f"Server ready on random port {port}!")
|
122
|
+
|
123
|
+
self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
|
124
|
+
self.launch_kwargs["api_key"] = "local"
|
125
|
+
self.get_logs = get_logs
|
126
|
+
self.process = process
|
127
|
+
self.thread = thread
|
128
|
+
self.current_model = model
|
129
|
+
|
130
|
+
def kill(self):
|
131
|
+
from sglang.utils import terminate_process
|
132
|
+
|
133
|
+
if self.process is None:
|
134
|
+
print("No running server to kill.")
|
135
|
+
return
|
136
|
+
|
137
|
+
process = self.process
|
138
|
+
thread = self.thread
|
139
|
+
|
140
|
+
terminate_process(process)
|
141
|
+
|
142
|
+
# Clear references first
|
143
|
+
self.process = None
|
144
|
+
self.thread = None
|
145
|
+
self.get_logs = None
|
146
|
+
self.last_activity = None
|
147
|
+
|
148
|
+
try:
|
149
|
+
# Handle nested signal case
|
150
|
+
if self._shutting_down:
|
151
|
+
process.kill() # Go straight to SIGKILL if we're shutting down
|
152
|
+
else:
|
153
|
+
process.terminate() # Try SIGTERM first
|
154
|
+
try:
|
155
|
+
process.wait(timeout=10)
|
156
|
+
except subprocess.TimeoutExpired:
|
157
|
+
print(
|
158
|
+
"Process did not terminate after 10 seconds, forcing with SIGKILL..."
|
159
|
+
)
|
160
|
+
process.kill()
|
161
|
+
|
162
|
+
process.wait(timeout=5)
|
163
|
+
|
164
|
+
if thread and thread.is_alive():
|
165
|
+
thread.join(timeout=5)
|
166
|
+
|
167
|
+
except Exception as e:
|
168
|
+
print(f"Error during cleanup: {e}")
|
169
|
+
try:
|
170
|
+
process.kill() # Final attempt to kill
|
171
|
+
except:
|
172
|
+
pass
|
173
|
+
|
174
|
+
print("Server killed.")
|
175
|
+
|
176
|
+
def run_inference(self, request_json: dict):
|
177
|
+
model = request_json["model"]
|
178
|
+
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
179
|
+
for prefix in prefixes:
|
180
|
+
if model.startswith(prefix):
|
181
|
+
model = model[len(prefix) :]
|
182
|
+
print(f"Running inference for model {model}")
|
183
|
+
# Monkeypatch:
|
184
|
+
if model != self.current_model:
|
185
|
+
print(f"MONKEYPATCH: Model changed from {model} to {self.current_model}")
|
186
|
+
model = self.current_model
|
187
|
+
request_json["model"] = model
|
188
|
+
|
189
|
+
# Update last_activity timestamp
|
190
|
+
self.last_activity = datetime.now()
|
191
|
+
|
192
|
+
if self.process is None or self.launch_kwargs.get("api_base") is None:
|
193
|
+
raise RuntimeError("Server is not running. Please launch it first.")
|
194
|
+
|
195
|
+
if self.restarting:
|
196
|
+
while self.restarting:
|
197
|
+
print("Inference is paused while server is restarting...")
|
198
|
+
time.sleep(5)
|
199
|
+
request_json["model"] = self.current_model
|
200
|
+
|
201
|
+
url = f"{self.launch_kwargs['api_base']}/chat/completions"
|
202
|
+
try:
|
203
|
+
self.inference_count += 1
|
204
|
+
response = requests.post(url, json=request_json)
|
205
|
+
return response.json()
|
206
|
+
except requests.exceptions.ConnectionError:
|
207
|
+
print("Server disconnected...ignoring")
|
208
|
+
return None
|
209
|
+
except Exception as e:
|
210
|
+
print(f"Error during inference: {e}")
|
211
|
+
raise
|
212
|
+
finally:
|
213
|
+
self.inference_count -= 1
|
214
|
+
|
215
|
+
def update_model(self, output_dir):
|
216
|
+
print("Restarting server with new model...")
|
217
|
+
self.restarting = True
|
218
|
+
|
219
|
+
while self.inference_count > 0:
|
220
|
+
print(
|
221
|
+
f"Waiting for inference requests to finish... {self.inference_count} remaining"
|
222
|
+
)
|
223
|
+
time.sleep(5)
|
224
|
+
|
225
|
+
tik = time.time()
|
226
|
+
self.kill()
|
227
|
+
print("Just killed server")
|
228
|
+
# Check that output directory exists and was created successfully
|
229
|
+
print(f"Checking that output directory {output_dir} exists")
|
230
|
+
if not os.path.exists(output_dir):
|
231
|
+
raise RuntimeError(
|
232
|
+
f"Failed to save model - output directory {output_dir} does not exist"
|
233
|
+
)
|
234
|
+
|
235
|
+
print("Launching new server")
|
236
|
+
self.launch(output_dir, self.launch_kwargs)
|
237
|
+
tok = time.time()
|
238
|
+
self.restarting = False
|
239
|
+
print(f"Time taken to update model: {tok - tik} seconds")
|
240
|
+
|
241
|
+
|
242
|
+
def get_free_port() -> int:
|
243
|
+
"""
|
244
|
+
Return a free TCP port on localhost.
|
245
|
+
"""
|
246
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
247
|
+
s.bind(("localhost", 0))
|
248
|
+
return s.getsockname()[1]
|
249
|
+
|
250
|
+
|
251
|
+
def wait_for_server(base_url: str, timeout: int = None) -> None:
|
252
|
+
"""
|
253
|
+
Wait for the server to be ready by polling the /v1/models endpoint.
|
254
|
+
|
255
|
+
Args:
|
256
|
+
base_url: The base URL of the server (e.g. http://localhost:1234)
|
257
|
+
timeout: Maximum time to wait in seconds. None means wait forever.
|
258
|
+
"""
|
259
|
+
start_time = time.time()
|
260
|
+
while True:
|
261
|
+
try:
|
262
|
+
response = requests.get(
|
263
|
+
f"{base_url}/v1/models",
|
264
|
+
headers={"Authorization": "Bearer None"},
|
265
|
+
)
|
266
|
+
if response.status_code == 200:
|
267
|
+
# A small extra sleep to ensure server is fully up.
|
268
|
+
time.sleep(5)
|
269
|
+
break
|
270
|
+
|
271
|
+
if timeout and (time.time() - start_time) > timeout:
|
272
|
+
raise TimeoutError("Server did not become ready within timeout period")
|
273
|
+
except requests.exceptions.RequestException:
|
274
|
+
# Server not up yet, wait and retry
|
275
|
+
time.sleep(1)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import uuid
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Literal
|
4
|
+
|
5
|
+
from arbor.server.api.models.schemas import JobStatus
|
6
|
+
from arbor.server.core.config import Settings
|
7
|
+
|
8
|
+
|
9
|
+
class JobEvent:
|
10
|
+
def __init__(
|
11
|
+
self, level: Literal["info", "warning", "error"], message: str, data: dict = {}
|
12
|
+
):
|
13
|
+
self.level = level
|
14
|
+
self.message = message
|
15
|
+
self.data = data
|
16
|
+
|
17
|
+
self.id = str(f"ftevent-{uuid.uuid4()}")
|
18
|
+
self.created_at = datetime.now()
|
19
|
+
|
20
|
+
|
21
|
+
class JobCheckpoint:
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
fine_tuned_model_checkpoint: str,
|
25
|
+
fine_tuning_job_id: str,
|
26
|
+
metrics: dict,
|
27
|
+
step_number: int,
|
28
|
+
):
|
29
|
+
self.id = str(f"ftckpt-{uuid.uuid4()}")
|
30
|
+
self.fine_tuned_model_checkpoint = fine_tuned_model_checkpoint
|
31
|
+
self.fine_tuning_job_id = fine_tuning_job_id
|
32
|
+
self.metrics = metrics
|
33
|
+
self.step_number = step_number
|
34
|
+
self.created_at = datetime.now()
|
35
|
+
|
36
|
+
|
37
|
+
class Job:
|
38
|
+
def __init__(self, status: JobStatus):
|
39
|
+
self.id = str(f"ftjob-{uuid.uuid4()}")
|
40
|
+
self.status = status
|
41
|
+
self.fine_tuned_model = None
|
42
|
+
self.events: list[JobEvent] = []
|
43
|
+
self.checkpoints: list[JobCheckpoint] = []
|
44
|
+
|
45
|
+
self.created_at = datetime.now()
|
46
|
+
|
47
|
+
def add_event(self, event: JobEvent):
|
48
|
+
self.events.append(event)
|
49
|
+
|
50
|
+
def get_events(self) -> list[JobEvent]:
|
51
|
+
return self.events
|
52
|
+
|
53
|
+
def add_checkpoint(self, checkpoint: JobCheckpoint):
|
54
|
+
self.checkpoints.append(checkpoint)
|
55
|
+
|
56
|
+
def get_checkpoints(self) -> list[JobCheckpoint]:
|
57
|
+
return self.checkpoints
|
58
|
+
|
59
|
+
|
60
|
+
class JobManager:
|
61
|
+
def __init__(self, settings: Settings):
|
62
|
+
self.jobs = {}
|
63
|
+
|
64
|
+
def get_job(self, job_id: str):
|
65
|
+
if job_id not in self.jobs:
|
66
|
+
raise ValueError(f"Job {job_id} not found")
|
67
|
+
return self.jobs[job_id]
|
68
|
+
|
69
|
+
def create_job(self):
|
70
|
+
job = Job(status=JobStatus.PENDING)
|
71
|
+
self.jobs[job.id] = job
|
72
|
+
return job
|
73
|
+
|
74
|
+
def get_jobs(self):
|
75
|
+
return list(self.jobs.values())
|
76
|
+
|
77
|
+
def get_active_job(self):
|
78
|
+
for job in self.jobs.values():
|
79
|
+
if job.status == JobStatus.RUNNING:
|
80
|
+
return job
|
81
|
+
return None
|