arbor-ai 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. arbor_ai-0.1.6/PKG-INFO +78 -0
  2. arbor_ai-0.1.6/README.md +51 -0
  3. arbor_ai-0.1.6/arbor/cli.py +101 -0
  4. arbor_ai-0.1.6/arbor/client/api.py +1 -0
  5. arbor_ai-0.1.6/arbor/server/api/models/schemas.py +223 -0
  6. arbor_ai-0.1.6/arbor/server/api/routes/files.py +52 -0
  7. arbor_ai-0.1.6/arbor/server/api/routes/grpo.py +54 -0
  8. arbor_ai-0.1.6/arbor/server/api/routes/inference.py +53 -0
  9. arbor_ai-0.1.6/arbor/server/api/routes/jobs.py +117 -0
  10. arbor_ai-0.1.6/arbor/server/core/config.py +47 -0
  11. arbor_ai-0.1.6/arbor/server/main.py +11 -0
  12. arbor_ai-0.1.6/arbor/server/services/comms/comms.py +226 -0
  13. arbor_ai-0.1.6/arbor/server/services/file_manager.py +289 -0
  14. arbor_ai-0.1.6/arbor/server/services/grpo_manager.py +310 -0
  15. arbor_ai-0.1.6/arbor/server/services/inference_manager.py +275 -0
  16. arbor_ai-0.1.6/arbor/server/services/job_manager.py +81 -0
  17. arbor_ai-0.1.6/arbor/server/services/scripts/grpo_training.py +576 -0
  18. arbor_ai-0.1.6/arbor/server/services/training_manager.py +561 -0
  19. arbor_ai-0.1.6/arbor/server/utils/__init__.py +0 -0
  20. arbor_ai-0.1.6/arbor/server/utils/helpers.py +0 -0
  21. arbor_ai-0.1.6/arbor_ai.egg-info/PKG-INFO +78 -0
  22. arbor_ai-0.1.6/arbor_ai.egg-info/SOURCES.txt +38 -0
  23. arbor_ai-0.1.6/arbor_ai.egg-info/dependency_links.txt +1 -0
  24. arbor_ai-0.1.6/arbor_ai.egg-info/entry_points.txt +2 -0
  25. arbor_ai-0.1.6/arbor_ai.egg-info/requires.txt +15 -0
  26. arbor_ai-0.1.6/arbor_ai.egg-info/top_level.txt +1 -0
  27. arbor_ai-0.1.6/pyproject.toml +45 -0
  28. arbor_ai-0.1.6/setup.cfg +4 -0
  29. arbor_ai-0.1.6/tests/test_cli.py +58 -0
  30. arbor_ai-0.1.4/PKG-INFO +0 -97
  31. arbor_ai-0.1.4/README.md +0 -67
  32. arbor_ai-0.1.4/pyproject.toml +0 -55
  33. arbor_ai-0.1.4/src/arbor/cli.py +0 -17
  34. arbor_ai-0.1.4/src/arbor/client/api.py +0 -2
  35. arbor_ai-0.1.4/src/arbor/server/api/models/schemas.py +0 -19
  36. arbor_ai-0.1.4/src/arbor/server/api/routes/files.py +0 -23
  37. arbor_ai-0.1.4/src/arbor/server/api/routes/jobs.py +0 -14
  38. arbor_ai-0.1.4/src/arbor/server/api/routes/training.py +0 -16
  39. arbor_ai-0.1.4/src/arbor/server/core/config.py +0 -10
  40. arbor_ai-0.1.4/src/arbor/server/main.py +0 -10
  41. arbor_ai-0.1.4/src/arbor/server/services/dependencies.py +0 -16
  42. arbor_ai-0.1.4/src/arbor/server/services/file_manager.py +0 -128
  43. arbor_ai-0.1.4/src/arbor/server/services/job_manager.py +0 -76
  44. arbor_ai-0.1.4/src/arbor/server/services/training_manager.py +0 -264
  45. {arbor_ai-0.1.4 → arbor_ai-0.1.6}/LICENSE +0 -0
  46. {arbor_ai-0.1.4/src → arbor_ai-0.1.6}/arbor/__init__.py +0 -0
  47. {arbor_ai-0.1.4/src → arbor_ai-0.1.6}/arbor/client/__init__.py +0 -0
  48. {arbor_ai-0.1.4/src → arbor_ai-0.1.6}/arbor/server/__init__.py +0 -0
  49. {arbor_ai-0.1.4/src → arbor_ai-0.1.6}/arbor/server/api/__init__.py +0 -0
  50. {arbor_ai-0.1.4/src → arbor_ai-0.1.6}/arbor/server/api/routes/__init__.py +0 -0
  51. {arbor_ai-0.1.4/src → arbor_ai-0.1.6}/arbor/server/core/__init__.py +0 -0
  52. {arbor_ai-0.1.4/src → arbor_ai-0.1.6}/arbor/server/core/logging.py +0 -0
  53. {arbor_ai-0.1.4/src → arbor_ai-0.1.6}/arbor/server/services/__init__.py +0 -0
  54. {arbor_ai-0.1.4/src/arbor/server/utils → arbor_ai-0.1.6/arbor/server/services/comms}/__init__.py +0 -0
  55. /arbor_ai-0.1.4/src/arbor/server/utils/helpers.py → /arbor_ai-0.1.6/arbor/server/services/dependencies.py +0 -0
