arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/__init__.py +17 -0
  2. arbor/cli.py +83 -43
  3. arbor/client/arbor_client.py +259 -0
  4. arbor/server/api/models/schemas.py +3 -1
  5. arbor/server/api/routes/grpo.py +2 -6
  6. arbor/server/api/routes/inference.py +7 -3
  7. arbor/server/core/config.py +293 -7
  8. arbor/server/core/config_manager.py +100 -0
  9. arbor/server/main.py +26 -1
  10. arbor/server/services/comms/comms.py +13 -9
  11. arbor/server/services/file_manager.py +7 -4
  12. arbor/server/services/grpo_manager.py +98 -62
  13. arbor/server/services/health_manager.py +171 -0
  14. arbor/server/services/inference/vllm_client.py +6 -4
  15. arbor/server/services/inference_manager.py +40 -38
  16. arbor/server/services/job_manager.py +2 -2
  17. arbor/server/services/scripts/grpo_training.py +62 -281
  18. arbor/server/services/scripts/mmgrpo_training.py +510 -0
  19. arbor/server/services/scripts/sft_training.py +8 -5
  20. arbor/server/services/scripts/utils/callbacks.py +33 -0
  21. arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  22. arbor/server/services/scripts/utils/dataset.py +176 -0
  23. arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  24. arbor/server/services/scripts/utils/mock_server.py +124 -0
  25. arbor/server/services/training_manager.py +4 -4
  26. arbor/server/utils/logging.py +298 -0
  27. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
  28. arbor_ai-0.2.2.dist-info/RECORD +51 -0
  29. arbor_ai-0.2.1.dist-info/RECORD +0 -42
  30. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,18 @@
1
+ import datetime
2
+ import os
3
+ import subprocess
4
+ import sys
1
5
  from pathlib import Path
2
- from typing import Optional
6
+ from typing import Any, ClassVar, Dict, Optional
3
7
 
4
8
  import yaml
5
- from pydantic import BaseModel, ConfigDict
9
+ from pydantic import BaseModel
10
+
11
+ try:
12
+ from importlib.metadata import PackageNotFoundError, version
13
+ except ImportError:
14
+ # For Python < 3.8
15
+ from importlib_metadata import PackageNotFoundError, version
6
16
 
7
17
 
8
18
  class InferenceConfig(BaseModel):
@@ -19,16 +29,267 @@ class ArborConfig(BaseModel):
19
29
  training: TrainingConfig
20
30
 
21
31
 
22
- class Settings(BaseModel):
23
-
24
- STORAGE_PATH: str = "./storage"
32
+ class Config(BaseModel):
33
+ STORAGE_PATH: ClassVar[str] = str(Path.home() / ".arbor" / "storage")
25
34
  INACTIVITY_TIMEOUT: int = 30 # 5 seconds
26
35
  arbor_config: ArborConfig
27
36
 
