arbor-ai 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arbor/cli.py CHANGED
@@ -1,17 +1,101 @@
1
1
  import click
2
2
  import uvicorn
3
+
4
+ from arbor.server.core.config import Settings
3
5
  from arbor.server.main import app
6
+ from arbor.server.services.file_manager import FileManager
7
+ from arbor.server.services.grpo_manager import GRPOManager
8
+ from arbor.server.services.inference_manager import InferenceManager
9
+ from arbor.server.services.job_manager import JobManager
10
+ from arbor.server.services.training_manager import TrainingManager
11
+
4
12
 
5
13
  @click.group()
6
14
  def cli():
7
15
  pass
8
16
 
17
+
18
+ def create_app(arbor_config_path: str):
19
+ """Create and configure the Arbor API application
20
+
21
+ Args:
22
+ storage_path (str): Path to store models and uploaded training files
23
+
24
+ Returns:
25
+ FastAPI: Configured FastAPI application
26
+ """
27
+ # Create new settings instance with overrides
28
+ settings = Settings.load_from_yaml(arbor_config_path)
29
+
30
+ # Initialize services with settings
31
+ file_manager = FileManager(settings=settings)
32
+ job_manager = JobManager(settings=settings)
33
+ training_manager = TrainingManager(settings=settings)
34
+ inference_manager = InferenceManager(settings=settings)
35
+ grpo_manager = GRPOManager(settings=settings)
36
+ # Inject settings into app state
37
+ app.state.settings = settings
38
+ app.state.file_manager = file_manager
39
+ app.state.job_manager = job_manager
40
+ app.state.training_manager = training_manager
41
+ app.state.inference_manager = inference_manager
42
+ app.state.grpo_manager = grpo_manager
43
+
44
+ return app
45
+
46
+
47
+ def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
48
+ """Start the Arbor API server with a single function call"""
49
+ import socket
50
+ import threading
51
+ import time
52
+ from contextlib import closing
53
+
54
+ def is_port_in_use(port):
55
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
56
+ return sock.connect_ex(("localhost", port)) == 0
57
+
58
+ # First ensure the port is free
59
+ if is_port_in_use(port):
60
+ raise RuntimeError(f"Port {port} is already in use")
61
+
62
+ app = create_app(storage_path)
63
+ config = uvicorn.Config(app, host=host, port=port, log_level="info")
64
+ server = uvicorn.Server(config)
65
+
66
+ def run_server():
67
+ server.run()
68
+
69
+ thread = threading.Thread(target=run_server, daemon=True)
70
+ thread.start()
71
+
72
+ # Wait for server to start
73
+ start_time = time.time()
74
+ while not is_port_in_use(port):
75
+ if time.time() - start_time > timeout:
76
+ raise TimeoutError(f"Server failed to start within {timeout} seconds")
77
+ time.sleep(0.1)
78
+
79
+ # Give it a little extra time to fully initialize
80
+ time.sleep(0.5)
81
+
82
+ return server
83
+
84
+
85
+ def stop_server(server):
86
+ """Stop the Arbor API server"""
87
+ server.should_exit = True
88
+
89
+
9
90
  @cli.command()
10
- @click.option('--host', default='0.0.0.0', help='Host to bind to')
11
- @click.option('--port', default=8000, help='Port to bind to')
12
- def serve(host, port):
91
+ @click.option("--host", default="0.0.0.0", help="Host to bind to")
92
+ @click.option("--port", default=7453, help="Port to bind to")
93
+ @click.option("--arbor-config", required=True, help="Path to the Arbor config file")
94
+ def serve(host, port, arbor_config):
13
95
  """Start the Arbor API server"""
96
+ app = create_app(arbor_config)
14
97
  uvicorn.run(app, host=host, port=port)
15
98
 
16
- if __name__ == '__main__':
17
- cli()
99
+
100
+ if __name__ == "__main__":
101
+ cli()
arbor/client/api.py CHANGED
@@ -1,2 +1 @@
1
- from typing import Optional, Dict, Any
2
-
1
+ # Unused Right Now
@@ -1,6 +1,19 @@
1
- from pydantic import BaseModel
1
+ from enum import Enum
2
+ from typing import Any, Generic, List, Literal, Optional, TypeVar
2
3
 
