liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,219 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
6
|
+
from typing import Callable
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import transformers
|
|
10
|
+
|
|
11
|
+
from packaging import version
|
|
12
|
+
from transformers import PreTrainedModel
|
|
13
|
+
|
|
1
14
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
15
|
+
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
2
16
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
3
|
-
from liger_kernel.transformers.
|
|
17
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
18
|
+
from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
|
|
19
|
+
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
20
|
+
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
21
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
22
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
23
|
+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
|
|
24
|
+
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
25
|
+
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
26
|
+
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
27
|
+
from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
|
|
28
|
+
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
29
|
+
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
30
|
+
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
31
|
+
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
32
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
33
|
+
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
34
|
+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
35
|
+
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
4
36
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
5
37
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
6
|
-
from liger_kernel.transformers.
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
|
|
39
|
+
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
40
|
+
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
41
|
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
import peft
|
|
45
|
+
|
|
46
|
+
PEFT_AVAILABLE = True
|
|
47
|
+
except ImportError:
|
|
48
|
+
PEFT_AVAILABLE = False
|
|
49
|
+
|
|
50
|
+
transformer_version = version.parse(transformers.__version__)
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger(__name__)
|
|
53
|
+
SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
|
|
54
|
+
TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
58
|
+
# Binds a new method to a module instance so that self is passed as the first argument
|
|
59
|
+
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
63
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
64
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
65
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
66
|
+
module.modules_to_save.default.offset = offset
|
|
67
|
+
module.modules_to_save.default.casting_mode = casting_mode
|
|
68
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
69
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
70
|
+
)
|
|
71
|
+
module.modules_to_save.default.in_place = in_place
|
|
72
|
+
module.modules_to_save.default.row_mode = row_mode
|
|
73
|
+
module.original_module.offset = offset
|
|
74
|
+
module.original_module.casting_mode = casting_mode
|
|
75
|
+
module.original_module.variance_epsilon = (
|
|
76
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
77
|
+
)
|
|
78
|
+
module.original_module.in_place = in_place
|
|
79
|
+
module.original_module.row_mode = row_mode
|
|
80
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
81
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
82
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
83
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
84
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
85
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
86
|
+
else:
|
|
87
|
+
module.offset = offset
|
|
88
|
+
module.casting_mode = casting_mode
|
|
89
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
90
|
+
module.in_place = in_place
|
|
91
|
+
module.row_mode = row_mode
|
|
92
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
93
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
94
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _patch_layer_norm_module(module, eps=1e-6):
|
|
98
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
99
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
100
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
101
|
+
module.hidden_size = module.normalized_shape
|
|
102
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
103
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
104
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
105
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
106
|
+
)
|
|
107
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
108
|
+
module, "normalized_shape", None
|
|
109
|
+
)
|
|
110
|
+
module.original_module.variance_epsilon = (
|
|
111
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
112
|
+
)
|
|
113
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
114
|
+
module, "normalized_shape", None
|
|
115
|
+
)
|
|
116
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
|
|
117
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
|
|
118
|
+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
|
|
119
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
120
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
121
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
122
|
+
else:
|
|
123
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
124
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
125
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
126
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
127
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _patch_swiglu_module(module, liger_module):
|
|
131
|
+
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
132
|
+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _patch_geglu_module(module):
|
|
136
|
+
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
137
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def apply_liger_kernel_to_granite(
|
|
141
|
+
rope: bool = True,
|
|
142
|
+
cross_entropy: bool = True,
|
|
143
|
+
fused_linear_cross_entropy: bool = False,
|
|
144
|
+
rms_norm: bool = True,
|
|
145
|
+
swiglu: bool = True,
|
|
146
|
+
model: PreTrainedModel = None,
|
|
147
|
+
) -> None:
|
|
148
|
+
"""
|
|
149
|
+
Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
153
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
154
|
+
fused_linear_cross_entropy (bool):
|
|
155
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
156
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
157
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
158
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
159
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
160
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
161
|
+
loaded. Default is None.
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
Debugging notes:
|
|
166
|
+
If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
170
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
from transformers.models.granite import modeling_granite
|
|
174
|
+
from transformers.models.granite.modeling_granite import GraniteModel
|
|
175
|
+
|
|
176
|
+
if swiglu:
|
|
177
|
+
modeling_granite.GraniteMLP = LigerSwiGLUMLP
|
|
178
|
+
|
|
179
|
+
if rms_norm:
|
|
180
|
+
modeling_granite.GraniteRMSNorm = LigerRMSNorm
|
|
181
|
+
|
|
182
|
+
if rope:
|
|
183
|
+
modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
184
|
+
|
|
185
|
+
if cross_entropy:
|
|
186
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
187
|
+
from transformers.loss.loss_utils import nn
|
|
188
|
+
|
|
189
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
190
|
+
else:
|
|
191
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
192
|
+
modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
193
|
+
|
|
194
|
+
if fused_linear_cross_entropy:
|
|
195
|
+
raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
|
|
196
|
+
# NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
|
|
197
|
+
# call, so we can't sidestep logit materialization. A bit more work
|
|
198
|
+
# would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
|
|
199
|
+
# for the logit output.
|
|
200
|
+
|
|
201
|
+
if model is not None:
|
|
202
|
+
# The model instance already exists, so we need to additionally patch the
|
|
203
|
+
# instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
|
|
204
|
+
|
|
205
|
+
# get the base model from the model instance
|
|
206
|
+
base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
|
|
207
|
+
|
|
208
|
+
if rms_norm:
|
|
209
|
+
_patch_rms_norm_module(base_model.norm)
|
|
210
|
+
|
|
211
|
+
for decoder_layer in base_model.layers:
|
|
212
|
+
if swiglu:
|
|
213
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
214
|
+
if rms_norm:
|
|
215
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
216
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
7
217
|
|
|
8
218
|
|
|
9
219
|
def apply_liger_kernel_to_llama(
|
|
@@ -12,6 +222,7 @@ def apply_liger_kernel_to_llama(
|
|
|
12
222
|
fused_linear_cross_entropy: bool = True,
|
|
13
223
|
rms_norm: bool = True,
|
|
14
224
|
swiglu: bool = True,
|
|
225
|
+
model: PreTrainedModel = None,
|
|
15
226
|
) -> None:
|
|
16
227
|
"""
|
|
17
228
|
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
|
@@ -20,18 +231,21 @@ def apply_liger_kernel_to_llama(
|
|
|
20
231
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
21
232
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
22
233
|
fused_linear_cross_entropy (bool):
|
|
23
|
-
Whether to apply Liger's fused
|
|
234
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
24
235
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
25
236
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
26
237
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
27
238
|
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
239
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
240
|
+
loaded. Default is None.
|
|
28
241
|
"""
|
|
29
242
|
|
|
30
|
-
assert not (
|
|
31
|
-
cross_entropy and fused_linear_cross_entropy
|
|
32
|
-
)
|
|
243
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
244
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
245
|
+
)
|
|
33
246
|
|
|
34
247
|
from transformers.models.llama import modeling_llama
|
|
248
|
+
from transformers.models.llama.modeling_llama import LlamaModel
|
|
35
249
|
|
|
36
250
|
if rope:
|
|
37
251
|
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -39,29 +253,439 @@ def apply_liger_kernel_to_llama(
|
|
|
39
253
|
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
|
40
254
|
if swiglu:
|
|
41
255
|
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
|
256
|
+
|
|
257
|
+
if cross_entropy:
|
|
258
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
259
|
+
from transformers.loss.loss_utils import nn
|
|
260
|
+
|
|
261
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
262
|
+
else:
|
|
263
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
264
|
+
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
265
|
+
|
|
266
|
+
if fused_linear_cross_entropy:
|
|
267
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
268
|
+
if model is not None:
|
|
269
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
270
|
+
else:
|
|
271
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
272
|
+
else: # if version < 4.46.1
|
|
273
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
274
|
+
if model is not None:
|
|
275
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
276
|
+
else:
|
|
277
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
278
|
+
|
|
279
|
+
if model is not None:
|
|
280
|
+
# The model instance already exists, so we need to additionally patch the
|
|
281
|
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
282
|
+
|
|
283
|
+
# get the base model from the model instance
|
|
284
|
+
base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
|
|
285
|
+
|
|
286
|
+
if rms_norm:
|
|
287
|
+
_patch_rms_norm_module(base_model.norm)
|
|
288
|
+
|
|
289
|
+
for decoder_layer in base_model.layers:
|
|
290
|
+
if swiglu:
|
|
291
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
292
|
+
if rms_norm:
|
|
293
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
294
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def apply_liger_kernel_to_smollm3(
|
|
298
|
+
rope: bool = True,
|
|
299
|
+
cross_entropy: bool = False,
|
|
300
|
+
fused_linear_cross_entropy: bool = True,
|
|
301
|
+
rms_norm: bool = True,
|
|
302
|
+
swiglu: bool = True,
|
|
303
|
+
model: PreTrainedModel = None,
|
|
304
|
+
) -> None:
|
|
305
|
+
"""
|
|
306
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
310
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
311
|
+
fused_linear_cross_entropy (bool):
|
|
312
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
313
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
314
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
315
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
316
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
317
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
318
|
+
loaded. Default is None.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
322
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
from transformers.models.smollm3 import modeling_smollm3
|
|
326
|
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
|
|
327
|
+
|
|
328
|
+
if rope:
|
|
329
|
+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
330
|
+
if rms_norm:
|
|
331
|
+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
|
|
332
|
+
if swiglu:
|
|
333
|
+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
|
|
334
|
+
|
|
335
|
+
if cross_entropy:
|
|
336
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
337
|
+
from transformers.loss.loss_utils import nn
|
|
338
|
+
|
|
339
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
340
|
+
else:
|
|
341
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
342
|
+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
343
|
+
|
|
344
|
+
if fused_linear_cross_entropy:
|
|
345
|
+
if model is not None:
|
|
346
|
+
model.forward = MethodType(smollm3_lce_forward, model)
|
|
347
|
+
else:
|
|
348
|
+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
|
|
349
|
+
|
|
350
|
+
if model is not None:
|
|
351
|
+
# The model instance already exists, so we need to additionally patch the
|
|
352
|
+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
|
|
353
|
+
|
|
354
|
+
# get the base model from the model instance
|
|
355
|
+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
|
|
356
|
+
|
|
357
|
+
if rms_norm:
|
|
358
|
+
_patch_rms_norm_module(base_model.norm)
|
|
359
|
+
|
|
360
|
+
for decoder_layer in base_model.layers:
|
|
361
|
+
if swiglu:
|
|
362
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
363
|
+
if rms_norm:
|
|
364
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
365
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def apply_liger_kernel_to_llava(
|
|
369
|
+
cross_entropy: bool = False,
|
|
370
|
+
fused_linear_cross_entropy: bool = True,
|
|
371
|
+
model: PreTrainedModel = None,
|
|
372
|
+
**kwargs,
|
|
373
|
+
) -> None:
|
|
374
|
+
"""
|
|
375
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
|
|
376
|
+
Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
|
|
377
|
+
However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
|
|
378
|
+
NOTE: Llava is not available in transformers<4.36.0
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
382
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
383
|
+
fused_linear_cross_entropy (bool):
|
|
384
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
385
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
386
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
387
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
388
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
389
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
390
|
+
loaded. Default is None.
|
|
391
|
+
"""
|
|
392
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
393
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
from transformers.models.llava import modeling_llava
|
|
397
|
+
|
|
398
|
+
if cross_entropy:
|
|
399
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
400
|
+
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
401
|
+
if fused_linear_cross_entropy:
|
|
402
|
+
if transformer_version >= version.parse("4.52.0"):
|
|
403
|
+
if model is not None:
|
|
404
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
405
|
+
else:
|
|
406
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
407
|
+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
408
|
+
if model is not None:
|
|
409
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
410
|
+
else:
|
|
411
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
412
|
+
else: # if version < 4.49.0
|
|
413
|
+
logger.warning(
|
|
414
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
if model is not None:
|
|
418
|
+
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
419
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
420
|
+
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
|
|
421
|
+
|
|
422
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
|
|
423
|
+
if text_liger_fn:
|
|
424
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
425
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
426
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
427
|
+
|
|
428
|
+
if remain_params:
|
|
429
|
+
logger.warning(
|
|
430
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
431
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
432
|
+
)
|
|
433
|
+
text_kwargs["model"] = model.language_model
|
|
434
|
+
text_liger_fn(**text_kwargs)
|
|
435
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
436
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
437
|
+
|
|
438
|
+
if vision_liger_fn:
|
|
439
|
+
accept_params = inspect.signature(vision_liger_fn).parameters
|
|
440
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
441
|
+
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
442
|
+
|
|
443
|
+
if remain_params:
|
|
444
|
+
logger.warning(
|
|
445
|
+
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
446
|
+
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
447
|
+
)
|
|
448
|
+
vision_kwargs["model"] = model.vision_tower
|
|
449
|
+
vision_liger_fn(**vision_kwargs)
|
|
450
|
+
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
451
|
+
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def apply_liger_kernel_to_llama4(
|
|
455
|
+
rope: bool = True,
|
|
456
|
+
cross_entropy: bool = False,
|
|
457
|
+
fused_linear_cross_entropy: bool = True,
|
|
458
|
+
rms_norm: bool = True,
|
|
459
|
+
swiglu: bool = True,
|
|
460
|
+
model: PreTrainedModel = None,
|
|
461
|
+
layer_norm: bool = True,
|
|
462
|
+
) -> None:
|
|
463
|
+
"""
|
|
464
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
468
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
469
|
+
fused_linear_cross_entropy (bool):
|
|
470
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
471
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
472
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
473
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
474
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
475
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
476
|
+
loaded. Default is None.
|
|
477
|
+
"""
|
|
478
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
479
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
from transformers.models.llama4 import modeling_llama4
|
|
483
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
484
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
485
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
486
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
487
|
+
|
|
488
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
489
|
+
|
|
490
|
+
if rope:
|
|
491
|
+
from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
|
|
492
|
+
|
|
493
|
+
apply_liger_llama4_rope_full(modeling_llama4)
|
|
494
|
+
if rms_norm:
|
|
495
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
496
|
+
if swiglu:
|
|
497
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
498
|
+
|
|
499
|
+
if cross_entropy:
|
|
500
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
501
|
+
|
|
502
|
+
if fused_linear_cross_entropy:
|
|
503
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
504
|
+
|
|
505
|
+
if model is not None:
|
|
506
|
+
# The model instance already exists, so we need to additionally patch the
|
|
507
|
+
# instance variables that reference already-instantiated modules
|
|
508
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
509
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
510
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
511
|
+
text_model: Llama4TextModel = language_model.model
|
|
512
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
513
|
+
text_model = model.model
|
|
514
|
+
vision_model = None
|
|
515
|
+
elif isinstance(model, Llama4TextModel):
|
|
516
|
+
text_model = model
|
|
517
|
+
vision_model = None
|
|
518
|
+
|
|
519
|
+
else:
|
|
520
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
521
|
+
|
|
522
|
+
if text_model:
|
|
523
|
+
if rms_norm:
|
|
524
|
+
_patch_rms_norm_module(text_model.norm)
|
|
525
|
+
for decoder_layer in text_model.layers:
|
|
526
|
+
if swiglu:
|
|
527
|
+
if decoder_layer.is_moe_layer:
|
|
528
|
+
_patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
|
|
529
|
+
else:
|
|
530
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
531
|
+
if rms_norm:
|
|
532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
534
|
+
|
|
535
|
+
if vision_model:
|
|
536
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
537
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
538
|
+
|
|
539
|
+
for layer in vision_model.model.layers:
|
|
540
|
+
if layer_norm:
|
|
541
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
542
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def apply_liger_kernel_to_mllama(
|
|
546
|
+
rope: bool = True,
|
|
547
|
+
cross_entropy: bool = False,
|
|
548
|
+
fused_linear_cross_entropy: bool = True,
|
|
549
|
+
layer_norm: bool = True,
|
|
550
|
+
rms_norm: bool = True,
|
|
551
|
+
swiglu: bool = True,
|
|
552
|
+
model: PreTrainedModel = None,
|
|
553
|
+
) -> None:
|
|
554
|
+
"""
|
|
555
|
+
Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
|
|
556
|
+
NOTE: MLlama is not available in transformers<4.45.0
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
560
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
561
|
+
fused_linear_cross_entropy (bool):
|
|
562
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
563
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
564
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
565
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
566
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
567
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
568
|
+
loaded. Default is None.
|
|
569
|
+
"""
|
|
570
|
+
|
|
571
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
572
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
from transformers.models.mllama import modeling_mllama
|
|
576
|
+
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
|
|
577
|
+
from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
|
|
578
|
+
from transformers.models.mllama.modeling_mllama import MllamaTextModel
|
|
579
|
+
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
|
|
580
|
+
|
|
581
|
+
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
|
|
582
|
+
from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
|
|
583
|
+
|
|
584
|
+
if rope:
|
|
585
|
+
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
586
|
+
if layer_norm and model is None:
|
|
587
|
+
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
588
|
+
if rms_norm:
|
|
589
|
+
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
590
|
+
if swiglu:
|
|
591
|
+
modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
|
|
42
592
|
if cross_entropy:
|
|
43
|
-
|
|
593
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
594
|
+
from transformers.loss.loss_utils import nn
|
|
595
|
+
|
|
596
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
597
|
+
else:
|
|
598
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
599
|
+
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
44
600
|
if fused_linear_cross_entropy:
|
|
45
|
-
|
|
601
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
602
|
+
if model is not None:
|
|
603
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
604
|
+
else:
|
|
605
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
606
|
+
else: # if version < 4.46.1
|
|
607
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
608
|
+
if model is not None:
|
|
609
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
610
|
+
else:
|
|
611
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
612
|
+
|
|
613
|
+
if model is not None:
|
|
614
|
+
# The model instance already exists, so we need to additionally patch the
|
|
615
|
+
# instance variables that reference already-instantiated modules
|
|
616
|
+
|
|
617
|
+
if isinstance(model, MllamaForConditionalGeneration):
|
|
618
|
+
language_model: MllamaForCausalLM = model.language_model
|
|
619
|
+
vision_model: MllamaVisionModel = model.vision_model
|
|
620
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
621
|
+
text_model: MllamaTextModel = language_model.model
|
|
622
|
+
else:
|
|
623
|
+
text_model = language_model
|
|
624
|
+
elif isinstance(model, MllamaForCausalLM):
|
|
625
|
+
text_model = model.model
|
|
626
|
+
vision_model = None
|
|
627
|
+
elif isinstance(model, MllamaTextModel):
|
|
628
|
+
text_model = model
|
|
629
|
+
vision_model = None
|
|
630
|
+
|
|
631
|
+
else:
|
|
632
|
+
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
|
633
|
+
|
|
634
|
+
if text_model:
|
|
635
|
+
if rms_norm:
|
|
636
|
+
_patch_rms_norm_module(text_model.norm)
|
|
637
|
+
for decoder_layer in text_model.layers:
|
|
638
|
+
if swiglu:
|
|
639
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
640
|
+
if rms_norm:
|
|
641
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
642
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
643
|
+
|
|
644
|
+
if vision_model:
|
|
645
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
646
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
647
|
+
|
|
648
|
+
for layer in vision_model.transformer.layers:
|
|
649
|
+
if layer_norm:
|
|
650
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
651
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
652
|
+
|
|
653
|
+
for layer in vision_model.global_transformer.layers:
|
|
654
|
+
if layer_norm:
|
|
655
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
656
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
46
657
|
|
|
47
658
|
|
|
48
659
|
def apply_liger_kernel_to_mistral(
|
|
49
660
|
rope: bool = True,
|
|
50
|
-
cross_entropy: bool =
|
|
661
|
+
cross_entropy: bool = False,
|
|
662
|
+
fused_linear_cross_entropy: bool = True,
|
|
51
663
|
rms_norm: bool = True,
|
|
52
664
|
swiglu: bool = True,
|
|
665
|
+
model: PreTrainedModel = None,
|
|
53
666
|
) -> None:
|
|
54
667
|
"""
|
|
55
668
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
56
669
|
|
|
57
670
|
Args:
|
|
58
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
671
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
59
672
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
673
|
+
fused_linear_cross_entropy (bool):
|
|
674
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
675
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
676
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
677
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
60
678
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
61
679
|
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
680
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
681
|
+
loaded. Default is None.
|
|
62
682
|
"""
|
|
683
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
684
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
685
|
+
)
|
|
63
686
|
|
|
64
687
|
from transformers.models.mistral import modeling_mistral
|
|
688
|
+
from transformers.models.mistral.modeling_mistral import MistralModel
|
|
65
689
|
|
|
66
690
|
if rope:
|
|
67
691
|
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -69,62 +693,2233 @@ def apply_liger_kernel_to_mistral(
|
|
|
69
693
|
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
|
70
694
|
if cross_entropy:
|
|
71
695
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
696
|
+
if fused_linear_cross_entropy:
|
|
697
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
698
|
+
if model is not None:
|
|
699
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
700
|
+
else:
|
|
701
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
702
|
+
else:
|
|
703
|
+
logger.warning(
|
|
704
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
705
|
+
)
|
|
706
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
707
|
+
|
|
72
708
|
if swiglu:
|
|
73
709
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
74
710
|
|
|
711
|
+
if model is not None:
|
|
712
|
+
# The model instance already exists, so we need to additionally patch the
|
|
713
|
+
# instance variables that reference already-instantiated modules
|
|
714
|
+
|
|
715
|
+
# get the base model from the model instance
|
|
716
|
+
base_model: MistralModel = getattr(model, model.base_model_prefix, model)
|
|
717
|
+
|
|
718
|
+
if rms_norm:
|
|
719
|
+
_patch_rms_norm_module(base_model.norm)
|
|
720
|
+
|
|
721
|
+
for decoder_layer in base_model.layers:
|
|
722
|
+
if swiglu:
|
|
723
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
724
|
+
if rms_norm:
|
|
725
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
726
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
727
|
+
|
|
75
728
|
|
|
76
729
|
def apply_liger_kernel_to_mixtral(
|
|
77
730
|
rope: bool = True,
|
|
78
|
-
cross_entropy: bool =
|
|
731
|
+
cross_entropy: bool = False,
|
|
732
|
+
fused_linear_cross_entropy: bool = True,
|
|
79
733
|
rms_norm: bool = True,
|
|
80
734
|
swiglu: bool = True,
|
|
735
|
+
model: PreTrainedModel = None,
|
|
81
736
|
) -> None:
|
|
82
737
|
"""
|
|
83
738
|
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
|
|
84
739
|
|
|
85
740
|
Args:
|
|
86
741
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
87
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
742
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
743
|
+
fused_linear_cross_entropy (bool):
|
|
744
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
745
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
746
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
88
747
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
89
748
|
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
749
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
750
|
+
loaded. Default is None.
|
|
90
751
|
"""
|
|
91
752
|
|
|
753
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
754
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
755
|
+
)
|
|
756
|
+
|
|
92
757
|
from transformers.models.mixtral import modeling_mixtral
|
|
758
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
93
759
|
|
|
94
760
|
if rope:
|
|
95
761
|
modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
96
762
|
if rms_norm:
|
|
97
|
-
modeling_mixtral.
|
|
763
|
+
modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
|
|
98
764
|
if cross_entropy:
|
|
99
|
-
|
|
765
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
766
|
+
from transformers.loss.loss_utils import nn
|
|
767
|
+
|
|
768
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
769
|
+
else:
|
|
770
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
771
|
+
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
772
|
+
|
|
773
|
+
if fused_linear_cross_entropy:
|
|
774
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
775
|
+
if model is not None:
|
|
776
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
777
|
+
else:
|
|
778
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
779
|
+
else: # if version < 4.46.1
|
|
780
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
781
|
+
if model is not None:
|
|
782
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
783
|
+
else:
|
|
784
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
100
785
|
if swiglu:
|
|
101
786
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
102
787
|
|
|
788
|
+
if model is not None:
|
|
789
|
+
# The model instance already exists, so we need to additionally patch the
|
|
790
|
+
# instance variables that reference already-instantiated modules
|
|
791
|
+
|
|
792
|
+
# get the base model from the model instance
|
|
793
|
+
base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
|
|
794
|
+
|
|
795
|
+
if rms_norm:
|
|
796
|
+
_patch_rms_norm_module(base_model.norm)
|
|
797
|
+
|
|
798
|
+
for decoder_layer in base_model.layers:
|
|
799
|
+
if swiglu:
|
|
800
|
+
for expert in decoder_layer.block_sparse_moe.experts:
|
|
801
|
+
_patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
|
|
802
|
+
if rms_norm:
|
|
803
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
804
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
805
|
+
|
|
103
806
|
|
|
104
807
|
def apply_liger_kernel_to_gemma(
|
|
105
808
|
rope: bool = True,
|
|
106
|
-
cross_entropy: bool =
|
|
809
|
+
cross_entropy: bool = False,
|
|
810
|
+
fused_linear_cross_entropy: bool = True,
|
|
107
811
|
rms_norm: bool = True,
|
|
108
812
|
geglu: bool = True,
|
|
813
|
+
model: PreTrainedModel = None,
|
|
109
814
|
) -> None:
|
|
110
815
|
"""
|
|
111
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
112
|
-
to make GPU go burrr.
|
|
816
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma
|
|
817
|
+
(Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
|
|
113
818
|
|
|
114
819
|
Args:
|
|
115
820
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
116
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
821
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
822
|
+
fused_linear_cross_entropy (bool):
|
|
823
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
824
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
825
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
117
826
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
118
827
|
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
828
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
829
|
+
loaded. Default is None.
|
|
119
830
|
"""
|
|
120
|
-
|
|
831
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
832
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
833
|
+
)
|
|
834
|
+
|
|
121
835
|
from transformers.models.gemma import modeling_gemma
|
|
836
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
837
|
+
|
|
838
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
|
839
|
+
|
|
840
|
+
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
|
122
841
|
|
|
123
842
|
if rope:
|
|
124
843
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
125
844
|
if rms_norm:
|
|
126
|
-
modeling_gemma.GemmaRMSNorm =
|
|
845
|
+
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
|
|
127
846
|
if cross_entropy:
|
|
128
|
-
|
|
847
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
848
|
+
from transformers.loss.loss_utils import nn
|
|
849
|
+
|
|
850
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
851
|
+
else:
|
|
852
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
853
|
+
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
129
854
|
if geglu:
|
|
130
855
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
856
|
+
if fused_linear_cross_entropy:
|
|
857
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
858
|
+
if model is not None:
|
|
859
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
860
|
+
else:
|
|
861
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
862
|
+
else: # if version < 4.46.1
|
|
863
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
864
|
+
if model is not None:
|
|
865
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
866
|
+
else:
|
|
867
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
868
|
+
|
|
869
|
+
if model is not None:
|
|
870
|
+
# The model instance already exists, so we need to additionally patch the
|
|
871
|
+
# instance variables that reference already-instantiated modules
|
|
872
|
+
|
|
873
|
+
# get the base model from the model instance
|
|
874
|
+
base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
|
|
875
|
+
|
|
876
|
+
if rms_norm:
|
|
877
|
+
_patch_rms_norm_module_for_gemma(base_model.norm)
|
|
878
|
+
|
|
879
|
+
for decoder_layer in base_model.layers:
|
|
880
|
+
if geglu:
|
|
881
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
882
|
+
if rms_norm:
|
|
883
|
+
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
|
884
|
+
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def apply_liger_kernel_to_gemma2(
|
|
888
|
+
rope: bool = True,
|
|
889
|
+
cross_entropy: bool = False,
|
|
890
|
+
fused_linear_cross_entropy: bool = True,
|
|
891
|
+
rms_norm: bool = True,
|
|
892
|
+
geglu: bool = True,
|
|
893
|
+
model: PreTrainedModel = None,
|
|
894
|
+
) -> None:
|
|
895
|
+
"""
|
|
896
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma2
|
|
897
|
+
(for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
|
|
898
|
+
|
|
899
|
+
Args:
|
|
900
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
901
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
902
|
+
fused_linear_cross_entropy (bool):
|
|
903
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
904
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
905
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
906
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
907
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
908
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
909
|
+
loaded. Default is None.
|
|
910
|
+
"""
|
|
911
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
912
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
from transformers.models.gemma2 import modeling_gemma2
|
|
916
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
917
|
+
|
|
918
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
|
919
|
+
|
|
920
|
+
_patch_rms_norm_module_for_gemma2 = partial(
|
|
921
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
if rope:
|
|
925
|
+
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
926
|
+
if rms_norm:
|
|
927
|
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
928
|
+
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
|
|
929
|
+
if cross_entropy:
|
|
930
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
931
|
+
from transformers.loss.loss_utils import nn
|
|
932
|
+
|
|
933
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
934
|
+
else:
|
|
935
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
936
|
+
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
937
|
+
if fused_linear_cross_entropy:
|
|
938
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
939
|
+
if model is not None:
|
|
940
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
941
|
+
else:
|
|
942
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
943
|
+
else:
|
|
944
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
945
|
+
if model is not None:
|
|
946
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
947
|
+
else:
|
|
948
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
949
|
+
if geglu:
|
|
950
|
+
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
951
|
+
|
|
952
|
+
if model is not None:
|
|
953
|
+
# The model instance already exists, so we need to additionally patch the
|
|
954
|
+
# instance variables that reference already-instantiated modules
|
|
955
|
+
|
|
956
|
+
# get the base model from the model instance
|
|
957
|
+
base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
|
|
958
|
+
|
|
959
|
+
if rms_norm:
|
|
960
|
+
_patch_rms_norm_module_for_gemma2(base_model.norm)
|
|
961
|
+
|
|
962
|
+
for decoder_layer in base_model.layers:
|
|
963
|
+
if geglu:
|
|
964
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
965
|
+
if rms_norm:
|
|
966
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
|
967
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
|
968
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
|
|
969
|
+
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def apply_liger_kernel_to_gemma3_text(
|
|
973
|
+
rope: bool = True,
|
|
974
|
+
cross_entropy: bool = False,
|
|
975
|
+
fused_linear_cross_entropy: bool = True,
|
|
976
|
+
rms_norm: bool = True,
|
|
977
|
+
geglu: bool = True,
|
|
978
|
+
model: PreTrainedModel = None,
|
|
979
|
+
) -> None:
|
|
980
|
+
"""
|
|
981
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
|
|
982
|
+
|
|
983
|
+
Args:
|
|
984
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
985
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
986
|
+
fused_linear_cross_entropy (bool):
|
|
987
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
988
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
989
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
990
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
991
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
992
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
993
|
+
loaded. Default is None.
|
|
994
|
+
"""
|
|
995
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
996
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
1000
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
|
1001
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
|
1002
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
|
1003
|
+
|
|
1004
|
+
from liger_kernel.transformers.model.gemma3 import causal_forward
|
|
1005
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
|
1006
|
+
|
|
1007
|
+
_patch_rms_norm_module_for_gemma3 = partial(
|
|
1008
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
if rope:
|
|
1012
|
+
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1013
|
+
|
|
1014
|
+
if rms_norm:
|
|
1015
|
+
modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
|
|
1016
|
+
|
|
1017
|
+
if geglu:
|
|
1018
|
+
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
|
1019
|
+
|
|
1020
|
+
# Handle loss function
|
|
1021
|
+
if cross_entropy:
|
|
1022
|
+
from transformers.loss.loss_utils import nn
|
|
1023
|
+
|
|
1024
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1025
|
+
|
|
1026
|
+
if fused_linear_cross_entropy:
|
|
1027
|
+
if model is not None:
|
|
1028
|
+
model.forward = MethodType(causal_forward, model)
|
|
1029
|
+
else:
|
|
1030
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
1031
|
+
|
|
1032
|
+
if model is not None:
|
|
1033
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1034
|
+
# instance variables that reference already-instantiated modules
|
|
1035
|
+
|
|
1036
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
|
1037
|
+
# get the base model from the model instance
|
|
1038
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
|
1039
|
+
|
|
1040
|
+
if rms_norm:
|
|
1041
|
+
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
|
1042
|
+
|
|
1043
|
+
for decoder_layer in base_model.layers:
|
|
1044
|
+
decoder_layer: Gemma3DecoderLayer
|
|
1045
|
+
if geglu:
|
|
1046
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
|
1047
|
+
if rms_norm:
|
|
1048
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
|
|
1049
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
|
|
1050
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
|
|
1051
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
|
|
1052
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
|
|
1053
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
|
|
1054
|
+
|
|
1055
|
+
else:
|
|
1056
|
+
raise TypeError("The model must be Gemma3ForCausalLM.")
|
|
1057
|
+
|
|
1058
|
+
|
|
1059
|
+
def apply_liger_kernel_to_gemma3(
|
|
1060
|
+
rope: bool = True,
|
|
1061
|
+
cross_entropy: bool = False,
|
|
1062
|
+
fused_linear_cross_entropy: bool = True,
|
|
1063
|
+
layer_norm: bool = True,
|
|
1064
|
+
rms_norm: bool = True,
|
|
1065
|
+
geglu: bool = True,
|
|
1066
|
+
model: PreTrainedModel = None,
|
|
1067
|
+
) -> None:
|
|
1068
|
+
"""
|
|
1069
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
|
|
1070
|
+
|
|
1071
|
+
Args:
|
|
1072
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1073
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1074
|
+
fused_linear_cross_entropy (bool):
|
|
1075
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1076
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1077
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1078
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1079
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1080
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
1081
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1082
|
+
loaded. Default is None.
|
|
1083
|
+
"""
|
|
1084
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1085
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
1089
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
|
|
1090
|
+
from transformers.models.siglip import modeling_siglip
|
|
1091
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
1092
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
1093
|
+
|
|
1094
|
+
from liger_kernel.transformers.model.gemma3 import multimodal_forward
|
|
1095
|
+
|
|
1096
|
+
_patch_rms_norm_module_for_gemma3 = partial(
|
|
1097
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
if layer_norm and model is None:
|
|
1101
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
1102
|
+
|
|
1103
|
+
apply_liger_kernel_to_gemma3_text(
|
|
1104
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
if cross_entropy:
|
|
1108
|
+
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1109
|
+
|
|
1110
|
+
if fused_linear_cross_entropy:
|
|
1111
|
+
if model is not None:
|
|
1112
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1113
|
+
else:
|
|
1114
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
1115
|
+
|
|
1116
|
+
if model is not None:
|
|
1117
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1118
|
+
# instance variables that reference already-instantiated modules
|
|
1119
|
+
|
|
1120
|
+
if isinstance(model, Gemma3ForConditionalGeneration):
|
|
1121
|
+
if isinstance(model.vision_tower, SiglipVisionModel):
|
|
1122
|
+
vision_tower = model.vision_tower
|
|
1123
|
+
|
|
1124
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1125
|
+
|
|
1126
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
1127
|
+
layer: SiglipEncoderLayer
|
|
1128
|
+
if layer_norm:
|
|
1129
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
1130
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
1131
|
+
else:
|
|
1132
|
+
raise TypeError("The vision tower must be SiglipVisionModel")
|
|
1133
|
+
|
|
1134
|
+
if rms_norm:
|
|
1135
|
+
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
|
|
1136
|
+
|
|
1137
|
+
apply_liger_kernel_to_gemma3_text(
|
|
1138
|
+
rope=rope,
|
|
1139
|
+
cross_entropy=False,
|
|
1140
|
+
fused_linear_cross_entropy=False,
|
|
1141
|
+
rms_norm=rms_norm,
|
|
1142
|
+
geglu=geglu,
|
|
1143
|
+
model=model.language_model,
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
else:
|
|
1147
|
+
raise TypeError("The model must be Gemma3ForConditionalGeneration.")
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
def apply_liger_kernel_to_paligemma(
|
|
1151
|
+
rope: bool = True,
|
|
1152
|
+
cross_entropy: bool = False,
|
|
1153
|
+
fused_linear_cross_entropy: bool = True,
|
|
1154
|
+
layer_norm: bool = True,
|
|
1155
|
+
rms_norm: bool = True,
|
|
1156
|
+
geglu: bool = True,
|
|
1157
|
+
model: PreTrainedModel = None,
|
|
1158
|
+
) -> None:
|
|
1159
|
+
"""
|
|
1160
|
+
Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
|
|
1161
|
+
|
|
1162
|
+
Args:
|
|
1163
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1164
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1165
|
+
fused_linear_cross_entropy (bool):
|
|
1166
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1167
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1168
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1169
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1170
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1171
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
1172
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1173
|
+
loaded. Default is None.
|
|
1174
|
+
"""
|
|
1175
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1176
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
1180
|
+
|
|
1181
|
+
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1182
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
1183
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1184
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
1185
|
+
from transformers.models.paligemma import modeling_paligemma
|
|
1186
|
+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
1187
|
+
from transformers.models.siglip import modeling_siglip
|
|
1188
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
1189
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
1190
|
+
|
|
1191
|
+
from liger_kernel.transformers.model.paligemma import lce_forward
|
|
1192
|
+
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
1193
|
+
|
|
1194
|
+
# The vision_tower is a SiglipVisionModel
|
|
1195
|
+
if layer_norm and model is None:
|
|
1196
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
1197
|
+
|
|
1198
|
+
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
1199
|
+
# The multi_modal_projector is Linear, nothing to do
|
|
1200
|
+
|
|
1201
|
+
# The language_model is GemmaForCausalLM or Gemma2ForCausalLM
|
|
1202
|
+
apply_liger_kernel_to_gemma(
|
|
1203
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1204
|
+
)
|
|
1205
|
+
apply_liger_kernel_to_gemma2(
|
|
1206
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
1207
|
+
)
|
|
1208
|
+
# Handle loss function
|
|
1209
|
+
if cross_entropy:
|
|
1210
|
+
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1211
|
+
if fused_linear_cross_entropy:
|
|
1212
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1213
|
+
if model is not None:
|
|
1214
|
+
model.forward = MethodType(lce_forward, model)
|
|
1215
|
+
else:
|
|
1216
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
1217
|
+
else: # if version < 4.46.1
|
|
1218
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1219
|
+
if model is not None:
|
|
1220
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1221
|
+
else:
|
|
1222
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
1223
|
+
|
|
1224
|
+
if model is not None:
|
|
1225
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1226
|
+
# instance variables that reference already-instantiated modules
|
|
1227
|
+
|
|
1228
|
+
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
1229
|
+
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
1230
|
+
|
|
1231
|
+
vision_tower: SiglipVisionModel = model.vision_tower
|
|
1232
|
+
|
|
1233
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
1234
|
+
|
|
1235
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
1236
|
+
layer: SiglipEncoderLayer
|
|
1237
|
+
if layer_norm:
|
|
1238
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
1239
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
1240
|
+
|
|
1241
|
+
language_model = model.language_model
|
|
1242
|
+
|
|
1243
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1244
|
+
apply_liger_kernel_to_gemma(
|
|
1245
|
+
rope=rope,
|
|
1246
|
+
cross_entropy=False,
|
|
1247
|
+
fused_linear_cross_entropy=False,
|
|
1248
|
+
rms_norm=rms_norm,
|
|
1249
|
+
geglu=geglu,
|
|
1250
|
+
model=language_model,
|
|
1251
|
+
)
|
|
1252
|
+
|
|
1253
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
1254
|
+
apply_liger_kernel_to_gemma2(
|
|
1255
|
+
rope=rope,
|
|
1256
|
+
cross_entropy=False,
|
|
1257
|
+
fused_linear_cross_entropy=False,
|
|
1258
|
+
rms_norm=rms_norm,
|
|
1259
|
+
geglu=geglu,
|
|
1260
|
+
model=language_model,
|
|
1261
|
+
)
|
|
1262
|
+
else:
|
|
1263
|
+
raise TypeError(
|
|
1264
|
+
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
|
|
1265
|
+
)
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def apply_liger_kernel_to_qwen2(
|
|
1269
|
+
rope: bool = True,
|
|
1270
|
+
cross_entropy: bool = False,
|
|
1271
|
+
fused_linear_cross_entropy: bool = True,
|
|
1272
|
+
rms_norm: bool = True,
|
|
1273
|
+
swiglu: bool = True,
|
|
1274
|
+
model: PreTrainedModel = None,
|
|
1275
|
+
) -> None:
|
|
1276
|
+
"""
|
|
1277
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
|
|
1278
|
+
|
|
1279
|
+
Args:
|
|
1280
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1281
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1282
|
+
fused_linear_cross_entropy (bool):
|
|
1283
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1284
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1285
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1286
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1287
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1288
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1289
|
+
loaded. Default is None.
|
|
1290
|
+
"""
|
|
1291
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1292
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
from transformers.models.qwen2 import modeling_qwen2
|
|
1296
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
1297
|
+
|
|
1298
|
+
if rope:
|
|
1299
|
+
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1300
|
+
if rms_norm:
|
|
1301
|
+
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
|
1302
|
+
|
|
1303
|
+
if cross_entropy:
|
|
1304
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1305
|
+
from transformers.loss.loss_utils import nn
|
|
1306
|
+
|
|
1307
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1308
|
+
else:
|
|
1309
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1310
|
+
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1311
|
+
|
|
1312
|
+
if fused_linear_cross_entropy:
|
|
1313
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1314
|
+
if model is not None:
|
|
1315
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1316
|
+
else:
|
|
1317
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1318
|
+
else: # if version < 4.46.1
|
|
1319
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1320
|
+
if model is not None:
|
|
1321
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1322
|
+
else:
|
|
1323
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1324
|
+
|
|
1325
|
+
if swiglu:
|
|
1326
|
+
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
1327
|
+
|
|
1328
|
+
if model is not None:
|
|
1329
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1330
|
+
# instance variables that reference already-instantiated modules
|
|
1331
|
+
|
|
1332
|
+
# get the base model from the model instance
|
|
1333
|
+
base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
|
|
1334
|
+
|
|
1335
|
+
if rms_norm:
|
|
1336
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1337
|
+
|
|
1338
|
+
for decoder_layer in base_model.layers:
|
|
1339
|
+
if swiglu:
|
|
1340
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1341
|
+
if rms_norm:
|
|
1342
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1343
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1344
|
+
|
|
1345
|
+
|
|
1346
|
+
def apply_liger_kernel_to_qwen3(
|
|
1347
|
+
rope: bool = True,
|
|
1348
|
+
cross_entropy: bool = False,
|
|
1349
|
+
fused_linear_cross_entropy: bool = True,
|
|
1350
|
+
rms_norm: bool = True,
|
|
1351
|
+
swiglu: bool = True,
|
|
1352
|
+
model: PreTrainedModel = None,
|
|
1353
|
+
) -> None:
|
|
1354
|
+
"""
|
|
1355
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1356
|
+
"""
|
|
1357
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1358
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
from transformers.models.qwen3 import modeling_qwen3
|
|
1362
|
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
|
|
1363
|
+
|
|
1364
|
+
from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
|
|
1365
|
+
|
|
1366
|
+
if rope:
|
|
1367
|
+
modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1368
|
+
|
|
1369
|
+
if rms_norm:
|
|
1370
|
+
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
|
1371
|
+
|
|
1372
|
+
if cross_entropy:
|
|
1373
|
+
from transformers.loss.loss_utils import nn
|
|
1374
|
+
|
|
1375
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1376
|
+
|
|
1377
|
+
if fused_linear_cross_entropy:
|
|
1378
|
+
if model is not None:
|
|
1379
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1380
|
+
else:
|
|
1381
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1382
|
+
|
|
1383
|
+
if swiglu:
|
|
1384
|
+
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
1385
|
+
|
|
1386
|
+
if model is not None:
|
|
1387
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1388
|
+
# instance variables that reference already-instantiated modules
|
|
1389
|
+
|
|
1390
|
+
# get the base model from the model instance
|
|
1391
|
+
base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
|
|
1392
|
+
|
|
1393
|
+
if rms_norm:
|
|
1394
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1395
|
+
for decoder_layer in base_model.layers:
|
|
1396
|
+
if swiglu:
|
|
1397
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1398
|
+
if rms_norm:
|
|
1399
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1400
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1401
|
+
|
|
1402
|
+
|
|
1403
|
+
def apply_liger_kernel_to_qwen3_moe(
|
|
1404
|
+
rope: bool = True,
|
|
1405
|
+
cross_entropy: bool = False,
|
|
1406
|
+
fused_linear_cross_entropy: bool = True,
|
|
1407
|
+
rms_norm: bool = True,
|
|
1408
|
+
swiglu: bool = True,
|
|
1409
|
+
model: PreTrainedModel = None,
|
|
1410
|
+
) -> None:
|
|
1411
|
+
"""
|
|
1412
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1413
|
+
"""
|
|
1414
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1415
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1416
|
+
)
|
|
1417
|
+
|
|
1418
|
+
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
1419
|
+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
|
|
1420
|
+
|
|
1421
|
+
from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
|
|
1422
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
1423
|
+
|
|
1424
|
+
if rope:
|
|
1425
|
+
modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1426
|
+
|
|
1427
|
+
if rms_norm:
|
|
1428
|
+
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
|
1429
|
+
|
|
1430
|
+
if cross_entropy:
|
|
1431
|
+
from transformers.loss.loss_utils import nn
|
|
1432
|
+
|
|
1433
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1434
|
+
|
|
1435
|
+
if fused_linear_cross_entropy:
|
|
1436
|
+
if model is not None:
|
|
1437
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1438
|
+
else:
|
|
1439
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1440
|
+
|
|
1441
|
+
if swiglu:
|
|
1442
|
+
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
1443
|
+
|
|
1444
|
+
if model is not None:
|
|
1445
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1446
|
+
# instance variables that reference already-instantiated modules
|
|
1447
|
+
|
|
1448
|
+
# get the base model from the model instance
|
|
1449
|
+
base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
|
|
1450
|
+
|
|
1451
|
+
if rms_norm:
|
|
1452
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1453
|
+
for decoder_layer in base_model.layers:
|
|
1454
|
+
if swiglu:
|
|
1455
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
1456
|
+
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
|
|
1457
|
+
if rms_norm:
|
|
1458
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1459
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
def apply_liger_kernel_to_gpt_oss(
|
|
1463
|
+
rope: bool = True,
|
|
1464
|
+
cross_entropy: bool = False,
|
|
1465
|
+
fused_linear_cross_entropy: bool = True,
|
|
1466
|
+
rms_norm: bool = True,
|
|
1467
|
+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
|
|
1468
|
+
model: PreTrainedModel = None,
|
|
1469
|
+
) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
|
|
1472
|
+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
|
|
1473
|
+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
|
|
1474
|
+
implementation with clamping and MXFP4 quantization.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1478
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1479
|
+
fused_linear_cross_entropy (bool):
|
|
1480
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1481
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1482
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1483
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1484
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1485
|
+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
|
|
1486
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1487
|
+
loaded. Default is None.
|
|
1488
|
+
"""
|
|
1489
|
+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
|
|
1490
|
+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1493
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1494
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
from transformers.models.gpt_oss import modeling_gpt_oss
|
|
1498
|
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
|
|
1499
|
+
|
|
1500
|
+
if rope:
|
|
1501
|
+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1502
|
+
|
|
1503
|
+
if rms_norm:
|
|
1504
|
+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
|
|
1505
|
+
|
|
1506
|
+
if cross_entropy:
|
|
1507
|
+
from transformers.loss.loss_utils import nn
|
|
1508
|
+
|
|
1509
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1510
|
+
|
|
1511
|
+
if fused_linear_cross_entropy:
|
|
1512
|
+
if model is not None:
|
|
1513
|
+
model.forward = MethodType(gpt_oss_lce_forward, model)
|
|
1514
|
+
else:
|
|
1515
|
+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
|
|
1516
|
+
|
|
1517
|
+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
|
|
1518
|
+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
|
|
1519
|
+
|
|
1520
|
+
if model is not None:
|
|
1521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1522
|
+
# instance variables that reference already-instantiated modules
|
|
1523
|
+
|
|
1524
|
+
# get the base model from the model instance
|
|
1525
|
+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
|
|
1526
|
+
|
|
1527
|
+
if rms_norm:
|
|
1528
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1529
|
+
for decoder_layer in base_model.layers:
|
|
1530
|
+
if rms_norm:
|
|
1531
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1532
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1533
|
+
|
|
1534
|
+
|
|
1535
|
+
def apply_liger_kernel_to_qwen2_vl(
|
|
1536
|
+
rope: bool = True,
|
|
1537
|
+
cross_entropy: bool = False,
|
|
1538
|
+
fused_linear_cross_entropy: bool = True,
|
|
1539
|
+
rms_norm: bool = True,
|
|
1540
|
+
layer_norm: bool = True,
|
|
1541
|
+
swiglu: bool = True,
|
|
1542
|
+
model: PreTrainedModel = None,
|
|
1543
|
+
) -> None:
|
|
1544
|
+
"""
|
|
1545
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
1546
|
+
NOTE: Qwen2-VL is not supported in transformers<4.52.4
|
|
1547
|
+
|
|
1548
|
+
Args:
|
|
1549
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1550
|
+
fused_linear_cross_entropy (bool):
|
|
1551
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1552
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1553
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1554
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1555
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
1556
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1557
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1558
|
+
loaded. Default is None.
|
|
1559
|
+
"""
|
|
1560
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1561
|
+
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
|
|
1562
|
+
return
|
|
1563
|
+
|
|
1564
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1565
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1566
|
+
)
|
|
1567
|
+
|
|
1568
|
+
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
1569
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
|
1570
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
|
1571
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
1572
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
|
1573
|
+
|
|
1574
|
+
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
|
1575
|
+
|
|
1576
|
+
if rope:
|
|
1577
|
+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
1578
|
+
if rms_norm:
|
|
1579
|
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1580
|
+
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1581
|
+
if layer_norm and model is None:
|
|
1582
|
+
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1583
|
+
if cross_entropy:
|
|
1584
|
+
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1585
|
+
if fused_linear_cross_entropy:
|
|
1586
|
+
if model is not None:
|
|
1587
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1588
|
+
else:
|
|
1589
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1590
|
+
if swiglu:
|
|
1591
|
+
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1592
|
+
|
|
1593
|
+
if model is not None:
|
|
1594
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1595
|
+
# instance variables that reference already-instantiated modules
|
|
1596
|
+
|
|
1597
|
+
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
|
|
1598
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1599
|
+
# Not sure if it is subject to changes in the future.
|
|
1600
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1601
|
+
text_model: Qwen2VLTextModel = model.language_model
|
|
1602
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1603
|
+
elif isinstance(model, Qwen2VLTextModel):
|
|
1604
|
+
text_model: Qwen2VLTextModel = model
|
|
1605
|
+
vision_model = None
|
|
1606
|
+
else:
|
|
1607
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1608
|
+
raise TypeError(
|
|
1609
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1610
|
+
)
|
|
1611
|
+
|
|
1612
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
|
1613
|
+
if vision_model is not None:
|
|
1614
|
+
for vision_block in vision_model.blocks:
|
|
1615
|
+
if layer_norm:
|
|
1616
|
+
_patch_layer_norm_module(vision_block.norm1)
|
|
1617
|
+
_patch_layer_norm_module(vision_block.norm2)
|
|
1618
|
+
|
|
1619
|
+
# Patch Qwen2VisionTextModel
|
|
1620
|
+
if text_model is not None:
|
|
1621
|
+
if rms_norm:
|
|
1622
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1623
|
+
for decoder_layer in text_model.layers:
|
|
1624
|
+
if swiglu:
|
|
1625
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1626
|
+
if rms_norm:
|
|
1627
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1628
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1629
|
+
|
|
1630
|
+
|
|
1631
|
+
def apply_liger_kernel_to_qwen2_5_vl(
|
|
1632
|
+
rope: bool = True,
|
|
1633
|
+
cross_entropy: bool = False,
|
|
1634
|
+
fused_linear_cross_entropy: bool = True,
|
|
1635
|
+
rms_norm: bool = True,
|
|
1636
|
+
swiglu: bool = True,
|
|
1637
|
+
model: PreTrainedModel = None,
|
|
1638
|
+
) -> None:
|
|
1639
|
+
"""
|
|
1640
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
|
|
1641
|
+
NOTE: Qwen2.5-VL is not available in transformers<4.48.2
|
|
1642
|
+
|
|
1643
|
+
Args:
|
|
1644
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1645
|
+
fused_linear_cross_entropy (bool):
|
|
1646
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1647
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1648
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1649
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1650
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
1651
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1652
|
+
loaded. Default is None.
|
|
1653
|
+
"""
|
|
1654
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1655
|
+
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
|
|
1656
|
+
return
|
|
1657
|
+
|
|
1658
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1659
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1660
|
+
)
|
|
1661
|
+
|
|
1662
|
+
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
1663
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
|
1664
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
|
1665
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
|
1666
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
|
1667
|
+
|
|
1668
|
+
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
|
1669
|
+
|
|
1670
|
+
if rope:
|
|
1671
|
+
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
1672
|
+
if rms_norm:
|
|
1673
|
+
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1674
|
+
if cross_entropy:
|
|
1675
|
+
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1676
|
+
if fused_linear_cross_entropy:
|
|
1677
|
+
if model is not None:
|
|
1678
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1679
|
+
else:
|
|
1680
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1681
|
+
if swiglu:
|
|
1682
|
+
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1683
|
+
|
|
1684
|
+
if model is not None:
|
|
1685
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1686
|
+
# instance variables that reference already-instantiated modules
|
|
1687
|
+
|
|
1688
|
+
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
|
|
1689
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1690
|
+
# Not sure if it is subject to changes in the future.
|
|
1691
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1692
|
+
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1693
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1694
|
+
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
1695
|
+
text_model: Qwen2_5_VLTextModel = model
|
|
1696
|
+
vision_model = None
|
|
1697
|
+
else:
|
|
1698
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1699
|
+
raise TypeError(
|
|
1700
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1701
|
+
)
|
|
1702
|
+
|
|
1703
|
+
if vision_model is not None:
|
|
1704
|
+
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1705
|
+
for vision_block in model.visual.blocks:
|
|
1706
|
+
if rms_norm:
|
|
1707
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
1708
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
1709
|
+
|
|
1710
|
+
if text_model is not None:
|
|
1711
|
+
if rms_norm:
|
|
1712
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1713
|
+
for decoder_layer in text_model.layers:
|
|
1714
|
+
if swiglu:
|
|
1715
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1716
|
+
if rms_norm:
|
|
1717
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1718
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1719
|
+
|
|
1720
|
+
|
|
1721
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1722
|
+
rope: bool = True,
|
|
1723
|
+
cross_entropy: bool = False,
|
|
1724
|
+
fused_linear_cross_entropy: bool = True,
|
|
1725
|
+
rms_norm: bool = True,
|
|
1726
|
+
swiglu: bool = False,
|
|
1727
|
+
model: PreTrainedModel = None,
|
|
1728
|
+
) -> None:
|
|
1729
|
+
"""
|
|
1730
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1731
|
+
|
|
1732
|
+
Args:
|
|
1733
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1734
|
+
fused_linear_cross_entropy (bool):
|
|
1735
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1736
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1737
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1738
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1739
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1740
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1741
|
+
loaded. Default is None.
|
|
1742
|
+
"""
|
|
1743
|
+
|
|
1744
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1745
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1746
|
+
)
|
|
1747
|
+
|
|
1748
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1749
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1750
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1751
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1752
|
+
|
|
1753
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1754
|
+
|
|
1755
|
+
if rope:
|
|
1756
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1757
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1758
|
+
|
|
1759
|
+
if rms_norm:
|
|
1760
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1761
|
+
|
|
1762
|
+
if cross_entropy:
|
|
1763
|
+
from transformers.loss.loss_utils import nn
|
|
1764
|
+
|
|
1765
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1766
|
+
|
|
1767
|
+
if fused_linear_cross_entropy:
|
|
1768
|
+
if model is not None:
|
|
1769
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1770
|
+
else:
|
|
1771
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1772
|
+
|
|
1773
|
+
if model is not None and rms_norm:
|
|
1774
|
+
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
|
|
1775
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1776
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1777
|
+
text_model = model
|
|
1778
|
+
else:
|
|
1779
|
+
raise TypeError(
|
|
1780
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1784
|
+
|
|
1785
|
+
if text_model is not None:
|
|
1786
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1787
|
+
for decoder_layer in text_model.layers:
|
|
1788
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1789
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1790
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1791
|
+
if self_attn is not None:
|
|
1792
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1793
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1794
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1795
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1799
|
+
rope: bool = True,
|
|
1800
|
+
cross_entropy: bool = False,
|
|
1801
|
+
fused_linear_cross_entropy: bool = True,
|
|
1802
|
+
rms_norm: bool = True,
|
|
1803
|
+
swiglu: bool = False,
|
|
1804
|
+
model: PreTrainedModel = None,
|
|
1805
|
+
) -> None:
|
|
1806
|
+
"""
|
|
1807
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1808
|
+
|
|
1809
|
+
Args:
|
|
1810
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1811
|
+
fused_linear_cross_entropy (bool):
|
|
1812
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1813
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1814
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1815
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1816
|
+
loaded. Default is None.
|
|
1817
|
+
"""
|
|
1818
|
+
|
|
1819
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1820
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1824
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1825
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1826
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1827
|
+
|
|
1828
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1829
|
+
|
|
1830
|
+
if rope:
|
|
1831
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1832
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
|
|
1833
|
+
|
|
1834
|
+
if rms_norm:
|
|
1835
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1836
|
+
|
|
1837
|
+
if cross_entropy:
|
|
1838
|
+
from transformers.loss.loss_utils import nn
|
|
1839
|
+
|
|
1840
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1841
|
+
|
|
1842
|
+
if fused_linear_cross_entropy:
|
|
1843
|
+
if model is not None:
|
|
1844
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1845
|
+
else:
|
|
1846
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1847
|
+
|
|
1848
|
+
if model is not None and rms_norm:
|
|
1849
|
+
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
|
|
1850
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1851
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1852
|
+
text_model = model
|
|
1853
|
+
else:
|
|
1854
|
+
raise TypeError(
|
|
1855
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1856
|
+
)
|
|
1857
|
+
|
|
1858
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1859
|
+
|
|
1860
|
+
if text_model is not None:
|
|
1861
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1862
|
+
for decoder_layer in text_model.layers:
|
|
1863
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1864
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1865
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1866
|
+
if self_attn is not None:
|
|
1867
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1868
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1869
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1870
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1871
|
+
|
|
1872
|
+
|
|
1873
|
+
def apply_liger_kernel_to_phi3(
|
|
1874
|
+
rope: bool = True,
|
|
1875
|
+
cross_entropy: bool = False,
|
|
1876
|
+
fused_linear_cross_entropy: bool = True,
|
|
1877
|
+
rms_norm: bool = True,
|
|
1878
|
+
swiglu: bool = True,
|
|
1879
|
+
model: PreTrainedModel = None,
|
|
1880
|
+
) -> None:
|
|
1881
|
+
"""
|
|
1882
|
+
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
|
|
1883
|
+
|
|
1884
|
+
Args:
|
|
1885
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1886
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1887
|
+
fused_linear_cross_entropy (bool):
|
|
1888
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1889
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1890
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1891
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1892
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
1893
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1894
|
+
loaded. Default is None.
|
|
1895
|
+
"""
|
|
1896
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1897
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1898
|
+
)
|
|
1899
|
+
|
|
1900
|
+
from transformers.models.phi3 import modeling_phi3
|
|
1901
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
1902
|
+
|
|
1903
|
+
if rope:
|
|
1904
|
+
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
1905
|
+
if rms_norm:
|
|
1906
|
+
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
|
|
1907
|
+
if swiglu:
|
|
1908
|
+
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1909
|
+
if cross_entropy:
|
|
1910
|
+
from transformers.loss.loss_utils import nn
|
|
1911
|
+
|
|
1912
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1913
|
+
if fused_linear_cross_entropy:
|
|
1914
|
+
if model is not None:
|
|
1915
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1916
|
+
else:
|
|
1917
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1918
|
+
|
|
1919
|
+
if model is not None:
|
|
1920
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1921
|
+
# instance variables that reference already-instantiated modules
|
|
1922
|
+
|
|
1923
|
+
# get the base model from the model instance
|
|
1924
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
1925
|
+
|
|
1926
|
+
if rms_norm:
|
|
1927
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1928
|
+
|
|
1929
|
+
for decoder_layer in base_model.layers:
|
|
1930
|
+
if swiglu:
|
|
1931
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
1932
|
+
if rms_norm:
|
|
1933
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1934
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1935
|
+
|
|
1936
|
+
|
|
1937
|
+
def apply_liger_kernel_to_olmo2(
|
|
1938
|
+
rope: bool = True,
|
|
1939
|
+
cross_entropy: bool = False,
|
|
1940
|
+
fused_linear_cross_entropy: bool = True,
|
|
1941
|
+
rms_norm: bool = True,
|
|
1942
|
+
swiglu: bool = True,
|
|
1943
|
+
model: PreTrainedModel = None,
|
|
1944
|
+
) -> None:
|
|
1945
|
+
"""
|
|
1946
|
+
Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
|
|
1947
|
+
|
|
1948
|
+
Args:
|
|
1949
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1950
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1951
|
+
fused_linear_cross_entropy (bool):
|
|
1952
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1953
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1954
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1955
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1956
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
|
|
1957
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1958
|
+
loaded. Default is None.
|
|
1959
|
+
"""
|
|
1960
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1961
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
from transformers.models.olmo2 import modeling_olmo2
|
|
1965
|
+
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
1966
|
+
|
|
1967
|
+
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
1968
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1969
|
+
|
|
1970
|
+
if rope:
|
|
1971
|
+
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1972
|
+
if rms_norm:
|
|
1973
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
|
1974
|
+
if swiglu:
|
|
1975
|
+
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1976
|
+
if cross_entropy:
|
|
1977
|
+
from transformers.loss.loss_utils import nn
|
|
1978
|
+
|
|
1979
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1980
|
+
if fused_linear_cross_entropy:
|
|
1981
|
+
if model is not None:
|
|
1982
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1983
|
+
else:
|
|
1984
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1985
|
+
|
|
1986
|
+
if model is not None:
|
|
1987
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1988
|
+
# instance variables that reference already-instantiated modules
|
|
1989
|
+
|
|
1990
|
+
# get the base model from the model instance
|
|
1991
|
+
base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
|
|
1992
|
+
|
|
1993
|
+
if rms_norm:
|
|
1994
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1995
|
+
|
|
1996
|
+
for decoder_layer in base_model.layers:
|
|
1997
|
+
if swiglu:
|
|
1998
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1999
|
+
if rms_norm:
|
|
2000
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2001
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2002
|
+
|
|
2003
|
+
|
|
2004
|
+
def apply_liger_kernel_to_olmo3(
|
|
2005
|
+
rope: bool = True,
|
|
2006
|
+
cross_entropy: bool = False,
|
|
2007
|
+
fused_linear_cross_entropy: bool = True,
|
|
2008
|
+
rms_norm: bool = True,
|
|
2009
|
+
swiglu: bool = True,
|
|
2010
|
+
model: PreTrainedModel = None,
|
|
2011
|
+
) -> None:
|
|
2012
|
+
"""
|
|
2013
|
+
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
|
|
2014
|
+
|
|
2015
|
+
Args:
|
|
2016
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2017
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2018
|
+
fused_linear_cross_entropy (bool):
|
|
2019
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2020
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2021
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2022
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2023
|
+
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
|
|
2024
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2025
|
+
loaded. Default is None.
|
|
2026
|
+
"""
|
|
2027
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2028
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2029
|
+
)
|
|
2030
|
+
|
|
2031
|
+
from transformers.models.olmo3 import modeling_olmo3
|
|
2032
|
+
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
|
|
2033
|
+
|
|
2034
|
+
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
|
|
2035
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
2036
|
+
|
|
2037
|
+
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
|
|
2038
|
+
if rope:
|
|
2039
|
+
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2040
|
+
if rms_norm:
|
|
2041
|
+
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
|
|
2042
|
+
if swiglu:
|
|
2043
|
+
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
|
|
2044
|
+
if cross_entropy:
|
|
2045
|
+
from transformers.loss.loss_utils import nn
|
|
2046
|
+
|
|
2047
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2048
|
+
if fused_linear_cross_entropy:
|
|
2049
|
+
if model is not None:
|
|
2050
|
+
model.forward = MethodType(olmo3_lce_forward, model)
|
|
2051
|
+
else:
|
|
2052
|
+
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
|
|
2053
|
+
|
|
2054
|
+
if model is not None:
|
|
2055
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2056
|
+
# instance variables that reference already-instantiated modules
|
|
2057
|
+
|
|
2058
|
+
# get the base model from the model instance
|
|
2059
|
+
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
|
|
2060
|
+
|
|
2061
|
+
if rms_norm:
|
|
2062
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2063
|
+
|
|
2064
|
+
for decoder_layer in base_model.layers:
|
|
2065
|
+
if swiglu:
|
|
2066
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2067
|
+
if rms_norm:
|
|
2068
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2069
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2070
|
+
|
|
2071
|
+
|
|
2072
|
+
def apply_liger_kernel_to_glm4(
|
|
2073
|
+
rope: bool = False,
|
|
2074
|
+
cross_entropy: bool = False,
|
|
2075
|
+
fused_linear_cross_entropy: bool = True,
|
|
2076
|
+
rms_norm: bool = True,
|
|
2077
|
+
swiglu: bool = True,
|
|
2078
|
+
model: PreTrainedModel = None,
|
|
2079
|
+
) -> None:
|
|
2080
|
+
"""
|
|
2081
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
|
|
2082
|
+
|
|
2083
|
+
Args:
|
|
2084
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2085
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2086
|
+
fused_linear_cross_entropy (bool):
|
|
2087
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2088
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2089
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2090
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2091
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
2092
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2093
|
+
loaded. Default is None.
|
|
2094
|
+
"""
|
|
2095
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2096
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2097
|
+
)
|
|
2098
|
+
|
|
2099
|
+
from transformers.models.glm4 import modeling_glm4
|
|
2100
|
+
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
|
2101
|
+
|
|
2102
|
+
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
|
2103
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2104
|
+
|
|
2105
|
+
if rope:
|
|
2106
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2107
|
+
if rms_norm:
|
|
2108
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
|
2109
|
+
if swiglu:
|
|
2110
|
+
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
|
2111
|
+
if cross_entropy:
|
|
2112
|
+
from transformers.loss.loss_utils import nn
|
|
2113
|
+
|
|
2114
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2115
|
+
if fused_linear_cross_entropy:
|
|
2116
|
+
if model is not None:
|
|
2117
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
2118
|
+
else:
|
|
2119
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
2120
|
+
|
|
2121
|
+
if model is not None:
|
|
2122
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2123
|
+
# instance variables that reference already-instantiated modules
|
|
2124
|
+
|
|
2125
|
+
# get the base model from the model instance
|
|
2126
|
+
base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
|
|
2127
|
+
|
|
2128
|
+
if rms_norm:
|
|
2129
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
2130
|
+
|
|
2131
|
+
for decoder_layer in base_model.layers:
|
|
2132
|
+
if swiglu:
|
|
2133
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2134
|
+
if rms_norm:
|
|
2135
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
|
|
2136
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2137
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
|
|
2138
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
|
|
2139
|
+
|
|
2140
|
+
|
|
2141
|
+
def apply_liger_kernel_to_glm4v(
|
|
2142
|
+
rope: bool = False,
|
|
2143
|
+
cross_entropy: bool = False,
|
|
2144
|
+
fused_linear_cross_entropy: bool = True,
|
|
2145
|
+
rms_norm: bool = True,
|
|
2146
|
+
swiglu: bool = True,
|
|
2147
|
+
model: PreTrainedModel = None,
|
|
2148
|
+
) -> None:
|
|
2149
|
+
"""
|
|
2150
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
|
|
2151
|
+
|
|
2152
|
+
Args:
|
|
2153
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2154
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2155
|
+
fused_linear_cross_entropy (bool):
|
|
2156
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2157
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2158
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2159
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2160
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
2161
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2162
|
+
loaded. Default is None.
|
|
2163
|
+
"""
|
|
2164
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2165
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2166
|
+
)
|
|
2167
|
+
|
|
2168
|
+
from transformers.models.glm4v import modeling_glm4v
|
|
2169
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
|
|
2170
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
|
|
2171
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
|
|
2172
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
|
|
2173
|
+
|
|
2174
|
+
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
|
|
2175
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2176
|
+
|
|
2177
|
+
if rope:
|
|
2178
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2179
|
+
if rms_norm:
|
|
2180
|
+
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
|
|
2181
|
+
if cross_entropy:
|
|
2182
|
+
from transformers.loss.loss_utils import nn
|
|
2183
|
+
|
|
2184
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2185
|
+
if fused_linear_cross_entropy:
|
|
2186
|
+
if model is not None:
|
|
2187
|
+
model.forward = MethodType(glm4v_lce_forward, model)
|
|
2188
|
+
else:
|
|
2189
|
+
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
|
|
2190
|
+
|
|
2191
|
+
if model is not None:
|
|
2192
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2193
|
+
# instance variables that reference already-instantiated modules
|
|
2194
|
+
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
|
|
2195
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
2196
|
+
# Not sure if it is subject to changes in the future.
|
|
2197
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
|
|
2198
|
+
text_model: Glm4vTextModel = model.language_model
|
|
2199
|
+
vision_model: Glm4vVisionModel = model.visual
|
|
2200
|
+
elif isinstance(model, Glm4vTextModel):
|
|
2201
|
+
text_model: Glm4vTextModel = model
|
|
2202
|
+
vision_model = None
|
|
2203
|
+
else:
|
|
2204
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2205
|
+
raise TypeError(
|
|
2206
|
+
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
|
|
2207
|
+
)
|
|
2208
|
+
|
|
2209
|
+
if vision_model is not None:
|
|
2210
|
+
for vision_block in vision_model.blocks:
|
|
2211
|
+
if rms_norm:
|
|
2212
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2213
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2214
|
+
if swiglu:
|
|
2215
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2216
|
+
|
|
2217
|
+
if text_model is not None:
|
|
2218
|
+
if rms_norm:
|
|
2219
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2220
|
+
for decoder_layer in text_model.layers:
|
|
2221
|
+
if swiglu:
|
|
2222
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2223
|
+
if rms_norm:
|
|
2224
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2225
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2226
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
|
|
2227
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
|
|
2228
|
+
|
|
2229
|
+
|
|
2230
|
+
def apply_liger_kernel_to_glm4v_moe(
|
|
2231
|
+
rope: bool = False,
|
|
2232
|
+
cross_entropy: bool = False,
|
|
2233
|
+
fused_linear_cross_entropy: bool = True,
|
|
2234
|
+
rms_norm: bool = True,
|
|
2235
|
+
swiglu: bool = True,
|
|
2236
|
+
model: PreTrainedModel = None,
|
|
2237
|
+
) -> None:
|
|
2238
|
+
"""
|
|
2239
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
2240
|
+
|
|
2241
|
+
Args:
|
|
2242
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2243
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2244
|
+
fused_linear_cross_entropy (bool):
|
|
2245
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2246
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2247
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2248
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2249
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
2250
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2251
|
+
loaded. Default is None.
|
|
2252
|
+
"""
|
|
2253
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2254
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2255
|
+
)
|
|
2256
|
+
|
|
2257
|
+
from transformers.models.glm4v_moe import modeling_glm4v_moe
|
|
2258
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
|
|
2259
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
|
|
2260
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
|
|
2261
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
|
|
2262
|
+
|
|
2263
|
+
from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
|
|
2264
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2265
|
+
|
|
2266
|
+
if rope:
|
|
2267
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2268
|
+
if rms_norm:
|
|
2269
|
+
modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
|
|
2270
|
+
modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
|
|
2271
|
+
if cross_entropy:
|
|
2272
|
+
from transformers.loss.loss_utils import nn
|
|
2273
|
+
|
|
2274
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2275
|
+
if fused_linear_cross_entropy:
|
|
2276
|
+
if model is not None:
|
|
2277
|
+
model.forward = MethodType(glm4v_moe_lce_forward, model)
|
|
2278
|
+
else:
|
|
2279
|
+
modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
|
|
2280
|
+
|
|
2281
|
+
if model is not None:
|
|
2282
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2283
|
+
# instance variables that reference already-instantiated modules
|
|
2284
|
+
if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
|
|
2285
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
2286
|
+
# Not sure if it is subject to changes in the future.
|
|
2287
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
|
|
2288
|
+
text_model: Glm4vMoeTextModel = model.language_model
|
|
2289
|
+
vision_model: Glm4vMoeVisionModel = model.visual
|
|
2290
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2291
|
+
elif isinstance(model, Glm4vMoeTextModel):
|
|
2292
|
+
text_model: Glm4vMoeTextModel = model
|
|
2293
|
+
vision_model = None
|
|
2294
|
+
else:
|
|
2295
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2296
|
+
raise TypeError(
|
|
2297
|
+
f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
|
|
2298
|
+
)
|
|
2299
|
+
|
|
2300
|
+
if vision_model is not None:
|
|
2301
|
+
_patch_rms_norm_module(vision_model.post_conv_layernorm)
|
|
2302
|
+
_patch_rms_norm_module(vision_model.post_layernorm)
|
|
2303
|
+
for vision_block in vision_model.blocks:
|
|
2304
|
+
if rms_norm:
|
|
2305
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2306
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2307
|
+
if swiglu:
|
|
2308
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2309
|
+
|
|
2310
|
+
if text_model is not None:
|
|
2311
|
+
if rms_norm:
|
|
2312
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2313
|
+
for decoder_layer in text_model.layers:
|
|
2314
|
+
if swiglu:
|
|
2315
|
+
decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2316
|
+
if rms_norm:
|
|
2317
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2318
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2319
|
+
if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
|
|
2320
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2321
|
+
if experts is not None:
|
|
2322
|
+
for expert in experts:
|
|
2323
|
+
_patch_swiglu_module(expert, LigerSwiGLUMLP)
|
|
2324
|
+
if decoder_layer.mlp.shared_experts is not None:
|
|
2325
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
|
|
2326
|
+
for decoder_layer in text_model.layers:
|
|
2327
|
+
if rms_norm:
|
|
2328
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2329
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2330
|
+
|
|
2331
|
+
|
|
2332
|
+
def apply_liger_kernel_to_internvl(
|
|
2333
|
+
cross_entropy: bool = False,
|
|
2334
|
+
fused_linear_cross_entropy: bool = True,
|
|
2335
|
+
rms_norm: bool = True,
|
|
2336
|
+
layer_norm: bool = True,
|
|
2337
|
+
model: Optional[PreTrainedModel] = None,
|
|
2338
|
+
**kwargs,
|
|
2339
|
+
) -> None:
|
|
2340
|
+
"""
|
|
2341
|
+
Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
|
|
2342
|
+
Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
|
|
2343
|
+
However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
|
|
2344
|
+
NOTE: InternVL is not available in transformers<4.52.1
|
|
2345
|
+
|
|
2346
|
+
Args:
|
|
2347
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2348
|
+
fused_linear_cross_entropy (bool):
|
|
2349
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2350
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2351
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2352
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2353
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2354
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2355
|
+
loaded. Default is None.
|
|
2356
|
+
"""
|
|
2357
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2358
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2359
|
+
)
|
|
2360
|
+
import torch.nn as torch_nn
|
|
2361
|
+
|
|
2362
|
+
from transformers.models.internvl import modeling_internvl
|
|
2363
|
+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
|
|
2364
|
+
from transformers.models.internvl.modeling_internvl import InternVLModel
|
|
2365
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
|
|
2366
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
|
|
2367
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
|
|
2368
|
+
|
|
2369
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
2370
|
+
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
|
|
2371
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
2372
|
+
|
|
2373
|
+
if layer_norm and model is None:
|
|
2374
|
+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
|
|
2375
|
+
|
|
2376
|
+
if cross_entropy:
|
|
2377
|
+
logger.info("Apply liger cross entropy")
|
|
2378
|
+
|
|
2379
|
+
from transformers.loss.loss_utils import nn
|
|
2380
|
+
|
|
2381
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2382
|
+
if fused_linear_cross_entropy:
|
|
2383
|
+
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
|
|
2384
|
+
if rms_norm:
|
|
2385
|
+
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
|
|
2386
|
+
|
|
2387
|
+
if model is not None:
|
|
2388
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2389
|
+
# instance variables that reference already-instantiated modules
|
|
2390
|
+
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
|
|
2391
|
+
# NOTE: language_model and visual properties can be accessed throught conditional class.
|
|
2392
|
+
text_model = model.language_model
|
|
2393
|
+
vision_model: InternVLVisionModel = model.vision_tower
|
|
2394
|
+
else:
|
|
2395
|
+
raise TypeError(
|
|
2396
|
+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
|
|
2397
|
+
)
|
|
2398
|
+
|
|
2399
|
+
text_model_name = model.config.text_config.model_type
|
|
2400
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2401
|
+
|
|
2402
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2403
|
+
if text_liger_fn:
|
|
2404
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2405
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2406
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2407
|
+
|
|
2408
|
+
if remain_params:
|
|
2409
|
+
logger.warning(
|
|
2410
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2411
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2412
|
+
)
|
|
2413
|
+
text_kwargs["model"] = text_model
|
|
2414
|
+
text_liger_fn(**text_kwargs)
|
|
2415
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2416
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2417
|
+
|
|
2418
|
+
# Patch vision model RMSNorm layers
|
|
2419
|
+
if rms_norm:
|
|
2420
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2421
|
+
encoder_layer: InternVLVisionLayer
|
|
2422
|
+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
|
|
2423
|
+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
|
|
2424
|
+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
|
|
2425
|
+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
|
|
2426
|
+
|
|
2427
|
+
# Patch vision model LayerNorm layers
|
|
2428
|
+
if layer_norm:
|
|
2429
|
+
# Patch layernorm
|
|
2430
|
+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
|
|
2431
|
+
_patch_layer_norm_module(vision_model.layernorm)
|
|
2432
|
+
|
|
2433
|
+
# Patch encoder layers
|
|
2434
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2435
|
+
encoder_layer: InternVLVisionLayer
|
|
2436
|
+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
|
|
2437
|
+
_patch_layer_norm_module(encoder_layer.layernorm_before)
|
|
2438
|
+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
|
|
2439
|
+
_patch_layer_norm_module(encoder_layer.layernorm_after)
|
|
2440
|
+
|
|
2441
|
+
|
|
2442
|
+
def apply_liger_kernel_to_smolvlm(
|
|
2443
|
+
cross_entropy: bool = False,
|
|
2444
|
+
fused_linear_cross_entropy: bool = True,
|
|
2445
|
+
rms_norm: bool = True,
|
|
2446
|
+
layer_norm: bool = True,
|
|
2447
|
+
model: Optional[PreTrainedModel] = None,
|
|
2448
|
+
**kwargs,
|
|
2449
|
+
) -> None:
|
|
2450
|
+
"""
|
|
2451
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
|
|
2452
|
+
Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
|
|
2453
|
+
However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
|
|
2454
|
+
NOTE: SmolVLM is not available in transformers<4.50.0
|
|
2455
|
+
|
|
2456
|
+
Args:
|
|
2457
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2458
|
+
fused_linear_cross_entropy (bool):
|
|
2459
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2460
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2461
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2462
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2463
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2464
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2465
|
+
loaded. Default is None.
|
|
2466
|
+
"""
|
|
2467
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2468
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2469
|
+
)
|
|
2470
|
+
|
|
2471
|
+
from transformers.models.smolvlm import modeling_smolvlm
|
|
2472
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
|
|
2473
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
|
|
2474
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
|
|
2475
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
|
|
2476
|
+
|
|
2477
|
+
from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
|
|
2478
|
+
|
|
2479
|
+
# Patch LayerNorm for vision model if model is not provided (pre-initialization)
|
|
2480
|
+
if layer_norm and model is None:
|
|
2481
|
+
modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
|
|
2482
|
+
|
|
2483
|
+
if cross_entropy:
|
|
2484
|
+
logger.info("Apply liger cross entropy")
|
|
2485
|
+
|
|
2486
|
+
from transformers.loss.loss_utils import nn
|
|
2487
|
+
|
|
2488
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2489
|
+
if fused_linear_cross_entropy:
|
|
2490
|
+
if model is not None:
|
|
2491
|
+
model.forward = MethodType(smolvlm_lce_forward, model)
|
|
2492
|
+
else:
|
|
2493
|
+
modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
|
|
2494
|
+
if rms_norm:
|
|
2495
|
+
modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
|
|
2496
|
+
|
|
2497
|
+
if model is not None:
|
|
2498
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2499
|
+
# instance variables that reference already-instantiated modules
|
|
2500
|
+
if isinstance(model, SmolVLMForConditionalGeneration):
|
|
2501
|
+
text_model = model.model.text_model
|
|
2502
|
+
vision_model: SmolVLMVisionTransformer = model.model.vision_model
|
|
2503
|
+
elif isinstance(model, SmolVLMModel):
|
|
2504
|
+
text_model = model.text_model
|
|
2505
|
+
vision_model: SmolVLMVisionTransformer = model.vision_model
|
|
2506
|
+
else:
|
|
2507
|
+
raise TypeError(
|
|
2508
|
+
f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
|
|
2509
|
+
)
|
|
2510
|
+
|
|
2511
|
+
text_model_name = model.config.text_config.model_type
|
|
2512
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2513
|
+
|
|
2514
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2515
|
+
if text_liger_fn:
|
|
2516
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2517
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2518
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2519
|
+
|
|
2520
|
+
if remain_params:
|
|
2521
|
+
logger.warning(
|
|
2522
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2523
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2524
|
+
)
|
|
2525
|
+
text_kwargs["model"] = text_model
|
|
2526
|
+
text_liger_fn(**text_kwargs)
|
|
2527
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2528
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2529
|
+
|
|
2530
|
+
# Patch vision model LayerNorm layers
|
|
2531
|
+
if layer_norm:
|
|
2532
|
+
# Patch post_layernorm
|
|
2533
|
+
_patch_layer_norm_module(vision_model.post_layernorm)
|
|
2534
|
+
|
|
2535
|
+
# Patch encoder layers
|
|
2536
|
+
for encoder_layer in vision_model.encoder.layers:
|
|
2537
|
+
encoder_layer: SmolVLMEncoderLayer
|
|
2538
|
+
_patch_layer_norm_module(encoder_layer.layer_norm1)
|
|
2539
|
+
_patch_layer_norm_module(encoder_layer.layer_norm2)
|
|
2540
|
+
|
|
2541
|
+
|
|
2542
|
+
def apply_liger_kernel_to_falcon_h1(
|
|
2543
|
+
rope: bool = True,
|
|
2544
|
+
cross_entropy: bool = False,
|
|
2545
|
+
fused_linear_cross_entropy: bool = True,
|
|
2546
|
+
rms_norm: bool = True,
|
|
2547
|
+
swiglu: bool = False,
|
|
2548
|
+
model: PreTrainedModel = None,
|
|
2549
|
+
) -> None:
|
|
2550
|
+
"""
|
|
2551
|
+
Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
|
|
2552
|
+
Args:
|
|
2553
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2554
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
2555
|
+
fused_linear_cross_entropy (bool):
|
|
2556
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
2557
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2558
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2559
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
2560
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2561
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2562
|
+
loaded. Default is None.
|
|
2563
|
+
"""
|
|
2564
|
+
|
|
2565
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2566
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2567
|
+
)
|
|
2568
|
+
|
|
2569
|
+
from transformers.models.falcon_h1 import modeling_falcon_h1
|
|
2570
|
+
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
|
|
2571
|
+
|
|
2572
|
+
if rope:
|
|
2573
|
+
logger.info("Apply liger rotary pos emb.")
|
|
2574
|
+
modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2575
|
+
if rms_norm:
|
|
2576
|
+
logger.info("Apply liger RMSNorm")
|
|
2577
|
+
modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
|
|
2578
|
+
if swiglu:
|
|
2579
|
+
logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
|
|
2580
|
+
|
|
2581
|
+
if cross_entropy:
|
|
2582
|
+
logger.info("Apply liger cross entropy")
|
|
2583
|
+
from transformers.loss.loss_utils import nn
|
|
2584
|
+
|
|
2585
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2586
|
+
|
|
2587
|
+
if fused_linear_cross_entropy:
|
|
2588
|
+
if model is not None:
|
|
2589
|
+
model.forward = MethodType(falcon_h1_lce_forward, model)
|
|
2590
|
+
else:
|
|
2591
|
+
modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
|
|
2592
|
+
|
|
2593
|
+
if model is not None:
|
|
2594
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2595
|
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
2596
|
+
|
|
2597
|
+
# get the base model from the model instance
|
|
2598
|
+
base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
|
|
2599
|
+
|
|
2600
|
+
if rms_norm:
|
|
2601
|
+
_patch_rms_norm_module(base_model.final_layernorm)
|
|
2602
|
+
|
|
2603
|
+
for decoder_layer in base_model.layers:
|
|
2604
|
+
if swiglu:
|
|
2605
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2606
|
+
if rms_norm:
|
|
2607
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2608
|
+
_patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
|
|
2609
|
+
|
|
2610
|
+
|
|
2611
|
+
def apply_liger_kernel_to_qwen3_next(
|
|
2612
|
+
rope: bool = False,
|
|
2613
|
+
cross_entropy: bool = False,
|
|
2614
|
+
fused_linear_cross_entropy: bool = True,
|
|
2615
|
+
rms_norm: bool = True,
|
|
2616
|
+
swiglu: bool = True,
|
|
2617
|
+
model: PreTrainedModel = None,
|
|
2618
|
+
) -> None:
|
|
2619
|
+
"""
|
|
2620
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
2621
|
+
|
|
2622
|
+
Args:
|
|
2623
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2624
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2625
|
+
fused_linear_cross_entropy (bool):
|
|
2626
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2627
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2628
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2629
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2630
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
2631
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2632
|
+
loaded. Default is None.
|
|
2633
|
+
"""
|
|
2634
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2635
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2636
|
+
)
|
|
2637
|
+
|
|
2638
|
+
from transformers.models.qwen3_next import modeling_qwen3_next
|
|
2639
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
|
|
2640
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
|
|
2641
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
|
|
2642
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
|
2643
|
+
|
|
2644
|
+
from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
|
|
2645
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
|
|
2646
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
2647
|
+
|
|
2648
|
+
if rope:
|
|
2649
|
+
# It might enocunter nan issue
|
|
2650
|
+
# modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2651
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
|
|
2652
|
+
if rms_norm:
|
|
2653
|
+
modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
|
|
2654
|
+
if cross_entropy:
|
|
2655
|
+
from transformers.loss.loss_utils import nn
|
|
2656
|
+
|
|
2657
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2658
|
+
if fused_linear_cross_entropy:
|
|
2659
|
+
if model is not None:
|
|
2660
|
+
if isinstance(model, Qwen3NextForCausalLM):
|
|
2661
|
+
model.forward = MethodType(qwen3_next_lce_forward, model)
|
|
2662
|
+
else:
|
|
2663
|
+
raise TypeError(
|
|
2664
|
+
f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
|
|
2665
|
+
)
|
|
2666
|
+
else:
|
|
2667
|
+
modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
|
|
2668
|
+
if swiglu:
|
|
2669
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2670
|
+
modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
|
|
2671
|
+
|
|
2672
|
+
if model is not None:
|
|
2673
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2674
|
+
# instance variables that reference already-instantiated modules
|
|
2675
|
+
if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
|
|
2676
|
+
base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
|
|
2677
|
+
else:
|
|
2678
|
+
raise TypeError(
|
|
2679
|
+
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
|
|
2680
|
+
)
|
|
2681
|
+
|
|
2682
|
+
if rms_norm:
|
|
2683
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2684
|
+
|
|
2685
|
+
for decoder_layer in base_model.layers:
|
|
2686
|
+
if rms_norm:
|
|
2687
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2688
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2689
|
+
|
|
2690
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2691
|
+
if swiglu:
|
|
2692
|
+
if isinstance(decoder_layer.mlp, Qwen3NextMLP):
|
|
2693
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
|
|
2694
|
+
if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
|
|
2695
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
|
|
2696
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2697
|
+
if experts is not None:
|
|
2698
|
+
for expert in experts:
|
|
2699
|
+
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
|
|
2700
|
+
|
|
2701
|
+
|
|
2702
|
+
def apply_liger_kernel_to_hunyuan_v1_dense(
|
|
2703
|
+
rope: bool = True,
|
|
2704
|
+
cross_entropy: bool = False,
|
|
2705
|
+
fused_linear_cross_entropy: bool = True,
|
|
2706
|
+
rms_norm: bool = True,
|
|
2707
|
+
swiglu: bool = True,
|
|
2708
|
+
model: PreTrainedModel = None,
|
|
2709
|
+
) -> None:
|
|
2710
|
+
"""
|
|
2711
|
+
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
|
|
2712
|
+
"""
|
|
2713
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2714
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2715
|
+
)
|
|
2716
|
+
|
|
2717
|
+
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
|
|
2718
|
+
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
|
|
2719
|
+
|
|
2720
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
|
|
2721
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2722
|
+
|
|
2723
|
+
if rope:
|
|
2724
|
+
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2725
|
+
|
|
2726
|
+
if rms_norm:
|
|
2727
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
|
|
2728
|
+
|
|
2729
|
+
if cross_entropy:
|
|
2730
|
+
from transformers.loss.loss_utils import nn
|
|
2731
|
+
|
|
2732
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2733
|
+
|
|
2734
|
+
if fused_linear_cross_entropy:
|
|
2735
|
+
if model is not None:
|
|
2736
|
+
model.forward = MethodType(hunyuan_v1_lce_forward, model)
|
|
2737
|
+
else:
|
|
2738
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
|
|
2739
|
+
|
|
2740
|
+
if swiglu:
|
|
2741
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2742
|
+
|
|
2743
|
+
if model is not None:
|
|
2744
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2745
|
+
# instance variables that reference already-instantiated modules
|
|
2746
|
+
|
|
2747
|
+
# get the base model from the model instance
|
|
2748
|
+
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
|
|
2749
|
+
|
|
2750
|
+
if rms_norm:
|
|
2751
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2752
|
+
for decoder_layer in base_model.layers:
|
|
2753
|
+
if swiglu:
|
|
2754
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
|
|
2755
|
+
if rms_norm:
|
|
2756
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2757
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2758
|
+
|
|
2759
|
+
|
|
2760
|
+
def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
2761
|
+
rope: bool = True,
|
|
2762
|
+
cross_entropy: bool = False,
|
|
2763
|
+
fused_linear_cross_entropy: bool = True,
|
|
2764
|
+
rms_norm: bool = True,
|
|
2765
|
+
swiglu: bool = True,
|
|
2766
|
+
model: PreTrainedModel = None,
|
|
2767
|
+
) -> None:
|
|
2768
|
+
"""
|
|
2769
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
2770
|
+
"""
|
|
2771
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2772
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2773
|
+
)
|
|
2774
|
+
|
|
2775
|
+
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
|
|
2776
|
+
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
|
|
2777
|
+
|
|
2778
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
|
|
2779
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2780
|
+
|
|
2781
|
+
if rope:
|
|
2782
|
+
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2783
|
+
|
|
2784
|
+
if rms_norm:
|
|
2785
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
|
|
2786
|
+
|
|
2787
|
+
if cross_entropy:
|
|
2788
|
+
from transformers.loss.loss_utils import nn
|
|
2789
|
+
|
|
2790
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2791
|
+
|
|
2792
|
+
if fused_linear_cross_entropy:
|
|
2793
|
+
if model is not None:
|
|
2794
|
+
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
|
|
2795
|
+
else:
|
|
2796
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
|
|
2797
|
+
|
|
2798
|
+
if swiglu:
|
|
2799
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2800
|
+
|
|
2801
|
+
if model is not None:
|
|
2802
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2803
|
+
# instance variables that reference already-instantiated modules
|
|
2804
|
+
|
|
2805
|
+
# get the base model from the model instance
|
|
2806
|
+
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
|
|
2807
|
+
|
|
2808
|
+
if rms_norm:
|
|
2809
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2810
|
+
for decoder_layer in base_model.layers:
|
|
2811
|
+
if swiglu:
|
|
2812
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
2813
|
+
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
|
|
2814
|
+
if rms_norm:
|
|
2815
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2816
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2817
|
+
|
|
2818
|
+
|
|
2819
|
+
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
2820
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
2821
|
+
"gemma": apply_liger_kernel_to_gemma,
|
|
2822
|
+
"gemma2": apply_liger_kernel_to_gemma2,
|
|
2823
|
+
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
2824
|
+
"gemma3": apply_liger_kernel_to_gemma3,
|
|
2825
|
+
"glm4": apply_liger_kernel_to_glm4,
|
|
2826
|
+
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2827
|
+
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2828
|
+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
|
|
2829
|
+
"internvl": apply_liger_kernel_to_internvl,
|
|
2830
|
+
"llama": apply_liger_kernel_to_llama,
|
|
2831
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
2832
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
2833
|
+
"llava": apply_liger_kernel_to_llava,
|
|
2834
|
+
"granite": apply_liger_kernel_to_granite,
|
|
2835
|
+
"mllama": apply_liger_kernel_to_mllama,
|
|
2836
|
+
"mllama_text_model": apply_liger_kernel_to_mllama,
|
|
2837
|
+
"mistral": apply_liger_kernel_to_mistral,
|
|
2838
|
+
"mixtral": apply_liger_kernel_to_mixtral,
|
|
2839
|
+
"olmo2": apply_liger_kernel_to_olmo2,
|
|
2840
|
+
"olmo3": apply_liger_kernel_to_olmo3,
|
|
2841
|
+
"qwen2": apply_liger_kernel_to_qwen2,
|
|
2842
|
+
"qwen3": apply_liger_kernel_to_qwen3,
|
|
2843
|
+
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
2844
|
+
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
2845
|
+
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
2846
|
+
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
2847
|
+
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2848
|
+
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2849
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2850
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2851
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2852
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2853
|
+
"smollm3": apply_liger_kernel_to_smollm3,
|
|
2854
|
+
"phi3": apply_liger_kernel_to_phi3,
|
|
2855
|
+
"paligemma": apply_liger_kernel_to_paligemma,
|
|
2856
|
+
"falcon_h1": apply_liger_kernel_to_falcon_h1,
|
|
2857
|
+
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2858
|
+
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2859
|
+
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
2860
|
+
}
|
|
2861
|
+
|
|
2862
|
+
|
|
2863
|
+
def _apply_liger_kernel(model_type: str, **kwargs) -> None:
|
|
2864
|
+
"""
|
|
2865
|
+
Applies Liger kernels based on the specified model type. The custom
|
|
2866
|
+
kernels for the specified model type will be applied with the provided
|
|
2867
|
+
keyword arguments, otherwise the default configuration will be used.
|
|
2868
|
+
|
|
2869
|
+
** Note: Calling _apply_liger_kernel() after model initialization
|
|
2870
|
+
will not be able to fully patch models. This must be called before model initialization.
|
|
2871
|
+
If the model has already been instantiated
|
|
2872
|
+
|
|
2873
|
+
Args:
|
|
2874
|
+
- model_type: the model types as defined in transformers/models/auto/modeling_auto.py
|
|
2875
|
+
and specified in the model's config.json
|
|
2876
|
+
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
2877
|
+
"""
|
|
2878
|
+
if not model_type:
|
|
2879
|
+
logger.info("Model type was not provided. No Liger kernels will be applied.")
|
|
2880
|
+
return
|
|
2881
|
+
|
|
2882
|
+
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
2883
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
|
2884
|
+
return
|
|
2885
|
+
|
|
2886
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
2887
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
2888
|
+
|
|
2889
|
+
# Filter out the keyword arguments that are not supported by the apply function
|
|
2890
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
|
2891
|
+
|
|
2892
|
+
logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
|
|
2893
|
+
|
|
2894
|
+
# Assume this is invoked pre-model initialization, so we only need to patch transformers code
|
|
2895
|
+
apply_fn(**applicable_kwargs)
|
|
2896
|
+
|
|
2897
|
+
|
|
2898
|
+
def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
2899
|
+
"""
|
|
2900
|
+
Applies Liger kernels to the provided model instance.
|
|
2901
|
+
|
|
2902
|
+
Args:
|
|
2903
|
+
- model: the model instance to apply Liger kernels to
|
|
2904
|
+
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
2905
|
+
"""
|
|
2906
|
+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
|
2907
|
+
|
|
2908
|
+
if not model_type:
|
|
2909
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
|
2910
|
+
return
|
|
2911
|
+
|
|
2912
|
+
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
2913
|
+
logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
|
|
2914
|
+
return
|
|
2915
|
+
|
|
2916
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
2917
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
2918
|
+
|
|
2919
|
+
# Filter out the keyword arguments that are not supported by the apply function
|
|
2920
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
|
2921
|
+
logger.info(
|
|
2922
|
+
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
|
2923
|
+
)
|
|
2924
|
+
|
|
2925
|
+
apply_fn(model=model, **applicable_kwargs)
|