liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -7,18 +7,15 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
10
- from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
- from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
10
  from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
11
+ from transformers.utils.deprecation import deprecate_kwarg
15
12
 
16
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
18
17
 
19
18
 
20
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
19
  def lce_forward_deprecated(
23
20
  self,
24
21
  input_ids: torch.LongTensor = None,
@@ -144,8 +141,7 @@ def lce_forward_deprecated(
144
141
  )
145
142
 
146
143
 
147
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
148
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
144
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
149
145
  # Ignore copy
150
146
  def lce_forward(
151
147
  self,
@@ -161,9 +157,10 @@ def lce_forward(
161
157
  output_router_logits: Optional[bool] = None,
162
158
  return_dict: Optional[bool] = None,
163
159
  cache_position: Optional[torch.LongTensor] = None,
164
- num_logits_to_keep: int = 0,
165
- **loss_kwargs,
166
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
160
+ logits_to_keep: Union[int, torch.Tensor] = 0,
161
+ skip_logits: Optional[bool] = None,
162
+ **kwargs,
163
+ ) -> Union[Tuple, LigerMoeCausalLMOutputWithPast]:
167
164
  r"""
168
165
  Args:
169
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -171,10 +168,12 @@ def lce_forward(
171
168
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
172
169
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
173
170
 
174
- num_logits_to_keep (`int`, *optional*):
175
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
171
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
172
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
176
173
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
177
174
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
175
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
176
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
178
177
 
179
178
  Returns:
180
179
 
@@ -218,32 +217,50 @@ def lce_forward(
218
217
  output_router_logits=output_router_logits,
219
218
  return_dict=return_dict,
220
219
  cache_position=cache_position,
220
+ **kwargs,
221
221
  )
222
222
 
223
223
  hidden_states = outputs[0]
224
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
225
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
226
+ kept_hidden_states = hidden_states[:, slice_indices, :]
224
227
 
228
+ shift_labels = kwargs.pop("shift_labels", None)
225
229
  logits = None
226
230
  loss = None
227
- # if in training mode, don't materialize logits
228
- if self.training and (labels is not None):
229
- loss = LigerForCausalLMLoss(
230
- hidden_states=hidden_states,
231
+ token_accuracy = None
232
+
233
+ if skip_logits and labels is None and shift_labels is None:
234
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
235
+
236
+ if skip_logits is None:
237
+ # By default, if in training mode, don't materialize logits
238
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
239
+
240
+ # Compute loss
241
+ if skip_logits:
242
+ result = LigerForCausalLMLoss(
243
+ hidden_states=kept_hidden_states,
231
244
  lm_head_weight=self.lm_head.weight,
232
245
  labels=labels,
246
+ shift_labels=shift_labels,
233
247
  hidden_size=self.config.hidden_size,
234
- **loss_kwargs,
248
+ **kwargs,
235
249
  )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
236
251
 
237
- else: # if in inference mode materialize logits
238
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
239
- if labels is not None:
252
+ else:
253
+ logits = self.lm_head(kept_hidden_states)
254
+
255
+ loss = None
256
+ if labels is not None or shift_labels is not None:
240
257
  loss = self.loss_function(
241
258
  logits=logits,
242
259
  labels=labels,
243
- vocab_size=self.config.vocab_size,
244
- **loss_kwargs,
260
+ shift_labels=shift_labels,
261
+ vocab_size=self.vocab_size,
262
+ **kwargs,
245
263
  )
246
-
247
264
  aux_loss = None
248
265
  if output_router_logits:
249
266
  aux_loss = load_balancing_loss_func(
@@ -256,17 +273,21 @@ def lce_forward(
256
273
  loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
257
274
 
258
275
  if not return_dict:
259
- output = (logits,) + outputs[1:]
276
+ output_tuple = (logits,) + outputs[1:]
260
277
  if output_router_logits:
261
- output = (aux_loss,) + output
262
- return (loss,) + output if loss is not None else output
278
+ output_tuple = (aux_loss,) + output_tuple
279
+ if token_accuracy is not None:
280
+ output_tuple = output_tuple + (token_accuracy,)
281
+ return (loss,) + output_tuple if loss is not None else output_tuple
263
282
 
264
- return MoeCausalLMOutputWithPast(
283
+ # Return custom output class with token_accuracy field
284
+ return LigerMoeCausalLMOutputWithPast(
265
285
  loss=loss,
266
286
  aux_loss=aux_loss,
267
287
  logits=logits,
268
288
  past_key_values=outputs.past_key_values,
269
289
  hidden_states=outputs.hidden_states,
270
290
  attentions=outputs.attentions,
271
- router_logits=outputs.router_logits,
291
+ router_logits=outputs.router_logits if return_dict else outputs[-1],
292
+ token_accuracy=token_accuracy,
272
293
  )
@@ -8,16 +8,14 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
11
+ from transformers.utils.deprecation import deprecate_kwarg
14
12
 
15
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
17
17
 
18
18
 
19
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
21
19
  def lce_forward_deprecated(
22
20
  self,
23
21
  input_ids: torch.LongTensor = None,
@@ -133,8 +131,7 @@ def lce_forward_deprecated(
133
131
  )
134
132
 
135
133
 
136
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
137
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
134
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
135
  def lce_forward(
139
136
  self,
140
137
  input_ids: torch.LongTensor = None,
@@ -151,9 +148,10 @@ def lce_forward(
151
148
  output_hidden_states: Optional[bool] = None,
152
149
  return_dict: Optional[bool] = None,
153
150
  cache_position: Optional[torch.LongTensor] = None,
154
- num_logits_to_keep: int = 0,
155
- **loss_kwargs,
156
- ) -> Union[Tuple, CausalLMOutputWithPast]:
151
+ logits_to_keep: Union[int, torch.Tensor] = 0,
152
+ skip_logits: Optional[bool] = None,
153
+ **kwargs,
154
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
157
155
  r"""
158
156
  Args:
159
157
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -161,10 +159,12 @@ def lce_forward(
161
159
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
162
160
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
163
161
 
164
- num_logits_to_keep (`int`, *optional*):
165
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
162
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
163
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
166
164
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
167
165
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
166
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
167
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
168
168
 
169
169
  Returns:
170
170
 
@@ -192,7 +192,9 @@ def lce_forward(
192
192
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
193
193
  )
194
194
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
195
-
195
+ # Filter out accum_dtype from kwargs for model call as MllamaTextModel doesn't accept it in transformers 4.49.0
196
+ # but preserve it for loss function calls
197
+ model_kwargs = {k: v for k, v in kwargs.items() if k != "accum_dtype"}
196
198
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
197
199
  outputs = self.model(
198
200
  input_ids=input_ids,
@@ -208,40 +210,60 @@ def lce_forward(
208
210
  output_hidden_states=output_hidden_states,
209
211
  return_dict=return_dict,
210
212
  cache_position=cache_position,
213
+ **model_kwargs,
211
214
  )
212
215
 
213
216
  hidden_states = outputs[0]
217
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
218
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
219
+ kept_hidden_states = hidden_states[:, slice_indices, :]
214
220
 
221
+ shift_labels = kwargs.pop("shift_labels", None)
215
222
  logits = None
216
223
  loss = None
217
- # if in training mode, don't materialize logits
218
- if self.training and (labels is not None):
219
- loss = LigerForCausalLMLoss(
220
- hidden_states=hidden_states,
224
+ token_accuracy = None
225
+
226
+ if skip_logits and labels is None and shift_labels is None:
227
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
228
+
229
+ if skip_logits is None:
230
+ # By default, if in training mode, don't materialize logits
231
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
232
+
233
+ if skip_logits:
234
+ result = LigerForCausalLMLoss(
235
+ hidden_states=kept_hidden_states,
221
236
  lm_head_weight=self.lm_head.weight,
222
237
  labels=labels,
238
+ shift_labels=shift_labels,
223
239
  hidden_size=self.config.hidden_size,
224
- **loss_kwargs,
240
+ **kwargs,
225
241
  )
242
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
226
243
 
227
- else: # if in inference mode materialize logits
228
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
229
- if labels is not None:
244
+ else:
245
+ logits = self.lm_head(kept_hidden_states)
246
+ if labels is not None or shift_labels is not None:
230
247
  loss = self.loss_function(
231
248
  logits=logits,
232
249
  labels=labels,
250
+ shift_labels=shift_labels,
233
251
  vocab_size=self.config.vocab_size,
234
- **loss_kwargs,
252
+ **kwargs,
235
253
  )
236
254
 
237
255
  if not return_dict:
238
256
  output = (logits,) + outputs[1:]
239
- return (loss,) + output if loss is not None else output
257
+ output = (loss,) + output if loss is not None else output
258
+ output = output + (token_accuracy,) if token_accuracy is not None else output
259
+ return output
240
260
 
241
- return CausalLMOutputWithPast(
261
+ # Return custom output class with token_accuracy field
262
+ return LigerCausalLMOutputWithPast(
242
263
  loss=loss,
243
264
  logits=logits,
244
265
  past_key_values=outputs.past_key_values,
245
266
  hidden_states=outputs.hidden_states,
246
267
  attentions=outputs.attentions,
268
+ token_accuracy=token_accuracy,
247
269
  )
@@ -5,17 +5,14 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
- from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
8
+ from transformers.utils.deprecation import deprecate_kwarg
13
9
 
14
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
15
13
 
16
14
 
17
- @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
18
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
15
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
16
  def lce_forward(
20
17
  self,
21
18
  input_ids: torch.LongTensor = None,
@@ -29,9 +26,10 @@ def lce_forward(
29
26
  output_hidden_states: Optional[bool] = None,
30
27
  return_dict: Optional[bool] = None,
31
28
  cache_position: Optional[torch.LongTensor] = None,
32
- num_logits_to_keep: int = 0,
33
- **loss_kwargs,
34
- ) -> Union[Tuple, CausalLMOutputWithPast]:
29
+ logits_to_keep: Union[int, torch.Tensor] = 0,
30
+ skip_logits: Optional[bool] = None,
31
+ **kwargs,
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
35
33
  r"""
36
34
  Args:
37
35
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -39,10 +37,12 @@ def lce_forward(
39
37
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
38
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
39
 
42
- num_logits_to_keep (`int`, *optional*):
43
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
40
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
41
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
44
42
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
45
43
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
44
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
45
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
46
46
 
47
47
  Returns:
48
48
 
@@ -81,36 +81,61 @@ def lce_forward(
81
81
  output_hidden_states=output_hidden_states,
82
82
  return_dict=return_dict,
83
83
  cache_position=cache_position,
84
+ **kwargs,
84
85
  )
85
86
 
86
87
  hidden_states = outputs[0]
88
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
89
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
90
+ kept_hidden_states = hidden_states[:, slice_indices, :]
87
91
 
92
+ shift_labels = kwargs.pop("shift_labels", None)
88
93
  logits = None
89
94
  loss = None
90
- # if in training mode, don't materialize logits
91
- if self.training and (labels is not None):
92
- loss = LigerForCausalLMLoss(
93
- hidden_states=hidden_states,
95
+ token_accuracy = None
96
+
97
+ if skip_logits and labels is None and shift_labels is None:
98
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
99
+
100
+ if skip_logits is None:
101
+ # By default, if in training mode, don't materialize logits
102
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
103
+
104
+ # Compute loss
105
+ if skip_logits:
106
+ result = LigerForCausalLMLoss(
107
+ hidden_states=kept_hidden_states,
94
108
  lm_head_weight=self.lm_head.weight,
95
109
  labels=labels,
110
+ shift_labels=shift_labels,
96
111
  hidden_size=self.config.hidden_size,
97
- **loss_kwargs,
112
+ **kwargs,
98
113
  )
114
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
99
115
 
100
- else: # if in inference mode materialize logits
101
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
102
- if labels is not None:
116
+ else:
117
+ logits = self.lm_head(kept_hidden_states)
118
+ if labels is not None or shift_labels is not None:
103
119
  loss = self.loss_function(
104
120
  logits=logits,
105
121
  labels=labels,
122
+ shift_labels=shift_labels,
106
123
  vocab_size=self.config.vocab_size,
107
- **loss_kwargs,
124
+ **kwargs,
108
125
  )
109
126
 
110
- return CausalLMOutputWithPast(
127
+ if not return_dict:
128
+ output = (logits,) + outputs[1:]
129
+ output = ((loss,) + output) if loss is not None else output
130
+ output = output + (token_accuracy,) if token_accuracy is not None else output
131
+ return output
132
+
133
+ # Return custom output class with token_accuracy field
134
+ return LigerCausalLMOutputWithPast(
111
135
  loss=loss,
112
136
  logits=logits,
113
137
  past_key_values=outputs.past_key_values,
114
138
  hidden_states=outputs.hidden_states,
115
139
  attentions=outputs.attentions,
140
+ token_accuracy=token_accuracy,
116
141
  )
@@ -0,0 +1,142 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast
9
+ from transformers.utils.deprecation import deprecate_kwarg
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
14
+
15
+
16
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
17
+ def lce_forward(
18
+ self,
19
+ input_ids: torch.LongTensor = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.LongTensor] = None,
22
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
23
+ inputs_embeds: Optional[torch.FloatTensor] = None,
24
+ labels: Optional[torch.LongTensor] = None,
25
+ use_cache: Optional[bool] = None,
26
+ output_attentions: Optional[bool] = None,
27
+ output_hidden_states: Optional[bool] = None,
28
+ return_dict: Optional[bool] = None,
29
+ cache_position: Optional[torch.LongTensor] = None,
30
+ logits_to_keep: Union[int, torch.Tensor] = 0,
31
+ skip_logits: Optional[bool] = None,
32
+ **kwargs,
33
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
34
+ r"""
35
+ Args:
36
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40
+
41
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
42
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
43
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
44
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
45
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
46
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
47
+
48
+ Returns:
49
+
50
+ Example:
51
+
52
+ ```python
53
+ >>> from transformers import AutoTokenizer, Olmo3ForCausalLM
54
+
55
+ >>> model = Olmo3ForCausalLM.from_pretrained("allenai/Olmo-3-7B-Instruct")
56
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Instruct")
57
+
58
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
59
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
60
+
61
+ >>> # Generate
62
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
63
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
65
+ ```
66
+ """
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
68
+ output_hidden_states = (
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
70
+ )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
72
+
73
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
74
+ outputs: BaseModelOutputWithPast = self.model(
75
+ input_ids=input_ids,
76
+ attention_mask=attention_mask,
77
+ position_ids=position_ids,
78
+ past_key_values=past_key_values,
79
+ inputs_embeds=inputs_embeds,
80
+ use_cache=use_cache,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ cache_position=cache_position,
85
+ **kwargs,
86
+ )
87
+
88
+ hidden_states = outputs[0]
89
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
90
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
91
+ kept_hidden_states = hidden_states[:, slice_indices, :]
92
+
93
+ shift_labels = kwargs.pop("shift_labels", None)
94
+ logits = None
95
+ loss = None
96
+ token_accuracy = None
97
+
98
+ if skip_logits and labels is None and shift_labels is None:
99
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
100
+
101
+ if skip_logits is None:
102
+ # By default, if in training mode, don't materialize logits
103
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
104
+
105
+ # Compute loss
106
+ if skip_logits:
107
+ result = LigerForCausalLMLoss(
108
+ hidden_states=kept_hidden_states,
109
+ lm_head_weight=self.lm_head.weight,
110
+ labels=labels,
111
+ shift_labels=shift_labels,
112
+ hidden_size=self.config.hidden_size,
113
+ **kwargs,
114
+ )
115
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
116
+
117
+ else:
118
+ logits = self.lm_head(kept_hidden_states)
119
+ if labels is not None or shift_labels is not None:
120
+ loss = self.loss_function(
121
+ logits=logits,
122
+ labels=labels,
123
+ shift_labels=shift_labels,
124
+ vocab_size=self.config.vocab_size,
125
+ **kwargs,
126
+ )
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ output = ((loss,) + output) if loss is not None else output
131
+ output = output + (token_accuracy,) if token_accuracy is not None else output
132
+ return output
133
+
134
+ # Return custom output class with token_accuracy field
135
+ return LigerCausalLMOutputWithPast(
136
+ loss=loss,
137
+ logits=logits,
138
+ past_key_values=outputs.past_key_values,
139
+ hidden_states=outputs.hidden_states,
140
+ attentions=outputs.attentions,
141
+ token_accuracy=token_accuracy,
142
+ )