3
- class FileResponse(BaseModel):
4
+ from pydantic import BaseModel, ConfigDict
5
+
6
+ # Generic type for list items
7
+ T = TypeVar("T")
8
+
9
+
10
+ class PaginatedResponse(BaseModel, Generic[T]):
11
+ object: str = "list"
12
+ data: List[T]
13
+ has_more: bool = False
14
+
15
+
16
+ class FileModel(BaseModel):
4
17
  id: str
5
18
  object: str = "file"
6
19
  bytes: int
@@ -8,12 +21,203 @@ class FileResponse(BaseModel):
8
21
  filename: str
9
22
  purpose: str
10
23
 
24
+
25
+ class WandbConfig(BaseModel):
26
+ project: str
27
+ name: Optional[str] = None
28
+ entity: Optional[str] = None
29
+ tags: Optional[List[str]] = None
30
+
31
+
32
+ class IntegrationModel(BaseModel):
33
+ type: str
34
+ wandb: WandbConfig
35
+
36
+
11
37
  class FineTuneRequest(BaseModel):
12
38
  model: str
13
39
  training_file: str # id of uploaded jsonl file
40
+ method: dict
41
+ suffix: Optional[str] = None
42
+ # UNUSED
43
+ validation_file: Optional[str] = None
44
+ integrations: Optional[List[IntegrationModel]] = []
45
+ seed: Optional[int] = None
46
+
47
+
48
+ class ErrorModel(BaseModel):
49
+ code: str
50
+ message: str
51
+ param: str | None = None
52
+
53
+
54
+ class SupervisedHyperparametersModel(BaseModel):
55
+ batch_size: int | str = "auto"
56
+ learning_rate_multiplier: float | str = "auto"
57
+ n_epochs: int | str = "auto"
58
+
59
+
60
+ class DPOHyperparametersModel(BaseModel):
61
+ beta: float | str = "auto"
62
+ batch_size: int | str = "auto"
63
+ learning_rate_multiplier: float | str = "auto"
64
+ n_epochs: int | str = "auto"
14
65
 
15
- class JobStatusResponse(BaseModel):
66
+
67
+ class SupervisedModel(BaseModel):
68
+ hyperparameters: SupervisedHyperparametersModel
69
+
70
+
71
+ class DpoModel(BaseModel):
72
+ hyperparameters: DPOHyperparametersModel
73
+
74
+
75
+ class MethodModel(BaseModel):
76
+ type: Literal["supervised"] | Literal["dpo"]
77
+ supervised: SupervisedModel | None = None
78
+ dpo: DpoModel | None = None
79
+
80
+
81
+ # https://platform.openai.com/docs/api-reference/fine-tuning/object
82
+ class JobStatus(Enum):
83
+ PENDING = "pending" # Not in OAI
84
+ PENDING_PAUSE = "pending_pause" # Not in OAI
85
+ PENDING_RESUME = "pending_resume" # Not in OAI
86
+ PAUSED = "paused" # Not in OAI
87
+ VALIDATING_FILES = "validating_files"
88
+ QUEUED = "queued"
89
+ RUNNING = "running"
90
+ SUCCEEDED = "succeeded"
91
+ FAILED = "failed"
92
+ CANCELLED = "cancelled"
93
+ PENDING_CANCEL = "pending_cancel"
94
+
95
+
96
+ # https://platform.openai.com/docs/api-reference/fine-tuning/object
97
+ class JobStatusModel(BaseModel):
98
+ object: str = "fine_tuning.job"
99
+ id: str
100
+ fine_tuned_model: str | None = None
101
+ status: JobStatus
102
+
103
+ # UNUSED so commented out
104
+ # model: str
105
+ # created_at: int
106
+ # error: ErrorModel | None = None
107
+ # details: str = ""
108
+ # finished_at: int
109
+ # hyperparameters: None # deprecated in OAI
110
+ # organization_id: str
111
+ # result_files: list[str]
112
+ # trained_tokens: int | None = None # None if not finished
113
+ # training_file: str
114
+ # validation_file: str
115
+ # integrations: list[Integration]
116
+ # seed: int
117
+ # estimated_finish: int | None = None # The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.
118
+ # method: MethodModel
119
+ # metadata: dict[str, str]
120
+
121
+
122
+ class JobEventModel(BaseModel):
123
+ object: str = "fine_tuning.job_event"
16
124
  id: str
