liger-kernel 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. liger_kernel/ops/cross_entropy.py +5 -39
  2. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  3. liger_kernel/ops/fused_linear_cross_entropy.py +12 -9
  4. liger_kernel/ops/fused_linear_jsd.py +245 -0
  5. liger_kernel/ops/geglu.py +2 -2
  6. liger_kernel/ops/jsd.py +176 -0
  7. liger_kernel/ops/kl_div.py +2 -2
  8. liger_kernel/ops/rms_norm.py +67 -42
  9. liger_kernel/ops/swiglu.py +2 -2
  10. liger_kernel/ops/utils.py +62 -1
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/functional.py +4 -0
  13. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  14. liger_kernel/transformers/jsd.py +75 -0
  15. liger_kernel/transformers/model/gemma.py +124 -1
  16. liger_kernel/transformers/model/llama.py +135 -4
  17. liger_kernel/transformers/model/mistral.py +3 -0
  18. liger_kernel/transformers/model/mixtral.py +153 -2
  19. liger_kernel/transformers/model/mllama.py +274 -0
  20. liger_kernel/transformers/model/phi3.py +140 -2
  21. liger_kernel/transformers/model/qwen2.py +123 -2
  22. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  23. liger_kernel/transformers/monkey_patch.py +158 -7
  24. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +60 -28
  25. liger_kernel-0.4.0.dist-info/NOTICE +58 -0
  26. liger_kernel-0.4.0.dist-info/RECORD +48 -0
  27. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
  28. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  29. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  30. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
22
22
  @replace_return_docstrings(
23
23
  output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
24
  )
