arbor-ai 0.1.15__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/server/api/models/schemas.py +0 -1
- arbor/server/api/routes/grpo.py +0 -8
- arbor/server/api/routes/inference.py +16 -1
- arbor/server/services/grpo_manager.py +0 -8
- arbor/server/services/inference/vllm_client.py +3 -4
- arbor/server/services/inference/vllm_serve.py +1 -0
- arbor/server/services/scripts/grpo_training.py +8 -4
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.dist-info}/METADATA +7 -2
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.dist-info}/RECORD +13 -13
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.dist-info}/WHEEL +0 -0
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.dist-info}/top_level.txt +0 -0
arbor/server/api/routes/grpo.py
CHANGED
@@ -38,14 +38,6 @@ def run_grpo_step(
|
|
38
38
|
return GRPOStepResponse(status="success", **step_data)
|
39
39
|
|
40
40
|
|
41
|
-
@router.post("/update_model", response_model=GRPOStepResponse)
|
42
|
-
def update_model(request: Request):
|
43
|
-
grpo_manager = request.app.state.grpo_manager
|
44
|
-
inference_manager = request.app.state.inference_manager
|
45
|
-
update_model_data = grpo_manager.update_model(request, inference_manager)
|
46
|
-
return GRPOStepResponse(status="success", **update_model_data)
|
47
|
-
|
48
|
-
|
49
41
|
@router.post("/checkpoint", response_model=GRPOCheckpointResponse)
|
50
42
|
def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
|
51
43
|
grpo_manager = request.app.state.grpo_manager
|
@@ -19,10 +19,25 @@ async def run_inference(
|
|
19
19
|
with open(f"{request.app.state.log_dir}/inference_requests.jsonl", "a") as f:
|
20
20
|
f.write(json.dumps({"id": request_id, "request": raw_json}) + "\n")
|
21
21
|
|
22
|
+
request_model = raw_json["model"]
|
23
|
+
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
24
|
+
for prefix in prefixes:
|
25
|
+
if request_model.startswith(prefix):
|
26
|
+
request_model = request_model[len(prefix) :]
|
27
|
+
|
22
28
|
# if a server isnt running, launch one
|
23
29
|
if not inference_manager.is_server_running():
|
24
30
|
print("No model is running, launching model...")
|
25
|
-
inference_manager.launch(
|
31
|
+
inference_manager.launch(request_model)
|
32
|
+
|
33
|
+
# if the requested model is different from the launched model, swap the server
|
34
|
+
if request_model != inference_manager.launched_model:
|
35
|
+
print(
|
36
|
+
f"Model changed from {inference_manager.launched_model} to {request_model}, swapping server..."
|
37
|
+
)
|
38
|
+
inference_manager.kill()
|
39
|
+
inference_manager.launch(request_model)
|
40
|
+
print(f"Model swapped to {request_model}")
|
26
41
|
|
27
42
|
# forward the request to the inference server
|
28
43
|
completion = await inference_manager.run_inference(raw_json)
|
@@ -314,14 +314,6 @@ class GRPOManager:
|
|
314
314
|
"last_checkpoint": self.last_checkpoint,
|
315
315
|
}
|
316
316
|
|
317
|
-
def update_model(self, request, inference_manager: InferenceManager):
|
318
|
-
# No longer used
|
319
|
-
return {
|
320
|
-
"current_model": self.current_model,
|
321
|
-
"checkpoints": self.checkpoints,
|
322
|
-
"last_checkpoint": self.last_checkpoint,
|
323
|
-
}
|
324
|
-
|
325
317
|
def checkpoint(
|
326
318
|
self, request: GRPOCheckpointRequest, inference_manager: InferenceManager
|
327
319
|
):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# adapted from
|
1
|
+
# adapted from Will Brown's verifiers library (https://github.com/willccbb/verifiers)
|
2
2
|
|
3
3
|
import atexit
|
4
4
|
import logging
|
@@ -8,7 +8,6 @@ from typing import Optional
|
|
8
8
|
import httpx
|
9
9
|
import requests
|
10
10
|
import torch
|
11
|
-
from openai import OpenAI
|
12
11
|
from requests import ConnectionError
|
13
12
|
from requests.adapters import HTTPAdapter
|
14
13
|
from torch import nn
|
@@ -31,7 +30,7 @@ class InferenceBlockedError(Exception):
|
|
31
30
|
pass
|
32
31
|
|
33
32
|
|
34
|
-
class VLLMClient
|
33
|
+
class VLLMClient:
|
35
34
|
"""
|
36
35
|
A client class to interact with a vLLM server.
|
37
36
|
|
@@ -90,7 +89,7 @@ class VLLMClient(OpenAI):
|
|
90
89
|
"vLLM is not installed. Please install it with `pip install vllm`."
|
91
90
|
)
|
92
91
|
|
93
|
-
|
92
|
+
self.base_url = f"http://{host}:{port}/v1"
|
94
93
|
self.session = requests.Session()
|
95
94
|
# Configure connection pooling to handle rapid requests better
|
96
95
|
adapter = HTTPAdapter(
|
@@ -139,7 +139,11 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
139
139
|
maybe_apply_chat_template(
|
140
140
|
{
|
141
141
|
"prompt": example["messages"],
|
142
|
-
"completion":
|
142
|
+
"completion": (
|
143
|
+
example["completion"]
|
144
|
+
if isinstance(example["completion"], list)
|
145
|
+
else [example["completion"]]
|
146
|
+
),
|
143
147
|
},
|
144
148
|
self.processing_class,
|
145
149
|
)
|
@@ -168,15 +172,15 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
168
172
|
prompt_completion_text["completion"]
|
169
173
|
for prompt_completion_text in prompt_completion_texts
|
170
174
|
]
|
171
|
-
|
175
|
+
completion_inputs = self.processing_class(
|
172
176
|
completions_text,
|
173
177
|
return_tensors="pt",
|
174
178
|
padding=True,
|
175
179
|
add_special_tokens=False,
|
176
180
|
).to(device)
|
177
181
|
completion_ids, completion_mask = (
|
178
|
-
|
179
|
-
|
182
|
+
completion_inputs["input_ids"],
|
183
|
+
completion_inputs["attention_mask"],
|
180
184
|
)
|
181
185
|
|
182
186
|
if self.max_prompt_length is not None:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -40,7 +40,12 @@ Dynamic: license-file
|
|
40
40
|
Install Arbor via pip:
|
41
41
|
|
42
42
|
```bash
|
43
|
-
pip install arbor-ai
|
43
|
+
pip install -U arbor-ai
|
44
|
+
```
|
45
|
+
|
46
|
+
Optionally, you can also install:
|
47
|
+
```bash
|
48
|
+
pip install flash-attn --no-build-isolation
|
44
49
|
```
|
45
50
|
|
46
51
|
---
|
@@ -5,11 +5,11 @@ arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
|
|
5
5
|
arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
6
6
|
arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
|
7
7
|
arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
8
|
-
arbor/server/api/models/schemas.py,sha256=
|
8
|
+
arbor/server/api/models/schemas.py,sha256=394FHmIxAWVwED3z5tjnJCsyrgSWXg2SFWvMM1oKqOI,6177
|
9
9
|
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
|
11
|
-
arbor/server/api/routes/grpo.py,sha256=
|
12
|
-
arbor/server/api/routes/inference.py,sha256=
|
11
|
+
arbor/server/api/routes/grpo.py,sha256=Yc4FxieuUbJ7Dbd-93uN4syQu9h2eQU4R9ZvnE_axRU,1982
|
12
|
+
arbor/server/api/routes/inference.py,sha256=txLF4ANa0ZSaROrbvSaPZVFOSzn4so9e7mjNKnt2bcM,2182
|
13
13
|
arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
|
14
14
|
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
15
|
arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
|
@@ -17,26 +17,26 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
17
17
|
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
18
|
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
|
-
arbor/server/services/grpo_manager.py,sha256=
|
20
|
+
arbor/server/services/grpo_manager.py,sha256=MDpOGN99WnNg4q_8974MkAnqcClOXy6fYcD2sFvs2Ho,18487
|
21
21
|
arbor/server/services/inference_manager.py,sha256=a1c5zYbjk6fPM3egX2McKv7ZWPN7c-QH_Qogu9iay90,9597
|
22
22
|
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
23
|
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
24
|
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
|
26
26
|
arbor/server/services/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
-
arbor/server/services/inference/vllm_client.py,sha256=
|
28
|
-
arbor/server/services/inference/vllm_serve.py,sha256=
|
27
|
+
arbor/server/services/inference/vllm_client.py,sha256=P4etwX47VVMEaVWUOT-aP6_OONf8ZzniwXndmJujNxY,18250
|
28
|
+
arbor/server/services/inference/vllm_serve.py,sha256=UZAGo7CyshR3-9fhXCTKhXeidqNqbY6LyU9DDNiX_Sw,109543
|
29
29
|
arbor/server/services/scripts/dpo_training.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
30
|
-
arbor/server/services/scripts/grpo_training.py,sha256=
|
30
|
+
arbor/server/services/scripts/grpo_training.py,sha256=6kXzMwn3rZXHdEn0xe_Kd9d7tbdYb76zE0zbi02xCm4,31314
|
31
31
|
arbor/server/services/scripts/sft_training.py,sha256=jgDMxZn9RFH9ys_7OF9Is8pQ9V97O2KzWg22Gveh3yE,3410
|
32
32
|
arbor/server/services/scripts/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
33
|
arbor/server/services/scripts/utils/arg_parser.py,sha256=ur_iyhc_Ie00tjq63vK4Sdeu2PGKwe6Dh6Iax2vw9jc,1022
|
34
34
|
arbor/server/services/scripts/utils/dataset.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
35
35
|
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
36
36
|
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
|
-
arbor_ai-0.
|
38
|
-
arbor_ai-0.
|
39
|
-
arbor_ai-0.
|
40
|
-
arbor_ai-0.
|
41
|
-
arbor_ai-0.
|
42
|
-
arbor_ai-0.
|
37
|
+
arbor_ai-0.2.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
38
|
+
arbor_ai-0.2.dist-info/METADATA,sha256=LieUwdo2RQBgh5ukQJh-NHUA2_CBS1Dr9YqjSbgcEnM,2504
|
39
|
+
arbor_ai-0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
40
|
+
arbor_ai-0.2.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
41
|
+
arbor_ai-0.2.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
42
|
+
arbor_ai-0.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|