liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,157 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.utils import can_return_tuple
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerInternVLCausalLMOutputWithPast
13
+
14
+
15
+ # Copied from https://github.com/huggingface/transformers/blob/d888bd435d0c0eaabaabad5b33d52af518c7187c/src/transformers/models/internvl/modeling_internvl.py#L862
16
+ @can_return_tuple
17
+ def lce_forward(
18
+ self,
19
+ input_ids: torch.LongTensor = None,
20
+ pixel_values: Optional[torch.FloatTensor] = None,
21
+ attention_mask: Optional[torch.Tensor] = None,
22
+ position_ids: Optional[torch.LongTensor] = None,
23
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
24
+ inputs_embeds: Optional[torch.FloatTensor] = None,
25
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
26
+ vision_feature_select_strategy: Optional[str] = None,
27
+ labels: Optional[torch.LongTensor] = None,
28
+ use_cache: Optional[bool] = None,
29
+ output_attentions: Optional[bool] = None,
30
+ output_hidden_states: Optional[bool] = None,
31
+ return_dict: Optional[bool] = None,
32
+ cache_position: Optional[torch.LongTensor] = None,
33
+ logits_to_keep: Union[int, torch.Tensor] = 0,
34
+ image_sizes: Optional[torch.Tensor] = None,
35
+ skip_logits: Optional[bool] = None, # Added argument for liger-kernel
36
+ **lm_kwargs, # renamed from kwargs
37
+ ) -> Union[Tuple, LigerInternVLCausalLMOutputWithPast]:
38
+ r"""
39
+ Example:
40
+
41
+ ```python
42
+ >>> import torch
43
+ >>> from transformers import AutoProcessor, AutoModelForImageTextToText
44
+
45
+ >>> torch_device = "cuda"
46
+ >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
47
+ >>> model = AutoModelForImageTextToText.from_pretrained(
48
+ ... "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
49
+ ... )
50
+
51
+ >>> messages = [
52
+ ... {
53
+ ... "role": "user",
54
+ ... "content": [
55
+ ... {
56
+ ... "type": "image",
57
+ ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
58
+ ... },
59
+ ... {
60
+ ... "type": "image",
61
+ ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
62
+ ... },
63
+ ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
64
+ ... ],
65
+ ... },
66
+ ... ]
67
+
68
+ >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
69
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=200)
70
+ >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
71
+ The images depict the Statue of Liberty and the Golden Gate Bridge.
72
+ ```"""
73
+
74
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
+ output_hidden_states = (
76
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
77
+ )
78
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
+ vision_feature_layer = (
80
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
81
+ )
82
+ vision_feature_select_strategy = (
83
+ vision_feature_select_strategy
84
+ if vision_feature_select_strategy is not None
85
+ else self.config.vision_feature_select_strategy
86
+ )
87
+
88
+ outputs = self.model(
89
+ input_ids=input_ids,
90
+ pixel_values=pixel_values,
91
+ attention_mask=attention_mask,
92
+ position_ids=position_ids,
93
+ past_key_values=past_key_values,
94
+ inputs_embeds=inputs_embeds,
95
+ vision_feature_layer=vision_feature_layer,
96
+ vision_feature_select_strategy=vision_feature_select_strategy,
97
+ use_cache=use_cache,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ cache_position=cache_position,
102
+ image_sizes=image_sizes,
103
+ **lm_kwargs,
104
+ )
105
+
106
+ # Copied from llava.py
107
+ hidden_states = outputs[0]
108
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
109
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
110
+ kept_hidden_states = hidden_states[:, slice_indices, :]
111
+
112
+ shift_labels = lm_kwargs.pop("shift_labels", None)
113
+ logits = None
114
+ loss = None
115
+ token_accuracy = None
116
+
117
+ if skip_logits and labels is None and shift_labels is None:
118
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
119
+
120
+ if skip_logits is None:
121
+ # By default, if in training mode, don't materialize logits
122
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
123
+
124
+ if skip_logits:
125
+ result = LigerForCausalLMLoss(
126
+ hidden_states=kept_hidden_states,
127
+ lm_head_weight=self.lm_head.weight,
128
+ labels=labels,
129
+ shift_labels=shift_labels,
130
+ hidden_size=self.config.text_config.hidden_size,
131
+ **lm_kwargs,
132
+ )
133
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
134
+
135
+ else:
136
+ logits = self.lm_head(kept_hidden_states)
137
+ if labels is not None:
138
+ loss = self.loss_function(
139
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
140
+ )
141
+
142
+ if not return_dict:
143
+ output = (logits,) + outputs[1:]
144
+ output = (loss,) + output if loss is not None else output
145
+ output = output + (token_accuracy,) if token_accuracy is not None else output
146
+ return output
147
+
148
+ # Return custom output class with token_accuracy field
149
+ return LigerInternVLCausalLMOutputWithPast(
150
+ loss=loss,
151
+ logits=logits,
152
+ past_key_values=outputs.past_key_values,
153
+ hidden_states=outputs.hidden_states,
154
+ attentions=outputs.attentions,
155
+ image_hidden_states=outputs.image_hidden_states,
156
+ token_accuracy=token_accuracy,
157
+ )
@@ -7,23 +7,25 @@ from typing import Union
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
 
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
10
11
  from torch.nn import CrossEntropyLoss
11
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
- from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
13
  from transformers.utils.deprecation import deprecate_kwarg
17
14
 
15
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
18
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
19
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
20
+ from liger_kernel.utils import PEFT_AVAILABLE
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from transformers.cache_utils import Cache
23
24
 
25
+ if PEFT_AVAILABLE:
26
+ from peft.utils.other import ModulesToSaveWrapper
27
+
24
28
 
25
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
26
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
27
29
  def lce_forward_deprecated(
28
30
  self,
29
31
  input_ids: torch.LongTensor = None,
@@ -37,6 +39,7 @@ def lce_forward_deprecated(
37
39
  output_hidden_states: Optional[bool] = None,
38
40
  return_dict: Optional[bool] = None,
39
41
  cache_position: Optional[torch.LongTensor] = None,
42
+ skip_logits: Optional[bool] = None,
40
43
  ) -> Union[Tuple, CausalLMOutputWithPast]:
41
44
  r"""
42
45
  Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
@@ -91,7 +94,15 @@ def lce_forward_deprecated(
91
94
  loss = None
92
95
  logits = None
93
96
 
94
- if self.training and (labels is not None):
97
+ # if in training mode, don't materialize logits
98
+ if skip_logits and labels is None:
99
+ raise ValueError("skip_logits is True, but labels is None")
100
+
101
+ if skip_logits is None:
102
+ # By default, if in training mode, don't materialize logits
103
+ skip_logits = self.training and labels is not None
104
+
105
+ if skip_logits:
95
106
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
96
107
  shift_labels = labels[..., 1:].contiguous()
97
108
 
@@ -137,8 +148,6 @@ def lce_forward_deprecated(
137
148
 
138
149
 
139
150
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
140
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
142
151
  def lce_forward(
143
152
  self,
144
153
  input_ids: torch.LongTensor = None,
@@ -153,8 +162,9 @@ def lce_forward(
153
162
  return_dict: Optional[bool] = None,
154
163
  cache_position: Optional[torch.LongTensor] = None,
155
164
  logits_to_keep: Union[int, torch.Tensor] = 0,
156
- **loss_kwargs,
157
- ) -> Union[Tuple, CausalLMOutputWithPast]:
165
+ skip_logits: Optional[bool] = None,
166
+ **kwargs,
167
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
158
168
  r"""
159
169
  Args:
160
170
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -206,44 +216,111 @@ def lce_forward(
206
216
  output_hidden_states=output_hidden_states,
207
217
  return_dict=return_dict,
208
218
  cache_position=cache_position,
219
+ **kwargs,
209
220
  )
210
221
 
211
222
  hidden_states = outputs[0]
223
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
224
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
225
+ kept_hidden_states = hidden_states[:, slice_indices, :]
212
226
 
213
227
  if self.config.pretraining_tp > 1:
214
228
  raise Exception("Liger Kernel does not support pretraining_tp!!")
215
229
 
230
+ shift_labels = kwargs.pop("shift_labels", None)
216
231
  logits = None
217
232
  loss = None
233
+ token_accuracy = None
234
+
218
235
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
220
- loss = LigerForCausalLMLoss(
221
- hidden_states=hidden_states,
222
- lm_head_weight=self.lm_head.weight,
223
- labels=labels,
236
+ if skip_logits and labels is None and shift_labels is None:
237
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
238
+
239
+ if skip_logits is None:
240
+ # By default, if in training mode, don't materialize logits
241
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
242
+
243
+ # Compute loss
244
+ if skip_logits:
245
+ result = lce_maybe_trainable_lm_head(
246
+ self,
247
+ hidden_states=kept_hidden_states,
224
248
  hidden_size=self.config.hidden_size,
225
- **loss_kwargs,
249
+ labels=labels,
250
+ shift_labels=shift_labels,
251
+ **kwargs,
226
252
  )
227
-
228
- else: # if in inference mode materialize logits
229
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
- logits = self.lm_head(hidden_states[:, slice_indices, :])
231
- if labels is not None:
253
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
254
+ else:
255
+ logits = self.lm_head(kept_hidden_states)
256
+ if labels is not None or shift_labels is not None:
232
257
  loss = self.loss_function(
233
258
  logits=logits,
234
259
  labels=labels,
260
+ shift_labels=shift_labels,
235
261
  vocab_size=self.config.vocab_size,
236
- **loss_kwargs,
262
+ **kwargs,
237
263
  )
238
264
 
239
265
  if not return_dict:
240
266
  output = (logits,) + outputs[1:]
241
- return (loss,) + output if loss is not None else output
267
+ output = ((loss,) + output) if loss is not None else output
268
+ output = output + (token_accuracy,) if token_accuracy is not None else output
269
+ return output
242
270
 
243
- return CausalLMOutputWithPast(
271
+ # Return custom output class with token_accuracy field
272
+ return LigerCausalLMOutputWithPast(
244
273
  loss=loss,
245
274
  logits=logits,
246
275
  past_key_values=outputs.past_key_values,
247
276
  hidden_states=outputs.hidden_states,
248
277
  attentions=outputs.attentions,
278
+ token_accuracy=token_accuracy,
279
+ )
280
+
281
+
282
+ def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
283
+ lm_head = self.lm_head
284
+
285
+ # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
286
+ # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
287
+ # from the unwrapped module.
288
+ # See https://huggingface.co/docs/peft/package_reference/lora for reference.
289
+ if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
290
+ lm_head = lm_head.modules_to_save.default
291
+
292
+ # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
293
+ # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
294
+ # so the module entire parameters are summoned and kept in memory during the kernel execution.
295
+ if isinstance(lm_head, FullyShardedDataParallel):
296
+ return _FSDPForwardRedirection()(
297
+ lm_head,
298
+ _liger_for_causal_lm_loss,
299
+ lm_head.module,
300
+ hidden_states,
301
+ hidden_size,
302
+ labels,
303
+ shift_labels,
304
+ **loss_kwargs,
305
+ )
306
+
307
+ # FSDP is not used so we can read the lm_head weights and call the kernel directly
308
+ return _liger_for_causal_lm_loss(
309
+ lm_head=self.lm_head,
310
+ hidden_states=hidden_states,
311
+ hidden_size=hidden_size,
312
+ labels=labels,
313
+ shift_labels=shift_labels,
314
+ **loss_kwargs,
315
+ )
316
+
317
+
318
+ def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
319
+ return LigerForCausalLMLoss(
320
+ hidden_states=hidden_states,
321
+ lm_head_weight=lm_head.weight,
322
+ labels=labels,
323
+ hidden_size=hidden_size,
324
+ shift_labels=shift_labels,
325
+ **loss_kwargs,
249
326
  )
@@ -0,0 +1,121 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.cache_utils import Cache
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
+
14
+
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ cache_position: Optional[torch.LongTensor] = None,
28
+ logits_to_keep: Union[int, torch.Tensor] = 0,
29
+ **kwargs,
30
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
31
+ r"""
32
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
33
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
34
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
35
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
36
+
37
+ Example:
38
+
39
+ ```python
40
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
41
+
42
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
43
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
44
+
45
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
46
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
47
+
48
+ >>> # Generate
49
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
50
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
51
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
52
+ ```"""
53
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
54
+ output_hidden_states = (
55
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
+ )
57
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
58
+
59
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
60
+ outputs = self.model(
61
+ input_ids=input_ids,
62
+ attention_mask=attention_mask,
63
+ position_ids=position_ids,
64
+ past_key_values=past_key_values,
65
+ inputs_embeds=inputs_embeds,
66
+ use_cache=use_cache,
67
+ output_attentions=output_attentions,
68
+ output_hidden_states=output_hidden_states,
69
+ return_dict=True,
70
+ cache_position=cache_position,
71
+ **kwargs,
72
+ )
73
+
74
+ hidden_states = outputs[0]
75
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
76
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
77
+ kept_hidden_states = hidden_states[:, slice_indices, :]
78
+
79
+ shift_labels = kwargs.pop("shift_labels", None)
80
+ logits = None
81
+ loss = None
82
+ token_accuracy = None
83
+
84
+ # Compute loss
85
+ if self.training and (labels is not None or shift_labels is not None):
86
+ result = LigerForCausalLMLoss(
87
+ hidden_states=kept_hidden_states,
88
+ lm_head_weight=self.lm_head.weight,
89
+ labels=labels,
90
+ shift_labels=shift_labels,
91
+ hidden_size=self.config.hidden_size,
92
+ **kwargs,
93
+ )
94
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
95
+
96
+ else: # if in inference mode materialize logits
97
+ logits = self.lm_head(kept_hidden_states)
98
+ if labels is not None or shift_labels is not None:
99
+ loss = self.loss_function(
100
+ logits=logits,
101
+ labels=labels,
102
+ shift_labels=shift_labels,
103
+ vocab_size=self.config.vocab_size,
104
+ **kwargs,
105
+ )
106
+
107
+ if not return_dict:
108
+ output = (logits,) + outputs[1:]
109
+ output = ((loss,) + output) if loss is not None else output
110
+ output = output + (token_accuracy,) if token_accuracy is not None else output
111
+ return output
112
+
113
+ # Return custom output class with token_accuracy field
114
+ return LigerCausalLMOutputWithPast(
115
+ loss=loss,
116
+ logits=logits,
117
+ past_key_values=outputs.past_key_values,
118
+ hidden_states=outputs.hidden_states,
119
+ attentions=outputs.attentions,
120
+ token_accuracy=token_accuracy,
121
+ )