liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  6. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  7. liger_kernel/ops/dyt.py +111 -179
  8. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  9. liger_kernel/ops/geglu.py +1 -1
  10. liger_kernel/ops/grpo_loss.py +310 -0
  11. liger_kernel/ops/multi_token_attention.py +207 -0
  12. liger_kernel/ops/rms_norm.py +265 -54
  13. liger_kernel/ops/softmax.py +201 -0
  14. liger_kernel/ops/sparsemax.py +179 -0
  15. liger_kernel/ops/swiglu.py +1 -1
  16. liger_kernel/transformers/__init__.py +8 -0
  17. liger_kernel/transformers/dyt.py +5 -3
  18. liger_kernel/transformers/fsdp.py +55 -0
  19. liger_kernel/transformers/functional.py +70 -0
  20. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  21. liger_kernel/transformers/grpo_loss.py +98 -0
  22. liger_kernel/transformers/model/gemma.py +25 -16
  23. liger_kernel/transformers/model/gemma2.py +27 -14
  24. liger_kernel/transformers/model/gemma3.py +62 -106
  25. liger_kernel/transformers/model/glm4.py +16 -13
  26. liger_kernel/transformers/model/llama.py +81 -18
  27. liger_kernel/transformers/model/llama4.py +108 -0
  28. liger_kernel/transformers/model/llava.py +95 -132
  29. liger_kernel/transformers/model/mistral.py +13 -14
  30. liger_kernel/transformers/model/mixtral.py +16 -15
  31. liger_kernel/transformers/model/mllama.py +16 -14
  32. liger_kernel/transformers/model/olmo2.py +16 -13
  33. liger_kernel/transformers/model/paligemma.py +8 -9
  34. liger_kernel/transformers/model/phi3.py +25 -16
  35. liger_kernel/transformers/model/qwen2.py +24 -15
  36. liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
  37. liger_kernel/transformers/model/qwen2_vl.py +38 -106
  38. liger_kernel/transformers/model/qwen3.py +11 -9
  39. liger_kernel/transformers/model/qwen3_moe.py +132 -0
  40. liger_kernel/transformers/monkey_patch.py +424 -81
  41. liger_kernel/transformers/multi_token_attention.py +64 -0
  42. liger_kernel/transformers/rms_norm.py +40 -4
  43. liger_kernel/transformers/softmax.py +12 -0
  44. liger_kernel/transformers/sparsemax.py +16 -0
  45. liger_kernel/transformers/swiglu.py +21 -0
  46. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  47. liger_kernel/utils.py +11 -0
  48. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
  49. liger_kernel-0.6.0.dist-info/RECORD +97 -0
  50. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  51. liger_kernel/transformers/gema3_rms.py +0 -8
  52. liger_kernel-0.5.9.dist-info/RECORD +0 -84
  53. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -7,23 +7,23 @@ from typing import Union
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
 
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
10
11
  from torch.nn import CrossEntropyLoss
11
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
- from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
13
  from transformers.utils.deprecation import deprecate_kwarg
17
14
 
15
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
18
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
+ from liger_kernel.utils import PEFT_AVAILABLE
20
19
 
21
20
  if TYPE_CHECKING:
22
21
  from transformers.cache_utils import Cache
23
22
 
23
+ if PEFT_AVAILABLE:
24
+ from peft.utils.other import ModulesToSaveWrapper
25
+
24
26
 
25
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
26
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
27
27
  def lce_forward_deprecated(
28
28
  self,
29
29
  input_ids: torch.LongTensor = None,
@@ -37,6 +37,7 @@ def lce_forward_deprecated(
37
37
  output_hidden_states: Optional[bool] = None,
38
38
  return_dict: Optional[bool] = None,
39
39
  cache_position: Optional[torch.LongTensor] = None,
40
+ skip_logits: Optional[bool] = None,
40
41
  ) -> Union[Tuple, CausalLMOutputWithPast]:
41
42
  r"""
42
43
  Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
@@ -91,7 +92,15 @@ def lce_forward_deprecated(
91
92
  loss = None
92
93
  logits = None
93
94
 
94
- if self.training and (labels is not None):
95
+ # if in training mode, don't materialize logits
96
+ if skip_logits and labels is None:
97
+ raise ValueError("skip_logits is True, but labels is None")
98
+
99
+ if skip_logits is None:
100
+ # By default, if in training mode, don't materialize logits
101
+ skip_logits = self.training and labels is not None
102
+
103
+ if skip_logits:
95
104
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
96
105
  shift_labels = labels[..., 1:].contiguous()
97
106
 
@@ -137,8 +146,6 @@ def lce_forward_deprecated(
137
146
 
138
147
 
139
148
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
140
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
142
149
  def lce_forward(
143
150
  self,
144
151
  input_ids: torch.LongTensor = None,
@@ -153,7 +160,8 @@ def lce_forward(
153
160
  return_dict: Optional[bool] = None,
154
161
  cache_position: Optional[torch.LongTensor] = None,
155
162
  logits_to_keep: Union[int, torch.Tensor] = 0,
156
- **loss_kwargs,
163
+ skip_logits: Optional[bool] = None,
164
+ **kwargs,
157
165
  ) -> Union[Tuple, CausalLMOutputWithPast]:
158
166
  r"""
159
167
  Args:
@@ -206,6 +214,7 @@ def lce_forward(
206
214
  output_hidden_states=output_hidden_states,
207
215
  return_dict=return_dict,
208
216
  cache_position=cache_position,
217
+ **kwargs,
209
218
  )
210
219
 
211
220
  hidden_states = outputs[0]
@@ -216,28 +225,35 @@ def lce_forward(
216
225
  if self.config.pretraining_tp > 1:
217
226
  raise Exception("Liger Kernel does not support pretraining_tp!!")
218
227
 
219
- shift_labels = loss_kwargs.pop("shift_labels", None)
228
+ shift_labels = kwargs.pop("shift_labels", None)
220
229
  logits = None
221
230
  loss = None
222
231
  # if in training mode, don't materialize logits
223
- if self.training and (labels is not None or shift_labels is not None):
224
- loss = LigerForCausalLMLoss(
232
+ if skip_logits and labels is None and shift_labels is None:
233
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
234
+
235
+ if skip_logits is None:
236
+ # By default, if in training mode, don't materialize logits
237
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
238
+
239
+ if skip_logits:
240
+ loss = lce_maybe_trainable_lm_head(
241
+ self,
225
242
  hidden_states=kept_hidden_states,
226
- lm_head_weight=self.lm_head.weight,
243
+ hidden_size=self.config.hidden_size,
227
244
  labels=labels,
228
245
  shift_labels=shift_labels,
229
- hidden_size=self.config.hidden_size,
230
- **loss_kwargs,
246
+ **kwargs,
231
247
  )
232
248
 
233
- else: # if in inference mode materialize logits
249
+ else:
234
250
  logits = self.lm_head(kept_hidden_states)
235
251
  if labels is not None:
236
252
  loss = self.loss_function(
237
253
  logits=logits,
238
254
  labels=labels,
239
255
  vocab_size=self.config.vocab_size,
240
- **loss_kwargs,
256
+ **kwargs,
241
257
  )
242
258
 
243
259
  if not return_dict:
@@ -251,3 +267,50 @@ def lce_forward(
251
267
  hidden_states=outputs.hidden_states,
252
268
  attentions=outputs.attentions,
253
269
  )
270
+
271
+
272
+ def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
273
+ lm_head = self.lm_head
274
+
275
+ # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
276
+ # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
277
+ # from the unwrapped module.
278
+ # See https://huggingface.co/docs/peft/package_reference/lora for reference.
279
+ if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
280
+ lm_head = lm_head.modules_to_save.default
281
+
282
+ # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
283
+ # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
284
+ # so the module entire parameters are summoned and kept in memory during the kernel execution.
285
+ if isinstance(lm_head, FullyShardedDataParallel):
286
+ return _FSDPForwardRedirection()(
287
+ lm_head,
288
+ _liger_for_causal_lm_loss,
289
+ lm_head.module,
290
+ hidden_states,
291
+ hidden_size,
292
+ labels,
293
+ shift_labels,
294
+ **loss_kwargs,
295
+ )
296
+
297
+ # FSDP is not used so we can read the lm_head weights and call the kernel directly
298
+ return _liger_for_causal_lm_loss(
299
+ lm_head=self.lm_head,
300
+ hidden_states=hidden_states,
301
+ hidden_size=hidden_size,
302
+ labels=labels,
303
+ shift_labels=shift_labels,
304
+ **loss_kwargs,
305
+ )
306
+
307
+
308
+ def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
309
+ return LigerForCausalLMLoss(
310
+ hidden_states=hidden_states,
311
+ lm_head_weight=lm_head.weight,
312
+ labels=labels,
313
+ hidden_size=hidden_size,
314
+ shift_labels=shift_labels,
315
+ **loss_kwargs,
316
+ )
@@ -0,0 +1,108 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.cache_utils import Cache
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+
13
+
14
+ def lce_forward(
15
+ self,
16
+ input_ids: torch.LongTensor = None,
17
+ attention_mask: Optional[torch.Tensor] = None,
18
+ position_ids: Optional[torch.LongTensor] = None,
19
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
20
+ inputs_embeds: Optional[torch.FloatTensor] = None,
21
+ labels: Optional[torch.LongTensor] = None,
22
+ use_cache: Optional[bool] = None,
23
+ output_attentions: Optional[bool] = None,
24
+ output_hidden_states: Optional[bool] = None,
25
+ return_dict: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ **kwargs,
29
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
30
+ r"""
31
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
32
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
33
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
34
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
35
+
36
+ Example:
37
+
38
+ ```python
39
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
40
+
41
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
42
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
43
+
44
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
45
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
46
+
47
+ >>> # Generate
48
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
49
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
50
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
51
+ ```"""
52
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
53
+ output_hidden_states = (
54
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
55
+ )
56
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
+
58
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
59
+ outputs = self.model(
60
+ input_ids=input_ids,
61
+ attention_mask=attention_mask,
62
+ position_ids=position_ids,
63
+ past_key_values=past_key_values,
64
+ inputs_embeds=inputs_embeds,
65
+ use_cache=use_cache,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=True,
69
+ cache_position=cache_position,
70
+ **kwargs,
71
+ )
72
+
73
+ hidden_states = outputs[0]
74
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
75
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
76
+ kept_hidden_states = hidden_states[:, slice_indices, :]
77
+
78
+ shift_labels = kwargs.pop("shift_labels", None)
79
+ logits = None
80
+ loss = None
81
+
82
+ if self.training and (labels is not None or shift_labels is not None):
83
+ loss = LigerForCausalLMLoss(
84
+ hidden_states=kept_hidden_states,
85
+ lm_head_weight=self.lm_head.weight,
86
+ labels=labels,
87
+ shift_labels=shift_labels,
88
+ hidden_size=self.config.hidden_size,
89
+ **kwargs,
90
+ )
91
+
92
+ else: # if in inference mode materialize logits
93
+ logits = self.lm_head(kept_hidden_states)
94
+ if labels is not None:
95
+ loss = self.loss_function(
96
+ logits=logits,
97
+ labels=labels,
98
+ vocab_size=self.config.vocab_size,
99
+ **kwargs,
100
+ )
101
+
102
+ return CausalLMOutputWithPast(
103
+ loss=loss,
104
+ logits=logits,
105
+ past_key_values=outputs.past_key_values,
106
+ hidden_states=outputs.hidden_states,
107
+ attentions=outputs.attentions,
108
+ )
@@ -5,19 +5,14 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
9
- from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
8
+ from torch.nn import CrossEntropyLoss
10
9
  from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
10
  from transformers.utils import is_torchdynamo_compiling
13
- from transformers.utils import replace_return_docstrings
14
- from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
14
 
18
15
 
19
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
16
  def lce_forward_deprecated(
22
17
  self,
23
18
  input_ids: torch.LongTensor = None,
@@ -33,6 +28,11 @@ def lce_forward_deprecated(
33
28
  output_attentions: Optional[bool] = None,
34
29
  output_hidden_states: Optional[bool] = None,
35
30
  return_dict: Optional[bool] = None,
31
+ cache_position: Optional[torch.LongTensor] = None,
32
+ logits_to_keep: Union[int, torch.Tensor] = 0,
33
+ image_sizes: torch.Tensor = None,
34
+ skip_logits: Optional[bool] = None,
35
+ **lm_kwargs,
36
36
  ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
37
37
  r"""
38
38
  Args:
@@ -41,10 +41,12 @@ def lce_forward_deprecated(
41
41
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
42
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
43
 
44
- num_logits_to_keep (`int`, *optional*):
45
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46
46
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
47
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
48
50
 
49
51
 
50
52
  Returns:
@@ -70,7 +72,6 @@ def lce_forward_deprecated(
70
72
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
71
73
  "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
72
74
  ```"""
73
-
74
75
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
76
  output_hidden_states = (
76
77
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -94,73 +95,24 @@ def lce_forward_deprecated(
94
95
  )
95
96
 
96
97
  if inputs_embeds is None:
97
- # 1. Extra the input embeddings
98
98
  inputs_embeds = self.get_input_embeddings()(input_ids)
99
99
 
100
- # 2. Merge text and images
101
- if pixel_values is not None and input_ids.shape[1] != 1:
102
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
103
- # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
104
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
105
-
106
- if vision_feature_select_strategy == "default":
107
- selected_image_feature = selected_image_feature[:, 1:]
108
- elif vision_feature_select_strategy == "full":
109
- selected_image_feature = selected_image_feature
110
- else:
111
- raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
112
-
113
- image_features = self.multi_modal_projector(selected_image_feature)
114
- inputs_embeds = inputs_embeds.to(image_features.dtype)
115
- inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
116
- image_features, inputs_embeds, input_ids, attention_mask, labels
117
- )
118
-
119
- # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
120
- # generation with cache
121
- elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
122
- # Retrieve the first layer to inspect the logits and mask out the hidden states
123
- # that are set to 0
124
- first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
125
-
126
- # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
127
- batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
128
-
129
- # Get the target length
130
- target_length = input_ids.shape[1]
131
- past_length = first_layer_past_key_value.shape[-1]
132
-
133
- extended_attention_mask = torch.ones(
134
- (attention_mask.shape[0], past_length),
135
- dtype=attention_mask.dtype,
136
- device=attention_mask.device,
137
- )
138
-
139
- # Filter out only the tokens that can be un-attended, this can happen
140
- # if one uses Llava + Fused modules where the cache on the
141
- # first iteration is already big enough, or if one passes custom cache
142
- valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
143
- new_batch_index = batch_index[valid_indices]
144
- new_non_attended_tokens = non_attended_tokens[valid_indices]
145
-
146
- # Zero-out the places where we don't need to attend
147
- extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
148
-
149
- attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
150
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
151
-
152
- # TODO: @raushan retain only the new behavior after v4.47
153
- elif image_features is not None:
154
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
155
- n_image_features = image_features.shape[0] * image_features.shape[1]
100
+ if pixel_values is not None:
101
+ image_features = self.get_image_features(
102
+ pixel_values=pixel_values,
103
+ vision_feature_layer=vision_feature_layer,
104
+ vision_feature_select_strategy=vision_feature_select_strategy,
105
+ image_sizes=image_sizes,
106
+ )
156
107
 
157
- if n_image_tokens != n_image_features:
108
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
109
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
110
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
111
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
112
+ n_image_features = image_features.shape[0] * image_features.shape[1]
158
113
  raise ValueError(
159
114
  f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
160
115
  )
161
- special_image_mask = (
162
- (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
163
- )
164
116
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
165
117
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
166
118
 
@@ -173,13 +125,19 @@ def lce_forward_deprecated(
173
125
  output_attentions=output_attentions,
174
126
  output_hidden_states=output_hidden_states,
175
127
  return_dict=return_dict,
128
+ cache_position=cache_position,
129
+ logits_to_keep=logits_to_keep,
130
+ **lm_kwargs,
176
131
  )
177
132
  hidden_states = outputs[0]
178
133
 
179
134
  loss = None
180
135
  logits = None
181
136
 
182
- if self.training and (labels is not None):
137
+ # Overwrite skip_logits, since llava never materializes logits
138
+ skip_logits = labels is not None
139
+
140
+ if skip_logits:
183
141
  # Shift so that tokens < n predict n
184
142
  if attention_mask is not None:
185
143
  # we use the input attention mask to shift the logits and labels, because it is 2D.
@@ -194,7 +152,33 @@ def lce_forward_deprecated(
194
152
  shift_labels = labels[..., 1:].contiguous()
195
153
 
196
154
  lce = LigerFusedLinearCrossEntropyLoss()
197
- loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
155
+ loss = lce(
156
+ self.language_model.lm_head.weight,
157
+ shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
158
+ shift_labels.view(-1).to(shift_hidden_states.device),
159
+ )
160
+ else:
161
+ logits = self.language_model.lm_head(hidden_states)
162
+ if labels is not None:
163
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
164
+ logits = logits.float()
165
+ shift_logits = logits[..., :-1, :]
166
+ shift_labels = labels[..., 1:]
167
+ if attention_mask is not None:
168
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
169
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
170
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
171
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
172
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
173
+ else:
174
+ shift_logits = shift_logits.contiguous()
175
+ shift_labels = shift_labels.contiguous()
176
+ # Flatten the tokens
177
+ loss_fct = CrossEntropyLoss()
178
+
179
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
180
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
181
+ loss = loss_fct(flat_logits, flat_labels)
198
182
 
199
183
  if not return_dict:
200
184
  # NOTE: This part has not been tested.
@@ -207,12 +191,10 @@ def lce_forward_deprecated(
207
191
  past_key_values=outputs.past_key_values,
208
192
  hidden_states=outputs.hidden_states,
209
193
  attentions=outputs.attentions,
194
+ image_hidden_states=image_features if pixel_values is not None else None,
210
195
  )
211
196
 
212
197
 
213
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
214
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
215
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
216
198
  def lce_forward(
217
199
  self,
218
200
  input_ids: torch.LongTensor = None,
@@ -231,6 +213,7 @@ def lce_forward(
231
213
  cache_position: Optional[torch.LongTensor] = None,
232
214
  logits_to_keep: Union[int, torch.Tensor] = 0,
233
215
  image_sizes: torch.Tensor = None,
216
+ skip_logits: Optional[bool] = None,
234
217
  **lm_kwargs,
235
218
  ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
236
219
  r"""
@@ -285,78 +268,58 @@ def lce_forward(
285
268
  else self.config.vision_feature_select_strategy
286
269
  )
287
270
 
288
- if (input_ids is None) ^ (inputs_embeds is not None):
289
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
290
-
291
- if pixel_values is not None and inputs_embeds is not None:
292
- raise ValueError(
293
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
294
- )
295
-
296
- if inputs_embeds is None:
297
- inputs_embeds = self.get_input_embeddings()(input_ids)
298
-
299
- if pixel_values is not None:
300
- image_features = self.get_image_features(
301
- pixel_values=pixel_values,
302
- vision_feature_layer=vision_feature_layer,
303
- vision_feature_select_strategy=vision_feature_select_strategy,
304
- image_sizes=image_sizes,
305
- )
306
-
307
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
308
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
309
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
310
- n_image_tokens = (input_ids == self.config.image_token_index).sum()
311
- n_image_features = image_features.shape[0] * image_features.shape[1]
312
- raise ValueError(
313
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
314
- )
315
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
316
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
317
-
318
- outputs = self.language_model.model(
271
+ outputs = self.model(
272
+ input_ids=input_ids,
273
+ pixel_values=pixel_values,
319
274
  attention_mask=attention_mask,
320
275
  position_ids=position_ids,
321
276
  past_key_values=past_key_values,
322
277
  inputs_embeds=inputs_embeds,
278
+ vision_feature_layer=vision_feature_layer,
279
+ vision_feature_select_strategy=vision_feature_select_strategy,
323
280
  use_cache=use_cache,
324
281
  output_attentions=output_attentions,
325
282
  output_hidden_states=output_hidden_states,
326
- return_dict=return_dict,
283
+ return_dict=True,
327
284
  cache_position=cache_position,
328
- logits_to_keep=logits_to_keep,
285
+ image_sizes=image_sizes,
329
286
  **lm_kwargs,
330
287
  )
331
288
  hidden_states = outputs[0]
289
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
290
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
291
+ kept_hidden_states = hidden_states[:, slice_indices, :]
332
292
 
333
- loss = None
293
+ shift_labels = lm_kwargs.pop("shift_labels", None)
334
294
  logits = None
295
+ loss = None
335
296
 
336
- if self.training and (labels is not None):
337
- # Shift so that tokens < n predict n
338
- if attention_mask is not None:
339
- # we use the input attention mask to shift the logits and labels, because it is 2D.
340
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
341
- shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
342
- shift_hidden_states = hidden_states[..., :-1, :][
343
- shift_attention_mask.to(hidden_states.device) != 0
344
- ].contiguous()
345
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
346
- else:
347
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
348
- shift_labels = labels[..., 1:].contiguous()
349
-
350
- lce = LigerFusedLinearCrossEntropyLoss()
351
- loss = lce(
352
- self.language_model.lm_head.weight,
353
- shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
354
- shift_labels.view(-1).to(shift_hidden_states.device),
297
+ if skip_logits and labels is None and shift_labels is None:
298
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
299
+
300
+ if skip_logits is None:
301
+ # By default, if in training mode, don't materialize logits
302
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
303
+
304
+ if skip_logits:
305
+ loss = LigerForCausalLMLoss(
306
+ hidden_states=kept_hidden_states,
307
+ lm_head_weight=self.lm_head.weight,
308
+ labels=labels,
309
+ shift_labels=shift_labels,
310
+ hidden_size=self.config.text_config.hidden_size,
311
+ **lm_kwargs,
355
312
  )
356
313
 
314
+ else:
315
+ logits = self.lm_head(kept_hidden_states)
316
+ if labels is not None:
317
+ loss = self.loss_function(
318
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
319
+ )
320
+
357
321
  if not return_dict:
358
- # NOTE: This part has not been tested.
359
- output = outputs[1:]
322
+ output = (logits,) + outputs[1:]
360
323
  return (loss,) + output if loss is not None else output
361
324
 
362
325
  return LlavaCausalLMOutputWithPast(
@@ -365,5 +328,5 @@ def lce_forward(
365
328
  past_key_values=outputs.past_key_values,
366
329
  hidden_states=outputs.hidden_states,
367
330
  attentions=outputs.attentions,
368
- image_hidden_states=image_features if pixel_values is not None else None,
331
+ image_hidden_states=outputs.image_hidden_states,
369
332
  )