arbor-ai 0.2__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/__init__.py +17 -0
  2. arbor/cli.py +83 -43
  3. arbor/client/arbor_client.py +259 -0
  4. arbor/server/api/models/schemas.py +3 -1
  5. arbor/server/api/routes/grpo.py +2 -6
  6. arbor/server/api/routes/inference.py +7 -3
  7. arbor/server/core/config.py +293 -7
  8. arbor/server/core/config_manager.py +100 -0
  9. arbor/server/main.py +26 -1
  10. arbor/server/services/comms/comms.py +13 -9
  11. arbor/server/services/file_manager.py +7 -4
  12. arbor/server/services/grpo_manager.py +101 -63
  13. arbor/server/services/health_manager.py +171 -0
  14. arbor/server/services/inference/vllm_client.py +8 -5
  15. arbor/server/services/inference_manager.py +40 -38
  16. arbor/server/services/job_manager.py +2 -2
  17. arbor/server/services/scripts/grpo_training.py +62 -281
  18. arbor/server/services/scripts/mmgrpo_training.py +510 -0
  19. arbor/server/services/scripts/sft_training.py +8 -5
  20. arbor/server/services/scripts/utils/callbacks.py +33 -0
  21. arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  22. arbor/server/services/scripts/utils/dataset.py +176 -0
  23. arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  24. arbor/server/services/scripts/utils/mock_server.py +124 -0
  25. arbor/server/services/training_manager.py +4 -4
  26. arbor/server/utils/logging.py +298 -0
  27. {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +18 -18
  28. arbor_ai-0.2.2.dist-info/RECORD +51 -0
  29. arbor_ai-0.2.dist-info/RECORD +0 -42
  30. {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ import json
3
3
  import os
4
4
  import random
5
5
  import signal
6
- import socket
7
6
  import string
8
7
  import subprocess
9
8
  import sys
@@ -20,14 +19,17 @@ from arbor.server.api.models.schemas import (
20
19
  GRPOConfigRequest,
21
20
  GRPORequest,
22
21
  )
23
- from arbor.server.core.config import Settings
22
+ from arbor.server.core.config import Config
24
23
  from arbor.server.services.comms.comms import ArborServerCommsHandler
25
24
  from arbor.server.services.inference_manager import InferenceManager
25
+ from arbor.server.utils.logging import get_logger
26
+
27
+ logger = get_logger(__name__)
26
28
 
27
29
 
28
30
  class GRPOManager:
29
- def __init__(self, settings: Settings):
30
- self.settings = settings
31
+ def __init__(self, config: Config):
32
+ self.config = config
31
33
  self.training_process = None
32
34
  self.current_model = None
33
35
  self.train_kwargs = None
@@ -47,7 +49,7 @@ class GRPOManager:
47
49
 
48
50
  def _signal_handler(self, signum, frame):
49
51
  """Handle keyboard interrupt (SIGINT) gracefully."""
50
- print("\nReceived keyboard interrupt. Shutting down gracefully...")
52
+ logger.info("Received keyboard interrupt. Shutting down gracefully...")
51
53
  # Sleep for a bit to let async operations go through
52
54
  time.sleep(2)
53
55
  if self.training_process is not None:
@@ -65,16 +67,14 @@ class GRPOManager:
65
67
  )
66
68
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
67
69
  name = f"grpo:{model_name}:{suffix}:{timestamp}"
68
- return name, str(Path(self.settings.STORAGE_PATH).resolve() / "models" / name)
70
+ return name, str(Path(self.config.STORAGE_PATH).resolve() / "models" / name)
69
71
 
70
72
  def find_training_args(self, request: GRPOConfigRequest) -> dict:
71
73
  """Process the config request and return training arguments."""
72
74
  name, output_dir = self.make_output_dir(request.model, request.suffix)
73
75
 
74
76
  # Here are defaults for training. We can adjust them if we disagree w the huggingface defaults
75
- default_train_kwargs = {
76
- "output_dir": output_dir,
77
- }
77
+ default_train_kwargs = {"output_dir": output_dir, "grpo_flavor": "grpo"}
78
78
 
79
79
  train_kwargs = request.model_dump(exclude_unset=True)
80
80
  return {**default_train_kwargs, **(train_kwargs or {})}
@@ -108,7 +108,7 @@ class GRPOManager:
108
108
  key: train_kwargs[key] for key in trl_keys if key in train_kwargs
109
109
  }
