arbor-ai 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +89 -5
- arbor/client/api.py +1 -2
- arbor/server/api/models/schemas.py +209 -5
- arbor/server/api/routes/files.py +39 -10
- arbor/server/api/routes/grpo.py +54 -0
- arbor/server/api/routes/inference.py +53 -0
- arbor/server/api/routes/jobs.py +110 -7
- arbor/server/core/config.py +44 -7
- arbor/server/main.py +6 -5
- arbor/server/services/comms/__init__.py +0 -0
- arbor/server/services/comms/comms.py +226 -0
- arbor/server/services/dependencies.py +0 -16
- arbor/server/services/file_manager.py +270 -109
- arbor/server/services/grpo_manager.py +310 -0
- arbor/server/services/inference_manager.py +275 -0
- arbor/server/services/job_manager.py +74 -69
- arbor/server/services/scripts/grpo_training.py +576 -0
- arbor/server/services/training_manager.py +337 -40
- arbor_ai-0.1.6.dist-info/METADATA +78 -0
- arbor_ai-0.1.6.dist-info/RECORD +34 -0
- {arbor_ai-0.1.4.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +2 -1
- arbor_ai-0.1.6.dist-info/entry_points.txt +2 -0
- arbor_ai-0.1.6.dist-info/top_level.txt +1 -0
- arbor/server/api/routes/training.py +0 -16
- arbor_ai-0.1.4.dist-info/METADATA +0 -97
- arbor_ai-0.1.4.dist-info/RECORD +0 -27
- arbor_ai-0.1.4.dist-info/entry_points.txt +0 -3
- {arbor_ai-0.1.4.dist-info → arbor_ai-0.1.6.dist-info/licenses}/LICENSE +0 -0
@@ -1,19 +1,40 @@
|
|
1
|
+
import os
|
2
|
+
import random
|
3
|
+
import string
|
4
|
+
import time
|
5
|
+
from datetime import datetime
|
6
|
+
from pathlib import Path
|
7
|
+
|
1
8
|
from arbor.server.api.models.schemas import FineTuneRequest
|
2
|
-
from arbor.server.
|
9
|
+
from arbor.server.core.config import Settings
|
3
10
|
from arbor.server.services.file_manager import FileManager
|
11
|
+
from arbor.server.services.job_manager import Job, JobEvent, JobStatus
|
4
12
|
|
5
|
-
class TrainingManager:
|
6
|
-
def __init__(self):
|
7
|
-
pass
|
8
13
|
|
9
|
-
|
14
|
+
class TrainingManager:
|
15
|
+
def __init__(self, settings: Settings):
|
16
|
+
self.settings = settings
|
17
|
+
|
18
|
+
def make_output_dir(self, request: FineTuneRequest):
|
19
|
+
model_name = request.model.split("/")[-1].lower()
|
20
|
+
suffix = (
|
21
|
+
request.suffix
|
22
|
+
if request.suffix is not None
|
23
|
+
else "".join(random.choices(string.ascii_letters + string.digits, k=6))
|
24
|
+
)
|
25
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
26
|
+
name = f"ft:{model_name}:{suffix}:{timestamp}"
|
27
|
+
return name, str(Path(self.settings.STORAGE_PATH).resolve() / "models" / name)
|
28
|
+
|
29
|
+
def find_train_args_sft(self, request: FineTuneRequest, file_manager: FileManager):
|
10
30
|
file = file_manager.get_file(request.training_file)
|
11
31
|
if file is None:
|
12
32
|
raise ValueError(f"Training file {request.training_file} not found")
|
13
33
|
|
14
34
|
data_path = file["path"]
|
15
|
-
|
35
|
+
file_manager.validate_file_format_sft(data_path)
|
16
36
|
|
37
|
+
name, output_dir = self.make_output_dir(request)
|
17
38
|
|
18
39
|
default_train_kwargs = {
|
19
40
|
"device": None,
|
@@ -28,25 +49,225 @@ class TrainingManager:
|
|
28
49
|
"output_dir": output_dir,
|
29
50
|
"train_data_path": data_path,
|
30
51
|
}
|
31
|
-
train_kwargs = {
|
32
|
-
train_kwargs={**default_train_kwargs, **(train_kwargs or {})}
|
33
|
-
output_dir = train_kwargs["output_dir"] # user might have changed the output_dir
|
52
|
+
train_kwargs = {"packing": False}
|
53
|
+
train_kwargs = {**default_train_kwargs, **(train_kwargs or {})}
|
34
54
|
|
35
55
|
return train_kwargs
|
36
56
|
|
57
|
+
def find_train_args_dpo(self, request: FineTuneRequest, file_manager: FileManager):
|
58
|
+
file = file_manager.get_file(request.training_file)
|
59
|
+
if file is None:
|
60
|
+
raise ValueError(f"Training file {request.training_file} not found")
|
61
|
+
|
62
|
+
data_path = file["path"]
|
63
|
+
file_manager.validate_file_format_dpo(data_path)
|
64
|
+
|
65
|
+
name, output_dir = self.make_output_dir(request)
|
66
|
+
|
67
|
+
default_train_kwargs = {
|
68
|
+
"device": "cuda:2",
|
69
|
+
"use_peft": False,
|
70
|
+
"num_train_epochs": 5,
|
71
|
+
"per_device_train_batch_size": 1,
|
72
|
+
"gradient_accumulation_steps": 8,
|
73
|
+
"learning_rate": 1e-5,
|
74
|
+
"packing": True,
|
75
|
+
"bf16": True,
|
76
|
+
"output_dir": output_dir,
|
77
|
+
"train_data_path": data_path,
|
78
|
+
"prompt_length": 1024,
|
79
|
+
"max_seq_length": 1512,
|
80
|
+
"use_peft": False,
|
81
|
+
}
|
82
|
+
|
83
|
+
# https://www.philschmid.de/dpo-align-llms-in-2024-with-trl#3-align-llm-with-trl-and-the-dpotrainer
|
84
|
+
|
85
|
+
train_kwargs = request.model_dump(exclude_unset=True)
|
86
|
+
train_kwargs = {**default_train_kwargs, **(train_kwargs or {})}
|
87
|
+
|
88
|
+
return train_kwargs
|
37
89
|
|
38
90
|
def fine_tune(self, request: FineTuneRequest, job: Job, file_manager: FileManager):
|
39
|
-
# Get logger for this job
|
40
|
-
logger = job.setup_logger("training")
|
41
91
|
|
42
92
|
job.status = JobStatus.RUNNING
|
43
|
-
|
93
|
+
job.add_event(
|
94
|
+
JobEvent(level="info", message="Starting fine-tuning job", data={})
|
95
|
+
)
|
96
|
+
|
97
|
+
fine_tune_type = request.method["type"]
|
98
|
+
if fine_tune_type == "dpo":
|
99
|
+
self.dpo_fine_tune(request, job, file_manager)
|
100
|
+
else:
|
101
|
+
self.sft_fine_tune(request, job, file_manager)
|
102
|
+
|
103
|
+
def dpo_fine_tune(
|
104
|
+
self, request: FineTuneRequest, job: Job, file_manager: FileManager
|
105
|
+
):
|
106
|
+
try:
|
107
|
+
|
108
|
+
job.status = JobStatus.RUNNING
|
109
|
+
job.add_event(
|
110
|
+
JobEvent(level="info", message="Starting fine-tuning job", data={})
|
111
|
+
)
|
112
|
+
|
113
|
+
train_kwargs = self.find_train_args_dpo(request, file_manager)
|
114
|
+
import torch
|
115
|
+
from transformers import (
|
116
|
+
AutoModelForCausalLM,
|
117
|
+
AutoTokenizer,
|
118
|
+
TrainingArguments,
|
119
|
+
)
|
120
|
+
from trl import DPOConfig, DPOTrainer, setup_chat_format
|
121
|
+
|
122
|
+
device = train_kwargs.get("device", None)
|
123
|
+
if device is None:
|
124
|
+
device = (
|
125
|
+
"cuda"
|
126
|
+
if torch.cuda.is_available()
|
127
|
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
128
|
+
)
|
129
|
+
|
130
|
+
job.add_event(
|
131
|
+
JobEvent(level="info", message=f"Using device: {device}", data={})
|
132
|
+
)
|
133
|
+
|
134
|
+
model = AutoModelForCausalLM.from_pretrained(
|
135
|
+
pretrained_model_name_or_path=request.model, device_map="auto"
|
136
|
+
)
|
137
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
138
|
+
pretrained_model_name_or_path=request.model
|
139
|
+
)
|
140
|
+
|
141
|
+
try:
|
142
|
+
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
|
143
|
+
except Exception:
|
144
|
+
pass
|
145
|
+
|
146
|
+
if tokenizer.pad_token_id is None:
|
147
|
+
job.add_event(
|
148
|
+
JobEvent(
|
149
|
+
level="info", message="Adding pad token to tokenizer", data={}
|
150
|
+
)
|
151
|
+
)
|
152
|
+
tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
|
153
|
+
|
154
|
+
hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
|
155
|
+
train_dataset = hf_dataset
|
156
|
+
|
157
|
+
use_peft = train_kwargs.get("use_peft", False)
|
158
|
+
peft_config = None
|
159
|
+
|
160
|
+
if use_peft:
|
161
|
+
from peft import LoraConfig
|
162
|
+
|
163
|
+
peft_config = LoraConfig(
|
164
|
+
lora_alpha=128,
|
165
|
+
lora_dropout=0.05,
|
166
|
+
r=256,
|
167
|
+
bias="none",
|
168
|
+
target_modules="all-linear",
|
169
|
+
task_type="CAUSAL_LM",
|
170
|
+
)
|
171
|
+
|
172
|
+
dpo_args = DPOConfig(
|
173
|
+
output_dir=train_kwargs["output_dir"],
|
174
|
+
num_train_epochs=train_kwargs["num_train_epochs"],
|
175
|
+
)
|
176
|
+
|
177
|
+
trainer = DPOTrainer(
|
178
|
+
model=model,
|
179
|
+
args=dpo_args,
|
180
|
+
train_dataset=train_dataset,
|
181
|
+
processing_class=tokenizer,
|
182
|
+
peft_config=peft_config,
|
183
|
+
)
|
184
|
+
|
185
|
+
trainer.train()
|
186
|
+
|
187
|
+
if job.status == JobStatus.PENDING_PAUSE:
|
188
|
+
trainer.accelerator.save_state(output_dir=dpo_args.output_dir)
|
189
|
+
current_step = trainer.state.global_step
|
190
|
+
job.status = JobStatus.PAUSED
|
191
|
+
job.add_event(
|
192
|
+
JobEvent(
|
193
|
+
level="info",
|
194
|
+
message="Training paused",
|
195
|
+
data={"step": current_step},
|
196
|
+
)
|
197
|
+
)
|
198
|
+
|
199
|
+
while job.status == JobStatus.PAUSED:
|
200
|
+
time.sleep(1) # Sleep to avoid busy waiting
|
201
|
+
if job.status == JobStatus.PENDING_RESUME:
|
202
|
+
job.status = JobStatus.RUNNING
|
203
|
+
job.add_event(JobEvent(level="info", message="Resuming training"))
|
204
|
+
trainer.accelerator.load_state(input_dir=dpo_args.output_dir)
|
205
|
+
trainer.train(resume_from_checkpoint=True)
|
206
|
+
|
207
|
+
if job.status == JobStatus.PENDING_CANCEL:
|
208
|
+
job.status = JobStatus.CANCELLED
|
209
|
+
job.add_event(JobEvent(level="info", message="Training cancelled"))
|
210
|
+
|
211
|
+
_cleanup(model, tokenizer, trainer)
|
212
|
+
raise Exception(
|
213
|
+
"Training cancelled"
|
214
|
+
) # not sure if this should be raised or just return None
|
215
|
+
|
216
|
+
job.add_event(
|
217
|
+
JobEvent(level="info", message="Training completed successfully")
|
218
|
+
)
|
219
|
+
|
220
|
+
job.add_event(JobEvent(level="info", message="Saving model", data={}))
|
221
|
+
# Save the model!
|
222
|
+
trainer.save_model()
|
223
|
+
job.add_event(
|
224
|
+
JobEvent(
|
225
|
+
level="info",
|
226
|
+
message="Model saved",
|
227
|
+
data={"location": dpo_args.output_dir},
|
228
|
+
)
|
229
|
+
)
|
230
|
+
|
231
|
+
MERGE = True
|
232
|
+
if use_peft and MERGE:
|
233
|
+
from peft import AutoPeftModelForCausalLM
|
234
|
+
|
235
|
+
# Load PEFT model on CPU
|
236
|
+
model_ = AutoPeftModelForCausalLM.from_pretrained(
|
237
|
+
pretrained_model_name_or_path=dpo_args.output_dir,
|
238
|
+
torch_dtype=torch.float16,
|
239
|
+
low_cpu_mem_usage=True,
|
240
|
+
)
|
241
|
+
|
242
|
+
merged_model = model_.merge_and_unload()
|
243
|
+
merged_model.save_pretrained(
|
244
|
+
dpo_args.output_dir, safe_serialization=True, max_shard_size="5GB"
|
245
|
+
)
|
246
|
+
|
247
|
+
_cleanup(model, tokenizer, trainer)
|
248
|
+
|
249
|
+
job.status = JobStatus.SUCCEEDED
|
250
|
+
job.fine_tuned_model = dpo_args.output_dir
|
251
|
+
except Exception as e:
|
252
|
+
job.add_event(
|
253
|
+
JobEvent(level="error", message=f"Training failed: {str(e)}", data={})
|
254
|
+
)
|
255
|
+
job.status = JobStatus.FAILED
|
256
|
+
raise
|
257
|
+
finally:
|
258
|
+
pass
|
259
|
+
|
260
|
+
return dpo_args.output_dir
|
261
|
+
|
262
|
+
def sft_fine_tune(
|
263
|
+
self, request: FineTuneRequest, job: Job, file_manager: FileManager
|
264
|
+
):
|
44
265
|
|
45
266
|
try:
|
46
|
-
train_kwargs = self.
|
267
|
+
train_kwargs = self.find_train_args_sft(request, file_manager)
|
47
268
|
|
48
269
|
import torch
|
49
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
270
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
50
271
|
from trl import SFTConfig, SFTTrainer, setup_chat_format
|
51
272
|
|
52
273
|
device = train_kwargs.get("device", None)
|
@@ -56,12 +277,16 @@ class TrainingManager:
|
|
56
277
|
if torch.cuda.is_available()
|
57
278
|
else "mps" if torch.backends.mps.is_available() else "cpu"
|
58
279
|
)
|
59
|
-
|
280
|
+
job.add_event(
|
281
|
+
JobEvent(level="info", message=f"Using device: {device}", data={})
|
282
|
+
)
|
60
283
|
|
61
284
|
model = AutoModelForCausalLM.from_pretrained(
|
62
285
|
pretrained_model_name_or_path=request.model
|
63
286
|
).to(device)
|
64
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
287
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
288
|
+
pretrained_model_name_or_path=request.model
|
289
|
+
)
|
65
290
|
|
66
291
|
# Set up the chat format; generally only for non-chat model variants, hence the try-except.
|
67
292
|
try:
|
@@ -70,21 +295,40 @@ class TrainingManager:
|
|
70
295
|
pass
|
71
296
|
|
72
297
|
if tokenizer.pad_token_id is None:
|
73
|
-
|
298
|
+
job.add_event(
|
299
|
+
JobEvent(
|
300
|
+
level="info", message="Adding pad token to tokenizer", data={}
|
301
|
+
)
|
302
|
+
)
|
74
303
|
tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
|
75
304
|
|
76
|
-
|
77
|
-
if
|
305
|
+
job.add_event(JobEvent(level="info", message="Creating dataset", data={}))
|
306
|
+
if (
|
307
|
+
"max_seq_length" not in train_kwargs
|
308
|
+
or train_kwargs["max_seq_length"] is None
|
309
|
+
):
|
78
310
|
train_kwargs["max_seq_length"] = 4096
|
79
|
-
|
80
|
-
|
311
|
+
job.add_event(
|
312
|
+
JobEvent(
|
313
|
+
level="info",
|
314
|
+
message=f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}",
|
315
|
+
data={},
|
316
|
+
)
|
317
|
+
)
|
81
318
|
|
319
|
+
job.add_event(JobEvent(level="info", message="Tokenizing dataset", data={}))
|
82
320
|
hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
|
321
|
+
|
83
322
|
def tokenize_function(example):
|
84
|
-
return encode_sft_example(
|
323
|
+
return encode_sft_example(
|
324
|
+
example, tokenizer, train_kwargs["max_seq_length"]
|
325
|
+
)
|
326
|
+
|
85
327
|
tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
|
86
328
|
tokenized_dataset.set_format(type="torch")
|
87
|
-
tokenized_dataset = tokenized_dataset.filter(
|
329
|
+
tokenized_dataset = tokenized_dataset.filter(
|
330
|
+
lambda example: (example["labels"] != -100).any()
|
331
|
+
)
|
88
332
|
|
89
333
|
USE_PEFT = train_kwargs.get("use_peft", False)
|
90
334
|
peft_config = None
|
@@ -125,20 +369,61 @@ class TrainingManager:
|
|
125
369
|
},
|
126
370
|
)
|
127
371
|
|
128
|
-
|
372
|
+
job.add_event(JobEvent(level="info", message="Starting training"))
|
129
373
|
trainer = SFTTrainer(
|
130
374
|
model=model,
|
131
375
|
args=sft_config,
|
132
376
|
train_dataset=tokenized_dataset,
|
133
377
|
peft_config=peft_config,
|
134
|
-
|
135
378
|
)
|
136
379
|
|
137
380
|
# Train!
|
138
381
|
trainer.train()
|
139
382
|
|
383
|
+
if job.status == JobStatus.PENDING_PAUSE:
|
384
|
+
trainer.accelerator.save_state(output_dir=sft_config.output_dir)
|
385
|
+
current_step = trainer.state.global_step
|
386
|
+
job.status = JobStatus.PAUSED
|
387
|
+
job.add_event(
|
388
|
+
JobEvent(
|
389
|
+
level="info",
|
390
|
+
message="Training paused",
|
391
|
+
data={"step": current_step},
|
392
|
+
)
|
393
|
+
)
|
394
|
+
|
395
|
+
while job.status == JobStatus.PAUSED:
|
396
|
+
time.sleep(1) # Sleep to avoid busy waiting
|
397
|
+
if job.status == JobStatus.PENDING_RESUME:
|
398
|
+
job.status = JobStatus.RUNNING
|
399
|
+
job.add_event(JobEvent(level="info", message="Resuming training"))
|
400
|
+
trainer.accelerator.load_state(input_dir=sft_config.output_dir)
|
401
|
+
trainer.train(resume_from_checkpoint=True)
|
402
|
+
|
403
|
+
if job.status == JobStatus.PENDING_CANCEL:
|
404
|
+
job.status = JobStatus.CANCELLED
|
405
|
+
job.add_event(JobEvent(level="info", message="Training cancelled"))
|
406
|
+
|
407
|
+
_cleanup(model, tokenizer, trainer)
|
408
|
+
|
409
|
+
raise Exception(
|
410
|
+
"Training cancelled"
|
411
|
+
) # not sure if this should be raised or just return None
|
412
|
+
|
413
|
+
job.add_event(
|
414
|
+
JobEvent(level="info", message="Training completed successfully")
|
415
|
+
)
|
416
|
+
|
417
|
+
job.add_event(JobEvent(level="info", message="Saving model", data={}))
|
140
418
|
# Save the model!
|
141
419
|
trainer.save_model()
|
420
|
+
job.add_event(
|
421
|
+
JobEvent(
|
422
|
+
level="info",
|
423
|
+
message="Model saved",
|
424
|
+
data={"location": sft_config.output_dir},
|
425
|
+
)
|
426
|
+
)
|
142
427
|
|
143
428
|
MERGE = True
|
144
429
|
if USE_PEFT and MERGE:
|
@@ -156,27 +441,22 @@ class TrainingManager:
|
|
156
441
|
sft_config.output_dir, safe_serialization=True, max_shard_size="5GB"
|
157
442
|
)
|
158
443
|
|
159
|
-
|
160
|
-
import gc
|
444
|
+
_cleanup(model, tokenizer, trainer)
|
161
445
|
|
162
|
-
del model
|
163
|
-
del tokenizer
|
164
|
-
del trainer
|
165
|
-
gc.collect()
|
166
|
-
torch.cuda.empty_cache()
|
167
|
-
|
168
|
-
logger.info("Training completed successfully")
|
169
446
|
job.status = JobStatus.SUCCEEDED
|
170
447
|
job.fine_tuned_model = sft_config.output_dir
|
171
448
|
except Exception as e:
|
172
|
-
|
449
|
+
job.add_event(
|
450
|
+
JobEvent(level="error", message=f"Training failed: {str(e)}", data={})
|
451
|
+
)
|
173
452
|
job.status = JobStatus.FAILED
|
174
453
|
raise
|
175
454
|
finally:
|
176
|
-
|
455
|
+
pass
|
177
456
|
|
178
457
|
return sft_config.output_dir
|
179
458
|
|
459
|
+
|
180
460
|
def dataset_from_file(data_path):
|
181
461
|
"""
|
182
462
|
Creates a HuggingFace Dataset from a JSONL file.
|
@@ -218,7 +498,9 @@ def encode_sft_example(example, tokenizer, max_seq_length):
|
|
218
498
|
message_start_idx = 0
|
219
499
|
else:
|
220
500
|
message_start_idx = tokenizer.apply_chat_template(
|
221
|
-
conversation=messages[
|
501
|
+
conversation=messages[
|
502
|
+
:message_idx
|
503
|
+
], # here marks the end of the previous messages
|
222
504
|
tokenize=True,
|
223
505
|
return_tensors="pt",
|
224
506
|
padding=False,
|
@@ -227,7 +509,10 @@ def encode_sft_example(example, tokenizer, max_seq_length):
|
|
227
509
|
add_generation_prompt=False,
|
228
510
|
).shape[1]
|
229
511
|
# next, we calculate the end index of this non-assistant message
|
230
|
-
if
|
512
|
+
if (
|
513
|
+
message_idx < len(messages) - 1
|
514
|
+
and messages[message_idx + 1]["role"] == "assistant"
|
515
|
+
):
|
231
516
|
# for intermediate messages that follow with an assistant message, we need to
|
232
517
|
# set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
|
233
518
|
# (e.g., `<|assistant|>`)
|
@@ -260,5 +545,17 @@ def encode_sft_example(example, tokenizer, max_seq_length):
|
|
260
545
|
return {
|
261
546
|
"input_ids": input_ids.flatten(),
|
262
547
|
"labels": labels.flatten(),
|
263
|
-
"attention_mask": attention_mask.flatten()
|
264
|
-
}
|
548
|
+
"attention_mask": attention_mask.flatten(),
|
549
|
+
}
|
550
|
+
|
551
|
+
|
552
|
+
def _cleanup(model, tokenizer, trainer):
|
553
|
+
import gc
|
554
|
+
|
555
|
+
import torch
|
556
|
+
|
557
|
+
del model
|
558
|
+
del tokenizer
|
559
|
+
del trainer
|
560
|
+
gc.collect()
|
561
|
+
torch.cuda.empty_cache()
|
@@ -0,0 +1,78 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: arbor-ai
|
3
|
+
Version: 0.1.6
|
4
|
+
Summary: A framework for fine-tuning and managing language models
|
5
|
+
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
|
+
Project-URL: Homepage, https://github.com/Ziems/arbor
|
7
|
+
Project-URL: Issues, https://github.com/Ziems/arbor/issues
|
8
|
+
Requires-Python: >=3.10
|
9
|
+
Description-Content-Type: text/markdown
|
10
|
+
License-File: LICENSE
|
11
|
+
Requires-Dist: fastapi
|
12
|
+
Requires-Dist: uvicorn
|
13
|
+
Requires-Dist: click
|
14
|
+
Requires-Dist: python-multipart
|
15
|
+
Requires-Dist: pydantic-settings
|
16
|
+
Requires-Dist: torch
|
17
|
+
Requires-Dist: transformers
|
18
|
+
Requires-Dist: trl
|
19
|
+
Requires-Dist: peft
|
20
|
+
Requires-Dist: ray>=2.9
|
21
|
+
Requires-Dist: setuptools<77.0.0,>=76.0.0
|
22
|
+
Requires-Dist: pyzmq>=26.4.0
|
23
|
+
Requires-Dist: pyyaml>=6.0.2
|
24
|
+
Requires-Dist: sglang>=0.4.5.post3
|
25
|
+
Requires-Dist: sglang-router
|
26
|
+
Dynamic: license-file
|
27
|
+
|
28
|
+
<p align="center">
|
29
|
+
<img src="https://github.com/user-attachments/assets/ed0dd782-65fa-48b5-a762-b343b183be09" alt="Description" width="400"/>
|
30
|
+
</p>
|
31
|
+
|
32
|
+
**A framework for optimizing DSPy programs with RL.**
|
33
|
+
|
34
|
+
---
|
35
|
+
|
36
|
+
## 🚀 Installation
|
37
|
+
|
38
|
+
Install Arbor via pip:
|
39
|
+
|
40
|
+
```bash
|
41
|
+
pip install git+https://github.com/Ziems/arbor.git
|
42
|
+
```
|
43
|
+
|
44
|
+
---
|
45
|
+
|
46
|
+
## ⚡ Quick Start
|
47
|
+
|
48
|
+
### 1️⃣ Make an `arbor.yaml` File
|
49
|
+
|
50
|
+
This is all dependent on your setup. Here is an example of one:
|
51
|
+
```yaml
|
52
|
+
inference:
|
53
|
+
gpu_ids: '0'
|
54
|
+
|
55
|
+
training:
|
56
|
+
gpu_ids: '1, 2'
|
57
|
+
```
|
58
|
+
|
59
|
+
### 2️⃣ Start the Server
|
60
|
+
|
61
|
+
**CLI:**
|
62
|
+
|
63
|
+
```bash
|
64
|
+
python -m arbor.cli serve --arbor-config arbor.yaml
|
65
|
+
```
|
66
|
+
|
67
|
+
### 3️⃣ Optimize a DSPy Program
|
68
|
+
|
69
|
+
Follow the DSPy tutorials here to see usage examples:
|
70
|
+
[DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
|
71
|
+
|
72
|
+
---
|
73
|
+
|
74
|
+
## 🙏 Acknowledgements
|
75
|
+
|
76
|
+
Arbor builds on the shoulders of great work. We extend our thanks to:
|
77
|
+
- **[Will Brown's Verifiers library](https://github.com/willccbb/verifiers)**
|
78
|
+
- **[Hugging Face TRL library](https://github.com/huggingface/trl)**
|
@@ -0,0 +1,34 @@
|
|
1
|
+
arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
arbor/cli.py,sha256=3o9A03Kew9cM5ZvD_6xOTaquNIE_hTYMOeQH3hkuJbY,3110
|
3
|
+
arbor/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
|
5
|
+
arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
6
|
+
arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
|
7
|
+
arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
8
|
+
arbor/server/api/models/schemas.py,sha256=s_G8sSb05FjkKEqpKpLlqaEd8NysJddHibRHhcnrKIk,5594
|
9
|
+
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
|
11
|
+
arbor/server/api/routes/grpo.py,sha256=VuEvSOwwrHegn9qM-1nbHFmmUnnC_BMwnIHsfIdiJyI,1877
|
12
|
+
arbor/server/api/routes/inference.py,sha256=xlP-FMpOJAiiPZkE470l9mCR0ujLki8RrcO9hmTQD-k,1662
|
13
|
+
arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
|
14
|
+
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
|
+
arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
|
16
|
+
arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
|
+
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
+
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
|
+
arbor/server/services/grpo_manager.py,sha256=T-f1TrNSi_kmxPOcpaDphS8Xf3UMUbricocc6fuaKIM,12077
|
21
|
+
arbor/server/services/inference_manager.py,sha256=qR9xPiYs4Is24vgeF72w7Hbe8j_PGEbl-qewcvUV-dA,9731
|
22
|
+
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
|
+
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
|
+
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
|
+
arbor/server/services/comms/comms.py,sha256=Dg08D2Fm5TAEiGyr0Qcr0uocabQpFD_sBVhxIkj9D2M,7424
|
26
|
+
arbor/server/services/scripts/grpo_training.py,sha256=V36pCMZDJj2DdzquxScOddi9zP8EVPGWN3HGiftFfrY,21082
|
27
|
+
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
|
+
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
+
arbor_ai-0.1.6.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
30
|
+
arbor_ai-0.1.6.dist-info/METADATA,sha256=ot1XsjoFawGbNBmaAaxv29lE_uNh8TTdl8OANkTPTS8,1823
|
31
|
+
arbor_ai-0.1.6.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
32
|
+
arbor_ai-0.1.6.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
33
|
+
arbor_ai-0.1.6.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
34
|
+
arbor_ai-0.1.6.dist-info/RECORD,,
|
@@ -0,0 +1 @@
|
|
1
|
+
arbor
|
@@ -1,16 +0,0 @@
|
|
1
|
-
from fastapi import APIRouter, BackgroundTasks, Depends
|
2
|
-
|
3
|
-
from arbor.server.api.models.schemas import FineTuneRequest, JobStatusResponse
|
4
|
-
from arbor.server.services.job_manager import JobManager, JobStatus
|
5
|
-
from arbor.server.services.file_manager import FileManager
|
6
|
-
from arbor.server.services.training_manager import TrainingManager
|
7
|
-
from arbor.server.services.dependencies import get_training_manager, get_job_manager, get_file_manager
|
8
|
-
|
9
|
-
router = APIRouter()
|
10
|
-
|
11
|
-
@router.post("", response_model=JobStatusResponse)
|
12
|
-
def fine_tune(request: FineTuneRequest, background_tasks: BackgroundTasks, training_manager: TrainingManager = Depends(get_training_manager), job_manager: JobManager = Depends(get_job_manager), file_manager: FileManager = Depends(get_file_manager)):
|
13
|
-
job = job_manager.create_job()
|
14
|
-
background_tasks.add_task(training_manager.fine_tune, request, job, file_manager)
|
15
|
-
job.status = JobStatus.QUEUED
|
16
|
-
return JobStatusResponse(id=job.id, status=job.status.value)
|