arbor-ai 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +89 -5
- arbor/client/api.py +1 -2
- arbor/server/api/models/schemas.py +209 -5
- arbor/server/api/routes/files.py +39 -10
- arbor/server/api/routes/grpo.py +54 -0
- arbor/server/api/routes/inference.py +53 -0
- arbor/server/api/routes/jobs.py +110 -7
- arbor/server/core/config.py +44 -7
- arbor/server/main.py +6 -5
- arbor/server/services/comms/__init__.py +0 -0
- arbor/server/services/comms/comms.py +226 -0
- arbor/server/services/dependencies.py +0 -16
- arbor/server/services/file_manager.py +270 -109
- arbor/server/services/grpo_manager.py +310 -0
- arbor/server/services/inference_manager.py +275 -0
- arbor/server/services/job_manager.py +74 -69
- arbor/server/services/scripts/grpo_training.py +576 -0
- arbor/server/services/training_manager.py +337 -40
- arbor_ai-0.1.6.dist-info/METADATA +78 -0
- arbor_ai-0.1.6.dist-info/RECORD +34 -0
- {arbor_ai-0.1.4.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +2 -1
- arbor_ai-0.1.6.dist-info/entry_points.txt +2 -0
- arbor_ai-0.1.6.dist-info/top_level.txt +1 -0
- arbor/server/api/routes/training.py +0 -16
- arbor_ai-0.1.4.dist-info/METADATA +0 -97
- arbor_ai-0.1.4.dist-info/RECORD +0 -27
- arbor_ai-0.1.4.dist-info/entry_points.txt +0 -3
- {arbor_ai-0.1.4.dist-info → arbor_ai-0.1.6.dist-info/licenses}/LICENSE +0 -0
arbor/cli.py
CHANGED
@@ -1,17 +1,101 @@
|
|
1
1
|
import click
|
2
2
|
import uvicorn
|
3
|
+
|
4
|
+
from arbor.server.core.config import Settings
|
3
5
|
from arbor.server.main import app
|
6
|
+
from arbor.server.services.file_manager import FileManager
|
7
|
+
from arbor.server.services.grpo_manager import GRPOManager
|
8
|
+
from arbor.server.services.inference_manager import InferenceManager
|
9
|
+
from arbor.server.services.job_manager import JobManager
|
10
|
+
from arbor.server.services.training_manager import TrainingManager
|
11
|
+
|
4
12
|
|
5
13
|
@click.group()
|
6
14
|
def cli():
|
7
15
|
pass
|
8
16
|
|
17
|
+
|
18
|
+
def create_app(arbor_config_path: str):
|
19
|
+
"""Create and configure the Arbor API application
|
20
|
+
|
21
|
+
Args:
|
22
|
+
storage_path (str): Path to store models and uploaded training files
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
FastAPI: Configured FastAPI application
|
26
|
+
"""
|
27
|
+
# Create new settings instance with overrides
|
28
|
+
settings = Settings.load_from_yaml(arbor_config_path)
|
29
|
+
|
30
|
+
# Initialize services with settings
|
31
|
+
file_manager = FileManager(settings=settings)
|
32
|
+
job_manager = JobManager(settings=settings)
|
33
|
+
training_manager = TrainingManager(settings=settings)
|
34
|
+
inference_manager = InferenceManager(settings=settings)
|
35
|
+
grpo_manager = GRPOManager(settings=settings)
|
36
|
+
# Inject settings into app state
|
37
|
+
app.state.settings = settings
|
38
|
+
app.state.file_manager = file_manager
|
39
|
+
app.state.job_manager = job_manager
|
40
|
+
app.state.training_manager = training_manager
|
41
|
+
app.state.inference_manager = inference_manager
|
42
|
+
app.state.grpo_manager = grpo_manager
|
43
|
+
|
44
|
+
return app
|
45
|
+
|
46
|
+
|
47
|
+
def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
|
48
|
+
"""Start the Arbor API server with a single function call"""
|
49
|
+
import socket
|
50
|
+
import threading
|
51
|
+
import time
|
52
|
+
from contextlib import closing
|
53
|
+
|
54
|
+
def is_port_in_use(port):
|
55
|
+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
56
|
+
return sock.connect_ex(("localhost", port)) == 0
|
57
|
+
|
58
|
+
# First ensure the port is free
|
59
|
+
if is_port_in_use(port):
|
60
|
+
raise RuntimeError(f"Port {port} is already in use")
|
61
|
+
|
62
|
+
app = create_app(storage_path)
|
63
|
+
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
64
|
+
server = uvicorn.Server(config)
|
65
|
+
|
66
|
+
def run_server():
|
67
|
+
server.run()
|
68
|
+
|
69
|
+
thread = threading.Thread(target=run_server, daemon=True)
|
70
|
+
thread.start()
|
71
|
+
|
72
|
+
# Wait for server to start
|
73
|
+
start_time = time.time()
|
74
|
+
while not is_port_in_use(port):
|
75
|
+
if time.time() - start_time > timeout:
|
76
|
+
raise TimeoutError(f"Server failed to start within {timeout} seconds")
|
77
|
+
time.sleep(0.1)
|
78
|
+
|
79
|
+
# Give it a little extra time to fully initialize
|
80
|
+
time.sleep(0.5)
|
81
|
+
|
82
|
+
return server
|
83
|
+
|
84
|
+
|
85
|
+
def stop_server(server):
|
86
|
+
"""Stop the Arbor API server"""
|
87
|
+
server.should_exit = True
|
88
|
+
|
89
|
+
|
9
90
|
@cli.command()
|
10
|
-
@click.option(
|
11
|
-
@click.option(
|
12
|
-
|
91
|
+
@click.option("--host", default="0.0.0.0", help="Host to bind to")
|
92
|
+
@click.option("--port", default=7453, help="Port to bind to")
|
93
|
+
@click.option("--arbor-config", required=True, help="Path to the Arbor config file")
|
94
|
+
def serve(host, port, arbor_config):
|
13
95
|
"""Start the Arbor API server"""
|
96
|
+
app = create_app(arbor_config)
|
14
97
|
uvicorn.run(app, host=host, port=port)
|
15
98
|
|
16
|
-
|
17
|
-
|
99
|
+
|
100
|
+
if __name__ == "__main__":
|
101
|
+
cli()
|
arbor/client/api.py
CHANGED
@@ -1,2 +1 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
# Unused Right Now
|
@@ -1,6 +1,19 @@
|
|
1
|
-
from
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Any, Generic, List, Literal, Optional, TypeVar
|
2
3
|
|
3
|
-
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
5
|
+
|
6
|
+
# Generic type for list items
|
7
|
+
T = TypeVar("T")
|
8
|
+
|
9
|
+
|
10
|
+
class PaginatedResponse(BaseModel, Generic[T]):
|
11
|
+
object: str = "list"
|
12
|
+
data: List[T]
|
13
|
+
has_more: bool = False
|
14
|
+
|
15
|
+
|
16
|
+
class FileModel(BaseModel):
|
4
17
|
id: str
|
5
18
|
object: str = "file"
|
6
19
|
bytes: int
|
@@ -8,12 +21,203 @@ class FileResponse(BaseModel):
|
|
8
21
|
filename: str
|
9
22
|
purpose: str
|
10
23
|
|
24
|
+
|
25
|
+
class WandbConfig(BaseModel):
|
26
|
+
project: str
|
27
|
+
name: Optional[str] = None
|
28
|
+
entity: Optional[str] = None
|
29
|
+
tags: Optional[List[str]] = None
|
30
|
+
|
31
|
+
|
32
|
+
class IntegrationModel(BaseModel):
|
33
|
+
type: str
|
34
|
+
wandb: WandbConfig
|
35
|
+
|
36
|
+
|
11
37
|
class FineTuneRequest(BaseModel):
|
12
38
|
model: str
|
13
39
|
training_file: str # id of uploaded jsonl file
|
40
|
+
method: dict
|
41
|
+
suffix: Optional[str] = None
|
42
|
+
# UNUSED
|
43
|
+
validation_file: Optional[str] = None
|
44
|
+
integrations: Optional[List[IntegrationModel]] = []
|
45
|
+
seed: Optional[int] = None
|
46
|
+
|
47
|
+
|
48
|
+
class ErrorModel(BaseModel):
|
49
|
+
code: str
|
50
|
+
message: str
|
51
|
+
param: str | None = None
|
52
|
+
|
53
|
+
|
54
|
+
class SupervisedHyperparametersModel(BaseModel):
|
55
|
+
batch_size: int | str = "auto"
|
56
|
+
learning_rate_multiplier: float | str = "auto"
|
57
|
+
n_epochs: int | str = "auto"
|
58
|
+
|
59
|
+
|
60
|
+
class DPOHyperparametersModel(BaseModel):
|
61
|
+
beta: float | str = "auto"
|
62
|
+
batch_size: int | str = "auto"
|
63
|
+
learning_rate_multiplier: float | str = "auto"
|
64
|
+
n_epochs: int | str = "auto"
|
14
65
|
|
15
|
-
|
66
|
+
|
67
|
+
class SupervisedModel(BaseModel):
|
68
|
+
hyperparameters: SupervisedHyperparametersModel
|
69
|
+
|
70
|
+
|
71
|
+
class DpoModel(BaseModel):
|
72
|
+
hyperparameters: DPOHyperparametersModel
|
73
|
+
|
74
|
+
|
75
|
+
class MethodModel(BaseModel):
|
76
|
+
type: Literal["supervised"] | Literal["dpo"]
|
77
|
+
supervised: SupervisedModel | None = None
|
78
|
+
dpo: DpoModel | None = None
|
79
|
+
|
80
|
+
|
81
|
+
# https://platform.openai.com/docs/api-reference/fine-tuning/object
|
82
|
+
class JobStatus(Enum):
|
83
|
+
PENDING = "pending" # Not in OAI
|
84
|
+
PENDING_PAUSE = "pending_pause" # Not in OAI
|
85
|
+
PENDING_RESUME = "pending_resume" # Not in OAI
|
86
|
+
PAUSED = "paused" # Not in OAI
|
87
|
+
VALIDATING_FILES = "validating_files"
|
88
|
+
QUEUED = "queued"
|
89
|
+
RUNNING = "running"
|
90
|
+
SUCCEEDED = "succeeded"
|
91
|
+
FAILED = "failed"
|
92
|
+
CANCELLED = "cancelled"
|
93
|
+
PENDING_CANCEL = "pending_cancel"
|
94
|
+
|
95
|
+
|
96
|
+
# https://platform.openai.com/docs/api-reference/fine-tuning/object
|
97
|
+
class JobStatusModel(BaseModel):
|
98
|
+
object: str = "fine_tuning.job"
|
99
|
+
id: str
|
100
|
+
fine_tuned_model: str | None = None
|
101
|
+
status: JobStatus
|
102
|
+
|
103
|
+
# UNUSED so commented out
|
104
|
+
# model: str
|
105
|
+
# created_at: int
|
106
|
+
# error: ErrorModel | None = None
|
107
|
+
# details: str = ""
|
108
|
+
# finished_at: int
|
109
|
+
# hyperparameters: None # deprecated in OAI
|
110
|
+
# organization_id: str
|
111
|
+
# result_files: list[str]
|
112
|
+
# trained_tokens: int | None = None # None if not finished
|
113
|
+
# training_file: str
|
114
|
+
# validation_file: str
|
115
|
+
# integrations: list[Integration]
|
116
|
+
# seed: int
|
117
|
+
# estimated_finish: int | None = None # The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.
|
118
|
+
# method: MethodModel
|
119
|
+
# metadata: dict[str, str]
|
120
|
+
|
121
|
+
|
122
|
+
class JobEventModel(BaseModel):
|
123
|
+
object: str = "fine_tuning.job_event"
|
16
124
|
id: str
|
125
|
+
created_at: int
|
126
|
+
level: str
|
127
|
+
message: str
|
128
|
+
data: dict[str, Any]
|
129
|
+
type: str
|
130
|
+
|
131
|
+
|
132
|
+
class MetricsModel(BaseModel):
|
133
|
+
step: int
|
134
|
+
train_loss: float
|
135
|
+
train_mean_token_accuracy: float
|
136
|
+
valid_loss: float
|
137
|
+
valid_mean_token_accuracy: float
|
138
|
+
full_valid_loss: float
|
139
|
+
full_valid_mean_token_accuracy: float
|
140
|
+
|
141
|
+
|
142
|
+
class JobCheckpointModel(BaseModel):
|
143
|
+
object: str = "fine_tuning.job_checkpoint"
|
144
|
+
id: str
|
145
|
+
created_at: int
|
146
|
+
fine_tuned_model_checkpoint: str
|
147
|
+
step_number: int
|
148
|
+
metrics: MetricsModel
|
149
|
+
fine_tuning_job_id: str
|
150
|
+
|
151
|
+
|
152
|
+
class ChatCompletionMessage(BaseModel):
|
153
|
+
role: Literal["system", "user", "assistant"]
|
154
|
+
content: str
|
155
|
+
|
156
|
+
|
157
|
+
class ChatCompletionRequest(BaseModel):
|
158
|
+
model: str
|
159
|
+
messages: List[ChatCompletionMessage]
|
160
|
+
temperature: float | None = None
|
161
|
+
top_p: float | None = None
|
162
|
+
max_tokens: int | None = None
|
163
|
+
|
164
|
+
|
165
|
+
class ChatCompletionChoice(BaseModel):
|
166
|
+
message: ChatCompletionMessage
|
167
|
+
index: int
|
168
|
+
finish_reason: Literal["stop", "length", "tool_calls"]
|
169
|
+
|
170
|
+
|
171
|
+
class ChatCompletionModel(BaseModel):
|
172
|
+
id: str
|
173
|
+
object: str = "chat.completion"
|
174
|
+
created: int
|
175
|
+
model: str
|
176
|
+
choices: List[ChatCompletionChoice]
|
177
|
+
|
178
|
+
|
179
|
+
class GRPORequest(BaseModel):
|
180
|
+
model: str
|
181
|
+
update_inference_model: bool
|
182
|
+
batch: List[dict]
|
183
|
+
|
184
|
+
|
185
|
+
class GRPOConfigRequest(BaseModel):
|
186
|
+
model: str
|
187
|
+
temperature: Optional[float] = None
|
188
|
+
beta: Optional[float] = None
|
189
|
+
num_iterations: Optional[int] = None
|
190
|
+
num_generations: Optional[int] = None
|
191
|
+
per_device_train_batch_size: Optional[int] = None
|
192
|
+
learning_rate: Optional[float] = None
|
193
|
+
gradient_accumulation_steps: Optional[int] = None
|
194
|
+
gradient_checkpointing: Optional[bool] = None
|
195
|
+
lr_scheduler_type: Optional[str] = None
|
196
|
+
max_prompt_length: Optional[int] = None
|
197
|
+
max_completion_length: Optional[int] = None
|
198
|
+
gradient_checkpointing_kwargs: Optional[dict] = {}
|
199
|
+
bf16: Optional[bool] = None
|
200
|
+
scale_rewards: Optional[bool] = None
|
201
|
+
max_grad_norm: Optional[float] = None
|
202
|
+
lora: Optional[bool] = None
|
203
|
+
update_interval: Optional[int] = None
|
204
|
+
# To name the run
|
205
|
+
suffix: Optional[str] = None
|
206
|
+
|
207
|
+
|
208
|
+
class GRPOConfigResponse(BaseModel):
|
209
|
+
status: str
|
210
|
+
|
211
|
+
|
212
|
+
class GRPOTerminateRequest(BaseModel):
|
213
|
+
status: Optional[str] = "success"
|
214
|
+
|
215
|
+
|
216
|
+
class GRPOTerminateResponse(BaseModel):
|
217
|
+
status: str
|
218
|
+
current_model: str
|
219
|
+
|
220
|
+
|
221
|
+
class GRPOStepResponse(BaseModel):
|
17
222
|
status: str
|
18
|
-
|
19
|
-
fine_tuned_model: str | None = None
|
223
|
+
current_model: str
|
arbor/server/api/routes/files.py
CHANGED
@@ -1,23 +1,52 @@
|
|
1
|
-
from
|
2
|
-
|
3
|
-
from
|
4
|
-
|
1
|
+
from typing import Literal
|
2
|
+
|
3
|
+
from fastapi import APIRouter, Body, File, HTTPException, Request, UploadFile
|
4
|
+
|
5
|
+
from arbor.server.api.models.schemas import FileModel, PaginatedResponse
|
5
6
|
from arbor.server.services.file_manager import FileValidationError
|
6
7
|
|
8
|
+
# https://platform.openai.com/docs/api-reference/files/list
|
7
9
|
router = APIRouter()
|
8
10
|
|
9
|
-
|
11
|
+
|
12
|
+
@router.post("", response_model=FileModel)
|
10
13
|
async def upload_file(
|
14
|
+
request: Request,
|
11
15
|
file: UploadFile = File(...),
|
12
|
-
|
16
|
+
purpose: Literal["assistants", "vision", "fine-tune", "batch"] = Body("fine-tune"),
|
13
17
|
):
|
14
|
-
|
18
|
+
file_manager = request.app.state.file_manager
|
19
|
+
if not file.filename.endswith(".jsonl"):
|
15
20
|
raise HTTPException(status_code=400, detail="Only .jsonl files are allowed")
|
16
21
|
|
17
22
|
try:
|
18
23
|
content = await file.read()
|
19
|
-
file_manager.validate_file_format(content)
|
24
|
+
# file_manager.validate_file_format(content) #TODO: add another flag to specify the types of files
|
20
25
|
await file.seek(0) # Reset file pointer to beginning
|
21
|
-
return file_manager.save_uploaded_file(file)
|
26
|
+
return FileModel(**file_manager.save_uploaded_file(file))
|
22
27
|
except FileValidationError as e:
|
23
|
-
raise HTTPException(status_code=400, detail=f"Invalid file format: {str(e)}")
|
28
|
+
raise HTTPException(status_code=400, detail=f"Invalid file format: {str(e)}")
|
29
|
+
|
30
|
+
|
31
|
+
@router.get("", response_model=PaginatedResponse[FileModel])
|
32
|
+
def list_files(request: Request):
|
33
|
+
file_manager = request.app.state.file_manager
|
34
|
+
return PaginatedResponse(
|
35
|
+
items=file_manager.get_files(),
|
36
|
+
total=len(file_manager.get_files()),
|
37
|
+
page=1,
|
38
|
+
page_size=10,
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
@router.get("/{file_id}", response_model=FileModel)
|
43
|
+
def get_file(request: Request, file_id: str):
|
44
|
+
file_manager = request.app.state.file_manager
|
45
|
+
return file_manager.get_file(file_id)
|
46
|
+
|
47
|
+
|
48
|
+
@router.delete("/{file_id}")
|
49
|
+
def delete_file(request: Request, file_id: str):
|
50
|
+
file_manager = request.app.state.file_manager
|
51
|
+
file_manager.delete_file(file_id)
|
52
|
+
return {"message": "File deleted"}
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import os
|
2
|
+
import subprocess
|
3
|
+
|
4
|
+
from fastapi import APIRouter, BackgroundTasks, Request
|
5
|
+
|
6
|
+
from arbor.server.api.models.schemas import (
|
7
|
+
GRPOConfigRequest,
|
8
|
+
GRPOConfigResponse,
|
9
|
+
GRPORequest,
|
10
|
+
GRPOStepResponse,
|
11
|
+
GRPOTerminateRequest,
|
12
|
+
GRPOTerminateResponse,
|
13
|
+
)
|
14
|
+
|
15
|
+
router = APIRouter()
|
16
|
+
|
17
|
+
|
18
|
+
@router.post("/initialize", response_model=GRPOConfigResponse)
|
19
|
+
def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
|
20
|
+
inference_manager = request.app.state.inference_manager
|
21
|
+
grpo_manager = request.app.state.grpo_manager
|
22
|
+
grpo_manager.initialize(grpo_config_request, inference_manager)
|
23
|
+
return GRPOConfigResponse(status="success")
|
24
|
+
|
25
|
+
|
26
|
+
# Create a grpo job
|
27
|
+
@router.post("/step", response_model=GRPOStepResponse)
|
28
|
+
def run_grpo_step(
|
29
|
+
request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
|
30
|
+
):
|
31
|
+
inference_manager = request.app.state.inference_manager
|
32
|
+
grpo_manager = request.app.state.grpo_manager
|
33
|
+
|
34
|
+
current_model = grpo_manager.grpo_step(grpo_request, inference_manager)
|
35
|
+
|
36
|
+
return GRPOStepResponse(status="success", current_model=current_model)
|
37
|
+
|
38
|
+
|
39
|
+
@router.post("/update_model", response_model=GRPOStepResponse)
|
40
|
+
def update_model(request: Request):
|
41
|
+
grpo_manager = request.app.state.grpo_manager
|
42
|
+
inference_manager = request.app.state.inference_manager
|
43
|
+
current_model = grpo_manager.update_model(request, inference_manager)
|
44
|
+
return GRPOStepResponse(status="success", current_model=current_model)
|
45
|
+
|
46
|
+
|
47
|
+
@router.post("/terminate", response_model=GRPOTerminateResponse)
|
48
|
+
def terminate_grpo(request: Request):
|
49
|
+
# No body needed for this request at this moment
|
50
|
+
grpo_manager = request.app.state.grpo_manager
|
51
|
+
inference_manager = request.app.state.inference_manager
|
52
|
+
|
53
|
+
final_model = grpo_manager.terminate(inference_manager)
|
54
|
+
return GRPOTerminateResponse(status="success", current_model=final_model)
|
@@ -0,0 +1,53 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
from fastapi import APIRouter, Request
|
4
|
+
|
5
|
+
router = APIRouter()
|
6
|
+
|
7
|
+
|
8
|
+
@router.post("/completions")
|
9
|
+
async def run_inference(
|
10
|
+
request: Request,
|
11
|
+
):
|
12
|
+
inference_manager = request.app.state.inference_manager
|
13
|
+
raw_json = await request.json()
|
14
|
+
|
15
|
+
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
16
|
+
for prefix in prefixes:
|
17
|
+
if raw_json["model"].startswith(prefix):
|
18
|
+
raw_json["model"] = raw_json["model"][len(prefix) :]
|
19
|
+
|
20
|
+
# if a server isnt running, launch one
|
21
|
+
if (
|
22
|
+
not inference_manager.is_server_running()
|
23
|
+
and not inference_manager.is_server_restarting()
|
24
|
+
):
|
25
|
+
print("No model is running, launching model...")
|
26
|
+
inference_manager.launch(raw_json["model"])
|
27
|
+
|
28
|
+
if inference_manager.is_server_restarting():
|
29
|
+
print("Waiting for server to finish restarting...")
|
30
|
+
while inference_manager.is_server_restarting():
|
31
|
+
time.sleep(5)
|
32
|
+
# Update the model in the request
|
33
|
+
raw_json["model"] = inference_manager.current_model
|
34
|
+
|
35
|
+
# forward the request to the inference server
|
36
|
+
completion = inference_manager.run_inference(raw_json)
|
37
|
+
|
38
|
+
return completion
|
39
|
+
|
40
|
+
|
41
|
+
@router.post("/launch")
|
42
|
+
async def launch_inference(request: Request):
|
43
|
+
inference_manager = request.app.state.inference_manager
|
44
|
+
raw_json = await request.json()
|
45
|
+
inference_manager.launch(raw_json["model"], raw_json["launch_kwargs"])
|
46
|
+
return {"message": "Inference server launched"}
|
47
|
+
|
48
|
+
|
49
|
+
@router.post("/kill")
|
50
|
+
async def kill_inference(request: Request):
|
51
|
+
inference_manager = request.app.state.inference_manager
|
52
|
+
inference_manager.kill()
|
53
|
+
return {"message": "Inference server killed"}
|
arbor/server/api/routes/jobs.py
CHANGED
@@ -1,14 +1,117 @@
|
|
1
|
-
from fastapi import APIRouter,
|
2
|
-
|
3
|
-
from arbor.server.
|
4
|
-
|
1
|
+
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
|
2
|
+
|
3
|
+
from arbor.server.api.models.schemas import (
|
4
|
+
FineTuneRequest,
|
5
|
+
JobCheckpointModel,
|
6
|
+
JobEventModel,
|
7
|
+
JobStatus,
|
8
|
+
JobStatusModel,
|
9
|
+
PaginatedResponse,
|
10
|
+
)
|
11
|
+
from arbor.server.services.job_manager import JobStatus
|
5
12
|
|
6
13
|
router = APIRouter()
|
7
14
|
|
8
|
-
|
15
|
+
|
16
|
+
# Create a fine-tune job
|
17
|
+
@router.post("", response_model=JobStatusModel)
|
18
|
+
def create_fine_tune_job(
|
19
|
+
request: Request,
|
20
|
+
fine_tune_request: FineTuneRequest,
|
21
|
+
background_tasks: BackgroundTasks,
|
22
|
+
):
|
23
|
+
job_manager = request.app.state.job_manager
|
24
|
+
file_manager = request.app.state.file_manager
|
25
|
+
training_manager = request.app.state.training_manager
|
26
|
+
|
27
|
+
job = job_manager.create_job()
|
28
|
+
background_tasks.add_task(
|
29
|
+
training_manager.fine_tune, fine_tune_request, job, file_manager
|
30
|
+
)
|
31
|
+
job.status = JobStatus.QUEUED
|
32
|
+
return JobStatusModel(id=job.id, status=job.status.value)
|
33
|
+
|
34
|
+
|
35
|
+
# List fine-tune jobs (paginated)
|
36
|
+
@router.get("", response_model=PaginatedResponse[JobStatusModel])
|
37
|
+
def get_jobs(request: Request):
|
38
|
+
job_manager = request.app.state.job_manager
|
39
|
+
return PaginatedResponse(
|
40
|
+
data=[
|
41
|
+
JobStatusModel(id=job.id, status=job.status.value)
|
42
|
+
for job in job_manager.get_jobs()
|
43
|
+
],
|
44
|
+
has_more=False,
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
# List fine-tuning events
|
49
|
+
@router.get("/{job_id}/events", response_model=PaginatedResponse[JobEventModel])
|
50
|
+
def get_job_events(request: Request, job_id: str):
|
51
|
+
job_manager = request.app.state.job_manager
|
52
|
+
job = job_manager.get_job(job_id)
|
53
|
+
return PaginatedResponse(
|
54
|
+
data=[
|
55
|
+
JobEventModel(
|
56
|
+
id=event.id,
|
57
|
+
level=event.level,
|
58
|
+
message=event.message,
|
59
|
+
data=event.data,
|
60
|
+
created_at=int(event.created_at.timestamp()),
|
61
|
+
type="message",
|
62
|
+
)
|
63
|
+
for event in job.get_events()
|
64
|
+
],
|
65
|
+
has_more=False,
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
# List fine-tuning checkpoints
|
70
|
+
@router.get(
|
71
|
+
"/{job_id}/checkpoints", response_model=PaginatedResponse[JobCheckpointModel]
|
72
|
+
)
|
73
|
+
def get_job_checkpoints(request: Request, job_id: str):
|
74
|
+
job_manager = request.app.state.job_manager
|
75
|
+
job = job_manager.get_job(job_id)
|
76
|
+
return PaginatedResponse(
|
77
|
+
data=[
|
78
|
+
JobCheckpointModel(
|
79
|
+
id=checkpoint.id,
|
80
|
+
fine_tuned_model_checkpoint=checkpoint.fine_tuned_model_checkpoint,
|
81
|
+
fine_tuning_job_id=checkpoint.fine_tuning_job_id,
|
82
|
+
metrics=checkpoint.metrics,
|
83
|
+
step_number=checkpoint.step_number,
|
84
|
+
)
|
85
|
+
for checkpoint in job.get_checkpoints()
|
86
|
+
],
|
87
|
+
has_more=False,
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
# Retrieve a fine-tune job by id
|
92
|
+
@router.get("/{job_id}", response_model=JobStatusModel)
|
9
93
|
def get_job_status(
|
94
|
+
request: Request,
|
10
95
|
job_id: str,
|
11
|
-
job_manager: JobManager = Depends(get_job_manager)
|
12
96
|
):
|
97
|
+
job_manager = request.app.state.job_manager
|
13
98
|
job = job_manager.get_job(job_id)
|
14
|
-
return
|
99
|
+
return JobStatusModel(
|
100
|
+
id=job_id, status=job.status.value, fine_tuned_model=job.fine_tuned_model
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
# Cancel a fine-tune job
|
105
|
+
@router.post("/{job_id}/cancel", response_model=JobStatusModel)
|
106
|
+
def cancel_job(request: Request, job_id: str):
|
107
|
+
job_manager = request.app.state.job_manager
|
108
|
+
job = job_manager.get_job(job_id)
|
109
|
+
|
110
|
+
# Only allow cancellation of jobs that aren't finished
|
111
|
+
if job.status in [JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.CANCELLED]:
|
112
|
+
raise HTTPException(
|
113
|
+
status_code=400, detail=f"Cannot cancel job with status {job.status.value}"
|
114
|
+
)
|
115
|
+
|
116
|
+
job.status = JobStatus.PENDING_CANCEL
|
117
|
+
return JobStatusModel(id=job.id, status=job.status.value)
|
arbor/server/core/config.py
CHANGED
@@ -1,10 +1,47 @@
|
|
1
|
-
from
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Optional
|
2
3
|
|
3
|
-
|
4
|
-
|
5
|
-
MODEL_CACHE_DIR: str = "model_cache"
|
4
|
+
import yaml
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
6
6
|
|
7
|
-
class Config:
|
8
|
-
env_file = ".env"
|
9
7
|
|
10
|
-
|
8
|
+
class InferenceConfig(BaseModel):
|
9
|
+
gpu_ids: str = "0"
|
10
|
+
|
11
|
+
|
12
|
+
class TrainingConfig(BaseModel):
|
13
|
+
gpu_ids: str = "0"
|
14
|
+
accelerate_config: Optional[str] = None
|
15
|
+
|
16
|
+
|
17
|
+
class ArborConfig(BaseModel):
|
18
|
+
inference: InferenceConfig
|
19
|
+
training: TrainingConfig
|
20
|
+
|
21
|
+
|
22
|
+
class Settings(BaseModel):
|
23
|
+
|
24
|
+
STORAGE_PATH: str = "./storage"
|
25
|
+
INACTIVITY_TIMEOUT: int = 30 # 5 seconds
|
26
|
+
arbor_config: ArborConfig
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def load_from_yaml(cls, yaml_path: str) -> "Settings":
|
30
|
+
if not yaml_path:
|
31
|
+
raise ValueError("Config file path is required")
|
32
|
+
if not Path(yaml_path).exists():
|
33
|
+
raise ValueError(f"Config file {yaml_path} does not exist")
|
34
|
+
|
35
|
+
try:
|
36
|
+
with open(yaml_path, "r") as f:
|
37
|
+
config = yaml.safe_load(f)
|
38
|
+
|
39
|
+
settings = cls(
|
40
|
+
arbor_config=ArborConfig(
|
41
|
+
inference=InferenceConfig(**config["inference"]),
|
42
|
+
training=TrainingConfig(**config["training"]),
|
43
|
+
)
|
44
|
+
)
|
45
|
+
return settings
|
46
|
+
except Exception as e:
|
47
|
+
raise ValueError(f"Error loading config file {yaml_path}: {e}")
|