125
+ created_at: int
126
+ level: str
127
+ message: str
128
+ data: dict[str, Any]
129
+ type: str
130
+
131
+
132
+ class MetricsModel(BaseModel):
133
+ step: int
134
+ train_loss: float
135
+ train_mean_token_accuracy: float
136
+ valid_loss: float
137
+ valid_mean_token_accuracy: float
138
+ full_valid_loss: float
139
+ full_valid_mean_token_accuracy: float
140
+
141
+
142
+ class JobCheckpointModel(BaseModel):
143
+ object: str = "fine_tuning.job_checkpoint"
144
+ id: str
145
+ created_at: int
146
+ fine_tuned_model_checkpoint: str
147
+ step_number: int
148
+ metrics: MetricsModel
149
+ fine_tuning_job_id: str
150
+
151
+
152
+ class ChatCompletionMessage(BaseModel):
153
+ role: Literal["system", "user", "assistant"]
154
+ content: str
155
+
156
+
157
+ class ChatCompletionRequest(BaseModel):
158
+ model: str
159
+ messages: List[ChatCompletionMessage]
160
+ temperature: float | None = None
161
+ top_p: float | None = None
162
+ max_tokens: int | None = None
163
+
164
+
165
+ class ChatCompletionChoice(BaseModel):
166
+ message: ChatCompletionMessage
167
+ index: int
168
+ finish_reason: Literal["stop", "length", "tool_calls"]
169
+
170
+
171
+ class ChatCompletionModel(BaseModel):
172
+ id: str
173
+ object: str = "chat.completion"
174
+ created: int
175
+ model: str
176
+ choices: List[ChatCompletionChoice]
177
+
178
+
179
+ class GRPORequest(BaseModel):
180
+ model: str
181
+ update_inference_model: bool
182
+ batch: List[dict]
183
+
184
+
185
+ class GRPOConfigRequest(BaseModel):
186
+ model: str
187
+ temperature: Optional[float] = None
188
+ beta: Optional[float] = None
189
+ num_iterations: Optional[int] = None
190
+ num_generations: Optional[int] = None
191
+ per_device_train_batch_size: Optional[int] = None
192
+ learning_rate: Optional[float] = None
193
+ gradient_accumulation_steps: Optional[int] = None
194
+ gradient_checkpointing: Optional[bool] = None
195
+ lr_scheduler_type: Optional[str] = None
196
+ max_prompt_length: Optional[int] = None
197
+ max_completion_length: Optional[int] = None
198
+ gradient_checkpointing_kwargs: Optional[dict] = {}
199
+ bf16: Optional[bool] = None
200
+ scale_rewards: Optional[bool] = None
201
+ max_grad_norm: Optional[float] = None
202
+ lora: Optional[bool] = None
203
+ update_interval: Optional[int] = None
204
+ # To name the run
205
+ suffix: Optional[str] = None
206
+
207
+
208
+ class GRPOConfigResponse(BaseModel):
209
+ status: str
210
+
211
+
212
+ class GRPOTerminateRequest(BaseModel):
213
+ status: Optional[str] = "success"
214
+
215
+
216
+ class GRPOTerminateResponse(BaseModel):
217
+ status: str
218
+ current_model: str
219
+
220
+
221
+ class GRPOStepResponse(BaseModel):
17
222
  status: str
18
- details: str = ""
19
- fine_tuned_model: str | None = None
223
+ current_model: str
@@ -1,23 +1,52 @@
1
- from fastapi import APIRouter, UploadFile, File, Depends, HTTPException
2
- from arbor.server.services.file_manager import FileManager
3
- from arbor.server.api.models.schemas import FileResponse
4
- from arbor.server.services.dependencies import get_file_manager
1
+ from typing import Literal
2
+
3
+ from fastapi import APIRouter, Body, File, HTTPException, Request, UploadFile
4
+
5
+ from arbor.server.api.models.schemas import FileModel, PaginatedResponse
5
6
  from arbor.server.services.file_manager import FileValidationError
6
7
 
8
+ # https://platform.openai.com/docs/api-reference/files/list
7
9
  router = APIRouter()
8
10
 
9
- @router.post("", response_model=FileResponse)
11
+
12
+ @router.post("", response_model=FileModel)
10
13
  async def upload_file(
14
+ request: Request,
11
15
  file: UploadFile = File(...),
12
- file_manager: FileManager = Depends(get_file_manager)
16
+ purpose: Literal["assistants", "vision", "fine-tune", "batch"] = Body("fine-tune"),
13
17
  ):
