liger-kernel 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +103 -61
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +30 -11
  10. liger_kernel/ops/kl_div.py +2 -2
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/dyt.py +20 -0
  13. liger_kernel/transformers/functional.py +5 -0
  14. liger_kernel/transformers/model/gemma.py +8 -16
  15. liger_kernel/transformers/model/gemma2.py +7 -16
  16. liger_kernel/transformers/model/llama.py +8 -15
  17. liger_kernel/transformers/model/llava.py +369 -0
  18. liger_kernel/transformers/model/loss_utils.py +57 -0
  19. liger_kernel/transformers/model/mistral.py +9 -10
  20. liger_kernel/transformers/model/mixtral.py +8 -15
  21. liger_kernel/transformers/model/mllama.py +8 -15
  22. liger_kernel/transformers/model/olmo2.py +8 -16
  23. liger_kernel/transformers/model/paligemma.py +397 -0
  24. liger_kernel/transformers/model/phi3.py +8 -15
  25. liger_kernel/transformers/model/qwen2.py +8 -15
  26. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  27. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  28. liger_kernel/transformers/monkey_patch.py +219 -13
  29. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +9 -6
  30. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/RECORD +34 -29
  31. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
  32. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,397 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from torch.nn import CrossEntropyLoss
9
+ from transformers.cache_utils import Cache
10
+ from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
11
+ from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
12
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import is_torchdynamo_compiling
15
+ from transformers.utils import logging
16
+ from transformers.utils import replace_return_docstrings
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+
19
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
+ def lce_forward_deprecated(
27
+ self,
28
+ input_ids: torch.LongTensor = None,
29
+ pixel_values: torch.FloatTensor = None,
30
+ attention_mask: Optional[torch.Tensor] = None,
31
+ position_ids: Optional[torch.LongTensor] = None,
32
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
33
+ token_type_ids: Optional[torch.LongTensor] = None,
34
+ cache_position: Optional[torch.LongTensor] = None,
35
+ inputs_embeds: Optional[torch.FloatTensor] = None,
36
+ labels: Optional[torch.LongTensor] = None,
37
+ use_cache: Optional[bool] = None,
38
+ output_attentions: Optional[bool] = None,
39
+ output_hidden_states: Optional[bool] = None,
40
+ return_dict: Optional[bool] = None,
41
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
42
+ r"""
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ Returns:
50
+
51
+ Example:
52
+
53
+ ```python
54
+ >>> from PIL import Image
55
+ >>> import requests
56
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
57
+
58
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
59
+ >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
60
+
61
+ >>> prompt = "answer en Where is the cow standing?"
62
+ >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
63
+ >>> image = Image.open(requests.get(url, stream=True).raw)
64
+
65
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
66
+
67
+ >>> # Generate
68
+ >>> generate_ids = model.generate(**inputs, max_length=30)
69
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
70
+ "answer en Where is the cow standing?\nbeach"
71
+ ```"""
72
+
73
+ if (input_ids is None) ^ (inputs_embeds is not None):
74
+ raise ValueError(
75
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
76
+ )
77
+
78
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
79
+ output_hidden_states = (
80
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
+ )
82
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
+
84
+ # the attention mask is turned 4d after, we keep track of the original one
85
+ input_attention_mask = attention_mask
86
+
87
+ if inputs_embeds is None:
88
+ # 1. Extra the input embeddings
89
+ inputs_embeds = self.get_input_embeddings()(input_ids)
90
+
91
+ # 2. Merge text and images
92
+ if pixel_values is not None and input_ids.shape[1] != 1:
93
+ image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
94
+ selected_image_feature = image_outputs.last_hidden_state
95
+ image_features = self.multi_modal_projector(selected_image_feature)
96
+
97
+ if cache_position is None:
98
+ cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
99
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
100
+ image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
101
+ )
102
+
103
+ else:
104
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
105
+ # generation with cache
106
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
107
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
108
+ # that are set to 0
109
+ # TODO @molbap this will only work for dynamic cache.
110
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
111
+
112
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
113
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
114
+
115
+ # Get the target length
116
+ target_seqlen = cache_position[-1] + 1
117
+ extended_attention_mask = torch.ones(
118
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1),
119
+ dtype=attention_mask.dtype,
120
+ device=attention_mask.device,
121
+ )
122
+ # Filter out only the tokens that can be un-attended, this can happen
123
+ # if one uses PaliGemma+ Fused modules where the cache on the
124
+ # first iteration is already big enough, or if one passes custom cache
125
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
126
+ new_batch_index = batch_index[valid_indices]
127
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
128
+
129
+ # Zero-out the places where we don't need to attend
130
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
131
+
132
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
133
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
134
+
135
+ attention_mask = attention_mask.to(inputs_embeds.dtype)
136
+ outputs = self.language_model.model(
137
+ attention_mask=attention_mask,
138
+ position_ids=position_ids,
139
+ past_key_values=past_key_values,
140
+ inputs_embeds=inputs_embeds,
141
+ use_cache=use_cache,
142
+ output_attentions=output_attentions,
143
+ output_hidden_states=output_hidden_states,
144
+ return_dict=return_dict,
145
+ cache_position=cache_position,
146
+ )
147
+
148
+ hidden_states = outputs[0]
149
+
150
+ loss = None
151
+ logits = None
152
+
153
+ if self.training and (labels is not None):
154
+ shift_hidden_states = hidden_states[..., :-1, :]
155
+ shift_labels = labels[..., 1:]
156
+
157
+ hidden_device = shift_hidden_states.device
158
+
159
+ if attention_mask is not None:
160
+ # we use the input attention mask to shift the hidden_states and labels, because it is 2D.
161
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
162
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
163
+ shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
164
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
165
+ else:
166
+ shift_hidden_states = shift_hidden_states.contiguous()
167
+ shift_labels = shift_labels.contiguous()
168
+
169
+ # Flatten hidden state
170
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
171
+ shift_labels = shift_labels.view(-1).to(hidden_device)
172
+
173
+ lce = LigerFusedLinearCrossEntropyLoss()
174
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
175
+
176
+ else:
177
+ logits = self.language_model.lm_head(hidden_states)
178
+ if labels is not None:
179
+ shift_logits = logits[..., :-1, :]
180
+ shift_labels = labels[..., 1:]
181
+ if input_attention_mask is not None:
182
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
183
+ shift_attention_mask = input_attention_mask[..., 1:]
184
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
185
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
186
+ else:
187
+ shift_logits = shift_logits.contiguous()
188
+ shift_labels = shift_labels.contiguous()
189
+ # Flatten the tokens
190
+ loss_fct = CrossEntropyLoss()
191
+
192
+ flat_logits = shift_logits.view(-1, self.config.vocab_size)
193
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
194
+ loss = loss_fct(flat_logits, flat_labels)
195
+ if not return_dict:
196
+ output = (logits,) + outputs[1:]
197
+ return (loss,) + output if loss is not None else output
198
+
199
+ return PaliGemmaCausalLMOutputWithPast(
200
+ loss=loss,
201
+ logits=logits,
202
+ past_key_values=outputs.past_key_values,
203
+ hidden_states=outputs.hidden_states,
204
+ attentions=outputs.attentions,
205
+ )
206
+
207
+
208
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
209
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
210
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
211
+ def lce_forward(
212
+ self,
213
+ input_ids: torch.LongTensor = None,
214
+ pixel_values: torch.FloatTensor = None,
215
+ attention_mask: Optional[torch.Tensor] = None,
216
+ position_ids: Optional[torch.LongTensor] = None,
217
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
218
+ token_type_ids: Optional[torch.LongTensor] = None,
219
+ cache_position: Optional[torch.LongTensor] = None,
220
+ inputs_embeds: Optional[torch.FloatTensor] = None,
221
+ labels: Optional[torch.LongTensor] = None,
222
+ use_cache: Optional[bool] = None,
223
+ output_attentions: Optional[bool] = None,
224
+ output_hidden_states: Optional[bool] = None,
225
+ return_dict: Optional[bool] = None,
226
+ logits_to_keep: Union[int, torch.Tensor] = 0,
227
+ **lm_kwargs,
228
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
229
+ r"""
230
+ Args:
231
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
232
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
233
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
234
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
235
+
236
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
237
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
238
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
239
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
240
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
241
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
242
+
243
+ Returns:
244
+
245
+ Example:
246
+
247
+ ```python
248
+ >>> from PIL import Image
249
+ >>> import requests
250
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
251
+
252
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
253
+ >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
254
+
255
+ >>> prompt = "answer en Where is the cow standing?"
256
+ >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
257
+ >>> image = Image.open(requests.get(url, stream=True).raw)
258
+
259
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
260
+
261
+ >>> # Generate
262
+ >>> generate_ids = model.generate(**inputs, max_length=30)
263
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
264
+ "answer en Where is the cow standing?\nbeach"
265
+ ```"""
266
+
267
+ if (input_ids is None) ^ (inputs_embeds is not None):
268
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
269
+
270
+ if pixel_values is not None and inputs_embeds is not None:
271
+ raise ValueError(
272
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
273
+ )
274
+
275
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
276
+ output_hidden_states = (
277
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
278
+ )
279
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
280
+
281
+ is_training = token_type_ids is not None and labels is not None
282
+
283
+ if inputs_embeds is None:
284
+ inputs_embeds = self.get_input_embeddings()(input_ids)
285
+
286
+ if cache_position is None:
287
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
288
+ cache_position = torch.arange(
289
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
290
+ )
291
+
292
+ if position_ids is None:
293
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
294
+
295
+ # Merge text and images
296
+ if pixel_values is not None:
297
+ image_features = self.get_image_features(pixel_values)
298
+
299
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
300
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
301
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
302
+ image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
303
+ raise ValueError(
304
+ f"Number of images does not match number of special image tokens in the input text. "
305
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
306
+ "tokens from image embeddings."
307
+ )
308
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
309
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
310
+
311
+ # mask out pad-token-ids in labels for BC
312
+ if labels is not None and self.pad_token_id in labels:
313
+ logger.warning_once(
314
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
315
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
316
+ )
317
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
318
+
319
+ causal_mask = self._update_causal_mask(
320
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
321
+ )
322
+
323
+ outputs = self.language_model.model(
324
+ attention_mask=causal_mask,
325
+ position_ids=position_ids,
326
+ past_key_values=past_key_values,
327
+ inputs_embeds=inputs_embeds,
328
+ use_cache=use_cache,
329
+ output_attentions=output_attentions,
330
+ output_hidden_states=output_hidden_states,
331
+ return_dict=return_dict,
332
+ cache_position=cache_position,
333
+ logits_to_keep=logits_to_keep,
334
+ **lm_kwargs,
335
+ )
336
+
337
+ hidden_states = outputs[0]
338
+
339
+ loss = None
340
+ logits = None
341
+
342
+ if self.training and (labels is not None):
343
+ shift_hidden_states = hidden_states[..., :-1, :]
344
+ shift_labels = labels[..., 1:]
345
+
346
+ hidden_device = shift_hidden_states.device
347
+
348
+ if attention_mask is not None:
349
+ # we use the input attention mask to shift the hidden_states and labels, because it is 2D.
350
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
351
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
352
+ shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
353
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
354
+ else:
355
+ shift_hidden_states = shift_hidden_states.contiguous()
356
+ shift_labels = shift_labels.contiguous()
357
+
358
+ # Flatten hidden state
359
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
360
+ shift_labels = shift_labels.view(-1).to(hidden_device)
361
+
362
+ lce = LigerFusedLinearCrossEntropyLoss()
363
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
364
+ else:
365
+ logits = self.language_model.lm_head(hidden_states)
366
+ if labels is not None:
367
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
368
+ logits = logits.float()
369
+ shift_logits = logits[..., :-1, :]
370
+ shift_labels = labels[..., 1:]
371
+ if attention_mask is not None:
372
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
373
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
374
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
375
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
376
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
377
+ else:
378
+ shift_logits = shift_logits.contiguous()
379
+ shift_labels = shift_labels.contiguous()
380
+ # Flatten the tokens
381
+ loss_fct = CrossEntropyLoss()
382
+
383
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
384
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
385
+ loss = loss_fct(flat_logits, flat_labels)
386
+ if not return_dict:
387
+ output = (logits,) + outputs[1:]
388
+ return (loss,) + output if loss is not None else output
389
+
390
+ return PaliGemmaCausalLMOutputWithPast(
391
+ loss=loss,
392
+ logits=logits,
393
+ past_key_values=outputs.past_key_values,
394
+ hidden_states=outputs.hidden_states,
395
+ attentions=outputs.attentions,
396
+ image_hidden_states=image_features if pixel_values is not None else None,
397
+ )
@@ -13,6 +13,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
14
 
15
15
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
17
 
17
18
 
18
19
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
@@ -213,21 +214,13 @@ def lce_forward(
213
214
  loss = None
214
215
  # if in training mode, don't materialize logits
215
216
  if self.training and (labels is not None):
216
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
217
-
218
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
219
- shift_labels = labels[..., 1:].contiguous()
220
-
221
- # flatten tokens
222
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
223
- shift_labels = shift_labels.view(-1)
224
-
225
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
226
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
227
-
228
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
229
- if reduction == "sum":
230
- loss /= loss_kwargs["num_items_in_batch"]
217
+ loss = LigerForCausalLMLoss(
218
+ hidden_states=hidden_states,
219
+ lm_head_weight=self.lm_head.weight,
220
+ labels=labels,
221
+ hidden_size=self.config.hidden_size,
222
+ **loss_kwargs,
223
+ )
231
224
 
232
225
  else: # if in inference mode materialize logits
233
226
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
@@ -13,6 +13,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
14
 
15
15
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
17
 
17
18
 
18
19
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
@@ -199,21 +200,13 @@ def lce_forward(
199
200
  loss = None
200
201
  # if in training mode, don't materialize logits
201
202
  if self.training and (labels is not None):
202
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
203
-
204
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
205
- shift_labels = labels[..., 1:].contiguous()
206
-
207
- # flatten tokens
208
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
209
- shift_labels = shift_labels.view(-1)
210
-
211
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
212
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
213
-
214
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
215
- if reduction == "sum":
216
- loss /= loss_kwargs["num_items_in_batch"]
203
+ loss = LigerForCausalLMLoss(
204
+ hidden_states=hidden_states,
205
+ lm_head_weight=self.lm_head.weight,
206
+ labels=labels,
207
+ hidden_size=self.config.hidden_size,
208
+ **loss_kwargs,
209
+ )
217
210
 
218
211
  else: # if in inference mode materialize logits
219
212
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
@@ -12,7 +12,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalL
12
12
  from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
14
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
16
 
17
17
 
18
18
  @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
@@ -36,6 +36,7 @@ def lce_forward(
36
36
  rope_deltas: Optional[torch.LongTensor] = None,
37
37
  cache_position: Optional[torch.LongTensor] = None,
38
38
  second_per_grid_ts: Optional[torch.Tensor] = None,
39
+ **loss_kwargs,
39
40
  ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
40
41
  r"""
41
42
  Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -166,15 +167,13 @@ def lce_forward(
166
167
  logits = None
167
168
 
168
169
  if self.training and (labels is not None):
169
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
170
- shift_labels = labels[..., 1:].contiguous()
171
-
172
- # Flatten tokens
173
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
174
- shift_labels = shift_labels.view(-1)
175
-
176
- lce = LigerFusedLinearCrossEntropyLoss()
177
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
170
+ loss = LigerForCausalLMLoss(
171
+ hidden_states=hidden_states,
172
+ lm_head_weight=self.lm_head.weight,
173
+ labels=labels,
174
+ hidden_size=self.config.hidden_size,
175
+ **loss_kwargs,
176
+ )
178
177
  else:
179
178
  logits = self.lm_head(hidden_states)
180
179
  if labels is not None:
@@ -14,7 +14,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutput
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
16
 
17
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
18
 
19
19
 
20
20
  @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@@ -37,6 +37,7 @@ def lce_forward(
37
37
  video_grid_thw: Optional[torch.LongTensor] = None,
38
38
  rope_deltas: Optional[torch.LongTensor] = None,
39
39
  cache_position: Optional[torch.LongTensor] = None,
40
+ **loss_kwargs,
40
41
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
41
42
  r"""
42
43
  Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -170,15 +171,13 @@ def lce_forward(
170
171
  logits = None
171
172
 
172
173
  if self.training and (labels is not None):
173
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
174
- shift_labels = labels[..., 1:].contiguous()
175
-
176
- # Flatten tokens
177
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
178
- shift_labels = shift_labels.view(-1)
179
-
180
- lce = LigerFusedLinearCrossEntropyLoss()
181
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
174
+ loss = LigerForCausalLMLoss(
175
+ hidden_states=hidden_states,
176
+ lm_head_weight=self.lm_head.weight,
177
+ labels=labels,
178
+ hidden_size=self.config.hidden_size,
179
+ **loss_kwargs,
180
+ )
182
181
  else:
183
182
  logits = self.lm_head(hidden_states)
184
183
  if labels is not None: