liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  12. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  13. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  14. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  15. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  16. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  17. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  18. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  19. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  20. liger_kernel/ops/backends/registry.py +61 -0
  21. liger_kernel/ops/cross_entropy.py +71 -11
  22. liger_kernel/ops/dyt.py +5 -2
  23. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  24. liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
  25. liger_kernel/ops/geglu.py +5 -3
  26. liger_kernel/ops/group_norm.py +12 -8
  27. liger_kernel/ops/grpo_loss.py +3 -1
  28. liger_kernel/ops/kl_div.py +8 -11
  29. liger_kernel/ops/layer_norm.py +89 -69
  30. liger_kernel/ops/poly_norm.py +19 -21
  31. liger_kernel/ops/rms_norm.py +149 -71
  32. liger_kernel/ops/tiled_mlp.py +136 -0
  33. liger_kernel/ops/utils.py +25 -0
  34. liger_kernel/transformers/__init__.py +25 -0
  35. liger_kernel/transformers/auto_model.py +21 -0
  36. liger_kernel/transformers/cross_entropy.py +9 -4
  37. liger_kernel/transformers/dyt.py +1 -1
  38. liger_kernel/transformers/experimental/embedding.py +1 -1
  39. liger_kernel/transformers/functional.py +44 -26
  40. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  41. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  42. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  43. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  44. liger_kernel/transformers/geglu.py +1 -1
  45. liger_kernel/transformers/group_norm.py +1 -1
  46. liger_kernel/transformers/grpo_loss.py +57 -2
  47. liger_kernel/transformers/jsd.py +1 -1
  48. liger_kernel/transformers/kl_div.py +1 -1
  49. liger_kernel/transformers/layer_norm.py +1 -1
  50. liger_kernel/transformers/llama4_rope.py +1 -1
  51. liger_kernel/transformers/model/exaone4.py +136 -0
  52. liger_kernel/transformers/model/falcon_h1.py +19 -5
  53. liger_kernel/transformers/model/gemma.py +17 -6
  54. liger_kernel/transformers/model/gemma2.py +17 -8
  55. liger_kernel/transformers/model/gemma3.py +35 -16
  56. liger_kernel/transformers/model/glm4.py +16 -4
  57. liger_kernel/transformers/model/glm4v.py +16 -4
  58. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  59. liger_kernel/transformers/model/gpt_oss.py +211 -0
  60. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  61. liger_kernel/transformers/model/internvl.py +12 -5
  62. liger_kernel/transformers/model/llama.py +14 -5
  63. liger_kernel/transformers/model/llama4.py +16 -4
  64. liger_kernel/transformers/model/llava.py +12 -4
  65. liger_kernel/transformers/model/loss_utils.py +37 -3
  66. liger_kernel/transformers/model/mistral.py +15 -6
  67. liger_kernel/transformers/model/mixtral.py +16 -7
  68. liger_kernel/transformers/model/mllama.py +12 -4
  69. liger_kernel/transformers/model/olmo2.py +16 -4
  70. liger_kernel/transformers/model/olmo3.py +142 -0
  71. liger_kernel/transformers/model/output_classes.py +147 -0
  72. liger_kernel/transformers/model/paligemma.py +23 -5
  73. liger_kernel/transformers/model/phi3.py +14 -7
  74. liger_kernel/transformers/model/qwen2.py +16 -3
  75. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  76. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  77. liger_kernel/transformers/model/qwen3.py +20 -5
  78. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  79. liger_kernel/transformers/model/qwen3_next.py +17 -5
  80. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  81. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  82. liger_kernel/transformers/model/smollm3.py +15 -6
  83. liger_kernel/transformers/monkey_patch.py +584 -49
  84. liger_kernel/transformers/multi_token_attention.py +1 -1
  85. liger_kernel/transformers/poly_norm.py +1 -1
  86. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  87. liger_kernel/transformers/rms_norm.py +8 -3
  88. liger_kernel/transformers/rope.py +45 -1
  89. liger_kernel/transformers/softmax.py +1 -1
  90. liger_kernel/transformers/sparsemax.py +1 -1
  91. liger_kernel/transformers/swiglu.py +18 -1
  92. liger_kernel/transformers/tiled_mlp.py +125 -0
  93. liger_kernel/transformers/tvd.py +1 -1
  94. liger_kernel/utils.py +54 -0
  95. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
  96. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  97. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  98. liger_kernel-0.6.3.dist-info/RECORD +0 -111
  99. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  100. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  101. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import torch.nn as nn
