liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  6. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  7. liger_kernel/ops/dyt.py +111 -179
  8. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  9. liger_kernel/ops/geglu.py +1 -1
  10. liger_kernel/ops/grpo_loss.py +310 -0
  11. liger_kernel/ops/multi_token_attention.py +207 -0
  12. liger_kernel/ops/rms_norm.py +265 -54
  13. liger_kernel/ops/softmax.py +201 -0
  14. liger_kernel/ops/sparsemax.py +179 -0
  15. liger_kernel/ops/swiglu.py +1 -1
  16. liger_kernel/transformers/__init__.py +8 -0
  17. liger_kernel/transformers/dyt.py +5 -3
  18. liger_kernel/transformers/fsdp.py +55 -0
  19. liger_kernel/transformers/functional.py +70 -0
  20. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  21. liger_kernel/transformers/grpo_loss.py +98 -0
  22. liger_kernel/transformers/model/gemma.py +25 -16
  23. liger_kernel/transformers/model/gemma2.py +27 -14
  24. liger_kernel/transformers/model/gemma3.py +62 -106
  25. liger_kernel/transformers/model/glm4.py +16 -13
  26. liger_kernel/transformers/model/llama.py +81 -18
  27. liger_kernel/transformers/model/llama4.py +108 -0
  28. liger_kernel/transformers/model/llava.py +95 -132
  29. liger_kernel/transformers/model/mistral.py +13 -14
  30. liger_kernel/transformers/model/mixtral.py +16 -15
  31. liger_kernel/transformers/model/mllama.py +16 -14
  32. liger_kernel/transformers/model/olmo2.py +16 -13
  33. liger_kernel/transformers/model/paligemma.py +8 -9
  34. liger_kernel/transformers/model/phi3.py +25 -16
  35. liger_kernel/transformers/model/qwen2.py +24 -15
  36. liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
  37. liger_kernel/transformers/model/qwen2_vl.py +38 -106
  38. liger_kernel/transformers/model/qwen3.py +11 -9
  39. liger_kernel/transformers/model/qwen3_moe.py +132 -0
  40. liger_kernel/transformers/monkey_patch.py +424 -81
  41. liger_kernel/transformers/multi_token_attention.py +64 -0
  42. liger_kernel/transformers/rms_norm.py +40 -4
  43. liger_kernel/transformers/softmax.py +12 -0
  44. liger_kernel/transformers/sparsemax.py +16 -0
  45. liger_kernel/transformers/swiglu.py +21 -0
  46. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  47. liger_kernel/utils.py +11 -0
  48. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
  49. liger_kernel-0.6.0.dist-info/RECORD +97 -0
  50. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  51. liger_kernel/transformers/gema3_rms.py +0 -8
  52. liger_kernel-0.5.9.dist-info/RECORD +0 -84
  53. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
