arbor-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/__init__.py +17 -0
  2. arbor/cli.py +83 -43
  3. arbor/client/arbor_client.py +259 -0
  4. arbor/server/api/models/schemas.py +3 -1
  5. arbor/server/api/routes/grpo.py +2 -6
  6. arbor/server/api/routes/inference.py +7 -3
  7. arbor/server/core/config.py +293 -7
  8. arbor/server/core/config_manager.py +100 -0
  9. arbor/server/main.py +26 -1
  10. arbor/server/services/comms/comms.py +13 -9
  11. arbor/server/services/file_manager.py +7 -4
  12. arbor/server/services/grpo_manager.py +98 -62
  13. arbor/server/services/health_manager.py +171 -0
  14. arbor/server/services/inference/vllm_client.py +6 -4
  15. arbor/server/services/inference_manager.py +40 -38
  16. arbor/server/services/job_manager.py +2 -2
  17. arbor/server/services/scripts/grpo_training.py +62 -281
  18. arbor/server/services/scripts/mmgrpo_training.py +510 -0
  19. arbor/server/services/scripts/sft_training.py +8 -5
  20. arbor/server/services/scripts/utils/callbacks.py +33 -0
  21. arbor/server/services/scripts/utils/comms_monitors.py +169 -0
  22. arbor/server/services/scripts/utils/dataset.py +176 -0
  23. arbor/server/services/scripts/utils/ingestion_monitor.py +35 -0
  24. arbor/server/services/scripts/utils/mock_server.py +124 -0
  25. arbor/server/services/training_manager.py +4 -4
  26. arbor/server/utils/logging.py +298 -0
  27. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/METADATA +8 -18
  28. arbor_ai-0.2.2.dist-info/RECORD +51 -0
  29. arbor_ai-0.2.1.dist-info/RECORD +0 -42
  30. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.2.1.dist-info → arbor_ai-0.2.2.dist-info}/top_level.txt +0 -0
@@ -5,24 +5,19 @@
5
5
 
6
6
  import argparse
7
7
  import json
8
- import os
9
8
  import random
10
- import shutil
11
9
  import signal
12
10
  import sys
13
11
  import threading
14
12
  import time
15
- from functools import lru_cache
16
13
  from typing import Any, List, Optional, Union
17
14
 
18
15
  import torch
19
16
  import trl.extras.vllm_client
20
17
  import zmq
21
- from accelerate import Accelerator
22
18
  from accelerate.utils import broadcast_object_list, gather, gather_object
23
19
  from datasets import Dataset, IterableDataset, load_dataset
24
- from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig # type: ignore
25
- from torch.utils.data import Dataset
20
+ from peft import LoraConfig, PeftConfig
26
21
  from transformers import (
27
22
  PreTrainedModel,
28
23
  PreTrainedTokenizerBase,
@@ -38,28 +33,19 @@ from arbor.server.services.comms.comms import (
38
33
  ArborServerCommsHandler,
39
34
  )
40
35
  from arbor.server.services.inference.vllm_client import VLLMClient
36
+ from arbor.server.services.scripts.utils.callbacks import WeightUpdateCallback
37
+ from arbor.server.services.scripts.utils.comms_monitors import CommandMonitor
38
+ from arbor.server.services.scripts.utils.dataset import BlockingRotatingQueueDataset
39
+ from arbor.server.services.scripts.utils.ingestion_monitor import IngestionMonitor
41
40
 
42
41
  trl.extras.vllm_client.VLLMClient = VLLMClient
43
42
 
44
- if is_wandb_available():
45
- import wandb
46
-
47
- last_step_time = None
48
- last_queue_pop_time = None
49
-
50
-
51
- def time_since_last_step():
52
- global last_step_time
53
- if last_step_time is None:
54
- return float("inf")
55
- return time.time() - last_step_time
43
+ from arbor.server.utils.logging import get_logger
56
44
 
45
+ logger = get_logger(__name__)
57
46
 
58
- def get_time_since_last_queue_pop():
59
- global last_queue_pop_time
60
- if last_queue_pop_time is None:
61
- return float("inf")
62
- return time.time() - last_queue_pop_time
47
+ if is_wandb_available():
48
+ import wandb
63
49
 
64
50
 
65
51
  class ArborGRPOTrainer(GRPOTrainer):
@@ -77,10 +63,7 @@ class ArborGRPOTrainer(GRPOTrainer):
77
63
  ] = (None, None),
78
64
  peft_config: Optional["PeftConfig"] = None,
79
65
  comms_handler: Optional[ArborScriptCommsHandler] = None,
80
- lora: Optional[bool] = False,
81
- # We do nothing with max_context_length right now
82
66
  vllm_group_port: Optional[int] = None,
83
- max_context_length: Optional[int] = None,
84
67
  **kwargs,
85
68
  ):
86
69
  super().__init__(
@@ -103,7 +86,7 @@ class ArborGRPOTrainer(GRPOTrainer):
103
86
  args.use_vllm = True
104
87
  self.use_vllm = True
105
88
  if self.accelerator.is_main_process:
106
- print(
89
+ logger.info(
107
90
  f"Initializing vLLM client with server port {args.vllm_server_port} and group port {vllm_group_port}"
108
91
  )
109
92
  self.vllm_client = VLLMClient(
@@ -185,13 +168,15 @@ class ArborGRPOTrainer(GRPOTrainer):
185
168
 
186
169
  if self.max_prompt_length is not None:
187
170
  if prompt_ids.shape[1] > self.max_prompt_length:
188
- print(f"Truncating prompt to {self.max_prompt_length} tokens")
171
+ logger.info(f"Truncating prompt to {self.max_prompt_length} tokens")
189
172
  prompt_ids = prompt_ids[:, -self.max_prompt_length :]
190
173
  prompt_mask = prompt_mask[:, -self.max_prompt_length :]
191
174
 
192
175
  if self.max_completion_length is not None:
193
176
  if completion_ids.shape[1] > self.max_completion_length:
194
- print(f"Truncating completion to {self.max_completion_length} tokens")
177
+ logger.info(
178
+ f"Truncating completion to {self.max_completion_length} tokens"
179
+ )
195
180
  completion_ids = completion_ids[:, : self.max_completion_length]
196
181
  completion_mask = completion_mask[:, : self.max_completion_length]
197
182
 
@@ -225,7 +210,7 @@ class ArborGRPOTrainer(GRPOTrainer):
225
210
  prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
226
211
  attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
227
212
 
228
- print(
213
+ logger.info(
229
214
  f"prompt_completion_ids.shape (after truncation, if enabled): {prompt_completion_ids.shape}, prompt_ids.shape: {prompt_ids.shape}, completion_ids.shape: {completion_ids.shape}"
230
215
  )
231
216
 
@@ -354,238 +339,14 @@ class ArborGRPOTrainer(GRPOTrainer):
354
339
  class LastStepTimeCallback(TrainerCallback):
355
340
  "A callback that prints a message at the beginning of training"
356
341
 
357
- def on_step_end(self, args, state, control, **kwargs):
358
- global last_step_time
359
- print(f"Time since last step: {time_since_last_step()}")
360
- last_step_time = time.time()
361
-
362
-
363
- class WeightUpdateCallback(TrainerCallback):
364
- """A callback that sends weight update completion status after each step"""
365
-
366
- def __init__(self):
367
- self.comms_handler = None
368
- self.trainer = None
369
-
370
- def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
371
- self.comms_handler = comms_handler
372
-
373
- def set_trainer(self, trainer):
374
- self.trainer = trainer
342
+ def __init__(self, ingestion_monitor: IngestionMonitor):
343
+ self.ingestion_monitor = ingestion_monitor
375
344
 
376
345
  def on_step_end(self, args, state, control, **kwargs):
377
- if self.comms_handler and self.comms_handler.is_main_process and self.trainer:
378
- if state.global_step != self.trainer._last_loaded_step:
379
- print("Updating inference model...")
380
- self.comms_handler.send_status({"status": "weight_update_start"})
381
- self.trainer._move_model_to_vllm()
382
- self.trainer._last_loaded_step = state.global_step
383
- print("[DEBUG] Sending weight update completion status")
384
- self.comms_handler.send_status({"status": "weight_update_complete"})
385
-
386
-
387
- class BlockingQueueDataset(Dataset):
388
- def __init__(
389
- self,
390
- accelerator: Accelerator,
391
- comms_handler: ArborScriptCommsHandler,
392
- size=10_000, # Just a random number
393
- maxsize=100,
394
- ):
395
- self.size = size
396
- self.accelerator = accelerator
397
- self.comms_handler = comms_handler
398
- self.get_cached_data = lru_cache(maxsize=maxsize)(self._get_data)
399
- self.completion_counters = {}
400
-
401
- def __len__(self):
402
- return self.size
403
-
404
- def _get_data(self, idx):
405
- rank = self.accelerator.process_index
406
- world_size = self.accelerator.num_processes
407
-
408
- if self.accelerator.is_main_process:
409
- global last_queue_pop_time
410
- last_queue_pop_time = time.time()
411
-
412
- if idx not in self.completion_counters:
413
- self.completion_counters[idx] = 0
414
-
415
- try:
416
- new_data = self.comms_handler.receive_data()
417
-
418
- except Exception as e:
419
- print(f"[rank {rank}] Error receiving data: {e}")
420
- new_data = None
421
-
422
- return new_data
423
-
424
- def __getitem__(self, idx):
425
- data = self.get_cached_data(idx)
426
- # Create hash of data to detect if processes are using the same idx for the same data
427
- data_hash = format(abs(hash(str(data))) % (16**8), "08x")
428
-
429
- if data is None:
430
- return None
431
-
432
- counter = self.completion_counters.get(idx, 0)
433
- item = data[counter]
434
- self.completion_counters[idx] = (counter + 1) % len(data)
435
- return item
436
-
437
-
438
- class CommandMonitor:
439
- def __init__(
440
- self,
441
- comms_handler: ArborScriptCommsHandler,
442
- trainer: ArborGRPOTrainer,
443
- base_model_name: str,
444
- ):
445
- self.comms_handler = comms_handler
446
- self.trainer = trainer
447
- self.base_model_name = base_model_name
448
- self.command_thread = threading.Thread(
449
- target=self._monitor_commands, daemon=True
346
+ logger.info(
347
+ f"Time since last step: {self.ingestion_monitor.time_since_last_step()}"
450
348
  )
451
- self.command_thread.start()
452
-
453
- def _monitor_commands(self):
454
- """Background thread that monitors for commands from the server."""
455
- if not self.comms_handler:
456
- return
457
- try:
458
- for command in self.comms_handler.receive_command():
459
- print(f"Main process received command: {command}")
460
- if (
461
- command.get("command") == "save_model"
462
- and self.trainer.accelerator.is_main_process
463
- ):
464
- print(
465
- f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
466
- )
467
- while (
468
- time_since_last_step() <= 10
469
- or get_time_since_last_queue_pop() <= 10
470
- ):
471
- print(f"Waiting for steps to finish")
472
- print(
473
- f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
474
- )
475
- print(
476
- f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
477
- )
478
- time.sleep(5)
479
- print("[Training Script] Saving model...")
480
- if self.trainer.peft_config:
481
- self.trainer.save_model(
482
- output_dir=self.trainer.args.output_dir + "/adapter/"
483
- )
484
- _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
485
- self.trainer.args.output_dir + "/adapter/",
486
- config=self.trainer.peft_config,
487
- )
488
- merged_model = _model_to_merge.merge_and_unload()
489
- merged_model.save_pretrained(
490
- self.trainer.args.output_dir,
491
- safe_serialization=True,
492
- )
493
- self.trainer.processing_class.save_pretrained(
494
- self.trainer.args.output_dir
495
- )
496
- else:
497
- self.trainer.save_model()
498
-
499
- print("[Training Script] Model saved")
500
- self.comms_handler.send_status(
501
- {
502
- "status": "model_saved",
503
- "output_dir": self.trainer.args.output_dir,
504
- }
505
- )
506
- elif command.get("command") == "save_checkpoint":
507
- print(
508
- f"[Training Script] Instructed to save checkpoint {command.get('checkpoint_name')}"
509
- )
510
- while (
511
- time_since_last_step() <= 10
512
- or get_time_since_last_queue_pop() <= 10
513
- ):
514
- print(f"Waiting for steps to finish")
515
- print(
516
- f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
517
- )
518
- print(
519
- f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
520
- )
521
- time.sleep(5)
522
- if self.trainer.peft_config:
523
- self.trainer.save_model(
524
- output_dir=self.trainer.args.output_dir
525
- + f"/checkpoints/{command.get('checkpoint_name')}/adapter/"
526
- )
527
- _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
528
- self.trainer.args.output_dir
529
- + f"/checkpoints/{command.get('checkpoint_name')}/adapter/",
530
- config=self.trainer.peft_config,
531
- )
532
- merged_model = _model_to_merge.merge_and_unload()
533
- merged_model.save_pretrained(
534
- self.trainer.args.output_dir
535
- + f"/checkpoints/{command.get('checkpoint_name')}/",
536
- safe_serialization=True,
537
- )
538
- self.trainer.processing_class.save_pretrained(
539
- self.trainer.args.output_dir
540
- + f"/checkpoints/{command.get('checkpoint_name')}/"
541
- )
542
- else:
543
- self.trainer.save_model(
544
- output_dir=self.trainer.args.output_dir
545
- + f"/checkpoints/{command.get('checkpoint_name')}/"
546
- )
547
-
548
- # Copy checkpoint files to root output directory
549
- checkpoint_dir = (
550
- self.trainer.args.output_dir
551
- + f"/checkpoints/{command.get('checkpoint_name')}/"
552
- )
553
- root_dir = self.trainer.args.output_dir
554
-
555
- # Copy all files from checkpoint dir to root dir, overwriting if they exist
556
- # (effectively saves the checkpoint to the output directory)
557
- for item in os.listdir(checkpoint_dir):
558
- src = os.path.join(checkpoint_dir, item)
559
- dst = os.path.join(root_dir, item)
560
- if os.path.isdir(src):
561
- if os.path.exists(dst):
562
- shutil.rmtree(dst)
563
- shutil.copytree(src, dst)
564
- else:
565
- shutil.copy2(src, dst)
566
-
567
- self.comms_handler.send_status(
568
- {
569
- "status": "checkpoint_saved",
570
- "checkpoint_name": command.get("checkpoint_name"),
571
- "output_dir": self.trainer.args.output_dir
572
- + f"/checkpoints/{command.get('checkpoint_name')}/",
573
- }
574
- )
575
- self.comms_handler.send_status(
576
- {
577
- "status": "model_saved",
578
- "output_dir": self.trainer.args.output_dir,
579
- }
580
- )
581
- elif command.get("command") == "terminate":
582
- print("TERMINATED")
583
- self.trainer.accelerator.end_training()
584
- self.comms_handler.send_status({"status": "terminated"})
585
-
586
- except Exception as e:
587
- print(e)
588
- self.comms_handler.send_status({"status": "error", "error": str(e)})
349
+ self.ingestion_monitor.set_last_step_time()
589
350
 
590
351
 
591
352
  def main():
@@ -679,7 +440,7 @@ def main():
679
440
  # Need to set subscription for PUB/SUB pattern
680
441
  server_comms_handler.status_socket.setsockopt_string(zmq.SUBSCRIBE, "")
681
442
  for status in server_comms_handler.receive_status():
682
- print(f"Status: {status}")
443
+ logger.info(f"Status: {status}")
683
444
 
684
445
  status_listener_thread = threading.Thread(target=status_listener, daemon=True)
685
446
  status_listener_thread.start()
@@ -693,7 +454,7 @@ def main():
693
454
  if "gradient_checkpointing_kwargs" in trl_train_args and arbor_train_args.get(
694
455
  "lora", False
695
456
  ):
696
- print(
457
+ logger.info(
697
458
  "Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
698
459
  )
699
460
  trl_train_args["gradient_checkpointing_kwargs"] = {
@@ -703,7 +464,7 @@ def main():
703
464
 
704
465
  lora_config = None
705
466
  if arbor_train_args.get("lora", False):
706
- print("Using LORA for PEFT")
467
+ logger.info("Using LORA for PEFT")
707
468
  lora_config = LoraConfig(
708
469
  r=16,
709
470
  lora_alpha=64,
@@ -721,6 +482,12 @@ def main():
721
482
  inference_mode=False,
722
483
  )
723
484
 
485
+ if "report_to" in trl_train_args and trl_train_args["report_to"] == "wandb":
486
+ import wandb
487
+
488
+ if "wandb_kwargs" in arbor_train_args and arbor_train_args["wandb_kwargs"]:
489
+ wandb.init(**arbor_train_args["wandb_kwargs"])
490
+
724
491
  training_args = GRPOConfig(
725
492
  dataloader_num_workers=0,
726
493
  shuffle_dataset=False,
@@ -728,15 +495,24 @@ def main():
728
495
  **trl_train_args,
729
496
  )
730
497
 
731
- weight_update_callback = WeightUpdateCallback()
498
+ # Create ingestion monitor
499
+ ingestion_monitor = IngestionMonitor()
500
+
501
+ train_dataset = BlockingRotatingQueueDataset(
502
+ ingestion_monitor=ingestion_monitor,
503
+ )
504
+
505
+ weight_update_callback = WeightUpdateCallback(
506
+ ingestion_monitor=ingestion_monitor,
507
+ )
508
+
732
509
  trainer = ArborGRPOTrainer(
733
510
  model=args.model,
734
511
  args=training_args,
735
- train_dataset=BlockingQueueDataset(None, None),
736
- callbacks=[LastStepTimeCallback(), weight_update_callback],
512
+ train_dataset=train_dataset,
513
+ callbacks=[LastStepTimeCallback(ingestion_monitor), weight_update_callback],
737
514
  peft_config=lora_config,
738
515
  vllm_group_port=args.vllm_group_port,
739
- **arbor_train_args,
740
516
  )
741
517
  # Create client handler
742
518
  comms_handler = ArborScriptCommsHandler(
@@ -748,48 +524,53 @@ def main():
748
524
  handshake_port=args.handshake_port,
749
525
  is_main_process=trainer.accelerator.is_main_process,
750
526
  )
527
+
528
+ train_dataset.set_comms_handler(comms_handler)
529
+ train_dataset.set_accelerator(trainer.accelerator)
530
+
751
531
  weight_update_callback.set_comms_handler(comms_handler)
752
532
  weight_update_callback.set_trainer(trainer)
753
533
  trainer.comms_handler = comms_handler
754
534
 
755
- # Initialize the dataset with the actual accelerator
756
- trainer.train_dataset = BlockingQueueDataset(
757
- accelerator=trainer.accelerator,
758
- comms_handler=trainer.comms_handler,
759
- )
760
-
761
535
  command_monitor = CommandMonitor(
762
536
  comms_handler=comms_handler,
763
537
  trainer=trainer,
764
538
  base_model_name=args.model,
539
+ ingestion_monitor=ingestion_monitor,
765
540
  )
541
+ command_monitor.start()
766
542
 
767
543
  # Add signal handlers for graceful shutdown
768
544
  def signal_handler(signum, frame):
769
- print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
770
- print("Ending training...")
545
+ logger.info(f"\nReceived signal {signum}. Initiating graceful shutdown...")
546
+ logger.info("Ending training...")
771
547
  trainer.accelerator.end_training()
772
- print("Closing communications...")
548
+ logger.info("Closing communications...")
773
549
  comms_handler.close()
774
550
  sys.exit(0)
775
551
 
776
552
  signal.signal(signal.SIGINT, signal_handler)
777
553
  signal.signal(signal.SIGTERM, signal_handler)
778
554
 
779
- print("Training...")
780
- trainer.train()
555
+ logger.info("Starting training...")
556
+ try:
557
+ trainer.train()
558
+ except Exception as e:
559
+ logger.error(f"Error during training: {e}")
560
+ logger.error(f"Error type: {type(e).__name__}")
561
+ raise
781
562
 
782
563
  except KeyboardInterrupt:
783
- print("\nReceived interrupt, shutting down...")
564
+ logger.info("\nReceived interrupt, shutting down...")
784
565
  except Exception as e:
785
- print(f"Error: {e}")
566
+ logger.error(f"Error: {e}")
786
567
  comms_handler.send_status({"status": "error", "error": str(e)})
787
568
  raise e
788
569
  finally:
789
- print("Cleaning up resources...")
570
+ logger.info("Cleaning up resources...")
790
571
  trainer.accelerator.end_training()
791
572
  comms_handler.close()
792
- print("Cleanup complete")
573
+ logger.info("Cleanup complete")
793
574
 
794
575
 
795
576
  if __name__ == "__main__":