arbor-ai 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,13 +5,18 @@
5
5
 
6
6
  import argparse
7
7
  import json
8
+ import os
8
9
  import random
10
+ import shutil
11
+ import signal
12
+ import sys
9
13
  import threading
10
14
  import time
11
15
  from functools import lru_cache
12
16
  from typing import Any, List, Optional, Union
13
17
 
14
18
  import torch
19
+ import trl.extras.vllm_client
15
20
  import zmq
16
21
  from accelerate import Accelerator
17
22
  from accelerate.utils import broadcast_object_list, gather, gather_object
@@ -32,6 +37,9 @@ from arbor.server.services.comms.comms import (
32
37
  ArborScriptCommsHandler,
33
38
  ArborServerCommsHandler,
34
39
  )
40
+ from arbor.server.services.inference.vllm_client import VLLMClient
41
+
42
+ trl.extras.vllm_client.VLLMClient = VLLMClient
35
43
 
36
44
  if is_wandb_available():
37
45
  import wandb
@@ -71,10 +79,10 @@ class ArborGRPOTrainer(GRPOTrainer):
71
79
  comms_handler: Optional[ArborScriptCommsHandler] = None,
72
80
  lora: Optional[bool] = False,
73
81
  # We do nothing with max_context_length right now
82
+ vllm_group_port: Optional[int] = None,
74
83
  max_context_length: Optional[int] = None,
75
84
  **kwargs,
76
85
  ):
77
-
78
86
  super().__init__(
79
87
  model=model,
80
88
  reward_funcs=[],
@@ -91,6 +99,33 @@ class ArborGRPOTrainer(GRPOTrainer):
91
99
  self.scale_rewards = scale_rewards
92
100
  self.comms_handler = comms_handler
93
101
 
102
+ self.vllm_client = None
103
+ args.use_vllm = True
104
+ self.use_vllm = True
105
+ if self.accelerator.is_main_process:
106
+ print(
107
+ f"Initializing vLLM client with server port {args.vllm_server_port} and group port {vllm_group_port}"
108
+ )
109
+ self.vllm_client = VLLMClient(
110
+ args.vllm_server_host,
111
+ args.vllm_server_port,
112
+ group_port=vllm_group_port,
113
+ connection_timeout=args.vllm_server_timeout,
114
+ )
115
+ self.vllm_client.init_communicator()
116
+
117
+ # vLLM specific sampling arguments
118
+ self.guided_decoding_regex = args.vllm_guided_decoding_regex
119
+
120
+ self._last_loaded_step = (
121
+ -1
122
+ ) # tag to avoid useless loading during grad accumulation
123
+
124
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
125
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
126
+ # synchronize all processes after vLLM has been fully initialized.
127
+ self.accelerator.wait_for_everyone()
128
+
94
129
  def _generate_and_score_completions(
95
130
  self, batch: List[dict[str, Any]]
96
131
  ) -> dict[str, Union[torch.Tensor, Any]]:
@@ -156,11 +191,6 @@ class ArborGRPOTrainer(GRPOTrainer):
156
191
  completion_ids = completion_ids[:, : self.max_completion_length]
157
192
  completion_mask = completion_mask[:, : self.max_completion_length]
158
193
 
159
- # Keeping this for when we switch to vllm
160
- # if self.state.global_step != self._last_loaded_step:
161
- # self._move_model_to_vllm()
162
- # self._last_loaded_step = self.state.global_step
163
-
164
194
  prompt_ids = broadcast_object_list(prompt_ids)
165
195
  prompt_mask = broadcast_object_list(prompt_mask)
166
196
  completion_ids = broadcast_object_list(completion_ids)
@@ -178,6 +208,9 @@ class ArborGRPOTrainer(GRPOTrainer):
178
208
 
179
209
  is_eos = completion_ids == self.processing_class.eos_token_id
180
210
 
211
+ # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
212
+ completion_lengths = completion_mask.sum(1)
213
+
181
214
  # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
182
215
  if self.mask_truncated_completions:
183
216
  truncated_completions = ~is_eos.any(dim=1)
@@ -230,6 +263,10 @@ class ArborGRPOTrainer(GRPOTrainer):
230
263
  std_grouped_rewards = std_grouped_rewards.repeat_interleave(
231
264
  self.num_generations, dim=0
232
265
  )
266
+ is_std_zero = torch.isclose(
267
+ std_grouped_rewards, torch.zeros_like(std_grouped_rewards)
268
+ )
269
+
233
270
  advantages = rewards - mean_grouped_rewards
234
271
 
235
272
  if self.scale_rewards:
@@ -241,66 +278,72 @@ class ArborGRPOTrainer(GRPOTrainer):
241
278
  self.accelerator.process_index * len(batch),
242
279
  (self.accelerator.process_index + 1) * len(batch),
243
280
  )
