liger-kernel 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +0 -0
- liger_kernel/chunked_loss/dpo_loss.py +57 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +206 -0
- liger_kernel/chunked_loss/orpo_loss.py +63 -0
- liger_kernel/env_report.py +2 -0
- liger_kernel/ops/cross_entropy.py +143 -30
- liger_kernel/ops/fused_linear_cross_entropy.py +20 -2
- liger_kernel/ops/group_norm.py +322 -0
- liger_kernel/ops/rms_norm.py +27 -6
- liger_kernel/transformers/cross_entropy.py +44 -12
- liger_kernel/transformers/functional.py +34 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
- liger_kernel/transformers/group_norm.py +56 -0
- liger_kernel/transformers/model/gemma2.py +277 -0
- liger_kernel/transformers/model/qwen2_vl.py +43 -17
- liger_kernel/transformers/monkey_patch.py +106 -64
- liger_kernel/transformers/rms_norm.py +11 -3
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/METADATA +18 -82
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/RECORD +23 -16
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/WHEEL +1 -1
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/LICENSE +0 -0
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/NOTICE +0 -0
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerGroupNorm(nn.Module):
|
|
8
|
+
def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"):
|
|
9
|
+
"""
|
|
10
|
+
A Group Normalization layer.
|
|
11
|
+
Args:
|
|
12
|
+
num_channels (int): Number of channels in the input tensor.
|
|
13
|
+
num_groups (int): Number of groups to divide the channels into.
|
|
14
|
+
eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6.
|
|
15
|
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``.
|
|
16
|
+
init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones".
|
|
17
|
+
"""
|
|
18
|
+
super().__init__()
|
|
19
|
+
assert init_fn in [
|
|
20
|
+
"ones",
|
|
21
|
+
"zeros",
|
|
22
|
+
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
23
|
+
|
|
24
|
+
assert (
|
|
25
|
+
num_channels % num_groups == 0
|
|
26
|
+
), f"Number of channels {num_channels} must be divisible by num_groups {num_groups}"
|
|
27
|
+
self.num_channels = num_channels
|
|
28
|
+
self.num_groups = num_groups
|
|
29
|
+
self.eps = eps
|
|
30
|
+
self.weight = nn.Parameter(
|
|
31
|
+
torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)
|
|
32
|
+
)
|
|
33
|
+
self.bias = nn.Parameter(
|
|
34
|
+
torch.randn(num_channels) if bias else torch.zeros(num_channels)
|
|
35
|
+
)
|
|
36
|
+
self.variance_epsilon = eps
|
|
37
|
+
|
|
38
|
+
def forward(self, hidden_states):
|
|
39
|
+
# hidden_states: (batch_size, num_channels, *)
|
|
40
|
+
assert (
|
|
41
|
+
hidden_states.dim() >= 3
|
|
42
|
+
), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
|
|
43
|
+
assert (
|
|
44
|
+
hidden_states.size(1) == self.num_channels
|
|
45
|
+
), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
|
|
46
|
+
return LigerGroupNormFunction.apply(
|
|
47
|
+
hidden_states,
|
|
48
|
+
self.weight,
|
|
49
|
+
self.bias,
|
|
50
|
+
self.num_channels,
|
|
51
|
+
self.num_groups,
|
|
52
|
+
self.variance_epsilon,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def extra_repr(self):
|
|
56
|
+
return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}"
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.nn import CrossEntropyLoss
|
|
6
|
+
from transformers.cache_utils import HybridCache
|
|
7
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
+
from transformers.models.gemma2.modeling_gemma2 import (
|
|
9
|
+
_CONFIG_FOR_DOC,
|
|
10
|
+
GEMMA2_INPUTS_DOCSTRING,
|
|
11
|
+
)
|
|
12
|
+
from transformers.utils import (
|
|
13
|
+
add_start_docstrings_to_model_forward,
|
|
14
|
+
replace_return_docstrings,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|
18
|
+
LigerFusedLinearCrossEntropyLoss,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def lce_forward_deprecated(
|
|
25
|
+
self,
|
|
26
|
+
input_ids: torch.LongTensor = None,
|
|
27
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
28
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
29
|
+
past_key_values: Optional[HybridCache] = None,
|
|
30
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
31
|
+
labels: Optional[torch.LongTensor] = None,
|
|
32
|
+
use_cache: Optional[bool] = None,
|
|
33
|
+
output_attentions: Optional[bool] = None,
|
|
34
|
+
output_hidden_states: Optional[bool] = None,
|
|
35
|
+
return_dict: Optional[bool] = None,
|
|
36
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
37
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
38
|
+
r"""
|
|
39
|
+
Args:
|
|
40
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
41
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
42
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
43
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
51
|
+
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
52
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
53
|
+
>>> prompt = "What is your favorite condiment?"
|
|
54
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
55
|
+
>>> # Generate
|
|
56
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
57
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
58
|
+
"What is your favorite condiment?"
|
|
59
|
+
```"""
|
|
60
|
+
|
|
61
|
+
if self.training and self.config._attn_implementation != "eager":
|
|
62
|
+
logger.warning_once(
|
|
63
|
+
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
|
64
|
+
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
|
65
|
+
)
|
|
66
|
+
output_attentions = (
|
|
67
|
+
output_attentions
|
|
68
|
+
if output_attentions is not None
|
|
69
|
+
else self.config.output_attentions
|
|
70
|
+
)
|
|
71
|
+
output_hidden_states = (
|
|
72
|
+
output_hidden_states
|
|
73
|
+
if output_hidden_states is not None
|
|
74
|
+
else self.config.output_hidden_states
|
|
75
|
+
)
|
|
76
|
+
return_dict = (
|
|
77
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
78
|
+
)
|
|
79
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
80
|
+
outputs = self.model(
|
|
81
|
+
input_ids=input_ids,
|
|
82
|
+
attention_mask=attention_mask,
|
|
83
|
+
position_ids=position_ids,
|
|
84
|
+
past_key_values=past_key_values,
|
|
85
|
+
inputs_embeds=inputs_embeds,
|
|
86
|
+
use_cache=use_cache,
|
|
87
|
+
output_attentions=output_attentions,
|
|
88
|
+
output_hidden_states=output_hidden_states,
|
|
89
|
+
return_dict=return_dict,
|
|
90
|
+
cache_position=cache_position,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
hidden_states = outputs[0]
|
|
94
|
+
|
|
95
|
+
loss = None
|
|
96
|
+
logits = None
|
|
97
|
+
|
|
98
|
+
if self.training and (labels is not None):
|
|
99
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
100
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
101
|
+
|
|
102
|
+
# flatten
|
|
103
|
+
|
|
104
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
105
|
+
shift_labels = shift_labels.view(-1)
|
|
106
|
+
|
|
107
|
+
lce = LigerFusedLinearCrossEntropyLoss(
|
|
108
|
+
softcap=self.config.final_logit_softcapping
|
|
109
|
+
)
|
|
110
|
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
114
|
+
logits = self.lm_head(hidden_states)
|
|
115
|
+
if self.config.final_logit_softcapping is not None:
|
|
116
|
+
logits = logits / self.config.final_logit_softcapping
|
|
117
|
+
logits = torch.tanh(logits)
|
|
118
|
+
logits = logits * self.config.final_logit_softcapping
|
|
119
|
+
|
|
120
|
+
loss = None
|
|
121
|
+
if labels is not None:
|
|
122
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
123
|
+
logits = logits.float()
|
|
124
|
+
# Shift so that tokens < n predict n
|
|
125
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
126
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
127
|
+
# Flatten the tokens
|
|
128
|
+
loss_fct = CrossEntropyLoss()
|
|
129
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
130
|
+
shift_labels = shift_labels.view(-1)
|
|
131
|
+
# Enable model parallelism
|
|
132
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
133
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
134
|
+
|
|
135
|
+
if not return_dict:
|
|
136
|
+
output = (logits,) + outputs[1:]
|
|
137
|
+
return (loss,) + output if loss is not None else output
|
|
138
|
+
|
|
139
|
+
return CausalLMOutputWithPast(
|
|
140
|
+
loss=loss,
|
|
141
|
+
logits=logits,
|
|
142
|
+
past_key_values=outputs.past_key_values,
|
|
143
|
+
hidden_states=outputs.hidden_states,
|
|
144
|
+
attentions=outputs.attentions,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
|
149
|
+
@replace_return_docstrings(
|
|
150
|
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
151
|
+
)
|
|
152
|
+
def lce_forward(
|
|
153
|
+
self,
|
|
154
|
+
input_ids: torch.LongTensor = None,
|
|
155
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
156
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
157
|
+
past_key_values: Optional[HybridCache] = None,
|
|
158
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
159
|
+
labels: Optional[torch.LongTensor] = None,
|
|
160
|
+
use_cache: Optional[bool] = None,
|
|
161
|
+
output_attentions: Optional[bool] = None,
|
|
162
|
+
output_hidden_states: Optional[bool] = None,
|
|
163
|
+
return_dict: Optional[bool] = None,
|
|
164
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
165
|
+
num_logits_to_keep: int = 0,
|
|
166
|
+
**loss_kwargs,
|
|
167
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
168
|
+
r"""
|
|
169
|
+
Args:
|
|
170
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
171
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
172
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
173
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
174
|
+
|
|
175
|
+
num_logits_to_keep (`int`, *optional*):
|
|
176
|
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
|
177
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
178
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
|
|
184
|
+
```python
|
|
185
|
+
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
186
|
+
|
|
187
|
+
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
188
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
189
|
+
|
|
190
|
+
>>> prompt = "What is your favorite condiment?"
|
|
191
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
192
|
+
|
|
193
|
+
>>> # Generate
|
|
194
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
195
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
196
|
+
"What is your favorite condiment?"
|
|
197
|
+
```"""
|
|
198
|
+
|
|
199
|
+
if self.training and self.config._attn_implementation != "eager":
|
|
200
|
+
logger.warning_once(
|
|
201
|
+
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
|
202
|
+
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
|
203
|
+
)
|
|
204
|
+
output_attentions = (
|
|
205
|
+
output_attentions
|
|
206
|
+
if output_attentions is not None
|
|
207
|
+
else self.config.output_attentions
|
|
208
|
+
)
|
|
209
|
+
output_hidden_states = (
|
|
210
|
+
output_hidden_states
|
|
211
|
+
if output_hidden_states is not None
|
|
212
|
+
else self.config.output_hidden_states
|
|
213
|
+
)
|
|
214
|
+
return_dict = (
|
|
215
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
216
|
+
)
|
|
217
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
218
|
+
outputs = self.model(
|
|
219
|
+
input_ids=input_ids,
|
|
220
|
+
attention_mask=attention_mask,
|
|
221
|
+
position_ids=position_ids,
|
|
222
|
+
past_key_values=past_key_values,
|
|
223
|
+
inputs_embeds=inputs_embeds,
|
|
224
|
+
use_cache=use_cache,
|
|
225
|
+
output_attentions=output_attentions,
|
|
226
|
+
output_hidden_states=output_hidden_states,
|
|
227
|
+
return_dict=return_dict,
|
|
228
|
+
cache_position=cache_position,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
hidden_states = outputs[0]
|
|
232
|
+
|
|
233
|
+
logits = None
|
|
234
|
+
loss = None
|
|
235
|
+
# if in training mode, don't materialize logits
|
|
236
|
+
if self.training and (labels is not None):
|
|
237
|
+
# We do the same thing as ForCausalLMLoss but using Liger FLCE
|
|
238
|
+
|
|
239
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
240
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
241
|
+
|
|
242
|
+
# flatten tokens
|
|
243
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
244
|
+
shift_labels = shift_labels.view(-1)
|
|
245
|
+
|
|
246
|
+
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
247
|
+
lce = LigerFusedLinearCrossEntropyLoss(
|
|
248
|
+
softcap=self.config.final_logit_softcapping,
|
|
249
|
+
reduction=reduction,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
253
|
+
if reduction == "sum":
|
|
254
|
+
loss /= loss_kwargs["num_items_in_batch"]
|
|
255
|
+
|
|
256
|
+
else: # if in inference mode materialize logits
|
|
257
|
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
258
|
+
if self.config.final_logit_softcapping is not None:
|
|
259
|
+
logits = logits / self.config.final_logit_softcapping
|
|
260
|
+
logits = torch.tanh(logits)
|
|
261
|
+
logits = logits * self.config.final_logit_softcapping
|
|
262
|
+
|
|
263
|
+
loss = None
|
|
264
|
+
if labels is not None:
|
|
265
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
266
|
+
|
|
267
|
+
if not return_dict:
|
|
268
|
+
output = (logits,) + outputs[1:]
|
|
269
|
+
return (loss,) + output if loss is not None else output
|
|
270
|
+
|
|
271
|
+
return CausalLMOutputWithPast(
|
|
272
|
+
loss=loss,
|
|
273
|
+
logits=logits,
|
|
274
|
+
past_key_values=outputs.past_key_values,
|
|
275
|
+
hidden_states=outputs.hidden_states,
|
|
276
|
+
attentions=outputs.attentions,
|
|
277
|
+
)
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from typing import List, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from packaging import version
|
|
4
5
|
from torch.nn import CrossEntropyLoss
|
|
6
|
+
from transformers import __version__ as transformers_version
|
|
5
7
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
6
8
|
_CONFIG_FOR_DOC,
|
|
7
9
|
QWEN2_VL_INPUTS_DOCSTRING,
|
|
@@ -80,8 +82,6 @@ def lce_forward(
|
|
|
80
82
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
81
83
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
82
84
|
```"""
|
|
83
|
-
# FIXME: The code is outdated and not compatible with transformer >= 4.46.1
|
|
84
|
-
|
|
85
85
|
output_attentions = (
|
|
86
86
|
output_attentions
|
|
87
87
|
if output_attentions is not None
|
|
@@ -100,27 +100,53 @@ def lce_forward(
|
|
|
100
100
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
101
101
|
if pixel_values is not None:
|
|
102
102
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
|
103
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
104
|
-
|
|
103
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
104
|
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
105
|
+
n_image_features = image_embeds.shape[0]
|
|
106
|
+
if n_image_tokens != n_image_features:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
109
|
+
)
|
|
110
|
+
image_mask = (
|
|
111
|
+
(input_ids == self.config.image_token_id)
|
|
112
|
+
.unsqueeze(-1)
|
|
113
|
+
.expand_as(inputs_embeds)
|
|
114
|
+
.to(inputs_embeds.device)
|
|
105
115
|
)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
inputs_embeds[image_mask] = image_embeds
|
|
116
|
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
117
|
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
118
|
+
|
|
110
119
|
if pixel_values_videos is not None:
|
|
111
120
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
|
112
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
113
|
-
|
|
121
|
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
122
|
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
123
|
+
n_video_features = video_embeds.shape[0]
|
|
124
|
+
if n_video_tokens != n_video_features:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
127
|
+
)
|
|
128
|
+
video_mask = (
|
|
129
|
+
(input_ids == self.config.video_token_id)
|
|
130
|
+
.unsqueeze(-1)
|
|
131
|
+
.expand_as(inputs_embeds)
|
|
132
|
+
.to(inputs_embeds.device)
|
|
114
133
|
)
|
|
115
|
-
|
|
116
|
-
inputs_embeds
|
|
134
|
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
135
|
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
136
|
+
|
|
117
137
|
if attention_mask is not None:
|
|
118
138
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
119
|
-
|
|
120
|
-
if
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
139
|
+
|
|
140
|
+
if version.parse(transformers_version) > version.parse("4.46.2"):
|
|
141
|
+
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
|
|
142
|
+
# https://github.com/huggingface/transformers/issues/33401
|
|
143
|
+
# While correct, this breaks equivalence with past versions of Qwen2-VL from
|
|
144
|
+
# transformers and leads to failed tests or users noticing differences in results.
|
|
145
|
+
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
|
146
|
+
if position_ids is None and input_ids is not None:
|
|
147
|
+
position_ids, _ = self.get_rope_index(
|
|
148
|
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
|
149
|
+
)
|
|
124
150
|
|
|
125
151
|
outputs = self.model(
|
|
126
152
|
input_ids=None,
|