arbor-ai 0.1.15__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -178,7 +178,6 @@ class ChatCompletionModel(BaseModel):
178
178
 
179
179
  class GRPORequest(BaseModel):
180
180
  model: str
181
- update_inference_model: bool
182
181
  batch: List[dict]
183
182
 
184
183
 
@@ -38,14 +38,6 @@ def run_grpo_step(
38
38
  return GRPOStepResponse(status="success", **step_data)
39
39
 
40
40
 
41
- @router.post("/update_model", response_model=GRPOStepResponse)
42
- def update_model(request: Request):
43
- grpo_manager = request.app.state.grpo_manager
44
- inference_manager = request.app.state.inference_manager
45
- update_model_data = grpo_manager.update_model(request, inference_manager)
46
- return GRPOStepResponse(status="success", **update_model_data)
47
-
48
-
49
41
  @router.post("/checkpoint", response_model=GRPOCheckpointResponse)
50
42
  def checkpoint(request: Request, grpo_checkpoint_request: GRPOCheckpointRequest):
51
43
  grpo_manager = request.app.state.grpo_manager
@@ -19,10 +19,25 @@ async def run_inference(
19
19
  with open(f"{request.app.state.log_dir}/inference_requests.jsonl", "a") as f:
20
20
  f.write(json.dumps({"id": request_id, "request": raw_json}) + "\n")
21
21
 
22
+ request_model = raw_json["model"]
23
+ prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
24
+ for prefix in prefixes:
25
+ if request_model.startswith(prefix):
26
+ request_model = request_model[len(prefix) :]
27
+
22
28
  # if a server isnt running, launch one
23
29
  if not inference_manager.is_server_running():
24
30
  print("No model is running, launching model...")
25
- inference_manager.launch(raw_json["model"])
31
+ inference_manager.launch(request_model)
32
+
33
+ # if the requested model is different from the launched model, swap the server
34
+ if request_model != inference_manager.launched_model:
35
+ print(
36
+ f"Model changed from {inference_manager.launched_model} to {request_model}, swapping server..."
37
+ )
38
+ inference_manager.kill()
39
+ inference_manager.launch(request_model)
40
+ print(f"Model swapped to {request_model}")
26
41
 
27
42
  # forward the request to the inference server
28
43
  completion = await inference_manager.run_inference(raw_json)
@@ -270,7 +270,6 @@ class GRPOManager:
270
270
  print("Updating inference model...")
271
271
  # There is a case where this status is sent multiple times
272
272
  # We need to make sure we only update the model once
273
- self.current_model = status["output_dir"]
274
273
  self.saving_model = False
275
274
  print("Model update complete")
276
275
  elif status["status"] == "checkpoint_saved":
@@ -308,14 +307,9 @@ class GRPOManager:
308
307
  print(f"Failed to send batch to training process: {e}")
309
308
  raise
310
309
 
311
- return {
312
- "current_model": self.current_model,
313
- "checkpoints": self.checkpoints,
314
- "last_checkpoint": self.last_checkpoint,
315
- }
310
+ self.current_model = self.train_kwargs["output_dir"]
311
+ inference_manager.launched_model = self.current_model
316
312
 
317
- def update_model(self, request, inference_manager: InferenceManager):
318
- # No longer used
319
313
  return {
320
314
  "current_model": self.current_model,
321
315
  "checkpoints": self.checkpoints,
@@ -1,5 +1,6 @@
1
- # adapted from trl/extras/vllm_client.py (huggingface/trl)
1
+ # adapted from Will Brown's verifiers library (https://github.com/willccbb/verifiers)
2
2
 
3
+ import asyncio
3
4
  import atexit
4
5
  import logging
5
6
  import time
@@ -8,7 +9,6 @@ from typing import Optional
8
9
  import httpx
9
10
  import requests
10
11
  import torch
11
- from openai import OpenAI
12
12
  from requests import ConnectionError
13
13
  from requests.adapters import HTTPAdapter
14
14
  from torch import nn
@@ -31,7 +31,7 @@ class InferenceBlockedError(Exception):
31
31
  pass
32
32
 
33
33
 
34
- class VLLMClient(OpenAI):
34
+ class VLLMClient:
35
35
  """
36
36
  A client class to interact with a vLLM server.
37
37
 
@@ -90,7 +90,7 @@ class VLLMClient(OpenAI):
90
90
  "vLLM is not installed. Please install it with `pip install vllm`."
91
91
  )
92
92
 
93
- super().__init__(base_url=f"http://{host}:{port}/v1", api_key="local")
93
+ self.base_url = f"http://{host}:{port}/v1"
94
94
  self.session = requests.Session()
95
95
  # Configure connection pooling to handle rapid requests better
96
96
  adapter = HTTPAdapter(
@@ -240,7 +240,7 @@ class VLLMClient(OpenAI):
240
240
  response.raise_for_status()
241
241
  return response.json()
242
242
 
243
- except httpx.TimeoutError:
243
+ except httpx.TimeoutException:
244
244
  logger.error("Request timed out")
245
245
  raise
246
246
  except InferenceBlockedError:
@@ -1,3 +1,4 @@
1
+ # adapted from Will Brown's verifiers library (https://github.com/willccbb/verifiers)
1
2
  """
2
3
  OpenAI-compatible vLLM server with weight synchronization.
3
4
 
@@ -139,7 +139,11 @@ class ArborGRPOTrainer(GRPOTrainer):
139
139
  maybe_apply_chat_template(
140
140
  {
141
141
  "prompt": example["messages"],
142
- "completion": [example["completion"]],
142
+ "completion": (
143
+ example["completion"]
144
+ if isinstance(example["completion"], list)
145
+ else [example["completion"]]
146
+ ),
143
147
  },
144
148
  self.processing_class,
145
149
  )
@@ -168,15 +172,15 @@ class ArborGRPOTrainer(GRPOTrainer):
168
172
  prompt_completion_text["completion"]
169
173
  for prompt_completion_text in prompt_completion_texts
170
174
  ]
171
- completion_ids = self.processing_class(
175
+ completion_inputs = self.processing_class(
172
176
  completions_text,
173
177
  return_tensors="pt",
174
178
  padding=True,
175
179
  add_special_tokens=False,
176
180
  ).to(device)
177
181
  completion_ids, completion_mask = (
178
- completion_ids["input_ids"],
179
- completion_ids["attention_mask"],
182
+ completion_inputs["input_ids"],
183
+ completion_inputs["attention_mask"],
180
184
  )
181
185
 
182
186
  if self.max_prompt_length is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.15
3
+ Version: 0.2.1
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -40,7 +40,12 @@ Dynamic: license-file
40
40
  Install Arbor via pip:
41
41
 
42
42
  ```bash
43
- pip install arbor-ai
43
+ pip install -U arbor-ai
44
+ ```
45
+
46
+ Optionally, you can also install:
47
+ ```bash
48
+ pip install flash-attn --no-build-isolation
44
49
  ```
45
50
 
46
51
  ---
@@ -74,6 +79,16 @@ Follow the DSPy tutorials here to see usage examples:
74
79
 
75
80
  ---
76
81
 
82
+ ### Troubleshooting
83
+
84
+ **NCCL Errors**
85
+ Certain GPU setups, particularly with newer GPUs, seem to have issues with NCCL that cause Arbor to crash. Often times of these can be fixed with the following environment variables:
86
+
87
+ ```bash
88
+ export NCCL_P2P_DISABLE=1
89
+ export NCCL_IB_DISABLE=1
90
+ ```
91
+
77
92
  ## 🙏 Acknowledgements
78
93
 
79
94
  Arbor builds on the shoulders of great work. We extend our thanks to:
@@ -5,11 +5,11 @@ arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
5
5
  arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
6
6
  arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
7
7
  arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
8
- arbor/server/api/models/schemas.py,sha256=KCHav1nPFbQEynrcO-MObhRmoOrdFvfGuVogApynOCA,6210
8
+ arbor/server/api/models/schemas.py,sha256=394FHmIxAWVwED3z5tjnJCsyrgSWXg2SFWvMM1oKqOI,6177
9
9
  arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
11
- arbor/server/api/routes/grpo.py,sha256=QrWwj44-EenOyDwtiAO7OJPPGe8CyNaxCUTDlqfJs4g,2338
12
- arbor/server/api/routes/inference.py,sha256=JI4lm7zWrUqgMadWA0JuTD13hq6kGQpTLcuklhOH7f8,1547
11
+ arbor/server/api/routes/grpo.py,sha256=Yc4FxieuUbJ7Dbd-93uN4syQu9h2eQU4R9ZvnE_axRU,1982
12
+ arbor/server/api/routes/inference.py,sha256=txLF4ANa0ZSaROrbvSaPZVFOSzn4so9e7mjNKnt2bcM,2182
13
13
  arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
14
14
  arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
15
15
  arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
@@ -17,26 +17,26 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
17
17
  arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
20
- arbor/server/services/grpo_manager.py,sha256=y5gOko_RmyjQqvzlR79_PPZgMwMwCMJiaeygCG5qS-A,18761
20
+ arbor/server/services/grpo_manager.py,sha256=jY4kc7wlKKoi7RigjJiH1VaxX6qJCOxyEc0oYCkqPlQ,18549
21
21
  arbor/server/services/inference_manager.py,sha256=a1c5zYbjk6fPM3egX2McKv7ZWPN7c-QH_Qogu9iay90,9597
22
22
  arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
23
23
  arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
24
24
  arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
26
26
  arbor/server/services/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
- arbor/server/services/inference/vllm_client.py,sha256=X0v6zGHuaROGniWw_VCkzeWWuAHq0PlwtrFjTngCT4k,18285
28
- arbor/server/services/inference/vllm_serve.py,sha256=GdcaQStGKLj4J1kAnAnnI07R0X3A-bPoj7Tvagxsias,109457
27
+ arbor/server/services/inference/vllm_client.py,sha256=06-VfdcwKqq8_ZRWaER3OnSVLtvL87bLdljSrkXfm-A,18269
28
+ arbor/server/services/inference/vllm_serve.py,sha256=UZAGo7CyshR3-9fhXCTKhXeidqNqbY6LyU9DDNiX_Sw,109543
29
29
  arbor/server/services/scripts/dpo_training.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- arbor/server/services/scripts/grpo_training.py,sha256=qjYSinOhi9-vvKY-gqGARwUgDQXYGDHlp9ZLwqKE1rw,31123
30
+ arbor/server/services/scripts/grpo_training.py,sha256=6kXzMwn3rZXHdEn0xe_Kd9d7tbdYb76zE0zbi02xCm4,31314
31
31
  arbor/server/services/scripts/sft_training.py,sha256=jgDMxZn9RFH9ys_7OF9Is8pQ9V97O2KzWg22Gveh3yE,3410
32
32
  arbor/server/services/scripts/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
33
  arbor/server/services/scripts/utils/arg_parser.py,sha256=ur_iyhc_Ie00tjq63vK4Sdeu2PGKwe6Dh6Iax2vw9jc,1022
34
34
  arbor/server/services/scripts/utils/dataset.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
35
  arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
- arbor_ai-0.1.15.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
38
- arbor_ai-0.1.15.dist-info/METADATA,sha256=GMGq6nbWEbRZxsJG2u7DhnMj6qCSTvssMVUN4ASs2BA,2413
39
- arbor_ai-0.1.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
- arbor_ai-0.1.15.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
41
- arbor_ai-0.1.15.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
42
- arbor_ai-0.1.15.dist-info/RECORD,,
37
+ arbor_ai-0.2.1.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
38
+ arbor_ai-0.2.1.dist-info/METADATA,sha256=34XAZBm8OLlsSBicLmRn_hhbltn0pDNlAj5WOjn9LtE,2791
39
+ arbor_ai-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ arbor_ai-0.2.1.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
41
+ arbor_ai-0.2.1.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
42
+ arbor_ai-0.2.1.dist-info/RECORD,,