arbor-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/__init__.py +17 -0
  2. arbor/cli.py +83 -43
  3. arbor/client/arbor_client.py +259 -0
  4. arbor/server/api/models/schemas.py +3 -1
  5. arbor/server/api/routes/grpo.py +2 -6
  6. arbor/server/api/routes/inference.py +7 -3
  7. arbor/server/core/config.py +293 -7
  8. arbor/server/core/config_manager.py +100 -0
  9. arbor/server/main.py +26 -1
  10. arbor/server/services/comms/comms.py +13 -9
  11. arbor/server/services/file_manager.py +7 -4
  12. arbor/server/services/grpo_manager.py +98 -62
  13. arbor/server/services/health_manager.py +171 -0
  14. arbor/server/services/inference/vllm_client.py +6 -4
  15. arbor/server/services/inference_manager.py +40 -38
  16. arbor/server/services/job_manager.py +2 -2
  17. arbor/server/services/scripts/grpo_training.py +62 -281
  18. arbor/server/services/scripts/mmgrpo_training.py +510 -0
  19. arbor/server/services/scripts/sft_training.py +8 -5
  20. arbor/server/services/scripts/utils/callbacks.py +33 -0
  21. arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  22. arbor/server/services/scripts/utils/dataset.py +176 -0
  23. arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  24. arbor/server/services/scripts/utils/mock_server.py +124 -0
  25. arbor/server/services/training_manager.py +4 -4
  26. arbor/server/utils/logging.py +298 -0
  27. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/METADATA +17 -19
  28. arbor_ai-0.2.3.dist-info/RECORD +51 -0
  29. arbor_ai-0.2.1.dist-info/RECORD +0 -42
  30. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ import json
3
3
  import os
4
4
  import random
5
5
  import signal
6
- import socket
7
6
  import string
8
7
  import subprocess
9
8
  import sys
@@ -20,14 +19,17 @@ from arbor.server.api.models.schemas import (
20
19
  GRPOConfigRequest,
21
20
  GRPORequest,
22
21
  )
23
- from arbor.server.core.config import Settings
22
+ from arbor.server.core.config import Config
24
23
  from arbor.server.services.comms.comms import ArborServerCommsHandler
25
24
  from arbor.server.services.inference_manager import InferenceManager
25
+ from arbor.server.utils.logging import get_logger
26
+
27
+ logger = get_logger(__name__)
26
28
 
27
29
 
28
30
  class GRPOManager:
29
- def __init__(self, settings: Settings):
30
- self.settings = settings
31
+ def __init__(self, config: Config):
32
+ self.config = config
31
33
  self.training_process = None
32
34
  self.current_model = None
33
35
  self.train_kwargs = None
@@ -47,7 +49,7 @@ class GRPOManager:
47
49
 
48
50
  def _signal_handler(self, signum, frame):
49
51
  """Handle keyboard interrupt (SIGINT) gracefully."""
50
- print("\nReceived keyboard interrupt. Shutting down gracefully...")
52
+ logger.info("Received keyboard interrupt. Shutting down gracefully...")
51
53
  # Sleep for a bit to let async operations go through
52
54
  time.sleep(2)
53
55
  if self.training_process is not None:
@@ -65,16 +67,14 @@ class GRPOManager:
65
67
  )
66
68
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
67
69
  name = f"grpo:{model_name}:{suffix}:{timestamp}"
68
- return name, str(Path(self.settings.STORAGE_PATH).resolve() / "models" / name)
70
+ return name, str(Path(self.config.STORAGE_PATH).resolve() / "models" / name)
69
71
 
70
72
  def find_training_args(self, request: GRPOConfigRequest) -> dict:
71
73
  """Process the config request and return training arguments."""
72
74
  name, output_dir = self.make_output_dir(request.model, request.suffix)
73
75
 
74
76
  # Here are defaults for training. We can adjust them if we disagree w the huggingface defaults