25
- def lce_forward(
25
+ def lce_forward_deprecated(
26
26
  self,
27
27
  input_ids: torch.LongTensor = None,
28
28
  attention_mask: Optional[torch.Tensor] = None,
@@ -103,7 +103,6 @@ def lce_forward(
103
103
 
104
104
  hidden_states = outputs[0]
105
105
  logits = self.lm_head(hidden_states)
106
- logits = logits.float()
107
106
 
108
107
  loss = None
109
108
  if self.training and (labels is not None):
@@ -116,6 +115,8 @@ def lce_forward(
116
115
  lce = LigerFusedLinearCrossEntropyLoss()
117
116
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
118
117
  elif labels is not None:
118
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
119
+ logits = logits.float()
119
120
  # Shift so that tokens < n predict n
120
121
  shift_logits = logits[..., :-1, :].contiguous()
121
122
  shift_labels = labels[..., 1:].contiguous()
@@ -156,3 +157,153 @@ def lce_forward(
156
157
  attentions=outputs.attentions,
157
158
  router_logits=outputs.router_logits,
158
159
  )
160
+
161
+
162
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
163
+ @replace_return_docstrings(
164
+ output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
165
+ )
166
+ # Ignore copy
167
+ def lce_forward(
168
+ self,
169
+ input_ids: torch.LongTensor = None,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ position_ids: Optional[torch.LongTensor] = None,
172
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
173
+ inputs_embeds: Optional[torch.FloatTensor] = None,
174
+ labels: Optional[torch.LongTensor] = None,
175
+ use_cache: Optional[bool] = None,
176
+ output_attentions: Optional[bool] = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ output_router_logits: Optional[bool] = None,
179
+ return_dict: Optional[bool] = None,
180
+ cache_position: Optional[torch.LongTensor] = None,
181
+ num_logits_to_keep: int = 0,
182
+ **loss_kwargs,
183
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
184
+ r"""
185
+ Args:
186
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
187
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
188
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
189
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
190
+
191
+ num_logits_to_keep (`int`, *optional*):
192
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
193
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
194
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
195
+
196
+ Returns:
197
+
198
+ Example:
199
+
200
+ ```python
201
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
202
+
203
+ >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
204
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
205
+
206
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
207
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
208
+
209
+ >>> # Generate
210
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
211
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
212
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
213
+ ```"""
214
+
215
+ output_attentions = (
216
+ output_attentions
217
+ if output_attentions is not None
218
+ else self.config.output_attentions
219
+ )
220
+ output_router_logits = (
221
+ output_router_logits
222
+ if output_router_logits is not None
223
+ else self.config.output_router_logits
224
+ )
225
+
226
+ output_hidden_states = (
227
+ output_hidden_states
228
+ if output_hidden_states is not None
229
+ else self.config.output_hidden_states
230
+ )
231
+ return_dict = (
232
+ return_dict if return_dict is not None else self.config.use_return_dict
233
+ )
234
+
235
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
236
+ outputs = self.model(
237
+ input_ids=input_ids,
238
+ attention_mask=attention_mask,
239
+ position_ids=position_ids,
240
+ past_key_values=past_key_values,
241
+ inputs_embeds=inputs_embeds,
242
+ use_cache=use_cache,
243
+ output_attentions=output_attentions,
244
+ output_hidden_states=output_hidden_states,
245
+ output_router_logits=output_router_logits,
246
+ return_dict=return_dict,
247
+ cache_position=cache_position,
248
+ )
249
+
250
+ hidden_states = outputs[0]
251
+
252
+ logits = None
253
+ loss = None
254
+ # if in training mode, don't materialize logits
255
+ if self.training and (labels is not None):
256
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
257
+
258
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
259
+ shift_labels = labels[..., 1:].contiguous()
260
+
261
+ # flatten tokens
262
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
263
+ shift_labels = shift_labels.view(-1)
264
+
265
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
266
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
267
+
268
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
269
+ if reduction == "sum":
270
+ loss /= loss_kwargs["num_items_in_batch"]
271
+
272
+ else: # if in inference mode materialize logits
273
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
274
+ if labels is not None:
275
+ loss = self.loss_function(
276
+ logits=logits,
277
+ labels=labels,
278
+ vocab_size=self.config.vocab_size,
279
+ **loss_kwargs,
280
+ )
281
+
282
+ aux_loss = None
283
+ if output_router_logits:
284
+ aux_loss = load_balancing_loss_func(
285
+ outputs.router_logits if return_dict else outputs[-1],
286
+ self.num_experts,
287
+ self.num_experts_per_tok,
288
+ attention_mask,
289
+ )
290
+ if labels is not None:
291
+ loss += self.router_aux_loss_coef * aux_loss.to(
292
+ loss.device
293
+ ) # make sure to reside in the same device
294
+
295
+ if not return_dict:
296
+ output = (logits,) + outputs[1:]
297
+ if output_router_logits:
298
+ output = (aux_loss,) + output
299
+ return (loss,) + output if loss is not None else output
300
+
301
+ return MoeCausalLMOutputWithPast(
302
+ loss=loss,
303
+ aux_loss=aux_loss,
304
+ logits=logits,
305
+ past_key_values=outputs.past_key_values,
306
+ hidden_states=outputs.hidden_states,
307
+ attentions=outputs.attentions,
308
+ router_logits=outputs.router_logits,
309
+ )
@@ -0,0 +1,274 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers.cache_utils import Cache
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
8
+ from transformers.utils import (
9
+ add_start_docstrings_to_model_forward,
10
+ replace_return_docstrings,
11
+ )
12
+
13
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
14
+ LigerFusedLinearCrossEntropyLoss,
15
+ )
16
+
17
+
18
+ @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
19
+ @replace_return_docstrings(
20
+ output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
21
+ )
22
+ def lce_forward_deprecated(
23
+ self,
24
+ input_ids: torch.LongTensor = None,
25
+ attention_mask: Optional[torch.Tensor] = None,
26
+ position_ids: Optional[torch.LongTensor] = None,
27
+ cross_attention_states: Optional[torch.LongTensor] = None,
28
+ cross_attention_mask: Optional[torch.LongTensor] = None,
29
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
30
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
31
+ inputs_embeds: Optional[torch.FloatTensor] = None,
32
+ labels: Optional[torch.LongTensor] = None,
33
+ use_cache: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ num_logits_to_keep: int = 0,
39
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
40
+ r"""
41
+ Copy paste mllama forward but replace torch cross entropy with liger fused linear cross entropy
42
+
43
+
44
+ Args:
45
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
46
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
47
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
48
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
49
+ num_logits_to_keep (`int`, *optional*):
50
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
51
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
52
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
53
+ Returns:
54
+ Example:
55
+ ```python
56
+ >>> from transformers import AutoTokenizer, MllamaForCausalLM
57
+ >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
58
+ >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
59
+ >>> prompt = "If I had to write a haiku, it would be:"
60
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
61
+ >>> # Generate
62
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
63
+ >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
+ >>> print(result)
65
+ If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
66
+ I love the idea of snowflakes gently falling, each one
67
+ ```
68
+ """
69
+ output_attentions = (
70
+ output_attentions
71
+ if output_attentions is not None
72
+ else self.config.output_attentions
73
+ )
74
+ output_hidden_states = (
75
+ output_hidden_states
76
+ if output_hidden_states is not None
77
+ else self.config.output_hidden_states
78
+ )
79
+ return_dict = (
80
+ return_dict if return_dict is not None else self.config.use_return_dict
81
+ )
82
+
83
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
84
+ outputs = self.model(
85
+ input_ids=input_ids,
86
+ cross_attention_states=cross_attention_states,
87
+ attention_mask=attention_mask,
88
+ position_ids=position_ids,
89
+ cross_attention_mask=cross_attention_mask,
90
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
91
+ past_key_values=past_key_values,
92
+ inputs_embeds=inputs_embeds,
93
+ use_cache=use_cache,
94
+ output_attentions=output_attentions,
95
+ output_hidden_states=output_hidden_states,
96
+ return_dict=return_dict,
97
+ cache_position=cache_position,
98
+ )
99
+
100
+ hidden_states = outputs[0]
101
+
102
+ loss = None
103
+ logits = None
104
+
105
+ if self.training and (labels is not None):
106
+ kept_hidden_states = hidden_states[:, -num_logits_to_keep:, :]
107
+
108
+ shift_hidden_states = kept_hidden_states[..., :-1, :].contiguous()
109
+ shift_labels = labels[..., 1:].contiguous()
110
+
111
+ # flatten tokens
112
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
113
+ shift_labels = shift_labels.view(-1)
114
+
115
+ lce = LigerFusedLinearCrossEntropyLoss()
116
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
117
+
118
+ else:
119
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
120
+ if labels is not None:
121
+ # Shift so that tokens < n predict n
122
+ shift_logits = logits[..., :-1, :].contiguous()
123
+ shift_labels = labels[..., 1:].contiguous()
124
+ # Flatten the tokens
125
+ loss_fct = CrossEntropyLoss()
126
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
127
+ shift_labels = shift_labels.view(-1)
128
+ # Enable model parallelism
129
+ shift_labels = shift_labels.to(shift_logits.device)
130
+ loss = loss_fct(shift_logits, shift_labels)
131
+
132
+ if not return_dict:
133
+ output = (logits,) + outputs[1:]
134
+ return (loss,) + output if loss is not None else output
135
+
136
+ return CausalLMOutputWithPast(
137
+ loss=loss,
138
+ logits=logits,
139
+ past_key_values=outputs.past_key_values,
140
+ hidden_states=outputs.hidden_states,
141
+ attentions=outputs.attentions,
142
+ )
143
+
144
+
145
+ @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
146
+ @replace_return_docstrings(
147
+ output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
148
+ )
149
+ def lce_forward(
150
+ self,
151
+ input_ids: torch.LongTensor = None,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ position_ids: Optional[torch.LongTensor] = None,
154
+ cross_attention_states: Optional[torch.LongTensor] = None,
155
+ cross_attention_mask: Optional[torch.LongTensor] = None,
156
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
157
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
159
+ labels: Optional[torch.LongTensor] = None,
160
+ use_cache: Optional[bool] = None,
161
+ output_attentions: Optional[bool] = None,
162
+ output_hidden_states: Optional[bool] = None,
163
+ return_dict: Optional[bool] = None,
164
+ cache_position: Optional[torch.LongTensor] = None,
165
+ num_logits_to_keep: int = 0,
166
+ **loss_kwargs,
167
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
168
+ r"""
169
+ Args:
170
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
172
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174
+
175
+ num_logits_to_keep (`int`, *optional*):
176
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
177
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179
+
180
+ Returns:
181
+
182
+ Example:
183
+
184
+ ```python
185
+ >>> from transformers import AutoTokenizer, MllamaForCausalLM
186
+
187
+ >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
188
+ >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
189
+
190
+ >>> prompt = "If I had to write a haiku, it would be:"
191
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
192
+
193
+ >>> # Generate
194
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
195
+ >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
+ >>> print(result)
197
+ If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
198
+ I love the idea of snowflakes gently falling, each one
199
+ ```
200
+ """
201
+ output_attentions = (
202
+ output_attentions
203
+ if output_attentions is not None
204
+ else self.config.output_attentions
205
+ )
206
+ output_hidden_states = (
207
+ output_hidden_states
208
+ if output_hidden_states is not None
209
+ else self.config.output_hidden_states
210
+ )
211
+ return_dict = (
212
+ return_dict if return_dict is not None else self.config.use_return_dict
213
+ )
214
+
215
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
216
+ outputs = self.model(
217
+ input_ids=input_ids,
218
+ cross_attention_states=cross_attention_states,
219
+ attention_mask=attention_mask,
220
+ position_ids=position_ids,
221
+ cross_attention_mask=cross_attention_mask,
222
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
223
+ past_key_values=past_key_values,
224
+ inputs_embeds=inputs_embeds,
225
+ use_cache=use_cache,
226
+ output_attentions=output_attentions,
227
+ output_hidden_states=output_hidden_states,
228
+ return_dict=return_dict,
229
+ cache_position=cache_position,
230
+ )
231
+
232
+ hidden_states = outputs[0]
233
+
234
+ logits = None
235
+ loss = None
236
+ # if in training mode, don't materialize logits
237
+ if self.training and (labels is not None):
238
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
239
+
240
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
241
+ shift_labels = labels[..., 1:].contiguous()
242
+
243
+ # flatten tokens
244
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
245
+ shift_labels = shift_labels.view(-1)
246
+
247
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
248
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
249
+
250
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
251
+ if reduction == "sum":
252
+ loss /= loss_kwargs["num_items_in_batch"]
253
+
254
+ else: # if in inference mode materialize logits
255
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
256
+ if labels is not None:
257
+ loss = self.loss_function(
258
+ logits=logits,
259
+ labels=labels,
260
+ vocab_size=self.config.vocab_size,
261
+ **loss_kwargs,
262
+ )
263
+
264
+ if not return_dict:
265
+ output = (logits,) + outputs[1:]
266
+ return (loss,) + output if loss is not None else output
267
+
268
+ return CausalLMOutputWithPast(
269
+ loss=loss,
270
+ logits=logits,
271
+ past_key_values=outputs.past_key_values,
272
+ hidden_states=outputs.hidden_states,
273
+ attentions=outputs.attentions,
274
+ )
@@ -21,7 +21,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
21
21
  @replace_return_docstrings(
22
22
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
23
23
  )
24
- def lce_forward(
24
+ def lce_forward_deprecated(
25
25
  self,
26
26
  input_ids: torch.LongTensor = None,
27
27
  attention_mask: Optional[torch.Tensor] = None,
@@ -108,10 +108,11 @@ def lce_forward(
108
108
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
109
109
  else:
110
110
  logits = self.lm_head(hidden_states)
111
- logits = logits.float()
112
111
 
113
112
  loss = None
114
113
  if labels is not None:
114
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
115
+ logits = logits.float()
115
116
  # Shift so that tokens < n predict n
116
117
  shift_logits = logits[..., :-1, :].contiguous()
117
118
  shift_labels = labels[..., 1:].contiguous()
@@ -134,3 +135,140 @@ def lce_forward(
134
135
  hidden_states=outputs.hidden_states,
135
136
  attentions=outputs.attentions,
136
137
  )
138
+
139
+
140
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
141
+ @replace_return_docstrings(
142
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
143
+ )
144
+ def lce_forward(
145
+ self,
146
+ input_ids: torch.LongTensor = None,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ position_ids: Optional[torch.LongTensor] = None,
149
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
150
+ inputs_embeds: Optional[torch.FloatTensor] = None,
151
+ labels: Optional[torch.LongTensor] = None,
152
+ use_cache: Optional[bool] = None,
153
+ output_attentions: Optional[bool] = None,
154
+ output_hidden_states: Optional[bool] = None,
155
+ return_dict: Optional[bool] = None,
156
+ cache_position: Optional[torch.LongTensor] = None,
157
+ num_logits_to_keep: int = 0,
158
+ **loss_kwargs,
159
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
160
+ r"""
161
+ Args:
162
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
163
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
164
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
165
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
166
+
167
+ num_logits_to_keep (`int`, *optional*):
168
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
169
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
170
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
171
+
172
+ Returns:
173
+
174
+ Example:
175
+
176
+ ```python
177
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
178
+
179
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
180
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
181
+
182
+ >>> prompt = "This is an example script ."
183
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
184
+
185
+ >>> # Generate
186
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
187
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
188
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
189
+ ```"""
190
+
191
+ from transformers.models.phi3.modeling_phi3 import logging
192
+
193
+ logger = logging.get_logger(__name__)
194
+
195
+ if (
196
+ use_cache
197
+ and self.config.rope_scaling
198
+ and cache_position is not None
199
+ and cache_position[0] == self.config.original_max_position_embeddings
200
+ ):
201
+ logger.warning(
202
+ f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
203
+ )
204
+
205
+ output_attentions = (
206
+ output_attentions
207
+ if output_attentions is not None
208
+ else self.config.output_attentions
209
+ )
210
+ output_hidden_states = (
211
+ output_hidden_states
212
+ if output_hidden_states is not None
213
+ else self.config.output_hidden_states
214
+ )
215
+ return_dict = (
216
+ return_dict if return_dict is not None else self.config.use_return_dict
217
+ )
218
+
219
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
220
+ outputs = self.model(
221
+ input_ids=input_ids,
222
+ attention_mask=attention_mask,
223
+ position_ids=position_ids,
224
+ past_key_values=past_key_values,
225
+ inputs_embeds=inputs_embeds,
226
+ use_cache=use_cache,
227
+ output_attentions=output_attentions,
228
+ output_hidden_states=output_hidden_states,
229
+ return_dict=return_dict,
230
+ )
231
+
232
+ hidden_states = outputs[0]
233
+
234
+ logits = None
235
+ loss = None
236
+ # if in training mode, don't materialize logits
237
+ if self.training and (labels is not None):
238
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
239
+
240
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
241
+ shift_labels = labels[..., 1:].contiguous()
242
+
243
+ # flatten tokens
244
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
245
+ shift_labels = shift_labels.view(-1)
246
+
247
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
248
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
249
+
250
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
251
+ if reduction == "sum":
252
+ loss /= loss_kwargs["num_items_in_batch"]
253
+
254
+ else: # if in inference mode materialize logits
255
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
256
+ if labels is not None:
257
+ loss = self.loss_function(
258
+ logits=logits,
259
+ labels=labels,
260
+ vocab_size=self.config.vocab_size,
261
+ **loss_kwargs,
262
+ )
263
+
264
+ if not return_dict:
265
+ output = (logits,) + outputs[1:]
266
+ return (loss,) + output if loss is not None else output
267
+
268
+ return CausalLMOutputWithPast(
269
+ loss=loss,
270
+ logits=logits,
271
+ past_key_values=outputs.past_key_values,
272
+ hidden_states=outputs.hidden_states,
273
+ attentions=outputs.attentions,
274
+ )