cciwon-code-review-cli 2.0.2 → 2.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/lib/api-client.js +1 -1
  2. package/lib/chat-mode.js +7 -2
  3. package/package.json +1 -1
  4. package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
  5. package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
  6. package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
  7. package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
  8. package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
  9. package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
  10. package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
  11. package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
  12. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
  13. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
  14. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
  15. package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
  16. package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
  17. package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
  18. package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
  19. package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
  20. package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
  21. package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
  22. package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
  23. package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
  24. package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
  25. package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
  26. package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
  27. package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
  28. package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
  29. package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
  30. package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
  31. package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
  32. package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
  33. package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
  34. package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
  35. package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
  36. package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
  37. package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
  38. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
  39. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
  40. package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
  41. package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
  42. package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
  43. package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
  44. package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
  45. package/unsloth_compiled_cache/Conv1d.py +70 -0
  46. package/unsloth_compiled_cache/Conv2d.py +70 -0
  47. package/unsloth_compiled_cache/Conv3d.py +70 -0
  48. package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
  49. package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
  50. package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
  51. package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
  52. package/unsloth_compiled_cache/GroupNorm.py +70 -0
  53. package/unsloth_compiled_cache/LayerNorm.py +72 -0
  54. package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
  55. package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
  56. package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
  57. package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
  58. package/unsloth_compiled_cache/RMSNorm.py +73 -0
  59. package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
  60. package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
  61. package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
  62. package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
  63. package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
  64. package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
  65. package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
  66. package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
  67. package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
  68. package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
  69. package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
  70. package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
  71. package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
  72. package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
  73. package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
  74. package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
  75. package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
  76. package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
  77. package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
  78. package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
  79. package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
  80. package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
  81. package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
  82. package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
  83. package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
  84. package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
  85. package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
  86. package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
  87. package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
  88. package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
  89. package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
  90. package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
  91. package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
  92. package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
  93. package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  94. package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  95. package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
  96. package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  97. package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
  98. package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  99. package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  100. package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  101. package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  102. package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  103. package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  104. package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  105. package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  106. package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  107. package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
  108. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
  109. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
  110. package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
  111. package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
