liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  6. liger_kernel/ops/geglu.py +1 -1
  7. liger_kernel/ops/multi_token_attention.py +207 -0
  8. liger_kernel/ops/rms_norm.py +265 -54
  9. liger_kernel/ops/softmax.py +201 -0
  10. liger_kernel/ops/sparsemax.py +62 -50
  11. liger_kernel/ops/swiglu.py +1 -1
  12. liger_kernel/transformers/__init__.py +3 -0
  13. liger_kernel/transformers/functional.py +62 -0
  14. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  15. liger_kernel/transformers/model/gemma.py +25 -8
  16. liger_kernel/transformers/model/gemma2.py +27 -8
  17. liger_kernel/transformers/model/gemma3.py +62 -98
  18. liger_kernel/transformers/model/glm4.py +16 -7
  19. liger_kernel/transformers/model/llama.py +25 -7
  20. liger_kernel/transformers/model/llama4.py +108 -0
  21. liger_kernel/transformers/model/llava.py +95 -124
  22. liger_kernel/transformers/model/mistral.py +13 -8
  23. liger_kernel/transformers/model/mixtral.py +16 -7
  24. liger_kernel/transformers/model/mllama.py +16 -7
  25. liger_kernel/transformers/model/olmo2.py +16 -7
  26. liger_kernel/transformers/model/paligemma.py +8 -1
  27. liger_kernel/transformers/model/phi3.py +25 -8
  28. liger_kernel/transformers/model/qwen2.py +24 -7
  29. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  30. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  31. liger_kernel/transformers/model/qwen3.py +11 -3
  32. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  33. liger_kernel/transformers/monkey_patch.py +304 -70
  34. liger_kernel/transformers/multi_token_attention.py +64 -0
  35. liger_kernel/transformers/rms_norm.py +40 -4
  36. liger_kernel/transformers/softmax.py +12 -0
  37. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
  38. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
  39. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  40. liger_kernel/transformers/gema3_rms.py +0 -8
  41. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ def lce_forward_deprecated(
27
27
  output_hidden_states: Optional[bool] = None,
28
28
  return_dict: Optional[bool] = None,
29
29
  cache_position: Optional[torch.LongTensor] = None,
30
+ skip_logits: Optional[bool] = None,
30
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
31
32
  r"""
32
33
 
@@ -81,7 +82,14 @@ def lce_forward_deprecated(
81
82
  loss = None
82
83
  logits = None
83
84
 
84
- if self.training and (labels is not None):
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels is None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ if skip_logits:
85
93
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
86
94
  shift_labels = labels[..., 1:].contiguous()
87
95
 
@@ -137,7 +145,8 @@ def lce_forward(
137
145
  return_dict: Optional[bool] = None,
138
146
  cache_position: Optional[torch.LongTensor] = None,
139
147
  logits_to_keep: Union[int, torch.Tensor] = 0,
140
- **loss_kwargs,
148
+ skip_logits: Optional[bool] = None,
149
+ **kwargs,
141
150
  ) -> Union[Tuple, CausalLMOutputWithPast]:
142
151
  r"""
143
152
  Args:
@@ -189,6 +198,7 @@ def lce_forward(
189
198
  output_hidden_states=output_hidden_states,
190
199
  return_dict=return_dict,
191
200
  cache_position=cache_position,
201
+ **kwargs,
192
202
  )
193
203
 
194
204
  hidden_states = outputs[0]
@@ -196,27 +206,34 @@ def lce_forward(
196
206
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
197
207
  kept_hidden_states = hidden_states[:, slice_indices, :]
198
208
 
199
- shift_labels = loss_kwargs.pop("shift_labels", None)
209
+ shift_labels = kwargs.pop("shift_labels", None)
200
210
  logits = None
201
211
  loss = None
202
- # if in training mode, don't materialize logits
203
- if self.training and (labels is not None or shift_labels is not None):
212
+
213
+ if skip_logits and labels is None and shift_labels is None:
214
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
215
+
216
+ if skip_logits is None:
217
+ # By default, if in training mode, don't materialize logits
218
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
219
+
220
+ if skip_logits:
204
221
  loss = LigerForCausalLMLoss(
205
222
  hidden_states=kept_hidden_states,
206
223
  lm_head_weight=self.lm_head.weight,
207
224
  labels=labels,
208
225
  shift_labels=shift_labels,
209
226
  hidden_size=self.config.hidden_size,
210
- **loss_kwargs,
227
+ **kwargs,
211
228
  )
212
- else: # if in inference mode materialize logits
229
+ else:
213
230
  logits = self.lm_head(kept_hidden_states)
214
231
  if labels is not None:
215
232
  loss = self.loss_function(
216
233
  logits=logits,
217
234
  labels=labels,
218
235
  vocab_size=self.config.vocab_size,
219
- **loss_kwargs,
236
+ **kwargs,
220
237
  )
221
238
 
222
239
  if not return_dict:
@@ -30,6 +30,8 @@ def lce_forward_deprecated(
30
30
  output_hidden_states: Optional[bool] = None,
31
31
  return_dict: Optional[bool] = None,
32
32
  cache_position: Optional[torch.LongTensor] = None,
33
+ skip_logits: Optional[bool] = None,
34
+ **kwargs,
33
35
  ) -> Union[Tuple, CausalLMOutputWithPast]:
34
36
  r"""
35
37
  Args:
@@ -76,6 +78,7 @@ def lce_forward_deprecated(
76
78
  output_hidden_states=output_hidden_states,
77
79
  return_dict=return_dict,
78
80
  cache_position=cache_position,
81
+ **kwargs,
79
82
  )
80
83
 
81
84
  hidden_states = outputs[0]
@@ -83,7 +86,14 @@ def lce_forward_deprecated(
83
86
  loss = None
84
87
  logits = None
85
88
 
86
- if self.training and (labels is not None):
89
+ if skip_logits and labels is None:
90
+ raise ValueError("skip_logits is True, but labels is None")
91
+
92
+ if skip_logits is None:
93
+ # By default, if in training mode, don't materialize logits
94
+ skip_logits = self.training and labels is not None
95
+
96
+ if skip_logits:
87
97
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
88
98
  shift_labels = labels[..., 1:].contiguous()
89
99
 
@@ -146,7 +156,8 @@ def lce_forward(
146
156
  return_dict: Optional[bool] = None,
147
157
  cache_position: Optional[torch.LongTensor] = None,
148
158
  logits_to_keep: Union[int, torch.Tensor] = 0,
149
- **loss_kwargs,
159
+ skip_logits: Optional[bool] = None,
160
+ **kwargs,
150
161
  ) -> Union[Tuple, CausalLMOutputWithPast]:
151
162
  r"""
152
163
  Args:
@@ -203,6 +214,7 @@ def lce_forward(
203
214
  output_hidden_states=output_hidden_states,
204
215
  return_dict=return_dict,
205
216
  cache_position=cache_position,
217
+ **kwargs,
206
218
  )
207
219
 
208
220
  hidden_states = outputs[0]
@@ -210,11 +222,18 @@ def lce_forward(
210
222
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
211
223
  kept_hidden_states = hidden_states[:, slice_indices, :]
212
224
 
213
- shift_labels = loss_kwargs.pop("shift_labels", None)
225
+ shift_labels = kwargs.pop("shift_labels", None)
214
226
  logits = None
215
227
  loss = None
216
- # if in training mode, don't materialize logits
217
- if self.training and (labels is not None or shift_labels is not None):
228
+
229
+ if skip_logits and labels is None and shift_labels is None:
230
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
231
+
232
+ if skip_logits is None:
233
+ # By default, if in training mode, don't materialize logits
234
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
235
+
236
+ if skip_logits:
218
237
  loss = LigerForCausalLMLoss(
219
238
  hidden_states=kept_hidden_states,
220
239
  lm_head_weight=self.lm_head.weight,
@@ -222,10 +241,10 @@ def lce_forward(
222
241
  shift_labels=shift_labels,
223
242
  hidden_size=self.config.hidden_size,
224
243
  final_logit_softcapping=self.config.final_logit_softcapping,
225
- **loss_kwargs,
244
+ **kwargs,
226
245
  )
227
246
 
228
- else: # if in inference mode materialize logits
247
+ else:
229
248
  logits = self.lm_head(kept_hidden_states)
230
249
  if self.config.final_logit_softcapping is not None:
231
250
  logits = logits / self.config.final_logit_softcapping
@@ -234,7 +253,7 @@ def lce_forward(
234
253
 
235
254
  loss = None
236
255
  if labels is not None:
237
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
256
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
238
257
 
239
258
  if not return_dict:
240
259
  output = (logits,) + outputs[1:]
@@ -1,4 +1,3 @@
1
- from typing import List
2
1
  from typing import Optional
3
2
  from typing import Tuple
4
3
  from typing import Union
@@ -10,9 +9,7 @@ from transformers.cache_utils import Cache
10
9
  from transformers.cache_utils import HybridCache
11
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
11
  from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
13
- from transformers.utils import is_torchdynamo_compiling
14
12
  from transformers.utils import logging
15
- from transformers.utils.deprecation import deprecate_kwarg
16
13
 
17
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -20,7 +17,6 @@ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
20
17
  logger = logging.get_logger(__name__)
21
18
 
22
19
 
23
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
24
20
  def causal_forward(
25
21
  self,
26
22
  input_ids: torch.LongTensor = None,
@@ -35,6 +31,7 @@ def causal_forward(
35
31
  return_dict: Optional[bool] = None,
36
32
  cache_position: Optional[torch.LongTensor] = None,
37
33
  logits_to_keep: Union[int, torch.Tensor] = 0,
34
+ skip_logits: Optional[bool] = None,
38
35
  **loss_kwargs,
39
36
  ) -> Union[Tuple, CausalLMOutputWithPast]:
40
37
  r"""
@@ -101,7 +98,11 @@ def causal_forward(
101
98
  shift_labels = loss_kwargs.pop("shift_labels", None)
102
99
  loss = None
103
100
  logits = None
104
- if self.training and (labels is not None or shift_labels is not None):
101
+
102
+ if skip_logits is None:
103
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
104
+
105
+ if skip_logits:
105
106
  loss = LigerForCausalLMLoss(
106
107
  hidden_states=kept_hidden_states,
107
108
  lm_head_weight=self.lm_head.weight,
@@ -134,14 +135,13 @@ def causal_forward(
134
135
  )
135
136
 
136
137
 
137
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
138
  def multimodal_forward(
139
139
  self,
140
140
  input_ids: torch.LongTensor = None,
141
141
  pixel_values: torch.FloatTensor = None,
142
142
  attention_mask: Optional[torch.Tensor] = None,
143
143
  position_ids: Optional[torch.LongTensor] = None,
144
- past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
144
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
145
145
  token_type_ids: Optional[torch.LongTensor] = None,
146
146
  cache_position: Optional[torch.LongTensor] = None,
147
147
  inputs_embeds: Optional[torch.FloatTensor] = None,
@@ -151,22 +151,14 @@ def multimodal_forward(
151
151
  output_hidden_states: Optional[bool] = None,
152
152
  return_dict: Optional[bool] = None,
153
153
  logits_to_keep: Union[int, torch.Tensor] = 0,
154
+ skip_logits: Optional[bool] = None,
154
155
  **lm_kwargs,
155
- ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
156
+ ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
156
157
  r"""
157
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
158
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
159
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
160
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
161
-
162
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
163
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
164
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
165
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
166
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
167
- This is useful when using packed tensor format (single dimension for batch and sequence length).
168
-
169
- Returns:
158
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
159
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
160
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
161
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
170
162
 
171
163
  Example:
172
164
 
@@ -175,23 +167,37 @@ def multimodal_forward(
175
167
  >>> import requests
176
168
  >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
177
169
 
178
- >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
179
- >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
180
-
181
- >>> prompt = "answer en Where is the cow standing?"
182
- >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
183
- >>> image = Image.open(requests.get(url, stream=True).raw)
184
-
185
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
186
-
170
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
171
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
172
+
173
+ >>> messages = [
174
+ ... {
175
+ ... "role": "system",
176
+ ... "content": [
177
+ ... {"type": "text", "text": "You are a helpful assistant."}
178
+ ... ]
179
+ ... },
180
+ ... {
181
+ ... "role": "user", "content": [
182
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
183
+ ... {"type": "text", "text": "Where is the cat standing?"},
184
+ ... ]
185
+ ... },
186
+ ... ]
187
+
188
+ >>> inputs = processor.apply_chat_template(
189
+ ... messages,
190
+ ... tokenize=True,
191
+ ... return_dict=True,
192
+ ... return_tensors="pt",
193
+ ... add_generation_prompt=True
194
+ ... )
187
195
  >>> # Generate
188
- >>> generate_ids = model.generate(**inputs, max_length=30)
196
+ >>> generate_ids = model.generate(**inputs)
189
197
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
190
- "answer en Where is the cow standing?\nbeach"
191
- ```"""
192
-
193
- if (input_ids is None) ^ (inputs_embeds is not None):
194
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
198
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
199
+ ```
200
+ """
195
201
 
196
202
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
197
203
  output_hidden_states = (
@@ -199,81 +205,38 @@ def multimodal_forward(
199
205
  )
200
206
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
201
207
 
202
- is_training = token_type_ids is not None and labels is not None
203
-
204
- # Replace image id woth PAD if the image token if OOV, to avoid index-errors
205
- if input_ids is not None and self.config.image_token_index >= self.vocab_size:
206
- special_image_mask = input_ids == self.config.image_token_index
207
- llm_input_ids = input_ids.clone()
208
- llm_input_ids[special_image_mask] = 0
209
- else:
210
- llm_input_ids = input_ids
211
-
212
- if inputs_embeds is None:
213
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
214
-
215
- if cache_position is None:
216
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
217
- cache_position = torch.arange(
218
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
219
- )
220
-
221
- if position_ids is None:
222
- position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
223
-
224
- # Merge text and images
225
- if pixel_values is not None:
226
- image_features = self.get_image_features(pixel_values)
227
-
228
- if input_ids is None:
229
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
230
- torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
231
- )
232
- else:
233
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
234
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
235
-
236
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
237
- image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
238
- raise ValueError(
239
- f"Number of images does not match number of special image tokens in the input text. "
240
- f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
241
- "tokens from image embeddings."
242
- )
243
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
244
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
245
-
246
- # mask out pad-token-ids in labels for BC
247
- if labels is not None and self.pad_token_id in labels:
248
- logger.warning_once(
249
- "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
250
- "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
251
- )
252
- labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
253
-
254
- causal_mask = self._update_causal_mask(
255
- attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
256
- )
257
- outputs = self.language_model.model(
258
- attention_mask=causal_mask,
208
+ outputs = self.model(
209
+ input_ids=input_ids,
210
+ pixel_values=pixel_values,
211
+ token_type_ids=token_type_ids,
212
+ attention_mask=attention_mask,
259
213
  position_ids=position_ids,
260
214
  past_key_values=past_key_values,
261
215
  inputs_embeds=inputs_embeds,
262
216
  use_cache=use_cache,
217
+ labels=labels,
263
218
  output_attentions=output_attentions,
264
219
  output_hidden_states=output_hidden_states,
265
220
  return_dict=return_dict,
266
221
  cache_position=cache_position,
267
- logits_to_keep=logits_to_keep,
268
222
  **lm_kwargs,
269
223
  )
270
224
 
271
225
  hidden_states = outputs[0]
226
+
227
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
228
+ kept_hidden_states = hidden_states[:, slice_indices, :]
229
+
272
230
  loss = None
273
231
  logits = None
232
+ if skip_logits and labels is None:
233
+ raise ValueError("skip_logits is True, but labels is None")
234
+
235
+ if skip_logits is None:
236
+ skip_logits = self.training and (labels is not None)
274
237
 
275
- if self.training and (labels is not None):
276
- shift_hidden_states = hidden_states[..., :-1, :]
238
+ if skip_logits:
239
+ shift_hidden_states = kept_hidden_states[..., :-1, :]
277
240
  shift_labels = labels[..., 1:]
278
241
 
279
242
  hidden_device = shift_hidden_states.device
@@ -294,7 +257,7 @@ def multimodal_forward(
294
257
  lce = LigerFusedLinearCrossEntropyLoss()
295
258
  loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
296
259
  else:
297
- logits = self.language_model.lm_head(hidden_states)
260
+ logits = self.lm_head(kept_hidden_states)
298
261
  if labels is not None:
299
262
  # Upcast to float if we need to compute the loss to avoid potential precision issues
300
263
  logits = logits.float()
@@ -315,6 +278,7 @@ def multimodal_forward(
315
278
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
316
279
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
317
280
  loss = loss_fct(flat_logits, flat_labels)
281
+
318
282
  if not return_dict:
319
283
  output = (logits,) + outputs[1:]
320
284
  return (loss,) + output if loss is not None else output
@@ -325,5 +289,5 @@ def multimodal_forward(
325
289
  past_key_values=outputs.past_key_values,
326
290
  hidden_states=outputs.hidden_states,
327
291
  attentions=outputs.attentions,
328
- image_hidden_states=image_features if pixel_values is not None else None,
292
+ image_hidden_states=outputs.image_hidden_states,
329
293
  )
@@ -26,7 +26,8 @@ def lce_forward(
26
26
  return_dict: Optional[bool] = None,
27
27
  cache_position: Optional[torch.LongTensor] = None,
28
28
  logits_to_keep: Union[int, torch.Tensor] = 0,
29
- **loss_kwargs,
29
+ skip_logits: Optional[bool] = None,
30
+ **kwargs,
30
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
31
32
  r"""
32
33
  Args:
@@ -79,6 +80,7 @@ def lce_forward(
79
80
  output_hidden_states=output_hidden_states,
80
81
  return_dict=return_dict,
81
82
  cache_position=cache_position,
83
+ **kwargs,
82
84
  )
83
85
 
84
86
  hidden_states = outputs[0]
@@ -86,28 +88,35 @@ def lce_forward(
86
88
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
87
89
  kept_hidden_states = hidden_states[:, slice_indices, :]
88
90
 
89
- shift_labels = loss_kwargs.pop("shift_labels", None)
91
+ shift_labels = kwargs.pop("shift_labels", None)
90
92
  logits = None
91
93
  loss = None
92
- # if in training mode, don't materialize logits
93
- if self.training and (labels is not None or shift_labels is not None):
94
+
95
+ if skip_logits and labels is None and shift_labels is None:
96
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
97
+
98
+ if skip_logits is None:
99
+ # By default, if in training mode, don't materialize logits
100
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
101
+
102
+ if skip_logits:
94
103
  loss = LigerForCausalLMLoss(
95
104
  hidden_states=kept_hidden_states,
96
105
  lm_head_weight=self.lm_head.weight,
97
106
  labels=labels,
98
107
  shift_labels=shift_labels,
99
108
  hidden_size=self.config.hidden_size,
100
- **loss_kwargs,
109
+ **kwargs,
101
110
  )
102
111
 
103
- else: # if in inference mode materialize logits
112
+ else:
104
113
  logits = self.lm_head(kept_hidden_states)
105
114
  if labels is not None:
106
115
  loss = self.loss_function(
107
116
  logits=logits,
108
117
  labels=labels,
109
118
  vocab_size=self.config.vocab_size,
110
- **loss_kwargs,
119
+ **kwargs,
111
120
  )
112
121
 
113
122
  return CausalLMOutputWithPast(
@@ -37,6 +37,7 @@ def lce_forward_deprecated(
37
37
  output_hidden_states: Optional[bool] = None,
38
38
  return_dict: Optional[bool] = None,
39
39
  cache_position: Optional[torch.LongTensor] = None,
40
+ skip_logits: Optional[bool] = None,
40
41
  ) -> Union[Tuple, CausalLMOutputWithPast]:
41
42
  r"""
42
43
  Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
@@ -91,7 +92,15 @@ def lce_forward_deprecated(
91
92
  loss = None
92
93
  logits = None
93
94
 
94
- if self.training and (labels is not None):
95
+ # if in training mode, don't materialize logits
96
+ if skip_logits and labels is None:
97
+ raise ValueError("skip_logits is True, but labels is None")
98
+
99
+ if skip_logits is None:
100
+ # By default, if in training mode, don't materialize logits
101
+ skip_logits = self.training and labels is not None
102
+
103
+ if skip_logits:
95
104
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
96
105
  shift_labels = labels[..., 1:].contiguous()
97
106
 
@@ -151,7 +160,8 @@ def lce_forward(
151
160
  return_dict: Optional[bool] = None,
152
161
  cache_position: Optional[torch.LongTensor] = None,
153
162
  logits_to_keep: Union[int, torch.Tensor] = 0,
154
- **loss_kwargs,
163
+ skip_logits: Optional[bool] = None,
164
+ **kwargs,
155
165
  ) -> Union[Tuple, CausalLMOutputWithPast]:
156
166
  r"""
157
167
  Args:
@@ -204,6 +214,7 @@ def lce_forward(
204
214
  output_hidden_states=output_hidden_states,
205
215
  return_dict=return_dict,
206
216
  cache_position=cache_position,
217
+ **kwargs,
207
218
  )
208
219
 
209
220
  hidden_states = outputs[0]
@@ -214,28 +225,35 @@ def lce_forward(
214
225
  if self.config.pretraining_tp > 1:
215
226
  raise Exception("Liger Kernel does not support pretraining_tp!!")
216
227
 
217
- shift_labels = loss_kwargs.pop("shift_labels", None)
228
+ shift_labels = kwargs.pop("shift_labels", None)
218
229
  logits = None
219
230
  loss = None
220
231
  # if in training mode, don't materialize logits
221
- if self.training and (labels is not None or shift_labels is not None):
232
+ if skip_logits and labels is None and shift_labels is None:
233
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
234
+
235
+ if skip_logits is None:
236
+ # By default, if in training mode, don't materialize logits
237
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
238
+
239
+ if skip_logits:
222
240
  loss = lce_maybe_trainable_lm_head(
223
241
  self,
224
242
  hidden_states=kept_hidden_states,
225
243
  hidden_size=self.config.hidden_size,
226
244
  labels=labels,
227
245
  shift_labels=shift_labels,
228
- **loss_kwargs,
246
+ **kwargs,
229
247
  )
230
248
 
231
- else: # if in inference mode materialize logits
249
+ else:
232
250
  logits = self.lm_head(kept_hidden_states)
233
251
  if labels is not None:
234
252
  loss = self.loss_function(
235
253
  logits=logits,
236
254
  labels=labels,
237
255
  vocab_size=self.config.vocab_size,
238
- **loss_kwargs,
256
+ **kwargs,
239
257
  )
240
258
 
241
259
  if not return_dict: