liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  6. liger_kernel/ops/geglu.py +1 -1
  7. liger_kernel/ops/multi_token_attention.py +207 -0
  8. liger_kernel/ops/rms_norm.py +265 -54
  9. liger_kernel/ops/softmax.py +201 -0
  10. liger_kernel/ops/sparsemax.py +62 -50
  11. liger_kernel/ops/swiglu.py +1 -1
  12. liger_kernel/transformers/__init__.py +3 -0
  13. liger_kernel/transformers/functional.py +62 -0
  14. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  15. liger_kernel/transformers/model/gemma.py +25 -8
  16. liger_kernel/transformers/model/gemma2.py +27 -8
  17. liger_kernel/transformers/model/gemma3.py +62 -98
  18. liger_kernel/transformers/model/glm4.py +16 -7
  19. liger_kernel/transformers/model/llama.py +25 -7
  20. liger_kernel/transformers/model/llama4.py +108 -0
  21. liger_kernel/transformers/model/llava.py +95 -124
  22. liger_kernel/transformers/model/mistral.py +13 -8
  23. liger_kernel/transformers/model/mixtral.py +16 -7
  24. liger_kernel/transformers/model/mllama.py +16 -7
  25. liger_kernel/transformers/model/olmo2.py +16 -7
  26. liger_kernel/transformers/model/paligemma.py +8 -1
  27. liger_kernel/transformers/model/phi3.py +25 -8
  28. liger_kernel/transformers/model/qwen2.py +24 -7
  29. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  30. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  31. liger_kernel/transformers/model/qwen3.py +11 -3
  32. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  33. liger_kernel/transformers/monkey_patch.py +304 -70
  34. liger_kernel/transformers/multi_token_attention.py +64 -0
  35. liger_kernel/transformers/rms_norm.py +40 -4
  36. liger_kernel/transformers/softmax.py +12 -0
  37. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
  38. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
  39. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  40. liger_kernel/transformers/gema3_rms.py +0 -8
  41. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -216,6 +216,7 @@ def lce_forward(
216
216
  output_hidden_states: Optional[bool] = None,
217
217
  return_dict: Optional[bool] = None,
218
218
  logits_to_keep: Union[int, torch.Tensor] = 0,
219
+ skip_logits: Optional[bool] = None,
219
220
  **lm_kwargs,
220
221
  ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
221
222
  r"""
@@ -331,7 +332,13 @@ def lce_forward(
331
332
  loss = None
332
333
  logits = None
333
334
 
334
- if self.training and (labels is not None):
335
+ if skip_logits and labels is None:
336
+ raise ValueError("skip_logits is True, but labels is None")
337
+
338
+ if skip_logits is None:
339
+ skip_logits = self.training and (labels is not None)
340
+
341
+ if skip_logits:
335
342
  shift_hidden_states = hidden_states[..., :-1, :]
336
343
  shift_labels = labels[..., 1:]
337
344
 
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
26
26
  output_hidden_states: Optional[bool] = None,
27
27
  return_dict: Optional[bool] = None,
28
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
29
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
30
31
  r"""
31
32
  Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
@@ -80,7 +81,14 @@ def lce_forward_deprecated(
80
81
  loss = None
81
82
  logits = None
82
83
 
83
- if self.training and labels is not None:
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
91
+ if skip_logits:
84
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
85
93
  shift_labels = labels[..., 1:].contiguous()
86
94
 
@@ -136,7 +144,8 @@ def lce_forward(
136
144
  return_dict: Optional[bool] = None,
137
145
  cache_position: Optional[torch.LongTensor] = None,
138
146
  logits_to_keep: Union[int, torch.Tensor] = 0,
139
- **loss_kwargs,
147
+ skip_logits: Optional[bool] = None,
148
+ **kwargs,
140
149
  ) -> Union[Tuple, CausalLMOutputWithPast]:
141
150
  r"""
142
151
  Args:
@@ -202,6 +211,7 @@ def lce_forward(
202
211
  output_attentions=output_attentions,
203
212
  output_hidden_states=output_hidden_states,
204
213
  return_dict=return_dict,
214
+ **kwargs,
205
215
  )
206
216
 
207
217
  hidden_states = outputs[0]
@@ -209,28 +219,35 @@ def lce_forward(
209
219
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
210
220
  kept_hidden_states = hidden_states[:, slice_indices, :]
211
221
 
212
- shift_labels = loss_kwargs.pop("shift_labels", None)
222
+ shift_labels = kwargs.pop("shift_labels", None)
213
223
  logits = None
214
224
  loss = None
215
- # if in training mode, don't materialize logits
216
- if self.training and (labels is not None or shift_labels is not None):
225
+
226
+ if skip_logits and labels is None and shift_labels is None:
227
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
228
+
229
+ if skip_logits is None:
230
+ # By default, if in training mode, don't materialize logits
231
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
232
+
233
+ if skip_logits:
217
234
  loss = LigerForCausalLMLoss(
218
235
  hidden_states=kept_hidden_states,
219
236
  lm_head_weight=self.lm_head.weight,
220
237
  labels=labels,
221
238
  shift_labels=shift_labels,
222
239
  hidden_size=self.config.hidden_size,
223
- **loss_kwargs,
240
+ **kwargs,
224
241
  )
225
242
 
226
- else: # if in inference mode materialize logits
243
+ else:
227
244
  logits = self.lm_head(kept_hidden_states)
228
245
  if labels is not None:
229
246
  loss = self.loss_function(
230
247
  logits=logits,
231
248
  labels=labels,
232
249
  vocab_size=self.config.vocab_size,
233
- **loss_kwargs,
250
+ **kwargs,
234
251
  )
235
252
 
236
253
  if not return_dict:
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
26
26
  output_hidden_states: Optional[bool] = None,
27
27
  return_dict: Optional[bool] = None,
28
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
29
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
30
31
  r"""
31
32
  Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -80,6 +81,13 @@ def lce_forward_deprecated(
80
81
  loss = None
81
82
  logits = None
82
83
 
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
83
91
  if self.training and (labels is not None):
84
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
85
93
  shift_labels = labels[..., 1:].contiguous()
@@ -135,7 +143,8 @@ def lce_forward(
135
143
  return_dict: Optional[bool] = None,
136
144
  cache_position: Optional[torch.LongTensor] = None,
137
145
  logits_to_keep: Union[int, torch.Tensor] = 0,
138
- **loss_kwargs,
146
+ skip_logits: Optional[bool] = None,
147
+ **kwargs,
139
148
  ) -> Union[Tuple, CausalLMOutputWithPast]:
140
149
  r"""
141
150
  Args:
@@ -188,6 +197,7 @@ def lce_forward(
188
197
  output_hidden_states=output_hidden_states,
189
198
  return_dict=return_dict,
190
199
  cache_position=cache_position,
200
+ **kwargs,
191
201
  )
192
202
 
193
203
  hidden_states = outputs[0]
@@ -195,28 +205,35 @@ def lce_forward(
195
205
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
196
206
  kept_hidden_states = hidden_states[:, slice_indices, :]
197
207
 
198
- shift_labels = loss_kwargs.pop("shift_labels", None)
208
+ shift_labels = kwargs.pop("shift_labels", None)
199
209
  logits = None
200
210
  loss = None
201
- # if in training mode, don't materialize logits
202
- if self.training and (labels is not None or shift_labels is not None):
211
+
212
+ if skip_logits and labels is None and shift_labels is None:
213
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
214
+
215
+ if skip_logits is None:
216
+ # By default, if in training mode, don't materialize logits
217
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
218
+
219
+ if skip_logits:
203
220
  loss = LigerForCausalLMLoss(
204
221
  hidden_states=kept_hidden_states,
205
222
  lm_head_weight=self.lm_head.weight,
206
223
  labels=labels,
207
224
  shift_labels=shift_labels,
208
225
  hidden_size=self.config.hidden_size,
209
- **loss_kwargs,
226
+ **kwargs,
210
227
  )
211
228
 
212
- else: # if in inference mode materialize logits
229
+ else:
213
230
  logits = self.lm_head(kept_hidden_states)
214
231
  if labels is not None:
215
232
  loss = self.loss_function(
216
233
  logits=logits,
217
234
  labels=labels,
218
235
  vocab_size=self.config.vocab_size,
219
- **loss_kwargs,
236
+ **kwargs,
220
237
  )
221
238
 
222
239
  return CausalLMOutputWithPast(
@@ -5,12 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
8
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
9
+ from transformers.utils import can_return_tuple
10
10
 
11
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
12
 
13
13
 
14
+ @can_return_tuple
14
15
  def lce_forward(
15
16
  self,
16
17
  input_ids: torch.LongTensor = None,
@@ -30,17 +31,26 @@ def lce_forward(
30
31
  rope_deltas: Optional[torch.LongTensor] = None,
31
32
  cache_position: Optional[torch.LongTensor] = None,
32
33
  second_per_grid_ts: Optional[torch.Tensor] = None,
33
- **loss_kwargs,
34
+ skip_logits: Optional[bool] = None,
35
+ **kwargs,
34
36
  ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
35
37
  r"""
36
- Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
37
- Args:
38
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42
-
43
- Returns:
38
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
43
+ The tensors corresponding to the input videos. Pixel values can be obtained using
44
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
45
+ [`Qwen2_5_VLImageProcessor`] for processing videos.
46
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
47
+ The temporal, height and width of feature shape of each image in LLM.
48
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
49
+ The temporal, height and width of feature shape of each video in LLM.
50
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
51
+ The rope index difference between sequence length and multimodal rope.
52
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
53
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
44
54
 
45
55
  Example:
46
56
 
@@ -72,78 +82,20 @@ def lce_forward(
72
82
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
73
83
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
74
84
  ```"""
85
+
75
86
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
76
87
  output_hidden_states = (
77
88
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78
89
  )
79
90
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80
91
 
81
- if inputs_embeds is None:
82
- inputs_embeds = self.model.embed_tokens(input_ids)
83
- if pixel_values is not None:
84
- pixel_values = pixel_values.type(self.visual.dtype)
85
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
86
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
87
- n_image_features = image_embeds.shape[0]
88
- if n_image_tokens != n_image_features:
89
- raise ValueError(
90
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
91
- )
92
-
93
- mask = input_ids == self.config.image_token_id
94
- mask_unsqueezed = mask.unsqueeze(-1)
95
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
96
- image_mask = mask_expanded.to(inputs_embeds.device)
97
-
98
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
99
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
100
-
101
- if pixel_values_videos is not None:
102
- pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
103
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
104
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
105
- n_video_features = video_embeds.shape[0]
106
- if n_video_tokens != n_video_features:
107
- raise ValueError(
108
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
109
- )
110
-
111
- mask = input_ids == self.config.video_token_id
112
- mask_unsqueezed = mask.unsqueeze(-1)
113
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
114
- video_mask = mask_expanded.to(inputs_embeds.device)
115
-
116
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
117
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
118
-
119
- if attention_mask is not None:
120
- attention_mask = attention_mask.to(inputs_embeds.device)
121
-
122
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
123
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
124
- # calculate RoPE index once per generation in the pre-fill stage only
125
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
126
- position_ids, rope_deltas = self.get_rope_index(
127
- input_ids,
128
- image_grid_thw,
129
- video_grid_thw,
130
- second_per_grid_ts,
131
- attention_mask,
132
- )
133
- self.rope_deltas = rope_deltas
134
- # then use the prev pre-calculated rope-deltas to get the correct position ids
135
- else:
136
- batch_size, seq_length, _ = inputs_embeds.shape
137
- delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
138
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
139
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
140
- if cache_position is not None: # otherwise `deltas` is an int `0`
141
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
142
- position_ids = position_ids.add(delta)
143
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
144
-
145
92
  outputs = self.model(
146
- input_ids=None,
93
+ input_ids=input_ids,
94
+ pixel_values=pixel_values,
95
+ pixel_values_videos=pixel_values_videos,
96
+ image_grid_thw=image_grid_thw,
97
+ video_grid_thw=video_grid_thw,
98
+ second_per_grid_ts=second_per_grid_ts,
147
99
  position_ids=position_ids,
148
100
  attention_mask=attention_mask,
149
101
  past_key_values=past_key_values,
@@ -153,38 +105,36 @@ def lce_forward(
153
105
  output_hidden_states=output_hidden_states,
154
106
  return_dict=return_dict,
155
107
  cache_position=cache_position,
108
+ **kwargs,
156
109
  )
157
110
 
158
111
  hidden_states = outputs[0]
159
112
 
160
- shift_labels = loss_kwargs.pop("shift_labels", None)
113
+ shift_labels = kwargs.pop("shift_labels", None)
161
114
  loss = None
162
115
  logits = None
163
116
 
164
- if self.training and (labels is not None or shift_labels is not None):
117
+ if skip_logits and labels is None and shift_labels is None:
118
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
119
+
120
+ if skip_logits is None:
121
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
122
+
123
+ if skip_logits:
165
124
  loss = LigerForCausalLMLoss(
166
125
  hidden_states=hidden_states,
167
126
  lm_head_weight=self.lm_head.weight,
168
127
  labels=labels,
169
128
  shift_labels=shift_labels,
170
129
  hidden_size=self.config.hidden_size,
171
- **loss_kwargs,
130
+ **kwargs,
172
131
  )
173
132
  else:
174
133
  logits = self.lm_head(hidden_states)
134
+
135
+ loss = None
175
136
  if labels is not None:
176
- # Upcast to float if we need to compute the loss to avoid potential precision issues
177
- logits = logits.float()
178
- # Shift so that tokens < n predict n
179
- shift_logits = logits[..., :-1, :].contiguous()
180
- shift_labels = labels[..., 1:].contiguous()
181
- # Flatten the tokens
182
- loss_fct = CrossEntropyLoss()
183
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
184
- shift_labels = shift_labels.view(-1)
185
- # Enable model parallelism
186
- shift_labels = shift_labels.to(shift_logits.device)
187
- loss = loss_fct(shift_logits, shift_labels)
137
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
188
138
 
189
139
  if not return_dict:
190
140
  output = (logits,) + outputs[1:]
@@ -196,5 +146,5 @@ def lce_forward(
196
146
  past_key_values=outputs.past_key_values,
197
147
  hidden_states=outputs.hidden_states,
198
148
  attentions=outputs.attentions,
199
- rope_deltas=rope_deltas,
149
+ rope_deltas=outputs.rope_deltas,
200
150
  )
@@ -5,14 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from packaging import version
9
- from torch.nn import CrossEntropyLoss
10
- from transformers import __version__ as transformers_version
11
8
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
9
+ from transformers.utils import can_return_tuple
12
10
 
13
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
12
 
15
13
 
14
+ @can_return_tuple
16
15
  def lce_forward(
17
16
  self,
18
17
  input_ids: torch.LongTensor = None,
@@ -31,18 +30,24 @@ def lce_forward(
31
30
  video_grid_thw: Optional[torch.LongTensor] = None,
32
31
  rope_deltas: Optional[torch.LongTensor] = None,
33
32
  cache_position: Optional[torch.LongTensor] = None,
34
- **loss_kwargs,
33
+ skip_logits: Optional[bool] = None,
34
+ **kwargs,
35
35
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
36
36
  r"""
37
- Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
38
-
39
- Args:
40
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
41
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
42
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
43
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
44
-
45
- Returns:
37
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
38
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
39
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
42
+ The tensors corresponding to the input videos. Pixel values can be obtained using
43
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
44
+ [`Qwen2VLImageProcessor`] for processing videos.
45
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
46
+ The temporal, height and width of feature shape of each image in LLM.
47
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
48
+ The temporal, height and width of feature shape of each video in LLM.
49
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
50
+ The rope index difference between sequence length and multimodal rope.
46
51
 
47
52
  Example:
48
53
 
@@ -74,80 +79,19 @@ def lce_forward(
74
79
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
75
80
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
76
81
  ```"""
82
+
77
83
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
78
84
  output_hidden_states = (
79
85
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
80
86
  )
81
87
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
88
 
83
- if inputs_embeds is None:
84
- inputs_embeds = self.model.embed_tokens(input_ids)
85
- if pixel_values is not None:
86
- pixel_values = pixel_values.type(self.visual.get_dtype())
87
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
88
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
89
- n_image_features = image_embeds.shape[0]
90
- if n_image_tokens != n_image_features:
91
- raise ValueError(
92
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
93
- )
94
- image_mask = (
95
- (input_ids == self.config.image_token_id)
96
- .unsqueeze(-1)
97
- .expand_as(inputs_embeds)
98
- .to(inputs_embeds.device)
99
- )
100
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
101
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
102
-
103
- if pixel_values_videos is not None:
104
- pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
105
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
106
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
107
- n_video_features = video_embeds.shape[0]
108
- if n_video_tokens != n_video_features:
109
- raise ValueError(
110
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
111
- )
112
- video_mask = (
113
- (input_ids == self.config.video_token_id)
114
- .unsqueeze(-1)
115
- .expand_as(inputs_embeds)
116
- .to(inputs_embeds.device)
117
- )
118
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
119
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
120
-
121
- if attention_mask is not None:
122
- attention_mask = attention_mask.to(inputs_embeds.device)
123
-
124
- if version.parse(transformers_version) > version.parse("4.46.3"):
125
- # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
126
- # https://github.com/huggingface/transformers/issues/33401
127
- # While correct, this breaks equivalence with past versions of Qwen2-VL from
128
- # transformers and leads to failed tests or users noticing differences in results.
129
- # TODO: remove above conditional when liger drops support for transformers<4.47.0
130
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
131
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
132
- # calculate RoPE index once per generation in the pre-fill stage only
133
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
134
- position_ids, rope_deltas = self.get_rope_index(
135
- input_ids, image_grid_thw, video_grid_thw, attention_mask
136
- )
137
- self.rope_deltas = rope_deltas
138
- # then use the prev pre-calculated rope-deltas to get the correct position ids
139
- else:
140
- batch_size, seq_length, _ = inputs_embeds.shape
141
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
142
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
143
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
144
- if cache_position is not None: # otherwise `deltas` is an int `0`
145
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
146
- position_ids = position_ids.add(delta)
147
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
148
-
149
89
  outputs = self.model(
150
- input_ids=None,
90
+ input_ids=input_ids,
91
+ pixel_values=pixel_values,
92
+ pixel_values_videos=pixel_values_videos,
93
+ image_grid_thw=image_grid_thw,
94
+ video_grid_thw=video_grid_thw,
151
95
  position_ids=position_ids,
152
96
  attention_mask=attention_mask,
153
97
  past_key_values=past_key_values,
@@ -157,42 +101,36 @@ def lce_forward(
157
101
  output_hidden_states=output_hidden_states,
158
102
  return_dict=return_dict,
159
103
  cache_position=cache_position,
104
+ **kwargs,
160
105
  )
161
106
 
162
107
  hidden_states = outputs[0]
163
108
 
164
- shift_labels = loss_kwargs.pop("shift_labels", None)
109
+ shift_labels = kwargs.pop("shift_labels", None)
165
110
  loss = None
166
111
  logits = None
167
112
 
168
- if self.training and (labels is not None or shift_labels is not None):
113
+ if skip_logits and labels is None and shift_labels is None:
114
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
115
+
116
+ if skip_logits is None:
117
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
118
+
119
+ if skip_logits:
169
120
  loss = LigerForCausalLMLoss(
170
121
  hidden_states=hidden_states,
171
122
  lm_head_weight=self.lm_head.weight,
172
123
  labels=labels,
173
124
  shift_labels=shift_labels,
174
125
  hidden_size=self.config.hidden_size,
175
- **loss_kwargs,
126
+ **kwargs,
176
127
  )
177
128
  else:
178
129
  logits = self.lm_head(hidden_states)
130
+
131
+ loss = None
179
132
  if labels is not None:
180
- # Upcast to float if we need to compute the loss to avoid potential precision issues
181
- logits = logits.float()
182
- # Shift so that tokens < n predict n
183
- shift_logits = logits[..., :-1, :].contiguous()
184
- shift_labels = labels[..., 1:].contiguous()
185
- # Flatten the tokens
186
- loss_fct = CrossEntropyLoss()
187
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
188
- shift_labels = shift_labels.view(-1)
189
- # Enable model parallelism
190
- shift_labels = shift_labels.to(shift_logits.device)
191
- loss = loss_fct(shift_logits, shift_labels)
192
-
193
- if not return_dict:
194
- output = (logits,) + outputs[1:]
195
- return (loss,) + output if loss is not None else output
133
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
196
134
 
197
135
  return Qwen2VLCausalLMOutputWithPast(
198
136
  loss=loss,
@@ -200,5 +138,5 @@ def lce_forward(
200
138
  past_key_values=outputs.past_key_values,
201
139
  hidden_states=outputs.hidden_states,
202
140
  attentions=outputs.attentions,
203
- rope_deltas=rope_deltas,
141
+ rope_deltas=outputs.rope_deltas,
204
142
  )
@@ -22,6 +22,7 @@ def lce_forward(
22
22
  output_hidden_states: Optional[bool] = None,
23
23
  cache_position: Optional[torch.LongTensor] = None,
24
24
  logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
25
26
  **kwargs,
26
27
  ) -> CausalLMOutputWithPast:
27
28
  r"""
@@ -82,8 +83,15 @@ def lce_forward(
82
83
  shift_labels = kwargs.pop("shift_labels", None)
83
84
  logits = None
84
85
  loss = None
85
- # if in training mode, don't materialize logits
86
- if self.training and (labels is not None or shift_labels is not None):
86
+
87
+ if skip_logits and labels is None and shift_labels is None:
88
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
89
+
90
+ if skip_logits is None:
91
+ # By default, if in training mode, don't materialize logits
92
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
93
+
94
+ if skip_logits:
87
95
  loss = LigerForCausalLMLoss(
88
96
  hidden_states=kept_hidden_states,
89
97
  lm_head_weight=self.lm_head.weight,
@@ -93,7 +101,7 @@ def lce_forward(
93
101
  **kwargs,
94
102
  )
95
103
 
96
- else: # if in inference mode materialize logits
104
+ else:
97
105
  logits = self.lm_head(kept_hidden_states)
98
106
  if labels is not None:
99
107
  loss = self.loss_function(
@@ -25,7 +25,8 @@ def lce_forward(
25
25
  output_router_logits: Optional[bool] = None,
26
26
  cache_position: Optional[torch.LongTensor] = None,
27
27
  logits_to_keep: Union[int, torch.Tensor] = 0,
28
- **loss_kwargs,
28
+ skip_logits: Optional[bool] = None,
29
+ **kwargs,
29
30
  ) -> MoeCausalLMOutputWithPast:
30
31
  r"""
31
32
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -80,6 +81,7 @@ def lce_forward(
80
81
  output_hidden_states=output_hidden_states,
81
82
  output_router_logits=output_router_logits,
82
83
  cache_position=cache_position,
84
+ **kwargs,
83
85
  )
84
86
 
85
87
  hidden_states = outputs.last_hidden_state
@@ -87,24 +89,26 @@ def lce_forward(
87
89
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
88
90
  kept_hidden_states = hidden_states[:, slice_indices, :]
89
91
 
90
- shift_labels = loss_kwargs.pop("shift_labels", None)
92
+ shift_labels = kwargs.pop("shift_labels", None)
91
93
  logits = None
92
94
  loss = None
93
95
 
94
- # if in training mode, do not materialize logits
95
- if self.training and (labels is not None or shift_labels is not None):
96
+ if skip_logits is None:
97
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
98
+
99
+ if skip_logits:
96
100
  loss = LigerForCausalLMLoss(
97
101
  hidden_states=kept_hidden_states,
98
102
  lm_head_weight=self.lm_head.weight,
99
103
  labels=labels,
100
104
  shift_labels=shift_labels,
101
105
  hidden_size=self.config.hidden_size,
102
- **loss_kwargs,
106
+ **kwargs,
103
107
  )
104
108
  else: # if in inference model materialize logits
105
109
  logits = self.lm_head(kept_hidden_states)
106
110
  if labels is not None:
107
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
111
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
108
112
 
109
113
  aux_loss = None
110
114
  if output_router_logits: