liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +25 -9
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +124 -64
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +13 -6
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +283 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +205 -19
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +50 -25
- liger_kernel/transformers/model/gemma2.py +55 -23
- liger_kernel/transformers/model/gemma3.py +117 -120
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +102 -25
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +36 -23
- liger_kernel/transformers/model/mixtral.py +45 -25
- liger_kernel/transformers/model/mllama.py +39 -22
- liger_kernel/transformers/model/olmo2.py +40 -20
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -177
- liger_kernel/transformers/model/qwen2.py +48 -21
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1678 -160
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +48 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +36 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,9 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
5
6
|
from typing import Callable
|
|
7
|
+
from typing import Optional
|
|
6
8
|
|
|
7
9
|
import transformers
|
|
8
10
|
|
|
@@ -13,10 +15,12 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
|
13
15
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
14
16
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
15
17
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
18
|
+
from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
|
|
16
19
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
17
20
|
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
18
21
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
19
22
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
23
|
+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
|
|
20
24
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
21
25
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
22
26
|
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
@@ -25,16 +29,24 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
|
|
|
25
29
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
26
30
|
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
27
31
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
28
|
-
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
|
29
32
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
30
33
|
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
34
|
+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
31
35
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
32
36
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
33
37
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
|
|
34
39
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
35
40
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
36
41
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
37
42
|
|
|
43
|
+
try:
|
|
44
|
+
import peft
|
|
45
|
+
|
|
46
|
+
PEFT_AVAILABLE = True
|
|
47
|
+
except ImportError:
|
|
48
|
+
PEFT_AVAILABLE = False
|
|
49
|
+
|
|
38
50
|
transformer_version = version.parse(transformers.__version__)
|
|
39
51
|
|
|
40
52
|
logger = logging.getLogger(__name__)
|
|
@@ -47,33 +59,82 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
47
59
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
48
60
|
|
|
49
61
|
|
|
50
|
-
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
51
|
-
module
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
62
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
63
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
64
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
65
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
66
|
+
module.modules_to_save.default.offset = offset
|
|
67
|
+
module.modules_to_save.default.casting_mode = casting_mode
|
|
68
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
69
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
70
|
+
)
|
|
71
|
+
module.modules_to_save.default.in_place = in_place
|
|
72
|
+
module.modules_to_save.default.row_mode = row_mode
|
|
73
|
+
module.original_module.offset = offset
|
|
74
|
+
module.original_module.casting_mode = casting_mode
|
|
75
|
+
module.original_module.variance_epsilon = (
|
|
76
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
77
|
+
)
|
|
78
|
+
module.original_module.in_place = in_place
|
|
79
|
+
module.original_module.row_mode = row_mode
|
|
80
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
81
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
82
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
83
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
84
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
85
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
86
|
+
else:
|
|
87
|
+
module.offset = offset
|
|
88
|
+
module.casting_mode = casting_mode
|
|
89
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
90
|
+
module.in_place = in_place
|
|
91
|
+
module.row_mode = row_mode
|
|
92
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
93
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
94
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
58
95
|
|
|
59
96
|
|
|
60
97
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
98
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
99
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
100
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
101
|
+
module.hidden_size = module.normalized_shape
|
|
102
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
103
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
104
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
105
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
106
|
+
)
|
|
107
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
108
|
+
module, "normalized_shape", None
|
|
109
|
+
)
|
|
110
|
+
module.original_module.variance_epsilon = (
|
|
111
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
112
|
+
)
|
|
113
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
114
|
+
module, "normalized_shape", None
|
|
115
|
+
)
|
|
116
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
|
|
117
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
|
|
118
|
+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
|
|
119
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
120
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
121
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
122
|
+
else:
|
|
123
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
124
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
125
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
126
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
127
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
67
128
|
|
|
68
129
|
|
|
69
130
|
def _patch_swiglu_module(module, liger_module):
|
|
70
131
|
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
71
|
-
module
|
|
132
|
+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
|
|
72
133
|
|
|
73
134
|
|
|
74
135
|
def _patch_geglu_module(module):
|
|
75
136
|
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
76
|
-
module
|
|
137
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
|
|
77
138
|
|
|
78
139
|
|
|
79
140
|
def apply_liger_kernel_to_granite(
|
|
@@ -204,10 +265,16 @@ def apply_liger_kernel_to_llama(
|
|
|
204
265
|
|
|
205
266
|
if fused_linear_cross_entropy:
|
|
206
267
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
207
|
-
|
|
268
|
+
if model is not None:
|
|
269
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
270
|
+
else:
|
|
271
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
208
272
|
else: # if version < 4.46.1
|
|
209
273
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
210
|
-
|
|
274
|
+
if model is not None:
|
|
275
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
276
|
+
else:
|
|
277
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
211
278
|
|
|
212
279
|
if model is not None:
|
|
213
280
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -227,6 +294,77 @@ def apply_liger_kernel_to_llama(
|
|
|
227
294
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
228
295
|
|
|
229
296
|
|
|
297
|
+
def apply_liger_kernel_to_smollm3(
|
|
298
|
+
rope: bool = True,
|
|
299
|
+
cross_entropy: bool = False,
|
|
300
|
+
fused_linear_cross_entropy: bool = True,
|
|
301
|
+
rms_norm: bool = True,
|
|
302
|
+
swiglu: bool = True,
|
|
303
|
+
model: PreTrainedModel = None,
|
|
304
|
+
) -> None:
|
|
305
|
+
"""
|
|
306
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
310
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
311
|
+
fused_linear_cross_entropy (bool):
|
|
312
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
313
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
314
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
315
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
316
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
317
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
318
|
+
loaded. Default is None.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
322
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
from transformers.models.smollm3 import modeling_smollm3
|
|
326
|
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
|
|
327
|
+
|
|
328
|
+
if rope:
|
|
329
|
+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
330
|
+
if rms_norm:
|
|
331
|
+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
|
|
332
|
+
if swiglu:
|
|
333
|
+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
|
|
334
|
+
|
|
335
|
+
if cross_entropy:
|
|
336
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
337
|
+
from transformers.loss.loss_utils import nn
|
|
338
|
+
|
|
339
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
340
|
+
else:
|
|
341
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
342
|
+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
343
|
+
|
|
344
|
+
if fused_linear_cross_entropy:
|
|
345
|
+
if model is not None:
|
|
346
|
+
model.forward = MethodType(smollm3_lce_forward, model)
|
|
347
|
+
else:
|
|
348
|
+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
|
|
349
|
+
|
|
350
|
+
if model is not None:
|
|
351
|
+
# The model instance already exists, so we need to additionally patch the
|
|
352
|
+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
|
|
353
|
+
|
|
354
|
+
# get the base model from the model instance
|
|
355
|
+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
|
|
356
|
+
|
|
357
|
+
if rms_norm:
|
|
358
|
+
_patch_rms_norm_module(base_model.norm)
|
|
359
|
+
|
|
360
|
+
for decoder_layer in base_model.layers:
|
|
361
|
+
if swiglu:
|
|
362
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
363
|
+
if rms_norm:
|
|
364
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
365
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
366
|
+
|
|
367
|
+
|
|
230
368
|
def apply_liger_kernel_to_llava(
|
|
231
369
|
cross_entropy: bool = False,
|
|
232
370
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -261,13 +399,20 @@ def apply_liger_kernel_to_llava(
|
|
|
261
399
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
262
400
|
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
263
401
|
if fused_linear_cross_entropy:
|
|
264
|
-
if transformer_version >= version.parse("4.
|
|
265
|
-
|
|
402
|
+
if transformer_version >= version.parse("4.52.0"):
|
|
403
|
+
if model is not None:
|
|
404
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
405
|
+
else:
|
|
406
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
407
|
+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
408
|
+
if model is not None:
|
|
409
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
410
|
+
else:
|
|
411
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
266
412
|
else: # if version < 4.49.0
|
|
267
413
|
logger.warning(
|
|
268
|
-
"
|
|
414
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
269
415
|
)
|
|
270
|
-
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
271
416
|
|
|
272
417
|
if model is not None:
|
|
273
418
|
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
@@ -306,6 +451,97 @@ def apply_liger_kernel_to_llava(
|
|
|
306
451
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
307
452
|
|
|
308
453
|
|
|
454
|
+
def apply_liger_kernel_to_llama4(
|
|
455
|
+
rope: bool = True,
|
|
456
|
+
cross_entropy: bool = False,
|
|
457
|
+
fused_linear_cross_entropy: bool = True,
|
|
458
|
+
rms_norm: bool = True,
|
|
459
|
+
swiglu: bool = True,
|
|
460
|
+
model: PreTrainedModel = None,
|
|
461
|
+
layer_norm: bool = True,
|
|
462
|
+
) -> None:
|
|
463
|
+
"""
|
|
464
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
468
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
469
|
+
fused_linear_cross_entropy (bool):
|
|
470
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
471
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
472
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
473
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
474
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
475
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
476
|
+
loaded. Default is None.
|
|
477
|
+
"""
|
|
478
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
479
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
from transformers.models.llama4 import modeling_llama4
|
|
483
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
484
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
485
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
486
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
487
|
+
|
|
488
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
489
|
+
|
|
490
|
+
if rope:
|
|
491
|
+
from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
|
|
492
|
+
|
|
493
|
+
apply_liger_llama4_rope_full(modeling_llama4)
|
|
494
|
+
if rms_norm:
|
|
495
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
496
|
+
if swiglu:
|
|
497
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
498
|
+
|
|
499
|
+
if cross_entropy:
|
|
500
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
501
|
+
|
|
502
|
+
if fused_linear_cross_entropy:
|
|
503
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
504
|
+
|
|
505
|
+
if model is not None:
|
|
506
|
+
# The model instance already exists, so we need to additionally patch the
|
|
507
|
+
# instance variables that reference already-instantiated modules
|
|
508
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
509
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
510
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
511
|
+
text_model: Llama4TextModel = language_model.model
|
|
512
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
513
|
+
text_model = model.model
|
|
514
|
+
vision_model = None
|
|
515
|
+
elif isinstance(model, Llama4TextModel):
|
|
516
|
+
text_model = model
|
|
517
|
+
vision_model = None
|
|
518
|
+
|
|
519
|
+
else:
|
|
520
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
521
|
+
|
|
522
|
+
if text_model:
|
|
523
|
+
if rms_norm:
|
|
524
|
+
_patch_rms_norm_module(text_model.norm)
|
|
525
|
+
for decoder_layer in text_model.layers:
|
|
526
|
+
if swiglu:
|
|
527
|
+
if decoder_layer.is_moe_layer:
|
|
528
|
+
_patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
|
|
529
|
+
else:
|
|
530
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
531
|
+
if rms_norm:
|
|
532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
534
|
+
|
|
535
|
+
if vision_model:
|
|
536
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
537
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
538
|
+
|
|
539
|
+
for layer in vision_model.model.layers:
|
|
540
|
+
if layer_norm:
|
|
541
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
542
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
543
|
+
|
|
544
|
+
|
|
309
545
|
def apply_liger_kernel_to_mllama(
|
|
310
546
|
rope: bool = True,
|
|
311
547
|
cross_entropy: bool = False,
|
|
@@ -347,7 +583,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
347
583
|
|
|
348
584
|
if rope:
|
|
349
585
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
350
|
-
if layer_norm:
|
|
586
|
+
if layer_norm and model is None:
|
|
351
587
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
352
588
|
if rms_norm:
|
|
353
589
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
@@ -363,10 +599,16 @@ def apply_liger_kernel_to_mllama(
|
|
|
363
599
|
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
364
600
|
if fused_linear_cross_entropy:
|
|
365
601
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
366
|
-
|
|
602
|
+
if model is not None:
|
|
603
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
604
|
+
else:
|
|
605
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
367
606
|
else: # if version < 4.46.1
|
|
368
607
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
369
|
-
|
|
608
|
+
if model is not None:
|
|
609
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
610
|
+
else:
|
|
611
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
370
612
|
|
|
371
613
|
if model is not None:
|
|
372
614
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -375,13 +617,17 @@ def apply_liger_kernel_to_mllama(
|
|
|
375
617
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
376
618
|
language_model: MllamaForCausalLM = model.language_model
|
|
377
619
|
vision_model: MllamaVisionModel = model.vision_model
|
|
378
|
-
|
|
620
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
621
|
+
text_model: MllamaTextModel = language_model.model
|
|
622
|
+
else:
|
|
623
|
+
text_model = language_model
|
|
379
624
|
elif isinstance(model, MllamaForCausalLM):
|
|
380
625
|
text_model = model.model
|
|
381
626
|
vision_model = None
|
|
382
627
|
elif isinstance(model, MllamaTextModel):
|
|
383
628
|
text_model = model
|
|
384
629
|
vision_model = None
|
|
630
|
+
|
|
385
631
|
else:
|
|
386
632
|
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
|
387
633
|
|
|
@@ -448,7 +694,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
448
694
|
if cross_entropy:
|
|
449
695
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
450
696
|
if fused_linear_cross_entropy:
|
|
451
|
-
|
|
697
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
698
|
+
if model is not None:
|
|
699
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
700
|
+
else:
|
|
701
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
702
|
+
else:
|
|
703
|
+
logger.warning(
|
|
704
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
705
|
+
)
|
|
706
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
707
|
+
|
|
452
708
|
if swiglu:
|
|
453
709
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
454
710
|
|
|
@@ -516,10 +772,16 @@ def apply_liger_kernel_to_mixtral(
|
|
|
516
772
|
|
|
517
773
|
if fused_linear_cross_entropy:
|
|
518
774
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
519
|
-
|
|
775
|
+
if model is not None:
|
|
776
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
777
|
+
else:
|
|
778
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
520
779
|
else: # if version < 4.46.1
|
|
521
780
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
522
|
-
|
|
781
|
+
if model is not None:
|
|
782
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
783
|
+
else:
|
|
784
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
523
785
|
if swiglu:
|
|
524
786
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
525
787
|
|
|
@@ -573,8 +835,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
573
835
|
from transformers.models.gemma import modeling_gemma
|
|
574
836
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
575
837
|
|
|
576
|
-
|
|
577
|
-
|
|
838
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
|
839
|
+
|
|
578
840
|
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
|
579
841
|
|
|
580
842
|
if rope:
|
|
@@ -593,10 +855,16 @@ def apply_liger_kernel_to_gemma(
|
|
|
593
855
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
594
856
|
if fused_linear_cross_entropy:
|
|
595
857
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
596
|
-
|
|
858
|
+
if model is not None:
|
|
859
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
860
|
+
else:
|
|
861
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
597
862
|
else: # if version < 4.46.1
|
|
598
863
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
599
|
-
|
|
864
|
+
if model is not None:
|
|
865
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
866
|
+
else:
|
|
867
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
600
868
|
|
|
601
869
|
if model is not None:
|
|
602
870
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -647,7 +915,8 @@ def apply_liger_kernel_to_gemma2(
|
|
|
647
915
|
from transformers.models.gemma2 import modeling_gemma2
|
|
648
916
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
649
917
|
|
|
650
|
-
|
|
918
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
|
919
|
+
|
|
651
920
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
652
921
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
653
922
|
)
|
|
@@ -667,10 +936,16 @@ def apply_liger_kernel_to_gemma2(
|
|
|
667
936
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
668
937
|
if fused_linear_cross_entropy:
|
|
669
938
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
670
|
-
|
|
939
|
+
if model is not None:
|
|
940
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
941
|
+
else:
|
|
942
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
671
943
|
else:
|
|
672
944
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
673
|
-
|
|
945
|
+
if model is not None:
|
|
946
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
947
|
+
else:
|
|
948
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
674
949
|
if geglu:
|
|
675
950
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
676
951
|
|
|
@@ -724,9 +999,10 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
724
999
|
from transformers.models.gemma3 import modeling_gemma3
|
|
725
1000
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
|
726
1001
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
|
1002
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
|
727
1003
|
|
|
728
|
-
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
|
729
1004
|
from liger_kernel.transformers.model.gemma3 import causal_forward
|
|
1005
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
|
730
1006
|
|
|
731
1007
|
_patch_rms_norm_module_for_gemma3 = partial(
|
|
732
1008
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
@@ -748,15 +1024,18 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
748
1024
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
749
1025
|
|
|
750
1026
|
if fused_linear_cross_entropy:
|
|
751
|
-
|
|
1027
|
+
if model is not None:
|
|
1028
|
+
model.forward = MethodType(causal_forward, model)
|
|
1029
|
+
else:
|
|
1030
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
752
1031
|
|
|
753
1032
|
if model is not None:
|
|
754
1033
|
# The model instance already exists, so we need to additionally patch the
|
|
755
1034
|
# instance variables that reference already-instantiated modules
|
|
756
1035
|
|
|
757
|
-
if isinstance(model, Gemma3ForCausalLM):
|
|
1036
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
|
758
1037
|
# get the base model from the model instance
|
|
759
|
-
base_model = model.model
|
|
1038
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
|
760
1039
|
|
|
761
1040
|
if rms_norm:
|
|
762
1041
|
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
|
@@ -818,7 +1097,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
818
1097
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
819
1098
|
)
|
|
820
1099
|
|
|
821
|
-
if layer_norm:
|
|
1100
|
+
if layer_norm and model is None:
|
|
822
1101
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
823
1102
|
|
|
824
1103
|
apply_liger_kernel_to_gemma3_text(
|
|
@@ -829,7 +1108,10 @@ def apply_liger_kernel_to_gemma3(
|
|
|
829
1108
|
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
830
1109
|
|
|
831
1110
|
if fused_linear_cross_entropy:
|
|
832
|
-
|
|
1111
|
+
if model is not None:
|
|
1112
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1113
|
+
else:
|
|
1114
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
833
1115
|
|
|
834
1116
|
if model is not None:
|
|
835
1117
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -897,7 +1179,9 @@ def apply_liger_kernel_to_paligemma(
|
|
|
897
1179
|
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
898
1180
|
|
|
899
1181
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1182
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
900
1183
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1184
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
901
1185
|
from transformers.models.paligemma import modeling_paligemma
|
|
902
1186
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
903
1187
|
from transformers.models.siglip import modeling_siglip
|
|
@@ -908,7 +1192,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
908
1192
|
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
909
1193
|
|
|
910
1194
|
# The vision_tower is a SiglipVisionModel
|
|
911
|
-
if layer_norm:
|
|
1195
|
+
if layer_norm and model is None:
|
|
912
1196
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
913
1197
|
|
|
914
1198
|
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
@@ -926,10 +1210,16 @@ def apply_liger_kernel_to_paligemma(
|
|
|
926
1210
|
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
927
1211
|
if fused_linear_cross_entropy:
|
|
928
1212
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
929
|
-
|
|
1213
|
+
if model is not None:
|
|
1214
|
+
model.forward = MethodType(lce_forward, model)
|
|
1215
|
+
else:
|
|
1216
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
930
1217
|
else: # if version < 4.46.1
|
|
931
1218
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
932
|
-
|
|
1219
|
+
if model is not None:
|
|
1220
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1221
|
+
else:
|
|
1222
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
933
1223
|
|
|
934
1224
|
if model is not None:
|
|
935
1225
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -950,7 +1240,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
950
1240
|
|
|
951
1241
|
language_model = model.language_model
|
|
952
1242
|
|
|
953
|
-
if isinstance(language_model, GemmaForCausalLM):
|
|
1243
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
954
1244
|
apply_liger_kernel_to_gemma(
|
|
955
1245
|
rope=rope,
|
|
956
1246
|
cross_entropy=False,
|
|
@@ -960,7 +1250,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
960
1250
|
model=language_model,
|
|
961
1251
|
)
|
|
962
1252
|
|
|
963
|
-
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
1253
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
964
1254
|
apply_liger_kernel_to_gemma2(
|
|
965
1255
|
rope=rope,
|
|
966
1256
|
cross_entropy=False,
|
|
@@ -1021,10 +1311,16 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1021
1311
|
|
|
1022
1312
|
if fused_linear_cross_entropy:
|
|
1023
1313
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1024
|
-
|
|
1314
|
+
if model is not None:
|
|
1315
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1316
|
+
else:
|
|
1317
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1025
1318
|
else: # if version < 4.46.1
|
|
1026
1319
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1027
|
-
|
|
1320
|
+
if model is not None:
|
|
1321
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1322
|
+
else:
|
|
1323
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1028
1324
|
|
|
1029
1325
|
if swiglu:
|
|
1030
1326
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
@@ -1045,70 +1341,54 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1045
1341
|
if rms_norm:
|
|
1046
1342
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1047
1343
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1048
|
-
print("Applied Liger kernels to Qwen2")
|
|
1049
1344
|
|
|
1050
1345
|
|
|
1051
|
-
def
|
|
1346
|
+
def apply_liger_kernel_to_qwen3(
|
|
1052
1347
|
rope: bool = True,
|
|
1053
1348
|
cross_entropy: bool = False,
|
|
1054
1349
|
fused_linear_cross_entropy: bool = True,
|
|
1055
1350
|
rms_norm: bool = True,
|
|
1056
|
-
layer_norm: bool = True,
|
|
1057
1351
|
swiglu: bool = True,
|
|
1058
1352
|
model: PreTrainedModel = None,
|
|
1059
1353
|
) -> None:
|
|
1060
1354
|
"""
|
|
1061
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1062
|
-
NOTE: Qwen2-VL is not available in transformers<4.45.0
|
|
1063
|
-
|
|
1064
|
-
Args:
|
|
1065
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1066
|
-
fused_linear_cross_entropy (bool):
|
|
1067
|
-
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1068
|
-
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1069
|
-
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1070
|
-
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1071
|
-
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1072
|
-
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1073
|
-
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1074
|
-
loaded. Default is None.
|
|
1355
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1075
1356
|
"""
|
|
1076
1357
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1077
1358
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1078
1359
|
)
|
|
1079
1360
|
|
|
1080
|
-
from transformers.models.
|
|
1081
|
-
from transformers.models.
|
|
1361
|
+
from transformers.models.qwen3 import modeling_qwen3
|
|
1362
|
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
|
|
1082
1363
|
|
|
1083
|
-
from liger_kernel.transformers.model.
|
|
1364
|
+
from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
|
|
1084
1365
|
|
|
1085
1366
|
if rope:
|
|
1086
|
-
|
|
1367
|
+
modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1368
|
+
|
|
1087
1369
|
if rms_norm:
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
if layer_norm:
|
|
1091
|
-
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1370
|
+
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
|
1371
|
+
|
|
1092
1372
|
if cross_entropy:
|
|
1093
|
-
|
|
1373
|
+
from transformers.loss.loss_utils import nn
|
|
1374
|
+
|
|
1375
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1376
|
+
|
|
1094
1377
|
if fused_linear_cross_entropy:
|
|
1095
|
-
|
|
1378
|
+
if model is not None:
|
|
1379
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1380
|
+
else:
|
|
1381
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1382
|
+
|
|
1096
1383
|
if swiglu:
|
|
1097
|
-
|
|
1384
|
+
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
1098
1385
|
|
|
1099
1386
|
if model is not None:
|
|
1100
1387
|
# The model instance already exists, so we need to additionally patch the
|
|
1101
1388
|
# instance variables that reference already-instantiated modules
|
|
1102
1389
|
|
|
1103
1390
|
# get the base model from the model instance
|
|
1104
|
-
base_model:
|
|
1105
|
-
|
|
1106
|
-
if hasattr(model, "visual"):
|
|
1107
|
-
# Patch Qwen2VisionTransformerPretrainedModel
|
|
1108
|
-
for vision_block in model.visual.blocks:
|
|
1109
|
-
if layer_norm:
|
|
1110
|
-
_patch_layer_norm_module(vision_block.norm1)
|
|
1111
|
-
_patch_layer_norm_module(vision_block.norm2)
|
|
1391
|
+
base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
|
|
1112
1392
|
|
|
1113
1393
|
if rms_norm:
|
|
1114
1394
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -1120,7 +1400,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1120
1400
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1121
1401
|
|
|
1122
1402
|
|
|
1123
|
-
def
|
|
1403
|
+
def apply_liger_kernel_to_qwen3_moe(
|
|
1124
1404
|
rope: bool = True,
|
|
1125
1405
|
cross_entropy: bool = False,
|
|
1126
1406
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -1129,74 +1409,69 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1129
1409
|
model: PreTrainedModel = None,
|
|
1130
1410
|
) -> None:
|
|
1131
1411
|
"""
|
|
1132
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1133
|
-
NOTE: Qwen2.5-VL is not available in transformers<4.48.2
|
|
1134
|
-
|
|
1135
|
-
Args:
|
|
1136
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1137
|
-
fused_linear_cross_entropy (bool):
|
|
1138
|
-
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1139
|
-
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1140
|
-
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1141
|
-
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1142
|
-
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1143
|
-
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1144
|
-
loaded. Default is None.
|
|
1412
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1145
1413
|
"""
|
|
1146
1414
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1147
1415
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1148
1416
|
)
|
|
1149
1417
|
|
|
1150
|
-
from transformers.models.
|
|
1151
|
-
from transformers.models.
|
|
1418
|
+
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
1419
|
+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
|
|
1152
1420
|
|
|
1153
|
-
from liger_kernel.transformers.model.
|
|
1421
|
+
from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
|
|
1422
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
1154
1423
|
|
|
1155
1424
|
if rope:
|
|
1156
|
-
|
|
1425
|
+
modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1426
|
+
|
|
1157
1427
|
if rms_norm:
|
|
1158
|
-
|
|
1428
|
+
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
|
1429
|
+
|
|
1159
1430
|
if cross_entropy:
|
|
1160
|
-
|
|
1431
|
+
from transformers.loss.loss_utils import nn
|
|
1432
|
+
|
|
1433
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1434
|
+
|
|
1161
1435
|
if fused_linear_cross_entropy:
|
|
1162
|
-
|
|
1436
|
+
if model is not None:
|
|
1437
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1438
|
+
else:
|
|
1439
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1440
|
+
|
|
1163
1441
|
if swiglu:
|
|
1164
|
-
|
|
1442
|
+
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
1165
1443
|
|
|
1166
1444
|
if model is not None:
|
|
1167
1445
|
# The model instance already exists, so we need to additionally patch the
|
|
1168
1446
|
# instance variables that reference already-instantiated modules
|
|
1169
1447
|
|
|
1170
1448
|
# get the base model from the model instance
|
|
1171
|
-
base_model:
|
|
1172
|
-
|
|
1173
|
-
if hasattr(model, "visual"):
|
|
1174
|
-
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1175
|
-
for vision_block in model.visual.blocks:
|
|
1176
|
-
if rms_norm:
|
|
1177
|
-
_patch_rms_norm_module(vision_block.norm1)
|
|
1178
|
-
_patch_rms_norm_module(vision_block.norm2)
|
|
1449
|
+
base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
|
|
1179
1450
|
|
|
1180
1451
|
if rms_norm:
|
|
1181
1452
|
_patch_rms_norm_module(base_model.norm)
|
|
1182
1453
|
for decoder_layer in base_model.layers:
|
|
1183
1454
|
if swiglu:
|
|
1184
|
-
|
|
1455
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
1456
|
+
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
|
|
1185
1457
|
if rms_norm:
|
|
1186
1458
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1187
1459
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1188
1460
|
|
|
1189
1461
|
|
|
1190
|
-
def
|
|
1462
|
+
def apply_liger_kernel_to_gpt_oss(
|
|
1191
1463
|
rope: bool = True,
|
|
1192
1464
|
cross_entropy: bool = False,
|
|
1193
1465
|
fused_linear_cross_entropy: bool = True,
|
|
1194
1466
|
rms_norm: bool = True,
|
|
1195
|
-
swiglu: bool =
|
|
1467
|
+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
|
|
1196
1468
|
model: PreTrainedModel = None,
|
|
1197
1469
|
) -> None:
|
|
1198
1470
|
"""
|
|
1199
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1471
|
+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
|
|
1472
|
+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
|
|
1473
|
+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
|
|
1474
|
+
implementation with clamping and MXFP4 quantization.
|
|
1200
1475
|
|
|
1201
1476
|
Args:
|
|
1202
1477
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
@@ -1206,58 +1481,1135 @@ def apply_liger_kernel_to_phi3(
|
|
|
1206
1481
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1207
1482
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1208
1483
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1209
|
-
swiglu (bool): Whether to apply Liger's SwiGLU
|
|
1484
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1485
|
+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
|
|
1210
1486
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1211
|
-
|
|
1487
|
+
loaded. Default is None.
|
|
1212
1488
|
"""
|
|
1489
|
+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
|
|
1490
|
+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1213
1493
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1214
1494
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1215
1495
|
)
|
|
1216
1496
|
|
|
1217
|
-
from transformers.models.
|
|
1218
|
-
from transformers.models.
|
|
1497
|
+
from transformers.models.gpt_oss import modeling_gpt_oss
|
|
1498
|
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
|
|
1219
1499
|
|
|
1220
1500
|
if rope:
|
|
1221
|
-
|
|
1501
|
+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1502
|
+
|
|
1222
1503
|
if rms_norm:
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1504
|
+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
|
|
1505
|
+
|
|
1226
1506
|
if cross_entropy:
|
|
1227
|
-
|
|
1228
|
-
|
|
1507
|
+
from transformers.loss.loss_utils import nn
|
|
1508
|
+
|
|
1509
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1229
1510
|
|
|
1230
|
-
nn.functional.cross_entropy = liger_cross_entropy
|
|
1231
|
-
else:
|
|
1232
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1233
|
-
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1234
1511
|
if fused_linear_cross_entropy:
|
|
1235
|
-
if
|
|
1236
|
-
|
|
1237
|
-
else:
|
|
1238
|
-
|
|
1239
|
-
|
|
1512
|
+
if model is not None:
|
|
1513
|
+
model.forward = MethodType(gpt_oss_lce_forward, model)
|
|
1514
|
+
else:
|
|
1515
|
+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
|
|
1516
|
+
|
|
1517
|
+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
|
|
1518
|
+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
|
|
1519
|
+
|
|
1520
|
+
if model is not None:
|
|
1521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1522
|
+
# instance variables that reference already-instantiated modules
|
|
1523
|
+
|
|
1524
|
+
# get the base model from the model instance
|
|
1525
|
+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
|
|
1526
|
+
|
|
1527
|
+
if rms_norm:
|
|
1528
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1529
|
+
for decoder_layer in base_model.layers:
|
|
1530
|
+
if rms_norm:
|
|
1531
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1532
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1533
|
+
|
|
1534
|
+
|
|
1535
|
+
def apply_liger_kernel_to_qwen2_vl(
|
|
1536
|
+
rope: bool = True,
|
|
1537
|
+
cross_entropy: bool = False,
|
|
1538
|
+
fused_linear_cross_entropy: bool = True,
|
|
1539
|
+
rms_norm: bool = True,
|
|
1540
|
+
layer_norm: bool = True,
|
|
1541
|
+
swiglu: bool = True,
|
|
1542
|
+
model: PreTrainedModel = None,
|
|
1543
|
+
) -> None:
|
|
1544
|
+
"""
|
|
1545
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
1546
|
+
NOTE: Qwen2-VL is not supported in transformers<4.52.4
|
|
1547
|
+
|
|
1548
|
+
Args:
|
|
1549
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1550
|
+
fused_linear_cross_entropy (bool):
|
|
1551
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1552
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1553
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1554
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1555
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1556
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1557
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1558
|
+
loaded. Default is None.
|
|
1559
|
+
"""
|
|
1560
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1561
|
+
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
|
|
1562
|
+
return
|
|
1563
|
+
|
|
1564
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1565
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1566
|
+
)
|
|
1567
|
+
|
|
1568
|
+
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
1569
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
|
1570
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
|
1571
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
1572
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
|
1573
|
+
|
|
1574
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
|
1575
|
+
|
|
1576
|
+
if rope:
|
|
1577
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
1578
|
+
if rms_norm:
|
|
1579
|
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1580
|
+
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1581
|
+
if layer_norm and model is None:
|
|
1582
|
+
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1583
|
+
if cross_entropy:
|
|
1584
|
+
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1585
|
+
if fused_linear_cross_entropy:
|
|
1586
|
+
if model is not None:
|
|
1587
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1588
|
+
else:
|
|
1589
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1590
|
+
if swiglu:
|
|
1591
|
+
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1592
|
+
|
|
1593
|
+
if model is not None:
|
|
1594
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1595
|
+
# instance variables that reference already-instantiated modules
|
|
1596
|
+
|
|
1597
|
+
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
|
|
1598
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1599
|
+
# Not sure if it is subject to changes in the future.
|
|
1600
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1601
|
+
text_model: Qwen2VLTextModel = model.language_model
|
|
1602
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1603
|
+
elif isinstance(model, Qwen2VLTextModel):
|
|
1604
|
+
text_model: Qwen2VLTextModel = model
|
|
1605
|
+
vision_model = None
|
|
1606
|
+
else:
|
|
1607
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1608
|
+
raise TypeError(
|
|
1609
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1610
|
+
)
|
|
1611
|
+
|
|
1612
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
|
1613
|
+
if vision_model is not None:
|
|
1614
|
+
for vision_block in vision_model.blocks:
|
|
1615
|
+
if layer_norm:
|
|
1616
|
+
_patch_layer_norm_module(vision_block.norm1)
|
|
1617
|
+
_patch_layer_norm_module(vision_block.norm2)
|
|
1618
|
+
|
|
1619
|
+
# Patch Qwen2VisionTextModel
|
|
1620
|
+
if text_model is not None:
|
|
1621
|
+
if rms_norm:
|
|
1622
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1623
|
+
for decoder_layer in text_model.layers:
|
|
1624
|
+
if swiglu:
|
|
1625
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1626
|
+
if rms_norm:
|
|
1627
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1628
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1629
|
+
|
|
1630
|
+
|
|
1631
|
+
def apply_liger_kernel_to_qwen2_5_vl(
|
|
1632
|
+
rope: bool = True,
|
|
1633
|
+
cross_entropy: bool = False,
|
|
1634
|
+
fused_linear_cross_entropy: bool = True,
|
|
1635
|
+
rms_norm: bool = True,
|
|
1636
|
+
swiglu: bool = True,
|
|
1637
|
+
model: PreTrainedModel = None,
|
|
1638
|
+
) -> None:
|
|
1639
|
+
"""
|
|
1640
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
|
|
1641
|
+
NOTE: Qwen2.5-VL is not available in transformers<4.48.2
|
|
1642
|
+
|
|
1643
|
+
Args:
|
|
1644
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1645
|
+
fused_linear_cross_entropy (bool):
|
|
1646
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1647
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1648
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1649
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1650
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1651
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1652
|
+
loaded. Default is None.
|
|
1653
|
+
"""
|
|
1654
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1655
|
+
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
|
|
1656
|
+
return
|
|
1657
|
+
|
|
1658
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1659
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1660
|
+
)
|
|
1661
|
+
|
|
1662
|
+
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
1663
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
|
1664
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
|
1665
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
|
1666
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
|
1667
|
+
|
|
1668
|
+
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
|
1669
|
+
|
|
1670
|
+
if rope:
|
|
1671
|
+
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
1672
|
+
if rms_norm:
|
|
1673
|
+
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1674
|
+
if cross_entropy:
|
|
1675
|
+
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1676
|
+
if fused_linear_cross_entropy:
|
|
1677
|
+
if model is not None:
|
|
1678
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1679
|
+
else:
|
|
1680
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1681
|
+
if swiglu:
|
|
1682
|
+
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1683
|
+
|
|
1684
|
+
if model is not None:
|
|
1685
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1686
|
+
# instance variables that reference already-instantiated modules
|
|
1687
|
+
|
|
1688
|
+
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
|
|
1689
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1690
|
+
# Not sure if it is subject to changes in the future.
|
|
1691
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1692
|
+
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1693
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1694
|
+
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
1695
|
+
text_model: Qwen2_5_VLTextModel = model
|
|
1696
|
+
vision_model = None
|
|
1697
|
+
else:
|
|
1698
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1699
|
+
raise TypeError(
|
|
1700
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1701
|
+
)
|
|
1702
|
+
|
|
1703
|
+
if vision_model is not None:
|
|
1704
|
+
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1705
|
+
for vision_block in model.visual.blocks:
|
|
1706
|
+
if rms_norm:
|
|
1707
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
1708
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
1709
|
+
|
|
1710
|
+
if text_model is not None:
|
|
1711
|
+
if rms_norm:
|
|
1712
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1713
|
+
for decoder_layer in text_model.layers:
|
|
1714
|
+
if swiglu:
|
|
1715
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1716
|
+
if rms_norm:
|
|
1717
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1718
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1719
|
+
|
|
1720
|
+
|
|
1721
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1722
|
+
rope: bool = True,
|
|
1723
|
+
cross_entropy: bool = False,
|
|
1724
|
+
fused_linear_cross_entropy: bool = True,
|
|
1725
|
+
rms_norm: bool = True,
|
|
1726
|
+
swiglu: bool = False,
|
|
1727
|
+
model: PreTrainedModel = None,
|
|
1728
|
+
) -> None:
|
|
1729
|
+
"""
|
|
1730
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1731
|
+
|
|
1732
|
+
Args:
|
|
1733
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1734
|
+
fused_linear_cross_entropy (bool):
|
|
1735
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1736
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1737
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1738
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1739
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1740
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1741
|
+
loaded. Default is None.
|
|
1742
|
+
"""
|
|
1743
|
+
|
|
1744
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1745
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1746
|
+
)
|
|
1747
|
+
|
|
1748
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1749
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1750
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1751
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1752
|
+
|
|
1753
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1754
|
+
|
|
1755
|
+
if rope:
|
|
1756
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1757
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1758
|
+
|
|
1759
|
+
if rms_norm:
|
|
1760
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1761
|
+
|
|
1762
|
+
if cross_entropy:
|
|
1763
|
+
from transformers.loss.loss_utils import nn
|
|
1764
|
+
|
|
1765
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1766
|
+
|
|
1767
|
+
if fused_linear_cross_entropy:
|
|
1768
|
+
if model is not None:
|
|
1769
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1770
|
+
else:
|
|
1771
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1772
|
+
|
|
1773
|
+
if model is not None and rms_norm:
|
|
1774
|
+
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
|
|
1775
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1776
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1777
|
+
text_model = model
|
|
1778
|
+
else:
|
|
1779
|
+
raise TypeError(
|
|
1780
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1784
|
+
|
|
1785
|
+
if text_model is not None:
|
|
1786
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1787
|
+
for decoder_layer in text_model.layers:
|
|
1788
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1789
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1790
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1791
|
+
if self_attn is not None:
|
|
1792
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1793
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1794
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1795
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1799
|
+
rope: bool = True,
|
|
1800
|
+
cross_entropy: bool = False,
|
|
1801
|
+
fused_linear_cross_entropy: bool = True,
|
|
1802
|
+
rms_norm: bool = True,
|
|
1803
|
+
swiglu: bool = False,
|
|
1804
|
+
model: PreTrainedModel = None,
|
|
1805
|
+
) -> None:
|
|
1806
|
+
"""
|
|
1807
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1808
|
+
|
|
1809
|
+
Args:
|
|
1810
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1811
|
+
fused_linear_cross_entropy (bool):
|
|
1812
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1813
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1814
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1815
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1816
|
+
loaded. Default is None.
|
|
1817
|
+
"""
|
|
1818
|
+
|
|
1819
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1820
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1824
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1825
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1826
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1827
|
+
|
|
1828
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1829
|
+
|
|
1830
|
+
if rope:
|
|
1831
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1832
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1833
|
+
|
|
1834
|
+
if rms_norm:
|
|
1835
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1836
|
+
|
|
1837
|
+
if cross_entropy:
|
|
1838
|
+
from transformers.loss.loss_utils import nn
|
|
1839
|
+
|
|
1840
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1841
|
+
|
|
1842
|
+
if fused_linear_cross_entropy:
|
|
1843
|
+
if model is not None:
|
|
1844
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1845
|
+
else:
|
|
1846
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1847
|
+
|
|
1848
|
+
if model is not None and rms_norm:
|
|
1849
|
+
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
|
|
1850
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1851
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1852
|
+
text_model = model
|
|
1853
|
+
else:
|
|
1854
|
+
raise TypeError(
|
|
1855
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1856
|
+
)
|
|
1857
|
+
|
|
1858
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1859
|
+
|
|
1860
|
+
if text_model is not None:
|
|
1861
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1862
|
+
for decoder_layer in text_model.layers:
|
|
1863
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1864
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1865
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1866
|
+
if self_attn is not None:
|
|
1867
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1868
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1869
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1870
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1871
|
+
|
|
1872
|
+
|
|
1873
|
+
def apply_liger_kernel_to_phi3(
|
|
1874
|
+
rope: bool = True,
|
|
1875
|
+
cross_entropy: bool = False,
|
|
1876
|
+
fused_linear_cross_entropy: bool = True,
|
|
1877
|
+
rms_norm: bool = True,
|
|
1878
|
+
swiglu: bool = True,
|
|
1879
|
+
model: PreTrainedModel = None,
|
|
1880
|
+
) -> None:
|
|
1881
|
+
"""
|
|
1882
|
+
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
|
|
1883
|
+
|
|
1884
|
+
Args:
|
|
1885
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1886
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1887
|
+
fused_linear_cross_entropy (bool):
|
|
1888
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1889
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1890
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1891
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1892
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
1893
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1894
|
+
loaded. Default is None.
|
|
1895
|
+
"""
|
|
1896
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1897
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1898
|
+
)
|
|
1899
|
+
|
|
1900
|
+
from transformers.models.phi3 import modeling_phi3
|
|
1901
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
1902
|
+
|
|
1903
|
+
if rope:
|
|
1904
|
+
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
1905
|
+
if rms_norm:
|
|
1906
|
+
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
|
|
1907
|
+
if swiglu:
|
|
1908
|
+
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1909
|
+
if cross_entropy:
|
|
1910
|
+
from transformers.loss.loss_utils import nn
|
|
1911
|
+
|
|
1912
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1913
|
+
if fused_linear_cross_entropy:
|
|
1914
|
+
if model is not None:
|
|
1915
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1916
|
+
else:
|
|
1917
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1918
|
+
|
|
1919
|
+
if model is not None:
|
|
1920
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1921
|
+
# instance variables that reference already-instantiated modules
|
|
1922
|
+
|
|
1923
|
+
# get the base model from the model instance
|
|
1924
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
1925
|
+
|
|
1926
|
+
if rms_norm:
|
|
1927
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1928
|
+
|
|
1929
|
+
for decoder_layer in base_model.layers:
|
|
1930
|
+
if swiglu:
|
|
1931
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
1932
|
+
if rms_norm:
|
|
1933
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1934
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1935
|
+
|
|
1936
|
+
|
|
1937
|
+
def apply_liger_kernel_to_olmo2(
|
|
1938
|
+
rope: bool = True,
|
|
1939
|
+
cross_entropy: bool = False,
|
|
1940
|
+
fused_linear_cross_entropy: bool = True,
|
|
1941
|
+
rms_norm: bool = True,
|
|
1942
|
+
swiglu: bool = True,
|
|
1943
|
+
model: PreTrainedModel = None,
|
|
1944
|
+
) -> None:
|
|
1945
|
+
"""
|
|
1946
|
+
Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
|
|
1947
|
+
|
|
1948
|
+
Args:
|
|
1949
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1950
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1951
|
+
fused_linear_cross_entropy (bool):
|
|
1952
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1953
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1954
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1955
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1956
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
|
|
1957
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1958
|
+
loaded. Default is None.
|
|
1959
|
+
"""
|
|
1960
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1961
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
from transformers.models.olmo2 import modeling_olmo2
|
|
1965
|
+
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
1966
|
+
|
|
1967
|
+
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
1968
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1969
|
+
|
|
1970
|
+
if rope:
|
|
1971
|
+
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1972
|
+
if rms_norm:
|
|
1973
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
|
1974
|
+
if swiglu:
|
|
1975
|
+
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1976
|
+
if cross_entropy:
|
|
1977
|
+
from transformers.loss.loss_utils import nn
|
|
1978
|
+
|
|
1979
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1980
|
+
if fused_linear_cross_entropy:
|
|
1981
|
+
if model is not None:
|
|
1982
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1983
|
+
else:
|
|
1984
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1985
|
+
|
|
1986
|
+
if model is not None:
|
|
1987
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1988
|
+
# instance variables that reference already-instantiated modules
|
|
1989
|
+
|
|
1990
|
+
# get the base model from the model instance
|
|
1991
|
+
base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
|
|
1992
|
+
|
|
1993
|
+
if rms_norm:
|
|
1994
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1995
|
+
|
|
1996
|
+
for decoder_layer in base_model.layers:
|
|
1997
|
+
if swiglu:
|
|
1998
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1999
|
+
if rms_norm:
|
|
2000
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2001
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2002
|
+
|
|
2003
|
+
|
|
2004
|
+
def apply_liger_kernel_to_olmo3(
|
|
2005
|
+
rope: bool = True,
|
|
2006
|
+
cross_entropy: bool = False,
|
|
2007
|
+
fused_linear_cross_entropy: bool = True,
|
|
2008
|
+
rms_norm: bool = True,
|
|
2009
|
+
swiglu: bool = True,
|
|
2010
|
+
model: PreTrainedModel = None,
|
|
2011
|
+
) -> None:
|
|
2012
|
+
"""
|
|
2013
|
+
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
|
|
2014
|
+
|
|
2015
|
+
Args:
|
|
2016
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2017
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2018
|
+
fused_linear_cross_entropy (bool):
|
|
2019
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2020
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2021
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2022
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2023
|
+
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
|
|
2024
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2025
|
+
loaded. Default is None.
|
|
2026
|
+
"""
|
|
2027
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2028
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2029
|
+
)
|
|
2030
|
+
|
|
2031
|
+
from transformers.models.olmo3 import modeling_olmo3
|
|
2032
|
+
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
|
|
2033
|
+
|
|
2034
|
+
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
|
|
2035
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
2036
|
+
|
|
2037
|
+
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
|
|
2038
|
+
if rope:
|
|
2039
|
+
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2040
|
+
if rms_norm:
|
|
2041
|
+
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
|
|
2042
|
+
if swiglu:
|
|
2043
|
+
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
|
|
2044
|
+
if cross_entropy:
|
|
2045
|
+
from transformers.loss.loss_utils import nn
|
|
2046
|
+
|
|
2047
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2048
|
+
if fused_linear_cross_entropy:
|
|
2049
|
+
if model is not None:
|
|
2050
|
+
model.forward = MethodType(olmo3_lce_forward, model)
|
|
2051
|
+
else:
|
|
2052
|
+
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
|
|
2053
|
+
|
|
2054
|
+
if model is not None:
|
|
2055
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2056
|
+
# instance variables that reference already-instantiated modules
|
|
2057
|
+
|
|
2058
|
+
# get the base model from the model instance
|
|
2059
|
+
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
|
|
2060
|
+
|
|
2061
|
+
if rms_norm:
|
|
2062
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2063
|
+
|
|
2064
|
+
for decoder_layer in base_model.layers:
|
|
2065
|
+
if swiglu:
|
|
2066
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2067
|
+
if rms_norm:
|
|
2068
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2069
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2070
|
+
|
|
2071
|
+
|
|
2072
|
+
def apply_liger_kernel_to_glm4(
|
|
2073
|
+
rope: bool = False,
|
|
2074
|
+
cross_entropy: bool = False,
|
|
2075
|
+
fused_linear_cross_entropy: bool = True,
|
|
2076
|
+
rms_norm: bool = True,
|
|
2077
|
+
swiglu: bool = True,
|
|
2078
|
+
model: PreTrainedModel = None,
|
|
2079
|
+
) -> None:
|
|
2080
|
+
"""
|
|
2081
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
|
|
2082
|
+
|
|
2083
|
+
Args:
|
|
2084
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2085
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2086
|
+
fused_linear_cross_entropy (bool):
|
|
2087
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2088
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2089
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2090
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2091
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
2092
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2093
|
+
loaded. Default is None.
|
|
2094
|
+
"""
|
|
2095
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2096
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2097
|
+
)
|
|
2098
|
+
|
|
2099
|
+
from transformers.models.glm4 import modeling_glm4
|
|
2100
|
+
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
|
2101
|
+
|
|
2102
|
+
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
|
2103
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2104
|
+
|
|
2105
|
+
if rope:
|
|
2106
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2107
|
+
if rms_norm:
|
|
2108
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
|
2109
|
+
if swiglu:
|
|
2110
|
+
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
|
2111
|
+
if cross_entropy:
|
|
2112
|
+
from transformers.loss.loss_utils import nn
|
|
2113
|
+
|
|
2114
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2115
|
+
if fused_linear_cross_entropy:
|
|
2116
|
+
if model is not None:
|
|
2117
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
2118
|
+
else:
|
|
2119
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
2120
|
+
|
|
2121
|
+
if model is not None:
|
|
2122
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2123
|
+
# instance variables that reference already-instantiated modules
|
|
2124
|
+
|
|
2125
|
+
# get the base model from the model instance
|
|
2126
|
+
base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
|
|
2127
|
+
|
|
2128
|
+
if rms_norm:
|
|
2129
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
2130
|
+
|
|
2131
|
+
for decoder_layer in base_model.layers:
|
|
2132
|
+
if swiglu:
|
|
2133
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2134
|
+
if rms_norm:
|
|
2135
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
|
|
2136
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2137
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
|
|
2138
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
|
|
2139
|
+
|
|
2140
|
+
|
|
2141
|
+
def apply_liger_kernel_to_glm4v(
|
|
2142
|
+
rope: bool = False,
|
|
2143
|
+
cross_entropy: bool = False,
|
|
2144
|
+
fused_linear_cross_entropy: bool = True,
|
|
2145
|
+
rms_norm: bool = True,
|
|
2146
|
+
swiglu: bool = True,
|
|
2147
|
+
model: PreTrainedModel = None,
|
|
2148
|
+
) -> None:
|
|
2149
|
+
"""
|
|
2150
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
|
|
2151
|
+
|
|
2152
|
+
Args:
|
|
2153
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2154
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2155
|
+
fused_linear_cross_entropy (bool):
|
|
2156
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2157
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2158
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2159
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2160
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
2161
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2162
|
+
loaded. Default is None.
|
|
2163
|
+
"""
|
|
2164
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2165
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2166
|
+
)
|
|
2167
|
+
|
|
2168
|
+
from transformers.models.glm4v import modeling_glm4v
|
|
2169
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
|
|
2170
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
|
|
2171
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
|
|
2172
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
|
|
2173
|
+
|
|
2174
|
+
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
|
|
2175
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2176
|
+
|
|
2177
|
+
if rope:
|
|
2178
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2179
|
+
if rms_norm:
|
|
2180
|
+
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
|
|
2181
|
+
if cross_entropy:
|
|
2182
|
+
from transformers.loss.loss_utils import nn
|
|
2183
|
+
|
|
2184
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2185
|
+
if fused_linear_cross_entropy:
|
|
2186
|
+
if model is not None:
|
|
2187
|
+
model.forward = MethodType(glm4v_lce_forward, model)
|
|
2188
|
+
else:
|
|
2189
|
+
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
|
|
2190
|
+
|
|
2191
|
+
if model is not None:
|
|
2192
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2193
|
+
# instance variables that reference already-instantiated modules
|
|
2194
|
+
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
|
|
2195
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
2196
|
+
# Not sure if it is subject to changes in the future.
|
|
2197
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
|
|
2198
|
+
text_model: Glm4vTextModel = model.language_model
|
|
2199
|
+
vision_model: Glm4vVisionModel = model.visual
|
|
2200
|
+
elif isinstance(model, Glm4vTextModel):
|
|
2201
|
+
text_model: Glm4vTextModel = model
|
|
2202
|
+
vision_model = None
|
|
2203
|
+
else:
|
|
2204
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2205
|
+
raise TypeError(
|
|
2206
|
+
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
|
|
2207
|
+
)
|
|
2208
|
+
|
|
2209
|
+
if vision_model is not None:
|
|
2210
|
+
for vision_block in vision_model.blocks:
|
|
2211
|
+
if rms_norm:
|
|
2212
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2213
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2214
|
+
if swiglu:
|
|
2215
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2216
|
+
|
|
2217
|
+
if text_model is not None:
|
|
2218
|
+
if rms_norm:
|
|
2219
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2220
|
+
for decoder_layer in text_model.layers:
|
|
2221
|
+
if swiglu:
|
|
2222
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2223
|
+
if rms_norm:
|
|
2224
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2225
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2226
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
|
|
2227
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
|
|
2228
|
+
|
|
2229
|
+
|
|
2230
|
+
def apply_liger_kernel_to_glm4v_moe(
|
|
2231
|
+
rope: bool = False,
|
|
2232
|
+
cross_entropy: bool = False,
|
|
2233
|
+
fused_linear_cross_entropy: bool = True,
|
|
2234
|
+
rms_norm: bool = True,
|
|
2235
|
+
swiglu: bool = True,
|
|
2236
|
+
model: PreTrainedModel = None,
|
|
2237
|
+
) -> None:
|
|
2238
|
+
"""
|
|
2239
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
2240
|
+
|
|
2241
|
+
Args:
|
|
2242
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2243
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2244
|
+
fused_linear_cross_entropy (bool):
|
|
2245
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2246
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2247
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2248
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2249
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
2250
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2251
|
+
loaded. Default is None.
|
|
2252
|
+
"""
|
|
2253
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2254
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2255
|
+
)
|
|
2256
|
+
|
|
2257
|
+
from transformers.models.glm4v_moe import modeling_glm4v_moe
|
|
2258
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
|
|
2259
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
|
|
2260
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
|
|
2261
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
|
|
2262
|
+
|
|
2263
|
+
from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
|
|
2264
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2265
|
+
|
|
2266
|
+
if rope:
|
|
2267
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2268
|
+
if rms_norm:
|
|
2269
|
+
modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
|
|
2270
|
+
modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
|
|
2271
|
+
if cross_entropy:
|
|
2272
|
+
from transformers.loss.loss_utils import nn
|
|
2273
|
+
|
|
2274
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2275
|
+
if fused_linear_cross_entropy:
|
|
2276
|
+
if model is not None:
|
|
2277
|
+
model.forward = MethodType(glm4v_moe_lce_forward, model)
|
|
2278
|
+
else:
|
|
2279
|
+
modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
|
|
2280
|
+
|
|
2281
|
+
if model is not None:
|
|
2282
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2283
|
+
# instance variables that reference already-instantiated modules
|
|
2284
|
+
if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
|
|
2285
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
2286
|
+
# Not sure if it is subject to changes in the future.
|
|
2287
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
|
|
2288
|
+
text_model: Glm4vMoeTextModel = model.language_model
|
|
2289
|
+
vision_model: Glm4vMoeVisionModel = model.visual
|
|
2290
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2291
|
+
elif isinstance(model, Glm4vMoeTextModel):
|
|
2292
|
+
text_model: Glm4vMoeTextModel = model
|
|
2293
|
+
vision_model = None
|
|
2294
|
+
else:
|
|
2295
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2296
|
+
raise TypeError(
|
|
2297
|
+
f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
|
|
2298
|
+
)
|
|
2299
|
+
|
|
2300
|
+
if vision_model is not None:
|
|
2301
|
+
_patch_rms_norm_module(vision_model.post_conv_layernorm)
|
|
2302
|
+
_patch_rms_norm_module(vision_model.post_layernorm)
|
|
2303
|
+
for vision_block in vision_model.blocks:
|
|
2304
|
+
if rms_norm:
|
|
2305
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2306
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2307
|
+
if swiglu:
|
|
2308
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2309
|
+
|
|
2310
|
+
if text_model is not None:
|
|
2311
|
+
if rms_norm:
|
|
2312
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2313
|
+
for decoder_layer in text_model.layers:
|
|
2314
|
+
if swiglu:
|
|
2315
|
+
decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2316
|
+
if rms_norm:
|
|
2317
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2318
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2319
|
+
if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
|
|
2320
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2321
|
+
if experts is not None:
|
|
2322
|
+
for expert in experts:
|
|
2323
|
+
_patch_swiglu_module(expert, LigerSwiGLUMLP)
|
|
2324
|
+
if decoder_layer.mlp.shared_experts is not None:
|
|
2325
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
|
|
2326
|
+
for decoder_layer in text_model.layers:
|
|
2327
|
+
if rms_norm:
|
|
2328
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2329
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2330
|
+
|
|
2331
|
+
|
|
2332
|
+
def apply_liger_kernel_to_internvl(
|
|
2333
|
+
cross_entropy: bool = False,
|
|
2334
|
+
fused_linear_cross_entropy: bool = True,
|
|
2335
|
+
rms_norm: bool = True,
|
|
2336
|
+
layer_norm: bool = True,
|
|
2337
|
+
model: Optional[PreTrainedModel] = None,
|
|
2338
|
+
**kwargs,
|
|
2339
|
+
) -> None:
|
|
2340
|
+
"""
|
|
2341
|
+
Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
|
|
2342
|
+
Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
|
|
2343
|
+
However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
|
|
2344
|
+
NOTE: InternVL is not available in transformers<4.52.1
|
|
2345
|
+
|
|
2346
|
+
Args:
|
|
2347
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2348
|
+
fused_linear_cross_entropy (bool):
|
|
2349
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2350
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2351
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2352
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2353
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2354
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2355
|
+
loaded. Default is None.
|
|
2356
|
+
"""
|
|
2357
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2358
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2359
|
+
)
|
|
2360
|
+
import torch.nn as torch_nn
|
|
2361
|
+
|
|
2362
|
+
from transformers.models.internvl import modeling_internvl
|
|
2363
|
+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
|
|
2364
|
+
from transformers.models.internvl.modeling_internvl import InternVLModel
|
|
2365
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
|
|
2366
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
|
|
2367
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
|
|
2368
|
+
|
|
2369
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
2370
|
+
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
|
|
2371
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
2372
|
+
|
|
2373
|
+
if layer_norm and model is None:
|
|
2374
|
+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
|
|
2375
|
+
|
|
2376
|
+
if cross_entropy:
|
|
2377
|
+
logger.info("Apply liger cross entropy")
|
|
2378
|
+
|
|
2379
|
+
from transformers.loss.loss_utils import nn
|
|
2380
|
+
|
|
2381
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2382
|
+
if fused_linear_cross_entropy:
|
|
2383
|
+
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
|
|
2384
|
+
if rms_norm:
|
|
2385
|
+
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
|
|
2386
|
+
|
|
2387
|
+
if model is not None:
|
|
2388
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2389
|
+
# instance variables that reference already-instantiated modules
|
|
2390
|
+
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
|
|
2391
|
+
# NOTE: language_model and visual properties can be accessed throught conditional class.
|
|
2392
|
+
text_model = model.language_model
|
|
2393
|
+
vision_model: InternVLVisionModel = model.vision_tower
|
|
2394
|
+
else:
|
|
2395
|
+
raise TypeError(
|
|
2396
|
+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
|
|
2397
|
+
)
|
|
2398
|
+
|
|
2399
|
+
text_model_name = model.config.text_config.model_type
|
|
2400
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2401
|
+
|
|
2402
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2403
|
+
if text_liger_fn:
|
|
2404
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2405
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2406
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2407
|
+
|
|
2408
|
+
if remain_params:
|
|
2409
|
+
logger.warning(
|
|
2410
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2411
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2412
|
+
)
|
|
2413
|
+
text_kwargs["model"] = text_model
|
|
2414
|
+
text_liger_fn(**text_kwargs)
|
|
2415
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2416
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2417
|
+
|
|
2418
|
+
# Patch vision model RMSNorm layers
|
|
2419
|
+
if rms_norm:
|
|
2420
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2421
|
+
encoder_layer: InternVLVisionLayer
|
|
2422
|
+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
|
|
2423
|
+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
|
|
2424
|
+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
|
|
2425
|
+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
|
|
2426
|
+
|
|
2427
|
+
# Patch vision model LayerNorm layers
|
|
2428
|
+
if layer_norm:
|
|
2429
|
+
# Patch layernorm
|
|
2430
|
+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
|
|
2431
|
+
_patch_layer_norm_module(vision_model.layernorm)
|
|
2432
|
+
|
|
2433
|
+
# Patch encoder layers
|
|
2434
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2435
|
+
encoder_layer: InternVLVisionLayer
|
|
2436
|
+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
|
|
2437
|
+
_patch_layer_norm_module(encoder_layer.layernorm_before)
|
|
2438
|
+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
|
|
2439
|
+
_patch_layer_norm_module(encoder_layer.layernorm_after)
|
|
2440
|
+
|
|
2441
|
+
|
|
2442
|
+
def apply_liger_kernel_to_smolvlm(
|
|
2443
|
+
cross_entropy: bool = False,
|
|
2444
|
+
fused_linear_cross_entropy: bool = True,
|
|
2445
|
+
rms_norm: bool = True,
|
|
2446
|
+
layer_norm: bool = True,
|
|
2447
|
+
model: Optional[PreTrainedModel] = None,
|
|
2448
|
+
**kwargs,
|
|
2449
|
+
) -> None:
|
|
2450
|
+
"""
|
|
2451
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
|
|
2452
|
+
Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
|
|
2453
|
+
However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
|
|
2454
|
+
NOTE: SmolVLM is not available in transformers<4.50.0
|
|
2455
|
+
|
|
2456
|
+
Args:
|
|
2457
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2458
|
+
fused_linear_cross_entropy (bool):
|
|
2459
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2460
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2461
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2462
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2463
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2464
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2465
|
+
loaded. Default is None.
|
|
2466
|
+
"""
|
|
2467
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2468
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2469
|
+
)
|
|
2470
|
+
|
|
2471
|
+
from transformers.models.smolvlm import modeling_smolvlm
|
|
2472
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
|
|
2473
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
|
|
2474
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
|
|
2475
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
|
|
2476
|
+
|
|
2477
|
+
from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
|
|
2478
|
+
|
|
2479
|
+
# Patch LayerNorm for vision model if model is not provided (pre-initialization)
|
|
2480
|
+
if layer_norm and model is None:
|
|
2481
|
+
modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
|
|
2482
|
+
|
|
2483
|
+
if cross_entropy:
|
|
2484
|
+
logger.info("Apply liger cross entropy")
|
|
2485
|
+
|
|
2486
|
+
from transformers.loss.loss_utils import nn
|
|
2487
|
+
|
|
2488
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2489
|
+
if fused_linear_cross_entropy:
|
|
2490
|
+
if model is not None:
|
|
2491
|
+
model.forward = MethodType(smolvlm_lce_forward, model)
|
|
2492
|
+
else:
|
|
2493
|
+
modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
|
|
2494
|
+
if rms_norm:
|
|
2495
|
+
modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
|
|
1240
2496
|
|
|
1241
2497
|
if model is not None:
|
|
1242
2498
|
# The model instance already exists, so we need to additionally patch the
|
|
1243
2499
|
# instance variables that reference already-instantiated modules
|
|
2500
|
+
if isinstance(model, SmolVLMForConditionalGeneration):
|
|
2501
|
+
text_model = model.model.text_model
|
|
2502
|
+
vision_model: SmolVLMVisionTransformer = model.model.vision_model
|
|
2503
|
+
elif isinstance(model, SmolVLMModel):
|
|
2504
|
+
text_model = model.text_model
|
|
2505
|
+
vision_model: SmolVLMVisionTransformer = model.vision_model
|
|
2506
|
+
else:
|
|
2507
|
+
raise TypeError(
|
|
2508
|
+
f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
|
|
2509
|
+
)
|
|
2510
|
+
|
|
2511
|
+
text_model_name = model.config.text_config.model_type
|
|
2512
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2513
|
+
|
|
2514
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2515
|
+
if text_liger_fn:
|
|
2516
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2517
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2518
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2519
|
+
|
|
2520
|
+
if remain_params:
|
|
2521
|
+
logger.warning(
|
|
2522
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2523
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2524
|
+
)
|
|
2525
|
+
text_kwargs["model"] = text_model
|
|
2526
|
+
text_liger_fn(**text_kwargs)
|
|
2527
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2528
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2529
|
+
|
|
2530
|
+
# Patch vision model LayerNorm layers
|
|
2531
|
+
if layer_norm:
|
|
2532
|
+
# Patch post_layernorm
|
|
2533
|
+
_patch_layer_norm_module(vision_model.post_layernorm)
|
|
2534
|
+
|
|
2535
|
+
# Patch encoder layers
|
|
2536
|
+
for encoder_layer in vision_model.encoder.layers:
|
|
2537
|
+
encoder_layer: SmolVLMEncoderLayer
|
|
2538
|
+
_patch_layer_norm_module(encoder_layer.layer_norm1)
|
|
2539
|
+
_patch_layer_norm_module(encoder_layer.layer_norm2)
|
|
2540
|
+
|
|
2541
|
+
|
|
2542
|
+
def apply_liger_kernel_to_falcon_h1(
|
|
2543
|
+
rope: bool = True,
|
|
2544
|
+
cross_entropy: bool = False,
|
|
2545
|
+
fused_linear_cross_entropy: bool = True,
|
|
2546
|
+
rms_norm: bool = True,
|
|
2547
|
+
swiglu: bool = False,
|
|
2548
|
+
model: PreTrainedModel = None,
|
|
2549
|
+
) -> None:
|
|
2550
|
+
"""
|
|
2551
|
+
Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
|
|
2552
|
+
Args:
|
|
2553
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2554
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
2555
|
+
fused_linear_cross_entropy (bool):
|
|
2556
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
2557
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2558
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2559
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
2560
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2561
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2562
|
+
loaded. Default is None.
|
|
2563
|
+
"""
|
|
2564
|
+
|
|
2565
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2566
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2567
|
+
)
|
|
2568
|
+
|
|
2569
|
+
from transformers.models.falcon_h1 import modeling_falcon_h1
|
|
2570
|
+
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
|
|
2571
|
+
|
|
2572
|
+
if rope:
|
|
2573
|
+
logger.info("Apply liger rotary pos emb.")
|
|
2574
|
+
modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2575
|
+
if rms_norm:
|
|
2576
|
+
logger.info("Apply liger RMSNorm")
|
|
2577
|
+
modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
|
|
2578
|
+
if swiglu:
|
|
2579
|
+
logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
|
|
2580
|
+
|
|
2581
|
+
if cross_entropy:
|
|
2582
|
+
logger.info("Apply liger cross entropy")
|
|
2583
|
+
from transformers.loss.loss_utils import nn
|
|
2584
|
+
|
|
2585
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2586
|
+
|
|
2587
|
+
if fused_linear_cross_entropy:
|
|
2588
|
+
if model is not None:
|
|
2589
|
+
model.forward = MethodType(falcon_h1_lce_forward, model)
|
|
2590
|
+
else:
|
|
2591
|
+
modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
|
|
2592
|
+
|
|
2593
|
+
if model is not None:
|
|
2594
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2595
|
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
1244
2596
|
|
|
1245
2597
|
# get the base model from the model instance
|
|
1246
|
-
base_model:
|
|
2598
|
+
base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
|
|
1247
2599
|
|
|
1248
2600
|
if rms_norm:
|
|
1249
|
-
_patch_rms_norm_module(base_model.
|
|
2601
|
+
_patch_rms_norm_module(base_model.final_layernorm)
|
|
1250
2602
|
|
|
1251
2603
|
for decoder_layer in base_model.layers:
|
|
1252
2604
|
if swiglu:
|
|
1253
|
-
_patch_swiglu_module(decoder_layer.mlp,
|
|
2605
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1254
2606
|
if rms_norm:
|
|
1255
2607
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1256
|
-
_patch_rms_norm_module(decoder_layer.
|
|
2608
|
+
_patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
|
|
1257
2609
|
|
|
1258
2610
|
|
|
1259
|
-
def
|
|
1260
|
-
rope: bool =
|
|
2611
|
+
def apply_liger_kernel_to_qwen3_next(
|
|
2612
|
+
rope: bool = False,
|
|
1261
2613
|
cross_entropy: bool = False,
|
|
1262
2614
|
fused_linear_cross_entropy: bool = True,
|
|
1263
2615
|
rms_norm: bool = True,
|
|
@@ -1265,17 +2617,17 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1265
2617
|
model: PreTrainedModel = None,
|
|
1266
2618
|
) -> None:
|
|
1267
2619
|
"""
|
|
1268
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
2620
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
1269
2621
|
|
|
1270
2622
|
Args:
|
|
1271
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
2623
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
1272
2624
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1273
2625
|
fused_linear_cross_entropy (bool):
|
|
1274
2626
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1275
2627
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1276
2628
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1277
2629
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1278
|
-
swiglu (bool): Whether to apply Liger's
|
|
2630
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
1279
2631
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1280
2632
|
loaded. Default is None.
|
|
1281
2633
|
"""
|
|
@@ -1283,40 +2635,185 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1283
2635
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1284
2636
|
)
|
|
1285
2637
|
|
|
1286
|
-
from transformers.models.
|
|
1287
|
-
from transformers.models.
|
|
2638
|
+
from transformers.models.qwen3_next import modeling_qwen3_next
|
|
2639
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
|
|
2640
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
|
|
2641
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
|
|
2642
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
|
1288
2643
|
|
|
1289
|
-
from liger_kernel.transformers.model.
|
|
2644
|
+
from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
|
|
2645
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
|
|
2646
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
1290
2647
|
|
|
1291
2648
|
if rope:
|
|
1292
|
-
|
|
2649
|
+
# It might enocunter nan issue
|
|
2650
|
+
# modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2651
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
|
|
1293
2652
|
if rms_norm:
|
|
1294
|
-
|
|
2653
|
+
modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
|
|
2654
|
+
if cross_entropy:
|
|
2655
|
+
from transformers.loss.loss_utils import nn
|
|
2656
|
+
|
|
2657
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2658
|
+
if fused_linear_cross_entropy:
|
|
2659
|
+
if model is not None:
|
|
2660
|
+
if isinstance(model, Qwen3NextForCausalLM):
|
|
2661
|
+
model.forward = MethodType(qwen3_next_lce_forward, model)
|
|
2662
|
+
else:
|
|
2663
|
+
raise TypeError(
|
|
2664
|
+
f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
|
|
2665
|
+
)
|
|
2666
|
+
else:
|
|
2667
|
+
modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
|
|
1295
2668
|
if swiglu:
|
|
1296
|
-
|
|
2669
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2670
|
+
modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
|
|
2671
|
+
|
|
2672
|
+
if model is not None:
|
|
2673
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2674
|
+
# instance variables that reference already-instantiated modules
|
|
2675
|
+
if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
|
|
2676
|
+
base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
|
|
2677
|
+
else:
|
|
2678
|
+
raise TypeError(
|
|
2679
|
+
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
|
|
2680
|
+
)
|
|
2681
|
+
|
|
2682
|
+
if rms_norm:
|
|
2683
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2684
|
+
|
|
2685
|
+
for decoder_layer in base_model.layers:
|
|
2686
|
+
if rms_norm:
|
|
2687
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2688
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2689
|
+
|
|
2690
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2691
|
+
if swiglu:
|
|
2692
|
+
if isinstance(decoder_layer.mlp, Qwen3NextMLP):
|
|
2693
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
|
|
2694
|
+
if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
|
|
2695
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
|
|
2696
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2697
|
+
if experts is not None:
|
|
2698
|
+
for expert in experts:
|
|
2699
|
+
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
|
|
2700
|
+
|
|
2701
|
+
|
|
2702
|
+
def apply_liger_kernel_to_hunyuan_v1_dense(
|
|
2703
|
+
rope: bool = True,
|
|
2704
|
+
cross_entropy: bool = False,
|
|
2705
|
+
fused_linear_cross_entropy: bool = True,
|
|
2706
|
+
rms_norm: bool = True,
|
|
2707
|
+
swiglu: bool = True,
|
|
2708
|
+
model: PreTrainedModel = None,
|
|
2709
|
+
) -> None:
|
|
2710
|
+
"""
|
|
2711
|
+
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
|
|
2712
|
+
"""
|
|
2713
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2714
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2715
|
+
)
|
|
2716
|
+
|
|
2717
|
+
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
|
|
2718
|
+
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
|
|
2719
|
+
|
|
2720
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
|
|
2721
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2722
|
+
|
|
2723
|
+
if rope:
|
|
2724
|
+
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2725
|
+
|
|
2726
|
+
if rms_norm:
|
|
2727
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
|
|
2728
|
+
|
|
1297
2729
|
if cross_entropy:
|
|
1298
2730
|
from transformers.loss.loss_utils import nn
|
|
1299
2731
|
|
|
1300
2732
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
2733
|
+
|
|
1301
2734
|
if fused_linear_cross_entropy:
|
|
1302
|
-
|
|
2735
|
+
if model is not None:
|
|
2736
|
+
model.forward = MethodType(hunyuan_v1_lce_forward, model)
|
|
2737
|
+
else:
|
|
2738
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
|
|
2739
|
+
|
|
2740
|
+
if swiglu:
|
|
2741
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
1303
2742
|
|
|
1304
2743
|
if model is not None:
|
|
1305
2744
|
# The model instance already exists, so we need to additionally patch the
|
|
1306
2745
|
# instance variables that reference already-instantiated modules
|
|
1307
2746
|
|
|
1308
2747
|
# get the base model from the model instance
|
|
1309
|
-
base_model:
|
|
2748
|
+
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
|
|
1310
2749
|
|
|
1311
2750
|
if rms_norm:
|
|
1312
2751
|
_patch_rms_norm_module(base_model.norm)
|
|
2752
|
+
for decoder_layer in base_model.layers:
|
|
2753
|
+
if swiglu:
|
|
2754
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
|
|
2755
|
+
if rms_norm:
|
|
2756
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2757
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2758
|
+
|
|
2759
|
+
|
|
2760
|
+
def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
2761
|
+
rope: bool = True,
|
|
2762
|
+
cross_entropy: bool = False,
|
|
2763
|
+
fused_linear_cross_entropy: bool = True,
|
|
2764
|
+
rms_norm: bool = True,
|
|
2765
|
+
swiglu: bool = True,
|
|
2766
|
+
model: PreTrainedModel = None,
|
|
2767
|
+
) -> None:
|
|
2768
|
+
"""
|
|
2769
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
2770
|
+
"""
|
|
2771
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2772
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2773
|
+
)
|
|
2774
|
+
|
|
2775
|
+
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
|
|
2776
|
+
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
|
|
2777
|
+
|
|
2778
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
|
|
2779
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2780
|
+
|
|
2781
|
+
if rope:
|
|
2782
|
+
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2783
|
+
|
|
2784
|
+
if rms_norm:
|
|
2785
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
|
|
2786
|
+
|
|
2787
|
+
if cross_entropy:
|
|
2788
|
+
from transformers.loss.loss_utils import nn
|
|
2789
|
+
|
|
2790
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2791
|
+
|
|
2792
|
+
if fused_linear_cross_entropy:
|
|
2793
|
+
if model is not None:
|
|
2794
|
+
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
|
|
2795
|
+
else:
|
|
2796
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
|
|
1313
2797
|
|
|
2798
|
+
if swiglu:
|
|
2799
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2800
|
+
|
|
2801
|
+
if model is not None:
|
|
2802
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2803
|
+
# instance variables that reference already-instantiated modules
|
|
2804
|
+
|
|
2805
|
+
# get the base model from the model instance
|
|
2806
|
+
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
|
|
2807
|
+
|
|
2808
|
+
if rms_norm:
|
|
2809
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1314
2810
|
for decoder_layer in base_model.layers:
|
|
1315
2811
|
if swiglu:
|
|
1316
|
-
|
|
2812
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
2813
|
+
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
|
|
1317
2814
|
if rms_norm:
|
|
1318
|
-
_patch_rms_norm_module(decoder_layer.
|
|
1319
|
-
_patch_rms_norm_module(decoder_layer.
|
|
2815
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2816
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1320
2817
|
|
|
1321
2818
|
|
|
1322
2819
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
@@ -1325,7 +2822,14 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1325
2822
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
1326
2823
|
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
1327
2824
|
"gemma3": apply_liger_kernel_to_gemma3,
|
|
2825
|
+
"glm4": apply_liger_kernel_to_glm4,
|
|
2826
|
+
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2827
|
+
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2828
|
+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
|
|
2829
|
+
"internvl": apply_liger_kernel_to_internvl,
|
|
1328
2830
|
"llama": apply_liger_kernel_to_llama,
|
|
2831
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
2832
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
1329
2833
|
"llava": apply_liger_kernel_to_llava,
|
|
1330
2834
|
"granite": apply_liger_kernel_to_granite,
|
|
1331
2835
|
"mllama": apply_liger_kernel_to_mllama,
|
|
@@ -1333,11 +2837,26 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1333
2837
|
"mistral": apply_liger_kernel_to_mistral,
|
|
1334
2838
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
1335
2839
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
2840
|
+
"olmo3": apply_liger_kernel_to_olmo3,
|
|
1336
2841
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
2842
|
+
"qwen3": apply_liger_kernel_to_qwen3,
|
|
2843
|
+
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
1337
2844
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
2845
|
+
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
1338
2846
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
2847
|
+
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2848
|
+
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2849
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2850
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2851
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2852
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2853
|
+
"smollm3": apply_liger_kernel_to_smollm3,
|
|
1339
2854
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1340
2855
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
2856
|
+
"falcon_h1": apply_liger_kernel_to_falcon_h1,
|
|
2857
|
+
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2858
|
+
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2859
|
+
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
1341
2860
|
}
|
|
1342
2861
|
|
|
1343
2862
|
|
|
@@ -1395,7 +2914,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
1395
2914
|
return
|
|
1396
2915
|
|
|
1397
2916
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
1398
|
-
|
|
1399
2917
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
1400
2918
|
|
|
1401
2919
|
# Filter out the keyword arguments that are not supported by the apply function
|