liger-kernel-nightly 0.6.4.dev20251202054858__py3-none-any.whl → 0.6.4.dev20260107181130__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +12 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/rms_norm.py +126 -49
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +1 -1
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +20 -20
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +1 -1
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/gemma3.py +1 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/paligemma.py +1 -0
- liger_kernel/transformers/monkey_patch.py +118 -39
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +28 -27
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +1 -1
- liger_kernel/transformers/tiled_mlp.py +3 -3
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +27 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/METADATA +9 -3
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/RECORD +58 -46
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
8
|
+
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
9
|
+
|
|
10
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
12
|
+
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def lce_forward(
|
|
16
|
+
self,
|
|
17
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
18
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
19
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
20
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
21
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
22
|
+
labels: Optional[torch.LongTensor] = None,
|
|
23
|
+
use_cache: Optional[bool] = None,
|
|
24
|
+
output_attentions: Optional[bool] = None,
|
|
25
|
+
output_hidden_states: Optional[bool] = None,
|
|
26
|
+
output_router_logits: Optional[bool] = None,
|
|
27
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
28
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
30
|
+
**kwargs,
|
|
31
|
+
) -> LigerMoeCausalLMOutputWithPast:
|
|
32
|
+
r"""
|
|
33
|
+
Forward pass for causal language modeling with Mixture of Experts (MoE) architecture using Liger Kernel optimizations.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
37
|
+
Indices of input sequence tokens in the vocabulary. Indices can be obtained using tokenizers.
|
|
38
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
39
|
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
40
|
+
- 1 for tokens that are **not masked**,
|
|
41
|
+
- 0 for tokens that are **masked**.
|
|
42
|
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
43
|
+
Indices of positions of each input sequence tokens in the position embeddings.
|
|
44
|
+
past_key_values (`List[torch.FloatTensor]` or `Cache`, *optional*):
|
|
45
|
+
Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up
|
|
46
|
+
sequential decoding. See `past_key_values` input for more details.
|
|
47
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
48
|
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
49
|
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
50
|
+
than the model's internal embedding lookup matrix.
|
|
51
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
52
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
53
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
54
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
55
|
+
use_cache (`bool`, *optional*):
|
|
56
|
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
57
|
+
(see `past_key_values`).
|
|
58
|
+
output_attentions (`bool`, *optional*):
|
|
59
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
60
|
+
tensors for more detail.
|
|
61
|
+
output_hidden_states (`bool`, *optional*):
|
|
62
|
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
63
|
+
more detail.
|
|
64
|
+
output_router_logits (`bool`, *optional*):
|
|
65
|
+
Whether or not to return the router logits of all MoE layers. See `router_logits` under returned tensors
|
|
66
|
+
for more detail.
|
|
67
|
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
|
68
|
+
Indices depicting the position of the input sequence tokens in the sequence.
|
|
69
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
|
|
70
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
71
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
72
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
73
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
74
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
75
|
+
skip_logits (`bool`, *optional*):
|
|
76
|
+
Whether to skip logit computation and directly compute loss. If `None`, defaults to `True` during training
|
|
77
|
+
when labels are provided (to save memory), and `False` during inference.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
`LigerMoeCausalLMOutputWithPast`: An output object containing:
|
|
81
|
+
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
82
|
+
Language modeling loss (for next-token prediction), including the auxiliary load balancing loss.
|
|
83
|
+
- aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
|
|
84
|
+
Auxiliary load balancing loss for the sparse MoE modules.
|
|
85
|
+
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
|
|
86
|
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
87
|
+
Note: logits are `None` during training when `skip_logits=True` to save memory.
|
|
88
|
+
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed):
|
|
89
|
+
Cached key and value projection states for faster sequential decoding.
|
|
90
|
+
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
|
91
|
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer) of shape
|
|
92
|
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer.
|
|
93
|
+
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
|
94
|
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
95
|
+
sequence_length)`. Attentions weights after the attention softmax.
|
|
96
|
+
- router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True`):
|
|
97
|
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
|
|
98
|
+
Router logits of the MoE layers, useful to compute the auxiliary loss and z_loss.
|
|
99
|
+
- token_accuracy (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
|
|
100
|
+
Token-level prediction accuracy.
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
>>> from transformers import AutoTokenizer, GptOssForCausalLM
|
|
106
|
+
>>> from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss
|
|
107
|
+
|
|
108
|
+
>>> # Apply Liger Kernel patches for optimized performance
|
|
109
|
+
>>> apply_liger_kernel_to_gpt_oss()
|
|
110
|
+
|
|
111
|
+
>>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b")
|
|
112
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
|
|
113
|
+
|
|
114
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
115
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
116
|
+
|
|
117
|
+
>>> # Inference: Forward pass returns logits
|
|
118
|
+
>>> outputs = model(**inputs)
|
|
119
|
+
>>> outputs.logits.shape
|
|
120
|
+
torch.Size([1, 12, 201088])
|
|
121
|
+
|
|
122
|
+
>>> # Get next token prediction
|
|
123
|
+
>>> next_token_logits = outputs.logits[:, -1, :]
|
|
124
|
+
>>> predicted_token_id = next_token_logits.argmax(dim=-1)
|
|
125
|
+
|
|
126
|
+
>>> # Training: Forward pass with labels returns loss
|
|
127
|
+
>>> labels = inputs.input_ids.clone()
|
|
128
|
+
>>> outputs = model(**inputs, labels=labels)
|
|
129
|
+
>>> outputs.loss
|
|
130
|
+
tensor(2.6454)
|
|
131
|
+
```"""
|
|
132
|
+
|
|
133
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
134
|
+
output_router_logits = (
|
|
135
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
output_hidden_states = (
|
|
139
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
143
|
+
outputs: MoeModelOutputWithPast = self.model(
|
|
144
|
+
input_ids=input_ids,
|
|
145
|
+
attention_mask=attention_mask,
|
|
146
|
+
position_ids=position_ids,
|
|
147
|
+
past_key_values=past_key_values,
|
|
148
|
+
inputs_embeds=inputs_embeds,
|
|
149
|
+
use_cache=use_cache,
|
|
150
|
+
output_attentions=output_attentions,
|
|
151
|
+
output_hidden_states=output_hidden_states,
|
|
152
|
+
output_router_logits=output_router_logits,
|
|
153
|
+
cache_position=cache_position,
|
|
154
|
+
**kwargs,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
hidden_states = outputs.last_hidden_state
|
|
158
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
159
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
160
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
161
|
+
|
|
162
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
163
|
+
logits = None
|
|
164
|
+
loss = None
|
|
165
|
+
token_accuracy = None
|
|
166
|
+
|
|
167
|
+
if skip_logits is None:
|
|
168
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
169
|
+
|
|
170
|
+
if skip_logits:
|
|
171
|
+
result = LigerForCausalLMLoss(
|
|
172
|
+
hidden_states=kept_hidden_states,
|
|
173
|
+
lm_head_weight=self.lm_head.weight,
|
|
174
|
+
labels=labels,
|
|
175
|
+
shift_labels=shift_labels,
|
|
176
|
+
hidden_size=self.config.hidden_size,
|
|
177
|
+
**kwargs,
|
|
178
|
+
)
|
|
179
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
180
|
+
else: # if in inference model materialize logits
|
|
181
|
+
logits = self.lm_head(kept_hidden_states)
|
|
182
|
+
if labels is not None or shift_labels is not None:
|
|
183
|
+
loss = self.loss_function(
|
|
184
|
+
logits=logits,
|
|
185
|
+
labels=labels,
|
|
186
|
+
shift_labels=shift_labels,
|
|
187
|
+
vocab_size=self.vocab_size,
|
|
188
|
+
**kwargs,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
aux_loss = None
|
|
192
|
+
if output_router_logits:
|
|
193
|
+
aux_loss = load_balancing_loss_func(
|
|
194
|
+
outputs.router_logits,
|
|
195
|
+
self.num_experts,
|
|
196
|
+
self.num_experts_per_tok,
|
|
197
|
+
attention_mask,
|
|
198
|
+
)
|
|
199
|
+
if labels is not None:
|
|
200
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
201
|
+
|
|
202
|
+
return LigerMoeCausalLMOutputWithPast(
|
|
203
|
+
loss=loss,
|
|
204
|
+
aux_loss=aux_loss,
|
|
205
|
+
logits=logits,
|
|
206
|
+
past_key_values=outputs.past_key_values,
|
|
207
|
+
hidden_states=outputs.hidden_states,
|
|
208
|
+
attentions=outputs.attentions,
|
|
209
|
+
router_logits=outputs.router_logits,
|
|
210
|
+
token_accuracy=token_accuracy,
|
|
211
|
+
)
|
|
@@ -20,6 +20,7 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
|
|
|
20
20
|
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
21
21
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
22
22
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
23
|
+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
|
|
23
24
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
24
25
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
25
26
|
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
@@ -34,8 +35,7 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
|
|
|
34
35
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
35
36
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
36
37
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
37
|
-
from liger_kernel.transformers.rope import
|
|
38
|
-
from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
|
|
39
39
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
40
40
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
41
41
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
@@ -430,7 +430,7 @@ def apply_liger_kernel_to_llava(
|
|
|
430
430
|
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
431
431
|
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
432
432
|
)
|
|
433
|
-
text_kwargs["model"] = model.language_model
|
|
433
|
+
text_kwargs["model"] = model.model.language_model
|
|
434
434
|
text_liger_fn(**text_kwargs)
|
|
435
435
|
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
436
436
|
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
@@ -445,7 +445,7 @@ def apply_liger_kernel_to_llava(
|
|
|
445
445
|
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
446
446
|
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
447
447
|
)
|
|
448
|
-
vision_kwargs["model"] = model.vision_tower
|
|
448
|
+
vision_kwargs["model"] = model.model.vision_tower
|
|
449
449
|
vision_liger_fn(**vision_kwargs)
|
|
450
450
|
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
451
451
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
@@ -615,8 +615,8 @@ def apply_liger_kernel_to_mllama(
|
|
|
615
615
|
# instance variables that reference already-instantiated modules
|
|
616
616
|
|
|
617
617
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
618
|
-
language_model: MllamaForCausalLM = model.language_model
|
|
619
|
-
vision_model: MllamaVisionModel = model.vision_model
|
|
618
|
+
language_model: MllamaForCausalLM = model.model.language_model
|
|
619
|
+
vision_model: MllamaVisionModel = model.model.vision_model
|
|
620
620
|
if isinstance(language_model, MllamaForCausalLM):
|
|
621
621
|
text_model: MllamaTextModel = language_model.model
|
|
622
622
|
else:
|
|
@@ -1118,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1118
1118
|
# instance variables that reference already-instantiated modules
|
|
1119
1119
|
|
|
1120
1120
|
if isinstance(model, Gemma3ForConditionalGeneration):
|
|
1121
|
-
if isinstance(model.vision_tower, SiglipVisionModel):
|
|
1122
|
-
vision_tower = model.vision_tower
|
|
1121
|
+
if isinstance(model.model.vision_tower, SiglipVisionModel):
|
|
1122
|
+
vision_tower = model.model.vision_tower
|
|
1123
1123
|
|
|
1124
1124
|
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1125
1125
|
|
|
@@ -1132,7 +1132,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1132
1132
|
raise TypeError("The vision tower must be SiglipVisionModel")
|
|
1133
1133
|
|
|
1134
1134
|
if rms_norm:
|
|
1135
|
-
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
|
|
1135
|
+
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
|
|
1136
1136
|
|
|
1137
1137
|
apply_liger_kernel_to_gemma3_text(
|
|
1138
1138
|
rope=rope,
|
|
@@ -1140,7 +1140,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
1140
1140
|
fused_linear_cross_entropy=False,
|
|
1141
1141
|
rms_norm=rms_norm,
|
|
1142
1142
|
geglu=geglu,
|
|
1143
|
-
model=model.language_model,
|
|
1143
|
+
model=model.model.language_model,
|
|
1144
1144
|
)
|
|
1145
1145
|
|
|
1146
1146
|
else:
|
|
@@ -1228,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1228
1228
|
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
1229
1229
|
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
1230
1230
|
|
|
1231
|
-
vision_tower: SiglipVisionModel = model.vision_tower
|
|
1231
|
+
vision_tower: SiglipVisionModel = model.model.vision_tower
|
|
1232
1232
|
|
|
1233
1233
|
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1234
1234
|
|
|
@@ -1238,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1238
1238
|
_patch_layer_norm_module(layer.layer_norm1)
|
|
1239
1239
|
_patch_layer_norm_module(layer.layer_norm2)
|
|
1240
1240
|
|
|
1241
|
-
language_model = model.language_model
|
|
1241
|
+
language_model = model.model.language_model
|
|
1242
1242
|
|
|
1243
1243
|
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1244
1244
|
apply_liger_kernel_to_gemma(
|
|
@@ -1459,6 +1459,79 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1459
1459
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1460
1460
|
|
|
1461
1461
|
|
|
1462
|
+
def apply_liger_kernel_to_gpt_oss(
|
|
1463
|
+
rope: bool = True,
|
|
1464
|
+
cross_entropy: bool = False,
|
|
1465
|
+
fused_linear_cross_entropy: bool = True,
|
|
1466
|
+
rms_norm: bool = True,
|
|
1467
|
+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
|
|
1468
|
+
model: PreTrainedModel = None,
|
|
1469
|
+
) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
|
|
1472
|
+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
|
|
1473
|
+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
|
|
1474
|
+
implementation with clamping and MXFP4 quantization.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1478
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1479
|
+
fused_linear_cross_entropy (bool):
|
|
1480
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1481
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1482
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1483
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1484
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1485
|
+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
|
|
1486
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1487
|
+
loaded. Default is None.
|
|
1488
|
+
"""
|
|
1489
|
+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
|
|
1490
|
+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1493
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1494
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
from transformers.models.gpt_oss import modeling_gpt_oss
|
|
1498
|
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
|
|
1499
|
+
|
|
1500
|
+
if rope:
|
|
1501
|
+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1502
|
+
|
|
1503
|
+
if rms_norm:
|
|
1504
|
+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
|
|
1505
|
+
|
|
1506
|
+
if cross_entropy:
|
|
1507
|
+
from transformers.loss.loss_utils import nn
|
|
1508
|
+
|
|
1509
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1510
|
+
|
|
1511
|
+
if fused_linear_cross_entropy:
|
|
1512
|
+
if model is not None:
|
|
1513
|
+
model.forward = MethodType(gpt_oss_lce_forward, model)
|
|
1514
|
+
else:
|
|
1515
|
+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
|
|
1516
|
+
|
|
1517
|
+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
|
|
1518
|
+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
|
|
1519
|
+
|
|
1520
|
+
if model is not None:
|
|
1521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1522
|
+
# instance variables that reference already-instantiated modules
|
|
1523
|
+
|
|
1524
|
+
# get the base model from the model instance
|
|
1525
|
+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
|
|
1526
|
+
|
|
1527
|
+
if rms_norm:
|
|
1528
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1529
|
+
for decoder_layer in base_model.layers:
|
|
1530
|
+
if rms_norm:
|
|
1531
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1532
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1533
|
+
|
|
1534
|
+
|
|
1462
1535
|
def apply_liger_kernel_to_qwen2_vl(
|
|
1463
1536
|
rope: bool = True,
|
|
1464
1537
|
cross_entropy: bool = False,
|
|
@@ -1520,11 +1593,10 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1520
1593
|
if model is not None:
|
|
1521
1594
|
# The model instance already exists, so we need to additionally patch the
|
|
1522
1595
|
# instance variables that reference already-instantiated modules
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1596
|
+
if isinstance(model, Qwen2VLForConditionalGeneration):
|
|
1597
|
+
text_model: Qwen2VLTextModel = model.model.language_model
|
|
1598
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
|
|
1599
|
+
elif isinstance(model, Qwen2VLModel):
|
|
1528
1600
|
text_model: Qwen2VLTextModel = model.language_model
|
|
1529
1601
|
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1530
1602
|
elif isinstance(model, Qwen2VLTextModel):
|
|
@@ -1611,11 +1683,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1611
1683
|
if model is not None:
|
|
1612
1684
|
# The model instance already exists, so we need to additionally patch the
|
|
1613
1685
|
# instance variables that reference already-instantiated modules
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1686
|
+
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
1687
|
+
text_model: Qwen2_5_VLTextModel = model.model.language_model
|
|
1688
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
|
|
1689
|
+
elif isinstance(model, Qwen2_5_VLModel):
|
|
1619
1690
|
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1620
1691
|
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1621
1692
|
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
@@ -1629,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1629
1700
|
|
|
1630
1701
|
if vision_model is not None:
|
|
1631
1702
|
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1632
|
-
for vision_block in
|
|
1703
|
+
for vision_block in vision_model.blocks:
|
|
1633
1704
|
if rms_norm:
|
|
1634
1705
|
_patch_rms_norm_module(vision_block.norm1)
|
|
1635
1706
|
_patch_rms_norm_module(vision_block.norm2)
|
|
@@ -1680,8 +1751,8 @@ def apply_liger_kernel_to_qwen3_vl(
|
|
|
1680
1751
|
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1681
1752
|
|
|
1682
1753
|
if rope:
|
|
1683
|
-
modeling_qwen3_vl.apply_rotary_pos_emb =
|
|
1684
|
-
modeling_qwen3_vl.apply_rotary_pos_emb_vision =
|
|
1754
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1755
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1685
1756
|
|
|
1686
1757
|
if rms_norm:
|
|
1687
1758
|
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
@@ -1698,7 +1769,9 @@ def apply_liger_kernel_to_qwen3_vl(
|
|
|
1698
1769
|
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1699
1770
|
|
|
1700
1771
|
if model is not None and rms_norm:
|
|
1701
|
-
if isinstance(model,
|
|
1772
|
+
if isinstance(model, Qwen3VLForConditionalGeneration):
|
|
1773
|
+
text_model: Qwen3VLTextModel = model.model.language_model
|
|
1774
|
+
elif isinstance(model, Qwen3VLModel):
|
|
1702
1775
|
text_model: Qwen3VLTextModel = model.language_model
|
|
1703
1776
|
elif isinstance(model, Qwen3VLTextModel):
|
|
1704
1777
|
text_model = model
|
|
@@ -1755,8 +1828,8 @@ def apply_liger_kernel_to_qwen3_vl_moe(
|
|
|
1755
1828
|
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1756
1829
|
|
|
1757
1830
|
if rope:
|
|
1758
|
-
modeling_qwen3_vl_moe.apply_rotary_pos_emb =
|
|
1759
|
-
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision =
|
|
1831
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1832
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1760
1833
|
|
|
1761
1834
|
if rms_norm:
|
|
1762
1835
|
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
@@ -1773,7 +1846,9 @@ def apply_liger_kernel_to_qwen3_vl_moe(
|
|
|
1773
1846
|
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1774
1847
|
|
|
1775
1848
|
if model is not None and rms_norm:
|
|
1776
|
-
if isinstance(model,
|
|
1849
|
+
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
|
|
1850
|
+
text_model: Qwen3VLMoeTextModel = model.model.language_model
|
|
1851
|
+
elif isinstance(model, Qwen3VLMoeModel):
|
|
1777
1852
|
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1778
1853
|
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1779
1854
|
text_model = model
|
|
@@ -2118,10 +2193,10 @@ def apply_liger_kernel_to_glm4v(
|
|
|
2118
2193
|
if model is not None:
|
|
2119
2194
|
# The model instance already exists, so we need to additionally patch the
|
|
2120
2195
|
# instance variables that reference already-instantiated modules
|
|
2121
|
-
if isinstance(model,
|
|
2122
|
-
|
|
2123
|
-
|
|
2124
|
-
|
|
2196
|
+
if isinstance(model, Glm4vForConditionalGeneration):
|
|
2197
|
+
text_model: Glm4vTextModel = model.model.language_model
|
|
2198
|
+
vision_model: Glm4vVisionModel = model.model.visual
|
|
2199
|
+
elif isinstance(model, Glm4vModel):
|
|
2125
2200
|
text_model: Glm4vTextModel = model.language_model
|
|
2126
2201
|
vision_model: Glm4vVisionModel = model.visual
|
|
2127
2202
|
elif isinstance(model, Glm4vTextModel):
|
|
@@ -2208,10 +2283,11 @@ def apply_liger_kernel_to_glm4v_moe(
|
|
|
2208
2283
|
if model is not None:
|
|
2209
2284
|
# The model instance already exists, so we need to additionally patch the
|
|
2210
2285
|
# instance variables that reference already-instantiated modules
|
|
2211
|
-
if isinstance(model,
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2286
|
+
if isinstance(model, Glm4vMoeForConditionalGeneration):
|
|
2287
|
+
text_model: Glm4vMoeTextModel = model.model.language_model
|
|
2288
|
+
vision_model: Glm4vMoeVisionModel = model.model.visual
|
|
2289
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2290
|
+
elif isinstance(model, Glm4vMoeModel):
|
|
2215
2291
|
text_model: Glm4vMoeTextModel = model.language_model
|
|
2216
2292
|
vision_model: Glm4vMoeVisionModel = model.visual
|
|
2217
2293
|
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
@@ -2314,8 +2390,10 @@ def apply_liger_kernel_to_internvl(
|
|
|
2314
2390
|
if model is not None:
|
|
2315
2391
|
# The model instance already exists, so we need to additionally patch the
|
|
2316
2392
|
# instance variables that reference already-instantiated modules
|
|
2317
|
-
if isinstance(model,
|
|
2318
|
-
|
|
2393
|
+
if isinstance(model, InternVLForConditionalGeneration):
|
|
2394
|
+
text_model = model.model.language_model
|
|
2395
|
+
vision_model: InternVLVisionModel = model.model.vision_tower
|
|
2396
|
+
elif isinstance(model, InternVLModel):
|
|
2319
2397
|
text_model = model.language_model
|
|
2320
2398
|
vision_model: InternVLVisionModel = model.vision_tower
|
|
2321
2399
|
else:
|
|
@@ -2752,6 +2830,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2752
2830
|
"glm4": apply_liger_kernel_to_glm4,
|
|
2753
2831
|
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2754
2832
|
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2833
|
+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
|
|
2755
2834
|
"internvl": apply_liger_kernel_to_internvl,
|
|
2756
2835
|
"llama": apply_liger_kernel_to_llama,
|
|
2757
2836
|
"llama4_text": apply_liger_kernel_to_llama4,
|
|
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
|
5
5
|
|
|
6
6
|
from torch.nn.modules.utils import _pair
|
|
7
7
|
|
|
8
|
-
from liger_kernel.ops
|
|
8
|
+
from liger_kernel.ops import LigerMultiTokenAttentionFunction
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class LigerMultiTokenAttention(nn.Module):
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
|
|
4
|
-
from liger_kernel.ops
|
|
4
|
+
from liger_kernel.ops import LigerRMSNormFunction
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class LigerRMSNorm(nn.Module):
|
|
@@ -14,13 +14,18 @@ class LigerRMSNorm(nn.Module):
|
|
|
14
14
|
init_fn="ones",
|
|
15
15
|
in_place=True,
|
|
16
16
|
row_mode=None,
|
|
17
|
+
elementwise_affine=True,
|
|
17
18
|
):
|
|
18
19
|
super().__init__()
|
|
19
20
|
assert init_fn in [
|
|
20
21
|
"ones",
|
|
21
22
|
"zeros",
|
|
22
23
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
23
|
-
self.
|
|
24
|
+
self.elementwise_affine = elementwise_affine
|
|
25
|
+
if self.elementwise_affine:
|
|
26
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
|
27
|
+
else:
|
|
28
|
+
self.register_parameter("weight", None)
|
|
24
29
|
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
|
|
25
30
|
eps,
|
|
26
31
|
offset,
|
|
@@ -41,7 +46,7 @@ class LigerRMSNorm(nn.Module):
|
|
|
41
46
|
)
|
|
42
47
|
|
|
43
48
|
def extra_repr(self):
|
|
44
|
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
|
|
49
|
+
return f"weight_shape={tuple(self.weight.shape) if self.weight is not None else None}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
|
|
45
50
|
|
|
46
51
|
|
|
47
52
|
class LigerRMSNormForGemma(LigerRMSNorm):
|