arbor-ai 0.2__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/__init__.py +17 -0
- arbor/cli.py +83 -43
- arbor/client/arbor_client.py +259 -0
- arbor/server/api/models/schemas.py +3 -1
- arbor/server/api/routes/grpo.py +2 -6
- arbor/server/api/routes/inference.py +7 -3
- arbor/server/core/config.py +293 -7
- arbor/server/core/config_manager.py +100 -0
- arbor/server/main.py +26 -1
- arbor/server/services/comms/comms.py +13 -9
- arbor/server/services/file_manager.py +7 -4
- arbor/server/services/grpo_manager.py +101 -63
- arbor/server/services/health_manager.py +171 -0
- arbor/server/services/inference/vllm_client.py +8 -5
- arbor/server/services/inference_manager.py +40 -38
- arbor/server/services/job_manager.py +2 -2
- arbor/server/services/scripts/grpo_training.py +62 -281
- arbor/server/services/scripts/mmgrpo_training.py +510 -0
- arbor/server/services/scripts/sft_training.py +8 -5
- arbor/server/services/scripts/utils/callbacks.py +33 -0
- arbor/server/services/scripts/utils/comms_monitors.py +169 -0
- arbor/server/services/scripts/utils/dataset.py +176 -0
- arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
- arbor/server/services/scripts/utils/mock_server.py +124 -0
- arbor/server/services/training_manager.py +4 -4
- arbor/server/utils/logging.py +298 -0
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +18 -18
- arbor_ai-0.2.2.dist-info/RECORD +51 -0
- arbor_ai-0.2.dist-info/RECORD +0 -42
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ import json
|
|
3
3
|
import os
|
4
4
|
import random
|
5
5
|
import signal
|
6
|
-
import socket
|
7
6
|
import string
|
8
7
|
import subprocess
|
9
8
|
import sys
|
@@ -20,14 +19,17 @@ from arbor.server.api.models.schemas import (
|
|
20
19
|
GRPOConfigRequest,
|
21
20
|
GRPORequest,
|
22
21
|
)
|
23
|
-
from arbor.server.core.config import
|
22
|
+
from arbor.server.core.config import Config
|
24
23
|
from arbor.server.services.comms.comms import ArborServerCommsHandler
|
25
24
|
from arbor.server.services.inference_manager import InferenceManager
|
25
|
+
from arbor.server.utils.logging import get_logger
|
26
|
+
|
27
|
+
logger = get_logger(__name__)
|
26
28
|
|
27
29
|
|
28
30
|
class GRPOManager:
|
29
|
-
def __init__(self,
|
30
|
-
self.
|
31
|
+
def __init__(self, config: Config):
|
32
|
+
self.config = config
|
31
33
|
self.training_process = None
|
32
34
|
self.current_model = None
|
33
35
|
self.train_kwargs = None
|
@@ -47,7 +49,7 @@ class GRPOManager:
|
|
47
49
|
|
48
50
|
def _signal_handler(self, signum, frame):
|
49
51
|
"""Handle keyboard interrupt (SIGINT) gracefully."""
|
50
|
-
|
52
|
+
logger.info("Received keyboard interrupt. Shutting down gracefully...")
|
51
53
|
# Sleep for a bit to let async operations go through
|
52
54
|
time.sleep(2)
|
53
55
|
if self.training_process is not None:
|
@@ -65,16 +67,14 @@ class GRPOManager:
|
|
65
67
|
)
|
66
68
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
67
69
|
name = f"grpo:{model_name}:{suffix}:{timestamp}"
|
68
|
-
return name, str(Path(self.
|
70
|
+
return name, str(Path(self.config.STORAGE_PATH).resolve() / "models" / name)
|
69
71
|
|
70
72
|
def find_training_args(self, request: GRPOConfigRequest) -> dict:
|
71
73
|
"""Process the config request and return training arguments."""
|
72
74
|
name, output_dir = self.make_output_dir(request.model, request.suffix)
|
73
75
|
|
74
76
|
# Here are defaults for training. We can adjust them if we disagree w the huggingface defaults
|
75
|
-
default_train_kwargs = {
|
76
|
-
"output_dir": output_dir,
|
77
|
-
}
|
77
|
+
default_train_kwargs = {"output_dir": output_dir, "grpo_flavor": "grpo"}
|
78
78
|
|
79
79
|
train_kwargs = request.model_dump(exclude_unset=True)
|
80
80
|
return {**default_train_kwargs, **(train_kwargs or {})}
|
@@ -108,7 +108,7 @@ class GRPOManager:
|
|
108
108
|
key: train_kwargs[key] for key in trl_keys if key in train_kwargs
|
109
109
|
}
|
110
110
|
|
111
|
-
arbor_keys = ["max_context_length", "lora"]
|
111
|
+
arbor_keys = ["max_context_length", "lora", "wandb_kwargs", "grpo_flavor"]
|
112
112
|
arbor_train_kwargs = {
|
113
113
|
key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
|
114
114
|
}
|
@@ -121,7 +121,7 @@ class GRPOManager:
|
|
121
121
|
"""Initialize the training process with ZMQ-based communication."""
|
122
122
|
self.train_kwargs = self.find_training_args(request)
|
123
123
|
|
124
|
-
trl_train_kwargs, arbor_train_kwargs = self.process_training_args(
|
124
|
+
self.trl_train_kwargs, self.arbor_train_kwargs = self.process_training_args(
|
125
125
|
self.train_kwargs
|
126
126
|
)
|
127
127
|
|
@@ -132,31 +132,34 @@ class GRPOManager:
|
|
132
132
|
# launch_kwargs = {
|
133
133
|
# k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
|
134
134
|
# }
|
135
|
-
inference_manager.launch_kwargs["max_context_length"] =
|
136
|
-
"max_context_length", None
|
135
|
+
inference_manager.launch_kwargs["max_context_length"] = (
|
136
|
+
self.arbor_train_kwargs.get("max_context_length", None)
|
137
137
|
)
|
138
|
-
|
138
|
+
logger.info("Launching inference server...")
|
139
139
|
inference_manager.launch(self.current_model)
|
140
140
|
|
141
141
|
# Initialize ZMQ socket manager - no need for connection acceptance thread anymore
|
142
142
|
self.server_comms_handler = ArborServerCommsHandler()
|
143
143
|
|
144
144
|
script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
|
145
|
-
|
145
|
+
script_name = {"mmgrpo": "mmgrpo_training.py", "grpo": "grpo_training.py"}[
|
146
|
+
self.arbor_train_kwargs["grpo_flavor"]
|
147
|
+
]
|
148
|
+
script_path = os.path.join(script_dir, script_name)
|
146
149
|
|
147
150
|
# Start the training process with ZMQ ports
|
148
151
|
my_env = os.environ.copy()
|
149
|
-
my_env["CUDA_VISIBLE_DEVICES"] = self.
|
152
|
+
my_env["CUDA_VISIBLE_DEVICES"] = self.config.arbor_config.training.gpu_ids
|
150
153
|
# WandB can block the training process for login, so we silence it
|
151
154
|
my_env["WANDB_SILENT"] = "true"
|
152
155
|
|
153
|
-
num_processes = self.
|
156
|
+
num_processes = self.config.arbor_config.training.gpu_ids.count(",") + 1
|
154
157
|
|
155
158
|
# This is the port for the accelerate main process
|
156
159
|
main_process_port = get_free_port()
|
157
160
|
|
158
161
|
params = [
|
159
|
-
|
162
|
+
sys.executable,
|
160
163
|
"-m",
|
161
164
|
"accelerate.commands.launch",
|
162
165
|
"--num_processes",
|
@@ -164,11 +167,11 @@ class GRPOManager:
|
|
164
167
|
"--main_process_port",
|
165
168
|
str(main_process_port),
|
166
169
|
]
|
167
|
-
if self.
|
170
|
+
if self.config.arbor_config.training.accelerate_config:
|
168
171
|
params.extend(
|
169
172
|
[
|
170
173
|
"--config_file",
|
171
|
-
self.
|
174
|
+
self.config.arbor_config.training.accelerate_config,
|
172
175
|
]
|
173
176
|
)
|
174
177
|
params.extend(
|
@@ -195,12 +198,12 @@ class GRPOManager:
|
|
195
198
|
"--model",
|
196
199
|
self.current_model,
|
197
200
|
"--trl_train_kwargs",
|
198
|
-
json.dumps(trl_train_kwargs),
|
201
|
+
json.dumps(self.trl_train_kwargs),
|
199
202
|
"--arbor_train_kwargs",
|
200
|
-
json.dumps(arbor_train_kwargs),
|
203
|
+
json.dumps(self.arbor_train_kwargs),
|
201
204
|
]
|
202
205
|
)
|
203
|
-
|
206
|
+
logger.info(f"Running GRPO training command: {' '.join(params)}")
|
204
207
|
|
205
208
|
self.training_process = subprocess.Popen(
|
206
209
|
params,
|
@@ -222,9 +225,9 @@ class GRPOManager:
|
|
222
225
|
break
|
223
226
|
if line:
|
224
227
|
buffer.append(line)
|
225
|
-
#
|
228
|
+
# Log only if stop_event is not set
|
226
229
|
if not stop_event.is_set():
|
227
|
-
|
230
|
+
logger.info(f"[GRPO LOG] {line.strip()}")
|
228
231
|
|
229
232
|
# Start a background thread to read from the process continuously
|
230
233
|
thread = threading.Thread(
|
@@ -256,10 +259,10 @@ class GRPOManager:
|
|
256
259
|
|
257
260
|
def _handle_status_updates(self, inference_manager: InferenceManager):
|
258
261
|
"""Handle status updates from training process using ZMQ SUB socket"""
|
259
|
-
|
262
|
+
logger.info("Starting status update handler...")
|
260
263
|
try:
|
261
264
|
for status in self.server_comms_handler.receive_status():
|
262
|
-
|
265
|
+
logger.debug(f"Received status update: {status}")
|
263
266
|
if status["status"] == "weight_update_start":
|
264
267
|
# Block inference calls by incrementing counter
|
265
268
|
inference_manager.start_weight_update()
|
@@ -267,47 +270,85 @@ class GRPOManager:
|
|
267
270
|
# Decrement counter to potentially allow inference calls again
|
268
271
|
inference_manager.complete_weight_update()
|
269
272
|
elif status["status"] == "model_saved":
|
270
|
-
|
273
|
+
logger.info("Updating inference model...")
|
271
274
|
# There is a case where this status is sent multiple times
|
272
275
|
# We need to make sure we only update the model once
|
273
|
-
self.current_model = status["output_dir"]
|
274
276
|
self.saving_model = False
|
275
|
-
|
277
|
+
logger.info("Model update complete")
|
276
278
|
elif status["status"] == "checkpoint_saved":
|
277
|
-
|
279
|
+
logger.info("Received checkpoint saved status")
|
278
280
|
self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
|
279
281
|
self.last_checkpoint = status["checkpoint_name"]
|
280
282
|
self.saving_checkpoint = False
|
281
|
-
|
283
|
+
logger.info("Checkpoint saved")
|
282
284
|
elif status["status"] == "error":
|
283
|
-
|
285
|
+
error_msg = status.get("error", "Unknown error")
|
286
|
+
logger.error(f"Training error: {error_msg}")
|
284
287
|
elif status["status"] == "terminated":
|
285
288
|
self.terminating = False
|
286
|
-
|
289
|
+
logger.info("Training process terminated")
|
287
290
|
except Exception as e:
|
288
|
-
|
291
|
+
logger.error(f"Error in status update handler: {e}")
|
289
292
|
# Make sure to allow inference if there's an error
|
290
293
|
try:
|
291
294
|
inference_manager.complete_weight_update()
|
292
295
|
except:
|
293
296
|
pass
|
294
297
|
|
298
|
+
def validate_batch(self, batch):
|
299
|
+
if not isinstance(batch, list):
|
300
|
+
raise ValueError("Batch must be a list")
|
301
|
+
|
302
|
+
if self.arbor_train_kwargs["grpo_flavor"] == "mmgrpo":
|
303
|
+
for group in batch:
|
304
|
+
if not isinstance(group, list):
|
305
|
+
raise ValueError("Each group in batch must be a list")
|
306
|
+
for item in group:
|
307
|
+
if not isinstance(item, dict):
|
308
|
+
raise ValueError("Each item in group must be a dictionary")
|
309
|
+
required_keys = {"messages", "completion", "advantage"}
|
310
|
+
if not all(key in item for key in required_keys):
|
311
|
+
raise ValueError(
|
312
|
+
f"Each item must contain keys: {required_keys}"
|
313
|
+
)
|
314
|
+
return True
|
315
|
+
elif self.arbor_train_kwargs["grpo_flavor"] == "grpo":
|
316
|
+
for item in batch:
|
317
|
+
if not isinstance(item, dict):
|
318
|
+
raise ValueError("Each item in batch must be a dictionary")
|
319
|
+
required_keys = {"messages", "completion", "reward"}
|
320
|
+
if not all(key in item for key in required_keys):
|
321
|
+
raise ValueError(f"Each item must contain keys: {required_keys}")
|
322
|
+
return True
|
323
|
+
else:
|
324
|
+
raise NotImplementedError(
|
325
|
+
f"GRPO flavor batch validation not implemented for {self.arbor_train_kwargs['grpo_flavor']}"
|
326
|
+
)
|
327
|
+
|
295
328
|
def grpo_step(
|
296
329
|
self, request: GRPORequest, inference_manager: InferenceManager
|
297
330
|
) -> str:
|
298
331
|
while self.saving_checkpoint:
|
299
|
-
|
332
|
+
logger.info(
|
333
|
+
"Saving checkpoint, pausing GRPO steps until checkpoint is saved..."
|
334
|
+
)
|
300
335
|
time.sleep(5)
|
301
336
|
|
337
|
+
self.validate_batch(request.batch)
|
338
|
+
|
302
339
|
try:
|
340
|
+
|
303
341
|
# Send the batch to the training process
|
304
342
|
self.server_comms_handler.send_data(request.batch)
|
305
343
|
self.data_count += 1
|
306
344
|
|
307
345
|
except Exception as e:
|
308
|
-
|
346
|
+
logger.error(f"Failed to send batch to training process: {e}")
|
309
347
|
raise
|
310
348
|
|
349
|
+
self.current_model = self.train_kwargs["output_dir"]
|
350
|
+
inference_manager.launched_model = self.current_model
|
351
|
+
|
311
352
|
return {
|
312
353
|
"current_model": self.current_model,
|
313
354
|
"checkpoints": self.checkpoints,
|
@@ -320,7 +361,7 @@ class GRPOManager:
|
|
320
361
|
while (
|
321
362
|
inference_manager.is_updating
|
322
363
|
): # Use the property instead of direct access
|
323
|
-
|
364
|
+
logger.info("Waiting for weight updates to finish before checkpointing...")
|
324
365
|
time.sleep(5)
|
325
366
|
|
326
367
|
self.saving_checkpoint = True
|
@@ -328,7 +369,7 @@ class GRPOManager:
|
|
328
369
|
{"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
|
329
370
|
)
|
330
371
|
while self.saving_checkpoint:
|
331
|
-
|
372
|
+
logger.info("Waiting for checkpoint to be saved...")
|
332
373
|
time.sleep(5)
|
333
374
|
return {
|
334
375
|
"current_model": self.current_model,
|
@@ -343,14 +384,14 @@ class GRPOManager:
|
|
343
384
|
while (
|
344
385
|
inference_manager and inference_manager.is_updating
|
345
386
|
): # Use the property instead of direct access
|
346
|
-
|
387
|
+
logger.info("Waiting for final weight updates to finish before saving...")
|
347
388
|
time.sleep(5)
|
348
389
|
|
349
|
-
|
390
|
+
logger.info("Sending save model command")
|
350
391
|
self.saving_model = True
|
351
392
|
self.server_comms_handler.send_command({"command": "save_model"})
|
352
393
|
while self.saving_model:
|
353
|
-
|
394
|
+
logger.info("Waiting for final model to be saved...")
|
354
395
|
time.sleep(5)
|
355
396
|
|
356
397
|
termination_data = {
|
@@ -359,37 +400,34 @@ class GRPOManager:
|
|
359
400
|
"last_checkpoint": self.last_checkpoint,
|
360
401
|
}
|
361
402
|
|
362
|
-
|
403
|
+
logger.info("Sending termination command")
|
363
404
|
self.terminating = True
|
364
405
|
self.server_comms_handler.send_command({"command": "terminate"})
|
365
|
-
|
406
|
+
logger.info("Waiting for training process to finish...")
|
366
407
|
|
367
408
|
# Wait for at most 15 seconds for termination
|
368
409
|
start_time = time.time()
|
369
410
|
while self.terminating:
|
370
411
|
if time.time() - start_time > 15:
|
371
|
-
|
412
|
+
logger.warning(
|
372
413
|
"Termination wait timed out after 15 seconds, proceeding with cleanup..."
|
373
414
|
)
|
374
415
|
break
|
375
|
-
|
416
|
+
logger.info("Waiting for run to be terminated...")
|
376
417
|
time.sleep(3)
|
377
418
|
|
378
|
-
|
419
|
+
logger.info("Starting cleanup")
|
379
420
|
self.cleanup_termination(inference_manager)
|
380
421
|
|
381
422
|
if self.train_kwargs and "output_dir" in self.train_kwargs:
|
382
|
-
print(
|
383
|
-
f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
|
384
|
-
)
|
385
|
-
if not os.path.exists(self.train_kwargs["output_dir"]):
|
386
|
-
print(
|
387
|
-
f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
|
388
|
-
)
|
389
423
|
output_dir = self.train_kwargs["output_dir"]
|
424
|
+
logger.info(f"Training completed. Model saved to {output_dir}")
|
425
|
+
logger.info(f"Training logs and checkpoints are stored in: {output_dir}")
|
426
|
+
if not os.path.exists(output_dir):
|
427
|
+
logger.warning(f"Output directory {output_dir} does not exist")
|
390
428
|
self.train_kwargs = None
|
391
429
|
else:
|
392
|
-
|
430
|
+
logger.info("Training terminated, no output directory specified")
|
393
431
|
self.train_kwargs = None
|
394
432
|
|
395
433
|
return termination_data
|
@@ -398,7 +436,7 @@ class GRPOManager:
|
|
398
436
|
try:
|
399
437
|
# Kill training process and all its children (accelerate launcher creates multiple processes)
|
400
438
|
if self.training_process:
|
401
|
-
|
439
|
+
logger.info("Terminating training process and its children...")
|
402
440
|
try:
|
403
441
|
parent = psutil.Process(self.training_process.pid)
|
404
442
|
# Get all child processes including grandchildren
|
@@ -425,9 +463,9 @@ class GRPOManager:
|
|
425
463
|
pass
|
426
464
|
|
427
465
|
except psutil.NoSuchProcess:
|
428
|
-
|
466
|
+
logger.warning(f"Process {self.training_process.pid} not found")
|
429
467
|
except Exception as e:
|
430
|
-
|
468
|
+
logger.error(f"Error killing training process tree: {e}")
|
431
469
|
# Fallback to basic termination
|
432
470
|
self.training_process.terminate()
|
433
471
|
try:
|
@@ -438,11 +476,11 @@ class GRPOManager:
|
|
438
476
|
|
439
477
|
# Clean up ZMQ connections
|
440
478
|
if self.server_comms_handler:
|
441
|
-
|
479
|
+
logger.debug("Closing ZMQ connections...")
|
442
480
|
self.server_comms_handler.close()
|
443
481
|
|
444
482
|
if inference_manager and inference_manager.process is not None:
|
445
|
-
|
483
|
+
logger.info("Killing inference manager...")
|
446
484
|
inference_manager.kill()
|
447
485
|
|
448
486
|
# Reinitialize in case we want to start a new training run
|
@@ -451,9 +489,9 @@ class GRPOManager:
|
|
451
489
|
self.server_comms_handler = None
|
452
490
|
self.status_thread = None
|
453
491
|
self.data_count = 0
|
454
|
-
|
492
|
+
logger.info("Cleanup completed successfully")
|
455
493
|
except Exception as e:
|
456
|
-
|
494
|
+
logger.error(f"Error during cleanup: {e}")
|
457
495
|
# Still reset state even if cleanup fails
|
458
496
|
self.training_process = None
|
459
497
|
self.current_model = None
|
@@ -476,5 +514,5 @@ def get_free_port() -> int:
|
|
476
514
|
s.bind(("localhost", 0))
|
477
515
|
ports.append(s.getsockname()[1])
|
478
516
|
except Exception as e:
|
479
|
-
|
517
|
+
logger.error(f"Error binding to port: {e}")
|
480
518
|
return random.choice(ports)
|
@@ -0,0 +1,171 @@
|
|
1
|
+
import platform
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Any, Dict
|
4
|
+
|
5
|
+
import psutil
|
6
|
+
|
7
|
+
from arbor.server.core.config import Config
|
8
|
+
|
9
|
+
try:
|
10
|
+
import GPUtil
|
11
|
+
|
12
|
+
GPU_AVAILABLE = True
|
13
|
+
except ImportError:
|
14
|
+
GPU_AVAILABLE = False
|
15
|
+
|
16
|
+
try:
|
17
|
+
import torch
|
18
|
+
|
19
|
+
TORCH_AVAILABLE = True
|
20
|
+
except ImportError:
|
21
|
+
TORCH_AVAILABLE = False
|
22
|
+
|
23
|
+
|
24
|
+
class HealthManager:
|
25
|
+
"""Manages system health checks including GPU monitoring."""
|
26
|
+
|
27
|
+
def __init__(self, config: Config = None):
|
28
|
+
self.config = config
|
29
|
+
|
30
|
+
def get_gpu_info(self) -> Dict[str, Any]:
|
31
|
+
"""Get GPU information including available and used GPUs."""
|
32
|
+
gpu_info = {
|
33
|
+
"gpus_available": 0,
|
34
|
+
"gpus_used": 0,
|
35
|
+
"gpu_details": [],
|
36
|
+
"cuda_available": False,
|
37
|
+
"gpu_library": "none",
|
38
|
+
}
|
39
|
+
|
40
|
+
# Try GPUtil first
|
41
|
+
if GPU_AVAILABLE:
|
42
|
+
try:
|
43
|
+
gpus = GPUtil.getGPUs()
|
44
|
+
gpu_info["gpus_available"] = len(gpus)
|
45
|
+
gpu_info["gpu_library"] = "GPUtil"
|
46
|
+
|
47
|
+
for i, gpu in enumerate(gpus):
|
48
|
+
gpu_detail = {
|
49
|
+
"id": gpu.id,
|
50
|
+
"name": gpu.name,
|
51
|
+
"memory_total": f"{gpu.memoryTotal}MB",
|
52
|
+
"memory_used": f"{gpu.memoryUsed}MB",
|
53
|
+
"memory_free": f"{gpu.memoryFree}MB",
|
54
|
+
"utilization": f"{gpu.load * 100:.1f}%",
|
55
|
+
"temperature": f"{gpu.temperature}°C",
|
56
|
+
}
|
57
|
+
gpu_info["gpu_details"].append(gpu_detail)
|
58
|
+
|
59
|
+
# Consider GPU "used" if utilization > 10% or memory usage > 10%
|
60
|
+
if gpu.load > 0.1 or (gpu.memoryUsed / gpu.memoryTotal) > 0.1:
|
61
|
+
gpu_info["gpus_used"] += 1
|
62
|
+
|
63
|
+
except Exception as e:
|
64
|
+
gpu_info["error"] = f"GPUtil error: {str(e)}"
|
65
|
+
|
66
|
+
# Try PyTorch as fallback/additional info
|
67
|
+
if TORCH_AVAILABLE:
|
68
|
+
try:
|
69
|
+
gpu_info["cuda_available"] = torch.cuda.is_available()
|
70
|
+
if torch.cuda.is_available():
|
71
|
+
cuda_count = torch.cuda.device_count()
|
72
|
+
if not GPU_AVAILABLE: # Only use torch info if GPUtil not available
|
73
|
+
gpu_info["gpus_available"] = cuda_count
|
74
|
+
gpu_info["gpu_library"] = "PyTorch"
|
75
|
+
|
76
|
+
for i in range(cuda_count):
|
77
|
+
props = torch.cuda.get_device_properties(i)
|
78
|
+
memory_allocated = (
|
79
|
+
torch.cuda.memory_allocated(i) / 1024**2
|
80
|
+
) # MB
|
81
|
+
memory_cached = (
|
82
|
+
torch.cuda.memory_reserved(i) / 1024**2
|
83
|
+
) # MB
|
84
|
+
memory_total = props.total_memory / 1024**2 # MB
|
85
|
+
|
86
|
+
gpu_detail = {
|
87
|
+
"id": i,
|
88
|
+
"name": props.name,
|
89
|
+
"memory_total": f"{memory_total:.0f}MB",
|
90
|
+
"memory_allocated": f"{memory_allocated:.0f}MB",
|
91
|
+
"memory_cached": f"{memory_cached:.0f}MB",
|
92
|
+
"compute_capability": f"{props.major}.{props.minor}",
|
93
|
+
}
|
94
|
+
gpu_info["gpu_details"].append(gpu_detail)
|
95
|
+
|
96
|
+
# Consider GPU "used" if memory allocated > 100MB
|
97
|
+
if memory_allocated > 100:
|
98
|
+
gpu_info["gpus_used"] += 1
|
99
|
+
|
100
|
+
except Exception as e:
|
101
|
+
gpu_info["torch_error"] = f"PyTorch error: {str(e)}"
|
102
|
+
|
103
|
+
return gpu_info
|
104
|
+
|
105
|
+
def get_system_info(self) -> Dict[str, Any]:
|
106
|
+
"""Get system information including CPU, memory, and disk usage."""
|
107
|
+
memory = psutil.virtual_memory()
|
108
|
+
disk = psutil.disk_usage("/")
|
109
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
110
|
+
|
111
|
+
return {
|
112
|
+
"platform": platform.system(),
|
113
|
+
"platform_release": platform.release(),
|
114
|
+
"platform_version": platform.version(),
|
115
|
+
"architecture": platform.machine(),
|
116
|
+
"processor": platform.processor(),
|
117
|
+
"cpu_usage": f"{cpu_percent}%",
|
118
|
+
"memory": {
|
119
|
+
"total": f"{memory.total / 1024**3:.2f}GB",
|
120
|
+
"available": f"{memory.available / 1024**3:.2f}GB",
|
121
|
+
"used": f"{memory.used / 1024**3:.2f}GB",
|
122
|
+
"percentage": f"{memory.percent}%",
|
123
|
+
},
|
124
|
+
"disk": {
|
125
|
+
"total": f"{disk.total / 1024**3:.2f}GB",
|
126
|
+
"free": f"{disk.free / 1024**3:.2f}GB",
|
127
|
+
"used": f"{disk.used / 1024**3:.2f}GB",
|
128
|
+
"percentage": f"{(disk.used / disk.total) * 100:.1f}%",
|
129
|
+
},
|
130
|
+
"gpu": self.get_gpu_info(),
|
131
|
+
}
|
132
|
+
|
133
|
+
def get_health_status(self) -> Dict[str, Any]:
|
134
|
+
"""Get comprehensive health status including system and GPU information."""
|
135
|
+
version = self.config.get_arbor_version() if self.config else "unknown"
|
136
|
+
versions = (
|
137
|
+
self.config.get_system_versions() if self.config else {"arbor": version}
|
138
|
+
)
|
139
|
+
|
140
|
+
return {
|
141
|
+
"status": "healthy",
|
142
|
+
"version": version, # Keep for backward compatibility
|
143
|
+
"versions": versions, # Comprehensive version info
|
144
|
+
"timestamp": datetime.now().isoformat(),
|
145
|
+
"system": self.get_system_info(),
|
146
|
+
}
|
147
|
+
|
148
|
+
def is_healthy(self) -> bool:
|
149
|
+
"""Check if the system is healthy based on various metrics."""
|
150
|
+
try:
|
151
|
+
# Check memory usage (unhealthy if > 90%)
|
152
|
+
memory = psutil.virtual_memory()
|
153
|
+
if memory.percent > 90:
|
154
|
+
print(f"Memory usage is {memory.percent}%")
|
155
|
+
return False
|
156
|
+
|
157
|
+
# Check disk usage (unhealthy if > 95%)
|
158
|
+
disk = psutil.disk_usage("/")
|
159
|
+
if (disk.used / disk.total) * 100 > 95:
|
160
|
+
print(f"Disk usage is {disk.used / disk.total * 100}%")
|
161
|
+
return False
|
162
|
+
|
163
|
+
# Check CPU usage (unhealthy if > 95% sustained)
|
164
|
+
cpu_percent = psutil.cpu_percent(interval=2)
|
165
|
+
if cpu_percent > 95:
|
166
|
+
print(f"CPU usage is {cpu_percent}%")
|
167
|
+
return False
|
168
|
+
|
169
|
+
return True
|
170
|
+
except Exception:
|
171
|
+
return False
|
@@ -1,8 +1,10 @@
|
|
1
1
|
# adapted from Will Brown's verifiers library (https://github.com/willccbb/verifiers)
|
2
2
|
|
3
|
+
import asyncio
|
3
4
|
import atexit
|
4
5
|
import logging
|
5
6
|
import time
|
7
|
+
import traceback
|
6
8
|
from typing import Optional
|
7
9
|
|
8
10
|
import httpx
|
@@ -130,11 +132,11 @@ class VLLMClient:
|
|
130
132
|
) from exc
|
131
133
|
else:
|
132
134
|
if response.status_code == 200:
|
133
|
-
logger.
|
135
|
+
logger.debug("Server is up!")
|
134
136
|
return None
|
135
137
|
|
136
|
-
# Retry logic: wait before
|
137
|
-
logger.
|
138
|
+
# Retry logic: wait before tryng again
|
139
|
+
logger.debug(
|
138
140
|
f"Server is not up yet. Retrying in {retry_interval} seconds..."
|
139
141
|
)
|
140
142
|
time.sleep(retry_interval)
|
@@ -239,7 +241,7 @@ class VLLMClient:
|
|
239
241
|
response.raise_for_status()
|
240
242
|
return response.json()
|
241
243
|
|
242
|
-
except httpx.
|
244
|
+
except httpx.TimeoutException:
|
243
245
|
logger.error("Request timed out")
|
244
246
|
raise
|
245
247
|
except InferenceBlockedError:
|
@@ -253,7 +255,8 @@ class VLLMClient:
|
|
253
255
|
await asyncio.sleep(INFERENCE_RETRY_DELAY)
|
254
256
|
else:
|
255
257
|
logger.error(
|
256
|
-
f"Request failed after {MAX_INFERENCE_RETRIES} retries"
|
258
|
+
f"Request failed after {MAX_INFERENCE_RETRIES} retries. Error: {e}\n"
|
259
|
+
f"Stack trace:\n{traceback.format_exc()}"
|
257
260
|
)
|
258
261
|
raise
|
259
262
|
|