281
+ all_process_advantages = (
282
+ advantages.clone()
283
+ ) # keep the aggregated advantages for logging
244
284
  advantages = advantages[process_slice]
245
285
 
246
286
  # Log the metrics
247
287
  if mode == "train":
248
288
  self.state.num_input_tokens_seen += (
249
- self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
289
+ self.accelerator.gather(attention_mask.sum()).sum().item()
250
290
  )
251
291
  self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
252
292
 
253
- # log completion lengths, mean, min, max
254
- agg_completion_mask = self.accelerator.gather_for_metrics(
255
- completion_mask.sum(1)
256
- )
293
+ # Log completion lengths, mean, min, max
294
+ agg_completion_lengths = self.accelerator.gather(completion_lengths)
257
295
  self._metrics[mode]["completions/mean_length"].append(
258
- agg_completion_mask.float().mean().item()
296
+ agg_completion_lengths.float().mean().item()
259
297
  )
260
298
  self._metrics[mode]["completions/min_length"].append(
261
- agg_completion_mask.float().min().item()
299
+ agg_completion_lengths.float().min().item()
262
300
  )
263
301
  self._metrics[mode]["completions/max_length"].append(
264
- agg_completion_mask.float().max().item()
302
+ agg_completion_lengths.float().max().item()
265
303
  )
266
304
 
267
- # identify sequences that terminated with EOS and log their lengths
268
- agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
269
- term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
270
- clipped_completions_ratio = 1 - len(term_completion_mask) / len(
271
- agg_completion_mask
305
+ # Identify sequences that terminated with EOS and log their lengths
306
+ agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
307
+ term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
308
+ clipped_completions_ratio = 1 - len(term_completion_lengths) / len(
309
+ agg_completion_lengths
272
310
  )
273
311
  self._metrics[mode]["completions/clipped_ratio"].append(
274
312
  clipped_completions_ratio
275
313
  )
276
- if len(term_completion_mask) == 0:
277
- # edge case where no completed sequences are found
278
- term_completion_mask = torch.zeros(1, device=device)
314
+ if (
315
+ len(term_completion_lengths) == 0
316
+ ): # edge case where no terminated sequences are found
317
+ term_completion_lengths = torch.zeros(1, device=device)
279
318
  self._metrics[mode]["completions/mean_terminated_length"].append(
280
- term_completion_mask.float().mean().item()
319
+ term_completion_lengths.float().mean().item()
281
320
  )
282
321
  self._metrics[mode]["completions/min_terminated_length"].append(
283
- term_completion_mask.float().min().item()
322
+ term_completion_lengths.float().min().item()
284
323
  )
285
324
  self._metrics[mode]["completions/max_terminated_length"].append(
286
- term_completion_mask.float().max().item()
325
+ term_completion_lengths.float().max().item()
287
326
  )
288
327
 
289
- # Calculate mean reward
328
+ # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
290
329
  self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
291
330
  self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
331
+ self._metrics[mode]["frac_reward_zero_std"].append(
332
+ is_std_zero.float().mean().item()
333
+ )
292
334
 
293
335
  # Log prompt and completion texts
294
336
  self._textual_logs["prompt"].extend(gather_object(prompts_text))
295
337
  self._textual_logs["completion"].extend(gather_object(completions_text))
338
+ self._textual_logs["advantages"].extend(all_process_advantages.tolist())
296
339
 
297
340
  return {
298
341
  "prompt_ids": prompt_ids,
299
342
  "prompt_mask": prompt_mask,
300
343
  "completion_ids": completion_ids,
301
344
  "completion_mask": completion_mask,
302
- "old_per_token_logps": old_per_token_logps,
303
345
  "advantages": advantages,
346
+ "old_per_token_logps": old_per_token_logps,
304
347
  }
305
348
 
306
349
 
@@ -313,6 +356,30 @@ class LastStepTimeCallback(TrainerCallback):
313
356
  last_step_time = time.time()
314
357
 
315
358
 
359
+ class WeightUpdateCallback(TrainerCallback):
360
+ """A callback that sends weight update completion status after each step"""
361
+
362
+ def __init__(self):
363
+ self.comms_handler = None
364
+ self.trainer = None
365
+
366
+ def set_comms_handler(self, comms_handler: ArborScriptCommsHandler):
367
+ self.comms_handler = comms_handler
368
+
369
+ def set_trainer(self, trainer):
370
+ self.trainer = trainer
371
+
372
+ def on_step_end(self, args, state, control, **kwargs):
373
+ if self.comms_handler and self.comms_handler.is_main_process and self.trainer:
374
+ if state.global_step != self.trainer._last_loaded_step:
375
+ print("Updating inference model...")
376
+ self.comms_handler.send_status({"status": "weight_update_start"})
377
+ self.trainer._move_model_to_vllm()
378
+ self.trainer._last_loaded_step = state.global_step
379
+ print("[DEBUG] Sending weight update completion status")
380
+ self.comms_handler.send_status({"status": "weight_update_complete"})
381
+
382
+
316
383
  class BlockingQueueDataset(Dataset):
317
384
  def __init__(
318
385
  self,
@@ -379,11 +446,6 @@ class CommandMonitor:
379
446
  )
380
447
  self.command_thread.start()
381
448
 
382
- self.broadcast_thread = threading.Thread(
383
- target=self._monitor_broadcasts, daemon=True
384
- )
385
- self.broadcast_thread.start()
386
-
387
449
  def _monitor_commands(self):
388
450
  """Background thread that monitors for commands from the server."""
389
451
  if not self.comms_handler:
@@ -478,6 +540,26 @@ class CommandMonitor:
478
540
  output_dir=self.trainer.args.output_dir
479
541
  + f"/checkpoints/{command.get('checkpoint_name')}/"
480
542
  )
543
+
544
+ # Copy checkpoint files to root output directory
545
+ checkpoint_dir = (
546
+ self.trainer.args.output_dir
547
+ + f"/checkpoints/{command.get('checkpoint_name')}/"
548
+ )
549
+ root_dir = self.trainer.args.output_dir
550
+
551
+ # Copy all files from checkpoint dir to root dir, overwriting if they exist
552
+ # (effectively saves the checkpoint to the output directory)
553
+ for item in os.listdir(checkpoint_dir):
554
+ src = os.path.join(checkpoint_dir, item)
555
+ dst = os.path.join(root_dir, item)
556
+ if os.path.isdir(src):
557
+ if os.path.exists(dst):
558
+ shutil.rmtree(dst)
559
+ shutil.copytree(src, dst)
560
+ else:
561
+ shutil.copy2(src, dst)
562
+
481
563
  self.comms_handler.send_status(
482
564
  {
483
565
  "status": "checkpoint_saved",
@@ -486,31 +568,21 @@ class CommandMonitor:
486
568
  + f"/checkpoints/{command.get('checkpoint_name')}/",
487
569
  }
488
570
  )
571
+ self.comms_handler.send_status(
572
+ {
573
+ "status": "model_saved",
574
+ "output_dir": self.trainer.args.output_dir,
575
+ }
576
+ )
577
+ elif command.get("command") == "terminate":
578
+ print("TERMINATED")
579
+ self.trainer.accelerator.end_training()
580
+ self.comms_handler.send_status({"status": "terminated"})
489
581
 
490
582
  except Exception as e:
491
583
  print(e)
492
584
  self.comms_handler.send_status({"status": "error", "error": str(e)})
493
585
 
494
- def _monitor_broadcasts(self):
495
- """Background thread that monitors for broadcasts from the server."""
496
- if not self.comms_handler:
497
- return
498
- try:
499
- for broadcast in self.comms_handler.receive_broadcast():
500
- print(f"!!!Received broadcast: {broadcast}")
501
- if broadcast.get("message") == "terminate":
502
- # self.trainer.control.should_training_stop = True
503
- # self.comms_handler.send_status(
504
- # {
505
- # "status": "Received termination command",
506
- # "process_id": self.trainer.accelerator.process_index,
507
- # }
508
- # )
509
- if self.trainer.accelerator.is_main_process:
510
- self.trainer.accelerator.end_training()
511
- except Exception as e:
512
- self.comms_handler.send_status({"status": "error", "error": str(e)})
513
-
514
586
 
515
587
  def main():
516
588
  parser = argparse.ArgumentParser()
@@ -523,6 +595,8 @@ def main():
523
595
  pipe_args.add_argument("--data_port", type=int, required=True)
524
596
  pipe_args.add_argument("--broadcast_port", type=int, required=True)
525
597
  pipe_args.add_argument("--handshake_port", type=int, required=True)
598
+ pipe_args.add_argument("--vllm_group_port", type=int, required=True)
599
+ pipe_args.add_argument("--vllm_port", type=int, required=True)
526
600
 
527
601
  training_args = parser.add_argument_group("Training arguments")
528
602
  training_args.add_argument(
@@ -544,6 +618,11 @@ def main():
544
618
  args = parser.parse_args()
545
619
 
546
620
  if args.debug:
621
+ # python grpo_training.py --debug
622
+ # --command_port 0 --status_port 0
623
+ # --data_port 0 --broadcast_port 0
624
+ # --handshake_port 0 --model Qwen/Qwen3-0.6B
625
+ # --trl_train_kwargs '{"output_dir": ".", "report_to": "none"}'
547
626
  server_comms_handler = ArborServerCommsHandler(
548
627
  host=args.host,
549
628
  )
@@ -554,6 +633,11 @@ def main():
554
633
  args.broadcast_port = server_comms_handler.broadcast_port
555
634
  args.handshake_port = server_comms_handler.handshake_port
556
635
 
636
+ handshake_thread = threading.Thread(
637
+ target=server_comms_handler.wait_for_clients, args=(1,), daemon=True
638
+ )
639
+ handshake_thread.start()
640
+
557
641
  def debug_data_generator():
558
642
  tldr_dataset = load_dataset("trl-lib/tldr", split="train")
559
643
  idx = 0
@@ -636,15 +720,18 @@ def main():
636
720
  training_args = GRPOConfig(
637
721
  dataloader_num_workers=0,
638
722
  shuffle_dataset=False,
723
+ vllm_server_port=args.vllm_port,
639
724
  **trl_train_args,
640
725
  )
641
726
 
727
+ weight_update_callback = WeightUpdateCallback()
642
728
  trainer = ArborGRPOTrainer(
643
729
  model=args.model,
644
730
  args=training_args,
645
731
  train_dataset=BlockingQueueDataset(None, None),
646
- callbacks=[LastStepTimeCallback()],
732
+ callbacks=[LastStepTimeCallback(), weight_update_callback],
647
733
  peft_config=lora_config,
734
+ vllm_group_port=args.vllm_group_port,
648
735
  **arbor_train_args,
649
736
  )
650
737
  # Create client handler
@@ -657,6 +744,8 @@ def main():
657
744
  handshake_port=args.handshake_port,
658
745
  is_main_process=trainer.accelerator.is_main_process,
659
746
  )
747
+ weight_update_callback.set_comms_handler(comms_handler)
748
+ weight_update_callback.set_trainer(trainer)
660
749
  trainer.comms_handler = comms_handler
