arbor-ai 0.1.14__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,13 +5,18 @@
5
5
 
6
6
  import argparse
7
7
  import json
8
+ import os
8
9
  import random
10
+ import shutil
11
+ import signal
12
+ import sys
9
13
  import threading
10
14
  import time
11
15
  from functools import lru_cache
12
16
  from typing import Any, List, Optional, Union
13
17
 
14
18
  import torch
19
+ import trl.extras.vllm_client
15
20
  import zmq
16
21
  from accelerate import Accelerator
17
22
  from accelerate.utils import broadcast_object_list, gather, gather_object
@@ -32,6 +37,9 @@ from arbor.server.services.comms.comms import (
32
37
  ArborScriptCommsHandler,
33
38
  ArborServerCommsHandler,
34
39
  )
40
+ from arbor.server.services.inference.vllm_client import VLLMClient
41
+
42
+ trl.extras.vllm_client.VLLMClient = VLLMClient
35
43
 
36
44
  if is_wandb_available():
37
45
  import wandb
@@ -71,10 +79,10 @@ class ArborGRPOTrainer(GRPOTrainer):
71
79
  comms_handler: Optional[ArborScriptCommsHandler] = None,
72
80
  lora: Optional[bool] = False,
73
81
  # We do nothing with max_context_length right now
82
+ vllm_group_port: Optional[int] = None,
74
83
  max_context_length: Optional[int] = None,
75
84
  **kwargs,
76
85
  ):
77
-
78
86
  super().__init__(
79
87
  model=model,
80
88
  reward_funcs=[],
@@ -91,6 +99,33 @@ class ArborGRPOTrainer(GRPOTrainer):
91
99
  self.scale_rewards = scale_rewards
92
100
  self.comms_handler = comms_handler
93
101
 
102
+ self.vllm_client = None
103
+ args.use_vllm = True
104
+ self.use_vllm = True
105
+ if self.accelerator.is_main_process:
106
+ print(
107
+ f"Initializing vLLM client with server port {args.vllm_server_port} and group port {vllm_group_port}"
108
+ )
109
+ self.vllm_client = VLLMClient(
110
+ args.vllm_server_host,
111
+ args.vllm_server_port,
112
+ group_port=vllm_group_port,
113
+ connection_timeout=args.vllm_server_timeout,
114
+ )
115
+ self.vllm_client.init_communicator()
116
+
117
+ # vLLM specific sampling arguments
118
+ self.guided_decoding_regex = args.vllm_guided_decoding_regex
119
+
120
+ self._last_loaded_step = (
121
+ -1
122
+ ) # tag to avoid useless loading during grad accumulation
123
+
124
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
125
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
126
+ # synchronize all processes after vLLM has been fully initialized.
127
+ self.accelerator.wait_for_everyone()
128
+
94
129
  def _generate_and_score_completions(
95
130
  self, batch: List[dict[str, Any]]
96
131
  ) -> dict[str, Union[torch.Tensor, Any]]:
@@ -104,7 +139,11 @@ class ArborGRPOTrainer(GRPOTrainer):
104
139
  maybe_apply_chat_template(
105
140
  {
106
141
  "prompt": example["messages"],
107
- "completion": [example["completion"]],
142
+ "completion": (
143
+ example["completion"]
144
+ if isinstance(example["completion"], list)
145
+ else [example["completion"]]
146
+ ),
108
147
  },
109
148
  self.processing_class,
110
149
  )
@@ -133,15 +172,15 @@ class ArborGRPOTrainer(GRPOTrainer):
133
172
  prompt_completion_text["completion"]
134
173
  for prompt_completion_text in prompt_completion_texts
135
174
  ]