2
2
 
3
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
3
+ from liger_kernel.ops import LigerKLDivLossFunction
4
4
 
5
5
 
6
6
  class LigerKLDIVLoss(nn.KLDivLoss):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
4
+ from liger_kernel.ops import LigerLayerNormFunction
5
5
 
6
6
 
7
7
  class LigerLayerNorm(nn.Module):
@@ -5,7 +5,7 @@ Supports both text and vision RoPE variants with fused operations for optimal pe
5
5
 
6
6
  import torch
7
7
 
8
- from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction
8
+ from liger_kernel.ops import LigerLlama4RopeFunction
9
9
 
10
10
 
11
11
  def liger_llama4_text_rotary_pos_emb(
@@ -0,0 +1,136 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
8
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
9
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10
+
11
+
12
+ def lce_forward(
13
+ self,
14
+ input_ids: Optional[torch.LongTensor] = None,
15
+ attention_mask: Optional[torch.Tensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
18
+ inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ cache_position: Optional[torch.LongTensor] = None,
24
+ logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ **kwargs,
28
+ ) -> LigerCausalLMOutputWithPast:
29
+ r"""
30
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
31
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
32
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
33
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
34
+
35
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
36
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
37
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
38
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
39
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
40
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ````python
47
+ >>> from transformers import AutoTokenizer, Exaone4ForCausalLM
48
+
49
+ >>> model = Exaone4ForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B")
50
+ >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B")
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61
+ output_hidden_states = (
62
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63
+ )
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67
+ outputs = self.model(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ position_ids=position_ids,
71
+ past_key_values=past_key_values,
72
+ inputs_embeds=inputs_embeds,
73
+ use_cache=use_cache,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ cache_position=cache_position,
77
+ **kwargs,
78
+ )
79
+
80
+ hidden_states = outputs[0]
81
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
82
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
83
+ kept_hidden_states = hidden_states[:, slice_indices, :]
84
+
85
+ shift_labels = kwargs.pop("shift_labels", None)
86
+ # Remove output-control parameters that shouldn't be passed to loss functions
87
+ kwargs.pop("return_dict", None)
88
+ logits = None
89
+ loss = None
90
+ token_accuracy = None
91
+
92
+ if skip_logits and labels is None and shift_labels is None:
93
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
94
+
95
+ if skip_logits is None:
96
+ # By default, if in training mode, don't materialize logits
97
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
98
+
99
+ # Compute loss
100
+ if skip_logits:
101
+ result = LigerForCausalLMLoss(
102
+ hidden_states=kept_hidden_states,
103
+ lm_head_weight=self.lm_head.weight,
104
+ labels=labels,
105
+ shift_labels=shift_labels,
106
+ hidden_size=self.config.hidden_size,
107
+ **kwargs,
108
+ )
109
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
110
+
111
+ else:
112
+ logits = self.lm_head(kept_hidden_states)
113
+ if labels is not None or shift_labels is not None:
114
+ loss = self.loss_function(
115
+ logits=logits,
116
+ labels=labels,
117
+ shift_labels=shift_labels,
118
+ vocab_size=self.config.vocab_size,
119
+ **kwargs,
120
+ )
121
+
122
+ if not return_dict:
123
+ output = (logits,) + outputs[1:]
124
+ output = ((loss,) + output) if loss is not None else output
125
+ output = output + (token_accuracy,) if token_accuracy is not None else output
126
+ return output
127
+
128
+ # Return custom output class with accuracy field
129
+ return LigerCausalLMOutputWithPast(
130
+ loss=loss,
131
+ logits=logits,
132
+ past_key_values=outputs.past_key_values,
133
+ hidden_states=outputs.hidden_states,
134
+ attentions=outputs.attentions,
135
+ token_accuracy=token_accuracy,
136
+ )
@@ -4,12 +4,12 @@ from typing import Union
4
4
 
5
5
  import torch
6
6
 
7
- from transformers.modeling_outputs import CausalLMOutputWithPast
8
-
9
7
  if TYPE_CHECKING:
10
8
  from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
11
9
 
12
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
13
 
14
14
 
15
15
  def lce_forward(
@@ -26,8 +26,9 @@ def lce_forward(
26
26
  cache_position: Optional[torch.LongTensor] = None,
27
27
  logits_to_keep: Union[int, torch.Tensor] = 0,
28
28
  skip_logits: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
29
30
  **kwargs,
30
- ) -> Union[tuple, CausalLMOutputWithPast]:
31
+ ) -> Union[tuple, LigerCausalLMOutputWithPast]:
31
32
  r"""
32
33
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
33
34
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -54,6 +55,7 @@ def lce_forward(
54
55
  output_hidden_states = (
55
56
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
57
  )
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
59
 
58
60
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
59
61
  outputs = self.model(
@@ -77,6 +79,8 @@ def lce_forward(
77
79
  shift_labels = kwargs.pop("shift_labels", None)
78
80
  logits = None
79
81
  loss = None
82
+ token_accuracy = None
83
+
80
84
  # if in training mode, don't materialize logits
81
85
  if skip_logits and labels is None:
82
86
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -85,8 +89,9 @@ def lce_forward(
85
89
  # By default, if in training mode, don't materialize logits
86
90
  skip_logits = self.training and labels is not None
87
91
 
92
+ # Compute loss
88
93
  if skip_logits:
89
- loss = LigerForCausalLMLoss(
94
+ result = LigerForCausalLMLoss(
90
95
  hidden_states=kept_hidden_states,
91
96
  lm_head_weight=self.lm_head.weight,
92
97
  labels=labels,
@@ -94,15 +99,24 @@ def lce_forward(
94
99
  hidden_size=self.config.hidden_size,
95
100
  **kwargs,
96
101
  )
102
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
97
103
  else:
98
104
  logits = self.lm_head(kept_hidden_states)
99
105
  if labels is not None or shift_labels is not None:
100
106
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
101
107
 
102
- return CausalLMOutputWithPast(
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ output = ((loss,) + output) if loss is not None else output
111
+ output = output + (token_accuracy,) if token_accuracy is not None else output
112
+ return output
113
+
114
+ # Return custom output class with token_accuracy field
115
+ return LigerCausalLMOutputWithPast(
103
116
  loss=loss,
104
117
  logits=logits,
105
118
  past_key_values=outputs.past_key_values,
106
119
  hidden_states=outputs.hidden_states,
107
120
  attentions=outputs.attentions,
121
+ token_accuracy=token_accuracy,
108
122
  )
@@ -12,6 +12,8 @@ from transformers.utils.deprecation import deprecate_kwarg
12
12
 
13
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
14
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
15
17
 
16
18
 
17
19
  def lce_forward_deprecated(
@@ -147,7 +149,7 @@ def lce_forward(
147
149
  logits_to_keep: Union[int, torch.Tensor] = 0,
148
150
  skip_logits: Optional[bool] = None,
149
151
  **kwargs,
150
- ) -> Union[Tuple, CausalLMOutputWithPast]:
152
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
151
153
  r"""
152
154
  Args:
153
155
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -209,6 +211,7 @@ def lce_forward(
209
211
  shift_labels = kwargs.pop("shift_labels", None)
210
212
  logits = None
211
213
  loss = None
214
+ token_accuracy = None
212
215
 
213
216
  if skip_logits and labels is None and shift_labels is None:
214
217
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -217,8 +220,9 @@ def lce_forward(
217
220
  # By default, if in training mode, don't materialize logits
218
221
  skip_logits = self.training and (labels is not None or shift_labels is not None)
219
222
 
223
+ # Compute loss
220
224
  if skip_logits:
221
- loss = LigerForCausalLMLoss(
225
+ result = LigerForCausalLMLoss(
222
226
  hidden_states=kept_hidden_states,
223
227
  lm_head_weight=self.lm_head.weight,
224
228
  labels=labels,
@@ -226,6 +230,7 @@ def lce_forward(
226
230
  hidden_size=self.config.hidden_size,
227
231
  **kwargs,
228
232
  )
233
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
229
234
  else:
230
235
  logits = self.lm_head(kept_hidden_states)
231
236
  if labels is not None or shift_labels is not None:
@@ -238,13 +243,19 @@ def lce_forward(
238
243
  )
239
244
 
240
245
  if not return_dict:
241
- output = (logits,) + outputs[1:]
242
- return (loss,) + output if loss is not None else output
243
-
244
- return CausalLMOutputWithPast(
246
+ output_tuple = (logits,) + outputs[1:]
247
+ if loss is not None:
248
+ output_tuple = (loss,) + output_tuple
249
+ if token_accuracy is not None:
250
+ output_tuple = output_tuple + (token_accuracy,)
251
+ return output_tuple
252
+
253
+ # Return custom output class with token_accuracy field
254
+ return LigerCausalLMOutputWithPast(
245
255
  loss=loss,
246
256
  logits=logits,
247
257
  past_key_values=outputs.past_key_values,
248
258
  hidden_states=outputs.hidden_states,
249
259
  attentions=outputs.attentions,
260
+ token_accuracy=token_accuracy,
250
261
  )
@@ -7,12 +7,14 @@ from typing import Union
7
7
  import torch
8
8
 
9
9
  from torch.nn import CrossEntropyLoss
10
- from transformers.cache_utils import HybridCache
10
+ from transformers.cache_utils import Cache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
12
  from transformers.utils.deprecation import deprecate_kwarg
13
13
 
14
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
17
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
18
 
17
19
  logger = logging.getLogger(__name__)
18
20
 
@@ -22,7 +24,7 @@ def lce_forward_deprecated(
22
24
  input_ids: torch.LongTensor = None,
23
25
  attention_mask: Optional[torch.Tensor] = None,
24
26
  position_ids: Optional[torch.LongTensor] = None,
25
- past_key_values: Optional[HybridCache] = None,
27
+ past_key_values: Optional[Cache] = None,
26
28
  inputs_embeds: Optional[torch.FloatTensor] = None,
27
29
  labels: Optional[torch.LongTensor] = None,
28
30
  use_cache: Optional[bool] = None,
@@ -147,7 +149,7 @@ def lce_forward(
147
149
  input_ids: torch.LongTensor = None,
148
150
  attention_mask: Optional[torch.Tensor] = None,
149
151
  position_ids: Optional[torch.LongTensor] = None,
150
- past_key_values: Optional[HybridCache] = None,
152
+ past_key_values: Optional[Cache] = None,
151
153
  inputs_embeds: Optional[torch.FloatTensor] = None,
152
154
  labels: Optional[torch.LongTensor] = None,
153
155
  use_cache: Optional[bool] = None,
@@ -158,7 +160,7 @@ def lce_forward(
158
160
  logits_to_keep: Union[int, torch.Tensor] = 0,
159
161
  skip_logits: Optional[bool] = None,
160
162
  **kwargs,
161
- ) -> Union[Tuple, CausalLMOutputWithPast]:
163
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
162
164
  r"""
163
165
  Args:
164
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -225,6 +227,7 @@ def lce_forward(
225
227
  shift_labels = kwargs.pop("shift_labels", None)
226
228
  logits = None
227
229
  loss = None
230
+ token_accuracy = None
228
231
 
229
232
  if skip_logits and labels is None and shift_labels is None:
230
233
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -233,8 +236,9 @@ def lce_forward(
233
236
  # By default, if in training mode, don't materialize logits
234
237
  skip_logits = self.training and (labels is not None or shift_labels is not None)
235
238
 
239
+ # Compute loss
236
240
  if skip_logits:
237
- loss = LigerForCausalLMLoss(
241
+ result = LigerForCausalLMLoss(
238
242
  hidden_states=kept_hidden_states,
239
243
  lm_head_weight=self.lm_head.weight,
240
244
  labels=labels,
@@ -243,6 +247,7 @@ def lce_forward(
243
247
  final_logit_softcapping=self.config.final_logit_softcapping,
244
248
  **kwargs,
245
249
  )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
246
251
 
247
252
  else:
248
253
  logits = self.lm_head(kept_hidden_states)
@@ -262,13 +267,17 @@ def lce_forward(
262
267
  )
263
268
 
264
269
  if not return_dict:
265
- output = (logits,) + outputs[1:]
266
- return (loss,) + output if loss is not None else output
270
+ output_tuple = (logits,) + outputs[1:]
271
+ output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
272
+ output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
273
+ return output_tuple
267
274
 
268
- return CausalLMOutputWithPast(
275
+ # Return custom output class with token_accuracy field
276
+ return LigerCausalLMOutputWithPast(
269
277
  loss=loss,
270
278
  logits=logits,
271
279
  past_key_values=outputs.past_key_values,
272
280
  hidden_states=outputs.hidden_states,
273
281
  attentions=outputs.attentions,
282
+ token_accuracy=token_accuracy,
274
283
  )
@@ -6,13 +6,12 @@ import torch
6
6
  import torch.nn as nn
7
7
 
8
8
  from transformers.cache_utils import Cache
9
- from transformers.cache_utils import HybridCache
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
12
9
  from transformers.utils import logging
13
10
 
14
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
14
+ from liger_kernel.transformers.model.output_classes import LigerGemma3CausalLMOutputWithPast
16
15
 
17
16
  logger = logging.get_logger(__name__)
18
17
 
@@ -22,7 +21,7 @@ def causal_forward(
22
21
  input_ids: torch.LongTensor = None,
23
22
  attention_mask: Optional[torch.Tensor] = None,
24
23
  position_ids: Optional[torch.LongTensor] = None,
25
- past_key_values: Optional[HybridCache] = None,
24
+ past_key_values: Optional[Cache] = None,
26
25
  inputs_embeds: Optional[torch.FloatTensor] = None,
27
26
  labels: Optional[torch.LongTensor] = None,
28
27
  use_cache: Optional[bool] = None,
@@ -33,7 +32,7 @@ def causal_forward(
33
32
  logits_to_keep: Union[int, torch.Tensor] = 0,
34
33
  skip_logits: Optional[bool] = None,
35
34
  **loss_kwargs,
36
- ) -> Union[Tuple, CausalLMOutputWithPast]:
35
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
37
36
  r"""
38
37
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
38
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -98,12 +97,14 @@ def causal_forward(
98
97
  shift_labels = loss_kwargs.pop("shift_labels", None)
99
98
  loss = None
100
99
  logits = None
100
+ token_accuracy = None
101
101
 
102
102
  if skip_logits is None:
103
103
  skip_logits = self.training and (labels is not None or shift_labels is not None)
104
104
 
105
+ # Compute loss
105
106
  if skip_logits:
106
- loss = LigerForCausalLMLoss(
107
+ result = LigerForCausalLMLoss(
107
108
  hidden_states=kept_hidden_states,
108
109
  lm_head_weight=self.lm_head.weight,
109
110
  labels=labels,
@@ -112,7 +113,7 @@ def causal_forward(
112
113
  final_logit_softcapping=self.config.final_logit_softcapping,
113
114
  **loss_kwargs,
114
115
  )
115
-
116
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
116
117
  else:
117
118
  logits = self.lm_head(kept_hidden_states)
118
119
  if self.config.final_logit_softcapping is not None:
@@ -129,15 +130,19 @@ def causal_forward(
129
130
  )
130
131
 
131
132
  if not return_dict:
132
- output = (logits,) + outputs[1:]
133
- return (loss,) + output if loss is not None else output
133
+ output_tuple = (logits,) + outputs[1:]
134
+ output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
135
+ output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
136
+ return output_tuple
134
137
 
135
- return CausalLMOutputWithPast(
138
+ # Return custom output class with token_accuracy field
139
+ return LigerCausalLMOutputWithPast(
136
140
  loss=loss,
137
141
  logits=logits,
138
142
  past_key_values=outputs.past_key_values,
139
143
  hidden_states=outputs.hidden_states,
140
144
  attentions=outputs.attentions,
145
+ token_accuracy=token_accuracy,
141
146
  )
142
147
 
143
148
 
@@ -159,7 +164,7 @@ def multimodal_forward(
159
164
  logits_to_keep: Union[int, torch.Tensor] = 0,
160
165
  skip_logits: Optional[bool] = None,
161
166
  **lm_kwargs,
162
- ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
167
+ ) -> Union[tuple, LigerGemma3CausalLMOutputWithPast]:
163
168
  r"""
164
169
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
165
170
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -228,6 +233,7 @@ def multimodal_forward(
228
233
  **lm_kwargs,
229
234
  )
230
235
 
236
+ shift_labels = lm_kwargs.pop("shift_labels", None)
231
237
  hidden_states = outputs[0]
232
238
 
233
239
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
@@ -235,6 +241,7 @@ def multimodal_forward(
235
241
 
236
242
  loss = None
237
243
  logits = None
244
+ token_accuracy = None
238
245
  if skip_logits and labels is None:
239
246
  raise ValueError("skip_logits is True, but labels is None")
240
247
 
@@ -260,8 +267,17 @@ def multimodal_forward(
260
267
  shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
261
268
  shift_labels = shift_labels.view(-1).to(hidden_device)
262
269
 
263
- lce = LigerFusedLinearCrossEntropyLoss()
264
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
270
+ result = LigerForCausalLMLoss(
271
+ hidden_states=shift_hidden_states,
272
+ lm_head_weight=self.lm_head.weight,
273
+ labels=shift_labels,
274
+ hidden_size=self.config.text_config.hidden_size,
275
+ shift_labels=shift_labels,
276
+ final_logit_softcapping=getattr(self.config.text_config, "final_logit_softcapping", None),
277
+ **lm_kwargs,
278
+ )
279
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
280
+
265
281
  else:
266
282
  logits = self.lm_head(kept_hidden_states)
267
283
  if labels is not None:
@@ -306,13 +322,16 @@ def multimodal_forward(
306
322
 
307
323
  if not return_dict:
308
324
  output = (logits,) + outputs[1:]
309
- return (loss,) + output if loss is not None else output
325
+ output = (loss,) + output if loss is not None else output
326
+ output = output + (token_accuracy,) if token_accuracy is not None else output
327
+ return output
310
328
 
311
- return Gemma3CausalLMOutputWithPast(
329
+ return LigerGemma3CausalLMOutputWithPast(
312
330
  loss=loss,
313
331
  logits=logits,
314
332
  past_key_values=outputs.past_key_values,
315
333
  hidden_states=outputs.hidden_states,
316
334
  attentions=outputs.attentions,
317
335
  image_hidden_states=outputs.image_hidden_states,
336
+ token_accuracy=token_accuracy,
318
337
  )
@@ -5,10 +5,11 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.modeling_outputs import CausalLMOutputWithPast
9
8
  from transformers.utils.deprecation import deprecate_kwarg
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -28,7 +29,7 @@ def lce_forward(
28
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
30
  skip_logits: Optional[bool] = None,
30
31
  **kwargs,
31
- ) -> Union[Tuple, CausalLMOutputWithPast]:
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
32
33
  r"""
33
34
  Args:
34
35
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -91,6 +92,7 @@ def lce_forward(
91
92
  shift_labels = kwargs.pop("shift_labels", None)
92
93
  logits = None
93
94
  loss = None
95
+ token_accuracy = None
94
96
 
95
97
  if skip_logits and labels is None and shift_labels is None:
96
98
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -99,8 +101,9 @@ def lce_forward(
99
101
  # By default, if in training mode, don't materialize logits
100
102
  skip_logits = self.training and (labels is not None or shift_labels is not None)
101
103
 
104
+ # Compute loss
102
105
  if skip_logits:
103
- loss = LigerForCausalLMLoss(
106
+ result = LigerForCausalLMLoss(
104
107
  hidden_states=kept_hidden_states,
105
108
  lm_head_weight=self.lm_head.weight,
106
109
  labels=labels,
@@ -108,6 +111,7 @@ def lce_forward(
108
111
  hidden_size=self.config.hidden_size,
109
112
  **kwargs,
110
113
  )
114
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
111
115
 
112
116
  else:
113
117
  logits = self.lm_head(kept_hidden_states)
@@ -120,10 +124,18 @@ def lce_forward(
120
124
  **kwargs,
121
125
  )
122
126
 
123
- return CausalLMOutputWithPast(
127
+ if not return_dict:
128
+ output = (logits,) + outputs[1:]
129
+ output = ((loss,) + output) if loss is not None else output
130
+ output = output + (token_accuracy,) if token_accuracy is not None else output
131
+ return output
132
+
133
+ # Return custom output class with token_accuracy field
134
+ return LigerCausalLMOutputWithPast(
124
135
  loss=loss,
125
136
  logits=logits,
126
137
  past_key_values=outputs.past_key_values,
127
138
  hidden_states=outputs.hidden_states,
128
139
  attentions=outputs.attentions,
140
+ token_accuracy=token_accuracy,
129
141
  )
@@ -5,10 +5,11 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.modeling_outputs import CausalLMOutputWithPast
9
8
  from transformers.utils.deprecation import deprecate_kwarg
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -28,7 +29,7 @@ def lce_forward(
28
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
30
  skip_logits: Optional[bool] = None,
30
31
  **kwargs,
31
- ) -> Union[Tuple, CausalLMOutputWithPast]:
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
32
33
  r"""
33
34
  Args:
34
35
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -113,6 +114,7 @@ def lce_forward(
113
114
  shift_labels = kwargs.pop("shift_labels", None)
114
115
  logits = None
115
116
  loss = None
117
+ token_accuracy = None
116
118
 
117
119
  if skip_logits and labels is None and shift_labels is None:
118
120
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -121,8 +123,9 @@ def lce_forward(
121
123
  # By default, if in training mode, don't materialize logits
122
124
  skip_logits = self.training and (labels is not None or shift_labels is not None)
123
125
 
126
+ # Compute loss
124
127
  if skip_logits:
125
- loss = LigerForCausalLMLoss(
128
+ result = LigerForCausalLMLoss(
126
129
  hidden_states=kept_hidden_states,
127
130
  lm_head_weight=self.lm_head.weight,
128
131
  labels=labels,
@@ -130,6 +133,7 @@ def lce_forward(
130
133
  hidden_size=self.config.hidden_size,
131
134
  **kwargs,
132
135
  )
136
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
133
137
 
134
138
  else:
135
139
  logits = self.lm_head(kept_hidden_states)
@@ -142,10 +146,18 @@ def lce_forward(
142
146
  **kwargs,
143
147
  )
144
148
 
145
- return CausalLMOutputWithPast(
149
+ if not return_dict:
150
+ output = (logits,) + outputs[1:]
151
+ output = ((loss,) + output) if loss is not None else output
152
+ output = output + (token_accuracy,) if token_accuracy is not None else output
153
+ return output
154
+
155
+ # Return custom output class with token_accuracy field
156
+ return LigerCausalLMOutputWithPast(
146
157
  loss=loss,
147
158
  logits=logits,
148
159
  past_key_values=outputs.past_key_values,
149
160
  hidden_states=outputs.hidden_states,
150
161
  attentions=outputs.attentions,
162
+ token_accuracy=token_accuracy,
151
163
  )