liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,9 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
5
6
|
from typing import Callable
|
|
7
|
+
from typing import Optional
|
|
6
8
|
|
|
7
9
|
import transformers
|
|
8
10
|
|
|
@@ -13,10 +15,12 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
|
13
15
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
14
16
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
15
17
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
18
|
+
from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
|
|
16
19
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
17
20
|
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
18
21
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
19
22
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
23
|
+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
|
|
20
24
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
21
25
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
22
26
|
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
@@ -25,16 +29,24 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
|
|
|
25
29
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
26
30
|
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
27
31
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
28
|
-
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
|
29
32
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
30
33
|
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
34
|
+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
31
35
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
32
36
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
33
37
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
|
|
34
39
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
35
40
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
36
41
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
37
42
|
|
|
43
|
+
try:
|
|
44
|
+
import peft
|
|
45
|
+
|
|
46
|
+
PEFT_AVAILABLE = True
|
|
47
|
+
except ImportError:
|
|
48
|
+
PEFT_AVAILABLE = False
|
|
49
|
+
|
|
38
50
|
transformer_version = version.parse(transformers.__version__)
|
|
39
51
|
|
|
40
52
|
logger = logging.getLogger(__name__)
|
|
@@ -47,33 +59,82 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
47
59
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
48
60
|
|
|
49
61
|
|
|
50
|
-
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
51
|
-
module
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
62
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
63
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
64
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
65
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
66
|
+
module.modules_to_save.default.offset = offset
|
|
67
|
+
module.modules_to_save.default.casting_mode = casting_mode
|
|
68
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
69
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
70
|
+
)
|
|
71
|
+
module.modules_to_save.default.in_place = in_place
|
|
72
|
+
module.modules_to_save.default.row_mode = row_mode
|
|
73
|
+
module.original_module.offset = offset
|
|
74
|
+
module.original_module.casting_mode = casting_mode
|
|
75
|
+
module.original_module.variance_epsilon = (
|
|
76
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
77
|
+
)
|
|
78
|
+
module.original_module.in_place = in_place
|
|
79
|
+
module.original_module.row_mode = row_mode
|
|
80
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
81
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
82
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
83
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
84
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
85
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
86
|
+
else:
|
|
87
|
+
module.offset = offset
|
|
88
|
+
module.casting_mode = casting_mode
|
|
89
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
90
|
+
module.in_place = in_place
|
|
91
|
+
module.row_mode = row_mode
|
|
92
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
93
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
94
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
58
95
|
|
|
59
96
|
|
|
60
97
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
98
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
99
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
100
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
101
|
+
module.hidden_size = module.normalized_shape
|
|
102
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
103
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
104
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
105
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
106
|
+
)
|
|
107
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
108
|
+
module, "normalized_shape", None
|
|
109
|
+
)
|
|
110
|
+
module.original_module.variance_epsilon = (
|
|
111
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
112
|
+
)
|
|
113
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
114
|
+
module, "normalized_shape", None
|
|
115
|
+
)
|
|
116
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
|
|
117
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
|
|
118
|
+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
|
|
119
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
120
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
121
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
122
|
+
else:
|
|
123
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
124
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
125
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
126
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
127
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
67
128
|
|
|
68
129
|
|
|
69
130
|
def _patch_swiglu_module(module, liger_module):
|
|
70
131
|
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
71
|
-
module
|
|
132
|
+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
|
|
72
133
|
|
|
73
134
|
|
|
74
135
|
def _patch_geglu_module(module):
|
|
75
136
|
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
76
|
-
module
|
|
137
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
|
|
77
138
|
|
|
78
139
|
|
|
79
140
|
def apply_liger_kernel_to_granite(
|
|
@@ -204,10 +265,16 @@ def apply_liger_kernel_to_llama(
|
|
|
204
265
|
|
|
205
266
|
if fused_linear_cross_entropy:
|
|
206
267
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
207
|
-
|
|
268
|
+
if model is not None:
|
|
269
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
270
|
+
else:
|
|
271
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
208
272
|
else: # if version < 4.46.1
|
|
209
273
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
210
|
-
|
|
274
|
+
if model is not None:
|
|
275
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
276
|
+
else:
|
|
277
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
211
278
|
|
|
212
279
|
if model is not None:
|
|
213
280
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -227,6 +294,77 @@ def apply_liger_kernel_to_llama(
|
|
|
227
294
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
228
295
|
|
|
229
296
|
|
|
297
|
+
def apply_liger_kernel_to_smollm3(
|
|
298
|
+
rope: bool = True,
|
|
299
|
+
cross_entropy: bool = False,
|
|
300
|
+
fused_linear_cross_entropy: bool = True,
|
|
301
|
+
rms_norm: bool = True,
|
|
302
|
+
swiglu: bool = True,
|
|
303
|
+
model: PreTrainedModel = None,
|
|
304
|
+
) -> None:
|
|
305
|
+
"""
|
|
306
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
310
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
311
|
+
fused_linear_cross_entropy (bool):
|
|
312
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
313
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
314
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
315
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
316
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
317
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
318
|
+
loaded. Default is None.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
322
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
from transformers.models.smollm3 import modeling_smollm3
|
|
326
|
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
|
|
327
|
+
|
|
328
|
+
if rope:
|
|
329
|
+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
330
|
+
if rms_norm:
|
|
331
|
+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
|
|
332
|
+
if swiglu:
|
|
333
|
+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
|
|
334
|
+
|
|
335
|
+
if cross_entropy:
|
|
336
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
337
|
+
from transformers.loss.loss_utils import nn
|
|
338
|
+
|
|
339
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
340
|
+
else:
|
|
341
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
342
|
+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
343
|
+
|
|
344
|
+
if fused_linear_cross_entropy:
|
|
345
|
+
if model is not None:
|
|
346
|
+
model.forward = MethodType(smollm3_lce_forward, model)
|
|
347
|
+
else:
|
|
348
|
+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
|
|
349
|
+
|
|
350
|
+
if model is not None:
|
|
351
|
+
# The model instance already exists, so we need to additionally patch the
|
|
352
|
+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
|
|
353
|
+
|
|
354
|
+
# get the base model from the model instance
|
|
355
|
+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
|
|
356
|
+
|
|
357
|
+
if rms_norm:
|
|
358
|
+
_patch_rms_norm_module(base_model.norm)
|
|
359
|
+
|
|
360
|
+
for decoder_layer in base_model.layers:
|
|
361
|
+
if swiglu:
|
|
362
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
363
|
+
if rms_norm:
|
|
364
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
365
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
366
|
+
|
|
367
|
+
|
|
230
368
|
def apply_liger_kernel_to_llava(
|
|
231
369
|
cross_entropy: bool = False,
|
|
232
370
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -261,13 +399,20 @@ def apply_liger_kernel_to_llava(
|
|
|
261
399
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
262
400
|
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
263
401
|
if fused_linear_cross_entropy:
|
|
264
|
-
if transformer_version >= version.parse("4.
|
|
265
|
-
|
|
402
|
+
if transformer_version >= version.parse("4.52.0"):
|
|
403
|
+
if model is not None:
|
|
404
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
405
|
+
else:
|
|
406
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
407
|
+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
408
|
+
if model is not None:
|
|
409
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
410
|
+
else:
|
|
411
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
266
412
|
else: # if version < 4.49.0
|
|
267
413
|
logger.warning(
|
|
268
|
-
"
|
|
414
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
269
415
|
)
|
|
270
|
-
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
271
416
|
|
|
272
417
|
if model is not None:
|
|
273
418
|
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
@@ -285,7 +430,7 @@ def apply_liger_kernel_to_llava(
|
|
|
285
430
|
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
286
431
|
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
287
432
|
)
|
|
288
|
-
text_kwargs["model"] = model.language_model
|
|
433
|
+
text_kwargs["model"] = model.model.language_model
|
|
289
434
|
text_liger_fn(**text_kwargs)
|
|
290
435
|
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
291
436
|
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
@@ -300,12 +445,103 @@ def apply_liger_kernel_to_llava(
|
|
|
300
445
|
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
301
446
|
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
302
447
|
)
|
|
303
|
-
vision_kwargs["model"] = model.vision_tower
|
|
448
|
+
vision_kwargs["model"] = model.model.vision_tower
|
|
304
449
|
vision_liger_fn(**vision_kwargs)
|
|
305
450
|
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
306
451
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
307
452
|
|
|
308
453
|
|
|
454
|
+
def apply_liger_kernel_to_llama4(
|
|
455
|
+
rope: bool = True,
|
|
456
|
+
cross_entropy: bool = False,
|
|
457
|
+
fused_linear_cross_entropy: bool = True,
|
|
458
|
+
rms_norm: bool = True,
|
|
459
|
+
swiglu: bool = True,
|
|
460
|
+
model: PreTrainedModel = None,
|
|
461
|
+
layer_norm: bool = True,
|
|
462
|
+
) -> None:
|
|
463
|
+
"""
|
|
464
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
468
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
469
|
+
fused_linear_cross_entropy (bool):
|
|
470
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
471
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
472
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
473
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
474
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
475
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
476
|
+
loaded. Default is None.
|
|
477
|
+
"""
|
|
478
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
479
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
from transformers.models.llama4 import modeling_llama4
|
|
483
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
484
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
485
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
486
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
487
|
+
|
|
488
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
489
|
+
|
|
490
|
+
if rope:
|
|
491
|
+
from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
|
|
492
|
+
|
|
493
|
+
apply_liger_llama4_rope_full(modeling_llama4)
|
|
494
|
+
if rms_norm:
|
|
495
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
496
|
+
if swiglu:
|
|
497
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
498
|
+
|
|
499
|
+
if cross_entropy:
|
|
500
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
501
|
+
|
|
502
|
+
if fused_linear_cross_entropy:
|
|
503
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
504
|
+
|
|
505
|
+
if model is not None:
|
|
506
|
+
# The model instance already exists, so we need to additionally patch the
|
|
507
|
+
# instance variables that reference already-instantiated modules
|
|
508
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
509
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
510
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
511
|
+
text_model: Llama4TextModel = language_model.model
|
|
512
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
513
|
+
text_model = model.model
|
|
514
|
+
vision_model = None
|
|
515
|
+
elif isinstance(model, Llama4TextModel):
|
|
516
|
+
text_model = model
|
|
517
|
+
vision_model = None
|
|
518
|
+
|
|
519
|
+
else:
|
|
520
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
521
|
+
|
|
522
|
+
if text_model:
|
|
523
|
+
if rms_norm:
|
|
524
|
+
_patch_rms_norm_module(text_model.norm)
|
|
525
|
+
for decoder_layer in text_model.layers:
|
|
526
|
+
if swiglu:
|
|
527
|
+
if decoder_layer.is_moe_layer:
|
|
528
|
+
_patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
|
|
529
|
+
else:
|
|
530
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
531
|
+
if rms_norm:
|
|
532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
534
|
+
|
|
535
|
+
if vision_model:
|
|
536
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
537
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
538
|
+
|
|
539
|
+
for layer in vision_model.model.layers:
|
|
540
|
+
if layer_norm:
|
|
541
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
542
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
543
|
+
|
|
544
|
+
|
|
309
545
|
def apply_liger_kernel_to_mllama(
|
|
310
546
|
rope: bool = True,
|
|
311
547
|
cross_entropy: bool = False,
|
|
@@ -347,7 +583,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
347
583
|
|
|
348
584
|
if rope:
|
|
349
585
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
350
|
-
if layer_norm:
|
|
586
|
+
if layer_norm and model is None:
|
|
351
587
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
352
588
|
if rms_norm:
|
|
353
589
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
@@ -363,25 +599,35 @@ def apply_liger_kernel_to_mllama(
|
|
|
363
599
|
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
364
600
|
if fused_linear_cross_entropy:
|
|
365
601
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
366
|
-
|
|
602
|
+
if model is not None:
|
|
603
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
604
|
+
else:
|
|
605
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
367
606
|
else: # if version < 4.46.1
|
|
368
607
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
369
|
-
|
|
608
|
+
if model is not None:
|
|
609
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
610
|
+
else:
|
|
611
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
370
612
|
|
|
371
613
|
if model is not None:
|
|
372
614
|
# The model instance already exists, so we need to additionally patch the
|
|
373
615
|
# instance variables that reference already-instantiated modules
|
|
374
616
|
|
|
375
617
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
376
|
-
language_model: MllamaForCausalLM = model.language_model
|
|
377
|
-
vision_model: MllamaVisionModel = model.vision_model
|
|
378
|
-
|
|
618
|
+
language_model: MllamaForCausalLM = model.model.language_model
|
|
619
|
+
vision_model: MllamaVisionModel = model.model.vision_model
|
|
620
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
621
|
+
text_model: MllamaTextModel = language_model.model
|
|
622
|
+
else:
|
|
623
|
+
text_model = language_model
|
|
379
624
|
elif isinstance(model, MllamaForCausalLM):
|
|
380
625
|
text_model = model.model
|
|
381
626
|
vision_model = None
|
|
382
627
|
elif isinstance(model, MllamaTextModel):
|
|
383
628
|
text_model = model
|
|
384
629
|
vision_model = None
|
|
630
|
+
|
|
385
631
|
else:
|
|
386
632
|
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
|
387
633
|
|
|
@@ -448,7 +694,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
448
694
|
if cross_entropy:
|
|
449
695
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
450
696
|
if fused_linear_cross_entropy:
|
|
451
|
-
|
|
697
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
698
|
+
if model is not None:
|
|
699
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
700
|
+
else:
|
|
701
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
702
|
+
else:
|
|
703
|
+
logger.warning(
|
|
704
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
705
|
+
)
|
|
706
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
707
|
+
|
|
452
708
|
if swiglu:
|
|
453
709
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
454
710
|
|
|
@@ -516,10 +772,16 @@ def apply_liger_kernel_to_mixtral(
|
|
|
516
772
|
|
|
517
773
|
if fused_linear_cross_entropy:
|
|
518
774
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
519
|
-
|
|
775
|
+
if model is not None:
|
|
776
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
777
|
+
else:
|
|
778
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
520
779
|
else: # if version < 4.46.1
|
|
521
780
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
522
|
-
|
|
781
|
+
if model is not None:
|
|
782
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
783
|
+
else:
|
|
784
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
523
785
|
if swiglu:
|
|
524
786
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
525
787
|
|
|
@@ -573,8 +835,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
573
835
|
from transformers.models.gemma import modeling_gemma
|
|
574
836
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
575
837
|
|
|
576
|
-
|
|
577
|
-
|
|
838
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
|
839
|
+
|
|
578
840
|
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
|
579
841
|
|
|
580
842
|
if rope:
|
|
@@ -593,10 +855,16 @@ def apply_liger_kernel_to_gemma(
|
|
|
593
855
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
594
856
|
if fused_linear_cross_entropy:
|
|
595
857
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
596
|
-
|
|
858
|
+
if model is not None:
|
|
859
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
860
|
+
else:
|
|
861
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
597
862
|
else: # if version < 4.46.1
|
|
598
863
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
599
|
-
|
|
864
|
+
if model is not None:
|
|
865
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
866
|
+
else:
|
|
867
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
600
868
|
|
|
601
869
|
if model is not None:
|
|
602
870
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -647,7 +915,8 @@ def apply_liger_kernel_to_gemma2(
|
|
|
647
915
|
from transformers.models.gemma2 import modeling_gemma2
|
|
648
916
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
649
917
|
|
|
650
|
-
|
|
918
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
|
919
|
+
|
|
651
920
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
652
921
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
653
922
|
)
|
|
@@ -667,10 +936,16 @@ def apply_liger_kernel_to_gemma2(
|
|
|
667
936
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
668
937
|
if fused_linear_cross_entropy:
|
|
669
938
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
670
|
-
|
|
939
|
+
if model is not None:
|
|
940
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
941
|
+
else:
|
|
942
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
671
943
|
else:
|
|
672
944
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
673
|
-
|
|
945
|
+
if model is not None:
|
|
946
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
947
|
+
else:
|
|
948
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
674
949
|
if geglu:
|
|
675
950
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
676
951
|
|
|
@@ -694,17 +969,16 @@ def apply_liger_kernel_to_gemma2(
|
|
|
694
969
|
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
|
695
970
|
|
|
696
971
|
|
|
697
|
-
def
|
|
972
|
+
def apply_liger_kernel_to_gemma3_text(
|
|
698
973
|
rope: bool = True,
|
|
699
974
|
cross_entropy: bool = False,
|
|
700
975
|
fused_linear_cross_entropy: bool = True,
|
|
701
|
-
layer_norm: bool = True,
|
|
702
976
|
rms_norm: bool = True,
|
|
703
977
|
geglu: bool = True,
|
|
704
978
|
model: PreTrainedModel = None,
|
|
705
979
|
) -> None:
|
|
706
980
|
"""
|
|
707
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
981
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
|
|
708
982
|
|
|
709
983
|
Args:
|
|
710
984
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
@@ -713,7 +987,6 @@ def apply_liger_kernel_to_paligemma(
|
|
|
713
987
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
714
988
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
715
989
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
716
|
-
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
717
990
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
718
991
|
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
719
992
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
@@ -723,97 +996,77 @@ def apply_liger_kernel_to_paligemma(
|
|
|
723
996
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
724
997
|
)
|
|
725
998
|
|
|
726
|
-
|
|
999
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
1000
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
|
1001
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
|
1002
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
|
727
1003
|
|
|
728
|
-
from transformers.
|
|
729
|
-
from transformers.
|
|
730
|
-
from transformers.models.paligemma import modeling_paligemma
|
|
731
|
-
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
732
|
-
from transformers.models.siglip import modeling_siglip
|
|
733
|
-
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
734
|
-
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
1004
|
+
from liger_kernel.transformers.model.gemma3 import causal_forward
|
|
1005
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
|
735
1006
|
|
|
736
|
-
|
|
737
|
-
|
|
1007
|
+
_patch_rms_norm_module_for_gemma3 = partial(
|
|
1008
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
1009
|
+
)
|
|
738
1010
|
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
1011
|
+
if rope:
|
|
1012
|
+
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
742
1013
|
|
|
743
|
-
|
|
744
|
-
|
|
1014
|
+
if rms_norm:
|
|
1015
|
+
modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
|
|
1016
|
+
|
|
1017
|
+
if geglu:
|
|
1018
|
+
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
|
745
1019
|
|
|
746
|
-
# The language_model is GemmaForCausalLM or Gemma2ForCausalLM
|
|
747
|
-
apply_liger_kernel_to_gemma(
|
|
748
|
-
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
749
|
-
)
|
|
750
|
-
apply_liger_kernel_to_gemma2(
|
|
751
|
-
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
752
|
-
)
|
|
753
1020
|
# Handle loss function
|
|
754
1021
|
if cross_entropy:
|
|
755
|
-
|
|
1022
|
+
from transformers.loss.loss_utils import nn
|
|
1023
|
+
|
|
1024
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1025
|
+
|
|
756
1026
|
if fused_linear_cross_entropy:
|
|
757
|
-
if
|
|
758
|
-
|
|
759
|
-
else:
|
|
760
|
-
|
|
761
|
-
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
1027
|
+
if model is not None:
|
|
1028
|
+
model.forward = MethodType(causal_forward, model)
|
|
1029
|
+
else:
|
|
1030
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
762
1031
|
|
|
763
1032
|
if model is not None:
|
|
764
1033
|
# The model instance already exists, so we need to additionally patch the
|
|
765
1034
|
# instance variables that reference already-instantiated modules
|
|
766
1035
|
|
|
767
|
-
if
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
vision_tower: SiglipVisionModel = model.vision_tower
|
|
771
|
-
|
|
772
|
-
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
773
|
-
|
|
774
|
-
for layer in vision_tower.vision_model.encoder.layers:
|
|
775
|
-
layer: SiglipEncoderLayer
|
|
776
|
-
if layer_norm:
|
|
777
|
-
_patch_layer_norm_module(layer.layer_norm1)
|
|
778
|
-
_patch_layer_norm_module(layer.layer_norm2)
|
|
1036
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
|
1037
|
+
# get the base model from the model instance
|
|
1038
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
|
779
1039
|
|
|
780
|
-
|
|
1040
|
+
if rms_norm:
|
|
1041
|
+
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
|
781
1042
|
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
1043
|
+
for decoder_layer in base_model.layers:
|
|
1044
|
+
decoder_layer: Gemma3DecoderLayer
|
|
1045
|
+
if geglu:
|
|
1046
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
|
1047
|
+
if rms_norm:
|
|
1048
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
|
|
1049
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
|
|
1050
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
|
|
1051
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
|
|
1052
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
|
|
1053
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
|
|
791
1054
|
|
|
792
|
-
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
793
|
-
apply_liger_kernel_to_gemma2(
|
|
794
|
-
rope=rope,
|
|
795
|
-
cross_entropy=False,
|
|
796
|
-
fused_linear_cross_entropy=False,
|
|
797
|
-
rms_norm=rms_norm,
|
|
798
|
-
geglu=geglu,
|
|
799
|
-
model=language_model,
|
|
800
|
-
)
|
|
801
1055
|
else:
|
|
802
|
-
raise TypeError(
|
|
803
|
-
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
|
|
804
|
-
)
|
|
1056
|
+
raise TypeError("The model must be Gemma3ForCausalLM.")
|
|
805
1057
|
|
|
806
1058
|
|
|
807
|
-
def
|
|
1059
|
+
def apply_liger_kernel_to_gemma3(
|
|
808
1060
|
rope: bool = True,
|
|
809
1061
|
cross_entropy: bool = False,
|
|
810
1062
|
fused_linear_cross_entropy: bool = True,
|
|
1063
|
+
layer_norm: bool = True,
|
|
811
1064
|
rms_norm: bool = True,
|
|
812
|
-
|
|
1065
|
+
geglu: bool = True,
|
|
813
1066
|
model: PreTrainedModel = None,
|
|
814
1067
|
) -> None:
|
|
815
1068
|
"""
|
|
816
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1069
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
|
|
817
1070
|
|
|
818
1071
|
Args:
|
|
819
1072
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
@@ -822,8 +1075,9 @@ def apply_liger_kernel_to_qwen2(
|
|
|
822
1075
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
823
1076
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
824
1077
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1078
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
825
1079
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
826
|
-
|
|
1080
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
827
1081
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
828
1082
|
loaded. Default is None.
|
|
829
1083
|
"""
|
|
@@ -831,64 +1085,1378 @@ def apply_liger_kernel_to_qwen2(
|
|
|
831
1085
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
832
1086
|
)
|
|
833
1087
|
|
|
834
|
-
from transformers.models.
|
|
835
|
-
from transformers.models.
|
|
1088
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
1089
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
|
|
1090
|
+
from transformers.models.siglip import modeling_siglip
|
|
1091
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
1092
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
836
1093
|
|
|
837
|
-
|
|
838
|
-
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
839
|
-
if rms_norm:
|
|
840
|
-
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
1094
|
+
from liger_kernel.transformers.model.gemma3 import multimodal_forward
|
|
841
1095
|
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
1096
|
+
_patch_rms_norm_module_for_gemma3 = partial(
|
|
1097
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
1098
|
+
)
|
|
845
1099
|
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
849
|
-
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1100
|
+
if layer_norm and model is None:
|
|
1101
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
850
1102
|
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
else: # if version < 4.46.1
|
|
855
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
856
|
-
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1103
|
+
apply_liger_kernel_to_gemma3_text(
|
|
1104
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1105
|
+
)
|
|
857
1106
|
|
|
858
|
-
if
|
|
859
|
-
|
|
1107
|
+
if cross_entropy:
|
|
1108
|
+
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1109
|
+
|
|
1110
|
+
if fused_linear_cross_entropy:
|
|
1111
|
+
if model is not None:
|
|
1112
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1113
|
+
else:
|
|
1114
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
860
1115
|
|
|
861
1116
|
if model is not None:
|
|
862
1117
|
# The model instance already exists, so we need to additionally patch the
|
|
863
1118
|
# instance variables that reference already-instantiated modules
|
|
864
1119
|
|
|
865
|
-
|
|
866
|
-
|
|
1120
|
+
if isinstance(model, Gemma3ForConditionalGeneration):
|
|
1121
|
+
if isinstance(model.model.vision_tower, SiglipVisionModel):
|
|
1122
|
+
vision_tower = model.model.vision_tower
|
|
867
1123
|
|
|
868
|
-
|
|
869
|
-
|
|
1124
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1125
|
+
|
|
1126
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
1127
|
+
layer: SiglipEncoderLayer
|
|
1128
|
+
if layer_norm:
|
|
1129
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
1130
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
1131
|
+
else:
|
|
1132
|
+
raise TypeError("The vision tower must be SiglipVisionModel")
|
|
870
1133
|
|
|
871
|
-
for decoder_layer in base_model.layers:
|
|
872
|
-
if swiglu:
|
|
873
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
874
1134
|
if rms_norm:
|
|
875
|
-
|
|
876
|
-
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
877
|
-
print("Applied Liger kernels to Qwen2")
|
|
1135
|
+
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
|
|
878
1136
|
|
|
1137
|
+
apply_liger_kernel_to_gemma3_text(
|
|
1138
|
+
rope=rope,
|
|
1139
|
+
cross_entropy=False,
|
|
1140
|
+
fused_linear_cross_entropy=False,
|
|
1141
|
+
rms_norm=rms_norm,
|
|
1142
|
+
geglu=geglu,
|
|
1143
|
+
model=model.model.language_model,
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
else:
|
|
1147
|
+
raise TypeError("The model must be Gemma3ForConditionalGeneration.")
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
def apply_liger_kernel_to_paligemma(
|
|
1151
|
+
rope: bool = True,
|
|
1152
|
+
cross_entropy: bool = False,
|
|
1153
|
+
fused_linear_cross_entropy: bool = True,
|
|
1154
|
+
layer_norm: bool = True,
|
|
1155
|
+
rms_norm: bool = True,
|
|
1156
|
+
geglu: bool = True,
|
|
1157
|
+
model: PreTrainedModel = None,
|
|
1158
|
+
) -> None:
|
|
1159
|
+
"""
|
|
1160
|
+
Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
|
|
1161
|
+
|
|
1162
|
+
Args:
|
|
1163
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1164
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1165
|
+
fused_linear_cross_entropy (bool):
|
|
1166
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1167
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1168
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1169
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1170
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1171
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
1172
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1173
|
+
loaded. Default is None.
|
|
1174
|
+
"""
|
|
1175
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1176
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
1180
|
+
|
|
1181
|
+
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1182
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
1183
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1184
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
1185
|
+
from transformers.models.paligemma import modeling_paligemma
|
|
1186
|
+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
1187
|
+
from transformers.models.siglip import modeling_siglip
|
|
1188
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
1189
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
1190
|
+
|
|
1191
|
+
from liger_kernel.transformers.model.paligemma import lce_forward
|
|
1192
|
+
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
1193
|
+
|
|
1194
|
+
# The vision_tower is a SiglipVisionModel
|
|
1195
|
+
if layer_norm and model is None:
|
|
1196
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
1197
|
+
|
|
1198
|
+
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
1199
|
+
# The multi_modal_projector is Linear, nothing to do
|
|
1200
|
+
|
|
1201
|
+
# The language_model is GemmaForCausalLM or Gemma2ForCausalLM
|
|
1202
|
+
apply_liger_kernel_to_gemma(
|
|
1203
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1204
|
+
)
|
|
1205
|
+
apply_liger_kernel_to_gemma2(
|
|
1206
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1207
|
+
)
|
|
1208
|
+
# Handle loss function
|
|
1209
|
+
if cross_entropy:
|
|
1210
|
+
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1211
|
+
if fused_linear_cross_entropy:
|
|
1212
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1213
|
+
if model is not None:
|
|
1214
|
+
model.forward = MethodType(lce_forward, model)
|
|
1215
|
+
else:
|
|
1216
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
1217
|
+
else: # if version < 4.46.1
|
|
1218
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1219
|
+
if model is not None:
|
|
1220
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1221
|
+
else:
|
|
1222
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
1223
|
+
|
|
1224
|
+
if model is not None:
|
|
1225
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1226
|
+
# instance variables that reference already-instantiated modules
|
|
1227
|
+
|
|
1228
|
+
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
1229
|
+
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
1230
|
+
|
|
1231
|
+
vision_tower: SiglipVisionModel = model.model.vision_tower
|
|
1232
|
+
|
|
1233
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1234
|
+
|
|
1235
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
1236
|
+
layer: SiglipEncoderLayer
|
|
1237
|
+
if layer_norm:
|
|
1238
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
1239
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
1240
|
+
|
|
1241
|
+
language_model = model.model.language_model
|
|
1242
|
+
|
|
1243
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1244
|
+
apply_liger_kernel_to_gemma(
|
|
1245
|
+
rope=rope,
|
|
1246
|
+
cross_entropy=False,
|
|
1247
|
+
fused_linear_cross_entropy=False,
|
|
1248
|
+
rms_norm=rms_norm,
|
|
1249
|
+
geglu=geglu,
|
|
1250
|
+
model=language_model,
|
|
1251
|
+
)
|
|
1252
|
+
|
|
1253
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
1254
|
+
apply_liger_kernel_to_gemma2(
|
|
1255
|
+
rope=rope,
|
|
1256
|
+
cross_entropy=False,
|
|
1257
|
+
fused_linear_cross_entropy=False,
|
|
1258
|
+
rms_norm=rms_norm,
|
|
1259
|
+
geglu=geglu,
|
|
1260
|
+
model=language_model,
|
|
1261
|
+
)
|
|
1262
|
+
else:
|
|
1263
|
+
raise TypeError(
|
|
1264
|
+
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
|
|
1265
|
+
)
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def apply_liger_kernel_to_qwen2(
|
|
1269
|
+
rope: bool = True,
|
|
1270
|
+
cross_entropy: bool = False,
|
|
1271
|
+
fused_linear_cross_entropy: bool = True,
|
|
1272
|
+
rms_norm: bool = True,
|
|
1273
|
+
swiglu: bool = True,
|
|
1274
|
+
model: PreTrainedModel = None,
|
|
1275
|
+
) -> None:
|
|
1276
|
+
"""
|
|
1277
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
|
|
1278
|
+
|
|
1279
|
+
Args:
|
|
1280
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1281
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1282
|
+
fused_linear_cross_entropy (bool):
|
|
1283
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1284
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1285
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1286
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1287
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1288
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1289
|
+
loaded. Default is None.
|
|
1290
|
+
"""
|
|
1291
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1292
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
from transformers.models.qwen2 import modeling_qwen2
|
|
1296
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
1297
|
+
|
|
1298
|
+
if rope:
|
|
1299
|
+
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1300
|
+
if rms_norm:
|
|
1301
|
+
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
1302
|
+
|
|
1303
|
+
if cross_entropy:
|
|
1304
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1305
|
+
from transformers.loss.loss_utils import nn
|
|
1306
|
+
|
|
1307
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1308
|
+
else:
|
|
1309
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1310
|
+
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1311
|
+
|
|
1312
|
+
if fused_linear_cross_entropy:
|
|
1313
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1314
|
+
if model is not None:
|
|
1315
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1316
|
+
else:
|
|
1317
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1318
|
+
else: # if version < 4.46.1
|
|
1319
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1320
|
+
if model is not None:
|
|
1321
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1322
|
+
else:
|
|
1323
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1324
|
+
|
|
1325
|
+
if swiglu:
|
|
1326
|
+
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
1327
|
+
|
|
1328
|
+
if model is not None:
|
|
1329
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1330
|
+
# instance variables that reference already-instantiated modules
|
|
1331
|
+
|
|
1332
|
+
# get the base model from the model instance
|
|
1333
|
+
base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
|
|
1334
|
+
|
|
1335
|
+
if rms_norm:
|
|
1336
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1337
|
+
|
|
1338
|
+
for decoder_layer in base_model.layers:
|
|
1339
|
+
if swiglu:
|
|
1340
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1341
|
+
if rms_norm:
|
|
1342
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1343
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1344
|
+
|
|
1345
|
+
|
|
1346
|
+
def apply_liger_kernel_to_qwen3(
|
|
1347
|
+
rope: bool = True,
|
|
1348
|
+
cross_entropy: bool = False,
|
|
1349
|
+
fused_linear_cross_entropy: bool = True,
|
|
1350
|
+
rms_norm: bool = True,
|
|
1351
|
+
swiglu: bool = True,
|
|
1352
|
+
model: PreTrainedModel = None,
|
|
1353
|
+
) -> None:
|
|
1354
|
+
"""
|
|
1355
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1356
|
+
"""
|
|
1357
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1358
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
from transformers.models.qwen3 import modeling_qwen3
|
|
1362
|
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
|
|
1363
|
+
|
|
1364
|
+
from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
|
|
1365
|
+
|
|
1366
|
+
if rope:
|
|
1367
|
+
modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1368
|
+
|
|
1369
|
+
if rms_norm:
|
|
1370
|
+
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
|
1371
|
+
|
|
1372
|
+
if cross_entropy:
|
|
1373
|
+
from transformers.loss.loss_utils import nn
|
|
1374
|
+
|
|
1375
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1376
|
+
|
|
1377
|
+
if fused_linear_cross_entropy:
|
|
1378
|
+
if model is not None:
|
|
1379
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1380
|
+
else:
|
|
1381
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1382
|
+
|
|
1383
|
+
if swiglu:
|
|
1384
|
+
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
1385
|
+
|
|
1386
|
+
if model is not None:
|
|
1387
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1388
|
+
# instance variables that reference already-instantiated modules
|
|
1389
|
+
|
|
1390
|
+
# get the base model from the model instance
|
|
1391
|
+
base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
|
|
1392
|
+
|
|
1393
|
+
if rms_norm:
|
|
1394
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1395
|
+
for decoder_layer in base_model.layers:
|
|
1396
|
+
if swiglu:
|
|
1397
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1398
|
+
if rms_norm:
|
|
1399
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1400
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1401
|
+
|
|
1402
|
+
|
|
1403
|
+
def apply_liger_kernel_to_qwen3_moe(
|
|
1404
|
+
rope: bool = True,
|
|
1405
|
+
cross_entropy: bool = False,
|
|
1406
|
+
fused_linear_cross_entropy: bool = True,
|
|
1407
|
+
rms_norm: bool = True,
|
|
1408
|
+
swiglu: bool = True,
|
|
1409
|
+
model: PreTrainedModel = None,
|
|
1410
|
+
) -> None:
|
|
1411
|
+
"""
|
|
1412
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1413
|
+
"""
|
|
1414
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1415
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1416
|
+
)
|
|
1417
|
+
|
|
1418
|
+
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
1419
|
+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
|
|
1420
|
+
|
|
1421
|
+
from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
|
|
1422
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
1423
|
+
|
|
1424
|
+
if rope:
|
|
1425
|
+
modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1426
|
+
|
|
1427
|
+
if rms_norm:
|
|
1428
|
+
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
|
1429
|
+
|
|
1430
|
+
if cross_entropy:
|
|
1431
|
+
from transformers.loss.loss_utils import nn
|
|
1432
|
+
|
|
1433
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1434
|
+
|
|
1435
|
+
if fused_linear_cross_entropy:
|
|
1436
|
+
if model is not None:
|
|
1437
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1438
|
+
else:
|
|
1439
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1440
|
+
|
|
1441
|
+
if swiglu:
|
|
1442
|
+
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
1443
|
+
|
|
1444
|
+
if model is not None:
|
|
1445
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1446
|
+
# instance variables that reference already-instantiated modules
|
|
1447
|
+
|
|
1448
|
+
# get the base model from the model instance
|
|
1449
|
+
base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
|
|
1450
|
+
|
|
1451
|
+
if rms_norm:
|
|
1452
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1453
|
+
for decoder_layer in base_model.layers:
|
|
1454
|
+
if swiglu:
|
|
1455
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
1456
|
+
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
|
|
1457
|
+
if rms_norm:
|
|
1458
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1459
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
def apply_liger_kernel_to_gpt_oss(
|
|
1463
|
+
rope: bool = True,
|
|
1464
|
+
cross_entropy: bool = False,
|
|
1465
|
+
fused_linear_cross_entropy: bool = True,
|
|
1466
|
+
rms_norm: bool = True,
|
|
1467
|
+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
|
|
1468
|
+
model: PreTrainedModel = None,
|
|
1469
|
+
) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
|
|
1472
|
+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
|
|
1473
|
+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
|
|
1474
|
+
implementation with clamping and MXFP4 quantization.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1478
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1479
|
+
fused_linear_cross_entropy (bool):
|
|
1480
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1481
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1482
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1483
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1484
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1485
|
+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
|
|
1486
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1487
|
+
loaded. Default is None.
|
|
1488
|
+
"""
|
|
1489
|
+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
|
|
1490
|
+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1493
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1494
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
from transformers.models.gpt_oss import modeling_gpt_oss
|
|
1498
|
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
|
|
1499
|
+
|
|
1500
|
+
if rope:
|
|
1501
|
+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1502
|
+
|
|
1503
|
+
if rms_norm:
|
|
1504
|
+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
|
|
1505
|
+
|
|
1506
|
+
if cross_entropy:
|
|
1507
|
+
from transformers.loss.loss_utils import nn
|
|
1508
|
+
|
|
1509
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1510
|
+
|
|
1511
|
+
if fused_linear_cross_entropy:
|
|
1512
|
+
if model is not None:
|
|
1513
|
+
model.forward = MethodType(gpt_oss_lce_forward, model)
|
|
1514
|
+
else:
|
|
1515
|
+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
|
|
1516
|
+
|
|
1517
|
+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
|
|
1518
|
+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
|
|
1519
|
+
|
|
1520
|
+
if model is not None:
|
|
1521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1522
|
+
# instance variables that reference already-instantiated modules
|
|
1523
|
+
|
|
1524
|
+
# get the base model from the model instance
|
|
1525
|
+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
|
|
1526
|
+
|
|
1527
|
+
if rms_norm:
|
|
1528
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1529
|
+
for decoder_layer in base_model.layers:
|
|
1530
|
+
if rms_norm:
|
|
1531
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1532
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1533
|
+
|
|
1534
|
+
|
|
1535
|
+
def apply_liger_kernel_to_qwen2_vl(
|
|
1536
|
+
rope: bool = True,
|
|
1537
|
+
cross_entropy: bool = False,
|
|
1538
|
+
fused_linear_cross_entropy: bool = True,
|
|
1539
|
+
rms_norm: bool = True,
|
|
1540
|
+
layer_norm: bool = True,
|
|
1541
|
+
swiglu: bool = True,
|
|
1542
|
+
model: PreTrainedModel = None,
|
|
1543
|
+
) -> None:
|
|
1544
|
+
"""
|
|
1545
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
1546
|
+
NOTE: Qwen2-VL is not supported in transformers<4.52.4
|
|
1547
|
+
|
|
1548
|
+
Args:
|
|
1549
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1550
|
+
fused_linear_cross_entropy (bool):
|
|
1551
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1552
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1553
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1554
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1555
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1556
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1557
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1558
|
+
loaded. Default is None.
|
|
1559
|
+
"""
|
|
1560
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1561
|
+
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
|
|
1562
|
+
return
|
|
1563
|
+
|
|
1564
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1565
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1566
|
+
)
|
|
1567
|
+
|
|
1568
|
+
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
1569
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
|
1570
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
|
1571
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
1572
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
|
1573
|
+
|
|
1574
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
|
1575
|
+
|
|
1576
|
+
if rope:
|
|
1577
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
1578
|
+
if rms_norm:
|
|
1579
|
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1580
|
+
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1581
|
+
if layer_norm and model is None:
|
|
1582
|
+
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1583
|
+
if cross_entropy:
|
|
1584
|
+
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1585
|
+
if fused_linear_cross_entropy:
|
|
1586
|
+
if model is not None:
|
|
1587
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1588
|
+
else:
|
|
1589
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1590
|
+
if swiglu:
|
|
1591
|
+
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1592
|
+
|
|
1593
|
+
if model is not None:
|
|
1594
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1595
|
+
# instance variables that reference already-instantiated modules
|
|
1596
|
+
if isinstance(model, Qwen2VLForConditionalGeneration):
|
|
1597
|
+
text_model: Qwen2VLTextModel = model.model.language_model
|
|
1598
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
|
|
1599
|
+
elif isinstance(model, Qwen2VLModel):
|
|
1600
|
+
text_model: Qwen2VLTextModel = model.language_model
|
|
1601
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1602
|
+
elif isinstance(model, Qwen2VLTextModel):
|
|
1603
|
+
text_model: Qwen2VLTextModel = model
|
|
1604
|
+
vision_model = None
|
|
1605
|
+
else:
|
|
1606
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1607
|
+
raise TypeError(
|
|
1608
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1609
|
+
)
|
|
1610
|
+
|
|
1611
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
|
1612
|
+
if vision_model is not None:
|
|
1613
|
+
for vision_block in vision_model.blocks:
|
|
1614
|
+
if layer_norm:
|
|
1615
|
+
_patch_layer_norm_module(vision_block.norm1)
|
|
1616
|
+
_patch_layer_norm_module(vision_block.norm2)
|
|
1617
|
+
|
|
1618
|
+
# Patch Qwen2VisionTextModel
|
|
1619
|
+
if text_model is not None:
|
|
1620
|
+
if rms_norm:
|
|
1621
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1622
|
+
for decoder_layer in text_model.layers:
|
|
1623
|
+
if swiglu:
|
|
1624
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1625
|
+
if rms_norm:
|
|
1626
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1627
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1628
|
+
|
|
1629
|
+
|
|
1630
|
+
def apply_liger_kernel_to_qwen2_5_vl(
|
|
1631
|
+
rope: bool = True,
|
|
1632
|
+
cross_entropy: bool = False,
|
|
1633
|
+
fused_linear_cross_entropy: bool = True,
|
|
1634
|
+
rms_norm: bool = True,
|
|
1635
|
+
swiglu: bool = True,
|
|
1636
|
+
model: PreTrainedModel = None,
|
|
1637
|
+
) -> None:
|
|
1638
|
+
"""
|
|
1639
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
|
|
1640
|
+
NOTE: Qwen2.5-VL is not available in transformers<4.48.2
|
|
1641
|
+
|
|
1642
|
+
Args:
|
|
1643
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1644
|
+
fused_linear_cross_entropy (bool):
|
|
1645
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1646
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1647
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1648
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1649
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1650
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1651
|
+
loaded. Default is None.
|
|
1652
|
+
"""
|
|
1653
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1654
|
+
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
|
|
1655
|
+
return
|
|
1656
|
+
|
|
1657
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1658
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1659
|
+
)
|
|
1660
|
+
|
|
1661
|
+
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
1662
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
|
1663
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
|
1664
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
|
1665
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
|
1666
|
+
|
|
1667
|
+
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
|
1668
|
+
|
|
1669
|
+
if rope:
|
|
1670
|
+
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
1671
|
+
if rms_norm:
|
|
1672
|
+
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1673
|
+
if cross_entropy:
|
|
1674
|
+
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1675
|
+
if fused_linear_cross_entropy:
|
|
1676
|
+
if model is not None:
|
|
1677
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1678
|
+
else:
|
|
1679
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1680
|
+
if swiglu:
|
|
1681
|
+
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1682
|
+
|
|
1683
|
+
if model is not None:
|
|
1684
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1685
|
+
# instance variables that reference already-instantiated modules
|
|
1686
|
+
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
1687
|
+
text_model: Qwen2_5_VLTextModel = model.model.language_model
|
|
1688
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
|
|
1689
|
+
elif isinstance(model, Qwen2_5_VLModel):
|
|
1690
|
+
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1691
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1692
|
+
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
1693
|
+
text_model: Qwen2_5_VLTextModel = model
|
|
1694
|
+
vision_model = None
|
|
1695
|
+
else:
|
|
1696
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1697
|
+
raise TypeError(
|
|
1698
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1699
|
+
)
|
|
1700
|
+
|
|
1701
|
+
if vision_model is not None:
|
|
1702
|
+
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1703
|
+
for vision_block in vision_model.blocks:
|
|
1704
|
+
if rms_norm:
|
|
1705
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
1706
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
1707
|
+
|
|
1708
|
+
if text_model is not None:
|
|
1709
|
+
if rms_norm:
|
|
1710
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1711
|
+
for decoder_layer in text_model.layers:
|
|
1712
|
+
if swiglu:
|
|
1713
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1714
|
+
if rms_norm:
|
|
1715
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1716
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1717
|
+
|
|
1718
|
+
|
|
1719
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1720
|
+
rope: bool = True,
|
|
1721
|
+
cross_entropy: bool = False,
|
|
1722
|
+
fused_linear_cross_entropy: bool = True,
|
|
1723
|
+
rms_norm: bool = True,
|
|
1724
|
+
swiglu: bool = False,
|
|
1725
|
+
model: PreTrainedModel = None,
|
|
1726
|
+
) -> None:
|
|
1727
|
+
"""
|
|
1728
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1729
|
+
|
|
1730
|
+
Args:
|
|
1731
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1732
|
+
fused_linear_cross_entropy (bool):
|
|
1733
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1734
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1735
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1736
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1737
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1738
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1739
|
+
loaded. Default is None.
|
|
1740
|
+
"""
|
|
1741
|
+
|
|
1742
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1743
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1747
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1748
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1749
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1750
|
+
|
|
1751
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1752
|
+
|
|
1753
|
+
if rope:
|
|
1754
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1755
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1756
|
+
|
|
1757
|
+
if rms_norm:
|
|
1758
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1759
|
+
|
|
1760
|
+
if cross_entropy:
|
|
1761
|
+
from transformers.loss.loss_utils import nn
|
|
1762
|
+
|
|
1763
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1764
|
+
|
|
1765
|
+
if fused_linear_cross_entropy:
|
|
1766
|
+
if model is not None:
|
|
1767
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1768
|
+
else:
|
|
1769
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1770
|
+
|
|
1771
|
+
if model is not None and rms_norm:
|
|
1772
|
+
if isinstance(model, Qwen3VLForConditionalGeneration):
|
|
1773
|
+
text_model: Qwen3VLTextModel = model.model.language_model
|
|
1774
|
+
elif isinstance(model, Qwen3VLModel):
|
|
1775
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1776
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1777
|
+
text_model = model
|
|
1778
|
+
else:
|
|
1779
|
+
raise TypeError(
|
|
1780
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1784
|
+
|
|
1785
|
+
if text_model is not None:
|
|
1786
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1787
|
+
for decoder_layer in text_model.layers:
|
|
1788
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1789
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1790
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1791
|
+
if self_attn is not None:
|
|
1792
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1793
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1794
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1795
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1799
|
+
rope: bool = True,
|
|
1800
|
+
cross_entropy: bool = False,
|
|
1801
|
+
fused_linear_cross_entropy: bool = True,
|
|
1802
|
+
rms_norm: bool = True,
|
|
1803
|
+
swiglu: bool = False,
|
|
1804
|
+
model: PreTrainedModel = None,
|
|
1805
|
+
) -> None:
|
|
1806
|
+
"""
|
|
1807
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1808
|
+
|
|
1809
|
+
Args:
|
|
1810
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1811
|
+
fused_linear_cross_entropy (bool):
|
|
1812
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1813
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1814
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1815
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1816
|
+
loaded. Default is None.
|
|
1817
|
+
"""
|
|
1818
|
+
|
|
1819
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1820
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1824
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1825
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1826
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1827
|
+
|
|
1828
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1829
|
+
|
|
1830
|
+
if rope:
|
|
1831
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1832
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1833
|
+
|
|
1834
|
+
if rms_norm:
|
|
1835
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1836
|
+
|
|
1837
|
+
if cross_entropy:
|
|
1838
|
+
from transformers.loss.loss_utils import nn
|
|
1839
|
+
|
|
1840
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1841
|
+
|
|
1842
|
+
if fused_linear_cross_entropy:
|
|
1843
|
+
if model is not None:
|
|
1844
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1845
|
+
else:
|
|
1846
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1847
|
+
|
|
1848
|
+
if model is not None and rms_norm:
|
|
1849
|
+
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
|
|
1850
|
+
text_model: Qwen3VLMoeTextModel = model.model.language_model
|
|
1851
|
+
elif isinstance(model, Qwen3VLMoeModel):
|
|
1852
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1853
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1854
|
+
text_model = model
|
|
1855
|
+
else:
|
|
1856
|
+
raise TypeError(
|
|
1857
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1861
|
+
|
|
1862
|
+
if text_model is not None:
|
|
1863
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1864
|
+
for decoder_layer in text_model.layers:
|
|
1865
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1866
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1867
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1868
|
+
if self_attn is not None:
|
|
1869
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1870
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1871
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1872
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1873
|
+
|
|
1874
|
+
|
|
1875
|
+
def apply_liger_kernel_to_phi3(
|
|
1876
|
+
rope: bool = True,
|
|
1877
|
+
cross_entropy: bool = False,
|
|
1878
|
+
fused_linear_cross_entropy: bool = True,
|
|
1879
|
+
rms_norm: bool = True,
|
|
1880
|
+
swiglu: bool = True,
|
|
1881
|
+
model: PreTrainedModel = None,
|
|
1882
|
+
) -> None:
|
|
1883
|
+
"""
|
|
1884
|
+
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
|
|
1885
|
+
|
|
1886
|
+
Args:
|
|
1887
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1888
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1889
|
+
fused_linear_cross_entropy (bool):
|
|
1890
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1891
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1892
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1893
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1894
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
1895
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1896
|
+
loaded. Default is None.
|
|
1897
|
+
"""
|
|
1898
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1899
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1900
|
+
)
|
|
1901
|
+
|
|
1902
|
+
from transformers.models.phi3 import modeling_phi3
|
|
1903
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
1904
|
+
|
|
1905
|
+
if rope:
|
|
1906
|
+
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
1907
|
+
if rms_norm:
|
|
1908
|
+
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
|
|
1909
|
+
if swiglu:
|
|
1910
|
+
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1911
|
+
if cross_entropy:
|
|
1912
|
+
from transformers.loss.loss_utils import nn
|
|
1913
|
+
|
|
1914
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1915
|
+
if fused_linear_cross_entropy:
|
|
1916
|
+
if model is not None:
|
|
1917
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1918
|
+
else:
|
|
1919
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1920
|
+
|
|
1921
|
+
if model is not None:
|
|
1922
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1923
|
+
# instance variables that reference already-instantiated modules
|
|
1924
|
+
|
|
1925
|
+
# get the base model from the model instance
|
|
1926
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
1927
|
+
|
|
1928
|
+
if rms_norm:
|
|
1929
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1930
|
+
|
|
1931
|
+
for decoder_layer in base_model.layers:
|
|
1932
|
+
if swiglu:
|
|
1933
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
1934
|
+
if rms_norm:
|
|
1935
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1936
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1937
|
+
|
|
1938
|
+
|
|
1939
|
+
def apply_liger_kernel_to_olmo2(
|
|
1940
|
+
rope: bool = True,
|
|
1941
|
+
cross_entropy: bool = False,
|
|
1942
|
+
fused_linear_cross_entropy: bool = True,
|
|
1943
|
+
rms_norm: bool = True,
|
|
1944
|
+
swiglu: bool = True,
|
|
1945
|
+
model: PreTrainedModel = None,
|
|
1946
|
+
) -> None:
|
|
1947
|
+
"""
|
|
1948
|
+
Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
|
|
1949
|
+
|
|
1950
|
+
Args:
|
|
1951
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1952
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1953
|
+
fused_linear_cross_entropy (bool):
|
|
1954
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1955
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1956
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1957
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1958
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
|
|
1959
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1960
|
+
loaded. Default is None.
|
|
1961
|
+
"""
|
|
1962
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1963
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1964
|
+
)
|
|
1965
|
+
|
|
1966
|
+
from transformers.models.olmo2 import modeling_olmo2
|
|
1967
|
+
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
1968
|
+
|
|
1969
|
+
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
1970
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1971
|
+
|
|
1972
|
+
if rope:
|
|
1973
|
+
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1974
|
+
if rms_norm:
|
|
1975
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
|
1976
|
+
if swiglu:
|
|
1977
|
+
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1978
|
+
if cross_entropy:
|
|
1979
|
+
from transformers.loss.loss_utils import nn
|
|
1980
|
+
|
|
1981
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1982
|
+
if fused_linear_cross_entropy:
|
|
1983
|
+
if model is not None:
|
|
1984
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1985
|
+
else:
|
|
1986
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1987
|
+
|
|
1988
|
+
if model is not None:
|
|
1989
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1990
|
+
# instance variables that reference already-instantiated modules
|
|
1991
|
+
|
|
1992
|
+
# get the base model from the model instance
|
|
1993
|
+
base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
|
|
1994
|
+
|
|
1995
|
+
if rms_norm:
|
|
1996
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1997
|
+
|
|
1998
|
+
for decoder_layer in base_model.layers:
|
|
1999
|
+
if swiglu:
|
|
2000
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2001
|
+
if rms_norm:
|
|
2002
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2003
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2004
|
+
|
|
2005
|
+
|
|
2006
|
+
def apply_liger_kernel_to_olmo3(
|
|
2007
|
+
rope: bool = True,
|
|
2008
|
+
cross_entropy: bool = False,
|
|
2009
|
+
fused_linear_cross_entropy: bool = True,
|
|
2010
|
+
rms_norm: bool = True,
|
|
2011
|
+
swiglu: bool = True,
|
|
2012
|
+
model: PreTrainedModel = None,
|
|
2013
|
+
) -> None:
|
|
2014
|
+
"""
|
|
2015
|
+
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
|
|
2016
|
+
|
|
2017
|
+
Args:
|
|
2018
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2019
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2020
|
+
fused_linear_cross_entropy (bool):
|
|
2021
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2022
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2023
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2024
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2025
|
+
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
|
|
2026
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2027
|
+
loaded. Default is None.
|
|
2028
|
+
"""
|
|
2029
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2030
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2031
|
+
)
|
|
2032
|
+
|
|
2033
|
+
from transformers.models.olmo3 import modeling_olmo3
|
|
2034
|
+
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
|
|
2035
|
+
|
|
2036
|
+
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
|
|
2037
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
2038
|
+
|
|
2039
|
+
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
|
|
2040
|
+
if rope:
|
|
2041
|
+
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2042
|
+
if rms_norm:
|
|
2043
|
+
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
|
|
2044
|
+
if swiglu:
|
|
2045
|
+
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
|
|
2046
|
+
if cross_entropy:
|
|
2047
|
+
from transformers.loss.loss_utils import nn
|
|
2048
|
+
|
|
2049
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2050
|
+
if fused_linear_cross_entropy:
|
|
2051
|
+
if model is not None:
|
|
2052
|
+
model.forward = MethodType(olmo3_lce_forward, model)
|
|
2053
|
+
else:
|
|
2054
|
+
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
|
|
2055
|
+
|
|
2056
|
+
if model is not None:
|
|
2057
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2058
|
+
# instance variables that reference already-instantiated modules
|
|
2059
|
+
|
|
2060
|
+
# get the base model from the model instance
|
|
2061
|
+
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
|
|
2062
|
+
|
|
2063
|
+
if rms_norm:
|
|
2064
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2065
|
+
|
|
2066
|
+
for decoder_layer in base_model.layers:
|
|
2067
|
+
if swiglu:
|
|
2068
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2069
|
+
if rms_norm:
|
|
2070
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2071
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2072
|
+
|
|
2073
|
+
|
|
2074
|
+
def apply_liger_kernel_to_glm4(
|
|
2075
|
+
rope: bool = False,
|
|
2076
|
+
cross_entropy: bool = False,
|
|
2077
|
+
fused_linear_cross_entropy: bool = True,
|
|
2078
|
+
rms_norm: bool = True,
|
|
2079
|
+
swiglu: bool = True,
|
|
2080
|
+
model: PreTrainedModel = None,
|
|
2081
|
+
) -> None:
|
|
2082
|
+
"""
|
|
2083
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
|
|
2084
|
+
|
|
2085
|
+
Args:
|
|
2086
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2087
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2088
|
+
fused_linear_cross_entropy (bool):
|
|
2089
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2090
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2091
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2092
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2093
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
2094
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2095
|
+
loaded. Default is None.
|
|
2096
|
+
"""
|
|
2097
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2098
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2099
|
+
)
|
|
2100
|
+
|
|
2101
|
+
from transformers.models.glm4 import modeling_glm4
|
|
2102
|
+
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
|
2103
|
+
|
|
2104
|
+
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
|
2105
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2106
|
+
|
|
2107
|
+
if rope:
|
|
2108
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2109
|
+
if rms_norm:
|
|
2110
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
|
2111
|
+
if swiglu:
|
|
2112
|
+
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
|
2113
|
+
if cross_entropy:
|
|
2114
|
+
from transformers.loss.loss_utils import nn
|
|
2115
|
+
|
|
2116
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2117
|
+
if fused_linear_cross_entropy:
|
|
2118
|
+
if model is not None:
|
|
2119
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
2120
|
+
else:
|
|
2121
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
2122
|
+
|
|
2123
|
+
if model is not None:
|
|
2124
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2125
|
+
# instance variables that reference already-instantiated modules
|
|
2126
|
+
|
|
2127
|
+
# get the base model from the model instance
|
|
2128
|
+
base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
|
|
2129
|
+
|
|
2130
|
+
if rms_norm:
|
|
2131
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
2132
|
+
|
|
2133
|
+
for decoder_layer in base_model.layers:
|
|
2134
|
+
if swiglu:
|
|
2135
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2136
|
+
if rms_norm:
|
|
2137
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
|
|
2138
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2139
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
|
|
2140
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
|
|
2141
|
+
|
|
2142
|
+
|
|
2143
|
+
def apply_liger_kernel_to_glm4v(
|
|
2144
|
+
rope: bool = False,
|
|
2145
|
+
cross_entropy: bool = False,
|
|
2146
|
+
fused_linear_cross_entropy: bool = True,
|
|
2147
|
+
rms_norm: bool = True,
|
|
2148
|
+
swiglu: bool = True,
|
|
2149
|
+
model: PreTrainedModel = None,
|
|
2150
|
+
) -> None:
|
|
2151
|
+
"""
|
|
2152
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
|
|
2153
|
+
|
|
2154
|
+
Args:
|
|
2155
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2156
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2157
|
+
fused_linear_cross_entropy (bool):
|
|
2158
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2159
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2160
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2161
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2162
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
2163
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2164
|
+
loaded. Default is None.
|
|
2165
|
+
"""
|
|
2166
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2167
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2168
|
+
)
|
|
2169
|
+
|
|
2170
|
+
from transformers.models.glm4v import modeling_glm4v
|
|
2171
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
|
|
2172
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
|
|
2173
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
|
|
2174
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
|
|
2175
|
+
|
|
2176
|
+
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
|
|
2177
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2178
|
+
|
|
2179
|
+
if rope:
|
|
2180
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2181
|
+
if rms_norm:
|
|
2182
|
+
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
|
|
2183
|
+
if cross_entropy:
|
|
2184
|
+
from transformers.loss.loss_utils import nn
|
|
2185
|
+
|
|
2186
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2187
|
+
if fused_linear_cross_entropy:
|
|
2188
|
+
if model is not None:
|
|
2189
|
+
model.forward = MethodType(glm4v_lce_forward, model)
|
|
2190
|
+
else:
|
|
2191
|
+
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
|
|
2192
|
+
|
|
2193
|
+
if model is not None:
|
|
2194
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2195
|
+
# instance variables that reference already-instantiated modules
|
|
2196
|
+
if isinstance(model, Glm4vForConditionalGeneration):
|
|
2197
|
+
text_model: Glm4vTextModel = model.model.language_model
|
|
2198
|
+
vision_model: Glm4vVisionModel = model.model.visual
|
|
2199
|
+
elif isinstance(model, Glm4vModel):
|
|
2200
|
+
text_model: Glm4vTextModel = model.language_model
|
|
2201
|
+
vision_model: Glm4vVisionModel = model.visual
|
|
2202
|
+
elif isinstance(model, Glm4vTextModel):
|
|
2203
|
+
text_model: Glm4vTextModel = model
|
|
2204
|
+
vision_model = None
|
|
2205
|
+
else:
|
|
2206
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2207
|
+
raise TypeError(
|
|
2208
|
+
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
|
|
2209
|
+
)
|
|
2210
|
+
|
|
2211
|
+
if vision_model is not None:
|
|
2212
|
+
for vision_block in vision_model.blocks:
|
|
2213
|
+
if rms_norm:
|
|
2214
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2215
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2216
|
+
if swiglu:
|
|
2217
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2218
|
+
|
|
2219
|
+
if text_model is not None:
|
|
2220
|
+
if rms_norm:
|
|
2221
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2222
|
+
for decoder_layer in text_model.layers:
|
|
2223
|
+
if swiglu:
|
|
2224
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2225
|
+
if rms_norm:
|
|
2226
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2227
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2228
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
|
|
2229
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
|
|
2230
|
+
|
|
2231
|
+
|
|
2232
|
+
def apply_liger_kernel_to_glm4v_moe(
|
|
2233
|
+
rope: bool = False,
|
|
2234
|
+
cross_entropy: bool = False,
|
|
2235
|
+
fused_linear_cross_entropy: bool = True,
|
|
2236
|
+
rms_norm: bool = True,
|
|
2237
|
+
swiglu: bool = True,
|
|
2238
|
+
model: PreTrainedModel = None,
|
|
2239
|
+
) -> None:
|
|
2240
|
+
"""
|
|
2241
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
2242
|
+
|
|
2243
|
+
Args:
|
|
2244
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2245
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2246
|
+
fused_linear_cross_entropy (bool):
|
|
2247
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2248
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2249
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2250
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2251
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
2252
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2253
|
+
loaded. Default is None.
|
|
2254
|
+
"""
|
|
2255
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2256
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2257
|
+
)
|
|
2258
|
+
|
|
2259
|
+
from transformers.models.glm4v_moe import modeling_glm4v_moe
|
|
2260
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
|
|
2261
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
|
|
2262
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
|
|
2263
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
|
|
2264
|
+
|
|
2265
|
+
from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
|
|
2266
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2267
|
+
|
|
2268
|
+
if rope:
|
|
2269
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2270
|
+
if rms_norm:
|
|
2271
|
+
modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
|
|
2272
|
+
modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
|
|
2273
|
+
if cross_entropy:
|
|
2274
|
+
from transformers.loss.loss_utils import nn
|
|
2275
|
+
|
|
2276
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2277
|
+
if fused_linear_cross_entropy:
|
|
2278
|
+
if model is not None:
|
|
2279
|
+
model.forward = MethodType(glm4v_moe_lce_forward, model)
|
|
2280
|
+
else:
|
|
2281
|
+
modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
|
|
2282
|
+
|
|
2283
|
+
if model is not None:
|
|
2284
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2285
|
+
# instance variables that reference already-instantiated modules
|
|
2286
|
+
if isinstance(model, Glm4vMoeForConditionalGeneration):
|
|
2287
|
+
text_model: Glm4vMoeTextModel = model.model.language_model
|
|
2288
|
+
vision_model: Glm4vMoeVisionModel = model.model.visual
|
|
2289
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2290
|
+
elif isinstance(model, Glm4vMoeModel):
|
|
2291
|
+
text_model: Glm4vMoeTextModel = model.language_model
|
|
2292
|
+
vision_model: Glm4vMoeVisionModel = model.visual
|
|
2293
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2294
|
+
elif isinstance(model, Glm4vMoeTextModel):
|
|
2295
|
+
text_model: Glm4vMoeTextModel = model
|
|
2296
|
+
vision_model = None
|
|
2297
|
+
else:
|
|
2298
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2299
|
+
raise TypeError(
|
|
2300
|
+
f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
|
|
2301
|
+
)
|
|
2302
|
+
|
|
2303
|
+
if vision_model is not None:
|
|
2304
|
+
_patch_rms_norm_module(vision_model.post_conv_layernorm)
|
|
2305
|
+
_patch_rms_norm_module(vision_model.post_layernorm)
|
|
2306
|
+
for vision_block in vision_model.blocks:
|
|
2307
|
+
if rms_norm:
|
|
2308
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2309
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2310
|
+
if swiglu:
|
|
2311
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2312
|
+
|
|
2313
|
+
if text_model is not None:
|
|
2314
|
+
if rms_norm:
|
|
2315
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2316
|
+
for decoder_layer in text_model.layers:
|
|
2317
|
+
if swiglu:
|
|
2318
|
+
decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2319
|
+
if rms_norm:
|
|
2320
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2321
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2322
|
+
if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
|
|
2323
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2324
|
+
if experts is not None:
|
|
2325
|
+
for expert in experts:
|
|
2326
|
+
_patch_swiglu_module(expert, LigerSwiGLUMLP)
|
|
2327
|
+
if decoder_layer.mlp.shared_experts is not None:
|
|
2328
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
|
|
2329
|
+
for decoder_layer in text_model.layers:
|
|
2330
|
+
if rms_norm:
|
|
2331
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2332
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2333
|
+
|
|
2334
|
+
|
|
2335
|
+
def apply_liger_kernel_to_internvl(
|
|
2336
|
+
cross_entropy: bool = False,
|
|
2337
|
+
fused_linear_cross_entropy: bool = True,
|
|
2338
|
+
rms_norm: bool = True,
|
|
2339
|
+
layer_norm: bool = True,
|
|
2340
|
+
model: Optional[PreTrainedModel] = None,
|
|
2341
|
+
**kwargs,
|
|
2342
|
+
) -> None:
|
|
2343
|
+
"""
|
|
2344
|
+
Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
|
|
2345
|
+
Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
|
|
2346
|
+
However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
|
|
2347
|
+
NOTE: InternVL is not available in transformers<4.52.1
|
|
2348
|
+
|
|
2349
|
+
Args:
|
|
2350
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2351
|
+
fused_linear_cross_entropy (bool):
|
|
2352
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2353
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2354
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2355
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2356
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2357
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2358
|
+
loaded. Default is None.
|
|
2359
|
+
"""
|
|
2360
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2361
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2362
|
+
)
|
|
2363
|
+
import torch.nn as torch_nn
|
|
2364
|
+
|
|
2365
|
+
from transformers.models.internvl import modeling_internvl
|
|
2366
|
+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
|
|
2367
|
+
from transformers.models.internvl.modeling_internvl import InternVLModel
|
|
2368
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
|
|
2369
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
|
|
2370
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
|
|
2371
|
+
|
|
2372
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
2373
|
+
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
|
|
2374
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
2375
|
+
|
|
2376
|
+
if layer_norm and model is None:
|
|
2377
|
+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
|
|
2378
|
+
|
|
2379
|
+
if cross_entropy:
|
|
2380
|
+
logger.info("Apply liger cross entropy")
|
|
2381
|
+
|
|
2382
|
+
from transformers.loss.loss_utils import nn
|
|
879
2383
|
|
|
880
|
-
|
|
881
|
-
|
|
2384
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2385
|
+
if fused_linear_cross_entropy:
|
|
2386
|
+
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
|
|
2387
|
+
if rms_norm:
|
|
2388
|
+
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
|
|
2389
|
+
|
|
2390
|
+
if model is not None:
|
|
2391
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2392
|
+
# instance variables that reference already-instantiated modules
|
|
2393
|
+
if isinstance(model, InternVLForConditionalGeneration):
|
|
2394
|
+
text_model = model.model.language_model
|
|
2395
|
+
vision_model: InternVLVisionModel = model.model.vision_tower
|
|
2396
|
+
elif isinstance(model, InternVLModel):
|
|
2397
|
+
text_model = model.language_model
|
|
2398
|
+
vision_model: InternVLVisionModel = model.vision_tower
|
|
2399
|
+
else:
|
|
2400
|
+
raise TypeError(
|
|
2401
|
+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
|
|
2402
|
+
)
|
|
2403
|
+
|
|
2404
|
+
text_model_name = model.config.text_config.model_type
|
|
2405
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2406
|
+
|
|
2407
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2408
|
+
if text_liger_fn:
|
|
2409
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2410
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2411
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2412
|
+
|
|
2413
|
+
if remain_params:
|
|
2414
|
+
logger.warning(
|
|
2415
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2416
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2417
|
+
)
|
|
2418
|
+
text_kwargs["model"] = text_model
|
|
2419
|
+
text_liger_fn(**text_kwargs)
|
|
2420
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2421
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2422
|
+
|
|
2423
|
+
# Patch vision model RMSNorm layers
|
|
2424
|
+
if rms_norm:
|
|
2425
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2426
|
+
encoder_layer: InternVLVisionLayer
|
|
2427
|
+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
|
|
2428
|
+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
|
|
2429
|
+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
|
|
2430
|
+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
|
|
2431
|
+
|
|
2432
|
+
# Patch vision model LayerNorm layers
|
|
2433
|
+
if layer_norm:
|
|
2434
|
+
# Patch layernorm
|
|
2435
|
+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
|
|
2436
|
+
_patch_layer_norm_module(vision_model.layernorm)
|
|
2437
|
+
|
|
2438
|
+
# Patch encoder layers
|
|
2439
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2440
|
+
encoder_layer: InternVLVisionLayer
|
|
2441
|
+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
|
|
2442
|
+
_patch_layer_norm_module(encoder_layer.layernorm_before)
|
|
2443
|
+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
|
|
2444
|
+
_patch_layer_norm_module(encoder_layer.layernorm_after)
|
|
2445
|
+
|
|
2446
|
+
|
|
2447
|
+
def apply_liger_kernel_to_smolvlm(
|
|
882
2448
|
cross_entropy: bool = False,
|
|
883
2449
|
fused_linear_cross_entropy: bool = True,
|
|
884
2450
|
rms_norm: bool = True,
|
|
885
2451
|
layer_norm: bool = True,
|
|
886
|
-
|
|
887
|
-
|
|
2452
|
+
model: Optional[PreTrainedModel] = None,
|
|
2453
|
+
**kwargs,
|
|
888
2454
|
) -> None:
|
|
889
2455
|
"""
|
|
890
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
891
|
-
|
|
2456
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
|
|
2457
|
+
Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
|
|
2458
|
+
However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
|
|
2459
|
+
NOTE: SmolVLM is not available in transformers<4.50.0
|
|
892
2460
|
|
|
893
2461
|
Args:
|
|
894
2462
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
@@ -898,7 +2466,6 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
898
2466
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
899
2467
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
900
2468
|
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
901
|
-
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
902
2469
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
903
2470
|
loaded. Default is None.
|
|
904
2471
|
"""
|
|
@@ -906,51 +2473,148 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
906
2473
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
907
2474
|
)
|
|
908
2475
|
|
|
909
|
-
from transformers.models.
|
|
910
|
-
from transformers.models.
|
|
2476
|
+
from transformers.models.smolvlm import modeling_smolvlm
|
|
2477
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
|
|
2478
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
|
|
2479
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
|
|
2480
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
|
|
911
2481
|
|
|
912
|
-
from liger_kernel.transformers.model.
|
|
2482
|
+
from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
|
|
2483
|
+
|
|
2484
|
+
# Patch LayerNorm for vision model if model is not provided (pre-initialization)
|
|
2485
|
+
if layer_norm and model is None:
|
|
2486
|
+
modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
|
|
2487
|
+
|
|
2488
|
+
if cross_entropy:
|
|
2489
|
+
logger.info("Apply liger cross entropy")
|
|
2490
|
+
|
|
2491
|
+
from transformers.loss.loss_utils import nn
|
|
2492
|
+
|
|
2493
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2494
|
+
if fused_linear_cross_entropy:
|
|
2495
|
+
if model is not None:
|
|
2496
|
+
model.forward = MethodType(smolvlm_lce_forward, model)
|
|
2497
|
+
else:
|
|
2498
|
+
modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
|
|
2499
|
+
if rms_norm:
|
|
2500
|
+
modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
|
|
2501
|
+
|
|
2502
|
+
if model is not None:
|
|
2503
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2504
|
+
# instance variables that reference already-instantiated modules
|
|
2505
|
+
if isinstance(model, SmolVLMForConditionalGeneration):
|
|
2506
|
+
text_model = model.model.text_model
|
|
2507
|
+
vision_model: SmolVLMVisionTransformer = model.model.vision_model
|
|
2508
|
+
elif isinstance(model, SmolVLMModel):
|
|
2509
|
+
text_model = model.text_model
|
|
2510
|
+
vision_model: SmolVLMVisionTransformer = model.vision_model
|
|
2511
|
+
else:
|
|
2512
|
+
raise TypeError(
|
|
2513
|
+
f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
|
|
2514
|
+
)
|
|
2515
|
+
|
|
2516
|
+
text_model_name = model.config.text_config.model_type
|
|
2517
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2518
|
+
|
|
2519
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2520
|
+
if text_liger_fn:
|
|
2521
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2522
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2523
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2524
|
+
|
|
2525
|
+
if remain_params:
|
|
2526
|
+
logger.warning(
|
|
2527
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2528
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2529
|
+
)
|
|
2530
|
+
text_kwargs["model"] = text_model
|
|
2531
|
+
text_liger_fn(**text_kwargs)
|
|
2532
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2533
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2534
|
+
|
|
2535
|
+
# Patch vision model LayerNorm layers
|
|
2536
|
+
if layer_norm:
|
|
2537
|
+
# Patch post_layernorm
|
|
2538
|
+
_patch_layer_norm_module(vision_model.post_layernorm)
|
|
2539
|
+
|
|
2540
|
+
# Patch encoder layers
|
|
2541
|
+
for encoder_layer in vision_model.encoder.layers:
|
|
2542
|
+
encoder_layer: SmolVLMEncoderLayer
|
|
2543
|
+
_patch_layer_norm_module(encoder_layer.layer_norm1)
|
|
2544
|
+
_patch_layer_norm_module(encoder_layer.layer_norm2)
|
|
2545
|
+
|
|
2546
|
+
|
|
2547
|
+
def apply_liger_kernel_to_falcon_h1(
|
|
2548
|
+
rope: bool = True,
|
|
2549
|
+
cross_entropy: bool = False,
|
|
2550
|
+
fused_linear_cross_entropy: bool = True,
|
|
2551
|
+
rms_norm: bool = True,
|
|
2552
|
+
swiglu: bool = False,
|
|
2553
|
+
model: PreTrainedModel = None,
|
|
2554
|
+
) -> None:
|
|
2555
|
+
"""
|
|
2556
|
+
Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
|
|
2557
|
+
Args:
|
|
2558
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2559
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
2560
|
+
fused_linear_cross_entropy (bool):
|
|
2561
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
2562
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2563
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2564
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
2565
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2566
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2567
|
+
loaded. Default is None.
|
|
2568
|
+
"""
|
|
2569
|
+
|
|
2570
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2571
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2572
|
+
)
|
|
2573
|
+
|
|
2574
|
+
from transformers.models.falcon_h1 import modeling_falcon_h1
|
|
2575
|
+
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
|
|
913
2576
|
|
|
914
2577
|
if rope:
|
|
915
|
-
|
|
2578
|
+
logger.info("Apply liger rotary pos emb.")
|
|
2579
|
+
modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
916
2580
|
if rms_norm:
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
if
|
|
920
|
-
|
|
2581
|
+
logger.info("Apply liger RMSNorm")
|
|
2582
|
+
modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
|
|
2583
|
+
if swiglu:
|
|
2584
|
+
logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
|
|
2585
|
+
|
|
921
2586
|
if cross_entropy:
|
|
922
|
-
|
|
2587
|
+
logger.info("Apply liger cross entropy")
|
|
2588
|
+
from transformers.loss.loss_utils import nn
|
|
2589
|
+
|
|
2590
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2591
|
+
|
|
923
2592
|
if fused_linear_cross_entropy:
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
2593
|
+
if model is not None:
|
|
2594
|
+
model.forward = MethodType(falcon_h1_lce_forward, model)
|
|
2595
|
+
else:
|
|
2596
|
+
modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
|
|
927
2597
|
|
|
928
2598
|
if model is not None:
|
|
929
2599
|
# The model instance already exists, so we need to additionally patch the
|
|
930
|
-
# instance variables that reference already-instantiated modules
|
|
2600
|
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
931
2601
|
|
|
932
2602
|
# get the base model from the model instance
|
|
933
|
-
base_model:
|
|
934
|
-
|
|
935
|
-
if hasattr(model, "visual"):
|
|
936
|
-
# Patch Qwen2VisionTransformerPretrainedModel
|
|
937
|
-
for vision_block in model.visual.blocks:
|
|
938
|
-
if layer_norm:
|
|
939
|
-
_patch_layer_norm_module(vision_block.norm1)
|
|
940
|
-
_patch_layer_norm_module(vision_block.norm2)
|
|
2603
|
+
base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
|
|
941
2604
|
|
|
942
2605
|
if rms_norm:
|
|
943
|
-
_patch_rms_norm_module(base_model.
|
|
2606
|
+
_patch_rms_norm_module(base_model.final_layernorm)
|
|
2607
|
+
|
|
944
2608
|
for decoder_layer in base_model.layers:
|
|
945
2609
|
if swiglu:
|
|
946
2610
|
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
947
2611
|
if rms_norm:
|
|
948
2612
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
949
|
-
_patch_rms_norm_module(decoder_layer.
|
|
2613
|
+
_patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
|
|
950
2614
|
|
|
951
2615
|
|
|
952
|
-
def
|
|
953
|
-
rope: bool =
|
|
2616
|
+
def apply_liger_kernel_to_qwen3_next(
|
|
2617
|
+
rope: bool = False,
|
|
954
2618
|
cross_entropy: bool = False,
|
|
955
2619
|
fused_linear_cross_entropy: bool = True,
|
|
956
2620
|
rms_norm: bool = True,
|
|
@@ -958,17 +2622,17 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
958
2622
|
model: PreTrainedModel = None,
|
|
959
2623
|
) -> None:
|
|
960
2624
|
"""
|
|
961
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
962
|
-
NOTE: Qwen2.5-VL is not available in transformers<4.48.2
|
|
2625
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
963
2626
|
|
|
964
2627
|
Args:
|
|
2628
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
965
2629
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
966
2630
|
fused_linear_cross_entropy (bool):
|
|
967
2631
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
968
2632
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
969
2633
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
970
2634
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
971
|
-
swiglu (bool): Whether to apply Liger's
|
|
2635
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
972
2636
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
973
2637
|
loaded. Default is None.
|
|
974
2638
|
"""
|
|
@@ -976,47 +2640,129 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
976
2640
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
977
2641
|
)
|
|
978
2642
|
|
|
979
|
-
from transformers.models.
|
|
980
|
-
from transformers.models.
|
|
2643
|
+
from transformers.models.qwen3_next import modeling_qwen3_next
|
|
2644
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
|
|
2645
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
|
|
2646
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
|
|
2647
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
|
981
2648
|
|
|
982
|
-
from liger_kernel.transformers.model.
|
|
2649
|
+
from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
|
|
2650
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
|
|
2651
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
983
2652
|
|
|
984
2653
|
if rope:
|
|
985
|
-
|
|
2654
|
+
# It might enocunter nan issue
|
|
2655
|
+
# modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2656
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
|
|
986
2657
|
if rms_norm:
|
|
987
|
-
|
|
2658
|
+
modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
|
|
988
2659
|
if cross_entropy:
|
|
989
|
-
|
|
2660
|
+
from transformers.loss.loss_utils import nn
|
|
2661
|
+
|
|
2662
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
990
2663
|
if fused_linear_cross_entropy:
|
|
991
|
-
|
|
2664
|
+
if model is not None:
|
|
2665
|
+
if isinstance(model, Qwen3NextForCausalLM):
|
|
2666
|
+
model.forward = MethodType(qwen3_next_lce_forward, model)
|
|
2667
|
+
else:
|
|
2668
|
+
raise TypeError(
|
|
2669
|
+
f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
|
|
2670
|
+
)
|
|
2671
|
+
else:
|
|
2672
|
+
modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
|
|
992
2673
|
if swiglu:
|
|
993
|
-
|
|
2674
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2675
|
+
modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
|
|
994
2676
|
|
|
995
2677
|
if model is not None:
|
|
996
2678
|
# The model instance already exists, so we need to additionally patch the
|
|
997
2679
|
# instance variables that reference already-instantiated modules
|
|
2680
|
+
if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
|
|
2681
|
+
base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
|
|
2682
|
+
else:
|
|
2683
|
+
raise TypeError(
|
|
2684
|
+
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
|
|
2685
|
+
)
|
|
998
2686
|
|
|
999
|
-
|
|
1000
|
-
|
|
2687
|
+
if rms_norm:
|
|
2688
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1001
2689
|
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
2690
|
+
for decoder_layer in base_model.layers:
|
|
2691
|
+
if rms_norm:
|
|
2692
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2693
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2694
|
+
|
|
2695
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2696
|
+
if swiglu:
|
|
2697
|
+
if isinstance(decoder_layer.mlp, Qwen3NextMLP):
|
|
2698
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
|
|
2699
|
+
if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
|
|
2700
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
|
|
2701
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2702
|
+
if experts is not None:
|
|
2703
|
+
for expert in experts:
|
|
2704
|
+
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
|
|
2705
|
+
|
|
2706
|
+
|
|
2707
|
+
def apply_liger_kernel_to_hunyuan_v1_dense(
|
|
2708
|
+
rope: bool = True,
|
|
2709
|
+
cross_entropy: bool = False,
|
|
2710
|
+
fused_linear_cross_entropy: bool = True,
|
|
2711
|
+
rms_norm: bool = True,
|
|
2712
|
+
swiglu: bool = True,
|
|
2713
|
+
model: PreTrainedModel = None,
|
|
2714
|
+
) -> None:
|
|
2715
|
+
"""
|
|
2716
|
+
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
|
|
2717
|
+
"""
|
|
2718
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2719
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2720
|
+
)
|
|
2721
|
+
|
|
2722
|
+
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
|
|
2723
|
+
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
|
|
2724
|
+
|
|
2725
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
|
|
2726
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2727
|
+
|
|
2728
|
+
if rope:
|
|
2729
|
+
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2730
|
+
|
|
2731
|
+
if rms_norm:
|
|
2732
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
|
|
2733
|
+
|
|
2734
|
+
if cross_entropy:
|
|
2735
|
+
from transformers.loss.loss_utils import nn
|
|
2736
|
+
|
|
2737
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2738
|
+
|
|
2739
|
+
if fused_linear_cross_entropy:
|
|
2740
|
+
if model is not None:
|
|
2741
|
+
model.forward = MethodType(hunyuan_v1_lce_forward, model)
|
|
2742
|
+
else:
|
|
2743
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
|
|
2744
|
+
|
|
2745
|
+
if swiglu:
|
|
2746
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2747
|
+
|
|
2748
|
+
if model is not None:
|
|
2749
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2750
|
+
# instance variables that reference already-instantiated modules
|
|
2751
|
+
|
|
2752
|
+
# get the base model from the model instance
|
|
2753
|
+
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
|
|
1008
2754
|
|
|
1009
2755
|
if rms_norm:
|
|
1010
2756
|
_patch_rms_norm_module(base_model.norm)
|
|
1011
2757
|
for decoder_layer in base_model.layers:
|
|
1012
2758
|
if swiglu:
|
|
1013
|
-
_patch_swiglu_module(decoder_layer.mlp,
|
|
2759
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
|
|
1014
2760
|
if rms_norm:
|
|
1015
2761
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1016
2762
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1017
2763
|
|
|
1018
2764
|
|
|
1019
|
-
def
|
|
2765
|
+
def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
1020
2766
|
rope: bool = True,
|
|
1021
2767
|
cross_entropy: bool = False,
|
|
1022
2768
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -1025,67 +2771,57 @@ def apply_liger_kernel_to_phi3(
|
|
|
1025
2771
|
model: PreTrainedModel = None,
|
|
1026
2772
|
) -> None:
|
|
1027
2773
|
"""
|
|
1028
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1029
|
-
|
|
1030
|
-
Args:
|
|
1031
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1032
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1033
|
-
fused_linear_cross_entropy (bool):
|
|
1034
|
-
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1035
|
-
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1036
|
-
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1037
|
-
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1038
|
-
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
1039
|
-
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1040
|
-
loaded. Default is None.
|
|
2774
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1041
2775
|
"""
|
|
1042
2776
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1043
2777
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1044
2778
|
)
|
|
1045
2779
|
|
|
1046
|
-
from transformers.models.
|
|
1047
|
-
from transformers.models.
|
|
2780
|
+
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
|
|
2781
|
+
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
|
|
2782
|
+
|
|
2783
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
|
|
2784
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
1048
2785
|
|
|
1049
2786
|
if rope:
|
|
1050
|
-
|
|
2787
|
+
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2788
|
+
|
|
1051
2789
|
if rms_norm:
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
2790
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
|
|
2791
|
+
|
|
1055
2792
|
if cross_entropy:
|
|
1056
|
-
|
|
1057
|
-
|
|
2793
|
+
from transformers.loss.loss_utils import nn
|
|
2794
|
+
|
|
2795
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1058
2796
|
|
|
1059
|
-
nn.functional.cross_entropy = liger_cross_entropy
|
|
1060
|
-
else:
|
|
1061
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1062
|
-
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1063
2797
|
if fused_linear_cross_entropy:
|
|
1064
|
-
if
|
|
1065
|
-
|
|
1066
|
-
else:
|
|
1067
|
-
|
|
1068
|
-
|
|
2798
|
+
if model is not None:
|
|
2799
|
+
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
|
|
2800
|
+
else:
|
|
2801
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
|
|
2802
|
+
|
|
2803
|
+
if swiglu:
|
|
2804
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
1069
2805
|
|
|
1070
2806
|
if model is not None:
|
|
1071
2807
|
# The model instance already exists, so we need to additionally patch the
|
|
1072
2808
|
# instance variables that reference already-instantiated modules
|
|
1073
2809
|
|
|
1074
2810
|
# get the base model from the model instance
|
|
1075
|
-
base_model:
|
|
2811
|
+
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
|
|
1076
2812
|
|
|
1077
2813
|
if rms_norm:
|
|
1078
2814
|
_patch_rms_norm_module(base_model.norm)
|
|
1079
|
-
|
|
1080
2815
|
for decoder_layer in base_model.layers:
|
|
1081
2816
|
if swiglu:
|
|
1082
|
-
|
|
2817
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
2818
|
+
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
|
|
1083
2819
|
if rms_norm:
|
|
1084
2820
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1085
2821
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1086
2822
|
|
|
1087
2823
|
|
|
1088
|
-
def
|
|
2824
|
+
def apply_liger_kernel_to_exaone4(
|
|
1089
2825
|
rope: bool = True,
|
|
1090
2826
|
cross_entropy: bool = False,
|
|
1091
2827
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -1094,7 +2830,7 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1094
2830
|
model: PreTrainedModel = None,
|
|
1095
2831
|
) -> None:
|
|
1096
2832
|
"""
|
|
1097
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
2833
|
+
Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
|
|
1098
2834
|
|
|
1099
2835
|
Args:
|
|
1100
2836
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
@@ -1104,7 +2840,7 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1104
2840
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1105
2841
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1106
2842
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1107
|
-
swiglu (bool): Whether to apply Liger's SwiGLU
|
|
2843
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1108
2844
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1109
2845
|
loaded. Default is None.
|
|
1110
2846
|
"""
|
|
@@ -1112,47 +2848,70 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1112
2848
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1113
2849
|
)
|
|
1114
2850
|
|
|
1115
|
-
from transformers.models.
|
|
1116
|
-
from transformers.models.
|
|
2851
|
+
from transformers.models.exaone4 import modeling_exaone4
|
|
2852
|
+
from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
|
|
1117
2853
|
|
|
1118
|
-
from liger_kernel.transformers.model.
|
|
2854
|
+
from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
|
|
1119
2855
|
|
|
1120
2856
|
if rope:
|
|
1121
|
-
|
|
2857
|
+
modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2858
|
+
|
|
1122
2859
|
if rms_norm:
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
2860
|
+
# EXAONE4 requires in_place=False to avoid gradient issues
|
|
2861
|
+
class Exaone4LigerRMSNorm(LigerRMSNorm):
|
|
2862
|
+
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
|
2863
|
+
super().__init__(hidden_size, eps, **kwargs)
|
|
2864
|
+
self.in_place = False
|
|
2865
|
+
|
|
2866
|
+
modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
|
|
2867
|
+
|
|
1126
2868
|
if cross_entropy:
|
|
1127
2869
|
from transformers.loss.loss_utils import nn
|
|
1128
2870
|
|
|
1129
2871
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
2872
|
+
|
|
1130
2873
|
if fused_linear_cross_entropy:
|
|
1131
|
-
|
|
2874
|
+
if model is not None:
|
|
2875
|
+
model.forward = MethodType(exaone4_lce_forward, model)
|
|
2876
|
+
else:
|
|
2877
|
+
modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
|
|
2878
|
+
|
|
2879
|
+
if swiglu:
|
|
2880
|
+
modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
|
|
1132
2881
|
|
|
1133
2882
|
if model is not None:
|
|
1134
2883
|
# The model instance already exists, so we need to additionally patch the
|
|
1135
2884
|
# instance variables that reference already-instantiated modules
|
|
1136
2885
|
|
|
1137
2886
|
# get the base model from the model instance
|
|
1138
|
-
base_model:
|
|
2887
|
+
base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
|
|
1139
2888
|
|
|
1140
2889
|
if rms_norm:
|
|
1141
|
-
_patch_rms_norm_module(base_model.norm)
|
|
1142
|
-
|
|
2890
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
1143
2891
|
for decoder_layer in base_model.layers:
|
|
1144
2892
|
if swiglu:
|
|
1145
|
-
|
|
2893
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
1146
2894
|
if rms_norm:
|
|
1147
2895
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
1148
2896
|
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2897
|
+
_patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
|
|
2898
|
+
_patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
|
|
1149
2899
|
|
|
1150
2900
|
|
|
1151
2901
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
1152
2902
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
1153
2903
|
"gemma": apply_liger_kernel_to_gemma,
|
|
1154
2904
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
2905
|
+
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
2906
|
+
"gemma3": apply_liger_kernel_to_gemma3,
|
|
2907
|
+
"glm4": apply_liger_kernel_to_glm4,
|
|
2908
|
+
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2909
|
+
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2910
|
+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
|
|
2911
|
+
"internvl": apply_liger_kernel_to_internvl,
|
|
1155
2912
|
"llama": apply_liger_kernel_to_llama,
|
|
2913
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
2914
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
1156
2915
|
"llava": apply_liger_kernel_to_llava,
|
|
1157
2916
|
"granite": apply_liger_kernel_to_granite,
|
|
1158
2917
|
"mllama": apply_liger_kernel_to_mllama,
|
|
@@ -1160,11 +2919,27 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1160
2919
|
"mistral": apply_liger_kernel_to_mistral,
|
|
1161
2920
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
1162
2921
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
2922
|
+
"olmo3": apply_liger_kernel_to_olmo3,
|
|
1163
2923
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
2924
|
+
"qwen3": apply_liger_kernel_to_qwen3,
|
|
2925
|
+
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
1164
2926
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
2927
|
+
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
1165
2928
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
2929
|
+
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2930
|
+
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2931
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2932
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2933
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2934
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2935
|
+
"smollm3": apply_liger_kernel_to_smollm3,
|
|
1166
2936
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1167
2937
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
2938
|
+
"falcon_h1": apply_liger_kernel_to_falcon_h1,
|
|
2939
|
+
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2940
|
+
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2941
|
+
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
2942
|
+
"exaone4": apply_liger_kernel_to_exaone4,
|
|
1168
2943
|
}
|
|
1169
2944
|
|
|
1170
2945
|
|
|
@@ -1222,7 +2997,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
1222
2997
|
return
|
|
1223
2998
|
|
|
1224
2999
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
1225
|
-
|
|
1226
3000
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
1227
3001
|
|
|
1228
3002
|
# Filter out the keyword arguments that are not supported by the apply function
|