136
- completion_ids = self.processing_class(
175
+ completion_inputs = self.processing_class(
137
176
  completions_text,
138
177
  return_tensors="pt",
139
178
  padding=True,
140
179
  add_special_tokens=False,
141
180
  ).to(device)
142
181
  completion_ids, completion_mask = (
143
- completion_ids["input_ids"],
144
- completion_ids["attention_mask"],
182
+ completion_inputs["input_ids"],
183
+ completion_inputs["attention_mask"],
145
184
  )
146
185
 
147
186
  if self.max_prompt_length is not None:
@@ -156,11 +195,6 @@ class ArborGRPOTrainer(GRPOTrainer):
156
195
  completion_ids = completion_ids[:, : self.max_completion_length]
157
196
  completion_mask = completion_mask[:, : self.max_completion_length]
158
197
 
159
- # Keeping this for when we switch to vllm
160
- # if self.state.global_step != self._last_loaded_step:
161
- # self._move_model_to_vllm()
162
- # self._last_loaded_step = self.state.global_step
163
-
164
198
  prompt_ids = broadcast_object_list(prompt_ids)
165
199
  prompt_mask = broadcast_object_list(prompt_mask)
166
200
  completion_ids = broadcast_object_list(completion_ids)
@@ -178,6 +212,9 @@ class ArborGRPOTrainer(GRPOTrainer):
178
212
 
179
213
  is_eos = completion_ids == self.processing_class.eos_token_id
180
214
 
215
+ # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
216
+ completion_lengths = completion_mask.sum(1)
217
+
181
218
  # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
182
219
  if self.mask_truncated_completions:
183
220
  truncated_completions = ~is_eos.any(dim=1)
@@ -230,6 +267,10 @@ class ArborGRPOTrainer(GRPOTrainer):
230
267
  std_grouped_rewards = std_grouped_rewards.repeat_interleave(
231
268
  self.num_generations, dim=0
232
269
  )
270
+ is_std_zero = torch.isclose(
271
+ std_grouped_rewards, torch.zeros_like(std_grouped_rewards)
272
+ )
273
+
233
274
  advantages = rewards - mean_grouped_rewards
234
275
 
235
276
  if self.scale_rewards:
@@ -241,66 +282,72 @@ class ArborGRPOTrainer(GRPOTrainer):
241
282
  self.accelerator.process_index * len(batch),
242
283
  (self.accelerator.process_index + 1) * len(batch),
243
284
  )
285
+ all_process_advantages = (
286
+ advantages.clone()
287
+ ) # keep the aggregated advantages for logging
244
288
  advantages = advantages[process_slice]
245
289
 
246
290
  # Log the metrics
247
291
  if mode == "train":
248
292
  self.state.num_input_tokens_seen += (
249
- self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
293
+ self.accelerator.gather(attention_mask.sum()).sum().item()
250
294
  )
251
295
  self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
252
296
 
253
- # log completion lengths, mean, min, max
254
- agg_completion_mask = self.accelerator.gather_for_metrics(
255
- completion_mask.sum(1)
256
- )
297
+ # Log completion lengths, mean, min, max
298
+ agg_completion_lengths = self.accelerator.gather(completion_lengths)
257
299
  self._metrics[mode]["completions/mean_length"].append(
258
- agg_completion_mask.float().mean().item()
300
+ agg_completion_lengths.float().mean().item()
259
301
  )
260
302
  self._metrics[mode]["completions/min_length"].append(
261
- agg_completion_mask.float().min().item()
303
+ agg_completion_lengths.float().min().item()
262
304
  )
263
305
  self._metrics[mode]["completions/max_length"].append(
264
- agg_completion_mask.float().max().item()
306
+ agg_completion_lengths.float().max().item()
265
307
  )
266
308
 
