liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/dpo_loss.py +1 -1
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/dyt.py +111 -179
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +8 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +70 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +25 -16
- liger_kernel/transformers/model/gemma2.py +27 -14
- liger_kernel/transformers/model/gemma3.py +62 -106
- liger_kernel/transformers/model/glm4.py +16 -13
- liger_kernel/transformers/model/llama.py +81 -18
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -132
- liger_kernel/transformers/model/mistral.py +13 -14
- liger_kernel/transformers/model/mixtral.py +16 -15
- liger_kernel/transformers/model/mllama.py +16 -14
- liger_kernel/transformers/model/olmo2.py +16 -13
- liger_kernel/transformers/model/paligemma.py +8 -9
- liger_kernel/transformers/model/phi3.py +25 -16
- liger_kernel/transformers/model/qwen2.py +24 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
- liger_kernel/transformers/model/qwen2_vl.py +38 -106
- liger_kernel/transformers/model/qwen3.py +11 -9
- liger_kernel/transformers/model/qwen3_moe.py +132 -0
- liger_kernel/transformers/monkey_patch.py +424 -81
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
- liger_kernel-0.6.0.dist-info/RECORD +97 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel-0.5.9.dist-info/RECORD +0 -84
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
5
6
|
from typing import Callable
|
|
6
7
|
|
|
7
8
|
import transformers
|
|
@@ -35,6 +36,13 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
|
35
36
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
36
37
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
37
38
|
|
|
39
|
+
try:
|
|
40
|
+
import peft
|
|
41
|
+
|
|
42
|
+
PEFT_AVAILABLE = True
|
|
43
|
+
except ImportError:
|
|
44
|
+
PEFT_AVAILABLE = False
|
|
45
|
+
|
|
38
46
|
transformer_version = version.parse(transformers.__version__)
|
|
39
47
|
|
|
40
48
|
logger = logging.getLogger(__name__)
|
|
@@ -47,23 +55,72 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
47
55
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
48
56
|
|
|
49
57
|
|
|
50
|
-
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
51
|
-
module
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
59
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
60
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
61
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
62
|
+
module.modules_to_save.default.offset = offset
|
|
63
|
+
module.modules_to_save.default.casting_mode = casting_mode
|
|
64
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
65
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
66
|
+
)
|
|
67
|
+
module.modules_to_save.default.in_place = in_place
|
|
68
|
+
module.modules_to_save.default.row_mode = row_mode
|
|
69
|
+
module.original_module.offset = offset
|
|
70
|
+
module.original_module.casting_mode = casting_mode
|
|
71
|
+
module.original_module.variance_epsilon = (
|
|
72
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
73
|
+
)
|
|
74
|
+
module.original_module.in_place = in_place
|
|
75
|
+
module.original_module.row_mode = row_mode
|
|
76
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
77
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
78
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
79
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
80
|
+
module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
|
|
81
|
+
module.original_module.__class__.__name__ = LigerRMSNorm.__name__
|
|
82
|
+
else:
|
|
83
|
+
module.offset = offset
|
|
84
|
+
module.casting_mode = casting_mode
|
|
85
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
86
|
+
module.in_place = in_place
|
|
87
|
+
module.row_mode = row_mode
|
|
88
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
89
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
90
|
+
module.__class__.__name__ = LigerRMSNorm.__name__
|
|
58
91
|
|
|
59
92
|
|
|
60
93
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
94
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
95
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
96
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
97
|
+
module.hidden_size = module.normalized_shape
|
|
98
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
99
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
100
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
101
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
102
|
+
)
|
|
103
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
104
|
+
module, "normalized_shape", None
|
|
105
|
+
)
|
|
106
|
+
module.original_module.variance_epsilon = (
|
|
107
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
108
|
+
)
|
|
109
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
110
|
+
module, "normalized_shape", None
|
|
111
|
+
)
|
|
112
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
113
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
114
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
115
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
116
|
+
module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
|
|
117
|
+
module.original_module.__class__.__name__ = LigerLayerNorm.__name__
|
|
118
|
+
else:
|
|
119
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
120
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
121
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
122
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
123
|
+
module.__class__.__name__ = LigerLayerNorm.__name__
|
|
67
124
|
|
|
68
125
|
|
|
69
126
|
def _patch_swiglu_module(module, liger_module):
|
|
@@ -204,10 +261,16 @@ def apply_liger_kernel_to_llama(
|
|
|
204
261
|
|
|
205
262
|
if fused_linear_cross_entropy:
|
|
206
263
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
207
|
-
|
|
264
|
+
if model is not None:
|
|
265
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
266
|
+
else:
|
|
267
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
208
268
|
else: # if version < 4.46.1
|
|
209
269
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
210
|
-
|
|
270
|
+
if model is not None:
|
|
271
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
272
|
+
else:
|
|
273
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
211
274
|
|
|
212
275
|
if model is not None:
|
|
213
276
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -261,13 +324,20 @@ def apply_liger_kernel_to_llava(
|
|
|
261
324
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
262
325
|
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
263
326
|
if fused_linear_cross_entropy:
|
|
264
|
-
if transformer_version >= version.parse("4.
|
|
265
|
-
|
|
327
|
+
if transformer_version >= version.parse("4.52.0"):
|
|
328
|
+
if model is not None:
|
|
329
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
330
|
+
else:
|
|
331
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
332
|
+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
333
|
+
if model is not None:
|
|
334
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
335
|
+
else:
|
|
336
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
266
337
|
else: # if version < 4.49.0
|
|
267
338
|
logger.warning(
|
|
268
|
-
"
|
|
339
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
269
340
|
)
|
|
270
|
-
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
271
341
|
|
|
272
342
|
if model is not None:
|
|
273
343
|
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
@@ -306,6 +376,92 @@ def apply_liger_kernel_to_llava(
|
|
|
306
376
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
307
377
|
|
|
308
378
|
|
|
379
|
+
def apply_liger_kernel_to_llama4(
|
|
380
|
+
rope: bool = False,
|
|
381
|
+
cross_entropy: bool = False,
|
|
382
|
+
fused_linear_cross_entropy: bool = True,
|
|
383
|
+
rms_norm: bool = True,
|
|
384
|
+
swiglu: bool = True,
|
|
385
|
+
model: PreTrainedModel = None,
|
|
386
|
+
layer_norm: bool = True,
|
|
387
|
+
) -> None:
|
|
388
|
+
"""
|
|
389
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
393
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
394
|
+
fused_linear_cross_entropy (bool):
|
|
395
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
396
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
397
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
398
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
399
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
400
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
401
|
+
loaded. Default is None.
|
|
402
|
+
"""
|
|
403
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
404
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
from transformers.models.llama4 import modeling_llama4
|
|
408
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
409
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
410
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
411
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
412
|
+
|
|
413
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
414
|
+
|
|
415
|
+
if rope:
|
|
416
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
|
|
417
|
+
if rms_norm:
|
|
418
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
419
|
+
if swiglu:
|
|
420
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
421
|
+
|
|
422
|
+
if cross_entropy:
|
|
423
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
424
|
+
|
|
425
|
+
if fused_linear_cross_entropy:
|
|
426
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
427
|
+
|
|
428
|
+
if model is not None:
|
|
429
|
+
# The model instance already exists, so we need to additionally patch the
|
|
430
|
+
# instance variables that reference already-instantiated modules
|
|
431
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
432
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
433
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
434
|
+
text_model: Llama4TextModel = language_model.model
|
|
435
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
436
|
+
text_model = model.model
|
|
437
|
+
vision_model = None
|
|
438
|
+
elif isinstance(model, Llama4TextModel):
|
|
439
|
+
text_model = model
|
|
440
|
+
vision_model = None
|
|
441
|
+
|
|
442
|
+
else:
|
|
443
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
444
|
+
|
|
445
|
+
if text_model:
|
|
446
|
+
if rms_norm:
|
|
447
|
+
_patch_rms_norm_module(text_model.norm)
|
|
448
|
+
for decoder_layer in text_model.layers:
|
|
449
|
+
if swiglu:
|
|
450
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
451
|
+
if rms_norm:
|
|
452
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
453
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
454
|
+
|
|
455
|
+
if vision_model:
|
|
456
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
457
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
458
|
+
|
|
459
|
+
for layer in vision_model.model.layers:
|
|
460
|
+
if layer_norm:
|
|
461
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
462
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
463
|
+
|
|
464
|
+
|
|
309
465
|
def apply_liger_kernel_to_mllama(
|
|
310
466
|
rope: bool = True,
|
|
311
467
|
cross_entropy: bool = False,
|
|
@@ -347,7 +503,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
347
503
|
|
|
348
504
|
if rope:
|
|
349
505
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
350
|
-
if layer_norm:
|
|
506
|
+
if layer_norm and model is None:
|
|
351
507
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
352
508
|
if rms_norm:
|
|
353
509
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
@@ -363,10 +519,16 @@ def apply_liger_kernel_to_mllama(
|
|
|
363
519
|
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
364
520
|
if fused_linear_cross_entropy:
|
|
365
521
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
366
|
-
|
|
522
|
+
if model is not None:
|
|
523
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
524
|
+
else:
|
|
525
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
367
526
|
else: # if version < 4.46.1
|
|
368
527
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
369
|
-
|
|
528
|
+
if model is not None:
|
|
529
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
530
|
+
else:
|
|
531
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
370
532
|
|
|
371
533
|
if model is not None:
|
|
372
534
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -375,13 +537,17 @@ def apply_liger_kernel_to_mllama(
|
|
|
375
537
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
376
538
|
language_model: MllamaForCausalLM = model.language_model
|
|
377
539
|
vision_model: MllamaVisionModel = model.vision_model
|
|
378
|
-
|
|
540
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
541
|
+
text_model: MllamaTextModel = language_model.model
|
|
542
|
+
else:
|
|
543
|
+
text_model = language_model
|
|
379
544
|
elif isinstance(model, MllamaForCausalLM):
|
|
380
545
|
text_model = model.model
|
|
381
546
|
vision_model = None
|
|
382
547
|
elif isinstance(model, MllamaTextModel):
|
|
383
548
|
text_model = model
|
|
384
549
|
vision_model = None
|
|
550
|
+
|
|
385
551
|
else:
|
|
386
552
|
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
|
387
553
|
|
|
@@ -448,7 +614,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
448
614
|
if cross_entropy:
|
|
449
615
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
450
616
|
if fused_linear_cross_entropy:
|
|
451
|
-
|
|
617
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
618
|
+
if model is not None:
|
|
619
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
620
|
+
else:
|
|
621
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
622
|
+
else:
|
|
623
|
+
logger.warning(
|
|
624
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
625
|
+
)
|
|
626
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
627
|
+
|
|
452
628
|
if swiglu:
|
|
453
629
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
454
630
|
|
|
@@ -516,10 +692,16 @@ def apply_liger_kernel_to_mixtral(
|
|
|
516
692
|
|
|
517
693
|
if fused_linear_cross_entropy:
|
|
518
694
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
519
|
-
|
|
695
|
+
if model is not None:
|
|
696
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
697
|
+
else:
|
|
698
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
520
699
|
else: # if version < 4.46.1
|
|
521
700
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
522
|
-
|
|
701
|
+
if model is not None:
|
|
702
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
703
|
+
else:
|
|
704
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
523
705
|
if swiglu:
|
|
524
706
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
525
707
|
|
|
@@ -573,8 +755,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
573
755
|
from transformers.models.gemma import modeling_gemma
|
|
574
756
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
575
757
|
|
|
576
|
-
|
|
577
|
-
|
|
758
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
|
759
|
+
|
|
578
760
|
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
|
579
761
|
|
|
580
762
|
if rope:
|
|
@@ -593,10 +775,16 @@ def apply_liger_kernel_to_gemma(
|
|
|
593
775
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
594
776
|
if fused_linear_cross_entropy:
|
|
595
777
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
596
|
-
|
|
778
|
+
if model is not None:
|
|
779
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
780
|
+
else:
|
|
781
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
597
782
|
else: # if version < 4.46.1
|
|
598
783
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
599
|
-
|
|
784
|
+
if model is not None:
|
|
785
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
786
|
+
else:
|
|
787
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
600
788
|
|
|
601
789
|
if model is not None:
|
|
602
790
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -647,7 +835,8 @@ def apply_liger_kernel_to_gemma2(
|
|
|
647
835
|
from transformers.models.gemma2 import modeling_gemma2
|
|
648
836
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
649
837
|
|
|
650
|
-
|
|
838
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
|
839
|
+
|
|
651
840
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
652
841
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
653
842
|
)
|
|
@@ -667,10 +856,16 @@ def apply_liger_kernel_to_gemma2(
|
|
|
667
856
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
668
857
|
if fused_linear_cross_entropy:
|
|
669
858
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
670
|
-
|
|
859
|
+
if model is not None:
|
|
860
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
861
|
+
else:
|
|
862
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
671
863
|
else:
|
|
672
864
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
673
|
-
|
|
865
|
+
if model is not None:
|
|
866
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
867
|
+
else:
|
|
868
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
674
869
|
if geglu:
|
|
675
870
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
676
871
|
|
|
@@ -724,9 +919,10 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
724
919
|
from transformers.models.gemma3 import modeling_gemma3
|
|
725
920
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
|
726
921
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
|
922
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
|
727
923
|
|
|
728
|
-
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
|
729
924
|
from liger_kernel.transformers.model.gemma3 import causal_forward
|
|
925
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
|
730
926
|
|
|
731
927
|
_patch_rms_norm_module_for_gemma3 = partial(
|
|
732
928
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
@@ -748,15 +944,18 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
748
944
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
749
945
|
|
|
750
946
|
if fused_linear_cross_entropy:
|
|
751
|
-
|
|
947
|
+
if model is not None:
|
|
948
|
+
model.forward = MethodType(causal_forward, model)
|
|
949
|
+
else:
|
|
950
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
752
951
|
|
|
753
952
|
if model is not None:
|
|
754
953
|
# The model instance already exists, so we need to additionally patch the
|
|
755
954
|
# instance variables that reference already-instantiated modules
|
|
756
955
|
|
|
757
|
-
if isinstance(model, Gemma3ForCausalLM):
|
|
956
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
|
758
957
|
# get the base model from the model instance
|
|
759
|
-
base_model = model.model
|
|
958
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
|
760
959
|
|
|
761
960
|
if rms_norm:
|
|
762
961
|
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
|
@@ -818,7 +1017,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
818
1017
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
819
1018
|
)
|
|
820
1019
|
|
|
821
|
-
if layer_norm:
|
|
1020
|
+
if layer_norm and model is None:
|
|
822
1021
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
823
1022
|
|
|
824
1023
|
apply_liger_kernel_to_gemma3_text(
|
|
@@ -829,7 +1028,10 @@ def apply_liger_kernel_to_gemma3(
|
|
|
829
1028
|
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
830
1029
|
|
|
831
1030
|
if fused_linear_cross_entropy:
|
|
832
|
-
|
|
1031
|
+
if model is not None:
|
|
1032
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1033
|
+
else:
|
|
1034
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
833
1035
|
|
|
834
1036
|
if model is not None:
|
|
835
1037
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -897,7 +1099,9 @@ def apply_liger_kernel_to_paligemma(
|
|
|
897
1099
|
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
898
1100
|
|
|
899
1101
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1102
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
900
1103
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1104
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
901
1105
|
from transformers.models.paligemma import modeling_paligemma
|
|
902
1106
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
903
1107
|
from transformers.models.siglip import modeling_siglip
|
|
@@ -908,7 +1112,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
908
1112
|
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
909
1113
|
|
|
910
1114
|
# The vision_tower is a SiglipVisionModel
|
|
911
|
-
if layer_norm:
|
|
1115
|
+
if layer_norm and model is None:
|
|
912
1116
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
913
1117
|
|
|
914
1118
|
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
@@ -926,10 +1130,16 @@ def apply_liger_kernel_to_paligemma(
|
|
|
926
1130
|
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
927
1131
|
if fused_linear_cross_entropy:
|
|
928
1132
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
929
|
-
|
|
1133
|
+
if model is not None:
|
|
1134
|
+
model.forward = MethodType(lce_forward, model)
|
|
1135
|
+
else:
|
|
1136
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
930
1137
|
else: # if version < 4.46.1
|
|
931
1138
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
932
|
-
|
|
1139
|
+
if model is not None:
|
|
1140
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1141
|
+
else:
|
|
1142
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
933
1143
|
|
|
934
1144
|
if model is not None:
|
|
935
1145
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -950,7 +1160,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
950
1160
|
|
|
951
1161
|
language_model = model.language_model
|
|
952
1162
|
|
|
953
|
-
if isinstance(language_model, GemmaForCausalLM):
|
|
1163
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
954
1164
|
apply_liger_kernel_to_gemma(
|
|
955
1165
|
rope=rope,
|
|
956
1166
|
cross_entropy=False,
|
|
@@ -960,7 +1170,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
960
1170
|
model=language_model,
|
|
961
1171
|
)
|
|
962
1172
|
|
|
963
|
-
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
1173
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
964
1174
|
apply_liger_kernel_to_gemma2(
|
|
965
1175
|
rope=rope,
|
|
966
1176
|
cross_entropy=False,
|
|
@@ -1021,10 +1231,16 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1021
1231
|
|
|
1022
1232
|
if fused_linear_cross_entropy:
|
|
1023
1233
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1024
|
-
|
|
1234
|
+
if model is not None:
|
|
1235
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1236
|
+
else:
|
|
1237
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1025
1238
|
else: # if version < 4.46.1
|
|
1026
1239
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1027
|
-
|
|
1240
|
+
if model is not None:
|
|
1241
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1242
|
+
else:
|
|
1243
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1028
1244
|
|
|
1029
1245
|
if swiglu:
|
|
1030
1246
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
@@ -1080,7 +1296,10 @@ def apply_liger_kernel_to_qwen3(
|
|
|
1080
1296
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1081
1297
|
|
|
1082
1298
|
if fused_linear_cross_entropy:
|
|
1083
|
-
|
|
1299
|
+
if model is not None:
|
|
1300
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1301
|
+
else:
|
|
1302
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1084
1303
|
|
|
1085
1304
|
if swiglu:
|
|
1086
1305
|
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
@@ -1102,6 +1321,65 @@ def apply_liger_kernel_to_qwen3(
|
|
|
1102
1321
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1103
1322
|
|
|
1104
1323
|
|
|
1324
|
+
def apply_liger_kernel_to_qwen3_moe(
|
|
1325
|
+
rope: bool = True,
|
|
1326
|
+
cross_entropy: bool = False,
|
|
1327
|
+
fused_linear_cross_entropy: bool = True,
|
|
1328
|
+
rms_norm: bool = True,
|
|
1329
|
+
swiglu: bool = True,
|
|
1330
|
+
model: PreTrainedModel = None,
|
|
1331
|
+
) -> None:
|
|
1332
|
+
"""
|
|
1333
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1334
|
+
"""
|
|
1335
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1336
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1337
|
+
)
|
|
1338
|
+
|
|
1339
|
+
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
1340
|
+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
|
|
1341
|
+
|
|
1342
|
+
from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
|
|
1343
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
1344
|
+
|
|
1345
|
+
if rope:
|
|
1346
|
+
modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1347
|
+
|
|
1348
|
+
if rms_norm:
|
|
1349
|
+
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
|
1350
|
+
|
|
1351
|
+
if cross_entropy:
|
|
1352
|
+
from transformers.loss.loss_utils import nn
|
|
1353
|
+
|
|
1354
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1355
|
+
|
|
1356
|
+
if fused_linear_cross_entropy:
|
|
1357
|
+
if model is not None:
|
|
1358
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1359
|
+
else:
|
|
1360
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1361
|
+
|
|
1362
|
+
if swiglu:
|
|
1363
|
+
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
1364
|
+
|
|
1365
|
+
if model is not None:
|
|
1366
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1367
|
+
# instance variables that reference already-instantiated modules
|
|
1368
|
+
|
|
1369
|
+
# get the base model from the model instance
|
|
1370
|
+
base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
|
|
1371
|
+
|
|
1372
|
+
if rms_norm:
|
|
1373
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1374
|
+
for decoder_layer in base_model.layers:
|
|
1375
|
+
if swiglu:
|
|
1376
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
1377
|
+
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
|
|
1378
|
+
if rms_norm:
|
|
1379
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1380
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1381
|
+
|
|
1382
|
+
|
|
1105
1383
|
def apply_liger_kernel_to_qwen2_vl(
|
|
1106
1384
|
rope: bool = True,
|
|
1107
1385
|
cross_entropy: bool = False,
|
|
@@ -1113,7 +1391,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1113
1391
|
) -> None:
|
|
1114
1392
|
"""
|
|
1115
1393
|
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
1116
|
-
NOTE: Qwen2-VL is not
|
|
1394
|
+
NOTE: Qwen2-VL is not supported in transformers<4.52.4
|
|
1117
1395
|
|
|
1118
1396
|
Args:
|
|
1119
1397
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
@@ -1127,12 +1405,19 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1127
1405
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1128
1406
|
loaded. Default is None.
|
|
1129
1407
|
"""
|
|
1408
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1409
|
+
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
|
|
1410
|
+
return
|
|
1411
|
+
|
|
1130
1412
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1131
1413
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1132
1414
|
)
|
|
1133
1415
|
|
|
1134
1416
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
1417
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
|
1418
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
|
1135
1419
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
1420
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
|
1136
1421
|
|
|
1137
1422
|
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
|
1138
1423
|
|
|
@@ -1141,12 +1426,15 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1141
1426
|
if rms_norm:
|
|
1142
1427
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1143
1428
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1144
|
-
if layer_norm:
|
|
1429
|
+
if layer_norm and model is None:
|
|
1145
1430
|
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1146
1431
|
if cross_entropy:
|
|
1147
1432
|
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1148
1433
|
if fused_linear_cross_entropy:
|
|
1149
|
-
|
|
1434
|
+
if model is not None:
|
|
1435
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1436
|
+
else:
|
|
1437
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1150
1438
|
if swiglu:
|
|
1151
1439
|
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1152
1440
|
|
|
@@ -1154,24 +1442,38 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1154
1442
|
# The model instance already exists, so we need to additionally patch the
|
|
1155
1443
|
# instance variables that reference already-instantiated modules
|
|
1156
1444
|
|
|
1157
|
-
|
|
1158
|
-
|
|
1445
|
+
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
|
|
1446
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1447
|
+
# Not sure if it is subject to changes in the future.
|
|
1448
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1449
|
+
text_model: Qwen2VLTextModel = model.language_model
|
|
1450
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1451
|
+
elif isinstance(model, Qwen2VLTextModel):
|
|
1452
|
+
text_model: Qwen2VLTextModel = model
|
|
1453
|
+
vision_model = None
|
|
1454
|
+
else:
|
|
1455
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1456
|
+
raise TypeError(
|
|
1457
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1458
|
+
)
|
|
1159
1459
|
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
for vision_block in
|
|
1460
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
|
1461
|
+
if vision_model is not None:
|
|
1462
|
+
for vision_block in vision_model.blocks:
|
|
1163
1463
|
if layer_norm:
|
|
1164
1464
|
_patch_layer_norm_module(vision_block.norm1)
|
|
1165
1465
|
_patch_layer_norm_module(vision_block.norm2)
|
|
1166
1466
|
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
for decoder_layer in base_model.layers:
|
|
1170
|
-
if swiglu:
|
|
1171
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1467
|
+
# Patch Qwen2VisionTextModel
|
|
1468
|
+
if text_model is not None:
|
|
1172
1469
|
if rms_norm:
|
|
1173
|
-
_patch_rms_norm_module(
|
|
1174
|
-
|
|
1470
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1471
|
+
for decoder_layer in text_model.layers:
|
|
1472
|
+
if swiglu:
|
|
1473
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1474
|
+
if rms_norm:
|
|
1475
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1476
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1175
1477
|
|
|
1176
1478
|
|
|
1177
1479
|
def apply_liger_kernel_to_qwen2_5_vl(
|
|
@@ -1197,12 +1499,19 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1197
1499
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1198
1500
|
loaded. Default is None.
|
|
1199
1501
|
"""
|
|
1502
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1503
|
+
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
|
|
1504
|
+
return
|
|
1505
|
+
|
|
1200
1506
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1201
1507
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1202
1508
|
)
|
|
1203
1509
|
|
|
1204
1510
|
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
1511
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
|
1512
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
|
1205
1513
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
|
1514
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
|
1206
1515
|
|
|
1207
1516
|
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
|
1208
1517
|
|
|
@@ -1213,7 +1522,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1213
1522
|
if cross_entropy:
|
|
1214
1523
|
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1215
1524
|
if fused_linear_cross_entropy:
|
|
1216
|
-
|
|
1525
|
+
if model is not None:
|
|
1526
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1527
|
+
else:
|
|
1528
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1217
1529
|
if swiglu:
|
|
1218
1530
|
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1219
1531
|
|
|
@@ -1221,24 +1533,37 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1221
1533
|
# The model instance already exists, so we need to additionally patch the
|
|
1222
1534
|
# instance variables that reference already-instantiated modules
|
|
1223
1535
|
|
|
1224
|
-
|
|
1225
|
-
|
|
1536
|
+
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
|
|
1537
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1538
|
+
# Not sure if it is subject to changes in the future.
|
|
1539
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1540
|
+
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1541
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1542
|
+
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
1543
|
+
text_model: Qwen2_5_VLTextModel = model
|
|
1544
|
+
vision_model = None
|
|
1545
|
+
else:
|
|
1546
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1547
|
+
raise TypeError(
|
|
1548
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1549
|
+
)
|
|
1226
1550
|
|
|
1227
|
-
if
|
|
1551
|
+
if vision_model is not None:
|
|
1228
1552
|
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1229
1553
|
for vision_block in model.visual.blocks:
|
|
1230
1554
|
if rms_norm:
|
|
1231
1555
|
_patch_rms_norm_module(vision_block.norm1)
|
|
1232
1556
|
_patch_rms_norm_module(vision_block.norm2)
|
|
1233
1557
|
|
|
1234
|
-
if
|
|
1235
|
-
_patch_rms_norm_module(base_model.norm)
|
|
1236
|
-
for decoder_layer in base_model.layers:
|
|
1237
|
-
if swiglu:
|
|
1238
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1558
|
+
if text_model is not None:
|
|
1239
1559
|
if rms_norm:
|
|
1240
|
-
_patch_rms_norm_module(
|
|
1241
|
-
|
|
1560
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1561
|
+
for decoder_layer in text_model.layers:
|
|
1562
|
+
if swiglu:
|
|
1563
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1564
|
+
if rms_norm:
|
|
1565
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1566
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1242
1567
|
|
|
1243
1568
|
|
|
1244
1569
|
def apply_liger_kernel_to_phi3(
|
|
@@ -1287,10 +1612,16 @@ def apply_liger_kernel_to_phi3(
|
|
|
1287
1612
|
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1288
1613
|
if fused_linear_cross_entropy:
|
|
1289
1614
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1290
|
-
|
|
1615
|
+
if model is not None:
|
|
1616
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1617
|
+
else:
|
|
1618
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1291
1619
|
else: # if version < 4.46.1
|
|
1292
1620
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1293
|
-
|
|
1621
|
+
if model is not None:
|
|
1622
|
+
model.forward = MethodType(phi3_lce_forward_deprecated, model)
|
|
1623
|
+
else:
|
|
1624
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
|
|
1294
1625
|
|
|
1295
1626
|
if model is not None:
|
|
1296
1627
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1341,11 +1672,12 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1341
1672
|
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
1342
1673
|
|
|
1343
1674
|
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
1675
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1344
1676
|
|
|
1345
1677
|
if rope:
|
|
1346
1678
|
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1347
1679
|
if rms_norm:
|
|
1348
|
-
modeling_olmo2.Olmo2RMSNorm =
|
|
1680
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
|
1349
1681
|
if swiglu:
|
|
1350
1682
|
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1351
1683
|
if cross_entropy:
|
|
@@ -1353,7 +1685,10 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1353
1685
|
|
|
1354
1686
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1355
1687
|
if fused_linear_cross_entropy:
|
|
1356
|
-
|
|
1688
|
+
if model is not None:
|
|
1689
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1690
|
+
else:
|
|
1691
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1357
1692
|
|
|
1358
1693
|
if model is not None:
|
|
1359
1694
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1404,11 +1739,12 @@ def apply_liger_kernel_to_glm4(
|
|
|
1404
1739
|
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
|
1405
1740
|
|
|
1406
1741
|
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
|
1742
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
1407
1743
|
|
|
1408
1744
|
if rope:
|
|
1409
1745
|
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
1410
1746
|
if rms_norm:
|
|
1411
|
-
modeling_glm4.Glm4RMSNorm =
|
|
1747
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
|
1412
1748
|
if swiglu:
|
|
1413
1749
|
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
|
1414
1750
|
if cross_entropy:
|
|
@@ -1416,7 +1752,10 @@ def apply_liger_kernel_to_glm4(
|
|
|
1416
1752
|
|
|
1417
1753
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1418
1754
|
if fused_linear_cross_entropy:
|
|
1419
|
-
|
|
1755
|
+
if model is not None:
|
|
1756
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
1757
|
+
else:
|
|
1758
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
1420
1759
|
|
|
1421
1760
|
if model is not None:
|
|
1422
1761
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1446,6 +1785,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1446
1785
|
"gemma3": apply_liger_kernel_to_gemma3,
|
|
1447
1786
|
"glm4": apply_liger_kernel_to_glm4,
|
|
1448
1787
|
"llama": apply_liger_kernel_to_llama,
|
|
1788
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
1789
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
1449
1790
|
"llava": apply_liger_kernel_to_llava,
|
|
1450
1791
|
"granite": apply_liger_kernel_to_granite,
|
|
1451
1792
|
"mllama": apply_liger_kernel_to_mllama,
|
|
@@ -1455,8 +1796,11 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1455
1796
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
1456
1797
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
1457
1798
|
"qwen3": apply_liger_kernel_to_qwen3,
|
|
1799
|
+
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
1458
1800
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
1801
|
+
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
1459
1802
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
1803
|
+
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
1460
1804
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1461
1805
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
1462
1806
|
}
|
|
@@ -1516,7 +1860,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
1516
1860
|
return
|
|
1517
1861
|
|
|
1518
1862
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
1519
|
-
|
|
1520
1863
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
1521
1864
|
|
|
1522
1865
|
# Filter out the keyword arguments that are not supported by the apply function
|