liger-kernel-nightly 0.6.2.dev20250919191028__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +120 -63
  7. liger_kernel/ops/dyt.py +5 -2
  8. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  9. liger_kernel/ops/fused_linear_cross_entropy.py +43 -12
  10. liger_kernel/ops/geglu.py +2 -1
  11. liger_kernel/ops/group_norm.py +2 -1
  12. liger_kernel/ops/grpo_loss.py +3 -1
  13. liger_kernel/ops/layer_norm.py +88 -70
  14. liger_kernel/ops/poly_norm.py +390 -0
  15. liger_kernel/ops/rms_norm.py +7 -2
  16. liger_kernel/ops/tiled_mlp.py +136 -0
  17. liger_kernel/ops/utils.py +2 -0
  18. liger_kernel/transformers/__init__.py +33 -0
  19. liger_kernel/transformers/cross_entropy.py +8 -3
  20. liger_kernel/transformers/functional.py +29 -6
  21. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  22. liger_kernel/transformers/grpo_loss.py +56 -1
  23. liger_kernel/transformers/model/falcon_h1.py +122 -0
  24. liger_kernel/transformers/model/gemma.py +19 -7
  25. liger_kernel/transformers/model/gemma2.py +22 -7
  26. liger_kernel/transformers/model/gemma3.py +52 -14
  27. liger_kernel/transformers/model/glm4.py +18 -5
  28. liger_kernel/transformers/model/glm4v.py +18 -5
  29. liger_kernel/transformers/model/glm4v_moe.py +25 -5
  30. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  31. liger_kernel/transformers/model/internvl.py +157 -0
  32. liger_kernel/transformers/model/llama.py +16 -6
  33. liger_kernel/transformers/model/llama4.py +18 -5
  34. liger_kernel/transformers/model/llava.py +18 -6
  35. liger_kernel/transformers/model/loss_utils.py +31 -3
  36. liger_kernel/transformers/model/mistral.py +17 -7
  37. liger_kernel/transformers/model/mixtral.py +24 -9
  38. liger_kernel/transformers/model/mllama.py +14 -5
  39. liger_kernel/transformers/model/olmo2.py +18 -5
  40. liger_kernel/transformers/model/olmo3.py +142 -0
  41. liger_kernel/transformers/model/output_classes.py +147 -0
  42. liger_kernel/transformers/model/paligemma.py +41 -5
  43. liger_kernel/transformers/model/phi3.py +16 -8
  44. liger_kernel/transformers/model/qwen2.py +18 -4
  45. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  46. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  47. liger_kernel/transformers/model/qwen3.py +22 -6
  48. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  49. liger_kernel/transformers/model/qwen3_next.py +146 -0
  50. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  51. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  52. liger_kernel/transformers/model/smollm3.py +17 -7
  53. liger_kernel/transformers/model/smolvlm.py +158 -0
  54. liger_kernel/transformers/monkey_patch.py +729 -4
  55. liger_kernel/transformers/poly_norm.py +42 -0
  56. liger_kernel/transformers/rms_norm.py +7 -0
  57. liger_kernel/transformers/rope.py +43 -0
  58. liger_kernel/transformers/swiglu.py +17 -0
  59. liger_kernel/transformers/tiled_mlp.py +133 -0
  60. liger_kernel/utils.py +25 -0
  61. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +13 -6
  62. liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
  63. liger_kernel_nightly-0.6.2.dev20250919191028.dist-info/RECORD +0 -105
  64. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,8 @@ from transformers.utils.deprecation import deprecate_kwarg
12
12
 
13
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
14
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
15
17
 
16
18
 
