liger-kernel-nightly 0.5.10.dev20250610174206__py3-none-any.whl → 0.5.10.dev20250611064616__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,12 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
8
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
9
+ from transformers.utils import can_return_tuple
10
10
 
11
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
12
 
13
13
 
14
+ @can_return_tuple
14
15
  def lce_forward(
15
16
  self,
16
17
  input_ids: torch.LongTensor = None,
@@ -34,14 +35,22 @@ def lce_forward(
34
35
  **kwargs,
35
36
  ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
36
37
  r"""
37
- Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
38
- Args:
39
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
-
44
- Returns:
38
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
43
+ The tensors corresponding to the input videos. Pixel values can be obtained using
44
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
45
+ [`Qwen2_5_VLImageProcessor`] for processing videos.
46
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
47
+ The temporal, height and width of feature shape of each image in LLM.
48
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
49
+ The temporal, height and width of feature shape of each video in LLM.
50
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
51
+ The rope index difference between sequence length and multimodal rope.
52
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
53
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
45
54
 
46
55
  Example:
47
56
 
@@ -73,78 +82,20 @@ def lce_forward(
73
82
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
74
83
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
75
84
  ```"""
85
+
76
86
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
77
87
  output_hidden_states = (
78
88
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
79
89
  )
80
90
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
91
 
82
- if inputs_embeds is None:
83
- inputs_embeds = self.model.embed_tokens(input_ids)
84
- if pixel_values is not None:
85
- pixel_values = pixel_values.type(self.visual.dtype)
86
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
87
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
88
- n_image_features = image_embeds.shape[0]
89
- if n_image_tokens != n_image_features:
90
- raise ValueError(
91
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
92
- )
93
-
94
- mask = input_ids == self.config.image_token_id
95
- mask_unsqueezed = mask.unsqueeze(-1)
96
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
97
- image_mask = mask_expanded.to(inputs_embeds.device)
98
-
99
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
100
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
101
-
102
- if pixel_values_videos is not None:
103
- pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
104
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
105
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
106
- n_video_features = video_embeds.shape[0]
107
- if n_video_tokens != n_video_features:
108
- raise ValueError(
109
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
110
- )
111
-
112
- mask = input_ids == self.config.video_token_id
113
- mask_unsqueezed = mask.unsqueeze(-1)
114
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
115
- video_mask = mask_expanded.to(inputs_embeds.device)
116
-
117
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
118
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
119
-
120
- if attention_mask is not None:
121
- attention_mask = attention_mask.to(inputs_embeds.device)
122
-
123
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
124
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
125
- # calculate RoPE index once per generation in the pre-fill stage only
126
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
127
- position_ids, rope_deltas = self.get_rope_index(
128
- input_ids,
129
- image_grid_thw,
130
- video_grid_thw,
131
- second_per_grid_ts,
132
- attention_mask,
133
- )
134
- self.rope_deltas = rope_deltas
135
- # then use the prev pre-calculated rope-deltas to get the correct position ids
136
- else:
137
- batch_size, seq_length, _ = inputs_embeds.shape
138
- delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
139
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
140
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
141
- if cache_position is not None: # otherwise `deltas` is an int `0`
142
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
143
- position_ids = position_ids.add(delta)
144
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
145
-
146
92
  outputs = self.model(
147
- input_ids=None,
93
+ input_ids=input_ids,
94
+ pixel_values=pixel_values,
95
+ pixel_values_videos=pixel_values_videos,
96
+ image_grid_thw=image_grid_thw,
97
+ video_grid_thw=video_grid_thw,
98
+ second_per_grid_ts=second_per_grid_ts,
148
99
  position_ids=position_ids,
149
100
  attention_mask=attention_mask,
150
101
  past_key_values=past_key_values,
@@ -180,19 +131,10 @@ def lce_forward(
180
131
  )
181
132
  else:
182
133
  logits = self.lm_head(hidden_states)
134
+
135
+ loss = None
183
136
  if labels is not None:
184
- # Upcast to float if we need to compute the loss to avoid potential precision issues
185
- logits = logits.float()
186
- # Shift so that tokens < n predict n
187
- shift_logits = logits[..., :-1, :].contiguous()
188
- shift_labels = labels[..., 1:].contiguous()
189
- # Flatten the tokens
190
- loss_fct = CrossEntropyLoss()
191
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
192
- shift_labels = shift_labels.view(-1)
193
- # Enable model parallelism
194
- shift_labels = shift_labels.to(shift_logits.device)
195
- loss = loss_fct(shift_logits, shift_labels)
137
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
196
138
 
197
139
  if not return_dict:
198
140
  output = (logits,) + outputs[1:]
@@ -204,5 +146,5 @@ def lce_forward(
204
146
  past_key_values=outputs.past_key_values,
205
147
  hidden_states=outputs.hidden_states,
206
148
  attentions=outputs.attentions,
207
- rope_deltas=rope_deltas,
149
+ rope_deltas=outputs.rope_deltas,
208
150
  )
@@ -5,14 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from packaging import version
9
- from torch.nn import CrossEntropyLoss
10
- from transformers import __version__ as transformers_version
11
8
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
9
+ from transformers.utils import can_return_tuple
12
10
 
13
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
12
 
15
13
 
14
+ @can_return_tuple
16
15
  def lce_forward(
17
16
  self,
18
17
  input_ids: torch.LongTensor = None,
@@ -35,15 +34,20 @@ def lce_forward(
35
34
  **kwargs,
36
35
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
37
36
  r"""
38
- Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
39
-
40
- Args:
41
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
42
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
43
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
44
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
45
-
46
- Returns:
37
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
38
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
39
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
42
+ The tensors corresponding to the input videos. Pixel values can be obtained using
43
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
44
+ [`Qwen2VLImageProcessor`] for processing videos.
45
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
46
+ The temporal, height and width of feature shape of each image in LLM.
47
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
48
+ The temporal, height and width of feature shape of each video in LLM.
49
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
50
+ The rope index difference between sequence length and multimodal rope.
47
51
 
48
52
  Example:
49
53
 
@@ -75,80 +79,19 @@ def lce_forward(
75
79
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
76
80
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
77
81
  ```"""
82
+
78
83
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
79
84
  output_hidden_states = (
80
85
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81
86
  )
82
87
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
88
 
84
- if inputs_embeds is None:
85
- inputs_embeds = self.model.embed_tokens(input_ids)
86
- if pixel_values is not None:
87
- pixel_values = pixel_values.type(self.visual.get_dtype())
88
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
89
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
90
- n_image_features = image_embeds.shape[0]
91
- if n_image_tokens != n_image_features:
92
- raise ValueError(
93
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
94
- )
95
- image_mask = (
96
- (input_ids == self.config.image_token_id)
97
- .unsqueeze(-1)
98
- .expand_as(inputs_embeds)
99
- .to(inputs_embeds.device)
100
- )
101
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
102
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
103
-
104
- if pixel_values_videos is not None:
105
- pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
106
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
107
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
108
- n_video_features = video_embeds.shape[0]
109
- if n_video_tokens != n_video_features:
110
- raise ValueError(
111
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
112
- )
113
- video_mask = (
114
- (input_ids == self.config.video_token_id)
115
- .unsqueeze(-1)
116
- .expand_as(inputs_embeds)
117
- .to(inputs_embeds.device)
118
- )
119
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
120
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
121
-
122
- if attention_mask is not None:
123
- attention_mask = attention_mask.to(inputs_embeds.device)
124
-
125
- if version.parse(transformers_version) > version.parse("4.46.3"):
126
- # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
127
- # https://github.com/huggingface/transformers/issues/33401
128
- # While correct, this breaks equivalence with past versions of Qwen2-VL from
129
- # transformers and leads to failed tests or users noticing differences in results.
130
- # TODO: remove above conditional when liger drops support for transformers<4.47.0
131
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
132
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
133
- # calculate RoPE index once per generation in the pre-fill stage only
134
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
135
- position_ids, rope_deltas = self.get_rope_index(
136
- input_ids, image_grid_thw, video_grid_thw, attention_mask
137
- )
138
- self.rope_deltas = rope_deltas
139
- # then use the prev pre-calculated rope-deltas to get the correct position ids
140
- else:
141
- batch_size, seq_length, _ = inputs_embeds.shape
142
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
143
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
144
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
145
- if cache_position is not None: # otherwise `deltas` is an int `0`
146
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
147
- position_ids = position_ids.add(delta)
148
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
149
-
150
89
  outputs = self.model(
151
- input_ids=None,
90
+ input_ids=input_ids,
91
+ pixel_values=pixel_values,
92
+ pixel_values_videos=pixel_values_videos,
93
+ image_grid_thw=image_grid_thw,
94
+ video_grid_thw=video_grid_thw,
152
95
  position_ids=position_ids,
153
96
  attention_mask=attention_mask,
154
97
  past_key_values=past_key_values,
@@ -184,23 +127,10 @@ def lce_forward(
184
127
  )
185
128
  else:
186
129
  logits = self.lm_head(hidden_states)
130
+
131
+ loss = None
187
132
  if labels is not None:
188
- # Upcast to float if we need to compute the loss to avoid potential precision issues
189
- logits = logits.float()
190
- # Shift so that tokens < n predict n
191
- shift_logits = logits[..., :-1, :].contiguous()
192
- shift_labels = labels[..., 1:].contiguous()
193
- # Flatten the tokens
194
- loss_fct = CrossEntropyLoss()
195
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
196
- shift_labels = shift_labels.view(-1)
197
- # Enable model parallelism
198
- shift_labels = shift_labels.to(shift_logits.device)
199
- loss = loss_fct(shift_logits, shift_labels)
200
-
201
- if not return_dict:
202
- output = (logits,) + outputs[1:]
203
- return (loss,) + output if loss is not None else output
133
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
204
134
 
205
135
  return Qwen2VLCausalLMOutputWithPast(
206
136
  loss=loss,
@@ -208,5 +138,5 @@ def lce_forward(
208
138
  past_key_values=outputs.past_key_values,
209
139
  hidden_states=outputs.hidden_states,
210
140
  attentions=outputs.attentions,
211
- rope_deltas=rope_deltas,
141
+ rope_deltas=outputs.rope_deltas,
212
142
  )
@@ -1225,7 +1225,7 @@ def apply_liger_kernel_to_qwen2_vl(
1225
1225
  ) -> None:
1226
1226
  """
1227
1227
  Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1228
- NOTE: Qwen2-VL is not available in transformers<4.45.0
1228
+ NOTE: Qwen2-VL is not supported in transformers<4.52.4
1229
1229
 
1230
1230
  Args:
1231
1231
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
@@ -1239,12 +1239,19 @@ def apply_liger_kernel_to_qwen2_vl(
1239
1239
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1240
1240
  loaded. Default is None.
1241
1241
  """
1242
+ if transformer_version < version.parse("4.52.4"):
1243
+ logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
1244
+ return
1245
+
1242
1246
  assert not (cross_entropy and fused_linear_cross_entropy), (
1243
1247
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1244
1248
  )
1245
1249
 
1246
1250
  from transformers.models.qwen2_vl import modeling_qwen2_vl
1251
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
1252
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
1247
1253
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1254
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
1248
1255
 
1249
1256
  from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1250
1257
 
@@ -1266,24 +1273,38 @@ def apply_liger_kernel_to_qwen2_vl(
1266
1273
  # The model instance already exists, so we need to additionally patch the
1267
1274
  # instance variables that reference already-instantiated modules
1268
1275
 
1269
- # get the base model from the model instance
1270
- base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
1276
+ if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1277
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1278
+ # Not sure if it is subject to changes in the future.
1279
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1280
+ text_model: Qwen2VLTextModel = model.language_model
1281
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1282
+ elif isinstance(model, Qwen2VLTextModel):
1283
+ text_model: Qwen2VLTextModel = model
1284
+ vision_model = None
1285
+ else:
1286
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1287
+ raise TypeError(
1288
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1289
+ )
1271
1290
 
1272
- if hasattr(model, "visual"):
1273
- # Patch Qwen2VisionTransformerPretrainedModel
1274
- for vision_block in model.visual.blocks:
1291
+ # Patch Qwen2VisionTransformerPretrainedModel
1292
+ if vision_model is not None:
1293
+ for vision_block in vision_model.blocks:
1275
1294
  if layer_norm:
1276
1295
  _patch_layer_norm_module(vision_block.norm1)
1277
1296
  _patch_layer_norm_module(vision_block.norm2)
1278
1297
 
1279
- if rms_norm:
1280
- _patch_rms_norm_module(base_model.norm)
1281
- for decoder_layer in base_model.layers:
1282
- if swiglu:
1283
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1298
+ # Patch Qwen2VisionTextModel
1299
+ if text_model is not None:
1284
1300
  if rms_norm:
1285
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1286
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1301
+ _patch_rms_norm_module(text_model.norm)
1302
+ for decoder_layer in text_model.layers:
1303
+ if swiglu:
1304
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1305
+ if rms_norm:
1306
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1307
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1287
1308
 
1288
1309
 
1289
1310
  def apply_liger_kernel_to_qwen2_5_vl(
@@ -1309,12 +1330,19 @@ def apply_liger_kernel_to_qwen2_5_vl(
1309
1330
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1310
1331
  loaded. Default is None.
1311
1332
  """
1333
+ if transformer_version < version.parse("4.52.4"):
1334
+ logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
1335
+ return
1336
+
1312
1337
  assert not (cross_entropy and fused_linear_cross_entropy), (
1313
1338
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1314
1339
  )
1315
1340
 
1316
1341
  from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1342
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
1343
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
1317
1344
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1345
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
1318
1346
 
1319
1347
  from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1320
1348
 
@@ -1333,24 +1361,37 @@ def apply_liger_kernel_to_qwen2_5_vl(
1333
1361
  # The model instance already exists, so we need to additionally patch the
1334
1362
  # instance variables that reference already-instantiated modules
1335
1363
 
1336
- # get the base model from the model instance
1337
- base_model: Qwen2_5_VLModel = getattr(model, model.base_model_prefix, model)
1364
+ if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1365
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1366
+ # Not sure if it is subject to changes in the future.
1367
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1368
+ text_model: Qwen2_5_VLTextModel = model.language_model
1369
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1370
+ elif isinstance(model, Qwen2_5_VLTextModel):
1371
+ text_model: Qwen2_5_VLTextModel = model
1372
+ vision_model = None
1373
+ else:
1374
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1375
+ raise TypeError(
1376
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1377
+ )
1338
1378
 
1339
- if hasattr(model, "visual"):
1379
+ if vision_model is not None:
1340
1380
  # Patch Qwen2_5_VisionTransformerPretrainedModel
1341
1381
  for vision_block in model.visual.blocks:
1342
1382
  if rms_norm:
1343
1383
  _patch_rms_norm_module(vision_block.norm1)
1344
1384
  _patch_rms_norm_module(vision_block.norm2)
1345
1385
 
1346
- if rms_norm:
1347
- _patch_rms_norm_module(base_model.norm)
1348
- for decoder_layer in base_model.layers:
1349
- if swiglu:
1350
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1386
+ if text_model is not None:
1351
1387
  if rms_norm:
1352
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1353
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1388
+ _patch_rms_norm_module(text_model.norm)
1389
+ for decoder_layer in text_model.layers:
1390
+ if swiglu:
1391
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1392
+ if rms_norm:
1393
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1394
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1354
1395
 
1355
1396
 
1356
1397
  def apply_liger_kernel_to_phi3(
@@ -1571,7 +1612,9 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1571
1612
  "qwen3": apply_liger_kernel_to_qwen3,
1572
1613
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
1573
1614
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
1615
+ "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1574
1616
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1617
+ "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
1575
1618
  "phi3": apply_liger_kernel_to_phi3,
1576
1619
  "paligemma": apply_liger_kernel_to_paligemma,
1577
1620
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250610174206
3
+ Version: 0.5.10.dev20250611064616
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -53,7 +53,7 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
53
53
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
54
54
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
55
55
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
56
- liger_kernel/transformers/monkey_patch.py,sha256=zeqmbU__X965iSZ4ZO0Zq3kq6qvqfSgU7B3acxynL3Y,74605
56
+ liger_kernel/transformers/monkey_patch.py,sha256=IWqNiimHL0895yo0TjQ3lN_Y8fKGesxC-bF5He6zB2g,77536
57
57
  liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
58
58
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
59
59
  liger_kernel/transformers/rms_norm.py,sha256=eErIr1n-13oVrc1VJY07lqazYelw_vlu9Az__RmXPSE,2717
@@ -79,17 +79,17 @@ liger_kernel/transformers/model/olmo2.py,sha256=6L_bo-ZUgO1lYppdJneOtYxNIylQKS6B
79
79
  liger_kernel/transformers/model/paligemma.py,sha256=xuIx3oOwTgftU3jqLfWOxUxgCLBNJh0yNC21an9qDjo,18773
80
80
  liger_kernel/transformers/model/phi3.py,sha256=m-MD_OuTaYMGZhHOvl-RHOVEObrL8tL5cBv3VTNd4F0,10376
81
81
  liger_kernel/transformers/model/qwen2.py,sha256=SdN7V-MI3eX9s2DAFRvC1g-G146uG_5n1fnNdY9QwYk,9658
82
- liger_kernel/transformers/model/qwen2_5_vl.py,sha256=k6jt1bTCJsKsZVGhTxqIbDzmnL8-B3CpWJOjLazswbo,9203
83
- liger_kernel/transformers/model/qwen2_vl.py,sha256=Cgs7-nPlKFifiDO9gqSI6np4vRUVCKiqoospT_vIi_M,9614
82
+ liger_kernel/transformers/model/qwen2_5_vl.py,sha256=zEVVwotCXnAm3RRc8-1Nc8uitSWrwW4B9dYY2uOZDwg,6331
83
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=5vK-vtCDpKZ2w33xYp2BS8kQYWUbKMqaiKvQcI27Mss,5884
84
84
  liger_kernel/transformers/model/qwen3.py,sha256=w2jBHuK9kK9EmOr5dnEIXNQXUgUSV_sJUkXSEwxLPHs,4885
85
85
  liger_kernel/transformers/model/qwen3_moe.py,sha256=BkpfFH3fOH0yRfA7LF-AoHTLut2GV0Y4MOlkiIYewfU,5511
86
86
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
87
87
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
88
88
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
89
89
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
90
- liger_kernel_nightly-0.5.10.dev20250610174206.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
91
- liger_kernel_nightly-0.5.10.dev20250610174206.dist-info/METADATA,sha256=T6CCI8j-_GLD4_OTFov5VFLiGK7sITnt6Ht6zVDPhqw,24309
92
- liger_kernel_nightly-0.5.10.dev20250610174206.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
93
- liger_kernel_nightly-0.5.10.dev20250610174206.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
94
- liger_kernel_nightly-0.5.10.dev20250610174206.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
95
- liger_kernel_nightly-0.5.10.dev20250610174206.dist-info/RECORD,,
90
+ liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
91
+ liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/METADATA,sha256=49fgwei-BXjeHZEaeZCYnr3xfKN_psmRALV3GmK6YCk,24309
92
+ liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
93
+ liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
94
+ liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
95
+ liger_kernel_nightly-0.5.10.dev20250611064616.dist-info/RECORD,,