37
+ @staticmethod
38
+ def validate_storage_path(storage_path: str):
39
+ """Validates a storage path, return True for success, False if failed."""
40
+ try:
41
+ if not Path(storage_path).exists():
42
+ return False
43
+ return True
44
+
45
+ except Exception as e:
46
+ return False
47
+
48
+ @classmethod
49
+ def set_storage_path(cls, storage_path: str):
50
+ """Set a valid storage path to use, return True for success, False if failed."""
51
+ if not cls.validate_storage_path(storage_path):
52
+ return False
53
+
54
+ cls.STORAGE_PATH = storage_path
55
+
56
+ return True
57
+
58
+ @staticmethod
59
+ def validate_storage_path(storage_path: str) -> None:
60
+ """Validates a storage path, raises exception if invalid."""
61
+ if not storage_path:
62
+ raise ValueError("Storage path cannot be empty")
63
+
64
+ path = Path(storage_path)
65
+
66
+ if not path.exists():
67
+ raise FileNotFoundError(f"Storage path does not exist: {storage_path}")
68
+
69
+ if not path.is_dir():
70
+ raise NotADirectoryError(f"Storage path is not a directory: {storage_path}")
71
+
72
+ # Check if we can write to the directory
73
+ if not os.access(path, os.W_OK):
74
+ raise PermissionError(
75
+ f"No write permission for storage path: {storage_path}"
76
+ )
77
+
78
+ @classmethod
79
+ def set_storage_path(cls, storage_path: str) -> None:
80
+ """Set a valid storage path to use, raises exception if invalid."""
81
+ cls.validate_storage_path(storage_path) # raises if invalid
82
+ cls.STORAGE_PATH = storage_path
83
+
84
+ @classmethod
85
+ def make_log_dir(cls, storage_path: str = None):
86
+ """Create a timestamped log directory under the storage path."""
87
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
88
+
89
+ log_dir = Path(
90
+ storage_path if storage_path else cls.STORAGE_PATH / "logs" / timestamp
91
+ )
92
+ log_dir.mkdir(exist_ok=True)
93
+
94
+ return log_dir
95
+
96
+ @staticmethod
97
+ def get_arbor_version() -> str:
98
+ """Get the installed version of arbor package."""
99
+ try:
100
+ return version("arbor-ai")
101
+ except PackageNotFoundError:
102
+ # Fallback to a default version if package not found
103
+ # This might happen in development mode
104
+ return "dev"
105
+ except Exception:
106
+ return "unknown"
107
+
108
+ @staticmethod
109
+ def get_cuda_version() -> str:
110
+ """Get CUDA runtime version."""
111
+ try:
112
+ import torch
113
+
114
+ if torch.cuda.is_available():
115
+ return torch.version.cuda
116
+ else:
117
+ return "not_available"
118
+ except ImportError:
119
+ try:
120
+ # Try getting CUDA version from nvcc
121
+ result = subprocess.run(
122
+ ["nvcc", "--version"], capture_output=True, text=True, timeout=5
123
+ )
124
+ if result.returncode == 0:
125
+ # Parse nvcc output for version
126
+ for line in result.stdout.split("\n"):
127
+ if "release" in line.lower():
128
+ # Extract version from line like "Cuda compilation tools, release 11.8, V11.8.89"
129
+ parts = line.split("release")
130
+ if len(parts) > 1:
131
+ version_part = parts[1].split(",")[0].strip()
132
+ return version_part
133
+ return "unknown"
134
+ except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
135
+ return "not_installed"
136
+ except Exception:
137
+ return "unknown"
138
+
139
+ @staticmethod
140
+ def get_nvidia_driver_version() -> str:
141
+ """Get NVIDIA driver version."""
142
+ try:
143
+ result = subprocess.run(
144
+ [
145
+ "nvidia-smi",
146
+ "--query-gpu=driver_version",
147
+ "--format=csv,noheader,nounits",
148
+ ],
149
+ capture_output=True,
150
+ text=True,
151
+ timeout=5,
152
+ )
153
+ if result.returncode == 0:
154
+ return result.stdout.strip().split("\n")[0]
155
+ return "unknown"
156
+ except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
157
+ return "not_installed"
158
+
159
+ @staticmethod
160
+ def get_python_package_version(package_name: str) -> str:
161
+ """Get version of a Python package."""
162
+ try:
163
+ return version(package_name)
164
+ except PackageNotFoundError:
165
+ return "not_installed"
166
+ except Exception:
167
+ return "unknown"
168
+
169
+ @classmethod
170
+ def get_ml_library_versions(cls) -> Dict[str, str]:
171
+ """Get versions of common ML libraries."""
172
+ libraries = {
173
+ "torch": "torch",
174
+ "transformers": "transformers",
175
+ "vllm": "vllm",
176
+ "trl": "trl",
177
+ "peft": "peft",
178
+ "accelerate": "accelerate",
179
+ "ray": "ray",
180
+ "wandb": "wandb",
181
+ "numpy": "numpy",
182
+ "pandas": "pandas",
183
+ "scikit-learn": "scikit-learn",
184
+ }
185
+
186
+ versions = {}
187
+ for lib_name, package_name in libraries.items():
188
+ versions[lib_name] = cls.get_python_package_version(package_name)
189
+
190
+ return versions
191
+
192
+ @classmethod
193
+ def get_cuda_library_versions(cls) -> Dict[str, str]:
194
+ """Get versions of CUDA-related libraries."""
195
+ cuda_info = {}
196
+
197
+ # CUDA runtime version
198
+ cuda_info["cuda_runtime"] = cls.get_cuda_version()
199
+
200
+ # NVIDIA driver version
201
+ cuda_info["nvidia_driver"] = cls.get_nvidia_driver_version()
202
+
203
+ # cuDNN version (if available through PyTorch)
204
+ try:
205
+ import torch
206
+
207
+ if torch.cuda.is_available() and hasattr(torch.backends.cudnn, "version"):
208
+ cuda_info["cudnn"] = str(torch.backends.cudnn.version())
209
+ else:
210
+ cuda_info["cudnn"] = "not_available"
211
+ except ImportError:
212
+ cuda_info["cudnn"] = "torch_not_installed"
213
+ except Exception:
214
+ cuda_info["cudnn"] = "unknown"
215
+
216
+ # NCCL version (if available through PyTorch)
217
+ try:
218
+ import torch
219
+
220
+ if torch.cuda.is_available() and hasattr(torch, "__version__"):
221
+ # NCCL version is often embedded in PyTorch build info
222
+ try:
223
+ import torch.distributed as dist
224
+
225
+ if hasattr(dist, "is_nccl_available") and dist.is_nccl_available():
226
+ # Try to get NCCL version from PyTorch
227
+ if hasattr(torch.cuda.nccl, "version"):
228
+ nccl_version = torch.cuda.nccl.version()
229
+ cuda_info["nccl"] = (
230
+ f"{nccl_version[0]}.{nccl_version[1]}.{nccl_version[2]}"
231
+ )
232
+ else:
233
+ cuda_info["nccl"] = "available"
234
+ else:
235
+ cuda_info["nccl"] = "not_available"
236
+ except Exception:
237
+ cuda_info["nccl"] = "unknown"
238
+ else:
239
+ cuda_info["nccl"] = "cuda_not_available"
240
+ except ImportError:
241
+ cuda_info["nccl"] = "torch_not_installed"
242
+ except Exception:
243
+ cuda_info["nccl"] = "unknown"
244
+
245
+ return cuda_info
246
+
28
247
  @classmethod