@@ -0,0 +1,979 @@
1
+ """
2
+ 2025.12.6
3
+ 2025.12.7
4
+ 4.57.1
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
30
+ from trl.trainer.prm_trainer import (BaseImageProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, Path, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, nn, os, prepare_peft_model, textwrap, torch, warnings, Optional, PreTrainedModel, os, torch)
31
+
32
+
33
+ import os
34
+ from typing import *
35
+ from dataclasses import dataclass, field
36
+ from packaging.version import Version
37
+ import torch
38
+ import numpy as np
39
+ from contextlib import nullcontext
40
+ from torch.nn import functional as F
41
+ import inspect
42
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
43
+ from transformers.training_args import ParallelMode
44
+
45
+ # Wrap trainer with padding to right and enable training mode
46
+ # Also patches W&B since multiple runs must use wandb.finish()
47
+ import functools
48
+ from types import MethodType
49
+ def prepare_for_training_mode(f):
50
+ @functools.wraps(f)
51
+ def wrapper(self, *args, **kwargs):
52
+ # Enable training mode
53
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
54
+ self.model.for_training()
55
+ output = f(self, *args, **kwargs)
56
+ # Return inference mode
57
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
58
+ self.model.for_inference()
59
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
60
+ try:
61
+ import wandb
62
+ wandb.finish()
63
+ except:
64
+ pass
65
+ return output
66
+ return wrapper
67
+ pass
68
+
69
+ torch_compile_options = {
70
+ "epilogue_fusion" : True,
71
+ "max_autotune" : False,
72
+ "shape_padding" : True,
73
+ "trace.enabled" : False,
74
+ "triton.cudagraphs" : False,
75
+ }
76
+
77
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
78
+ def chunked_selective_log_softmax(logits, index):
79
+ # Split into 4 chunks only
80
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
81
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
82
+ all_per_token_logps = []
83
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
84
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
85
+ chunk_logits = chunk_logits.to(torch.float32)
86
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
87
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
88
+ per_token_logps = selected_logits - logsumexp_values
89
+ all_per_token_logps.append(per_token_logps)
90
+ pass
91
+ all_per_token_logps = torch.concat(all_per_token_logps)
92
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
93
+ return all_per_token_logps
94
+
95
+ def calculate_pad_tokens_in_prompt(
96
+ input_ids: torch.Tensor,
97
+ logits_to_keep: int,
98
+ pad_token_id: int
99
+ ) -> torch.Tensor:
100
+ """
101
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
102
+ """
103
+ if logits_to_keep >= input_ids.shape[1]:
104
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
105
+
106
+ prompt_section = input_ids[:, :-logits_to_keep]
107
+
108
+ padding_mask = (prompt_section == pad_token_id)
109
+
110
+ pad_token_counts = padding_mask.sum(dim=1)
111
+
112
+ return pad_token_counts
113
+
114
+ def create_completion_attention_mask(
115
+ completion_input_ids: torch.Tensor,
116
+ left_pad_tokens_per_prompt: torch.Tensor,
117
+ max_left_pad: int,
118
+ pad_token_id: int
119
+ ) -> torch.Tensor:
120
+ """
121
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
122
+
123
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
124
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
125
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
126
+ """
127
+ batch_size, completion_len = completion_input_ids.shape
128
+ device = completion_input_ids.device
129
+
130
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
131
+
132
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
133
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
134
+
135
+ non_padding_mask = (completion_input_ids != pad_token_id)
136
+
137
+ final_mask = shift_mask & non_padding_mask
138
+
139
+ return final_mask
140
+
141
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
142
+ """
143
+ Moves all padding tokens in each sequence of a batch to the right.
144
+ """
145
+ mask = (tensor != pad_id)
146
+ # Must do stable=True since binary mark is unordered
147
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
148
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
149
+ return packed_tensor
150
+
151
+ def align_logprobs_with_mask(
152
+ logprob_tensor: torch.Tensor,
153
+ attention_mask: torch.Tensor,
154
+ pad_value: float = 0.0
155
+ ) -> torch.Tensor:
156
+ """
157
+ Aligns a log probability tensor with a given attention mask.
158
+ """
159
+
160
+ device = logprob_tensor.device
161
+ batch_size, logprob_seq_len = logprob_tensor.shape
162
+ mask_seq_len = attention_mask.shape[1]
163
+
164
+ padded_logprobs = torch.full(
165
+ attention_mask.shape,
166
+ fill_value=pad_value,
167
+ dtype=logprob_tensor.dtype,
168
+ device=device
169
+ )
170
+
171
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
172
+
173
+ cols = torch.arange(logprob_seq_len, device=device)
174
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
175
+
176
+ # Create destination row indices
177
+ # Shape: [batch_size, logprob_seq_len]
178
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
179
+
180
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
181
+ # Create a mask to identify only the indices that are within the bounds
182
+ # of the target tensor's sequence length.
183
+ valid_mask = dest_indices < mask_seq_len
184
+
185
+ # Use this mask to select only the valid row indices, column indices,
186
+ # and the corresponding values from the logprob tensor.
187
+ # This flattens the selected elements into 1D tensors.
188
+ valid_rows = row_indices[valid_mask]
189
+ valid_cols = dest_indices[valid_mask]
190
+ valid_vals = logprob_tensor[valid_mask]
191
+
192
+ # Place the valid values into their correct positions in the padded tensor
193
+ # using a single, efficient advanced indexing operation.
194
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
195
+
196
+ return padded_logprobs
197
+ @dataclass
198
+ class UnslothPRMConfig(PRMConfig):
199
+ """
200
+
201
+ Configuration class for the [`PRMTrainer`].
202
+
203
+ This class includes only the parameters that are specific to PRM training. For a full list of training arguments,
204
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
205
+ differ from those in [`~transformers.TrainingArguments`].
206
+
207
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
208
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
209
+ command line.
210
+
211
+ Parameters:
212
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
213
+ Maximum length of the sequences (prompt + completion) used for truncation.
214
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
215
+ Maximum length of the prompt used for truncation.
216
+ max_completion_length (`int`, *optional*):
217
+ Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
218
+ disable_dropout (`bool`, *optional*, defaults to `True`):
219
+ Whether to disable dropout in the model.
220
+ step_separator (`str`, *optional*, defaults to `"\n"`):
221
+ Separator used to separate each step of the reasoning process.
222
+ train_on_last_step_only (`bool`, *optional*, defaults to `False`):
223
+ Whether to train only on the last step.
224
+ dataset_num_proc (`int`, *optional*):
225
+ Number of processes to use for processing the dataset.
226
+
227
+ """
228
+ vllm_sampling_params: Optional[Any] = field(
229
+ default = None,
230
+ metadata = {'help': 'vLLM SamplingParams'},
231
+ )
232
+ unsloth_num_chunks : Optional[int] = field(
233
+ default = -1,
234
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
235
+ )
236
+ max_seq_length : Optional[int] = field(
237
+ default = None,
238
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
239
+ )
240
+ def __init__(
241
+ self,
242
+ output_dir = None,
243
+ overwrite_output_dir = None,
244
+ do_train = False,
245
+ do_eval = False,
246
+ do_predict = False,
247
+ eval_strategy = 'no',
248
+ prediction_loss_only = False,
249
+ per_device_train_batch_size = 4,
250
+ per_device_eval_batch_size = 4,
251
+ per_gpu_train_batch_size = None,
252
+ per_gpu_eval_batch_size = None,
253
+ gradient_accumulation_steps = 2,
254
+ eval_accumulation_steps = 2,
255
+ eval_delay = 0,
256
+ torch_empty_cache_steps = 250,
257
+ learning_rate = 5e-05,
258
+ weight_decay = 0.01,
259
+ adam_beta1 = 0.9,
260
+ adam_beta2 = 0.999,
261
+ adam_epsilon = 1e-08,
262
+ max_grad_norm = 1.0,
263
+ num_train_epochs = 3.0,
264
+ max_steps = -1,
265
+ lr_scheduler_type = 'linear',
266
+ warmup_ratio = 0.1,
267
+ warmup_steps = 0,
268
+ log_level = 'passive',
269
+ log_level_replica = 'warning',
270
+ log_on_each_node = True,
271
+ logging_dir = None,
272
+ logging_strategy = 'steps',
273
+ logging_first_step = False,
274
+ logging_steps = 1,
275
+ logging_nan_inf_filter = False,
276
+ save_strategy = 'steps',
277
+ save_steps = 500,
278
+ save_total_limit = None,
279
+ save_safetensors = True,
280
+ save_on_each_node = False,
281
+ save_only_model = False,
282
+ restore_callback_states_from_checkpoint = False,
283
+ no_cuda = False,
284
+ use_cpu = False,
285
+ use_mps_device = False,
286
+ seed = 3407,
287
+ data_seed = 3407,
288
+ jit_mode_eval = False,
289
+ bf16 = False,
290
+ fp16 = False,
291
+ fp16_opt_level = 'O1',
292
+ half_precision_backend = 'auto',
293
+ bf16_full_eval = False,
294
+ fp16_full_eval = False,
295
+ tf32 = None,
296
+ local_rank = -1,
297
+ ddp_backend = None,
298
+ tpu_num_cores = None,
299
+ tpu_metrics_debug = False,
300
+ debug = '',
301
+ dataloader_drop_last = False,
302
+ eval_steps = None,
303
+ dataloader_num_workers = 0,
304
+ dataloader_prefetch_factor = None,
305
+ past_index = -1,
306
+ run_name = None,
307
+ disable_tqdm = None,
308
+ remove_unused_columns = True,
309
+ label_names = None,
310
+ load_best_model_at_end = False,
311
+ metric_for_best_model = None,
312
+ greater_is_better = None,
313
+ ignore_data_skip = False,
314
+ fsdp = None,
315
+ fsdp_min_num_params = 0,
316
+ fsdp_config = None,
317
+ fsdp_transformer_layer_cls_to_wrap = None,
318
+ accelerator_config = None,
319
+ parallelism_config = None,
320
+ deepspeed = None,
321
+ label_smoothing_factor = 0.0,
322
+ optim = 'adamw_8bit',
323
+ optim_args = None,
324
+ adafactor = False,
325
+ group_by_length = False,
326
+ length_column_name = 'length',
327
+ report_to = 'none',
328
+ project = 'huggingface',
329
+ trackio_space_id = 'trackio',
330
+ ddp_find_unused_parameters = None,
331
+ ddp_bucket_cap_mb = None,
332
+ ddp_broadcast_buffers = None,
333
+ dataloader_pin_memory = True,
334
+ dataloader_persistent_workers = False,
335
+ skip_memory_metrics = True,
336
+ use_legacy_prediction_loop = False,
337
+ push_to_hub = False,
338
+ resume_from_checkpoint = None,
339
+ hub_model_id = None,
340
+ hub_strategy = 'every_save',
341
+ hub_token = None,
342
+ hub_private_repo = None,
343
+ hub_always_push = False,
344
+ hub_revision = None,
345
+ gradient_checkpointing = True,
346
+ gradient_checkpointing_kwargs = None,
347
+ include_inputs_for_metrics = False,
348
+ eval_do_concat_batches = True,
349
+ fp16_backend = 'auto',
350
+ push_to_hub_model_id = None,
351
+ push_to_hub_organization = None,
352
+ push_to_hub_token = None,
353
+ mp_parameters = '',
354
+ auto_find_batch_size = False,
355
+ full_determinism = False,
356
+ torchdynamo = None,
357
+ ray_scope = 'last',
358
+ ddp_timeout = 1800,
359
+ torch_compile = False,
360
+ torch_compile_backend = None,
361
+ torch_compile_mode = None,
362
+ include_tokens_per_second = False,
363
+ include_num_input_tokens_seen = False,
364
+ neftune_noise_alpha = None,
365
+ optim_target_modules = None,
366
+ batch_eval_metrics = False,
367
+ eval_on_start = False,
368
+ use_liger_kernel = False,
369
+ liger_kernel_config = None,
370
+ eval_use_gather_object = False,
371
+ average_tokens_across_devices = True,
372
+ max_length = 1024,
373
+ max_prompt_length = 512,
374
+ max_completion_length = None,
375
+ disable_dropout = True,
376
+ step_separator = '\
377
+ ',
378
+ train_on_last_step_only = False,
379
+ dataset_num_proc = None,
380
+ vllm_sampling_params = None,
381
+ unsloth_num_chunks = -1,
382
+ max_seq_length = None,
383
+ **kwargs,
384
+ ):
385
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
386
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
387
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
388
+ output_dir = 'unsloth_training_checkpoints'
389
+ save_strategy = 'no'
390
+ if dataset_num_proc is None:
391
+ from multiprocessing import cpu_count
392
+ dataset_num_proc = min(max(cpu_count()+4, 2), 64)
393
+
394
+ super().__init__(
395
+ output_dir = output_dir,
396
+ overwrite_output_dir = overwrite_output_dir,
397
+ do_train = do_train,
398
+ do_eval = do_eval,
399
+ do_predict = do_predict,
400
+ eval_strategy = eval_strategy,
401
+ prediction_loss_only = prediction_loss_only,
402
+ per_device_train_batch_size = per_device_train_batch_size,
403
+ per_device_eval_batch_size = per_device_eval_batch_size,
404
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
405
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
406
+ gradient_accumulation_steps = gradient_accumulation_steps,
407
+ eval_accumulation_steps = eval_accumulation_steps,
408
+ eval_delay = eval_delay,
409
+ torch_empty_cache_steps = torch_empty_cache_steps,
410
+ learning_rate = learning_rate,
411
+ weight_decay = weight_decay,
412
+ adam_beta1 = adam_beta1,
413
+ adam_beta2 = adam_beta2,
414
+ adam_epsilon = adam_epsilon,
415
+ max_grad_norm = max_grad_norm,
416
+ num_train_epochs = num_train_epochs,
417
+ max_steps = max_steps,
418
+ lr_scheduler_type = lr_scheduler_type,
419
+ warmup_ratio = warmup_ratio,
420
+ warmup_steps = warmup_steps,
421
+ log_level = log_level,
422
+ log_level_replica = log_level_replica,
423
+ log_on_each_node = log_on_each_node,
424
+ logging_dir = logging_dir,
425
+ logging_strategy = logging_strategy,
426
+ logging_first_step = logging_first_step,
427
+ logging_steps = logging_steps,
428
+ logging_nan_inf_filter = logging_nan_inf_filter,
429
+ save_strategy = save_strategy,
430
+ save_steps = save_steps,
431
+ save_total_limit = save_total_limit,
432
+ save_safetensors = save_safetensors,
433
+ save_on_each_node = save_on_each_node,
434
+ save_only_model = save_only_model,
435
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
436
+ no_cuda = no_cuda,
437
+ use_cpu = use_cpu,
438
+ use_mps_device = use_mps_device,
439
+ seed = seed,
440
+ data_seed = data_seed,
441
+ jit_mode_eval = jit_mode_eval,
442
+ bf16 = bf16,
443
+ fp16 = fp16,
444
+ fp16_opt_level = fp16_opt_level,
445
+ half_precision_backend = half_precision_backend,
446
+ bf16_full_eval = bf16_full_eval,
447
+ fp16_full_eval = fp16_full_eval,
448
+ tf32 = tf32,
449
+ local_rank = local_rank,
450
+ ddp_backend = ddp_backend,
451
+ tpu_num_cores = tpu_num_cores,
452
+ tpu_metrics_debug = tpu_metrics_debug,
453
+ debug = debug,
454
+ dataloader_drop_last = dataloader_drop_last,
455
+ eval_steps = eval_steps,
456
+ dataloader_num_workers = dataloader_num_workers,
457
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
458
+ past_index = past_index,
459
+ run_name = run_name,
460
+ disable_tqdm = disable_tqdm,
461
+ remove_unused_columns = remove_unused_columns,
462
+ label_names = label_names,
463
+ load_best_model_at_end = load_best_model_at_end,
464
+ metric_for_best_model = metric_for_best_model,
465
+ greater_is_better = greater_is_better,
466
+ ignore_data_skip = ignore_data_skip,
467
+ fsdp = fsdp,
468
+ fsdp_min_num_params = fsdp_min_num_params,
469
+ fsdp_config = fsdp_config,
470
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
471
+ accelerator_config = accelerator_config,
472
+ parallelism_config = parallelism_config,
473
+ deepspeed = deepspeed,
474
+ label_smoothing_factor = label_smoothing_factor,
475
+ optim = optim,
476
+ optim_args = optim_args,
477
+ adafactor = adafactor,
478
+ group_by_length = group_by_length,
479
+ length_column_name = length_column_name,
480
+ report_to = report_to,
481
+ project = project,
482
+ trackio_space_id = trackio_space_id,
483
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
484
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
485
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
486
+ dataloader_pin_memory = dataloader_pin_memory,
487
+ dataloader_persistent_workers = dataloader_persistent_workers,
488
+ skip_memory_metrics = skip_memory_metrics,
489
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
490
+ push_to_hub = push_to_hub,
491
+ resume_from_checkpoint = resume_from_checkpoint,
492
+ hub_model_id = hub_model_id,
493
+ hub_strategy = hub_strategy,
494
+ hub_token = hub_token,
495
+ hub_private_repo = hub_private_repo,
496
+ hub_always_push = hub_always_push,
497
+ hub_revision = hub_revision,
498
+ gradient_checkpointing = gradient_checkpointing,
499
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
500
+ include_inputs_for_metrics = include_inputs_for_metrics,
501
+ eval_do_concat_batches = eval_do_concat_batches,
502
+ fp16_backend = fp16_backend,
503
+ push_to_hub_model_id = push_to_hub_model_id,
504
+ push_to_hub_organization = push_to_hub_organization,
505
+ push_to_hub_token = push_to_hub_token,
506
+ mp_parameters = mp_parameters,
507
+ auto_find_batch_size = auto_find_batch_size,
508
+ full_determinism = full_determinism,
509
+ torchdynamo = torchdynamo,
510
+ ray_scope = ray_scope,
511
+ ddp_timeout = ddp_timeout,
512
+ torch_compile = torch_compile,
513
+ torch_compile_backend = torch_compile_backend,
514
+ torch_compile_mode = torch_compile_mode,
515
+ include_tokens_per_second = include_tokens_per_second,
516
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
517
+ neftune_noise_alpha = neftune_noise_alpha,
518
+ optim_target_modules = optim_target_modules,
519
+ batch_eval_metrics = batch_eval_metrics,
520
+ eval_on_start = eval_on_start,
521
+ use_liger_kernel = use_liger_kernel,
522
+ liger_kernel_config = liger_kernel_config,
523
+ eval_use_gather_object = eval_use_gather_object,
524
+ average_tokens_across_devices = average_tokens_across_devices,
525
+ max_length = max_length,
526
+ max_prompt_length = max_prompt_length,
527
+ max_completion_length = max_completion_length,
528
+ disable_dropout = disable_dropout,
529
+ step_separator = step_separator,
530
+ train_on_last_step_only = train_on_last_step_only,
531
+ dataset_num_proc = dataset_num_proc,**kwargs)
532
+ self.vllm_sampling_params = vllm_sampling_params
533
+ self.unsloth_num_chunks = unsloth_num_chunks
534
+ self.max_seq_length = max_seq_length
535
+ pass
536
+
537
+ class _UnslothPRMTrainer(BaseTrainer):
538
+ """"""
539
+
540
+ _tag_names = ["trl", "prm"]
541
+ _name = "PRM"
542
+ _paper = {
543
+ "title": "Solving math word problems with process-and outcome-based feedback",
544
+ "id": "2211.14275",
545
+ # docstyle-ignore
546
+ "citation": textwrap.dedent("""\
547
+ @article{uesato2022solving,
548
+ title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
549
+ author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
550
+ year = 2022,
551
+ journal = {arXiv preprint arXiv:2211.14275}
552
+ }"""),
553
+ }
554
+
555
+ def __init__(
556
+ self,
557
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
558
+ args: Optional[PRMConfig] = None,
559
+ data_collator: Optional[DataCollator] = None,
560
+ train_dataset: Optional[Dataset] = None,
561
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
562
+ processing_class: Optional[
563
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
564
+ ] = None,
565
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
566
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
567
+ callbacks: Optional[list[TrainerCallback]] = None,
568
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
569
+ None,
570
+ None,
571
+ ),
572
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
573
+ peft_config: Optional[dict] = None,
574
+ ):
575
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
576
+ warnings.warn(
577
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
578
+ "it and want it to remain, please share your comments here: "
579
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
580
+ "TRL_EXPERIMENTAL_SILENCE=1."
581
+ )
582
+ if False:
583
+ model = prepare_peft_model(model, peft_config, args)
584
+
585
+ # Disable dropout in the model
586
+ if args.disable_dropout:
587
+ disable_dropout_in_model(model)
588
+
589
+ if compute_metrics is None:
590
+ compute_metrics = compute_accuracy
591
+
592
+ if data_collator is None:
593
+ if processing_class is None:
594
+ raise ValueError(
595
+ "A processing_class must be specified when using the default DataCollatorForTokenClassification"
596
+ )
597
+ data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
598
+
599
+ if "input_ids" not in train_dataset.column_names:
600
+ with PartialState().main_process_first():
601
+ fn_kwargs = {
602
+ "tokenizer": processing_class,
603
+ "step_separator": args.step_separator,
604
+ "max_length": args.max_length,
605
+ "max_prompt_length": args.max_prompt_length,
606
+ "max_completion_length": args.max_completion_length,
607
+ "train_on_last_step_only": args.train_on_last_step_only,
608
+ }
609
+ train_fn_kwargs = {**fn_kwargs, "is_eval": False}
610
+ train_dataset = train_dataset.map(
611
+ self.tokenize_row,
612
+ fn_kwargs=train_fn_kwargs,
613
+ num_proc=args.dataset_num_proc,
614
+ remove_columns=train_dataset.features,
615
+ desc="Tokenizing train dataset",
616
+ features=features.Features( # needed to avoid map to cast labels to bool
617
+ {
618
+ "labels": features.Sequence(features.Value("int64")),
619
+ "input_ids": features.Sequence(features.Value("int64")),
620
+ }
621
+ ),
622
+ )
623
+
624
+ eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
625
+ if eval_dataset is not None:
626
+ eval_dataset = eval_dataset.map(
627
+ self.tokenize_row,
628
+ fn_kwargs=eval_fn_kwargs,
629
+ num_proc=args.dataset_num_proc,
630
+ remove_columns=eval_dataset.features,
631
+ desc="Tokenizing eval dataset",
632
+ features=features.Features( # needed to avoid map to cast labels to bool
633
+ {
634
+ "labels": features.Sequence(features.Value("int64")),
635
+ "input_ids": features.Sequence(features.Value("int64")),
636
+ }
637
+ ),
638
+ )
639
+
640
+ super().__init__(
641
+ model=model,
642
+ args=args,
643
+ data_collator=data_collator,
644
+ train_dataset=train_dataset,
645
+ eval_dataset=eval_dataset,
646
+ processing_class=processing_class,
647
+ model_init=model_init,
648
+ compute_metrics=compute_metrics,
649
+ callbacks=callbacks,
650
+ optimizers=optimizers,
651
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
652
+ )
653
+
654
+ # Add tags for models that have been loaded with the correct transformers version
655
+ if hasattr(self.model, "add_model_tags"):
656
+ self.model.add_model_tags(self._tag_names)
657
+
658
+ @staticmethod
659
+ def tokenize_row(
660
+ features,
661
+ tokenizer,
662
+ step_separator,
663
+ max_length,
664
+ max_prompt_length,
665
+ max_completion_length,
666
+ train_on_last_step_only,
667
+ is_eval,
668
+ ):
669
+ r"""
670
+ Tokenize a row of the dataset.
671
+
672
+ Args:
673
+ features (`dict[str, str]`):
674
+ Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
675
+ tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
676
+ Tokenizer used to process the data.
677
+ step_separator (`str`):
678
+ Separator between steps in the completion.
679
+ max_length (`int` or `None`):
680
+ Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
681
+ max_prompt_length (`int` or `None`):
682
+ Maximum length of the prompt. If `None`, the prompt is not truncated.
683
+ max_completion_length (`int` or `None`):
684
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
685
+ train_on_last_step_only (`bool`):
686
+ Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
687
+ token of the completion.
688
+ is_eval (`bool`):
689
+ Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if
690
+ `train_on_last_step_only` is set to `True`.
691
+
692
+ Returns:
693
+ `dict[str, list[int]]`:
694
+ Tokenized sequences with the keys `"input_ids"`, and `"labels".
695
+
696
+ Example:
697
+ ```python
698
+ >>> from transformers import AutoTokenizer
699
+
700
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
701
+ >>> features = {
702
+ ... "prompt": "Which number is larger, 9.8 or 9.11?",
703
+ ... "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
704
+ ... "labels": [True, False],
705
+ ... }
706
+ >>> PRMTrainer.tokenize_row(
707
+ ... features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False
708
+ ... )
709
+ {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
710
+ 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
711
+ ```
712
+ """
713
+ # Tokenize the prompt and completions
714
+ prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
715
+ completions_ids = [
716
+ tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
717
+ ]
718
+ if train_on_last_step_only and not is_eval:
719
+ labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
720
+ else:
721
+ labels = [int(label) for label in features["labels"]]
722
+
723
+ # Get the ID of the separator token and add it to the completions
724
+ separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
725
+ completions_ids = [completion + separator_ids for completion in completions_ids]
726
+
727
+ # Create the label
728
+ labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
729
+
730
+ # Join the completions and labels steps
731
+ completion_ids = list(chain(*completions_ids))
732
+ labels = list(chain(*labels))
733
+
734
+ if tokenizer.bos_token_id is not None:
735
+ prompt_ids = [tokenizer.bos_token_id] + prompt_ids
736
+
737
+ # Truncate prompt and completion sequences
738
+ if max_prompt_length is not None:
739
+ prompt_ids = prompt_ids[-max_prompt_length:]
740
+ if max_completion_length is not None:
741
+ completion_ids = completion_ids[:max_completion_length]
742
+ labels = labels[:max_completion_length]
743
+
744
+ input_ids = prompt_ids + completion_ids
745
+ labels = [-100] * len(prompt_ids) + labels
746
+
747
+ if max_length is not None:
748
+ input_ids = input_ids[:max_length]
749
+ labels = labels[:max_length]
750
+
751
+ return {"input_ids": input_ids, "labels": labels}
752
+
753
+ # Ensure the model card is saved along with the checkpoint
754
+ def _save_checkpoint(self, model, trial):
755
+ if self.args.hub_model_id is None:
756
+ model_name = Path(self.args.output_dir).name
757
+ else:
758
+ model_name = self.args.hub_model_id.split("/")[-1]
759
+ self.create_model_card(model_name=model_name)
760
+ super()._save_checkpoint(model, trial)
761
+ class UnslothPRMTrainer(_UnslothPRMTrainer):
762
+ """
763
+
764
+ Initialize PRMTrainer.
765
+
766
+ Args:
767
+ model ([`~transformers.PreTrainedModel`]):
768
+ The model to train, preferably an `AutoModelForTokenClassification`.
769
+ args ([`PRMConfig`]):
770
+ The arguments to use for training.
771
+ data_collator ([`~transformers.DataCollator`]):
772
+ The data collator to use for training. If None is specified, the default data collator
773
+ ([`~transformers.DataCollatorForTokenClassification`]) will be used which will pad the sequences to the
774
+ maximum length of the sequences in the batch, given a dataset of paired sequences.
775
+ train_dataset ([`~datasets.Dataset`]):
776
+ The dataset to use for training.
777
+ eval_dataset ([`~datasets.Dataset`]):
778
+ The dataset to use for evaluation.
779
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
780
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
781
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
782
+ reuse the fine-tuned model.
783
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
784
+ The model initializer to use for training. If None is specified, the default model initializer will be
785
+ used.
786
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
787
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`)
788
+ will be used.
789
+ callbacks (`list[transformers.TrainerCallback]`):
790
+ The callbacks to use for training.
791
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
792
+ The optimizer and scheduler to use for training.
793
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
794
+ The function to use to preprocess the logits before computing the metrics.
795
+ peft_config (`dict`, defaults to `None`):
796
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
797
+ a PEFT model.
798
+
799
+ """
800
+ def __init__(
801
+ self,
802
+ model = None,
803
+ args = None,
804
+ data_collator = None,
805
+ train_dataset = None,
806
+ eval_dataset = None,
807
+ processing_class = None,
808
+ model_init = None,
809
+ compute_metrics = None,
810
+ callbacks = None,
811
+ preprocess_logits_for_metrics = None,
812
+ peft_config = None,
813
+ **kwargs
814
+ ):
815
+ if args is None: args = UnslothPRMConfig()
816
+ use_bf16 = getattr(args, 'bf16', False)
817
+ if type(use_bf16) is not bool: use_bf16 = False
818
+ use_fp16 = getattr(args, 'fp16', False)
819
+ if type(use_fp16) is not bool: use_fp16 = False
820
+ force_float32 = False
821
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
822
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
823
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
824
+ force_float32 = True
825
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
826
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
827
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
828
+ from unsloth_zoo.utils import _get_dtype
829
+ dtype = _get_dtype(dtype)
830
+ float16 = dtype == torch.float16
831
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
832
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
833
+ if force_float32:
834
+ # Forced float32 training
835
+ args.fp16 = False
836
+ args.bf16 = False
837
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
838
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
839
+ # args.mixed_precision is a new argument which needs to be set now
840
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
841
+ # Mixed precision training
842
+ args.fp16 = float16
843
+ args.bf16 = not float16
844
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
845
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
846
+ # args.mixed_precision is a new argument which needs to be set now
847
+ elif mixed_precision_dtype == 'bfloat16':
848
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
849
+ args.fp16 = False
850
+ args.bf16 = False
851
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
852
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
853
+ # args.mixed_precision is a new argument which needs to be set now
854
+
855
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
856
+ args.eval_strategy = 'steps'
857
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
858
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
859
+ if ga_steps is not None and ga_steps > 1:
860
+ from transformers import __version__ as transformers_version
861
+ if Version(transformers_version) <= Version('4.45.2'):
862
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
863
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
864
+ if getattr(args, 'eval_strategy', 'no') != 'no':
865
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
866
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
867
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
868
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
869
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
870
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
871
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
872
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
873
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
874
+ if force_float32:
875
+ args.bf16_full_eval = False
876
+ args.fp16_full_eval = False
877
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
878
+ args.bf16_full_eval = True
879
+ args.fp16_full_eval = False
880
+ elif not bf16_full_eval and not fp16_full_eval:
881
+ args.bf16_full_eval = args.bf16
882
+ args.fp16_full_eval = args.fp16
883
+ _output_logits = False
884
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
885
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
886
+ if _output_logits:
887
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
888
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
889
+ pass
890
+ else:
891
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
892
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
893
+ if args_max_seq_length is None and model_max_seq_length is not None:
894
+ max_seq_length = model.max_seq_length
895
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
896
+ if model is not None and hasattr(model, 'for_training'):
897
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
898
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
899
+ if 'processing_class' in locals():
900
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
901
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
902
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
903
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
904
+ if not isinstance(data_collator, UnslothVisionDataCollator):
905
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
906
+ data_collator = TransformersDataCollatorForLanguageModeling(
907
+ __tokenizer,
908
+ mlm = False,
909
+ mlm_probability = 0.0,
910
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
911
+ )
912
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
913
+ data_collator = DataCollatorForSeq2Seq(
914
+ __tokenizer,
915
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
916
+ )
917
+ else:
918
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
919
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
920
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
921
+ if not isinstance(data_collator, UnslothVisionDataCollator):
922
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
923
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
924
+ data_collator = DataCollatorForSeq2Seq(
925
+ __tokenizer.tokenizer,
926
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
927
+ )
928
+ else:
929
+ data_collator = TransformersDataCollatorForLanguageModeling(
930
+ __tokenizer.tokenizer,
931
+ mlm = False,
932
+ mlm_probability = 0.0,
933
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
934
+ )
935
+ other_metrics = []
936
+
937
+ from unsloth_zoo.logging_utils import PatchRLStatistics
938
+ PatchRLStatistics('prm_trainer', other_metrics)
939
+
940
+ # [TODO] Fix up DataParallel multiplying batch sizes
941
+ # [TODO] DDP works, but DP seems to not work? [TODO]
942
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
943
+ if getattr(args, "_n_gpu", 1) != 1:
944
+ args._n_gpu = 1
945
+ if "model" in locals() and hasattr(model, "for_training"):
946
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
947
+ super().__init__(
948
+ model = model,
949
+ args = args,
950
+ data_collator = data_collator,
951
+ train_dataset = train_dataset,
952
+ eval_dataset = eval_dataset,
953
+ processing_class = processing_class,
954
+ model_init = model_init,
955
+ compute_metrics = compute_metrics,
956
+ callbacks = callbacks,
957
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
958
+ peft_config = peft_config,**kwargs)
959
+ if "model" in locals() and hasattr(model, "for_inference"):
960
+ model.for_inference()
961
+ if hasattr(self, 'neftune_hook_handle'):
962
+ self.neftune_hook_handle.remove()
963
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
964
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
965
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
966
+ pass
967
+ if hasattr(self, 'accelerator'):
968
+ scaler = self.accelerator.scaler
969
+ current_model = model
970
+ while hasattr(current_model, 'model'):
971
+ current_model.accelerator_scaler = scaler
972
+ current_model = current_model.model
973
+ current_model.accelerator_scaler = scaler
974
+ pass
975
+ if hasattr(self, 'train'):
976
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
977
+ pass
978
+
979
+ pass