cciwon-code-review-cli 2.0.2 → 2.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/lib/api-client.js +1 -1
  2. package/lib/chat-mode.js +7 -2
  3. package/package.json +1 -1
  4. package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
  5. package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
  6. package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
  7. package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
  8. package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
  9. package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
  10. package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
  11. package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
  12. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
  13. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
  14. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
  15. package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
  16. package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
  17. package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
  18. package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
  19. package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
  20. package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
  21. package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
  22. package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
  23. package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
  24. package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
  25. package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
  26. package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
  27. package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
  28. package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
  29. package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
  30. package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
  31. package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
  32. package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
  33. package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
  34. package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
  35. package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
  36. package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
  37. package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
  38. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
  39. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
  40. package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
  41. package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
  42. package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
  43. package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
  44. package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
  45. package/unsloth_compiled_cache/Conv1d.py +70 -0
  46. package/unsloth_compiled_cache/Conv2d.py +70 -0
  47. package/unsloth_compiled_cache/Conv3d.py +70 -0
  48. package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
  49. package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
  50. package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
  51. package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
  52. package/unsloth_compiled_cache/GroupNorm.py +70 -0
  53. package/unsloth_compiled_cache/LayerNorm.py +72 -0
  54. package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
  55. package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
  56. package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
  57. package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
  58. package/unsloth_compiled_cache/RMSNorm.py +73 -0
  59. package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
  60. package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
  61. package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
  62. package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
  63. package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
  64. package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
  65. package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
  66. package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
  67. package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
  68. package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
  69. package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
  70. package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
  71. package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
  72. package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
  73. package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
  74. package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
  75. package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
  76. package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
  77. package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
  78. package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
  79. package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
  80. package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
  81. package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
  82. package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
  83. package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
  84. package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
  85. package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
  86. package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
  87. package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
  88. package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
  89. package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
  90. package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
  91. package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
  92. package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
  93. package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  94. package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  95. package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
  96. package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  97. package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
  98. package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  99. package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  100. package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  101. package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  102. package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  103. package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  104. package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  105. package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  106. package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  107. package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
  108. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
  109. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
  110. package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
  111. package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
