liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +134 -60
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +32 -12
  10. liger_kernel/ops/kl_div.py +15 -8
  11. liger_kernel/ops/layer_norm.py +14 -1
  12. liger_kernel/ops/rms_norm.py +12 -1
  13. liger_kernel/transformers/__init__.py +133 -15
  14. liger_kernel/transformers/dyt.py +20 -0
  15. liger_kernel/transformers/functional.py +5 -0
  16. liger_kernel/transformers/gema3_rms.py +8 -0
  17. liger_kernel/transformers/model/gemma.py +17 -20
  18. liger_kernel/transformers/model/gemma2.py +17 -21
  19. liger_kernel/transformers/model/gemma3.py +335 -0
  20. liger_kernel/transformers/model/llama.py +17 -19
  21. liger_kernel/transformers/model/llava.py +369 -0
  22. liger_kernel/transformers/model/loss_utils.py +64 -0
  23. liger_kernel/transformers/model/mistral.py +28 -25
  24. liger_kernel/transformers/model/mixtral.py +20 -26
  25. liger_kernel/transformers/model/mllama.py +17 -19
  26. liger_kernel/transformers/model/olmo2.py +17 -20
  27. liger_kernel/transformers/model/paligemma.py +397 -0
  28. liger_kernel/transformers/model/phi3.py +17 -19
  29. liger_kernel/transformers/model/qwen2.py +17 -19
  30. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  31. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  32. liger_kernel/transformers/monkey_patch.py +392 -13
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
  36. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  37. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
  38. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
  39. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
11
11
  from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
12
12
  from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
14
15
 
15
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
18
 
17
19
 
18
20
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@@ -132,6 +134,7 @@ def lce_forward_deprecated(
132
134
  )
133
135
 
134
136
 
137
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
135
138
  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
136
139
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
137
140
  def lce_forward(
@@ -150,7 +153,7 @@ def lce_forward(
150
153
  output_hidden_states: Optional[bool] = None,
151
154
  return_dict: Optional[bool] = None,
152
155
  cache_position: Optional[torch.LongTensor] = None,
153
- num_logits_to_keep: int = 0,
156
+ logits_to_keep: Union[int, torch.Tensor] = 0,
154
157
  **loss_kwargs,
155
158
  ) -> Union[Tuple, CausalLMOutputWithPast]:
156
159
  r"""
@@ -160,10 +163,12 @@ def lce_forward(
160
163
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
161
164
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
162
165
 
163
- num_logits_to_keep (`int`, *optional*):
164
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
166
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
167
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
165
168
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
166
169
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
170
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
171
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
167
172
 
168
173
  Returns:
169
174
 
@@ -215,24 +220,17 @@ def lce_forward(
215
220
  loss = None
216
221
  # if in training mode, don't materialize logits
217
222
  if self.training and (labels is not None):
218
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
219
-
220
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
221
- shift_labels = labels[..., 1:].contiguous()
222
-
223
- # flatten tokens
224
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
225
- shift_labels = shift_labels.view(-1)
226
-
227
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
228
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
229
-
230
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
231
- if reduction == "sum":
232
- loss /= loss_kwargs["num_items_in_batch"]
223
+ loss = LigerForCausalLMLoss(
224
+ hidden_states=hidden_states,
225
+ lm_head_weight=self.lm_head.weight,
226
+ labels=labels,
227
+ hidden_size=self.config.hidden_size,
228
+ **loss_kwargs,
229
+ )
233
230
 
234
231
  else: # if in inference mode materialize logits
235
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
232
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
236
234
  if labels is not None:
237
235
  loss = self.loss_function(
238
236
  logits=logits,
@@ -10,10 +10,12 @@ from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
10
  from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
11
  from transformers.utils import add_start_docstrings_to_model_forward
12
12
  from transformers.utils import replace_return_docstrings
13
+ from transformers.utils.deprecation import deprecate_kwarg
13
14
 
14
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
16
 
16
17
 
18
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
17
19
  @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
18
20
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
19
21
  def lce_forward(
@@ -29,7 +31,7 @@ def lce_forward(
29
31
  output_hidden_states: Optional[bool] = None,
30
32
  return_dict: Optional[bool] = None,
31
33
  cache_position: Optional[torch.LongTensor] = None,
32
- num_logits_to_keep: int = 0,
34
+ logits_to_keep: Union[int, torch.Tensor] = 0,
33
35
  **loss_kwargs,
34
36
  ) -> Union[Tuple, CausalLMOutputWithPast]:
35
37
  r"""
@@ -39,10 +41,12 @@ def lce_forward(
39
41
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
42
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
43
 
42
- num_logits_to_keep (`int`, *optional*):
43
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
44
46
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
45
47
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
46
50
 
47
51
  Returns:
48
52
 
@@ -89,24 +93,17 @@ def lce_forward(
89
93
  loss = None
90
94
  # if in training mode, don't materialize logits
91
95
  if self.training and (labels is not None):
92
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
93
-
94
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
95
- shift_labels = labels[..., 1:].contiguous()
96
-
97
- # flatten tokens
98
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
99
- shift_labels = shift_labels.view(-1)
100
-
101
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
102
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
103
-
104
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
105
- if reduction == "sum":
106
- loss /= loss_kwargs["num_items_in_batch"]
96
+ loss = LigerForCausalLMLoss(
97
+ hidden_states=hidden_states,
98
+ lm_head_weight=self.lm_head.weight,
99
+ labels=labels,
100
+ hidden_size=self.config.hidden_size,
101
+ **loss_kwargs,
102
+ )
107
103
 
108
104
  else: # if in inference mode materialize logits
109
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
105
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
110
107
  if labels is not None:
111
108
  loss = self.loss_function(
112
109
  logits=logits,
@@ -0,0 +1,397 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from torch.nn import CrossEntropyLoss
9
+ from transformers.cache_utils import Cache
10
+ from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
11
+ from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
12
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import is_torchdynamo_compiling
15
+ from transformers.utils import logging
16
+ from transformers.utils import replace_return_docstrings
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+
19
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
+ def lce_forward_deprecated(
27
+ self,
28
+ input_ids: torch.LongTensor = None,
29
+ pixel_values: torch.FloatTensor = None,
30
+ attention_mask: Optional[torch.Tensor] = None,
31
+ position_ids: Optional[torch.LongTensor] = None,
32
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
33
+ token_type_ids: Optional[torch.LongTensor] = None,
34
+ cache_position: Optional[torch.LongTensor] = None,
35
+ inputs_embeds: Optional[torch.FloatTensor] = None,
36
+ labels: Optional[torch.LongTensor] = None,
37
+ use_cache: Optional[bool] = None,
38
+ output_attentions: Optional[bool] = None,
39
+ output_hidden_states: Optional[bool] = None,
40
+ return_dict: Optional[bool] = None,
41
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
42
+ r"""
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ Returns:
50
+
51
+ Example:
52
+
53
+ ```python
54
+ >>> from PIL import Image
55
+ >>> import requests
56
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
57
+
58
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
59
+ >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
60
+
61
+ >>> prompt = "answer en Where is the cow standing?"
62
+ >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
63
+ >>> image = Image.open(requests.get(url, stream=True).raw)
64
+
65
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
66
+
67
+ >>> # Generate
68
+ >>> generate_ids = model.generate(**inputs, max_length=30)
69
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
70
+ "answer en Where is the cow standing?\nbeach"
71
+ ```"""
72
+
73
+ if (input_ids is None) ^ (inputs_embeds is not None):
74
+ raise ValueError(
75
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
76
+ )
77
+
78
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
79
+ output_hidden_states = (
80
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
+ )
82
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
+
84
+ # the attention mask is turned 4d after, we keep track of the original one
85
+ input_attention_mask = attention_mask
86
+
87
+ if inputs_embeds is None:
88
+ # 1. Extra the input embeddings
89
+ inputs_embeds = self.get_input_embeddings()(input_ids)
90
+
91
+ # 2. Merge text and images
92
+ if pixel_values is not None and input_ids.shape[1] != 1:
93
+ image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
94
+ selected_image_feature = image_outputs.last_hidden_state
95
+ image_features = self.multi_modal_projector(selected_image_feature)
96
+
97
+ if cache_position is None:
98
+ cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
99
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
100
+ image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
101
+ )
102
+
103
+ else:
104
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
105
+ # generation with cache
106
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
107
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
108
+ # that are set to 0
109
+ # TODO @molbap this will only work for dynamic cache.
110
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
111
+
112
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
113
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
114
+
115
+ # Get the target length
116
+ target_seqlen = cache_position[-1] + 1
117
+ extended_attention_mask = torch.ones(
118
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1),
119
+ dtype=attention_mask.dtype,
120
+ device=attention_mask.device,
121
+ )
122
+ # Filter out only the tokens that can be un-attended, this can happen
123
+ # if one uses PaliGemma+ Fused modules where the cache on the
124
+ # first iteration is already big enough, or if one passes custom cache
125
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
126
+ new_batch_index = batch_index[valid_indices]
127
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
128
+
129
+ # Zero-out the places where we don't need to attend
130
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
131
+
132
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
133
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
134
+
135
+ attention_mask = attention_mask.to(inputs_embeds.dtype)
136
+ outputs = self.language_model.model(
137
+ attention_mask=attention_mask,
138
+ position_ids=position_ids,
139
+ past_key_values=past_key_values,
140
+ inputs_embeds=inputs_embeds,
141
+ use_cache=use_cache,
142
+ output_attentions=output_attentions,
143
+ output_hidden_states=output_hidden_states,
144
+ return_dict=return_dict,
145
+ cache_position=cache_position,
146
+ )
147
+
148
+ hidden_states = outputs[0]
149
+
150
+ loss = None
151
+ logits = None
152
+
153
+ if self.training and (labels is not None):
154
+ shift_hidden_states = hidden_states[..., :-1, :]
155
+ shift_labels = labels[..., 1:]
156
+
157
+ hidden_device = shift_hidden_states.device
158
+
159
+ if attention_mask is not None:
160
+ # we use the input attention mask to shift the hidden_states and labels, because it is 2D.
161
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
162
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
163
+ shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
164
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
165
+ else:
166
+ shift_hidden_states = shift_hidden_states.contiguous()
167
+ shift_labels = shift_labels.contiguous()
168
+
169
+ # Flatten hidden state
170
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
171
+ shift_labels = shift_labels.view(-1).to(hidden_device)
172
+
173
+ lce = LigerFusedLinearCrossEntropyLoss()
174
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
175
+
176
+ else:
177
+ logits = self.language_model.lm_head(hidden_states)
178
+ if labels is not None:
179
+ shift_logits = logits[..., :-1, :]
180
+ shift_labels = labels[..., 1:]
181
+ if input_attention_mask is not None:
182
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
183
+ shift_attention_mask = input_attention_mask[..., 1:]
184
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
185
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
186
+ else:
187
+ shift_logits = shift_logits.contiguous()
188
+ shift_labels = shift_labels.contiguous()
189
+ # Flatten the tokens
190
+ loss_fct = CrossEntropyLoss()
191
+
192
+ flat_logits = shift_logits.view(-1, self.config.vocab_size)
193
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
194
+ loss = loss_fct(flat_logits, flat_labels)
195
+ if not return_dict:
196
+ output = (logits,) + outputs[1:]
197
+ return (loss,) + output if loss is not None else output
198
+
199
+ return PaliGemmaCausalLMOutputWithPast(
200
+ loss=loss,
201
+ logits=logits,
202
+ past_key_values=outputs.past_key_values,
203
+ hidden_states=outputs.hidden_states,
204
+ attentions=outputs.attentions,
205
+ )
206
+
207
+
208
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
209
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
210
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
211
+ def lce_forward(
212
+ self,
213
+ input_ids: torch.LongTensor = None,
214
+ pixel_values: torch.FloatTensor = None,
215
+ attention_mask: Optional[torch.Tensor] = None,
216
+ position_ids: Optional[torch.LongTensor] = None,
217
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
218
+ token_type_ids: Optional[torch.LongTensor] = None,
219
+ cache_position: Optional[torch.LongTensor] = None,
220
+ inputs_embeds: Optional[torch.FloatTensor] = None,
221
+ labels: Optional[torch.LongTensor] = None,
222
+ use_cache: Optional[bool] = None,
223
+ output_attentions: Optional[bool] = None,
224
+ output_hidden_states: Optional[bool] = None,
225
+ return_dict: Optional[bool] = None,
226
+ logits_to_keep: Union[int, torch.Tensor] = 0,
227
+ **lm_kwargs,
228
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
229
+ r"""
230
+ Args:
231
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
232
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
233
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
234
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
235
+
236
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
237
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
238
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
239
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
240
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
241
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
242
+
243
+ Returns:
244
+
245
+ Example:
246
+
247
+ ```python
248
+ >>> from PIL import Image
249
+ >>> import requests
250
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
251
+
252
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
253
+ >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
254
+
255
+ >>> prompt = "answer en Where is the cow standing?"
256
+ >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
257
+ >>> image = Image.open(requests.get(url, stream=True).raw)
258
+
259
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
260
+
261
+ >>> # Generate
262
+ >>> generate_ids = model.generate(**inputs, max_length=30)
263
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
264
+ "answer en Where is the cow standing?\nbeach"
265
+ ```"""
266
+
267
+ if (input_ids is None) ^ (inputs_embeds is not None):
268
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
269
+
270
+ if pixel_values is not None and inputs_embeds is not None:
271
+ raise ValueError(
272
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
273
+ )
274
+
275
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
276
+ output_hidden_states = (
277
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
278
+ )
279
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
280
+
281
+ is_training = token_type_ids is not None and labels is not None
282
+
283
+ if inputs_embeds is None:
284
+ inputs_embeds = self.get_input_embeddings()(input_ids)
285
+
286
+ if cache_position is None:
287
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
288
+ cache_position = torch.arange(
289
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
290
+ )
291
+
292
+ if position_ids is None:
293
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
294
+
295
+ # Merge text and images
296
+ if pixel_values is not None:
297
+ image_features = self.get_image_features(pixel_values)
298
+
299
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
300
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
301
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
302
+ image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
303
+ raise ValueError(
304
+ f"Number of images does not match number of special image tokens in the input text. "
305
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
306
+ "tokens from image embeddings."
307
+ )
308
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
309
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
310
+
311
+ # mask out pad-token-ids in labels for BC
312
+ if labels is not None and self.pad_token_id in labels:
313
+ logger.warning_once(
314
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
315
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
316
+ )
317
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
318
+
319
+ causal_mask = self._update_causal_mask(
320
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
321
+ )
322
+
323
+ outputs = self.language_model.model(
324
+ attention_mask=causal_mask,
325
+ position_ids=position_ids,
326
+ past_key_values=past_key_values,
327
+ inputs_embeds=inputs_embeds,
328
+ use_cache=use_cache,
329
+ output_attentions=output_attentions,
330
+ output_hidden_states=output_hidden_states,
331
+ return_dict=return_dict,
332
+ cache_position=cache_position,
333
+ logits_to_keep=logits_to_keep,
334
+ **lm_kwargs,
335
+ )
336
+
337
+ hidden_states = outputs[0]
338
+
339
+ loss = None
340
+ logits = None
341
+
342
+ if self.training and (labels is not None):
343
+ shift_hidden_states = hidden_states[..., :-1, :]
344
+ shift_labels = labels[..., 1:]
345
+
346
+ hidden_device = shift_hidden_states.device
347
+
348
+ if attention_mask is not None:
349
+ # we use the input attention mask to shift the hidden_states and labels, because it is 2D.
350
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
351
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
352
+ shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
353
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
354
+ else:
355
+ shift_hidden_states = shift_hidden_states.contiguous()
356
+ shift_labels = shift_labels.contiguous()
357
+
358
+ # Flatten hidden state
359
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
360
+ shift_labels = shift_labels.view(-1).to(hidden_device)
361
+
362
+ lce = LigerFusedLinearCrossEntropyLoss()
363
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
364
+ else:
365
+ logits = self.language_model.lm_head(hidden_states)
366
+ if labels is not None:
367
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
368
+ logits = logits.float()
369
+ shift_logits = logits[..., :-1, :]
370
+ shift_labels = labels[..., 1:]
371
+ if attention_mask is not None:
372
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
373
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
374
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
375
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
376
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
377
+ else:
378
+ shift_logits = shift_logits.contiguous()
379
+ shift_labels = shift_labels.contiguous()
380
+ # Flatten the tokens
381
+ loss_fct = CrossEntropyLoss()
382
+
383
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
384
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
385
+ loss = loss_fct(flat_logits, flat_labels)
386
+ if not return_dict:
387
+ output = (logits,) + outputs[1:]
388
+ return (loss,) + output if loss is not None else output
389
+
390
+ return PaliGemmaCausalLMOutputWithPast(
391
+ loss=loss,
392
+ logits=logits,
393
+ past_key_values=outputs.past_key_values,
394
+ hidden_states=outputs.hidden_states,
395
+ attentions=outputs.attentions,
396
+ image_hidden_states=image_features if pixel_values is not None else None,
397
+ )
@@ -11,8 +11,10 @@ from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
11
  from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
12
  from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
14
15
 
15
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
18
 
17
19
 
18
20
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
@@ -125,6 +127,7 @@ def lce_forward_deprecated(
125
127
  )
126
128
 
127
129
 
130
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
128
131
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
129
132
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
130
133
  def lce_forward(
@@ -140,7 +143,7 @@ def lce_forward(
140
143
  output_hidden_states: Optional[bool] = None,
141
144
  return_dict: Optional[bool] = None,
142
145
  cache_position: Optional[torch.LongTensor] = None,
143
- num_logits_to_keep: int = 0,
146
+ logits_to_keep: Union[int, torch.Tensor] = 0,
144
147
  **loss_kwargs,
145
148
  ) -> Union[Tuple, CausalLMOutputWithPast]:
146
149
  r"""
@@ -150,10 +153,12 @@ def lce_forward(
150
153
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
151
154
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
152
155
 
153
- num_logits_to_keep (`int`, *optional*):
154
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
156
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
157
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
155
158
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
156
159
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
160
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
161
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
157
162
 
158
163
  Returns:
159
164
 
@@ -213,24 +218,17 @@ def lce_forward(
213
218
  loss = None
214
219
  # if in training mode, don't materialize logits
215
220
  if self.training and (labels is not None):
216
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
217
-
218
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
219
- shift_labels = labels[..., 1:].contiguous()
220
-
221
- # flatten tokens
222
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
223
- shift_labels = shift_labels.view(-1)
224
-
225
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
226
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
227
-
228
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
229
- if reduction == "sum":
230
- loss /= loss_kwargs["num_items_in_batch"]
221
+ loss = LigerForCausalLMLoss(
222
+ hidden_states=hidden_states,
223
+ lm_head_weight=self.lm_head.weight,
224
+ labels=labels,
225
+ hidden_size=self.config.hidden_size,
226
+ **loss_kwargs,
227
+ )
231
228
 
232
229
  else: # if in inference mode materialize logits
233
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
230
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
234
232
  if labels is not None:
235
233
  loss = self.loss_function(
236
234
  logits=logits,