17
19
  def lce_forward_deprecated(
@@ -147,7 +149,7 @@ def lce_forward(
147
149
  logits_to_keep: Union[int, torch.Tensor] = 0,
148
150
  skip_logits: Optional[bool] = None,
149
151
  **kwargs,
150
- ) -> Union[Tuple, CausalLMOutputWithPast]:
152
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
151
153
  r"""
152
154
  Args:
153
155
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -209,6 +211,7 @@ def lce_forward(
209
211
  shift_labels = kwargs.pop("shift_labels", None)
210
212
  logits = None
211
213
  loss = None
214
+ token_accuracy = None
212
215
 
213
216
  if skip_logits and labels is None and shift_labels is None:
214
217
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -217,8 +220,9 @@ def lce_forward(
217
220
  # By default, if in training mode, don't materialize logits
218
221
  skip_logits = self.training and (labels is not None or shift_labels is not None)
219
222
 
223
+ # Compute loss
220
224
  if skip_logits:
221
- loss = LigerForCausalLMLoss(
225
+ result = LigerForCausalLMLoss(
222
226
  hidden_states=kept_hidden_states,
223
227
  lm_head_weight=self.lm_head.weight,
224
228
  labels=labels,
@@ -226,24 +230,32 @@ def lce_forward(
226
230
  hidden_size=self.config.hidden_size,
227
231
  **kwargs,
228
232
  )
233
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
229
234
  else:
230
235
  logits = self.lm_head(kept_hidden_states)
231
- if labels is not None:
236
+ if labels is not None or shift_labels is not None:
232
237
  loss = self.loss_function(
233
238
  logits=logits,
234
239
  labels=labels,
240
+ shift_labels=shift_labels,
235
241
  vocab_size=self.config.vocab_size,
236
242
  **kwargs,
237
243
  )
238
244
 
239
245
  if not return_dict:
240
- output = (logits,) + outputs[1:]
241
- return (loss,) + output if loss is not None else output
242
-
243
- return CausalLMOutputWithPast(
246
+ output_tuple = (logits,) + outputs[1:]
247
+ if loss is not None:
248
+ output_tuple = (loss,) + output_tuple
249
+ if token_accuracy is not None:
250
+ output_tuple = output_tuple + (token_accuracy,)
251
+ return output_tuple
252
+
253
+ # Return custom output class with token_accuracy field
254
+ return LigerCausalLMOutputWithPast(
244
255
  loss=loss,
245
256
  logits=logits,
246
257
  past_key_values=outputs.past_key_values,
247
258
  hidden_states=outputs.hidden_states,
248
259
  attentions=outputs.attentions,
260
+ token_accuracy=token_accuracy,
249
261
  )
@@ -13,6 +13,8 @@ from transformers.utils.deprecation import deprecate_kwarg
13
13
 
14
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
17
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
18
 
17
19
  logger = logging.getLogger(__name__)
18
20
 
@@ -158,7 +160,7 @@ def lce_forward(
158
160
  logits_to_keep: Union[int, torch.Tensor] = 0,
159
161
  skip_logits: Optional[bool] = None,
160
162
  **kwargs,
161
- ) -> Union[Tuple, CausalLMOutputWithPast]:
163
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
162
164
  r"""
163
165
  Args:
164
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -225,6 +227,7 @@ def lce_forward(
225
227
  shift_labels = kwargs.pop("shift_labels", None)
226
228
  logits = None
227
229
  loss = None
230
+ token_accuracy = None
228
231
 
229
232
  if skip_logits and labels is None and shift_labels is None:
230
233
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -233,8 +236,9 @@ def lce_forward(
233
236
  # By default, if in training mode, don't materialize logits
234
237
  skip_logits = self.training and (labels is not None or shift_labels is not None)
235
238
 
239
+ # Compute loss
236
240
  if skip_logits:
237
- loss = LigerForCausalLMLoss(
241
+ result = LigerForCausalLMLoss(
238
242
  hidden_states=kept_hidden_states,
239
243
  lm_head_weight=self.lm_head.weight,
240
244
  labels=labels,
@@ -243,6 +247,7 @@ def lce_forward(
243
247
  final_logit_softcapping=self.config.final_logit_softcapping,
244
248
  **kwargs,
245
249
  )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
246
251
 
247
252
  else:
248
253
  logits = self.lm_head(kept_hidden_states)
@@ -252,17 +257,27 @@ def lce_forward(
252
257
  logits = logits * self.config.final_logit_softcapping
253
258
 
254
259
  loss = None
255
- if labels is not None:
256
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
260
+ if labels is not None or shift_labels is not None:
261
+ loss = self.loss_function(
262
+ logits=logits,
263
+ labels=labels,
264
+ shift_labels=shift_labels,
265
+ vocab_size=self.vocab_size,
266
+ **kwargs,
267
+ )
257
268
 
258
269
  if not return_dict:
259
- output = (logits,) + outputs[1:]
260
- return (loss,) + output if loss is not None else output
270
+ output_tuple = (logits,) + outputs[1:]
271
+ output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
272
+ output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
273
+ return output_tuple
261
274
 
262
- return CausalLMOutputWithPast(
275
+ # Return custom output class with token_accuracy field
276
+ return LigerCausalLMOutputWithPast(
263
277
  loss=loss,
264
278
  logits=logits,
265
279
  past_key_values=outputs.past_key_values,
266
280
  hidden_states=outputs.hidden_states,
267
281
  attentions=outputs.attentions,
282
+ token_accuracy=token_accuracy,
268
283
  )
@@ -7,12 +7,13 @@ import torch.nn as nn
7
7
 
8
8
  from transformers.cache_utils import Cache
9
9
  from transformers.cache_utils import HybridCache
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
12
10
  from transformers.utils import logging
13
11
 
14
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
+ from liger_kernel.transformers.model.output_classes import LigerGemma3CausalLMOutputWithPast
16
17
 
17
18
  logger = logging.get_logger(__name__)
18
19
 
@@ -33,7 +34,7 @@ def causal_forward(
33
34
  logits_to_keep: Union[int, torch.Tensor] = 0,
34
35
  skip_logits: Optional[bool] = None,
35
36
  **loss_kwargs,
36
- ) -> Union[Tuple, CausalLMOutputWithPast]:
37
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
37
38
  r"""
38
39
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
40
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -98,12 +99,14 @@ def causal_forward(
98
99
  shift_labels = loss_kwargs.pop("shift_labels", None)
99
100
  loss = None
100
101
  logits = None
102
+ token_accuracy = None
101
103
 
102
104
  if skip_logits is None:
103
105
  skip_logits = self.training and (labels is not None or shift_labels is not None)
104
106
 
107
+ # Compute loss
105
108
  if skip_logits:
106
- loss = LigerForCausalLMLoss(
109
+ result = LigerForCausalLMLoss(
107
110
  hidden_states=kept_hidden_states,
108
111
  lm_head_weight=self.lm_head.weight,
109
112
  labels=labels,
@@ -112,26 +115,36 @@ def causal_forward(
112
115
  final_logit_softcapping=self.config.final_logit_softcapping,
113
116
  **loss_kwargs,
114
117
  )
115
-
118
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
116
119
  else:
117
120
  logits = self.lm_head(kept_hidden_states)
118
121
  if self.config.final_logit_softcapping is not None:
119
122
  logits = logits / self.config.final_logit_softcapping
120
123
  logits = torch.tanh(logits)
121
124
  logits = logits * self.config.final_logit_softcapping
122
- if labels is not None:
123
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
125
+ if labels is not None or shift_labels is not None:
126
+ loss = self.loss_function(
127
+ logits=logits,
128
+ labels=labels,
129
+ shift_labels=shift_labels,
130
+ vocab_size=self.vocab_size,
131
+ **loss_kwargs,
132
+ )
124
133
 
125
134
  if not return_dict:
126
- output = (logits,) + outputs[1:]
127
- return (loss,) + output if loss is not None else output
135
+ output_tuple = (logits,) + outputs[1:]
136
+ output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
137
+ output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
138
+ return output_tuple
128
139
 
129
- return CausalLMOutputWithPast(
140
+ # Return custom output class with token_accuracy field
141
+ return LigerCausalLMOutputWithPast(
130
142
  loss=loss,
131
143
  logits=logits,
132
144
  past_key_values=outputs.past_key_values,
133
145
  hidden_states=outputs.hidden_states,
134
146
  attentions=outputs.attentions,
147
+ token_accuracy=token_accuracy,
135
148
  )
136
149
 
137
150
 
@@ -153,7 +166,7 @@ def multimodal_forward(
153
166
  logits_to_keep: Union[int, torch.Tensor] = 0,
154
167
  skip_logits: Optional[bool] = None,
155
168
  **lm_kwargs,
156
- ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
169
+ ) -> Union[tuple, LigerGemma3CausalLMOutputWithPast]:
157
170
  r"""
158
171
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
159
172
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -229,6 +242,7 @@ def multimodal_forward(
229
242
 
230
243
  loss = None
231
244
  logits = None
245
+ token_accuracy = None
232
246
  if skip_logits and labels is None:
233
247
  raise ValueError("skip_logits is True, but labels is None")
234
248
 
@@ -255,7 +269,9 @@ def multimodal_forward(
255
269
  shift_labels = shift_labels.view(-1).to(hidden_device)
256
270
 
257
271
  lce = LigerFusedLinearCrossEntropyLoss()
258
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
272
+ result = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
273
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
274
+
259
275
  else:
260
276
  logits = self.lm_head(kept_hidden_states)
261
277
  if labels is not None:
@@ -275,19 +291,41 @@ def multimodal_forward(
275
291
  # Flatten the tokens
276
292
  loss_fct = nn.CrossEntropyLoss()
277
293
 
294
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
295
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
296
+ loss = loss_fct(flat_logits, flat_labels)
297
+ elif shift_labels is not None:
298
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
299
+ logits = logits.float()
300
+ shift_logits = logits[..., :-1, :]
301
+ if attention_mask is not None:
302
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
303
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
304
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
305
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
306
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
307
+ else:
308
+ shift_logits = shift_logits.contiguous()
309
+ shift_labels = shift_labels.contiguous()
310
+ # Flatten the tokens
311
+ loss_fct = nn.CrossEntropyLoss()
312
+
278
313
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
279
314
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
280
315
  loss = loss_fct(flat_logits, flat_labels)
281
316
 
282
317
  if not return_dict:
283
318
  output = (logits,) + outputs[1:]
284
- return (loss,) + output if loss is not None else output
319
+ output = (loss,) + output if loss is not None else output
320
+ output = output + (token_accuracy,) if token_accuracy is not None else output
321
+ return output
285
322
 
286
- return Gemma3CausalLMOutputWithPast(
323
+ return LigerGemma3CausalLMOutputWithPast(
287
324
  loss=loss,
288
325
  logits=logits,
289
326
  past_key_values=outputs.past_key_values,
290
327
  hidden_states=outputs.hidden_states,
291
328
  attentions=outputs.attentions,
292
329
  image_hidden_states=outputs.image_hidden_states,
330
+ token_accuracy=token_accuracy,
293
331
  )
@@ -5,10 +5,11 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.modeling_outputs import CausalLMOutputWithPast
9
8
  from transformers.utils.deprecation import deprecate_kwarg
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -28,7 +29,7 @@ def lce_forward(
28
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
30
  skip_logits: Optional[bool] = None,
30
31
  **kwargs,
31
- ) -> Union[Tuple, CausalLMOutputWithPast]:
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
32
33
  r"""
33
34
  Args:
34
35
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -91,6 +92,7 @@ def lce_forward(
91
92
  shift_labels = kwargs.pop("shift_labels", None)
92
93
  logits = None
93
94
  loss = None
95
+ token_accuracy = None
94
96
 
95
97
  if skip_logits and labels is None and shift_labels is None:
96
98
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -99,8 +101,9 @@ def lce_forward(
99
101
  # By default, if in training mode, don't materialize logits
100
102
  skip_logits = self.training and (labels is not None or shift_labels is not None)
101
103
 
104
+ # Compute loss
102
105
  if skip_logits:
103
- loss = LigerForCausalLMLoss(
106
+ result = LigerForCausalLMLoss(
104
107
  hidden_states=kept_hidden_states,
105
108
  lm_head_weight=self.lm_head.weight,
106
109
  labels=labels,
@@ -108,21 +111,31 @@ def lce_forward(
108
111
  hidden_size=self.config.hidden_size,
109
112
  **kwargs,
110
113
  )
114
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
111
115
 
112
116
  else:
113
117
  logits = self.lm_head(kept_hidden_states)
114
- if labels is not None:
118
+ if labels is not None or shift_labels is not None:
115
119
  loss = self.loss_function(
116
120
  logits=logits,
117
121
  labels=labels,
122
+ shift_labels=shift_labels,
118
123
  vocab_size=self.config.vocab_size,
119
124
  **kwargs,
120
125
  )
121
126
 
122
- return CausalLMOutputWithPast(
127
+ if not return_dict:
128
+ output = (logits,) + outputs[1:]
129
+ output = ((loss,) + output) if loss is not None else output
130
+ output = output + (token_accuracy,) if token_accuracy is not None else output
131
+ return output
132
+
133
+ # Return custom output class with token_accuracy field
134
+ return LigerCausalLMOutputWithPast(
123
135
  loss=loss,
124
136
  logits=logits,
125
137
  past_key_values=outputs.past_key_values,
126
138
  hidden_states=outputs.hidden_states,
127
139
  attentions=outputs.attentions,
140
+ token_accuracy=token_accuracy,
128
141
  )
@@ -5,10 +5,11 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.modeling_outputs import CausalLMOutputWithPast
9
8
  from transformers.utils.deprecation import deprecate_kwarg
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -28,7 +29,7 @@ def lce_forward(
28
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
30
  skip_logits: Optional[bool] = None,
30
31
  **kwargs,
31
- ) -> Union[Tuple, CausalLMOutputWithPast]:
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
32
33
  r"""
33
34
  Args:
34
35
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -113,6 +114,7 @@ def lce_forward(
113
114
  shift_labels = kwargs.pop("shift_labels", None)
114
115
  logits = None
115
116
  loss = None
117
+ token_accuracy = None
116
118
 
117
119
  if skip_logits and labels is None and shift_labels is None:
118
120
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -121,8 +123,9 @@ def lce_forward(
121
123
  # By default, if in training mode, don't materialize logits
122
124
  skip_logits = self.training and (labels is not None or shift_labels is not None)
123
125
 
126
+ # Compute loss
124
127
  if skip_logits:
125
- loss = LigerForCausalLMLoss(
128
+ result = LigerForCausalLMLoss(
126
129
  hidden_states=kept_hidden_states,
127
130
  lm_head_weight=self.lm_head.weight,
128
131
  labels=labels,
@@ -130,21 +133,31 @@ def lce_forward(
130
133
  hidden_size=self.config.hidden_size,
131
134
  **kwargs,
132
135
  )
136
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
133
137
 
134
138
  else:
135
139
  logits = self.lm_head(kept_hidden_states)
136
- if labels is not None:
140
+ if labels is not None or shift_labels is not None:
137
141
  loss = self.loss_function(
138
142
  logits=logits,
139
143
  labels=labels,
144
+ shift_labels=shift_labels,
140
145
  vocab_size=self.config.vocab_size,
141
146
  **kwargs,
142
147
  )
143
148
 
144
- return CausalLMOutputWithPast(
149
+ if not return_dict:
150
+ output = (logits,) + outputs[1:]
151
+ output = ((loss,) + output) if loss is not None else output
152
+ output = output + (token_accuracy,) if token_accuracy is not None else output
153
+ return output
154
+
155
+ # Return custom output class with token_accuracy field
156
+ return LigerCausalLMOutputWithPast(
145
157
  loss=loss,
146
158
  logits=logits,
147
159
  past_key_values=outputs.past_key_values,
148
160
  hidden_states=outputs.hidden_states,
149
161
  attentions=outputs.attentions,
162
+ token_accuracy=token_accuracy,
150
163
  )
@@ -4,10 +4,11 @@ from typing import Union
4
4
 
5
5
  import torch
6
6
 
7
- from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeCausalLMOutputWithPast
8
7
  from transformers.utils.deprecation import deprecate_kwarg
9
8
 
10
9
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
10
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
11
+ from liger_kernel.transformers.model.output_classes import LigerGlm4vMoeCausalLMOutputWithPast
11
12
 
12
13
 
13
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -27,8 +28,9 @@ def lce_forward(
27
28
  cache_position: Optional[torch.LongTensor] = None,
28
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
30
  skip_logits: Optional[bool] = None,
31
+ return_dict: Optional[bool] = None,
30
32
  **kwargs,
31
- ) -> Union[Tuple, Glm4vMoeCausalLMOutputWithPast]:
33
+ ) -> Union[Tuple, LigerGlm4vMoeCausalLMOutputWithPast]:
32
34
  r"""
33
35
  Args:
34
36
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -90,6 +92,7 @@ def lce_forward(
90
92
  >>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
91
93
  ```
92
94
  """
95
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
93
96
 
94
97
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
95
98
  outputs = self.model(
@@ -114,6 +117,7 @@ def lce_forward(
114
117
  shift_labels = kwargs.pop("shift_labels", None)
115
118
  logits = None
116
119
  loss = None
120
+ token_accuracy = None
117
121
 
118
122
  if skip_logits and labels is None and shift_labels is None:
119
123
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -122,8 +126,9 @@ def lce_forward(
122
126
  # By default, if in training mode, don't materialize logits
123
127
  skip_logits = self.training and (labels is not None or shift_labels is not None)
124
128
 
129
+ # Compute loss
125
130
  if skip_logits:
126
- loss = LigerForCausalLMLoss(
131
+ result = LigerForCausalLMLoss(
127
132
  hidden_states=kept_hidden_states,
128
133
  lm_head_weight=self.lm_head.weight,
129
134
  labels=labels,
@@ -131,22 +136,37 @@ def lce_forward(
131
136
  hidden_size=self.config.hidden_size,
132
137
  **kwargs,
133
138
  )
139
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
134
140
 
135
141
  else:
136
142
  logits = self.lm_head(kept_hidden_states)
137
- if labels is not None:
143
+ if labels is not None or shift_labels is not None:
138
144
  loss = self.loss_function(
139
145
  logits=logits,
140
146
  labels=labels,
147
+ shift_labels=shift_labels,
141
148
  vocab_size=self.config.vocab_size,
142
149
  **kwargs,
143
150
  )
144
151
 
145
- return Glm4vMoeCausalLMOutputWithPast(
152
+ if not return_dict:
153
+ output = (logits,) + outputs[1:]
154
+ output = ((loss,) + output) if loss is not None else output
155
+ output = output + (token_accuracy,) if token_accuracy is not None else output
156
+ return output
157
+
158
+ # Build output kwargs and include aux_loss only if present (depends on transformers version)
159
+ output_kwargs = dict(
146
160
  loss=loss,
147
161
  logits=logits,
148
162
  past_key_values=outputs.past_key_values,
149
163
  hidden_states=outputs.hidden_states,
150
164
  attentions=outputs.attentions,
151
165
  rope_deltas=outputs.rope_deltas,
166
+ token_accuracy=token_accuracy,
152
167
  )
168
+ if hasattr(outputs, "aux_loss"):
169
+ output_kwargs["aux_loss"] = outputs.aux_loss
170
+
171
+ # Return GLM4V MoE output with accuracy
172
+ return LigerGlm4vMoeCausalLMOutputWithPast(**output_kwargs)
@@ -0,0 +1,134 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
8
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
9
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10
+
11
+
12
+ def lce_forward(
13
+ self,
14
+ input_ids: Optional[torch.LongTensor] = None,
15
+ attention_mask: Optional[torch.Tensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
18
+ inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ cache_position: Optional[torch.LongTensor] = None,
24
+ logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ **kwargs,
28
+ ) -> LigerCausalLMOutputWithPast:
29
+ r"""
30
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
31
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
32
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
33
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
34
+
35
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
36
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
37
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
38
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
39
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
40
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ```python
47
+ >>> from transformers import AutoTokenizer, HunYuanDenseV1ForCausalLM
48
+
49
+ >>> model = HunYuanDenseV1ForCausalLM.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
50
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61
+ output_hidden_states = (
62
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63
+ )
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67
+ outputs = self.model(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ position_ids=position_ids,
71
+ past_key_values=past_key_values,
72
+ inputs_embeds=inputs_embeds,
73
+ use_cache=use_cache,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ cache_position=cache_position,
77
+ **kwargs,
78
+ )
79
+
80
+ hidden_states = outputs[0]
81
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
82
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
83
+ kept_hidden_states = hidden_states[:, slice_indices, :]
84
+
85
+ shift_labels = kwargs.pop("shift_labels", None)
86
+ logits = None
87
+ loss = None
88
+ token_accuracy = None
89
+
90
+ if skip_logits and labels is None and shift_labels is None:
91
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
92
+
93
+ if skip_logits is None:
94
+ # By default, if in training mode, don't materialize logits
95
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
96
+
97
+ # Compute loss
98
+ if skip_logits:
99
+ result = LigerForCausalLMLoss(
100
+ hidden_states=kept_hidden_states,
101
+ lm_head_weight=self.lm_head.weight,
102
+ labels=labels,
103
+ shift_labels=shift_labels,
104
+ hidden_size=self.config.hidden_size,
105
+ **kwargs,
106
+ )
107
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
108
+
109
+ else:
110
+ logits = self.lm_head(kept_hidden_states)
111
+ if labels is not None or shift_labels is not None:
112
+ loss = self.loss_function(
113
+ logits=logits,
114
+ labels=labels,
115
+ shift_labels=shift_labels,
116
+ vocab_size=self.config.vocab_size,
117
+ **kwargs,
118
+ )
119
+
120
+ if not return_dict:
121
+ output = (logits,) + outputs[1:]
122
+ output = ((loss,) + output) if loss is not None else output
123
+ output = output + (token_accuracy,) if token_accuracy is not None else output
124
+ return output
125
+
126
+ # Return custom output class with accuracy field
127
+ return LigerCausalLMOutputWithPast(
128
+ loss=loss,
129
+ logits=logits,
130
+ past_key_values=outputs.past_key_values,
131
+ hidden_states=outputs.hidden_states,
132
+ attentions=outputs.attentions,
133
+ token_accuracy=token_accuracy,
134
+ )