arbor-ai 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/PKG-INFO +17 -19
  2. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/README.md +14 -16
  3. arbor_ai-0.2.3/arbor/__init__.py +17 -0
  4. arbor_ai-0.2.3/arbor/cli.py +153 -0
  5. arbor_ai-0.2.3/arbor/client/arbor_client.py +259 -0
  6. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/api/models/schemas.py +3 -1
  7. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/api/routes/grpo.py +2 -6
  8. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/api/routes/inference.py +7 -3
  9. arbor_ai-0.2.3/arbor/server/core/config.py +333 -0
  10. arbor_ai-0.2.3/arbor/server/core/config_manager.py +100 -0
  11. arbor_ai-0.2.3/arbor/server/main.py +36 -0
  12. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/comms/comms.py +13 -9
  13. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/file_manager.py +7 -4
  14. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/grpo_manager.py +98 -62
  15. arbor_ai-0.2.3/arbor/server/services/health_manager.py +171 -0
  16. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/inference/vllm_client.py +6 -4
  17. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/inference_manager.py +40 -38
  18. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/job_manager.py +2 -2
  19. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/scripts/grpo_training.py +62 -281
  20. arbor_ai-0.2.3/arbor/server/services/scripts/mmgrpo_training.py +510 -0
  21. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/scripts/sft_training.py +8 -5
  22. arbor_ai-0.2.3/arbor/server/services/scripts/utils/callbacks.py +33 -0
  23. arbor_ai-0.2.3/arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  24. arbor_ai-0.2.3/arbor/server/services/scripts/utils/dataset.py +176 -0
  25. arbor_ai-0.2.3/arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  26. arbor_ai-0.2.3/arbor/server/services/scripts/utils/mock_server.py +124 -0
  27. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/training_manager.py +4 -4
  28. arbor_ai-0.2.3/arbor/server/utils/logging.py +298 -0
  29. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor_ai.egg-info/PKG-INFO +17 -19
  30. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor_ai.egg-info/SOURCES.txt +9 -0
  31. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor_ai.egg-info/requires.txt +1 -1
  32. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/pyproject.toml +4 -4
  33. arbor_ai-0.2.1/arbor/cli.py +0 -113
  34. arbor_ai-0.2.1/arbor/server/core/config.py +0 -47
  35. arbor_ai-0.2.1/arbor/server/main.py +0 -11
  36. arbor_ai-0.2.1/arbor/server/services/scripts/utils/dataset.py +0 -0
  37. arbor_ai-0.2.1/arbor/server/utils/__init__.py +0 -0
  38. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/LICENSE +0 -0
  39. {arbor_ai-0.2.1/arbor → arbor_ai-0.2.3/arbor/client}/__init__.py +0 -0
  40. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/client/api.py +0 -0
  41. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/__init__.py +0 -0
  42. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/api/__init__.py +0 -0
  43. {arbor_ai-0.2.1/arbor/client → arbor_ai-0.2.3/arbor/server/api/routes}/__init__.py +0 -0
  44. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/api/routes/files.py +0 -0
  45. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/api/routes/jobs.py +0 -0
  46. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/core/__init__.py +0 -0
  47. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/core/logging.py +0 -0
  48. {arbor_ai-0.2.1/arbor/server/api/routes → arbor_ai-0.2.3/arbor/server/services}/__init__.py +0 -0
  49. {arbor_ai-0.2.1/arbor/server/services → arbor_ai-0.2.3/arbor/server/services/comms}/__init__.py +0 -0
  50. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/dependencies.py +0 -0
  51. {arbor_ai-0.2.1/arbor/server/services/comms → arbor_ai-0.2.3/arbor/server/services/inference}/__init__.py +0 -0
  52. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/inference/vllm_serve.py +0 -0
  53. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/scripts/dpo_training.py +0 -0
  54. {arbor_ai-0.2.1/arbor/server/services/inference → arbor_ai-0.2.3/arbor/server/services/scripts/utils}/__init__.py +0 -0
  55. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/services/scripts/utils/arg_parser.py +0 -0
  56. {arbor_ai-0.2.1/arbor/server/services/scripts → arbor_ai-0.2.3/arbor/server}/utils/__init__.py +0 -0
  57. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor/server/utils/helpers.py +0 -0
  58. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor_ai.egg-info/dependency_links.txt +0 -0
  59. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor_ai.egg-info/entry_points.txt +0 -0
  60. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/arbor_ai.egg-info/top_level.txt +0 -0
  61. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/setup.cfg +0 -0
  62. {arbor_ai-0.2.1 → arbor_ai-0.2.3}/tests/test_cli.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
7
7
  Project-URL: Issues, https://github.com/Ziems/arbor/issues
8
- Requires-Python: >=3.10
8
+ Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE
11
11
  Requires-Dist: torch>=2.6.0
@@ -14,7 +14,7 @@ Requires-Dist: uvicorn
14
14
  Requires-Dist: click
15
15
  Requires-Dist: python-multipart
16
16
  Requires-Dist: pydantic-settings
17
- Requires-Dist: vllm>=0.8.5.post1
17
+ Requires-Dist: vllm==0.8.5.post1
18
18
  Requires-Dist: transformers
19
19
  Requires-Dist: trl>=0.17.0
20
20
  Requires-Dist: peft
@@ -43,7 +43,8 @@ Install Arbor via pip:
43
43
  pip install -U arbor-ai
44
44
  ```
45
45
 
46
- Optionally, you can also install:
46
+ Optionally, you can also install flash attention to speed up inference. <br/>
47
+ This can take 15+ minutes to install on some setups:
47
48
  ```bash
48
49
  pip install flash-attn --no-build-isolation
49
50
  ```
@@ -52,27 +53,16 @@ pip install flash-attn --no-build-isolation
52
53
 
53
54
  ## ⚡ Quick Start
54
55
 
55
- ### 1️⃣ Make an `arbor.yaml` File
56
-
57
- This is all dependent on your setup. Here is an example of one:
58
- ```yaml
59
- inference:
60
- gpu_ids: '0'
61
-
62
- training:
63
- gpu_ids: '1, 2'
64
- ```
65
- Which will use the `GPU:0` for inference with `GPU:1` and `GPU:2` reserved for training. We generally recommend splitting the GPUs roughly evenly between inference and training.
66
-
67
- ### 2️⃣ Start the Server
56
+ ### 1️⃣ Start the Server
68
57
 
69
58
  **CLI:**
70
59
 
71
60
  ```bash
72
- python -m arbor.cli serve --arbor-config arbor.yaml
61
+ python -m arbor.cli serve
73
62
  ```
63
+ On the first run you'll be asked which GPUs will be used for training and which for inference. For more that one GPU, separate the ids by comma: `1, 2`. Your config file will be saved in `~/.arbor/config.yaml` should you want to edit these configs in the future.
74
64
 
75
- ### 3️⃣ Optimize a DSPy Program
65
+ ### 2️⃣ Optimize a DSPy Program
76
66
 
77
67
  Follow the DSPy tutorials here to see usage examples:
78
68
  [DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
@@ -89,6 +79,14 @@ export NCCL_P2P_DISABLE=1
89
79
  export NCCL_IB_DISABLE=1
90
80
  ```
91
81
 
82
+ **NVCC**
83
+ If you run into issues, double check that you have [nvcc](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/) installed:
84
+ ```bash
85
+ nvcc --version
86
+ ```
87
+ If you don't have admin permissions, you can often install nvcc using conda.
88
+
89
+
92
90
  ## 🙏 Acknowledgements
93
91
 
94
92
  Arbor builds on the shoulders of great work. We extend our thanks to:
@@ -16,7 +16,8 @@ Install Arbor via pip:
16
16
  pip install -U arbor-ai
17
17
  ```
18
18
 
19
- Optionally, you can also install:
19
+ Optionally, you can also install flash attention to speed up inference. <br/>
20
+ This can take 15+ minutes to install on some setups:
20
21
  ```bash
21
22
  pip install flash-attn --no-build-isolation