75
- default_train_kwargs = {
76
- "output_dir": output_dir,
77
- }
77
+ default_train_kwargs = {"output_dir": output_dir, "grpo_flavor": "grpo"}
78
78
 
79
79
  train_kwargs = request.model_dump(exclude_unset=True)
80
80
  return {**default_train_kwargs, **(train_kwargs or {})}
@@ -108,7 +108,7 @@ class GRPOManager:
108
108
  key: train_kwargs[key] for key in trl_keys if key in train_kwargs
109
109
  }
110
110
 
111
- arbor_keys = ["max_context_length", "lora"]
111
+ arbor_keys = ["max_context_length", "lora", "wandb_kwargs", "grpo_flavor"]
112
112
  arbor_train_kwargs = {
113
113
  key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
114
114
  }
@@ -121,7 +121,7 @@ class GRPOManager:
121
121
  """Initialize the training process with ZMQ-based communication."""
122
122
  self.train_kwargs = self.find_training_args(request)
123
123
 
124
- trl_train_kwargs, arbor_train_kwargs = self.process_training_args(
124
+ self.trl_train_kwargs, self.arbor_train_kwargs = self.process_training_args(
125
125
  self.train_kwargs
126
126
  )
127
127
 
@@ -132,31 +132,34 @@ class GRPOManager:
132
132
  # launch_kwargs = {
133
133
  # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
134
134
  # }
135
- inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
136
- "max_context_length", None
135
+ inference_manager.launch_kwargs["max_context_length"] = (
136
+ self.arbor_train_kwargs.get("max_context_length", None)
137
137
  )
138
- print("Launching inference server...")
138
+ logger.info("Launching inference server...")
139
139
  inference_manager.launch(self.current_model)
140
140
 
141
141
  # Initialize ZMQ socket manager - no need for connection acceptance thread anymore
142
142
  self.server_comms_handler = ArborServerCommsHandler()
143
143
 
144
144
  script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
145
- script_path = os.path.join(script_dir, "grpo_training.py")
145
+ script_name = {"mmgrpo": "mmgrpo_training.py", "grpo": "grpo_training.py"}[
146
+ self.arbor_train_kwargs["grpo_flavor"]
147
+ ]
148
+ script_path = os.path.join(script_dir, script_name)
146
149
 
147
150
  # Start the training process with ZMQ ports
148
151
  my_env = os.environ.copy()
149
- my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
152
+ my_env["CUDA_VISIBLE_DEVICES"] = self.config.arbor_config.training.gpu_ids
150
153
  # WandB can block the training process for login, so we silence it
151
154
  my_env["WANDB_SILENT"] = "true"
152
155
 
153
- num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
156
+ num_processes = self.config.arbor_config.training.gpu_ids.count(",") + 1
154
157
 
155
158
  # This is the port for the accelerate main process
156
159
  main_process_port = get_free_port()
157
160
 
158
161
  params = [
159
- "python",
162
+ sys.executable,
160
163
  "-m",
161
164
  "accelerate.commands.launch",
162
165
  "--num_processes",
@@ -164,11 +167,11 @@ class GRPOManager:
164
167
  "--main_process_port",
165
168
  str(main_process_port),
166
169
  ]
167
- if self.settings.arbor_config.training.accelerate_config:
170
+ if self.config.arbor_config.training.accelerate_config:
168
171
  params.extend(
169
172
  [
170
173
  "--config_file",
171
- self.settings.arbor_config.training.accelerate_config,
174
+ self.config.arbor_config.training.accelerate_config,
172
175
  ]
173
176
  )
174
177
  params.extend(
@@ -195,12 +198,12 @@ class GRPOManager:
195
198
  "--model",
196
199
  self.current_model,
197
200
  "--trl_train_kwargs",
198
- json.dumps(trl_train_kwargs),
201
+ json.dumps(self.trl_train_kwargs),
199
202
  "--arbor_train_kwargs",
200
- json.dumps(arbor_train_kwargs),
203
+ json.dumps(self.arbor_train_kwargs),
201
204
  ]
202
205
  )