661
750
 
662
751
  # Initialize the dataset with the actual accelerator
@@ -671,6 +760,18 @@ def main():
671
760
  base_model_name=args.model,
672
761
  )
673
762
 
763
+ # Add signal handlers for graceful shutdown
764
+ def signal_handler(signum, frame):
765
+ print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
766
+ print("Ending training...")
767
+ trainer.accelerator.end_training()
768
+ print("Closing communications...")
769
+ comms_handler.close()
770
+ sys.exit(0)
771
+
772
+ signal.signal(signal.SIGINT, signal_handler)
773
+ signal.signal(signal.SIGTERM, signal_handler)
774
+
674
775
  print("Training...")
675
776
  trainer.train()
676
777
 
@@ -681,7 +782,10 @@ def main():
681
782
  comms_handler.send_status({"status": "error", "error": str(e)})
682
783
  raise e
683
784
  finally:
785
+ print("Cleaning up resources...")
786
+ trainer.accelerator.end_training()
684
787
  comms_handler.close()
788
+ print("Cleanup complete")
685
789
 
686
790
 
687
791
  if __name__ == "__main__":
@@ -0,0 +1,109 @@
1
+ import argparse
2
+ import json
3
+ import random
4
+ import threading
5
+ import time
6
+
7
+ import torch
8
+ import zmq
9
+ from peft import LoraConfig
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from trl import SFTConfig, SFTTrainer, setup_chat_format
12
+
13
+ from arbor.server.services.scripts.utils.arg_parser import get_training_arg_parser
14
+
15
+
16
+ def main():
17
+ parser = get_training_arg_parser()
18
+ parser.add_argument("--model", type=str, required=True)
19
+ parser.add_argument("--lora", type=bool, default=False)
20
+ args = parser.parse_args()
21
+
22
+ try:
23
+ trl_train_kwargs = {**(args.trl_config_kwargs or {})}
24
+
25
+ # TODO: These assertions should be done in some better way
26
+ assert "output_dir" in trl_train_kwargs, "output_dir is required"
27
+ if "gradient_checkpointing_kwargs" in trl_train_kwargs and args.lora:
28
+ print(
29
+ "Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
30
+ )
31
+ trl_train_kwargs["gradient_checkpointing_kwargs"] = {
32
+ **(trl_train_kwargs.get("gradient_checkpointing_kwargs") or {}),
33
+ "use_reentrant": False,
34
+ }
35
+
36
+ lora_config = None
37
+ if args.lora:
38
+ print("Using LORA for PEFT")
39
+ lora_config = LoraConfig(
40
+ r=16,
41
+ lora_alpha=64,
42
+ target_modules=[
43
+ "q_proj",
44
+ "k_proj",
45
+ "v_proj",
46
+ "o_proj",
47
+ "up_proj",
48
+ "down_proj",
49
+ "gate_proj",
50
+ ],
51
+ task_type="CAUSAL_LM",
52
+ lora_dropout=0.05,
53
+ inference_mode=False,
54
+ )
55
+
56
+ training_args = GRPOConfig(
57
+ dataloader_num_workers=0,
58
+ shuffle_dataset=False,
59
+ **trl_train_args,
60
+ )
61
+
62
+ trainer = ArborGRPOTrainer(
63
+ model=args.model,
64
+ args=training_args,
65
+ train_dataset=BlockingQueueDataset(None, None),
66
+ callbacks=[LastStepTimeCallback()],
67
+ peft_config=lora_config,
68
+ **arbor_train_args,
69
+ )
70
+ # Create client handler
71
+ comms_handler = ArborScriptCommsHandler(
72
+ host=args.host,
73
+ command_port=args.command_port,
74
+ status_port=args.status_port,
75
+ data_port=args.data_port,
76
+ broadcast_port=args.broadcast_port,
77
+ handshake_port=args.handshake_port,
78
+ is_main_process=trainer.accelerator.is_main_process,
79
+ )
80
+ trainer.comms_handler = comms_handler
81
+
82
+ # Initialize the dataset with the actual accelerator
83
+ trainer.train_dataset = BlockingQueueDataset(
84
+ accelerator=trainer.accelerator,
85
+ comms_handler=trainer.comms_handler,
86
+ )
87
+
88
+ command_monitor = CommandMonitor(
89
+ comms_handler=comms_handler,
90
+ trainer=trainer,
91
+ base_model_name=args.model,
92
+ )
93
+
94
+ print("Training...")
95
+ trainer.train()
96
+
97
+ except KeyboardInterrupt:
98
+ print("\nReceived interrupt, shutting down...")
99
+ except Exception as e:
100
+ print(f"Error: {e}")
101
+ comms_handler.send_status({"status": "error", "error": str(e)})
102
+ raise e
103
+ finally:
104
+ trainer.accelerator.end_training()
105
+ comms_handler.close()
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main()
File without changes
@@ -0,0 +1,31 @@
1
+ """The arg parser for the training scripts"""
2
+
3
+ import argparse
4
+ import json
5
+
6
+
7
+ def get_training_arg_parser():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--debug", action="store_true")
10
+
11
+ pipe_args = parser.add_argument_group("Comms arguments")
12
+ pipe_args.add_argument("--host", default="localhost")
13
+ pipe_args.add_argument("--command_port", type=int, required=True)
14
+ pipe_args.add_argument("--status_port", type=int, required=True)
15
+ pipe_args.add_argument("--data_port", type=int, required=True)
16
+ pipe_args.add_argument("--broadcast_port", type=int, required=True)
17
+ pipe_args.add_argument("--handshake_port", type=int, required=True)
18
+
19
+ training_args = parser.add_argument_group("Training arguments")
20
+ training_args.add_argument(
21
+ "--model",
22
+ type=str,
23
+ help="Model to use for training",
24
+ )
25
+ training_args.add_argument(
26
+ "--trl_config_kwargs",
27
+ type=json.loads,
28
+ help="Training configs as a JSON string",
29
+ )
30
+
31
+ return parser
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arbor-ai
3
- Version: 0.1.13
3
+ Version: 0.1.15
4
4
  Summary: A framework for fine-tuning and managing language models
