liger-kernel-nightly 0.6.3.dev20251028143010__py3-none-any.whl → 0.6.3.dev20251101160510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/__init__.py +6 -0
- liger_kernel/transformers/model/qwen3_vl.py +144 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +121 -0
- liger_kernel/transformers/monkey_patch.py +163 -1
- liger_kernel/transformers/rope.py +47 -0
- {liger_kernel_nightly-0.6.3.dev20251028143010.dist-info → liger_kernel_nightly-0.6.3.dev20251101160510.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.3.dev20251028143010.dist-info → liger_kernel_nightly-0.6.3.dev20251101160510.dist-info}/RECORD +11 -9
- {liger_kernel_nightly-0.6.3.dev20251028143010.dist-info → liger_kernel_nightly-0.6.3.dev20251101160510.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251028143010.dist-info → liger_kernel_nightly-0.6.3.dev20251101160510.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251028143010.dist-info → liger_kernel_nightly-0.6.3.dev20251101160510.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.3.dev20251028143010.dist-info → liger_kernel_nightly-0.6.3.dev20251101160510.dist-info}/top_level.txt +0 -0
|
@@ -56,6 +56,8 @@ if TYPE_CHECKING:
|
|
|
56
56
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
|
|
57
57
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
|
|
58
58
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
|
|
59
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
|
|
60
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
|
|
59
61
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
|
|
60
62
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
|
|
61
63
|
|
|
@@ -120,6 +122,8 @@ def __getattr__(name: str):
|
|
|
120
122
|
"apply_liger_kernel_to_qwen3",
|
|
121
123
|
"apply_liger_kernel_to_qwen3_moe",
|
|
122
124
|
"apply_liger_kernel_to_qwen3_next",
|
|
125
|
+
"apply_liger_kernel_to_qwen3_vl",
|
|
126
|
+
"apply_liger_kernel_to_qwen3_vl_moe",
|
|
123
127
|
"apply_liger_kernel_to_smollm3",
|
|
124
128
|
"apply_liger_kernel_to_smolvlm",
|
|
125
129
|
}
|
|
@@ -190,6 +194,8 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
190
194
|
"apply_liger_kernel_to_qwen3",
|
|
191
195
|
"apply_liger_kernel_to_qwen3_moe",
|
|
192
196
|
"apply_liger_kernel_to_qwen3_next",
|
|
197
|
+
"apply_liger_kernel_to_qwen3_vl",
|
|
198
|
+
"apply_liger_kernel_to_qwen3_vl_moe",
|
|
193
199
|
"apply_liger_kernel_to_smollm3",
|
|
194
200
|
"apply_liger_kernel_to_smolvlm",
|
|
195
201
|
]
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLCausalLMOutputWithPast
|
|
9
|
+
from transformers.utils import can_return_tuple
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@can_return_tuple
|
|
15
|
+
def lce_forward(
|
|
16
|
+
self,
|
|
17
|
+
input_ids: torch.LongTensor = None,
|
|
18
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
19
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
20
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
21
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
22
|
+
labels: Optional[torch.LongTensor] = None,
|
|
23
|
+
use_cache: Optional[bool] = None,
|
|
24
|
+
output_attentions: Optional[bool] = None,
|
|
25
|
+
output_hidden_states: Optional[bool] = None,
|
|
26
|
+
return_dict: Optional[bool] = None,
|
|
27
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
28
|
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
29
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
30
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
31
|
+
rope_deltas: Optional[torch.LongTensor] = None,
|
|
32
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
33
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
34
|
+
skip_logits: Optional[bool] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> Union[Tuple, Qwen3VLCausalLMOutputWithPast]:
|
|
37
|
+
"""
|
|
38
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
39
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
40
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
41
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
42
|
+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
|
43
|
+
The tensors corresponding to the input videos. Pixel values can be obtained using
|
|
44
|
+
[`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
|
|
45
|
+
[`Qwen2_5_VLImageProcessor`] for processing videos.
|
|
46
|
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
47
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
48
|
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
49
|
+
The temporal, height and width of feature shape of each video in LLM.
|
|
50
|
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
51
|
+
The rope index difference between sequence length and multimodal rope.
|
|
52
|
+
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
|
53
|
+
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
|
54
|
+
Example:
|
|
55
|
+
```python
|
|
56
|
+
>>> from PIL import Image
|
|
57
|
+
>>> import requests
|
|
58
|
+
>>> from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
|
59
|
+
>>> model = Qwen3VLForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL")
|
|
60
|
+
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL")
|
|
61
|
+
>>> messages = [
|
|
62
|
+
{
|
|
63
|
+
"role": "user",
|
|
64
|
+
"content": [
|
|
65
|
+
{"type": "image"},
|
|
66
|
+
{"type": "text", "text": "What is shown in this image?"},
|
|
67
|
+
],
|
|
68
|
+
},
|
|
69
|
+
]
|
|
70
|
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
71
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
72
|
+
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
73
|
+
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
74
|
+
>>> # Generate
|
|
75
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
76
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
77
|
+
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
78
|
+
```"""
|
|
79
|
+
|
|
80
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
81
|
+
output_hidden_states = (
|
|
82
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
83
|
+
)
|
|
84
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
85
|
+
|
|
86
|
+
outputs = self.model(
|
|
87
|
+
input_ids=input_ids,
|
|
88
|
+
pixel_values=pixel_values,
|
|
89
|
+
pixel_values_videos=pixel_values_videos,
|
|
90
|
+
image_grid_thw=image_grid_thw,
|
|
91
|
+
video_grid_thw=video_grid_thw,
|
|
92
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
93
|
+
position_ids=position_ids,
|
|
94
|
+
attention_mask=attention_mask,
|
|
95
|
+
past_key_values=past_key_values,
|
|
96
|
+
inputs_embeds=inputs_embeds,
|
|
97
|
+
use_cache=use_cache,
|
|
98
|
+
output_attentions=output_attentions,
|
|
99
|
+
output_hidden_states=output_hidden_states,
|
|
100
|
+
return_dict=return_dict,
|
|
101
|
+
cache_position=cache_position,
|
|
102
|
+
**kwargs,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
hidden_states = outputs[0]
|
|
106
|
+
|
|
107
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
108
|
+
loss = None
|
|
109
|
+
logits = None
|
|
110
|
+
|
|
111
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
112
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
113
|
+
|
|
114
|
+
if skip_logits is None:
|
|
115
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
116
|
+
|
|
117
|
+
if skip_logits:
|
|
118
|
+
loss = LigerForCausalLMLoss(
|
|
119
|
+
hidden_states=hidden_states,
|
|
120
|
+
lm_head_weight=self.lm_head.weight,
|
|
121
|
+
labels=labels,
|
|
122
|
+
shift_labels=shift_labels,
|
|
123
|
+
hidden_size=self.config.text_config.hidden_size,
|
|
124
|
+
**kwargs,
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
logits = self.lm_head(hidden_states)
|
|
128
|
+
|
|
129
|
+
loss = None
|
|
130
|
+
if labels is not None:
|
|
131
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
|
132
|
+
|
|
133
|
+
if not return_dict:
|
|
134
|
+
output = (logits,) + outputs[1:]
|
|
135
|
+
return (loss,) + output if loss is not None else output
|
|
136
|
+
|
|
137
|
+
return Qwen3VLCausalLMOutputWithPast(
|
|
138
|
+
loss=loss,
|
|
139
|
+
logits=logits,
|
|
140
|
+
past_key_values=outputs.past_key_values,
|
|
141
|
+
hidden_states=outputs.hidden_states,
|
|
142
|
+
attentions=outputs.attentions,
|
|
143
|
+
rope_deltas=outputs.rope_deltas,
|
|
144
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.utils import can_return_tuple
|
|
9
|
+
|
|
10
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
|
|
12
|
+
Qwen3VLMoeCausalLMOutputWithPast,
|
|
13
|
+
load_balancing_loss_func,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@can_return_tuple
|
|
18
|
+
def lce_forward(
|
|
19
|
+
self,
|
|
20
|
+
input_ids: torch.LongTensor = None,
|
|
21
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
22
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
23
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
24
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
25
|
+
labels: Optional[torch.LongTensor] = None,
|
|
26
|
+
use_cache: Optional[bool] = None,
|
|
27
|
+
output_attentions: Optional[bool] = None,
|
|
28
|
+
output_hidden_states: Optional[bool] = None,
|
|
29
|
+
return_dict: Optional[bool] = None,
|
|
30
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
31
|
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
32
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
33
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
34
|
+
rope_deltas: Optional[torch.LongTensor] = None,
|
|
35
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
36
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
37
|
+
skip_logits: Optional[bool] = None,
|
|
38
|
+
**kwargs,
|
|
39
|
+
) -> Union[Tuple, Qwen3VLMoeCausalLMOutputWithPast]:
|
|
40
|
+
"""
|
|
41
|
+
Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
45
|
+
output_hidden_states = (
|
|
46
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
47
|
+
)
|
|
48
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
49
|
+
|
|
50
|
+
outputs = self.model(
|
|
51
|
+
input_ids=input_ids,
|
|
52
|
+
pixel_values=pixel_values,
|
|
53
|
+
pixel_values_videos=pixel_values_videos,
|
|
54
|
+
image_grid_thw=image_grid_thw,
|
|
55
|
+
video_grid_thw=video_grid_thw,
|
|
56
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
57
|
+
position_ids=position_ids,
|
|
58
|
+
attention_mask=attention_mask,
|
|
59
|
+
past_key_values=past_key_values,
|
|
60
|
+
inputs_embeds=inputs_embeds,
|
|
61
|
+
use_cache=use_cache,
|
|
62
|
+
output_attentions=output_attentions,
|
|
63
|
+
output_hidden_states=output_hidden_states,
|
|
64
|
+
return_dict=return_dict,
|
|
65
|
+
cache_position=cache_position,
|
|
66
|
+
**kwargs,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
hidden_states = outputs[0]
|
|
70
|
+
|
|
71
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
72
|
+
loss = None
|
|
73
|
+
logits = None
|
|
74
|
+
|
|
75
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
76
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
77
|
+
|
|
78
|
+
if skip_logits is None:
|
|
79
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
80
|
+
|
|
81
|
+
if skip_logits:
|
|
82
|
+
loss = LigerForCausalLMLoss(
|
|
83
|
+
hidden_states=hidden_states,
|
|
84
|
+
lm_head_weight=self.lm_head.weight,
|
|
85
|
+
labels=labels,
|
|
86
|
+
shift_labels=shift_labels,
|
|
87
|
+
hidden_size=self.config.text_config.hidden_size,
|
|
88
|
+
**kwargs,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
logits = self.lm_head(hidden_states)
|
|
92
|
+
|
|
93
|
+
if labels is not None:
|
|
94
|
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
|
95
|
+
|
|
96
|
+
# Compute auxiliary load-balancing loss for MoE when requested
|
|
97
|
+
aux_loss = None
|
|
98
|
+
if kwargs.get("output_router_logits", False):
|
|
99
|
+
aux_loss = load_balancing_loss_func(
|
|
100
|
+
outputs.router_logits,
|
|
101
|
+
self.config.text_config.num_experts,
|
|
102
|
+
self.config.text_config.num_experts_per_tok,
|
|
103
|
+
attention_mask,
|
|
104
|
+
)
|
|
105
|
+
# If we computed training loss, add the scaled aux loss to it
|
|
106
|
+
if loss is not None and aux_loss is not None:
|
|
107
|
+
loss = loss + self.config.text_config.router_aux_loss_coef * aux_loss.to(loss.device)
|
|
108
|
+
|
|
109
|
+
if not return_dict:
|
|
110
|
+
output = (logits,) + outputs[1:]
|
|
111
|
+
return (loss,) + output if loss is not None else output
|
|
112
|
+
|
|
113
|
+
return Qwen3VLMoeCausalLMOutputWithPast(
|
|
114
|
+
loss=loss,
|
|
115
|
+
logits=logits,
|
|
116
|
+
past_key_values=outputs.past_key_values,
|
|
117
|
+
hidden_states=outputs.hidden_states,
|
|
118
|
+
attentions=outputs.attentions,
|
|
119
|
+
rope_deltas=outputs.rope_deltas,
|
|
120
|
+
aux_loss=aux_loss,
|
|
121
|
+
)
|
|
@@ -6,6 +6,7 @@ from types import MethodType
|
|
|
6
6
|
from typing import Callable
|
|
7
7
|
from typing import Optional
|
|
8
8
|
|
|
9
|
+
import torch
|
|
9
10
|
import transformers
|
|
10
11
|
|
|
11
12
|
from packaging import version
|
|
@@ -30,10 +31,16 @@ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mi
|
|
|
30
31
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
31
32
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
32
33
|
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
34
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
35
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
33
36
|
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
34
37
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
35
38
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
36
|
-
from liger_kernel.transformers.rope import
|
|
39
|
+
from liger_kernel.transformers.rope import (
|
|
40
|
+
liger_rotary_pos_emb,
|
|
41
|
+
liger_rotary_pos_emb_with_cast,
|
|
42
|
+
liger_rotary_pos_emb_with_cast_and_leading_batch,
|
|
43
|
+
)
|
|
37
44
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
38
45
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
39
46
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
@@ -57,6 +64,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
57
64
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
58
65
|
|
|
59
66
|
|
|
67
|
+
|
|
60
68
|
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
61
69
|
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
62
70
|
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
@@ -1643,6 +1651,156 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1643
1651
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1644
1652
|
|
|
1645
1653
|
|
|
1654
|
+
|
|
1655
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1656
|
+
rope: bool = True,
|
|
1657
|
+
cross_entropy: bool = False,
|
|
1658
|
+
fused_linear_cross_entropy: bool = True,
|
|
1659
|
+
rms_norm: bool = True,
|
|
1660
|
+
swiglu: bool = False,
|
|
1661
|
+
model: PreTrainedModel = None,
|
|
1662
|
+
) -> None:
|
|
1663
|
+
"""
|
|
1664
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1665
|
+
|
|
1666
|
+
Args:
|
|
1667
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1668
|
+
fused_linear_cross_entropy (bool):
|
|
1669
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1670
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1671
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1672
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1673
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1674
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1675
|
+
loaded. Default is None.
|
|
1676
|
+
"""
|
|
1677
|
+
|
|
1678
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1679
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1680
|
+
)
|
|
1681
|
+
|
|
1682
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1683
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1684
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1685
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1686
|
+
|
|
1687
|
+
if rope:
|
|
1688
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1689
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
1690
|
+
|
|
1691
|
+
|
|
1692
|
+
if rms_norm:
|
|
1693
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1694
|
+
|
|
1695
|
+
if cross_entropy:
|
|
1696
|
+
from transformers.loss.loss_utils import nn
|
|
1697
|
+
|
|
1698
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1699
|
+
|
|
1700
|
+
if fused_linear_cross_entropy:
|
|
1701
|
+
if model is not None:
|
|
1702
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1703
|
+
else:
|
|
1704
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1705
|
+
|
|
1706
|
+
if model is not None and rms_norm:
|
|
1707
|
+
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
|
|
1708
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1709
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1710
|
+
text_model = model
|
|
1711
|
+
else:
|
|
1712
|
+
raise TypeError(
|
|
1713
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1714
|
+
)
|
|
1715
|
+
|
|
1716
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1717
|
+
|
|
1718
|
+
if text_model is not None:
|
|
1719
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1720
|
+
for decoder_layer in text_model.layers:
|
|
1721
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1722
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1723
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1724
|
+
if self_attn is not None:
|
|
1725
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1726
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1727
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1728
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1729
|
+
|
|
1730
|
+
|
|
1731
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1732
|
+
rope: bool = True,
|
|
1733
|
+
cross_entropy: bool = False,
|
|
1734
|
+
fused_linear_cross_entropy: bool = True,
|
|
1735
|
+
rms_norm: bool = True,
|
|
1736
|
+
swiglu: bool = False,
|
|
1737
|
+
model: PreTrainedModel = None,
|
|
1738
|
+
) -> None:
|
|
1739
|
+
"""
|
|
1740
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1741
|
+
|
|
1742
|
+
Args:
|
|
1743
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1744
|
+
fused_linear_cross_entropy (bool):
|
|
1745
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1746
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1747
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1748
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1749
|
+
loaded. Default is None.
|
|
1750
|
+
"""
|
|
1751
|
+
|
|
1752
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1753
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1754
|
+
)
|
|
1755
|
+
|
|
1756
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1757
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1758
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1759
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1760
|
+
|
|
1761
|
+
if rope:
|
|
1762
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1763
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
1764
|
+
|
|
1765
|
+
if rms_norm:
|
|
1766
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1767
|
+
|
|
1768
|
+
if cross_entropy:
|
|
1769
|
+
from transformers.loss.loss_utils import nn
|
|
1770
|
+
|
|
1771
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1772
|
+
|
|
1773
|
+
if fused_linear_cross_entropy:
|
|
1774
|
+
if model is not None:
|
|
1775
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1776
|
+
else:
|
|
1777
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1778
|
+
|
|
1779
|
+
if model is not None and rms_norm:
|
|
1780
|
+
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
|
|
1781
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1782
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1783
|
+
text_model = model
|
|
1784
|
+
else:
|
|
1785
|
+
raise TypeError(
|
|
1786
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1787
|
+
)
|
|
1788
|
+
|
|
1789
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1790
|
+
|
|
1791
|
+
if text_model is not None:
|
|
1792
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1793
|
+
for decoder_layer in text_model.layers:
|
|
1794
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1795
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1796
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1797
|
+
if self_attn is not None:
|
|
1798
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1799
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1800
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1801
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1802
|
+
|
|
1803
|
+
|
|
1646
1804
|
def apply_liger_kernel_to_phi3(
|
|
1647
1805
|
rope: bool = True,
|
|
1648
1806
|
cross_entropy: bool = False,
|
|
@@ -2432,6 +2590,10 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2432
2590
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
2433
2591
|
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2434
2592
|
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2593
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2594
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2595
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2596
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2435
2597
|
"smollm3": apply_liger_kernel_to_smollm3,
|
|
2436
2598
|
"phi3": apply_liger_kernel_to_phi3,
|
|
2437
2599
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
1
5
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
2
6
|
|
|
3
7
|
|
|
@@ -18,3 +22,46 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
18
22
|
"""
|
|
19
23
|
|
|
20
24
|
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def liger_rotary_pos_emb_with_cast(
|
|
28
|
+
q: torch.Tensor,
|
|
29
|
+
k: torch.Tensor,
|
|
30
|
+
cos: torch.Tensor,
|
|
31
|
+
sin: torch.Tensor,
|
|
32
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
33
|
+
unsqueeze_dim: int = 1,
|
|
34
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
35
|
+
|
|
36
|
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
37
|
+
|
|
38
|
+
q32 = q.to(torch.float32)
|
|
39
|
+
k32 = k.to(torch.float32)
|
|
40
|
+
cos32 = cos.to(torch.float32)
|
|
41
|
+
sin32 = sin.to(torch.float32)
|
|
42
|
+
|
|
43
|
+
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
|
44
|
+
return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def liger_rotary_pos_emb_with_cast_and_leading_batch(
|
|
48
|
+
q: torch.Tensor,
|
|
49
|
+
k: torch.Tensor,
|
|
50
|
+
cos: torch.Tensor,
|
|
51
|
+
sin: torch.Tensor,
|
|
52
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
53
|
+
unsqueeze_dim: int = 1,
|
|
54
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
58
|
+
|
|
59
|
+
q32 = q.to(torch.float32).unsqueeze(0)
|
|
60
|
+
k32 = k.to(torch.float32).unsqueeze(0)
|
|
61
|
+
cos32 = cos.to(torch.float32).unsqueeze(0)
|
|
62
|
+
sin32 = sin.to(torch.float32).unsqueeze(0)
|
|
63
|
+
|
|
64
|
+
q_out, k_out = liger_rotary_pos_emb(
|
|
65
|
+
q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim
|
|
66
|
+
)
|
|
67
|
+
return q_out.to(orig_q_dtype).squeeze(0), k_out.to(orig_k_dtype).squeeze(0)
|
|
@@ -42,7 +42,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
|
|
42
42
|
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
|
43
43
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
|
44
44
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
|
45
|
-
liger_kernel/transformers/__init__.py,sha256=
|
|
45
|
+
liger_kernel/transformers/__init__.py,sha256=iV1X0gH1JXwgeb7AeY8Ryv7q3r44MLQvSvn79yIVDzw,9874
|
|
46
46
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
|
47
47
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
|
48
48
|
liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
|
|
@@ -59,12 +59,12 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
|
|
|
59
59
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
|
60
60
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
|
61
61
|
liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
|
|
62
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
|
62
|
+
liger_kernel/transformers/monkey_patch.py,sha256=Qo5phPCiSF_w29R5AiDO382penkmzuEijv_iNenuuHc,124681
|
|
63
63
|
liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
|
|
64
64
|
liger_kernel/transformers/poly_norm.py,sha256=g5tC75i3qy1_N26ZUP-jfpct7ivQAEdJfIfx8IXzeyE,1377
|
|
65
65
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
|
66
66
|
liger_kernel/transformers/rms_norm.py,sha256=HwddVqrqS58jE-M2_4NkFGARtCDBhGnkKyjBN9b3FYI,3004
|
|
67
|
-
liger_kernel/transformers/rope.py,sha256=
|
|
67
|
+
liger_kernel/transformers/rope.py,sha256=SoOyYArsioIQzp6eZo6vnFumISf06Gl3O8WWkMmr-gQ,2360
|
|
68
68
|
liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
|
|
69
69
|
liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
|
|
70
70
|
liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
|
|
@@ -97,15 +97,17 @@ liger_kernel/transformers/model/qwen2_vl.py,sha256=ZeasFPGs-bxm2Y_E15mo0YNx5wwtK
|
|
|
97
97
|
liger_kernel/transformers/model/qwen3.py,sha256=Q2aOg5erPrgVgRcqJm8sefLSDtvU1AD5B7aJnP7mRMM,4956
|
|
98
98
|
liger_kernel/transformers/model/qwen3_moe.py,sha256=1CwTMCNFDYsjGoa_aHFBagtC5HuJTV-s0__5UvcjD3A,5686
|
|
99
99
|
liger_kernel/transformers/model/qwen3_next.py,sha256=7To7azriAogxeE7oEvByKztH9154dnDiDVNHHm7PZK4,5632
|
|
100
|
+
liger_kernel/transformers/model/qwen3_vl.py,sha256=YU76HJ0A9kG5CUaZM4i9Bzci4eeXcNl_VSC2tsPWA3k,6301
|
|
101
|
+
liger_kernel/transformers/model/qwen3_vl_moe.py,sha256=0WuGA-pg5hzKPKc_B3d32qyzXMlkVi3_wlNu9d0KLOg,4392
|
|
100
102
|
liger_kernel/transformers/model/smollm3.py,sha256=0KWVkDtXbjsBKhJnaquV6vUUYyLtfmNwYH0sxJt-qTk,7667
|
|
101
103
|
liger_kernel/transformers/model/smolvlm.py,sha256=yFpPKawLVo3zXzLjM7Y_T8FyRrPxVyp-YPFMM8m3k0c,6734
|
|
102
104
|
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
|
103
105
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
|
104
106
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
105
107
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
106
|
-
liger_kernel_nightly-0.6.3.
|
|
107
|
-
liger_kernel_nightly-0.6.3.
|
|
108
|
-
liger_kernel_nightly-0.6.3.
|
|
109
|
-
liger_kernel_nightly-0.6.3.
|
|
110
|
-
liger_kernel_nightly-0.6.3.
|
|
111
|
-
liger_kernel_nightly-0.6.3.
|
|
108
|
+
liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
109
|
+
liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/METADATA,sha256=rsY01xVUY_8qxjoUXKklmwMso2nGFtFS5caQA2iDGlE,24777
|
|
110
|
+
liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
111
|
+
liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
112
|
+
liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
113
|
+
liger_kernel_nightly-0.6.3.dev20251101160510.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|