liger-kernel-nightly 0.6.3.dev20251031170118__py3-none-any.whl → 0.6.3.dev20251101160510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,6 +56,8 @@ if TYPE_CHECKING:
56
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
57
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
58
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
59
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
60
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
59
61
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
60
62
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
61
63
 
@@ -120,6 +122,8 @@ def __getattr__(name: str):
120
122
  "apply_liger_kernel_to_qwen3",
121
123
  "apply_liger_kernel_to_qwen3_moe",
122
124
  "apply_liger_kernel_to_qwen3_next",
125
+ "apply_liger_kernel_to_qwen3_vl",
126
+ "apply_liger_kernel_to_qwen3_vl_moe",
123
127
  "apply_liger_kernel_to_smollm3",
124
128
  "apply_liger_kernel_to_smolvlm",
125
129
  }
@@ -190,6 +194,8 @@ if _TRANSFORMERS_AVAILABLE:
190
194
  "apply_liger_kernel_to_qwen3",
191
195
  "apply_liger_kernel_to_qwen3_moe",
192
196
  "apply_liger_kernel_to_qwen3_next",
197
+ "apply_liger_kernel_to_qwen3_vl",
198
+ "apply_liger_kernel_to_qwen3_vl_moe",
193
199
  "apply_liger_kernel_to_smollm3",
194
200
  "apply_liger_kernel_to_smolvlm",
195
201
  ]