22
23
  ```
@@ -25,27 +26,16 @@ pip install flash-attn --no-build-isolation
25
26
 
26
27
  ## ⚡ Quick Start
27
28
 
28
- ### 1️⃣ Make an `arbor.yaml` File
29
-
30
- This is all dependent on your setup. Here is an example of one:
31
- ```yaml
32
- inference:
33
- gpu_ids: '0'
34
-
35
- training:
36
- gpu_ids: '1, 2'
37
- ```
38
- Which will use the `GPU:0` for inference with `GPU:1` and `GPU:2` reserved for training. We generally recommend splitting the GPUs roughly evenly between inference and training.
39
-
40
- ### 2️⃣ Start the Server
29
+ ### 1️⃣ Start the Server
41
30
 
42
31
  **CLI:**
43
32
 
44
33
  ```bash
45
- python -m arbor.cli serve --arbor-config arbor.yaml
34
+ python -m arbor.cli serve
46
35
  ```
36
+ On the first run you'll be asked which GPUs will be used for training and which for inference. For more that one GPU, separate the ids by comma: `1, 2`. Your config file will be saved in `~/.arbor/config.yaml` should you want to edit these configs in the future.
47
37
 
48
- ### 3️⃣ Optimize a DSPy Program
38
+ ### 2️⃣ Optimize a DSPy Program
49
39
 
50
40
  Follow the DSPy tutorials here to see usage examples:
51
41
  [DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
@@ -62,6 +52,14 @@ export NCCL_P2P_DISABLE=1
62
52
  export NCCL_IB_DISABLE=1
63
53
  ```
64
54
 
55
+ **NVCC**
56
+ If you run into issues, double check that you have [nvcc](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/) installed:
57
+ ```bash
58
+ nvcc --version
59
+ ```
60
+ If you don't have admin permissions, you can often install nvcc using conda.
61
+
62
+
65
63
  ## 🙏 Acknowledgements
66
64
 
67
65
  Arbor builds on the shoulders of great work. We extend our thanks to:
@@ -0,0 +1,17 @@
1
+ """
2
+ Arbor - A framework for fine-tuning and managing language models
3
+ """
4
+
5
+ from importlib.metadata import PackageNotFoundError, version
6
+
7
+ try:
8
+ __version__ = version("arbor-ai")
9
+ except PackageNotFoundError:
10
+ # Package is not installed, likely in development mode
11
+ __version__ = "dev"
12
+ except Exception:
13
+ __version__ = "unknown"
14
+
15
+ from arbor.client.arbor_client import is_running, serve, stop
16
+
17
+ __all__ = ["__version__", "serve", "stop", "is_running"]
@@ -0,0 +1,153 @@
1
+ import os
2
+ from datetime import datetime
3
+
4
+ import click
5
+ import uvicorn
6
+
7
+ from arbor.server.core.config import Config
8
+ from arbor.server.core.config_manager import ConfigManager
9
+ from arbor.server.main import app
10
+ from arbor.server.services.file_manager import FileManager
11
+ from arbor.server.services.grpo_manager import GRPOManager
12
+ from arbor.server.services.health_manager import HealthManager
13
+ from arbor.server.services.inference_manager import InferenceManager
14
+ from arbor.server.services.job_manager import JobManager
15
+ from arbor.server.services.training_manager import TrainingManager
16
+ from arbor.client.arbor_client import create_app
17
+
18
+
19
+ @click.group()
20
+ def cli():
21
+ pass
22
+
23
+ def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
24
+ """Start the Arbor API server with a single function call"""
25
+ import socket
26
+ import threading
27
+ import time
28
+ from contextlib import closing
29
+
30
+ def is_port_in_use(port):
31
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
32
+ return sock.connect_ex(("localhost", port)) == 0
33
+
34
+ # First ensure the port is free
35
+ if is_port_in_use(port):
36
+ raise RuntimeError(f"Port {port} is already in use")
37
+
38
+ app = create_app(storage_path)
39
+ # configure_uvicorn_logging()
40
+ config = uvicorn.Config(app, host=host, port=port, log_level="info")
41
+ server = uvicorn.Server(config)
42
+
43
+ def run_server():
44
+ server.run()
45
+
46
+ thread = threading.Thread(target=run_server, daemon=True)
47
+ thread.start()
48
+
49
+ # Wait for server to start
50
+ start_time = time.time()
51
+ while not is_port_in_use(port):
52
+ if time.time() - start_time > timeout:
53
+ raise TimeoutError(f"Server failed to start within {timeout} seconds")
54
+ time.sleep(0.1)
55
+
56
+ # Give it a little extra time to fully initialize
57
+ time.sleep(0.5)
58
+
59
+ return server
60
+
61
+
62
+ def stop_server(server):
63
+ """Stop the Arbor API server"""
64
+ server.should_exit = True
65
+
66
+
67
+ @cli.command()
68
+ @click.option("--host", default="0.0.0.0", help="Host to bind to")
69
+ @click.option("--port", default=7453, help="Port to bind to")
70
+ @click.option("--arbor-config", required=False, help="Path to the Arbor config file")
71
+ def serve(host, port, arbor_config):
72
+ """Start the Arbor API server"""
73
+
74
+ if arbor_config:
75
+ config_path = arbor_config
76
+ else:
77
+ config_path = Config.use_default_config()
78
+
79
+ # If no config found, run first-time setup
80
+ if config_path is None:
81
+ config_path = run_first_time_setup()
82
+
83
+ # Validate config exists and is readable
84
+ is_valid, msg = ConfigManager.validate_config_file(config_path)
85
+
86
+ if not is_valid:
87
+ click.echo(msg)
88
+ raise click.Abort()
89
+
90
+ try:
91
+ create_app(config_path)
92
+ # Temporarily disable custom uvicorn logging configuration
93
+ # configure_uvicorn_logging()
94
+ uvicorn.run(app, host=host, port=port)
95
+ except Exception as e:
96
+ click.echo(f"Failed to start server: {e}", err=True)
97
+ raise click.Abort()
98
+
99
+
100
+ def run_first_time_setup() -> str:
101
+ """Run first-time setup and return created config path"""
102
+ click.echo("Welcome to Arbor!")
103
+ click.echo("It looks like this is your first time running Arbor.")
104
+ click.echo("Let's set up your configuration...\n")
105
+
106
+ try:
107
+ # Get config details
108
+ inference = click.prompt(
109
+ "Which gpu ids should be used for inference (separated by comma)",
110
+ default="0",
111
+ )
112
+ training = click.prompt(
113
+ "Which gpu ids should be used for training (separated by comma)",
114
+ default="1, 2",
115
+ )
116
+ click.echo()
117
+
118
+ # Get config file path
119
+ config_path = click.prompt(
120
+ "Enter path to save config file in. We recommend (~/.arbor/config.yaml)",
121
+ default=ConfigManager.get_default_config_path(),
122
+ )
123
+ logger = get_logger(__name__)
124
+ logger.info(f"Config path selected: {config_path}")
125
+ click.echo()
126
+
127
+ # Update or create config at path
128
+ config_path = ConfigManager.update_config(inference, training, config_path)
129
+ click.echo(f"Created configuration at: {config_path}")
130
+
131
+ # Check if it is a valid config file
132
+ is_valid, msg = ConfigManager.validate_config_file(config_path)
133
+ if not is_valid:
134
+ raise click.ClickException(f"Invalid config file: {msg}")
135
+
136
+ # Read and display the contents
137
+ _, content = ConfigManager.get_config_contents(config_path)
138
+
139
+ click.echo("\nConfiguration file contents:")
140
+ click.echo("---")
141
+ click.echo(content)
142
+ click.echo("---")
143
+
144
+ click.echo("\nSetup complete! Starting Arbor server...")
145
+ return config_path
146
+
147
+ except Exception as e:
148
+ click.echo(f"Failed initial setup of Arbor: {e}", err=True)
149
+ raise click.Abort()
150
+
151
+
152
+ if __name__ == "__main__":
153
+ cli()
@@ -0,0 +1,259 @@
1
+ import asyncio
2
+ import os
3
+ import socket
4
+ import threading
5
+ import time
6
+ from datetime import datetime
7
+
8
+ import click
9
+ import requests
10
+ import uvicorn
11
+
12
+ from arbor.server.core.config import Config
13
+ from arbor.server.core.config_manager import ConfigManager
14
+ from arbor.server.main import app
15
+ from arbor.server.services.file_manager import FileManager
16
+ from arbor.server.services.grpo_manager import GRPOManager
17
+ from arbor.server.services.health_manager import HealthManager
18
+ from arbor.server.services.inference_manager import InferenceManager
19
+ from arbor.server.services.job_manager import JobManager
20
+ from arbor.server.services.training_manager import TrainingManager
21
+
22
+ # Global server state
23
+ _server = None
24
+ _server_thread = None
25
+ _server_loop = None
26
+ _server_host = None
27
+ _server_port = None
28
+
29
+
30
+ def create_app(
31
+ config_path: str = None,
32
+ storage_path: str = None,
33
+ inference_gpus: str = None,
34
+ training_gpus: str = None,
35
+ ):
36
+ """Create and configure the Arbor API application
37
+
38
+ Args:
39
+ arbor_config_path (str): Path to config file
40
+ storage_path (str): Path to storage directory
41
+ inference_gpus (str): gpu ids to use for inference
42
+ training_gpus (str): gpu ids to use for training
43
+
44
+ Returns:
45
+ FastAPI: Configured FastAPI application
46
+ """
47
+ # Create new config instance with overrides
48
+ if config_path:
49
+ config = Config.load_config_from_yaml(config_path)
50
+ elif inference_gpus and training_gpus:
51
+ config = Config.load_config_directly(
52
+ storage_path, inference_gpus, training_gpus
53
+ )
54
+ else:
55
+ raise ValueError(
56
+ "Either 'config_path' must be provided, or 'inference_gpus', and 'training_gpus' must be provided"
57
+ )
58
+
59
+ app.state.log_dir = Config.make_log_dir(config.STORAGE_PATH)
60
+
61
+ # Initialize services with config
62
+ health_manager = HealthManager(config=config)
63
+ file_manager = FileManager(config=config)
64
+ job_manager = JobManager(config=config)
65
+ training_manager = TrainingManager(config=config)
66
+ inference_manager = InferenceManager(config=config)
67
+ grpo_manager = GRPOManager(config=config)
68
+
69
+ # Inject config into app state
70
+ app.state.config = config
71
+ app.state.file_manager = file_manager
72
+ app.state.job_manager = job_manager
73
+ app.state.training_manager = training_manager
74
+ app.state.inference_manager = inference_manager
75
+ app.state.grpo_manager = grpo_manager
76
+ app.state.health_manager = health_manager
77
+
78
+ return app
79
+
80
+
81
+ def serve(
82
+ config_path: str = None,
83
+ storage_path: str = None,
84
+ inference_gpus: str = None,
85
+ training_gpus: str = None,
86
+ host: str = "0.0.0.0",
87
+ port: int = 7453,
88
+ ):
89
+ """Start the Arbor API server.
90
+
91
+ Starts the server in a background thread and returns once the server is ready to accept requests.
92
+ Use arbor.stop() to shutdown the server.
93
+
94
+ Args:
95
+ config_path: Path to YAML config file (optional)
96
+ storage_path: Valid storage directory path (optional)
97
+ inference_gpus: GPU IDs for inference, e.g. "0,1" (optional, default 0)
98
+ training_gpus: GPU IDs for training, e.g. "1,2,3" (optional, default 1,2)
99
+ host: Host to bind to (default: "0.0.0.0")
100
+ port: Port to bind to (default: 7453)
101
+
102
+ Example:
103
+ import arbor
104
+ arbor.serve(inference_gpus="0", training_gpus="1,2")
105
+ # Server is now ready to accept requests
106
+ # Later, to stop:
107
+ arbor.stop()
108
+ """
109
+ global _server, _server_thread, _server_loop, _server_host, _server_port
110
+
111
+ # Stop existing server if running
112
+ if _server is not None:
113
+ print("🌳 Stopping existing server...")
114
+ stop()
115
+
116
+ _server_host = host
117
+ _server_port = port
118
+
119
+ create_app(config_path, storage_path, inference_gpus, training_gpus)
120
+
121
+ # Start server in background thread
122
+ def run_server():
123
+ global _server, _server_loop
124
+
125
+ # Create a new event loop for this thread
126
+ loop = asyncio.new_event_loop()
127
+ asyncio.set_event_loop(loop)
128
+ _server_loop = loop
129
+
130
+ # Create uvicorn config and server
131
+ config = uvicorn.Config(app, host=host, port=port, loop=loop)
132
+ server = uvicorn.Server(config)
133
+ _server = server
134
+
135
+ # Run the server
136
+ try:
137
+ loop.run_until_complete(server.serve())
138
+ except Exception as e:
139
+ print(f"Server error: {e}")
140
+ finally:
141
+ loop.close()
142
+ _server = None
143
+ _server_loop = None
144
+
145
+ # Start server thread
146
+ _server_thread = threading.Thread(target=run_server, daemon=True)
147
+ _server_thread.start()
148
+
149
+ print(f"🌳 Starting Arbor server on http://{host}:{port}...")
150
+
151
+ # Wait for server to be ready
152
+ try:
153
+ _wait_for_server_ready(host, port, timeout=60) # Increased timeout
154
+ print(f"🌳 Arbor server is ready and accepting requests!")
155
+ except TimeoutError as e:
156
+ print(f"❌ {e}")
157
+ # Try to stop the server if it failed to start properly
158
+ stop()
159
+ raise
160
+
161
+
162
+ def stop():
163
+ """Stop the Arbor server if it's running."""
164
+ global _server, _server_thread, _server_loop
165
+
166
+ if _server is None:
167
+ print("🌳 No server running to stop.")
168
+ return
169
+
170
+ print("🌳 Stopping Arbor server...")
171
+
172
+ # Schedule server shutdown in the server's event loop
173
+ if _server_loop and _server:
174
+ try:
175
+ asyncio.run_coroutine_threadsafe(_server.shutdown(), _server_loop)
176
+ except Exception as e:
177
+ print(f"Error during shutdown: {e}")
178
+
179
+ # Wait for thread to finish
180
+ if _server_thread and _server_thread.is_alive():
181
+ _server_thread.join(timeout=5)
182
+
183
+ # Reset global state
184
+ _server = None
185
+ _server_thread = None
186
+ _server_loop = None
187
+
188
+ print("🌳 Arbor server stopped.")
189
+
190
+
191
+ def is_running():
192
+ """Check if the Arbor server is currently running."""
193
+ return (
194
+ _server is not None and _server_thread is not None and _server_thread.is_alive()
195
+ )
196
+
197
+
198
+ def _wait_for_server_ready(host, port, timeout=30):
199
+ """Wait for the server to be ready to accept requests."""
200
+ start_time = time.time()
201
+ last_error = None
202
+ port_open = False
203
+
204
+ print(f"🌳 Waiting for server to be ready at http://{host}:{port}...")
205
+
206
+ while time.time() - start_time < timeout:
207
+ # First check if the port is open
208
+ if not port_open:
209
+ try:
210
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
211
+ sock.settimeout(1)
212
+ result = sock.connect_ex((host, port))
213
+ sock.close()
214
+ if result == 0:
215
+ port_open = True
216
+ print(f"🌳 Port {port} is now open, checking health endpoint...")
217
+ else:
218
+ last_error = f"Port {port} not yet open"
219
+ time.sleep(0.5)
220
+ continue
221
+ except Exception as e:
222
+ last_error = f"Socket error: {e}"
223
+ time.sleep(0.5)
224
+ continue
225
+
226
+ # Now try the health check
227
+ try:
228
+ response = requests.get(f"http://{host}:{port}/health/simple", timeout=2)
229
+ if response.status_code == 200:
230
+ print(f"🌳 Server ready! Response: {response.json()}")
231
+ return
232
+ else:
233
+ last_error = f"Health check returned status {response.status_code}"
234
+ except requests.exceptions.ConnectionError as e:
235
+ last_error = f"Connection error: {e}"
236
+ port_open = False # Port might have closed
237
+ except requests.exceptions.Timeout as e:
238
+ last_error = f"Timeout error: {e}"
239
+ except requests.exceptions.RequestException as e:
240
+ last_error = f"Request error: {e}"
241
+ except Exception as e:
242
+ last_error = f"Unexpected error: {e}"
243
+
244
+ # Print progress every 5 seconds
245
+ elapsed = time.time() - start_time
246
+ if int(elapsed) % 5 == 0 and elapsed >= 5:
247
+ print(
248
+ f"🌳 Still waiting... ({elapsed:.1f}s elapsed, port_open={port_open}, last error: {last_error})"
249
+ )
250
+
251
+ time.sleep(0.5)
252
+
253
+ raise TimeoutError(
254
+ f"Server did not become ready within {timeout} seconds. Last error: {last_error}"
255
+ )
256
+
257
+
258
+ if __name__ == "__main__":
259
+ serve(inference_gpus="0, 1", training_gpus="2, 3")
@@ -178,7 +178,7 @@ class ChatCompletionModel(BaseModel):
178
178
 
