cciwon-code-review-cli 2.0.2 → 2.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/lib/api-client.js +1 -1
  2. package/lib/chat-mode.js +7 -2
  3. package/package.json +1 -1
  4. package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
  5. package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
  6. package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
  7. package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
  8. package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
  9. package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
  10. package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
  11. package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
  12. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
  13. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
  14. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
  15. package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
  16. package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
  17. package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
  18. package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
  19. package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
  20. package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
  21. package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
  22. package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
  23. package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
  24. package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
  25. package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
  26. package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
  27. package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
  28. package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
  29. package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
  30. package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
  31. package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
  32. package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
  33. package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
  34. package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
  35. package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
  36. package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
  37. package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
  38. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
  39. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
  40. package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
  41. package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
  42. package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
  43. package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
  44. package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
  45. package/unsloth_compiled_cache/Conv1d.py +70 -0
  46. package/unsloth_compiled_cache/Conv2d.py +70 -0
  47. package/unsloth_compiled_cache/Conv3d.py +70 -0
  48. package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
  49. package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
  50. package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
  51. package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
  52. package/unsloth_compiled_cache/GroupNorm.py +70 -0
  53. package/unsloth_compiled_cache/LayerNorm.py +72 -0
  54. package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
  55. package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
  56. package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
  57. package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
  58. package/unsloth_compiled_cache/RMSNorm.py +73 -0
  59. package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
  60. package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
  61. package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
  62. package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
  63. package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
  64. package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
  65. package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
  66. package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
  67. package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
  68. package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
  69. package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
  70. package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
  71. package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
  72. package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
  73. package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
  74. package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
  75. package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
  76. package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
  77. package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
  78. package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
  79. package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
  80. package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
  81. package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
  82. package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
  83. package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
  84. package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
  85. package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
  86. package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
  87. package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
  88. package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
  89. package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
  90. package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
  91. package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
  92. package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
  93. package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  94. package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  95. package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
  96. package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  97. package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
  98. package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  99. package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  100. package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  101. package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  102. package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  103. package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  104. package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  105. package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  106. package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  107. package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
  108. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
  109. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
  110. package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
  111. package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
@@ -0,0 +1,1157 @@
1
+ """
2
+ 2025.12.6
3
+ 2025.12.7
4
+ 4.57.1
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ from torch import Tensor
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn import functional as F
29
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
30
+ from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, empty_cache, nn, os, prepare_deepspeed, random, textwrap, torch, unwrap_model_for_generation, warnings)
31
+
32
+
33
+ import os
34
+ from typing import *
35
+ from dataclasses import dataclass, field
36
+ from packaging.version import Version
37
+ import torch
38
+ import numpy as np
39
+ from contextlib import nullcontext
40
+ from torch.nn import functional as F
41
+ import inspect
42
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
43
+ from transformers.training_args import ParallelMode
44
+
45
+ # Wrap trainer with padding to right and enable training mode
46
+ # Also patches W&B since multiple runs must use wandb.finish()
47
+ import functools
48
+ from types import MethodType
49
+ def prepare_for_training_mode(f):
50
+ @functools.wraps(f)
51
+ def wrapper(self, *args, **kwargs):
52
+ # Enable training mode
53
+ if hasattr(self, 'model') and hasattr(self.model, "for_training"):
54
+ self.model.for_training()
55
+ output = f(self, *args, **kwargs)
56
+ # Return inference mode
57
+ if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
58
+ self.model.for_inference()
59
+ # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
60
+ try:
61
+ import wandb
62
+ wandb.finish()
63
+ except:
64
+ pass
65
+ return output
66
+ return wrapper
67
+ pass
68
+
69
+ torch_compile_options = {
70
+ "epilogue_fusion" : True,
71
+ "max_autotune" : False,
72
+ "shape_padding" : True,
73
+ "trace.enabled" : False,
74
+ "triton.cudagraphs" : False,
75
+ }
76
+
77
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
78
+ def chunked_selective_log_softmax(logits, index):
79
+ # Split into 4 chunks only
80
+ chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
81
+ chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
82
+ all_per_token_logps = []
83
+ # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
84
+ for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
85
+ chunk_logits = chunk_logits.to(torch.float32)
86
+ selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
87
+ logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
88
+ per_token_logps = selected_logits - logsumexp_values
89
+ all_per_token_logps.append(per_token_logps)
90
+ pass
91
+ all_per_token_logps = torch.concat(all_per_token_logps)
92
+ all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
93
+ return all_per_token_logps
94
+
95
+ def calculate_pad_tokens_in_prompt(
96
+ input_ids: torch.Tensor,
97
+ logits_to_keep: int,
98
+ pad_token_id: int
99
+ ) -> torch.Tensor:
100
+ """
101
+ Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
102
+ """
103
+ if logits_to_keep >= input_ids.shape[1]:
104
+ raise ValueError("logits_to_keep must be smaller than the sequence length.")
105
+
106
+ prompt_section = input_ids[:, :-logits_to_keep]
107
+
108
+ padding_mask = (prompt_section == pad_token_id)
109
+
110
+ pad_token_counts = padding_mask.sum(dim=1)
111
+
112
+ return pad_token_counts
113
+
114
+ def create_completion_attention_mask(
115
+ completion_input_ids: torch.Tensor,
116
+ left_pad_tokens_per_prompt: torch.Tensor,
117
+ max_left_pad: int,
118
+ pad_token_id: int
119
+ ) -> torch.Tensor:
120
+ """
121
+ Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
122
+
123
+ Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
124
+ and pad are pad tokens, this function would make a completion mask that would 0 out the pad
125
+ and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
126
+ """
127
+ batch_size, completion_len = completion_input_ids.shape
128
+ device = completion_input_ids.device
129
+
130
+ num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
131
+
132
+ indices = torch.arange(completion_len, device=device).unsqueeze(0)
133
+ shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
134
+
135
+ non_padding_mask = (completion_input_ids != pad_token_id)
136
+
137
+ final_mask = shift_mask & non_padding_mask
138
+
139
+ return final_mask
140
+
141
+ def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
142
+ """
143
+ Moves all padding tokens in each sequence of a batch to the right.
144
+ """
145
+ mask = (tensor != pad_id)
146
+ # Must do stable=True since binary mark is unordered
147
+ sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
148
+ packed_tensor = torch.gather(tensor, 1, sorted_indices)
149
+ return packed_tensor
150
+
151
+ def align_logprobs_with_mask(
152
+ logprob_tensor: torch.Tensor,
153
+ attention_mask: torch.Tensor,
154
+ pad_value: float = 0.0
155
+ ) -> torch.Tensor:
156
+ """
157
+ Aligns a log probability tensor with a given attention mask.
158
+ """
159
+
160
+ device = logprob_tensor.device
161
+ batch_size, logprob_seq_len = logprob_tensor.shape
162
+ mask_seq_len = attention_mask.shape[1]
163
+
164
+ padded_logprobs = torch.full(
165
+ attention_mask.shape,
166
+ fill_value=pad_value,
167
+ dtype=logprob_tensor.dtype,
168
+ device=device
169
+ )
170
+
171
+ left_pad_counts = torch.argmax(attention_mask, dim=1)
172
+
173
+ cols = torch.arange(logprob_seq_len, device=device)
174
+ dest_indices = left_pad_counts.unsqueeze(1) + cols
175
+
176
+ # Create destination row indices
177
+ # Shape: [batch_size, logprob_seq_len]
178
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
179
+
180
+ # --- 4. Filter out-of-bounds indices and perform assignment ---
181
+ # Create a mask to identify only the indices that are within the bounds
182
+ # of the target tensor's sequence length.
183
+ valid_mask = dest_indices < mask_seq_len
184
+
185
+ # Use this mask to select only the valid row indices, column indices,
186
+ # and the corresponding values from the logprob tensor.
187
+ # This flattens the selected elements into 1D tensors.
188
+ valid_rows = row_indices[valid_mask]
189
+ valid_cols = dest_indices[valid_mask]
190
+ valid_vals = logprob_tensor[valid_mask]
191
+
192
+ # Place the valid values into their correct positions in the padded tensor
193
+ # using a single, efficient advanced indexing operation.
194
+ padded_logprobs[valid_rows, valid_cols] = valid_vals
195
+
196
+ return padded_logprobs
197
+ @dataclass
198
+ class UnslothGKDConfig(GKDConfig):
199
+ """
200
+
201
+ Configuration class for [`GKDTrainer`].
202
+
203
+ This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
204
+ please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
205
+
206
+ Args:
207
+ temperature (`float`, *optional*, defaults to `0.9`):
208
+ Temperature for sampling. The higher the temperature, the more random the completions.
209
+ lmbda (`float`, *optional*, defaults to `0.5`):
210
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
211
+ student-generated outputs).
212
+ beta (`float`, *optional*, defaults to `0.5`):
213
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
214
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
215
+ max_new_tokens (`int`, *optional*, defaults to `128`):
216
+ Maximum number of tokens to generate per completion.
217
+ teacher_model_name_or_path (`str`, *optional*):
218
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
219
+ trained.
220
+ teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
221
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
222
+ from a string.
223
+ disable_dropout (`bool`, *optional*, defaults to `True`):
224
+ Whether to disable dropout in the model.
225
+ seq_kd (`bool`, *optional*, defaults to `False`):
226
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
227
+ teacher-generated output).
228
+
229
+ """
230
+ vllm_sampling_params: Optional[Any] = field(
231
+ default = None,
232
+ metadata = {'help': 'vLLM SamplingParams'},
233
+ )
234
+ unsloth_num_chunks : Optional[int] = field(
235
+ default = -1,
236
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
237
+ )
238
+ max_seq_length : Optional[int] = field(
239
+ default = None,
240
+ metadata = {'help': 'Maximum sequence length to truncate to.'},
241
+ )
242
+ def __init__(
243
+ self,
244
+ output_dir = None,
245
+ overwrite_output_dir = None,
246
+ do_train = False,
247
+ do_eval = False,
248
+ do_predict = False,
249
+ eval_strategy = 'no',
250
+ prediction_loss_only = False,
251
+ per_device_train_batch_size = 4,
252
+ per_device_eval_batch_size = 4,
253
+ per_gpu_train_batch_size = None,
254
+ per_gpu_eval_batch_size = None,
255
+ gradient_accumulation_steps = 2,
256
+ eval_accumulation_steps = 2,
257
+ eval_delay = 0,
258
+ torch_empty_cache_steps = 250,
259
+ learning_rate = 5e-05,
260
+ weight_decay = 0.01,
261
+ adam_beta1 = 0.9,
262
+ adam_beta2 = 0.999,
263
+ adam_epsilon = 1e-08,
264
+ max_grad_norm = 1.0,
265
+ num_train_epochs = 3.0,
266
+ max_steps = -1,
267
+ lr_scheduler_type = 'linear',
268
+ warmup_ratio = 0.1,
269
+ warmup_steps = 0,
270
+ log_level = 'passive',
271
+ log_level_replica = 'warning',
272
+ log_on_each_node = True,
273
+ logging_dir = None,
274
+ logging_strategy = 'steps',
275
+ logging_first_step = False,
276
+ logging_steps = 1,
277
+ logging_nan_inf_filter = False,
278
+ save_strategy = 'steps',
279
+ save_steps = 500,
280
+ save_total_limit = None,
281
+ save_safetensors = True,
282
+ save_on_each_node = False,
283
+ save_only_model = False,
284
+ restore_callback_states_from_checkpoint = False,
285
+ no_cuda = False,
286
+ use_cpu = False,
287
+ use_mps_device = False,
288
+ seed = 3407,
289
+ data_seed = 3407,
290
+ jit_mode_eval = False,
291
+ bf16 = False,
292
+ fp16 = False,
293
+ fp16_opt_level = 'O1',
294
+ half_precision_backend = 'auto',
295
+ bf16_full_eval = False,
296
+ fp16_full_eval = False,
297
+ tf32 = None,
298
+ local_rank = -1,
299
+ ddp_backend = None,
300
+ tpu_num_cores = None,
301
+ tpu_metrics_debug = False,
302
+ debug = '',
303
+ dataloader_drop_last = False,
304
+ eval_steps = None,
305
+ dataloader_num_workers = 0,
306
+ dataloader_prefetch_factor = None,
307
+ past_index = -1,
308
+ run_name = None,
309
+ disable_tqdm = None,
310
+ remove_unused_columns = True,
311
+ label_names = None,
312
+ load_best_model_at_end = False,
313
+ metric_for_best_model = None,
314
+ greater_is_better = None,
315
+ ignore_data_skip = False,
316
+ fsdp = None,
317
+ fsdp_min_num_params = 0,
318
+ fsdp_config = None,
319
+ fsdp_transformer_layer_cls_to_wrap = None,
320
+ accelerator_config = None,
321
+ parallelism_config = None,
322
+ deepspeed = None,
323
+ label_smoothing_factor = 0.0,
324
+ optim = 'adamw_8bit',
325
+ optim_args = None,
326
+ adafactor = False,
327
+ group_by_length = False,
328
+ length_column_name = 'length',
329
+ report_to = 'none',
330
+ project = 'huggingface',
331
+ trackio_space_id = 'trackio',
332
+ ddp_find_unused_parameters = None,
333
+ ddp_bucket_cap_mb = None,
334
+ ddp_broadcast_buffers = None,
335
+ dataloader_pin_memory = True,
336
+ dataloader_persistent_workers = False,
337
+ skip_memory_metrics = True,
338
+ use_legacy_prediction_loop = False,
339
+ push_to_hub = False,
340
+ resume_from_checkpoint = None,
341
+ hub_model_id = None,
342
+ hub_strategy = 'every_save',
343
+ hub_token = None,
344
+ hub_private_repo = None,
345
+ hub_always_push = False,
346
+ hub_revision = None,
347
+ gradient_checkpointing = True,
348
+ gradient_checkpointing_kwargs = None,
349
+ include_inputs_for_metrics = False,
350
+ eval_do_concat_batches = True,
351
+ fp16_backend = 'auto',
352
+ push_to_hub_model_id = None,
353
+ push_to_hub_organization = None,
354
+ push_to_hub_token = None,
355
+ mp_parameters = '',
356
+ auto_find_batch_size = False,
357
+ full_determinism = False,
358
+ torchdynamo = None,
359
+ ray_scope = 'last',
360
+ ddp_timeout = 1800,
361
+ torch_compile = False,
362
+ torch_compile_backend = None,
363
+ torch_compile_mode = None,
364
+ include_tokens_per_second = False,
365
+ include_num_input_tokens_seen = False,
366
+ neftune_noise_alpha = None,
367
+ optim_target_modules = None,
368
+ batch_eval_metrics = False,
369
+ eval_on_start = False,
370
+ use_liger_kernel = False,
371
+ liger_kernel_config = None,
372
+ eval_use_gather_object = False,
373
+ average_tokens_across_devices = True,
374
+ model_init_kwargs = None,
375
+ chat_template_path = None,
376
+ dataset_text_field = 'text',
377
+ dataset_kwargs = None,
378
+ dataset_num_proc = None,
379
+ eos_token = None,
380
+ pad_token = None,
381
+ max_length = 1024,
382
+ packing = False,
383
+ packing_strategy = 'bfd',
384
+ padding_free = False,
385
+ pad_to_multiple_of = None,
386
+ eval_packing = None,
387
+ completion_only_loss = None,
388
+ assistant_only_loss = False,
389
+ loss_type = 'nll',
390
+ activation_offloading = False,
391
+ temperature = 0.9,
392
+ lmbda = 0.5,
393
+ beta = 0.5,
394
+ max_new_tokens = 128,
395
+ teacher_model_name_or_path = None,
396
+ teacher_model_init_kwargs = None,
397
+ disable_dropout = True,
398
+ seq_kd = False,
399
+ vllm_sampling_params = None,
400
+ unsloth_num_chunks = -1,
401
+ max_seq_length = None,
402
+ **kwargs,
403
+ ):
404
+ if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
405
+ if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
406
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
407
+ output_dir = 'unsloth_training_checkpoints'
408
+ save_strategy = 'no'
409
+ if dataset_num_proc is None:
410
+ from multiprocessing import cpu_count
411
+ dataset_num_proc = min(max(cpu_count()+4, 2), 64)
412
+ if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
413
+ from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
414
+ if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
415
+ from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
416
+ pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
417
+
418
+ if temperature <= 0:
419
+ raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
420
+ elif temperature >= 10:
421
+ raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
422
+
423
+
424
+ super().__init__(
425
+ output_dir = output_dir,
426
+ overwrite_output_dir = overwrite_output_dir,
427
+ do_train = do_train,
428
+ do_eval = do_eval,
429
+ do_predict = do_predict,
430
+ eval_strategy = eval_strategy,
431
+ prediction_loss_only = prediction_loss_only,
432
+ per_device_train_batch_size = per_device_train_batch_size,
433
+ per_device_eval_batch_size = per_device_eval_batch_size,
434
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
435
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
436
+ gradient_accumulation_steps = gradient_accumulation_steps,
437
+ eval_accumulation_steps = eval_accumulation_steps,
438
+ eval_delay = eval_delay,
439
+ torch_empty_cache_steps = torch_empty_cache_steps,
440
+ learning_rate = learning_rate,
441
+ weight_decay = weight_decay,
442
+ adam_beta1 = adam_beta1,
443
+ adam_beta2 = adam_beta2,
444
+ adam_epsilon = adam_epsilon,
445
+ max_grad_norm = max_grad_norm,
446
+ num_train_epochs = num_train_epochs,
447
+ max_steps = max_steps,
448
+ lr_scheduler_type = lr_scheduler_type,
449
+ warmup_ratio = warmup_ratio,
450
+ warmup_steps = warmup_steps,
451
+ log_level = log_level,
452
+ log_level_replica = log_level_replica,
453
+ log_on_each_node = log_on_each_node,
454
+ logging_dir = logging_dir,
455
+ logging_strategy = logging_strategy,
456
+ logging_first_step = logging_first_step,
457
+ logging_steps = logging_steps,
458
+ logging_nan_inf_filter = logging_nan_inf_filter,
459
+ save_strategy = save_strategy,
460
+ save_steps = save_steps,
461
+ save_total_limit = save_total_limit,
462
+ save_safetensors = save_safetensors,
463
+ save_on_each_node = save_on_each_node,
464
+ save_only_model = save_only_model,
465
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
466
+ no_cuda = no_cuda,
467
+ use_cpu = use_cpu,
468
+ use_mps_device = use_mps_device,
469
+ seed = seed,
470
+ data_seed = data_seed,
471
+ jit_mode_eval = jit_mode_eval,
472
+ bf16 = bf16,
473
+ fp16 = fp16,
474
+ fp16_opt_level = fp16_opt_level,
475
+ half_precision_backend = half_precision_backend,
476
+ bf16_full_eval = bf16_full_eval,
477
+ fp16_full_eval = fp16_full_eval,
478
+ tf32 = tf32,
479
+ local_rank = local_rank,
480
+ ddp_backend = ddp_backend,
481
+ tpu_num_cores = tpu_num_cores,
482
+ tpu_metrics_debug = tpu_metrics_debug,
483
+ debug = debug,
484
+ dataloader_drop_last = dataloader_drop_last,
485
+ eval_steps = eval_steps,
486
+ dataloader_num_workers = dataloader_num_workers,
487
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
488
+ past_index = past_index,
489
+ run_name = run_name,
490
+ disable_tqdm = disable_tqdm,
491
+ remove_unused_columns = remove_unused_columns,
492
+ label_names = label_names,
493
+ load_best_model_at_end = load_best_model_at_end,
494
+ metric_for_best_model = metric_for_best_model,
495
+ greater_is_better = greater_is_better,
496
+ ignore_data_skip = ignore_data_skip,
497
+ fsdp = fsdp,
498
+ fsdp_min_num_params = fsdp_min_num_params,
499
+ fsdp_config = fsdp_config,
500
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
501
+ accelerator_config = accelerator_config,
502
+ parallelism_config = parallelism_config,
503
+ deepspeed = deepspeed,
504
+ label_smoothing_factor = label_smoothing_factor,
505
+ optim = optim,
506
+ optim_args = optim_args,
507
+ adafactor = adafactor,
508
+ group_by_length = group_by_length,
509
+ length_column_name = length_column_name,
510
+ report_to = report_to,
511
+ project = project,
512
+ trackio_space_id = trackio_space_id,
513
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
514
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
515
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
516
+ dataloader_pin_memory = dataloader_pin_memory,
517
+ dataloader_persistent_workers = dataloader_persistent_workers,
518
+ skip_memory_metrics = skip_memory_metrics,
519
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
520
+ push_to_hub = push_to_hub,
521
+ resume_from_checkpoint = resume_from_checkpoint,
522
+ hub_model_id = hub_model_id,
523
+ hub_strategy = hub_strategy,
524
+ hub_token = hub_token,
525
+ hub_private_repo = hub_private_repo,
526
+ hub_always_push = hub_always_push,
527
+ hub_revision = hub_revision,
528
+ gradient_checkpointing = gradient_checkpointing,
529
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
530
+ include_inputs_for_metrics = include_inputs_for_metrics,
531
+ eval_do_concat_batches = eval_do_concat_batches,
532
+ fp16_backend = fp16_backend,
533
+ push_to_hub_model_id = push_to_hub_model_id,
534
+ push_to_hub_organization = push_to_hub_organization,
535
+ push_to_hub_token = push_to_hub_token,
536
+ mp_parameters = mp_parameters,
537
+ auto_find_batch_size = auto_find_batch_size,
538
+ full_determinism = full_determinism,
539
+ torchdynamo = torchdynamo,
540
+ ray_scope = ray_scope,
541
+ ddp_timeout = ddp_timeout,
542
+ torch_compile = torch_compile,
543
+ torch_compile_backend = torch_compile_backend,
544
+ torch_compile_mode = torch_compile_mode,
545
+ include_tokens_per_second = include_tokens_per_second,
546
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
547
+ neftune_noise_alpha = neftune_noise_alpha,
548
+ optim_target_modules = optim_target_modules,
549
+ batch_eval_metrics = batch_eval_metrics,
550
+ eval_on_start = eval_on_start,
551
+ use_liger_kernel = use_liger_kernel,
552
+ liger_kernel_config = liger_kernel_config,
553
+ eval_use_gather_object = eval_use_gather_object,
554
+ average_tokens_across_devices = average_tokens_across_devices,
555
+ model_init_kwargs = model_init_kwargs,
556
+ chat_template_path = chat_template_path,
557
+ dataset_text_field = dataset_text_field,
558
+ dataset_kwargs = dataset_kwargs,
559
+ dataset_num_proc = dataset_num_proc,
560
+ eos_token = eos_token,
561
+ pad_token = pad_token,
562
+ max_length = max_length,
563
+ packing = packing,
564
+ packing_strategy = packing_strategy,
565
+ padding_free = padding_free,
566
+ pad_to_multiple_of = pad_to_multiple_of,
567
+ eval_packing = eval_packing,
568
+ completion_only_loss = completion_only_loss,
569
+ assistant_only_loss = assistant_only_loss,
570
+ loss_type = loss_type,
571
+ activation_offloading = activation_offloading,
572
+ temperature = temperature,
573
+ lmbda = lmbda,
574
+ beta = beta,
575
+ max_new_tokens = max_new_tokens,
576
+ teacher_model_name_or_path = teacher_model_name_or_path,
577
+ teacher_model_init_kwargs = teacher_model_init_kwargs,
578
+ disable_dropout = disable_dropout,
579
+ seq_kd = seq_kd,**kwargs)
580
+ self.vllm_sampling_params = vllm_sampling_params
581
+ self.unsloth_num_chunks = unsloth_num_chunks
582
+ self.max_seq_length = max_seq_length
583
+ pass
584
+
585
+ class _UnslothGKDTrainer(SFTTrainer):
586
+ """"""
587
+
588
+ _tag_names = ["trl", "gkd"]
589
+ _name = "GKD"
590
+ _paper = {
591
+ "title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
592
+ "id": "2306.13649",
593
+ # docstyle-ignore
594
+ "citation": textwrap.dedent("""\
595
+ @inproceedings{agarwal2024on-policy,
596
+ title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
597
+ author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
598
+ year = 2024,
599
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
600
+ publisher = {OpenReview.net},
601
+ url = {https://openreview.net/forum?id=3zKtaqxLhW},
602
+ }"""),
603
+ }
604
+
605
+ def __init__(
606
+ self,
607
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
608
+ teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
609
+ args: Optional[GKDConfig] = None,
610
+ data_collator: Optional[DataCollator] = None, # type: ignore
611
+ train_dataset: Optional[Dataset] = None,
612
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
613
+ processing_class: Optional[
614
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
615
+ ] = None,
616
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
617
+ callbacks: Optional[list[TrainerCallback]] = None,
618
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
619
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
620
+ peft_config: Optional["PeftConfig"] = None,
621
+ formatting_func: Optional[Callable] = None,
622
+ ):
623
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
624
+ warnings.warn(
625
+ "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
626
+ "it and want it to remain, please share your comments here: "
627
+ "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
628
+ "TRL_EXPERIMENTAL_SILENCE=1."
629
+ )
630
+ # Ensure Trainer does not drop non-signature columns used by the collator [e.g., "prompts"]
631
+ args.remove_unused_columns = False
632
+ # Respect a user-provided data_collator; otherwise, provide a ChatML collator that
633
+ if data_collator is None:
634
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
635
+
636
+ # Ensure SFTTrainer does not pre-process the dataset when using a ChatML collator,
637
+ # so that raw conversational fields [e.g., "messages"] remain available to the collator.
638
+ if args.dataset_kwargs is None:
639
+ args.dataset_kwargs = {"skip_prepare_dataset": True}
640
+ else:
641
+ args.dataset_kwargs["skip_prepare_dataset"] = True
642
+
643
+ # Liger fused GKD loss [JSD]
644
+ self.use_liger_gkd_loss = False
645
+ if args.use_liger_kernel:
646
+ self.liger_jsd_loss = LigerFusedLinearJSDLoss(
647
+ beta=args.beta,
648
+ ignore_index=-100,
649
+ temperature=args.temperature,
650
+ compiled=False,
651
+ )
652
+ self.use_liger_gkd_loss = True
653
+
654
+ super().__init__(
655
+ model,
656
+ args=args,
657
+ data_collator=data_collator,
658
+ train_dataset=train_dataset,
659
+ eval_dataset=eval_dataset,
660
+ processing_class=processing_class,
661
+ compute_metrics=compute_metrics,
662
+ callbacks=callbacks,
663
+ optimizers=optimizers,
664
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
665
+ peft_config=peft_config,
666
+ formatting_func=formatting_func,
667
+ )
668
+
669
+ if args.teacher_model_init_kwargs is None:
670
+ teacher_model_init_kwargs = {}
671
+ elif not isinstance(teacher_model, str):
672
+ raise ValueError(
673
+ "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
674
+ )
675
+ else:
676
+ teacher_model_init_kwargs = args.teacher_model_init_kwargs
677
+ teacher_model_init_kwargs["dtype"] = (
678
+ teacher_model_init_kwargs["dtype"]
679
+ if teacher_model_init_kwargs["dtype"] in ["auto", None]
680
+ else getattr(torch, teacher_model_init_kwargs["dtype"])
681
+ )
682
+
683
+ if isinstance(teacher_model, str):
684
+ teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
685
+
686
+ # Disable dropout in the model
687
+ if args.disable_dropout:
688
+ disable_dropout_in_model(self.model)
689
+
690
+ if self.is_deepspeed_enabled:
691
+ self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
692
+ else:
693
+ self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
694
+
695
+ self.lmbda = args.lmbda
696
+ self.beta = args.beta
697
+ self.temperature = args.temperature
698
+ self.seq_kd = args.seq_kd
699
+
700
+ self.generation_config = GenerationConfig(
701
+ max_new_tokens=args.max_new_tokens,
702
+ temperature=args.temperature,
703
+ do_sample=True,
704
+ top_k=0,
705
+ use_cache=False if args.gradient_checkpointing else True,
706
+ pad_token_id=self.processing_class.pad_token_id,
707
+ )
708
+ # Set custom EOS tokens if they are specified by the model's generation
709
+ # config. This is important for models with the Llama 3 chat template,
710
+ # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
711
+ # turns or messages.
712
+ if (
713
+ hasattr(self.model.generation_config, "eos_token_id")
714
+ and self.model.generation_config.eos_token_id is not None
715
+ ):
716
+ self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
717
+
718
+ @staticmethod
719
+ def generalized_jsd_loss(
720
+ student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
721
+ ):
722
+ """
723
+ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
724
+ of https://huggingface.co/papers/2306.13649 for the definition.
725
+
726
+ Args:
727
+ student_logits:
728
+ Tensor of shape (batch_size, sequence_length, vocab_size)
729
+ teacher_logits:
730
+ Tensor of shape (batch_size, sequence_length, vocab_size)
731
+ labels:
732
+ Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing
733
+ loss
734
+ beta:
735
+ Interpolation coefficient between 0 and 1 (default: 0.5)
736
+ temperature:
737
+ Softmax temperature (default: 1.0)
738
+ reduction:
739
+ Specifies the reduction to apply to the output (default: 'batchmean')
740
+
741
+ Returns:
742
+ loss: Scalar tensor with the generalized JSD loss
743
+ """
744
+
745
+ # Apply temperature scaling
746
+ student_logits = student_logits / temperature
747
+ teacher_logits = teacher_logits / temperature
748
+
749
+ # Compute log probabilities for student and probabilities for teacher
750
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
751
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
752
+
753
+ if beta == 0:
754
+ jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
755
+ elif beta == 1:
756
+ jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
757
+ else:
758
+ # Compute the log of the mixture distribution
759
+ # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
760
+ beta = torch.tensor(beta, dtype=student_log_probs.dtype)
761
+ mixture_log_probs = torch.logsumexp(
762
+ torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
763
+ dim=0,
764
+ )
765
+
766
+ # Compute KL divergences using F.kl_div
767
+ # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
768
+ kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
769
+ kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
770
+
771
+ # Compute the Generalized Jensen-Shannon Divergence
772
+ jsd = beta * kl_teacher + (1 - beta) * kl_student
773
+
774
+ # Masking
775
+ if labels is not None:
776
+ mask = labels != -100
777
+ jsd = jsd[mask]
778
+
779
+ # Apply reduction
780
+ if reduction == "batchmean":
781
+ return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0)
782
+ elif reduction == "sum":
783
+ return jsd.sum()
784
+ elif reduction == "mean":
785
+ return jsd.mean()
786
+ else:
787
+ return jsd
788
+
789
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
790
+ if self.use_liger_gkd_loss:
791
+ # Forward only through the base models (avoid lm_head to save memory)
792
+ unwrapped_student = self.accelerator.unwrap_model(model)
793
+ if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None:
794
+ base_student = unwrapped_student.get_decoder()
795
+ else:
796
+ base_student = getattr(
797
+ unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student
798
+ )
799
+
800
+ student_outputs = base_student(
801
+ input_ids=inputs["input_ids"],
802
+ attention_mask=inputs["attention_mask"],
803
+ output_hidden_states=True,
804
+ use_cache=False,
805
+ )
806
+
807
+ self.teacher_model.eval()
808
+ unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model)
809
+ if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None:
810
+ base_teacher = unwrapped_teacher.get_decoder()
811
+ else:
812
+ base_teacher = getattr(
813
+ unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher
814
+ )
815
+ with torch.no_grad():
816
+ teacher_outputs = base_teacher(
817
+ input_ids=inputs["input_ids"],
818
+ attention_mask=inputs["attention_mask"],
819
+ output_hidden_states=True,
820
+ use_cache=False,
821
+ )
822
+
823
+ # hidden states (shifted)
824
+ student_hidden = student_outputs.last_hidden_state[:, :-1].contiguous()
825
+ teacher_hidden = teacher_outputs.last_hidden_state[:, :-1].contiguous()
826
+
827
+ # labels mask and labels (shifted)
828
+ labels_mask = inputs["labels"] != -100
829
+ masked_input_ids = torch.where(
830
+ labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100)
831
+ )
832
+ true_labels = masked_input_ids[:, 1:].contiguous()
833
+
834
+ # heads
835
+ student_head = unwrapped_student.get_output_embeddings()
836
+ teacher_head = unwrapped_teacher.get_output_embeddings()
837
+
838
+ # liger fused jsd loss
839
+ loss = self.liger_jsd_loss(
840
+ student_input=student_hidden,
841
+ student_weight=student_head.weight,
842
+ teacher_input=teacher_hidden,
843
+ teacher_weight=teacher_head.weight,
844
+ true_labels=true_labels,
845
+ student_bias=getattr(student_head, "bias", None),
846
+ teacher_bias=getattr(teacher_head, "bias", None),
847
+ )
848
+ else:
849
+ # compute student output
850
+ student_outputs = model(
851
+ input_ids=inputs["input_ids"],
852
+ attention_mask=inputs["attention_mask"],
853
+ )
854
+
855
+ # compute teacher output in eval mode
856
+ self.teacher_model.eval()
857
+ with torch.no_grad():
858
+ teacher_outputs = self.teacher_model(
859
+ input_ids=inputs["input_ids"],
860
+ attention_mask=inputs["attention_mask"],
861
+ )
862
+
863
+ # slice the logits for the generated tokens using the inputs["prompts"] lengths
864
+ prompt_lengths = inputs["prompts"].shape[1]
865
+ shifted_student_logits = student_outputs.logits[:, prompt_lengths - 1 : -1, :]
866
+ shifted_teacher_logits = teacher_outputs.logits[:, prompt_lengths - 1 : -1, :]
867
+ shifted_labels = inputs["labels"][:, prompt_lengths:]
868
+
869
+ # compute loss
870
+ loss = self.generalized_jsd_loss(
871
+ student_logits=shifted_student_logits,
872
+ teacher_logits=shifted_teacher_logits,
873
+ labels=shifted_labels,
874
+ beta=self.beta,
875
+ )
876
+
877
+ # empty cache
878
+ empty_cache()
879
+
880
+ # Return loss
881
+ return (loss, student_outputs) if return_outputs else loss
882
+
883
+ @staticmethod
884
+ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
885
+ # Generate output with respect to the prompt-only
886
+ generated_outputs = model.generate(
887
+ input_ids=inputs["prompts"],
888
+ attention_mask=inputs.get("prompt_attention_mask", None),
889
+ generation_config=generation_config,
890
+ return_dict_in_generate=True,
891
+ )
892
+
893
+ # Get the generated token IDs
894
+ generated_tokens = generated_outputs.sequences
895
+ # Calculate new attention mask
896
+ new_attention_mask = torch.ones_like(generated_tokens)
897
+ new_labels = generated_tokens.clone()
898
+
899
+ # If there's pad_token_id, set attention mask to 0 for padding tokens
900
+ if pad_token_id is not None:
901
+ new_labels[new_labels == pad_token_id] = -100
902
+ new_attention_mask[generated_tokens == pad_token_id] = 0
903
+
904
+ return generated_tokens, new_attention_mask, new_labels
905
+
906
+ def training_step(
907
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
908
+ ) -> torch.Tensor:
909
+ """
910
+ Perform a training step for the Generalized Knowledge Distillation (GKD) model.
911
+
912
+ This method implements the on-policy learning approach described in the GKD paper. With probability
913
+ `self.lmbda`, it generates new responses using the student model, which are then used for training instead of
914
+ the original inputs.
915
+ """
916
+ if self.seq_kd:
917
+ with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
918
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
919
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
920
+ )
921
+ inputs["input_ids"] = new_input_ids
922
+ inputs["attention_mask"] = new_attention_mask
923
+ inputs["labels"] = new_labels
924
+ if random.random() <= self.lmbda:
925
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
926
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
927
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
928
+ )
929
+ inputs["input_ids"] = new_input_ids
930
+ inputs["attention_mask"] = new_attention_mask
931
+ inputs["labels"] = new_labels
932
+
933
+ loss = super().training_step(model, inputs, num_items_in_batch)
934
+ return loss
935
+ class UnslothGKDTrainer(_UnslothGKDTrainer):
936
+ """
937
+ Trainer for Generalized Knowledge Distillation (GKD) of language models.
938
+
939
+ For details on GKD, see the paper: [On-Policy Distillation of Language Models: Learning from Self-Generated
940
+ Mistakes](https://huggingface.co/papers/2306.13649).
941
+
942
+ Args:
943
+ model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*):
944
+ Model to be trained, or the string identifier of the model to be instantiated from a pretrained model.
945
+ teacher_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*):
946
+ Teacher model for knowledge distillation, or the string identifier of the model to be instantiated from a
947
+ pretrained model.
948
+ args ([`GKDConfig`], *optional*):
949
+ Training arguments.
950
+ data_collator ([`~transformers.DataCollator`], *optional*):
951
+ Data collator to batch samples from the dataset. It defaults to a [`DataCollatorForChatML`] using the
952
+ `processing_class`.
953
+ train_dataset ([`~datasets.Dataset`], *optional*):
954
+ Dataset for training.
955
+ eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
956
+ Dataset for evaluation.
957
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
958
+ Class to process the data.
959
+ compute_metrics (`Callable`, *optional*):
960
+ Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a
961
+ dictionary string to float.
962
+ callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
963
+ Callbacks to use during training.
964
+ optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
965
+ Tuple containing the optimizer and the learning rate scheduler to use for training.
966
+ preprocess_logits_for_metrics (`Callable`, *optional*):
967
+ Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and
968
+ return the logits to be used for metrics computation.
969
+ peft_config ([`~peft.PeftConfig`], *optional*):
970
+ PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be
971
+ wrapped with the specified PEFT adapter.
972
+ formatting_func (`Callable`, *optional*):
973
+ Function to format the dataset. Must take in an example and return an example.
974
+
975
+ """
976
+ def __init__(
977
+ self,
978
+ model = None,
979
+ teacher_model = None,
980
+ args = None,
981
+ data_collator = None,
982
+ train_dataset = None,
983
+ eval_dataset = None,
984
+ processing_class = None,
985
+ compute_metrics = None,
986
+ callbacks = None,
987
+ preprocess_logits_for_metrics = None,
988
+ peft_config = None,
989
+ formatting_func = None,
990
+ **kwargs
991
+ ):
992
+ if args is None: args = UnslothGKDConfig()
993
+ use_bf16 = getattr(args, 'bf16', False)
994
+ if type(use_bf16) is not bool: use_bf16 = False
995
+ use_fp16 = getattr(args, 'fp16', False)
996
+ if type(use_fp16) is not bool: use_fp16 = False
997
+ force_float32 = False
998
+ full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
999
+ if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
1000
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1001
+ force_float32 = True
1002
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1003
+ dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
1004
+ if dtype is None: dtype = model.get_input_embeddings().weight.dtype
1005
+ from unsloth_zoo.utils import _get_dtype
1006
+ dtype = _get_dtype(dtype)
1007
+ float16 = dtype == torch.float16
1008
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1009
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1010
+ if force_float32:
1011
+ # Forced float32 training
1012
+ args.fp16 = False
1013
+ args.bf16 = False
1014
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1015
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1016
+ # args.mixed_precision is a new argument which needs to be set now
1017
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1018
+ # Mixed precision training
1019
+ args.fp16 = float16
1020
+ args.bf16 = not float16
1021
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1022
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
1023
+ # args.mixed_precision is a new argument which needs to be set now
1024
+ elif mixed_precision_dtype == 'bfloat16':
1025
+ # Both False since bfloat16 full finetuning doesn't do any autocasting.
1026
+ args.fp16 = False
1027
+ args.bf16 = False
1028
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1029
+ if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
1030
+ # args.mixed_precision is a new argument which needs to be set now
1031
+
1032
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1033
+ args.eval_strategy = 'steps'
1034
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1035
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1036
+ if ga_steps is not None and ga_steps > 1:
1037
+ from transformers import __version__ as transformers_version
1038
+ if Version(transformers_version) <= Version('4.45.2'):
1039
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1040
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1041
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1042
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1043
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1044
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1045
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1046
+ if type(fp16_full_eval) is not bool: fp16_full_eval = False
1047
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1048
+ if type(bf16_full_eval) is not bool: bf16_full_eval = False
1049
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1050
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1051
+ if force_float32:
1052
+ args.bf16_full_eval = False
1053
+ args.fp16_full_eval = False
1054
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1055
+ args.bf16_full_eval = True
1056
+ args.fp16_full_eval = False
1057
+ elif not bf16_full_eval and not fp16_full_eval:
1058
+ args.bf16_full_eval = args.bf16
1059
+ args.fp16_full_eval = args.fp16
1060
+ _output_logits = False
1061
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1062
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1063
+ if _output_logits:
1064
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1065
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1066
+ pass
1067
+ else:
1068
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1069
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1070
+ if args_max_seq_length is None and model_max_seq_length is not None:
1071
+ max_seq_length = model.max_seq_length
1072
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1073
+ if model is not None and hasattr(model, 'for_training'):
1074
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1075
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1076
+ if 'processing_class' in locals():
1077
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1078
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1079
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1080
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1081
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1082
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1083
+ data_collator = TransformersDataCollatorForLanguageModeling(
1084
+ __tokenizer,
1085
+ mlm = False,
1086
+ mlm_probability = 0.0,
1087
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1088
+ )
1089
+ elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1090
+ data_collator = DataCollatorForSeq2Seq(
1091
+ __tokenizer,
1092
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1093
+ )
1094
+ else:
1095
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1096
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1097
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1098
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1099
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1100
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1101
+ data_collator = DataCollatorForSeq2Seq(
1102
+ __tokenizer.tokenizer,
1103
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1104
+ )
1105
+ else:
1106
+ data_collator = TransformersDataCollatorForLanguageModeling(
1107
+ __tokenizer.tokenizer,
1108
+ mlm = False,
1109
+ mlm_probability = 0.0,
1110
+ pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
1111
+ )
1112
+ other_metrics = []
1113
+
1114
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1115
+ PatchRLStatistics('gkd_trainer', other_metrics)
1116
+
1117
+ # [TODO] Fix up DataParallel multiplying batch sizes
1118
+ # [TODO] DDP works, but DP seems to not work? [TODO]
1119
+ if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
1120
+ if getattr(args, "_n_gpu", 1) != 1:
1121
+ args._n_gpu = 1
1122
+ if "model" in locals() and hasattr(model, "for_training"):
1123
+ model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
1124
+ super().__init__(
1125
+ model = model,
1126
+ teacher_model = teacher_model,
1127
+ args = args,
1128
+ data_collator = data_collator,
1129
+ train_dataset = train_dataset,
1130
+ eval_dataset = eval_dataset,
1131
+ processing_class = processing_class,
1132
+ compute_metrics = compute_metrics,
1133
+ callbacks = callbacks,
1134
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1135
+ peft_config = peft_config,
1136
+ formatting_func = formatting_func,**kwargs)
1137
+ if "model" in locals() and hasattr(model, "for_inference"):
1138
+ model.for_inference()
1139
+ if hasattr(self, 'neftune_hook_handle'):
1140
+ self.neftune_hook_handle.remove()
1141
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1142
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1143
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1144
+ pass
1145
+ if hasattr(self, 'accelerator'):
1146
+ scaler = self.accelerator.scaler
1147
+ current_model = model
1148
+ while hasattr(current_model, 'model'):
1149
+ current_model.accelerator_scaler = scaler
1150
+ current_model = current_model.model
1151
+ current_model.accelerator_scaler = scaler
1152
+ pass
1153
+ if hasattr(self, 'train'):
1154
+ self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
1155
+ pass
1156
+
1157
+ pass