liger-kernel-nightly 0.5.10.dev20250610174206__py3-none-any.whl → 0.5.10.dev20250611064616__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/model/qwen2_5_vl.py +29 -87
- liger_kernel/transformers/model/qwen2_vl.py +26 -96
- liger_kernel/transformers/monkey_patch.py +66 -23
- {liger_kernel_nightly-0.5.10.dev20250610174206.dist-info → liger_kernel_nightly-0.5.10.dev20250611064616.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250610174206.dist-info → liger_kernel_nightly-0.5.10.dev20250611064616.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.5.10.dev20250610174206.dist-info → liger_kernel_nightly-0.5.10.dev20250611064616.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250610174206.dist-info → liger_kernel_nightly-0.5.10.dev20250611064616.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250610174206.dist-info → liger_kernel_nightly-0.5.10.dev20250611064616.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250610174206.dist-info → liger_kernel_nightly-0.5.10.dev20250611064616.dist-info}/top_level.txt +0 -0
@@ -5,12 +5,13 @@ from typing import Union
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from torch.nn import CrossEntropyLoss
|
9
8
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
|
9
|
+
from transformers.utils import can_return_tuple
|
10
10
|
|
11
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
12
12
|
|
13
13
|
|
14
|
+
@can_return_tuple
|
14
15
|
def lce_forward(
|
15
16
|
self,
|
16
17
|
input_ids: torch.LongTensor = None,
|
@@ -34,14 +35,22 @@ def lce_forward(
|
|
34
35
|
**kwargs,
|
35
36
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
36
37
|
r"""
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
38
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
39
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
40
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
41
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
42
|
+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
43
|
+
The tensors corresponding to the input videos. Pixel values can be obtained using
|
44
|
+
[`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
|
45
|
+
[`Qwen2_5_VLImageProcessor`] for processing videos.
|
46
|
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
47
|
+
The temporal, height and width of feature shape of each image in LLM.
|
48
|
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
49
|
+
The temporal, height and width of feature shape of each video in LLM.
|
50
|
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
51
|
+
The rope index difference between sequence length and multimodal rope.
|
52
|
+
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
53
|
+
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
45
54
|
|
46
55
|
Example:
|
47
56
|
|
@@ -73,78 +82,20 @@ def lce_forward(
|
|
73
82
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
74
83
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
75
84
|
```"""
|
85
|
+
|
76
86
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
77
87
|
output_hidden_states = (
|
78
88
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
79
89
|
)
|
80
90
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
81
91
|
|
82
|
-
if inputs_embeds is None:
|
83
|
-
inputs_embeds = self.model.embed_tokens(input_ids)
|
84
|
-
if pixel_values is not None:
|
85
|
-
pixel_values = pixel_values.type(self.visual.dtype)
|
86
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
87
|
-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
88
|
-
n_image_features = image_embeds.shape[0]
|
89
|
-
if n_image_tokens != n_image_features:
|
90
|
-
raise ValueError(
|
91
|
-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
92
|
-
)
|
93
|
-
|
94
|
-
mask = input_ids == self.config.image_token_id
|
95
|
-
mask_unsqueezed = mask.unsqueeze(-1)
|
96
|
-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
97
|
-
image_mask = mask_expanded.to(inputs_embeds.device)
|
98
|
-
|
99
|
-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
100
|
-
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
101
|
-
|
102
|
-
if pixel_values_videos is not None:
|
103
|
-
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
104
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
105
|
-
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
106
|
-
n_video_features = video_embeds.shape[0]
|
107
|
-
if n_video_tokens != n_video_features:
|
108
|
-
raise ValueError(
|
109
|
-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
110
|
-
)
|
111
|
-
|
112
|
-
mask = input_ids == self.config.video_token_id
|
113
|
-
mask_unsqueezed = mask.unsqueeze(-1)
|
114
|
-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
115
|
-
video_mask = mask_expanded.to(inputs_embeds.device)
|
116
|
-
|
117
|
-
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
118
|
-
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
119
|
-
|
120
|
-
if attention_mask is not None:
|
121
|
-
attention_mask = attention_mask.to(inputs_embeds.device)
|
122
|
-
|
123
|
-
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
124
|
-
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
125
|
-
# calculate RoPE index once per generation in the pre-fill stage only
|
126
|
-
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
127
|
-
position_ids, rope_deltas = self.get_rope_index(
|
128
|
-
input_ids,
|
129
|
-
image_grid_thw,
|
130
|
-
video_grid_thw,
|
131
|
-
second_per_grid_ts,
|
132
|
-
attention_mask,
|
133
|
-
)
|
134
|
-
self.rope_deltas = rope_deltas
|
135
|
-
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
136
|
-
else:
|
137
|
-
batch_size, seq_length, _ = inputs_embeds.shape
|
138
|
-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
|
139
|
-
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
140
|
-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
141
|
-
if cache_position is not None: # otherwise `deltas` is an int `0`
|
142
|
-
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
143
|
-
position_ids = position_ids.add(delta)
|
144
|
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
145
|
-
|
146
92
|
outputs = self.model(
|
147
|
-
input_ids=
|
93
|
+
input_ids=input_ids,
|
94
|
+
pixel_values=pixel_values,
|
95
|
+
pixel_values_videos=pixel_values_videos,
|
96
|
+
image_grid_thw=image_grid_thw,
|
97
|
+
video_grid_thw=video_grid_thw,
|
98
|
+
second_per_grid_ts=second_per_grid_ts,
|
148
99
|
position_ids=position_ids,
|
149
100
|
attention_mask=attention_mask,
|
150
101
|
past_key_values=past_key_values,
|
@@ -180,19 +131,10 @@ def lce_forward(
|
|
180
131
|
)
|
181
132
|
else:
|
182
133
|
logits = self.lm_head(hidden_states)
|
134
|
+
|
135
|
+
loss = None
|
183
136
|
if labels is not None:
|
184
|
-
|
185
|
-
logits = logits.float()
|
186
|
-
# Shift so that tokens < n predict n
|
187
|
-
shift_logits = logits[..., :-1, :].contiguous()
|
188
|
-
shift_labels = labels[..., 1:].contiguous()
|
189
|
-
# Flatten the tokens
|
190
|
-
loss_fct = CrossEntropyLoss()
|
191
|
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
192
|
-
shift_labels = shift_labels.view(-1)
|
193
|
-
# Enable model parallelism
|
194
|
-
shift_labels = shift_labels.to(shift_logits.device)
|
195
|
-
loss = loss_fct(shift_logits, shift_labels)
|
137
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
196
138
|
|
197
139
|
if not return_dict:
|
198
140
|
output = (logits,) + outputs[1:]
|
@@ -204,5 +146,5 @@ def lce_forward(
|
|
204
146
|
past_key_values=outputs.past_key_values,
|
205
147
|
hidden_states=outputs.hidden_states,
|
206
148
|
attentions=outputs.attentions,
|
207
|
-
rope_deltas=rope_deltas,
|
149
|
+
rope_deltas=outputs.rope_deltas,
|
208
150
|
)
|
@@ -5,14 +5,13 @@ from typing import Union
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from packaging import version
|
9
|
-
from torch.nn import CrossEntropyLoss
|
10
|
-
from transformers import __version__ as transformers_version
|
11
8
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
9
|
+
from transformers.utils import can_return_tuple
|
12
10
|
|
13
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
14
12
|
|
15
13
|
|
14
|
+
@can_return_tuple
|
16
15
|
def lce_forward(
|
17
16
|
self,
|
18
17
|
input_ids: torch.LongTensor = None,
|
@@ -35,15 +34,20 @@ def lce_forward(
|
|
35
34
|
**kwargs,
|
36
35
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
37
36
|
r"""
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
37
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
38
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
39
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
40
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
41
|
+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
42
|
+
The tensors corresponding to the input videos. Pixel values can be obtained using
|
43
|
+
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
|
44
|
+
[`Qwen2VLImageProcessor`] for processing videos.
|
45
|
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
46
|
+
The temporal, height and width of feature shape of each image in LLM.
|
47
|
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
48
|
+
The temporal, height and width of feature shape of each video in LLM.
|
49
|
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
50
|
+
The rope index difference between sequence length and multimodal rope.
|
47
51
|
|
48
52
|
Example:
|
49
53
|
|
@@ -75,80 +79,19 @@ def lce_forward(
|
|
75
79
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
76
80
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
77
81
|
```"""
|
82
|
+
|
78
83
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
79
84
|
output_hidden_states = (
|
80
85
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
81
86
|
)
|
82
87
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
83
88
|
|
84
|
-
if inputs_embeds is None:
|
85
|
-
inputs_embeds = self.model.embed_tokens(input_ids)
|
86
|
-
if pixel_values is not None:
|
87
|
-
pixel_values = pixel_values.type(self.visual.get_dtype())
|
88
|
-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
89
|
-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
90
|
-
n_image_features = image_embeds.shape[0]
|
91
|
-
if n_image_tokens != n_image_features:
|
92
|
-
raise ValueError(
|
93
|
-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
94
|
-
)
|
95
|
-
image_mask = (
|
96
|
-
(input_ids == self.config.image_token_id)
|
97
|
-
.unsqueeze(-1)
|
98
|
-
.expand_as(inputs_embeds)
|
99
|
-
.to(inputs_embeds.device)
|
100
|
-
)
|
101
|
-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
102
|
-
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
103
|
-
|
104
|
-
if pixel_values_videos is not None:
|
105
|
-
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
106
|
-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
107
|
-
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
108
|
-
n_video_features = video_embeds.shape[0]
|
109
|
-
if n_video_tokens != n_video_features:
|
110
|
-
raise ValueError(
|
111
|
-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
112
|
-
)
|
113
|
-
video_mask = (
|
114
|
-
(input_ids == self.config.video_token_id)
|
115
|
-
.unsqueeze(-1)
|
116
|
-
.expand_as(inputs_embeds)
|
117
|
-
.to(inputs_embeds.device)
|
118
|
-
)
|
119
|
-
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
120
|
-
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
121
|
-
|
122
|
-
if attention_mask is not None:
|
123
|
-
attention_mask = attention_mask.to(inputs_embeds.device)
|
124
|
-
|
125
|
-
if version.parse(transformers_version) > version.parse("4.46.3"):
|
126
|
-
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
|
127
|
-
# https://github.com/huggingface/transformers/issues/33401
|
128
|
-
# While correct, this breaks equivalence with past versions of Qwen2-VL from
|
129
|
-
# transformers and leads to failed tests or users noticing differences in results.
|
130
|
-
# TODO: remove above conditional when liger drops support for transformers<4.47.0
|
131
|
-
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
132
|
-
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
133
|
-
# calculate RoPE index once per generation in the pre-fill stage only
|
134
|
-
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
135
|
-
position_ids, rope_deltas = self.get_rope_index(
|
136
|
-
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
137
|
-
)
|
138
|
-
self.rope_deltas = rope_deltas
|
139
|
-
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
140
|
-
else:
|
141
|
-
batch_size, seq_length, _ = inputs_embeds.shape
|
142
|
-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
143
|
-
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
144
|
-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
145
|
-
if cache_position is not None: # otherwise `deltas` is an int `0`
|
146
|
-
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
147
|
-
position_ids = position_ids.add(delta)
|
148
|
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
149
|
-
|
150
89
|
outputs = self.model(
|
151
|
-
input_ids=
|
90
|
+
input_ids=input_ids,
|
91
|
+
pixel_values=pixel_values,
|
92
|
+
pixel_values_videos=pixel_values_videos,
|
93
|
+
image_grid_thw=image_grid_thw,
|
94
|
+
video_grid_thw=video_grid_thw,
|
152
95
|
position_ids=position_ids,
|
153
96
|
attention_mask=attention_mask,
|
154
97
|
past_key_values=past_key_values,
|
@@ -184,23 +127,10 @@ def lce_forward(
|
|
184
127
|
)
|
185
128
|
else:
|
186
129
|
logits = self.lm_head(hidden_states)
|
130
|
+
|
131
|
+
loss = None
|
187
132
|
if labels is not None:
|
188
|
-
|
189
|
-
logits = logits.float()
|
190
|
-
# Shift so that tokens < n predict n
|
191
|
-
shift_logits = logits[..., :-1, :].contiguous()
|
192
|
-
shift_labels = labels[..., 1:].contiguous()
|
193
|
-
# Flatten the tokens
|
194
|
-
loss_fct = CrossEntropyLoss()
|
195
|
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
196
|
-
shift_labels = shift_labels.view(-1)
|
197
|
-
# Enable model parallelism
|
198
|
-
shift_labels = shift_labels.to(shift_logits.device)
|
199
|
-
loss = loss_fct(shift_logits, shift_labels)
|
200
|
-
|
201
|
-
if not return_dict:
|
202
|
-
output = (logits,) + outputs[1:]
|
203
|
-
return (loss,) + output if loss is not None else output
|
133
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
204
134
|
|
205
135
|
return Qwen2VLCausalLMOutputWithPast(
|
206
136
|
loss=loss,
|
@@ -208,5 +138,5 @@ def lce_forward(
|
|
208
138
|
past_key_values=outputs.past_key_values,
|
209
139
|
hidden_states=outputs.hidden_states,
|
210
140
|
attentions=outputs.attentions,
|
211
|
-
rope_deltas=rope_deltas,
|
141
|
+
rope_deltas=outputs.rope_deltas,
|
212
142
|
)
|
@@ -1225,7 +1225,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
1225
1225
|
) -> None:
|
1226
1226
|
"""
|
1227
1227
|
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
1228
|
-
NOTE: Qwen2-VL is not
|
1228
|
+
NOTE: Qwen2-VL is not supported in transformers<4.52.4
|
1229
1229
|
|
1230
1230
|
Args:
|
1231
1231
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
@@ -1239,12 +1239,19 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
1239
1239
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
1240
1240
|
loaded. Default is None.
|
1241
1241
|
"""
|
1242
|
+
if transformer_version < version.parse("4.52.4"):
|
1243
|
+
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
|
1244
|
+
return
|
1245
|
+
|
1242
1246
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
1243
1247
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
1244
1248
|
)
|
1245
1249
|
|
1246
1250
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
1251
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
1252
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
1247
1253
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
1254
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
1248
1255
|
|
1249
1256
|
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
1250
1257
|
|
@@ -1266,24 +1273,38 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
1266
1273
|
# The model instance already exists, so we need to additionally patch the
|
1267
1274
|
# instance variables that reference already-instantiated modules
|
1268
1275
|
|
1269
|
-
|
1270
|
-
|
1276
|
+
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
|
1277
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
1278
|
+
# Not sure if it is subject to changes in the future.
|
1279
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
1280
|
+
text_model: Qwen2VLTextModel = model.language_model
|
1281
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
1282
|
+
elif isinstance(model, Qwen2VLTextModel):
|
1283
|
+
text_model: Qwen2VLTextModel = model
|
1284
|
+
vision_model = None
|
1285
|
+
else:
|
1286
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
1287
|
+
raise TypeError(
|
1288
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
1289
|
+
)
|
1271
1290
|
|
1272
|
-
|
1273
|
-
|
1274
|
-
for vision_block in
|
1291
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
1292
|
+
if vision_model is not None:
|
1293
|
+
for vision_block in vision_model.blocks:
|
1275
1294
|
if layer_norm:
|
1276
1295
|
_patch_layer_norm_module(vision_block.norm1)
|
1277
1296
|
_patch_layer_norm_module(vision_block.norm2)
|
1278
1297
|
|
1279
|
-
|
1280
|
-
|
1281
|
-
for decoder_layer in base_model.layers:
|
1282
|
-
if swiglu:
|
1283
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
1298
|
+
# Patch Qwen2VisionTextModel
|
1299
|
+
if text_model is not None:
|
1284
1300
|
if rms_norm:
|
1285
|
-
_patch_rms_norm_module(
|
1286
|
-
|
1301
|
+
_patch_rms_norm_module(text_model.norm)
|
1302
|
+
for decoder_layer in text_model.layers:
|
1303
|
+
if swiglu:
|
1304
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
1305
|
+
if rms_norm:
|
1306
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
1307
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
1287
1308
|
|
1288
1309
|
|
1289
1310
|
def apply_liger_kernel_to_qwen2_5_vl(
|
@@ -1309,12 +1330,19 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
1309
1330
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
1310
1331
|
loaded. Default is None.
|
1311
1332
|
"""
|
1333
|
+
if transformer_version < version.parse("4.52.4"):
|
1334
|
+
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
|
1335
|
+
return
|
1336
|
+
|
1312
1337
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
1313
1338
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
1314
1339
|
)
|
1315
1340
|
|
1316
1341
|
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
1342
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
1343
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
1317
1344
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
1345
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
1318
1346
|
|
1319
1347
|
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
1320
1348
|
|
@@ -1333,24 +1361,37 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
1333
1361
|
# The model instance already exists, so we need to additionally patch the
|
1334
1362
|
# instance variables that reference already-instantiated modules
|
1335
1363
|
|
1336
|
-
|
1337
|
-
|
1364
|
+
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
|
1365
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
1366
|
+
# Not sure if it is subject to changes in the future.
|
1367
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
1368
|
+
text_model: Qwen2_5_VLTextModel = model.language_model
|
1369
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
1370
|
+
elif isinstance(model, Qwen2_5_VLTextModel):
|
1371
|
+
text_model: Qwen2_5_VLTextModel = model
|
1372
|
+
vision_model = None
|
1373
|
+
else:
|
1374
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
1375
|
+
raise TypeError(
|
1376
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
1377
|
+
)
|
1338
1378
|
|
1339
|
-
if
|
1379
|
+
if vision_model is not None:
|
1340
1380
|
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
1341
1381
|
for vision_block in model.visual.blocks:
|
1342
1382
|
if rms_norm:
|
1343
1383
|
_patch_rms_norm_module(vision_block.norm1)
|
1344
1384
|
_patch_rms_norm_module(vision_block.norm2)
|
1345
1385
|
|
1346
|
-
if
|
1347
|
-
_patch_rms_norm_module(base_model.norm)
|
1348
|
-
for decoder_layer in base_model.layers:
|
1349
|
-
if swiglu:
|
1350
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
1386
|
+
if text_model is not None:
|
1351
1387
|
if rms_norm:
|
1352
|
-
_patch_rms_norm_module(
|
1353
|
-
|
1388
|
+
_patch_rms_norm_module(text_model.norm)
|
1389
|
+
for decoder_layer in text_model.layers:
|
1390
|
+
if swiglu:
|
1391
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
1392
|
+
if rms_norm:
|
1393
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
1394
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
1354
1395
|
|
1355
1396
|
|
1356
1397
|
def apply_liger_kernel_to_phi3(
|
@@ -1571,7 +1612,9 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
1571
1612
|
"qwen3": apply_liger_kernel_to_qwen3,
|
1572
1613
|
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
1573
1614
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
1615
|
+
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
1574
1616
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
1617
|
+
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
1575
1618
|
"phi3": apply_liger_kernel_to_phi3,
|
1576
1619
|
"paligemma": apply_liger_kernel_to_paligemma,
|
1577
1620
|
}
|
@@ -53,7 +53,7 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
|
|
53
53
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
54
54
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
55
55
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
56
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
56
|
+
liger_kernel/transformers/monkey_patch.py,sha256=IWqNiimHL0895yo0TjQ3lN_Y8fKGesxC-bF5He6zB2g,77536
|
57
57
|
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
58
58
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
59
59
|
liger_kernel/transformers/rms_norm.py,sha256=eErIr1n-13oVrc1VJY07lqazYelw_vlu9Az__RmXPSE,2717
|
@@ -79,17 +79,17 @@ liger_kernel/transformers/model/olmo2.py,sha256=6L_bo-ZUgO1lYppdJneOtYxNIylQKS6B
|
|
79
79
|
liger_kernel/transformers/model/paligemma.py,sha256=xuIx3oOwTgftU3jqLfWOxUxgCLBNJh0yNC21an9qDjo,18773
|
80
80
|
liger_kernel/transformers/model/phi3.py,sha256=m-MD_OuTaYMGZhHOvl-RHOVEObrL8tL5cBv3VTNd4F0,10376
|
81
81
|
liger_kernel/transformers/model/qwen2.py,sha256=SdN7V-MI3eX9s2DAFRvC1g-G146uG_5n1fnNdY9QwYk,9658
|
82
|
-
liger_kernel/transformers/model/qwen2_5_vl.py,sha256=
|
83
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=
|
82
|
+
liger_kernel/transformers/model/qwen2_5_vl.py,sha256=zEVVwotCXnAm3RRc8-1Nc8uitSWrwW4B9dYY2uOZDwg,6331
|
83
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=5vK-vtCDpKZ2w33xYp2BS8kQYWUbKMqaiKvQcI27Mss,5884
|
84
84
|
liger_kernel/transformers/model/qwen3.py,sha256=w2jBHuK9kK9EmOr5dnEIXNQXUgUSV_sJUkXSEwxLPHs,4885
|
85
85
|
liger_kernel/transformers/model/qwen3_moe.py,sha256=BkpfFH3fOH0yRfA7LF-AoHTLut2GV0Y4MOlkiIYewfU,5511
|
86
86
|
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
87
87
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
88
88
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
89
89
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
95
|
-
liger_kernel_nightly-0.5.10.
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/METADATA,sha256=49fgwei-BXjeHZEaeZCYnr3xfKN_psmRALV3GmK6YCk,24309
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
95
|
+
liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|