liger-kernel-nightly 0.6.3.dev20251031170118__py3-none-any.whl → 0.6.3.dev20251105012545__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,6 +56,8 @@ if TYPE_CHECKING:
56
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
57
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
58
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
59
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
60
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
59
61
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
60
62
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
61
63
 
@@ -120,6 +122,8 @@ def __getattr__(name: str):
120
122
  "apply_liger_kernel_to_qwen3",
121
123
  "apply_liger_kernel_to_qwen3_moe",
122
124
  "apply_liger_kernel_to_qwen3_next",
125
+ "apply_liger_kernel_to_qwen3_vl",
126
+ "apply_liger_kernel_to_qwen3_vl_moe",
123
127
  "apply_liger_kernel_to_smollm3",
124
128
  "apply_liger_kernel_to_smolvlm",
125
129
  }
@@ -190,6 +194,8 @@ if _TRANSFORMERS_AVAILABLE:
190
194
  "apply_liger_kernel_to_qwen3",
191
195
  "apply_liger_kernel_to_qwen3_moe",
192
196
  "apply_liger_kernel_to_qwen3_next",
197
+ "apply_liger_kernel_to_qwen3_vl",
198
+ "apply_liger_kernel_to_qwen3_vl_moe",
193
199
  "apply_liger_kernel_to_smollm3",
194
200
  "apply_liger_kernel_to_smolvlm",
195
201
  ]
