cciwon-code-review-cli 2.0.1 → 2.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/bin/code-review.js +1 -1
  2. package/lib/chat-mode.js +7 -2
  3. package/package.json +1 -1
  4. package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
  5. package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
  6. package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
  7. package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
  8. package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
  9. package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
  10. package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
  11. package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
  12. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
  13. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
  14. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
  15. package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
  16. package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
  17. package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
  18. package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
  19. package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
  20. package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
  21. package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
  22. package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
  23. package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
  24. package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
  25. package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
  26. package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
  27. package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
  28. package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
  29. package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
  30. package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
  31. package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
  32. package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
  33. package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
  34. package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
  35. package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
  36. package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
  37. package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
  38. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
  39. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
  40. package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
  41. package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
  42. package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
  43. package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
  44. package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
  45. package/unsloth_compiled_cache/Conv1d.py +70 -0
  46. package/unsloth_compiled_cache/Conv2d.py +70 -0
  47. package/unsloth_compiled_cache/Conv3d.py +70 -0
  48. package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
  49. package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
  50. package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
  51. package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
  52. package/unsloth_compiled_cache/GroupNorm.py +70 -0
  53. package/unsloth_compiled_cache/LayerNorm.py +72 -0
  54. package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
  55. package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
  56. package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
  57. package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
  58. package/unsloth_compiled_cache/RMSNorm.py +73 -0
  59. package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
  60. package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
  61. package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
  62. package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
  63. package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
  64. package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
  65. package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
  66. package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
  67. package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
  68. package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
  69. package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
  70. package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
  71. package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
  72. package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
  73. package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
  74. package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
  75. package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
  76. package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
  77. package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
  78. package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
  79. package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
  80. package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
  81. package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
  82. package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
  83. package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
  84. package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
  85. package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
  86. package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
  87. package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
  88. package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
  89. package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
  90. package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
  91. package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
  92. package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
  93. package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  94. package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  95. package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
  96. package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  97. package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
  98. package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  99. package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  100. package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  101. package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  102. package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  103. package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  104. package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  105. package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  106. package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  107. package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
  108. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
  109. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
  110. package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
  111. package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