14
- if not file.filename.endswith('.jsonl'):
18
+ file_manager = request.app.state.file_manager
19
+ if not file.filename.endswith(".jsonl"):
15
20
  raise HTTPException(status_code=400, detail="Only .jsonl files are allowed")
16
21
 
17
22
  try:
18
23
  content = await file.read()
19
- file_manager.validate_file_format(content)
24
+ # file_manager.validate_file_format(content) #TODO: add another flag to specify the types of files
20
25
  await file.seek(0) # Reset file pointer to beginning
21
- return file_manager.save_uploaded_file(file)
26
+ return FileModel(**file_manager.save_uploaded_file(file))
22
27
  except FileValidationError as e:
23
- raise HTTPException(status_code=400, detail=f"Invalid file format: {str(e)}")
28
+ raise HTTPException(status_code=400, detail=f"Invalid file format: {str(e)}")
29
+
30
+
31
+ @router.get("", response_model=PaginatedResponse[FileModel])
32
+ def list_files(request: Request):
33
+ file_manager = request.app.state.file_manager
34
+ return PaginatedResponse(
35
+ items=file_manager.get_files(),
36
+ total=len(file_manager.get_files()),
37
+ page=1,
38
+ page_size=10,
39
+ )
40
+
41
+
42
+ @router.get("/{file_id}", response_model=FileModel)
43
+ def get_file(request: Request, file_id: str):
44
+ file_manager = request.app.state.file_manager
45
+ return file_manager.get_file(file_id)
46
+
47
+
48
+ @router.delete("/{file_id}")
49
+ def delete_file(request: Request, file_id: str):
50
+ file_manager = request.app.state.file_manager
51
+ file_manager.delete_file(file_id)
52
+ return {"message": "File deleted"}
@@ -0,0 +1,54 @@
1
+ import os
2
+ import subprocess
3
+
4
+ from fastapi import APIRouter, BackgroundTasks, Request
5
+
6
+ from arbor.server.api.models.schemas import (
7
+ GRPOConfigRequest,
8
+ GRPOConfigResponse,
9
+ GRPORequest,
10
+ GRPOStepResponse,
11
+ GRPOTerminateRequest,
12
+ GRPOTerminateResponse,
13
+ )
14
+
15
+ router = APIRouter()
16
+
17
+
18
+ @router.post("/initialize", response_model=GRPOConfigResponse)
19
+ def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
20
+ inference_manager = request.app.state.inference_manager
21
+ grpo_manager = request.app.state.grpo_manager
22
+ grpo_manager.initialize(grpo_config_request, inference_manager)
23
+ return GRPOConfigResponse(status="success")
24
+
25
+
26
+ # Create a grpo job
27
+ @router.post("/step", response_model=GRPOStepResponse)
28
+ def run_grpo_step(
29
+ request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
30
+ ):
31
+ inference_manager = request.app.state.inference_manager
32
+ grpo_manager = request.app.state.grpo_manager
33
+
34
+ current_model = grpo_manager.grpo_step(grpo_request, inference_manager)
35
+
36
+ return GRPOStepResponse(status="success", current_model=current_model)
37
+
38
+
39
+ @router.post("/update_model", response_model=GRPOStepResponse)
40
+ def update_model(request: Request):
41
+ grpo_manager = request.app.state.grpo_manager
42
+ inference_manager = request.app.state.inference_manager
43
+ current_model = grpo_manager.update_model(request, inference_manager)
44
+ return GRPOStepResponse(status="success", current_model=current_model)
45
+
46
+
47
+ @router.post("/terminate", response_model=GRPOTerminateResponse)
48
+ def terminate_grpo(request: Request):
49
+ # No body needed for this request at this moment
50
+ grpo_manager = request.app.state.grpo_manager
51
+ inference_manager = request.app.state.inference_manager
52
+
53
+ final_model = grpo_manager.terminate(inference_manager)
54
+ return GRPOTerminateResponse(status="success", current_model=final_model)
@@ -0,0 +1,53 @@
1
+ import time
2
+
3
+ from fastapi import APIRouter, Request
4
+
5
+ router = APIRouter()
6
+
7
+
8
+ @router.post("/completions")
9
+ async def run_inference(
10
+ request: Request,
11
+ ):
12
+ inference_manager = request.app.state.inference_manager
13
+ raw_json = await request.json()
14
+
15
+ prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
16
+ for prefix in prefixes:
17
+ if raw_json["model"].startswith(prefix):
18
+ raw_json["model"] = raw_json["model"][len(prefix) :]
19
+
20
+ # if a server isnt running, launch one
21
+ if (
22
+ not inference_manager.is_server_running()
23
+ and not inference_manager.is_server_restarting()
24
+ ):
25
+ print("No model is running, launching model...")
26
+ inference_manager.launch(raw_json["model"])
27
+
28
+ if inference_manager.is_server_restarting():
29
+ print("Waiting for server to finish restarting...")
30
+ while inference_manager.is_server_restarting():
31
+ time.sleep(5)
32
+ # Update the model in the request
33
+ raw_json["model"] = inference_manager.current_model
34
+
35
+ # forward the request to the inference server
36
+ completion = inference_manager.run_inference(raw_json)
37
+
38
+ return completion
39
+
40
+
41
+ @router.post("/launch")
42
+ async def launch_inference(request: Request):
43
+ inference_manager = request.app.state.inference_manager
44
+ raw_json = await request.json()
45
+ inference_manager.launch(raw_json["model"], raw_json["launch_kwargs"])
46
+ return {"message": "Inference server launched"}
47
+
48
+
49
+ @router.post("/kill")
50
+ async def kill_inference(request: Request):
51
+ inference_manager = request.app.state.inference_manager
52
+ inference_manager.kill()
53
+ return {"message": "Inference server killed"}
@@ -1,14 +1,117 @@
1
- from fastapi import APIRouter, Depends
2
- from arbor.server.services.job_manager import JobManager
3
- from arbor.server.services.dependencies import get_job_manager
4
- from arbor.server.api.models.schemas import JobStatusResponse
1
+ from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
2
+
3
+ from arbor.server.api.models.schemas import (
4
+ FineTuneRequest,
5
+ JobCheckpointModel,
6
+ JobEventModel,
7
+ JobStatus,
8
+ JobStatusModel,
9
+ PaginatedResponse,
10
+ )
11
+ from arbor.server.services.job_manager import JobStatus
5
12
 