@@ -0,0 +1,144 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLCausalLMOutputWithPast
9
+ from transformers.utils import can_return_tuple
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+
13
+
14
+ @can_return_tuple
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ pixel_values: Optional[torch.Tensor] = None,
28
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
29
+ image_grid_thw: Optional[torch.LongTensor] = None,
30
+ video_grid_thw: Optional[torch.LongTensor] = None,
31
+ rope_deltas: Optional[torch.LongTensor] = None,
32
+ cache_position: Optional[torch.LongTensor] = None,
33
+ second_per_grid_ts: Optional[torch.Tensor] = None,
34
+ skip_logits: Optional[bool] = None,
35
+ **kwargs,
36
+ ) -> Union[Tuple, Qwen3VLCausalLMOutputWithPast]:
37
+ """
38
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
43
+ The tensors corresponding to the input videos. Pixel values can be obtained using
44
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
45
+ [`Qwen2_5_VLImageProcessor`] for processing videos.
46
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
47
+ The temporal, height and width of feature shape of each image in LLM.
48
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
49
+ The temporal, height and width of feature shape of each video in LLM.
50
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
51
+ The rope index difference between sequence length and multimodal rope.
52
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
53
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
54
+ Example:
55
+ ```python
56
+ >>> from PIL import Image
57
+ >>> import requests
58
+ >>> from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
59
+ >>> model = Qwen3VLForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL")
60
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL")
61
+ >>> messages = [
62
+ {
63
+ "role": "user",
64
+ "content": [
65
+ {"type": "image"},
66
+ {"type": "text", "text": "What is shown in this image?"},
67
+ ],
68
+ },
69
+ ]
70
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
71
+ >>> image = Image.open(requests.get(url, stream=True).raw)
72
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
73
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
74
+ >>> # Generate
75
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
76
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
77
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
78
+ ```"""
79
+
80
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
81
+ output_hidden_states = (
82
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
83
+ )
84
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
85
+
86
+ outputs = self.model(
87
+ input_ids=input_ids,
88
+ pixel_values=pixel_values,
89
+ pixel_values_videos=pixel_values_videos,
90
+ image_grid_thw=image_grid_thw,
91
+ video_grid_thw=video_grid_thw,
92
+ second_per_grid_ts=second_per_grid_ts,
93
+ position_ids=position_ids,
94
+ attention_mask=attention_mask,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ use_cache=use_cache,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ cache_position=cache_position,
102
+ **kwargs,
103
+ )
104
+
105
+ hidden_states = outputs[0]
106
+
107
+ shift_labels = kwargs.pop("shift_labels", None)
108
+ loss = None
109
+ logits = None
110
+
111
+ if skip_logits and labels is None and shift_labels is None:
112
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
113
+
114
+ if skip_logits is None:
115
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
116
+
117
+ if skip_logits:
118
+ loss = LigerForCausalLMLoss(
119
+ hidden_states=hidden_states,
120
+ lm_head_weight=self.lm_head.weight,
121
+ labels=labels,
122
+ shift_labels=shift_labels,
123
+ hidden_size=self.config.text_config.hidden_size,
124
+ **kwargs,
125
+ )
126
+ else:
127
+ logits = self.lm_head(hidden_states)
128
+
129
+ loss = None
130
+ if labels is not None:
131
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
132
+
133
+ if not return_dict:
134
+ output = (logits,) + outputs[1:]
135
+ return (loss,) + output if loss is not None else output
136
+
137
+ return Qwen3VLCausalLMOutputWithPast(
138
+ loss=loss,
139
+ logits=logits,
140
+ past_key_values=outputs.past_key_values,
141
+ hidden_states=outputs.hidden_states,
142
+ attentions=outputs.attentions,
143
+ rope_deltas=outputs.rope_deltas,
144
+ )
@@ -0,0 +1,119 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeCausalLMOutputWithPast
9
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import load_balancing_loss_func
10
+ from transformers.utils import can_return_tuple
11
+
12
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
13
+
14
+
15
+ @can_return_tuple
16
+ def lce_forward(
17
+ self,
18
+ input_ids: torch.LongTensor = None,
19
+ attention_mask: Optional[torch.Tensor] = None,
20
+ position_ids: Optional[torch.LongTensor] = None,
21
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
22
+ inputs_embeds: Optional[torch.FloatTensor] = None,
23
+ labels: Optional[torch.LongTensor] = None,
24
+ use_cache: Optional[bool] = None,
25
+ output_attentions: Optional[bool] = None,
26
+ output_hidden_states: Optional[bool] = None,
27
+ return_dict: Optional[bool] = None,
28
+ pixel_values: Optional[torch.Tensor] = None,
29
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
30
+ image_grid_thw: Optional[torch.LongTensor] = None,
31
+ video_grid_thw: Optional[torch.LongTensor] = None,
32
+ rope_deltas: Optional[torch.LongTensor] = None,
33
+ cache_position: Optional[torch.LongTensor] = None,
34
+ second_per_grid_ts: Optional[torch.Tensor] = None,
35
+ skip_logits: Optional[bool] = None,
36
+ **kwargs,
37
+ ) -> Union[Tuple, Qwen3VLMoeCausalLMOutputWithPast]:
38
+ """
39
+ Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour.
40
+ """
41
+
42
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
43
+ output_hidden_states = (
44
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
45
+ )
46
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
47
+
48
+ outputs = self.model(
49
+ input_ids=input_ids,
50
+ pixel_values=pixel_values,
51
+ pixel_values_videos=pixel_values_videos,
52
+ image_grid_thw=image_grid_thw,
53
+ video_grid_thw=video_grid_thw,
54
+ second_per_grid_ts=second_per_grid_ts,
55
+ position_ids=position_ids,
56
+ attention_mask=attention_mask,
57
+ past_key_values=past_key_values,
58
+ inputs_embeds=inputs_embeds,
59
+ use_cache=use_cache,
60
+ output_attentions=output_attentions,
61
+ output_hidden_states=output_hidden_states,
62
+ return_dict=return_dict,
63
+ cache_position=cache_position,
64
+ **kwargs,
65
+ )
66
+
67
+ hidden_states = outputs[0]
68
+
69
+ shift_labels = kwargs.pop("shift_labels", None)
70
+ loss = None
71
+ logits = None
72
+
73
+ if skip_logits and labels is None and shift_labels is None:
74
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
75
+
76
+ if skip_logits is None:
77
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
78
+
79
+ if skip_logits:
80
+ loss = LigerForCausalLMLoss(
81
+ hidden_states=hidden_states,
82
+ lm_head_weight=self.lm_head.weight,
83
+ labels=labels,
84
+ shift_labels=shift_labels,
85
+ hidden_size=self.config.text_config.hidden_size,
86
+ **kwargs,
87
+ )
88
+ else:
89
+ logits = self.lm_head(hidden_states)
90
+
91
+ if labels is not None:
92
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
93
+
94
+ # Compute auxiliary load-balancing loss for MoE when requested
95
+ aux_loss = None
96
+ if kwargs.get("output_router_logits", False):
97
+ aux_loss = load_balancing_loss_func(
98
+ outputs.router_logits,
99
+ self.config.text_config.num_experts,
100
+ self.config.text_config.num_experts_per_tok,
101
+ attention_mask,
102
+ )
103
+ # If we computed training loss, add the scaled aux loss to it
104
+ if loss is not None and aux_loss is not None:
105
+ loss = loss + self.config.text_config.router_aux_loss_coef * aux_loss.to(loss.device)
106
+
107
+ if not return_dict:
108
+ output = (logits,) + outputs[1:]
109
+ return (loss,) + output if loss is not None else output
110
+
111
+ return Qwen3VLMoeCausalLMOutputWithPast(
112
+ loss=loss,
113
+ logits=logits,
114
+ past_key_values=outputs.past_key_values,
115
+ hidden_states=outputs.hidden_states,
116
+ attentions=outputs.attentions,
117
+ rope_deltas=outputs.rope_deltas,
118
+ aux_loss=aux_loss,
119
+ )
@@ -30,10 +30,14 @@ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mi
30
30
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
31
31
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
32
32
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
33
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
34
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
33
35
  from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
