liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -5,19 +5,16 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
9
- from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
8
+ from torch.nn import CrossEntropyLoss
10
9
  from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
10
  from transformers.utils import is_torchdynamo_compiling
13
- from transformers.utils import replace_return_docstrings
14
- from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerLlavaCausalLMOutputWithPast
17
16
 
18
17
 
19
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
18
  def lce_forward_deprecated(
22
19
  self,
23
20
  input_ids: torch.LongTensor = None,
@@ -33,6 +30,11 @@ def lce_forward_deprecated(
33
30
  output_attentions: Optional[bool] = None,
34
31
  output_hidden_states: Optional[bool] = None,
35
32
  return_dict: Optional[bool] = None,
33
+ cache_position: Optional[torch.LongTensor] = None,
34
+ logits_to_keep: Union[int, torch.Tensor] = 0,
35
+ image_sizes: torch.Tensor = None,
36
+ skip_logits: Optional[bool] = None,
37
+ **lm_kwargs,
36
38
  ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
37
39
  r"""
38
40
  Args:
@@ -41,10 +43,12 @@ def lce_forward_deprecated(
41
43
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
44
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
45
 
44
- num_logits_to_keep (`int`, *optional*):
45
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
46
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
47
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46
48
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
49
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
50
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
51
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
48
52
 
49
53
 
50
54
  Returns:
@@ -70,7 +74,6 @@ def lce_forward_deprecated(
70
74
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
71
75
  "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
72
76
  ```"""
73
-
74
77
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
78
  output_hidden_states = (
76
79
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -94,73 +97,24 @@ def lce_forward_deprecated(
94
97
  )
95
98
 
96
99
  if inputs_embeds is None:
97
- # 1. Extra the input embeddings
98
100
  inputs_embeds = self.get_input_embeddings()(input_ids)
99
101
 
100
- # 2. Merge text and images
101
- if pixel_values is not None and input_ids.shape[1] != 1:
102
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
103
- # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
104
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
105
-
106
- if vision_feature_select_strategy == "default":
107
- selected_image_feature = selected_image_feature[:, 1:]
108
- elif vision_feature_select_strategy == "full":
109
- selected_image_feature = selected_image_feature
110
- else:
111
- raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
112
-
113
- image_features = self.multi_modal_projector(selected_image_feature)
114
- inputs_embeds = inputs_embeds.to(image_features.dtype)
115
- inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
116
- image_features, inputs_embeds, input_ids, attention_mask, labels
117
- )
118
-
119
- # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
120
- # generation with cache
121
- elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
122
- # Retrieve the first layer to inspect the logits and mask out the hidden states
123
- # that are set to 0
124
- first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
125
-
126
- # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
127
- batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
128
-
129
- # Get the target length
130
- target_length = input_ids.shape[1]
131
- past_length = first_layer_past_key_value.shape[-1]
132
-
133
- extended_attention_mask = torch.ones(
134
- (attention_mask.shape[0], past_length),
135
- dtype=attention_mask.dtype,
136
- device=attention_mask.device,
137
- )
138
-
139
- # Filter out only the tokens that can be un-attended, this can happen
140
- # if one uses Llava + Fused modules where the cache on the
141
- # first iteration is already big enough, or if one passes custom cache
142
- valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
143
- new_batch_index = batch_index[valid_indices]
144
- new_non_attended_tokens = non_attended_tokens[valid_indices]
145
-
146
- # Zero-out the places where we don't need to attend
147
- extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
148
-
149
- attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
150
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
151
-
152
- # TODO: @raushan retain only the new behavior after v4.47
153
- elif image_features is not None:
154
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
155
- n_image_features = image_features.shape[0] * image_features.shape[1]
102
+ if pixel_values is not None:
103
+ image_features = self.get_image_features(
104
+ pixel_values=pixel_values,
105
+ vision_feature_layer=vision_feature_layer,
106
+ vision_feature_select_strategy=vision_feature_select_strategy,
107
+ image_sizes=image_sizes,
108
+ )
156
109
 
157
- if n_image_tokens != n_image_features:
110
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
111
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
112
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
113
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
114
+ n_image_features = image_features.shape[0] * image_features.shape[1]
158
115
  raise ValueError(
159
116
  f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
160
117
  )
161
- special_image_mask = (
162
- (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
163
- )
164
118
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
165
119
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
166
120
 
@@ -173,13 +127,19 @@ def lce_forward_deprecated(
173
127
  output_attentions=output_attentions,
174
128
  output_hidden_states=output_hidden_states,
175
129
  return_dict=return_dict,
130
+ cache_position=cache_position,
131
+ logits_to_keep=logits_to_keep,
132
+ **lm_kwargs,
176
133
  )
177
134
  hidden_states = outputs[0]
178
135
 
179
136
  loss = None
180
137
  logits = None
181
138
 
182
- if self.training and (labels is not None):
139
+ # Overwrite skip_logits, since llava never materializes logits
140
+ skip_logits = labels is not None
141
+
142
+ if skip_logits:
183
143
  # Shift so that tokens < n predict n
184
144
  if attention_mask is not None:
185
145
  # we use the input attention mask to shift the logits and labels, because it is 2D.
@@ -194,7 +154,33 @@ def lce_forward_deprecated(
194
154
  shift_labels = labels[..., 1:].contiguous()
195
155
 
196
156
  lce = LigerFusedLinearCrossEntropyLoss()
197
- loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
157
+ loss = lce(
158
+ self.language_model.lm_head.weight,
159
+ shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
160
+ shift_labels.view(-1).to(shift_hidden_states.device),
161
+ )
162
+ else:
163
+ logits = self.language_model.lm_head(hidden_states)
164
+ if labels is not None:
165
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
166
+ logits = logits.float()
167
+ shift_logits = logits[..., :-1, :]
168
+ shift_labels = labels[..., 1:]
169
+ if attention_mask is not None:
170
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
171
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
172
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
173
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
174
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
175
+ else:
176
+ shift_logits = shift_logits.contiguous()
177
+ shift_labels = shift_labels.contiguous()
178
+ # Flatten the tokens
179
+ loss_fct = CrossEntropyLoss()
180
+
181
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
182
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
183
+ loss = loss_fct(flat_logits, flat_labels)
198
184
 
199
185
  if not return_dict:
200
186
  # NOTE: This part has not been tested.
@@ -207,12 +193,10 @@ def lce_forward_deprecated(
207
193
  past_key_values=outputs.past_key_values,
208
194
  hidden_states=outputs.hidden_states,
209
195
  attentions=outputs.attentions,
196
+ image_hidden_states=image_features if pixel_values is not None else None,
210
197
  )
211
198
 
212
199
 
213
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
214
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
215
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
216
200
  def lce_forward(
217
201
  self,
218
202
  input_ids: torch.LongTensor = None,
@@ -231,8 +215,9 @@ def lce_forward(
231
215
  cache_position: Optional[torch.LongTensor] = None,
232
216
  logits_to_keep: Union[int, torch.Tensor] = 0,
233
217
  image_sizes: torch.Tensor = None,
218
+ skip_logits: Optional[bool] = None,
234
219
  **lm_kwargs,
235
- ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
220
+ ) -> Union[Tuple, LigerLlavaCausalLMOutputWithPast]:
236
221
  r"""
237
222
  Args:
238
223
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -285,85 +270,75 @@ def lce_forward(
285
270
  else self.config.vision_feature_select_strategy
286
271
  )
287
272
 
288
- if (input_ids is None) ^ (inputs_embeds is not None):
289
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
290
-
291
- if pixel_values is not None and inputs_embeds is not None:
292
- raise ValueError(
293
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
294
- )
295
-
296
- if inputs_embeds is None:
297
- inputs_embeds = self.get_input_embeddings()(input_ids)
298
-
299
- if pixel_values is not None:
300
- image_features = self.get_image_features(
301
- pixel_values=pixel_values,
302
- vision_feature_layer=vision_feature_layer,
303
- vision_feature_select_strategy=vision_feature_select_strategy,
304
- image_sizes=image_sizes,
305
- )
306
-
307
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
308
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
309
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
310
- n_image_tokens = (input_ids == self.config.image_token_index).sum()
311
- n_image_features = image_features.shape[0] * image_features.shape[1]
312
- raise ValueError(
313
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
314
- )
315
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
316
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
317
-
318
- outputs = self.language_model.model(
273
+ outputs = self.model(
274
+ input_ids=input_ids,
275
+ pixel_values=pixel_values,
319
276
  attention_mask=attention_mask,
320
277
  position_ids=position_ids,
321
278
  past_key_values=past_key_values,
322
279
  inputs_embeds=inputs_embeds,
280
+ vision_feature_layer=vision_feature_layer,
281
+ vision_feature_select_strategy=vision_feature_select_strategy,
323
282
  use_cache=use_cache,
324
283
  output_attentions=output_attentions,
325
284
  output_hidden_states=output_hidden_states,
326
- return_dict=return_dict,
285
+ return_dict=True,
327
286
  cache_position=cache_position,
328
- logits_to_keep=logits_to_keep,
287
+ image_sizes=image_sizes,
329
288
  **lm_kwargs,
330
289
  )
331
290
  hidden_states = outputs[0]
291
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
292
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
293
+ kept_hidden_states = hidden_states[:, slice_indices, :]
332
294
 
333
- loss = None
295
+ shift_labels = lm_kwargs.pop("shift_labels", None)
334
296
  logits = None
335
-
336
- if self.training and (labels is not None):
337
- # Shift so that tokens < n predict n
338
- if attention_mask is not None:
339
- # we use the input attention mask to shift the logits and labels, because it is 2D.
340
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
341
- shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
342
- shift_hidden_states = hidden_states[..., :-1, :][
343
- shift_attention_mask.to(hidden_states.device) != 0
344
- ].contiguous()
345
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
346
- else:
347
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
348
- shift_labels = labels[..., 1:].contiguous()
349
-
350
- lce = LigerFusedLinearCrossEntropyLoss()
351
- loss = lce(
352
- self.language_model.lm_head.weight,
353
- shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
354
- shift_labels.view(-1).to(shift_hidden_states.device),
297
+ loss = None
298
+ token_accuracy = None
299
+
300
+ if skip_logits and labels is None and shift_labels is None:
301
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
302
+
303
+ if skip_logits is None:
304
+ # By default, if in training mode, don't materialize logits
305
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
306
+
307
+ if skip_logits:
308
+ result = LigerForCausalLMLoss(
309
+ hidden_states=kept_hidden_states,
310
+ lm_head_weight=self.lm_head.weight,
311
+ labels=labels,
312
+ shift_labels=shift_labels,
313
+ hidden_size=self.config.text_config.hidden_size,
314
+ **lm_kwargs,
355
315
  )
316
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
317
+
318
+ else:
319
+ logits = self.lm_head(kept_hidden_states)
320
+ if labels is not None or shift_labels is not None:
321
+ loss = self.loss_function(
322
+ logits=logits,
323
+ labels=labels,
324
+ shift_labels=shift_labels,
325
+ vocab_size=self.config.text_config.vocab_size,
326
+ **lm_kwargs,
327
+ )
356
328
 
357
329
  if not return_dict:
358
- # NOTE: This part has not been tested.
359
- output = outputs[1:]
360
- return (loss,) + output if loss is not None else output
330
+ output = (logits,) + outputs[1:]
331
+ output = (loss,) + output if loss is not None else output
332
+ output = output + (token_accuracy,) if token_accuracy is not None else output
333
+ return output
361
334
 
362
- return LlavaCausalLMOutputWithPast(
335
+ # Return custom output class with token_accuracy field
336
+ return LigerLlavaCausalLMOutputWithPast(
363
337
  loss=loss,
364
338
  logits=logits,
365
339
  past_key_values=outputs.past_key_values,
366
340
  hidden_states=outputs.hidden_states,
367
341
  attentions=outputs.attentions,
368
- image_hidden_states=image_features if pixel_values is not None else None,
342
+ image_hidden_states=outputs.image_hidden_states,
343
+ token_accuracy=token_accuracy,
369
344
  )
@@ -1,28 +1,61 @@
1
+ from typing import Optional
2
+ from typing import Tuple
3
+
4
+ import torch
1
5
  import torch.nn as nn
2
6
 
3
7
  import liger_kernel.transformers.functional as F
4
8
 
9
+ from liger_kernel.transformers.functional import CrossEntropyOutput
10
+
11
+
12
+ def unpack_cross_entropy_result(
13
+ result,
14
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
15
+ if isinstance(result, CrossEntropyOutput):
16
+ return result.loss, result.z_loss, result.token_accuracy
17
+
18
+ if isinstance(result, tuple):
19
+ loss = result[0]
20
+ z_loss = result[1] if len(result) > 1 else None
21
+ token_accuracy = result[2] if len(result) > 2 else None
22
+ return loss, z_loss, token_accuracy
23
+
24
+ return result, None, None
25
+
5
26
 
6
27
  def fixed_fused_linear_cross_entropy(
7
- hidden_states,
8
- lm_head_weight,
9
- target,
10
- num_items_in_batch: int = None,
28
+ hidden_states: torch.Tensor,
29
+ lm_head_weight: torch.Tensor,
30
+ target: torch.Tensor,
31
+ num_items_in_batch: Optional[int] = None,
11
32
  ignore_index: int = -100,
33
+ final_logit_softcapping: Optional[float] = None,
34
+ accum_dtype: Optional[torch.dtype] = None,
35
+ return_token_accuracy: bool = False,
12
36
  **kwargs,
13
37
  ):
14
38
  reduction = "sum" if num_items_in_batch is not None else "mean"
15
- loss = F.liger_fused_linear_cross_entropy(
39
+ result = F.liger_fused_linear_cross_entropy(
16
40
  hidden_states,
17
41
  lm_head_weight,
18
42
  target,
19
43
  reduction=reduction,
20
44
  ignore_index=ignore_index,
45
+ softcap=final_logit_softcapping,
46
+ accum_dtype=accum_dtype,
47
+ return_token_accuracy=return_token_accuracy,
21
48
  **kwargs,
22
49
  )
50
+
51
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
52
+
23
53
  if reduction == "sum":
24
54
  loss = loss / num_items_in_batch
25
55
 
56
+ if return_token_accuracy:
57
+ return CrossEntropyOutput(loss=loss, token_accuracy=token_accuracy)
58
+
26
59
  return loss
27
60
 
28
61
 
@@ -31,27 +64,32 @@ def LigerForCausalLMLoss(
31
64
  lm_head_weight,
32
65
  labels,
33
66
  hidden_size: int,
34
- num_items_in_batch: int = None,
67
+ num_items_in_batch: Optional[int] = None,
35
68
  ignore_index: int = -100,
69
+ shift_labels: Optional[torch.Tensor] = None,
70
+ final_logit_softcapping: Optional[float] = None,
71
+ return_token_accuracy: bool = False,
36
72
  **kwargs,
37
73
  ):
38
74
  # Skip upcast since intermediate values for the loss are all fp32 in kernel
39
- labels = labels.to(hidden_states.device)
40
- # Shift so that token < n predict n
41
- labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
42
- shift_labels = labels[..., 1:].contiguous()
75
+ if shift_labels is None:
76
+ # Shift so that token < n predict n
77
+ labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
78
+ shift_labels = labels[..., 1:].contiguous()
43
79
 
44
80
  # Flatten the tokens
45
81
  hidden_states = hidden_states.view(-1, hidden_size)
46
82
  shift_labels = shift_labels.view(-1)
47
83
  # Enable model parallelism
48
84
  shift_labels = shift_labels.to(hidden_states.device)
49
- loss = fixed_fused_linear_cross_entropy(
85
+ result = fixed_fused_linear_cross_entropy(
50
86
  hidden_states,
51
87
  lm_head_weight,
52
88
  shift_labels,
53
89
  num_items_in_batch,
54
90
  ignore_index,
91
+ final_logit_softcapping,
92
+ return_token_accuracy=return_token_accuracy,
55
93
  **kwargs,
56
94
  )
57
- return loss
95
+ return result
@@ -5,19 +5,15 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
8
  from transformers.cache_utils import Cache
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
12
- from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
9
+ from transformers.utils.deprecation import deprecate_kwarg
15
10
 
16
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
17
14
 
18
15
 
19
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
16
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
21
17
  def lce_forward(
22
18
  self,
23
19
  input_ids: torch.LongTensor = None,
@@ -31,8 +27,10 @@ def lce_forward(
31
27
  output_hidden_states: Optional[bool] = None,
32
28
  return_dict: Optional[bool] = None,
33
29
  cache_position: Optional[torch.LongTensor] = None,
34
- **loss_kwargs,
35
- ) -> Union[Tuple, CausalLMOutputWithPast]:
30
+ logits_to_keep: Union[int, torch.Tensor] = 0,
31
+ skip_logits: Optional[bool] = None,
32
+ **kwargs,
33
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
36
34
  r"""
37
35
  Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
38
36
 
@@ -43,6 +41,12 @@ def lce_forward(
43
41
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
44
42
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
45
43
 
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
46
50
  Returns:
47
51
 
48
52
  Example:
@@ -80,49 +84,62 @@ def lce_forward(
80
84
  output_hidden_states=output_hidden_states,
81
85
  return_dict=return_dict,
82
86
  cache_position=cache_position,
87
+ **kwargs,
83
88
  )
84
89
 
85
90
  hidden_states = outputs[0]
91
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
92
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93
+ kept_hidden_states = hidden_states[:, slice_indices, :]
86
94
 
95
+ shift_labels = kwargs.pop("shift_labels", None)
87
96
  loss = None
88
97
  logits = None
98
+ token_accuracy = None
89
99
 
90
- if self.training and (labels is not None):
91
- loss = LigerForCausalLMLoss(
92
- hidden_states=hidden_states,
100
+ if skip_logits and labels is None and shift_labels is None:
101
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
102
+
103
+ if skip_logits is None:
104
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
105
+
106
+ # Compute loss
107
+ if skip_logits:
108
+ result = LigerForCausalLMLoss(
109
+ hidden_states=kept_hidden_states,
93
110
  lm_head_weight=self.lm_head.weight,
94
111
  labels=labels,
112
+ shift_labels=shift_labels,
95
113
  hidden_size=self.config.hidden_size,
96
- **loss_kwargs,
114
+ **kwargs,
97
115
  )
116
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
98
117
 
99
118
  else:
100
- logits = self.lm_head(hidden_states)
101
- if labels is not None:
102
- # Upcast to float if we need to compute the loss to avoid potential precision issues
103
- logits = logits.float()
104
- # Shift so that tokens < n predict n
105
- shift_logits = logits[..., :-1, :].contiguous()
106
- shift_labels = labels[..., 1:].contiguous()
107
- # Flatten the tokens
108
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
109
- shift_labels = shift_labels.view(-1)
110
- # Ensure tensors are on the same device
111
- shift_labels = shift_labels.to(shift_logits.device)
112
- loss_fct = CrossEntropyLoss()
113
- loss = loss_fct(shift_logits, shift_labels)
119
+ logits = self.lm_head(kept_hidden_states)
120
+
121
+ loss = None
122
+ if labels is not None or shift_labels is not None:
123
+ loss = self.loss_function(
124
+ logits=logits,
125
+ labels=labels,
126
+ shift_labels=shift_labels,
127
+ vocab_size=self.config.vocab_size,
128
+ **kwargs,
129
+ )
114
130
 
115
131
  if not return_dict:
116
- output = (logits,) + outputs[1:]
117
- return (loss,) + output if loss is not None else output
132
+ output_tuple = (logits,) + outputs[1:]
133
+ output = (loss,) + output_tuple if loss is not None else output_tuple
134
+ output = output + (token_accuracy,) if token_accuracy is not None else output
135
+ return output
118
136
 
119
- return CausalLMOutputWithPast(
137
+ # Return custom output class with token_accuracy field
138
+ return LigerCausalLMOutputWithPast(
120
139
  loss=loss,
121
140
  logits=logits,
122
141
  past_key_values=outputs.past_key_values,
123
142
  hidden_states=outputs.hidden_states,
124
143
  attentions=outputs.attentions,
144
+ token_accuracy=token_accuracy,
125
145
  )
126
-
127
-
128
- # Note: Grad Acc is not fixed in mistral at transformer 4.46.1