6
13
  router = APIRouter()
7
14
 
8
- @router.get("/{job_id}", response_model=JobStatusResponse)
15
+
16
+ # Create a fine-tune job
17
+ @router.post("", response_model=JobStatusModel)
18
+ def create_fine_tune_job(
19
+ request: Request,
20
+ fine_tune_request: FineTuneRequest,
21
+ background_tasks: BackgroundTasks,
22
+ ):
23
+ job_manager = request.app.state.job_manager
24
+ file_manager = request.app.state.file_manager
25
+ training_manager = request.app.state.training_manager
26
+
27
+ job = job_manager.create_job()
28
+ background_tasks.add_task(
29
+ training_manager.fine_tune, fine_tune_request, job, file_manager
30
+ )
31
+ job.status = JobStatus.QUEUED
32
+ return JobStatusModel(id=job.id, status=job.status.value)
33
+
34
+
35
+ # List fine-tune jobs (paginated)
36
+ @router.get("", response_model=PaginatedResponse[JobStatusModel])
37
+ def get_jobs(request: Request):
38
+ job_manager = request.app.state.job_manager
39
+ return PaginatedResponse(
40
+ data=[
41
+ JobStatusModel(id=job.id, status=job.status.value)
42
+ for job in job_manager.get_jobs()
43
+ ],
44
+ has_more=False,
45
+ )
46
+
47
+
48
+ # List fine-tuning events
49
+ @router.get("/{job_id}/events", response_model=PaginatedResponse[JobEventModel])
50
+ def get_job_events(request: Request, job_id: str):
51
+ job_manager = request.app.state.job_manager
52
+ job = job_manager.get_job(job_id)
53
+ return PaginatedResponse(
54
+ data=[
55
+ JobEventModel(
56
+ id=event.id,
57
+ level=event.level,
58
+ message=event.message,
59
+ data=event.data,
60
+ created_at=int(event.created_at.timestamp()),
61
+ type="message",
62
+ )
63
+ for event in job.get_events()
64
+ ],
65
+ has_more=False,
66
+ )
67
+
68
+
69
+ # List fine-tuning checkpoints
70
+ @router.get(
71
+ "/{job_id}/checkpoints", response_model=PaginatedResponse[JobCheckpointModel]
72
+ )
73
+ def get_job_checkpoints(request: Request, job_id: str):
74
+ job_manager = request.app.state.job_manager
75
+ job = job_manager.get_job(job_id)
76
+ return PaginatedResponse(
77
+ data=[
78
+ JobCheckpointModel(
79
+ id=checkpoint.id,
80
+ fine_tuned_model_checkpoint=checkpoint.fine_tuned_model_checkpoint,
81
+ fine_tuning_job_id=checkpoint.fine_tuning_job_id,
82
+ metrics=checkpoint.metrics,
83
+ step_number=checkpoint.step_number,
84
+ )
85
+ for checkpoint in job.get_checkpoints()
86
+ ],
87
+ has_more=False,
88
+ )
89
+
90
+
91
+ # Retrieve a fine-tune job by id
92
+ @router.get("/{job_id}", response_model=JobStatusModel)
9
93
  def get_job_status(
94
+ request: Request,
10
95
  job_id: str,
11
- job_manager: JobManager = Depends(get_job_manager)
12
96
  ):