5
5
  Author-email: Noah Ziems <nziems2@nd.edu>
6
6
  Project-URL: Homepage, https://github.com/Ziems/arbor
@@ -8,21 +8,20 @@ Project-URL: Issues, https://github.com/Ziems/arbor/issues
8
8
  Requires-Python: >=3.10
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE
11
+ Requires-Dist: torch>=2.6.0
11
12
  Requires-Dist: fastapi
12
13
  Requires-Dist: uvicorn
13
14
  Requires-Dist: click
14
15
  Requires-Dist: python-multipart
15
16
  Requires-Dist: pydantic-settings
16
- Requires-Dist: torch
17
+ Requires-Dist: vllm>=0.8.5.post1
17
18
  Requires-Dist: transformers
18
- Requires-Dist: trl==0.17.0
19
+ Requires-Dist: trl>=0.17.0
19
20
  Requires-Dist: peft
20
21
  Requires-Dist: ray>=2.9
21
22
  Requires-Dist: setuptools<77.0.0,>=76.0.0
22
23
  Requires-Dist: pyzmq>=26.4.0
23
24
  Requires-Dist: pyyaml>=6.0.2
24
- Requires-Dist: sglang[all]>=0.4.5.post3
25
- Requires-Dist: sglang-router
26
25
  Requires-Dist: wandb
27
26
  Dynamic: license-file
28
27
 
@@ -1,5 +1,5 @@
1
1
  arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- arbor/cli.py,sha256=3o9A03Kew9cM5ZvD_6xOTaquNIE_hTYMOeQH3hkuJbY,3110
2
+ arbor/cli.py,sha256=6S_nT93Zof6nB01n-xA7hSzssREzY13Oyh_jrElyTTY,3490
3
3
  arbor/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
5
5
  arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
@@ -8,8 +8,8 @@ arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,
8
8
  arbor/server/api/models/schemas.py,sha256=KCHav1nPFbQEynrcO-MObhRmoOrdFvfGuVogApynOCA,6210
