arbor-ai 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,576 @@
1
+ ###############################################################################
2
+ # Initial Versions of this File Borrowed from Will Brown's Verifiers Library #
3
+ # https://github.com/willccbb/verifiers #
4
+ ###############################################################################
5
+
6
+ import argparse
7
+ import json
8
+ import random
9
+ import threading
10
+ import time
11
+ from functools import lru_cache
12
+ from typing import Any, List, Optional, Union
13
+
14
+ import torch
15
+ import zmq
16
+ from accelerate import Accelerator
17
+ from accelerate.utils import gather
18
+ from datasets import Dataset, IterableDataset, load_dataset
19
+ from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig # type: ignore
20
+ from torch.utils.data import Dataset
21
+ from transformers import (
22
+ PreTrainedModel,
23
+ PreTrainedTokenizerBase,
24
+ Trainer,
25
+ TrainerCallback,
26
+ )
27
+ from trl import GRPOConfig, GRPOTrainer
28
+ from trl.data_utils import maybe_apply_chat_template
29
+
30
+ from arbor.server.services.comms.comms import (
31
+ ArborScriptCommsHandler,
32
+ ArborServerCommsHandler,
33
+ )
34
+
35
+ last_step_time = None
36
+ last_queue_pop_time = None
37
+
38
+
39
+ def time_since_last_step():
40
+ global last_step_time
41
+ if last_step_time is None:
42
+ return float("inf")
43
+ return time.time() - last_step_time
44
+
45
+
46
+ def get_time_since_last_queue_pop():
47
+ global last_queue_pop_time
48
+ if last_queue_pop_time is None:
49
+ return float("inf")
50
+ return time.time() - last_queue_pop_time
51
+
52
+
53
+ class ArborGRPOTrainer(GRPOTrainer):
54
+ def __init__(
55
+ self,
56
+ model: Union[str, PreTrainedModel],
57
+ scale_rewards: bool = True,
58
+ args: Optional[GRPOConfig] = None,
59
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
60
+ eval_dataset: Optional[Union[Dataset, IterableDataset]] = None,
61
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
62
+ callbacks: Optional[list[TrainerCallback]] = None,
63
+ optimizers: tuple[
64
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
65
+ ] = (None, None),
66
+ peft_config: Optional["PeftConfig"] = None,
67
+ comms_handler: Optional[ArborScriptCommsHandler] = None,
68
+ update_interval: Optional[int] = 5,
69
+ lora: Optional[bool] = False,
70
+ **kwargs,
71
+ ):
72
+
73
+ super().__init__(
74
+ model=model,
75
+ reward_funcs=[],
76
+ args=args,
77
+ train_dataset=train_dataset,
78
+ eval_dataset=eval_dataset,
79
+ processing_class=processing_class,
80
+ callbacks=callbacks,
81
+ optimizers=optimizers,
82
+ peft_config=peft_config,
83
+ **kwargs,
84
+ )
85
+ self.peft_config = peft_config
86
+ self.scale_rewards = scale_rewards
87
+ self.comms_handler = comms_handler
88
+ self.update_interval = update_interval
89
+
90
+ def _generate_and_score_completions(
91
+ self, batch: List[dict[str, Any]]
92
+ ) -> dict[str, Union[torch.Tensor, Any]]:
93
+ device = self.accelerator.device
94
+
95
+ # Process prompts and completions
96
+ prompt_completion_texts = []
97
+ for example in batch:
98
+ prompt_completion_texts.append(
99
+ maybe_apply_chat_template(
100
+ {
101
+ "prompt": example["messages"],
102
+ "completion": [example["completion"]],
103
+ },
104
+ self.processing_class,
105
+ )
106
+ )
107
+
108
+ # Tokenize prompts
109
+ prompt_texts = [
110
+ prompt_completion_text["prompt"]
111
+ for prompt_completion_text in prompt_completion_texts
112
+ ]
113
+ prompt_inputs = self.processing_class(
114
+ prompt_texts,
115
+ return_tensors="pt",
116
+ padding=True,
117
+ padding_side="left",
118
+ add_special_tokens=False,
119
+ ).to(device)
120
+ prompt_ids = Trainer._prepare_inputs(self, prompt_inputs)
121
+ prompt_ids, prompt_mask = (
122
+ prompt_inputs["input_ids"],
123
+ prompt_inputs["attention_mask"],
124
+ )
125
+
126
+ # Tokenize completions
127
+ completion_texts = [
128
+ prompt_completion_text["completion"]
129
+ for prompt_completion_text in prompt_completion_texts
130
+ ]
131
+ completion_ids = self.processing_class(
132
+ completion_texts,
133
+ return_tensors="pt",
134
+ padding=True,
135
+ add_special_tokens=False,
136
+ ).to(device)
137
+ completion_ids, completion_mask = (
138
+ completion_ids["input_ids"],
139
+ completion_ids["attention_mask"],
140
+ )
141
+
142
+ if self.max_prompt_length is not None:
143
+ if prompt_ids.shape[1] > self.max_prompt_length:
144
+ print(f"Truncating prompt to {self.max_prompt_length} tokens")
145
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
146
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
147
+
148
+ if self.max_completion_length is not None:
149
+ if completion_ids.shape[1] > self.max_completion_length:
150
+ print(f"Truncating completion to {self.max_completion_length} tokens")
151
+ completion_ids = completion_ids[:, : self.max_completion_length]
152
+ completion_mask = completion_mask[:, : self.max_completion_length]
153
+
154
+ # Keeping this for when we switch to vllm
155
+ # if self.state.global_step != self._last_loaded_step:
156
+ # self._move_model_to_vllm()
157
+ # self._last_loaded_step = self.state.global_step
158
+
159
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
160
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
161
+
162
+ print(
163
+ f"prompt_completion_ids.shape (after truncation, if enabled): {prompt_completion_ids.shape}, prompt_ids.shape: {prompt_ids.shape}, completion_ids.shape: {completion_ids.shape}"
164
+ )
165
+
166
+ logits_to_keep = completion_ids.size(1)
167
+
168
+ with torch.no_grad():
169
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
170
+ # computation here, and use per_token_logps.detach() instead.
171
+ if self.num_iterations > 1:
172
+ old_per_token_logps = self._get_per_token_logps(
173
+ self.model, prompt_completion_ids, attention_mask, logits_to_keep
174
+ )
175
+ else:
176
+ old_per_token_logps = None
177
+
178
+ if self.beta == 0.0:
179
+ ref_per_token_logps = None
180
+ elif self.ref_model is not None:
181
+ ref_per_token_logps = self._get_per_token_logps(
182
+ self.ref_model,
183
+ prompt_completion_ids,
184
+ attention_mask,
185
+ logits_to_keep,
186
+ )
187
+ else:
188
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
189
+ ref_per_token_logps = self._get_per_token_logps(
190
+ self.model,
191
+ prompt_completion_ids,
192
+ attention_mask,
193
+ logits_to_keep,
194
+ )
195
+
196
+ rewards = torch.tensor(
197
+ [example["reward"] for example in batch], dtype=torch.float32
198
+ ).to(device)
199
+ rewards = gather(rewards)
200
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
201
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
202
+
203
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
204
+ self.num_generations, dim=0
205
+ )
206
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(
207
+ self.num_generations, dim=0
208
+ )
209
+ advantages = rewards - mean_grouped_rewards
210
+
211
+ if self.scale_rewards:
212
+ # Scale the rewards to be between 0 and 1
213
+ advantages = advantages / (std_grouped_rewards + 1e-4)
214
+
215
+ # Slice to keep only the local part of the data
216
+ process_slice = slice(
217
+ self.accelerator.process_index * len(batch),
218
+ (self.accelerator.process_index + 1) * len(batch),
219
+ )
220
+ advantages = advantages[process_slice]
221
+
222
+ ## Logged Metrics Removed Here
223
+
224
+ return {
225
+ "prompt_ids": prompt_ids,
226
+ "prompt_mask": prompt_mask,
227
+ "completion_ids": completion_ids,
228
+ "completion_mask": completion_mask,
229
+ "old_per_token_logps": old_per_token_logps,
230
+ "ref_per_token_logps": ref_per_token_logps,
231
+ "advantages": advantages,
232
+ }
233
+
234
+
235
+ class LastStepTimeCallback(TrainerCallback):
236
+ "A callback that prints a message at the beginning of training"
237
+
238
+ def on_step_end(self, args, state, control, **kwargs):
239
+ global last_step_time
240
+ print(f"Time since last step: {time_since_last_step()}")
241
+ last_step_time = time.time()
242
+
243
+
244
+ class BlockingQueueDataset(Dataset):
245
+ def __init__(
246
+ self,
247
+ accelerator: Accelerator,
248
+ comms_handler: ArborScriptCommsHandler,
249
+ size=10_000, # Just a random number
250
+ maxsize=100,
251
+ ):
252
+ self.size = size
253
+ self.accelerator = accelerator
254
+ self.comms_handler = comms_handler
255
+ self.get_cached_data = lru_cache(maxsize=maxsize)(self._get_data)
256
+ self.completion_counters = {}
257
+
258
+ def __len__(self):
259
+ return self.size
260
+
261
+ def _get_data(self, idx):
262
+ rank = self.accelerator.process_index
263
+ world_size = self.accelerator.num_processes
264
+
265
+ if self.accelerator.is_main_process:
266
+ global last_queue_pop_time
267
+ last_queue_pop_time = time.time()
268
+
269
+ if idx not in self.completion_counters:
270
+ self.completion_counters[idx] = 0
271
+
272
+ try:
273
+ new_data = self.comms_handler.receive_data()
274
+
275
+ except Exception as e:
276
+ print(f"[rank {rank}] Error receiving data: {e}")
277
+ new_data = None
278
+
279
+ return new_data
280
+
281
+ def __getitem__(self, idx):
282
+ data = self.get_cached_data(idx)
283
+ # Create hash of data to detect if processes are using the same idx for the same data
284
+ data_hash = format(abs(hash(str(data))) % (16**8), "08x")
285
+
286
+ if data is None:
287
+ return None
288
+
289
+ counter = self.completion_counters.get(idx, 0)
290
+ item = data[counter]
291
+ self.completion_counters[idx] = (counter + 1) % len(data)
292
+ return item
293
+
294
+
295
+ class CommandMonitor:
296
+ def __init__(
297
+ self,
298
+ comms_handler: ArborScriptCommsHandler,
299
+ trainer: ArborGRPOTrainer,
300
+ base_model_name: str,
301
+ ):
302
+ self.comms_handler = comms_handler
303
+ self.trainer = trainer
304
+ self.base_model_name = base_model_name
305
+ self.command_thread = threading.Thread(
306
+ target=self._monitor_commands, daemon=True
307
+ )
308
+ self.command_thread.start()
309
+
310
+ self.broadcast_thread = threading.Thread(
311
+ target=self._monitor_broadcasts, daemon=True
312
+ )
313
+ self.broadcast_thread.start()
314
+
315
+ def _monitor_commands(self):
316
+ """Background thread that monitors for commands from the server."""
317
+ if not self.comms_handler:
318
+ return
319
+ try:
320
+ for command in self.comms_handler.receive_command():
321
+ print(f"Main process received command: {command}")
322
+ if (
323
+ command.get("command") == "save_model"
324
+ and self.trainer.accelerator.is_main_process
325
+ ):
326
+ print(
327
+ f"[Training Script] Instructed to save model at {self.trainer.args.output_dir}"
328
+ )
329
+ # Wait until data queue is empty before saving
330
+
331
+ while (
332
+ time_since_last_step() <= 10
333
+ or get_time_since_last_queue_pop() <= 10
334
+ ):
335
+ # print(
336
+ # f"Waiting for data queue to empty...{self.comms_handler.get_data_queue_size()}"
337
+ # )
338
+ print(f"Waiting for steps to finish")
339
+ print(
340
+ f"Time since last step: {time_since_last_step():.1f} (needs to be >= 10)"
341
+ )
342
+ print(
343
+ f"Time since last queue pop: {get_time_since_last_queue_pop():.1f} (needs to be >= 10)"
344
+ )
345
+ time.sleep(5) # Small delay to prevent busy waiting)
346
+ print("[Training Script] Saving model...")
347
+
348
+ if self.trainer.peft_config:
349
+
350
+ self.trainer.save_model(
351
+ output_dir=self.trainer.args.output_dir + "/adapter/"
352
+ )
353
+
354
+ # base_model = AutoModelForCausalLM.from_pretrained(
355
+ # self.base_model_name
356
+ # ).to(self.trainer.accelerator.device)
357
+
358
+ _model_to_merge = AutoPeftModelForCausalLM.from_pretrained(
359
+ self.trainer.args.output_dir + "/adapter/",
360
+ config=self.trainer.peft_config,
361
+ )
362
+ merged_model = _model_to_merge.merge_and_unload()
363
+ merged_model.save_pretrained(
364
+ self.trainer.args.output_dir,
365
+ safe_serialization=True,
366
+ )
367
+ self.trainer.processing_class.save_pretrained(
368
+ self.trainer.args.output_dir
369
+ )
370
+ else:
371
+ self.trainer.save_model()
372
+
373
+ print("[Training Script] Model saved")
374
+ self.comms_handler.send_status(
375
+ {
376
+ "status": "model_saved",
377
+ "output_dir": self.trainer.args.output_dir,
378
+ }
379
+ )
380
+ except Exception as e:
381
+ print(e)
382
+ self.comms_handler.send_status({"status": "error", "error": str(e)})
383
+
384
+ def _monitor_broadcasts(self):
385
+ """Background thread that monitors for broadcasts from the server."""
386
+ if not self.comms_handler:
387
+ return
388
+ try:
389
+ for broadcast in self.comms_handler.receive_broadcast():
390
+ print(f"!!!Received broadcast: {broadcast}")
391
+ if broadcast.get("message") == "terminate":
392
+ self.trainer.control.should_training_stop = True
393
+ self.comms_handler.send_status(
394
+ {
395
+ "status": "Received termination command",
396
+ "process_id": self.trainer.accelerator.process_index,
397
+ }
398
+ )
399
+ except Exception as e:
400
+ self.comms_handler.send_status({"status": "error", "error": str(e)})
401
+
402
+
403
+ def main():
404
+ parser = argparse.ArgumentParser()
405
+ parser.add_argument("--debug", action="store_true")
406
+
407
+ pipe_args = parser.add_argument_group("Comms arguments")
408
+ pipe_args.add_argument("--host", default="localhost")
409
+ pipe_args.add_argument("--command_port", type=int, required=True)
410
+ pipe_args.add_argument("--status_port", type=int, required=True)
411
+ pipe_args.add_argument("--data_port", type=int, required=True)
412
+ pipe_args.add_argument("--broadcast_port", type=int, required=True)
413
+ pipe_args.add_argument("--handshake_port", type=int, required=True)
414
+
415
+ training_args = parser.add_argument_group("Training arguments")
416
+ training_args.add_argument(
417
+ "--model",
418
+ type=str,
419
+ help="Model to use for training",
420
+ )
421
+ training_args.add_argument(
422
+ "--trl_train_kwargs",
423
+ type=json.loads,
424
+ help="Training arguments as a JSON string",
425
+ )
426
+ training_args.add_argument(
427
+ "--arbor_train_kwargs",
428
+ type=json.loads,
429
+ help="Training arguments as a JSON string",
430
+ )
431
+
432
+ args = parser.parse_args()
433
+
434
+ if args.debug:
435
+ server_comms_handler = ArborServerCommsHandler(
436
+ host=args.host,
437
+ )
438
+
439
+ args.command_port = server_comms_handler.command_port
440
+ args.status_port = server_comms_handler.status_port
441
+ args.data_port = server_comms_handler.data_port
442
+ args.broadcast_port = server_comms_handler.broadcast_port
443
+ args.handshake_port = server_comms_handler.handshake_port
444
+
445
+ def debug_data_generator():
446
+ tldr_dataset = load_dataset("trl-lib/tldr", split="train")
447
+ idx = 0
448
+ for item in tldr_dataset:
449
+ input_messages = [{"role": "user", "content": item["prompt"]}]
450
+ completions = [
451
+ {
452
+ "role": "assistant",
453
+ "content": "This is a test completion"
454
+ + hex(random.randint(0, 0xFFFFFF))[2:],
455
+ }
456
+ for _ in range(8)
457
+ ]
458
+
459
+ rewards = [-abs(20 - len(c["content"])) for c in completions]
460
+ batch = []
461
+ for completion, reward in zip(completions, rewards):
462
+ batch.append(
463
+ {
464
+ "messages": input_messages,
465
+ "completion": completion,
466
+ "reward": reward,
467
+ }
468
+ )
469
+ server_comms_handler.send_data(batch)
470
+ time.sleep(1)
471
+
472
+ if idx >= 25:
473
+ server_comms_handler.send_command({"command": "save_model"})
474
+
475
+ debug_thread = threading.Thread(target=debug_data_generator, daemon=True)
476
+ debug_thread.start()
477
+
478
+ def status_listener():
479
+ # Need to set subscription for PUB/SUB pattern
480
+ server_comms_handler.status_socket.setsockopt_string(zmq.SUBSCRIBE, "")
481
+ for status in server_comms_handler.receive_status():
482
+ print(f"Status: {status}")
483
+
484
+ status_listener_thread = threading.Thread(target=status_listener, daemon=True)
485
+ status_listener_thread.start()
486
+
487
+ try:
488
+ trl_train_args = {**(args.trl_train_kwargs or {})}
489
+ arbor_train_args = {**(args.arbor_train_kwargs or {})}
490
+
491
+ # TODO: These assertions should be done in some better way
492
+ assert "output_dir" in trl_train_args, "output_dir is required"
493
+ if "gradient_checkpointing_kwargs" in trl_train_args and arbor_train_args.get(
494
+ "lora", False
495
+ ):
496
+ print(
497
+ "Setting gradient_checkpointing_kwargs to use_reentrant=False for LORA training"
498
+ )
499
+ trl_train_args["gradient_checkpointing_kwargs"] = {
500
+ **(trl_train_args.get("gradient_checkpointing_kwargs") or {}),
501
+ "use_reentrant": False,
502
+ }
503
+
504
+ lora_config = None
505
+ if arbor_train_args.get("lora", False):
506
+ print("Using LORA for PEFT")
507
+ lora_config = LoraConfig(
508
+ r=16,
509
+ lora_alpha=64,
510
+ target_modules=[
511
+ "q_proj",
512
+ "k_proj",
513
+ "v_proj",
514
+ "o_proj",
515
+ "up_proj",
516
+ "down_proj",
517
+ "gate_proj",
518
+ ],
519
+ task_type="CAUSAL_LM",
520
+ lora_dropout=0.05,
521
+ inference_mode=False,
522
+ )
523
+
524
+ training_args = GRPOConfig(
525
+ dataloader_num_workers=0,
526
+ shuffle_dataset=False,
527
+ **trl_train_args,
528
+ )
529
+
530
+ trainer = ArborGRPOTrainer(
531
+ model=args.model,
532
+ args=training_args,
533
+ train_dataset=BlockingQueueDataset(None, None),
534
+ callbacks=[LastStepTimeCallback()],
535
+ peft_config=lora_config,
536
+ **arbor_train_args,
537
+ )
538
+ # Create client handler
539
+ comms_handler = ArborScriptCommsHandler(
540
+ host=args.host,
541
+ command_port=args.command_port,
542
+ status_port=args.status_port,
543
+ data_port=args.data_port,
544
+ broadcast_port=args.broadcast_port,
545
+ handshake_port=args.handshake_port,
546
+ is_main_process=trainer.accelerator.is_main_process,
547
+ )
548
+ trainer.comms_handler = comms_handler
549
+
550
+ # Initialize the dataset with the actual accelerator
551
+ trainer.train_dataset = BlockingQueueDataset(
552
+ accelerator=trainer.accelerator,
553
+ comms_handler=trainer.comms_handler,
554
+ )
555
+
556
+ command_monitor = CommandMonitor(
557
+ comms_handler=comms_handler,
558
+ trainer=trainer,
559
+ base_model_name=args.model,
560
+ )
561
+
562
+ print("Training...")
563
+ trainer.train()
564
+
565
+ except KeyboardInterrupt:
566
+ print("\nReceived interrupt, shutting down...")
567
+ except Exception as e:
568
+ print(f"Error: {e}")
569
+ comms_handler.send_status({"status": "error", "error": str(e)})
570
+ raise e
571
+ finally:
572
+ comms_handler.close()
573
+
574
+
575
+ if __name__ == "__main__":
576
+ main()