267
- # identify sequences that terminated with EOS and log their lengths
268
- agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
269
- term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
270
- clipped_completions_ratio = 1 - len(term_completion_mask) / len(
271
- agg_completion_mask
309
+ # Identify sequences that terminated with EOS and log their lengths
310
+ agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
311
+ term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
312
+ clipped_completions_ratio = 1 - len(term_completion_lengths) / len(
313
+ agg_completion_lengths
272
314
  )
273
315
  self._metrics[mode]["completions/clipped_ratio"].append(
274
316
  clipped_completions_ratio
275
317
  )
276
- if len(term_completion_mask) == 0:
277
- # edge case where no completed sequences are found
278
- term_completion_mask = torch.zeros(1, device=device)
318
+ if (
319
+ len(term_completion_lengths) == 0
320
+ ): # edge case where no terminated sequences are found
321
+ term_completion_lengths = torch.zeros(1, device=device)
279
322
  self._metrics[mode]["completions/mean_terminated_length"].append(
280
- term_completion_mask.float().mean().item()
323
+ term_completion_lengths.float().mean().item()
281
324
  )
282
325
  self._metrics[mode]["completions/min_terminated_length"].append(
283
- term_completion_mask.float().min().item()
326
+ term_completion_lengths.float().min().item()
284
327
  )
285
328
  self._metrics[mode]["completions/max_terminated_length"].append(
286
- term_completion_mask.float().max().item()
329
+ term_completion_lengths.float().max().item()
287
330
  )
288
331
 
289
- # Calculate mean reward
332
+ # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
290
333
  self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
291
334
  self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
335
+ self._metrics[mode]["frac_reward_zero_std"].append(
336
+ is_std_zero.float().mean().item()
337
+ )
292
338
 
293
339
  # Log prompt and completion texts
294
340
  self._textual_logs["prompt"].extend(gather_object(prompts_text))
295
341
  self._textual_logs["completion"].extend(gather_object(completions_text))
342
+ self._textual_logs["advantages"].extend(all_process_advantages.tolist())
296
343
 
297
344
  return {
298
345
  "prompt_ids": prompt_ids,
299
346
  "prompt_mask": prompt_mask,
300
347
  "completion_ids": completion_ids,
301
348
  "completion_mask": completion_mask,
302
- "old_per_token_logps": old_per_token_logps,
303
349
  "advantages": advantages,
350
+ "old_per_token_logps": old_per_token_logps,
304
351
  }
305
352
 
306
353
 
@@ -313,6 +360,30 @@ class LastStepTimeCallback(TrainerCallback):
313
360
  last_step_time = time.time()
314
361
 
315
362
 
363
+ class WeightUpdateCallback(TrainerCallback):
364
+ """A callback that sends weight update completion status after each step"""
365
+
366
+ def __init__(self):
367
+ self.comms_handler = None
368
+ self.trainer = None
369
+
370
+ def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
371
+ self.comms_handler = comms_handler
372
+
373
+ def set_trainer(self, trainer):
374
+ self.trainer = trainer
375
+
376
+ def on_step_end(self, args, state, control, **kwargs):
377
+ if self.comms_handler and self.comms_handler.is_main_process and self.trainer:
378
+ if state.global_step != self.trainer._last_loaded_step:
379
+ print("Updating inference model...")
380
+ self.comms_handler.send_status({"status": "weight_update_start"})
381
+ self.trainer._move_model_to_vllm()
382
+ self.trainer._last_loaded_step = state.global_step
383
+ print("[DEBUG] Sending weight update completion status")
384
+ self.comms_handler.send_status({"status": "weight_update_complete"})
385
+
386
+
316
387
  class BlockingQueueDataset(Dataset):
317
388
  def __init__(
318
389
  self,
@@ -379,11 +450,6 @@ class CommandMonitor:
379
450
  )
380
451
  self.command_thread.start()
381
452
 
382
- self.broadcast_thread = threading.Thread(
383
- target=self._monitor_broadcasts, daemon=True
384
- )
385
- self.broadcast_thread.start()
386
-
387
453
  def _monitor_commands(self):
388
454
  """Background thread that monitors for commands from the server."""
389
455
  if not self.comms_handler:
@@ -478,6 +544,26 @@ class CommandMonitor:
478
544
  output_dir=self.trainer.args.output_dir
479
545
  + f"/checkpoints/{command.get('checkpoint_name')}/"
480
546
  )
547
+
548
+ # Copy checkpoint files to root output directory
549
+ checkpoint_dir = (
550
+ self.trainer.args.output_dir
551
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
552
+ )
553
+ root_dir = self.trainer.args.output_dir
554
+
555
+ # Copy all files from checkpoint dir to root dir, overwriting if they exist
556
+ # (effectively saves the checkpoint to the output directory)
557
+ for item in os.listdir(checkpoint_dir):
558
+ src = os.path.join(checkpoint_dir, item)
559
+ dst = os.path.join(root_dir, item)
560
+ if os.path.isdir(src):
561
+ if os.path.exists(dst):
562
+ shutil.rmtree(dst)
563
+ shutil.copytree(src, dst)
564
+ else:
565
+ shutil.copy2(src, dst)
566
+
481
567
  self.comms_handler.send_status(
482
568
  {
483
569
  "status": "checkpoint_saved",
@@ -486,31 +572,21 @@ class CommandMonitor:
486
572
  + f"/checkpoints/{command.get('checkpoint_name')}/",
487
573
  }
488
574
  )