9
9
  arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
11
- arbor/server/api/routes/grpo.py,sha256=AbQ_BHgk-Om5U0qSt_FeJfyBJ0vItpfrnCNtJgD6p5k,2245
12
- arbor/server/api/routes/inference.py,sha256=Zy4ciN6vdRgu0-sFFnEeTZB-4XnLjEDH-atU7roIKSs,1668
11
+ arbor/server/api/routes/grpo.py,sha256=QrWwj44-EenOyDwtiAO7OJPPGe8CyNaxCUTDlqfJs4g,2338
12
+ arbor/server/api/routes/inference.py,sha256=JI4lm7zWrUqgMadWA0JuTD13hq6kGQpTLcuklhOH7f8,1547
13
13
  arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
14
14
  arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
15
15
  arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
@@ -17,18 +17,26 @@ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
17
17
  arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
20
- arbor/server/services/grpo_manager.py,sha256=-_0xjENvIrOAtHACkFPMYox9YAeckHbpX2FkrmKrWuU,15448
21
- arbor/server/services/inference_manager.py,sha256=NcsUI-pgf3cRhU6P3xlPx0dxhvgYrfGZkEEGORcHcis,12833
20
+ arbor/server/services/grpo_manager.py,sha256=y5gOko_RmyjQqvzlR79_PPZgMwMwCMJiaeygCG5qS-A,18761
21
+ arbor/server/services/inference_manager.py,sha256=a1c5zYbjk6fPM3egX2McKv7ZWPN7c-QH_Qogu9iay90,9597
22
22
  arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
23
23
  arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
24
24
  arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  arbor/server/services/comms/comms.py,sha256=3KN3mzwPvfW2_L5hq02JdAk6yOMyhY0_pBz-DDr5A3o,7694
26
- arbor/server/services/scripts/grpo_training.py,sha256=eMT5cIMolAzhukANH1WRmPdxIkvLbsbrggdGFCMGMHc,26474
26
+ arbor/server/services/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ arbor/server/services/inference/vllm_client.py,sha256=X0v6zGHuaROGniWw_VCkzeWWuAHq0PlwtrFjTngCT4k,18285
28
+ arbor/server/services/inference/vllm_serve.py,sha256=GdcaQStGKLj4J1kAnAnnI07R0X3A-bPoj7Tvagxsias,109457
29
+ arbor/server/services/scripts/dpo_training.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
+ arbor/server/services/scripts/grpo_training.py,sha256=qjYSinOhi9-vvKY-gqGARwUgDQXYGDHlp9ZLwqKE1rw,31123
31
+ arbor/server/services/scripts/sft_training.py,sha256=jgDMxZn9RFH9ys_7OF9Is8pQ9V97O2KzWg22Gveh3yE,3410
32
+ arbor/server/services/scripts/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
+ arbor/server/services/scripts/utils/arg_parser.py,sha256=ur_iyhc_Ie00tjq63vK4Sdeu2PGKwe6Dh6Iax2vw9jc,1022
34
+ arbor/server/services/scripts/utils/dataset.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
35
  arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
36
  arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- arbor_ai-0.1.13.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
- arbor_ai-0.1.13.dist-info/METADATA,sha256=c0yScMpCiWYSFqVLjgk5TrRBuAVJK3aTBl0z0IPZ_8Y,2442
31
- arbor_ai-0.1.13.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
32
- arbor_ai-0.1.13.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
- arbor_ai-0.1.13.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
- arbor_ai-0.1.13.dist-info/RECORD,,
37
+ arbor_ai-0.1.15.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
38
+ arbor_ai-0.1.15.dist-info/METADATA,sha256=GMGq6nbWEbRZxsJG2u7DhnMj6qCSTvssMVUN4ASs2BA,2413
39
+ arbor_ai-0.1.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ arbor_ai-0.1.15.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
41
+ arbor_ai-0.1.15.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
42
+ arbor_ai-0.1.15.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.6.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5