@@ -0,0 +1,1504 @@
1
+ """
2
+ 2025.12.6
3
+ 2025.12.7
4
+ 4.57.1
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
30
+ from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, BaseTrainer, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, Path, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, empty_cache, exact_div, first_true_indices, forward, gather_object, gc, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_rich_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, warnings, Optional, PeftModel, is_peft_available, os, torch)
31
+
32
+
33
+ import os
34
+ from typing import *
35
+ from dataclasses import dataclass, field
36
+ from packaging.version import Version
37
+ import torch
38
+ import numpy as np
39
+ from contextlib import nullcontext
40
+ from torch.nn import functional as F
41
+ import inspect
42
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
43
+ from transformers.training_args import ParallelMode
44
+
45
+ # Wrap trainer with padding to right and enable training mode
46
+ # Also patches W&B since multiple runs must use wandb.finish()
47
+ import functools
48
+ from types import MethodType
49
+ def prepare_for_training_mode(f):
50
+ @functools.wraps(f)
51
+ def wrapper(self, *args, **kwargs):
52
+ # Enable training mode
53
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
54
+ self.model.for_training()
55
+ output = f(self, *args, **kwargs)
56
+ # Return inference mode
57
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
58
+ self.model.for_inference()
59
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
60
+ try:
61
+ import wandb
62
+ wandb.finish()
63
+ except:
64
+ pass
65
+ return output
66
+ return wrapper
67
+ pass
68
+
69
+ torch_compile_options = {
70
+ "epilogue_fusion" : True,
71
+ "max_autotune" : False,
72
+ "shape_padding" : True,
73
+ "trace.enabled" : False,
74
+ "triton.cudagraphs" : False,
75
+ }
76
+
77
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
78
+ def chunked_selective_log_softmax(logits, index):
79
+ # Split into 4 chunks only
80
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
81
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
82
+ all_per_token_logps = []
83
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
84
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
85
+ chunk_logits = chunk_logits.to(torch.float32)
86
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
87
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
88
+ per_token_logps = selected_logits - logsumexp_values
89
+ all_per_token_logps.append(per_token_logps)
90
+ pass
91
+ all_per_token_logps = torch.concat(all_per_token_logps)
92
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
93
+ return all_per_token_logps
94
+
95
+ def calculate_pad_tokens_in_prompt(
96
+ input_ids: torch.Tensor,
97
+ logits_to_keep: int,
98
+ pad_token_id: int
99
+ ) -> torch.Tensor:
100
+ """
101
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
102
+ """
103
+ if logits_to_keep >= input_ids.shape[1]:
104
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
105
+
106
+ prompt_section = input_ids[:, :-logits_to_keep]
107
+
108
+ padding_mask = (prompt_section == pad_token_id)
109
+
110
+ pad_token_counts = padding_mask.sum(dim=1)
111
+
112
+ return pad_token_counts
113
+
114
+ def create_completion_attention_mask(
115
+ completion_input_ids: torch.Tensor,
116
+ left_pad_tokens_per_prompt: torch.Tensor,
117
+ max_left_pad: int,
118
+ pad_token_id: int
119
+ ) -> torch.Tensor:
120
+ """
121
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
122
+
123
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
124
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
125
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
126
+ """
127
+ batch_size, completion_len = completion_input_ids.shape
128
+ device = completion_input_ids.device
129
+
130
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
131
+
132
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
133
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
134
+
135
+ non_padding_mask = (completion_input_ids != pad_token_id)
136
+
137
+ final_mask = shift_mask & non_padding_mask
138
+
139
+ return final_mask
140
+
141
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
142
+ """
143
+ Moves all padding tokens in each sequence of a batch to the right.
144
+ """
145
+ mask = (tensor != pad_id)
146
+ # Must do stable=True since binary mark is unordered
147
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
148
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
149
+ return packed_tensor
150
+
151
+ def align_logprobs_with_mask(
152
+ logprob_tensor: torch.Tensor,
153
+ attention_mask: torch.Tensor,
154
+ pad_value: float = 0.0
155
+ ) -> torch.Tensor:
156
+ """
157
+ Aligns a log probability tensor with a given attention mask.
158
+ """
159
+
160
+ device = logprob_tensor.device
161
+ batch_size, logprob_seq_len = logprob_tensor.shape
162
+ mask_seq_len = attention_mask.shape[1]
163
+
164
+ padded_logprobs = torch.full(
165
+ attention_mask.shape,
166
+ fill_value=pad_value,
167
+ dtype=logprob_tensor.dtype,
168
+ device=device
169
+ )
170
+
171
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
172
+
173
+ cols = torch.arange(logprob_seq_len, device=device)
174
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
175
+
176
+ # Create destination row indices
177
+ # Shape: [batch_size, logprob_seq_len]
178
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
179
+
180
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
181
+ # Create a mask to identify only the indices that are within the bounds
182
+ # of the target tensor's sequence length.
183
+ valid_mask = dest_indices < mask_seq_len
184
+
185
+ # Use this mask to select only the valid row indices, column indices,
186
+ # and the corresponding values from the logprob tensor.
187
+ # This flattens the selected elements into 1D tensors.
188
+ valid_rows = row_indices[valid_mask]
189
+ valid_cols = dest_indices[valid_mask]
190
+ valid_vals = logprob_tensor[valid_mask]
191
+
192
+ # Place the valid values into their correct positions in the padded tensor
193
+ # using a single, efficient advanced indexing operation.
194
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
195
+
196
+ return padded_logprobs
197
+ @dataclass
198
+ class UnslothPPOConfig(PPOConfig):
199
+ """
200
+
201
+ Configuration class for the [`PPOTrainer`].
202
+
203
+ This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
204
+ please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
205
+ values in this class may differ from those in [`~transformers.TrainingArguments`].
206
+
207
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
208
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
209
+ command line.
210
+
211
+ Parameters:
212
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
213
+ Name of this experiment.
214
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
215
+ Path to the reward model.
216
+ model_adapter_name (`str`, *optional*):
217
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
218
+ ref_adapter_name (`str`, *optional*):
219
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
220
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
221
+ Number of epochs to train.
222
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
223
+ Whether to whiten the rewards.
224
+ kl_coef (`float`, *optional*, defaults to `0.05`):
225
+ KL coefficient.
226
+ kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`):
227
+ Which estimator for KL-Divergence to use from [Approximating KL
228
+ Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased
229
+ estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly
230
+ better estimator". Cannot be set to "k2", as it is used for logging purposes.
231
+ cliprange (`float`, *optional*, defaults to `0.2`):
232
+ Clip range.
233
+ vf_coef (`float`, *optional*, defaults to `0.1`):
234
+ Value function coefficient.
235
+ cliprange_value (`float`, *optional*, defaults to `0.2`):
236
+ Clip range for the value function.
237
+ gamma (`float`, *optional*, defaults to `1.0`):
238
+ Discount factor.
239
+ lam (`float`, *optional*, defaults to `0.95`):
240
+ Lambda value for GAE.
241
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
242
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
243
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
244
+ capacity of a single GPU, albeit at the cost of slower generation.
245
+
246
+ """
247
+ vllm_sampling_params: Optional[Any] = field(
248
+ default = None,
249
+ metadata = {'help': 'vLLM SamplingParams'},
250
+ )
251
+ unsloth_num_chunks : Optional[int] = field(
252
+ default = -1,
253
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
254
+ )
255
+
256
+ def __init__(
257
+ self,
258
+ output_dir = None,
259
+ overwrite_output_dir = None,
260
+ do_train = False,
261
+ do_eval = False,
262
+ do_predict = False,
263
+ eval_strategy = 'no',
264
+ prediction_loss_only = False,
265
+ per_device_train_batch_size = 4,
266
+ per_device_eval_batch_size = 4,
267
+ per_gpu_train_batch_size = None,
268
+ per_gpu_eval_batch_size = None,
269
+ gradient_accumulation_steps = 2,
270
+ eval_accumulation_steps = 2,
271
+ eval_delay = 0,
272
+ torch_empty_cache_steps = 250,
273
+ learning_rate = 5e-05,
274
+ weight_decay = 0.01,
275
+ adam_beta1 = 0.9,
276
+ adam_beta2 = 0.999,
277
+ adam_epsilon = 1e-08,
278
+ max_grad_norm = 1.0,
279
+ num_train_epochs = 3.0,
280
+ max_steps = -1,
281
+ lr_scheduler_type = 'linear',
282
+ warmup_ratio = 0.1,
283
+ warmup_steps = 0,
284
+ log_level = 'passive',
285
+ log_level_replica = 'warning',
286
+ log_on_each_node = True,
287
+ logging_dir = None,
288
+ logging_strategy = 'steps',
289
+ logging_first_step = False,
290
+ logging_steps = 1,
291
+ logging_nan_inf_filter = False,
292
+ save_strategy = 'steps',
293
+ save_steps = 500,
294
+ save_total_limit = None,
295
+ save_safetensors = True,
296
+ save_on_each_node = False,
297
+ save_only_model = False,
298
+ restore_callback_states_from_checkpoint = False,
299
+ no_cuda = False,
300
+ use_cpu = False,
301
+ use_mps_device = False,
302
+ seed = 3407,
303
+ data_seed = 3407,
304
+ jit_mode_eval = False,
305
+ bf16 = False,
306
+ fp16 = False,
307
+ fp16_opt_level = 'O1',
308
+ half_precision_backend = 'auto',
309
+ bf16_full_eval = False,
310
+ fp16_full_eval = False,
311
+ tf32 = None,
312
+ local_rank = -1,
313
+ ddp_backend = None,
314
+ tpu_num_cores = None,
315
+ tpu_metrics_debug = False,
316
+ debug = '',
317
+ dataloader_drop_last = False,
318
+ eval_steps = None,
319
+ dataloader_num_workers = 0,
320
+ dataloader_prefetch_factor = None,
321
+ past_index = -1,
322
+ run_name = None,
323
+ disable_tqdm = None,
324
+ remove_unused_columns = True,
325
+ label_names = None,
326
+ load_best_model_at_end = False,
327
+ metric_for_best_model = None,
328
+ greater_is_better = None,
329
+ ignore_data_skip = False,
330
+ fsdp = None,
331
+ fsdp_min_num_params = 0,
332
+ fsdp_config = None,
333
+ fsdp_transformer_layer_cls_to_wrap = None,
334
+ accelerator_config = None,
335
+ parallelism_config = None,
336
+ deepspeed = None,
337
+ label_smoothing_factor = 0.0,
338
+ optim = 'adamw_8bit',
339
+ optim_args = None,
340
+ adafactor = False,
341
+ group_by_length = False,
342
+ length_column_name = 'length',
343
+ report_to = 'none',
344
+ project = 'huggingface',
345
+ trackio_space_id = 'trackio',
346
+ ddp_find_unused_parameters = None,
347
+ ddp_bucket_cap_mb = None,
348
+ ddp_broadcast_buffers = None,
349
+ dataloader_pin_memory = True,
350
+ dataloader_persistent_workers = False,
351
+ skip_memory_metrics = True,
352
+ use_legacy_prediction_loop = False,
353
+ push_to_hub = False,
354
+ resume_from_checkpoint = None,
355
+ hub_model_id = None,
356
+ hub_strategy = 'every_save',
357
+ hub_token = None,
358
+ hub_private_repo = None,
359
+ hub_always_push = False,
360
+ hub_revision = None,
361
+ gradient_checkpointing = True,
362
+ gradient_checkpointing_kwargs = None,
363
+ include_inputs_for_metrics = False,
364
+ eval_do_concat_batches = True,
365
+ fp16_backend = 'auto',
366
+ push_to_hub_model_id = None,
367
+ push_to_hub_organization = None,
368
+ push_to_hub_token = None,
369
+ mp_parameters = '',
370
+ auto_find_batch_size = False,
371
+ full_determinism = False,
372
+ torchdynamo = None,
373
+ ray_scope = 'last',
374
+ ddp_timeout = 1800,
375
+ torch_compile = False,
376
+ torch_compile_backend = None,
377
+ torch_compile_mode = None,
378
+ include_tokens_per_second = False,
379
+ include_num_input_tokens_seen = False,
380
+ neftune_noise_alpha = None,
381
+ optim_target_modules = None,
382
+ batch_eval_metrics = False,
383
+ eval_on_start = False,
384
+ use_liger_kernel = False,
385
+ liger_kernel_config = None,
386
+ eval_use_gather_object = False,
387
+ average_tokens_across_devices = True,
388
+ dataset_num_proc = None,
389
+ num_mini_batches = 1,
390
+ total_episodes = None,
391
+ local_rollout_forward_batch_size = 64,
392
+ num_sample_generations = 10,
393
+ response_length = 53,
394
+ stop_token = None,
395
+ stop_token_id = None,
396
+ temperature = 0.7,
397
+ missing_eos_penalty = None,
398
+ sft_model_path = 'EleutherAI/pythia-160m',
399
+ world_size = None,
400
+ num_total_batches = None,
401
+ micro_batch_size = None,
402
+ local_batch_size = None,
403
+ batch_size = None,
404
+ local_mini_batch_size = None,
405
+ mini_batch_size = None,
406
+ exp_name = 'ppo_config',
407
+ reward_model_path = 'EleutherAI/pythia-160m',
408
+ model_adapter_name = None,
409
+ ref_adapter_name = None,
410
+ num_ppo_epochs = 4,
411
+ whiten_rewards = False,
412
+ kl_coef = 0.05,
413
+ kl_estimator = 'k1',
414
+ cliprange = 0.2,
415
+ vf_coef = 0.1,
416
+ cliprange_value = 0.2,
417
+ gamma = 1.0,
418
+ lam = 0.95,
419
+ ds3_gather_for_generation = True,
420
+ vllm_sampling_params = None,
421
+ unsloth_num_chunks = -1,
422
+
423
+ **kwargs,
424
+ ):
425
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
426
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
427
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
428
+ output_dir = 'unsloth_training_checkpoints'
429
+ save_strategy = 'no'
430
+ if dataset_num_proc is None:
431
+ from multiprocessing import cpu_count
432
+ dataset_num_proc = min(max(cpu_count()+4, 2), 64)
433
+ if temperature <= 0:
434
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
435
+ elif temperature >= 10:
436
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
437
+
438
+
439
+ super().__init__(
440
+ output_dir = output_dir,
441
+ overwrite_output_dir = overwrite_output_dir,
442
+ do_train = do_train,
443
+ do_eval = do_eval,
444
+ do_predict = do_predict,
445
+ eval_strategy = eval_strategy,
446
+ prediction_loss_only = prediction_loss_only,
447
+ per_device_train_batch_size = per_device_train_batch_size,
448
+ per_device_eval_batch_size = per_device_eval_batch_size,
449
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
450
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
451
+ gradient_accumulation_steps = gradient_accumulation_steps,
452
+ eval_accumulation_steps = eval_accumulation_steps,
453
+ eval_delay = eval_delay,
454
+ torch_empty_cache_steps = torch_empty_cache_steps,
455
+ learning_rate = learning_rate,
456
+ weight_decay = weight_decay,
457
+ adam_beta1 = adam_beta1,
458
+ adam_beta2 = adam_beta2,
459
+ adam_epsilon = adam_epsilon,
460
+ max_grad_norm = max_grad_norm,
461
+ num_train_epochs = num_train_epochs,
462
+ max_steps = max_steps,
463
+ lr_scheduler_type = lr_scheduler_type,
464
+ warmup_ratio = warmup_ratio,
465
+ warmup_steps = warmup_steps,
466
+ log_level = log_level,
467
+ log_level_replica = log_level_replica,
468
+ log_on_each_node = log_on_each_node,
469
+ logging_dir = logging_dir,
470
+ logging_strategy = logging_strategy,
471
+ logging_first_step = logging_first_step,
472
+ logging_steps = logging_steps,
473
+ logging_nan_inf_filter = logging_nan_inf_filter,
474
+ save_strategy = save_strategy,
475
+ save_steps = save_steps,
476
+ save_total_limit = save_total_limit,
477
+ save_safetensors = save_safetensors,
478
+ save_on_each_node = save_on_each_node,
479
+ save_only_model = save_only_model,
480
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
481
+ no_cuda = no_cuda,
482
+ use_cpu = use_cpu,
483
+ use_mps_device = use_mps_device,
484
+ seed = seed,
485
+ data_seed = data_seed,
486
+ jit_mode_eval = jit_mode_eval,
487
+ bf16 = bf16,
488
+ fp16 = fp16,
489
+ fp16_opt_level = fp16_opt_level,
490
+ half_precision_backend = half_precision_backend,
491
+ bf16_full_eval = bf16_full_eval,
492
+ fp16_full_eval = fp16_full_eval,
493
+ tf32 = tf32,
494
+ local_rank = local_rank,
495
+ ddp_backend = ddp_backend,
496
+ tpu_num_cores = tpu_num_cores,
497
+ tpu_metrics_debug = tpu_metrics_debug,
498
+ debug = debug,
499
+ dataloader_drop_last = dataloader_drop_last,
500
+ eval_steps = eval_steps,
501
+ dataloader_num_workers = dataloader_num_workers,
502
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
503
+ past_index = past_index,
504
+ run_name = run_name,
505
+ disable_tqdm = disable_tqdm,
506
+ remove_unused_columns = remove_unused_columns,
507
+ label_names = label_names,
508
+ load_best_model_at_end = load_best_model_at_end,
509
+ metric_for_best_model = metric_for_best_model,
510
+ greater_is_better = greater_is_better,
511
+ ignore_data_skip = ignore_data_skip,
512
+ fsdp = fsdp,
513
+ fsdp_min_num_params = fsdp_min_num_params,
514
+ fsdp_config = fsdp_config,
515
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
516
+ accelerator_config = accelerator_config,
517
+ parallelism_config = parallelism_config,
518
+ deepspeed = deepspeed,
519
+ label_smoothing_factor = label_smoothing_factor,
520
+ optim = optim,
521
+ optim_args = optim_args,
522
+ adafactor = adafactor,
523
+ group_by_length = group_by_length,
524
+ length_column_name = length_column_name,
525
+ report_to = report_to,
526
+ project = project,
527
+ trackio_space_id = trackio_space_id,
528
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
529
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
530
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
531
+ dataloader_pin_memory = dataloader_pin_memory,
532
+ dataloader_persistent_workers = dataloader_persistent_workers,
533
+ skip_memory_metrics = skip_memory_metrics,
534
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
535
+ push_to_hub = push_to_hub,
536
+ resume_from_checkpoint = resume_from_checkpoint,
537
+ hub_model_id = hub_model_id,
538
+ hub_strategy = hub_strategy,
539
+ hub_token = hub_token,
540
+ hub_private_repo = hub_private_repo,
541
+ hub_always_push = hub_always_push,
542
+ hub_revision = hub_revision,
543
+ gradient_checkpointing = gradient_checkpointing,
544
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
545
+ include_inputs_for_metrics = include_inputs_for_metrics,
546
+ eval_do_concat_batches = eval_do_concat_batches,
547
+ fp16_backend = fp16_backend,
548
+ push_to_hub_model_id = push_to_hub_model_id,
549
+ push_to_hub_organization = push_to_hub_organization,
550
+ push_to_hub_token = push_to_hub_token,
551
+ mp_parameters = mp_parameters,
552
+ auto_find_batch_size = auto_find_batch_size,
553
+ full_determinism = full_determinism,
554
+ torchdynamo = torchdynamo,
555
+ ray_scope = ray_scope,
556
+ ddp_timeout = ddp_timeout,
557
+ torch_compile = torch_compile,
558
+ torch_compile_backend = torch_compile_backend,
559
+ torch_compile_mode = torch_compile_mode,
560
+ include_tokens_per_second = include_tokens_per_second,
561
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
562
+ neftune_noise_alpha = neftune_noise_alpha,
563
+ optim_target_modules = optim_target_modules,
564
+ batch_eval_metrics = batch_eval_metrics,
565
+ eval_on_start = eval_on_start,
566
+ use_liger_kernel = use_liger_kernel,
567
+ liger_kernel_config = liger_kernel_config,
568
+ eval_use_gather_object = eval_use_gather_object,
569
+ average_tokens_across_devices = average_tokens_across_devices,
570
+ dataset_num_proc = dataset_num_proc,
571
+ num_mini_batches = num_mini_batches,
572
+ total_episodes = total_episodes,
573
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
574
+ num_sample_generations = num_sample_generations,
575
+ response_length = response_length,
576
+ stop_token = stop_token,
577
+ stop_token_id = stop_token_id,
578
+ temperature = temperature,
579
+ missing_eos_penalty = missing_eos_penalty,
580
+ sft_model_path = sft_model_path,
581
+ world_size = world_size,
582
+ num_total_batches = num_total_batches,
583
+ micro_batch_size = micro_batch_size,
584
+ local_batch_size = local_batch_size,
585
+ batch_size = batch_size,
586
+ local_mini_batch_size = local_mini_batch_size,
587
+ mini_batch_size = mini_batch_size,
588
+ exp_name = exp_name,
589
+ reward_model_path = reward_model_path,
590
+ model_adapter_name = model_adapter_name,
591
+ ref_adapter_name = ref_adapter_name,
592
+ num_ppo_epochs = num_ppo_epochs,
593
+ whiten_rewards = whiten_rewards,
594
+ kl_coef = kl_coef,
595
+ kl_estimator = kl_estimator,
596
+ cliprange = cliprange,
597
+ vf_coef = vf_coef,
598
+ cliprange_value = cliprange_value,
599
+ gamma = gamma,
600
+ lam = lam,
601
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
602
+ self.vllm_sampling_params = vllm_sampling_params
603
+ self.unsloth_num_chunks = unsloth_num_chunks
604
+
605
+ pass
606
+
607
+ class _UnslothPPOTrainer(BaseTrainer):
608
+ """"""
609
+
610
+ _tag_names = ["trl", "ppo"]
611
+ _name = "PPO"
612
+ _paper = {
613
+ "title": "Fine-Tuning Language Models from Human Preferences",
614
+ "id": "1909.08593",
615
+ # docstyle-ignore
616
+ "citation": textwrap.dedent("""\
617
+ @article{mziegler2019fine-tuning,
618
+ title = {{Fine-Tuning Language Models from Human Preferences}},
619
+ author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
620
+ year = 2019,
621
+ eprint = {arXiv:1909.08593}
622
+ }"""),
623
+ }
624
+
625
+ def __init__(
626
+ self,
627
+ args: PPOConfig,
628
+ processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
629
+ model: nn.Module,
630
+ ref_model: Optional[nn.Module],
631
+ reward_model: nn.Module,
632
+ train_dataset: Dataset,
633
+ value_model: nn.Module,
634
+ data_collator: Optional[DataCollatorWithPadding] = None,
635
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
636
+ # less commonly used
637
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
638
+ callbacks: Optional[list[TrainerCallback]] = None,
639
+ peft_config: Optional["PeftConfig"] = None,
640
+ ) -> None:
641
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
642
+ warnings.warn(
643
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
644
+ "it and want it to remain, please share your comments here: "
645
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
646
+ "TRL_EXPERIMENTAL_SILENCE=1."
647
+ )
648
+ if ref_model is model:
649
+ raise ValueError(
650
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
651
+ "same as `model`, you must make a copy of it, or `None` if you use peft."
652
+ )
653
+
654
+ self.args = args
655
+ self.processing_class = processing_class
656
+ self.policy_model = model
657
+
658
+ # Define the collator if not provided
659
+ if data_collator is None:
660
+ data_collator = DataCollatorWithPadding(self.processing_class)
661
+
662
+ # Handle stop token settings: update policy model's generation_config to use provided stop token
663
+ if args.stop_token and args.stop_token_id:
664
+ raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
665
+ elif args.stop_token:
666
+ if args.stop_token == "eos":
667
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
668
+ else:
669
+ raise ValueError(
670
+ f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
671
+ )
672
+ else:
673
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
674
+
675
+ # Check that the kl estimator is valid
676
+ if self.args.kl_estimator not in {"k1", "k3"}:
677
+ raise ValueError(
678
+ "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
679
+ "appears to be a strictly better estimator). See "
680
+ "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
681
+ )
682
+
683
+ # peft support
684
+ if not is_peft_available() and peft_config is not None:
685
+ raise ImportError(
686
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
687
+ )
688
+ elif is_peft_available() and peft_config is not None:
689
+ # if model is a peft model and we have a peft_confg, we merge and unload it first
690
+ if isinstance(self.policy_model, PeftModel):
691
+ self.policy_model = self.policy_model.merge_and_unload()
692
+
693
+ # get peft model with the given config
694
+ self.policy_model = get_peft_model(self.policy_model, peft_config)
695
+ if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
696
+ peft_module_casting_to_bf16(self.policy_model)
697
+
698
+ self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
699
+ self.model_adapter_name = args.model_adapter_name
700
+ self.ref_adapter_name = args.ref_adapter_name
701
+
702
+ if ref_model:
703
+ self.ref_model = ref_model
704
+ elif self.is_peft_model:
705
+ self.ref_model = None
706
+ else:
707
+ self.ref_model = create_reference_model(self.policy_model)
708
+
709
+ self.reward_model = reward_model
710
+ self.train_dataset = train_dataset
711
+ self.train_dataset_len = len(train_dataset)
712
+ self.value_model = value_model
713
+ self.data_collator = data_collator
714
+ self.eval_dataset = eval_dataset
715
+ self.optimizer, self.lr_scheduler = optimizers
716
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
717
+
718
+ #########
719
+ # calculate various batch sizes
720
+ #########
721
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
722
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
723
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
724
+ self.accelerator = accelerator
725
+ args.world_size = accelerator.num_processes
726
+ args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
727
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
728
+ args.batch_size = int(args.local_batch_size * args.world_size)
729
+ args.mini_batch_size = exact_div(
730
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
731
+ )
732
+ args.local_mini_batch_size = exact_div(
733
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
734
+ )
735
+ if args.whiten_rewards:
736
+ assert args.local_mini_batch_size >= 8, (
737
+ f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
738
+ )
739
+ # `per_rank_rollout_batch_size` is our `args.local_batch_size`
740
+ # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
741
+ args.num_total_batches = math.ceil(
742
+ args.total_episodes / args.batch_size
743
+ ) # we may train for more than `total_episodes`
744
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
745
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
746
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
747
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
748
+ if args.num_sample_generations > 0:
749
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
750
+ self.local_dataloader_batch_size = args.local_batch_size
751
+
752
+ #########
753
+ # setup model, optimizer, and others
754
+ #########
755
+ for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
756
+ if module is not None:
757
+ disable_dropout_in_model(module)
758
+ self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
759
+ self.model.config = self.policy_model.config # needed for pushing to hub
760
+ self.create_optimizer_and_scheduler(
761
+ num_training_steps=args.num_total_batches
762
+ ) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
763
+
764
+ #########
765
+ # trainer specifics
766
+ #########
767
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
768
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
769
+ self.callback_handler = CallbackHandler(
770
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
771
+ )
772
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
773
+ self.control = TrainerControl()
774
+ self.state = OnlineTrainerState(
775
+ is_local_process_zero=self.is_local_process_zero(),
776
+ is_world_process_zero=self.is_world_process_zero(),
777
+ stateful_callbacks=[
778
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
779
+ ],
780
+ )
781
+ self.current_flos = 0
782
+ self.hp_search_backend = None
783
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
784
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
785
+ # Create distant repo and output directory if needed
786
+ self.hub_model_id = None
787
+ if self.args.push_to_hub:
788
+ self.init_hf_repo()
789
+ if self.args.should_save:
790
+ os.makedirs(self.args.output_dir, exist_ok=True)
791
+
792
+ # Add tags for models that have been loaded with the correct transformers version
793
+ if hasattr(self.model, "add_model_tags"):
794
+ self.model.add_model_tags(self._tag_names)
795
+
796
+ #########
797
+ # setup dataloader
798
+ #########
799
+ self.dataloader = DataLoader(
800
+ self.train_dataset,
801
+ batch_size=self.local_dataloader_batch_size,
802
+ shuffle=True,
803
+ collate_fn=self.data_collator,
804
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
805
+ )
806
+ # sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
807
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
808
+ torch.manual_seed(args.seed)
809
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
810
+ torch.manual_seed(self.local_seed) # reset the local seed again
811
+
812
+ self.eval_dataloader = DataLoader(
813
+ self.eval_dataset,
814
+ batch_size=args.per_device_eval_batch_size,
815
+ collate_fn=self.data_collator,
816
+ drop_last=True,
817
+ ) # no need to shuffle eval dataset
818
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
819
+
820
+ if self.is_deepspeed_enabled:
821
+ self.reward_model = prepare_deepspeed(
822
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
823
+ )
824
+
825
+ if self.ref_model is None:
826
+ if not self.is_peft_model:
827
+ raise ValueError("No reference model and model is not a Peft model.")
828
+ else:
829
+ self.ref_model = prepare_deepspeed(
830
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
831
+ )
832
+ else:
833
+ if self.ref_model is None:
834
+ if not self.is_peft_model:
835
+ raise ValueError("No reference model and model is not a Peft model.")
836
+ else:
837
+ self.ref_model = self.ref_model.to(self.accelerator.device)
838
+ self.reward_model = self.reward_model.to(self.accelerator.device)
839
+
840
+ def get_train_dataloader(self) -> DataLoader:
841
+ return self.dataloader
842
+
843
+ def get_eval_dataloader(self) -> DataLoader:
844
+ return self.eval_dataloader
845
+
846
+ @contextmanager
847
+ def null_ref_context(self):
848
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
849
+ with (
850
+ self.accelerator.unwrap_model(self.model.policy).disable_adapter()
851
+ if self.is_peft_model and not self.ref_adapter_name
852
+ else nullcontext()
853
+ ):
854
+ if self.ref_adapter_name:
855
+ self.model.policy.set_adapter(self.ref_adapter_name)
856
+ yield
857
+ if self.ref_adapter_name:
858
+ self.model.policy.set_adapter(self.model_adapter_name or "default")
859
+
860
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
861
+ backup_model = self.model
862
+ self.model = self.model.policy # save only the policy
863
+
864
+ if self.is_deepspeed_enabled:
865
+ backup_deepspeed = self.deepspeed
866
+ self.deepspeed = self.model
867
+
868
+ super().save_model(output_dir, _internal_call)
869
+
870
+ self.model = backup_model
871
+
872
+ if self.is_deepspeed_enabled:
873
+ self.deepspeed = backup_deepspeed
874
+
875
+ def train(self):
876
+ args = self.args
877
+ accelerator = self.accelerator
878
+ optimizer = self.optimizer
879
+ model = self.model
880
+ ref_policy = self.ref_model
881
+ reward_model = self.reward_model
882
+ processing_class = self.processing_class
883
+ dataloader = self.dataloader
884
+ device = accelerator.device
885
+
886
+ def repeat_generator():
887
+ while True:
888
+ yield from dataloader
889
+
890
+ iter_dataloader = iter(repeat_generator())
891
+ generation_config = GenerationConfig(
892
+ max_new_tokens=args.response_length,
893
+ temperature=(args.temperature + 1e-7),
894
+ top_k=0.0,
895
+ top_p=1.0,
896
+ do_sample=True,
897
+ )
898
+
899
+ accelerator.print("===training policy===")
900
+ start_time = time.time()
901
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
902
+ approxkl_stats = torch.zeros(stats_shape, device=device)
903
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
904
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
905
+ vf_loss_stats = torch.zeros(stats_shape, device=device)
906
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
907
+ entropy_stats = torch.zeros(stats_shape, device=device)
908
+ ratio_stats = torch.zeros(stats_shape, device=device)
909
+ model.train()
910
+
911
+ # trainer state initialization
912
+ self.state.global_step = 0
913
+ self.state.episode = 0
914
+ self.state.max_steps = args.num_total_batches
915
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
916
+ # Compute absolute values for logging, eval, and save if given as ratio
917
+ if args.logging_steps is not None:
918
+ if args.logging_steps < 1:
919
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
920
+ else:
921
+ self.state.logging_steps = args.logging_steps
922
+ if args.eval_steps is not None:
923
+ if args.eval_steps < 1:
924
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
925
+ else:
926
+ self.state.eval_steps = args.eval_steps
927
+ if args.save_steps is not None:
928
+ if args.save_steps < 1:
929
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
930
+ else:
931
+ self.state.save_steps = args.save_steps
932
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
933
+
934
+ # backward compatibility
935
+ if self.is_deepspeed_enabled:
936
+ self.deepspeed = self.model
937
+ self.model_wrapped = self.model
938
+
939
+ for update in range(1, args.num_total_batches + 1):
940
+ self.state.episode += 1 * args.batch_size
941
+ data = next(iter_dataloader)
942
+ with torch.no_grad():
943
+ queries = data["input_ids"].to(device)
944
+ context_length = queries.shape[1]
945
+ responses = []
946
+ postprocessed_responses = []
947
+ logprobs = []
948
+ ref_logprobs = []
949
+ scores = []
950
+ sequence_lengths = []
951
+ values = []
952
+ with unwrap_model_for_generation(
953
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
954
+ ) as unwrapped_model:
955
+ query_responses, logitss = batch_generation(
956
+ unwrapped_model.policy,
957
+ queries,
958
+ args.local_rollout_forward_batch_size,
959
+ processing_class.pad_token_id,
960
+ generation_config,
961
+ )
962
+
963
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
964
+ query = queries[i : i + args.local_rollout_forward_batch_size]
965
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
966
+ response = query_response[:, context_length:]
967
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
968
+ logprob = selective_log_softmax(logits, response)
969
+ del logits
970
+ empty_cache()
971
+
972
+ if ref_policy is None:
973
+ with self.null_ref_context():
974
+ ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
975
+ else:
976
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
977
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
978
+ ref_logits /= args.temperature + 1e-7
979
+ ref_logprob = selective_log_softmax(ref_logits, response)
980
+ del ref_output, ref_logits
981
+ empty_cache()
982
+
983
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
984
+ postprocessed_response = response
985
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
986
+ postprocessed_response = truncate_response(
987
+ self.stop_token_id, processing_class.pad_token_id, response
988
+ )
989
+
990
+ # Response Processing 2. run reward model on the truncated responses
991
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
992
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
993
+ unwrapped_value_model = accelerator.unwrap_model(model).value_model
994
+ full_value, _, _ = get_reward(
995
+ unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
996
+ )
997
+ value = full_value[:, context_length - 1 : -1].squeeze(-1)
998
+ _, score, _ = get_reward(
999
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1000
+ )
1001
+
1002
+ responses.append(response)
1003
+ postprocessed_responses.append(postprocessed_response)
1004
+ logprobs.append(logprob)
1005
+ ref_logprobs.append(ref_logprob)
1006
+ sequence_lengths.append(sequence_length)
1007
+ scores.append(score)
1008
+ values.append(value)
1009
+ responses = torch.cat(responses, 0)
1010
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
1011
+ logprobs = torch.cat(logprobs, 0)
1012
+ ref_logprobs = torch.cat(ref_logprobs, 0)
1013
+ sequence_lengths = torch.cat(sequence_lengths, 0)
1014
+ scores = torch.cat(scores, 0)
1015
+ values = torch.cat(values, 0)
1016
+ del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
1017
+ empty_cache()
1018
+ gc.collect()
1019
+
1020
+ # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
1021
+ # Completions not passing that filter will receive a lower score.
1022
+ contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
1023
+ if self.args.missing_eos_penalty is not None:
1024
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
1025
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
1026
+
1027
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
1028
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
1029
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
1030
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
1031
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
1032
+ sequence_lengths_p1 = sequence_lengths + 1
1033
+ padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
1034
+ values = torch.masked_fill(values, padding_mask_p1, 0)
1035
+
1036
+ # 4. compute rewards
1037
+ # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
1038
+ logr = ref_logprobs - logprobs
1039
+ kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
1040
+ non_score_reward = -args.kl_coef * kl
1041
+ rewards = non_score_reward.clone()
1042
+ actual_start = torch.arange(rewards.size(0), device=rewards.device)
1043
+ actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
1044
+ rewards[[actual_start, actual_end]] += scores
1045
+
1046
+ # 5. whiten rewards
1047
+ if args.whiten_rewards:
1048
+ rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
1049
+ rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
1050
+
1051
+ # 6. compute advantages and returns
1052
+ lastgaelam = 0
1053
+ advantages_reversed = []
1054
+ gen_length = responses.shape[1]
1055
+ for t in reversed(range(gen_length)):
1056
+ nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
1057
+ delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
1058
+ lastgaelam = delta + args.gamma * args.lam * lastgaelam
1059
+ advantages_reversed.append(lastgaelam)
1060
+ advantages = torch.stack(advantages_reversed[::-1], axis=1)
1061
+ returns = advantages + values
1062
+ advantages = masked_whiten(advantages, ~padding_mask)
1063
+ advantages = torch.masked_fill(advantages, padding_mask, 0)
1064
+ empty_cache()
1065
+
1066
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
1067
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
1068
+ b_inds = np.random.permutation(args.local_batch_size)
1069
+ minibatch_idx = 0
1070
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
1071
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
1072
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
1073
+ gradient_accumulation_idx = 0
1074
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
1075
+ with accelerator.accumulate(model):
1076
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
1077
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
1078
+ mb_advantage = advantages[micro_batch_inds]
1079
+ mb_responses = responses[micro_batch_inds]
1080
+ mb_query_responses = query_responses[micro_batch_inds]
1081
+ mb_logprobs = logprobs[micro_batch_inds]
1082
+ mb_return = returns[micro_batch_inds]
1083
+ mb_values = values[micro_batch_inds]
1084
+
1085
+ output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
1086
+ logits = output.logits[:, context_length - 1 : -1]
1087
+ logits /= args.temperature + 1e-7
1088
+ new_logprobs = selective_log_softmax(logits, mb_responses)
1089
+ new_logprobs = torch.masked_fill(
1090
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
1091
+ )
1092
+ vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
1093
+ vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
1094
+ vpredclipped = torch.clamp(
1095
+ vpred,
1096
+ mb_values - args.cliprange_value,
1097
+ mb_values + args.cliprange_value,
1098
+ )
1099
+ vf_losses1 = torch.square(vpred - mb_return)
1100
+ vf_losses2 = torch.square(vpredclipped - mb_return)
1101
+ vf_loss_max = torch.max(vf_losses1, vf_losses2)
1102
+ vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
1103
+ vf_clipfrac = masked_mean(
1104
+ (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
1105
+ )
1106
+ logprobs_diff = new_logprobs - mb_logprobs
1107
+ ratio = torch.exp(logprobs_diff)
1108
+ pg_losses = -mb_advantage * ratio
1109
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
1110
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
1111
+ pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
1112
+ loss = pg_loss + args.vf_coef * vf_loss
1113
+ accelerator.backward(loss)
1114
+ optimizer.step()
1115
+ optimizer.zero_grad()
1116
+ with torch.no_grad():
1117
+ pg_clipfrac = masked_mean(
1118
+ (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
1119
+ )
1120
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1, dtype = torch.float32).to(logits.dtype)
1121
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
1122
+ approxkl = 0.5 * (logprobs_diff**2).mean()
1123
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
1124
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
1125
+ pg_clipfrac
1126
+ )
1127
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
1128
+ vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
1129
+ vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
1130
+ vf_clipfrac
1131
+ )
1132
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
1133
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
1134
+ gradient_accumulation_idx += 1
1135
+ minibatch_idx += 1
1136
+ # del everything and empty cache
1137
+ # fmt: off
1138
+ del (
1139
+ output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
1140
+ vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
1141
+ pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
1142
+ mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
1143
+ )
1144
+ # fmt: on
1145
+ empty_cache()
1146
+ with torch.no_grad():
1147
+ mean_kl = kl.sum(1).mean()
1148
+ mean_entropy = (-logprobs).sum(1).mean()
1149
+ mean_non_score_reward = non_score_reward.sum(1).mean()
1150
+ rlhf_reward = mean_non_score_reward + scores.mean()
1151
+ eps = int(self.state.episode / (time.time() - start_time))
1152
+ metrics = {}
1153
+ metrics["eps"] = eps
1154
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
1155
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
1156
+ metrics["objective/non_score_reward"] = (
1157
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
1158
+ )
1159
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
1160
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
1161
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
1162
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
1163
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
1164
+ metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
1165
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
1166
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
1167
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
1168
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
1169
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
1170
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
1171
+ metrics["episode"] = self.state.episode
1172
+ self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
1173
+ self.state.global_step += 1
1174
+ self.log(metrics)
1175
+
1176
+ self.lr_scheduler.step()
1177
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
1178
+ if self.control.should_save:
1179
+ self._save_checkpoint(model, trial=None)
1180
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1181
+ del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
1182
+ empty_cache()
1183
+ gc.collect()
1184
+
1185
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
1186
+ self.generate_completions(sampling=True)
1187
+ empty_cache()
1188
+ del (
1189
+ query_responses,
1190
+ responses,
1191
+ postprocessed_responses,
1192
+ logprobs,
1193
+ ref_logprobs,
1194
+ values,
1195
+ sequence_lengths,
1196
+ contain_eos_token,
1197
+ sequence_lengths_p1,
1198
+ response_idxs,
1199
+ padding_mask,
1200
+ padding_mask_p1,
1201
+ rewards,
1202
+ actual_start,
1203
+ actual_end,
1204
+ advantages,
1205
+ returns,
1206
+ )
1207
+ empty_cache()
1208
+
1209
+ # HF trainer specifics
1210
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1211
+ if self.control.should_save:
1212
+ self._save_checkpoint(model, trial=None)
1213
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1214
+
1215
+ def generate_completions(self, sampling: bool = False):
1216
+ args = self.args
1217
+ processing_class = self.processing_class
1218
+ generation_config = GenerationConfig(
1219
+ max_new_tokens=self.args.response_length,
1220
+ temperature=(0.01 + 1e-7),
1221
+ top_k=0.0,
1222
+ top_p=1.0,
1223
+ do_sample=True,
1224
+ )
1225
+
1226
+ table = defaultdict(list)
1227
+ with unwrap_model_for_generation(
1228
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1229
+ ) as unwrapped_model:
1230
+ for batch in self.eval_dataloader:
1231
+ query = batch["input_ids"]
1232
+ with torch.no_grad():
1233
+ context_length = query.shape[1]
1234
+ query_response, _ = batch_generation(
1235
+ unwrapped_model.policy,
1236
+ query,
1237
+ query.shape[0],
1238
+ processing_class.pad_token_id,
1239
+ generation_config,
1240
+ )
1241
+ response = query_response[:, context_length:]
1242
+ postprocessed_response = response
1243
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1244
+ postprocessed_response = truncate_response(
1245
+ self.stop_token_id, processing_class.pad_token_id, response
1246
+ )
1247
+ table["query"].extend(
1248
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1249
+ )
1250
+ table["model response"].extend(
1251
+ gather_object(processing_class.batch_decode(postprocessed_response))
1252
+ )
1253
+
1254
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1255
+ _, score, _ = get_reward(
1256
+ self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1257
+ )
1258
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1259
+
1260
+ if sampling:
1261
+ break
1262
+ df = pd.DataFrame(table)
1263
+
1264
+ if self.accelerator.is_main_process:
1265
+ if is_rich_available():
1266
+ print_rich_table(df.iloc[0 : 0 + 5])
1267
+ if "wandb" in args.report_to:
1268
+ import wandb
1269
+
1270
+ if wandb.run is not None:
1271
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1272
+
1273
+ if "comet_ml" in args.report_to:
1274
+ log_table_to_comet_experiment(
1275
+ name="completions.csv",
1276
+ table=df,
1277
+ )
1278
+
1279
+ # Ensure the model card is saved along with the checkpoint
1280
+ def _save_checkpoint(self, model, trial):
1281
+ if self.args.hub_model_id is None:
1282
+ model_name = Path(self.args.output_dir).name
1283
+ else:
1284
+ model_name = self.args.hub_model_id.split("/")[-1]
1285
+ self.create_model_card(model_name=model_name)
1286
+ super()._save_checkpoint(model, trial)
1287
+ class UnslothPPOTrainer(_UnslothPPOTrainer):
1288
+ """
1289
+ Trainer for Proximal Policy Optimization (PPO).
1290
+
1291
+ For details on PPO, see the paper: [Proximal Policy Optimization
1292
+ Algorithms](https://huggingface.co/papers/1707.06347).
1293
+
1294
+ Args:
1295
+ args ([`PPOConfig`]):
1296
+ Training arguments.
1297
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]):
1298
+ Class to process the data.
1299
+ model (`torch.nn.Module`):
1300
+ Model to be trained. This is the policy model.
1301
+ ref_model (`torch.nn.Module`, *optional*):
1302
+ Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created.
1303
+ reward_model (`torch.nn.Module`):
1304
+ Reward model used to compute the rewards.
1305
+ train_dataset ([`~datasets.Dataset`]):
1306
+ Dataset for training.
1307
+ value_model (`torch.nn.Module`):
1308
+ Value model used to predict the value of a state.
1309
+ data_collator ([`~transformers.DataCollatorWithPadding`], *optional*):
1310
+ Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created
1311
+ using the `processing_class`.
1312
+ eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
1313
+ Dataset for evaluation.
1314
+ optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
1315
+ Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the
1316
+ optimizer and the learning rate scheduler are created using the
1317
+ [`~transformers.Trainer.create_optimizer_and_scheduler`] method.
1318
+ callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
1319
+ Callbacks to use during training.
1320
+ peft_config ([`~peft.PeftConfig`], *optional*):
1321
+ PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model`
1322
+ will be wrapped with the specified PEFT adapter.
1323
+
1324
+ """
1325
+ def __init__(
1326
+ self,
1327
+ args,
1328
+ processing_class,
1329
+ model,
1330
+ ref_model,
1331
+ reward_model,
1332
+ train_dataset,
1333
+ value_model,
1334
+ data_collator = None,
1335
+ eval_dataset = None,
1336
+ callbacks = None,
1337
+ peft_config = None,
1338
+ **kwargs
1339
+ ):
1340
+ if args is None: args = UnslothPPOConfig()
1341
+ use_bf16 = getattr(args, 'bf16', False)
1342
+ if type(use_bf16) is not bool: use_bf16 = False
1343
+ use_fp16 = getattr(args, 'fp16', False)
1344
+ if type(use_fp16) is not bool: use_fp16 = False
1345
+ force_float32 = False
1346
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1347
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1348
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1349
+ force_float32 = True
1350
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1351
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1352
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1353
+ from unsloth_zoo.utils import _get_dtype
1354
+ dtype = _get_dtype(dtype)
1355
+ float16 = dtype == torch.float16
1356
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1357
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1358
+ if force_float32:
1359
+ # Forced float32 training
1360
+ args.fp16 = False
1361
+ args.bf16 = False
1362
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1363
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1364
+ # args.mixed_precision is a new argument which needs to be set now
1365
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1366
+ # Mixed precision training
1367
+ args.fp16 = float16
1368
+ args.bf16 = not float16
1369
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1370
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1371
+ # args.mixed_precision is a new argument which needs to be set now
1372
+ elif mixed_precision_dtype == 'bfloat16':
1373
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1374
+ args.fp16 = False
1375
+ args.bf16 = False
1376
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1377
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1378
+ # args.mixed_precision is a new argument which needs to be set now
1379
+
1380
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1381
+ args.eval_strategy = 'steps'
1382
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1383
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1384
+ if ga_steps is not None and ga_steps > 1:
1385
+ from transformers import __version__ as transformers_version
1386
+ if Version(transformers_version) <= Version('4.45.2'):
1387
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1388
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1389
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1390
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1391
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1392
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1393
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1394
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1395
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1396
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1397
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1398
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1399
+ if force_float32:
1400
+ args.bf16_full_eval = False
1401
+ args.fp16_full_eval = False
1402
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1403
+ args.bf16_full_eval = True
1404
+ args.fp16_full_eval = False
1405
+ elif not bf16_full_eval and not fp16_full_eval:
1406
+ args.bf16_full_eval = args.bf16
1407
+ args.fp16_full_eval = args.fp16
1408
+ _output_logits = False
1409
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1410
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1411
+ if _output_logits:
1412
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1413
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1414
+ pass
1415
+ else:
1416
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1417
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1418
+ if args_max_seq_length is None and model_max_seq_length is not None:
1419
+ max_seq_length = model.max_seq_length
1420
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1421
+ if model is not None and hasattr(model, 'for_training'):
1422
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1423
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1424
+ if 'processing_class' in locals():
1425
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1426
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1427
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1428
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1429
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1430
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1431
+ data_collator = TransformersDataCollatorForLanguageModeling(
1432
+ __tokenizer,
1433
+ mlm = False,
1434
+ mlm_probability = 0.0,
1435
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1436
+ )
1437
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1438
+ data_collator = DataCollatorForSeq2Seq(
1439
+ __tokenizer,
1440
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1441
+ )
1442
+ else:
1443
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1444
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1445
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1446
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1447
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1448
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1449
+ data_collator = DataCollatorForSeq2Seq(
1450
+ __tokenizer.tokenizer,
1451
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1452
+ )
1453
+ else:
1454
+ data_collator = TransformersDataCollatorForLanguageModeling(
1455
+ __tokenizer.tokenizer,
1456
+ mlm = False,
1457
+ mlm_probability = 0.0,
1458
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1459
+ )
1460
+ other_metrics = []
1461
+
1462
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1463
+ PatchRLStatistics('ppo_trainer', other_metrics)
1464
+
1465
+ # [TODO] Fix up DataParallel multiplying batch sizes
1466
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1467
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1468
+ if getattr(args, "_n_gpu", 1) != 1:
1469
+ args._n_gpu = 1
1470
+ if "model" in locals() and hasattr(model, "for_training"):
1471
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1472
+ super().__init__(
1473
+ args = args,
1474
+ processing_class = processing_class,
1475
+ model = model,
1476
+ ref_model = ref_model,
1477
+ reward_model = reward_model,
1478
+ train_dataset = train_dataset,
1479
+ value_model = value_model,
1480
+ data_collator = data_collator,
1481
+ eval_dataset = eval_dataset,
1482
+ callbacks = callbacks,
1483
+ peft_config = peft_config,**kwargs)
1484
+ if "model" in locals() and hasattr(model, "for_inference"):
1485
+ model.for_inference()
1486
+ if hasattr(self, 'neftune_hook_handle'):
1487
+ self.neftune_hook_handle.remove()
1488
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1489
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1490
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1491
+ pass
1492
+ if hasattr(self, 'accelerator'):
1493
+ scaler = self.accelerator.scaler
1494
+ current_model = model
1495
+ while hasattr(current_model, 'model'):
1496
+ current_model.accelerator_scaler = scaler
1497
+ current_model = current_model.model
1498
+ current_model.accelerator_scaler = scaler
1499
+ pass
1500
+ if hasattr(self, 'train'):
1501
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1502
+ pass
1503
+
1504
+ pass