575
+ self.comms_handler.send_status(
576
+ {
577
+ "status": "model_saved",
578
+ "output_dir": self.trainer.args.output_dir,
579
+ }
580
+ )
581
+ elif command.get("command") == "terminate":
582
+ print("TERMINATED")
583
+ self.trainer.accelerator.end_training()
584
+ self.comms_handler.send_status({"status": "terminated"})
489
585
 
490
586
  except Exception as e:
491
587
  print(e)
492
588
  self.comms_handler.send_status({"status": "error", "error": str(e)})
493
589
 
494
- def _monitor_broadcasts(self):
495
- """Background thread that monitors for broadcasts from the server."""
496
- if not self.comms_handler:
497
- return
498
- try:
499
- for broadcast in self.comms_handler.receive_broadcast():
500
- print(f"!!!Received broadcast: {broadcast}")
501
- if broadcast.get("message") == "terminate":
502
- # self.trainer.control.should_training_stop = True
503
- # self.comms_handler.send_status(
504
- # {
505
- # "status": "Received termination command",
506
- # "process_id": self.trainer.accelerator.process_index,
507
- # }
508
- # )
509
- if self.trainer.accelerator.is_main_process:
510
- self.trainer.accelerator.end_training()
511
- except Exception as e:
512
- self.comms_handler.send_status({"status": "error", "error": str(e)})
513
-
514
590
 
515
591
  def main():
516
592
  parser = argparse.ArgumentParser()
@@ -523,6 +599,8 @@ def main():
523
599
  pipe_args.add_argument("--data_port", type=int, required=True)
524
600
  pipe_args.add_argument("--broadcast_port", type=int, required=True)
525
601
  pipe_args.add_argument("--handshake_port", type=int, required=True)
602
+ pipe_args.add_argument("--vllm_group_port", type=int, required=True)
603
+ pipe_args.add_argument("--vllm_port", type=int, required=True)
526
604
 
527
605
  training_args = parser.add_argument_group("Training arguments")
528
606
  training_args.add_argument(
@@ -544,6 +622,11 @@ def main():
544
622
  args = parser.parse_args()
545
623
 
546
624
  if args.debug:
625
+ # python grpo_training.py --debug
626
+ # --command_port 0 --status_port 0
627
+ # --data_port 0 --broadcast_port 0
628
+ # --handshake_port 0 --model Qwen/Qwen3-0.6B
629
+ # --trl_train_kwargs '{"output_dir": ".", "report_to": "none"}'
547
630
  server_comms_handler = ArborServerCommsHandler(
548
631
  host=args.host,
549
632
  )
@@ -554,6 +637,11 @@ def main():
554
637
  args.broadcast_port = server_comms_handler.broadcast_port
555
638
  args.handshake_port = server_comms_handler.handshake_port
556
639
 
640
+ handshake_thread = threading.Thread(
641
+ target=server_comms_handler.wait_for_clients, args=(1,), daemon=True
642
+ )
643
+ handshake_thread.start()
644
+
557
645
  def debug_data_generator():
558
646
  tldr_dataset = load_dataset("trl-lib/tldr", split="train")
559
647
  idx = 0
@@ -636,15 +724,18 @@ def main():
636
724
  training_args = GRPOConfig(
637
725
  dataloader_num_workers=0,
638
726
  shuffle_dataset=False,
727
+ vllm_server_port=args.vllm_port,
639
728
  **trl_train_args,
640
729
  )
641
730
 
731
+ weight_update_callback = WeightUpdateCallback()
642
732
  trainer = ArborGRPOTrainer(
643
733
  model=args.model,
644
734
  args=training_args,
645
735
  train_dataset=BlockingQueueDataset(None, None),
646
- callbacks=[LastStepTimeCallback()],
736
+ callbacks=[LastStepTimeCallback(), weight_update_callback],
647
737
  peft_config=lora_config,
738
+ vllm_group_port=args.vllm_group_port,
648
739
  **arbor_train_args,
649
740
  )
650
741
  # Create client handler
@@ -657,6 +748,8 @@ def main():
657
748
  handshake_port=args.handshake_port,
658
749
  is_main_process=trainer.accelerator.is_main_process,
659
750
  )
