liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. liger_kernel/env_report.py +2 -0
  2. liger_kernel/ops/cross_entropy.py +144 -65
  3. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
  5. liger_kernel/ops/fused_linear_jsd.py +245 -0
  6. liger_kernel/ops/geglu.py +2 -2
  7. liger_kernel/ops/group_norm.py +322 -0
  8. liger_kernel/ops/jsd.py +176 -0
  9. liger_kernel/ops/kl_div.py +2 -2
  10. liger_kernel/ops/rms_norm.py +92 -46
  11. liger_kernel/ops/swiglu.py +2 -2
  12. liger_kernel/ops/utils.py +62 -1
  13. liger_kernel/transformers/__init__.py +3 -0
  14. liger_kernel/transformers/cross_entropy.py +44 -12
  15. liger_kernel/transformers/functional.py +38 -1
  16. liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
  17. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  18. liger_kernel/transformers/group_norm.py +56 -0
  19. liger_kernel/transformers/jsd.py +75 -0
  20. liger_kernel/transformers/model/gemma.py +124 -1
  21. liger_kernel/transformers/model/gemma2.py +277 -0
  22. liger_kernel/transformers/model/llama.py +135 -4
  23. liger_kernel/transformers/model/mistral.py +3 -0
  24. liger_kernel/transformers/model/mixtral.py +153 -2
  25. liger_kernel/transformers/model/mllama.py +274 -0
  26. liger_kernel/transformers/model/phi3.py +140 -2
  27. liger_kernel/transformers/model/qwen2.py +123 -2
  28. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  29. liger_kernel/transformers/monkey_patch.py +258 -68
  30. liger_kernel/transformers/rms_norm.py +11 -3
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
  32. liger_kernel-0.4.1.dist-info/NOTICE +58 -0
  33. liger_kernel-0.4.1.dist-info/RECORD +51 -0
  34. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  36. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  37. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,277 @@