203
- print(f"Running following command\n: {' '.join(params)}")
206
+ logger.info(f"Running GRPO training command: {' '.join(params)}")
204
207
 
205
208
  self.training_process = subprocess.Popen(
206
209
  params,
@@ -222,9 +225,9 @@ class GRPOManager:
222
225
  break
223
226
  if line:
224
227
  buffer.append(line)
225
- # Print only if stop_event is not set
228
+ # Log only if stop_event is not set
226
229
  if not stop_event.is_set():
227
- print(f"[GRPO LOG] {line}", end="")
230
+ logger.info(f"[GRPO LOG] {line.strip()}")
228
231
 
229
232
  # Start a background thread to read from the process continuously
230
233
  thread = threading.Thread(
@@ -256,10 +259,10 @@ class GRPOManager:
256
259
 
257
260
  def _handle_status_updates(self, inference_manager: InferenceManager):
258
261
  """Handle status updates from training process using ZMQ SUB socket"""
259
- print("Starting status update handler...")
262
+ logger.info("Starting status update handler...")
260
263
  try:
261
264
  for status in self.server_comms_handler.receive_status():
262
- print(f"Received status update: {status}")
265
+ logger.debug(f"Received status update: {status}")
263
266
  if status["status"] == "weight_update_start":
264
267
  # Block inference calls by incrementing counter
265
268
  inference_manager.start_weight_update()
@@ -267,44 +270,80 @@ class GRPOManager:
267
270
  # Decrement counter to potentially allow inference calls again
268
271
  inference_manager.complete_weight_update()
269
272
  elif status["status"] == "model_saved":
270
- print("Updating inference model...")
273
+ logger.info("Updating inference model...")
271
274
  # There is a case where this status is sent multiple times
272
275
  # We need to make sure we only update the model once
273
276
  self.saving_model = False
274
- print("Model update complete")
277
+ logger.info("Model update complete")
275
278
  elif status["status"] == "checkpoint_saved":
276
- print("Received checkpoint saved status")
279
+ logger.info("Received checkpoint saved status")
277
280
  self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
278
281
  self.last_checkpoint = status["checkpoint_name"]
279
282
  self.saving_checkpoint = False
280
- print("Checkpoint saved")
283
+ logger.info("Checkpoint saved")
281
284
  elif status["status"] == "error":
282
- print(f"Training error: {status.get('error', 'Unknown error')}")
285
+ error_msg = status.get("error", "Unknown error")
286
+ logger.error(f"Training error: {error_msg}")
283
287
  elif status["status"] == "terminated":
284
288
  self.terminating = False
285
- print("Training process terminated")
289
+ logger.info("Training process terminated")
286
290
  except Exception as e:
287
- print(f"Error in status update handler: {e}")
291
+ logger.error(f"Error in status update handler: {e}")
288
292
  # Make sure to allow inference if there's an error
289
293
  try:
290
294
  inference_manager.complete_weight_update()
291
295
  except:
292
296
  pass
293
297
 
298
+ def validate_batch(self, batch):
299
+ if not isinstance(batch, list):
300
+ raise ValueError("Batch must be a list")
301
+
302
+ if self.arbor_train_kwargs["grpo_flavor"] == "mmgrpo":
303
+ for group in batch:
304
+ if not isinstance(group, list):
305
+ raise ValueError("Each group in batch must be a list")
306
+ for item in group:
307
+ if not isinstance(item, dict):
308
+ raise ValueError("Each item in group must be a dictionary")
309
+ required_keys = {"messages", "completion", "advantage"}
310
+ if not all(key in item for key in required_keys):
311
+ raise ValueError(
312
+ f"Each item must contain keys: {required_keys}"
313
+ )
314
+ return True
315
+ elif self.arbor_train_kwargs["grpo_flavor"] == "grpo":
316
+ for item in batch:
317
+ if not isinstance(item, dict):
318
+ raise ValueError("Each item in batch must be a dictionary")
319
+ required_keys = {"messages", "completion", "reward"}
320
+ if not all(key in item for key in required_keys):
321
+ raise ValueError(f"Each item must contain keys: {required_keys}")
322
+ return True
323
+ else:
324
+ raise NotImplementedError(
325
+ f"GRPO flavor batch validation not implemented for {self.arbor_train_kwargs['grpo_flavor']}"
326
+ )
327
+
294
328
  def grpo_step(
295
329
  self, request: GRPORequest, inference_manager: InferenceManager
296
330
  ) -> str:
297
331
  while self.saving_checkpoint:
298
- print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
332
+ logger.info(
333
+ "Saving checkpoint, pausing GRPO steps until checkpoint is saved..."
334
+ )
299
335
  time.sleep(5)
300
336
 
337
+ self.validate_batch(request.batch)
338
+
301
339
  try:
340
+
302
341
  # Send the batch to the training process
303
342
  self.server_comms_handler.send_data(request.batch)
304
343
  self.data_count += 1
305
344
 
306
345
  except Exception as e:
307
- print(f"Failed to send batch to training process: {e}")
346
+ logger.error(f"Failed to send batch to training process: {e}")
308
347
  raise
309
348
 
310
349
  self.current_model = self.train_kwargs["output_dir"]
@@ -322,7 +361,7 @@ class GRPOManager:
322
361
  while (
323
362
  inference_manager.is_updating
324
363
  ): # Use the property instead of direct access
325
- print("Waiting for weight updates to finish before checkpointing...")
364
+ logger.info("Waiting for weight updates to finish before checkpointing...")
326
365
  time.sleep(5)
327
366
 
328
367
  self.saving_checkpoint = True
@@ -330,7 +369,7 @@ class GRPOManager:
330
369
  {"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
331
370
  )
332
371
  while self.saving_checkpoint:
333
- print("Waiting for checkpoint to be saved...")
372
+ logger.info("Waiting for checkpoint to be saved...")
334
373
  time.sleep(5)
335
374
  return {
336
375
  "current_model": self.current_model,
@@ -345,14 +384,14 @@ class GRPOManager:
345
384
  while (
346
385
  inference_manager and inference_manager.is_updating
347
386
  ): # Use the property instead of direct access
348
- print("Waiting for final weight updates to finish before saving...")
387
+ logger.info("Waiting for final weight updates to finish before saving...")
349
388
  time.sleep(5)
350
389
 
351
- print("sending save model command")
390
+ logger.info("Sending save model command")
352
391
  self.saving_model = True
353
392
  self.server_comms_handler.send_command({"command": "save_model"})
354
393
  while self.saving_model:
355
- print("Waiting for final model to be saved...")
394
+ logger.info("Waiting for final model to be saved...")
356
395
  time.sleep(5)
357
396
 
358
397
  termination_data = {
@@ -361,37 +400,34 @@ class GRPOManager:
361
400
  "last_checkpoint": self.last_checkpoint,
362
401
  }
363
402
 
364
- print("sending termination command")
403
+ logger.info("Sending termination command")
365
404
  self.terminating = True
366
405
  self.server_comms_handler.send_command({"command": "terminate"})
367
- print("Waiting for training process to finish...")
406
+ logger.info("Waiting for training process to finish...")
368
407
 
369
408
  # Wait for at most 15 seconds for termination
370
409
  start_time = time.time()
371
410
  while self.terminating:
372
411
  if time.time() - start_time > 15:
373
- print(
412
+ logger.warning(
374
413
  "Termination wait timed out after 15 seconds, proceeding with cleanup..."
375
414
  )
376
415
  break
377
- print("Waiting for run to be terminated...")
416
+ logger.info("Waiting for run to be terminated...")
378
417
  time.sleep(3)
379
418
 
380
- print("Doing cleanup")
419
+ logger.info("Starting cleanup")
381
420
  self.cleanup_termination(inference_manager)
382
421
 
383
422
  if self.train_kwargs and "output_dir" in self.train_kwargs:
384
- print(
385
- f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
386
- )
387
- if not os.path.exists(self.train_kwargs["output_dir"]):
388
- print(
389
- f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
390
- )
391
423
  output_dir = self.train_kwargs["output_dir"]
424
+ logger.info(f"Training completed. Model saved to {output_dir}")
425
+ logger.info(f"Training logs and checkpoints are stored in: {output_dir}")
426
+ if not os.path.exists(output_dir):
427
+ logger.warning(f"Output directory {output_dir} does not exist")
392
428
  self.train_kwargs = None
393
429
  else:
394
- print("Training terminated, no output directory specified")
430
+ logger.info("Training terminated, no output directory specified")
395
431
  self.train_kwargs = None
396
432
 
397
433
  return termination_data
@@ -400,7 +436,7 @@ class GRPOManager:
400
436
  try:
401
437
  # Kill training process and all its children (accelerate launcher creates multiple processes)
402
438
  if self.training_process:
403
- print("Terminating training process and its children...")
439
+ logger.info("Terminating training process and its children...")
404
440
  try:
405
441
  parent = psutil.Process(self.training_process.pid)
406
442
  # Get all child processes including grandchildren
@@ -427,9 +463,9 @@ class GRPOManager:
427
463
  pass
428
464
 
429
465
  except psutil.NoSuchProcess:
430
- print(f"Process {self.training_process.pid} not found")
466
+ logger.warning(f"Process {self.training_process.pid} not found")
431
467
  except Exception as e:
432
- print(f"Error killing training process tree: {e}")
468
+ logger.error(f"Error killing training process tree: {e}")
433
469
  # Fallback to basic termination
434
470
  self.training_process.terminate()
435
471
  try:
@@ -440,11 +476,11 @@ class GRPOManager:
440
476
 
441
477
  # Clean up ZMQ connections
442
478
  if self.server_comms_handler:
443
- print("Closing ZMQ connections...")
479
+ logger.debug("Closing ZMQ connections...")
444
480
  self.server_comms_handler.close()
445
481
 
446
482
  if inference_manager and inference_manager.process is not None:
447
- print("Killing inference manager...")
483
+ logger.info("Killing inference manager...")
448
484
  inference_manager.kill()
449
485
 
450
486
  # Reinitialize in case we want to start a new training run
@@ -453,9 +489,9 @@ class GRPOManager:
453
489
  self.server_comms_handler = None
454
490
  self.status_thread = None
455
491
  self.data_count = 0
456
- print("Cleanup completed successfully")
492
+ logger.info("Cleanup completed successfully")
457
493
  except Exception as e:
458
- print(f"Error during cleanup: {e}")
494
+ logger.error(f"Error during cleanup: {e}")
459
495
  # Still reset state even if cleanup fails
460
496
  self.training_process = None
461
497
  self.current_model = None
@@ -478,5 +514,5 @@ def get_free_port() -> int:
478
514
  s.bind(("localhost", 0))
479
515
  ports.append(s.getsockname()[1])
480
516
  except Exception as e:
481
- print(f"Error binding to port: {e}")
517
+ logger.error(f"Error binding to port: {e}")
482
518
  return random.choice(ports)
@@ -0,0 +1,171 @@
1
+ import platform
2
+ from datetime import datetime
3
+ from typing import Any, Dict
4
+
5
+ import psutil
6
+
7
+ from arbor.server.core.config import Config
8
+
9
+ try:
10
+ import GPUtil
11
+
12
+ GPU_AVAILABLE = True
13
+ except ImportError:
14
+ GPU_AVAILABLE = False
15
+
16
+ try:
17
+ import torch
18
+
19
+ TORCH_AVAILABLE = True
20
+ except ImportError:
21
+ TORCH_AVAILABLE = False
22
+
23
+
24
+ class HealthManager:
25
+ """Manages system health checks including GPU monitoring."""
26
+
27
+ def __init__(self, config: Config = None):
28
+ self.config = config
29
+
30
+ def get_gpu_info(self) -> Dict[str, Any]:
31
+ """Get GPU information including available and used GPUs."""
32
+ gpu_info = {
33
+ "gpus_available": 0,
34
+ "gpus_used": 0,
35
+ "gpu_details": [],
36
+ "cuda_available": False,
37
+ "gpu_library": "none",
38
+ }
39
+
40
+ # Try GPUtil first
41
+ if GPU_AVAILABLE:
42
+ try:
43
+ gpus = GPUtil.getGPUs()
44
+ gpu_info["gpus_available"] = len(gpus)
45
+ gpu_info["gpu_library"] = "GPUtil"
46
+
47
+ for i, gpu in enumerate(gpus):
48
+ gpu_detail = {
49
+ "id": gpu.id,
50
+ "name": gpu.name,
51
+ "memory_total": f"{gpu.memoryTotal}MB",
52
+ "memory_used": f"{gpu.memoryUsed}MB",
53
+ "memory_free": f"{gpu.memoryFree}MB",
54
+ "utilization": f"{gpu.load * 100:.1f}%",
55
+ "temperature": f"{gpu.temperature}°C",
56
+ }
57
+ gpu_info["gpu_details"].append(gpu_detail)
58
+
59
+ # Consider GPU "used" if utilization > 10% or memory usage > 10%
60
+ if gpu.load > 0.1 or (gpu.memoryUsed / gpu.memoryTotal) > 0.1:
61
+ gpu_info["gpus_used"] += 1
62
+
63
+ except Exception as e:
64
+ gpu_info["error"] = f"GPUtil error: {str(e)}"
65
+
66
+ # Try PyTorch as fallback/additional info
67
+ if TORCH_AVAILABLE:
68
+ try:
69
+ gpu_info["cuda_available"] = torch.cuda.is_available()
70
+ if torch.cuda.is_available():
71
+ cuda_count = torch.cuda.device_count()
72
+ if not GPU_AVAILABLE: # Only use torch info if GPUtil not available
73
+ gpu_info["gpus_available"] = cuda_count
74
+ gpu_info["gpu_library"] = "PyTorch"
75
+
76
+ for i in range(cuda_count):
77
+ props = torch.cuda.get_device_properties(i)
78
+ memory_allocated = (
79
+ torch.cuda.memory_allocated(i) / 1024**2
80
+ ) # MB
81
+ memory_cached = (
82
+ torch.cuda.memory_reserved(i) / 1024**2
83
+ ) # MB
84
+ memory_total = props.total_memory / 1024**2 # MB
85
+
86
+ gpu_detail = {
87
+ "id": i,
88
+ "name": props.name,
89
+ "memory_total": f"{memory_total:.0f}MB",
90
+ "memory_allocated": f"{memory_allocated:.0f}MB",
91
+ "memory_cached": f"{memory_cached:.0f}MB",
92
+ "compute_capability": f"{props.major}.{props.minor}",
93
+ }
94
+ gpu_info["gpu_details"].append(gpu_detail)
95
+
96
+ # Consider GPU "used" if memory allocated > 100MB
97
+ if memory_allocated > 100:
98
+ gpu_info["gpus_used"] += 1
99
+
100
+ except Exception as e:
101
+ gpu_info["torch_error"] = f"PyTorch error: {str(e)}"
102
+
103
+ return gpu_info
104
+
105
+ def get_system_info(self) -> Dict[str, Any]:
106
+ """Get system information including CPU, memory, and disk usage."""
107
+ memory = psutil.virtual_memory()
108
+ disk = psutil.disk_usage("/")
109
+ cpu_percent = psutil.cpu_percent(interval=1)
110
+
111
+ return {
112
+ "platform": platform.system(),
113
+ "platform_release": platform.release(),
114
+ "platform_version": platform.version(),
115
+ "architecture": platform.machine(),
116
+ "processor": platform.processor(),
117
+ "cpu_usage": f"{cpu_percent}%",
118
+ "memory": {
119
+ "total": f"{memory.total / 1024**3:.2f}GB",
120
+ "available": f"{memory.available / 1024**3:.2f}GB",
121
+ "used": f"{memory.used / 1024**3:.2f}GB",
122
+ "percentage": f"{memory.percent}%",
123
+ },
124
+ "disk": {
125
+ "total": f"{disk.total / 1024**3:.2f}GB",
126
+ "free": f"{disk.free / 1024**3:.2f}GB",
127
+ "used": f"{disk.used / 1024**3:.2f}GB",
128
+ "percentage": f"{(disk.used / disk.total) * 100:.1f}%",
129
+ },
130
+ "gpu": self.get_gpu_info(),
131
+ }
132
+
133
+ def get_health_status(self) -> Dict[str, Any]:
134
+ """Get comprehensive health status including system and GPU information."""
135
+ version = self.config.get_arbor_version() if self.config else "unknown"
136
+ versions = (
137
+ self.config.get_system_versions() if self.config else {"arbor": version}
138
+ )
139
+
140
+ return {
141
+ "status": "healthy",
142
+ "version": version, # Keep for backward compatibility
143
+ "versions": versions, # Comprehensive version info
144
+ "timestamp": datetime.now().isoformat(),
145
+ "system": self.get_system_info(),
146
+ }
147
+
148
+ def is_healthy(self) -> bool:
149
+ """Check if the system is healthy based on various metrics."""
150
+ try:
151
+ # Check memory usage (unhealthy if > 90%)
152
+ memory = psutil.virtual_memory()
153
+ if memory.percent > 90:
154
+ print(f"Memory usage is {memory.percent}%")
155
+ return False
156
+
157
+ # Check disk usage (unhealthy if > 95%)
158
+ disk = psutil.disk_usage("/")
159
+ if (disk.used / disk.total) * 100 > 95:
160
+ print(f"Disk usage is {disk.used / disk.total * 100}%")
161
+ return False
162
+
163
+ # Check CPU usage (unhealthy if > 95% sustained)
164
+ cpu_percent = psutil.cpu_percent(interval=2)
165
+ if cpu_percent > 95:
166
+ print(f"CPU usage is {cpu_percent}%")
167
+ return False
168
+
169
+ return True
170
+ except Exception:
171
+ return False
@@ -4,6 +4,7 @@ import asyncio
4
4
  import atexit
5
5
  import logging
6
6
  import time
7
+ import traceback
7
8
  from typing import Optional
8
9
 
9
10
  import httpx
@@ -131,11 +132,11 @@ class VLLMClient:
131
132
  ) from exc
132
133
  else:
133
134
  if response.status_code == 200:
134
- logger.info("Server is up!")
135
+ logger.debug("Server is up!")
135
136
  return None
136
137
 
137
- # Retry logic: wait before trying again
138
- logger.info(
138
+ # Retry logic: wait before tryng again
139
+ logger.debug(
139
140
  f"Server is not up yet. Retrying in {retry_interval} seconds..."
140
141
  )
141
142
  time.sleep(retry_interval)
@@ -254,7 +255,8 @@ class VLLMClient:
254
255
  await asyncio.sleep(INFERENCE_RETRY_DELAY)
255
256
  else:
256
257
  logger.error(
257
- f"Request failed after {MAX_INFERENCE_RETRIES} retries"
258
+ f"Request failed after {MAX_INFERENCE_RETRIES} retries. Error: {e}\n"
259
+ f"Stack trace:\n{traceback.format_exc()}"
258
260
  )
259
261
  raise
260
262