751
+ weight_update_callback.set_comms_handler(comms_handler)
752
+ weight_update_callback.set_trainer(trainer)
660
753
  trainer.comms_handler = comms_handler
661
754
 
662
755
  # Initialize the dataset with the actual accelerator
@@ -671,6 +764,18 @@ def main():
671
764
  base_model_name=args.model,
672
765
  )
673
766
 
767
+ # Add signal handlers for graceful shutdown
768
+ def signal_handler(signum, frame):
769
+ print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
770
+ print("Ending training...")
771
+ trainer.accelerator.end_training()
772
+ print("Closing communications...")
773
+ comms_handler.close()
774
+ sys.exit(0)
775
+
776
+ signal.signal(signal.SIGINT, signal_handler)
777
+ signal.signal(signal.SIGTERM, signal_handler)
778
+
674
779
  print("Training...")
675
780
  trainer.train()
676
781
 
@@ -681,7 +786,10 @@ def main():
681
786
  comms_handler.send_status({"status": "error", "error": str(e)})
682
787
  raise e
683
788
  finally:
789
+ print("Cleaning up resources...")
790
+ trainer.accelerator.end_training()
684
791
  comms_handler.close()
792
+ print("Cleanup complete")
685
793
 
686
794
 
687
795
  if __name__ == "__main__":
@@ -0,0 +1,109 @@
1
+ import argparse
2
+ import json
3
+ import random
4
+ import threading
5
+ import time
6
+
7
+ import torch
8
+ import zmq
9
+ from peft import LoraConfig
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from trl import SFTConfig, SFTTrainer, setup_chat_format
12
+
13
+ from arbor.server.services.scripts.utils.arg_parser import get_training_arg_parser
14
+
15
+
16
+ def main():
17
+ parser = get_training_arg_parser()
18
+ parser.add_argument("--model", type=str, required=True)
19
+ parser.add_argument("--lora", type=bool, default=False)
20
+ args = parser.parse_args()
21
+
22
+ try:
23
+ trl_train_kwargs = {**(args.trl_config_kwargs or {})}
24
+
25
+ # TODO: These assertions should be done in some better way
26
+ assert "output_dir" in trl_train_kwargs, "output_dir is required"
27
+ if "gradient_checkpointing_kwargs" in trl_train_kwargs and args.lora:
28
+ print(
29
+ "Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
30
+ )
31
+ trl_train_kwargs["gradient_checkpointing_kwargs"] = {
32
+ **(trl_train_kwargs.get("gradient_checkpointing_kwargs") or {}),
33
+ "use_reentrant": False,
34
+ }
35
+
36
+ lora_config = None
37
+ if args.lora:
38
+ print("Using LORA for PEFT")
39
+ lora_config = LoraConfig(
40
+ r=16,
41
+ lora_alpha=64,
42
+ target_modules=[
43
+ "q_proj",
44
+ "k_proj",
45
+ "v_proj",
46
+ "o_proj",
47
+ "up_proj",
48
+ "down_proj",
49
+ "gate_proj",
50
+ ],
51
+ task_type="CAUSAL_LM",
52
+ lora_dropout=0.05,
53
+ inference_mode=False,
54
+ )
55
+
56
+ training_args = GRPOConfig(
57
+ dataloader_num_workers=0,
58
+ shuffle_dataset=False,
59
+ **trl_train_args,
60
+ )
61
+
62
+ trainer = ArborGRPOTrainer(
63
+ model=args.model,
64
+ args=training_args,
65
+ train_dataset=BlockingQueueDataset(None, None),
66
+ callbacks=[LastStepTimeCallback()],
67
+ peft_config=lora_config,
68
+ **arbor_train_args,
69
+ )
70
+ # Create client handler
71
+ comms_handler = ArborScriptCommsHandler(
72
+ host=args.host,
73
+ command_port=args.command_port,
74
+ status_port=args.status_port,
75
+ data_port=args.data_port,
76
+ broadcast_port=args.broadcast_port,
77
+ handshake_port=args.handshake_port,
78
+ is_main_process=trainer.accelerator.is_main_process,
79
+ )
80
+ trainer.comms_handler = comms_handler
81
+
82
+ # Initialize the dataset with the actual accelerator
83
+ trainer.train_dataset = BlockingQueueDataset(
84
+ accelerator=trainer.accelerator,
85
+ comms_handler=trainer.comms_handler,
86
+ )
87
+
88
+ command_monitor = CommandMonitor(
89
+ comms_handler=comms_handler,
90
+ trainer=trainer,
91
+ base_model_name=args.model,
92
+ )
93
+
94
+ print("Training...")
95
+ trainer.train()
96
+
97
+ except KeyboardInterrupt:
98
+ print("\nReceived interrupt, shutting down...")
99
+ except Exception as e:
100
+ print(f"Error: {e}")
101
+ comms_handler.send_status({"status": "error", "error": str(e)})
102
+ raise e
103
+ finally:
104
+ trainer.accelerator.end_training()
105
+ comms_handler.close()
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main()
File without changes
@@ -0,0 +1,31 @@
1
+ """The arg parser for the training scripts"""
2
+
3
+ import argparse
4
+ import json
5
+
6
+
7
+ def get_training_arg_parser():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--debug", action="store_true")
10
+
11
+ pipe_args = parser.add_argument_group("Comms arguments")
12
+ pipe_args.add_argument("--host", default="localhost")
13
+ pipe_args.add_argument("--command_port", type=int, required=True)
14
+ pipe_args.add_argument("--status_port", type=int, required=True)
15
+ pipe_args.add_argument("--data_port", type=int, required=True)
16
+ pipe_args.add_argument("--broadcast_port", type=int, required=True)
17
+ pipe_args.add_argument("--handshake_port", type=int, required=True)
18
+
19
+ training_args = parser.add_argument_group("Training arguments")
20
+ training_args.add_argument(
21
+ "--model",
22
+ type=str,
23
+ help="Model to use for training",
24
+ )
25
+ training_args.add_argument(
26
+ "--trl_config_kwargs",
27
+ type=json.loads,
28
+ help="Training configs as a JSON string",
29
+ )
30
+
31
+ return parser
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.14
3
+ Version: 0.2
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -8,21 +8,20 @@ Project-URL: Issues, https://github.com/Ziems/arbor/issues
8
8
  Requires-Python: >=3.10
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE
11
+ Requires-Dist: torch>=2.6.0
11
12
  Requires-Dist: fastapi
12
13
  Requires-Dist: uvicorn
13
14
  Requires-Dist: click
14
15
  Requires-Dist: python-multipart
15
16
  Requires-Dist: pydantic-settings
16
- Requires-Dist: torch
17
+ Requires-Dist: vllm>=0.8.5.post1
17
18
  Requires-Dist: transformers
18
- Requires-Dist: trl
19
+ Requires-Dist: trl>=0.17.0
19
20
  Requires-Dist: peft
20
21
  Requires-Dist: ray>=2.9
21
22
  Requires-Dist: setuptools<77.0.0,>=76.0.0
22
23
  Requires-Dist: pyzmq>=26.4.0
23
24
  Requires-Dist: pyyaml>=6.0.2
24
- Requires-Dist: sglang[all]>=0.4.5.post3
25
- Requires-Dist: sglang-router
26
25
  Requires-Dist: wandb
27
26
  Dynamic: license-file
28
27
 
@@ -41,7 +40,12 @@ Dynamic: license-file
41
40
  Install Arbor via pip:
42
41
 
43
42
  ```bash
44
- pip install arbor-ai
43
+ pip install -U arbor-ai
44
+ ```
45
+
46
+ Optionally, you can also install:
47
+ ```bash
48
+ pip install flash-attn --no-build-isolation
45
49
  ```
46
50
 
47
51
  ---