1
+ from liger_kernel.ops.grpo_loss import GrpoLossFunction
2
+
3
+
4
+ def triton_grpo_loss(
5
+ logits,
6
+ old_logp,
7
+ ref_logp,
8
+ completion_ids,
9
+ advantages,
10
+ completion_mask=None,
11
+ temperature=0.9,
12
+ beta=0.04,
13
+ eps_low=0.2,
14
+ eps_high=0.4,
15
+ inplace=True,
16
+ ):
17
+ assert logits is not None and completion_ids is not None and advantages is not None, (
18
+ "must provide logits、completion_ids and advantages"
19
+ )
20
+
21
+ return GrpoLossFunction.apply(
22
+ logits,
23
+ old_logp,
24
+ ref_logp,
25
+ completion_ids,
26
+ advantages,
27
+ completion_mask,
28
+ temperature,
29
+ beta,
30
+ eps_low,
31
+ eps_high,
32
+ inplace,
33
+ )
34
+
35
+
36
+ # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
37
+ """
38
+ import torch
39
+ import trl
40
+ assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
41
+ from trl.extras.profiling import profiling_decorator
42
+
43
+ @profiling_decorator
44
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
45
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
46
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
47
+ return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
48
+
49
+ @profiling_decorator
50
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
51
+ if return_outputs:
52
+ raise ValueError("The GRPOTrainer does not support returning outputs")
53
+ # Compute the per-token log probabilities for the model
54
+
55
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
56
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
57
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
58
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
59
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
60
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
61
+
62
+ ref_per_token_logps = inputs["ref_per_token_logps"]
63
+ advantages = inputs["advantages"]
64
+ old_per_token_logps = inputs["old_per_token_logps"]
65
+
66
+
67
+ per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
68
+ old_per_token_logps,
69
+ ref_per_token_logps,
70
+ completion_ids,
71
+ advantages,
72
+ completion_mask,
73
+ self.temperature,
74
+ self.beta,
75
+ self.epsilon_low,
76
+ self.epsilon_high,)
77
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
78
+
79
+ # Log the metrics
80
+ mode = "eval" if self.control.should_evaluate else "train"
81
+
82
+ if self.beta != 0.0:
83
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
84
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
85
+
86
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
87
+ self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
88
+ return loss
89
+
90
+ trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
91
+ trl.GRPOTrainer.compute_loss = compute_loss
92
+ trigger = None
93
+ """
94
+
95
+ # add this line at the first line of grpo.py in open-r1
96
+ """
97
+ from liger_kernel.transformers.grpo_loss import trigger
98
+ """
@@ -8,18 +8,12 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
- from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
19
15
 
20
16
 
