arbor-ai 0.2__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arbor_ai-0.2 → arbor_ai-0.2.2}/PKG-INFO +18 -18
- {arbor_ai-0.2 → arbor_ai-0.2.2}/README.md +16 -16
- arbor_ai-0.2.2/arbor/__init__.py +17 -0
- arbor_ai-0.2.2/arbor/cli.py +153 -0
- arbor_ai-0.2.2/arbor/client/arbor_client.py +259 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/api/models/schemas.py +3 -1
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/api/routes/grpo.py +2 -6
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/api/routes/inference.py +7 -3
- arbor_ai-0.2.2/arbor/server/core/config.py +333 -0
- arbor_ai-0.2.2/arbor/server/core/config_manager.py +100 -0
- arbor_ai-0.2.2/arbor/server/main.py +36 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/comms/comms.py +13 -9
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/file_manager.py +7 -4
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/grpo_manager.py +101 -63
- arbor_ai-0.2.2/arbor/server/services/health_manager.py +171 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/inference/vllm_client.py +8 -5
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/inference_manager.py +40 -38
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/job_manager.py +2 -2
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/scripts/grpo_training.py +62 -281
- arbor_ai-0.2.2/arbor/server/services/scripts/mmgrpo_training.py +510 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/scripts/sft_training.py +8 -5
- arbor_ai-0.2.2/arbor/server/services/scripts/utils/callbacks.py +33 -0
- arbor_ai-0.2.2/arbor/server/services/scripts/utils/comms_monitors.py +169 -0
- arbor_ai-0.2.2/arbor/server/services/scripts/utils/dataset.py +176 -0
- arbor_ai-0.2.2/arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
- arbor_ai-0.2.2/arbor/server/services/scripts/utils/mock_server.py +124 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/training_manager.py +4 -4
- arbor_ai-0.2.2/arbor/server/utils/logging.py +298 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor_ai.egg-info/PKG-INFO +18 -18
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor_ai.egg-info/SOURCES.txt +9 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor_ai.egg-info/requires.txt +1 -1
- {arbor_ai-0.2 → arbor_ai-0.2.2}/pyproject.toml +3 -3
- arbor_ai-0.2/arbor/cli.py +0 -113
- arbor_ai-0.2/arbor/server/core/config.py +0 -47
- arbor_ai-0.2/arbor/server/main.py +0 -11
- arbor_ai-0.2/arbor/server/services/scripts/utils/dataset.py +0 -0
- arbor_ai-0.2/arbor/server/utils/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/LICENSE +0 -0
- {arbor_ai-0.2/arbor → arbor_ai-0.2.2/arbor/client}/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/client/api.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/api/__init__.py +0 -0
- {arbor_ai-0.2/arbor/client → arbor_ai-0.2.2/arbor/server/api/routes}/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/api/routes/files.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/api/routes/jobs.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/core/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/core/logging.py +0 -0
- {arbor_ai-0.2/arbor/server/api/routes → arbor_ai-0.2.2/arbor/server/services}/__init__.py +0 -0
- {arbor_ai-0.2/arbor/server/services → arbor_ai-0.2.2/arbor/server/services/comms}/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/dependencies.py +0 -0
- {arbor_ai-0.2/arbor/server/services/comms → arbor_ai-0.2.2/arbor/server/services/inference}/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/inference/vllm_serve.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/scripts/dpo_training.py +0 -0
- {arbor_ai-0.2/arbor/server/services/inference → arbor_ai-0.2.2/arbor/server/services/scripts/utils}/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/services/scripts/utils/arg_parser.py +0 -0
- {arbor_ai-0.2/arbor/server/services/scripts → arbor_ai-0.2.2/arbor/server}/utils/__init__.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor/server/utils/helpers.py +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor_ai.egg-info/dependency_links.txt +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor_ai.egg-info/entry_points.txt +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/arbor_ai.egg-info/top_level.txt +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/setup.cfg +0 -0
- {arbor_ai-0.2 → arbor_ai-0.2.2}/tests/test_cli.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.2
|
3
|
+
Version: 0.2.2
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -14,7 +14,7 @@ Requires-Dist: uvicorn
|
|
14
14
|
Requires-Dist: click
|
15
15
|
Requires-Dist: python-multipart
|
16
16
|
Requires-Dist: pydantic-settings
|
17
|
-
Requires-Dist: vllm
|
17
|
+
Requires-Dist: vllm==0.8.5.post1
|
18
18
|
Requires-Dist: transformers
|
19
19
|
Requires-Dist: trl>=0.17.0
|
20
20
|
Requires-Dist: peft
|
@@ -43,7 +43,8 @@ Install Arbor via pip:
|
|
43
43
|
pip install -U arbor-ai
|
44
44
|
```
|
45
45
|
|
46
|
-
Optionally, you can also install
|
46
|
+
Optionally, you can also install flash attention to speed up inference. <br/>
|
47
|
+
This can take 15+ minutes to install on some setups:
|
47
48
|
```bash
|
48
49
|
pip install flash-attn --no-build-isolation
|
49
50
|
```
|
@@ -52,33 +53,32 @@ pip install flash-attn --no-build-isolation
|
|
52
53
|
|
53
54
|
## ⚡ Quick Start
|
54
55
|
|
55
|
-
### 1️⃣
|
56
|
-
|
57
|
-
This is all dependent on your setup. Here is an example of one:
|
58
|
-
```yaml
|
59
|
-
inference:
|
60
|
-
gpu_ids: '0'
|
61
|
-
|
62
|
-
training:
|
63
|
-
gpu_ids: '1, 2'
|
64
|
-
```
|
65
|
-
Which will use the `GPU:0` for inference with `GPU:1` and `GPU:2` reserved for training. We generally recommend splitting the GPUs roughly evenly between inference and training.
|
66
|
-
|
67
|
-
### 2️⃣ Start the Server
|
56
|
+
### 1️⃣ Start the Server
|
68
57
|
|
69
58
|
**CLI:**
|
70
59
|
|
71
60
|
```bash
|
72
|
-
python -m arbor.cli serve
|
61
|
+
python -m arbor.cli serve
|
73
62
|
```
|
63
|
+
On the first run you'll be asked which GPUs will be used for training and which for inference. For more that one GPU, separate the ids by comma: `1, 2`. Your config file will be saved in `~/.arbor/config.yaml` should you want to edit these configs in the future.
|
74
64
|
|
75
|
-
###
|
65
|
+
### 2️⃣ Optimize a DSPy Program
|
76
66
|
|
77
67
|
Follow the DSPy tutorials here to see usage examples:
|
78
68
|
[DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
|
79
69
|
|
80
70
|
---
|
81
71
|
|
72
|
+
### Troubleshooting
|
73
|
+
|
74
|
+
**NCCL Errors**
|
75
|
+
Certain GPU setups, particularly with newer GPUs, seem to have issues with NCCL that cause Arbor to crash. Often times of these can be fixed with the following environment variables:
|
76
|
+
|
77
|
+
```bash
|
78
|
+
export NCCL_P2P_DISABLE=1
|
79
|
+
export NCCL_IB_DISABLE=1
|
80
|
+
```
|
81
|
+
|
82
82
|
## 🙏 Acknowledgements
|
83
83
|
|
84
84
|
Arbor builds on the shoulders of great work. We extend our thanks to:
|
@@ -16,7 +16,8 @@ Install Arbor via pip:
|
|
16
16
|
pip install -U arbor-ai
|
17
17
|
```
|
18
18
|
|
19
|
-
Optionally, you can also install
|
19
|
+
Optionally, you can also install flash attention to speed up inference. <br/>
|
20
|
+
This can take 15+ minutes to install on some setups:
|
20
21
|
```bash
|
21
22
|
pip install flash-attn --no-build-isolation
|
22
23
|
```
|
@@ -25,33 +26,32 @@ pip install flash-attn --no-build-isolation
|
|
25
26
|
|
26
27
|
## ⚡ Quick Start
|
27
28
|
|
28
|
-
### 1️⃣
|
29
|
-
|
30
|
-
This is all dependent on your setup. Here is an example of one:
|
31
|
-
```yaml
|
32
|
-
inference:
|
33
|
-
gpu_ids: '0'
|
34
|
-
|
35
|
-
training:
|
36
|
-
gpu_ids: '1, 2'
|
37
|
-
```
|
38
|
-
Which will use the `GPU:0` for inference with `GPU:1` and `GPU:2` reserved for training. We generally recommend splitting the GPUs roughly evenly between inference and training.
|
39
|
-
|
40
|
-
### 2️⃣ Start the Server
|
29
|
+
### 1️⃣ Start the Server
|
41
30
|
|
42
31
|
**CLI:**
|
43
32
|
|
44
33
|
```bash
|
45
|
-
python -m arbor.cli serve
|
34
|
+
python -m arbor.cli serve
|
46
35
|
```
|
36
|
+
On the first run you'll be asked which GPUs will be used for training and which for inference. For more that one GPU, separate the ids by comma: `1, 2`. Your config file will be saved in `~/.arbor/config.yaml` should you want to edit these configs in the future.
|
47
37
|
|
48
|
-
###
|
38
|
+
### 2️⃣ Optimize a DSPy Program
|
49
39
|
|
50
40
|
Follow the DSPy tutorials here to see usage examples:
|
51
41
|
[DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
|
52
42
|
|
53
43
|
---
|
54
44
|
|
45
|
+
### Troubleshooting
|
46
|
+
|
47
|
+
**NCCL Errors**
|
48
|
+
Certain GPU setups, particularly with newer GPUs, seem to have issues with NCCL that cause Arbor to crash. Often times of these can be fixed with the following environment variables:
|
49
|
+
|
50
|
+
```bash
|
51
|
+
export NCCL_P2P_DISABLE=1
|
52
|
+
export NCCL_IB_DISABLE=1
|
53
|
+
```
|
54
|
+
|
55
55
|
## 🙏 Acknowledgements
|
56
56
|
|
57
57
|
Arbor builds on the shoulders of great work. We extend our thanks to:
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""
|
2
|
+
Arbor - A framework for fine-tuning and managing language models
|
3
|
+
"""
|
4
|
+
|
5
|
+
from importlib.metadata import PackageNotFoundError, version
|
6
|
+
|
7
|
+
try:
|
8
|
+
__version__ = version("arbor-ai")
|
9
|
+
except PackageNotFoundError:
|
10
|
+
# Package is not installed, likely in development mode
|
11
|
+
__version__ = "dev"
|
12
|
+
except Exception:
|
13
|
+
__version__ = "unknown"
|
14
|
+
|
15
|
+
from arbor.client.arbor_client import is_running, serve, stop
|
16
|
+
|
17
|
+
__all__ = ["__version__", "serve", "stop", "is_running"]
|
@@ -0,0 +1,153 @@
|
|
1
|
+
import os
|
2
|
+
from datetime import datetime
|
3
|
+
|
4
|
+
import click
|
5
|
+
import uvicorn
|
6
|
+
|
7
|
+
from arbor.server.core.config import Config
|
8
|
+
from arbor.server.core.config_manager import ConfigManager
|
9
|
+
from arbor.server.main import app
|
10
|
+
from arbor.server.services.file_manager import FileManager
|
11
|
+
from arbor.server.services.grpo_manager import GRPOManager
|
12
|
+
from arbor.server.services.health_manager import HealthManager
|
13
|
+
from arbor.server.services.inference_manager import InferenceManager
|
14
|
+
from arbor.server.services.job_manager import JobManager
|
15
|
+
from arbor.server.services.training_manager import TrainingManager
|
16
|
+
from arbor.client.arbor_client import create_app
|
17
|
+
|
18
|
+
|
19
|
+
@click.group()
|
20
|
+
def cli():
|
21
|
+
pass
|
22
|
+
|
23
|
+
def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
|
24
|
+
"""Start the Arbor API server with a single function call"""
|
25
|
+
import socket
|
26
|
+
import threading
|
27
|
+
import time
|
28
|
+
from contextlib import closing
|
29
|
+
|
30
|
+
def is_port_in_use(port):
|
31
|
+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
32
|
+
return sock.connect_ex(("localhost", port)) == 0
|
33
|
+
|
34
|
+
# First ensure the port is free
|
35
|
+
if is_port_in_use(port):
|
36
|
+
raise RuntimeError(f"Port {port} is already in use")
|
37
|
+
|
38
|
+
app = create_app(storage_path)
|
39
|
+
# configure_uvicorn_logging()
|
40
|
+
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
41
|
+
server = uvicorn.Server(config)
|
42
|
+
|
43
|
+
def run_server():
|
44
|
+
server.run()
|
45
|
+
|
46
|
+
thread = threading.Thread(target=run_server, daemon=True)
|
47
|
+
thread.start()
|
48
|
+
|
49
|
+
# Wait for server to start
|
50
|
+
start_time = time.time()
|
51
|
+
while not is_port_in_use(port):
|
52
|
+
if time.time() - start_time > timeout:
|
53
|
+
raise TimeoutError(f"Server failed to start within {timeout} seconds")
|
54
|
+
time.sleep(0.1)
|
55
|
+
|
56
|
+
# Give it a little extra time to fully initialize
|
57
|
+
time.sleep(0.5)
|
58
|
+
|
59
|
+
return server
|
60
|
+
|
61
|
+
|
62
|
+
def stop_server(server):
|
63
|
+
"""Stop the Arbor API server"""
|
64
|
+
server.should_exit = True
|
65
|
+
|
66
|
+
|
67
|
+
@cli.command()
|
68
|
+
@click.option("--host", default="0.0.0.0", help="Host to bind to")
|
69
|
+
@click.option("--port", default=7453, help="Port to bind to")
|
70
|
+
@click.option("--arbor-config", required=False, help="Path to the Arbor config file")
|
71
|
+
def serve(host, port, arbor_config):
|
72
|
+
"""Start the Arbor API server"""
|
73
|
+
|
74
|
+
if arbor_config:
|
75
|
+
config_path = arbor_config
|
76
|
+
else:
|
77
|
+
config_path = Config.use_default_config()
|
78
|
+
|
79
|
+
# If no config found, run first-time setup
|
80
|
+
if config_path is None:
|
81
|
+
config_path = run_first_time_setup()
|
82
|
+
|
83
|
+
# Validate config exists and is readable
|
84
|
+
is_valid, msg = ConfigManager.validate_config_file(config_path)
|
85
|
+
|
86
|
+
if not is_valid:
|
87
|
+
click.echo(msg)
|
88
|
+
raise click.Abort()
|
89
|
+
|
90
|
+
try:
|
91
|
+
create_app(config_path)
|
92
|
+
# Temporarily disable custom uvicorn logging configuration
|
93
|
+
# configure_uvicorn_logging()
|
94
|
+
uvicorn.run(app, host=host, port=port)
|
95
|
+
except Exception as e:
|
96
|
+
click.echo(f"Failed to start server: {e}", err=True)
|
97
|
+
raise click.Abort()
|
98
|
+
|
99
|
+
|
100
|
+
def run_first_time_setup() -> str:
|
101
|
+
"""Run first-time setup and return created config path"""
|
102
|
+
click.echo("Welcome to Arbor!")
|
103
|
+
click.echo("It looks like this is your first time running Arbor.")
|
104
|
+
click.echo("Let's set up your configuration...\n")
|
105
|
+
|
106
|
+
try:
|
107
|
+
# Get config details
|
108
|
+
inference = click.prompt(
|
109
|
+
"Which gpu ids should be used for inference (separated by comma)",
|
110
|
+
default="0",
|
111
|
+
)
|
112
|
+
training = click.prompt(
|
113
|
+
"Which gpu ids should be used for training (separated by comma)",
|
114
|
+
default="1, 2",
|
115
|
+
)
|
116
|
+
click.echo()
|
117
|
+
|
118
|
+
# Get config file path
|
119
|
+
config_path = click.prompt(
|
120
|
+
"Enter path to save config file in. We recommend (~/.arbor/config.yaml)",
|
121
|
+
default=ConfigManager.get_default_config_path(),
|
122
|
+
)
|
123
|
+
logger = get_logger(__name__)
|
124
|
+
logger.info(f"Config path selected: {config_path}")
|
125
|
+
click.echo()
|
126
|
+
|
127
|
+
# Update or create config at path
|
128
|
+
config_path = ConfigManager.update_config(inference, training, config_path)
|
129
|
+
click.echo(f"Created configuration at: {config_path}")
|
130
|
+
|
131
|
+
# Check if it is a valid config file
|
132
|
+
is_valid, msg = ConfigManager.validate_config_file(config_path)
|
133
|
+
if not is_valid:
|
134
|
+
raise click.ClickException(f"Invalid config file: {msg}")
|
135
|
+
|
136
|
+
# Read and display the contents
|
137
|
+
_, content = ConfigManager.get_config_contents(config_path)
|
138
|
+
|
139
|
+
click.echo("\nConfiguration file contents:")
|
140
|
+
click.echo("---")
|
141
|
+
click.echo(content)
|
142
|
+
click.echo("---")
|
143
|
+
|
144
|
+
click.echo("\nSetup complete! Starting Arbor server...")
|
145
|
+
return config_path
|
146
|
+
|
147
|
+
except Exception as e:
|
148
|
+
click.echo(f"Failed initial setup of Arbor: {e}", err=True)
|
149
|
+
raise click.Abort()
|
150
|
+
|
151
|
+
|
152
|
+
if __name__ == "__main__":
|
153
|
+
cli()
|
@@ -0,0 +1,259 @@
|
|
1
|
+
import asyncio
|
2
|
+
import os
|
3
|
+
import socket
|
4
|
+
import threading
|
5
|
+
import time
|
6
|
+
from datetime import datetime
|
7
|
+
|
8
|
+
import click
|
9
|
+
import requests
|
10
|
+
import uvicorn
|
11
|
+
|
12
|
+
from arbor.server.core.config import Config
|
13
|
+
from arbor.server.core.config_manager import ConfigManager
|
14
|
+
from arbor.server.main import app
|
15
|
+
from arbor.server.services.file_manager import FileManager
|
16
|
+
from arbor.server.services.grpo_manager import GRPOManager
|
17
|
+
from arbor.server.services.health_manager import HealthManager
|
18
|
+
from arbor.server.services.inference_manager import InferenceManager
|
19
|
+
from arbor.server.services.job_manager import JobManager
|
20
|
+
from arbor.server.services.training_manager import TrainingManager
|
21
|
+
|
22
|
+
# Global server state
|
23
|
+
_server = None
|
24
|
+
_server_thread = None
|
25
|
+
_server_loop = None
|
26
|
+
_server_host = None
|
27
|
+
_server_port = None
|
28
|
+
|
29
|
+
|
30
|
+
def create_app(
|
31
|
+
config_path: str = None,
|
32
|
+
storage_path: str = None,
|
33
|
+
inference_gpus: str = None,
|
34
|
+
training_gpus: str = None,
|
35
|
+
):
|
36
|
+
"""Create and configure the Arbor API application
|
37
|
+
|
38
|
+
Args:
|
39
|
+
arbor_config_path (str): Path to config file
|
40
|
+
storage_path (str): Path to storage directory
|
41
|
+
inference_gpus (str): gpu ids to use for inference
|
42
|
+
training_gpus (str): gpu ids to use for training
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
FastAPI: Configured FastAPI application
|
46
|
+
"""
|
47
|
+
# Create new config instance with overrides
|
48
|
+
if config_path:
|
49
|
+
config = Config.load_config_from_yaml(config_path)
|
50
|
+
elif inference_gpus and training_gpus:
|
51
|
+
config = Config.load_config_directly(
|
52
|
+
storage_path, inference_gpus, training_gpus
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
raise ValueError(
|
56
|
+
"Either 'config_path' must be provided, or 'inference_gpus', and 'training_gpus' must be provided"
|
57
|
+
)
|
58
|
+
|
59
|
+
app.state.log_dir = Config.make_log_dir(config.STORAGE_PATH)
|
60
|
+
|
61
|
+
# Initialize services with config
|
62
|
+
health_manager = HealthManager(config=config)
|
63
|
+
file_manager = FileManager(config=config)
|
64
|
+
job_manager = JobManager(config=config)
|
65
|
+
training_manager = TrainingManager(config=config)
|
66
|
+
inference_manager = InferenceManager(config=config)
|
67
|
+
grpo_manager = GRPOManager(config=config)
|
68
|
+
|
69
|
+
# Inject config into app state
|
70
|
+
app.state.config = config
|
71
|
+
app.state.file_manager = file_manager
|
72
|
+
app.state.job_manager = job_manager
|
73
|
+
app.state.training_manager = training_manager
|
74
|
+
app.state.inference_manager = inference_manager
|
75
|
+
app.state.grpo_manager = grpo_manager
|
76
|
+
app.state.health_manager = health_manager
|
77
|
+
|
78
|
+
return app
|
79
|
+
|
80
|
+
|
81
|
+
def serve(
|
82
|
+
config_path: str = None,
|
83
|
+
storage_path: str = None,
|
84
|
+
inference_gpus: str = None,
|
85
|
+
training_gpus: str = None,
|
86
|
+
host: str = "0.0.0.0",
|
87
|
+
port: int = 7453,
|
88
|
+
):
|
89
|
+
"""Start the Arbor API server.
|
90
|
+
|
91
|
+
Starts the server in a background thread and returns once the server is ready to accept requests.
|
92
|
+
Use arbor.stop() to shutdown the server.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
config_path: Path to YAML config file (optional)
|
96
|
+
storage_path: Valid storage directory path (optional)
|
97
|
+
inference_gpus: GPU IDs for inference, e.g. "0,1" (optional, default 0)
|
98
|
+
training_gpus: GPU IDs for training, e.g. "1,2,3" (optional, default 1,2)
|
99
|
+
host: Host to bind to (default: "0.0.0.0")
|
100
|
+
port: Port to bind to (default: 7453)
|
101
|
+
|
102
|
+
Example:
|
103
|
+
import arbor
|
104
|
+
arbor.serve(inference_gpus="0", training_gpus="1,2")
|
105
|
+
# Server is now ready to accept requests
|
106
|
+
# Later, to stop:
|
107
|
+
arbor.stop()
|
108
|
+
"""
|
109
|
+
global _server, _server_thread, _server_loop, _server_host, _server_port
|
110
|
+
|
111
|
+
# Stop existing server if running
|
112
|
+
if _server is not None:
|
113
|
+
print("🌳 Stopping existing server...")
|
114
|
+
stop()
|
115
|
+
|
116
|
+
_server_host = host
|
117
|
+
_server_port = port
|
118
|
+
|
119
|
+
create_app(config_path, storage_path, inference_gpus, training_gpus)
|
120
|
+
|
121
|
+
# Start server in background thread
|
122
|
+
def run_server():
|
123
|
+
global _server, _server_loop
|
124
|
+
|
125
|
+
# Create a new event loop for this thread
|
126
|
+
loop = asyncio.new_event_loop()
|
127
|
+
asyncio.set_event_loop(loop)
|
128
|
+
_server_loop = loop
|
129
|
+
|
130
|
+
# Create uvicorn config and server
|
131
|
+
config = uvicorn.Config(app, host=host, port=port, loop=loop)
|
132
|
+
server = uvicorn.Server(config)
|
133
|
+
_server = server
|
134
|
+
|
135
|
+
# Run the server
|
136
|
+
try:
|
137
|
+
loop.run_until_complete(server.serve())
|
138
|
+
except Exception as e:
|
139
|
+
print(f"Server error: {e}")
|
140
|
+
finally:
|
141
|
+
loop.close()
|
142
|
+
_server = None
|
143
|
+
_server_loop = None
|
144
|
+
|
145
|
+
# Start server thread
|
146
|
+
_server_thread = threading.Thread(target=run_server, daemon=True)
|
147
|
+
_server_thread.start()
|
148
|
+
|
149
|
+
print(f"🌳 Starting Arbor server on http://{host}:{port}...")
|
150
|
+
|
151
|
+
# Wait for server to be ready
|
152
|
+
try:
|
153
|
+
_wait_for_server_ready(host, port, timeout=60) # Increased timeout
|
154
|
+
print(f"🌳 Arbor server is ready and accepting requests!")
|
155
|
+
except TimeoutError as e:
|
156
|
+
print(f"❌ {e}")
|
157
|
+
# Try to stop the server if it failed to start properly
|
158
|
+
stop()
|
159
|
+
raise
|
160
|
+
|
161
|
+
|
162
|
+
def stop():
|
163
|
+
"""Stop the Arbor server if it's running."""
|
164
|
+
global _server, _server_thread, _server_loop
|
165
|
+
|
166
|
+
if _server is None:
|
167
|
+
print("🌳 No server running to stop.")
|
168
|
+
return
|
169
|
+
|
170
|
+
print("🌳 Stopping Arbor server...")
|
171
|
+
|
172
|
+
# Schedule server shutdown in the server's event loop
|
173
|
+
if _server_loop and _server:
|
174
|
+
try:
|
175
|
+
asyncio.run_coroutine_threadsafe(_server.shutdown(), _server_loop)
|
176
|
+
except Exception as e:
|
177
|
+
print(f"Error during shutdown: {e}")
|
178
|
+
|
179
|
+
# Wait for thread to finish
|
180
|
+
if _server_thread and _server_thread.is_alive():
|
181
|
+
_server_thread.join(timeout=5)
|
182
|
+
|
183
|
+
# Reset global state
|
184
|
+
_server = None
|
185
|
+
_server_thread = None
|
186
|
+
_server_loop = None
|
187
|
+
|
188
|
+
print("🌳 Arbor server stopped.")
|
189
|
+
|
190
|
+
|
191
|
+
def is_running():
|
192
|
+
"""Check if the Arbor server is currently running."""
|
193
|
+
return (
|
194
|
+
_server is not None and _server_thread is not None and _server_thread.is_alive()
|
195
|
+
)
|
196
|
+
|
197
|
+
|
198
|
+
def _wait_for_server_ready(host, port, timeout=30):
|
199
|
+
"""Wait for the server to be ready to accept requests."""
|
200
|
+
start_time = time.time()
|
201
|
+
last_error = None
|
202
|
+
port_open = False
|
203
|
+
|
204
|
+
print(f"🌳 Waiting for server to be ready at http://{host}:{port}...")
|
205
|
+
|
206
|
+
while time.time() - start_time < timeout:
|
207
|
+
# First check if the port is open
|
208
|
+
if not port_open:
|
209
|
+
try:
|
210
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
211
|
+
sock.settimeout(1)
|
212
|
+
result = sock.connect_ex((host, port))
|
213
|
+
sock.close()
|
214
|
+
if result == 0:
|
215
|
+
port_open = True
|
216
|
+
print(f"🌳 Port {port} is now open, checking health endpoint...")
|
217
|
+
else:
|
218
|
+
last_error = f"Port {port} not yet open"
|
219
|
+
time.sleep(0.5)
|
220
|
+
continue
|
221
|
+
except Exception as e:
|
222
|
+
last_error = f"Socket error: {e}"
|
223
|
+
time.sleep(0.5)
|
224
|
+
continue
|
225
|
+
|
226
|
+
# Now try the health check
|
227
|
+
try:
|
228
|
+
response = requests.get(f"http://{host}:{port}/health/simple", timeout=2)
|
229
|
+
if response.status_code == 200:
|
230
|
+
print(f"🌳 Server ready! Response: {response.json()}")
|
231
|
+
return
|
232
|
+
else:
|
233
|
+
last_error = f"Health check returned status {response.status_code}"
|
234
|
+
except requests.exceptions.ConnectionError as e:
|
235
|
+
last_error = f"Connection error: {e}"
|
236
|
+
port_open = False # Port might have closed
|
237
|
+
except requests.exceptions.Timeout as e:
|
238
|
+
last_error = f"Timeout error: {e}"
|
239
|
+
except requests.exceptions.RequestException as e:
|
240
|
+
last_error = f"Request error: {e}"
|
241
|
+
except Exception as e:
|
242
|
+
last_error = f"Unexpected error: {e}"
|
243
|
+
|
244
|
+
# Print progress every 5 seconds
|
245
|
+
elapsed = time.time() - start_time
|
246
|
+
if int(elapsed) % 5 == 0 and elapsed >= 5:
|
247
|
+
print(
|
248
|
+
f"🌳 Still waiting... ({elapsed:.1f}s elapsed, port_open={port_open}, last error: {last_error})"
|
249
|
+
)
|
250
|
+
|
251
|
+
time.sleep(0.5)
|
252
|
+
|
253
|
+
raise TimeoutError(
|
254
|
+
f"Server did not become ready within {timeout} seconds. Last error: {last_error}"
|
255
|
+
)
|
256
|
+
|
257
|
+
|
258
|
+
if __name__ == "__main__":
|
259
|
+
serve(inference_gpus="0, 1", training_gpus="2, 3")
|
@@ -178,7 +178,7 @@ class ChatCompletionModel(BaseModel):
|
|
178
178
|
|
179
179
|
class GRPORequest(BaseModel):
|
180
180
|
model: str
|
181
|
-
batch: List[dict]
|
181
|
+
batch: List[dict] | List[List[dict]]
|
182
182
|
|
183
183
|
|
184
184
|
class GRPOConfigRequest(BaseModel):
|
@@ -205,6 +205,8 @@ class GRPOConfigRequest(BaseModel):
|
|
205
205
|
# Arbor specific
|
206
206
|
max_context_length: Optional[int] = None
|
207
207
|
lora: Optional[bool] = None
|
208
|
+
grpo_flavor: Optional[Literal["grpo", "mmgrpo"]] = None
|
209
|
+
wandb_kwargs: Optional[dict] = None
|
208
210
|
# To name the run
|
209
211
|
suffix: Optional[str] = None
|
210
212
|
generation_batch_size: Optional[int] = None
|
@@ -10,7 +10,6 @@ from arbor.server.api.models.schemas import (
|
|
10
10
|
GRPOConfigResponse,
|
11
11
|
GRPORequest,
|
12
12
|
GRPOStepResponse,
|
13
|
-
GRPOTerminateRequest,
|
14
13
|
GRPOTerminateResponse,
|
15
14
|
)
|
16
15
|
|
@@ -27,12 +26,9 @@ def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
|
|
27
26
|
|
28
27
|
# Create a grpo job
|
29
28
|
@router.post("/step", response_model=GRPOStepResponse)
|
30
|
-
def run_grpo_step(
|
31
|
-
request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
|
32
|
-
):
|
33
|
-
inference_manager = request.app.state.inference_manager
|
29
|
+
def run_grpo_step(request: Request, grpo_request: GRPORequest):
|
34
30
|
grpo_manager = request.app.state.grpo_manager
|
35
|
-
|
31
|
+
inference_manager = request.app.state.inference_manager
|
36
32
|
step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
|
37
33
|
|
38
34
|
return GRPOStepResponse(status="success", **step_data)
|
@@ -3,6 +3,10 @@ import uuid
|
|
3
3
|
|
4
4
|
from fastapi import APIRouter, Request
|
5
5
|
|
6
|
+
from arbor.server.utils.logging import get_logger
|
7
|
+
|
8
|
+
logger = get_logger(__name__)
|
9
|
+
|
6
10
|
router = APIRouter()
|
7
11
|
|
8
12
|
|
@@ -27,17 +31,17 @@ async def run_inference(
|
|
27
31
|
|
28
32
|
# if a server isnt running, launch one
|
29
33
|
if not inference_manager.is_server_running():
|
30
|
-
|
34
|
+
logger.info("No model is running, launching model...")
|
31
35
|
inference_manager.launch(request_model)
|
32
36
|
|
33
37
|
# if the requested model is different from the launched model, swap the server
|
34
38
|
if request_model != inference_manager.launched_model:
|
35
|
-
|
39
|
+
logger.info(
|
36
40
|
f"Model changed from {inference_manager.launched_model} to {request_model}, swapping server..."
|
37
41
|
)
|
38
42
|
inference_manager.kill()
|
39
43
|
inference_manager.launch(request_model)
|
40
|
-
|
44
|
+
logger.info(f"Model swapped to {request_model}")
|
41
45
|
|
42
46
|
# forward the request to the inference server
|
43
47
|
completion = await inference_manager.run_inference(raw_json)
|