liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  6. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  7. liger_kernel/ops/dyt.py +111 -179
  8. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  9. liger_kernel/ops/geglu.py +1 -1
  10. liger_kernel/ops/grpo_loss.py +310 -0
  11. liger_kernel/ops/multi_token_attention.py +207 -0
  12. liger_kernel/ops/rms_norm.py +265 -54
  13. liger_kernel/ops/softmax.py +201 -0
  14. liger_kernel/ops/sparsemax.py +179 -0
  15. liger_kernel/ops/swiglu.py +1 -1
  16. liger_kernel/transformers/__init__.py +8 -0
  17. liger_kernel/transformers/dyt.py +5 -3
  18. liger_kernel/transformers/fsdp.py +55 -0
  19. liger_kernel/transformers/functional.py +70 -0
  20. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  21. liger_kernel/transformers/grpo_loss.py +98 -0
  22. liger_kernel/transformers/model/gemma.py +25 -16
  23. liger_kernel/transformers/model/gemma2.py +27 -14
  24. liger_kernel/transformers/model/gemma3.py +62 -106
  25. liger_kernel/transformers/model/glm4.py +16 -13
  26. liger_kernel/transformers/model/llama.py +81 -18
  27. liger_kernel/transformers/model/llama4.py +108 -0
  28. liger_kernel/transformers/model/llava.py +95 -132
  29. liger_kernel/transformers/model/mistral.py +13 -14
  30. liger_kernel/transformers/model/mixtral.py +16 -15
  31. liger_kernel/transformers/model/mllama.py +16 -14
  32. liger_kernel/transformers/model/olmo2.py +16 -13
  33. liger_kernel/transformers/model/paligemma.py +8 -9
  34. liger_kernel/transformers/model/phi3.py +25 -16
  35. liger_kernel/transformers/model/qwen2.py +24 -15
  36. liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
  37. liger_kernel/transformers/model/qwen2_vl.py +38 -106
  38. liger_kernel/transformers/model/qwen3.py +11 -9
  39. liger_kernel/transformers/model/qwen3_moe.py +132 -0
  40. liger_kernel/transformers/monkey_patch.py +424 -81
  41. liger_kernel/transformers/multi_token_attention.py +64 -0
  42. liger_kernel/transformers/rms_norm.py +40 -4
  43. liger_kernel/transformers/softmax.py +12 -0
  44. liger_kernel/transformers/sparsemax.py +16 -0
  45. liger_kernel/transformers/swiglu.py +21 -0
  46. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  47. liger_kernel/utils.py +11 -0
  48. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
  49. liger_kernel-0.6.0.dist-info/RECORD +97 -0
  50. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  51. liger_kernel/transformers/gema3_rms.py +0 -8
  52. liger_kernel-0.5.9.dist-info/RECORD +0 -84
  53. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -5,18 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
