arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/__init__.py +17 -0
  2. arbor/cli.py +83 -43
  3. arbor/client/arbor_client.py +259 -0
  4. arbor/server/api/models/schemas.py +3 -1
  5. arbor/server/api/routes/grpo.py +2 -6
  6. arbor/server/api/routes/inference.py +7 -3
  7. arbor/server/core/config.py +293 -7
  8. arbor/server/core/config_manager.py +100 -0
  9. arbor/server/main.py +26 -1
  10. arbor/server/services/comms/comms.py +13 -9
  11. arbor/server/services/file_manager.py +7 -4
  12. arbor/server/services/grpo_manager.py +98 -62
  13. arbor/server/services/health_manager.py +171 -0
  14. arbor/server/services/inference/vllm_client.py +6 -4
  15. arbor/server/services/inference_manager.py +40 -38
  16. arbor/server/services/job_manager.py +2 -2
  17. arbor/server/services/scripts/grpo_training.py +62 -281
  18. arbor/server/services/scripts/mmgrpo_training.py +510 -0
  19. arbor/server/services/scripts/sft_training.py +8 -5
  20. arbor/server/services/scripts/utils/callbacks.py +33 -0
  21. arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  22. arbor/server/services/scripts/utils/dataset.py +176 -0
  23. arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  24. arbor/server/services/scripts/utils/mock_server.py +124 -0
  25. arbor/server/services/training_manager.py +4 -4
  26. arbor/server/utils/logging.py +298 -0
  27. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
  28. arbor_ai-0.2.2.dist-info/RECORD +51 -0
  29. arbor_ai-0.2.1.dist-info/RECORD +0 -42
  30. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
arbor/__init__.py CHANGED
@@ -0,0 +1,17 @@
1
+ """
2
+ Arbor - A framework for fine-tuning and managing language models
3
+ """
4
+
5
+ from importlib.metadata import PackageNotFoundError, version
6
+
7
+ try:
8
+ __version__ = version("arbor-ai")
9
+ except PackageNotFoundError:
10
+ # Package is not installed, likely in development mode
11
+ __version__ = "dev"
12
+ except Exception:
13
+ __version__ = "unknown"
14
+
15
+ from arbor.client.arbor_client import is_running, serve, stop
16
+
17
+ __all__ = ["__version__", "serve", "stop", "is_running"]
arbor/cli.py CHANGED
@@ -4,58 +4,22 @@ from datetime import datetime
4
4
  import click
5
5
  import uvicorn
6
6
 
7
- from arbor.server.core.config import Settings
7
+ from arbor.server.core.config import Config
8
+ from arbor.server.core.config_manager import ConfigManager
8
9
  from arbor.server.main import app
9
10
  from arbor.server.services.file_manager import FileManager
10
11
  from arbor.server.services.grpo_manager import GRPOManager
12
+ from arbor.server.services.health_manager import HealthManager
11
13
  from arbor.server.services.inference_manager import InferenceManager
12
14
  from arbor.server.services.job_manager import JobManager
13
15
  from arbor.server.services.training_manager import TrainingManager
14
-
15
-
16
- def make_log_dir(storage_path: str):
17
- # Create a timestamped log directory under the storage path
18
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
19
- log_dir = os.path.join(storage_path, "logs", timestamp)
20
- os.makedirs(log_dir, exist_ok=True)
21
- return log_dir
16
+ from arbor.client.arbor_client import create_app
22
17
 
23
18
 
24
19
  @click.group()
25
20
  def cli():
26
21
  pass
27
22
 
28
-
29
- def create_app(arbor_config_path: str):
30
- """Create and configure the Arbor API application
31
-
32
- Args:
33
- storage_path (str): Path to store models and uploaded training files
34
-
35
- Returns:
36
- FastAPI: Configured FastAPI application
37
- """
38
- # Create new settings instance with overrides
39
- settings = Settings.load_from_yaml(arbor_config_path)
40
- app.state.log_dir = make_log_dir(settings.STORAGE_PATH)
41
-
42
- # Initialize services with settings
43
- file_manager = FileManager(settings=settings)
44
- job_manager = JobManager(settings=settings)
45
- training_manager = TrainingManager(settings=settings)
46
- inference_manager = InferenceManager(settings=settings)
47
- grpo_manager = GRPOManager(settings=settings)
48
- # Inject settings into app state
49
- app.state.settings = settings
50
- app.state.file_manager = file_manager
51
- app.state.job_manager = job_manager
52
- app.state.training_manager = training_manager
53
- app.state.inference_manager = inference_manager
54
- app.state.grpo_manager = grpo_manager
55
-
56
- return app
57
-
58
-
59
23
  def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
60
24
  """Start the Arbor API server with a single function call"""
61
25
  import socket
@@ -72,6 +36,7 @@ def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10
72
36
  raise RuntimeError(f"Port {port} is already in use")
73
37
 
74
38
  app = create_app(storage_path)
39
+ # configure_uvicorn_logging()
75
40
  config = uvicorn.Config(app, host=host, port=port, log_level="info")
76
41
  server = uvicorn.Server(config)
77
42
 
@@ -102,11 +67,86 @@ def stop_server(server):
102
67
  @cli.command()
103
68
  @click.option("--host", default="0.0.0.0", help="Host to bind to")
104
69
  @click.option("--port", default=7453, help="Port to bind to")
105
- @click.option("--arbor-config", required=True, help="Path to the Arbor config file")
70
+ @click.option("--arbor-config", required=False, help="Path to the Arbor config file")
106
71
  def serve(host, port, arbor_config):
107
72
  """Start the Arbor API server"""
108
- app = create_app(arbor_config)
109
- uvicorn.run(app, host=host, port=port)
73
+
74
+ if arbor_config:
75
+ config_path = arbor_config
76
+ else:
77
+ config_path = Config.use_default_config()
78
+
79
+ # If no config found, run first-time setup
80
+ if config_path is None:
81
+ config_path = run_first_time_setup()
82
+
83
+ # Validate config exists and is readable
84
+ is_valid, msg = ConfigManager.validate_config_file(config_path)
85
+
86
+ if not is_valid:
87
+ click.echo(msg)
88
+ raise click.Abort()
89
+
90
+ try:
91
+ create_app(config_path)
92
+ # Temporarily disable custom uvicorn logging configuration
93
+ # configure_uvicorn_logging()
94
+ uvicorn.run(app, host=host, port=port)
95
+ except Exception as e:
96
+ click.echo(f"Failed to start server: {e}", err=True)
97
+ raise click.Abort()
98
+
99
+
100
+ def run_first_time_setup() -> str:
101
+ """Run first-time setup and return created config path"""
102
+ click.echo("Welcome to Arbor!")
103
+ click.echo("It looks like this is your first time running Arbor.")
104
+ click.echo("Let's set up your configuration...\n")
105
+
106
+ try:
107
+ # Get config details
108
+ inference = click.prompt(
109
+ "Which gpu ids should be used for inference (separated by comma)",
110
+ default="0",
111
+ )
112
+ training = click.prompt(
113
+ "Which gpu ids should be used for training (separated by comma)",
114
+ default="1, 2",
115
+ )
116
+ click.echo()
117
+
118
+ # Get config file path
119
+ config_path = click.prompt(
120
+ "Enter path to save config file in. We recommend (~/.arbor/config.yaml)",
121
+ default=ConfigManager.get_default_config_path(),
122
+ )
123
+ logger = get_logger(__name__)
124
+ logger.info(f"Config path selected: {config_path}")
125
+ click.echo()
126
+
127
+ # Update or create config at path
128
+ config_path = ConfigManager.update_config(inference, training, config_path)
129
+ click.echo(f"Created configuration at: {config_path}")
130
+
131
+ # Check if it is a valid config file
132
+ is_valid, msg = ConfigManager.validate_config_file(config_path)
133
+ if not is_valid:
134
+ raise click.ClickException(f"Invalid config file: {msg}")
135
+
136
+ # Read and display the contents
137
+ _, content = ConfigManager.get_config_contents(config_path)
138
+
139
+ click.echo("\nConfiguration file contents:")
140
+ click.echo("---")
141
+ click.echo(content)
142
+ click.echo("---")
143
+
144
+ click.echo("\nSetup complete! Starting Arbor server...")
145
+ return config_path
146
+
147
+ except Exception as e:
148
+ click.echo(f"Failed initial setup of Arbor: {e}", err=True)
149
+ raise click.Abort()
110
150
 
111
151
 
112
152
  if __name__ == "__main__":
@@ -0,0 +1,259 @@
1
+ import asyncio
2
+ import os
3
+ import socket
4
+ import threading
5
+ import time
6
+ from datetime import datetime
7
+
8
+ import click
9
+ import requests
10
+ import uvicorn
11
+
12
+ from arbor.server.core.config import Config
13
+ from arbor.server.core.config_manager import ConfigManager
14
+ from arbor.server.main import app
15
+ from arbor.server.services.file_manager import FileManager
16
+ from arbor.server.services.grpo_manager import GRPOManager
17
+ from arbor.server.services.health_manager import HealthManager
18
+ from arbor.server.services.inference_manager import InferenceManager
19
+ from arbor.server.services.job_manager import JobManager
20
+ from arbor.server.services.training_manager import TrainingManager
21
+
22
+ # Global server state
23
+ _server = None
24
+ _server_thread = None
25
+ _server_loop = None
26
+ _server_host = None
27
+ _server_port = None
28
+
29
+
30
+ def create_app(
31
+ config_path: str = None,
32
+ storage_path: str = None,
33
+ inference_gpus: str = None,
34
+ training_gpus: str = None,
35
+ ):
36
+ """Create and configure the Arbor API application
37
+
38
+ Args:
39
+ arbor_config_path (str): Path to config file
40
+ storage_path (str): Path to storage directory
41
+ inference_gpus (str): gpu ids to use for inference
42
+ training_gpus (str): gpu ids to use for training
43
+
44
+ Returns:
45
+ FastAPI: Configured FastAPI application
46
+ """
47
+ # Create new config instance with overrides
48
+ if config_path:
49
+ config = Config.load_config_from_yaml(config_path)
50
+ elif inference_gpus and training_gpus:
51
+ config = Config.load_config_directly(
52
+ storage_path, inference_gpus, training_gpus
53
+ )
54
+ else:
55
+ raise ValueError(
56
+ "Either 'config_path' must be provided, or 'inference_gpus', and 'training_gpus' must be provided"
57
+ )
58
+
59
+ app.state.log_dir = Config.make_log_dir(config.STORAGE_PATH)
60
+
61
+ # Initialize services with config
62
+ health_manager = HealthManager(config=config)
63
+ file_manager = FileManager(config=config)
64
+ job_manager = JobManager(config=config)
65
+ training_manager = TrainingManager(config=config)
66
+ inference_manager = InferenceManager(config=config)
67
+ grpo_manager = GRPOManager(config=config)
68
+
69
+ # Inject config into app state
70
+ app.state.config = config
71
+ app.state.file_manager = file_manager
72
+ app.state.job_manager = job_manager
73
+ app.state.training_manager = training_manager
74
+ app.state.inference_manager = inference_manager
75
+ app.state.grpo_manager = grpo_manager
76
+ app.state.health_manager = health_manager
77
+
78
+ return app
79
+
80
+
81
+ def serve(
82
+ config_path: str = None,
83
+ storage_path: str = None,
84
+ inference_gpus: str = None,
85
+ training_gpus: str = None,
86
+ host: str = "0.0.0.0",
87
+ port: int = 7453,
88
+ ):
89
+ """Start the Arbor API server.
90
+
91
+ Starts the server in a background thread and returns once the server is ready to accept requests.
92
+ Use arbor.stop() to shutdown the server.
93
+
94
+ Args:
95
+ config_path: Path to YAML config file (optional)
96
+ storage_path: Valid storage directory path (optional)
97
+ inference_gpus: GPU IDs for inference, e.g. "0,1" (optional, default 0)
98
+ training_gpus: GPU IDs for training, e.g. "1,2,3" (optional, default 1,2)
99
+ host: Host to bind to (default: "0.0.0.0")
100
+ port: Port to bind to (default: 7453)
101
+
102
+ Example:
103
+ import arbor
104
+ arbor.serve(inference_gpus="0", training_gpus="1,2")
105
+ # Server is now ready to accept requests
106
+ # Later, to stop:
107
+ arbor.stop()
108
+ """
109
+ global _server, _server_thread, _server_loop, _server_host, _server_port
110
+
111
+ # Stop existing server if running
112
+ if _server is not None:
113
+ print("🌳 Stopping existing server...")
114
+ stop()
115
+
116
+ _server_host = host
117
+ _server_port = port
118
+
119
+ create_app(config_path, storage_path, inference_gpus, training_gpus)
120
+
121
+ # Start server in background thread
122
+ def run_server():
123
+ global _server, _server_loop
124
+
125
+ # Create a new event loop for this thread
126
+ loop = asyncio.new_event_loop()
127
+ asyncio.set_event_loop(loop)
128
+ _server_loop = loop
129
+
130
+ # Create uvicorn config and server
131
+ config = uvicorn.Config(app, host=host, port=port, loop=loop)
132
+ server = uvicorn.Server(config)
133
+ _server = server
134
+
135
+ # Run the server
136
+ try:
137
+ loop.run_until_complete(server.serve())
138
+ except Exception as e:
139
+ print(f"Server error: {e}")
140
+ finally:
141
+ loop.close()
142
+ _server = None
143
+ _server_loop = None
144
+
145
+ # Start server thread
146
+ _server_thread = threading.Thread(target=run_server, daemon=True)
147
+ _server_thread.start()
148
+
149
+ print(f"🌳 Starting Arbor server on http://{host}:{port}...")
150
+
151
+ # Wait for server to be ready
152
+ try:
153
+ _wait_for_server_ready(host, port, timeout=60) # Increased timeout
154
+ print(f"🌳 Arbor server is ready and accepting requests!")
155
+ except TimeoutError as e:
156
+ print(f"❌ {e}")
157
+ # Try to stop the server if it failed to start properly
158
+ stop()
159
+ raise
160
+
161
+
162
+ def stop():
163
+ """Stop the Arbor server if it's running."""
164
+ global _server, _server_thread, _server_loop
165
+
166
+ if _server is None:
167
+ print("🌳 No server running to stop.")
168
+ return
169
+
170
+ print("🌳 Stopping Arbor server...")
171
+
172
+ # Schedule server shutdown in the server's event loop
173
+ if _server_loop and _server:
174
+ try:
175
+ asyncio.run_coroutine_threadsafe(_server.shutdown(), _server_loop)
176
+ except Exception as e:
177
+ print(f"Error during shutdown: {e}")
178
+
179
+ # Wait for thread to finish
180
+ if _server_thread and _server_thread.is_alive():
181
+ _server_thread.join(timeout=5)
182
+
183
+ # Reset global state
184
+ _server = None
185
+ _server_thread = None
186
+ _server_loop = None
187
+
188
+ print("🌳 Arbor server stopped.")
189
+
190
+
191
+ def is_running():
192
+ """Check if the Arbor server is currently running."""
193
+ return (
194
+ _server is not None and _server_thread is not None and _server_thread.is_alive()
195
+ )
196
+
197
+
198
+ def _wait_for_server_ready(host, port, timeout=30):
199
+ """Wait for the server to be ready to accept requests."""
200
+ start_time = time.time()
201
+ last_error = None
202
+ port_open = False
203
+
204
+ print(f"🌳 Waiting for server to be ready at http://{host}:{port}...")
205
+
206
+ while time.time() - start_time < timeout:
207
+ # First check if the port is open
208
+ if not port_open:
209
+ try:
210
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
211
+ sock.settimeout(1)
212
+ result = sock.connect_ex((host, port))
213
+ sock.close()
214
+ if result == 0:
215
+ port_open = True
216
+ print(f"🌳 Port {port} is now open, checking health endpoint...")
217
+ else:
218
+ last_error = f"Port {port} not yet open"
219
+ time.sleep(0.5)
220
+ continue
221
+ except Exception as e:
222
+ last_error = f"Socket error: {e}"
223
+ time.sleep(0.5)
224
+ continue
225
+
226
+ # Now try the health check
227
+ try:
228
+ response = requests.get(f"http://{host}:{port}/health/simple", timeout=2)
229
+ if response.status_code == 200:
230
+ print(f"🌳 Server ready! Response: {response.json()}")
231
+ return
232
+ else:
233
+ last_error = f"Health check returned status {response.status_code}"
234
+ except requests.exceptions.ConnectionError as e:
235
+ last_error = f"Connection error: {e}"
236
+ port_open = False # Port might have closed
237
+ except requests.exceptions.Timeout as e:
238
+ last_error = f"Timeout error: {e}"
239
+ except requests.exceptions.RequestException as e:
240
+ last_error = f"Request error: {e}"
241
+ except Exception as e:
242
+ last_error = f"Unexpected error: {e}"
243
+
244
+ # Print progress every 5 seconds
245
+ elapsed = time.time() - start_time
246
+ if int(elapsed) % 5 == 0 and elapsed >= 5:
247
+ print(
248
+ f"🌳 Still waiting... ({elapsed:.1f}s elapsed, port_open={port_open}, last error: {last_error})"
249
+ )
250
+
251
+ time.sleep(0.5)
252
+
253
+ raise TimeoutError(
254
+ f"Server did not become ready within {timeout} seconds. Last error: {last_error}"
255
+ )
256
+
257
+
258
+ if __name__ == "__main__":
259
+ serve(inference_gpus="0, 1", training_gpus="2, 3")
@@ -178,7 +178,7 @@ class ChatCompletionModel(BaseModel):
178
178
 
179
179
  class GRPORequest(BaseModel):
180
180
  model: str
181
- batch: List[dict]
181
+ batch: List[dict] | List[List[dict]]
182
182
 
183
183
 
184
184
  class GRPOConfigRequest(BaseModel):
@@ -205,6 +205,8 @@ class GRPOConfigRequest(BaseModel):
205
205
  # Arbor specific
206
206
  max_context_length: Optional[int] = None
207
207
  lora: Optional[bool] = None
208
+ grpo_flavor: Optional[Literal["grpo", "mmgrpo"]] = None
209
+ wandb_kwargs: Optional[dict] = None
208
210
  # To name the run
209
211
  suffix: Optional[str] = None
210
212
  generation_batch_size: Optional[int] = None
@@ -10,7 +10,6 @@ from arbor.server.api.models.schemas import (
10
10
  GRPOConfigResponse,
11
11
  GRPORequest,
12
12
  GRPOStepResponse,
13
- GRPOTerminateRequest,
14
13
  GRPOTerminateResponse,
15
14
  )
16
15
 
@@ -27,12 +26,9 @@ def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
27
26
 
28
27
  # Create a grpo job
29
28
  @router.post("/step", response_model=GRPOStepResponse)
30
- def run_grpo_step(
31
- request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
32
- ):
33
- inference_manager = request.app.state.inference_manager
29
+ def run_grpo_step(request: Request, grpo_request: GRPORequest):
34
30
  grpo_manager = request.app.state.grpo_manager