21
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
17
  def lce_forward_deprecated(
24
18
  self,
25
19
  input_ids: torch.LongTensor = None,
@@ -33,6 +27,7 @@ def lce_forward_deprecated(
33
27
  output_hidden_states: Optional[bool] = None,
34
28
  return_dict: Optional[bool] = None,
35
29
  cache_position: Optional[torch.LongTensor] = None,
30
+ skip_logits: Optional[bool] = None,
36
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
37
32
  r"""
38
33
 
@@ -87,7 +82,14 @@ def lce_forward_deprecated(
87
82
  loss = None
88
83
  logits = None
89
84
 
90
- if self.training and (labels is not None):
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels is None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ if skip_logits:
91
93
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
92
94
  shift_labels = labels[..., 1:].contiguous()
93
95
 
@@ -129,8 +131,6 @@ def lce_forward_deprecated(
129
131
 
130
132
 
131
133
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
132
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
133
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
134
134
  def lce_forward(
135
135
  self,
136
136
  input_ids: torch.LongTensor = None,
@@ -145,7 +145,8 @@ def lce_forward(
145
145
  return_dict: Optional[bool] = None,
146
146
  cache_position: Optional[torch.LongTensor] = None,
147
147
  logits_to_keep: Union[int, torch.Tensor] = 0,
148
- **loss_kwargs,
148
+ skip_logits: Optional[bool] = None,
149
+ **kwargs,
149
150
  ) -> Union[Tuple, CausalLMOutputWithPast]:
150
151
  r"""
151
152
  Args:
@@ -197,6 +198,7 @@ def lce_forward(
197
198
  output_hidden_states=output_hidden_states,
198
199
  return_dict=return_dict,
199
200
  cache_position=cache_position,
201
+ **kwargs,
200
202
  )
201
203
 
202
204
  hidden_states = outputs[0]
@@ -204,27 +206,34 @@ def lce_forward(
204
206
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
205
207
  kept_hidden_states = hidden_states[:, slice_indices, :]
206
208
 
207
- shift_labels = loss_kwargs.pop("shift_labels", None)
209
+ shift_labels = kwargs.pop("shift_labels", None)
208
210
  logits = None
209
211
  loss = None
210
- # if in training mode, don't materialize logits
211
- if self.training and (labels is not None or shift_labels is not None):
212
+
213
+ if skip_logits and labels is None and shift_labels is None:
214
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
215
+
216
+ if skip_logits is None:
217
+ # By default, if in training mode, don't materialize logits
218
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
219
+
220
+ if skip_logits:
212
221
  loss = LigerForCausalLMLoss(
213
222
  hidden_states=kept_hidden_states,
214
223
  lm_head_weight=self.lm_head.weight,
215
224
  labels=labels,
216
225
  shift_labels=shift_labels,
217
226
  hidden_size=self.config.hidden_size,
218
- **loss_kwargs,
227
+ **kwargs,
219
228
  )
220
- else: # if in inference mode materialize logits
229
+ else:
221
230
  logits = self.lm_head(kept_hidden_states)
222
231
  if labels is not None:
223
232
  loss = self.loss_function(
224
233
  logits=logits,
225
234
  labels=labels,
226
235
  vocab_size=self.config.vocab_size,
227
- **loss_kwargs,
236
+ **kwargs,
228
237
  )
229
238
 
230
239
  if not return_dict:
@@ -9,10 +9,6 @@ import torch
9
9
  from torch.nn import CrossEntropyLoss
10
10
  from transformers.cache_utils import HybridCache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
12
  from transformers.utils.deprecation import deprecate_kwarg
17
13
 
18
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -34,6 +30,8 @@ def lce_forward_deprecated(
34
30
  output_hidden_states: Optional[bool] = None,
35
31
  return_dict: Optional[bool] = None,
36
32
  cache_position: Optional[torch.LongTensor] = None,
33
+ skip_logits: Optional[bool] = None,
34
+ **kwargs,
37
35
  ) -> Union[Tuple, CausalLMOutputWithPast]:
38
36
  r"""
39
37
  Args:
@@ -80,6 +78,7 @@ def lce_forward_deprecated(
80
78
  output_hidden_states=output_hidden_states,
81
79
  return_dict=return_dict,
82
80
  cache_position=cache_position,
81
+ **kwargs,
83
82
  )
84
83
 
85
84
  hidden_states = outputs[0]
@@ -87,7 +86,14 @@ def lce_forward_deprecated(
87
86
  loss = None
88
87
  logits = None
89
88
 
90
- if self.training and (labels is not None):
89
+ if skip_logits and labels is None:
90
+ raise ValueError("skip_logits is True, but labels is None")
91
+
92
+ if skip_logits is None:
93
+ # By default, if in training mode, don't materialize logits
94
+ skip_logits = self.training and labels is not None
95
+
96
+ if skip_logits:
91
97
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
92
98
  shift_labels = labels[..., 1:].contiguous()
93
99
 
@@ -136,8 +142,6 @@ def lce_forward_deprecated(
136
142
 
137
143
 
138
144
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
139
- @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
140
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
141
145
  def lce_forward(
142
146
  self,
143
147
  input_ids: torch.LongTensor = None,
@@ -152,7 +156,8 @@ def lce_forward(
152
156
  return_dict: Optional[bool] = None,
153
157
  cache_position: Optional[torch.LongTensor] = None,
154
158
  logits_to_keep: Union[int, torch.Tensor] = 0,
155
- **loss_kwargs,
159
+ skip_logits: Optional[bool] = None,
160
+ **kwargs,
156
161
  ) -> Union[Tuple, CausalLMOutputWithPast]:
157
162
  r"""
158
163
  Args:
@@ -209,6 +214,7 @@ def lce_forward(
209
214
  output_hidden_states=output_hidden_states,
210
215
  return_dict=return_dict,
211
216
  cache_position=cache_position,
217
+ **kwargs,
212
218
  )
213
219
 
214
220
  hidden_states = outputs[0]
@@ -216,11 +222,18 @@ def lce_forward(
216
222
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
223
  kept_hidden_states = hidden_states[:, slice_indices, :]
218
224
 
219
- shift_labels = loss_kwargs.pop("shift_labels", None)
225
+ shift_labels = kwargs.pop("shift_labels", None)
220
226
  logits = None
221
227
  loss = None
222
- # if in training mode, don't materialize logits
223
- if self.training and (labels is not None or shift_labels is not None):
228
+
229
+ if skip_logits and labels is None and shift_labels is None:
230
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
231
+
232
+ if skip_logits is None:
233
+ # By default, if in training mode, don't materialize logits
234
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
235
+
236
+ if skip_logits:
224
237
  loss = LigerForCausalLMLoss(
225
238
  hidden_states=kept_hidden_states,
226
239
  lm_head_weight=self.lm_head.weight,
@@ -228,10 +241,10 @@ def lce_forward(
228
241
  shift_labels=shift_labels,
229
242
  hidden_size=self.config.hidden_size,
230
243
  final_logit_softcapping=self.config.final_logit_softcapping,
231
- **loss_kwargs,
244
+ **kwargs,
232
245
  )
233
246
 
234
- else: # if in inference mode materialize logits
247
+ else:
235
248
  logits = self.lm_head(kept_hidden_states)
236
249
  if self.config.final_logit_softcapping is not None:
237
250
  logits = logits / self.config.final_logit_softcapping
@@ -240,7 +253,7 @@ def lce_forward(
240
253
 
241
254
  loss = None
242
255
  if labels is not None:
243
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
256
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
244
257
 
245
258
  if not return_dict:
246
259
  output = (logits,) + outputs[1:]
@@ -1,4 +1,3 @@
1
- from typing import List
2
1
  from typing import Optional
3
2
  from typing import Tuple
4
3
  from typing import Union
@@ -9,14 +8,8 @@ import torch.nn as nn
9
8
  from transformers.cache_utils import Cache
10
9
  from transformers.cache_utils import HybridCache
11
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
14
11
  from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
15
- from transformers.utils import add_start_docstrings_to_model_forward
16
- from transformers.utils import is_torchdynamo_compiling
17
12
  from transformers.utils import logging
18
- from transformers.utils import replace_return_docstrings
19
- from transformers.utils.deprecation import deprecate_kwarg
20
13
 
21
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
22
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -24,9 +17,6 @@ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
24
17
  logger = logging.get_logger(__name__)
25
18
 
26
19
 
27
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
28
- @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
29
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
30
20
  def causal_forward(
31
21
  self,
32
22
  input_ids: torch.LongTensor = None,
@@ -41,6 +31,7 @@ def causal_forward(
41
31
  return_dict: Optional[bool] = None,
42
32
  cache_position: Optional[torch.LongTensor] = None,
43
33
  logits_to_keep: Union[int, torch.Tensor] = 0,
34
+ skip_logits: Optional[bool] = None,
44
35
  **loss_kwargs,
45
36
  ) -> Union[Tuple, CausalLMOutputWithPast]:
46
37
  r"""
@@ -107,7 +98,11 @@ def causal_forward(
107
98
  shift_labels = loss_kwargs.pop("shift_labels", None)
108
99
  loss = None
109
100
  logits = None
110
- if self.training and (labels is not None or shift_labels is not None):
101
+
102
+ if skip_logits is None:
103
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
104
+
105
+ if skip_logits:
111
106
  loss = LigerForCausalLMLoss(
112
107
  hidden_states=kept_hidden_states,
113
108
  lm_head_weight=self.lm_head.weight,
@@ -140,16 +135,13 @@ def causal_forward(
140
135
  )
141
136
 
142
137
 
143
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
144
- @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
145
- @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
146
138
  def multimodal_forward(
147
139
  self,
148
140
  input_ids: torch.LongTensor = None,
149
141
  pixel_values: torch.FloatTensor = None,
150
142
  attention_mask: Optional[torch.Tensor] = None,
151
143
  position_ids: Optional[torch.LongTensor] = None,
152
- past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
144
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
153
145
  token_type_ids: Optional[torch.LongTensor] = None,
154
146
  cache_position: Optional[torch.LongTensor] = None,
155
147
  inputs_embeds: Optional[torch.FloatTensor] = None,
@@ -159,22 +151,14 @@ def multimodal_forward(
159
151
  output_hidden_states: Optional[bool] = None,
160
152
  return_dict: Optional[bool] = None,
161
153
  logits_to_keep: Union[int, torch.Tensor] = 0,
154
+ skip_logits: Optional[bool] = None,
162
155
  **lm_kwargs,
163
- ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
156
+ ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
164
157
  r"""
165
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
166
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
167
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
168
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
169
-
170
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
171
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
172
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
173
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
174
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
175
- This is useful when using packed tensor format (single dimension for batch and sequence length).
176
-
177
- Returns:
158
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
159
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
160
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
161
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
178
162
 
179
163
  Example:
180
164
 
@@ -183,23 +167,37 @@ def multimodal_forward(
183
167
  >>> import requests
184
168
  >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
185
169
 
186
- >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
187
- >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
188
-
189
- >>> prompt = "answer en Where is the cow standing?"
190
- >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
191
- >>> image = Image.open(requests.get(url, stream=True).raw)
192
-
193
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
194
-
170
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
171
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
172
+
173
+ >>> messages = [
174
+ ... {
175
+ ... "role": "system",
176
+ ... "content": [
177
+ ... {"type": "text", "text": "You are a helpful assistant."}
178
+ ... ]
179
+ ... },
180
+ ... {
181
+ ... "role": "user", "content": [
182
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
183
+ ... {"type": "text", "text": "Where is the cat standing?"},
184
+ ... ]
185
+ ... },
186
+ ... ]
187
+
188
+ >>> inputs = processor.apply_chat_template(
189
+ ... messages,
190
+ ... tokenize=True,
191
+ ... return_dict=True,
192
+ ... return_tensors="pt",
193
+ ... add_generation_prompt=True
194
+ ... )
195
195
  >>> # Generate
196
- >>> generate_ids = model.generate(**inputs, max_length=30)
196
+ >>> generate_ids = model.generate(**inputs)
197
197
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
198
- "answer en Where is the cow standing?\nbeach"
199
- ```"""
200
-
201
- if (input_ids is None) ^ (inputs_embeds is not None):
202
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
198
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
199
+ ```
200
+ """
203
201
 
204
202
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
203
  output_hidden_states = (
@@ -207,81 +205,38 @@ def multimodal_forward(
207
205
  )
208
206
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
209
207
 
210
- is_training = token_type_ids is not None and labels is not None
211
-
212
- # Replace image id woth PAD if the image token if OOV, to avoid index-errors
213
- if input_ids is not None and self.config.image_token_index >= self.vocab_size:
214
- special_image_mask = input_ids == self.config.image_token_index
215
- llm_input_ids = input_ids.clone()
216
- llm_input_ids[special_image_mask] = 0
217
- else:
218
- llm_input_ids = input_ids
219
-
220
- if inputs_embeds is None:
221
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
222
-
223
- if cache_position is None:
224
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
225
- cache_position = torch.arange(
226
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
227
- )
228
-
229
- if position_ids is None:
230
- position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
231
-
232
- # Merge text and images
233
- if pixel_values is not None:
234
- image_features = self.get_image_features(pixel_values)
235
-
236
- if input_ids is None:
237
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
238
- torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
239
- )
240
- else:
241
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
242
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
243
-
244
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
245
- image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
246
- raise ValueError(
247
- f"Number of images does not match number of special image tokens in the input text. "
248
- f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
249
- "tokens from image embeddings."
250
- )
251
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
252
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
253
-
254
- # mask out pad-token-ids in labels for BC
255
- if labels is not None and self.pad_token_id in labels:
256
- logger.warning_once(
257
- "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
258
- "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
259
- )
260
- labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
261
-
262
- causal_mask = self._update_causal_mask(
263
- attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
264
- )
265
- outputs = self.language_model.model(
266
- attention_mask=causal_mask,
208
+ outputs = self.model(
209
+ input_ids=input_ids,
210
+ pixel_values=pixel_values,
211
+ token_type_ids=token_type_ids,
212
+ attention_mask=attention_mask,
267
213
  position_ids=position_ids,
268
214
  past_key_values=past_key_values,
269
215
  inputs_embeds=inputs_embeds,
270
216
  use_cache=use_cache,
217
+ labels=labels,
271
218
  output_attentions=output_attentions,
272
219
  output_hidden_states=output_hidden_states,
273
220
  return_dict=return_dict,
274
221
  cache_position=cache_position,
275
- logits_to_keep=logits_to_keep,
276
222
  **lm_kwargs,
277
223
  )
278
224
 
279
225
  hidden_states = outputs[0]
226
+
227
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
228
+ kept_hidden_states = hidden_states[:, slice_indices, :]
229
+
280
230
  loss = None
281
231
  logits = None
232
+ if skip_logits and labels is None:
233
+ raise ValueError("skip_logits is True, but labels is None")
234
+
235
+ if skip_logits is None:
236
+ skip_logits = self.training and (labels is not None)
282
237
 
283
- if self.training and (labels is not None):
284
- shift_hidden_states = hidden_states[..., :-1, :]
238
+ if skip_logits:
239
+ shift_hidden_states = kept_hidden_states[..., :-1, :]
285
240
  shift_labels = labels[..., 1:]
286
241
 
287
242
  hidden_device = shift_hidden_states.device
@@ -302,7 +257,7 @@ def multimodal_forward(
302
257
  lce = LigerFusedLinearCrossEntropyLoss()
303
258
  loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
304
259
  else:
305
- logits = self.language_model.lm_head(hidden_states)
260
+ logits = self.lm_head(kept_hidden_states)
306
261
  if labels is not None:
307
262
  # Upcast to float if we need to compute the loss to avoid potential precision issues
308
263
  logits = logits.float()
@@ -323,6 +278,7 @@ def multimodal_forward(
323
278
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
324
279
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
325
280
  loss = loss_fct(flat_logits, flat_labels)
281
+
326
282
  if not return_dict:
327
283
  output = (logits,) + outputs[1:]
328
284
  return (loss,) + output if loss is not None else output
@@ -333,5 +289,5 @@ def multimodal_forward(
333
289
  past_key_values=outputs.past_key_values,
334
290
  hidden_states=outputs.hidden_states,
335
291
  attentions=outputs.attentions,
336
- image_hidden_states=image_features if pixel_values is not None else None,
292
+ image_hidden_states=outputs.image_hidden_states,
337
293
  )
@@ -6,18 +6,12 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
10
- from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
13
9
  from transformers.utils.deprecation import deprecate_kwarg
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
- @add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -32,7 +26,8 @@ def lce_forward(
32
26
  return_dict: Optional[bool] = None,
33
27
  cache_position: Optional[torch.LongTensor] = None,
34
28
  logits_to_keep: Union[int, torch.Tensor] = 0,
35
- **loss_kwargs,
29
+ skip_logits: Optional[bool] = None,
30
+ **kwargs,
36
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
37
32
  r"""
38
33
  Args:
@@ -85,6 +80,7 @@ def lce_forward(
85
80
  output_hidden_states=output_hidden_states,
86
81
  return_dict=return_dict,
87
82
  cache_position=cache_position,
83
+ **kwargs,
88
84
  )
89
85
 
90
86
  hidden_states = outputs[0]
@@ -92,28 +88,35 @@ def lce_forward(
92
88
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93
89
  kept_hidden_states = hidden_states[:, slice_indices, :]
94
90
 
95
- shift_labels = loss_kwargs.pop("shift_labels", None)
91
+ shift_labels = kwargs.pop("shift_labels", None)
96
92
  logits = None
97
93
  loss = None
98
- # if in training mode, don't materialize logits
99
- if self.training and (labels is not None or shift_labels is not None):
94
+
95
+ if skip_logits and labels is None and shift_labels is None:
96
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
97
+
98
+ if skip_logits is None:
99
+ # By default, if in training mode, don't materialize logits
100
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
101
+
102
+ if skip_logits:
100
103
  loss = LigerForCausalLMLoss(
101
104
  hidden_states=kept_hidden_states,
102
105
  lm_head_weight=self.lm_head.weight,
103
106
  labels=labels,
104
107
  shift_labels=shift_labels,
105
108
  hidden_size=self.config.hidden_size,
106
- **loss_kwargs,
109
+ **kwargs,
107
110
  )
108
111
 
109
- else: # if in inference mode materialize logits
112
+ else:
110
113
  logits = self.lm_head(kept_hidden_states)
111
114
  if labels is not None:
112
115
  loss = self.loss_function(
113
116
  logits=logits,
114
117
  labels=labels,
115
118
  vocab_size=self.config.vocab_size,
116
- **loss_kwargs,
119
+ **kwargs,
117
120
  )
118
121
 
119
122
  return CausalLMOutputWithPast(