34
36
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
35
37
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
36
38
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
39
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
40
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
37
41
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
38
42
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
39
43
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -1643,6 +1647,154 @@ def apply_liger_kernel_to_qwen2_5_vl(
1643
1647
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1644
1648
 
1645
1649
 
1650
+ def apply_liger_kernel_to_qwen3_vl(
1651
+ rope: bool = True,
1652
+ cross_entropy: bool = False,
1653
+ fused_linear_cross_entropy: bool = True,
1654
+ rms_norm: bool = True,
1655
+ swiglu: bool = False,
1656
+ model: PreTrainedModel = None,
1657
+ ) -> None:
1658
+ """
1659
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1660
+
1661
+ Args:
1662
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1663
+ fused_linear_cross_entropy (bool):
1664
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1665
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1666
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1667
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1668
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1669
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1670
+ loaded. Default is None.
1671
+ """
1672
+
1673
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1674
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1675
+ )
1676
+
1677
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1678
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1679
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1680
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1681
+
1682
+ if rope:
1683
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1684
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1685
+
1686
+ if rms_norm:
1687
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1688
+
1689
+ if cross_entropy:
1690
+ from transformers.loss.loss_utils import nn
1691
+
1692
+ nn.functional.cross_entropy = liger_cross_entropy
1693
+
1694
+ if fused_linear_cross_entropy:
1695
+ if model is not None:
1696
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1697
+ else:
1698
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1699
+
1700
+ if model is not None and rms_norm:
1701
+ if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1702
+ text_model: Qwen3VLTextModel = model.language_model
1703
+ elif isinstance(model, Qwen3VLTextModel):
1704
+ text_model = model
1705
+ else:
1706
+ raise TypeError(
1707
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1708
+ )
1709
+
1710
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1711
+
1712
+ if text_model is not None:
1713
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1714
+ for decoder_layer in text_model.layers:
1715
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1716
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1717
+ self_attn = getattr(decoder_layer, "self_attn", None)
1718
+ if self_attn is not None:
1719
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1720
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1721
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1722
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1723
+
1724
+
1725
+ def apply_liger_kernel_to_qwen3_vl_moe(
1726
+ rope: bool = True,
1727
+ cross_entropy: bool = False,
1728
+ fused_linear_cross_entropy: bool = True,
1729
+ rms_norm: bool = True,
1730
+ swiglu: bool = False,
1731
+ model: PreTrainedModel = None,
1732
+ ) -> None:
1733
+ """
1734
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1735
+
1736
+ Args:
1737
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1738
+ fused_linear_cross_entropy (bool):
1739
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1740
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1741
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1742
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1743
+ loaded. Default is None.
1744
+ """
1745
+
1746
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1747
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1748
+ )
1749
+
1750
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1751
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1752
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1753
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1754
+
1755
+ if rope:
1756
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1757
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1758
+
1759
+ if rms_norm:
1760
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1761
+
1762
+ if cross_entropy:
1763
+ from transformers.loss.loss_utils import nn
1764
+
1765
+ nn.functional.cross_entropy = liger_cross_entropy
1766
+
1767
+ if fused_linear_cross_entropy:
1768
+ if model is not None:
1769
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1770
+ else:
1771
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1772
+
1773
+ if model is not None and rms_norm:
1774
+ if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1775
+ text_model: Qwen3VLMoeTextModel = model.language_model
1776
+ elif isinstance(model, Qwen3VLMoeTextModel):
1777
+ text_model = model
1778
+ else:
1779
+ raise TypeError(
1780
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1781
+ )
1782
+
1783
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1784
+
1785
+ if text_model is not None:
1786
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1787
+ for decoder_layer in text_model.layers:
1788
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1789
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1790
+ self_attn = getattr(decoder_layer, "self_attn", None)
1791
+ if self_attn is not None:
1792
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1793
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1794
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1795
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1796
+
1797
+
1646
1798
  def apply_liger_kernel_to_phi3(
1647
1799
  rope: bool = True,
1648
1800
  cross_entropy: bool = False,
@@ -2432,6 +2584,10 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2432
2584
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2433
2585
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2434
2586
  "qwen3_next": apply_liger_kernel_to_qwen3_next,
2587
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2588
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2589
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2590
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2435
2591
  "smollm3": apply_liger_kernel_to_smollm3,
2436
2592
  "phi3": apply_liger_kernel_to_phi3,
2437
2593
  "paligemma": apply_liger_kernel_to_paligemma,
@@ -1,3 +1,8 @@
1
+ from typing import Optional
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
1
6
  from liger_kernel.ops.rope import LigerRopeFunction
2
7
 
3
8
 
@@ -18,3 +23,41 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
18
23
  """
19
24
 
20
25
  return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
26
+
27
+
28
+ def liger_rotary_pos_emb_with_cast(
29
+ q: torch.Tensor,
30
+ k: torch.Tensor,
31
+ cos: torch.Tensor,
32
+ sin: torch.Tensor,
33
+ position_ids: Optional[torch.Tensor] = None,
34
+ unsqueeze_dim: int = 1,
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
37
+
38
+ q32 = q.to(torch.float32)
39
+ k32 = k.to(torch.float32)
40
+ cos32 = cos.to(torch.float32)
41
+ sin32 = sin.to(torch.float32)
42
+
43
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
44
+ return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
45
+
46
+
47
+ def liger_rotary_pos_emb_with_cast_and_leading_batch(
48
+ q: torch.Tensor,
49
+ k: torch.Tensor,
50
+ cos: torch.Tensor,
51
+ sin: torch.Tensor,
52
+ position_ids: Optional[torch.Tensor] = None,
53
+ unsqueeze_dim: int = 1,
54
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
56
+
57
+ q32 = q.to(torch.float32).unsqueeze(0)
58
+ k32 = k.to(torch.float32).unsqueeze(0)
59
+ cos32 = cos.to(torch.float32).unsqueeze(0)
60
+ sin32 = sin.to(torch.float32).unsqueeze(0)
61
+
62
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
63
+ return q_out.to(orig_q_dtype).squeeze(0), k_out.to(orig_k_dtype).squeeze(0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.3.dev20251031170118
3
+ Version: 0.6.3.dev20251105012545
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -42,7 +42,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
42
42
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
43
43
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
44
44
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
45
- liger_kernel/transformers/__init__.py,sha256=MAAd-YqPdG-j_sbrIE43nrICpA4xTg-dx6M06KWLMFU,9486
45
+ liger_kernel/transformers/__init__.py,sha256=iV1X0gH1JXwgeb7AeY8Ryv7q3r44MLQvSvn79yIVDzw,9874
46
46
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
47
47
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
48
48
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
@@ -59,12 +59,12 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
59
59
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
60
60
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
61
61
  liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
62
- liger_kernel/transformers/monkey_patch.py,sha256=3DLFMn2VusVcR6C5YElfpHJBRoJxvho0a2JoVdGqxHA,117266
62
+ liger_kernel/transformers/monkey_patch.py,sha256=O_kl0l56oHinVv-bwl1LU5nKPm6nA0YBjKTYmmwgRbk,124732
63
63
  liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
64
64
  liger_kernel/transformers/poly_norm.py,sha256=g5tC75i3qy1_N26ZUP-jfpct7ivQAEdJfIfx8IXzeyE,1377
65
65
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
66
66
  liger_kernel/transformers/rms_norm.py,sha256=HwddVqrqS58jE-M2_4NkFGARtCDBhGnkKyjBN9b3FYI,3004
67
- liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
67
+ liger_kernel/transformers/rope.py,sha256=VMlDZI6zss9mLaLcN5XCE_ktmYRwAi_Eh4TIgO6NrIQ,2361
68
68
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
69
69
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
70
70
  liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
@@ -97,15 +97,17 @@ liger_kernel/transformers/model/qwen2_vl.py,sha256=ZeasFPGs-bxm2Y_E15mo0YNx5wwtK
97
97
  liger_kernel/transformers/model/qwen3.py,sha256=Q2aOg5erPrgVgRcqJm8sefLSDtvU1AD5B7aJnP7mRMM,4956
98
98
  liger_kernel/transformers/model/qwen3_moe.py,sha256=1CwTMCNFDYsjGoa_aHFBagtC5HuJTV-s0__5UvcjD3A,5686
99
99
  liger_kernel/transformers/model/qwen3_next.py,sha256=7To7azriAogxeE7oEvByKztH9154dnDiDVNHHm7PZK4,5632
100
+ liger_kernel/transformers/model/qwen3_vl.py,sha256=YU76HJ0A9kG5CUaZM4i9Bzci4eeXcNl_VSC2tsPWA3k,6301
101
+ liger_kernel/transformers/model/qwen3_vl_moe.py,sha256=ykNIvGBtmcTkn236lhmJHzU1IHVR1Kq1YYYlJ5ynhw4,4445
100
102
  liger_kernel/transformers/model/smollm3.py,sha256=0KWVkDtXbjsBKhJnaquV6vUUYyLtfmNwYH0sxJt-qTk,7667
101
103
  liger_kernel/transformers/model/smolvlm.py,sha256=yFpPKawLVo3zXzLjM7Y_T8FyRrPxVyp-YPFMM8m3k0c,6734
102
104
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
103
105
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
104
106
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
105
107
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
106
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
107
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/METADATA,sha256=tIRv5lazhwtKsdhSattKCeY8GFJaJgIXFrPQXIXNd6E,24777
108
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
109
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
110
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
111
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/RECORD,,
108
+ liger_kernel_nightly-0.6.3.dev20251105012545.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
109
+ liger_kernel_nightly-0.6.3.dev20251105012545.dist-info/METADATA,sha256=MKC5NuGeIkIrDXRVDM3wv-p0cyVbwya5NujVcmSz-mQ,24777
110
+ liger_kernel_nightly-0.6.3.dev20251105012545.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
111
+ liger_kernel_nightly-0.6.3.dev20251105012545.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
112
+ liger_kernel_nightly-0.6.3.dev20251105012545.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
113
+ liger_kernel_nightly-0.6.3.dev20251105012545.dist-info/RECORD,,