35
-
31
+ inference_manager = request.app.state.inference_manager
36
32
  step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
37
33
 
38
34
  return GRPOStepResponse(status="success", **step_data)
@@ -3,6 +3,10 @@ import uuid
3
3
 
4
4
  from fastapi import APIRouter, Request
5
5
 
6
+ from arbor.server.utils.logging import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
6
10
  router = APIRouter()
7
11
 
8
12
 
@@ -27,17 +31,17 @@ async def run_inference(
27
31
 
28
32
  # if a server isnt running, launch one
29
33
  if not inference_manager.is_server_running():
30
- print("No model is running, launching model...")
34
+ logger.info("No model is running, launching model...")
31
35
  inference_manager.launch(request_model)
32
36
 
33
37
  # if the requested model is different from the launched model, swap the server
34
38
  if request_model != inference_manager.launched_model:
35
- print(
39
+ logger.info(
36
40
  f"Model changed from {inference_manager.launched_model} to {request_model}, swapping server..."
37
41
  )
38
42
  inference_manager.kill()
39
43
  inference_manager.launch(request_model)
40
- print(f"Model swapped to {request_model}")
44
+ logger.info(f"Model swapped to {request_model}")
41
45
 
42
46
  # forward the request to the inference server
43
47
  completion = await inference_manager.run_inference(raw_json)