liger-kernel 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  6. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  7. liger_kernel/ops/geglu.py +1 -1
  8. liger_kernel/ops/layer_norm.py +126 -89
  9. liger_kernel/ops/multi_token_attention.py +207 -0
  10. liger_kernel/ops/rms_norm.py +267 -56
  11. liger_kernel/ops/rope.py +1 -1
  12. liger_kernel/ops/softmax.py +201 -0
  13. liger_kernel/ops/sparsemax.py +62 -50
  14. liger_kernel/ops/swiglu.py +1 -1
  15. liger_kernel/transformers/__init__.py +8 -0
  16. liger_kernel/transformers/functional.py +67 -0
  17. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  18. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  19. liger_kernel/transformers/model/gemma.py +25 -8
  20. liger_kernel/transformers/model/gemma2.py +27 -8
  21. liger_kernel/transformers/model/gemma3.py +63 -99
  22. liger_kernel/transformers/model/glm4.py +16 -7
  23. liger_kernel/transformers/model/llama.py +25 -7
  24. liger_kernel/transformers/model/llama4.py +108 -0
  25. liger_kernel/transformers/model/llava.py +95 -124
  26. liger_kernel/transformers/model/mistral.py +13 -8
  27. liger_kernel/transformers/model/mixtral.py +16 -7
  28. liger_kernel/transformers/model/mllama.py +16 -7
  29. liger_kernel/transformers/model/olmo2.py +16 -7
  30. liger_kernel/transformers/model/paligemma.py +8 -1
  31. liger_kernel/transformers/model/phi3.py +25 -8
  32. liger_kernel/transformers/model/qwen2.py +24 -7
  33. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  34. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  35. liger_kernel/transformers/model/qwen3.py +11 -3
  36. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  37. liger_kernel/transformers/model/smollm3.py +189 -0
  38. liger_kernel/transformers/monkey_patch.py +389 -82
  39. liger_kernel/transformers/multi_token_attention.py +64 -0
  40. liger_kernel/transformers/rms_norm.py +40 -4
  41. liger_kernel/transformers/softmax.py +12 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/METADATA +18 -14
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/RECORD +47 -37
  44. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/WHEEL +1 -1
  45. liger_kernel/transformers/gema3_rms.py +0 -8
  46. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/LICENSE +0 -0
  47. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/NOTICE +0 -0
  48. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/top_level.txt +0 -0
@@ -5,11 +5,12 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
+ from torch.nn import CrossEntropyLoss
8
9
  from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
9
10
  from transformers.utils import is_torchdynamo_compiling
10
- from transformers.utils.deprecation import deprecate_kwarg
11
11
 
12
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
13
14
 
14
15
 
