liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. liger_kernel/env_report.py +2 -0
  2. liger_kernel/ops/cross_entropy.py +144 -65
  3. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
  5. liger_kernel/ops/fused_linear_jsd.py +245 -0
  6. liger_kernel/ops/geglu.py +2 -2
  7. liger_kernel/ops/group_norm.py +322 -0
  8. liger_kernel/ops/jsd.py +176 -0
  9. liger_kernel/ops/kl_div.py +2 -2
  10. liger_kernel/ops/rms_norm.py +92 -46
  11. liger_kernel/ops/swiglu.py +2 -2
  12. liger_kernel/ops/utils.py +62 -1
  13. liger_kernel/transformers/__init__.py +3 -0
  14. liger_kernel/transformers/cross_entropy.py +44 -12
  15. liger_kernel/transformers/functional.py +38 -1
  16. liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
  17. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  18. liger_kernel/transformers/group_norm.py +56 -0
  19. liger_kernel/transformers/jsd.py +75 -0
  20. liger_kernel/transformers/model/gemma.py +124 -1
  21. liger_kernel/transformers/model/gemma2.py +277 -0
  22. liger_kernel/transformers/model/llama.py +135 -4
  23. liger_kernel/transformers/model/mistral.py +3 -0
  24. liger_kernel/transformers/model/mixtral.py +153 -2
  25. liger_kernel/transformers/model/mllama.py +274 -0
  26. liger_kernel/transformers/model/phi3.py +140 -2
  27. liger_kernel/transformers/model/qwen2.py +123 -2
  28. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  29. liger_kernel/transformers/monkey_patch.py +258 -68
  30. liger_kernel/transformers/rms_norm.py +11 -3
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
  32. liger_kernel-0.4.1.dist-info/NOTICE +58 -0
  33. liger_kernel-0.4.1.dist-info/RECORD +51 -0
  34. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  36. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  37. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,274 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers.cache_utils import Cache
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
8
+ from transformers.utils import (
9
+ add_start_docstrings_to_model_forward,
10
+ replace_return_docstrings,
11
+ )
12
+
13
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
14
+ LigerFusedLinearCrossEntropyLoss,
15
+ )
16
+
17
+
18
+ @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
19
+ @replace_return_docstrings(
20
+ output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
21
+ )
22
+ def lce_forward_deprecated(
23
+ self,
24
+ input_ids: torch.LongTensor = None,
25
+ attention_mask: Optional[torch.Tensor] = None,
26
+ position_ids: Optional[torch.LongTensor] = None,
27
+ cross_attention_states: Optional[torch.LongTensor] = None,
28
+ cross_attention_mask: Optional[torch.LongTensor] = None,
29
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
30
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ num_logits_to_keep: int = 0,
39
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
40
+ r"""
41
+ Copy paste mllama forward but replace torch cross entropy with liger fused linear cross entropy
42
+
43
+
44
+ Args:
45
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
46
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
47
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
48
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
49
+ num_logits_to_keep (`int`, *optional*):
50
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
51
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
52
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
53
+ Returns:
54
+ Example:
55
+ ```python
56
+ >>> from transformers import AutoTokenizer, MllamaForCausalLM
57
+ >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
58
+ >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
59
+ >>> prompt = "If I had to write a haiku, it would be:"
60
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
61
+ >>> # Generate
62
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
63
+ >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
+ >>> print(result)
65
+ If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
66
+ I love the idea of snowflakes gently falling, each one
67
+ ```
68
+ """
69
+ output_attentions = (
70
+ output_attentions
71
+ if output_attentions is not None
72
+ else self.config.output_attentions
73
+ )
74
+ output_hidden_states = (
75
+ output_hidden_states
76
+ if output_hidden_states is not None
77
+ else self.config.output_hidden_states
78
+ )
79
+ return_dict = (
80
+ return_dict if return_dict is not None else self.config.use_return_dict
81
+ )
82
+
83
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
84
+ outputs = self.model(
85
+ input_ids=input_ids,
86
+ cross_attention_states=cross_attention_states,
87
+ attention_mask=attention_mask,
88
+ position_ids=position_ids,
89
+ cross_attention_mask=cross_attention_mask,
90
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
91
+ past_key_values=past_key_values,
92
+ inputs_embeds=inputs_embeds,
93
+ use_cache=use_cache,
94
+ output_attentions=output_attentions,
95
+ output_hidden_states=output_hidden_states,
96
+ return_dict=return_dict,
97
+ cache_position=cache_position,
98
+ )
99
+
100
+ hidden_states = outputs[0]
101
+
102
+ loss = None
103
+ logits = None
104
+
105
+ if self.training and (labels is not None):
106
+ kept_hidden_states = hidden_states[:, -num_logits_to_keep:, :]
107
+
108
+ shift_hidden_states = kept_hidden_states[..., :-1, :].contiguous()
109
+ shift_labels = labels[..., 1:].contiguous()
110
+
111
+ # flatten tokens
112
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
113
+ shift_labels = shift_labels.view(-1)
114
+
115
+ lce = LigerFusedLinearCrossEntropyLoss()
116
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
117
+
118
+ else:
119
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
120
+ if labels is not None:
121
+ # Shift so that tokens < n predict n
122
+ shift_logits = logits[..., :-1, :].contiguous()
123
+ shift_labels = labels[..., 1:].contiguous()
124
+ # Flatten the tokens
125
+ loss_fct = CrossEntropyLoss()
126
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
127
+ shift_labels = shift_labels.view(-1)
128
+ # Enable model parallelism
129
+ shift_labels = shift_labels.to(shift_logits.device)
130
+ loss = loss_fct(shift_logits, shift_labels)
131
+
132
+ if not return_dict:
133
+ output = (logits,) + outputs[1:]
134
+ return (loss,) + output if loss is not None else output
135
+
136
+ return CausalLMOutputWithPast(
137
+ loss=loss,
138
+ logits=logits,
139
+ past_key_values=outputs.past_key_values,
140
+ hidden_states=outputs.hidden_states,
141
+ attentions=outputs.attentions,
142
+ )
143
+
144
+
145
+ @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
146
+ @replace_return_docstrings(
147
+ output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
148
+ )
149
+ def lce_forward(
150
+ self,
151
+ input_ids: torch.LongTensor = None,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ position_ids: Optional[torch.LongTensor] = None,
154
+ cross_attention_states: Optional[torch.LongTensor] = None,
155
+ cross_attention_mask: Optional[torch.LongTensor] = None,
156
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
157
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
159
+ labels: Optional[torch.LongTensor] = None,
160
+ use_cache: Optional[bool] = None,
161
+ output_attentions: Optional[bool] = None,
162
+ output_hidden_states: Optional[bool] = None,
163
+ return_dict: Optional[bool] = None,
164
+ cache_position: Optional[torch.LongTensor] = None,
165
+ num_logits_to_keep: int = 0,
166
+ **loss_kwargs,
167
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
168
+ r"""
169
+ Args:
170
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
172
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174
+
175
+ num_logits_to_keep (`int`, *optional*):
176
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
177
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179
+
180
+ Returns:
181
+
182
+ Example:
183
+
184
+ ```python
185
+ >>> from transformers import AutoTokenizer, MllamaForCausalLM
186
+
187
+ >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
188
+ >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
189
+
190
+ >>> prompt = "If I had to write a haiku, it would be:"
191
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
192
+
193
+ >>> # Generate
194
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
195
+ >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
+ >>> print(result)
197
+ If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
198
+ I love the idea of snowflakes gently falling, each one
199
+ ```
200
+ """
201
+ output_attentions = (
202
+ output_attentions
203
+ if output_attentions is not None
204
+ else self.config.output_attentions
205
+ )
206
+ output_hidden_states = (
207
+ output_hidden_states
208
+ if output_hidden_states is not None
209
+ else self.config.output_hidden_states
210
+ )
211
+ return_dict = (
212
+ return_dict if return_dict is not None else self.config.use_return_dict
213
+ )
214
+
215
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
216
+ outputs = self.model(
217
+ input_ids=input_ids,
218
+ cross_attention_states=cross_attention_states,
219
+ attention_mask=attention_mask,
220
+ position_ids=position_ids,
221
+ cross_attention_mask=cross_attention_mask,
222
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
223
+ past_key_values=past_key_values,
224
+ inputs_embeds=inputs_embeds,
225
+ use_cache=use_cache,
226
+ output_attentions=output_attentions,
227
+ output_hidden_states=output_hidden_states,
228
+ return_dict=return_dict,
229
+ cache_position=cache_position,
230
+ )
231
+
232
+ hidden_states = outputs[0]
233
+
234
+ logits = None
235
+ loss = None
236
+ # if in training mode, don't materialize logits
237
+ if self.training and (labels is not None):
238
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
239
+
240
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
241
+ shift_labels = labels[..., 1:].contiguous()
242
+
243
+ # flatten tokens
244
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
245
+ shift_labels = shift_labels.view(-1)
246
+
247
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
248
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
249
+
250
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
251
+ if reduction == "sum":
252
+ loss /= loss_kwargs["num_items_in_batch"]
253
+
254
+ else: # if in inference mode materialize logits
255
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
256
+ if labels is not None:
257
+ loss = self.loss_function(
258
+ logits=logits,
259
+ labels=labels,
260
+ vocab_size=self.config.vocab_size,
261
+ **loss_kwargs,
262
+ )
263
+
264
+ if not return_dict:
265
+ output = (logits,) + outputs[1:]
266
+ return (loss,) + output if loss is not None else output
267
+
268
+ return CausalLMOutputWithPast(
269
+ loss=loss,
270
+ logits=logits,
271
+ past_key_values=outputs.past_key_values,
272
+ hidden_states=outputs.hidden_states,
273
+ attentions=outputs.attentions,
274
+ )
@@ -21,7 +21,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
21
21
  @replace_return_docstrings(
22
22
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
23
  )
24
- def lce_forward(
24
+ def lce_forward_deprecated(
25
25
  self,
26
26
  input_ids: torch.LongTensor = None,
27
27
  attention_mask: Optional[torch.Tensor] = None,
@@ -108,10 +108,11 @@ def lce_forward(
108
108
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
109
109
  else:
110
110
  logits = self.lm_head(hidden_states)
111
- logits = logits.float()
112
111
 
113
112
  loss = None
114
113
  if labels is not None:
114
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
115
+ logits = logits.float()
115
116
  # Shift so that tokens < n predict n
116
117
  shift_logits = logits[..., :-1, :].contiguous()
117
118
  shift_labels = labels[..., 1:].contiguous()
@@ -134,3 +135,140 @@ def lce_forward(
134
135
  hidden_states=outputs.hidden_states,
135
136
  attentions=outputs.attentions,
136
137
  )
138
+
139
+
140
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
141
+ @replace_return_docstrings(
142
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
143
+ )
144
+ def lce_forward(
145
+ self,
146
+ input_ids: torch.LongTensor = None,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ position_ids: Optional[torch.LongTensor] = None,
149
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
150
+ inputs_embeds: Optional[torch.FloatTensor] = None,
151
+ labels: Optional[torch.LongTensor] = None,
152
+ use_cache: Optional[bool] = None,
153
+ output_attentions: Optional[bool] = None,
154
+ output_hidden_states: Optional[bool] = None,
155
+ return_dict: Optional[bool] = None,
156
+ cache_position: Optional[torch.LongTensor] = None,
157
+ num_logits_to_keep: int = 0,
158
+ **loss_kwargs,
159
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
160
+ r"""
161
+ Args:
162
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
163
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
164
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
165
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
166
+
167
+ num_logits_to_keep (`int`, *optional*):
168
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
169
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
170
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
171
+
172
+ Returns:
173
+
174
+ Example:
175
+
176
+ ```python
177
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
178
+
179
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
180
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
181
+
182
+ >>> prompt = "This is an example script ."
183
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
184
+
185
+ >>> # Generate
186
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
187
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
188
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
189
+ ```"""
190
+
191
+ from transformers.models.phi3.modeling_phi3 import logging
192
+
193
+ logger = logging.get_logger(__name__)
194
+
195
+ if (
196
+ use_cache
197
+ and self.config.rope_scaling
198
+ and cache_position is not None
199
+ and cache_position[0] == self.config.original_max_position_embeddings
200
+ ):
201
+ logger.warning(
202
+ f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
203
+ )
204
+
205
+ output_attentions = (
206
+ output_attentions
207
+ if output_attentions is not None
208
+ else self.config.output_attentions
209
+ )
210
+ output_hidden_states = (
211
+ output_hidden_states
212
+ if output_hidden_states is not None
213
+ else self.config.output_hidden_states
214
+ )
215
+ return_dict = (
216
+ return_dict if return_dict is not None else self.config.use_return_dict
217
+ )
218
+
219
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
220
+ outputs = self.model(
221
+ input_ids=input_ids,
222
+ attention_mask=attention_mask,
223
+ position_ids=position_ids,
224
+ past_key_values=past_key_values,
225
+ inputs_embeds=inputs_embeds,
226
+ use_cache=use_cache,
227
+ output_attentions=output_attentions,
228
+ output_hidden_states=output_hidden_states,
229
+ return_dict=return_dict,
230
+ )
231
+
232
+ hidden_states = outputs[0]
233
+
234
+ logits = None
235
+ loss = None
236
+ # if in training mode, don't materialize logits
237
+ if self.training and (labels is not None):
238
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
239
+
240
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
241
+ shift_labels = labels[..., 1:].contiguous()
242
+
243
+ # flatten tokens
244
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
245
+ shift_labels = shift_labels.view(-1)
246
+
247
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
248
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
249
+
250
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
251
+ if reduction == "sum":
252
+ loss /= loss_kwargs["num_items_in_batch"]
253
+
254
+ else: # if in inference mode materialize logits
255
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
256
+ if labels is not None:
257
+ loss = self.loss_function(
258
+ logits=logits,
259
+ labels=labels,
260
+ vocab_size=self.config.vocab_size,
261
+ **loss_kwargs,
262
+ )
263
+
264
+ if not return_dict:
265
+ output = (logits,) + outputs[1:]
266
+ return (loss,) + output if loss is not None else output
267
+
268
+ return CausalLMOutputWithPast(
269
+ loss=loss,
270
+ logits=logits,
271
+ past_key_values=outputs.past_key_values,
272
+ hidden_states=outputs.hidden_states,
273
+ attentions=outputs.attentions,
274
+ )
@@ -21,7 +21,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
21
21
  @replace_return_docstrings(
22
22
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
23
  )
24
- def lce_forward(
24
+ def lce_forward_deprecated(
25
25
  self,
26
26
  input_ids: torch.LongTensor = None,
27
27
  attention_mask: Optional[torch.Tensor] = None,
@@ -109,8 +109,9 @@ def lce_forward(
109
109
 
110
110
  else:
111
111
  logits = self.lm_head(hidden_states)
112
- logits = logits.float()
113
112
  if labels is not None:
113
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
114
+ logits = logits.float()
114
115
  # Shift so that tokens < n predict n
115
116
  shift_logits = logits[..., :-1, :].contiguous()
116
117
  shift_labels = labels[..., 1:].contiguous()
@@ -133,3 +134,123 @@ def lce_forward(
133
134
  hidden_states=outputs.hidden_states,
134
135
  attentions=outputs.attentions,
135
136
  )
137
+
138
+
139
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
140
+ @replace_return_docstrings(
141
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
142
+ )
143
+ def lce_forward(
144
+ self,
145
+ input_ids: torch.LongTensor = None,
146
+ attention_mask: Optional[torch.Tensor] = None,
147
+ position_ids: Optional[torch.LongTensor] = None,
148
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
149
+ inputs_embeds: Optional[torch.FloatTensor] = None,
150
+ labels: Optional[torch.LongTensor] = None,
151
+ use_cache: Optional[bool] = None,
152
+ output_attentions: Optional[bool] = None,
153
+ output_hidden_states: Optional[bool] = None,
154
+ return_dict: Optional[bool] = None,
155
+ cache_position: Optional[torch.LongTensor] = None,
156
+ num_logits_to_keep: int = 0,
157
+ **loss_kwargs,
158
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
159
+ r"""
160
+ Args:
161
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
162
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
163
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
164
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
165
+
166
+ num_logits_to_keep (`int`, *optional*):
167
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
168
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
169
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
170
+
171
+ Returns:
172
+
173
+ Example:
174
+
175
+ ```python
176
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
177
+
178
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
179
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
180
+
181
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
182
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
183
+
184
+ >>> # Generate
185
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
186
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
187
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
188
+ ```"""
189
+
190
+ output_attentions = (
191
+ output_attentions
192
+ if output_attentions is not None
193
+ else self.config.output_attentions
194
+ )
195
+ output_hidden_states = (
196
+ output_hidden_states
197
+ if output_hidden_states is not None
198
+ else self.config.output_hidden_states
199
+ )
200
+ return_dict = (
201
+ return_dict if return_dict is not None else self.config.use_return_dict
202
+ )
203
+
204
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
205
+ outputs = self.model(
206
+ input_ids=input_ids,
207
+ attention_mask=attention_mask,
208
+ position_ids=position_ids,
209
+ past_key_values=past_key_values,
210
+ inputs_embeds=inputs_embeds,
211
+ use_cache=use_cache,
212
+ output_attentions=output_attentions,
213
+ output_hidden_states=output_hidden_states,
214
+ return_dict=return_dict,
215
+ cache_position=cache_position,
216
+ )
217
+
218
+ hidden_states = outputs[0]
219
+
220
+ logits = None
221
+ loss = None
222
+ # if in training mode, don't materialize logits
223
+ if self.training and (labels is not None):
224
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
225
+
226
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
227
+ shift_labels = labels[..., 1:].contiguous()
228
+
229
+ # flatten tokens
230
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
231
+ shift_labels = shift_labels.view(-1)
232
+
233
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
234
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
235
+
236
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
237
+ if reduction == "sum":
238
+ loss /= loss_kwargs["num_items_in_batch"]
239
+
240
+ else: # if in inference mode materialize logits
241
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
242
+ if labels is not None:
243
+ loss = self.loss_function(
244
+ logits=logits,
245
+ labels=labels,
246
+ vocab_size=self.config.vocab_size,
247
+ **loss_kwargs,
248
+ )
249
+
250
+ return CausalLMOutputWithPast(
251
+ loss=loss,
252
+ logits=logits,
253
+ past_key_values=outputs.past_key_values,
254
+ hidden_states=outputs.hidden_states,
255
+ attentions=outputs.attentions,
256
+ )
@@ -80,6 +80,7 @@ def lce_forward(
80
80
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
81
81
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
82
82
  ```"""
83
+ # FIXME: The code is outdated and not compatible with transformer >= 4.46.1
83
84
 
84
85
  output_attentions = (
85
86
  output_attentions
@@ -115,6 +116,11 @@ def lce_forward(
115
116
  inputs_embeds[video_mask] = video_embeds
116
117
  if attention_mask is not None:
117
118
  attention_mask = attention_mask.to(inputs_embeds.device)
119
+ # The code is copied from https://github.com/huggingface/transformers/pull/33487
120
+ if position_ids is None and input_ids is not None:
121
+ position_ids, _ = self.get_rope_index(
122
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
123
+ )
118
124
 
119
125
  outputs = self.model(
120
126
  input_ids=None,
@@ -145,8 +151,9 @@ def lce_forward(
145
151
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
146
152
  else:
147
153
  logits = self.lm_head(hidden_states)
148
- logits = logits.float()
149
154
  if labels is not None:
155
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
156
+ logits = logits.float()
150
157
  # Shift so that tokens < n predict n
151
158
  shift_logits = logits[..., :-1, :].contiguous()
152
159
  shift_labels = labels[..., 1:].contiguous()