arbor-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/__init__.py +17 -0
- arbor/cli.py +83 -43
- arbor/client/arbor_client.py +259 -0
- arbor/server/api/models/schemas.py +3 -1
- arbor/server/api/routes/grpo.py +2 -6
- arbor/server/api/routes/inference.py +7 -3
- arbor/server/core/config.py +293 -7
- arbor/server/core/config_manager.py +100 -0
- arbor/server/main.py +26 -1
- arbor/server/services/comms/comms.py +13 -9
- arbor/server/services/file_manager.py +7 -4
- arbor/server/services/grpo_manager.py +98 -62
- arbor/server/services/health_manager.py +171 -0
- arbor/server/services/inference/vllm_client.py +6 -4
- arbor/server/services/inference_manager.py +40 -38
- arbor/server/services/job_manager.py +2 -2
- arbor/server/services/scripts/grpo_training.py +62 -281
- arbor/server/services/scripts/mmgrpo_training.py +510 -0
- arbor/server/services/scripts/sft_training.py +8 -5
- arbor/server/services/scripts/utils/callbacks.py +33 -0
- arbor/server/services/scripts/utils/comms_monitors.py +169 -0
- arbor/server/services/scripts/utils/dataset.py +176 -0
- arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
- arbor/server/services/scripts/utils/mock_server.py +124 -0
- arbor/server/services/training_manager.py +4 -4
- arbor/server/utils/logging.py +298 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/METADATA +17 -19
- arbor_ai-0.2.3.dist-info/RECORD +51 -0
- arbor_ai-0.2.1.dist-info/RECORD +0 -42
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/WHEEL +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ import json
|
|
3
3
|
import os
|
4
4
|
import random
|
5
5
|
import signal
|
6
|
-
import socket
|
7
6
|
import string
|
8
7
|
import subprocess
|
9
8
|
import sys
|
@@ -20,14 +19,17 @@ from arbor.server.api.models.schemas import (
|
|
20
19
|
GRPOConfigRequest,
|
21
20
|
GRPORequest,
|
22
21
|
)
|
23
|
-
from arbor.server.core.config import
|
22
|
+
from arbor.server.core.config import Config
|
24
23
|
from arbor.server.services.comms.comms import ArborServerCommsHandler
|
25
24
|
from arbor.server.services.inference_manager import InferenceManager
|
25
|
+
from arbor.server.utils.logging import get_logger
|
26
|
+
|
27
|
+
logger = get_logger(__name__)
|
26
28
|
|
27
29
|
|
28
30
|
class GRPOManager:
|
29
|
-
def __init__(self,
|
30
|
-
self.
|
31
|
+
def __init__(self, config: Config):
|
32
|
+
self.config = config
|
31
33
|
self.training_process = None
|
32
34
|
self.current_model = None
|
33
35
|
self.train_kwargs = None
|
@@ -47,7 +49,7 @@ class GRPOManager:
|
|
47
49
|
|
48
50
|
def _signal_handler(self, signum, frame):
|
49
51
|
"""Handle keyboard interrupt (SIGINT) gracefully."""
|
50
|
-
|
52
|
+
logger.info("Received keyboard interrupt. Shutting down gracefully...")
|
51
53
|
# Sleep for a bit to let async operations go through
|
52
54
|
time.sleep(2)
|
53
55
|
if self.training_process is not None:
|
@@ -65,16 +67,14 @@ class GRPOManager:
|
|
65
67
|
)
|
66
68
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
67
69
|
name = f"grpo:{model_name}:{suffix}:{timestamp}"
|
68
|
-
return name, str(Path(self.
|
70
|
+
return name, str(Path(self.config.STORAGE_PATH).resolve() / "models" / name)
|
69
71
|
|
70
72
|
def find_training_args(self, request: GRPOConfigRequest) -> dict:
|
71
73
|
"""Process the config request and return training arguments."""
|
72
74
|
name, output_dir = self.make_output_dir(request.model, request.suffix)
|
73
75
|
|
74
76
|
# Here are defaults for training. We can adjust them if we disagree w the huggingface defaults
|
75
|
-
default_train_kwargs = {
|
76
|
-
"output_dir": output_dir,
|
77
|
-
}
|
77
|
+
default_train_kwargs = {"output_dir": output_dir, "grpo_flavor": "grpo"}
|
78
78
|
|
79
79
|
train_kwargs = request.model_dump(exclude_unset=True)
|
80
80
|
return {**default_train_kwargs, **(train_kwargs or {})}
|
@@ -108,7 +108,7 @@ class GRPOManager:
|
|
108
108
|
key: train_kwargs[key] for key in trl_keys if key in train_kwargs
|
109
109
|
}
|
110
110
|
|
111
|
-
arbor_keys = ["max_context_length", "lora"]
|
111
|
+
arbor_keys = ["max_context_length", "lora", "wandb_kwargs", "grpo_flavor"]
|
112
112
|
arbor_train_kwargs = {
|
113
113
|
key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
|
114
114
|
}
|
@@ -121,7 +121,7 @@ class GRPOManager:
|
|
121
121
|
"""Initialize the training process with ZMQ-based communication."""
|
122
122
|
self.train_kwargs = self.find_training_args(request)
|
123
123
|
|
124
|
-
trl_train_kwargs, arbor_train_kwargs = self.process_training_args(
|
124
|
+
self.trl_train_kwargs, self.arbor_train_kwargs = self.process_training_args(
|
125
125
|
self.train_kwargs
|
126
126
|
)
|
127
127
|
|
@@ -132,31 +132,34 @@ class GRPOManager:
|
|
132
132
|
# launch_kwargs = {
|
133
133
|
# k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
|
134
134
|
# }
|
135
|
-
inference_manager.launch_kwargs["max_context_length"] =
|
136
|
-
"max_context_length", None
|
135
|
+
inference_manager.launch_kwargs["max_context_length"] = (
|
136
|
+
self.arbor_train_kwargs.get("max_context_length", None)
|
137
137
|
)
|
138
|
-
|
138
|
+
logger.info("Launching inference server...")
|
139
139
|
inference_manager.launch(self.current_model)
|
140
140
|
|
141
141
|
# Initialize ZMQ socket manager - no need for connection acceptance thread anymore
|
142
142
|
self.server_comms_handler = ArborServerCommsHandler()
|
143
143
|
|
144
144
|
script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
|
145
|
-
|
145
|
+
script_name = {"mmgrpo": "mmgrpo_training.py", "grpo": "grpo_training.py"}[
|
146
|
+
self.arbor_train_kwargs["grpo_flavor"]
|
147
|
+
]
|
148
|
+
script_path = os.path.join(script_dir, script_name)
|
146
149
|
|
147
150
|
# Start the training process with ZMQ ports
|
148
151
|
my_env = os.environ.copy()
|
149
|
-
my_env["CUDA_VISIBLE_DEVICES"] = self.
|
152
|
+
my_env["CUDA_VISIBLE_DEVICES"] = self.config.arbor_config.training.gpu_ids
|
150
153
|
# WandB can block the training process for login, so we silence it
|
151
154
|
my_env["WANDB_SILENT"] = "true"
|
152
155
|
|
153
|
-
num_processes = self.
|
156
|
+
num_processes = self.config.arbor_config.training.gpu_ids.count(",") + 1
|
154
157
|
|
155
158
|
# This is the port for the accelerate main process
|
156
159
|
main_process_port = get_free_port()
|
157
160
|
|
158
161
|
params = [
|
159
|
-
|
162
|
+
sys.executable,
|
160
163
|
"-m",
|
161
164
|
"accelerate.commands.launch",
|
162
165
|
"--num_processes",
|
@@ -164,11 +167,11 @@ class GRPOManager:
|
|
164
167
|
"--main_process_port",
|
165
168
|
str(main_process_port),
|
166
169
|
]
|
167
|
-
if self.
|
170
|
+
if self.config.arbor_config.training.accelerate_config:
|
168
171
|
params.extend(
|
169
172
|
[
|
170
173
|
"--config_file",
|
171
|
-
self.
|
174
|
+
self.config.arbor_config.training.accelerate_config,
|
172
175
|
]
|
173
176
|
)
|
174
177
|
params.extend(
|
@@ -195,12 +198,12 @@ class GRPOManager:
|
|
195
198
|
"--model",
|
196
199
|
self.current_model,
|
197
200
|
"--trl_train_kwargs",
|
198
|
-
json.dumps(trl_train_kwargs),
|
201
|
+
json.dumps(self.trl_train_kwargs),
|
199
202
|
"--arbor_train_kwargs",
|
200
|
-
json.dumps(arbor_train_kwargs),
|
203
|
+
json.dumps(self.arbor_train_kwargs),
|
201
204
|
]
|
202
205
|
)
|
203
|
-
|
206
|
+
logger.info(f"Running GRPO training command: {' '.join(params)}")
|
204
207
|
|
205
208
|
self.training_process = subprocess.Popen(
|
206
209
|
params,
|
@@ -222,9 +225,9 @@ class GRPOManager:
|
|
222
225
|
break
|
223
226
|
if line:
|
224
227
|
buffer.append(line)
|
225
|
-
#
|
228
|
+
# Log only if stop_event is not set
|
226
229
|
if not stop_event.is_set():
|
227
|
-
|
230
|
+
logger.info(f"[GRPO LOG] {line.strip()}")
|
228
231
|
|
229
232
|
# Start a background thread to read from the process continuously
|
230
233
|
thread = threading.Thread(
|
@@ -256,10 +259,10 @@ class GRPOManager:
|
|
256
259
|
|
257
260
|
def _handle_status_updates(self, inference_manager: InferenceManager):
|
258
261
|
"""Handle status updates from training process using ZMQ SUB socket"""
|
259
|
-
|
262
|
+
logger.info("Starting status update handler...")
|
260
263
|
try:
|
261
264
|
for status in self.server_comms_handler.receive_status():
|
262
|
-
|
265
|
+
logger.debug(f"Received status update: {status}")
|
263
266
|
if status["status"] == "weight_update_start":
|
264
267
|
# Block inference calls by incrementing counter
|
265
268
|
inference_manager.start_weight_update()
|
@@ -267,44 +270,80 @@ class GRPOManager:
|
|
267
270
|
# Decrement counter to potentially allow inference calls again
|
268
271
|
inference_manager.complete_weight_update()
|
269
272
|
elif status["status"] == "model_saved":
|
270
|
-
|
273
|
+
logger.info("Updating inference model...")
|
271
274
|
# There is a case where this status is sent multiple times
|
272
275
|
# We need to make sure we only update the model once
|
273
276
|
self.saving_model = False
|
274
|
-
|
277
|
+
logger.info("Model update complete")
|
275
278
|
elif status["status"] == "checkpoint_saved":
|
276
|
-
|
279
|
+
logger.info("Received checkpoint saved status")
|
277
280
|
self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
|
278
281
|
self.last_checkpoint = status["checkpoint_name"]
|
279
282
|
self.saving_checkpoint = False
|
280
|
-
|
283
|
+
logger.info("Checkpoint saved")
|
281
284
|
elif status["status"] == "error":
|
282
|
-
|
285
|
+
error_msg = status.get("error", "Unknown error")
|
286
|
+
logger.error(f"Training error: {error_msg}")
|
283
287
|
elif status["status"] == "terminated":
|
284
288
|
self.terminating = False
|
285
|
-
|
289
|
+
logger.info("Training process terminated")
|
286
290
|
except Exception as e:
|
287
|
-
|
291
|
+
logger.error(f"Error in status update handler: {e}")
|
288
292
|
# Make sure to allow inference if there's an error
|
289
293
|
try:
|
290
294
|
inference_manager.complete_weight_update()
|
291
295
|
except:
|
292
296
|
pass
|
293
297
|
|
298
|
+
def validate_batch(self, batch):
|
299
|
+
if not isinstance(batch, list):
|
300
|
+
raise ValueError("Batch must be a list")
|
301
|
+
|
302
|
+
if self.arbor_train_kwargs["grpo_flavor"] == "mmgrpo":
|
303
|
+
for group in batch:
|
304
|
+
if not isinstance(group, list):
|
305
|
+
raise ValueError("Each group in batch must be a list")
|
306
|
+
for item in group:
|
307
|
+
if not isinstance(item, dict):
|
308
|
+
raise ValueError("Each item in group must be a dictionary")
|
309
|
+
required_keys = {"messages", "completion", "advantage"}
|
310
|
+
if not all(key in item for key in required_keys):
|
311
|
+
raise ValueError(
|
312
|
+
f"Each item must contain keys: {required_keys}"
|
313
|
+
)
|
314
|
+
return True
|
315
|
+
elif self.arbor_train_kwargs["grpo_flavor"] == "grpo":
|
316
|
+
for item in batch:
|
317
|
+
if not isinstance(item, dict):
|
318
|
+
raise ValueError("Each item in batch must be a dictionary")
|
319
|
+
required_keys = {"messages", "completion", "reward"}
|
320
|
+
if not all(key in item for key in required_keys):
|
321
|
+
raise ValueError(f"Each item must contain keys: {required_keys}")
|
322
|
+
return True
|
323
|
+
else:
|
324
|
+
raise NotImplementedError(
|
325
|
+
f"GRPO flavor batch validation not implemented for {self.arbor_train_kwargs['grpo_flavor']}"
|
326
|
+
)
|
327
|
+
|
294
328
|
def grpo_step(
|
295
329
|
self, request: GRPORequest, inference_manager: InferenceManager
|
296
330
|
) -> str:
|
297
331
|
while self.saving_checkpoint:
|
298
|
-
|
332
|
+
logger.info(
|
333
|
+
"Saving checkpoint, pausing GRPO steps until checkpoint is saved..."
|
334
|
+
)
|
299
335
|
time.sleep(5)
|
300
336
|
|
337
|
+
self.validate_batch(request.batch)
|
338
|
+
|
301
339
|
try:
|
340
|
+
|
302
341
|
# Send the batch to the training process
|
303
342
|
self.server_comms_handler.send_data(request.batch)
|
304
343
|
self.data_count += 1
|
305
344
|
|
306
345
|
except Exception as e:
|
307
|
-
|
346
|
+
logger.error(f"Failed to send batch to training process: {e}")
|
308
347
|
raise
|
309
348
|
|
310
349
|
self.current_model = self.train_kwargs["output_dir"]
|
@@ -322,7 +361,7 @@ class GRPOManager:
|
|
322
361
|
while (
|
323
362
|
inference_manager.is_updating
|
324
363
|
): # Use the property instead of direct access
|
325
|
-
|
364
|
+
logger.info("Waiting for weight updates to finish before checkpointing...")
|
326
365
|
time.sleep(5)
|
327
366
|
|
328
367
|
self.saving_checkpoint = True
|
@@ -330,7 +369,7 @@ class GRPOManager:
|
|
330
369
|
{"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
|
331
370
|
)
|
332
371
|
while self.saving_checkpoint:
|
333
|
-
|
372
|
+
logger.info("Waiting for checkpoint to be saved...")
|
334
373
|
time.sleep(5)
|
335
374
|
return {
|
336
375
|
"current_model": self.current_model,
|
@@ -345,14 +384,14 @@ class GRPOManager:
|
|
345
384
|
while (
|
346
385
|
inference_manager and inference_manager.is_updating
|
347
386
|
): # Use the property instead of direct access
|
348
|
-
|
387
|
+
logger.info("Waiting for final weight updates to finish before saving...")
|
349
388
|
time.sleep(5)
|
350
389
|
|
351
|
-
|
390
|
+
logger.info("Sending save model command")
|
352
391
|
self.saving_model = True
|
353
392
|
self.server_comms_handler.send_command({"command": "save_model"})
|
354
393
|
while self.saving_model:
|
355
|
-
|
394
|
+
logger.info("Waiting for final model to be saved...")
|
356
395
|
time.sleep(5)
|
357
396
|
|
358
397
|
termination_data = {
|
@@ -361,37 +400,34 @@ class GRPOManager:
|
|
361
400
|
"last_checkpoint": self.last_checkpoint,
|
362
401
|
}
|
363
402
|
|
364
|
-
|
403
|
+
logger.info("Sending termination command")
|
365
404
|
self.terminating = True
|
366
405
|
self.server_comms_handler.send_command({"command": "terminate"})
|
367
|
-
|
406
|
+
logger.info("Waiting for training process to finish...")
|
368
407
|
|
369
408
|
# Wait for at most 15 seconds for termination
|
370
409
|
start_time = time.time()
|
371
410
|
while self.terminating:
|
372
411
|
if time.time() - start_time > 15:
|
373
|
-
|
412
|
+
logger.warning(
|
374
413
|
"Termination wait timed out after 15 seconds, proceeding with cleanup..."
|
375
414
|
)
|
376
415
|
break
|
377
|
-
|
416
|
+
logger.info("Waiting for run to be terminated...")
|
378
417
|
time.sleep(3)
|
379
418
|
|
380
|
-
|
419
|
+
logger.info("Starting cleanup")
|
381
420
|
self.cleanup_termination(inference_manager)
|
382
421
|
|
383
422
|
if self.train_kwargs and "output_dir" in self.train_kwargs:
|
384
|
-
print(
|
385
|
-
f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
|
386
|
-
)
|
387
|
-
if not os.path.exists(self.train_kwargs["output_dir"]):
|
388
|
-
print(
|
389
|
-
f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
|
390
|
-
)
|
391
423
|
output_dir = self.train_kwargs["output_dir"]
|
424
|
+
logger.info(f"Training completed. Model saved to {output_dir}")
|
425
|
+
logger.info(f"Training logs and checkpoints are stored in: {output_dir}")
|
426
|
+
if not os.path.exists(output_dir):
|
427
|
+
logger.warning(f"Output directory {output_dir} does not exist")
|
392
428
|
self.train_kwargs = None
|
393
429
|
else:
|
394
|
-
|
430
|
+
logger.info("Training terminated, no output directory specified")
|
395
431
|
self.train_kwargs = None
|
396
432
|
|
397
433
|
return termination_data
|
@@ -400,7 +436,7 @@ class GRPOManager:
|
|
400
436
|
try:
|
401
437
|
# Kill training process and all its children (accelerate launcher creates multiple processes)
|
402
438
|
if self.training_process:
|
403
|
-
|
439
|
+
logger.info("Terminating training process and its children...")
|
404
440
|
try:
|
405
441
|
parent = psutil.Process(self.training_process.pid)
|
406
442
|
# Get all child processes including grandchildren
|
@@ -427,9 +463,9 @@ class GRPOManager:
|
|
427
463
|
pass
|
428
464
|
|
429
465
|
except psutil.NoSuchProcess:
|
430
|
-
|
466
|
+
logger.warning(f"Process {self.training_process.pid} not found")
|
431
467
|
except Exception as e:
|
432
|
-
|
468
|
+
logger.error(f"Error killing training process tree: {e}")
|
433
469
|
# Fallback to basic termination
|
434
470
|
self.training_process.terminate()
|
435
471
|
try:
|
@@ -440,11 +476,11 @@ class GRPOManager:
|
|
440
476
|
|
441
477
|
# Clean up ZMQ connections
|
442
478
|
if self.server_comms_handler:
|
443
|
-
|
479
|
+
logger.debug("Closing ZMQ connections...")
|
444
480
|
self.server_comms_handler.close()
|
445
481
|
|
446
482
|
if inference_manager and inference_manager.process is not None:
|
447
|
-
|
483
|
+
logger.info("Killing inference manager...")
|
448
484
|
inference_manager.kill()
|
449
485
|
|
450
486
|
# Reinitialize in case we want to start a new training run
|
@@ -453,9 +489,9 @@ class GRPOManager:
|
|
453
489
|
self.server_comms_handler = None
|
454
490
|
self.status_thread = None
|
455
491
|
self.data_count = 0
|
456
|
-
|
492
|
+
logger.info("Cleanup completed successfully")
|
457
493
|
except Exception as e:
|
458
|
-
|
494
|
+
logger.error(f"Error during cleanup: {e}")
|
459
495
|
# Still reset state even if cleanup fails
|
460
496
|
self.training_process = None
|
461
497
|
self.current_model = None
|
@@ -478,5 +514,5 @@ def get_free_port() -> int:
|
|
478
514
|
s.bind(("localhost", 0))
|
479
515
|
ports.append(s.getsockname()[1])
|
480
516
|
except Exception as e:
|
481
|
-
|
517
|
+
logger.error(f"Error binding to port: {e}")
|
482
518
|
return random.choice(ports)
|
@@ -0,0 +1,171 @@
|
|
1
|
+
import platform
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Any, Dict
|
4
|
+
|
5
|
+
import psutil
|
6
|
+
|
7
|
+
from arbor.server.core.config import Config
|
8
|
+
|
9
|
+
try:
|
10
|
+
import GPUtil
|
11
|
+
|
12
|
+
GPU_AVAILABLE = True
|
13
|
+
except ImportError:
|
14
|
+
GPU_AVAILABLE = False
|
15
|
+
|
16
|
+
try:
|
17
|
+
import torch
|
18
|
+
|
19
|
+
TORCH_AVAILABLE = True
|
20
|
+
except ImportError:
|
21
|
+
TORCH_AVAILABLE = False
|
22
|
+
|
23
|
+
|
24
|
+
class HealthManager:
|
25
|
+
"""Manages system health checks including GPU monitoring."""
|
26
|
+
|
27
|
+
def __init__(self, config: Config = None):
|
28
|
+
self.config = config
|
29
|
+
|
30
|
+
def get_gpu_info(self) -> Dict[str, Any]:
|
31
|
+
"""Get GPU information including available and used GPUs."""
|
32
|
+
gpu_info = {
|
33
|
+
"gpus_available": 0,
|
34
|
+
"gpus_used": 0,
|
35
|
+
"gpu_details": [],
|
36
|
+
"cuda_available": False,
|
37
|
+
"gpu_library": "none",
|
38
|
+
}
|
39
|
+
|
40
|
+
# Try GPUtil first
|
41
|
+
if GPU_AVAILABLE:
|
42
|
+
try:
|
43
|
+
gpus = GPUtil.getGPUs()
|
44
|
+
gpu_info["gpus_available"] = len(gpus)
|
45
|
+
gpu_info["gpu_library"] = "GPUtil"
|
46
|
+
|
47
|
+
for i, gpu in enumerate(gpus):
|
48
|
+
gpu_detail = {
|
49
|
+
"id": gpu.id,
|
50
|
+
"name": gpu.name,
|
51
|
+
"memory_total": f"{gpu.memoryTotal}MB",
|
52
|
+
"memory_used": f"{gpu.memoryUsed}MB",
|
53
|
+
"memory_free": f"{gpu.memoryFree}MB",
|
54
|
+
"utilization": f"{gpu.load * 100:.1f}%",
|
55
|
+
"temperature": f"{gpu.temperature}°C",
|
56
|
+
}
|
57
|
+
gpu_info["gpu_details"].append(gpu_detail)
|
58
|
+
|
59
|
+
# Consider GPU "used" if utilization > 10% or memory usage > 10%
|
60
|
+
if gpu.load > 0.1 or (gpu.memoryUsed / gpu.memoryTotal) > 0.1:
|
61
|
+
gpu_info["gpus_used"] += 1
|
62
|
+
|
63
|
+
except Exception as e:
|
64
|
+
gpu_info["error"] = f"GPUtil error: {str(e)}"
|
65
|
+
|
66
|
+
# Try PyTorch as fallback/additional info
|
67
|
+
if TORCH_AVAILABLE:
|
68
|
+
try:
|
69
|
+
gpu_info["cuda_available"] = torch.cuda.is_available()
|
70
|
+
if torch.cuda.is_available():
|
71
|
+
cuda_count = torch.cuda.device_count()
|
72
|
+
if not GPU_AVAILABLE: # Only use torch info if GPUtil not available
|
73
|
+
gpu_info["gpus_available"] = cuda_count
|
74
|
+
gpu_info["gpu_library"] = "PyTorch"
|
75
|
+
|
76
|
+
for i in range(cuda_count):
|
77
|
+
props = torch.cuda.get_device_properties(i)
|
78
|
+
memory_allocated = (
|
79
|
+
torch.cuda.memory_allocated(i) / 1024**2
|
80
|
+
) # MB
|
81
|
+
memory_cached = (
|
82
|
+
torch.cuda.memory_reserved(i) / 1024**2
|
83
|
+
) # MB
|
84
|
+
memory_total = props.total_memory / 1024**2 # MB
|
85
|
+
|
86
|
+
gpu_detail = {
|
87
|
+
"id": i,
|
88
|
+
"name": props.name,
|
89
|
+
"memory_total": f"{memory_total:.0f}MB",
|
90
|
+
"memory_allocated": f"{memory_allocated:.0f}MB",
|
91
|
+
"memory_cached": f"{memory_cached:.0f}MB",
|
92
|
+
"compute_capability": f"{props.major}.{props.minor}",
|
93
|
+
}
|
94
|
+
gpu_info["gpu_details"].append(gpu_detail)
|
95
|
+
|
96
|
+
# Consider GPU "used" if memory allocated > 100MB
|
97
|
+
if memory_allocated > 100:
|
98
|
+
gpu_info["gpus_used"] += 1
|
99
|
+
|
100
|
+
except Exception as e:
|
101
|
+
gpu_info["torch_error"] = f"PyTorch error: {str(e)}"
|
102
|
+
|
103
|
+
return gpu_info
|
104
|
+
|
105
|
+
def get_system_info(self) -> Dict[str, Any]:
|
106
|
+
"""Get system information including CPU, memory, and disk usage."""
|
107
|
+
memory = psutil.virtual_memory()
|
108
|
+
disk = psutil.disk_usage("/")
|
109
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
110
|
+
|
111
|
+
return {
|
112
|
+
"platform": platform.system(),
|
113
|
+
"platform_release": platform.release(),
|
114
|
+
"platform_version": platform.version(),
|
115
|
+
"architecture": platform.machine(),
|
116
|
+
"processor": platform.processor(),
|
117
|
+
"cpu_usage": f"{cpu_percent}%",
|
118
|
+
"memory": {
|
119
|
+
"total": f"{memory.total / 1024**3:.2f}GB",
|
120
|
+
"available": f"{memory.available / 1024**3:.2f}GB",
|
121
|
+
"used": f"{memory.used / 1024**3:.2f}GB",
|
122
|
+
"percentage": f"{memory.percent}%",
|
123
|
+
},
|
124
|
+
"disk": {
|
125
|
+
"total": f"{disk.total / 1024**3:.2f}GB",
|
126
|
+
"free": f"{disk.free / 1024**3:.2f}GB",
|
127
|
+
"used": f"{disk.used / 1024**3:.2f}GB",
|
128
|
+
"percentage": f"{(disk.used / disk.total) * 100:.1f}%",
|
129
|
+
},
|
130
|
+
"gpu": self.get_gpu_info(),
|
131
|
+
}
|
132
|
+
|
133
|
+
def get_health_status(self) -> Dict[str, Any]:
|
134
|
+
"""Get comprehensive health status including system and GPU information."""
|
135
|
+
version = self.config.get_arbor_version() if self.config else "unknown"
|
136
|
+
versions = (
|
137
|
+
self.config.get_system_versions() if self.config else {"arbor": version}
|
138
|
+
)
|
139
|
+
|
140
|
+
return {
|
141
|
+
"status": "healthy",
|
142
|
+
"version": version, # Keep for backward compatibility
|
143
|
+
"versions": versions, # Comprehensive version info
|
144
|
+
"timestamp": datetime.now().isoformat(),
|
145
|
+
"system": self.get_system_info(),
|
146
|
+
}
|
147
|
+
|
148
|
+
def is_healthy(self) -> bool:
|
149
|
+
"""Check if the system is healthy based on various metrics."""
|
150
|
+
try:
|
151
|
+
# Check memory usage (unhealthy if > 90%)
|
152
|
+
memory = psutil.virtual_memory()
|
153
|
+
if memory.percent > 90:
|
154
|
+
print(f"Memory usage is {memory.percent}%")
|
155
|
+
return False
|
156
|
+
|
157
|
+
# Check disk usage (unhealthy if > 95%)
|
158
|
+
disk = psutil.disk_usage("/")
|
159
|
+
if (disk.used / disk.total) * 100 > 95:
|
160
|
+
print(f"Disk usage is {disk.used / disk.total * 100}%")
|
161
|
+
return False
|
162
|
+
|
163
|
+
# Check CPU usage (unhealthy if > 95% sustained)
|
164
|
+
cpu_percent = psutil.cpu_percent(interval=2)
|
165
|
+
if cpu_percent > 95:
|
166
|
+
print(f"CPU usage is {cpu_percent}%")
|
167
|
+
return False
|
168
|
+
|
169
|
+
return True
|
170
|
+
except Exception:
|
171
|
+
return False
|
@@ -4,6 +4,7 @@ import asyncio
|
|
4
4
|
import atexit
|
5
5
|
import logging
|
6
6
|
import time
|
7
|
+
import traceback
|
7
8
|
from typing import Optional
|
8
9
|
|
9
10
|
import httpx
|
@@ -131,11 +132,11 @@ class VLLMClient:
|
|
131
132
|
) from exc
|
132
133
|
else:
|
133
134
|
if response.status_code == 200:
|
134
|
-
logger.
|
135
|
+
logger.debug("Server is up!")
|
135
136
|
return None
|
136
137
|
|
137
|
-
# Retry logic: wait before
|
138
|
-
logger.
|
138
|
+
# Retry logic: wait before tryng again
|
139
|
+
logger.debug(
|
139
140
|
f"Server is not up yet. Retrying in {retry_interval} seconds..."
|
140
141
|
)
|
141
142
|
time.sleep(retry_interval)
|
@@ -254,7 +255,8 @@ class VLLMClient:
|
|
254
255
|
await asyncio.sleep(INFERENCE_RETRY_DELAY)
|
255
256
|
else:
|
256
257
|
logger.error(
|
257
|
-
f"Request failed after {MAX_INFERENCE_RETRIES} retries"
|
258
|
+
f"Request failed after {MAX_INFERENCE_RETRIES} retries. Error: {e}\n"
|
259
|
+
f"Stack trace:\n{traceback.format_exc()}"
|
258
260
|
)
|
259
261
|
raise
|
260
262
|
|