liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -1,43 +1,51 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import logging
|
|
3
|
+
|
|
3
4
|
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
4
6
|
from typing import Callable
|
|
7
|
+
from typing import Optional
|
|
5
8
|
|
|
6
9
|
import transformers
|
|
10
|
+
|
|
7
11
|
from packaging import version
|
|
8
12
|
from transformers import PreTrainedModel
|
|
9
13
|
|
|
10
14
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
15
|
+
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
11
16
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
12
17
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
18
|
+
from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
|
|
13
19
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
14
|
-
from liger_kernel.transformers.model.gemma import
|
|
15
|
-
|
|
16
|
-
|
|
20
|
+
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
21
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
22
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
17
23
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
18
|
-
from liger_kernel.transformers.model.llama import
|
|
19
|
-
|
|
20
|
-
|
|
24
|
+
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
25
|
+
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
26
|
+
from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
|
|
21
27
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
22
28
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
23
|
-
from liger_kernel.transformers.model.mixtral import
|
|
24
|
-
lce_forward_deprecated as mixtral_lce_forward_deprecated,
|
|
25
|
-
)
|
|
29
|
+
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
26
30
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
27
|
-
from liger_kernel.transformers.model.phi3 import (
|
|
28
|
-
lce_forward_deprecated as phi3_lce_forward_deprecated,
|
|
29
|
-
)
|
|
30
31
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
31
|
-
from liger_kernel.transformers.model.qwen2 import
|
|
32
|
-
|
|
33
|
-
|
|
32
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
33
|
+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
34
|
+
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
34
35
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
35
36
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
36
|
-
from liger_kernel.transformers.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
39
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
40
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
41
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
import peft
|
|
45
|
+
|
|
46
|
+
PEFT_AVAILABLE = True
|
|
47
|
+
except ImportError:
|
|
48
|
+
PEFT_AVAILABLE = False
|
|
41
49
|
|
|
42
50
|
transformer_version = version.parse(transformers.__version__)
|
|
43
51
|
|
|
@@ -51,23 +59,161 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
51
59
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
52
60
|
|
|
53
61
|
|
|
54
|
-
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
|
|
55
|
-
module
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
63
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
64
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
65
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
66
|
+
module.modules_to_save.default.offset = offset
|
|
67
|
+
module.modules_to_save.default.casting_mode = casting_mode
|
|
68
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
69
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
70
|
+
)
|
|
71
|
+
module.modules_to_save.default.in_place = in_place
|
|
72
|
+
module.modules_to_save.default.row_mode = row_mode
|
|
73
|
+
module.original_module.offset = offset
|
|
74
|
+
module.original_module.casting_mode = casting_mode
|
|
75
|
+
module.original_module.variance_epsilon = (
|
|
76
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
77
|
+
)
|
|
78
|
+
module.original_module.in_place = in_place
|
|
79
|
+
module.original_module.row_mode = row_mode
|
|
80
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
81
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
82
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
83
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
84
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
85
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
86
|
+
else:
|
|
87
|
+
module.offset = offset
|
|
88
|
+
module.casting_mode = casting_mode
|
|
89
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
90
|
+
module.in_place = in_place
|
|
91
|
+
module.row_mode = row_mode
|
|
92
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
93
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
94
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
62
95
|
|
|
63
96
|
|
|
64
97
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
65
|
-
module
|
|
66
|
-
|
|
98
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
99
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
100
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
101
|
+
module.hidden_size = module.normalized_shape
|
|
102
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
103
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
104
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
105
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
106
|
+
)
|
|
107
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
108
|
+
module, "normalized_shape", None
|
|
109
|
+
)
|
|
110
|
+
module.original_module.variance_epsilon = (
|
|
111
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
112
|
+
)
|
|
113
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
114
|
+
module, "normalized_shape", None
|
|
115
|
+
)
|
|
116
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
|
|
117
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
|
|
118
|
+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
|
|
119
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
120
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
121
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
122
|
+
else:
|
|
123
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
124
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
125
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
126
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
127
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _patch_swiglu_module(module, liger_module):
|
|
131
|
+
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
132
|
+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _patch_geglu_module(module):
|
|
136
|
+
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
137
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def apply_liger_kernel_to_granite(
|
|
141
|
+
rope: bool = True,
|
|
142
|
+
cross_entropy: bool = True,
|
|
143
|
+
fused_linear_cross_entropy: bool = False,
|
|
144
|
+
rms_norm: bool = True,
|
|
145
|
+
swiglu: bool = True,
|
|
146
|
+
model: PreTrainedModel = None,
|
|
147
|
+
) -> None:
|
|
148
|
+
"""
|
|
149
|
+
Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
153
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
154
|
+
fused_linear_cross_entropy (bool):
|
|
155
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
156
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
157
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
158
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
159
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
160
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
161
|
+
loaded. Default is None.
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
Debugging notes:
|
|
166
|
+
If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
170
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
67
171
|
)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
172
|
+
|
|
173
|
+
from transformers.models.granite import modeling_granite
|
|
174
|
+
from transformers.models.granite.modeling_granite import GraniteModel
|
|
175
|
+
|
|
176
|
+
if swiglu:
|
|
177
|
+
modeling_granite.GraniteMLP = LigerSwiGLUMLP
|
|
178
|
+
|
|
179
|
+
if rms_norm:
|
|
180
|
+
modeling_granite.GraniteRMSNorm = LigerRMSNorm
|
|
181
|
+
|
|
182
|
+
if rope:
|
|
183
|
+
modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
184
|
+
|
|
185
|
+
if cross_entropy:
|
|
186
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
187
|
+
from transformers.loss.loss_utils import nn
|
|
188
|
+
|
|
189
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
190
|
+
else:
|
|
191
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
192
|
+
modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
193
|
+
|
|
194
|
+
if fused_linear_cross_entropy:
|
|
195
|
+
raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
|
|
196
|
+
# NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
|
|
197
|
+
# call, so we can't sidestep logit materialization. A bit more work
|
|
198
|
+
# would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
|
|
199
|
+
# for the logit output.
|
|
200
|
+
|
|
201
|
+
if model is not None:
|
|
202
|
+
# The model instance already exists, so we need to additionally patch the
|
|
203
|
+
# instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
|
|
204
|
+
|
|
205
|
+
# get the base model from the model instance
|
|
206
|
+
base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
|
|
207
|
+
|
|
208
|
+
if rms_norm:
|
|
209
|
+
_patch_rms_norm_module(base_model.norm)
|
|
210
|
+
|
|
211
|
+
for decoder_layer in base_model.layers:
|
|
212
|
+
if swiglu:
|
|
213
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
214
|
+
if rms_norm:
|
|
215
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
216
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
71
217
|
|
|
72
218
|
|
|
73
219
|
def apply_liger_kernel_to_llama(
|
|
@@ -94,11 +240,12 @@ def apply_liger_kernel_to_llama(
|
|
|
94
240
|
loaded. Default is None.
|
|
95
241
|
"""
|
|
96
242
|
|
|
97
|
-
assert not (
|
|
98
|
-
cross_entropy and fused_linear_cross_entropy
|
|
99
|
-
)
|
|
243
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
244
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
245
|
+
)
|
|
100
246
|
|
|
101
247
|
from transformers.models.llama import modeling_llama
|
|
248
|
+
from transformers.models.llama.modeling_llama import LlamaModel
|
|
102
249
|
|
|
103
250
|
if rope:
|
|
104
251
|
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -106,42 +253,295 @@ def apply_liger_kernel_to_llama(
|
|
|
106
253
|
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
|
107
254
|
if swiglu:
|
|
108
255
|
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
|
256
|
+
|
|
109
257
|
if cross_entropy:
|
|
110
|
-
|
|
258
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
259
|
+
from transformers.loss.loss_utils import nn
|
|
260
|
+
|
|
261
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
262
|
+
else:
|
|
263
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
264
|
+
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
265
|
+
|
|
111
266
|
if fused_linear_cross_entropy:
|
|
112
267
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
113
|
-
|
|
268
|
+
if model is not None:
|
|
269
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
270
|
+
else:
|
|
271
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
114
272
|
else: # if version < 4.46.1
|
|
115
273
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
116
|
-
|
|
274
|
+
if model is not None:
|
|
275
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
276
|
+
else:
|
|
277
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
117
278
|
|
|
118
279
|
if model is not None:
|
|
119
280
|
# The model instance already exists, so we need to additionally patch the
|
|
120
281
|
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
121
282
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
283
|
+
# get the base model from the model instance
|
|
284
|
+
base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
|
|
285
|
+
|
|
286
|
+
if rms_norm:
|
|
287
|
+
_patch_rms_norm_module(base_model.norm)
|
|
288
|
+
|
|
289
|
+
for decoder_layer in base_model.layers:
|
|
290
|
+
if swiglu:
|
|
291
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
292
|
+
if rms_norm:
|
|
293
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
294
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def apply_liger_kernel_to_smollm3(
|
|
298
|
+
rope: bool = True,
|
|
299
|
+
cross_entropy: bool = False,
|
|
300
|
+
fused_linear_cross_entropy: bool = True,
|
|
301
|
+
rms_norm: bool = True,
|
|
302
|
+
swiglu: bool = True,
|
|
303
|
+
model: PreTrainedModel = None,
|
|
304
|
+
) -> None:
|
|
305
|
+
"""
|
|
306
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
310
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
311
|
+
fused_linear_cross_entropy (bool):
|
|
312
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
313
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
314
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
315
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
316
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
317
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
318
|
+
loaded. Default is None.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
322
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
from transformers.models.smollm3 import modeling_smollm3
|
|
326
|
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
|
|
327
|
+
|
|
328
|
+
if rope:
|
|
329
|
+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
330
|
+
if rms_norm:
|
|
331
|
+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
|
|
332
|
+
if swiglu:
|
|
333
|
+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
|
|
334
|
+
|
|
335
|
+
if cross_entropy:
|
|
336
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
337
|
+
from transformers.loss.loss_utils import nn
|
|
338
|
+
|
|
339
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
340
|
+
else:
|
|
341
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
342
|
+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
343
|
+
|
|
344
|
+
if fused_linear_cross_entropy:
|
|
345
|
+
if model is not None:
|
|
346
|
+
model.forward = MethodType(smollm3_lce_forward, model)
|
|
128
347
|
else:
|
|
129
|
-
|
|
130
|
-
|
|
348
|
+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
|
|
349
|
+
|
|
350
|
+
if model is not None:
|
|
351
|
+
# The model instance already exists, so we need to additionally patch the
|
|
352
|
+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
|
|
353
|
+
|
|
354
|
+
# get the base model from the model instance
|
|
355
|
+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
|
|
131
356
|
|
|
132
357
|
if rms_norm:
|
|
133
358
|
_patch_rms_norm_module(base_model.norm)
|
|
134
359
|
|
|
135
360
|
for decoder_layer in base_model.layers:
|
|
136
361
|
if swiglu:
|
|
137
|
-
|
|
138
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
139
|
-
)
|
|
362
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
140
363
|
if rms_norm:
|
|
141
364
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
142
365
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
143
366
|
|
|
144
367
|
|
|
368
|
+
def apply_liger_kernel_to_llava(
|
|
369
|
+
cross_entropy: bool = False,
|
|
370
|
+
fused_linear_cross_entropy: bool = True,
|
|
371
|
+
model: PreTrainedModel = None,
|
|
372
|
+
**kwargs,
|
|
373
|
+
) -> None:
|
|
374
|
+
"""
|
|
375
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
|
|
376
|
+
Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
|
|
377
|
+
However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
|
|
378
|
+
NOTE: Llava is not available in transformers<4.36.0
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
382
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
383
|
+
fused_linear_cross_entropy (bool):
|
|
384
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
385
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
386
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
387
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
388
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
389
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
390
|
+
loaded. Default is None.
|
|
391
|
+
"""
|
|
392
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
393
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
from transformers.models.llava import modeling_llava
|
|
397
|
+
|
|
398
|
+
if cross_entropy:
|
|
399
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
400
|
+
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
401
|
+
if fused_linear_cross_entropy:
|
|
402
|
+
if transformer_version >= version.parse("4.52.0"):
|
|
403
|
+
if model is not None:
|
|
404
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
405
|
+
else:
|
|
406
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
407
|
+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
408
|
+
if model is not None:
|
|
409
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
410
|
+
else:
|
|
411
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
412
|
+
else: # if version < 4.49.0
|
|
413
|
+
logger.warning(
|
|
414
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
if model is not None:
|
|
418
|
+
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
419
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
420
|
+
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
|
|
421
|
+
|
|
422
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
|
|
423
|
+
if text_liger_fn:
|
|
424
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
425
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
426
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
427
|
+
|
|
428
|
+
if remain_params:
|
|
429
|
+
logger.warning(
|
|
430
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
431
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
432
|
+
)
|
|
433
|
+
text_kwargs["model"] = model.language_model
|
|
434
|
+
text_liger_fn(**text_kwargs)
|
|
435
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
436
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
437
|
+
|
|
438
|
+
if vision_liger_fn:
|
|
439
|
+
accept_params = inspect.signature(vision_liger_fn).parameters
|
|
440
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
441
|
+
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
442
|
+
|
|
443
|
+
if remain_params:
|
|
444
|
+
logger.warning(
|
|
445
|
+
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
446
|
+
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
447
|
+
)
|
|
448
|
+
vision_kwargs["model"] = model.vision_tower
|
|
449
|
+
vision_liger_fn(**vision_kwargs)
|
|
450
|
+
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
451
|
+
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def apply_liger_kernel_to_llama4(
|
|
455
|
+
rope: bool = True,
|
|
456
|
+
cross_entropy: bool = False,
|
|
457
|
+
fused_linear_cross_entropy: bool = True,
|
|
458
|
+
rms_norm: bool = True,
|
|
459
|
+
swiglu: bool = True,
|
|
460
|
+
model: PreTrainedModel = None,
|
|
461
|
+
layer_norm: bool = True,
|
|
462
|
+
) -> None:
|
|
463
|
+
"""
|
|
464
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
468
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
469
|
+
fused_linear_cross_entropy (bool):
|
|
470
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
471
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
472
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
473
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
474
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
475
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
476
|
+
loaded. Default is None.
|
|
477
|
+
"""
|
|
478
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
479
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
from transformers.models.llama4 import modeling_llama4
|
|
483
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
484
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
485
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
486
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
487
|
+
|
|
488
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
489
|
+
|
|
490
|
+
if rope:
|
|
491
|
+
from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
|
|
492
|
+
|
|
493
|
+
apply_liger_llama4_rope_full(modeling_llama4)
|
|
494
|
+
if rms_norm:
|
|
495
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
496
|
+
if swiglu:
|
|
497
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
498
|
+
|
|
499
|
+
if cross_entropy:
|
|
500
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
501
|
+
|
|
502
|
+
if fused_linear_cross_entropy:
|
|
503
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
504
|
+
|
|
505
|
+
if model is not None:
|
|
506
|
+
# The model instance already exists, so we need to additionally patch the
|
|
507
|
+
# instance variables that reference already-instantiated modules
|
|
508
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
509
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
510
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
511
|
+
text_model: Llama4TextModel = language_model.model
|
|
512
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
513
|
+
text_model = model.model
|
|
514
|
+
vision_model = None
|
|
515
|
+
elif isinstance(model, Llama4TextModel):
|
|
516
|
+
text_model = model
|
|
517
|
+
vision_model = None
|
|
518
|
+
|
|
519
|
+
else:
|
|
520
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
521
|
+
|
|
522
|
+
if text_model:
|
|
523
|
+
if rms_norm:
|
|
524
|
+
_patch_rms_norm_module(text_model.norm)
|
|
525
|
+
for decoder_layer in text_model.layers:
|
|
526
|
+
if swiglu:
|
|
527
|
+
if decoder_layer.is_moe_layer:
|
|
528
|
+
_patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
|
|
529
|
+
else:
|
|
530
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
531
|
+
if rms_norm:
|
|
532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
534
|
+
|
|
535
|
+
if vision_model:
|
|
536
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
537
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
538
|
+
|
|
539
|
+
for layer in vision_model.model.layers:
|
|
540
|
+
if layer_norm:
|
|
541
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
542
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
543
|
+
|
|
544
|
+
|
|
145
545
|
def apply_liger_kernel_to_mllama(
|
|
146
546
|
rope: bool = True,
|
|
147
547
|
cross_entropy: bool = False,
|
|
@@ -168,39 +568,47 @@ def apply_liger_kernel_to_mllama(
|
|
|
168
568
|
loaded. Default is None.
|
|
169
569
|
"""
|
|
170
570
|
|
|
171
|
-
assert not (
|
|
172
|
-
cross_entropy and fused_linear_cross_entropy
|
|
173
|
-
)
|
|
571
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
572
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
573
|
+
)
|
|
174
574
|
|
|
175
575
|
from transformers.models.mllama import modeling_mllama
|
|
176
|
-
from transformers.models.mllama.modeling_mllama import
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
MllamaVisionModel,
|
|
181
|
-
)
|
|
576
|
+
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
|
|
577
|
+
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
|
|
578
|
+
from transformers.models.mllama.modeling_mllama import MllamaTextModel
|
|
579
|
+
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
|
|
182
580
|
|
|
183
581
|
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
|
|
184
|
-
from liger_kernel.transformers.model.mllama import
|
|
185
|
-
lce_forward_deprecated as mllama_lce_forward_deprecated,
|
|
186
|
-
)
|
|
582
|
+
from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
|
|
187
583
|
|
|
188
584
|
if rope:
|
|
189
585
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
190
|
-
if layer_norm:
|
|
586
|
+
if layer_norm and model is None:
|
|
191
587
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
192
588
|
if rms_norm:
|
|
193
589
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
194
590
|
if swiglu:
|
|
195
591
|
modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
|
|
196
592
|
if cross_entropy:
|
|
197
|
-
|
|
593
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
594
|
+
from transformers.loss.loss_utils import nn
|
|
595
|
+
|
|
596
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
597
|
+
else:
|
|
598
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
599
|
+
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
198
600
|
if fused_linear_cross_entropy:
|
|
199
601
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
200
|
-
|
|
602
|
+
if model is not None:
|
|
603
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
604
|
+
else:
|
|
605
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
201
606
|
else: # if version < 4.46.1
|
|
202
607
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
203
|
-
|
|
608
|
+
if model is not None:
|
|
609
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
610
|
+
else:
|
|
611
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
204
612
|
|
|
205
613
|
if model is not None:
|
|
206
614
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -209,13 +617,17 @@ def apply_liger_kernel_to_mllama(
|
|
|
209
617
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
210
618
|
language_model: MllamaForCausalLM = model.language_model
|
|
211
619
|
vision_model: MllamaVisionModel = model.vision_model
|
|
212
|
-
|
|
620
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
621
|
+
text_model: MllamaTextModel = language_model.model
|
|
622
|
+
else:
|
|
623
|
+
text_model = language_model
|
|
213
624
|
elif isinstance(model, MllamaForCausalLM):
|
|
214
625
|
text_model = model.model
|
|
215
626
|
vision_model = None
|
|
216
627
|
elif isinstance(model, MllamaTextModel):
|
|
217
628
|
text_model = model
|
|
218
629
|
vision_model = None
|
|
630
|
+
|
|
219
631
|
else:
|
|
220
632
|
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
|
221
633
|
|
|
@@ -224,9 +636,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
224
636
|
_patch_rms_norm_module(text_model.norm)
|
|
225
637
|
for decoder_layer in text_model.layers:
|
|
226
638
|
if swiglu:
|
|
227
|
-
|
|
228
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
229
|
-
)
|
|
639
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
230
640
|
if rms_norm:
|
|
231
641
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
232
642
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -258,7 +668,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
258
668
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
259
669
|
|
|
260
670
|
Args:
|
|
261
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
671
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
262
672
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
263
673
|
fused_linear_cross_entropy (bool):
|
|
264
674
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
@@ -270,11 +680,12 @@ def apply_liger_kernel_to_mistral(
|
|
|
270
680
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
271
681
|
loaded. Default is None.
|
|
272
682
|
"""
|
|
273
|
-
assert not (
|
|
274
|
-
cross_entropy and fused_linear_cross_entropy
|
|
275
|
-
)
|
|
683
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
684
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
685
|
+
)
|
|
276
686
|
|
|
277
687
|
from transformers.models.mistral import modeling_mistral
|
|
688
|
+
from transformers.models.mistral.modeling_mistral import MistralModel
|
|
278
689
|
|
|
279
690
|
if rope:
|
|
280
691
|
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -283,7 +694,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
283
694
|
if cross_entropy:
|
|
284
695
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
285
696
|
if fused_linear_cross_entropy:
|
|
286
|
-
|
|
697
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
698
|
+
if model is not None:
|
|
699
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
700
|
+
else:
|
|
701
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
702
|
+
else:
|
|
703
|
+
logger.warning(
|
|
704
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
705
|
+
)
|
|
706
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
707
|
+
|
|
287
708
|
if swiglu:
|
|
288
709
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
289
710
|
|
|
@@ -291,21 +712,15 @@ def apply_liger_kernel_to_mistral(
|
|
|
291
712
|
# The model instance already exists, so we need to additionally patch the
|
|
292
713
|
# instance variables that reference already-instantiated modules
|
|
293
714
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
base_model = model.model
|
|
297
|
-
else:
|
|
298
|
-
# Direct MistralModel
|
|
299
|
-
base_model = model
|
|
715
|
+
# get the base model from the model instance
|
|
716
|
+
base_model: MistralModel = getattr(model, model.base_model_prefix, model)
|
|
300
717
|
|
|
301
718
|
if rms_norm:
|
|
302
719
|
_patch_rms_norm_module(base_model.norm)
|
|
303
720
|
|
|
304
721
|
for decoder_layer in base_model.layers:
|
|
305
722
|
if swiglu:
|
|
306
|
-
|
|
307
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
308
|
-
)
|
|
723
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
309
724
|
if rms_norm:
|
|
310
725
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
311
726
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -335,24 +750,38 @@ def apply_liger_kernel_to_mixtral(
|
|
|
335
750
|
loaded. Default is None.
|
|
336
751
|
"""
|
|
337
752
|
|
|
338
|
-
assert not (
|
|
339
|
-
cross_entropy and fused_linear_cross_entropy
|
|
340
|
-
)
|
|
753
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
754
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
755
|
+
)
|
|
341
756
|
|
|
342
757
|
from transformers.models.mixtral import modeling_mixtral
|
|
758
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
343
759
|
|
|
344
760
|
if rope:
|
|
345
761
|
modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
346
762
|
if rms_norm:
|
|
347
763
|
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
|
|
348
764
|
if cross_entropy:
|
|
349
|
-
|
|
765
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
766
|
+
from transformers.loss.loss_utils import nn
|
|
767
|
+
|
|
768
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
769
|
+
else:
|
|
770
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
771
|
+
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
772
|
+
|
|
350
773
|
if fused_linear_cross_entropy:
|
|
351
774
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
352
|
-
|
|
775
|
+
if model is not None:
|
|
776
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
777
|
+
else:
|
|
778
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
353
779
|
else: # if version < 4.46.1
|
|
354
780
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
355
|
-
|
|
781
|
+
if model is not None:
|
|
782
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
783
|
+
else:
|
|
784
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
356
785
|
if swiglu:
|
|
357
786
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
358
787
|
|
|
@@ -360,12 +789,8 @@ def apply_liger_kernel_to_mixtral(
|
|
|
360
789
|
# The model instance already exists, so we need to additionally patch the
|
|
361
790
|
# instance variables that reference already-instantiated modules
|
|
362
791
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
base_model = model.model
|
|
366
|
-
else:
|
|
367
|
-
# Direct MixtralModel
|
|
368
|
-
base_model = model
|
|
792
|
+
# get the base model from the model instance
|
|
793
|
+
base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
|
|
369
794
|
|
|
370
795
|
if rms_norm:
|
|
371
796
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -373,9 +798,7 @@ def apply_liger_kernel_to_mixtral(
|
|
|
373
798
|
for decoder_layer in base_model.layers:
|
|
374
799
|
if swiglu:
|
|
375
800
|
for expert in decoder_layer.block_sparse_moe.experts:
|
|
376
|
-
|
|
377
|
-
expert, "forward", LigerBlockSparseTop2MLP.forward
|
|
378
|
-
)
|
|
801
|
+
_patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
|
|
379
802
|
if rms_norm:
|
|
380
803
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
381
804
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -405,54 +828,57 @@ def apply_liger_kernel_to_gemma(
|
|
|
405
828
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
406
829
|
loaded. Default is None.
|
|
407
830
|
"""
|
|
408
|
-
assert not (
|
|
409
|
-
cross_entropy and fused_linear_cross_entropy
|
|
410
|
-
)
|
|
831
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
832
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
833
|
+
)
|
|
411
834
|
|
|
412
835
|
from transformers.models.gemma import modeling_gemma
|
|
836
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
413
837
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
)
|
|
418
|
-
_patch_rms_norm_module_for_gemma = partial(
|
|
419
|
-
_patch_rms_norm_module, casting_mode="gemma", offset=1.0
|
|
420
|
-
)
|
|
838
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
|
839
|
+
|
|
840
|
+
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
|
421
841
|
|
|
422
842
|
if rope:
|
|
423
843
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
424
844
|
if rms_norm:
|
|
425
845
|
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
|
|
426
846
|
if cross_entropy:
|
|
427
|
-
|
|
847
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
848
|
+
from transformers.loss.loss_utils import nn
|
|
849
|
+
|
|
850
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
851
|
+
else:
|
|
852
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
853
|
+
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
428
854
|
if geglu:
|
|
429
855
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
430
856
|
if fused_linear_cross_entropy:
|
|
431
857
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
432
|
-
|
|
858
|
+
if model is not None:
|
|
859
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
860
|
+
else:
|
|
861
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
433
862
|
else: # if version < 4.46.1
|
|
434
863
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
435
|
-
|
|
864
|
+
if model is not None:
|
|
865
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
866
|
+
else:
|
|
867
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
436
868
|
|
|
437
869
|
if model is not None:
|
|
438
870
|
# The model instance already exists, so we need to additionally patch the
|
|
439
871
|
# instance variables that reference already-instantiated modules
|
|
440
872
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
base_model = model.model
|
|
444
|
-
else:
|
|
445
|
-
# Direct GemmaModel
|
|
446
|
-
base_model = model
|
|
873
|
+
# get the base model from the model instance
|
|
874
|
+
base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
|
|
447
875
|
|
|
448
876
|
if rms_norm:
|
|
449
877
|
_patch_rms_norm_module_for_gemma(base_model.norm)
|
|
450
878
|
|
|
451
879
|
for decoder_layer in base_model.layers:
|
|
452
880
|
if geglu:
|
|
453
|
-
|
|
454
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
|
455
|
-
)
|
|
881
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
456
882
|
if rms_norm:
|
|
457
883
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
|
458
884
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
|
@@ -460,7 +886,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
460
886
|
|
|
461
887
|
def apply_liger_kernel_to_gemma2(
|
|
462
888
|
rope: bool = True,
|
|
463
|
-
cross_entropy: bool =
|
|
889
|
+
cross_entropy: bool = False,
|
|
890
|
+
fused_linear_cross_entropy: bool = True,
|
|
464
891
|
rms_norm: bool = True,
|
|
465
892
|
geglu: bool = True,
|
|
466
893
|
model: PreTrainedModel = None,
|
|
@@ -471,65 +898,1107 @@ def apply_liger_kernel_to_gemma2(
|
|
|
471
898
|
|
|
472
899
|
Args:
|
|
473
900
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
474
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
475
|
-
|
|
476
|
-
|
|
901
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
902
|
+
fused_linear_cross_entropy (bool):
|
|
903
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
904
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
905
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
906
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
907
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
908
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
909
|
+
loaded. Default is None.
|
|
910
|
+
"""
|
|
911
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
912
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
from transformers.models.gemma2 import modeling_gemma2
|
|
916
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
917
|
+
|
|
918
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
|
919
|
+
|
|
920
|
+
_patch_rms_norm_module_for_gemma2 = partial(
|
|
921
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
if rope:
|
|
925
|
+
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
926
|
+
if rms_norm:
|
|
927
|
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
928
|
+
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
|
|
929
|
+
if cross_entropy:
|
|
930
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
931
|
+
from transformers.loss.loss_utils import nn
|
|
932
|
+
|
|
933
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
934
|
+
else:
|
|
935
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
936
|
+
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
937
|
+
if fused_linear_cross_entropy:
|
|
938
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
939
|
+
if model is not None:
|
|
940
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
941
|
+
else:
|
|
942
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
943
|
+
else:
|
|
944
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
945
|
+
if model is not None:
|
|
946
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
947
|
+
else:
|
|
948
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
949
|
+
if geglu:
|
|
950
|
+
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
951
|
+
|
|
952
|
+
if model is not None:
|
|
953
|
+
# The model instance already exists, so we need to additionally patch the
|
|
954
|
+
# instance variables that reference already-instantiated modules
|
|
955
|
+
|
|
956
|
+
# get the base model from the model instance
|
|
957
|
+
base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
|
|
958
|
+
|
|
959
|
+
if rms_norm:
|
|
960
|
+
_patch_rms_norm_module_for_gemma2(base_model.norm)
|
|
961
|
+
|
|
962
|
+
for decoder_layer in base_model.layers:
|
|
963
|
+
if geglu:
|
|
964
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
965
|
+
if rms_norm:
|
|
966
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
|
967
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
|
968
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
|
|
969
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def apply_liger_kernel_to_gemma3_text(
|
|
973
|
+
rope: bool = True,
|
|
974
|
+
cross_entropy: bool = False,
|
|
975
|
+
fused_linear_cross_entropy: bool = True,
|
|
976
|
+
rms_norm: bool = True,
|
|
977
|
+
geglu: bool = True,
|
|
978
|
+
model: PreTrainedModel = None,
|
|
979
|
+
) -> None:
|
|
980
|
+
"""
|
|
981
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
|
|
982
|
+
|
|
983
|
+
Args:
|
|
984
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
985
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
986
|
+
fused_linear_cross_entropy (bool):
|
|
987
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
988
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
989
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
990
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
991
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
992
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
993
|
+
loaded. Default is None.
|
|
994
|
+
"""
|
|
995
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
996
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
1000
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
|
1001
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
|
1002
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
|
1003
|
+
|
|
1004
|
+
from liger_kernel.transformers.model.gemma3 import causal_forward
|
|
1005
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
|
1006
|
+
|
|
1007
|
+
_patch_rms_norm_module_for_gemma3 = partial(
|
|
1008
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
if rope:
|
|
1012
|
+
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1013
|
+
|
|
1014
|
+
if rms_norm:
|
|
1015
|
+
modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
|
|
1016
|
+
|
|
1017
|
+
if geglu:
|
|
1018
|
+
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
|
1019
|
+
|
|
1020
|
+
# Handle loss function
|
|
1021
|
+
if cross_entropy:
|
|
1022
|
+
from transformers.loss.loss_utils import nn
|
|
1023
|
+
|
|
1024
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1025
|
+
|
|
1026
|
+
if fused_linear_cross_entropy:
|
|
1027
|
+
if model is not None:
|
|
1028
|
+
model.forward = MethodType(causal_forward, model)
|
|
1029
|
+
else:
|
|
1030
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
1031
|
+
|
|
1032
|
+
if model is not None:
|
|
1033
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1034
|
+
# instance variables that reference already-instantiated modules
|
|
1035
|
+
|
|
1036
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
|
1037
|
+
# get the base model from the model instance
|
|
1038
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
|
1039
|
+
|
|
1040
|
+
if rms_norm:
|
|
1041
|
+
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
|
1042
|
+
|
|
1043
|
+
for decoder_layer in base_model.layers:
|
|
1044
|
+
decoder_layer: Gemma3DecoderLayer
|
|
1045
|
+
if geglu:
|
|
1046
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
|
1047
|
+
if rms_norm:
|
|
1048
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
|
|
1049
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
|
|
1050
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
|
|
1051
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
|
|
1052
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
|
|
1053
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
|
|
1054
|
+
|
|
1055
|
+
else:
|
|
1056
|
+
raise TypeError("The model must be Gemma3ForCausalLM.")
|
|
1057
|
+
|
|
1058
|
+
|
|
1059
|
+
def apply_liger_kernel_to_gemma3(
|
|
1060
|
+
rope: bool = True,
|
|
1061
|
+
cross_entropy: bool = False,
|
|
1062
|
+
fused_linear_cross_entropy: bool = True,
|
|
1063
|
+
layer_norm: bool = True,
|
|
1064
|
+
rms_norm: bool = True,
|
|
1065
|
+
geglu: bool = True,
|
|
1066
|
+
model: PreTrainedModel = None,
|
|
1067
|
+
) -> None:
|
|
1068
|
+
"""
|
|
1069
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
|
|
1070
|
+
|
|
1071
|
+
Args:
|
|
1072
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1073
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1074
|
+
fused_linear_cross_entropy (bool):
|
|
1075
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1076
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1077
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1078
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1079
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1080
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
1081
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1082
|
+
loaded. Default is None.
|
|
1083
|
+
"""
|
|
1084
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1085
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
1089
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
|
|
1090
|
+
from transformers.models.siglip import modeling_siglip
|
|
1091
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
1092
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
1093
|
+
|
|
1094
|
+
from liger_kernel.transformers.model.gemma3 import multimodal_forward
|
|
1095
|
+
|
|
1096
|
+
_patch_rms_norm_module_for_gemma3 = partial(
|
|
1097
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
if layer_norm and model is None:
|
|
1101
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
1102
|
+
|
|
1103
|
+
apply_liger_kernel_to_gemma3_text(
|
|
1104
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
if cross_entropy:
|
|
1108
|
+
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1109
|
+
|
|
1110
|
+
if fused_linear_cross_entropy:
|
|
1111
|
+
if model is not None:
|
|
1112
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1113
|
+
else:
|
|
1114
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
1115
|
+
|
|
1116
|
+
if model is not None:
|
|
1117
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1118
|
+
# instance variables that reference already-instantiated modules
|
|
1119
|
+
|
|
1120
|
+
if isinstance(model, Gemma3ForConditionalGeneration):
|
|
1121
|
+
if isinstance(model.vision_tower, SiglipVisionModel):
|
|
1122
|
+
vision_tower = model.vision_tower
|
|
1123
|
+
|
|
1124
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1125
|
+
|
|
1126
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
1127
|
+
layer: SiglipEncoderLayer
|
|
1128
|
+
if layer_norm:
|
|
1129
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
1130
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
1131
|
+
else:
|
|
1132
|
+
raise TypeError("The vision tower must be SiglipVisionModel")
|
|
1133
|
+
|
|
1134
|
+
if rms_norm:
|
|
1135
|
+
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
|
|
1136
|
+
|
|
1137
|
+
apply_liger_kernel_to_gemma3_text(
|
|
1138
|
+
rope=rope,
|
|
1139
|
+
cross_entropy=False,
|
|
1140
|
+
fused_linear_cross_entropy=False,
|
|
1141
|
+
rms_norm=rms_norm,
|
|
1142
|
+
geglu=geglu,
|
|
1143
|
+
model=model.language_model,
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
else:
|
|
1147
|
+
raise TypeError("The model must be Gemma3ForConditionalGeneration.")
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
def apply_liger_kernel_to_paligemma(
|
|
1151
|
+
rope: bool = True,
|
|
1152
|
+
cross_entropy: bool = False,
|
|
1153
|
+
fused_linear_cross_entropy: bool = True,
|
|
1154
|
+
layer_norm: bool = True,
|
|
1155
|
+
rms_norm: bool = True,
|
|
1156
|
+
geglu: bool = True,
|
|
1157
|
+
model: PreTrainedModel = None,
|
|
1158
|
+
) -> None:
|
|
1159
|
+
"""
|
|
1160
|
+
Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
|
|
1161
|
+
|
|
1162
|
+
Args:
|
|
1163
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1164
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1165
|
+
fused_linear_cross_entropy (bool):
|
|
1166
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1167
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1168
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1169
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1170
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1171
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
1172
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1173
|
+
loaded. Default is None.
|
|
1174
|
+
"""
|
|
1175
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1176
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
1180
|
+
|
|
1181
|
+
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1182
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
1183
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1184
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
1185
|
+
from transformers.models.paligemma import modeling_paligemma
|
|
1186
|
+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
1187
|
+
from transformers.models.siglip import modeling_siglip
|
|
1188
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
1189
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
1190
|
+
|
|
1191
|
+
from liger_kernel.transformers.model.paligemma import lce_forward
|
|
1192
|
+
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
1193
|
+
|
|
1194
|
+
# The vision_tower is a SiglipVisionModel
|
|
1195
|
+
if layer_norm and model is None:
|
|
1196
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
1197
|
+
|
|
1198
|
+
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
1199
|
+
# The multi_modal_projector is Linear, nothing to do
|
|
1200
|
+
|
|
1201
|
+
# The language_model is GemmaForCausalLM or Gemma2ForCausalLM
|
|
1202
|
+
apply_liger_kernel_to_gemma(
|
|
1203
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1204
|
+
)
|
|
1205
|
+
apply_liger_kernel_to_gemma2(
|
|
1206
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1207
|
+
)
|
|
1208
|
+
# Handle loss function
|
|
1209
|
+
if cross_entropy:
|
|
1210
|
+
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1211
|
+
if fused_linear_cross_entropy:
|
|
1212
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1213
|
+
if model is not None:
|
|
1214
|
+
model.forward = MethodType(lce_forward, model)
|
|
1215
|
+
else:
|
|
1216
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
1217
|
+
else: # if version < 4.46.1
|
|
1218
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1219
|
+
if model is not None:
|
|
1220
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1221
|
+
else:
|
|
1222
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
1223
|
+
|
|
1224
|
+
if model is not None:
|
|
1225
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1226
|
+
# instance variables that reference already-instantiated modules
|
|
1227
|
+
|
|
1228
|
+
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
1229
|
+
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
1230
|
+
|
|
1231
|
+
vision_tower: SiglipVisionModel = model.vision_tower
|
|
1232
|
+
|
|
1233
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1234
|
+
|
|
1235
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
1236
|
+
layer: SiglipEncoderLayer
|
|
1237
|
+
if layer_norm:
|
|
1238
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
1239
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
1240
|
+
|
|
1241
|
+
language_model = model.language_model
|
|
1242
|
+
|
|
1243
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1244
|
+
apply_liger_kernel_to_gemma(
|
|
1245
|
+
rope=rope,
|
|
1246
|
+
cross_entropy=False,
|
|
1247
|
+
fused_linear_cross_entropy=False,
|
|
1248
|
+
rms_norm=rms_norm,
|
|
1249
|
+
geglu=geglu,
|
|
1250
|
+
model=language_model,
|
|
1251
|
+
)
|
|
1252
|
+
|
|
1253
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
1254
|
+
apply_liger_kernel_to_gemma2(
|
|
1255
|
+
rope=rope,
|
|
1256
|
+
cross_entropy=False,
|
|
1257
|
+
fused_linear_cross_entropy=False,
|
|
1258
|
+
rms_norm=rms_norm,
|
|
1259
|
+
geglu=geglu,
|
|
1260
|
+
model=language_model,
|
|
1261
|
+
)
|
|
1262
|
+
else:
|
|
1263
|
+
raise TypeError(
|
|
1264
|
+
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
|
|
1265
|
+
)
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def apply_liger_kernel_to_qwen2(
|
|
1269
|
+
rope: bool = True,
|
|
1270
|
+
cross_entropy: bool = False,
|
|
1271
|
+
fused_linear_cross_entropy: bool = True,
|
|
1272
|
+
rms_norm: bool = True,
|
|
1273
|
+
swiglu: bool = True,
|
|
1274
|
+
model: PreTrainedModel = None,
|
|
1275
|
+
) -> None:
|
|
1276
|
+
"""
|
|
1277
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
|
|
1278
|
+
|
|
1279
|
+
Args:
|
|
1280
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1281
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1282
|
+
fused_linear_cross_entropy (bool):
|
|
1283
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1284
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1285
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1286
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1287
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1288
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1289
|
+
loaded. Default is None.
|
|
1290
|
+
"""
|
|
1291
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1292
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
from transformers.models.qwen2 import modeling_qwen2
|
|
1296
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
1297
|
+
|
|
1298
|
+
if rope:
|
|
1299
|
+
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1300
|
+
if rms_norm:
|
|
1301
|
+
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
1302
|
+
|
|
1303
|
+
if cross_entropy:
|
|
1304
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1305
|
+
from transformers.loss.loss_utils import nn
|
|
1306
|
+
|
|
1307
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1308
|
+
else:
|
|
1309
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1310
|
+
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1311
|
+
|
|
1312
|
+
if fused_linear_cross_entropy:
|
|
1313
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1314
|
+
if model is not None:
|
|
1315
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1316
|
+
else:
|
|
1317
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1318
|
+
else: # if version < 4.46.1
|
|
1319
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1320
|
+
if model is not None:
|
|
1321
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1322
|
+
else:
|
|
1323
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1324
|
+
|
|
1325
|
+
if swiglu:
|
|
1326
|
+
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
1327
|
+
|
|
1328
|
+
if model is not None:
|
|
1329
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1330
|
+
# instance variables that reference already-instantiated modules
|
|
1331
|
+
|
|
1332
|
+
# get the base model from the model instance
|
|
1333
|
+
base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
|
|
1334
|
+
|
|
1335
|
+
if rms_norm:
|
|
1336
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1337
|
+
|
|
1338
|
+
for decoder_layer in base_model.layers:
|
|
1339
|
+
if swiglu:
|
|
1340
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1341
|
+
if rms_norm:
|
|
1342
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1343
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1344
|
+
|
|
1345
|
+
|
|
1346
|
+
def apply_liger_kernel_to_qwen3(
|
|
1347
|
+
rope: bool = True,
|
|
1348
|
+
cross_entropy: bool = False,
|
|
1349
|
+
fused_linear_cross_entropy: bool = True,
|
|
1350
|
+
rms_norm: bool = True,
|
|
1351
|
+
swiglu: bool = True,
|
|
1352
|
+
model: PreTrainedModel = None,
|
|
1353
|
+
) -> None:
|
|
1354
|
+
"""
|
|
1355
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1356
|
+
"""
|
|
1357
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1358
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
from transformers.models.qwen3 import modeling_qwen3
|
|
1362
|
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
|
|
1363
|
+
|
|
1364
|
+
from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
|
|
1365
|
+
|
|
1366
|
+
if rope:
|
|
1367
|
+
modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1368
|
+
|
|
1369
|
+
if rms_norm:
|
|
1370
|
+
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
|
1371
|
+
|
|
1372
|
+
if cross_entropy:
|
|
1373
|
+
from transformers.loss.loss_utils import nn
|
|
1374
|
+
|
|
1375
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1376
|
+
|
|
1377
|
+
if fused_linear_cross_entropy:
|
|
1378
|
+
if model is not None:
|
|
1379
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1380
|
+
else:
|
|
1381
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1382
|
+
|
|
1383
|
+
if swiglu:
|
|
1384
|
+
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
1385
|
+
|
|
1386
|
+
if model is not None:
|
|
1387
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1388
|
+
# instance variables that reference already-instantiated modules
|
|
1389
|
+
|
|
1390
|
+
# get the base model from the model instance
|
|
1391
|
+
base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
|
|
1392
|
+
|
|
1393
|
+
if rms_norm:
|
|
1394
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1395
|
+
for decoder_layer in base_model.layers:
|
|
1396
|
+
if swiglu:
|
|
1397
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1398
|
+
if rms_norm:
|
|
1399
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1400
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1401
|
+
|
|
1402
|
+
|
|
1403
|
+
def apply_liger_kernel_to_qwen3_moe(
|
|
1404
|
+
rope: bool = True,
|
|
1405
|
+
cross_entropy: bool = False,
|
|
1406
|
+
fused_linear_cross_entropy: bool = True,
|
|
1407
|
+
rms_norm: bool = True,
|
|
1408
|
+
swiglu: bool = True,
|
|
1409
|
+
model: PreTrainedModel = None,
|
|
1410
|
+
) -> None:
|
|
1411
|
+
"""
|
|
1412
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1413
|
+
"""
|
|
1414
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1415
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1416
|
+
)
|
|
1417
|
+
|
|
1418
|
+
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
1419
|
+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
|
|
1420
|
+
|
|
1421
|
+
from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
|
|
1422
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
1423
|
+
|
|
1424
|
+
if rope:
|
|
1425
|
+
modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1426
|
+
|
|
1427
|
+
if rms_norm:
|
|
1428
|
+
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
|
1429
|
+
|
|
1430
|
+
if cross_entropy:
|
|
1431
|
+
from transformers.loss.loss_utils import nn
|
|
1432
|
+
|
|
1433
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1434
|
+
|
|
1435
|
+
if fused_linear_cross_entropy:
|
|
1436
|
+
if model is not None:
|
|
1437
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1438
|
+
else:
|
|
1439
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1440
|
+
|
|
1441
|
+
if swiglu:
|
|
1442
|
+
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
1443
|
+
|
|
1444
|
+
if model is not None:
|
|
1445
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1446
|
+
# instance variables that reference already-instantiated modules
|
|
1447
|
+
|
|
1448
|
+
# get the base model from the model instance
|
|
1449
|
+
base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
|
|
1450
|
+
|
|
1451
|
+
if rms_norm:
|
|
1452
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1453
|
+
for decoder_layer in base_model.layers:
|
|
1454
|
+
if swiglu:
|
|
1455
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
1456
|
+
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
|
|
1457
|
+
if rms_norm:
|
|
1458
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1459
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
def apply_liger_kernel_to_qwen2_vl(
|
|
1463
|
+
rope: bool = True,
|
|
1464
|
+
cross_entropy: bool = False,
|
|
1465
|
+
fused_linear_cross_entropy: bool = True,
|
|
1466
|
+
rms_norm: bool = True,
|
|
1467
|
+
layer_norm: bool = True,
|
|
1468
|
+
swiglu: bool = True,
|
|
1469
|
+
model: PreTrainedModel = None,
|
|
1470
|
+
) -> None:
|
|
1471
|
+
"""
|
|
1472
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
1473
|
+
NOTE: Qwen2-VL is not supported in transformers<4.52.4
|
|
1474
|
+
|
|
1475
|
+
Args:
|
|
1476
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1477
|
+
fused_linear_cross_entropy (bool):
|
|
1478
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1479
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1480
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1481
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1482
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1483
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1484
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1485
|
+
loaded. Default is None.
|
|
1486
|
+
"""
|
|
1487
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1488
|
+
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
|
|
1489
|
+
return
|
|
1490
|
+
|
|
1491
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1492
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
1496
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
|
1497
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
|
1498
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
1499
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
|
1500
|
+
|
|
1501
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
|
1502
|
+
|
|
1503
|
+
if rope:
|
|
1504
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
1505
|
+
if rms_norm:
|
|
1506
|
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1507
|
+
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1508
|
+
if layer_norm and model is None:
|
|
1509
|
+
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1510
|
+
if cross_entropy:
|
|
1511
|
+
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1512
|
+
if fused_linear_cross_entropy:
|
|
1513
|
+
if model is not None:
|
|
1514
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1515
|
+
else:
|
|
1516
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1517
|
+
if swiglu:
|
|
1518
|
+
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1519
|
+
|
|
1520
|
+
if model is not None:
|
|
1521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1522
|
+
# instance variables that reference already-instantiated modules
|
|
1523
|
+
|
|
1524
|
+
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
|
|
1525
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1526
|
+
# Not sure if it is subject to changes in the future.
|
|
1527
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1528
|
+
text_model: Qwen2VLTextModel = model.language_model
|
|
1529
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1530
|
+
elif isinstance(model, Qwen2VLTextModel):
|
|
1531
|
+
text_model: Qwen2VLTextModel = model
|
|
1532
|
+
vision_model = None
|
|
1533
|
+
else:
|
|
1534
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1535
|
+
raise TypeError(
|
|
1536
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1537
|
+
)
|
|
1538
|
+
|
|
1539
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
|
1540
|
+
if vision_model is not None:
|
|
1541
|
+
for vision_block in vision_model.blocks:
|
|
1542
|
+
if layer_norm:
|
|
1543
|
+
_patch_layer_norm_module(vision_block.norm1)
|
|
1544
|
+
_patch_layer_norm_module(vision_block.norm2)
|
|
1545
|
+
|
|
1546
|
+
# Patch Qwen2VisionTextModel
|
|
1547
|
+
if text_model is not None:
|
|
1548
|
+
if rms_norm:
|
|
1549
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1550
|
+
for decoder_layer in text_model.layers:
|
|
1551
|
+
if swiglu:
|
|
1552
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1553
|
+
if rms_norm:
|
|
1554
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1555
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1556
|
+
|
|
1557
|
+
|
|
1558
|
+
def apply_liger_kernel_to_qwen2_5_vl(
|
|
1559
|
+
rope: bool = True,
|
|
1560
|
+
cross_entropy: bool = False,
|
|
1561
|
+
fused_linear_cross_entropy: bool = True,
|
|
1562
|
+
rms_norm: bool = True,
|
|
1563
|
+
swiglu: bool = True,
|
|
1564
|
+
model: PreTrainedModel = None,
|
|
1565
|
+
) -> None:
|
|
1566
|
+
"""
|
|
1567
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
|
|
1568
|
+
NOTE: Qwen2.5-VL is not available in transformers<4.48.2
|
|
1569
|
+
|
|
1570
|
+
Args:
|
|
1571
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1572
|
+
fused_linear_cross_entropy (bool):
|
|
1573
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1574
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1575
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1576
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1577
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1578
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1579
|
+
loaded. Default is None.
|
|
1580
|
+
"""
|
|
1581
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1582
|
+
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
|
|
1583
|
+
return
|
|
1584
|
+
|
|
1585
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1586
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1587
|
+
)
|
|
1588
|
+
|
|
1589
|
+
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
1590
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
|
1591
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
|
1592
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
|
1593
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
|
1594
|
+
|
|
1595
|
+
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
|
1596
|
+
|
|
1597
|
+
if rope:
|
|
1598
|
+
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
1599
|
+
if rms_norm:
|
|
1600
|
+
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1601
|
+
if cross_entropy:
|
|
1602
|
+
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1603
|
+
if fused_linear_cross_entropy:
|
|
1604
|
+
if model is not None:
|
|
1605
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1606
|
+
else:
|
|
1607
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1608
|
+
if swiglu:
|
|
1609
|
+
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1610
|
+
|
|
1611
|
+
if model is not None:
|
|
1612
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1613
|
+
# instance variables that reference already-instantiated modules
|
|
1614
|
+
|
|
1615
|
+
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
|
|
1616
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1617
|
+
# Not sure if it is subject to changes in the future.
|
|
1618
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1619
|
+
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1620
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1621
|
+
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
1622
|
+
text_model: Qwen2_5_VLTextModel = model
|
|
1623
|
+
vision_model = None
|
|
1624
|
+
else:
|
|
1625
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1626
|
+
raise TypeError(
|
|
1627
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1628
|
+
)
|
|
1629
|
+
|
|
1630
|
+
if vision_model is not None:
|
|
1631
|
+
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1632
|
+
for vision_block in model.visual.blocks:
|
|
1633
|
+
if rms_norm:
|
|
1634
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
1635
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
1636
|
+
|
|
1637
|
+
if text_model is not None:
|
|
1638
|
+
if rms_norm:
|
|
1639
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1640
|
+
for decoder_layer in text_model.layers:
|
|
1641
|
+
if swiglu:
|
|
1642
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1643
|
+
if rms_norm:
|
|
1644
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1645
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1646
|
+
|
|
1647
|
+
|
|
1648
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1649
|
+
rope: bool = True,
|
|
1650
|
+
cross_entropy: bool = False,
|
|
1651
|
+
fused_linear_cross_entropy: bool = True,
|
|
1652
|
+
rms_norm: bool = True,
|
|
1653
|
+
swiglu: bool = False,
|
|
1654
|
+
model: PreTrainedModel = None,
|
|
1655
|
+
) -> None:
|
|
1656
|
+
"""
|
|
1657
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1658
|
+
|
|
1659
|
+
Args:
|
|
1660
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1661
|
+
fused_linear_cross_entropy (bool):
|
|
1662
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1663
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1664
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1665
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1666
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1667
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1668
|
+
loaded. Default is None.
|
|
1669
|
+
"""
|
|
1670
|
+
|
|
1671
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1672
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1673
|
+
)
|
|
1674
|
+
|
|
1675
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1676
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1677
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1678
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1679
|
+
|
|
1680
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1681
|
+
|
|
1682
|
+
if rope:
|
|
1683
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1684
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
1685
|
+
|
|
1686
|
+
if rms_norm:
|
|
1687
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1688
|
+
|
|
1689
|
+
if cross_entropy:
|
|
1690
|
+
from transformers.loss.loss_utils import nn
|
|
1691
|
+
|
|
1692
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1693
|
+
|
|
1694
|
+
if fused_linear_cross_entropy:
|
|
1695
|
+
if model is not None:
|
|
1696
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1697
|
+
else:
|
|
1698
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1699
|
+
|
|
1700
|
+
if model is not None and rms_norm:
|
|
1701
|
+
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
|
|
1702
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1703
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1704
|
+
text_model = model
|
|
1705
|
+
else:
|
|
1706
|
+
raise TypeError(
|
|
1707
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1708
|
+
)
|
|
1709
|
+
|
|
1710
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1711
|
+
|
|
1712
|
+
if text_model is not None:
|
|
1713
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1714
|
+
for decoder_layer in text_model.layers:
|
|
1715
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1716
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1717
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1718
|
+
if self_attn is not None:
|
|
1719
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1720
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1721
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1722
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1723
|
+
|
|
1724
|
+
|
|
1725
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1726
|
+
rope: bool = True,
|
|
1727
|
+
cross_entropy: bool = False,
|
|
1728
|
+
fused_linear_cross_entropy: bool = True,
|
|
1729
|
+
rms_norm: bool = True,
|
|
1730
|
+
swiglu: bool = False,
|
|
1731
|
+
model: PreTrainedModel = None,
|
|
1732
|
+
) -> None:
|
|
1733
|
+
"""
|
|
1734
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1735
|
+
|
|
1736
|
+
Args:
|
|
1737
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1738
|
+
fused_linear_cross_entropy (bool):
|
|
1739
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1740
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1741
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1742
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1743
|
+
loaded. Default is None.
|
|
1744
|
+
"""
|
|
1745
|
+
|
|
1746
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1747
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1748
|
+
)
|
|
1749
|
+
|
|
1750
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1751
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1752
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1753
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1754
|
+
|
|
1755
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1756
|
+
|
|
1757
|
+
if rope:
|
|
1758
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1759
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
1760
|
+
|
|
1761
|
+
if rms_norm:
|
|
1762
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1763
|
+
|
|
1764
|
+
if cross_entropy:
|
|
1765
|
+
from transformers.loss.loss_utils import nn
|
|
1766
|
+
|
|
1767
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1768
|
+
|
|
1769
|
+
if fused_linear_cross_entropy:
|
|
1770
|
+
if model is not None:
|
|
1771
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1772
|
+
else:
|
|
1773
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1774
|
+
|
|
1775
|
+
if model is not None and rms_norm:
|
|
1776
|
+
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
|
|
1777
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1778
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1779
|
+
text_model = model
|
|
1780
|
+
else:
|
|
1781
|
+
raise TypeError(
|
|
1782
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1783
|
+
)
|
|
1784
|
+
|
|
1785
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1786
|
+
|
|
1787
|
+
if text_model is not None:
|
|
1788
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1789
|
+
for decoder_layer in text_model.layers:
|
|
1790
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1791
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1792
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1793
|
+
if self_attn is not None:
|
|
1794
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1795
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1796
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1797
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1798
|
+
|
|
1799
|
+
|
|
1800
|
+
def apply_liger_kernel_to_phi3(
|
|
1801
|
+
rope: bool = True,
|
|
1802
|
+
cross_entropy: bool = False,
|
|
1803
|
+
fused_linear_cross_entropy: bool = True,
|
|
1804
|
+
rms_norm: bool = True,
|
|
1805
|
+
swiglu: bool = True,
|
|
1806
|
+
model: PreTrainedModel = None,
|
|
1807
|
+
) -> None:
|
|
1808
|
+
"""
|
|
1809
|
+
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
|
|
1810
|
+
|
|
1811
|
+
Args:
|
|
1812
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1813
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1814
|
+
fused_linear_cross_entropy (bool):
|
|
1815
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1816
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1817
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1818
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1819
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
1820
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1821
|
+
loaded. Default is None.
|
|
1822
|
+
"""
|
|
1823
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1824
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
from transformers.models.phi3 import modeling_phi3
|
|
1828
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
1829
|
+
|
|
1830
|
+
if rope:
|
|
1831
|
+
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
1832
|
+
if rms_norm:
|
|
1833
|
+
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
|
|
1834
|
+
if swiglu:
|
|
1835
|
+
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1836
|
+
if cross_entropy:
|
|
1837
|
+
from transformers.loss.loss_utils import nn
|
|
1838
|
+
|
|
1839
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1840
|
+
if fused_linear_cross_entropy:
|
|
1841
|
+
if model is not None:
|
|
1842
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1843
|
+
else:
|
|
1844
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1845
|
+
|
|
1846
|
+
if model is not None:
|
|
1847
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1848
|
+
# instance variables that reference already-instantiated modules
|
|
1849
|
+
|
|
1850
|
+
# get the base model from the model instance
|
|
1851
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
1852
|
+
|
|
1853
|
+
if rms_norm:
|
|
1854
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1855
|
+
|
|
1856
|
+
for decoder_layer in base_model.layers:
|
|
1857
|
+
if swiglu:
|
|
1858
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
1859
|
+
if rms_norm:
|
|
1860
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1861
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1862
|
+
|
|
1863
|
+
|
|
1864
|
+
def apply_liger_kernel_to_olmo2(
|
|
1865
|
+
rope: bool = True,
|
|
1866
|
+
cross_entropy: bool = False,
|
|
1867
|
+
fused_linear_cross_entropy: bool = True,
|
|
1868
|
+
rms_norm: bool = True,
|
|
1869
|
+
swiglu: bool = True,
|
|
1870
|
+
model: PreTrainedModel = None,
|
|
1871
|
+
) -> None:
|
|
1872
|
+
"""
|
|
1873
|
+
Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
|
|
1874
|
+
|
|
1875
|
+
Args:
|
|
1876
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1877
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1878
|
+
fused_linear_cross_entropy (bool):
|
|
1879
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1880
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1881
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1882
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1883
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
|
|
1884
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1885
|
+
loaded. Default is None.
|
|
1886
|
+
"""
|
|
1887
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1888
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1889
|
+
)
|
|
1890
|
+
|
|
1891
|
+
from transformers.models.olmo2 import modeling_olmo2
|
|
1892
|
+
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
1893
|
+
|
|
1894
|
+
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
1895
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1896
|
+
|
|
1897
|
+
if rope:
|
|
1898
|
+
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1899
|
+
if rms_norm:
|
|
1900
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
|
1901
|
+
if swiglu:
|
|
1902
|
+
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1903
|
+
if cross_entropy:
|
|
1904
|
+
from transformers.loss.loss_utils import nn
|
|
1905
|
+
|
|
1906
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1907
|
+
if fused_linear_cross_entropy:
|
|
1908
|
+
if model is not None:
|
|
1909
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1910
|
+
else:
|
|
1911
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1912
|
+
|
|
1913
|
+
if model is not None:
|
|
1914
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1915
|
+
# instance variables that reference already-instantiated modules
|
|
1916
|
+
|
|
1917
|
+
# get the base model from the model instance
|
|
1918
|
+
base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
|
|
1919
|
+
|
|
1920
|
+
if rms_norm:
|
|
1921
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1922
|
+
|
|
1923
|
+
for decoder_layer in base_model.layers:
|
|
1924
|
+
if swiglu:
|
|
1925
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1926
|
+
if rms_norm:
|
|
1927
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
1928
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
1929
|
+
|
|
1930
|
+
|
|
1931
|
+
def apply_liger_kernel_to_glm4(
|
|
1932
|
+
rope: bool = False,
|
|
1933
|
+
cross_entropy: bool = False,
|
|
1934
|
+
fused_linear_cross_entropy: bool = True,
|
|
1935
|
+
rms_norm: bool = True,
|
|
1936
|
+
swiglu: bool = True,
|
|
1937
|
+
model: PreTrainedModel = None,
|
|
1938
|
+
) -> None:
|
|
1939
|
+
"""
|
|
1940
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
|
|
1941
|
+
|
|
1942
|
+
Args:
|
|
1943
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
1944
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1945
|
+
fused_linear_cross_entropy (bool):
|
|
1946
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1947
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1948
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1949
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1950
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
477
1951
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
478
1952
|
loaded. Default is None.
|
|
479
1953
|
"""
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
LigerRMSNormForGemma2 = partial(
|
|
483
|
-
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
|
|
484
|
-
)
|
|
485
|
-
_patch_rms_norm_module_for_gemma2 = partial(
|
|
486
|
-
_patch_rms_norm_module, offset=1.0, casting_mode="gemma"
|
|
1954
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1955
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
487
1956
|
)
|
|
488
1957
|
|
|
1958
|
+
from transformers.models.glm4 import modeling_glm4
|
|
1959
|
+
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
|
1960
|
+
|
|
1961
|
+
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
|
1962
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
1963
|
+
|
|
489
1964
|
if rope:
|
|
490
|
-
|
|
1965
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
491
1966
|
if rms_norm:
|
|
492
|
-
|
|
493
|
-
|
|
1967
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
|
1968
|
+
if swiglu:
|
|
1969
|
+
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
|
494
1970
|
if cross_entropy:
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
1971
|
+
from transformers.loss.loss_utils import nn
|
|
1972
|
+
|
|
1973
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1974
|
+
if fused_linear_cross_entropy:
|
|
1975
|
+
if model is not None:
|
|
1976
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
1977
|
+
else:
|
|
1978
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
498
1979
|
|
|
499
1980
|
if model is not None:
|
|
500
1981
|
# The model instance already exists, so we need to additionally patch the
|
|
501
1982
|
# instance variables that reference already-instantiated modules
|
|
502
1983
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
base_model = model.model
|
|
506
|
-
else:
|
|
507
|
-
# Direct Gemma2Model
|
|
508
|
-
base_model = model
|
|
1984
|
+
# get the base model from the model instance
|
|
1985
|
+
base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
|
|
509
1986
|
|
|
510
1987
|
if rms_norm:
|
|
511
|
-
|
|
1988
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
512
1989
|
|
|
513
1990
|
for decoder_layer in base_model.layers:
|
|
514
|
-
if
|
|
515
|
-
|
|
516
|
-
decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
|
|
517
|
-
)
|
|
1991
|
+
if swiglu:
|
|
1992
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
518
1993
|
if rms_norm:
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
)
|
|
523
|
-
_patch_rms_norm_module_for_gemma2(
|
|
524
|
-
decoder_layer.pre_feedforward_layernorm
|
|
525
|
-
)
|
|
526
|
-
_patch_rms_norm_module_for_gemma2(
|
|
527
|
-
decoder_layer.post_feedforward_layernorm
|
|
528
|
-
)
|
|
1994
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
|
|
1995
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
1996
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
|
|
1997
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
|
|
529
1998
|
|
|
530
1999
|
|
|
531
|
-
def
|
|
532
|
-
rope: bool =
|
|
2000
|
+
def apply_liger_kernel_to_glm4v(
|
|
2001
|
+
rope: bool = False,
|
|
533
2002
|
cross_entropy: bool = False,
|
|
534
2003
|
fused_linear_cross_entropy: bool = True,
|
|
535
2004
|
rms_norm: bool = True,
|
|
@@ -537,150 +2006,469 @@ def apply_liger_kernel_to_qwen2(
|
|
|
537
2006
|
model: PreTrainedModel = None,
|
|
538
2007
|
) -> None:
|
|
539
2008
|
"""
|
|
540
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
2009
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
|
|
541
2010
|
|
|
542
2011
|
Args:
|
|
543
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
2012
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
544
2013
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
545
2014
|
fused_linear_cross_entropy (bool):
|
|
546
2015
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
547
2016
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
548
2017
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
549
2018
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
550
|
-
swiglu (bool): Whether to apply Liger's SwiGLU
|
|
2019
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
551
2020
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
552
2021
|
loaded. Default is None.
|
|
553
2022
|
"""
|
|
554
|
-
assert not (
|
|
555
|
-
cross_entropy and fused_linear_cross_entropy
|
|
556
|
-
)
|
|
2023
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2024
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2025
|
+
)
|
|
557
2026
|
|
|
558
|
-
from transformers.models.
|
|
2027
|
+
from transformers.models.glm4v import modeling_glm4v
|
|
2028
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
|
|
2029
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
|
|
2030
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
|
|
2031
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
|
|
2032
|
+
|
|
2033
|
+
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
|
|
2034
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
559
2035
|
|
|
560
2036
|
if rope:
|
|
561
|
-
|
|
2037
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
562
2038
|
if rms_norm:
|
|
563
|
-
|
|
2039
|
+
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
|
|
564
2040
|
if cross_entropy:
|
|
565
|
-
|
|
2041
|
+
from transformers.loss.loss_utils import nn
|
|
566
2042
|
|
|
567
|
-
|
|
2043
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
568
2044
|
if fused_linear_cross_entropy:
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
574
|
-
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
575
|
-
|
|
576
|
-
if swiglu:
|
|
577
|
-
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
2045
|
+
if model is not None:
|
|
2046
|
+
model.forward = MethodType(glm4v_lce_forward, model)
|
|
2047
|
+
else:
|
|
2048
|
+
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
|
|
578
2049
|
|
|
579
2050
|
if model is not None:
|
|
580
2051
|
# The model instance already exists, so we need to additionally patch the
|
|
581
2052
|
# instance variables that reference already-instantiated modules
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
#
|
|
585
|
-
|
|
2053
|
+
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
|
|
2054
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
2055
|
+
# Not sure if it is subject to changes in the future.
|
|
2056
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
|
|
2057
|
+
text_model: Glm4vTextModel = model.language_model
|
|
2058
|
+
vision_model: Glm4vVisionModel = model.visual
|
|
2059
|
+
elif isinstance(model, Glm4vTextModel):
|
|
2060
|
+
text_model: Glm4vTextModel = model
|
|
2061
|
+
vision_model = None
|
|
586
2062
|
else:
|
|
587
|
-
#
|
|
588
|
-
|
|
2063
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2064
|
+
raise TypeError(
|
|
2065
|
+
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
|
|
2066
|
+
)
|
|
589
2067
|
|
|
590
|
-
if
|
|
591
|
-
|
|
2068
|
+
if vision_model is not None:
|
|
2069
|
+
for vision_block in vision_model.blocks:
|
|
2070
|
+
if rms_norm:
|
|
2071
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2072
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2073
|
+
if swiglu:
|
|
2074
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
592
2075
|
|
|
593
|
-
|
|
594
|
-
if swiglu:
|
|
595
|
-
_bind_method_to_module(
|
|
596
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
597
|
-
)
|
|
2076
|
+
if text_model is not None:
|
|
598
2077
|
if rms_norm:
|
|
599
|
-
_patch_rms_norm_module(
|
|
600
|
-
|
|
601
|
-
|
|
2078
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2079
|
+
for decoder_layer in text_model.layers:
|
|
2080
|
+
if swiglu:
|
|
2081
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2082
|
+
if rms_norm:
|
|
2083
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2084
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2085
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
|
|
2086
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
|
|
602
2087
|
|
|
603
2088
|
|
|
604
|
-
def
|
|
2089
|
+
def apply_liger_kernel_to_glm4v_moe(
|
|
2090
|
+
rope: bool = False,
|
|
605
2091
|
cross_entropy: bool = False,
|
|
606
2092
|
fused_linear_cross_entropy: bool = True,
|
|
607
2093
|
rms_norm: bool = True,
|
|
608
|
-
layer_norm: bool = True,
|
|
609
2094
|
swiglu: bool = True,
|
|
610
2095
|
model: PreTrainedModel = None,
|
|
611
2096
|
) -> None:
|
|
612
2097
|
"""
|
|
613
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
614
|
-
NOTE: Qwen2-VL is not available in transformers<4.45.0
|
|
2098
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
615
2099
|
|
|
616
2100
|
Args:
|
|
2101
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
617
2102
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
618
2103
|
fused_linear_cross_entropy (bool):
|
|
619
2104
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
620
2105
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
621
2106
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
622
2107
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
623
|
-
|
|
624
|
-
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2108
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
625
2109
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
626
2110
|
loaded. Default is None.
|
|
627
2111
|
"""
|
|
628
|
-
assert not (
|
|
629
|
-
cross_entropy and fused_linear_cross_entropy
|
|
630
|
-
)
|
|
2112
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2113
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2114
|
+
)
|
|
631
2115
|
|
|
632
|
-
from transformers.models.
|
|
2116
|
+
from transformers.models.glm4v_moe import modeling_glm4v_moe
|
|
2117
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
|
|
2118
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
|
|
2119
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
|
|
2120
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
|
|
2121
|
+
|
|
2122
|
+
from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
|
|
2123
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2124
|
+
|
|
2125
|
+
if rope:
|
|
2126
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2127
|
+
if rms_norm:
|
|
2128
|
+
modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
|
|
2129
|
+
modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
|
|
2130
|
+
if cross_entropy:
|
|
2131
|
+
from transformers.loss.loss_utils import nn
|
|
2132
|
+
|
|
2133
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2134
|
+
if fused_linear_cross_entropy:
|
|
2135
|
+
if model is not None:
|
|
2136
|
+
model.forward = MethodType(glm4v_moe_lce_forward, model)
|
|
2137
|
+
else:
|
|
2138
|
+
modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
|
|
2139
|
+
|
|
2140
|
+
if model is not None:
|
|
2141
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2142
|
+
# instance variables that reference already-instantiated modules
|
|
2143
|
+
if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
|
|
2144
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
2145
|
+
# Not sure if it is subject to changes in the future.
|
|
2146
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
|
|
2147
|
+
text_model: Glm4vMoeTextModel = model.language_model
|
|
2148
|
+
vision_model: Glm4vMoeVisionModel = model.visual
|
|
2149
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2150
|
+
elif isinstance(model, Glm4vMoeTextModel):
|
|
2151
|
+
text_model: Glm4vMoeTextModel = model
|
|
2152
|
+
vision_model = None
|
|
2153
|
+
else:
|
|
2154
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2155
|
+
raise TypeError(
|
|
2156
|
+
f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
|
|
2157
|
+
)
|
|
2158
|
+
|
|
2159
|
+
if vision_model is not None:
|
|
2160
|
+
_patch_rms_norm_module(vision_model.post_conv_layernorm)
|
|
2161
|
+
_patch_rms_norm_module(vision_model.post_layernorm)
|
|
2162
|
+
for vision_block in vision_model.blocks:
|
|
2163
|
+
if rms_norm:
|
|
2164
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2165
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2166
|
+
if swiglu:
|
|
2167
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2168
|
+
|
|
2169
|
+
if text_model is not None:
|
|
2170
|
+
if rms_norm:
|
|
2171
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2172
|
+
for decoder_layer in text_model.layers:
|
|
2173
|
+
if swiglu:
|
|
2174
|
+
decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2175
|
+
if rms_norm:
|
|
2176
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2177
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2178
|
+
if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
|
|
2179
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2180
|
+
if experts is not None:
|
|
2181
|
+
for expert in experts:
|
|
2182
|
+
_patch_swiglu_module(expert, LigerSwiGLUMLP)
|
|
2183
|
+
if decoder_layer.mlp.shared_experts is not None:
|
|
2184
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
|
|
2185
|
+
for decoder_layer in text_model.layers:
|
|
2186
|
+
if rms_norm:
|
|
2187
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2188
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
633
2189
|
|
|
634
|
-
|
|
635
|
-
|
|
2190
|
+
|
|
2191
|
+
def apply_liger_kernel_to_internvl(
|
|
2192
|
+
cross_entropy: bool = False,
|
|
2193
|
+
fused_linear_cross_entropy: bool = True,
|
|
2194
|
+
rms_norm: bool = True,
|
|
2195
|
+
layer_norm: bool = True,
|
|
2196
|
+
model: Optional[PreTrainedModel] = None,
|
|
2197
|
+
**kwargs,
|
|
2198
|
+
) -> None:
|
|
2199
|
+
"""
|
|
2200
|
+
Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
|
|
2201
|
+
Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
|
|
2202
|
+
However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
|
|
2203
|
+
NOTE: InternVL is not available in transformers<4.52.1
|
|
2204
|
+
|
|
2205
|
+
Args:
|
|
2206
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2207
|
+
fused_linear_cross_entropy (bool):
|
|
2208
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2209
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2210
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2211
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2212
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2213
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2214
|
+
loaded. Default is None.
|
|
2215
|
+
"""
|
|
2216
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2217
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
636
2218
|
)
|
|
2219
|
+
import torch.nn as torch_nn
|
|
2220
|
+
|
|
2221
|
+
from transformers.models.internvl import modeling_internvl
|
|
2222
|
+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
|
|
2223
|
+
from transformers.models.internvl.modeling_internvl import InternVLModel
|
|
2224
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
|
|
2225
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
|
|
2226
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
|
|
2227
|
+
|
|
2228
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
2229
|
+
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
|
|
2230
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
2231
|
+
|
|
2232
|
+
if layer_norm and model is None:
|
|
2233
|
+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
|
|
2234
|
+
|
|
2235
|
+
if cross_entropy:
|
|
2236
|
+
logger.info("Apply liger cross entropy")
|
|
637
2237
|
|
|
638
|
-
|
|
2238
|
+
from transformers.loss.loss_utils import nn
|
|
639
2239
|
|
|
2240
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2241
|
+
if fused_linear_cross_entropy:
|
|
2242
|
+
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
|
|
640
2243
|
if rms_norm:
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
if
|
|
644
|
-
|
|
2244
|
+
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
|
|
2245
|
+
|
|
2246
|
+
if model is not None:
|
|
2247
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2248
|
+
# instance variables that reference already-instantiated modules
|
|
2249
|
+
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
|
|
2250
|
+
# NOTE: language_model and visual properties can be accessed throught conditional class.
|
|
2251
|
+
text_model = model.language_model
|
|
2252
|
+
vision_model: InternVLVisionModel = model.vision_tower
|
|
2253
|
+
else:
|
|
2254
|
+
raise TypeError(
|
|
2255
|
+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
|
|
2256
|
+
)
|
|
2257
|
+
|
|
2258
|
+
text_model_name = model.config.text_config.model_type
|
|
2259
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2260
|
+
|
|
2261
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2262
|
+
if text_liger_fn:
|
|
2263
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2264
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2265
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2266
|
+
|
|
2267
|
+
if remain_params:
|
|
2268
|
+
logger.warning(
|
|
2269
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2270
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2271
|
+
)
|
|
2272
|
+
text_kwargs["model"] = text_model
|
|
2273
|
+
text_liger_fn(**text_kwargs)
|
|
2274
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2275
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2276
|
+
|
|
2277
|
+
# Patch vision model RMSNorm layers
|
|
2278
|
+
if rms_norm:
|
|
2279
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2280
|
+
encoder_layer: InternVLVisionLayer
|
|
2281
|
+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
|
|
2282
|
+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
|
|
2283
|
+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
|
|
2284
|
+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
|
|
2285
|
+
|
|
2286
|
+
# Patch vision model LayerNorm layers
|
|
2287
|
+
if layer_norm:
|
|
2288
|
+
# Patch layernorm
|
|
2289
|
+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
|
|
2290
|
+
_patch_layer_norm_module(vision_model.layernorm)
|
|
2291
|
+
|
|
2292
|
+
# Patch encoder layers
|
|
2293
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2294
|
+
encoder_layer: InternVLVisionLayer
|
|
2295
|
+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
|
|
2296
|
+
_patch_layer_norm_module(encoder_layer.layernorm_before)
|
|
2297
|
+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
|
|
2298
|
+
_patch_layer_norm_module(encoder_layer.layernorm_after)
|
|
2299
|
+
|
|
2300
|
+
|
|
2301
|
+
def apply_liger_kernel_to_smolvlm(
|
|
2302
|
+
cross_entropy: bool = False,
|
|
2303
|
+
fused_linear_cross_entropy: bool = True,
|
|
2304
|
+
rms_norm: bool = True,
|
|
2305
|
+
layer_norm: bool = True,
|
|
2306
|
+
model: Optional[PreTrainedModel] = None,
|
|
2307
|
+
**kwargs,
|
|
2308
|
+
) -> None:
|
|
2309
|
+
"""
|
|
2310
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
|
|
2311
|
+
Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
|
|
2312
|
+
However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
|
|
2313
|
+
NOTE: SmolVLM is not available in transformers<4.50.0
|
|
2314
|
+
|
|
2315
|
+
Args:
|
|
2316
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2317
|
+
fused_linear_cross_entropy (bool):
|
|
2318
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2319
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2320
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2321
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2322
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2323
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2324
|
+
loaded. Default is None.
|
|
2325
|
+
"""
|
|
2326
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2327
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2328
|
+
)
|
|
2329
|
+
|
|
2330
|
+
from transformers.models.smolvlm import modeling_smolvlm
|
|
2331
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
|
|
2332
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
|
|
2333
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
|
|
2334
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
|
|
2335
|
+
|
|
2336
|
+
from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
|
|
2337
|
+
|
|
2338
|
+
# Patch LayerNorm for vision model if model is not provided (pre-initialization)
|
|
2339
|
+
if layer_norm and model is None:
|
|
2340
|
+
modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
|
|
2341
|
+
|
|
645
2342
|
if cross_entropy:
|
|
646
|
-
|
|
2343
|
+
logger.info("Apply liger cross entropy")
|
|
2344
|
+
|
|
2345
|
+
from transformers.loss.loss_utils import nn
|
|
2346
|
+
|
|
2347
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
647
2348
|
if fused_linear_cross_entropy:
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
2349
|
+
if model is not None:
|
|
2350
|
+
model.forward = MethodType(smolvlm_lce_forward, model)
|
|
2351
|
+
else:
|
|
2352
|
+
modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
|
|
2353
|
+
if rms_norm:
|
|
2354
|
+
modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
|
|
651
2355
|
|
|
652
2356
|
if model is not None:
|
|
653
2357
|
# The model instance already exists, so we need to additionally patch the
|
|
654
2358
|
# instance variables that reference already-instantiated modules
|
|
2359
|
+
if isinstance(model, SmolVLMForConditionalGeneration):
|
|
2360
|
+
text_model = model.model.text_model
|
|
2361
|
+
vision_model: SmolVLMVisionTransformer = model.model.vision_model
|
|
2362
|
+
elif isinstance(model, SmolVLMModel):
|
|
2363
|
+
text_model = model.text_model
|
|
2364
|
+
vision_model: SmolVLMVisionTransformer = model.vision_model
|
|
2365
|
+
else:
|
|
2366
|
+
raise TypeError(
|
|
2367
|
+
f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
|
|
2368
|
+
)
|
|
2369
|
+
|
|
2370
|
+
text_model_name = model.config.text_config.model_type
|
|
2371
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2372
|
+
|
|
2373
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2374
|
+
if text_liger_fn:
|
|
2375
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2376
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2377
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2378
|
+
|
|
2379
|
+
if remain_params:
|
|
2380
|
+
logger.warning(
|
|
2381
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2382
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2383
|
+
)
|
|
2384
|
+
text_kwargs["model"] = text_model
|
|
2385
|
+
text_liger_fn(**text_kwargs)
|
|
2386
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2387
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2388
|
+
|
|
2389
|
+
# Patch vision model LayerNorm layers
|
|
2390
|
+
if layer_norm:
|
|
2391
|
+
# Patch post_layernorm
|
|
2392
|
+
_patch_layer_norm_module(vision_model.post_layernorm)
|
|
2393
|
+
|
|
2394
|
+
# Patch encoder layers
|
|
2395
|
+
for encoder_layer in vision_model.encoder.layers:
|
|
2396
|
+
encoder_layer: SmolVLMEncoderLayer
|
|
2397
|
+
_patch_layer_norm_module(encoder_layer.layer_norm1)
|
|
2398
|
+
_patch_layer_norm_module(encoder_layer.layer_norm2)
|
|
2399
|
+
|
|
2400
|
+
|
|
2401
|
+
def apply_liger_kernel_to_falcon_h1(
|
|
2402
|
+
rope: bool = True,
|
|
2403
|
+
cross_entropy: bool = False,
|
|
2404
|
+
fused_linear_cross_entropy: bool = True,
|
|
2405
|
+
rms_norm: bool = True,
|
|
2406
|
+
swiglu: bool = False,
|
|
2407
|
+
model: PreTrainedModel = None,
|
|
2408
|
+
) -> None:
|
|
2409
|
+
"""
|
|
2410
|
+
Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
|
|
2411
|
+
Args:
|
|
2412
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2413
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
2414
|
+
fused_linear_cross_entropy (bool):
|
|
2415
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
2416
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2417
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2418
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
2419
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2420
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2421
|
+
loaded. Default is None.
|
|
2422
|
+
"""
|
|
2423
|
+
|
|
2424
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2425
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2426
|
+
)
|
|
2427
|
+
|
|
2428
|
+
from transformers.models.falcon_h1 import modeling_falcon_h1
|
|
2429
|
+
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
|
|
2430
|
+
|
|
2431
|
+
if rope:
|
|
2432
|
+
logger.info("Apply liger rotary pos emb.")
|
|
2433
|
+
modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2434
|
+
if rms_norm:
|
|
2435
|
+
logger.info("Apply liger RMSNorm")
|
|
2436
|
+
modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
|
|
2437
|
+
if swiglu:
|
|
2438
|
+
logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
|
|
655
2439
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
2440
|
+
if cross_entropy:
|
|
2441
|
+
logger.info("Apply liger cross entropy")
|
|
2442
|
+
from transformers.loss.loss_utils import nn
|
|
2443
|
+
|
|
2444
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2445
|
+
|
|
2446
|
+
if fused_linear_cross_entropy:
|
|
2447
|
+
if model is not None:
|
|
2448
|
+
model.forward = MethodType(falcon_h1_lce_forward, model)
|
|
659
2449
|
else:
|
|
660
|
-
|
|
661
|
-
base_model = model
|
|
2450
|
+
modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
|
|
662
2451
|
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
2452
|
+
if model is not None:
|
|
2453
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2454
|
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
2455
|
+
|
|
2456
|
+
# get the base model from the model instance
|
|
2457
|
+
base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
|
|
669
2458
|
|
|
670
2459
|
if rms_norm:
|
|
671
|
-
_patch_rms_norm_module(base_model.
|
|
2460
|
+
_patch_rms_norm_module(base_model.final_layernorm)
|
|
2461
|
+
|
|
672
2462
|
for decoder_layer in base_model.layers:
|
|
673
2463
|
if swiglu:
|
|
674
|
-
|
|
675
|
-
decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
|
|
676
|
-
)
|
|
2464
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
677
2465
|
if rms_norm:
|
|
678
2466
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
679
|
-
_patch_rms_norm_module(decoder_layer.
|
|
2467
|
+
_patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
|
|
680
2468
|
|
|
681
2469
|
|
|
682
|
-
def
|
|
683
|
-
rope: bool =
|
|
2470
|
+
def apply_liger_kernel_to_qwen3_next(
|
|
2471
|
+
rope: bool = False,
|
|
684
2472
|
cross_entropy: bool = False,
|
|
685
2473
|
fused_linear_cross_entropy: bool = True,
|
|
686
2474
|
rms_norm: bool = True,
|
|
@@ -688,77 +2476,125 @@ def apply_liger_kernel_to_phi3(
|
|
|
688
2476
|
model: PreTrainedModel = None,
|
|
689
2477
|
) -> None:
|
|
690
2478
|
"""
|
|
691
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
2479
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
692
2480
|
|
|
693
2481
|
Args:
|
|
694
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
2482
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
695
2483
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
696
2484
|
fused_linear_cross_entropy (bool):
|
|
697
2485
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
698
2486
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
699
2487
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
700
2488
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
701
|
-
swiglu (bool): Whether to apply Liger's
|
|
2489
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
702
2490
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
703
2491
|
loaded. Default is None.
|
|
704
2492
|
"""
|
|
705
|
-
assert not (
|
|
706
|
-
cross_entropy and fused_linear_cross_entropy
|
|
707
|
-
)
|
|
2493
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2494
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2495
|
+
)
|
|
708
2496
|
|
|
709
|
-
from transformers.models.
|
|
2497
|
+
from transformers.models.qwen3_next import modeling_qwen3_next
|
|
2498
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
|
|
2499
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
|
|
2500
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
|
|
2501
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
|
2502
|
+
|
|
2503
|
+
from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
|
|
2504
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
|
|
2505
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
710
2506
|
|
|
711
2507
|
if rope:
|
|
712
|
-
|
|
2508
|
+
# It might enocunter nan issue
|
|
2509
|
+
# modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2510
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
|
|
713
2511
|
if rms_norm:
|
|
714
|
-
|
|
715
|
-
if swiglu:
|
|
716
|
-
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
2512
|
+
modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
|
|
717
2513
|
if cross_entropy:
|
|
718
|
-
|
|
2514
|
+
from transformers.loss.loss_utils import nn
|
|
2515
|
+
|
|
2516
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
719
2517
|
if fused_linear_cross_entropy:
|
|
720
|
-
if
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
2518
|
+
if model is not None:
|
|
2519
|
+
if isinstance(model, Qwen3NextForCausalLM):
|
|
2520
|
+
model.forward = MethodType(qwen3_next_lce_forward, model)
|
|
2521
|
+
else:
|
|
2522
|
+
raise TypeError(
|
|
2523
|
+
f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
|
|
2524
|
+
)
|
|
2525
|
+
else:
|
|
2526
|
+
modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
|
|
2527
|
+
if swiglu:
|
|
2528
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2529
|
+
modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
|
|
725
2530
|
|
|
726
2531
|
if model is not None:
|
|
727
2532
|
# The model instance already exists, so we need to additionally patch the
|
|
728
2533
|
# instance variables that reference already-instantiated modules
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
# The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
|
|
732
|
-
base_model = model.model
|
|
2534
|
+
if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
|
|
2535
|
+
base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
|
|
733
2536
|
else:
|
|
734
|
-
|
|
735
|
-
|
|
2537
|
+
raise TypeError(
|
|
2538
|
+
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
|
|
2539
|
+
)
|
|
736
2540
|
|
|
737
2541
|
if rms_norm:
|
|
738
2542
|
_patch_rms_norm_module(base_model.norm)
|
|
739
2543
|
|
|
740
2544
|
for decoder_layer in base_model.layers:
|
|
741
|
-
if swiglu:
|
|
742
|
-
_bind_method_to_module(
|
|
743
|
-
decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
|
|
744
|
-
)
|
|
745
2545
|
if rms_norm:
|
|
746
2546
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
747
2547
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
748
2548
|
|
|
2549
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2550
|
+
if swiglu:
|
|
2551
|
+
if isinstance(decoder_layer.mlp, Qwen3NextMLP):
|
|
2552
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
|
|
2553
|
+
if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
|
|
2554
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
|
|
2555
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2556
|
+
if experts is not None:
|
|
2557
|
+
for expert in experts:
|
|
2558
|
+
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
|
|
2559
|
+
|
|
749
2560
|
|
|
750
2561
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
751
2562
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
752
2563
|
"gemma": apply_liger_kernel_to_gemma,
|
|
753
2564
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
2565
|
+
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
2566
|
+
"gemma3": apply_liger_kernel_to_gemma3,
|
|
2567
|
+
"glm4": apply_liger_kernel_to_glm4,
|
|
2568
|
+
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2569
|
+
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2570
|
+
"internvl": apply_liger_kernel_to_internvl,
|
|
754
2571
|
"llama": apply_liger_kernel_to_llama,
|
|
2572
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
2573
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
2574
|
+
"llava": apply_liger_kernel_to_llava,
|
|
2575
|
+
"granite": apply_liger_kernel_to_granite,
|
|
755
2576
|
"mllama": apply_liger_kernel_to_mllama,
|
|
756
2577
|
"mllama_text_model": apply_liger_kernel_to_mllama,
|
|
757
2578
|
"mistral": apply_liger_kernel_to_mistral,
|
|
758
2579
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
2580
|
+
"olmo2": apply_liger_kernel_to_olmo2,
|
|
759
2581
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
2582
|
+
"qwen3": apply_liger_kernel_to_qwen3,
|
|
2583
|
+
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
760
2584
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
2585
|
+
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
2586
|
+
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
2587
|
+
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2588
|
+
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2589
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2590
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2591
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2592
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2593
|
+
"smollm3": apply_liger_kernel_to_smollm3,
|
|
761
2594
|
"phi3": apply_liger_kernel_to_phi3,
|
|
2595
|
+
"paligemma": apply_liger_kernel_to_paligemma,
|
|
2596
|
+
"falcon_h1": apply_liger_kernel_to_falcon_h1,
|
|
2597
|
+
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
762
2598
|
}
|
|
763
2599
|
|
|
764
2600
|
|
|
@@ -782,24 +2618,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
|
|
|
782
2618
|
return
|
|
783
2619
|
|
|
784
2620
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
785
|
-
logger.info(
|
|
786
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
787
|
-
)
|
|
2621
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
|
788
2622
|
return
|
|
789
2623
|
|
|
790
2624
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
791
2625
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
792
2626
|
|
|
793
2627
|
# Filter out the keyword arguments that are not supported by the apply function
|
|
794
|
-
applicable_kwargs = {
|
|
795
|
-
key: value
|
|
796
|
-
for key, value in kwargs.items()
|
|
797
|
-
if key in apply_fn_signature.parameters
|
|
798
|
-
}
|
|
2628
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
|
799
2629
|
|
|
800
|
-
logger.info(
|
|
801
|
-
f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
802
|
-
)
|
|
2630
|
+
logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
|
|
803
2631
|
|
|
804
2632
|
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
|
|
805
2633
|
apply_fn(**applicable_kwargs)
|
|
@@ -813,32 +2641,21 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
813
2641
|
- model: the model instance to apply Liger kernels to
|
|
814
2642
|
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
815
2643
|
"""
|
|
816
|
-
model_type = getattr(model, "config", None) and getattr(
|
|
817
|
-
model.config, "model_type", None
|
|
818
|
-
)
|
|
2644
|
+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
|
819
2645
|
|
|
820
2646
|
if not model_type:
|
|
821
|
-
logger.info(
|
|
822
|
-
"Model type could not be determined from model config. No Liger kernels will be applied."
|
|
823
|
-
)
|
|
2647
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
|
824
2648
|
return
|
|
825
2649
|
|
|
826
2650
|
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
827
|
-
logger.info(
|
|
828
|
-
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
829
|
-
)
|
|
2651
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
|
830
2652
|
return
|
|
831
2653
|
|
|
832
2654
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
833
|
-
|
|
834
2655
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
835
2656
|
|
|
836
2657
|
# Filter out the keyword arguments that are not supported by the apply function
|
|
837
|
-
applicable_kwargs = {
|
|
838
|
-
key: value
|
|
839
|
-
for key, value in kwargs.items()
|
|
840
|
-
if key in apply_fn_signature.parameters
|
|
841
|
-
}
|
|
2658
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
|
842
2659
|
logger.info(
|
|
843
2660
|
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
844
2661
|
)
|