arbor-ai 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/client/__init__.py +0 -0
  2. arbor/client/api.py +1 -0
  3. arbor/server/__init__.py +1 -0
  4. arbor/server/api/__init__.py +1 -0
  5. arbor/server/api/models/schemas.py +223 -0
  6. arbor/server/api/routes/__init__.py +0 -0
  7. arbor/server/api/routes/files.py +52 -0
  8. arbor/server/api/routes/grpo.py +54 -0
  9. arbor/server/api/routes/inference.py +53 -0
  10. arbor/server/api/routes/jobs.py +117 -0
  11. arbor/server/core/__init__.py +1 -0
  12. arbor/server/core/config.py +47 -0
  13. arbor/server/core/logging.py +0 -0
  14. arbor/server/main.py +11 -0
  15. arbor/server/services/__init__.py +0 -0
  16. arbor/server/services/comms/__init__.py +0 -0
  17. arbor/server/services/comms/comms.py +226 -0
  18. arbor/server/services/dependencies.py +0 -0
  19. arbor/server/services/file_manager.py +289 -0
  20. arbor/server/services/grpo_manager.py +310 -0
  21. arbor/server/services/inference_manager.py +275 -0
  22. arbor/server/services/job_manager.py +81 -0
  23. arbor/server/services/scripts/grpo_training.py +576 -0
  24. arbor/server/services/training_manager.py +561 -0
  25. arbor/server/utils/__init__.py +0 -0
  26. arbor/server/utils/helpers.py +0 -0
  27. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/METADATA +1 -1
  28. arbor_ai-0.1.6.dist-info/RECORD +34 -0
  29. arbor_ai-0.1.5.dist-info/RECORD +0 -8
  30. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/top_level.txt +0 -0
File without changes
arbor/client/api.py ADDED
@@ -0,0 +1 @@
1
+ # Unused Right Now
@@ -0,0 +1 @@
1
+
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,223 @@
1
+ from enum import Enum
2
+ from typing import Any, Generic, List, Literal, Optional, TypeVar
3
+
4
+ from pydantic import BaseModel, ConfigDict
5
+
6
+ # Generic type for list items
7
+ T = TypeVar("T")
8
+
9
+
10
+ class PaginatedResponse(BaseModel, Generic[T]):
11
+ object: str = "list"
12
+ data: List[T]
13
+ has_more: bool = False
14
+
15
+
16
+ class FileModel(BaseModel):
17
+ id: str
18
+ object: str = "file"
19
+ bytes: int
20
+ created_at: int
21
+ filename: str
22
+ purpose: str
23
+
24
+
25
+ class WandbConfig(BaseModel):
26
+ project: str
27
+ name: Optional[str] = None
28
+ entity: Optional[str] = None
29
+ tags: Optional[List[str]] = None
30
+
31
+
32
+ class IntegrationModel(BaseModel):
33
+ type: str
34
+ wandb: WandbConfig
35
+
36
+
37
+ class FineTuneRequest(BaseModel):
38
+ model: str
39
+ training_file: str # id of uploaded jsonl file
40
+ method: dict
41
+ suffix: Optional[str] = None
42
+ # UNUSED
43
+ validation_file: Optional[str] = None
44
+ integrations: Optional[List[IntegrationModel]] = []
45
+ seed: Optional[int] = None
46
+
47
+
48
+ class ErrorModel(BaseModel):
49
+ code: str
50
+ message: str
51
+ param: str | None = None
52
+
53
+
54
+ class SupervisedHyperparametersModel(BaseModel):
55
+ batch_size: int | str = "auto"
56
+ learning_rate_multiplier: float | str = "auto"
57
+ n_epochs: int | str = "auto"
58
+
59
+
60
+ class DPOHyperparametersModel(BaseModel):
61
+ beta: float | str = "auto"
62
+ batch_size: int | str = "auto"
63
+ learning_rate_multiplier: float | str = "auto"
64
+ n_epochs: int | str = "auto"
65
+
66
+
67
+ class SupervisedModel(BaseModel):
68
+ hyperparameters: SupervisedHyperparametersModel
69
+
70
+
71
+ class DpoModel(BaseModel):
72
+ hyperparameters: DPOHyperparametersModel
73
+
74
+
75
+ class MethodModel(BaseModel):
76
+ type: Literal["supervised"] | Literal["dpo"]
77
+ supervised: SupervisedModel | None = None
78
+ dpo: DpoModel | None = None
79
+
80
+
81
+ # https://platform.openai.com/docs/api-reference/fine-tuning/object
82
+ class JobStatus(Enum):
83
+ PENDING = "pending" # Not in OAI
84
+ PENDING_PAUSE = "pending_pause" # Not in OAI
85
+ PENDING_RESUME = "pending_resume" # Not in OAI
86
+ PAUSED = "paused" # Not in OAI
87
+ VALIDATING_FILES = "validating_files"
88
+ QUEUED = "queued"
89
+ RUNNING = "running"
90
+ SUCCEEDED = "succeeded"
91
+ FAILED = "failed"
92
+ CANCELLED = "cancelled"
93
+ PENDING_CANCEL = "pending_cancel"
94
+
95
+
96
+ # https://platform.openai.com/docs/api-reference/fine-tuning/object
97
+ class JobStatusModel(BaseModel):
98
+ object: str = "fine_tuning.job"
99
+ id: str
100
+ fine_tuned_model: str | None = None
101
+ status: JobStatus
102
+
103
+ # UNUSED so commented out
104
+ # model: str
105
+ # created_at: int
106
+ # error: ErrorModel | None = None
107
+ # details: str = ""
108
+ # finished_at: int
109
+ # hyperparameters: None # deprecated in OAI
110
+ # organization_id: str
111
+ # result_files: list[str]
112
+ # trained_tokens: int | None = None # None if not finished
113
+ # training_file: str
114
+ # validation_file: str
115
+ # integrations: list[Integration]
116
+ # seed: int
117
+ # estimated_finish: int | None = None # The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.
118
+ # method: MethodModel
119
+ # metadata: dict[str, str]
120
+
121
+
122
+ class JobEventModel(BaseModel):
123
+ object: str = "fine_tuning.job_event"
124
+ id: str
125
+ created_at: int
126
+ level: str
127
+ message: str
128
+ data: dict[str, Any]
129
+ type: str
130
+
131
+
132
+ class MetricsModel(BaseModel):
133
+ step: int
134
+ train_loss: float
135
+ train_mean_token_accuracy: float
136
+ valid_loss: float
137
+ valid_mean_token_accuracy: float
138
+ full_valid_loss: float
139
+ full_valid_mean_token_accuracy: float
140
+
141
+
142
+ class JobCheckpointModel(BaseModel):
143
+ object: str = "fine_tuning.job_checkpoint"
144
+ id: str
145
+ created_at: int
146
+ fine_tuned_model_checkpoint: str
147
+ step_number: int
148
+ metrics: MetricsModel
149
+ fine_tuning_job_id: str
150
+
151
+
152
+ class ChatCompletionMessage(BaseModel):
153
+ role: Literal["system", "user", "assistant"]
154
+ content: str
155
+
156
+
157
+ class ChatCompletionRequest(BaseModel):
158
+ model: str
159
+ messages: List[ChatCompletionMessage]
160
+ temperature: float | None = None
161
+ top_p: float | None = None
162
+ max_tokens: int | None = None
163
+
164
+
165
+ class ChatCompletionChoice(BaseModel):
166
+ message: ChatCompletionMessage
167
+ index: int
168
+ finish_reason: Literal["stop", "length", "tool_calls"]
169
+
170
+
171
+ class ChatCompletionModel(BaseModel):
172
+ id: str
173
+ object: str = "chat.completion"
174
+ created: int
175
+ model: str
176
+ choices: List[ChatCompletionChoice]
177
+
178
+
179
+ class GRPORequest(BaseModel):
180
+ model: str
181
+ update_inference_model: bool
182
+ batch: List[dict]
183
+
184
+
185
+ class GRPOConfigRequest(BaseModel):
186
+ model: str
187
+ temperature: Optional[float] = None
188
+ beta: Optional[float] = None
189
+ num_iterations: Optional[int] = None
190
+ num_generations: Optional[int] = None
191
+ per_device_train_batch_size: Optional[int] = None
192
+ learning_rate: Optional[float] = None
193
+ gradient_accumulation_steps: Optional[int] = None
194
+ gradient_checkpointing: Optional[bool] = None
195
+ lr_scheduler_type: Optional[str] = None
196
+ max_prompt_length: Optional[int] = None
197
+ max_completion_length: Optional[int] = None
198
+ gradient_checkpointing_kwargs: Optional[dict] = {}
199
+ bf16: Optional[bool] = None
200
+ scale_rewards: Optional[bool] = None
201
+ max_grad_norm: Optional[float] = None
202
+ lora: Optional[bool] = None
203
+ update_interval: Optional[int] = None
204
+ # To name the run
205
+ suffix: Optional[str] = None
206
+
207
+
208
+ class GRPOConfigResponse(BaseModel):
209
+ status: str
210
+
211
+
212
+ class GRPOTerminateRequest(BaseModel):
213
+ status: Optional[str] = "success"
214
+
215
+
216
+ class GRPOTerminateResponse(BaseModel):
217
+ status: str
218
+ current_model: str
219
+
220
+
221
+ class GRPOStepResponse(BaseModel):
222
+ status: str
223
+ current_model: str
File without changes
@@ -0,0 +1,52 @@
1
+ from typing import Literal
2
+
3
+ from fastapi import APIRouter, Body, File, HTTPException, Request, UploadFile
4
+
5
+ from arbor.server.api.models.schemas import FileModel, PaginatedResponse
6
+ from arbor.server.services.file_manager import FileValidationError
7
+
8
+ # https://platform.openai.com/docs/api-reference/files/list
9
+ router = APIRouter()
10
+
11
+
12
+ @router.post("", response_model=FileModel)
13
+ async def upload_file(
14
+ request: Request,
15
+ file: UploadFile = File(...),
16
+ purpose: Literal["assistants", "vision", "fine-tune", "batch"] = Body("fine-tune"),
17
+ ):
18
+ file_manager = request.app.state.file_manager
19
+ if not file.filename.endswith(".jsonl"):
20
+ raise HTTPException(status_code=400, detail="Only .jsonl files are allowed")
21
+
22
+ try:
23
+ content = await file.read()
24
+ # file_manager.validate_file_format(content) #TODO: add another flag to specify the types of files
25
+ await file.seek(0) # Reset file pointer to beginning
26
+ return FileModel(**file_manager.save_uploaded_file(file))
27
+ except FileValidationError as e:
28
+ raise HTTPException(status_code=400, detail=f"Invalid file format: {str(e)}")
29
+
30
+
31
+ @router.get("", response_model=PaginatedResponse[FileModel])
32
+ def list_files(request: Request):
33
+ file_manager = request.app.state.file_manager
34
+ return PaginatedResponse(
35
+ items=file_manager.get_files(),
36
+ total=len(file_manager.get_files()),
37
+ page=1,
38
+ page_size=10,
39
+ )
40
+
41
+
42
+ @router.get("/{file_id}", response_model=FileModel)
43
+ def get_file(request: Request, file_id: str):
44
+ file_manager = request.app.state.file_manager
45
+ return file_manager.get_file(file_id)
46
+
47
+
48
+ @router.delete("/{file_id}")
49
+ def delete_file(request: Request, file_id: str):
50
+ file_manager = request.app.state.file_manager
51
+ file_manager.delete_file(file_id)
52
+ return {"message": "File deleted"}
@@ -0,0 +1,54 @@
1
+ import os
2
+ import subprocess
3
+
4
+ from fastapi import APIRouter, BackgroundTasks, Request
5
+
6
+ from arbor.server.api.models.schemas import (
7
+ GRPOConfigRequest,
8
+ GRPOConfigResponse,
9
+ GRPORequest,
10
+ GRPOStepResponse,
11
+ GRPOTerminateRequest,
12
+ GRPOTerminateResponse,
13
+ )
14
+
15
+ router = APIRouter()
16
+
17
+
18
+ @router.post("/initialize", response_model=GRPOConfigResponse)
19
+ def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
20
+ inference_manager = request.app.state.inference_manager
21
+ grpo_manager = request.app.state.grpo_manager
22
+ grpo_manager.initialize(grpo_config_request, inference_manager)
23
+ return GRPOConfigResponse(status="success")
24
+
25
+
26
+ # Create a grpo job
27
+ @router.post("/step", response_model=GRPOStepResponse)
28
+ def run_grpo_step(
29
+ request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
30
+ ):
31
+ inference_manager = request.app.state.inference_manager
32
+ grpo_manager = request.app.state.grpo_manager
33
+
34
+ current_model = grpo_manager.grpo_step(grpo_request, inference_manager)
35
+
36
+ return GRPOStepResponse(status="success", current_model=current_model)
37
+
38
+
39
+ @router.post("/update_model", response_model=GRPOStepResponse)
40
+ def update_model(request: Request):
41
+ grpo_manager = request.app.state.grpo_manager
42
+ inference_manager = request.app.state.inference_manager
43
+ current_model = grpo_manager.update_model(request, inference_manager)
44
+ return GRPOStepResponse(status="success", current_model=current_model)
45
+
46
+
47
+ @router.post("/terminate", response_model=GRPOTerminateResponse)
48
+ def terminate_grpo(request: Request):
49
+ # No body needed for this request at this moment
50
+ grpo_manager = request.app.state.grpo_manager
51
+ inference_manager = request.app.state.inference_manager
52
+
53
+ final_model = grpo_manager.terminate(inference_manager)
54
+ return GRPOTerminateResponse(status="success", current_model=final_model)
@@ -0,0 +1,53 @@
1
+ import time
2
+
3
+ from fastapi import APIRouter, Request
4
+
5
+ router = APIRouter()
6
+
7
+
8
+ @router.post("/completions")
9
+ async def run_inference(
10
+ request: Request,
11
+ ):
12
+ inference_manager = request.app.state.inference_manager
13
+ raw_json = await request.json()
14
+
15
+ prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
16
+ for prefix in prefixes:
17
+ if raw_json["model"].startswith(prefix):
18
+ raw_json["model"] = raw_json["model"][len(prefix) :]
19
+
20
+ # if a server isnt running, launch one
21
+ if (
22
+ not inference_manager.is_server_running()
23
+ and not inference_manager.is_server_restarting()
24
+ ):
25
+ print("No model is running, launching model...")
26
+ inference_manager.launch(raw_json["model"])
27
+
28
+ if inference_manager.is_server_restarting():
29
+ print("Waiting for server to finish restarting...")
30
+ while inference_manager.is_server_restarting():
31
+ time.sleep(5)
32
+ # Update the model in the request
33
+ raw_json["model"] = inference_manager.current_model
34
+
35
+ # forward the request to the inference server
36
+ completion = inference_manager.run_inference(raw_json)
37
+
38
+ return completion
39
+
40
+
41
+ @router.post("/launch")
42
+ async def launch_inference(request: Request):
43
+ inference_manager = request.app.state.inference_manager
44
+ raw_json = await request.json()
45
+ inference_manager.launch(raw_json["model"], raw_json["launch_kwargs"])
46
+ return {"message": "Inference server launched"}
47
+
48
+
49
+ @router.post("/kill")
50
+ async def kill_inference(request: Request):
51
+ inference_manager = request.app.state.inference_manager
52
+ inference_manager.kill()
53
+ return {"message": "Inference server killed"}
@@ -0,0 +1,117 @@
1
+ from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
2
+
3
+ from arbor.server.api.models.schemas import (
4
+ FineTuneRequest,
5
+ JobCheckpointModel,
6
+ JobEventModel,
7
+ JobStatus,
8
+ JobStatusModel,
9
+ PaginatedResponse,
10
+ )
11
+ from arbor.server.services.job_manager import JobStatus
12
+
13
+ router = APIRouter()
14
+
15
+
16
+ # Create a fine-tune job
17
+ @router.post("", response_model=JobStatusModel)
18
+ def create_fine_tune_job(
19
+ request: Request,
20
+ fine_tune_request: FineTuneRequest,
21
+ background_tasks: BackgroundTasks,
22
+ ):
23
+ job_manager = request.app.state.job_manager
24
+ file_manager = request.app.state.file_manager
25
+ training_manager = request.app.state.training_manager
26
+
27
+ job = job_manager.create_job()
28
+ background_tasks.add_task(
29
+ training_manager.fine_tune, fine_tune_request, job, file_manager
30
+ )
31
+ job.status = JobStatus.QUEUED
32
+ return JobStatusModel(id=job.id, status=job.status.value)
33
+
34
+
35
+ # List fine-tune jobs (paginated)
36
+ @router.get("", response_model=PaginatedResponse[JobStatusModel])
37
+ def get_jobs(request: Request):
38
+ job_manager = request.app.state.job_manager
39
+ return PaginatedResponse(
40
+ data=[
41
+ JobStatusModel(id=job.id, status=job.status.value)
42
+ for job in job_manager.get_jobs()
43
+ ],
44
+ has_more=False,
45
+ )
46
+
47
+
48
+ # List fine-tuning events
49
+ @router.get("/{job_id}/events", response_model=PaginatedResponse[JobEventModel])
50
+ def get_job_events(request: Request, job_id: str):
51
+ job_manager = request.app.state.job_manager
52
+ job = job_manager.get_job(job_id)
53
+ return PaginatedResponse(
54
+ data=[
55
+ JobEventModel(
56
+ id=event.id,
57
+ level=event.level,
58
+ message=event.message,
59
+ data=event.data,
60
+ created_at=int(event.created_at.timestamp()),
61
+ type="message",
62
+ )
63
+ for event in job.get_events()
64
+ ],
65
+ has_more=False,
66
+ )
67
+
68
+
69
+ # List fine-tuning checkpoints
70
+ @router.get(
71
+ "/{job_id}/checkpoints", response_model=PaginatedResponse[JobCheckpointModel]
72
+ )
73
+ def get_job_checkpoints(request: Request, job_id: str):
74
+ job_manager = request.app.state.job_manager
75
+ job = job_manager.get_job(job_id)
76
+ return PaginatedResponse(
77
+ data=[
78
+ JobCheckpointModel(
79
+ id=checkpoint.id,
80
+ fine_tuned_model_checkpoint=checkpoint.fine_tuned_model_checkpoint,
81
+ fine_tuning_job_id=checkpoint.fine_tuning_job_id,
82
+ metrics=checkpoint.metrics,
83
+ step_number=checkpoint.step_number,
84
+ )
85
+ for checkpoint in job.get_checkpoints()
86
+ ],
87
+ has_more=False,
88
+ )
89
+
90
+
91
+ # Retrieve a fine-tune job by id
92
+ @router.get("/{job_id}", response_model=JobStatusModel)
93
+ def get_job_status(
94
+ request: Request,
95
+ job_id: str,
96
+ ):
97
+ job_manager = request.app.state.job_manager
98
+ job = job_manager.get_job(job_id)
99
+ return JobStatusModel(
100
+ id=job_id, status=job.status.value, fine_tuned_model=job.fine_tuned_model
101
+ )
102
+
103
+
104
+ # Cancel a fine-tune job
105
+ @router.post("/{job_id}/cancel", response_model=JobStatusModel)
106
+ def cancel_job(request: Request, job_id: str):
107
+ job_manager = request.app.state.job_manager
108
+ job = job_manager.get_job(job_id)
109
+
110
+ # Only allow cancellation of jobs that aren't finished
111
+ if job.status in [JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.CANCELLED]:
112
+ raise HTTPException(
113
+ status_code=400, detail=f"Cannot cancel job with status {job.status.value}"
114
+ )
115
+
116
+ job.status = JobStatus.PENDING_CANCEL
117
+ return JobStatusModel(id=job.id, status=job.status.value)
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,47 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import yaml
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+
8
+ class InferenceConfig(BaseModel):
9
+ gpu_ids: str = "0"
10
+
11
+
12
+ class TrainingConfig(BaseModel):
13
+ gpu_ids: str = "0"
14
+ accelerate_config: Optional[str] = None
15
+
16
+
17
+ class ArborConfig(BaseModel):
18
+ inference: InferenceConfig
19
+ training: TrainingConfig
20
+
21
+
22
+ class Settings(BaseModel):
23
+
24
+ STORAGE_PATH: str = "./storage"
25
+ INACTIVITY_TIMEOUT: int = 30 # 5 seconds
26
+ arbor_config: ArborConfig
27
+
28
+ @classmethod
29
+ def load_from_yaml(cls, yaml_path: str) -> "Settings":
30
+ if not yaml_path:
31
+ raise ValueError("Config file path is required")
32
+ if not Path(yaml_path).exists():
33
+ raise ValueError(f"Config file {yaml_path} does not exist")
34
+
35
+ try:
36
+ with open(yaml_path, "r") as f:
37
+ config = yaml.safe_load(f)
38
+
39
+ settings = cls(
40
+ arbor_config=ArborConfig(
41
+ inference=InferenceConfig(**config["inference"]),
42
+ training=TrainingConfig(**config["training"]),
43
+ )
44
+ )
45
+ return settings
46
+ except Exception as e:
47
+ raise ValueError(f"Error loading config file {yaml_path}: {e}")
File without changes
arbor/server/main.py ADDED
@@ -0,0 +1,11 @@
1
+ from fastapi import FastAPI
2
+
3
+ from arbor.server.api.routes import files, grpo, inference, jobs
4
+
5
+ app = FastAPI(title="Arbor API")
6
+
7
+ # Include routers
8
+ app.include_router(files.router, prefix="/v1/files")
9
+ app.include_router(jobs.router, prefix="/v1/fine_tuning/jobs")
10
+ app.include_router(grpo.router, prefix="/v1/fine_tuning/grpo")
11
+ app.include_router(inference.router, prefix="/v1/chat")
File without changes
File without changes