29
- def load_from_yaml(cls, yaml_path: str) -> "Settings":
248
+ def get_system_versions(cls) -> Dict[str, Any]:
249
+ """Get comprehensive version information for the system."""
250
+ return {
251
+ "arbor": cls.get_arbor_version(),
252
+ "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
253
+ "ml_libraries": cls.get_ml_library_versions(),
254
+ "cuda_stack": cls.get_cuda_library_versions(),
255
+ }
256
+
257
+ @classmethod
258
+ def _init_arbor_directories(cls):
259
+ arbor_root = Path.home() / ".arbor"
260
+ storage_dir = Path(cls.STORAGE_PATH)
261
+
262
+ arbor_root.mkdir(exist_ok=True)
263
+ storage_dir.mkdir(exist_ok=True)
264
+ (storage_dir / "logs").mkdir(exist_ok=True)
265
+ (storage_dir / "models").mkdir(exist_ok=True)
266
+ (storage_dir / "uploads").mkdir(exist_ok=True)
267
+
268
+ @classmethod
269
+ def use_default_config(cls) -> Optional[str]:
270
+ """Search for: ~/.arbor/config.yaml, else return None"""
271
+
272
+ # Check ~/.arbor/config.yaml
273
+ arbor_config = Path.home() / ".arbor" / "config.yaml"
274
+ if arbor_config.exists():
275
+ return str(arbor_config)
276
+
277
+ return None
278
+
279
+ @classmethod
280
+ def load_config_from_yaml(cls, yaml_path: str) -> "Config":
281
+ # If yaml file is not provided, try to use ~/.arbor/config.yaml
282
+ cls._init_arbor_directories()
283
+
284
+ if not yaml_path:
285
+ yaml_path = cls.use_default_config()
286
+
30
287
  if not yaml_path:
31
- raise ValueError("Config file path is required")
288
+ raise ValueError(
289
+ "No config file found. Please create ~/.arbor/config.yaml or "
290
+ "provide a config file path with --arbor-config"
291
+ )
292
+
32
293
  if not Path(yaml_path).exists():
33
294
  raise ValueError(f"Config file {yaml_path} does not exist")
34
295
 
@@ -42,6 +303,31 @@ class Settings(BaseModel):
42
303
  training=TrainingConfig(**config["training"]),
43
304
  )
44
305
  )
306
+
307
+ storage_path = config.get("storage_path")
308
+ if storage_path:
309
+ cls.set_storage_path(storage_path)
310
+
45
311
  return settings
46
312
  except Exception as e:
47
313
  raise ValueError(f"Error loading config file {yaml_path}: {e}")
314
+
315
+ @classmethod
316
+ def load_config_directly(
317
+ cls,
318
+ storage_path: str = None,
319
+ inference_gpus: str = "0",
320
+ training_gpus: str = "1,2",
321
+ ):
322
+ cls._init_arbor_directories()
323
+
324
+ # create settings without yaml file
325
+ config = ArborConfig(
326
+ inference=InferenceConfig(gpu_ids=inference_gpus),
327
+ training=TrainingConfig(gpu_ids=training_gpus),
328
+ )
329
+
330
+ if storage_path:
331
+ cls.set_storage_path(storage_path)
332
+
333
+ return cls(arbor_config=config)
@@ -0,0 +1,100 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ import yaml
6
+
7
+ from arbor.server.core.config import Config
8
+
9
+
10
+ class ConfigManager:
11
+ def __init__(self):
12
+ self._init_arbor_directories()
13
+
14
+ def _init_arbor_directories(self):
15
+ arbor_root = Path.home() / ".arbor"
16
+ storage_dir = Path(self.STORAGE_PATH)
17
+
18
+ arbor_root.mkdir(exist_ok=True)
19
+ storage_dir.mkdir(exist_ok=True)
20
+ (storage_dir / "logs").mkdir(exist_ok=True)
21
+ (storage_dir / "models").mkdir(exist_ok=True)
22
+ (storage_dir / "uploads").mkdir(exist_ok=True)
23
+
24
+ @staticmethod
25
+ def get_default_config_path() -> Path:
26
+ return str(Path.home() / ".arbor" / "config.yaml")
27
+
28
+ @staticmethod
29
+ def get_config_template() -> Dict:
30
+ return {"inference": {"gpu_ids": "0"}, "training": {"gpu_ids": "1, 2"}}
31
+
32
+ @classmethod
33
+ def update_config(
34
+ cls,
35
+ inference_gpus: Optional[str] = None,
36
+ training_gpus: Optional[str] = None,
37
+ config_path: Optional[str] = None,
38
+ ) -> str:
39
+ """Update existing config or create new one."""
40
+
41
+ if config_path is None:
42
+ config_path = Config.use_default_config()
43
+ if config_path is None:
44
+ config_path = str(cls.get_default_config_path())
45
+
46
+ config_file = Path(config_path)
47
+ config_file.parent.mkdir(parents=True, exist_ok=True)
48
+
49
+ # Load existing config or use template
50
+ if config_file.exists():
51
+ with open(config_file, "r") as f:
52
+ config = yaml.safe_load(f) or {}
53
+ else:
54
+ config = cls.get_config_template()
55
+
56
+ # Update values given
57
+ if inference_gpus is not None:
58
+ if "inference" not in config:
59
+ config["inference"] = {}
60
+ config["inference"]["gpu_ids"] = str(inference_gpus)
61
+
62
+ if training_gpus is not None:
63
+ if "training" not in config:
64
+ config["training"] = {}
65
+ config["training"]["gpu_ids"] = str(training_gpus)
66
+
67
+ temp_path = config_file.with_suffix(".tmp")
68
+ try:
69
+ with open(temp_path, "w") as f:
70
+ yaml.dump(config, f, default_flow_style=False, default_style="'")
71
+ temp_path.rename(config_file)
72
+ except Exception:
73
+ if temp_path.exists():
74
+ temp_path.unlink()
75
+ raise
76
+
77
+ return str(config_file)
78
+
79
+ @classmethod
80
+ def validate_config_file(cls, config_path: str) -> Tuple[bool, str]:
81
+ """Validate a config file"""
82
+ try:
83
+ if not Path(config_path).exists():
84
+ return False, f"Config file does not exist: {config_path}"
85
+
86
+ # If we do have a config file, try to see if it will load
87
+ Config.load_config_from_yaml(config_path)
88
+ return True, "Config is valid"
89
+
90
+ except Exception as e:
91
+ return False, f"Invalid config: {e}"
92
+
93
+ @classmethod
94
+ def get_config_contents(cls, config_path: str) -> Tuple[bool, str]:
95
+ try:
96
+ with open(config_path, "r") as f:
97
+ config_content = f.read()
98
+ return True, config_content
99
+ except Exception as e:
100
+ return False, str(e)
arbor/server/main.py CHANGED
@@ -1,11 +1,36 @@
1
- from fastapi import FastAPI
1
+ from fastapi import FastAPI, Request
2
2
 
