arbor-ai 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/client/__init__.py +0 -0
- arbor/client/api.py +1 -0
- arbor/server/__init__.py +1 -0
- arbor/server/api/__init__.py +1 -0
- arbor/server/api/models/schemas.py +223 -0
- arbor/server/api/routes/__init__.py +0 -0
- arbor/server/api/routes/files.py +52 -0
- arbor/server/api/routes/grpo.py +54 -0
- arbor/server/api/routes/inference.py +53 -0
- arbor/server/api/routes/jobs.py +117 -0
- arbor/server/core/__init__.py +1 -0
- arbor/server/core/config.py +47 -0
- arbor/server/core/logging.py +0 -0
- arbor/server/main.py +11 -0
- arbor/server/services/__init__.py +0 -0
- arbor/server/services/comms/__init__.py +0 -0
- arbor/server/services/comms/comms.py +226 -0
- arbor/server/services/dependencies.py +0 -0
- arbor/server/services/file_manager.py +289 -0
- arbor/server/services/grpo_manager.py +310 -0
- arbor/server/services/inference_manager.py +275 -0
- arbor/server/services/job_manager.py +81 -0
- arbor/server/services/scripts/grpo_training.py +576 -0
- arbor/server/services/training_manager.py +561 -0
- arbor/server/utils/__init__.py +0 -0
- arbor/server/utils/helpers.py +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.7.dist-info}/METADATA +2 -2
- arbor_ai-0.1.7.dist-info/RECORD +34 -0
- arbor_ai-0.1.5.dist-info/RECORD +0 -8
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.7.dist-info}/WHEEL +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.7.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.7.dist-info}/top_level.txt +0 -0
arbor/client/__init__.py
ADDED
File without changes
|
arbor/client/api.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
# Unused Right Now
|
arbor/server/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -0,0 +1,223 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Any, Generic, List, Literal, Optional, TypeVar
|
3
|
+
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
5
|
+
|
6
|
+
# Generic type for list items
|
7
|
+
T = TypeVar("T")
|
8
|
+
|
9
|
+
|
10
|
+
class PaginatedResponse(BaseModel, Generic[T]):
|
11
|
+
object: str = "list"
|
12
|
+
data: List[T]
|
13
|
+
has_more: bool = False
|
14
|
+
|
15
|
+
|
16
|
+
class FileModel(BaseModel):
|
17
|
+
id: str
|
18
|
+
object: str = "file"
|
19
|
+
bytes: int
|
20
|
+
created_at: int
|
21
|
+
filename: str
|
22
|
+
purpose: str
|
23
|
+
|
24
|
+
|
25
|
+
class WandbConfig(BaseModel):
|
26
|
+
project: str
|
27
|
+
name: Optional[str] = None
|
28
|
+
entity: Optional[str] = None
|
29
|
+
tags: Optional[List[str]] = None
|
30
|
+
|
31
|
+
|
32
|
+
class IntegrationModel(BaseModel):
|
33
|
+
type: str
|
34
|
+
wandb: WandbConfig
|
35
|
+
|
36
|
+
|
37
|
+
class FineTuneRequest(BaseModel):
|
38
|
+
model: str
|
39
|
+
training_file: str # id of uploaded jsonl file
|
40
|
+
method: dict
|
41
|
+
suffix: Optional[str] = None
|
42
|
+
# UNUSED
|
43
|
+
validation_file: Optional[str] = None
|
44
|
+
integrations: Optional[List[IntegrationModel]] = []
|
45
|
+
seed: Optional[int] = None
|
46
|
+
|
47
|
+
|
48
|
+
class ErrorModel(BaseModel):
|
49
|
+
code: str
|
50
|
+
message: str
|
51
|
+
param: str | None = None
|
52
|
+
|
53
|
+
|
54
|
+
class SupervisedHyperparametersModel(BaseModel):
|
55
|
+
batch_size: int | str = "auto"
|
56
|
+
learning_rate_multiplier: float | str = "auto"
|
57
|
+
n_epochs: int | str = "auto"
|
58
|
+
|
59
|
+
|
60
|
+
class DPOHyperparametersModel(BaseModel):
|
61
|
+
beta: float | str = "auto"
|
62
|
+
batch_size: int | str = "auto"
|
63
|
+
learning_rate_multiplier: float | str = "auto"
|
64
|
+
n_epochs: int | str = "auto"
|
65
|
+
|
66
|
+
|
67
|
+
class SupervisedModel(BaseModel):
|
68
|
+
hyperparameters: SupervisedHyperparametersModel
|
69
|
+
|
70
|
+
|
71
|
+
class DpoModel(BaseModel):
|
72
|
+
hyperparameters: DPOHyperparametersModel
|
73
|
+
|
74
|
+
|
75
|
+
class MethodModel(BaseModel):
|
76
|
+
type: Literal["supervised"] | Literal["dpo"]
|
77
|
+
supervised: SupervisedModel | None = None
|
78
|
+
dpo: DpoModel | None = None
|
79
|
+
|
80
|
+
|
81
|
+
# https://platform.openai.com/docs/api-reference/fine-tuning/object
|
82
|
+
class JobStatus(Enum):
|
83
|
+
PENDING = "pending" # Not in OAI
|
84
|
+
PENDING_PAUSE = "pending_pause" # Not in OAI
|
85
|
+
PENDING_RESUME = "pending_resume" # Not in OAI
|
86
|
+
PAUSED = "paused" # Not in OAI
|
87
|
+
VALIDATING_FILES = "validating_files"
|
88
|
+
QUEUED = "queued"
|
89
|
+
RUNNING = "running"
|
90
|
+
SUCCEEDED = "succeeded"
|
91
|
+
FAILED = "failed"
|
92
|
+
CANCELLED = "cancelled"
|
93
|
+
PENDING_CANCEL = "pending_cancel"
|
94
|
+
|
95
|
+
|
96
|
+
# https://platform.openai.com/docs/api-reference/fine-tuning/object
|
97
|
+
class JobStatusModel(BaseModel):
|
98
|
+
object: str = "fine_tuning.job"
|
99
|
+
id: str
|
100
|
+
fine_tuned_model: str | None = None
|
101
|
+
status: JobStatus
|
102
|
+
|
103
|
+
# UNUSED so commented out
|
104
|
+
# model: str
|
105
|
+
# created_at: int
|
106
|
+
# error: ErrorModel | None = None
|
107
|
+
# details: str = ""
|
108
|
+
# finished_at: int
|
109
|
+
# hyperparameters: None # deprecated in OAI
|
110
|
+
# organization_id: str
|
111
|
+
# result_files: list[str]
|
112
|
+
# trained_tokens: int | None = None # None if not finished
|
113
|
+
# training_file: str
|
114
|
+
# validation_file: str
|
115
|
+
# integrations: list[Integration]
|
116
|
+
# seed: int
|
117
|
+
# estimated_finish: int | None = None # The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.
|
118
|
+
# method: MethodModel
|
119
|
+
# metadata: dict[str, str]
|
120
|
+
|
121
|
+
|
122
|
+
class JobEventModel(BaseModel):
|
123
|
+
object: str = "fine_tuning.job_event"
|
124
|
+
id: str
|
125
|
+
created_at: int
|
126
|
+
level: str
|
127
|
+
message: str
|
128
|
+
data: dict[str, Any]
|
129
|
+
type: str
|
130
|
+
|
131
|
+
|
132
|
+
class MetricsModel(BaseModel):
|
133
|
+
step: int
|
134
|
+
train_loss: float
|
135
|
+
train_mean_token_accuracy: float
|
136
|
+
valid_loss: float
|
137
|
+
valid_mean_token_accuracy: float
|
138
|
+
full_valid_loss: float
|
139
|
+
full_valid_mean_token_accuracy: float
|
140
|
+
|
141
|
+
|
142
|
+
class JobCheckpointModel(BaseModel):
|
143
|
+
object: str = "fine_tuning.job_checkpoint"
|
144
|
+
id: str
|
145
|
+
created_at: int
|
146
|
+
fine_tuned_model_checkpoint: str
|
147
|
+
step_number: int
|
148
|
+
metrics: MetricsModel
|
149
|
+
fine_tuning_job_id: str
|
150
|
+
|
151
|
+
|
152
|
+
class ChatCompletionMessage(BaseModel):
|
153
|
+
role: Literal["system", "user", "assistant"]
|
154
|
+
content: str
|
155
|
+
|
156
|
+
|
157
|
+
class ChatCompletionRequest(BaseModel):
|
158
|
+
model: str
|
159
|
+
messages: List[ChatCompletionMessage]
|
160
|
+
temperature: float | None = None
|
161
|
+
top_p: float | None = None
|
162
|
+
max_tokens: int | None = None
|
163
|
+
|
164
|
+
|
165
|
+
class ChatCompletionChoice(BaseModel):
|
166
|
+
message: ChatCompletionMessage
|
167
|
+
index: int
|
168
|
+
finish_reason: Literal["stop", "length", "tool_calls"]
|
169
|
+
|
170
|
+
|
171
|
+
class ChatCompletionModel(BaseModel):
|
172
|
+
id: str
|
173
|
+
object: str = "chat.completion"
|
174
|
+
created: int
|
175
|
+
model: str
|
176
|
+
choices: List[ChatCompletionChoice]
|
177
|
+
|
178
|
+
|
179
|
+
class GRPORequest(BaseModel):
|
180
|
+
model: str
|
181
|
+
update_inference_model: bool
|
182
|
+
batch: List[dict]
|
183
|
+
|
184
|
+
|
185
|
+
class GRPOConfigRequest(BaseModel):
|
186
|
+
model: str
|
187
|
+
temperature: Optional[float] = None
|
188
|
+
beta: Optional[float] = None
|
189
|
+
num_iterations: Optional[int] = None
|
190
|
+
num_generations: Optional[int] = None
|
191
|
+
per_device_train_batch_size: Optional[int] = None
|
192
|
+
learning_rate: Optional[float] = None
|
193
|
+
gradient_accumulation_steps: Optional[int] = None
|
194
|
+
gradient_checkpointing: Optional[bool] = None
|
195
|
+
lr_scheduler_type: Optional[str] = None
|
196
|
+
max_prompt_length: Optional[int] = None
|
197
|
+
max_completion_length: Optional[int] = None
|
198
|
+
gradient_checkpointing_kwargs: Optional[dict] = {}
|
199
|
+
bf16: Optional[bool] = None
|
200
|
+
scale_rewards: Optional[bool] = None
|
201
|
+
max_grad_norm: Optional[float] = None
|
202
|
+
lora: Optional[bool] = None
|
203
|
+
update_interval: Optional[int] = None
|
204
|
+
# To name the run
|
205
|
+
suffix: Optional[str] = None
|
206
|
+
|
207
|
+
|
208
|
+
class GRPOConfigResponse(BaseModel):
|
209
|
+
status: str
|
210
|
+
|
211
|
+
|
212
|
+
class GRPOTerminateRequest(BaseModel):
|
213
|
+
status: Optional[str] = "success"
|
214
|
+
|
215
|
+
|
216
|
+
class GRPOTerminateResponse(BaseModel):
|
217
|
+
status: str
|
218
|
+
current_model: str
|
219
|
+
|
220
|
+
|
221
|
+
class GRPOStepResponse(BaseModel):
|
222
|
+
status: str
|
223
|
+
current_model: str
|
File without changes
|
@@ -0,0 +1,52 @@
|
|
1
|
+
from typing import Literal
|
2
|
+
|
3
|
+
from fastapi import APIRouter, Body, File, HTTPException, Request, UploadFile
|
4
|
+
|
5
|
+
from arbor.server.api.models.schemas import FileModel, PaginatedResponse
|
6
|
+
from arbor.server.services.file_manager import FileValidationError
|
7
|
+
|
8
|
+
# https://platform.openai.com/docs/api-reference/files/list
|
9
|
+
router = APIRouter()
|
10
|
+
|
11
|
+
|
12
|
+
@router.post("", response_model=FileModel)
|
13
|
+
async def upload_file(
|
14
|
+
request: Request,
|
15
|
+
file: UploadFile = File(...),
|
16
|
+
purpose: Literal["assistants", "vision", "fine-tune", "batch"] = Body("fine-tune"),
|
17
|
+
):
|
18
|
+
file_manager = request.app.state.file_manager
|
19
|
+
if not file.filename.endswith(".jsonl"):
|
20
|
+
raise HTTPException(status_code=400, detail="Only .jsonl files are allowed")
|
21
|
+
|
22
|
+
try:
|
23
|
+
content = await file.read()
|
24
|
+
# file_manager.validate_file_format(content) #TODO: add another flag to specify the types of files
|
25
|
+
await file.seek(0) # Reset file pointer to beginning
|
26
|
+
return FileModel(**file_manager.save_uploaded_file(file))
|
27
|
+
except FileValidationError as e:
|
28
|
+
raise HTTPException(status_code=400, detail=f"Invalid file format: {str(e)}")
|
29
|
+
|
30
|
+
|
31
|
+
@router.get("", response_model=PaginatedResponse[FileModel])
|
32
|
+
def list_files(request: Request):
|
33
|
+
file_manager = request.app.state.file_manager
|
34
|
+
return PaginatedResponse(
|
35
|
+
items=file_manager.get_files(),
|
36
|
+
total=len(file_manager.get_files()),
|
37
|
+
page=1,
|
38
|
+
page_size=10,
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
@router.get("/{file_id}", response_model=FileModel)
|
43
|
+
def get_file(request: Request, file_id: str):
|
44
|
+
file_manager = request.app.state.file_manager
|
45
|
+
return file_manager.get_file(file_id)
|
46
|
+
|
47
|
+
|
48
|
+
@router.delete("/{file_id}")
|
49
|
+
def delete_file(request: Request, file_id: str):
|
50
|
+
file_manager = request.app.state.file_manager
|
51
|
+
file_manager.delete_file(file_id)
|
52
|
+
return {"message": "File deleted"}
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import os
|
2
|
+
import subprocess
|
3
|
+
|
4
|
+
from fastapi import APIRouter, BackgroundTasks, Request
|
5
|
+
|
6
|
+
from arbor.server.api.models.schemas import (
|
7
|
+
GRPOConfigRequest,
|
8
|
+
GRPOConfigResponse,
|
9
|
+
GRPORequest,
|
10
|
+
GRPOStepResponse,
|
11
|
+
GRPOTerminateRequest,
|
12
|
+
GRPOTerminateResponse,
|
13
|
+
)
|
14
|
+
|
15
|
+
router = APIRouter()
|
16
|
+
|
17
|
+
|
18
|
+
@router.post("/initialize", response_model=GRPOConfigResponse)
|
19
|
+
def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
|
20
|
+
inference_manager = request.app.state.inference_manager
|
21
|
+
grpo_manager = request.app.state.grpo_manager
|
22
|
+
grpo_manager.initialize(grpo_config_request, inference_manager)
|
23
|
+
return GRPOConfigResponse(status="success")
|
24
|
+
|
25
|
+
|
26
|
+
# Create a grpo job
|
27
|
+
@router.post("/step", response_model=GRPOStepResponse)
|
28
|
+
def run_grpo_step(
|
29
|
+
request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
|
30
|
+
):
|
31
|
+
inference_manager = request.app.state.inference_manager
|
32
|
+
grpo_manager = request.app.state.grpo_manager
|
33
|
+
|
34
|
+
current_model = grpo_manager.grpo_step(grpo_request, inference_manager)
|
35
|
+
|
36
|
+
return GRPOStepResponse(status="success", current_model=current_model)
|
37
|
+
|
38
|
+
|
39
|
+
@router.post("/update_model", response_model=GRPOStepResponse)
|
40
|
+
def update_model(request: Request):
|
41
|
+
grpo_manager = request.app.state.grpo_manager
|
42
|
+
inference_manager = request.app.state.inference_manager
|
43
|
+
current_model = grpo_manager.update_model(request, inference_manager)
|
44
|
+
return GRPOStepResponse(status="success", current_model=current_model)
|
45
|
+
|
46
|
+
|
47
|
+
@router.post("/terminate", response_model=GRPOTerminateResponse)
|
48
|
+
def terminate_grpo(request: Request):
|
49
|
+
# No body needed for this request at this moment
|
50
|
+
grpo_manager = request.app.state.grpo_manager
|
51
|
+
inference_manager = request.app.state.inference_manager
|
52
|
+
|
53
|
+
final_model = grpo_manager.terminate(inference_manager)
|
54
|
+
return GRPOTerminateResponse(status="success", current_model=final_model)
|
@@ -0,0 +1,53 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
from fastapi import APIRouter, Request
|
4
|
+
|
5
|
+
router = APIRouter()
|
6
|
+
|
7
|
+
|
8
|
+
@router.post("/completions")
|
9
|
+
async def run_inference(
|
10
|
+
request: Request,
|
11
|
+
):
|
12
|
+
inference_manager = request.app.state.inference_manager
|
13
|
+
raw_json = await request.json()
|
14
|
+
|
15
|
+
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
16
|
+
for prefix in prefixes:
|
17
|
+
if raw_json["model"].startswith(prefix):
|
18
|
+
raw_json["model"] = raw_json["model"][len(prefix) :]
|
19
|
+
|
20
|
+
# if a server isnt running, launch one
|
21
|
+
if (
|
22
|
+
not inference_manager.is_server_running()
|
23
|
+
and not inference_manager.is_server_restarting()
|
24
|
+
):
|
25
|
+
print("No model is running, launching model...")
|
26
|
+
inference_manager.launch(raw_json["model"])
|
27
|
+
|
28
|
+
if inference_manager.is_server_restarting():
|
29
|
+
print("Waiting for server to finish restarting...")
|
30
|
+
while inference_manager.is_server_restarting():
|
31
|
+
time.sleep(5)
|
32
|
+
# Update the model in the request
|
33
|
+
raw_json["model"] = inference_manager.current_model
|
34
|
+
|
35
|
+
# forward the request to the inference server
|
36
|
+
completion = inference_manager.run_inference(raw_json)
|
37
|
+
|
38
|
+
return completion
|
39
|
+
|
40
|
+
|
41
|
+
@router.post("/launch")
|
42
|
+
async def launch_inference(request: Request):
|
43
|
+
inference_manager = request.app.state.inference_manager
|
44
|
+
raw_json = await request.json()
|
45
|
+
inference_manager.launch(raw_json["model"], raw_json["launch_kwargs"])
|
46
|
+
return {"message": "Inference server launched"}
|
47
|
+
|
48
|
+
|
49
|
+
@router.post("/kill")
|
50
|
+
async def kill_inference(request: Request):
|
51
|
+
inference_manager = request.app.state.inference_manager
|
52
|
+
inference_manager.kill()
|
53
|
+
return {"message": "Inference server killed"}
|
@@ -0,0 +1,117 @@
|
|
1
|
+
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
|
2
|
+
|
3
|
+
from arbor.server.api.models.schemas import (
|
4
|
+
FineTuneRequest,
|
5
|
+
JobCheckpointModel,
|
6
|
+
JobEventModel,
|
7
|
+
JobStatus,
|
8
|
+
JobStatusModel,
|
9
|
+
PaginatedResponse,
|
10
|
+
)
|
11
|
+
from arbor.server.services.job_manager import JobStatus
|
12
|
+
|
13
|
+
router = APIRouter()
|
14
|
+
|
15
|
+
|
16
|
+
# Create a fine-tune job
|
17
|
+
@router.post("", response_model=JobStatusModel)
|
18
|
+
def create_fine_tune_job(
|
19
|
+
request: Request,
|
20
|
+
fine_tune_request: FineTuneRequest,
|
21
|
+
background_tasks: BackgroundTasks,
|
22
|
+
):
|
23
|
+
job_manager = request.app.state.job_manager
|
24
|
+
file_manager = request.app.state.file_manager
|
25
|
+
training_manager = request.app.state.training_manager
|
26
|
+
|
27
|
+
job = job_manager.create_job()
|
28
|
+
background_tasks.add_task(
|
29
|
+
training_manager.fine_tune, fine_tune_request, job, file_manager
|
30
|
+
)
|
31
|
+
job.status = JobStatus.QUEUED
|
32
|
+
return JobStatusModel(id=job.id, status=job.status.value)
|
33
|
+
|
34
|
+
|
35
|
+
# List fine-tune jobs (paginated)
|
36
|
+
@router.get("", response_model=PaginatedResponse[JobStatusModel])
|
37
|
+
def get_jobs(request: Request):
|
38
|
+
job_manager = request.app.state.job_manager
|
39
|
+
return PaginatedResponse(
|
40
|
+
data=[
|
41
|
+
JobStatusModel(id=job.id, status=job.status.value)
|
42
|
+
for job in job_manager.get_jobs()
|
43
|
+
],
|
44
|
+
has_more=False,
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
# List fine-tuning events
|
49
|
+
@router.get("/{job_id}/events", response_model=PaginatedResponse[JobEventModel])
|
50
|
+
def get_job_events(request: Request, job_id: str):
|
51
|
+
job_manager = request.app.state.job_manager
|
52
|
+
job = job_manager.get_job(job_id)
|
53
|
+
return PaginatedResponse(
|
54
|
+
data=[
|
55
|
+
JobEventModel(
|
56
|
+
id=event.id,
|
57
|
+
level=event.level,
|
58
|
+
message=event.message,
|
59
|
+
data=event.data,
|
60
|
+
created_at=int(event.created_at.timestamp()),
|
61
|
+
type="message",
|
62
|
+
)
|
63
|
+
for event in job.get_events()
|
64
|
+
],
|
65
|
+
has_more=False,
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
# List fine-tuning checkpoints
|
70
|
+
@router.get(
|
71
|
+
"/{job_id}/checkpoints", response_model=PaginatedResponse[JobCheckpointModel]
|
72
|
+
)
|
73
|
+
def get_job_checkpoints(request: Request, job_id: str):
|
74
|
+
job_manager = request.app.state.job_manager
|
75
|
+
job = job_manager.get_job(job_id)
|
76
|
+
return PaginatedResponse(
|
77
|
+
data=[
|
78
|
+
JobCheckpointModel(
|
79
|
+
id=checkpoint.id,
|
80
|
+
fine_tuned_model_checkpoint=checkpoint.fine_tuned_model_checkpoint,
|
81
|
+
fine_tuning_job_id=checkpoint.fine_tuning_job_id,
|
82
|
+
metrics=checkpoint.metrics,
|
83
|
+
step_number=checkpoint.step_number,
|
84
|
+
)
|
85
|
+
for checkpoint in job.get_checkpoints()
|
86
|
+
],
|
87
|
+
has_more=False,
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
# Retrieve a fine-tune job by id
|
92
|
+
@router.get("/{job_id}", response_model=JobStatusModel)
|
93
|
+
def get_job_status(
|
94
|
+
request: Request,
|
95
|
+
job_id: str,
|
96
|
+
):
|
97
|
+
job_manager = request.app.state.job_manager
|
98
|
+
job = job_manager.get_job(job_id)
|
99
|
+
return JobStatusModel(
|
100
|
+
id=job_id, status=job.status.value, fine_tuned_model=job.fine_tuned_model
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
# Cancel a fine-tune job
|
105
|
+
@router.post("/{job_id}/cancel", response_model=JobStatusModel)
|
106
|
+
def cancel_job(request: Request, job_id: str):
|
107
|
+
job_manager = request.app.state.job_manager
|
108
|
+
job = job_manager.get_job(job_id)
|
109
|
+
|
110
|
+
# Only allow cancellation of jobs that aren't finished
|
111
|
+
if job.status in [JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.CANCELLED]:
|
112
|
+
raise HTTPException(
|
113
|
+
status_code=400, detail=f"Cannot cancel job with status {job.status.value}"
|
114
|
+
)
|
115
|
+
|
116
|
+
job.status = JobStatus.PENDING_CANCEL
|
117
|
+
return JobStatusModel(id=job.id, status=job.status.value)
|
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import yaml
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
6
|
+
|
7
|
+
|
8
|
+
class InferenceConfig(BaseModel):
|
9
|
+
gpu_ids: str = "0"
|
10
|
+
|
11
|
+
|
12
|
+
class TrainingConfig(BaseModel):
|
13
|
+
gpu_ids: str = "0"
|
14
|
+
accelerate_config: Optional[str] = None
|
15
|
+
|
16
|
+
|
17
|
+
class ArborConfig(BaseModel):
|
18
|
+
inference: InferenceConfig
|
19
|
+
training: TrainingConfig
|
20
|
+
|
21
|
+
|
22
|
+
class Settings(BaseModel):
|
23
|
+
|
24
|
+
STORAGE_PATH: str = "./storage"
|
25
|
+
INACTIVITY_TIMEOUT: int = 30 # 5 seconds
|
26
|
+
arbor_config: ArborConfig
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def load_from_yaml(cls, yaml_path: str) -> "Settings":
|
30
|
+
if not yaml_path:
|
31
|
+
raise ValueError("Config file path is required")
|
32
|
+
if not Path(yaml_path).exists():
|
33
|
+
raise ValueError(f"Config file {yaml_path} does not exist")
|
34
|
+
|
35
|
+
try:
|
36
|
+
with open(yaml_path, "r") as f:
|
37
|
+
config = yaml.safe_load(f)
|
38
|
+
|
39
|
+
settings = cls(
|
40
|
+
arbor_config=ArborConfig(
|
41
|
+
inference=InferenceConfig(**config["inference"]),
|
42
|
+
training=TrainingConfig(**config["training"]),
|
43
|
+
)
|
44
|
+
)
|
45
|
+
return settings
|
46
|
+
except Exception as e:
|
47
|
+
raise ValueError(f"Error loading config file {yaml_path}: {e}")
|
File without changes
|
arbor/server/main.py
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
from fastapi import FastAPI
|
2
|
+
|
3
|
+
from arbor.server.api.routes import files, grpo, inference, jobs
|
4
|
+
|
5
|
+
app = FastAPI(title="Arbor API")
|
6
|
+
|
7
|
+
# Include routers
|
8
|
+
app.include_router(files.router, prefix="/v1/files")
|
9
|
+
app.include_router(jobs.router, prefix="/v1/fine_tuning/jobs")
|
10
|
+
app.include_router(grpo.router, prefix="/v1/fine_tuning/grpo")
|
11
|
+
app.include_router(inference.router, prefix="/v1/chat")
|
File without changes
|
File without changes
|