arbor-ai 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +12 -0
- arbor/server/api/routes/grpo.py +4 -1
- arbor/server/api/routes/inference.py +11 -16
- arbor/server/services/grpo_manager.py +179 -98
- arbor/server/services/inference/__init__.py +0 -0
- arbor/server/services/inference/vllm_client.py +445 -0
- arbor/server/services/inference/vllm_serve.py +2335 -0
- arbor/server/services/inference_manager.py +149 -219
- arbor/server/services/scripts/dpo_training.py +0 -0
- arbor/server/services/scripts/grpo_training.py +157 -53
- arbor/server/services/scripts/sft_training.py +109 -0
- arbor/server/services/scripts/utils/__init__.py +0 -0
- arbor/server/services/scripts/utils/arg_parser.py +31 -0
- arbor/server/services/scripts/utils/dataset.py +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/METADATA +4 -5
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/RECORD +20 -12
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/WHEEL +1 -1
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/entry_points.txt +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/licenses/LICENSE +0 -0
- {arbor_ai-0.1.13.dist-info → arbor_ai-0.1.15.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,18 @@
|
|
5
5
|
|
6
6
|
import argparse
|
7
7
|
import json
|
8
|
+
import os
|
8
9
|
import random
|
10
|
+
import shutil
|
11
|
+
import signal
|
12
|
+
import sys
|
9
13
|
import threading
|
10
14
|
import time
|
11
15
|
from functools import lru_cache
|
12
16
|
from typing import Any, List, Optional, Union
|
13
17
|
|
14
18
|
import torch
|
19
|
+
import trl.extras.vllm_client
|
15
20
|
import zmq
|
16
21
|
from accelerate import Accelerator
|
17
22
|
from accelerate.utils import broadcast_object_list, gather, gather_object
|
@@ -32,6 +37,9 @@ from arbor.server.services.comms.comms import (
|
|
32
37
|
ArborScriptCommsHandler,
|
33
38
|
ArborServerCommsHandler,
|
34
39
|
)
|
40
|
+
from arbor.server.services.inference.vllm_client import VLLMClient
|
41
|
+
|
42
|
+
trl.extras.vllm_client.VLLMClient = VLLMClient
|
35
43
|
|
36
44
|
if is_wandb_available():
|
37
45
|
import wandb
|
@@ -71,10 +79,10 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
71
79
|
comms_handler: Optional[ArborScriptCommsHandler] = None,
|
72
80
|
lora: Optional[bool] = False,
|
73
81
|
# We do nothing with max_context_length right now
|
82
|
+
vllm_group_port: Optional[int] = None,
|
74
83
|
max_context_length: Optional[int] = None,
|
75
84
|
**kwargs,
|
76
85
|
):
|
77
|
-
|
78
86
|
super().__init__(
|
79
87
|
model=model,
|
80
88
|
reward_funcs=[],
|
@@ -91,6 +99,33 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
91
99
|
self.scale_rewards = scale_rewards
|
92
100
|
self.comms_handler = comms_handler
|
93
101
|
|
102
|
+
self.vllm_client = None
|
103
|
+
args.use_vllm = True
|
104
|
+
self.use_vllm = True
|
105
|
+
if self.accelerator.is_main_process:
|
106
|
+
print(
|
107
|
+
f"Initializing vLLM client with server port {args.vllm_server_port} and group port {vllm_group_port}"
|
108
|
+
)
|
109
|
+
self.vllm_client = VLLMClient(
|
110
|
+
args.vllm_server_host,
|
111
|
+
args.vllm_server_port,
|
112
|
+
group_port=vllm_group_port,
|
113
|
+
connection_timeout=args.vllm_server_timeout,
|
114
|
+
)
|
115
|
+
self.vllm_client.init_communicator()
|
116
|
+
|
117
|
+
# vLLM specific sampling arguments
|
118
|
+
self.guided_decoding_regex = args.vllm_guided_decoding_regex
|
119
|
+
|
120
|
+
self._last_loaded_step = (
|
121
|
+
-1
|
122
|
+
) # tag to avoid useless loading during grad accumulation
|
123
|
+
|
124
|
+
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
125
|
+
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
126
|
+
# synchronize all processes after vLLM has been fully initialized.
|
127
|
+
self.accelerator.wait_for_everyone()
|
128
|
+
|
94
129
|
def _generate_and_score_completions(
|
95
130
|
self, batch: List[dict[str, Any]]
|
96
131
|
) -> dict[str, Union[torch.Tensor, Any]]:
|
@@ -156,11 +191,6 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
156
191
|
completion_ids = completion_ids[:, : self.max_completion_length]
|
157
192
|
completion_mask = completion_mask[:, : self.max_completion_length]
|
158
193
|
|
159
|
-
# Keeping this for when we switch to vllm
|
160
|
-
# if self.state.global_step != self._last_loaded_step:
|
161
|
-
# self._move_model_to_vllm()
|
162
|
-
# self._last_loaded_step = self.state.global_step
|
163
|
-
|
164
194
|
prompt_ids = broadcast_object_list(prompt_ids)
|
165
195
|
prompt_mask = broadcast_object_list(prompt_mask)
|
166
196
|
completion_ids = broadcast_object_list(completion_ids)
|
@@ -178,6 +208,9 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
178
208
|
|
179
209
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
180
210
|
|
211
|
+
# Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
|
212
|
+
completion_lengths = completion_mask.sum(1)
|
213
|
+
|
181
214
|
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
182
215
|
if self.mask_truncated_completions:
|
183
216
|
truncated_completions = ~is_eos.any(dim=1)
|
@@ -230,6 +263,10 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
230
263
|
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
|
231
264
|
self.num_generations, dim=0
|
232
265
|
)
|
266
|
+
is_std_zero = torch.isclose(
|
267
|
+
std_grouped_rewards, torch.zeros_like(std_grouped_rewards)
|
268
|
+
)
|
269
|
+
|
233
270
|
advantages = rewards - mean_grouped_rewards
|
234
271
|
|
235
272
|
if self.scale_rewards:
|
@@ -241,66 +278,72 @@ class ArborGRPOTrainer(GRPOTrainer):
|
|
241
278
|
self.accelerator.process_index * len(batch),
|
242
279
|
(self.accelerator.process_index + 1) * len(batch),
|
243
280
|
)
|
281
|
+
all_process_advantages = (
|
282
|
+
advantages.clone()
|
283
|
+
) # keep the aggregated advantages for logging
|
244
284
|
advantages = advantages[process_slice]
|
245
285
|
|
246
286
|
# Log the metrics
|
247
287
|
if mode == "train":
|
248
288
|
self.state.num_input_tokens_seen += (
|
249
|
-
self.accelerator.
|
289
|
+
self.accelerator.gather(attention_mask.sum()).sum().item()
|
250
290
|
)
|
251
291
|
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
|
252
292
|
|
253
|
-
#
|
254
|
-
|
255
|
-
completion_mask.sum(1)
|
256
|
-
)
|
293
|
+
# Log completion lengths, mean, min, max
|
294
|
+
agg_completion_lengths = self.accelerator.gather(completion_lengths)
|
257
295
|
self._metrics[mode]["completions/mean_length"].append(
|
258
|
-
|
296
|
+
agg_completion_lengths.float().mean().item()
|
259
297
|
)
|
260
298
|
self._metrics[mode]["completions/min_length"].append(
|
261
|
-
|
299
|
+
agg_completion_lengths.float().min().item()
|
262
300
|
)
|
263
301
|
self._metrics[mode]["completions/max_length"].append(
|
264
|
-
|
302
|
+
agg_completion_lengths.float().max().item()
|
265
303
|
)
|
266
304
|
|
267
|
-
#
|
268
|
-
agg_terminated_with_eos = self.accelerator.
|
269
|
-
|
270
|
-
clipped_completions_ratio = 1 - len(
|
271
|
-
|
305
|
+
# Identify sequences that terminated with EOS and log their lengths
|
306
|
+
agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
|
307
|
+
term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
|
308
|
+
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(
|
309
|
+
agg_completion_lengths
|
272
310
|
)
|
273
311
|
self._metrics[mode]["completions/clipped_ratio"].append(
|
274
312
|
clipped_completions_ratio
|
275
313
|
)
|
276
|
-
if
|
277
|
-
|
278
|
-
|
314
|
+
if (
|
315
|
+
len(term_completion_lengths) == 0
|
316
|
+
): # edge case where no terminated sequences are found
|
317
|
+
term_completion_lengths = torch.zeros(1, device=device)
|
279
318
|
self._metrics[mode]["completions/mean_terminated_length"].append(
|
280
|
-
|
319
|
+
term_completion_lengths.float().mean().item()
|
281
320
|
)
|
282
321
|
self._metrics[mode]["completions/min_terminated_length"].append(
|
283
|
-
|
322
|
+
term_completion_lengths.float().min().item()
|
284
323
|
)
|
285
324
|
self._metrics[mode]["completions/max_terminated_length"].append(
|
286
|
-
|
325
|
+
term_completion_lengths.float().max().item()
|
287
326
|
)
|
288
327
|
|
289
|
-
# Calculate mean reward
|
328
|
+
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
|
290
329
|
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
|
291
330
|
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
|
331
|
+
self._metrics[mode]["frac_reward_zero_std"].append(
|
332
|
+
is_std_zero.float().mean().item()
|
333
|
+
)
|
292
334
|
|
293
335
|
# Log prompt and completion texts
|
294
336
|
self._textual_logs["prompt"].extend(gather_object(prompts_text))
|
295
337
|
self._textual_logs["completion"].extend(gather_object(completions_text))
|
338
|
+
self._textual_logs["advantages"].extend(all_process_advantages.tolist())
|
296
339
|
|
297
340
|
return {
|
298
341
|
"prompt_ids": prompt_ids,
|
299
342
|
"prompt_mask": prompt_mask,
|
300
343
|
"completion_ids": completion_ids,
|
301
344
|
"completion_mask": completion_mask,
|
302
|
-
"old_per_token_logps": old_per_token_logps,
|
303
345
|
"advantages": advantages,
|
346
|
+
"old_per_token_logps": old_per_token_logps,
|
304
347
|
}
|
305
348
|
|
306
349
|
|
@@ -313,6 +356,30 @@ class LastStepTimeCallback(TrainerCallback):
|
|
313
356
|
last_step_time = time.time()
|
314
357
|
|
315
358
|
|
359
|
+
class WeightUpdateCallback(TrainerCallback):
|
360
|
+
"""A callback that sends weight update completion status after each step"""
|
361
|
+
|
362
|
+
def __init__(self):
|
363
|
+
self.comms_handler = None
|
364
|
+
self.trainer = None
|
365
|
+
|
366
|
+
def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
|
367
|
+
self.comms_handler = comms_handler
|
368
|
+
|
369
|
+
def set_trainer(self, trainer):
|
370
|
+
self.trainer = trainer
|
371
|
+
|
372
|
+
def on_step_end(self, args, state, control, **kwargs):
|
373
|
+
if self.comms_handler and self.comms_handler.is_main_process and self.trainer:
|
374
|
+
if state.global_step != self.trainer._last_loaded_step:
|
375
|
+
print("Updating inference model...")
|
376
|
+
self.comms_handler.send_status({"status": "weight_update_start"})
|
377
|
+
self.trainer._move_model_to_vllm()
|
378
|
+
self.trainer._last_loaded_step = state.global_step
|
379
|
+
print("[DEBUG] Sending weight update completion status")
|
380
|
+
self.comms_handler.send_status({"status": "weight_update_complete"})
|
381
|
+
|
382
|
+
|
316
383
|
class BlockingQueueDataset(Dataset):
|
317
384
|
def __init__(
|
318
385
|
self,
|
@@ -379,11 +446,6 @@ class CommandMonitor:
|
|
379
446
|
)
|
380
447
|
self.command_thread.start()
|
381
448
|
|
382
|
-
self.broadcast_thread = threading.Thread(
|
383
|
-
target=self._monitor_broadcasts, daemon=True
|
384
|
-
)
|
385
|
-
self.broadcast_thread.start()
|
386
|
-
|
387
449
|
def _monitor_commands(self):
|
388
450
|
"""Background thread that monitors for commands from the server."""
|
389
451
|
if not self.comms_handler:
|
@@ -478,6 +540,26 @@ class CommandMonitor:
|
|
478
540
|
output_dir=self.trainer.args.output_dir
|
479
541
|
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
480
542
|
)
|
543
|
+
|
544
|
+
# Copy checkpoint files to root output directory
|
545
|
+
checkpoint_dir = (
|
546
|
+
self.trainer.args.output_dir
|
547
|
+
+ f"/checkpoints/{command.get('checkpoint_name')}/"
|
548
|
+
)
|
549
|
+
root_dir = self.trainer.args.output_dir
|
550
|
+
|
551
|
+
# Copy all files from checkpoint dir to root dir, overwriting if they exist
|
552
|
+
# (effectively saves the checkpoint to the output directory)
|
553
|
+
for item in os.listdir(checkpoint_dir):
|
554
|
+
src = os.path.join(checkpoint_dir, item)
|
555
|
+
dst = os.path.join(root_dir, item)
|
556
|
+
if os.path.isdir(src):
|
557
|
+
if os.path.exists(dst):
|
558
|
+
shutil.rmtree(dst)
|
559
|
+
shutil.copytree(src, dst)
|
560
|
+
else:
|
561
|
+
shutil.copy2(src, dst)
|
562
|
+
|
481
563
|
self.comms_handler.send_status(
|
482
564
|
{
|
483
565
|
"status": "checkpoint_saved",
|
@@ -486,31 +568,21 @@ class CommandMonitor:
|
|
486
568
|
+ f"/checkpoints/{command.get('checkpoint_name')}/",
|
487
569
|
}
|
488
570
|
)
|
571
|
+
self.comms_handler.send_status(
|
572
|
+
{
|
573
|
+
"status": "model_saved",
|
574
|
+
"output_dir": self.trainer.args.output_dir,
|
575
|
+
}
|
576
|
+
)
|
577
|
+
elif command.get("command") == "terminate":
|
578
|
+
print("TERMINATED")
|
579
|
+
self.trainer.accelerator.end_training()
|
580
|
+
self.comms_handler.send_status({"status": "terminated"})
|
489
581
|
|
490
582
|
except Exception as e:
|
491
583
|
print(e)
|
492
584
|
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
493
585
|
|
494
|
-
def _monitor_broadcasts(self):
|
495
|
-
"""Background thread that monitors for broadcasts from the server."""
|
496
|
-
if not self.comms_handler:
|
497
|
-
return
|
498
|
-
try:
|
499
|
-
for broadcast in self.comms_handler.receive_broadcast():
|
500
|
-
print(f"!!!Received broadcast: {broadcast}")
|
501
|
-
if broadcast.get("message") == "terminate":
|
502
|
-
# self.trainer.control.should_training_stop = True
|
503
|
-
# self.comms_handler.send_status(
|
504
|
-
# {
|
505
|
-
# "status": "Received termination command",
|
506
|
-
# "process_id": self.trainer.accelerator.process_index,
|
507
|
-
# }
|
508
|
-
# )
|
509
|
-
if self.trainer.accelerator.is_main_process:
|
510
|
-
self.trainer.accelerator.end_training()
|
511
|
-
except Exception as e:
|
512
|
-
self.comms_handler.send_status({"status": "error", "error": str(e)})
|
513
|
-
|
514
586
|
|
515
587
|
def main():
|
516
588
|
parser = argparse.ArgumentParser()
|
@@ -523,6 +595,8 @@ def main():
|
|
523
595
|
pipe_args.add_argument("--data_port", type=int, required=True)
|
524
596
|
pipe_args.add_argument("--broadcast_port", type=int, required=True)
|
525
597
|
pipe_args.add_argument("--handshake_port", type=int, required=True)
|
598
|
+
pipe_args.add_argument("--vllm_group_port", type=int, required=True)
|
599
|
+
pipe_args.add_argument("--vllm_port", type=int, required=True)
|
526
600
|
|
527
601
|
training_args = parser.add_argument_group("Training arguments")
|
528
602
|
training_args.add_argument(
|
@@ -544,6 +618,11 @@ def main():
|
|
544
618
|
args = parser.parse_args()
|
545
619
|
|
546
620
|
if args.debug:
|
621
|
+
# python grpo_training.py --debug
|
622
|
+
# --command_port 0 --status_port 0
|
623
|
+
# --data_port 0 --broadcast_port 0
|
624
|
+
# --handshake_port 0 --model Qwen/Qwen3-0.6B
|
625
|
+
# --trl_train_kwargs '{"output_dir": ".", "report_to": "none"}'
|
547
626
|
server_comms_handler = ArborServerCommsHandler(
|
548
627
|
host=args.host,
|
549
628
|
)
|
@@ -554,6 +633,11 @@ def main():
|
|
554
633
|
args.broadcast_port = server_comms_handler.broadcast_port
|
555
634
|
args.handshake_port = server_comms_handler.handshake_port
|
556
635
|
|
636
|
+
handshake_thread = threading.Thread(
|
637
|
+
target=server_comms_handler.wait_for_clients, args=(1,), daemon=True
|
638
|
+
)
|
639
|
+
handshake_thread.start()
|
640
|
+
|
557
641
|
def debug_data_generator():
|
558
642
|
tldr_dataset = load_dataset("trl-lib/tldr", split="train")
|
559
643
|
idx = 0
|
@@ -636,15 +720,18 @@ def main():
|
|
636
720
|
training_args = GRPOConfig(
|
637
721
|
dataloader_num_workers=0,
|
638
722
|
shuffle_dataset=False,
|
723
|
+
vllm_server_port=args.vllm_port,
|
639
724
|
**trl_train_args,
|
640
725
|
)
|
641
726
|
|
727
|
+
weight_update_callback = WeightUpdateCallback()
|
642
728
|
trainer = ArborGRPOTrainer(
|
643
729
|
model=args.model,
|
644
730
|
args=training_args,
|
645
731
|
train_dataset=BlockingQueueDataset(None, None),
|
646
|
-
callbacks=[LastStepTimeCallback()],
|
732
|
+
callbacks=[LastStepTimeCallback(), weight_update_callback],
|
647
733
|
peft_config=lora_config,
|
734
|
+
vllm_group_port=args.vllm_group_port,
|
648
735
|
**arbor_train_args,
|
649
736
|
)
|
650
737
|
# Create client handler
|
@@ -657,6 +744,8 @@ def main():
|
|
657
744
|
handshake_port=args.handshake_port,
|
658
745
|
is_main_process=trainer.accelerator.is_main_process,
|
659
746
|
)
|
747
|
+
weight_update_callback.set_comms_handler(comms_handler)
|
748
|
+
weight_update_callback.set_trainer(trainer)
|
660
749
|
trainer.comms_handler = comms_handler
|
661
750
|
|
662
751
|
# Initialize the dataset with the actual accelerator
|
@@ -671,6 +760,18 @@ def main():
|
|
671
760
|
base_model_name=args.model,
|
672
761
|
)
|
673
762
|
|
763
|
+
# Add signal handlers for graceful shutdown
|
764
|
+
def signal_handler(signum, frame):
|
765
|
+
print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
|
766
|
+
print("Ending training...")
|
767
|
+
trainer.accelerator.end_training()
|
768
|
+
print("Closing communications...")
|
769
|
+
comms_handler.close()
|
770
|
+
sys.exit(0)
|
771
|
+
|
772
|
+
signal.signal(signal.SIGINT, signal_handler)
|
773
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
774
|
+
|
674
775
|
print("Training...")
|
675
776
|
trainer.train()
|
676
777
|
|
@@ -681,7 +782,10 @@ def main():
|
|
681
782
|
comms_handler.send_status({"status": "error", "error": str(e)})
|
682
783
|
raise e
|
683
784
|
finally:
|
785
|
+
print("Cleaning up resources...")
|
786
|
+
trainer.accelerator.end_training()
|
684
787
|
comms_handler.close()
|
788
|
+
print("Cleanup complete")
|
685
789
|
|
686
790
|
|
687
791
|
if __name__ == "__main__":
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import argparse
|
2
|
+
import json
|
3
|
+
import random
|
4
|
+
import threading
|
5
|
+
import time
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import zmq
|
9
|
+
from peft import LoraConfig
|
10
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11
|
+
from trl import SFTConfig, SFTTrainer, setup_chat_format
|
12
|
+
|
13
|
+
from arbor.server.services.scripts.utils.arg_parser import get_training_arg_parser
|
14
|
+
|
15
|
+
|
16
|
+
def main():
|
17
|
+
parser = get_training_arg_parser()
|
18
|
+
parser.add_argument("--model", type=str, required=True)
|
19
|
+
parser.add_argument("--lora", type=bool, default=False)
|
20
|
+
args = parser.parse_args()
|
21
|
+
|
22
|
+
try:
|
23
|
+
trl_train_kwargs = {**(args.trl_config_kwargs or {})}
|
24
|
+
|
25
|
+
# TODO: These assertions should be done in some better way
|
26
|
+
assert "output_dir" in trl_train_kwargs, "output_dir is required"
|
27
|
+
if "gradient_checkpointing_kwargs" in trl_train_kwargs and args.lora:
|
28
|
+
print(
|
29
|
+
"Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
|
30
|
+
)
|
31
|
+
trl_train_kwargs["gradient_checkpointing_kwargs"] = {
|
32
|
+
**(trl_train_kwargs.get("gradient_checkpointing_kwargs") or {}),
|
33
|
+
"use_reentrant": False,
|
34
|
+
}
|
35
|
+
|
36
|
+
lora_config = None
|
37
|
+
if args.lora:
|
38
|
+
print("Using LORA for PEFT")
|
39
|
+
lora_config = LoraConfig(
|
40
|
+
r=16,
|
41
|
+
lora_alpha=64,
|
42
|
+
target_modules=[
|
43
|
+
"q_proj",
|
44
|
+
"k_proj",
|
45
|
+
"v_proj",
|
46
|
+
"o_proj",
|
47
|
+
"up_proj",
|
48
|
+
"down_proj",
|
49
|
+
"gate_proj",
|
50
|
+
],
|
51
|
+
task_type="CAUSAL_LM",
|
52
|
+
lora_dropout=0.05,
|
53
|
+
inference_mode=False,
|
54
|
+
)
|
55
|
+
|
56
|
+
training_args = GRPOConfig(
|
57
|
+
dataloader_num_workers=0,
|
58
|
+
shuffle_dataset=False,
|
59
|
+
**trl_train_args,
|
60
|
+
)
|
61
|
+
|
62
|
+
trainer = ArborGRPOTrainer(
|
63
|
+
model=args.model,
|
64
|
+
args=training_args,
|
65
|
+
train_dataset=BlockingQueueDataset(None, None),
|
66
|
+
callbacks=[LastStepTimeCallback()],
|
67
|
+
peft_config=lora_config,
|
68
|
+
**arbor_train_args,
|
69
|
+
)
|
70
|
+
# Create client handler
|
71
|
+
comms_handler = ArborScriptCommsHandler(
|
72
|
+
host=args.host,
|
73
|
+
command_port=args.command_port,
|
74
|
+
status_port=args.status_port,
|
75
|
+
data_port=args.data_port,
|
76
|
+
broadcast_port=args.broadcast_port,
|
77
|
+
handshake_port=args.handshake_port,
|
78
|
+
is_main_process=trainer.accelerator.is_main_process,
|
79
|
+
)
|
80
|
+
trainer.comms_handler = comms_handler
|
81
|
+
|
82
|
+
# Initialize the dataset with the actual accelerator
|
83
|
+
trainer.train_dataset = BlockingQueueDataset(
|
84
|
+
accelerator=trainer.accelerator,
|
85
|
+
comms_handler=trainer.comms_handler,
|
86
|
+
)
|
87
|
+
|
88
|
+
command_monitor = CommandMonitor(
|
89
|
+
comms_handler=comms_handler,
|
90
|
+
trainer=trainer,
|
91
|
+
base_model_name=args.model,
|
92
|
+
)
|
93
|
+
|
94
|
+
print("Training...")
|
95
|
+
trainer.train()
|
96
|
+
|
97
|
+
except KeyboardInterrupt:
|
98
|
+
print("\nReceived interrupt, shutting down...")
|
99
|
+
except Exception as e:
|
100
|
+
print(f"Error: {e}")
|
101
|
+
comms_handler.send_status({"status": "error", "error": str(e)})
|
102
|
+
raise e
|
103
|
+
finally:
|
104
|
+
trainer.accelerator.end_training()
|
105
|
+
comms_handler.close()
|
106
|
+
|
107
|
+
|
108
|
+
if __name__ == "__main__":
|
109
|
+
main()
|
File without changes
|
@@ -0,0 +1,31 @@
|
|
1
|
+
"""The arg parser for the training scripts"""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import json
|
5
|
+
|
6
|
+
|
7
|
+
def get_training_arg_parser():
|
8
|
+
parser = argparse.ArgumentParser()
|
9
|
+
parser.add_argument("--debug", action="store_true")
|
10
|
+
|
11
|
+
pipe_args = parser.add_argument_group("Comms arguments")
|
12
|
+
pipe_args.add_argument("--host", default="localhost")
|
13
|
+
pipe_args.add_argument("--command_port", type=int, required=True)
|
14
|
+
pipe_args.add_argument("--status_port", type=int, required=True)
|
15
|
+
pipe_args.add_argument("--data_port", type=int, required=True)
|
16
|
+
pipe_args.add_argument("--broadcast_port", type=int, required=True)
|
17
|
+
pipe_args.add_argument("--handshake_port", type=int, required=True)
|
18
|
+
|
19
|
+
training_args = parser.add_argument_group("Training arguments")
|
20
|
+
training_args.add_argument(
|
21
|
+
"--model",
|
22
|
+
type=str,
|
23
|
+
help="Model to use for training",
|
24
|
+
)
|
25
|
+
training_args.add_argument(
|
26
|
+
"--trl_config_kwargs",
|
27
|
+
type=json.loads,
|
28
|
+
help="Training configs as a JSON string",
|
29
|
+
)
|
30
|
+
|
31
|
+
return parser
|
File without changes
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: arbor-ai
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.15
|
4
4
|
Summary: A framework for fine-tuning and managing language models
|
5
5
|
Author-email: Noah Ziems <nziems2@nd.edu>
|
6
6
|
Project-URL: Homepage, https://github.com/Ziems/arbor
|
@@ -8,21 +8,20 @@ Project-URL: Issues, https://github.com/Ziems/arbor/issues
|
|
8
8
|
Requires-Python: >=3.10
|
9
9
|
Description-Content-Type: text/markdown
|
10
10
|
License-File: LICENSE
|
11
|
+
Requires-Dist: torch>=2.6.0
|
11
12
|
Requires-Dist: fastapi
|
12
13
|
Requires-Dist: uvicorn
|
13
14
|
Requires-Dist: click
|
14
15
|
Requires-Dist: python-multipart
|
15
16
|
Requires-Dist: pydantic-settings
|
16
|
-
Requires-Dist:
|
17
|
+
Requires-Dist: vllm>=0.8.5.post1
|
17
18
|
Requires-Dist: transformers
|
18
|
-
Requires-Dist: trl
|
19
|
+
Requires-Dist: trl>=0.17.0
|
19
20
|
Requires-Dist: peft
|
20
21
|
Requires-Dist: ray>=2.9
|
21
22
|
Requires-Dist: setuptools<77.0.0,>=76.0.0
|
22
23
|
Requires-Dist: pyzmq>=26.4.0
|
23
24
|
Requires-Dist: pyyaml>=6.0.2
|
24
|
-
Requires-Dist: sglang[all]>=0.4.5.post3
|
25
|
-
Requires-Dist: sglang-router
|
26
25
|
Requires-Dist: wandb
|
27
26
|
Dynamic: license-file
|
28
27
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
arbor/cli.py,sha256=
|
2
|
+
arbor/cli.py,sha256=6S_nT93Zof6nB01n-xA7hSzssREzY13Oyh_jrElyTTY,3490
|
3
3
|
arbor/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
|
5
5
|
arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
@@ -8,8 +8,8 @@ arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,
|
|
8
8
|
arbor/server/api/models/schemas.py,sha256=KCHav1nPFbQEynrcO-MObhRmoOrdFvfGuVogApynOCA,6210
|
9
9
|
arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
|
11
|
-
arbor/server/api/routes/grpo.py,sha256=
|
12
|
-
arbor/server/api/routes/inference.py,sha256=
|
11
|
+
arbor/server/api/routes/grpo.py,sha256=QrWwj44-EenOyDwtiAO7OJPPGe8CyNaxCUTDlqfJs4g,2338
|
12
|
+
arbor/server/api/routes/inference.py,sha256=JI4lm7zWrUqgMadWA0JuTD13hq6kGQpTLcuklhOH7f8,1547
|
13
13
|
arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
|
14
14
|
arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
15
15
|
arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
|
@@ -17,18 +17,26 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
17
17
|
arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
18
|
arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
|
20
|
-
arbor/server/services/grpo_manager.py,sha256
|
21
|
-
arbor/server/services/inference_manager.py,sha256=
|
20
|
+
arbor/server/services/grpo_manager.py,sha256=y5gOko_RmyjQqvzlR79_PPZgMwMwCMJiaeygCG5qS-A,18761
|
21
|
+
arbor/server/services/inference_manager.py,sha256=a1c5zYbjk6fPM3egX2McKv7ZWPN7c-QH_Qogu9iay90,9597
|
22
22
|
arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
|
23
23
|
arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
|
24
24
|
arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
|
26
|
-
arbor/server/services/
|
26
|
+
arbor/server/services/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
+
arbor/server/services/inference/vllm_client.py,sha256=X0v6zGHuaROGniWw_VCkzeWWuAHq0PlwtrFjTngCT4k,18285
|
28
|
+
arbor/server/services/inference/vllm_serve.py,sha256=GdcaQStGKLj4J1kAnAnnI07R0X3A-bPoj7Tvagxsias,109457
|
29
|
+
arbor/server/services/scripts/dpo_training.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
30
|
+
arbor/server/services/scripts/grpo_training.py,sha256=qjYSinOhi9-vvKY-gqGARwUgDQXYGDHlp9ZLwqKE1rw,31123
|
31
|
+
arbor/server/services/scripts/sft_training.py,sha256=jgDMxZn9RFH9ys_7OF9Is8pQ9V97O2KzWg22Gveh3yE,3410
|
32
|
+
arbor/server/services/scripts/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
|
+
arbor/server/services/scripts/utils/arg_parser.py,sha256=ur_iyhc_Ie00tjq63vK4Sdeu2PGKwe6Dh6Iax2vw9jc,1022
|
34
|
+
arbor/server/services/scripts/utils/dataset.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
35
|
arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
36
|
arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
-
arbor_ai-0.1.
|
30
|
-
arbor_ai-0.1.
|
31
|
-
arbor_ai-0.1.
|
32
|
-
arbor_ai-0.1.
|
33
|
-
arbor_ai-0.1.
|
34
|
-
arbor_ai-0.1.
|
37
|
+
arbor_ai-0.1.15.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
|
38
|
+
arbor_ai-0.1.15.dist-info/METADATA,sha256=GMGq6nbWEbRZxsJG2u7DhnMj6qCSTvssMVUN4ASs2BA,2413
|
39
|
+
arbor_ai-0.1.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
40
|
+
arbor_ai-0.1.15.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
|
41
|
+
arbor_ai-0.1.15.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
|
42
|
+
arbor_ai-0.1.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|