@@ -0,0 +1,2313 @@
1
+ """
2
+ 2025.12.6
3
+ 2025.12.7
4
+ 4.57.1
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
30
+ from trl.trainer.online_dpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BasePairwiseJudge, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FSDP, GenerationConfig, GuidedDecodingParams, IterableDataset, LLM, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardFunc, SIMPLE_CHAT_TEMPLATE, SamplingParams, Trainer, TrainerCallback, Union, VLLMClient, apply_chat_template, broadcast_object_list, create_reference_model, disable_dropout_in_model, empty_cache, ensure_master_addr_port, gather_object, is_conversational, is_flash_attn_2_available, is_peft_model, is_vllm_available, jinja2, logger, logging, maybe_apply_chat_template, nn, nullcontext, os, pad, prepare_deepspeed, prepare_fsdp, prepare_peft_model, profiling_context, re, seed_worker, textwrap, torch, truncate_right, unwrap_model_for_generation, version, warnings, wraps, F, LLM, apply_chat_template, is_conversational, os, re, F, FSDP, LLM, is_peft_model, nn, nullcontext, os, re, version, F, Optional, PreTrainedModel, Trainer, logger, os, re, torch, F, FSDP, LLM, nn, os, re, F, FSDP, nn, re, torch)
31
+
32
+
33
+ import os
34
+ from typing import *
35
+ from dataclasses import dataclass, field
36
+ from packaging.version import Version
37
+ import torch
38
+ import numpy as np
39
+ from contextlib import nullcontext
40
+ from torch.nn import functional as F
41
+ import inspect
42
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
43
+ from transformers.training_args import ParallelMode
44
+
45
+ # Wrap trainer with padding to right and enable training mode
46
+ # Also patches W&B since multiple runs must use wandb.finish()
47
+ import functools
48
+ from types import MethodType
49
+ def prepare_for_training_mode(f):
50
+ @functools.wraps(f)
51
+ def wrapper(self, *args, **kwargs):
52
+ # Enable training mode
53
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
54
+ self.model.for_training()
55
+ output = f(self, *args, **kwargs)
56
+ # Return inference mode
57
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
58
+ self.model.for_inference()
59
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
60
+ try:
61
+ import wandb
62
+ wandb.finish()
63
+ except:
64
+ pass
65
+ return output
66
+ return wrapper
67
+ pass
68
+
69
+ torch_compile_options = {
70
+ "epilogue_fusion" : True,
71
+ "max_autotune" : False,
72
+ "shape_padding" : True,
73
+ "trace.enabled" : False,
74
+ "triton.cudagraphs" : False,
75
+ }
76
+
77
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
78
+ def chunked_selective_log_softmax(logits, index):
79
+ # Split into 4 chunks only
80
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
81
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
82
+ all_per_token_logps = []
83
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
84
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
85
+ chunk_logits = chunk_logits.to(torch.float32)
86
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
87
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
88
+ per_token_logps = selected_logits - logsumexp_values
89
+ all_per_token_logps.append(per_token_logps)
90
+ pass
91
+ all_per_token_logps = torch.concat(all_per_token_logps)
92
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
93
+ return all_per_token_logps
94
+
95
+ def calculate_pad_tokens_in_prompt(
96
+ input_ids: torch.Tensor,
97
+ logits_to_keep: int,
98
+ pad_token_id: int
99
+ ) -> torch.Tensor:
100
+ """
101
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
102
+ """
103
+ if logits_to_keep >= input_ids.shape[1]:
104
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
105
+
106
+ prompt_section = input_ids[:, :-logits_to_keep]
107
+
108
+ padding_mask = (prompt_section == pad_token_id)
109
+
110
+ pad_token_counts = padding_mask.sum(dim=1)
111
+
112
+ return pad_token_counts
113
+
114
+ def create_completion_attention_mask(
115
+ completion_input_ids: torch.Tensor,
116
+ left_pad_tokens_per_prompt: torch.Tensor,
117
+ max_left_pad: int,
118
+ pad_token_id: int
119
+ ) -> torch.Tensor:
120
+ """
121
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
122
+
123
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
124
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
125
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
126
+ """
127
+ batch_size, completion_len = completion_input_ids.shape
128
+ device = completion_input_ids.device
129
+
130
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
131
+
132
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
133
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
134
+
135
+ non_padding_mask = (completion_input_ids != pad_token_id)
136
+
137
+ final_mask = shift_mask & non_padding_mask
138
+
139
+ return final_mask
140
+
141
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
142
+ """
143
+ Moves all padding tokens in each sequence of a batch to the right.
144
+ """
145
+ mask = (tensor != pad_id)
146
+ # Must do stable=True since binary mark is unordered
147
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
148
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
149
+ return packed_tensor
150
+
151
+ def align_logprobs_with_mask(
152
+ logprob_tensor: torch.Tensor,
153
+ attention_mask: torch.Tensor,
154
+ pad_value: float = 0.0
155
+ ) -> torch.Tensor:
156
+ """
157
+ Aligns a log probability tensor with a given attention mask.
158
+ """
159
+
160
+ device = logprob_tensor.device
161
+ batch_size, logprob_seq_len = logprob_tensor.shape
162
+ mask_seq_len = attention_mask.shape[1]
163
+
164
+ padded_logprobs = torch.full(
165
+ attention_mask.shape,
166
+ fill_value=pad_value,
167
+ dtype=logprob_tensor.dtype,
168
+ device=device
169
+ )
170
+
171
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
172
+
173
+ cols = torch.arange(logprob_seq_len, device=device)
174
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
175
+
176
+ # Create destination row indices
177
+ # Shape: [batch_size, logprob_seq_len]
178
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
179
+
180
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
181
+ # Create a mask to identify only the indices that are within the bounds
182
+ # of the target tensor's sequence length.
183
+ valid_mask = dest_indices < mask_seq_len
184
+
185
+ # Use this mask to select only the valid row indices, column indices,
186
+ # and the corresponding values from the logprob tensor.
187
+ # This flattens the selected elements into 1D tensors.
188
+ valid_rows = row_indices[valid_mask]
189
+ valid_cols = dest_indices[valid_mask]
190
+ valid_vals = logprob_tensor[valid_mask]
191
+
192
+ # Place the valid values into their correct positions in the padded tensor
193
+ # using a single, efficient advanced indexing operation.
194
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
195
+
196
+ return padded_logprobs
197
+ def vLLMSamplingParams(**kwargs):
198
+ from vllm import SamplingParams
199
+
200
+ sampling_params = SamplingParams(**kwargs)
201
+ sampling_params._set_kwargs = kwargs
202
+ return sampling_params
203
+ @dataclass
204
+ class UnslothOnlineDPOConfig(OnlineDPOConfig):
205
+ """
206
+
207
+ Configuration class for the [`OnlineDPOTrainer`].
208
+
209
+ This class includes only the parameters that are specific to Online DPO training. For a full list of training
210
+ arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
211
+ class may differ from those in [`~transformers.TrainingArguments`].
212
+
213
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
214
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
215
+ command line.
216
+
217
+ Parameters:
218
+ reward_model_path (`str`, *optional*):
219
+ Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
220
+ judge (`str`, *optional*):
221
+ Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
222
+ max_new_tokens (`int`, *optional*, defaults to `64`):
223
+ Maximum number of tokens to generate per completion.
224
+ max_length (`int`, *optional*, defaults to `256`):
225
+ Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
226
+ sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
227
+ possible.
228
+ temperature (`float`, *optional*, defaults to `0.9`):
229
+ Temperature for sampling. The higher the temperature, the more random the completions.
230
+ missing_eos_penalty (`float`, *optional*):
231
+ Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage to
232
+ generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
233
+ value. This parameter only works when using `reward_funcs` and not when using `judge`.
234
+ beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
235
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
236
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
237
+ the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
238
+ selected for each new epoch and the last β is used for the rest of the epochs.
239
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
240
+ Type of loss to use. Possible values are:
241
+
242
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
243
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
244
+
245
+ dataset_num_proc (`int`, *optional*):
246
+ Number of processes to use for processing the dataset.
247
+
248
+ <Deprecated version="0.22.0">
249
+
250
+ This parameter is deprecated and will be removed in version 0.25.0. Since OnlineDPO does not involve
251
+ dataset preparation, you can safely remove it.
252
+
253
+ </Deprecated>
254
+
255
+ disable_dropout (`bool`, *optional*, defaults to `True`):
256
+ Whether to disable dropout in the model and reference model.
257
+
258
+ > Parameters that control generation
259
+
260
+ top_p (`float`, *optional*, defaults to `1.0`):
261
+ Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
262
+ `1.0` to consider all tokens.
263
+ top_k (`int`, *optional*):
264
+ Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
265
+ disabled and all tokens are considered.
266
+ min_p (`float`, *optional*):
267
+ Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
268
+ value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
269
+ repetition_penalty (`float`, *optional*, defaults to `1.0`):
270
+ Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
271
+ Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
272
+ tokens.
273
+ use_transformers_paged (`bool`, *optional*, defaults to `False`):
274
+ Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers`
275
+ paged implementation will be used for generation instead of the default padded implementation. This
276
+ parameter is only effective when `use_vllm` is set to `False`.
277
+ cache_implementation (`str`, *optional*):
278
+ Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
279
+ generation_kwargs (`dict[str, Any]`, *optional*):
280
+ Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
281
+ `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
282
+ generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
283
+ with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
284
+
285
+ > Parameters that control generation acceleration powered by vLLM
286
+
287
+ use_vllm (`bool`, *optional*, defaults to `False`):
288
+ Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation
289
+ instead of the default model.generate(). Requires `vllm` to be installed.
290
+ vllm_model_impl (`str`, *optional*, defaults to `"vllm"`):
291
+ Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
292
+ the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
293
+ implementation.
294
+ vllm_mode (`str`, *optional*, defaults to `"server"`):
295
+ Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or
296
+ `"colocate"`.
297
+
298
+ - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM
299
+ server is running (start with `trl vllm-serve`).
300
+ - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
301
+ separate server but may cause resource contention with training.
302
+ vllm_guided_decoding_regex (`str`, *optional*):
303
+ Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
304
+
305
+ > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
306
+
307
+ vllm_server_base_url (`str`, *optional*):
308
+ Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and
309
+ `vllm_server_port` are ignored.
310
+ vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
311
+ Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
312
+ vllm_server_port (`int`, *optional*, defaults to `8000`):
313
+ Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
314
+ vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
315
+ Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
316
+ timeout, a `ConnectionError` is raised.
317
+
318
+ > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
319
+
320
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.55`):
321
+ Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to
322
+ `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
323
+ launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.
324
+ vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
325
+ Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
326
+ `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
327
+ launching the vLLM server via the `--vllm_tensor_parallel_size` flag.
328
+
329
+ > Other parameters
330
+
331
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
332
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
333
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
334
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
335
+ with vLLM generation.
336
+ model_init_kwargs (`dict[str, Any]`, *optional*):
337
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
338
+ string.
339
+
340
+ """
341
+ vllm_sampling_params: Optional[Any] = field(
342
+ default = None,
343
+ metadata = {'help': 'vLLM SamplingParams'},
344
+ )
345
+ unsloth_num_chunks : Optional[int] = field(
346
+ default = -1,
347
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
348
+ )
349
+ max_seq_length : Optional[int] = field(
350
+ default = None,
351
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
352
+ )
353
+ def __init__(
354
+ self,
355
+ output_dir = None,
356
+ overwrite_output_dir = None,
357
+ do_train = False,
358
+ do_eval = False,
359
+ do_predict = False,
360
+ eval_strategy = 'no',
361
+ prediction_loss_only = False,
362
+ per_device_train_batch_size = 4,
363
+ per_device_eval_batch_size = 4,
364
+ per_gpu_train_batch_size = None,
365
+ per_gpu_eval_batch_size = None,
366
+ gradient_accumulation_steps = 2,
367
+ eval_accumulation_steps = 2,
368
+ eval_delay = 0,
369
+ torch_empty_cache_steps = 250,
370
+ learning_rate = 5e-05,
371
+ weight_decay = 0.01,
372
+ adam_beta1 = 0.9,
373
+ adam_beta2 = 0.999,
374
+ adam_epsilon = 1e-08,
375
+ max_grad_norm = 1.0,
376
+ num_train_epochs = 3.0,
377
+ max_steps = -1,
378
+ lr_scheduler_type = 'linear',
379
+ warmup_ratio = 0.1,
380
+ warmup_steps = 0,
381
+ log_level = 'passive',
382
+ log_level_replica = 'warning',
383
+ log_on_each_node = True,
384
+ logging_dir = None,
385
+ logging_strategy = 'steps',
386
+ logging_first_step = False,
387
+ logging_steps = 1,
388
+ logging_nan_inf_filter = False,
389
+ save_strategy = 'steps',
390
+ save_steps = 500,
391
+ save_total_limit = None,
392
+ save_safetensors = True,
393
+ save_on_each_node = False,
394
+ save_only_model = False,
395
+ restore_callback_states_from_checkpoint = False,
396
+ no_cuda = False,
397
+ use_cpu = False,
398
+ use_mps_device = False,
399
+ seed = 3407,
400
+ data_seed = 3407,
401
+ jit_mode_eval = False,
402
+ bf16 = False,
403
+ fp16 = False,
404
+ fp16_opt_level = 'O1',
405
+ half_precision_backend = 'auto',
406
+ bf16_full_eval = False,
407
+ fp16_full_eval = False,
408
+ tf32 = None,
409
+ local_rank = -1,
410
+ ddp_backend = None,
411
+ tpu_num_cores = None,
412
+ tpu_metrics_debug = False,
413
+ debug = '',
414
+ dataloader_drop_last = False,
415
+ eval_steps = None,
416
+ dataloader_num_workers = 0,
417
+ dataloader_prefetch_factor = None,
418
+ past_index = -1,
419
+ run_name = None,
420
+ disable_tqdm = None,
421
+ remove_unused_columns = True,
422
+ label_names = None,
423
+ load_best_model_at_end = False,
424
+ metric_for_best_model = None,
425
+ greater_is_better = None,
426
+ ignore_data_skip = False,
427
+ fsdp = None,
428
+ fsdp_min_num_params = 0,
429
+ fsdp_config = None,
430
+ fsdp_transformer_layer_cls_to_wrap = None,
431
+ accelerator_config = None,
432
+ parallelism_config = None,
433
+ deepspeed = None,
434
+ label_smoothing_factor = 0.0,
435
+ optim = 'adamw_8bit',
436
+ optim_args = None,
437
+ adafactor = False,
438
+ group_by_length = False,
439
+ length_column_name = 'length',
440
+ report_to = 'none',
441
+ project = 'huggingface',
442
+ trackio_space_id = 'trackio',
443
+ ddp_find_unused_parameters = None,
444
+ ddp_bucket_cap_mb = None,
445
+ ddp_broadcast_buffers = None,
446
+ dataloader_pin_memory = True,
447
+ dataloader_persistent_workers = False,
448
+ skip_memory_metrics = True,
449
+ use_legacy_prediction_loop = False,
450
+ push_to_hub = False,
451
+ resume_from_checkpoint = None,
452
+ hub_model_id = None,
453
+ hub_strategy = 'every_save',
454
+ hub_token = None,
455
+ hub_private_repo = None,
456
+ hub_always_push = False,
457
+ hub_revision = None,
458
+ gradient_checkpointing = True,
459
+ gradient_checkpointing_kwargs = None,
460
+ include_inputs_for_metrics = False,
461
+ eval_do_concat_batches = True,
462
+ fp16_backend = 'auto',
463
+ push_to_hub_model_id = None,
464
+ push_to_hub_organization = None,
465
+ push_to_hub_token = None,
466
+ mp_parameters = '',
467
+ auto_find_batch_size = False,
468
+ full_determinism = False,
469
+ torchdynamo = None,
470
+ ray_scope = 'last',
471
+ ddp_timeout = 1800,
472
+ torch_compile = False,
473
+ torch_compile_backend = None,
474
+ torch_compile_mode = None,
475
+ include_tokens_per_second = False,
476
+ include_num_input_tokens_seen = False,
477
+ neftune_noise_alpha = None,
478
+ optim_target_modules = None,
479
+ batch_eval_metrics = False,
480
+ eval_on_start = False,
481
+ use_liger_kernel = False,
482
+ liger_kernel_config = None,
483
+ eval_use_gather_object = False,
484
+ average_tokens_across_devices = True,
485
+ reward_model_path = None,
486
+ judge = None,
487
+ max_new_tokens = 64,
488
+ max_length = 512,
489
+ temperature = 0.9,
490
+ top_p = 1.0,
491
+ top_k = None,
492
+ min_p = None,
493
+ repetition_penalty = 1.0,
494
+ generation_kwargs = {},
495
+ use_transformers_paged = False,
496
+ cache_implementation = None,
497
+ missing_eos_penalty = None,
498
+ loss_type = 'sigmoid',
499
+ disable_dropout = True,
500
+ use_vllm = False,
501
+ vllm_model_impl = 'vllm',
502
+ vllm_guided_decoding_regex = None,
503
+ vllm_gpu_memory_utilization = 0.55,
504
+ vllm_mode = 'colocate',
505
+ vllm_server_base_url = None,
506
+ vllm_server_host = '0.0.0.0',
507
+ vllm_server_port = 8000,
508
+ vllm_server_timeout = 240.0,
509
+ vllm_tensor_parallel_size = 1,
510
+ ds3_gather_for_generation = True,
511
+ model_init_kwargs = None,
512
+ reward_weights = None,
513
+ dataset_num_proc = None,
514
+ gpu_memory_utilization = None,
515
+ vllm_sampling_params = None,
516
+ unsloth_num_chunks = -1,
517
+ max_seq_length = None,
518
+ **kwargs,
519
+ ):
520
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
521
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
522
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
523
+ output_dir = 'unsloth_training_checkpoints'
524
+ save_strategy = 'no'
525
+ if dataset_num_proc is None:
526
+ from multiprocessing import cpu_count
527
+ dataset_num_proc = min(max(cpu_count()+4, 2), 64)
528
+ if temperature <= 0:
529
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
530
+ elif temperature >= 10:
531
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
532
+
533
+
534
+ super().__init__(
535
+ output_dir = output_dir,
536
+ overwrite_output_dir = overwrite_output_dir,
537
+ do_train = do_train,
538
+ do_eval = do_eval,
539
+ do_predict = do_predict,
540
+ eval_strategy = eval_strategy,
541
+ prediction_loss_only = prediction_loss_only,
542
+ per_device_train_batch_size = per_device_train_batch_size,
543
+ per_device_eval_batch_size = per_device_eval_batch_size,
544
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
545
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
546
+ gradient_accumulation_steps = gradient_accumulation_steps,
547
+ eval_accumulation_steps = eval_accumulation_steps,
548
+ eval_delay = eval_delay,
549
+ torch_empty_cache_steps = torch_empty_cache_steps,
550
+ learning_rate = learning_rate,
551
+ weight_decay = weight_decay,
552
+ adam_beta1 = adam_beta1,
553
+ adam_beta2 = adam_beta2,
554
+ adam_epsilon = adam_epsilon,
555
+ max_grad_norm = max_grad_norm,
556
+ num_train_epochs = num_train_epochs,
557
+ max_steps = max_steps,
558
+ lr_scheduler_type = lr_scheduler_type,
559
+ warmup_ratio = warmup_ratio,
560
+ warmup_steps = warmup_steps,
561
+ log_level = log_level,
562
+ log_level_replica = log_level_replica,
563
+ log_on_each_node = log_on_each_node,
564
+ logging_dir = logging_dir,
565
+ logging_strategy = logging_strategy,
566
+ logging_first_step = logging_first_step,
567
+ logging_steps = logging_steps,
568
+ logging_nan_inf_filter = logging_nan_inf_filter,
569
+ save_strategy = save_strategy,
570
+ save_steps = save_steps,
571
+ save_total_limit = save_total_limit,
572
+ save_safetensors = save_safetensors,
573
+ save_on_each_node = save_on_each_node,
574
+ save_only_model = save_only_model,
575
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
576
+ no_cuda = no_cuda,
577
+ use_cpu = use_cpu,
578
+ use_mps_device = use_mps_device,
579
+ seed = seed,
580
+ data_seed = data_seed,
581
+ jit_mode_eval = jit_mode_eval,
582
+ bf16 = bf16,
583
+ fp16 = fp16,
584
+ fp16_opt_level = fp16_opt_level,
585
+ half_precision_backend = half_precision_backend,
586
+ bf16_full_eval = bf16_full_eval,
587
+ fp16_full_eval = fp16_full_eval,
588
+ tf32 = tf32,
589
+ local_rank = local_rank,
590
+ ddp_backend = ddp_backend,
591
+ tpu_num_cores = tpu_num_cores,
592
+ tpu_metrics_debug = tpu_metrics_debug,
593
+ debug = debug,
594
+ dataloader_drop_last = dataloader_drop_last,
595
+ eval_steps = eval_steps,
596
+ dataloader_num_workers = dataloader_num_workers,
597
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
598
+ past_index = past_index,
599
+ run_name = run_name,
600
+ disable_tqdm = disable_tqdm,
601
+ remove_unused_columns = remove_unused_columns,
602
+ label_names = label_names,
603
+ load_best_model_at_end = load_best_model_at_end,
604
+ metric_for_best_model = metric_for_best_model,
605
+ greater_is_better = greater_is_better,
606
+ ignore_data_skip = ignore_data_skip,
607
+ fsdp = fsdp,
608
+ fsdp_min_num_params = fsdp_min_num_params,
609
+ fsdp_config = fsdp_config,
610
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
611
+ accelerator_config = accelerator_config,
612
+ parallelism_config = parallelism_config,
613
+ deepspeed = deepspeed,
614
+ label_smoothing_factor = label_smoothing_factor,
615
+ optim = optim,
616
+ optim_args = optim_args,
617
+ adafactor = adafactor,
618
+ group_by_length = group_by_length,
619
+ length_column_name = length_column_name,
620
+ report_to = report_to,
621
+ project = project,
622
+ trackio_space_id = trackio_space_id,
623
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
624
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
625
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
626
+ dataloader_pin_memory = dataloader_pin_memory,
627
+ dataloader_persistent_workers = dataloader_persistent_workers,
628
+ skip_memory_metrics = skip_memory_metrics,
629
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
630
+ push_to_hub = push_to_hub,
631
+ resume_from_checkpoint = resume_from_checkpoint,
632
+ hub_model_id = hub_model_id,
633
+ hub_strategy = hub_strategy,
634
+ hub_token = hub_token,
635
+ hub_private_repo = hub_private_repo,
636
+ hub_always_push = hub_always_push,
637
+ hub_revision = hub_revision,
638
+ gradient_checkpointing = gradient_checkpointing,
639
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
640
+ include_inputs_for_metrics = include_inputs_for_metrics,
641
+ eval_do_concat_batches = eval_do_concat_batches,
642
+ fp16_backend = fp16_backend,
643
+ push_to_hub_model_id = push_to_hub_model_id,
644
+ push_to_hub_organization = push_to_hub_organization,
645
+ push_to_hub_token = push_to_hub_token,
646
+ mp_parameters = mp_parameters,
647
+ auto_find_batch_size = auto_find_batch_size,
648
+ full_determinism = full_determinism,
649
+ torchdynamo = torchdynamo,
650
+ ray_scope = ray_scope,
651
+ ddp_timeout = ddp_timeout,
652
+ torch_compile = torch_compile,
653
+ torch_compile_backend = torch_compile_backend,
654
+ torch_compile_mode = torch_compile_mode,
655
+ include_tokens_per_second = include_tokens_per_second,
656
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
657
+ neftune_noise_alpha = neftune_noise_alpha,
658
+ optim_target_modules = optim_target_modules,
659
+ batch_eval_metrics = batch_eval_metrics,
660
+ eval_on_start = eval_on_start,
661
+ use_liger_kernel = use_liger_kernel,
662
+ liger_kernel_config = liger_kernel_config,
663
+ eval_use_gather_object = eval_use_gather_object,
664
+ average_tokens_across_devices = average_tokens_across_devices,
665
+ reward_model_path = reward_model_path,
666
+ judge = judge,
667
+ max_new_tokens = max_new_tokens,
668
+ max_length = max_length,
669
+ temperature = temperature,
670
+ top_p = top_p,
671
+ top_k = top_k,
672
+ min_p = min_p,
673
+ repetition_penalty = repetition_penalty,
674
+ generation_kwargs = generation_kwargs,
675
+ use_transformers_paged = use_transformers_paged,
676
+ cache_implementation = cache_implementation,
677
+ missing_eos_penalty = missing_eos_penalty,
678
+ loss_type = loss_type,
679
+ disable_dropout = disable_dropout,
680
+ use_vllm = use_vllm,
681
+ vllm_model_impl = vllm_model_impl,
682
+ vllm_guided_decoding_regex = vllm_guided_decoding_regex,
683
+ vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
684
+ vllm_mode = vllm_mode,
685
+ vllm_server_base_url = vllm_server_base_url,
686
+ vllm_server_host = vllm_server_host,
687
+ vllm_server_port = vllm_server_port,
688
+ vllm_server_timeout = vllm_server_timeout,
689
+ vllm_tensor_parallel_size = vllm_tensor_parallel_size,
690
+ ds3_gather_for_generation = ds3_gather_for_generation,
691
+ model_init_kwargs = model_init_kwargs,
692
+ reward_weights = reward_weights,
693
+ dataset_num_proc = dataset_num_proc,
694
+ gpu_memory_utilization = gpu_memory_utilization,**kwargs)
695
+ self.vllm_sampling_params = vllm_sampling_params
696
+ self.unsloth_num_chunks = unsloth_num_chunks
697
+ self.max_seq_length = max_seq_length
698
+ pass
699
+
700
+ class _UnslothOnlineDPOTrainer(BaseTrainer):
701
+ r""""""
702
+
703
+ _tag_names = ["trl", "online-dpo"]
704
+ _name = "Online DPO"
705
+ _paper = {
706
+ "title": "Direct Language Model Alignment from Online AI Feedback",
707
+ "id": "2402.04792",
708
+ # docstyle-ignore
709
+ "citation": textwrap.dedent("""\
710
+ @article{guo2024direct,
711
+ title = {{Direct Language Model Alignment from Online AI Feedback}},
712
+ author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
713
+ year = 2024,
714
+ eprint = {arXiv:2402.04792}
715
+ }"""),
716
+ }
717
+
718
+ def __init__(
719
+ self,
720
+ model: Union[PreTrainedModel, nn.Module, str],
721
+ ref_model: Union[PreTrainedModel, nn.Module, None] = None,
722
+ reward_funcs: Optional[Union[RewardFunc, list[RewardFunc]]] = None,
723
+ judge: Optional[BasePairwiseJudge] = None,
724
+ args: Optional[OnlineDPOConfig] = None,
725
+ data_collator: Optional[DataCollator] = None,
726
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
727
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
728
+ processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
729
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
730
+ peft_config: Optional["PeftConfig"] = None,
731
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
732
+ callbacks: Optional[list[TrainerCallback]] = None,
733
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
734
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
735
+ # Deprecated parameters
736
+ reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
737
+ reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
738
+ ) -> None:
739
+
740
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):
741
+ if (getattr(args, 'use_vllm', False) == False):
742
+ args.use_vllm = True
743
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
744
+ warnings.warn(
745
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
746
+ "it and want it to remain, please share your comments here: "
747
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
748
+ "TRL_EXPERIMENTAL_SILENCE=1."
749
+ )
750
+ if ref_model is model:
751
+ raise ValueError(
752
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
753
+ "same as `model`, either omit the `ref_model` argument or pass `None`."
754
+ )
755
+
756
+ self.ref_model = ref_model
757
+
758
+ # Handle deprecated parameters for backward compatibility
759
+ if reward_model is not None:
760
+ warnings.warn(
761
+ "The `reward_model` parameter is deprecated and will be removed in version 0.25.0. "
762
+ "Please use `reward_funcs` instead. For example, change `reward_model=model` to `reward_funcs=model`.",
763
+ )
764
+ # Convert old reward_model to new reward_funcs format
765
+ if reward_funcs is None:
766
+ reward_funcs = reward_model
767
+ else:
768
+ warnings.warn(
769
+ "Both `reward_model` and `reward_funcs` are provided. Using `reward_funcs` and ignoring "
770
+ "`reward_model`.",
771
+ )
772
+
773
+ if reward_processing_class is not None:
774
+ warnings.warn(
775
+ "The `reward_processing_class` parameter is deprecated and will be removed in version 0.25.0. "
776
+ "Please use `reward_processing_classes` instead. For example, change "
777
+ "`reward_processing_class=tokenizer` to `reward_processing_classes=tokenizer`.",
778
+ )
779
+ # Convert old reward_processing_class to new reward_processing_classes format
780
+ if reward_processing_classes is None:
781
+ reward_processing_classes = reward_processing_class
782
+ else:
783
+ warnings.warn(
784
+ "Both `reward_processing_class` and `reward_processing_classes` are provided. Using "
785
+ "`reward_processing_classes` and ignoring `reward_processing_class`.",
786
+ )
787
+
788
+ # Validate reward configuration - must have exactly one of: judge, or reward_funcs
789
+ reward_configs = sum(x is not None for x in [judge, reward_funcs])
790
+ if reward_configs == 0:
791
+ raise ValueError("One of `judge` or `reward_funcs` must be provided.")
792
+ elif reward_configs > 1:
793
+ if judge is not None:
794
+ logger.warning(
795
+ "Both `judge` and `reward_funcs` are provided. Using `judge` and ignoring `reward_funcs`.",
796
+ UserWarning,
797
+ )
798
+ reward_funcs = None
799
+ self.judge = judge
800
+
801
+ # Handle reward_funcs
802
+ if reward_funcs is not None:
803
+ if not isinstance(reward_funcs, list):
804
+ reward_funcs = [reward_funcs]
805
+ self.reward_func_names = []
806
+
807
+ # Process reward functions [convert strings to models, collect names]
808
+ model_init_kwargs = args.model_init_kwargs or {}
809
+ for i, reward_func in enumerate(reward_funcs):
810
+ if isinstance(reward_func, str):
811
+ # Load model from string path
812
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
813
+ reward_func, num_labels=1, **model_init_kwargs
814
+ )
815
+ if isinstance(reward_funcs[i], nn.Module):
816
+ self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
817
+ else:
818
+ self.reward_func_names.append(reward_funcs[i].__name__)
819
+ self.reward_funcs = reward_funcs
820
+
821
+ # Handle reward processing classes for reward_funcs
822
+ if reward_processing_classes is None:
823
+ reward_processing_classes = [None] * len(reward_funcs)
824
+ elif not isinstance(reward_processing_classes, list):
825
+ reward_processing_classes = [reward_processing_classes]
826
+ else:
827
+ if len(reward_processing_classes) != len(reward_funcs):
828
+ raise ValueError(
829
+ "The number of reward processing classes must match the number of reward functions."
830
+ )
831
+
832
+ self.reward_processing_classes = []
833
+ for reward_processing_class_i, reward_func in zip(reward_processing_classes, reward_funcs):
834
+ if isinstance(reward_func, PreTrainedModel):
835
+ if reward_processing_class_i is None:
836
+ reward_processing_class_i = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
837
+ if reward_processing_class_i.pad_token_id is None:
838
+ reward_processing_class_i.pad_token = reward_processing_class_i.eos_token
839
+ # Set pad token ID on reward model config
840
+ reward_func.config.pad_token_id = reward_processing_class_i.pad_token_id
841
+ self.reward_processing_classes.append(reward_processing_class_i)
842
+ else:
843
+ self.reward_funcs = None
844
+ self.reward_func_names = []
845
+ self.reward_processing_classes = []
846
+
847
+ # Handle reward_weights
848
+ if reward_funcs is not None:
849
+ if args.reward_weights is not None:
850
+ if len(args.reward_weights) != len(self.reward_funcs):
851
+ raise ValueError(
852
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
853
+ f"functions ({len(self.reward_funcs)})"
854
+ )
855
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
856
+ else:
857
+ self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32)
858
+ else:
859
+ self.reward_weights = None
860
+
861
+ if args.missing_eos_penalty is not None and reward_funcs is None and judge is None:
862
+ # Check if this is the old reward_model case
863
+ if reward_model is not None:
864
+ logger.warning(
865
+ "The `missing_eos_penalty` parameter is deprecated when used with the deprecated `reward_model` parameter. "
866
+ "Please use `reward_funcs` instead of `reward_model` to continue using this feature.",
867
+ FutureWarning,
868
+ stacklevel=2,
869
+ )
870
+ else:
871
+ raise ValueError("`missing_eos_penalty` is only supported when `reward_funcs` is provided.")
872
+
873
+ if args is None:
874
+ raise ValueError("`args` must be provided.")
875
+
876
+ # Check that the processing_class is provided
877
+ if processing_class is None:
878
+ raise ValueError("`processing_class` must be provided.")
879
+
880
+ model_init_kwargs = args.model_init_kwargs or {}
881
+ if isinstance(model, str):
882
+ model_id = model
883
+
884
+ # Handle dtype in model_init_kwargs
885
+ dtype = model_init_kwargs.get("dtype")
886
+ if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
887
+ pass
888
+ elif isinstance(dtype, str):
889
+ dtype = getattr(torch, dtype)
890
+ model_init_kwargs["dtype"] = dtype
891
+ else:
892
+ raise ValueError(
893
+ "Invalid `dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string "
894
+ f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}."
895
+ )
896
+
897
+ model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
898
+ else:
899
+ if args.model_init_kwargs is not None:
900
+ raise ValueError(
901
+ "You passed `model_init_kwargs` to the `OnlineDPOConfig`, but your model is already instantiated. "
902
+ "This argument can only be used when the `model` argument is a string."
903
+ )
904
+ self.is_encoder_decoder = model.config.is_encoder_decoder
905
+ self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()
906
+
907
+ if False:
908
+ model = prepare_peft_model(model, peft_config, args)
909
+
910
+ # Enable gradient checkpointing if requested
911
+ if args.gradient_checkpointing:
912
+ model = self._enable_gradient_checkpointing(model, args)
913
+
914
+ # Disable dropout in the model and reference model
915
+ if args.disable_dropout:
916
+ disable_dropout_in_model(model)
917
+ if self.ref_model is not None:
918
+ disable_dropout_in_model(self.ref_model)
919
+
920
+ # Handle the ref_model
921
+ # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
922
+ # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
923
+ # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
924
+ if ref_model is None: # No ref model provided, the most common case
925
+ if False:
926
+ self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
927
+ else:
928
+ self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
929
+ else: # rare case, the user provided a ref model
930
+ self.ref_model = ref_model
931
+ self.ref_model.eval()
932
+
933
+ # Disable the gradient and set the reward model in eval mode
934
+ if reward_funcs is not None:
935
+ for reward_func in reward_funcs:
936
+ if isinstance(reward_func, PreTrainedModel):
937
+ reward_func.eval()
938
+
939
+ self.max_length = args.max_length
940
+
941
+ self.stats = {
942
+ "objective/kl": [],
943
+ "objective/entropy": [],
944
+ "objective/non_score_reward": [],
945
+ "rewards/chosen": [],
946
+ "rewards/rejected": [],
947
+ "rewards/accuracies": [],
948
+ "rewards/margins": [],
949
+ "logps/chosen": [],
950
+ "logps/rejected": [],
951
+ "val/contain_eos_token": [],
952
+ "beta": [],
953
+ }
954
+ if self.reward_funcs is not None:
955
+ self.stats["objective/rlhf_reward"] = []
956
+ self.stats["objective/scores_margin"] = []
957
+ self.stats["objective/scores"] = []
958
+
959
+ # Store generation parameters for later use
960
+ self.use_vllm = args.use_vllm
961
+ self.num_generations = 2 # Generate 2 completions per prompt for Online DPO
962
+ self.temperature = args.temperature
963
+ self.top_p = args.top_p
964
+ self.top_k = args.top_k
965
+ self.min_p = args.min_p
966
+ self.repetition_penalty = args.repetition_penalty
967
+ self.use_transformers_paged = args.use_transformers_paged
968
+ self.vllm_mode = args.vllm_mode if args.use_vllm else None
969
+ self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization
970
+ self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size
971
+ self.vllm_model_impl = args.vllm_model_impl
972
+
973
+ # Handle pad token for processors or tokenizers
974
+ if isinstance(processing_class, ProcessorMixin):
975
+ tokenizer = processing_class.tokenizer
976
+ elif isinstance(processing_class, PreTrainedTokenizerBase):
977
+ tokenizer = processing_class
978
+ else:
979
+ raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
980
+
981
+ if tokenizer.pad_token is None:
982
+ tokenizer.pad_token = tokenizer.eos_token
983
+
984
+ self.pad_token = tokenizer.pad_token
985
+ self.pad_token_id = tokenizer.pad_token_id
986
+ self.eos_token_id = tokenizer.eos_token_id
987
+
988
+ # Vision tokens for VLM support
989
+ self.image_token_id = getattr(processing_class, "image_token_id", None)
990
+ self.vision_start_token_id = getattr(processing_class, "vision_start_token_id", None)
991
+ self.vision_end_token_id = getattr(processing_class, "vision_end_token_id", None)
992
+ # Get the image token string for token collapsing
993
+ self.image_token = None
994
+ if self.image_token_id is not None:
995
+ self.image_token = tokenizer.decode([self.image_token_id])
996
+
997
+ # Define the collator if not provided
998
+ if data_collator is None:
999
+ data_collator = DPODataCollatorWithPadding(pad_token_id=self.pad_token_id)
1000
+
1001
+ # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
1002
+ # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
1003
+ # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
1004
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
1005
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
1006
+ # that the warning has already been issued.
1007
+ model.warnings_issued["estimate_tokens"] = True
1008
+
1009
+ super().__init__(
1010
+ model=model,
1011
+ args=args,
1012
+ data_collator=data_collator,
1013
+ train_dataset=train_dataset,
1014
+ eval_dataset=eval_dataset,
1015
+ processing_class=processing_class,
1016
+ compute_metrics=compute_metrics,
1017
+ callbacks=callbacks,
1018
+ optimizers=optimizers,
1019
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
1020
+ )
1021
+
1022
+ # Add tags for models that have been loaded with the correct transformers version
1023
+ if hasattr(self.model, "add_model_tags"):
1024
+ self.model.add_model_tags(self._tag_names)
1025
+
1026
+ self._beta = args.beta
1027
+
1028
+ # Set up generation configuration and vLLM after super[].__init__
1029
+ if self.use_vllm:
1030
+ if not is_vllm_available():
1031
+ raise ImportError(
1032
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
1033
+ "`pip install trl[vllm]` to use it."
1034
+ )
1035
+
1036
+ if self.vllm_mode == "server":
1037
+ if self.accelerator.is_main_process:
1038
+ if args.vllm_server_base_url is not None:
1039
+ base_url = args.vllm_server_base_url
1040
+ else:
1041
+ base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
1042
+ self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
1043
+ self.vllm_client.init_communicator(device=torch.cuda.current_device())
1044
+ else:
1045
+ self.vllm_client = None
1046
+ elif self.vllm_mode == "colocate":
1047
+ vllm_kwargs = {
1048
+ "model": model.name_or_path,
1049
+ "tensor_parallel_size": self.vllm_tensor_parallel_size,
1050
+ "gpu_memory_utilization": self.vllm_gpu_memory_utilization,
1051
+ "model_impl": self.vllm_model_impl,
1052
+ "max_num_seqs": self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size,
1053
+ "max_model_len": args.max_length + args.max_new_tokens,
1054
+ "distributed_executor_backend": "external_launcher",
1055
+ "seed": self.accelerator.process_index // self.vllm_tensor_parallel_size,
1056
+ "max_num_batched_tokens": 4096,
1057
+ }
1058
+ os.environ["RANK"] = str(self.accelerator.process_index)
1059
+ os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index)
1060
+ os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes)
1061
+ ensure_master_addr_port()
1062
+
1063
+ self.llm = model.vllm_engine
1064
+ else:
1065
+ raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
1066
+ self.guided_decoding_regex = args.vllm_guided_decoding_regex
1067
+ self._last_loaded_step = -1
1068
+ generation_params = {
1069
+ "n": 2,
1070
+ "repetition_penalty": self.repetition_penalty,
1071
+ "temperature": self.temperature,
1072
+ "top_p": self.top_p,
1073
+ "top_k": -1 if self.top_k is None else self.top_k,
1074
+ "min_p": 0.0 if self.min_p is None else self.min_p,
1075
+ "max_tokens": args.max_new_tokens,
1076
+ "detokenize": False,
1077
+ }
1078
+ if args.generation_kwargs is not None:
1079
+ generation_params.update(args.generation_kwargs)
1080
+ if self.guided_decoding_regex:
1081
+ generation_params["guided_decoding"] = GuidedDecodingParams(regex=self.guided_decoding_regex)
1082
+ self.generation_config = SamplingParams(**generation_params)
1083
+ self.accelerator.wait_for_everyone()
1084
+ else:
1085
+ # Set up transformers generation config
1086
+ generation_kwargs = {
1087
+ "max_new_tokens": args.max_new_tokens,
1088
+ "do_sample": True,
1089
+ "pad_token_id": self.pad_token_id,
1090
+ "bos_token_id": tokenizer.bos_token_id,
1091
+ "eos_token_id": self.eos_token_id,
1092
+ "temperature": self.temperature,
1093
+ "top_k": self.top_k,
1094
+ "top_p": self.top_p,
1095
+ "repetition_penalty": self.repetition_penalty,
1096
+ "use_cache": True if not self.args.gradient_checkpointing else False,
1097
+ }
1098
+ # Add min_p if supported
1099
+ if self.min_p is not None:
1100
+ generation_kwargs["min_p"] = self.min_p
1101
+ if args.generation_kwargs is not None:
1102
+ generation_kwargs.update(args.generation_kwargs)
1103
+ # Remove None values
1104
+ generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
1105
+ self.generation_config = GenerationConfig(**generation_kwargs)
1106
+
1107
+ if self.ref_model is not None:
1108
+ if self.is_deepspeed_enabled:
1109
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
1110
+ elif self.is_fsdp_enabled:
1111
+ self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
1112
+ else:
1113
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
1114
+ if self.reward_funcs is not None:
1115
+ for i, reward_func in enumerate(self.reward_funcs):
1116
+ if isinstance(reward_func, PreTrainedModel):
1117
+ if self.is_deepspeed_enabled:
1118
+ self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
1119
+ else:
1120
+ # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
1121
+ self.reward_funcs[i] = self.accelerator.prepare_model(
1122
+ reward_func, evaluation_mode=True, device_placement=True
1123
+ )
1124
+
1125
+ @property
1126
+ def beta(self):
1127
+ if isinstance(self._beta, list):
1128
+ epoch = self.state.epoch
1129
+ return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
1130
+ else:
1131
+ return self._beta
1132
+
1133
+ @staticmethod
1134
+ def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
1135
+ """Tokenize a single row from a DPO specific dataset."""
1136
+ if not is_encoder_decoder:
1137
+ batch = tokenizer(feature["prompt"], add_special_tokens=False)
1138
+ # Add BOS token to head of prompt. Avoid adding if it's already there
1139
+ if tokenizer.bos_token_id is not None:
1140
+ prompt_len_input_ids = len(batch["input_ids"])
1141
+ if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
1142
+ batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
1143
+ batch["attention_mask"] = [1] + batch["attention_mask"]
1144
+ else:
1145
+ batch = tokenizer(feature["prompt"], add_special_tokens=True)
1146
+ batch = {f"prompt_{key}": value for key, value in batch.items()}
1147
+ return batch
1148
+
1149
+ # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
1150
+ @wraps(Trainer.get_train_dataloader)
1151
+ def get_train_dataloader(self) -> DataLoader:
1152
+ if self.train_dataset is None:
1153
+ raise ValueError("Trainer: training requires a train_dataset.")
1154
+
1155
+ train_dataset = self.train_dataset
1156
+ data_collator = self.data_collator
1157
+ dataloader_params = {
1158
+ "batch_size": self._train_batch_size,
1159
+ "collate_fn": data_collator,
1160
+ "num_workers": self.args.dataloader_num_workers,
1161
+ "pin_memory": self.args.dataloader_pin_memory,
1162
+ "persistent_workers": self.args.dataloader_persistent_workers,
1163
+ }
1164
+
1165
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
1166
+ dataloader_params["sampler"] = self._get_train_sampler()
1167
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
1168
+ dataloader_params["worker_init_fn"] = seed_worker
1169
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
1170
+
1171
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
1172
+
1173
+ # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
1174
+ @wraps(Trainer.get_eval_dataloader)
1175
+ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
1176
+ if eval_dataset is None and self.eval_dataset is None:
1177
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
1178
+
1179
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
1180
+ # don't change during training
1181
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
1182
+ if (
1183
+ hasattr(self, "_eval_dataloaders")
1184
+ and dataloader_key in self._eval_dataloaders
1185
+ and self.args.dataloader_persistent_workers
1186
+ ):
1187
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
1188
+
1189
+ eval_dataset = (
1190
+ self.eval_dataset[eval_dataset]
1191
+ if isinstance(eval_dataset, str)
1192
+ else eval_dataset
1193
+ if eval_dataset is not None
1194
+ else self.eval_dataset
1195
+ )
1196
+ data_collator = self.data_collator
1197
+
1198
+ dataloader_params = {
1199
+ "batch_size": self.args.eval_batch_size,
1200
+ "collate_fn": data_collator,
1201
+ "num_workers": self.args.dataloader_num_workers,
1202
+ "pin_memory": self.args.dataloader_pin_memory,
1203
+ "persistent_workers": self.args.dataloader_persistent_workers,
1204
+ }
1205
+
1206
+ if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
1207
+ dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
1208
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
1209
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
1210
+
1211
+ # accelerator.free_memory() will destroy the references, so
1212
+ # we need to store the non-prepared version
1213
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
1214
+ if self.args.dataloader_persistent_workers:
1215
+ if hasattr(self, "_eval_dataloaders"):
1216
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
1217
+ else:
1218
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
1219
+
1220
+ return self.accelerator.prepare(eval_dataloader)
1221
+
1222
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPOConfig) -> PreTrainedModel:
1223
+ """Enables gradient checkpointing for the model."""
1224
+ # Ensure use_cache is disabled
1225
+ model.config.use_cache = False
1226
+
1227
+ # Enable gradient checkpointing on the base model for PEFT
1228
+ if is_peft_model(model):
1229
+ model.base_model.gradient_checkpointing_enable()
1230
+ # Enable gradient checkpointing for non-PEFT models
1231
+ else:
1232
+ model.gradient_checkpointing_enable()
1233
+
1234
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
1235
+ use_reentrant = (
1236
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
1237
+ )
1238
+
1239
+ if use_reentrant:
1240
+ model.enable_input_require_grads()
1241
+
1242
+ return model
1243
+
1244
+ def _generate_vllm(self, prompts, images=None):
1245
+ eos_token_id = self.eos_token_id
1246
+ pad_token_id = self.pad_token_id
1247
+
1248
+ # Generate completion_ids and prompt_ids based on mode
1249
+ if self.vllm_mode == "server":
1250
+ completion_ids, prompt_ids = self._generate_vllm_server(prompts, images)
1251
+ elif self.vllm_mode == "colocate":
1252
+ completion_ids, prompt_ids = self._generate_vllm_colocate(prompts, images)
1253
+
1254
+ # Shared padding, masking, and tensor conversion logic
1255
+ max_prompt_length = max(len(ids) for ids in prompt_ids)
1256
+ prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
1257
+ prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
1258
+ max_tokens = self.generation_config.max_tokens
1259
+ completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
1260
+ completion_ids = [
1261
+ ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
1262
+ for ids in completion_ids
1263
+ ]
1264
+ completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
1265
+
1266
+ # Convert to tensors
1267
+ prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
1268
+ prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
1269
+ completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
1270
+ completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
1271
+
1272
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
1273
+
1274
+ def _generate_vllm_server(self, prompts, images=None):
1275
+ """Generate completions using vLLM server mode"""
1276
+ has_images = images is not None
1277
+
1278
+ # Update vLLM server weights if needed
1279
+ if hasattr(self, "_last_loaded_step") and self.state.global_step != self._last_loaded_step:
1280
+ self._move_model_to_vllm()
1281
+ self._last_loaded_step = self.state.global_step
1282
+ elif not hasattr(self, "_last_loaded_step"):
1283
+ self._move_model_to_vllm()
1284
+ self._last_loaded_step = self.state.global_step
1285
+
1286
+ # Apply chat template if conversational
1287
+ if is_conversational({"prompt": prompts[0]}):
1288
+ prompts_text = [apply_chat_template({"prompt": p}, self.processing_class)["prompt"] for p in prompts]
1289
+ else:
1290
+ prompts_text = prompts
1291
+ # Gather all prompts to main process
1292
+ all_prompts = gather_object(prompts_text)
1293
+ if has_images:
1294
+ all_images = gather_object(images)
1295
+
1296
+ if self.accelerator.is_main_process:
1297
+ # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
1298
+ # num_generations outputs for each one. This is faster than generating outputs for each duplicate
1299
+ # prompt individually.
1300
+ ordered_set_of_prompts = all_prompts[:: self.num_generations]
1301
+ if has_images:
1302
+ ordered_set_of_images = all_images[:: self.num_generations]
1303
+ else:
1304
+ ordered_set_of_images = None
1305
+ completion_ids = self.vllm_client.generate(
1306
+ prompts=ordered_set_of_prompts,
1307
+ images=ordered_set_of_images,
1308
+ n=self.num_generations,
1309
+ repetition_penalty=self.repetition_penalty,
1310
+ temperature=self.temperature,
1311
+ top_p=self.top_p,
1312
+ top_k=-1 if self.top_k is None else self.top_k,
1313
+ min_p=0.0 if self.min_p is None else self.min_p,
1314
+ max_tokens=self.generation_config.max_tokens,
1315
+ guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None,
1316
+ generation_kwargs=self.args.generation_kwargs,
1317
+ )
1318
+ # Flatten: each prompt generates 2 completions
1319
+ completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions]
1320
+ else:
1321
+ completion_ids = [None] * (len(all_prompts) * 2)
1322
+
1323
+ # Broadcast completions to all processes
1324
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
1325
+
1326
+ # Each process takes its slice
1327
+ process_slice = slice(
1328
+ self.accelerator.process_index * len(prompts) * 2,
1329
+ (self.accelerator.process_index + 1) * len(prompts) * 2,
1330
+ )
1331
+ completion_ids = completion_ids[process_slice]
1332
+
1333
+ # Create prompt_ids by tokenizing locally
1334
+ prompt_inputs = self.processing_class(
1335
+ text=prompts_text,
1336
+ return_tensors="pt",
1337
+ padding=True,
1338
+ padding_side="left",
1339
+ add_special_tokens=False,
1340
+ )
1341
+ prompt_ids = []
1342
+ for prompt_tokens in prompt_inputs["input_ids"]:
1343
+ prompt_ids.extend([prompt_tokens.tolist(), prompt_tokens.tolist()]) # 2 copies for 2 completions
1344
+ return completion_ids, prompt_ids
1345
+
1346
+ def _generate_vllm_colocate(self, prompts, images=None):
1347
+ """Generate completions using vLLM colocate mode"""
1348
+ # Update model weights if needed - only after gradient accumulation completes
1349
+ if self.state.global_step != self._last_loaded_step:
1350
+ self._move_model_to_vllm()
1351
+ self._last_loaded_step = self.state.global_step
1352
+
1353
+ # Apply chat template if conversational
1354
+ if is_conversational({"prompt": prompts[0]}):
1355
+ prompts_text = [apply_chat_template({"prompt": p}, self.processing_class)["prompt"] for p in prompts]
1356
+ else:
1357
+ prompts_text = prompts
1358
+
1359
+ # Prepare vLLM inputs with images if available
1360
+ if images is not None:
1361
+ vllm_inputs = []
1362
+ for prompt, image in zip(prompts_text, images):
1363
+ if image is not None:
1364
+ vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
1365
+ else:
1366
+ vllm_inputs.append(prompt)
1367
+ else:
1368
+ vllm_inputs = prompts_text
1369
+
1370
+ outputs = self.llm.generate(vllm_inputs, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model_' + (os.environ.get('CUDA_VISIBLE_DEVICES', '0').replace(',','')), load_tensors = True))
1371
+
1372
+ completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
1373
+ prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
1374
+
1375
+ return completion_ids, prompt_ids
1376
+
1377
+ def _move_model_to_vllm(self):
1378
+ """Synchronize model weights to vLLM server with support for PEFT, DeepSpeed, and FSDP"""
1379
+ # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
1380
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
1381
+ zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
1382
+ if zero_stage_3:
1383
+ import deepspeed
1384
+
1385
+ gather_if_zero3 = deepspeed.zero.GatheredParameters
1386
+ else:
1387
+ gather_if_zero3 = nullcontext
1388
+
1389
+ if is_peft_model(self.model):
1390
+ # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
1391
+ # merging adapters in a sharded manner is not supported.
1392
+ # TODO: does this work with FSDP?
1393
+ with gather_if_zero3(list(self.model.parameters())):
1394
+ self.model.merge_adapter()
1395
+
1396
+ # Update vLLM weights while parameters are gathered
1397
+ if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
1398
+ # Update vLLM weights while parameters are gathered
1399
+ # For PEFT with FSDP we need to use the memory efficient post-order traversal
1400
+ fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
1401
+ fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
1402
+ if fsdp_version == 1:
1403
+ # use memory-efficient post-order traversal for FSDP
1404
+ self._sync_fsdp1_params_to_vllm(self.model)
1405
+ elif fsdp_version == 2:
1406
+ self._sync_fsdp2_params_to_vllm(self.model)
1407
+ else:
1408
+ # DeepSpeed ZeRO-3 with PEFT
1409
+ for name, param in self.model.named_parameters():
1410
+ # When using PEFT, we need to recover the original parameter name and discard some parameters
1411
+ name = name.removeprefix("base_model.model.").replace(".base_layer", "")
1412
+ if self.model.prefix in name:
1413
+ continue
1414
+ # When module to save, remove its prefix and discard the original module
1415
+ if "original_module" in name:
1416
+ continue
1417
+ name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."])
1418
+
1419
+ if self.vllm_mode == "server" and self.accelerator.is_main_process:
1420
+ self.vllm_client.update_named_param(name, param.data)
1421
+ elif self.vllm_mode == "colocate":
1422
+
1423
+ pass
1424
+
1425
+ pass
1426
+ # Unmerge adapters while parameters are still gathered
1427
+ self.model.unmerge_adapter()
1428
+ # Parameters will automatically be repartitioned when exiting the context
1429
+ else:
1430
+ # For non-PEFT models, simply gather (if needed) and update each parameter individually.
1431
+ if self.is_fsdp_enabled:
1432
+ fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
1433
+ fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
1434
+ if fsdp_version == 1:
1435
+ self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
1436
+ elif fsdp_version == 2:
1437
+ self._sync_fsdp2_params_to_vllm(self.model)
1438
+ else:
1439
+ for name, param in self.model.named_parameters():
1440
+ name = self._fix_param_name_to_vllm(name)
1441
+ with gather_if_zero3([param]):
1442
+ if self.vllm_mode == "server" and self.accelerator.is_main_process:
1443
+ self.vllm_client.update_named_param(name, param.data)
1444
+ elif self.vllm_mode == "colocate":
1445
+
1446
+ pass
1447
+
1448
+ pass
1449
+
1450
+ # Reset cache on vLLM
1451
+ if self.vllm_mode == "server" and self.accelerator.is_main_process:
1452
+ self.vllm_client.reset_prefix_cache()
1453
+ elif self.vllm_mode == "colocate":
1454
+ self.llm.reset_prefix_cache()
1455
+
1456
+ def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
1457
+ """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
1458
+ # For FSDP1, we need to recurse into children and also use summon_full_params
1459
+ if visited is None:
1460
+ visited = set()
1461
+ for child_name, child_module in module.named_children():
1462
+ child_prefix = f"{prefix}.{child_name}" if prefix else child_name
1463
+ self._sync_fsdp1_params_to_vllm(
1464
+ child_module, prefix=child_prefix, visited=visited
1465
+ ) # recurse into the child
1466
+
1467
+ if isinstance(module, FSDP):
1468
+ with FSDP.summon_full_params(module, recurse=False, writeback=False):
1469
+ for param_name, param in module.named_parameters():
1470
+ full_name = f"{prefix}.{param_name}" if prefix else param_name
1471
+ full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."])
1472
+
1473
+ if full_name in visited:
1474
+ continue # skip FSDP subtrees already traversed
1475
+ visited.add(full_name)
1476
+
1477
+ if self.vllm_mode == "server" and self.accelerator.is_main_process:
1478
+ self.vllm_client.update_named_param(full_name, param.data)
1479
+ elif self.vllm_mode == "colocate":
1480
+
1481
+ pass
1482
+
1483
+ pass
1484
+
1485
+ def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
1486
+ # For FSDP2, module already covers all parameters, so no need for recursion
1487
+ for name, param in module.items():
1488
+ if param.is_cpu:
1489
+ param = param.to(torch.device("cuda"))
1490
+ param = param.full_tensor()
1491
+
1492
+ if self.vllm_mode == "server" and self.accelerator.is_main_process:
1493
+ self.vllm_client.update_named_param(name, param)
1494
+ elif self.vllm_mode == "colocate":
1495
+
1496
+ pass
1497
+
1498
+ pass
1499
+
1500
+ def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
1501
+ """Clean parameter names for vLLM compatibility"""
1502
+ extra_prefixes = extra_prefixes or []
1503
+ prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
1504
+ for prefix in prefixes:
1505
+ name = name.replace(prefix, "")
1506
+ return name
1507
+
1508
+ def process_vision_row(
1509
+ self, features: dict[str, Union[list, torch.Tensor]], processing_class=None
1510
+ ) -> dict[str, list[int]]:
1511
+ """
1512
+ Process a vision row for VLM models (adapted from DPO trainer)
1513
+ """
1514
+ processor = processing_class or self.processing_class
1515
+ processed_features = processor(images=[features["image"]], text=features["prompt"], add_special_tokens=False)
1516
+
1517
+ prompt_input_ids = processed_features["input_ids"][0]
1518
+
1519
+ # Create the output dict with required fields
1520
+ output = {
1521
+ "prompt_input_ids": prompt_input_ids,
1522
+ "prompt_attention_mask": processed_features["attention_mask"][0],
1523
+ }
1524
+
1525
+ # Add vision-specific fields
1526
+ if "pixel_values" in processed_features:
1527
+ output["pixel_values"] = processed_features["pixel_values"][0]
1528
+ if "pixel_attention_mask" in processed_features:
1529
+ output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
1530
+ if "image_sizes" in processed_features:
1531
+ output["image_sizes"] = processed_features["image_sizes"][0]
1532
+
1533
+ return output
1534
+
1535
+ def _generate(self, model, prompts, images=None):
1536
+ """Generate completions using the model"""
1537
+ device = next(model.parameters()).device
1538
+ eos_token_id = self.eos_token_id
1539
+ pad_token_id = self.pad_token_id
1540
+
1541
+ # Apply chat template and tokenize the input
1542
+ inputs = [{"prompt": prompt} for prompt in prompts]
1543
+
1544
+ # Add images if provided (VLM support)
1545
+ if images is not None:
1546
+ for i, image in enumerate(images):
1547
+ inputs[i]["image"] = image
1548
+
1549
+ # Apply chat template to get text prompts
1550
+ prompts_text = [maybe_apply_chat_template(x, self.processing_class)["prompt"] for x in inputs]
1551
+
1552
+ # Handle image token collapsing/removal
1553
+ # The chat template sometimes inserts a single image token into the prompt text. However, when this text is
1554
+ # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
1555
+ # image size. We need to handle this properly.
1556
+ if self.image_token is not None and images is not None:
1557
+ escaped_img_token = re.escape(self.image_token)
1558
+ # Search for the image token in the chat template
1559
+ if hasattr(self.processing_class, "chat_template") and self.processing_class.chat_template:
1560
+ if re.search(escaped_img_token, self.processing_class.chat_template):
1561
+ # Collapse repeated image tokens back into a single token
1562
+ prompts_text = [
1563
+ re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
1564
+ ]
1565
+ else:
1566
+ # If the chat template doesn't use the image token, remove all instances
1567
+ if self.vision_end_token_id is not None:
1568
+ escaped_eoi_token = re.escape(
1569
+ self.processing_class.tokenizer.decode([self.vision_end_token_id])
1570
+ )
1571
+ prompts_text = [
1572
+ re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
1573
+ ]
1574
+ else:
1575
+ # If vision_end_token_id is None, just remove the image tokens
1576
+ prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
1577
+
1578
+ # Prepare kwargs for processing class
1579
+ kwargs = {}
1580
+ if images is not None:
1581
+ kwargs = {"images": [[img] for img in images]}
1582
+
1583
+ # Process inputs using the processing class (handles both VLM and LLM)
1584
+ prompt_inputs = self.processing_class(
1585
+ text=prompts_text,
1586
+ return_tensors="pt",
1587
+ padding=True,
1588
+ padding_side="left",
1589
+ add_special_tokens=False,
1590
+ **kwargs,
1591
+ )
1592
+
1593
+ prompt_inputs = {k: v.to(device) for k, v in prompt_inputs.items()}
1594
+ # Convert vision inputs to model's dtype for proper computation
1595
+ if "pixel_values" in prompt_inputs:
1596
+ # Handle DataParallel wrapped models
1597
+ model_dtype = getattr(model, "dtype", None)
1598
+ if model_dtype is None and hasattr(model, "module"):
1599
+ model_dtype = model.module.dtype
1600
+ if model_dtype is not None:
1601
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].to(model_dtype)
1602
+
1603
+ # Sample 2 completions per prompt of size `max_new_tokens` from the model
1604
+ prompt_ids = prompt_inputs["input_ids"].repeat(2, 1)
1605
+ prompt_mask = prompt_inputs["attention_mask"].repeat(2, 1)
1606
+
1607
+ # Prepare vision inputs if available
1608
+ vision_generation_kwargs = {}
1609
+ if self.is_vision_model and images is not None:
1610
+ if "pixel_values" in prompt_inputs:
1611
+ vision_generation_kwargs["pixel_values"] = prompt_inputs["pixel_values"].repeat(2, 1, 1, 1)
1612
+ if "pixel_attention_mask" in prompt_inputs:
1613
+ vision_generation_kwargs["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"].repeat(2, 1)
1614
+ if "image_sizes" in prompt_inputs:
1615
+ vision_generation_kwargs["image_sizes"] = prompt_inputs["image_sizes"].repeat(2, 1)
1616
+ if "image_grid_thw" in prompt_inputs:
1617
+ vision_generation_kwargs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(2, 1)
1618
+
1619
+ if self.use_transformers_paged:
1620
+ previous_attn = self.model_wrapped.config._attn_implementation
1621
+
1622
+ if is_flash_attn_2_available():
1623
+ self.model_wrapped.config._attn_implementation = "paged_attention"
1624
+ else:
1625
+ self.model_wrapped.config._attn_implementation = "sdpa_paged"
1626
+ with (
1627
+ profiling_context(self, "transformers.generate_batch"),
1628
+ unwrap_model_for_generation(
1629
+ model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1630
+ ) as unwrapped_model,
1631
+ torch.no_grad(),
1632
+ FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
1633
+ ):
1634
+ # Cast to the appropriate dtype based on training configuration
1635
+ if self.args.bf16:
1636
+ unwrapped_model.to(torch.bfloat16)
1637
+ elif self.args.fp16:
1638
+ unwrapped_model.to(torch.float16)
1639
+ with torch.inference_mode():
1640
+ all_outputs = unwrapped_model.generate_batch(
1641
+ prompt_ids.tolist(),
1642
+ generation_config=self.generation_config,
1643
+ progress_bar=False,
1644
+ )
1645
+ unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
1646
+ completion_ids = [output.generated_tokens for output in all_outputs.values()]
1647
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
1648
+ completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
1649
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
1650
+ # Restore the original attention implementation, training mode
1651
+ self.model_wrapped.config._attn_implementation = previous_attn
1652
+
1653
+ # Extract completion_ids and create completion_mask
1654
+ prompt_length = prompt_ids.size(1)
1655
+ completion_ids = prompt_completion_ids[:, prompt_length:]
1656
+ completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
1657
+
1658
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
1659
+ else:
1660
+ # Regular generation path
1661
+ with (
1662
+ profiling_context(self, "transformers.generate"),
1663
+ unwrap_model_for_generation(
1664
+ model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1665
+ ) as unwrapped_model,
1666
+ torch.no_grad(),
1667
+ FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
1668
+ ):
1669
+ # Setup cache implementation if specified
1670
+ if self.args.cache_implementation is not None:
1671
+ unwrapped_model.generation_config.cache_implementation = self.args.cache_implementation
1672
+
1673
+ # Standard generation
1674
+ output = unwrapped_model.generate(
1675
+ input_ids=prompt_ids,
1676
+ attention_mask=prompt_mask,
1677
+ generation_config=self.generation_config,
1678
+ **vision_generation_kwargs,
1679
+ )
1680
+
1681
+ completion_ids = output[:, prompt_ids.size(1) :]
1682
+ completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
1683
+
1684
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
1685
+
1686
+ def _calculate_rewards_from_functions(self, prompts, completions, completion_ids_list, **reward_kwargs):
1687
+ """
1688
+ Calculate rewards using reward functions
1689
+ """
1690
+ device = self.accelerator.device
1691
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
1692
+
1693
+ # Add trainer state to reward kwargs for dynamic reward shaping
1694
+ reward_kwargs["trainer_state"] = self.state
1695
+
1696
+ for i, (reward_func, reward_processing_class) in enumerate(
1697
+ zip(self.reward_funcs, self.reward_processing_classes)
1698
+ ):
1699
+ if isinstance(reward_func, nn.Module): # Model-based reward function
1700
+ # Handle conversational vs text input
1701
+ if is_conversational({"prompt": prompts[0]}):
1702
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
1703
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1704
+ else:
1705
+ texts = [p + c for p, c in zip(prompts, completions)]
1706
+
1707
+ # Tokenize and get reward scores
1708
+ reward_inputs = reward_processing_class(
1709
+ text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
1710
+ )
1711
+ reward_inputs = {k: v.to(device) for k, v in reward_inputs.items()}
1712
+
1713
+ with torch.inference_mode():
1714
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
1715
+ else:
1716
+ # Custom reward function
1717
+ output_reward_func = reward_func(
1718
+ prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
1719
+ )
1720
+ # Convert None values to NaN
1721
+ output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
1722
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
1723
+
1724
+ # Weight and sum across all reward functions
1725
+ if self.reward_weights is not None:
1726
+ total_rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
1727
+ else:
1728
+ total_rewards = rewards_per_func.nansum(dim=1)
1729
+
1730
+ return total_rewards
1731
+
1732
+ def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs=None):
1733
+ # Get the number of tokens to truncate from prompt
1734
+ num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
1735
+
1736
+ # Truncate left to avoid oom
1737
+ prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
1738
+ prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
1739
+
1740
+ # Concat the prompt and completion
1741
+ prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
1742
+ prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
1743
+
1744
+ # Prepare model kwargs with vision inputs if available
1745
+ model_kwargs = {"attention_mask": prompt_completion_mask}
1746
+ if vision_inputs is not None:
1747
+ if "pixel_values" in vision_inputs:
1748
+ model_kwargs["pixel_values"] = vision_inputs["pixel_values"]
1749
+ if "pixel_attention_mask" in vision_inputs:
1750
+ model_kwargs["pixel_attention_mask"] = vision_inputs["pixel_attention_mask"]
1751
+ if "image_sizes" in vision_inputs:
1752
+ model_kwargs["image_sizes"] = vision_inputs["image_sizes"]
1753
+ if "image_grid_thw" in vision_inputs:
1754
+ model_kwargs["image_grid_thw"] = vision_inputs["image_grid_thw"]
1755
+
1756
+ # Get the logprobs of the completions from the model
1757
+ output = model(prompt_completion_ids, **model_kwargs)
1758
+
1759
+ # There is 1 offset, because the model predicts the next token
1760
+ prompt_len = prompt_ids.size(1)
1761
+ start_idx = prompt_len - 1 if prompt_len > 0 else 0
1762
+ # Only slice off the last logit when we have a prompt, otherwise we need all logits
1763
+ end_idx = -1 if prompt_len > 0 else None
1764
+ logits = output.logits[:, start_idx:end_idx]
1765
+
1766
+ # Take the completion tokens logprob
1767
+ logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
1768
+ return logprobs
1769
+
1770
+ def training_step(
1771
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
1772
+ ) -> torch.Tensor:
1773
+ model.train()
1774
+
1775
+ prompts = inputs["prompt"]
1776
+ batch_size = len(prompts)
1777
+
1778
+ # Handle images for VLM support
1779
+ has_images = "image" in inputs
1780
+ images = None
1781
+ if has_images:
1782
+ images = inputs["image"]
1783
+ # Convert conversational prompts to include image tokens
1784
+ for prompt in prompts:
1785
+ if isinstance(prompt, list):
1786
+ for message in prompt:
1787
+ if not isinstance(message, dict):
1788
+ continue
1789
+ content = message.get("content")
1790
+ role = message.get("role")
1791
+ if isinstance(content, str):
1792
+ if role == "user":
1793
+ message["content"] = [{"type": "image"}, {"type": "text", "text": content}]
1794
+ elif role == "system":
1795
+ message["content"] = [{"type": "text", "text": content}]
1796
+
1797
+ if self.args.use_vllm:
1798
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(prompts, images)
1799
+ else:
1800
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts, images)
1801
+
1802
+ contain_eos_token = torch.any(completion_ids == self.eos_token_id, dim=-1)
1803
+
1804
+ # Extract vision inputs if available for VLM support
1805
+ vision_inputs = None
1806
+ if has_images and self.is_vision_model and not self.args.use_vllm:
1807
+ # For vision models with transformers generation, we need to prepare vision inputs
1808
+ # Process the images to get vision inputs that can be passed through the forward pass
1809
+ vision_inputs = {}
1810
+ kwargs = {"images": [[img] for img in images]}
1811
+ processed = self.processing_class(
1812
+ text=[""] * len(images), # Dummy text for vision processing
1813
+ return_tensors="pt",
1814
+ **kwargs,
1815
+ )
1816
+ # Handle DataParallel wrapped models
1817
+ model_device = getattr(model, "device", None)
1818
+ model_dtype = getattr(model, "dtype", None)
1819
+ if model_device is None and hasattr(model, "module"):
1820
+ model_device = model.module.device
1821
+ model_dtype = model.module.dtype
1822
+ # Move vision tensors to device and convert to model dtype
1823
+ # Need to duplicate for 2 completions per prompt
1824
+ if "pixel_values" in processed:
1825
+ vision_inputs["pixel_values"] = (
1826
+ processed["pixel_values"].to(model_device, dtype=model_dtype).repeat(2, 1, 1, 1)
1827
+ )
1828
+ if "pixel_attention_mask" in processed:
1829
+ vision_inputs["pixel_attention_mask"] = processed["pixel_attention_mask"].to(model_device).repeat(2, 1)
1830
+ if "image_sizes" in processed:
1831
+ vision_inputs["image_sizes"] = processed["image_sizes"].to(model_device).repeat(2, 1)
1832
+ if "image_grid_thw" in processed:
1833
+ vision_inputs["image_grid_thw"] = processed["image_grid_thw"].to(model_device).repeat(2, 1)
1834
+
1835
+ logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs)
1836
+ with torch.no_grad():
1837
+ if self.ref_model is not None:
1838
+ ref_logprobs = self._forward(
1839
+ self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs
1840
+ )
1841
+ else: # peft case: we just need to disable the adapter
1842
+ with self.model.disable_adapter():
1843
+ ref_logprobs = self._forward(
1844
+ self.model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs
1845
+ )
1846
+
1847
+ # Decode the completions, and format them if the input is conversational
1848
+ device = logprobs.device
1849
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1850
+ if is_conversational({"prompt": prompts[0]}):
1851
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
1852
+
1853
+ # Get the reward from reward functions, judge, or deprecated reward_model
1854
+ if self.reward_funcs is not None:
1855
+ # First create completion_ids_list for custom reward functions
1856
+ completion_ids_list = [completion_ids[i].tolist() for i in range(completion_ids.shape[0])]
1857
+
1858
+ # Extract additional fields from inputs for reward functions
1859
+ reward_kwargs = {}
1860
+ keys = [key for key in inputs if key not in ["prompt"]]
1861
+ for key in keys:
1862
+ if isinstance(inputs[key], (list, tuple)):
1863
+ # Repeat input fields to match number of completions (2 per prompt)
1864
+ reward_kwargs[key] = inputs[key] * 2
1865
+ else:
1866
+ reward_kwargs[key] = inputs[key]
1867
+
1868
+ # Calculate rewards using reward functions
1869
+ rewards = self._calculate_rewards_from_functions(
1870
+ prompts=2 * prompts, completions=completions, completion_ids_list=completion_ids_list, **reward_kwargs
1871
+ )
1872
+
1873
+ # Apply missing EOS penalty if configured
1874
+ if self.args.missing_eos_penalty is not None:
1875
+ rewards[~contain_eos_token] -= self.args.missing_eos_penalty
1876
+
1877
+ # Split rewards into chosen/rejected pairs
1878
+ first_half, second_half = rewards.split(batch_size)
1879
+ mask = first_half >= second_half
1880
+ elif self.judge is not None:
1881
+ # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
1882
+ # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
1883
+ # independent of the model's chat template, we use the raw conversation data, and apply our own chat
1884
+ # template to it.
1885
+ if is_conversational({"prompt": prompts[0]}):
1886
+ environment = jinja2.Environment()
1887
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
1888
+ prompts = [template.render(messages=prompt) for prompt in prompts]
1889
+ completions = [template.render(messages=completion) for completion in completions]
1890
+
1891
+ ranks_of_first_completion = self.judge.judge(
1892
+ prompts, list(zip(completions[:batch_size], completions[batch_size:]))
1893
+ )
1894
+
1895
+ # convert ranks to a True/False mask:
1896
+ # when rank == 0, it means the first completion is the best
1897
+ # when rank == 1, it means the second completion is the best
1898
+ mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
1899
+
1900
+ batch_range = torch.arange(batch_size, device=device)
1901
+ chosen_indices = batch_range + (~mask * batch_size)
1902
+ rejected_indices = batch_range + (mask * batch_size)
1903
+
1904
+ # Build tensor so that the first half is the chosen examples and the second half the rejected examples
1905
+ cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
1906
+ cr_logprobs = logprobs[cr_indices]
1907
+ cr_ref_logprobs = ref_logprobs[cr_indices]
1908
+
1909
+ # mask out the padding tokens
1910
+ padding_mask = ~completion_mask.bool()
1911
+ cr_padding_mask = padding_mask[cr_indices]
1912
+
1913
+ cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
1914
+ cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
1915
+
1916
+ # Split the chosen and rejected examples
1917
+ chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
1918
+ chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
1919
+ pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
1920
+ ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
1921
+
1922
+ logits = pi_logratios - ref_logratios
1923
+
1924
+ if self.args.loss_type == "sigmoid":
1925
+ losses = -F.logsigmoid(self.beta * logits)
1926
+ elif self.args.loss_type == "ipo":
1927
+ losses = (logits - 1 / (2 * self.beta)) ** 2
1928
+ else:
1929
+ raise NotImplementedError(f"invalid loss type {self.loss_type}")
1930
+
1931
+ loss = losses.mean()
1932
+
1933
+ # Log everything
1934
+ if self.reward_funcs is not None:
1935
+ # When using reward_funcs, we have rewards instead of scores
1936
+ scores_margin = rewards[chosen_indices] - rewards[rejected_indices]
1937
+ self.stats["objective/scores_margin"].append(
1938
+ self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
1939
+ )
1940
+ self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(rewards.mean()).mean().item())
1941
+ self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
1942
+ self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
1943
+ self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
1944
+
1945
+ kl = logprobs - ref_logprobs
1946
+ mean_kl = kl.sum(1).mean()
1947
+ self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1948
+ non_score_reward = (-self.beta * kl).sum(1)
1949
+ mean_non_score_reward = non_score_reward.mean()
1950
+ self.stats["objective/non_score_reward"].append(
1951
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
1952
+ )
1953
+ if self.reward_funcs is not None:
1954
+ # Calculate RLHF reward by combining rewards with non_score_reward
1955
+ rlhf_reward = rewards + non_score_reward
1956
+ self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
1957
+
1958
+ mean_entropy = -logprobs.sum(1).mean()
1959
+ self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
1960
+ chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
1961
+ gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
1962
+ self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
1963
+ rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
1964
+ gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
1965
+ self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
1966
+ margin = gathered_chosen_rewards - gathered_rejected_rewards
1967
+ self.stats["rewards/margins"].append(margin.mean().item())
1968
+ accuracy = margin > 0
1969
+ self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
1970
+ self.stats["beta"].append(self.beta)
1971
+
1972
+ if (
1973
+ self.args.torch_empty_cache_steps is not None
1974
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
1975
+ ):
1976
+ empty_cache()
1977
+
1978
+ kwargs = {}
1979
+
1980
+ # For LOMO optimizers you need to explicitly use the learning rate
1981
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
1982
+ kwargs["learning_rate"] = self._get_learning_rate()
1983
+
1984
+ if self.args.n_gpu > 1:
1985
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
1986
+
1987
+ self.accelerator.backward(loss, **kwargs)
1988
+
1989
+ return loss.detach() / self.args.gradient_accumulation_steps
1990
+
1991
+ # Same as Trainer._maybe_log_save_evaluate but log our metrics
1992
+ def _maybe_log_save_evaluate(
1993
+ self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
1994
+ ):
1995
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
1996
+ logs: dict[str, float] = {}
1997
+
1998
+ # all_gather + mean() to get average loss over all processes
1999
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
2000
+
2001
+ # reset tr_loss to zero
2002
+ tr_loss -= tr_loss
2003
+
2004
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2005
+ if grad_norm is not None:
2006
+ logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
2007
+ if learning_rate is not None:
2008
+ logs["learning_rate"] = learning_rate
2009
+ else:
2010
+ logs["learning_rate"] = self._get_learning_rate()
2011
+
2012
+ # Add our metrics
2013
+ for key, val in self.stats.items():
2014
+ logs[key] = sum(val) / len(val)
2015
+ self.stats = {key: [] for key in self.stats} # reset stats
2016
+
2017
+ self._total_loss_scalar += tr_loss_scalar
2018
+ self._globalstep_last_logged = self.state.global_step
2019
+ self.store_flos()
2020
+ self.log(logs, start_time)
2021
+
2022
+ metrics = None
2023
+ if self.control.should_evaluate:
2024
+ metrics = self._evaluate(trial, ignore_keys_for_eval)
2025
+ is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
2026
+
2027
+ if self.args.save_strategy == "best":
2028
+ self.control.should_save = is_new_best_metric
2029
+
2030
+ if self.control.should_save:
2031
+ self._save_checkpoint(model, trial)
2032
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
2033
+
2034
+ # Ensure the model card is saved along with the checkpoint
2035
+ def _save_checkpoint(self, model, trial):
2036
+ if self.args.hub_model_id is None:
2037
+ model_name = Path(self.args.output_dir).name
2038
+ else:
2039
+ model_name = self.args.hub_model_id.split("/")[-1]
2040
+ self.create_model_card(model_name=model_name)
2041
+ super()._save_checkpoint(model, trial)
2042
+ class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
2043
+ """
2044
+
2045
+ Initialize OnlineDPOTrainer.
2046
+
2047
+ Args:
2048
+ model (`Union[str, nn.Module, PreTrainedModel]`):
2049
+ Model to be trained. Can be either:
2050
+
2051
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
2052
+ path to a *directory* containing model weights saved using
2053
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
2054
+ using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
2055
+ `args.model_init_kwargs`.
2056
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
2057
+ ref_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `None`):
2058
+ The reference model to use for training. If None is specified, the reference model will be created from the
2059
+ model.
2060
+ judge ([`BasePairwiseJudge`]):
2061
+ The judge to use for pairwise comparison of model completions.
2062
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`, *optional*):
2063
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
2064
+ functions with the prompts and completions and sum the rewards. Can be either:
2065
+
2066
+ - A single reward function: Can be a string (path to model), a [`~transformers.PreTrainedModel`], or a
2067
+ custom callable function.
2068
+ - A list of reward functions: Must all be of compatible types.
2069
+
2070
+ Note: Only one of `judge`, or `reward_funcs` should be provided.
2071
+ args ([`OnlineDPOConfig`]):
2072
+ The online DPO config arguments to use for training.
2073
+ data_collator ([`~transformers.DataCollator`]):
2074
+ The data collator to use for training. If None is specified, the default data collator
2075
+ ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
2076
+ sequences in the batch, given a dataset of paired sequences.
2077
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
2078
+ The dataset to use for training.
2079
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
2080
+ The dataset to use for evaluation.
2081
+ processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*):
2082
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
2083
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
2084
+ reuse the fine-tuned model.
2085
+ reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
2086
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
2087
+
2088
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
2089
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
2090
+
2091
+ If set to `None`, the tokenizer for each model-based reward function is automatically loaded using
2092
+ [`~transformers.AutoTokenizer.from_pretrained`].
2093
+ peft_config ([`~peft.PeftConfig`], *optional*):
2094
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
2095
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
2096
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
2097
+ metric values.
2098
+ callbacks (`list[transformers.TrainerCallback]`):
2099
+ The callbacks to use for training.
2100
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
2101
+ The optimizer and scheduler to use for training.
2102
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
2103
+ The function to use to preprocess the logits before computing the metrics.
2104
+
2105
+ reward_model:
2106
+
2107
+ <Deprecated version="0.22.0">
2108
+
2109
+ This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
2110
+
2111
+ </Deprecated>
2112
+
2113
+ """
2114
+ def __init__(
2115
+ self,
2116
+ model,
2117
+ ref_model = None,
2118
+ reward_funcs = None,
2119
+ judge = None,
2120
+ args = None,
2121
+ data_collator = None,
2122
+ train_dataset = None,
2123
+ eval_dataset = None,
2124
+ processing_class = None,
2125
+ reward_processing_classes = None,
2126
+ peft_config = None,
2127
+ compute_metrics = None,
2128
+ callbacks = None,
2129
+ preprocess_logits_for_metrics = None,
2130
+ reward_model = None,
2131
+ reward_processing_class = None,
2132
+ **kwargs
2133
+ ):
2134
+ if args is None: args = UnslothOnlineDPOConfig()
2135
+ use_bf16 = getattr(args, 'bf16', False)
2136
+ if type(use_bf16) is not bool: use_bf16 = False
2137
+ use_fp16 = getattr(args, 'fp16', False)
2138
+ if type(use_fp16) is not bool: use_fp16 = False
2139
+ force_float32 = False
2140
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
2141
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
2142
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
2143
+ force_float32 = True
2144
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
2145
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
2146
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
2147
+ from unsloth_zoo.utils import _get_dtype
2148
+ dtype = _get_dtype(dtype)
2149
+ float16 = dtype == torch.float16
2150
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
2151
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
2152
+ if force_float32:
2153
+ # Forced float32 training
2154
+ args.fp16 = False
2155
+ args.bf16 = False
2156
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
2157
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
2158
+ # args.mixed_precision is a new argument which needs to be set now
2159
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
2160
+ # Mixed precision training
2161
+ args.fp16 = float16
2162
+ args.bf16 = not float16
2163
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
2164
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
2165
+ # args.mixed_precision is a new argument which needs to be set now
2166
+ elif mixed_precision_dtype == 'bfloat16':
2167
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
2168
+ args.fp16 = False
2169
+ args.bf16 = False
2170
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
2171
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
2172
+ # args.mixed_precision is a new argument which needs to be set now
2173
+
2174
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
2175
+ args.eval_strategy = 'steps'
2176
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
2177
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
2178
+ if ga_steps is not None and ga_steps > 1:
2179
+ from transformers import __version__ as transformers_version
2180
+ if Version(transformers_version) <= Version('4.45.2'):
2181
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
2182
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
2183
+ if getattr(args, 'eval_strategy', 'no') != 'no':
2184
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
2185
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
2186
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
2187
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
2188
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
2189
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
2190
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
2191
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
2192
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
2193
+ if force_float32:
2194
+ args.bf16_full_eval = False
2195
+ args.fp16_full_eval = False
2196
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
2197
+ args.bf16_full_eval = True
2198
+ args.fp16_full_eval = False
2199
+ elif not bf16_full_eval and not fp16_full_eval:
2200
+ args.bf16_full_eval = args.bf16
2201
+ args.fp16_full_eval = args.fp16
2202
+ _output_logits = False
2203
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
2204
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
2205
+ if _output_logits:
2206
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
2207
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
2208
+ pass
2209
+ else:
2210
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
2211
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
2212
+ if args_max_seq_length is None and model_max_seq_length is not None:
2213
+ max_seq_length = model.max_seq_length
2214
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
2215
+ if model is not None and hasattr(model, 'for_training'):
2216
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
2217
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
2218
+ if 'processing_class' in locals():
2219
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
2220
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
2221
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
2222
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
2223
+ if not isinstance(data_collator, UnslothVisionDataCollator):
2224
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
2225
+ data_collator = TransformersDataCollatorForLanguageModeling(
2226
+ __tokenizer,
2227
+ mlm = False,
2228
+ mlm_probability = 0.0,
2229
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
2230
+ )
2231
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
2232
+ data_collator = DataCollatorForSeq2Seq(
2233
+ __tokenizer,
2234
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
2235
+ )
2236
+ else:
2237
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
2238
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
2239
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
2240
+ if not isinstance(data_collator, UnslothVisionDataCollator):
2241
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
2242
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
2243
+ data_collator = DataCollatorForSeq2Seq(
2244
+ __tokenizer.tokenizer,
2245
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
2246
+ )
2247
+ else:
2248
+ data_collator = TransformersDataCollatorForLanguageModeling(
2249
+ __tokenizer.tokenizer,
2250
+ mlm = False,
2251
+ mlm_probability = 0.0,
2252
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
2253
+ )
2254
+ other_metrics = []
2255
+
2256
+ from unsloth_zoo.logging_utils import PatchRLStatistics
2257
+ PatchRLStatistics('online_dpo_trainer', other_metrics)
2258
+
2259
+ # [TODO] Fix up DataParallel multiplying batch sizes
2260
+ # [TODO] DDP works, but DP seems to not work? [TODO]
2261
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
2262
+ if getattr(args, "_n_gpu", 1) != 1:
2263
+ args._n_gpu = 1
2264
+ if "model" in locals() and hasattr(model, "for_training"):
2265
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
2266
+ super().__init__(
2267
+ model = model,
2268
+ ref_model = ref_model,
2269
+ reward_funcs = reward_funcs,
2270
+ judge = judge,
2271
+ args = args,
2272
+ data_collator = data_collator,
2273
+ train_dataset = train_dataset,
2274
+ eval_dataset = eval_dataset,
2275
+ processing_class = processing_class,
2276
+ reward_processing_classes = reward_processing_classes,
2277
+ peft_config = peft_config,
2278
+ compute_metrics = compute_metrics,
2279
+ callbacks = callbacks,
2280
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
2281
+ reward_model = reward_model,
2282
+ reward_processing_class = reward_processing_class,**kwargs)
2283
+ if "model" in locals() and hasattr(model, "for_inference"):
2284
+ model.for_inference()
2285
+ if hasattr(self, 'neftune_hook_handle'):
2286
+ self.neftune_hook_handle.remove()
2287
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
2288
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
2289
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
2290
+ pass
2291
+ if hasattr(self, 'accelerator'):
2292
+ scaler = self.accelerator.scaler
2293
+ current_model = model
2294
+ while hasattr(current_model, 'model'):
2295
+ current_model.accelerator_scaler = scaler
2296
+ current_model = current_model.model
2297
+ current_model.accelerator_scaler = scaler
2298
+ pass
2299
+ if hasattr(self, 'train'):
2300
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
2301
+ pass
2302
+
2303
+ pass
2304
+
2305
+
2306
+ if hasattr(logger, "addFilter"):
2307
+ import logging
2308
+ class HideLoggingMessage(logging.Filter):
2309
+ def __init__(self, text): self.text = text
2310
+ def filter(self, x): return not (self.text in x.getMessage())
2311
+ pass
2312
+ logger.addFilter(HideLoggingMessage("`use_cache=True`"))
2313
+