liger-kernel-nightly 0.6.3.dev20251105190428__py3-none-any.whl → 0.6.3.dev20251105235313__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. liger_kernel/ops/cross_entropy.py +59 -9
  2. liger_kernel/ops/fused_linear_cross_entropy.py +27 -4
  3. liger_kernel/transformers/cross_entropy.py +8 -3
  4. liger_kernel/transformers/functional.py +24 -6
  5. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  6. liger_kernel/transformers/model/falcon_h1.py +19 -5
  7. liger_kernel/transformers/model/gemma.py +17 -6
  8. liger_kernel/transformers/model/gemma2.py +14 -5
  9. liger_kernel/transformers/model/gemma3.py +25 -12
  10. liger_kernel/transformers/model/glm4.py +16 -4
  11. liger_kernel/transformers/model/glm4v.py +16 -4
  12. liger_kernel/transformers/model/glm4v_moe.py +19 -4
  13. liger_kernel/transformers/model/internvl.py +12 -5
  14. liger_kernel/transformers/model/llama.py +14 -5
  15. liger_kernel/transformers/model/llama4.py +16 -4
  16. liger_kernel/transformers/model/llava.py +12 -4
  17. liger_kernel/transformers/model/loss_utils.py +31 -3
  18. liger_kernel/transformers/model/mistral.py +15 -6
  19. liger_kernel/transformers/model/mixtral.py +16 -7
  20. liger_kernel/transformers/model/mllama.py +12 -4
  21. liger_kernel/transformers/model/olmo2.py +16 -4
  22. liger_kernel/transformers/model/output_classes.py +147 -0
  23. liger_kernel/transformers/model/paligemma.py +22 -5
  24. liger_kernel/transformers/model/phi3.py +14 -7
  25. liger_kernel/transformers/model/qwen2.py +16 -3
  26. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  27. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  28. liger_kernel/transformers/model/qwen3.py +18 -5
  29. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  30. liger_kernel/transformers/model/qwen3_next.py +17 -5
  31. liger_kernel/transformers/model/qwen3_vl.py +11 -5
  32. liger_kernel/transformers/model/qwen3_vl_moe.py +12 -5
  33. liger_kernel/transformers/model/smollm3.py +15 -6
  34. liger_kernel/transformers/monkey_patch.py +4 -2
  35. {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/METADATA +1 -1
  36. {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/RECORD +40 -39
  37. {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/LICENSE +0 -0
  38. {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/NOTICE +0 -0
  39. {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/WHEEL +0 -0
  40. {liger_kernel_nightly-0.6.3.dev20251105190428.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ from transformers.utils.deprecation import deprecate_kwarg
13
13
 
14
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
17
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
18
 
17
19
  logger = logging.getLogger(__name__)
18
20
 
@@ -158,7 +160,7 @@ def lce_forward(
158
160
  logits_to_keep: Union[int, torch.Tensor] = 0,
159
161
  skip_logits: Optional[bool] = None,
160
162
  **kwargs,
161
- ) -> Union[Tuple, CausalLMOutputWithPast]:
163
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
162
164
  r"""
163
165
  Args:
164
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -225,6 +227,7 @@ def lce_forward(
225
227
  shift_labels = kwargs.pop("shift_labels", None)
226
228
  logits = None
227
229
  loss = None
230
+ token_accuracy = None
228
231
 
229
232
  if skip_logits and labels is None and shift_labels is None:
230
233
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -233,8 +236,9 @@ def lce_forward(
233
236
  # By default, if in training mode, don't materialize logits
234
237
  skip_logits = self.training and (labels is not None or shift_labels is not None)
235
238
 
239
+ # Compute loss
236
240
  if skip_logits:
237
- loss = LigerForCausalLMLoss(
241
+ result = LigerForCausalLMLoss(
238
242
  hidden_states=kept_hidden_states,
239
243
  lm_head_weight=self.lm_head.weight,
240
244
  labels=labels,
@@ -243,6 +247,7 @@ def lce_forward(
243
247
  final_logit_softcapping=self.config.final_logit_softcapping,
244
248
  **kwargs,
245
249
  )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
246
251
 
247
252
  else:
248
253
  logits = self.lm_head(kept_hidden_states)
@@ -262,13 +267,17 @@ def lce_forward(
262
267
  )
263
268
 
264
269
  if not return_dict:
265
- output = (logits,) + outputs[1:]
266
- return (loss,) + output if loss is not None else output
270
+ output_tuple = (logits,) + outputs[1:]
271
+ output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
272
+ output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
273
+ return output_tuple
267
274
 
268
- return CausalLMOutputWithPast(
275
+ # Return custom output class with token_accuracy field
276
+ return LigerCausalLMOutputWithPast(
269
277
  loss=loss,
270
278
  logits=logits,
271
279
  past_key_values=outputs.past_key_values,
272
280
  hidden_states=outputs.hidden_states,
273
281
  attentions=outputs.attentions,
282
+ token_accuracy=token_accuracy,
274
283
  )
@@ -7,12 +7,13 @@ import torch.nn as nn
7
7
 
8
8
  from transformers.cache_utils import Cache
9
9
  from transformers.cache_utils import HybridCache
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
12
10
  from transformers.utils import logging
13
11
 
14
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
+ from liger_kernel.transformers.model.output_classes import LigerGemma3CausalLMOutputWithPast
16
17
 
17
18
  logger = logging.get_logger(__name__)
18
19
 
@@ -33,7 +34,7 @@ def causal_forward(
33
34
  logits_to_keep: Union[int, torch.Tensor] = 0,
34
35
  skip_logits: Optional[bool] = None,
35
36
  **loss_kwargs,
36
- ) -> Union[Tuple, CausalLMOutputWithPast]:
37
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
37
38
  r"""
38
39
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
40
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -98,12 +99,14 @@ def causal_forward(
98
99
  shift_labels = loss_kwargs.pop("shift_labels", None)
99
100
  loss = None
100
101
  logits = None
102
+ token_accuracy = None
101
103
 
102
104
  if skip_logits is None:
103
105
  skip_logits = self.training and (labels is not None or shift_labels is not None)
104
106
 
107
+ # Compute loss
105
108
  if skip_logits:
106
- loss = LigerForCausalLMLoss(
109
+ result = LigerForCausalLMLoss(
107
110
  hidden_states=kept_hidden_states,
108
111
  lm_head_weight=self.lm_head.weight,
109
112
  labels=labels,
@@ -112,7 +115,7 @@ def causal_forward(
112
115
  final_logit_softcapping=self.config.final_logit_softcapping,
113
116
  **loss_kwargs,
114
117
  )
115
-
118
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
116
119
  else:
117
120
  logits = self.lm_head(kept_hidden_states)
118
121
  if self.config.final_logit_softcapping is not None:
@@ -129,15 +132,19 @@ def causal_forward(
129
132
  )
130
133
 
131
134
  if not return_dict:
132
- output = (logits,) + outputs[1:]
133
- return (loss,) + output if loss is not None else output
135
+ output_tuple = (logits,) + outputs[1:]
136
+ output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
137
+ output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
138
+ return output_tuple
134
139
 
135
- return CausalLMOutputWithPast(
140
+ # Return custom output class with token_accuracy field
141
+ return LigerCausalLMOutputWithPast(
136
142
  loss=loss,
137
143
  logits=logits,
138
144
  past_key_values=outputs.past_key_values,
139
145
  hidden_states=outputs.hidden_states,
140
146
  attentions=outputs.attentions,
147
+ token_accuracy=token_accuracy,
141
148
  )
142
149
 
143
150
 
@@ -159,7 +166,7 @@ def multimodal_forward(
159
166
  logits_to_keep: Union[int, torch.Tensor] = 0,
160
167
  skip_logits: Optional[bool] = None,
161
168
  **lm_kwargs,
162
- ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
169
+ ) -> Union[tuple, LigerGemma3CausalLMOutputWithPast]:
163
170
  r"""
164
171
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
165
172
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -235,6 +242,7 @@ def multimodal_forward(
235
242
 
236
243
  loss = None
237
244
  logits = None
245
+ token_accuracy = None
238
246
  if skip_logits and labels is None:
239
247
  raise ValueError("skip_logits is True, but labels is None")
240
248
 
@@ -261,7 +269,9 @@ def multimodal_forward(
261
269
  shift_labels = shift_labels.view(-1).to(hidden_device)
262
270
 
263
271
  lce = LigerFusedLinearCrossEntropyLoss()
264
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
272
+ result = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
273
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
274
+
265
275
  else:
266
276
  logits = self.lm_head(kept_hidden_states)
267
277
  if labels is not None:
@@ -306,13 +316,16 @@ def multimodal_forward(
306
316
 
307
317
  if not return_dict:
308
318
  output = (logits,) + outputs[1:]
309
- return (loss,) + output if loss is not None else output
319
+ output = (loss,) + output if loss is not None else output
320
+ output = output + (token_accuracy,) if token_accuracy is not None else output
321
+ return output
310
322
 
311
- return Gemma3CausalLMOutputWithPast(
323
+ return LigerGemma3CausalLMOutputWithPast(
312
324
  loss=loss,
313
325
  logits=logits,
314
326
  past_key_values=outputs.past_key_values,
315
327
  hidden_states=outputs.hidden_states,
316
328
  attentions=outputs.attentions,
317
329
  image_hidden_states=outputs.image_hidden_states,
330
+ token_accuracy=token_accuracy,
318
331
  )
@@ -5,10 +5,11 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.modeling_outputs import CausalLMOutputWithPast
9
8
  from transformers.utils.deprecation import deprecate_kwarg
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -28,7 +29,7 @@ def lce_forward(
28
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
30
  skip_logits: Optional[bool] = None,
30
31
  **kwargs,
31
- ) -> Union[Tuple, CausalLMOutputWithPast]:
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
32
33
  r"""
33
34
  Args:
34
35
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -91,6 +92,7 @@ def lce_forward(
91
92
  shift_labels = kwargs.pop("shift_labels", None)
92
93
  logits = None
93
94
  loss = None
95
+ token_accuracy = None
94
96
 
95
97
  if skip_logits and labels is None and shift_labels is None:
96
98
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -99,8 +101,9 @@ def lce_forward(
99
101
  # By default, if in training mode, don't materialize logits
100
102
  skip_logits = self.training and (labels is not None or shift_labels is not None)
101
103
 
104
+ # Compute loss
102
105
  if skip_logits:
103
- loss = LigerForCausalLMLoss(
106
+ result = LigerForCausalLMLoss(
104
107
  hidden_states=kept_hidden_states,
105
108
  lm_head_weight=self.lm_head.weight,
106
109
  labels=labels,
@@ -108,6 +111,7 @@ def lce_forward(
108
111
  hidden_size=self.config.hidden_size,
109
112
  **kwargs,
110
113
  )
114
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
111
115
 
112
116
  else:
113
117
  logits = self.lm_head(kept_hidden_states)
@@ -120,10 +124,18 @@ def lce_forward(
120
124
  **kwargs,
121
125
  )
122
126
 
123
- return CausalLMOutputWithPast(
127
+ if not return_dict:
128
+ output = (logits,) + outputs[1:]
129
+ output = ((loss,) + output) if loss is not None else output
130
+ output = output + (token_accuracy,) if token_accuracy is not None else output
131
+ return output
132
+
133
+ # Return custom output class with token_accuracy field
134
+ return LigerCausalLMOutputWithPast(
124
135
  loss=loss,
125
136
  logits=logits,
126
137
  past_key_values=outputs.past_key_values,
127
138
  hidden_states=outputs.hidden_states,
128
139
  attentions=outputs.attentions,
140
+ token_accuracy=token_accuracy,
129
141
  )
@@ -5,10 +5,11 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.modeling_outputs import CausalLMOutputWithPast
9
8
  from transformers.utils.deprecation import deprecate_kwarg
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -28,7 +29,7 @@ def lce_forward(
28
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
30
  skip_logits: Optional[bool] = None,
30
31
  **kwargs,
31
- ) -> Union[Tuple, CausalLMOutputWithPast]:
32
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
32
33
  r"""
33
34
  Args:
34
35
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -113,6 +114,7 @@ def lce_forward(
113
114
  shift_labels = kwargs.pop("shift_labels", None)
114
115
  logits = None
115
116
  loss = None
117
+ token_accuracy = None
116
118
 
117
119
  if skip_logits and labels is None and shift_labels is None:
118
120
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -121,8 +123,9 @@ def lce_forward(
121
123
  # By default, if in training mode, don't materialize logits
122
124
  skip_logits = self.training and (labels is not None or shift_labels is not None)
123
125
 
126
+ # Compute loss
124
127
  if skip_logits:
125
- loss = LigerForCausalLMLoss(
128
+ result = LigerForCausalLMLoss(
126
129
  hidden_states=kept_hidden_states,
127
130
  lm_head_weight=self.lm_head.weight,
128
131
  labels=labels,
@@ -130,6 +133,7 @@ def lce_forward(
130
133
  hidden_size=self.config.hidden_size,
131
134
  **kwargs,
132
135
  )
136
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
133
137
 
134
138
  else:
135
139
  logits = self.lm_head(kept_hidden_states)
@@ -142,10 +146,18 @@ def lce_forward(
142
146
  **kwargs,
143
147
  )
144
148
 
145
- return CausalLMOutputWithPast(
149
+ if not return_dict:
150
+ output = (logits,) + outputs[1:]
151
+ output = ((loss,) + output) if loss is not None else output
152
+ output = output + (token_accuracy,) if token_accuracy is not None else output
153
+ return output
154
+
155
+ # Return custom output class with token_accuracy field
156
+ return LigerCausalLMOutputWithPast(
146
157
  loss=loss,
147
158
  logits=logits,
148
159
  past_key_values=outputs.past_key_values,
149
160
  hidden_states=outputs.hidden_states,
150
161
  attentions=outputs.attentions,
162
+ token_accuracy=token_accuracy,
151
163
  )
@@ -4,10 +4,11 @@ from typing import Union
4
4
 
5
5
  import torch
6
6
 
7
- from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeCausalLMOutputWithPast
8
7
  from transformers.utils.deprecation import deprecate_kwarg
9
8
 
10
9
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
10
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
11
+ from liger_kernel.transformers.model.output_classes import LigerGlm4vMoeCausalLMOutputWithPast
11
12
 
12
13
 
13
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@@ -27,8 +28,9 @@ def lce_forward(
27
28
  cache_position: Optional[torch.LongTensor] = None,
28
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
30
  skip_logits: Optional[bool] = None,
31
+ return_dict: Optional[bool] = None,
30
32
  **kwargs,
31
- ) -> Union[Tuple, Glm4vMoeCausalLMOutputWithPast]:
33
+ ) -> Union[Tuple, LigerGlm4vMoeCausalLMOutputWithPast]:
32
34
  r"""
33
35
  Args:
34
36
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -90,6 +92,7 @@ def lce_forward(
90
92
  >>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
91
93
  ```
92
94
  """
95
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
93
96
 
94
97
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
95
98
  outputs = self.model(
@@ -114,6 +117,7 @@ def lce_forward(
114
117
  shift_labels = kwargs.pop("shift_labels", None)
115
118
  logits = None
116
119
  loss = None
120
+ token_accuracy = None
117
121
 
118
122
  if skip_logits and labels is None and shift_labels is None:
119
123
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -122,8 +126,9 @@ def lce_forward(
122
126
  # By default, if in training mode, don't materialize logits
123
127
  skip_logits = self.training and (labels is not None or shift_labels is not None)
124
128
 
129
+ # Compute loss
125
130
  if skip_logits:
126
- loss = LigerForCausalLMLoss(
131
+ result = LigerForCausalLMLoss(
127
132
  hidden_states=kept_hidden_states,
128
133
  lm_head_weight=self.lm_head.weight,
129
134
  labels=labels,
@@ -131,6 +136,7 @@ def lce_forward(
131
136
  hidden_size=self.config.hidden_size,
132
137
  **kwargs,
133
138
  )
139
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
134
140
 
135
141
  else:
136
142
  logits = self.lm_head(kept_hidden_states)
@@ -143,11 +149,20 @@ def lce_forward(
143
149
  **kwargs,
144
150
  )
145
151
 
146
- return Glm4vMoeCausalLMOutputWithPast(
152
+ if not return_dict:
153
+ output = (logits,) + outputs[1:]
154
+ output = ((loss,) + output) if loss is not None else output
155
+ output = output + (token_accuracy,) if token_accuracy is not None else output
156
+ return output
157
+
158
+ # Return GLM4V MoE output with accuracy (using dict syntax to add extra field)
159
+ return LigerGlm4vMoeCausalLMOutputWithPast(
147
160
  loss=loss,
148
161
  logits=logits,
149
162
  past_key_values=outputs.past_key_values,
150
163
  hidden_states=outputs.hidden_states,
151
164
  attentions=outputs.attentions,
152
165
  rope_deltas=outputs.rope_deltas,
166
+ aux_loss=outputs.aux_loss,
167
+ token_accuracy=token_accuracy,
153
168
  )
@@ -5,10 +5,11 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.models.internvl.modeling_internvl import InternVLCausalLMOutputWithPast
9
8
  from transformers.utils import can_return_tuple
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerInternVLCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  # Copied from https://github.com/huggingface/transformers/blob/d888bd435d0c0eaabaabad5b33d52af518c7187c/src/transformers/models/internvl/modeling_internvl.py#L862
@@ -33,7 +34,7 @@ def lce_forward(
33
34
  image_sizes: Optional[torch.Tensor] = None,
34
35
  skip_logits: Optional[bool] = None, # Added argument for liger-kernel
35
36
  **lm_kwargs, # renamed from kwargs
36
- ) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
37
+ ) -> Union[Tuple, LigerInternVLCausalLMOutputWithPast]:
37
38
  r"""
38
39
  Example:
39
40
 
@@ -111,6 +112,7 @@ def lce_forward(
111
112
  shift_labels = lm_kwargs.pop("shift_labels", None)
112
113
  logits = None
113
114
  loss = None
115
+ token_accuracy = None
114
116
 
115
117
  if skip_logits and labels is None and shift_labels is None:
116
118
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -120,7 +122,7 @@ def lce_forward(
120
122
  skip_logits = self.training and (labels is not None or shift_labels is not None)
121
123
 
122
124
  if skip_logits:
123
- loss = LigerForCausalLMLoss(
125
+ result = LigerForCausalLMLoss(
124
126
  hidden_states=kept_hidden_states,
125
127
  lm_head_weight=self.lm_head.weight,
126
128
  labels=labels,
@@ -128,6 +130,7 @@ def lce_forward(
128
130
  hidden_size=self.config.text_config.hidden_size,
129
131
  **lm_kwargs,
130
132
  )
133
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
131
134
 
132
135
  else:
133
136
  logits = self.lm_head(kept_hidden_states)
@@ -138,13 +141,17 @@ def lce_forward(
138
141
 
139
142
  if not return_dict:
140
143
  output = (logits,) + outputs[1:]
141
- return (loss,) + output if loss is not None else output
144
+ output = (loss,) + output if loss is not None else output
145
+ output = output + (token_accuracy,) if token_accuracy is not None else output
146
+ return output
142
147
 
143
- return InternVLCausalLMOutputWithPast(
148
+ # Return custom output class with token_accuracy field
149
+ return LigerInternVLCausalLMOutputWithPast(
144
150
  loss=loss,
145
151
  logits=logits,
146
152
  past_key_values=outputs.past_key_values,
147
153
  hidden_states=outputs.hidden_states,
148
154
  attentions=outputs.attentions,
149
155
  image_hidden_states=outputs.image_hidden_states,
156
+ token_accuracy=token_accuracy,
150
157
  )
@@ -15,6 +15,8 @@ from transformers.utils.deprecation import deprecate_kwarg
15
15
  from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
16
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
19
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
18
20
  from liger_kernel.utils import PEFT_AVAILABLE
19
21
 
20
22
  if TYPE_CHECKING:
@@ -162,7 +164,7 @@ def lce_forward(
162
164
  logits_to_keep: Union[int, torch.Tensor] = 0,
163
165
  skip_logits: Optional[bool] = None,
164
166
  **kwargs,
165
- ) -> Union[Tuple, CausalLMOutputWithPast]:
167
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
166
168
  r"""
167
169
  Args:
168
170
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -228,6 +230,8 @@ def lce_forward(
228
230
  shift_labels = kwargs.pop("shift_labels", None)
229
231
  logits = None
230
232
  loss = None
233
+ token_accuracy = None
234
+
231
235
  # if in training mode, don't materialize logits
232
236
  if skip_logits and labels is None and shift_labels is None:
233
237
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -236,8 +240,9 @@ def lce_forward(
236
240
  # By default, if in training mode, don't materialize logits
237
241
  skip_logits = self.training and (labels is not None or shift_labels is not None)
238
242
 
243
+ # Compute loss
239
244
  if skip_logits:
240
- loss = lce_maybe_trainable_lm_head(
245
+ result = lce_maybe_trainable_lm_head(
241
246
  self,
242
247
  hidden_states=kept_hidden_states,
243
248
  hidden_size=self.config.hidden_size,
@@ -245,7 +250,7 @@ def lce_forward(
245
250
  shift_labels=shift_labels,
246
251
  **kwargs,
247
252
  )
248
-
253
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
249
254
  else:
250
255
  logits = self.lm_head(kept_hidden_states)
251
256
  if labels is not None or shift_labels is not None:
@@ -259,14 +264,18 @@ def lce_forward(
259
264
 
260
265
  if not return_dict:
261
266
  output = (logits,) + outputs[1:]
262
- return (loss,) + output if loss is not None else output
267
+ output = ((loss,) + output) if loss is not None else output
268
+ output = output + (token_accuracy,) if token_accuracy is not None else output
269
+ return output
263
270
 
264
- return CausalLMOutputWithPast(
271
+ # Return custom output class with token_accuracy field
272
+ return LigerCausalLMOutputWithPast(
265
273
  loss=loss,
266
274
  logits=logits,
267
275
  past_key_values=outputs.past_key_values,
268
276
  hidden_states=outputs.hidden_states,
269
277
  attentions=outputs.attentions,
278
+ token_accuracy=token_accuracy,
270
279
  )
271
280
 
272
281
 
@@ -6,9 +6,10 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.cache_utils import Cache
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
9
 
11
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
12
13
 
13
14
 
14
15
  def lce_forward(
@@ -26,7 +27,7 @@ def lce_forward(
26
27
  cache_position: Optional[torch.LongTensor] = None,
27
28
  logits_to_keep: Union[int, torch.Tensor] = 0,
28
29
  **kwargs,
29
- ) -> Union[Tuple, CausalLMOutputWithPast]:
30
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
30
31
  r"""
31
32
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
32
33
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -78,9 +79,11 @@ def lce_forward(
78
79
  shift_labels = kwargs.pop("shift_labels", None)
79
80
  logits = None
80
81
  loss = None
82
+ token_accuracy = None
81
83
 
84
+ # Compute loss
82
85
  if self.training and (labels is not None or shift_labels is not None):
83
- loss = LigerForCausalLMLoss(
86
+ result = LigerForCausalLMLoss(
84
87
  hidden_states=kept_hidden_states,
85
88
  lm_head_weight=self.lm_head.weight,
86
89
  labels=labels,
@@ -88,6 +91,7 @@ def lce_forward(
88
91
  hidden_size=self.config.hidden_size,
89
92
  **kwargs,
90
93
  )
94
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
91
95
 
92
96
  else: # if in inference mode materialize logits
93
97
  logits = self.lm_head(kept_hidden_states)
@@ -100,10 +104,18 @@ def lce_forward(
100
104
  **kwargs,
101
105
  )
102
106
 
103
- return CausalLMOutputWithPast(
107
+ if not return_dict:
108
+ output = (logits,) + outputs[1:]
109
+ output = ((loss,) + output) if loss is not None else output
110
+ output = output + (token_accuracy,) if token_accuracy is not None else output
111
+ return output
112
+
113
+ # Return custom output class with token_accuracy field
114
+ return LigerCausalLMOutputWithPast(
104
115
  loss=loss,
105
116
  logits=logits,
106
117
  past_key_values=outputs.past_key_values,
107
118
  hidden_states=outputs.hidden_states,
108
119
  attentions=outputs.attentions,
120
+ token_accuracy=token_accuracy,
109
121
  )
@@ -11,6 +11,8 @@ from transformers.utils import is_torchdynamo_compiling
11
11
 
12
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerLlavaCausalLMOutputWithPast
14
16
 
15
17
 
16
18
  def lce_forward_deprecated(
@@ -215,7 +217,7 @@ def lce_forward(
215
217
  image_sizes: torch.Tensor = None,
216
218
  skip_logits: Optional[bool] = None,
217
219
  **lm_kwargs,
218
- ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
220
+ ) -> Union[Tuple, LigerLlavaCausalLMOutputWithPast]:
219
221
  r"""
220
222
  Args:
221
223
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -293,6 +295,7 @@ def lce_forward(
293
295
  shift_labels = lm_kwargs.pop("shift_labels", None)
294
296
  logits = None
295
297
  loss = None
298
+ token_accuracy = None
296
299
 
297
300
  if skip_logits and labels is None and shift_labels is None:
298
301
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -302,7 +305,7 @@ def lce_forward(
302
305
  skip_logits = self.training and (labels is not None or shift_labels is not None)
303
306
 
304
307
  if skip_logits:
305
- loss = LigerForCausalLMLoss(
308
+ result = LigerForCausalLMLoss(
306
309
  hidden_states=kept_hidden_states,
307
310
  lm_head_weight=self.lm_head.weight,
308
311
  labels=labels,
@@ -310,6 +313,7 @@ def lce_forward(
310
313
  hidden_size=self.config.text_config.hidden_size,
311
314
  **lm_kwargs,
312
315
  )
316
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
313
317
 
314
318
  else:
315
319
  logits = self.lm_head(kept_hidden_states)
@@ -324,13 +328,17 @@ def lce_forward(
324
328
 
325
329
  if not return_dict:
326
330
  output = (logits,) + outputs[1:]
327
- return (loss,) + output if loss is not None else output
331
+ output = (loss,) + output if loss is not None else output
332
+ output = output + (token_accuracy,) if token_accuracy is not None else output
333
+ return output
328
334
 
329
- return LlavaCausalLMOutputWithPast(
335
+ # Return custom output class with token_accuracy field
336
+ return LigerLlavaCausalLMOutputWithPast(
330
337
  loss=loss,
331
338
  logits=logits,
332
339
  past_key_values=outputs.past_key_values,
333
340
  hidden_states=outputs.hidden_states,
334
341
  attentions=outputs.attentions,
335
342
  image_hidden_states=outputs.image_hidden_states,
343
+ token_accuracy=token_accuracy,
336
344
  )