179
179
  class GRPORequest(BaseModel):
180
180
  model: str
181
- batch: List[dict]
181
+ batch: List[dict] | List[List[dict]]
182
182
 
183
183
 
184
184
  class GRPOConfigRequest(BaseModel):
@@ -205,6 +205,8 @@ class GRPOConfigRequest(BaseModel):
205
205
  # Arbor specific
206
206
  max_context_length: Optional[int] = None
207
207
  lora: Optional[bool] = None
208
+ grpo_flavor: Optional[Literal["grpo", "mmgrpo"]] = None
209
+ wandb_kwargs: Optional[dict] = None
208
210
  # To name the run
209
211
  suffix: Optional[str] = None
210
212
  generation_batch_size: Optional[int] = None
@@ -10,7 +10,6 @@ from arbor.server.api.models.schemas import (
10
10
  GRPOConfigResponse,
11
11
  GRPORequest,
12
12
  GRPOStepResponse,
13
- GRPOTerminateRequest,
14
13
  GRPOTerminateResponse,
15
14
  )
16
15
 
@@ -27,12 +26,9 @@ def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
27
26
 
28
27
  # Create a grpo job
29
28
  @router.post("/step", response_model=GRPOStepResponse)
30
- def run_grpo_step(
31
- request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
32
- ):
33
- inference_manager = request.app.state.inference_manager
29
+ def run_grpo_step(request: Request, grpo_request: GRPORequest):
34
30
  grpo_manager = request.app.state.grpo_manager
35
-
31
+ inference_manager = request.app.state.inference_manager
36
32
  step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
37
33
 
38
34
  return GRPOStepResponse(status="success", **step_data)
@@ -3,6 +3,10 @@ import uuid
3
3
 
4
4
  from fastapi import APIRouter, Request
5
5
 
6
+ from arbor.server.utils.logging import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
6
10
  router = APIRouter()
7
11
 
8
12
 
@@ -27,17 +31,17 @@ async def run_inference(
27
31
 
28
32
  # if a server isnt running, launch one
29
33
  if not inference_manager.is_server_running():
30
- print("No model is running, launching model...")
34
+ logger.info("No model is running, launching model...")
31
35
  inference_manager.launch(request_model)
32
36
 
33
37
  # if the requested model is different from the launched model, swap the server
34
38
  if request_model != inference_manager.launched_model:
35
- print(
39
+ logger.info(
36
40
  f"Model changed from {inference_manager.launched_model} to {request_model}, swapping server..."
37
41
  )
38
42
  inference_manager.kill()
39
43
  inference_manager.launch(request_model)
40
- print(f"Model swapped to {request_model}")
44
+ logger.info(f"Model swapped to {request_model}")
41
45
 
42
46
  # forward the request to the inference server
43
47
  completion = await inference_manager.run_inference(raw_json)