arbor-ai 0.1.15__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/server/api/models/schemas.py +0 -1
- arbor/server/api/routes/grpo.py +0 -8
- arbor/server/api/routes/inference.py +16 -1
- arbor/server/services/grpo_manager.py +2 -8
- arbor/server/services/inference/vllm_client.py +5 -5
- arbor/server/services/inference/vllm_serve.py +1 -0
- arbor/server/services/scripts/grpo_training.py +8 -4
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.1.dist-info}/METADATA +17 -2
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.1.dist-info}/RECORD +13 -13
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.1.dist-info}/WHEEL +0 -0
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.1.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.15.dist-info → arbor_ai-0.2.1.dist-info}/top_level.txt +0 -0
arbor/server/api/routes/grpo.py
CHANGED
@@ -38,14 +38,6 @@ def run_grpo_step(
|
|
38
38
|
return GRPOStepResponse(status="success", **step_data)
|
39
39
|
|
40
40
|
|
41
|
-
@router.post("/update_model", response_model=GRPOStepResponse)
|
42
|
-
def update_model(request: Request):
|
43
|
-
grpo_manager = request.app.state.grpo_manager
|
44
|
-
inference_manager = request.app.state.inference_manager
|
45
|
-
update_model_data = grpo_manager.update_model(request, inference_manager)
|
46
|
-
return GRPOStepResponse(status="success", **update_model_data)
|
47
|
-
|
48
|
-
|
49
41
|
@router.post("/checkpoint", response_model=GRPOCheckpointResponse)
|
50
42
|
def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
|
51
43
|
grpo_manager = request.app.state.grpo_manager
|
@@ -19,10 +19,25 @@ async def run_inference(
|
|
19
19
|
with open(f"{request.app.state.log_dir}/inference_requests.jsonl", "a") as f:
|
20
20
|
f.write(json.dumps({"id": request_id, "request": raw_json}) + "\n")
|
21
21
|
|
22
|
+
request_model = raw_json["model"]
|
23
|
+
prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
|
24
|
+
for prefix in prefixes:
|
25
|
+
if request_model.startswith(prefix):
|
26
|
+
request_model = request_model[len(prefix) :]
|
27
|
+
|
22
28
|
# if a server isnt running, launch one
|
23
29
|
if not inference_manager.is_server_running():
|
24
30
|
print("No model is running, launching model...")
|
25
|
-
inference_manager.launch(
|
31
|
+
inference_manager.launch(request_model)
|
32
|
+
|
33
|
+
# if the requested model is different from the launched model, swap the server
|
34
|
+
if request_model != inference_manager.launched_model:
|
35
|
+
print(
|
36
|
+
f"Model changed from {inference_manager.launched_model} to {request_model}, swapping server..."
|
37
|
+
)
|
38
|
+
inference_manager.kill()
|
39
|
+
inference_manager.launch(request_model)
|
40
|
+
print(f"Model swapped to {request_model}")
|
26
41
|
|
27
42
|
# forward the request to the inference server
|
28
43
|
completion = await inference_manager.run_inference(raw_json)
|
@@ -270,7 +270,6 @@ class GRPOManager:
|
|
270
270
|
print("Updating inference model...")
|
271
271
|
# There is a case where this status is sent multiple times
|
272
272
|
# We need to make sure we only update the model once
|
273
|
-
self.current_model = status["output_dir"]
|
274
273
|
self.saving_model = False
|
275
274
|
print("Model update complete")
|
276
275
|
elif status["status"] == "checkpoint_saved":
|
@@ -308,14 +307,9 @@ class GRPOManager:
|
|
308
307
|
print(f"Failed to send batch to training process: {e}")
|
309
308
|
raise
|
310
309
|
|
311
|
-
|
312
|
-
|
313
|
-
"checkpoints": self.checkpoints,
|
314
|
-
"last_checkpoint": self.last_checkpoint,
|
315
|
-
}
|
310
|
+
self.current_model = self.train_kwargs["output_dir"]
|
311
|
+
inference_manager.launched_model = self.current_model
|
316
312
|
|
317
|
-
def update_model(self, request, inference_manager: InferenceManager):
|
318
|
-
# No longer used
|
319
313
|
return {
|
320
314
|
"current_model": self.current_model,
|
321
315
|
"checkpoints": self.checkpoints,
|
@@ -1,5 +1,6 @@
|
|
1
|
-
# adapted from
|
1
|
+
# adapted from Will Brown's verifiers library (https://github.com/willccbb/verifiers)
|
2
2
|
|
3
|
+
import asyncio
|
3
4
|
import atexit
|
4
5
|
import logging
|
5
6
|
import time
|
@@ -8,7 +9,6 @@ from typing import Optional
|
|
8
9
|
import httpx
|
9
10
|
import requests
|
10
11
|
import torch
|
11
|
-
from openai import OpenAI
|
12
12
|
from requests import ConnectionError
|
13
13
|
from requests.adapters import HTTPAdapter
|
14
14
|
from torch import nn
|
@@ -31,7 +31,7 @@ class InferenceBlockedError(Exception):
|
|
31
31
|
pass
|
32
32
|
|
33
33
|
|
34
|
-
class VLLMClient
|
34
|
+
class VLLMClient:
|
35
35
|
"""
|
36
36
|
A client class to interact with a vLLM server.
|
37
37
|
|
@@ -90,7 +90,7 @@ class VLLMClient(OpenAI):
|
|
90
90
|
"vLLM is not installed. Please install it with `pip install vllm`."
|
91
91
|
)
|
92
92
|
|
93
|
-
|
93
|
+
self.base_url = f"http://{host}:{port}/v1"
|
94
94
|
self.session = requests.Session()
|
95
95
|
# Configure connection pooling to handle rapid requests better
|
96
96
|
adapter = HTTPAdapter(
|
@@ -240,7 +240,7 @@ class VLLMClient(OpenAI):
|
|
240
240
|
response.raise_for_status()
|
241
241
|
return response.json()
|
242
242
|
|
243
|
-
except httpx.
|
243
|
+
except httpx.TimeoutException:
|
244
244
|
logger.error("Request timed out")
|
245
245
|
raise
|
246
246
|
except InferenceBlockedError:
|
@@ -139,7 +139,11 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
139
139
|
maybe_apply_chat_template(
|
140
140
|
{
|
141
141
|
"prompt": example["messages"],
|
142
|
-
"completion":
|
142
|
+
"completion": (
|
143
|
+
example["completion"]
|
144
|
+
if isinstance(example["completion"], list)
|
145
|
+
else [example["completion"]]
|
146
|
+
),
|
143
147
|
},
|
144
148
|
self.processing_class,
|
145
149
|
)
|
@@ -168,15 +172,15 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
168
172
|
prompt_completion_text["completion"]
|
169
173
|
for prompt_completion_text in prompt_completion_texts
|
170
174
|
]
|
171
|
-
|
175
|
+
completion_inputs = self.processing_class(
|
172
176
|
completions_text,
|
173
177
|
return_tensors="pt",
|
174
178
|
padding=True,
|
175
179
|
add_special_tokens=False,
|
176
180
|
).to(device)
|
177
181
|
completion_ids, completion_mask = (
|
178
|
-
|
179
|
-
|
182
|
+
completion_inputs["input_ids"],
|
183
|
+
completion_inputs["attention_mask"],
|
180
184
|
)
|
181
185
|
|
182
186
|
if self.max_prompt_length is not None:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.1
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -40,7 +40,12 @@ Dynamic: license-file
|
|
40
40
|
Install Arbor via pip:
|
41
41
|
|
42
42
|
```bash
|
43
|
-
pip install arbor-ai
|
43
|
+
pip install -U arbor-ai
|
44
|
+
```
|
45
|
+
|
46
|
+
Optionally, you can also install:
|
47
|
+
```bash
|
48
|
+
pip install flash-attn --no-build-isolation
|
44
49
|
```
|
45
50
|
|
46
51
|
---
|
@@ -74,6 +79,16 @@ Follow the DSPy tutorials here to see usage examples:
|
|
74
79
|
|
75
80
|
---
|
76
81
|
|
82
|
+
### Troubleshooting
|
83
|
+
|
84
|
+
**NCCL Errors**
|
85
|
+
Certain GPU setups, particularly with newer GPUs, seem to have issues with NCCL that cause Arbor to crash. Often times of these can be fixed with the following environment variables:
|
86
|
+
|
87
|
+
```bash
|
88
|
+
export NCCL_P2P_DISABLE=1
|
89
|
+
export NCCL_IB_DISABLE=1
|
90
|
+
```
|
91
|
+
|
77
92
|
## 🙏 Acknowledgements
|
78
93
|
|
79
94
|
Arbor builds on the shoulders of great work. We extend our thanks to:
|
@@ -5,11 +5,11 @@ arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
|
|
5
5
|
arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
6
6
|
arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
|
7
7
|
arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
8
|
-
arbor/server/api/models/schemas.py,sha256=
|
8
|
+
arbor/server/api/models/schemas.py,sha256=394FHmIxAWVwED3z5tjnJCsyrgSWXg2SFWvMM1oKqOI,6177
|
9
9
|
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
|
11
|
-
arbor/server/api/routes/grpo.py,sha256=
|
12
|
-
arbor/server/api/routes/inference.py,sha256=
|
11
|
+
arbor/server/api/routes/grpo.py,sha256=Yc4FxieuUbJ7Dbd-93uN4syQu9h2eQU4R9ZvnE_axRU,1982
|
12
|
+
arbor/server/api/routes/inference.py,sha256=txLF4ANa0ZSaROrbvSaPZVFOSzn4so9e7mjNKnt2bcM,2182
|
13
13
|
arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
|
14
14
|
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
15
|
arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
|
@@ -17,26 +17,26 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
17
17
|
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
18
|
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
|
-
arbor/server/services/grpo_manager.py,sha256=
|
20
|
+
arbor/server/services/grpo_manager.py,sha256=jY4kc7wlKKoi7RigjJiH1VaxX6qJCOxyEc0oYCkqPlQ,18549
|
21
21
|
arbor/server/services/inference_manager.py,sha256=a1c5zYbjk6fPM3egX2McKv7ZWPN7c-QH_Qogu9iay90,9597
|
22
22
|
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
23
|
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
24
|
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
|
26
26
|
arbor/server/services/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
-
arbor/server/services/inference/vllm_client.py,sha256=
|
28
|
-
arbor/server/services/inference/vllm_serve.py,sha256=
|
27
|
+
arbor/server/services/inference/vllm_client.py,sha256=06-VfdcwKqq8_ZRWaER3OnSVLtvL87bLdljSrkXfm-A,18269
|
28
|
+
arbor/server/services/inference/vllm_serve.py,sha256=UZAGo7CyshR3-9fhXCTKhXeidqNqbY6LyU9DDNiX_Sw,109543
|
29
29
|
arbor/server/services/scripts/dpo_training.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
30
|
-
arbor/server/services/scripts/grpo_training.py,sha256=
|
30
|
+
arbor/server/services/scripts/grpo_training.py,sha256=6kXzMwn3rZXHdEn0xe_Kd9d7tbdYb76zE0zbi02xCm4,31314
|
31
31
|
arbor/server/services/scripts/sft_training.py,sha256=jgDMxZn9RFH9ys_7OF9Is8pQ9V97O2KzWg22Gveh3yE,3410
|
32
32
|
arbor/server/services/scripts/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
33
|
arbor/server/services/scripts/utils/arg_parser.py,sha256=ur_iyhc_Ie00tjq63vK4Sdeu2PGKwe6Dh6Iax2vw9jc,1022
|
34
34
|
arbor/server/services/scripts/utils/dataset.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
35
35
|
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
36
36
|
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
|
-
arbor_ai-0.1.
|
38
|
-
arbor_ai-0.1.
|
39
|
-
arbor_ai-0.1.
|
40
|
-
arbor_ai-0.1.
|
41
|
-
arbor_ai-0.1.
|
42
|
-
arbor_ai-0.1.
|
37
|
+
arbor_ai-0.2.1.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
38
|
+
arbor_ai-0.2.1.dist-info/METADATA,sha256=34XAZBm8OLlsSBicLmRn_hhbltn0pDNlAj5WOjn9LtE,2791
|
39
|
+
arbor_ai-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
40
|
+
arbor_ai-0.2.1.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
41
|
+
arbor_ai-0.2.1.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
42
|
+
arbor_ai-0.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|