liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -7,19 +7,15 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
10
- from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
- from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
10
  from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
19
17
 
20
18
 
21
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
19
  def lce_forward_deprecated(
24
20
  self,
25
21
  input_ids: torch.LongTensor = None,
@@ -146,8 +142,6 @@ def lce_forward_deprecated(
146
142
 
147
143
 
148
144
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
149
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
150
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
151
145
  # Ignore copy
152
146
  def lce_forward(
153
147
  self,
@@ -164,8 +158,9 @@ def lce_forward(
164
158
  return_dict: Optional[bool] = None,
165
159
  cache_position: Optional[torch.LongTensor] = None,
166
160
  logits_to_keep: Union[int, torch.Tensor] = 0,
167
- **loss_kwargs,
168
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
161
+ skip_logits: Optional[bool] = None,
162
+ **kwargs,
163
+ ) -> Union[Tuple, LigerMoeCausalLMOutputWithPast]:
169
164
  r"""
170
165
  Args:
171
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -222,29 +217,50 @@ def lce_forward(
222
217
  output_router_logits=output_router_logits,
223
218
  return_dict=return_dict,
224
219
  cache_position=cache_position,
220
+ **kwargs,
225
221
  )
226
222
 
227
223
  hidden_states = outputs[0]
224
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
225
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
226
+ kept_hidden_states = hidden_states[:, slice_indices, :]
228
227
 
228
+ shift_labels = kwargs.pop("shift_labels", None)
229
229
  logits = None
230
230
  loss = None
231
- # if in training mode, don't materialize logits
232
- if self.training and (labels is not None):
233
- loss = LigerForCausalLMLoss(
234
- hidden_states=hidden_states,
231
+ token_accuracy = None
232
+
233
+ if skip_logits and labels is None and shift_labels is None:
234
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
235
+
236
+ if skip_logits is None:
237
+ # By default, if in training mode, don't materialize logits
238
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
239
+
240
+ # Compute loss
241
+ if skip_logits:
242
+ result = LigerForCausalLMLoss(
243
+ hidden_states=kept_hidden_states,
235
244
  lm_head_weight=self.lm_head.weight,
236
245
  labels=labels,
246
+ shift_labels=shift_labels,
237
247
  hidden_size=self.config.hidden_size,
238
- **loss_kwargs,
248
+ **kwargs,
239
249
  )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
240
251
 
241
- else: # if in inference mode materialize logits
242
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
243
- logits = self.lm_head(hidden_states[:, slice_indices, :])
252
+ else:
253
+ logits = self.lm_head(kept_hidden_states)
244
254
 
245
255
  loss = None
246
- if labels is not None:
247
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
256
+ if labels is not None or shift_labels is not None:
257
+ loss = self.loss_function(
258
+ logits=logits,
259
+ labels=labels,
260
+ shift_labels=shift_labels,
261
+ vocab_size=self.vocab_size,
262
+ **kwargs,
263
+ )
248
264
  aux_loss = None
249
265
  if output_router_logits:
250
266
  aux_loss = load_balancing_loss_func(
@@ -257,17 +273,21 @@ def lce_forward(
257
273
  loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
258
274
 
259
275
  if not return_dict:
260
- output = (logits,) + outputs[1:]
276
+ output_tuple = (logits,) + outputs[1:]
261
277
  if output_router_logits:
262
- output = (aux_loss,) + output
263
- return (loss,) + output if loss is not None else output
278
+ output_tuple = (aux_loss,) + output_tuple
279
+ if token_accuracy is not None:
280
+ output_tuple = output_tuple + (token_accuracy,)
281
+ return (loss,) + output_tuple if loss is not None else output_tuple
264
282
 
265
- return MoeCausalLMOutputWithPast(
283
+ # Return custom output class with token_accuracy field
284
+ return LigerMoeCausalLMOutputWithPast(
266
285
  loss=loss,
267
286
  aux_loss=aux_loss,
268
287
  logits=logits,
269
288
  past_key_values=outputs.past_key_values,
270
289
  hidden_states=outputs.hidden_states,
271
290
  attentions=outputs.attentions,
272
- router_logits=outputs.router_logits,
291
+ router_logits=outputs.router_logits if return_dict else outputs[-1],
292
+ token_accuracy=token_accuracy,
273
293
  )
@@ -8,17 +8,14 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
11
  from transformers.utils.deprecation import deprecate_kwarg
15
12
 
16
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
18
17
 
19
18
 
20
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
22
19
  def lce_forward_deprecated(
23
20
  self,
24
21
  input_ids: torch.LongTensor = None,
@@ -135,8 +132,6 @@ def lce_forward_deprecated(
135
132
 
136
133
 
137
134
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
139
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
140
135
  def lce_forward(
141
136
  self,
142
137
  input_ids: torch.LongTensor = None,
@@ -154,8 +149,9 @@ def lce_forward(
154
149
  return_dict: Optional[bool] = None,
155
150
  cache_position: Optional[torch.LongTensor] = None,
156
151
  logits_to_keep: Union[int, torch.Tensor] = 0,
157
- **loss_kwargs,
158
- ) -> Union[Tuple, CausalLMOutputWithPast]:
152
+ skip_logits: Optional[bool] = None,
153
+ **kwargs,
154
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
159
155
  r"""
160
156
  Args:
161
157
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -196,7 +192,9 @@ def lce_forward(
196
192
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
193
  )
198
194
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
-
195
+ # Filter out accum_dtype from kwargs for model call as MllamaTextModel doesn't accept it in transformers 4.49.0
196
+ # but preserve it for loss function calls
197
+ model_kwargs = {k: v for k, v in kwargs.items() if k != "accum_dtype"}
200
198
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
201
199
  outputs = self.model(
202
200
  input_ids=input_ids,
@@ -212,41 +210,60 @@ def lce_forward(
212
210
  output_hidden_states=output_hidden_states,
213
211
  return_dict=return_dict,
214
212
  cache_position=cache_position,
213
+ **model_kwargs,
215
214
  )
216
215
 
217
216
  hidden_states = outputs[0]
217
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
218
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
219
+ kept_hidden_states = hidden_states[:, slice_indices, :]
218
220
 
221
+ shift_labels = kwargs.pop("shift_labels", None)
219
222
  logits = None
220
223
  loss = None
221
- # if in training mode, don't materialize logits
222
- if self.training and (labels is not None):
223
- loss = LigerForCausalLMLoss(
224
- hidden_states=hidden_states,
224
+ token_accuracy = None
225
+
226
+ if skip_logits and labels is None and shift_labels is None:
227
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
228
+
229
+ if skip_logits is None:
230
+ # By default, if in training mode, don't materialize logits
231
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
232
+
233
+ if skip_logits:
234
+ result = LigerForCausalLMLoss(
235
+ hidden_states=kept_hidden_states,
225
236
  lm_head_weight=self.lm_head.weight,
226
237
  labels=labels,
238
+ shift_labels=shift_labels,
227
239
  hidden_size=self.config.hidden_size,
228
- **loss_kwargs,
240
+ **kwargs,
229
241
  )
242
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
230
243
 
231
- else: # if in inference mode materialize logits
232
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233
- logits = self.lm_head(hidden_states[:, slice_indices, :])
234
- if labels is not None:
244
+ else:
245
+ logits = self.lm_head(kept_hidden_states)
246
+ if labels is not None or shift_labels is not None:
235
247
  loss = self.loss_function(
236
248
  logits=logits,
237
249
  labels=labels,
250
+ shift_labels=shift_labels,
238
251
  vocab_size=self.config.vocab_size,
239
- **loss_kwargs,
252
+ **kwargs,
240
253
  )
241
254
 
242
255
  if not return_dict:
243
256
  output = (logits,) + outputs[1:]
244
- return (loss,) + output if loss is not None else output
257
+ output = (loss,) + output if loss is not None else output
258
+ output = output + (token_accuracy,) if token_accuracy is not None else output
259
+ return output
245
260
 
246
- return CausalLMOutputWithPast(
261
+ # Return custom output class with token_accuracy field
262
+ return LigerCausalLMOutputWithPast(
247
263
  loss=loss,
248
264
  logits=logits,
249
265
  past_key_values=outputs.past_key_values,
250
266
  hidden_states=outputs.hidden_states,
251
267
  attentions=outputs.attentions,
268
+ token_accuracy=token_accuracy,
252
269
  )
@@ -5,19 +5,14 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
- from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
13
8
  from transformers.utils.deprecation import deprecate_kwarg
14
9
 
15
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
13
 
17
14
 
18
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
- @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
16
  def lce_forward(
22
17
  self,
23
18
  input_ids: torch.LongTensor = None,
@@ -32,8 +27,9 @@ def lce_forward(
32
27
  return_dict: Optional[bool] = None,
33
28
  cache_position: Optional[torch.LongTensor] = None,
34
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
35
- **loss_kwargs,
36
- ) -> Union[Tuple, CausalLMOutputWithPast]:
30
+ skip_logits: Optional[bool] = None,
31
+ **kwargs,
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
37
33
  r"""
38
34
  Args:
39
35
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -85,37 +81,61 @@ def lce_forward(
85
81
  output_hidden_states=output_hidden_states,
86
82
  return_dict=return_dict,
87
83
  cache_position=cache_position,
84
+ **kwargs,
88
85
  )
89
86
 
90
87
  hidden_states = outputs[0]
88
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
89
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
90
+ kept_hidden_states = hidden_states[:, slice_indices, :]
91
91
 
92
+ shift_labels = kwargs.pop("shift_labels", None)
92
93
  logits = None
93
94
  loss = None
94
- # if in training mode, don't materialize logits
95
- if self.training and (labels is not None):
96
- loss = LigerForCausalLMLoss(
97
- hidden_states=hidden_states,
95
+ token_accuracy = None
96
+
97
+ if skip_logits and labels is None and shift_labels is None:
98
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
99
+
100
+ if skip_logits is None:
101
+ # By default, if in training mode, don't materialize logits
102
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
103
+
104
+ # Compute loss
105
+ if skip_logits:
106
+ result = LigerForCausalLMLoss(
107
+ hidden_states=kept_hidden_states,
98
108
  lm_head_weight=self.lm_head.weight,
99
109
  labels=labels,
110
+ shift_labels=shift_labels,
100
111
  hidden_size=self.config.hidden_size,
101
- **loss_kwargs,
112
+ **kwargs,
102
113
  )
114
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
103
115
 
104
- else: # if in inference mode materialize logits
105
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
- logits = self.lm_head(hidden_states[:, slice_indices, :])
107
- if labels is not None:
116
+ else:
117
+ logits = self.lm_head(kept_hidden_states)
118
+ if labels is not None or shift_labels is not None:
108
119
  loss = self.loss_function(
109
120
  logits=logits,
110
121
  labels=labels,
122
+ shift_labels=shift_labels,
111
123
  vocab_size=self.config.vocab_size,
112
- **loss_kwargs,
124
+ **kwargs,
113
125
  )
114
126
 
115
- return CausalLMOutputWithPast(
127
+ if not return_dict:
128
+ output = (logits,) + outputs[1:]
129
+ output = ((loss,) + output) if loss is not None else output
130
+ output = output + (token_accuracy,) if token_accuracy is not None else output
131
+ return output
132
+
133
+ # Return custom output class with token_accuracy field
134
+ return LigerCausalLMOutputWithPast(
116
135
  loss=loss,
117
136
  logits=logits,
118
137
  past_key_values=outputs.past_key_values,
119
138
  hidden_states=outputs.hidden_states,
120
139
  attentions=outputs.attentions,
140
+ token_accuracy=token_accuracy,
121
141
  )
@@ -0,0 +1,142 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast
9
+ from transformers.utils.deprecation import deprecate_kwarg
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
14
+
15
+
16
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
17
+ def lce_forward(
18
+ self,
19
+ input_ids: torch.LongTensor = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.LongTensor] = None,
22
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
23
+ inputs_embeds: Optional[torch.FloatTensor] = None,
24
+ labels: Optional[torch.LongTensor] = None,
25
+ use_cache: Optional[bool] = None,
26
+ output_attentions: Optional[bool] = None,
27
+ output_hidden_states: Optional[bool] = None,
28
+ return_dict: Optional[bool] = None,
29
+ cache_position: Optional[torch.LongTensor] = None,
30
+ logits_to_keep: Union[int, torch.Tensor] = 0,
31
+ skip_logits: Optional[bool] = None,
32
+ **kwargs,
33
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
34
+ r"""
35
+ Args:
36
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40
+
41
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
42
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
43
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
44
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
45
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
46
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
47
+
48
+ Returns:
49
+
50
+ Example:
51
+
52
+ ```python
53
+ >>> from transformers import AutoTokenizer, Olmo3ForCausalLM
54
+
55
+ >>> model = Olmo3ForCausalLM.from_pretrained("allenai/Olmo-3-7B-Instruct")
56
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Instruct")
57
+
58
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
59
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
60
+
61
+ >>> # Generate
62
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
63
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
65
+ ```
66
+ """
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
68
+ output_hidden_states = (
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
70
+ )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
72
+
73
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
74
+ outputs: BaseModelOutputWithPast = self.model(
75
+ input_ids=input_ids,
76
+ attention_mask=attention_mask,
77
+ position_ids=position_ids,
78
+ past_key_values=past_key_values,
79
+ inputs_embeds=inputs_embeds,
80
+ use_cache=use_cache,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ cache_position=cache_position,
85
+ **kwargs,
86
+ )
87
+
88
+ hidden_states = outputs[0]
89
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
90
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
91
+ kept_hidden_states = hidden_states[:, slice_indices, :]
92
+
93
+ shift_labels = kwargs.pop("shift_labels", None)
94
+ logits = None
95
+ loss = None
96
+ token_accuracy = None
97
+
98
+ if skip_logits and labels is None and shift_labels is None:
99
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
100
+
101
+ if skip_logits is None:
102
+ # By default, if in training mode, don't materialize logits
103
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
104
+
105
+ # Compute loss
106
+ if skip_logits:
107
+ result = LigerForCausalLMLoss(
108
+ hidden_states=kept_hidden_states,
109
+ lm_head_weight=self.lm_head.weight,
110
+ labels=labels,
111
+ shift_labels=shift_labels,
112
+ hidden_size=self.config.hidden_size,
113
+ **kwargs,
114
+ )
115
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
116
+
117
+ else:
118
+ logits = self.lm_head(kept_hidden_states)
119
+ if labels is not None or shift_labels is not None:
120
+ loss = self.loss_function(
121
+ logits=logits,
122
+ labels=labels,
123
+ shift_labels=shift_labels,
124
+ vocab_size=self.config.vocab_size,
125
+ **kwargs,
126
+ )
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ output = ((loss,) + output) if loss is not None else output
131
+ output = output + (token_accuracy,) if token_accuracy is not None else output
132
+ return output
133
+
134
+ # Return custom output class with token_accuracy field
135
+ return LigerCausalLMOutputWithPast(
136
+ loss=loss,
137
+ logits=logits,
138
+ past_key_values=outputs.past_key_values,
139
+ hidden_states=outputs.hidden_states,
140
+ attentions=outputs.attentions,
141
+ token_accuracy=token_accuracy,
142
+ )
@@ -0,0 +1,147 @@
1
+ """
2
+ Custom output classes for Liger-Kernel that extend transformers' ModelOutput classes
3
+ with optional token accuracy field.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import torch
10
+
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast
13
+
14
+ # The following model-specific outputs are optional and depend on the installed
15
+ # transformers version. Guard their imports so our module remains importable
16
+ # even when those models are not available in the environment.
17
+ try:
18
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast as _Gemma3CausalLMOutputWithPast
19
+ except Exception:
20
+ _Gemma3CausalLMOutputWithPast = None
21
+
22
+ try:
23
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import (
24
+ Glm4vMoeCausalLMOutputWithPast as _Glm4vMoeCausalLMOutputWithPast,
25
+ )
26
+ except Exception:
27
+ _Glm4vMoeCausalLMOutputWithPast = None
28
+
29
+ try:
30
+ from transformers.models.internvl.modeling_internvl import (
31
+ InternVLCausalLMOutputWithPast as _InternVLCausalLMOutputWithPast,
32
+ )
33
+ except Exception:
34
+ _InternVLCausalLMOutputWithPast = None
35
+
36
+ try:
37
+ from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast as _LlavaCausalLMOutputWithPast
38
+ except Exception:
39
+ _LlavaCausalLMOutputWithPast = None
40
+
41
+ try:
42
+ from transformers.models.paligemma.modeling_paligemma import (
43
+ PaliGemmaCausalLMOutputWithPast as _PaliGemmaCausalLMOutputWithPast,
44
+ )
45
+ except Exception:
46
+ _PaliGemmaCausalLMOutputWithPast = None
47
+
48
+ try:
49
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
50
+ Qwen2_5_VLCausalLMOutputWithPast as _Qwen2_5_VLCausalLMOutputWithPast,
51
+ )
52
+ except Exception:
53
+ _Qwen2_5_VLCausalLMOutputWithPast = None
54
+
55
+ try:
56
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
57
+ Qwen2VLCausalLMOutputWithPast as _Qwen2VLCausalLMOutputWithPast,
58
+ )
59
+ except Exception:
60
+ _Qwen2VLCausalLMOutputWithPast = None
61
+
62
+ try:
63
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import (
64
+ Qwen3VLCausalLMOutputWithPast as _Qwen3VLCausalLMOutputWithPast,
65
+ )
66
+ except Exception:
67
+ _Qwen3VLCausalLMOutputWithPast = None
68
+
69
+ try:
70
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
71
+ Qwen3VLMoeCausalLMOutputWithPast as _Qwen3VLMoeCausalLMOutputWithPast,
72
+ )
73
+ except Exception:
74
+ _Qwen3VLMoeCausalLMOutputWithPast = None
75
+
76
+
77
+ @dataclass
78
+ class LigerCausalLMOutputWithPast(CausalLMOutputWithPast):
79
+ token_accuracy: Optional[torch.FloatTensor] = None
80
+
81
+
82
+ @dataclass
83
+ class LigerMoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
84
+ token_accuracy: Optional[torch.FloatTensor] = None
85
+
86
+
87
+ if _Gemma3CausalLMOutputWithPast is not None:
88
+
89
+ @dataclass
90
+ class LigerGemma3CausalLMOutputWithPast(_Gemma3CausalLMOutputWithPast):
91
+ token_accuracy: Optional[torch.FloatTensor] = None
92
+
93
+
94
+ if _Glm4vMoeCausalLMOutputWithPast is not None:
95
+
96
+ @dataclass
97
+ class LigerGlm4vMoeCausalLMOutputWithPast(_Glm4vMoeCausalLMOutputWithPast):
98
+ token_accuracy: Optional[torch.FloatTensor] = None
99
+
100
+
101
+ if _LlavaCausalLMOutputWithPast is not None:
102
+
103
+ @dataclass
104
+ class LigerLlavaCausalLMOutputWithPast(_LlavaCausalLMOutputWithPast):
105
+ token_accuracy: Optional[torch.FloatTensor] = None
106
+
107
+
108
+ if _InternVLCausalLMOutputWithPast is not None:
109
+
110
+ @dataclass
111
+ class LigerInternVLCausalLMOutputWithPast(_InternVLCausalLMOutputWithPast):
112
+ token_accuracy: Optional[torch.FloatTensor] = None
113
+
114
+
115
+ if _PaliGemmaCausalLMOutputWithPast is not None:
116
+
117
+ @dataclass
118
+ class LigerPaliGemmaCausalLMOutputWithPast(_PaliGemmaCausalLMOutputWithPast):
119
+ token_accuracy: Optional[torch.FloatTensor] = None
120
+
121
+
122
+ if _Qwen2_5_VLCausalLMOutputWithPast is not None:
123
+
124
+ @dataclass
125
+ class LigerQwen2_5_VLCausalLMOutputWithPast(_Qwen2_5_VLCausalLMOutputWithPast):
126
+ token_accuracy: Optional[torch.FloatTensor] = None
127
+
128
+
129
+ if _Qwen2VLCausalLMOutputWithPast is not None:
130
+
131
+ @dataclass
132
+ class LigerQwen2VLCausalLMOutputWithPast(_Qwen2VLCausalLMOutputWithPast):
133
+ token_accuracy: Optional[torch.FloatTensor] = None
134
+
135
+
136
+ if _Qwen3VLCausalLMOutputWithPast is not None:
137
+
138
+ @dataclass
139
+ class LigerQwen3VLCausalLMOutputWithPast(_Qwen3VLCausalLMOutputWithPast):
140
+ token_accuracy: Optional[torch.FloatTensor] = None
141
+
142
+
143
+ if _Qwen3VLMoeCausalLMOutputWithPast is not None:
144
+
145
+ @dataclass
146
+ class LigerQwen3VLMoeCausalLMOutputWithPast(_Qwen3VLMoeCausalLMOutputWithPast):
147
+ token_accuracy: Optional[torch.FloatTensor] = None