@@ -0,0 +1,144 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLCausalLMOutputWithPast
9
+ from transformers.utils import can_return_tuple
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+
13
+
14
+ @can_return_tuple
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ pixel_values: Optional[torch.Tensor] = None,
28
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
29
+ image_grid_thw: Optional[torch.LongTensor] = None,
30
+ video_grid_thw: Optional[torch.LongTensor] = None,
31
+ rope_deltas: Optional[torch.LongTensor] = None,
32
+ cache_position: Optional[torch.LongTensor] = None,
33
+ second_per_grid_ts: Optional[torch.Tensor] = None,
34
+ skip_logits: Optional[bool] = None,
35
+ **kwargs,
36
+ ) -> Union[Tuple, Qwen3VLCausalLMOutputWithPast]:
37
+ """
38
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
43
+ The tensors corresponding to the input videos. Pixel values can be obtained using
44
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
45
+ [`Qwen2_5_VLImageProcessor`] for processing videos.
46
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
47
+ The temporal, height and width of feature shape of each image in LLM.
48
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
49
+ The temporal, height and width of feature shape of each video in LLM.
50
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
51
+ The rope index difference between sequence length and multimodal rope.
52
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
53
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
54
+ Example:
55
+ ```python
56
+ >>> from PIL import Image
57
+ >>> import requests
58
+ >>> from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
59
+ >>> model = Qwen3VLForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL")
60
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL")
61
+ >>> messages = [
62
+ {
63
+ "role": "user",
64
+ "content": [
65
+ {"type": "image"},
66
+ {"type": "text", "text": "What is shown in this image?"},
67
+ ],
68
+ },
69
+ ]
70
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
71
+ >>> image = Image.open(requests.get(url, stream=True).raw)
72
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
73
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
74
+ >>> # Generate
75
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
76
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
77
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
78
+ ```"""
79
+
80
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
81
+ output_hidden_states = (
82
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
83
+ )
84
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
85
+
86
+ outputs = self.model(
87
+ input_ids=input_ids,
88
+ pixel_values=pixel_values,
89
+ pixel_values_videos=pixel_values_videos,
90
+ image_grid_thw=image_grid_thw,
91
+ video_grid_thw=video_grid_thw,
92
+ second_per_grid_ts=second_per_grid_ts,
93
+ position_ids=position_ids,
94
+ attention_mask=attention_mask,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ use_cache=use_cache,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ cache_position=cache_position,
102
+ **kwargs,
103
+ )
104
+
105
+ hidden_states = outputs[0]
106
+
107
+ shift_labels = kwargs.pop("shift_labels", None)
108
+ loss = None
109
+ logits = None
110
+
111
+ if skip_logits and labels is None and shift_labels is None:
112
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
113
+
114
+ if skip_logits is None:
115
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
116
+
117
+ if skip_logits:
118
+ loss = LigerForCausalLMLoss(
119
+ hidden_states=hidden_states,
120
+ lm_head_weight=self.lm_head.weight,
121
+ labels=labels,
122
+ shift_labels=shift_labels,
123
+ hidden_size=self.config.text_config.hidden_size,
124
+ **kwargs,
125
+ )
126
+ else:
127
+ logits = self.lm_head(hidden_states)
128
+
129
+ loss = None
130
+ if labels is not None:
131
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
132
+
133
+ if not return_dict:
134
+ output = (logits,) + outputs[1:]
135
+ return (loss,) + output if loss is not None else output
136
+
137
+ return Qwen3VLCausalLMOutputWithPast(
138
+ loss=loss,
139
+ logits=logits,
140
+ past_key_values=outputs.past_key_values,
141
+ hidden_states=outputs.hidden_states,
142
+ attentions=outputs.attentions,
143
+ rope_deltas=outputs.rope_deltas,
144
+ )
@@ -0,0 +1,121 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.utils import can_return_tuple
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
12
+ Qwen3VLMoeCausalLMOutputWithPast,
13
+ load_balancing_loss_func,
14
+ )
15
+
16
+
17
+ @can_return_tuple
18
+ def lce_forward(
19
+ self,
20
+ input_ids: torch.LongTensor = None,
21
+ attention_mask: Optional[torch.Tensor] = None,
22
+ position_ids: Optional[torch.LongTensor] = None,
23
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
24
+ inputs_embeds: Optional[torch.FloatTensor] = None,
25
+ labels: Optional[torch.LongTensor] = None,
26
+ use_cache: Optional[bool] = None,
27
+ output_attentions: Optional[bool] = None,
28
+ output_hidden_states: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
30
+ pixel_values: Optional[torch.Tensor] = None,
31
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
32
+ image_grid_thw: Optional[torch.LongTensor] = None,
33
+ video_grid_thw: Optional[torch.LongTensor] = None,
34
+ rope_deltas: Optional[torch.LongTensor] = None,
35
+ cache_position: Optional[torch.LongTensor] = None,
36
+ second_per_grid_ts: Optional[torch.Tensor] = None,
37
+ skip_logits: Optional[bool] = None,
38
+ **kwargs,
39
+ ) -> Union[Tuple, Qwen3VLMoeCausalLMOutputWithPast]:
40
+ """
41
+ Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour.
42
+ """
43
+
44
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
45
+ output_hidden_states = (
46
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
47
+ )
48
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
49
+
50
+ outputs = self.model(
51
+ input_ids=input_ids,
52
+ pixel_values=pixel_values,
53
+ pixel_values_videos=pixel_values_videos,
54
+ image_grid_thw=image_grid_thw,
55
+ video_grid_thw=video_grid_thw,
56
+ second_per_grid_ts=second_per_grid_ts,
57
+ position_ids=position_ids,
58
+ attention_mask=attention_mask,
59
+ past_key_values=past_key_values,
60
+ inputs_embeds=inputs_embeds,
61
+ use_cache=use_cache,
62
+ output_attentions=output_attentions,
63
+ output_hidden_states=output_hidden_states,
64
+ return_dict=return_dict,
65
+ cache_position=cache_position,
66
+ **kwargs,
67
+ )
68
+
69
+ hidden_states = outputs[0]
70
+
71
+ shift_labels = kwargs.pop("shift_labels", None)
72
+ loss = None
73
+ logits = None
74
+
75
+ if skip_logits and labels is None and shift_labels is None:
76
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
77
+
78
+ if skip_logits is None:
79
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
80
+
81
+ if skip_logits:
82
+ loss = LigerForCausalLMLoss(
83
+ hidden_states=hidden_states,
84
+ lm_head_weight=self.lm_head.weight,
85
+ labels=labels,
86
+ shift_labels=shift_labels,
87
+ hidden_size=self.config.text_config.hidden_size,
88
+ **kwargs,
89
+ )
90
+ else:
91
+ logits = self.lm_head(hidden_states)
92
+
93
+ if labels is not None:
94
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
95
+
96
+ # Compute auxiliary load-balancing loss for MoE when requested
97
+ aux_loss = None
98
+ if kwargs.get("output_router_logits", False):
99
+ aux_loss = load_balancing_loss_func(
100
+ outputs.router_logits,
101
+ self.config.text_config.num_experts,
102
+ self.config.text_config.num_experts_per_tok,
103
+ attention_mask,
104
+ )
105
+ # If we computed training loss, add the scaled aux loss to it
106
+ if loss is not None and aux_loss is not None:
107
+ loss = loss + self.config.text_config.router_aux_loss_coef * aux_loss.to(loss.device)
108
+
109
+ if not return_dict:
110
+ output = (logits,) + outputs[1:]
111
+ return (loss,) + output if loss is not None else output
112
+
113
+ return Qwen3VLMoeCausalLMOutputWithPast(
114
+ loss=loss,
115
+ logits=logits,
116
+ past_key_values=outputs.past_key_values,
117
+ hidden_states=outputs.hidden_states,
118
+ attentions=outputs.attentions,
119
+ rope_deltas=outputs.rope_deltas,
120
+ aux_loss=aux_loss,
121
+ )
@@ -6,6 +6,7 @@ from types import MethodType
6
6
  from typing import Callable
7
7
  from typing import Optional
8
8
 
9
+ import torch
9
10
  import transformers
10
11
 
11
12
  from packaging import version
@@ -30,10 +31,16 @@ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mi
30
31
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
31
32
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
32
33
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
34
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
35
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
33
36
  from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
34
37
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
35
38
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
36
- from liger_kernel.transformers.rope import liger_rotary_pos_emb
39
+ from liger_kernel.transformers.rope import (
40
+ liger_rotary_pos_emb,
41
+ liger_rotary_pos_emb_with_cast,
42
+ liger_rotary_pos_emb_with_cast_and_leading_batch,
43
+ )
37
44
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
38
45
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
39
46
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -57,6 +64,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
57
64
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
58
65
 
59
66
 
67
+
60
68
  def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
61
69
  # Check if the module is a PEFT ModulesToSaveWrapper
62
70
  # If it is, we need to patch the modules_to_save.default and original_modules
@@ -1643,6 +1651,156 @@ def apply_liger_kernel_to_qwen2_5_vl(
1643
1651
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1644
1652
 
1645
1653
 
1654
+
1655
+ def apply_liger_kernel_to_qwen3_vl(
1656
+ rope: bool = True,
1657
+ cross_entropy: bool = False,
1658
+ fused_linear_cross_entropy: bool = True,
1659
+ rms_norm: bool = True,
1660
+ swiglu: bool = False,
1661
+ model: PreTrainedModel = None,
1662
+ ) -> None:
1663
+ """
1664
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1665
+
1666
+ Args:
1667
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1668
+ fused_linear_cross_entropy (bool):
1669
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1670
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1671
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1672
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1673
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1674
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1675
+ loaded. Default is None.
1676
+ """
1677
+
1678
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1679
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1680
+ )
1681
+
1682
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1683
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1684
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1685
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1686
+
1687
+ if rope:
1688
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1689
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1690
+
1691
+
1692
+ if rms_norm:
1693
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1694
+
1695
+ if cross_entropy:
1696
+ from transformers.loss.loss_utils import nn
1697
+
1698
+ nn.functional.cross_entropy = liger_cross_entropy
1699
+
1700
+ if fused_linear_cross_entropy:
1701
+ if model is not None:
1702
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1703
+ else:
1704
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1705
+
1706
+ if model is not None and rms_norm:
1707
+ if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1708
+ text_model: Qwen3VLTextModel = model.language_model
1709
+ elif isinstance(model, Qwen3VLTextModel):
1710
+ text_model = model
1711
+ else:
1712
+ raise TypeError(
1713
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1714
+ )
1715
+
1716
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1717
+
1718
+ if text_model is not None:
1719
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1720
+ for decoder_layer in text_model.layers:
1721
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1722
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1723
+ self_attn = getattr(decoder_layer, "self_attn", None)
1724
+ if self_attn is not None:
1725
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1726
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1727
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1728
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1729
+
1730
+
1731
+ def apply_liger_kernel_to_qwen3_vl_moe(
1732
+ rope: bool = True,
1733
+ cross_entropy: bool = False,
1734
+ fused_linear_cross_entropy: bool = True,
1735
+ rms_norm: bool = True,
1736
+ swiglu: bool = False,
1737
+ model: PreTrainedModel = None,
1738
+ ) -> None:
1739
+ """
1740
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1741
+
1742
+ Args:
1743
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1744
+ fused_linear_cross_entropy (bool):
1745
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1746
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1747
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1748
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1749
+ loaded. Default is None.
1750
+ """
1751
+
1752
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1753
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1754
+ )
1755
+
1756
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1757
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1758
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1759
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1760
+
1761
+ if rope:
1762
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1763
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1764
+
1765
+ if rms_norm:
1766
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1767
+
1768
+ if cross_entropy:
1769
+ from transformers.loss.loss_utils import nn
1770
+
1771
+ nn.functional.cross_entropy = liger_cross_entropy
1772
+
1773
+ if fused_linear_cross_entropy:
1774
+ if model is not None:
1775
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1776
+ else:
1777
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1778
+
1779
+ if model is not None and rms_norm:
1780
+ if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1781
+ text_model: Qwen3VLMoeTextModel = model.language_model
1782
+ elif isinstance(model, Qwen3VLMoeTextModel):
1783
+ text_model = model
1784
+ else:
1785
+ raise TypeError(
1786
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1787
+ )
1788
+
1789
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1790
+
1791
+ if text_model is not None:
1792
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1793
+ for decoder_layer in text_model.layers:
1794
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1795
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1796
+ self_attn = getattr(decoder_layer, "self_attn", None)
1797
+ if self_attn is not None:
1798
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1799
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1800
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1801
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1802
+
1803
+
1646
1804
  def apply_liger_kernel_to_phi3(
1647
1805
  rope: bool = True,
1648
1806
  cross_entropy: bool = False,
@@ -2432,6 +2590,10 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2432
2590
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2433
2591
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2434
2592
  "qwen3_next": apply_liger_kernel_to_qwen3_next,
2593
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2594
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2595
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2596
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2435
2597
  "smollm3": apply_liger_kernel_to_smollm3,
2436
2598
  "phi3": apply_liger_kernel_to_phi3,
2437
2599
  "paligemma": apply_liger_kernel_to_paligemma,
@@ -1,3 +1,7 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
1
5
  from liger_kernel.ops.rope import LigerRopeFunction
2
6
 
3
7
 
@@ -18,3 +22,46 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
18
22
  """
19
23
 
20
24
  return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
25
+
26
+
27
+ def liger_rotary_pos_emb_with_cast(
28
+ q: torch.Tensor,
29
+ k: torch.Tensor,
30
+ cos: torch.Tensor,
31
+ sin: torch.Tensor,
32
+ position_ids: Optional[torch.Tensor] = None,
33
+ unsqueeze_dim: int = 1,
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+
36
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
37
+
38
+ q32 = q.to(torch.float32)
39
+ k32 = k.to(torch.float32)
40
+ cos32 = cos.to(torch.float32)
41
+ sin32 = sin.to(torch.float32)
42
+
43
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
44
+ return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
45
+
46
+
47
+ def liger_rotary_pos_emb_with_cast_and_leading_batch(
48
+ q: torch.Tensor,
49
+ k: torch.Tensor,
50
+ cos: torch.Tensor,
51
+ sin: torch.Tensor,
52
+ position_ids: Optional[torch.Tensor] = None,
53
+ unsqueeze_dim: int = 1,
54
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
55
+
56
+
57
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
58
+
59
+ q32 = q.to(torch.float32).unsqueeze(0)
60
+ k32 = k.to(torch.float32).unsqueeze(0)
61
+ cos32 = cos.to(torch.float32).unsqueeze(0)
62
+ sin32 = sin.to(torch.float32).unsqueeze(0)
63
+
64
+ q_out, k_out = liger_rotary_pos_emb(
65
+ q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim
66
+ )
67
+ return q_out.to(orig_q_dtype).squeeze(0), k_out.to(orig_k_dtype).squeeze(0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.3.dev20251031170118
3
+ Version: 0.6.3.dev20251101160510
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -42,7 +42,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
42
42
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
43
43
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
44
44
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
45
- liger_kernel/transformers/__init__.py,sha256=MAAd-YqPdG-j_sbrIE43nrICpA4xTg-dx6M06KWLMFU,9486
45
+ liger_kernel/transformers/__init__.py,sha256=iV1X0gH1JXwgeb7AeY8Ryv7q3r44MLQvSvn79yIVDzw,9874
46
46
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
47
47
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
48
48
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
@@ -59,12 +59,12 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
59
59
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
60
60
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
61
61
  liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
62
- liger_kernel/transformers/monkey_patch.py,sha256=3DLFMn2VusVcR6C5YElfpHJBRoJxvho0a2JoVdGqxHA,117266
62
+ liger_kernel/transformers/monkey_patch.py,sha256=Qo5phPCiSF_w29R5AiDO382penkmzuEijv_iNenuuHc,124681
63
63
  liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
64
64
  liger_kernel/transformers/poly_norm.py,sha256=g5tC75i3qy1_N26ZUP-jfpct7ivQAEdJfIfx8IXzeyE,1377
65
65
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
66
66
  liger_kernel/transformers/rms_norm.py,sha256=HwddVqrqS58jE-M2_4NkFGARtCDBhGnkKyjBN9b3FYI,3004
67
- liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
67
+ liger_kernel/transformers/rope.py,sha256=SoOyYArsioIQzp6eZo6vnFumISf06Gl3O8WWkMmr-gQ,2360
68
68
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
69
69
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
70
70
  liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
@@ -97,15 +97,17 @@ liger_kernel/transformers/model/qwen2_vl.py,sha256=ZeasFPGs-bxm2Y_E15mo0YNx5wwtK
97
97
  liger_kernel/transformers/model/qwen3.py,sha256=Q2aOg5erPrgVgRcqJm8sefLSDtvU1AD5B7aJnP7mRMM,4956
98
98
  liger_kernel/transformers/model/qwen3_moe.py,sha256=1CwTMCNFDYsjGoa_aHFBagtC5HuJTV-s0__5UvcjD3A,5686
99
99
  liger_kernel/transformers/model/qwen3_next.py,sha256=7To7azriAogxeE7oEvByKztH9154dnDiDVNHHm7PZK4,5632
100
+ liger_kernel/transformers/model/qwen3_vl.py,sha256=YU76HJ0A9kG5CUaZM4i9Bzci4eeXcNl_VSC2tsPWA3k,6301
101
+ liger_kernel/transformers/model/qwen3_vl_moe.py,sha256=0WuGA-pg5hzKPKc_B3d32qyzXMlkVi3_wlNu9d0KLOg,4392
100
102
  liger_kernel/transformers/model/smollm3.py,sha256=0KWVkDtXbjsBKhJnaquV6vUUYyLtfmNwYH0sxJt-qTk,7667
101
103
  liger_kernel/transformers/model/smolvlm.py,sha256=yFpPKawLVo3zXzLjM7Y_T8FyRrPxVyp-YPFMM8m3k0c,6734
102
104
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
103
105
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
104
106
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
105
107
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
106
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
107
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/METADATA,sha256=tIRv5lazhwtKsdhSattKCeY8GFJaJgIXFrPQXIXNd6E,24777
108
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
109
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
110
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
111
- liger_kernel_nightly-0.6.3.dev20251031170118.dist-info/RECORD,,
108
+ liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
109
+ liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/METADATA,sha256=rsY01xVUY_8qxjoUXKklmwMso2nGFtFS5caQA2iDGlE,24777
110
+ liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
111
+ liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
112
+ liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
113
+ liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/RECORD,,