liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/functional.py +62 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +62 -98
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/monkey_patch.py +304 -70
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -216,6 +216,7 @@ def lce_forward(
|
|
|
216
216
|
output_hidden_states: Optional[bool] = None,
|
|
217
217
|
return_dict: Optional[bool] = None,
|
|
218
218
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
219
|
+
skip_logits: Optional[bool] = None,
|
|
219
220
|
**lm_kwargs,
|
|
220
221
|
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
|
|
221
222
|
r"""
|
|
@@ -331,7 +332,13 @@ def lce_forward(
|
|
|
331
332
|
loss = None
|
|
332
333
|
logits = None
|
|
333
334
|
|
|
334
|
-
if
|
|
335
|
+
if skip_logits and labels is None:
|
|
336
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
337
|
+
|
|
338
|
+
if skip_logits is None:
|
|
339
|
+
skip_logits = self.training and (labels is not None)
|
|
340
|
+
|
|
341
|
+
if skip_logits:
|
|
335
342
|
shift_hidden_states = hidden_states[..., :-1, :]
|
|
336
343
|
shift_labels = labels[..., 1:]
|
|
337
344
|
|
|
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
|
|
|
26
26
|
output_hidden_states: Optional[bool] = None,
|
|
27
27
|
return_dict: Optional[bool] = None,
|
|
28
28
|
cache_position: Optional[torch.LongTensor] = None,
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
29
30
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
30
31
|
r"""
|
|
31
32
|
Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -80,7 +81,14 @@ def lce_forward_deprecated(
|
|
|
80
81
|
loss = None
|
|
81
82
|
logits = None
|
|
82
83
|
|
|
83
|
-
if
|
|
84
|
+
if skip_logits and labels is None:
|
|
85
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
86
|
+
|
|
87
|
+
if skip_logits is None:
|
|
88
|
+
# By default, if in training mode, don't materialize logits
|
|
89
|
+
skip_logits = self.training and labels is not None
|
|
90
|
+
|
|
91
|
+
if skip_logits:
|
|
84
92
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
85
93
|
shift_labels = labels[..., 1:].contiguous()
|
|
86
94
|
|
|
@@ -136,7 +144,8 @@ def lce_forward(
|
|
|
136
144
|
return_dict: Optional[bool] = None,
|
|
137
145
|
cache_position: Optional[torch.LongTensor] = None,
|
|
138
146
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
139
|
-
|
|
147
|
+
skip_logits: Optional[bool] = None,
|
|
148
|
+
**kwargs,
|
|
140
149
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
141
150
|
r"""
|
|
142
151
|
Args:
|
|
@@ -202,6 +211,7 @@ def lce_forward(
|
|
|
202
211
|
output_attentions=output_attentions,
|
|
203
212
|
output_hidden_states=output_hidden_states,
|
|
204
213
|
return_dict=return_dict,
|
|
214
|
+
**kwargs,
|
|
205
215
|
)
|
|
206
216
|
|
|
207
217
|
hidden_states = outputs[0]
|
|
@@ -209,28 +219,35 @@ def lce_forward(
|
|
|
209
219
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
210
220
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
211
221
|
|
|
212
|
-
shift_labels =
|
|
222
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
213
223
|
logits = None
|
|
214
224
|
loss = None
|
|
215
|
-
|
|
216
|
-
if
|
|
225
|
+
|
|
226
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
227
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
228
|
+
|
|
229
|
+
if skip_logits is None:
|
|
230
|
+
# By default, if in training mode, don't materialize logits
|
|
231
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
232
|
+
|
|
233
|
+
if skip_logits:
|
|
217
234
|
loss = LigerForCausalLMLoss(
|
|
218
235
|
hidden_states=kept_hidden_states,
|
|
219
236
|
lm_head_weight=self.lm_head.weight,
|
|
220
237
|
labels=labels,
|
|
221
238
|
shift_labels=shift_labels,
|
|
222
239
|
hidden_size=self.config.hidden_size,
|
|
223
|
-
**
|
|
240
|
+
**kwargs,
|
|
224
241
|
)
|
|
225
242
|
|
|
226
|
-
else:
|
|
243
|
+
else:
|
|
227
244
|
logits = self.lm_head(kept_hidden_states)
|
|
228
245
|
if labels is not None:
|
|
229
246
|
loss = self.loss_function(
|
|
230
247
|
logits=logits,
|
|
231
248
|
labels=labels,
|
|
232
249
|
vocab_size=self.config.vocab_size,
|
|
233
|
-
**
|
|
250
|
+
**kwargs,
|
|
234
251
|
)
|
|
235
252
|
|
|
236
253
|
if not return_dict:
|
|
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
|
|
|
26
26
|
output_hidden_states: Optional[bool] = None,
|
|
27
27
|
return_dict: Optional[bool] = None,
|
|
28
28
|
cache_position: Optional[torch.LongTensor] = None,
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
29
30
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
30
31
|
r"""
|
|
31
32
|
Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -80,6 +81,13 @@ def lce_forward_deprecated(
|
|
|
80
81
|
loss = None
|
|
81
82
|
logits = None
|
|
82
83
|
|
|
84
|
+
if skip_logits and labels is None:
|
|
85
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
86
|
+
|
|
87
|
+
if skip_logits is None:
|
|
88
|
+
# By default, if in training mode, don't materialize logits
|
|
89
|
+
skip_logits = self.training and labels is not None
|
|
90
|
+
|
|
83
91
|
if self.training and (labels is not None):
|
|
84
92
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
85
93
|
shift_labels = labels[..., 1:].contiguous()
|
|
@@ -135,7 +143,8 @@ def lce_forward(
|
|
|
135
143
|
return_dict: Optional[bool] = None,
|
|
136
144
|
cache_position: Optional[torch.LongTensor] = None,
|
|
137
145
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
138
|
-
|
|
146
|
+
skip_logits: Optional[bool] = None,
|
|
147
|
+
**kwargs,
|
|
139
148
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
140
149
|
r"""
|
|
141
150
|
Args:
|
|
@@ -188,6 +197,7 @@ def lce_forward(
|
|
|
188
197
|
output_hidden_states=output_hidden_states,
|
|
189
198
|
return_dict=return_dict,
|
|
190
199
|
cache_position=cache_position,
|
|
200
|
+
**kwargs,
|
|
191
201
|
)
|
|
192
202
|
|
|
193
203
|
hidden_states = outputs[0]
|
|
@@ -195,28 +205,35 @@ def lce_forward(
|
|
|
195
205
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
196
206
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
197
207
|
|
|
198
|
-
shift_labels =
|
|
208
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
199
209
|
logits = None
|
|
200
210
|
loss = None
|
|
201
|
-
|
|
202
|
-
if
|
|
211
|
+
|
|
212
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
213
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
214
|
+
|
|
215
|
+
if skip_logits is None:
|
|
216
|
+
# By default, if in training mode, don't materialize logits
|
|
217
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
218
|
+
|
|
219
|
+
if skip_logits:
|
|
203
220
|
loss = LigerForCausalLMLoss(
|
|
204
221
|
hidden_states=kept_hidden_states,
|
|
205
222
|
lm_head_weight=self.lm_head.weight,
|
|
206
223
|
labels=labels,
|
|
207
224
|
shift_labels=shift_labels,
|
|
208
225
|
hidden_size=self.config.hidden_size,
|
|
209
|
-
**
|
|
226
|
+
**kwargs,
|
|
210
227
|
)
|
|
211
228
|
|
|
212
|
-
else:
|
|
229
|
+
else:
|
|
213
230
|
logits = self.lm_head(kept_hidden_states)
|
|
214
231
|
if labels is not None:
|
|
215
232
|
loss = self.loss_function(
|
|
216
233
|
logits=logits,
|
|
217
234
|
labels=labels,
|
|
218
235
|
vocab_size=self.config.vocab_size,
|
|
219
|
-
**
|
|
236
|
+
**kwargs,
|
|
220
237
|
)
|
|
221
238
|
|
|
222
239
|
return CausalLMOutputWithPast(
|
|
@@ -5,12 +5,13 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from torch.nn import CrossEntropyLoss
|
|
9
8
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
|
|
9
|
+
from transformers.utils import can_return_tuple
|
|
10
10
|
|
|
11
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
@can_return_tuple
|
|
14
15
|
def lce_forward(
|
|
15
16
|
self,
|
|
16
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -30,17 +31,26 @@ def lce_forward(
|
|
|
30
31
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
31
32
|
cache_position: Optional[torch.LongTensor] = None,
|
|
32
33
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
33
|
-
|
|
34
|
+
skip_logits: Optional[bool] = None,
|
|
35
|
+
**kwargs,
|
|
34
36
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
|
35
37
|
r"""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
38
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
39
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
40
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
41
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
42
|
+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
|
43
|
+
The tensors corresponding to the input videos. Pixel values can be obtained using
|
|
44
|
+
[`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
|
|
45
|
+
[`Qwen2_5_VLImageProcessor`] for processing videos.
|
|
46
|
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
47
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
48
|
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
49
|
+
The temporal, height and width of feature shape of each video in LLM.
|
|
50
|
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
51
|
+
The rope index difference between sequence length and multimodal rope.
|
|
52
|
+
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
|
53
|
+
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
|
44
54
|
|
|
45
55
|
Example:
|
|
46
56
|
|
|
@@ -72,78 +82,20 @@ def lce_forward(
|
|
|
72
82
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
73
83
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
74
84
|
```"""
|
|
85
|
+
|
|
75
86
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
76
87
|
output_hidden_states = (
|
|
77
88
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
78
89
|
)
|
|
79
90
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
80
91
|
|
|
81
|
-
if inputs_embeds is None:
|
|
82
|
-
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
83
|
-
if pixel_values is not None:
|
|
84
|
-
pixel_values = pixel_values.type(self.visual.dtype)
|
|
85
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
86
|
-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
87
|
-
n_image_features = image_embeds.shape[0]
|
|
88
|
-
if n_image_tokens != n_image_features:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
mask = input_ids == self.config.image_token_id
|
|
94
|
-
mask_unsqueezed = mask.unsqueeze(-1)
|
|
95
|
-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
96
|
-
image_mask = mask_expanded.to(inputs_embeds.device)
|
|
97
|
-
|
|
98
|
-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
99
|
-
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
100
|
-
|
|
101
|
-
if pixel_values_videos is not None:
|
|
102
|
-
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
103
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
104
|
-
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
105
|
-
n_video_features = video_embeds.shape[0]
|
|
106
|
-
if n_video_tokens != n_video_features:
|
|
107
|
-
raise ValueError(
|
|
108
|
-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
mask = input_ids == self.config.video_token_id
|
|
112
|
-
mask_unsqueezed = mask.unsqueeze(-1)
|
|
113
|
-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
114
|
-
video_mask = mask_expanded.to(inputs_embeds.device)
|
|
115
|
-
|
|
116
|
-
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
117
|
-
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
118
|
-
|
|
119
|
-
if attention_mask is not None:
|
|
120
|
-
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
121
|
-
|
|
122
|
-
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
123
|
-
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
|
124
|
-
# calculate RoPE index once per generation in the pre-fill stage only
|
|
125
|
-
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
|
126
|
-
position_ids, rope_deltas = self.get_rope_index(
|
|
127
|
-
input_ids,
|
|
128
|
-
image_grid_thw,
|
|
129
|
-
video_grid_thw,
|
|
130
|
-
second_per_grid_ts,
|
|
131
|
-
attention_mask,
|
|
132
|
-
)
|
|
133
|
-
self.rope_deltas = rope_deltas
|
|
134
|
-
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
135
|
-
else:
|
|
136
|
-
batch_size, seq_length, _ = inputs_embeds.shape
|
|
137
|
-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
|
|
138
|
-
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
139
|
-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
140
|
-
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
141
|
-
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
142
|
-
position_ids = position_ids.add(delta)
|
|
143
|
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
144
|
-
|
|
145
92
|
outputs = self.model(
|
|
146
|
-
input_ids=
|
|
93
|
+
input_ids=input_ids,
|
|
94
|
+
pixel_values=pixel_values,
|
|
95
|
+
pixel_values_videos=pixel_values_videos,
|
|
96
|
+
image_grid_thw=image_grid_thw,
|
|
97
|
+
video_grid_thw=video_grid_thw,
|
|
98
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
147
99
|
position_ids=position_ids,
|
|
148
100
|
attention_mask=attention_mask,
|
|
149
101
|
past_key_values=past_key_values,
|
|
@@ -153,38 +105,36 @@ def lce_forward(
|
|
|
153
105
|
output_hidden_states=output_hidden_states,
|
|
154
106
|
return_dict=return_dict,
|
|
155
107
|
cache_position=cache_position,
|
|
108
|
+
**kwargs,
|
|
156
109
|
)
|
|
157
110
|
|
|
158
111
|
hidden_states = outputs[0]
|
|
159
112
|
|
|
160
|
-
shift_labels =
|
|
113
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
161
114
|
loss = None
|
|
162
115
|
logits = None
|
|
163
116
|
|
|
164
|
-
if
|
|
117
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
118
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
119
|
+
|
|
120
|
+
if skip_logits is None:
|
|
121
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
122
|
+
|
|
123
|
+
if skip_logits:
|
|
165
124
|
loss = LigerForCausalLMLoss(
|
|
166
125
|
hidden_states=hidden_states,
|
|
167
126
|
lm_head_weight=self.lm_head.weight,
|
|
168
127
|
labels=labels,
|
|
169
128
|
shift_labels=shift_labels,
|
|
170
129
|
hidden_size=self.config.hidden_size,
|
|
171
|
-
**
|
|
130
|
+
**kwargs,
|
|
172
131
|
)
|
|
173
132
|
else:
|
|
174
133
|
logits = self.lm_head(hidden_states)
|
|
134
|
+
|
|
135
|
+
loss = None
|
|
175
136
|
if labels is not None:
|
|
176
|
-
|
|
177
|
-
logits = logits.float()
|
|
178
|
-
# Shift so that tokens < n predict n
|
|
179
|
-
shift_logits = logits[..., :-1, :].contiguous()
|
|
180
|
-
shift_labels = labels[..., 1:].contiguous()
|
|
181
|
-
# Flatten the tokens
|
|
182
|
-
loss_fct = CrossEntropyLoss()
|
|
183
|
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
184
|
-
shift_labels = shift_labels.view(-1)
|
|
185
|
-
# Enable model parallelism
|
|
186
|
-
shift_labels = shift_labels.to(shift_logits.device)
|
|
187
|
-
loss = loss_fct(shift_logits, shift_labels)
|
|
137
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
|
188
138
|
|
|
189
139
|
if not return_dict:
|
|
190
140
|
output = (logits,) + outputs[1:]
|
|
@@ -196,5 +146,5 @@ def lce_forward(
|
|
|
196
146
|
past_key_values=outputs.past_key_values,
|
|
197
147
|
hidden_states=outputs.hidden_states,
|
|
198
148
|
attentions=outputs.attentions,
|
|
199
|
-
rope_deltas=rope_deltas,
|
|
149
|
+
rope_deltas=outputs.rope_deltas,
|
|
200
150
|
)
|
|
@@ -5,14 +5,13 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from packaging import version
|
|
9
|
-
from torch.nn import CrossEntropyLoss
|
|
10
|
-
from transformers import __version__ as transformers_version
|
|
11
8
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
|
9
|
+
from transformers.utils import can_return_tuple
|
|
12
10
|
|
|
13
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
12
|
|
|
15
13
|
|
|
14
|
+
@can_return_tuple
|
|
16
15
|
def lce_forward(
|
|
17
16
|
self,
|
|
18
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -31,18 +30,24 @@ def lce_forward(
|
|
|
31
30
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
32
31
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
33
32
|
cache_position: Optional[torch.LongTensor] = None,
|
|
34
|
-
|
|
33
|
+
skip_logits: Optional[bool] = None,
|
|
34
|
+
**kwargs,
|
|
35
35
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
36
36
|
r"""
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
37
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
38
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
39
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
40
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
41
|
+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
|
42
|
+
The tensors corresponding to the input videos. Pixel values can be obtained using
|
|
43
|
+
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
|
|
44
|
+
[`Qwen2VLImageProcessor`] for processing videos.
|
|
45
|
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
46
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
47
|
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
48
|
+
The temporal, height and width of feature shape of each video in LLM.
|
|
49
|
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
50
|
+
The rope index difference between sequence length and multimodal rope.
|
|
46
51
|
|
|
47
52
|
Example:
|
|
48
53
|
|
|
@@ -74,80 +79,19 @@ def lce_forward(
|
|
|
74
79
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
75
80
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
76
81
|
```"""
|
|
82
|
+
|
|
77
83
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
78
84
|
output_hidden_states = (
|
|
79
85
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
80
86
|
)
|
|
81
87
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
82
88
|
|
|
83
|
-
if inputs_embeds is None:
|
|
84
|
-
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
85
|
-
if pixel_values is not None:
|
|
86
|
-
pixel_values = pixel_values.type(self.visual.get_dtype())
|
|
87
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
88
|
-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
89
|
-
n_image_features = image_embeds.shape[0]
|
|
90
|
-
if n_image_tokens != n_image_features:
|
|
91
|
-
raise ValueError(
|
|
92
|
-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
93
|
-
)
|
|
94
|
-
image_mask = (
|
|
95
|
-
(input_ids == self.config.image_token_id)
|
|
96
|
-
.unsqueeze(-1)
|
|
97
|
-
.expand_as(inputs_embeds)
|
|
98
|
-
.to(inputs_embeds.device)
|
|
99
|
-
)
|
|
100
|
-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
101
|
-
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
102
|
-
|
|
103
|
-
if pixel_values_videos is not None:
|
|
104
|
-
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
|
105
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
106
|
-
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
107
|
-
n_video_features = video_embeds.shape[0]
|
|
108
|
-
if n_video_tokens != n_video_features:
|
|
109
|
-
raise ValueError(
|
|
110
|
-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
111
|
-
)
|
|
112
|
-
video_mask = (
|
|
113
|
-
(input_ids == self.config.video_token_id)
|
|
114
|
-
.unsqueeze(-1)
|
|
115
|
-
.expand_as(inputs_embeds)
|
|
116
|
-
.to(inputs_embeds.device)
|
|
117
|
-
)
|
|
118
|
-
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
119
|
-
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
120
|
-
|
|
121
|
-
if attention_mask is not None:
|
|
122
|
-
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
123
|
-
|
|
124
|
-
if version.parse(transformers_version) > version.parse("4.46.3"):
|
|
125
|
-
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
|
|
126
|
-
# https://github.com/huggingface/transformers/issues/33401
|
|
127
|
-
# While correct, this breaks equivalence with past versions of Qwen2-VL from
|
|
128
|
-
# transformers and leads to failed tests or users noticing differences in results.
|
|
129
|
-
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
|
130
|
-
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
131
|
-
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
|
132
|
-
# calculate RoPE index once per generation in the pre-fill stage only
|
|
133
|
-
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
|
134
|
-
position_ids, rope_deltas = self.get_rope_index(
|
|
135
|
-
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
|
136
|
-
)
|
|
137
|
-
self.rope_deltas = rope_deltas
|
|
138
|
-
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
139
|
-
else:
|
|
140
|
-
batch_size, seq_length, _ = inputs_embeds.shape
|
|
141
|
-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
|
142
|
-
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
143
|
-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
144
|
-
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
145
|
-
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
146
|
-
position_ids = position_ids.add(delta)
|
|
147
|
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
148
|
-
|
|
149
89
|
outputs = self.model(
|
|
150
|
-
input_ids=
|
|
90
|
+
input_ids=input_ids,
|
|
91
|
+
pixel_values=pixel_values,
|
|
92
|
+
pixel_values_videos=pixel_values_videos,
|
|
93
|
+
image_grid_thw=image_grid_thw,
|
|
94
|
+
video_grid_thw=video_grid_thw,
|
|
151
95
|
position_ids=position_ids,
|
|
152
96
|
attention_mask=attention_mask,
|
|
153
97
|
past_key_values=past_key_values,
|
|
@@ -157,42 +101,36 @@ def lce_forward(
|
|
|
157
101
|
output_hidden_states=output_hidden_states,
|
|
158
102
|
return_dict=return_dict,
|
|
159
103
|
cache_position=cache_position,
|
|
104
|
+
**kwargs,
|
|
160
105
|
)
|
|
161
106
|
|
|
162
107
|
hidden_states = outputs[0]
|
|
163
108
|
|
|
164
|
-
shift_labels =
|
|
109
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
165
110
|
loss = None
|
|
166
111
|
logits = None
|
|
167
112
|
|
|
168
|
-
if
|
|
113
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
114
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
115
|
+
|
|
116
|
+
if skip_logits is None:
|
|
117
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
118
|
+
|
|
119
|
+
if skip_logits:
|
|
169
120
|
loss = LigerForCausalLMLoss(
|
|
170
121
|
hidden_states=hidden_states,
|
|
171
122
|
lm_head_weight=self.lm_head.weight,
|
|
172
123
|
labels=labels,
|
|
173
124
|
shift_labels=shift_labels,
|
|
174
125
|
hidden_size=self.config.hidden_size,
|
|
175
|
-
**
|
|
126
|
+
**kwargs,
|
|
176
127
|
)
|
|
177
128
|
else:
|
|
178
129
|
logits = self.lm_head(hidden_states)
|
|
130
|
+
|
|
131
|
+
loss = None
|
|
179
132
|
if labels is not None:
|
|
180
|
-
|
|
181
|
-
logits = logits.float()
|
|
182
|
-
# Shift so that tokens < n predict n
|
|
183
|
-
shift_logits = logits[..., :-1, :].contiguous()
|
|
184
|
-
shift_labels = labels[..., 1:].contiguous()
|
|
185
|
-
# Flatten the tokens
|
|
186
|
-
loss_fct = CrossEntropyLoss()
|
|
187
|
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
188
|
-
shift_labels = shift_labels.view(-1)
|
|
189
|
-
# Enable model parallelism
|
|
190
|
-
shift_labels = shift_labels.to(shift_logits.device)
|
|
191
|
-
loss = loss_fct(shift_logits, shift_labels)
|
|
192
|
-
|
|
193
|
-
if not return_dict:
|
|
194
|
-
output = (logits,) + outputs[1:]
|
|
195
|
-
return (loss,) + output if loss is not None else output
|
|
133
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
|
196
134
|
|
|
197
135
|
return Qwen2VLCausalLMOutputWithPast(
|
|
198
136
|
loss=loss,
|
|
@@ -200,5 +138,5 @@ def lce_forward(
|
|
|
200
138
|
past_key_values=outputs.past_key_values,
|
|
201
139
|
hidden_states=outputs.hidden_states,
|
|
202
140
|
attentions=outputs.attentions,
|
|
203
|
-
rope_deltas=rope_deltas,
|
|
141
|
+
rope_deltas=outputs.rope_deltas,
|
|
204
142
|
)
|
|
@@ -22,6 +22,7 @@ def lce_forward(
|
|
|
22
22
|
output_hidden_states: Optional[bool] = None,
|
|
23
23
|
cache_position: Optional[torch.LongTensor] = None,
|
|
24
24
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
25
|
+
skip_logits: Optional[bool] = None,
|
|
25
26
|
**kwargs,
|
|
26
27
|
) -> CausalLMOutputWithPast:
|
|
27
28
|
r"""
|
|
@@ -82,8 +83,15 @@ def lce_forward(
|
|
|
82
83
|
shift_labels = kwargs.pop("shift_labels", None)
|
|
83
84
|
logits = None
|
|
84
85
|
loss = None
|
|
85
|
-
|
|
86
|
-
if
|
|
86
|
+
|
|
87
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
88
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
89
|
+
|
|
90
|
+
if skip_logits is None:
|
|
91
|
+
# By default, if in training mode, don't materialize logits
|
|
92
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
93
|
+
|
|
94
|
+
if skip_logits:
|
|
87
95
|
loss = LigerForCausalLMLoss(
|
|
88
96
|
hidden_states=kept_hidden_states,
|
|
89
97
|
lm_head_weight=self.lm_head.weight,
|
|
@@ -93,7 +101,7 @@ def lce_forward(
|
|
|
93
101
|
**kwargs,
|
|
94
102
|
)
|
|
95
103
|
|
|
96
|
-
else:
|
|
104
|
+
else:
|
|
97
105
|
logits = self.lm_head(kept_hidden_states)
|
|
98
106
|
if labels is not None:
|
|
99
107
|
loss = self.loss_function(
|
|
@@ -25,7 +25,8 @@ def lce_forward(
|
|
|
25
25
|
output_router_logits: Optional[bool] = None,
|
|
26
26
|
cache_position: Optional[torch.LongTensor] = None,
|
|
27
27
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
28
|
-
|
|
28
|
+
skip_logits: Optional[bool] = None,
|
|
29
|
+
**kwargs,
|
|
29
30
|
) -> MoeCausalLMOutputWithPast:
|
|
30
31
|
r"""
|
|
31
32
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -80,6 +81,7 @@ def lce_forward(
|
|
|
80
81
|
output_hidden_states=output_hidden_states,
|
|
81
82
|
output_router_logits=output_router_logits,
|
|
82
83
|
cache_position=cache_position,
|
|
84
|
+
**kwargs,
|
|
83
85
|
)
|
|
84
86
|
|
|
85
87
|
hidden_states = outputs.last_hidden_state
|
|
@@ -87,24 +89,26 @@ def lce_forward(
|
|
|
87
89
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
88
90
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
89
91
|
|
|
90
|
-
shift_labels =
|
|
92
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
91
93
|
logits = None
|
|
92
94
|
loss = None
|
|
93
95
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
+
if skip_logits is None:
|
|
97
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
98
|
+
|
|
99
|
+
if skip_logits:
|
|
96
100
|
loss = LigerForCausalLMLoss(
|
|
97
101
|
hidden_states=kept_hidden_states,
|
|
98
102
|
lm_head_weight=self.lm_head.weight,
|
|
99
103
|
labels=labels,
|
|
100
104
|
shift_labels=shift_labels,
|
|
101
105
|
hidden_size=self.config.hidden_size,
|
|
102
|
-
**
|
|
106
|
+
**kwargs,
|
|
103
107
|
)
|
|
104
108
|
else: # if in inference model materialize logits
|
|
105
109
|
logits = self.lm_head(kept_hidden_states)
|
|
106
110
|
if labels is not None:
|
|
107
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
|
111
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
108
112
|
|
|
109
113
|
aux_loss = None
|
|
110
114
|
if output_router_logits:
|