arbor-ai 0.2__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/__init__.py +17 -0
- arbor/cli.py +83 -43
- arbor/client/arbor_client.py +259 -0
- arbor/server/api/models/schemas.py +3 -1
- arbor/server/api/routes/grpo.py +2 -6
- arbor/server/api/routes/inference.py +7 -3
- arbor/server/core/config.py +293 -7
- arbor/server/core/config_manager.py +100 -0
- arbor/server/main.py +26 -1
- arbor/server/services/comms/comms.py +13 -9
- arbor/server/services/file_manager.py +7 -4
- arbor/server/services/grpo_manager.py +101 -63
- arbor/server/services/health_manager.py +171 -0
- arbor/server/services/inference/vllm_client.py +8 -5
- arbor/server/services/inference_manager.py +40 -38
- arbor/server/services/job_manager.py +2 -2
- arbor/server/services/scripts/grpo_training.py +62 -281
- arbor/server/services/scripts/mmgrpo_training.py +510 -0
- arbor/server/services/scripts/sft_training.py +8 -5
- arbor/server/services/scripts/utils/callbacks.py +33 -0
- arbor/server/services/scripts/utils/comms_monitors.py +169 -0
- arbor/server/services/scripts/utils/dataset.py +176 -0
- arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
- arbor/server/services/scripts/utils/mock_server.py +124 -0
- arbor/server/services/training_manager.py +4 -4
- arbor/server/utils/logging.py +298 -0
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +18 -18
- arbor_ai-0.2.2.dist-info/RECORD +51 -0
- arbor_ai-0.2.dist-info/RECORD +0 -42
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.2.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
arbor/__init__.py
CHANGED
@@ -0,0 +1,17 @@
|
|
1
|
+
"""
|
2
|
+
Arbor - A framework for fine-tuning and managing language models
|
3
|
+
"""
|
4
|
+
|
5
|
+
from importlib.metadata import PackageNotFoundError, version
|
6
|
+
|
7
|
+
try:
|
8
|
+
__version__ = version("arbor-ai")
|
9
|
+
except PackageNotFoundError:
|
10
|
+
# Package is not installed, likely in development mode
|
11
|
+
__version__ = "dev"
|
12
|
+
except Exception:
|
13
|
+
__version__ = "unknown"
|
14
|
+
|
15
|
+
from arbor.client.arbor_client import is_running, serve, stop
|
16
|
+
|
17
|
+
__all__ = ["__version__", "serve", "stop", "is_running"]
|
arbor/cli.py
CHANGED
@@ -4,58 +4,22 @@ from datetime import datetime
|
|
4
4
|
import click
|
5
5
|
import uvicorn
|
6
6
|
|
7
|
-
from arbor.server.core.config import
|
7
|
+
from arbor.server.core.config import Config
|
8
|
+
from arbor.server.core.config_manager import ConfigManager
|
8
9
|
from arbor.server.main import app
|
9
10
|
from arbor.server.services.file_manager import FileManager
|
10
11
|
from arbor.server.services.grpo_manager import GRPOManager
|
12
|
+
from arbor.server.services.health_manager import HealthManager
|
11
13
|
from arbor.server.services.inference_manager import InferenceManager
|
12
14
|
from arbor.server.services.job_manager import JobManager
|
13
15
|
from arbor.server.services.training_manager import TrainingManager
|
14
|
-
|
15
|
-
|
16
|
-
def make_log_dir(storage_path: str):
|
17
|
-
# Create a timestamped log directory under the storage path
|
18
|
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
19
|
-
log_dir = os.path.join(storage_path, "logs", timestamp)
|
20
|
-
os.makedirs(log_dir, exist_ok=True)
|
21
|
-
return log_dir
|
16
|
+
from arbor.client.arbor_client import create_app
|
22
17
|
|
23
18
|
|
24
19
|
@click.group()
|
25
20
|
def cli():
|
26
21
|
pass
|
27
22
|
|
28
|
-
|
29
|
-
def create_app(arbor_config_path: str):
|
30
|
-
"""Create and configure the Arbor API application
|
31
|
-
|
32
|
-
Args:
|
33
|
-
storage_path (str): Path to store models and uploaded training files
|
34
|
-
|
35
|
-
Returns:
|
36
|
-
FastAPI: Configured FastAPI application
|
37
|
-
"""
|
38
|
-
# Create new settings instance with overrides
|
39
|
-
settings = Settings.load_from_yaml(arbor_config_path)
|
40
|
-
app.state.log_dir = make_log_dir(settings.STORAGE_PATH)
|
41
|
-
|
42
|
-
# Initialize services with settings
|
43
|
-
file_manager = FileManager(settings=settings)
|
44
|
-
job_manager = JobManager(settings=settings)
|
45
|
-
training_manager = TrainingManager(settings=settings)
|
46
|
-
inference_manager = InferenceManager(settings=settings)
|
47
|
-
grpo_manager = GRPOManager(settings=settings)
|
48
|
-
# Inject settings into app state
|
49
|
-
app.state.settings = settings
|
50
|
-
app.state.file_manager = file_manager
|
51
|
-
app.state.job_manager = job_manager
|
52
|
-
app.state.training_manager = training_manager
|
53
|
-
app.state.inference_manager = inference_manager
|
54
|
-
app.state.grpo_manager = grpo_manager
|
55
|
-
|
56
|
-
return app
|
57
|
-
|
58
|
-
|
59
23
|
def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
|
60
24
|
"""Start the Arbor API server with a single function call"""
|
61
25
|
import socket
|
@@ -72,6 +36,7 @@ def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10
|
|
72
36
|
raise RuntimeError(f"Port {port} is already in use")
|
73
37
|
|
74
38
|
app = create_app(storage_path)
|
39
|
+
# configure_uvicorn_logging()
|
75
40
|
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
76
41
|
server = uvicorn.Server(config)
|
77
42
|
|
@@ -102,11 +67,86 @@ def stop_server(server):
|
|
102
67
|
@cli.command()
|
103
68
|
@click.option("--host", default="0.0.0.0", help="Host to bind to")
|
104
69
|
@click.option("--port", default=7453, help="Port to bind to")
|
105
|
-
@click.option("--arbor-config", required=
|
70
|
+
@click.option("--arbor-config", required=False, help="Path to the Arbor config file")
|
106
71
|
def serve(host, port, arbor_config):
|
107
72
|
"""Start the Arbor API server"""
|
108
|
-
|
109
|
-
|
73
|
+
|
74
|
+
if arbor_config:
|
75
|
+
config_path = arbor_config
|
76
|
+
else:
|
77
|
+
config_path = Config.use_default_config()
|
78
|
+
|
79
|
+
# If no config found, run first-time setup
|
80
|
+
if config_path is None:
|
81
|
+
config_path = run_first_time_setup()
|
82
|
+
|
83
|
+
# Validate config exists and is readable
|
84
|
+
is_valid, msg = ConfigManager.validate_config_file(config_path)
|
85
|
+
|
86
|
+
if not is_valid:
|
87
|
+
click.echo(msg)
|
88
|
+
raise click.Abort()
|
89
|
+
|
90
|
+
try:
|
91
|
+
create_app(config_path)
|
92
|
+
# Temporarily disable custom uvicorn logging configuration
|
93
|
+
# configure_uvicorn_logging()
|
94
|
+
uvicorn.run(app, host=host, port=port)
|
95
|
+
except Exception as e:
|
96
|
+
click.echo(f"Failed to start server: {e}", err=True)
|
97
|
+
raise click.Abort()
|
98
|
+
|
99
|
+
|
100
|
+
def run_first_time_setup() -> str:
|
101
|
+
"""Run first-time setup and return created config path"""
|
102
|
+
click.echo("Welcome to Arbor!")
|
103
|
+
click.echo("It looks like this is your first time running Arbor.")
|
104
|
+
click.echo("Let's set up your configuration...\n")
|
105
|
+
|
106
|
+
try:
|
107
|
+
# Get config details
|
108
|
+
inference = click.prompt(
|
109
|
+
"Which gpu ids should be used for inference (separated by comma)",
|
110
|
+
default="0",
|
111
|
+
)
|
112
|
+
training = click.prompt(
|
113
|
+
"Which gpu ids should be used for training (separated by comma)",
|
114
|
+
default="1, 2",
|
115
|
+
)
|
116
|
+
click.echo()
|
117
|
+
|
118
|
+
# Get config file path
|
119
|
+
config_path = click.prompt(
|
120
|
+
"Enter path to save config file in. We recommend (~/.arbor/config.yaml)",
|
121
|
+
default=ConfigManager.get_default_config_path(),
|
122
|
+
)
|
123
|
+
logger = get_logger(__name__)
|
124
|
+
logger.info(f"Config path selected: {config_path}")
|
125
|
+
click.echo()
|
126
|
+
|
127
|
+
# Update or create config at path
|
128
|
+
config_path = ConfigManager.update_config(inference, training, config_path)
|
129
|
+
click.echo(f"Created configuration at: {config_path}")
|
130
|
+
|
131
|
+
# Check if it is a valid config file
|
132
|
+
is_valid, msg = ConfigManager.validate_config_file(config_path)
|
133
|
+
if not is_valid:
|
134
|
+
raise click.ClickException(f"Invalid config file: {msg}")
|
135
|
+
|
136
|
+
# Read and display the contents
|
137
|
+
_, content = ConfigManager.get_config_contents(config_path)
|
138
|
+
|
139
|
+
click.echo("\nConfiguration file contents:")
|
140
|
+
click.echo("---")
|
141
|
+
click.echo(content)
|
142
|
+
click.echo("---")
|
143
|
+
|
144
|
+
click.echo("\nSetup complete! Starting Arbor server...")
|
145
|
+
return config_path
|
146
|
+
|
147
|
+
except Exception as e:
|
148
|
+
click.echo(f"Failed initial setup of Arbor: {e}", err=True)
|
149
|
+
raise click.Abort()
|
110
150
|
|
111
151
|
|
112
152
|
if __name__ == "__main__":
|
@@ -0,0 +1,259 @@
|
|
1
|
+
import asyncio
|
2
|
+
import os
|
3
|
+
import socket
|
4
|
+
import threading
|
5
|
+
import time
|
6
|
+
from datetime import datetime
|
7
|
+
|
8
|
+
import click
|
9
|
+
import requests
|
10
|
+
import uvicorn
|
11
|
+
|
12
|
+
from arbor.server.core.config import Config
|
13
|
+
from arbor.server.core.config_manager import ConfigManager
|
14
|
+
from arbor.server.main import app
|
15
|
+
from arbor.server.services.file_manager import FileManager
|
16
|
+
from arbor.server.services.grpo_manager import GRPOManager
|
17
|
+
from arbor.server.services.health_manager import HealthManager
|
18
|
+
from arbor.server.services.inference_manager import InferenceManager
|
19
|
+
from arbor.server.services.job_manager import JobManager
|
20
|
+
from arbor.server.services.training_manager import TrainingManager
|
21
|
+
|
22
|
+
# Global server state
|
23
|
+
_server = None
|
24
|
+
_server_thread = None
|
25
|
+
_server_loop = None
|
26
|
+
_server_host = None
|
27
|
+
_server_port = None
|
28
|
+
|
29
|
+
|
30
|
+
def create_app(
|
31
|
+
config_path: str = None,
|
32
|
+
storage_path: str = None,
|
33
|
+
inference_gpus: str = None,
|
34
|
+
training_gpus: str = None,
|
35
|
+
):
|
36
|
+
"""Create and configure the Arbor API application
|
37
|
+
|
38
|
+
Args:
|
39
|
+
arbor_config_path (str): Path to config file
|
40
|
+
storage_path (str): Path to storage directory
|
41
|
+
inference_gpus (str): gpu ids to use for inference
|
42
|
+
training_gpus (str): gpu ids to use for training
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
FastAPI: Configured FastAPI application
|
46
|
+
"""
|
47
|
+
# Create new config instance with overrides
|
48
|
+
if config_path:
|
49
|
+
config = Config.load_config_from_yaml(config_path)
|
50
|
+
elif inference_gpus and training_gpus:
|
51
|
+
config = Config.load_config_directly(
|
52
|
+
storage_path, inference_gpus, training_gpus
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
raise ValueError(
|
56
|
+
"Either 'config_path' must be provided, or 'inference_gpus', and 'training_gpus' must be provided"
|
57
|
+
)
|
58
|
+
|
59
|
+
app.state.log_dir = Config.make_log_dir(config.STORAGE_PATH)
|
60
|
+
|
61
|
+
# Initialize services with config
|
62
|
+
health_manager = HealthManager(config=config)
|
63
|
+
file_manager = FileManager(config=config)
|
64
|
+
job_manager = JobManager(config=config)
|
65
|
+
training_manager = TrainingManager(config=config)
|
66
|
+
inference_manager = InferenceManager(config=config)
|
67
|
+
grpo_manager = GRPOManager(config=config)
|
68
|
+
|
69
|
+
# Inject config into app state
|
70
|
+
app.state.config = config
|
71
|
+
app.state.file_manager = file_manager
|
72
|
+
app.state.job_manager = job_manager
|
73
|
+
app.state.training_manager = training_manager
|
74
|
+
app.state.inference_manager = inference_manager
|
75
|
+
app.state.grpo_manager = grpo_manager
|
76
|
+
app.state.health_manager = health_manager
|
77
|
+
|
78
|
+
return app
|
79
|
+
|
80
|
+
|
81
|
+
def serve(
|
82
|
+
config_path: str = None,
|
83
|
+
storage_path: str = None,
|
84
|
+
inference_gpus: str = None,
|
85
|
+
training_gpus: str = None,
|
86
|
+
host: str = "0.0.0.0",
|
87
|
+
port: int = 7453,
|
88
|
+
):
|
89
|
+
"""Start the Arbor API server.
|
90
|
+
|
91
|
+
Starts the server in a background thread and returns once the server is ready to accept requests.
|
92
|
+
Use arbor.stop() to shutdown the server.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
config_path: Path to YAML config file (optional)
|
96
|
+
storage_path: Valid storage directory path (optional)
|
97
|
+
inference_gpus: GPU IDs for inference, e.g. "0,1" (optional, default 0)
|
98
|
+
training_gpus: GPU IDs for training, e.g. "1,2,3" (optional, default 1,2)
|
99
|
+
host: Host to bind to (default: "0.0.0.0")
|
100
|
+
port: Port to bind to (default: 7453)
|
101
|
+
|
102
|
+
Example:
|
103
|
+
import arbor
|
104
|
+
arbor.serve(inference_gpus="0", training_gpus="1,2")
|
105
|
+
# Server is now ready to accept requests
|
106
|
+
# Later, to stop:
|
107
|
+
arbor.stop()
|
108
|
+
"""
|
109
|
+
global _server, _server_thread, _server_loop, _server_host, _server_port
|
110
|
+
|
111
|
+
# Stop existing server if running
|
112
|
+
if _server is not None:
|
113
|
+
print("🌳 Stopping existing server...")
|
114
|
+
stop()
|
115
|
+
|
116
|
+
_server_host = host
|
117
|
+
_server_port = port
|
118
|
+
|
119
|
+
create_app(config_path, storage_path, inference_gpus, training_gpus)
|
120
|
+
|
121
|
+
# Start server in background thread
|
122
|
+
def run_server():
|
123
|
+
global _server, _server_loop
|
124
|
+
|
125
|
+
# Create a new event loop for this thread
|
126
|
+
loop = asyncio.new_event_loop()
|
127
|
+
asyncio.set_event_loop(loop)
|
128
|
+
_server_loop = loop
|
129
|
+
|
130
|
+
# Create uvicorn config and server
|
131
|
+
config = uvicorn.Config(app, host=host, port=port, loop=loop)
|
132
|
+
server = uvicorn.Server(config)
|
133
|
+
_server = server
|
134
|
+
|
135
|
+
# Run the server
|
136
|
+
try:
|
137
|
+
loop.run_until_complete(server.serve())
|
138
|
+
except Exception as e:
|
139
|
+
print(f"Server error: {e}")
|
140
|
+
finally:
|
141
|
+
loop.close()
|
142
|
+
_server = None
|
143
|
+
_server_loop = None
|
144
|
+
|
145
|
+
# Start server thread
|
146
|
+
_server_thread = threading.Thread(target=run_server, daemon=True)
|
147
|
+
_server_thread.start()
|
148
|
+
|
149
|
+
print(f"🌳 Starting Arbor server on http://{host}:{port}...")
|
150
|
+
|
151
|
+
# Wait for server to be ready
|
152
|
+
try:
|
153
|
+
_wait_for_server_ready(host, port, timeout=60) # Increased timeout
|
154
|
+
print(f"🌳 Arbor server is ready and accepting requests!")
|
155
|
+
except TimeoutError as e:
|
156
|
+
print(f"❌ {e}")
|
157
|
+
# Try to stop the server if it failed to start properly
|
158
|
+
stop()
|
159
|
+
raise
|
160
|
+
|
161
|
+
|
162
|
+
def stop():
|
163
|
+
"""Stop the Arbor server if it's running."""
|
164
|
+
global _server, _server_thread, _server_loop
|
165
|
+
|
166
|
+
if _server is None:
|
167
|
+
print("🌳 No server running to stop.")
|
168
|
+
return
|
169
|
+
|
170
|
+
print("🌳 Stopping Arbor server...")
|
171
|
+
|
172
|
+
# Schedule server shutdown in the server's event loop
|
173
|
+
if _server_loop and _server:
|
174
|
+
try:
|
175
|
+
asyncio.run_coroutine_threadsafe(_server.shutdown(), _server_loop)
|
176
|
+
except Exception as e:
|
177
|
+
print(f"Error during shutdown: {e}")
|
178
|
+
|
179
|
+
# Wait for thread to finish
|
180
|
+
if _server_thread and _server_thread.is_alive():
|
181
|
+
_server_thread.join(timeout=5)
|
182
|
+
|
183
|
+
# Reset global state
|
184
|
+
_server = None
|
185
|
+
_server_thread = None
|
186
|
+
_server_loop = None
|
187
|
+
|
188
|
+
print("🌳 Arbor server stopped.")
|
189
|
+
|
190
|
+
|
191
|
+
def is_running():
|
192
|
+
"""Check if the Arbor server is currently running."""
|
193
|
+
return (
|
194
|
+
_server is not None and _server_thread is not None and _server_thread.is_alive()
|
195
|
+
)
|
196
|
+
|
197
|
+
|
198
|
+
def _wait_for_server_ready(host, port, timeout=30):
|
199
|
+
"""Wait for the server to be ready to accept requests."""
|
200
|
+
start_time = time.time()
|
201
|
+
last_error = None
|
202
|
+
port_open = False
|
203
|
+
|
204
|
+
print(f"🌳 Waiting for server to be ready at http://{host}:{port}...")
|
205
|
+
|
206
|
+
while time.time() - start_time < timeout:
|
207
|
+
# First check if the port is open
|
208
|
+
if not port_open:
|
209
|
+
try:
|
210
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
211
|
+
sock.settimeout(1)
|
212
|
+
result = sock.connect_ex((host, port))
|
213
|
+
sock.close()
|
214
|
+
if result == 0:
|
215
|
+
port_open = True
|
216
|
+
print(f"🌳 Port {port} is now open, checking health endpoint...")
|
217
|
+
else:
|
218
|
+
last_error = f"Port {port} not yet open"
|
219
|
+
time.sleep(0.5)
|
220
|
+
continue
|
221
|
+
except Exception as e:
|
222
|
+
last_error = f"Socket error: {e}"
|
223
|
+
time.sleep(0.5)
|
224
|
+
continue
|
225
|
+
|
226
|
+
# Now try the health check
|
227
|
+
try:
|
228
|
+
response = requests.get(f"http://{host}:{port}/health/simple", timeout=2)
|
229
|
+
if response.status_code == 200:
|
230
|
+
print(f"🌳 Server ready! Response: {response.json()}")
|
231
|
+
return
|
232
|
+
else:
|
233
|
+
last_error = f"Health check returned status {response.status_code}"
|
234
|
+
except requests.exceptions.ConnectionError as e:
|
235
|
+
last_error = f"Connection error: {e}"
|
236
|
+
port_open = False # Port might have closed
|
237
|
+
except requests.exceptions.Timeout as e:
|
238
|
+
last_error = f"Timeout error: {e}"
|
239
|
+
except requests.exceptions.RequestException as e:
|
240
|
+
last_error = f"Request error: {e}"
|
241
|
+
except Exception as e:
|
242
|
+
last_error = f"Unexpected error: {e}"
|
243
|
+
|
244
|
+
# Print progress every 5 seconds
|
245
|
+
elapsed = time.time() - start_time
|
246
|
+
if int(elapsed) % 5 == 0 and elapsed >= 5:
|
247
|
+
print(
|
248
|
+
f"🌳 Still waiting... ({elapsed:.1f}s elapsed, port_open={port_open}, last error: {last_error})"
|
249
|
+
)
|
250
|
+
|
251
|
+
time.sleep(0.5)
|
252
|
+
|
253
|
+
raise TimeoutError(
|
254
|
+
f"Server did not become ready within {timeout} seconds. Last error: {last_error}"
|
255
|
+
)
|
256
|
+
|
257
|
+
|
258
|
+
if __name__ == "__main__":
|
259
|
+
serve(inference_gpus="0, 1", training_gpus="2, 3")
|
@@ -178,7 +178,7 @@ class ChatCompletionModel(BaseModel):
|
|
178
178
|
|
179
179
|
class GRPORequest(BaseModel):
|
180
180
|
model: str
|
181
|
-
batch: List[dict]
|
181
|
+
batch: List[dict] | List[List[dict]]
|
182
182
|
|
183
183
|
|
184
184
|
class GRPOConfigRequest(BaseModel):
|
@@ -205,6 +205,8 @@ class GRPOConfigRequest(BaseModel):
|
|
205
205
|
# Arbor specific
|
206
206
|
max_context_length: Optional[int] = None
|
207
207
|
lora: Optional[bool] = None
|
208
|
+
grpo_flavor: Optional[Literal["grpo", "mmgrpo"]] = None
|
209
|
+
wandb_kwargs: Optional[dict] = None
|
208
210
|
# To name the run
|
209
211
|
suffix: Optional[str] = None
|
210
212
|
generation_batch_size: Optional[int] = None
|
arbor/server/api/routes/grpo.py
CHANGED
@@ -10,7 +10,6 @@ from arbor.server.api.models.schemas import (
|
|
10
10
|
GRPOConfigResponse,
|
11
11
|
GRPORequest,
|
12
12
|
GRPOStepResponse,
|
13
|
-
GRPOTerminateRequest,
|
14
13
|
GRPOTerminateResponse,
|
15
14
|
)
|
16
15
|
|
@@ -27,12 +26,9 @@ def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
|
|
27
26
|
|
28
27
|
# Create a grpo job
|
29
28
|
@router.post("/step", response_model=GRPOStepResponse)
|
30
|
-
def run_grpo_step(
|
31
|
-
request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
|
32
|
-
):
|
33
|
-
inference_manager = request.app.state.inference_manager
|
29
|
+
def run_grpo_step(request: Request, grpo_request: GRPORequest):
|
34
30
|
grpo_manager = request.app.state.grpo_manager
|
35
|
-
|
31
|
+
inference_manager = request.app.state.inference_manager
|
36
32
|
step_data = grpo_manager.grpo_step(grpo_request, inference_manager)
|
37
33
|
|
38
34
|
return GRPOStepResponse(status="success", **step_data)
|
@@ -3,6 +3,10 @@ import uuid
|
|
3
3
|
|
4
4
|
from fastapi import APIRouter, Request
|
5
5
|
|
6
|
+
from arbor.server.utils.logging import get_logger
|
7
|
+
|
8
|
+
logger = get_logger(__name__)
|
9
|
+
|
6
10
|
router = APIRouter()
|
7
11
|
|
8
12
|
|
@@ -27,17 +31,17 @@ async def run_inference(
|
|
27
31
|
|
28
32
|
# if a server isnt running, launch one
|
29
33
|
if not inference_manager.is_server_running():
|
30
|
-
|
34
|
+
logger.info("No model is running, launching model...")
|
31
35
|
inference_manager.launch(request_model)
|
32
36
|
|
33
37
|
# if the requested model is different from the launched model, swap the server
|
34
38
|
if request_model != inference_manager.launched_model:
|
35
|
-
|
39
|
+
logger.info(
|
36
40
|
f"Model changed from {inference_manager.launched_model} to {request_model}, swapping server..."
|
37
41
|
)
|
38
42
|
inference_manager.kill()
|
39
43
|
inference_manager.launch(request_model)
|
40
|
-
|
44
|
+
logger.info(f"Model swapped to {request_model}")
|
41
45
|
|
42
46
|
# forward the request to the inference server
|
43
47
|
completion = await inference_manager.run_inference(raw_json)
|