3
3
  from arbor.server.api.routes import files, grpo, inference, jobs
4
+ from arbor.server.utils.logging import apply_uvicorn_formatting
4
5
 
5
6
  app = FastAPI(title="Arbor API")
6
7
 
8
+
9
+ @app.on_event("startup")
10
+ async def startup_event():
11
+ """Configure uvicorn logging after the app starts up."""
12
+ apply_uvicorn_formatting()
13
+
14
+
7
15
  # Include routers
8
16
  app.include_router(files.router, prefix="/v1/files")
9
17
  app.include_router(jobs.router, prefix="/v1/fine_tuning/jobs")
10
18
  app.include_router(grpo.router, prefix="/v1/fine_tuning/grpo")
11
19
  app.include_router(inference.router, prefix="/v1/chat")
20
+
21
+
22
+ @app.get("/health")
23
+ def health_check(request: Request):
24
+ """Enhanced health check with system and GPU information."""
25
+ health_manager = request.app.state.health_manager
26
+ return health_manager.get_health_status()
27
+
28
+
29
+ @app.get("/health/simple")
30
+ def simple_health_check(request: Request):
31
+ """Simple health check that returns just the status."""
32
+ health_manager = request.app.state.health_manager
33
+ return {
34
+ "status": "healthy" if health_manager.is_healthy() else "unhealthy",
35
+ "timestamp": health_manager.get_health_status()["timestamp"],
36
+ }
@@ -6,6 +6,10 @@ import time
6
6
 
7
7
  import zmq
8
8
 
9
+ from arbor.server.utils.logging import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
9
13
 
10
14
  class ArborServerCommsHandler:
11
15
  """Handles socket communication between manager and training process"""
@@ -64,14 +68,14 @@ class ArborServerCommsHandler:
64
68
  def wait_for_clients(self, expected_count):
65
69
  connected_clients = []
66
70
  while len(connected_clients) < expected_count:
67
- print(f"Waiting for {expected_count} clients to connect...")
71
+ logger.info(f"Waiting for {expected_count} clients to connect...")
68
72
  msg = self.handshake_socket.recv_json()
69
73
  if msg.get("type") == "hello":
70
74
  client_id = msg.get("client_id")
71
75
  connected_clients.append(client_id)
72
76
  self.handshake_socket.send_json({"status": "ack"})
73
- print(f"Received handshake from {client_id}")
74
- print(f"All {expected_count} clients connected!")
77
+ logger.info(f"Received handshake from {client_id}")
78
+ logger.info(f"All {expected_count} clients connected!")
75
79
 
76
80
 
77
81
  class ArborScriptCommsHandler:
@@ -138,7 +142,7 @@ class ArborScriptCommsHandler:
138
142
  data = self.data_socket.recv_json()
139
143
  self.data_queue.put(data)
140
144
  except Exception as e:
141
- print(f"Error receiving data: {e}")
145
+ logger.error(f"Error receiving data: {e}")
142
146
  break
143
147
 
144
148
  self.receiver_thread = threading.Thread(target=_receiver, daemon=True)
@@ -170,7 +174,7 @@ class ArborScriptCommsHandler:
170
174
  return f"{socket.gethostname()}_{os.getpid()}"
171
175
 
172
176
  def _send_handshake(self):
173
- print(f"Sending handshake to {self.handshake_socket}")
177
+ logger.debug(f"Sending handshake to {self.handshake_socket}")
174
178
  self.handshake_socket.send_json(
175
179
  {"type": "hello", "client_id": self._get_client_id()}
176
180
  )
@@ -187,12 +191,12 @@ if __name__ == "__main__":
187
191
 
188
192
  def _client_thread(script_comms):
189
193
  for data in script_comms.receive_data():
190
- print("Client received data:", data)
194
+ logger.info("Client received data:", data)
191
195
 
192
196
  server_comms = ArborServerCommsHandler()
193
197
  t1 = threading.Thread(target=_server_thread, args=(server_comms,))
194
198
  t1.start()
195
- print("Server started")
199
+ logger.info("Server started")
196
200
 
197
201
  client_threads = []
198
202
  script_comms_list = []
@@ -222,9 +226,9 @@ if __name__ == "__main__":
222
226
  for t in client_threads:
223
227
  t.join()
224
228
  except KeyboardInterrupt:
225
- print("Keyboard interrupt")
229
+ logger.info("Keyboard interrupt")
226
230
  except Exception as e:
227
- print(f"Error: {e}")
231
+ logger.error(f"Error: {e}")
228
232
  finally:
229
233
  for script_comms in script_comms_list:
230
234
  script_comms.close()
@@ -7,7 +7,10 @@ from pathlib import Path
7
7
 
8
8
  from fastapi import UploadFile
9
9
 
10
- from arbor.server.core.config import Settings
10
+ from arbor.server.core.config import Config
11
+ from arbor.server.utils.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
11
14
 
12
15
 
13
16
  class FileValidationError(Exception):
@@ -17,8 +20,8 @@ class FileValidationError(Exception):
17
20
 
18
21
 
19
22
  class FileManager:
20
- def __init__(self, settings: Settings):
21
- self.uploads_dir = Path(settings.STORAGE_PATH) / "uploads"
23
+ def __init__(self, config: Config):
24
+ self.uploads_dir = Path(config.STORAGE_PATH) / "uploads"
22
25
  self.uploads_dir.mkdir(parents=True, exist_ok=True)
23
26
  self.files = self.load_files_from_uploads()
24
27
 
@@ -284,6 +287,6 @@ class FileManager:
284
287
  }
285
288
  fout.write(json.dumps(new_line) + "\n")
286
289
  except Exception as e:
287
- print(f"Error parsing line {line_num}: {e}")
290
+ logger.error(f"Error parsing line {line_num}: {e}")
288
291
 
289
292
  os.replace(output_path, file_path)