liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +44 -13
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +130 -64
- liger_kernel/ops/dyt.py +5 -4
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +135 -80
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +148 -71
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +65 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +56 -24
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +34 -11
- liger_kernel/transformers/model/gemma3.py +102 -112
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +42 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1423 -100
- liger_kernel/transformers/multi_token_attention.py +2 -2
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,9 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
5
6
|
from typing import Callable
|
|
7
|
+
from typing import Optional
|
|
6
8
|
|
|
7
9
|
import transformers
|
|
8
10
|
|
|
@@ -13,10 +15,12 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
|
13
15
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
14
16
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
15
17
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
18
|
+
from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
|
|
16
19
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
17
20
|
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
18
21
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
19
22
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
23
|
+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
|
|
20
24
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
21
25
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
22
26
|
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
@@ -25,12 +29,13 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
|
|
|
25
29
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
26
30
|
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
27
31
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
28
|
-
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
|
29
32
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
30
33
|
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
34
|
+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
31
35
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
32
36
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
33
37
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
|
|
34
39
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
35
40
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
36
41
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
@@ -54,7 +59,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
54
59
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
55
60
|
|
|
56
61
|
|
|
57
|
-
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
62
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
58
63
|
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
59
64
|
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
60
65
|
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
@@ -64,26 +69,29 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
64
69
|
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
65
70
|
)
|
|
66
71
|
module.modules_to_save.default.in_place = in_place
|
|
72
|
+
module.modules_to_save.default.row_mode = row_mode
|
|
67
73
|
module.original_module.offset = offset
|
|
68
74
|
module.original_module.casting_mode = casting_mode
|
|
69
75
|
module.original_module.variance_epsilon = (
|
|
70
76
|
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
71
77
|
)
|
|
72
78
|
module.original_module.in_place = in_place
|
|
79
|
+
module.original_module.row_mode = row_mode
|
|
73
80
|
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
74
81
|
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
75
82
|
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
76
83
|
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
77
|
-
module.modules_to_save.default
|
|
78
|
-
module.original_module
|
|
84
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
85
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
79
86
|
else:
|
|
80
87
|
module.offset = offset
|
|
81
88
|
module.casting_mode = casting_mode
|
|
82
89
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
83
90
|
module.in_place = in_place
|
|
91
|
+
module.row_mode = row_mode
|
|
84
92
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
85
93
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
86
|
-
module
|
|
94
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
87
95
|
|
|
88
96
|
|
|
89
97
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
@@ -105,28 +113,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
|
|
|
105
113
|
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
106
114
|
module, "normalized_shape", None
|
|
107
115
|
)
|
|
108
|
-
_bind_method_to_module(module.modules_to_save.default, "forward",
|
|
109
|
-
_bind_method_to_module(module.modules_to_save.default, "extra_repr",
|
|
110
|
-
_bind_method_to_module(module.original_module, "forward",
|
|
111
|
-
_bind_method_to_module(module.original_module, "extra_repr",
|
|
112
|
-
module.modules_to_save.default
|
|
113
|
-
module.original_module
|
|
116
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
|
|
117
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
|
|
118
|
+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
|
|
119
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
120
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
121
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
114
122
|
else:
|
|
115
123
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
116
124
|
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
117
125
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
118
126
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
119
|
-
module
|
|
127
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
120
128
|
|
|
121
129
|
|
|
122
130
|
def _patch_swiglu_module(module, liger_module):
|
|
123
131
|
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
124
|
-
module
|
|
132
|
+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
|
|
125
133
|
|
|
126
134
|
|
|
127
135
|
def _patch_geglu_module(module):
|
|
128
136
|
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
129
|
-
module
|
|
137
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
|
|
130
138
|
|
|
131
139
|
|
|
132
140
|
def apply_liger_kernel_to_granite(
|
|
@@ -257,10 +265,16 @@ def apply_liger_kernel_to_llama(
|
|
|
257
265
|
|
|
258
266
|
if fused_linear_cross_entropy:
|
|
259
267
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
260
|
-
|
|
268
|
+
if model is not None:
|
|
269
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
270
|
+
else:
|
|
271
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
261
272
|
else: # if version < 4.46.1
|
|
262
273
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
263
|
-
|
|
274
|
+
if model is not None:
|
|
275
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
276
|
+
else:
|
|
277
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
264
278
|
|
|
265
279
|
if model is not None:
|
|
266
280
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -280,6 +294,77 @@ def apply_liger_kernel_to_llama(
|
|
|
280
294
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
281
295
|
|
|
282
296
|
|
|
297
|
+
def apply_liger_kernel_to_smollm3(
|
|
298
|
+
rope: bool = True,
|
|
299
|
+
cross_entropy: bool = False,
|
|
300
|
+
fused_linear_cross_entropy: bool = True,
|
|
301
|
+
rms_norm: bool = True,
|
|
302
|
+
swiglu: bool = True,
|
|
303
|
+
model: PreTrainedModel = None,
|
|
304
|
+
) -> None:
|
|
305
|
+
"""
|
|
306
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
310
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
311
|
+
fused_linear_cross_entropy (bool):
|
|
312
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
313
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
314
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
315
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
316
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
317
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
318
|
+
loaded. Default is None.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
322
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
from transformers.models.smollm3 import modeling_smollm3
|
|
326
|
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
|
|
327
|
+
|
|
328
|
+
if rope:
|
|
329
|
+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
330
|
+
if rms_norm:
|
|
331
|
+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
|
|
332
|
+
if swiglu:
|
|
333
|
+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
|
|
334
|
+
|
|
335
|
+
if cross_entropy:
|
|
336
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
337
|
+
from transformers.loss.loss_utils import nn
|
|
338
|
+
|
|
339
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
340
|
+
else:
|
|
341
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
342
|
+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
343
|
+
|
|
344
|
+
if fused_linear_cross_entropy:
|
|
345
|
+
if model is not None:
|
|
346
|
+
model.forward = MethodType(smollm3_lce_forward, model)
|
|
347
|
+
else:
|
|
348
|
+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
|
|
349
|
+
|
|
350
|
+
if model is not None:
|
|
351
|
+
# The model instance already exists, so we need to additionally patch the
|
|
352
|
+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
|
|
353
|
+
|
|
354
|
+
# get the base model from the model instance
|
|
355
|
+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
|
|
356
|
+
|
|
357
|
+
if rms_norm:
|
|
358
|
+
_patch_rms_norm_module(base_model.norm)
|
|
359
|
+
|
|
360
|
+
for decoder_layer in base_model.layers:
|
|
361
|
+
if swiglu:
|
|
362
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
363
|
+
if rms_norm:
|
|
364
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
365
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
366
|
+
|
|
367
|
+
|
|
283
368
|
def apply_liger_kernel_to_llava(
|
|
284
369
|
cross_entropy: bool = False,
|
|
285
370
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -315,9 +400,15 @@ def apply_liger_kernel_to_llava(
|
|
|
315
400
|
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
316
401
|
if fused_linear_cross_entropy:
|
|
317
402
|
if transformer_version >= version.parse("4.52.0"):
|
|
318
|
-
|
|
403
|
+
if model is not None:
|
|
404
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
405
|
+
else:
|
|
406
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
319
407
|
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
320
|
-
|
|
408
|
+
if model is not None:
|
|
409
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
410
|
+
else:
|
|
411
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
321
412
|
else: # if version < 4.49.0
|
|
322
413
|
logger.warning(
|
|
323
414
|
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
@@ -339,7 +430,7 @@ def apply_liger_kernel_to_llava(
|
|
|
339
430
|
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
340
431
|
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
341
432
|
)
|
|
342
|
-
text_kwargs["model"] = model.language_model
|
|
433
|
+
text_kwargs["model"] = model.model.language_model
|
|
343
434
|
text_liger_fn(**text_kwargs)
|
|
344
435
|
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
345
436
|
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
@@ -354,12 +445,103 @@ def apply_liger_kernel_to_llava(
|
|
|
354
445
|
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
355
446
|
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
356
447
|
)
|
|
357
|
-
vision_kwargs["model"] = model.vision_tower
|
|
448
|
+
vision_kwargs["model"] = model.model.vision_tower
|
|
358
449
|
vision_liger_fn(**vision_kwargs)
|
|
359
450
|
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
360
451
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
361
452
|
|
|
362
453
|
|
|
454
|
+
def apply_liger_kernel_to_llama4(
|
|
455
|
+
rope: bool = True,
|
|
456
|
+
cross_entropy: bool = False,
|
|
457
|
+
fused_linear_cross_entropy: bool = True,
|
|
458
|
+
rms_norm: bool = True,
|
|
459
|
+
swiglu: bool = True,
|
|
460
|
+
model: PreTrainedModel = None,
|
|
461
|
+
layer_norm: bool = True,
|
|
462
|
+
) -> None:
|
|
463
|
+
"""
|
|
464
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
468
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
469
|
+
fused_linear_cross_entropy (bool):
|
|
470
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
471
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
472
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
473
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
474
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
475
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
476
|
+
loaded. Default is None.
|
|
477
|
+
"""
|
|
478
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
479
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
from transformers.models.llama4 import modeling_llama4
|
|
483
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
484
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
485
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
486
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
487
|
+
|
|
488
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
489
|
+
|
|
490
|
+
if rope:
|
|
491
|
+
from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
|
|
492
|
+
|
|
493
|
+
apply_liger_llama4_rope_full(modeling_llama4)
|
|
494
|
+
if rms_norm:
|
|
495
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
496
|
+
if swiglu:
|
|
497
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
498
|
+
|
|
499
|
+
if cross_entropy:
|
|
500
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
501
|
+
|
|
502
|
+
if fused_linear_cross_entropy:
|
|
503
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
504
|
+
|
|
505
|
+
if model is not None:
|
|
506
|
+
# The model instance already exists, so we need to additionally patch the
|
|
507
|
+
# instance variables that reference already-instantiated modules
|
|
508
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
509
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
510
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
511
|
+
text_model: Llama4TextModel = language_model.model
|
|
512
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
513
|
+
text_model = model.model
|
|
514
|
+
vision_model = None
|
|
515
|
+
elif isinstance(model, Llama4TextModel):
|
|
516
|
+
text_model = model
|
|
517
|
+
vision_model = None
|
|
518
|
+
|
|
519
|
+
else:
|
|
520
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
521
|
+
|
|
522
|
+
if text_model:
|
|
523
|
+
if rms_norm:
|
|
524
|
+
_patch_rms_norm_module(text_model.norm)
|
|
525
|
+
for decoder_layer in text_model.layers:
|
|
526
|
+
if swiglu:
|
|
527
|
+
if decoder_layer.is_moe_layer:
|
|
528
|
+
_patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
|
|
529
|
+
else:
|
|
530
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
531
|
+
if rms_norm:
|
|
532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
534
|
+
|
|
535
|
+
if vision_model:
|
|
536
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
537
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
538
|
+
|
|
539
|
+
for layer in vision_model.model.layers:
|
|
540
|
+
if layer_norm:
|
|
541
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
542
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
543
|
+
|
|
544
|
+
|
|
363
545
|
def apply_liger_kernel_to_mllama(
|
|
364
546
|
rope: bool = True,
|
|
365
547
|
cross_entropy: bool = False,
|
|
@@ -401,7 +583,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
401
583
|
|
|
402
584
|
if rope:
|
|
403
585
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
404
|
-
if layer_norm:
|
|
586
|
+
if layer_norm and model is None:
|
|
405
587
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
406
588
|
if rms_norm:
|
|
407
589
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
@@ -417,19 +599,28 @@ def apply_liger_kernel_to_mllama(
|
|
|
417
599
|
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
418
600
|
if fused_linear_cross_entropy:
|
|
419
601
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
420
|
-
|
|
602
|
+
if model is not None:
|
|
603
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
604
|
+
else:
|
|
605
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
421
606
|
else: # if version < 4.46.1
|
|
422
607
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
423
|
-
|
|
608
|
+
if model is not None:
|
|
609
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
610
|
+
else:
|
|
611
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
424
612
|
|
|
425
613
|
if model is not None:
|
|
426
614
|
# The model instance already exists, so we need to additionally patch the
|
|
427
615
|
# instance variables that reference already-instantiated modules
|
|
428
616
|
|
|
429
617
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
430
|
-
language_model: MllamaForCausalLM = model.language_model
|
|
431
|
-
vision_model: MllamaVisionModel = model.vision_model
|
|
432
|
-
|
|
618
|
+
language_model: MllamaForCausalLM = model.model.language_model
|
|
619
|
+
vision_model: MllamaVisionModel = model.model.vision_model
|
|
620
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
621
|
+
text_model: MllamaTextModel = language_model.model
|
|
622
|
+
else:
|
|
623
|
+
text_model = language_model
|
|
433
624
|
elif isinstance(model, MllamaForCausalLM):
|
|
434
625
|
text_model = model.model
|
|
435
626
|
vision_model = None
|
|
@@ -503,7 +694,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
503
694
|
if cross_entropy:
|
|
504
695
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
505
696
|
if fused_linear_cross_entropy:
|
|
506
|
-
|
|
697
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
698
|
+
if model is not None:
|
|
699
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
700
|
+
else:
|
|
701
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
702
|
+
else:
|
|
703
|
+
logger.warning(
|
|
704
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
705
|
+
)
|
|
706
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
707
|
+
|
|
507
708
|
if swiglu:
|
|
508
709
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
509
710
|
|
|
@@ -571,10 +772,16 @@ def apply_liger_kernel_to_mixtral(
|
|
|
571
772
|
|
|
572
773
|
if fused_linear_cross_entropy:
|
|
573
774
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
574
|
-
|
|
775
|
+
if model is not None:
|
|
776
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
777
|
+
else:
|
|
778
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
575
779
|
else: # if version < 4.46.1
|
|
576
780
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
577
|
-
|
|
781
|
+
if model is not None:
|
|
782
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
783
|
+
else:
|
|
784
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
578
785
|
if swiglu:
|
|
579
786
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
580
787
|
|
|
@@ -648,10 +855,16 @@ def apply_liger_kernel_to_gemma(
|
|
|
648
855
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
649
856
|
if fused_linear_cross_entropy:
|
|
650
857
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
651
|
-
|
|
858
|
+
if model is not None:
|
|
859
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
860
|
+
else:
|
|
861
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
652
862
|
else: # if version < 4.46.1
|
|
653
863
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
654
|
-
|
|
864
|
+
if model is not None:
|
|
865
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
866
|
+
else:
|
|
867
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
655
868
|
|
|
656
869
|
if model is not None:
|
|
657
870
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -723,10 +936,16 @@ def apply_liger_kernel_to_gemma2(
|
|
|
723
936
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
724
937
|
if fused_linear_cross_entropy:
|
|
725
938
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
726
|
-
|
|
939
|
+
if model is not None:
|
|
940
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
941
|
+
else:
|
|
942
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
727
943
|
else:
|
|
728
944
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
729
|
-
|
|
945
|
+
if model is not None:
|
|
946
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
947
|
+
else:
|
|
948
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
730
949
|
if geglu:
|
|
731
950
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
732
951
|
|
|
@@ -805,7 +1024,10 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
805
1024
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
806
1025
|
|
|
807
1026
|
if fused_linear_cross_entropy:
|
|
808
|
-
|
|
1027
|
+
if model is not None:
|
|
1028
|
+
model.forward = MethodType(causal_forward, model)
|
|
1029
|
+
else:
|
|
1030
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
809
1031
|
|
|
810
1032
|
if model is not None:
|
|
811
1033
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -875,7 +1097,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
875
1097
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
876
1098
|
)
|
|
877
1099
|
|
|
878
|
-
if layer_norm:
|
|
1100
|
+
if layer_norm and model is None:
|
|
879
1101
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
880
1102
|
|
|
881
1103
|
apply_liger_kernel_to_gemma3_text(
|
|
@@ -886,15 +1108,18 @@ def apply_liger_kernel_to_gemma3(
|
|
|
886
1108
|
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
887
1109
|
|
|
888
1110
|
if fused_linear_cross_entropy:
|
|
889
|
-
|
|
1111
|
+
if model is not None:
|
|
1112
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1113
|
+
else:
|
|
1114
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
890
1115
|
|
|
891
1116
|
if model is not None:
|
|
892
1117
|
# The model instance already exists, so we need to additionally patch the
|
|
893
1118
|
# instance variables that reference already-instantiated modules
|
|
894
1119
|
|
|
895
1120
|
if isinstance(model, Gemma3ForConditionalGeneration):
|
|
896
|
-
if isinstance(model.vision_tower, SiglipVisionModel):
|
|
897
|
-
vision_tower = model.vision_tower
|
|
1121
|
+
if isinstance(model.model.vision_tower, SiglipVisionModel):
|
|
1122
|
+
vision_tower = model.model.vision_tower
|
|
898
1123
|
|
|
899
1124
|
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
900
1125
|
|
|
@@ -907,7 +1132,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
907
1132
|
raise TypeError("The vision tower must be SiglipVisionModel")
|
|
908
1133
|
|
|
909
1134
|
if rms_norm:
|
|
910
|
-
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
|
|
1135
|
+
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
|
|
911
1136
|
|
|
912
1137
|
apply_liger_kernel_to_gemma3_text(
|
|
913
1138
|
rope=rope,
|
|
@@ -915,7 +1140,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
915
1140
|
fused_linear_cross_entropy=False,
|
|
916
1141
|
rms_norm=rms_norm,
|
|
917
1142
|
geglu=geglu,
|
|
918
|
-
model=model.language_model,
|
|
1143
|
+
model=model.model.language_model,
|
|
919
1144
|
)
|
|
920
1145
|
|
|
921
1146
|
else:
|
|
@@ -954,7 +1179,9 @@ def apply_liger_kernel_to_paligemma(
|
|
|
954
1179
|
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
955
1180
|
|
|
956
1181
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1182
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
957
1183
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1184
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
958
1185
|
from transformers.models.paligemma import modeling_paligemma
|
|
959
1186
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
960
1187
|
from transformers.models.siglip import modeling_siglip
|
|
@@ -965,7 +1192,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
965
1192
|
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
966
1193
|
|
|
967
1194
|
# The vision_tower is a SiglipVisionModel
|
|
968
|
-
if layer_norm:
|
|
1195
|
+
if layer_norm and model is None:
|
|
969
1196
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
970
1197
|
|
|
971
1198
|
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
@@ -983,10 +1210,16 @@ def apply_liger_kernel_to_paligemma(
|
|
|
983
1210
|
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
984
1211
|
if fused_linear_cross_entropy:
|
|
985
1212
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
986
|
-
|
|
1213
|
+
if model is not None:
|
|
1214
|
+
model.forward = MethodType(lce_forward, model)
|
|
1215
|
+
else:
|
|
1216
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
987
1217
|
else: # if version < 4.46.1
|
|
988
1218
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
989
|
-
|
|
1219
|
+
if model is not None:
|
|
1220
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1221
|
+
else:
|
|
1222
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
990
1223
|
|
|
991
1224
|
if model is not None:
|
|
992
1225
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -995,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
995
1228
|
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
996
1229
|
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
997
1230
|
|
|
998
|
-
vision_tower: SiglipVisionModel = model.vision_tower
|
|
1231
|
+
vision_tower: SiglipVisionModel = model.model.vision_tower
|
|
999
1232
|
|
|
1000
1233
|
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1001
1234
|
|
|
@@ -1005,9 +1238,9 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1005
1238
|
_patch_layer_norm_module(layer.layer_norm1)
|
|
1006
1239
|
_patch_layer_norm_module(layer.layer_norm2)
|
|
1007
1240
|
|
|
1008
|
-
language_model = model.language_model
|
|
1241
|
+
language_model = model.model.language_model
|
|
1009
1242
|
|
|
1010
|
-
if isinstance(language_model, GemmaForCausalLM):
|
|
1243
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1011
1244
|
apply_liger_kernel_to_gemma(
|
|
1012
1245
|
rope=rope,
|
|
1013
1246
|
cross_entropy=False,
|
|
@@ -1017,7 +1250,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1017
1250
|
model=language_model,
|
|
1018
1251
|
)
|
|
1019
1252
|
|
|
1020
|
-
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
1253
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
1021
1254
|
apply_liger_kernel_to_gemma2(
|
|
1022
1255
|
rope=rope,
|
|
1023
1256
|
cross_entropy=False,
|
|
@@ -1078,10 +1311,16 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1078
1311
|
|
|
1079
1312
|
if fused_linear_cross_entropy:
|
|
1080
1313
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1081
|
-
|
|
1314
|
+
if model is not None:
|
|
1315
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1316
|
+
else:
|
|
1317
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1082
1318
|
else: # if version < 4.46.1
|
|
1083
1319
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1084
|
-
|
|
1320
|
+
if model is not None:
|
|
1321
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1322
|
+
else:
|
|
1323
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1085
1324
|
|
|
1086
1325
|
if swiglu:
|
|
1087
1326
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
@@ -1102,7 +1341,6 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1102
1341
|
if rms_norm:
|
|
1103
1342
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1104
1343
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1105
|
-
print("Applied Liger kernels to Qwen2")
|
|
1106
1344
|
|
|
1107
1345
|
|
|
1108
1346
|
def apply_liger_kernel_to_qwen3(
|
|
@@ -1137,7 +1375,10 @@ def apply_liger_kernel_to_qwen3(
|
|
|
1137
1375
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1138
1376
|
|
|
1139
1377
|
if fused_linear_cross_entropy:
|
|
1140
|
-
|
|
1378
|
+
if model is not None:
|
|
1379
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1380
|
+
else:
|
|
1381
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1141
1382
|
|
|
1142
1383
|
if swiglu:
|
|
1143
1384
|
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
@@ -1192,7 +1433,10 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1192
1433
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1193
1434
|
|
|
1194
1435
|
if fused_linear_cross_entropy:
|
|
1195
|
-
|
|
1436
|
+
if model is not None:
|
|
1437
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1438
|
+
else:
|
|
1439
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1196
1440
|
|
|
1197
1441
|
if swiglu:
|
|
1198
1442
|
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
@@ -1208,7 +1452,81 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1208
1452
|
_patch_rms_norm_module(base_model.norm)
|
|
1209
1453
|
for decoder_layer in base_model.layers:
|
|
1210
1454
|
if swiglu:
|
|
1211
|
-
|
|
1455
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
1456
|
+
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
|
|
1457
|
+
if rms_norm:
|
|
1458
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1459
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
def apply_liger_kernel_to_gpt_oss(
|
|
1463
|
+
rope: bool = True,
|
|
1464
|
+
cross_entropy: bool = False,
|
|
1465
|
+
fused_linear_cross_entropy: bool = True,
|
|
1466
|
+
rms_norm: bool = True,
|
|
1467
|
+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
|
|
1468
|
+
model: PreTrainedModel = None,
|
|
1469
|
+
) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
|
|
1472
|
+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
|
|
1473
|
+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
|
|
1474
|
+
implementation with clamping and MXFP4 quantization.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1478
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1479
|
+
fused_linear_cross_entropy (bool):
|
|
1480
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1481
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1482
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1483
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1484
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1485
|
+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
|
|
1486
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1487
|
+
loaded. Default is None.
|
|
1488
|
+
"""
|
|
1489
|
+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
|
|
1490
|
+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1493
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1494
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
from transformers.models.gpt_oss import modeling_gpt_oss
|
|
1498
|
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
|
|
1499
|
+
|
|
1500
|
+
if rope:
|
|
1501
|
+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1502
|
+
|
|
1503
|
+
if rms_norm:
|
|
1504
|
+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
|
|
1505
|
+
|
|
1506
|
+
if cross_entropy:
|
|
1507
|
+
from transformers.loss.loss_utils import nn
|
|
1508
|
+
|
|
1509
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1510
|
+
|
|
1511
|
+
if fused_linear_cross_entropy:
|
|
1512
|
+
if model is not None:
|
|
1513
|
+
model.forward = MethodType(gpt_oss_lce_forward, model)
|
|
1514
|
+
else:
|
|
1515
|
+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
|
|
1516
|
+
|
|
1517
|
+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
|
|
1518
|
+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
|
|
1519
|
+
|
|
1520
|
+
if model is not None:
|
|
1521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1522
|
+
# instance variables that reference already-instantiated modules
|
|
1523
|
+
|
|
1524
|
+
# get the base model from the model instance
|
|
1525
|
+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
|
|
1526
|
+
|
|
1527
|
+
if rms_norm:
|
|
1528
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1529
|
+
for decoder_layer in base_model.layers:
|
|
1212
1530
|
if rms_norm:
|
|
1213
1531
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1214
1532
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -1260,23 +1578,25 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1260
1578
|
if rms_norm:
|
|
1261
1579
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1262
1580
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1263
|
-
if layer_norm:
|
|
1581
|
+
if layer_norm and model is None:
|
|
1264
1582
|
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1265
1583
|
if cross_entropy:
|
|
1266
1584
|
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1267
1585
|
if fused_linear_cross_entropy:
|
|
1268
|
-
|
|
1586
|
+
if model is not None:
|
|
1587
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1588
|
+
else:
|
|
1589
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1269
1590
|
if swiglu:
|
|
1270
1591
|
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1271
1592
|
|
|
1272
1593
|
if model is not None:
|
|
1273
1594
|
# The model instance already exists, so we need to additionally patch the
|
|
1274
1595
|
# instance variables that reference already-instantiated modules
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1596
|
+
if isinstance(model, Qwen2VLForConditionalGeneration):
|
|
1597
|
+
text_model: Qwen2VLTextModel = model.model.language_model
|
|
1598
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
|
|
1599
|
+
elif isinstance(model, Qwen2VLModel):
|
|
1280
1600
|
text_model: Qwen2VLTextModel = model.language_model
|
|
1281
1601
|
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1282
1602
|
elif isinstance(model, Qwen2VLTextModel):
|
|
@@ -1353,18 +1673,20 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1353
1673
|
if cross_entropy:
|
|
1354
1674
|
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1355
1675
|
if fused_linear_cross_entropy:
|
|
1356
|
-
|
|
1676
|
+
if model is not None:
|
|
1677
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1678
|
+
else:
|
|
1679
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1357
1680
|
if swiglu:
|
|
1358
1681
|
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1359
1682
|
|
|
1360
1683
|
if model is not None:
|
|
1361
1684
|
# The model instance already exists, so we need to additionally patch the
|
|
1362
1685
|
# instance variables that reference already-instantiated modules
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1686
|
+
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
1687
|
+
text_model: Qwen2_5_VLTextModel = model.model.language_model
|
|
1688
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
|
|
1689
|
+
elif isinstance(model, Qwen2_5_VLModel):
|
|
1368
1690
|
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1369
1691
|
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1370
1692
|
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
@@ -1378,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1378
1700
|
|
|
1379
1701
|
if vision_model is not None:
|
|
1380
1702
|
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1381
|
-
for vision_block in
|
|
1703
|
+
for vision_block in vision_model.blocks:
|
|
1382
1704
|
if rms_norm:
|
|
1383
1705
|
_patch_rms_norm_module(vision_block.norm1)
|
|
1384
1706
|
_patch_rms_norm_module(vision_block.norm2)
|
|
@@ -1394,69 +1716,220 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1394
1716
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1395
1717
|
|
|
1396
1718
|
|
|
1397
|
-
def
|
|
1719
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1398
1720
|
rope: bool = True,
|
|
1399
1721
|
cross_entropy: bool = False,
|
|
1400
1722
|
fused_linear_cross_entropy: bool = True,
|
|
1401
1723
|
rms_norm: bool = True,
|
|
1402
|
-
swiglu: bool =
|
|
1724
|
+
swiglu: bool = False,
|
|
1403
1725
|
model: PreTrainedModel = None,
|
|
1404
1726
|
) -> None:
|
|
1405
1727
|
"""
|
|
1406
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1728
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1407
1729
|
|
|
1408
1730
|
Args:
|
|
1409
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1410
1731
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1411
1732
|
fused_linear_cross_entropy (bool):
|
|
1412
1733
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1413
1734
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1414
1735
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1415
1736
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1416
|
-
swiglu (bool): Whether to apply Liger's SwiGLU
|
|
1737
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1417
1738
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1418
1739
|
loaded. Default is None.
|
|
1419
1740
|
"""
|
|
1741
|
+
|
|
1420
1742
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1421
1743
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1422
1744
|
)
|
|
1423
1745
|
|
|
1424
|
-
from transformers.models.
|
|
1425
|
-
from transformers.models.
|
|
1746
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1747
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1748
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1749
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1750
|
+
|
|
1751
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1426
1752
|
|
|
1427
1753
|
if rope:
|
|
1428
|
-
|
|
1754
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1755
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1756
|
+
|
|
1429
1757
|
if rms_norm:
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1758
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1759
|
+
|
|
1433
1760
|
if cross_entropy:
|
|
1434
|
-
|
|
1435
|
-
|
|
1761
|
+
from transformers.loss.loss_utils import nn
|
|
1762
|
+
|
|
1763
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1436
1764
|
|
|
1437
|
-
nn.functional.cross_entropy = liger_cross_entropy
|
|
1438
|
-
else:
|
|
1439
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1440
|
-
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1441
1765
|
if fused_linear_cross_entropy:
|
|
1442
|
-
if
|
|
1443
|
-
|
|
1444
|
-
else:
|
|
1445
|
-
|
|
1446
|
-
|
|
1766
|
+
if model is not None:
|
|
1767
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1768
|
+
else:
|
|
1769
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1770
|
+
|
|
1771
|
+
if model is not None and rms_norm:
|
|
1772
|
+
if isinstance(model, Qwen3VLForConditionalGeneration):
|
|
1773
|
+
text_model: Qwen3VLTextModel = model.model.language_model
|
|
1774
|
+
elif isinstance(model, Qwen3VLModel):
|
|
1775
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1776
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1777
|
+
text_model = model
|
|
1778
|
+
else:
|
|
1779
|
+
raise TypeError(
|
|
1780
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1781
|
+
)
|
|
1447
1782
|
|
|
1448
|
-
|
|
1449
|
-
# The model instance already exists, so we need to additionally patch the
|
|
1450
|
-
# instance variables that reference already-instantiated modules
|
|
1783
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1451
1784
|
|
|
1452
|
-
|
|
1453
|
-
|
|
1785
|
+
if text_model is not None:
|
|
1786
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1787
|
+
for decoder_layer in text_model.layers:
|
|
1788
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1789
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1790
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1791
|
+
if self_attn is not None:
|
|
1792
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1793
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1794
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1795
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1454
1796
|
|
|
1455
|
-
if rms_norm:
|
|
1456
|
-
_patch_rms_norm_module(base_model.norm)
|
|
1457
1797
|
|
|
1458
|
-
|
|
1459
|
-
|
|
1798
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1799
|
+
rope: bool = True,
|
|
1800
|
+
cross_entropy: bool = False,
|
|
1801
|
+
fused_linear_cross_entropy: bool = True,
|
|
1802
|
+
rms_norm: bool = True,
|
|
1803
|
+
swiglu: bool = False,
|
|
1804
|
+
model: PreTrainedModel = None,
|
|
1805
|
+
) -> None:
|
|
1806
|
+
"""
|
|
1807
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1808
|
+
|
|
1809
|
+
Args:
|
|
1810
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1811
|
+
fused_linear_cross_entropy (bool):
|
|
1812
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1813
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1814
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1815
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1816
|
+
loaded. Default is None.
|
|
1817
|
+
"""
|
|
1818
|
+
|
|
1819
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1820
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1824
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1825
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1826
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1827
|
+
|
|
1828
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1829
|
+
|
|
1830
|
+
if rope:
|
|
1831
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1832
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1833
|
+
|
|
1834
|
+
if rms_norm:
|
|
1835
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1836
|
+
|
|
1837
|
+
if cross_entropy:
|
|
1838
|
+
from transformers.loss.loss_utils import nn
|
|
1839
|
+
|
|
1840
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1841
|
+
|
|
1842
|
+
if fused_linear_cross_entropy:
|
|
1843
|
+
if model is not None:
|
|
1844
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1845
|
+
else:
|
|
1846
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1847
|
+
|
|
1848
|
+
if model is not None and rms_norm:
|
|
1849
|
+
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
|
|
1850
|
+
text_model: Qwen3VLMoeTextModel = model.model.language_model
|
|
1851
|
+
elif isinstance(model, Qwen3VLMoeModel):
|
|
1852
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1853
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1854
|
+
text_model = model
|
|
1855
|
+
else:
|
|
1856
|
+
raise TypeError(
|
|
1857
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1861
|
+
|
|
1862
|
+
if text_model is not None:
|
|
1863
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1864
|
+
for decoder_layer in text_model.layers:
|
|
1865
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1866
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1867
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1868
|
+
if self_attn is not None:
|
|
1869
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1870
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1871
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1872
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1873
|
+
|
|
1874
|
+
|
|
1875
|
+
def apply_liger_kernel_to_phi3(
|
|
1876
|
+
rope: bool = True,
|
|
1877
|
+
cross_entropy: bool = False,
|
|
1878
|
+
fused_linear_cross_entropy: bool = True,
|
|
1879
|
+
rms_norm: bool = True,
|
|
1880
|
+
swiglu: bool = True,
|
|
1881
|
+
model: PreTrainedModel = None,
|
|
1882
|
+
) -> None:
|
|
1883
|
+
"""
|
|
1884
|
+
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
|
|
1885
|
+
|
|
1886
|
+
Args:
|
|
1887
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1888
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1889
|
+
fused_linear_cross_entropy (bool):
|
|
1890
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1891
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1892
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1893
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1894
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
1895
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1896
|
+
loaded. Default is None.
|
|
1897
|
+
"""
|
|
1898
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1899
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1900
|
+
)
|
|
1901
|
+
|
|
1902
|
+
from transformers.models.phi3 import modeling_phi3
|
|
1903
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
1904
|
+
|
|
1905
|
+
if rope:
|
|
1906
|
+
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
1907
|
+
if rms_norm:
|
|
1908
|
+
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
|
|
1909
|
+
if swiglu:
|
|
1910
|
+
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1911
|
+
if cross_entropy:
|
|
1912
|
+
from transformers.loss.loss_utils import nn
|
|
1913
|
+
|
|
1914
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1915
|
+
if fused_linear_cross_entropy:
|
|
1916
|
+
if model is not None:
|
|
1917
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1918
|
+
else:
|
|
1919
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1920
|
+
|
|
1921
|
+
if model is not None:
|
|
1922
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1923
|
+
# instance variables that reference already-instantiated modules
|
|
1924
|
+
|
|
1925
|
+
# get the base model from the model instance
|
|
1926
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
1927
|
+
|
|
1928
|
+
if rms_norm:
|
|
1929
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1930
|
+
|
|
1931
|
+
for decoder_layer in base_model.layers:
|
|
1932
|
+
if swiglu:
|
|
1460
1933
|
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
1461
1934
|
if rms_norm:
|
|
1462
1935
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
@@ -1507,7 +1980,10 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1507
1980
|
|
|
1508
1981
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1509
1982
|
if fused_linear_cross_entropy:
|
|
1510
|
-
|
|
1983
|
+
if model is not None:
|
|
1984
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1985
|
+
else:
|
|
1986
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1511
1987
|
|
|
1512
1988
|
if model is not None:
|
|
1513
1989
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1527,6 +2003,74 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1527
2003
|
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
1528
2004
|
|
|
1529
2005
|
|
|
2006
|
+
def apply_liger_kernel_to_olmo3(
|
|
2007
|
+
rope: bool = True,
|
|
2008
|
+
cross_entropy: bool = False,
|
|
2009
|
+
fused_linear_cross_entropy: bool = True,
|
|
2010
|
+
rms_norm: bool = True,
|
|
2011
|
+
swiglu: bool = True,
|
|
2012
|
+
model: PreTrainedModel = None,
|
|
2013
|
+
) -> None:
|
|
2014
|
+
"""
|
|
2015
|
+
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
|
|
2016
|
+
|
|
2017
|
+
Args:
|
|
2018
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2019
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2020
|
+
fused_linear_cross_entropy (bool):
|
|
2021
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2022
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2023
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2024
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2025
|
+
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
|
|
2026
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2027
|
+
loaded. Default is None.
|
|
2028
|
+
"""
|
|
2029
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2030
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2031
|
+
)
|
|
2032
|
+
|
|
2033
|
+
from transformers.models.olmo3 import modeling_olmo3
|
|
2034
|
+
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
|
|
2035
|
+
|
|
2036
|
+
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
|
|
2037
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
2038
|
+
|
|
2039
|
+
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
|
|
2040
|
+
if rope:
|
|
2041
|
+
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2042
|
+
if rms_norm:
|
|
2043
|
+
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
|
|
2044
|
+
if swiglu:
|
|
2045
|
+
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
|
|
2046
|
+
if cross_entropy:
|
|
2047
|
+
from transformers.loss.loss_utils import nn
|
|
2048
|
+
|
|
2049
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2050
|
+
if fused_linear_cross_entropy:
|
|
2051
|
+
if model is not None:
|
|
2052
|
+
model.forward = MethodType(olmo3_lce_forward, model)
|
|
2053
|
+
else:
|
|
2054
|
+
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
|
|
2055
|
+
|
|
2056
|
+
if model is not None:
|
|
2057
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2058
|
+
# instance variables that reference already-instantiated modules
|
|
2059
|
+
|
|
2060
|
+
# get the base model from the model instance
|
|
2061
|
+
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
|
|
2062
|
+
|
|
2063
|
+
if rms_norm:
|
|
2064
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2065
|
+
|
|
2066
|
+
for decoder_layer in base_model.layers:
|
|
2067
|
+
if swiglu:
|
|
2068
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2069
|
+
if rms_norm:
|
|
2070
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2071
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2072
|
+
|
|
2073
|
+
|
|
1530
2074
|
def apply_liger_kernel_to_glm4(
|
|
1531
2075
|
rope: bool = False,
|
|
1532
2076
|
cross_entropy: bool = False,
|
|
@@ -1571,7 +2115,10 @@ def apply_liger_kernel_to_glm4(
|
|
|
1571
2115
|
|
|
1572
2116
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1573
2117
|
if fused_linear_cross_entropy:
|
|
1574
|
-
|
|
2118
|
+
if model is not None:
|
|
2119
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
2120
|
+
else:
|
|
2121
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
1575
2122
|
|
|
1576
2123
|
if model is not None:
|
|
1577
2124
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1593,6 +2140,764 @@ def apply_liger_kernel_to_glm4(
|
|
|
1593
2140
|
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
|
|
1594
2141
|
|
|
1595
2142
|
|
|
2143
|
+
def apply_liger_kernel_to_glm4v(
|
|
2144
|
+
rope: bool = False,
|
|
2145
|
+
cross_entropy: bool = False,
|
|
2146
|
+
fused_linear_cross_entropy: bool = True,
|
|
2147
|
+
rms_norm: bool = True,
|
|
2148
|
+
swiglu: bool = True,
|
|
2149
|
+
model: PreTrainedModel = None,
|
|
2150
|
+
) -> None:
|
|
2151
|
+
"""
|
|
2152
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
|
|
2153
|
+
|
|
2154
|
+
Args:
|
|
2155
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2156
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2157
|
+
fused_linear_cross_entropy (bool):
|
|
2158
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2159
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2160
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2161
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2162
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
2163
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2164
|
+
loaded. Default is None.
|
|
2165
|
+
"""
|
|
2166
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2167
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2168
|
+
)
|
|
2169
|
+
|
|
2170
|
+
from transformers.models.glm4v import modeling_glm4v
|
|
2171
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
|
|
2172
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
|
|
2173
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
|
|
2174
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
|
|
2175
|
+
|
|
2176
|
+
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
|
|
2177
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2178
|
+
|
|
2179
|
+
if rope:
|
|
2180
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2181
|
+
if rms_norm:
|
|
2182
|
+
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
|
|
2183
|
+
if cross_entropy:
|
|
2184
|
+
from transformers.loss.loss_utils import nn
|
|
2185
|
+
|
|
2186
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2187
|
+
if fused_linear_cross_entropy:
|
|
2188
|
+
if model is not None:
|
|
2189
|
+
model.forward = MethodType(glm4v_lce_forward, model)
|
|
2190
|
+
else:
|
|
2191
|
+
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
|
|
2192
|
+
|
|
2193
|
+
if model is not None:
|
|
2194
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2195
|
+
# instance variables that reference already-instantiated modules
|
|
2196
|
+
if isinstance(model, Glm4vForConditionalGeneration):
|
|
2197
|
+
text_model: Glm4vTextModel = model.model.language_model
|
|
2198
|
+
vision_model: Glm4vVisionModel = model.model.visual
|
|
2199
|
+
elif isinstance(model, Glm4vModel):
|
|
2200
|
+
text_model: Glm4vTextModel = model.language_model
|
|
2201
|
+
vision_model: Glm4vVisionModel = model.visual
|
|
2202
|
+
elif isinstance(model, Glm4vTextModel):
|
|
2203
|
+
text_model: Glm4vTextModel = model
|
|
2204
|
+
vision_model = None
|
|
2205
|
+
else:
|
|
2206
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2207
|
+
raise TypeError(
|
|
2208
|
+
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
|
|
2209
|
+
)
|
|
2210
|
+
|
|
2211
|
+
if vision_model is not None:
|
|
2212
|
+
for vision_block in vision_model.blocks:
|
|
2213
|
+
if rms_norm:
|
|
2214
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2215
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2216
|
+
if swiglu:
|
|
2217
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2218
|
+
|
|
2219
|
+
if text_model is not None:
|
|
2220
|
+
if rms_norm:
|
|
2221
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2222
|
+
for decoder_layer in text_model.layers:
|
|
2223
|
+
if swiglu:
|
|
2224
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2225
|
+
if rms_norm:
|
|
2226
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2227
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2228
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
|
|
2229
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
|
|
2230
|
+
|
|
2231
|
+
|
|
2232
|
+
def apply_liger_kernel_to_glm4v_moe(
|
|
2233
|
+
rope: bool = False,
|
|
2234
|
+
cross_entropy: bool = False,
|
|
2235
|
+
fused_linear_cross_entropy: bool = True,
|
|
2236
|
+
rms_norm: bool = True,
|
|
2237
|
+
swiglu: bool = True,
|
|
2238
|
+
model: PreTrainedModel = None,
|
|
2239
|
+
) -> None:
|
|
2240
|
+
"""
|
|
2241
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
2242
|
+
|
|
2243
|
+
Args:
|
|
2244
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2245
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2246
|
+
fused_linear_cross_entropy (bool):
|
|
2247
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2248
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2249
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2250
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2251
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
2252
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2253
|
+
loaded. Default is None.
|
|
2254
|
+
"""
|
|
2255
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2256
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2257
|
+
)
|
|
2258
|
+
|
|
2259
|
+
from transformers.models.glm4v_moe import modeling_glm4v_moe
|
|
2260
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
|
|
2261
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
|
|
2262
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
|
|
2263
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
|
|
2264
|
+
|
|
2265
|
+
from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
|
|
2266
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2267
|
+
|
|
2268
|
+
if rope:
|
|
2269
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2270
|
+
if rms_norm:
|
|
2271
|
+
modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
|
|
2272
|
+
modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
|
|
2273
|
+
if cross_entropy:
|
|
2274
|
+
from transformers.loss.loss_utils import nn
|
|
2275
|
+
|
|
2276
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2277
|
+
if fused_linear_cross_entropy:
|
|
2278
|
+
if model is not None:
|
|
2279
|
+
model.forward = MethodType(glm4v_moe_lce_forward, model)
|
|
2280
|
+
else:
|
|
2281
|
+
modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
|
|
2282
|
+
|
|
2283
|
+
if model is not None:
|
|
2284
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2285
|
+
# instance variables that reference already-instantiated modules
|
|
2286
|
+
if isinstance(model, Glm4vMoeForConditionalGeneration):
|
|
2287
|
+
text_model: Glm4vMoeTextModel = model.model.language_model
|
|
2288
|
+
vision_model: Glm4vMoeVisionModel = model.model.visual
|
|
2289
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2290
|
+
elif isinstance(model, Glm4vMoeModel):
|
|
2291
|
+
text_model: Glm4vMoeTextModel = model.language_model
|
|
2292
|
+
vision_model: Glm4vMoeVisionModel = model.visual
|
|
2293
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2294
|
+
elif isinstance(model, Glm4vMoeTextModel):
|
|
2295
|
+
text_model: Glm4vMoeTextModel = model
|
|
2296
|
+
vision_model = None
|
|
2297
|
+
else:
|
|
2298
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2299
|
+
raise TypeError(
|
|
2300
|
+
f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
|
|
2301
|
+
)
|
|
2302
|
+
|
|
2303
|
+
if vision_model is not None:
|
|
2304
|
+
_patch_rms_norm_module(vision_model.post_conv_layernorm)
|
|
2305
|
+
_patch_rms_norm_module(vision_model.post_layernorm)
|
|
2306
|
+
for vision_block in vision_model.blocks:
|
|
2307
|
+
if rms_norm:
|
|
2308
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2309
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2310
|
+
if swiglu:
|
|
2311
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2312
|
+
|
|
2313
|
+
if text_model is not None:
|
|
2314
|
+
if rms_norm:
|
|
2315
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2316
|
+
for decoder_layer in text_model.layers:
|
|
2317
|
+
if swiglu:
|
|
2318
|
+
decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2319
|
+
if rms_norm:
|
|
2320
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2321
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2322
|
+
if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
|
|
2323
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2324
|
+
if experts is not None:
|
|
2325
|
+
for expert in experts:
|
|
2326
|
+
_patch_swiglu_module(expert, LigerSwiGLUMLP)
|
|
2327
|
+
if decoder_layer.mlp.shared_experts is not None:
|
|
2328
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
|
|
2329
|
+
for decoder_layer in text_model.layers:
|
|
2330
|
+
if rms_norm:
|
|
2331
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2332
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2333
|
+
|
|
2334
|
+
|
|
2335
|
+
def apply_liger_kernel_to_internvl(
|
|
2336
|
+
cross_entropy: bool = False,
|
|
2337
|
+
fused_linear_cross_entropy: bool = True,
|
|
2338
|
+
rms_norm: bool = True,
|
|
2339
|
+
layer_norm: bool = True,
|
|
2340
|
+
model: Optional[PreTrainedModel] = None,
|
|
2341
|
+
**kwargs,
|
|
2342
|
+
) -> None:
|
|
2343
|
+
"""
|
|
2344
|
+
Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
|
|
2345
|
+
Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
|
|
2346
|
+
However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
|
|
2347
|
+
NOTE: InternVL is not available in transformers<4.52.1
|
|
2348
|
+
|
|
2349
|
+
Args:
|
|
2350
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2351
|
+
fused_linear_cross_entropy (bool):
|
|
2352
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2353
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2354
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2355
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2356
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2357
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2358
|
+
loaded. Default is None.
|
|
2359
|
+
"""
|
|
2360
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2361
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2362
|
+
)
|
|
2363
|
+
import torch.nn as torch_nn
|
|
2364
|
+
|
|
2365
|
+
from transformers.models.internvl import modeling_internvl
|
|
2366
|
+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
|
|
2367
|
+
from transformers.models.internvl.modeling_internvl import InternVLModel
|
|
2368
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
|
|
2369
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
|
|
2370
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
|
|
2371
|
+
|
|
2372
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
2373
|
+
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
|
|
2374
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
2375
|
+
|
|
2376
|
+
if layer_norm and model is None:
|
|
2377
|
+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
|
|
2378
|
+
|
|
2379
|
+
if cross_entropy:
|
|
2380
|
+
logger.info("Apply liger cross entropy")
|
|
2381
|
+
|
|
2382
|
+
from transformers.loss.loss_utils import nn
|
|
2383
|
+
|
|
2384
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2385
|
+
if fused_linear_cross_entropy:
|
|
2386
|
+
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
|
|
2387
|
+
if rms_norm:
|
|
2388
|
+
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
|
|
2389
|
+
|
|
2390
|
+
if model is not None:
|
|
2391
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2392
|
+
# instance variables that reference already-instantiated modules
|
|
2393
|
+
if isinstance(model, InternVLForConditionalGeneration):
|
|
2394
|
+
text_model = model.model.language_model
|
|
2395
|
+
vision_model: InternVLVisionModel = model.model.vision_tower
|
|
2396
|
+
elif isinstance(model, InternVLModel):
|
|
2397
|
+
text_model = model.language_model
|
|
2398
|
+
vision_model: InternVLVisionModel = model.vision_tower
|
|
2399
|
+
else:
|
|
2400
|
+
raise TypeError(
|
|
2401
|
+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
|
|
2402
|
+
)
|
|
2403
|
+
|
|
2404
|
+
text_model_name = model.config.text_config.model_type
|
|
2405
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2406
|
+
|
|
2407
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2408
|
+
if text_liger_fn:
|
|
2409
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2410
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2411
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2412
|
+
|
|
2413
|
+
if remain_params:
|
|
2414
|
+
logger.warning(
|
|
2415
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2416
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2417
|
+
)
|
|
2418
|
+
text_kwargs["model"] = text_model
|
|
2419
|
+
text_liger_fn(**text_kwargs)
|
|
2420
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2421
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2422
|
+
|
|
2423
|
+
# Patch vision model RMSNorm layers
|
|
2424
|
+
if rms_norm:
|
|
2425
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2426
|
+
encoder_layer: InternVLVisionLayer
|
|
2427
|
+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
|
|
2428
|
+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
|
|
2429
|
+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
|
|
2430
|
+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
|
|
2431
|
+
|
|
2432
|
+
# Patch vision model LayerNorm layers
|
|
2433
|
+
if layer_norm:
|
|
2434
|
+
# Patch layernorm
|
|
2435
|
+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
|
|
2436
|
+
_patch_layer_norm_module(vision_model.layernorm)
|
|
2437
|
+
|
|
2438
|
+
# Patch encoder layers
|
|
2439
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2440
|
+
encoder_layer: InternVLVisionLayer
|
|
2441
|
+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
|
|
2442
|
+
_patch_layer_norm_module(encoder_layer.layernorm_before)
|
|
2443
|
+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
|
|
2444
|
+
_patch_layer_norm_module(encoder_layer.layernorm_after)
|
|
2445
|
+
|
|
2446
|
+
|
|
2447
|
+
def apply_liger_kernel_to_smolvlm(
|
|
2448
|
+
cross_entropy: bool = False,
|
|
2449
|
+
fused_linear_cross_entropy: bool = True,
|
|
2450
|
+
rms_norm: bool = True,
|
|
2451
|
+
layer_norm: bool = True,
|
|
2452
|
+
model: Optional[PreTrainedModel] = None,
|
|
2453
|
+
**kwargs,
|
|
2454
|
+
) -> None:
|
|
2455
|
+
"""
|
|
2456
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
|
|
2457
|
+
Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
|
|
2458
|
+
However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
|
|
2459
|
+
NOTE: SmolVLM is not available in transformers<4.50.0
|
|
2460
|
+
|
|
2461
|
+
Args:
|
|
2462
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2463
|
+
fused_linear_cross_entropy (bool):
|
|
2464
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2465
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2466
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2467
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2468
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2469
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2470
|
+
loaded. Default is None.
|
|
2471
|
+
"""
|
|
2472
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2473
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2474
|
+
)
|
|
2475
|
+
|
|
2476
|
+
from transformers.models.smolvlm import modeling_smolvlm
|
|
2477
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
|
|
2478
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
|
|
2479
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
|
|
2480
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
|
|
2481
|
+
|
|
2482
|
+
from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
|
|
2483
|
+
|
|
2484
|
+
# Patch LayerNorm for vision model if model is not provided (pre-initialization)
|
|
2485
|
+
if layer_norm and model is None:
|
|
2486
|
+
modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
|
|
2487
|
+
|
|
2488
|
+
if cross_entropy:
|
|
2489
|
+
logger.info("Apply liger cross entropy")
|
|
2490
|
+
|
|
2491
|
+
from transformers.loss.loss_utils import nn
|
|
2492
|
+
|
|
2493
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2494
|
+
if fused_linear_cross_entropy:
|
|
2495
|
+
if model is not None:
|
|
2496
|
+
model.forward = MethodType(smolvlm_lce_forward, model)
|
|
2497
|
+
else:
|
|
2498
|
+
modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
|
|
2499
|
+
if rms_norm:
|
|
2500
|
+
modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
|
|
2501
|
+
|
|
2502
|
+
if model is not None:
|
|
2503
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2504
|
+
# instance variables that reference already-instantiated modules
|
|
2505
|
+
if isinstance(model, SmolVLMForConditionalGeneration):
|
|
2506
|
+
text_model = model.model.text_model
|
|
2507
|
+
vision_model: SmolVLMVisionTransformer = model.model.vision_model
|
|
2508
|
+
elif isinstance(model, SmolVLMModel):
|
|
2509
|
+
text_model = model.text_model
|
|
2510
|
+
vision_model: SmolVLMVisionTransformer = model.vision_model
|
|
2511
|
+
else:
|
|
2512
|
+
raise TypeError(
|
|
2513
|
+
f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
|
|
2514
|
+
)
|
|
2515
|
+
|
|
2516
|
+
text_model_name = model.config.text_config.model_type
|
|
2517
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2518
|
+
|
|
2519
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2520
|
+
if text_liger_fn:
|
|
2521
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2522
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2523
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2524
|
+
|
|
2525
|
+
if remain_params:
|
|
2526
|
+
logger.warning(
|
|
2527
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2528
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2529
|
+
)
|
|
2530
|
+
text_kwargs["model"] = text_model
|
|
2531
|
+
text_liger_fn(**text_kwargs)
|
|
2532
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2533
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2534
|
+
|
|
2535
|
+
# Patch vision model LayerNorm layers
|
|
2536
|
+
if layer_norm:
|
|
2537
|
+
# Patch post_layernorm
|
|
2538
|
+
_patch_layer_norm_module(vision_model.post_layernorm)
|
|
2539
|
+
|
|
2540
|
+
# Patch encoder layers
|
|
2541
|
+
for encoder_layer in vision_model.encoder.layers:
|
|
2542
|
+
encoder_layer: SmolVLMEncoderLayer
|
|
2543
|
+
_patch_layer_norm_module(encoder_layer.layer_norm1)
|
|
2544
|
+
_patch_layer_norm_module(encoder_layer.layer_norm2)
|
|
2545
|
+
|
|
2546
|
+
|
|
2547
|
+
def apply_liger_kernel_to_falcon_h1(
|
|
2548
|
+
rope: bool = True,
|
|
2549
|
+
cross_entropy: bool = False,
|
|
2550
|
+
fused_linear_cross_entropy: bool = True,
|
|
2551
|
+
rms_norm: bool = True,
|
|
2552
|
+
swiglu: bool = False,
|
|
2553
|
+
model: PreTrainedModel = None,
|
|
2554
|
+
) -> None:
|
|
2555
|
+
"""
|
|
2556
|
+
Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
|
|
2557
|
+
Args:
|
|
2558
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2559
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
2560
|
+
fused_linear_cross_entropy (bool):
|
|
2561
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
2562
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2563
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2564
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
2565
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2566
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2567
|
+
loaded. Default is None.
|
|
2568
|
+
"""
|
|
2569
|
+
|
|
2570
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2571
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2572
|
+
)
|
|
2573
|
+
|
|
2574
|
+
from transformers.models.falcon_h1 import modeling_falcon_h1
|
|
2575
|
+
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
|
|
2576
|
+
|
|
2577
|
+
if rope:
|
|
2578
|
+
logger.info("Apply liger rotary pos emb.")
|
|
2579
|
+
modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2580
|
+
if rms_norm:
|
|
2581
|
+
logger.info("Apply liger RMSNorm")
|
|
2582
|
+
modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
|
|
2583
|
+
if swiglu:
|
|
2584
|
+
logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
|
|
2585
|
+
|
|
2586
|
+
if cross_entropy:
|
|
2587
|
+
logger.info("Apply liger cross entropy")
|
|
2588
|
+
from transformers.loss.loss_utils import nn
|
|
2589
|
+
|
|
2590
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2591
|
+
|
|
2592
|
+
if fused_linear_cross_entropy:
|
|
2593
|
+
if model is not None:
|
|
2594
|
+
model.forward = MethodType(falcon_h1_lce_forward, model)
|
|
2595
|
+
else:
|
|
2596
|
+
modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
|
|
2597
|
+
|
|
2598
|
+
if model is not None:
|
|
2599
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2600
|
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
2601
|
+
|
|
2602
|
+
# get the base model from the model instance
|
|
2603
|
+
base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
|
|
2604
|
+
|
|
2605
|
+
if rms_norm:
|
|
2606
|
+
_patch_rms_norm_module(base_model.final_layernorm)
|
|
2607
|
+
|
|
2608
|
+
for decoder_layer in base_model.layers:
|
|
2609
|
+
if swiglu:
|
|
2610
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2611
|
+
if rms_norm:
|
|
2612
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2613
|
+
_patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
|
|
2614
|
+
|
|
2615
|
+
|
|
2616
|
+
def apply_liger_kernel_to_qwen3_next(
|
|
2617
|
+
rope: bool = False,
|
|
2618
|
+
cross_entropy: bool = False,
|
|
2619
|
+
fused_linear_cross_entropy: bool = True,
|
|
2620
|
+
rms_norm: bool = True,
|
|
2621
|
+
swiglu: bool = True,
|
|
2622
|
+
model: PreTrainedModel = None,
|
|
2623
|
+
) -> None:
|
|
2624
|
+
"""
|
|
2625
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
2626
|
+
|
|
2627
|
+
Args:
|
|
2628
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2629
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2630
|
+
fused_linear_cross_entropy (bool):
|
|
2631
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2632
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2633
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2634
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2635
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
2636
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2637
|
+
loaded. Default is None.
|
|
2638
|
+
"""
|
|
2639
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2640
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2641
|
+
)
|
|
2642
|
+
|
|
2643
|
+
from transformers.models.qwen3_next import modeling_qwen3_next
|
|
2644
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
|
|
2645
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
|
|
2646
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
|
|
2647
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
|
2648
|
+
|
|
2649
|
+
from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
|
|
2650
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
|
|
2651
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
2652
|
+
|
|
2653
|
+
if rope:
|
|
2654
|
+
# It might enocunter nan issue
|
|
2655
|
+
# modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2656
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
|
|
2657
|
+
if rms_norm:
|
|
2658
|
+
modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
|
|
2659
|
+
if cross_entropy:
|
|
2660
|
+
from transformers.loss.loss_utils import nn
|
|
2661
|
+
|
|
2662
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2663
|
+
if fused_linear_cross_entropy:
|
|
2664
|
+
if model is not None:
|
|
2665
|
+
if isinstance(model, Qwen3NextForCausalLM):
|
|
2666
|
+
model.forward = MethodType(qwen3_next_lce_forward, model)
|
|
2667
|
+
else:
|
|
2668
|
+
raise TypeError(
|
|
2669
|
+
f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
|
|
2670
|
+
)
|
|
2671
|
+
else:
|
|
2672
|
+
modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
|
|
2673
|
+
if swiglu:
|
|
2674
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2675
|
+
modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
|
|
2676
|
+
|
|
2677
|
+
if model is not None:
|
|
2678
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2679
|
+
# instance variables that reference already-instantiated modules
|
|
2680
|
+
if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
|
|
2681
|
+
base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
|
|
2682
|
+
else:
|
|
2683
|
+
raise TypeError(
|
|
2684
|
+
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
|
|
2685
|
+
)
|
|
2686
|
+
|
|
2687
|
+
if rms_norm:
|
|
2688
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2689
|
+
|
|
2690
|
+
for decoder_layer in base_model.layers:
|
|
2691
|
+
if rms_norm:
|
|
2692
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2693
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2694
|
+
|
|
2695
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2696
|
+
if swiglu:
|
|
2697
|
+
if isinstance(decoder_layer.mlp, Qwen3NextMLP):
|
|
2698
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
|
|
2699
|
+
if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
|
|
2700
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
|
|
2701
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2702
|
+
if experts is not None:
|
|
2703
|
+
for expert in experts:
|
|
2704
|
+
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
|
|
2705
|
+
|
|
2706
|
+
|
|
2707
|
+
def apply_liger_kernel_to_hunyuan_v1_dense(
|
|
2708
|
+
rope: bool = True,
|
|
2709
|
+
cross_entropy: bool = False,
|
|
2710
|
+
fused_linear_cross_entropy: bool = True,
|
|
2711
|
+
rms_norm: bool = True,
|
|
2712
|
+
swiglu: bool = True,
|
|
2713
|
+
model: PreTrainedModel = None,
|
|
2714
|
+
) -> None:
|
|
2715
|
+
"""
|
|
2716
|
+
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
|
|
2717
|
+
"""
|
|
2718
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2719
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2720
|
+
)
|
|
2721
|
+
|
|
2722
|
+
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
|
|
2723
|
+
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
|
|
2724
|
+
|
|
2725
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
|
|
2726
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2727
|
+
|
|
2728
|
+
if rope:
|
|
2729
|
+
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2730
|
+
|
|
2731
|
+
if rms_norm:
|
|
2732
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
|
|
2733
|
+
|
|
2734
|
+
if cross_entropy:
|
|
2735
|
+
from transformers.loss.loss_utils import nn
|
|
2736
|
+
|
|
2737
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2738
|
+
|
|
2739
|
+
if fused_linear_cross_entropy:
|
|
2740
|
+
if model is not None:
|
|
2741
|
+
model.forward = MethodType(hunyuan_v1_lce_forward, model)
|
|
2742
|
+
else:
|
|
2743
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
|
|
2744
|
+
|
|
2745
|
+
if swiglu:
|
|
2746
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2747
|
+
|
|
2748
|
+
if model is not None:
|
|
2749
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2750
|
+
# instance variables that reference already-instantiated modules
|
|
2751
|
+
|
|
2752
|
+
# get the base model from the model instance
|
|
2753
|
+
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
|
|
2754
|
+
|
|
2755
|
+
if rms_norm:
|
|
2756
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2757
|
+
for decoder_layer in base_model.layers:
|
|
2758
|
+
if swiglu:
|
|
2759
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
|
|
2760
|
+
if rms_norm:
|
|
2761
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2762
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2763
|
+
|
|
2764
|
+
|
|
2765
|
+
def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
2766
|
+
rope: bool = True,
|
|
2767
|
+
cross_entropy: bool = False,
|
|
2768
|
+
fused_linear_cross_entropy: bool = True,
|
|
2769
|
+
rms_norm: bool = True,
|
|
2770
|
+
swiglu: bool = True,
|
|
2771
|
+
model: PreTrainedModel = None,
|
|
2772
|
+
) -> None:
|
|
2773
|
+
"""
|
|
2774
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
2775
|
+
"""
|
|
2776
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2777
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2778
|
+
)
|
|
2779
|
+
|
|
2780
|
+
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
|
|
2781
|
+
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
|
|
2782
|
+
|
|
2783
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
|
|
2784
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2785
|
+
|
|
2786
|
+
if rope:
|
|
2787
|
+
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2788
|
+
|
|
2789
|
+
if rms_norm:
|
|
2790
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
|
|
2791
|
+
|
|
2792
|
+
if cross_entropy:
|
|
2793
|
+
from transformers.loss.loss_utils import nn
|
|
2794
|
+
|
|
2795
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2796
|
+
|
|
2797
|
+
if fused_linear_cross_entropy:
|
|
2798
|
+
if model is not None:
|
|
2799
|
+
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
|
|
2800
|
+
else:
|
|
2801
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
|
|
2802
|
+
|
|
2803
|
+
if swiglu:
|
|
2804
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2805
|
+
|
|
2806
|
+
if model is not None:
|
|
2807
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2808
|
+
# instance variables that reference already-instantiated modules
|
|
2809
|
+
|
|
2810
|
+
# get the base model from the model instance
|
|
2811
|
+
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
|
|
2812
|
+
|
|
2813
|
+
if rms_norm:
|
|
2814
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2815
|
+
for decoder_layer in base_model.layers:
|
|
2816
|
+
if swiglu:
|
|
2817
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
2818
|
+
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
|
|
2819
|
+
if rms_norm:
|
|
2820
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2821
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2822
|
+
|
|
2823
|
+
|
|
2824
|
+
def apply_liger_kernel_to_exaone4(
|
|
2825
|
+
rope: bool = True,
|
|
2826
|
+
cross_entropy: bool = False,
|
|
2827
|
+
fused_linear_cross_entropy: bool = True,
|
|
2828
|
+
rms_norm: bool = True,
|
|
2829
|
+
swiglu: bool = True,
|
|
2830
|
+
model: PreTrainedModel = None,
|
|
2831
|
+
) -> None:
|
|
2832
|
+
"""
|
|
2833
|
+
Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
|
|
2834
|
+
|
|
2835
|
+
Args:
|
|
2836
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2837
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2838
|
+
fused_linear_cross_entropy (bool):
|
|
2839
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2840
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2841
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2842
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2843
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2844
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2845
|
+
loaded. Default is None.
|
|
2846
|
+
"""
|
|
2847
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2848
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2849
|
+
)
|
|
2850
|
+
|
|
2851
|
+
from transformers.models.exaone4 import modeling_exaone4
|
|
2852
|
+
from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
|
|
2853
|
+
|
|
2854
|
+
from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
|
|
2855
|
+
|
|
2856
|
+
if rope:
|
|
2857
|
+
modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2858
|
+
|
|
2859
|
+
if rms_norm:
|
|
2860
|
+
# EXAONE4 requires in_place=False to avoid gradient issues
|
|
2861
|
+
class Exaone4LigerRMSNorm(LigerRMSNorm):
|
|
2862
|
+
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
|
2863
|
+
super().__init__(hidden_size, eps, **kwargs)
|
|
2864
|
+
self.in_place = False
|
|
2865
|
+
|
|
2866
|
+
modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
|
|
2867
|
+
|
|
2868
|
+
if cross_entropy:
|
|
2869
|
+
from transformers.loss.loss_utils import nn
|
|
2870
|
+
|
|
2871
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2872
|
+
|
|
2873
|
+
if fused_linear_cross_entropy:
|
|
2874
|
+
if model is not None:
|
|
2875
|
+
model.forward = MethodType(exaone4_lce_forward, model)
|
|
2876
|
+
else:
|
|
2877
|
+
modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
|
|
2878
|
+
|
|
2879
|
+
if swiglu:
|
|
2880
|
+
modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
|
|
2881
|
+
|
|
2882
|
+
if model is not None:
|
|
2883
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2884
|
+
# instance variables that reference already-instantiated modules
|
|
2885
|
+
|
|
2886
|
+
# get the base model from the model instance
|
|
2887
|
+
base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
|
|
2888
|
+
|
|
2889
|
+
if rms_norm:
|
|
2890
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
2891
|
+
for decoder_layer in base_model.layers:
|
|
2892
|
+
if swiglu:
|
|
2893
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
2894
|
+
if rms_norm:
|
|
2895
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2896
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2897
|
+
_patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
|
|
2898
|
+
_patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
|
|
2899
|
+
|
|
2900
|
+
|
|
1596
2901
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
1597
2902
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
1598
2903
|
"gemma": apply_liger_kernel_to_gemma,
|
|
@@ -1600,7 +2905,13 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1600
2905
|
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
1601
2906
|
"gemma3": apply_liger_kernel_to_gemma3,
|
|
1602
2907
|
"glm4": apply_liger_kernel_to_glm4,
|
|
2908
|
+
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2909
|
+
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2910
|
+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
|
|
2911
|
+
"internvl": apply_liger_kernel_to_internvl,
|
|
1603
2912
|
"llama": apply_liger_kernel_to_llama,
|
|
2913
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
2914
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
1604
2915
|
"llava": apply_liger_kernel_to_llava,
|
|
1605
2916
|
"granite": apply_liger_kernel_to_granite,
|
|
1606
2917
|
"mllama": apply_liger_kernel_to_mllama,
|
|
@@ -1608,6 +2919,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1608
2919
|
"mistral": apply_liger_kernel_to_mistral,
|
|
1609
2920
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
1610
2921
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
2922
|
+
"olmo3": apply_liger_kernel_to_olmo3,
|
|
1611
2923
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
1612
2924
|
"qwen3": apply_liger_kernel_to_qwen3,
|
|
1613
2925
|
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
@@ -1615,8 +2927,19 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1615
2927
|
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
1616
2928
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
1617
2929
|
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2930
|
+
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2931
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2932
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2933
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2934
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2935
|
+
"smollm3": apply_liger_kernel_to_smollm3,
|
|
1618
2936
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1619
2937
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
2938
|
+
"falcon_h1": apply_liger_kernel_to_falcon_h1,
|
|
2939
|
+
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2940
|
+
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2941
|
+
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
2942
|
+
"exaone4": apply_liger_kernel_to_exaone4,
|
|
1620
2943
|
}
|
|
1621
2944
|
|
|
1622
2945
|
|