cciwon-code-review-cli 2.0.2 → 2.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/lib/chat-mode.js +7 -2
  2. package/package.json +1 -1
  3. package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
  4. package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
  5. package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
  6. package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
  7. package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
  8. package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
  9. package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
  10. package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
  11. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
  12. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
  13. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
  14. package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
  15. package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
  16. package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
  17. package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
  18. package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
  19. package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
  20. package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
  21. package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
  22. package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
  23. package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
  24. package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
  25. package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
  26. package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
  27. package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
  28. package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
  29. package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
  30. package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
  31. package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
  32. package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
  33. package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
  34. package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
  35. package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
  36. package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
  37. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
  38. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
  39. package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
  40. package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
  41. package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
  42. package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
  43. package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
  44. package/unsloth_compiled_cache/Conv1d.py +70 -0
  45. package/unsloth_compiled_cache/Conv2d.py +70 -0
  46. package/unsloth_compiled_cache/Conv3d.py +70 -0
  47. package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
  48. package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
  49. package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
  50. package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
  51. package/unsloth_compiled_cache/GroupNorm.py +70 -0
  52. package/unsloth_compiled_cache/LayerNorm.py +72 -0
  53. package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
  54. package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
  55. package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
  56. package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
  57. package/unsloth_compiled_cache/RMSNorm.py +73 -0
  58. package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
  59. package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
  60. package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
  61. package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
  62. package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
  63. package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
  64. package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
  65. package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
  66. package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
  67. package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
  68. package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
  69. package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
  70. package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
  71. package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
  72. package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
  73. package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
  74. package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
  75. package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
  76. package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
  77. package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
  78. package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
  79. package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
  80. package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
  81. package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
  82. package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
  83. package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
  84. package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
  85. package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
  86. package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
  87. package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
  88. package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
  89. package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
  90. package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
  91. package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
  92. package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  93. package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  94. package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
  95. package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  96. package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
  97. package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  98. package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  99. package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  100. package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  101. package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  102. package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  103. package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  104. package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  105. package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  106. package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
  107. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
  108. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
  109. package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
  110. package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
@@ -0,0 +1,1255 @@
1
+ """
2
+ 2025.12.6
3
+ 2025.12.7
4
+ 4.57.1
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
30
+ from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation)
31
+
32
+
33
+ import os
34
+ from typing import *
35
+ from dataclasses import dataclass, field
36
+ from packaging.version import Version
37
+ import torch
38
+ import numpy as np
39
+ from contextlib import nullcontext
40
+ from torch.nn import functional as F
41
+ import inspect
42
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
43
+ from transformers.training_args import ParallelMode
44
+
45
+ # Wrap trainer with padding to right and enable training mode
46
+ # Also patches W&B since multiple runs must use wandb.finish()
47
+ import functools
48
+ from types import MethodType
49
+ def prepare_for_training_mode(f):
50
+ @functools.wraps(f)
51
+ def wrapper(self, *args, **kwargs):
52
+ # Enable training mode
53
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
54
+ self.model.for_training()
55
+ output = f(self, *args, **kwargs)
56
+ # Return inference mode
57
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
58
+ self.model.for_inference()
59
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
60
+ try:
61
+ import wandb
62
+ wandb.finish()
63
+ except:
64
+ pass
65
+ return output
66
+ return wrapper
67
+ pass
68
+
69
+ torch_compile_options = {
70
+ "epilogue_fusion" : True,
71
+ "max_autotune" : False,
72
+ "shape_padding" : True,
73
+ "trace.enabled" : False,
74
+ "triton.cudagraphs" : False,
75
+ }
76
+
77
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
78
+ def chunked_selective_log_softmax(logits, index):
79
+ # Split into 4 chunks only
80
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
81
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
82
+ all_per_token_logps = []
83
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
84
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
85
+ chunk_logits = chunk_logits.to(torch.float32)
86
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
87
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
88
+ per_token_logps = selected_logits - logsumexp_values
89
+ all_per_token_logps.append(per_token_logps)
90
+ pass
91
+ all_per_token_logps = torch.concat(all_per_token_logps)
92
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
93
+ return all_per_token_logps
94
+
95
+ def calculate_pad_tokens_in_prompt(
96
+ input_ids: torch.Tensor,
97
+ logits_to_keep: int,
98
+ pad_token_id: int
99
+ ) -> torch.Tensor:
100
+ """
101
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
102
+ """
103
+ if logits_to_keep >= input_ids.shape[1]:
104
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
105
+
106
+ prompt_section = input_ids[:, :-logits_to_keep]
107
+
108
+ padding_mask = (prompt_section == pad_token_id)
109
+
110
+ pad_token_counts = padding_mask.sum(dim=1)
111
+
112
+ return pad_token_counts
113
+
114
+ def create_completion_attention_mask(
115
+ completion_input_ids: torch.Tensor,
116
+ left_pad_tokens_per_prompt: torch.Tensor,
117
+ max_left_pad: int,
118
+ pad_token_id: int
119
+ ) -> torch.Tensor:
120
+ """
121
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
122
+
123
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
124
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
125
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
126
+ """
127
+ batch_size, completion_len = completion_input_ids.shape
128
+ device = completion_input_ids.device
129
+
130
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
131
+
132
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
133
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
134
+
135
+ non_padding_mask = (completion_input_ids != pad_token_id)
136
+
137
+ final_mask = shift_mask & non_padding_mask
138
+
139
+ return final_mask
140
+
141
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
142
+ """
143
+ Moves all padding tokens in each sequence of a batch to the right.
144
+ """
145
+ mask = (tensor != pad_id)
146
+ # Must do stable=True since binary mark is unordered
147
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
148
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
149
+ return packed_tensor
150
+
151
+ def align_logprobs_with_mask(
152
+ logprob_tensor: torch.Tensor,
153
+ attention_mask: torch.Tensor,
154
+ pad_value: float = 0.0
155
+ ) -> torch.Tensor:
156
+ """
157
+ Aligns a log probability tensor with a given attention mask.
158
+ """
159
+
160
+ device = logprob_tensor.device
161
+ batch_size, logprob_seq_len = logprob_tensor.shape
162
+ mask_seq_len = attention_mask.shape[1]
163
+
164
+ padded_logprobs = torch.full(
165
+ attention_mask.shape,
166
+ fill_value=pad_value,
167
+ dtype=logprob_tensor.dtype,
168
+ device=device
169
+ )
170
+
171
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
172
+
173
+ cols = torch.arange(logprob_seq_len, device=device)
174
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
175
+
176
+ # Create destination row indices
177
+ # Shape: [batch_size, logprob_seq_len]
178
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
179
+
180
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
181
+ # Create a mask to identify only the indices that are within the bounds
182
+ # of the target tensor's sequence length.
183
+ valid_mask = dest_indices < mask_seq_len
184
+
185
+ # Use this mask to select only the valid row indices, column indices,
186
+ # and the corresponding values from the logprob tensor.
187
+ # This flattens the selected elements into 1D tensors.
188
+ valid_rows = row_indices[valid_mask]
189
+ valid_cols = dest_indices[valid_mask]
190
+ valid_vals = logprob_tensor[valid_mask]
191
+
192
+ # Place the valid values into their correct positions in the padded tensor
193
+ # using a single, efficient advanced indexing operation.
194
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
195
+
196
+ return padded_logprobs
197
+ @dataclass
198
+ class UnslothXPOConfig(XPOConfig):
199
+ """
200
+
201
+ Configuration class for the [`XPOTrainer`].
202
+
203
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
204
+
205
+ Parameters:
206
+ alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
207
+ Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
208
+ and the last alpha is used for the rest of the epochs.
209
+
210
+ """
211
+ vllm_sampling_params: Optional[Any] = field(
212
+ default = None,
213
+ metadata = {'help': 'vLLM SamplingParams'},
214
+ )
215
+ unsloth_num_chunks : Optional[int] = field(
216
+ default = -1,
217
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
218
+ )
219
+ max_seq_length : Optional[int] = field(
220
+ default = None,
221
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
222
+ )
223
+ def __init__(
224
+ self,
225
+ output_dir = None,
226
+ overwrite_output_dir = None,
227
+ do_train = False,
228
+ do_eval = False,
229
+ do_predict = False,
230
+ eval_strategy = 'no',
231
+ prediction_loss_only = False,
232
+ per_device_train_batch_size = 4,
233
+ per_device_eval_batch_size = 4,
234
+ per_gpu_train_batch_size = None,
235
+ per_gpu_eval_batch_size = None,
236
+ gradient_accumulation_steps = 2,
237
+ eval_accumulation_steps = 2,
238
+ eval_delay = 0,
239
+ torch_empty_cache_steps = 250,
240
+ learning_rate = 5e-05,
241
+ weight_decay = 0.01,
242
+ adam_beta1 = 0.9,
243
+ adam_beta2 = 0.999,
244
+ adam_epsilon = 1e-08,
245
+ max_grad_norm = 1.0,
246
+ num_train_epochs = 3.0,
247
+ max_steps = -1,
248
+ lr_scheduler_type = 'linear',
249
+ warmup_ratio = 0.1,
250
+ warmup_steps = 0,
251
+ log_level = 'passive',
252
+ log_level_replica = 'warning',
253
+ log_on_each_node = True,
254
+ logging_dir = None,
255
+ logging_strategy = 'steps',
256
+ logging_first_step = False,
257
+ logging_steps = 1,
258
+ logging_nan_inf_filter = False,
259
+ save_strategy = 'steps',
260
+ save_steps = 500,
261
+ save_total_limit = None,
262
+ save_safetensors = True,
263
+ save_on_each_node = False,
264
+ save_only_model = False,
265
+ restore_callback_states_from_checkpoint = False,
266
+ no_cuda = False,
267
+ use_cpu = False,
268
+ use_mps_device = False,
269
+ seed = 3407,
270
+ data_seed = 3407,
271
+ jit_mode_eval = False,
272
+ bf16 = False,
273
+ fp16 = False,
274
+ fp16_opt_level = 'O1',
275
+ half_precision_backend = 'auto',
276
+ bf16_full_eval = False,
277
+ fp16_full_eval = False,
278
+ tf32 = None,
279
+ local_rank = -1,
280
+ ddp_backend = None,
281
+ tpu_num_cores = None,
282
+ tpu_metrics_debug = False,
283
+ debug = '',
284
+ dataloader_drop_last = False,
285
+ eval_steps = None,
286
+ dataloader_num_workers = 0,
287
+ dataloader_prefetch_factor = None,
288
+ past_index = -1,
289
+ run_name = None,
290
+ disable_tqdm = None,
291
+ remove_unused_columns = True,
292
+ label_names = None,
293
+ load_best_model_at_end = False,
294
+ metric_for_best_model = None,
295
+ greater_is_better = None,
296
+ ignore_data_skip = False,
297
+ fsdp = None,
298
+ fsdp_min_num_params = 0,
299
+ fsdp_config = None,
300
+ fsdp_transformer_layer_cls_to_wrap = None,
301
+ accelerator_config = None,
302
+ parallelism_config = None,
303
+ deepspeed = None,
304
+ label_smoothing_factor = 0.0,
305
+ optim = 'adamw_8bit',
306
+ optim_args = None,
307
+ adafactor = False,
308
+ group_by_length = False,
309
+ length_column_name = 'length',
310
+ report_to = 'none',
311
+ project = 'huggingface',
312
+ trackio_space_id = 'trackio',
313
+ ddp_find_unused_parameters = None,
314
+ ddp_bucket_cap_mb = None,
315
+ ddp_broadcast_buffers = None,
316
+ dataloader_pin_memory = True,
317
+ dataloader_persistent_workers = False,
318
+ skip_memory_metrics = True,
319
+ use_legacy_prediction_loop = False,
320
+ push_to_hub = False,
321
+ resume_from_checkpoint = None,
322
+ hub_model_id = None,
323
+ hub_strategy = 'every_save',
324
+ hub_token = None,
325
+ hub_private_repo = None,
326
+ hub_always_push = False,
327
+ hub_revision = None,
328
+ gradient_checkpointing = True,
329
+ gradient_checkpointing_kwargs = None,
330
+ include_inputs_for_metrics = False,
331
+ eval_do_concat_batches = True,
332
+ fp16_backend = 'auto',
333
+ push_to_hub_model_id = None,
334
+ push_to_hub_organization = None,
335
+ push_to_hub_token = None,
336
+ mp_parameters = '',
337
+ auto_find_batch_size = False,
338
+ full_determinism = False,
339
+ torchdynamo = None,
340
+ ray_scope = 'last',
341
+ ddp_timeout = 1800,
342
+ torch_compile = False,
343
+ torch_compile_backend = None,
344
+ torch_compile_mode = None,
345
+ include_tokens_per_second = False,
346
+ include_num_input_tokens_seen = False,
347
+ neftune_noise_alpha = None,
348
+ optim_target_modules = None,
349
+ batch_eval_metrics = False,
350
+ eval_on_start = False,
351
+ use_liger_kernel = False,
352
+ liger_kernel_config = None,
353
+ eval_use_gather_object = False,
354
+ average_tokens_across_devices = True,
355
+ reward_model_path = None,
356
+ judge = None,
357
+ max_new_tokens = 64,
358
+ max_length = 512,
359
+ temperature = 0.9,
360
+ top_p = 1.0,
361
+ top_k = None,
362
+ min_p = None,
363
+ repetition_penalty = 1.0,
364
+ generation_kwargs = {},
365
+ use_transformers_paged = False,
366
+ cache_implementation = None,
367
+ missing_eos_penalty = None,
368
+ loss_type = 'sigmoid',
369
+ disable_dropout = True,
370
+ use_vllm = False,
371
+ vllm_model_impl = 'vllm',
372
+ vllm_guided_decoding_regex = None,
373
+ vllm_gpu_memory_utilization = 0.55,
374
+ vllm_mode = 'colocate',
375
+ vllm_server_base_url = None,
376
+ vllm_server_host = '0.0.0.0',
377
+ vllm_server_port = 8000,
378
+ vllm_server_timeout = 240.0,
379
+ vllm_tensor_parallel_size = 1,
380
+ ds3_gather_for_generation = True,
381
+ model_init_kwargs = None,
382
+ reward_weights = None,
383
+ dataset_num_proc = None,
384
+ gpu_memory_utilization = None,
385
+ vllm_sampling_params = None,
386
+ unsloth_num_chunks = -1,
387
+ max_seq_length = None,
388
+ **kwargs,
389
+ ):
390
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
391
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
392
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
393
+ output_dir = 'unsloth_training_checkpoints'
394
+ save_strategy = 'no'
395
+ if dataset_num_proc is None:
396
+ from multiprocessing import cpu_count
397
+ dataset_num_proc = min(max(cpu_count()+4, 2), 64)
398
+ if temperature <= 0:
399
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
400
+ elif temperature >= 10:
401
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
402
+
403
+
404
+ super().__init__(
405
+ output_dir = output_dir,
406
+ overwrite_output_dir = overwrite_output_dir,
407
+ do_train = do_train,
408
+ do_eval = do_eval,
409
+ do_predict = do_predict,
410
+ eval_strategy = eval_strategy,
411
+ prediction_loss_only = prediction_loss_only,
412
+ per_device_train_batch_size = per_device_train_batch_size,
413
+ per_device_eval_batch_size = per_device_eval_batch_size,
414
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
415
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
416
+ gradient_accumulation_steps = gradient_accumulation_steps,
417
+ eval_accumulation_steps = eval_accumulation_steps,
418
+ eval_delay = eval_delay,
419
+ torch_empty_cache_steps = torch_empty_cache_steps,
420
+ learning_rate = learning_rate,
421
+ weight_decay = weight_decay,
422
+ adam_beta1 = adam_beta1,
423
+ adam_beta2 = adam_beta2,
424
+ adam_epsilon = adam_epsilon,
425
+ max_grad_norm = max_grad_norm,
426
+ num_train_epochs = num_train_epochs,
427
+ max_steps = max_steps,
428
+ lr_scheduler_type = lr_scheduler_type,
429
+ warmup_ratio = warmup_ratio,
430
+ warmup_steps = warmup_steps,
431
+ log_level = log_level,
432
+ log_level_replica = log_level_replica,
433
+ log_on_each_node = log_on_each_node,
434
+ logging_dir = logging_dir,
435
+ logging_strategy = logging_strategy,
436
+ logging_first_step = logging_first_step,
437
+ logging_steps = logging_steps,
438
+ logging_nan_inf_filter = logging_nan_inf_filter,
439
+ save_strategy = save_strategy,
440
+ save_steps = save_steps,
441
+ save_total_limit = save_total_limit,
442
+ save_safetensors = save_safetensors,
443
+ save_on_each_node = save_on_each_node,
444
+ save_only_model = save_only_model,
445
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
446
+ no_cuda = no_cuda,
447
+ use_cpu = use_cpu,
448
+ use_mps_device = use_mps_device,
449
+ seed = seed,
450
+ data_seed = data_seed,
451
+ jit_mode_eval = jit_mode_eval,
452
+ bf16 = bf16,
453
+ fp16 = fp16,
454
+ fp16_opt_level = fp16_opt_level,
455
+ half_precision_backend = half_precision_backend,
456
+ bf16_full_eval = bf16_full_eval,
457
+ fp16_full_eval = fp16_full_eval,
458
+ tf32 = tf32,
459
+ local_rank = local_rank,
460
+ ddp_backend = ddp_backend,
461
+ tpu_num_cores = tpu_num_cores,
462
+ tpu_metrics_debug = tpu_metrics_debug,
463
+ debug = debug,
464
+ dataloader_drop_last = dataloader_drop_last,
465
+ eval_steps = eval_steps,
466
+ dataloader_num_workers = dataloader_num_workers,
467
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
468
+ past_index = past_index,
469
+ run_name = run_name,
470
+ disable_tqdm = disable_tqdm,
471
+ remove_unused_columns = remove_unused_columns,
472
+ label_names = label_names,
473
+ load_best_model_at_end = load_best_model_at_end,
474
+ metric_for_best_model = metric_for_best_model,
475
+ greater_is_better = greater_is_better,
476
+ ignore_data_skip = ignore_data_skip,
477
+ fsdp = fsdp,
478
+ fsdp_min_num_params = fsdp_min_num_params,
479
+ fsdp_config = fsdp_config,
480
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
481
+ accelerator_config = accelerator_config,
482
+ parallelism_config = parallelism_config,
483
+ deepspeed = deepspeed,
484
+ label_smoothing_factor = label_smoothing_factor,
485
+ optim = optim,
486
+ optim_args = optim_args,
487
+ adafactor = adafactor,
488
+ group_by_length = group_by_length,
489
+ length_column_name = length_column_name,
490
+ report_to = report_to,
491
+ project = project,
492
+ trackio_space_id = trackio_space_id,
493
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
494
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
495
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
496
+ dataloader_pin_memory = dataloader_pin_memory,
497
+ dataloader_persistent_workers = dataloader_persistent_workers,
498
+ skip_memory_metrics = skip_memory_metrics,
499
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
500
+ push_to_hub = push_to_hub,
501
+ resume_from_checkpoint = resume_from_checkpoint,
502
+ hub_model_id = hub_model_id,
503
+ hub_strategy = hub_strategy,
504
+ hub_token = hub_token,
505
+ hub_private_repo = hub_private_repo,
506
+ hub_always_push = hub_always_push,
507
+ hub_revision = hub_revision,
508
+ gradient_checkpointing = gradient_checkpointing,
509
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
510
+ include_inputs_for_metrics = include_inputs_for_metrics,
511
+ eval_do_concat_batches = eval_do_concat_batches,
512
+ fp16_backend = fp16_backend,
513
+ push_to_hub_model_id = push_to_hub_model_id,
514
+ push_to_hub_organization = push_to_hub_organization,
515
+ push_to_hub_token = push_to_hub_token,
516
+ mp_parameters = mp_parameters,
517
+ auto_find_batch_size = auto_find_batch_size,
518
+ full_determinism = full_determinism,
519
+ torchdynamo = torchdynamo,
520
+ ray_scope = ray_scope,
521
+ ddp_timeout = ddp_timeout,
522
+ torch_compile = torch_compile,
523
+ torch_compile_backend = torch_compile_backend,
524
+ torch_compile_mode = torch_compile_mode,
525
+ include_tokens_per_second = include_tokens_per_second,
526
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
527
+ neftune_noise_alpha = neftune_noise_alpha,
528
+ optim_target_modules = optim_target_modules,
529
+ batch_eval_metrics = batch_eval_metrics,
530
+ eval_on_start = eval_on_start,
531
+ use_liger_kernel = use_liger_kernel,
532
+ liger_kernel_config = liger_kernel_config,
533
+ eval_use_gather_object = eval_use_gather_object,
534
+ average_tokens_across_devices = average_tokens_across_devices,
535
+ reward_model_path = reward_model_path,
536
+ judge = judge,
537
+ max_new_tokens = max_new_tokens,
538
+ max_length = max_length,
539
+ temperature = temperature,
540
+ top_p = top_p,
541
+ top_k = top_k,
542
+ min_p = min_p,
543
+ repetition_penalty = repetition_penalty,
544
+ generation_kwargs = generation_kwargs,
545
+ use_transformers_paged = use_transformers_paged,
546
+ cache_implementation = cache_implementation,
547
+ missing_eos_penalty = missing_eos_penalty,
548
+ loss_type = loss_type,
549
+ disable_dropout = disable_dropout,
550
+ use_vllm = use_vllm,
551
+ vllm_model_impl = vllm_model_impl,
552
+ vllm_guided_decoding_regex = vllm_guided_decoding_regex,
553
+ vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
554
+ vllm_mode = vllm_mode,
555
+ vllm_server_base_url = vllm_server_base_url,
556
+ vllm_server_host = vllm_server_host,
557
+ vllm_server_port = vllm_server_port,
558
+ vllm_server_timeout = vllm_server_timeout,
559
+ vllm_tensor_parallel_size = vllm_tensor_parallel_size,
560
+ ds3_gather_for_generation = ds3_gather_for_generation,
561
+ model_init_kwargs = model_init_kwargs,
562
+ reward_weights = reward_weights,
563
+ dataset_num_proc = dataset_num_proc,
564
+ gpu_memory_utilization = gpu_memory_utilization,**kwargs)
565
+ self.vllm_sampling_params = vllm_sampling_params
566
+ self.unsloth_num_chunks = unsloth_num_chunks
567
+ self.max_seq_length = max_seq_length
568
+ pass
569
+
570
+ class _UnslothXPOTrainer(OnlineDPOTrainer):
571
+ """"""
572
+
573
+ _tag_names = ["trl", "xpo"]
574
+ _name = "XPO"
575
+ _paper = {
576
+ "title": "Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
577
+ "id": "2405.21046",
578
+ # docstyle-ignore
579
+ "citation": textwrap.dedent("""\
580
+ @article{jung2024binary,
581
+ title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
582
+ author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
583
+ year = 2024,
584
+ eprint = {arXiv:2405.21046}
585
+ }"""),
586
+ }
587
+
588
+ def __init__(
589
+ self,
590
+ model: Union[PreTrainedModel, nn.Module] = None,
591
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
592
+ reward_funcs: Optional[nn.Module] = None,
593
+ judge: Optional[BasePairwiseJudge] = None,
594
+ args: Optional[XPOConfig] = None,
595
+ data_collator: Optional[Callable] = None,
596
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
597
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
598
+ processing_class: Optional[
599
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
600
+ ] = None,
601
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
602
+ peft_config: Optional[dict] = None,
603
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
604
+ callbacks: Optional[list[TrainerCallback]] = None,
605
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
606
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
607
+ # Deprecated parameters
608
+ reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
609
+ ) -> None:
610
+ super().__init__(
611
+ model=model,
612
+ ref_model=ref_model,
613
+ judge=judge,
614
+ reward_funcs=reward_funcs,
615
+ reward_model=reward_model,
616
+ args=args,
617
+ data_collator=data_collator,
618
+ train_dataset=train_dataset,
619
+ eval_dataset=eval_dataset,
620
+ processing_class=processing_class,
621
+ reward_processing_classes=reward_processing_classes,
622
+ peft_config=peft_config,
623
+ compute_metrics=compute_metrics,
624
+ callbacks=callbacks,
625
+ optimizers=optimizers,
626
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
627
+ )
628
+
629
+ self._alpha = self.args.alpha
630
+
631
+ # Overwrite the stats dictionary to include XPO specific statistics
632
+ self.stats = {
633
+ # Remove "non_score_reward", "rlhf_reward", "scores"
634
+ # Add "loss/dpo", "loss/xpo"
635
+ "loss/dpo": [],
636
+ "loss/xpo": [],
637
+ "objective/kl": [],
638
+ "objective/entropy": [],
639
+ "rewards/chosen": [],
640
+ "rewards/rejected": [],
641
+ "rewards/accuracies": [],
642
+ "rewards/margins": [],
643
+ "logps/chosen": [],
644
+ "logps/rejected": [],
645
+ # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
646
+ "val/model_contain_eos_token": [],
647
+ "val/ref_contain_eos_token": [],
648
+ "alpha": [],
649
+ "beta": [],
650
+ }
651
+ if self.reward_funcs is not None:
652
+ if len(self.reward_funcs) != 1:
653
+ raise ValueError("XPOTrainer only supports one reward function/model.")
654
+ self.reward_funcs = self.reward_funcs[0]
655
+ self.stats["objective/model_scores"] = []
656
+ self.stats["objective/ref_scores"] = []
657
+ self.stats["objective/scores_margin"] = []
658
+
659
+ @property
660
+ def alpha(self):
661
+ if isinstance(self._alpha, list):
662
+ epoch = self.state.epoch
663
+ return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
664
+ else:
665
+ return self._alpha
666
+
667
+ def _generate_completions(self, prompts, model):
668
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
669
+ model_output = unwrapped_policy_model_for_gen.generate(
670
+ input_ids=prompts["input_ids"],
671
+ attention_mask=prompts["attention_mask"],
672
+ generation_config=self.generation_config,
673
+ )
674
+
675
+ actual_model_for_ref_generation: torch.nn.Module
676
+ if self.ref_model is None:
677
+ unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
678
+
679
+ if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
680
+ actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
681
+ else:
682
+ actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
683
+ else:
684
+ actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
685
+
686
+ with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
687
+ ref_output = final_ref_model_for_gen.generate(
688
+ input_ids=prompts["input_ids"],
689
+ attention_mask=prompts["attention_mask"],
690
+ generation_config=self.generation_config,
691
+ )
692
+
693
+ return model_output, ref_output
694
+
695
+ def _process_completions(self, model_output, ref_output, prompts):
696
+ context_length = prompts["input_ids"].shape[1]
697
+
698
+ # Process model completions
699
+ model_completion_ids = model_output[:, context_length:]
700
+ model_completion_ids, model_completion_mask = truncate_right(
701
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
702
+ )
703
+ model_data = {
704
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
705
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
706
+ "raw": prompts["raw"],
707
+ }
708
+
709
+ # Process reference model completions
710
+ ref_completion_ids = ref_output[:, context_length:]
711
+ ref_completion_ids, ref_completion_mask = truncate_right(
712
+ ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
713
+ )
714
+ ref_data = {
715
+ "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
716
+ "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
717
+ "raw": prompts["raw"],
718
+ }
719
+
720
+ return model_data, ref_data
721
+
722
+ def _compute_rewards(self, model_data, ref_data, context_length):
723
+ with torch.no_grad():
724
+ _, model_scores, _ = get_reward(
725
+ self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length
726
+ )
727
+ _, ref_scores, _ = get_reward(
728
+ self.reward_funcs, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
729
+ )
730
+
731
+ # Apply EOS penalty if needed
732
+ if self.args.missing_eos_penalty is not None:
733
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
734
+ ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
735
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
736
+ ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
737
+
738
+ return model_scores, ref_scores
739
+
740
+ def _compute_judge(self, model_data, ref_data, context_length):
741
+ prompts = model_data["raw"]
742
+ model_data_completions = self.processing_class.batch_decode(
743
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
744
+ )
745
+ model_data_completions = [completion.strip() for completion in model_data_completions]
746
+
747
+ ref_data_completions = self.processing_class.batch_decode(
748
+ ref_data["input_ids"][:, context_length:], skip_special_tokens=True
749
+ )
750
+ ref_data_completions = [completion.strip() for completion in ref_data_completions]
751
+
752
+ if is_conversational({"prompt": prompts[0]}):
753
+ model_data_completions = [
754
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
755
+ ]
756
+ environment = jinja2.Environment()
757
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
758
+ prompts = [template.render(messages=message) for message in prompts]
759
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
760
+
761
+ ref_data_completions = [
762
+ [{"role": "assistant", "content": completion}] for completion in ref_data_completions
763
+ ]
764
+ ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
765
+
766
+ ranks_of_first_completion = self.judge.judge(
767
+ prompts,
768
+ list(zip(model_data_completions, ref_data_completions)),
769
+ )
770
+ # convert ranks to a True/False mask:
771
+ # when rank == 0, it means the first completion is the best
772
+ # when rank == 1, it means the second completion is the best
773
+ return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
774
+
775
+ def _compute_logprobs(self, model, model_data, ref_data, context_length):
776
+ def compute_logprobs_for_data(m, data):
777
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
778
+ logits = output.logits[:, context_length - 1 : -1]
779
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
780
+ return token_logprobs
781
+
782
+ # Compute logprobs for model completions
783
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
784
+ # Compute logprobs for model on reference completions (for XPO loss)
785
+ model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
786
+
787
+ # Compute logprobs for reference model completions
788
+ with torch.no_grad():
789
+ if self.ref_model is None:
790
+ with model.disable_adapter():
791
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
792
+ ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
793
+ else:
794
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
795
+ ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
796
+
797
+ # Mask padding tokens
798
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
799
+ ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
800
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
801
+ model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
802
+ ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
803
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
804
+
805
+ return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
806
+
807
+ def _compute_losses(
808
+ self,
809
+ model_logprobs_model_data,
810
+ model_logprobs_ref_data,
811
+ ref_logprobs_ref_data,
812
+ ref_logprobs_model_data,
813
+ chosen_mask,
814
+ ):
815
+ # Compute log probs
816
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
817
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
818
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
819
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
820
+
821
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
822
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
823
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
824
+
825
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
826
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
827
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
828
+
829
+ # Compute logits as the difference between chosen and rejected log ratios
830
+ logits = chosen_log_ratios - rejected_log_ratios
831
+
832
+ if self.args.loss_type == "sigmoid":
833
+ dpo_losses = -F.logsigmoid(self.beta * logits)
834
+ elif self.args.loss_type == "ipo":
835
+ dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
836
+ else:
837
+ raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
838
+
839
+ # Compute XPO specific loss
840
+ xpo_losses = self.alpha * model_logprobs_ref_data_sum
841
+
842
+ # Total loss
843
+ loss = (dpo_losses + xpo_losses).mean()
844
+
845
+ return loss, dpo_losses, xpo_losses
846
+
847
+ def _log_statistics(
848
+ self,
849
+ model_data,
850
+ ref_data,
851
+ model_logprobs_model_data,
852
+ model_logprobs_ref_data,
853
+ ref_logprobs_ref_data,
854
+ ref_logprobs_model_data,
855
+ chosen_mask,
856
+ dpo_losses,
857
+ xpo_losses,
858
+ context_length,
859
+ model_scores=None,
860
+ ref_scores=None,
861
+ ):
862
+ # Helper function to gather and compute mean
863
+ def gather_mean(tensor):
864
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
865
+
866
+ # Log losses
867
+ self.stats["loss/dpo"].append(gather_mean(dpo_losses))
868
+ self.stats["loss/xpo"].append(gather_mean(xpo_losses))
869
+
870
+ # Log scores
871
+ if self.reward_funcs is not None:
872
+ self.stats["objective/model_scores"].append(gather_mean(model_scores))
873
+ self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
874
+ self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
875
+
876
+ # Log logprobs
877
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
878
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
879
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
880
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
881
+
882
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
883
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
884
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
885
+
886
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
887
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
888
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
889
+
890
+ self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
891
+ self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
892
+
893
+ # Log rewards
894
+ # Compute various statistics
895
+ chosen_rewards = chosen_log_ratios * self.beta
896
+ rejected_rewards = rejected_log_ratios * self.beta
897
+ self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
898
+ self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
899
+
900
+ # Calculate KL divergence for model and ref data
901
+ kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
902
+ kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
903
+ mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
904
+ self.stats["objective/kl"].append(gather_mean(mean_kl))
905
+
906
+ # Calculate entropy for model and ref data
907
+ entropy_model_data = -model_logprobs_model_data.sum(1)
908
+ entropy_ref_data = -model_logprobs_ref_data.sum(1)
909
+ mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
910
+ self.stats["objective/entropy"].append(gather_mean(mean_entropy))
911
+
912
+ # Calculate margins
913
+ margin = chosen_rewards - rejected_rewards
914
+ self.stats["rewards/margins"].append(gather_mean(margin.mean()))
915
+
916
+ # Calculate accuracy
917
+ accuracy = (margin > 0).float()
918
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
919
+
920
+ # Log EOS token statistics
921
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
922
+ ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
923
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
924
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
925
+
926
+ # Log alpha and beta
927
+ self.stats["alpha"].append(self.alpha)
928
+ self.stats["beta"].append(self.beta)
929
+
930
+ def training_step(
931
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
932
+ ) -> torch.Tensor:
933
+ model.train()
934
+
935
+ # Apply chat template and tokenize the input
936
+ batch_size = len(next(iter(inputs.values())))
937
+ prompts = inputs["prompt"]
938
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
939
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
940
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
941
+ inputs = self.data_collator(inputs)
942
+
943
+ # need the prompt_ only
944
+ inputs = self._prepare_inputs(inputs)
945
+ context_length = inputs["prompt_input_ids"].shape[1]
946
+ prompts = {
947
+ "input_ids": inputs["prompt_input_ids"],
948
+ "attention_mask": inputs["prompt_attention_mask"],
949
+ "raw": prompts,
950
+ }
951
+ del inputs
952
+
953
+ # Sample completions from both the model and the reference model
954
+ model_output, ref_output = self._generate_completions(prompts, model)
955
+
956
+ # Process model completions
957
+ model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
958
+
959
+ # Compute rewards
960
+ if self.reward_funcs is not None:
961
+ model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
962
+ chosen_mask = model_scores >= ref_scores
963
+ else:
964
+ model_scores, ref_scores = None, None
965
+ chosen_mask = self._compute_judge(model_data, ref_data, context_length)
966
+
967
+ # Compute logprobs
968
+ model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
969
+ self._compute_logprobs(model, model_data, ref_data, context_length)
970
+ )
971
+
972
+ # Compute loss
973
+ loss, dpo_losses, xpo_losses = self._compute_losses(
974
+ model_logprobs_model_data,
975
+ model_logprobs_ref_data,
976
+ ref_logprobs_ref_data,
977
+ ref_logprobs_model_data,
978
+ chosen_mask,
979
+ )
980
+
981
+ # Log everything
982
+ self._log_statistics(
983
+ model_data,
984
+ ref_data,
985
+ model_logprobs_model_data.detach(),
986
+ model_logprobs_ref_data.detach(),
987
+ ref_logprobs_ref_data,
988
+ ref_logprobs_model_data,
989
+ chosen_mask,
990
+ dpo_losses.detach(),
991
+ xpo_losses.detach(),
992
+ context_length,
993
+ model_scores,
994
+ ref_scores,
995
+ )
996
+
997
+ if (
998
+ self.args.torch_empty_cache_steps is not None
999
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
1000
+ ):
1001
+ empty_cache()
1002
+
1003
+ kwargs = {}
1004
+ # For LOMO optimizers you need to explicitly use the learning rate
1005
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
1006
+ kwargs["learning_rate"] = self._get_learning_rate()
1007
+
1008
+ if self.args.n_gpu > 1:
1009
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
1010
+
1011
+ self.accelerator.backward(loss, **kwargs)
1012
+
1013
+ return loss.detach() / self.args.gradient_accumulation_steps
1014
+ class UnslothXPOTrainer(_UnslothXPOTrainer):
1015
+ """
1016
+
1017
+ Trainer for Exploratory Preference Optimization (XPO).
1018
+
1019
+ It is implemented as a subclass of [`OnlineDPOTrainer`].
1020
+
1021
+ Args:
1022
+ model ([`~transformers.PreTrainedModel`]):
1023
+ The model to train, preferably an `AutoModelForCausalLM`.
1024
+ ref_model ([`PreTrainedModelWrapper`]):
1025
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
1026
+ and loss. If no reference model is provided, the trainer will create a reference model with the same
1027
+ architecture as the model to be optimized.
1028
+ reward_funcs ([`~transformers.PreTrainedModel`]):
1029
+ The reward model to score completions with, preferably an
1030
+ [`~transformers.AutoModelForSequenceClassification`].
1031
+ judge ([`BasePairwiseJudge`]):
1032
+ The judge to use for pairwise comparison of model completions.
1033
+ args ([`XPOConfig`]):
1034
+ The XPO config arguments to use for training.
1035
+ data_collator ([`~transformers.DataCollator`]):
1036
+ The data collator to use for training. If None is specified, the default data collator
1037
+ ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
1038
+ sequences in the batch, given a dataset of paired sequences.
1039
+ train_dataset ([`~datasets.Dataset`]):
1040
+ The dataset to use for training.
1041
+ eval_dataset ([`~datasets.Dataset`]):
1042
+ The dataset to use for evaluation.
1043
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
1044
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1045
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1046
+ reuse the fine-tuned model.
1047
+ peft_config (`dict`):
1048
+ The peft config to use for training.
1049
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1050
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
1051
+ metric values.
1052
+ callbacks (`list[transformers.TrainerCallback]`):
1053
+ The callbacks to use for training.
1054
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1055
+ The optimizer and scheduler to use for training.
1056
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1057
+ The function to use to preprocess the logits before computing the metrics.
1058
+
1059
+ reward_model:
1060
+
1061
+ <Deprecated version="0.22.0">
1062
+
1063
+ This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
1064
+
1065
+ </Deprecated>
1066
+
1067
+ """
1068
+ def __init__(
1069
+ self,
1070
+ model = None,
1071
+ ref_model = None,
1072
+ reward_funcs = None,
1073
+ judge = None,
1074
+ args = None,
1075
+ data_collator = None,
1076
+ train_dataset = None,
1077
+ eval_dataset = None,
1078
+ processing_class = None,
1079
+ reward_processing_classes = None,
1080
+ peft_config = None,
1081
+ compute_metrics = None,
1082
+ callbacks = None,
1083
+ preprocess_logits_for_metrics = None,
1084
+ reward_model = None,
1085
+ **kwargs
1086
+ ):
1087
+ if args is None: args = UnslothXPOConfig()
1088
+ use_bf16 = getattr(args, 'bf16', False)
1089
+ if type(use_bf16) is not bool: use_bf16 = False
1090
+ use_fp16 = getattr(args, 'fp16', False)
1091
+ if type(use_fp16) is not bool: use_fp16 = False
1092
+ force_float32 = False
1093
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
1094
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1095
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1096
+ force_float32 = True
1097
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1098
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1099
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1100
+ from unsloth_zoo.utils import _get_dtype
1101
+ dtype = _get_dtype(dtype)
1102
+ float16 = dtype == torch.float16
1103
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1104
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1105
+ if force_float32:
1106
+ # Forced float32 training
1107
+ args.fp16 = False
1108
+ args.bf16 = False
1109
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1110
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1111
+ # args.mixed_precision is a new argument which needs to be set now
1112
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1113
+ # Mixed precision training
1114
+ args.fp16 = float16
1115
+ args.bf16 = not float16
1116
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1117
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1118
+ # args.mixed_precision is a new argument which needs to be set now
1119
+ elif mixed_precision_dtype == 'bfloat16':
1120
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1121
+ args.fp16 = False
1122
+ args.bf16 = False
1123
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1124
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1125
+ # args.mixed_precision is a new argument which needs to be set now
1126
+
1127
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1128
+ args.eval_strategy = 'steps'
1129
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1130
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1131
+ if ga_steps is not None and ga_steps > 1:
1132
+ from transformers import __version__ as transformers_version
1133
+ if Version(transformers_version) <= Version('4.45.2'):
1134
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1135
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1136
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1137
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1138
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1139
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1140
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1141
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1142
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1143
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1144
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1145
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1146
+ if force_float32:
1147
+ args.bf16_full_eval = False
1148
+ args.fp16_full_eval = False
1149
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1150
+ args.bf16_full_eval = True
1151
+ args.fp16_full_eval = False
1152
+ elif not bf16_full_eval and not fp16_full_eval:
1153
+ args.bf16_full_eval = args.bf16
1154
+ args.fp16_full_eval = args.fp16
1155
+ _output_logits = False
1156
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1157
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1158
+ if _output_logits:
1159
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1160
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1161
+ pass
1162
+ else:
1163
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1164
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1165
+ if args_max_seq_length is None and model_max_seq_length is not None:
1166
+ max_seq_length = model.max_seq_length
1167
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1168
+ if model is not None and hasattr(model, 'for_training'):
1169
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1170
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1171
+ if 'processing_class' in locals():
1172
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1173
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1174
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1175
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1176
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1177
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1178
+ data_collator = TransformersDataCollatorForLanguageModeling(
1179
+ __tokenizer,
1180
+ mlm = False,
1181
+ mlm_probability = 0.0,
1182
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1183
+ )
1184
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1185
+ data_collator = DataCollatorForSeq2Seq(
1186
+ __tokenizer,
1187
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1188
+ )
1189
+ else:
1190
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1191
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1192
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1193
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1194
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1195
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1196
+ data_collator = DataCollatorForSeq2Seq(
1197
+ __tokenizer.tokenizer,
1198
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1199
+ )
1200
+ else:
1201
+ data_collator = TransformersDataCollatorForLanguageModeling(
1202
+ __tokenizer.tokenizer,
1203
+ mlm = False,
1204
+ mlm_probability = 0.0,
1205
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1206
+ )
1207
+ other_metrics = []
1208
+
1209
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1210
+ PatchRLStatistics('xpo_trainer', other_metrics)
1211
+
1212
+ # [TODO] Fix up DataParallel multiplying batch sizes
1213
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1214
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1215
+ if getattr(args, "_n_gpu", 1) != 1:
1216
+ args._n_gpu = 1
1217
+ if "model" in locals() and hasattr(model, "for_training"):
1218
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1219
+ super().__init__(
1220
+ model = model,
1221
+ ref_model = ref_model,
1222
+ reward_funcs = reward_funcs,
1223
+ judge = judge,
1224
+ args = args,
1225
+ data_collator = data_collator,
1226
+ train_dataset = train_dataset,
1227
+ eval_dataset = eval_dataset,
1228
+ processing_class = processing_class,
1229
+ reward_processing_classes = reward_processing_classes,
1230
+ peft_config = peft_config,
1231
+ compute_metrics = compute_metrics,
1232
+ callbacks = callbacks,
1233
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1234
+ reward_model = reward_model,**kwargs)
1235
+ if "model" in locals() and hasattr(model, "for_inference"):
1236
+ model.for_inference()
1237
+ if hasattr(self, 'neftune_hook_handle'):
1238
+ self.neftune_hook_handle.remove()
1239
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1240
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1241
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1242
+ pass
1243
+ if hasattr(self, 'accelerator'):
1244
+ scaler = self.accelerator.scaler
1245
+ current_model = model
1246
+ while hasattr(current_model, 'model'):
1247
+ current_model.accelerator_scaler = scaler
1248
+ current_model = current_model.model
1249
+ current_model.accelerator_scaler = scaler
1250
+ pass
1251
+ if hasattr(self, 'train'):
1252
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1253
+ pass
1254
+
1255
+ pass