110
110
 
111
- arbor_keys = ["max_context_length", "lora"]
111
+ arbor_keys = ["max_context_length", "lora", "wandb_kwargs", "grpo_flavor"]
112
112
  arbor_train_kwargs = {
113
113
  key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
114
114
  }
@@ -121,7 +121,7 @@ class GRPOManager:
121
121
  """Initialize the training process with ZMQ-based communication."""
122
122
  self.train_kwargs = self.find_training_args(request)
123
123
 
124
- trl_train_kwargs, arbor_train_kwargs = self.process_training_args(
124
+ self.trl_train_kwargs, self.arbor_train_kwargs = self.process_training_args(
125
125
  self.train_kwargs
126
126
  )
127
127
 
@@ -132,31 +132,34 @@ class GRPOManager:
132
132
  # launch_kwargs = {
133
133
  # k: v for k, v in arbor_train_kwargs.items() if k in ["max_context_length"]
134
134
  # }
135
- inference_manager.launch_kwargs["max_context_length"] = arbor_train_kwargs.get(
136
- "max_context_length", None
135
+ inference_manager.launch_kwargs["max_context_length"] = (
136
+ self.arbor_train_kwargs.get("max_context_length", None)
137
137
  )
138
- print("Launching inference server...")
138
+ logger.info("Launching inference server...")
139
139
  inference_manager.launch(self.current_model)
140
140
 
141
141
  # Initialize ZMQ socket manager - no need for connection acceptance thread anymore
142
142
  self.server_comms_handler = ArborServerCommsHandler()
143
143
 
144
144
  script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
145
- script_path = os.path.join(script_dir, "grpo_training.py")
145
+ script_name = {"mmgrpo": "mmgrpo_training.py", "grpo": "grpo_training.py"}[
146
+ self.arbor_train_kwargs["grpo_flavor"]
147
+ ]
148
+ script_path = os.path.join(script_dir, script_name)
146
149
 
147
150
  # Start the training process with ZMQ ports
148
151
  my_env = os.environ.copy()
149
- my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
152
+ my_env["CUDA_VISIBLE_DEVICES"] = self.config.arbor_config.training.gpu_ids
150
153
  # WandB can block the training process for login, so we silence it
151
154
  my_env["WANDB_SILENT"] = "true"
152
155
 
153
- num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
156
+ num_processes = self.config.arbor_config.training.gpu_ids.count(",") + 1
154
157
 
155
158
  # This is the port for the accelerate main process
156
159
  main_process_port = get_free_port()
157
160
 
158
161
  params = [
159
- "python",
162
+ sys.executable,
160
163
  "-m",
161
164
  "accelerate.commands.launch",
162
165
  "--num_processes",
@@ -164,11 +167,11 @@ class GRPOManager:
164
167
  "--main_process_port",
165
168
  str(main_process_port),
166
169
  ]
167
- if self.settings.arbor_config.training.accelerate_config:
170
+ if self.config.arbor_config.training.accelerate_config:
168
171
  params.extend(
169
172
  [
170
173
  "--config_file",
171
- self.settings.arbor_config.training.accelerate_config,
174
+ self.config.arbor_config.training.accelerate_config,
172
175
  ]
173
176
  )
174
177
  params.extend(
@@ -195,12 +198,12 @@ class GRPOManager:
195
198
  "--model",
196
199
  self.current_model,
197
200
  "--trl_train_kwargs",
198
- json.dumps(trl_train_kwargs),
201
+ json.dumps(self.trl_train_kwargs),
199
202
  "--arbor_train_kwargs",
200
- json.dumps(arbor_train_kwargs),
203
+ json.dumps(self.arbor_train_kwargs),
201
204
  ]
202
205
  )
203
- print(f"Running following command\n: {' '.join(params)}")
206
+ logger.info(f"Running GRPO training command: {' '.join(params)}")
204
207
 
205
208
  self.training_process = subprocess.Popen(
206
209
  params,
@@ -222,9 +225,9 @@ class GRPOManager:
222
225
  break
223
226
  if line:
224
227
  buffer.append(line)
225
- # Print only if stop_event is not set
228
+ # Log only if stop_event is not set
226
229
  if not stop_event.is_set():
227
- print(f"[GRPO LOG] {line}", end="")
230
+ logger.info(f"[GRPO LOG] {line.strip()}")
228
231
 
229
232
  # Start a background thread to read from the process continuously
230
233
  thread = threading.Thread(
@@ -256,10 +259,10 @@ class GRPOManager:
256
259
 
257
260
  def _handle_status_updates(self, inference_manager: InferenceManager):
258
261
  """Handle status updates from training process using ZMQ SUB socket"""
259
- print("Starting status update handler...")
262
+ logger.info("Starting status update handler...")
260
263
  try:
261
264
  for status in self.server_comms_handler.receive_status():
262
- print(f"Received status update: {status}")
265
+ logger.debug(f"Received status update: {status}")
263
266
  if status["status"] == "weight_update_start":
264
267
  # Block inference calls by incrementing counter
265
268
  inference_manager.start_weight_update()
@@ -267,47 +270,85 @@ class GRPOManager:
267
270
  # Decrement counter to potentially allow inference calls again
268
271
  inference_manager.complete_weight_update()
269
272
  elif status["status"] == "model_saved":
270
- print("Updating inference model...")
273
+ logger.info("Updating inference model...")
271
274
  # There is a case where this status is sent multiple times
272
275
  # We need to make sure we only update the model once
273
- self.current_model = status["output_dir"]
274
276
  self.saving_model = False
275
- print("Model update complete")
277
+ logger.info("Model update complete")
276
278
  elif status["status"] == "checkpoint_saved":
277
- print("Received checkpoint saved status")
279
+ logger.info("Received checkpoint saved status")
278
280
  self.checkpoints[status["checkpoint_name"]] = status["output_dir"]
279
281
  self.last_checkpoint = status["checkpoint_name"]
280
282
  self.saving_checkpoint = False
281
- print("Checkpoint saved")
283
+ logger.info("Checkpoint saved")
282
284
  elif status["status"] == "error":
283
- print(f"Training error: {status.get('error', 'Unknown error')}")
285
+ error_msg = status.get("error", "Unknown error")
286
+ logger.error(f"Training error: {error_msg}")
284
287
  elif status["status"] == "terminated":
285
288
  self.terminating = False
286
- print("Training process terminated")
289
+ logger.info("Training process terminated")
287
290
  except Exception as e:
288
- print(f"Error in status update handler: {e}")
291
+ logger.error(f"Error in status update handler: {e}")
289
292
  # Make sure to allow inference if there's an error
290
293
  try:
291
294
  inference_manager.complete_weight_update()
292
295
  except:
293
296
  pass
294
297
 
298
+ def validate_batch(self, batch):
299
+ if not isinstance(batch, list):
300
+ raise ValueError("Batch must be a list")
301
+
302
+ if self.arbor_train_kwargs["grpo_flavor"] == "mmgrpo":
303
+ for group in batch:
304
+ if not isinstance(group, list):
305
+ raise ValueError("Each group in batch must be a list")
306
+ for item in group:
307
+ if not isinstance(item, dict):
308
+ raise ValueError("Each item in group must be a dictionary")
309
+ required_keys = {"messages", "completion", "advantage"}
310
+ if not all(key in item for key in required_keys):
311
+ raise ValueError(
312
+ f"Each item must contain keys: {required_keys}"
313
+ )
314
+ return True
315
+ elif self.arbor_train_kwargs["grpo_flavor"] == "grpo":
316
+ for item in batch:
317
+ if not isinstance(item, dict):
318
+ raise ValueError("Each item in batch must be a dictionary")
319
+ required_keys = {"messages", "completion", "reward"}
320
+ if not all(key in item for key in required_keys):
321
+ raise ValueError(f"Each item must contain keys: {required_keys}")
322
+ return True
323
+ else:
324
+ raise NotImplementedError(
325
+ f"GRPO flavor batch validation not implemented for {self.arbor_train_kwargs['grpo_flavor']}"
326
+ )
327
+
295
328
  def grpo_step(
296
329
  self, request: GRPORequest, inference_manager: InferenceManager
297
330
  ) -> str:
298
331
  while self.saving_checkpoint:
299
- print("Saving checkpoint, pausing GRPO steps until checkpoint is saved...")
332
+ logger.info(
333
+ "Saving checkpoint, pausing GRPO steps until checkpoint is saved..."
334
+ )
300
335
  time.sleep(5)
301
336
 
337
+ self.validate_batch(request.batch)
338
+
302
339
  try:
340
+
303
341
  # Send the batch to the training process
304
342
  self.server_comms_handler.send_data(request.batch)
305
343
  self.data_count += 1
306
344
 
307
345
  except Exception as e:
308
- print(f"Failed to send batch to training process: {e}")
346
+ logger.error(f"Failed to send batch to training process: {e}")
309
347
  raise
310
348
 
349
+ self.current_model = self.train_kwargs["output_dir"]
350
+ inference_manager.launched_model = self.current_model
351
+
311
352
  return {
312
353
  "current_model": self.current_model,
313
354
  "checkpoints": self.checkpoints,
@@ -320,7 +361,7 @@ class GRPOManager:
320
361
  while (
321
362
  inference_manager.is_updating
322
363
  ): # Use the property instead of direct access
323
- print("Waiting for weight updates to finish before checkpointing...")
364
+ logger.info("Waiting for weight updates to finish before checkpointing...")
324
365
  time.sleep(5)
325
366
 
326
367
  self.saving_checkpoint = True
@@ -328,7 +369,7 @@ class GRPOManager:
328
369
  {"command": "save_checkpoint", "checkpoint_name": request.checkpoint_name}
329
370
  )
330
371
  while self.saving_checkpoint:
331
- print("Waiting for checkpoint to be saved...")
372
+ logger.info("Waiting for checkpoint to be saved...")
332
373
  time.sleep(5)
333
374
  return {
334
375
  "current_model": self.current_model,
@@ -343,14 +384,14 @@ class GRPOManager:
343
384
  while (
344
385
  inference_manager and inference_manager.is_updating
345
386
  ): # Use the property instead of direct access
346
- print("Waiting for final weight updates to finish before saving...")
387
+ logger.info("Waiting for final weight updates to finish before saving...")
347
388
  time.sleep(5)
348
389
 
349
- print("sending save model command")
390
+ logger.info("Sending save model command")
350
391
  self.saving_model = True
351
392
  self.server_comms_handler.send_command({"command": "save_model"})
352
393
  while self.saving_model:
353
- print("Waiting for final model to be saved...")
394
+ logger.info("Waiting for final model to be saved...")
354
395
  time.sleep(5)
355
396
 
356
397
  termination_data = {
@@ -359,37 +400,34 @@ class GRPOManager:
359
400
  "last_checkpoint": self.last_checkpoint,
360
401
  }
361
402
 
362
- print("sending termination command")
403
+ logger.info("Sending termination command")
363
404
  self.terminating = True
364
405
  self.server_comms_handler.send_command({"command": "terminate"})
365
- print("Waiting for training process to finish...")
406
+ logger.info("Waiting for training process to finish...")
366
407
 
367
408
  # Wait for at most 15 seconds for termination
368
409
  start_time = time.time()
369
410
  while self.terminating:
370
411
  if time.time() - start_time > 15:
371
- print(
412
+ logger.warning(
372
413
  "Termination wait timed out after 15 seconds, proceeding with cleanup..."
373
414
  )
374
415
  break
375
- print("Waiting for run to be terminated...")
416
+ logger.info("Waiting for run to be terminated...")
376
417
  time.sleep(3)
377
418
 
378
- print("Doing cleanup")
419
+ logger.info("Starting cleanup")
379
420
  self.cleanup_termination(inference_manager)
380
421
 
381
422
  if self.train_kwargs and "output_dir" in self.train_kwargs:
382
- print(
383
- f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
384
- )
385
- if not os.path.exists(self.train_kwargs["output_dir"]):
386
- print(
387
- f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
388
- )
389
423
  output_dir = self.train_kwargs["output_dir"]
424
+ logger.info(f"Training completed. Model saved to {output_dir}")
425
+ logger.info(f"Training logs and checkpoints are stored in: {output_dir}")
426
+ if not os.path.exists(output_dir):
427
+ logger.warning(f"Output directory {output_dir} does not exist")
390
428
  self.train_kwargs = None
391
429
  else:
392
- print("Training terminated, no output directory specified")
430
+ logger.info("Training terminated, no output directory specified")
393
431
  self.train_kwargs = None
394
432
 
395
433
  return termination_data
@@ -398,7 +436,7 @@ class GRPOManager:
398
436
  try:
399
437
  # Kill training process and all its children (accelerate launcher creates multiple processes)
400
438
  if self.training_process:
401
- print("Terminating training process and its children...")
439
+ logger.info("Terminating training process and its children...")
402
440
  try:
403
441
  parent = psutil.Process(self.training_process.pid)
404
442
  # Get all child processes including grandchildren
@@ -425,9 +463,9 @@ class GRPOManager:
425
463
  pass
426
464
 
427
465
  except psutil.NoSuchProcess:
428
- print(f"Process {self.training_process.pid} not found")
466
+ logger.warning(f"Process {self.training_process.pid} not found")
429
467
  except Exception as e:
430
- print(f"Error killing training process tree: {e}")
468
+ logger.error(f"Error killing training process tree: {e}")
431
469
  # Fallback to basic termination
432
470
  self.training_process.terminate()
433
471
  try:
@@ -438,11 +476,11 @@ class GRPOManager:
438
476
 
439
477
  # Clean up ZMQ connections
440
478
  if self.server_comms_handler:
441
- print("Closing ZMQ connections...")
479
+ logger.debug("Closing ZMQ connections...")
442
480
  self.server_comms_handler.close()
443
481
 
444
482
  if inference_manager and inference_manager.process is not None:
445
- print("Killing inference manager...")
483
+ logger.info("Killing inference manager...")
446
484
  inference_manager.kill()
447
485
 
448
486
  # Reinitialize in case we want to start a new training run
@@ -451,9 +489,9 @@ class GRPOManager:
451
489
  self.server_comms_handler = None
452
490
  self.status_thread = None
453
491
  self.data_count = 0
454
- print("Cleanup completed successfully")
492
+ logger.info("Cleanup completed successfully")
455
493
  except Exception as e:
456
- print(f"Error during cleanup: {e}")
494
+ logger.error(f"Error during cleanup: {e}")
457
495
  # Still reset state even if cleanup fails
458
496
  self.training_process = None
459
497
  self.current_model = None
@@ -476,5 +514,5 @@ def get_free_port() -> int:
476
514
  s.bind(("localhost", 0))
477
515
  ports.append(s.getsockname()[1])
478
516
  except Exception as e:
479
- print(f"Error binding to port: {e}")
517
+ logger.error(f"Error binding to port: {e}")
480
518
  return random.choice(ports)
@@ -0,0 +1,171 @@
1
+ import platform
2
+ from datetime import datetime
3
+ from typing import Any, Dict
4
+
5
+ import psutil
6
+
7
+ from arbor.server.core.config import Config
8
+
9
+ try:
10
+ import GPUtil
11
+
12
+ GPU_AVAILABLE = True
13
+ except ImportError:
14
+ GPU_AVAILABLE = False
15
+
16
+ try:
17
+ import torch
18
+
19
+ TORCH_AVAILABLE = True
20
+ except ImportError:
21
+ TORCH_AVAILABLE = False
22
+
23
+
24
+ class HealthManager:
25
+ """Manages system health checks including GPU monitoring."""
26
+
27
+ def __init__(self, config: Config = None):
28
+ self.config = config
29
+
30
+ def get_gpu_info(self) -> Dict[str, Any]:
31
+ """Get GPU information including available and used GPUs."""
32
+ gpu_info = {
33
+ "gpus_available": 0,
34
+ "gpus_used": 0,
35
+ "gpu_details": [],
36
+ "cuda_available": False,
37
+ "gpu_library": "none",
38
+ }
39
+
40
+ # Try GPUtil first
41
+ if GPU_AVAILABLE:
42
+ try:
43
+ gpus = GPUtil.getGPUs()
44
+ gpu_info["gpus_available"] = len(gpus)
45
+ gpu_info["gpu_library"] = "GPUtil"
46
+
47
+ for i, gpu in enumerate(gpus):
48
+ gpu_detail = {
49
+ "id": gpu.id,
50
+ "name": gpu.name,
51
+ "memory_total": f"{gpu.memoryTotal}MB",
52
+ "memory_used": f"{gpu.memoryUsed}MB",
53
+ "memory_free": f"{gpu.memoryFree}MB",
54
+ "utilization": f"{gpu.load * 100:.1f}%",
55
+ "temperature": f"{gpu.temperature}°C",
56
+ }
57
+ gpu_info["gpu_details"].append(gpu_detail)
58
+
59
+ # Consider GPU "used" if utilization > 10% or memory usage > 10%
60
+ if gpu.load > 0.1 or (gpu.memoryUsed / gpu.memoryTotal) > 0.1:
61
+ gpu_info["gpus_used"] += 1
62
+
63
+ except Exception as e:
64
+ gpu_info["error"] = f"GPUtil error: {str(e)}"
65
+
66
+ # Try PyTorch as fallback/additional info
67
+ if TORCH_AVAILABLE:
68
+ try:
69
+ gpu_info["cuda_available"] = torch.cuda.is_available()
70
+ if torch.cuda.is_available():
71
+ cuda_count = torch.cuda.device_count()
72
+ if not GPU_AVAILABLE: # Only use torch info if GPUtil not available
73
+ gpu_info["gpus_available"] = cuda_count
74
+ gpu_info["gpu_library"] = "PyTorch"
75
+
76
+ for i in range(cuda_count):
77
+ props = torch.cuda.get_device_properties(i)
78
+ memory_allocated = (
79
+ torch.cuda.memory_allocated(i) / 1024**2
80
+ ) # MB
81
+ memory_cached = (
82
+ torch.cuda.memory_reserved(i) / 1024**2
83
+ ) # MB
84
+ memory_total = props.total_memory / 1024**2 # MB
85
+
86
+ gpu_detail = {
87
+ "id": i,
88
+ "name": props.name,
89
+ "memory_total": f"{memory_total:.0f}MB",
90
+ "memory_allocated": f"{memory_allocated:.0f}MB",
91
+ "memory_cached": f"{memory_cached:.0f}MB",
92
+ "compute_capability": f"{props.major}.{props.minor}",
93
+ }
94
+ gpu_info["gpu_details"].append(gpu_detail)
95
+
96
+ # Consider GPU "used" if memory allocated > 100MB
97
+ if memory_allocated > 100:
98
+ gpu_info["gpus_used"] += 1
99
+
100
+ except Exception as e:
101
+ gpu_info["torch_error"] = f"PyTorch error: {str(e)}"
102
+
103
+ return gpu_info
104
+
105
+ def get_system_info(self) -> Dict[str, Any]:
106
+ """Get system information including CPU, memory, and disk usage."""
107
+ memory = psutil.virtual_memory()
108
+ disk = psutil.disk_usage("/")
109
+ cpu_percent = psutil.cpu_percent(interval=1)
110
+
111
+ return {
112
+ "platform": platform.system(),
113
+ "platform_release": platform.release(),
114
+ "platform_version": platform.version(),
115
+ "architecture": platform.machine(),
116
+ "processor": platform.processor(),
117
+ "cpu_usage": f"{cpu_percent}%",
118
+ "memory": {
119
+ "total": f"{memory.total / 1024**3:.2f}GB",
120
+ "available": f"{memory.available / 1024**3:.2f}GB",
121
+ "used": f"{memory.used / 1024**3:.2f}GB",
122
+ "percentage": f"{memory.percent}%",
123
+ },
124
+ "disk": {
125
+ "total": f"{disk.total / 1024**3:.2f}GB",
126
+ "free": f"{disk.free / 1024**3:.2f}GB",
127
+ "used": f"{disk.used / 1024**3:.2f}GB",
128
+ "percentage": f"{(disk.used / disk.total) * 100:.1f}%",
129
+ },
130
+ "gpu": self.get_gpu_info(),
131
+ }
132
+
133
+ def get_health_status(self) -> Dict[str, Any]:
134
+ """Get comprehensive health status including system and GPU information."""
135
+ version = self.config.get_arbor_version() if self.config else "unknown"
136
+ versions = (
137
+ self.config.get_system_versions() if self.config else {"arbor": version}
138
+ )
139
+
140
+ return {
141
+ "status": "healthy",
142
+ "version": version, # Keep for backward compatibility
143
+ "versions": versions, # Comprehensive version info
144
+ "timestamp": datetime.now().isoformat(),
145
+ "system": self.get_system_info(),
146
+ }
147
+
148
+ def is_healthy(self) -> bool:
149
+ """Check if the system is healthy based on various metrics."""
150
+ try:
151
+ # Check memory usage (unhealthy if > 90%)
152
+ memory = psutil.virtual_memory()
153
+ if memory.percent > 90:
154
+ print(f"Memory usage is {memory.percent}%")
155
+ return False
156
+
157
+ # Check disk usage (unhealthy if > 95%)
158
+ disk = psutil.disk_usage("/")
159
+ if (disk.used / disk.total) * 100 > 95:
160
+ print(f"Disk usage is {disk.used / disk.total * 100}%")
161
+ return False
162
+
163
+ # Check CPU usage (unhealthy if > 95% sustained)
164
+ cpu_percent = psutil.cpu_percent(interval=2)
165
+ if cpu_percent > 95:
166
+ print(f"CPU usage is {cpu_percent}%")
167
+ return False
168
+
169
+ return True
170
+ except Exception:
171
+ return False
@@ -1,8 +1,10 @@
1
1
  # adapted from Will Brown's verifiers library (https://github.com/willccbb/verifiers)
2
2
 
3
+ import asyncio
3
4
  import atexit
4
5
  import logging
5
6
  import time
7
+ import traceback
6
8
  from typing import Optional
7
9
 
8
10
  import httpx
@@ -130,11 +132,11 @@ class VLLMClient:
130
132
  ) from exc
131
133
  else:
132
134
  if response.status_code == 200:
133
- logger.info("Server is up!")
135
+ logger.debug("Server is up!")
134
136
  return None
135
137
 
136
- # Retry logic: wait before trying again
137
- logger.info(
138
+ # Retry logic: wait before tryng again
139
+ logger.debug(
138
140
  f"Server is not up yet. Retrying in {retry_interval} seconds..."
139
141
  )
140
142
  time.sleep(retry_interval)
@@ -239,7 +241,7 @@ class VLLMClient:
239
241
  response.raise_for_status()
240
242
  return response.json()
241
243
 
242
- except httpx.TimeoutError:
244
+ except httpx.TimeoutException:
243
245
  logger.error("Request timed out")
244
246
  raise
245
247
  except InferenceBlockedError:
@@ -253,7 +255,8 @@ class VLLMClient:
253
255
  await asyncio.sleep(INFERENCE_RETRY_DELAY)
254
256
  else:
255
257
  logger.error(
256
- f"Request failed after {MAX_INFERENCE_RETRIES} retries"
258
+ f"Request failed after {MAX_INFERENCE_RETRIES} retries. Error: {e}\n"
259
+ f"Stack trace:\n{traceback.format_exc()}"
257
260
  )
258
261
  raise
259
262