liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@ from transformers.utils.deprecation import deprecate_kwarg
15
15
  from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
16
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
19
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
18
20
  from liger_kernel.utils import PEFT_AVAILABLE
19
21
 
20
22
  if TYPE_CHECKING:
@@ -37,6 +39,7 @@ def lce_forward_deprecated(
37
39
  output_hidden_states: Optional[bool] = None,
38
40
  return_dict: Optional[bool] = None,
39
41
  cache_position: Optional[torch.LongTensor] = None,
42
+ skip_logits: Optional[bool] = None,
40
43
  ) -> Union[Tuple, CausalLMOutputWithPast]:
41
44
  r"""
42
45
  Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
@@ -91,7 +94,15 @@ def lce_forward_deprecated(
91
94
  loss = None
92
95
  logits = None
93
96
 
94
- if self.training and (labels is not None):
97
+ # if in training mode, don't materialize logits
98
+ if skip_logits and labels is None:
99
+ raise ValueError("skip_logits is True, but labels is None")
100
+
101
+ if skip_logits is None:
102
+ # By default, if in training mode, don't materialize logits
103
+ skip_logits = self.training and labels is not None
104
+
105
+ if skip_logits:
95
106
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
96
107
  shift_labels = labels[..., 1:].contiguous()
97
108
 
@@ -153,7 +164,7 @@ def lce_forward(
153
164
  logits_to_keep: Union[int, torch.Tensor] = 0,
154
165
  skip_logits: Optional[bool] = None,
155
166
  **kwargs,
156
- ) -> Union[Tuple, CausalLMOutputWithPast]:
167
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
157
168
  r"""
158
169
  Args:
159
170
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -219,6 +230,8 @@ def lce_forward(
219
230
  shift_labels = kwargs.pop("shift_labels", None)
220
231
  logits = None
221
232
  loss = None
233
+ token_accuracy = None
234
+
222
235
  # if in training mode, don't materialize logits
223
236
  if skip_logits and labels is None and shift_labels is None:
224
237
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -227,8 +240,9 @@ def lce_forward(
227
240
  # By default, if in training mode, don't materialize logits
228
241
  skip_logits = self.training and (labels is not None or shift_labels is not None)
229
242
 
243
+ # Compute loss
230
244
  if skip_logits:
231
- loss = lce_maybe_trainable_lm_head(
245
+ result = lce_maybe_trainable_lm_head(
232
246
  self,
233
247
  hidden_states=kept_hidden_states,
234
248
  hidden_size=self.config.hidden_size,
@@ -236,27 +250,32 @@ def lce_forward(
236
250
  shift_labels=shift_labels,
237
251
  **kwargs,
238
252
  )
239
-
253
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
240
254
  else:
241
255
  logits = self.lm_head(kept_hidden_states)
242
- if labels is not None:
256
+ if labels is not None or shift_labels is not None:
243
257
  loss = self.loss_function(
244
258
  logits=logits,
245
259
  labels=labels,
260
+ shift_labels=shift_labels,
246
261
  vocab_size=self.config.vocab_size,
247
262
  **kwargs,
248
263
  )
249
264
 
250
265
  if not return_dict:
251
266
  output = (logits,) + outputs[1:]
252
- return (loss,) + output if loss is not None else output
267
+ output = ((loss,) + output) if loss is not None else output
268
+ output = output + (token_accuracy,) if token_accuracy is not None else output
269
+ return output
253
270
 
254
- return CausalLMOutputWithPast(
271
+ # Return custom output class with token_accuracy field
272
+ return LigerCausalLMOutputWithPast(
255
273
  loss=loss,
256
274
  logits=logits,
257
275
  past_key_values=outputs.past_key_values,
258
276
  hidden_states=outputs.hidden_states,
259
277
  attentions=outputs.attentions,
278
+ token_accuracy=token_accuracy,
260
279
  )
261
280
 
262
281
 
@@ -0,0 +1,121 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.cache_utils import Cache
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
+
14
+
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ cache_position: Optional[torch.LongTensor] = None,
28
+ logits_to_keep: Union[int, torch.Tensor] = 0,
29
+ **kwargs,
30
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
31
+ r"""
32
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
33
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
34
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
35
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
36
+
37
+ Example:
38
+
39
+ ```python
40
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
41
+
42
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
43
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
44
+
45
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
46
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
47
+
48
+ >>> # Generate
49
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
50
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
51
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
52
+ ```"""
53
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
54
+ output_hidden_states = (
55
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
+ )
57
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
58
+
59
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
60
+ outputs = self.model(
61
+ input_ids=input_ids,
62
+ attention_mask=attention_mask,
63
+ position_ids=position_ids,
64
+ past_key_values=past_key_values,
65
+ inputs_embeds=inputs_embeds,
66
+ use_cache=use_cache,
67
+ output_attentions=output_attentions,
68
+ output_hidden_states=output_hidden_states,
69
+ return_dict=True,
70
+ cache_position=cache_position,
71
+ **kwargs,
72
+ )
73
+
74
+ hidden_states = outputs[0]
75
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
76
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
77
+ kept_hidden_states = hidden_states[:, slice_indices, :]
78
+
79
+ shift_labels = kwargs.pop("shift_labels", None)
80
+ logits = None
81
+ loss = None
82
+ token_accuracy = None
83
+
84
+ # Compute loss
85
+ if self.training and (labels is not None or shift_labels is not None):
86
+ result = LigerForCausalLMLoss(
87
+ hidden_states=kept_hidden_states,
88
+ lm_head_weight=self.lm_head.weight,
89
+ labels=labels,
90
+ shift_labels=shift_labels,
91
+ hidden_size=self.config.hidden_size,
92
+ **kwargs,
93
+ )
94
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
95
+
96
+ else: # if in inference mode materialize logits
97
+ logits = self.lm_head(kept_hidden_states)
98
+ if labels is not None or shift_labels is not None:
99
+ loss = self.loss_function(
100
+ logits=logits,
101
+ labels=labels,
102
+ shift_labels=shift_labels,
103
+ vocab_size=self.config.vocab_size,
104
+ **kwargs,
105
+ )
106
+
107
+ if not return_dict:
108
+ output = (logits,) + outputs[1:]
109
+ output = ((loss,) + output) if loss is not None else output
110
+ output = output + (token_accuracy,) if token_accuracy is not None else output
111
+ return output
112
+
113
+ # Return custom output class with token_accuracy field
114
+ return LigerCausalLMOutputWithPast(
115
+ loss=loss,
116
+ logits=logits,
117
+ past_key_values=outputs.past_key_values,
118
+ hidden_states=outputs.hidden_states,
119
+ attentions=outputs.attentions,
120
+ token_accuracy=token_accuracy,
121
+ )
@@ -11,6 +11,8 @@ from transformers.utils import is_torchdynamo_compiling
11
11
 
12
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerLlavaCausalLMOutputWithPast
14
16
 
15
17
 
16
18
  def lce_forward_deprecated(
@@ -215,7 +217,7 @@ def lce_forward(
215
217
  image_sizes: torch.Tensor = None,
216
218
  skip_logits: Optional[bool] = None,
217
219
  **lm_kwargs,
218
- ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
220
+ ) -> Union[Tuple, LigerLlavaCausalLMOutputWithPast]:
219
221
  r"""
220
222
  Args:
221
223
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -293,6 +295,7 @@ def lce_forward(
293
295
  shift_labels = lm_kwargs.pop("shift_labels", None)
294
296
  logits = None
295
297
  loss = None
298
+ token_accuracy = None
296
299
 
297
300
  if skip_logits and labels is None and shift_labels is None:
298
301
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -302,7 +305,7 @@ def lce_forward(
302
305
  skip_logits = self.training and (labels is not None or shift_labels is not None)
303
306
 
304
307
  if skip_logits:
305
- loss = LigerForCausalLMLoss(
308
+ result = LigerForCausalLMLoss(
306
309
  hidden_states=kept_hidden_states,
307
310
  lm_head_weight=self.lm_head.weight,
308
311
  labels=labels,
@@ -310,23 +313,32 @@ def lce_forward(
310
313
  hidden_size=self.config.text_config.hidden_size,
311
314
  **lm_kwargs,
312
315
  )
316
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
313
317
 
314
318
  else:
315
319
  logits = self.lm_head(kept_hidden_states)
316
- if labels is not None:
320
+ if labels is not None or shift_labels is not None:
317
321
  loss = self.loss_function(
318
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
322
+ logits=logits,
323
+ labels=labels,
324
+ shift_labels=shift_labels,
325
+ vocab_size=self.config.text_config.vocab_size,
326
+ **lm_kwargs,
319
327
  )
320
328
 
321
329
  if not return_dict:
322
330
  output = (logits,) + outputs[1:]
323
- return (loss,) + output if loss is not None else output
331
+ output = (loss,) + output if loss is not None else output
332
+ output = output + (token_accuracy,) if token_accuracy is not None else output
333
+ return output
324
334
 
325
- return LlavaCausalLMOutputWithPast(
335
+ # Return custom output class with token_accuracy field
336
+ return LigerLlavaCausalLMOutputWithPast(
326
337
  loss=loss,
327
338
  logits=logits,
328
339
  past_key_values=outputs.past_key_values,
329
340
  hidden_states=outputs.hidden_states,
330
341
  attentions=outputs.attentions,
331
342
  image_hidden_states=outputs.image_hidden_states,
343
+ token_accuracy=token_accuracy,
332
344
  )
@@ -1,10 +1,28 @@
1
1
  from typing import Optional
2
+ from typing import Tuple
2
3
 
3
4
  import torch
4
5
  import torch.nn as nn
5
6
 
6
7
  import liger_kernel.transformers.functional as F
7
8
 
9
+ from liger_kernel.transformers.functional import CrossEntropyOutput
10
+
11
+
12
+ def unpack_cross_entropy_result(
13
+ result,
14
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
15
+ if isinstance(result, CrossEntropyOutput):
16
+ return result.loss, result.z_loss, result.token_accuracy
17
+
18
+ if isinstance(result, tuple):
19
+ loss = result[0]
20
+ z_loss = result[1] if len(result) > 1 else None
21
+ token_accuracy = result[2] if len(result) > 2 else None
22
+ return loss, z_loss, token_accuracy
23
+
24
+ return result, None, None
25
+
8
26
 
9
27
  def fixed_fused_linear_cross_entropy(
10
28
  hidden_states: torch.Tensor,
@@ -13,20 +31,31 @@ def fixed_fused_linear_cross_entropy(
13
31
  num_items_in_batch: Optional[int] = None,
14
32
  ignore_index: int = -100,
15
33
  final_logit_softcapping: Optional[float] = None,
34
+ accum_dtype: Optional[torch.dtype] = None,
35
+ return_token_accuracy: bool = False,
16
36
  **kwargs,
17
37
  ):
18
38
  reduction = "sum" if num_items_in_batch is not None else "mean"
19
- loss = F.liger_fused_linear_cross_entropy(
39
+ result = F.liger_fused_linear_cross_entropy(
20
40
  hidden_states,
21
41
  lm_head_weight,
22
42
  target,
23
43
  reduction=reduction,
24
44
  ignore_index=ignore_index,
25
45
  softcap=final_logit_softcapping,
46
+ accum_dtype=accum_dtype,
47
+ return_token_accuracy=return_token_accuracy,
48
+ **kwargs,
26
49
  )
50
+
51
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
52
+
27
53
  if reduction == "sum":
28
54
  loss = loss / num_items_in_batch
29
55
 
56
+ if return_token_accuracy:
57
+ return CrossEntropyOutput(loss=loss, token_accuracy=token_accuracy)
58
+
30
59
  return loss
31
60
 
32
61
 
@@ -39,6 +68,7 @@ def LigerForCausalLMLoss(
39
68
  ignore_index: int = -100,
40
69
  shift_labels: Optional[torch.Tensor] = None,
41
70
  final_logit_softcapping: Optional[float] = None,
71
+ return_token_accuracy: bool = False,
42
72
  **kwargs,
43
73
  ):
44
74
  # Skip upcast since intermediate values for the loss are all fp32 in kernel
@@ -52,13 +82,14 @@ def LigerForCausalLMLoss(
52
82
  shift_labels = shift_labels.view(-1)
53
83
  # Enable model parallelism
54
84
  shift_labels = shift_labels.to(hidden_states.device)
55
- loss = fixed_fused_linear_cross_entropy(
85
+ result = fixed_fused_linear_cross_entropy(
56
86
  hidden_states,
57
87
  lm_head_weight,
58
88
  shift_labels,
59
89
  num_items_in_batch,
60
90
  ignore_index,
61
91
  final_logit_softcapping,
92
+ return_token_accuracy=return_token_accuracy,
62
93
  **kwargs,
63
94
  )
64
- return loss
95
+ return result
@@ -6,10 +6,11 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.cache_utils import Cache
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
9
  from transformers.utils.deprecation import deprecate_kwarg
11
10
 
12
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
14
 
14
15
 
15
16
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -29,7 +30,7 @@ def lce_forward(
29
30
  logits_to_keep: Union[int, torch.Tensor] = 0,
30
31
  skip_logits: Optional[bool] = None,
31
32
  **kwargs,
32
- ) -> Union[Tuple, CausalLMOutputWithPast]:
33
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
33
34
  r"""
34
35
  Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
35
36
 
@@ -94,6 +95,7 @@ def lce_forward(
94
95
  shift_labels = kwargs.pop("shift_labels", None)
95
96
  loss = None
96
97
  logits = None
98
+ token_accuracy = None
97
99
 
98
100
  if skip_logits and labels is None and shift_labels is None:
99
101
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -101,8 +103,9 @@ def lce_forward(
101
103
  if skip_logits is None:
102
104
  skip_logits = self.training and (labels is not None or shift_labels is not None)
103
105
 
106
+ # Compute loss
104
107
  if skip_logits:
105
- loss = LigerForCausalLMLoss(
108
+ result = LigerForCausalLMLoss(
106
109
  hidden_states=kept_hidden_states,
107
110
  lm_head_weight=self.lm_head.weight,
108
111
  labels=labels,
@@ -110,29 +113,33 @@ def lce_forward(
110
113
  hidden_size=self.config.hidden_size,
111
114
  **kwargs,
112
115
  )
116
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
113
117
 
114
118
  else:
115
119
  logits = self.lm_head(kept_hidden_states)
116
120
 
117
121
  loss = None
118
- if labels is not None:
122
+ if labels is not None or shift_labels is not None:
119
123
  loss = self.loss_function(
120
124
  logits=logits,
121
125
  labels=labels,
126
+ shift_labels=shift_labels,
122
127
  vocab_size=self.config.vocab_size,
123
128
  **kwargs,
124
129
  )
130
+
125
131
  if not return_dict:
126
- output = (logits,) + outputs[1:]
127
- return (loss,) + output if loss is not None else output
132
+ output_tuple = (logits,) + outputs[1:]
133
+ output = (loss,) + output_tuple if loss is not None else output_tuple
134
+ output = output + (token_accuracy,) if token_accuracy is not None else output
135
+ return output
128
136
 
129
- return CausalLMOutputWithPast(
137
+ # Return custom output class with token_accuracy field
138
+ return LigerCausalLMOutputWithPast(
130
139
  loss=loss,
131
140
  logits=logits,
132
141
  past_key_values=outputs.past_key_values,
133
142
  hidden_states=outputs.hidden_states,
134
143
  attentions=outputs.attentions,
144
+ token_accuracy=token_accuracy,
135
145
  )
136
-
137
-
138
- # Note: Grad Acc is not fixed in mistral at transformer 4.46.1
@@ -12,6 +12,8 @@ from transformers.utils.deprecation import deprecate_kwarg
12
12
 
13
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
14
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
15
17
 
16
18
 
17
19
  def lce_forward_deprecated(
@@ -158,7 +160,7 @@ def lce_forward(
158
160
  logits_to_keep: Union[int, torch.Tensor] = 0,
159
161
  skip_logits: Optional[bool] = None,
160
162
  **kwargs,
161
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
163
+ ) -> Union[Tuple, LigerMoeCausalLMOutputWithPast]:
162
164
  r"""
163
165
  Args:
164
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -226,6 +228,7 @@ def lce_forward(
226
228
  shift_labels = kwargs.pop("shift_labels", None)
227
229
  logits = None
228
230
  loss = None
231
+ token_accuracy = None
229
232
 
230
233
  if skip_logits and labels is None and shift_labels is None:
231
234
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -234,8 +237,9 @@ def lce_forward(
234
237
  # By default, if in training mode, don't materialize logits
235
238
  skip_logits = self.training and (labels is not None or shift_labels is not None)
236
239
 
240
+ # Compute loss
237
241
  if skip_logits:
238
- loss = LigerForCausalLMLoss(
242
+ result = LigerForCausalLMLoss(
239
243
  hidden_states=kept_hidden_states,
240
244
  lm_head_weight=self.lm_head.weight,
241
245
  labels=labels,
@@ -243,13 +247,20 @@ def lce_forward(
243
247
  hidden_size=self.config.hidden_size,
244
248
  **kwargs,
245
249
  )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
246
251
 
247
252
  else:
248
253
  logits = self.lm_head(kept_hidden_states)
249
254
 
250
255
  loss = None
251
- if labels is not None:
252
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
256
+ if labels is not None or shift_labels is not None:
257
+ loss = self.loss_function(
258
+ logits=logits,
259
+ labels=labels,
260
+ shift_labels=shift_labels,
261
+ vocab_size=self.vocab_size,
262
+ **kwargs,
263
+ )
253
264
  aux_loss = None
254
265
  if output_router_logits:
255
266
  aux_loss = load_balancing_loss_func(
@@ -262,17 +273,21 @@ def lce_forward(
262
273
  loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
263
274
 
264
275
  if not return_dict:
265
- output = (logits,) + outputs[1:]
276
+ output_tuple = (logits,) + outputs[1:]
266
277
  if output_router_logits:
267
- output = (aux_loss,) + output
268
- return (loss,) + output if loss is not None else output
278
+ output_tuple = (aux_loss,) + output_tuple
279
+ if token_accuracy is not None:
280
+ output_tuple = output_tuple + (token_accuracy,)
281
+ return (loss,) + output_tuple if loss is not None else output_tuple
269
282
 
270
- return MoeCausalLMOutputWithPast(
283
+ # Return custom output class with token_accuracy field
284
+ return LigerMoeCausalLMOutputWithPast(
271
285
  loss=loss,
272
286
  aux_loss=aux_loss,
273
287
  logits=logits,
274
288
  past_key_values=outputs.past_key_values,
275
289
  hidden_states=outputs.hidden_states,
276
290
  attentions=outputs.attentions,
277
- router_logits=outputs.router_logits,
291
+ router_logits=outputs.router_logits if return_dict else outputs[-1],
292
+ token_accuracy=token_accuracy,
278
293
  )
@@ -12,6 +12,8 @@ from transformers.utils.deprecation import deprecate_kwarg
12
12
 
13
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
14
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
15
17
 
16
18
 
17
19
  def lce_forward_deprecated(
@@ -149,7 +151,7 @@ def lce_forward(
149
151
  logits_to_keep: Union[int, torch.Tensor] = 0,
150
152
  skip_logits: Optional[bool] = None,
151
153
  **kwargs,
152
- ) -> Union[Tuple, CausalLMOutputWithPast]:
154
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
153
155
  r"""
154
156
  Args:
155
157
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -190,7 +192,9 @@ def lce_forward(
190
192
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
191
193
  )
192
194
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
193
-
195
+ # Filter out accum_dtype from kwargs for model call as MllamaTextModel doesn't accept it in transformers 4.49.0
196
+ # but preserve it for loss function calls
197
+ model_kwargs = {k: v for k, v in kwargs.items() if k != "accum_dtype"}
194
198
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
195
199
  outputs = self.model(
196
200
  input_ids=input_ids,
@@ -206,7 +210,7 @@ def lce_forward(
206
210
  output_hidden_states=output_hidden_states,
207
211
  return_dict=return_dict,
208
212
  cache_position=cache_position,
209
- **kwargs,
213
+ **model_kwargs,
210
214
  )
211
215
 
212
216
  hidden_states = outputs[0]
@@ -217,6 +221,7 @@ def lce_forward(
217
221
  shift_labels = kwargs.pop("shift_labels", None)
218
222
  logits = None
219
223
  loss = None
224
+ token_accuracy = None
220
225
 
221
226
  if skip_logits and labels is None and shift_labels is None:
222
227
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -226,7 +231,7 @@ def lce_forward(
226
231
  skip_logits = self.training and (labels is not None or shift_labels is not None)
227
232
 
228
233
  if skip_logits:
229
- loss = LigerForCausalLMLoss(
234
+ result = LigerForCausalLMLoss(
230
235
  hidden_states=kept_hidden_states,
231
236
  lm_head_weight=self.lm_head.weight,
232
237
  labels=labels,
@@ -234,25 +239,31 @@ def lce_forward(
234
239
  hidden_size=self.config.hidden_size,
235
240
  **kwargs,
236
241
  )
242
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
237
243
 
238
244
  else:
239
245
  logits = self.lm_head(kept_hidden_states)
240
- if labels is not None:
246
+ if labels is not None or shift_labels is not None:
241
247
  loss = self.loss_function(
242
248
  logits=logits,
243
249
  labels=labels,
250
+ shift_labels=shift_labels,
244
251
  vocab_size=self.config.vocab_size,
245
252
  **kwargs,
246
253
  )
247
254
 
248
255
  if not return_dict:
249
256
  output = (logits,) + outputs[1:]
250
- return (loss,) + output if loss is not None else output
257
+ output = (loss,) + output if loss is not None else output
258
+ output = output + (token_accuracy,) if token_accuracy is not None else output
259
+ return output
251
260
 
252
- return CausalLMOutputWithPast(
261
+ # Return custom output class with token_accuracy field
262
+ return LigerCausalLMOutputWithPast(
253
263
  loss=loss,
254
264
  logits=logits,
255
265
  past_key_values=outputs.past_key_values,
256
266
  hidden_states=outputs.hidden_states,
257
267
  attentions=outputs.attentions,
268
+ token_accuracy=token_accuracy,
258
269
  )