@@ -0,0 +1,78 @@
1
+ Metadata-Version: 2.4
2
+ Name: arbor-ai
3
+ Version: 0.1.6
4
+ Summary: A framework for fine-tuning and managing language models
5
+ Author-email: Noah Ziems <nziems2@nd.edu>
6
+ Project-URL: Homepage, https://github.com/Ziems/arbor
7
+ Project-URL: Issues, https://github.com/Ziems/arbor/issues
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: fastapi
12
+ Requires-Dist: uvicorn
13
+ Requires-Dist: click
14
+ Requires-Dist: python-multipart
15
+ Requires-Dist: pydantic-settings
16
+ Requires-Dist: torch
17
+ Requires-Dist: transformers
18
+ Requires-Dist: trl
19
+ Requires-Dist: peft
20
+ Requires-Dist: ray>=2.9
21
+ Requires-Dist: setuptools<77.0.0,>=76.0.0
22
+ Requires-Dist: pyzmq>=26.4.0
23
+ Requires-Dist: pyyaml>=6.0.2
24
+ Requires-Dist: sglang>=0.4.5.post3
25
+ Requires-Dist: sglang-router
26
+ Dynamic: license-file
27
+
28
+ <p align="center">
29
+ <img src="https://github.com/user-attachments/assets/ed0dd782-65fa-48b5-a762-b343b183be09" alt="Description" width="400"/>
30
+ </p>
31
+
32
+ **A framework for optimizing DSPy programs with RL.**
33
+
34
+ ---
35
+
36
+ ## 🚀 Installation
37
+
38
+ Install Arbor via pip:
39
+
40
+ ```bash
41
+ pip install git+https://github.com/Ziems/arbor.git
42
+ ```
43
+
44
+ ---
45
+
46
+ ## ⚡ Quick Start
47
+
48
+ ### 1️⃣ Make an `arbor.yaml` File
49
+
50
+ This is all dependent on your setup. Here is an example of one:
51
+ ```yaml
52
+ inference:
53
+ gpu_ids: '0'
54
+
55
+ training:
56
+ gpu_ids: '1, 2'
57
+ ```
58
+
59
+ ### 2️⃣ Start the Server
60
+
61
+ **CLI:**
62
+
63
+ ```bash
64
+ python -m arbor.cli serve --arbor-config arbor.yaml
65
+ ```
66
+
67
+ ### 3️⃣ Optimize a DSPy Program
68
+
69
+ Follow the DSPy tutorials here to see usage examples:
70
+ [DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
71
+
72
+ ---
73
+
74
+ ## 🙏 Acknowledgements
75
+
76
+ Arbor builds on the shoulders of great work. We extend our thanks to:
77
+ - **[Will Brown's Verifiers library](https://github.com/willccbb/verifiers)**
78
+ - **[Hugging Face TRL library](https://github.com/huggingface/trl)**
@@ -0,0 +1,51 @@
1
+ <p align="center">
2
+ <img src="https://github.com/user-attachments/assets/ed0dd782-65fa-48b5-a762-b343b183be09" alt="Description" width="400"/>
3
+ </p>
4
+
5
+ **A framework for optimizing DSPy programs with RL.**
6
+
7
+ ---
8
+
9
+ ## 🚀 Installation
10
+
11
+ Install Arbor via pip:
12
+
13
+ ```bash
14
+ pip install git+https://github.com/Ziems/arbor.git
15
+ ```
16
+
17
+ ---
18
+
19
+ ## ⚡ Quick Start
20
+
21
+ ### 1️⃣ Make an `arbor.yaml` File
22
+
23
+ This is all dependent on your setup. Here is an example of one:
24
+ ```yaml
25
+ inference:
26
+ gpu_ids: '0'
27
+
28
+ training:
29
+ gpu_ids: '1, 2'
30
+ ```
31
+
32
+ ### 2️⃣ Start the Server
33
+
34
+ **CLI:**
35
+
36
+ ```bash
37
+ python -m arbor.cli serve --arbor-config arbor.yaml
38
+ ```
39
+
40
+ ### 3️⃣ Optimize a DSPy Program
41
+
42
+ Follow the DSPy tutorials here to see usage examples:
43
+ [DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
44
+
45
+ ---
46
+
47
+ ## 🙏 Acknowledgements
48
+
49
+ Arbor builds on the shoulders of great work. We extend our thanks to:
50
+ - **[Will Brown's Verifiers library](https://github.com/willccbb/verifiers)**
51
+ - **[Hugging Face TRL library](https://github.com/huggingface/trl)**
@@ -0,0 +1,101 @@
1
+ import click
2
+ import uvicorn
3
+
4
+ from arbor.server.core.config import Settings
5
+ from arbor.server.main import app
6
+ from arbor.server.services.file_manager import FileManager
7
+ from arbor.server.services.grpo_manager import GRPOManager
8
+ from arbor.server.services.inference_manager import InferenceManager
9
+ from arbor.server.services.job_manager import JobManager
10
+ from arbor.server.services.training_manager import TrainingManager
11
+
12
+
13
+ @click.group()
14
+ def cli():
15
+ pass
16
+
17
+
18
+ def create_app(arbor_config_path: str):
19
+ """Create and configure the Arbor API application
20
+
21
+ Args:
22
+ storage_path (str): Path to store models and uploaded training files
23
+
24
+ Returns:
25
+ FastAPI: Configured FastAPI application
26
+ """
27
+ # Create new settings instance with overrides
28
+ settings = Settings.load_from_yaml(arbor_config_path)
29
+
30
+ # Initialize services with settings
31
+ file_manager = FileManager(settings=settings)
32
+ job_manager = JobManager(settings=settings)
33
+ training_manager = TrainingManager(settings=settings)
34
+ inference_manager = InferenceManager(settings=settings)
35
+ grpo_manager = GRPOManager(settings=settings)
36
+ # Inject settings into app state
37
+ app.state.settings = settings
38
+ app.state.file_manager = file_manager
39
+ app.state.job_manager = job_manager
40
+ app.state.training_manager = training_manager
41
+ app.state.inference_manager = inference_manager
42
+ app.state.grpo_manager = grpo_manager
43
+
44
+ return app
45
+
46
+
47
+ def start_server(host="0.0.0.0", port=7453, storage_path="./storage", timeout=10):
48
+ """Start the Arbor API server with a single function call"""
49
+ import socket
50
+ import threading
51
+ import time
52
+ from contextlib import closing
53
+
54
+ def is_port_in_use(port):
55
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
56
+ return sock.connect_ex(("localhost", port)) == 0
57
+
58
+ # First ensure the port is free
59
+ if is_port_in_use(port):
60
+ raise RuntimeError(f"Port {port} is already in use")
61
+
62
+ app = create_app(storage_path)
63
+ config = uvicorn.Config(app, host=host, port=port, log_level="info")
64
+ server = uvicorn.Server(config)
65
+
66
+ def run_server():
67
+ server.run()
68
+
69
+ thread = threading.Thread(target=run_server, daemon=True)
70
+ thread.start()
71
+
72
+ # Wait for server to start
73
+ start_time = time.time()
74
+ while not is_port_in_use(port):
75
+ if time.time() - start_time > timeout:
76
+ raise TimeoutError(f"Server failed to start within {timeout} seconds")
77
+ time.sleep(0.1)
78
+
79
+ # Give it a little extra time to fully initialize
80
+ time.sleep(0.5)
81
+
82
+ return server
83
+
84
+
85
+ def stop_server(server):
86
+ """Stop the Arbor API server"""
87
+ server.should_exit = True
88
+
89
+
90
+ @cli.command()
91
+ @click.option("--host", default="0.0.0.0", help="Host to bind to")
92
+ @click.option("--port", default=7453, help="Port to bind to")
93
+ @click.option("--arbor-config", required=True, help="Path to the Arbor config file")
94
+ def serve(host, port, arbor_config):
95
+ """Start the Arbor API server"""
96
+ app = create_app(arbor_config)
97
+ uvicorn.run(app, host=host, port=port)
98
+
99
+
100
+ if __name__ == "__main__":
101
+ cli()
@@ -0,0 +1 @@
1
+ # Unused Right Now
@@ -0,0 +1,223 @@
1
+ from enum import Enum
2
+ from typing import Any, Generic, List, Literal, Optional, TypeVar
3
+
4
+ from pydantic import BaseModel, ConfigDict
5
+
6
+ # Generic type for list items
7
+ T = TypeVar("T")
8
+
9
+
10
+ class PaginatedResponse(BaseModel, Generic[T]):
11
+ object: str = "list"
12
+ data: List[T]
13
+ has_more: bool = False
14
+
15
+
16
+ class FileModel(BaseModel):
17
+ id: str
18
+ object: str = "file"
19
+ bytes: int
20
+ created_at: int
21
+ filename: str
22
+ purpose: str
23
+
24
+
25
+ class WandbConfig(BaseModel):
26
+ project: str
27
+ name: Optional[str] = None
28
+ entity: Optional[str] = None
29
+ tags: Optional[List[str]] = None
30
+
31
+
32
+ class IntegrationModel(BaseModel):
33
+ type: str
34
+ wandb: WandbConfig
35
+
36
+
37
+ class FineTuneRequest(BaseModel):
38
+ model: str
39
+ training_file: str # id of uploaded jsonl file
40
+ method: dict
41
+ suffix: Optional[str] = None
42
+ # UNUSED
43
+ validation_file: Optional[str] = None
44
+ integrations: Optional[List[IntegrationModel]] = []
45
+ seed: Optional[int] = None
46
+
47
+
48
+ class ErrorModel(BaseModel):
49
+ code: str
50
+ message: str
51
+ param: str | None = None
52
+
53
+
54
+ class SupervisedHyperparametersModel(BaseModel):
55
+ batch_size: int | str = "auto"
56
+ learning_rate_multiplier: float | str = "auto"
57
+ n_epochs: int | str = "auto"
58
+
59
+
60
+ class DPOHyperparametersModel(BaseModel):
61
+ beta: float | str = "auto"
62
+ batch_size: int | str = "auto"
63
+ learning_rate_multiplier: float | str = "auto"
64
+ n_epochs: int | str = "auto"
65
+
66
+
67
+ class SupervisedModel(BaseModel):
68
+ hyperparameters: SupervisedHyperparametersModel
69
+
70
+
71
+ class DpoModel(BaseModel):
72
+ hyperparameters: DPOHyperparametersModel
73
+
74
+
75
+ class MethodModel(BaseModel):
76
+ type: Literal["supervised"] | Literal["dpo"]
77
+ supervised: SupervisedModel | None = None
78
+ dpo: DpoModel | None = None
79
+
80
+
81
+ # https://platform.openai.com/docs/api-reference/fine-tuning/object
82
+ class JobStatus(Enum):
83
+ PENDING = "pending" # Not in OAI
84
+ PENDING_PAUSE = "pending_pause" # Not in OAI
85
+ PENDING_RESUME = "pending_resume" # Not in OAI
86
+ PAUSED = "paused" # Not in OAI
87
+ VALIDATING_FILES = "validating_files"
88
+ QUEUED = "queued"
89
+ RUNNING = "running"
90
+ SUCCEEDED = "succeeded"
91
+ FAILED = "failed"
92
+ CANCELLED = "cancelled"
93
+ PENDING_CANCEL = "pending_cancel"
94
+
95
+
96
+ # https://platform.openai.com/docs/api-reference/fine-tuning/object
97
+ class JobStatusModel(BaseModel):
98
+ object: str = "fine_tuning.job"
99
+ id: str
100
+ fine_tuned_model: str | None = None
101
+ status: JobStatus
102
+
103
+ # UNUSED so commented out
104
+ # model: str
105
+ # created_at: int
106
+ # error: ErrorModel | None = None
107
+ # details: str = ""
108
+ # finished_at: int
109
+ # hyperparameters: None # deprecated in OAI
110
+ # organization_id: str
111
+ # result_files: list[str]
112
+ # trained_tokens: int | None = None # None if not finished
113
+ # training_file: str
114
+ # validation_file: str
115
+ # integrations: list[Integration]
116
+ # seed: int
117
+ # estimated_finish: int | None = None # The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running.
118
+ # method: MethodModel
119
+ # metadata: dict[str, str]
120
+
121
+
122
+ class JobEventModel(BaseModel):
123
+ object: str = "fine_tuning.job_event"
124
+ id: str
125
+ created_at: int
126
+ level: str
127
+ message: str
128
+ data: dict[str, Any]
129
+ type: str
130
+
131
+
132
+ class MetricsModel(BaseModel):
133
+ step: int
134
+ train_loss: float
135
+ train_mean_token_accuracy: float
136
+ valid_loss: float
137
+ valid_mean_token_accuracy: float
138
+ full_valid_loss: float
139
+ full_valid_mean_token_accuracy: float
140
+
141
+
142
+ class JobCheckpointModel(BaseModel):
143
+ object: str = "fine_tuning.job_checkpoint"
144
+ id: str
145
+ created_at: int
146
+ fine_tuned_model_checkpoint: str
147
+ step_number: int
148
+ metrics: MetricsModel
149
+ fine_tuning_job_id: str
150
+
151
+
152
+ class ChatCompletionMessage(BaseModel):
153
+ role: Literal["system", "user", "assistant"]
154
+ content: str
155
+
156
+
157
+ class ChatCompletionRequest(BaseModel):
158
+ model: str
159
+ messages: List[ChatCompletionMessage]
160
+ temperature: float | None = None
161
+ top_p: float | None = None
162
+ max_tokens: int | None = None
163
+
164
+
165
+ class ChatCompletionChoice(BaseModel):
166
+ message: ChatCompletionMessage
167
+ index: int
168
+ finish_reason: Literal["stop", "length", "tool_calls"]
169
+
170
+
171
+ class ChatCompletionModel(BaseModel):
172
+ id: str
173
+ object: str = "chat.completion"
174
+ created: int
175
+ model: str
176
+ choices: List[ChatCompletionChoice]
177
+
178
+
179
+ class GRPORequest(BaseModel):
180
+ model: str
181
+ update_inference_model: bool
182
+ batch: List[dict]
183
+
184
+
185
+ class GRPOConfigRequest(BaseModel):
186
+ model: str
187
+ temperature: Optional[float] = None
188
+ beta: Optional[float] = None
189
+ num_iterations: Optional[int] = None
190
+ num_generations: Optional[int] = None
191
+ per_device_train_batch_size: Optional[int] = None
192
+ learning_rate: Optional[float] = None
193
+ gradient_accumulation_steps: Optional[int] = None
194
+ gradient_checkpointing: Optional[bool] = None
195
+ lr_scheduler_type: Optional[str] = None
196
+ max_prompt_length: Optional[int] = None
197
+ max_completion_length: Optional[int] = None
198
+ gradient_checkpointing_kwargs: Optional[dict] = {}
199
+ bf16: Optional[bool] = None
200
+ scale_rewards: Optional[bool] = None
201
+ max_grad_norm: Optional[float] = None
202
+ lora: Optional[bool] = None
203
+ update_interval: Optional[int] = None
204
+ # To name the run
205
+ suffix: Optional[str] = None
206
+
207
+
208
+ class GRPOConfigResponse(BaseModel):
209
+ status: str
210
+
211
+
212
+ class GRPOTerminateRequest(BaseModel):
213
+ status: Optional[str] = "success"
214
+
215
+
216
+ class GRPOTerminateResponse(BaseModel):
217
+ status: str
218
+ current_model: str
219
+
220
+
221
+ class GRPOStepResponse(BaseModel):
222
+ status: str
223
+ current_model: str
@@ -0,0 +1,52 @@
1
+ from typing import Literal
2
+
3
+ from fastapi import APIRouter, Body, File, HTTPException, Request, UploadFile
4
+
5
+ from arbor.server.api.models.schemas import FileModel, PaginatedResponse
6
+ from arbor.server.services.file_manager import FileValidationError
7
+
8
+ # https://platform.openai.com/docs/api-reference/files/list
9
+ router = APIRouter()
10
+
11
+
12
+ @router.post("", response_model=FileModel)
13
+ async def upload_file(
14
+ request: Request,
15
+ file: UploadFile = File(...),
16
+ purpose: Literal["assistants", "vision", "fine-tune", "batch"] = Body("fine-tune"),
17
+ ):
18
+ file_manager = request.app.state.file_manager
19
+ if not file.filename.endswith(".jsonl"):
20
+ raise HTTPException(status_code=400, detail="Only .jsonl files are allowed")
21
+
22
+ try:
23
+ content = await file.read()
24
+ # file_manager.validate_file_format(content) #TODO: add another flag to specify the types of files
25
+ await file.seek(0) # Reset file pointer to beginning
26
+ return FileModel(**file_manager.save_uploaded_file(file))
27
+ except FileValidationError as e:
28
+ raise HTTPException(status_code=400, detail=f"Invalid file format: {str(e)}")
29
+
30
+
31
+ @router.get("", response_model=PaginatedResponse[FileModel])
32
+ def list_files(request: Request):
33
+ file_manager = request.app.state.file_manager
34
+ return PaginatedResponse(
35
+ items=file_manager.get_files(),
36
+ total=len(file_manager.get_files()),
37
+ page=1,
38
+ page_size=10,
39
+ )
40
+
41
+
42
+ @router.get("/{file_id}", response_model=FileModel)
43
+ def get_file(request: Request, file_id: str):
44
+ file_manager = request.app.state.file_manager
45
+ return file_manager.get_file(file_id)
46
+
47
+
48
+ @router.delete("/{file_id}")
49
+ def delete_file(request: Request, file_id: str):
50
+ file_manager = request.app.state.file_manager
51
+ file_manager.delete_file(file_id)
52
+ return {"message": "File deleted"}
@@ -0,0 +1,54 @@
1
+ import os
2
+ import subprocess
3
+
4
+ from fastapi import APIRouter, BackgroundTasks, Request
5
+
6
+ from arbor.server.api.models.schemas import (
7
+ GRPOConfigRequest,
8
+ GRPOConfigResponse,
9
+ GRPORequest,
10
+ GRPOStepResponse,
11
+ GRPOTerminateRequest,
12
+ GRPOTerminateResponse,
13
+ )
14
+
15
+ router = APIRouter()
16
+
17
+
18
+ @router.post("/initialize", response_model=GRPOConfigResponse)
19
+ def initialize_grpo(request: Request, grpo_config_request: GRPOConfigRequest):
20
+ inference_manager = request.app.state.inference_manager
21
+ grpo_manager = request.app.state.grpo_manager
22
+ grpo_manager.initialize(grpo_config_request, inference_manager)
23
+ return GRPOConfigResponse(status="success")
24
+
25
+
26
+ # Create a grpo job
27
+ @router.post("/step", response_model=GRPOStepResponse)
28
+ def run_grpo_step(
29
+ request: Request, grpo_request: GRPORequest, background_tasks: BackgroundTasks
30
+ ):
31
+ inference_manager = request.app.state.inference_manager
32
+ grpo_manager = request.app.state.grpo_manager
33
+
34
+ current_model = grpo_manager.grpo_step(grpo_request, inference_manager)
35
+
36
+ return GRPOStepResponse(status="success", current_model=current_model)
37
+
38
+
39
+ @router.post("/update_model", response_model=GRPOStepResponse)
40
+ def update_model(request: Request):
41
+ grpo_manager = request.app.state.grpo_manager
42
+ inference_manager = request.app.state.inference_manager
43
+ current_model = grpo_manager.update_model(request, inference_manager)
44
+ return GRPOStepResponse(status="success", current_model=current_model)
45
+
46
+
47
+ @router.post("/terminate", response_model=GRPOTerminateResponse)
48
+ def terminate_grpo(request: Request):
49
+ # No body needed for this request at this moment
50
+ grpo_manager = request.app.state.grpo_manager
51
+ inference_manager = request.app.state.inference_manager
52
+
53
+ final_model = grpo_manager.terminate(inference_manager)
54
+ return GRPOTerminateResponse(status="success", current_model=final_model)
@@ -0,0 +1,53 @@
1
+ import time
2
+
3
+ from fastapi import APIRouter, Request
4
+
5
+ router = APIRouter()
6
+
7
+
8
+ @router.post("/completions")
9
+ async def run_inference(
10
+ request: Request,
11
+ ):
12
+ inference_manager = request.app.state.inference_manager
13
+ raw_json = await request.json()
14
+
15
+ prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
16
+ for prefix in prefixes:
17
+ if raw_json["model"].startswith(prefix):
18
+ raw_json["model"] = raw_json["model"][len(prefix) :]
19
+
20
+ # if a server isnt running, launch one
21
+ if (
22
+ not inference_manager.is_server_running()
23
+ and not inference_manager.is_server_restarting()
24
+ ):
25
+ print("No model is running, launching model...")
26
+ inference_manager.launch(raw_json["model"])
27
+
28
+ if inference_manager.is_server_restarting():
29
+ print("Waiting for server to finish restarting...")
30
+ while inference_manager.is_server_restarting():
31
+ time.sleep(5)
32
+ # Update the model in the request
33
+ raw_json["model"] = inference_manager.current_model
34
+
35
+ # forward the request to the inference server
36
+ completion = inference_manager.run_inference(raw_json)
37
+
38
+ return completion
39
+
40
+
41
+ @router.post("/launch")
42
+ async def launch_inference(request: Request):
43
+ inference_manager = request.app.state.inference_manager
44
+ raw_json = await request.json()
45
+ inference_manager.launch(raw_json["model"], raw_json["launch_kwargs"])
46
+ return {"message": "Inference server launched"}
47
+
48
+
49
+ @router.post("/kill")
50
+ async def kill_inference(request: Request):
51
+ inference_manager = request.app.state.inference_manager
52
+ inference_manager.kill()
53
+ return {"message": "Inference server killed"}