15
16
  def lce_forward_deprecated(
@@ -27,6 +28,11 @@ def lce_forward_deprecated(
27
28
  output_attentions: Optional[bool] = None,
28
29
  output_hidden_states: Optional[bool] = None,
29
30
  return_dict: Optional[bool] = None,
31
+ cache_position: Optional[torch.LongTensor] = None,
32
+ logits_to_keep: Union[int, torch.Tensor] = 0,
33
+ image_sizes: torch.Tensor = None,
34
+ skip_logits: Optional[bool] = None,
35
+ **lm_kwargs,
30
36
  ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
31
37
  r"""
32
38
  Args:
@@ -35,10 +41,12 @@ def lce_forward_deprecated(
35
41
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
36
42
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
37
43
 
38
- num_logits_to_keep (`int`, *optional*):
39
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
40
46
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
41
47
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
42
50
 
43
51
 
44
52
  Returns:
@@ -64,7 +72,6 @@ def lce_forward_deprecated(
64
72
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
65
73
  "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
66
74
  ```"""
67
-
68
75
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
69
76
  output_hidden_states = (
70
77
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -88,73 +95,24 @@ def lce_forward_deprecated(
88
95
  )
89
96
 
90
97
  if inputs_embeds is None:
91
- # 1. Extra the input embeddings
92
98
  inputs_embeds = self.get_input_embeddings()(input_ids)
93
99
 
94
- # 2. Merge text and images
95
- if pixel_values is not None and input_ids.shape[1] != 1:
96
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
97
- # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
98
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
99
-
100
- if vision_feature_select_strategy == "default":
101
- selected_image_feature = selected_image_feature[:, 1:]
102
- elif vision_feature_select_strategy == "full":
103
- selected_image_feature = selected_image_feature
104
- else:
105
- raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
106
-
107
- image_features = self.multi_modal_projector(selected_image_feature)
108
- inputs_embeds = inputs_embeds.to(image_features.dtype)
109
- inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
110
- image_features, inputs_embeds, input_ids, attention_mask, labels
111
- )
112
-
113
- # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
114
- # generation with cache
115
- elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
116
- # Retrieve the first layer to inspect the logits and mask out the hidden states
117
- # that are set to 0
118
- first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
119
-
120
- # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
121
- batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
122
-
123
- # Get the target length
124
- target_length = input_ids.shape[1]
125
- past_length = first_layer_past_key_value.shape[-1]
126
-
127
- extended_attention_mask = torch.ones(
128
- (attention_mask.shape[0], past_length),
129
- dtype=attention_mask.dtype,
130
- device=attention_mask.device,
131
- )
132
-
133
- # Filter out only the tokens that can be un-attended, this can happen
134
- # if one uses Llava + Fused modules where the cache on the
135
- # first iteration is already big enough, or if one passes custom cache
136
- valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
137
- new_batch_index = batch_index[valid_indices]
138
- new_non_attended_tokens = non_attended_tokens[valid_indices]
139
-
140
- # Zero-out the places where we don't need to attend
141
- extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
142
-
143
- attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
144
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
145
-
146
- # TODO: @raushan retain only the new behavior after v4.47
147
- elif image_features is not None:
148
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
149
- n_image_features = image_features.shape[0] * image_features.shape[1]
100
+ if pixel_values is not None:
101
+ image_features = self.get_image_features(
102
+ pixel_values=pixel_values,
103
+ vision_feature_layer=vision_feature_layer,
104
+ vision_feature_select_strategy=vision_feature_select_strategy,
105
+ image_sizes=image_sizes,
106
+ )
150
107
 
151
- if n_image_tokens != n_image_features:
108
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
109
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
110
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
111
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
112
+ n_image_features = image_features.shape[0] * image_features.shape[1]
152
113
  raise ValueError(
153
114
  f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
154
115
  )
155
- special_image_mask = (
156
- (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
157
- )
158
116
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
159
117
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
160
118
 
@@ -167,13 +125,19 @@ def lce_forward_deprecated(
167
125
  output_attentions=output_attentions,
168
126
  output_hidden_states=output_hidden_states,
169
127
  return_dict=return_dict,
128
+ cache_position=cache_position,
129
+ logits_to_keep=logits_to_keep,
130
+ **lm_kwargs,
170
131
  )
171
132
  hidden_states = outputs[0]
172
133
 
173
134
  loss = None
174
135
  logits = None
175
136
 
176
- if self.training and (labels is not None):
137
+ # Overwrite skip_logits, since llava never materializes logits
138
+ skip_logits = labels is not None
139
+
140
+ if skip_logits:
177
141
  # Shift so that tokens < n predict n
178
142
  if attention_mask is not None:
179
143
  # we use the input attention mask to shift the logits and labels, because it is 2D.
@@ -188,7 +152,33 @@ def lce_forward_deprecated(
188
152
  shift_labels = labels[..., 1:].contiguous()
189
153
 
190
154
  lce = LigerFusedLinearCrossEntropyLoss()
191
- loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
155
+ loss = lce(
156
+ self.language_model.lm_head.weight,
157
+ shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
158
+ shift_labels.view(-1).to(shift_hidden_states.device),
159
+ )
160
+ else:
161
+ logits = self.language_model.lm_head(hidden_states)
162
+ if labels is not None:
163
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
164
+ logits = logits.float()
165
+ shift_logits = logits[..., :-1, :]
166
+ shift_labels = labels[..., 1:]
167
+ if attention_mask is not None:
168
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
169
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
170
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
171
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
172
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
173
+ else:
174
+ shift_logits = shift_logits.contiguous()
175
+ shift_labels = shift_labels.contiguous()
176
+ # Flatten the tokens
177
+ loss_fct = CrossEntropyLoss()
178
+
179
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
180
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
181
+ loss = loss_fct(flat_logits, flat_labels)
192
182
 
193
183
  if not return_dict:
194
184
  # NOTE: This part has not been tested.
@@ -201,10 +191,10 @@ def lce_forward_deprecated(
201
191
  past_key_values=outputs.past_key_values,
202
192
  hidden_states=outputs.hidden_states,
203
193
  attentions=outputs.attentions,
194
+ image_hidden_states=image_features if pixel_values is not None else None,
204
195
  )
205
196
 
206
197
 
207
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
208
198
  def lce_forward(
209
199
  self,
210
200
  input_ids: torch.LongTensor = None,
@@ -223,6 +213,7 @@ def lce_forward(
223
213
  cache_position: Optional[torch.LongTensor] = None,
224
214
  logits_to_keep: Union[int, torch.Tensor] = 0,
225
215
  image_sizes: torch.Tensor = None,
216
+ skip_logits: Optional[bool] = None,
226
217
  **lm_kwargs,
227
218
  ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
228
219
  r"""
@@ -277,78 +268,58 @@ def lce_forward(
277
268
  else self.config.vision_feature_select_strategy
278
269
  )
279
270
 
280
- if (input_ids is None) ^ (inputs_embeds is not None):
281
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
282
-
283
- if pixel_values is not None and inputs_embeds is not None:
284
- raise ValueError(
285
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
286
- )
287
-
288
- if inputs_embeds is None:
289
- inputs_embeds = self.get_input_embeddings()(input_ids)
290
-
291
- if pixel_values is not None:
292
- image_features = self.get_image_features(
293
- pixel_values=pixel_values,
294
- vision_feature_layer=vision_feature_layer,
295
- vision_feature_select_strategy=vision_feature_select_strategy,
296
- image_sizes=image_sizes,
297
- )
298
-
299
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
300
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
301
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
302
- n_image_tokens = (input_ids == self.config.image_token_index).sum()
303
- n_image_features = image_features.shape[0] * image_features.shape[1]
304
- raise ValueError(
305
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
306
- )
307
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
308
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
309
-
310
- outputs = self.language_model.model(
271
+ outputs = self.model(
272
+ input_ids=input_ids,
273
+ pixel_values=pixel_values,
311
274
  attention_mask=attention_mask,
312
275
  position_ids=position_ids,
313
276
  past_key_values=past_key_values,
314
277
  inputs_embeds=inputs_embeds,
278
+ vision_feature_layer=vision_feature_layer,
279
+ vision_feature_select_strategy=vision_feature_select_strategy,
315
280
  use_cache=use_cache,
316
281
  output_attentions=output_attentions,
317
282
  output_hidden_states=output_hidden_states,
318
- return_dict=return_dict,
283
+ return_dict=True,
319
284
  cache_position=cache_position,
320
- logits_to_keep=logits_to_keep,
285
+ image_sizes=image_sizes,
321
286
  **lm_kwargs,
322
287
  )
323
288
  hidden_states = outputs[0]
289
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
290
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
291
+ kept_hidden_states = hidden_states[:, slice_indices, :]
324
292
 
325
- loss = None
293
+ shift_labels = lm_kwargs.pop("shift_labels", None)
326
294
  logits = None
295
+ loss = None
327
296
 
328
- if self.training and (labels is not None):
329
- # Shift so that tokens < n predict n
330
- if attention_mask is not None:
331
- # we use the input attention mask to shift the logits and labels, because it is 2D.
332
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
333
- shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
334
- shift_hidden_states = hidden_states[..., :-1, :][
335
- shift_attention_mask.to(hidden_states.device) != 0
336
- ].contiguous()
337
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
338
- else:
339
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
340
- shift_labels = labels[..., 1:].contiguous()
341
-
342
- lce = LigerFusedLinearCrossEntropyLoss()
343
- loss = lce(
344
- self.language_model.lm_head.weight,
345
- shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
346
- shift_labels.view(-1).to(shift_hidden_states.device),
297
+ if skip_logits and labels is None and shift_labels is None:
298
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
299
+
300
+ if skip_logits is None:
301
+ # By default, if in training mode, don't materialize logits
302
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
303
+
304
+ if skip_logits:
305
+ loss = LigerForCausalLMLoss(
306
+ hidden_states=kept_hidden_states,
307
+ lm_head_weight=self.lm_head.weight,
308
+ labels=labels,
309
+ shift_labels=shift_labels,
310
+ hidden_size=self.config.text_config.hidden_size,
311
+ **lm_kwargs,
347
312
  )
348
313
 
314
+ else:
315
+ logits = self.lm_head(kept_hidden_states)
316
+ if labels is not None:
317
+ loss = self.loss_function(
318
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
319
+ )
320
+
349
321
  if not return_dict:
350
- # NOTE: This part has not been tested.
351
- output = outputs[1:]
322
+ output = (logits,) + outputs[1:]
352
323
  return (loss,) + output if loss is not None else output
353
324
 
354
325
  return LlavaCausalLMOutputWithPast(
@@ -357,5 +328,5 @@ def lce_forward(
357
328
  past_key_values=outputs.past_key_values,
358
329
  hidden_states=outputs.hidden_states,
359
330
  attentions=outputs.attentions,
360
- image_hidden_states=image_features if pixel_values is not None else None,
331
+ image_hidden_states=outputs.image_hidden_states,
361
332
  )
@@ -27,7 +27,8 @@ def lce_forward(
27
27
  return_dict: Optional[bool] = None,
28
28
  cache_position: Optional[torch.LongTensor] = None,
29
29
  logits_to_keep: Union[int, torch.Tensor] = 0,
30
- **loss_kwargs,
30
+ skip_logits: Optional[bool] = None,
31
+ **kwargs,
31
32
  ) -> Union[Tuple, CausalLMOutputWithPast]:
32
33
  r"""
33
34
  Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -82,6 +83,7 @@ def lce_forward(
82
83
  output_hidden_states=output_hidden_states,
83
84
  return_dict=return_dict,
84
85
  cache_position=cache_position,
86
+ **kwargs,
85
87
  )
86
88
 
87
89
  hidden_states = outputs[0]
@@ -89,18 +91,24 @@ def lce_forward(
89
91
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
90
92
  kept_hidden_states = hidden_states[:, slice_indices, :]
91
93
 
92
- shift_labels = loss_kwargs.pop("shift_labels", None)
94
+ shift_labels = kwargs.pop("shift_labels", None)
93
95
  loss = None
94
96
  logits = None
95
97
 
96
- if self.training and (labels is not None or shift_labels is not None):
98
+ if skip_logits and labels is None and shift_labels is None:
99
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
100
+
101
+ if skip_logits is None:
102
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
103
+
104
+ if skip_logits:
97
105
  loss = LigerForCausalLMLoss(
98
106
  hidden_states=kept_hidden_states,
99
107
  lm_head_weight=self.lm_head.weight,
100
108
  labels=labels,
101
109
  shift_labels=shift_labels,
102
110
  hidden_size=self.config.hidden_size,
103
- **loss_kwargs,
111
+ **kwargs,
104
112
  )
105
113
 
106
114
  else:
@@ -112,7 +120,7 @@ def lce_forward(
112
120
  logits=logits,
113
121
  labels=labels,
114
122
  vocab_size=self.config.vocab_size,
115
- **loss_kwargs,
123
+ **kwargs,
116
124
  )
117
125
  if not return_dict:
118
126
  output = (logits,) + outputs[1:]
@@ -125,6 +133,3 @@ def lce_forward(
125
133
  hidden_states=outputs.hidden_states,
126
134
  attentions=outputs.attentions,
127
135
  )
128
-
129
-
130
- # Note: Grad Acc is not fixed in mistral at transformer 4.46.1
@@ -156,7 +156,8 @@ def lce_forward(
156
156
  return_dict: Optional[bool] = None,
157
157
  cache_position: Optional[torch.LongTensor] = None,
158
158
  logits_to_keep: Union[int, torch.Tensor] = 0,
159
- **loss_kwargs,
159
+ skip_logits: Optional[bool] = None,
160
+ **kwargs,
160
161
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
161
162
  r"""
162
163
  Args:
@@ -214,6 +215,7 @@ def lce_forward(
214
215
  output_router_logits=output_router_logits,
215
216
  return_dict=return_dict,
216
217
  cache_position=cache_position,
218
+ **kwargs,
217
219
  )
218
220
 
219
221
  hidden_states = outputs[0]
@@ -221,26 +223,33 @@ def lce_forward(
221
223
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
222
224
  kept_hidden_states = hidden_states[:, slice_indices, :]
223
225
 
224
- shift_labels = loss_kwargs.pop("shift_labels", None)
226
+ shift_labels = kwargs.pop("shift_labels", None)
225
227
  logits = None
226
228
  loss = None
227
- # if in training mode, don't materialize logits
228
- if self.training and (labels is not None or shift_labels is not None):
229
+
230
+ if skip_logits and labels is None and shift_labels is None:
231
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
232
+
233
+ if skip_logits is None:
234
+ # By default, if in training mode, don't materialize logits
235
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
236
+
237
+ if skip_logits:
229
238
  loss = LigerForCausalLMLoss(
230
239
  hidden_states=kept_hidden_states,
231
240
  lm_head_weight=self.lm_head.weight,
232
241
  labels=labels,
233
242
  shift_labels=shift_labels,
234
243
  hidden_size=self.config.hidden_size,
235
- **loss_kwargs,
244
+ **kwargs,
236
245
  )
237
246
 
238
- else: # if in inference mode materialize logits
247
+ else:
239
248
  logits = self.lm_head(kept_hidden_states)
240
249
 
241
250
  loss = None
242
251
  if labels is not None:
243
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
252
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
244
253
  aux_loss = None
245
254
  if output_router_logits:
246
255
  aux_loss = load_balancing_loss_func(
@@ -147,7 +147,8 @@ def lce_forward(
147
147
  return_dict: Optional[bool] = None,
148
148
  cache_position: Optional[torch.LongTensor] = None,
149
149
  logits_to_keep: Union[int, torch.Tensor] = 0,
150
- **loss_kwargs,
150
+ skip_logits: Optional[bool] = None,
151
+ **kwargs,
151
152
  ) -> Union[Tuple, CausalLMOutputWithPast]:
152
153
  r"""
153
154
  Args:
@@ -205,6 +206,7 @@ def lce_forward(
205
206
  output_hidden_states=output_hidden_states,
206
207
  return_dict=return_dict,
207
208
  cache_position=cache_position,
209
+ **kwargs,
208
210
  )
209
211
 
210
212
  hidden_states = outputs[0]
@@ -212,28 +214,35 @@ def lce_forward(
212
214
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
213
215
  kept_hidden_states = hidden_states[:, slice_indices, :]
214
216
 
215
- shift_labels = loss_kwargs.pop("shift_labels", None)
217
+ shift_labels = kwargs.pop("shift_labels", None)
216
218
  logits = None
217
219
  loss = None
218
- # if in training mode, don't materialize logits
219
- if self.training and (labels is not None or shift_labels is not None):
220
+
221
+ if skip_logits and labels is None and shift_labels is None:
222
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
223
+
224
+ if skip_logits is None:
225
+ # By default, if in training mode, don't materialize logits
226
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
227
+
228
+ if skip_logits:
220
229
  loss = LigerForCausalLMLoss(
221
230
  hidden_states=kept_hidden_states,
222
231
  lm_head_weight=self.lm_head.weight,
223
232
  labels=labels,
224
233
  shift_labels=shift_labels,
225
234
  hidden_size=self.config.hidden_size,
226
- **loss_kwargs,
235
+ **kwargs,
227
236
  )
228
237
 
229
- else: # if in inference mode materialize logits
238
+ else:
230
239
  logits = self.lm_head(kept_hidden_states)
231
240
  if labels is not None:
232
241
  loss = self.loss_function(
233
242
  logits=logits,
234
243
  labels=labels,
235
244
  vocab_size=self.config.vocab_size,
236
- **loss_kwargs,
245
+ **kwargs,
237
246
  )
238
247
 
239
248
  if not return_dict:
@@ -26,7 +26,8 @@ def lce_forward(
26
26
  return_dict: Optional[bool] = None,
27
27
  cache_position: Optional[torch.LongTensor] = None,
28
28
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
- **loss_kwargs,
29
+ skip_logits: Optional[bool] = None,
30
+ **kwargs,
30
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
31
32
  r"""
32
33
  Args:
@@ -79,6 +80,7 @@ def lce_forward(
79
80
  output_hidden_states=output_hidden_states,
80
81
  return_dict=return_dict,
81
82
  cache_position=cache_position,
83
+ **kwargs,
82
84
  )
83
85
 
84
86
  hidden_states = outputs[0]
@@ -86,28 +88,35 @@ def lce_forward(
86
88
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
87
89
  kept_hidden_states = hidden_states[:, slice_indices, :]
88
90
 
89
- shift_labels = loss_kwargs.pop("shift_labels", None)
91
+ shift_labels = kwargs.pop("shift_labels", None)
90
92
  logits = None
91
93
  loss = None
92
- # if in training mode, don't materialize logits
93
- if self.training and (labels is not None or shift_labels is not None):
94
+
95
+ if skip_logits and labels is None and shift_labels is None:
96
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
97
+
98
+ if skip_logits is None:
99
+ # By default, if in training mode, don't materialize logits
100
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
101
+
102
+ if skip_logits:
94
103
  loss = LigerForCausalLMLoss(
95
104
  hidden_states=kept_hidden_states,
96
105
  lm_head_weight=self.lm_head.weight,
97
106
  labels=labels,
98
107
  shift_labels=shift_labels,
99
108
  hidden_size=self.config.hidden_size,
100
- **loss_kwargs,
109
+ **kwargs,
101
110
  )
102
111
 
103
- else: # if in inference mode materialize logits
112
+ else:
104
113
  logits = self.lm_head(kept_hidden_states)
105
114
  if labels is not None:
106
115
  loss = self.loss_function(
107
116
  logits=logits,
108
117
  labels=labels,
109
118
  vocab_size=self.config.vocab_size,
110
- **loss_kwargs,
119
+ **kwargs,
111
120
  )
112
121
 
113
122
  return CausalLMOutputWithPast(
@@ -216,6 +216,7 @@ def lce_forward(
216
216
  output_hidden_states: Optional[bool] = None,
217
217
  return_dict: Optional[bool] = None,
218
218
  logits_to_keep: Union[int, torch.Tensor] = 0,
219
+ skip_logits: Optional[bool] = None,
219
220
  **lm_kwargs,
220
221
  ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
221
222
  r"""
@@ -331,7 +332,13 @@ def lce_forward(
331
332
  loss = None
332
333
  logits = None
333
334
 
334
- if self.training and (labels is not None):
335
+ if skip_logits and labels is None:
336
+ raise ValueError("skip_logits is True, but labels is None")
337
+
338
+ if skip_logits is None:
339
+ skip_logits = self.training and (labels is not None)
340
+
341
+ if skip_logits:
335
342
  shift_hidden_states = hidden_states[..., :-1, :]
336
343
  shift_labels = labels[..., 1:]
337
344
 
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
26
26
  output_hidden_states: Optional[bool] = None,
27
27
  return_dict: Optional[bool] = None,
28
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
29
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
30
31
  r"""
31
32
  Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
@@ -80,7 +81,14 @@ def lce_forward_deprecated(
80
81
  loss = None
81
82
  logits = None
82
83
 
83
- if self.training and labels is not None:
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
91
+ if skip_logits:
84
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
85
93
  shift_labels = labels[..., 1:].contiguous()
86
94
 
@@ -136,7 +144,8 @@ def lce_forward(
136
144
  return_dict: Optional[bool] = None,
137
145
  cache_position: Optional[torch.LongTensor] = None,
138
146
  logits_to_keep: Union[int, torch.Tensor] = 0,
139
- **loss_kwargs,
147
+ skip_logits: Optional[bool] = None,
148
+ **kwargs,
140
149
  ) -> Union[Tuple, CausalLMOutputWithPast]:
141
150
  r"""
142
151
  Args:
@@ -202,6 +211,7 @@ def lce_forward(
202
211
  output_attentions=output_attentions,
203
212
  output_hidden_states=output_hidden_states,
204
213
  return_dict=return_dict,
214
+ **kwargs,
205
215
  )
206
216
 
207
217
  hidden_states = outputs[0]
@@ -209,28 +219,35 @@ def lce_forward(
209
219
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
210
220
  kept_hidden_states = hidden_states[:, slice_indices, :]
211
221
 
212
- shift_labels = loss_kwargs.pop("shift_labels", None)
222
+ shift_labels = kwargs.pop("shift_labels", None)
213
223
  logits = None
214
224
  loss = None
215
- # if in training mode, don't materialize logits
216
- if self.training and (labels is not None or shift_labels is not None):
225
+
226
+ if skip_logits and labels is None and shift_labels is None:
227
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
228
+
229
+ if skip_logits is None:
230
+ # By default, if in training mode, don't materialize logits
231
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
232
+
233
+ if skip_logits:
217
234
  loss = LigerForCausalLMLoss(
218
235
  hidden_states=kept_hidden_states,
219
236
  lm_head_weight=self.lm_head.weight,
220
237
  labels=labels,
221
238
  shift_labels=shift_labels,
222
239
  hidden_size=self.config.hidden_size,
223
- **loss_kwargs,
240
+ **kwargs,
224
241
  )
225
242
 
226
- else: # if in inference mode materialize logits
243
+ else:
227
244
  logits = self.lm_head(kept_hidden_states)
228
245
  if labels is not None:
229
246
  loss = self.loss_function(
230
247
  logits=logits,
231
248
  labels=labels,
232
249
  vocab_size=self.config.vocab_size,
233
- **loss_kwargs,
250
+ **kwargs,
234
251
  )
235
252
 
236
253
  if not return_dict: