arbor-ai 0.1.14__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +12 -0
- arbor/server/api/models/schemas.py +0 -1
- arbor/server/api/routes/grpo.py +4 -9
- arbor/server/api/routes/inference.py +24 -14
- arbor/server/services/grpo_manager.py +176 -103
- arbor/server/services/inference/vllm_client.py +444 -0
- arbor/server/services/inference/vllm_serve.py +2336 -0
- arbor/server/services/inference_manager.py +145 -272
- arbor/server/services/scripts/dpo_training.py +0 -0
- arbor/server/services/scripts/grpo_training.py +165 -57
- arbor/server/services/scripts/sft_training.py +109 -0
- arbor/server/services/scripts/utils/__init__.py +0 -0
- arbor/server/services/scripts/utils/arg_parser.py +31 -0
- arbor/server/services/scripts/utils/dataset.py +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/METADATA +10 -6
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/RECORD +20 -14
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/WHEEL +1 -1
- arbor/server/services/inference/sgl_router_launch_server.py +0 -226
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.14.dist-info → arbor_ai-0.2.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,18 @@
|
|
5
5
|
|
6
6
|
import argparse
|
7
7
|
import json
|
8
|
+
import os
|
8
9
|
import random
|
10
|
+
import shutil
|
11
|
+
import signal
|
12
|
+
import sys
|
9
13
|
import threading
|
10
14
|
import time
|
11
15
|
from functools import lru_cache
|
12
16
|
from typing import Any, List, Optional, Union
|
13
17
|
|
14
18
|
import torch
|
19
|
+
import trl.extras.vllm_client
|
15
20
|
import zmq
|
16
21
|
from accelerate import Accelerator
|
17
22
|
from accelerate.utils import broadcast_object_list, gather, gather_object
|
@@ -32,6 +37,9 @@ from arbor.server.services.comms.comms import (
|
|
32
37
|
ArborScriptCommsHandler,
|
33
38
|
ArborServerCommsHandler,
|
34
39
|
)
|
40
|
+
from arbor.server.services.inference.vllm_client import VLLMClient
|
41
|
+
|
42
|
+
trl.extras.vllm_client.VLLMClient = VLLMClient
|
35
43
|
|
36
44
|
if is_wandb_available():
|
37
45
|
import wandb
|
@@ -71,10 +79,10 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
71
79
|
comms_handler: Optional[ArborScriptCommsHandler] = None,
|
72
80
|
lora: Optional[bool] = False,
|
73
81
|
# We do nothing with max_context_length right now
|
82
|
+
vllm_group_port: Optional[int] = None,
|
74
83
|
max_context_length: Optional[int] = None,
|
75
84
|
**kwargs,
|
76
85
|
):
|
77
|
-
|
78
86
|
super().__init__(
|
79
87
|
model=model,
|
80
88
|
reward_funcs=[],
|
@@ -91,6 +99,33 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
91
99
|
self.scale_rewards = scale_rewards
|
92
100
|
self.comms_handler = comms_handler
|
93
101
|
|
102
|
+
self.vllm_client = None
|
103
|
+
args.use_vllm = True
|
104
|
+
self.use_vllm = True
|
105
|
+
if self.accelerator.is_main_process:
|
106
|
+
print(
|
107
|
+
f"Initializing vLLM client with server port {args.vllm_server_port} and group port {vllm_group_port}"
|
108
|
+
)
|
109
|
+
self.vllm_client = VLLMClient(
|
110
|
+
args.vllm_server_host,
|
111
|
+
args.vllm_server_port,
|
112
|
+
group_port=vllm_group_port,
|
113
|
+
connection_timeout=args.vllm_server_timeout,
|
114
|
+
)
|
115
|
+
self.vllm_client.init_communicator()
|
116
|
+
|
117
|
+
# vLLM specific sampling arguments
|
118
|
+
self.guided_decoding_regex = args.vllm_guided_decoding_regex
|
119
|
+
|
120
|
+
self._last_loaded_step = (
|
121
|
+
-1
|
122
|
+
) # tag to avoid useless loading during grad accumulation
|
123
|
+
|
124
|
+
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
125
|
+
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
126
|
+
# synchronize all processes after vLLM has been fully initialized.
|
127
|
+
self.accelerator.wait_for_everyone()
|
128
|
+
|
94
129
|
def _generate_and_score_completions(
|
95
130
|
self, batch: List[dict[str, Any]]
|
96
131
|
) -> dict[str, Union[torch.Tensor, Any]]:
|
@@ -104,7 +139,11 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
104
139
|
maybe_apply_chat_template(
|
105
140
|
{
|
106
141
|
"prompt": example["messages"],
|
107
|
-
"completion":
|
142
|
+
"completion": (
|
143
|
+
example["completion"]
|
144
|
+
if isinstance(example["completion"], list)
|
145
|
+
else [example["completion"]]
|
146
|
+
),
|
108
147
|
},
|
109
148
|
self.processing_class,
|
110
149
|
)
|
@@ -133,15 +172,15 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
133
172
|
prompt_completion_text["completion"]
|
134
173
|
for prompt_completion_text in prompt_completion_texts
|
135
174
|
]
|
136
|
-
|
175
|
+
completion_inputs = self.processing_class(
|
137
176
|
completions_text,
|
138
177
|
return_tensors="pt",
|
139
178
|
padding=True,
|
140
179
|
add_special_tokens=False,
|
141
180
|
).to(device)
|
142
181
|
completion_ids, completion_mask = (
|
143
|
-
|
144
|
-
|
182
|
+
completion_inputs["input_ids"],
|
183
|
+
completion_inputs["attention_mask"],
|
145
184
|
)
|
146
185
|
|
147
186
|
if self.max_prompt_length is not None:
|
@@ -156,11 +195,6 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
156
195
|
completion_ids = completion_ids[:, : self.max_completion_length]
|
157
196
|
completion_mask = completion_mask[:, : self.max_completion_length]
|
158
197
|
|
159
|
-
# Keeping this for when we switch to vllm
|
160
|
-
# if self.state.global_step != self._last_loaded_step:
|
161
|
-
# self._move_model_to_vllm()
|
162
|
-
# self._last_loaded_step = self.state.global_step
|
163
|
-
|
164
198
|
prompt_ids = broadcast_object_list(prompt_ids)
|
165
199
|
prompt_mask = broadcast_object_list(prompt_mask)
|
166
200
|
completion_ids = broadcast_object_list(completion_ids)
|
@@ -178,6 +212,9 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
178
212
|
|
179
213
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
180
214
|
|
215
|
+
# Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
|
216
|
+
completion_lengths = completion_mask.sum(1)
|
217
|
+
|
181
218
|
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
182
219
|
if self.mask_truncated_completions:
|
183
220
|
truncated_completions = ~is_eos.any(dim=1)
|
@@ -230,6 +267,10 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
230
267
|
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
|
231
268
|
self.num_generations, dim=0
|
232
269
|
)
|
270
|
+
is_std_zero = torch.isclose(
|
271
|
+
std_grouped_rewards, torch.zeros_like(std_grouped_rewards)
|
272
|
+
)
|
273
|
+
|
233
274
|
advantages = rewards - mean_grouped_rewards
|
234
275
|
|
235
276
|
if self.scale_rewards:
|
@@ -241,66 +282,72 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
241
282
|
self.accelerator.process_index * len(batch),
|
242
283
|
(self.accelerator.process_index + 1) * len(batch),
|
243
284
|
)
|
285
|
+
all_process_advantages = (
|
286
|
+
advantages.clone()
|
287
|
+
) # keep the aggregated advantages for logging
|
244
288
|
advantages = advantages[process_slice]
|
245
289
|
|
246
290
|
# Log the metrics
|
247
291
|
if mode == "train":
|
248
292
|
self.state.num_input_tokens_seen += (
|
249
|
-
self.accelerator.
|
293
|
+
self.accelerator.gather(attention_mask.sum()).sum().item()
|
250
294
|
)
|
251
295
|
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
|
252
296
|
|
253
|
-
#
|
254
|
-
|
255
|
-
completion_mask.sum(1)
|
256
|
-
)
|
297
|
+
# Log completion lengths, mean, min, max
|
298
|
+
agg_completion_lengths = self.accelerator.gather(completion_lengths)
|
257
299
|
self._metrics[mode]["completions/mean_length"].append(
|
258
|
-
|
300
|
+
agg_completion_lengths.float().mean().item()
|
259
301
|
)
|
260
302
|
self._metrics[mode]["completions/min_length"].append(
|
261
|
-
|
303
|
+
agg_completion_lengths.float().min().item()
|
262
304
|
)
|
263
305
|
self._metrics[mode]["completions/max_length"].append(
|
264
|
-
|
306
|
+
agg_completion_lengths.float().max().item()
|
265
307
|
)
|
266
308
|
|
267
|
-
#
|
268
|
-
agg_terminated_with_eos = self.accelerator.
|
269
|
-
|
270
|
-
clipped_completions_ratio = 1 - len(
|
271
|
-
|
309
|
+
# Identify sequences that terminated with EOS and log their lengths
|
310
|
+
agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
|
311
|
+
term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
|
312
|
+
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(
|
313
|
+
agg_completion_lengths
|
272
314
|
)
|
273
315
|
self._metrics[mode]["completions/clipped_ratio"].append(
|
274
316
|
clipped_completions_ratio
|
275
317
|
)
|
276
|
-
if
|
277
|
-
|
278
|
-
|
318
|
+
if (
|
319
|
+
len(term_completion_lengths) == 0
|
320
|
+
): # edge case where no terminated sequences are found
|
321
|
+
term_completion_lengths = torch.zeros(1, device=device)
|
279
322
|
self._metrics[mode]["completions/mean_terminated_length"].append(
|
280
|
-
|
323
|
+
term_completion_lengths.float().mean().item()
|
281
324
|
)
|
282
325
|
self._metrics[mode]["completions/min_terminated_length"].append(
|
283
|
-
|
326
|
+
term_completion_lengths.float().min().item()
|
284
327
|
)
|
285
328
|
self._metrics[mode]["completions/max_terminated_length"].append(
|
286
|
-
|
329
|
+
term_completion_lengths.float().max().item()
|
287
330
|
)
|
288
331
|
|
289
|
-
# Calculate mean reward
|
332
|
+
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
|
290
333
|
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
|
291
334
|
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
|
335
|
+
self._metrics[mode]["frac_reward_zero_std"].append(
|
336
|
+
is_std_zero.float().mean().item()
|
337
|
+
)
|
292
338
|
|
293
339
|
# Log prompt and completion texts
|
294
340
|
self._textual_logs["prompt"].extend(gather_object(prompts_text))
|
295
341
|
self._textual_logs["completion"].extend(gather_object(completions_text))
|
342
|
+
self._textual_logs["advantages"].extend(all_process_advantages.tolist())
|
296
343
|
|
297
344
|
return {
|
298
345
|
"prompt_ids": prompt_ids,
|
299
346
|
"prompt_mask": prompt_mask,
|
300
347
|
"completion_ids": completion_ids,
|
301
348
|
"completion_mask": completion_mask,
|
302
|
-
"old_per_token_logps": old_per_token_logps,
|
303
349
|
"advantages": advantages,
|
350
|
+
"old_per_token_logps": old_per_token_logps,
|
304
351
|
}
|
305
352
|
|
306
353
|
|
@@ -313,6 +360,30 @@ class LastStepTimeCallback(TrainerCallback):
|
|
313
360
|
last_step_time = time.time()
|
314
361
|
|
315
362
|
|
363
|
+
class WeightUpdateCallback(TrainerCallback):
|
364
|
+
"""A callback that sends weight update completion status after each step"""
|
365
|
+
|
366
|
+
def __init__(self):
|
367
|
+
self.comms_handler = None
|
368
|
+
self.trainer = None
|
369
|
+
|
370
|
+
def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
|
371
|
+
self.comms_handler = comms_handler
|
372
|
+
|
373
|
+
def set_trainer(self, trainer):
|
374
|
+
self.trainer = trainer
|
375
|
+
|
376
|
+
def on_step_end(self, args, state, control, **kwargs):
|
377
|
+
if self.comms_handler and self.comms_handler.is_main_process and self.trainer:
|
378
|
+
if state.global_step != self.trainer._last_loaded_step:
|
379
|
+
print("Updating inference model...")
|
380
|
+
self.comms_handler.send_status({"status": "weight_update_start"})
|
381
|
+
self.trainer._move_model_to_vllm()
|
382
|
+
self.trainer._last_loaded_step = state.global_step
|
383
|
+
print("[DEBUG] Sending weight update completion status")
|
384
|
+
self.comms_handler.send_status({"status": "weight_update_complete"})
|
385
|
+
|
386
|
+
|
316
387
|
class BlockingQueueDataset(Dataset):
|
317
388
|
def __init__(
|
318
389
|
self,
|
@@ -379,11 +450,6 @@ class CommandMonitor:
|
|
379
450
|
)
|
380
451
|
self.command_thread.start()
|
381
452
|
|
382
|
-
self.broadcast_thread = threading.Thread(
|
383
|
-
target=self._monitor_broadcasts, daemon=True
|
384
|
-
)
|
385
|
-
self.broadcast_thread.start()
|
386
|
-
|
387
453
|
def _monitor_commands(self):
|
388
454
|
"""Background thread that monitors for commands from the server."""
|
389
455
|
if not self.comms_handler:
|
@@ -478,6 +544,26 @@ class CommandMonitor:
|
|
478
544
|
output_dir=self.trainer.args.output_dir
|
479
545
|
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
480
546
|
)
|
547
|
+
|
548
|
+
# Copy checkpoint files to root output directory
|
549
|
+
checkpoint_dir = (
|
550
|
+
self.trainer.args.output_dir
|
551
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
552
|
+
)
|
553
|
+
root_dir = self.trainer.args.output_dir
|
554
|
+
|
555
|
+
# Copy all files from checkpoint dir to root dir, overwriting if they exist
|
556
|
+
# (effectively saves the checkpoint to the output directory)
|
557
|
+
for item in os.listdir(checkpoint_dir):
|
558
|
+
src = os.path.join(checkpoint_dir, item)
|
559
|
+
dst = os.path.join(root_dir, item)
|
560
|
+
if os.path.isdir(src):
|
561
|
+
if os.path.exists(dst):
|
562
|
+
shutil.rmtree(dst)
|
563
|
+
shutil.copytree(src, dst)
|
564
|
+
else:
|
565
|
+
shutil.copy2(src, dst)
|
566
|
+
|
481
567
|
self.comms_handler.send_status(
|
482
568
|
{
|
483
569
|
"status": "checkpoint_saved",
|
@@ -486,31 +572,21 @@ class CommandMonitor:
|
|
486
572
|
+ f"/checkpoints/{command.get('checkpoint_name')}/",
|
487
573
|
}
|
488
574
|
)
|
575
|
+
self.comms_handler.send_status(
|
576
|
+
{
|
577
|
+
"status": "model_saved",
|
578
|
+
"output_dir": self.trainer.args.output_dir,
|
579
|
+
}
|
580
|
+
)
|
581
|
+
elif command.get("command") == "terminate":
|
582
|
+
print("TERMINATED")
|
583
|
+
self.trainer.accelerator.end_training()
|
584
|
+
self.comms_handler.send_status({"status": "terminated"})
|
489
585
|
|
490
586
|
except Exception as e:
|
491
587
|
print(e)
|
492
588
|
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
493
589
|
|
494
|
-
def _monitor_broadcasts(self):
|
495
|
-
"""Background thread that monitors for broadcasts from the server."""
|
496
|
-
if not self.comms_handler:
|
497
|
-
return
|
498
|
-
try:
|
499
|
-
for broadcast in self.comms_handler.receive_broadcast():
|
500
|
-
print(f"!!!Received broadcast: {broadcast}")
|
501
|
-
if broadcast.get("message") == "terminate":
|
502
|
-
# self.trainer.control.should_training_stop = True
|
503
|
-
# self.comms_handler.send_status(
|
504
|
-
# {
|
505
|
-
# "status": "Received termination command",
|
506
|
-
# "process_id": self.trainer.accelerator.process_index,
|
507
|
-
# }
|
508
|
-
# )
|
509
|
-
if self.trainer.accelerator.is_main_process:
|
510
|
-
self.trainer.accelerator.end_training()
|
511
|
-
except Exception as e:
|
512
|
-
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
513
|
-
|
514
590
|
|
515
591
|
def main():
|
516
592
|
parser = argparse.ArgumentParser()
|
@@ -523,6 +599,8 @@ def main():
|
|
523
599
|
pipe_args.add_argument("--data_port", type=int, required=True)
|
524
600
|
pipe_args.add_argument("--broadcast_port", type=int, required=True)
|
525
601
|
pipe_args.add_argument("--handshake_port", type=int, required=True)
|
602
|
+
pipe_args.add_argument("--vllm_group_port", type=int, required=True)
|
603
|
+
pipe_args.add_argument("--vllm_port", type=int, required=True)
|
526
604
|
|
527
605
|
training_args = parser.add_argument_group("Training arguments")
|
528
606
|
training_args.add_argument(
|
@@ -544,6 +622,11 @@ def main():
|
|
544
622
|
args = parser.parse_args()
|
545
623
|
|
546
624
|
if args.debug:
|
625
|
+
# python grpo_training.py --debug
|
626
|
+
# --command_port 0 --status_port 0
|
627
|
+
# --data_port 0 --broadcast_port 0
|
628
|
+
# --handshake_port 0 --model Qwen/Qwen3-0.6B
|
629
|
+
# --trl_train_kwargs '{"output_dir": ".", "report_to": "none"}'
|
547
630
|
server_comms_handler = ArborServerCommsHandler(
|
548
631
|
host=args.host,
|
549
632
|
)
|
@@ -554,6 +637,11 @@ def main():
|
|
554
637
|
args.broadcast_port = server_comms_handler.broadcast_port
|
555
638
|
args.handshake_port = server_comms_handler.handshake_port
|
556
639
|
|
640
|
+
handshake_thread = threading.Thread(
|
641
|
+
target=server_comms_handler.wait_for_clients, args=(1,), daemon=True
|
642
|
+
)
|
643
|
+
handshake_thread.start()
|
644
|
+
|
557
645
|
def debug_data_generator():
|
558
646
|
tldr_dataset = load_dataset("trl-lib/tldr", split="train")
|
559
647
|
idx = 0
|
@@ -636,15 +724,18 @@ def main():
|
|
636
724
|
training_args = GRPOConfig(
|
637
725
|
dataloader_num_workers=0,
|
638
726
|
shuffle_dataset=False,
|
727
|
+
vllm_server_port=args.vllm_port,
|
639
728
|
**trl_train_args,
|
640
729
|
)
|
641
730
|
|
731
|
+
weight_update_callback = WeightUpdateCallback()
|
642
732
|
trainer = ArborGRPOTrainer(
|
643
733
|
model=args.model,
|
644
734
|
args=training_args,
|
645
735
|
train_dataset=BlockingQueueDataset(None, None),
|
646
|
-
callbacks=[LastStepTimeCallback()],
|
736
|
+
callbacks=[LastStepTimeCallback(), weight_update_callback],
|
647
737
|
peft_config=lora_config,
|
738
|
+
vllm_group_port=args.vllm_group_port,
|
648
739
|
**arbor_train_args,
|
649
740
|
)
|
650
741
|
# Create client handler
|
@@ -657,6 +748,8 @@ def main():
|
|
657
748
|
handshake_port=args.handshake_port,
|
658
749
|
is_main_process=trainer.accelerator.is_main_process,
|
659
750
|
)
|
751
|
+
weight_update_callback.set_comms_handler(comms_handler)
|
752
|
+
weight_update_callback.set_trainer(trainer)
|
660
753
|
trainer.comms_handler = comms_handler
|
661
754
|
|
662
755
|
# Initialize the dataset with the actual accelerator
|
@@ -671,6 +764,18 @@ def main():
|
|
671
764
|
base_model_name=args.model,
|
672
765
|
)
|
673
766
|
|
767
|
+
# Add signal handlers for graceful shutdown
|
768
|
+
def signal_handler(signum, frame):
|
769
|
+
print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
|
770
|
+
print("Ending training...")
|
771
|
+
trainer.accelerator.end_training()
|
772
|
+
print("Closing communications...")
|
773
|
+
comms_handler.close()
|
774
|
+
sys.exit(0)
|
775
|
+
|
776
|
+
signal.signal(signal.SIGINT, signal_handler)
|
777
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
778
|
+
|
674
779
|
print("Training...")
|
675
780
|
trainer.train()
|
676
781
|
|
@@ -681,7 +786,10 @@ def main():
|
|
681
786
|
comms_handler.send_status({"status": "error", "error": str(e)})
|
682
787
|
raise e
|
683
788
|
finally:
|
789
|
+
print("Cleaning up resources...")
|
790
|
+
trainer.accelerator.end_training()
|
684
791
|
comms_handler.close()
|
792
|
+
print("Cleanup complete")
|
685
793
|
|
686
794
|
|
687
795
|
if __name__ == "__main__":
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import argparse
|
2
|
+
import json
|
3
|
+
import random
|
4
|
+
import threading
|
5
|
+
import time
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import zmq
|
9
|
+
from peft import LoraConfig
|
10
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11
|
+
from trl import SFTConfig, SFTTrainer, setup_chat_format
|
12
|
+
|
13
|
+
from arbor.server.services.scripts.utils.arg_parser import get_training_arg_parser
|
14
|
+
|
15
|
+
|
16
|
+
def main():
|
17
|
+
parser = get_training_arg_parser()
|
18
|
+
parser.add_argument("--model", type=str, required=True)
|
19
|
+
parser.add_argument("--lora", type=bool, default=False)
|
20
|
+
args = parser.parse_args()
|
21
|
+
|
22
|
+
try:
|
23
|
+
trl_train_kwargs = {**(args.trl_config_kwargs or {})}
|
24
|
+
|
25
|
+
# TODO: These assertions should be done in some better way
|
26
|
+
assert "output_dir" in trl_train_kwargs, "output_dir is required"
|
27
|
+
if "gradient_checkpointing_kwargs" in trl_train_kwargs and args.lora:
|
28
|
+
print(
|
29
|
+
"Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
|
30
|
+
)
|
31
|
+
trl_train_kwargs["gradient_checkpointing_kwargs"] = {
|
32
|
+
**(trl_train_kwargs.get("gradient_checkpointing_kwargs") or {}),
|
33
|
+
"use_reentrant": False,
|
34
|
+
}
|
35
|
+
|
36
|
+
lora_config = None
|
37
|
+
if args.lora:
|
38
|
+
print("Using LORA for PEFT")
|
39
|
+
lora_config = LoraConfig(
|
40
|
+
r=16,
|
41
|
+
lora_alpha=64,
|
42
|
+
target_modules=[
|
43
|
+
"q_proj",
|
44
|
+
"k_proj",
|
45
|
+
"v_proj",
|
46
|
+
"o_proj",
|
47
|
+
"up_proj",
|
48
|
+
"down_proj",
|
49
|
+
"gate_proj",
|
50
|
+
],
|
51
|
+
task_type="CAUSAL_LM",
|
52
|
+
lora_dropout=0.05,
|
53
|
+
inference_mode=False,
|
54
|
+
)
|
55
|
+
|
56
|
+
training_args = GRPOConfig(
|
57
|
+
dataloader_num_workers=0,
|
58
|
+
shuffle_dataset=False,
|
59
|
+
**trl_train_args,
|
60
|
+
)
|
61
|
+
|
62
|
+
trainer = ArborGRPOTrainer(
|
63
|
+
model=args.model,
|
64
|
+
args=training_args,
|
65
|
+
train_dataset=BlockingQueueDataset(None, None),
|
66
|
+
callbacks=[LastStepTimeCallback()],
|
67
|
+
peft_config=lora_config,
|
68
|
+
**arbor_train_args,
|
69
|
+
)
|
70
|
+
# Create client handler
|
71
|
+
comms_handler = ArborScriptCommsHandler(
|
72
|
+
host=args.host,
|
73
|
+
command_port=args.command_port,
|
74
|
+
status_port=args.status_port,
|
75
|
+
data_port=args.data_port,
|
76
|
+
broadcast_port=args.broadcast_port,
|
77
|
+
handshake_port=args.handshake_port,
|
78
|
+
is_main_process=trainer.accelerator.is_main_process,
|
79
|
+
)
|
80
|
+
trainer.comms_handler = comms_handler
|
81
|
+
|
82
|
+
# Initialize the dataset with the actual accelerator
|
83
|
+
trainer.train_dataset = BlockingQueueDataset(
|
84
|
+
accelerator=trainer.accelerator,
|
85
|
+
comms_handler=trainer.comms_handler,
|
86
|
+
)
|
87
|
+
|
88
|
+
command_monitor = CommandMonitor(
|
89
|
+
comms_handler=comms_handler,
|
90
|
+
trainer=trainer,
|
91
|
+
base_model_name=args.model,
|
92
|
+
)
|
93
|
+
|
94
|
+
print("Training...")
|
95
|
+
trainer.train()
|
96
|
+
|
97
|
+
except KeyboardInterrupt:
|
98
|
+
print("\nReceived interrupt, shutting down...")
|
99
|
+
except Exception as e:
|
100
|
+
print(f"Error: {e}")
|
101
|
+
comms_handler.send_status({"status": "error", "error": str(e)})
|
102
|
+
raise e
|
103
|
+
finally:
|
104
|
+
trainer.accelerator.end_training()
|
105
|
+
comms_handler.close()
|
106
|
+
|
107
|
+
|
108
|
+
if __name__ == "__main__":
|
109
|
+
main()
|
File without changes
|
@@ -0,0 +1,31 @@
|
|
1
|
+
"""The arg parser for the training scripts"""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import json
|
5
|
+
|
6
|
+
|
7
|
+
def get_training_arg_parser():
|
8
|
+
parser = argparse.ArgumentParser()
|
9
|
+
parser.add_argument("--debug", action="store_true")
|
10
|
+
|
11
|
+
pipe_args = parser.add_argument_group("Comms arguments")
|
12
|
+
pipe_args.add_argument("--host", default="localhost")
|
13
|
+
pipe_args.add_argument("--command_port", type=int, required=True)
|
14
|
+
pipe_args.add_argument("--status_port", type=int, required=True)
|
15
|
+
pipe_args.add_argument("--data_port", type=int, required=True)
|
16
|
+
pipe_args.add_argument("--broadcast_port", type=int, required=True)
|
17
|
+
pipe_args.add_argument("--handshake_port", type=int, required=True)
|
18
|
+
|
19
|
+
training_args = parser.add_argument_group("Training arguments")
|
20
|
+
training_args.add_argument(
|
21
|
+
"--model",
|
22
|
+
type=str,
|
23
|
+
help="Model to use for training",
|
24
|
+
)
|
25
|
+
training_args.add_argument(
|
26
|
+
"--trl_config_kwargs",
|
27
|
+
type=json.loads,
|
28
|
+
help="Training configs as a JSON string",
|
29
|
+
)
|
30
|
+
|
31
|
+
return parser
|
File without changes
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -8,21 +8,20 @@ Project-URL: Issues, https://github.com/Ziems/arbor/issues
|
|
8
8
|
Requires-Python: >=3.10
|
9
9
|
Description-Content-Type: text/markdown
|
10
10
|
License-File: LICENSE
|
11
|
+
Requires-Dist: torch>=2.6.0
|
11
12
|
Requires-Dist: fastapi
|
12
13
|
Requires-Dist: uvicorn
|
13
14
|
Requires-Dist: click
|
14
15
|
Requires-Dist: python-multipart
|
15
16
|
Requires-Dist: pydantic-settings
|
16
|
-
Requires-Dist:
|
17
|
+
Requires-Dist: vllm>=0.8.5.post1
|
17
18
|
Requires-Dist: transformers
|
18
|
-
Requires-Dist: trl
|
19
|
+
Requires-Dist: trl>=0.17.0
|
19
20
|
Requires-Dist: peft
|
20
21
|
Requires-Dist: ray>=2.9
|
21
22
|
Requires-Dist: setuptools<77.0.0,>=76.0.0
|
22
23
|
Requires-Dist: pyzmq>=26.4.0
|
23
24
|
Requires-Dist: pyyaml>=6.0.2
|
24
|
-
Requires-Dist: sglang[all]>=0.4.5.post3
|
25
|
-
Requires-Dist: sglang-router
|
26
25
|
Requires-Dist: wandb
|
27
26
|
Dynamic: license-file
|
28
27
|
|
@@ -41,7 +40,12 @@ Dynamic: license-file
|
|
41
40
|
Install Arbor via pip:
|
42
41
|
|
43
42
|
```bash
|
44
|
-
pip install arbor-ai
|
43
|
+
pip install -U arbor-ai
|
44
|
+
```
|
45
|
+
|
46
|
+
Optionally, you can also install:
|
47
|
+
```bash
|
48
|
+
pip install flash-attn --no-build-isolation
|
45
49
|
```
|
46
50
|
|
47
51
|
---
|