10
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
11
8
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
9
+ from transformers.utils import can_return_tuple
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
- @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
14
+ @can_return_tuple
20
15
  def lce_forward(
21
16
  self,
22
17
  input_ids: torch.LongTensor = None,
@@ -36,17 +31,26 @@ def lce_forward(
36
31
  rope_deltas: Optional[torch.LongTensor] = None,
37
32
  cache_position: Optional[torch.LongTensor] = None,
38
33
  second_per_grid_ts: Optional[torch.Tensor] = None,
39
- **loss_kwargs,
34
+ skip_logits: Optional[bool] = None,
35
+ **kwargs,
40
36
  ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
41
37
  r"""
42
- Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
43
- Args:
44
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
-
49
- Returns:
38
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
43
+ The tensors corresponding to the input videos. Pixel values can be obtained using
44
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
45
+ [`Qwen2_5_VLImageProcessor`] for processing videos.
46
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
47
+ The temporal, height and width of feature shape of each image in LLM.
48
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
49
+ The temporal, height and width of feature shape of each video in LLM.
50
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
51
+ The rope index difference between sequence length and multimodal rope.
52
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
53
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
50
54
 
51
55
  Example:
52
56
 
@@ -78,78 +82,20 @@ def lce_forward(
78
82
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
79
83
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
80
84
  ```"""
85
+
81
86
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
82
87
  output_hidden_states = (
83
88
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
84
89
  )
85
90
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
91
 
87
- if inputs_embeds is None:
88
- inputs_embeds = self.model.embed_tokens(input_ids)
89
- if pixel_values is not None:
90
- pixel_values = pixel_values.type(self.visual.dtype)
91
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
92
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
93
- n_image_features = image_embeds.shape[0]
94
- if n_image_tokens != n_image_features:
95
- raise ValueError(
96
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
97
- )
98
-
99
- mask = input_ids == self.config.image_token_id
100
- mask_unsqueezed = mask.unsqueeze(-1)
101
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
102
- image_mask = mask_expanded.to(inputs_embeds.device)
103
-
104
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
105
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
106
-
107
- if pixel_values_videos is not None:
108
- pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
109
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
110
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
111
- n_video_features = video_embeds.shape[0]
112
- if n_video_tokens != n_video_features:
113
- raise ValueError(
114
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
115
- )
116
-
117
- mask = input_ids == self.config.video_token_id
118
- mask_unsqueezed = mask.unsqueeze(-1)
119
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
120
- video_mask = mask_expanded.to(inputs_embeds.device)
121
-
122
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
123
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
124
-
125
- if attention_mask is not None:
126
- attention_mask = attention_mask.to(inputs_embeds.device)
127
-
128
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
129
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
130
- # calculate RoPE index once per generation in the pre-fill stage only
131
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
132
- position_ids, rope_deltas = self.get_rope_index(
133
- input_ids,
134
- image_grid_thw,
135
- video_grid_thw,
136
- second_per_grid_ts,
137
- attention_mask,
138
- )
139
- self.rope_deltas = rope_deltas
140
- # then use the prev pre-calculated rope-deltas to get the correct position ids
141
- else:
142
- batch_size, seq_length, _ = inputs_embeds.shape
143
- delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
144
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
145
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
146
- if cache_position is not None: # otherwise `deltas` is an int `0`
147
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
148
- position_ids = position_ids.add(delta)
149
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
150
-
151
92
  outputs = self.model(
152
- input_ids=None,
93
+ input_ids=input_ids,
94
+ pixel_values=pixel_values,
95
+ pixel_values_videos=pixel_values_videos,
96
+ image_grid_thw=image_grid_thw,
97
+ video_grid_thw=video_grid_thw,
98
+ second_per_grid_ts=second_per_grid_ts,
153
99
  position_ids=position_ids,
154
100
  attention_mask=attention_mask,
155
101
  past_key_values=past_key_values,
@@ -159,38 +105,36 @@ def lce_forward(
159
105
  output_hidden_states=output_hidden_states,
160
106
  return_dict=return_dict,
161
107
  cache_position=cache_position,
108
+ **kwargs,
162
109
  )
163
110
 
164
111
  hidden_states = outputs[0]
165
112
 
166
- shift_labels = loss_kwargs.pop("shift_labels", None)
113
+ shift_labels = kwargs.pop("shift_labels", None)
167
114
  loss = None
168
115
  logits = None
169
116
 
170
- if self.training and (labels is not None or shift_labels is not None):
117
+ if skip_logits and labels is None and shift_labels is None:
118
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
119
+
120
+ if skip_logits is None:
121
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
122
+
123
+ if skip_logits:
171
124
  loss = LigerForCausalLMLoss(
172
125
  hidden_states=hidden_states,
173
126
  lm_head_weight=self.lm_head.weight,
174
127
  labels=labels,
175
128
  shift_labels=shift_labels,
176
129
  hidden_size=self.config.hidden_size,
177
- **loss_kwargs,
130
+ **kwargs,
178
131
  )
179
132
  else:
180
133
  logits = self.lm_head(hidden_states)
134
+
135
+ loss = None
181
136
  if labels is not None:
182
- # Upcast to float if we need to compute the loss to avoid potential precision issues
183
- logits = logits.float()
184
- # Shift so that tokens < n predict n
185
- shift_logits = logits[..., :-1, :].contiguous()
186
- shift_labels = labels[..., 1:].contiguous()
187
- # Flatten the tokens
188
- loss_fct = CrossEntropyLoss()
189
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
190
- shift_labels = shift_labels.view(-1)
191
- # Enable model parallelism
192
- shift_labels = shift_labels.to(shift_logits.device)
193
- loss = loss_fct(shift_logits, shift_labels)
137
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
194
138
 
195
139
  if not return_dict:
196
140
  output = (logits,) + outputs[1:]
@@ -202,5 +146,5 @@ def lce_forward(
202
146
  past_key_values=outputs.past_key_values,
203
147
  hidden_states=outputs.hidden_states,
204
148
  attentions=outputs.attentions,
205
- rope_deltas=rope_deltas,
149
+ rope_deltas=outputs.rope_deltas,
206
150
  )
@@ -5,20 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from packaging import version
9
- from torch.nn import CrossEntropyLoss
10
- from transformers import __version__ as transformers_version
11
- from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
- from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
8
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
9
+ from transformers.utils import can_return_tuple
16
10
 
17
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
12
 
19
13
 
20
- @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
14
+ @can_return_tuple
22
15
  def lce_forward(
23
16
  self,
24
17
  input_ids: torch.LongTensor = None,
@@ -37,18 +30,24 @@ def lce_forward(
37
30
  video_grid_thw: Optional[torch.LongTensor] = None,
38
31
  rope_deltas: Optional[torch.LongTensor] = None,
39
32
  cache_position: Optional[torch.LongTensor] = None,
40
- **loss_kwargs,
33
+ skip_logits: Optional[bool] = None,
34
+ **kwargs,
41
35
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
42
36
  r"""
43
- Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
44
-
45
- Args:
46
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
47
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
48
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
49
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
50
-
51
- Returns:
37
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
38
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
39
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
42
+ The tensors corresponding to the input videos. Pixel values can be obtained using
43
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
44
+ [`Qwen2VLImageProcessor`] for processing videos.
45
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
46
+ The temporal, height and width of feature shape of each image in LLM.
47
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
48
+ The temporal, height and width of feature shape of each video in LLM.
49
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
50
+ The rope index difference between sequence length and multimodal rope.
52
51
 
53
52
  Example:
54
53
 
@@ -80,80 +79,19 @@ def lce_forward(
80
79
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
81
80
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
82
81
  ```"""
82
+
83
83
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
84
84
  output_hidden_states = (
85
85
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86
86
  )
87
87
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
88
 
89
- if inputs_embeds is None:
90
- inputs_embeds = self.model.embed_tokens(input_ids)
91
- if pixel_values is not None:
92
- pixel_values = pixel_values.type(self.visual.get_dtype())
93
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
94
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
95
- n_image_features = image_embeds.shape[0]
96
- if n_image_tokens != n_image_features:
97
- raise ValueError(
98
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
99
- )
100
- image_mask = (
101
- (input_ids == self.config.image_token_id)
102
- .unsqueeze(-1)
103
- .expand_as(inputs_embeds)
104
- .to(inputs_embeds.device)
105
- )
106
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
107
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
108
-
109
- if pixel_values_videos is not None:
110
- pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
111
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
112
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
113
- n_video_features = video_embeds.shape[0]
114
- if n_video_tokens != n_video_features:
115
- raise ValueError(
116
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
117
- )
118
- video_mask = (
119
- (input_ids == self.config.video_token_id)
120
- .unsqueeze(-1)
121
- .expand_as(inputs_embeds)
122
- .to(inputs_embeds.device)
123
- )
124
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
125
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
126
-
127
- if attention_mask is not None:
128
- attention_mask = attention_mask.to(inputs_embeds.device)
129
-
130
- if version.parse(transformers_version) > version.parse("4.46.3"):
131
- # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
132
- # https://github.com/huggingface/transformers/issues/33401
133
- # While correct, this breaks equivalence with past versions of Qwen2-VL from
134
- # transformers and leads to failed tests or users noticing differences in results.
135
- # TODO: remove above conditional when liger drops support for transformers<4.47.0
136
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
137
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
138
- # calculate RoPE index once per generation in the pre-fill stage only
139
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
140
- position_ids, rope_deltas = self.get_rope_index(
141
- input_ids, image_grid_thw, video_grid_thw, attention_mask
142
- )
143
- self.rope_deltas = rope_deltas
144
- # then use the prev pre-calculated rope-deltas to get the correct position ids
145
- else:
146
- batch_size, seq_length, _ = inputs_embeds.shape
147
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
148
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
149
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
150
- if cache_position is not None: # otherwise `deltas` is an int `0`
151
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
152
- position_ids = position_ids.add(delta)
153
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
154
-
155
89
  outputs = self.model(
156
- input_ids=None,
90
+ input_ids=input_ids,
91
+ pixel_values=pixel_values,
92
+ pixel_values_videos=pixel_values_videos,
93
+ image_grid_thw=image_grid_thw,
94
+ video_grid_thw=video_grid_thw,
157
95
  position_ids=position_ids,
158
96
  attention_mask=attention_mask,
159
97
  past_key_values=past_key_values,
@@ -163,42 +101,36 @@ def lce_forward(
163
101
  output_hidden_states=output_hidden_states,
164
102
  return_dict=return_dict,
165
103
  cache_position=cache_position,
104
+ **kwargs,
166
105
  )
167
106
 
168
107
  hidden_states = outputs[0]
169
108
 
170
- shift_labels = loss_kwargs.pop("shift_labels", None)
109
+ shift_labels = kwargs.pop("shift_labels", None)
171
110
  loss = None
172
111
  logits = None
173
112
 
174
- if self.training and (labels is not None or shift_labels is not None):
113
+ if skip_logits and labels is None and shift_labels is None:
114
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
115
+
116
+ if skip_logits is None:
117
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
118
+
119
+ if skip_logits:
175
120
  loss = LigerForCausalLMLoss(
176
121
  hidden_states=hidden_states,
177
122
  lm_head_weight=self.lm_head.weight,
178
123
  labels=labels,
179
124
  shift_labels=shift_labels,
180
125
  hidden_size=self.config.hidden_size,
181
- **loss_kwargs,
126
+ **kwargs,
182
127
  )
183
128
  else:
184
129
  logits = self.lm_head(hidden_states)
130
+
131
+ loss = None
185
132
  if labels is not None:
186
- # Upcast to float if we need to compute the loss to avoid potential precision issues
187
- logits = logits.float()
188
- # Shift so that tokens < n predict n
189
- shift_logits = logits[..., :-1, :].contiguous()
190
- shift_labels = labels[..., 1:].contiguous()
191
- # Flatten the tokens
192
- loss_fct = CrossEntropyLoss()
193
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
194
- shift_labels = shift_labels.view(-1)
195
- # Enable model parallelism
196
- shift_labels = shift_labels.to(shift_logits.device)
197
- loss = loss_fct(shift_logits, shift_labels)
198
-
199
- if not return_dict:
200
- output = (logits,) + outputs[1:]
201
- return (loss,) + output if loss is not None else output
133
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
202
134
 
203
135
  return Qwen2VLCausalLMOutputWithPast(
204
136
  loss=loss,
@@ -206,5 +138,5 @@ def lce_forward(
206
138
  past_key_values=outputs.past_key_values,
207
139
  hidden_states=outputs.hidden_states,
208
140
  attentions=outputs.attentions,
209
- rope_deltas=rope_deltas,
141
+ rope_deltas=outputs.rope_deltas,
210
142
  )
@@ -5,16 +5,10 @@ from typing import Union
5
5
  import torch
6
6
 
7
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
- from transformers.models.qwen3.modeling_qwen3 import _CONFIG_FOR_DOC
9
- from transformers.models.qwen3.modeling_qwen3 import QWEN3_INPUTS_DOCSTRING
10
- from transformers.utils import add_start_docstrings_to_model_forward
11
- from transformers.utils import replace_return_docstrings
12
8
 
13
9
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
10
 
15
11
 
16
- @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
17
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
18
12
  def lce_forward(
19
13
  self,
20
14
  input_ids: Optional[torch.LongTensor] = None,
@@ -28,6 +22,7 @@ def lce_forward(
28
22
  output_hidden_states: Optional[bool] = None,
29
23
  cache_position: Optional[torch.LongTensor] = None,
30
24
  logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
31
26
  **kwargs,
32
27
  ) -> CausalLMOutputWithPast:
33
28
  r"""
@@ -88,8 +83,15 @@ def lce_forward(
88
83
  shift_labels = kwargs.pop("shift_labels", None)
89
84
  logits = None
90
85
  loss = None
91
- # if in training mode, don't materialize logits
92
- if self.training and (labels is not None or shift_labels is not None):
86
+
87
+ if skip_logits and labels is None and shift_labels is None:
88
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
89
+
90
+ if skip_logits is None:
91
+ # By default, if in training mode, don't materialize logits
92
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
93
+
94
+ if skip_logits:
93
95
  loss = LigerForCausalLMLoss(
94
96
  hidden_states=kept_hidden_states,
95
97
  lm_head_weight=self.lm_head.weight,
@@ -99,7 +101,7 @@ def lce_forward(
99
101
  **kwargs,
100
102
  )
101
103
 
102
- else: # if in inference mode materialize logits
104
+ else:
103
105
  logits = self.lm_head(kept_hidden_states)
104
106
  if labels is not None:
105
107
  loss = self.loss_function(
@@ -0,0 +1,132 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast
8
+ from transformers.modeling_outputs import MoeModelOutputWithPast
9
+ from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+
13
+
14
+ def lce_forward(
15
+ self,
16
+ input_ids: Optional[torch.LongTensor] = None,
17
+ attention_mask: Optional[torch.Tensor] = None,
18
+ position_ids: Optional[torch.LongTensor] = None,
19
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
20
+ inputs_embeds: Optional[torch.FloatTensor] = None,
21
+ labels: Optional[torch.LongTensor] = None,
22
+ use_cache: Optional[bool] = None,
23
+ output_attentions: Optional[bool] = None,
24
+ output_hidden_states: Optional[bool] = None,
25
+ output_router_logits: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ skip_logits: Optional[bool] = None,
29
+ **kwargs,
30
+ ) -> MoeCausalLMOutputWithPast:
31
+ r"""
32
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
33
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
34
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
35
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
36
+
37
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
38
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
39
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
40
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
41
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
42
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
43
+
44
+ Returns:
45
+
46
+ Example:
47
+
48
+ ```python
49
+ >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
50
+
51
+ >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
52
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
53
+
54
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
55
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
56
+
57
+ >>> # Generate
58
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
59
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
60
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
61
+ ```"""
62
+
63
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
64
+ output_router_logits = (
65
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
66
+ )
67
+
68
+ output_hidden_states = (
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
70
+ )
71
+
72
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
73
+ outputs: MoeModelOutputWithPast = self.model(
74
+ input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ position_ids=position_ids,
77
+ past_key_values=past_key_values,
78
+ inputs_embeds=inputs_embeds,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ output_router_logits=output_router_logits,
83
+ cache_position=cache_position,
84
+ **kwargs,
85
+ )
86
+
87
+ hidden_states = outputs.last_hidden_state
88
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
89
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
90
+ kept_hidden_states = hidden_states[:, slice_indices, :]
91
+
92
+ shift_labels = kwargs.pop("shift_labels", None)
93
+ logits = None
94
+ loss = None
95
+
96
+ if skip_logits is None:
97
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
98
+
99
+ if skip_logits:
100
+ loss = LigerForCausalLMLoss(
101
+ hidden_states=kept_hidden_states,
102
+ lm_head_weight=self.lm_head.weight,
103
+ labels=labels,
104
+ shift_labels=shift_labels,
105
+ hidden_size=self.config.hidden_size,
106
+ **kwargs,
107
+ )
108
+ else: # if in inference model materialize logits
109
+ logits = self.lm_head(kept_hidden_states)
110
+ if labels is not None:
111
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
112
+
113
+ aux_loss = None
114
+ if output_router_logits:
115
+ aux_loss = load_balancing_loss_func(
116
+ outputs.router_logits,
117
+ self.num_experts,
118
+ self.num_experts_per_tok,
119
+ attention_mask,
120
+ )
121
+ if labels is not None:
122
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
123
+
124
+ return MoeCausalLMOutputWithPast(
125
+ loss=loss,
126
+ aux_loss=aux_loss,
127
+ logits=logits,
128
+ past_key_values=outputs.past_key_values,
129
+ hidden_states=outputs.hidden_states,
130
+ attentions=outputs.attentions,
131
+ router_logits=outputs.router_logits,
132
+ )