liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  6. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  7. liger_kernel/ops/dyt.py +111 -179
  8. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  9. liger_kernel/ops/geglu.py +1 -1
  10. liger_kernel/ops/grpo_loss.py +310 -0
  11. liger_kernel/ops/multi_token_attention.py +207 -0
  12. liger_kernel/ops/rms_norm.py +265 -54
  13. liger_kernel/ops/softmax.py +201 -0
  14. liger_kernel/ops/sparsemax.py +179 -0
  15. liger_kernel/ops/swiglu.py +1 -1
  16. liger_kernel/transformers/__init__.py +8 -0
  17. liger_kernel/transformers/dyt.py +5 -3
  18. liger_kernel/transformers/fsdp.py +55 -0
  19. liger_kernel/transformers/functional.py +70 -0
  20. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  21. liger_kernel/transformers/grpo_loss.py +98 -0
  22. liger_kernel/transformers/model/gemma.py +25 -16
  23. liger_kernel/transformers/model/gemma2.py +27 -14
  24. liger_kernel/transformers/model/gemma3.py +62 -106
  25. liger_kernel/transformers/model/glm4.py +16 -13
  26. liger_kernel/transformers/model/llama.py +81 -18
  27. liger_kernel/transformers/model/llama4.py +108 -0
  28. liger_kernel/transformers/model/llava.py +95 -132
  29. liger_kernel/transformers/model/mistral.py +13 -14
  30. liger_kernel/transformers/model/mixtral.py +16 -15
  31. liger_kernel/transformers/model/mllama.py +16 -14
  32. liger_kernel/transformers/model/olmo2.py +16 -13
  33. liger_kernel/transformers/model/paligemma.py +8 -9
  34. liger_kernel/transformers/model/phi3.py +25 -16
  35. liger_kernel/transformers/model/qwen2.py +24 -15
  36. liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
  37. liger_kernel/transformers/model/qwen2_vl.py +38 -106
  38. liger_kernel/transformers/model/qwen3.py +11 -9
  39. liger_kernel/transformers/model/qwen3_moe.py +132 -0
  40. liger_kernel/transformers/monkey_patch.py +424 -81
  41. liger_kernel/transformers/multi_token_attention.py +64 -0
  42. liger_kernel/transformers/rms_norm.py +40 -4
  43. liger_kernel/transformers/softmax.py +12 -0
  44. liger_kernel/transformers/sparsemax.py +16 -0
  45. liger_kernel/transformers/swiglu.py +21 -0
  46. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  47. liger_kernel/utils.py +11 -0
  48. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
  49. liger_kernel-0.6.0.dist-info/RECORD +97 -0
  50. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  51. liger_kernel/transformers/gema3_rms.py +0 -8
  52. liger_kernel-0.5.9.dist-info/RECORD +0 -84
  53. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from transformers.cache_utils import Cache
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
11
- from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
13
 
18
14
 
19
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
20
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -33,7 +27,8 @@ def lce_forward(
33
27
  return_dict: Optional[bool] = None,
34
28
  cache_position: Optional[torch.LongTensor] = None,
35
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
36
- **loss_kwargs,
30
+ skip_logits: Optional[bool] = None,
31
+ **kwargs,
37
32
  ) -> Union[Tuple, CausalLMOutputWithPast]:
38
33
  r"""
39
34
  Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -88,6 +83,7 @@ def lce_forward(
88
83
  output_hidden_states=output_hidden_states,
89
84
  return_dict=return_dict,
90
85
  cache_position=cache_position,
86
+ **kwargs,
91
87
  )
92
88
 
93
89
  hidden_states = outputs[0]
@@ -95,18 +91,24 @@ def lce_forward(
95
91
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
96
92
  kept_hidden_states = hidden_states[:, slice_indices, :]
97
93
 
98
- shift_labels = loss_kwargs.pop("shift_labels", None)
94
+ shift_labels = kwargs.pop("shift_labels", None)
99
95
  loss = None
100
96
  logits = None
101
97
 
102
- if self.training and (labels is not None or shift_labels is not None):
98
+ if skip_logits and labels is None and shift_labels is None:
99
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
100
+
101
+ if skip_logits is None:
102
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
103
+
104
+ if skip_logits:
103
105
  loss = LigerForCausalLMLoss(
104
106
  hidden_states=kept_hidden_states,
105
107
  lm_head_weight=self.lm_head.weight,
106
108
  labels=labels,
107
109
  shift_labels=shift_labels,
108
110
  hidden_size=self.config.hidden_size,
109
- **loss_kwargs,
111
+ **kwargs,
110
112
  )
111
113
 
112
114
  else:
@@ -118,7 +120,7 @@ def lce_forward(
118
120
  logits=logits,
119
121
  labels=labels,
120
122
  vocab_size=self.config.vocab_size,
121
- **loss_kwargs,
123
+ **kwargs,
122
124
  )
123
125
  if not return_dict:
124
126
  output = (logits,) + outputs[1:]
@@ -131,6 +133,3 @@ def lce_forward(
131
133
  hidden_states=outputs.hidden_states,
132
134
  attentions=outputs.attentions,
133
135
  )
134
-
135
-
136
- # Note: Grad Acc is not fixed in mistral at transformer 4.46.1
@@ -7,19 +7,13 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
10
- from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
- from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
10
  from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
19
15
 
20
16
 
21
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
17
  def lce_forward_deprecated(
24
18
  self,
25
19
  input_ids: torch.LongTensor = None,
@@ -146,8 +140,6 @@ def lce_forward_deprecated(
146
140
 
147
141
 
148
142
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
149
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
150
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
151
143
  # Ignore copy
152
144
  def lce_forward(
153
145
  self,
@@ -164,7 +156,8 @@ def lce_forward(
164
156
  return_dict: Optional[bool] = None,
165
157
  cache_position: Optional[torch.LongTensor] = None,
166
158
  logits_to_keep: Union[int, torch.Tensor] = 0,
167
- **loss_kwargs,
159
+ skip_logits: Optional[bool] = None,
160
+ **kwargs,
168
161
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
169
162
  r"""
170
163
  Args:
@@ -222,6 +215,7 @@ def lce_forward(
222
215
  output_router_logits=output_router_logits,
223
216
  return_dict=return_dict,
224
217
  cache_position=cache_position,
218
+ **kwargs,
225
219
  )
226
220
 
227
221
  hidden_states = outputs[0]
@@ -229,26 +223,33 @@ def lce_forward(
229
223
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
224
  kept_hidden_states = hidden_states[:, slice_indices, :]
231
225
 
232
- shift_labels = loss_kwargs.pop("shift_labels", None)
226
+ shift_labels = kwargs.pop("shift_labels", None)
233
227
  logits = None
234
228
  loss = None
235
- # if in training mode, don't materialize logits
236
- if self.training and (labels is not None or shift_labels is not None):
229
+
230
+ if skip_logits and labels is None and shift_labels is None:
231
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
232
+
233
+ if skip_logits is None:
234
+ # By default, if in training mode, don't materialize logits
235
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
236
+
237
+ if skip_logits:
237
238
  loss = LigerForCausalLMLoss(
238
239
  hidden_states=kept_hidden_states,
239
240
  lm_head_weight=self.lm_head.weight,
240
241
  labels=labels,
241
242
  shift_labels=shift_labels,
242
243
  hidden_size=self.config.hidden_size,
243
- **loss_kwargs,
244
+ **kwargs,
244
245
  )
245
246
 
246
- else: # if in inference mode materialize logits
247
+ else:
247
248
  logits = self.lm_head(kept_hidden_states)
248
249
 
249
250
  loss = None
250
251
  if labels is not None:
251
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
252
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
252
253
  aux_loss = None
253
254
  if output_router_logits:
254
255
  aux_loss = load_balancing_loss_func(
@@ -8,17 +8,12 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
11
  from transformers.utils.deprecation import deprecate_kwarg
15
12
 
16
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
15
 
19
16
 
20
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
22
17
  def lce_forward_deprecated(
23
18
  self,
24
19
  input_ids: torch.LongTensor = None,
@@ -135,8 +130,6 @@ def lce_forward_deprecated(
135
130
 
136
131
 
137
132
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
139
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
140
133
  def lce_forward(
141
134
  self,
142
135
  input_ids: torch.LongTensor = None,
@@ -154,7 +147,8 @@ def lce_forward(
154
147
  return_dict: Optional[bool] = None,
155
148
  cache_position: Optional[torch.LongTensor] = None,
156
149
  logits_to_keep: Union[int, torch.Tensor] = 0,
157
- **loss_kwargs,
150
+ skip_logits: Optional[bool] = None,
151
+ **kwargs,
158
152
  ) -> Union[Tuple, CausalLMOutputWithPast]:
159
153
  r"""
160
154
  Args:
@@ -212,6 +206,7 @@ def lce_forward(
212
206
  output_hidden_states=output_hidden_states,
213
207
  return_dict=return_dict,
214
208
  cache_position=cache_position,
209
+ **kwargs,
215
210
  )
216
211
 
217
212
  hidden_states = outputs[0]
@@ -219,28 +214,35 @@ def lce_forward(
219
214
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
220
215
  kept_hidden_states = hidden_states[:, slice_indices, :]
221
216
 
222
- shift_labels = loss_kwargs.pop("shift_labels", None)
217
+ shift_labels = kwargs.pop("shift_labels", None)
223
218
  logits = None
224
219
  loss = None
225
- # if in training mode, don't materialize logits
226
- if self.training and (labels is not None or shift_labels is not None):
220
+
221
+ if skip_logits and labels is None and shift_labels is None:
222
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
223
+
224
+ if skip_logits is None:
225
+ # By default, if in training mode, don't materialize logits
226
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
227
+
228
+ if skip_logits:
227
229
  loss = LigerForCausalLMLoss(
228
230
  hidden_states=kept_hidden_states,
229
231
  lm_head_weight=self.lm_head.weight,
230
232
  labels=labels,
231
233
  shift_labels=shift_labels,
232
234
  hidden_size=self.config.hidden_size,
233
- **loss_kwargs,
235
+ **kwargs,
234
236
  )
235
237
 
236
- else: # if in inference mode materialize logits
238
+ else:
237
239
  logits = self.lm_head(kept_hidden_states)
238
240
  if labels is not None:
239
241
  loss = self.loss_function(
240
242
  logits=logits,
241
243
  labels=labels,
242
244
  vocab_size=self.config.vocab_size,
243
- **loss_kwargs,
245
+ **kwargs,
244
246
  )
245
247
 
246
248
  if not return_dict:
@@ -6,18 +6,12 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
- from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
13
9
  from transformers.utils.deprecation import deprecate_kwarg
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
- @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -32,7 +26,8 @@ def lce_forward(
32
26
  return_dict: Optional[bool] = None,
33
27
  cache_position: Optional[torch.LongTensor] = None,
34
28
  logits_to_keep: Union[int, torch.Tensor] = 0,
35
- **loss_kwargs,
29
+ skip_logits: Optional[bool] = None,
30
+ **kwargs,
36
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
37
32
  r"""
38
33
  Args:
@@ -85,6 +80,7 @@ def lce_forward(
85
80
  output_hidden_states=output_hidden_states,
86
81
  return_dict=return_dict,
87
82
  cache_position=cache_position,
83
+ **kwargs,
88
84
  )
89
85
 
90
86
  hidden_states = outputs[0]
@@ -92,28 +88,35 @@ def lce_forward(
92
88
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93
89
  kept_hidden_states = hidden_states[:, slice_indices, :]
94
90
 
95
- shift_labels = loss_kwargs.pop("shift_labels", None)
91
+ shift_labels = kwargs.pop("shift_labels", None)
96
92
  logits = None
97
93
  loss = None
98
- # if in training mode, don't materialize logits
99
- if self.training and (labels is not None or shift_labels is not None):
94
+
95
+ if skip_logits and labels is None and shift_labels is None:
96
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
97
+
98
+ if skip_logits is None:
99
+ # By default, if in training mode, don't materialize logits
100
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
101
+
102
+ if skip_logits:
100
103
  loss = LigerForCausalLMLoss(
101
104
  hidden_states=kept_hidden_states,
102
105
  lm_head_weight=self.lm_head.weight,
103
106
  labels=labels,
104
107
  shift_labels=shift_labels,
105
108
  hidden_size=self.config.hidden_size,
106
- **loss_kwargs,
109
+ **kwargs,
107
110
  )
108
111
 
109
- else: # if in inference mode materialize logits
112
+ else:
110
113
  logits = self.lm_head(kept_hidden_states)
111
114
  if labels is not None:
112
115
  loss = self.loss_function(
113
116
  logits=logits,
114
117
  labels=labels,
115
118
  vocab_size=self.config.vocab_size,
116
- **loss_kwargs,
119
+ **kwargs,
117
120
  )
118
121
 
119
122
  return CausalLMOutputWithPast(
@@ -7,13 +7,9 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
- from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
11
- from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
12
10
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
11
  from transformers.utils import is_torchdynamo_compiling
15
12
  from transformers.utils import logging
16
- from transformers.utils import replace_return_docstrings
17
13
  from transformers.utils.deprecation import deprecate_kwarg
18
14
 
19
15
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -21,8 +17,6 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
21
17
  logger = logging.get_logger(__name__)
22
18
 
23
19
 
24
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
20
  def lce_forward_deprecated(
27
21
  self,
28
22
  input_ids: torch.LongTensor = None,
@@ -206,8 +200,6 @@ def lce_forward_deprecated(
206
200
 
207
201
 
208
202
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
209
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
210
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
211
203
  def lce_forward(
212
204
  self,
213
205
  input_ids: torch.LongTensor = None,
@@ -224,6 +216,7 @@ def lce_forward(
224
216
  output_hidden_states: Optional[bool] = None,
225
217
  return_dict: Optional[bool] = None,
226
218
  logits_to_keep: Union[int, torch.Tensor] = 0,
219
+ skip_logits: Optional[bool] = None,
227
220
  **lm_kwargs,
228
221
  ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
229
222
  r"""
@@ -339,7 +332,13 @@ def lce_forward(
339
332
  loss = None
340
333
  logits = None
341
334
 
342
- if self.training and (labels is not None):
335
+ if skip_logits and labels is None:
336
+ raise ValueError("skip_logits is True, but labels is None")
337
+
338
+ if skip_logits is None:
339
+ skip_logits = self.training and (labels is not None)
340
+
341
+ if skip_logits:
343
342
  shift_hidden_states = hidden_states[..., :-1, :]
344
343
  shift_labels = labels[..., 1:]
345
344
 
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
- from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward_deprecated(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -32,6 +26,7 @@ def lce_forward_deprecated(
32
26
  output_hidden_states: Optional[bool] = None,
33
27
  return_dict: Optional[bool] = None,
34
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
35
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
36
31
  r"""
37
32
  Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
@@ -86,7 +81,14 @@ def lce_forward_deprecated(
86
81
  loss = None
87
82
  logits = None
88
83
 
89
- if self.training and labels is not None:
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
91
+ if skip_logits:
90
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
91
93
  shift_labels = labels[..., 1:].contiguous()
92
94
 
@@ -128,8 +130,6 @@ def lce_forward_deprecated(
128
130
 
129
131
 
130
132
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
131
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
132
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
133
133
  def lce_forward(
134
134
  self,
135
135
  input_ids: torch.LongTensor = None,
@@ -144,7 +144,8 @@ def lce_forward(
144
144
  return_dict: Optional[bool] = None,
145
145
  cache_position: Optional[torch.LongTensor] = None,
146
146
  logits_to_keep: Union[int, torch.Tensor] = 0,
147
- **loss_kwargs,
147
+ skip_logits: Optional[bool] = None,
148
+ **kwargs,
148
149
  ) -> Union[Tuple, CausalLMOutputWithPast]:
149
150
  r"""
150
151
  Args:
@@ -210,6 +211,7 @@ def lce_forward(
210
211
  output_attentions=output_attentions,
211
212
  output_hidden_states=output_hidden_states,
212
213
  return_dict=return_dict,
214
+ **kwargs,
213
215
  )
214
216
 
215
217
  hidden_states = outputs[0]
@@ -217,28 +219,35 @@ def lce_forward(
217
219
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
218
220
  kept_hidden_states = hidden_states[:, slice_indices, :]
219
221
 
220
- shift_labels = loss_kwargs.pop("shift_labels", None)
222
+ shift_labels = kwargs.pop("shift_labels", None)
221
223
  logits = None
222
224
  loss = None
223
- # if in training mode, don't materialize logits
224
- if self.training and (labels is not None or shift_labels is not None):
225
+
226
+ if skip_logits and labels is None and shift_labels is None:
227
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
228
+
229
+ if skip_logits is None:
230
+ # By default, if in training mode, don't materialize logits
231
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
232
+
233
+ if skip_logits:
225
234
  loss = LigerForCausalLMLoss(
226
235
  hidden_states=kept_hidden_states,
227
236
  lm_head_weight=self.lm_head.weight,
228
237
  labels=labels,
229
238
  shift_labels=shift_labels,
230
239
  hidden_size=self.config.hidden_size,
231
- **loss_kwargs,
240
+ **kwargs,
232
241
  )
233
242
 
234
- else: # if in inference mode materialize logits
243
+ else:
235
244
  logits = self.lm_head(kept_hidden_states)
236
245
  if labels is not None:
237
246
  loss = self.loss_function(
238
247
  logits=logits,
239
248
  labels=labels,
240
249
  vocab_size=self.config.vocab_size,
241
- **loss_kwargs,
250
+ **kwargs,
242
251
  )
243
252
 
244
253
  if not return_dict:
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
- from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward_deprecated(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -32,6 +26,7 @@ def lce_forward_deprecated(
32
26
  output_hidden_states: Optional[bool] = None,
33
27
  return_dict: Optional[bool] = None,
34
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
35
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
36
31
  r"""
37
32
  Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -86,6 +81,13 @@ def lce_forward_deprecated(
86
81
  loss = None
87
82
  logits = None
88
83
 
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
89
91
  if self.training and (labels is not None):
90
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
91
93
  shift_labels = labels[..., 1:].contiguous()
@@ -127,8 +129,6 @@ def lce_forward_deprecated(
127
129
 
128
130
 
129
131
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
131
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
132
132
  def lce_forward(
133
133
  self,
134
134
  input_ids: torch.LongTensor = None,
@@ -143,7 +143,8 @@ def lce_forward(
143
143
  return_dict: Optional[bool] = None,
144
144
  cache_position: Optional[torch.LongTensor] = None,
145
145
  logits_to_keep: Union[int, torch.Tensor] = 0,
146
- **loss_kwargs,
146
+ skip_logits: Optional[bool] = None,
147
+ **kwargs,
147
148
  ) -> Union[Tuple, CausalLMOutputWithPast]:
148
149
  r"""
149
150
  Args:
@@ -196,6 +197,7 @@ def lce_forward(
196
197
  output_hidden_states=output_hidden_states,
197
198
  return_dict=return_dict,
198
199
  cache_position=cache_position,
200
+ **kwargs,
199
201
  )
200
202
 
201
203
  hidden_states = outputs[0]
@@ -203,28 +205,35 @@ def lce_forward(
203
205
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
204
206
  kept_hidden_states = hidden_states[:, slice_indices, :]
205
207
 
206
- shift_labels = loss_kwargs.pop("shift_labels", None)
208
+ shift_labels = kwargs.pop("shift_labels", None)
207
209
  logits = None
208
210
  loss = None
209
- # if in training mode, don't materialize logits
210
- if self.training and (labels is not None or shift_labels is not None):
211
+
212
+ if skip_logits and labels is None and shift_labels is None:
213
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
214
+
215
+ if skip_logits is None:
216
+ # By default, if in training mode, don't materialize logits
217
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
218
+
219
+ if skip_logits:
211
220
  loss = LigerForCausalLMLoss(
212
221
  hidden_states=kept_hidden_states,
213
222
  lm_head_weight=self.lm_head.weight,
214
223
  labels=labels,
215
224
  shift_labels=shift_labels,
216
225
  hidden_size=self.config.hidden_size,
217
- **loss_kwargs,
226
+ **kwargs,
218
227
  )
219
228
 
220
- else: # if in inference mode materialize logits
229
+ else:
221
230
  logits = self.lm_head(kept_hidden_states)
222
231
  if labels is not None:
223
232
  loss = self.loss_function(
224
233
  logits=logits,
225
234
  labels=labels,
226
235
  vocab_size=self.config.vocab_size,
227
- **loss_kwargs,
236
+ **kwargs,
228
237
  )
229
238
 
230
239
  return CausalLMOutputWithPast(