liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.4.dev20251121224847__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +23 -7
- liger_kernel/ops/cross_entropy.py +118 -62
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +133 -79
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/rms_norm.py +2 -2
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +59 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +38 -6
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +31 -8
- liger_kernel/transformers/model/gemma3.py +100 -110
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +41 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1278 -116
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/METADATA +29 -24
- liger_kernel_nightly-0.6.4.dev20251121224847.dist-info/RECORD +118 -0
- liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,9 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
5
6
|
from typing import Callable
|
|
7
|
+
from typing import Optional
|
|
6
8
|
|
|
7
9
|
import transformers
|
|
8
10
|
|
|
@@ -13,6 +15,7 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
|
13
15
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
14
16
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
15
17
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
18
|
+
from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
|
|
16
19
|
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
|
|
17
20
|
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
|
|
18
21
|
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
@@ -25,12 +28,14 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
|
|
|
25
28
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
26
29
|
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
27
30
|
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
|
|
28
|
-
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
|
|
29
31
|
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
|
|
30
32
|
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
|
|
33
|
+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
|
|
31
34
|
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
|
|
32
35
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
33
36
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
37
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
|
|
38
|
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
34
39
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
35
40
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
36
41
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
@@ -76,8 +81,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
76
81
|
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
77
82
|
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
78
83
|
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
79
|
-
module.modules_to_save.default
|
|
80
|
-
module.original_module
|
|
84
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
85
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
81
86
|
else:
|
|
82
87
|
module.offset = offset
|
|
83
88
|
module.casting_mode = casting_mode
|
|
@@ -86,7 +91,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
86
91
|
module.row_mode = row_mode
|
|
87
92
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
88
93
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
89
|
-
module
|
|
94
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
|
|
90
95
|
|
|
91
96
|
|
|
92
97
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
@@ -108,28 +113,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
|
|
|
108
113
|
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
109
114
|
module, "normalized_shape", None
|
|
110
115
|
)
|
|
111
|
-
_bind_method_to_module(module.modules_to_save.default, "forward",
|
|
112
|
-
_bind_method_to_module(module.modules_to_save.default, "extra_repr",
|
|
113
|
-
_bind_method_to_module(module.original_module, "forward",
|
|
114
|
-
_bind_method_to_module(module.original_module, "extra_repr",
|
|
115
|
-
module.modules_to_save.default
|
|
116
|
-
module.original_module
|
|
116
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
|
|
117
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
|
|
118
|
+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
|
|
119
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
120
|
+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
121
|
+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
117
122
|
else:
|
|
118
123
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
119
124
|
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
120
125
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
121
126
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
122
|
-
module
|
|
127
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
|
|
123
128
|
|
|
124
129
|
|
|
125
130
|
def _patch_swiglu_module(module, liger_module):
|
|
126
131
|
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
127
|
-
module
|
|
132
|
+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
|
|
128
133
|
|
|
129
134
|
|
|
130
135
|
def _patch_geglu_module(module):
|
|
131
136
|
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
132
|
-
module
|
|
137
|
+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
|
|
133
138
|
|
|
134
139
|
|
|
135
140
|
def apply_liger_kernel_to_granite(
|
|
@@ -260,10 +265,16 @@ def apply_liger_kernel_to_llama(
|
|
|
260
265
|
|
|
261
266
|
if fused_linear_cross_entropy:
|
|
262
267
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
263
|
-
|
|
268
|
+
if model is not None:
|
|
269
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
270
|
+
else:
|
|
271
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
264
272
|
else: # if version < 4.46.1
|
|
265
273
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
266
|
-
|
|
274
|
+
if model is not None:
|
|
275
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
276
|
+
else:
|
|
277
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
267
278
|
|
|
268
279
|
if model is not None:
|
|
269
280
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -283,6 +294,77 @@ def apply_liger_kernel_to_llama(
|
|
|
283
294
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
284
295
|
|
|
285
296
|
|
|
297
|
+
def apply_liger_kernel_to_smollm3(
|
|
298
|
+
rope: bool = True,
|
|
299
|
+
cross_entropy: bool = False,
|
|
300
|
+
fused_linear_cross_entropy: bool = True,
|
|
301
|
+
rms_norm: bool = True,
|
|
302
|
+
swiglu: bool = True,
|
|
303
|
+
model: PreTrainedModel = None,
|
|
304
|
+
) -> None:
|
|
305
|
+
"""
|
|
306
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
310
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
311
|
+
fused_linear_cross_entropy (bool):
|
|
312
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
313
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
314
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
315
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
316
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
317
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
318
|
+
loaded. Default is None.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
322
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
from transformers.models.smollm3 import modeling_smollm3
|
|
326
|
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
|
|
327
|
+
|
|
328
|
+
if rope:
|
|
329
|
+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
330
|
+
if rms_norm:
|
|
331
|
+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
|
|
332
|
+
if swiglu:
|
|
333
|
+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
|
|
334
|
+
|
|
335
|
+
if cross_entropy:
|
|
336
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
337
|
+
from transformers.loss.loss_utils import nn
|
|
338
|
+
|
|
339
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
340
|
+
else:
|
|
341
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
342
|
+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
343
|
+
|
|
344
|
+
if fused_linear_cross_entropy:
|
|
345
|
+
if model is not None:
|
|
346
|
+
model.forward = MethodType(smollm3_lce_forward, model)
|
|
347
|
+
else:
|
|
348
|
+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
|
|
349
|
+
|
|
350
|
+
if model is not None:
|
|
351
|
+
# The model instance already exists, so we need to additionally patch the
|
|
352
|
+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
|
|
353
|
+
|
|
354
|
+
# get the base model from the model instance
|
|
355
|
+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
|
|
356
|
+
|
|
357
|
+
if rms_norm:
|
|
358
|
+
_patch_rms_norm_module(base_model.norm)
|
|
359
|
+
|
|
360
|
+
for decoder_layer in base_model.layers:
|
|
361
|
+
if swiglu:
|
|
362
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
363
|
+
if rms_norm:
|
|
364
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
365
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
366
|
+
|
|
367
|
+
|
|
286
368
|
def apply_liger_kernel_to_llava(
|
|
287
369
|
cross_entropy: bool = False,
|
|
288
370
|
fused_linear_cross_entropy: bool = True,
|
|
@@ -318,9 +400,15 @@ def apply_liger_kernel_to_llava(
|
|
|
318
400
|
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
319
401
|
if fused_linear_cross_entropy:
|
|
320
402
|
if transformer_version >= version.parse("4.52.0"):
|
|
321
|
-
|
|
403
|
+
if model is not None:
|
|
404
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
405
|
+
else:
|
|
406
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
322
407
|
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
323
|
-
|
|
408
|
+
if model is not None:
|
|
409
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
410
|
+
else:
|
|
411
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
324
412
|
else: # if version < 4.49.0
|
|
325
413
|
logger.warning(
|
|
326
414
|
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
@@ -363,6 +451,97 @@ def apply_liger_kernel_to_llava(
|
|
|
363
451
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
364
452
|
|
|
365
453
|
|
|
454
|
+
def apply_liger_kernel_to_llama4(
|
|
455
|
+
rope: bool = True,
|
|
456
|
+
cross_entropy: bool = False,
|
|
457
|
+
fused_linear_cross_entropy: bool = True,
|
|
458
|
+
rms_norm: bool = True,
|
|
459
|
+
swiglu: bool = True,
|
|
460
|
+
model: PreTrainedModel = None,
|
|
461
|
+
layer_norm: bool = True,
|
|
462
|
+
) -> None:
|
|
463
|
+
"""
|
|
464
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
468
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
469
|
+
fused_linear_cross_entropy (bool):
|
|
470
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
471
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
472
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
473
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
474
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
475
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
476
|
+
loaded. Default is None.
|
|
477
|
+
"""
|
|
478
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
479
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
from transformers.models.llama4 import modeling_llama4
|
|
483
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
484
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
485
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
486
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
487
|
+
|
|
488
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
489
|
+
|
|
490
|
+
if rope:
|
|
491
|
+
from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
|
|
492
|
+
|
|
493
|
+
apply_liger_llama4_rope_full(modeling_llama4)
|
|
494
|
+
if rms_norm:
|
|
495
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
496
|
+
if swiglu:
|
|
497
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
498
|
+
|
|
499
|
+
if cross_entropy:
|
|
500
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
501
|
+
|
|
502
|
+
if fused_linear_cross_entropy:
|
|
503
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
504
|
+
|
|
505
|
+
if model is not None:
|
|
506
|
+
# The model instance already exists, so we need to additionally patch the
|
|
507
|
+
# instance variables that reference already-instantiated modules
|
|
508
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
509
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
510
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
511
|
+
text_model: Llama4TextModel = language_model.model
|
|
512
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
513
|
+
text_model = model.model
|
|
514
|
+
vision_model = None
|
|
515
|
+
elif isinstance(model, Llama4TextModel):
|
|
516
|
+
text_model = model
|
|
517
|
+
vision_model = None
|
|
518
|
+
|
|
519
|
+
else:
|
|
520
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
521
|
+
|
|
522
|
+
if text_model:
|
|
523
|
+
if rms_norm:
|
|
524
|
+
_patch_rms_norm_module(text_model.norm)
|
|
525
|
+
for decoder_layer in text_model.layers:
|
|
526
|
+
if swiglu:
|
|
527
|
+
if decoder_layer.is_moe_layer:
|
|
528
|
+
_patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
|
|
529
|
+
else:
|
|
530
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
531
|
+
if rms_norm:
|
|
532
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
533
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
534
|
+
|
|
535
|
+
if vision_model:
|
|
536
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
537
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
538
|
+
|
|
539
|
+
for layer in vision_model.model.layers:
|
|
540
|
+
if layer_norm:
|
|
541
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
542
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
543
|
+
|
|
544
|
+
|
|
366
545
|
def apply_liger_kernel_to_mllama(
|
|
367
546
|
rope: bool = True,
|
|
368
547
|
cross_entropy: bool = False,
|
|
@@ -404,7 +583,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
404
583
|
|
|
405
584
|
if rope:
|
|
406
585
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
407
|
-
if layer_norm:
|
|
586
|
+
if layer_norm and model is None:
|
|
408
587
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
409
588
|
if rms_norm:
|
|
410
589
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
@@ -420,10 +599,16 @@ def apply_liger_kernel_to_mllama(
|
|
|
420
599
|
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
421
600
|
if fused_linear_cross_entropy:
|
|
422
601
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
423
|
-
|
|
602
|
+
if model is not None:
|
|
603
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
604
|
+
else:
|
|
605
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
424
606
|
else: # if version < 4.46.1
|
|
425
607
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
426
|
-
|
|
608
|
+
if model is not None:
|
|
609
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
610
|
+
else:
|
|
611
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
427
612
|
|
|
428
613
|
if model is not None:
|
|
429
614
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -432,7 +617,10 @@ def apply_liger_kernel_to_mllama(
|
|
|
432
617
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
433
618
|
language_model: MllamaForCausalLM = model.language_model
|
|
434
619
|
vision_model: MllamaVisionModel = model.vision_model
|
|
435
|
-
|
|
620
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
621
|
+
text_model: MllamaTextModel = language_model.model
|
|
622
|
+
else:
|
|
623
|
+
text_model = language_model
|
|
436
624
|
elif isinstance(model, MllamaForCausalLM):
|
|
437
625
|
text_model = model.model
|
|
438
626
|
vision_model = None
|
|
@@ -506,7 +694,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
506
694
|
if cross_entropy:
|
|
507
695
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
508
696
|
if fused_linear_cross_entropy:
|
|
509
|
-
|
|
697
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
698
|
+
if model is not None:
|
|
699
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
700
|
+
else:
|
|
701
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
702
|
+
else:
|
|
703
|
+
logger.warning(
|
|
704
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
705
|
+
)
|
|
706
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
707
|
+
|
|
510
708
|
if swiglu:
|
|
511
709
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
512
710
|
|
|
@@ -574,10 +772,16 @@ def apply_liger_kernel_to_mixtral(
|
|
|
574
772
|
|
|
575
773
|
if fused_linear_cross_entropy:
|
|
576
774
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
577
|
-
|
|
775
|
+
if model is not None:
|
|
776
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
777
|
+
else:
|
|
778
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
578
779
|
else: # if version < 4.46.1
|
|
579
780
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
580
|
-
|
|
781
|
+
if model is not None:
|
|
782
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
783
|
+
else:
|
|
784
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
581
785
|
if swiglu:
|
|
582
786
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
583
787
|
|
|
@@ -651,10 +855,16 @@ def apply_liger_kernel_to_gemma(
|
|
|
651
855
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
652
856
|
if fused_linear_cross_entropy:
|
|
653
857
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
654
|
-
|
|
858
|
+
if model is not None:
|
|
859
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
860
|
+
else:
|
|
861
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
655
862
|
else: # if version < 4.46.1
|
|
656
863
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
657
|
-
|
|
864
|
+
if model is not None:
|
|
865
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
866
|
+
else:
|
|
867
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
658
868
|
|
|
659
869
|
if model is not None:
|
|
660
870
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -726,10 +936,16 @@ def apply_liger_kernel_to_gemma2(
|
|
|
726
936
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
727
937
|
if fused_linear_cross_entropy:
|
|
728
938
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
729
|
-
|
|
939
|
+
if model is not None:
|
|
940
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
941
|
+
else:
|
|
942
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
730
943
|
else:
|
|
731
944
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
732
|
-
|
|
945
|
+
if model is not None:
|
|
946
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
947
|
+
else:
|
|
948
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
733
949
|
if geglu:
|
|
734
950
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
735
951
|
|
|
@@ -808,7 +1024,10 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
808
1024
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
809
1025
|
|
|
810
1026
|
if fused_linear_cross_entropy:
|
|
811
|
-
|
|
1027
|
+
if model is not None:
|
|
1028
|
+
model.forward = MethodType(causal_forward, model)
|
|
1029
|
+
else:
|
|
1030
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
812
1031
|
|
|
813
1032
|
if model is not None:
|
|
814
1033
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -878,7 +1097,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
878
1097
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
879
1098
|
)
|
|
880
1099
|
|
|
881
|
-
if layer_norm:
|
|
1100
|
+
if layer_norm and model is None:
|
|
882
1101
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
883
1102
|
|
|
884
1103
|
apply_liger_kernel_to_gemma3_text(
|
|
@@ -889,7 +1108,10 @@ def apply_liger_kernel_to_gemma3(
|
|
|
889
1108
|
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
890
1109
|
|
|
891
1110
|
if fused_linear_cross_entropy:
|
|
892
|
-
|
|
1111
|
+
if model is not None:
|
|
1112
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1113
|
+
else:
|
|
1114
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
893
1115
|
|
|
894
1116
|
if model is not None:
|
|
895
1117
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -957,7 +1179,9 @@ def apply_liger_kernel_to_paligemma(
|
|
|
957
1179
|
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
958
1180
|
|
|
959
1181
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1182
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
960
1183
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1184
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
961
1185
|
from transformers.models.paligemma import modeling_paligemma
|
|
962
1186
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
963
1187
|
from transformers.models.siglip import modeling_siglip
|
|
@@ -968,7 +1192,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
968
1192
|
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
969
1193
|
|
|
970
1194
|
# The vision_tower is a SiglipVisionModel
|
|
971
|
-
if layer_norm:
|
|
1195
|
+
if layer_norm and model is None:
|
|
972
1196
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
973
1197
|
|
|
974
1198
|
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
@@ -986,10 +1210,16 @@ def apply_liger_kernel_to_paligemma(
|
|
|
986
1210
|
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
987
1211
|
if fused_linear_cross_entropy:
|
|
988
1212
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
989
|
-
|
|
1213
|
+
if model is not None:
|
|
1214
|
+
model.forward = MethodType(lce_forward, model)
|
|
1215
|
+
else:
|
|
1216
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
990
1217
|
else: # if version < 4.46.1
|
|
991
1218
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
992
|
-
|
|
1219
|
+
if model is not None:
|
|
1220
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1221
|
+
else:
|
|
1222
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
993
1223
|
|
|
994
1224
|
if model is not None:
|
|
995
1225
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1010,7 +1240,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1010
1240
|
|
|
1011
1241
|
language_model = model.language_model
|
|
1012
1242
|
|
|
1013
|
-
if isinstance(language_model, GemmaForCausalLM):
|
|
1243
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1014
1244
|
apply_liger_kernel_to_gemma(
|
|
1015
1245
|
rope=rope,
|
|
1016
1246
|
cross_entropy=False,
|
|
@@ -1020,7 +1250,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1020
1250
|
model=language_model,
|
|
1021
1251
|
)
|
|
1022
1252
|
|
|
1023
|
-
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
1253
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
1024
1254
|
apply_liger_kernel_to_gemma2(
|
|
1025
1255
|
rope=rope,
|
|
1026
1256
|
cross_entropy=False,
|
|
@@ -1081,10 +1311,16 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1081
1311
|
|
|
1082
1312
|
if fused_linear_cross_entropy:
|
|
1083
1313
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1084
|
-
|
|
1314
|
+
if model is not None:
|
|
1315
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1316
|
+
else:
|
|
1317
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1085
1318
|
else: # if version < 4.46.1
|
|
1086
1319
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1087
|
-
|
|
1320
|
+
if model is not None:
|
|
1321
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1322
|
+
else:
|
|
1323
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1088
1324
|
|
|
1089
1325
|
if swiglu:
|
|
1090
1326
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
@@ -1105,7 +1341,6 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1105
1341
|
if rms_norm:
|
|
1106
1342
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1107
1343
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1108
|
-
print("Applied Liger kernels to Qwen2")
|
|
1109
1344
|
|
|
1110
1345
|
|
|
1111
1346
|
def apply_liger_kernel_to_qwen3(
|
|
@@ -1140,7 +1375,10 @@ def apply_liger_kernel_to_qwen3(
|
|
|
1140
1375
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1141
1376
|
|
|
1142
1377
|
if fused_linear_cross_entropy:
|
|
1143
|
-
|
|
1378
|
+
if model is not None:
|
|
1379
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1380
|
+
else:
|
|
1381
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1144
1382
|
|
|
1145
1383
|
if swiglu:
|
|
1146
1384
|
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
@@ -1195,7 +1433,10 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1195
1433
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1196
1434
|
|
|
1197
1435
|
if fused_linear_cross_entropy:
|
|
1198
|
-
|
|
1436
|
+
if model is not None:
|
|
1437
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1438
|
+
else:
|
|
1439
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1199
1440
|
|
|
1200
1441
|
if swiglu:
|
|
1201
1442
|
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
@@ -1264,12 +1505,15 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1264
1505
|
if rms_norm:
|
|
1265
1506
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1266
1507
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1267
|
-
if layer_norm:
|
|
1508
|
+
if layer_norm and model is None:
|
|
1268
1509
|
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1269
1510
|
if cross_entropy:
|
|
1270
1511
|
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1271
1512
|
if fused_linear_cross_entropy:
|
|
1272
|
-
|
|
1513
|
+
if model is not None:
|
|
1514
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1515
|
+
else:
|
|
1516
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1273
1517
|
if swiglu:
|
|
1274
1518
|
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1275
1519
|
|
|
@@ -1357,7 +1601,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1357
1601
|
if cross_entropy:
|
|
1358
1602
|
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1359
1603
|
if fused_linear_cross_entropy:
|
|
1360
|
-
|
|
1604
|
+
if model is not None:
|
|
1605
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1606
|
+
else:
|
|
1607
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1361
1608
|
if swiglu:
|
|
1362
1609
|
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1363
1610
|
|
|
@@ -1398,141 +1645,160 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1398
1645
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1399
1646
|
|
|
1400
1647
|
|
|
1401
|
-
def
|
|
1648
|
+
def apply_liger_kernel_to_qwen3_vl(
|
|
1402
1649
|
rope: bool = True,
|
|
1403
1650
|
cross_entropy: bool = False,
|
|
1404
1651
|
fused_linear_cross_entropy: bool = True,
|
|
1405
1652
|
rms_norm: bool = True,
|
|
1406
|
-
swiglu: bool =
|
|
1653
|
+
swiglu: bool = False,
|
|
1407
1654
|
model: PreTrainedModel = None,
|
|
1408
1655
|
) -> None:
|
|
1409
1656
|
"""
|
|
1410
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1657
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
|
|
1411
1658
|
|
|
1412
1659
|
Args:
|
|
1413
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1414
1660
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1415
1661
|
fused_linear_cross_entropy (bool):
|
|
1416
1662
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1417
1663
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1418
1664
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1419
1665
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1420
|
-
swiglu (bool): Whether to apply Liger's SwiGLU
|
|
1666
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1421
1667
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1422
1668
|
loaded. Default is None.
|
|
1423
1669
|
"""
|
|
1670
|
+
|
|
1424
1671
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1425
1672
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1426
1673
|
)
|
|
1427
1674
|
|
|
1428
|
-
from transformers.models.
|
|
1429
|
-
from transformers.models.
|
|
1675
|
+
from transformers.models.qwen3_vl import modeling_qwen3_vl
|
|
1676
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
|
|
1677
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
|
|
1678
|
+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
|
|
1679
|
+
|
|
1680
|
+
from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
|
|
1430
1681
|
|
|
1431
1682
|
if rope:
|
|
1432
|
-
|
|
1683
|
+
modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1684
|
+
modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
1685
|
+
|
|
1433
1686
|
if rms_norm:
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1687
|
+
modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
|
|
1688
|
+
|
|
1437
1689
|
if cross_entropy:
|
|
1438
|
-
|
|
1439
|
-
from transformers.loss.loss_utils import nn
|
|
1690
|
+
from transformers.loss.loss_utils import nn
|
|
1440
1691
|
|
|
1441
|
-
|
|
1442
|
-
else:
|
|
1443
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1444
|
-
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1445
|
-
if fused_linear_cross_entropy:
|
|
1446
|
-
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1447
|
-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1448
|
-
else: # if version < 4.46.1
|
|
1449
|
-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1450
|
-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
|
|
1692
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1451
1693
|
|
|
1452
|
-
if
|
|
1453
|
-
|
|
1454
|
-
|
|
1694
|
+
if fused_linear_cross_entropy:
|
|
1695
|
+
if model is not None:
|
|
1696
|
+
model.forward = MethodType(qwen3_vl_lce_forward, model)
|
|
1697
|
+
else:
|
|
1698
|
+
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
|
|
1455
1699
|
|
|
1456
|
-
|
|
1457
|
-
|
|
1700
|
+
if model is not None and rms_norm:
|
|
1701
|
+
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
|
|
1702
|
+
text_model: Qwen3VLTextModel = model.language_model
|
|
1703
|
+
elif isinstance(model, Qwen3VLTextModel):
|
|
1704
|
+
text_model = model
|
|
1705
|
+
else:
|
|
1706
|
+
raise TypeError(
|
|
1707
|
+
f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
|
|
1708
|
+
)
|
|
1458
1709
|
|
|
1459
|
-
|
|
1460
|
-
_patch_rms_norm_module(base_model.norm)
|
|
1710
|
+
_patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1461
1711
|
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1712
|
+
if text_model is not None:
|
|
1713
|
+
_patch_qwen3_vl_rms_norm(text_model.norm)
|
|
1714
|
+
for decoder_layer in text_model.layers:
|
|
1715
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
|
|
1716
|
+
_patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1717
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1718
|
+
if self_attn is not None:
|
|
1719
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1720
|
+
_patch_qwen3_vl_rms_norm(self_attn.q_norm)
|
|
1721
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1722
|
+
_patch_qwen3_vl_rms_norm(self_attn.k_norm)
|
|
1468
1723
|
|
|
1469
1724
|
|
|
1470
|
-
def
|
|
1725
|
+
def apply_liger_kernel_to_qwen3_vl_moe(
|
|
1471
1726
|
rope: bool = True,
|
|
1472
1727
|
cross_entropy: bool = False,
|
|
1473
1728
|
fused_linear_cross_entropy: bool = True,
|
|
1474
1729
|
rms_norm: bool = True,
|
|
1475
|
-
swiglu: bool =
|
|
1730
|
+
swiglu: bool = False,
|
|
1476
1731
|
model: PreTrainedModel = None,
|
|
1477
1732
|
) -> None:
|
|
1478
1733
|
"""
|
|
1479
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1734
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
|
|
1480
1735
|
|
|
1481
1736
|
Args:
|
|
1482
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1483
1737
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1484
1738
|
fused_linear_cross_entropy (bool):
|
|
1485
|
-
Whether to apply Liger's fused linear cross entropy loss. Default is
|
|
1486
|
-
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1487
|
-
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1739
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
1488
1740
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1489
|
-
swiglu (bool): Whether to apply Liger's SwiGLU
|
|
1741
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
1490
1742
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1491
1743
|
loaded. Default is None.
|
|
1492
1744
|
"""
|
|
1745
|
+
|
|
1493
1746
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1494
1747
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1495
1748
|
)
|
|
1496
1749
|
|
|
1497
|
-
from transformers.models.
|
|
1498
|
-
from transformers.models.
|
|
1750
|
+
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
|
|
1751
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
|
1752
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
|
|
1753
|
+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
|
|
1499
1754
|
|
|
1500
|
-
from liger_kernel.transformers.model.
|
|
1501
|
-
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1755
|
+
from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
|
|
1502
1756
|
|
|
1503
1757
|
if rope:
|
|
1504
|
-
|
|
1758
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
|
|
1759
|
+
modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
|
|
1760
|
+
|
|
1505
1761
|
if rms_norm:
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1762
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
|
|
1763
|
+
|
|
1509
1764
|
if cross_entropy:
|
|
1510
1765
|
from transformers.loss.loss_utils import nn
|
|
1511
1766
|
|
|
1512
1767
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1513
|
-
if fused_linear_cross_entropy:
|
|
1514
|
-
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1515
1768
|
|
|
1516
|
-
if
|
|
1517
|
-
|
|
1518
|
-
|
|
1769
|
+
if fused_linear_cross_entropy:
|
|
1770
|
+
if model is not None:
|
|
1771
|
+
model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
|
|
1772
|
+
else:
|
|
1773
|
+
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
|
|
1519
1774
|
|
|
1520
|
-
|
|
1521
|
-
|
|
1775
|
+
if model is not None and rms_norm:
|
|
1776
|
+
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
|
|
1777
|
+
text_model: Qwen3VLMoeTextModel = model.language_model
|
|
1778
|
+
elif isinstance(model, Qwen3VLMoeTextModel):
|
|
1779
|
+
text_model = model
|
|
1780
|
+
else:
|
|
1781
|
+
raise TypeError(
|
|
1782
|
+
f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
|
|
1783
|
+
)
|
|
1522
1784
|
|
|
1523
|
-
|
|
1524
|
-
_patch_rms_norm_module(base_model.norm)
|
|
1785
|
+
_patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
|
|
1525
1786
|
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1787
|
+
if text_model is not None:
|
|
1788
|
+
_patch_qwen3_vl_moe_rms_norm(text_model.norm)
|
|
1789
|
+
for decoder_layer in text_model.layers:
|
|
1790
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
|
|
1791
|
+
_patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
|
|
1792
|
+
self_attn = getattr(decoder_layer, "self_attn", None)
|
|
1793
|
+
if self_attn is not None:
|
|
1794
|
+
if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
|
|
1795
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
|
|
1796
|
+
if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
|
|
1797
|
+
_patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
|
|
1532
1798
|
|
|
1533
1799
|
|
|
1534
|
-
def
|
|
1535
|
-
rope: bool =
|
|
1800
|
+
def apply_liger_kernel_to_phi3(
|
|
1801
|
+
rope: bool = True,
|
|
1536
1802
|
cross_entropy: bool = False,
|
|
1537
1803
|
fused_linear_cross_entropy: bool = True,
|
|
1538
1804
|
rms_norm: bool = True,
|
|
@@ -1540,10 +1806,209 @@ def apply_liger_kernel_to_glm4(
|
|
|
1540
1806
|
model: PreTrainedModel = None,
|
|
1541
1807
|
) -> None:
|
|
1542
1808
|
"""
|
|
1543
|
-
Apply Liger kernels to replace original implementation in HuggingFace
|
|
1809
|
+
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
|
|
1544
1810
|
|
|
1545
1811
|
Args:
|
|
1546
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
1812
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1813
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1814
|
+
fused_linear_cross_entropy (bool):
|
|
1815
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1816
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1817
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1818
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1819
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
|
|
1820
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1821
|
+
loaded. Default is None.
|
|
1822
|
+
"""
|
|
1823
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1824
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
from transformers.models.phi3 import modeling_phi3
|
|
1828
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
1829
|
+
|
|
1830
|
+
if rope:
|
|
1831
|
+
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
1832
|
+
if rms_norm:
|
|
1833
|
+
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
|
|
1834
|
+
if swiglu:
|
|
1835
|
+
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
|
|
1836
|
+
if cross_entropy:
|
|
1837
|
+
from transformers.loss.loss_utils import nn
|
|
1838
|
+
|
|
1839
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1840
|
+
if fused_linear_cross_entropy:
|
|
1841
|
+
if model is not None:
|
|
1842
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1843
|
+
else:
|
|
1844
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1845
|
+
|
|
1846
|
+
if model is not None:
|
|
1847
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1848
|
+
# instance variables that reference already-instantiated modules
|
|
1849
|
+
|
|
1850
|
+
# get the base model from the model instance
|
|
1851
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
1852
|
+
|
|
1853
|
+
if rms_norm:
|
|
1854
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1855
|
+
|
|
1856
|
+
for decoder_layer in base_model.layers:
|
|
1857
|
+
if swiglu:
|
|
1858
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
1859
|
+
if rms_norm:
|
|
1860
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1861
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1862
|
+
|
|
1863
|
+
|
|
1864
|
+
def apply_liger_kernel_to_olmo2(
|
|
1865
|
+
rope: bool = True,
|
|
1866
|
+
cross_entropy: bool = False,
|
|
1867
|
+
fused_linear_cross_entropy: bool = True,
|
|
1868
|
+
rms_norm: bool = True,
|
|
1869
|
+
swiglu: bool = True,
|
|
1870
|
+
model: PreTrainedModel = None,
|
|
1871
|
+
) -> None:
|
|
1872
|
+
"""
|
|
1873
|
+
Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
|
|
1874
|
+
|
|
1875
|
+
Args:
|
|
1876
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1877
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1878
|
+
fused_linear_cross_entropy (bool):
|
|
1879
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1880
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1881
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1882
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1883
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
|
|
1884
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1885
|
+
loaded. Default is None.
|
|
1886
|
+
"""
|
|
1887
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1888
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1889
|
+
)
|
|
1890
|
+
|
|
1891
|
+
from transformers.models.olmo2 import modeling_olmo2
|
|
1892
|
+
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
1893
|
+
|
|
1894
|
+
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
1895
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1896
|
+
|
|
1897
|
+
if rope:
|
|
1898
|
+
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1899
|
+
if rms_norm:
|
|
1900
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
|
1901
|
+
if swiglu:
|
|
1902
|
+
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1903
|
+
if cross_entropy:
|
|
1904
|
+
from transformers.loss.loss_utils import nn
|
|
1905
|
+
|
|
1906
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1907
|
+
if fused_linear_cross_entropy:
|
|
1908
|
+
if model is not None:
|
|
1909
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1910
|
+
else:
|
|
1911
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1912
|
+
|
|
1913
|
+
if model is not None:
|
|
1914
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1915
|
+
# instance variables that reference already-instantiated modules
|
|
1916
|
+
|
|
1917
|
+
# get the base model from the model instance
|
|
1918
|
+
base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
|
|
1919
|
+
|
|
1920
|
+
if rms_norm:
|
|
1921
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1922
|
+
|
|
1923
|
+
for decoder_layer in base_model.layers:
|
|
1924
|
+
if swiglu:
|
|
1925
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1926
|
+
if rms_norm:
|
|
1927
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
1928
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
1929
|
+
|
|
1930
|
+
|
|
1931
|
+
def apply_liger_kernel_to_olmo3(
|
|
1932
|
+
rope: bool = True,
|
|
1933
|
+
cross_entropy: bool = False,
|
|
1934
|
+
fused_linear_cross_entropy: bool = True,
|
|
1935
|
+
rms_norm: bool = True,
|
|
1936
|
+
swiglu: bool = True,
|
|
1937
|
+
model: PreTrainedModel = None,
|
|
1938
|
+
) -> None:
|
|
1939
|
+
"""
|
|
1940
|
+
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
|
|
1941
|
+
|
|
1942
|
+
Args:
|
|
1943
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
1944
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1945
|
+
fused_linear_cross_entropy (bool):
|
|
1946
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
1947
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
1948
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
1949
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
1950
|
+
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
|
|
1951
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1952
|
+
loaded. Default is None.
|
|
1953
|
+
"""
|
|
1954
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1955
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1956
|
+
)
|
|
1957
|
+
|
|
1958
|
+
from transformers.models.olmo3 import modeling_olmo3
|
|
1959
|
+
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
|
|
1960
|
+
|
|
1961
|
+
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
|
|
1962
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1963
|
+
|
|
1964
|
+
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
|
|
1965
|
+
if rope:
|
|
1966
|
+
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1967
|
+
if rms_norm:
|
|
1968
|
+
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
|
|
1969
|
+
if swiglu:
|
|
1970
|
+
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
|
|
1971
|
+
if cross_entropy:
|
|
1972
|
+
from transformers.loss.loss_utils import nn
|
|
1973
|
+
|
|
1974
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1975
|
+
if fused_linear_cross_entropy:
|
|
1976
|
+
if model is not None:
|
|
1977
|
+
model.forward = MethodType(olmo3_lce_forward, model)
|
|
1978
|
+
else:
|
|
1979
|
+
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
|
|
1980
|
+
|
|
1981
|
+
if model is not None:
|
|
1982
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1983
|
+
# instance variables that reference already-instantiated modules
|
|
1984
|
+
|
|
1985
|
+
# get the base model from the model instance
|
|
1986
|
+
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
|
|
1987
|
+
|
|
1988
|
+
if rms_norm:
|
|
1989
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1990
|
+
|
|
1991
|
+
for decoder_layer in base_model.layers:
|
|
1992
|
+
if swiglu:
|
|
1993
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1994
|
+
if rms_norm:
|
|
1995
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
1996
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
1997
|
+
|
|
1998
|
+
|
|
1999
|
+
def apply_liger_kernel_to_glm4(
|
|
2000
|
+
rope: bool = False,
|
|
2001
|
+
cross_entropy: bool = False,
|
|
2002
|
+
fused_linear_cross_entropy: bool = True,
|
|
2003
|
+
rms_norm: bool = True,
|
|
2004
|
+
swiglu: bool = True,
|
|
2005
|
+
model: PreTrainedModel = None,
|
|
2006
|
+
) -> None:
|
|
2007
|
+
"""
|
|
2008
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
|
|
2009
|
+
|
|
2010
|
+
Args:
|
|
2011
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
1547
2012
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
1548
2013
|
fused_linear_cross_entropy (bool):
|
|
1549
2014
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
@@ -1575,7 +2040,10 @@ def apply_liger_kernel_to_glm4(
|
|
|
1575
2040
|
|
|
1576
2041
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1577
2042
|
if fused_linear_cross_entropy:
|
|
1578
|
-
|
|
2043
|
+
if model is not None:
|
|
2044
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
2045
|
+
else:
|
|
2046
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
1579
2047
|
|
|
1580
2048
|
if model is not None:
|
|
1581
2049
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1597,6 +2065,684 @@ def apply_liger_kernel_to_glm4(
|
|
|
1597
2065
|
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
|
|
1598
2066
|
|
|
1599
2067
|
|
|
2068
|
+
def apply_liger_kernel_to_glm4v(
|
|
2069
|
+
rope: bool = False,
|
|
2070
|
+
cross_entropy: bool = False,
|
|
2071
|
+
fused_linear_cross_entropy: bool = True,
|
|
2072
|
+
rms_norm: bool = True,
|
|
2073
|
+
swiglu: bool = True,
|
|
2074
|
+
model: PreTrainedModel = None,
|
|
2075
|
+
) -> None:
|
|
2076
|
+
"""
|
|
2077
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
|
|
2078
|
+
|
|
2079
|
+
Args:
|
|
2080
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2081
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2082
|
+
fused_linear_cross_entropy (bool):
|
|
2083
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2084
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2085
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2086
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2087
|
+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
|
|
2088
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2089
|
+
loaded. Default is None.
|
|
2090
|
+
"""
|
|
2091
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2092
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2093
|
+
)
|
|
2094
|
+
|
|
2095
|
+
from transformers.models.glm4v import modeling_glm4v
|
|
2096
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
|
|
2097
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
|
|
2098
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
|
|
2099
|
+
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
|
|
2100
|
+
|
|
2101
|
+
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
|
|
2102
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2103
|
+
|
|
2104
|
+
if rope:
|
|
2105
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2106
|
+
if rms_norm:
|
|
2107
|
+
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
|
|
2108
|
+
if cross_entropy:
|
|
2109
|
+
from transformers.loss.loss_utils import nn
|
|
2110
|
+
|
|
2111
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2112
|
+
if fused_linear_cross_entropy:
|
|
2113
|
+
if model is not None:
|
|
2114
|
+
model.forward = MethodType(glm4v_lce_forward, model)
|
|
2115
|
+
else:
|
|
2116
|
+
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
|
|
2117
|
+
|
|
2118
|
+
if model is not None:
|
|
2119
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2120
|
+
# instance variables that reference already-instantiated modules
|
|
2121
|
+
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
|
|
2122
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
2123
|
+
# Not sure if it is subject to changes in the future.
|
|
2124
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
|
|
2125
|
+
text_model: Glm4vTextModel = model.language_model
|
|
2126
|
+
vision_model: Glm4vVisionModel = model.visual
|
|
2127
|
+
elif isinstance(model, Glm4vTextModel):
|
|
2128
|
+
text_model: Glm4vTextModel = model
|
|
2129
|
+
vision_model = None
|
|
2130
|
+
else:
|
|
2131
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2132
|
+
raise TypeError(
|
|
2133
|
+
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
|
|
2134
|
+
)
|
|
2135
|
+
|
|
2136
|
+
if vision_model is not None:
|
|
2137
|
+
for vision_block in vision_model.blocks:
|
|
2138
|
+
if rms_norm:
|
|
2139
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2140
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2141
|
+
if swiglu:
|
|
2142
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2143
|
+
|
|
2144
|
+
if text_model is not None:
|
|
2145
|
+
if rms_norm:
|
|
2146
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2147
|
+
for decoder_layer in text_model.layers:
|
|
2148
|
+
if swiglu:
|
|
2149
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
2150
|
+
if rms_norm:
|
|
2151
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2152
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2153
|
+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
|
|
2154
|
+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
|
|
2155
|
+
|
|
2156
|
+
|
|
2157
|
+
def apply_liger_kernel_to_glm4v_moe(
|
|
2158
|
+
rope: bool = False,
|
|
2159
|
+
cross_entropy: bool = False,
|
|
2160
|
+
fused_linear_cross_entropy: bool = True,
|
|
2161
|
+
rms_norm: bool = True,
|
|
2162
|
+
swiglu: bool = True,
|
|
2163
|
+
model: PreTrainedModel = None,
|
|
2164
|
+
) -> None:
|
|
2165
|
+
"""
|
|
2166
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
2167
|
+
|
|
2168
|
+
Args:
|
|
2169
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2170
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2171
|
+
fused_linear_cross_entropy (bool):
|
|
2172
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2173
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2174
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2175
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2176
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
2177
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2178
|
+
loaded. Default is None.
|
|
2179
|
+
"""
|
|
2180
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2181
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2182
|
+
)
|
|
2183
|
+
|
|
2184
|
+
from transformers.models.glm4v_moe import modeling_glm4v_moe
|
|
2185
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
|
|
2186
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
|
|
2187
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
|
|
2188
|
+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
|
|
2189
|
+
|
|
2190
|
+
from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
|
|
2191
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
2192
|
+
|
|
2193
|
+
if rope:
|
|
2194
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
2195
|
+
if rms_norm:
|
|
2196
|
+
modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
|
|
2197
|
+
modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
|
|
2198
|
+
if cross_entropy:
|
|
2199
|
+
from transformers.loss.loss_utils import nn
|
|
2200
|
+
|
|
2201
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2202
|
+
if fused_linear_cross_entropy:
|
|
2203
|
+
if model is not None:
|
|
2204
|
+
model.forward = MethodType(glm4v_moe_lce_forward, model)
|
|
2205
|
+
else:
|
|
2206
|
+
modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
|
|
2207
|
+
|
|
2208
|
+
if model is not None:
|
|
2209
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2210
|
+
# instance variables that reference already-instantiated modules
|
|
2211
|
+
if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
|
|
2212
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
2213
|
+
# Not sure if it is subject to changes in the future.
|
|
2214
|
+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
|
|
2215
|
+
text_model: Glm4vMoeTextModel = model.language_model
|
|
2216
|
+
vision_model: Glm4vMoeVisionModel = model.visual
|
|
2217
|
+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
|
|
2218
|
+
elif isinstance(model, Glm4vMoeTextModel):
|
|
2219
|
+
text_model: Glm4vMoeTextModel = model
|
|
2220
|
+
vision_model = None
|
|
2221
|
+
else:
|
|
2222
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
2223
|
+
raise TypeError(
|
|
2224
|
+
f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
|
|
2225
|
+
)
|
|
2226
|
+
|
|
2227
|
+
if vision_model is not None:
|
|
2228
|
+
_patch_rms_norm_module(vision_model.post_conv_layernorm)
|
|
2229
|
+
_patch_rms_norm_module(vision_model.post_layernorm)
|
|
2230
|
+
for vision_block in vision_model.blocks:
|
|
2231
|
+
if rms_norm:
|
|
2232
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
2233
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
2234
|
+
if swiglu:
|
|
2235
|
+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
|
|
2236
|
+
|
|
2237
|
+
if text_model is not None:
|
|
2238
|
+
if rms_norm:
|
|
2239
|
+
_patch_rms_norm_module(text_model.norm)
|
|
2240
|
+
for decoder_layer in text_model.layers:
|
|
2241
|
+
if swiglu:
|
|
2242
|
+
decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2243
|
+
if rms_norm:
|
|
2244
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2245
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2246
|
+
if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
|
|
2247
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2248
|
+
if experts is not None:
|
|
2249
|
+
for expert in experts:
|
|
2250
|
+
_patch_swiglu_module(expert, LigerSwiGLUMLP)
|
|
2251
|
+
if decoder_layer.mlp.shared_experts is not None:
|
|
2252
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
|
|
2253
|
+
for decoder_layer in text_model.layers:
|
|
2254
|
+
if rms_norm:
|
|
2255
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2256
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2257
|
+
|
|
2258
|
+
|
|
2259
|
+
def apply_liger_kernel_to_internvl(
|
|
2260
|
+
cross_entropy: bool = False,
|
|
2261
|
+
fused_linear_cross_entropy: bool = True,
|
|
2262
|
+
rms_norm: bool = True,
|
|
2263
|
+
layer_norm: bool = True,
|
|
2264
|
+
model: Optional[PreTrainedModel] = None,
|
|
2265
|
+
**kwargs,
|
|
2266
|
+
) -> None:
|
|
2267
|
+
"""
|
|
2268
|
+
Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
|
|
2269
|
+
Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
|
|
2270
|
+
However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
|
|
2271
|
+
NOTE: InternVL is not available in transformers<4.52.1
|
|
2272
|
+
|
|
2273
|
+
Args:
|
|
2274
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2275
|
+
fused_linear_cross_entropy (bool):
|
|
2276
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2277
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2278
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2279
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2280
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2281
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2282
|
+
loaded. Default is None.
|
|
2283
|
+
"""
|
|
2284
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2285
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2286
|
+
)
|
|
2287
|
+
import torch.nn as torch_nn
|
|
2288
|
+
|
|
2289
|
+
from transformers.models.internvl import modeling_internvl
|
|
2290
|
+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
|
|
2291
|
+
from transformers.models.internvl.modeling_internvl import InternVLModel
|
|
2292
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
|
|
2293
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
|
|
2294
|
+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
|
|
2295
|
+
|
|
2296
|
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
2297
|
+
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
|
|
2298
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
2299
|
+
|
|
2300
|
+
if layer_norm and model is None:
|
|
2301
|
+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
|
|
2302
|
+
|
|
2303
|
+
if cross_entropy:
|
|
2304
|
+
logger.info("Apply liger cross entropy")
|
|
2305
|
+
|
|
2306
|
+
from transformers.loss.loss_utils import nn
|
|
2307
|
+
|
|
2308
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2309
|
+
if fused_linear_cross_entropy:
|
|
2310
|
+
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
|
|
2311
|
+
if rms_norm:
|
|
2312
|
+
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
|
|
2313
|
+
|
|
2314
|
+
if model is not None:
|
|
2315
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2316
|
+
# instance variables that reference already-instantiated modules
|
|
2317
|
+
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
|
|
2318
|
+
# NOTE: language_model and visual properties can be accessed throught conditional class.
|
|
2319
|
+
text_model = model.language_model
|
|
2320
|
+
vision_model: InternVLVisionModel = model.vision_tower
|
|
2321
|
+
else:
|
|
2322
|
+
raise TypeError(
|
|
2323
|
+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
|
|
2324
|
+
)
|
|
2325
|
+
|
|
2326
|
+
text_model_name = model.config.text_config.model_type
|
|
2327
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2328
|
+
|
|
2329
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2330
|
+
if text_liger_fn:
|
|
2331
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2332
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2333
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2334
|
+
|
|
2335
|
+
if remain_params:
|
|
2336
|
+
logger.warning(
|
|
2337
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2338
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2339
|
+
)
|
|
2340
|
+
text_kwargs["model"] = text_model
|
|
2341
|
+
text_liger_fn(**text_kwargs)
|
|
2342
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2343
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2344
|
+
|
|
2345
|
+
# Patch vision model RMSNorm layers
|
|
2346
|
+
if rms_norm:
|
|
2347
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2348
|
+
encoder_layer: InternVLVisionLayer
|
|
2349
|
+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
|
|
2350
|
+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
|
|
2351
|
+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
|
|
2352
|
+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
|
|
2353
|
+
|
|
2354
|
+
# Patch vision model LayerNorm layers
|
|
2355
|
+
if layer_norm:
|
|
2356
|
+
# Patch layernorm
|
|
2357
|
+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
|
|
2358
|
+
_patch_layer_norm_module(vision_model.layernorm)
|
|
2359
|
+
|
|
2360
|
+
# Patch encoder layers
|
|
2361
|
+
for encoder_layer in vision_model.encoder.layer:
|
|
2362
|
+
encoder_layer: InternVLVisionLayer
|
|
2363
|
+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
|
|
2364
|
+
_patch_layer_norm_module(encoder_layer.layernorm_before)
|
|
2365
|
+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
|
|
2366
|
+
_patch_layer_norm_module(encoder_layer.layernorm_after)
|
|
2367
|
+
|
|
2368
|
+
|
|
2369
|
+
def apply_liger_kernel_to_smolvlm(
|
|
2370
|
+
cross_entropy: bool = False,
|
|
2371
|
+
fused_linear_cross_entropy: bool = True,
|
|
2372
|
+
rms_norm: bool = True,
|
|
2373
|
+
layer_norm: bool = True,
|
|
2374
|
+
model: Optional[PreTrainedModel] = None,
|
|
2375
|
+
**kwargs,
|
|
2376
|
+
) -> None:
|
|
2377
|
+
"""
|
|
2378
|
+
Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
|
|
2379
|
+
Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
|
|
2380
|
+
However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
|
|
2381
|
+
NOTE: SmolVLM is not available in transformers<4.50.0
|
|
2382
|
+
|
|
2383
|
+
Args:
|
|
2384
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2385
|
+
fused_linear_cross_entropy (bool):
|
|
2386
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2387
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2388
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2389
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2390
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
2391
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2392
|
+
loaded. Default is None.
|
|
2393
|
+
"""
|
|
2394
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2395
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2396
|
+
)
|
|
2397
|
+
|
|
2398
|
+
from transformers.models.smolvlm import modeling_smolvlm
|
|
2399
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
|
|
2400
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
|
|
2401
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
|
|
2402
|
+
from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
|
|
2403
|
+
|
|
2404
|
+
from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
|
|
2405
|
+
|
|
2406
|
+
# Patch LayerNorm for vision model if model is not provided (pre-initialization)
|
|
2407
|
+
if layer_norm and model is None:
|
|
2408
|
+
modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
|
|
2409
|
+
|
|
2410
|
+
if cross_entropy:
|
|
2411
|
+
logger.info("Apply liger cross entropy")
|
|
2412
|
+
|
|
2413
|
+
from transformers.loss.loss_utils import nn
|
|
2414
|
+
|
|
2415
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2416
|
+
if fused_linear_cross_entropy:
|
|
2417
|
+
if model is not None:
|
|
2418
|
+
model.forward = MethodType(smolvlm_lce_forward, model)
|
|
2419
|
+
else:
|
|
2420
|
+
modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
|
|
2421
|
+
if rms_norm:
|
|
2422
|
+
modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
|
|
2423
|
+
|
|
2424
|
+
if model is not None:
|
|
2425
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2426
|
+
# instance variables that reference already-instantiated modules
|
|
2427
|
+
if isinstance(model, SmolVLMForConditionalGeneration):
|
|
2428
|
+
text_model = model.model.text_model
|
|
2429
|
+
vision_model: SmolVLMVisionTransformer = model.model.vision_model
|
|
2430
|
+
elif isinstance(model, SmolVLMModel):
|
|
2431
|
+
text_model = model.text_model
|
|
2432
|
+
vision_model: SmolVLMVisionTransformer = model.vision_model
|
|
2433
|
+
else:
|
|
2434
|
+
raise TypeError(
|
|
2435
|
+
f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
|
|
2436
|
+
)
|
|
2437
|
+
|
|
2438
|
+
text_model_name = model.config.text_config.model_type
|
|
2439
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
2440
|
+
|
|
2441
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
|
|
2442
|
+
if text_liger_fn:
|
|
2443
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
2444
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
2445
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
2446
|
+
|
|
2447
|
+
if remain_params:
|
|
2448
|
+
logger.warning(
|
|
2449
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
2450
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
2451
|
+
)
|
|
2452
|
+
text_kwargs["model"] = text_model
|
|
2453
|
+
text_liger_fn(**text_kwargs)
|
|
2454
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
2455
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
2456
|
+
|
|
2457
|
+
# Patch vision model LayerNorm layers
|
|
2458
|
+
if layer_norm:
|
|
2459
|
+
# Patch post_layernorm
|
|
2460
|
+
_patch_layer_norm_module(vision_model.post_layernorm)
|
|
2461
|
+
|
|
2462
|
+
# Patch encoder layers
|
|
2463
|
+
for encoder_layer in vision_model.encoder.layers:
|
|
2464
|
+
encoder_layer: SmolVLMEncoderLayer
|
|
2465
|
+
_patch_layer_norm_module(encoder_layer.layer_norm1)
|
|
2466
|
+
_patch_layer_norm_module(encoder_layer.layer_norm2)
|
|
2467
|
+
|
|
2468
|
+
|
|
2469
|
+
def apply_liger_kernel_to_falcon_h1(
|
|
2470
|
+
rope: bool = True,
|
|
2471
|
+
cross_entropy: bool = False,
|
|
2472
|
+
fused_linear_cross_entropy: bool = True,
|
|
2473
|
+
rms_norm: bool = True,
|
|
2474
|
+
swiglu: bool = False,
|
|
2475
|
+
model: PreTrainedModel = None,
|
|
2476
|
+
) -> None:
|
|
2477
|
+
"""
|
|
2478
|
+
Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
|
|
2479
|
+
Args:
|
|
2480
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2481
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
2482
|
+
fused_linear_cross_entropy (bool):
|
|
2483
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
2484
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2485
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2486
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
2487
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2488
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2489
|
+
loaded. Default is None.
|
|
2490
|
+
"""
|
|
2491
|
+
|
|
2492
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2493
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2494
|
+
)
|
|
2495
|
+
|
|
2496
|
+
from transformers.models.falcon_h1 import modeling_falcon_h1
|
|
2497
|
+
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
|
|
2498
|
+
|
|
2499
|
+
if rope:
|
|
2500
|
+
logger.info("Apply liger rotary pos emb.")
|
|
2501
|
+
modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2502
|
+
if rms_norm:
|
|
2503
|
+
logger.info("Apply liger RMSNorm")
|
|
2504
|
+
modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
|
|
2505
|
+
if swiglu:
|
|
2506
|
+
logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
|
|
2507
|
+
|
|
2508
|
+
if cross_entropy:
|
|
2509
|
+
logger.info("Apply liger cross entropy")
|
|
2510
|
+
from transformers.loss.loss_utils import nn
|
|
2511
|
+
|
|
2512
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2513
|
+
|
|
2514
|
+
if fused_linear_cross_entropy:
|
|
2515
|
+
if model is not None:
|
|
2516
|
+
model.forward = MethodType(falcon_h1_lce_forward, model)
|
|
2517
|
+
else:
|
|
2518
|
+
modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
|
|
2519
|
+
|
|
2520
|
+
if model is not None:
|
|
2521
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2522
|
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
2523
|
+
|
|
2524
|
+
# get the base model from the model instance
|
|
2525
|
+
base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
|
|
2526
|
+
|
|
2527
|
+
if rms_norm:
|
|
2528
|
+
_patch_rms_norm_module(base_model.final_layernorm)
|
|
2529
|
+
|
|
2530
|
+
for decoder_layer in base_model.layers:
|
|
2531
|
+
if swiglu:
|
|
2532
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
2533
|
+
if rms_norm:
|
|
2534
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2535
|
+
_patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
|
|
2536
|
+
|
|
2537
|
+
|
|
2538
|
+
def apply_liger_kernel_to_qwen3_next(
|
|
2539
|
+
rope: bool = False,
|
|
2540
|
+
cross_entropy: bool = False,
|
|
2541
|
+
fused_linear_cross_entropy: bool = True,
|
|
2542
|
+
rms_norm: bool = True,
|
|
2543
|
+
swiglu: bool = True,
|
|
2544
|
+
model: PreTrainedModel = None,
|
|
2545
|
+
) -> None:
|
|
2546
|
+
"""
|
|
2547
|
+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
|
|
2548
|
+
|
|
2549
|
+
Args:
|
|
2550
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
2551
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2552
|
+
fused_linear_cross_entropy (bool):
|
|
2553
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2554
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2555
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2556
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2557
|
+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
|
|
2558
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2559
|
+
loaded. Default is None.
|
|
2560
|
+
"""
|
|
2561
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2562
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2563
|
+
)
|
|
2564
|
+
|
|
2565
|
+
from transformers.models.qwen3_next import modeling_qwen3_next
|
|
2566
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
|
|
2567
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
|
|
2568
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
|
|
2569
|
+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
|
2570
|
+
|
|
2571
|
+
from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
|
|
2572
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
|
|
2573
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
2574
|
+
|
|
2575
|
+
if rope:
|
|
2576
|
+
# It might enocunter nan issue
|
|
2577
|
+
# modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2578
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
|
|
2579
|
+
if rms_norm:
|
|
2580
|
+
modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
|
|
2581
|
+
if cross_entropy:
|
|
2582
|
+
from transformers.loss.loss_utils import nn
|
|
2583
|
+
|
|
2584
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2585
|
+
if fused_linear_cross_entropy:
|
|
2586
|
+
if model is not None:
|
|
2587
|
+
if isinstance(model, Qwen3NextForCausalLM):
|
|
2588
|
+
model.forward = MethodType(qwen3_next_lce_forward, model)
|
|
2589
|
+
else:
|
|
2590
|
+
raise TypeError(
|
|
2591
|
+
f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
|
|
2592
|
+
)
|
|
2593
|
+
else:
|
|
2594
|
+
modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
|
|
2595
|
+
if swiglu:
|
|
2596
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2597
|
+
modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
|
|
2598
|
+
|
|
2599
|
+
if model is not None:
|
|
2600
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2601
|
+
# instance variables that reference already-instantiated modules
|
|
2602
|
+
if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
|
|
2603
|
+
base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
|
|
2604
|
+
else:
|
|
2605
|
+
raise TypeError(
|
|
2606
|
+
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
|
|
2607
|
+
)
|
|
2608
|
+
|
|
2609
|
+
if rms_norm:
|
|
2610
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2611
|
+
|
|
2612
|
+
for decoder_layer in base_model.layers:
|
|
2613
|
+
if rms_norm:
|
|
2614
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2615
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2616
|
+
|
|
2617
|
+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
|
|
2618
|
+
if swiglu:
|
|
2619
|
+
if isinstance(decoder_layer.mlp, Qwen3NextMLP):
|
|
2620
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
|
|
2621
|
+
if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
|
|
2622
|
+
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
|
|
2623
|
+
experts = getattr(decoder_layer.mlp, "experts", None)
|
|
2624
|
+
if experts is not None:
|
|
2625
|
+
for expert in experts:
|
|
2626
|
+
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
|
|
2627
|
+
|
|
2628
|
+
|
|
2629
|
+
def apply_liger_kernel_to_hunyuan_v1_dense(
|
|
2630
|
+
rope: bool = True,
|
|
2631
|
+
cross_entropy: bool = False,
|
|
2632
|
+
fused_linear_cross_entropy: bool = True,
|
|
2633
|
+
rms_norm: bool = True,
|
|
2634
|
+
swiglu: bool = True,
|
|
2635
|
+
model: PreTrainedModel = None,
|
|
2636
|
+
) -> None:
|
|
2637
|
+
"""
|
|
2638
|
+
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
|
|
2639
|
+
"""
|
|
2640
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2641
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2642
|
+
)
|
|
2643
|
+
|
|
2644
|
+
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
|
|
2645
|
+
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
|
|
2646
|
+
|
|
2647
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
|
|
2648
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2649
|
+
|
|
2650
|
+
if rope:
|
|
2651
|
+
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2652
|
+
|
|
2653
|
+
if rms_norm:
|
|
2654
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
|
|
2655
|
+
|
|
2656
|
+
if cross_entropy:
|
|
2657
|
+
from transformers.loss.loss_utils import nn
|
|
2658
|
+
|
|
2659
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2660
|
+
|
|
2661
|
+
if fused_linear_cross_entropy:
|
|
2662
|
+
if model is not None:
|
|
2663
|
+
model.forward = MethodType(hunyuan_v1_lce_forward, model)
|
|
2664
|
+
else:
|
|
2665
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
|
|
2666
|
+
|
|
2667
|
+
if swiglu:
|
|
2668
|
+
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2669
|
+
|
|
2670
|
+
if model is not None:
|
|
2671
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2672
|
+
# instance variables that reference already-instantiated modules
|
|
2673
|
+
|
|
2674
|
+
# get the base model from the model instance
|
|
2675
|
+
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
|
|
2676
|
+
|
|
2677
|
+
if rms_norm:
|
|
2678
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2679
|
+
for decoder_layer in base_model.layers:
|
|
2680
|
+
if swiglu:
|
|
2681
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
|
|
2682
|
+
if rms_norm:
|
|
2683
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2684
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2685
|
+
|
|
2686
|
+
|
|
2687
|
+
def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
2688
|
+
rope: bool = True,
|
|
2689
|
+
cross_entropy: bool = False,
|
|
2690
|
+
fused_linear_cross_entropy: bool = True,
|
|
2691
|
+
rms_norm: bool = True,
|
|
2692
|
+
swiglu: bool = True,
|
|
2693
|
+
model: PreTrainedModel = None,
|
|
2694
|
+
) -> None:
|
|
2695
|
+
"""
|
|
2696
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
2697
|
+
"""
|
|
2698
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2699
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2700
|
+
)
|
|
2701
|
+
|
|
2702
|
+
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
|
|
2703
|
+
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
|
|
2704
|
+
|
|
2705
|
+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
|
|
2706
|
+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
|
|
2707
|
+
|
|
2708
|
+
if rope:
|
|
2709
|
+
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2710
|
+
|
|
2711
|
+
if rms_norm:
|
|
2712
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
|
|
2713
|
+
|
|
2714
|
+
if cross_entropy:
|
|
2715
|
+
from transformers.loss.loss_utils import nn
|
|
2716
|
+
|
|
2717
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2718
|
+
|
|
2719
|
+
if fused_linear_cross_entropy:
|
|
2720
|
+
if model is not None:
|
|
2721
|
+
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
|
|
2722
|
+
else:
|
|
2723
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
|
|
2724
|
+
|
|
2725
|
+
if swiglu:
|
|
2726
|
+
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
|
|
2727
|
+
|
|
2728
|
+
if model is not None:
|
|
2729
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2730
|
+
# instance variables that reference already-instantiated modules
|
|
2731
|
+
|
|
2732
|
+
# get the base model from the model instance
|
|
2733
|
+
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
|
|
2734
|
+
|
|
2735
|
+
if rms_norm:
|
|
2736
|
+
_patch_rms_norm_module(base_model.norm)
|
|
2737
|
+
for decoder_layer in base_model.layers:
|
|
2738
|
+
if swiglu:
|
|
2739
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
2740
|
+
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
|
|
2741
|
+
if rms_norm:
|
|
2742
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
2743
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2744
|
+
|
|
2745
|
+
|
|
1600
2746
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
1601
2747
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
1602
2748
|
"gemma": apply_liger_kernel_to_gemma,
|
|
@@ -1604,7 +2750,12 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1604
2750
|
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
1605
2751
|
"gemma3": apply_liger_kernel_to_gemma3,
|
|
1606
2752
|
"glm4": apply_liger_kernel_to_glm4,
|
|
2753
|
+
"glm4v": apply_liger_kernel_to_glm4v,
|
|
2754
|
+
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
|
|
2755
|
+
"internvl": apply_liger_kernel_to_internvl,
|
|
1607
2756
|
"llama": apply_liger_kernel_to_llama,
|
|
2757
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
2758
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
1608
2759
|
"llava": apply_liger_kernel_to_llava,
|
|
1609
2760
|
"granite": apply_liger_kernel_to_granite,
|
|
1610
2761
|
"mllama": apply_liger_kernel_to_mllama,
|
|
@@ -1612,6 +2763,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1612
2763
|
"mistral": apply_liger_kernel_to_mistral,
|
|
1613
2764
|
"mixtral": apply_liger_kernel_to_mixtral,
|
|
1614
2765
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
2766
|
+
"olmo3": apply_liger_kernel_to_olmo3,
|
|
1615
2767
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
1616
2768
|
"qwen3": apply_liger_kernel_to_qwen3,
|
|
1617
2769
|
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
@@ -1619,8 +2771,18 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1619
2771
|
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
1620
2772
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
1621
2773
|
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
2774
|
+
"qwen3_next": apply_liger_kernel_to_qwen3_next,
|
|
2775
|
+
"qwen3_vl": apply_liger_kernel_to_qwen3_vl,
|
|
2776
|
+
"qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
|
|
2777
|
+
"qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2778
|
+
"qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
|
|
2779
|
+
"smollm3": apply_liger_kernel_to_smollm3,
|
|
1622
2780
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1623
2781
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
2782
|
+
"falcon_h1": apply_liger_kernel_to_falcon_h1,
|
|
2783
|
+
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2784
|
+
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2785
|
+
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
1624
2786
|
}
|
|
1625
2787
|
|
|
1626
2788
|
|