97
+ job_manager = request.app.state.job_manager
13
98
  job = job_manager.get_job(job_id)
14
- return JobStatusResponse(id=job_id, status=job.status.value, fine_tuned_model=job.fine_tuned_model)
99
+ return JobStatusModel(
100
+ id=job_id, status=job.status.value, fine_tuned_model=job.fine_tuned_model
101
+ )
102
+
103
+
104
+ # Cancel a fine-tune job
105
+ @router.post("/{job_id}/cancel", response_model=JobStatusModel)
106
+ def cancel_job(request: Request, job_id: str):
107
+ job_manager = request.app.state.job_manager
108
+ job = job_manager.get_job(job_id)
109
+
110
+ # Only allow cancellation of jobs that aren't finished
111
+ if job.status in [JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.CANCELLED]:
112
+ raise HTTPException(
113
+ status_code=400, detail=f"Cannot cancel job with status {job.status.value}"
114
+ )
115
+
116
+ job.status = JobStatus.PENDING_CANCEL
117
+ return JobStatusModel(id=job.id, status=job.status.value)
@@ -1,10 +1,47 @@
1
- from pydantic_settings import BaseSettings
1
+ from pathlib import Path
2
+ from typing import Optional
2
3
 
3
- class Settings(BaseSettings):
4
- UPLOADS_DIR: str = "uploads"
5
- MODEL_CACHE_DIR: str = "model_cache"
4
+ import yaml
5
+ from pydantic import BaseModel, ConfigDict
6
6
 
7
- class Config:
8
- env_file = ".env"
9
7
 
10
- settings = Settings()
8
+ class InferenceConfig(BaseModel):
9
+ gpu_ids: str = "0"
10
+
11
+
12
+ class TrainingConfig(BaseModel):
13
+ gpu_ids: str = "0"
14
+ accelerate_config: Optional[str] = None
15
+
16
+
17
+ class ArborConfig(BaseModel):
18
+ inference: InferenceConfig
19
+ training: TrainingConfig
20
+
21
+
22
+ class Settings(BaseModel):
23
+
24
+ STORAGE_PATH: str = "./storage"
25
+ INACTIVITY_TIMEOUT: int = 30 # 5 seconds
26
+ arbor_config: ArborConfig
27
+
28
+ @classmethod
29
+ def load_from_yaml(cls, yaml_path: str) -> "Settings":
30
+ if not yaml_path:
31
+ raise ValueError("Config file path is required")
32
+ if not Path(yaml_path).exists():
33
+ raise ValueError(f"Config file {yaml_path} does not exist")
34
+
35
+ try:
36
+ with open(yaml_path, "r") as f:
37
+ config = yaml.safe_load(f)
38
+
39
+ settings = cls(
40
+ arbor_config=ArborConfig(
41
+ inference=InferenceConfig(**config["inference"]),
42
+ training=TrainingConfig(**config["training"]),
43
+ )
44
+ )
45
+ return settings
46
+ except Exception as e:
47
+ raise ValueError(f"Error loading config file {yaml_path}: {e}")