1
+ import logging
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers.cache_utils import HybridCache
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.models.gemma2.modeling_gemma2 import (
9
+ _CONFIG_FOR_DOC,
10
+ GEMMA2_INPUTS_DOCSTRING,
11
+ )
12
+ from transformers.utils import (
13
+ add_start_docstrings_to_model_forward,
14
+ replace_return_docstrings,
15
+ )
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
18
+ LigerFusedLinearCrossEntropyLoss,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def lce_forward_deprecated(
25
+ self,
26
+ input_ids: torch.LongTensor = None,
27
+ attention_mask: Optional[torch.Tensor] = None,
28
+ position_ids: Optional[torch.LongTensor] = None,
29
+ past_key_values: Optional[HybridCache] = None,
30
+ inputs_embeds: Optional[torch.FloatTensor] = None,
31
+ labels: Optional[torch.LongTensor] = None,
32
+ use_cache: Optional[bool] = None,
33
+ output_attentions: Optional[bool] = None,
34
+ output_hidden_states: Optional[bool] = None,
35
+ return_dict: Optional[bool] = None,
36
+ cache_position: Optional[torch.LongTensor] = None,
37
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
38
+ r"""
39
+ Args:
40
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
41
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
42
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
43
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
44
+
45
+ Returns:
46
+
47
+ Example:
48
+
49
+ ```python
50
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
51
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
52
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
53
+ >>> prompt = "What is your favorite condiment?"
54
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "What is your favorite condiment?"
59
+ ```"""
60
+
61
+ if self.training and self.config._attn_implementation != "eager":
62
+ logger.warning_once(
63
+ "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
64
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
65
+ )
66
+ output_attentions = (
67
+ output_attentions
68
+ if output_attentions is not None
69
+ else self.config.output_attentions
70
+ )
71
+ output_hidden_states = (
72
+ output_hidden_states
73
+ if output_hidden_states is not None
74
+ else self.config.output_hidden_states
75
+ )
76
+ return_dict = (
77
+ return_dict if return_dict is not None else self.config.use_return_dict
78
+ )
79
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
80
+ outputs = self.model(
81
+ input_ids=input_ids,
82
+ attention_mask=attention_mask,
83
+ position_ids=position_ids,
84
+ past_key_values=past_key_values,
85
+ inputs_embeds=inputs_embeds,
86
+ use_cache=use_cache,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=output_hidden_states,
89
+ return_dict=return_dict,
90
+ cache_position=cache_position,
91
+ )
92
+
93
+ hidden_states = outputs[0]
94
+
95
+ loss = None
96
+ logits = None
97
+
98
+ if self.training and (labels is not None):
99
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
100
+ shift_labels = labels[..., 1:].contiguous()
101
+
102
+ # flatten
103
+
104
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105
+ shift_labels = shift_labels.view(-1)
106
+
107
+ lce = LigerFusedLinearCrossEntropyLoss(
108
+ softcap=self.config.final_logit_softcapping
109
+ )
110
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
+
112
+ else:
113
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
114
+ logits = self.lm_head(hidden_states)
115
+ if self.config.final_logit_softcapping is not None:
116
+ logits = logits / self.config.final_logit_softcapping
117
+ logits = torch.tanh(logits)
118
+ logits = logits * self.config.final_logit_softcapping
119
+
120
+ loss = None
121
+ if labels is not None:
122
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
123
+ logits = logits.float()
124
+ # Shift so that tokens < n predict n
125
+ shift_logits = logits[..., :-1, :].contiguous()
126
+ shift_labels = labels[..., 1:].contiguous()
127
+ # Flatten the tokens
128
+ loss_fct = CrossEntropyLoss()
129
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
130
+ shift_labels = shift_labels.view(-1)
131
+ # Enable model parallelism
132
+ shift_labels = shift_labels.to(shift_logits.device)
133
+ loss = loss_fct(shift_logits, shift_labels)
134
+
135
+ if not return_dict:
136
+ output = (logits,) + outputs[1:]
137
+ return (loss,) + output if loss is not None else output
138
+
139
+ return CausalLMOutputWithPast(
140
+ loss=loss,
141
+ logits=logits,
142
+ past_key_values=outputs.past_key_values,
143
+ hidden_states=outputs.hidden_states,
144
+ attentions=outputs.attentions,
145
+ )
146
+
147
+
148
+ @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
149
+ @replace_return_docstrings(
150
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
151
+ )
152
+ def lce_forward(
153
+ self,
154
+ input_ids: torch.LongTensor = None,
155
+ attention_mask: Optional[torch.Tensor] = None,
156
+ position_ids: Optional[torch.LongTensor] = None,
157
+ past_key_values: Optional[HybridCache] = None,
158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
159
+ labels: Optional[torch.LongTensor] = None,
160
+ use_cache: Optional[bool] = None,
161
+ output_attentions: Optional[bool] = None,
162
+ output_hidden_states: Optional[bool] = None,
163
+ return_dict: Optional[bool] = None,
164
+ cache_position: Optional[torch.LongTensor] = None,
165
+ num_logits_to_keep: int = 0,
166
+ **loss_kwargs,
167
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
168
+ r"""
169
+ Args:
170
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
172
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174
+
175
+ num_logits_to_keep (`int`, *optional*):
176
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
177
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179
+
180
+ Returns:
181
+
182
+ Example:
183
+
184
+ ```python
185
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
186
+
187
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
188
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
189
+
190
+ >>> prompt = "What is your favorite condiment?"
191
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
192
+
193
+ >>> # Generate
194
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
195
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
+ "What is your favorite condiment?"
197
+ ```"""
198
+
199
+ if self.training and self.config._attn_implementation != "eager":
200
+ logger.warning_once(
201
+ "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
202
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
203
+ )
204
+ output_attentions = (
205
+ output_attentions
206
+ if output_attentions is not None
207
+ else self.config.output_attentions
208
+ )
209
+ output_hidden_states = (
210
+ output_hidden_states
211
+ if output_hidden_states is not None
212
+ else self.config.output_hidden_states
213
+ )
214
+ return_dict = (
215
+ return_dict if return_dict is not None else self.config.use_return_dict
216
+ )
217
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218
+ outputs = self.model(
219
+ input_ids=input_ids,
220
+ attention_mask=attention_mask,
221
+ position_ids=position_ids,
222
+ past_key_values=past_key_values,
223
+ inputs_embeds=inputs_embeds,
224
+ use_cache=use_cache,
225
+ output_attentions=output_attentions,
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ cache_position=cache_position,
229
+ )
230
+
231
+ hidden_states = outputs[0]
232
+
233
+ logits = None
234
+ loss = None
235
+ # if in training mode, don't materialize logits
236
+ if self.training and (labels is not None):
237
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
238
+
239
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
240
+ shift_labels = labels[..., 1:].contiguous()
241
+
242
+ # flatten tokens
243
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
244
+ shift_labels = shift_labels.view(-1)
245
+
246
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
247
+ lce = LigerFusedLinearCrossEntropyLoss(
248
+ softcap=self.config.final_logit_softcapping,
249
+ reduction=reduction,
250
+ )
251
+
252
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
253
+ if reduction == "sum":
254
+ loss /= loss_kwargs["num_items_in_batch"]
255
+
256
+ else: # if in inference mode materialize logits
257
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
258
+ if self.config.final_logit_softcapping is not None:
259
+ logits = logits / self.config.final_logit_softcapping
260
+ logits = torch.tanh(logits)
261
+ logits = logits * self.config.final_logit_softcapping
262
+
263
+ loss = None
264
+ if labels is not None:
265
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
266
+
267
+ if not return_dict:
268
+ output = (logits,) + outputs[1:]
269
+ return (loss,) + output if loss is not None else output
270
+
271
+ return CausalLMOutputWithPast(
272
+ loss=loss,
273
+ logits=logits,
274
+ past_key_values=outputs.past_key_values,
275
+ hidden_states=outputs.hidden_states,
276
+ attentions=outputs.attentions,
277
+ )
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
  import torch.nn.functional as F
@@ -17,17 +17,20 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
17
17
  LigerFusedLinearCrossEntropyLoss,
18
18
  )
19
19
 
20
+ if TYPE_CHECKING:
21
+ from transformers.cache_utils import Cache
22
+
20
23
 
21
24
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
22
25
  @replace_return_docstrings(
23
26
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
27
  )
25
- def lce_forward(
28
+ def lce_forward_deprecated(
26
29
  self,
27
30
  input_ids: torch.LongTensor = None,
28
31
  attention_mask: Optional[torch.Tensor] = None,
29
32
  position_ids: Optional[torch.LongTensor] = None,
30
- past_key_values: Optional[List[torch.FloatTensor]] = None,
33
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
31
34
  inputs_embeds: Optional[torch.FloatTensor] = None,
32
35
  labels: Optional[torch.LongTensor] = None,
33
36
  use_cache: Optional[bool] = None,
@@ -120,8 +123,9 @@ def lce_forward(
120
123
  logits = torch.cat(logits, dim=-1)
121
124
  else:
122
125
  logits = self.lm_head(hidden_states)
123
- logits = logits.float()
124
126
  if labels is not None:
127
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
128
+ logits = logits.float()
125
129
  # Shift so that tokens < n predict n
126
130
  shift_logits = logits[..., :-1, :].contiguous()
127
131
  shift_labels = labels[..., 1:].contiguous()
@@ -144,3 +148,130 @@ def lce_forward(
144
148
  hidden_states=outputs.hidden_states,
145
149
  attentions=outputs.attentions,
146
150
  )
151
+
152
+
153
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
154
+ @replace_return_docstrings(
155
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
156
+ )
157
+ def lce_forward(
158
+ self,
159
+ input_ids: torch.LongTensor = None,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ position_ids: Optional[torch.LongTensor] = None,
162
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
163
+ inputs_embeds: Optional[torch.FloatTensor] = None,
164
+ labels: Optional[torch.LongTensor] = None,
165
+ use_cache: Optional[bool] = None,
166
+ output_attentions: Optional[bool] = None,
167
+ output_hidden_states: Optional[bool] = None,
168
+ return_dict: Optional[bool] = None,
169
+ cache_position: Optional[torch.LongTensor] = None,
170
+ num_logits_to_keep: int = 0,
171
+ **loss_kwargs,
172
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
173
+ r"""
174
+ Args:
175
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
176
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
177
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
178
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
179
+
180
+ num_logits_to_keep (`int`, *optional*):
181
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
182
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
183
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
184
+
185
+ Returns:
186
+
187
+ Example:
188
+
189
+ ```python
190
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
191
+
192
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
193
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
194
+
195
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
196
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
197
+
198
+ >>> # Generate
199
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
200
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
201
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
202
+ ```"""
203
+
204
+ output_attentions = (
205
+ output_attentions
206
+ if output_attentions is not None
207
+ else self.config.output_attentions
208
+ )
209
+ output_hidden_states = (
210
+ output_hidden_states
211
+ if output_hidden_states is not None
212
+ else self.config.output_hidden_states
213
+ )
214
+ return_dict = (
215
+ return_dict if return_dict is not None else self.config.use_return_dict
216
+ )
217
+
218
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
219
+ outputs = self.model(
220
+ input_ids=input_ids,
221
+ attention_mask=attention_mask,
222
+ position_ids=position_ids,
223
+ past_key_values=past_key_values,
224
+ inputs_embeds=inputs_embeds,
225
+ use_cache=use_cache,
226
+ output_attentions=output_attentions,
227
+ output_hidden_states=output_hidden_states,
228
+ return_dict=return_dict,
229
+ cache_position=cache_position,
230
+ )
231
+
232
+ hidden_states = outputs[0]
233
+
234
+ if self.config.pretraining_tp > 1:
235
+ raise Exception("Liger Kernel does not support pretraining_tp!!")
236
+
237
+ logits = None
238
+ loss = None
239
+ # if in training mode, don't materialize logits
240
+ if self.training and (labels is not None):
241
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
242
+
243
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
244
+ shift_labels = labels[..., 1:].contiguous()
245
+
246
+ # flatten tokens
247
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
248
+ shift_labels = shift_labels.view(-1)
249
+
250
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
251
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
252
+
253
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
254
+ if reduction == "sum":
255
+ loss /= loss_kwargs["num_items_in_batch"]
256
+
257
+ else: # if in inference mode materialize logits
258
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
259
+ if labels is not None:
260
+ loss = self.loss_function(
261
+ logits=logits,
262
+ labels=labels,
263
+ vocab_size=self.config.vocab_size,
264
+ **loss_kwargs,
265
+ )
266
+
267
+ if not return_dict:
268
+ output = (logits,) + outputs[1:]
269
+ return (loss,) + output if loss is not None else output
270
+
271
+ return CausalLMOutputWithPast(
272
+ loss=loss,
273
+ logits=logits,
274
+ past_key_values=outputs.past_key_values,
275
+ hidden_states=outputs.hidden_states,
276
+ attentions=outputs.attentions,
277
+ )
@@ -136,3 +136,6 @@ def lce_forward(
136
136
  hidden_states=outputs.hidden_states,
137
137
  attentions=outputs.attentions,
138
138
  )
139
+
140
+
141
+ # Note: Grad Acc is not fixed in mistral at transformer 4.46.1
@@ -22,7 +22,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
22
22
  @replace_return_docstrings(
23
23
  output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
24
24
  )
25
- def lce_forward(
25
+ def lce_forward_deprecated(
26
26
  self,
27
27
  input_ids: torch.LongTensor = None,
28
28
  attention_mask: Optional[torch.Tensor] = None,
@@ -103,7 +103,6 @@ def lce_forward(
103
103
 
104
104
  hidden_states = outputs[0]
105
105
  logits = self.lm_head(hidden_states)
106
- logits = logits.float()
107
106
 
108
107
  loss = None
109
108
  if self.training and (labels is not None):
@@ -116,6 +115,8 @@ def lce_forward(
116
115
  lce = LigerFusedLinearCrossEntropyLoss()
117
116
  loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
118
117
  elif labels is not None:
118
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
119
+ logits = logits.float()
119
120
  # Shift so that tokens < n predict n
120
121
  shift_logits = logits[..., :-1, :].contiguous()
121
122
  shift_labels = labels[..., 1:].contiguous()
@@ -156,3 +157,153 @@ def lce_forward(
156
157
  attentions=outputs.attentions,
157
158
  router_logits=outputs.router_logits,
158
159
  )
160
+
161
+
162
+ @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
163
+ @replace_return_docstrings(
164
+ output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
165
+ )
166
+ # Ignore copy
167
+ def lce_forward(
168
+ self,
169
+ input_ids: torch.LongTensor = None,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ position_ids: Optional[torch.LongTensor] = None,
172
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
173
+ inputs_embeds: Optional[torch.FloatTensor] = None,
174
+ labels: Optional[torch.LongTensor] = None,
175
+ use_cache: Optional[bool] = None,
176
+ output_attentions: Optional[bool] = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ output_router_logits: Optional[bool] = None,
179
+ return_dict: Optional[bool] = None,
180
+ cache_position: Optional[torch.LongTensor] = None,
181
+ num_logits_to_keep: int = 0,
182
+ **loss_kwargs,
183
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
184
+ r"""
185
+ Args:
186
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
187
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
188
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
189
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
190
+
191
+ num_logits_to_keep (`int`, *optional*):
192
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
193
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
194
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
195
+
196
+ Returns:
197
+
198
+ Example:
199
+
200
+ ```python
201
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
202
+
203
+ >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
204
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
205
+
206
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
207
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
208
+
209
+ >>> # Generate
210
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
211
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
212
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
213
+ ```"""
214
+
215
+ output_attentions = (
216
+ output_attentions
217
+ if output_attentions is not None
218
+ else self.config.output_attentions
219
+ )
220
+ output_router_logits = (
221
+ output_router_logits
222
+ if output_router_logits is not None
223
+ else self.config.output_router_logits
224
+ )
225
+
226
+ output_hidden_states = (
227
+ output_hidden_states
228
+ if output_hidden_states is not None
229
+ else self.config.output_hidden_states
230
+ )
231
+ return_dict = (
232
+ return_dict if return_dict is not None else self.config.use_return_dict
233
+ )
234
+
235
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
236
+ outputs = self.model(
237
+ input_ids=input_ids,
238
+ attention_mask=attention_mask,
239
+ position_ids=position_ids,
240
+ past_key_values=past_key_values,
241
+ inputs_embeds=inputs_embeds,
242
+ use_cache=use_cache,
243
+ output_attentions=output_attentions,
244
+ output_hidden_states=output_hidden_states,
245
+ output_router_logits=output_router_logits,
246
+ return_dict=return_dict,
247
+ cache_position=cache_position,
248
+ )
249
+
250
+ hidden_states = outputs[0]
251
+
252
+ logits = None
253
+ loss = None
254
+ # if in training mode, don't materialize logits
255
+ if self.training and (labels is not None):
256
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
257
+
258
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
259
+ shift_labels = labels[..., 1:].contiguous()
260
+
261
+ # flatten tokens
262
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
263
+ shift_labels = shift_labels.view(-1)
264
+
265
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
266
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
267
+
268
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
269
+ if reduction == "sum":
270
+ loss /= loss_kwargs["num_items_in_batch"]
271
+
272
+ else: # if in inference mode materialize logits
273
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
274
+ if labels is not None:
275
+ loss = self.loss_function(
276
+ logits=logits,
277
+ labels=labels,
278
+ vocab_size=self.config.vocab_size,
279
+ **loss_kwargs,
280
+ )
281
+
282
+ aux_loss = None
283
+ if output_router_logits:
284
+ aux_loss = load_balancing_loss_func(
285
+ outputs.router_logits if return_dict else outputs[-1],
286
+ self.num_experts,
287
+ self.num_experts_per_tok,
288
+ attention_mask,
289
+ )
290
+ if labels is not None:
291
+ loss += self.router_aux_loss_coef * aux_loss.to(
292
+ loss.device
293
+ ) # make sure to reside in the same device
294
+
295
+ if not return_dict:
296
+ output = (logits,) + outputs[1:]
297
+ if output_router_logits:
298
+ output = (aux_loss,) + output
299
+ return (loss,) + output if loss is not None else output
300
+
301
+ return MoeCausalLMOutputWithPast(
302
+ loss=loss,
303
+ aux_loss=aux_loss,
304
+ logits=logits,
305
+ past_key_values=outputs.past_key_values,
306
+ hidden_states=outputs.hidden_states,
307
+ attentions=outputs.attentions,
308
+ router_logits=outputs.router_logits,
309
+ )